LLVM 23.0.0git
ScalarEvolution.cpp
Go to the documentation of this file.
1//===- ScalarEvolution.cpp - Scalar Evolution Analysis --------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file contains the implementation of the scalar evolution analysis
10// engine, which is used primarily to analyze expressions involving induction
11// variables in loops.
12//
13// There are several aspects to this library. First is the representation of
14// scalar expressions, which are represented as subclasses of the SCEV class.
15// These classes are used to represent certain types of subexpressions that we
16// can handle. We only create one SCEV of a particular shape, so
17// pointer-comparisons for equality are legal.
18//
19// One important aspect of the SCEV objects is that they are never cyclic, even
20// if there is a cycle in the dataflow for an expression (ie, a PHI node). If
21// the PHI node is one of the idioms that we can represent (e.g., a polynomial
22// recurrence) then we represent it directly as a recurrence node, otherwise we
23// represent it as a SCEVUnknown node.
24//
25// In addition to being able to represent expressions of various types, we also
26// have folders that are used to build the *canonical* representation for a
27// particular expression. These folders are capable of using a variety of
28// rewrite rules to simplify the expressions.
29//
30// Once the folders are defined, we can implement the more interesting
31// higher-level code, such as the code that recognizes PHI nodes of various
32// types, computes the execution count of a loop, etc.
33//
34// TODO: We should use these routines and value representations to implement
35// dependence analysis!
36//
37//===----------------------------------------------------------------------===//
38//
39// There are several good references for the techniques used in this analysis.
40//
41// Chains of recurrences -- a method to expedite the evaluation
42// of closed-form functions
43// Olaf Bachmann, Paul S. Wang, Eugene V. Zima
44//
45// On computational properties of chains of recurrences
46// Eugene V. Zima
47//
48// Symbolic Evaluation of Chains of Recurrences for Loop Optimization
49// Robert A. van Engelen
50//
51// Efficient Symbolic Analysis for Optimizing Compilers
52// Robert A. van Engelen
53//
54// Using the chains of recurrences algebra for data dependence testing and
55// induction variable substitution
56// MS Thesis, Johnie Birch
57//
58//===----------------------------------------------------------------------===//
59
61#include "llvm/ADT/APInt.h"
62#include "llvm/ADT/ArrayRef.h"
63#include "llvm/ADT/DenseMap.h"
65#include "llvm/ADT/FoldingSet.h"
66#include "llvm/ADT/STLExtras.h"
67#include "llvm/ADT/ScopeExit.h"
68#include "llvm/ADT/Sequence.h"
71#include "llvm/ADT/Statistic.h"
73#include "llvm/ADT/StringRef.h"
83#include "llvm/Config/llvm-config.h"
84#include "llvm/IR/Argument.h"
85#include "llvm/IR/BasicBlock.h"
86#include "llvm/IR/CFG.h"
87#include "llvm/IR/Constant.h"
89#include "llvm/IR/Constants.h"
90#include "llvm/IR/DataLayout.h"
92#include "llvm/IR/Dominators.h"
93#include "llvm/IR/Function.h"
94#include "llvm/IR/GlobalAlias.h"
95#include "llvm/IR/GlobalValue.h"
97#include "llvm/IR/InstrTypes.h"
98#include "llvm/IR/Instruction.h"
101#include "llvm/IR/Intrinsics.h"
102#include "llvm/IR/LLVMContext.h"
103#include "llvm/IR/Operator.h"
104#include "llvm/IR/PatternMatch.h"
105#include "llvm/IR/Type.h"
106#include "llvm/IR/Use.h"
107#include "llvm/IR/User.h"
108#include "llvm/IR/Value.h"
109#include "llvm/IR/Verifier.h"
111#include "llvm/Pass.h"
112#include "llvm/Support/Casting.h"
115#include "llvm/Support/Debug.h"
121#include <algorithm>
122#include <cassert>
123#include <climits>
124#include <cstdint>
125#include <cstdlib>
126#include <map>
127#include <memory>
128#include <numeric>
129#include <optional>
130#include <tuple>
131#include <utility>
132#include <vector>
133
134using namespace llvm;
135using namespace PatternMatch;
136using namespace SCEVPatternMatch;
137
138#define DEBUG_TYPE "scalar-evolution"
139
140STATISTIC(NumExitCountsComputed,
141 "Number of loop exits with predictable exit counts");
142STATISTIC(NumExitCountsNotComputed,
143 "Number of loop exits without predictable exit counts");
144STATISTIC(NumBruteForceTripCountsComputed,
145 "Number of loops with trip counts computed by force");
146
147#ifdef EXPENSIVE_CHECKS
148bool llvm::VerifySCEV = true;
149#else
150bool llvm::VerifySCEV = false;
151#endif
152
154 MaxBruteForceIterations("scalar-evolution-max-iterations", cl::ReallyHidden,
155 cl::desc("Maximum number of iterations SCEV will "
156 "symbolically execute a constant "
157 "derived loop"),
158 cl::init(100));
159
161 "verify-scev", cl::Hidden, cl::location(VerifySCEV),
162 cl::desc("Verify ScalarEvolution's backedge taken counts (slow)"));
164 "verify-scev-strict", cl::Hidden,
165 cl::desc("Enable stricter verification with -verify-scev is passed"));
166
168 "scev-verify-ir", cl::Hidden,
169 cl::desc("Verify IR correctness when making sensitive SCEV queries (slow)"),
170 cl::init(false));
171
173 "scev-mulops-inline-threshold", cl::Hidden,
174 cl::desc("Threshold for inlining multiplication operands into a SCEV"),
175 cl::init(32));
176
178 "scev-addops-inline-threshold", cl::Hidden,
179 cl::desc("Threshold for inlining addition operands into a SCEV"),
180 cl::init(500));
181
183 "scalar-evolution-max-scev-compare-depth", cl::Hidden,
184 cl::desc("Maximum depth of recursive SCEV complexity comparisons"),
185 cl::init(32));
186
188 "scalar-evolution-max-scev-operations-implication-depth", cl::Hidden,
189 cl::desc("Maximum depth of recursive SCEV operations implication analysis"),
190 cl::init(2));
191
193 "scalar-evolution-max-value-compare-depth", cl::Hidden,
194 cl::desc("Maximum depth of recursive value complexity comparisons"),
195 cl::init(2));
196
198 MaxArithDepth("scalar-evolution-max-arith-depth", cl::Hidden,
199 cl::desc("Maximum depth of recursive arithmetics"),
200 cl::init(32));
201
203 "scalar-evolution-max-constant-evolving-depth", cl::Hidden,
204 cl::desc("Maximum depth of recursive constant evolving"), cl::init(32));
205
207 MaxCastDepth("scalar-evolution-max-cast-depth", cl::Hidden,
208 cl::desc("Maximum depth of recursive SExt/ZExt/Trunc"),
209 cl::init(8));
210
212 MaxAddRecSize("scalar-evolution-max-add-rec-size", cl::Hidden,
213 cl::desc("Max coefficients in AddRec during evolving"),
214 cl::init(8));
215
217 HugeExprThreshold("scalar-evolution-huge-expr-threshold", cl::Hidden,
218 cl::desc("Size of the expression which is considered huge"),
219 cl::init(4096));
220
222 "scev-range-iter-threshold", cl::Hidden,
223 cl::desc("Threshold for switching to iteratively computing SCEV ranges"),
224 cl::init(32));
225
227 "scalar-evolution-max-loop-guard-collection-depth", cl::Hidden,
228 cl::desc("Maximum depth for recursive loop guard collection"), cl::init(1));
229
230static cl::opt<bool>
231ClassifyExpressions("scalar-evolution-classify-expressions",
232 cl::Hidden, cl::init(true),
233 cl::desc("When printing analysis, include information on every instruction"));
234
236 "scalar-evolution-use-expensive-range-sharpening", cl::Hidden,
237 cl::init(false),
238 cl::desc("Use more powerful methods of sharpening expression ranges. May "
239 "be costly in terms of compile time"));
240
242 "scalar-evolution-max-scc-analysis-depth", cl::Hidden,
243 cl::desc("Maximum amount of nodes to process while searching SCEVUnknown "
244 "Phi strongly connected components"),
245 cl::init(8));
246
247static cl::opt<bool>
248 EnableFiniteLoopControl("scalar-evolution-finite-loop", cl::Hidden,
249 cl::desc("Handle <= and >= in finite loops"),
250 cl::init(true));
251
253 "scalar-evolution-use-context-for-no-wrap-flag-strenghening", cl::Hidden,
254 cl::desc("Infer nuw/nsw flags using context where suitable"),
255 cl::init(true));
256
257//===----------------------------------------------------------------------===//
258// SCEV class definitions
259//===----------------------------------------------------------------------===//
260
261#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
263 print(dbgs());
264 dbgs() << '\n';
265}
266#endif
267
269 getPointer()->print(OS);
270 SCEV::NoWrapFlags Flags = static_cast<SCEV::NoWrapFlags>(getInt());
271 if (Flags & SCEV::FlagNUW)
272 OS << "(u nuw)";
273 if (Flags & SCEV::FlagNSW)
274 OS << "(u nsw)";
275}
276
277//===----------------------------------------------------------------------===//
278// Implementation of the SCEV class.
279//
280
281#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
283 print(dbgs());
284 dbgs() << '\n';
285}
286#endif
287
288void SCEV::print(raw_ostream &OS) const {
289 switch (getSCEVType()) {
290 case scConstant:
291 cast<SCEVConstant>(this)->getValue()->printAsOperand(OS, false);
292 return;
293 case scVScale:
294 OS << "vscale";
295 return;
296 case scPtrToAddr:
297 case scPtrToInt: {
298 const SCEVCastExpr *PtrCast = cast<SCEVCastExpr>(this);
299 const SCEV *Op = PtrCast->getOperand();
300 StringRef OpS = getSCEVType() == scPtrToAddr ? "addr" : "int";
301 OS << "(ptrto" << OpS << " " << *Op->getType() << " " << *Op << " to "
302 << *PtrCast->getType() << ")";
303 return;
304 }
305 case scTruncate: {
306 const SCEVTruncateExpr *Trunc = cast<SCEVTruncateExpr>(this);
307 const SCEV *Op = Trunc->getOperand();
308 OS << "(trunc " << *Op->getType() << " " << *Op << " to "
309 << *Trunc->getType() << ")";
310 return;
311 }
312 case scZeroExtend: {
314 const SCEV *Op = ZExt->getOperand();
315 OS << "(zext " << *Op->getType() << " " << *Op << " to "
316 << *ZExt->getType() << ")";
317 return;
318 }
319 case scSignExtend: {
321 const SCEV *Op = SExt->getOperand();
322 OS << "(sext " << *Op->getType() << " " << *Op << " to "
323 << *SExt->getType() << ")";
324 return;
325 }
326 case scAddRecExpr: {
327 const SCEVAddRecExpr *AR = cast<SCEVAddRecExpr>(this);
328 OS << "{" << *AR->getOperand(0);
329 for (unsigned i = 1, e = AR->getNumOperands(); i != e; ++i)
330 OS << ",+," << *AR->getOperand(i);
331 OS << "}<";
332 if (AR->hasNoUnsignedWrap())
333 OS << "nuw><";
334 if (AR->hasNoSignedWrap())
335 OS << "nsw><";
336 if (AR->hasNoSelfWrap() &&
338 OS << "nw><";
339 AR->getLoop()->getHeader()->printAsOperand(OS, /*PrintType=*/false);
340 OS << ">";
341 return;
342 }
343 case scAddExpr:
344 case scMulExpr:
345 case scUMaxExpr:
346 case scSMaxExpr:
347 case scUMinExpr:
348 case scSMinExpr:
350 const SCEVNAryExpr *NAry = cast<SCEVNAryExpr>(this);
351 const char *OpStr = nullptr;
352 switch (NAry->getSCEVType()) {
353 case scAddExpr: OpStr = " + "; break;
354 case scMulExpr: OpStr = " * "; break;
355 case scUMaxExpr: OpStr = " umax "; break;
356 case scSMaxExpr: OpStr = " smax "; break;
357 case scUMinExpr:
358 OpStr = " umin ";
359 break;
360 case scSMinExpr:
361 OpStr = " smin ";
362 break;
364 OpStr = " umin_seq ";
365 break;
366 default:
367 llvm_unreachable("There are no other nary expression types.");
368 }
369 OS << "("
371 << ")";
372 switch (NAry->getSCEVType()) {
373 case scAddExpr:
374 case scMulExpr:
375 if (NAry->hasNoUnsignedWrap())
376 OS << "<nuw>";
377 if (NAry->hasNoSignedWrap())
378 OS << "<nsw>";
379 break;
380 default:
381 // Nothing to print for other nary expressions.
382 break;
383 }
384 return;
385 }
386 case scUDivExpr: {
387 const SCEVUDivExpr *UDiv = cast<SCEVUDivExpr>(this);
388 OS << "(" << *UDiv->getLHS() << " /u " << *UDiv->getRHS() << ")";
389 return;
390 }
391 case scUnknown:
392 cast<SCEVUnknown>(this)->getValue()->printAsOperand(OS, false);
393 return;
395 OS << "***COULDNOTCOMPUTE***";
396 return;
397 }
398 llvm_unreachable("Unknown SCEV kind!");
399}
400
402 switch (getSCEVType()) {
403 case scConstant:
404 return cast<SCEVConstant>(this)->getType();
405 case scVScale:
406 return cast<SCEVVScale>(this)->getType();
407 case scPtrToAddr:
408 case scPtrToInt:
409 case scTruncate:
410 case scZeroExtend:
411 case scSignExtend:
412 return cast<SCEVCastExpr>(this)->getType();
413 case scAddRecExpr:
414 return cast<SCEVAddRecExpr>(this)->getType();
415 case scMulExpr:
416 return cast<SCEVMulExpr>(this)->getType();
417 case scUMaxExpr:
418 case scSMaxExpr:
419 case scUMinExpr:
420 case scSMinExpr:
421 return cast<SCEVMinMaxExpr>(this)->getType();
423 return cast<SCEVSequentialMinMaxExpr>(this)->getType();
424 case scAddExpr:
425 return cast<SCEVAddExpr>(this)->getType();
426 case scUDivExpr:
427 return cast<SCEVUDivExpr>(this)->getType();
428 case scUnknown:
429 return cast<SCEVUnknown>(this)->getType();
431 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
432 }
433 llvm_unreachable("Unknown SCEV kind!");
434}
435
437 switch (getSCEVType()) {
438 case scConstant:
439 case scVScale:
440 case scUnknown:
441 return {};
442 case scPtrToAddr:
443 case scPtrToInt:
444 case scTruncate:
445 case scZeroExtend:
446 case scSignExtend:
447 return cast<SCEVCastExpr>(this)->operands();
448 case scAddRecExpr:
449 case scAddExpr:
450 case scMulExpr:
451 case scUMaxExpr:
452 case scSMaxExpr:
453 case scUMinExpr:
454 case scSMinExpr:
456 return cast<SCEVNAryExpr>(this)->operands();
457 case scUDivExpr:
458 return cast<SCEVUDivExpr>(this)->operands();
460 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
461 }
462 llvm_unreachable("Unknown SCEV kind!");
463}
464
465bool SCEV::isZero() const { return match(this, m_scev_Zero()); }
466
467bool SCEV::isOne() const { return match(this, m_scev_One()); }
468
469bool SCEV::isAllOnesValue() const { return match(this, m_scev_AllOnes()); }
470
473 if (!Mul) return false;
474
475 // If there is a constant factor, it will be first.
476 const SCEVConstant *SC = dyn_cast<SCEVConstant>(Mul->getOperand(0));
477 if (!SC) return false;
478
479 // Return true if the value is negative, this matches things like (-42 * V).
480 return SC->getAPInt().isNegative();
481}
482
485
487 return S->getSCEVType() == scCouldNotCompute;
488}
489
492 ID.AddInteger(scConstant);
493 ID.AddPointer(V);
494 void *IP = nullptr;
495 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
496 SCEV *S = new (SCEVAllocator) SCEVConstant(ID.Intern(SCEVAllocator), V);
497 UniqueSCEVs.InsertNode(S, IP);
498 return S;
499}
500
502 return getConstant(ConstantInt::get(getContext(), Val));
503}
504
505const SCEV *
508 // TODO: Avoid implicit trunc?
509 // See https://github.com/llvm/llvm-project/issues/112510.
510 return getConstant(
511 ConstantInt::get(ITy, V, isSigned, /*ImplicitTrunc=*/true));
512}
513
516 ID.AddInteger(scVScale);
517 ID.AddPointer(Ty);
518 void *IP = nullptr;
519 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP))
520 return S;
521 SCEV *S = new (SCEVAllocator) SCEVVScale(ID.Intern(SCEVAllocator), Ty);
522 UniqueSCEVs.InsertNode(S, IP);
523 return S;
524}
525
527 SCEV::NoWrapFlags Flags) {
528 const SCEV *Res = getConstant(Ty, EC.getKnownMinValue());
529 if (EC.isScalable())
530 Res = getMulExpr(Res, getVScale(Ty), Flags);
531 return Res;
532}
533
537
538SCEVPtrToAddrExpr::SCEVPtrToAddrExpr(const FoldingSetNodeIDRef ID,
539 const SCEV *Op, Type *ITy)
540 : SCEVCastExpr(ID, scPtrToAddr, Op, ITy) {
541 assert(getOperand()->getType()->isPointerTy() && Ty->isIntegerTy() &&
542 "Must be a non-bit-width-changing pointer-to-integer cast!");
543}
544
545SCEVPtrToIntExpr::SCEVPtrToIntExpr(const FoldingSetNodeIDRef ID, SCEVUse Op,
546 Type *ITy)
547 : SCEVCastExpr(ID, scPtrToInt, Op, ITy) {
548 assert(getOperand()->getType()->isPointerTy() && Ty->isIntegerTy() &&
549 "Must be a non-bit-width-changing pointer-to-integer cast!");
550}
551
556
557SCEVTruncateExpr::SCEVTruncateExpr(const FoldingSetNodeIDRef ID, SCEVUse op,
558 Type *ty)
560 assert(getOperand()->getType()->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
561 "Cannot truncate non-integer value!");
562}
563
564SCEVZeroExtendExpr::SCEVZeroExtendExpr(const FoldingSetNodeIDRef ID, SCEVUse op,
565 Type *ty)
567 assert(getOperand()->getType()->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
568 "Cannot zero extend non-integer value!");
569}
570
571SCEVSignExtendExpr::SCEVSignExtendExpr(const FoldingSetNodeIDRef ID, SCEVUse op,
572 Type *ty)
574 assert(getOperand()->getType()->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
575 "Cannot sign extend non-integer value!");
576}
577
579 // Clear this SCEVUnknown from various maps.
580 SE->forgetMemoizedResults({this});
581
582 // Remove this SCEVUnknown from the uniquing map.
583 SE->UniqueSCEVs.RemoveNode(this);
584
585 // Release the value.
586 setValPtr(nullptr);
587}
588
589void SCEVUnknown::allUsesReplacedWith(Value *New) {
590 // Clear this SCEVUnknown from various maps.
591 SE->forgetMemoizedResults({this});
592
593 // Remove this SCEVUnknown from the uniquing map.
594 SE->UniqueSCEVs.RemoveNode(this);
595
596 // Replace the value pointer in case someone is still using this SCEVUnknown.
597 setValPtr(New);
598}
599
600//===----------------------------------------------------------------------===//
601// SCEV Utilities
602//===----------------------------------------------------------------------===//
603
604/// Compare the two values \p LV and \p RV in terms of their "complexity" where
605/// "complexity" is a partial (and somewhat ad-hoc) relation used to order
606/// operands in SCEV expressions.
607static int CompareValueComplexity(const LoopInfo *const LI, Value *LV,
608 Value *RV, unsigned Depth) {
610 return 0;
611
612 // Order pointer values after integer values. This helps SCEVExpander form
613 // GEPs.
614 bool LIsPointer = LV->getType()->isPointerTy(),
615 RIsPointer = RV->getType()->isPointerTy();
616 if (LIsPointer != RIsPointer)
617 return (int)LIsPointer - (int)RIsPointer;
618
619 // Compare getValueID values.
620 unsigned LID = LV->getValueID(), RID = RV->getValueID();
621 if (LID != RID)
622 return (int)LID - (int)RID;
623
624 // Sort arguments by their position.
625 if (const auto *LA = dyn_cast<Argument>(LV)) {
626 const auto *RA = cast<Argument>(RV);
627 unsigned LArgNo = LA->getArgNo(), RArgNo = RA->getArgNo();
628 return (int)LArgNo - (int)RArgNo;
629 }
630
631 if (const auto *LGV = dyn_cast<GlobalValue>(LV)) {
632 const auto *RGV = cast<GlobalValue>(RV);
633
634 if (auto L = LGV->getLinkage() - RGV->getLinkage())
635 return L;
636
637 const auto IsGVNameSemantic = [&](const GlobalValue *GV) {
638 auto LT = GV->getLinkage();
639 return !(GlobalValue::isPrivateLinkage(LT) ||
641 };
642
643 // Use the names to distinguish the two values, but only if the
644 // names are semantically important.
645 if (IsGVNameSemantic(LGV) && IsGVNameSemantic(RGV))
646 return LGV->getName().compare(RGV->getName());
647 }
648
649 // For instructions, compare their loop depth, and their operand count. This
650 // is pretty loose.
651 if (const auto *LInst = dyn_cast<Instruction>(LV)) {
652 const auto *RInst = cast<Instruction>(RV);
653
654 // Compare loop depths.
655 const BasicBlock *LParent = LInst->getParent(),
656 *RParent = RInst->getParent();
657 if (LParent != RParent) {
658 unsigned LDepth = LI->getLoopDepth(LParent),
659 RDepth = LI->getLoopDepth(RParent);
660 if (LDepth != RDepth)
661 return (int)LDepth - (int)RDepth;
662 }
663
664 // Compare the number of operands.
665 unsigned LNumOps = LInst->getNumOperands(),
666 RNumOps = RInst->getNumOperands();
667 if (LNumOps != RNumOps)
668 return (int)LNumOps - (int)RNumOps;
669
670 for (unsigned Idx : seq(LNumOps)) {
671 int Result = CompareValueComplexity(LI, LInst->getOperand(Idx),
672 RInst->getOperand(Idx), Depth + 1);
673 if (Result != 0)
674 return Result;
675 }
676 }
677
678 return 0;
679}
680
681// Return negative, zero, or positive, if LHS is less than, equal to, or greater
682// than RHS, respectively. A three-way result allows recursive comparisons to be
683// more efficient.
684// If the max analysis depth was reached, return std::nullopt, assuming we do
685// not know if they are equivalent for sure.
686static std::optional<int>
687CompareSCEVComplexity(const LoopInfo *const LI, const SCEV *LHS,
688 const SCEV *RHS, DominatorTree &DT, unsigned Depth = 0) {
689 // Fast-path: SCEVs are uniqued so we can do a quick equality check.
690 if (LHS == RHS)
691 return 0;
692
693 // Primarily, sort the SCEVs by their getSCEVType().
694 SCEVTypes LType = LHS->getSCEVType(), RType = RHS->getSCEVType();
695 if (LType != RType)
696 return (int)LType - (int)RType;
697
699 return std::nullopt;
700
701 // Aside from the getSCEVType() ordering, the particular ordering
702 // isn't very important except that it's beneficial to be consistent,
703 // so that (a + b) and (b + a) don't end up as different expressions.
704 switch (LType) {
705 case scUnknown: {
706 const SCEVUnknown *LU = cast<SCEVUnknown>(LHS);
707 const SCEVUnknown *RU = cast<SCEVUnknown>(RHS);
708
709 int X =
710 CompareValueComplexity(LI, LU->getValue(), RU->getValue(), Depth + 1);
711 return X;
712 }
713
714 case scConstant: {
717
718 // Compare constant values.
719 const APInt &LA = LC->getAPInt();
720 const APInt &RA = RC->getAPInt();
721 unsigned LBitWidth = LA.getBitWidth(), RBitWidth = RA.getBitWidth();
722 if (LBitWidth != RBitWidth)
723 return (int)LBitWidth - (int)RBitWidth;
724 return LA.ult(RA) ? -1 : 1;
725 }
726
727 case scVScale: {
728 const auto *LTy = cast<IntegerType>(cast<SCEVVScale>(LHS)->getType());
729 const auto *RTy = cast<IntegerType>(cast<SCEVVScale>(RHS)->getType());
730 return LTy->getBitWidth() - RTy->getBitWidth();
731 }
732
733 case scAddRecExpr: {
736
737 // There is always a dominance between two recs that are used by one SCEV,
738 // so we can safely sort recs by loop header dominance. We require such
739 // order in getAddExpr.
740 const Loop *LLoop = LA->getLoop(), *RLoop = RA->getLoop();
741 if (LLoop != RLoop) {
742 const BasicBlock *LHead = LLoop->getHeader(), *RHead = RLoop->getHeader();
743 assert(LHead != RHead && "Two loops share the same header?");
744 if (DT.dominates(LHead, RHead))
745 return 1;
746 assert(DT.dominates(RHead, LHead) &&
747 "No dominance between recurrences used by one SCEV?");
748 return -1;
749 }
750
751 [[fallthrough]];
752 }
753
754 case scTruncate:
755 case scZeroExtend:
756 case scSignExtend:
757 case scPtrToAddr:
758 case scPtrToInt:
759 case scAddExpr:
760 case scMulExpr:
761 case scUDivExpr:
762 case scSMaxExpr:
763 case scUMaxExpr:
764 case scSMinExpr:
765 case scUMinExpr:
767 ArrayRef<SCEVUse> LOps = LHS->operands();
768 ArrayRef<SCEVUse> ROps = RHS->operands();
769
770 // Lexicographically compare n-ary-like expressions.
771 unsigned LNumOps = LOps.size(), RNumOps = ROps.size();
772 if (LNumOps != RNumOps)
773 return (int)LNumOps - (int)RNumOps;
774
775 for (unsigned i = 0; i != LNumOps; ++i) {
776 auto X = CompareSCEVComplexity(LI, LOps[i].getPointer(),
777 ROps[i].getPointer(), DT, Depth + 1);
778 if (X != 0)
779 return X;
780 }
781 return 0;
782 }
783
785 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
786 }
787 llvm_unreachable("Unknown SCEV kind!");
788}
789
790/// Given a list of SCEV objects, order them by their complexity, and group
791/// objects of the same complexity together by value. When this routine is
792/// finished, we know that any duplicates in the vector are consecutive and that
793/// complexity is monotonically increasing.
794///
795/// Note that we go take special precautions to ensure that we get deterministic
796/// results from this routine. In other words, we don't want the results of
797/// this to depend on where the addresses of various SCEV objects happened to
798/// land in memory.
800 DominatorTree &DT) {
801 if (Ops.size() < 2) return; // Noop
802
803 // Whether LHS has provably less complexity than RHS.
804 auto IsLessComplex = [&](SCEVUse LHS, SCEVUse RHS) {
805 auto Complexity = CompareSCEVComplexity(LI, LHS, RHS, DT);
806 return Complexity && *Complexity < 0;
807 };
808 if (Ops.size() == 2) {
809 // This is the common case, which also happens to be trivially simple.
810 // Special case it.
811 SCEVUse &LHS = Ops[0], &RHS = Ops[1];
812 if (IsLessComplex(RHS, LHS))
813 std::swap(LHS, RHS);
814 return;
815 }
816
817 // Do the rough sort by complexity.
819 Ops, [&](SCEVUse LHS, SCEVUse RHS) { return IsLessComplex(LHS, RHS); });
820
821 // Now that we are sorted by complexity, group elements of the same
822 // complexity. Note that this is, at worst, N^2, but the vector is likely to
823 // be extremely short in practice. Note that we take this approach because we
824 // do not want to depend on the addresses of the objects we are grouping.
825 for (unsigned i = 0, e = Ops.size(); i != e-2; ++i) {
826 const SCEV *S = Ops[i];
827 unsigned Complexity = S->getSCEVType();
828
829 // If there are any objects of the same complexity and same value as this
830 // one, group them.
831 for (unsigned j = i+1; j != e && Ops[j]->getSCEVType() == Complexity; ++j) {
832 if (Ops[j] == S) { // Found a duplicate.
833 // Move it to immediately after i'th element.
834 std::swap(Ops[i+1], Ops[j]);
835 ++i; // no need to rescan it.
836 if (i == e-2) return; // Done!
837 }
838 }
839 }
840}
841
842/// Returns true if \p Ops contains a huge SCEV (the subtree of S contains at
843/// least HugeExprThreshold nodes).
845 return any_of(Ops, [](const SCEV *S) {
847 });
848}
849
850/// Performs a number of common optimizations on the passed \p Ops. If the
851/// whole expression reduces down to a single operand, it will be returned.
852///
853/// The following optimizations are performed:
854/// * Fold constants using the \p Fold function.
855/// * Remove identity constants satisfying \p IsIdentity.
856/// * If a constant satisfies \p IsAbsorber, return it.
857/// * Sort operands by complexity.
858template <typename FoldT, typename IsIdentityT, typename IsAbsorberT>
859static const SCEV *
861 SmallVectorImpl<SCEVUse> &Ops, FoldT Fold,
862 IsIdentityT IsIdentity, IsAbsorberT IsAbsorber) {
863 const SCEVConstant *Folded = nullptr;
864 for (unsigned Idx = 0; Idx < Ops.size();) {
865 const SCEV *Op = Ops[Idx];
866 if (const auto *C = dyn_cast<SCEVConstant>(Op)) {
867 if (!Folded)
868 Folded = C;
869 else
870 Folded = cast<SCEVConstant>(
871 SE.getConstant(Fold(Folded->getAPInt(), C->getAPInt())));
872 Ops.erase(Ops.begin() + Idx);
873 continue;
874 }
875 ++Idx;
876 }
877
878 if (Ops.empty()) {
879 assert(Folded && "Must have folded value");
880 return Folded;
881 }
882
883 if (Folded && IsAbsorber(Folded->getAPInt()))
884 return Folded;
885
886 GroupByComplexity(Ops, &LI, DT);
887 if (Folded && !IsIdentity(Folded->getAPInt()))
888 Ops.insert(Ops.begin(), Folded);
889
890 return Ops.size() == 1 ? Ops[0] : nullptr;
891}
892
893//===----------------------------------------------------------------------===//
894// Simple SCEV method implementations
895//===----------------------------------------------------------------------===//
896
897/// Compute BC(It, K). The result has width W. Assume, K > 0.
898static const SCEV *BinomialCoefficient(const SCEV *It, unsigned K,
899 ScalarEvolution &SE,
900 Type *ResultTy) {
901 // Handle the simplest case efficiently.
902 if (K == 1)
903 return SE.getTruncateOrZeroExtend(It, ResultTy);
904
905 // We are using the following formula for BC(It, K):
906 //
907 // BC(It, K) = (It * (It - 1) * ... * (It - K + 1)) / K!
908 //
909 // Suppose, W is the bitwidth of the return value. We must be prepared for
910 // overflow. Hence, we must assure that the result of our computation is
911 // equal to the accurate one modulo 2^W. Unfortunately, division isn't
912 // safe in modular arithmetic.
913 //
914 // However, this code doesn't use exactly that formula; the formula it uses
915 // is something like the following, where T is the number of factors of 2 in
916 // K! (i.e. trailing zeros in the binary representation of K!), and ^ is
917 // exponentiation:
918 //
919 // BC(It, K) = (It * (It - 1) * ... * (It - K + 1)) / 2^T / (K! / 2^T)
920 //
921 // This formula is trivially equivalent to the previous formula. However,
922 // this formula can be implemented much more efficiently. The trick is that
923 // K! / 2^T is odd, and exact division by an odd number *is* safe in modular
924 // arithmetic. To do exact division in modular arithmetic, all we have
925 // to do is multiply by the inverse. Therefore, this step can be done at
926 // width W.
927 //
928 // The next issue is how to safely do the division by 2^T. The way this
929 // is done is by doing the multiplication step at a width of at least W + T
930 // bits. This way, the bottom W+T bits of the product are accurate. Then,
931 // when we perform the division by 2^T (which is equivalent to a right shift
932 // by T), the bottom W bits are accurate. Extra bits are okay; they'll get
933 // truncated out after the division by 2^T.
934 //
935 // In comparison to just directly using the first formula, this technique
936 // is much more efficient; using the first formula requires W * K bits,
937 // but this formula less than W + K bits. Also, the first formula requires
938 // a division step, whereas this formula only requires multiplies and shifts.
939 //
940 // It doesn't matter whether the subtraction step is done in the calculation
941 // width or the input iteration count's width; if the subtraction overflows,
942 // the result must be zero anyway. We prefer here to do it in the width of
943 // the induction variable because it helps a lot for certain cases; CodeGen
944 // isn't smart enough to ignore the overflow, which leads to much less
945 // efficient code if the width of the subtraction is wider than the native
946 // register width.
947 //
948 // (It's possible to not widen at all by pulling out factors of 2 before
949 // the multiplication; for example, K=2 can be calculated as
950 // It/2*(It+(It*INT_MIN/INT_MIN)+-1). However, it requires
951 // extra arithmetic, so it's not an obvious win, and it gets
952 // much more complicated for K > 3.)
953
954 // Protection from insane SCEVs; this bound is conservative,
955 // but it probably doesn't matter.
956 if (K > 1000)
957 return SE.getCouldNotCompute();
958
959 unsigned W = SE.getTypeSizeInBits(ResultTy);
960
961 // Calculate K! / 2^T and T; we divide out the factors of two before
962 // multiplying for calculating K! / 2^T to avoid overflow.
963 // Other overflow doesn't matter because we only care about the bottom
964 // W bits of the result.
965 APInt OddFactorial(W, 1);
966 unsigned T = 1;
967 for (unsigned i = 3; i <= K; ++i) {
968 unsigned TwoFactors = countr_zero(i);
969 T += TwoFactors;
970 OddFactorial *= (i >> TwoFactors);
971 }
972
973 // We need at least W + T bits for the multiplication step
974 unsigned CalculationBits = W + T;
975
976 // Calculate 2^T, at width T+W.
977 APInt DivFactor = APInt::getOneBitSet(CalculationBits, T);
978
979 // Calculate the multiplicative inverse of K! / 2^T;
980 // this multiplication factor will perform the exact division by
981 // K! / 2^T.
982 APInt MultiplyFactor = OddFactorial.multiplicativeInverse();
983
984 // Calculate the product, at width T+W
985 IntegerType *CalculationTy = IntegerType::get(SE.getContext(),
986 CalculationBits);
987 const SCEV *Dividend = SE.getTruncateOrZeroExtend(It, CalculationTy);
988 for (unsigned i = 1; i != K; ++i) {
989 const SCEV *S = SE.getMinusSCEV(It, SE.getConstant(It->getType(), i));
990 Dividend = SE.getMulExpr(Dividend,
991 SE.getTruncateOrZeroExtend(S, CalculationTy));
992 }
993
994 // Divide by 2^T
995 const SCEV *DivResult = SE.getUDivExpr(Dividend, SE.getConstant(DivFactor));
996
997 // Truncate the result, and divide by K! / 2^T.
998
999 return SE.getMulExpr(SE.getConstant(MultiplyFactor),
1000 SE.getTruncateOrZeroExtend(DivResult, ResultTy));
1001}
1002
1003/// Return the value of this chain of recurrences at the specified iteration
1004/// number. We can evaluate this recurrence by multiplying each element in the
1005/// chain by the binomial coefficient corresponding to it. In other words, we
1006/// can evaluate {A,+,B,+,C,+,D} as:
1007///
1008/// A*BC(It, 0) + B*BC(It, 1) + C*BC(It, 2) + D*BC(It, 3)
1009///
1010/// where BC(It, k) stands for binomial coefficient.
1012 ScalarEvolution &SE) const {
1013 return evaluateAtIteration(operands(), It, SE);
1014}
1015
1017 const SCEV *It,
1018 ScalarEvolution &SE) {
1019 assert(Operands.size() > 0);
1020 const SCEV *Result = Operands[0].getPointer();
1021 for (unsigned i = 1, e = Operands.size(); i != e; ++i) {
1022 // The computation is correct in the face of overflow provided that the
1023 // multiplication is performed _after_ the evaluation of the binomial
1024 // coefficient.
1025 const SCEV *Coeff = BinomialCoefficient(It, i, SE, Result->getType());
1026 if (isa<SCEVCouldNotCompute>(Coeff))
1027 return Coeff;
1028
1029 Result =
1030 SE.getAddExpr(Result, SE.getMulExpr(Operands[i].getPointer(), Coeff));
1031 }
1032 return Result;
1033}
1034
1035//===----------------------------------------------------------------------===//
1036// SCEV Expression folder implementations
1037//===----------------------------------------------------------------------===//
1038
1039/// The SCEVCastSinkingRewriter takes a scalar evolution expression,
1040/// which computes a pointer-typed value, and rewrites the whole expression
1041/// tree so that *all* the computations are done on integers, and the only
1042/// pointer-typed operands in the expression are SCEVUnknown.
1043/// The CreatePtrCast callback is invoked to create the actual conversion
1044/// (ptrtoint or ptrtoaddr) at the SCEVUnknown leaves.
1046 : public SCEVRewriteVisitor<SCEVCastSinkingRewriter> {
1048 using ConversionFn = function_ref<const SCEV *(const SCEVUnknown *)>;
1049 Type *TargetTy;
1050 ConversionFn CreatePtrCast;
1051
1052public:
1054 ConversionFn CreatePtrCast)
1055 : Base(SE), TargetTy(TargetTy), CreatePtrCast(std::move(CreatePtrCast)) {}
1056
1057 static const SCEV *rewrite(const SCEV *Scev, ScalarEvolution &SE,
1058 Type *TargetTy, ConversionFn CreatePtrCast) {
1059 SCEVCastSinkingRewriter Rewriter(SE, TargetTy, std::move(CreatePtrCast));
1060 return Rewriter.visit(Scev);
1061 }
1062
1063 const SCEV *visit(const SCEV *S) {
1064 Type *STy = S->getType();
1065 // If the expression is not pointer-typed, just keep it as-is.
1066 if (!STy->isPointerTy())
1067 return S;
1068 // Else, recursively sink the cast down into it.
1069 return Base::visit(S);
1070 }
1071
1072 const SCEV *visitAddExpr(const SCEVAddExpr *Expr) {
1073 // Preserve wrap flags on rewritten SCEVAddExpr, which the default
1074 // implementation drops.
1075 SmallVector<SCEVUse, 2> Operands;
1076 bool Changed = false;
1077 for (SCEVUse Op : Expr->operands()) {
1078 Operands.push_back(visit(Op.getPointer()));
1079 Changed |= Op.getPointer() != Operands.back();
1080 }
1081 return !Changed ? Expr : SE.getAddExpr(Operands, Expr->getNoWrapFlags());
1082 }
1083
1084 const SCEV *visitMulExpr(const SCEVMulExpr *Expr) {
1085 SmallVector<SCEVUse, 2> Operands;
1086 bool Changed = false;
1087 for (SCEVUse Op : Expr->operands()) {
1088 Operands.push_back(visit(Op.getPointer()));
1089 Changed |= Op.getPointer() != Operands.back();
1090 }
1091 return !Changed ? Expr : SE.getMulExpr(Operands, Expr->getNoWrapFlags());
1092 }
1093
1094 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
1095 assert(Expr->getType()->isPointerTy() &&
1096 "Should only reach pointer-typed SCEVUnknown's.");
1097 // Perform some basic constant folding. If the operand of the cast is a
1098 // null pointer, don't create a cast SCEV expression (that will be left
1099 // as-is), but produce a zero constant.
1101 return SE.getZero(TargetTy);
1102 return CreatePtrCast(Expr);
1103 }
1104};
1105
1107 assert(Op->getType()->isPointerTy() && "Op must be a pointer");
1108
1109 // It isn't legal for optimizations to construct new ptrtoint expressions
1110 // for non-integral pointers.
1111 if (getDataLayout().isNonIntegralPointerType(Op->getType()))
1112 return getCouldNotCompute();
1113
1114 Type *IntPtrTy = getDataLayout().getIntPtrType(Op->getType());
1115
1116 // We can only trivially model ptrtoint if SCEV's effective (integer) type
1117 // is sufficiently wide to represent all possible pointer values.
1118 // We could theoretically teach SCEV to truncate wider pointers, but
1119 // that isn't implemented for now.
1121 getDataLayout().getTypeSizeInBits(IntPtrTy))
1122 return getCouldNotCompute();
1123
1124 // Use the rewriter to sink the cast down to SCEVUnknown leaves.
1126 Op, *this, IntPtrTy, [this, IntPtrTy](const SCEVUnknown *U) {
1128 ID.AddInteger(scPtrToInt);
1129 ID.AddPointer(U);
1130 void *IP = nullptr;
1131 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP))
1132 return S;
1133 SCEV *S = new (SCEVAllocator)
1134 SCEVPtrToIntExpr(ID.Intern(SCEVAllocator), U, IntPtrTy);
1135 UniqueSCEVs.InsertNode(S, IP);
1136 registerUser(S, U);
1137 return static_cast<const SCEV *>(S);
1138 });
1139 assert(IntOp->getType()->isIntegerTy() &&
1140 "We must have succeeded in sinking the cast, "
1141 "and ending up with an integer-typed expression!");
1142 return IntOp;
1143}
1144
1146 assert(Op->getType()->isPointerTy() && "Op must be a pointer");
1147
1148 // Treat pointers with unstable representation conservatively, since the
1149 // address bits may change.
1150 if (DL.hasUnstableRepresentation(Op->getType()))
1151 return getCouldNotCompute();
1152
1153 Type *Ty = DL.getAddressType(Op->getType());
1154
1155 // Use the rewriter to sink the cast down to SCEVUnknown leaves.
1156 // The rewriter handles null pointer constant folding.
1158 Op, *this, Ty, [this, Ty](const SCEVUnknown *U) {
1160 ID.AddInteger(scPtrToAddr);
1161 ID.AddPointer(U);
1162 ID.AddPointer(Ty);
1163 void *IP = nullptr;
1164 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP))
1165 return S;
1166 SCEV *S = new (SCEVAllocator)
1167 SCEVPtrToAddrExpr(ID.Intern(SCEVAllocator), U, Ty);
1168 UniqueSCEVs.InsertNode(S, IP);
1169 registerUser(S, U);
1170 return static_cast<const SCEV *>(S);
1171 });
1172 assert(IntOp->getType()->isIntegerTy() &&
1173 "We must have succeeded in sinking the cast, "
1174 "and ending up with an integer-typed expression!");
1175 return IntOp;
1176}
1177
1179 assert(Ty->isIntegerTy() && "Target type must be an integer type!");
1180
1181 const SCEV *IntOp = getLosslessPtrToIntExpr(Op);
1182 if (isa<SCEVCouldNotCompute>(IntOp))
1183 return IntOp;
1184
1185 return getTruncateOrZeroExtend(IntOp, Ty);
1186}
1187
1189 unsigned Depth) {
1190 assert(getTypeSizeInBits(Op->getType()) > getTypeSizeInBits(Ty) &&
1191 "This is not a truncating conversion!");
1192 assert(isSCEVable(Ty) &&
1193 "This is not a conversion to a SCEVable type!");
1194 assert(!Op->getType()->isPointerTy() && "Can't truncate pointer!");
1195 Ty = getEffectiveSCEVType(Ty);
1196
1198 ID.AddInteger(scTruncate);
1199 ID.AddPointer(Op);
1200 ID.AddPointer(Ty);
1201 void *IP = nullptr;
1202 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
1203
1204 // Fold if the operand is constant.
1205 if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op))
1206 return getConstant(
1207 cast<ConstantInt>(ConstantExpr::getTrunc(SC->getValue(), Ty)));
1208
1209 // trunc(trunc(x)) --> trunc(x)
1211 return getTruncateExpr(ST->getOperand(), Ty, Depth + 1);
1212
1213 // trunc(sext(x)) --> sext(x) if widening or trunc(x) if narrowing
1215 return getTruncateOrSignExtend(SS->getOperand(), Ty, Depth + 1);
1216
1217 // trunc(zext(x)) --> zext(x) if widening or trunc(x) if narrowing
1219 return getTruncateOrZeroExtend(SZ->getOperand(), Ty, Depth + 1);
1220
1221 if (Depth > MaxCastDepth) {
1222 SCEV *S =
1223 new (SCEVAllocator) SCEVTruncateExpr(ID.Intern(SCEVAllocator), Op, Ty);
1224 UniqueSCEVs.InsertNode(S, IP);
1225 registerUser(S, Op);
1226 return S;
1227 }
1228
1229 // trunc(x1 + ... + xN) --> trunc(x1) + ... + trunc(xN) and
1230 // trunc(x1 * ... * xN) --> trunc(x1) * ... * trunc(xN),
1231 // if after transforming we have at most one truncate, not counting truncates
1232 // that replace other casts.
1234 auto *CommOp = cast<SCEVCommutativeExpr>(Op);
1235 SmallVector<SCEVUse, 4> Operands;
1236 unsigned numTruncs = 0;
1237 for (unsigned i = 0, e = CommOp->getNumOperands(); i != e && numTruncs < 2;
1238 ++i) {
1239 const SCEV *S = getTruncateExpr(CommOp->getOperand(i), Ty, Depth + 1);
1240 if (!isa<SCEVIntegralCastExpr>(CommOp->getOperand(i)) &&
1242 numTruncs++;
1243 Operands.push_back(S);
1244 }
1245 if (numTruncs < 2) {
1246 if (isa<SCEVAddExpr>(Op))
1247 return getAddExpr(Operands);
1248 if (isa<SCEVMulExpr>(Op))
1249 return getMulExpr(Operands);
1250 llvm_unreachable("Unexpected SCEV type for Op.");
1251 }
1252 // Although we checked in the beginning that ID is not in the cache, it is
1253 // possible that during recursion and different modification ID was inserted
1254 // into the cache. So if we find it, just return it.
1255 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP))
1256 return S;
1257 }
1258
1259 // If the input value is a chrec scev, truncate the chrec's operands.
1260 if (const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(Op)) {
1261 SmallVector<SCEVUse, 4> Operands;
1262 for (const SCEV *Op : AddRec->operands())
1263 Operands.push_back(getTruncateExpr(Op, Ty, Depth + 1));
1264 return getAddRecExpr(Operands, AddRec->getLoop(), SCEV::FlagAnyWrap);
1265 }
1266
1267 // Return zero if truncating to known zeros.
1268 uint32_t MinTrailingZeros = getMinTrailingZeros(Op);
1269 if (MinTrailingZeros >= getTypeSizeInBits(Ty))
1270 return getZero(Ty);
1271
1272 // The cast wasn't folded; create an explicit cast node. We can reuse
1273 // the existing insert position since if we get here, we won't have
1274 // made any changes which would invalidate it.
1275 SCEV *S = new (SCEVAllocator) SCEVTruncateExpr(ID.Intern(SCEVAllocator),
1276 Op, Ty);
1277 UniqueSCEVs.InsertNode(S, IP);
1278 registerUser(S, Op);
1279 return S;
1280}
1281
1282// Get the limit of a recurrence such that incrementing by Step cannot cause
1283// signed overflow as long as the value of the recurrence within the
1284// loop does not exceed this limit before incrementing.
1285static const SCEV *getSignedOverflowLimitForStep(const SCEV *Step,
1286 ICmpInst::Predicate *Pred,
1287 ScalarEvolution *SE) {
1288 unsigned BitWidth = SE->getTypeSizeInBits(Step->getType());
1289 if (SE->isKnownPositive(Step)) {
1290 *Pred = ICmpInst::ICMP_SLT;
1292 SE->getSignedRangeMax(Step));
1293 }
1294 if (SE->isKnownNegative(Step)) {
1295 *Pred = ICmpInst::ICMP_SGT;
1297 SE->getSignedRangeMin(Step));
1298 }
1299 return nullptr;
1300}
1301
1302// Get the limit of a recurrence such that incrementing by Step cannot cause
1303// unsigned overflow as long as the value of the recurrence within the loop does
1304// not exceed this limit before incrementing.
1306 ICmpInst::Predicate *Pred,
1307 ScalarEvolution *SE) {
1308 unsigned BitWidth = SE->getTypeSizeInBits(Step->getType());
1309 *Pred = ICmpInst::ICMP_ULT;
1310
1312 SE->getUnsignedRangeMax(Step));
1313}
1314
1315namespace {
1316
1317struct ExtendOpTraitsBase {
1318 typedef const SCEV *(ScalarEvolution::*GetExtendExprTy)(const SCEV *, Type *,
1319 unsigned);
1320};
1321
1322// Used to make code generic over signed and unsigned overflow.
1323template <typename ExtendOp> struct ExtendOpTraits {
1324 // Members present:
1325 //
1326 // static const SCEV::NoWrapFlags WrapType;
1327 //
1328 // static const ExtendOpTraitsBase::GetExtendExprTy GetExtendExpr;
1329 //
1330 // static const SCEV *getOverflowLimitForStep(const SCEV *Step,
1331 // ICmpInst::Predicate *Pred,
1332 // ScalarEvolution *SE);
1333};
1334
1335template <>
1336struct ExtendOpTraits<SCEVSignExtendExpr> : public ExtendOpTraitsBase {
1337 static const SCEV::NoWrapFlags WrapType = SCEV::FlagNSW;
1338
1339 static const GetExtendExprTy GetExtendExpr;
1340
1341 static const SCEV *getOverflowLimitForStep(const SCEV *Step,
1342 ICmpInst::Predicate *Pred,
1343 ScalarEvolution *SE) {
1344 return getSignedOverflowLimitForStep(Step, Pred, SE);
1345 }
1346};
1347
1348const ExtendOpTraitsBase::GetExtendExprTy ExtendOpTraits<
1350
1351template <>
1352struct ExtendOpTraits<SCEVZeroExtendExpr> : public ExtendOpTraitsBase {
1353 static const SCEV::NoWrapFlags WrapType = SCEV::FlagNUW;
1354
1355 static const GetExtendExprTy GetExtendExpr;
1356
1357 static const SCEV *getOverflowLimitForStep(const SCEV *Step,
1358 ICmpInst::Predicate *Pred,
1359 ScalarEvolution *SE) {
1360 return getUnsignedOverflowLimitForStep(Step, Pred, SE);
1361 }
1362};
1363
1364const ExtendOpTraitsBase::GetExtendExprTy ExtendOpTraits<
1366
1367} // end anonymous namespace
1368
1369// The recurrence AR has been shown to have no signed/unsigned wrap or something
1370// close to it. Typically, if we can prove NSW/NUW for AR, then we can just as
1371// easily prove NSW/NUW for its preincrement or postincrement sibling. This
1372// allows normalizing a sign/zero extended AddRec as such: {sext/zext(Step +
1373// Start),+,Step} => {(Step + sext/zext(Start),+,Step} As a result, the
1374// expression "Step + sext/zext(PreIncAR)" is congruent with
1375// "sext/zext(PostIncAR)"
1376template <typename ExtendOpTy>
1377static const SCEV *getPreStartForExtend(const SCEVAddRecExpr *AR, Type *Ty,
1378 ScalarEvolution *SE, unsigned Depth) {
1379 auto WrapType = ExtendOpTraits<ExtendOpTy>::WrapType;
1380 auto GetExtendExpr = ExtendOpTraits<ExtendOpTy>::GetExtendExpr;
1381
1382 const Loop *L = AR->getLoop();
1383 const SCEV *Start = AR->getStart();
1384 const SCEV *Step = AR->getStepRecurrence(*SE);
1385
1386 // Check for a simple looking step prior to loop entry.
1387 const SCEVAddExpr *SA = dyn_cast<SCEVAddExpr>(Start);
1388 if (!SA)
1389 return nullptr;
1390
1391 // Create an AddExpr for "PreStart" after subtracting Step. Full SCEV
1392 // subtraction is expensive. For this purpose, perform a quick and dirty
1393 // difference, by checking for Step in the operand list. Note, that
1394 // SA might have repeated ops, like %a + %a + ..., so only remove one.
1395 SmallVector<SCEVUse, 4> DiffOps(SA->operands());
1396 for (auto It = DiffOps.begin(); It != DiffOps.end(); ++It)
1397 if (*It == Step) {
1398 DiffOps.erase(It);
1399 break;
1400 }
1401
1402 if (DiffOps.size() == SA->getNumOperands())
1403 return nullptr;
1404
1405 // Try to prove `WrapType` (SCEV::FlagNSW or SCEV::FlagNUW) on `PreStart` +
1406 // `Step`:
1407
1408 // 1. NSW/NUW flags on the step increment.
1409 auto PreStartFlags =
1411 const SCEV *PreStart = SE->getAddExpr(DiffOps, PreStartFlags);
1413 SE->getAddRecExpr(PreStart, Step, L, SCEV::FlagAnyWrap));
1414
1415 // "{S,+,X} is <nsw>/<nuw>" and "the backedge is taken at least once" implies
1416 // "S+X does not sign/unsign-overflow".
1417 //
1418
1419 const SCEV *BECount = SE->getBackedgeTakenCount(L);
1420 if (PreAR && PreAR->getNoWrapFlags(WrapType) &&
1421 !isa<SCEVCouldNotCompute>(BECount) && SE->isKnownPositive(BECount))
1422 return PreStart;
1423
1424 // 2. Direct overflow check on the step operation's expression.
1425 unsigned BitWidth = SE->getTypeSizeInBits(AR->getType());
1426 Type *WideTy = IntegerType::get(SE->getContext(), BitWidth * 2);
1427 const SCEV *OperandExtendedStart =
1428 SE->getAddExpr((SE->*GetExtendExpr)(PreStart, WideTy, Depth),
1429 (SE->*GetExtendExpr)(Step, WideTy, Depth));
1430 if ((SE->*GetExtendExpr)(Start, WideTy, Depth) == OperandExtendedStart) {
1431 if (PreAR && AR->getNoWrapFlags(WrapType)) {
1432 // If we know `AR` == {`PreStart`+`Step`,+,`Step`} is `WrapType` (FlagNSW
1433 // or FlagNUW) and that `PreStart` + `Step` is `WrapType` too, then
1434 // `PreAR` == {`PreStart`,+,`Step`} is also `WrapType`. Cache this fact.
1435 SE->setNoWrapFlags(const_cast<SCEVAddRecExpr *>(PreAR), WrapType);
1436 }
1437 return PreStart;
1438 }
1439
1440 // 3. Loop precondition.
1442 const SCEV *OverflowLimit =
1443 ExtendOpTraits<ExtendOpTy>::getOverflowLimitForStep(Step, &Pred, SE);
1444
1445 if (OverflowLimit &&
1446 SE->isLoopEntryGuardedByCond(L, Pred, PreStart, OverflowLimit))
1447 return PreStart;
1448
1449 return nullptr;
1450}
1451
1452// Get the normalized zero or sign extended expression for this AddRec's Start.
1453template <typename ExtendOpTy>
1454static const SCEV *getExtendAddRecStart(const SCEVAddRecExpr *AR, Type *Ty,
1455 ScalarEvolution *SE,
1456 unsigned Depth) {
1457 auto GetExtendExpr = ExtendOpTraits<ExtendOpTy>::GetExtendExpr;
1458
1459 const SCEV *PreStart = getPreStartForExtend<ExtendOpTy>(AR, Ty, SE, Depth);
1460 if (!PreStart)
1461 return (SE->*GetExtendExpr)(AR->getStart(), Ty, Depth);
1462
1463 return SE->getAddExpr((SE->*GetExtendExpr)(AR->getStepRecurrence(*SE), Ty,
1464 Depth),
1465 (SE->*GetExtendExpr)(PreStart, Ty, Depth));
1466}
1467
1468// Try to prove away overflow by looking at "nearby" add recurrences. A
1469// motivating example for this rule: if we know `{0,+,4}` is `ult` `-1` and it
1470// does not itself wrap then we can conclude that `{1,+,4}` is `nuw`.
1471//
1472// Formally:
1473//
1474// {S,+,X} == {S-T,+,X} + T
1475// => Ext({S,+,X}) == Ext({S-T,+,X} + T)
1476//
1477// If ({S-T,+,X} + T) does not overflow ... (1)
1478//
1479// RHS == Ext({S-T,+,X} + T) == Ext({S-T,+,X}) + Ext(T)
1480//
1481// If {S-T,+,X} does not overflow ... (2)
1482//
1483// RHS == Ext({S-T,+,X}) + Ext(T) == {Ext(S-T),+,Ext(X)} + Ext(T)
1484// == {Ext(S-T)+Ext(T),+,Ext(X)}
1485//
1486// If (S-T)+T does not overflow ... (3)
1487//
1488// RHS == {Ext(S-T)+Ext(T),+,Ext(X)} == {Ext(S-T+T),+,Ext(X)}
1489// == {Ext(S),+,Ext(X)} == LHS
1490//
1491// Thus, if (1), (2) and (3) are true for some T, then
1492// Ext({S,+,X}) == {Ext(S),+,Ext(X)}
1493//
1494// (3) is implied by (1) -- "(S-T)+T does not overflow" is simply "({S-T,+,X}+T)
1495// does not overflow" restricted to the 0th iteration. Therefore we only need
1496// to check for (1) and (2).
1497//
1498// In the current context, S is `Start`, X is `Step`, Ext is `ExtendOpTy` and T
1499// is `Delta` (defined below).
1500template <typename ExtendOpTy>
1501bool ScalarEvolution::proveNoWrapByVaryingStart(const SCEV *Start,
1502 const SCEV *Step,
1503 const Loop *L) {
1504 auto WrapType = ExtendOpTraits<ExtendOpTy>::WrapType;
1505
1506 // We restrict `Start` to a constant to prevent SCEV from spending too much
1507 // time here. It is correct (but more expensive) to continue with a
1508 // non-constant `Start` and do a general SCEV subtraction to compute
1509 // `PreStart` below.
1510 const SCEVConstant *StartC = dyn_cast<SCEVConstant>(Start);
1511 if (!StartC)
1512 return false;
1513
1514 APInt StartAI = StartC->getAPInt();
1515
1516 for (unsigned Delta : {-2, -1, 1, 2}) {
1517 const SCEV *PreStart = getConstant(StartAI - Delta);
1518
1519 FoldingSetNodeID ID;
1520 ID.AddInteger(scAddRecExpr);
1521 ID.AddPointer(PreStart);
1522 ID.AddPointer(Step);
1523 ID.AddPointer(L);
1524 void *IP = nullptr;
1525 const auto *PreAR =
1526 static_cast<SCEVAddRecExpr *>(UniqueSCEVs.FindNodeOrInsertPos(ID, IP));
1527
1528 // Give up if we don't already have the add recurrence we need because
1529 // actually constructing an add recurrence is relatively expensive.
1530 if (PreAR && PreAR->getNoWrapFlags(WrapType)) { // proves (2)
1531 const SCEV *DeltaS = getConstant(StartC->getType(), Delta);
1533 const SCEV *Limit = ExtendOpTraits<ExtendOpTy>::getOverflowLimitForStep(
1534 DeltaS, &Pred, this);
1535 if (Limit && isKnownPredicate(Pred, PreAR, Limit)) // proves (1)
1536 return true;
1537 }
1538 }
1539
1540 return false;
1541}
1542
1543// Finds an integer D for an expression (C + x + y + ...) such that the top
1544// level addition in (D + (C - D + x + y + ...)) would not wrap (signed or
1545// unsigned) and the number of trailing zeros of (C - D + x + y + ...) is
1546// maximized, where C is the \p ConstantTerm, x, y, ... are arbitrary SCEVs, and
1547// the (C + x + y + ...) expression is \p WholeAddExpr.
1549 const SCEVConstant *ConstantTerm,
1550 const SCEVAddExpr *WholeAddExpr) {
1551 const APInt &C = ConstantTerm->getAPInt();
1552 const unsigned BitWidth = C.getBitWidth();
1553 // Find number of trailing zeros of (x + y + ...) w/o the C first:
1554 uint32_t TZ = BitWidth;
1555 for (unsigned I = 1, E = WholeAddExpr->getNumOperands(); I < E && TZ; ++I)
1556 TZ = std::min(TZ, SE.getMinTrailingZeros(WholeAddExpr->getOperand(I)));
1557 if (TZ) {
1558 // Set D to be as many least significant bits of C as possible while still
1559 // guaranteeing that adding D to (C - D + x + y + ...) won't cause a wrap:
1560 return TZ < BitWidth ? C.trunc(TZ).zext(BitWidth) : C;
1561 }
1562 return APInt(BitWidth, 0);
1563}
1564
1565// Finds an integer D for an affine AddRec expression {C,+,x} such that the top
1566// level addition in (D + {C-D,+,x}) would not wrap (signed or unsigned) and the
1567// number of trailing zeros of (C - D + x * n) is maximized, where C is the \p
1568// ConstantStart, x is an arbitrary \p Step, and n is the loop trip count.
1570 const APInt &ConstantStart,
1571 const SCEV *Step) {
1572 const unsigned BitWidth = ConstantStart.getBitWidth();
1573 const uint32_t TZ = SE.getMinTrailingZeros(Step);
1574 if (TZ)
1575 return TZ < BitWidth ? ConstantStart.trunc(TZ).zext(BitWidth)
1576 : ConstantStart;
1577 return APInt(BitWidth, 0);
1578}
1579
1581 const ScalarEvolution::FoldID &ID, const SCEV *S,
1584 &FoldCacheUser) {
1585 auto I = FoldCache.insert({ID, S});
1586 if (!I.second) {
1587 // Remove FoldCacheUser entry for ID when replacing an existing FoldCache
1588 // entry.
1589 auto &UserIDs = FoldCacheUser[I.first->second];
1590 assert(count(UserIDs, ID) == 1 && "unexpected duplicates in UserIDs");
1591 for (unsigned I = 0; I != UserIDs.size(); ++I)
1592 if (UserIDs[I] == ID) {
1593 std::swap(UserIDs[I], UserIDs.back());
1594 break;
1595 }
1596 UserIDs.pop_back();
1597 I.first->second = S;
1598 }
1599 FoldCacheUser[S].push_back(ID);
1600}
1601
1602const SCEV *
1604 assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) &&
1605 "This is not an extending conversion!");
1606 assert(isSCEVable(Ty) &&
1607 "This is not a conversion to a SCEVable type!");
1608 assert(!Op->getType()->isPointerTy() && "Can't extend pointer!");
1609 Ty = getEffectiveSCEVType(Ty);
1610
1611 FoldID ID(scZeroExtend, Op, Ty);
1612 if (const SCEV *S = FoldCache.lookup(ID))
1613 return S;
1614
1615 const SCEV *S = getZeroExtendExprImpl(Op, Ty, Depth);
1617 insertFoldCacheEntry(ID, S, FoldCache, FoldCacheUser);
1618 return S;
1619}
1620
1622 unsigned Depth) {
1623 assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) &&
1624 "This is not an extending conversion!");
1625 assert(isSCEVable(Ty) && "This is not a conversion to a SCEVable type!");
1626 assert(!Op->getType()->isPointerTy() && "Can't extend pointer!");
1627
1628 // Fold if the operand is constant.
1629 if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op))
1630 return getConstant(SC->getAPInt().zext(getTypeSizeInBits(Ty)));
1631
1632 // zext(zext(x)) --> zext(x)
1634 return getZeroExtendExpr(SZ->getOperand(), Ty, Depth + 1);
1635
1636 // Before doing any expensive analysis, check to see if we've already
1637 // computed a SCEV for this Op and Ty.
1639 ID.AddInteger(scZeroExtend);
1640 ID.AddPointer(Op);
1641 ID.AddPointer(Ty);
1642 void *IP = nullptr;
1643 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
1644 if (Depth > MaxCastDepth) {
1645 SCEV *S = new (SCEVAllocator) SCEVZeroExtendExpr(ID.Intern(SCEVAllocator),
1646 Op, Ty);
1647 UniqueSCEVs.InsertNode(S, IP);
1648 registerUser(S, Op);
1649 return S;
1650 }
1651
1652 // zext(trunc(x)) --> zext(x) or x or trunc(x)
1654 // It's possible the bits taken off by the truncate were all zero bits. If
1655 // so, we should be able to simplify this further.
1656 const SCEV *X = ST->getOperand();
1658 unsigned TruncBits = getTypeSizeInBits(ST->getType());
1659 unsigned NewBits = getTypeSizeInBits(Ty);
1660 if (CR.truncate(TruncBits).zeroExtend(NewBits).contains(
1661 CR.zextOrTrunc(NewBits)))
1662 return getTruncateOrZeroExtend(X, Ty, Depth);
1663 }
1664
1665 // If the input value is a chrec scev, and we can prove that the value
1666 // did not overflow the old, smaller, value, we can zero extend all of the
1667 // operands (often constants). This allows analysis of something like
1668 // this: for (unsigned char X = 0; X < 100; ++X) { int Y = X; }
1670 if (AR->isAffine()) {
1671 const SCEV *Start = AR->getStart();
1672 const SCEV *Step = AR->getStepRecurrence(*this);
1673 unsigned BitWidth = getTypeSizeInBits(AR->getType());
1674 const Loop *L = AR->getLoop();
1675
1676 // If we have special knowledge that this addrec won't overflow,
1677 // we don't need to do any further analysis.
1678 if (AR->hasNoUnsignedWrap()) {
1679 Start =
1681 Step = getZeroExtendExpr(Step, Ty, Depth + 1);
1682 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
1683 }
1684
1685 // Check whether the backedge-taken count is SCEVCouldNotCompute.
1686 // Note that this serves two purposes: It filters out loops that are
1687 // simply not analyzable, and it covers the case where this code is
1688 // being called from within backedge-taken count analysis, such that
1689 // attempting to ask for the backedge-taken count would likely result
1690 // in infinite recursion. In the later case, the analysis code will
1691 // cope with a conservative value, and it will take care to purge
1692 // that value once it has finished.
1693 const SCEV *MaxBECount = getConstantMaxBackedgeTakenCount(L);
1694 if (!isa<SCEVCouldNotCompute>(MaxBECount)) {
1695 // Manually compute the final value for AR, checking for overflow.
1696
1697 // Check whether the backedge-taken count can be losslessly casted to
1698 // the addrec's type. The count is always unsigned.
1699 const SCEV *CastedMaxBECount =
1700 getTruncateOrZeroExtend(MaxBECount, Start->getType(), Depth);
1701 const SCEV *RecastedMaxBECount = getTruncateOrZeroExtend(
1702 CastedMaxBECount, MaxBECount->getType(), Depth);
1703 if (MaxBECount == RecastedMaxBECount) {
1704 Type *WideTy = IntegerType::get(getContext(), BitWidth * 2);
1705 // Check whether Start+Step*MaxBECount has no unsigned overflow.
1706 const SCEV *ZMul = getMulExpr(CastedMaxBECount, Step,
1708 const SCEV *ZAdd = getZeroExtendExpr(getAddExpr(Start, ZMul,
1710 Depth + 1),
1711 WideTy, Depth + 1);
1712 const SCEV *WideStart = getZeroExtendExpr(Start, WideTy, Depth + 1);
1713 const SCEV *WideMaxBECount =
1714 getZeroExtendExpr(CastedMaxBECount, WideTy, Depth + 1);
1715 const SCEV *OperandExtendedAdd =
1716 getAddExpr(WideStart,
1717 getMulExpr(WideMaxBECount,
1718 getZeroExtendExpr(Step, WideTy, Depth + 1),
1721 if (ZAdd == OperandExtendedAdd) {
1722 // Cache knowledge of AR NUW, which is propagated to this AddRec.
1723 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNUW);
1724 // Return the expression with the addrec on the outside.
1725 Start = getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this,
1726 Depth + 1);
1727 Step = getZeroExtendExpr(Step, Ty, Depth + 1);
1728 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
1729 }
1730 // Similar to above, only this time treat the step value as signed.
1731 // This covers loops that count down.
1732 OperandExtendedAdd =
1733 getAddExpr(WideStart,
1734 getMulExpr(WideMaxBECount,
1735 getSignExtendExpr(Step, WideTy, Depth + 1),
1738 if (ZAdd == OperandExtendedAdd) {
1739 // Cache knowledge of AR NW, which is propagated to this AddRec.
1740 // Negative step causes unsigned wrap, but it still can't self-wrap.
1741 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNW);
1742 // Return the expression with the addrec on the outside.
1743 Start = getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this,
1744 Depth + 1);
1745 Step = getSignExtendExpr(Step, Ty, Depth + 1);
1746 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
1747 }
1748 }
1749 }
1750
1751 // Normally, in the cases we can prove no-overflow via a
1752 // backedge guarding condition, we can also compute a backedge
1753 // taken count for the loop. The exceptions are assumptions and
1754 // guards present in the loop -- SCEV is not great at exploiting
1755 // these to compute max backedge taken counts, but can still use
1756 // these to prove lack of overflow. Use this fact to avoid
1757 // doing extra work that may not pay off.
1758 if (!isa<SCEVCouldNotCompute>(MaxBECount) || HasGuards ||
1759 !AC.assumptions().empty()) {
1760
1761 auto NewFlags = proveNoUnsignedWrapViaInduction(AR);
1762 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), NewFlags);
1763 if (AR->hasNoUnsignedWrap()) {
1764 // Same as nuw case above - duplicated here to avoid a compile time
1765 // issue. It's not clear that the order of checks does matter, but
1766 // it's one of two issue possible causes for a change which was
1767 // reverted. Be conservative for the moment.
1768 Start =
1770 Step = getZeroExtendExpr(Step, Ty, Depth + 1);
1771 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
1772 }
1773
1774 // For a negative step, we can extend the operands iff doing so only
1775 // traverses values in the range zext([0,UINT_MAX]).
1776 if (isKnownNegative(Step)) {
1778 getSignedRangeMin(Step));
1781 // Cache knowledge of AR NW, which is propagated to this
1782 // AddRec. Negative step causes unsigned wrap, but it
1783 // still can't self-wrap.
1784 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNW);
1785 // Return the expression with the addrec on the outside.
1786 Start = getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this,
1787 Depth + 1);
1788 Step = getSignExtendExpr(Step, Ty, Depth + 1);
1789 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
1790 }
1791 }
1792 }
1793
1794 // zext({C,+,Step}) --> (zext(D) + zext({C-D,+,Step}))<nuw><nsw>
1795 // if D + (C - D + Step * n) could be proven to not unsigned wrap
1796 // where D maximizes the number of trailing zeros of (C - D + Step * n)
1797 if (const auto *SC = dyn_cast<SCEVConstant>(Start)) {
1798 const APInt &C = SC->getAPInt();
1799 const APInt &D = extractConstantWithoutWrapping(*this, C, Step);
1800 if (D != 0) {
1801 const SCEV *SZExtD = getZeroExtendExpr(getConstant(D), Ty, Depth);
1802 const SCEV *SResidual =
1803 getAddRecExpr(getConstant(C - D), Step, L, AR->getNoWrapFlags());
1804 const SCEV *SZExtR = getZeroExtendExpr(SResidual, Ty, Depth + 1);
1805 return getAddExpr(SZExtD, SZExtR,
1807 Depth + 1);
1808 }
1809 }
1810
1811 if (proveNoWrapByVaryingStart<SCEVZeroExtendExpr>(Start, Step, L)) {
1812 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNUW);
1813 Start =
1815 Step = getZeroExtendExpr(Step, Ty, Depth + 1);
1816 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
1817 }
1818 }
1819
1820 // zext(A % B) --> zext(A) % zext(B)
1821 {
1822 const SCEV *LHS;
1823 const SCEV *RHS;
1824 if (match(Op, m_scev_URem(m_SCEV(LHS), m_SCEV(RHS), *this)))
1825 return getURemExpr(getZeroExtendExpr(LHS, Ty, Depth + 1),
1826 getZeroExtendExpr(RHS, Ty, Depth + 1));
1827 }
1828
1829 // zext(A / B) --> zext(A) / zext(B).
1830 if (auto *Div = dyn_cast<SCEVUDivExpr>(Op))
1831 return getUDivExpr(getZeroExtendExpr(Div->getLHS(), Ty, Depth + 1),
1832 getZeroExtendExpr(Div->getRHS(), Ty, Depth + 1));
1833
1834 if (auto *SA = dyn_cast<SCEVAddExpr>(Op)) {
1835 // zext((A + B + ...)<nuw>) --> (zext(A) + zext(B) + ...)<nuw>
1836 if (SA->hasNoUnsignedWrap()) {
1837 // If the addition does not unsign overflow then we can, by definition,
1838 // commute the zero extension with the addition operation.
1840 for (SCEVUse Op : SA->operands())
1841 Ops.push_back(getZeroExtendExpr(Op, Ty, Depth + 1));
1842 return getAddExpr(Ops, SCEV::FlagNUW, Depth + 1);
1843 }
1844
1845 // zext(C + x + y + ...) --> (zext(D) + zext((C - D) + x + y + ...))
1846 // if D + (C - D + x + y + ...) could be proven to not unsigned wrap
1847 // where D maximizes the number of trailing zeros of (C - D + x + y + ...)
1848 //
1849 // Often address arithmetics contain expressions like
1850 // (zext (add (shl X, C1), C2)), for instance, (zext (5 + (4 * X))).
1851 // This transformation is useful while proving that such expressions are
1852 // equal or differ by a small constant amount, see LoadStoreVectorizer pass.
1853 if (const auto *SC = dyn_cast<SCEVConstant>(SA->getOperand(0))) {
1854 const APInt &D = extractConstantWithoutWrapping(*this, SC, SA);
1855 if (D != 0) {
1856 const SCEV *SZExtD = getZeroExtendExpr(getConstant(D), Ty, Depth);
1857 const SCEV *SResidual =
1859 const SCEV *SZExtR = getZeroExtendExpr(SResidual, Ty, Depth + 1);
1860 return getAddExpr(SZExtD, SZExtR,
1862 Depth + 1);
1863 }
1864 }
1865 }
1866
1867 if (auto *SM = dyn_cast<SCEVMulExpr>(Op)) {
1868 // zext((A * B * ...)<nuw>) --> (zext(A) * zext(B) * ...)<nuw>
1869 if (SM->hasNoUnsignedWrap()) {
1870 // If the multiply does not unsign overflow then we can, by definition,
1871 // commute the zero extension with the multiply operation.
1873 for (SCEVUse Op : SM->operands())
1874 Ops.push_back(getZeroExtendExpr(Op, Ty, Depth + 1));
1875 return getMulExpr(Ops, SCEV::FlagNUW, Depth + 1);
1876 }
1877
1878 // zext(2^K * (trunc X to iN)) to iM ->
1879 // 2^K * (zext(trunc X to i{N-K}) to iM)<nuw>
1880 //
1881 // Proof:
1882 //
1883 // zext(2^K * (trunc X to iN)) to iM
1884 // = zext((trunc X to iN) << K) to iM
1885 // = zext((trunc X to i{N-K}) << K)<nuw> to iM
1886 // (because shl removes the top K bits)
1887 // = zext((2^K * (trunc X to i{N-K}))<nuw>) to iM
1888 // = (2^K * (zext(trunc X to i{N-K}) to iM))<nuw>.
1889 //
1890 const APInt *C;
1891 const SCEV *TruncRHS;
1892 if (match(SM,
1893 m_scev_Mul(m_scev_APInt(C), m_scev_Trunc(m_SCEV(TruncRHS)))) &&
1894 C->isPowerOf2()) {
1895 int NewTruncBits =
1896 getTypeSizeInBits(SM->getOperand(1)->getType()) - C->logBase2();
1897 Type *NewTruncTy = IntegerType::get(getContext(), NewTruncBits);
1898 return getMulExpr(
1899 getZeroExtendExpr(SM->getOperand(0), Ty),
1900 getZeroExtendExpr(getTruncateExpr(TruncRHS, NewTruncTy), Ty),
1901 SCEV::FlagNUW, Depth + 1);
1902 }
1903 }
1904
1905 // zext(umin(x, y)) -> umin(zext(x), zext(y))
1906 // zext(umax(x, y)) -> umax(zext(x), zext(y))
1909 SmallVector<SCEVUse, 4> Operands;
1910 for (SCEVUse Operand : MinMax->operands())
1911 Operands.push_back(getZeroExtendExpr(Operand, Ty));
1913 return getUMinExpr(Operands);
1914 return getUMaxExpr(Operands);
1915 }
1916
1917 // zext(umin_seq(x, y)) -> umin_seq(zext(x), zext(y))
1919 assert(isa<SCEVSequentialUMinExpr>(MinMax) && "Not supported!");
1920 SmallVector<SCEVUse, 4> Operands;
1921 for (SCEVUse Operand : MinMax->operands())
1922 Operands.push_back(getZeroExtendExpr(Operand, Ty));
1923 return getUMinExpr(Operands, /*Sequential*/ true);
1924 }
1925
1926 // The cast wasn't folded; create an explicit cast node.
1927 // Recompute the insert position, as it may have been invalidated.
1928 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
1929 SCEV *S = new (SCEVAllocator) SCEVZeroExtendExpr(ID.Intern(SCEVAllocator),
1930 Op, Ty);
1931 UniqueSCEVs.InsertNode(S, IP);
1932 registerUser(S, Op);
1933 return S;
1934}
1935
1936const SCEV *
1938 assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) &&
1939 "This is not an extending conversion!");
1940 assert(isSCEVable(Ty) &&
1941 "This is not a conversion to a SCEVable type!");
1942 assert(!Op->getType()->isPointerTy() && "Can't extend pointer!");
1943 Ty = getEffectiveSCEVType(Ty);
1944
1945 FoldID ID(scSignExtend, Op, Ty);
1946 if (const SCEV *S = FoldCache.lookup(ID))
1947 return S;
1948
1949 const SCEV *S = getSignExtendExprImpl(Op, Ty, Depth);
1951 insertFoldCacheEntry(ID, S, FoldCache, FoldCacheUser);
1952 return S;
1953}
1954
1956 unsigned Depth) {
1957 assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) &&
1958 "This is not an extending conversion!");
1959 assert(isSCEVable(Ty) && "This is not a conversion to a SCEVable type!");
1960 assert(!Op->getType()->isPointerTy() && "Can't extend pointer!");
1961 Ty = getEffectiveSCEVType(Ty);
1962
1963 // Fold if the operand is constant.
1964 if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op))
1965 return getConstant(SC->getAPInt().sext(getTypeSizeInBits(Ty)));
1966
1967 // sext(sext(x)) --> sext(x)
1969 return getSignExtendExpr(SS->getOperand(), Ty, Depth + 1);
1970
1971 // sext(zext(x)) --> zext(x)
1973 return getZeroExtendExpr(SZ->getOperand(), Ty, Depth + 1);
1974
1975 // Before doing any expensive analysis, check to see if we've already
1976 // computed a SCEV for this Op and Ty.
1978 ID.AddInteger(scSignExtend);
1979 ID.AddPointer(Op);
1980 ID.AddPointer(Ty);
1981 void *IP = nullptr;
1982 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
1983 // Limit recursion depth.
1984 if (Depth > MaxCastDepth) {
1985 SCEV *S = new (SCEVAllocator) SCEVSignExtendExpr(ID.Intern(SCEVAllocator),
1986 Op, Ty);
1987 UniqueSCEVs.InsertNode(S, IP);
1988 registerUser(S, Op);
1989 return S;
1990 }
1991
1992 // sext(trunc(x)) --> sext(x) or x or trunc(x)
1994 // It's possible the bits taken off by the truncate were all sign bits. If
1995 // so, we should be able to simplify this further.
1996 const SCEV *X = ST->getOperand();
1998 unsigned TruncBits = getTypeSizeInBits(ST->getType());
1999 unsigned NewBits = getTypeSizeInBits(Ty);
2000 if (CR.truncate(TruncBits).signExtend(NewBits).contains(
2001 CR.sextOrTrunc(NewBits)))
2002 return getTruncateOrSignExtend(X, Ty, Depth);
2003 }
2004
2005 if (auto *SA = dyn_cast<SCEVAddExpr>(Op)) {
2006 // sext((A + B + ...)<nsw>) --> (sext(A) + sext(B) + ...)<nsw>
2007 if (SA->hasNoSignedWrap()) {
2008 // If the addition does not sign overflow then we can, by definition,
2009 // commute the sign extension with the addition operation.
2011 for (SCEVUse Op : SA->operands())
2012 Ops.push_back(getSignExtendExpr(Op, Ty, Depth + 1));
2013 return getAddExpr(Ops, SCEV::FlagNSW, Depth + 1);
2014 }
2015
2016 // sext(C + x + y + ...) --> (sext(D) + sext((C - D) + x + y + ...))
2017 // if D + (C - D + x + y + ...) could be proven to not signed wrap
2018 // where D maximizes the number of trailing zeros of (C - D + x + y + ...)
2019 //
2020 // For instance, this will bring two seemingly different expressions:
2021 // 1 + sext(5 + 20 * %x + 24 * %y) and
2022 // sext(6 + 20 * %x + 24 * %y)
2023 // to the same form:
2024 // 2 + sext(4 + 20 * %x + 24 * %y)
2025 if (const auto *SC = dyn_cast<SCEVConstant>(SA->getOperand(0))) {
2026 const APInt &D = extractConstantWithoutWrapping(*this, SC, SA);
2027 if (D != 0) {
2028 const SCEV *SSExtD = getSignExtendExpr(getConstant(D), Ty, Depth);
2029 const SCEV *SResidual =
2031 const SCEV *SSExtR = getSignExtendExpr(SResidual, Ty, Depth + 1);
2032 return getAddExpr(SSExtD, SSExtR,
2034 Depth + 1);
2035 }
2036 }
2037 }
2038 // If the input value is a chrec scev, and we can prove that the value
2039 // did not overflow the old, smaller, value, we can sign extend all of the
2040 // operands (often constants). This allows analysis of something like
2041 // this: for (signed char X = 0; X < 100; ++X) { int Y = X; }
2043 if (AR->isAffine()) {
2044 const SCEV *Start = AR->getStart();
2045 const SCEV *Step = AR->getStepRecurrence(*this);
2046 unsigned BitWidth = getTypeSizeInBits(AR->getType());
2047 const Loop *L = AR->getLoop();
2048
2049 // If we have special knowledge that this addrec won't overflow,
2050 // we don't need to do any further analysis.
2051 if (AR->hasNoSignedWrap()) {
2052 Start =
2054 Step = getSignExtendExpr(Step, Ty, Depth + 1);
2055 return getAddRecExpr(Start, Step, L, SCEV::FlagNSW);
2056 }
2057
2058 // Check whether the backedge-taken count is SCEVCouldNotCompute.
2059 // Note that this serves two purposes: It filters out loops that are
2060 // simply not analyzable, and it covers the case where this code is
2061 // being called from within backedge-taken count analysis, such that
2062 // attempting to ask for the backedge-taken count would likely result
2063 // in infinite recursion. In the later case, the analysis code will
2064 // cope with a conservative value, and it will take care to purge
2065 // that value once it has finished.
2066 const SCEV *MaxBECount = getConstantMaxBackedgeTakenCount(L);
2067 if (!isa<SCEVCouldNotCompute>(MaxBECount)) {
2068 // Manually compute the final value for AR, checking for
2069 // overflow.
2070
2071 // Check whether the backedge-taken count can be losslessly casted to
2072 // the addrec's type. The count is always unsigned.
2073 const SCEV *CastedMaxBECount =
2074 getTruncateOrZeroExtend(MaxBECount, Start->getType(), Depth);
2075 const SCEV *RecastedMaxBECount = getTruncateOrZeroExtend(
2076 CastedMaxBECount, MaxBECount->getType(), Depth);
2077 if (MaxBECount == RecastedMaxBECount) {
2078 Type *WideTy = IntegerType::get(getContext(), BitWidth * 2);
2079 // Check whether Start+Step*MaxBECount has no signed overflow.
2080 const SCEV *SMul = getMulExpr(CastedMaxBECount, Step,
2082 const SCEV *SAdd = getSignExtendExpr(getAddExpr(Start, SMul,
2084 Depth + 1),
2085 WideTy, Depth + 1);
2086 const SCEV *WideStart = getSignExtendExpr(Start, WideTy, Depth + 1);
2087 const SCEV *WideMaxBECount =
2088 getZeroExtendExpr(CastedMaxBECount, WideTy, Depth + 1);
2089 const SCEV *OperandExtendedAdd =
2090 getAddExpr(WideStart,
2091 getMulExpr(WideMaxBECount,
2092 getSignExtendExpr(Step, WideTy, Depth + 1),
2095 if (SAdd == OperandExtendedAdd) {
2096 // Cache knowledge of AR NSW, which is propagated to this AddRec.
2097 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNSW);
2098 // Return the expression with the addrec on the outside.
2099 Start = getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this,
2100 Depth + 1);
2101 Step = getSignExtendExpr(Step, Ty, Depth + 1);
2102 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
2103 }
2104 // Similar to above, only this time treat the step value as unsigned.
2105 // This covers loops that count up with an unsigned step.
2106 OperandExtendedAdd =
2107 getAddExpr(WideStart,
2108 getMulExpr(WideMaxBECount,
2109 getZeroExtendExpr(Step, WideTy, Depth + 1),
2112 if (SAdd == OperandExtendedAdd) {
2113 // If AR wraps around then
2114 //
2115 // abs(Step) * MaxBECount > unsigned-max(AR->getType())
2116 // => SAdd != OperandExtendedAdd
2117 //
2118 // Thus (AR is not NW => SAdd != OperandExtendedAdd) <=>
2119 // (SAdd == OperandExtendedAdd => AR is NW)
2120
2121 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNW);
2122
2123 // Return the expression with the addrec on the outside.
2124 Start = getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this,
2125 Depth + 1);
2126 Step = getZeroExtendExpr(Step, Ty, Depth + 1);
2127 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
2128 }
2129 }
2130 }
2131
2132 auto NewFlags = proveNoSignedWrapViaInduction(AR);
2133 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), NewFlags);
2134 if (AR->hasNoSignedWrap()) {
2135 // Same as nsw case above - duplicated here to avoid a compile time
2136 // issue. It's not clear that the order of checks does matter, but
2137 // it's one of two issue possible causes for a change which was
2138 // reverted. Be conservative for the moment.
2139 Start =
2141 Step = getSignExtendExpr(Step, Ty, Depth + 1);
2142 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
2143 }
2144
2145 // sext({C,+,Step}) --> (sext(D) + sext({C-D,+,Step}))<nuw><nsw>
2146 // if D + (C - D + Step * n) could be proven to not signed wrap
2147 // where D maximizes the number of trailing zeros of (C - D + Step * n)
2148 if (const auto *SC = dyn_cast<SCEVConstant>(Start)) {
2149 const APInt &C = SC->getAPInt();
2150 const APInt &D = extractConstantWithoutWrapping(*this, C, Step);
2151 if (D != 0) {
2152 const SCEV *SSExtD = getSignExtendExpr(getConstant(D), Ty, Depth);
2153 const SCEV *SResidual =
2154 getAddRecExpr(getConstant(C - D), Step, L, AR->getNoWrapFlags());
2155 const SCEV *SSExtR = getSignExtendExpr(SResidual, Ty, Depth + 1);
2156 return getAddExpr(SSExtD, SSExtR,
2158 Depth + 1);
2159 }
2160 }
2161
2162 if (proveNoWrapByVaryingStart<SCEVSignExtendExpr>(Start, Step, L)) {
2163 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNSW);
2164 Start =
2166 Step = getSignExtendExpr(Step, Ty, Depth + 1);
2167 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
2168 }
2169 }
2170
2171 // If the input value is provably positive and we could not simplify
2172 // away the sext build a zext instead.
2174 return getZeroExtendExpr(Op, Ty, Depth + 1);
2175
2176 // sext(smin(x, y)) -> smin(sext(x), sext(y))
2177 // sext(smax(x, y)) -> smax(sext(x), sext(y))
2180 SmallVector<SCEVUse, 4> Operands;
2181 for (SCEVUse Operand : MinMax->operands())
2182 Operands.push_back(getSignExtendExpr(Operand, Ty));
2184 return getSMinExpr(Operands);
2185 return getSMaxExpr(Operands);
2186 }
2187
2188 // The cast wasn't folded; create an explicit cast node.
2189 // Recompute the insert position, as it may have been invalidated.
2190 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
2191 SCEV *S = new (SCEVAllocator) SCEVSignExtendExpr(ID.Intern(SCEVAllocator),
2192 Op, Ty);
2193 UniqueSCEVs.InsertNode(S, IP);
2194 registerUser(S, Op);
2195 return S;
2196}
2197
2199 Type *Ty) {
2200 switch (Kind) {
2201 case scTruncate:
2202 return getTruncateExpr(Op, Ty);
2203 case scZeroExtend:
2204 return getZeroExtendExpr(Op, Ty);
2205 case scSignExtend:
2206 return getSignExtendExpr(Op, Ty);
2207 case scPtrToInt:
2208 return getPtrToIntExpr(Op, Ty);
2209 default:
2210 llvm_unreachable("Not a SCEV cast expression!");
2211 }
2212}
2213
2214/// getAnyExtendExpr - Return a SCEV for the given operand extended with
2215/// unspecified bits out to the given type.
2217 Type *Ty) {
2218 assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) &&
2219 "This is not an extending conversion!");
2220 assert(isSCEVable(Ty) &&
2221 "This is not a conversion to a SCEVable type!");
2222 Ty = getEffectiveSCEVType(Ty);
2223
2224 // Sign-extend negative constants.
2225 if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op))
2226 if (SC->getAPInt().isNegative())
2227 return getSignExtendExpr(Op, Ty);
2228
2229 // Peel off a truncate cast.
2231 const SCEV *NewOp = T->getOperand();
2232 if (getTypeSizeInBits(NewOp->getType()) < getTypeSizeInBits(Ty))
2233 return getAnyExtendExpr(NewOp, Ty);
2234 return getTruncateOrNoop(NewOp, Ty);
2235 }
2236
2237 // Next try a zext cast. If the cast is folded, use it.
2238 const SCEV *ZExt = getZeroExtendExpr(Op, Ty);
2239 if (!isa<SCEVZeroExtendExpr>(ZExt))
2240 return ZExt;
2241
2242 // Next try a sext cast. If the cast is folded, use it.
2243 const SCEV *SExt = getSignExtendExpr(Op, Ty);
2244 if (!isa<SCEVSignExtendExpr>(SExt))
2245 return SExt;
2246
2247 // Force the cast to be folded into the operands of an addrec.
2248 if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Op)) {
2250 for (const SCEV *Op : AR->operands())
2251 Ops.push_back(getAnyExtendExpr(Op, Ty));
2252 return getAddRecExpr(Ops, AR->getLoop(), SCEV::FlagNW);
2253 }
2254
2255 // If the expression is obviously signed, use the sext cast value.
2256 if (isa<SCEVSMaxExpr>(Op))
2257 return SExt;
2258
2259 // Absent any other information, use the zext cast value.
2260 return ZExt;
2261}
2262
2263/// Process the given Ops list, which is a list of operands to be added under
2264/// the given scale, update the given map. This is a helper function for
2265/// getAddRecExpr. As an example of what it does, given a sequence of operands
2266/// that would form an add expression like this:
2267///
2268/// m + n + 13 + (A * (o + p + (B * (q + m + 29)))) + r + (-1 * r)
2269///
2270/// where A and B are constants, update the map with these values:
2271///
2272/// (m, 1+A*B), (n, 1), (o, A), (p, A), (q, A*B), (r, 0)
2273///
2274/// and add 13 + A*B*29 to AccumulatedConstant.
2275/// This will allow getAddRecExpr to produce this:
2276///
2277/// 13+A*B*29 + n + (m * (1+A*B)) + ((o + p) * A) + (q * A*B)
2278///
2279/// This form often exposes folding opportunities that are hidden in
2280/// the original operand list.
2281///
2282/// Return true iff it appears that any interesting folding opportunities
2283/// may be exposed. This helps getAddRecExpr short-circuit extra work in
2284/// the common case where no interesting opportunities are present, and
2285/// is also used as a check to avoid infinite recursion.
2288 APInt &AccumulatedConstant,
2290 const APInt &Scale,
2291 ScalarEvolution &SE) {
2292 bool Interesting = false;
2293
2294 // Iterate over the add operands. They are sorted, with constants first.
2295 unsigned i = 0;
2296 while (const SCEVConstant *C = dyn_cast<SCEVConstant>(Ops[i])) {
2297 ++i;
2298 // Pull a buried constant out to the outside.
2299 if (Scale != 1 || AccumulatedConstant != 0 || C->getValue()->isZero())
2300 Interesting = true;
2301 AccumulatedConstant += Scale * C->getAPInt();
2302 }
2303
2304 // Next comes everything else. We're especially interested in multiplies
2305 // here, but they're in the middle, so just visit the rest with one loop.
2306 for (; i != Ops.size(); ++i) {
2308 if (Mul && isa<SCEVConstant>(Mul->getOperand(0))) {
2309 APInt NewScale =
2310 Scale * cast<SCEVConstant>(Mul->getOperand(0))->getAPInt();
2311 if (Mul->getNumOperands() == 2 && isa<SCEVAddExpr>(Mul->getOperand(1))) {
2312 // A multiplication of a constant with another add; recurse.
2313 const SCEVAddExpr *Add = cast<SCEVAddExpr>(Mul->getOperand(1));
2314 Interesting |= CollectAddOperandsWithScales(
2315 M, NewOps, AccumulatedConstant, Add->operands(), NewScale, SE);
2316 } else {
2317 // A multiplication of a constant with some other value. Update
2318 // the map.
2319 SmallVector<SCEVUse, 4> MulOps(drop_begin(Mul->operands()));
2320 const SCEV *Key = SE.getMulExpr(MulOps);
2321 auto Pair = M.insert({Key, NewScale});
2322 if (Pair.second) {
2323 NewOps.push_back(Pair.first->first);
2324 } else {
2325 Pair.first->second += NewScale;
2326 // The map already had an entry for this value, which may indicate
2327 // a folding opportunity.
2328 Interesting = true;
2329 }
2330 }
2331 } else {
2332 // An ordinary operand. Update the map.
2333 auto Pair = M.insert({Ops[i], Scale});
2334 if (Pair.second) {
2335 NewOps.push_back(Pair.first->first);
2336 } else {
2337 Pair.first->second += Scale;
2338 // The map already had an entry for this value, which may indicate
2339 // a folding opportunity.
2340 Interesting = true;
2341 }
2342 }
2343 }
2344
2345 return Interesting;
2346}
2347
2349 const SCEV *LHS, const SCEV *RHS,
2350 const Instruction *CtxI) {
2352 unsigned);
2353 switch (BinOp) {
2354 default:
2355 llvm_unreachable("Unsupported binary op");
2356 case Instruction::Add:
2358 break;
2359 case Instruction::Sub:
2361 break;
2362 case Instruction::Mul:
2364 break;
2365 }
2366
2367 const SCEV *(ScalarEvolution::*Extension)(const SCEV *, Type *, unsigned) =
2370
2371 // Check ext(LHS op RHS) == ext(LHS) op ext(RHS)
2372 auto *NarrowTy = cast<IntegerType>(LHS->getType());
2373 auto *WideTy =
2374 IntegerType::get(NarrowTy->getContext(), NarrowTy->getBitWidth() * 2);
2375
2376 const SCEV *A = (this->*Extension)(
2377 (this->*Operation)(LHS, RHS, SCEV::FlagAnyWrap, 0), WideTy, 0);
2378 const SCEV *LHSB = (this->*Extension)(LHS, WideTy, 0);
2379 const SCEV *RHSB = (this->*Extension)(RHS, WideTy, 0);
2380 const SCEV *B = (this->*Operation)(LHSB, RHSB, SCEV::FlagAnyWrap, 0);
2381 if (A == B)
2382 return true;
2383 // Can we use context to prove the fact we need?
2384 if (!CtxI)
2385 return false;
2386 // TODO: Support mul.
2387 if (BinOp == Instruction::Mul)
2388 return false;
2389 auto *RHSC = dyn_cast<SCEVConstant>(RHS);
2390 // TODO: Lift this limitation.
2391 if (!RHSC)
2392 return false;
2393 APInt C = RHSC->getAPInt();
2394 unsigned NumBits = C.getBitWidth();
2395 bool IsSub = (BinOp == Instruction::Sub);
2396 bool IsNegativeConst = (Signed && C.isNegative());
2397 // Compute the direction and magnitude by which we need to check overflow.
2398 bool OverflowDown = IsSub ^ IsNegativeConst;
2399 APInt Magnitude = C;
2400 if (IsNegativeConst) {
2401 if (C == APInt::getSignedMinValue(NumBits))
2402 // TODO: SINT_MIN on inversion gives the same negative value, we don't
2403 // want to deal with that.
2404 return false;
2405 Magnitude = -C;
2406 }
2407
2409 if (OverflowDown) {
2410 // To avoid overflow down, we need to make sure that MIN + Magnitude <= LHS.
2411 APInt Min = Signed ? APInt::getSignedMinValue(NumBits)
2412 : APInt::getMinValue(NumBits);
2413 APInt Limit = Min + Magnitude;
2414 return isKnownPredicateAt(Pred, getConstant(Limit), LHS, CtxI);
2415 } else {
2416 // To avoid overflow up, we need to make sure that LHS <= MAX - Magnitude.
2417 APInt Max = Signed ? APInt::getSignedMaxValue(NumBits)
2418 : APInt::getMaxValue(NumBits);
2419 APInt Limit = Max - Magnitude;
2420 return isKnownPredicateAt(Pred, LHS, getConstant(Limit), CtxI);
2421 }
2422}
2423
2424std::optional<SCEV::NoWrapFlags>
2426 const OverflowingBinaryOperator *OBO) {
2427 // It cannot be done any better.
2428 if (OBO->hasNoUnsignedWrap() && OBO->hasNoSignedWrap())
2429 return std::nullopt;
2430
2432
2433 if (OBO->hasNoUnsignedWrap())
2435 if (OBO->hasNoSignedWrap())
2437
2438 bool Deduced = false;
2439
2440 if (OBO->getOpcode() != Instruction::Add &&
2441 OBO->getOpcode() != Instruction::Sub &&
2442 OBO->getOpcode() != Instruction::Mul)
2443 return std::nullopt;
2444
2445 const SCEV *LHS = getSCEV(OBO->getOperand(0));
2446 const SCEV *RHS = getSCEV(OBO->getOperand(1));
2447
2448 const Instruction *CtxI =
2450 if (!OBO->hasNoUnsignedWrap() &&
2452 /* Signed */ false, LHS, RHS, CtxI)) {
2454 Deduced = true;
2455 }
2456
2457 if (!OBO->hasNoSignedWrap() &&
2459 /* Signed */ true, LHS, RHS, CtxI)) {
2461 Deduced = true;
2462 }
2463
2464 if (Deduced)
2465 return Flags;
2466 return std::nullopt;
2467}
2468
2469// We're trying to construct a SCEV of type `Type' with `Ops' as operands and
2470// `OldFlags' as can't-wrap behavior. Infer a more aggressive set of
2471// can't-overflow flags for the operation if possible.
2475 SCEV::NoWrapFlags Flags) {
2476 using namespace std::placeholders;
2477
2478 using OBO = OverflowingBinaryOperator;
2479
2480 bool CanAnalyze =
2482 (void)CanAnalyze;
2483 assert(CanAnalyze && "don't call from other places!");
2484
2485 int SignOrUnsignMask = SCEV::FlagNUW | SCEV::FlagNSW;
2486 SCEV::NoWrapFlags SignOrUnsignWrap =
2487 ScalarEvolution::maskFlags(Flags, SignOrUnsignMask);
2488
2489 // If FlagNSW is true and all the operands are non-negative, infer FlagNUW.
2490 auto IsKnownNonNegative = [&](SCEVUse U) {
2491 return SE->isKnownNonNegative(U);
2492 };
2493
2494 if (SignOrUnsignWrap == SCEV::FlagNSW && all_of(Ops, IsKnownNonNegative))
2495 Flags =
2496 ScalarEvolution::setFlags(Flags, (SCEV::NoWrapFlags)SignOrUnsignMask);
2497
2498 SignOrUnsignWrap = ScalarEvolution::maskFlags(Flags, SignOrUnsignMask);
2499
2500 if (SignOrUnsignWrap != SignOrUnsignMask &&
2501 (Type == scAddExpr || Type == scMulExpr) && Ops.size() == 2 &&
2502 isa<SCEVConstant>(Ops[0])) {
2503
2504 auto Opcode = [&] {
2505 switch (Type) {
2506 case scAddExpr:
2507 return Instruction::Add;
2508 case scMulExpr:
2509 return Instruction::Mul;
2510 default:
2511 llvm_unreachable("Unexpected SCEV op.");
2512 }
2513 }();
2514
2515 const APInt &C = cast<SCEVConstant>(Ops[0])->getAPInt();
2516
2517 // (A <opcode> C) --> (A <opcode> C)<nsw> if the op doesn't sign overflow.
2518 if (!(SignOrUnsignWrap & SCEV::FlagNSW)) {
2520 Opcode, C, OBO::NoSignedWrap);
2521 if (NSWRegion.contains(SE->getSignedRange(Ops[1])))
2523 }
2524
2525 // (A <opcode> C) --> (A <opcode> C)<nuw> if the op doesn't unsign overflow.
2526 if (!(SignOrUnsignWrap & SCEV::FlagNUW)) {
2528 Opcode, C, OBO::NoUnsignedWrap);
2529 if (NUWRegion.contains(SE->getUnsignedRange(Ops[1])))
2531 }
2532 }
2533
2534 // <0,+,nonnegative><nw> is also nuw
2535 // TODO: Add corresponding nsw case
2537 !ScalarEvolution::hasFlags(Flags, SCEV::FlagNUW) && Ops.size() == 2 &&
2538 Ops[0]->isZero() && IsKnownNonNegative(Ops[1]))
2540
2541 // both (udiv X, Y) * Y and Y * (udiv X, Y) are always NUW
2543 Ops.size() == 2) {
2544 if (auto *UDiv = dyn_cast<SCEVUDivExpr>(Ops[0]))
2545 if (UDiv->getOperand(1) == Ops[1])
2547 if (auto *UDiv = dyn_cast<SCEVUDivExpr>(Ops[1]))
2548 if (UDiv->getOperand(1) == Ops[0])
2550 }
2551
2552 return Flags;
2553}
2554
2556 return isLoopInvariant(S, L) && properlyDominates(S, L->getHeader());
2557}
2558
2559/// Get a canonical add expression, or something simpler if possible.
2561 SCEV::NoWrapFlags OrigFlags,
2562 unsigned Depth) {
2563 assert(!(OrigFlags & ~(SCEV::FlagNUW | SCEV::FlagNSW)) &&
2564 "only nuw or nsw allowed");
2565 assert(!Ops.empty() && "Cannot get empty add!");
2566 if (Ops.size() == 1) return Ops[0];
2567#ifndef NDEBUG
2568 Type *ETy = getEffectiveSCEVType(Ops[0]->getType());
2569 for (unsigned i = 1, e = Ops.size(); i != e; ++i)
2570 assert(getEffectiveSCEVType(Ops[i]->getType()) == ETy &&
2571 "SCEVAddExpr operand types don't match!");
2572 unsigned NumPtrs = count_if(
2573 Ops, [](const SCEV *Op) { return Op->getType()->isPointerTy(); });
2574 assert(NumPtrs <= 1 && "add has at most one pointer operand");
2575#endif
2576
2577 const SCEV *Folded = constantFoldAndGroupOps(
2578 *this, LI, DT, Ops,
2579 [](const APInt &C1, const APInt &C2) { return C1 + C2; },
2580 [](const APInt &C) { return C.isZero(); }, // identity
2581 [](const APInt &C) { return false; }); // absorber
2582 if (Folded)
2583 return Folded;
2584
2585 unsigned Idx = isa<SCEVConstant>(Ops[0]) ? 1 : 0;
2586
2587 // Delay expensive flag strengthening until necessary.
2588 auto ComputeFlags = [this, OrigFlags](ArrayRef<SCEVUse> Ops) {
2589 return StrengthenNoWrapFlags(this, scAddExpr, Ops, OrigFlags);
2590 };
2591
2592 // Limit recursion calls depth.
2594 return getOrCreateAddExpr(Ops, ComputeFlags(Ops));
2595
2596 if (SCEV *S = findExistingSCEVInCache(scAddExpr, Ops)) {
2597 // Don't strengthen flags if we have no new information.
2598 SCEVAddExpr *Add = static_cast<SCEVAddExpr *>(S);
2599 if (Add->getNoWrapFlags(OrigFlags) != OrigFlags)
2600 Add->setNoWrapFlags(ComputeFlags(Ops));
2601 return S;
2602 }
2603
2604 // Okay, check to see if the same value occurs in the operand list more than
2605 // once. If so, merge them together into an multiply expression. Since we
2606 // sorted the list, these values are required to be adjacent.
2607 Type *Ty = Ops[0]->getType();
2608 bool FoundMatch = false;
2609 for (unsigned i = 0, e = Ops.size(); i != e-1; ++i)
2610 if (Ops[i] == Ops[i+1]) { // X + Y + Y --> X + Y*2
2611 // Scan ahead to count how many equal operands there are.
2612 unsigned Count = 2;
2613 while (i+Count != e && Ops[i+Count] == Ops[i])
2614 ++Count;
2615 // Merge the values into a multiply.
2616 SCEVUse Scale = getConstant(Ty, Count);
2617 const SCEV *Mul = getMulExpr(Scale, Ops[i], SCEV::FlagAnyWrap, Depth + 1);
2618 if (Ops.size() == Count)
2619 return Mul;
2620 Ops[i] = Mul;
2621 Ops.erase(Ops.begin()+i+1, Ops.begin()+i+Count);
2622 --i; e -= Count - 1;
2623 FoundMatch = true;
2624 }
2625 if (FoundMatch)
2626 return getAddExpr(Ops, OrigFlags, Depth + 1);
2627
2628 // Check for truncates. If all the operands are truncated from the same
2629 // type, see if factoring out the truncate would permit the result to be
2630 // folded. eg., n*trunc(x) + m*trunc(y) --> trunc(trunc(m)*x + trunc(n)*y)
2631 // if the contents of the resulting outer trunc fold to something simple.
2632 auto FindTruncSrcType = [&]() -> Type * {
2633 // We're ultimately looking to fold an addrec of truncs and muls of only
2634 // constants and truncs, so if we find any other types of SCEV
2635 // as operands of the addrec then we bail and return nullptr here.
2636 // Otherwise, we return the type of the operand of a trunc that we find.
2637 if (auto *T = dyn_cast<SCEVTruncateExpr>(Ops[Idx]))
2638 return T->getOperand()->getType();
2639 if (const auto *Mul = dyn_cast<SCEVMulExpr>(Ops[Idx])) {
2640 SCEVUse LastOp = Mul->getOperand(Mul->getNumOperands() - 1);
2641 if (const auto *T = dyn_cast<SCEVTruncateExpr>(LastOp))
2642 return T->getOperand()->getType();
2643 }
2644 return nullptr;
2645 };
2646 if (auto *SrcType = FindTruncSrcType()) {
2647 SmallVector<SCEVUse, 8> LargeOps;
2648 bool Ok = true;
2649 // Check all the operands to see if they can be represented in the
2650 // source type of the truncate.
2651 for (const SCEV *Op : Ops) {
2653 if (T->getOperand()->getType() != SrcType) {
2654 Ok = false;
2655 break;
2656 }
2657 LargeOps.push_back(T->getOperand());
2658 } else if (const SCEVConstant *C = dyn_cast<SCEVConstant>(Op)) {
2659 LargeOps.push_back(getAnyExtendExpr(C, SrcType));
2660 } else if (const SCEVMulExpr *M = dyn_cast<SCEVMulExpr>(Op)) {
2661 SmallVector<SCEVUse, 8> LargeMulOps;
2662 for (unsigned j = 0, f = M->getNumOperands(); j != f && Ok; ++j) {
2663 if (const SCEVTruncateExpr *T =
2664 dyn_cast<SCEVTruncateExpr>(M->getOperand(j))) {
2665 if (T->getOperand()->getType() != SrcType) {
2666 Ok = false;
2667 break;
2668 }
2669 LargeMulOps.push_back(T->getOperand());
2670 } else if (const auto *C = dyn_cast<SCEVConstant>(M->getOperand(j))) {
2671 LargeMulOps.push_back(getAnyExtendExpr(C, SrcType));
2672 } else {
2673 Ok = false;
2674 break;
2675 }
2676 }
2677 if (Ok)
2678 LargeOps.push_back(getMulExpr(LargeMulOps, SCEV::FlagAnyWrap, Depth + 1));
2679 } else {
2680 Ok = false;
2681 break;
2682 }
2683 }
2684 if (Ok) {
2685 // Evaluate the expression in the larger type.
2686 const SCEV *Fold = getAddExpr(LargeOps, SCEV::FlagAnyWrap, Depth + 1);
2687 // If it folds to something simple, use it. Otherwise, don't.
2688 if (isa<SCEVConstant>(Fold) || isa<SCEVUnknown>(Fold))
2689 return getTruncateExpr(Fold, Ty);
2690 }
2691 }
2692
2693 if (Ops.size() == 2) {
2694 // Check if we have an expression of the form ((X + C1) - C2), where C1 and
2695 // C2 can be folded in a way that allows retaining wrapping flags of (X +
2696 // C1).
2697 const SCEV *A = Ops[0];
2698 const SCEV *B = Ops[1];
2699 auto *AddExpr = dyn_cast<SCEVAddExpr>(B);
2700 auto *C = dyn_cast<SCEVConstant>(A);
2701 if (AddExpr && C && isa<SCEVConstant>(AddExpr->getOperand(0))) {
2702 auto C1 = cast<SCEVConstant>(AddExpr->getOperand(0))->getAPInt();
2703 auto C2 = C->getAPInt();
2704 SCEV::NoWrapFlags PreservedFlags = SCEV::FlagAnyWrap;
2705
2706 APInt ConstAdd = C1 + C2;
2707 auto AddFlags = AddExpr->getNoWrapFlags();
2708 // Adding a smaller constant is NUW if the original AddExpr was NUW.
2710 ConstAdd.ule(C1)) {
2711 PreservedFlags =
2713 }
2714
2715 // Adding a constant with the same sign and small magnitude is NSW, if the
2716 // original AddExpr was NSW.
2718 C1.isSignBitSet() == ConstAdd.isSignBitSet() &&
2719 ConstAdd.abs().ule(C1.abs())) {
2720 PreservedFlags =
2722 }
2723
2724 if (PreservedFlags != SCEV::FlagAnyWrap) {
2725 SmallVector<SCEVUse, 4> NewOps(AddExpr->operands());
2726 NewOps[0] = getConstant(ConstAdd);
2727 return getAddExpr(NewOps, PreservedFlags);
2728 }
2729 }
2730
2731 // Try to push the constant operand into a ZExt: A + zext (-A + B) -> zext
2732 // (B), if trunc (A) + -A + B does not unsigned-wrap.
2733 const SCEVAddExpr *InnerAdd;
2734 if (match(B, m_scev_ZExt(m_scev_Add(InnerAdd)))) {
2735 const SCEV *NarrowA = getTruncateExpr(A, InnerAdd->getType());
2736 if (NarrowA == getNegativeSCEV(InnerAdd->getOperand(0)) &&
2737 getZeroExtendExpr(NarrowA, B->getType()) == A &&
2738 hasFlags(StrengthenNoWrapFlags(this, scAddExpr, {NarrowA, InnerAdd},
2740 SCEV::FlagNUW)) {
2741 return getZeroExtendExpr(getAddExpr(NarrowA, InnerAdd), B->getType());
2742 }
2743 }
2744 }
2745
2746 // Canonicalize (-1 * urem X, Y) + X --> (Y * X/Y)
2747 const SCEV *Y;
2748 if (Ops.size() == 2 &&
2749 match(Ops[0],
2751 m_scev_URem(m_scev_Specific(Ops[1]), m_SCEV(Y), *this))))
2752 return getMulExpr(Y, getUDivExpr(Ops[1], Y));
2753
2754 // Skip past any other cast SCEVs.
2755 while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scAddExpr)
2756 ++Idx;
2757
2758 // If there are add operands they would be next.
2759 if (Idx < Ops.size()) {
2760 bool DeletedAdd = false;
2761 // If the original flags and all inlined SCEVAddExprs are NUW, use the
2762 // common NUW flag for expression after inlining. Other flags cannot be
2763 // preserved, because they may depend on the original order of operations.
2764 SCEV::NoWrapFlags CommonFlags = maskFlags(OrigFlags, SCEV::FlagNUW);
2765 while (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Ops[Idx])) {
2766 if (Ops.size() > AddOpsInlineThreshold ||
2767 Add->getNumOperands() > AddOpsInlineThreshold)
2768 break;
2769 // If we have an add, expand the add operands onto the end of the operands
2770 // list.
2771 Ops.erase(Ops.begin()+Idx);
2772 append_range(Ops, Add->operands());
2773 DeletedAdd = true;
2774 CommonFlags = maskFlags(CommonFlags, Add->getNoWrapFlags());
2775 }
2776
2777 // If we deleted at least one add, we added operands to the end of the list,
2778 // and they are not necessarily sorted. Recurse to resort and resimplify
2779 // any operands we just acquired.
2780 if (DeletedAdd)
2781 return getAddExpr(Ops, CommonFlags, Depth + 1);
2782 }
2783
2784 // Skip over the add expression until we get to a multiply.
2785 while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scMulExpr)
2786 ++Idx;
2787
2788 // Check to see if there are any folding opportunities present with
2789 // operands multiplied by constant values.
2790 if (Idx < Ops.size() && isa<SCEVMulExpr>(Ops[Idx])) {
2794 APInt AccumulatedConstant(BitWidth, 0);
2795 if (CollectAddOperandsWithScales(M, NewOps, AccumulatedConstant,
2796 Ops, APInt(BitWidth, 1), *this)) {
2797 struct APIntCompare {
2798 bool operator()(const APInt &LHS, const APInt &RHS) const {
2799 return LHS.ult(RHS);
2800 }
2801 };
2802
2803 // Some interesting folding opportunity is present, so its worthwhile to
2804 // re-generate the operands list. Group the operands by constant scale,
2805 // to avoid multiplying by the same constant scale multiple times.
2806 std::map<APInt, SmallVector<SCEVUse, 4>, APIntCompare> MulOpLists;
2807 for (const SCEV *NewOp : NewOps)
2808 MulOpLists[M.find(NewOp)->second].push_back(NewOp);
2809 // Re-generate the operands list.
2810 Ops.clear();
2811 if (AccumulatedConstant != 0)
2812 Ops.push_back(getConstant(AccumulatedConstant));
2813 for (auto &MulOp : MulOpLists) {
2814 if (MulOp.first == 1) {
2815 Ops.push_back(getAddExpr(MulOp.second, SCEV::FlagAnyWrap, Depth + 1));
2816 } else if (MulOp.first != 0) {
2817 Ops.push_back(getMulExpr(
2818 getConstant(MulOp.first),
2819 getAddExpr(MulOp.second, SCEV::FlagAnyWrap, Depth + 1),
2820 SCEV::FlagAnyWrap, Depth + 1));
2821 }
2822 }
2823 if (Ops.empty())
2824 return getZero(Ty);
2825 if (Ops.size() == 1)
2826 return Ops[0];
2827 return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
2828 }
2829 }
2830
2831 // If we are adding something to a multiply expression, make sure the
2832 // something is not already an operand of the multiply. If so, merge it into
2833 // the multiply.
2834 for (; Idx < Ops.size() && isa<SCEVMulExpr>(Ops[Idx]); ++Idx) {
2835 const SCEVMulExpr *Mul = cast<SCEVMulExpr>(Ops[Idx]);
2836 for (unsigned MulOp = 0, e = Mul->getNumOperands(); MulOp != e; ++MulOp) {
2837 const SCEV *MulOpSCEV = Mul->getOperand(MulOp);
2838 if (isa<SCEVConstant>(MulOpSCEV))
2839 continue;
2840 for (unsigned AddOp = 0, e = Ops.size(); AddOp != e; ++AddOp)
2841 if (MulOpSCEV == Ops[AddOp]) {
2842 // Fold W + X + (X * Y * Z) --> W + (X * ((Y*Z)+1))
2843 const SCEV *InnerMul = Mul->getOperand(MulOp == 0);
2844 if (Mul->getNumOperands() != 2) {
2845 // If the multiply has more than two operands, we must get the
2846 // Y*Z term.
2847 SmallVector<SCEVUse, 4> MulOps(Mul->operands().take_front(MulOp));
2848 append_range(MulOps, Mul->operands().drop_front(MulOp + 1));
2849 InnerMul = getMulExpr(MulOps, SCEV::FlagAnyWrap, Depth + 1);
2850 }
2851 const SCEV *AddOne =
2852 getAddExpr(getOne(Ty), InnerMul, SCEV::FlagAnyWrap, Depth + 1);
2853 const SCEV *OuterMul = getMulExpr(AddOne, MulOpSCEV,
2855 if (Ops.size() == 2) return OuterMul;
2856 if (AddOp < Idx) {
2857 Ops.erase(Ops.begin()+AddOp);
2858 Ops.erase(Ops.begin()+Idx-1);
2859 } else {
2860 Ops.erase(Ops.begin()+Idx);
2861 Ops.erase(Ops.begin()+AddOp-1);
2862 }
2863 Ops.push_back(OuterMul);
2864 return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
2865 }
2866
2867 // Check this multiply against other multiplies being added together.
2868 for (unsigned OtherMulIdx = Idx+1;
2869 OtherMulIdx < Ops.size() && isa<SCEVMulExpr>(Ops[OtherMulIdx]);
2870 ++OtherMulIdx) {
2871 const SCEVMulExpr *OtherMul = cast<SCEVMulExpr>(Ops[OtherMulIdx]);
2872 // If MulOp occurs in OtherMul, we can fold the two multiplies
2873 // together.
2874 for (unsigned OMulOp = 0, e = OtherMul->getNumOperands();
2875 OMulOp != e; ++OMulOp)
2876 if (OtherMul->getOperand(OMulOp) == MulOpSCEV) {
2877 // Fold X + (A*B*C) + (A*D*E) --> X + (A*(B*C+D*E))
2878 const SCEV *InnerMul1 = Mul->getOperand(MulOp == 0);
2879 if (Mul->getNumOperands() != 2) {
2880 SmallVector<SCEVUse, 4> MulOps(Mul->operands().take_front(MulOp));
2881 append_range(MulOps, Mul->operands().drop_front(MulOp+1));
2882 InnerMul1 = getMulExpr(MulOps, SCEV::FlagAnyWrap, Depth + 1);
2883 }
2884 const SCEV *InnerMul2 = OtherMul->getOperand(OMulOp == 0);
2885 if (OtherMul->getNumOperands() != 2) {
2887 OtherMul->operands().take_front(OMulOp));
2888 append_range(MulOps, OtherMul->operands().drop_front(OMulOp+1));
2889 InnerMul2 = getMulExpr(MulOps, SCEV::FlagAnyWrap, Depth + 1);
2890 }
2891 const SCEV *InnerMulSum =
2892 getAddExpr(InnerMul1, InnerMul2, SCEV::FlagAnyWrap, Depth + 1);
2893 const SCEV *OuterMul = getMulExpr(MulOpSCEV, InnerMulSum,
2895 if (Ops.size() == 2) return OuterMul;
2896 Ops.erase(Ops.begin()+Idx);
2897 Ops.erase(Ops.begin()+OtherMulIdx-1);
2898 Ops.push_back(OuterMul);
2899 return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
2900 }
2901 }
2902 }
2903 }
2904
2905 // If there are any add recurrences in the operands list, see if any other
2906 // added values are loop invariant. If so, we can fold them into the
2907 // recurrence.
2908 while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scAddRecExpr)
2909 ++Idx;
2910
2911 // Scan over all recurrences, trying to fold loop invariants into them.
2912 for (; Idx < Ops.size() && isa<SCEVAddRecExpr>(Ops[Idx]); ++Idx) {
2913 // Scan all of the other operands to this add and add them to the vector if
2914 // they are loop invariant w.r.t. the recurrence.
2916 const SCEVAddRecExpr *AddRec = cast<SCEVAddRecExpr>(Ops[Idx]);
2917 const Loop *AddRecLoop = AddRec->getLoop();
2918 for (unsigned i = 0, e = Ops.size(); i != e; ++i)
2919 if (isAvailableAtLoopEntry(Ops[i], AddRecLoop)) {
2920 LIOps.push_back(Ops[i]);
2921 Ops.erase(Ops.begin()+i);
2922 --i; --e;
2923 }
2924
2925 // If we found some loop invariants, fold them into the recurrence.
2926 if (!LIOps.empty()) {
2927 // Compute nowrap flags for the addition of the loop-invariant ops and
2928 // the addrec. Temporarily push it as an operand for that purpose. These
2929 // flags are valid in the scope of the addrec only.
2930 LIOps.push_back(AddRec);
2931 SCEV::NoWrapFlags Flags = ComputeFlags(LIOps);
2932 LIOps.pop_back();
2933
2934 // NLI + LI + {Start,+,Step} --> NLI + {LI+Start,+,Step}
2935 LIOps.push_back(AddRec->getStart());
2936
2937 SmallVector<SCEVUse, 4> AddRecOps(AddRec->operands());
2938
2939 // It is not in general safe to propagate flags valid on an add within
2940 // the addrec scope to one outside it. We must prove that the inner
2941 // scope is guaranteed to execute if the outer one does to be able to
2942 // safely propagate. We know the program is undefined if poison is
2943 // produced on the inner scoped addrec. We also know that *for this use*
2944 // the outer scoped add can't overflow (because of the flags we just
2945 // computed for the inner scoped add) without the program being undefined.
2946 // Proving that entry to the outer scope neccesitates entry to the inner
2947 // scope, thus proves the program undefined if the flags would be violated
2948 // in the outer scope.
2949 SCEV::NoWrapFlags AddFlags = Flags;
2950 if (AddFlags != SCEV::FlagAnyWrap) {
2951 auto *DefI = getDefiningScopeBound(LIOps);
2952 auto *ReachI = &*AddRecLoop->getHeader()->begin();
2953 if (!isGuaranteedToTransferExecutionTo(DefI, ReachI))
2954 AddFlags = SCEV::FlagAnyWrap;
2955 }
2956 AddRecOps[0] = getAddExpr(LIOps, AddFlags, Depth + 1);
2957
2958 // Build the new addrec. Propagate the NUW and NSW flags if both the
2959 // outer add and the inner addrec are guaranteed to have no overflow.
2960 // Always propagate NW.
2961 Flags = AddRec->getNoWrapFlags(setFlags(Flags, SCEV::FlagNW));
2962 const SCEV *NewRec = getAddRecExpr(AddRecOps, AddRecLoop, Flags);
2963
2964 // If all of the other operands were loop invariant, we are done.
2965 if (Ops.size() == 1) return NewRec;
2966
2967 // Otherwise, add the folded AddRec by the non-invariant parts.
2968 for (unsigned i = 0;; ++i)
2969 if (Ops[i] == AddRec) {
2970 Ops[i] = NewRec;
2971 break;
2972 }
2973 return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
2974 }
2975
2976 // Okay, if there weren't any loop invariants to be folded, check to see if
2977 // there are multiple AddRec's with the same loop induction variable being
2978 // added together. If so, we can fold them.
2979 for (unsigned OtherIdx = Idx+1;
2980 OtherIdx < Ops.size() && isa<SCEVAddRecExpr>(Ops[OtherIdx]);
2981 ++OtherIdx) {
2982 // We expect the AddRecExpr's to be sorted in reverse dominance order,
2983 // so that the 1st found AddRecExpr is dominated by all others.
2984 assert(DT.dominates(
2985 cast<SCEVAddRecExpr>(Ops[OtherIdx])->getLoop()->getHeader(),
2986 AddRec->getLoop()->getHeader()) &&
2987 "AddRecExprs are not sorted in reverse dominance order?");
2988 if (AddRecLoop == cast<SCEVAddRecExpr>(Ops[OtherIdx])->getLoop()) {
2989 // Other + {A,+,B}<L> + {C,+,D}<L> --> Other + {A+C,+,B+D}<L>
2990 SmallVector<SCEVUse, 4> AddRecOps(AddRec->operands());
2991 for (; OtherIdx != Ops.size() && isa<SCEVAddRecExpr>(Ops[OtherIdx]);
2992 ++OtherIdx) {
2993 const auto *OtherAddRec = cast<SCEVAddRecExpr>(Ops[OtherIdx]);
2994 if (OtherAddRec->getLoop() == AddRecLoop) {
2995 for (unsigned i = 0, e = OtherAddRec->getNumOperands();
2996 i != e; ++i) {
2997 if (i >= AddRecOps.size()) {
2998 append_range(AddRecOps, OtherAddRec->operands().drop_front(i));
2999 break;
3000 }
3001 AddRecOps[i] =
3002 getAddExpr(AddRecOps[i], OtherAddRec->getOperand(i),
3004 }
3005 Ops.erase(Ops.begin() + OtherIdx); --OtherIdx;
3006 }
3007 }
3008 // Step size has changed, so we cannot guarantee no self-wraparound.
3009 Ops[Idx] = getAddRecExpr(AddRecOps, AddRecLoop, SCEV::FlagAnyWrap);
3010 return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
3011 }
3012 }
3013
3014 // Otherwise couldn't fold anything into this recurrence. Move onto the
3015 // next one.
3016 }
3017
3018 // Okay, it looks like we really DO need an add expr. Check to see if we
3019 // already have one, otherwise create a new one.
3020 return getOrCreateAddExpr(Ops, ComputeFlags(Ops));
3021}
3022
3023const SCEV *ScalarEvolution::getOrCreateAddExpr(ArrayRef<SCEVUse> Ops,
3024 SCEV::NoWrapFlags Flags) {
3026 ID.AddInteger(scAddExpr);
3027 for (const SCEV *Op : Ops)
3028 ID.AddPointer(Op);
3029 void *IP = nullptr;
3030 SCEVAddExpr *S =
3031 static_cast<SCEVAddExpr *>(UniqueSCEVs.FindNodeOrInsertPos(ID, IP));
3032 if (!S) {
3033 SCEVUse *O = SCEVAllocator.Allocate<SCEVUse>(Ops.size());
3035 S = new (SCEVAllocator)
3036 SCEVAddExpr(ID.Intern(SCEVAllocator), O, Ops.size());
3037 UniqueSCEVs.InsertNode(S, IP);
3038 registerUser(S, Ops);
3039 }
3040 S->setNoWrapFlags(Flags);
3041 return S;
3042}
3043
3044const SCEV *ScalarEvolution::getOrCreateAddRecExpr(ArrayRef<SCEVUse> Ops,
3045 const Loop *L,
3046 SCEV::NoWrapFlags Flags) {
3047 FoldingSetNodeID ID;
3048 ID.AddInteger(scAddRecExpr);
3049 for (const SCEV *Op : Ops)
3050 ID.AddPointer(Op);
3051 ID.AddPointer(L);
3052 void *IP = nullptr;
3053 SCEVAddRecExpr *S =
3054 static_cast<SCEVAddRecExpr *>(UniqueSCEVs.FindNodeOrInsertPos(ID, IP));
3055 if (!S) {
3056 SCEVUse *O = SCEVAllocator.Allocate<SCEVUse>(Ops.size());
3058 S = new (SCEVAllocator)
3059 SCEVAddRecExpr(ID.Intern(SCEVAllocator), O, Ops.size(), L);
3060 UniqueSCEVs.InsertNode(S, IP);
3061 LoopUsers[L].push_back(S);
3062 registerUser(S, Ops);
3063 }
3064 setNoWrapFlags(S, Flags);
3065 return S;
3066}
3067
3068const SCEV *ScalarEvolution::getOrCreateMulExpr(ArrayRef<SCEVUse> Ops,
3069 SCEV::NoWrapFlags Flags) {
3070 FoldingSetNodeID ID;
3071 ID.AddInteger(scMulExpr);
3072 for (const SCEV *Op : Ops)
3073 ID.AddPointer(Op);
3074 void *IP = nullptr;
3075 SCEVMulExpr *S =
3076 static_cast<SCEVMulExpr *>(UniqueSCEVs.FindNodeOrInsertPos(ID, IP));
3077 if (!S) {
3078 SCEVUse *O = SCEVAllocator.Allocate<SCEVUse>(Ops.size());
3080 S = new (SCEVAllocator) SCEVMulExpr(ID.Intern(SCEVAllocator),
3081 O, Ops.size());
3082 UniqueSCEVs.InsertNode(S, IP);
3083 registerUser(S, Ops);
3084 }
3085 S->setNoWrapFlags(Flags);
3086 return S;
3087}
3088
3089static uint64_t umul_ov(uint64_t i, uint64_t j, bool &Overflow) {
3090 uint64_t k = i*j;
3091 if (j > 1 && k / j != i) Overflow = true;
3092 return k;
3093}
3094
3095/// Compute the result of "n choose k", the binomial coefficient. If an
3096/// intermediate computation overflows, Overflow will be set and the return will
3097/// be garbage. Overflow is not cleared on absence of overflow.
3098static uint64_t Choose(uint64_t n, uint64_t k, bool &Overflow) {
3099 // We use the multiplicative formula:
3100 // n(n-1)(n-2)...(n-(k-1)) / k(k-1)(k-2)...1 .
3101 // At each iteration, we take the n-th term of the numeral and divide by the
3102 // (k-n)th term of the denominator. This division will always produce an
3103 // integral result, and helps reduce the chance of overflow in the
3104 // intermediate computations. However, we can still overflow even when the
3105 // final result would fit.
3106
3107 if (n == 0 || n == k) return 1;
3108 if (k > n) return 0;
3109
3110 if (k > n/2)
3111 k = n-k;
3112
3113 uint64_t r = 1;
3114 for (uint64_t i = 1; i <= k; ++i) {
3115 r = umul_ov(r, n-(i-1), Overflow);
3116 r /= i;
3117 }
3118 return r;
3119}
3120
3121/// Determine if any of the operands in this SCEV are a constant or if
3122/// any of the add or multiply expressions in this SCEV contain a constant.
3123static bool containsConstantInAddMulChain(const SCEV *StartExpr) {
3124 struct FindConstantInAddMulChain {
3125 bool FoundConstant = false;
3126
3127 bool follow(const SCEV *S) {
3128 FoundConstant |= isa<SCEVConstant>(S);
3129 return isa<SCEVAddExpr>(S) || isa<SCEVMulExpr>(S);
3130 }
3131
3132 bool isDone() const {
3133 return FoundConstant;
3134 }
3135 };
3136
3137 FindConstantInAddMulChain F;
3139 ST.visitAll(StartExpr);
3140 return F.FoundConstant;
3141}
3142
3143/// Get a canonical multiply expression, or something simpler if possible.
3145 SCEV::NoWrapFlags OrigFlags,
3146 unsigned Depth) {
3147 assert(OrigFlags == maskFlags(OrigFlags, SCEV::FlagNUW | SCEV::FlagNSW) &&
3148 "only nuw or nsw allowed");
3149 assert(!Ops.empty() && "Cannot get empty mul!");
3150 if (Ops.size() == 1) return Ops[0];
3151#ifndef NDEBUG
3152 Type *ETy = Ops[0]->getType();
3153 assert(!ETy->isPointerTy());
3154 for (unsigned i = 1, e = Ops.size(); i != e; ++i)
3155 assert(Ops[i]->getType() == ETy &&
3156 "SCEVMulExpr operand types don't match!");
3157#endif
3158
3159 const SCEV *Folded = constantFoldAndGroupOps(
3160 *this, LI, DT, Ops,
3161 [](const APInt &C1, const APInt &C2) { return C1 * C2; },
3162 [](const APInt &C) { return C.isOne(); }, // identity
3163 [](const APInt &C) { return C.isZero(); }); // absorber
3164 if (Folded)
3165 return Folded;
3166
3167 // Delay expensive flag strengthening until necessary.
3168 auto ComputeFlags = [this, OrigFlags](const ArrayRef<SCEVUse> Ops) {
3169 return StrengthenNoWrapFlags(this, scMulExpr, Ops, OrigFlags);
3170 };
3171
3172 // Limit recursion calls depth.
3174 return getOrCreateMulExpr(Ops, ComputeFlags(Ops));
3175
3176 if (SCEV *S = findExistingSCEVInCache(scMulExpr, Ops)) {
3177 // Don't strengthen flags if we have no new information.
3178 SCEVMulExpr *Mul = static_cast<SCEVMulExpr *>(S);
3179 if (Mul->getNoWrapFlags(OrigFlags) != OrigFlags)
3180 Mul->setNoWrapFlags(ComputeFlags(Ops));
3181 return S;
3182 }
3183
3184 if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(Ops[0])) {
3185 if (Ops.size() == 2) {
3186 // C1*(C2+V) -> C1*C2 + C1*V
3187 // If any of Add's ops are Adds or Muls with a constant, apply this
3188 // transformation as well.
3189 //
3190 // TODO: There are some cases where this transformation is not
3191 // profitable; for example, Add = (C0 + X) * Y + Z. Maybe the scope of
3192 // this transformation should be narrowed down.
3193 const SCEV *Op0, *Op1;
3194 if (match(Ops[1], m_scev_Add(m_SCEV(Op0), m_SCEV(Op1))) &&
3196 const SCEV *LHS = getMulExpr(LHSC, Op0, SCEV::FlagAnyWrap, Depth + 1);
3197 const SCEV *RHS = getMulExpr(LHSC, Op1, SCEV::FlagAnyWrap, Depth + 1);
3198 return getAddExpr(LHS, RHS, SCEV::FlagAnyWrap, Depth + 1);
3199 }
3200
3201 if (Ops[0]->isAllOnesValue()) {
3202 // If we have a mul by -1 of an add, try distributing the -1 among the
3203 // add operands.
3204 if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Ops[1])) {
3206 bool AnyFolded = false;
3207 for (const SCEV *AddOp : Add->operands()) {
3208 const SCEV *Mul = getMulExpr(Ops[0], SCEVUse(AddOp),
3210 if (!isa<SCEVMulExpr>(Mul)) AnyFolded = true;
3211 NewOps.push_back(Mul);
3212 }
3213 if (AnyFolded)
3214 return getAddExpr(NewOps, SCEV::FlagAnyWrap, Depth + 1);
3215 } else if (const auto *AddRec = dyn_cast<SCEVAddRecExpr>(Ops[1])) {
3216 // Negation preserves a recurrence's no self-wrap property.
3217 SmallVector<SCEVUse, 4> Operands;
3218 for (const SCEV *AddRecOp : AddRec->operands())
3219 Operands.push_back(getMulExpr(Ops[0], SCEVUse(AddRecOp),
3220 SCEV::FlagAnyWrap, Depth + 1));
3221 // Let M be the minimum representable signed value. AddRec with nsw
3222 // multiplied by -1 can have signed overflow if and only if it takes a
3223 // value of M: M * (-1) would stay M and (M + 1) * (-1) would be the
3224 // maximum signed value. In all other cases signed overflow is
3225 // impossible.
3226 auto FlagsMask = SCEV::FlagNW;
3227 if (hasFlags(AddRec->getNoWrapFlags(), SCEV::FlagNSW)) {
3228 auto MinInt =
3229 APInt::getSignedMinValue(getTypeSizeInBits(AddRec->getType()));
3230 if (getSignedRangeMin(AddRec) != MinInt)
3231 FlagsMask = setFlags(FlagsMask, SCEV::FlagNSW);
3232 }
3233 return getAddRecExpr(Operands, AddRec->getLoop(),
3234 AddRec->getNoWrapFlags(FlagsMask));
3235 }
3236 }
3237
3238 // Try to push the constant operand into a ZExt: C * zext (A + B) ->
3239 // zext (C*A + C*B) if trunc (C) * (A + B) does not unsigned-wrap.
3240 const SCEVAddExpr *InnerAdd;
3241 if (match(Ops[1], m_scev_ZExt(m_scev_Add(InnerAdd)))) {
3242 const SCEV *NarrowC = getTruncateExpr(LHSC, InnerAdd->getType());
3243 if (isa<SCEVConstant>(InnerAdd->getOperand(0)) &&
3244 getZeroExtendExpr(NarrowC, Ops[1]->getType()) == LHSC &&
3245 hasFlags(StrengthenNoWrapFlags(this, scMulExpr, {NarrowC, InnerAdd},
3247 SCEV::FlagNUW)) {
3248 auto *Res = getMulExpr(NarrowC, InnerAdd, SCEV::FlagNUW, Depth + 1);
3249 return getZeroExtendExpr(Res, Ops[1]->getType(), Depth + 1);
3250 };
3251 }
3252
3253 // Try to fold (C1 * D /u C2) -> C1/C2 * D, if C1 and C2 are powers-of-2,
3254 // D is a multiple of C2, and C1 is a multiple of C2. If C2 is a multiple
3255 // of C1, fold to (D /u (C2 /u C1)).
3256 const SCEV *D;
3257 APInt C1V = LHSC->getAPInt();
3258 // (C1 * D /u C2) == -1 * -C1 * D /u C2 when C1 != INT_MIN. Don't treat -1
3259 // as -1 * 1, as it won't enable additional folds.
3260 if (C1V.isNegative() && !C1V.isMinSignedValue() && !C1V.isAllOnes())
3261 C1V = C1V.abs();
3262 const SCEVConstant *C2;
3263 if (C1V.isPowerOf2() &&
3265 C2->getAPInt().isPowerOf2() &&
3266 C1V.logBase2() <= getMinTrailingZeros(D)) {
3267 const SCEV *NewMul = nullptr;
3268 if (C1V.uge(C2->getAPInt())) {
3269 NewMul = getMulExpr(getUDivExpr(getConstant(C1V), C2), D);
3270 } else if (C2->getAPInt().logBase2() <= getMinTrailingZeros(D)) {
3271 assert(C1V.ugt(1) && "C1 <= 1 should have been folded earlier");
3272 NewMul = getUDivExpr(D, getUDivExpr(C2, getConstant(C1V)));
3273 }
3274 if (NewMul)
3275 return C1V == LHSC->getAPInt() ? NewMul : getNegativeSCEV(NewMul);
3276 }
3277 }
3278 }
3279
3280 // Skip over the add expression until we get to a multiply.
3281 unsigned Idx = 0;
3282 while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scMulExpr)
3283 ++Idx;
3284
3285 // If there are mul operands inline them all into this expression.
3286 if (Idx < Ops.size()) {
3287 bool DeletedMul = false;
3288 while (const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(Ops[Idx])) {
3289 if (Ops.size() > MulOpsInlineThreshold)
3290 break;
3291 // If we have an mul, expand the mul operands onto the end of the
3292 // operands list.
3293 Ops.erase(Ops.begin()+Idx);
3294 append_range(Ops, Mul->operands());
3295 DeletedMul = true;
3296 }
3297
3298 // If we deleted at least one mul, we added operands to the end of the
3299 // list, and they are not necessarily sorted. Recurse to resort and
3300 // resimplify any operands we just acquired.
3301 if (DeletedMul)
3302 return getMulExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
3303 }
3304
3305 // If there are any add recurrences in the operands list, see if any other
3306 // added values are loop invariant. If so, we can fold them into the
3307 // recurrence.
3308 while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scAddRecExpr)
3309 ++Idx;
3310
3311 // Scan over all recurrences, trying to fold loop invariants into them.
3312 for (; Idx < Ops.size() && isa<SCEVAddRecExpr>(Ops[Idx]); ++Idx) {
3313 // Scan all of the other operands to this mul and add them to the vector
3314 // if they are loop invariant w.r.t. the recurrence.
3316 const SCEVAddRecExpr *AddRec = cast<SCEVAddRecExpr>(Ops[Idx]);
3317 for (unsigned i = 0, e = Ops.size(); i != e; ++i)
3318 if (isAvailableAtLoopEntry(Ops[i], AddRec->getLoop())) {
3319 LIOps.push_back(Ops[i]);
3320 Ops.erase(Ops.begin()+i);
3321 --i; --e;
3322 }
3323
3324 // If we found some loop invariants, fold them into the recurrence.
3325 if (!LIOps.empty()) {
3326 // NLI * LI * {Start,+,Step} --> NLI * {LI*Start,+,LI*Step}
3328 NewOps.reserve(AddRec->getNumOperands());
3329 const SCEV *Scale = getMulExpr(LIOps, SCEV::FlagAnyWrap, Depth + 1);
3330
3331 // If both the mul and addrec are nuw, we can preserve nuw.
3332 // If both the mul and addrec are nsw, we can only preserve nsw if either
3333 // a) they are also nuw, or
3334 // b) all multiplications of addrec operands with scale are nsw.
3335 SCEV::NoWrapFlags Flags =
3336 AddRec->getNoWrapFlags(ComputeFlags({Scale, AddRec}));
3337
3338 for (unsigned i = 0, e = AddRec->getNumOperands(); i != e; ++i) {
3339 NewOps.push_back(getMulExpr(Scale, AddRec->getOperand(i),
3340 SCEV::FlagAnyWrap, Depth + 1));
3341
3342 if (hasFlags(Flags, SCEV::FlagNSW) && !hasFlags(Flags, SCEV::FlagNUW)) {
3344 Instruction::Mul, getSignedRange(Scale),
3346 if (!NSWRegion.contains(getSignedRange(AddRec->getOperand(i))))
3347 Flags = clearFlags(Flags, SCEV::FlagNSW);
3348 }
3349 }
3350
3351 const SCEV *NewRec = getAddRecExpr(NewOps, AddRec->getLoop(), Flags);
3352
3353 // If all of the other operands were loop invariant, we are done.
3354 if (Ops.size() == 1) return NewRec;
3355
3356 // Otherwise, multiply the folded AddRec by the non-invariant parts.
3357 for (unsigned i = 0;; ++i)
3358 if (Ops[i] == AddRec) {
3359 Ops[i] = NewRec;
3360 break;
3361 }
3362 return getMulExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
3363 }
3364
3365 // Okay, if there weren't any loop invariants to be folded, check to see
3366 // if there are multiple AddRec's with the same loop induction variable
3367 // being multiplied together. If so, we can fold them.
3368
3369 // {A1,+,A2,+,...,+,An}<L> * {B1,+,B2,+,...,+,Bn}<L>
3370 // = {x=1 in [ sum y=x..2x [ sum z=max(y-x, y-n)..min(x,n) [
3371 // choose(x, 2x)*choose(2x-y, x-z)*A_{y-z}*B_z
3372 // ]]],+,...up to x=2n}.
3373 // Note that the arguments to choose() are always integers with values
3374 // known at compile time, never SCEV objects.
3375 //
3376 // The implementation avoids pointless extra computations when the two
3377 // addrec's are of different length (mathematically, it's equivalent to
3378 // an infinite stream of zeros on the right).
3379 bool OpsModified = false;
3380 for (unsigned OtherIdx = Idx+1;
3381 OtherIdx != Ops.size() && isa<SCEVAddRecExpr>(Ops[OtherIdx]);
3382 ++OtherIdx) {
3383 const SCEVAddRecExpr *OtherAddRec =
3384 dyn_cast<SCEVAddRecExpr>(Ops[OtherIdx]);
3385 if (!OtherAddRec || OtherAddRec->getLoop() != AddRec->getLoop())
3386 continue;
3387
3388 // Limit max number of arguments to avoid creation of unreasonably big
3389 // SCEVAddRecs with very complex operands.
3390 if (AddRec->getNumOperands() + OtherAddRec->getNumOperands() - 1 >
3391 MaxAddRecSize || hasHugeExpression({AddRec, OtherAddRec}))
3392 continue;
3393
3394 bool Overflow = false;
3395 Type *Ty = AddRec->getType();
3396 bool LargerThan64Bits = getTypeSizeInBits(Ty) > 64;
3397 SmallVector<SCEVUse, 7> AddRecOps;
3398 for (int x = 0, xe = AddRec->getNumOperands() +
3399 OtherAddRec->getNumOperands() - 1; x != xe && !Overflow; ++x) {
3401 for (int y = x, ye = 2*x+1; y != ye && !Overflow; ++y) {
3402 uint64_t Coeff1 = Choose(x, 2*x - y, Overflow);
3403 for (int z = std::max(y-x, y-(int)AddRec->getNumOperands()+1),
3404 ze = std::min(x+1, (int)OtherAddRec->getNumOperands());
3405 z < ze && !Overflow; ++z) {
3406 uint64_t Coeff2 = Choose(2*x - y, x-z, Overflow);
3407 uint64_t Coeff;
3408 if (LargerThan64Bits)
3409 Coeff = umul_ov(Coeff1, Coeff2, Overflow);
3410 else
3411 Coeff = Coeff1*Coeff2;
3412 const SCEV *CoeffTerm = getConstant(Ty, Coeff);
3413 const SCEV *Term1 = AddRec->getOperand(y-z);
3414 const SCEV *Term2 = OtherAddRec->getOperand(z);
3415 SumOps.push_back(getMulExpr(CoeffTerm, Term1, Term2,
3416 SCEV::FlagAnyWrap, Depth + 1));
3417 }
3418 }
3419 if (SumOps.empty())
3420 SumOps.push_back(getZero(Ty));
3421 AddRecOps.push_back(getAddExpr(SumOps, SCEV::FlagAnyWrap, Depth + 1));
3422 }
3423 if (!Overflow) {
3424 const SCEV *NewAddRec = getAddRecExpr(AddRecOps, AddRec->getLoop(),
3426 if (Ops.size() == 2) return NewAddRec;
3427 Ops[Idx] = NewAddRec;
3428 Ops.erase(Ops.begin() + OtherIdx); --OtherIdx;
3429 OpsModified = true;
3430 AddRec = dyn_cast<SCEVAddRecExpr>(NewAddRec);
3431 if (!AddRec)
3432 break;
3433 }
3434 }
3435 if (OpsModified)
3436 return getMulExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
3437
3438 // Otherwise couldn't fold anything into this recurrence. Move onto the
3439 // next one.
3440 }
3441
3442 // Okay, it looks like we really DO need an mul expr. Check to see if we
3443 // already have one, otherwise create a new one.
3444 return getOrCreateMulExpr(Ops, ComputeFlags(Ops));
3445}
3446
3447/// Represents an unsigned remainder expression based on unsigned division.
3449 assert(getEffectiveSCEVType(LHS->getType()) ==
3450 getEffectiveSCEVType(RHS->getType()) &&
3451 "SCEVURemExpr operand types don't match!");
3452
3453 // Short-circuit easy cases
3454 if (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS)) {
3455 // If constant is one, the result is trivial
3456 if (RHSC->getValue()->isOne())
3457 return getZero(LHS->getType()); // X urem 1 --> 0
3458
3459 // If constant is a power of two, fold into a zext(trunc(LHS)).
3460 if (RHSC->getAPInt().isPowerOf2()) {
3461 Type *FullTy = LHS->getType();
3462 Type *TruncTy =
3463 IntegerType::get(getContext(), RHSC->getAPInt().logBase2());
3464 return getZeroExtendExpr(getTruncateExpr(LHS, TruncTy), FullTy);
3465 }
3466 }
3467
3468 // Fallback to %a == %x urem %y == %x -<nuw> ((%x udiv %y) *<nuw> %y)
3469 const SCEV *UDiv = getUDivExpr(LHS, RHS);
3470 const SCEV *Mult = getMulExpr(UDiv, RHS, SCEV::FlagNUW);
3471 return getMinusSCEV(LHS, Mult, SCEV::FlagNUW);
3472}
3473
3474/// Get a canonical unsigned division expression, or something simpler if
3475/// possible.
3477 assert(!LHS->getType()->isPointerTy() &&
3478 "SCEVUDivExpr operand can't be pointer!");
3479 assert(LHS->getType() == RHS->getType() &&
3480 "SCEVUDivExpr operand types don't match!");
3481
3483 ID.AddInteger(scUDivExpr);
3484 ID.AddPointer(LHS);
3485 ID.AddPointer(RHS);
3486 void *IP = nullptr;
3487 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP))
3488 return S;
3489
3490 // 0 udiv Y == 0
3491 if (match(LHS, m_scev_Zero()))
3492 return LHS;
3493
3494 if (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS)) {
3495 if (RHSC->getValue()->isOne())
3496 return LHS; // X udiv 1 --> x
3497 // If the denominator is zero, the result of the udiv is undefined. Don't
3498 // try to analyze it, because the resolution chosen here may differ from
3499 // the resolution chosen in other parts of the compiler.
3500 if (!RHSC->getValue()->isZero()) {
3501 // Determine if the division can be folded into the operands of
3502 // its operands.
3503 // TODO: Generalize this to non-constants by using known-bits information.
3504 Type *Ty = LHS->getType();
3505 unsigned LZ = RHSC->getAPInt().countl_zero();
3506 unsigned MaxShiftAmt = getTypeSizeInBits(Ty) - LZ - 1;
3507 // For non-power-of-two values, effectively round the value up to the
3508 // nearest power of two.
3509 if (!RHSC->getAPInt().isPowerOf2())
3510 ++MaxShiftAmt;
3511 IntegerType *ExtTy =
3512 IntegerType::get(getContext(), getTypeSizeInBits(Ty) + MaxShiftAmt);
3513 if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(LHS))
3514 if (const SCEVConstant *Step =
3515 dyn_cast<SCEVConstant>(AR->getStepRecurrence(*this))) {
3516 // {X,+,N}/C --> {X/C,+,N/C} if safe and N/C can be folded.
3517 const APInt &StepInt = Step->getAPInt();
3518 const APInt &DivInt = RHSC->getAPInt();
3519 if (!StepInt.urem(DivInt) &&
3520 getZeroExtendExpr(AR, ExtTy) ==
3521 getAddRecExpr(getZeroExtendExpr(AR->getStart(), ExtTy),
3522 getZeroExtendExpr(Step, ExtTy),
3523 AR->getLoop(), SCEV::FlagAnyWrap)) {
3524 SmallVector<SCEVUse, 4> Operands;
3525 for (const SCEV *Op : AR->operands())
3526 Operands.push_back(getUDivExpr(Op, RHS));
3527 return getAddRecExpr(Operands, AR->getLoop(), SCEV::FlagNW);
3528 }
3529 /// Get a canonical UDivExpr for a recurrence.
3530 /// {X,+,N}/C => {Y,+,N}/C where Y=X-(X%N). Safe when C%N=0.
3531 const APInt *StartRem;
3532 if (!DivInt.urem(StepInt) && match(getURemExpr(AR->getStart(), Step),
3533 m_scev_APInt(StartRem))) {
3534 bool NoWrap =
3535 getZeroExtendExpr(AR, ExtTy) ==
3536 getAddRecExpr(getZeroExtendExpr(AR->getStart(), ExtTy),
3537 getZeroExtendExpr(Step, ExtTy), AR->getLoop(),
3539
3540 // With N <= C and both N, C as powers-of-2, the transformation
3541 // {X,+,N}/C => {(X - X%N),+,N}/C preserves division results even
3542 // if wrapping occurs, as the division results remain equivalent for
3543 // all offsets in [[(X - X%N), X).
3544 bool CanFoldWithWrap = StepInt.ule(DivInt) && // N <= C
3545 StepInt.isPowerOf2() && DivInt.isPowerOf2();
3546 // Only fold if the subtraction can be folded in the start
3547 // expression.
3548 const SCEV *NewStart =
3549 getMinusSCEV(AR->getStart(), getConstant(*StartRem));
3550 if (*StartRem != 0 && (NoWrap || CanFoldWithWrap) &&
3551 !isa<SCEVAddExpr>(NewStart)) {
3552 const SCEV *NewLHS =
3553 getAddRecExpr(NewStart, Step, AR->getLoop(),
3554 NoWrap ? SCEV::FlagNW : SCEV::FlagAnyWrap);
3555 if (LHS != NewLHS) {
3556 LHS = NewLHS;
3557
3558 // Reset the ID to include the new LHS, and check if it is
3559 // already cached.
3560 ID.clear();
3561 ID.AddInteger(scUDivExpr);
3562 ID.AddPointer(LHS);
3563 ID.AddPointer(RHS);
3564 IP = nullptr;
3565 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP))
3566 return S;
3567 }
3568 }
3569 }
3570 }
3571 // (A*B)/C --> A*(B/C) if safe and B/C can be folded.
3572 if (const SCEVMulExpr *M = dyn_cast<SCEVMulExpr>(LHS)) {
3573 SmallVector<SCEVUse, 4> Operands;
3574 for (const SCEV *Op : M->operands())
3575 Operands.push_back(getZeroExtendExpr(Op, ExtTy));
3576 if (getZeroExtendExpr(M, ExtTy) == getMulExpr(Operands))
3577 // Find an operand that's safely divisible.
3578 for (unsigned i = 0, e = M->getNumOperands(); i != e; ++i) {
3579 const SCEV *Op = M->getOperand(i);
3580 const SCEV *Div = getUDivExpr(Op, RHSC);
3581 if (!isa<SCEVUDivExpr>(Div) && getMulExpr(Div, RHSC) == Op) {
3582 Operands = SmallVector<SCEVUse, 4>(M->operands());
3583 Operands[i] = Div;
3584 return getMulExpr(Operands);
3585 }
3586 }
3587 }
3588
3589 // (A/B)/C --> A/(B*C) if safe and B*C can be folded.
3590 if (const SCEVUDivExpr *OtherDiv = dyn_cast<SCEVUDivExpr>(LHS)) {
3591 if (auto *DivisorConstant =
3592 dyn_cast<SCEVConstant>(OtherDiv->getRHS())) {
3593 bool Overflow = false;
3594 APInt NewRHS =
3595 DivisorConstant->getAPInt().umul_ov(RHSC->getAPInt(), Overflow);
3596 if (Overflow) {
3597 return getConstant(RHSC->getType(), 0, false);
3598 }
3599 return getUDivExpr(OtherDiv->getLHS(), getConstant(NewRHS));
3600 }
3601 }
3602
3603 // (A+B)/C --> (A/C + B/C) if safe and A/C and B/C can be folded.
3604 if (const SCEVAddExpr *A = dyn_cast<SCEVAddExpr>(LHS)) {
3605 SmallVector<SCEVUse, 4> Operands;
3606 for (const SCEV *Op : A->operands())
3607 Operands.push_back(getZeroExtendExpr(Op, ExtTy));
3608 if (getZeroExtendExpr(A, ExtTy) == getAddExpr(Operands)) {
3609 Operands.clear();
3610 for (unsigned i = 0, e = A->getNumOperands(); i != e; ++i) {
3611 const SCEV *Op = getUDivExpr(A->getOperand(i), RHS);
3612 if (isa<SCEVUDivExpr>(Op) ||
3613 getMulExpr(Op, RHS) != A->getOperand(i))
3614 break;
3615 Operands.push_back(Op);
3616 }
3617 if (Operands.size() == A->getNumOperands())
3618 return getAddExpr(Operands);
3619 }
3620 }
3621
3622 // Fold if both operands are constant.
3623 if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(LHS))
3624 return getConstant(LHSC->getAPInt().udiv(RHSC->getAPInt()));
3625 }
3626 }
3627
3628 // ((-C + (C smax %x)) /u %x) evaluates to zero, for any positive constant C.
3629 const APInt *NegC, *C;
3630 if (match(LHS,
3633 NegC->isNegative() && !NegC->isMinSignedValue() && *C == -*NegC)
3634 return getZero(LHS->getType());
3635
3636 // TODO: Generalize to handle any common factors.
3637 // udiv (mul nuw a, vscale), (mul nuw b, vscale) --> udiv a, b
3638 const SCEV *NewLHS, *NewRHS;
3639 if (match(LHS, m_scev_c_NUWMul(m_SCEV(NewLHS), m_SCEVVScale())) &&
3640 match(RHS, m_scev_c_NUWMul(m_SCEV(NewRHS), m_SCEVVScale())))
3641 return getUDivExpr(NewLHS, NewRHS);
3642
3643 // The Insertion Point (IP) might be invalid by now (due to UniqueSCEVs
3644 // changes). Make sure we get a new one.
3645 IP = nullptr;
3646 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
3647 SCEV *S = new (SCEVAllocator) SCEVUDivExpr(ID.Intern(SCEVAllocator),
3648 LHS, RHS);
3649 UniqueSCEVs.InsertNode(S, IP);
3650 registerUser(S, ArrayRef<SCEVUse>({LHS, RHS}));
3651 return S;
3652}
3653
3654APInt gcd(const SCEVConstant *C1, const SCEVConstant *C2) {
3655 APInt A = C1->getAPInt().abs();
3656 APInt B = C2->getAPInt().abs();
3657 uint32_t ABW = A.getBitWidth();
3658 uint32_t BBW = B.getBitWidth();
3659
3660 if (ABW > BBW)
3661 B = B.zext(ABW);
3662 else if (ABW < BBW)
3663 A = A.zext(BBW);
3664
3665 return APIntOps::GreatestCommonDivisor(std::move(A), std::move(B));
3666}
3667
3668/// Get a canonical unsigned division expression, or something simpler if
3669/// possible. There is no representation for an exact udiv in SCEV IR, but we
3670/// can attempt to remove factors from the LHS and RHS. We can't do this when
3671/// it's not exact because the udiv may be clearing bits.
3673 // TODO: we could try to find factors in all sorts of things, but for now we
3674 // just deal with u/exact (multiply, constant). See SCEVDivision towards the
3675 // end of this file for inspiration.
3676
3678 if (!Mul || !Mul->hasNoUnsignedWrap())
3679 return getUDivExpr(LHS, RHS);
3680
3681 if (const SCEVConstant *RHSCst = dyn_cast<SCEVConstant>(RHS)) {
3682 // If the mulexpr multiplies by a constant, then that constant must be the
3683 // first element of the mulexpr.
3684 if (const auto *LHSCst = dyn_cast<SCEVConstant>(Mul->getOperand(0))) {
3685 if (LHSCst == RHSCst) {
3686 SmallVector<SCEVUse, 2> Operands(drop_begin(Mul->operands()));
3687 return getMulExpr(Operands);
3688 }
3689
3690 // We can't just assume that LHSCst divides RHSCst cleanly, it could be
3691 // that there's a factor provided by one of the other terms. We need to
3692 // check.
3693 APInt Factor = gcd(LHSCst, RHSCst);
3694 if (!Factor.isIntN(1)) {
3695 LHSCst =
3696 cast<SCEVConstant>(getConstant(LHSCst->getAPInt().udiv(Factor)));
3697 RHSCst =
3698 cast<SCEVConstant>(getConstant(RHSCst->getAPInt().udiv(Factor)));
3699 SmallVector<SCEVUse, 2> Operands;
3700 Operands.push_back(LHSCst);
3701 append_range(Operands, Mul->operands().drop_front());
3702 LHS = getMulExpr(Operands);
3703 RHS = RHSCst;
3705 if (!Mul)
3706 return getUDivExactExpr(LHS, RHS);
3707 }
3708 }
3709 }
3710
3711 for (int i = 0, e = Mul->getNumOperands(); i != e; ++i) {
3712 if (Mul->getOperand(i) == RHS) {
3713 SmallVector<SCEVUse, 2> Operands;
3714 append_range(Operands, Mul->operands().take_front(i));
3715 append_range(Operands, Mul->operands().drop_front(i + 1));
3716 return getMulExpr(Operands);
3717 }
3718 }
3719
3720 return getUDivExpr(LHS, RHS);
3721}
3722
3723/// Get an add recurrence expression for the specified loop. Simplify the
3724/// expression as much as possible.
3726 const Loop *L,
3727 SCEV::NoWrapFlags Flags) {
3728 SmallVector<SCEVUse, 4> Operands;
3729 Operands.push_back(Start);
3730 if (const SCEVAddRecExpr *StepChrec = dyn_cast<SCEVAddRecExpr>(Step))
3731 if (StepChrec->getLoop() == L) {
3732 append_range(Operands, StepChrec->operands());
3733 return getAddRecExpr(Operands, L, maskFlags(Flags, SCEV::FlagNW));
3734 }
3735
3736 Operands.push_back(Step);
3737 return getAddRecExpr(Operands, L, Flags);
3738}
3739
3740/// Get an add recurrence expression for the specified loop. Simplify the
3741/// expression as much as possible.
3743 const Loop *L,
3744 SCEV::NoWrapFlags Flags) {
3745 if (Operands.size() == 1) return Operands[0];
3746#ifndef NDEBUG
3747 Type *ETy = getEffectiveSCEVType(Operands[0]->getType());
3748 for (const SCEV *Op : llvm::drop_begin(Operands)) {
3749 assert(getEffectiveSCEVType(Op->getType()) == ETy &&
3750 "SCEVAddRecExpr operand types don't match!");
3751 assert(!Op->getType()->isPointerTy() && "Step must be integer");
3752 }
3753 for (const SCEV *Op : Operands)
3755 "SCEVAddRecExpr operand is not available at loop entry!");
3756#endif
3757
3758 if (Operands.back()->isZero()) {
3759 Operands.pop_back();
3760 return getAddRecExpr(Operands, L, SCEV::FlagAnyWrap); // {X,+,0} --> X
3761 }
3762
3763 // It's tempting to want to call getConstantMaxBackedgeTakenCount count here and
3764 // use that information to infer NUW and NSW flags. However, computing a
3765 // BE count requires calling getAddRecExpr, so we may not yet have a
3766 // meaningful BE count at this point (and if we don't, we'd be stuck
3767 // with a SCEVCouldNotCompute as the cached BE count).
3768
3769 Flags = StrengthenNoWrapFlags(this, scAddRecExpr, Operands, Flags);
3770
3771 // Canonicalize nested AddRecs in by nesting them in order of loop depth.
3772 if (const SCEVAddRecExpr *NestedAR = dyn_cast<SCEVAddRecExpr>(Operands[0])) {
3773 const Loop *NestedLoop = NestedAR->getLoop();
3774 if (L->contains(NestedLoop)
3775 ? (L->getLoopDepth() < NestedLoop->getLoopDepth())
3776 : (!NestedLoop->contains(L) &&
3777 DT.dominates(L->getHeader(), NestedLoop->getHeader()))) {
3778 SmallVector<SCEVUse, 4> NestedOperands(NestedAR->operands());
3779 Operands[0] = NestedAR->getStart();
3780 // AddRecs require their operands be loop-invariant with respect to their
3781 // loops. Don't perform this transformation if it would break this
3782 // requirement.
3783 bool AllInvariant = all_of(
3784 Operands, [&](const SCEV *Op) { return isLoopInvariant(Op, L); });
3785
3786 if (AllInvariant) {
3787 // Create a recurrence for the outer loop with the same step size.
3788 //
3789 // The outer recurrence keeps its NW flag but only keeps NUW/NSW if the
3790 // inner recurrence has the same property.
3791 SCEV::NoWrapFlags OuterFlags =
3792 maskFlags(Flags, SCEV::FlagNW | NestedAR->getNoWrapFlags());
3793
3794 NestedOperands[0] = getAddRecExpr(Operands, L, OuterFlags);
3795 AllInvariant = all_of(NestedOperands, [&](const SCEV *Op) {
3796 return isLoopInvariant(Op, NestedLoop);
3797 });
3798
3799 if (AllInvariant) {
3800 // Ok, both add recurrences are valid after the transformation.
3801 //
3802 // The inner recurrence keeps its NW flag but only keeps NUW/NSW if
3803 // the outer recurrence has the same property.
3804 SCEV::NoWrapFlags InnerFlags =
3805 maskFlags(NestedAR->getNoWrapFlags(), SCEV::FlagNW | Flags);
3806 return getAddRecExpr(NestedOperands, NestedLoop, InnerFlags);
3807 }
3808 }
3809 // Reset Operands to its original state.
3810 Operands[0] = NestedAR;
3811 }
3812 }
3813
3814 // Okay, it looks like we really DO need an addrec expr. Check to see if we
3815 // already have one, otherwise create a new one.
3816 return getOrCreateAddRecExpr(Operands, L, Flags);
3817}
3818
3820 ArrayRef<SCEVUse> IndexExprs) {
3821 const SCEV *BaseExpr = getSCEV(GEP->getPointerOperand());
3822 // getSCEV(Base)->getType() has the same address space as Base->getType()
3823 // because SCEV::getType() preserves the address space.
3824 GEPNoWrapFlags NW = GEP->getNoWrapFlags();
3825 if (NW != GEPNoWrapFlags::none()) {
3826 // We'd like to propagate flags from the IR to the corresponding SCEV nodes,
3827 // but to do that, we have to ensure that said flag is valid in the entire
3828 // defined scope of the SCEV.
3829 // TODO: non-instructions have global scope. We might be able to prove
3830 // some global scope cases
3831 auto *GEPI = dyn_cast<Instruction>(GEP);
3832 if (!GEPI || !isSCEVExprNeverPoison(GEPI))
3833 NW = GEPNoWrapFlags::none();
3834 }
3835
3836 return getGEPExpr(BaseExpr, IndexExprs, GEP->getSourceElementType(), NW);
3837}
3838
3840 ArrayRef<SCEVUse> IndexExprs,
3841 Type *SrcElementTy, GEPNoWrapFlags NW) {
3843 if (NW.hasNoUnsignedSignedWrap())
3844 OffsetWrap = setFlags(OffsetWrap, SCEV::FlagNSW);
3845 if (NW.hasNoUnsignedWrap())
3846 OffsetWrap = setFlags(OffsetWrap, SCEV::FlagNUW);
3847
3848 Type *CurTy = BaseExpr->getType();
3849 Type *IntIdxTy = getEffectiveSCEVType(BaseExpr->getType());
3850 bool FirstIter = true;
3852 for (SCEVUse IndexExpr : IndexExprs) {
3853 // Compute the (potentially symbolic) offset in bytes for this index.
3854 if (StructType *STy = dyn_cast<StructType>(CurTy)) {
3855 // For a struct, add the member offset.
3856 ConstantInt *Index = cast<SCEVConstant>(IndexExpr)->getValue();
3857 unsigned FieldNo = Index->getZExtValue();
3858 const SCEV *FieldOffset = getOffsetOfExpr(IntIdxTy, STy, FieldNo);
3859 Offsets.push_back(FieldOffset);
3860
3861 // Update CurTy to the type of the field at Index.
3862 CurTy = STy->getTypeAtIndex(Index);
3863 } else {
3864 // Update CurTy to its element type.
3865 if (FirstIter) {
3866 assert(isa<PointerType>(CurTy) &&
3867 "The first index of a GEP indexes a pointer");
3868 CurTy = SrcElementTy;
3869 FirstIter = false;
3870 } else {
3872 }
3873 // For an array, add the element offset, explicitly scaled.
3874 const SCEV *ElementSize = getSizeOfExpr(IntIdxTy, CurTy);
3875 // Getelementptr indices are signed.
3876 IndexExpr = getTruncateOrSignExtend(IndexExpr, IntIdxTy);
3877
3878 // Multiply the index by the element size to compute the element offset.
3879 const SCEV *LocalOffset = getMulExpr(IndexExpr, ElementSize, OffsetWrap);
3880 Offsets.push_back(LocalOffset);
3881 }
3882 }
3883
3884 // Handle degenerate case of GEP without offsets.
3885 if (Offsets.empty())
3886 return BaseExpr;
3887
3888 // Add the offsets together, assuming nsw if inbounds.
3889 const SCEV *Offset = getAddExpr(Offsets, OffsetWrap);
3890 // Add the base address and the offset. We cannot use the nsw flag, as the
3891 // base address is unsigned. However, if we know that the offset is
3892 // non-negative, we can use nuw.
3893 bool NUW = NW.hasNoUnsignedWrap() ||
3896 auto *GEPExpr = getAddExpr(BaseExpr, Offset, BaseWrap);
3897 assert(BaseExpr->getType() == GEPExpr->getType() &&
3898 "GEP should not change type mid-flight.");
3899 return GEPExpr;
3900}
3901
3902SCEV *ScalarEvolution::findExistingSCEVInCache(SCEVTypes SCEVType,
3905 ID.AddInteger(SCEVType);
3906 for (const SCEV *Op : Ops)
3907 ID.AddPointer(Op);
3908 void *IP = nullptr;
3909 return UniqueSCEVs.FindNodeOrInsertPos(ID, IP);
3910}
3911
3912SCEV *ScalarEvolution::findExistingSCEVInCache(SCEVTypes SCEVType,
3915 ID.AddInteger(SCEVType);
3916 for (const SCEV *Op : Ops)
3917 ID.AddPointer(Op);
3918 void *IP = nullptr;
3919 return UniqueSCEVs.FindNodeOrInsertPos(ID, IP);
3920}
3921
3922const SCEV *ScalarEvolution::getAbsExpr(const SCEV *Op, bool IsNSW) {
3924 return getSMaxExpr(Op, getNegativeSCEV(Op, Flags));
3925}
3926
3929 assert(SCEVMinMaxExpr::isMinMaxType(Kind) && "Not a SCEVMinMaxExpr!");
3930 assert(!Ops.empty() && "Cannot get empty (u|s)(min|max)!");
3931 if (Ops.size() == 1) return Ops[0];
3932#ifndef NDEBUG
3933 Type *ETy = getEffectiveSCEVType(Ops[0]->getType());
3934 for (unsigned i = 1, e = Ops.size(); i != e; ++i) {
3935 assert(getEffectiveSCEVType(Ops[i]->getType()) == ETy &&
3936 "Operand types don't match!");
3937 assert(Ops[0]->getType()->isPointerTy() ==
3938 Ops[i]->getType()->isPointerTy() &&
3939 "min/max should be consistently pointerish");
3940 }
3941#endif
3942
3943 bool IsSigned = Kind == scSMaxExpr || Kind == scSMinExpr;
3944 bool IsMax = Kind == scSMaxExpr || Kind == scUMaxExpr;
3945
3946 const SCEV *Folded = constantFoldAndGroupOps(
3947 *this, LI, DT, Ops,
3948 [&](const APInt &C1, const APInt &C2) {
3949 switch (Kind) {
3950 case scSMaxExpr:
3951 return APIntOps::smax(C1, C2);
3952 case scSMinExpr:
3953 return APIntOps::smin(C1, C2);
3954 case scUMaxExpr:
3955 return APIntOps::umax(C1, C2);
3956 case scUMinExpr:
3957 return APIntOps::umin(C1, C2);
3958 default:
3959 llvm_unreachable("Unknown SCEV min/max opcode");
3960 }
3961 },
3962 [&](const APInt &C) {
3963 // identity
3964 if (IsMax)
3965 return IsSigned ? C.isMinSignedValue() : C.isMinValue();
3966 else
3967 return IsSigned ? C.isMaxSignedValue() : C.isMaxValue();
3968 },
3969 [&](const APInt &C) {
3970 // absorber
3971 if (IsMax)
3972 return IsSigned ? C.isMaxSignedValue() : C.isMaxValue();
3973 else
3974 return IsSigned ? C.isMinSignedValue() : C.isMinValue();
3975 });
3976 if (Folded)
3977 return Folded;
3978
3979 // Check if we have created the same expression before.
3980 if (const SCEV *S = findExistingSCEVInCache(Kind, Ops)) {
3981 return S;
3982 }
3983
3984 // Find the first operation of the same kind
3985 unsigned Idx = 0;
3986 while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < Kind)
3987 ++Idx;
3988
3989 // Check to see if one of the operands is of the same kind. If so, expand its
3990 // operands onto our operand list, and recurse to simplify.
3991 if (Idx < Ops.size()) {
3992 bool DeletedAny = false;
3993 while (Ops[Idx]->getSCEVType() == Kind) {
3994 const SCEVMinMaxExpr *SMME = cast<SCEVMinMaxExpr>(Ops[Idx]);
3995 Ops.erase(Ops.begin()+Idx);
3996 append_range(Ops, SMME->operands());
3997 DeletedAny = true;
3998 }
3999
4000 if (DeletedAny)
4001 return getMinMaxExpr(Kind, Ops);
4002 }
4003
4004 // Okay, check to see if the same value occurs in the operand list twice. If
4005 // so, delete one. Since we sorted the list, these values are required to
4006 // be adjacent.
4011 llvm::CmpInst::Predicate FirstPred = IsMax ? GEPred : LEPred;
4012 llvm::CmpInst::Predicate SecondPred = IsMax ? LEPred : GEPred;
4013 for (unsigned i = 0, e = Ops.size() - 1; i != e; ++i) {
4014 if (Ops[i] == Ops[i + 1] ||
4015 isKnownViaNonRecursiveReasoning(FirstPred, Ops[i], Ops[i + 1])) {
4016 // X op Y op Y --> X op Y
4017 // X op Y --> X, if we know X, Y are ordered appropriately
4018 Ops.erase(Ops.begin() + i + 1, Ops.begin() + i + 2);
4019 --i;
4020 --e;
4021 } else if (isKnownViaNonRecursiveReasoning(SecondPred, Ops[i],
4022 Ops[i + 1])) {
4023 // X op Y --> Y, if we know X, Y are ordered appropriately
4024 Ops.erase(Ops.begin() + i, Ops.begin() + i + 1);
4025 --i;
4026 --e;
4027 }
4028 }
4029
4030 if (Ops.size() == 1) return Ops[0];
4031
4032 assert(!Ops.empty() && "Reduced smax down to nothing!");
4033
4034 // Okay, it looks like we really DO need an expr. Check to see if we
4035 // already have one, otherwise create a new one.
4037 ID.AddInteger(Kind);
4038 for (const SCEV *Op : Ops)
4039 ID.AddPointer(Op);
4040 void *IP = nullptr;
4041 const SCEV *ExistingSCEV = UniqueSCEVs.FindNodeOrInsertPos(ID, IP);
4042 if (ExistingSCEV)
4043 return ExistingSCEV;
4044 SCEVUse *O = SCEVAllocator.Allocate<SCEVUse>(Ops.size());
4046 SCEV *S = new (SCEVAllocator)
4047 SCEVMinMaxExpr(ID.Intern(SCEVAllocator), Kind, O, Ops.size());
4048
4049 UniqueSCEVs.InsertNode(S, IP);
4050 registerUser(S, Ops);
4051 return S;
4052}
4053
4054namespace {
4055
4056class SCEVSequentialMinMaxDeduplicatingVisitor final
4057 : public SCEVVisitor<SCEVSequentialMinMaxDeduplicatingVisitor,
4058 std::optional<const SCEV *>> {
4059 using RetVal = std::optional<const SCEV *>;
4061
4062 ScalarEvolution &SE;
4063 const SCEVTypes RootKind; // Must be a sequential min/max expression.
4064 const SCEVTypes NonSequentialRootKind; // Non-sequential variant of RootKind.
4066
4067 bool canRecurseInto(SCEVTypes Kind) const {
4068 // We can only recurse into the SCEV expression of the same effective type
4069 // as the type of our root SCEV expression.
4070 return RootKind == Kind || NonSequentialRootKind == Kind;
4071 };
4072
4073 RetVal visitAnyMinMaxExpr(const SCEV *S) {
4075 "Only for min/max expressions.");
4076 SCEVTypes Kind = S->getSCEVType();
4077
4078 if (!canRecurseInto(Kind))
4079 return S;
4080
4081 auto *NAry = cast<SCEVNAryExpr>(S);
4082 SmallVector<SCEVUse> NewOps;
4083 bool Changed = visit(Kind, NAry->operands(), NewOps);
4084
4085 if (!Changed)
4086 return S;
4087 if (NewOps.empty())
4088 return std::nullopt;
4089
4091 ? SE.getSequentialMinMaxExpr(Kind, NewOps)
4092 : SE.getMinMaxExpr(Kind, NewOps);
4093 }
4094
4095 RetVal visit(const SCEV *S) {
4096 // Has the whole operand been seen already?
4097 if (!SeenOps.insert(S).second)
4098 return std::nullopt;
4099 return Base::visit(S);
4100 }
4101
4102public:
4103 SCEVSequentialMinMaxDeduplicatingVisitor(ScalarEvolution &SE,
4104 SCEVTypes RootKind)
4105 : SE(SE), RootKind(RootKind),
4106 NonSequentialRootKind(
4107 SCEVSequentialMinMaxExpr::getEquivalentNonSequentialSCEVType(
4108 RootKind)) {}
4109
4110 bool /*Changed*/ visit(SCEVTypes Kind, ArrayRef<SCEVUse> OrigOps,
4111 SmallVectorImpl<SCEVUse> &NewOps) {
4112 bool Changed = false;
4114 Ops.reserve(OrigOps.size());
4115
4116 for (const SCEV *Op : OrigOps) {
4117 RetVal NewOp = visit(Op);
4118 if (NewOp != Op)
4119 Changed = true;
4120 if (NewOp)
4121 Ops.emplace_back(*NewOp);
4122 }
4123
4124 if (Changed)
4125 NewOps = std::move(Ops);
4126 return Changed;
4127 }
4128
4129 RetVal visitConstant(const SCEVConstant *Constant) { return Constant; }
4130
4131 RetVal visitVScale(const SCEVVScale *VScale) { return VScale; }
4132
4133 RetVal visitPtrToAddrExpr(const SCEVPtrToAddrExpr *Expr) { return Expr; }
4134
4135 RetVal visitPtrToIntExpr(const SCEVPtrToIntExpr *Expr) { return Expr; }
4136
4137 RetVal visitTruncateExpr(const SCEVTruncateExpr *Expr) { return Expr; }
4138
4139 RetVal visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr) { return Expr; }
4140
4141 RetVal visitSignExtendExpr(const SCEVSignExtendExpr *Expr) { return Expr; }
4142
4143 RetVal visitAddExpr(const SCEVAddExpr *Expr) { return Expr; }
4144
4145 RetVal visitMulExpr(const SCEVMulExpr *Expr) { return Expr; }
4146
4147 RetVal visitUDivExpr(const SCEVUDivExpr *Expr) { return Expr; }
4148
4149 RetVal visitAddRecExpr(const SCEVAddRecExpr *Expr) { return Expr; }
4150
4151 RetVal visitSMaxExpr(const SCEVSMaxExpr *Expr) {
4152 return visitAnyMinMaxExpr(Expr);
4153 }
4154
4155 RetVal visitUMaxExpr(const SCEVUMaxExpr *Expr) {
4156 return visitAnyMinMaxExpr(Expr);
4157 }
4158
4159 RetVal visitSMinExpr(const SCEVSMinExpr *Expr) {
4160 return visitAnyMinMaxExpr(Expr);
4161 }
4162
4163 RetVal visitUMinExpr(const SCEVUMinExpr *Expr) {
4164 return visitAnyMinMaxExpr(Expr);
4165 }
4166
4167 RetVal visitSequentialUMinExpr(const SCEVSequentialUMinExpr *Expr) {
4168 return visitAnyMinMaxExpr(Expr);
4169 }
4170
4171 RetVal visitUnknown(const SCEVUnknown *Expr) { return Expr; }
4172
4173 RetVal visitCouldNotCompute(const SCEVCouldNotCompute *Expr) { return Expr; }
4174};
4175
4176} // namespace
4177
4179 switch (Kind) {
4180 case scConstant:
4181 case scVScale:
4182 case scTruncate:
4183 case scZeroExtend:
4184 case scSignExtend:
4185 case scPtrToAddr:
4186 case scPtrToInt:
4187 case scAddExpr:
4188 case scMulExpr:
4189 case scUDivExpr:
4190 case scAddRecExpr:
4191 case scUMaxExpr:
4192 case scSMaxExpr:
4193 case scUMinExpr:
4194 case scSMinExpr:
4195 case scUnknown:
4196 // If any operand is poison, the whole expression is poison.
4197 return true;
4199 // FIXME: if the *first* operand is poison, the whole expression is poison.
4200 return false; // Pessimistically, say that it does not propagate poison.
4201 case scCouldNotCompute:
4202 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
4203 }
4204 llvm_unreachable("Unknown SCEV kind!");
4205}
4206
4207namespace {
4208// The only way poison may be introduced in a SCEV expression is from a
4209// poison SCEVUnknown (ConstantExprs are also represented as SCEVUnknown,
4210// not SCEVConstant). Notably, nowrap flags in SCEV nodes can *not*
4211// introduce poison -- they encode guaranteed, non-speculated knowledge.
4212//
4213// Additionally, all SCEV nodes propagate poison from inputs to outputs,
4214// with the notable exception of umin_seq, where only poison from the first
4215// operand is (unconditionally) propagated.
4216struct SCEVPoisonCollector {
4217 bool LookThroughMaybePoisonBlocking;
4218 SmallPtrSet<const SCEVUnknown *, 4> MaybePoison;
4219 SCEVPoisonCollector(bool LookThroughMaybePoisonBlocking)
4220 : LookThroughMaybePoisonBlocking(LookThroughMaybePoisonBlocking) {}
4221
4222 bool follow(const SCEV *S) {
4223 if (!LookThroughMaybePoisonBlocking &&
4225 return false;
4226
4227 if (auto *SU = dyn_cast<SCEVUnknown>(S)) {
4228 if (!isGuaranteedNotToBePoison(SU->getValue()))
4229 MaybePoison.insert(SU);
4230 }
4231 return true;
4232 }
4233 bool isDone() const { return false; }
4234};
4235} // namespace
4236
4237/// Return true if V is poison given that AssumedPoison is already poison.
4238static bool impliesPoison(const SCEV *AssumedPoison, const SCEV *S) {
4239 // First collect all SCEVs that might result in AssumedPoison to be poison.
4240 // We need to look through potentially poison-blocking operations here,
4241 // because we want to find all SCEVs that *might* result in poison, not only
4242 // those that are *required* to.
4243 SCEVPoisonCollector PC1(/* LookThroughMaybePoisonBlocking */ true);
4244 visitAll(AssumedPoison, PC1);
4245
4246 // AssumedPoison is never poison. As the assumption is false, the implication
4247 // is true. Don't bother walking the other SCEV in this case.
4248 if (PC1.MaybePoison.empty())
4249 return true;
4250
4251 // Collect all SCEVs in S that, if poison, *will* result in S being poison
4252 // as well. We cannot look through potentially poison-blocking operations
4253 // here, as their arguments only *may* make the result poison.
4254 SCEVPoisonCollector PC2(/* LookThroughMaybePoisonBlocking */ false);
4255 visitAll(S, PC2);
4256
4257 // Make sure that no matter which SCEV in PC1.MaybePoison is actually poison,
4258 // it will also make S poison by being part of PC2.MaybePoison.
4259 return llvm::set_is_subset(PC1.MaybePoison, PC2.MaybePoison);
4260}
4261
4263 SmallPtrSetImpl<const Value *> &Result, const SCEV *S) {
4264 SCEVPoisonCollector PC(/* LookThroughMaybePoisonBlocking */ false);
4265 visitAll(S, PC);
4266 for (const SCEVUnknown *SU : PC.MaybePoison)
4267 Result.insert(SU->getValue());
4268}
4269
4271 const SCEV *S, Instruction *I,
4272 SmallVectorImpl<Instruction *> &DropPoisonGeneratingInsts) {
4273 // If the instruction cannot be poison, it's always safe to reuse.
4275 return true;
4276
4277 // Otherwise, it is possible that I is more poisonous that S. Collect the
4278 // poison-contributors of S, and then check whether I has any additional
4279 // poison-contributors. Poison that is contributed through poison-generating
4280 // flags is handled by dropping those flags instead.
4282 getPoisonGeneratingValues(PoisonVals, S);
4283
4284 SmallVector<Value *> Worklist;
4286 Worklist.push_back(I);
4287 while (!Worklist.empty()) {
4288 Value *V = Worklist.pop_back_val();
4289 if (!Visited.insert(V).second)
4290 continue;
4291
4292 // Avoid walking large instruction graphs.
4293 if (Visited.size() > 16)
4294 return false;
4295
4296 // Either the value can't be poison, or the S would also be poison if it
4297 // is.
4298 if (PoisonVals.contains(V) || ::isGuaranteedNotToBePoison(V))
4299 continue;
4300
4301 auto *I = dyn_cast<Instruction>(V);
4302 if (!I)
4303 return false;
4304
4305 // Disjoint or instructions are interpreted as adds by SCEV. However, we
4306 // can't replace an arbitrary add with disjoint or, even if we drop the
4307 // flag. We would need to convert the or into an add.
4308 if (auto *PDI = dyn_cast<PossiblyDisjointInst>(I))
4309 if (PDI->isDisjoint())
4310 return false;
4311
4312 // FIXME: Ignore vscale, even though it technically could be poison. Do this
4313 // because SCEV currently assumes it can't be poison. Remove this special
4314 // case once we proper model when vscale can be poison.
4315 if (auto *II = dyn_cast<IntrinsicInst>(I);
4316 II && II->getIntrinsicID() == Intrinsic::vscale)
4317 continue;
4318
4319 if (canCreatePoison(cast<Operator>(I), /*ConsiderFlagsAndMetadata*/ false))
4320 return false;
4321
4322 // If the instruction can't create poison, we can recurse to its operands.
4323 if (I->hasPoisonGeneratingAnnotations())
4324 DropPoisonGeneratingInsts.push_back(I);
4325
4326 llvm::append_range(Worklist, I->operands());
4327 }
4328 return true;
4329}
4330
4331const SCEV *
4334 assert(SCEVSequentialMinMaxExpr::isSequentialMinMaxType(Kind) &&
4335 "Not a SCEVSequentialMinMaxExpr!");
4336 assert(!Ops.empty() && "Cannot get empty (u|s)(min|max)!");
4337 if (Ops.size() == 1)
4338 return Ops[0];
4339#ifndef NDEBUG
4340 Type *ETy = getEffectiveSCEVType(Ops[0]->getType());
4341 for (unsigned i = 1, e = Ops.size(); i != e; ++i) {
4342 assert(getEffectiveSCEVType(Ops[i]->getType()) == ETy &&
4343 "Operand types don't match!");
4344 assert(Ops[0]->getType()->isPointerTy() ==
4345 Ops[i]->getType()->isPointerTy() &&
4346 "min/max should be consistently pointerish");
4347 }
4348#endif
4349
4350 // Note that SCEVSequentialMinMaxExpr is *NOT* commutative,
4351 // so we can *NOT* do any kind of sorting of the expressions!
4352
4353 // Check if we have created the same expression before.
4354 if (const SCEV *S = findExistingSCEVInCache(Kind, Ops))
4355 return S;
4356
4357 // FIXME: there are *some* simplifications that we can do here.
4358
4359 // Keep only the first instance of an operand.
4360 {
4361 SCEVSequentialMinMaxDeduplicatingVisitor Deduplicator(*this, Kind);
4362 bool Changed = Deduplicator.visit(Kind, Ops, Ops);
4363 if (Changed)
4364 return getSequentialMinMaxExpr(Kind, Ops);
4365 }
4366
4367 // Check to see if one of the operands is of the same kind. If so, expand its
4368 // operands onto our operand list, and recurse to simplify.
4369 {
4370 unsigned Idx = 0;
4371 bool DeletedAny = false;
4372 while (Idx < Ops.size()) {
4373 if (Ops[Idx]->getSCEVType() != Kind) {
4374 ++Idx;
4375 continue;
4376 }
4377 const auto *SMME = cast<SCEVSequentialMinMaxExpr>(Ops[Idx]);
4378 Ops.erase(Ops.begin() + Idx);
4379 Ops.insert(Ops.begin() + Idx, SMME->operands().begin(),
4380 SMME->operands().end());
4381 DeletedAny = true;
4382 }
4383
4384 if (DeletedAny)
4385 return getSequentialMinMaxExpr(Kind, Ops);
4386 }
4387
4388 const SCEV *SaturationPoint;
4390 switch (Kind) {
4392 SaturationPoint = getZero(Ops[0]->getType());
4393 Pred = ICmpInst::ICMP_ULE;
4394 break;
4395 default:
4396 llvm_unreachable("Not a sequential min/max type.");
4397 }
4398
4399 for (unsigned i = 1, e = Ops.size(); i != e; ++i) {
4400 if (!isGuaranteedNotToCauseUB(Ops[i]))
4401 continue;
4402 // We can replace %x umin_seq %y with %x umin %y if either:
4403 // * %y being poison implies %x is also poison.
4404 // * %x cannot be the saturating value (e.g. zero for umin).
4405 if (::impliesPoison(Ops[i], Ops[i - 1]) ||
4406 isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_NE, Ops[i - 1],
4407 SaturationPoint)) {
4408 SmallVector<SCEVUse, 2> SeqOps = {Ops[i - 1], Ops[i]};
4409 Ops[i - 1] = getMinMaxExpr(
4411 SeqOps);
4412 Ops.erase(Ops.begin() + i);
4413 return getSequentialMinMaxExpr(Kind, Ops);
4414 }
4415 // Fold %x umin_seq %y to %x if %x ule %y.
4416 // TODO: We might be able to prove the predicate for a later operand.
4417 if (isKnownViaNonRecursiveReasoning(Pred, Ops[i - 1], Ops[i])) {
4418 Ops.erase(Ops.begin() + i);
4419 return getSequentialMinMaxExpr(Kind, Ops);
4420 }
4421 }
4422
4423 // Okay, it looks like we really DO need an expr. Check to see if we
4424 // already have one, otherwise create a new one.
4426 ID.AddInteger(Kind);
4427 for (const SCEV *Op : Ops)
4428 ID.AddPointer(Op);
4429 void *IP = nullptr;
4430 const SCEV *ExistingSCEV = UniqueSCEVs.FindNodeOrInsertPos(ID, IP);
4431 if (ExistingSCEV)
4432 return ExistingSCEV;
4433
4434 SCEVUse *O = SCEVAllocator.Allocate<SCEVUse>(Ops.size());
4436 SCEV *S = new (SCEVAllocator)
4437 SCEVSequentialMinMaxExpr(ID.Intern(SCEVAllocator), Kind, O, Ops.size());
4438
4439 UniqueSCEVs.InsertNode(S, IP);
4440 registerUser(S, Ops);
4441 return S;
4442}
4443
4448
4452
4457
4461
4466
4470
4472 bool Sequential) {
4473 SmallVector<SCEVUse, 2> Ops = {LHS, RHS};
4474 return getUMinExpr(Ops, Sequential);
4475}
4476
4482
4483const SCEV *
4485 const SCEV *Res = getConstant(IntTy, Size.getKnownMinValue());
4486 if (Size.isScalable())
4487 Res = getMulExpr(Res, getVScale(IntTy));
4488 return Res;
4489}
4490
4492 return getSizeOfExpr(IntTy, getDataLayout().getTypeAllocSize(AllocTy));
4493}
4494
4496 return getSizeOfExpr(IntTy, getDataLayout().getTypeStoreSize(StoreTy));
4497}
4498
4500 StructType *STy,
4501 unsigned FieldNo) {
4502 // We can bypass creating a target-independent constant expression and then
4503 // folding it back into a ConstantInt. This is just a compile-time
4504 // optimization.
4505 const StructLayout *SL = getDataLayout().getStructLayout(STy);
4506 assert(!SL->getSizeInBits().isScalable() &&
4507 "Cannot get offset for structure containing scalable vector types");
4508 return getConstant(IntTy, SL->getElementOffset(FieldNo));
4509}
4510
4512 // Don't attempt to do anything other than create a SCEVUnknown object
4513 // here. createSCEV only calls getUnknown after checking for all other
4514 // interesting possibilities, and any other code that calls getUnknown
4515 // is doing so in order to hide a value from SCEV canonicalization.
4516
4518 ID.AddInteger(scUnknown);
4519 ID.AddPointer(V);
4520 void *IP = nullptr;
4521 if (SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) {
4522 assert(cast<SCEVUnknown>(S)->getValue() == V &&
4523 "Stale SCEVUnknown in uniquing map!");
4524 return S;
4525 }
4526 SCEV *S = new (SCEVAllocator) SCEVUnknown(ID.Intern(SCEVAllocator), V, this,
4527 FirstUnknown);
4528 FirstUnknown = cast<SCEVUnknown>(S);
4529 UniqueSCEVs.InsertNode(S, IP);
4530 return S;
4531}
4532
4533//===----------------------------------------------------------------------===//
4534// Basic SCEV Analysis and PHI Idiom Recognition Code
4535//
4536
4537/// Test if values of the given type are analyzable within the SCEV
4538/// framework. This primarily includes integer types, and it can optionally
4539/// include pointer types if the ScalarEvolution class has access to
4540/// target-specific information.
4542 // Integers and pointers are always SCEVable.
4543 return Ty->isIntOrPtrTy();
4544}
4545
4546/// Return the size in bits of the specified type, for which isSCEVable must
4547/// return true.
4549 assert(isSCEVable(Ty) && "Type is not SCEVable!");
4550 if (Ty->isPointerTy())
4552 return getDataLayout().getTypeSizeInBits(Ty);
4553}
4554
4555/// Return a type with the same bitwidth as the given type and which represents
4556/// how SCEV will treat the given type, for which isSCEVable must return
4557/// true. For pointer types, this is the pointer index sized integer type.
4559 assert(isSCEVable(Ty) && "Type is not SCEVable!");
4560
4561 if (Ty->isIntegerTy())
4562 return Ty;
4563
4564 // The only other support type is pointer.
4565 assert(Ty->isPointerTy() && "Unexpected non-pointer non-integer type!");
4566 return getDataLayout().getIndexType(Ty);
4567}
4568
4570 return getTypeSizeInBits(T1) >= getTypeSizeInBits(T2) ? T1 : T2;
4571}
4572
4574 const SCEV *B) {
4575 /// For a valid use point to exist, the defining scope of one operand
4576 /// must dominate the other.
4577 bool PreciseA, PreciseB;
4578 auto *ScopeA = getDefiningScopeBound({A}, PreciseA);
4579 auto *ScopeB = getDefiningScopeBound({B}, PreciseB);
4580 if (!PreciseA || !PreciseB)
4581 // Can't tell.
4582 return false;
4583 return (ScopeA == ScopeB) || DT.dominates(ScopeA, ScopeB) ||
4584 DT.dominates(ScopeB, ScopeA);
4585}
4586
4588 return CouldNotCompute.get();
4589}
4590
4591bool ScalarEvolution::checkValidity(const SCEV *S) const {
4592 bool ContainsNulls = SCEVExprContains(S, [](const SCEV *S) {
4593 auto *SU = dyn_cast<SCEVUnknown>(S);
4594 return SU && SU->getValue() == nullptr;
4595 });
4596
4597 return !ContainsNulls;
4598}
4599
4601 HasRecMapType::iterator I = HasRecMap.find(S);
4602 if (I != HasRecMap.end())
4603 return I->second;
4604
4605 bool FoundAddRec =
4606 SCEVExprContains(S, [](const SCEV *S) { return isa<SCEVAddRecExpr>(S); });
4607 HasRecMap.insert({S, FoundAddRec});
4608 return FoundAddRec;
4609}
4610
4611/// Return the ValueOffsetPair set for \p S. \p S can be represented
4612/// by the value and offset from any ValueOffsetPair in the set.
4613ArrayRef<Value *> ScalarEvolution::getSCEVValues(const SCEV *S) {
4614 ExprValueMapType::iterator SI = ExprValueMap.find_as(S);
4615 if (SI == ExprValueMap.end())
4616 return {};
4617 return SI->second.getArrayRef();
4618}
4619
4620/// Erase Value from ValueExprMap and ExprValueMap. ValueExprMap.erase(V)
4621/// cannot be used separately. eraseValueFromMap should be used to remove
4622/// V from ValueExprMap and ExprValueMap at the same time.
4623void ScalarEvolution::eraseValueFromMap(Value *V) {
4624 ValueExprMapType::iterator I = ValueExprMap.find_as(V);
4625 if (I != ValueExprMap.end()) {
4626 auto EVIt = ExprValueMap.find(I->second);
4627 bool Removed = EVIt->second.remove(V);
4628 (void) Removed;
4629 assert(Removed && "Value not in ExprValueMap?");
4630 ValueExprMap.erase(I);
4631 }
4632}
4633
4634void ScalarEvolution::insertValueToMap(Value *V, const SCEV *S) {
4635 // A recursive query may have already computed the SCEV. It should be
4636 // equivalent, but may not necessarily be exactly the same, e.g. due to lazily
4637 // inferred nowrap flags.
4638 auto It = ValueExprMap.find_as(V);
4639 if (It == ValueExprMap.end()) {
4640 ValueExprMap.insert({SCEVCallbackVH(V, this), S});
4641 ExprValueMap[S].insert(V);
4642 }
4643}
4644
4645/// Return an existing SCEV if it exists, otherwise analyze the expression and
4646/// create a new one.
4648 assert(isSCEVable(V->getType()) && "Value is not SCEVable!");
4649
4650 if (const SCEV *S = getExistingSCEV(V))
4651 return S;
4652 return createSCEVIter(V);
4653}
4654
4656 assert(isSCEVable(V->getType()) && "Value is not SCEVable!");
4657
4658 ValueExprMapType::iterator I = ValueExprMap.find_as(V);
4659 if (I != ValueExprMap.end()) {
4660 const SCEV *S = I->second;
4661 assert(checkValidity(S) &&
4662 "existing SCEV has not been properly invalidated");
4663 return S;
4664 }
4665 return nullptr;
4666}
4667
4668/// Return a SCEV corresponding to -V = -1*V
4670 SCEV::NoWrapFlags Flags) {
4671 if (const SCEVConstant *VC = dyn_cast<SCEVConstant>(V))
4672 return getConstant(
4673 cast<ConstantInt>(ConstantExpr::getNeg(VC->getValue())));
4674
4675 Type *Ty = V->getType();
4676 Ty = getEffectiveSCEVType(Ty);
4677 return getMulExpr(V, getMinusOne(Ty), Flags);
4678}
4679
4680/// If Expr computes ~A, return A else return nullptr
4681static const SCEV *MatchNotExpr(const SCEV *Expr) {
4682 const SCEV *MulOp;
4683 if (match(Expr, m_scev_Add(m_scev_AllOnes(),
4684 m_scev_Mul(m_scev_AllOnes(), m_SCEV(MulOp)))))
4685 return MulOp;
4686 return nullptr;
4687}
4688
4689/// Return a SCEV corresponding to ~V = -1-V
4691 assert(!V->getType()->isPointerTy() && "Can't negate pointer");
4692
4693 if (const SCEVConstant *VC = dyn_cast<SCEVConstant>(V))
4694 return getConstant(
4695 cast<ConstantInt>(ConstantExpr::getNot(VC->getValue())));
4696
4697 // Fold ~(u|s)(min|max)(~x, ~y) to (u|s)(max|min)(x, y)
4698 if (const SCEVMinMaxExpr *MME = dyn_cast<SCEVMinMaxExpr>(V)) {
4699 auto MatchMinMaxNegation = [&](const SCEVMinMaxExpr *MME) {
4700 SmallVector<SCEVUse, 2> MatchedOperands;
4701 for (const SCEV *Operand : MME->operands()) {
4702 const SCEV *Matched = MatchNotExpr(Operand);
4703 if (!Matched)
4704 return (const SCEV *)nullptr;
4705 MatchedOperands.push_back(Matched);
4706 }
4707 return getMinMaxExpr(SCEVMinMaxExpr::negate(MME->getSCEVType()),
4708 MatchedOperands);
4709 };
4710 if (const SCEV *Replaced = MatchMinMaxNegation(MME))
4711 return Replaced;
4712 }
4713
4714 Type *Ty = V->getType();
4715 Ty = getEffectiveSCEVType(Ty);
4716 return getMinusSCEV(getMinusOne(Ty), V);
4717}
4718
4720 assert(P->getType()->isPointerTy());
4721
4722 if (auto *AddRec = dyn_cast<SCEVAddRecExpr>(P)) {
4723 // The base of an AddRec is the first operand.
4724 SmallVector<SCEVUse> Ops{AddRec->operands()};
4725 Ops[0] = removePointerBase(Ops[0]);
4726 // Don't try to transfer nowrap flags for now. We could in some cases
4727 // (for example, if pointer operand of the AddRec is a SCEVUnknown).
4728 return getAddRecExpr(Ops, AddRec->getLoop(), SCEV::FlagAnyWrap);
4729 }
4730 if (auto *Add = dyn_cast<SCEVAddExpr>(P)) {
4731 // The base of an Add is the pointer operand.
4732 SmallVector<SCEVUse> Ops{Add->operands()};
4733 SCEVUse *PtrOp = nullptr;
4734 for (SCEVUse &AddOp : Ops) {
4735 if (AddOp->getType()->isPointerTy()) {
4736 assert(!PtrOp && "Cannot have multiple pointer ops");
4737 PtrOp = &AddOp;
4738 }
4739 }
4740 *PtrOp = removePointerBase(*PtrOp);
4741 // Don't try to transfer nowrap flags for now. We could in some cases
4742 // (for example, if the pointer operand of the Add is a SCEVUnknown).
4743 return getAddExpr(Ops);
4744 }
4745 // Any other expression must be a pointer base.
4746 return getZero(P->getType());
4747}
4748
4750 SCEV::NoWrapFlags Flags,
4751 unsigned Depth) {
4752 // Fast path: X - X --> 0.
4753 if (LHS == RHS)
4754 return getZero(LHS->getType());
4755
4756 // If we subtract two pointers with different pointer bases, bail.
4757 // Eventually, we're going to add an assertion to getMulExpr that we
4758 // can't multiply by a pointer.
4759 if (RHS->getType()->isPointerTy()) {
4760 if (!LHS->getType()->isPointerTy() ||
4761 getPointerBase(LHS) != getPointerBase(RHS))
4762 return getCouldNotCompute();
4763 LHS = removePointerBase(LHS);
4764 RHS = removePointerBase(RHS);
4765 }
4766
4767 // We represent LHS - RHS as LHS + (-1)*RHS. This transformation
4768 // makes it so that we cannot make much use of NUW.
4769 auto AddFlags = SCEV::FlagAnyWrap;
4770 const bool RHSIsNotMinSigned =
4772 if (hasFlags(Flags, SCEV::FlagNSW)) {
4773 // Let M be the minimum representable signed value. Then (-1)*RHS
4774 // signed-wraps if and only if RHS is M. That can happen even for
4775 // a NSW subtraction because e.g. (-1)*M signed-wraps even though
4776 // -1 - M does not. So to transfer NSW from LHS - RHS to LHS +
4777 // (-1)*RHS, we need to prove that RHS != M.
4778 //
4779 // If LHS is non-negative and we know that LHS - RHS does not
4780 // signed-wrap, then RHS cannot be M. So we can rule out signed-wrap
4781 // either by proving that RHS > M or that LHS >= 0.
4782 if (RHSIsNotMinSigned || isKnownNonNegative(LHS)) {
4783 AddFlags = SCEV::FlagNSW;
4784 }
4785 }
4786
4787 // FIXME: Find a correct way to transfer NSW to (-1)*M when LHS -
4788 // RHS is NSW and LHS >= 0.
4789 //
4790 // The difficulty here is that the NSW flag may have been proven
4791 // relative to a loop that is to be found in a recurrence in LHS and
4792 // not in RHS. Applying NSW to (-1)*M may then let the NSW have a
4793 // larger scope than intended.
4794 auto NegFlags = RHSIsNotMinSigned ? SCEV::FlagNSW : SCEV::FlagAnyWrap;
4795
4796 return getAddExpr(LHS, getNegativeSCEV(RHS, NegFlags), AddFlags, Depth);
4797}
4798
4800 unsigned Depth) {
4801 Type *SrcTy = V->getType();
4802 assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
4803 "Cannot truncate or zero extend with non-integer arguments!");
4804 if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
4805 return V; // No conversion
4806 if (getTypeSizeInBits(SrcTy) > getTypeSizeInBits(Ty))
4807 return getTruncateExpr(V, Ty, Depth);
4808 return getZeroExtendExpr(V, Ty, Depth);
4809}
4810
4812 unsigned Depth) {
4813 Type *SrcTy = V->getType();
4814 assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
4815 "Cannot truncate or zero extend with non-integer arguments!");
4816 if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
4817 return V; // No conversion
4818 if (getTypeSizeInBits(SrcTy) > getTypeSizeInBits(Ty))
4819 return getTruncateExpr(V, Ty, Depth);
4820 return getSignExtendExpr(V, Ty, Depth);
4821}
4822
4823const SCEV *
4825 Type *SrcTy = V->getType();
4826 assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
4827 "Cannot noop or zero extend with non-integer arguments!");
4829 "getNoopOrZeroExtend cannot truncate!");
4830 if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
4831 return V; // No conversion
4832 return getZeroExtendExpr(V, Ty);
4833}
4834
4835const SCEV *
4837 Type *SrcTy = V->getType();
4838 assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
4839 "Cannot noop or sign extend with non-integer arguments!");
4841 "getNoopOrSignExtend cannot truncate!");
4842 if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
4843 return V; // No conversion
4844 return getSignExtendExpr(V, Ty);
4845}
4846
4847const SCEV *
4849 Type *SrcTy = V->getType();
4850 assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
4851 "Cannot noop or any extend with non-integer arguments!");
4853 "getNoopOrAnyExtend cannot truncate!");
4854 if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
4855 return V; // No conversion
4856 return getAnyExtendExpr(V, Ty);
4857}
4858
4859const SCEV *
4861 Type *SrcTy = V->getType();
4862 assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
4863 "Cannot truncate or noop with non-integer arguments!");
4865 "getTruncateOrNoop cannot extend!");
4866 if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
4867 return V; // No conversion
4868 return getTruncateExpr(V, Ty);
4869}
4870
4872 const SCEV *RHS) {
4873 const SCEV *PromotedLHS = LHS;
4874 const SCEV *PromotedRHS = RHS;
4875
4876 if (getTypeSizeInBits(LHS->getType()) > getTypeSizeInBits(RHS->getType()))
4877 PromotedRHS = getZeroExtendExpr(RHS, LHS->getType());
4878 else
4879 PromotedLHS = getNoopOrZeroExtend(LHS, RHS->getType());
4880
4881 return getUMaxExpr(PromotedLHS, PromotedRHS);
4882}
4883
4885 const SCEV *RHS,
4886 bool Sequential) {
4887 SmallVector<SCEVUse, 2> Ops = {LHS, RHS};
4888 return getUMinFromMismatchedTypes(Ops, Sequential);
4889}
4890
4891const SCEV *
4893 bool Sequential) {
4894 assert(!Ops.empty() && "At least one operand must be!");
4895 // Trivial case.
4896 if (Ops.size() == 1)
4897 return Ops[0];
4898
4899 // Find the max type first.
4900 Type *MaxType = nullptr;
4901 for (SCEVUse S : Ops)
4902 if (MaxType)
4903 MaxType = getWiderType(MaxType, S->getType());
4904 else
4905 MaxType = S->getType();
4906 assert(MaxType && "Failed to find maximum type!");
4907
4908 // Extend all ops to max type.
4909 SmallVector<SCEVUse, 2> PromotedOps;
4910 for (SCEVUse S : Ops)
4911 PromotedOps.push_back(getNoopOrZeroExtend(S, MaxType));
4912
4913 // Generate umin.
4914 return getUMinExpr(PromotedOps, Sequential);
4915}
4916
4918 // A pointer operand may evaluate to a nonpointer expression, such as null.
4919 if (!V->getType()->isPointerTy())
4920 return V;
4921
4922 while (true) {
4923 if (auto *AddRec = dyn_cast<SCEVAddRecExpr>(V)) {
4924 V = AddRec->getStart();
4925 } else if (auto *Add = dyn_cast<SCEVAddExpr>(V)) {
4926 const SCEV *PtrOp = nullptr;
4927 for (const SCEV *AddOp : Add->operands()) {
4928 if (AddOp->getType()->isPointerTy()) {
4929 assert(!PtrOp && "Cannot have multiple pointer ops");
4930 PtrOp = AddOp;
4931 }
4932 }
4933 assert(PtrOp && "Must have pointer op");
4934 V = PtrOp;
4935 } else // Not something we can look further into.
4936 return V;
4937 }
4938}
4939
4940/// Push users of the given Instruction onto the given Worklist.
4944 // Push the def-use children onto the Worklist stack.
4945 for (User *U : I->users()) {
4946 auto *UserInsn = cast<Instruction>(U);
4947 if (Visited.insert(UserInsn).second)
4948 Worklist.push_back(UserInsn);
4949 }
4950}
4951
4952namespace {
4953
4954/// Takes SCEV S and Loop L. For each AddRec sub-expression, use its start
4955/// expression in case its Loop is L. If it is not L then
4956/// if IgnoreOtherLoops is true then use AddRec itself
4957/// otherwise rewrite cannot be done.
4958/// If SCEV contains non-invariant unknown SCEV rewrite cannot be done.
4959class SCEVInitRewriter : public SCEVRewriteVisitor<SCEVInitRewriter> {
4960public:
4961 static const SCEV *rewrite(const SCEV *S, const Loop *L, ScalarEvolution &SE,
4962 bool IgnoreOtherLoops = true) {
4963 SCEVInitRewriter Rewriter(L, SE);
4964 const SCEV *Result = Rewriter.visit(S);
4965 if (Rewriter.hasSeenLoopVariantSCEVUnknown())
4966 return SE.getCouldNotCompute();
4967 return Rewriter.hasSeenOtherLoops() && !IgnoreOtherLoops
4968 ? SE.getCouldNotCompute()
4969 : Result;
4970 }
4971
4972 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
4973 if (!SE.isLoopInvariant(Expr, L))
4974 SeenLoopVariantSCEVUnknown = true;
4975 return Expr;
4976 }
4977
4978 const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) {
4979 // Only re-write AddRecExprs for this loop.
4980 if (Expr->getLoop() == L)
4981 return Expr->getStart();
4982 SeenOtherLoops = true;
4983 return Expr;
4984 }
4985
4986 bool hasSeenLoopVariantSCEVUnknown() { return SeenLoopVariantSCEVUnknown; }
4987
4988 bool hasSeenOtherLoops() { return SeenOtherLoops; }
4989
4990private:
4991 explicit SCEVInitRewriter(const Loop *L, ScalarEvolution &SE)
4992 : SCEVRewriteVisitor(SE), L(L) {}
4993
4994 const Loop *L;
4995 bool SeenLoopVariantSCEVUnknown = false;
4996 bool SeenOtherLoops = false;
4997};
4998
4999/// Takes SCEV S and Loop L. For each AddRec sub-expression, use its post
5000/// increment expression in case its Loop is L. If it is not L then
5001/// use AddRec itself.
5002/// If SCEV contains non-invariant unknown SCEV rewrite cannot be done.
5003class SCEVPostIncRewriter : public SCEVRewriteVisitor<SCEVPostIncRewriter> {
5004public:
5005 static const SCEV *rewrite(const SCEV *S, const Loop *L, ScalarEvolution &SE) {
5006 SCEVPostIncRewriter Rewriter(L, SE);
5007 const SCEV *Result = Rewriter.visit(S);
5008 return Rewriter.hasSeenLoopVariantSCEVUnknown()
5009 ? SE.getCouldNotCompute()
5010 : Result;
5011 }
5012
5013 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
5014 if (!SE.isLoopInvariant(Expr, L))
5015 SeenLoopVariantSCEVUnknown = true;
5016 return Expr;
5017 }
5018
5019 const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) {
5020 // Only re-write AddRecExprs for this loop.
5021 if (Expr->getLoop() == L)
5022 return Expr->getPostIncExpr(SE);
5023 SeenOtherLoops = true;
5024 return Expr;
5025 }
5026
5027 bool hasSeenLoopVariantSCEVUnknown() { return SeenLoopVariantSCEVUnknown; }
5028
5029 bool hasSeenOtherLoops() { return SeenOtherLoops; }
5030
5031private:
5032 explicit SCEVPostIncRewriter(const Loop *L, ScalarEvolution &SE)
5033 : SCEVRewriteVisitor(SE), L(L) {}
5034
5035 const Loop *L;
5036 bool SeenLoopVariantSCEVUnknown = false;
5037 bool SeenOtherLoops = false;
5038};
5039
5040/// This class evaluates the compare condition by matching it against the
5041/// condition of loop latch. If there is a match we assume a true value
5042/// for the condition while building SCEV nodes.
5043class SCEVBackedgeConditionFolder
5044 : public SCEVRewriteVisitor<SCEVBackedgeConditionFolder> {
5045public:
5046 static const SCEV *rewrite(const SCEV *S, const Loop *L,
5047 ScalarEvolution &SE) {
5048 bool IsPosBECond = false;
5049 Value *BECond = nullptr;
5050 if (BasicBlock *Latch = L->getLoopLatch()) {
5051 if (CondBrInst *BI = dyn_cast<CondBrInst>(Latch->getTerminator())) {
5052 assert(BI->getSuccessor(0) != BI->getSuccessor(1) &&
5053 "Both outgoing branches should not target same header!");
5054 BECond = BI->getCondition();
5055 IsPosBECond = BI->getSuccessor(0) == L->getHeader();
5056 } else {
5057 return S;
5058 }
5059 }
5060 SCEVBackedgeConditionFolder Rewriter(L, BECond, IsPosBECond, SE);
5061 return Rewriter.visit(S);
5062 }
5063
5064 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
5065 const SCEV *Result = Expr;
5066 bool InvariantF = SE.isLoopInvariant(Expr, L);
5067
5068 if (!InvariantF) {
5070 switch (I->getOpcode()) {
5071 case Instruction::Select: {
5072 SelectInst *SI = cast<SelectInst>(I);
5073 std::optional<const SCEV *> Res =
5074 compareWithBackedgeCondition(SI->getCondition());
5075 if (Res) {
5076 bool IsOne = cast<SCEVConstant>(*Res)->getValue()->isOne();
5077 Result = SE.getSCEV(IsOne ? SI->getTrueValue() : SI->getFalseValue());
5078 }
5079 break;
5080 }
5081 default: {
5082 std::optional<const SCEV *> Res = compareWithBackedgeCondition(I);
5083 if (Res)
5084 Result = *Res;
5085 break;
5086 }
5087 }
5088 }
5089 return Result;
5090 }
5091
5092private:
5093 explicit SCEVBackedgeConditionFolder(const Loop *L, Value *BECond,
5094 bool IsPosBECond, ScalarEvolution &SE)
5095 : SCEVRewriteVisitor(SE), L(L), BackedgeCond(BECond),
5096 IsPositiveBECond(IsPosBECond) {}
5097
5098 std::optional<const SCEV *> compareWithBackedgeCondition(Value *IC);
5099
5100 const Loop *L;
5101 /// Loop back condition.
5102 Value *BackedgeCond = nullptr;
5103 /// Set to true if loop back is on positive branch condition.
5104 bool IsPositiveBECond;
5105};
5106
5107std::optional<const SCEV *>
5108SCEVBackedgeConditionFolder::compareWithBackedgeCondition(Value *IC) {
5109
5110 // If value matches the backedge condition for loop latch,
5111 // then return a constant evolution node based on loopback
5112 // branch taken.
5113 if (BackedgeCond == IC)
5114 return IsPositiveBECond ? SE.getOne(Type::getInt1Ty(SE.getContext()))
5116 return std::nullopt;
5117}
5118
5119class SCEVShiftRewriter : public SCEVRewriteVisitor<SCEVShiftRewriter> {
5120public:
5121 static const SCEV *rewrite(const SCEV *S, const Loop *L,
5122 ScalarEvolution &SE) {
5123 SCEVShiftRewriter Rewriter(L, SE);
5124 const SCEV *Result = Rewriter.visit(S);
5125 return Rewriter.isValid() ? Result : SE.getCouldNotCompute();
5126 }
5127
5128 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
5129 // Only allow AddRecExprs for this loop.
5130 if (!SE.isLoopInvariant(Expr, L))
5131 Valid = false;
5132 return Expr;
5133 }
5134
5135 const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) {
5136 if (Expr->getLoop() == L && Expr->isAffine())
5137 return SE.getMinusSCEV(Expr, Expr->getStepRecurrence(SE));
5138 Valid = false;
5139 return Expr;
5140 }
5141
5142 bool isValid() { return Valid; }
5143
5144private:
5145 explicit SCEVShiftRewriter(const Loop *L, ScalarEvolution &SE)
5146 : SCEVRewriteVisitor(SE), L(L) {}
5147
5148 const Loop *L;
5149 bool Valid = true;
5150};
5151
5152} // end anonymous namespace
5153
5155ScalarEvolution::proveNoWrapViaConstantRanges(const SCEVAddRecExpr *AR) {
5156 if (!AR->isAffine())
5157 return SCEV::FlagAnyWrap;
5158
5159 using OBO = OverflowingBinaryOperator;
5160
5162
5163 if (!AR->hasNoSelfWrap()) {
5164 const SCEV *BECount = getConstantMaxBackedgeTakenCount(AR->getLoop());
5165 if (const SCEVConstant *BECountMax = dyn_cast<SCEVConstant>(BECount)) {
5166 ConstantRange StepCR = getSignedRange(AR->getStepRecurrence(*this));
5167 const APInt &BECountAP = BECountMax->getAPInt();
5168 unsigned NoOverflowBitWidth =
5169 BECountAP.getActiveBits() + StepCR.getMinSignedBits();
5170 if (NoOverflowBitWidth <= getTypeSizeInBits(AR->getType()))
5172 }
5173 }
5174
5175 if (!AR->hasNoSignedWrap()) {
5176 ConstantRange AddRecRange = getSignedRange(AR);
5177 ConstantRange IncRange = getSignedRange(AR->getStepRecurrence(*this));
5178
5180 Instruction::Add, IncRange, OBO::NoSignedWrap);
5181 if (NSWRegion.contains(AddRecRange))
5183 }
5184
5185 if (!AR->hasNoUnsignedWrap()) {
5186 ConstantRange AddRecRange = getUnsignedRange(AR);
5187 ConstantRange IncRange = getUnsignedRange(AR->getStepRecurrence(*this));
5188
5190 Instruction::Add, IncRange, OBO::NoUnsignedWrap);
5191 if (NUWRegion.contains(AddRecRange))
5193 }
5194
5195 return Result;
5196}
5197
5199ScalarEvolution::proveNoSignedWrapViaInduction(const SCEVAddRecExpr *AR) {
5201
5202 if (AR->hasNoSignedWrap())
5203 return Result;
5204
5205 if (!AR->isAffine())
5206 return Result;
5207
5208 // This function can be expensive, only try to prove NSW once per AddRec.
5209 if (!SignedWrapViaInductionTried.insert(AR).second)
5210 return Result;
5211
5212 const SCEV *Step = AR->getStepRecurrence(*this);
5213 const Loop *L = AR->getLoop();
5214
5215 // Check whether the backedge-taken count is SCEVCouldNotCompute.
5216 // Note that this serves two purposes: It filters out loops that are
5217 // simply not analyzable, and it covers the case where this code is
5218 // being called from within backedge-taken count analysis, such that
5219 // attempting to ask for the backedge-taken count would likely result
5220 // in infinite recursion. In the later case, the analysis code will
5221 // cope with a conservative value, and it will take care to purge
5222 // that value once it has finished.
5223 const SCEV *MaxBECount = getConstantMaxBackedgeTakenCount(L);
5224
5225 // Normally, in the cases we can prove no-overflow via a
5226 // backedge guarding condition, we can also compute a backedge
5227 // taken count for the loop. The exceptions are assumptions and
5228 // guards present in the loop -- SCEV is not great at exploiting
5229 // these to compute max backedge taken counts, but can still use
5230 // these to prove lack of overflow. Use this fact to avoid
5231 // doing extra work that may not pay off.
5232
5233 if (isa<SCEVCouldNotCompute>(MaxBECount) && !HasGuards &&
5234 AC.assumptions().empty())
5235 return Result;
5236
5237 // If the backedge is guarded by a comparison with the pre-inc value the
5238 // addrec is safe. Also, if the entry is guarded by a comparison with the
5239 // start value and the backedge is guarded by a comparison with the post-inc
5240 // value, the addrec is safe.
5242 const SCEV *OverflowLimit =
5243 getSignedOverflowLimitForStep(Step, &Pred, this);
5244 if (OverflowLimit &&
5245 (isLoopBackedgeGuardedByCond(L, Pred, AR, OverflowLimit) ||
5246 isKnownOnEveryIteration(Pred, AR, OverflowLimit))) {
5247 Result = setFlags(Result, SCEV::FlagNSW);
5248 }
5249 return Result;
5250}
5252ScalarEvolution::proveNoUnsignedWrapViaInduction(const SCEVAddRecExpr *AR) {
5254
5255 if (AR->hasNoUnsignedWrap())
5256 return Result;
5257
5258 if (!AR->isAffine())
5259 return Result;
5260
5261 // This function can be expensive, only try to prove NUW once per AddRec.
5262 if (!UnsignedWrapViaInductionTried.insert(AR).second)
5263 return Result;
5264
5265 const SCEV *Step = AR->getStepRecurrence(*this);
5266 unsigned BitWidth = getTypeSizeInBits(AR->getType());
5267 const Loop *L = AR->getLoop();
5268
5269 // Check whether the backedge-taken count is SCEVCouldNotCompute.
5270 // Note that this serves two purposes: It filters out loops that are
5271 // simply not analyzable, and it covers the case where this code is
5272 // being called from within backedge-taken count analysis, such that
5273 // attempting to ask for the backedge-taken count would likely result
5274 // in infinite recursion. In the later case, the analysis code will
5275 // cope with a conservative value, and it will take care to purge
5276 // that value once it has finished.
5277 const SCEV *MaxBECount = getConstantMaxBackedgeTakenCount(L);
5278
5279 // Normally, in the cases we can prove no-overflow via a
5280 // backedge guarding condition, we can also compute a backedge
5281 // taken count for the loop. The exceptions are assumptions and
5282 // guards present in the loop -- SCEV is not great at exploiting
5283 // these to compute max backedge taken counts, but can still use
5284 // these to prove lack of overflow. Use this fact to avoid
5285 // doing extra work that may not pay off.
5286
5287 if (isa<SCEVCouldNotCompute>(MaxBECount) && !HasGuards &&
5288 AC.assumptions().empty())
5289 return Result;
5290
5291 // If the backedge is guarded by a comparison with the pre-inc value the
5292 // addrec is safe. Also, if the entry is guarded by a comparison with the
5293 // start value and the backedge is guarded by a comparison with the post-inc
5294 // value, the addrec is safe.
5295 if (isKnownPositive(Step)) {
5296 const SCEV *N = getConstant(APInt::getMinValue(BitWidth) -
5297 getUnsignedRangeMax(Step));
5300 Result = setFlags(Result, SCEV::FlagNUW);
5301 }
5302 }
5303
5304 return Result;
5305}
5306
5307namespace {
5308
5309/// Represents an abstract binary operation. This may exist as a
5310/// normal instruction or constant expression, or may have been
5311/// derived from an expression tree.
5312struct BinaryOp {
5313 unsigned Opcode;
5314 Value *LHS;
5315 Value *RHS;
5316 bool IsNSW = false;
5317 bool IsNUW = false;
5318
5319 /// Op is set if this BinaryOp corresponds to a concrete LLVM instruction or
5320 /// constant expression.
5321 Operator *Op = nullptr;
5322
5323 explicit BinaryOp(Operator *Op)
5324 : Opcode(Op->getOpcode()), LHS(Op->getOperand(0)), RHS(Op->getOperand(1)),
5325 Op(Op) {
5326 if (auto *OBO = dyn_cast<OverflowingBinaryOperator>(Op)) {
5327 IsNSW = OBO->hasNoSignedWrap();
5328 IsNUW = OBO->hasNoUnsignedWrap();
5329 }
5330 }
5331
5332 explicit BinaryOp(unsigned Opcode, Value *LHS, Value *RHS, bool IsNSW = false,
5333 bool IsNUW = false)
5334 : Opcode(Opcode), LHS(LHS), RHS(RHS), IsNSW(IsNSW), IsNUW(IsNUW) {}
5335};
5336
5337} // end anonymous namespace
5338
5339/// Try to map \p V into a BinaryOp, and return \c std::nullopt on failure.
5340static std::optional<BinaryOp> MatchBinaryOp(Value *V, const DataLayout &DL,
5341 AssumptionCache &AC,
5342 const DominatorTree &DT,
5343 const Instruction *CxtI) {
5344 auto *Op = dyn_cast<Operator>(V);
5345 if (!Op)
5346 return std::nullopt;
5347
5348 // Implementation detail: all the cleverness here should happen without
5349 // creating new SCEV expressions -- our caller knowns tricks to avoid creating
5350 // SCEV expressions when possible, and we should not break that.
5351
5352 switch (Op->getOpcode()) {
5353 case Instruction::Add:
5354 case Instruction::Sub:
5355 case Instruction::Mul:
5356 case Instruction::UDiv:
5357 case Instruction::URem:
5358 case Instruction::And:
5359 case Instruction::AShr:
5360 case Instruction::Shl:
5361 return BinaryOp(Op);
5362
5363 case Instruction::Or: {
5364 // Convert or disjoint into add nuw nsw.
5365 if (cast<PossiblyDisjointInst>(Op)->isDisjoint())
5366 return BinaryOp(Instruction::Add, Op->getOperand(0), Op->getOperand(1),
5367 /*IsNSW=*/true, /*IsNUW=*/true);
5368 return BinaryOp(Op);
5369 }
5370
5371 case Instruction::Xor:
5372 if (auto *RHSC = dyn_cast<ConstantInt>(Op->getOperand(1)))
5373 // If the RHS of the xor is a signmask, then this is just an add.
5374 // Instcombine turns add of signmask into xor as a strength reduction step.
5375 if (RHSC->getValue().isSignMask())
5376 return BinaryOp(Instruction::Add, Op->getOperand(0), Op->getOperand(1));
5377 // Binary `xor` is a bit-wise `add`.
5378 if (V->getType()->isIntegerTy(1))
5379 return BinaryOp(Instruction::Add, Op->getOperand(0), Op->getOperand(1));
5380 return BinaryOp(Op);
5381
5382 case Instruction::LShr:
5383 // Turn logical shift right of a constant into a unsigned divide.
5384 if (ConstantInt *SA = dyn_cast<ConstantInt>(Op->getOperand(1))) {
5385 uint32_t BitWidth = cast<IntegerType>(Op->getType())->getBitWidth();
5386
5387 // If the shift count is not less than the bitwidth, the result of
5388 // the shift is undefined. Don't try to analyze it, because the
5389 // resolution chosen here may differ from the resolution chosen in
5390 // other parts of the compiler.
5391 if (SA->getValue().ult(BitWidth)) {
5392 Constant *X =
5393 ConstantInt::get(SA->getContext(),
5394 APInt::getOneBitSet(BitWidth, SA->getZExtValue()));
5395 return BinaryOp(Instruction::UDiv, Op->getOperand(0), X);
5396 }
5397 }
5398 return BinaryOp(Op);
5399
5400 case Instruction::ExtractValue: {
5401 auto *EVI = cast<ExtractValueInst>(Op);
5402 if (EVI->getNumIndices() != 1 || EVI->getIndices()[0] != 0)
5403 break;
5404
5405 auto *WO = dyn_cast<WithOverflowInst>(EVI->getAggregateOperand());
5406 if (!WO)
5407 break;
5408
5409 Instruction::BinaryOps BinOp = WO->getBinaryOp();
5410 bool Signed = WO->isSigned();
5411 // TODO: Should add nuw/nsw flags for mul as well.
5412 if (BinOp == Instruction::Mul || !isOverflowIntrinsicNoWrap(WO, DT))
5413 return BinaryOp(BinOp, WO->getLHS(), WO->getRHS());
5414
5415 // Now that we know that all uses of the arithmetic-result component of
5416 // CI are guarded by the overflow check, we can go ahead and pretend
5417 // that the arithmetic is non-overflowing.
5418 return BinaryOp(BinOp, WO->getLHS(), WO->getRHS(),
5419 /* IsNSW = */ Signed, /* IsNUW = */ !Signed);
5420 }
5421
5422 default:
5423 break;
5424 }
5425
5426 // Recognise intrinsic loop.decrement.reg, and as this has exactly the same
5427 // semantics as a Sub, return a binary sub expression.
5428 if (auto *II = dyn_cast<IntrinsicInst>(V))
5429 if (II->getIntrinsicID() == Intrinsic::loop_decrement_reg)
5430 return BinaryOp(Instruction::Sub, II->getOperand(0), II->getOperand(1));
5431
5432 return std::nullopt;
5433}
5434
5435/// Helper function to createAddRecFromPHIWithCasts. We have a phi
5436/// node whose symbolic (unknown) SCEV is \p SymbolicPHI, which is updated via
5437/// the loop backedge by a SCEVAddExpr, possibly also with a few casts on the
5438/// way. This function checks if \p Op, an operand of this SCEVAddExpr,
5439/// follows one of the following patterns:
5440/// Op == (SExt ix (Trunc iy (%SymbolicPHI) to ix) to iy)
5441/// Op == (ZExt ix (Trunc iy (%SymbolicPHI) to ix) to iy)
5442/// If the SCEV expression of \p Op conforms with one of the expected patterns
5443/// we return the type of the truncation operation, and indicate whether the
5444/// truncated type should be treated as signed/unsigned by setting
5445/// \p Signed to true/false, respectively.
5446static Type *isSimpleCastedPHI(const SCEV *Op, const SCEVUnknown *SymbolicPHI,
5447 bool &Signed, ScalarEvolution &SE) {
5448 // The case where Op == SymbolicPHI (that is, with no type conversions on
5449 // the way) is handled by the regular add recurrence creating logic and
5450 // would have already been triggered in createAddRecForPHI. Reaching it here
5451 // means that createAddRecFromPHI had failed for this PHI before (e.g.,
5452 // because one of the other operands of the SCEVAddExpr updating this PHI is
5453 // not invariant).
5454 //
5455 // Here we look for the case where Op = (ext(trunc(SymbolicPHI))), and in
5456 // this case predicates that allow us to prove that Op == SymbolicPHI will
5457 // be added.
5458 if (Op == SymbolicPHI)
5459 return nullptr;
5460
5461 unsigned SourceBits = SE.getTypeSizeInBits(SymbolicPHI->getType());
5462 unsigned NewBits = SE.getTypeSizeInBits(Op->getType());
5463 if (SourceBits != NewBits)
5464 return nullptr;
5465
5466 if (match(Op, m_scev_SExt(m_scev_Trunc(m_scev_Specific(SymbolicPHI))))) {
5467 Signed = true;
5468 return cast<SCEVCastExpr>(Op)->getOperand()->getType();
5469 }
5470 if (match(Op, m_scev_ZExt(m_scev_Trunc(m_scev_Specific(SymbolicPHI))))) {
5471 Signed = false;
5472 return cast<SCEVCastExpr>(Op)->getOperand()->getType();
5473 }
5474 return nullptr;
5475}
5476
5477static const Loop *isIntegerLoopHeaderPHI(const PHINode *PN, LoopInfo &LI) {
5478 if (!PN->getType()->isIntegerTy())
5479 return nullptr;
5480 const Loop *L = LI.getLoopFor(PN->getParent());
5481 if (!L || L->getHeader() != PN->getParent())
5482 return nullptr;
5483 return L;
5484}
5485
5486// Analyze \p SymbolicPHI, a SCEV expression of a phi node, and check if the
5487// computation that updates the phi follows the following pattern:
5488// (SExt/ZExt ix (Trunc iy (%SymbolicPHI) to ix) to iy) + InvariantAccum
5489// which correspond to a phi->trunc->sext/zext->add->phi update chain.
5490// If so, try to see if it can be rewritten as an AddRecExpr under some
5491// Predicates. If successful, return them as a pair. Also cache the results
5492// of the analysis.
5493//
5494// Example usage scenario:
5495// Say the Rewriter is called for the following SCEV:
5496// 8 * ((sext i32 (trunc i64 %X to i32) to i64) + %Step)
5497// where:
5498// %X = phi i64 (%Start, %BEValue)
5499// It will visitMul->visitAdd->visitSExt->visitTrunc->visitUnknown(%X),
5500// and call this function with %SymbolicPHI = %X.
5501//
5502// The analysis will find that the value coming around the backedge has
5503// the following SCEV:
5504// BEValue = ((sext i32 (trunc i64 %X to i32) to i64) + %Step)
5505// Upon concluding that this matches the desired pattern, the function
5506// will return the pair {NewAddRec, SmallPredsVec} where:
5507// NewAddRec = {%Start,+,%Step}
5508// SmallPredsVec = {P1, P2, P3} as follows:
5509// P1(WrapPred): AR: {trunc(%Start),+,(trunc %Step)}<nsw> Flags: <nssw>
5510// P2(EqualPred): %Start == (sext i32 (trunc i64 %Start to i32) to i64)
5511// P3(EqualPred): %Step == (sext i32 (trunc i64 %Step to i32) to i64)
5512// The returned pair means that SymbolicPHI can be rewritten into NewAddRec
5513// under the predicates {P1,P2,P3}.
5514// This predicated rewrite will be cached in PredicatedSCEVRewrites:
5515// PredicatedSCEVRewrites[{%X,L}] = {NewAddRec, {P1,P2,P3)}
5516//
5517// TODO's:
5518//
5519// 1) Extend the Induction descriptor to also support inductions that involve
5520// casts: When needed (namely, when we are called in the context of the
5521// vectorizer induction analysis), a Set of cast instructions will be
5522// populated by this method, and provided back to isInductionPHI. This is
5523// needed to allow the vectorizer to properly record them to be ignored by
5524// the cost model and to avoid vectorizing them (otherwise these casts,
5525// which are redundant under the runtime overflow checks, will be
5526// vectorized, which can be costly).
5527//
5528// 2) Support additional induction/PHISCEV patterns: We also want to support
5529// inductions where the sext-trunc / zext-trunc operations (partly) occur
5530// after the induction update operation (the induction increment):
5531//
5532// (Trunc iy (SExt/ZExt ix (%SymbolicPHI + InvariantAccum) to iy) to ix)
5533// which correspond to a phi->add->trunc->sext/zext->phi update chain.
5534//
5535// (Trunc iy ((SExt/ZExt ix (%SymbolicPhi) to iy) + InvariantAccum) to ix)
5536// which correspond to a phi->trunc->add->sext/zext->phi update chain.
5537//
5538// 3) Outline common code with createAddRecFromPHI to avoid duplication.
5539std::optional<std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>>>
5540ScalarEvolution::createAddRecFromPHIWithCastsImpl(const SCEVUnknown *SymbolicPHI) {
5542
5543 // *** Part1: Analyze if we have a phi-with-cast pattern for which we can
5544 // return an AddRec expression under some predicate.
5545
5546 auto *PN = cast<PHINode>(SymbolicPHI->getValue());
5547 const Loop *L = isIntegerLoopHeaderPHI(PN, LI);
5548 assert(L && "Expecting an integer loop header phi");
5549
5550 // The loop may have multiple entrances or multiple exits; we can analyze
5551 // this phi as an addrec if it has a unique entry value and a unique
5552 // backedge value.
5553 Value *BEValueV = nullptr, *StartValueV = nullptr;
5554 for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) {
5555 Value *V = PN->getIncomingValue(i);
5556 if (L->contains(PN->getIncomingBlock(i))) {
5557 if (!BEValueV) {
5558 BEValueV = V;
5559 } else if (BEValueV != V) {
5560 BEValueV = nullptr;
5561 break;
5562 }
5563 } else if (!StartValueV) {
5564 StartValueV = V;
5565 } else if (StartValueV != V) {
5566 StartValueV = nullptr;
5567 break;
5568 }
5569 }
5570 if (!BEValueV || !StartValueV)
5571 return std::nullopt;
5572
5573 const SCEV *BEValue = getSCEV(BEValueV);
5574
5575 // If the value coming around the backedge is an add with the symbolic
5576 // value we just inserted, possibly with casts that we can ignore under
5577 // an appropriate runtime guard, then we found a simple induction variable!
5578 const auto *Add = dyn_cast<SCEVAddExpr>(BEValue);
5579 if (!Add)
5580 return std::nullopt;
5581
5582 // If there is a single occurrence of the symbolic value, possibly
5583 // casted, replace it with a recurrence.
5584 unsigned FoundIndex = Add->getNumOperands();
5585 Type *TruncTy = nullptr;
5586 bool Signed;
5587 for (unsigned i = 0, e = Add->getNumOperands(); i != e; ++i)
5588 if ((TruncTy =
5589 isSimpleCastedPHI(Add->getOperand(i), SymbolicPHI, Signed, *this)))
5590 if (FoundIndex == e) {
5591 FoundIndex = i;
5592 break;
5593 }
5594
5595 if (FoundIndex == Add->getNumOperands())
5596 return std::nullopt;
5597
5598 // Create an add with everything but the specified operand.
5600 for (unsigned i = 0, e = Add->getNumOperands(); i != e; ++i)
5601 if (i != FoundIndex)
5602 Ops.push_back(Add->getOperand(i));
5603 const SCEV *Accum = getAddExpr(Ops);
5604
5605 // The runtime checks will not be valid if the step amount is
5606 // varying inside the loop.
5607 if (!isLoopInvariant(Accum, L))
5608 return std::nullopt;
5609
5610 // *** Part2: Create the predicates
5611
5612 // Analysis was successful: we have a phi-with-cast pattern for which we
5613 // can return an AddRec expression under the following predicates:
5614 //
5615 // P1: A Wrap predicate that guarantees that Trunc(Start) + i*Trunc(Accum)
5616 // fits within the truncated type (does not overflow) for i = 0 to n-1.
5617 // P2: An Equal predicate that guarantees that
5618 // Start = (Ext ix (Trunc iy (Start) to ix) to iy)
5619 // P3: An Equal predicate that guarantees that
5620 // Accum = (Ext ix (Trunc iy (Accum) to ix) to iy)
5621 //
5622 // As we next prove, the above predicates guarantee that:
5623 // Start + i*Accum = (Ext ix (Trunc iy ( Start + i*Accum ) to ix) to iy)
5624 //
5625 //
5626 // More formally, we want to prove that:
5627 // Expr(i+1) = Start + (i+1) * Accum
5628 // = (Ext ix (Trunc iy (Expr(i)) to ix) to iy) + Accum
5629 //
5630 // Given that:
5631 // 1) Expr(0) = Start
5632 // 2) Expr(1) = Start + Accum
5633 // = (Ext ix (Trunc iy (Start) to ix) to iy) + Accum :: from P2
5634 // 3) Induction hypothesis (step i):
5635 // Expr(i) = (Ext ix (Trunc iy (Expr(i-1)) to ix) to iy) + Accum
5636 //
5637 // Proof:
5638 // Expr(i+1) =
5639 // = Start + (i+1)*Accum
5640 // = (Start + i*Accum) + Accum
5641 // = Expr(i) + Accum
5642 // = (Ext ix (Trunc iy (Expr(i-1)) to ix) to iy) + Accum + Accum
5643 // :: from step i
5644 //
5645 // = (Ext ix (Trunc iy (Start + (i-1)*Accum) to ix) to iy) + Accum + Accum
5646 //
5647 // = (Ext ix (Trunc iy (Start + (i-1)*Accum) to ix) to iy)
5648 // + (Ext ix (Trunc iy (Accum) to ix) to iy)
5649 // + Accum :: from P3
5650 //
5651 // = (Ext ix (Trunc iy ((Start + (i-1)*Accum) + Accum) to ix) to iy)
5652 // + Accum :: from P1: Ext(x)+Ext(y)=>Ext(x+y)
5653 //
5654 // = (Ext ix (Trunc iy (Start + i*Accum) to ix) to iy) + Accum
5655 // = (Ext ix (Trunc iy (Expr(i)) to ix) to iy) + Accum
5656 //
5657 // By induction, the same applies to all iterations 1<=i<n:
5658 //
5659
5660 // Create a truncated addrec for which we will add a no overflow check (P1).
5661 const SCEV *StartVal = getSCEV(StartValueV);
5662 const SCEV *PHISCEV =
5663 getAddRecExpr(getTruncateExpr(StartVal, TruncTy),
5664 getTruncateExpr(Accum, TruncTy), L, SCEV::FlagAnyWrap);
5665
5666 // PHISCEV can be either a SCEVConstant or a SCEVAddRecExpr.
5667 // ex: If truncated Accum is 0 and StartVal is a constant, then PHISCEV
5668 // will be constant.
5669 //
5670 // If PHISCEV is a constant, then P1 degenerates into P2 or P3, so we don't
5671 // add P1.
5672 if (const auto *AR = dyn_cast<SCEVAddRecExpr>(PHISCEV)) {
5676 const SCEVPredicate *AddRecPred = getWrapPredicate(AR, AddedFlags);
5677 Predicates.push_back(AddRecPred);
5678 }
5679
5680 // Create the Equal Predicates P2,P3:
5681
5682 // It is possible that the predicates P2 and/or P3 are computable at
5683 // compile time due to StartVal and/or Accum being constants.
5684 // If either one is, then we can check that now and escape if either P2
5685 // or P3 is false.
5686
5687 // Construct the extended SCEV: (Ext ix (Trunc iy (Expr) to ix) to iy)
5688 // for each of StartVal and Accum
5689 auto getExtendedExpr = [&](const SCEV *Expr,
5690 bool CreateSignExtend) -> const SCEV * {
5691 assert(isLoopInvariant(Expr, L) && "Expr is expected to be invariant");
5692 const SCEV *TruncatedExpr = getTruncateExpr(Expr, TruncTy);
5693 const SCEV *ExtendedExpr =
5694 CreateSignExtend ? getSignExtendExpr(TruncatedExpr, Expr->getType())
5695 : getZeroExtendExpr(TruncatedExpr, Expr->getType());
5696 return ExtendedExpr;
5697 };
5698
5699 // Given:
5700 // ExtendedExpr = (Ext ix (Trunc iy (Expr) to ix) to iy
5701 // = getExtendedExpr(Expr)
5702 // Determine whether the predicate P: Expr == ExtendedExpr
5703 // is known to be false at compile time
5704 auto PredIsKnownFalse = [&](const SCEV *Expr,
5705 const SCEV *ExtendedExpr) -> bool {
5706 return Expr != ExtendedExpr &&
5707 isKnownPredicate(ICmpInst::ICMP_NE, Expr, ExtendedExpr);
5708 };
5709
5710 const SCEV *StartExtended = getExtendedExpr(StartVal, Signed);
5711 if (PredIsKnownFalse(StartVal, StartExtended)) {
5712 LLVM_DEBUG(dbgs() << "P2 is compile-time false\n";);
5713 return std::nullopt;
5714 }
5715
5716 // The Step is always Signed (because the overflow checks are either
5717 // NSSW or NUSW)
5718 const SCEV *AccumExtended = getExtendedExpr(Accum, /*CreateSignExtend=*/true);
5719 if (PredIsKnownFalse(Accum, AccumExtended)) {
5720 LLVM_DEBUG(dbgs() << "P3 is compile-time false\n";);
5721 return std::nullopt;
5722 }
5723
5724 auto AppendPredicate = [&](const SCEV *Expr,
5725 const SCEV *ExtendedExpr) -> void {
5726 if (Expr != ExtendedExpr &&
5727 !isKnownPredicate(ICmpInst::ICMP_EQ, Expr, ExtendedExpr)) {
5728 const SCEVPredicate *Pred = getEqualPredicate(Expr, ExtendedExpr);
5729 LLVM_DEBUG(dbgs() << "Added Predicate: " << *Pred);
5730 Predicates.push_back(Pred);
5731 }
5732 };
5733
5734 AppendPredicate(StartVal, StartExtended);
5735 AppendPredicate(Accum, AccumExtended);
5736
5737 // *** Part3: Predicates are ready. Now go ahead and create the new addrec in
5738 // which the casts had been folded away. The caller can rewrite SymbolicPHI
5739 // into NewAR if it will also add the runtime overflow checks specified in
5740 // Predicates.
5741 auto *NewAR = getAddRecExpr(StartVal, Accum, L, SCEV::FlagAnyWrap);
5742
5743 std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>> PredRewrite =
5744 std::make_pair(NewAR, Predicates);
5745 // Remember the result of the analysis for this SCEV at this locayyytion.
5746 PredicatedSCEVRewrites[{SymbolicPHI, L}] = PredRewrite;
5747 return PredRewrite;
5748}
5749
5750std::optional<std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>>>
5752 auto *PN = cast<PHINode>(SymbolicPHI->getValue());
5753 const Loop *L = isIntegerLoopHeaderPHI(PN, LI);
5754 if (!L)
5755 return std::nullopt;
5756
5757 // Check to see if we already analyzed this PHI.
5758 auto I = PredicatedSCEVRewrites.find({SymbolicPHI, L});
5759 if (I != PredicatedSCEVRewrites.end()) {
5760 std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>> Rewrite =
5761 I->second;
5762 // Analysis was done before and failed to create an AddRec:
5763 if (Rewrite.first == SymbolicPHI)
5764 return std::nullopt;
5765 // Analysis was done before and succeeded to create an AddRec under
5766 // a predicate:
5767 assert(isa<SCEVAddRecExpr>(Rewrite.first) && "Expected an AddRec");
5768 assert(!(Rewrite.second).empty() && "Expected to find Predicates");
5769 return Rewrite;
5770 }
5771
5772 std::optional<std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>>>
5773 Rewrite = createAddRecFromPHIWithCastsImpl(SymbolicPHI);
5774
5775 // Record in the cache that the analysis failed
5776 if (!Rewrite) {
5778 PredicatedSCEVRewrites[{SymbolicPHI, L}] = {SymbolicPHI, Predicates};
5779 return std::nullopt;
5780 }
5781
5782 return Rewrite;
5783}
5784
5785// FIXME: This utility is currently required because the Rewriter currently
5786// does not rewrite this expression:
5787// {0, +, (sext ix (trunc iy to ix) to iy)}
5788// into {0, +, %step},
5789// even when the following Equal predicate exists:
5790// "%step == (sext ix (trunc iy to ix) to iy)".
5792 const SCEVAddRecExpr *AR1, const SCEVAddRecExpr *AR2) const {
5793 if (AR1 == AR2)
5794 return true;
5795
5796 auto areExprsEqual = [&](const SCEV *Expr1, const SCEV *Expr2) -> bool {
5797 if (Expr1 != Expr2 &&
5798 !Preds->implies(SE.getEqualPredicate(Expr1, Expr2), SE) &&
5799 !Preds->implies(SE.getEqualPredicate(Expr2, Expr1), SE))
5800 return false;
5801 return true;
5802 };
5803
5804 if (!areExprsEqual(AR1->getStart(), AR2->getStart()) ||
5805 !areExprsEqual(AR1->getStepRecurrence(SE), AR2->getStepRecurrence(SE)))
5806 return false;
5807 return true;
5808}
5809
5810/// A helper function for createAddRecFromPHI to handle simple cases.
5811///
5812/// This function tries to find an AddRec expression for the simplest (yet most
5813/// common) cases: PN = PHI(Start, OP(Self, LoopInvariant)).
5814/// If it fails, createAddRecFromPHI will use a more general, but slow,
5815/// technique for finding the AddRec expression.
5816const SCEV *ScalarEvolution::createSimpleAffineAddRec(PHINode *PN,
5817 Value *BEValueV,
5818 Value *StartValueV) {
5819 const Loop *L = LI.getLoopFor(PN->getParent());
5820 assert(L && L->getHeader() == PN->getParent());
5821 assert(BEValueV && StartValueV);
5822
5823 auto BO = MatchBinaryOp(BEValueV, getDataLayout(), AC, DT, PN);
5824 if (!BO)
5825 return nullptr;
5826
5827 if (BO->Opcode != Instruction::Add)
5828 return nullptr;
5829
5830 const SCEV *Accum = nullptr;
5831 if (BO->LHS == PN && L->isLoopInvariant(BO->RHS))
5832 Accum = getSCEV(BO->RHS);
5833 else if (BO->RHS == PN && L->isLoopInvariant(BO->LHS))
5834 Accum = getSCEV(BO->LHS);
5835
5836 if (!Accum)
5837 return nullptr;
5838
5840 if (BO->IsNUW)
5841 Flags = setFlags(Flags, SCEV::FlagNUW);
5842 if (BO->IsNSW)
5843 Flags = setFlags(Flags, SCEV::FlagNSW);
5844
5845 const SCEV *StartVal = getSCEV(StartValueV);
5846 const SCEV *PHISCEV = getAddRecExpr(StartVal, Accum, L, Flags);
5847 insertValueToMap(PN, PHISCEV);
5848
5849 if (auto *AR = dyn_cast<SCEVAddRecExpr>(PHISCEV)) {
5850 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR),
5852 proveNoWrapViaConstantRanges(AR)));
5853 }
5854
5855 // We can add Flags to the post-inc expression only if we
5856 // know that it is *undefined behavior* for BEValueV to
5857 // overflow.
5858 if (auto *BEInst = dyn_cast<Instruction>(BEValueV)) {
5859 assert(isLoopInvariant(Accum, L) &&
5860 "Accum is defined outside L, but is not invariant?");
5861 if (isAddRecNeverPoison(BEInst, L))
5862 (void)getAddRecExpr(getAddExpr(StartVal, Accum), Accum, L, Flags);
5863 }
5864
5865 return PHISCEV;
5866}
5867
5868const SCEV *ScalarEvolution::createAddRecFromPHI(PHINode *PN) {
5869 const Loop *L = LI.getLoopFor(PN->getParent());
5870 if (!L || L->getHeader() != PN->getParent())
5871 return nullptr;
5872
5873 // The loop may have multiple entrances or multiple exits; we can analyze
5874 // this phi as an addrec if it has a unique entry value and a unique
5875 // backedge value.
5876 Value *BEValueV = nullptr, *StartValueV = nullptr;
5877 for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) {
5878 Value *V = PN->getIncomingValue(i);
5879 if (L->contains(PN->getIncomingBlock(i))) {
5880 if (!BEValueV) {
5881 BEValueV = V;
5882 } else if (BEValueV != V) {
5883 BEValueV = nullptr;
5884 break;
5885 }
5886 } else if (!StartValueV) {
5887 StartValueV = V;
5888 } else if (StartValueV != V) {
5889 StartValueV = nullptr;
5890 break;
5891 }
5892 }
5893 if (!BEValueV || !StartValueV)
5894 return nullptr;
5895
5896 assert(ValueExprMap.find_as(PN) == ValueExprMap.end() &&
5897 "PHI node already processed?");
5898
5899 // First, try to find AddRec expression without creating a fictituos symbolic
5900 // value for PN.
5901 if (auto *S = createSimpleAffineAddRec(PN, BEValueV, StartValueV))
5902 return S;
5903
5904 // Handle PHI node value symbolically.
5905 const SCEV *SymbolicName = getUnknown(PN);
5906 insertValueToMap(PN, SymbolicName);
5907
5908 // Using this symbolic name for the PHI, analyze the value coming around
5909 // the back-edge.
5910 const SCEV *BEValue = getSCEV(BEValueV);
5911
5912 // NOTE: If BEValue is loop invariant, we know that the PHI node just
5913 // has a special value for the first iteration of the loop.
5914
5915 // If the value coming around the backedge is an add with the symbolic
5916 // value we just inserted, then we found a simple induction variable!
5917 if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(BEValue)) {
5918 // If there is a single occurrence of the symbolic value, replace it
5919 // with a recurrence.
5920 unsigned FoundIndex = Add->getNumOperands();
5921 for (unsigned i = 0, e = Add->getNumOperands(); i != e; ++i)
5922 if (Add->getOperand(i) == SymbolicName)
5923 if (FoundIndex == e) {
5924 FoundIndex = i;
5925 break;
5926 }
5927
5928 if (FoundIndex != Add->getNumOperands()) {
5929 // Create an add with everything but the specified operand.
5931 for (unsigned i = 0, e = Add->getNumOperands(); i != e; ++i)
5932 if (i != FoundIndex)
5933 Ops.push_back(SCEVBackedgeConditionFolder::rewrite(Add->getOperand(i),
5934 L, *this));
5935 const SCEV *Accum = getAddExpr(Ops);
5936
5937 // This is not a valid addrec if the step amount is varying each
5938 // loop iteration, but is not itself an addrec in this loop.
5939 if (isLoopInvariant(Accum, L) ||
5940 (isa<SCEVAddRecExpr>(Accum) &&
5941 cast<SCEVAddRecExpr>(Accum)->getLoop() == L)) {
5943
5944 if (auto BO = MatchBinaryOp(BEValueV, getDataLayout(), AC, DT, PN)) {
5945 if (BO->Opcode == Instruction::Add && BO->LHS == PN) {
5946 if (BO->IsNUW)
5947 Flags = setFlags(Flags, SCEV::FlagNUW);
5948 if (BO->IsNSW)
5949 Flags = setFlags(Flags, SCEV::FlagNSW);
5950 }
5951 } else if (GEPOperator *GEP = dyn_cast<GEPOperator>(BEValueV)) {
5952 if (GEP->getOperand(0) == PN) {
5953 GEPNoWrapFlags NW = GEP->getNoWrapFlags();
5954 // If the increment has any nowrap flags, then we know the address
5955 // space cannot be wrapped around.
5956 if (NW != GEPNoWrapFlags::none())
5957 Flags = setFlags(Flags, SCEV::FlagNW);
5958 // If the GEP is nuw or nusw with non-negative offset, we know that
5959 // no unsigned wrap occurs. We cannot set the nsw flag as only the
5960 // offset is treated as signed, while the base is unsigned.
5961 if (NW.hasNoUnsignedWrap() ||
5963 Flags = setFlags(Flags, SCEV::FlagNUW);
5964 }
5965
5966 // We cannot transfer nuw and nsw flags from subtraction
5967 // operations -- sub nuw X, Y is not the same as add nuw X, -Y
5968 // for instance.
5969 }
5970
5971 const SCEV *StartVal = getSCEV(StartValueV);
5972 const SCEV *PHISCEV = getAddRecExpr(StartVal, Accum, L, Flags);
5973
5974 // Okay, for the entire analysis of this edge we assumed the PHI
5975 // to be symbolic. We now need to go back and purge all of the
5976 // entries for the scalars that use the symbolic expression.
5977 forgetMemoizedResults({SymbolicName});
5978 insertValueToMap(PN, PHISCEV);
5979
5980 if (auto *AR = dyn_cast<SCEVAddRecExpr>(PHISCEV)) {
5981 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR),
5983 proveNoWrapViaConstantRanges(AR)));
5984 }
5985
5986 // We can add Flags to the post-inc expression only if we
5987 // know that it is *undefined behavior* for BEValueV to
5988 // overflow.
5989 if (auto *BEInst = dyn_cast<Instruction>(BEValueV))
5990 if (isLoopInvariant(Accum, L) && isAddRecNeverPoison(BEInst, L))
5991 (void)getAddRecExpr(getAddExpr(StartVal, Accum), Accum, L, Flags);
5992
5993 return PHISCEV;
5994 }
5995 }
5996 } else {
5997 // Otherwise, this could be a loop like this:
5998 // i = 0; for (j = 1; ..; ++j) { .... i = j; }
5999 // In this case, j = {1,+,1} and BEValue is j.
6000 // Because the other in-value of i (0) fits the evolution of BEValue
6001 // i really is an addrec evolution.
6002 //
6003 // We can generalize this saying that i is the shifted value of BEValue
6004 // by one iteration:
6005 // PHI(f(0), f({1,+,1})) --> f({0,+,1})
6006
6007 // Do not allow refinement in rewriting of BEValue.
6008 const SCEV *Shifted = SCEVShiftRewriter::rewrite(BEValue, L, *this);
6009 const SCEV *Start = SCEVInitRewriter::rewrite(Shifted, L, *this, false);
6010 if (Shifted != getCouldNotCompute() && Start != getCouldNotCompute() &&
6011 isGuaranteedNotToCauseUB(Shifted) && ::impliesPoison(Shifted, Start)) {
6012 const SCEV *StartVal = getSCEV(StartValueV);
6013 if (Start == StartVal) {
6014 // Okay, for the entire analysis of this edge we assumed the PHI
6015 // to be symbolic. We now need to go back and purge all of the
6016 // entries for the scalars that use the symbolic expression.
6017 forgetMemoizedResults({SymbolicName});
6018 insertValueToMap(PN, Shifted);
6019 return Shifted;
6020 }
6021 }
6022 }
6023
6024 // Remove the temporary PHI node SCEV that has been inserted while intending
6025 // to create an AddRecExpr for this PHI node. We can not keep this temporary
6026 // as it will prevent later (possibly simpler) SCEV expressions to be added
6027 // to the ValueExprMap.
6028 eraseValueFromMap(PN);
6029
6030 return nullptr;
6031}
6032
6033// Try to match a control flow sequence that branches out at BI and merges back
6034// at Merge into a "C ? LHS : RHS" select pattern. Return true on a successful
6035// match.
6037 Value *&C, Value *&LHS, Value *&RHS) {
6038 C = BI->getCondition();
6039
6040 BasicBlockEdge LeftEdge(BI->getParent(), BI->getSuccessor(0));
6041 BasicBlockEdge RightEdge(BI->getParent(), BI->getSuccessor(1));
6042
6043 Use &LeftUse = Merge->getOperandUse(0);
6044 Use &RightUse = Merge->getOperandUse(1);
6045
6046 if (DT.dominates(LeftEdge, LeftUse) && DT.dominates(RightEdge, RightUse)) {
6047 LHS = LeftUse;
6048 RHS = RightUse;
6049 return true;
6050 }
6051
6052 if (DT.dominates(LeftEdge, RightUse) && DT.dominates(RightEdge, LeftUse)) {
6053 LHS = RightUse;
6054 RHS = LeftUse;
6055 return true;
6056 }
6057
6058 return false;
6059}
6060
6062 Value *&Cond, Value *&LHS,
6063 Value *&RHS) {
6064 auto IsReachable =
6065 [&](BasicBlock *BB) { return DT.isReachableFromEntry(BB); };
6066 if (PN->getNumIncomingValues() == 2 && all_of(PN->blocks(), IsReachable)) {
6067 // Try to match
6068 //
6069 // br %cond, label %left, label %right
6070 // left:
6071 // br label %merge
6072 // right:
6073 // br label %merge
6074 // merge:
6075 // V = phi [ %x, %left ], [ %y, %right ]
6076 //
6077 // as "select %cond, %x, %y"
6078
6079 BasicBlock *IDom = DT[PN->getParent()]->getIDom()->getBlock();
6080 assert(IDom && "At least the entry block should dominate PN");
6081
6082 auto *BI = dyn_cast<CondBrInst>(IDom->getTerminator());
6083 return BI && BrPHIToSelect(DT, BI, PN, Cond, LHS, RHS);
6084 }
6085 return false;
6086}
6087
6088const SCEV *ScalarEvolution::createNodeFromSelectLikePHI(PHINode *PN) {
6089 Value *Cond = nullptr, *LHS = nullptr, *RHS = nullptr;
6090 if (getOperandsForSelectLikePHI(DT, PN, Cond, LHS, RHS) &&
6093 return createNodeForSelectOrPHI(PN, Cond, LHS, RHS);
6094
6095 return nullptr;
6096}
6097
6099 BinaryOperator *CommonInst = nullptr;
6100 // Check if instructions are identical.
6101 for (Value *Incoming : PN->incoming_values()) {
6102 auto *IncomingInst = dyn_cast<BinaryOperator>(Incoming);
6103 if (!IncomingInst)
6104 return nullptr;
6105 if (CommonInst) {
6106 if (!CommonInst->isIdenticalToWhenDefined(IncomingInst))
6107 return nullptr; // Not identical, give up
6108 } else {
6109 // Remember binary operator
6110 CommonInst = IncomingInst;
6111 }
6112 }
6113 return CommonInst;
6114}
6115
6116/// Returns SCEV for the first operand of a phi if all phi operands have
6117/// identical opcodes and operands
6118/// eg.
6119/// a: %add = %a + %b
6120/// br %c
6121/// b: %add1 = %a + %b
6122/// br %c
6123/// c: %phi = phi [%add, a], [%add1, b]
6124/// scev(%phi) => scev(%add)
6125const SCEV *
6126ScalarEvolution::createNodeForPHIWithIdenticalOperands(PHINode *PN) {
6127 BinaryOperator *CommonInst = getCommonInstForPHI(PN);
6128 if (!CommonInst)
6129 return nullptr;
6130
6131 // Check if SCEV exprs for instructions are identical.
6132 const SCEV *CommonSCEV = getSCEV(CommonInst);
6133 bool SCEVExprsIdentical =
6135 [this, CommonSCEV](Value *V) { return CommonSCEV == getSCEV(V); });
6136 return SCEVExprsIdentical ? CommonSCEV : nullptr;
6137}
6138
6139const SCEV *ScalarEvolution::createNodeForPHI(PHINode *PN) {
6140 if (const SCEV *S = createAddRecFromPHI(PN))
6141 return S;
6142
6143 // We do not allow simplifying phi (undef, X) to X here, to avoid reusing the
6144 // phi node for X.
6145 if (Value *V = simplifyInstruction(
6146 PN, {getDataLayout(), &TLI, &DT, &AC, /*CtxI=*/nullptr,
6147 /*UseInstrInfo=*/true, /*CanUseUndef=*/false}))
6148 return getSCEV(V);
6149
6150 if (const SCEV *S = createNodeForPHIWithIdenticalOperands(PN))
6151 return S;
6152
6153 if (const SCEV *S = createNodeFromSelectLikePHI(PN))
6154 return S;
6155
6156 // If it's not a loop phi, we can't handle it yet.
6157 return getUnknown(PN);
6158}
6159
6160bool SCEVMinMaxExprContains(const SCEV *Root, const SCEV *OperandToFind,
6161 SCEVTypes RootKind) {
6162 struct FindClosure {
6163 const SCEV *OperandToFind;
6164 const SCEVTypes RootKind; // Must be a sequential min/max expression.
6165 const SCEVTypes NonSequentialRootKind; // Non-seq variant of RootKind.
6166
6167 bool Found = false;
6168
6169 bool canRecurseInto(SCEVTypes Kind) const {
6170 // We can only recurse into the SCEV expression of the same effective type
6171 // as the type of our root SCEV expression, and into zero-extensions.
6172 return RootKind == Kind || NonSequentialRootKind == Kind ||
6173 scZeroExtend == Kind;
6174 };
6175
6176 FindClosure(const SCEV *OperandToFind, SCEVTypes RootKind)
6177 : OperandToFind(OperandToFind), RootKind(RootKind),
6178 NonSequentialRootKind(
6180 RootKind)) {}
6181
6182 bool follow(const SCEV *S) {
6183 Found = S == OperandToFind;
6184
6185 return !isDone() && canRecurseInto(S->getSCEVType());
6186 }
6187
6188 bool isDone() const { return Found; }
6189 };
6190
6191 FindClosure FC(OperandToFind, RootKind);
6192 visitAll(Root, FC);
6193 return FC.Found;
6194}
6195
6196std::optional<const SCEV *>
6197ScalarEvolution::createNodeForSelectOrPHIInstWithICmpInstCond(Type *Ty,
6198 ICmpInst *Cond,
6199 Value *TrueVal,
6200 Value *FalseVal) {
6201 // Try to match some simple smax or umax patterns.
6202 auto *ICI = Cond;
6203
6204 Value *LHS = ICI->getOperand(0);
6205 Value *RHS = ICI->getOperand(1);
6206
6207 switch (ICI->getPredicate()) {
6208 case ICmpInst::ICMP_SLT:
6209 case ICmpInst::ICMP_SLE:
6210 case ICmpInst::ICMP_ULT:
6211 case ICmpInst::ICMP_ULE:
6212 std::swap(LHS, RHS);
6213 [[fallthrough]];
6214 case ICmpInst::ICMP_SGT:
6215 case ICmpInst::ICMP_SGE:
6216 case ICmpInst::ICMP_UGT:
6217 case ICmpInst::ICMP_UGE:
6218 // a > b ? a+x : b+x -> max(a, b)+x
6219 // a > b ? b+x : a+x -> min(a, b)+x
6221 bool Signed = ICI->isSigned();
6222 const SCEV *LA = getSCEV(TrueVal);
6223 const SCEV *RA = getSCEV(FalseVal);
6224 const SCEV *LS = getSCEV(LHS);
6225 const SCEV *RS = getSCEV(RHS);
6226 if (LA->getType()->isPointerTy()) {
6227 // FIXME: Handle cases where LS/RS are pointers not equal to LA/RA.
6228 // Need to make sure we can't produce weird expressions involving
6229 // negated pointers.
6230 if (LA == LS && RA == RS)
6231 return Signed ? getSMaxExpr(LS, RS) : getUMaxExpr(LS, RS);
6232 if (LA == RS && RA == LS)
6233 return Signed ? getSMinExpr(LS, RS) : getUMinExpr(LS, RS);
6234 }
6235 auto CoerceOperand = [&](const SCEV *Op) -> const SCEV * {
6236 if (Op->getType()->isPointerTy()) {
6239 return Op;
6240 }
6241 if (Signed)
6242 Op = getNoopOrSignExtend(Op, Ty);
6243 else
6244 Op = getNoopOrZeroExtend(Op, Ty);
6245 return Op;
6246 };
6247 LS = CoerceOperand(LS);
6248 RS = CoerceOperand(RS);
6250 break;
6251 const SCEV *LDiff = getMinusSCEV(LA, LS);
6252 const SCEV *RDiff = getMinusSCEV(RA, RS);
6253 if (LDiff == RDiff)
6254 return getAddExpr(Signed ? getSMaxExpr(LS, RS) : getUMaxExpr(LS, RS),
6255 LDiff);
6256 LDiff = getMinusSCEV(LA, RS);
6257 RDiff = getMinusSCEV(RA, LS);
6258 if (LDiff == RDiff)
6259 return getAddExpr(Signed ? getSMinExpr(LS, RS) : getUMinExpr(LS, RS),
6260 LDiff);
6261 }
6262 break;
6263 case ICmpInst::ICMP_NE:
6264 // x != 0 ? x+y : C+y -> x == 0 ? C+y : x+y
6265 std::swap(TrueVal, FalseVal);
6266 [[fallthrough]];
6267 case ICmpInst::ICMP_EQ:
6268 // x == 0 ? C+y : x+y -> umax(x, C)+y iff C u<= 1
6271 const SCEV *X = getNoopOrZeroExtend(getSCEV(LHS), Ty);
6272 const SCEV *TrueValExpr = getSCEV(TrueVal); // C+y
6273 const SCEV *FalseValExpr = getSCEV(FalseVal); // x+y
6274 const SCEV *Y = getMinusSCEV(FalseValExpr, X); // y = (x+y)-x
6275 const SCEV *C = getMinusSCEV(TrueValExpr, Y); // C = (C+y)-y
6276 if (isa<SCEVConstant>(C) && cast<SCEVConstant>(C)->getAPInt().ule(1))
6277 return getAddExpr(getUMaxExpr(X, C), Y);
6278 }
6279 // x == 0 ? 0 : umin (..., x, ...) -> umin_seq(x, umin (...))
6280 // x == 0 ? 0 : umin_seq(..., x, ...) -> umin_seq(x, umin_seq(...))
6281 // x == 0 ? 0 : umin (..., umin_seq(..., x, ...), ...)
6282 // -> umin_seq(x, umin (..., umin_seq(...), ...))
6284 isa<ConstantInt>(TrueVal) && cast<ConstantInt>(TrueVal)->isZero()) {
6285 const SCEV *X = getSCEV(LHS);
6286 while (auto *ZExt = dyn_cast<SCEVZeroExtendExpr>(X))
6287 X = ZExt->getOperand();
6288 if (getTypeSizeInBits(X->getType()) <= getTypeSizeInBits(Ty)) {
6289 const SCEV *FalseValExpr = getSCEV(FalseVal);
6290 if (SCEVMinMaxExprContains(FalseValExpr, X, scSequentialUMinExpr))
6291 return getUMinExpr(getNoopOrZeroExtend(X, Ty), FalseValExpr,
6292 /*Sequential=*/true);
6293 }
6294 }
6295 break;
6296 default:
6297 break;
6298 }
6299
6300 return std::nullopt;
6301}
6302
6303static std::optional<const SCEV *>
6305 const SCEV *TrueExpr, const SCEV *FalseExpr) {
6306 assert(CondExpr->getType()->isIntegerTy(1) &&
6307 TrueExpr->getType() == FalseExpr->getType() &&
6308 TrueExpr->getType()->isIntegerTy(1) &&
6309 "Unexpected operands of a select.");
6310
6311 // i1 cond ? i1 x : i1 C --> C + (i1 cond ? (i1 x - i1 C) : i1 0)
6312 // --> C + (umin_seq cond, x - C)
6313 //
6314 // i1 cond ? i1 C : i1 x --> C + (i1 cond ? i1 0 : (i1 x - i1 C))
6315 // --> C + (i1 ~cond ? (i1 x - i1 C) : i1 0)
6316 // --> C + (umin_seq ~cond, x - C)
6317
6318 // FIXME: while we can't legally model the case where both of the hands
6319 // are fully variable, we only require that the *difference* is constant.
6320 if (!isa<SCEVConstant>(TrueExpr) && !isa<SCEVConstant>(FalseExpr))
6321 return std::nullopt;
6322
6323 const SCEV *X, *C;
6324 if (isa<SCEVConstant>(TrueExpr)) {
6325 CondExpr = SE->getNotSCEV(CondExpr);
6326 X = FalseExpr;
6327 C = TrueExpr;
6328 } else {
6329 X = TrueExpr;
6330 C = FalseExpr;
6331 }
6332 return SE->getAddExpr(C, SE->getUMinExpr(CondExpr, SE->getMinusSCEV(X, C),
6333 /*Sequential=*/true));
6334}
6335
6336static std::optional<const SCEV *>
6338 Value *FalseVal) {
6339 if (!isa<ConstantInt>(TrueVal) && !isa<ConstantInt>(FalseVal))
6340 return std::nullopt;
6341
6342 const auto *SECond = SE->getSCEV(Cond);
6343 const auto *SETrue = SE->getSCEV(TrueVal);
6344 const auto *SEFalse = SE->getSCEV(FalseVal);
6345 return createNodeForSelectViaUMinSeq(SE, SECond, SETrue, SEFalse);
6346}
6347
6348const SCEV *ScalarEvolution::createNodeForSelectOrPHIViaUMinSeq(
6349 Value *V, Value *Cond, Value *TrueVal, Value *FalseVal) {
6350 assert(Cond->getType()->isIntegerTy(1) && "Select condition is not an i1?");
6351 assert(TrueVal->getType() == FalseVal->getType() &&
6352 V->getType() == TrueVal->getType() &&
6353 "Types of select hands and of the result must match.");
6354
6355 // For now, only deal with i1-typed `select`s.
6356 if (!V->getType()->isIntegerTy(1))
6357 return getUnknown(V);
6358
6359 if (std::optional<const SCEV *> S =
6360 createNodeForSelectViaUMinSeq(this, Cond, TrueVal, FalseVal))
6361 return *S;
6362
6363 return getUnknown(V);
6364}
6365
6366const SCEV *ScalarEvolution::createNodeForSelectOrPHI(Value *V, Value *Cond,
6367 Value *TrueVal,
6368 Value *FalseVal) {
6369 // Handle "constant" branch or select. This can occur for instance when a
6370 // loop pass transforms an inner loop and moves on to process the outer loop.
6371 if (auto *CI = dyn_cast<ConstantInt>(Cond))
6372 return getSCEV(CI->isOne() ? TrueVal : FalseVal);
6373
6374 if (auto *I = dyn_cast<Instruction>(V)) {
6375 if (auto *ICI = dyn_cast<ICmpInst>(Cond)) {
6376 if (std::optional<const SCEV *> S =
6377 createNodeForSelectOrPHIInstWithICmpInstCond(I->getType(), ICI,
6378 TrueVal, FalseVal))
6379 return *S;
6380 }
6381 }
6382
6383 return createNodeForSelectOrPHIViaUMinSeq(V, Cond, TrueVal, FalseVal);
6384}
6385
6386/// Expand GEP instructions into add and multiply operations. This allows them
6387/// to be analyzed by regular SCEV code.
6388const SCEV *ScalarEvolution::createNodeForGEP(GEPOperator *GEP) {
6389 assert(GEP->getSourceElementType()->isSized() &&
6390 "GEP source element type must be sized");
6391
6392 SmallVector<SCEVUse, 4> IndexExprs;
6393 for (Value *Index : GEP->indices())
6394 IndexExprs.push_back(getSCEV(Index));
6395 return getGEPExpr(GEP, IndexExprs);
6396}
6397
6398APInt ScalarEvolution::getConstantMultipleImpl(const SCEV *S,
6399 const Instruction *CtxI) {
6400 uint64_t BitWidth = getTypeSizeInBits(S->getType());
6401 auto GetShiftedByZeros = [BitWidth](uint32_t TrailingZeros) {
6402 return TrailingZeros >= BitWidth
6404 : APInt::getOneBitSet(BitWidth, TrailingZeros);
6405 };
6406 auto GetGCDMultiple = [this, CtxI](const SCEVNAryExpr *N) {
6407 // The result is GCD of all operands results.
6408 APInt Res = getConstantMultiple(N->getOperand(0), CtxI);
6409 for (unsigned I = 1, E = N->getNumOperands(); I < E && Res != 1; ++I)
6411 Res, getConstantMultiple(N->getOperand(I), CtxI));
6412 return Res;
6413 };
6414
6415 switch (S->getSCEVType()) {
6416 case scConstant:
6417 return cast<SCEVConstant>(S)->getAPInt();
6418 case scPtrToAddr:
6419 case scPtrToInt:
6420 return getConstantMultiple(cast<SCEVCastExpr>(S)->getOperand());
6421 case scUDivExpr:
6422 case scVScale:
6423 return APInt(BitWidth, 1);
6424 case scTruncate: {
6425 // Only multiples that are a power of 2 will hold after truncation.
6426 const SCEVTruncateExpr *T = cast<SCEVTruncateExpr>(S);
6427 uint32_t TZ = getMinTrailingZeros(T->getOperand(), CtxI);
6428 return GetShiftedByZeros(TZ);
6429 }
6430 case scZeroExtend: {
6431 const SCEVZeroExtendExpr *Z = cast<SCEVZeroExtendExpr>(S);
6432 return getConstantMultiple(Z->getOperand(), CtxI).zext(BitWidth);
6433 }
6434 case scSignExtend: {
6435 // Only multiples that are a power of 2 will hold after sext.
6436 const SCEVSignExtendExpr *E = cast<SCEVSignExtendExpr>(S);
6437 uint32_t TZ = getMinTrailingZeros(E->getOperand(), CtxI);
6438 return GetShiftedByZeros(TZ);
6439 }
6440 case scMulExpr: {
6441 const SCEVMulExpr *M = cast<SCEVMulExpr>(S);
6442 if (M->hasNoUnsignedWrap()) {
6443 // The result is the product of all operand results.
6444 APInt Res = getConstantMultiple(M->getOperand(0), CtxI);
6445 for (const SCEV *Operand : M->operands().drop_front())
6446 Res = Res * getConstantMultiple(Operand, CtxI);
6447 return Res;
6448 }
6449
6450 // If there are no wrap guarentees, find the trailing zeros, which is the
6451 // sum of trailing zeros for all its operands.
6452 uint32_t TZ = 0;
6453 for (const SCEV *Operand : M->operands())
6454 TZ += getMinTrailingZeros(Operand, CtxI);
6455 return GetShiftedByZeros(TZ);
6456 }
6457 case scAddExpr:
6458 case scAddRecExpr: {
6459 const SCEVNAryExpr *N = cast<SCEVNAryExpr>(S);
6460 if (N->hasNoUnsignedWrap())
6461 return GetGCDMultiple(N);
6462 // Find the trailing bits, which is the minimum of its operands.
6463 uint32_t TZ = getMinTrailingZeros(N->getOperand(0), CtxI);
6464 for (const SCEV *Operand : N->operands().drop_front())
6465 TZ = std::min(TZ, getMinTrailingZeros(Operand, CtxI));
6466 return GetShiftedByZeros(TZ);
6467 }
6468 case scUMaxExpr:
6469 case scSMaxExpr:
6470 case scUMinExpr:
6471 case scSMinExpr:
6473 return GetGCDMultiple(cast<SCEVNAryExpr>(S));
6474 case scUnknown: {
6475 // Ask ValueTracking for known bits. SCEVUnknown only become available at
6476 // the point their underlying IR instruction has been defined. If CtxI was
6477 // not provided, use:
6478 // * the first instruction in the entry block if it is an argument
6479 // * the instruction itself otherwise.
6480 const SCEVUnknown *U = cast<SCEVUnknown>(S);
6481 if (!CtxI) {
6482 if (isa<Argument>(U->getValue()))
6483 CtxI = &*F.getEntryBlock().begin();
6484 else if (auto *I = dyn_cast<Instruction>(U->getValue()))
6485 CtxI = I;
6486 }
6487 unsigned Known =
6488 computeKnownBits(U->getValue(), getDataLayout(), &AC, CtxI, &DT)
6489 .countMinTrailingZeros();
6490 return GetShiftedByZeros(Known);
6491 }
6492 case scCouldNotCompute:
6493 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
6494 }
6495 llvm_unreachable("Unknown SCEV kind!");
6496}
6497
6499 const Instruction *CtxI) {
6500 // Skip looking up and updating the cache if there is a context instruction,
6501 // as the result will only be valid in the specified context.
6502 if (CtxI)
6503 return getConstantMultipleImpl(S, CtxI);
6504
6505 auto I = ConstantMultipleCache.find(S);
6506 if (I != ConstantMultipleCache.end())
6507 return I->second;
6508
6509 APInt Result = getConstantMultipleImpl(S, CtxI);
6510 auto InsertPair = ConstantMultipleCache.insert({S, Result});
6511 assert(InsertPair.second && "Should insert a new key");
6512 return InsertPair.first->second;
6513}
6514
6516 APInt Multiple = getConstantMultiple(S);
6517 return Multiple == 0 ? APInt(Multiple.getBitWidth(), 1) : Multiple;
6518}
6519
6521 const Instruction *CtxI) {
6522 return std::min(getConstantMultiple(S, CtxI).countTrailingZeros(),
6523 (unsigned)getTypeSizeInBits(S->getType()));
6524}
6525
6526/// Helper method to assign a range to V from metadata present in the IR.
6527static std::optional<ConstantRange> GetRangeFromMetadata(Value *V) {
6529 if (MDNode *MD = I->getMetadata(LLVMContext::MD_range))
6530 return getConstantRangeFromMetadata(*MD);
6531 if (const auto *CB = dyn_cast<CallBase>(V))
6532 if (std::optional<ConstantRange> Range = CB->getRange())
6533 return Range;
6534 }
6535 if (auto *A = dyn_cast<Argument>(V))
6536 if (std::optional<ConstantRange> Range = A->getRange())
6537 return Range;
6538
6539 return std::nullopt;
6540}
6541
6543 SCEV::NoWrapFlags Flags) {
6544 if (AddRec->getNoWrapFlags(Flags) != Flags) {
6545 AddRec->setNoWrapFlags(Flags);
6546 UnsignedRanges.erase(AddRec);
6547 SignedRanges.erase(AddRec);
6548 ConstantMultipleCache.erase(AddRec);
6549 }
6550}
6551
6552ConstantRange ScalarEvolution::
6553getRangeForUnknownRecurrence(const SCEVUnknown *U) {
6554 const DataLayout &DL = getDataLayout();
6555
6556 unsigned BitWidth = getTypeSizeInBits(U->getType());
6557 const ConstantRange FullSet(BitWidth, /*isFullSet=*/true);
6558
6559 // Match a simple recurrence of the form: <start, ShiftOp, Step>, and then
6560 // use information about the trip count to improve our available range. Note
6561 // that the trip count independent cases are already handled by known bits.
6562 // WARNING: The definition of recurrence used here is subtly different than
6563 // the one used by AddRec (and thus most of this file). Step is allowed to
6564 // be arbitrarily loop varying here, where AddRec allows only loop invariant
6565 // and other addrecs in the same loop (for non-affine addrecs). The code
6566 // below intentionally handles the case where step is not loop invariant.
6567 auto *P = dyn_cast<PHINode>(U->getValue());
6568 if (!P)
6569 return FullSet;
6570
6571 // Make sure that no Phi input comes from an unreachable block. Otherwise,
6572 // even the values that are not available in these blocks may come from them,
6573 // and this leads to false-positive recurrence test.
6574 for (auto *Pred : predecessors(P->getParent()))
6575 if (!DT.isReachableFromEntry(Pred))
6576 return FullSet;
6577
6578 BinaryOperator *BO;
6579 Value *Start, *Step;
6580 if (!matchSimpleRecurrence(P, BO, Start, Step))
6581 return FullSet;
6582
6583 // If we found a recurrence in reachable code, we must be in a loop. Note
6584 // that BO might be in some subloop of L, and that's completely okay.
6585 auto *L = LI.getLoopFor(P->getParent());
6586 assert(L && L->getHeader() == P->getParent());
6587 if (!L->contains(BO->getParent()))
6588 // NOTE: This bailout should be an assert instead. However, asserting
6589 // the condition here exposes a case where LoopFusion is querying SCEV
6590 // with malformed loop information during the midst of the transform.
6591 // There doesn't appear to be an obvious fix, so for the moment bailout
6592 // until the caller issue can be fixed. PR49566 tracks the bug.
6593 return FullSet;
6594
6595 // TODO: Extend to other opcodes such as mul, and div
6596 switch (BO->getOpcode()) {
6597 default:
6598 return FullSet;
6599 case Instruction::AShr:
6600 case Instruction::LShr:
6601 case Instruction::Shl:
6602 break;
6603 };
6604
6605 if (BO->getOperand(0) != P)
6606 // TODO: Handle the power function forms some day.
6607 return FullSet;
6608
6609 unsigned TC = getSmallConstantMaxTripCount(L);
6610 if (!TC || TC >= BitWidth)
6611 return FullSet;
6612
6613 auto KnownStart = computeKnownBits(Start, DL, &AC, nullptr, &DT);
6614 auto KnownStep = computeKnownBits(Step, DL, &AC, nullptr, &DT);
6615 assert(KnownStart.getBitWidth() == BitWidth &&
6616 KnownStep.getBitWidth() == BitWidth);
6617
6618 // Compute total shift amount, being careful of overflow and bitwidths.
6619 auto MaxShiftAmt = KnownStep.getMaxValue();
6620 APInt TCAP(BitWidth, TC-1);
6621 bool Overflow = false;
6622 auto TotalShift = MaxShiftAmt.umul_ov(TCAP, Overflow);
6623 if (Overflow)
6624 return FullSet;
6625
6626 switch (BO->getOpcode()) {
6627 default:
6628 llvm_unreachable("filtered out above");
6629 case Instruction::AShr: {
6630 // For each ashr, three cases:
6631 // shift = 0 => unchanged value
6632 // saturation => 0 or -1
6633 // other => a value closer to zero (of the same sign)
6634 // Thus, the end value is closer to zero than the start.
6635 auto KnownEnd = KnownBits::ashr(KnownStart,
6636 KnownBits::makeConstant(TotalShift));
6637 if (KnownStart.isNonNegative())
6638 // Analogous to lshr (simply not yet canonicalized)
6639 return ConstantRange::getNonEmpty(KnownEnd.getMinValue(),
6640 KnownStart.getMaxValue() + 1);
6641 if (KnownStart.isNegative())
6642 // End >=u Start && End <=s Start
6643 return ConstantRange::getNonEmpty(KnownStart.getMinValue(),
6644 KnownEnd.getMaxValue() + 1);
6645 break;
6646 }
6647 case Instruction::LShr: {
6648 // For each lshr, three cases:
6649 // shift = 0 => unchanged value
6650 // saturation => 0
6651 // other => a smaller positive number
6652 // Thus, the low end of the unsigned range is the last value produced.
6653 auto KnownEnd = KnownBits::lshr(KnownStart,
6654 KnownBits::makeConstant(TotalShift));
6655 return ConstantRange::getNonEmpty(KnownEnd.getMinValue(),
6656 KnownStart.getMaxValue() + 1);
6657 }
6658 case Instruction::Shl: {
6659 // Iff no bits are shifted out, value increases on every shift.
6660 auto KnownEnd = KnownBits::shl(KnownStart,
6661 KnownBits::makeConstant(TotalShift));
6662 if (TotalShift.ult(KnownStart.countMinLeadingZeros()))
6663 return ConstantRange(KnownStart.getMinValue(),
6664 KnownEnd.getMaxValue() + 1);
6665 break;
6666 }
6667 };
6668 return FullSet;
6669}
6670
6671// The goal of this function is to check if recursively visiting the operands
6672// of this PHI might lead to an infinite loop. If we do see such a loop,
6673// there's no good way to break it, so we avoid analyzing such cases.
6674//
6675// getRangeRef previously used a visited set to avoid infinite loops, but this
6676// caused other issues: the result was dependent on the order of getRangeRef
6677// calls, and the interaction with createSCEVIter could cause a stack overflow
6678// in some cases (see issue #148253).
6679//
6680// FIXME: The way this is implemented is overly conservative; this checks
6681// for a few obviously safe patterns, but anything that doesn't lead to
6682// recursion is fine.
6684 Value *Cond = nullptr, *LHS = nullptr, *RHS = nullptr;
6686 return true;
6687
6688 if (all_of(PHI->operands(),
6689 [&](Value *Operand) { return DT.dominates(Operand, PHI); }))
6690 return true;
6691
6692 return false;
6693}
6694
6695const ConstantRange &
6696ScalarEvolution::getRangeRefIter(const SCEV *S,
6697 ScalarEvolution::RangeSignHint SignHint) {
6698 DenseMap<const SCEV *, ConstantRange> &Cache =
6699 SignHint == ScalarEvolution::HINT_RANGE_UNSIGNED ? UnsignedRanges
6700 : SignedRanges;
6701 SmallVector<SCEVUse> WorkList;
6702 SmallPtrSet<const SCEV *, 8> Seen;
6703
6704 // Add Expr to the worklist, if Expr is either an N-ary expression or a
6705 // SCEVUnknown PHI node.
6706 auto AddToWorklist = [&WorkList, &Seen, &Cache](const SCEV *Expr) {
6707 if (!Seen.insert(Expr).second)
6708 return;
6709 if (Cache.contains(Expr))
6710 return;
6711 switch (Expr->getSCEVType()) {
6712 case scUnknown:
6713 if (!isa<PHINode>(cast<SCEVUnknown>(Expr)->getValue()))
6714 break;
6715 [[fallthrough]];
6716 case scConstant:
6717 case scVScale:
6718 case scTruncate:
6719 case scZeroExtend:
6720 case scSignExtend:
6721 case scPtrToAddr:
6722 case scPtrToInt:
6723 case scAddExpr:
6724 case scMulExpr:
6725 case scUDivExpr:
6726 case scAddRecExpr:
6727 case scUMaxExpr:
6728 case scSMaxExpr:
6729 case scUMinExpr:
6730 case scSMinExpr:
6732 WorkList.push_back(Expr);
6733 break;
6734 case scCouldNotCompute:
6735 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
6736 }
6737 };
6738 AddToWorklist(S);
6739
6740 // Build worklist by queuing operands of N-ary expressions and phi nodes.
6741 for (unsigned I = 0; I != WorkList.size(); ++I) {
6742 const SCEV *P = WorkList[I];
6743 auto *UnknownS = dyn_cast<SCEVUnknown>(P);
6744 // If it is not a `SCEVUnknown`, just recurse into operands.
6745 if (!UnknownS) {
6746 for (const SCEV *Op : P->operands())
6747 AddToWorklist(Op);
6748 continue;
6749 }
6750 // `SCEVUnknown`'s require special treatment.
6751 if (PHINode *P = dyn_cast<PHINode>(UnknownS->getValue())) {
6752 if (!RangeRefPHIAllowedOperands(DT, P))
6753 continue;
6754 for (auto &Op : reverse(P->operands()))
6755 AddToWorklist(getSCEV(Op));
6756 }
6757 }
6758
6759 if (!WorkList.empty()) {
6760 // Use getRangeRef to compute ranges for items in the worklist in reverse
6761 // order. This will force ranges for earlier operands to be computed before
6762 // their users in most cases.
6763 for (const SCEV *P : reverse(drop_begin(WorkList))) {
6764 getRangeRef(P, SignHint);
6765 }
6766 }
6767
6768 return getRangeRef(S, SignHint, 0);
6769}
6770
6771/// Determine the range for a particular SCEV. If SignHint is
6772/// HINT_RANGE_UNSIGNED (resp. HINT_RANGE_SIGNED) then getRange prefers ranges
6773/// with a "cleaner" unsigned (resp. signed) representation.
6774const ConstantRange &ScalarEvolution::getRangeRef(
6775 const SCEV *S, ScalarEvolution::RangeSignHint SignHint, unsigned Depth) {
6776 DenseMap<const SCEV *, ConstantRange> &Cache =
6777 SignHint == ScalarEvolution::HINT_RANGE_UNSIGNED ? UnsignedRanges
6778 : SignedRanges;
6780 SignHint == ScalarEvolution::HINT_RANGE_UNSIGNED ? ConstantRange::Unsigned
6782
6783 // See if we've computed this range already.
6785 if (I != Cache.end())
6786 return I->second;
6787
6788 if (const SCEVConstant *C = dyn_cast<SCEVConstant>(S))
6789 return setRange(C, SignHint, ConstantRange(C->getAPInt()));
6790
6791 // Switch to iteratively computing the range for S, if it is part of a deeply
6792 // nested expression.
6794 return getRangeRefIter(S, SignHint);
6795
6796 unsigned BitWidth = getTypeSizeInBits(S->getType());
6797 ConstantRange ConservativeResult(BitWidth, /*isFullSet=*/true);
6798 using OBO = OverflowingBinaryOperator;
6799
6800 // If the value has known zeros, the maximum value will have those known zeros
6801 // as well.
6802 if (SignHint == ScalarEvolution::HINT_RANGE_UNSIGNED) {
6803 APInt Multiple = getNonZeroConstantMultiple(S);
6804 APInt Remainder = APInt::getMaxValue(BitWidth).urem(Multiple);
6805 if (!Remainder.isZero())
6806 ConservativeResult =
6807 ConstantRange(APInt::getMinValue(BitWidth),
6808 APInt::getMaxValue(BitWidth) - Remainder + 1);
6809 }
6810 else {
6811 uint32_t TZ = getMinTrailingZeros(S);
6812 if (TZ != 0) {
6813 ConservativeResult = ConstantRange(
6815 APInt::getSignedMaxValue(BitWidth).ashr(TZ).shl(TZ) + 1);
6816 }
6817 }
6818
6819 switch (S->getSCEVType()) {
6820 case scConstant:
6821 llvm_unreachable("Already handled above.");
6822 case scVScale:
6823 return setRange(S, SignHint, getVScaleRange(&F, BitWidth));
6824 case scTruncate: {
6825 const SCEVTruncateExpr *Trunc = cast<SCEVTruncateExpr>(S);
6826 ConstantRange X = getRangeRef(Trunc->getOperand(), SignHint, Depth + 1);
6827 return setRange(
6828 Trunc, SignHint,
6829 ConservativeResult.intersectWith(X.truncate(BitWidth), RangeType));
6830 }
6831 case scZeroExtend: {
6832 const SCEVZeroExtendExpr *ZExt = cast<SCEVZeroExtendExpr>(S);
6833 ConstantRange X = getRangeRef(ZExt->getOperand(), SignHint, Depth + 1);
6834 return setRange(
6835 ZExt, SignHint,
6836 ConservativeResult.intersectWith(X.zeroExtend(BitWidth), RangeType));
6837 }
6838 case scSignExtend: {
6839 const SCEVSignExtendExpr *SExt = cast<SCEVSignExtendExpr>(S);
6840 ConstantRange X = getRangeRef(SExt->getOperand(), SignHint, Depth + 1);
6841 return setRange(
6842 SExt, SignHint,
6843 ConservativeResult.intersectWith(X.signExtend(BitWidth), RangeType));
6844 }
6845 case scPtrToAddr:
6846 case scPtrToInt: {
6847 const SCEVCastExpr *Cast = cast<SCEVCastExpr>(S);
6848 ConstantRange X = getRangeRef(Cast->getOperand(), SignHint, Depth + 1);
6849 return setRange(Cast, SignHint, X);
6850 }
6851 case scAddExpr: {
6852 const SCEVAddExpr *Add = cast<SCEVAddExpr>(S);
6853 // Check if this is a URem pattern: A - (A / B) * B, which is always < B.
6854 const SCEV *URemLHS = nullptr, *URemRHS = nullptr;
6855 if (SignHint == ScalarEvolution::HINT_RANGE_UNSIGNED &&
6856 match(S, m_scev_URem(m_SCEV(URemLHS), m_SCEV(URemRHS), *this))) {
6857 ConstantRange LHSRange = getRangeRef(URemLHS, SignHint, Depth + 1);
6858 ConstantRange RHSRange = getRangeRef(URemRHS, SignHint, Depth + 1);
6859 ConservativeResult =
6860 ConservativeResult.intersectWith(LHSRange.urem(RHSRange), RangeType);
6861 }
6862 ConstantRange X = getRangeRef(Add->getOperand(0), SignHint, Depth + 1);
6863 unsigned WrapType = OBO::AnyWrap;
6864 if (Add->hasNoSignedWrap())
6865 WrapType |= OBO::NoSignedWrap;
6866 if (Add->hasNoUnsignedWrap())
6867 WrapType |= OBO::NoUnsignedWrap;
6868 for (const SCEV *Op : drop_begin(Add->operands()))
6869 X = X.addWithNoWrap(getRangeRef(Op, SignHint, Depth + 1), WrapType,
6870 RangeType);
6871 return setRange(Add, SignHint,
6872 ConservativeResult.intersectWith(X, RangeType));
6873 }
6874 case scMulExpr: {
6875 const SCEVMulExpr *Mul = cast<SCEVMulExpr>(S);
6876 ConstantRange X = getRangeRef(Mul->getOperand(0), SignHint, Depth + 1);
6877 for (const SCEV *Op : drop_begin(Mul->operands()))
6878 X = X.multiply(getRangeRef(Op, SignHint, Depth + 1));
6879 return setRange(Mul, SignHint,
6880 ConservativeResult.intersectWith(X, RangeType));
6881 }
6882 case scUDivExpr: {
6883 const SCEVUDivExpr *UDiv = cast<SCEVUDivExpr>(S);
6884 ConstantRange X = getRangeRef(UDiv->getLHS(), SignHint, Depth + 1);
6885 ConstantRange Y = getRangeRef(UDiv->getRHS(), SignHint, Depth + 1);
6886 return setRange(UDiv, SignHint,
6887 ConservativeResult.intersectWith(X.udiv(Y), RangeType));
6888 }
6889 case scAddRecExpr: {
6890 const SCEVAddRecExpr *AddRec = cast<SCEVAddRecExpr>(S);
6891 // If there's no unsigned wrap, the value will never be less than its
6892 // initial value.
6893 if (AddRec->hasNoUnsignedWrap()) {
6894 APInt UnsignedMinValue = getUnsignedRangeMin(AddRec->getStart());
6895 if (!UnsignedMinValue.isZero())
6896 ConservativeResult = ConservativeResult.intersectWith(
6897 ConstantRange(UnsignedMinValue, APInt(BitWidth, 0)), RangeType);
6898 }
6899
6900 // If there's no signed wrap, and all the operands except initial value have
6901 // the same sign or zero, the value won't ever be:
6902 // 1: smaller than initial value if operands are non negative,
6903 // 2: bigger than initial value if operands are non positive.
6904 // For both cases, value can not cross signed min/max boundary.
6905 if (AddRec->hasNoSignedWrap()) {
6906 bool AllNonNeg = true;
6907 bool AllNonPos = true;
6908 for (unsigned i = 1, e = AddRec->getNumOperands(); i != e; ++i) {
6909 if (!isKnownNonNegative(AddRec->getOperand(i)))
6910 AllNonNeg = false;
6911 if (!isKnownNonPositive(AddRec->getOperand(i)))
6912 AllNonPos = false;
6913 }
6914 if (AllNonNeg)
6915 ConservativeResult = ConservativeResult.intersectWith(
6918 RangeType);
6919 else if (AllNonPos)
6920 ConservativeResult = ConservativeResult.intersectWith(
6922 getSignedRangeMax(AddRec->getStart()) +
6923 1),
6924 RangeType);
6925 }
6926
6927 // TODO: non-affine addrec
6928 if (AddRec->isAffine()) {
6929 const SCEV *MaxBEScev =
6931 if (!isa<SCEVCouldNotCompute>(MaxBEScev)) {
6932 APInt MaxBECount = cast<SCEVConstant>(MaxBEScev)->getAPInt();
6933
6934 // Adjust MaxBECount to the same bitwidth as AddRec. We can truncate if
6935 // MaxBECount's active bits are all <= AddRec's bit width.
6936 if (MaxBECount.getBitWidth() > BitWidth &&
6937 MaxBECount.getActiveBits() <= BitWidth)
6938 MaxBECount = MaxBECount.trunc(BitWidth);
6939 else if (MaxBECount.getBitWidth() < BitWidth)
6940 MaxBECount = MaxBECount.zext(BitWidth);
6941
6942 if (MaxBECount.getBitWidth() == BitWidth) {
6943 auto RangeFromAffine = getRangeForAffineAR(
6944 AddRec->getStart(), AddRec->getStepRecurrence(*this), MaxBECount);
6945 ConservativeResult =
6946 ConservativeResult.intersectWith(RangeFromAffine, RangeType);
6947
6948 auto RangeFromFactoring = getRangeViaFactoring(
6949 AddRec->getStart(), AddRec->getStepRecurrence(*this), MaxBECount);
6950 ConservativeResult =
6951 ConservativeResult.intersectWith(RangeFromFactoring, RangeType);
6952 }
6953 }
6954
6955 // Now try symbolic BE count and more powerful methods.
6957 const SCEV *SymbolicMaxBECount =
6959 if (!isa<SCEVCouldNotCompute>(SymbolicMaxBECount) &&
6960 getTypeSizeInBits(MaxBEScev->getType()) <= BitWidth &&
6961 AddRec->hasNoSelfWrap()) {
6962 auto RangeFromAffineNew = getRangeForAffineNoSelfWrappingAR(
6963 AddRec, SymbolicMaxBECount, BitWidth, SignHint);
6964 ConservativeResult =
6965 ConservativeResult.intersectWith(RangeFromAffineNew, RangeType);
6966 }
6967 }
6968 }
6969
6970 return setRange(AddRec, SignHint, std::move(ConservativeResult));
6971 }
6972 case scUMaxExpr:
6973 case scSMaxExpr:
6974 case scUMinExpr:
6975 case scSMinExpr:
6976 case scSequentialUMinExpr: {
6978 switch (S->getSCEVType()) {
6979 case scUMaxExpr:
6980 ID = Intrinsic::umax;
6981 break;
6982 case scSMaxExpr:
6983 ID = Intrinsic::smax;
6984 break;
6985 case scUMinExpr:
6987 ID = Intrinsic::umin;
6988 break;
6989 case scSMinExpr:
6990 ID = Intrinsic::smin;
6991 break;
6992 default:
6993 llvm_unreachable("Unknown SCEVMinMaxExpr/SCEVSequentialMinMaxExpr.");
6994 }
6995
6996 const auto *NAry = cast<SCEVNAryExpr>(S);
6997 ConstantRange X = getRangeRef(NAry->getOperand(0), SignHint, Depth + 1);
6998 for (unsigned i = 1, e = NAry->getNumOperands(); i != e; ++i)
6999 X = X.intrinsic(
7000 ID, {X, getRangeRef(NAry->getOperand(i), SignHint, Depth + 1)});
7001 return setRange(S, SignHint,
7002 ConservativeResult.intersectWith(X, RangeType));
7003 }
7004 case scUnknown: {
7005 const SCEVUnknown *U = cast<SCEVUnknown>(S);
7006 Value *V = U->getValue();
7007
7008 // Check if the IR explicitly contains !range metadata.
7009 std::optional<ConstantRange> MDRange = GetRangeFromMetadata(V);
7010 if (MDRange)
7011 ConservativeResult =
7012 ConservativeResult.intersectWith(*MDRange, RangeType);
7013
7014 // Use facts about recurrences in the underlying IR. Note that add
7015 // recurrences are AddRecExprs and thus don't hit this path. This
7016 // primarily handles shift recurrences.
7017 auto CR = getRangeForUnknownRecurrence(U);
7018 ConservativeResult = ConservativeResult.intersectWith(CR);
7019
7020 // See if ValueTracking can give us a useful range.
7021 const DataLayout &DL = getDataLayout();
7022 KnownBits Known = computeKnownBits(V, DL, &AC, nullptr, &DT);
7023 if (Known.getBitWidth() != BitWidth)
7024 Known = Known.zextOrTrunc(BitWidth);
7025
7026 // ValueTracking may be able to compute a tighter result for the number of
7027 // sign bits than for the value of those sign bits.
7028 unsigned NS = ComputeNumSignBits(V, DL, &AC, nullptr, &DT);
7029 if (U->getType()->isPointerTy()) {
7030 // If the pointer size is larger than the index size type, this can cause
7031 // NS to be larger than BitWidth. So compensate for this.
7032 unsigned ptrSize = DL.getPointerTypeSizeInBits(U->getType());
7033 int ptrIdxDiff = ptrSize - BitWidth;
7034 if (ptrIdxDiff > 0 && ptrSize > BitWidth && NS > (unsigned)ptrIdxDiff)
7035 NS -= ptrIdxDiff;
7036 }
7037
7038 if (NS > 1) {
7039 // If we know any of the sign bits, we know all of the sign bits.
7040 if (!Known.Zero.getHiBits(NS).isZero())
7041 Known.Zero.setHighBits(NS);
7042 if (!Known.One.getHiBits(NS).isZero())
7043 Known.One.setHighBits(NS);
7044 }
7045
7046 if (Known.getMinValue() != Known.getMaxValue() + 1)
7047 ConservativeResult = ConservativeResult.intersectWith(
7048 ConstantRange(Known.getMinValue(), Known.getMaxValue() + 1),
7049 RangeType);
7050 if (NS > 1)
7051 ConservativeResult = ConservativeResult.intersectWith(
7052 ConstantRange(APInt::getSignedMinValue(BitWidth).ashr(NS - 1),
7053 APInt::getSignedMaxValue(BitWidth).ashr(NS - 1) + 1),
7054 RangeType);
7055
7056 if (U->getType()->isPointerTy() && SignHint == HINT_RANGE_UNSIGNED) {
7057 // Strengthen the range if the underlying IR value is a
7058 // global/alloca/heap allocation using the size of the object.
7059 bool CanBeNull, CanBeFreed;
7060 uint64_t DerefBytes =
7061 V->getPointerDereferenceableBytes(DL, CanBeNull, CanBeFreed);
7062 if (DerefBytes > 1 && isUIntN(BitWidth, DerefBytes)) {
7063 // The highest address the object can start is DerefBytes bytes before
7064 // the end (unsigned max value). If this value is not a multiple of the
7065 // alignment, the last possible start value is the next lowest multiple
7066 // of the alignment. Note: The computations below cannot overflow,
7067 // because if they would there's no possible start address for the
7068 // object.
7069 APInt MaxVal =
7070 APInt::getMaxValue(BitWidth) - APInt(BitWidth, DerefBytes);
7071 uint64_t Align = U->getValue()->getPointerAlignment(DL).value();
7072 uint64_t Rem = MaxVal.urem(Align);
7073 MaxVal -= APInt(BitWidth, Rem);
7074 APInt MinVal = APInt::getZero(BitWidth);
7075 if (llvm::isKnownNonZero(V, DL))
7076 MinVal = Align;
7077 ConservativeResult = ConservativeResult.intersectWith(
7078 ConstantRange::getNonEmpty(MinVal, MaxVal + 1), RangeType);
7079 }
7080 }
7081
7082 // A range of Phi is a subset of union of all ranges of its input.
7083 if (PHINode *Phi = dyn_cast<PHINode>(V)) {
7084 // SCEVExpander sometimes creates SCEVUnknowns that are secretly
7085 // AddRecs; return the range for the corresponding AddRec.
7086 if (auto *AR = dyn_cast<SCEVAddRecExpr>(getSCEV(V)))
7087 return getRangeRef(AR, SignHint, Depth + 1);
7088
7089 // Make sure that we do not run over cycled Phis.
7090 if (RangeRefPHIAllowedOperands(DT, Phi)) {
7091 ConstantRange RangeFromOps(BitWidth, /*isFullSet=*/false);
7092
7093 for (const auto &Op : Phi->operands()) {
7094 auto OpRange = getRangeRef(getSCEV(Op), SignHint, Depth + 1);
7095 RangeFromOps = RangeFromOps.unionWith(OpRange);
7096 // No point to continue if we already have a full set.
7097 if (RangeFromOps.isFullSet())
7098 break;
7099 }
7100 ConservativeResult =
7101 ConservativeResult.intersectWith(RangeFromOps, RangeType);
7102 }
7103 }
7104
7105 // vscale can't be equal to zero
7106 if (const auto *II = dyn_cast<IntrinsicInst>(V))
7107 if (II->getIntrinsicID() == Intrinsic::vscale) {
7108 ConstantRange Disallowed = APInt::getZero(BitWidth);
7109 ConservativeResult = ConservativeResult.difference(Disallowed);
7110 }
7111
7112 return setRange(U, SignHint, std::move(ConservativeResult));
7113 }
7114 case scCouldNotCompute:
7115 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
7116 }
7117
7118 return setRange(S, SignHint, std::move(ConservativeResult));
7119}
7120
7121// Given a StartRange, Step and MaxBECount for an expression compute a range of
7122// values that the expression can take. Initially, the expression has a value
7123// from StartRange and then is changed by Step up to MaxBECount times. Signed
7124// argument defines if we treat Step as signed or unsigned.
7126 const ConstantRange &StartRange,
7127 const APInt &MaxBECount,
7128 bool Signed) {
7129 unsigned BitWidth = Step.getBitWidth();
7130 assert(BitWidth == StartRange.getBitWidth() &&
7131 BitWidth == MaxBECount.getBitWidth() && "mismatched bit widths");
7132 // If either Step or MaxBECount is 0, then the expression won't change, and we
7133 // just need to return the initial range.
7134 if (Step == 0 || MaxBECount == 0)
7135 return StartRange;
7136
7137 // If we don't know anything about the initial value (i.e. StartRange is
7138 // FullRange), then we don't know anything about the final range either.
7139 // Return FullRange.
7140 if (StartRange.isFullSet())
7141 return ConstantRange::getFull(BitWidth);
7142
7143 // If Step is signed and negative, then we use its absolute value, but we also
7144 // note that we're moving in the opposite direction.
7145 bool Descending = Signed && Step.isNegative();
7146
7147 if (Signed)
7148 // This is correct even for INT_SMIN. Let's look at i8 to illustrate this:
7149 // abs(INT_SMIN) = abs(-128) = abs(0x80) = -0x80 = 0x80 = 128.
7150 // This equations hold true due to the well-defined wrap-around behavior of
7151 // APInt.
7152 Step = Step.abs();
7153
7154 // Check if Offset is more than full span of BitWidth. If it is, the
7155 // expression is guaranteed to overflow.
7156 if (APInt::getMaxValue(StartRange.getBitWidth()).udiv(Step).ult(MaxBECount))
7157 return ConstantRange::getFull(BitWidth);
7158
7159 // Offset is by how much the expression can change. Checks above guarantee no
7160 // overflow here.
7161 APInt Offset = Step * MaxBECount;
7162
7163 // Minimum value of the final range will match the minimal value of StartRange
7164 // if the expression is increasing and will be decreased by Offset otherwise.
7165 // Maximum value of the final range will match the maximal value of StartRange
7166 // if the expression is decreasing and will be increased by Offset otherwise.
7167 APInt StartLower = StartRange.getLower();
7168 APInt StartUpper = StartRange.getUpper() - 1;
7169 APInt MovedBoundary = Descending ? (StartLower - std::move(Offset))
7170 : (StartUpper + std::move(Offset));
7171
7172 // It's possible that the new minimum/maximum value will fall into the initial
7173 // range (due to wrap around). This means that the expression can take any
7174 // value in this bitwidth, and we have to return full range.
7175 if (StartRange.contains(MovedBoundary))
7176 return ConstantRange::getFull(BitWidth);
7177
7178 APInt NewLower =
7179 Descending ? std::move(MovedBoundary) : std::move(StartLower);
7180 APInt NewUpper =
7181 Descending ? std::move(StartUpper) : std::move(MovedBoundary);
7182 NewUpper += 1;
7183
7184 // No overflow detected, return [StartLower, StartUpper + Offset + 1) range.
7185 return ConstantRange::getNonEmpty(std::move(NewLower), std::move(NewUpper));
7186}
7187
7188ConstantRange ScalarEvolution::getRangeForAffineAR(const SCEV *Start,
7189 const SCEV *Step,
7190 const APInt &MaxBECount) {
7191 assert(getTypeSizeInBits(Start->getType()) ==
7192 getTypeSizeInBits(Step->getType()) &&
7193 getTypeSizeInBits(Start->getType()) == MaxBECount.getBitWidth() &&
7194 "mismatched bit widths");
7195
7196 // First, consider step signed.
7197 ConstantRange StartSRange = getSignedRange(Start);
7198 ConstantRange StepSRange = getSignedRange(Step);
7199
7200 // If Step can be both positive and negative, we need to find ranges for the
7201 // maximum absolute step values in both directions and union them.
7202 ConstantRange SR = getRangeForAffineARHelper(
7203 StepSRange.getSignedMin(), StartSRange, MaxBECount, /* Signed = */ true);
7205 StartSRange, MaxBECount,
7206 /* Signed = */ true));
7207
7208 // Next, consider step unsigned.
7209 ConstantRange UR = getRangeForAffineARHelper(
7210 getUnsignedRangeMax(Step), getUnsignedRange(Start), MaxBECount,
7211 /* Signed = */ false);
7212
7213 // Finally, intersect signed and unsigned ranges.
7215}
7216
7217ConstantRange ScalarEvolution::getRangeForAffineNoSelfWrappingAR(
7218 const SCEVAddRecExpr *AddRec, const SCEV *MaxBECount, unsigned BitWidth,
7219 ScalarEvolution::RangeSignHint SignHint) {
7220 assert(AddRec->isAffine() && "Non-affine AddRecs are not suppored!\n");
7221 assert(AddRec->hasNoSelfWrap() &&
7222 "This only works for non-self-wrapping AddRecs!");
7223 const bool IsSigned = SignHint == HINT_RANGE_SIGNED;
7224 const SCEV *Step = AddRec->getStepRecurrence(*this);
7225 // Only deal with constant step to save compile time.
7226 if (!isa<SCEVConstant>(Step))
7227 return ConstantRange::getFull(BitWidth);
7228 // Let's make sure that we can prove that we do not self-wrap during
7229 // MaxBECount iterations. We need this because MaxBECount is a maximum
7230 // iteration count estimate, and we might infer nw from some exit for which we
7231 // do not know max exit count (or any other side reasoning).
7232 // TODO: Turn into assert at some point.
7233 if (getTypeSizeInBits(MaxBECount->getType()) >
7234 getTypeSizeInBits(AddRec->getType()))
7235 return ConstantRange::getFull(BitWidth);
7236 MaxBECount = getNoopOrZeroExtend(MaxBECount, AddRec->getType());
7237 const SCEV *RangeWidth = getMinusOne(AddRec->getType());
7238 const SCEV *StepAbs = getUMinExpr(Step, getNegativeSCEV(Step));
7239 const SCEV *MaxItersWithoutWrap = getUDivExpr(RangeWidth, StepAbs);
7240 if (!isKnownPredicateViaConstantRanges(ICmpInst::ICMP_ULE, MaxBECount,
7241 MaxItersWithoutWrap))
7242 return ConstantRange::getFull(BitWidth);
7243
7244 ICmpInst::Predicate LEPred =
7246 ICmpInst::Predicate GEPred =
7248 const SCEV *End = AddRec->evaluateAtIteration(MaxBECount, *this);
7249
7250 // We know that there is no self-wrap. Let's take Start and End values and
7251 // look at all intermediate values V1, V2, ..., Vn that IndVar takes during
7252 // the iteration. They either lie inside the range [Min(Start, End),
7253 // Max(Start, End)] or outside it:
7254 //
7255 // Case 1: RangeMin ... Start V1 ... VN End ... RangeMax;
7256 // Case 2: RangeMin Vk ... V1 Start ... End Vn ... Vk + 1 RangeMax;
7257 //
7258 // No self wrap flag guarantees that the intermediate values cannot be BOTH
7259 // outside and inside the range [Min(Start, End), Max(Start, End)]. Using that
7260 // knowledge, let's try to prove that we are dealing with Case 1. It is so if
7261 // Start <= End and step is positive, or Start >= End and step is negative.
7262 const SCEV *Start = applyLoopGuards(AddRec->getStart(), AddRec->getLoop());
7263 ConstantRange StartRange = getRangeRef(Start, SignHint);
7264 ConstantRange EndRange = getRangeRef(End, SignHint);
7265 ConstantRange RangeBetween = StartRange.unionWith(EndRange);
7266 // If they already cover full iteration space, we will know nothing useful
7267 // even if we prove what we want to prove.
7268 if (RangeBetween.isFullSet())
7269 return RangeBetween;
7270 // Only deal with ranges that do not wrap (i.e. RangeMin < RangeMax).
7271 bool IsWrappedSet = IsSigned ? RangeBetween.isSignWrappedSet()
7272 : RangeBetween.isWrappedSet();
7273 if (IsWrappedSet)
7274 return ConstantRange::getFull(BitWidth);
7275
7276 if (isKnownPositive(Step) &&
7277 isKnownPredicateViaConstantRanges(LEPred, Start, End))
7278 return RangeBetween;
7279 if (isKnownNegative(Step) &&
7280 isKnownPredicateViaConstantRanges(GEPred, Start, End))
7281 return RangeBetween;
7282 return ConstantRange::getFull(BitWidth);
7283}
7284
7285ConstantRange ScalarEvolution::getRangeViaFactoring(const SCEV *Start,
7286 const SCEV *Step,
7287 const APInt &MaxBECount) {
7288 // RangeOf({C?A:B,+,C?P:Q}) == RangeOf(C?{A,+,P}:{B,+,Q})
7289 // == RangeOf({A,+,P}) union RangeOf({B,+,Q})
7290
7291 unsigned BitWidth = MaxBECount.getBitWidth();
7292 assert(getTypeSizeInBits(Start->getType()) == BitWidth &&
7293 getTypeSizeInBits(Step->getType()) == BitWidth &&
7294 "mismatched bit widths");
7295
7296 struct SelectPattern {
7297 Value *Condition = nullptr;
7298 APInt TrueValue;
7299 APInt FalseValue;
7300
7301 explicit SelectPattern(ScalarEvolution &SE, unsigned BitWidth,
7302 const SCEV *S) {
7303 std::optional<unsigned> CastOp;
7304 APInt Offset(BitWidth, 0);
7305
7307 "Should be!");
7308
7309 // Peel off a constant offset. In the future we could consider being
7310 // smarter here and handle {Start+Step,+,Step} too.
7311 const APInt *Off;
7312 if (match(S, m_scev_Add(m_scev_APInt(Off), m_SCEV(S))))
7313 Offset = *Off;
7314
7315 // Peel off a cast operation
7316 if (auto *SCast = dyn_cast<SCEVIntegralCastExpr>(S)) {
7317 CastOp = SCast->getSCEVType();
7318 S = SCast->getOperand();
7319 }
7320
7321 using namespace llvm::PatternMatch;
7322
7323 auto *SU = dyn_cast<SCEVUnknown>(S);
7324 const APInt *TrueVal, *FalseVal;
7325 if (!SU ||
7326 !match(SU->getValue(), m_Select(m_Value(Condition), m_APInt(TrueVal),
7327 m_APInt(FalseVal)))) {
7328 Condition = nullptr;
7329 return;
7330 }
7331
7332 TrueValue = *TrueVal;
7333 FalseValue = *FalseVal;
7334
7335 // Re-apply the cast we peeled off earlier
7336 if (CastOp)
7337 switch (*CastOp) {
7338 default:
7339 llvm_unreachable("Unknown SCEV cast type!");
7340
7341 case scTruncate:
7342 TrueValue = TrueValue.trunc(BitWidth);
7343 FalseValue = FalseValue.trunc(BitWidth);
7344 break;
7345 case scZeroExtend:
7346 TrueValue = TrueValue.zext(BitWidth);
7347 FalseValue = FalseValue.zext(BitWidth);
7348 break;
7349 case scSignExtend:
7350 TrueValue = TrueValue.sext(BitWidth);
7351 FalseValue = FalseValue.sext(BitWidth);
7352 break;
7353 }
7354
7355 // Re-apply the constant offset we peeled off earlier
7356 TrueValue += Offset;
7357 FalseValue += Offset;
7358 }
7359
7360 bool isRecognized() { return Condition != nullptr; }
7361 };
7362
7363 SelectPattern StartPattern(*this, BitWidth, Start);
7364 if (!StartPattern.isRecognized())
7365 return ConstantRange::getFull(BitWidth);
7366
7367 SelectPattern StepPattern(*this, BitWidth, Step);
7368 if (!StepPattern.isRecognized())
7369 return ConstantRange::getFull(BitWidth);
7370
7371 if (StartPattern.Condition != StepPattern.Condition) {
7372 // We don't handle this case today; but we could, by considering four
7373 // possibilities below instead of two. I'm not sure if there are cases where
7374 // that will help over what getRange already does, though.
7375 return ConstantRange::getFull(BitWidth);
7376 }
7377
7378 // NB! Calling ScalarEvolution::getConstant is fine, but we should not try to
7379 // construct arbitrary general SCEV expressions here. This function is called
7380 // from deep in the call stack, and calling getSCEV (on a sext instruction,
7381 // say) can end up caching a suboptimal value.
7382
7383 // FIXME: without the explicit `this` receiver below, MSVC errors out with
7384 // C2352 and C2512 (otherwise it isn't needed).
7385
7386 const SCEV *TrueStart = this->getConstant(StartPattern.TrueValue);
7387 const SCEV *TrueStep = this->getConstant(StepPattern.TrueValue);
7388 const SCEV *FalseStart = this->getConstant(StartPattern.FalseValue);
7389 const SCEV *FalseStep = this->getConstant(StepPattern.FalseValue);
7390
7391 ConstantRange TrueRange =
7392 this->getRangeForAffineAR(TrueStart, TrueStep, MaxBECount);
7393 ConstantRange FalseRange =
7394 this->getRangeForAffineAR(FalseStart, FalseStep, MaxBECount);
7395
7396 return TrueRange.unionWith(FalseRange);
7397}
7398
7399SCEV::NoWrapFlags ScalarEvolution::getNoWrapFlagsFromUB(const Value *V) {
7400 if (isa<ConstantExpr>(V)) return SCEV::FlagAnyWrap;
7401 const BinaryOperator *BinOp = cast<BinaryOperator>(V);
7402
7403 // Return early if there are no flags to propagate to the SCEV.
7405 if (BinOp->hasNoUnsignedWrap())
7407 if (BinOp->hasNoSignedWrap())
7409 if (Flags == SCEV::FlagAnyWrap)
7410 return SCEV::FlagAnyWrap;
7411
7412 return isSCEVExprNeverPoison(BinOp) ? Flags : SCEV::FlagAnyWrap;
7413}
7414
7415const Instruction *
7416ScalarEvolution::getNonTrivialDefiningScopeBound(const SCEV *S) {
7417 if (auto *AddRec = dyn_cast<SCEVAddRecExpr>(S))
7418 return &*AddRec->getLoop()->getHeader()->begin();
7419 if (auto *U = dyn_cast<SCEVUnknown>(S))
7420 if (auto *I = dyn_cast<Instruction>(U->getValue()))
7421 return I;
7422 return nullptr;
7423}
7424
7425const Instruction *ScalarEvolution::getDefiningScopeBound(ArrayRef<SCEVUse> Ops,
7426 bool &Precise) {
7427 Precise = true;
7428 // Do a bounded search of the def relation of the requested SCEVs.
7429 SmallPtrSet<const SCEV *, 16> Visited;
7430 SmallVector<SCEVUse> Worklist;
7431 auto pushOp = [&](const SCEV *S) {
7432 if (!Visited.insert(S).second)
7433 return;
7434 // Threshold of 30 here is arbitrary.
7435 if (Visited.size() > 30) {
7436 Precise = false;
7437 return;
7438 }
7439 Worklist.push_back(S);
7440 };
7441
7442 for (SCEVUse S : Ops)
7443 pushOp(S);
7444
7445 const Instruction *Bound = nullptr;
7446 while (!Worklist.empty()) {
7447 SCEVUse S = Worklist.pop_back_val();
7448 if (auto *DefI = getNonTrivialDefiningScopeBound(S)) {
7449 if (!Bound || DT.dominates(Bound, DefI))
7450 Bound = DefI;
7451 } else {
7452 for (SCEVUse Op : S->operands())
7453 pushOp(Op);
7454 }
7455 }
7456 return Bound ? Bound : &*F.getEntryBlock().begin();
7457}
7458
7459const Instruction *
7460ScalarEvolution::getDefiningScopeBound(ArrayRef<SCEVUse> Ops) {
7461 bool Discard;
7462 return getDefiningScopeBound(Ops, Discard);
7463}
7464
7465bool ScalarEvolution::isGuaranteedToTransferExecutionTo(const Instruction *A,
7466 const Instruction *B) {
7467 if (A->getParent() == B->getParent() &&
7469 B->getIterator()))
7470 return true;
7471
7472 auto *BLoop = LI.getLoopFor(B->getParent());
7473 if (BLoop && BLoop->getHeader() == B->getParent() &&
7474 BLoop->getLoopPreheader() == A->getParent() &&
7476 A->getParent()->end()) &&
7477 isGuaranteedToTransferExecutionToSuccessor(B->getParent()->begin(),
7478 B->getIterator()))
7479 return true;
7480 return false;
7481}
7482
7483bool ScalarEvolution::isGuaranteedNotToBePoison(const SCEV *Op) {
7484 SCEVPoisonCollector PC(/* LookThroughMaybePoisonBlocking */ true);
7485 visitAll(Op, PC);
7486 return PC.MaybePoison.empty();
7487}
7488
7489bool ScalarEvolution::isGuaranteedNotToCauseUB(const SCEV *Op) {
7490 return !SCEVExprContains(Op, [this](const SCEV *S) {
7491 const SCEV *Op1;
7492 bool M = match(S, m_scev_UDiv(m_SCEV(), m_SCEV(Op1)));
7493 // The UDiv may be UB if the divisor is poison or zero. Unless the divisor
7494 // is a non-zero constant, we have to assume the UDiv may be UB.
7495 return M && (!isKnownNonZero(Op1) || !isGuaranteedNotToBePoison(Op1));
7496 });
7497}
7498
7499bool ScalarEvolution::isSCEVExprNeverPoison(const Instruction *I) {
7500 // Only proceed if we can prove that I does not yield poison.
7502 return false;
7503
7504 // At this point we know that if I is executed, then it does not wrap
7505 // according to at least one of NSW or NUW. If I is not executed, then we do
7506 // not know if the calculation that I represents would wrap. Multiple
7507 // instructions can map to the same SCEV. If we apply NSW or NUW from I to
7508 // the SCEV, we must guarantee no wrapping for that SCEV also when it is
7509 // derived from other instructions that map to the same SCEV. We cannot make
7510 // that guarantee for cases where I is not executed. So we need to find a
7511 // upper bound on the defining scope for the SCEV, and prove that I is
7512 // executed every time we enter that scope. When the bounding scope is a
7513 // loop (the common case), this is equivalent to proving I executes on every
7514 // iteration of that loop.
7515 SmallVector<SCEVUse> SCEVOps;
7516 for (const Use &Op : I->operands()) {
7517 // I could be an extractvalue from a call to an overflow intrinsic.
7518 // TODO: We can do better here in some cases.
7519 if (isSCEVable(Op->getType()))
7520 SCEVOps.push_back(getSCEV(Op));
7521 }
7522 auto *DefI = getDefiningScopeBound(SCEVOps);
7523 return isGuaranteedToTransferExecutionTo(DefI, I);
7524}
7525
7526bool ScalarEvolution::isAddRecNeverPoison(const Instruction *I, const Loop *L) {
7527 // If we know that \c I can never be poison period, then that's enough.
7528 if (isSCEVExprNeverPoison(I))
7529 return true;
7530
7531 // If the loop only has one exit, then we know that, if the loop is entered,
7532 // any instruction dominating that exit will be executed. If any such
7533 // instruction would result in UB, the addrec cannot be poison.
7534 //
7535 // This is basically the same reasoning as in isSCEVExprNeverPoison(), but
7536 // also handles uses outside the loop header (they just need to dominate the
7537 // single exit).
7538
7539 auto *ExitingBB = L->getExitingBlock();
7540 if (!ExitingBB || !loopHasNoAbnormalExits(L))
7541 return false;
7542
7543 SmallPtrSet<const Value *, 16> KnownPoison;
7545
7546 // We start by assuming \c I, the post-inc add recurrence, is poison. Only
7547 // things that are known to be poison under that assumption go on the
7548 // Worklist.
7549 KnownPoison.insert(I);
7550 Worklist.push_back(I);
7551
7552 while (!Worklist.empty()) {
7553 const Instruction *Poison = Worklist.pop_back_val();
7554
7555 for (const Use &U : Poison->uses()) {
7556 const Instruction *PoisonUser = cast<Instruction>(U.getUser());
7557 if (mustTriggerUB(PoisonUser, KnownPoison) &&
7558 DT.dominates(PoisonUser->getParent(), ExitingBB))
7559 return true;
7560
7561 if (propagatesPoison(U) && L->contains(PoisonUser))
7562 if (KnownPoison.insert(PoisonUser).second)
7563 Worklist.push_back(PoisonUser);
7564 }
7565 }
7566
7567 return false;
7568}
7569
7570ScalarEvolution::LoopProperties
7571ScalarEvolution::getLoopProperties(const Loop *L) {
7572 using LoopProperties = ScalarEvolution::LoopProperties;
7573
7574 auto Itr = LoopPropertiesCache.find(L);
7575 if (Itr == LoopPropertiesCache.end()) {
7576 auto HasSideEffects = [](Instruction *I) {
7577 if (auto *SI = dyn_cast<StoreInst>(I))
7578 return !SI->isSimple();
7579
7580 if (I->mayThrow())
7581 return true;
7582
7583 // Non-volatile memset / memcpy do not count as side-effect for forward
7584 // progress.
7585 if (isa<MemIntrinsic>(I) && !I->isVolatile())
7586 return false;
7587
7588 return I->mayWriteToMemory();
7589 };
7590
7591 LoopProperties LP = {/* HasNoAbnormalExits */ true,
7592 /*HasNoSideEffects*/ true};
7593
7594 for (auto *BB : L->getBlocks())
7595 for (auto &I : *BB) {
7597 LP.HasNoAbnormalExits = false;
7598 if (HasSideEffects(&I))
7599 LP.HasNoSideEffects = false;
7600 if (!LP.HasNoAbnormalExits && !LP.HasNoSideEffects)
7601 break; // We're already as pessimistic as we can get.
7602 }
7603
7604 auto InsertPair = LoopPropertiesCache.insert({L, LP});
7605 assert(InsertPair.second && "We just checked!");
7606 Itr = InsertPair.first;
7607 }
7608
7609 return Itr->second;
7610}
7611
7613 // A mustprogress loop without side effects must be finite.
7614 // TODO: The check used here is very conservative. It's only *specific*
7615 // side effects which are well defined in infinite loops.
7616 return isFinite(L) || (isMustProgress(L) && loopHasNoSideEffects(L));
7617}
7618
7619const SCEV *ScalarEvolution::createSCEVIter(Value *V) {
7620 // Worklist item with a Value and a bool indicating whether all operands have
7621 // been visited already.
7624
7625 Stack.emplace_back(V, true);
7626 Stack.emplace_back(V, false);
7627 while (!Stack.empty()) {
7628 auto E = Stack.pop_back_val();
7629 Value *CurV = E.getPointer();
7630
7631 if (getExistingSCEV(CurV))
7632 continue;
7633
7635 const SCEV *CreatedSCEV = nullptr;
7636 // If all operands have been visited already, create the SCEV.
7637 if (E.getInt()) {
7638 CreatedSCEV = createSCEV(CurV);
7639 } else {
7640 // Otherwise get the operands we need to create SCEV's for before creating
7641 // the SCEV for CurV. If the SCEV for CurV can be constructed trivially,
7642 // just use it.
7643 CreatedSCEV = getOperandsToCreate(CurV, Ops);
7644 }
7645
7646 if (CreatedSCEV) {
7647 insertValueToMap(CurV, CreatedSCEV);
7648 } else {
7649 // Queue CurV for SCEV creation, followed by its's operands which need to
7650 // be constructed first.
7651 Stack.emplace_back(CurV, true);
7652 for (Value *Op : Ops)
7653 Stack.emplace_back(Op, false);
7654 }
7655 }
7656
7657 return getExistingSCEV(V);
7658}
7659
7660const SCEV *
7661ScalarEvolution::getOperandsToCreate(Value *V, SmallVectorImpl<Value *> &Ops) {
7662 if (!isSCEVable(V->getType()))
7663 return getUnknown(V);
7664
7665 if (Instruction *I = dyn_cast<Instruction>(V)) {
7666 // Don't attempt to analyze instructions in blocks that aren't
7667 // reachable. Such instructions don't matter, and they aren't required
7668 // to obey basic rules for definitions dominating uses which this
7669 // analysis depends on.
7670 if (!DT.isReachableFromEntry(I->getParent()))
7671 return getUnknown(PoisonValue::get(V->getType()));
7672 } else if (ConstantInt *CI = dyn_cast<ConstantInt>(V))
7673 return getConstant(CI);
7674 else if (isa<GlobalAlias>(V))
7675 return getUnknown(V);
7676 else if (!isa<ConstantExpr>(V))
7677 return getUnknown(V);
7678
7680 if (auto BO =
7682 bool IsConstArg = isa<ConstantInt>(BO->RHS);
7683 switch (BO->Opcode) {
7684 case Instruction::Add:
7685 case Instruction::Mul: {
7686 // For additions and multiplications, traverse add/mul chains for which we
7687 // can potentially create a single SCEV, to reduce the number of
7688 // get{Add,Mul}Expr calls.
7689 do {
7690 if (BO->Op) {
7691 if (BO->Op != V && getExistingSCEV(BO->Op)) {
7692 Ops.push_back(BO->Op);
7693 break;
7694 }
7695 }
7696 Ops.push_back(BO->RHS);
7697 auto NewBO = MatchBinaryOp(BO->LHS, getDataLayout(), AC, DT,
7699 if (!NewBO ||
7700 (BO->Opcode == Instruction::Add &&
7701 (NewBO->Opcode != Instruction::Add &&
7702 NewBO->Opcode != Instruction::Sub)) ||
7703 (BO->Opcode == Instruction::Mul &&
7704 NewBO->Opcode != Instruction::Mul)) {
7705 Ops.push_back(BO->LHS);
7706 break;
7707 }
7708 // CreateSCEV calls getNoWrapFlagsFromUB, which under certain conditions
7709 // requires a SCEV for the LHS.
7710 if (BO->Op && (BO->IsNSW || BO->IsNUW)) {
7711 auto *I = dyn_cast<Instruction>(BO->Op);
7712 if (I && programUndefinedIfPoison(I)) {
7713 Ops.push_back(BO->LHS);
7714 break;
7715 }
7716 }
7717 BO = NewBO;
7718 } while (true);
7719 return nullptr;
7720 }
7721 case Instruction::Sub:
7722 case Instruction::UDiv:
7723 case Instruction::URem:
7724 break;
7725 case Instruction::AShr:
7726 case Instruction::Shl:
7727 case Instruction::Xor:
7728 if (!IsConstArg)
7729 return nullptr;
7730 break;
7731 case Instruction::And:
7732 case Instruction::Or:
7733 if (!IsConstArg && !BO->LHS->getType()->isIntegerTy(1))
7734 return nullptr;
7735 break;
7736 case Instruction::LShr:
7737 return getUnknown(V);
7738 default:
7739 llvm_unreachable("Unhandled binop");
7740 break;
7741 }
7742
7743 Ops.push_back(BO->LHS);
7744 Ops.push_back(BO->RHS);
7745 return nullptr;
7746 }
7747
7748 switch (U->getOpcode()) {
7749 case Instruction::Trunc:
7750 case Instruction::ZExt:
7751 case Instruction::SExt:
7752 case Instruction::PtrToAddr:
7753 case Instruction::PtrToInt:
7754 Ops.push_back(U->getOperand(0));
7755 return nullptr;
7756
7757 case Instruction::BitCast:
7758 if (isSCEVable(U->getType()) && isSCEVable(U->getOperand(0)->getType())) {
7759 Ops.push_back(U->getOperand(0));
7760 return nullptr;
7761 }
7762 return getUnknown(V);
7763
7764 case Instruction::SDiv:
7765 case Instruction::SRem:
7766 Ops.push_back(U->getOperand(0));
7767 Ops.push_back(U->getOperand(1));
7768 return nullptr;
7769
7770 case Instruction::GetElementPtr:
7771 assert(cast<GEPOperator>(U)->getSourceElementType()->isSized() &&
7772 "GEP source element type must be sized");
7773 llvm::append_range(Ops, U->operands());
7774 return nullptr;
7775
7776 case Instruction::IntToPtr:
7777 return getUnknown(V);
7778
7779 case Instruction::PHI:
7780 // getNodeForPHI has four ways to turn a PHI into a SCEV; retrieve the
7781 // relevant nodes for each of them.
7782 //
7783 // The first is just to call simplifyInstruction, and get something back
7784 // that isn't a PHI.
7785 if (Value *V = simplifyInstruction(
7786 cast<PHINode>(U),
7787 {getDataLayout(), &TLI, &DT, &AC, /*CtxI=*/nullptr,
7788 /*UseInstrInfo=*/true, /*CanUseUndef=*/false})) {
7789 assert(V);
7790 Ops.push_back(V);
7791 return nullptr;
7792 }
7793 // The second is createNodeForPHIWithIdenticalOperands: this looks for
7794 // operands which all perform the same operation, but haven't been
7795 // CSE'ed for whatever reason.
7796 if (BinaryOperator *BO = getCommonInstForPHI(cast<PHINode>(U))) {
7797 assert(BO);
7798 Ops.push_back(BO);
7799 return nullptr;
7800 }
7801 // The third is createNodeFromSelectLikePHI; this takes a PHI which
7802 // is equivalent to a select, and analyzes it like a select.
7803 {
7804 Value *Cond = nullptr, *LHS = nullptr, *RHS = nullptr;
7806 assert(Cond);
7807 assert(LHS);
7808 assert(RHS);
7809 if (auto *CondICmp = dyn_cast<ICmpInst>(Cond)) {
7810 Ops.push_back(CondICmp->getOperand(0));
7811 Ops.push_back(CondICmp->getOperand(1));
7812 }
7813 Ops.push_back(Cond);
7814 Ops.push_back(LHS);
7815 Ops.push_back(RHS);
7816 return nullptr;
7817 }
7818 }
7819 // The fourth way is createAddRecFromPHI. It's complicated to handle here,
7820 // so just construct it recursively.
7821 //
7822 // In addition to getNodeForPHI, also construct nodes which might be needed
7823 // by getRangeRef.
7825 for (Value *V : cast<PHINode>(U)->operands())
7826 Ops.push_back(V);
7827 return nullptr;
7828 }
7829 return nullptr;
7830
7831 case Instruction::Select: {
7832 // Check if U is a select that can be simplified to a SCEVUnknown.
7833 auto CanSimplifyToUnknown = [this, U]() {
7834 if (U->getType()->isIntegerTy(1) || isa<ConstantInt>(U->getOperand(0)))
7835 return false;
7836
7837 auto *ICI = dyn_cast<ICmpInst>(U->getOperand(0));
7838 if (!ICI)
7839 return false;
7840 Value *LHS = ICI->getOperand(0);
7841 Value *RHS = ICI->getOperand(1);
7842 if (ICI->getPredicate() == CmpInst::ICMP_EQ ||
7843 ICI->getPredicate() == CmpInst::ICMP_NE) {
7845 return true;
7846 } else if (getTypeSizeInBits(LHS->getType()) >
7847 getTypeSizeInBits(U->getType()))
7848 return true;
7849 return false;
7850 };
7851 if (CanSimplifyToUnknown())
7852 return getUnknown(U);
7853
7854 llvm::append_range(Ops, U->operands());
7855 return nullptr;
7856 break;
7857 }
7858 case Instruction::Call:
7859 case Instruction::Invoke:
7860 if (Value *RV = cast<CallBase>(U)->getReturnedArgOperand()) {
7861 Ops.push_back(RV);
7862 return nullptr;
7863 }
7864
7865 if (auto *II = dyn_cast<IntrinsicInst>(U)) {
7866 switch (II->getIntrinsicID()) {
7867 case Intrinsic::abs:
7868 Ops.push_back(II->getArgOperand(0));
7869 return nullptr;
7870 case Intrinsic::umax:
7871 case Intrinsic::umin:
7872 case Intrinsic::smax:
7873 case Intrinsic::smin:
7874 case Intrinsic::usub_sat:
7875 case Intrinsic::uadd_sat:
7876 Ops.push_back(II->getArgOperand(0));
7877 Ops.push_back(II->getArgOperand(1));
7878 return nullptr;
7879 case Intrinsic::start_loop_iterations:
7880 case Intrinsic::annotation:
7881 case Intrinsic::ptr_annotation:
7882 Ops.push_back(II->getArgOperand(0));
7883 return nullptr;
7884 default:
7885 break;
7886 }
7887 }
7888 break;
7889 }
7890
7891 return nullptr;
7892}
7893
7894const SCEV *ScalarEvolution::createSCEV(Value *V) {
7895 if (!isSCEVable(V->getType()))
7896 return getUnknown(V);
7897
7898 if (Instruction *I = dyn_cast<Instruction>(V)) {
7899 // Don't attempt to analyze instructions in blocks that aren't
7900 // reachable. Such instructions don't matter, and they aren't required
7901 // to obey basic rules for definitions dominating uses which this
7902 // analysis depends on.
7903 if (!DT.isReachableFromEntry(I->getParent()))
7904 return getUnknown(PoisonValue::get(V->getType()));
7905 } else if (ConstantInt *CI = dyn_cast<ConstantInt>(V))
7906 return getConstant(CI);
7907 else if (isa<GlobalAlias>(V))
7908 return getUnknown(V);
7909 else if (!isa<ConstantExpr>(V))
7910 return getUnknown(V);
7911
7912 const SCEV *LHS;
7913 const SCEV *RHS;
7914
7916 if (auto BO =
7918 switch (BO->Opcode) {
7919 case Instruction::Add: {
7920 // The simple thing to do would be to just call getSCEV on both operands
7921 // and call getAddExpr with the result. However if we're looking at a
7922 // bunch of things all added together, this can be quite inefficient,
7923 // because it leads to N-1 getAddExpr calls for N ultimate operands.
7924 // Instead, gather up all the operands and make a single getAddExpr call.
7925 // LLVM IR canonical form means we need only traverse the left operands.
7927 do {
7928 if (BO->Op) {
7929 if (auto *OpSCEV = getExistingSCEV(BO->Op)) {
7930 AddOps.push_back(OpSCEV);
7931 break;
7932 }
7933
7934 // If a NUW or NSW flag can be applied to the SCEV for this
7935 // addition, then compute the SCEV for this addition by itself
7936 // with a separate call to getAddExpr. We need to do that
7937 // instead of pushing the operands of the addition onto AddOps,
7938 // since the flags are only known to apply to this particular
7939 // addition - they may not apply to other additions that can be
7940 // formed with operands from AddOps.
7941 const SCEV *RHS = getSCEV(BO->RHS);
7942 SCEV::NoWrapFlags Flags = getNoWrapFlagsFromUB(BO->Op);
7943 if (Flags != SCEV::FlagAnyWrap) {
7944 const SCEV *LHS = getSCEV(BO->LHS);
7945 if (BO->Opcode == Instruction::Sub)
7946 AddOps.push_back(getMinusSCEV(LHS, RHS, Flags));
7947 else
7948 AddOps.push_back(getAddExpr(LHS, RHS, Flags));
7949 break;
7950 }
7951 }
7952
7953 if (BO->Opcode == Instruction::Sub)
7954 AddOps.push_back(getNegativeSCEV(getSCEV(BO->RHS)));
7955 else
7956 AddOps.push_back(getSCEV(BO->RHS));
7957
7958 auto NewBO = MatchBinaryOp(BO->LHS, getDataLayout(), AC, DT,
7960 if (!NewBO || (NewBO->Opcode != Instruction::Add &&
7961 NewBO->Opcode != Instruction::Sub)) {
7962 AddOps.push_back(getSCEV(BO->LHS));
7963 break;
7964 }
7965 BO = NewBO;
7966 } while (true);
7967
7968 return getAddExpr(AddOps);
7969 }
7970
7971 case Instruction::Mul: {
7973 do {
7974 if (BO->Op) {
7975 if (auto *OpSCEV = getExistingSCEV(BO->Op)) {
7976 MulOps.push_back(OpSCEV);
7977 break;
7978 }
7979
7980 SCEV::NoWrapFlags Flags = getNoWrapFlagsFromUB(BO->Op);
7981 if (Flags != SCEV::FlagAnyWrap) {
7982 LHS = getSCEV(BO->LHS);
7983 RHS = getSCEV(BO->RHS);
7984 MulOps.push_back(getMulExpr(LHS, RHS, Flags));
7985 break;
7986 }
7987 }
7988
7989 MulOps.push_back(getSCEV(BO->RHS));
7990 auto NewBO = MatchBinaryOp(BO->LHS, getDataLayout(), AC, DT,
7992 if (!NewBO || NewBO->Opcode != Instruction::Mul) {
7993 MulOps.push_back(getSCEV(BO->LHS));
7994 break;
7995 }
7996 BO = NewBO;
7997 } while (true);
7998
7999 return getMulExpr(MulOps);
8000 }
8001 case Instruction::UDiv:
8002 LHS = getSCEV(BO->LHS);
8003 RHS = getSCEV(BO->RHS);
8004 return getUDivExpr(LHS, RHS);
8005 case Instruction::URem:
8006 LHS = getSCEV(BO->LHS);
8007 RHS = getSCEV(BO->RHS);
8008 return getURemExpr(LHS, RHS);
8009 case Instruction::Sub: {
8011 if (BO->Op)
8012 Flags = getNoWrapFlagsFromUB(BO->Op);
8013 LHS = getSCEV(BO->LHS);
8014 RHS = getSCEV(BO->RHS);
8015 return getMinusSCEV(LHS, RHS, Flags);
8016 }
8017 case Instruction::And:
8018 // For an expression like x&255 that merely masks off the high bits,
8019 // use zext(trunc(x)) as the SCEV expression.
8020 if (ConstantInt *CI = dyn_cast<ConstantInt>(BO->RHS)) {
8021 if (CI->isZero())
8022 return getSCEV(BO->RHS);
8023 if (CI->isMinusOne())
8024 return getSCEV(BO->LHS);
8025 const APInt &A = CI->getValue();
8026
8027 // Instcombine's ShrinkDemandedConstant may strip bits out of
8028 // constants, obscuring what would otherwise be a low-bits mask.
8029 // Use computeKnownBits to compute what ShrinkDemandedConstant
8030 // knew about to reconstruct a low-bits mask value.
8031 unsigned LZ = A.countl_zero();
8032 unsigned TZ = A.countr_zero();
8033 unsigned BitWidth = A.getBitWidth();
8034 KnownBits Known(BitWidth);
8035 computeKnownBits(BO->LHS, Known, getDataLayout(), &AC, nullptr, &DT);
8036
8037 APInt EffectiveMask =
8038 APInt::getLowBitsSet(BitWidth, BitWidth - LZ - TZ).shl(TZ);
8039 if ((LZ != 0 || TZ != 0) && !((~A & ~Known.Zero) & EffectiveMask)) {
8040 const SCEV *MulCount = getConstant(APInt::getOneBitSet(BitWidth, TZ));
8041 const SCEV *LHS = getSCEV(BO->LHS);
8042 const SCEV *ShiftedLHS = nullptr;
8043 if (auto *LHSMul = dyn_cast<SCEVMulExpr>(LHS)) {
8044 if (auto *OpC = dyn_cast<SCEVConstant>(LHSMul->getOperand(0))) {
8045 // For an expression like (x * 8) & 8, simplify the multiply.
8046 unsigned MulZeros = OpC->getAPInt().countr_zero();
8047 unsigned GCD = std::min(MulZeros, TZ);
8048 APInt DivAmt = APInt::getOneBitSet(BitWidth, TZ - GCD);
8050 MulOps.push_back(getConstant(OpC->getAPInt().ashr(GCD)));
8051 append_range(MulOps, LHSMul->operands().drop_front());
8052 auto *NewMul = getMulExpr(MulOps, LHSMul->getNoWrapFlags());
8053 ShiftedLHS = getUDivExpr(NewMul, getConstant(DivAmt));
8054 }
8055 }
8056 if (!ShiftedLHS)
8057 ShiftedLHS = getUDivExpr(LHS, MulCount);
8058 return getMulExpr(
8060 getTruncateExpr(ShiftedLHS,
8061 IntegerType::get(getContext(), BitWidth - LZ - TZ)),
8062 BO->LHS->getType()),
8063 MulCount);
8064 }
8065 }
8066 // Binary `and` is a bit-wise `umin`.
8067 if (BO->LHS->getType()->isIntegerTy(1)) {
8068 LHS = getSCEV(BO->LHS);
8069 RHS = getSCEV(BO->RHS);
8070 return getUMinExpr(LHS, RHS);
8071 }
8072 break;
8073
8074 case Instruction::Or:
8075 // Binary `or` is a bit-wise `umax`.
8076 if (BO->LHS->getType()->isIntegerTy(1)) {
8077 LHS = getSCEV(BO->LHS);
8078 RHS = getSCEV(BO->RHS);
8079 return getUMaxExpr(LHS, RHS);
8080 }
8081 break;
8082
8083 case Instruction::Xor:
8084 if (ConstantInt *CI = dyn_cast<ConstantInt>(BO->RHS)) {
8085 // If the RHS of xor is -1, then this is a not operation.
8086 if (CI->isMinusOne())
8087 return getNotSCEV(getSCEV(BO->LHS));
8088
8089 // Model xor(and(x, C), C) as and(~x, C), if C is a low-bits mask.
8090 // This is a variant of the check for xor with -1, and it handles
8091 // the case where instcombine has trimmed non-demanded bits out
8092 // of an xor with -1.
8093 if (auto *LBO = dyn_cast<BinaryOperator>(BO->LHS))
8094 if (ConstantInt *LCI = dyn_cast<ConstantInt>(LBO->getOperand(1)))
8095 if (LBO->getOpcode() == Instruction::And &&
8096 LCI->getValue() == CI->getValue())
8097 if (const SCEVZeroExtendExpr *Z =
8099 Type *UTy = BO->LHS->getType();
8100 const SCEV *Z0 = Z->getOperand();
8101 Type *Z0Ty = Z0->getType();
8102 unsigned Z0TySize = getTypeSizeInBits(Z0Ty);
8103
8104 // If C is a low-bits mask, the zero extend is serving to
8105 // mask off the high bits. Complement the operand and
8106 // re-apply the zext.
8107 if (CI->getValue().isMask(Z0TySize))
8108 return getZeroExtendExpr(getNotSCEV(Z0), UTy);
8109
8110 // If C is a single bit, it may be in the sign-bit position
8111 // before the zero-extend. In this case, represent the xor
8112 // using an add, which is equivalent, and re-apply the zext.
8113 APInt Trunc = CI->getValue().trunc(Z0TySize);
8114 if (Trunc.zext(getTypeSizeInBits(UTy)) == CI->getValue() &&
8115 Trunc.isSignMask())
8116 return getZeroExtendExpr(getAddExpr(Z0, getConstant(Trunc)),
8117 UTy);
8118 }
8119 }
8120 break;
8121
8122 case Instruction::Shl:
8123 // Turn shift left of a constant amount into a multiply.
8124 if (ConstantInt *SA = dyn_cast<ConstantInt>(BO->RHS)) {
8125 uint32_t BitWidth = cast<IntegerType>(SA->getType())->getBitWidth();
8126
8127 // If the shift count is not less than the bitwidth, the result of
8128 // the shift is undefined. Don't try to analyze it, because the
8129 // resolution chosen here may differ from the resolution chosen in
8130 // other parts of the compiler.
8131 if (SA->getValue().uge(BitWidth))
8132 break;
8133
8134 // We can safely preserve the nuw flag in all cases. It's also safe to
8135 // turn a nuw nsw shl into a nuw nsw mul. However, nsw in isolation
8136 // requires special handling. It can be preserved as long as we're not
8137 // left shifting by bitwidth - 1.
8138 auto Flags = SCEV::FlagAnyWrap;
8139 if (BO->Op) {
8140 auto MulFlags = getNoWrapFlagsFromUB(BO->Op);
8141 if ((MulFlags & SCEV::FlagNSW) &&
8142 ((MulFlags & SCEV::FlagNUW) || SA->getValue().ult(BitWidth - 1)))
8144 if (MulFlags & SCEV::FlagNUW)
8146 }
8147
8148 ConstantInt *X = ConstantInt::get(
8149 getContext(), APInt::getOneBitSet(BitWidth, SA->getZExtValue()));
8150 return getMulExpr(getSCEV(BO->LHS), getConstant(X), Flags);
8151 }
8152 break;
8153
8154 case Instruction::AShr:
8155 // AShr X, C, where C is a constant.
8156 ConstantInt *CI = dyn_cast<ConstantInt>(BO->RHS);
8157 if (!CI)
8158 break;
8159
8160 Type *OuterTy = BO->LHS->getType();
8161 uint64_t BitWidth = getTypeSizeInBits(OuterTy);
8162 // If the shift count is not less than the bitwidth, the result of
8163 // the shift is undefined. Don't try to analyze it, because the
8164 // resolution chosen here may differ from the resolution chosen in
8165 // other parts of the compiler.
8166 if (CI->getValue().uge(BitWidth))
8167 break;
8168
8169 if (CI->isZero())
8170 return getSCEV(BO->LHS); // shift by zero --> noop
8171
8172 uint64_t AShrAmt = CI->getZExtValue();
8173 Type *TruncTy = IntegerType::get(getContext(), BitWidth - AShrAmt);
8174
8175 Operator *L = dyn_cast<Operator>(BO->LHS);
8176 const SCEV *AddTruncateExpr = nullptr;
8177 ConstantInt *ShlAmtCI = nullptr;
8178 const SCEV *AddConstant = nullptr;
8179
8180 if (L && L->getOpcode() == Instruction::Add) {
8181 // X = Shl A, n
8182 // Y = Add X, c
8183 // Z = AShr Y, m
8184 // n, c and m are constants.
8185
8186 Operator *LShift = dyn_cast<Operator>(L->getOperand(0));
8187 ConstantInt *AddOperandCI = dyn_cast<ConstantInt>(L->getOperand(1));
8188 if (LShift && LShift->getOpcode() == Instruction::Shl) {
8189 if (AddOperandCI) {
8190 const SCEV *ShlOp0SCEV = getSCEV(LShift->getOperand(0));
8191 ShlAmtCI = dyn_cast<ConstantInt>(LShift->getOperand(1));
8192 // since we truncate to TruncTy, the AddConstant should be of the
8193 // same type, so create a new Constant with type same as TruncTy.
8194 // Also, the Add constant should be shifted right by AShr amount.
8195 APInt AddOperand = AddOperandCI->getValue().ashr(AShrAmt);
8196 AddConstant = getConstant(AddOperand.trunc(BitWidth - AShrAmt));
8197 // we model the expression as sext(add(trunc(A), c << n)), since the
8198 // sext(trunc) part is already handled below, we create a
8199 // AddExpr(TruncExp) which will be used later.
8200 AddTruncateExpr = getTruncateExpr(ShlOp0SCEV, TruncTy);
8201 }
8202 }
8203 } else if (L && L->getOpcode() == Instruction::Shl) {
8204 // X = Shl A, n
8205 // Y = AShr X, m
8206 // Both n and m are constant.
8207
8208 const SCEV *ShlOp0SCEV = getSCEV(L->getOperand(0));
8209 ShlAmtCI = dyn_cast<ConstantInt>(L->getOperand(1));
8210 AddTruncateExpr = getTruncateExpr(ShlOp0SCEV, TruncTy);
8211 }
8212
8213 if (AddTruncateExpr && ShlAmtCI) {
8214 // We can merge the two given cases into a single SCEV statement,
8215 // incase n = m, the mul expression will be 2^0, so it gets resolved to
8216 // a simpler case. The following code handles the two cases:
8217 //
8218 // 1) For a two-shift sext-inreg, i.e. n = m,
8219 // use sext(trunc(x)) as the SCEV expression.
8220 //
8221 // 2) When n > m, use sext(mul(trunc(x), 2^(n-m)))) as the SCEV
8222 // expression. We already checked that ShlAmt < BitWidth, so
8223 // the multiplier, 1 << (ShlAmt - AShrAmt), fits into TruncTy as
8224 // ShlAmt - AShrAmt < Amt.
8225 const APInt &ShlAmt = ShlAmtCI->getValue();
8226 if (ShlAmt.ult(BitWidth) && ShlAmt.uge(AShrAmt)) {
8227 APInt Mul = APInt::getOneBitSet(BitWidth - AShrAmt,
8228 ShlAmtCI->getZExtValue() - AShrAmt);
8229 const SCEV *CompositeExpr =
8230 getMulExpr(AddTruncateExpr, getConstant(Mul));
8231 if (L->getOpcode() != Instruction::Shl)
8232 CompositeExpr = getAddExpr(CompositeExpr, AddConstant);
8233
8234 return getSignExtendExpr(CompositeExpr, OuterTy);
8235 }
8236 }
8237 break;
8238 }
8239 }
8240
8241 switch (U->getOpcode()) {
8242 case Instruction::Trunc:
8243 return getTruncateExpr(getSCEV(U->getOperand(0)), U->getType());
8244
8245 case Instruction::ZExt:
8246 return getZeroExtendExpr(getSCEV(U->getOperand(0)), U->getType());
8247
8248 case Instruction::SExt:
8249 if (auto BO = MatchBinaryOp(U->getOperand(0), getDataLayout(), AC, DT,
8251 // The NSW flag of a subtract does not always survive the conversion to
8252 // A + (-1)*B. By pushing sign extension onto its operands we are much
8253 // more likely to preserve NSW and allow later AddRec optimisations.
8254 //
8255 // NOTE: This is effectively duplicating this logic from getSignExtend:
8256 // sext((A + B + ...)<nsw>) --> (sext(A) + sext(B) + ...)<nsw>
8257 // but by that point the NSW information has potentially been lost.
8258 if (BO->Opcode == Instruction::Sub && BO->IsNSW) {
8259 Type *Ty = U->getType();
8260 auto *V1 = getSignExtendExpr(getSCEV(BO->LHS), Ty);
8261 auto *V2 = getSignExtendExpr(getSCEV(BO->RHS), Ty);
8262 return getMinusSCEV(V1, V2, SCEV::FlagNSW);
8263 }
8264 }
8265 return getSignExtendExpr(getSCEV(U->getOperand(0)), U->getType());
8266
8267 case Instruction::BitCast:
8268 // BitCasts are no-op casts so we just eliminate the cast.
8269 if (isSCEVable(U->getType()) && isSCEVable(U->getOperand(0)->getType()))
8270 return getSCEV(U->getOperand(0));
8271 break;
8272
8273 case Instruction::PtrToAddr: {
8274 const SCEV *IntOp = getPtrToAddrExpr(getSCEV(U->getOperand(0)));
8275 if (isa<SCEVCouldNotCompute>(IntOp))
8276 return getUnknown(V);
8277 return IntOp;
8278 }
8279
8280 case Instruction::PtrToInt: {
8281 // Pointer to integer cast is straight-forward, so do model it.
8282 const SCEV *Op = getSCEV(U->getOperand(0));
8283 Type *DstIntTy = U->getType();
8284 // But only if effective SCEV (integer) type is wide enough to represent
8285 // all possible pointer values.
8286 const SCEV *IntOp = getPtrToIntExpr(Op, DstIntTy);
8287 if (isa<SCEVCouldNotCompute>(IntOp))
8288 return getUnknown(V);
8289 return IntOp;
8290 }
8291 case Instruction::IntToPtr:
8292 // Just don't deal with inttoptr casts.
8293 return getUnknown(V);
8294
8295 case Instruction::SDiv:
8296 // If both operands are non-negative, this is just an udiv.
8297 if (isKnownNonNegative(getSCEV(U->getOperand(0))) &&
8298 isKnownNonNegative(getSCEV(U->getOperand(1))))
8299 return getUDivExpr(getSCEV(U->getOperand(0)), getSCEV(U->getOperand(1)));
8300 break;
8301
8302 case Instruction::SRem:
8303 // If both operands are non-negative, this is just an urem.
8304 if (isKnownNonNegative(getSCEV(U->getOperand(0))) &&
8305 isKnownNonNegative(getSCEV(U->getOperand(1))))
8306 return getURemExpr(getSCEV(U->getOperand(0)), getSCEV(U->getOperand(1)));
8307 break;
8308
8309 case Instruction::GetElementPtr:
8310 return createNodeForGEP(cast<GEPOperator>(U));
8311
8312 case Instruction::PHI:
8313 return createNodeForPHI(cast<PHINode>(U));
8314
8315 case Instruction::Select:
8316 return createNodeForSelectOrPHI(U, U->getOperand(0), U->getOperand(1),
8317 U->getOperand(2));
8318
8319 case Instruction::Call:
8320 case Instruction::Invoke:
8321 if (Value *RV = cast<CallBase>(U)->getReturnedArgOperand())
8322 return getSCEV(RV);
8323
8324 if (auto *II = dyn_cast<IntrinsicInst>(U)) {
8325 switch (II->getIntrinsicID()) {
8326 case Intrinsic::abs:
8327 return getAbsExpr(
8328 getSCEV(II->getArgOperand(0)),
8329 /*IsNSW=*/cast<ConstantInt>(II->getArgOperand(1))->isOne());
8330 case Intrinsic::umax:
8331 LHS = getSCEV(II->getArgOperand(0));
8332 RHS = getSCEV(II->getArgOperand(1));
8333 return getUMaxExpr(LHS, RHS);
8334 case Intrinsic::umin:
8335 LHS = getSCEV(II->getArgOperand(0));
8336 RHS = getSCEV(II->getArgOperand(1));
8337 return getUMinExpr(LHS, RHS);
8338 case Intrinsic::smax:
8339 LHS = getSCEV(II->getArgOperand(0));
8340 RHS = getSCEV(II->getArgOperand(1));
8341 return getSMaxExpr(LHS, RHS);
8342 case Intrinsic::smin:
8343 LHS = getSCEV(II->getArgOperand(0));
8344 RHS = getSCEV(II->getArgOperand(1));
8345 return getSMinExpr(LHS, RHS);
8346 case Intrinsic::usub_sat: {
8347 const SCEV *X = getSCEV(II->getArgOperand(0));
8348 const SCEV *Y = getSCEV(II->getArgOperand(1));
8349 const SCEV *ClampedY = getUMinExpr(X, Y);
8350 return getMinusSCEV(X, ClampedY, SCEV::FlagNUW);
8351 }
8352 case Intrinsic::uadd_sat: {
8353 const SCEV *X = getSCEV(II->getArgOperand(0));
8354 const SCEV *Y = getSCEV(II->getArgOperand(1));
8355 const SCEV *ClampedX = getUMinExpr(X, getNotSCEV(Y));
8356 return getAddExpr(ClampedX, Y, SCEV::FlagNUW);
8357 }
8358 case Intrinsic::start_loop_iterations:
8359 case Intrinsic::annotation:
8360 case Intrinsic::ptr_annotation:
8361 // A start_loop_iterations or llvm.annotation or llvm.prt.annotation is
8362 // just eqivalent to the first operand for SCEV purposes.
8363 return getSCEV(II->getArgOperand(0));
8364 case Intrinsic::vscale:
8365 return getVScale(II->getType());
8366 default:
8367 break;
8368 }
8369 }
8370 break;
8371 }
8372
8373 return getUnknown(V);
8374}
8375
8376//===----------------------------------------------------------------------===//
8377// Iteration Count Computation Code
8378//
8379
8381 if (isa<SCEVCouldNotCompute>(ExitCount))
8382 return getCouldNotCompute();
8383
8384 auto *ExitCountType = ExitCount->getType();
8385 assert(ExitCountType->isIntegerTy());
8386 auto *EvalTy = Type::getIntNTy(ExitCountType->getContext(),
8387 1 + ExitCountType->getScalarSizeInBits());
8388 return getTripCountFromExitCount(ExitCount, EvalTy, nullptr);
8389}
8390
8392 Type *EvalTy,
8393 const Loop *L) {
8394 if (isa<SCEVCouldNotCompute>(ExitCount))
8395 return getCouldNotCompute();
8396
8397 unsigned ExitCountSize = getTypeSizeInBits(ExitCount->getType());
8398 unsigned EvalSize = EvalTy->getPrimitiveSizeInBits();
8399
8400 auto CanAddOneWithoutOverflow = [&]() {
8401 ConstantRange ExitCountRange =
8402 getRangeRef(ExitCount, RangeSignHint::HINT_RANGE_UNSIGNED);
8403 if (!ExitCountRange.contains(APInt::getMaxValue(ExitCountSize)))
8404 return true;
8405
8406 return L && isLoopEntryGuardedByCond(L, ICmpInst::ICMP_NE, ExitCount,
8407 getMinusOne(ExitCount->getType()));
8408 };
8409
8410 // If we need to zero extend the backedge count, check if we can add one to
8411 // it prior to zero extending without overflow. Provided this is safe, it
8412 // allows better simplification of the +1.
8413 if (EvalSize > ExitCountSize && CanAddOneWithoutOverflow())
8414 return getZeroExtendExpr(
8415 getAddExpr(ExitCount, getOne(ExitCount->getType())), EvalTy);
8416
8417 // Get the total trip count from the count by adding 1. This may wrap.
8418 return getAddExpr(getTruncateOrZeroExtend(ExitCount, EvalTy), getOne(EvalTy));
8419}
8420
8421static unsigned getConstantTripCount(const SCEVConstant *ExitCount) {
8422 if (!ExitCount)
8423 return 0;
8424
8425 ConstantInt *ExitConst = ExitCount->getValue();
8426
8427 // Guard against huge trip counts.
8428 if (ExitConst->getValue().getActiveBits() > 32)
8429 return 0;
8430
8431 // In case of integer overflow, this returns 0, which is correct.
8432 return ((unsigned)ExitConst->getZExtValue()) + 1;
8433}
8434
8436 auto *ExitCount = dyn_cast<SCEVConstant>(getBackedgeTakenCount(L, Exact));
8437 return getConstantTripCount(ExitCount);
8438}
8439
8440unsigned
8442 const BasicBlock *ExitingBlock) {
8443 assert(ExitingBlock && "Must pass a non-null exiting block!");
8444 assert(L->isLoopExiting(ExitingBlock) &&
8445 "Exiting block must actually branch out of the loop!");
8446 const SCEVConstant *ExitCount =
8447 dyn_cast<SCEVConstant>(getExitCount(L, ExitingBlock));
8448 return getConstantTripCount(ExitCount);
8449}
8450
8452 const Loop *L, SmallVectorImpl<const SCEVPredicate *> *Predicates) {
8453
8454 const auto *MaxExitCount =
8455 Predicates ? getPredicatedConstantMaxBackedgeTakenCount(L, *Predicates)
8457 return getConstantTripCount(dyn_cast<SCEVConstant>(MaxExitCount));
8458}
8459
8461 SmallVector<BasicBlock *, 8> ExitingBlocks;
8462 L->getExitingBlocks(ExitingBlocks);
8463
8464 std::optional<unsigned> Res;
8465 for (auto *ExitingBB : ExitingBlocks) {
8466 unsigned Multiple = getSmallConstantTripMultiple(L, ExitingBB);
8467 if (!Res)
8468 Res = Multiple;
8469 Res = std::gcd(*Res, Multiple);
8470 }
8471 return Res.value_or(1);
8472}
8473
8475 const SCEV *ExitCount) {
8476 if (isa<SCEVCouldNotCompute>(ExitCount))
8477 return 1;
8478
8479 // Get the trip count
8480 const SCEV *TCExpr = getTripCountFromExitCount(applyLoopGuards(ExitCount, L));
8481
8482 APInt Multiple = getNonZeroConstantMultiple(TCExpr);
8483 // If a trip multiple is huge (>=2^32), the trip count is still divisible by
8484 // the greatest power of 2 divisor less than 2^32.
8485 return Multiple.getActiveBits() > 32
8486 ? 1U << std::min(31U, Multiple.countTrailingZeros())
8487 : (unsigned)Multiple.getZExtValue();
8488}
8489
8490/// Returns the largest constant divisor of the trip count of this loop as a
8491/// normal unsigned value, if possible. This means that the actual trip count is
8492/// always a multiple of the returned value (don't forget the trip count could
8493/// very well be zero as well!).
8494///
8495/// Returns 1 if the trip count is unknown or not guaranteed to be the
8496/// multiple of a constant (which is also the case if the trip count is simply
8497/// constant, use getSmallConstantTripCount for that case), Will also return 1
8498/// if the trip count is very large (>= 2^32).
8499///
8500/// As explained in the comments for getSmallConstantTripCount, this assumes
8501/// that control exits the loop via ExitingBlock.
8502unsigned
8504 const BasicBlock *ExitingBlock) {
8505 assert(ExitingBlock && "Must pass a non-null exiting block!");
8506 assert(L->isLoopExiting(ExitingBlock) &&
8507 "Exiting block must actually branch out of the loop!");
8508 const SCEV *ExitCount = getExitCount(L, ExitingBlock);
8509 return getSmallConstantTripMultiple(L, ExitCount);
8510}
8511
8513 const BasicBlock *ExitingBlock,
8514 ExitCountKind Kind) {
8515 switch (Kind) {
8516 case Exact:
8517 return getBackedgeTakenInfo(L).getExact(ExitingBlock, this);
8518 case SymbolicMaximum:
8519 return getBackedgeTakenInfo(L).getSymbolicMax(ExitingBlock, this);
8520 case ConstantMaximum:
8521 return getBackedgeTakenInfo(L).getConstantMax(ExitingBlock, this);
8522 };
8523 llvm_unreachable("Invalid ExitCountKind!");
8524}
8525
8527 const Loop *L, const BasicBlock *ExitingBlock,
8529 switch (Kind) {
8530 case Exact:
8531 return getPredicatedBackedgeTakenInfo(L).getExact(ExitingBlock, this,
8532 Predicates);
8533 case SymbolicMaximum:
8534 return getPredicatedBackedgeTakenInfo(L).getSymbolicMax(ExitingBlock, this,
8535 Predicates);
8536 case ConstantMaximum:
8537 return getPredicatedBackedgeTakenInfo(L).getConstantMax(ExitingBlock, this,
8538 Predicates);
8539 };
8540 llvm_unreachable("Invalid ExitCountKind!");
8541}
8542
8545 return getPredicatedBackedgeTakenInfo(L).getExact(L, this, &Preds);
8546}
8547
8549 ExitCountKind Kind) {
8550 switch (Kind) {
8551 case Exact:
8552 return getBackedgeTakenInfo(L).getExact(L, this);
8553 case ConstantMaximum:
8554 return getBackedgeTakenInfo(L).getConstantMax(this);
8555 case SymbolicMaximum:
8556 return getBackedgeTakenInfo(L).getSymbolicMax(L, this);
8557 };
8558 llvm_unreachable("Invalid ExitCountKind!");
8559}
8560
8563 return getPredicatedBackedgeTakenInfo(L).getSymbolicMax(L, this, &Preds);
8564}
8565
8568 return getPredicatedBackedgeTakenInfo(L).getConstantMax(this, &Preds);
8569}
8570
8572 return getBackedgeTakenInfo(L).isConstantMaxOrZero(this);
8573}
8574
8575/// Push PHI nodes in the header of the given loop onto the given Worklist.
8576static void PushLoopPHIs(const Loop *L,
8579 BasicBlock *Header = L->getHeader();
8580
8581 // Push all Loop-header PHIs onto the Worklist stack.
8582 for (PHINode &PN : Header->phis())
8583 if (Visited.insert(&PN).second)
8584 Worklist.push_back(&PN);
8585}
8586
8587ScalarEvolution::BackedgeTakenInfo &
8588ScalarEvolution::getPredicatedBackedgeTakenInfo(const Loop *L) {
8589 auto &BTI = getBackedgeTakenInfo(L);
8590 if (BTI.hasFullInfo())
8591 return BTI;
8592
8593 auto Pair = PredicatedBackedgeTakenCounts.try_emplace(L);
8594
8595 if (!Pair.second)
8596 return Pair.first->second;
8597
8598 BackedgeTakenInfo Result =
8599 computeBackedgeTakenCount(L, /*AllowPredicates=*/true);
8600
8601 return PredicatedBackedgeTakenCounts.find(L)->second = std::move(Result);
8602}
8603
8604ScalarEvolution::BackedgeTakenInfo &
8605ScalarEvolution::getBackedgeTakenInfo(const Loop *L) {
8606 // Initially insert an invalid entry for this loop. If the insertion
8607 // succeeds, proceed to actually compute a backedge-taken count and
8608 // update the value. The temporary CouldNotCompute value tells SCEV
8609 // code elsewhere that it shouldn't attempt to request a new
8610 // backedge-taken count, which could result in infinite recursion.
8611 std::pair<DenseMap<const Loop *, BackedgeTakenInfo>::iterator, bool> Pair =
8612 BackedgeTakenCounts.try_emplace(L);
8613 if (!Pair.second)
8614 return Pair.first->second;
8615
8616 // computeBackedgeTakenCount may allocate memory for its result. Inserting it
8617 // into the BackedgeTakenCounts map transfers ownership. Otherwise, the result
8618 // must be cleared in this scope.
8619 BackedgeTakenInfo Result = computeBackedgeTakenCount(L);
8620
8621 // Now that we know more about the trip count for this loop, forget any
8622 // existing SCEV values for PHI nodes in this loop since they are only
8623 // conservative estimates made without the benefit of trip count
8624 // information. This invalidation is not necessary for correctness, and is
8625 // only done to produce more precise results.
8626 if (Result.hasAnyInfo()) {
8627 // Invalidate any expression using an addrec in this loop.
8628 SmallVector<SCEVUse, 8> ToForget;
8629 auto LoopUsersIt = LoopUsers.find(L);
8630 if (LoopUsersIt != LoopUsers.end())
8631 append_range(ToForget, LoopUsersIt->second);
8632 forgetMemoizedResults(ToForget);
8633
8634 // Invalidate constant-evolved loop header phis.
8635 for (PHINode &PN : L->getHeader()->phis())
8636 ConstantEvolutionLoopExitValue.erase(&PN);
8637 }
8638
8639 // Re-lookup the insert position, since the call to
8640 // computeBackedgeTakenCount above could result in a
8641 // recusive call to getBackedgeTakenInfo (on a different
8642 // loop), which would invalidate the iterator computed
8643 // earlier.
8644 return BackedgeTakenCounts.find(L)->second = std::move(Result);
8645}
8646
8648 // This method is intended to forget all info about loops. It should
8649 // invalidate caches as if the following happened:
8650 // - The trip counts of all loops have changed arbitrarily
8651 // - Every llvm::Value has been updated in place to produce a different
8652 // result.
8653 BackedgeTakenCounts.clear();
8654 PredicatedBackedgeTakenCounts.clear();
8655 BECountUsers.clear();
8656 LoopPropertiesCache.clear();
8657 ConstantEvolutionLoopExitValue.clear();
8658 ValueExprMap.clear();
8659 ValuesAtScopes.clear();
8660 ValuesAtScopesUsers.clear();
8661 LoopDispositions.clear();
8662 BlockDispositions.clear();
8663 UnsignedRanges.clear();
8664 SignedRanges.clear();
8665 ExprValueMap.clear();
8666 HasRecMap.clear();
8667 ConstantMultipleCache.clear();
8668 PredicatedSCEVRewrites.clear();
8669 FoldCache.clear();
8670 FoldCacheUser.clear();
8671}
8672void ScalarEvolution::visitAndClearUsers(
8675 SmallVectorImpl<SCEVUse> &ToForget) {
8676 while (!Worklist.empty()) {
8677 Instruction *I = Worklist.pop_back_val();
8678 if (!isSCEVable(I->getType()) && !isa<WithOverflowInst>(I))
8679 continue;
8680
8682 ValueExprMap.find_as(static_cast<Value *>(I));
8683 if (It != ValueExprMap.end()) {
8684 eraseValueFromMap(It->first);
8685 ToForget.push_back(It->second);
8686 if (PHINode *PN = dyn_cast<PHINode>(I))
8687 ConstantEvolutionLoopExitValue.erase(PN);
8688 }
8689
8690 PushDefUseChildren(I, Worklist, Visited);
8691 }
8692}
8693
8695 SmallVector<const Loop *, 16> LoopWorklist(1, L);
8698 SmallVector<SCEVUse, 16> ToForget;
8699
8700 // Iterate over all the loops and sub-loops to drop SCEV information.
8701 while (!LoopWorklist.empty()) {
8702 auto *CurrL = LoopWorklist.pop_back_val();
8703
8704 // Drop any stored trip count value.
8705 forgetBackedgeTakenCounts(CurrL, /* Predicated */ false);
8706 forgetBackedgeTakenCounts(CurrL, /* Predicated */ true);
8707
8708 // Drop information about predicated SCEV rewrites for this loop.
8709 for (auto I = PredicatedSCEVRewrites.begin();
8710 I != PredicatedSCEVRewrites.end();) {
8711 std::pair<const SCEV *, const Loop *> Entry = I->first;
8712 if (Entry.second == CurrL)
8713 PredicatedSCEVRewrites.erase(I++);
8714 else
8715 ++I;
8716 }
8717
8718 auto LoopUsersItr = LoopUsers.find(CurrL);
8719 if (LoopUsersItr != LoopUsers.end())
8720 llvm::append_range(ToForget, LoopUsersItr->second);
8721
8722 // Drop information about expressions based on loop-header PHIs.
8723 PushLoopPHIs(CurrL, Worklist, Visited);
8724 visitAndClearUsers(Worklist, Visited, ToForget);
8725
8726 LoopPropertiesCache.erase(CurrL);
8727 // Forget all contained loops too, to avoid dangling entries in the
8728 // ValuesAtScopes map.
8729 LoopWorklist.append(CurrL->begin(), CurrL->end());
8730 }
8731 forgetMemoizedResults(ToForget);
8732}
8733
8735 forgetLoop(L->getOutermostLoop());
8736}
8737
8740 if (!I) return;
8741
8742 // Drop information about expressions based on loop-header PHIs.
8745 SmallVector<SCEVUse, 8> ToForget;
8746 Worklist.push_back(I);
8747 Visited.insert(I);
8748 visitAndClearUsers(Worklist, Visited, ToForget);
8749
8750 forgetMemoizedResults(ToForget);
8751}
8752
8754 if (!isSCEVable(V->getType()))
8755 return;
8756
8757 // If SCEV looked through a trivial LCSSA phi node, we might have SCEV's
8758 // directly using a SCEVUnknown/SCEVAddRec defined in the loop. After an
8759 // extra predecessor is added, this is no longer valid. Find all Unknowns and
8760 // AddRecs defined in the loop and invalidate any SCEV's making use of them.
8761 if (const SCEV *S = getExistingSCEV(V)) {
8762 struct InvalidationRootCollector {
8763 Loop *L;
8765
8766 InvalidationRootCollector(Loop *L) : L(L) {}
8767
8768 bool follow(const SCEV *S) {
8769 if (auto *SU = dyn_cast<SCEVUnknown>(S)) {
8770 if (auto *I = dyn_cast<Instruction>(SU->getValue()))
8771 if (L->contains(I))
8772 Roots.push_back(S);
8773 } else if (auto *AddRec = dyn_cast<SCEVAddRecExpr>(S)) {
8774 if (L->contains(AddRec->getLoop()))
8775 Roots.push_back(S);
8776 }
8777 return true;
8778 }
8779 bool isDone() const { return false; }
8780 };
8781
8782 InvalidationRootCollector C(L);
8783 visitAll(S, C);
8784 forgetMemoizedResults(C.Roots);
8785 }
8786
8787 // Also perform the normal invalidation.
8788 forgetValue(V);
8789}
8790
8791void ScalarEvolution::forgetLoopDispositions() { LoopDispositions.clear(); }
8792
8794 // Unless a specific value is passed to invalidation, completely clear both
8795 // caches.
8796 if (!V) {
8797 BlockDispositions.clear();
8798 LoopDispositions.clear();
8799 return;
8800 }
8801
8802 if (!isSCEVable(V->getType()))
8803 return;
8804
8805 const SCEV *S = getExistingSCEV(V);
8806 if (!S)
8807 return;
8808
8809 // Invalidate the block and loop dispositions cached for S. Dispositions of
8810 // S's users may change if S's disposition changes (i.e. a user may change to
8811 // loop-invariant, if S changes to loop invariant), so also invalidate
8812 // dispositions of S's users recursively.
8813 SmallVector<SCEVUse, 8> Worklist = {S};
8815 while (!Worklist.empty()) {
8816 const SCEV *Curr = Worklist.pop_back_val();
8817 bool LoopDispoRemoved = LoopDispositions.erase(Curr);
8818 bool BlockDispoRemoved = BlockDispositions.erase(Curr);
8819 if (!LoopDispoRemoved && !BlockDispoRemoved)
8820 continue;
8821 auto Users = SCEVUsers.find(Curr);
8822 if (Users != SCEVUsers.end())
8823 for (const auto *User : Users->second)
8824 if (Seen.insert(User).second)
8825 Worklist.push_back(User);
8826 }
8827}
8828
8829/// Get the exact loop backedge taken count considering all loop exits. A
8830/// computable result can only be returned for loops with all exiting blocks
8831/// dominating the latch. howFarToZero assumes that the limit of each loop test
8832/// is never skipped. This is a valid assumption as long as the loop exits via
8833/// that test. For precise results, it is the caller's responsibility to specify
8834/// the relevant loop exiting block using getExact(ExitingBlock, SE).
8835const SCEV *ScalarEvolution::BackedgeTakenInfo::getExact(
8836 const Loop *L, ScalarEvolution *SE,
8838 // If any exits were not computable, the loop is not computable.
8839 if (!isComplete() || ExitNotTaken.empty())
8840 return SE->getCouldNotCompute();
8841
8842 const BasicBlock *Latch = L->getLoopLatch();
8843 // All exiting blocks we have collected must dominate the only backedge.
8844 if (!Latch)
8845 return SE->getCouldNotCompute();
8846
8847 // All exiting blocks we have gathered dominate loop's latch, so exact trip
8848 // count is simply a minimum out of all these calculated exit counts.
8850 for (const auto &ENT : ExitNotTaken) {
8851 const SCEV *BECount = ENT.ExactNotTaken;
8852 assert(BECount != SE->getCouldNotCompute() && "Bad exit SCEV!");
8853 assert(SE->DT.dominates(ENT.ExitingBlock, Latch) &&
8854 "We should only have known counts for exiting blocks that dominate "
8855 "latch!");
8856
8857 Ops.push_back(BECount);
8858
8859 if (Preds)
8860 append_range(*Preds, ENT.Predicates);
8861
8862 assert((Preds || ENT.hasAlwaysTruePredicate()) &&
8863 "Predicate should be always true!");
8864 }
8865
8866 // If an earlier exit exits on the first iteration (exit count zero), then
8867 // a later poison exit count should not propagate into the result. This are
8868 // exactly the semantics provided by umin_seq.
8869 return SE->getUMinFromMismatchedTypes(Ops, /* Sequential */ true);
8870}
8871
8872const ScalarEvolution::ExitNotTakenInfo *
8873ScalarEvolution::BackedgeTakenInfo::getExitNotTaken(
8874 const BasicBlock *ExitingBlock,
8875 SmallVectorImpl<const SCEVPredicate *> *Predicates) const {
8876 for (const auto &ENT : ExitNotTaken)
8877 if (ENT.ExitingBlock == ExitingBlock) {
8878 if (ENT.hasAlwaysTruePredicate())
8879 return &ENT;
8880 else if (Predicates) {
8881 append_range(*Predicates, ENT.Predicates);
8882 return &ENT;
8883 }
8884 }
8885
8886 return nullptr;
8887}
8888
8889/// getConstantMax - Get the constant max backedge taken count for the loop.
8890const SCEV *ScalarEvolution::BackedgeTakenInfo::getConstantMax(
8891 ScalarEvolution *SE,
8892 SmallVectorImpl<const SCEVPredicate *> *Predicates) const {
8893 if (!getConstantMax())
8894 return SE->getCouldNotCompute();
8895
8896 for (const auto &ENT : ExitNotTaken)
8897 if (!ENT.hasAlwaysTruePredicate()) {
8898 if (!Predicates)
8899 return SE->getCouldNotCompute();
8900 append_range(*Predicates, ENT.Predicates);
8901 }
8902
8903 assert((isa<SCEVCouldNotCompute>(getConstantMax()) ||
8904 isa<SCEVConstant>(getConstantMax())) &&
8905 "No point in having a non-constant max backedge taken count!");
8906 return getConstantMax();
8907}
8908
8909const SCEV *ScalarEvolution::BackedgeTakenInfo::getSymbolicMax(
8910 const Loop *L, ScalarEvolution *SE,
8911 SmallVectorImpl<const SCEVPredicate *> *Predicates) {
8912 if (!SymbolicMax) {
8913 // Form an expression for the maximum exit count possible for this loop. We
8914 // merge the max and exact information to approximate a version of
8915 // getConstantMaxBackedgeTakenCount which isn't restricted to just
8916 // constants.
8917 SmallVector<SCEVUse, 4> ExitCounts;
8918
8919 for (const auto &ENT : ExitNotTaken) {
8920 const SCEV *ExitCount = ENT.SymbolicMaxNotTaken;
8921 if (!isa<SCEVCouldNotCompute>(ExitCount)) {
8922 assert(SE->DT.dominates(ENT.ExitingBlock, L->getLoopLatch()) &&
8923 "We should only have known counts for exiting blocks that "
8924 "dominate latch!");
8925 ExitCounts.push_back(ExitCount);
8926 if (Predicates)
8927 append_range(*Predicates, ENT.Predicates);
8928
8929 assert((Predicates || ENT.hasAlwaysTruePredicate()) &&
8930 "Predicate should be always true!");
8931 }
8932 }
8933 if (ExitCounts.empty())
8934 SymbolicMax = SE->getCouldNotCompute();
8935 else
8936 SymbolicMax =
8937 SE->getUMinFromMismatchedTypes(ExitCounts, /*Sequential*/ true);
8938 }
8939 return SymbolicMax;
8940}
8941
8942bool ScalarEvolution::BackedgeTakenInfo::isConstantMaxOrZero(
8943 ScalarEvolution *SE) const {
8944 auto PredicateNotAlwaysTrue = [](const ExitNotTakenInfo &ENT) {
8945 return !ENT.hasAlwaysTruePredicate();
8946 };
8947 return MaxOrZero && !any_of(ExitNotTaken, PredicateNotAlwaysTrue);
8948}
8949
8952
8954 const SCEV *E, const SCEV *ConstantMaxNotTaken,
8955 const SCEV *SymbolicMaxNotTaken, bool MaxOrZero,
8959 // If we prove the max count is zero, so is the symbolic bound. This happens
8960 // in practice due to differences in a) how context sensitive we've chosen
8961 // to be and b) how we reason about bounds implied by UB.
8962 if (ConstantMaxNotTaken->isZero()) {
8963 this->ExactNotTaken = E = ConstantMaxNotTaken;
8964 this->SymbolicMaxNotTaken = SymbolicMaxNotTaken = ConstantMaxNotTaken;
8965 }
8966
8969 "Exact is not allowed to be less precise than Constant Max");
8972 "Exact is not allowed to be less precise than Symbolic Max");
8975 "Symbolic Max is not allowed to be less precise than Constant Max");
8978 "No point in having a non-constant max backedge taken count!");
8980 for (const auto PredList : PredLists)
8981 for (const auto *P : PredList) {
8982 if (SeenPreds.contains(P))
8983 continue;
8984 assert(!isa<SCEVUnionPredicate>(P) && "Only add leaf predicates here!");
8985 SeenPreds.insert(P);
8986 Predicates.push_back(P);
8987 }
8988 assert((isa<SCEVCouldNotCompute>(E) || !E->getType()->isPointerTy()) &&
8989 "Backedge count should be int");
8991 !ConstantMaxNotTaken->getType()->isPointerTy()) &&
8992 "Max backedge count should be int");
8993}
8994
9002
9003/// Allocate memory for BackedgeTakenInfo and copy the not-taken count of each
9004/// computable exit into a persistent ExitNotTakenInfo array.
9005ScalarEvolution::BackedgeTakenInfo::BackedgeTakenInfo(
9007 bool IsComplete, const SCEV *ConstantMax, bool MaxOrZero)
9008 : ConstantMax(ConstantMax), IsComplete(IsComplete), MaxOrZero(MaxOrZero) {
9009 using EdgeExitInfo = ScalarEvolution::BackedgeTakenInfo::EdgeExitInfo;
9010
9011 ExitNotTaken.reserve(ExitCounts.size());
9012 std::transform(ExitCounts.begin(), ExitCounts.end(),
9013 std::back_inserter(ExitNotTaken),
9014 [&](const EdgeExitInfo &EEI) {
9015 BasicBlock *ExitBB = EEI.first;
9016 const ExitLimit &EL = EEI.second;
9017 return ExitNotTakenInfo(ExitBB, EL.ExactNotTaken,
9018 EL.ConstantMaxNotTaken, EL.SymbolicMaxNotTaken,
9019 EL.Predicates);
9020 });
9021 assert((isa<SCEVCouldNotCompute>(ConstantMax) ||
9022 isa<SCEVConstant>(ConstantMax)) &&
9023 "No point in having a non-constant max backedge taken count!");
9024}
9025
9026/// Compute the number of times the backedge of the specified loop will execute.
9027ScalarEvolution::BackedgeTakenInfo
9028ScalarEvolution::computeBackedgeTakenCount(const Loop *L,
9029 bool AllowPredicates) {
9030 SmallVector<BasicBlock *, 8> ExitingBlocks;
9031 L->getExitingBlocks(ExitingBlocks);
9032
9033 using EdgeExitInfo = ScalarEvolution::BackedgeTakenInfo::EdgeExitInfo;
9034
9036 bool CouldComputeBECount = true;
9037 BasicBlock *Latch = L->getLoopLatch(); // may be NULL.
9038 const SCEV *MustExitMaxBECount = nullptr;
9039 const SCEV *MayExitMaxBECount = nullptr;
9040 bool MustExitMaxOrZero = false;
9041 bool IsOnlyExit = ExitingBlocks.size() == 1;
9042
9043 // Compute the ExitLimit for each loop exit. Use this to populate ExitCounts
9044 // and compute maxBECount.
9045 // Do a union of all the predicates here.
9046 for (BasicBlock *ExitBB : ExitingBlocks) {
9047 // We canonicalize untaken exits to br (constant), ignore them so that
9048 // proving an exit untaken doesn't negatively impact our ability to reason
9049 // about the loop as whole.
9050 if (auto *BI = dyn_cast<CondBrInst>(ExitBB->getTerminator()))
9051 if (auto *CI = dyn_cast<ConstantInt>(BI->getCondition())) {
9052 bool ExitIfTrue = !L->contains(BI->getSuccessor(0));
9053 if (ExitIfTrue == CI->isZero())
9054 continue;
9055 }
9056
9057 ExitLimit EL = computeExitLimit(L, ExitBB, IsOnlyExit, AllowPredicates);
9058
9059 assert((AllowPredicates || EL.Predicates.empty()) &&
9060 "Predicated exit limit when predicates are not allowed!");
9061
9062 // 1. For each exit that can be computed, add an entry to ExitCounts.
9063 // CouldComputeBECount is true only if all exits can be computed.
9064 if (EL.ExactNotTaken != getCouldNotCompute())
9065 ++NumExitCountsComputed;
9066 else
9067 // We couldn't compute an exact value for this exit, so
9068 // we won't be able to compute an exact value for the loop.
9069 CouldComputeBECount = false;
9070 // Remember exit count if either exact or symbolic is known. Because
9071 // Exact always implies symbolic, only check symbolic.
9072 if (EL.SymbolicMaxNotTaken != getCouldNotCompute())
9073 ExitCounts.emplace_back(ExitBB, EL);
9074 else {
9075 assert(EL.ExactNotTaken == getCouldNotCompute() &&
9076 "Exact is known but symbolic isn't?");
9077 ++NumExitCountsNotComputed;
9078 }
9079
9080 // 2. Derive the loop's MaxBECount from each exit's max number of
9081 // non-exiting iterations. Partition the loop exits into two kinds:
9082 // LoopMustExits and LoopMayExits.
9083 //
9084 // If the exit dominates the loop latch, it is a LoopMustExit otherwise it
9085 // is a LoopMayExit. If any computable LoopMustExit is found, then
9086 // MaxBECount is the minimum EL.ConstantMaxNotTaken of computable
9087 // LoopMustExits. Otherwise, MaxBECount is conservatively the maximum
9088 // EL.ConstantMaxNotTaken, where CouldNotCompute is considered greater than
9089 // any
9090 // computable EL.ConstantMaxNotTaken.
9091 if (EL.ConstantMaxNotTaken != getCouldNotCompute() && Latch &&
9092 DT.dominates(ExitBB, Latch)) {
9093 if (!MustExitMaxBECount) {
9094 MustExitMaxBECount = EL.ConstantMaxNotTaken;
9095 MustExitMaxOrZero = EL.MaxOrZero;
9096 } else {
9097 MustExitMaxBECount = getUMinFromMismatchedTypes(MustExitMaxBECount,
9098 EL.ConstantMaxNotTaken);
9099 }
9100 } else if (MayExitMaxBECount != getCouldNotCompute()) {
9101 if (!MayExitMaxBECount || EL.ConstantMaxNotTaken == getCouldNotCompute())
9102 MayExitMaxBECount = EL.ConstantMaxNotTaken;
9103 else {
9104 MayExitMaxBECount = getUMaxFromMismatchedTypes(MayExitMaxBECount,
9105 EL.ConstantMaxNotTaken);
9106 }
9107 }
9108 }
9109 const SCEV *MaxBECount = MustExitMaxBECount ? MustExitMaxBECount :
9110 (MayExitMaxBECount ? MayExitMaxBECount : getCouldNotCompute());
9111 // The loop backedge will be taken the maximum or zero times if there's
9112 // a single exit that must be taken the maximum or zero times.
9113 bool MaxOrZero = (MustExitMaxOrZero && ExitingBlocks.size() == 1);
9114
9115 // Remember which SCEVs are used in exit limits for invalidation purposes.
9116 // We only care about non-constant SCEVs here, so we can ignore
9117 // EL.ConstantMaxNotTaken
9118 // and MaxBECount, which must be SCEVConstant.
9119 for (const auto &Pair : ExitCounts) {
9120 if (!isa<SCEVConstant>(Pair.second.ExactNotTaken))
9121 BECountUsers[Pair.second.ExactNotTaken].insert({L, AllowPredicates});
9122 if (!isa<SCEVConstant>(Pair.second.SymbolicMaxNotTaken))
9123 BECountUsers[Pair.second.SymbolicMaxNotTaken].insert(
9124 {L, AllowPredicates});
9125 }
9126 return BackedgeTakenInfo(std::move(ExitCounts), CouldComputeBECount,
9127 MaxBECount, MaxOrZero);
9128}
9129
9130ScalarEvolution::ExitLimit
9131ScalarEvolution::computeExitLimit(const Loop *L, BasicBlock *ExitingBlock,
9132 bool IsOnlyExit, bool AllowPredicates) {
9133 assert(L->contains(ExitingBlock) && "Exit count for non-loop block?");
9134 // If our exiting block does not dominate the latch, then its connection with
9135 // loop's exit limit may be far from trivial.
9136 const BasicBlock *Latch = L->getLoopLatch();
9137 if (!Latch || !DT.dominates(ExitingBlock, Latch))
9138 return getCouldNotCompute();
9139
9140 Instruction *Term = ExitingBlock->getTerminator();
9141 if (CondBrInst *BI = dyn_cast<CondBrInst>(Term)) {
9142 bool ExitIfTrue = !L->contains(BI->getSuccessor(0));
9143 assert(ExitIfTrue == L->contains(BI->getSuccessor(1)) &&
9144 "It should have one successor in loop and one exit block!");
9145 // Proceed to the next level to examine the exit condition expression.
9146 return computeExitLimitFromCond(L, BI->getCondition(), ExitIfTrue,
9147 /*ControlsOnlyExit=*/IsOnlyExit,
9148 AllowPredicates);
9149 }
9150
9151 if (SwitchInst *SI = dyn_cast<SwitchInst>(Term)) {
9152 // For switch, make sure that there is a single exit from the loop.
9153 BasicBlock *Exit = nullptr;
9154 for (auto *SBB : successors(ExitingBlock))
9155 if (!L->contains(SBB)) {
9156 if (Exit) // Multiple exit successors.
9157 return getCouldNotCompute();
9158 Exit = SBB;
9159 }
9160 assert(Exit && "Exiting block must have at least one exit");
9161 return computeExitLimitFromSingleExitSwitch(
9162 L, SI, Exit, /*ControlsOnlyExit=*/IsOnlyExit);
9163 }
9164
9165 return getCouldNotCompute();
9166}
9167
9169 const Loop *L, Value *ExitCond, bool ExitIfTrue, bool ControlsOnlyExit,
9170 bool AllowPredicates) {
9171 ScalarEvolution::ExitLimitCacheTy Cache(L, ExitIfTrue, AllowPredicates);
9172 return computeExitLimitFromCondCached(Cache, L, ExitCond, ExitIfTrue,
9173 ControlsOnlyExit, AllowPredicates);
9174}
9175
9176std::optional<ScalarEvolution::ExitLimit>
9177ScalarEvolution::ExitLimitCache::find(const Loop *L, Value *ExitCond,
9178 bool ExitIfTrue, bool ControlsOnlyExit,
9179 bool AllowPredicates) {
9180 (void)this->L;
9181 (void)this->ExitIfTrue;
9182 (void)this->AllowPredicates;
9183
9184 assert(this->L == L && this->ExitIfTrue == ExitIfTrue &&
9185 this->AllowPredicates == AllowPredicates &&
9186 "Variance in assumed invariant key components!");
9187 auto Itr = TripCountMap.find({ExitCond, ControlsOnlyExit});
9188 if (Itr == TripCountMap.end())
9189 return std::nullopt;
9190 return Itr->second;
9191}
9192
9193void ScalarEvolution::ExitLimitCache::insert(const Loop *L, Value *ExitCond,
9194 bool ExitIfTrue,
9195 bool ControlsOnlyExit,
9196 bool AllowPredicates,
9197 const ExitLimit &EL) {
9198 assert(this->L == L && this->ExitIfTrue == ExitIfTrue &&
9199 this->AllowPredicates == AllowPredicates &&
9200 "Variance in assumed invariant key components!");
9201
9202 auto InsertResult = TripCountMap.insert({{ExitCond, ControlsOnlyExit}, EL});
9203 assert(InsertResult.second && "Expected successful insertion!");
9204 (void)InsertResult;
9205 (void)ExitIfTrue;
9206}
9207
9208ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromCondCached(
9209 ExitLimitCacheTy &Cache, const Loop *L, Value *ExitCond, bool ExitIfTrue,
9210 bool ControlsOnlyExit, bool AllowPredicates) {
9211
9212 if (auto MaybeEL = Cache.find(L, ExitCond, ExitIfTrue, ControlsOnlyExit,
9213 AllowPredicates))
9214 return *MaybeEL;
9215
9216 ExitLimit EL = computeExitLimitFromCondImpl(
9217 Cache, L, ExitCond, ExitIfTrue, ControlsOnlyExit, AllowPredicates);
9218 Cache.insert(L, ExitCond, ExitIfTrue, ControlsOnlyExit, AllowPredicates, EL);
9219 return EL;
9220}
9221
9222ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromCondImpl(
9223 ExitLimitCacheTy &Cache, const Loop *L, Value *ExitCond, bool ExitIfTrue,
9224 bool ControlsOnlyExit, bool AllowPredicates) {
9225 // Handle BinOp conditions (And, Or).
9226 if (auto LimitFromBinOp = computeExitLimitFromCondFromBinOp(
9227 Cache, L, ExitCond, ExitIfTrue, ControlsOnlyExit, AllowPredicates))
9228 return *LimitFromBinOp;
9229
9230 // With an icmp, it may be feasible to compute an exact backedge-taken count.
9231 // Proceed to the next level to examine the icmp.
9232 if (ICmpInst *ExitCondICmp = dyn_cast<ICmpInst>(ExitCond)) {
9233 ExitLimit EL =
9234 computeExitLimitFromICmp(L, ExitCondICmp, ExitIfTrue, ControlsOnlyExit);
9235 if (EL.hasFullInfo() || !AllowPredicates)
9236 return EL;
9237
9238 // Try again, but use SCEV predicates this time.
9239 return computeExitLimitFromICmp(L, ExitCondICmp, ExitIfTrue,
9240 ControlsOnlyExit,
9241 /*AllowPredicates=*/true);
9242 }
9243
9244 // Check for a constant condition. These are normally stripped out by
9245 // SimplifyCFG, but ScalarEvolution may be used by a pass which wishes to
9246 // preserve the CFG and is temporarily leaving constant conditions
9247 // in place.
9248 if (ConstantInt *CI = dyn_cast<ConstantInt>(ExitCond)) {
9249 if (ExitIfTrue == !CI->getZExtValue())
9250 // The backedge is always taken.
9251 return getCouldNotCompute();
9252 // The backedge is never taken.
9253 return getZero(CI->getType());
9254 }
9255
9256 // If we're exiting based on the overflow flag of an x.with.overflow intrinsic
9257 // with a constant step, we can form an equivalent icmp predicate and figure
9258 // out how many iterations will be taken before we exit.
9259 const WithOverflowInst *WO;
9260 const APInt *C;
9261 if (match(ExitCond, m_ExtractValue<1>(m_WithOverflowInst(WO))) &&
9262 match(WO->getRHS(), m_APInt(C))) {
9263 ConstantRange NWR =
9265 WO->getNoWrapKind());
9266 CmpInst::Predicate Pred;
9267 APInt NewRHSC, Offset;
9268 NWR.getEquivalentICmp(Pred, NewRHSC, Offset);
9269 if (!ExitIfTrue)
9270 Pred = ICmpInst::getInversePredicate(Pred);
9271 auto *LHS = getSCEV(WO->getLHS());
9272 if (Offset != 0)
9274 auto EL = computeExitLimitFromICmp(L, Pred, LHS, getConstant(NewRHSC),
9275 ControlsOnlyExit, AllowPredicates);
9276 if (EL.hasAnyInfo())
9277 return EL;
9278 }
9279
9280 // If it's not an integer or pointer comparison then compute it the hard way.
9281 return computeExitCountExhaustively(L, ExitCond, ExitIfTrue);
9282}
9283
9284std::optional<ScalarEvolution::ExitLimit>
9285ScalarEvolution::computeExitLimitFromCondFromBinOp(
9286 ExitLimitCacheTy &Cache, const Loop *L, Value *ExitCond, bool ExitIfTrue,
9287 bool ControlsOnlyExit, bool AllowPredicates) {
9288 // Check if the controlling expression for this loop is an And or Or.
9289 Value *Op0, *Op1;
9290 bool IsAnd = false;
9291 if (match(ExitCond, m_LogicalAnd(m_Value(Op0), m_Value(Op1))))
9292 IsAnd = true;
9293 else if (match(ExitCond, m_LogicalOr(m_Value(Op0), m_Value(Op1))))
9294 IsAnd = false;
9295 else
9296 return std::nullopt;
9297
9298 // EitherMayExit is true in these two cases:
9299 // br (and Op0 Op1), loop, exit
9300 // br (or Op0 Op1), exit, loop
9301 bool EitherMayExit = IsAnd ^ ExitIfTrue;
9302 ExitLimit EL0 = computeExitLimitFromCondCached(
9303 Cache, L, Op0, ExitIfTrue, ControlsOnlyExit && !EitherMayExit,
9304 AllowPredicates);
9305 ExitLimit EL1 = computeExitLimitFromCondCached(
9306 Cache, L, Op1, ExitIfTrue, ControlsOnlyExit && !EitherMayExit,
9307 AllowPredicates);
9308
9309 // Be robust against unsimplified IR for the form "op i1 X, NeutralElement"
9310 const Constant *NeutralElement = ConstantInt::get(ExitCond->getType(), IsAnd);
9311 if (isa<ConstantInt>(Op1))
9312 return Op1 == NeutralElement ? EL0 : EL1;
9313 if (isa<ConstantInt>(Op0))
9314 return Op0 == NeutralElement ? EL1 : EL0;
9315
9316 const SCEV *BECount = getCouldNotCompute();
9317 const SCEV *ConstantMaxBECount = getCouldNotCompute();
9318 const SCEV *SymbolicMaxBECount = getCouldNotCompute();
9319 if (EitherMayExit) {
9320 bool UseSequentialUMin = !isa<BinaryOperator>(ExitCond);
9321 // Both conditions must be same for the loop to continue executing.
9322 // Choose the less conservative count.
9323 if (EL0.ExactNotTaken != getCouldNotCompute() &&
9324 EL1.ExactNotTaken != getCouldNotCompute()) {
9325 BECount = getUMinFromMismatchedTypes(EL0.ExactNotTaken, EL1.ExactNotTaken,
9326 UseSequentialUMin);
9327 }
9328 if (EL0.ConstantMaxNotTaken == getCouldNotCompute())
9329 ConstantMaxBECount = EL1.ConstantMaxNotTaken;
9330 else if (EL1.ConstantMaxNotTaken == getCouldNotCompute())
9331 ConstantMaxBECount = EL0.ConstantMaxNotTaken;
9332 else
9333 ConstantMaxBECount = getUMinFromMismatchedTypes(EL0.ConstantMaxNotTaken,
9334 EL1.ConstantMaxNotTaken);
9335 if (EL0.SymbolicMaxNotTaken == getCouldNotCompute())
9336 SymbolicMaxBECount = EL1.SymbolicMaxNotTaken;
9337 else if (EL1.SymbolicMaxNotTaken == getCouldNotCompute())
9338 SymbolicMaxBECount = EL0.SymbolicMaxNotTaken;
9339 else
9340 SymbolicMaxBECount = getUMinFromMismatchedTypes(
9341 EL0.SymbolicMaxNotTaken, EL1.SymbolicMaxNotTaken, UseSequentialUMin);
9342 } else {
9343 // Both conditions must be same at the same time for the loop to exit.
9344 // For now, be conservative.
9345 if (EL0.ExactNotTaken == EL1.ExactNotTaken)
9346 BECount = EL0.ExactNotTaken;
9347 }
9348
9349 // There are cases (e.g. PR26207) where computeExitLimitFromCond is able
9350 // to be more aggressive when computing BECount than when computing
9351 // ConstantMaxBECount. In these cases it is possible for EL0.ExactNotTaken
9352 // and
9353 // EL1.ExactNotTaken to match, but for EL0.ConstantMaxNotTaken and
9354 // EL1.ConstantMaxNotTaken to not.
9355 if (isa<SCEVCouldNotCompute>(ConstantMaxBECount) &&
9356 !isa<SCEVCouldNotCompute>(BECount))
9357 ConstantMaxBECount = getConstant(getUnsignedRangeMax(BECount));
9358 if (isa<SCEVCouldNotCompute>(SymbolicMaxBECount))
9359 SymbolicMaxBECount =
9360 isa<SCEVCouldNotCompute>(BECount) ? ConstantMaxBECount : BECount;
9361 return ExitLimit(BECount, ConstantMaxBECount, SymbolicMaxBECount, false,
9362 {ArrayRef(EL0.Predicates), ArrayRef(EL1.Predicates)});
9363}
9364
9365ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromICmp(
9366 const Loop *L, ICmpInst *ExitCond, bool ExitIfTrue, bool ControlsOnlyExit,
9367 bool AllowPredicates) {
9368 // If the condition was exit on true, convert the condition to exit on false
9369 CmpPredicate Pred;
9370 if (!ExitIfTrue)
9371 Pred = ExitCond->getCmpPredicate();
9372 else
9373 Pred = ExitCond->getInverseCmpPredicate();
9374 const ICmpInst::Predicate OriginalPred = Pred;
9375
9376 const SCEV *LHS = getSCEV(ExitCond->getOperand(0));
9377 const SCEV *RHS = getSCEV(ExitCond->getOperand(1));
9378
9379 ExitLimit EL = computeExitLimitFromICmp(L, Pred, LHS, RHS, ControlsOnlyExit,
9380 AllowPredicates);
9381 if (EL.hasAnyInfo())
9382 return EL;
9383
9384 auto *ExhaustiveCount =
9385 computeExitCountExhaustively(L, ExitCond, ExitIfTrue);
9386
9387 if (!isa<SCEVCouldNotCompute>(ExhaustiveCount))
9388 return ExhaustiveCount;
9389
9390 return computeShiftCompareExitLimit(ExitCond->getOperand(0),
9391 ExitCond->getOperand(1), L, OriginalPred);
9392}
9393ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromICmp(
9394 const Loop *L, CmpPredicate Pred, SCEVUse LHS, SCEVUse RHS,
9395 bool ControlsOnlyExit, bool AllowPredicates) {
9396
9397 // Try to evaluate any dependencies out of the loop.
9398 LHS = getSCEVAtScope(LHS, L);
9399 RHS = getSCEVAtScope(RHS, L);
9400
9401 // At this point, we would like to compute how many iterations of the
9402 // loop the predicate will return true for these inputs.
9403 if (isLoopInvariant(LHS, L) && !isLoopInvariant(RHS, L)) {
9404 // If there is a loop-invariant, force it into the RHS.
9405 std::swap(LHS, RHS);
9407 }
9408
9409 bool ControllingFiniteLoop = ControlsOnlyExit && loopHasNoAbnormalExits(L) &&
9411 // Simplify the operands before analyzing them.
9412 (void)SimplifyICmpOperands(Pred, LHS, RHS, /*Depth=*/0);
9413
9414 // If we have a comparison of a chrec against a constant, try to use value
9415 // ranges to answer this query.
9416 if (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS))
9417 if (const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(LHS))
9418 if (AddRec->getLoop() == L) {
9419 // Form the constant range.
9420 ConstantRange CompRange =
9421 ConstantRange::makeExactICmpRegion(Pred, RHSC->getAPInt());
9422
9423 const SCEV *Ret = AddRec->getNumIterationsInRange(CompRange, *this);
9424 if (!isa<SCEVCouldNotCompute>(Ret)) return Ret;
9425 }
9426
9427 // If this loop must exit based on this condition (or execute undefined
9428 // behaviour), see if we can improve wrap flags. This is essentially
9429 // a must execute style proof.
9430 if (ControllingFiniteLoop && isLoopInvariant(RHS, L)) {
9431 // If we can prove the test sequence produced must repeat the same values
9432 // on self-wrap of the IV, then we can infer that IV doesn't self wrap
9433 // because if it did, we'd have an infinite (undefined) loop.
9434 // TODO: We can peel off any functions which are invertible *in L*. Loop
9435 // invariant terms are effectively constants for our purposes here.
9436 SCEVUse InnerLHS = LHS;
9437 if (auto *ZExt = dyn_cast<SCEVZeroExtendExpr>(LHS))
9438 InnerLHS = ZExt->getOperand();
9439 if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(InnerLHS);
9440 AR && !AR->hasNoSelfWrap() && AR->getLoop() == L && AR->isAffine() &&
9441 isKnownToBeAPowerOfTwo(AR->getStepRecurrence(*this), /*OrZero=*/true,
9442 /*OrNegative=*/true)) {
9443 auto Flags = AR->getNoWrapFlags();
9444 Flags = setFlags(Flags, SCEV::FlagNW);
9445 SmallVector<SCEVUse> Operands{AR->operands()};
9446 Flags = StrengthenNoWrapFlags(this, scAddRecExpr, Operands, Flags);
9447 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), Flags);
9448 }
9449
9450 // For a slt/ult condition with a positive step, can we prove nsw/nuw?
9451 // From no-self-wrap, this follows trivially from the fact that every
9452 // (un)signed-wrapped, but not self-wrapped value must be LT than the
9453 // last value before (un)signed wrap. Since we know that last value
9454 // didn't exit, nor will any smaller one.
9455 if (Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_ULT) {
9456 auto WrapType = Pred == ICmpInst::ICMP_SLT ? SCEV::FlagNSW : SCEV::FlagNUW;
9457 if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(LHS);
9458 AR && AR->getLoop() == L && AR->isAffine() &&
9459 !AR->getNoWrapFlags(WrapType) && AR->hasNoSelfWrap() &&
9460 isKnownPositive(AR->getStepRecurrence(*this))) {
9461 auto Flags = AR->getNoWrapFlags();
9462 Flags = setFlags(Flags, WrapType);
9463 SmallVector<SCEVUse> Operands{AR->operands()};
9464 Flags = StrengthenNoWrapFlags(this, scAddRecExpr, Operands, Flags);
9465 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), Flags);
9466 }
9467 }
9468 }
9469
9470 switch (Pred) {
9471 case ICmpInst::ICMP_NE: { // while (X != Y)
9472 // Convert to: while (X-Y != 0)
9473 if (LHS->getType()->isPointerTy()) {
9476 return LHS;
9477 }
9478 if (RHS->getType()->isPointerTy()) {
9481 return RHS;
9482 }
9483 ExitLimit EL = howFarToZero(getMinusSCEV(LHS, RHS), L, ControlsOnlyExit,
9484 AllowPredicates);
9485 if (EL.hasAnyInfo())
9486 return EL;
9487 break;
9488 }
9489 case ICmpInst::ICMP_EQ: { // while (X == Y)
9490 // Convert to: while (X-Y == 0)
9491 if (LHS->getType()->isPointerTy()) {
9494 return LHS;
9495 }
9496 if (RHS->getType()->isPointerTy()) {
9499 return RHS;
9500 }
9501 ExitLimit EL = howFarToNonZero(getMinusSCEV(LHS, RHS), L);
9502 if (EL.hasAnyInfo()) return EL;
9503 break;
9504 }
9505 case ICmpInst::ICMP_SLE:
9506 case ICmpInst::ICMP_ULE:
9507 // Since the loop is finite, an invariant RHS cannot include the boundary
9508 // value, otherwise it would loop forever.
9509 if (!EnableFiniteLoopControl || !ControllingFiniteLoop ||
9510 !isLoopInvariant(RHS, L)) {
9511 // Otherwise, perform the addition in a wider type, to avoid overflow.
9512 // If the LHS is an addrec with the appropriate nowrap flag, the
9513 // extension will be sunk into it and the exit count can be analyzed.
9514 auto *OldType = dyn_cast<IntegerType>(LHS->getType());
9515 if (!OldType)
9516 break;
9517 // Prefer doubling the bitwidth over adding a single bit to make it more
9518 // likely that we use a legal type.
9519 auto *NewType =
9520 Type::getIntNTy(OldType->getContext(), OldType->getBitWidth() * 2);
9521 if (ICmpInst::isSigned(Pred)) {
9522 LHS = getSignExtendExpr(LHS, NewType);
9523 RHS = getSignExtendExpr(RHS, NewType);
9524 } else {
9525 LHS = getZeroExtendExpr(LHS, NewType);
9526 RHS = getZeroExtendExpr(RHS, NewType);
9527 }
9528 }
9530 [[fallthrough]];
9531 case ICmpInst::ICMP_SLT:
9532 case ICmpInst::ICMP_ULT: { // while (X < Y)
9533 bool IsSigned = ICmpInst::isSigned(Pred);
9534 ExitLimit EL = howManyLessThans(LHS, RHS, L, IsSigned, ControlsOnlyExit,
9535 AllowPredicates);
9536 if (EL.hasAnyInfo())
9537 return EL;
9538 break;
9539 }
9540 case ICmpInst::ICMP_SGE:
9541 case ICmpInst::ICMP_UGE:
9542 // Since the loop is finite, an invariant RHS cannot include the boundary
9543 // value, otherwise it would loop forever.
9544 if (!EnableFiniteLoopControl || !ControllingFiniteLoop ||
9545 !isLoopInvariant(RHS, L))
9546 break;
9548 [[fallthrough]];
9549 case ICmpInst::ICMP_SGT:
9550 case ICmpInst::ICMP_UGT: { // while (X > Y)
9551 bool IsSigned = ICmpInst::isSigned(Pred);
9552 ExitLimit EL = howManyGreaterThans(LHS, RHS, L, IsSigned, ControlsOnlyExit,
9553 AllowPredicates);
9554 if (EL.hasAnyInfo())
9555 return EL;
9556 break;
9557 }
9558 default:
9559 break;
9560 }
9561
9562 return getCouldNotCompute();
9563}
9564
9565ScalarEvolution::ExitLimit
9566ScalarEvolution::computeExitLimitFromSingleExitSwitch(const Loop *L,
9567 SwitchInst *Switch,
9568 BasicBlock *ExitingBlock,
9569 bool ControlsOnlyExit) {
9570 assert(!L->contains(ExitingBlock) && "Not an exiting block!");
9571
9572 // Give up if the exit is the default dest of a switch.
9573 if (Switch->getDefaultDest() == ExitingBlock)
9574 return getCouldNotCompute();
9575
9576 assert(L->contains(Switch->getDefaultDest()) &&
9577 "Default case must not exit the loop!");
9578 const SCEV *LHS = getSCEVAtScope(Switch->getCondition(), L);
9579 const SCEV *RHS = getConstant(Switch->findCaseDest(ExitingBlock));
9580
9581 // while (X != Y) --> while (X-Y != 0)
9582 ExitLimit EL = howFarToZero(getMinusSCEV(LHS, RHS), L, ControlsOnlyExit);
9583 if (EL.hasAnyInfo())
9584 return EL;
9585
9586 return getCouldNotCompute();
9587}
9588
9589static ConstantInt *
9591 ScalarEvolution &SE) {
9592 const SCEV *InVal = SE.getConstant(C);
9593 const SCEV *Val = AddRec->evaluateAtIteration(InVal, SE);
9595 "Evaluation of SCEV at constant didn't fold correctly?");
9596 return cast<SCEVConstant>(Val)->getValue();
9597}
9598
9599ScalarEvolution::ExitLimit ScalarEvolution::computeShiftCompareExitLimit(
9600 Value *LHS, Value *RHSV, const Loop *L, ICmpInst::Predicate Pred) {
9601 ConstantInt *RHS = dyn_cast<ConstantInt>(RHSV);
9602 if (!RHS)
9603 return getCouldNotCompute();
9604
9605 const BasicBlock *Latch = L->getLoopLatch();
9606 if (!Latch)
9607 return getCouldNotCompute();
9608
9609 const BasicBlock *Predecessor = L->getLoopPredecessor();
9610 if (!Predecessor)
9611 return getCouldNotCompute();
9612
9613 // Return true if V is of the form "LHS `shift_op` <positive constant>".
9614 // Return LHS in OutLHS and shift_opt in OutOpCode.
9615 auto MatchPositiveShift =
9616 [](Value *V, Value *&OutLHS, Instruction::BinaryOps &OutOpCode) {
9617
9618 using namespace PatternMatch;
9619
9620 ConstantInt *ShiftAmt;
9621 if (match(V, m_LShr(m_Value(OutLHS), m_ConstantInt(ShiftAmt))))
9622 OutOpCode = Instruction::LShr;
9623 else if (match(V, m_AShr(m_Value(OutLHS), m_ConstantInt(ShiftAmt))))
9624 OutOpCode = Instruction::AShr;
9625 else if (match(V, m_Shl(m_Value(OutLHS), m_ConstantInt(ShiftAmt))))
9626 OutOpCode = Instruction::Shl;
9627 else
9628 return false;
9629
9630 return ShiftAmt->getValue().isStrictlyPositive();
9631 };
9632
9633 // Recognize a "shift recurrence" either of the form %iv or of %iv.shifted in
9634 //
9635 // loop:
9636 // %iv = phi i32 [ %iv.shifted, %loop ], [ %val, %preheader ]
9637 // %iv.shifted = lshr i32 %iv, <positive constant>
9638 //
9639 // Return true on a successful match. Return the corresponding PHI node (%iv
9640 // above) in PNOut and the opcode of the shift operation in OpCodeOut.
9641 auto MatchShiftRecurrence =
9642 [&](Value *V, PHINode *&PNOut, Instruction::BinaryOps &OpCodeOut) {
9643 std::optional<Instruction::BinaryOps> PostShiftOpCode;
9644
9645 {
9647 Value *V;
9648
9649 // If we encounter a shift instruction, "peel off" the shift operation,
9650 // and remember that we did so. Later when we inspect %iv's backedge
9651 // value, we will make sure that the backedge value uses the same
9652 // operation.
9653 //
9654 // Note: the peeled shift operation does not have to be the same
9655 // instruction as the one feeding into the PHI's backedge value. We only
9656 // really care about it being the same *kind* of shift instruction --
9657 // that's all that is required for our later inferences to hold.
9658 if (MatchPositiveShift(LHS, V, OpC)) {
9659 PostShiftOpCode = OpC;
9660 LHS = V;
9661 }
9662 }
9663
9664 PNOut = dyn_cast<PHINode>(LHS);
9665 if (!PNOut || PNOut->getParent() != L->getHeader())
9666 return false;
9667
9668 Value *BEValue = PNOut->getIncomingValueForBlock(Latch);
9669 Value *OpLHS;
9670
9671 return
9672 // The backedge value for the PHI node must be a shift by a positive
9673 // amount
9674 MatchPositiveShift(BEValue, OpLHS, OpCodeOut) &&
9675
9676 // of the PHI node itself
9677 OpLHS == PNOut &&
9678
9679 // and the kind of shift should be match the kind of shift we peeled
9680 // off, if any.
9681 (!PostShiftOpCode || *PostShiftOpCode == OpCodeOut);
9682 };
9683
9684 PHINode *PN;
9686 if (!MatchShiftRecurrence(LHS, PN, OpCode))
9687 return getCouldNotCompute();
9688
9689 const DataLayout &DL = getDataLayout();
9690
9691 // The key rationale for this optimization is that for some kinds of shift
9692 // recurrences, the value of the recurrence "stabilizes" to either 0 or -1
9693 // within a finite number of iterations. If the condition guarding the
9694 // backedge (in the sense that the backedge is taken if the condition is true)
9695 // is false for the value the shift recurrence stabilizes to, then we know
9696 // that the backedge is taken only a finite number of times.
9697
9698 ConstantInt *StableValue = nullptr;
9699 switch (OpCode) {
9700 default:
9701 llvm_unreachable("Impossible case!");
9702
9703 case Instruction::AShr: {
9704 // {K,ashr,<positive-constant>} stabilizes to signum(K) in at most
9705 // bitwidth(K) iterations.
9706 Value *FirstValue = PN->getIncomingValueForBlock(Predecessor);
9707 KnownBits Known = computeKnownBits(FirstValue, DL, &AC,
9708 Predecessor->getTerminator(), &DT);
9709 auto *Ty = cast<IntegerType>(RHS->getType());
9710 if (Known.isNonNegative())
9711 StableValue = ConstantInt::get(Ty, 0);
9712 else if (Known.isNegative())
9713 StableValue = ConstantInt::get(Ty, -1, true);
9714 else
9715 return getCouldNotCompute();
9716
9717 break;
9718 }
9719 case Instruction::LShr:
9720 case Instruction::Shl:
9721 // Both {K,lshr,<positive-constant>} and {K,shl,<positive-constant>}
9722 // stabilize to 0 in at most bitwidth(K) iterations.
9723 StableValue = ConstantInt::get(cast<IntegerType>(RHS->getType()), 0);
9724 break;
9725 }
9726
9727 auto *Result =
9728 ConstantFoldCompareInstOperands(Pred, StableValue, RHS, DL, &TLI);
9729 assert(Result->getType()->isIntegerTy(1) &&
9730 "Otherwise cannot be an operand to a branch instruction");
9731
9732 if (Result->isNullValue()) {
9733 unsigned BitWidth = getTypeSizeInBits(RHS->getType());
9734 const SCEV *UpperBound =
9736 return ExitLimit(getCouldNotCompute(), UpperBound, UpperBound, false);
9737 }
9738
9739 return getCouldNotCompute();
9740}
9741
9742/// Return true if we can constant fold an instruction of the specified type,
9743/// assuming that all operands were constants.
9744static bool CanConstantFold(const Instruction *I) {
9748 return true;
9749
9750 if (const CallInst *CI = dyn_cast<CallInst>(I))
9751 if (const Function *F = CI->getCalledFunction())
9752 return canConstantFoldCallTo(CI, F);
9753 return false;
9754}
9755
9756/// Determine whether this instruction can constant evolve within this loop
9757/// assuming its operands can all constant evolve.
9758static bool canConstantEvolve(Instruction *I, const Loop *L) {
9759 // An instruction outside of the loop can't be derived from a loop PHI.
9760 if (!L->contains(I)) return false;
9761
9762 if (isa<PHINode>(I)) {
9763 // We don't currently keep track of the control flow needed to evaluate
9764 // PHIs, so we cannot handle PHIs inside of loops.
9765 return L->getHeader() == I->getParent();
9766 }
9767
9768 // If we won't be able to constant fold this expression even if the operands
9769 // are constants, bail early.
9770 return CanConstantFold(I);
9771}
9772
9773/// getConstantEvolvingPHIOperands - Implement getConstantEvolvingPHI by
9774/// recursing through each instruction operand until reaching a loop header phi.
9775static PHINode *
9778 unsigned Depth) {
9780 return nullptr;
9781
9782 // Otherwise, we can evaluate this instruction if all of its operands are
9783 // constant or derived from a PHI node themselves.
9784 PHINode *PHI = nullptr;
9785 for (Value *Op : UseInst->operands()) {
9786 if (isa<Constant>(Op)) continue;
9787
9789 if (!OpInst || !canConstantEvolve(OpInst, L)) return nullptr;
9790
9791 PHINode *P = dyn_cast<PHINode>(OpInst);
9792 if (!P)
9793 // If this operand is already visited, reuse the prior result.
9794 // We may have P != PHI if this is the deepest point at which the
9795 // inconsistent paths meet.
9796 P = PHIMap.lookup(OpInst);
9797 if (!P) {
9798 // Recurse and memoize the results, whether a phi is found or not.
9799 // This recursive call invalidates pointers into PHIMap.
9800 P = getConstantEvolvingPHIOperands(OpInst, L, PHIMap, Depth + 1);
9801 PHIMap[OpInst] = P;
9802 }
9803 if (!P)
9804 return nullptr; // Not evolving from PHI
9805 if (PHI && PHI != P)
9806 return nullptr; // Evolving from multiple different PHIs.
9807 PHI = P;
9808 }
9809 // This is a expression evolving from a constant PHI!
9810 return PHI;
9811}
9812
9813/// getConstantEvolvingPHI - Given an LLVM value and a loop, return a PHI node
9814/// in the loop that V is derived from. We allow arbitrary operations along the
9815/// way, but the operands of an operation must either be constants or a value
9816/// derived from a constant PHI. If this expression does not fit with these
9817/// constraints, return null.
9820 if (!I || !canConstantEvolve(I, L)) return nullptr;
9821
9822 if (PHINode *PN = dyn_cast<PHINode>(I))
9823 return PN;
9824
9825 // Record non-constant instructions contained by the loop.
9827 return getConstantEvolvingPHIOperands(I, L, PHIMap, 0);
9828}
9829
9830/// EvaluateExpression - Given an expression that passes the
9831/// getConstantEvolvingPHI predicate, evaluate its value assuming the PHI node
9832/// in the loop has the value PHIVal. If we can't fold this expression for some
9833/// reason, return null.
9836 const DataLayout &DL,
9837 const TargetLibraryInfo *TLI) {
9838 // Convenient constant check, but redundant for recursive calls.
9839 if (Constant *C = dyn_cast<Constant>(V)) return C;
9841 if (!I) return nullptr;
9842
9843 if (Constant *C = Vals.lookup(I)) return C;
9844
9845 // An instruction inside the loop depends on a value outside the loop that we
9846 // weren't given a mapping for, or a value such as a call inside the loop.
9847 if (!canConstantEvolve(I, L)) return nullptr;
9848
9849 // An unmapped PHI can be due to a branch or another loop inside this loop,
9850 // or due to this not being the initial iteration through a loop where we
9851 // couldn't compute the evolution of this particular PHI last time.
9852 if (isa<PHINode>(I)) return nullptr;
9853
9854 std::vector<Constant*> Operands(I->getNumOperands());
9855
9856 for (unsigned i = 0, e = I->getNumOperands(); i != e; ++i) {
9857 Instruction *Operand = dyn_cast<Instruction>(I->getOperand(i));
9858 if (!Operand) {
9859 Operands[i] = dyn_cast<Constant>(I->getOperand(i));
9860 if (!Operands[i]) return nullptr;
9861 continue;
9862 }
9863 Constant *C = EvaluateExpression(Operand, L, Vals, DL, TLI);
9864 Vals[Operand] = C;
9865 if (!C) return nullptr;
9866 Operands[i] = C;
9867 }
9868
9869 return ConstantFoldInstOperands(I, Operands, DL, TLI,
9870 /*AllowNonDeterministic=*/false);
9871}
9872
9873
9874// If every incoming value to PN except the one for BB is a specific Constant,
9875// return that, else return nullptr.
9877 Constant *IncomingVal = nullptr;
9878
9879 for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) {
9880 if (PN->getIncomingBlock(i) == BB)
9881 continue;
9882
9883 auto *CurrentVal = dyn_cast<Constant>(PN->getIncomingValue(i));
9884 if (!CurrentVal)
9885 return nullptr;
9886
9887 if (IncomingVal != CurrentVal) {
9888 if (IncomingVal)
9889 return nullptr;
9890 IncomingVal = CurrentVal;
9891 }
9892 }
9893
9894 return IncomingVal;
9895}
9896
9897/// getConstantEvolutionLoopExitValue - If we know that the specified Phi is
9898/// in the header of its containing loop, we know the loop executes a
9899/// constant number of times, and the PHI node is just a recurrence
9900/// involving constants, fold it.
9901Constant *
9902ScalarEvolution::getConstantEvolutionLoopExitValue(PHINode *PN,
9903 const APInt &BEs,
9904 const Loop *L) {
9905 auto [I, Inserted] = ConstantEvolutionLoopExitValue.try_emplace(PN);
9906 if (!Inserted)
9907 return I->second;
9908
9910 return nullptr; // Not going to evaluate it.
9911
9912 Constant *&RetVal = I->second;
9913
9914 DenseMap<Instruction *, Constant *> CurrentIterVals;
9915 BasicBlock *Header = L->getHeader();
9916 assert(PN->getParent() == Header && "Can't evaluate PHI not in loop header!");
9917
9918 BasicBlock *Latch = L->getLoopLatch();
9919 if (!Latch)
9920 return nullptr;
9921
9922 for (PHINode &PHI : Header->phis()) {
9923 if (auto *StartCST = getOtherIncomingValue(&PHI, Latch))
9924 CurrentIterVals[&PHI] = StartCST;
9925 }
9926 if (!CurrentIterVals.count(PN))
9927 return RetVal = nullptr;
9928
9929 Value *BEValue = PN->getIncomingValueForBlock(Latch);
9930
9931 // Execute the loop symbolically to determine the exit value.
9932 assert(BEs.getActiveBits() < CHAR_BIT * sizeof(unsigned) &&
9933 "BEs is <= MaxBruteForceIterations which is an 'unsigned'!");
9934
9935 unsigned NumIterations = BEs.getZExtValue(); // must be in range
9936 unsigned IterationNum = 0;
9937 const DataLayout &DL = getDataLayout();
9938 for (; ; ++IterationNum) {
9939 if (IterationNum == NumIterations)
9940 return RetVal = CurrentIterVals[PN]; // Got exit value!
9941
9942 // Compute the value of the PHIs for the next iteration.
9943 // EvaluateExpression adds non-phi values to the CurrentIterVals map.
9944 DenseMap<Instruction *, Constant *> NextIterVals;
9945 Constant *NextPHI =
9946 EvaluateExpression(BEValue, L, CurrentIterVals, DL, &TLI);
9947 if (!NextPHI)
9948 return nullptr; // Couldn't evaluate!
9949 NextIterVals[PN] = NextPHI;
9950
9951 bool StoppedEvolving = NextPHI == CurrentIterVals[PN];
9952
9953 // Also evaluate the other PHI nodes. However, we don't get to stop if we
9954 // cease to be able to evaluate one of them or if they stop evolving,
9955 // because that doesn't necessarily prevent us from computing PN.
9957 for (const auto &I : CurrentIterVals) {
9958 PHINode *PHI = dyn_cast<PHINode>(I.first);
9959 if (!PHI || PHI == PN || PHI->getParent() != Header) continue;
9960 PHIsToCompute.emplace_back(PHI, I.second);
9961 }
9962 // We use two distinct loops because EvaluateExpression may invalidate any
9963 // iterators into CurrentIterVals.
9964 for (const auto &I : PHIsToCompute) {
9965 PHINode *PHI = I.first;
9966 Constant *&NextPHI = NextIterVals[PHI];
9967 if (!NextPHI) { // Not already computed.
9968 Value *BEValue = PHI->getIncomingValueForBlock(Latch);
9969 NextPHI = EvaluateExpression(BEValue, L, CurrentIterVals, DL, &TLI);
9970 }
9971 if (NextPHI != I.second)
9972 StoppedEvolving = false;
9973 }
9974
9975 // If all entries in CurrentIterVals == NextIterVals then we can stop
9976 // iterating, the loop can't continue to change.
9977 if (StoppedEvolving)
9978 return RetVal = CurrentIterVals[PN];
9979
9980 CurrentIterVals.swap(NextIterVals);
9981 }
9982}
9983
9984const SCEV *ScalarEvolution::computeExitCountExhaustively(const Loop *L,
9985 Value *Cond,
9986 bool ExitWhen) {
9987 PHINode *PN = getConstantEvolvingPHI(Cond, L);
9988 if (!PN) return getCouldNotCompute();
9989
9990 // If the loop is canonicalized, the PHI will have exactly two entries.
9991 // That's the only form we support here.
9992 if (PN->getNumIncomingValues() != 2) return getCouldNotCompute();
9993
9994 DenseMap<Instruction *, Constant *> CurrentIterVals;
9995 BasicBlock *Header = L->getHeader();
9996 assert(PN->getParent() == Header && "Can't evaluate PHI not in loop header!");
9997
9998 BasicBlock *Latch = L->getLoopLatch();
9999 assert(Latch && "Should follow from NumIncomingValues == 2!");
10000
10001 for (PHINode &PHI : Header->phis()) {
10002 if (auto *StartCST = getOtherIncomingValue(&PHI, Latch))
10003 CurrentIterVals[&PHI] = StartCST;
10004 }
10005 if (!CurrentIterVals.count(PN))
10006 return getCouldNotCompute();
10007
10008 // Okay, we find a PHI node that defines the trip count of this loop. Execute
10009 // the loop symbolically to determine when the condition gets a value of
10010 // "ExitWhen".
10011 unsigned MaxIterations = MaxBruteForceIterations; // Limit analysis.
10012 const DataLayout &DL = getDataLayout();
10013 for (unsigned IterationNum = 0; IterationNum != MaxIterations;++IterationNum){
10014 auto *CondVal = dyn_cast_or_null<ConstantInt>(
10015 EvaluateExpression(Cond, L, CurrentIterVals, DL, &TLI));
10016
10017 // Couldn't symbolically evaluate.
10018 if (!CondVal) return getCouldNotCompute();
10019
10020 if (CondVal->getValue() == uint64_t(ExitWhen)) {
10021 ++NumBruteForceTripCountsComputed;
10022 return getConstant(Type::getInt32Ty(getContext()), IterationNum);
10023 }
10024
10025 // Update all the PHI nodes for the next iteration.
10026 DenseMap<Instruction *, Constant *> NextIterVals;
10027
10028 // Create a list of which PHIs we need to compute. We want to do this before
10029 // calling EvaluateExpression on them because that may invalidate iterators
10030 // into CurrentIterVals.
10031 SmallVector<PHINode *, 8> PHIsToCompute;
10032 for (const auto &I : CurrentIterVals) {
10033 PHINode *PHI = dyn_cast<PHINode>(I.first);
10034 if (!PHI || PHI->getParent() != Header) continue;
10035 PHIsToCompute.push_back(PHI);
10036 }
10037 for (PHINode *PHI : PHIsToCompute) {
10038 Constant *&NextPHI = NextIterVals[PHI];
10039 if (NextPHI) continue; // Already computed!
10040
10041 Value *BEValue = PHI->getIncomingValueForBlock(Latch);
10042 NextPHI = EvaluateExpression(BEValue, L, CurrentIterVals, DL, &TLI);
10043 }
10044 CurrentIterVals.swap(NextIterVals);
10045 }
10046
10047 // Too many iterations were needed to evaluate.
10048 return getCouldNotCompute();
10049}
10050
10051const SCEV *ScalarEvolution::getSCEVAtScope(const SCEV *V, const Loop *L) {
10053 ValuesAtScopes[V];
10054 // Check to see if we've folded this expression at this loop before.
10055 for (auto &LS : Values)
10056 if (LS.first == L)
10057 return LS.second ? LS.second : V;
10058
10059 Values.emplace_back(L, nullptr);
10060
10061 // Otherwise compute it.
10062 const SCEV *C = computeSCEVAtScope(V, L);
10063 for (auto &LS : reverse(ValuesAtScopes[V]))
10064 if (LS.first == L) {
10065 LS.second = C;
10066 if (!isa<SCEVConstant>(C))
10067 ValuesAtScopesUsers[C].push_back({L, V});
10068 break;
10069 }
10070 return C;
10071}
10072
10073/// This builds up a Constant using the ConstantExpr interface. That way, we
10074/// will return Constants for objects which aren't represented by a
10075/// SCEVConstant, because SCEVConstant is restricted to ConstantInt.
10076/// Returns NULL if the SCEV isn't representable as a Constant.
10078 switch (V->getSCEVType()) {
10079 case scCouldNotCompute:
10080 case scAddRecExpr:
10081 case scVScale:
10082 return nullptr;
10083 case scConstant:
10084 return cast<SCEVConstant>(V)->getValue();
10085 case scUnknown:
10086 return dyn_cast<Constant>(cast<SCEVUnknown>(V)->getValue());
10087 case scPtrToAddr: {
10089 if (Constant *CastOp = BuildConstantFromSCEV(P2I->getOperand()))
10090 return ConstantExpr::getPtrToAddr(CastOp, P2I->getType());
10091
10092 return nullptr;
10093 }
10094 case scPtrToInt: {
10096 if (Constant *CastOp = BuildConstantFromSCEV(P2I->getOperand()))
10097 return ConstantExpr::getPtrToInt(CastOp, P2I->getType());
10098
10099 return nullptr;
10100 }
10101 case scTruncate: {
10103 if (Constant *CastOp = BuildConstantFromSCEV(ST->getOperand()))
10104 return ConstantExpr::getTrunc(CastOp, ST->getType());
10105 return nullptr;
10106 }
10107 case scAddExpr: {
10108 const SCEVAddExpr *SA = cast<SCEVAddExpr>(V);
10109 Constant *C = nullptr;
10110 for (const SCEV *Op : SA->operands()) {
10112 if (!OpC)
10113 return nullptr;
10114 if (!C) {
10115 C = OpC;
10116 continue;
10117 }
10118 assert(!C->getType()->isPointerTy() &&
10119 "Can only have one pointer, and it must be last");
10120 if (OpC->getType()->isPointerTy()) {
10121 // The offsets have been converted to bytes. We can add bytes using
10122 // an i8 GEP.
10123 C = ConstantExpr::getPtrAdd(OpC, C);
10124 } else {
10125 C = ConstantExpr::getAdd(C, OpC);
10126 }
10127 }
10128 return C;
10129 }
10130 case scMulExpr:
10131 case scSignExtend:
10132 case scZeroExtend:
10133 case scUDivExpr:
10134 case scSMaxExpr:
10135 case scUMaxExpr:
10136 case scSMinExpr:
10137 case scUMinExpr:
10139 return nullptr;
10140 }
10141 llvm_unreachable("Unknown SCEV kind!");
10142}
10143
10144const SCEV *ScalarEvolution::getWithOperands(const SCEV *S,
10145 SmallVectorImpl<SCEVUse> &NewOps) {
10146 switch (S->getSCEVType()) {
10147 case scTruncate:
10148 case scZeroExtend:
10149 case scSignExtend:
10150 case scPtrToAddr:
10151 case scPtrToInt:
10152 return getCastExpr(S->getSCEVType(), NewOps[0], S->getType());
10153 case scAddRecExpr: {
10154 auto *AddRec = cast<SCEVAddRecExpr>(S);
10155 return getAddRecExpr(NewOps, AddRec->getLoop(), AddRec->getNoWrapFlags());
10156 }
10157 case scAddExpr:
10158 return getAddExpr(NewOps, cast<SCEVAddExpr>(S)->getNoWrapFlags());
10159 case scMulExpr:
10160 return getMulExpr(NewOps, cast<SCEVMulExpr>(S)->getNoWrapFlags());
10161 case scUDivExpr:
10162 return getUDivExpr(NewOps[0], NewOps[1]);
10163 case scUMaxExpr:
10164 case scSMaxExpr:
10165 case scUMinExpr:
10166 case scSMinExpr:
10167 return getMinMaxExpr(S->getSCEVType(), NewOps);
10169 return getSequentialMinMaxExpr(S->getSCEVType(), NewOps);
10170 case scConstant:
10171 case scVScale:
10172 case scUnknown:
10173 return S;
10174 case scCouldNotCompute:
10175 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
10176 }
10177 llvm_unreachable("Unknown SCEV kind!");
10178}
10179
10180const SCEV *ScalarEvolution::computeSCEVAtScope(const SCEV *V, const Loop *L) {
10181 switch (V->getSCEVType()) {
10182 case scConstant:
10183 case scVScale:
10184 return V;
10185 case scAddRecExpr: {
10186 // If this is a loop recurrence for a loop that does not contain L, then we
10187 // are dealing with the final value computed by the loop.
10188 const SCEVAddRecExpr *AddRec = cast<SCEVAddRecExpr>(V);
10189 // First, attempt to evaluate each operand.
10190 // Avoid performing the look-up in the common case where the specified
10191 // expression has no loop-variant portions.
10192 for (unsigned i = 0, e = AddRec->getNumOperands(); i != e; ++i) {
10193 const SCEV *OpAtScope = getSCEVAtScope(AddRec->getOperand(i), L);
10194 if (OpAtScope == AddRec->getOperand(i))
10195 continue;
10196
10197 // Okay, at least one of these operands is loop variant but might be
10198 // foldable. Build a new instance of the folded commutative expression.
10200 NewOps.reserve(AddRec->getNumOperands());
10201 append_range(NewOps, AddRec->operands().take_front(i));
10202 NewOps.push_back(OpAtScope);
10203 for (++i; i != e; ++i)
10204 NewOps.push_back(getSCEVAtScope(AddRec->getOperand(i), L));
10205
10206 const SCEV *FoldedRec = getAddRecExpr(
10207 NewOps, AddRec->getLoop(), AddRec->getNoWrapFlags(SCEV::FlagNW));
10208 AddRec = dyn_cast<SCEVAddRecExpr>(FoldedRec);
10209 // The addrec may be folded to a nonrecurrence, for example, if the
10210 // induction variable is multiplied by zero after constant folding. Go
10211 // ahead and return the folded value.
10212 if (!AddRec)
10213 return FoldedRec;
10214 break;
10215 }
10216
10217 // If the scope is outside the addrec's loop, evaluate it by using the
10218 // loop exit value of the addrec.
10219 if (!AddRec->getLoop()->contains(L)) {
10220 // To evaluate this recurrence, we need to know how many times the AddRec
10221 // loop iterates. Compute this now.
10222 const SCEV *BackedgeTakenCount = getBackedgeTakenCount(AddRec->getLoop());
10223 if (BackedgeTakenCount == getCouldNotCompute())
10224 return AddRec;
10225
10226 // Then, evaluate the AddRec.
10227 return AddRec->evaluateAtIteration(BackedgeTakenCount, *this);
10228 }
10229
10230 return AddRec;
10231 }
10232 case scTruncate:
10233 case scZeroExtend:
10234 case scSignExtend:
10235 case scPtrToAddr:
10236 case scPtrToInt:
10237 case scAddExpr:
10238 case scMulExpr:
10239 case scUDivExpr:
10240 case scUMaxExpr:
10241 case scSMaxExpr:
10242 case scUMinExpr:
10243 case scSMinExpr:
10244 case scSequentialUMinExpr: {
10245 ArrayRef<SCEVUse> Ops = V->operands();
10246 // Avoid performing the look-up in the common case where the specified
10247 // expression has no loop-variant portions.
10248 for (unsigned i = 0, e = Ops.size(); i != e; ++i) {
10249 const SCEV *OpAtScope = getSCEVAtScope(Ops[i].getPointer(), L);
10250 if (OpAtScope != Ops[i].getPointer()) {
10251 // Okay, at least one of these operands is loop variant but might be
10252 // foldable. Build a new instance of the folded commutative expression.
10254 NewOps.reserve(Ops.size());
10255 append_range(NewOps, Ops.take_front(i));
10256 NewOps.push_back(OpAtScope);
10257
10258 for (++i; i != e; ++i) {
10259 OpAtScope = getSCEVAtScope(Ops[i].getPointer(), L);
10260 NewOps.push_back(OpAtScope);
10261 }
10262
10263 return getWithOperands(V, NewOps);
10264 }
10265 }
10266 // If we got here, all operands are loop invariant.
10267 return V;
10268 }
10269 case scUnknown: {
10270 // If this instruction is evolved from a constant-evolving PHI, compute the
10271 // exit value from the loop without using SCEVs.
10272 const SCEVUnknown *SU = cast<SCEVUnknown>(V);
10274 if (!I)
10275 return V; // This is some other type of SCEVUnknown, just return it.
10276
10277 if (PHINode *PN = dyn_cast<PHINode>(I)) {
10278 const Loop *CurrLoop = this->LI[I->getParent()];
10279 // Looking for loop exit value.
10280 if (CurrLoop && CurrLoop->getParentLoop() == L &&
10281 PN->getParent() == CurrLoop->getHeader()) {
10282 // Okay, there is no closed form solution for the PHI node. Check
10283 // to see if the loop that contains it has a known backedge-taken
10284 // count. If so, we may be able to force computation of the exit
10285 // value.
10286 const SCEV *BackedgeTakenCount = getBackedgeTakenCount(CurrLoop);
10287 // This trivial case can show up in some degenerate cases where
10288 // the incoming IR has not yet been fully simplified.
10289 if (BackedgeTakenCount->isZero()) {
10290 Value *InitValue = nullptr;
10291 bool MultipleInitValues = false;
10292 for (unsigned i = 0; i < PN->getNumIncomingValues(); i++) {
10293 if (!CurrLoop->contains(PN->getIncomingBlock(i))) {
10294 if (!InitValue)
10295 InitValue = PN->getIncomingValue(i);
10296 else if (InitValue != PN->getIncomingValue(i)) {
10297 MultipleInitValues = true;
10298 break;
10299 }
10300 }
10301 }
10302 if (!MultipleInitValues && InitValue)
10303 return getSCEV(InitValue);
10304 }
10305 // Do we have a loop invariant value flowing around the backedge
10306 // for a loop which must execute the backedge?
10307 if (!isa<SCEVCouldNotCompute>(BackedgeTakenCount) &&
10308 isKnownNonZero(BackedgeTakenCount) &&
10309 PN->getNumIncomingValues() == 2) {
10310
10311 unsigned InLoopPred =
10312 CurrLoop->contains(PN->getIncomingBlock(0)) ? 0 : 1;
10313 Value *BackedgeVal = PN->getIncomingValue(InLoopPred);
10314 if (CurrLoop->isLoopInvariant(BackedgeVal))
10315 return getSCEV(BackedgeVal);
10316 }
10317 if (auto *BTCC = dyn_cast<SCEVConstant>(BackedgeTakenCount)) {
10318 // Okay, we know how many times the containing loop executes. If
10319 // this is a constant evolving PHI node, get the final value at
10320 // the specified iteration number.
10321 Constant *RV =
10322 getConstantEvolutionLoopExitValue(PN, BTCC->getAPInt(), CurrLoop);
10323 if (RV)
10324 return getSCEV(RV);
10325 }
10326 }
10327 }
10328
10329 // Okay, this is an expression that we cannot symbolically evaluate
10330 // into a SCEV. Check to see if it's possible to symbolically evaluate
10331 // the arguments into constants, and if so, try to constant propagate the
10332 // result. This is particularly useful for computing loop exit values.
10333 if (!CanConstantFold(I))
10334 return V; // This is some other type of SCEVUnknown, just return it.
10335
10336 SmallVector<Constant *, 4> Operands;
10337 Operands.reserve(I->getNumOperands());
10338 bool MadeImprovement = false;
10339 for (Value *Op : I->operands()) {
10340 if (Constant *C = dyn_cast<Constant>(Op)) {
10341 Operands.push_back(C);
10342 continue;
10343 }
10344
10345 // If any of the operands is non-constant and if they are
10346 // non-integer and non-pointer, don't even try to analyze them
10347 // with scev techniques.
10348 if (!isSCEVable(Op->getType()))
10349 return V;
10350
10351 const SCEV *OrigV = getSCEV(Op);
10352 const SCEV *OpV = getSCEVAtScope(OrigV, L);
10353 MadeImprovement |= OrigV != OpV;
10354
10356 if (!C)
10357 return V;
10358 assert(C->getType() == Op->getType() && "Type mismatch");
10359 Operands.push_back(C);
10360 }
10361
10362 // Check to see if getSCEVAtScope actually made an improvement.
10363 if (!MadeImprovement)
10364 return V; // This is some other type of SCEVUnknown, just return it.
10365
10366 Constant *C = nullptr;
10367 const DataLayout &DL = getDataLayout();
10368 C = ConstantFoldInstOperands(I, Operands, DL, &TLI,
10369 /*AllowNonDeterministic=*/false);
10370 if (!C)
10371 return V;
10372 return getSCEV(C);
10373 }
10374 case scCouldNotCompute:
10375 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
10376 }
10377 llvm_unreachable("Unknown SCEV type!");
10378}
10379
10381 return getSCEVAtScope(getSCEV(V), L);
10382}
10383
10384const SCEV *ScalarEvolution::stripInjectiveFunctions(const SCEV *S) const {
10386 return stripInjectiveFunctions(ZExt->getOperand());
10388 return stripInjectiveFunctions(SExt->getOperand());
10389 return S;
10390}
10391
10392/// Finds the minimum unsigned root of the following equation:
10393///
10394/// A * X = B (mod N)
10395///
10396/// where N = 2^BW and BW is the common bit width of A and B. The signedness of
10397/// A and B isn't important.
10398///
10399/// If the equation does not have a solution, SCEVCouldNotCompute is returned.
10400static const SCEV *
10403 ScalarEvolution &SE, const Loop *L) {
10404 uint32_t BW = A.getBitWidth();
10405 assert(BW == SE.getTypeSizeInBits(B->getType()));
10406 assert(A != 0 && "A must be non-zero.");
10407
10408 // 1. D = gcd(A, N)
10409 //
10410 // The gcd of A and N may have only one prime factor: 2. The number of
10411 // trailing zeros in A is its multiplicity
10412 uint32_t Mult2 = A.countr_zero();
10413 // D = 2^Mult2
10414
10415 // 2. Check if B is divisible by D.
10416 //
10417 // B is divisible by D if and only if the multiplicity of prime factor 2 for B
10418 // is not less than multiplicity of this prime factor for D.
10419 unsigned MinTZ = SE.getMinTrailingZeros(B);
10420 // Try again with the terminator of the loop predecessor for context-specific
10421 // result, if MinTZ s too small.
10422 if (MinTZ < Mult2 && L->getLoopPredecessor())
10423 MinTZ = SE.getMinTrailingZeros(B, L->getLoopPredecessor()->getTerminator());
10424 if (MinTZ < Mult2) {
10425 // Check if we can prove there's no remainder using URem.
10426 const SCEV *URem =
10427 SE.getURemExpr(B, SE.getConstant(APInt::getOneBitSet(BW, Mult2)));
10428 const SCEV *Zero = SE.getZero(B->getType());
10429 if (!SE.isKnownPredicate(CmpInst::ICMP_EQ, URem, Zero)) {
10430 // Try to add a predicate ensuring B is a multiple of 1 << Mult2.
10431 if (!Predicates)
10432 return SE.getCouldNotCompute();
10433
10434 // Avoid adding a predicate that is known to be false.
10435 if (SE.isKnownPredicate(CmpInst::ICMP_NE, URem, Zero))
10436 return SE.getCouldNotCompute();
10437 Predicates->push_back(SE.getEqualPredicate(URem, Zero));
10438 }
10439 }
10440
10441 // 3. Compute I: the multiplicative inverse of (A / D) in arithmetic
10442 // modulo (N / D).
10443 //
10444 // If D == 1, (N / D) == N == 2^BW, so we need one extra bit to represent
10445 // (N / D) in general. The inverse itself always fits into BW bits, though,
10446 // so we immediately truncate it.
10447 APInt AD = A.lshr(Mult2).trunc(BW - Mult2); // AD = A / D
10448 APInt I = AD.multiplicativeInverse().zext(BW);
10449
10450 // 4. Compute the minimum unsigned root of the equation:
10451 // I * (B / D) mod (N / D)
10452 // To simplify the computation, we factor out the divide by D:
10453 // (I * B mod N) / D
10454 const SCEV *D = SE.getConstant(APInt::getOneBitSet(BW, Mult2));
10455 return SE.getUDivExactExpr(SE.getMulExpr(B, SE.getConstant(I)), D);
10456}
10457
10458/// For a given quadratic addrec, generate coefficients of the corresponding
10459/// quadratic equation, multiplied by a common value to ensure that they are
10460/// integers.
10461/// The returned value is a tuple { A, B, C, M, BitWidth }, where
10462/// Ax^2 + Bx + C is the quadratic function, M is the value that A, B and C
10463/// were multiplied by, and BitWidth is the bit width of the original addrec
10464/// coefficients.
10465/// This function returns std::nullopt if the addrec coefficients are not
10466/// compile- time constants.
10467static std::optional<std::tuple<APInt, APInt, APInt, APInt, unsigned>>
10469 assert(AddRec->getNumOperands() == 3 && "This is not a quadratic chrec!");
10470 const SCEVConstant *LC = dyn_cast<SCEVConstant>(AddRec->getOperand(0));
10471 const SCEVConstant *MC = dyn_cast<SCEVConstant>(AddRec->getOperand(1));
10472 const SCEVConstant *NC = dyn_cast<SCEVConstant>(AddRec->getOperand(2));
10473 LLVM_DEBUG(dbgs() << __func__ << ": analyzing quadratic addrec: "
10474 << *AddRec << '\n');
10475
10476 // We currently can only solve this if the coefficients are constants.
10477 if (!LC || !MC || !NC) {
10478 LLVM_DEBUG(dbgs() << __func__ << ": coefficients are not constant\n");
10479 return std::nullopt;
10480 }
10481
10482 APInt L = LC->getAPInt();
10483 APInt M = MC->getAPInt();
10484 APInt N = NC->getAPInt();
10485 assert(!N.isZero() && "This is not a quadratic addrec");
10486
10487 unsigned BitWidth = LC->getAPInt().getBitWidth();
10488 unsigned NewWidth = BitWidth + 1;
10489 LLVM_DEBUG(dbgs() << __func__ << ": addrec coeff bw: "
10490 << BitWidth << '\n');
10491 // The sign-extension (as opposed to a zero-extension) here matches the
10492 // extension used in SolveQuadraticEquationWrap (with the same motivation).
10493 N = N.sext(NewWidth);
10494 M = M.sext(NewWidth);
10495 L = L.sext(NewWidth);
10496
10497 // The increments are M, M+N, M+2N, ..., so the accumulated values are
10498 // L+M, (L+M)+(M+N), (L+M)+(M+N)+(M+2N), ..., that is,
10499 // L+M, L+2M+N, L+3M+3N, ...
10500 // After n iterations the accumulated value Acc is L + nM + n(n-1)/2 N.
10501 //
10502 // The equation Acc = 0 is then
10503 // L + nM + n(n-1)/2 N = 0, or 2L + 2M n + n(n-1) N = 0.
10504 // In a quadratic form it becomes:
10505 // N n^2 + (2M-N) n + 2L = 0.
10506
10507 APInt A = N;
10508 APInt B = 2 * M - A;
10509 APInt C = 2 * L;
10510 APInt T = APInt(NewWidth, 2);
10511 LLVM_DEBUG(dbgs() << __func__ << ": equation " << A << "x^2 + " << B
10512 << "x + " << C << ", coeff bw: " << NewWidth
10513 << ", multiplied by " << T << '\n');
10514 return std::make_tuple(A, B, C, T, BitWidth);
10515}
10516
10517/// Helper function to compare optional APInts:
10518/// (a) if X and Y both exist, return min(X, Y),
10519/// (b) if neither X nor Y exist, return std::nullopt,
10520/// (c) if exactly one of X and Y exists, return that value.
10521static std::optional<APInt> MinOptional(std::optional<APInt> X,
10522 std::optional<APInt> Y) {
10523 if (X && Y) {
10524 unsigned W = std::max(X->getBitWidth(), Y->getBitWidth());
10525 APInt XW = X->sext(W);
10526 APInt YW = Y->sext(W);
10527 return XW.slt(YW) ? *X : *Y;
10528 }
10529 if (!X && !Y)
10530 return std::nullopt;
10531 return X ? *X : *Y;
10532}
10533
10534/// Helper function to truncate an optional APInt to a given BitWidth.
10535/// When solving addrec-related equations, it is preferable to return a value
10536/// that has the same bit width as the original addrec's coefficients. If the
10537/// solution fits in the original bit width, truncate it (except for i1).
10538/// Returning a value of a different bit width may inhibit some optimizations.
10539///
10540/// In general, a solution to a quadratic equation generated from an addrec
10541/// may require BW+1 bits, where BW is the bit width of the addrec's
10542/// coefficients. The reason is that the coefficients of the quadratic
10543/// equation are BW+1 bits wide (to avoid truncation when converting from
10544/// the addrec to the equation).
10545static std::optional<APInt> TruncIfPossible(std::optional<APInt> X,
10546 unsigned BitWidth) {
10547 if (!X)
10548 return std::nullopt;
10549 unsigned W = X->getBitWidth();
10551 return X->trunc(BitWidth);
10552 return X;
10553}
10554
10555/// Let c(n) be the value of the quadratic chrec {L,+,M,+,N} after n
10556/// iterations. The values L, M, N are assumed to be signed, and they
10557/// should all have the same bit widths.
10558/// Find the least n >= 0 such that c(n) = 0 in the arithmetic modulo 2^BW,
10559/// where BW is the bit width of the addrec's coefficients.
10560/// If the calculated value is a BW-bit integer (for BW > 1), it will be
10561/// returned as such, otherwise the bit width of the returned value may
10562/// be greater than BW.
10563///
10564/// This function returns std::nullopt if
10565/// (a) the addrec coefficients are not constant, or
10566/// (b) SolveQuadraticEquationWrap was unable to find a solution. For cases
10567/// like x^2 = 5, no integer solutions exist, in other cases an integer
10568/// solution may exist, but SolveQuadraticEquationWrap may fail to find it.
10569static std::optional<APInt>
10571 APInt A, B, C, M;
10572 unsigned BitWidth;
10573 auto T = GetQuadraticEquation(AddRec);
10574 if (!T)
10575 return std::nullopt;
10576
10577 std::tie(A, B, C, M, BitWidth) = *T;
10578 LLVM_DEBUG(dbgs() << __func__ << ": solving for unsigned overflow\n");
10579 std::optional<APInt> X =
10581 if (!X)
10582 return std::nullopt;
10583
10584 ConstantInt *CX = ConstantInt::get(SE.getContext(), *X);
10585 ConstantInt *V = EvaluateConstantChrecAtConstant(AddRec, CX, SE);
10586 if (!V->isZero())
10587 return std::nullopt;
10588
10589 return TruncIfPossible(X, BitWidth);
10590}
10591
10592/// Let c(n) be the value of the quadratic chrec {0,+,M,+,N} after n
10593/// iterations. The values M, N are assumed to be signed, and they
10594/// should all have the same bit widths.
10595/// Find the least n such that c(n) does not belong to the given range,
10596/// while c(n-1) does.
10597///
10598/// This function returns std::nullopt if
10599/// (a) the addrec coefficients are not constant, or
10600/// (b) SolveQuadraticEquationWrap was unable to find a solution for the
10601/// bounds of the range.
10602static std::optional<APInt>
10604 const ConstantRange &Range, ScalarEvolution &SE) {
10605 assert(AddRec->getOperand(0)->isZero() &&
10606 "Starting value of addrec should be 0");
10607 LLVM_DEBUG(dbgs() << __func__ << ": solving boundary crossing for range "
10608 << Range << ", addrec " << *AddRec << '\n');
10609 // This case is handled in getNumIterationsInRange. Here we can assume that
10610 // we start in the range.
10611 assert(Range.contains(APInt(SE.getTypeSizeInBits(AddRec->getType()), 0)) &&
10612 "Addrec's initial value should be in range");
10613
10614 APInt A, B, C, M;
10615 unsigned BitWidth;
10616 auto T = GetQuadraticEquation(AddRec);
10617 if (!T)
10618 return std::nullopt;
10619
10620 // Be careful about the return value: there can be two reasons for not
10621 // returning an actual number. First, if no solutions to the equations
10622 // were found, and second, if the solutions don't leave the given range.
10623 // The first case means that the actual solution is "unknown", the second
10624 // means that it's known, but not valid. If the solution is unknown, we
10625 // cannot make any conclusions.
10626 // Return a pair: the optional solution and a flag indicating if the
10627 // solution was found.
10628 auto SolveForBoundary =
10629 [&](APInt Bound) -> std::pair<std::optional<APInt>, bool> {
10630 // Solve for signed overflow and unsigned overflow, pick the lower
10631 // solution.
10632 LLVM_DEBUG(dbgs() << "SolveQuadraticAddRecRange: checking boundary "
10633 << Bound << " (before multiplying by " << M << ")\n");
10634 Bound *= M; // The quadratic equation multiplier.
10635
10636 std::optional<APInt> SO;
10637 if (BitWidth > 1) {
10638 LLVM_DEBUG(dbgs() << "SolveQuadraticAddRecRange: solving for "
10639 "signed overflow\n");
10641 }
10642 LLVM_DEBUG(dbgs() << "SolveQuadraticAddRecRange: solving for "
10643 "unsigned overflow\n");
10644 std::optional<APInt> UO =
10646
10647 auto LeavesRange = [&] (const APInt &X) {
10648 ConstantInt *C0 = ConstantInt::get(SE.getContext(), X);
10649 ConstantInt *V0 = EvaluateConstantChrecAtConstant(AddRec, C0, SE);
10650 if (Range.contains(V0->getValue()))
10651 return false;
10652 // X should be at least 1, so X-1 is non-negative.
10653 ConstantInt *C1 = ConstantInt::get(SE.getContext(), X-1);
10654 ConstantInt *V1 = EvaluateConstantChrecAtConstant(AddRec, C1, SE);
10655 if (Range.contains(V1->getValue()))
10656 return true;
10657 return false;
10658 };
10659
10660 // If SolveQuadraticEquationWrap returns std::nullopt, it means that there
10661 // can be a solution, but the function failed to find it. We cannot treat it
10662 // as "no solution".
10663 if (!SO || !UO)
10664 return {std::nullopt, false};
10665
10666 // Check the smaller value first to see if it leaves the range.
10667 // At this point, both SO and UO must have values.
10668 std::optional<APInt> Min = MinOptional(SO, UO);
10669 if (LeavesRange(*Min))
10670 return { Min, true };
10671 std::optional<APInt> Max = Min == SO ? UO : SO;
10672 if (LeavesRange(*Max))
10673 return { Max, true };
10674
10675 // Solutions were found, but were eliminated, hence the "true".
10676 return {std::nullopt, true};
10677 };
10678
10679 std::tie(A, B, C, M, BitWidth) = *T;
10680 // Lower bound is inclusive, subtract 1 to represent the exiting value.
10681 APInt Lower = Range.getLower().sext(A.getBitWidth()) - 1;
10682 APInt Upper = Range.getUpper().sext(A.getBitWidth());
10683 auto SL = SolveForBoundary(Lower);
10684 auto SU = SolveForBoundary(Upper);
10685 // If any of the solutions was unknown, no meaninigful conclusions can
10686 // be made.
10687 if (!SL.second || !SU.second)
10688 return std::nullopt;
10689
10690 // Claim: The correct solution is not some value between Min and Max.
10691 //
10692 // Justification: Assuming that Min and Max are different values, one of
10693 // them is when the first signed overflow happens, the other is when the
10694 // first unsigned overflow happens. Crossing the range boundary is only
10695 // possible via an overflow (treating 0 as a special case of it, modeling
10696 // an overflow as crossing k*2^W for some k).
10697 //
10698 // The interesting case here is when Min was eliminated as an invalid
10699 // solution, but Max was not. The argument is that if there was another
10700 // overflow between Min and Max, it would also have been eliminated if
10701 // it was considered.
10702 //
10703 // For a given boundary, it is possible to have two overflows of the same
10704 // type (signed/unsigned) without having the other type in between: this
10705 // can happen when the vertex of the parabola is between the iterations
10706 // corresponding to the overflows. This is only possible when the two
10707 // overflows cross k*2^W for the same k. In such case, if the second one
10708 // left the range (and was the first one to do so), the first overflow
10709 // would have to enter the range, which would mean that either we had left
10710 // the range before or that we started outside of it. Both of these cases
10711 // are contradictions.
10712 //
10713 // Claim: In the case where SolveForBoundary returns std::nullopt, the correct
10714 // solution is not some value between the Max for this boundary and the
10715 // Min of the other boundary.
10716 //
10717 // Justification: Assume that we had such Max_A and Min_B corresponding
10718 // to range boundaries A and B and such that Max_A < Min_B. If there was
10719 // a solution between Max_A and Min_B, it would have to be caused by an
10720 // overflow corresponding to either A or B. It cannot correspond to B,
10721 // since Min_B is the first occurrence of such an overflow. If it
10722 // corresponded to A, it would have to be either a signed or an unsigned
10723 // overflow that is larger than both eliminated overflows for A. But
10724 // between the eliminated overflows and this overflow, the values would
10725 // cover the entire value space, thus crossing the other boundary, which
10726 // is a contradiction.
10727
10728 return TruncIfPossible(MinOptional(SL.first, SU.first), BitWidth);
10729}
10730
10731ScalarEvolution::ExitLimit ScalarEvolution::howFarToZero(const SCEV *V,
10732 const Loop *L,
10733 bool ControlsOnlyExit,
10734 bool AllowPredicates) {
10735
10736 // This is only used for loops with a "x != y" exit test. The exit condition
10737 // is now expressed as a single expression, V = x-y. So the exit test is
10738 // effectively V != 0. We know and take advantage of the fact that this
10739 // expression only being used in a comparison by zero context.
10740
10742 // If the value is a constant
10743 if (const SCEVConstant *C = dyn_cast<SCEVConstant>(V)) {
10744 // If the value is already zero, the branch will execute zero times.
10745 if (C->getValue()->isZero()) return C;
10746 return getCouldNotCompute(); // Otherwise it will loop infinitely.
10747 }
10748
10749 const SCEVAddRecExpr *AddRec =
10750 dyn_cast<SCEVAddRecExpr>(stripInjectiveFunctions(V));
10751
10752 if (!AddRec && AllowPredicates)
10753 // Try to make this an AddRec using runtime tests, in the first X
10754 // iterations of this loop, where X is the SCEV expression found by the
10755 // algorithm below.
10756 AddRec = convertSCEVToAddRecWithPredicates(V, L, Predicates);
10757
10758 if (!AddRec || AddRec->getLoop() != L)
10759 return getCouldNotCompute();
10760
10761 // If this is a quadratic (3-term) AddRec {L,+,M,+,N}, find the roots of
10762 // the quadratic equation to solve it.
10763 if (AddRec->isQuadratic() && AddRec->getType()->isIntegerTy()) {
10764 // We can only use this value if the chrec ends up with an exact zero
10765 // value at this index. When solving for "X*X != 5", for example, we
10766 // should not accept a root of 2.
10767 if (auto S = SolveQuadraticAddRecExact(AddRec, *this)) {
10768 const auto *R = cast<SCEVConstant>(getConstant(*S));
10769 return ExitLimit(R, R, R, false, Predicates);
10770 }
10771 return getCouldNotCompute();
10772 }
10773
10774 // Otherwise we can only handle this if it is affine.
10775 if (!AddRec->isAffine())
10776 return getCouldNotCompute();
10777
10778 // If this is an affine expression, the execution count of this branch is
10779 // the minimum unsigned root of the following equation:
10780 //
10781 // Start + Step*N = 0 (mod 2^BW)
10782 //
10783 // equivalent to:
10784 //
10785 // Step*N = -Start (mod 2^BW)
10786 //
10787 // where BW is the common bit width of Start and Step.
10788
10789 // Get the initial value for the loop.
10790 const SCEV *Start = getSCEVAtScope(AddRec->getStart(), L->getParentLoop());
10791 const SCEV *Step = getSCEVAtScope(AddRec->getOperand(1), L->getParentLoop());
10792
10793 if (!isLoopInvariant(Step, L))
10794 return getCouldNotCompute();
10795
10796 LoopGuards Guards = LoopGuards::collect(L, *this);
10797 // Specialize step for this loop so we get context sensitive facts below.
10798 const SCEV *StepWLG = applyLoopGuards(Step, Guards);
10799
10800 // For positive steps (counting up until unsigned overflow):
10801 // N = -Start/Step (as unsigned)
10802 // For negative steps (counting down to zero):
10803 // N = Start/-Step
10804 // First compute the unsigned distance from zero in the direction of Step.
10805 bool CountDown = isKnownNegative(StepWLG);
10806 if (!CountDown && !isKnownNonNegative(StepWLG))
10807 return getCouldNotCompute();
10808
10809 const SCEV *Distance = CountDown ? Start : getNegativeSCEV(Start);
10810 // Handle unitary steps, which cannot wraparound.
10811 // 1*N = -Start; -1*N = Start (mod 2^BW), so:
10812 // N = Distance (as unsigned)
10813
10814 if (match(Step, m_CombineOr(m_scev_One(), m_scev_AllOnes()))) {
10815 APInt MaxBECount = getUnsignedRangeMax(applyLoopGuards(Distance, Guards));
10816 MaxBECount = APIntOps::umin(MaxBECount, getUnsignedRangeMax(Distance));
10817
10818 // When a loop like "for (int i = 0; i != n; ++i) { /* body */ }" is rotated,
10819 // we end up with a loop whose backedge-taken count is n - 1. Detect this
10820 // case, and see if we can improve the bound.
10821 //
10822 // Explicitly handling this here is necessary because getUnsignedRange
10823 // isn't context-sensitive; it doesn't know that we only care about the
10824 // range inside the loop.
10825 const SCEV *Zero = getZero(Distance->getType());
10826 const SCEV *One = getOne(Distance->getType());
10827 const SCEV *DistancePlusOne = getAddExpr(Distance, One);
10828 if (isLoopEntryGuardedByCond(L, ICmpInst::ICMP_NE, DistancePlusOne, Zero)) {
10829 // If Distance + 1 doesn't overflow, we can compute the maximum distance
10830 // as "unsigned_max(Distance + 1) - 1".
10831 ConstantRange CR = getUnsignedRange(DistancePlusOne);
10832 MaxBECount = APIntOps::umin(MaxBECount, CR.getUnsignedMax() - 1);
10833 }
10834 return ExitLimit(Distance, getConstant(MaxBECount), Distance, false,
10835 Predicates);
10836 }
10837
10838 // If the condition controls loop exit (the loop exits only if the expression
10839 // is true) and the addition is no-wrap we can use unsigned divide to
10840 // compute the backedge count. In this case, the step may not divide the
10841 // distance, but we don't care because if the condition is "missed" the loop
10842 // will have undefined behavior due to wrapping.
10843 if (ControlsOnlyExit && AddRec->hasNoSelfWrap() &&
10844 loopHasNoAbnormalExits(AddRec->getLoop())) {
10845
10846 // If the stride is zero and the start is non-zero, the loop must be
10847 // infinite. In C++, most loops are finite by assumption, in which case the
10848 // step being zero implies UB must execute if the loop is entered.
10849 if (!(loopIsFiniteByAssumption(L) && isKnownNonZero(Start)) &&
10850 !isKnownNonZero(StepWLG))
10851 return getCouldNotCompute();
10852
10853 const SCEV *Exact =
10854 getUDivExpr(Distance, CountDown ? getNegativeSCEV(Step) : Step);
10855 const SCEV *ConstantMax = getCouldNotCompute();
10856 if (Exact != getCouldNotCompute()) {
10857 APInt MaxInt = getUnsignedRangeMax(applyLoopGuards(Exact, Guards));
10858 ConstantMax =
10860 }
10861 const SCEV *SymbolicMax =
10862 isa<SCEVCouldNotCompute>(Exact) ? ConstantMax : Exact;
10863 return ExitLimit(Exact, ConstantMax, SymbolicMax, false, Predicates);
10864 }
10865
10866 // Solve the general equation.
10867 const SCEVConstant *StepC = dyn_cast<SCEVConstant>(Step);
10868 if (!StepC || StepC->getValue()->isZero())
10869 return getCouldNotCompute();
10870 const SCEV *E = SolveLinEquationWithOverflow(
10871 StepC->getAPInt(), getNegativeSCEV(Start),
10872 AllowPredicates ? &Predicates : nullptr, *this, L);
10873
10874 const SCEV *M = E;
10875 if (E != getCouldNotCompute()) {
10876 APInt MaxWithGuards = getUnsignedRangeMax(applyLoopGuards(E, Guards));
10877 M = getConstant(APIntOps::umin(MaxWithGuards, getUnsignedRangeMax(E)));
10878 }
10879 auto *S = isa<SCEVCouldNotCompute>(E) ? M : E;
10880 return ExitLimit(E, M, S, false, Predicates);
10881}
10882
10883ScalarEvolution::ExitLimit
10884ScalarEvolution::howFarToNonZero(const SCEV *V, const Loop *L) {
10885 // Loops that look like: while (X == 0) are very strange indeed. We don't
10886 // handle them yet except for the trivial case. This could be expanded in the
10887 // future as needed.
10888
10889 // If the value is a constant, check to see if it is known to be non-zero
10890 // already. If so, the backedge will execute zero times.
10891 if (const SCEVConstant *C = dyn_cast<SCEVConstant>(V)) {
10892 if (!C->getValue()->isZero())
10893 return getZero(C->getType());
10894 return getCouldNotCompute(); // Otherwise it will loop infinitely.
10895 }
10896
10897 // We could implement others, but I really doubt anyone writes loops like
10898 // this, and if they did, they would already be constant folded.
10899 return getCouldNotCompute();
10900}
10901
10902std::pair<const BasicBlock *, const BasicBlock *>
10903ScalarEvolution::getPredecessorWithUniqueSuccessorForBB(const BasicBlock *BB)
10904 const {
10905 // If the block has a unique predecessor, then there is no path from the
10906 // predecessor to the block that does not go through the direct edge
10907 // from the predecessor to the block.
10908 if (const BasicBlock *Pred = BB->getSinglePredecessor())
10909 return {Pred, BB};
10910
10911 // A loop's header is defined to be a block that dominates the loop.
10912 // If the header has a unique predecessor outside the loop, it must be
10913 // a block that has exactly one successor that can reach the loop.
10914 if (const Loop *L = LI.getLoopFor(BB))
10915 return {L->getLoopPredecessor(), L->getHeader()};
10916
10917 return {nullptr, BB};
10918}
10919
10920/// SCEV structural equivalence is usually sufficient for testing whether two
10921/// expressions are equal, however for the purposes of looking for a condition
10922/// guarding a loop, it can be useful to be a little more general, since a
10923/// front-end may have replicated the controlling expression.
10924static bool HasSameValue(const SCEV *A, const SCEV *B) {
10925 // Quick check to see if they are the same SCEV.
10926 if (A == B) return true;
10927
10928 auto ComputesEqualValues = [](const Instruction *A, const Instruction *B) {
10929 // Not all instructions that are "identical" compute the same value. For
10930 // instance, two distinct alloca instructions allocating the same type are
10931 // identical and do not read memory; but compute distinct values.
10932 return A->isIdenticalTo(B) && (isa<BinaryOperator>(A) || isa<GetElementPtrInst>(A));
10933 };
10934
10935 // Otherwise, if they're both SCEVUnknown, it's possible that they hold
10936 // two different instructions with the same value. Check for this case.
10937 if (const SCEVUnknown *AU = dyn_cast<SCEVUnknown>(A))
10938 if (const SCEVUnknown *BU = dyn_cast<SCEVUnknown>(B))
10939 if (const Instruction *AI = dyn_cast<Instruction>(AU->getValue()))
10940 if (const Instruction *BI = dyn_cast<Instruction>(BU->getValue()))
10941 if (ComputesEqualValues(AI, BI))
10942 return true;
10943
10944 // Otherwise assume they may have a different value.
10945 return false;
10946}
10947
10948static bool MatchBinarySub(const SCEV *S, SCEVUse &LHS, SCEVUse &RHS) {
10949 const SCEV *Op0, *Op1;
10950 if (!match(S, m_scev_Add(m_SCEV(Op0), m_SCEV(Op1))))
10951 return false;
10952 if (match(Op0, m_scev_Mul(m_scev_AllOnes(), m_SCEV(RHS)))) {
10953 LHS = Op1;
10954 return true;
10955 }
10956 if (match(Op1, m_scev_Mul(m_scev_AllOnes(), m_SCEV(RHS)))) {
10957 LHS = Op0;
10958 return true;
10959 }
10960 return false;
10961}
10962
10964 SCEVUse &RHS, unsigned Depth) {
10965 bool Changed = false;
10966 // Simplifies ICMP to trivial true or false by turning it into '0 == 0' or
10967 // '0 != 0'.
10968 auto TrivialCase = [&](bool TriviallyTrue) {
10970 Pred = TriviallyTrue ? ICmpInst::ICMP_EQ : ICmpInst::ICMP_NE;
10971 return true;
10972 };
10973 // If we hit the max recursion limit bail out.
10974 if (Depth >= 3)
10975 return false;
10976
10977 const SCEV *NewLHS, *NewRHS;
10978 if (match(LHS, m_scev_c_Mul(m_SCEV(NewLHS), m_SCEVVScale())) &&
10979 match(RHS, m_scev_c_Mul(m_SCEV(NewRHS), m_SCEVVScale()))) {
10980 const SCEVMulExpr *LMul = cast<SCEVMulExpr>(LHS);
10981 const SCEVMulExpr *RMul = cast<SCEVMulExpr>(RHS);
10982
10983 // (X * vscale) pred (Y * vscale) ==> X pred Y
10984 // when both multiples are NSW.
10985 // (X * vscale) uicmp/eq/ne (Y * vscale) ==> X uicmp/eq/ne Y
10986 // when both multiples are NUW.
10987 if ((LMul->hasNoSignedWrap() && RMul->hasNoSignedWrap()) ||
10988 (LMul->hasNoUnsignedWrap() && RMul->hasNoUnsignedWrap() &&
10989 !ICmpInst::isSigned(Pred))) {
10990 LHS = NewLHS;
10991 RHS = NewRHS;
10992 Changed = true;
10993 }
10994 }
10995
10996 // Canonicalize a constant to the right side.
10997 if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(LHS)) {
10998 // Check for both operands constant.
10999 if (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS)) {
11000 if (!ICmpInst::compare(LHSC->getAPInt(), RHSC->getAPInt(), Pred))
11001 return TrivialCase(false);
11002 return TrivialCase(true);
11003 }
11004 // Otherwise swap the operands to put the constant on the right.
11005 std::swap(LHS, RHS);
11007 Changed = true;
11008 }
11009
11010 // If we're comparing an addrec with a value which is loop-invariant in the
11011 // addrec's loop, put the addrec on the left. Also make a dominance check,
11012 // as both operands could be addrecs loop-invariant in each other's loop.
11013 if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(RHS)) {
11014 const Loop *L = AR->getLoop();
11015 if (isLoopInvariant(LHS, L) && properlyDominates(LHS, L->getHeader())) {
11016 std::swap(LHS, RHS);
11018 Changed = true;
11019 }
11020 }
11021
11022 // If there's a constant operand, canonicalize comparisons with boundary
11023 // cases, and canonicalize *-or-equal comparisons to regular comparisons.
11024 if (const SCEVConstant *RC = dyn_cast<SCEVConstant>(RHS)) {
11025 const APInt &RA = RC->getAPInt();
11026
11027 bool SimplifiedByConstantRange = false;
11028
11029 if (!ICmpInst::isEquality(Pred)) {
11031 if (ExactCR.isFullSet())
11032 return TrivialCase(true);
11033 if (ExactCR.isEmptySet())
11034 return TrivialCase(false);
11035
11036 APInt NewRHS;
11037 CmpInst::Predicate NewPred;
11038 if (ExactCR.getEquivalentICmp(NewPred, NewRHS) &&
11039 ICmpInst::isEquality(NewPred)) {
11040 // We were able to convert an inequality to an equality.
11041 Pred = NewPred;
11042 RHS = getConstant(NewRHS);
11043 Changed = SimplifiedByConstantRange = true;
11044 }
11045 }
11046
11047 if (!SimplifiedByConstantRange) {
11048 switch (Pred) {
11049 default:
11050 break;
11051 case ICmpInst::ICMP_EQ:
11052 case ICmpInst::ICMP_NE:
11053 // Fold ((-1) * %a) + %b == 0 (equivalent to %b-%a == 0) into %a == %b.
11054 if (RA.isZero() && MatchBinarySub(LHS, LHS, RHS))
11055 Changed = true;
11056 break;
11057
11058 // The "Should have been caught earlier!" messages refer to the fact
11059 // that the ExactCR.isFullSet() or ExactCR.isEmptySet() check above
11060 // should have fired on the corresponding cases, and canonicalized the
11061 // check to trivial case.
11062
11063 case ICmpInst::ICMP_UGE:
11064 assert(!RA.isMinValue() && "Should have been caught earlier!");
11065 Pred = ICmpInst::ICMP_UGT;
11066 RHS = getConstant(RA - 1);
11067 Changed = true;
11068 break;
11069 case ICmpInst::ICMP_ULE:
11070 assert(!RA.isMaxValue() && "Should have been caught earlier!");
11071 Pred = ICmpInst::ICMP_ULT;
11072 RHS = getConstant(RA + 1);
11073 Changed = true;
11074 break;
11075 case ICmpInst::ICMP_SGE:
11076 assert(!RA.isMinSignedValue() && "Should have been caught earlier!");
11077 Pred = ICmpInst::ICMP_SGT;
11078 RHS = getConstant(RA - 1);
11079 Changed = true;
11080 break;
11081 case ICmpInst::ICMP_SLE:
11082 assert(!RA.isMaxSignedValue() && "Should have been caught earlier!");
11083 Pred = ICmpInst::ICMP_SLT;
11084 RHS = getConstant(RA + 1);
11085 Changed = true;
11086 break;
11087 }
11088 }
11089 }
11090
11091 // Check for obvious equality.
11092 if (HasSameValue(LHS, RHS)) {
11093 if (ICmpInst::isTrueWhenEqual(Pred))
11094 return TrivialCase(true);
11096 return TrivialCase(false);
11097 }
11098
11099 // If possible, canonicalize GE/LE comparisons to GT/LT comparisons, by
11100 // adding or subtracting 1 from one of the operands.
11101 switch (Pred) {
11102 case ICmpInst::ICMP_SLE:
11103 if (!getSignedRangeMax(RHS).isMaxSignedValue()) {
11104 RHS = getAddExpr(getConstant(RHS->getType(), 1, true), RHS,
11106 Pred = ICmpInst::ICMP_SLT;
11107 Changed = true;
11108 } else if (!getSignedRangeMin(LHS).isMinSignedValue()) {
11109 LHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), LHS,
11111 Pred = ICmpInst::ICMP_SLT;
11112 Changed = true;
11113 }
11114 break;
11115 case ICmpInst::ICMP_SGE:
11116 if (!getSignedRangeMin(RHS).isMinSignedValue()) {
11117 RHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), RHS,
11119 Pred = ICmpInst::ICMP_SGT;
11120 Changed = true;
11121 } else if (!getSignedRangeMax(LHS).isMaxSignedValue()) {
11122 LHS = getAddExpr(getConstant(RHS->getType(), 1, true), LHS,
11124 Pred = ICmpInst::ICMP_SGT;
11125 Changed = true;
11126 }
11127 break;
11128 case ICmpInst::ICMP_ULE:
11129 if (!getUnsignedRangeMax(RHS).isMaxValue()) {
11130 RHS = getAddExpr(getConstant(RHS->getType(), 1, true), RHS,
11132 Pred = ICmpInst::ICMP_ULT;
11133 Changed = true;
11134 } else if (!getUnsignedRangeMin(LHS).isMinValue()) {
11135 LHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), LHS);
11136 Pred = ICmpInst::ICMP_ULT;
11137 Changed = true;
11138 }
11139 break;
11140 case ICmpInst::ICMP_UGE:
11141 // If RHS is an op we can fold the -1, try that first.
11142 // Otherwise prefer LHS to preserve the nuw flag.
11143 if ((isa<SCEVConstant>(RHS) ||
11145 isa<SCEVConstant>(cast<SCEVNAryExpr>(RHS)->getOperand(0)))) &&
11146 !getUnsignedRangeMin(RHS).isMinValue()) {
11147 RHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), RHS);
11148 Pred = ICmpInst::ICMP_UGT;
11149 Changed = true;
11150 } else if (!getUnsignedRangeMax(LHS).isMaxValue()) {
11151 LHS = getAddExpr(getConstant(RHS->getType(), 1, true), LHS,
11153 Pred = ICmpInst::ICMP_UGT;
11154 Changed = true;
11155 } else if (!getUnsignedRangeMin(RHS).isMinValue()) {
11156 RHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), RHS);
11157 Pred = ICmpInst::ICMP_UGT;
11158 Changed = true;
11159 }
11160 break;
11161 default:
11162 break;
11163 }
11164
11165 // TODO: More simplifications are possible here.
11166
11167 // Recursively simplify until we either hit a recursion limit or nothing
11168 // changes.
11169 if (Changed)
11170 (void)SimplifyICmpOperands(Pred, LHS, RHS, Depth + 1);
11171
11172 return Changed;
11173}
11174
11176 return getSignedRangeMax(S).isNegative();
11177}
11178
11182
11184 return !getSignedRangeMin(S).isNegative();
11185}
11186
11190
11192 // Query push down for cases where the unsigned range is
11193 // less than sufficient.
11194 if (const auto *SExt = dyn_cast<SCEVSignExtendExpr>(S))
11195 return isKnownNonZero(SExt->getOperand(0));
11196 return getUnsignedRangeMin(S) != 0;
11197}
11198
11200 bool OrNegative) {
11201 auto NonRecursive = [OrNegative](const SCEV *S) {
11202 if (auto *C = dyn_cast<SCEVConstant>(S))
11203 return C->getAPInt().isPowerOf2() ||
11204 (OrNegative && C->getAPInt().isNegatedPowerOf2());
11205
11206 // vscale is a power-of-two.
11207 return isa<SCEVVScale>(S);
11208 };
11209
11210 if (NonRecursive(S))
11211 return true;
11212
11213 auto *Mul = dyn_cast<SCEVMulExpr>(S);
11214 if (!Mul)
11215 return false;
11216 return all_of(Mul->operands(), NonRecursive) && (OrZero || isKnownNonZero(S));
11217}
11218
11220 const SCEV *S, uint64_t M,
11222 if (M == 0)
11223 return false;
11224 if (M == 1)
11225 return true;
11226
11227 // Recursively check AddRec operands. An AddRecExpr S is a multiple of M if S
11228 // starts with a multiple of M and at every iteration step S only adds
11229 // multiples of M.
11230 if (auto *AddRec = dyn_cast<SCEVAddRecExpr>(S))
11231 return isKnownMultipleOf(AddRec->getStart(), M, Assumptions) &&
11232 isKnownMultipleOf(AddRec->getStepRecurrence(*this), M, Assumptions);
11233
11234 // For a constant, check that "S % M == 0".
11235 if (auto *Cst = dyn_cast<SCEVConstant>(S)) {
11236 APInt C = Cst->getAPInt();
11237 return C.urem(M) == 0;
11238 }
11239
11240 // TODO: Also check other SCEV expressions, i.e., SCEVAddRecExpr, etc.
11241
11242 // Basic tests have failed.
11243 // Check "S % M == 0" at compile time and record runtime Assumptions.
11244 auto *STy = dyn_cast<IntegerType>(S->getType());
11245 const SCEV *SmodM =
11246 getURemExpr(S, getConstant(ConstantInt::get(STy, M, false)));
11247 const SCEV *Zero = getZero(STy);
11248
11249 // Check whether "S % M == 0" is known at compile time.
11250 if (isKnownPredicate(ICmpInst::ICMP_EQ, SmodM, Zero))
11251 return true;
11252
11253 // Check whether "S % M != 0" is known at compile time.
11254 if (isKnownPredicate(ICmpInst::ICMP_NE, SmodM, Zero))
11255 return false;
11256
11258
11259 // Detect redundant predicates.
11260 for (auto *A : Assumptions)
11261 if (A->implies(P, *this))
11262 return true;
11263
11264 // Only record non-redundant predicates.
11265 Assumptions.push_back(P);
11266 return true;
11267}
11268
11270 return ((isKnownNonNegative(S1) && isKnownNonNegative(S2)) ||
11272}
11273
11274std::pair<const SCEV *, const SCEV *>
11276 // Compute SCEV on entry of loop L.
11277 const SCEV *Start = SCEVInitRewriter::rewrite(S, L, *this);
11278 if (Start == getCouldNotCompute())
11279 return { Start, Start };
11280 // Compute post increment SCEV for loop L.
11281 const SCEV *PostInc = SCEVPostIncRewriter::rewrite(S, L, *this);
11282 assert(PostInc != getCouldNotCompute() && "Unexpected could not compute");
11283 return { Start, PostInc };
11284}
11285
11287 SCEVUse RHS) {
11288 // First collect all loops.
11290 getUsedLoops(LHS, LoopsUsed);
11291 getUsedLoops(RHS, LoopsUsed);
11292
11293 if (LoopsUsed.empty())
11294 return false;
11295
11296 // Domination relationship must be a linear order on collected loops.
11297#ifndef NDEBUG
11298 for (const auto *L1 : LoopsUsed)
11299 for (const auto *L2 : LoopsUsed)
11300 assert((DT.dominates(L1->getHeader(), L2->getHeader()) ||
11301 DT.dominates(L2->getHeader(), L1->getHeader())) &&
11302 "Domination relationship is not a linear order");
11303#endif
11304
11305 const Loop *MDL =
11306 *llvm::max_element(LoopsUsed, [&](const Loop *L1, const Loop *L2) {
11307 return DT.properlyDominates(L1->getHeader(), L2->getHeader());
11308 });
11309
11310 // Get init and post increment value for LHS.
11311 auto SplitLHS = SplitIntoInitAndPostInc(MDL, LHS);
11312 // if LHS contains unknown non-invariant SCEV then bail out.
11313 if (SplitLHS.first == getCouldNotCompute())
11314 return false;
11315 assert (SplitLHS.second != getCouldNotCompute() && "Unexpected CNC");
11316 // Get init and post increment value for RHS.
11317 auto SplitRHS = SplitIntoInitAndPostInc(MDL, RHS);
11318 // if RHS contains unknown non-invariant SCEV then bail out.
11319 if (SplitRHS.first == getCouldNotCompute())
11320 return false;
11321 assert (SplitRHS.second != getCouldNotCompute() && "Unexpected CNC");
11322 // It is possible that init SCEV contains an invariant load but it does
11323 // not dominate MDL and is not available at MDL loop entry, so we should
11324 // check it here.
11325 if (!isAvailableAtLoopEntry(SplitLHS.first, MDL) ||
11326 !isAvailableAtLoopEntry(SplitRHS.first, MDL))
11327 return false;
11328
11329 // It seems backedge guard check is faster than entry one so in some cases
11330 // it can speed up whole estimation by short circuit
11331 return isLoopBackedgeGuardedByCond(MDL, Pred, SplitLHS.second,
11332 SplitRHS.second) &&
11333 isLoopEntryGuardedByCond(MDL, Pred, SplitLHS.first, SplitRHS.first);
11334}
11335
11337 SCEVUse RHS) {
11338 // Canonicalize the inputs first.
11339 (void)SimplifyICmpOperands(Pred, LHS, RHS);
11340
11341 if (isKnownViaInduction(Pred, LHS, RHS))
11342 return true;
11343
11344 if (isKnownPredicateViaSplitting(Pred, LHS, RHS))
11345 return true;
11346
11347 // Otherwise see what can be done with some simple reasoning.
11348 return isKnownViaNonRecursiveReasoning(Pred, LHS, RHS);
11349}
11350
11352 const SCEV *LHS,
11353 const SCEV *RHS) {
11354 if (isKnownPredicate(Pred, LHS, RHS))
11355 return true;
11357 return false;
11358 return std::nullopt;
11359}
11360
11362 const SCEV *RHS,
11363 const Instruction *CtxI) {
11364 // TODO: Analyze guards and assumes from Context's block.
11365 return isKnownPredicate(Pred, LHS, RHS) ||
11366 isBasicBlockEntryGuardedByCond(CtxI->getParent(), Pred, LHS, RHS);
11367}
11368
11369std::optional<bool>
11371 const SCEV *RHS, const Instruction *CtxI) {
11372 std::optional<bool> KnownWithoutContext = evaluatePredicate(Pred, LHS, RHS);
11373 if (KnownWithoutContext)
11374 return KnownWithoutContext;
11375
11376 if (isBasicBlockEntryGuardedByCond(CtxI->getParent(), Pred, LHS, RHS))
11377 return true;
11379 CtxI->getParent(), ICmpInst::getInverseCmpPredicate(Pred), LHS, RHS))
11380 return false;
11381 return std::nullopt;
11382}
11383
11385 const SCEVAddRecExpr *LHS,
11386 const SCEV *RHS) {
11387 const Loop *L = LHS->getLoop();
11388 return isLoopEntryGuardedByCond(L, Pred, LHS->getStart(), RHS) &&
11389 isLoopBackedgeGuardedByCond(L, Pred, LHS->getPostIncExpr(*this), RHS);
11390}
11391
11392std::optional<ScalarEvolution::MonotonicPredicateType>
11394 ICmpInst::Predicate Pred) {
11395 auto Result = getMonotonicPredicateTypeImpl(LHS, Pred);
11396
11397#ifndef NDEBUG
11398 // Verify an invariant: inverting the predicate should turn a monotonically
11399 // increasing change to a monotonically decreasing one, and vice versa.
11400 if (Result) {
11401 auto ResultSwapped =
11402 getMonotonicPredicateTypeImpl(LHS, ICmpInst::getSwappedPredicate(Pred));
11403
11404 assert(*ResultSwapped != *Result &&
11405 "monotonicity should flip as we flip the predicate");
11406 }
11407#endif
11408
11409 return Result;
11410}
11411
11412std::optional<ScalarEvolution::MonotonicPredicateType>
11413ScalarEvolution::getMonotonicPredicateTypeImpl(const SCEVAddRecExpr *LHS,
11414 ICmpInst::Predicate Pred) {
11415 // A zero step value for LHS means the induction variable is essentially a
11416 // loop invariant value. We don't really depend on the predicate actually
11417 // flipping from false to true (for increasing predicates, and the other way
11418 // around for decreasing predicates), all we care about is that *if* the
11419 // predicate changes then it only changes from false to true.
11420 //
11421 // A zero step value in itself is not very useful, but there may be places
11422 // where SCEV can prove X >= 0 but not prove X > 0, so it is helpful to be
11423 // as general as possible.
11424
11425 // Only handle LE/LT/GE/GT predicates.
11426 if (!ICmpInst::isRelational(Pred))
11427 return std::nullopt;
11428
11429 bool IsGreater = ICmpInst::isGE(Pred) || ICmpInst::isGT(Pred);
11430 assert((IsGreater || ICmpInst::isLE(Pred) || ICmpInst::isLT(Pred)) &&
11431 "Should be greater or less!");
11432
11433 // Check that AR does not wrap.
11434 if (ICmpInst::isUnsigned(Pred)) {
11435 if (!LHS->hasNoUnsignedWrap())
11436 return std::nullopt;
11438 }
11439 assert(ICmpInst::isSigned(Pred) &&
11440 "Relational predicate is either signed or unsigned!");
11441 if (!LHS->hasNoSignedWrap())
11442 return std::nullopt;
11443
11444 const SCEV *Step = LHS->getStepRecurrence(*this);
11445
11446 if (isKnownNonNegative(Step))
11448
11449 if (isKnownNonPositive(Step))
11451
11452 return std::nullopt;
11453}
11454
11455std::optional<ScalarEvolution::LoopInvariantPredicate>
11457 const SCEV *RHS, const Loop *L,
11458 const Instruction *CtxI) {
11459 // If there is a loop-invariant, force it into the RHS, otherwise bail out.
11460 if (!isLoopInvariant(RHS, L)) {
11461 if (!isLoopInvariant(LHS, L))
11462 return std::nullopt;
11463
11464 std::swap(LHS, RHS);
11466 }
11467
11468 const SCEVAddRecExpr *ArLHS = dyn_cast<SCEVAddRecExpr>(LHS);
11469 if (!ArLHS || ArLHS->getLoop() != L)
11470 return std::nullopt;
11471
11472 auto MonotonicType = getMonotonicPredicateType(ArLHS, Pred);
11473 if (!MonotonicType)
11474 return std::nullopt;
11475 // If the predicate "ArLHS `Pred` RHS" monotonically increases from false to
11476 // true as the loop iterates, and the backedge is control dependent on
11477 // "ArLHS `Pred` RHS" == true then we can reason as follows:
11478 //
11479 // * if the predicate was false in the first iteration then the predicate
11480 // is never evaluated again, since the loop exits without taking the
11481 // backedge.
11482 // * if the predicate was true in the first iteration then it will
11483 // continue to be true for all future iterations since it is
11484 // monotonically increasing.
11485 //
11486 // For both the above possibilities, we can replace the loop varying
11487 // predicate with its value on the first iteration of the loop (which is
11488 // loop invariant).
11489 //
11490 // A similar reasoning applies for a monotonically decreasing predicate, by
11491 // replacing true with false and false with true in the above two bullets.
11493 auto P = Increasing ? Pred : ICmpInst::getInverseCmpPredicate(Pred);
11494
11495 if (isLoopBackedgeGuardedByCond(L, P, LHS, RHS))
11497 RHS);
11498
11499 if (!CtxI)
11500 return std::nullopt;
11501 // Try to prove via context.
11502 // TODO: Support other cases.
11503 switch (Pred) {
11504 default:
11505 break;
11506 case ICmpInst::ICMP_ULE:
11507 case ICmpInst::ICMP_ULT: {
11508 assert(ArLHS->hasNoUnsignedWrap() && "Is a requirement of monotonicity!");
11509 // Given preconditions
11510 // (1) ArLHS does not cross the border of positive and negative parts of
11511 // range because of:
11512 // - Positive step; (TODO: lift this limitation)
11513 // - nuw - does not cross zero boundary;
11514 // - nsw - does not cross SINT_MAX boundary;
11515 // (2) ArLHS <s RHS
11516 // (3) RHS >=s 0
11517 // we can replace the loop variant ArLHS <u RHS condition with loop
11518 // invariant Start(ArLHS) <u RHS.
11519 //
11520 // Because of (1) there are two options:
11521 // - ArLHS is always negative. It means that ArLHS <u RHS is always false;
11522 // - ArLHS is always non-negative. Because of (3) RHS is also non-negative.
11523 // It means that ArLHS <s RHS <=> ArLHS <u RHS.
11524 // Because of (2) ArLHS <u RHS is trivially true.
11525 // All together it means that ArLHS <u RHS <=> Start(ArLHS) >=s 0.
11526 // We can strengthen this to Start(ArLHS) <u RHS.
11527 auto SignFlippedPred = ICmpInst::getFlippedSignednessPredicate(Pred);
11528 if (ArLHS->hasNoSignedWrap() && ArLHS->isAffine() &&
11529 isKnownPositive(ArLHS->getStepRecurrence(*this)) &&
11530 isKnownNonNegative(RHS) &&
11531 isKnownPredicateAt(SignFlippedPred, ArLHS, RHS, CtxI))
11533 RHS);
11534 }
11535 }
11536
11537 return std::nullopt;
11538}
11539
11540std::optional<ScalarEvolution::LoopInvariantPredicate>
11542 CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS, const Loop *L,
11543 const Instruction *CtxI, const SCEV *MaxIter) {
11545 Pred, LHS, RHS, L, CtxI, MaxIter))
11546 return LIP;
11547 if (auto *UMin = dyn_cast<SCEVUMinExpr>(MaxIter))
11548 // Number of iterations expressed as UMIN isn't always great for expressing
11549 // the value on the last iteration. If the straightforward approach didn't
11550 // work, try the following trick: if the a predicate is invariant for X, it
11551 // is also invariant for umin(X, ...). So try to find something that works
11552 // among subexpressions of MaxIter expressed as umin.
11553 for (SCEVUse Op : UMin->operands())
11555 Pred, LHS, RHS, L, CtxI, Op))
11556 return LIP;
11557 return std::nullopt;
11558}
11559
11560std::optional<ScalarEvolution::LoopInvariantPredicate>
11562 CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS, const Loop *L,
11563 const Instruction *CtxI, const SCEV *MaxIter) {
11564 // Try to prove the following set of facts:
11565 // - The predicate is monotonic in the iteration space.
11566 // - If the check does not fail on the 1st iteration:
11567 // - No overflow will happen during first MaxIter iterations;
11568 // - It will not fail on the MaxIter'th iteration.
11569 // If the check does fail on the 1st iteration, we leave the loop and no
11570 // other checks matter.
11571
11572 // If there is a loop-invariant, force it into the RHS, otherwise bail out.
11573 if (!isLoopInvariant(RHS, L)) {
11574 if (!isLoopInvariant(LHS, L))
11575 return std::nullopt;
11576
11577 std::swap(LHS, RHS);
11579 }
11580
11581 auto *AR = dyn_cast<SCEVAddRecExpr>(LHS);
11582 if (!AR || AR->getLoop() != L)
11583 return std::nullopt;
11584
11585 // Even if both are valid, we need to consistently chose the unsigned or the
11586 // signed predicate below, not mixtures of both. For now, prefer the unsigned
11587 // predicate.
11588 Pred = Pred.dropSameSign();
11589
11590 // The predicate must be relational (i.e. <, <=, >=, >).
11591 if (!ICmpInst::isRelational(Pred))
11592 return std::nullopt;
11593
11594 // TODO: Support steps other than +/- 1.
11595 const SCEV *Step = AR->getStepRecurrence(*this);
11596 auto *One = getOne(Step->getType());
11597 auto *MinusOne = getNegativeSCEV(One);
11598 if (Step != One && Step != MinusOne)
11599 return std::nullopt;
11600
11601 // Type mismatch here means that MaxIter is potentially larger than max
11602 // unsigned value in start type, which mean we cannot prove no wrap for the
11603 // indvar.
11604 if (AR->getType() != MaxIter->getType())
11605 return std::nullopt;
11606
11607 // Value of IV on suggested last iteration.
11608 const SCEV *Last = AR->evaluateAtIteration(MaxIter, *this);
11609 // Does it still meet the requirement?
11610 if (!isLoopBackedgeGuardedByCond(L, Pred, Last, RHS))
11611 return std::nullopt;
11612 // Because step is +/- 1 and MaxIter has same type as Start (i.e. it does
11613 // not exceed max unsigned value of this type), this effectively proves
11614 // that there is no wrap during the iteration. To prove that there is no
11615 // signed/unsigned wrap, we need to check that
11616 // Start <= Last for step = 1 or Start >= Last for step = -1.
11617 ICmpInst::Predicate NoOverflowPred =
11619 if (Step == MinusOne)
11620 NoOverflowPred = ICmpInst::getSwappedPredicate(NoOverflowPred);
11621 const SCEV *Start = AR->getStart();
11622 if (!isKnownPredicateAt(NoOverflowPred, Start, Last, CtxI))
11623 return std::nullopt;
11624
11625 // Everything is fine.
11626 return ScalarEvolution::LoopInvariantPredicate(Pred, Start, RHS);
11627}
11628
11629bool ScalarEvolution::isKnownPredicateViaConstantRanges(CmpPredicate Pred,
11630 SCEVUse LHS,
11631 SCEVUse RHS) {
11632 if (HasSameValue(LHS, RHS))
11633 return ICmpInst::isTrueWhenEqual(Pred);
11634
11635 auto CheckRange = [&](bool IsSigned) {
11636 auto RangeLHS = IsSigned ? getSignedRange(LHS) : getUnsignedRange(LHS);
11637 auto RangeRHS = IsSigned ? getSignedRange(RHS) : getUnsignedRange(RHS);
11638 return RangeLHS.icmp(Pred, RangeRHS);
11639 };
11640
11641 // The check at the top of the function catches the case where the values are
11642 // known to be equal.
11643 if (Pred == CmpInst::ICMP_EQ)
11644 return false;
11645
11646 if (Pred == CmpInst::ICMP_NE) {
11647 if (CheckRange(true) || CheckRange(false))
11648 return true;
11649 auto *Diff = getMinusSCEV(LHS, RHS);
11650 return !isa<SCEVCouldNotCompute>(Diff) && isKnownNonZero(Diff);
11651 }
11652
11653 return CheckRange(CmpInst::isSigned(Pred));
11654}
11655
11656bool ScalarEvolution::isKnownPredicateViaNoOverflow(CmpPredicate Pred,
11657 SCEVUse LHS, SCEVUse RHS) {
11658 // Match X to (A + C1)<ExpectedFlags> and Y to (A + C2)<ExpectedFlags>, where
11659 // C1 and C2 are constant integers. If either X or Y are not add expressions,
11660 // consider them as X + 0 and Y + 0 respectively. C1 and C2 are returned via
11661 // OutC1 and OutC2.
11662 auto MatchBinaryAddToConst = [this](SCEVUse X, SCEVUse Y, APInt &OutC1,
11663 APInt &OutC2,
11664 SCEV::NoWrapFlags ExpectedFlags) {
11665 SCEVUse XNonConstOp, XConstOp;
11666 SCEVUse YNonConstOp, YConstOp;
11667 SCEV::NoWrapFlags XFlagsPresent;
11668 SCEV::NoWrapFlags YFlagsPresent;
11669
11670 if (!splitBinaryAdd(X, XConstOp, XNonConstOp, XFlagsPresent)) {
11671 XConstOp = getZero(X->getType());
11672 XNonConstOp = X;
11673 XFlagsPresent = ExpectedFlags;
11674 }
11675 if (!isa<SCEVConstant>(XConstOp))
11676 return false;
11677
11678 if (!splitBinaryAdd(Y, YConstOp, YNonConstOp, YFlagsPresent)) {
11679 YConstOp = getZero(Y->getType());
11680 YNonConstOp = Y;
11681 YFlagsPresent = ExpectedFlags;
11682 }
11683
11684 if (YNonConstOp != XNonConstOp)
11685 return false;
11686
11687 if (!isa<SCEVConstant>(YConstOp))
11688 return false;
11689
11690 // When matching ADDs with NUW flags (and unsigned predicates), only the
11691 // second ADD (with the larger constant) requires NUW.
11692 if ((YFlagsPresent & ExpectedFlags) != ExpectedFlags)
11693 return false;
11694 if (ExpectedFlags != SCEV::FlagNUW &&
11695 (XFlagsPresent & ExpectedFlags) != ExpectedFlags) {
11696 return false;
11697 }
11698
11699 OutC1 = cast<SCEVConstant>(XConstOp)->getAPInt();
11700 OutC2 = cast<SCEVConstant>(YConstOp)->getAPInt();
11701
11702 return true;
11703 };
11704
11705 APInt C1;
11706 APInt C2;
11707
11708 switch (Pred) {
11709 default:
11710 break;
11711
11712 case ICmpInst::ICMP_SGE:
11713 std::swap(LHS, RHS);
11714 [[fallthrough]];
11715 case ICmpInst::ICMP_SLE:
11716 // (X + C1)<nsw> s<= (X + C2)<nsw> if C1 s<= C2.
11717 if (MatchBinaryAddToConst(LHS, RHS, C1, C2, SCEV::FlagNSW) && C1.sle(C2))
11718 return true;
11719
11720 break;
11721
11722 case ICmpInst::ICMP_SGT:
11723 std::swap(LHS, RHS);
11724 [[fallthrough]];
11725 case ICmpInst::ICMP_SLT:
11726 // (X + C1)<nsw> s< (X + C2)<nsw> if C1 s< C2.
11727 if (MatchBinaryAddToConst(LHS, RHS, C1, C2, SCEV::FlagNSW) && C1.slt(C2))
11728 return true;
11729
11730 break;
11731
11732 case ICmpInst::ICMP_UGE:
11733 std::swap(LHS, RHS);
11734 [[fallthrough]];
11735 case ICmpInst::ICMP_ULE:
11736 // (X + C1) u<= (X + C2)<nuw> for C1 u<= C2.
11737 if (MatchBinaryAddToConst(LHS, RHS, C1, C2, SCEV::FlagNUW) && C1.ule(C2))
11738 return true;
11739
11740 break;
11741
11742 case ICmpInst::ICMP_UGT:
11743 std::swap(LHS, RHS);
11744 [[fallthrough]];
11745 case ICmpInst::ICMP_ULT:
11746 // (X + C1) u< (X + C2)<nuw> if C1 u< C2.
11747 if (MatchBinaryAddToConst(LHS, RHS, C1, C2, SCEV::FlagNUW) && C1.ult(C2))
11748 return true;
11749 break;
11750 }
11751
11752 return false;
11753}
11754
11755bool ScalarEvolution::isKnownPredicateViaSplitting(CmpPredicate Pred,
11756 SCEVUse LHS, SCEVUse RHS) {
11757 if (Pred != ICmpInst::ICMP_ULT || ProvingSplitPredicate)
11758 return false;
11759
11760 // Allowing arbitrary number of activations of isKnownPredicateViaSplitting on
11761 // the stack can result in exponential time complexity.
11762 SaveAndRestore Restore(ProvingSplitPredicate, true);
11763
11764 // If L >= 0 then I `ult` L <=> I >= 0 && I `slt` L
11765 //
11766 // To prove L >= 0 we use isKnownNonNegative whereas to prove I >= 0 we use
11767 // isKnownPredicate. isKnownPredicate is more powerful, but also more
11768 // expensive; and using isKnownNonNegative(RHS) is sufficient for most of the
11769 // interesting cases seen in practice. We can consider "upgrading" L >= 0 to
11770 // use isKnownPredicate later if needed.
11771 return isKnownNonNegative(RHS) &&
11774}
11775
11776bool ScalarEvolution::isImpliedViaGuard(const BasicBlock *BB, CmpPredicate Pred,
11777 const SCEV *LHS, const SCEV *RHS) {
11778 // No need to even try if we know the module has no guards.
11779 if (!HasGuards)
11780 return false;
11781
11782 return any_of(*BB, [&](const Instruction &I) {
11783 using namespace llvm::PatternMatch;
11784
11785 Value *Condition;
11787 m_Value(Condition))) &&
11788 isImpliedCond(Pred, LHS, RHS, Condition, false);
11789 });
11790}
11791
11792/// isLoopBackedgeGuardedByCond - Test whether the backedge of the loop is
11793/// protected by a conditional between LHS and RHS. This is used to
11794/// to eliminate casts.
11796 CmpPredicate Pred,
11797 const SCEV *LHS,
11798 const SCEV *RHS) {
11799 // Interpret a null as meaning no loop, where there is obviously no guard
11800 // (interprocedural conditions notwithstanding). Do not bother about
11801 // unreachable loops.
11802 if (!L || !DT.isReachableFromEntry(L->getHeader()))
11803 return true;
11804
11805 if (VerifyIR)
11806 assert(!verifyFunction(*L->getHeader()->getParent(), &dbgs()) &&
11807 "This cannot be done on broken IR!");
11808
11809
11810 if (isKnownViaNonRecursiveReasoning(Pred, LHS, RHS))
11811 return true;
11812
11813 BasicBlock *Latch = L->getLoopLatch();
11814 if (!Latch)
11815 return false;
11816
11817 CondBrInst *LoopContinuePredicate =
11819 if (LoopContinuePredicate &&
11820 isImpliedCond(Pred, LHS, RHS, LoopContinuePredicate->getCondition(),
11821 LoopContinuePredicate->getSuccessor(0) != L->getHeader()))
11822 return true;
11823
11824 // We don't want more than one activation of the following loops on the stack
11825 // -- that can lead to O(n!) time complexity.
11826 if (WalkingBEDominatingConds)
11827 return false;
11828
11829 SaveAndRestore ClearOnExit(WalkingBEDominatingConds, true);
11830
11831 // See if we can exploit a trip count to prove the predicate.
11832 const auto &BETakenInfo = getBackedgeTakenInfo(L);
11833 const SCEV *LatchBECount = BETakenInfo.getExact(Latch, this);
11834 if (LatchBECount != getCouldNotCompute()) {
11835 // We know that Latch branches back to the loop header exactly
11836 // LatchBECount times. This means the backdege condition at Latch is
11837 // equivalent to "{0,+,1} u< LatchBECount".
11838 Type *Ty = LatchBECount->getType();
11839 auto NoWrapFlags = SCEV::NoWrapFlags(SCEV::FlagNUW | SCEV::FlagNW);
11840 const SCEV *LoopCounter =
11841 getAddRecExpr(getZero(Ty), getOne(Ty), L, NoWrapFlags);
11842 if (isImpliedCond(Pred, LHS, RHS, ICmpInst::ICMP_ULT, LoopCounter,
11843 LatchBECount))
11844 return true;
11845 }
11846
11847 // Check conditions due to any @llvm.assume intrinsics.
11848 for (auto &AssumeVH : AC.assumptions()) {
11849 if (!AssumeVH)
11850 continue;
11851 auto *CI = cast<CallInst>(AssumeVH);
11852 if (!DT.dominates(CI, Latch->getTerminator()))
11853 continue;
11854
11855 if (isImpliedCond(Pred, LHS, RHS, CI->getArgOperand(0), false))
11856 return true;
11857 }
11858
11859 if (isImpliedViaGuard(Latch, Pred, LHS, RHS))
11860 return true;
11861
11862 for (DomTreeNode *DTN = DT[Latch], *HeaderDTN = DT[L->getHeader()];
11863 DTN != HeaderDTN; DTN = DTN->getIDom()) {
11864 assert(DTN && "should reach the loop header before reaching the root!");
11865
11866 BasicBlock *BB = DTN->getBlock();
11867 if (isImpliedViaGuard(BB, Pred, LHS, RHS))
11868 return true;
11869
11870 BasicBlock *PBB = BB->getSinglePredecessor();
11871 if (!PBB)
11872 continue;
11873
11875 if (!ContBr || ContBr->getSuccessor(0) == ContBr->getSuccessor(1))
11876 continue;
11877
11878 // If we have an edge `E` within the loop body that dominates the only
11879 // latch, the condition guarding `E` also guards the backedge. This
11880 // reasoning works only for loops with a single latch.
11881 // We're constructively (and conservatively) enumerating edges within the
11882 // loop body that dominate the latch. The dominator tree better agree
11883 // with us on this:
11884 assert(DT.dominates(BasicBlockEdge(PBB, BB), Latch) && "should be!");
11885 if (isImpliedCond(Pred, LHS, RHS, ContBr->getCondition(),
11886 BB != ContBr->getSuccessor(0)))
11887 return true;
11888 }
11889
11890 return false;
11891}
11892
11894 CmpPredicate Pred,
11895 const SCEV *LHS,
11896 const SCEV *RHS) {
11897 // Do not bother proving facts for unreachable code.
11898 if (!DT.isReachableFromEntry(BB))
11899 return true;
11900 if (VerifyIR)
11901 assert(!verifyFunction(*BB->getParent(), &dbgs()) &&
11902 "This cannot be done on broken IR!");
11903
11904 // If we cannot prove strict comparison (e.g. a > b), maybe we can prove
11905 // the facts (a >= b && a != b) separately. A typical situation is when the
11906 // non-strict comparison is known from ranges and non-equality is known from
11907 // dominating predicates. If we are proving strict comparison, we always try
11908 // to prove non-equality and non-strict comparison separately.
11909 CmpPredicate NonStrictPredicate = ICmpInst::getNonStrictCmpPredicate(Pred);
11910 const bool ProvingStrictComparison =
11911 Pred != NonStrictPredicate.dropSameSign();
11912 bool ProvedNonStrictComparison = false;
11913 bool ProvedNonEquality = false;
11914
11915 auto SplitAndProve = [&](std::function<bool(CmpPredicate)> Fn) -> bool {
11916 if (!ProvedNonStrictComparison)
11917 ProvedNonStrictComparison = Fn(NonStrictPredicate);
11918 if (!ProvedNonEquality)
11919 ProvedNonEquality = Fn(ICmpInst::ICMP_NE);
11920 if (ProvedNonStrictComparison && ProvedNonEquality)
11921 return true;
11922 return false;
11923 };
11924
11925 if (ProvingStrictComparison) {
11926 auto ProofFn = [&](CmpPredicate P) {
11927 return isKnownViaNonRecursiveReasoning(P, LHS, RHS);
11928 };
11929 if (SplitAndProve(ProofFn))
11930 return true;
11931 }
11932
11933 // Try to prove (Pred, LHS, RHS) using isImpliedCond.
11934 auto ProveViaCond = [&](const Value *Condition, bool Inverse) {
11935 const Instruction *CtxI = &BB->front();
11936 if (isImpliedCond(Pred, LHS, RHS, Condition, Inverse, CtxI))
11937 return true;
11938 if (ProvingStrictComparison) {
11939 auto ProofFn = [&](CmpPredicate P) {
11940 return isImpliedCond(P, LHS, RHS, Condition, Inverse, CtxI);
11941 };
11942 if (SplitAndProve(ProofFn))
11943 return true;
11944 }
11945 return false;
11946 };
11947
11948 // Starting at the block's predecessor, climb up the predecessor chain, as long
11949 // as there are predecessors that can be found that have unique successors
11950 // leading to the original block.
11951 const Loop *ContainingLoop = LI.getLoopFor(BB);
11952 const BasicBlock *PredBB;
11953 if (ContainingLoop && ContainingLoop->getHeader() == BB)
11954 PredBB = ContainingLoop->getLoopPredecessor();
11955 else
11956 PredBB = BB->getSinglePredecessor();
11957 for (std::pair<const BasicBlock *, const BasicBlock *> Pair(PredBB, BB);
11958 Pair.first; Pair = getPredecessorWithUniqueSuccessorForBB(Pair.first)) {
11959 const CondBrInst *BlockEntryPredicate =
11960 dyn_cast<CondBrInst>(Pair.first->getTerminator());
11961 if (!BlockEntryPredicate)
11962 continue;
11963
11964 if (ProveViaCond(BlockEntryPredicate->getCondition(),
11965 BlockEntryPredicate->getSuccessor(0) != Pair.second))
11966 return true;
11967 }
11968
11969 // Check conditions due to any @llvm.assume intrinsics.
11970 for (auto &AssumeVH : AC.assumptions()) {
11971 if (!AssumeVH)
11972 continue;
11973 auto *CI = cast<CallInst>(AssumeVH);
11974 if (!DT.dominates(CI, BB))
11975 continue;
11976
11977 if (ProveViaCond(CI->getArgOperand(0), false))
11978 return true;
11979 }
11980
11981 // Check conditions due to any @llvm.experimental.guard intrinsics.
11982 auto *GuardDecl = Intrinsic::getDeclarationIfExists(
11983 F.getParent(), Intrinsic::experimental_guard);
11984 if (GuardDecl)
11985 for (const auto *GU : GuardDecl->users())
11986 if (const auto *Guard = dyn_cast<IntrinsicInst>(GU))
11987 if (Guard->getFunction() == BB->getParent() && DT.dominates(Guard, BB))
11988 if (ProveViaCond(Guard->getArgOperand(0), false))
11989 return true;
11990 return false;
11991}
11992
11994 const SCEV *LHS,
11995 const SCEV *RHS) {
11996 // Interpret a null as meaning no loop, where there is obviously no guard
11997 // (interprocedural conditions notwithstanding).
11998 if (!L)
11999 return false;
12000
12001 // Both LHS and RHS must be available at loop entry.
12003 "LHS is not available at Loop Entry");
12005 "RHS is not available at Loop Entry");
12006
12007 if (isKnownViaNonRecursiveReasoning(Pred, LHS, RHS))
12008 return true;
12009
12010 return isBasicBlockEntryGuardedByCond(L->getHeader(), Pred, LHS, RHS);
12011}
12012
12013bool ScalarEvolution::isImpliedCond(CmpPredicate Pred, const SCEV *LHS,
12014 const SCEV *RHS,
12015 const Value *FoundCondValue, bool Inverse,
12016 const Instruction *CtxI) {
12017 // False conditions implies anything. Do not bother analyzing it further.
12018 if (FoundCondValue ==
12019 ConstantInt::getBool(FoundCondValue->getContext(), Inverse))
12020 return true;
12021
12022 if (!PendingLoopPredicates.insert(FoundCondValue).second)
12023 return false;
12024
12025 llvm::scope_exit ClearOnExit(
12026 [&]() { PendingLoopPredicates.erase(FoundCondValue); });
12027
12028 // Recursively handle And and Or conditions.
12029 const Value *Op0, *Op1;
12030 if (match(FoundCondValue, m_LogicalAnd(m_Value(Op0), m_Value(Op1)))) {
12031 if (!Inverse)
12032 return isImpliedCond(Pred, LHS, RHS, Op0, Inverse, CtxI) ||
12033 isImpliedCond(Pred, LHS, RHS, Op1, Inverse, CtxI);
12034 } else if (match(FoundCondValue, m_LogicalOr(m_Value(Op0), m_Value(Op1)))) {
12035 if (Inverse)
12036 return isImpliedCond(Pred, LHS, RHS, Op0, Inverse, CtxI) ||
12037 isImpliedCond(Pred, LHS, RHS, Op1, Inverse, CtxI);
12038 }
12039
12040 const ICmpInst *ICI = dyn_cast<ICmpInst>(FoundCondValue);
12041 if (!ICI) return false;
12042
12043 // Now that we found a conditional branch that dominates the loop or controls
12044 // the loop latch. Check to see if it is the comparison we are looking for.
12045 CmpPredicate FoundPred;
12046 if (Inverse)
12047 FoundPred = ICI->getInverseCmpPredicate();
12048 else
12049 FoundPred = ICI->getCmpPredicate();
12050
12051 const SCEV *FoundLHS = getSCEV(ICI->getOperand(0));
12052 const SCEV *FoundRHS = getSCEV(ICI->getOperand(1));
12053
12054 return isImpliedCond(Pred, LHS, RHS, FoundPred, FoundLHS, FoundRHS, CtxI);
12055}
12056
12057bool ScalarEvolution::isImpliedCond(CmpPredicate Pred, const SCEV *LHS,
12058 const SCEV *RHS, CmpPredicate FoundPred,
12059 const SCEV *FoundLHS, const SCEV *FoundRHS,
12060 const Instruction *CtxI) {
12061 // Balance the types.
12062 if (getTypeSizeInBits(LHS->getType()) <
12063 getTypeSizeInBits(FoundLHS->getType())) {
12064 // For unsigned and equality predicates, try to prove that both found
12065 // operands fit into narrow unsigned range. If so, try to prove facts in
12066 // narrow types.
12067 if (!CmpInst::isSigned(FoundPred) && !FoundLHS->getType()->isPointerTy() &&
12068 !FoundRHS->getType()->isPointerTy()) {
12069 auto *NarrowType = LHS->getType();
12070 auto *WideType = FoundLHS->getType();
12071 auto BitWidth = getTypeSizeInBits(NarrowType);
12072 const SCEV *MaxValue = getZeroExtendExpr(
12074 if (isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_ULE, FoundLHS,
12075 MaxValue) &&
12076 isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_ULE, FoundRHS,
12077 MaxValue)) {
12078 const SCEV *TruncFoundLHS = getTruncateExpr(FoundLHS, NarrowType);
12079 const SCEV *TruncFoundRHS = getTruncateExpr(FoundRHS, NarrowType);
12080 // We cannot preserve samesign after truncation.
12081 if (isImpliedCondBalancedTypes(Pred, LHS, RHS, FoundPred.dropSameSign(),
12082 TruncFoundLHS, TruncFoundRHS, CtxI))
12083 return true;
12084 }
12085 }
12086
12087 if (LHS->getType()->isPointerTy() || RHS->getType()->isPointerTy())
12088 return false;
12089 if (CmpInst::isSigned(Pred)) {
12090 LHS = getSignExtendExpr(LHS, FoundLHS->getType());
12091 RHS = getSignExtendExpr(RHS, FoundLHS->getType());
12092 } else {
12093 LHS = getZeroExtendExpr(LHS, FoundLHS->getType());
12094 RHS = getZeroExtendExpr(RHS, FoundLHS->getType());
12095 }
12096 } else if (getTypeSizeInBits(LHS->getType()) >
12097 getTypeSizeInBits(FoundLHS->getType())) {
12098 if (FoundLHS->getType()->isPointerTy() || FoundRHS->getType()->isPointerTy())
12099 return false;
12100 if (CmpInst::isSigned(FoundPred)) {
12101 FoundLHS = getSignExtendExpr(FoundLHS, LHS->getType());
12102 FoundRHS = getSignExtendExpr(FoundRHS, LHS->getType());
12103 } else {
12104 FoundLHS = getZeroExtendExpr(FoundLHS, LHS->getType());
12105 FoundRHS = getZeroExtendExpr(FoundRHS, LHS->getType());
12106 }
12107 }
12108 return isImpliedCondBalancedTypes(Pred, LHS, RHS, FoundPred, FoundLHS,
12109 FoundRHS, CtxI);
12110}
12111
12112bool ScalarEvolution::isImpliedCondBalancedTypes(
12113 CmpPredicate Pred, SCEVUse LHS, SCEVUse RHS, CmpPredicate FoundPred,
12114 SCEVUse FoundLHS, SCEVUse FoundRHS, const Instruction *CtxI) {
12116 getTypeSizeInBits(FoundLHS->getType()) &&
12117 "Types should be balanced!");
12118 // Canonicalize the query to match the way instcombine will have
12119 // canonicalized the comparison.
12120 if (SimplifyICmpOperands(Pred, LHS, RHS))
12121 if (LHS == RHS)
12122 return CmpInst::isTrueWhenEqual(Pred);
12123 if (SimplifyICmpOperands(FoundPred, FoundLHS, FoundRHS))
12124 if (FoundLHS == FoundRHS)
12125 return CmpInst::isFalseWhenEqual(FoundPred);
12126
12127 // Check to see if we can make the LHS or RHS match.
12128 if (LHS == FoundRHS || RHS == FoundLHS) {
12129 if (isa<SCEVConstant>(RHS)) {
12130 std::swap(FoundLHS, FoundRHS);
12131 FoundPred = ICmpInst::getSwappedCmpPredicate(FoundPred);
12132 } else {
12133 std::swap(LHS, RHS);
12135 }
12136 }
12137
12138 // Check whether the found predicate is the same as the desired predicate.
12139 if (auto P = CmpPredicate::getMatching(FoundPred, Pred))
12140 return isImpliedCondOperands(*P, LHS, RHS, FoundLHS, FoundRHS, CtxI);
12141
12142 // Check whether swapping the found predicate makes it the same as the
12143 // desired predicate.
12144 if (auto P = CmpPredicate::getMatching(
12145 ICmpInst::getSwappedCmpPredicate(FoundPred), Pred)) {
12146 // We can write the implication
12147 // 0. LHS Pred RHS <- FoundLHS SwapPred FoundRHS
12148 // using one of the following ways:
12149 // 1. LHS Pred RHS <- FoundRHS Pred FoundLHS
12150 // 2. RHS SwapPred LHS <- FoundLHS SwapPred FoundRHS
12151 // 3. LHS Pred RHS <- ~FoundLHS Pred ~FoundRHS
12152 // 4. ~LHS SwapPred ~RHS <- FoundLHS SwapPred FoundRHS
12153 // Forms 1. and 2. require swapping the operands of one condition. Don't
12154 // do this if it would break canonical constant/addrec ordering.
12156 return isImpliedCondOperands(ICmpInst::getSwappedCmpPredicate(*P), RHS,
12157 LHS, FoundLHS, FoundRHS, CtxI);
12158 if (!isa<SCEVConstant>(FoundRHS) && !isa<SCEVAddRecExpr>(FoundLHS))
12159 return isImpliedCondOperands(*P, LHS, RHS, FoundRHS, FoundLHS, CtxI);
12160
12161 // There's no clear preference between forms 3. and 4., try both. Avoid
12162 // forming getNotSCEV of pointer values as the resulting subtract is
12163 // not legal.
12164 if (!LHS->getType()->isPointerTy() && !RHS->getType()->isPointerTy() &&
12165 isImpliedCondOperands(ICmpInst::getSwappedCmpPredicate(*P),
12166 getNotSCEV(LHS), getNotSCEV(RHS), FoundLHS,
12167 FoundRHS, CtxI))
12168 return true;
12169
12170 if (!FoundLHS->getType()->isPointerTy() &&
12171 !FoundRHS->getType()->isPointerTy() &&
12172 isImpliedCondOperands(*P, LHS, RHS, getNotSCEV(FoundLHS),
12173 getNotSCEV(FoundRHS), CtxI))
12174 return true;
12175
12176 return false;
12177 }
12178
12179 auto IsSignFlippedPredicate = [](CmpInst::Predicate P1,
12181 assert(P1 != P2 && "Handled earlier!");
12182 return CmpInst::isRelational(P2) &&
12184 };
12185 if (IsSignFlippedPredicate(Pred, FoundPred)) {
12186 // Unsigned comparison is the same as signed comparison when both the
12187 // operands are non-negative or negative.
12188 if (haveSameSign(FoundLHS, FoundRHS))
12189 return isImpliedCondOperands(Pred, LHS, RHS, FoundLHS, FoundRHS, CtxI);
12190 // Create local copies that we can freely swap and canonicalize our
12191 // conditions to "le/lt".
12192 CmpPredicate CanonicalPred = Pred, CanonicalFoundPred = FoundPred;
12193 const SCEV *CanonicalLHS = LHS, *CanonicalRHS = RHS,
12194 *CanonicalFoundLHS = FoundLHS, *CanonicalFoundRHS = FoundRHS;
12195 if (ICmpInst::isGT(CanonicalPred) || ICmpInst::isGE(CanonicalPred)) {
12196 CanonicalPred = ICmpInst::getSwappedCmpPredicate(CanonicalPred);
12197 CanonicalFoundPred = ICmpInst::getSwappedCmpPredicate(CanonicalFoundPred);
12198 std::swap(CanonicalLHS, CanonicalRHS);
12199 std::swap(CanonicalFoundLHS, CanonicalFoundRHS);
12200 }
12201 assert((ICmpInst::isLT(CanonicalPred) || ICmpInst::isLE(CanonicalPred)) &&
12202 "Must be!");
12203 assert((ICmpInst::isLT(CanonicalFoundPred) ||
12204 ICmpInst::isLE(CanonicalFoundPred)) &&
12205 "Must be!");
12206 if (ICmpInst::isSigned(CanonicalPred) && isKnownNonNegative(CanonicalRHS))
12207 // Use implication:
12208 // x <u y && y >=s 0 --> x <s y.
12209 // If we can prove the left part, the right part is also proven.
12210 return isImpliedCondOperands(CanonicalFoundPred, CanonicalLHS,
12211 CanonicalRHS, CanonicalFoundLHS,
12212 CanonicalFoundRHS);
12213 if (ICmpInst::isUnsigned(CanonicalPred) && isKnownNegative(CanonicalRHS))
12214 // Use implication:
12215 // x <s y && y <s 0 --> x <u y.
12216 // If we can prove the left part, the right part is also proven.
12217 return isImpliedCondOperands(CanonicalFoundPred, CanonicalLHS,
12218 CanonicalRHS, CanonicalFoundLHS,
12219 CanonicalFoundRHS);
12220 }
12221
12222 // Check if we can make progress by sharpening ranges.
12223 if (FoundPred == ICmpInst::ICMP_NE &&
12224 (isa<SCEVConstant>(FoundLHS) || isa<SCEVConstant>(FoundRHS))) {
12225
12226 const SCEVConstant *C = nullptr;
12227 const SCEV *V = nullptr;
12228
12229 if (isa<SCEVConstant>(FoundLHS)) {
12230 C = cast<SCEVConstant>(FoundLHS);
12231 V = FoundRHS;
12232 } else {
12233 C = cast<SCEVConstant>(FoundRHS);
12234 V = FoundLHS;
12235 }
12236
12237 // The guarding predicate tells us that C != V. If the known range
12238 // of V is [C, t), we can sharpen the range to [C + 1, t). The
12239 // range we consider has to correspond to same signedness as the
12240 // predicate we're interested in folding.
12241
12242 APInt Min = ICmpInst::isSigned(Pred) ?
12244
12245 if (Min == C->getAPInt()) {
12246 // Given (V >= Min && V != Min) we conclude V >= (Min + 1).
12247 // This is true even if (Min + 1) wraps around -- in case of
12248 // wraparound, (Min + 1) < Min, so (V >= Min => V >= (Min + 1)).
12249
12250 APInt SharperMin = Min + 1;
12251
12252 switch (Pred) {
12253 case ICmpInst::ICMP_SGE:
12254 case ICmpInst::ICMP_UGE:
12255 // We know V `Pred` SharperMin. If this implies LHS `Pred`
12256 // RHS, we're done.
12257 if (isImpliedCondOperands(Pred, LHS, RHS, V, getConstant(SharperMin),
12258 CtxI))
12259 return true;
12260 [[fallthrough]];
12261
12262 case ICmpInst::ICMP_SGT:
12263 case ICmpInst::ICMP_UGT:
12264 // We know from the range information that (V `Pred` Min ||
12265 // V == Min). We know from the guarding condition that !(V
12266 // == Min). This gives us
12267 //
12268 // V `Pred` Min || V == Min && !(V == Min)
12269 // => V `Pred` Min
12270 //
12271 // If V `Pred` Min implies LHS `Pred` RHS, we're done.
12272
12273 if (isImpliedCondOperands(Pred, LHS, RHS, V, getConstant(Min), CtxI))
12274 return true;
12275 break;
12276
12277 // `LHS < RHS` and `LHS <= RHS` are handled in the same way as `RHS > LHS` and `RHS >= LHS` respectively.
12278 case ICmpInst::ICMP_SLE:
12279 case ICmpInst::ICMP_ULE:
12280 if (isImpliedCondOperands(ICmpInst::getSwappedCmpPredicate(Pred), RHS,
12281 LHS, V, getConstant(SharperMin), CtxI))
12282 return true;
12283 [[fallthrough]];
12284
12285 case ICmpInst::ICMP_SLT:
12286 case ICmpInst::ICMP_ULT:
12287 if (isImpliedCondOperands(ICmpInst::getSwappedCmpPredicate(Pred), RHS,
12288 LHS, V, getConstant(Min), CtxI))
12289 return true;
12290 break;
12291
12292 default:
12293 // No change
12294 break;
12295 }
12296 }
12297 }
12298
12299 // Check whether the actual condition is beyond sufficient.
12300 if (FoundPred == ICmpInst::ICMP_EQ)
12301 if (ICmpInst::isTrueWhenEqual(Pred))
12302 if (isImpliedCondOperands(Pred, LHS, RHS, FoundLHS, FoundRHS, CtxI))
12303 return true;
12304 if (Pred == ICmpInst::ICMP_NE)
12305 if (!ICmpInst::isTrueWhenEqual(FoundPred))
12306 if (isImpliedCondOperands(FoundPred, LHS, RHS, FoundLHS, FoundRHS, CtxI))
12307 return true;
12308
12309 if (isImpliedCondOperandsViaRanges(Pred, LHS, RHS, FoundPred, FoundLHS, FoundRHS))
12310 return true;
12311
12312 // Otherwise assume the worst.
12313 return false;
12314}
12315
12316bool ScalarEvolution::splitBinaryAdd(SCEVUse Expr, SCEVUse &L, SCEVUse &R,
12317 SCEV::NoWrapFlags &Flags) {
12318 if (!match(Expr, m_scev_Add(m_SCEV(L), m_SCEV(R))))
12319 return false;
12320
12321 Flags = cast<SCEVAddExpr>(Expr)->getNoWrapFlags();
12322 return true;
12323}
12324
12325std::optional<APInt>
12327 // We avoid subtracting expressions here because this function is usually
12328 // fairly deep in the call stack (i.e. is called many times).
12329
12330 unsigned BW = getTypeSizeInBits(More->getType());
12331 APInt Diff(BW, 0);
12332 APInt DiffMul(BW, 1);
12333 // Try various simplifications to reduce the difference to a constant. Limit
12334 // the number of allowed simplifications to keep compile-time low.
12335 for (unsigned I = 0; I < 8; ++I) {
12336 if (More == Less)
12337 return Diff;
12338
12339 // Reduce addrecs with identical steps to their start value.
12341 const auto *LAR = cast<SCEVAddRecExpr>(Less);
12342 const auto *MAR = cast<SCEVAddRecExpr>(More);
12343
12344 if (LAR->getLoop() != MAR->getLoop())
12345 return std::nullopt;
12346
12347 // We look at affine expressions only; not for correctness but to keep
12348 // getStepRecurrence cheap.
12349 if (!LAR->isAffine() || !MAR->isAffine())
12350 return std::nullopt;
12351
12352 if (LAR->getStepRecurrence(*this) != MAR->getStepRecurrence(*this))
12353 return std::nullopt;
12354
12355 Less = LAR->getStart();
12356 More = MAR->getStart();
12357 continue;
12358 }
12359
12360 // Try to match a common constant multiply.
12361 auto MatchConstMul =
12362 [](const SCEV *S) -> std::optional<std::pair<const SCEV *, APInt>> {
12363 const APInt *C;
12364 const SCEV *Op;
12365 if (match(S, m_scev_Mul(m_scev_APInt(C), m_SCEV(Op))))
12366 return {{Op, *C}};
12367 return std::nullopt;
12368 };
12369 if (auto MatchedMore = MatchConstMul(More)) {
12370 if (auto MatchedLess = MatchConstMul(Less)) {
12371 if (MatchedMore->second == MatchedLess->second) {
12372 More = MatchedMore->first;
12373 Less = MatchedLess->first;
12374 DiffMul *= MatchedMore->second;
12375 continue;
12376 }
12377 }
12378 }
12379
12380 // Try to cancel out common factors in two add expressions.
12382 auto Add = [&](const SCEV *S, int Mul) {
12383 if (auto *C = dyn_cast<SCEVConstant>(S)) {
12384 if (Mul == 1) {
12385 Diff += C->getAPInt() * DiffMul;
12386 } else {
12387 assert(Mul == -1);
12388 Diff -= C->getAPInt() * DiffMul;
12389 }
12390 } else
12391 Multiplicity[S] += Mul;
12392 };
12393 auto Decompose = [&](const SCEV *S, int Mul) {
12394 if (isa<SCEVAddExpr>(S)) {
12395 for (const SCEV *Op : S->operands())
12396 Add(Op, Mul);
12397 } else
12398 Add(S, Mul);
12399 };
12400 Decompose(More, 1);
12401 Decompose(Less, -1);
12402
12403 // Check whether all the non-constants cancel out, or reduce to new
12404 // More/Less values.
12405 const SCEV *NewMore = nullptr, *NewLess = nullptr;
12406 for (const auto &[S, Mul] : Multiplicity) {
12407 if (Mul == 0)
12408 continue;
12409 if (Mul == 1) {
12410 if (NewMore)
12411 return std::nullopt;
12412 NewMore = S;
12413 } else if (Mul == -1) {
12414 if (NewLess)
12415 return std::nullopt;
12416 NewLess = S;
12417 } else
12418 return std::nullopt;
12419 }
12420
12421 // Values stayed the same, no point in trying further.
12422 if (NewMore == More || NewLess == Less)
12423 return std::nullopt;
12424
12425 More = NewMore;
12426 Less = NewLess;
12427
12428 // Reduced to constant.
12429 if (!More && !Less)
12430 return Diff;
12431
12432 // Left with variable on only one side, bail out.
12433 if (!More || !Less)
12434 return std::nullopt;
12435 }
12436
12437 // Did not reduce to constant.
12438 return std::nullopt;
12439}
12440
12441bool ScalarEvolution::isImpliedCondOperandsViaAddRecStart(
12442 CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS, const SCEV *FoundLHS,
12443 const SCEV *FoundRHS, const Instruction *CtxI) {
12444 // Try to recognize the following pattern:
12445 //
12446 // FoundRHS = ...
12447 // ...
12448 // loop:
12449 // FoundLHS = {Start,+,W}
12450 // context_bb: // Basic block from the same loop
12451 // known(Pred, FoundLHS, FoundRHS)
12452 //
12453 // If some predicate is known in the context of a loop, it is also known on
12454 // each iteration of this loop, including the first iteration. Therefore, in
12455 // this case, `FoundLHS Pred FoundRHS` implies `Start Pred FoundRHS`. Try to
12456 // prove the original pred using this fact.
12457 if (!CtxI)
12458 return false;
12459 const BasicBlock *ContextBB = CtxI->getParent();
12460 // Make sure AR varies in the context block.
12461 if (auto *AR = dyn_cast<SCEVAddRecExpr>(FoundLHS)) {
12462 const Loop *L = AR->getLoop();
12463 const auto *Latch = L->getLoopLatch();
12464 // Make sure that context belongs to the loop and executes on 1st iteration
12465 // (if it ever executes at all).
12466 if (!L->contains(ContextBB) || !Latch || !DT.dominates(ContextBB, Latch))
12467 return false;
12468 if (!isAvailableAtLoopEntry(FoundRHS, AR->getLoop()))
12469 return false;
12470 return isImpliedCondOperands(Pred, LHS, RHS, AR->getStart(), FoundRHS);
12471 }
12472
12473 if (auto *AR = dyn_cast<SCEVAddRecExpr>(FoundRHS)) {
12474 const Loop *L = AR->getLoop();
12475 const auto *Latch = L->getLoopLatch();
12476 // Make sure that context belongs to the loop and executes on 1st iteration
12477 // (if it ever executes at all).
12478 if (!L->contains(ContextBB) || !Latch || !DT.dominates(ContextBB, Latch))
12479 return false;
12480 if (!isAvailableAtLoopEntry(FoundLHS, AR->getLoop()))
12481 return false;
12482 return isImpliedCondOperands(Pred, LHS, RHS, FoundLHS, AR->getStart());
12483 }
12484
12485 return false;
12486}
12487
12488bool ScalarEvolution::isImpliedCondOperandsViaNoOverflow(CmpPredicate Pred,
12489 const SCEV *LHS,
12490 const SCEV *RHS,
12491 const SCEV *FoundLHS,
12492 const SCEV *FoundRHS) {
12493 if (Pred != CmpInst::ICMP_SLT && Pred != CmpInst::ICMP_ULT)
12494 return false;
12495
12496 const auto *AddRecLHS = dyn_cast<SCEVAddRecExpr>(LHS);
12497 if (!AddRecLHS)
12498 return false;
12499
12500 const auto *AddRecFoundLHS = dyn_cast<SCEVAddRecExpr>(FoundLHS);
12501 if (!AddRecFoundLHS)
12502 return false;
12503
12504 // We'd like to let SCEV reason about control dependencies, so we constrain
12505 // both the inequalities to be about add recurrences on the same loop. This
12506 // way we can use isLoopEntryGuardedByCond later.
12507
12508 const Loop *L = AddRecFoundLHS->getLoop();
12509 if (L != AddRecLHS->getLoop())
12510 return false;
12511
12512 // FoundLHS u< FoundRHS u< -C => (FoundLHS + C) u< (FoundRHS + C) ... (1)
12513 //
12514 // FoundLHS s< FoundRHS s< INT_MIN - C => (FoundLHS + C) s< (FoundRHS + C)
12515 // ... (2)
12516 //
12517 // Informal proof for (2), assuming (1) [*]:
12518 //
12519 // We'll also assume (A s< B) <=> ((A + INT_MIN) u< (B + INT_MIN)) ... (3)[**]
12520 //
12521 // Then
12522 //
12523 // FoundLHS s< FoundRHS s< INT_MIN - C
12524 // <=> (FoundLHS + INT_MIN) u< (FoundRHS + INT_MIN) u< -C [ using (3) ]
12525 // <=> (FoundLHS + INT_MIN + C) u< (FoundRHS + INT_MIN + C) [ using (1) ]
12526 // <=> (FoundLHS + INT_MIN + C + INT_MIN) s<
12527 // (FoundRHS + INT_MIN + C + INT_MIN) [ using (3) ]
12528 // <=> FoundLHS + C s< FoundRHS + C
12529 //
12530 // [*]: (1) can be proved by ruling out overflow.
12531 //
12532 // [**]: This can be proved by analyzing all the four possibilities:
12533 // (A s< 0, B s< 0), (A s< 0, B s>= 0), (A s>= 0, B s< 0) and
12534 // (A s>= 0, B s>= 0).
12535 //
12536 // Note:
12537 // Despite (2), "FoundRHS s< INT_MIN - C" does not mean that "FoundRHS + C"
12538 // will not sign underflow. For instance, say FoundLHS = (i8 -128), FoundRHS
12539 // = (i8 -127) and C = (i8 -100). Then INT_MIN - C = (i8 -28), and FoundRHS
12540 // s< (INT_MIN - C). Lack of sign overflow / underflow in "FoundRHS + C" is
12541 // neither necessary nor sufficient to prove "(FoundLHS + C) s< (FoundRHS +
12542 // C)".
12543
12544 std::optional<APInt> LDiff = computeConstantDifference(LHS, FoundLHS);
12545 if (!LDiff)
12546 return false;
12547 std::optional<APInt> RDiff = computeConstantDifference(RHS, FoundRHS);
12548 if (!RDiff || *LDiff != *RDiff)
12549 return false;
12550
12551 if (LDiff->isMinValue())
12552 return true;
12553
12554 APInt FoundRHSLimit;
12555
12556 if (Pred == CmpInst::ICMP_ULT) {
12557 FoundRHSLimit = -(*RDiff);
12558 } else {
12559 assert(Pred == CmpInst::ICMP_SLT && "Checked above!");
12560 FoundRHSLimit = APInt::getSignedMinValue(getTypeSizeInBits(RHS->getType())) - *RDiff;
12561 }
12562
12563 // Try to prove (1) or (2), as needed.
12564 return isAvailableAtLoopEntry(FoundRHS, L) &&
12565 isLoopEntryGuardedByCond(L, Pred, FoundRHS,
12566 getConstant(FoundRHSLimit));
12567}
12568
12569bool ScalarEvolution::isImpliedViaMerge(CmpPredicate Pred, const SCEV *LHS,
12570 const SCEV *RHS, const SCEV *FoundLHS,
12571 const SCEV *FoundRHS, unsigned Depth) {
12572 const PHINode *LPhi = nullptr, *RPhi = nullptr;
12573
12574 llvm::scope_exit ClearOnExit([&]() {
12575 if (LPhi) {
12576 bool Erased = PendingMerges.erase(LPhi);
12577 assert(Erased && "Failed to erase LPhi!");
12578 (void)Erased;
12579 }
12580 if (RPhi) {
12581 bool Erased = PendingMerges.erase(RPhi);
12582 assert(Erased && "Failed to erase RPhi!");
12583 (void)Erased;
12584 }
12585 });
12586
12587 // Find respective Phis and check that they are not being pending.
12588 if (const SCEVUnknown *LU = dyn_cast<SCEVUnknown>(LHS))
12589 if (auto *Phi = dyn_cast<PHINode>(LU->getValue())) {
12590 if (!PendingMerges.insert(Phi).second)
12591 return false;
12592 LPhi = Phi;
12593 }
12594 if (const SCEVUnknown *RU = dyn_cast<SCEVUnknown>(RHS))
12595 if (auto *Phi = dyn_cast<PHINode>(RU->getValue())) {
12596 // If we detect a loop of Phi nodes being processed by this method, for
12597 // example:
12598 //
12599 // %a = phi i32 [ %some1, %preheader ], [ %b, %latch ]
12600 // %b = phi i32 [ %some2, %preheader ], [ %a, %latch ]
12601 //
12602 // we don't want to deal with a case that complex, so return conservative
12603 // answer false.
12604 if (!PendingMerges.insert(Phi).second)
12605 return false;
12606 RPhi = Phi;
12607 }
12608
12609 // If none of LHS, RHS is a Phi, nothing to do here.
12610 if (!LPhi && !RPhi)
12611 return false;
12612
12613 // If there is a SCEVUnknown Phi we are interested in, make it left.
12614 if (!LPhi) {
12615 std::swap(LHS, RHS);
12616 std::swap(FoundLHS, FoundRHS);
12617 std::swap(LPhi, RPhi);
12619 }
12620
12621 assert(LPhi && "LPhi should definitely be a SCEVUnknown Phi!");
12622 const BasicBlock *LBB = LPhi->getParent();
12623 const SCEVAddRecExpr *RAR = dyn_cast<SCEVAddRecExpr>(RHS);
12624
12625 auto ProvedEasily = [&](const SCEV *S1, const SCEV *S2) {
12626 return isKnownViaNonRecursiveReasoning(Pred, S1, S2) ||
12627 isImpliedCondOperandsViaRanges(Pred, S1, S2, Pred, FoundLHS, FoundRHS) ||
12628 isImpliedViaOperations(Pred, S1, S2, FoundLHS, FoundRHS, Depth);
12629 };
12630
12631 if (RPhi && RPhi->getParent() == LBB) {
12632 // Case one: RHS is also a SCEVUnknown Phi from the same basic block.
12633 // If we compare two Phis from the same block, and for each entry block
12634 // the predicate is true for incoming values from this block, then the
12635 // predicate is also true for the Phis.
12636 for (const BasicBlock *IncBB : predecessors(LBB)) {
12637 const SCEV *L = getSCEV(LPhi->getIncomingValueForBlock(IncBB));
12638 const SCEV *R = getSCEV(RPhi->getIncomingValueForBlock(IncBB));
12639 if (!ProvedEasily(L, R))
12640 return false;
12641 }
12642 } else if (RAR && RAR->getLoop()->getHeader() == LBB) {
12643 // Case two: RHS is also a Phi from the same basic block, and it is an
12644 // AddRec. It means that there is a loop which has both AddRec and Unknown
12645 // PHIs, for it we can compare incoming values of AddRec from above the loop
12646 // and latch with their respective incoming values of LPhi.
12647 // TODO: Generalize to handle loops with many inputs in a header.
12648 if (LPhi->getNumIncomingValues() != 2) return false;
12649
12650 auto *RLoop = RAR->getLoop();
12651 auto *Predecessor = RLoop->getLoopPredecessor();
12652 assert(Predecessor && "Loop with AddRec with no predecessor?");
12653 const SCEV *L1 = getSCEV(LPhi->getIncomingValueForBlock(Predecessor));
12654 if (!ProvedEasily(L1, RAR->getStart()))
12655 return false;
12656 auto *Latch = RLoop->getLoopLatch();
12657 assert(Latch && "Loop with AddRec with no latch?");
12658 const SCEV *L2 = getSCEV(LPhi->getIncomingValueForBlock(Latch));
12659 if (!ProvedEasily(L2, RAR->getPostIncExpr(*this)))
12660 return false;
12661 } else {
12662 // In all other cases go over inputs of LHS and compare each of them to RHS,
12663 // the predicate is true for (LHS, RHS) if it is true for all such pairs.
12664 // At this point RHS is either a non-Phi, or it is a Phi from some block
12665 // different from LBB.
12666 for (const BasicBlock *IncBB : predecessors(LBB)) {
12667 // Check that RHS is available in this block.
12668 if (!dominates(RHS, IncBB))
12669 return false;
12670 const SCEV *L = getSCEV(LPhi->getIncomingValueForBlock(IncBB));
12671 // Make sure L does not refer to a value from a potentially previous
12672 // iteration of a loop.
12673 if (!properlyDominates(L, LBB))
12674 return false;
12675 // Addrecs are considered to properly dominate their loop, so are missed
12676 // by the previous check. Discard any values that have computable
12677 // evolution in this loop.
12678 if (auto *Loop = LI.getLoopFor(LBB))
12679 if (hasComputableLoopEvolution(L, Loop))
12680 return false;
12681 if (!ProvedEasily(L, RHS))
12682 return false;
12683 }
12684 }
12685 return true;
12686}
12687
12688bool ScalarEvolution::isImpliedCondOperandsViaShift(CmpPredicate Pred,
12689 const SCEV *LHS,
12690 const SCEV *RHS,
12691 const SCEV *FoundLHS,
12692 const SCEV *FoundRHS) {
12693 // We want to imply LHS < RHS from LHS < (RHS >> shiftvalue). First, make
12694 // sure that we are dealing with same LHS.
12695 if (RHS == FoundRHS) {
12696 std::swap(LHS, RHS);
12697 std::swap(FoundLHS, FoundRHS);
12699 }
12700 if (LHS != FoundLHS)
12701 return false;
12702
12703 auto *SUFoundRHS = dyn_cast<SCEVUnknown>(FoundRHS);
12704 if (!SUFoundRHS)
12705 return false;
12706
12707 Value *Shiftee, *ShiftValue;
12708
12709 using namespace PatternMatch;
12710 if (match(SUFoundRHS->getValue(),
12711 m_LShr(m_Value(Shiftee), m_Value(ShiftValue)))) {
12712 auto *ShifteeS = getSCEV(Shiftee);
12713 // Prove one of the following:
12714 // LHS <u (shiftee >> shiftvalue) && shiftee <=u RHS ---> LHS <u RHS
12715 // LHS <=u (shiftee >> shiftvalue) && shiftee <=u RHS ---> LHS <=u RHS
12716 // LHS <s (shiftee >> shiftvalue) && shiftee <=s RHS && shiftee >=s 0
12717 // ---> LHS <s RHS
12718 // LHS <=s (shiftee >> shiftvalue) && shiftee <=s RHS && shiftee >=s 0
12719 // ---> LHS <=s RHS
12720 if (Pred == ICmpInst::ICMP_ULT || Pred == ICmpInst::ICMP_ULE)
12721 return isKnownPredicate(ICmpInst::ICMP_ULE, ShifteeS, RHS);
12722 if (Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_SLE)
12723 if (isKnownNonNegative(ShifteeS))
12724 return isKnownPredicate(ICmpInst::ICMP_SLE, ShifteeS, RHS);
12725 }
12726
12727 return false;
12728}
12729
12730bool ScalarEvolution::isImpliedCondOperands(CmpPredicate Pred, const SCEV *LHS,
12731 const SCEV *RHS,
12732 const SCEV *FoundLHS,
12733 const SCEV *FoundRHS,
12734 const Instruction *CtxI) {
12735 return isImpliedCondOperandsViaRanges(Pred, LHS, RHS, Pred, FoundLHS,
12736 FoundRHS) ||
12737 isImpliedCondOperandsViaNoOverflow(Pred, LHS, RHS, FoundLHS,
12738 FoundRHS) ||
12739 isImpliedCondOperandsViaShift(Pred, LHS, RHS, FoundLHS, FoundRHS) ||
12740 isImpliedCondOperandsViaAddRecStart(Pred, LHS, RHS, FoundLHS, FoundRHS,
12741 CtxI) ||
12742 isImpliedCondOperandsHelper(Pred, LHS, RHS, FoundLHS, FoundRHS);
12743}
12744
12745/// Is MaybeMinMaxExpr an (U|S)(Min|Max) of Candidate and some other values?
12746template <typename MinMaxExprType>
12747static bool IsMinMaxConsistingOf(const SCEV *MaybeMinMaxExpr,
12748 const SCEV *Candidate) {
12749 const MinMaxExprType *MinMaxExpr = dyn_cast<MinMaxExprType>(MaybeMinMaxExpr);
12750 if (!MinMaxExpr)
12751 return false;
12752
12753 return is_contained(MinMaxExpr->operands(), Candidate);
12754}
12755
12757 CmpPredicate Pred, const SCEV *LHS,
12758 const SCEV *RHS) {
12759 // If both sides are affine addrecs for the same loop, with equal
12760 // steps, and we know the recurrences don't wrap, then we only
12761 // need to check the predicate on the starting values.
12762
12763 if (!ICmpInst::isRelational(Pred))
12764 return false;
12765
12766 const SCEV *LStart, *RStart, *Step;
12767 const Loop *L;
12768 if (!match(LHS,
12769 m_scev_AffineAddRec(m_SCEV(LStart), m_SCEV(Step), m_Loop(L))) ||
12771 m_SpecificLoop(L))))
12772 return false;
12777 if (!LAR->getNoWrapFlags(NW) || !RAR->getNoWrapFlags(NW))
12778 return false;
12779
12780 return SE.isKnownPredicate(Pred, LStart, RStart);
12781}
12782
12783/// Is LHS `Pred` RHS true on the virtue of LHS or RHS being a Min or Max
12784/// expression?
12786 const SCEV *LHS, const SCEV *RHS) {
12787 switch (Pred) {
12788 default:
12789 return false;
12790
12791 case ICmpInst::ICMP_SGE:
12792 std::swap(LHS, RHS);
12793 [[fallthrough]];
12794 case ICmpInst::ICMP_SLE:
12795 return
12796 // min(A, ...) <= A
12798 // A <= max(A, ...)
12800
12801 case ICmpInst::ICMP_UGE:
12802 std::swap(LHS, RHS);
12803 [[fallthrough]];
12804 case ICmpInst::ICMP_ULE:
12805 return
12806 // min(A, ...) <= A
12807 // FIXME: what about umin_seq?
12809 // A <= max(A, ...)
12811 }
12812
12813 llvm_unreachable("covered switch fell through?!");
12814}
12815
12816bool ScalarEvolution::isImpliedViaOperations(CmpPredicate Pred, const SCEV *LHS,
12817 const SCEV *RHS,
12818 const SCEV *FoundLHS,
12819 const SCEV *FoundRHS,
12820 unsigned Depth) {
12823 "LHS and RHS have different sizes?");
12824 assert(getTypeSizeInBits(FoundLHS->getType()) ==
12825 getTypeSizeInBits(FoundRHS->getType()) &&
12826 "FoundLHS and FoundRHS have different sizes?");
12827 // We want to avoid hurting the compile time with analysis of too big trees.
12829 return false;
12830
12831 // We only want to work with GT comparison so far.
12832 if (ICmpInst::isLT(Pred)) {
12834 std::swap(LHS, RHS);
12835 std::swap(FoundLHS, FoundRHS);
12836 }
12837
12839
12840 // For unsigned, try to reduce it to corresponding signed comparison.
12841 if (P == ICmpInst::ICMP_UGT)
12842 // We can replace unsigned predicate with its signed counterpart if all
12843 // involved values are non-negative.
12844 // TODO: We could have better support for unsigned.
12845 if (isKnownNonNegative(FoundLHS) && isKnownNonNegative(FoundRHS)) {
12846 // Knowing that both FoundLHS and FoundRHS are non-negative, and knowing
12847 // FoundLHS >u FoundRHS, we also know that FoundLHS >s FoundRHS. Let us
12848 // use this fact to prove that LHS and RHS are non-negative.
12849 const SCEV *MinusOne = getMinusOne(LHS->getType());
12850 if (isImpliedCondOperands(ICmpInst::ICMP_SGT, LHS, MinusOne, FoundLHS,
12851 FoundRHS) &&
12852 isImpliedCondOperands(ICmpInst::ICMP_SGT, RHS, MinusOne, FoundLHS,
12853 FoundRHS))
12855 }
12856
12857 if (P != ICmpInst::ICMP_SGT)
12858 return false;
12859
12860 auto GetOpFromSExt = [&](const SCEV *S) -> const SCEV * {
12861 if (auto *Ext = dyn_cast<SCEVSignExtendExpr>(S))
12862 return Ext->getOperand();
12863 // TODO: If S is a SCEVConstant then you can cheaply "strip" the sext off
12864 // the constant in some cases.
12865 return S;
12866 };
12867
12868 // Acquire values from extensions.
12869 auto *OrigLHS = LHS;
12870 auto *OrigFoundLHS = FoundLHS;
12871 LHS = GetOpFromSExt(LHS);
12872 FoundLHS = GetOpFromSExt(FoundLHS);
12873
12874 // Is the SGT predicate can be proved trivially or using the found context.
12875 auto IsSGTViaContext = [&](const SCEV *S1, const SCEV *S2) {
12876 return isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_SGT, S1, S2) ||
12877 isImpliedViaOperations(ICmpInst::ICMP_SGT, S1, S2, OrigFoundLHS,
12878 FoundRHS, Depth + 1);
12879 };
12880
12881 if (auto *LHSAddExpr = dyn_cast<SCEVAddExpr>(LHS)) {
12882 // We want to avoid creation of any new non-constant SCEV. Since we are
12883 // going to compare the operands to RHS, we should be certain that we don't
12884 // need any size extensions for this. So let's decline all cases when the
12885 // sizes of types of LHS and RHS do not match.
12886 // TODO: Maybe try to get RHS from sext to catch more cases?
12888 return false;
12889
12890 // Should not overflow.
12891 if (!LHSAddExpr->hasNoSignedWrap())
12892 return false;
12893
12894 SCEVUse LL = LHSAddExpr->getOperand(0);
12895 SCEVUse LR = LHSAddExpr->getOperand(1);
12896 auto *MinusOne = getMinusOne(RHS->getType());
12897
12898 // Checks that S1 >= 0 && S2 > RHS, trivially or using the found context.
12899 auto IsSumGreaterThanRHS = [&](const SCEV *S1, const SCEV *S2) {
12900 return IsSGTViaContext(S1, MinusOne) && IsSGTViaContext(S2, RHS);
12901 };
12902 // Try to prove the following rule:
12903 // (LHS = LL + LR) && (LL >= 0) && (LR > RHS) => (LHS > RHS).
12904 // (LHS = LL + LR) && (LR >= 0) && (LL > RHS) => (LHS > RHS).
12905 if (IsSumGreaterThanRHS(LL, LR) || IsSumGreaterThanRHS(LR, LL))
12906 return true;
12907 } else if (auto *LHSUnknownExpr = dyn_cast<SCEVUnknown>(LHS)) {
12908 Value *LL, *LR;
12909 // FIXME: Once we have SDiv implemented, we can get rid of this matching.
12910
12911 using namespace llvm::PatternMatch;
12912
12913 if (match(LHSUnknownExpr->getValue(), m_SDiv(m_Value(LL), m_Value(LR)))) {
12914 // Rules for division.
12915 // We are going to perform some comparisons with Denominator and its
12916 // derivative expressions. In general case, creating a SCEV for it may
12917 // lead to a complex analysis of the entire graph, and in particular it
12918 // can request trip count recalculation for the same loop. This would
12919 // cache as SCEVCouldNotCompute to avoid the infinite recursion. To avoid
12920 // this, we only want to create SCEVs that are constants in this section.
12921 // So we bail if Denominator is not a constant.
12922 if (!isa<ConstantInt>(LR))
12923 return false;
12924
12925 auto *Denominator = cast<SCEVConstant>(getSCEV(LR));
12926
12927 // We want to make sure that LHS = FoundLHS / Denominator. If it is so,
12928 // then a SCEV for the numerator already exists and matches with FoundLHS.
12929 auto *Numerator = getExistingSCEV(LL);
12930 if (!Numerator || Numerator->getType() != FoundLHS->getType())
12931 return false;
12932
12933 // Make sure that the numerator matches with FoundLHS and the denominator
12934 // is positive.
12935 if (!HasSameValue(Numerator, FoundLHS) || !isKnownPositive(Denominator))
12936 return false;
12937
12938 auto *DTy = Denominator->getType();
12939 auto *FRHSTy = FoundRHS->getType();
12940 if (DTy->isPointerTy() != FRHSTy->isPointerTy())
12941 // One of types is a pointer and another one is not. We cannot extend
12942 // them properly to a wider type, so let us just reject this case.
12943 // TODO: Usage of getEffectiveSCEVType for DTy, FRHSTy etc should help
12944 // to avoid this check.
12945 return false;
12946
12947 // Given that:
12948 // FoundLHS > FoundRHS, LHS = FoundLHS / Denominator, Denominator > 0.
12949 auto *WTy = getWiderType(DTy, FRHSTy);
12950 auto *DenominatorExt = getNoopOrSignExtend(Denominator, WTy);
12951 auto *FoundRHSExt = getNoopOrSignExtend(FoundRHS, WTy);
12952
12953 // Try to prove the following rule:
12954 // (FoundRHS > Denominator - 2) && (RHS <= 0) => (LHS > RHS).
12955 // For example, given that FoundLHS > 2. It means that FoundLHS is at
12956 // least 3. If we divide it by Denominator < 4, we will have at least 1.
12957 auto *DenomMinusTwo = getMinusSCEV(DenominatorExt, getConstant(WTy, 2));
12958 if (isKnownNonPositive(RHS) &&
12959 IsSGTViaContext(FoundRHSExt, DenomMinusTwo))
12960 return true;
12961
12962 // Try to prove the following rule:
12963 // (FoundRHS > -1 - Denominator) && (RHS < 0) => (LHS > RHS).
12964 // For example, given that FoundLHS > -3. Then FoundLHS is at least -2.
12965 // If we divide it by Denominator > 2, then:
12966 // 1. If FoundLHS is negative, then the result is 0.
12967 // 2. If FoundLHS is non-negative, then the result is non-negative.
12968 // Anyways, the result is non-negative.
12969 auto *MinusOne = getMinusOne(WTy);
12970 auto *NegDenomMinusOne = getMinusSCEV(MinusOne, DenominatorExt);
12971 if (isKnownNegative(RHS) &&
12972 IsSGTViaContext(FoundRHSExt, NegDenomMinusOne))
12973 return true;
12974 }
12975 }
12976
12977 // If our expression contained SCEVUnknown Phis, and we split it down and now
12978 // need to prove something for them, try to prove the predicate for every
12979 // possible incoming values of those Phis.
12980 if (isImpliedViaMerge(Pred, OrigLHS, RHS, OrigFoundLHS, FoundRHS, Depth + 1))
12981 return true;
12982
12983 return false;
12984}
12985
12987 const SCEV *RHS) {
12988 // zext x u<= sext x, sext x s<= zext x
12989 const SCEV *Op;
12990 switch (Pred) {
12991 case ICmpInst::ICMP_SGE:
12992 std::swap(LHS, RHS);
12993 [[fallthrough]];
12994 case ICmpInst::ICMP_SLE: {
12995 // If operand >=s 0 then ZExt == SExt. If operand <s 0 then SExt <s ZExt.
12996 return match(LHS, m_scev_SExt(m_SCEV(Op))) &&
12998 }
12999 case ICmpInst::ICMP_UGE:
13000 std::swap(LHS, RHS);
13001 [[fallthrough]];
13002 case ICmpInst::ICMP_ULE: {
13003 // If operand >=u 0 then ZExt == SExt. If operand <u 0 then ZExt <u SExt.
13004 return match(LHS, m_scev_ZExt(m_SCEV(Op))) &&
13006 }
13007 default:
13008 return false;
13009 };
13010 llvm_unreachable("unhandled case");
13011}
13012
13013bool ScalarEvolution::isKnownViaNonRecursiveReasoning(CmpPredicate Pred,
13014 SCEVUse LHS,
13015 SCEVUse RHS) {
13016 return isKnownPredicateExtendIdiom(Pred, LHS, RHS) ||
13017 isKnownPredicateViaConstantRanges(Pred, LHS, RHS) ||
13018 IsKnownPredicateViaMinOrMax(*this, Pred, LHS, RHS) ||
13019 IsKnownPredicateViaAddRecStart(*this, Pred, LHS, RHS) ||
13020 isKnownPredicateViaNoOverflow(Pred, LHS, RHS);
13021}
13022
13023bool ScalarEvolution::isImpliedCondOperandsHelper(CmpPredicate Pred,
13024 const SCEV *LHS,
13025 const SCEV *RHS,
13026 const SCEV *FoundLHS,
13027 const SCEV *FoundRHS) {
13028 switch (Pred) {
13029 default:
13030 llvm_unreachable("Unexpected CmpPredicate value!");
13031 case ICmpInst::ICMP_EQ:
13032 case ICmpInst::ICMP_NE:
13033 if (HasSameValue(LHS, FoundLHS) && HasSameValue(RHS, FoundRHS))
13034 return true;
13035 break;
13036 case ICmpInst::ICMP_SLT:
13037 case ICmpInst::ICMP_SLE:
13038 if (isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_SLE, LHS, FoundLHS) &&
13039 isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_SGE, RHS, FoundRHS))
13040 return true;
13041 break;
13042 case ICmpInst::ICMP_SGT:
13043 case ICmpInst::ICMP_SGE:
13044 if (isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_SGE, LHS, FoundLHS) &&
13045 isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_SLE, RHS, FoundRHS))
13046 return true;
13047 break;
13048 case ICmpInst::ICMP_ULT:
13049 case ICmpInst::ICMP_ULE:
13050 if (isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_ULE, LHS, FoundLHS) &&
13051 isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_UGE, RHS, FoundRHS))
13052 return true;
13053 break;
13054 case ICmpInst::ICMP_UGT:
13055 case ICmpInst::ICMP_UGE:
13056 if (isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_UGE, LHS, FoundLHS) &&
13057 isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_ULE, RHS, FoundRHS))
13058 return true;
13059 break;
13060 }
13061
13062 // Maybe it can be proved via operations?
13063 if (isImpliedViaOperations(Pred, LHS, RHS, FoundLHS, FoundRHS))
13064 return true;
13065
13066 return false;
13067}
13068
13069bool ScalarEvolution::isImpliedCondOperandsViaRanges(
13070 CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS, CmpPredicate FoundPred,
13071 const SCEV *FoundLHS, const SCEV *FoundRHS) {
13072 if (!isa<SCEVConstant>(RHS) || !isa<SCEVConstant>(FoundRHS))
13073 // The restriction on `FoundRHS` be lifted easily -- it exists only to
13074 // reduce the compile time impact of this optimization.
13075 return false;
13076
13077 std::optional<APInt> Addend = computeConstantDifference(LHS, FoundLHS);
13078 if (!Addend)
13079 return false;
13080
13081 const APInt &ConstFoundRHS = cast<SCEVConstant>(FoundRHS)->getAPInt();
13082
13083 // `FoundLHSRange` is the range we know `FoundLHS` to be in by virtue of the
13084 // antecedent "`FoundLHS` `FoundPred` `FoundRHS`".
13085 ConstantRange FoundLHSRange =
13086 ConstantRange::makeExactICmpRegion(FoundPred, ConstFoundRHS);
13087
13088 // Since `LHS` is `FoundLHS` + `Addend`, we can compute a range for `LHS`:
13089 ConstantRange LHSRange = FoundLHSRange.add(ConstantRange(*Addend));
13090
13091 // We can also compute the range of values for `LHS` that satisfy the
13092 // consequent, "`LHS` `Pred` `RHS`":
13093 const APInt &ConstRHS = cast<SCEVConstant>(RHS)->getAPInt();
13094 // The antecedent implies the consequent if every value of `LHS` that
13095 // satisfies the antecedent also satisfies the consequent.
13096 return LHSRange.icmp(Pred, ConstRHS);
13097}
13098
13099bool ScalarEvolution::canIVOverflowOnLT(const SCEV *RHS, const SCEV *Stride,
13100 bool IsSigned) {
13101 assert(isKnownPositive(Stride) && "Positive stride expected!");
13102
13103 unsigned BitWidth = getTypeSizeInBits(RHS->getType());
13104 const SCEV *One = getOne(Stride->getType());
13105
13106 if (IsSigned) {
13107 APInt MaxRHS = getSignedRangeMax(RHS);
13108 APInt MaxValue = APInt::getSignedMaxValue(BitWidth);
13109 APInt MaxStrideMinusOne = getSignedRangeMax(getMinusSCEV(Stride, One));
13110
13111 // SMaxRHS + SMaxStrideMinusOne > SMaxValue => overflow!
13112 return (std::move(MaxValue) - MaxStrideMinusOne).slt(MaxRHS);
13113 }
13114
13115 APInt MaxRHS = getUnsignedRangeMax(RHS);
13116 APInt MaxValue = APInt::getMaxValue(BitWidth);
13117 APInt MaxStrideMinusOne = getUnsignedRangeMax(getMinusSCEV(Stride, One));
13118
13119 // UMaxRHS + UMaxStrideMinusOne > UMaxValue => overflow!
13120 return (std::move(MaxValue) - MaxStrideMinusOne).ult(MaxRHS);
13121}
13122
13123bool ScalarEvolution::canIVOverflowOnGT(const SCEV *RHS, const SCEV *Stride,
13124 bool IsSigned) {
13125
13126 unsigned BitWidth = getTypeSizeInBits(RHS->getType());
13127 const SCEV *One = getOne(Stride->getType());
13128
13129 if (IsSigned) {
13130 APInt MinRHS = getSignedRangeMin(RHS);
13131 APInt MinValue = APInt::getSignedMinValue(BitWidth);
13132 APInt MaxStrideMinusOne = getSignedRangeMax(getMinusSCEV(Stride, One));
13133
13134 // SMinRHS - SMaxStrideMinusOne < SMinValue => overflow!
13135 return (std::move(MinValue) + MaxStrideMinusOne).sgt(MinRHS);
13136 }
13137
13138 APInt MinRHS = getUnsignedRangeMin(RHS);
13139 APInt MinValue = APInt::getMinValue(BitWidth);
13140 APInt MaxStrideMinusOne = getUnsignedRangeMax(getMinusSCEV(Stride, One));
13141
13142 // UMinRHS - UMaxStrideMinusOne < UMinValue => overflow!
13143 return (std::move(MinValue) + MaxStrideMinusOne).ugt(MinRHS);
13144}
13145
13147 // umin(N, 1) + floor((N - umin(N, 1)) / D)
13148 // This is equivalent to "1 + floor((N - 1) / D)" for N != 0. The umin
13149 // expression fixes the case of N=0.
13150 const SCEV *MinNOne = getUMinExpr(N, getOne(N->getType()));
13151 const SCEV *NMinusOne = getMinusSCEV(N, MinNOne);
13152 return getAddExpr(MinNOne, getUDivExpr(NMinusOne, D));
13153}
13154
13155const SCEV *ScalarEvolution::computeMaxBECountForLT(const SCEV *Start,
13156 const SCEV *Stride,
13157 const SCEV *End,
13158 unsigned BitWidth,
13159 bool IsSigned) {
13160 // The logic in this function assumes we can represent a positive stride.
13161 // If we can't, the backedge-taken count must be zero.
13162 if (IsSigned && BitWidth == 1)
13163 return getZero(Stride->getType());
13164
13165 // This code below only been closely audited for negative strides in the
13166 // unsigned comparison case, it may be correct for signed comparison, but
13167 // that needs to be established.
13168 if (IsSigned && isKnownNegative(Stride))
13169 return getCouldNotCompute();
13170
13171 // Calculate the maximum backedge count based on the range of values
13172 // permitted by Start, End, and Stride.
13173 APInt MinStart =
13174 IsSigned ? getSignedRangeMin(Start) : getUnsignedRangeMin(Start);
13175
13176 APInt MinStride =
13177 IsSigned ? getSignedRangeMin(Stride) : getUnsignedRangeMin(Stride);
13178
13179 // We assume either the stride is positive, or the backedge-taken count
13180 // is zero. So force StrideForMaxBECount to be at least one.
13181 APInt One(BitWidth, 1);
13182 APInt StrideForMaxBECount = IsSigned ? APIntOps::smax(One, MinStride)
13183 : APIntOps::umax(One, MinStride);
13184
13185 APInt MaxValue = IsSigned ? APInt::getSignedMaxValue(BitWidth)
13186 : APInt::getMaxValue(BitWidth);
13187 APInt Limit = MaxValue - (StrideForMaxBECount - 1);
13188
13189 // Although End can be a MAX expression we estimate MaxEnd considering only
13190 // the case End = RHS of the loop termination condition. This is safe because
13191 // in the other case (End - Start) is zero, leading to a zero maximum backedge
13192 // taken count.
13193 APInt MaxEnd = IsSigned ? APIntOps::smin(getSignedRangeMax(End), Limit)
13194 : APIntOps::umin(getUnsignedRangeMax(End), Limit);
13195
13196 // MaxBECount = ceil((max(MaxEnd, MinStart) - MinStart) / Stride)
13197 MaxEnd = IsSigned ? APIntOps::smax(MaxEnd, MinStart)
13198 : APIntOps::umax(MaxEnd, MinStart);
13199
13200 return getUDivCeilSCEV(getConstant(MaxEnd - MinStart) /* Delta */,
13201 getConstant(StrideForMaxBECount) /* Step */);
13202}
13203
13205ScalarEvolution::howManyLessThans(const SCEV *LHS, const SCEV *RHS,
13206 const Loop *L, bool IsSigned,
13207 bool ControlsOnlyExit, bool AllowPredicates) {
13209
13211 bool PredicatedIV = false;
13212 if (!IV) {
13213 if (auto *ZExt = dyn_cast<SCEVZeroExtendExpr>(LHS)) {
13214 const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(ZExt->getOperand());
13215 if (AR && AR->getLoop() == L && AR->isAffine()) {
13216 auto canProveNUW = [&]() {
13217 // We can use the comparison to infer no-wrap flags only if it fully
13218 // controls the loop exit.
13219 if (!ControlsOnlyExit)
13220 return false;
13221
13222 if (!isLoopInvariant(RHS, L))
13223 return false;
13224
13225 if (!isKnownNonZero(AR->getStepRecurrence(*this)))
13226 // We need the sequence defined by AR to strictly increase in the
13227 // unsigned integer domain for the logic below to hold.
13228 return false;
13229
13230 const unsigned InnerBitWidth = getTypeSizeInBits(AR->getType());
13231 const unsigned OuterBitWidth = getTypeSizeInBits(RHS->getType());
13232 // If RHS <=u Limit, then there must exist a value V in the sequence
13233 // defined by AR (e.g. {Start,+,Step}) such that V >u RHS, and
13234 // V <=u UINT_MAX. Thus, we must exit the loop before unsigned
13235 // overflow occurs. This limit also implies that a signed comparison
13236 // (in the wide bitwidth) is equivalent to an unsigned comparison as
13237 // the high bits on both sides must be zero.
13238 APInt StrideMax = getUnsignedRangeMax(AR->getStepRecurrence(*this));
13239 APInt Limit = APInt::getMaxValue(InnerBitWidth) - (StrideMax - 1);
13240 Limit = Limit.zext(OuterBitWidth);
13241 return getUnsignedRangeMax(applyLoopGuards(RHS, L)).ule(Limit);
13242 };
13243 auto Flags = AR->getNoWrapFlags();
13244 if (!hasFlags(Flags, SCEV::FlagNUW) && canProveNUW())
13245 Flags = setFlags(Flags, SCEV::FlagNUW);
13246
13247 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), Flags);
13248 if (AR->hasNoUnsignedWrap()) {
13249 // Emulate what getZeroExtendExpr would have done during construction
13250 // if we'd been able to infer the fact just above at that time.
13251 const SCEV *Step = AR->getStepRecurrence(*this);
13252 Type *Ty = ZExt->getType();
13253 auto *S = getAddRecExpr(
13255 getZeroExtendExpr(Step, Ty, 0), L, AR->getNoWrapFlags());
13257 }
13258 }
13259 }
13260 }
13261
13262
13263 if (!IV && AllowPredicates) {
13264 // Try to make this an AddRec using runtime tests, in the first X
13265 // iterations of this loop, where X is the SCEV expression found by the
13266 // algorithm below.
13267 IV = convertSCEVToAddRecWithPredicates(LHS, L, Predicates);
13268 PredicatedIV = true;
13269 }
13270
13271 // Avoid weird loops
13272 if (!IV || IV->getLoop() != L || !IV->isAffine())
13273 return getCouldNotCompute();
13274
13275 // A precondition of this method is that the condition being analyzed
13276 // reaches an exiting branch which dominates the latch. Given that, we can
13277 // assume that an increment which violates the nowrap specification and
13278 // produces poison must cause undefined behavior when the resulting poison
13279 // value is branched upon and thus we can conclude that the backedge is
13280 // taken no more often than would be required to produce that poison value.
13281 // Note that a well defined loop can exit on the iteration which violates
13282 // the nowrap specification if there is another exit (either explicit or
13283 // implicit/exceptional) which causes the loop to execute before the
13284 // exiting instruction we're analyzing would trigger UB.
13285 auto WrapType = IsSigned ? SCEV::FlagNSW : SCEV::FlagNUW;
13286 bool NoWrap = ControlsOnlyExit && IV->getNoWrapFlags(WrapType);
13288
13289 const SCEV *Stride = IV->getStepRecurrence(*this);
13290
13291 bool PositiveStride = isKnownPositive(Stride);
13292
13293 // Avoid negative or zero stride values.
13294 if (!PositiveStride) {
13295 // We can compute the correct backedge taken count for loops with unknown
13296 // strides if we can prove that the loop is not an infinite loop with side
13297 // effects. Here's the loop structure we are trying to handle -
13298 //
13299 // i = start
13300 // do {
13301 // A[i] = i;
13302 // i += s;
13303 // } while (i < end);
13304 //
13305 // The backedge taken count for such loops is evaluated as -
13306 // (max(end, start + stride) - start - 1) /u stride
13307 //
13308 // The additional preconditions that we need to check to prove correctness
13309 // of the above formula is as follows -
13310 //
13311 // a) IV is either nuw or nsw depending upon signedness (indicated by the
13312 // NoWrap flag).
13313 // b) the loop is guaranteed to be finite (e.g. is mustprogress and has
13314 // no side effects within the loop)
13315 // c) loop has a single static exit (with no abnormal exits)
13316 //
13317 // Precondition a) implies that if the stride is negative, this is a single
13318 // trip loop. The backedge taken count formula reduces to zero in this case.
13319 //
13320 // Precondition b) and c) combine to imply that if rhs is invariant in L,
13321 // then a zero stride means the backedge can't be taken without executing
13322 // undefined behavior.
13323 //
13324 // The positive stride case is the same as isKnownPositive(Stride) returning
13325 // true (original behavior of the function).
13326 //
13327 if (PredicatedIV || !NoWrap || !loopIsFiniteByAssumption(L) ||
13329 return getCouldNotCompute();
13330
13331 if (!isKnownNonZero(Stride)) {
13332 // If we have a step of zero, and RHS isn't invariant in L, we don't know
13333 // if it might eventually be greater than start and if so, on which
13334 // iteration. We can't even produce a useful upper bound.
13335 if (!isLoopInvariant(RHS, L))
13336 return getCouldNotCompute();
13337
13338 // We allow a potentially zero stride, but we need to divide by stride
13339 // below. Since the loop can't be infinite and this check must control
13340 // the sole exit, we can infer the exit must be taken on the first
13341 // iteration (e.g. backedge count = 0) if the stride is zero. Given that,
13342 // we know the numerator in the divides below must be zero, so we can
13343 // pick an arbitrary non-zero value for the denominator (e.g. stride)
13344 // and produce the right result.
13345 // FIXME: Handle the case where Stride is poison?
13346 auto wouldZeroStrideBeUB = [&]() {
13347 // Proof by contradiction. Suppose the stride were zero. If we can
13348 // prove that the backedge *is* taken on the first iteration, then since
13349 // we know this condition controls the sole exit, we must have an
13350 // infinite loop. We can't have a (well defined) infinite loop per
13351 // check just above.
13352 // Note: The (Start - Stride) term is used to get the start' term from
13353 // (start' + stride,+,stride). Remember that we only care about the
13354 // result of this expression when stride == 0 at runtime.
13355 auto *StartIfZero = getMinusSCEV(IV->getStart(), Stride);
13356 return isLoopEntryGuardedByCond(L, Cond, StartIfZero, RHS);
13357 };
13358 if (!wouldZeroStrideBeUB()) {
13359 Stride = getUMaxExpr(Stride, getOne(Stride->getType()));
13360 }
13361 }
13362 } else if (!NoWrap) {
13363 // Avoid proven overflow cases: this will ensure that the backedge taken
13364 // count will not generate any unsigned overflow.
13365 if (canIVOverflowOnLT(RHS, Stride, IsSigned))
13366 return getCouldNotCompute();
13367 }
13368
13369 // On all paths just preceeding, we established the following invariant:
13370 // IV can be assumed not to overflow up to and including the exiting
13371 // iteration. We proved this in one of two ways:
13372 // 1) We can show overflow doesn't occur before the exiting iteration
13373 // 1a) canIVOverflowOnLT, and b) step of one
13374 // 2) We can show that if overflow occurs, the loop must execute UB
13375 // before any possible exit.
13376 // Note that we have not yet proved RHS invariant (in general).
13377
13378 const SCEV *Start = IV->getStart();
13379
13380 // Preserve pointer-typed Start/RHS to pass to isLoopEntryGuardedByCond.
13381 // If we convert to integers, isLoopEntryGuardedByCond will miss some cases.
13382 // Use integer-typed versions for actual computation; we can't subtract
13383 // pointers in general.
13384 const SCEV *OrigStart = Start;
13385 const SCEV *OrigRHS = RHS;
13386 if (Start->getType()->isPointerTy()) {
13388 if (isa<SCEVCouldNotCompute>(Start))
13389 return Start;
13390 }
13391 if (RHS->getType()->isPointerTy()) {
13394 return RHS;
13395 }
13396
13397 const SCEV *End = nullptr, *BECount = nullptr,
13398 *BECountIfBackedgeTaken = nullptr;
13399 if (!isLoopInvariant(RHS, L)) {
13400 const auto *RHSAddRec = dyn_cast<SCEVAddRecExpr>(RHS);
13401 if (PositiveStride && RHSAddRec != nullptr && RHSAddRec->getLoop() == L &&
13402 RHSAddRec->getNoWrapFlags()) {
13403 // The structure of loop we are trying to calculate backedge count of:
13404 //
13405 // left = left_start
13406 // right = right_start
13407 //
13408 // while(left < right){
13409 // ... do something here ...
13410 // left += s1; // stride of left is s1 (s1 > 0)
13411 // right += s2; // stride of right is s2 (s2 < 0)
13412 // }
13413 //
13414
13415 const SCEV *RHSStart = RHSAddRec->getStart();
13416 const SCEV *RHSStride = RHSAddRec->getStepRecurrence(*this);
13417
13418 // If Stride - RHSStride is positive and does not overflow, we can write
13419 // backedge count as ->
13420 // ceil((End - Start) /u (Stride - RHSStride))
13421 // Where, End = max(RHSStart, Start)
13422
13423 // Check if RHSStride < 0 and Stride - RHSStride will not overflow.
13424 if (isKnownNegative(RHSStride) &&
13425 willNotOverflow(Instruction::Sub, /*Signed=*/true, Stride,
13426 RHSStride)) {
13427
13428 const SCEV *Denominator = getMinusSCEV(Stride, RHSStride);
13429 if (isKnownPositive(Denominator)) {
13430 End = IsSigned ? getSMaxExpr(RHSStart, Start)
13431 : getUMaxExpr(RHSStart, Start);
13432
13433 // We can do this because End >= Start, as End = max(RHSStart, Start)
13434 const SCEV *Delta = getMinusSCEV(End, Start);
13435
13436 BECount = getUDivCeilSCEV(Delta, Denominator);
13437 BECountIfBackedgeTaken =
13438 getUDivCeilSCEV(getMinusSCEV(RHSStart, Start), Denominator);
13439 }
13440 }
13441 }
13442 if (BECount == nullptr) {
13443 // If we cannot calculate ExactBECount, we can calculate the MaxBECount,
13444 // given the start, stride and max value for the end bound of the
13445 // loop (RHS), and the fact that IV does not overflow (which is
13446 // checked above).
13447 const SCEV *MaxBECount = computeMaxBECountForLT(
13448 Start, Stride, RHS, getTypeSizeInBits(LHS->getType()), IsSigned);
13449 return ExitLimit(getCouldNotCompute() /* ExactNotTaken */, MaxBECount,
13450 MaxBECount, false /*MaxOrZero*/, Predicates);
13451 }
13452 } else {
13453 // We use the expression (max(End,Start)-Start)/Stride to describe the
13454 // backedge count, as if the backedge is taken at least once
13455 // max(End,Start) is End and so the result is as above, and if not
13456 // max(End,Start) is Start so we get a backedge count of zero.
13457 auto *OrigStartMinusStride = getMinusSCEV(OrigStart, Stride);
13458 assert(isAvailableAtLoopEntry(OrigStartMinusStride, L) && "Must be!");
13459 assert(isAvailableAtLoopEntry(OrigStart, L) && "Must be!");
13460 assert(isAvailableAtLoopEntry(OrigRHS, L) && "Must be!");
13461 // Can we prove (max(RHS,Start) > Start - Stride?
13462 if (isLoopEntryGuardedByCond(L, Cond, OrigStartMinusStride, OrigStart) &&
13463 isLoopEntryGuardedByCond(L, Cond, OrigStartMinusStride, OrigRHS)) {
13464 // In this case, we can use a refined formula for computing backedge
13465 // taken count. The general formula remains:
13466 // "End-Start /uceiling Stride" where "End = max(RHS,Start)"
13467 // We want to use the alternate formula:
13468 // "((End - 1) - (Start - Stride)) /u Stride"
13469 // Let's do a quick case analysis to show these are equivalent under
13470 // our precondition that max(RHS,Start) > Start - Stride.
13471 // * For RHS <= Start, the backedge-taken count must be zero.
13472 // "((End - 1) - (Start - Stride)) /u Stride" reduces to
13473 // "((Start - 1) - (Start - Stride)) /u Stride" which simplies to
13474 // "Stride - 1 /u Stride" which is indeed zero for all non-zero values
13475 // of Stride. For 0 stride, we've use umin(1,Stride) above,
13476 // reducing this to the stride of 1 case.
13477 // * For RHS >= Start, the backedge count must be "RHS-Start /uceil
13478 // Stride".
13479 // "((End - 1) - (Start - Stride)) /u Stride" reduces to
13480 // "((RHS - 1) - (Start - Stride)) /u Stride" reassociates to
13481 // "((RHS - (Start - Stride) - 1) /u Stride".
13482 // Our preconditions trivially imply no overflow in that form.
13483 const SCEV *MinusOne = getMinusOne(Stride->getType());
13484 const SCEV *Numerator =
13485 getMinusSCEV(getAddExpr(RHS, MinusOne), getMinusSCEV(Start, Stride));
13486 BECount = getUDivExpr(Numerator, Stride);
13487 }
13488
13489 if (!BECount) {
13490 auto canProveRHSGreaterThanEqualStart = [&]() {
13491 auto CondGE = IsSigned ? ICmpInst::ICMP_SGE : ICmpInst::ICMP_UGE;
13492 const SCEV *GuardedRHS = applyLoopGuards(OrigRHS, L);
13493 const SCEV *GuardedStart = applyLoopGuards(OrigStart, L);
13494
13495 if (isLoopEntryGuardedByCond(L, CondGE, OrigRHS, OrigStart) ||
13496 isKnownPredicate(CondGE, GuardedRHS, GuardedStart))
13497 return true;
13498
13499 // (RHS > Start - 1) implies RHS >= Start.
13500 // * "RHS >= Start" is trivially equivalent to "RHS > Start - 1" if
13501 // "Start - 1" doesn't overflow.
13502 // * For signed comparison, if Start - 1 does overflow, it's equal
13503 // to INT_MAX, and "RHS >s INT_MAX" is trivially false.
13504 // * For unsigned comparison, if Start - 1 does overflow, it's equal
13505 // to UINT_MAX, and "RHS >u UINT_MAX" is trivially false.
13506 //
13507 // FIXME: Should isLoopEntryGuardedByCond do this for us?
13508 auto CondGT = IsSigned ? ICmpInst::ICMP_SGT : ICmpInst::ICMP_UGT;
13509 auto *StartMinusOne =
13510 getAddExpr(OrigStart, getMinusOne(OrigStart->getType()));
13511 return isLoopEntryGuardedByCond(L, CondGT, OrigRHS, StartMinusOne);
13512 };
13513
13514 // If we know that RHS >= Start in the context of loop, then we know
13515 // that max(RHS, Start) = RHS at this point.
13516 if (canProveRHSGreaterThanEqualStart()) {
13517 End = RHS;
13518 } else {
13519 // If RHS < Start, the backedge will be taken zero times. So in
13520 // general, we can write the backedge-taken count as:
13521 //
13522 // RHS >= Start ? ceil(RHS - Start) / Stride : 0
13523 //
13524 // We convert it to the following to make it more convenient for SCEV:
13525 //
13526 // ceil(max(RHS, Start) - Start) / Stride
13527 End = IsSigned ? getSMaxExpr(RHS, Start) : getUMaxExpr(RHS, Start);
13528
13529 // See what would happen if we assume the backedge is taken. This is
13530 // used to compute MaxBECount.
13531 BECountIfBackedgeTaken =
13532 getUDivCeilSCEV(getMinusSCEV(RHS, Start), Stride);
13533 }
13534
13535 // At this point, we know:
13536 //
13537 // 1. If IsSigned, Start <=s End; otherwise, Start <=u End
13538 // 2. The index variable doesn't overflow.
13539 //
13540 // Therefore, we know N exists such that
13541 // (Start + Stride * N) >= End, and computing "(Start + Stride * N)"
13542 // doesn't overflow.
13543 //
13544 // Using this information, try to prove whether the addition in
13545 // "(Start - End) + (Stride - 1)" has unsigned overflow.
13546 const SCEV *One = getOne(Stride->getType());
13547 bool MayAddOverflow = [&] {
13548 if (isKnownToBeAPowerOfTwo(Stride)) {
13549 // Suppose Stride is a power of two, and Start/End are unsigned
13550 // integers. Let UMAX be the largest representable unsigned
13551 // integer.
13552 //
13553 // By the preconditions of this function, we know
13554 // "(Start + Stride * N) >= End", and this doesn't overflow.
13555 // As a formula:
13556 //
13557 // End <= (Start + Stride * N) <= UMAX
13558 //
13559 // Subtracting Start from all the terms:
13560 //
13561 // End - Start <= Stride * N <= UMAX - Start
13562 //
13563 // Since Start is unsigned, UMAX - Start <= UMAX. Therefore:
13564 //
13565 // End - Start <= Stride * N <= UMAX
13566 //
13567 // Stride * N is a multiple of Stride. Therefore,
13568 //
13569 // End - Start <= Stride * N <= UMAX - (UMAX mod Stride)
13570 //
13571 // Since Stride is a power of two, UMAX + 1 is divisible by
13572 // Stride. Therefore, UMAX mod Stride == Stride - 1. So we can
13573 // write:
13574 //
13575 // End - Start <= Stride * N <= UMAX - Stride - 1
13576 //
13577 // Dropping the middle term:
13578 //
13579 // End - Start <= UMAX - Stride - 1
13580 //
13581 // Adding Stride - 1 to both sides:
13582 //
13583 // (End - Start) + (Stride - 1) <= UMAX
13584 //
13585 // In other words, the addition doesn't have unsigned overflow.
13586 //
13587 // A similar proof works if we treat Start/End as signed values.
13588 // Just rewrite steps before "End - Start <= Stride * N <= UMAX"
13589 // to use signed max instead of unsigned max. Note that we're
13590 // trying to prove a lack of unsigned overflow in either case.
13591 return false;
13592 }
13593 if (Start == Stride || Start == getMinusSCEV(Stride, One)) {
13594 // If Start is equal to Stride, (End - Start) + (Stride - 1) == End
13595 // - 1. If !IsSigned, 0 <u Stride == Start <=u End; so 0 <u End - 1
13596 // <u End. If IsSigned, 0 <s Stride == Start <=s End; so 0 <s End -
13597 // 1 <s End.
13598 //
13599 // If Start is equal to Stride - 1, (End - Start) + Stride - 1 ==
13600 // End.
13601 return false;
13602 }
13603 return true;
13604 }();
13605
13606 const SCEV *Delta = getMinusSCEV(End, Start);
13607 if (!MayAddOverflow) {
13608 // floor((D + (S - 1)) / S)
13609 // We prefer this formulation if it's legal because it's fewer
13610 // operations.
13611 BECount =
13612 getUDivExpr(getAddExpr(Delta, getMinusSCEV(Stride, One)), Stride);
13613 } else {
13614 BECount = getUDivCeilSCEV(Delta, Stride);
13615 }
13616 }
13617 }
13618
13619 const SCEV *ConstantMaxBECount;
13620 bool MaxOrZero = false;
13621 if (isa<SCEVConstant>(BECount)) {
13622 ConstantMaxBECount = BECount;
13623 } else if (BECountIfBackedgeTaken &&
13624 isa<SCEVConstant>(BECountIfBackedgeTaken)) {
13625 // If we know exactly how many times the backedge will be taken if it's
13626 // taken at least once, then the backedge count will either be that or
13627 // zero.
13628 ConstantMaxBECount = BECountIfBackedgeTaken;
13629 MaxOrZero = true;
13630 } else {
13631 ConstantMaxBECount = computeMaxBECountForLT(
13632 Start, Stride, RHS, getTypeSizeInBits(LHS->getType()), IsSigned);
13633 }
13634
13635 if (isa<SCEVCouldNotCompute>(ConstantMaxBECount) &&
13636 !isa<SCEVCouldNotCompute>(BECount))
13637 ConstantMaxBECount = getConstant(getUnsignedRangeMax(BECount));
13638
13639 const SCEV *SymbolicMaxBECount =
13640 isa<SCEVCouldNotCompute>(BECount) ? ConstantMaxBECount : BECount;
13641 return ExitLimit(BECount, ConstantMaxBECount, SymbolicMaxBECount, MaxOrZero,
13642 Predicates);
13643}
13644
13645ScalarEvolution::ExitLimit ScalarEvolution::howManyGreaterThans(
13646 const SCEV *LHS, const SCEV *RHS, const Loop *L, bool IsSigned,
13647 bool ControlsOnlyExit, bool AllowPredicates) {
13649 // We handle only IV > Invariant
13650 if (!isLoopInvariant(RHS, L))
13651 return getCouldNotCompute();
13652
13653 const SCEVAddRecExpr *IV = dyn_cast<SCEVAddRecExpr>(LHS);
13654 if (!IV && AllowPredicates)
13655 // Try to make this an AddRec using runtime tests, in the first X
13656 // iterations of this loop, where X is the SCEV expression found by the
13657 // algorithm below.
13658 IV = convertSCEVToAddRecWithPredicates(LHS, L, Predicates);
13659
13660 // Avoid weird loops
13661 if (!IV || IV->getLoop() != L || !IV->isAffine())
13662 return getCouldNotCompute();
13663
13664 auto WrapType = IsSigned ? SCEV::FlagNSW : SCEV::FlagNUW;
13665 bool NoWrap = ControlsOnlyExit && IV->getNoWrapFlags(WrapType);
13667
13668 const SCEV *Stride = getNegativeSCEV(IV->getStepRecurrence(*this));
13669
13670 // Avoid negative or zero stride values
13671 if (!isKnownPositive(Stride))
13672 return getCouldNotCompute();
13673
13674 // Avoid proven overflow cases: this will ensure that the backedge taken count
13675 // will not generate any unsigned overflow. Relaxed no-overflow conditions
13676 // exploit NoWrapFlags, allowing to optimize in presence of undefined
13677 // behaviors like the case of C language.
13678 if (!Stride->isOne() && !NoWrap)
13679 if (canIVOverflowOnGT(RHS, Stride, IsSigned))
13680 return getCouldNotCompute();
13681
13682 const SCEV *Start = IV->getStart();
13683 const SCEV *End = RHS;
13684 if (!isLoopEntryGuardedByCond(L, Cond, getAddExpr(Start, Stride), RHS)) {
13685 // If we know that Start >= RHS in the context of loop, then we know that
13686 // min(RHS, Start) = RHS at this point.
13688 L, IsSigned ? ICmpInst::ICMP_SGE : ICmpInst::ICMP_UGE, Start, RHS))
13689 End = RHS;
13690 else
13691 End = IsSigned ? getSMinExpr(RHS, Start) : getUMinExpr(RHS, Start);
13692 }
13693
13694 if (Start->getType()->isPointerTy()) {
13696 if (isa<SCEVCouldNotCompute>(Start))
13697 return Start;
13698 }
13699 if (End->getType()->isPointerTy()) {
13700 End = getLosslessPtrToIntExpr(End);
13701 if (isa<SCEVCouldNotCompute>(End))
13702 return End;
13703 }
13704
13705 // Compute ((Start - End) + (Stride - 1)) / Stride.
13706 // FIXME: This can overflow. Holding off on fixing this for now;
13707 // howManyGreaterThans will hopefully be gone soon.
13708 const SCEV *One = getOne(Stride->getType());
13709 const SCEV *BECount = getUDivExpr(
13710 getAddExpr(getMinusSCEV(Start, End), getMinusSCEV(Stride, One)), Stride);
13711
13712 APInt MaxStart = IsSigned ? getSignedRangeMax(Start)
13714
13715 APInt MinStride = IsSigned ? getSignedRangeMin(Stride)
13716 : getUnsignedRangeMin(Stride);
13717
13718 unsigned BitWidth = getTypeSizeInBits(LHS->getType());
13719 APInt Limit = IsSigned ? APInt::getSignedMinValue(BitWidth) + (MinStride - 1)
13720 : APInt::getMinValue(BitWidth) + (MinStride - 1);
13721
13722 // Although End can be a MIN expression we estimate MinEnd considering only
13723 // the case End = RHS. This is safe because in the other case (Start - End)
13724 // is zero, leading to a zero maximum backedge taken count.
13725 APInt MinEnd =
13726 IsSigned ? APIntOps::smax(getSignedRangeMin(RHS), Limit)
13727 : APIntOps::umax(getUnsignedRangeMin(RHS), Limit);
13728
13729 const SCEV *ConstantMaxBECount =
13730 isa<SCEVConstant>(BECount)
13731 ? BECount
13732 : getUDivCeilSCEV(getConstant(MaxStart - MinEnd),
13733 getConstant(MinStride));
13734
13735 if (isa<SCEVCouldNotCompute>(ConstantMaxBECount))
13736 ConstantMaxBECount = BECount;
13737 const SCEV *SymbolicMaxBECount =
13738 isa<SCEVCouldNotCompute>(BECount) ? ConstantMaxBECount : BECount;
13739
13740 return ExitLimit(BECount, ConstantMaxBECount, SymbolicMaxBECount, false,
13741 Predicates);
13742}
13743
13745 ScalarEvolution &SE) const {
13746 if (Range.isFullSet()) // Infinite loop.
13747 return SE.getCouldNotCompute();
13748
13749 // If the start is a non-zero constant, shift the range to simplify things.
13750 if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(getStart()))
13751 if (!SC->getValue()->isZero()) {
13753 Operands[0] = SE.getZero(SC->getType());
13754 const SCEV *Shifted = SE.getAddRecExpr(Operands, getLoop(),
13756 if (const auto *ShiftedAddRec = dyn_cast<SCEVAddRecExpr>(Shifted))
13757 return ShiftedAddRec->getNumIterationsInRange(
13758 Range.subtract(SC->getAPInt()), SE);
13759 // This is strange and shouldn't happen.
13760 return SE.getCouldNotCompute();
13761 }
13762
13763 // The only time we can solve this is when we have all constant indices.
13764 // Otherwise, we cannot determine the overflow conditions.
13765 if (any_of(operands(), [](const SCEV *Op) { return !isa<SCEVConstant>(Op); }))
13766 return SE.getCouldNotCompute();
13767
13768 // Okay at this point we know that all elements of the chrec are constants and
13769 // that the start element is zero.
13770
13771 // First check to see if the range contains zero. If not, the first
13772 // iteration exits.
13773 unsigned BitWidth = SE.getTypeSizeInBits(getType());
13774 if (!Range.contains(APInt(BitWidth, 0)))
13775 return SE.getZero(getType());
13776
13777 if (isAffine()) {
13778 // If this is an affine expression then we have this situation:
13779 // Solve {0,+,A} in Range === Ax in Range
13780
13781 // We know that zero is in the range. If A is positive then we know that
13782 // the upper value of the range must be the first possible exit value.
13783 // If A is negative then the lower of the range is the last possible loop
13784 // value. Also note that we already checked for a full range.
13785 APInt A = cast<SCEVConstant>(getOperand(1))->getAPInt();
13786 APInt End = A.sge(1) ? (Range.getUpper() - 1) : Range.getLower();
13787
13788 // The exit value should be (End+A)/A.
13789 APInt ExitVal = (End + A).udiv(A);
13790 ConstantInt *ExitValue = ConstantInt::get(SE.getContext(), ExitVal);
13791
13792 // Evaluate at the exit value. If we really did fall out of the valid
13793 // range, then we computed our trip count, otherwise wrap around or other
13794 // things must have happened.
13795 ConstantInt *Val = EvaluateConstantChrecAtConstant(this, ExitValue, SE);
13796 if (Range.contains(Val->getValue()))
13797 return SE.getCouldNotCompute(); // Something strange happened
13798
13799 // Ensure that the previous value is in the range.
13800 assert(Range.contains(
13802 ConstantInt::get(SE.getContext(), ExitVal - 1), SE)->getValue()) &&
13803 "Linear scev computation is off in a bad way!");
13804 return SE.getConstant(ExitValue);
13805 }
13806
13807 if (isQuadratic()) {
13808 if (auto S = SolveQuadraticAddRecRange(this, Range, SE))
13809 return SE.getConstant(*S);
13810 }
13811
13812 return SE.getCouldNotCompute();
13813}
13814
13815const SCEVAddRecExpr *
13817 assert(getNumOperands() > 1 && "AddRec with zero step?");
13818 // There is a temptation to just call getAddExpr(this, getStepRecurrence(SE)),
13819 // but in this case we cannot guarantee that the value returned will be an
13820 // AddRec because SCEV does not have a fixed point where it stops
13821 // simplification: it is legal to return ({rec1} + {rec2}). For example, it
13822 // may happen if we reach arithmetic depth limit while simplifying. So we
13823 // construct the returned value explicitly.
13825 // If this is {A,+,B,+,C,...,+,N}, then its step is {B,+,C,+,...,+,N}, and
13826 // (this + Step) is {A+B,+,B+C,+...,+,N}.
13827 for (unsigned i = 0, e = getNumOperands() - 1; i < e; ++i)
13828 Ops.push_back(SE.getAddExpr(getOperand(i), getOperand(i + 1)));
13829 // We know that the last operand is not a constant zero (otherwise it would
13830 // have been popped out earlier). This guarantees us that if the result has
13831 // the same last operand, then it will also not be popped out, meaning that
13832 // the returned value will be an AddRec.
13833 const SCEV *Last = getOperand(getNumOperands() - 1);
13834 assert(!Last->isZero() && "Recurrency with zero step?");
13835 Ops.push_back(Last);
13838}
13839
13840// Return true when S contains at least an undef value.
13842 return SCEVExprContains(
13843 S, [](const SCEV *S) { return match(S, m_scev_UndefOrPoison()); });
13844}
13845
13846// Return true when S contains a value that is a nullptr.
13848 return SCEVExprContains(S, [](const SCEV *S) {
13849 if (const auto *SU = dyn_cast<SCEVUnknown>(S))
13850 return SU->getValue() == nullptr;
13851 return false;
13852 });
13853}
13854
13855/// Return the size of an element read or written by Inst.
13857 Type *Ty;
13858 if (StoreInst *Store = dyn_cast<StoreInst>(Inst))
13859 Ty = Store->getValueOperand()->getType();
13860 else if (LoadInst *Load = dyn_cast<LoadInst>(Inst))
13861 Ty = Load->getType();
13862 else
13863 return nullptr;
13864
13866 return getSizeOfExpr(ETy, Ty);
13867}
13868
13869//===----------------------------------------------------------------------===//
13870// SCEVCallbackVH Class Implementation
13871//===----------------------------------------------------------------------===//
13872
13874 assert(SE && "SCEVCallbackVH called with a null ScalarEvolution!");
13875 if (PHINode *PN = dyn_cast<PHINode>(getValPtr()))
13876 SE->ConstantEvolutionLoopExitValue.erase(PN);
13877 SE->eraseValueFromMap(getValPtr());
13878 // this now dangles!
13879}
13880
13881void ScalarEvolution::SCEVCallbackVH::allUsesReplacedWith(Value *V) {
13882 assert(SE && "SCEVCallbackVH called with a null ScalarEvolution!");
13883
13884 // Forget all the expressions associated with users of the old value,
13885 // so that future queries will recompute the expressions using the new
13886 // value.
13887 SE->forgetValue(getValPtr());
13888 // this now dangles!
13889}
13890
13891ScalarEvolution::SCEVCallbackVH::SCEVCallbackVH(Value *V, ScalarEvolution *se)
13892 : CallbackVH(V), SE(se) {}
13893
13894//===----------------------------------------------------------------------===//
13895// ScalarEvolution Class Implementation
13896//===----------------------------------------------------------------------===//
13897
13900 LoopInfo &LI)
13901 : F(F), DL(F.getDataLayout()), TLI(TLI), AC(AC), DT(DT), LI(LI),
13902 CouldNotCompute(new SCEVCouldNotCompute()), ValuesAtScopes(64),
13903 LoopDispositions(64), BlockDispositions(64) {
13904 // To use guards for proving predicates, we need to scan every instruction in
13905 // relevant basic blocks, and not just terminators. Doing this is a waste of
13906 // time if the IR does not actually contain any calls to
13907 // @llvm.experimental.guard, so do a quick check and remember this beforehand.
13908 //
13909 // This pessimizes the case where a pass that preserves ScalarEvolution wants
13910 // to _add_ guards to the module when there weren't any before, and wants
13911 // ScalarEvolution to optimize based on those guards. For now we prefer to be
13912 // efficient in lieu of being smart in that rather obscure case.
13913
13914 auto *GuardDecl = Intrinsic::getDeclarationIfExists(
13915 F.getParent(), Intrinsic::experimental_guard);
13916 HasGuards = GuardDecl && !GuardDecl->use_empty();
13917}
13918
13920 : F(Arg.F), DL(Arg.DL), HasGuards(Arg.HasGuards), TLI(Arg.TLI), AC(Arg.AC),
13921 DT(Arg.DT), LI(Arg.LI), CouldNotCompute(std::move(Arg.CouldNotCompute)),
13922 ValueExprMap(std::move(Arg.ValueExprMap)),
13923 PendingLoopPredicates(std::move(Arg.PendingLoopPredicates)),
13924 PendingMerges(std::move(Arg.PendingMerges)),
13925 ConstantMultipleCache(std::move(Arg.ConstantMultipleCache)),
13926 BackedgeTakenCounts(std::move(Arg.BackedgeTakenCounts)),
13927 PredicatedBackedgeTakenCounts(
13928 std::move(Arg.PredicatedBackedgeTakenCounts)),
13929 BECountUsers(std::move(Arg.BECountUsers)),
13930 ConstantEvolutionLoopExitValue(
13931 std::move(Arg.ConstantEvolutionLoopExitValue)),
13932 ValuesAtScopes(std::move(Arg.ValuesAtScopes)),
13933 ValuesAtScopesUsers(std::move(Arg.ValuesAtScopesUsers)),
13934 LoopDispositions(std::move(Arg.LoopDispositions)),
13935 LoopPropertiesCache(std::move(Arg.LoopPropertiesCache)),
13936 BlockDispositions(std::move(Arg.BlockDispositions)),
13937 SCEVUsers(std::move(Arg.SCEVUsers)),
13938 UnsignedRanges(std::move(Arg.UnsignedRanges)),
13939 SignedRanges(std::move(Arg.SignedRanges)),
13940 UniqueSCEVs(std::move(Arg.UniqueSCEVs)),
13941 UniquePreds(std::move(Arg.UniquePreds)),
13942 SCEVAllocator(std::move(Arg.SCEVAllocator)),
13943 LoopUsers(std::move(Arg.LoopUsers)),
13944 PredicatedSCEVRewrites(std::move(Arg.PredicatedSCEVRewrites)),
13945 FirstUnknown(Arg.FirstUnknown) {
13946 Arg.FirstUnknown = nullptr;
13947}
13948
13950 // Iterate through all the SCEVUnknown instances and call their
13951 // destructors, so that they release their references to their values.
13952 for (SCEVUnknown *U = FirstUnknown; U;) {
13953 SCEVUnknown *Tmp = U;
13954 U = U->Next;
13955 Tmp->~SCEVUnknown();
13956 }
13957 FirstUnknown = nullptr;
13958
13959 ExprValueMap.clear();
13960 ValueExprMap.clear();
13961 HasRecMap.clear();
13962 BackedgeTakenCounts.clear();
13963 PredicatedBackedgeTakenCounts.clear();
13964
13965 assert(PendingLoopPredicates.empty() && "isImpliedCond garbage");
13966 assert(PendingMerges.empty() && "isImpliedViaMerge garbage");
13967 assert(!WalkingBEDominatingConds && "isLoopBackedgeGuardedByCond garbage!");
13968 assert(!ProvingSplitPredicate && "ProvingSplitPredicate garbage!");
13969}
13970
13974
13975/// When printing a top-level SCEV for trip counts, it's helpful to include
13976/// a type for constants which are otherwise hard to disambiguate.
13977static void PrintSCEVWithTypeHint(raw_ostream &OS, const SCEV* S) {
13978 if (isa<SCEVConstant>(S))
13979 OS << *S->getType() << " ";
13980 OS << *S;
13981}
13982
13984 const Loop *L) {
13985 // Print all inner loops first
13986 for (Loop *I : *L)
13987 PrintLoopInfo(OS, SE, I);
13988
13989 OS << "Loop ";
13990 L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
13991 OS << ": ";
13992
13993 SmallVector<BasicBlock *, 8> ExitingBlocks;
13994 L->getExitingBlocks(ExitingBlocks);
13995 if (ExitingBlocks.size() != 1)
13996 OS << "<multiple exits> ";
13997
13998 auto *BTC = SE->getBackedgeTakenCount(L);
13999 if (!isa<SCEVCouldNotCompute>(BTC)) {
14000 OS << "backedge-taken count is ";
14001 PrintSCEVWithTypeHint(OS, BTC);
14002 } else
14003 OS << "Unpredictable backedge-taken count.";
14004 OS << "\n";
14005
14006 if (ExitingBlocks.size() > 1)
14007 for (BasicBlock *ExitingBlock : ExitingBlocks) {
14008 OS << " exit count for " << ExitingBlock->getName() << ": ";
14009 const SCEV *EC = SE->getExitCount(L, ExitingBlock);
14010 PrintSCEVWithTypeHint(OS, EC);
14011 if (isa<SCEVCouldNotCompute>(EC)) {
14012 // Retry with predicates.
14014 EC = SE->getPredicatedExitCount(L, ExitingBlock, &Predicates);
14015 if (!isa<SCEVCouldNotCompute>(EC)) {
14016 OS << "\n predicated exit count for " << ExitingBlock->getName()
14017 << ": ";
14018 PrintSCEVWithTypeHint(OS, EC);
14019 OS << "\n Predicates:\n";
14020 for (const auto *P : Predicates)
14021 P->print(OS, 4);
14022 }
14023 }
14024 OS << "\n";
14025 }
14026
14027 OS << "Loop ";
14028 L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
14029 OS << ": ";
14030
14031 auto *ConstantBTC = SE->getConstantMaxBackedgeTakenCount(L);
14032 if (!isa<SCEVCouldNotCompute>(ConstantBTC)) {
14033 OS << "constant max backedge-taken count is ";
14034 PrintSCEVWithTypeHint(OS, ConstantBTC);
14036 OS << ", actual taken count either this or zero.";
14037 } else {
14038 OS << "Unpredictable constant max backedge-taken count. ";
14039 }
14040
14041 OS << "\n"
14042 "Loop ";
14043 L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
14044 OS << ": ";
14045
14046 auto *SymbolicBTC = SE->getSymbolicMaxBackedgeTakenCount(L);
14047 if (!isa<SCEVCouldNotCompute>(SymbolicBTC)) {
14048 OS << "symbolic max backedge-taken count is ";
14049 PrintSCEVWithTypeHint(OS, SymbolicBTC);
14051 OS << ", actual taken count either this or zero.";
14052 } else {
14053 OS << "Unpredictable symbolic max backedge-taken count. ";
14054 }
14055 OS << "\n";
14056
14057 if (ExitingBlocks.size() > 1)
14058 for (BasicBlock *ExitingBlock : ExitingBlocks) {
14059 OS << " symbolic max exit count for " << ExitingBlock->getName() << ": ";
14060 auto *ExitBTC = SE->getExitCount(L, ExitingBlock,
14062 PrintSCEVWithTypeHint(OS, ExitBTC);
14063 if (isa<SCEVCouldNotCompute>(ExitBTC)) {
14064 // Retry with predicates.
14066 ExitBTC = SE->getPredicatedExitCount(L, ExitingBlock, &Predicates,
14068 if (!isa<SCEVCouldNotCompute>(ExitBTC)) {
14069 OS << "\n predicated symbolic max exit count for "
14070 << ExitingBlock->getName() << ": ";
14071 PrintSCEVWithTypeHint(OS, ExitBTC);
14072 OS << "\n Predicates:\n";
14073 for (const auto *P : Predicates)
14074 P->print(OS, 4);
14075 }
14076 }
14077 OS << "\n";
14078 }
14079
14081 auto *PBT = SE->getPredicatedBackedgeTakenCount(L, Preds);
14082 if (PBT != BTC) {
14083 assert(!Preds.empty() && "Different predicated BTC, but no predicates");
14084 OS << "Loop ";
14085 L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
14086 OS << ": ";
14087 if (!isa<SCEVCouldNotCompute>(PBT)) {
14088 OS << "Predicated backedge-taken count is ";
14089 PrintSCEVWithTypeHint(OS, PBT);
14090 } else
14091 OS << "Unpredictable predicated backedge-taken count.";
14092 OS << "\n";
14093 OS << " Predicates:\n";
14094 for (const auto *P : Preds)
14095 P->print(OS, 4);
14096 }
14097 Preds.clear();
14098
14099 auto *PredConstantMax =
14101 if (PredConstantMax != ConstantBTC) {
14102 assert(!Preds.empty() &&
14103 "different predicated constant max BTC but no predicates");
14104 OS << "Loop ";
14105 L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
14106 OS << ": ";
14107 if (!isa<SCEVCouldNotCompute>(PredConstantMax)) {
14108 OS << "Predicated constant max backedge-taken count is ";
14109 PrintSCEVWithTypeHint(OS, PredConstantMax);
14110 } else
14111 OS << "Unpredictable predicated constant max backedge-taken count.";
14112 OS << "\n";
14113 OS << " Predicates:\n";
14114 for (const auto *P : Preds)
14115 P->print(OS, 4);
14116 }
14117 Preds.clear();
14118
14119 auto *PredSymbolicMax =
14121 if (SymbolicBTC != PredSymbolicMax) {
14122 assert(!Preds.empty() &&
14123 "Different predicated symbolic max BTC, but no predicates");
14124 OS << "Loop ";
14125 L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
14126 OS << ": ";
14127 if (!isa<SCEVCouldNotCompute>(PredSymbolicMax)) {
14128 OS << "Predicated symbolic max backedge-taken count is ";
14129 PrintSCEVWithTypeHint(OS, PredSymbolicMax);
14130 } else
14131 OS << "Unpredictable predicated symbolic max backedge-taken count.";
14132 OS << "\n";
14133 OS << " Predicates:\n";
14134 for (const auto *P : Preds)
14135 P->print(OS, 4);
14136 }
14137
14139 OS << "Loop ";
14140 L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
14141 OS << ": ";
14142 OS << "Trip multiple is " << SE->getSmallConstantTripMultiple(L) << "\n";
14143 }
14144}
14145
14146namespace llvm {
14147// Note: these overloaded operators need to be in the llvm namespace for them
14148// to be resolved correctly. If we put them outside the llvm namespace, the
14149//
14150// OS << ": " << SE.getLoopDisposition(SV, InnerL);
14151//
14152// code below "breaks" and start printing raw enum values as opposed to the
14153// string values.
14156 switch (LD) {
14158 OS << "Variant";
14159 break;
14161 OS << "Invariant";
14162 break;
14164 OS << "Computable";
14165 break;
14166 }
14167 return OS;
14168}
14169
14172 switch (BD) {
14174 OS << "DoesNotDominate";
14175 break;
14177 OS << "Dominates";
14178 break;
14180 OS << "ProperlyDominates";
14181 break;
14182 }
14183 return OS;
14184}
14185} // namespace llvm
14186
14188 // ScalarEvolution's implementation of the print method is to print
14189 // out SCEV values of all instructions that are interesting. Doing
14190 // this potentially causes it to create new SCEV objects though,
14191 // which technically conflicts with the const qualifier. This isn't
14192 // observable from outside the class though, so casting away the
14193 // const isn't dangerous.
14194 ScalarEvolution &SE = *const_cast<ScalarEvolution *>(this);
14195
14196 if (ClassifyExpressions) {
14197 OS << "Classifying expressions for: ";
14198 F.printAsOperand(OS, /*PrintType=*/false);
14199 OS << "\n";
14200 for (Instruction &I : instructions(F))
14201 if (isSCEVable(I.getType()) && !isa<CmpInst>(I)) {
14202 OS << I << '\n';
14203 OS << " --> ";
14204 const SCEV *SV = SE.getSCEV(&I);
14205 SV->print(OS);
14206 if (!isa<SCEVCouldNotCompute>(SV)) {
14207 OS << " U: ";
14208 SE.getUnsignedRange(SV).print(OS);
14209 OS << " S: ";
14210 SE.getSignedRange(SV).print(OS);
14211 }
14212
14213 const Loop *L = LI.getLoopFor(I.getParent());
14214
14215 const SCEV *AtUse = SE.getSCEVAtScope(SV, L);
14216 if (AtUse != SV) {
14217 OS << " --> ";
14218 AtUse->print(OS);
14219 if (!isa<SCEVCouldNotCompute>(AtUse)) {
14220 OS << " U: ";
14221 SE.getUnsignedRange(AtUse).print(OS);
14222 OS << " S: ";
14223 SE.getSignedRange(AtUse).print(OS);
14224 }
14225 }
14226
14227 if (L) {
14228 OS << "\t\t" "Exits: ";
14229 const SCEV *ExitValue = SE.getSCEVAtScope(SV, L->getParentLoop());
14230 if (!SE.isLoopInvariant(ExitValue, L)) {
14231 OS << "<<Unknown>>";
14232 } else {
14233 OS << *ExitValue;
14234 }
14235
14236 ListSeparator LS(", ", "\t\tLoopDispositions: { ");
14237 for (const auto *Iter = L; Iter; Iter = Iter->getParentLoop()) {
14238 OS << LS;
14239 Iter->getHeader()->printAsOperand(OS, /*PrintType=*/false);
14240 OS << ": " << SE.getLoopDisposition(SV, Iter);
14241 }
14242
14243 for (const auto *InnerL : depth_first(L)) {
14244 if (InnerL == L)
14245 continue;
14246 OS << LS;
14247 InnerL->getHeader()->printAsOperand(OS, /*PrintType=*/false);
14248 OS << ": " << SE.getLoopDisposition(SV, InnerL);
14249 }
14250
14251 OS << " }";
14252 }
14253
14254 OS << "\n";
14255 }
14256 }
14257
14258 OS << "Determining loop execution counts for: ";
14259 F.printAsOperand(OS, /*PrintType=*/false);
14260 OS << "\n";
14261 for (Loop *I : LI)
14262 PrintLoopInfo(OS, &SE, I);
14263}
14264
14267 auto &Values = LoopDispositions[S];
14268 for (auto &V : Values) {
14269 if (V.getPointer() == L)
14270 return V.getInt();
14271 }
14272 Values.emplace_back(L, LoopVariant);
14273 LoopDisposition D = computeLoopDisposition(S, L);
14274 auto &Values2 = LoopDispositions[S];
14275 for (auto &V : llvm::reverse(Values2)) {
14276 if (V.getPointer() == L) {
14277 V.setInt(D);
14278 break;
14279 }
14280 }
14281 return D;
14282}
14283
14285ScalarEvolution::computeLoopDisposition(const SCEV *S, const Loop *L) {
14286 switch (S->getSCEVType()) {
14287 case scConstant:
14288 case scVScale:
14289 return LoopInvariant;
14290 case scAddRecExpr: {
14291 const SCEVAddRecExpr *AR = cast<SCEVAddRecExpr>(S);
14292
14293 // If L is the addrec's loop, it's computable.
14294 if (AR->getLoop() == L)
14295 return LoopComputable;
14296
14297 // Add recurrences are never invariant in the function-body (null loop).
14298 if (!L)
14299 return LoopVariant;
14300
14301 // Everything that is not defined at loop entry is variant.
14302 if (DT.dominates(L->getHeader(), AR->getLoop()->getHeader()))
14303 return LoopVariant;
14304 assert(!L->contains(AR->getLoop()) && "Containing loop's header does not"
14305 " dominate the contained loop's header?");
14306
14307 // This recurrence is invariant w.r.t. L if AR's loop contains L.
14308 if (AR->getLoop()->contains(L))
14309 return LoopInvariant;
14310
14311 // This recurrence is variant w.r.t. L if any of its operands
14312 // are variant.
14313 for (SCEVUse Op : AR->operands())
14314 if (!isLoopInvariant(Op, L))
14315 return LoopVariant;
14316
14317 // Otherwise it's loop-invariant.
14318 return LoopInvariant;
14319 }
14320 case scTruncate:
14321 case scZeroExtend:
14322 case scSignExtend:
14323 case scPtrToAddr:
14324 case scPtrToInt:
14325 case scAddExpr:
14326 case scMulExpr:
14327 case scUDivExpr:
14328 case scUMaxExpr:
14329 case scSMaxExpr:
14330 case scUMinExpr:
14331 case scSMinExpr:
14332 case scSequentialUMinExpr: {
14333 bool HasVarying = false;
14334 for (SCEVUse Op : S->operands()) {
14336 if (D == LoopVariant)
14337 return LoopVariant;
14338 if (D == LoopComputable)
14339 HasVarying = true;
14340 }
14341 return HasVarying ? LoopComputable : LoopInvariant;
14342 }
14343 case scUnknown:
14344 // All non-instruction values are loop invariant. All instructions are loop
14345 // invariant if they are not contained in the specified loop.
14346 // Instructions are never considered invariant in the function body
14347 // (null loop) because they are defined within the "loop".
14348 if (auto *I = dyn_cast<Instruction>(cast<SCEVUnknown>(S)->getValue()))
14349 return (L && !L->contains(I)) ? LoopInvariant : LoopVariant;
14350 return LoopInvariant;
14351 case scCouldNotCompute:
14352 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
14353 }
14354 llvm_unreachable("Unknown SCEV kind!");
14355}
14356
14358 return getLoopDisposition(S, L) == LoopInvariant;
14359}
14360
14362 return getLoopDisposition(S, L) == LoopComputable;
14363}
14364
14367 auto &Values = BlockDispositions[S];
14368 for (auto &V : Values) {
14369 if (V.getPointer() == BB)
14370 return V.getInt();
14371 }
14372 Values.emplace_back(BB, DoesNotDominateBlock);
14373 BlockDisposition D = computeBlockDisposition(S, BB);
14374 auto &Values2 = BlockDispositions[S];
14375 for (auto &V : llvm::reverse(Values2)) {
14376 if (V.getPointer() == BB) {
14377 V.setInt(D);
14378 break;
14379 }
14380 }
14381 return D;
14382}
14383
14385ScalarEvolution::computeBlockDisposition(const SCEV *S, const BasicBlock *BB) {
14386 switch (S->getSCEVType()) {
14387 case scConstant:
14388 case scVScale:
14390 case scAddRecExpr: {
14391 // This uses a "dominates" query instead of "properly dominates" query
14392 // to test for proper dominance too, because the instruction which
14393 // produces the addrec's value is a PHI, and a PHI effectively properly
14394 // dominates its entire containing block.
14395 const SCEVAddRecExpr *AR = cast<SCEVAddRecExpr>(S);
14396 if (!DT.dominates(AR->getLoop()->getHeader(), BB))
14397 return DoesNotDominateBlock;
14398
14399 // Fall through into SCEVNAryExpr handling.
14400 [[fallthrough]];
14401 }
14402 case scTruncate:
14403 case scZeroExtend:
14404 case scSignExtend:
14405 case scPtrToAddr:
14406 case scPtrToInt:
14407 case scAddExpr:
14408 case scMulExpr:
14409 case scUDivExpr:
14410 case scUMaxExpr:
14411 case scSMaxExpr:
14412 case scUMinExpr:
14413 case scSMinExpr:
14414 case scSequentialUMinExpr: {
14415 bool Proper = true;
14416 for (const SCEV *NAryOp : S->operands()) {
14418 if (D == DoesNotDominateBlock)
14419 return DoesNotDominateBlock;
14420 if (D == DominatesBlock)
14421 Proper = false;
14422 }
14423 return Proper ? ProperlyDominatesBlock : DominatesBlock;
14424 }
14425 case scUnknown:
14426 if (Instruction *I =
14427 dyn_cast<Instruction>(cast<SCEVUnknown>(S)->getValue())) {
14428 if (I->getParent() == BB)
14429 return DominatesBlock;
14430 if (DT.properlyDominates(I->getParent(), BB))
14432 return DoesNotDominateBlock;
14433 }
14435 case scCouldNotCompute:
14436 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
14437 }
14438 llvm_unreachable("Unknown SCEV kind!");
14439}
14440
14441bool ScalarEvolution::dominates(const SCEV *S, const BasicBlock *BB) {
14442 return getBlockDisposition(S, BB) >= DominatesBlock;
14443}
14444
14447}
14448
14449bool ScalarEvolution::hasOperand(const SCEV *S, const SCEV *Op) const {
14450 return SCEVExprContains(S, [&](const SCEV *Expr) { return Expr == Op; });
14451}
14452
14453void ScalarEvolution::forgetBackedgeTakenCounts(const Loop *L,
14454 bool Predicated) {
14455 auto &BECounts =
14456 Predicated ? PredicatedBackedgeTakenCounts : BackedgeTakenCounts;
14457 auto It = BECounts.find(L);
14458 if (It != BECounts.end()) {
14459 for (const ExitNotTakenInfo &ENT : It->second.ExitNotTaken) {
14460 for (const SCEV *S : {ENT.ExactNotTaken, ENT.SymbolicMaxNotTaken}) {
14461 if (!isa<SCEVConstant>(S)) {
14462 auto UserIt = BECountUsers.find(S);
14463 assert(UserIt != BECountUsers.end());
14464 UserIt->second.erase({L, Predicated});
14465 }
14466 }
14467 }
14468 BECounts.erase(It);
14469 }
14470}
14471
14472void ScalarEvolution::forgetMemoizedResults(ArrayRef<SCEVUse> SCEVs) {
14473 SmallPtrSet<const SCEV *, 8> ToForget(llvm::from_range, SCEVs);
14474 SmallVector<SCEVUse, 8> Worklist(ToForget.begin(), ToForget.end());
14475
14476 while (!Worklist.empty()) {
14477 const SCEV *Curr = Worklist.pop_back_val();
14478 auto Users = SCEVUsers.find(Curr);
14479 if (Users != SCEVUsers.end())
14480 for (const auto *User : Users->second)
14481 if (ToForget.insert(User).second)
14482 Worklist.push_back(User);
14483 }
14484
14485 for (const auto *S : ToForget)
14486 forgetMemoizedResultsImpl(S);
14487
14488 for (auto I = PredicatedSCEVRewrites.begin();
14489 I != PredicatedSCEVRewrites.end();) {
14490 std::pair<const SCEV *, const Loop *> Entry = I->first;
14491 if (ToForget.count(Entry.first))
14492 PredicatedSCEVRewrites.erase(I++);
14493 else
14494 ++I;
14495 }
14496}
14497
14498void ScalarEvolution::forgetMemoizedResultsImpl(const SCEV *S) {
14499 LoopDispositions.erase(S);
14500 BlockDispositions.erase(S);
14501 UnsignedRanges.erase(S);
14502 SignedRanges.erase(S);
14503 HasRecMap.erase(S);
14504 ConstantMultipleCache.erase(S);
14505
14506 if (auto *AR = dyn_cast<SCEVAddRecExpr>(S)) {
14507 UnsignedWrapViaInductionTried.erase(AR);
14508 SignedWrapViaInductionTried.erase(AR);
14509 }
14510
14511 auto ExprIt = ExprValueMap.find(S);
14512 if (ExprIt != ExprValueMap.end()) {
14513 for (Value *V : ExprIt->second) {
14514 auto ValueIt = ValueExprMap.find_as(V);
14515 if (ValueIt != ValueExprMap.end())
14516 ValueExprMap.erase(ValueIt);
14517 }
14518 ExprValueMap.erase(ExprIt);
14519 }
14520
14521 auto ScopeIt = ValuesAtScopes.find(S);
14522 if (ScopeIt != ValuesAtScopes.end()) {
14523 for (const auto &Pair : ScopeIt->second)
14524 if (!isa_and_nonnull<SCEVConstant>(Pair.second))
14525 llvm::erase(ValuesAtScopesUsers[Pair.second],
14526 std::make_pair(Pair.first, S));
14527 ValuesAtScopes.erase(ScopeIt);
14528 }
14529
14530 auto ScopeUserIt = ValuesAtScopesUsers.find(S);
14531 if (ScopeUserIt != ValuesAtScopesUsers.end()) {
14532 for (const auto &Pair : ScopeUserIt->second)
14533 llvm::erase(ValuesAtScopes[Pair.second], std::make_pair(Pair.first, S));
14534 ValuesAtScopesUsers.erase(ScopeUserIt);
14535 }
14536
14537 auto BEUsersIt = BECountUsers.find(S);
14538 if (BEUsersIt != BECountUsers.end()) {
14539 // Work on a copy, as forgetBackedgeTakenCounts() will modify the original.
14540 auto Copy = BEUsersIt->second;
14541 for (const auto &Pair : Copy)
14542 forgetBackedgeTakenCounts(Pair.getPointer(), Pair.getInt());
14543 BECountUsers.erase(BEUsersIt);
14544 }
14545
14546 auto FoldUser = FoldCacheUser.find(S);
14547 if (FoldUser != FoldCacheUser.end())
14548 for (auto &KV : FoldUser->second)
14549 FoldCache.erase(KV);
14550 FoldCacheUser.erase(S);
14551}
14552
14553void
14554ScalarEvolution::getUsedLoops(const SCEV *S,
14555 SmallPtrSetImpl<const Loop *> &LoopsUsed) {
14556 struct FindUsedLoops {
14557 FindUsedLoops(SmallPtrSetImpl<const Loop *> &LoopsUsed)
14558 : LoopsUsed(LoopsUsed) {}
14559 SmallPtrSetImpl<const Loop *> &LoopsUsed;
14560 bool follow(const SCEV *S) {
14561 if (auto *AR = dyn_cast<SCEVAddRecExpr>(S))
14562 LoopsUsed.insert(AR->getLoop());
14563 return true;
14564 }
14565
14566 bool isDone() const { return false; }
14567 };
14568
14569 FindUsedLoops F(LoopsUsed);
14570 SCEVTraversal<FindUsedLoops>(F).visitAll(S);
14571}
14572
14573void ScalarEvolution::getReachableBlocks(
14576 Worklist.push_back(&F.getEntryBlock());
14577 while (!Worklist.empty()) {
14578 BasicBlock *BB = Worklist.pop_back_val();
14579 if (!Reachable.insert(BB).second)
14580 continue;
14581
14582 Value *Cond;
14583 BasicBlock *TrueBB, *FalseBB;
14584 if (match(BB->getTerminator(), m_Br(m_Value(Cond), m_BasicBlock(TrueBB),
14585 m_BasicBlock(FalseBB)))) {
14586 if (auto *C = dyn_cast<ConstantInt>(Cond)) {
14587 Worklist.push_back(C->isOne() ? TrueBB : FalseBB);
14588 continue;
14589 }
14590
14591 if (auto *Cmp = dyn_cast<ICmpInst>(Cond)) {
14592 const SCEV *L = getSCEV(Cmp->getOperand(0));
14593 const SCEV *R = getSCEV(Cmp->getOperand(1));
14594 if (isKnownPredicateViaConstantRanges(Cmp->getCmpPredicate(), L, R)) {
14595 Worklist.push_back(TrueBB);
14596 continue;
14597 }
14598 if (isKnownPredicateViaConstantRanges(Cmp->getInverseCmpPredicate(), L,
14599 R)) {
14600 Worklist.push_back(FalseBB);
14601 continue;
14602 }
14603 }
14604 }
14605
14606 append_range(Worklist, successors(BB));
14607 }
14608}
14609
14611 ScalarEvolution &SE = *const_cast<ScalarEvolution *>(this);
14612 ScalarEvolution SE2(F, TLI, AC, DT, LI);
14613
14614 SmallVector<Loop *, 8> LoopStack(LI.begin(), LI.end());
14615
14616 // Map's SCEV expressions from one ScalarEvolution "universe" to another.
14617 struct SCEVMapper : public SCEVRewriteVisitor<SCEVMapper> {
14618 SCEVMapper(ScalarEvolution &SE) : SCEVRewriteVisitor<SCEVMapper>(SE) {}
14619
14620 const SCEV *visitConstant(const SCEVConstant *Constant) {
14621 return SE.getConstant(Constant->getAPInt());
14622 }
14623
14624 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
14625 return SE.getUnknown(Expr->getValue());
14626 }
14627
14628 const SCEV *visitCouldNotCompute(const SCEVCouldNotCompute *Expr) {
14629 return SE.getCouldNotCompute();
14630 }
14631 };
14632
14633 SCEVMapper SCM(SE2);
14634 SmallPtrSet<BasicBlock *, 16> ReachableBlocks;
14635 SE2.getReachableBlocks(ReachableBlocks, F);
14636
14637 auto GetDelta = [&](const SCEV *Old, const SCEV *New) -> const SCEV * {
14638 if (containsUndefs(Old) || containsUndefs(New)) {
14639 // SCEV treats "undef" as an unknown but consistent value (i.e. it does
14640 // not propagate undef aggressively). This means we can (and do) fail
14641 // verification in cases where a transform makes a value go from "undef"
14642 // to "undef+1" (say). The transform is fine, since in both cases the
14643 // result is "undef", but SCEV thinks the value increased by 1.
14644 return nullptr;
14645 }
14646
14647 // Unless VerifySCEVStrict is set, we only compare constant deltas.
14648 const SCEV *Delta = SE2.getMinusSCEV(Old, New);
14649 if (!VerifySCEVStrict && !isa<SCEVConstant>(Delta))
14650 return nullptr;
14651
14652 return Delta;
14653 };
14654
14655 while (!LoopStack.empty()) {
14656 auto *L = LoopStack.pop_back_val();
14657 llvm::append_range(LoopStack, *L);
14658
14659 // Only verify BECounts in reachable loops. For an unreachable loop,
14660 // any BECount is legal.
14661 if (!ReachableBlocks.contains(L->getHeader()))
14662 continue;
14663
14664 // Only verify cached BECounts. Computing new BECounts may change the
14665 // results of subsequent SCEV uses.
14666 auto It = BackedgeTakenCounts.find(L);
14667 if (It == BackedgeTakenCounts.end())
14668 continue;
14669
14670 auto *CurBECount =
14671 SCM.visit(It->second.getExact(L, const_cast<ScalarEvolution *>(this)));
14672 auto *NewBECount = SE2.getBackedgeTakenCount(L);
14673
14674 if (CurBECount == SE2.getCouldNotCompute() ||
14675 NewBECount == SE2.getCouldNotCompute()) {
14676 // NB! This situation is legal, but is very suspicious -- whatever pass
14677 // change the loop to make a trip count go from could not compute to
14678 // computable or vice-versa *should have* invalidated SCEV. However, we
14679 // choose not to assert here (for now) since we don't want false
14680 // positives.
14681 continue;
14682 }
14683
14684 if (SE.getTypeSizeInBits(CurBECount->getType()) >
14685 SE.getTypeSizeInBits(NewBECount->getType()))
14686 NewBECount = SE2.getZeroExtendExpr(NewBECount, CurBECount->getType());
14687 else if (SE.getTypeSizeInBits(CurBECount->getType()) <
14688 SE.getTypeSizeInBits(NewBECount->getType()))
14689 CurBECount = SE2.getZeroExtendExpr(CurBECount, NewBECount->getType());
14690
14691 const SCEV *Delta = GetDelta(CurBECount, NewBECount);
14692 if (Delta && !Delta->isZero()) {
14693 dbgs() << "Trip Count for " << *L << " Changed!\n";
14694 dbgs() << "Old: " << *CurBECount << "\n";
14695 dbgs() << "New: " << *NewBECount << "\n";
14696 dbgs() << "Delta: " << *Delta << "\n";
14697 std::abort();
14698 }
14699 }
14700
14701 // Collect all valid loops currently in LoopInfo.
14702 SmallPtrSet<Loop *, 32> ValidLoops;
14703 SmallVector<Loop *, 32> Worklist(LI.begin(), LI.end());
14704 while (!Worklist.empty()) {
14705 Loop *L = Worklist.pop_back_val();
14706 if (ValidLoops.insert(L).second)
14707 Worklist.append(L->begin(), L->end());
14708 }
14709 for (const auto &KV : ValueExprMap) {
14710#ifndef NDEBUG
14711 // Check for SCEV expressions referencing invalid/deleted loops.
14712 if (auto *AR = dyn_cast<SCEVAddRecExpr>(KV.second)) {
14713 assert(ValidLoops.contains(AR->getLoop()) &&
14714 "AddRec references invalid loop");
14715 }
14716#endif
14717
14718 // Check that the value is also part of the reverse map.
14719 auto It = ExprValueMap.find(KV.second);
14720 if (It == ExprValueMap.end() || !It->second.contains(KV.first)) {
14721 dbgs() << "Value " << *KV.first
14722 << " is in ValueExprMap but not in ExprValueMap\n";
14723 std::abort();
14724 }
14725
14726 if (auto *I = dyn_cast<Instruction>(&*KV.first)) {
14727 if (!ReachableBlocks.contains(I->getParent()))
14728 continue;
14729 const SCEV *OldSCEV = SCM.visit(KV.second);
14730 const SCEV *NewSCEV = SE2.getSCEV(I);
14731 const SCEV *Delta = GetDelta(OldSCEV, NewSCEV);
14732 if (Delta && !Delta->isZero()) {
14733 dbgs() << "SCEV for value " << *I << " changed!\n"
14734 << "Old: " << *OldSCEV << "\n"
14735 << "New: " << *NewSCEV << "\n"
14736 << "Delta: " << *Delta << "\n";
14737 std::abort();
14738 }
14739 }
14740 }
14741
14742 for (const auto &KV : ExprValueMap) {
14743 for (Value *V : KV.second) {
14744 const SCEV *S = ValueExprMap.lookup(V);
14745 if (!S) {
14746 dbgs() << "Value " << *V
14747 << " is in ExprValueMap but not in ValueExprMap\n";
14748 std::abort();
14749 }
14750 if (S != KV.first) {
14751 dbgs() << "Value " << *V << " mapped to " << *S << " rather than "
14752 << *KV.first << "\n";
14753 std::abort();
14754 }
14755 }
14756 }
14757
14758 // Verify integrity of SCEV users.
14759 for (const auto &S : UniqueSCEVs) {
14760 for (SCEVUse Op : S.operands()) {
14761 // We do not store dependencies of constants.
14762 if (isa<SCEVConstant>(Op))
14763 continue;
14764 auto It = SCEVUsers.find(Op);
14765 if (It != SCEVUsers.end() && It->second.count(&S))
14766 continue;
14767 dbgs() << "Use of operand " << *Op << " by user " << S
14768 << " is not being tracked!\n";
14769 std::abort();
14770 }
14771 }
14772
14773 // Verify integrity of ValuesAtScopes users.
14774 for (const auto &ValueAndVec : ValuesAtScopes) {
14775 const SCEV *Value = ValueAndVec.first;
14776 for (const auto &LoopAndValueAtScope : ValueAndVec.second) {
14777 const Loop *L = LoopAndValueAtScope.first;
14778 const SCEV *ValueAtScope = LoopAndValueAtScope.second;
14779 if (!isa<SCEVConstant>(ValueAtScope)) {
14780 auto It = ValuesAtScopesUsers.find(ValueAtScope);
14781 if (It != ValuesAtScopesUsers.end() &&
14782 is_contained(It->second, std::make_pair(L, Value)))
14783 continue;
14784 dbgs() << "Value: " << *Value << ", Loop: " << *L << ", ValueAtScope: "
14785 << *ValueAtScope << " missing in ValuesAtScopesUsers\n";
14786 std::abort();
14787 }
14788 }
14789 }
14790
14791 for (const auto &ValueAtScopeAndVec : ValuesAtScopesUsers) {
14792 const SCEV *ValueAtScope = ValueAtScopeAndVec.first;
14793 for (const auto &LoopAndValue : ValueAtScopeAndVec.second) {
14794 const Loop *L = LoopAndValue.first;
14795 const SCEV *Value = LoopAndValue.second;
14797 auto It = ValuesAtScopes.find(Value);
14798 if (It != ValuesAtScopes.end() &&
14799 is_contained(It->second, std::make_pair(L, ValueAtScope)))
14800 continue;
14801 dbgs() << "Value: " << *Value << ", Loop: " << *L << ", ValueAtScope: "
14802 << *ValueAtScope << " missing in ValuesAtScopes\n";
14803 std::abort();
14804 }
14805 }
14806
14807 // Verify integrity of BECountUsers.
14808 auto VerifyBECountUsers = [&](bool Predicated) {
14809 auto &BECounts =
14810 Predicated ? PredicatedBackedgeTakenCounts : BackedgeTakenCounts;
14811 for (const auto &LoopAndBEInfo : BECounts) {
14812 for (const ExitNotTakenInfo &ENT : LoopAndBEInfo.second.ExitNotTaken) {
14813 for (const SCEV *S : {ENT.ExactNotTaken, ENT.SymbolicMaxNotTaken}) {
14814 if (!isa<SCEVConstant>(S)) {
14815 auto UserIt = BECountUsers.find(S);
14816 if (UserIt != BECountUsers.end() &&
14817 UserIt->second.contains({ LoopAndBEInfo.first, Predicated }))
14818 continue;
14819 dbgs() << "Value " << *S << " for loop " << *LoopAndBEInfo.first
14820 << " missing from BECountUsers\n";
14821 std::abort();
14822 }
14823 }
14824 }
14825 }
14826 };
14827 VerifyBECountUsers(/* Predicated */ false);
14828 VerifyBECountUsers(/* Predicated */ true);
14829
14830 // Verify intergity of loop disposition cache.
14831 for (auto &[S, Values] : LoopDispositions) {
14832 for (auto [Loop, CachedDisposition] : Values) {
14833 const auto RecomputedDisposition = SE2.getLoopDisposition(S, Loop);
14834 if (CachedDisposition != RecomputedDisposition) {
14835 dbgs() << "Cached disposition of " << *S << " for loop " << *Loop
14836 << " is incorrect: cached " << CachedDisposition << ", actual "
14837 << RecomputedDisposition << "\n";
14838 std::abort();
14839 }
14840 }
14841 }
14842
14843 // Verify integrity of the block disposition cache.
14844 for (auto &[S, Values] : BlockDispositions) {
14845 for (auto [BB, CachedDisposition] : Values) {
14846 const auto RecomputedDisposition = SE2.getBlockDisposition(S, BB);
14847 if (CachedDisposition != RecomputedDisposition) {
14848 dbgs() << "Cached disposition of " << *S << " for block %"
14849 << BB->getName() << " is incorrect: cached " << CachedDisposition
14850 << ", actual " << RecomputedDisposition << "\n";
14851 std::abort();
14852 }
14853 }
14854 }
14855
14856 // Verify FoldCache/FoldCacheUser caches.
14857 for (auto [FoldID, Expr] : FoldCache) {
14858 auto I = FoldCacheUser.find(Expr);
14859 if (I == FoldCacheUser.end()) {
14860 dbgs() << "Missing entry in FoldCacheUser for cached expression " << *Expr
14861 << "!\n";
14862 std::abort();
14863 }
14864 if (!is_contained(I->second, FoldID)) {
14865 dbgs() << "Missing FoldID in cached users of " << *Expr << "!\n";
14866 std::abort();
14867 }
14868 }
14869 for (auto [Expr, IDs] : FoldCacheUser) {
14870 for (auto &FoldID : IDs) {
14871 const SCEV *S = FoldCache.lookup(FoldID);
14872 if (!S) {
14873 dbgs() << "Missing entry in FoldCache for expression " << *Expr
14874 << "!\n";
14875 std::abort();
14876 }
14877 if (S != Expr) {
14878 dbgs() << "Entry in FoldCache doesn't match FoldCacheUser: " << *S
14879 << " != " << *Expr << "!\n";
14880 std::abort();
14881 }
14882 }
14883 }
14884
14885 // Verify that ConstantMultipleCache computations are correct. We check that
14886 // cached multiples and recomputed multiples are multiples of each other to
14887 // verify correctness. It is possible that a recomputed multiple is different
14888 // from the cached multiple due to strengthened no wrap flags or changes in
14889 // KnownBits computations.
14890 for (auto [S, Multiple] : ConstantMultipleCache) {
14891 APInt RecomputedMultiple = SE2.getConstantMultiple(S);
14892 if ((Multiple != 0 && RecomputedMultiple != 0 &&
14893 Multiple.urem(RecomputedMultiple) != 0 &&
14894 RecomputedMultiple.urem(Multiple) != 0)) {
14895 dbgs() << "Incorrect cached computation in ConstantMultipleCache for "
14896 << *S << " : Computed " << RecomputedMultiple
14897 << " but cache contains " << Multiple << "!\n";
14898 std::abort();
14899 }
14900 }
14901}
14902
14904 Function &F, const PreservedAnalyses &PA,
14905 FunctionAnalysisManager::Invalidator &Inv) {
14906 // Invalidate the ScalarEvolution object whenever it isn't preserved or one
14907 // of its dependencies is invalidated.
14908 auto PAC = PA.getChecker<ScalarEvolutionAnalysis>();
14909 return !(PAC.preserved() || PAC.preservedSet<AllAnalysesOn<Function>>()) ||
14910 Inv.invalidate<AssumptionAnalysis>(F, PA) ||
14911 Inv.invalidate<DominatorTreeAnalysis>(F, PA) ||
14912 Inv.invalidate<LoopAnalysis>(F, PA);
14913}
14914
14915AnalysisKey ScalarEvolutionAnalysis::Key;
14916
14919 auto &TLI = AM.getResult<TargetLibraryAnalysis>(F);
14920 auto &AC = AM.getResult<AssumptionAnalysis>(F);
14921 auto &DT = AM.getResult<DominatorTreeAnalysis>(F);
14922 auto &LI = AM.getResult<LoopAnalysis>(F);
14923 return ScalarEvolution(F, TLI, AC, DT, LI);
14924}
14925
14931
14934 // For compatibility with opt's -analyze feature under legacy pass manager
14935 // which was not ported to NPM. This keeps tests using
14936 // update_analyze_test_checks.py working.
14937 OS << "Printing analysis 'Scalar Evolution Analysis' for function '"
14938 << F.getName() << "':\n";
14940 return PreservedAnalyses::all();
14941}
14942
14944 "Scalar Evolution Analysis", false, true)
14950 "Scalar Evolution Analysis", false, true)
14951
14953
14955
14957 SE.reset(new ScalarEvolution(
14959 getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F),
14961 getAnalysis<LoopInfoWrapperPass>().getLoopInfo()));
14962 return false;
14963}
14964
14966
14968 SE->print(OS);
14969}
14970
14972 if (!VerifySCEV)
14973 return;
14974
14975 SE->verify();
14976}
14977
14985
14987 const SCEV *RHS) {
14988 return getComparePredicate(ICmpInst::ICMP_EQ, LHS, RHS);
14989}
14990
14991const SCEVPredicate *
14993 const SCEV *LHS, const SCEV *RHS) {
14995 assert(LHS->getType() == RHS->getType() &&
14996 "Type mismatch between LHS and RHS");
14997 // Unique this node based on the arguments
14998 ID.AddInteger(SCEVPredicate::P_Compare);
14999 ID.AddInteger(Pred);
15000 ID.AddPointer(LHS);
15001 ID.AddPointer(RHS);
15002 void *IP = nullptr;
15003 if (const auto *S = UniquePreds.FindNodeOrInsertPos(ID, IP))
15004 return S;
15005 SCEVComparePredicate *Eq = new (SCEVAllocator)
15006 SCEVComparePredicate(ID.Intern(SCEVAllocator), Pred, LHS, RHS);
15007 UniquePreds.InsertNode(Eq, IP);
15008 return Eq;
15009}
15010
15012 const SCEVAddRecExpr *AR,
15015 // Unique this node based on the arguments
15016 ID.AddInteger(SCEVPredicate::P_Wrap);
15017 ID.AddPointer(AR);
15018 ID.AddInteger(AddedFlags);
15019 void *IP = nullptr;
15020 if (const auto *S = UniquePreds.FindNodeOrInsertPos(ID, IP))
15021 return S;
15022 auto *OF = new (SCEVAllocator)
15023 SCEVWrapPredicate(ID.Intern(SCEVAllocator), AR, AddedFlags);
15024 UniquePreds.InsertNode(OF, IP);
15025 return OF;
15026}
15027
15028namespace {
15029
15030class SCEVPredicateRewriter : public SCEVRewriteVisitor<SCEVPredicateRewriter> {
15031public:
15032
15033 /// Rewrites \p S in the context of a loop L and the SCEV predication
15034 /// infrastructure.
15035 ///
15036 /// If \p Pred is non-null, the SCEV expression is rewritten to respect the
15037 /// equivalences present in \p Pred.
15038 ///
15039 /// If \p NewPreds is non-null, rewrite is free to add further predicates to
15040 /// \p NewPreds such that the result will be an AddRecExpr.
15041 static const SCEV *rewrite(const SCEV *S, const Loop *L, ScalarEvolution &SE,
15043 const SCEVPredicate *Pred) {
15044 SCEVPredicateRewriter Rewriter(L, SE, NewPreds, Pred);
15045 return Rewriter.visit(S);
15046 }
15047
15048 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
15049 if (Pred) {
15050 if (auto *U = dyn_cast<SCEVUnionPredicate>(Pred)) {
15051 for (const auto *Pred : U->getPredicates())
15052 if (const auto *IPred = dyn_cast<SCEVComparePredicate>(Pred))
15053 if (IPred->getLHS() == Expr &&
15054 IPred->getPredicate() == ICmpInst::ICMP_EQ)
15055 return IPred->getRHS();
15056 } else if (const auto *IPred = dyn_cast<SCEVComparePredicate>(Pred)) {
15057 if (IPred->getLHS() == Expr &&
15058 IPred->getPredicate() == ICmpInst::ICMP_EQ)
15059 return IPred->getRHS();
15060 }
15061 }
15062 return convertToAddRecWithPreds(Expr);
15063 }
15064
15065 const SCEV *visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr) {
15066 const SCEV *Operand = visit(Expr->getOperand());
15067 const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Operand);
15068 if (AR && AR->getLoop() == L && AR->isAffine()) {
15069 // This couldn't be folded because the operand didn't have the nuw
15070 // flag. Add the nusw flag as an assumption that we could make.
15071 const SCEV *Step = AR->getStepRecurrence(SE);
15072 Type *Ty = Expr->getType();
15073 if (addOverflowAssumption(AR, SCEVWrapPredicate::IncrementNUSW))
15074 return SE.getAddRecExpr(SE.getZeroExtendExpr(AR->getStart(), Ty),
15075 SE.getSignExtendExpr(Step, Ty), L,
15076 AR->getNoWrapFlags());
15077 }
15078 return SE.getZeroExtendExpr(Operand, Expr->getType());
15079 }
15080
15081 const SCEV *visitSignExtendExpr(const SCEVSignExtendExpr *Expr) {
15082 const SCEV *Operand = visit(Expr->getOperand());
15083 const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Operand);
15084 if (AR && AR->getLoop() == L && AR->isAffine()) {
15085 // This couldn't be folded because the operand didn't have the nsw
15086 // flag. Add the nssw flag as an assumption that we could make.
15087 const SCEV *Step = AR->getStepRecurrence(SE);
15088 Type *Ty = Expr->getType();
15089 if (addOverflowAssumption(AR, SCEVWrapPredicate::IncrementNSSW))
15090 return SE.getAddRecExpr(SE.getSignExtendExpr(AR->getStart(), Ty),
15091 SE.getSignExtendExpr(Step, Ty), L,
15092 AR->getNoWrapFlags());
15093 }
15094 return SE.getSignExtendExpr(Operand, Expr->getType());
15095 }
15096
15097private:
15098 explicit SCEVPredicateRewriter(
15099 const Loop *L, ScalarEvolution &SE,
15100 SmallVectorImpl<const SCEVPredicate *> *NewPreds,
15101 const SCEVPredicate *Pred)
15102 : SCEVRewriteVisitor(SE), NewPreds(NewPreds), Pred(Pred), L(L) {}
15103
15104 bool addOverflowAssumption(const SCEVPredicate *P) {
15105 if (!NewPreds) {
15106 // Check if we've already made this assumption.
15107 return Pred && Pred->implies(P, SE);
15108 }
15109 NewPreds->push_back(P);
15110 return true;
15111 }
15112
15113 bool addOverflowAssumption(const SCEVAddRecExpr *AR,
15115 auto *A = SE.getWrapPredicate(AR, AddedFlags);
15116 return addOverflowAssumption(A);
15117 }
15118
15119 // If \p Expr represents a PHINode, we try to see if it can be represented
15120 // as an AddRec, possibly under a predicate (PHISCEVPred). If it is possible
15121 // to add this predicate as a runtime overflow check, we return the AddRec.
15122 // If \p Expr does not meet these conditions (is not a PHI node, or we
15123 // couldn't create an AddRec for it, or couldn't add the predicate), we just
15124 // return \p Expr.
15125 const SCEV *convertToAddRecWithPreds(const SCEVUnknown *Expr) {
15126 if (!isa<PHINode>(Expr->getValue()))
15127 return Expr;
15128 std::optional<
15129 std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>>>
15130 PredicatedRewrite = SE.createAddRecFromPHIWithCasts(Expr);
15131 if (!PredicatedRewrite)
15132 return Expr;
15133 for (const auto *P : PredicatedRewrite->second){
15134 // Wrap predicates from outer loops are not supported.
15135 if (auto *WP = dyn_cast<const SCEVWrapPredicate>(P)) {
15136 if (L != WP->getExpr()->getLoop())
15137 return Expr;
15138 }
15139 if (!addOverflowAssumption(P))
15140 return Expr;
15141 }
15142 return PredicatedRewrite->first;
15143 }
15144
15145 SmallVectorImpl<const SCEVPredicate *> *NewPreds;
15146 const SCEVPredicate *Pred;
15147 const Loop *L;
15148};
15149
15150} // end anonymous namespace
15151
15152const SCEV *
15154 const SCEVPredicate &Preds) {
15155 return SCEVPredicateRewriter::rewrite(S, L, *this, nullptr, &Preds);
15156}
15157
15159 const SCEV *S, const Loop *L,
15162 S = SCEVPredicateRewriter::rewrite(S, L, *this, &TransformPreds, nullptr);
15163 auto *AddRec = dyn_cast<SCEVAddRecExpr>(S);
15164
15165 if (!AddRec)
15166 return nullptr;
15167
15168 // Check if any of the transformed predicates is known to be false. In that
15169 // case, it doesn't make sense to convert to a predicated AddRec, as the
15170 // versioned loop will never execute.
15171 for (const SCEVPredicate *Pred : TransformPreds) {
15172 auto *WrapPred = dyn_cast<SCEVWrapPredicate>(Pred);
15173 if (!WrapPred || WrapPred->getFlags() != SCEVWrapPredicate::IncrementNSSW)
15174 continue;
15175
15176 const SCEVAddRecExpr *AddRecToCheck = WrapPred->getExpr();
15177 const SCEV *ExitCount = getBackedgeTakenCount(AddRecToCheck->getLoop());
15178 if (isa<SCEVCouldNotCompute>(ExitCount))
15179 continue;
15180
15181 const SCEV *Step = AddRecToCheck->getStepRecurrence(*this);
15182 if (!Step->isOne())
15183 continue;
15184
15185 ExitCount = getTruncateOrSignExtend(ExitCount, Step->getType());
15186 const SCEV *Add = getAddExpr(AddRecToCheck->getStart(), ExitCount);
15187 if (isKnownPredicate(CmpInst::ICMP_SLT, Add, AddRecToCheck->getStart()))
15188 return nullptr;
15189 }
15190
15191 // Since the transformation was successful, we can now transfer the SCEV
15192 // predicates.
15193 Preds.append(TransformPreds.begin(), TransformPreds.end());
15194
15195 return AddRec;
15196}
15197
15198/// SCEV predicates
15202
15204 const ICmpInst::Predicate Pred,
15205 const SCEV *LHS, const SCEV *RHS)
15206 : SCEVPredicate(ID, P_Compare), Pred(Pred), LHS(LHS), RHS(RHS) {
15207 assert(LHS->getType() == RHS->getType() && "LHS and RHS types don't match");
15208 assert(LHS != RHS && "LHS and RHS are the same SCEV");
15209}
15210
15212 ScalarEvolution &SE) const {
15213 const auto *Op = dyn_cast<SCEVComparePredicate>(N);
15214
15215 if (!Op)
15216 return false;
15217
15218 if (Pred != ICmpInst::ICMP_EQ)
15219 return false;
15220
15221 return Op->LHS == LHS && Op->RHS == RHS;
15222}
15223
15224bool SCEVComparePredicate::isAlwaysTrue() const { return false; }
15225
15227 if (Pred == ICmpInst::ICMP_EQ)
15228 OS.indent(Depth) << "Equal predicate: " << *LHS << " == " << *RHS << "\n";
15229 else
15230 OS.indent(Depth) << "Compare predicate: " << *LHS << " " << Pred << ") "
15231 << *RHS << "\n";
15232
15233}
15234
15236 const SCEVAddRecExpr *AR,
15237 IncrementWrapFlags Flags)
15238 : SCEVPredicate(ID, P_Wrap), AR(AR), Flags(Flags) {}
15239
15240const SCEVAddRecExpr *SCEVWrapPredicate::getExpr() const { return AR; }
15241
15243 ScalarEvolution &SE) const {
15244 const auto *Op = dyn_cast<SCEVWrapPredicate>(N);
15245 if (!Op || setFlags(Flags, Op->Flags) != Flags)
15246 return false;
15247
15248 if (Op->AR == AR)
15249 return true;
15250
15251 if (Flags != SCEVWrapPredicate::IncrementNSSW &&
15253 return false;
15254
15255 const SCEV *Start = AR->getStart();
15256 const SCEV *OpStart = Op->AR->getStart();
15257 if (Start->getType()->isPointerTy() != OpStart->getType()->isPointerTy())
15258 return false;
15259
15260 // Reject pointers to different address spaces.
15261 if (Start->getType()->isPointerTy() && Start->getType() != OpStart->getType())
15262 return false;
15263
15264 const SCEV *Step = AR->getStepRecurrence(SE);
15265 const SCEV *OpStep = Op->AR->getStepRecurrence(SE);
15266 if (!SE.isKnownPositive(Step) || !SE.isKnownPositive(OpStep))
15267 return false;
15268
15269 // If both steps are positive, this implies N, if N's start and step are
15270 // ULE/SLE (for NSUW/NSSW) than this'.
15271 Type *WiderTy = SE.getWiderType(Step->getType(), OpStep->getType());
15272 Step = SE.getNoopOrZeroExtend(Step, WiderTy);
15273 OpStep = SE.getNoopOrZeroExtend(OpStep, WiderTy);
15274
15275 bool IsNUW = Flags == SCEVWrapPredicate::IncrementNUSW;
15276 OpStart = IsNUW ? SE.getNoopOrZeroExtend(OpStart, WiderTy)
15277 : SE.getNoopOrSignExtend(OpStart, WiderTy);
15278 Start = IsNUW ? SE.getNoopOrZeroExtend(Start, WiderTy)
15279 : SE.getNoopOrSignExtend(Start, WiderTy);
15281 return SE.isKnownPredicate(Pred, OpStep, Step) &&
15282 SE.isKnownPredicate(Pred, OpStart, Start);
15283}
15284
15286 SCEV::NoWrapFlags ScevFlags = AR->getNoWrapFlags();
15287 IncrementWrapFlags IFlags = Flags;
15288
15289 if (ScalarEvolution::setFlags(ScevFlags, SCEV::FlagNSW) == ScevFlags)
15290 IFlags = clearFlags(IFlags, IncrementNSSW);
15291
15292 return IFlags == IncrementAnyWrap;
15293}
15294
15295void SCEVWrapPredicate::print(raw_ostream &OS, unsigned Depth) const {
15296 OS.indent(Depth) << *getExpr() << " Added Flags: ";
15298 OS << "<nusw>";
15300 OS << "<nssw>";
15301 OS << "\n";
15302}
15303
15306 ScalarEvolution &SE) {
15307 IncrementWrapFlags ImpliedFlags = IncrementAnyWrap;
15308 SCEV::NoWrapFlags StaticFlags = AR->getNoWrapFlags();
15309
15310 // We can safely transfer the NSW flag as NSSW.
15311 if (ScalarEvolution::setFlags(StaticFlags, SCEV::FlagNSW) == StaticFlags)
15312 ImpliedFlags = IncrementNSSW;
15313
15314 if (ScalarEvolution::setFlags(StaticFlags, SCEV::FlagNUW) == StaticFlags) {
15315 // If the increment is positive, the SCEV NUW flag will also imply the
15316 // WrapPredicate NUSW flag.
15317 if (const auto *Step = dyn_cast<SCEVConstant>(AR->getStepRecurrence(SE)))
15318 if (Step->getValue()->getValue().isNonNegative())
15319 ImpliedFlags = setFlags(ImpliedFlags, IncrementNUSW);
15320 }
15321
15322 return ImpliedFlags;
15323}
15324
15325/// Union predicates don't get cached so create a dummy set ID for it.
15327 ScalarEvolution &SE)
15328 : SCEVPredicate(FoldingSetNodeIDRef(nullptr, 0), P_Union) {
15329 for (const auto *P : Preds)
15330 add(P, SE);
15331}
15332
15334 return all_of(Preds,
15335 [](const SCEVPredicate *I) { return I->isAlwaysTrue(); });
15336}
15337
15339 ScalarEvolution &SE) const {
15340 if (const auto *Set = dyn_cast<SCEVUnionPredicate>(N))
15341 return all_of(Set->Preds, [this, &SE](const SCEVPredicate *I) {
15342 return this->implies(I, SE);
15343 });
15344
15345 return any_of(Preds,
15346 [N, &SE](const SCEVPredicate *I) { return I->implies(N, SE); });
15347}
15348
15350 for (const auto *Pred : Preds)
15351 Pred->print(OS, Depth);
15352}
15353
15354void SCEVUnionPredicate::add(const SCEVPredicate *N, ScalarEvolution &SE) {
15355 if (const auto *Set = dyn_cast<SCEVUnionPredicate>(N)) {
15356 for (const auto *Pred : Set->Preds)
15357 add(Pred, SE);
15358 return;
15359 }
15360
15361 // Implication checks are quadratic in the number of predicates. Stop doing
15362 // them if there are many predicates, as they should be too expensive to use
15363 // anyway at that point.
15364 bool CheckImplies = Preds.size() < 16;
15365
15366 // Only add predicate if it is not already implied by this union predicate.
15367 if (CheckImplies && implies(N, SE))
15368 return;
15369
15370 // Build a new vector containing the current predicates, except the ones that
15371 // are implied by the new predicate N.
15373 for (auto *P : Preds) {
15374 if (CheckImplies && N->implies(P, SE))
15375 continue;
15376 PrunedPreds.push_back(P);
15377 }
15378 Preds = std::move(PrunedPreds);
15379 Preds.push_back(N);
15380}
15381
15383 Loop &L)
15384 : SE(SE), L(L) {
15386 Preds = std::make_unique<SCEVUnionPredicate>(Empty, SE);
15387}
15388
15391 for (const auto *Op : Ops)
15392 // We do not expect that forgetting cached data for SCEVConstants will ever
15393 // open any prospects for sharpening or introduce any correctness issues,
15394 // so we don't bother storing their dependencies.
15395 if (!isa<SCEVConstant>(Op))
15396 SCEVUsers[Op].insert(User);
15397}
15398
15400 for (const SCEV *Op : Ops)
15401 // We do not expect that forgetting cached data for SCEVConstants will ever
15402 // open any prospects for sharpening or introduce any correctness issues,
15403 // so we don't bother storing their dependencies.
15404 if (!isa<SCEVConstant>(Op))
15405 SCEVUsers[Op].insert(User);
15406}
15407
15409 const SCEV *Expr = SE.getSCEV(V);
15410 return getPredicatedSCEV(Expr);
15411}
15412
15414 RewriteEntry &Entry = RewriteMap[Expr];
15415
15416 // If we already have an entry and the version matches, return it.
15417 if (Entry.second && Generation == Entry.first)
15418 return Entry.second;
15419
15420 // We found an entry but it's stale. Rewrite the stale entry
15421 // according to the current predicate.
15422 if (Entry.second)
15423 Expr = Entry.second;
15424
15425 const SCEV *NewSCEV = SE.rewriteUsingPredicate(Expr, &L, *Preds);
15426 Entry = {Generation, NewSCEV};
15427
15428 return NewSCEV;
15429}
15430
15432 if (!BackedgeCount) {
15434 BackedgeCount = SE.getPredicatedBackedgeTakenCount(&L, Preds);
15435 for (const auto *P : Preds)
15436 addPredicate(*P);
15437 }
15438 return BackedgeCount;
15439}
15440
15442 if (!SymbolicMaxBackedgeCount) {
15444 SymbolicMaxBackedgeCount =
15445 SE.getPredicatedSymbolicMaxBackedgeTakenCount(&L, Preds);
15446 for (const auto *P : Preds)
15447 addPredicate(*P);
15448 }
15449 return SymbolicMaxBackedgeCount;
15450}
15451
15453 if (!SmallConstantMaxTripCount) {
15455 SmallConstantMaxTripCount = SE.getSmallConstantMaxTripCount(&L, &Preds);
15456 for (const auto *P : Preds)
15457 addPredicate(*P);
15458 }
15459 return *SmallConstantMaxTripCount;
15460}
15461
15463 if (Preds->implies(&Pred, SE))
15464 return;
15465
15466 SmallVector<const SCEVPredicate *, 4> NewPreds(Preds->getPredicates());
15467 NewPreds.push_back(&Pred);
15468 Preds = std::make_unique<SCEVUnionPredicate>(NewPreds, SE);
15469 updateGeneration();
15470}
15471
15473 return *Preds;
15474}
15475
15476void PredicatedScalarEvolution::updateGeneration() {
15477 // If the generation number wrapped recompute everything.
15478 if (++Generation == 0) {
15479 for (auto &II : RewriteMap) {
15480 const SCEV *Rewritten = II.second.second;
15481 II.second = {Generation, SE.rewriteUsingPredicate(Rewritten, &L, *Preds)};
15482 }
15483 }
15484}
15485
15488 const SCEV *Expr = getSCEV(V);
15489 const auto *AR = cast<SCEVAddRecExpr>(Expr);
15490
15491 auto ImpliedFlags = SCEVWrapPredicate::getImpliedFlags(AR, SE);
15492
15493 // Clear the statically implied flags.
15494 Flags = SCEVWrapPredicate::clearFlags(Flags, ImpliedFlags);
15495 addPredicate(*SE.getWrapPredicate(AR, Flags));
15496
15497 auto II = FlagsMap.insert({V, Flags});
15498 if (!II.second)
15499 II.first->second = SCEVWrapPredicate::setFlags(Flags, II.first->second);
15500}
15501
15504 const SCEV *Expr = getSCEV(V);
15505 const auto *AR = cast<SCEVAddRecExpr>(Expr);
15506
15508 Flags, SCEVWrapPredicate::getImpliedFlags(AR, SE));
15509
15510 auto II = FlagsMap.find(V);
15511
15512 if (II != FlagsMap.end())
15513 Flags = SCEVWrapPredicate::clearFlags(Flags, II->second);
15514
15516}
15517
15519 const SCEV *Expr = this->getSCEV(V);
15521 auto *New = SE.convertSCEVToAddRecWithPredicates(Expr, &L, NewPreds);
15522
15523 if (!New)
15524 return nullptr;
15525
15526 for (const auto *P : NewPreds)
15527 addPredicate(*P);
15528
15529 RewriteMap[SE.getSCEV(V)] = {Generation, New};
15530 return New;
15531}
15532
15535 : RewriteMap(Init.RewriteMap), SE(Init.SE), L(Init.L),
15536 Preds(std::make_unique<SCEVUnionPredicate>(Init.Preds->getPredicates(),
15537 SE)),
15538 Generation(Init.Generation), BackedgeCount(Init.BackedgeCount) {
15539 for (auto I : Init.FlagsMap)
15540 FlagsMap.insert(I);
15541}
15542
15544 // For each block.
15545 for (auto *BB : L.getBlocks())
15546 for (auto &I : *BB) {
15547 if (!SE.isSCEVable(I.getType()))
15548 continue;
15549
15550 auto *Expr = SE.getSCEV(&I);
15551 auto II = RewriteMap.find(Expr);
15552
15553 if (II == RewriteMap.end())
15554 continue;
15555
15556 // Don't print things that are not interesting.
15557 if (II->second.second == Expr)
15558 continue;
15559
15560 OS.indent(Depth) << "[PSE]" << I << ":\n";
15561 OS.indent(Depth + 2) << *Expr << "\n";
15562 OS.indent(Depth + 2) << "--> " << *II->second.second << "\n";
15563 }
15564}
15565
15568 BasicBlock *Header = L->getHeader();
15569 BasicBlock *Pred = L->getLoopPredecessor();
15570 LoopGuards Guards(SE);
15571 if (!Pred)
15572 return Guards;
15574 collectFromBlock(SE, Guards, Header, Pred, VisitedBlocks);
15575 return Guards;
15576}
15577
15578void ScalarEvolution::LoopGuards::collectFromPHI(
15582 unsigned Depth) {
15583 if (!SE.isSCEVable(Phi.getType()))
15584 return;
15585
15586 using MinMaxPattern = std::pair<const SCEVConstant *, SCEVTypes>;
15587 auto GetMinMaxConst = [&](unsigned IncomingIdx) -> MinMaxPattern {
15588 const BasicBlock *InBlock = Phi.getIncomingBlock(IncomingIdx);
15589 if (!VisitedBlocks.insert(InBlock).second)
15590 return {nullptr, scCouldNotCompute};
15591
15592 // Avoid analyzing unreachable blocks so that we don't get trapped
15593 // traversing cycles with ill-formed dominance or infinite cycles
15594 if (!SE.DT.isReachableFromEntry(InBlock))
15595 return {nullptr, scCouldNotCompute};
15596
15597 auto [G, Inserted] = IncomingGuards.try_emplace(InBlock, LoopGuards(SE));
15598 if (Inserted)
15599 collectFromBlock(SE, G->second, Phi.getParent(), InBlock, VisitedBlocks,
15600 Depth + 1);
15601 auto &RewriteMap = G->second.RewriteMap;
15602 if (RewriteMap.empty())
15603 return {nullptr, scCouldNotCompute};
15604 auto S = RewriteMap.find(SE.getSCEV(Phi.getIncomingValue(IncomingIdx)));
15605 if (S == RewriteMap.end())
15606 return {nullptr, scCouldNotCompute};
15607 auto *SM = dyn_cast_if_present<SCEVMinMaxExpr>(S->second);
15608 if (!SM)
15609 return {nullptr, scCouldNotCompute};
15610 if (const SCEVConstant *C0 = dyn_cast<SCEVConstant>(SM->getOperand(0)))
15611 return {C0, SM->getSCEVType()};
15612 return {nullptr, scCouldNotCompute};
15613 };
15614 auto MergeMinMaxConst = [](MinMaxPattern P1,
15615 MinMaxPattern P2) -> MinMaxPattern {
15616 auto [C1, T1] = P1;
15617 auto [C2, T2] = P2;
15618 if (!C1 || !C2 || T1 != T2)
15619 return {nullptr, scCouldNotCompute};
15620 switch (T1) {
15621 case scUMaxExpr:
15622 return {C1->getAPInt().ult(C2->getAPInt()) ? C1 : C2, T1};
15623 case scSMaxExpr:
15624 return {C1->getAPInt().slt(C2->getAPInt()) ? C1 : C2, T1};
15625 case scUMinExpr:
15626 return {C1->getAPInt().ugt(C2->getAPInt()) ? C1 : C2, T1};
15627 case scSMinExpr:
15628 return {C1->getAPInt().sgt(C2->getAPInt()) ? C1 : C2, T1};
15629 default:
15630 llvm_unreachable("Trying to merge non-MinMaxExpr SCEVs.");
15631 }
15632 };
15633 auto P = GetMinMaxConst(0);
15634 for (unsigned int In = 1; In < Phi.getNumIncomingValues(); In++) {
15635 if (!P.first)
15636 break;
15637 P = MergeMinMaxConst(P, GetMinMaxConst(In));
15638 }
15639 if (P.first) {
15640 const SCEV *LHS = SE.getSCEV(const_cast<PHINode *>(&Phi));
15641 SmallVector<SCEVUse, 2> Ops({P.first, LHS});
15642 const SCEV *RHS = SE.getMinMaxExpr(P.second, Ops);
15643 Guards.RewriteMap.insert({LHS, RHS});
15644 }
15645}
15646
15647// Return a new SCEV that modifies \p Expr to the closest number divides by
15648// \p Divisor and less or equal than Expr. For now, only handle constant
15649// Expr.
15651 const APInt &DivisorVal,
15652 ScalarEvolution &SE) {
15653 const APInt *ExprVal;
15654 if (!match(Expr, m_scev_APInt(ExprVal)) || ExprVal->isNegative() ||
15655 DivisorVal.isNonPositive())
15656 return Expr;
15657 APInt Rem = ExprVal->urem(DivisorVal);
15658 // return the SCEV: Expr - Expr % Divisor
15659 return SE.getConstant(*ExprVal - Rem);
15660}
15661
15662// Return a new SCEV that modifies \p Expr to the closest number divides by
15663// \p Divisor and greater or equal than Expr. For now, only handle constant
15664// Expr.
15665static const SCEV *getNextSCEVDivisibleByDivisor(const SCEV *Expr,
15666 const APInt &DivisorVal,
15667 ScalarEvolution &SE) {
15668 const APInt *ExprVal;
15669 if (!match(Expr, m_scev_APInt(ExprVal)) || ExprVal->isNegative() ||
15670 DivisorVal.isNonPositive())
15671 return Expr;
15672 APInt Rem = ExprVal->urem(DivisorVal);
15673 if (Rem.isZero())
15674 return Expr;
15675 // return the SCEV: Expr + Divisor - Expr % Divisor
15676 return SE.getConstant(*ExprVal + DivisorVal - Rem);
15677}
15678
15680 ICmpInst::Predicate Predicate, const SCEV *LHS, const SCEV *RHS,
15683 // If we have LHS == 0, check if LHS is computing a property of some unknown
15684 // SCEV %v which we can rewrite %v to express explicitly.
15686 return false;
15687 // If LHS is A % B, i.e. A % B == 0, rewrite A to (A /u B) * B to
15688 // explicitly express that.
15689 const SCEVUnknown *URemLHS = nullptr;
15690 const SCEV *URemRHS = nullptr;
15691 if (!match(LHS, m_scev_URem(m_SCEVUnknown(URemLHS), m_SCEV(URemRHS), SE)))
15692 return false;
15693
15694 const SCEV *Multiple =
15695 SE.getMulExpr(SE.getUDivExpr(URemLHS, URemRHS), URemRHS);
15696 DivInfo[URemLHS] = Multiple;
15697 if (auto *C = dyn_cast<SCEVConstant>(URemRHS))
15698 Multiples[URemLHS] = C->getAPInt();
15699 return true;
15700}
15701
15702// Check if the condition is a divisibility guard (A % B == 0).
15703static bool isDivisibilityGuard(const SCEV *LHS, const SCEV *RHS,
15704 ScalarEvolution &SE) {
15705 const SCEV *X, *Y;
15706 return match(LHS, m_scev_URem(m_SCEV(X), m_SCEV(Y), SE)) && RHS->isZero();
15707}
15708
15709// Apply divisibility by \p Divisor on MinMaxExpr with constant values,
15710// recursively. This is done by aligning up/down the constant value to the
15711// Divisor.
15712static const SCEV *applyDivisibilityOnMinMaxExpr(const SCEV *MinMaxExpr,
15713 APInt Divisor,
15714 ScalarEvolution &SE) {
15715 // Return true if \p Expr is a MinMax SCEV expression with a non-negative
15716 // constant operand. If so, return in \p SCTy the SCEV type and in \p RHS
15717 // the non-constant operand and in \p LHS the constant operand.
15718 auto IsMinMaxSCEVWithNonNegativeConstant =
15719 [&](const SCEV *Expr, SCEVTypes &SCTy, const SCEV *&LHS,
15720 const SCEV *&RHS) {
15721 if (auto *MinMax = dyn_cast<SCEVMinMaxExpr>(Expr)) {
15722 if (MinMax->getNumOperands() != 2)
15723 return false;
15724 if (auto *C = dyn_cast<SCEVConstant>(MinMax->getOperand(0))) {
15725 if (C->getAPInt().isNegative())
15726 return false;
15727 SCTy = MinMax->getSCEVType();
15728 LHS = MinMax->getOperand(0);
15729 RHS = MinMax->getOperand(1);
15730 return true;
15731 }
15732 }
15733 return false;
15734 };
15735
15736 const SCEV *MinMaxLHS = nullptr, *MinMaxRHS = nullptr;
15737 SCEVTypes SCTy;
15738 if (!IsMinMaxSCEVWithNonNegativeConstant(MinMaxExpr, SCTy, MinMaxLHS,
15739 MinMaxRHS))
15740 return MinMaxExpr;
15741 auto IsMin = isa<SCEVSMinExpr>(MinMaxExpr) || isa<SCEVUMinExpr>(MinMaxExpr);
15742 assert(SE.isKnownNonNegative(MinMaxLHS) && "Expected non-negative operand!");
15743 auto *DivisibleExpr =
15744 IsMin ? getPreviousSCEVDivisibleByDivisor(MinMaxLHS, Divisor, SE)
15745 : getNextSCEVDivisibleByDivisor(MinMaxLHS, Divisor, SE);
15747 applyDivisibilityOnMinMaxExpr(MinMaxRHS, Divisor, SE), DivisibleExpr};
15748 return SE.getMinMaxExpr(SCTy, Ops);
15749}
15750
15751void ScalarEvolution::LoopGuards::collectFromBlock(
15752 ScalarEvolution &SE, ScalarEvolution::LoopGuards &Guards,
15753 const BasicBlock *Block, const BasicBlock *Pred,
15754 SmallPtrSetImpl<const BasicBlock *> &VisitedBlocks, unsigned Depth) {
15755
15757
15758 SmallVector<SCEVUse> ExprsToRewrite;
15759 auto CollectCondition = [&](ICmpInst::Predicate Predicate, const SCEV *LHS,
15760 const SCEV *RHS,
15761 DenseMap<const SCEV *, const SCEV *> &RewriteMap,
15762 const LoopGuards &DivGuards) {
15763 // WARNING: It is generally unsound to apply any wrap flags to the proposed
15764 // replacement SCEV which isn't directly implied by the structure of that
15765 // SCEV. In particular, using contextual facts to imply flags is *NOT*
15766 // legal. See the scoping rules for flags in the header to understand why.
15767
15768 // Check for a condition of the form (-C1 + X < C2). InstCombine will
15769 // create this form when combining two checks of the form (X u< C2 + C1) and
15770 // (X >=u C1).
15771 auto MatchRangeCheckIdiom = [&SE, Predicate, LHS, RHS, &RewriteMap,
15772 &ExprsToRewrite]() {
15773 const SCEVConstant *C1;
15774 const SCEVUnknown *LHSUnknown;
15775 auto *C2 = dyn_cast<SCEVConstant>(RHS);
15776 if (!match(LHS,
15777 m_scev_Add(m_SCEVConstant(C1), m_SCEVUnknown(LHSUnknown))) ||
15778 !C2)
15779 return false;
15780
15781 auto ExactRegion =
15782 ConstantRange::makeExactICmpRegion(Predicate, C2->getAPInt())
15783 .sub(C1->getAPInt());
15784
15785 // Bail out, unless we have a non-wrapping, monotonic range.
15786 if (ExactRegion.isWrappedSet() || ExactRegion.isFullSet())
15787 return false;
15788 auto [I, Inserted] = RewriteMap.try_emplace(LHSUnknown);
15789 const SCEV *RewrittenLHS = Inserted ? LHSUnknown : I->second;
15790 I->second = SE.getUMaxExpr(
15791 SE.getConstant(ExactRegion.getUnsignedMin()),
15792 SE.getUMinExpr(RewrittenLHS,
15793 SE.getConstant(ExactRegion.getUnsignedMax())));
15794 ExprsToRewrite.push_back(LHSUnknown);
15795 return true;
15796 };
15797 if (MatchRangeCheckIdiom())
15798 return;
15799
15800 // Do not apply information for constants or if RHS contains an AddRec.
15802 return;
15803
15804 // If RHS is SCEVUnknown, make sure the information is applied to it.
15806 std::swap(LHS, RHS);
15808 }
15809
15810 // Puts rewrite rule \p From -> \p To into the rewrite map. Also if \p From
15811 // and \p FromRewritten are the same (i.e. there has been no rewrite
15812 // registered for \p From), then puts this value in the list of rewritten
15813 // expressions.
15814 auto AddRewrite = [&](const SCEV *From, const SCEV *FromRewritten,
15815 const SCEV *To) {
15816 if (From == FromRewritten)
15817 ExprsToRewrite.push_back(From);
15818 RewriteMap[From] = To;
15819 };
15820
15821 // Checks whether \p S has already been rewritten. In that case returns the
15822 // existing rewrite because we want to chain further rewrites onto the
15823 // already rewritten value. Otherwise returns \p S.
15824 auto GetMaybeRewritten = [&](const SCEV *S) {
15825 return RewriteMap.lookup_or(S, S);
15826 };
15827
15828 const SCEV *RewrittenLHS = GetMaybeRewritten(LHS);
15829 // Apply divisibility information when computing the constant multiple.
15830 const APInt &DividesBy =
15831 SE.getConstantMultiple(DivGuards.rewrite(RewrittenLHS));
15832
15833 // Collect rewrites for LHS and its transitive operands based on the
15834 // condition.
15835 // For min/max expressions, also apply the guard to its operands:
15836 // 'min(a, b) >= c' -> '(a >= c) and (b >= c)',
15837 // 'min(a, b) > c' -> '(a > c) and (b > c)',
15838 // 'max(a, b) <= c' -> '(a <= c) and (b <= c)',
15839 // 'max(a, b) < c' -> '(a < c) and (b < c)'.
15840
15841 // We cannot express strict predicates in SCEV, so instead we replace them
15842 // with non-strict ones against plus or minus one of RHS depending on the
15843 // predicate.
15844 const SCEV *One = SE.getOne(RHS->getType());
15845 switch (Predicate) {
15846 case CmpInst::ICMP_ULT:
15847 if (RHS->getType()->isPointerTy())
15848 return;
15849 RHS = SE.getUMaxExpr(RHS, One);
15850 [[fallthrough]];
15851 case CmpInst::ICMP_SLT: {
15852 RHS = SE.getMinusSCEV(RHS, One);
15853 RHS = getPreviousSCEVDivisibleByDivisor(RHS, DividesBy, SE);
15854 break;
15855 }
15856 case CmpInst::ICMP_UGT:
15857 case CmpInst::ICMP_SGT:
15858 RHS = SE.getAddExpr(RHS, One);
15859 RHS = getNextSCEVDivisibleByDivisor(RHS, DividesBy, SE);
15860 break;
15861 case CmpInst::ICMP_ULE:
15862 case CmpInst::ICMP_SLE:
15863 RHS = getPreviousSCEVDivisibleByDivisor(RHS, DividesBy, SE);
15864 break;
15865 case CmpInst::ICMP_UGE:
15866 case CmpInst::ICMP_SGE:
15867 RHS = getNextSCEVDivisibleByDivisor(RHS, DividesBy, SE);
15868 break;
15869 default:
15870 break;
15871 }
15872
15873 SmallVector<SCEVUse, 16> Worklist(1, LHS);
15874 SmallPtrSet<const SCEV *, 16> Visited;
15875
15876 auto EnqueueOperands = [&Worklist](const SCEVNAryExpr *S) {
15877 append_range(Worklist, S->operands());
15878 };
15879
15880 while (!Worklist.empty()) {
15881 const SCEV *From = Worklist.pop_back_val();
15882 if (isa<SCEVConstant>(From))
15883 continue;
15884 if (!Visited.insert(From).second)
15885 continue;
15886 const SCEV *FromRewritten = GetMaybeRewritten(From);
15887 const SCEV *To = nullptr;
15888
15889 switch (Predicate) {
15890 case CmpInst::ICMP_ULT:
15891 case CmpInst::ICMP_ULE:
15892 To = SE.getUMinExpr(FromRewritten, RHS);
15893 if (auto *UMax = dyn_cast<SCEVUMaxExpr>(FromRewritten))
15894 EnqueueOperands(UMax);
15895 break;
15896 case CmpInst::ICMP_SLT:
15897 case CmpInst::ICMP_SLE:
15898 To = SE.getSMinExpr(FromRewritten, RHS);
15899 if (auto *SMax = dyn_cast<SCEVSMaxExpr>(FromRewritten))
15900 EnqueueOperands(SMax);
15901 break;
15902 case CmpInst::ICMP_UGT:
15903 case CmpInst::ICMP_UGE:
15904 To = SE.getUMaxExpr(FromRewritten, RHS);
15905 if (auto *UMin = dyn_cast<SCEVUMinExpr>(FromRewritten))
15906 EnqueueOperands(UMin);
15907 break;
15908 case CmpInst::ICMP_SGT:
15909 case CmpInst::ICMP_SGE:
15910 To = SE.getSMaxExpr(FromRewritten, RHS);
15911 if (auto *SMin = dyn_cast<SCEVSMinExpr>(FromRewritten))
15912 EnqueueOperands(SMin);
15913 break;
15914 case CmpInst::ICMP_EQ:
15916 To = RHS;
15917 break;
15918 case CmpInst::ICMP_NE:
15919 if (match(RHS, m_scev_Zero())) {
15920 const SCEV *OneAlignedUp =
15921 getNextSCEVDivisibleByDivisor(One, DividesBy, SE);
15922 To = SE.getUMaxExpr(FromRewritten, OneAlignedUp);
15923 } else {
15924 // LHS != RHS can be rewritten as (LHS - RHS) = UMax(1, LHS - RHS),
15925 // but creating the subtraction eagerly is expensive. Track the
15926 // inequalities in a separate map, and materialize the rewrite lazily
15927 // when encountering a suitable subtraction while re-writing.
15928 if (LHS->getType()->isPointerTy()) {
15932 break;
15933 }
15934 const SCEVConstant *C;
15935 const SCEV *A, *B;
15938 RHS = A;
15939 LHS = B;
15940 }
15941 if (LHS > RHS)
15942 std::swap(LHS, RHS);
15943 Guards.NotEqual.insert({LHS, RHS});
15944 continue;
15945 }
15946 break;
15947 default:
15948 break;
15949 }
15950
15951 if (To)
15952 AddRewrite(From, FromRewritten, To);
15953 }
15954 };
15955
15957 // First, collect information from assumptions dominating the loop.
15958 for (auto &AssumeVH : SE.AC.assumptions()) {
15959 if (!AssumeVH)
15960 continue;
15961 auto *AssumeI = cast<CallInst>(AssumeVH);
15962 if (!SE.DT.dominates(AssumeI, Block))
15963 continue;
15964 Terms.emplace_back(AssumeI->getOperand(0), true);
15965 }
15966
15967 // Second, collect information from llvm.experimental.guards dominating the loop.
15968 auto *GuardDecl = Intrinsic::getDeclarationIfExists(
15969 SE.F.getParent(), Intrinsic::experimental_guard);
15970 if (GuardDecl)
15971 for (const auto *GU : GuardDecl->users())
15972 if (const auto *Guard = dyn_cast<IntrinsicInst>(GU))
15973 if (Guard->getFunction() == Block->getParent() &&
15974 SE.DT.dominates(Guard, Block))
15975 Terms.emplace_back(Guard->getArgOperand(0), true);
15976
15977 // Third, collect conditions from dominating branches. Starting at the loop
15978 // predecessor, climb up the predecessor chain, as long as there are
15979 // predecessors that can be found that have unique successors leading to the
15980 // original header.
15981 // TODO: share this logic with isLoopEntryGuardedByCond.
15982 unsigned NumCollectedConditions = 0;
15984 std::pair<const BasicBlock *, const BasicBlock *> Pair(Pred, Block);
15985 for (; Pair.first;
15986 Pair = SE.getPredecessorWithUniqueSuccessorForBB(Pair.first)) {
15987 VisitedBlocks.insert(Pair.second);
15988 const CondBrInst *LoopEntryPredicate =
15989 dyn_cast<CondBrInst>(Pair.first->getTerminator());
15990 if (!LoopEntryPredicate)
15991 continue;
15992
15993 Terms.emplace_back(LoopEntryPredicate->getCondition(),
15994 LoopEntryPredicate->getSuccessor(0) == Pair.second);
15995 NumCollectedConditions++;
15996
15997 // If we are recursively collecting guards stop after 2
15998 // conditions to limit compile-time impact for now.
15999 if (Depth > 0 && NumCollectedConditions == 2)
16000 break;
16001 }
16002 // Finally, if we stopped climbing the predecessor chain because
16003 // there wasn't a unique one to continue, try to collect conditions
16004 // for PHINodes by recursively following all of their incoming
16005 // blocks and try to merge the found conditions to build a new one
16006 // for the Phi.
16007 if (Pair.second->hasNPredecessorsOrMore(2) &&
16009 SmallDenseMap<const BasicBlock *, LoopGuards> IncomingGuards;
16010 for (auto &Phi : Pair.second->phis())
16011 collectFromPHI(SE, Guards, Phi, VisitedBlocks, IncomingGuards, Depth);
16012 }
16013
16014 // Now apply the information from the collected conditions to
16015 // Guards.RewriteMap. Conditions are processed in reverse order, so the
16016 // earliest conditions is processed first, except guards with divisibility
16017 // information, which are moved to the back. This ensures the SCEVs with the
16018 // shortest dependency chains are constructed first.
16020 GuardsToProcess;
16021 for (auto [Term, EnterIfTrue] : reverse(Terms)) {
16022 SmallVector<Value *, 8> Worklist;
16023 SmallPtrSet<Value *, 8> Visited;
16024 Worklist.push_back(Term);
16025 while (!Worklist.empty()) {
16026 Value *Cond = Worklist.pop_back_val();
16027 if (!Visited.insert(Cond).second)
16028 continue;
16029
16030 if (auto *Cmp = dyn_cast<ICmpInst>(Cond)) {
16031 auto Predicate =
16032 EnterIfTrue ? Cmp->getPredicate() : Cmp->getInversePredicate();
16033 const auto *LHS = SE.getSCEV(Cmp->getOperand(0));
16034 const auto *RHS = SE.getSCEV(Cmp->getOperand(1));
16035 // If LHS is a constant, apply information to the other expression.
16036 // TODO: If LHS is not a constant, check if using CompareSCEVComplexity
16037 // can improve results.
16038 if (isa<SCEVConstant>(LHS)) {
16039 std::swap(LHS, RHS);
16041 }
16042 GuardsToProcess.emplace_back(Predicate, LHS, RHS);
16043 continue;
16044 }
16045
16046 Value *L, *R;
16047 if (EnterIfTrue ? match(Cond, m_LogicalAnd(m_Value(L), m_Value(R)))
16048 : match(Cond, m_LogicalOr(m_Value(L), m_Value(R)))) {
16049 Worklist.push_back(L);
16050 Worklist.push_back(R);
16051 }
16052 }
16053 }
16054
16055 // Process divisibility guards in reverse order to populate DivGuards early.
16056 DenseMap<const SCEV *, APInt> Multiples;
16057 LoopGuards DivGuards(SE);
16058 for (const auto &[Predicate, LHS, RHS] : GuardsToProcess) {
16059 if (!isDivisibilityGuard(LHS, RHS, SE))
16060 continue;
16061 collectDivisibilityInformation(Predicate, LHS, RHS, DivGuards.RewriteMap,
16062 Multiples, SE);
16063 }
16064
16065 for (const auto &[Predicate, LHS, RHS] : GuardsToProcess)
16066 CollectCondition(Predicate, LHS, RHS, Guards.RewriteMap, DivGuards);
16067
16068 // Apply divisibility information last. This ensures it is applied to the
16069 // outermost expression after other rewrites for the given value.
16070 for (const auto &[K, Divisor] : Multiples) {
16071 const SCEV *DivisorSCEV = SE.getConstant(Divisor);
16072 Guards.RewriteMap[K] =
16074 Guards.rewrite(K), Divisor, SE),
16075 DivisorSCEV),
16076 DivisorSCEV);
16077 ExprsToRewrite.push_back(K);
16078 }
16079
16080 // Let the rewriter preserve NUW/NSW flags if the unsigned/signed ranges of
16081 // the replacement expressions are contained in the ranges of the replaced
16082 // expressions.
16083 Guards.PreserveNUW = true;
16084 Guards.PreserveNSW = true;
16085 for (const SCEV *Expr : ExprsToRewrite) {
16086 const SCEV *RewriteTo = Guards.RewriteMap[Expr];
16087 Guards.PreserveNUW &=
16088 SE.getUnsignedRange(Expr).contains(SE.getUnsignedRange(RewriteTo));
16089 Guards.PreserveNSW &=
16090 SE.getSignedRange(Expr).contains(SE.getSignedRange(RewriteTo));
16091 }
16092
16093 // Now that all rewrite information is collect, rewrite the collected
16094 // expressions with the information in the map. This applies information to
16095 // sub-expressions.
16096 if (ExprsToRewrite.size() > 1) {
16097 for (const SCEV *Expr : ExprsToRewrite) {
16098 const SCEV *RewriteTo = Guards.RewriteMap[Expr];
16099 Guards.RewriteMap.erase(Expr);
16100 Guards.RewriteMap.insert({Expr, Guards.rewrite(RewriteTo)});
16101 }
16102 }
16103}
16104
16106 /// A rewriter to replace SCEV expressions in Map with the corresponding entry
16107 /// in the map. It skips AddRecExpr because we cannot guarantee that the
16108 /// replacement is loop invariant in the loop of the AddRec.
16109 class SCEVLoopGuardRewriter
16110 : public SCEVRewriteVisitor<SCEVLoopGuardRewriter> {
16113
16115
16116 public:
16117 SCEVLoopGuardRewriter(ScalarEvolution &SE,
16118 const ScalarEvolution::LoopGuards &Guards)
16119 : SCEVRewriteVisitor(SE), Map(Guards.RewriteMap),
16120 NotEqual(Guards.NotEqual) {
16121 if (Guards.PreserveNUW)
16122 FlagMask = ScalarEvolution::setFlags(FlagMask, SCEV::FlagNUW);
16123 if (Guards.PreserveNSW)
16124 FlagMask = ScalarEvolution::setFlags(FlagMask, SCEV::FlagNSW);
16125 }
16126
16127 const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) { return Expr; }
16128
16129 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
16130 return Map.lookup_or(Expr, Expr);
16131 }
16132
16133 const SCEV *visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr) {
16134 if (const SCEV *S = Map.lookup(Expr))
16135 return S;
16136
16137 // If we didn't find the extact ZExt expr in the map, check if there's
16138 // an entry for a smaller ZExt we can use instead.
16139 Type *Ty = Expr->getType();
16140 const SCEV *Op = Expr->getOperand(0);
16141 unsigned Bitwidth = Ty->getScalarSizeInBits() / 2;
16142 while (Bitwidth % 8 == 0 && Bitwidth >= 8 &&
16143 Bitwidth > Op->getType()->getScalarSizeInBits()) {
16144 Type *NarrowTy = IntegerType::get(SE.getContext(), Bitwidth);
16145 auto *NarrowExt = SE.getZeroExtendExpr(Op, NarrowTy);
16146 if (const SCEV *S = Map.lookup(NarrowExt))
16147 return SE.getZeroExtendExpr(S, Ty);
16148 Bitwidth = Bitwidth / 2;
16149 }
16150
16152 Expr);
16153 }
16154
16155 const SCEV *visitSignExtendExpr(const SCEVSignExtendExpr *Expr) {
16156 if (const SCEV *S = Map.lookup(Expr))
16157 return S;
16159 Expr);
16160 }
16161
16162 const SCEV *visitUMinExpr(const SCEVUMinExpr *Expr) {
16163 if (const SCEV *S = Map.lookup(Expr))
16164 return S;
16166 }
16167
16168 const SCEV *visitSMinExpr(const SCEVSMinExpr *Expr) {
16169 if (const SCEV *S = Map.lookup(Expr))
16170 return S;
16172 }
16173
16174 const SCEV *visitAddExpr(const SCEVAddExpr *Expr) {
16175 // Helper to check if S is a subtraction (A - B) where A != B, and if so,
16176 // return UMax(S, 1).
16177 auto RewriteSubtraction = [&](const SCEV *S) -> const SCEV * {
16178 SCEVUse LHS, RHS;
16179 if (MatchBinarySub(S, LHS, RHS)) {
16180 if (LHS > RHS)
16181 std::swap(LHS, RHS);
16182 if (NotEqual.contains({LHS, RHS})) {
16183 const SCEV *OneAlignedUp = getNextSCEVDivisibleByDivisor(
16184 SE.getOne(S->getType()), SE.getConstantMultiple(S), SE);
16185 return SE.getUMaxExpr(OneAlignedUp, S);
16186 }
16187 }
16188 return nullptr;
16189 };
16190
16191 // Check if Expr itself is a subtraction pattern with guard info.
16192 if (const SCEV *Rewritten = RewriteSubtraction(Expr))
16193 return Rewritten;
16194
16195 // Trip count expressions sometimes consist of adding 3 operands, i.e.
16196 // (Const + A + B). There may be guard info for A + B, and if so, apply
16197 // it.
16198 // TODO: Could more generally apply guards to Add sub-expressions.
16199 if (isa<SCEVConstant>(Expr->getOperand(0)) &&
16200 Expr->getNumOperands() == 3) {
16201 const SCEV *Add =
16202 SE.getAddExpr(Expr->getOperand(1), Expr->getOperand(2));
16203 if (const SCEV *Rewritten = RewriteSubtraction(Add))
16204 return SE.getAddExpr(
16205 Expr->getOperand(0), Rewritten,
16206 ScalarEvolution::maskFlags(Expr->getNoWrapFlags(), FlagMask));
16207 if (const SCEV *S = Map.lookup(Add))
16208 return SE.getAddExpr(Expr->getOperand(0), S);
16209 }
16210 SmallVector<SCEVUse, 2> Operands;
16211 bool Changed = false;
16212 for (SCEVUse Op : Expr->operands()) {
16213 Operands.push_back(
16215 Changed |= Op != Operands.back();
16216 }
16217 // We are only replacing operands with equivalent values, so transfer the
16218 // flags from the original expression.
16219 return !Changed ? Expr
16220 : SE.getAddExpr(Operands,
16222 Expr->getNoWrapFlags(), FlagMask));
16223 }
16224
16225 const SCEV *visitMulExpr(const SCEVMulExpr *Expr) {
16226 SmallVector<SCEVUse, 2> Operands;
16227 bool Changed = false;
16228 for (SCEVUse Op : Expr->operands()) {
16229 Operands.push_back(
16231 Changed |= Op != Operands.back();
16232 }
16233 // We are only replacing operands with equivalent values, so transfer the
16234 // flags from the original expression.
16235 return !Changed ? Expr
16236 : SE.getMulExpr(Operands,
16238 Expr->getNoWrapFlags(), FlagMask));
16239 }
16240 };
16241
16242 if (RewriteMap.empty() && NotEqual.empty())
16243 return Expr;
16244
16245 SCEVLoopGuardRewriter Rewriter(SE, *this);
16246 return Rewriter.visit(Expr);
16247}
16248
16249const SCEV *ScalarEvolution::applyLoopGuards(const SCEV *Expr, const Loop *L) {
16250 return applyLoopGuards(Expr, LoopGuards::collect(L, *this));
16251}
16252
16254 const LoopGuards &Guards) {
16255 return Guards.rewrite(Expr);
16256}
assert(UImm &&(UImm !=~static_cast< T >(0)) &&"Invalid immediate!")
constexpr LLT S1
Rewrite undef for PHI
This file implements a class to represent arbitrary precision integral constant values and operations...
@ PostInc
MachineBasicBlock MachineBasicBlock::iterator DebugLoc DL
Expand Atomic instructions
#define X(NUM, ENUM, NAME)
Definition ELF.h:849
static GCRegistry::Add< ErlangGC > A("erlang", "erlang-compatible garbage collector")
static GCRegistry::Add< StatepointGC > D("statepoint-example", "an example strategy for statepoint")
static GCRegistry::Add< CoreCLRGC > E("coreclr", "CoreCLR-compatible GC")
static GCRegistry::Add< OcamlGC > B("ocaml", "ocaml 3.10-compatible GC")
#define LLVM_DUMP_METHOD
Mark debug helper function definitions like dump() that should not be stripped from debug builds.
Definition Compiler.h:661
This file contains the declarations for the subclasses of Constant, which represent the different fla...
SmallPtrSet< const BasicBlock *, 8 > VisitedBlocks
This file defines the DenseMap class.
This file builds on the ADT/GraphTraits.h file to build generic depth first graph iterator.
This file defines a hash set that can be used to remove duplication of nodes in a graph.
#define op(i)
Hexagon Common GEP
Value * getPointer(Value *Ptr)
This file provides various utilities for inspecting and working with the control flow graph in LLVM I...
This defines the Use class.
iv Induction Variable Users
Definition IVUsers.cpp:48
const AbstractManglingParser< Derived, Alloc >::OperatorInfo AbstractManglingParser< Derived, Alloc >::Ops[]
static bool isZero(Value *V, const DataLayout &DL, DominatorTree *DT, AssumptionCache *AC)
Definition Lint.cpp:539
#define F(x, y, z)
Definition MD5.cpp:54
#define I(x, y, z)
Definition MD5.cpp:57
#define G(x, y, z)
Definition MD5.cpp:55
#define T
#define T1
static constexpr unsigned SM(unsigned Version)
ConstantRange Range(APInt(BitWidth, Low), APInt(BitWidth, High))
uint64_t IntrinsicInst * II
#define P(N)
ppc ctr loops verify
PowerPC Reduce CR logical Operation
#define INITIALIZE_PASS_DEPENDENCY(depName)
Definition PassSupport.h:42
#define INITIALIZE_PASS_END(passName, arg, name, cfg, analysis)
Definition PassSupport.h:44
#define INITIALIZE_PASS_BEGIN(passName, arg, name, cfg, analysis)
Definition PassSupport.h:39
R600 Clause Merge
const SmallVectorImpl< MachineOperand > & Cond
static DominatorTree getDomTree(Function &F)
static bool isValid(const char C)
Returns true if C is a valid mangled character: <0-9a-zA-Z_>.
SI optimize exec mask operations pre RA
static void visit(BasicBlock &Start, std::function< bool(BasicBlock *)> op)
This file contains some templates that are useful if you are working with the STL at all.
This file provides utility classes that use RAII to save and restore values.
bool SCEVMinMaxExprContains(const SCEV *Root, const SCEV *OperandToFind, SCEVTypes RootKind)
static cl::opt< unsigned > MaxAddRecSize("scalar-evolution-max-add-rec-size", cl::Hidden, cl::desc("Max coefficients in AddRec during evolving"), cl::init(8))
static cl::opt< unsigned > RangeIterThreshold("scev-range-iter-threshold", cl::Hidden, cl::desc("Threshold for switching to iteratively computing SCEV ranges"), cl::init(32))
static const Loop * isIntegerLoopHeaderPHI(const PHINode *PN, LoopInfo &LI)
static unsigned getConstantTripCount(const SCEVConstant *ExitCount)
static int CompareValueComplexity(const LoopInfo *const LI, Value *LV, Value *RV, unsigned Depth)
Compare the two values LV and RV in terms of their "complexity" where "complexity" is a partial (and ...
static const SCEV * getNextSCEVDivisibleByDivisor(const SCEV *Expr, const APInt &DivisorVal, ScalarEvolution &SE)
static void PushLoopPHIs(const Loop *L, SmallVectorImpl< Instruction * > &Worklist, SmallPtrSetImpl< Instruction * > &Visited)
Push PHI nodes in the header of the given loop onto the given Worklist.
static void insertFoldCacheEntry(const ScalarEvolution::FoldID &ID, const SCEV *S, DenseMap< ScalarEvolution::FoldID, const SCEV * > &FoldCache, DenseMap< const SCEV *, SmallVector< ScalarEvolution::FoldID, 2 > > &FoldCacheUser)
static cl::opt< bool > ClassifyExpressions("scalar-evolution-classify-expressions", cl::Hidden, cl::init(true), cl::desc("When printing analysis, include information on every instruction"))
static bool hasHugeExpression(ArrayRef< SCEVUse > Ops)
Returns true if Ops contains a huge SCEV (the subtree of S contains at least HugeExprThreshold nodes)...
static bool CanConstantFold(const Instruction *I)
Return true if we can constant fold an instruction of the specified type, assuming that all operands ...
static cl::opt< unsigned > AddOpsInlineThreshold("scev-addops-inline-threshold", cl::Hidden, cl::desc("Threshold for inlining addition operands into a SCEV"), cl::init(500))
static cl::opt< unsigned > MaxLoopGuardCollectionDepth("scalar-evolution-max-loop-guard-collection-depth", cl::Hidden, cl::desc("Maximum depth for recursive loop guard collection"), cl::init(1))
static cl::opt< bool > VerifyIR("scev-verify-ir", cl::Hidden, cl::desc("Verify IR correctness when making sensitive SCEV queries (slow)"), cl::init(false))
static bool RangeRefPHIAllowedOperands(DominatorTree &DT, PHINode *PHI)
static const SCEV * getPreStartForExtend(const SCEVAddRecExpr *AR, Type *Ty, ScalarEvolution *SE, unsigned Depth)
static std::optional< APInt > MinOptional(std::optional< APInt > X, std::optional< APInt > Y)
Helper function to compare optional APInts: (a) if X and Y both exist, return min(X,...
static cl::opt< unsigned > MulOpsInlineThreshold("scev-mulops-inline-threshold", cl::Hidden, cl::desc("Threshold for inlining multiplication operands into a SCEV"), cl::init(32))
static BinaryOperator * getCommonInstForPHI(PHINode *PN)
static bool isDivisibilityGuard(const SCEV *LHS, const SCEV *RHS, ScalarEvolution &SE)
static std::optional< const SCEV * > createNodeForSelectViaUMinSeq(ScalarEvolution *SE, const SCEV *CondExpr, const SCEV *TrueExpr, const SCEV *FalseExpr)
static Constant * BuildConstantFromSCEV(const SCEV *V)
This builds up a Constant using the ConstantExpr interface.
static ConstantInt * EvaluateConstantChrecAtConstant(const SCEVAddRecExpr *AddRec, ConstantInt *C, ScalarEvolution &SE)
static const SCEV * BinomialCoefficient(const SCEV *It, unsigned K, ScalarEvolution &SE, Type *ResultTy)
Compute BC(It, K). The result has width W. Assume, K > 0.
static cl::opt< unsigned > MaxCastDepth("scalar-evolution-max-cast-depth", cl::Hidden, cl::desc("Maximum depth of recursive SExt/ZExt/Trunc"), cl::init(8))
static bool IsMinMaxConsistingOf(const SCEV *MaybeMinMaxExpr, const SCEV *Candidate)
Is MaybeMinMaxExpr an (U|S)(Min|Max) of Candidate and some other values?
static PHINode * getConstantEvolvingPHI(Value *V, const Loop *L)
getConstantEvolvingPHI - Given an LLVM value and a loop, return a PHI node in the loop that V is deri...
static const SCEV * SolveLinEquationWithOverflow(const APInt &A, const SCEV *B, SmallVectorImpl< const SCEVPredicate * > *Predicates, ScalarEvolution &SE, const Loop *L)
Finds the minimum unsigned root of the following equation:
static cl::opt< unsigned > MaxBruteForceIterations("scalar-evolution-max-iterations", cl::ReallyHidden, cl::desc("Maximum number of iterations SCEV will " "symbolically execute a constant " "derived loop"), cl::init(100))
static uint64_t umul_ov(uint64_t i, uint64_t j, bool &Overflow)
static void PrintSCEVWithTypeHint(raw_ostream &OS, const SCEV *S)
When printing a top-level SCEV for trip counts, it's helpful to include a type for constants which ar...
static void PrintLoopInfo(raw_ostream &OS, ScalarEvolution *SE, const Loop *L)
static SCEV::NoWrapFlags StrengthenNoWrapFlags(ScalarEvolution *SE, SCEVTypes Type, ArrayRef< SCEVUse > Ops, SCEV::NoWrapFlags Flags)
static bool containsConstantInAddMulChain(const SCEV *StartExpr)
Determine if any of the operands in this SCEV are a constant or if any of the add or multiply express...
static const SCEV * getExtendAddRecStart(const SCEVAddRecExpr *AR, Type *Ty, ScalarEvolution *SE, unsigned Depth)
static bool CollectAddOperandsWithScales(SmallDenseMap< SCEVUse, APInt, 16 > &M, SmallVectorImpl< SCEVUse > &NewOps, APInt &AccumulatedConstant, ArrayRef< SCEVUse > Ops, const APInt &Scale, ScalarEvolution &SE)
Process the given Ops list, which is a list of operands to be added under the given scale,...
static const SCEV * constantFoldAndGroupOps(ScalarEvolution &SE, LoopInfo &LI, DominatorTree &DT, SmallVectorImpl< SCEVUse > &Ops, FoldT Fold, IsIdentityT IsIdentity, IsAbsorberT IsAbsorber)
Performs a number of common optimizations on the passed Ops.
static cl::opt< unsigned > MaxPhiSCCAnalysisSize("scalar-evolution-max-scc-analysis-depth", cl::Hidden, cl::desc("Maximum amount of nodes to process while searching SCEVUnknown " "Phi strongly connected components"), cl::init(8))
static bool IsKnownPredicateViaAddRecStart(ScalarEvolution &SE, CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS)
static void GroupByComplexity(SmallVectorImpl< SCEVUse > &Ops, LoopInfo *LI, DominatorTree &DT)
Given a list of SCEV objects, order them by their complexity, and group objects of the same complexit...
static bool collectDivisibilityInformation(ICmpInst::Predicate Predicate, const SCEV *LHS, const SCEV *RHS, DenseMap< const SCEV *, const SCEV * > &DivInfo, DenseMap< const SCEV *, APInt > &Multiples, ScalarEvolution &SE)
static cl::opt< unsigned > MaxSCEVOperationsImplicationDepth("scalar-evolution-max-scev-operations-implication-depth", cl::Hidden, cl::desc("Maximum depth of recursive SCEV operations implication analysis"), cl::init(2))
static void PushDefUseChildren(Instruction *I, SmallVectorImpl< Instruction * > &Worklist, SmallPtrSetImpl< Instruction * > &Visited)
Push users of the given Instruction onto the given Worklist.
static std::optional< APInt > SolveQuadraticAddRecRange(const SCEVAddRecExpr *AddRec, const ConstantRange &Range, ScalarEvolution &SE)
Let c(n) be the value of the quadratic chrec {0,+,M,+,N} after n iterations.
static cl::opt< bool > UseContextForNoWrapFlagInference("scalar-evolution-use-context-for-no-wrap-flag-strenghening", cl::Hidden, cl::desc("Infer nuw/nsw flags using context where suitable"), cl::init(true))
static cl::opt< bool > EnableFiniteLoopControl("scalar-evolution-finite-loop", cl::Hidden, cl::desc("Handle <= and >= in finite loops"), cl::init(true))
static bool getOperandsForSelectLikePHI(DominatorTree &DT, PHINode *PN, Value *&Cond, Value *&LHS, Value *&RHS)
static std::optional< std::tuple< APInt, APInt, APInt, APInt, unsigned > > GetQuadraticEquation(const SCEVAddRecExpr *AddRec)
For a given quadratic addrec, generate coefficients of the corresponding quadratic equation,...
static bool isKnownPredicateExtendIdiom(CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS)
static std::optional< BinaryOp > MatchBinaryOp(Value *V, const DataLayout &DL, AssumptionCache &AC, const DominatorTree &DT, const Instruction *CxtI)
Try to map V into a BinaryOp, and return std::nullopt on failure.
static std::optional< APInt > SolveQuadraticAddRecExact(const SCEVAddRecExpr *AddRec, ScalarEvolution &SE)
Let c(n) be the value of the quadratic chrec {L,+,M,+,N} after n iterations.
static std::optional< APInt > TruncIfPossible(std::optional< APInt > X, unsigned BitWidth)
Helper function to truncate an optional APInt to a given BitWidth.
static cl::opt< unsigned > MaxSCEVCompareDepth("scalar-evolution-max-scev-compare-depth", cl::Hidden, cl::desc("Maximum depth of recursive SCEV complexity comparisons"), cl::init(32))
static APInt extractConstantWithoutWrapping(ScalarEvolution &SE, const SCEVConstant *ConstantTerm, const SCEVAddExpr *WholeAddExpr)
static cl::opt< unsigned > MaxConstantEvolvingDepth("scalar-evolution-max-constant-evolving-depth", cl::Hidden, cl::desc("Maximum depth of recursive constant evolving"), cl::init(32))
static ConstantRange getRangeForAffineARHelper(APInt Step, const ConstantRange &StartRange, const APInt &MaxBECount, bool Signed)
static bool MatchBinarySub(const SCEV *S, SCEVUse &LHS, SCEVUse &RHS)
static std::optional< ConstantRange > GetRangeFromMetadata(Value *V)
Helper method to assign a range to V from metadata present in the IR.
static cl::opt< unsigned > HugeExprThreshold("scalar-evolution-huge-expr-threshold", cl::Hidden, cl::desc("Size of the expression which is considered huge"), cl::init(4096))
static Type * isSimpleCastedPHI(const SCEV *Op, const SCEVUnknown *SymbolicPHI, bool &Signed, ScalarEvolution &SE)
Helper function to createAddRecFromPHIWithCasts.
static Constant * EvaluateExpression(Value *V, const Loop *L, DenseMap< Instruction *, Constant * > &Vals, const DataLayout &DL, const TargetLibraryInfo *TLI)
EvaluateExpression - Given an expression that passes the getConstantEvolvingPHI predicate,...
static const SCEV * getPreviousSCEVDivisibleByDivisor(const SCEV *Expr, const APInt &DivisorVal, ScalarEvolution &SE)
static const SCEV * MatchNotExpr(const SCEV *Expr)
If Expr computes ~A, return A else return nullptr.
static cl::opt< unsigned > MaxValueCompareDepth("scalar-evolution-max-value-compare-depth", cl::Hidden, cl::desc("Maximum depth of recursive value complexity comparisons"), cl::init(2))
static const SCEV * applyDivisibilityOnMinMaxExpr(const SCEV *MinMaxExpr, APInt Divisor, ScalarEvolution &SE)
static cl::opt< bool, true > VerifySCEVOpt("verify-scev", cl::Hidden, cl::location(VerifySCEV), cl::desc("Verify ScalarEvolution's backedge taken counts (slow)"))
static const SCEV * getSignedOverflowLimitForStep(const SCEV *Step, ICmpInst::Predicate *Pred, ScalarEvolution *SE)
static cl::opt< unsigned > MaxArithDepth("scalar-evolution-max-arith-depth", cl::Hidden, cl::desc("Maximum depth of recursive arithmetics"), cl::init(32))
static bool HasSameValue(const SCEV *A, const SCEV *B)
SCEV structural equivalence is usually sufficient for testing whether two expressions are equal,...
static uint64_t Choose(uint64_t n, uint64_t k, bool &Overflow)
Compute the result of "n choose k", the binomial coefficient.
static std::optional< int > CompareSCEVComplexity(const LoopInfo *const LI, const SCEV *LHS, const SCEV *RHS, DominatorTree &DT, unsigned Depth=0)
static bool canConstantEvolve(Instruction *I, const Loop *L)
Determine whether this instruction can constant evolve within this loop assuming its operands can all...
static PHINode * getConstantEvolvingPHIOperands(Instruction *UseInst, const Loop *L, DenseMap< Instruction *, PHINode * > &PHIMap, unsigned Depth)
getConstantEvolvingPHIOperands - Implement getConstantEvolvingPHI by recursing through each instructi...
static bool scevUnconditionallyPropagatesPoisonFromOperands(SCEVTypes Kind)
static cl::opt< bool > VerifySCEVStrict("verify-scev-strict", cl::Hidden, cl::desc("Enable stricter verification with -verify-scev is passed"))
static Constant * getOtherIncomingValue(PHINode *PN, BasicBlock *BB)
static cl::opt< bool > UseExpensiveRangeSharpening("scalar-evolution-use-expensive-range-sharpening", cl::Hidden, cl::init(false), cl::desc("Use more powerful methods of sharpening expression ranges. May " "be costly in terms of compile time"))
static const SCEV * getUnsignedOverflowLimitForStep(const SCEV *Step, ICmpInst::Predicate *Pred, ScalarEvolution *SE)
static bool IsKnownPredicateViaMinOrMax(ScalarEvolution &SE, CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS)
Is LHS Pred RHS true on the virtue of LHS or RHS being a Min or Max expression?
static bool BrPHIToSelect(DominatorTree &DT, CondBrInst *BI, PHINode *Merge, Value *&C, Value *&LHS, Value *&RHS)
This file defines the make_scope_exit function, which executes user-defined cleanup logic at scope ex...
static bool InBlock(const Value *V, const BasicBlock *BB)
Provides some synthesis utilities to produce sequences of values.
This file defines the SmallPtrSet class.
This file defines the SmallVector class.
This file defines the 'Statistic' class, which is designed to be an easy way to expose various metric...
#define STATISTIC(VARNAME, DESC)
Definition Statistic.h:171
This file contains some functions that are useful when dealing with strings.
#define LLVM_DEBUG(...)
Definition Debug.h:114
static TableGen::Emitter::Opt Y("gen-skeleton-entry", EmitSkeleton, "Generate example skeleton entry")
static SymbolRef::Type getType(const Symbol *Sym)
Definition TapiFile.cpp:39
LocallyHashedType DenseMapInfo< LocallyHashedType >::Empty
static std::optional< unsigned > getOpcode(ArrayRef< VPValue * > Values)
Returns the opcode of Values or ~0 if they do not all agree.
Definition VPlanSLP.cpp:247
static std::optional< bool > isImpliedCondOperands(CmpInst::Predicate Pred, const Value *ALHS, const Value *ARHS, const Value *BLHS, const Value *BRHS)
Return true if "icmp Pred BLHS BRHS" is true whenever "icmp PredALHS ARHS" is true.
Virtual Register Rewriter
Value * RHS
Value * LHS
BinaryOperator * Mul
static const uint32_t IV[8]
Definition blake3_impl.h:83
SCEVCastSinkingRewriter(ScalarEvolution &SE, Type *TargetTy, ConversionFn CreatePtrCast)
static const SCEV * rewrite(const SCEV *Scev, ScalarEvolution &SE, Type *TargetTy, ConversionFn CreatePtrCast)
const SCEV * visitUnknown(const SCEVUnknown *Expr)
const SCEV * visitMulExpr(const SCEVMulExpr *Expr)
const SCEV * visitAddExpr(const SCEVAddExpr *Expr)
const SCEV * visit(const SCEV *S)
Class for arbitrary precision integers.
Definition APInt.h:78
LLVM_ABI APInt umul_ov(const APInt &RHS, bool &Overflow) const
Definition APInt.cpp:1982
LLVM_ABI APInt zext(unsigned width) const
Zero extend to a new width.
Definition APInt.cpp:1023
bool isMinSignedValue() const
Determine if this is the smallest signed value.
Definition APInt.h:424
uint64_t getZExtValue() const
Get zero extended value.
Definition APInt.h:1555
void setHighBits(unsigned hiBits)
Set the top hiBits bits.
Definition APInt.h:1406
LLVM_ABI APInt getHiBits(unsigned numBits) const
Compute an APInt containing numBits highbits from this APInt.
Definition APInt.cpp:639
unsigned getActiveBits() const
Compute the number of active bits in the value.
Definition APInt.h:1527
LLVM_ABI APInt trunc(unsigned width) const
Truncate to new width.
Definition APInt.cpp:936
static APInt getMaxValue(unsigned numBits)
Gets maximum unsigned value of APInt for specific bit width.
Definition APInt.h:207
APInt abs() const
Get the absolute value.
Definition APInt.h:1810
bool sgt(const APInt &RHS) const
Signed greater than comparison.
Definition APInt.h:1208
bool isAllOnes() const
Determine if all bits are set. This is true for zero-width values.
Definition APInt.h:372
bool ugt(const APInt &RHS) const
Unsigned greater than comparison.
Definition APInt.h:1189
bool isZero() const
Determine if this value is zero, i.e. all bits are clear.
Definition APInt.h:381
bool isSignMask() const
Check if the APInt's value is returned by getSignMask.
Definition APInt.h:467
LLVM_ABI APInt urem(const APInt &RHS) const
Unsigned remainder operation.
Definition APInt.cpp:1677
unsigned getBitWidth() const
Return the number of bits in the APInt.
Definition APInt.h:1503
bool ult(const APInt &RHS) const
Unsigned less than comparison.
Definition APInt.h:1118
static APInt getSignedMaxValue(unsigned numBits)
Gets maximum signed value of APInt for a specific bit width.
Definition APInt.h:210
static APInt getMinValue(unsigned numBits)
Gets minimum unsigned value of APInt for a specific bit width.
Definition APInt.h:217
bool isNegative() const
Determine sign of this APInt.
Definition APInt.h:330
bool sle(const APInt &RHS) const
Signed less or equal comparison.
Definition APInt.h:1173
static APInt getSignedMinValue(unsigned numBits)
Gets minimum signed value of APInt for a specific bit width.
Definition APInt.h:220
bool isNonPositive() const
Determine if this APInt Value is non-positive (<= 0).
Definition APInt.h:362
unsigned countTrailingZeros() const
Definition APInt.h:1662
bool isStrictlyPositive() const
Determine if this APInt Value is positive.
Definition APInt.h:357
unsigned logBase2() const
Definition APInt.h:1776
APInt ashr(unsigned ShiftAmt) const
Arithmetic right-shift function.
Definition APInt.h:834
LLVM_ABI APInt multiplicativeInverse() const
Definition APInt.cpp:1285
bool ule(const APInt &RHS) const
Unsigned less or equal comparison.
Definition APInt.h:1157
LLVM_ABI APInt sext(unsigned width) const
Sign extend to a new width.
Definition APInt.cpp:996
APInt shl(unsigned shiftAmt) const
Left-shift function.
Definition APInt.h:880
bool isPowerOf2() const
Check if this APInt's value is a power of two greater than zero.
Definition APInt.h:441
static APInt getLowBitsSet(unsigned numBits, unsigned loBitsSet)
Constructs an APInt value that has the bottom loBitsSet bits set.
Definition APInt.h:307
bool isSignBitSet() const
Determine if sign bit of this APInt is set.
Definition APInt.h:342
bool slt(const APInt &RHS) const
Signed less than comparison.
Definition APInt.h:1137
static APInt getZero(unsigned numBits)
Get the '0' value for the specified bit-width.
Definition APInt.h:201
bool isIntN(unsigned N) const
Check if this APInt has an N-bits unsigned integer value.
Definition APInt.h:433
static APInt getOneBitSet(unsigned numBits, unsigned BitNo)
Return an APInt with exactly one bit set in the result.
Definition APInt.h:240
bool uge(const APInt &RHS) const
Unsigned greater or equal comparison.
Definition APInt.h:1228
This templated class represents "all analyses that operate over <aparticular IR unit>" (e....
Definition Analysis.h:50
PassT::Result & getResult(IRUnitT &IR, ExtraArgTs... ExtraArgs)
Get the result of an analysis pass for a given IR unit.
Represent the analysis usage information of a pass.
void setPreservesAll()
Set by analyses that do not transform their input at all.
AnalysisUsage & addRequiredTransitive()
ArrayRef - Represent a constant reference to an array (0 or more elements consecutively in memory),...
Definition ArrayRef.h:40
iterator end() const
Definition ArrayRef.h:131
size_t size() const
size - Get the array size.
Definition ArrayRef.h:142
iterator begin() const
Definition ArrayRef.h:130
A function analysis which provides an AssumptionCache.
An immutable pass that tracks lazily created AssumptionCache objects.
A cache of @llvm.assume calls within a function.
MutableArrayRef< WeakVH > assumptions()
Access the list of assumption handles currently tracked for this function.
LLVM Basic Block Representation.
Definition BasicBlock.h:62
iterator begin()
Instruction iterator methods.
Definition BasicBlock.h:449
const Function * getParent() const
Return the enclosing method, or null if none.
Definition BasicBlock.h:213
LLVM_ABI const BasicBlock * getSinglePredecessor() const
Return the predecessor of this block if it has a single predecessor block.
const Instruction & front() const
Definition BasicBlock.h:472
const Instruction * getTerminator() const LLVM_READONLY
Returns the terminator instruction if the block is well formed or null if the block is not well forme...
Definition BasicBlock.h:233
LLVM_ABI unsigned getNoWrapKind() const
Returns one of OBO::NoSignedWrap or OBO::NoUnsignedWrap.
LLVM_ABI Instruction::BinaryOps getBinaryOp() const
Returns the binary operation underlying the intrinsic.
BinaryOps getOpcode() const
Definition InstrTypes.h:374
This class represents a function call, abstracting a target machine's calling convention.
virtual void deleted()
Callback for Value destruction.
void setValPtr(Value *P)
bool isFalseWhenEqual() const
This is just a convenience.
Definition InstrTypes.h:948
Predicate
This enumeration lists the possible predicates for CmpInst subclasses.
Definition InstrTypes.h:676
@ ICMP_SLT
signed less than
Definition InstrTypes.h:705
@ ICMP_SLE
signed less or equal
Definition InstrTypes.h:706
@ ICMP_UGE
unsigned greater or equal
Definition InstrTypes.h:700
@ ICMP_UGT
unsigned greater than
Definition InstrTypes.h:699
@ ICMP_SGT
signed greater than
Definition InstrTypes.h:703
@ ICMP_ULT
unsigned less than
Definition InstrTypes.h:701
@ ICMP_NE
not equal
Definition InstrTypes.h:698
@ ICMP_SGE
signed greater or equal
Definition InstrTypes.h:704
@ ICMP_ULE
unsigned less or equal
Definition InstrTypes.h:702
bool isSigned() const
Definition InstrTypes.h:930
Predicate getSwappedPredicate() const
For example, EQ->EQ, SLE->SGE, ULT->UGT, OEQ->OEQ, ULE->UGE, OLT->OGT, etc.
Definition InstrTypes.h:827
bool isTrueWhenEqual() const
This is just a convenience.
Definition InstrTypes.h:942
Predicate getInversePredicate() const
For example, EQ -> NE, UGT -> ULE, SLT -> SGE, OEQ -> UNE, UGT -> OLE, OLT -> UGE,...
Definition InstrTypes.h:789
bool isUnsigned() const
Definition InstrTypes.h:936
bool isRelational() const
Return true if the predicate is relational (not EQ or NE).
Definition InstrTypes.h:926
An abstraction over a floating-point predicate, and a pack of an integer predicate with samesign info...
static LLVM_ABI std::optional< CmpPredicate > getMatching(CmpPredicate A, CmpPredicate B)
Compares two CmpPredicates taking samesign into account and returns the canonicalized CmpPredicate if...
LLVM_ABI CmpInst::Predicate getPreferredSignedPredicate() const
Attempts to return a signed CmpInst::Predicate from the CmpPredicate.
CmpInst::Predicate dropSameSign() const
Drops samesign information.
Conditional Branch instruction.
Value * getCondition() const
BasicBlock * getSuccessor(unsigned i) const
static LLVM_ABI Constant * getNot(Constant *C)
static Constant * getPtrAdd(Constant *Ptr, Constant *Offset, GEPNoWrapFlags NW=GEPNoWrapFlags::none(), std::optional< ConstantRange > InRange=std::nullopt, Type *OnlyIfReduced=nullptr)
Create a getelementptr i8, ptr, offset constant expression.
Definition Constants.h:1472
static LLVM_ABI Constant * getPtrToInt(Constant *C, Type *Ty, bool OnlyIfReduced=false)
static LLVM_ABI Constant * getPtrToAddr(Constant *C, Type *Ty, bool OnlyIfReduced=false)
static LLVM_ABI Constant * getAdd(Constant *C1, Constant *C2, bool HasNUW=false, bool HasNSW=false)
static LLVM_ABI Constant * getNeg(Constant *C, bool HasNSW=false)
static LLVM_ABI Constant * getTrunc(Constant *C, Type *Ty, bool OnlyIfReduced=false)
This is the shared class of boolean and integer constants.
Definition Constants.h:87
bool isZero() const
This is just a convenience method to make client code smaller for a common code.
Definition Constants.h:219
static LLVM_ABI ConstantInt * getFalse(LLVMContext &Context)
uint64_t getZExtValue() const
Return the constant as a 64-bit unsigned integer value after it has been zero extended as appropriate...
Definition Constants.h:168
const APInt & getValue() const
Return the constant as an APInt value reference.
Definition Constants.h:159
static LLVM_ABI ConstantInt * getBool(LLVMContext &Context, bool V)
This class represents a range of values.
LLVM_ABI ConstantRange add(const ConstantRange &Other) const
Return a new range representing the possible values resulting from an addition of a value in this ran...
LLVM_ABI ConstantRange zextOrTrunc(uint32_t BitWidth) const
Make this range have the bit width given by BitWidth.
PreferredRangeType
If represented precisely, the result of some range operations may consist of multiple disjoint ranges...
LLVM_ABI bool getEquivalentICmp(CmpInst::Predicate &Pred, APInt &RHS) const
Set up Pred and RHS such that ConstantRange::makeExactICmpRegion(Pred, RHS) == *this.
const APInt & getLower() const
Return the lower value for this range.
LLVM_ABI ConstantRange urem(const ConstantRange &Other) const
Return a new range representing the possible values resulting from an unsigned remainder operation of...
LLVM_ABI bool isFullSet() const
Return true if this set contains all of the elements possible for this data-type.
LLVM_ABI bool icmp(CmpInst::Predicate Pred, const ConstantRange &Other) const
Does the predicate Pred hold between ranges this and Other?
LLVM_ABI bool isEmptySet() const
Return true if this set contains no members.
LLVM_ABI ConstantRange zeroExtend(uint32_t BitWidth) const
Return a new range in the specified integer type, which must be strictly larger than the current type...
LLVM_ABI bool isSignWrappedSet() const
Return true if this set wraps around the signed domain.
LLVM_ABI APInt getSignedMin() const
Return the smallest signed value contained in the ConstantRange.
LLVM_ABI bool isWrappedSet() const
Return true if this set wraps around the unsigned domain.
LLVM_ABI void print(raw_ostream &OS) const
Print out the bounds to a stream.
LLVM_ABI ConstantRange truncate(uint32_t BitWidth, unsigned NoWrapKind=0) const
Return a new range in the specified integer type, which must be strictly smaller than the current typ...
LLVM_ABI ConstantRange signExtend(uint32_t BitWidth) const
Return a new range in the specified integer type, which must be strictly larger than the current type...
const APInt & getUpper() const
Return the upper value for this range.
LLVM_ABI ConstantRange unionWith(const ConstantRange &CR, PreferredRangeType Type=Smallest) const
Return the range that results from the union of this range with another range.
static LLVM_ABI ConstantRange makeExactICmpRegion(CmpInst::Predicate Pred, const APInt &Other)
Produce the exact range such that all values in the returned range satisfy the given predicate with a...
LLVM_ABI bool contains(const APInt &Val) const
Return true if the specified value is in the set.
LLVM_ABI APInt getUnsignedMax() const
Return the largest unsigned value contained in the ConstantRange.
LLVM_ABI ConstantRange intersectWith(const ConstantRange &CR, PreferredRangeType Type=Smallest) const
Return the range that results from the intersection of this range with another range.
LLVM_ABI APInt getSignedMax() const
Return the largest signed value contained in the ConstantRange.
static ConstantRange getNonEmpty(APInt Lower, APInt Upper)
Create non-empty constant range with the given bounds.
static LLVM_ABI ConstantRange makeGuaranteedNoWrapRegion(Instruction::BinaryOps BinOp, const ConstantRange &Other, unsigned NoWrapKind)
Produce the largest range containing all X such that "X BinOp Y" is guaranteed not to wrap (overflow)...
LLVM_ABI unsigned getMinSignedBits() const
Compute the maximal number of bits needed to represent every value in this signed range.
uint32_t getBitWidth() const
Get the bit width of this ConstantRange.
LLVM_ABI ConstantRange sub(const ConstantRange &Other) const
Return a new range representing the possible values resulting from a subtraction of a value in this r...
LLVM_ABI ConstantRange sextOrTrunc(uint32_t BitWidth) const
Make this range have the bit width given by BitWidth.
static LLVM_ABI ConstantRange makeExactNoWrapRegion(Instruction::BinaryOps BinOp, const APInt &Other, unsigned NoWrapKind)
Produce the range that contains X if and only if "X BinOp Other" does not wrap.
This is an important base class in LLVM.
Definition Constant.h:43
A parsed version of the target data layout string in and methods for querying it.
Definition DataLayout.h:64
LLVM_ABI const StructLayout * getStructLayout(StructType *Ty) const
Returns a StructLayout object, indicating the alignment of the struct, its size, and the offsets of i...
LLVM_ABI IntegerType * getIntPtrType(LLVMContext &C, unsigned AddressSpace=0) const
Returns an integer type with size at least as big as that of a pointer in the given address space.
LLVM_ABI unsigned getIndexTypeSizeInBits(Type *Ty) const
The size in bits of the index used in GEP calculation for this type.
LLVM_ABI IntegerType * getIndexType(LLVMContext &C, unsigned AddressSpace) const
Returns the type of a GEP index in AddressSpace.
TypeSize getTypeSizeInBits(Type *Ty) const
Size examples:
Definition DataLayout.h:784
ValueT lookup(const_arg_type_t< KeyT > Val) const
lookup - Return the entry for the specified key, or a default constructed value if no such entry exis...
Definition DenseMap.h:205
iterator find(const_arg_type_t< KeyT > Val)
Definition DenseMap.h:178
std::pair< iterator, bool > try_emplace(KeyT &&Key, Ts &&...Args)
Definition DenseMap.h:256
DenseMapIterator< KeyT, ValueT, KeyInfoT, BucketT > iterator
Definition DenseMap.h:74
iterator find_as(const LookupKeyT &Val)
Alternate version of find() which allows a different, and possibly less expensive,...
Definition DenseMap.h:191
size_type count(const_arg_type_t< KeyT > Val) const
Return 1 if the specified key is in the map, 0 otherwise.
Definition DenseMap.h:174
iterator end()
Definition DenseMap.h:81
bool contains(const_arg_type_t< KeyT > Val) const
Return true if the specified key is in the map, false otherwise.
Definition DenseMap.h:169
void swap(DerivedT &RHS)
Definition DenseMap.h:371
std::pair< iterator, bool > insert(const std::pair< KeyT, ValueT > &KV)
Definition DenseMap.h:241
Analysis pass which computes a DominatorTree.
Definition Dominators.h:278
Legacy analysis pass which computes a DominatorTree.
Definition Dominators.h:316
Concrete subclass of DominatorTreeBase that is used to compute a normal dominator tree.
Definition Dominators.h:159
LLVM_ABI bool isReachableFromEntry(const Use &U) const
Provide an overload for a Use.
LLVM_ABI bool dominates(const BasicBlock *BB, const Use &U) const
Return true if the (end of the) basic block BB dominates the use U.
FoldingSetNodeIDRef - This class describes a reference to an interned FoldingSetNodeID,...
Definition FoldingSet.h:172
FoldingSetNodeID - This class is used to gather all the unique data bits of a node.
Definition FoldingSet.h:209
FunctionPass(char &pid)
Definition Pass.h:316
Represents flags for the getelementptr instruction/expression.
bool hasNoUnsignedSignedWrap() const
bool hasNoUnsignedWrap() const
static GEPNoWrapFlags none()
static LLVM_ABI Type * getTypeAtIndex(Type *Ty, Value *Idx)
Return the type of the element at the given index of an indexable type.
Module * getParent()
Get the module that this global value is contained inside of...
static bool isPrivateLinkage(LinkageTypes Linkage)
static bool isInternalLinkage(LinkageTypes Linkage)
This instruction compares its operands according to the predicate given to the constructor.
CmpPredicate getCmpPredicate() const
static bool isGE(Predicate P)
Return true if the predicate is SGE or UGE.
CmpPredicate getSwappedCmpPredicate() const
static LLVM_ABI bool compare(const APInt &LHS, const APInt &RHS, ICmpInst::Predicate Pred)
Return result of LHS Pred RHS comparison.
static bool isLT(Predicate P)
Return true if the predicate is SLT or ULT.
CmpPredicate getInverseCmpPredicate() const
Predicate getNonStrictCmpPredicate() const
For example, SGT -> SGE, SLT -> SLE, ULT -> ULE, UGT -> UGE.
static bool isGT(Predicate P)
Return true if the predicate is SGT or UGT.
Predicate getFlippedSignednessPredicate() const
For example, SLT->ULT, ULT->SLT, SLE->ULE, ULE->SLE, EQ->EQ.
static CmpPredicate getInverseCmpPredicate(CmpPredicate Pred)
static bool isEquality(Predicate P)
Return true if this predicate is either EQ or NE.
bool isRelational() const
Return true if the predicate is relational (not EQ or NE).
static bool isLE(Predicate P)
Return true if the predicate is SLE or ULE.
LLVM_ABI bool hasNoUnsignedWrap() const LLVM_READONLY
Determine whether the no unsigned wrap flag is set.
LLVM_ABI bool hasNoSignedWrap() const LLVM_READONLY
Determine whether the no signed wrap flag is set.
LLVM_ABI bool isIdenticalToWhenDefined(const Instruction *I, bool IntersectAttrs=false) const LLVM_READONLY
This is like isIdenticalTo, except that it ignores the SubclassOptionalData flags,...
Class to represent integer types.
static LLVM_ABI IntegerType * get(LLVMContext &C, unsigned NumBits)
This static method is the primary way of constructing an IntegerType.
Definition Type.cpp:354
A helper class to return the specified delimiter string after the first invocation of operator String...
An instruction for reading from memory.
Analysis pass that exposes the LoopInfo for a function.
Definition LoopInfo.h:569
bool contains(const LoopT *L) const
Return true if the specified loop is contained within in this loop.
BlockT * getHeader() const
unsigned getLoopDepth() const
Return the nesting level of this loop.
BlockT * getLoopPredecessor() const
If the given loop's header has exactly one unique predecessor outside the loop, return it.
LoopT * getParentLoop() const
Return the parent loop if it exists or nullptr for top level loops.
unsigned getLoopDepth(const BlockT *BB) const
Return the loop nesting level of the specified block.
LoopT * getLoopFor(const BlockT *BB) const
Return the inner most loop that BB lives in.
The legacy pass manager's analysis pass to compute loop information.
Definition LoopInfo.h:596
Represents a single loop in the control flow graph.
Definition LoopInfo.h:40
bool isLoopInvariant(const Value *V) const
Return true if the specified value is loop invariant.
Definition LoopInfo.cpp:67
Metadata node.
Definition Metadata.h:1080
A Module instance is used to store all the information related to an LLVM module.
Definition Module.h:67
unsigned getOpcode() const
Return the opcode for this Instruction or ConstantExpr.
Definition Operator.h:43
Utility class for integer operators which may exhibit overflow - Add, Sub, Mul, and Shl.
Definition Operator.h:78
bool hasNoSignedWrap() const
Test whether this operation is known to never undergo signed overflow, aka the nsw property.
Definition Operator.h:111
bool hasNoUnsignedWrap() const
Test whether this operation is known to never undergo unsigned overflow, aka the nuw property.
Definition Operator.h:105
iterator_range< const_block_iterator > blocks() const
op_range incoming_values()
Value * getIncomingValueForBlock(const BasicBlock *BB) const
BasicBlock * getIncomingBlock(unsigned i) const
Return incoming basic block number i.
Value * getIncomingValue(unsigned i) const
Return incoming value number x.
unsigned getNumIncomingValues() const
Return the number of incoming edges.
AnalysisType & getAnalysis() const
getAnalysis<AnalysisType>() - This function is used by subclasses to get to the analysis information ...
PointerIntPair - This class implements a pair of a pointer and small integer.
static PointerType * getUnqual(Type *ElementType)
This constructs a pointer to an object of the specified type in the default address space (address sp...
static LLVM_ABI PoisonValue * get(Type *T)
Static factory methods - Return an 'poison' object of the specified type.
LLVM_ABI void addPredicate(const SCEVPredicate &Pred)
Adds a new predicate.
LLVM_ABI const SCEVPredicate & getPredicate() const
LLVM_ABI const SCEV * getPredicatedSCEV(const SCEV *Expr)
Returns the rewritten SCEV for Expr in the context of the current SCEV predicate.
LLVM_ABI bool hasNoOverflow(Value *V, SCEVWrapPredicate::IncrementWrapFlags Flags)
Returns true if we've proved that V doesn't wrap by means of a SCEV predicate.
LLVM_ABI void setNoOverflow(Value *V, SCEVWrapPredicate::IncrementWrapFlags Flags)
Proves that V doesn't overflow by adding SCEV predicate.
LLVM_ABI void print(raw_ostream &OS, unsigned Depth) const
Print the SCEV mappings done by the Predicated Scalar Evolution.
LLVM_ABI bool areAddRecsEqualWithPreds(const SCEVAddRecExpr *AR1, const SCEVAddRecExpr *AR2) const
Check if AR1 and AR2 are equal, while taking into account Equal predicates in Preds.
LLVM_ABI PredicatedScalarEvolution(ScalarEvolution &SE, Loop &L)
LLVM_ABI const SCEVAddRecExpr * getAsAddRec(Value *V)
Attempts to produce an AddRecExpr for V by adding additional SCEV predicates.
LLVM_ABI unsigned getSmallConstantMaxTripCount()
Returns the upper bound of the loop trip count as a normal unsigned value, or 0 if the trip count is ...
LLVM_ABI const SCEV * getBackedgeTakenCount()
Get the (predicated) backedge count for the analyzed loop.
LLVM_ABI const SCEV * getSymbolicMaxBackedgeTakenCount()
Get the (predicated) symbolic max backedge count for the analyzed loop.
LLVM_ABI const SCEV * getSCEV(Value *V)
Returns the SCEV expression of V, in the context of the current SCEV predicate.
A set of analyses that are preserved following a run of a transformation pass.
Definition Analysis.h:112
static PreservedAnalyses all()
Construct a special preserved set that preserves all passes.
Definition Analysis.h:118
PreservedAnalysisChecker getChecker() const
Build a checker for this PreservedAnalyses and the specified analysis type.
Definition Analysis.h:275
constexpr bool isValid() const
Definition Register.h:112
This node represents an addition of some number of SCEVs.
This node represents a polynomial recurrence on the trip count of the specified loop.
LLVM_ABI const SCEV * evaluateAtIteration(const SCEV *It, ScalarEvolution &SE) const
Return the value of this chain of recurrences at the specified iteration number.
void setNoWrapFlags(NoWrapFlags Flags)
Set flags for a recurrence without clearing any previously set flags.
bool isAffine() const
Return true if this represents an expression A + B*x where A and B are loop invariant values.
bool isQuadratic() const
Return true if this represents an expression A + B*x + C*x^2 where A, B and C are loop invariant valu...
LLVM_ABI const SCEV * getNumIterationsInRange(const ConstantRange &Range, ScalarEvolution &SE) const
Return the number of iterations of this loop that produce values in the specified constant range.
LLVM_ABI const SCEVAddRecExpr * getPostIncExpr(ScalarEvolution &SE) const
Return an expression representing the value of this expression one iteration of the loop ahead.
SCEVUse getStepRecurrence(ScalarEvolution &SE) const
Constructs and returns the recurrence indicating how much this expression steps by.
This is the base class for unary cast operator classes.
LLVM_ABI SCEVCastExpr(const FoldingSetNodeIDRef ID, SCEVTypes SCEVTy, SCEVUse op, Type *ty)
void setNoWrapFlags(NoWrapFlags Flags)
Set flags for a non-recurrence without clearing previously set flags.
This class represents an assumption that the expression LHS Pred RHS evaluates to true,...
SCEVComparePredicate(const FoldingSetNodeIDRef ID, const ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS)
bool isAlwaysTrue() const override
Returns true if the predicate is always true.
void print(raw_ostream &OS, unsigned Depth=0) const override
Prints a textual representation of this predicate with an indentation of Depth.
bool implies(const SCEVPredicate *N, ScalarEvolution &SE) const override
Implementation of the SCEVPredicate interface.
This class represents a constant integer value.
ConstantInt * getValue() const
const APInt & getAPInt() const
This is the base class for unary integral cast operator classes.
LLVM_ABI SCEVIntegralCastExpr(const FoldingSetNodeIDRef ID, SCEVTypes SCEVTy, SCEVUse op, Type *ty)
This node is the base class min/max selections.
static enum SCEVTypes negate(enum SCEVTypes T)
This node represents multiplication of some number of SCEVs.
This node is a base class providing common functionality for n'ary operators.
ArrayRef< SCEVUse > operands() const
NoWrapFlags getNoWrapFlags(NoWrapFlags Mask=NoWrapMask) const
SCEVUse getOperand(unsigned i) const
This class represents an assumption made using SCEV expressions which can be checked at run-time.
SCEVPredicate(const SCEVPredicate &)=default
virtual bool implies(const SCEVPredicate *N, ScalarEvolution &SE) const =0
Returns true if this predicate implies N.
SCEVPredicateKind Kind
This class represents a cast from a pointer to a pointer-sized integer value, without capturing the p...
This class represents a cast from a pointer to a pointer-sized integer value.
This visitor recursively visits a SCEV expression and re-writes it.
const SCEV * visitSignExtendExpr(const SCEVSignExtendExpr *Expr)
const SCEV * visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr)
const SCEV * visitSMinExpr(const SCEVSMinExpr *Expr)
const SCEV * visitUMinExpr(const SCEVUMinExpr *Expr)
This class represents a signed minimum selection.
This node is the base class for sequential/in-order min/max selections.
static SCEVTypes getEquivalentNonSequentialSCEVType(SCEVTypes Ty)
This class represents a sign extension of a small integer value to a larger integer value.
Visit all nodes in the expression tree using worklist traversal.
This class represents a truncation of an integer value to a smaller integer value.
This class represents a binary unsigned division operation.
This class represents an unsigned minimum selection.
This class represents a composition of other SCEV predicates, and is the class that most clients will...
void print(raw_ostream &OS, unsigned Depth) const override
Prints a textual representation of this predicate with an indentation of Depth.
bool implies(const SCEVPredicate *N, ScalarEvolution &SE) const override
Returns true if this predicate implies N.
SCEVUnionPredicate(ArrayRef< const SCEVPredicate * > Preds, ScalarEvolution &SE)
Union predicates don't get cached so create a dummy set ID for it.
bool isAlwaysTrue() const override
Implementation of the SCEVPredicate interface.
This means that we are dealing with an entirely unknown SCEV value, and only represent it as its LLVM...
This class represents the value of vscale, as used when defining the length of a scalable vector or r...
This class represents an assumption made on an AddRec expression.
IncrementWrapFlags
Similar to SCEV::NoWrapFlags, but with slightly different semantics for FlagNUSW.
SCEVWrapPredicate(const FoldingSetNodeIDRef ID, const SCEVAddRecExpr *AR, IncrementWrapFlags Flags)
bool implies(const SCEVPredicate *N, ScalarEvolution &SE) const override
Returns true if this predicate implies N.
static SCEVWrapPredicate::IncrementWrapFlags setFlags(SCEVWrapPredicate::IncrementWrapFlags Flags, SCEVWrapPredicate::IncrementWrapFlags OnFlags)
void print(raw_ostream &OS, unsigned Depth=0) const override
Prints a textual representation of this predicate with an indentation of Depth.
bool isAlwaysTrue() const override
Returns true if the predicate is always true.
const SCEVAddRecExpr * getExpr() const
Implementation of the SCEVPredicate interface.
static SCEVWrapPredicate::IncrementWrapFlags clearFlags(SCEVWrapPredicate::IncrementWrapFlags Flags, SCEVWrapPredicate::IncrementWrapFlags OffFlags)
Convenient IncrementWrapFlags manipulation methods.
static SCEVWrapPredicate::IncrementWrapFlags getImpliedFlags(const SCEVAddRecExpr *AR, ScalarEvolution &SE)
Returns the set of SCEVWrapPredicate no wrap flags implied by a SCEVAddRecExpr.
IncrementWrapFlags getFlags() const
Returns the set assumed no overflow flags.
This class represents a zero extension of a small integer value to a larger integer value.
This class represents an analyzed expression in the program.
unsigned short getExpressionSize() const
LLVM_ABI bool isOne() const
Return true if the expression is a constant one.
LLVM_ABI bool isZero() const
Return true if the expression is a constant zero.
LLVM_ABI void dump() const
This method is used for debugging.
LLVM_ABI bool isAllOnesValue() const
Return true if the expression is a constant all-ones value.
LLVM_ABI bool isNonConstantNegative() const
Return true if the specified scev is negated, but not a constant.
LLVM_ABI ArrayRef< SCEVUse > operands() const
Return operands of this SCEV expression.
LLVM_ABI void print(raw_ostream &OS) const
Print out the internal representation of this scalar to the specified stream.
SCEV(const FoldingSetNodeIDRef ID, SCEVTypes SCEVTy, unsigned short ExpressionSize)
SCEVTypes getSCEVType() const
LLVM_ABI Type * getType() const
Return the LLVM type of this SCEV expression.
NoWrapFlags
NoWrapFlags are bitfield indices into SubclassData.
Analysis pass that exposes the ScalarEvolution for a function.
LLVM_ABI ScalarEvolution run(Function &F, FunctionAnalysisManager &AM)
LLVM_ABI PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM)
LLVM_ABI PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM)
void getAnalysisUsage(AnalysisUsage &AU) const override
getAnalysisUsage - This function should be overriden by passes that need analysis information to do t...
void print(raw_ostream &OS, const Module *=nullptr) const override
print - Print out the internal state of the pass.
bool runOnFunction(Function &F) override
runOnFunction - Virtual method overriden by subclasses to do the per-function processing of the pass.
void releaseMemory() override
releaseMemory() - This member can be implemented by a pass if it wants to be able to release its memo...
void verifyAnalysis() const override
verifyAnalysis() - This member can be implemented by a analysis pass to check state of analysis infor...
static LLVM_ABI LoopGuards collect(const Loop *L, ScalarEvolution &SE)
Collect rewrite map for loop guards for loop L, together with flags indicating if NUW and NSW can be ...
LLVM_ABI const SCEV * rewrite(const SCEV *Expr) const
Try to apply the collected loop guards to Expr.
The main scalar evolution driver.
LLVM_ABI const SCEV * getUDivExpr(SCEVUse LHS, SCEVUse RHS)
Get a canonical unsigned division expression, or something simpler if possible.
const SCEV * getConstantMaxBackedgeTakenCount(const Loop *L)
When successful, this returns a SCEVConstant that is greater than or equal to (i.e.
static bool hasFlags(SCEV::NoWrapFlags Flags, SCEV::NoWrapFlags TestFlags)
const DataLayout & getDataLayout() const
Return the DataLayout associated with the module this SCEV instance is operating on.
LLVM_ABI bool isKnownNonNegative(const SCEV *S)
Test if the given expression is known to be non-negative.
LLVM_ABI bool isKnownOnEveryIteration(CmpPredicate Pred, const SCEVAddRecExpr *LHS, const SCEV *RHS)
Test if the condition described by Pred, LHS, RHS is known to be true on every iteration of the loop ...
LLVM_ABI const SCEV * getNegativeSCEV(const SCEV *V, SCEV::NoWrapFlags Flags=SCEV::FlagAnyWrap)
Return the SCEV object corresponding to -V.
LLVM_ABI std::optional< LoopInvariantPredicate > getLoopInvariantExitCondDuringFirstIterationsImpl(CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS, const Loop *L, const Instruction *CtxI, const SCEV *MaxIter)
LLVM_ABI const SCEV * getUDivCeilSCEV(const SCEV *N, const SCEV *D)
Compute ceil(N / D).
LLVM_ABI std::optional< LoopInvariantPredicate > getLoopInvariantExitCondDuringFirstIterations(CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS, const Loop *L, const Instruction *CtxI, const SCEV *MaxIter)
If the result of the predicate LHS Pred RHS is loop invariant with respect to L at given Context duri...
LLVM_ABI Type * getWiderType(Type *Ty1, Type *Ty2) const
LLVM_ABI const SCEV * getAbsExpr(const SCEV *Op, bool IsNSW)
LLVM_ABI bool isKnownNonPositive(const SCEV *S)
Test if the given expression is known to be non-positive.
LLVM_ABI bool isKnownNegative(const SCEV *S)
Test if the given expression is known to be negative.
LLVM_ABI const SCEV * getPredicatedConstantMaxBackedgeTakenCount(const Loop *L, SmallVectorImpl< const SCEVPredicate * > &Predicates)
Similar to getConstantMaxBackedgeTakenCount, except it will add a set of SCEV predicates to Predicate...
LLVM_ABI const SCEV * removePointerBase(const SCEV *S)
Compute an expression equivalent to S - getPointerBase(S).
LLVM_ABI bool isLoopEntryGuardedByCond(const Loop *L, CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS)
Test whether entry to the loop is protected by a conditional between LHS and RHS.
LLVM_ABI bool isKnownNonZero(const SCEV *S)
Test if the given expression is known to be non-zero.
LLVM_ABI const SCEV * getURemExpr(SCEVUse LHS, SCEVUse RHS)
Represents an unsigned remainder expression based on unsigned division.
LLVM_ABI const SCEV * getSCEVAtScope(const SCEV *S, const Loop *L)
Return a SCEV expression for the specified value at the specified scope in the program.
LLVM_ABI const SCEV * getBackedgeTakenCount(const Loop *L, ExitCountKind Kind=Exact)
If the specified loop has a predictable backedge-taken count, return it, otherwise return a SCEVCould...
LLVM_ABI const SCEV * getSMinExpr(SCEVUse LHS, SCEVUse RHS)
LLVM_ABI void setNoWrapFlags(SCEVAddRecExpr *AddRec, SCEV::NoWrapFlags Flags)
Update no-wrap flags of an AddRec.
LLVM_ABI const SCEV * getUMaxFromMismatchedTypes(const SCEV *LHS, const SCEV *RHS)
Promote the operands to the wider of the types using zero-extension, and then perform a umax operatio...
const SCEV * getZero(Type *Ty)
Return a SCEV for the constant 0 of a specific type.
LLVM_ABI bool willNotOverflow(Instruction::BinaryOps BinOp, bool Signed, const SCEV *LHS, const SCEV *RHS, const Instruction *CtxI=nullptr)
Is operation BinOp between LHS and RHS provably does not have a signed/unsigned overflow (Signed)?
LLVM_ABI ExitLimit computeExitLimitFromCond(const Loop *L, Value *ExitCond, bool ExitIfTrue, bool ControlsOnlyExit, bool AllowPredicates=false)
Compute the number of times the backedge of the specified loop will execute if its exit condition wer...
LLVM_ABI const SCEV * getZeroExtendExprImpl(const SCEV *Op, Type *Ty, unsigned Depth=0)
LLVM_ABI const SCEV * getMinMaxExpr(SCEVTypes Kind, SmallVectorImpl< SCEVUse > &Operands)
LLVM_ABI const SCEVPredicate * getEqualPredicate(const SCEV *LHS, const SCEV *RHS)
LLVM_ABI unsigned getSmallConstantTripMultiple(const Loop *L, const SCEV *ExitCount)
Returns the largest constant divisor of the trip count as a normal unsigned value,...
LLVM_ABI uint64_t getTypeSizeInBits(Type *Ty) const
Return the size in bits of the specified type, for which isSCEVable must return true.
LLVM_ABI const SCEV * getConstant(ConstantInt *V)
LLVM_ABI const SCEV * getPredicatedBackedgeTakenCount(const Loop *L, SmallVectorImpl< const SCEVPredicate * > &Predicates)
Similar to getBackedgeTakenCount, except it will add a set of SCEV predicates to Predicates that are ...
LLVM_ABI const SCEV * getSCEV(Value *V)
Return a SCEV expression for the full generality of the specified expression.
LLVM_ABI const SCEV * getMinusSCEV(SCEVUse LHS, SCEVUse RHS, SCEV::NoWrapFlags Flags=SCEV::FlagAnyWrap, unsigned Depth=0)
Return LHS-RHS.
ConstantRange getSignedRange(const SCEV *S)
Determine the signed range for a particular SCEV.
LLVM_ABI const SCEV * getAddRecExpr(SCEVUse Start, SCEVUse Step, const Loop *L, SCEV::NoWrapFlags Flags)
Get an add recurrence expression for the specified loop.
LLVM_ABI const SCEV * getNoopOrSignExtend(const SCEV *V, Type *Ty)
Return a SCEV corresponding to a conversion of the input value to the specified type.
bool loopHasNoAbnormalExits(const Loop *L)
Return true if the loop has no abnormal exits.
LLVM_ABI const SCEV * getTripCountFromExitCount(const SCEV *ExitCount)
A version of getTripCountFromExitCount below which always picks an evaluation type which can not resu...
LLVM_ABI ScalarEvolution(Function &F, TargetLibraryInfo &TLI, AssumptionCache &AC, DominatorTree &DT, LoopInfo &LI)
const SCEV * getOne(Type *Ty)
Return a SCEV for the constant 1 of a specific type.
LLVM_ABI const SCEV * getTruncateOrNoop(const SCEV *V, Type *Ty)
Return a SCEV corresponding to a conversion of the input value to the specified type.
LLVM_ABI const SCEV * getLosslessPtrToIntExpr(const SCEV *Op)
LLVM_ABI const SCEV * getCastExpr(SCEVTypes Kind, const SCEV *Op, Type *Ty)
LLVM_ABI const SCEV * getSequentialMinMaxExpr(SCEVTypes Kind, SmallVectorImpl< SCEVUse > &Operands)
LLVM_ABI std::optional< bool > evaluatePredicateAt(CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS, const Instruction *CtxI)
Check whether the condition described by Pred, LHS, and RHS is true or false in the given Context.
LLVM_ABI unsigned getSmallConstantMaxTripCount(const Loop *L, SmallVectorImpl< const SCEVPredicate * > *Predicates=nullptr)
Returns the upper bound of the loop trip count as a normal unsigned value.
LLVM_ABI const SCEV * getPtrToIntExpr(const SCEV *Op, Type *Ty)
LLVM_ABI bool isBackedgeTakenCountMaxOrZero(const Loop *L)
Return true if the backedge taken count is either the value returned by getConstantMaxBackedgeTakenCo...
LLVM_ABI void forgetLoop(const Loop *L)
This method should be called by the client when it has changed a loop in a way that may effect Scalar...
LLVM_ABI bool isLoopInvariant(const SCEV *S, const Loop *L)
Return true if the value of the given SCEV is unchanging in the specified loop.
LLVM_ABI bool isKnownPositive(const SCEV *S)
Test if the given expression is known to be positive.
LLVM_ABI bool SimplifyICmpOperands(CmpPredicate &Pred, SCEVUse &LHS, SCEVUse &RHS, unsigned Depth=0)
Simplify LHS and RHS in a comparison with predicate Pred.
APInt getUnsignedRangeMin(const SCEV *S)
Determine the min of the unsigned range for a particular SCEV.
LLVM_ABI const SCEV * getOffsetOfExpr(Type *IntTy, StructType *STy, unsigned FieldNo)
Return an expression for offsetof on the given field with type IntTy.
LLVM_ABI LoopDisposition getLoopDisposition(const SCEV *S, const Loop *L)
Return the "disposition" of the given SCEV with respect to the given loop.
LLVM_ABI bool containsAddRecurrence(const SCEV *S)
Return true if the SCEV is a scAddRecExpr or it contains scAddRecExpr.
LLVM_ABI const SCEV * getSignExtendExprImpl(const SCEV *Op, Type *Ty, unsigned Depth=0)
LLVM_ABI bool hasOperand(const SCEV *S, const SCEV *Op) const
Test whether the given SCEV has Op as a direct or indirect operand.
LLVM_ABI const SCEV * getZeroExtendExpr(const SCEV *Op, Type *Ty, unsigned Depth=0)
LLVM_ABI bool isSCEVable(Type *Ty) const
Test if values of the given type are analyzable within the SCEV framework.
LLVM_ABI Type * getEffectiveSCEVType(Type *Ty) const
Return a type with the same bitwidth as the given type and which represents how SCEV will treat the g...
LLVM_ABI const SCEVPredicate * getComparePredicate(ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS)
LLVM_ABI bool haveSameSign(const SCEV *S1, const SCEV *S2)
Return true if we know that S1 and S2 must have the same sign.
LLVM_ABI const SCEV * getNotSCEV(const SCEV *V)
Return the SCEV object corresponding to ~V.
LLVM_ABI const SCEV * getElementCount(Type *Ty, ElementCount EC, SCEV::NoWrapFlags Flags=SCEV::FlagAnyWrap)
LLVM_ABI bool instructionCouldExistWithOperands(const SCEV *A, const SCEV *B)
Return true if there exists a point in the program at which both A and B could be operands to the sam...
ConstantRange getUnsignedRange(const SCEV *S)
Determine the unsigned range for a particular SCEV.
LLVM_ABI void print(raw_ostream &OS) const
LLVM_ABI const SCEV * getPredicatedExitCount(const Loop *L, const BasicBlock *ExitingBlock, SmallVectorImpl< const SCEVPredicate * > *Predicates, ExitCountKind Kind=Exact)
Same as above except this uses the predicated backedge taken info and may require predicates.
static SCEV::NoWrapFlags clearFlags(SCEV::NoWrapFlags Flags, SCEV::NoWrapFlags OffFlags)
LLVM_ABI void forgetTopmostLoop(const Loop *L)
LLVM_ABI void forgetValue(Value *V)
This method should be called by the client when it has changed a value in a way that may effect its v...
APInt getSignedRangeMin(const SCEV *S)
Determine the min of the signed range for a particular SCEV.
LLVM_ABI const SCEV * getNoopOrAnyExtend(const SCEV *V, Type *Ty)
Return a SCEV corresponding to a conversion of the input value to the specified type.
LLVM_ABI void forgetBlockAndLoopDispositions(Value *V=nullptr)
Called when the client has changed the disposition of values in a loop or block.
LLVM_ABI const SCEV * getTruncateExpr(const SCEV *Op, Type *Ty, unsigned Depth=0)
LLVM_ABI const SCEV * getUMaxExpr(SCEVUse LHS, SCEVUse RHS)
LLVM_ABI std::optional< LoopInvariantPredicate > getLoopInvariantPredicate(CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS, const Loop *L, const Instruction *CtxI=nullptr)
If the result of the predicate LHS Pred RHS is loop invariant with respect to L, return a LoopInvaria...
LLVM_ABI const SCEV * getStoreSizeOfExpr(Type *IntTy, Type *StoreTy)
Return an expression for the store size of StoreTy that is type IntTy.
LLVM_ABI const SCEVPredicate * getWrapPredicate(const SCEVAddRecExpr *AR, SCEVWrapPredicate::IncrementWrapFlags AddedFlags)
LLVM_ABI bool isLoopBackedgeGuardedByCond(const Loop *L, CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS)
Test whether the backedge of the loop is protected by a conditional between LHS and RHS.
LLVM_ABI APInt getNonZeroConstantMultiple(const SCEV *S)
const SCEV * getMinusOne(Type *Ty)
Return a SCEV for the constant -1 of a specific type.
static SCEV::NoWrapFlags setFlags(SCEV::NoWrapFlags Flags, SCEV::NoWrapFlags OnFlags)
LLVM_ABI bool hasLoopInvariantBackedgeTakenCount(const Loop *L)
Return true if the specified loop has an analyzable loop-invariant backedge-taken count.
LLVM_ABI BlockDisposition getBlockDisposition(const SCEV *S, const BasicBlock *BB)
Return the "disposition" of the given SCEV with respect to the given block.
LLVM_ABI const SCEV * getNoopOrZeroExtend(const SCEV *V, Type *Ty)
Return a SCEV corresponding to a conversion of the input value to the specified type.
LLVM_ABI bool invalidate(Function &F, const PreservedAnalyses &PA, FunctionAnalysisManager::Invalidator &Inv)
LLVM_ABI const SCEV * getUMinFromMismatchedTypes(const SCEV *LHS, const SCEV *RHS, bool Sequential=false)
Promote the operands to the wider of the types using zero-extension, and then perform a umin operatio...
LLVM_ABI bool loopIsFiniteByAssumption(const Loop *L)
Return true if this loop is finite by assumption.
LLVM_ABI const SCEV * getExistingSCEV(Value *V)
Return an existing SCEV for V if there is one, otherwise return nullptr.
LLVM_ABI APInt getConstantMultiple(const SCEV *S, const Instruction *CtxI=nullptr)
Returns the max constant multiple of S.
LoopDisposition
An enum describing the relationship between a SCEV and a loop.
@ LoopComputable
The SCEV varies predictably with the loop.
@ LoopVariant
The SCEV is loop-variant (unknown).
@ LoopInvariant
The SCEV is loop-invariant.
LLVM_ABI bool isKnownMultipleOf(const SCEV *S, uint64_t M, SmallVectorImpl< const SCEVPredicate * > &Assumptions)
Check that S is a multiple of M.
LLVM_ABI const SCEV * getAnyExtendExpr(const SCEV *Op, Type *Ty)
getAnyExtendExpr - Return a SCEV for the given operand extended with unspecified bits out to the give...
LLVM_ABI bool isKnownToBeAPowerOfTwo(const SCEV *S, bool OrZero=false, bool OrNegative=false)
Test if the given expression is known to be a power of 2.
LLVM_ABI std::optional< SCEV::NoWrapFlags > getStrengthenedNoWrapFlagsFromBinOp(const OverflowingBinaryOperator *OBO)
Parse NSW/NUW flags from add/sub/mul IR binary operation Op into SCEV no-wrap flags,...
LLVM_ABI void forgetLcssaPhiWithNewPredecessor(Loop *L, PHINode *V)
Forget LCSSA phi node V of loop L to which a new predecessor was added, such that it may no longer be...
LLVM_ABI bool containsUndefs(const SCEV *S) const
Return true if the SCEV expression contains an undef value.
LLVM_ABI std::optional< MonotonicPredicateType > getMonotonicPredicateType(const SCEVAddRecExpr *LHS, ICmpInst::Predicate Pred)
If, for all loop invariant X, the predicate "LHS `Pred` X" is monotonically increasing or decreasing,...
LLVM_ABI const SCEV * getCouldNotCompute()
LLVM_ABI const SCEV * getMulExpr(SmallVectorImpl< SCEVUse > &Ops, SCEV::NoWrapFlags Flags=SCEV::FlagAnyWrap, unsigned Depth=0)
Get a canonical multiply expression, or something simpler if possible.
LLVM_ABI bool isAvailableAtLoopEntry(const SCEV *S, const Loop *L)
Determine if the SCEV can be evaluated at loop's entry.
LLVM_ABI uint32_t getMinTrailingZeros(const SCEV *S, const Instruction *CtxI=nullptr)
Determine the minimum number of zero bits that S is guaranteed to end in (at every loop iteration).
BlockDisposition
An enum describing the relationship between a SCEV and a basic block.
@ DominatesBlock
The SCEV dominates the block.
@ ProperlyDominatesBlock
The SCEV properly dominates the block.
@ DoesNotDominateBlock
The SCEV does not dominate the block.
LLVM_ABI const SCEV * getExitCount(const Loop *L, const BasicBlock *ExitingBlock, ExitCountKind Kind=Exact)
Return the number of times the backedge executes before the given exit would be taken; if not exactly...
LLVM_ABI const SCEV * getSignExtendExpr(const SCEV *Op, Type *Ty, unsigned Depth=0)
LLVM_ABI void getPoisonGeneratingValues(SmallPtrSetImpl< const Value * > &Result, const SCEV *S)
Return the set of Values that, if poison, will definitively result in S being poison as well.
LLVM_ABI void forgetLoopDispositions()
Called when the client has changed the disposition of values in this loop.
LLVM_ABI const SCEV * getVScale(Type *Ty)
LLVM_ABI unsigned getSmallConstantTripCount(const Loop *L)
Returns the exact trip count of the loop if we can compute it, and the result is a small constant.
LLVM_ABI bool hasComputableLoopEvolution(const SCEV *S, const Loop *L)
Return true if the given SCEV changes value in a known way in the specified loop.
LLVM_ABI const SCEV * getPointerBase(const SCEV *V)
Transitively follow the chain of pointer-type operands until reaching a SCEV that does not have a sin...
LLVM_ABI void forgetAllLoops()
LLVM_ABI bool dominates(const SCEV *S, const BasicBlock *BB)
Return true if elements that makes up the given SCEV dominate the specified basic block.
APInt getUnsignedRangeMax(const SCEV *S)
Determine the max of the unsigned range for a particular SCEV.
LLVM_ABI const SCEV * getAddExpr(SmallVectorImpl< SCEVUse > &Ops, SCEV::NoWrapFlags Flags=SCEV::FlagAnyWrap, unsigned Depth=0)
Get a canonical add expression, or something simpler if possible.
ExitCountKind
The terms "backedge taken count" and "exit count" are used interchangeably to refer to the number of ...
@ SymbolicMaximum
An expression which provides an upper bound on the exact trip count.
@ ConstantMaximum
A constant which provides an upper bound on the exact trip count.
@ Exact
An expression exactly describing the number of times the backedge has executed when a loop is exited.
LLVM_ABI bool isKnownPredicate(CmpPredicate Pred, SCEVUse LHS, SCEVUse RHS)
Test if the given expression is known to satisfy the condition described by Pred, LHS,...
LLVM_ABI const SCEV * applyLoopGuards(const SCEV *Expr, const Loop *L)
Try to apply information from loop guards for L to Expr.
LLVM_ABI const SCEV * getPtrToAddrExpr(const SCEV *Op)
LLVM_ABI const SCEVAddRecExpr * convertSCEVToAddRecWithPredicates(const SCEV *S, const Loop *L, SmallVectorImpl< const SCEVPredicate * > &Preds)
Tries to convert the S expression to an AddRec expression, adding additional predicates to Preds as r...
LLVM_ABI const SCEV * getSMaxExpr(SCEVUse LHS, SCEVUse RHS)
LLVM_ABI const SCEV * getElementSize(Instruction *Inst)
Return the size of an element read or written by Inst.
LLVM_ABI const SCEV * getSizeOfExpr(Type *IntTy, TypeSize Size)
Return an expression for a TypeSize.
LLVM_ABI std::optional< bool > evaluatePredicate(CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS)
Check whether the condition described by Pred, LHS, and RHS is true or false.
LLVM_ABI const SCEV * getUnknown(Value *V)
LLVM_ABI std::optional< std::pair< const SCEV *, SmallVector< const SCEVPredicate *, 3 > > > createAddRecFromPHIWithCasts(const SCEVUnknown *SymbolicPHI)
Checks if SymbolicPHI can be rewritten as an AddRecExpr under some Predicates.
LLVM_ABI const SCEV * getTruncateOrZeroExtend(const SCEV *V, Type *Ty, unsigned Depth=0)
Return a SCEV corresponding to a conversion of the input value to the specified type.
LLVM_ABI bool isKnownViaInduction(CmpPredicate Pred, SCEVUse LHS, SCEVUse RHS)
We'd like to check the predicate on every iteration of the most dominated loop between loops used in ...
static SCEV::NoWrapFlags maskFlags(SCEV::NoWrapFlags Flags, int Mask)
Convenient NoWrapFlags manipulation that hides enum casts and is visible in the ScalarEvolution name ...
LLVM_ABI std::optional< APInt > computeConstantDifference(const SCEV *LHS, const SCEV *RHS)
Compute LHS - RHS and returns the result as an APInt if it is a constant, and std::nullopt if it isn'...
LLVM_ABI bool properlyDominates(const SCEV *S, const BasicBlock *BB)
Return true if elements that makes up the given SCEV properly dominate the specified basic block.
LLVM_ABI const SCEV * getUDivExactExpr(SCEVUse LHS, SCEVUse RHS)
Get a canonical unsigned division expression, or something simpler if possible.
LLVM_ABI const SCEV * rewriteUsingPredicate(const SCEV *S, const Loop *L, const SCEVPredicate &A)
Re-writes the SCEV according to the Predicates in A.
LLVM_ABI std::pair< const SCEV *, const SCEV * > SplitIntoInitAndPostInc(const Loop *L, const SCEV *S)
Splits SCEV expression S into two SCEVs.
LLVM_ABI bool canReuseInstruction(const SCEV *S, Instruction *I, SmallVectorImpl< Instruction * > &DropPoisonGeneratingInsts)
Check whether it is poison-safe to represent the expression S using the instruction I.
LLVM_ABI bool isKnownPredicateAt(CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS, const Instruction *CtxI)
Test if the given expression is known to satisfy the condition described by Pred, LHS,...
LLVM_ABI const SCEV * getPredicatedSymbolicMaxBackedgeTakenCount(const Loop *L, SmallVectorImpl< const SCEVPredicate * > &Predicates)
Similar to getSymbolicMaxBackedgeTakenCount, except it will add a set of SCEV predicates to Predicate...
LLVM_ABI const SCEV * getGEPExpr(GEPOperator *GEP, ArrayRef< SCEVUse > IndexExprs)
Returns an expression for a GEP.
LLVM_ABI const SCEV * getUMinExpr(SCEVUse LHS, SCEVUse RHS, bool Sequential=false)
LLVM_ABI void registerUser(const SCEV *User, ArrayRef< const SCEV * > Ops)
Notify this ScalarEvolution that User directly uses SCEVs in Ops.
LLVM_ABI bool isBasicBlockEntryGuardedByCond(const BasicBlock *BB, CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS)
Test whether entry to the basic block is protected by a conditional between LHS and RHS.
LLVM_ABI const SCEV * getTruncateOrSignExtend(const SCEV *V, Type *Ty, unsigned Depth=0)
Return a SCEV corresponding to a conversion of the input value to the specified type.
LLVM_ABI bool containsErasedValue(const SCEV *S) const
Return true if the SCEV expression contains a Value that has been optimised out and is now a nullptr.
const SCEV * getSymbolicMaxBackedgeTakenCount(const Loop *L)
When successful, this returns a SCEV that is greater than or equal to (i.e.
APInt getSignedRangeMax(const SCEV *S)
Determine the max of the signed range for a particular SCEV.
LLVM_ABI void verify() const
LLVMContext & getContext() const
Implements a dense probed hash-table based set with some number of buckets stored inline.
Definition DenseSet.h:291
size_type size() const
Definition SmallPtrSet.h:99
A templated base class for SmallPtrSet which provides the typesafe interface that is common across al...
std::pair< iterator, bool > insert(PtrType Ptr)
Inserts Ptr if and only if there is no element in the container equal to Ptr.
bool contains(ConstPtrType Ptr) const
SmallPtrSet - This class implements a set which is optimized for holding SmallSize or less elements.
This class consists of common code factored out of the SmallVector class to reduce code duplication b...
reference emplace_back(ArgTypes &&... Args)
void reserve(size_type N)
iterator erase(const_iterator CI)
void append(ItTy in_start, ItTy in_end)
Add the specified range to the end of the SmallVector.
iterator insert(iterator I, T &&Elt)
void push_back(const T &Elt)
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
An instruction for storing to memory.
StringRef - Represent a constant reference to a string, i.e.
Definition StringRef.h:55
Used to lazily calculate structure layout information for a target machine, based on the DataLayout s...
Definition DataLayout.h:736
TypeSize getElementOffset(unsigned Idx) const
Definition DataLayout.h:767
TypeSize getSizeInBits() const
Definition DataLayout.h:747
Class to represent struct types.
Analysis pass providing the TargetLibraryInfo.
Provides information about what library functions are available for the current target.
The instances of the Type class are immutable: once they are created, they are never changed.
Definition Type.h:46
static LLVM_ABI IntegerType * getInt32Ty(LLVMContext &C)
Definition Type.cpp:313
bool isPointerTy() const
True if this is an instance of PointerType.
Definition Type.h:284
LLVM_ABI TypeSize getPrimitiveSizeInBits() const LLVM_READONLY
Return the basic size of this type if it is a primitive type.
Definition Type.cpp:201
static LLVM_ABI IntegerType * getInt1Ty(LLVMContext &C)
Definition Type.cpp:310
bool isIntOrPtrTy() const
Return true if this is an integer type or a pointer type.
Definition Type.h:272
bool isIntegerTy() const
True if this is an instance of IntegerType.
Definition Type.h:257
static LLVM_ABI IntegerType * getIntNTy(LLVMContext &C, unsigned N)
Definition Type.cpp:317
A Use represents the edge between a Value definition and its users.
Definition Use.h:35
op_range operands()
Definition User.h:267
Use & Op()
Definition User.h:171
Value * getOperand(unsigned i) const
Definition User.h:207
LLVM Value Representation.
Definition Value.h:75
Type * getType() const
All values are typed, get the type of this value.
Definition Value.h:256
LLVMContext & getContext() const
All values hold a context through their type.
Definition Value.h:259
unsigned getValueID() const
Return an ID for the concrete type of this object.
Definition Value.h:544
LLVM_ABI void printAsOperand(raw_ostream &O, bool PrintType=true, const Module *M=nullptr) const
Print the name of this Value out to the specified raw_ostream.
LLVM_ABI StringRef getName() const
Return a constant reference to the value's name.
Definition Value.cpp:322
constexpr bool isScalable() const
Returns whether the quantity is scaled by a runtime quantity (vscale).
Definition TypeSize.h:168
An efficient, type-erasing, non-owning reference to a callable.
const ParentTy * getParent() const
Definition ilist_node.h:34
This class implements an extremely fast bulk output stream that can only output to a stream.
Definition raw_ostream.h:53
raw_ostream & indent(unsigned NumSpaces)
indent - Insert 'NumSpaces' spaces.
Changed
#define llvm_unreachable(msg)
Marks that the current location is not supposed to be reachable.
constexpr char Align[]
Key for Kernel::Arg::Metadata::mAlign.
const APInt & smin(const APInt &A, const APInt &B)
Determine the smaller of two APInts considered to be signed.
Definition APInt.h:2263
const APInt & smax(const APInt &A, const APInt &B)
Determine the larger of two APInts considered to be signed.
Definition APInt.h:2268
const APInt & umin(const APInt &A, const APInt &B)
Determine the smaller of two APInts considered to be unsigned.
Definition APInt.h:2273
LLVM_ABI std::optional< APInt > SolveQuadraticEquationWrap(APInt A, APInt B, APInt C, unsigned RangeWidth)
Let q(n) = An^2 + Bn + C, and BW = bit width of the value range (e.g.
Definition APInt.cpp:2823
const APInt & umax(const APInt &A, const APInt &B)
Determine the larger of two APInts considered to be unsigned.
Definition APInt.h:2278
LLVM_ABI APInt GreatestCommonDivisor(APInt A, APInt B)
Compute GCD of two unsigned APInt values.
Definition APInt.cpp:798
@ Entry
Definition COFF.h:862
unsigned ID
LLVM IR allows to use arbitrary numbers as calling convention identifiers.
Definition CallingConv.h:24
@ C
The default llvm calling convention, compatible with C.
Definition CallingConv.h:34
int getMinValue(MCInstrInfo const &MCII, MCInst const &MCI)
Return the minimum value of an extendable operand.
@ BasicBlock
Various leaf nodes.
Definition ISDOpcodes.h:81
LLVM_ABI Function * getDeclarationIfExists(const Module *M, ID id)
Look up the Function declaration of the intrinsic id in the Module M and return it if it exists.
Predicate
Predicate - These are "(BI << 5) | BO" for various predicates.
BinaryOp_match< LHS, RHS, Instruction::AShr > m_AShr(const LHS &L, const RHS &R)
ap_match< APInt > m_APInt(const APInt *&Res)
Match a ConstantInt or splatted ConstantVector, binding the specified pointer to the contained APInt.
bool match(Val *V, const Pattern &P)
class_match< ConstantInt > m_ConstantInt()
Match an arbitrary ConstantInt and ignore it.
IntrinsicID_match m_Intrinsic()
Match intrinsic calls like this: m_Intrinsic<Intrinsic::fabs>(m_Value(X))
ThreeOps_match< Cond, LHS, RHS, Instruction::Select > m_Select(const Cond &C, const LHS &L, const RHS &R)
Matches SelectInst.
ExtractValue_match< Ind, Val_t > m_ExtractValue(const Val_t &V)
Match a single index ExtractValue instruction.
bind_ty< WithOverflowInst > m_WithOverflowInst(WithOverflowInst *&I)
Match a with overflow intrinsic, capturing it if we match.
auto m_LogicalOr()
Matches L || R where L and R are arbitrary values.
brc_match< Cond_t, bind_ty< BasicBlock >, bind_ty< BasicBlock > > m_Br(const Cond_t &C, BasicBlock *&T, BasicBlock *&F)
BinaryOp_match< LHS, RHS, Instruction::SDiv > m_SDiv(const LHS &L, const RHS &R)
class_match< Value > m_Value()
Match an arbitrary value and ignore it.
BinaryOp_match< LHS, RHS, Instruction::LShr > m_LShr(const LHS &L, const RHS &R)
BinaryOp_match< LHS, RHS, Instruction::Shl > m_Shl(const LHS &L, const RHS &R)
auto m_LogicalAnd()
Matches L && R where L and R are arbitrary values.
class_match< BasicBlock > m_BasicBlock()
Match an arbitrary basic block value and ignore it.
match_combine_or< LTy, RTy > m_CombineOr(const LTy &L, const RTy &R)
Combine two pattern matchers matching L || R.
class_match< const SCEVVScale > m_SCEVVScale()
bind_cst_ty m_scev_APInt(const APInt *&C)
Match an SCEV constant and bind it to an APInt.
cst_pred_ty< is_all_ones > m_scev_AllOnes()
Match an integer with all bits set.
SCEVUnaryExpr_match< SCEVZeroExtendExpr, Op0_t > m_scev_ZExt(const Op0_t &Op0)
is_undef_or_poison m_scev_UndefOrPoison()
Match an SCEVUnknown wrapping undef or poison.
class_match< const SCEVConstant > m_SCEVConstant()
cst_pred_ty< is_one > m_scev_One()
Match an integer 1.
specificloop_ty m_SpecificLoop(const Loop *L)
SCEVAffineAddRec_match< Op0_t, Op1_t, class_match< const Loop > > m_scev_AffineAddRec(const Op0_t &Op0, const Op1_t &Op1)
bind_ty< const SCEVMulExpr > m_scev_Mul(const SCEVMulExpr *&V)
SCEVUnaryExpr_match< SCEVSignExtendExpr, Op0_t > m_scev_SExt(const Op0_t &Op0)
cst_pred_ty< is_zero > m_scev_Zero()
Match an integer 0.
SCEVUnaryExpr_match< SCEVTruncateExpr, Op0_t > m_scev_Trunc(const Op0_t &Op0)
bool match(const SCEV *S, const Pattern &P)
SCEVBinaryExpr_match< SCEVUDivExpr, Op0_t, Op1_t > m_scev_UDiv(const Op0_t &Op0, const Op1_t &Op1)
specificscev_ty m_scev_Specific(const SCEV *S)
Match if we have a specific specified SCEV.
SCEVBinaryExpr_match< SCEVMulExpr, Op0_t, Op1_t, SCEV::FlagNUW, true > m_scev_c_NUWMul(const Op0_t &Op0, const Op1_t &Op1)
class_match< const Loop > m_Loop()
bind_ty< const SCEVAddExpr > m_scev_Add(const SCEVAddExpr *&V)
bind_ty< const SCEVUnknown > m_SCEVUnknown(const SCEVUnknown *&V)
SCEVBinaryExpr_match< SCEVMulExpr, Op0_t, Op1_t, SCEV::FlagAnyWrap, true > m_scev_c_Mul(const Op0_t &Op0, const Op1_t &Op1)
SCEVBinaryExpr_match< SCEVSMaxExpr, Op0_t, Op1_t > m_scev_SMax(const Op0_t &Op0, const Op1_t &Op1)
SCEVURem_match< Op0_t, Op1_t > m_scev_URem(Op0_t LHS, Op1_t RHS, ScalarEvolution &SE)
Match the mathematical pattern A - (A / B) * B, where A and B can be arbitrary expressions.
class_match< const SCEV > m_SCEV()
@ Valid
The data is already valid.
initializer< Ty > init(const Ty &Val)
LocationClass< Ty > location(Ty &L)
@ Switch
The "resume-switch" lowering, where there are separate resume and destroy functions that are shared b...
Definition CoroShape.h:31
constexpr double e
NodeAddr< PhiNode * > Phi
Definition RDFGraph.h:390
friend class Instruction
Iterator for Instructions in a `BasicBlock.
Definition BasicBlock.h:73
This is an optimization pass for GlobalISel generic memory operations.
Definition Types.h:26
void visitAll(const SCEV *Root, SV &Visitor)
Use SCEVTraversal to visit all nodes in the given expression tree.
auto drop_begin(T &&RangeOrContainer, size_t N=1)
Return a range covering RangeOrContainer with the first N elements excluded.
Definition STLExtras.h:316
@ Offset
Definition DWP.cpp:532
FunctionAddr VTableAddr Value
Definition InstrProf.h:137
LLVM_ATTRIBUTE_ALWAYS_INLINE DynamicAPInt gcd(const DynamicAPInt &A, const DynamicAPInt &B)
void stable_sort(R &&Range)
Definition STLExtras.h:2116
bool all_of(R &&range, UnaryPredicate P)
Provide wrappers to std::all_of which take ranges instead of having to pass begin/end explicitly.
Definition STLExtras.h:1739
SaveAndRestore(T &) -> SaveAndRestore< T >
Printable print(const GCNRegPressure &RP, const GCNSubtarget *ST=nullptr, unsigned DynamicVGPRBlockSize=0)
LLVM_ABI bool canCreatePoison(const Operator *Op, bool ConsiderFlagsAndMetadata=true)
LLVM_ABI bool mustTriggerUB(const Instruction *I, const SmallPtrSetImpl< const Value * > &KnownPoison)
Return true if the given instruction must trigger undefined behavior when I is executed with any oper...
LLVM_ABI bool canConstantFoldCallTo(const CallBase *Call, const Function *F)
canConstantFoldCallTo - Return true if its even possible to fold a call to the specified function.
InterleavedRange< Range > interleaved(const Range &R, StringRef Separator=", ", StringRef Prefix="", StringRef Suffix="")
Output range R as a sequence of interleaved elements.
decltype(auto) dyn_cast(const From &Val)
dyn_cast<X> - Return the argument parameter cast to the specified type.
Definition Casting.h:643
LLVM_ABI bool verifyFunction(const Function &F, raw_ostream *OS=nullptr)
Check a function for errors, useful for use when debugging a pass.
auto successors(const MachineBasicBlock *BB)
scope_exit(Callable) -> scope_exit< Callable >
constexpr from_range_t from_range
auto dyn_cast_if_present(const Y &Val)
dyn_cast_if_present<X> - Functionally identical to dyn_cast, except that a null (or none in the case ...
Definition Casting.h:732
bool set_is_subset(const S1Ty &S1, const S2Ty &S2)
set_is_subset(A, B) - Return true iff A in B
void append_range(Container &C, Range &&R)
Wrapper function to append range R to container C.
Definition STLExtras.h:2208
constexpr bool isUIntN(unsigned N, uint64_t x)
Checks if an unsigned integer fits into the given (dynamic) bit width.
Definition MathExtras.h:243
LLVM_ABI Constant * ConstantFoldCompareInstOperands(unsigned Predicate, Constant *LHS, Constant *RHS, const DataLayout &DL, const TargetLibraryInfo *TLI=nullptr, const Instruction *I=nullptr)
Attempt to constant fold a compare instruction (icmp/fcmp) with the specified operands.
void * PointerTy
LLVM_ABI bool VerifySCEV
auto uninitialized_copy(R &&Src, IterTy Dst)
Definition STLExtras.h:2111
bool isa_and_nonnull(const Y &Val)
Definition Casting.h:676
LLVM_ABI ConstantRange getConstantRangeFromMetadata(const MDNode &RangeMD)
Parse out a conservative ConstantRange from !range metadata.
int countr_zero(T Val)
Count number of 0's from the least significant bit to the most stopping at the first 1.
Definition bit.h:202
LLVM_ABI Value * simplifyInstruction(Instruction *I, const SimplifyQuery &Q)
See if we can compute a simplified version of this instruction.
LLVM_ABI bool isOverflowIntrinsicNoWrap(const WithOverflowInst *WO, const DominatorTree &DT)
Returns true if the arithmetic part of the WO 's result is used only along the paths control dependen...
DomTreeNodeBase< BasicBlock > DomTreeNode
Definition Dominators.h:94
LLVM_ABI bool matchSimpleRecurrence(const PHINode *P, BinaryOperator *&BO, Value *&Start, Value *&Step)
Attempt to match a simple first order recurrence cycle of the form: iv = phi Ty [Start,...
auto dyn_cast_or_null(const Y &Val)
Definition Casting.h:753
void erase(Container &C, ValueType V)
Wrapper function to remove a value from a container:
Definition STLExtras.h:2200
bool any_of(R &&range, UnaryPredicate P)
Provide wrappers to std::any_of which take ranges instead of having to pass begin/end explicitly.
Definition STLExtras.h:1746
iterator_range< pointee_iterator< WrappedIteratorT > > make_pointee_range(RangeT &&Range)
Definition iterator.h:341
auto reverse(ContainerTy &&C)
Definition STLExtras.h:408
LLVM_ABI bool isMustProgress(const Loop *L)
Return true if this loop can be assumed to make progress.
LLVM_ABI bool impliesPoison(const Value *ValAssumedPoison, const Value *V)
Return true if V is poison given that ValAssumedPoison is already poison.
LLVM_ABI bool isFinite(const Loop *L)
Return true if this loop can be assumed to run for a finite number of iterations.
LLVM_ABI void computeKnownBits(const Value *V, KnownBits &Known, const DataLayout &DL, AssumptionCache *AC=nullptr, const Instruction *CxtI=nullptr, const DominatorTree *DT=nullptr, bool UseInstrInfo=true, unsigned Depth=0)
Determine which bits of V are known to be either zero or one and return them in the KnownZero/KnownOn...
unsigned short computeExpressionSize(ArrayRef< SCEVUse > Args)
LLVM_ABI bool programUndefinedIfPoison(const Instruction *Inst)
LLVM_ABI raw_ostream & dbgs()
dbgs() - This returns a reference to a raw_ostream for debugging messages.
Definition Debug.cpp:207
bool isPointerTy(const Type *T)
Definition SPIRVUtils.h:368
FunctionAddr VTableAddr Count
Definition InstrProf.h:139
LLVM_ABI ConstantRange getVScaleRange(const Function *F, unsigned BitWidth)
Determine the possible constant range of vscale with the given bit width, based on the vscale_range f...
class LLVM_GSL_OWNER SmallVector
Forward declaration of SmallVector so that calculateSmallVectorDefaultInlinedElements can reference s...
bool isa(const From &Val)
isa<X> - Return true if the parameter to the template is an instance of one of the template type argu...
Definition Casting.h:547
LLVM_ATTRIBUTE_VISIBILITY_DEFAULT AnalysisKey InnerAnalysisManagerProxy< AnalysisManagerT, IRUnitT, ExtraArgTs... >::Key
LLVM_ABI bool isKnownNonZero(const Value *V, const SimplifyQuery &Q, unsigned Depth=0)
Return true if the given value is known to be non-zero when defined.
LLVM_ABI bool propagatesPoison(const Use &PoisonOp)
Return true if PoisonOp's user yields poison or raises UB if its operand PoisonOp is poison.
@ UMin
Unsigned integer min implemented in terms of select(cmp()).
@ Mul
Product of integers.
@ SMax
Signed integer max implemented in terms of select(cmp()).
@ SMin
Signed integer min implemented in terms of select(cmp()).
@ Add
Sum of integers.
@ UMax
Unsigned integer max implemented in terms of select(cmp()).
auto count(R &&Range, const E &Element)
Wrapper function around std::count to count the number of times an element Element occurs in the give...
Definition STLExtras.h:2012
DWARFExpression::Operation Op
auto max_element(R &&Range)
Provide wrappers to std::max_element which take ranges instead of having to pass begin/end explicitly...
Definition STLExtras.h:2088
raw_ostream & operator<<(raw_ostream &OS, const APFixedPoint &FX)
ArrayRef(const T &OneElt) -> ArrayRef< T >
LLVM_ABI unsigned ComputeNumSignBits(const Value *Op, const DataLayout &DL, AssumptionCache *AC=nullptr, const Instruction *CxtI=nullptr, const DominatorTree *DT=nullptr, bool UseInstrInfo=true, unsigned Depth=0)
Return the number of times the sign bit of the register is replicated into the other bits.
constexpr unsigned BitWidth
OutputIt move(R &&Range, OutputIt Out)
Provide wrappers to std::move which take ranges instead of having to pass begin/end explicitly.
Definition STLExtras.h:1917
LLVM_ABI bool isGuaranteedToTransferExecutionToSuccessor(const Instruction *I)
Return true if this function can prove that the instruction I will always transfer execution to one o...
auto count_if(R &&Range, UnaryPredicate P)
Wrapper function around std::count_if to count the number of times an element satisfying a given pred...
Definition STLExtras.h:2019
decltype(auto) cast(const From &Val)
cast<X> - Return the argument parameter cast to the specified type.
Definition Casting.h:559
constexpr bool isIntN(unsigned N, int64_t x)
Checks if an signed integer fits into the given (dynamic) bit width.
Definition MathExtras.h:248
auto predecessors(const MachineBasicBlock *BB)
bool is_contained(R &&Range, const E &Element)
Returns true if Element is found in Range.
Definition STLExtras.h:1947
iterator_range< df_iterator< T > > depth_first(const T &G)
auto seq(T Begin, T End)
Iterate over an integral type from Begin up to - but not including - End.
Definition Sequence.h:305
AnalysisManager< Function > FunctionAnalysisManager
Convenience typedef for the Function analysis manager.
LLVM_ABI bool isGuaranteedNotToBePoison(const Value *V, AssumptionCache *AC=nullptr, const Instruction *CtxI=nullptr, const DominatorTree *DT=nullptr, unsigned Depth=0)
Returns true if V cannot be poison, but may be undef.
LLVM_ABI Constant * ConstantFoldInstOperands(const Instruction *I, ArrayRef< Constant * > Ops, const DataLayout &DL, const TargetLibraryInfo *TLI=nullptr, bool AllowNonDeterministic=true)
ConstantFoldInstOperands - Attempt to constant fold an instruction with the specified operands.
bool SCEVExprContains(const SCEV *Root, PredTy Pred)
Return true if any node in Root satisfies the predicate Pred.
Implement std::hash so that hash_code can be used in STL containers.
Definition BitVector.h:870
void swap(llvm::BitVector &LHS, llvm::BitVector &RHS)
Implement std::swap in terms of BitVector swap.
Definition BitVector.h:872
#define N
#define NC
Definition regutils.h:42
A special type used by analysis passes to provide an address that identifies that particular analysis...
Definition Analysis.h:29
Incoming for lane mask phi as machine instruction, incoming register Reg and incoming block Block are...
static KnownBits makeConstant(const APInt &C)
Create known bits from a known constant.
Definition KnownBits.h:317
bool isNonNegative() const
Returns true if this value is known to be non-negative.
Definition KnownBits.h:108
static LLVM_ABI KnownBits ashr(const KnownBits &LHS, const KnownBits &RHS, bool ShAmtNonZero=false, bool Exact=false)
Compute known bits for ashr(LHS, RHS).
unsigned getBitWidth() const
Get the bit width of this value.
Definition KnownBits.h:44
static LLVM_ABI KnownBits lshr(const KnownBits &LHS, const KnownBits &RHS, bool ShAmtNonZero=false, bool Exact=false)
Compute known bits for lshr(LHS, RHS).
KnownBits zextOrTrunc(unsigned BitWidth) const
Return known bits for a zero extension or truncation of the value we're tracking.
Definition KnownBits.h:202
APInt getMaxValue() const
Return the maximal unsigned value possible given these KnownBits.
Definition KnownBits.h:148
APInt getMinValue() const
Return the minimal unsigned value possible given these KnownBits.
Definition KnownBits.h:132
bool isNegative() const
Returns true if this value is known to be negative.
Definition KnownBits.h:105
static LLVM_ABI KnownBits shl(const KnownBits &LHS, const KnownBits &RHS, bool NUW=false, bool NSW=false, bool ShAmtNonZero=false)
Compute known bits for shl(LHS, RHS).
An object of this class is returned by queries that could not be answered.
static LLVM_ABI bool classof(const SCEV *S)
Methods for support type inquiry through isa, cast, and dyn_cast:
void dump() const
This method is used for debugging.
void print(raw_ostream &OS) const
Print out the internal representation of this scalar to the specified stream.
This class defines a simple visitor class that may be used for various SCEV analysis purposes.
A utility class that uses RAII to save and restore the value of a variable.
Information about the number of loop iterations for which a loop exit's branch condition evaluates to...
LLVM_ABI ExitLimit(const SCEV *E)
Construct either an exact exit limit from a constant, or an unknown one from a SCEVCouldNotCompute.
SmallVector< const SCEVPredicate *, 4 > Predicates
A vector of predicate guards for this ExitLimit.