LLVM 22.0.0git
ScalarEvolution.cpp
Go to the documentation of this file.
1//===- ScalarEvolution.cpp - Scalar Evolution Analysis --------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file contains the implementation of the scalar evolution analysis
10// engine, which is used primarily to analyze expressions involving induction
11// variables in loops.
12//
13// There are several aspects to this library. First is the representation of
14// scalar expressions, which are represented as subclasses of the SCEV class.
15// These classes are used to represent certain types of subexpressions that we
16// can handle. We only create one SCEV of a particular shape, so
17// pointer-comparisons for equality are legal.
18//
19// One important aspect of the SCEV objects is that they are never cyclic, even
20// if there is a cycle in the dataflow for an expression (ie, a PHI node). If
21// the PHI node is one of the idioms that we can represent (e.g., a polynomial
22// recurrence) then we represent it directly as a recurrence node, otherwise we
23// represent it as a SCEVUnknown node.
24//
25// In addition to being able to represent expressions of various types, we also
26// have folders that are used to build the *canonical* representation for a
27// particular expression. These folders are capable of using a variety of
28// rewrite rules to simplify the expressions.
29//
30// Once the folders are defined, we can implement the more interesting
31// higher-level code, such as the code that recognizes PHI nodes of various
32// types, computes the execution count of a loop, etc.
33//
34// TODO: We should use these routines and value representations to implement
35// dependence analysis!
36//
37//===----------------------------------------------------------------------===//
38//
39// There are several good references for the techniques used in this analysis.
40//
41// Chains of recurrences -- a method to expedite the evaluation
42// of closed-form functions
43// Olaf Bachmann, Paul S. Wang, Eugene V. Zima
44//
45// On computational properties of chains of recurrences
46// Eugene V. Zima
47//
48// Symbolic Evaluation of Chains of Recurrences for Loop Optimization
49// Robert A. van Engelen
50//
51// Efficient Symbolic Analysis for Optimizing Compilers
52// Robert A. van Engelen
53//
54// Using the chains of recurrences algebra for data dependence testing and
55// induction variable substitution
56// MS Thesis, Johnie Birch
57//
58//===----------------------------------------------------------------------===//
59
61#include "llvm/ADT/APInt.h"
62#include "llvm/ADT/ArrayRef.h"
63#include "llvm/ADT/DenseMap.h"
65#include "llvm/ADT/FoldingSet.h"
66#include "llvm/ADT/STLExtras.h"
67#include "llvm/ADT/ScopeExit.h"
68#include "llvm/ADT/Sequence.h"
71#include "llvm/ADT/Statistic.h"
73#include "llvm/ADT/StringRef.h"
83#include "llvm/Config/llvm-config.h"
84#include "llvm/IR/Argument.h"
85#include "llvm/IR/BasicBlock.h"
86#include "llvm/IR/CFG.h"
87#include "llvm/IR/Constant.h"
89#include "llvm/IR/Constants.h"
90#include "llvm/IR/DataLayout.h"
92#include "llvm/IR/Dominators.h"
93#include "llvm/IR/Function.h"
94#include "llvm/IR/GlobalAlias.h"
95#include "llvm/IR/GlobalValue.h"
97#include "llvm/IR/InstrTypes.h"
98#include "llvm/IR/Instruction.h"
101#include "llvm/IR/Intrinsics.h"
102#include "llvm/IR/LLVMContext.h"
103#include "llvm/IR/Operator.h"
104#include "llvm/IR/PatternMatch.h"
105#include "llvm/IR/Type.h"
106#include "llvm/IR/Use.h"
107#include "llvm/IR/User.h"
108#include "llvm/IR/Value.h"
109#include "llvm/IR/Verifier.h"
111#include "llvm/Pass.h"
112#include "llvm/Support/Casting.h"
115#include "llvm/Support/Debug.h"
121#include <algorithm>
122#include <cassert>
123#include <climits>
124#include <cstdint>
125#include <cstdlib>
126#include <map>
127#include <memory>
128#include <numeric>
129#include <optional>
130#include <tuple>
131#include <utility>
132#include <vector>
133
134using namespace llvm;
135using namespace PatternMatch;
136using namespace SCEVPatternMatch;
137
138#define DEBUG_TYPE "scalar-evolution"
139
140STATISTIC(NumExitCountsComputed,
141 "Number of loop exits with predictable exit counts");
142STATISTIC(NumExitCountsNotComputed,
143 "Number of loop exits without predictable exit counts");
144STATISTIC(NumBruteForceTripCountsComputed,
145 "Number of loops with trip counts computed by force");
146
147#ifdef EXPENSIVE_CHECKS
148bool llvm::VerifySCEV = true;
149#else
150bool llvm::VerifySCEV = false;
151#endif
152
154 MaxBruteForceIterations("scalar-evolution-max-iterations", cl::ReallyHidden,
155 cl::desc("Maximum number of iterations SCEV will "
156 "symbolically execute a constant "
157 "derived loop"),
158 cl::init(100));
159
161 "verify-scev", cl::Hidden, cl::location(VerifySCEV),
162 cl::desc("Verify ScalarEvolution's backedge taken counts (slow)"));
164 "verify-scev-strict", cl::Hidden,
165 cl::desc("Enable stricter verification with -verify-scev is passed"));
166
168 "scev-verify-ir", cl::Hidden,
169 cl::desc("Verify IR correctness when making sensitive SCEV queries (slow)"),
170 cl::init(false));
171
173 "scev-mulops-inline-threshold", cl::Hidden,
174 cl::desc("Threshold for inlining multiplication operands into a SCEV"),
175 cl::init(32));
176
178 "scev-addops-inline-threshold", cl::Hidden,
179 cl::desc("Threshold for inlining addition operands into a SCEV"),
180 cl::init(500));
181
183 "scalar-evolution-max-scev-compare-depth", cl::Hidden,
184 cl::desc("Maximum depth of recursive SCEV complexity comparisons"),
185 cl::init(32));
186
188 "scalar-evolution-max-scev-operations-implication-depth", cl::Hidden,
189 cl::desc("Maximum depth of recursive SCEV operations implication analysis"),
190 cl::init(2));
191
193 "scalar-evolution-max-value-compare-depth", cl::Hidden,
194 cl::desc("Maximum depth of recursive value complexity comparisons"),
195 cl::init(2));
196
198 MaxArithDepth("scalar-evolution-max-arith-depth", cl::Hidden,
199 cl::desc("Maximum depth of recursive arithmetics"),
200 cl::init(32));
201
203 "scalar-evolution-max-constant-evolving-depth", cl::Hidden,
204 cl::desc("Maximum depth of recursive constant evolving"), cl::init(32));
205
207 MaxCastDepth("scalar-evolution-max-cast-depth", cl::Hidden,
208 cl::desc("Maximum depth of recursive SExt/ZExt/Trunc"),
209 cl::init(8));
210
212 MaxAddRecSize("scalar-evolution-max-add-rec-size", cl::Hidden,
213 cl::desc("Max coefficients in AddRec during evolving"),
214 cl::init(8));
215
217 HugeExprThreshold("scalar-evolution-huge-expr-threshold", cl::Hidden,
218 cl::desc("Size of the expression which is considered huge"),
219 cl::init(4096));
220
222 "scev-range-iter-threshold", cl::Hidden,
223 cl::desc("Threshold for switching to iteratively computing SCEV ranges"),
224 cl::init(32));
225
227 "scalar-evolution-max-loop-guard-collection-depth", cl::Hidden,
228 cl::desc("Maximum depth for recursive loop guard collection"), cl::init(1));
229
230static cl::opt<bool>
231ClassifyExpressions("scalar-evolution-classify-expressions",
232 cl::Hidden, cl::init(true),
233 cl::desc("When printing analysis, include information on every instruction"));
234
236 "scalar-evolution-use-expensive-range-sharpening", cl::Hidden,
237 cl::init(false),
238 cl::desc("Use more powerful methods of sharpening expression ranges. May "
239 "be costly in terms of compile time"));
240
242 "scalar-evolution-max-scc-analysis-depth", cl::Hidden,
243 cl::desc("Maximum amount of nodes to process while searching SCEVUnknown "
244 "Phi strongly connected components"),
245 cl::init(8));
246
247static cl::opt<bool>
248 EnableFiniteLoopControl("scalar-evolution-finite-loop", cl::Hidden,
249 cl::desc("Handle <= and >= in finite loops"),
250 cl::init(true));
251
253 "scalar-evolution-use-context-for-no-wrap-flag-strenghening", cl::Hidden,
254 cl::desc("Infer nuw/nsw flags using context where suitable"),
255 cl::init(true));
256
257//===----------------------------------------------------------------------===//
258// SCEV class definitions
259//===----------------------------------------------------------------------===//
260
261//===----------------------------------------------------------------------===//
262// Implementation of the SCEV class.
263//
264
265#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
267 print(dbgs());
268 dbgs() << '\n';
269}
270#endif
271
272void SCEV::print(raw_ostream &OS) const {
273 switch (getSCEVType()) {
274 case scConstant:
275 cast<SCEVConstant>(this)->getValue()->printAsOperand(OS, false);
276 return;
277 case scVScale:
278 OS << "vscale";
279 return;
280 case scPtrToInt: {
281 const SCEVPtrToIntExpr *PtrToInt = cast<SCEVPtrToIntExpr>(this);
282 const SCEV *Op = PtrToInt->getOperand();
283 OS << "(ptrtoint " << *Op->getType() << " " << *Op << " to "
284 << *PtrToInt->getType() << ")";
285 return;
286 }
287 case scTruncate: {
288 const SCEVTruncateExpr *Trunc = cast<SCEVTruncateExpr>(this);
289 const SCEV *Op = Trunc->getOperand();
290 OS << "(trunc " << *Op->getType() << " " << *Op << " to "
291 << *Trunc->getType() << ")";
292 return;
293 }
294 case scZeroExtend: {
296 const SCEV *Op = ZExt->getOperand();
297 OS << "(zext " << *Op->getType() << " " << *Op << " to "
298 << *ZExt->getType() << ")";
299 return;
300 }
301 case scSignExtend: {
303 const SCEV *Op = SExt->getOperand();
304 OS << "(sext " << *Op->getType() << " " << *Op << " to "
305 << *SExt->getType() << ")";
306 return;
307 }
308 case scAddRecExpr: {
309 const SCEVAddRecExpr *AR = cast<SCEVAddRecExpr>(this);
310 OS << "{" << *AR->getOperand(0);
311 for (unsigned i = 1, e = AR->getNumOperands(); i != e; ++i)
312 OS << ",+," << *AR->getOperand(i);
313 OS << "}<";
314 if (AR->hasNoUnsignedWrap())
315 OS << "nuw><";
316 if (AR->hasNoSignedWrap())
317 OS << "nsw><";
318 if (AR->hasNoSelfWrap() &&
320 OS << "nw><";
321 AR->getLoop()->getHeader()->printAsOperand(OS, /*PrintType=*/false);
322 OS << ">";
323 return;
324 }
325 case scAddExpr:
326 case scMulExpr:
327 case scUMaxExpr:
328 case scSMaxExpr:
329 case scUMinExpr:
330 case scSMinExpr:
332 const SCEVNAryExpr *NAry = cast<SCEVNAryExpr>(this);
333 const char *OpStr = nullptr;
334 switch (NAry->getSCEVType()) {
335 case scAddExpr: OpStr = " + "; break;
336 case scMulExpr: OpStr = " * "; break;
337 case scUMaxExpr: OpStr = " umax "; break;
338 case scSMaxExpr: OpStr = " smax "; break;
339 case scUMinExpr:
340 OpStr = " umin ";
341 break;
342 case scSMinExpr:
343 OpStr = " smin ";
344 break;
346 OpStr = " umin_seq ";
347 break;
348 default:
349 llvm_unreachable("There are no other nary expression types.");
350 }
351 OS << "("
353 << ")";
354 switch (NAry->getSCEVType()) {
355 case scAddExpr:
356 case scMulExpr:
357 if (NAry->hasNoUnsignedWrap())
358 OS << "<nuw>";
359 if (NAry->hasNoSignedWrap())
360 OS << "<nsw>";
361 break;
362 default:
363 // Nothing to print for other nary expressions.
364 break;
365 }
366 return;
367 }
368 case scUDivExpr: {
369 const SCEVUDivExpr *UDiv = cast<SCEVUDivExpr>(this);
370 OS << "(" << *UDiv->getLHS() << " /u " << *UDiv->getRHS() << ")";
371 return;
372 }
373 case scUnknown:
374 cast<SCEVUnknown>(this)->getValue()->printAsOperand(OS, false);
375 return;
377 OS << "***COULDNOTCOMPUTE***";
378 return;
379 }
380 llvm_unreachable("Unknown SCEV kind!");
381}
382
384 switch (getSCEVType()) {
385 case scConstant:
386 return cast<SCEVConstant>(this)->getType();
387 case scVScale:
388 return cast<SCEVVScale>(this)->getType();
389 case scPtrToInt:
390 case scTruncate:
391 case scZeroExtend:
392 case scSignExtend:
393 return cast<SCEVCastExpr>(this)->getType();
394 case scAddRecExpr:
395 return cast<SCEVAddRecExpr>(this)->getType();
396 case scMulExpr:
397 return cast<SCEVMulExpr>(this)->getType();
398 case scUMaxExpr:
399 case scSMaxExpr:
400 case scUMinExpr:
401 case scSMinExpr:
402 return cast<SCEVMinMaxExpr>(this)->getType();
404 return cast<SCEVSequentialMinMaxExpr>(this)->getType();
405 case scAddExpr:
406 return cast<SCEVAddExpr>(this)->getType();
407 case scUDivExpr:
408 return cast<SCEVUDivExpr>(this)->getType();
409 case scUnknown:
410 return cast<SCEVUnknown>(this)->getType();
412 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
413 }
414 llvm_unreachable("Unknown SCEV kind!");
415}
416
418 switch (getSCEVType()) {
419 case scConstant:
420 case scVScale:
421 case scUnknown:
422 return {};
423 case scPtrToInt:
424 case scTruncate:
425 case scZeroExtend:
426 case scSignExtend:
427 return cast<SCEVCastExpr>(this)->operands();
428 case scAddRecExpr:
429 case scAddExpr:
430 case scMulExpr:
431 case scUMaxExpr:
432 case scSMaxExpr:
433 case scUMinExpr:
434 case scSMinExpr:
436 return cast<SCEVNAryExpr>(this)->operands();
437 case scUDivExpr:
438 return cast<SCEVUDivExpr>(this)->operands();
440 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
441 }
442 llvm_unreachable("Unknown SCEV kind!");
443}
444
445bool SCEV::isZero() const { return match(this, m_scev_Zero()); }
446
447bool SCEV::isOne() const { return match(this, m_scev_One()); }
448
449bool SCEV::isAllOnesValue() const { return match(this, m_scev_AllOnes()); }
450
453 if (!Mul) return false;
454
455 // If there is a constant factor, it will be first.
456 const SCEVConstant *SC = dyn_cast<SCEVConstant>(Mul->getOperand(0));
457 if (!SC) return false;
458
459 // Return true if the value is negative, this matches things like (-42 * V).
460 return SC->getAPInt().isNegative();
461}
462
465
467 return S->getSCEVType() == scCouldNotCompute;
468}
469
472 ID.AddInteger(scConstant);
473 ID.AddPointer(V);
474 void *IP = nullptr;
475 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
476 SCEV *S = new (SCEVAllocator) SCEVConstant(ID.Intern(SCEVAllocator), V);
477 UniqueSCEVs.InsertNode(S, IP);
478 return S;
479}
480
482 return getConstant(ConstantInt::get(getContext(), Val));
483}
484
485const SCEV *
488 return getConstant(ConstantInt::get(ITy, V, isSigned));
489}
490
493 ID.AddInteger(scVScale);
494 ID.AddPointer(Ty);
495 void *IP = nullptr;
496 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP))
497 return S;
498 SCEV *S = new (SCEVAllocator) SCEVVScale(ID.Intern(SCEVAllocator), Ty);
499 UniqueSCEVs.InsertNode(S, IP);
500 return S;
501}
502
504 SCEV::NoWrapFlags Flags) {
505 const SCEV *Res = getConstant(Ty, EC.getKnownMinValue());
506 if (EC.isScalable())
507 Res = getMulExpr(Res, getVScale(Ty), Flags);
508 return Res;
509}
510
512 const SCEV *op, Type *ty)
513 : SCEV(ID, SCEVTy, computeExpressionSize(op)), Op(op), Ty(ty) {}
514
515SCEVPtrToIntExpr::SCEVPtrToIntExpr(const FoldingSetNodeIDRef ID, const SCEV *Op,
516 Type *ITy)
517 : SCEVCastExpr(ID, scPtrToInt, Op, ITy) {
518 assert(getOperand()->getType()->isPointerTy() && Ty->isIntegerTy() &&
519 "Must be a non-bit-width-changing pointer-to-integer cast!");
520}
521
523 SCEVTypes SCEVTy, const SCEV *op,
524 Type *ty)
525 : SCEVCastExpr(ID, SCEVTy, op, ty) {}
526
527SCEVTruncateExpr::SCEVTruncateExpr(const FoldingSetNodeIDRef ID, const SCEV *op,
528 Type *ty)
530 assert(getOperand()->getType()->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
531 "Cannot truncate non-integer value!");
532}
533
534SCEVZeroExtendExpr::SCEVZeroExtendExpr(const FoldingSetNodeIDRef ID,
535 const SCEV *op, Type *ty)
537 assert(getOperand()->getType()->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
538 "Cannot zero extend non-integer value!");
539}
540
541SCEVSignExtendExpr::SCEVSignExtendExpr(const FoldingSetNodeIDRef ID,
542 const SCEV *op, Type *ty)
544 assert(getOperand()->getType()->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
545 "Cannot sign extend non-integer value!");
546}
547
549 // Clear this SCEVUnknown from various maps.
550 SE->forgetMemoizedResults(this);
551
552 // Remove this SCEVUnknown from the uniquing map.
553 SE->UniqueSCEVs.RemoveNode(this);
554
555 // Release the value.
556 setValPtr(nullptr);
557}
558
559void SCEVUnknown::allUsesReplacedWith(Value *New) {
560 // Clear this SCEVUnknown from various maps.
561 SE->forgetMemoizedResults(this);
562
563 // Remove this SCEVUnknown from the uniquing map.
564 SE->UniqueSCEVs.RemoveNode(this);
565
566 // Replace the value pointer in case someone is still using this SCEVUnknown.
567 setValPtr(New);
568}
569
570//===----------------------------------------------------------------------===//
571// SCEV Utilities
572//===----------------------------------------------------------------------===//
573
574/// Compare the two values \p LV and \p RV in terms of their "complexity" where
575/// "complexity" is a partial (and somewhat ad-hoc) relation used to order
576/// operands in SCEV expressions.
577static int CompareValueComplexity(const LoopInfo *const LI, Value *LV,
578 Value *RV, unsigned Depth) {
580 return 0;
581
582 // Order pointer values after integer values. This helps SCEVExpander form
583 // GEPs.
584 bool LIsPointer = LV->getType()->isPointerTy(),
585 RIsPointer = RV->getType()->isPointerTy();
586 if (LIsPointer != RIsPointer)
587 return (int)LIsPointer - (int)RIsPointer;
588
589 // Compare getValueID values.
590 unsigned LID = LV->getValueID(), RID = RV->getValueID();
591 if (LID != RID)
592 return (int)LID - (int)RID;
593
594 // Sort arguments by their position.
595 if (const auto *LA = dyn_cast<Argument>(LV)) {
596 const auto *RA = cast<Argument>(RV);
597 unsigned LArgNo = LA->getArgNo(), RArgNo = RA->getArgNo();
598 return (int)LArgNo - (int)RArgNo;
599 }
600
601 if (const auto *LGV = dyn_cast<GlobalValue>(LV)) {
602 const auto *RGV = cast<GlobalValue>(RV);
603
604 if (auto L = LGV->getLinkage() - RGV->getLinkage())
605 return L;
606
607 const auto IsGVNameSemantic = [&](const GlobalValue *GV) {
608 auto LT = GV->getLinkage();
609 return !(GlobalValue::isPrivateLinkage(LT) ||
611 };
612
613 // Use the names to distinguish the two values, but only if the
614 // names are semantically important.
615 if (IsGVNameSemantic(LGV) && IsGVNameSemantic(RGV))
616 return LGV->getName().compare(RGV->getName());
617 }
618
619 // For instructions, compare their loop depth, and their operand count. This
620 // is pretty loose.
621 if (const auto *LInst = dyn_cast<Instruction>(LV)) {
622 const auto *RInst = cast<Instruction>(RV);
623
624 // Compare loop depths.
625 const BasicBlock *LParent = LInst->getParent(),
626 *RParent = RInst->getParent();
627 if (LParent != RParent) {
628 unsigned LDepth = LI->getLoopDepth(LParent),
629 RDepth = LI->getLoopDepth(RParent);
630 if (LDepth != RDepth)
631 return (int)LDepth - (int)RDepth;
632 }
633
634 // Compare the number of operands.
635 unsigned LNumOps = LInst->getNumOperands(),
636 RNumOps = RInst->getNumOperands();
637 if (LNumOps != RNumOps)
638 return (int)LNumOps - (int)RNumOps;
639
640 for (unsigned Idx : seq(LNumOps)) {
641 int Result = CompareValueComplexity(LI, LInst->getOperand(Idx),
642 RInst->getOperand(Idx), Depth + 1);
643 if (Result != 0)
644 return Result;
645 }
646 }
647
648 return 0;
649}
650
651// Return negative, zero, or positive, if LHS is less than, equal to, or greater
652// than RHS, respectively. A three-way result allows recursive comparisons to be
653// more efficient.
654// If the max analysis depth was reached, return std::nullopt, assuming we do
655// not know if they are equivalent for sure.
656static std::optional<int>
657CompareSCEVComplexity(const LoopInfo *const LI, const SCEV *LHS,
658 const SCEV *RHS, DominatorTree &DT, unsigned Depth = 0) {
659 // Fast-path: SCEVs are uniqued so we can do a quick equality check.
660 if (LHS == RHS)
661 return 0;
662
663 // Primarily, sort the SCEVs by their getSCEVType().
664 SCEVTypes LType = LHS->getSCEVType(), RType = RHS->getSCEVType();
665 if (LType != RType)
666 return (int)LType - (int)RType;
667
669 return std::nullopt;
670
671 // Aside from the getSCEVType() ordering, the particular ordering
672 // isn't very important except that it's beneficial to be consistent,
673 // so that (a + b) and (b + a) don't end up as different expressions.
674 switch (LType) {
675 case scUnknown: {
676 const SCEVUnknown *LU = cast<SCEVUnknown>(LHS);
677 const SCEVUnknown *RU = cast<SCEVUnknown>(RHS);
678
679 int X =
680 CompareValueComplexity(LI, LU->getValue(), RU->getValue(), Depth + 1);
681 return X;
682 }
683
684 case scConstant: {
687
688 // Compare constant values.
689 const APInt &LA = LC->getAPInt();
690 const APInt &RA = RC->getAPInt();
691 unsigned LBitWidth = LA.getBitWidth(), RBitWidth = RA.getBitWidth();
692 if (LBitWidth != RBitWidth)
693 return (int)LBitWidth - (int)RBitWidth;
694 return LA.ult(RA) ? -1 : 1;
695 }
696
697 case scVScale: {
698 const auto *LTy = cast<IntegerType>(cast<SCEVVScale>(LHS)->getType());
699 const auto *RTy = cast<IntegerType>(cast<SCEVVScale>(RHS)->getType());
700 return LTy->getBitWidth() - RTy->getBitWidth();
701 }
702
703 case scAddRecExpr: {
706
707 // There is always a dominance between two recs that are used by one SCEV,
708 // so we can safely sort recs by loop header dominance. We require such
709 // order in getAddExpr.
710 const Loop *LLoop = LA->getLoop(), *RLoop = RA->getLoop();
711 if (LLoop != RLoop) {
712 const BasicBlock *LHead = LLoop->getHeader(), *RHead = RLoop->getHeader();
713 assert(LHead != RHead && "Two loops share the same header?");
714 if (DT.dominates(LHead, RHead))
715 return 1;
716 assert(DT.dominates(RHead, LHead) &&
717 "No dominance between recurrences used by one SCEV?");
718 return -1;
719 }
720
721 [[fallthrough]];
722 }
723
724 case scTruncate:
725 case scZeroExtend:
726 case scSignExtend:
727 case scPtrToInt:
728 case scAddExpr:
729 case scMulExpr:
730 case scUDivExpr:
731 case scSMaxExpr:
732 case scUMaxExpr:
733 case scSMinExpr:
734 case scUMinExpr:
736 ArrayRef<const SCEV *> LOps = LHS->operands();
737 ArrayRef<const SCEV *> ROps = RHS->operands();
738
739 // Lexicographically compare n-ary-like expressions.
740 unsigned LNumOps = LOps.size(), RNumOps = ROps.size();
741 if (LNumOps != RNumOps)
742 return (int)LNumOps - (int)RNumOps;
743
744 for (unsigned i = 0; i != LNumOps; ++i) {
745 auto X = CompareSCEVComplexity(LI, LOps[i], ROps[i], DT, Depth + 1);
746 if (X != 0)
747 return X;
748 }
749 return 0;
750 }
751
753 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
754 }
755 llvm_unreachable("Unknown SCEV kind!");
756}
757
758/// Given a list of SCEV objects, order them by their complexity, and group
759/// objects of the same complexity together by value. When this routine is
760/// finished, we know that any duplicates in the vector are consecutive and that
761/// complexity is monotonically increasing.
762///
763/// Note that we go take special precautions to ensure that we get deterministic
764/// results from this routine. In other words, we don't want the results of
765/// this to depend on where the addresses of various SCEV objects happened to
766/// land in memory.
768 LoopInfo *LI, DominatorTree &DT) {
769 if (Ops.size() < 2) return; // Noop
770
771 // Whether LHS has provably less complexity than RHS.
772 auto IsLessComplex = [&](const SCEV *LHS, const SCEV *RHS) {
773 auto Complexity = CompareSCEVComplexity(LI, LHS, RHS, DT);
774 return Complexity && *Complexity < 0;
775 };
776 if (Ops.size() == 2) {
777 // This is the common case, which also happens to be trivially simple.
778 // Special case it.
779 const SCEV *&LHS = Ops[0], *&RHS = Ops[1];
780 if (IsLessComplex(RHS, LHS))
781 std::swap(LHS, RHS);
782 return;
783 }
784
785 // Do the rough sort by complexity.
786 llvm::stable_sort(Ops, [&](const SCEV *LHS, const SCEV *RHS) {
787 return IsLessComplex(LHS, RHS);
788 });
789
790 // Now that we are sorted by complexity, group elements of the same
791 // complexity. Note that this is, at worst, N^2, but the vector is likely to
792 // be extremely short in practice. Note that we take this approach because we
793 // do not want to depend on the addresses of the objects we are grouping.
794 for (unsigned i = 0, e = Ops.size(); i != e-2; ++i) {
795 const SCEV *S = Ops[i];
796 unsigned Complexity = S->getSCEVType();
797
798 // If there are any objects of the same complexity and same value as this
799 // one, group them.
800 for (unsigned j = i+1; j != e && Ops[j]->getSCEVType() == Complexity; ++j) {
801 if (Ops[j] == S) { // Found a duplicate.
802 // Move it to immediately after i'th element.
803 std::swap(Ops[i+1], Ops[j]);
804 ++i; // no need to rescan it.
805 if (i == e-2) return; // Done!
806 }
807 }
808 }
809}
810
811/// Returns true if \p Ops contains a huge SCEV (the subtree of S contains at
812/// least HugeExprThreshold nodes).
814 return any_of(Ops, [](const SCEV *S) {
816 });
817}
818
819/// Performs a number of common optimizations on the passed \p Ops. If the
820/// whole expression reduces down to a single operand, it will be returned.
821///
822/// The following optimizations are performed:
823/// * Fold constants using the \p Fold function.
824/// * Remove identity constants satisfying \p IsIdentity.
825/// * If a constant satisfies \p IsAbsorber, return it.
826/// * Sort operands by complexity.
827template <typename FoldT, typename IsIdentityT, typename IsAbsorberT>
828static const SCEV *
831 IsIdentityT IsIdentity, IsAbsorberT IsAbsorber) {
832 const SCEVConstant *Folded = nullptr;
833 for (unsigned Idx = 0; Idx < Ops.size();) {
834 const SCEV *Op = Ops[Idx];
835 if (const auto *C = dyn_cast<SCEVConstant>(Op)) {
836 if (!Folded)
837 Folded = C;
838 else
839 Folded = cast<SCEVConstant>(
840 SE.getConstant(Fold(Folded->getAPInt(), C->getAPInt())));
841 Ops.erase(Ops.begin() + Idx);
842 continue;
843 }
844 ++Idx;
845 }
846
847 if (Ops.empty()) {
848 assert(Folded && "Must have folded value");
849 return Folded;
850 }
851
852 if (Folded && IsAbsorber(Folded->getAPInt()))
853 return Folded;
854
855 GroupByComplexity(Ops, &LI, DT);
856 if (Folded && !IsIdentity(Folded->getAPInt()))
857 Ops.insert(Ops.begin(), Folded);
858
859 return Ops.size() == 1 ? Ops[0] : nullptr;
860}
861
862//===----------------------------------------------------------------------===//
863// Simple SCEV method implementations
864//===----------------------------------------------------------------------===//
865
866/// Compute BC(It, K). The result has width W. Assume, K > 0.
867static const SCEV *BinomialCoefficient(const SCEV *It, unsigned K,
868 ScalarEvolution &SE,
869 Type *ResultTy) {
870 // Handle the simplest case efficiently.
871 if (K == 1)
872 return SE.getTruncateOrZeroExtend(It, ResultTy);
873
874 // We are using the following formula for BC(It, K):
875 //
876 // BC(It, K) = (It * (It - 1) * ... * (It - K + 1)) / K!
877 //
878 // Suppose, W is the bitwidth of the return value. We must be prepared for
879 // overflow. Hence, we must assure that the result of our computation is
880 // equal to the accurate one modulo 2^W. Unfortunately, division isn't
881 // safe in modular arithmetic.
882 //
883 // However, this code doesn't use exactly that formula; the formula it uses
884 // is something like the following, where T is the number of factors of 2 in
885 // K! (i.e. trailing zeros in the binary representation of K!), and ^ is
886 // exponentiation:
887 //
888 // BC(It, K) = (It * (It - 1) * ... * (It - K + 1)) / 2^T / (K! / 2^T)
889 //
890 // This formula is trivially equivalent to the previous formula. However,
891 // this formula can be implemented much more efficiently. The trick is that
892 // K! / 2^T is odd, and exact division by an odd number *is* safe in modular
893 // arithmetic. To do exact division in modular arithmetic, all we have
894 // to do is multiply by the inverse. Therefore, this step can be done at
895 // width W.
896 //
897 // The next issue is how to safely do the division by 2^T. The way this
898 // is done is by doing the multiplication step at a width of at least W + T
899 // bits. This way, the bottom W+T bits of the product are accurate. Then,
900 // when we perform the division by 2^T (which is equivalent to a right shift
901 // by T), the bottom W bits are accurate. Extra bits are okay; they'll get
902 // truncated out after the division by 2^T.
903 //
904 // In comparison to just directly using the first formula, this technique
905 // is much more efficient; using the first formula requires W * K bits,
906 // but this formula less than W + K bits. Also, the first formula requires
907 // a division step, whereas this formula only requires multiplies and shifts.
908 //
909 // It doesn't matter whether the subtraction step is done in the calculation
910 // width or the input iteration count's width; if the subtraction overflows,
911 // the result must be zero anyway. We prefer here to do it in the width of
912 // the induction variable because it helps a lot for certain cases; CodeGen
913 // isn't smart enough to ignore the overflow, which leads to much less
914 // efficient code if the width of the subtraction is wider than the native
915 // register width.
916 //
917 // (It's possible to not widen at all by pulling out factors of 2 before
918 // the multiplication; for example, K=2 can be calculated as
919 // It/2*(It+(It*INT_MIN/INT_MIN)+-1). However, it requires
920 // extra arithmetic, so it's not an obvious win, and it gets
921 // much more complicated for K > 3.)
922
923 // Protection from insane SCEVs; this bound is conservative,
924 // but it probably doesn't matter.
925 if (K > 1000)
926 return SE.getCouldNotCompute();
927
928 unsigned W = SE.getTypeSizeInBits(ResultTy);
929
930 // Calculate K! / 2^T and T; we divide out the factors of two before
931 // multiplying for calculating K! / 2^T to avoid overflow.
932 // Other overflow doesn't matter because we only care about the bottom
933 // W bits of the result.
934 APInt OddFactorial(W, 1);
935 unsigned T = 1;
936 for (unsigned i = 3; i <= K; ++i) {
937 unsigned TwoFactors = countr_zero(i);
938 T += TwoFactors;
939 OddFactorial *= (i >> TwoFactors);
940 }
941
942 // We need at least W + T bits for the multiplication step
943 unsigned CalculationBits = W + T;
944
945 // Calculate 2^T, at width T+W.
946 APInt DivFactor = APInt::getOneBitSet(CalculationBits, T);
947
948 // Calculate the multiplicative inverse of K! / 2^T;
949 // this multiplication factor will perform the exact division by
950 // K! / 2^T.
951 APInt MultiplyFactor = OddFactorial.multiplicativeInverse();
952
953 // Calculate the product, at width T+W
954 IntegerType *CalculationTy = IntegerType::get(SE.getContext(),
955 CalculationBits);
956 const SCEV *Dividend = SE.getTruncateOrZeroExtend(It, CalculationTy);
957 for (unsigned i = 1; i != K; ++i) {
958 const SCEV *S = SE.getMinusSCEV(It, SE.getConstant(It->getType(), i));
959 Dividend = SE.getMulExpr(Dividend,
960 SE.getTruncateOrZeroExtend(S, CalculationTy));
961 }
962
963 // Divide by 2^T
964 const SCEV *DivResult = SE.getUDivExpr(Dividend, SE.getConstant(DivFactor));
965
966 // Truncate the result, and divide by K! / 2^T.
967
968 return SE.getMulExpr(SE.getConstant(MultiplyFactor),
969 SE.getTruncateOrZeroExtend(DivResult, ResultTy));
970}
971
972/// Return the value of this chain of recurrences at the specified iteration
973/// number. We can evaluate this recurrence by multiplying each element in the
974/// chain by the binomial coefficient corresponding to it. In other words, we
975/// can evaluate {A,+,B,+,C,+,D} as:
976///
977/// A*BC(It, 0) + B*BC(It, 1) + C*BC(It, 2) + D*BC(It, 3)
978///
979/// where BC(It, k) stands for binomial coefficient.
981 ScalarEvolution &SE) const {
982 return evaluateAtIteration(operands(), It, SE);
983}
984
985const SCEV *
987 const SCEV *It, ScalarEvolution &SE) {
988 assert(Operands.size() > 0);
989 const SCEV *Result = Operands[0];
990 for (unsigned i = 1, e = Operands.size(); i != e; ++i) {
991 // The computation is correct in the face of overflow provided that the
992 // multiplication is performed _after_ the evaluation of the binomial
993 // coefficient.
994 const SCEV *Coeff = BinomialCoefficient(It, i, SE, Result->getType());
995 if (isa<SCEVCouldNotCompute>(Coeff))
996 return Coeff;
997
998 Result = SE.getAddExpr(Result, SE.getMulExpr(Operands[i], Coeff));
999 }
1000 return Result;
1001}
1002
1003//===----------------------------------------------------------------------===//
1004// SCEV Expression folder implementations
1005//===----------------------------------------------------------------------===//
1006
1008 unsigned Depth) {
1009 assert(Depth <= 1 &&
1010 "getLosslessPtrToIntExpr() should self-recurse at most once.");
1011
1012 // We could be called with an integer-typed operands during SCEV rewrites.
1013 // Since the operand is an integer already, just perform zext/trunc/self cast.
1014 if (!Op->getType()->isPointerTy())
1015 return Op;
1016
1017 // What would be an ID for such a SCEV cast expression?
1019 ID.AddInteger(scPtrToInt);
1020 ID.AddPointer(Op);
1021
1022 void *IP = nullptr;
1023
1024 // Is there already an expression for such a cast?
1025 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP))
1026 return S;
1027
1028 // It isn't legal for optimizations to construct new ptrtoint expressions
1029 // for non-integral pointers.
1030 if (getDataLayout().isNonIntegralPointerType(Op->getType()))
1031 return getCouldNotCompute();
1032
1033 Type *IntPtrTy = getDataLayout().getIntPtrType(Op->getType());
1034
1035 // We can only trivially model ptrtoint if SCEV's effective (integer) type
1036 // is sufficiently wide to represent all possible pointer values.
1037 // We could theoretically teach SCEV to truncate wider pointers, but
1038 // that isn't implemented for now.
1040 getDataLayout().getTypeSizeInBits(IntPtrTy))
1041 return getCouldNotCompute();
1042
1043 // If not, is this expression something we can't reduce any further?
1044 if (auto *U = dyn_cast<SCEVUnknown>(Op)) {
1045 // Perform some basic constant folding. If the operand of the ptr2int cast
1046 // is a null pointer, don't create a ptr2int SCEV expression (that will be
1047 // left as-is), but produce a zero constant.
1048 // NOTE: We could handle a more general case, but lack motivational cases.
1049 if (isa<ConstantPointerNull>(U->getValue()))
1050 return getZero(IntPtrTy);
1051
1052 // Create an explicit cast node.
1053 // We can reuse the existing insert position since if we get here,
1054 // we won't have made any changes which would invalidate it.
1055 SCEV *S = new (SCEVAllocator)
1056 SCEVPtrToIntExpr(ID.Intern(SCEVAllocator), Op, IntPtrTy);
1057 UniqueSCEVs.InsertNode(S, IP);
1058 registerUser(S, Op);
1059 return S;
1060 }
1061
1062 assert(Depth == 0 && "getLosslessPtrToIntExpr() should not self-recurse for "
1063 "non-SCEVUnknown's.");
1064
1065 // Otherwise, we've got some expression that is more complex than just a
1066 // single SCEVUnknown. But we don't want to have a SCEVPtrToIntExpr of an
1067 // arbitrary expression, we want to have SCEVPtrToIntExpr of an SCEVUnknown
1068 // only, and the expressions must otherwise be integer-typed.
1069 // So sink the cast down to the SCEVUnknown's.
1070
1071 /// The SCEVPtrToIntSinkingRewriter takes a scalar evolution expression,
1072 /// which computes a pointer-typed value, and rewrites the whole expression
1073 /// tree so that *all* the computations are done on integers, and the only
1074 /// pointer-typed operands in the expression are SCEVUnknown.
1075 class SCEVPtrToIntSinkingRewriter
1076 : public SCEVRewriteVisitor<SCEVPtrToIntSinkingRewriter> {
1078
1079 public:
1080 SCEVPtrToIntSinkingRewriter(ScalarEvolution &SE) : SCEVRewriteVisitor(SE) {}
1081
1082 static const SCEV *rewrite(const SCEV *Scev, ScalarEvolution &SE) {
1083 SCEVPtrToIntSinkingRewriter Rewriter(SE);
1084 return Rewriter.visit(Scev);
1085 }
1086
1087 const SCEV *visit(const SCEV *S) {
1088 Type *STy = S->getType();
1089 // If the expression is not pointer-typed, just keep it as-is.
1090 if (!STy->isPointerTy())
1091 return S;
1092 // Else, recursively sink the cast down into it.
1093 return Base::visit(S);
1094 }
1095
1096 const SCEV *visitAddExpr(const SCEVAddExpr *Expr) {
1098 bool Changed = false;
1099 for (const auto *Op : Expr->operands()) {
1100 Operands.push_back(visit(Op));
1101 Changed |= Op != Operands.back();
1102 }
1103 return !Changed ? Expr : SE.getAddExpr(Operands, Expr->getNoWrapFlags());
1104 }
1105
1106 const SCEV *visitMulExpr(const SCEVMulExpr *Expr) {
1108 bool Changed = false;
1109 for (const auto *Op : Expr->operands()) {
1110 Operands.push_back(visit(Op));
1111 Changed |= Op != Operands.back();
1112 }
1113 return !Changed ? Expr : SE.getMulExpr(Operands, Expr->getNoWrapFlags());
1114 }
1115
1116 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
1117 assert(Expr->getType()->isPointerTy() &&
1118 "Should only reach pointer-typed SCEVUnknown's.");
1119 return SE.getLosslessPtrToIntExpr(Expr, /*Depth=*/1);
1120 }
1121 };
1122
1123 // And actually perform the cast sinking.
1124 const SCEV *IntOp = SCEVPtrToIntSinkingRewriter::rewrite(Op, *this);
1125 assert(IntOp->getType()->isIntegerTy() &&
1126 "We must have succeeded in sinking the cast, "
1127 "and ending up with an integer-typed expression!");
1128 return IntOp;
1129}
1130
1132 assert(Ty->isIntegerTy() && "Target type must be an integer type!");
1133
1134 const SCEV *IntOp = getLosslessPtrToIntExpr(Op);
1135 if (isa<SCEVCouldNotCompute>(IntOp))
1136 return IntOp;
1137
1138 return getTruncateOrZeroExtend(IntOp, Ty);
1139}
1140
1142 unsigned Depth) {
1143 assert(getTypeSizeInBits(Op->getType()) > getTypeSizeInBits(Ty) &&
1144 "This is not a truncating conversion!");
1145 assert(isSCEVable(Ty) &&
1146 "This is not a conversion to a SCEVable type!");
1147 assert(!Op->getType()->isPointerTy() && "Can't truncate pointer!");
1148 Ty = getEffectiveSCEVType(Ty);
1149
1151 ID.AddInteger(scTruncate);
1152 ID.AddPointer(Op);
1153 ID.AddPointer(Ty);
1154 void *IP = nullptr;
1155 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
1156
1157 // Fold if the operand is constant.
1158 if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op))
1159 return getConstant(
1160 cast<ConstantInt>(ConstantExpr::getTrunc(SC->getValue(), Ty)));
1161
1162 // trunc(trunc(x)) --> trunc(x)
1164 return getTruncateExpr(ST->getOperand(), Ty, Depth + 1);
1165
1166 // trunc(sext(x)) --> sext(x) if widening or trunc(x) if narrowing
1168 return getTruncateOrSignExtend(SS->getOperand(), Ty, Depth + 1);
1169
1170 // trunc(zext(x)) --> zext(x) if widening or trunc(x) if narrowing
1172 return getTruncateOrZeroExtend(SZ->getOperand(), Ty, Depth + 1);
1173
1174 if (Depth > MaxCastDepth) {
1175 SCEV *S =
1176 new (SCEVAllocator) SCEVTruncateExpr(ID.Intern(SCEVAllocator), Op, Ty);
1177 UniqueSCEVs.InsertNode(S, IP);
1178 registerUser(S, Op);
1179 return S;
1180 }
1181
1182 // trunc(x1 + ... + xN) --> trunc(x1) + ... + trunc(xN) and
1183 // trunc(x1 * ... * xN) --> trunc(x1) * ... * trunc(xN),
1184 // if after transforming we have at most one truncate, not counting truncates
1185 // that replace other casts.
1187 auto *CommOp = cast<SCEVCommutativeExpr>(Op);
1189 unsigned numTruncs = 0;
1190 for (unsigned i = 0, e = CommOp->getNumOperands(); i != e && numTruncs < 2;
1191 ++i) {
1192 const SCEV *S = getTruncateExpr(CommOp->getOperand(i), Ty, Depth + 1);
1193 if (!isa<SCEVIntegralCastExpr>(CommOp->getOperand(i)) &&
1195 numTruncs++;
1196 Operands.push_back(S);
1197 }
1198 if (numTruncs < 2) {
1199 if (isa<SCEVAddExpr>(Op))
1200 return getAddExpr(Operands);
1201 if (isa<SCEVMulExpr>(Op))
1202 return getMulExpr(Operands);
1203 llvm_unreachable("Unexpected SCEV type for Op.");
1204 }
1205 // Although we checked in the beginning that ID is not in the cache, it is
1206 // possible that during recursion and different modification ID was inserted
1207 // into the cache. So if we find it, just return it.
1208 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP))
1209 return S;
1210 }
1211
1212 // If the input value is a chrec scev, truncate the chrec's operands.
1213 if (const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(Op)) {
1215 for (const SCEV *Op : AddRec->operands())
1216 Operands.push_back(getTruncateExpr(Op, Ty, Depth + 1));
1217 return getAddRecExpr(Operands, AddRec->getLoop(), SCEV::FlagAnyWrap);
1218 }
1219
1220 // Return zero if truncating to known zeros.
1221 uint32_t MinTrailingZeros = getMinTrailingZeros(Op);
1222 if (MinTrailingZeros >= getTypeSizeInBits(Ty))
1223 return getZero(Ty);
1224
1225 // The cast wasn't folded; create an explicit cast node. We can reuse
1226 // the existing insert position since if we get here, we won't have
1227 // made any changes which would invalidate it.
1228 SCEV *S = new (SCEVAllocator) SCEVTruncateExpr(ID.Intern(SCEVAllocator),
1229 Op, Ty);
1230 UniqueSCEVs.InsertNode(S, IP);
1231 registerUser(S, Op);
1232 return S;
1233}
1234
1235// Get the limit of a recurrence such that incrementing by Step cannot cause
1236// signed overflow as long as the value of the recurrence within the
1237// loop does not exceed this limit before incrementing.
1238static const SCEV *getSignedOverflowLimitForStep(const SCEV *Step,
1239 ICmpInst::Predicate *Pred,
1240 ScalarEvolution *SE) {
1241 unsigned BitWidth = SE->getTypeSizeInBits(Step->getType());
1242 if (SE->isKnownPositive(Step)) {
1243 *Pred = ICmpInst::ICMP_SLT;
1245 SE->getSignedRangeMax(Step));
1246 }
1247 if (SE->isKnownNegative(Step)) {
1248 *Pred = ICmpInst::ICMP_SGT;
1250 SE->getSignedRangeMin(Step));
1251 }
1252 return nullptr;
1253}
1254
1255// Get the limit of a recurrence such that incrementing by Step cannot cause
1256// unsigned overflow as long as the value of the recurrence within the loop does
1257// not exceed this limit before incrementing.
1259 ICmpInst::Predicate *Pred,
1260 ScalarEvolution *SE) {
1261 unsigned BitWidth = SE->getTypeSizeInBits(Step->getType());
1262 *Pred = ICmpInst::ICMP_ULT;
1263
1265 SE->getUnsignedRangeMax(Step));
1266}
1267
1268namespace {
1269
1270struct ExtendOpTraitsBase {
1271 typedef const SCEV *(ScalarEvolution::*GetExtendExprTy)(const SCEV *, Type *,
1272 unsigned);
1273};
1274
1275// Used to make code generic over signed and unsigned overflow.
1276template <typename ExtendOp> struct ExtendOpTraits {
1277 // Members present:
1278 //
1279 // static const SCEV::NoWrapFlags WrapType;
1280 //
1281 // static const ExtendOpTraitsBase::GetExtendExprTy GetExtendExpr;
1282 //
1283 // static const SCEV *getOverflowLimitForStep(const SCEV *Step,
1284 // ICmpInst::Predicate *Pred,
1285 // ScalarEvolution *SE);
1286};
1287
1288template <>
1289struct ExtendOpTraits<SCEVSignExtendExpr> : public ExtendOpTraitsBase {
1290 static const SCEV::NoWrapFlags WrapType = SCEV::FlagNSW;
1291
1292 static const GetExtendExprTy GetExtendExpr;
1293
1294 static const SCEV *getOverflowLimitForStep(const SCEV *Step,
1295 ICmpInst::Predicate *Pred,
1296 ScalarEvolution *SE) {
1297 return getSignedOverflowLimitForStep(Step, Pred, SE);
1298 }
1299};
1300
1301const ExtendOpTraitsBase::GetExtendExprTy ExtendOpTraits<
1303
1304template <>
1305struct ExtendOpTraits<SCEVZeroExtendExpr> : public ExtendOpTraitsBase {
1306 static const SCEV::NoWrapFlags WrapType = SCEV::FlagNUW;
1307
1308 static const GetExtendExprTy GetExtendExpr;
1309
1310 static const SCEV *getOverflowLimitForStep(const SCEV *Step,
1311 ICmpInst::Predicate *Pred,
1312 ScalarEvolution *SE) {
1313 return getUnsignedOverflowLimitForStep(Step, Pred, SE);
1314 }
1315};
1316
1317const ExtendOpTraitsBase::GetExtendExprTy ExtendOpTraits<
1319
1320} // end anonymous namespace
1321
1322// The recurrence AR has been shown to have no signed/unsigned wrap or something
1323// close to it. Typically, if we can prove NSW/NUW for AR, then we can just as
1324// easily prove NSW/NUW for its preincrement or postincrement sibling. This
1325// allows normalizing a sign/zero extended AddRec as such: {sext/zext(Step +
1326// Start),+,Step} => {(Step + sext/zext(Start),+,Step} As a result, the
1327// expression "Step + sext/zext(PreIncAR)" is congruent with
1328// "sext/zext(PostIncAR)"
1329template <typename ExtendOpTy>
1330static const SCEV *getPreStartForExtend(const SCEVAddRecExpr *AR, Type *Ty,
1331 ScalarEvolution *SE, unsigned Depth) {
1332 auto WrapType = ExtendOpTraits<ExtendOpTy>::WrapType;
1333 auto GetExtendExpr = ExtendOpTraits<ExtendOpTy>::GetExtendExpr;
1334
1335 const Loop *L = AR->getLoop();
1336 const SCEV *Start = AR->getStart();
1337 const SCEV *Step = AR->getStepRecurrence(*SE);
1338
1339 // Check for a simple looking step prior to loop entry.
1340 const SCEVAddExpr *SA = dyn_cast<SCEVAddExpr>(Start);
1341 if (!SA)
1342 return nullptr;
1343
1344 // Create an AddExpr for "PreStart" after subtracting Step. Full SCEV
1345 // subtraction is expensive. For this purpose, perform a quick and dirty
1346 // difference, by checking for Step in the operand list. Note, that
1347 // SA might have repeated ops, like %a + %a + ..., so only remove one.
1349 for (auto It = DiffOps.begin(); It != DiffOps.end(); ++It)
1350 if (*It == Step) {
1351 DiffOps.erase(It);
1352 break;
1353 }
1354
1355 if (DiffOps.size() == SA->getNumOperands())
1356 return nullptr;
1357
1358 // Try to prove `WrapType` (SCEV::FlagNSW or SCEV::FlagNUW) on `PreStart` +
1359 // `Step`:
1360
1361 // 1. NSW/NUW flags on the step increment.
1362 auto PreStartFlags =
1364 const SCEV *PreStart = SE->getAddExpr(DiffOps, PreStartFlags);
1366 SE->getAddRecExpr(PreStart, Step, L, SCEV::FlagAnyWrap));
1367
1368 // "{S,+,X} is <nsw>/<nuw>" and "the backedge is taken at least once" implies
1369 // "S+X does not sign/unsign-overflow".
1370 //
1371
1372 const SCEV *BECount = SE->getBackedgeTakenCount(L);
1373 if (PreAR && PreAR->getNoWrapFlags(WrapType) &&
1374 !isa<SCEVCouldNotCompute>(BECount) && SE->isKnownPositive(BECount))
1375 return PreStart;
1376
1377 // 2. Direct overflow check on the step operation's expression.
1378 unsigned BitWidth = SE->getTypeSizeInBits(AR->getType());
1379 Type *WideTy = IntegerType::get(SE->getContext(), BitWidth * 2);
1380 const SCEV *OperandExtendedStart =
1381 SE->getAddExpr((SE->*GetExtendExpr)(PreStart, WideTy, Depth),
1382 (SE->*GetExtendExpr)(Step, WideTy, Depth));
1383 if ((SE->*GetExtendExpr)(Start, WideTy, Depth) == OperandExtendedStart) {
1384 if (PreAR && AR->getNoWrapFlags(WrapType)) {
1385 // If we know `AR` == {`PreStart`+`Step`,+,`Step`} is `WrapType` (FlagNSW
1386 // or FlagNUW) and that `PreStart` + `Step` is `WrapType` too, then
1387 // `PreAR` == {`PreStart`,+,`Step`} is also `WrapType`. Cache this fact.
1388 SE->setNoWrapFlags(const_cast<SCEVAddRecExpr *>(PreAR), WrapType);
1389 }
1390 return PreStart;
1391 }
1392
1393 // 3. Loop precondition.
1395 const SCEV *OverflowLimit =
1396 ExtendOpTraits<ExtendOpTy>::getOverflowLimitForStep(Step, &Pred, SE);
1397
1398 if (OverflowLimit &&
1399 SE->isLoopEntryGuardedByCond(L, Pred, PreStart, OverflowLimit))
1400 return PreStart;
1401
1402 return nullptr;
1403}
1404
1405// Get the normalized zero or sign extended expression for this AddRec's Start.
1406template <typename ExtendOpTy>
1407static const SCEV *getExtendAddRecStart(const SCEVAddRecExpr *AR, Type *Ty,
1408 ScalarEvolution *SE,
1409 unsigned Depth) {
1410 auto GetExtendExpr = ExtendOpTraits<ExtendOpTy>::GetExtendExpr;
1411
1412 const SCEV *PreStart = getPreStartForExtend<ExtendOpTy>(AR, Ty, SE, Depth);
1413 if (!PreStart)
1414 return (SE->*GetExtendExpr)(AR->getStart(), Ty, Depth);
1415
1416 return SE->getAddExpr((SE->*GetExtendExpr)(AR->getStepRecurrence(*SE), Ty,
1417 Depth),
1418 (SE->*GetExtendExpr)(PreStart, Ty, Depth));
1419}
1420
1421// Try to prove away overflow by looking at "nearby" add recurrences. A
1422// motivating example for this rule: if we know `{0,+,4}` is `ult` `-1` and it
1423// does not itself wrap then we can conclude that `{1,+,4}` is `nuw`.
1424//
1425// Formally:
1426//
1427// {S,+,X} == {S-T,+,X} + T
1428// => Ext({S,+,X}) == Ext({S-T,+,X} + T)
1429//
1430// If ({S-T,+,X} + T) does not overflow ... (1)
1431//
1432// RHS == Ext({S-T,+,X} + T) == Ext({S-T,+,X}) + Ext(T)
1433//
1434// If {S-T,+,X} does not overflow ... (2)
1435//
1436// RHS == Ext({S-T,+,X}) + Ext(T) == {Ext(S-T),+,Ext(X)} + Ext(T)
1437// == {Ext(S-T)+Ext(T),+,Ext(X)}
1438//
1439// If (S-T)+T does not overflow ... (3)
1440//
1441// RHS == {Ext(S-T)+Ext(T),+,Ext(X)} == {Ext(S-T+T),+,Ext(X)}
1442// == {Ext(S),+,Ext(X)} == LHS
1443//
1444// Thus, if (1), (2) and (3) are true for some T, then
1445// Ext({S,+,X}) == {Ext(S),+,Ext(X)}
1446//
1447// (3) is implied by (1) -- "(S-T)+T does not overflow" is simply "({S-T,+,X}+T)
1448// does not overflow" restricted to the 0th iteration. Therefore we only need
1449// to check for (1) and (2).
1450//
1451// In the current context, S is `Start`, X is `Step`, Ext is `ExtendOpTy` and T
1452// is `Delta` (defined below).
1453template <typename ExtendOpTy>
1454bool ScalarEvolution::proveNoWrapByVaryingStart(const SCEV *Start,
1455 const SCEV *Step,
1456 const Loop *L) {
1457 auto WrapType = ExtendOpTraits<ExtendOpTy>::WrapType;
1458
1459 // We restrict `Start` to a constant to prevent SCEV from spending too much
1460 // time here. It is correct (but more expensive) to continue with a
1461 // non-constant `Start` and do a general SCEV subtraction to compute
1462 // `PreStart` below.
1463 const SCEVConstant *StartC = dyn_cast<SCEVConstant>(Start);
1464 if (!StartC)
1465 return false;
1466
1467 APInt StartAI = StartC->getAPInt();
1468
1469 for (unsigned Delta : {-2, -1, 1, 2}) {
1470 const SCEV *PreStart = getConstant(StartAI - Delta);
1471
1472 FoldingSetNodeID ID;
1473 ID.AddInteger(scAddRecExpr);
1474 ID.AddPointer(PreStart);
1475 ID.AddPointer(Step);
1476 ID.AddPointer(L);
1477 void *IP = nullptr;
1478 const auto *PreAR =
1479 static_cast<SCEVAddRecExpr *>(UniqueSCEVs.FindNodeOrInsertPos(ID, IP));
1480
1481 // Give up if we don't already have the add recurrence we need because
1482 // actually constructing an add recurrence is relatively expensive.
1483 if (PreAR && PreAR->getNoWrapFlags(WrapType)) { // proves (2)
1484 const SCEV *DeltaS = getConstant(StartC->getType(), Delta);
1486 const SCEV *Limit = ExtendOpTraits<ExtendOpTy>::getOverflowLimitForStep(
1487 DeltaS, &Pred, this);
1488 if (Limit && isKnownPredicate(Pred, PreAR, Limit)) // proves (1)
1489 return true;
1490 }
1491 }
1492
1493 return false;
1494}
1495
1496// Finds an integer D for an expression (C + x + y + ...) such that the top
1497// level addition in (D + (C - D + x + y + ...)) would not wrap (signed or
1498// unsigned) and the number of trailing zeros of (C - D + x + y + ...) is
1499// maximized, where C is the \p ConstantTerm, x, y, ... are arbitrary SCEVs, and
1500// the (C + x + y + ...) expression is \p WholeAddExpr.
1502 const SCEVConstant *ConstantTerm,
1503 const SCEVAddExpr *WholeAddExpr) {
1504 const APInt &C = ConstantTerm->getAPInt();
1505 const unsigned BitWidth = C.getBitWidth();
1506 // Find number of trailing zeros of (x + y + ...) w/o the C first:
1507 uint32_t TZ = BitWidth;
1508 for (unsigned I = 1, E = WholeAddExpr->getNumOperands(); I < E && TZ; ++I)
1509 TZ = std::min(TZ, SE.getMinTrailingZeros(WholeAddExpr->getOperand(I)));
1510 if (TZ) {
1511 // Set D to be as many least significant bits of C as possible while still
1512 // guaranteeing that adding D to (C - D + x + y + ...) won't cause a wrap:
1513 return TZ < BitWidth ? C.trunc(TZ).zext(BitWidth) : C;
1514 }
1515 return APInt(BitWidth, 0);
1516}
1517
1518// Finds an integer D for an affine AddRec expression {C,+,x} such that the top
1519// level addition in (D + {C-D,+,x}) would not wrap (signed or unsigned) and the
1520// number of trailing zeros of (C - D + x * n) is maximized, where C is the \p
1521// ConstantStart, x is an arbitrary \p Step, and n is the loop trip count.
1523 const APInt &ConstantStart,
1524 const SCEV *Step) {
1525 const unsigned BitWidth = ConstantStart.getBitWidth();
1526 const uint32_t TZ = SE.getMinTrailingZeros(Step);
1527 if (TZ)
1528 return TZ < BitWidth ? ConstantStart.trunc(TZ).zext(BitWidth)
1529 : ConstantStart;
1530 return APInt(BitWidth, 0);
1531}
1532
1534 const ScalarEvolution::FoldID &ID, const SCEV *S,
1537 &FoldCacheUser) {
1538 auto I = FoldCache.insert({ID, S});
1539 if (!I.second) {
1540 // Remove FoldCacheUser entry for ID when replacing an existing FoldCache
1541 // entry.
1542 auto &UserIDs = FoldCacheUser[I.first->second];
1543 assert(count(UserIDs, ID) == 1 && "unexpected duplicates in UserIDs");
1544 for (unsigned I = 0; I != UserIDs.size(); ++I)
1545 if (UserIDs[I] == ID) {
1546 std::swap(UserIDs[I], UserIDs.back());
1547 break;
1548 }
1549 UserIDs.pop_back();
1550 I.first->second = S;
1551 }
1552 FoldCacheUser[S].push_back(ID);
1553}
1554
1555const SCEV *
1557 assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) &&
1558 "This is not an extending conversion!");
1559 assert(isSCEVable(Ty) &&
1560 "This is not a conversion to a SCEVable type!");
1561 assert(!Op->getType()->isPointerTy() && "Can't extend pointer!");
1562 Ty = getEffectiveSCEVType(Ty);
1563
1564 FoldID ID(scZeroExtend, Op, Ty);
1565 if (const SCEV *S = FoldCache.lookup(ID))
1566 return S;
1567
1568 const SCEV *S = getZeroExtendExprImpl(Op, Ty, Depth);
1570 insertFoldCacheEntry(ID, S, FoldCache, FoldCacheUser);
1571 return S;
1572}
1573
1575 unsigned Depth) {
1576 assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) &&
1577 "This is not an extending conversion!");
1578 assert(isSCEVable(Ty) && "This is not a conversion to a SCEVable type!");
1579 assert(!Op->getType()->isPointerTy() && "Can't extend pointer!");
1580
1581 // Fold if the operand is constant.
1582 if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op))
1583 return getConstant(SC->getAPInt().zext(getTypeSizeInBits(Ty)));
1584
1585 // zext(zext(x)) --> zext(x)
1587 return getZeroExtendExpr(SZ->getOperand(), Ty, Depth + 1);
1588
1589 // Before doing any expensive analysis, check to see if we've already
1590 // computed a SCEV for this Op and Ty.
1592 ID.AddInteger(scZeroExtend);
1593 ID.AddPointer(Op);
1594 ID.AddPointer(Ty);
1595 void *IP = nullptr;
1596 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
1597 if (Depth > MaxCastDepth) {
1598 SCEV *S = new (SCEVAllocator) SCEVZeroExtendExpr(ID.Intern(SCEVAllocator),
1599 Op, Ty);
1600 UniqueSCEVs.InsertNode(S, IP);
1601 registerUser(S, Op);
1602 return S;
1603 }
1604
1605 // zext(trunc(x)) --> zext(x) or x or trunc(x)
1607 // It's possible the bits taken off by the truncate were all zero bits. If
1608 // so, we should be able to simplify this further.
1609 const SCEV *X = ST->getOperand();
1611 unsigned TruncBits = getTypeSizeInBits(ST->getType());
1612 unsigned NewBits = getTypeSizeInBits(Ty);
1613 if (CR.truncate(TruncBits).zeroExtend(NewBits).contains(
1614 CR.zextOrTrunc(NewBits)))
1615 return getTruncateOrZeroExtend(X, Ty, Depth);
1616 }
1617
1618 // If the input value is a chrec scev, and we can prove that the value
1619 // did not overflow the old, smaller, value, we can zero extend all of the
1620 // operands (often constants). This allows analysis of something like
1621 // this: for (unsigned char X = 0; X < 100; ++X) { int Y = X; }
1623 if (AR->isAffine()) {
1624 const SCEV *Start = AR->getStart();
1625 const SCEV *Step = AR->getStepRecurrence(*this);
1626 unsigned BitWidth = getTypeSizeInBits(AR->getType());
1627 const Loop *L = AR->getLoop();
1628
1629 // If we have special knowledge that this addrec won't overflow,
1630 // we don't need to do any further analysis.
1631 if (AR->hasNoUnsignedWrap()) {
1632 Start =
1634 Step = getZeroExtendExpr(Step, Ty, Depth + 1);
1635 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
1636 }
1637
1638 // Check whether the backedge-taken count is SCEVCouldNotCompute.
1639 // Note that this serves two purposes: It filters out loops that are
1640 // simply not analyzable, and it covers the case where this code is
1641 // being called from within backedge-taken count analysis, such that
1642 // attempting to ask for the backedge-taken count would likely result
1643 // in infinite recursion. In the later case, the analysis code will
1644 // cope with a conservative value, and it will take care to purge
1645 // that value once it has finished.
1646 const SCEV *MaxBECount = getConstantMaxBackedgeTakenCount(L);
1647 if (!isa<SCEVCouldNotCompute>(MaxBECount)) {
1648 // Manually compute the final value for AR, checking for overflow.
1649
1650 // Check whether the backedge-taken count can be losslessly casted to
1651 // the addrec's type. The count is always unsigned.
1652 const SCEV *CastedMaxBECount =
1653 getTruncateOrZeroExtend(MaxBECount, Start->getType(), Depth);
1654 const SCEV *RecastedMaxBECount = getTruncateOrZeroExtend(
1655 CastedMaxBECount, MaxBECount->getType(), Depth);
1656 if (MaxBECount == RecastedMaxBECount) {
1657 Type *WideTy = IntegerType::get(getContext(), BitWidth * 2);
1658 // Check whether Start+Step*MaxBECount has no unsigned overflow.
1659 const SCEV *ZMul = getMulExpr(CastedMaxBECount, Step,
1661 const SCEV *ZAdd = getZeroExtendExpr(getAddExpr(Start, ZMul,
1663 Depth + 1),
1664 WideTy, Depth + 1);
1665 const SCEV *WideStart = getZeroExtendExpr(Start, WideTy, Depth + 1);
1666 const SCEV *WideMaxBECount =
1667 getZeroExtendExpr(CastedMaxBECount, WideTy, Depth + 1);
1668 const SCEV *OperandExtendedAdd =
1669 getAddExpr(WideStart,
1670 getMulExpr(WideMaxBECount,
1671 getZeroExtendExpr(Step, WideTy, Depth + 1),
1674 if (ZAdd == OperandExtendedAdd) {
1675 // Cache knowledge of AR NUW, which is propagated to this AddRec.
1676 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNUW);
1677 // Return the expression with the addrec on the outside.
1678 Start = getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this,
1679 Depth + 1);
1680 Step = getZeroExtendExpr(Step, Ty, Depth + 1);
1681 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
1682 }
1683 // Similar to above, only this time treat the step value as signed.
1684 // This covers loops that count down.
1685 OperandExtendedAdd =
1686 getAddExpr(WideStart,
1687 getMulExpr(WideMaxBECount,
1688 getSignExtendExpr(Step, WideTy, Depth + 1),
1691 if (ZAdd == OperandExtendedAdd) {
1692 // Cache knowledge of AR NW, which is propagated to this AddRec.
1693 // Negative step causes unsigned wrap, but it still can't self-wrap.
1694 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNW);
1695 // Return the expression with the addrec on the outside.
1696 Start = getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this,
1697 Depth + 1);
1698 Step = getSignExtendExpr(Step, Ty, Depth + 1);
1699 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
1700 }
1701 }
1702 }
1703
1704 // Normally, in the cases we can prove no-overflow via a
1705 // backedge guarding condition, we can also compute a backedge
1706 // taken count for the loop. The exceptions are assumptions and
1707 // guards present in the loop -- SCEV is not great at exploiting
1708 // these to compute max backedge taken counts, but can still use
1709 // these to prove lack of overflow. Use this fact to avoid
1710 // doing extra work that may not pay off.
1711 if (!isa<SCEVCouldNotCompute>(MaxBECount) || HasGuards ||
1712 !AC.assumptions().empty()) {
1713
1714 auto NewFlags = proveNoUnsignedWrapViaInduction(AR);
1715 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), NewFlags);
1716 if (AR->hasNoUnsignedWrap()) {
1717 // Same as nuw case above - duplicated here to avoid a compile time
1718 // issue. It's not clear that the order of checks does matter, but
1719 // it's one of two issue possible causes for a change which was
1720 // reverted. Be conservative for the moment.
1721 Start =
1723 Step = getZeroExtendExpr(Step, Ty, Depth + 1);
1724 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
1725 }
1726
1727 // For a negative step, we can extend the operands iff doing so only
1728 // traverses values in the range zext([0,UINT_MAX]).
1729 if (isKnownNegative(Step)) {
1731 getSignedRangeMin(Step));
1734 // Cache knowledge of AR NW, which is propagated to this
1735 // AddRec. Negative step causes unsigned wrap, but it
1736 // still can't self-wrap.
1737 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNW);
1738 // Return the expression with the addrec on the outside.
1739 Start = getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this,
1740 Depth + 1);
1741 Step = getSignExtendExpr(Step, Ty, Depth + 1);
1742 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
1743 }
1744 }
1745 }
1746
1747 // zext({C,+,Step}) --> (zext(D) + zext({C-D,+,Step}))<nuw><nsw>
1748 // if D + (C - D + Step * n) could be proven to not unsigned wrap
1749 // where D maximizes the number of trailing zeros of (C - D + Step * n)
1750 if (const auto *SC = dyn_cast<SCEVConstant>(Start)) {
1751 const APInt &C = SC->getAPInt();
1752 const APInt &D = extractConstantWithoutWrapping(*this, C, Step);
1753 if (D != 0) {
1754 const SCEV *SZExtD = getZeroExtendExpr(getConstant(D), Ty, Depth);
1755 const SCEV *SResidual =
1756 getAddRecExpr(getConstant(C - D), Step, L, AR->getNoWrapFlags());
1757 const SCEV *SZExtR = getZeroExtendExpr(SResidual, Ty, Depth + 1);
1758 return getAddExpr(SZExtD, SZExtR,
1760 Depth + 1);
1761 }
1762 }
1763
1764 if (proveNoWrapByVaryingStart<SCEVZeroExtendExpr>(Start, Step, L)) {
1765 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNUW);
1766 Start =
1768 Step = getZeroExtendExpr(Step, Ty, Depth + 1);
1769 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
1770 }
1771 }
1772
1773 // zext(A % B) --> zext(A) % zext(B)
1774 {
1775 const SCEV *LHS;
1776 const SCEV *RHS;
1777 if (match(Op, m_scev_URem(m_SCEV(LHS), m_SCEV(RHS), *this)))
1778 return getURemExpr(getZeroExtendExpr(LHS, Ty, Depth + 1),
1779 getZeroExtendExpr(RHS, Ty, Depth + 1));
1780 }
1781
1782 // zext(A / B) --> zext(A) / zext(B).
1783 if (auto *Div = dyn_cast<SCEVUDivExpr>(Op))
1784 return getUDivExpr(getZeroExtendExpr(Div->getLHS(), Ty, Depth + 1),
1785 getZeroExtendExpr(Div->getRHS(), Ty, Depth + 1));
1786
1787 if (auto *SA = dyn_cast<SCEVAddExpr>(Op)) {
1788 // zext((A + B + ...)<nuw>) --> (zext(A) + zext(B) + ...)<nuw>
1789 if (SA->hasNoUnsignedWrap()) {
1790 // If the addition does not unsign overflow then we can, by definition,
1791 // commute the zero extension with the addition operation.
1793 for (const auto *Op : SA->operands())
1794 Ops.push_back(getZeroExtendExpr(Op, Ty, Depth + 1));
1795 return getAddExpr(Ops, SCEV::FlagNUW, Depth + 1);
1796 }
1797
1798 // zext(C + x + y + ...) --> (zext(D) + zext((C - D) + x + y + ...))
1799 // if D + (C - D + x + y + ...) could be proven to not unsigned wrap
1800 // where D maximizes the number of trailing zeros of (C - D + x + y + ...)
1801 //
1802 // Often address arithmetics contain expressions like
1803 // (zext (add (shl X, C1), C2)), for instance, (zext (5 + (4 * X))).
1804 // This transformation is useful while proving that such expressions are
1805 // equal or differ by a small constant amount, see LoadStoreVectorizer pass.
1806 if (const auto *SC = dyn_cast<SCEVConstant>(SA->getOperand(0))) {
1807 const APInt &D = extractConstantWithoutWrapping(*this, SC, SA);
1808 if (D != 0) {
1809 const SCEV *SZExtD = getZeroExtendExpr(getConstant(D), Ty, Depth);
1810 const SCEV *SResidual =
1812 const SCEV *SZExtR = getZeroExtendExpr(SResidual, Ty, Depth + 1);
1813 return getAddExpr(SZExtD, SZExtR,
1815 Depth + 1);
1816 }
1817 }
1818 }
1819
1820 if (auto *SM = dyn_cast<SCEVMulExpr>(Op)) {
1821 // zext((A * B * ...)<nuw>) --> (zext(A) * zext(B) * ...)<nuw>
1822 if (SM->hasNoUnsignedWrap()) {
1823 // If the multiply does not unsign overflow then we can, by definition,
1824 // commute the zero extension with the multiply operation.
1826 for (const auto *Op : SM->operands())
1827 Ops.push_back(getZeroExtendExpr(Op, Ty, Depth + 1));
1828 return getMulExpr(Ops, SCEV::FlagNUW, Depth + 1);
1829 }
1830
1831 // zext(2^K * (trunc X to iN)) to iM ->
1832 // 2^K * (zext(trunc X to i{N-K}) to iM)<nuw>
1833 //
1834 // Proof:
1835 //
1836 // zext(2^K * (trunc X to iN)) to iM
1837 // = zext((trunc X to iN) << K) to iM
1838 // = zext((trunc X to i{N-K}) << K)<nuw> to iM
1839 // (because shl removes the top K bits)
1840 // = zext((2^K * (trunc X to i{N-K}))<nuw>) to iM
1841 // = (2^K * (zext(trunc X to i{N-K}) to iM))<nuw>.
1842 //
1843 const APInt *C;
1844 const SCEV *TruncRHS;
1845 if (match(SM,
1846 m_scev_Mul(m_scev_APInt(C), m_scev_Trunc(m_SCEV(TruncRHS)))) &&
1847 C->isPowerOf2()) {
1848 int NewTruncBits =
1849 getTypeSizeInBits(SM->getOperand(1)->getType()) - C->logBase2();
1850 Type *NewTruncTy = IntegerType::get(getContext(), NewTruncBits);
1851 return getMulExpr(
1852 getZeroExtendExpr(SM->getOperand(0), Ty),
1853 getZeroExtendExpr(getTruncateExpr(TruncRHS, NewTruncTy), Ty),
1854 SCEV::FlagNUW, Depth + 1);
1855 }
1856 }
1857
1858 // zext(umin(x, y)) -> umin(zext(x), zext(y))
1859 // zext(umax(x, y)) -> umax(zext(x), zext(y))
1863 for (auto *Operand : MinMax->operands())
1864 Operands.push_back(getZeroExtendExpr(Operand, Ty));
1866 return getUMinExpr(Operands);
1867 return getUMaxExpr(Operands);
1868 }
1869
1870 // zext(umin_seq(x, y)) -> umin_seq(zext(x), zext(y))
1872 assert(isa<SCEVSequentialUMinExpr>(MinMax) && "Not supported!");
1874 for (auto *Operand : MinMax->operands())
1875 Operands.push_back(getZeroExtendExpr(Operand, Ty));
1876 return getUMinExpr(Operands, /*Sequential*/ true);
1877 }
1878
1879 // The cast wasn't folded; create an explicit cast node.
1880 // Recompute the insert position, as it may have been invalidated.
1881 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
1882 SCEV *S = new (SCEVAllocator) SCEVZeroExtendExpr(ID.Intern(SCEVAllocator),
1883 Op, Ty);
1884 UniqueSCEVs.InsertNode(S, IP);
1885 registerUser(S, Op);
1886 return S;
1887}
1888
1889const SCEV *
1891 assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) &&
1892 "This is not an extending conversion!");
1893 assert(isSCEVable(Ty) &&
1894 "This is not a conversion to a SCEVable type!");
1895 assert(!Op->getType()->isPointerTy() && "Can't extend pointer!");
1896 Ty = getEffectiveSCEVType(Ty);
1897
1898 FoldID ID(scSignExtend, Op, Ty);
1899 if (const SCEV *S = FoldCache.lookup(ID))
1900 return S;
1901
1902 const SCEV *S = getSignExtendExprImpl(Op, Ty, Depth);
1904 insertFoldCacheEntry(ID, S, FoldCache, FoldCacheUser);
1905 return S;
1906}
1907
1909 unsigned Depth) {
1910 assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) &&
1911 "This is not an extending conversion!");
1912 assert(isSCEVable(Ty) && "This is not a conversion to a SCEVable type!");
1913 assert(!Op->getType()->isPointerTy() && "Can't extend pointer!");
1914 Ty = getEffectiveSCEVType(Ty);
1915
1916 // Fold if the operand is constant.
1917 if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op))
1918 return getConstant(SC->getAPInt().sext(getTypeSizeInBits(Ty)));
1919
1920 // sext(sext(x)) --> sext(x)
1922 return getSignExtendExpr(SS->getOperand(), Ty, Depth + 1);
1923
1924 // sext(zext(x)) --> zext(x)
1926 return getZeroExtendExpr(SZ->getOperand(), Ty, Depth + 1);
1927
1928 // Before doing any expensive analysis, check to see if we've already
1929 // computed a SCEV for this Op and Ty.
1931 ID.AddInteger(scSignExtend);
1932 ID.AddPointer(Op);
1933 ID.AddPointer(Ty);
1934 void *IP = nullptr;
1935 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
1936 // Limit recursion depth.
1937 if (Depth > MaxCastDepth) {
1938 SCEV *S = new (SCEVAllocator) SCEVSignExtendExpr(ID.Intern(SCEVAllocator),
1939 Op, Ty);
1940 UniqueSCEVs.InsertNode(S, IP);
1941 registerUser(S, Op);
1942 return S;
1943 }
1944
1945 // sext(trunc(x)) --> sext(x) or x or trunc(x)
1947 // It's possible the bits taken off by the truncate were all sign bits. If
1948 // so, we should be able to simplify this further.
1949 const SCEV *X = ST->getOperand();
1951 unsigned TruncBits = getTypeSizeInBits(ST->getType());
1952 unsigned NewBits = getTypeSizeInBits(Ty);
1953 if (CR.truncate(TruncBits).signExtend(NewBits).contains(
1954 CR.sextOrTrunc(NewBits)))
1955 return getTruncateOrSignExtend(X, Ty, Depth);
1956 }
1957
1958 if (auto *SA = dyn_cast<SCEVAddExpr>(Op)) {
1959 // sext((A + B + ...)<nsw>) --> (sext(A) + sext(B) + ...)<nsw>
1960 if (SA->hasNoSignedWrap()) {
1961 // If the addition does not sign overflow then we can, by definition,
1962 // commute the sign extension with the addition operation.
1964 for (const auto *Op : SA->operands())
1965 Ops.push_back(getSignExtendExpr(Op, Ty, Depth + 1));
1966 return getAddExpr(Ops, SCEV::FlagNSW, Depth + 1);
1967 }
1968
1969 // sext(C + x + y + ...) --> (sext(D) + sext((C - D) + x + y + ...))
1970 // if D + (C - D + x + y + ...) could be proven to not signed wrap
1971 // where D maximizes the number of trailing zeros of (C - D + x + y + ...)
1972 //
1973 // For instance, this will bring two seemingly different expressions:
1974 // 1 + sext(5 + 20 * %x + 24 * %y) and
1975 // sext(6 + 20 * %x + 24 * %y)
1976 // to the same form:
1977 // 2 + sext(4 + 20 * %x + 24 * %y)
1978 if (const auto *SC = dyn_cast<SCEVConstant>(SA->getOperand(0))) {
1979 const APInt &D = extractConstantWithoutWrapping(*this, SC, SA);
1980 if (D != 0) {
1981 const SCEV *SSExtD = getSignExtendExpr(getConstant(D), Ty, Depth);
1982 const SCEV *SResidual =
1984 const SCEV *SSExtR = getSignExtendExpr(SResidual, Ty, Depth + 1);
1985 return getAddExpr(SSExtD, SSExtR,
1987 Depth + 1);
1988 }
1989 }
1990 }
1991 // If the input value is a chrec scev, and we can prove that the value
1992 // did not overflow the old, smaller, value, we can sign extend all of the
1993 // operands (often constants). This allows analysis of something like
1994 // this: for (signed char X = 0; X < 100; ++X) { int Y = X; }
1996 if (AR->isAffine()) {
1997 const SCEV *Start = AR->getStart();
1998 const SCEV *Step = AR->getStepRecurrence(*this);
1999 unsigned BitWidth = getTypeSizeInBits(AR->getType());
2000 const Loop *L = AR->getLoop();
2001
2002 // If we have special knowledge that this addrec won't overflow,
2003 // we don't need to do any further analysis.
2004 if (AR->hasNoSignedWrap()) {
2005 Start =
2007 Step = getSignExtendExpr(Step, Ty, Depth + 1);
2008 return getAddRecExpr(Start, Step, L, SCEV::FlagNSW);
2009 }
2010
2011 // Check whether the backedge-taken count is SCEVCouldNotCompute.
2012 // Note that this serves two purposes: It filters out loops that are
2013 // simply not analyzable, and it covers the case where this code is
2014 // being called from within backedge-taken count analysis, such that
2015 // attempting to ask for the backedge-taken count would likely result
2016 // in infinite recursion. In the later case, the analysis code will
2017 // cope with a conservative value, and it will take care to purge
2018 // that value once it has finished.
2019 const SCEV *MaxBECount = getConstantMaxBackedgeTakenCount(L);
2020 if (!isa<SCEVCouldNotCompute>(MaxBECount)) {
2021 // Manually compute the final value for AR, checking for
2022 // overflow.
2023
2024 // Check whether the backedge-taken count can be losslessly casted to
2025 // the addrec's type. The count is always unsigned.
2026 const SCEV *CastedMaxBECount =
2027 getTruncateOrZeroExtend(MaxBECount, Start->getType(), Depth);
2028 const SCEV *RecastedMaxBECount = getTruncateOrZeroExtend(
2029 CastedMaxBECount, MaxBECount->getType(), Depth);
2030 if (MaxBECount == RecastedMaxBECount) {
2031 Type *WideTy = IntegerType::get(getContext(), BitWidth * 2);
2032 // Check whether Start+Step*MaxBECount has no signed overflow.
2033 const SCEV *SMul = getMulExpr(CastedMaxBECount, Step,
2035 const SCEV *SAdd = getSignExtendExpr(getAddExpr(Start, SMul,
2037 Depth + 1),
2038 WideTy, Depth + 1);
2039 const SCEV *WideStart = getSignExtendExpr(Start, WideTy, Depth + 1);
2040 const SCEV *WideMaxBECount =
2041 getZeroExtendExpr(CastedMaxBECount, WideTy, Depth + 1);
2042 const SCEV *OperandExtendedAdd =
2043 getAddExpr(WideStart,
2044 getMulExpr(WideMaxBECount,
2045 getSignExtendExpr(Step, WideTy, Depth + 1),
2048 if (SAdd == OperandExtendedAdd) {
2049 // Cache knowledge of AR NSW, which is propagated to this AddRec.
2050 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNSW);
2051 // Return the expression with the addrec on the outside.
2052 Start = getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this,
2053 Depth + 1);
2054 Step = getSignExtendExpr(Step, Ty, Depth + 1);
2055 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
2056 }
2057 // Similar to above, only this time treat the step value as unsigned.
2058 // This covers loops that count up with an unsigned step.
2059 OperandExtendedAdd =
2060 getAddExpr(WideStart,
2061 getMulExpr(WideMaxBECount,
2062 getZeroExtendExpr(Step, WideTy, Depth + 1),
2065 if (SAdd == OperandExtendedAdd) {
2066 // If AR wraps around then
2067 //
2068 // abs(Step) * MaxBECount > unsigned-max(AR->getType())
2069 // => SAdd != OperandExtendedAdd
2070 //
2071 // Thus (AR is not NW => SAdd != OperandExtendedAdd) <=>
2072 // (SAdd == OperandExtendedAdd => AR is NW)
2073
2074 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNW);
2075
2076 // Return the expression with the addrec on the outside.
2077 Start = getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this,
2078 Depth + 1);
2079 Step = getZeroExtendExpr(Step, Ty, Depth + 1);
2080 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
2081 }
2082 }
2083 }
2084
2085 auto NewFlags = proveNoSignedWrapViaInduction(AR);
2086 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), NewFlags);
2087 if (AR->hasNoSignedWrap()) {
2088 // Same as nsw case above - duplicated here to avoid a compile time
2089 // issue. It's not clear that the order of checks does matter, but
2090 // it's one of two issue possible causes for a change which was
2091 // reverted. Be conservative for the moment.
2092 Start =
2094 Step = getSignExtendExpr(Step, Ty, Depth + 1);
2095 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
2096 }
2097
2098 // sext({C,+,Step}) --> (sext(D) + sext({C-D,+,Step}))<nuw><nsw>
2099 // if D + (C - D + Step * n) could be proven to not signed wrap
2100 // where D maximizes the number of trailing zeros of (C - D + Step * n)
2101 if (const auto *SC = dyn_cast<SCEVConstant>(Start)) {
2102 const APInt &C = SC->getAPInt();
2103 const APInt &D = extractConstantWithoutWrapping(*this, C, Step);
2104 if (D != 0) {
2105 const SCEV *SSExtD = getSignExtendExpr(getConstant(D), Ty, Depth);
2106 const SCEV *SResidual =
2107 getAddRecExpr(getConstant(C - D), Step, L, AR->getNoWrapFlags());
2108 const SCEV *SSExtR = getSignExtendExpr(SResidual, Ty, Depth + 1);
2109 return getAddExpr(SSExtD, SSExtR,
2111 Depth + 1);
2112 }
2113 }
2114
2115 if (proveNoWrapByVaryingStart<SCEVSignExtendExpr>(Start, Step, L)) {
2116 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNSW);
2117 Start =
2119 Step = getSignExtendExpr(Step, Ty, Depth + 1);
2120 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
2121 }
2122 }
2123
2124 // If the input value is provably positive and we could not simplify
2125 // away the sext build a zext instead.
2127 return getZeroExtendExpr(Op, Ty, Depth + 1);
2128
2129 // sext(smin(x, y)) -> smin(sext(x), sext(y))
2130 // sext(smax(x, y)) -> smax(sext(x), sext(y))
2134 for (auto *Operand : MinMax->operands())
2135 Operands.push_back(getSignExtendExpr(Operand, Ty));
2137 return getSMinExpr(Operands);
2138 return getSMaxExpr(Operands);
2139 }
2140
2141 // The cast wasn't folded; create an explicit cast node.
2142 // Recompute the insert position, as it may have been invalidated.
2143 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
2144 SCEV *S = new (SCEVAllocator) SCEVSignExtendExpr(ID.Intern(SCEVAllocator),
2145 Op, Ty);
2146 UniqueSCEVs.InsertNode(S, IP);
2147 registerUser(S, { Op });
2148 return S;
2149}
2150
2152 Type *Ty) {
2153 switch (Kind) {
2154 case scTruncate:
2155 return getTruncateExpr(Op, Ty);
2156 case scZeroExtend:
2157 return getZeroExtendExpr(Op, Ty);
2158 case scSignExtend:
2159 return getSignExtendExpr(Op, Ty);
2160 case scPtrToInt:
2161 return getPtrToIntExpr(Op, Ty);
2162 default:
2163 llvm_unreachable("Not a SCEV cast expression!");
2164 }
2165}
2166
2167/// getAnyExtendExpr - Return a SCEV for the given operand extended with
2168/// unspecified bits out to the given type.
2170 Type *Ty) {
2171 assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) &&
2172 "This is not an extending conversion!");
2173 assert(isSCEVable(Ty) &&
2174 "This is not a conversion to a SCEVable type!");
2175 Ty = getEffectiveSCEVType(Ty);
2176
2177 // Sign-extend negative constants.
2178 if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op))
2179 if (SC->getAPInt().isNegative())
2180 return getSignExtendExpr(Op, Ty);
2181
2182 // Peel off a truncate cast.
2184 const SCEV *NewOp = T->getOperand();
2185 if (getTypeSizeInBits(NewOp->getType()) < getTypeSizeInBits(Ty))
2186 return getAnyExtendExpr(NewOp, Ty);
2187 return getTruncateOrNoop(NewOp, Ty);
2188 }
2189
2190 // Next try a zext cast. If the cast is folded, use it.
2191 const SCEV *ZExt = getZeroExtendExpr(Op, Ty);
2192 if (!isa<SCEVZeroExtendExpr>(ZExt))
2193 return ZExt;
2194
2195 // Next try a sext cast. If the cast is folded, use it.
2196 const SCEV *SExt = getSignExtendExpr(Op, Ty);
2197 if (!isa<SCEVSignExtendExpr>(SExt))
2198 return SExt;
2199
2200 // Force the cast to be folded into the operands of an addrec.
2201 if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Op)) {
2203 for (const SCEV *Op : AR->operands())
2204 Ops.push_back(getAnyExtendExpr(Op, Ty));
2205 return getAddRecExpr(Ops, AR->getLoop(), SCEV::FlagNW);
2206 }
2207
2208 // If the expression is obviously signed, use the sext cast value.
2209 if (isa<SCEVSMaxExpr>(Op))
2210 return SExt;
2211
2212 // Absent any other information, use the zext cast value.
2213 return ZExt;
2214}
2215
2216/// Process the given Ops list, which is a list of operands to be added under
2217/// the given scale, update the given map. This is a helper function for
2218/// getAddRecExpr. As an example of what it does, given a sequence of operands
2219/// that would form an add expression like this:
2220///
2221/// m + n + 13 + (A * (o + p + (B * (q + m + 29)))) + r + (-1 * r)
2222///
2223/// where A and B are constants, update the map with these values:
2224///
2225/// (m, 1+A*B), (n, 1), (o, A), (p, A), (q, A*B), (r, 0)
2226///
2227/// and add 13 + A*B*29 to AccumulatedConstant.
2228/// This will allow getAddRecExpr to produce this:
2229///
2230/// 13+A*B*29 + n + (m * (1+A*B)) + ((o + p) * A) + (q * A*B)
2231///
2232/// This form often exposes folding opportunities that are hidden in
2233/// the original operand list.
2234///
2235/// Return true iff it appears that any interesting folding opportunities
2236/// may be exposed. This helps getAddRecExpr short-circuit extra work in
2237/// the common case where no interesting opportunities are present, and
2238/// is also used as a check to avoid infinite recursion.
2239static bool
2242 APInt &AccumulatedConstant,
2243 ArrayRef<const SCEV *> Ops, const APInt &Scale,
2244 ScalarEvolution &SE) {
2245 bool Interesting = false;
2246
2247 // Iterate over the add operands. They are sorted, with constants first.
2248 unsigned i = 0;
2249 while (const SCEVConstant *C = dyn_cast<SCEVConstant>(Ops[i])) {
2250 ++i;
2251 // Pull a buried constant out to the outside.
2252 if (Scale != 1 || AccumulatedConstant != 0 || C->getValue()->isZero())
2253 Interesting = true;
2254 AccumulatedConstant += Scale * C->getAPInt();
2255 }
2256
2257 // Next comes everything else. We're especially interested in multiplies
2258 // here, but they're in the middle, so just visit the rest with one loop.
2259 for (; i != Ops.size(); ++i) {
2261 if (Mul && isa<SCEVConstant>(Mul->getOperand(0))) {
2262 APInt NewScale =
2263 Scale * cast<SCEVConstant>(Mul->getOperand(0))->getAPInt();
2264 if (Mul->getNumOperands() == 2 && isa<SCEVAddExpr>(Mul->getOperand(1))) {
2265 // A multiplication of a constant with another add; recurse.
2266 const SCEVAddExpr *Add = cast<SCEVAddExpr>(Mul->getOperand(1));
2267 Interesting |=
2268 CollectAddOperandsWithScales(M, NewOps, AccumulatedConstant,
2269 Add->operands(), NewScale, SE);
2270 } else {
2271 // A multiplication of a constant with some other value. Update
2272 // the map.
2273 SmallVector<const SCEV *, 4> MulOps(drop_begin(Mul->operands()));
2274 const SCEV *Key = SE.getMulExpr(MulOps);
2275 auto Pair = M.insert({Key, NewScale});
2276 if (Pair.second) {
2277 NewOps.push_back(Pair.first->first);
2278 } else {
2279 Pair.first->second += NewScale;
2280 // The map already had an entry for this value, which may indicate
2281 // a folding opportunity.
2282 Interesting = true;
2283 }
2284 }
2285 } else {
2286 // An ordinary operand. Update the map.
2287 std::pair<DenseMap<const SCEV *, APInt>::iterator, bool> Pair =
2288 M.insert({Ops[i], Scale});
2289 if (Pair.second) {
2290 NewOps.push_back(Pair.first->first);
2291 } else {
2292 Pair.first->second += Scale;
2293 // The map already had an entry for this value, which may indicate
2294 // a folding opportunity.
2295 Interesting = true;
2296 }
2297 }
2298 }
2299
2300 return Interesting;
2301}
2302
2304 const SCEV *LHS, const SCEV *RHS,
2305 const Instruction *CtxI) {
2306 const SCEV *(ScalarEvolution::*Operation)(const SCEV *, const SCEV *,
2307 SCEV::NoWrapFlags, unsigned);
2308 switch (BinOp) {
2309 default:
2310 llvm_unreachable("Unsupported binary op");
2311 case Instruction::Add:
2313 break;
2314 case Instruction::Sub:
2316 break;
2317 case Instruction::Mul:
2319 break;
2320 }
2321
2322 const SCEV *(ScalarEvolution::*Extension)(const SCEV *, Type *, unsigned) =
2325
2326 // Check ext(LHS op RHS) == ext(LHS) op ext(RHS)
2327 auto *NarrowTy = cast<IntegerType>(LHS->getType());
2328 auto *WideTy =
2329 IntegerType::get(NarrowTy->getContext(), NarrowTy->getBitWidth() * 2);
2330
2331 const SCEV *A = (this->*Extension)(
2332 (this->*Operation)(LHS, RHS, SCEV::FlagAnyWrap, 0), WideTy, 0);
2333 const SCEV *LHSB = (this->*Extension)(LHS, WideTy, 0);
2334 const SCEV *RHSB = (this->*Extension)(RHS, WideTy, 0);
2335 const SCEV *B = (this->*Operation)(LHSB, RHSB, SCEV::FlagAnyWrap, 0);
2336 if (A == B)
2337 return true;
2338 // Can we use context to prove the fact we need?
2339 if (!CtxI)
2340 return false;
2341 // TODO: Support mul.
2342 if (BinOp == Instruction::Mul)
2343 return false;
2344 auto *RHSC = dyn_cast<SCEVConstant>(RHS);
2345 // TODO: Lift this limitation.
2346 if (!RHSC)
2347 return false;
2348 APInt C = RHSC->getAPInt();
2349 unsigned NumBits = C.getBitWidth();
2350 bool IsSub = (BinOp == Instruction::Sub);
2351 bool IsNegativeConst = (Signed && C.isNegative());
2352 // Compute the direction and magnitude by which we need to check overflow.
2353 bool OverflowDown = IsSub ^ IsNegativeConst;
2354 APInt Magnitude = C;
2355 if (IsNegativeConst) {
2356 if (C == APInt::getSignedMinValue(NumBits))
2357 // TODO: SINT_MIN on inversion gives the same negative value, we don't
2358 // want to deal with that.
2359 return false;
2360 Magnitude = -C;
2361 }
2362
2364 if (OverflowDown) {
2365 // To avoid overflow down, we need to make sure that MIN + Magnitude <= LHS.
2366 APInt Min = Signed ? APInt::getSignedMinValue(NumBits)
2367 : APInt::getMinValue(NumBits);
2368 APInt Limit = Min + Magnitude;
2369 return isKnownPredicateAt(Pred, getConstant(Limit), LHS, CtxI);
2370 } else {
2371 // To avoid overflow up, we need to make sure that LHS <= MAX - Magnitude.
2372 APInt Max = Signed ? APInt::getSignedMaxValue(NumBits)
2373 : APInt::getMaxValue(NumBits);
2374 APInt Limit = Max - Magnitude;
2375 return isKnownPredicateAt(Pred, LHS, getConstant(Limit), CtxI);
2376 }
2377}
2378
2379std::optional<SCEV::NoWrapFlags>
2381 const OverflowingBinaryOperator *OBO) {
2382 // It cannot be done any better.
2383 if (OBO->hasNoUnsignedWrap() && OBO->hasNoSignedWrap())
2384 return std::nullopt;
2385
2387
2388 if (OBO->hasNoUnsignedWrap())
2390 if (OBO->hasNoSignedWrap())
2392
2393 bool Deduced = false;
2394
2395 if (OBO->getOpcode() != Instruction::Add &&
2396 OBO->getOpcode() != Instruction::Sub &&
2397 OBO->getOpcode() != Instruction::Mul)
2398 return std::nullopt;
2399
2400 const SCEV *LHS = getSCEV(OBO->getOperand(0));
2401 const SCEV *RHS = getSCEV(OBO->getOperand(1));
2402
2403 const Instruction *CtxI =
2405 if (!OBO->hasNoUnsignedWrap() &&
2407 /* Signed */ false, LHS, RHS, CtxI)) {
2409 Deduced = true;
2410 }
2411
2412 if (!OBO->hasNoSignedWrap() &&
2414 /* Signed */ true, LHS, RHS, CtxI)) {
2416 Deduced = true;
2417 }
2418
2419 if (Deduced)
2420 return Flags;
2421 return std::nullopt;
2422}
2423
2424// We're trying to construct a SCEV of type `Type' with `Ops' as operands and
2425// `OldFlags' as can't-wrap behavior. Infer a more aggressive set of
2426// can't-overflow flags for the operation if possible.
2427static SCEV::NoWrapFlags
2430 SCEV::NoWrapFlags Flags) {
2431 using namespace std::placeholders;
2432
2433 using OBO = OverflowingBinaryOperator;
2434
2435 bool CanAnalyze =
2437 (void)CanAnalyze;
2438 assert(CanAnalyze && "don't call from other places!");
2439
2440 int SignOrUnsignMask = SCEV::FlagNUW | SCEV::FlagNSW;
2441 SCEV::NoWrapFlags SignOrUnsignWrap =
2442 ScalarEvolution::maskFlags(Flags, SignOrUnsignMask);
2443
2444 // If FlagNSW is true and all the operands are non-negative, infer FlagNUW.
2445 auto IsKnownNonNegative = [&](const SCEV *S) {
2446 return SE->isKnownNonNegative(S);
2447 };
2448
2449 if (SignOrUnsignWrap == SCEV::FlagNSW && all_of(Ops, IsKnownNonNegative))
2450 Flags =
2451 ScalarEvolution::setFlags(Flags, (SCEV::NoWrapFlags)SignOrUnsignMask);
2452
2453 SignOrUnsignWrap = ScalarEvolution::maskFlags(Flags, SignOrUnsignMask);
2454
2455 if (SignOrUnsignWrap != SignOrUnsignMask &&
2456 (Type == scAddExpr || Type == scMulExpr) && Ops.size() == 2 &&
2457 isa<SCEVConstant>(Ops[0])) {
2458
2459 auto Opcode = [&] {
2460 switch (Type) {
2461 case scAddExpr:
2462 return Instruction::Add;
2463 case scMulExpr:
2464 return Instruction::Mul;
2465 default:
2466 llvm_unreachable("Unexpected SCEV op.");
2467 }
2468 }();
2469
2470 const APInt &C = cast<SCEVConstant>(Ops[0])->getAPInt();
2471
2472 // (A <opcode> C) --> (A <opcode> C)<nsw> if the op doesn't sign overflow.
2473 if (!(SignOrUnsignWrap & SCEV::FlagNSW)) {
2475 Opcode, C, OBO::NoSignedWrap);
2476 if (NSWRegion.contains(SE->getSignedRange(Ops[1])))
2478 }
2479
2480 // (A <opcode> C) --> (A <opcode> C)<nuw> if the op doesn't unsign overflow.
2481 if (!(SignOrUnsignWrap & SCEV::FlagNUW)) {
2483 Opcode, C, OBO::NoUnsignedWrap);
2484 if (NUWRegion.contains(SE->getUnsignedRange(Ops[1])))
2486 }
2487 }
2488
2489 // <0,+,nonnegative><nw> is also nuw
2490 // TODO: Add corresponding nsw case
2492 !ScalarEvolution::hasFlags(Flags, SCEV::FlagNUW) && Ops.size() == 2 &&
2493 Ops[0]->isZero() && IsKnownNonNegative(Ops[1]))
2495
2496 // both (udiv X, Y) * Y and Y * (udiv X, Y) are always NUW
2498 Ops.size() == 2) {
2499 if (auto *UDiv = dyn_cast<SCEVUDivExpr>(Ops[0]))
2500 if (UDiv->getOperand(1) == Ops[1])
2502 if (auto *UDiv = dyn_cast<SCEVUDivExpr>(Ops[1]))
2503 if (UDiv->getOperand(1) == Ops[0])
2505 }
2506
2507 return Flags;
2508}
2509
2511 return isLoopInvariant(S, L) && properlyDominates(S, L->getHeader());
2512}
2513
2514/// Get a canonical add expression, or something simpler if possible.
2516 SCEV::NoWrapFlags OrigFlags,
2517 unsigned Depth) {
2518 assert(!(OrigFlags & ~(SCEV::FlagNUW | SCEV::FlagNSW)) &&
2519 "only nuw or nsw allowed");
2520 assert(!Ops.empty() && "Cannot get empty add!");
2521 if (Ops.size() == 1) return Ops[0];
2522#ifndef NDEBUG
2523 Type *ETy = getEffectiveSCEVType(Ops[0]->getType());
2524 for (unsigned i = 1, e = Ops.size(); i != e; ++i)
2525 assert(getEffectiveSCEVType(Ops[i]->getType()) == ETy &&
2526 "SCEVAddExpr operand types don't match!");
2527 unsigned NumPtrs = count_if(
2528 Ops, [](const SCEV *Op) { return Op->getType()->isPointerTy(); });
2529 assert(NumPtrs <= 1 && "add has at most one pointer operand");
2530#endif
2531
2532 const SCEV *Folded = constantFoldAndGroupOps(
2533 *this, LI, DT, Ops,
2534 [](const APInt &C1, const APInt &C2) { return C1 + C2; },
2535 [](const APInt &C) { return C.isZero(); }, // identity
2536 [](const APInt &C) { return false; }); // absorber
2537 if (Folded)
2538 return Folded;
2539
2540 unsigned Idx = isa<SCEVConstant>(Ops[0]) ? 1 : 0;
2541
2542 // Delay expensive flag strengthening until necessary.
2543 auto ComputeFlags = [this, OrigFlags](const ArrayRef<const SCEV *> Ops) {
2544 return StrengthenNoWrapFlags(this, scAddExpr, Ops, OrigFlags);
2545 };
2546
2547 // Limit recursion calls depth.
2549 return getOrCreateAddExpr(Ops, ComputeFlags(Ops));
2550
2551 if (SCEV *S = findExistingSCEVInCache(scAddExpr, Ops)) {
2552 // Don't strengthen flags if we have no new information.
2553 SCEVAddExpr *Add = static_cast<SCEVAddExpr *>(S);
2554 if (Add->getNoWrapFlags(OrigFlags) != OrigFlags)
2555 Add->setNoWrapFlags(ComputeFlags(Ops));
2556 return S;
2557 }
2558
2559 // Okay, check to see if the same value occurs in the operand list more than
2560 // once. If so, merge them together into an multiply expression. Since we
2561 // sorted the list, these values are required to be adjacent.
2562 Type *Ty = Ops[0]->getType();
2563 bool FoundMatch = false;
2564 for (unsigned i = 0, e = Ops.size(); i != e-1; ++i)
2565 if (Ops[i] == Ops[i+1]) { // X + Y + Y --> X + Y*2
2566 // Scan ahead to count how many equal operands there are.
2567 unsigned Count = 2;
2568 while (i+Count != e && Ops[i+Count] == Ops[i])
2569 ++Count;
2570 // Merge the values into a multiply.
2571 const SCEV *Scale = getConstant(Ty, Count);
2572 const SCEV *Mul = getMulExpr(Scale, Ops[i], SCEV::FlagAnyWrap, Depth + 1);
2573 if (Ops.size() == Count)
2574 return Mul;
2575 Ops[i] = Mul;
2576 Ops.erase(Ops.begin()+i+1, Ops.begin()+i+Count);
2577 --i; e -= Count - 1;
2578 FoundMatch = true;
2579 }
2580 if (FoundMatch)
2581 return getAddExpr(Ops, OrigFlags, Depth + 1);
2582
2583 // Check for truncates. If all the operands are truncated from the same
2584 // type, see if factoring out the truncate would permit the result to be
2585 // folded. eg., n*trunc(x) + m*trunc(y) --> trunc(trunc(m)*x + trunc(n)*y)
2586 // if the contents of the resulting outer trunc fold to something simple.
2587 auto FindTruncSrcType = [&]() -> Type * {
2588 // We're ultimately looking to fold an addrec of truncs and muls of only
2589 // constants and truncs, so if we find any other types of SCEV
2590 // as operands of the addrec then we bail and return nullptr here.
2591 // Otherwise, we return the type of the operand of a trunc that we find.
2592 if (auto *T = dyn_cast<SCEVTruncateExpr>(Ops[Idx]))
2593 return T->getOperand()->getType();
2594 if (const auto *Mul = dyn_cast<SCEVMulExpr>(Ops[Idx])) {
2595 const auto *LastOp = Mul->getOperand(Mul->getNumOperands() - 1);
2596 if (const auto *T = dyn_cast<SCEVTruncateExpr>(LastOp))
2597 return T->getOperand()->getType();
2598 }
2599 return nullptr;
2600 };
2601 if (auto *SrcType = FindTruncSrcType()) {
2603 bool Ok = true;
2604 // Check all the operands to see if they can be represented in the
2605 // source type of the truncate.
2606 for (const SCEV *Op : Ops) {
2608 if (T->getOperand()->getType() != SrcType) {
2609 Ok = false;
2610 break;
2611 }
2612 LargeOps.push_back(T->getOperand());
2613 } else if (const SCEVConstant *C = dyn_cast<SCEVConstant>(Op)) {
2614 LargeOps.push_back(getAnyExtendExpr(C, SrcType));
2615 } else if (const SCEVMulExpr *M = dyn_cast<SCEVMulExpr>(Op)) {
2616 SmallVector<const SCEV *, 8> LargeMulOps;
2617 for (unsigned j = 0, f = M->getNumOperands(); j != f && Ok; ++j) {
2618 if (const SCEVTruncateExpr *T =
2619 dyn_cast<SCEVTruncateExpr>(M->getOperand(j))) {
2620 if (T->getOperand()->getType() != SrcType) {
2621 Ok = false;
2622 break;
2623 }
2624 LargeMulOps.push_back(T->getOperand());
2625 } else if (const auto *C = dyn_cast<SCEVConstant>(M->getOperand(j))) {
2626 LargeMulOps.push_back(getAnyExtendExpr(C, SrcType));
2627 } else {
2628 Ok = false;
2629 break;
2630 }
2631 }
2632 if (Ok)
2633 LargeOps.push_back(getMulExpr(LargeMulOps, SCEV::FlagAnyWrap, Depth + 1));
2634 } else {
2635 Ok = false;
2636 break;
2637 }
2638 }
2639 if (Ok) {
2640 // Evaluate the expression in the larger type.
2641 const SCEV *Fold = getAddExpr(LargeOps, SCEV::FlagAnyWrap, Depth + 1);
2642 // If it folds to something simple, use it. Otherwise, don't.
2643 if (isa<SCEVConstant>(Fold) || isa<SCEVUnknown>(Fold))
2644 return getTruncateExpr(Fold, Ty);
2645 }
2646 }
2647
2648 if (Ops.size() == 2) {
2649 // Check if we have an expression of the form ((X + C1) - C2), where C1 and
2650 // C2 can be folded in a way that allows retaining wrapping flags of (X +
2651 // C1).
2652 const SCEV *A = Ops[0];
2653 const SCEV *B = Ops[1];
2654 auto *AddExpr = dyn_cast<SCEVAddExpr>(B);
2655 auto *C = dyn_cast<SCEVConstant>(A);
2656 if (AddExpr && C && isa<SCEVConstant>(AddExpr->getOperand(0))) {
2657 auto C1 = cast<SCEVConstant>(AddExpr->getOperand(0))->getAPInt();
2658 auto C2 = C->getAPInt();
2659 SCEV::NoWrapFlags PreservedFlags = SCEV::FlagAnyWrap;
2660
2661 APInt ConstAdd = C1 + C2;
2662 auto AddFlags = AddExpr->getNoWrapFlags();
2663 // Adding a smaller constant is NUW if the original AddExpr was NUW.
2665 ConstAdd.ule(C1)) {
2666 PreservedFlags =
2668 }
2669
2670 // Adding a constant with the same sign and small magnitude is NSW, if the
2671 // original AddExpr was NSW.
2673 C1.isSignBitSet() == ConstAdd.isSignBitSet() &&
2674 ConstAdd.abs().ule(C1.abs())) {
2675 PreservedFlags =
2677 }
2678
2679 if (PreservedFlags != SCEV::FlagAnyWrap) {
2680 SmallVector<const SCEV *, 4> NewOps(AddExpr->operands());
2681 NewOps[0] = getConstant(ConstAdd);
2682 return getAddExpr(NewOps, PreservedFlags);
2683 }
2684 }
2685
2686 // Try to push the constant operand into a ZExt: A + zext (-A + B) -> zext
2687 // (B), if trunc (A) + -A + B does not unsigned-wrap.
2688 const SCEVAddExpr *InnerAdd;
2689 if (match(B, m_scev_ZExt(m_scev_Add(InnerAdd)))) {
2690 const SCEV *NarrowA = getTruncateExpr(A, InnerAdd->getType());
2691 if (NarrowA == getNegativeSCEV(InnerAdd->getOperand(0)) &&
2692 getZeroExtendExpr(NarrowA, B->getType()) == A &&
2693 hasFlags(StrengthenNoWrapFlags(this, scAddExpr, {NarrowA, InnerAdd},
2695 SCEV::FlagNUW)) {
2696 return getZeroExtendExpr(getAddExpr(NarrowA, InnerAdd), B->getType());
2697 }
2698 }
2699 }
2700
2701 // Canonicalize (-1 * urem X, Y) + X --> (Y * X/Y)
2702 const SCEV *Y;
2703 if (Ops.size() == 2 &&
2704 match(Ops[0],
2706 m_scev_URem(m_scev_Specific(Ops[1]), m_SCEV(Y), *this))))
2707 return getMulExpr(Y, getUDivExpr(Ops[1], Y));
2708
2709 // Skip past any other cast SCEVs.
2710 while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scAddExpr)
2711 ++Idx;
2712
2713 // If there are add operands they would be next.
2714 if (Idx < Ops.size()) {
2715 bool DeletedAdd = false;
2716 // If the original flags and all inlined SCEVAddExprs are NUW, use the
2717 // common NUW flag for expression after inlining. Other flags cannot be
2718 // preserved, because they may depend on the original order of operations.
2719 SCEV::NoWrapFlags CommonFlags = maskFlags(OrigFlags, SCEV::FlagNUW);
2720 while (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Ops[Idx])) {
2721 if (Ops.size() > AddOpsInlineThreshold ||
2722 Add->getNumOperands() > AddOpsInlineThreshold)
2723 break;
2724 // If we have an add, expand the add operands onto the end of the operands
2725 // list.
2726 Ops.erase(Ops.begin()+Idx);
2727 append_range(Ops, Add->operands());
2728 DeletedAdd = true;
2729 CommonFlags = maskFlags(CommonFlags, Add->getNoWrapFlags());
2730 }
2731
2732 // If we deleted at least one add, we added operands to the end of the list,
2733 // and they are not necessarily sorted. Recurse to resort and resimplify
2734 // any operands we just acquired.
2735 if (DeletedAdd)
2736 return getAddExpr(Ops, CommonFlags, Depth + 1);
2737 }
2738
2739 // Skip over the add expression until we get to a multiply.
2740 while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scMulExpr)
2741 ++Idx;
2742
2743 // Check to see if there are any folding opportunities present with
2744 // operands multiplied by constant values.
2745 if (Idx < Ops.size() && isa<SCEVMulExpr>(Ops[Idx])) {
2749 APInt AccumulatedConstant(BitWidth, 0);
2750 if (CollectAddOperandsWithScales(M, NewOps, AccumulatedConstant,
2751 Ops, APInt(BitWidth, 1), *this)) {
2752 struct APIntCompare {
2753 bool operator()(const APInt &LHS, const APInt &RHS) const {
2754 return LHS.ult(RHS);
2755 }
2756 };
2757
2758 // Some interesting folding opportunity is present, so its worthwhile to
2759 // re-generate the operands list. Group the operands by constant scale,
2760 // to avoid multiplying by the same constant scale multiple times.
2761 std::map<APInt, SmallVector<const SCEV *, 4>, APIntCompare> MulOpLists;
2762 for (const SCEV *NewOp : NewOps)
2763 MulOpLists[M.find(NewOp)->second].push_back(NewOp);
2764 // Re-generate the operands list.
2765 Ops.clear();
2766 if (AccumulatedConstant != 0)
2767 Ops.push_back(getConstant(AccumulatedConstant));
2768 for (auto &MulOp : MulOpLists) {
2769 if (MulOp.first == 1) {
2770 Ops.push_back(getAddExpr(MulOp.second, SCEV::FlagAnyWrap, Depth + 1));
2771 } else if (MulOp.first != 0) {
2772 Ops.push_back(getMulExpr(
2773 getConstant(MulOp.first),
2774 getAddExpr(MulOp.second, SCEV::FlagAnyWrap, Depth + 1),
2775 SCEV::FlagAnyWrap, Depth + 1));
2776 }
2777 }
2778 if (Ops.empty())
2779 return getZero(Ty);
2780 if (Ops.size() == 1)
2781 return Ops[0];
2782 return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
2783 }
2784 }
2785
2786 // If we are adding something to a multiply expression, make sure the
2787 // something is not already an operand of the multiply. If so, merge it into
2788 // the multiply.
2789 for (; Idx < Ops.size() && isa<SCEVMulExpr>(Ops[Idx]); ++Idx) {
2790 const SCEVMulExpr *Mul = cast<SCEVMulExpr>(Ops[Idx]);
2791 for (unsigned MulOp = 0, e = Mul->getNumOperands(); MulOp != e; ++MulOp) {
2792 const SCEV *MulOpSCEV = Mul->getOperand(MulOp);
2793 if (isa<SCEVConstant>(MulOpSCEV))
2794 continue;
2795 for (unsigned AddOp = 0, e = Ops.size(); AddOp != e; ++AddOp)
2796 if (MulOpSCEV == Ops[AddOp]) {
2797 // Fold W + X + (X * Y * Z) --> W + (X * ((Y*Z)+1))
2798 const SCEV *InnerMul = Mul->getOperand(MulOp == 0);
2799 if (Mul->getNumOperands() != 2) {
2800 // If the multiply has more than two operands, we must get the
2801 // Y*Z term.
2803 Mul->operands().take_front(MulOp));
2804 append_range(MulOps, Mul->operands().drop_front(MulOp + 1));
2805 InnerMul = getMulExpr(MulOps, SCEV::FlagAnyWrap, Depth + 1);
2806 }
2807 SmallVector<const SCEV *, 2> TwoOps = {getOne(Ty), InnerMul};
2808 const SCEV *AddOne = getAddExpr(TwoOps, SCEV::FlagAnyWrap, Depth + 1);
2809 const SCEV *OuterMul = getMulExpr(AddOne, MulOpSCEV,
2811 if (Ops.size() == 2) return OuterMul;
2812 if (AddOp < Idx) {
2813 Ops.erase(Ops.begin()+AddOp);
2814 Ops.erase(Ops.begin()+Idx-1);
2815 } else {
2816 Ops.erase(Ops.begin()+Idx);
2817 Ops.erase(Ops.begin()+AddOp-1);
2818 }
2819 Ops.push_back(OuterMul);
2820 return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
2821 }
2822
2823 // Check this multiply against other multiplies being added together.
2824 for (unsigned OtherMulIdx = Idx+1;
2825 OtherMulIdx < Ops.size() && isa<SCEVMulExpr>(Ops[OtherMulIdx]);
2826 ++OtherMulIdx) {
2827 const SCEVMulExpr *OtherMul = cast<SCEVMulExpr>(Ops[OtherMulIdx]);
2828 // If MulOp occurs in OtherMul, we can fold the two multiplies
2829 // together.
2830 for (unsigned OMulOp = 0, e = OtherMul->getNumOperands();
2831 OMulOp != e; ++OMulOp)
2832 if (OtherMul->getOperand(OMulOp) == MulOpSCEV) {
2833 // Fold X + (A*B*C) + (A*D*E) --> X + (A*(B*C+D*E))
2834 const SCEV *InnerMul1 = Mul->getOperand(MulOp == 0);
2835 if (Mul->getNumOperands() != 2) {
2837 Mul->operands().take_front(MulOp));
2838 append_range(MulOps, Mul->operands().drop_front(MulOp+1));
2839 InnerMul1 = getMulExpr(MulOps, SCEV::FlagAnyWrap, Depth + 1);
2840 }
2841 const SCEV *InnerMul2 = OtherMul->getOperand(OMulOp == 0);
2842 if (OtherMul->getNumOperands() != 2) {
2844 OtherMul->operands().take_front(OMulOp));
2845 append_range(MulOps, OtherMul->operands().drop_front(OMulOp+1));
2846 InnerMul2 = getMulExpr(MulOps, SCEV::FlagAnyWrap, Depth + 1);
2847 }
2848 SmallVector<const SCEV *, 2> TwoOps = {InnerMul1, InnerMul2};
2849 const SCEV *InnerMulSum =
2850 getAddExpr(TwoOps, SCEV::FlagAnyWrap, Depth + 1);
2851 const SCEV *OuterMul = getMulExpr(MulOpSCEV, InnerMulSum,
2853 if (Ops.size() == 2) return OuterMul;
2854 Ops.erase(Ops.begin()+Idx);
2855 Ops.erase(Ops.begin()+OtherMulIdx-1);
2856 Ops.push_back(OuterMul);
2857 return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
2858 }
2859 }
2860 }
2861 }
2862
2863 // If there are any add recurrences in the operands list, see if any other
2864 // added values are loop invariant. If so, we can fold them into the
2865 // recurrence.
2866 while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scAddRecExpr)
2867 ++Idx;
2868
2869 // Scan over all recurrences, trying to fold loop invariants into them.
2870 for (; Idx < Ops.size() && isa<SCEVAddRecExpr>(Ops[Idx]); ++Idx) {
2871 // Scan all of the other operands to this add and add them to the vector if
2872 // they are loop invariant w.r.t. the recurrence.
2874 const SCEVAddRecExpr *AddRec = cast<SCEVAddRecExpr>(Ops[Idx]);
2875 const Loop *AddRecLoop = AddRec->getLoop();
2876 for (unsigned i = 0, e = Ops.size(); i != e; ++i)
2877 if (isAvailableAtLoopEntry(Ops[i], AddRecLoop)) {
2878 LIOps.push_back(Ops[i]);
2879 Ops.erase(Ops.begin()+i);
2880 --i; --e;
2881 }
2882
2883 // If we found some loop invariants, fold them into the recurrence.
2884 if (!LIOps.empty()) {
2885 // Compute nowrap flags for the addition of the loop-invariant ops and
2886 // the addrec. Temporarily push it as an operand for that purpose. These
2887 // flags are valid in the scope of the addrec only.
2888 LIOps.push_back(AddRec);
2889 SCEV::NoWrapFlags Flags = ComputeFlags(LIOps);
2890 LIOps.pop_back();
2891
2892 // NLI + LI + {Start,+,Step} --> NLI + {LI+Start,+,Step}
2893 LIOps.push_back(AddRec->getStart());
2894
2895 SmallVector<const SCEV *, 4> AddRecOps(AddRec->operands());
2896
2897 // It is not in general safe to propagate flags valid on an add within
2898 // the addrec scope to one outside it. We must prove that the inner
2899 // scope is guaranteed to execute if the outer one does to be able to
2900 // safely propagate. We know the program is undefined if poison is
2901 // produced on the inner scoped addrec. We also know that *for this use*
2902 // the outer scoped add can't overflow (because of the flags we just
2903 // computed for the inner scoped add) without the program being undefined.
2904 // Proving that entry to the outer scope neccesitates entry to the inner
2905 // scope, thus proves the program undefined if the flags would be violated
2906 // in the outer scope.
2907 SCEV::NoWrapFlags AddFlags = Flags;
2908 if (AddFlags != SCEV::FlagAnyWrap) {
2909 auto *DefI = getDefiningScopeBound(LIOps);
2910 auto *ReachI = &*AddRecLoop->getHeader()->begin();
2911 if (!isGuaranteedToTransferExecutionTo(DefI, ReachI))
2912 AddFlags = SCEV::FlagAnyWrap;
2913 }
2914 AddRecOps[0] = getAddExpr(LIOps, AddFlags, Depth + 1);
2915
2916 // Build the new addrec. Propagate the NUW and NSW flags if both the
2917 // outer add and the inner addrec are guaranteed to have no overflow.
2918 // Always propagate NW.
2919 Flags = AddRec->getNoWrapFlags(setFlags(Flags, SCEV::FlagNW));
2920 const SCEV *NewRec = getAddRecExpr(AddRecOps, AddRecLoop, Flags);
2921
2922 // If all of the other operands were loop invariant, we are done.
2923 if (Ops.size() == 1) return NewRec;
2924
2925 // Otherwise, add the folded AddRec by the non-invariant parts.
2926 for (unsigned i = 0;; ++i)
2927 if (Ops[i] == AddRec) {
2928 Ops[i] = NewRec;
2929 break;
2930 }
2931 return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
2932 }
2933
2934 // Okay, if there weren't any loop invariants to be folded, check to see if
2935 // there are multiple AddRec's with the same loop induction variable being
2936 // added together. If so, we can fold them.
2937 for (unsigned OtherIdx = Idx+1;
2938 OtherIdx < Ops.size() && isa<SCEVAddRecExpr>(Ops[OtherIdx]);
2939 ++OtherIdx) {
2940 // We expect the AddRecExpr's to be sorted in reverse dominance order,
2941 // so that the 1st found AddRecExpr is dominated by all others.
2942 assert(DT.dominates(
2943 cast<SCEVAddRecExpr>(Ops[OtherIdx])->getLoop()->getHeader(),
2944 AddRec->getLoop()->getHeader()) &&
2945 "AddRecExprs are not sorted in reverse dominance order?");
2946 if (AddRecLoop == cast<SCEVAddRecExpr>(Ops[OtherIdx])->getLoop()) {
2947 // Other + {A,+,B}<L> + {C,+,D}<L> --> Other + {A+C,+,B+D}<L>
2948 SmallVector<const SCEV *, 4> AddRecOps(AddRec->operands());
2949 for (; OtherIdx != Ops.size() && isa<SCEVAddRecExpr>(Ops[OtherIdx]);
2950 ++OtherIdx) {
2951 const auto *OtherAddRec = cast<SCEVAddRecExpr>(Ops[OtherIdx]);
2952 if (OtherAddRec->getLoop() == AddRecLoop) {
2953 for (unsigned i = 0, e = OtherAddRec->getNumOperands();
2954 i != e; ++i) {
2955 if (i >= AddRecOps.size()) {
2956 append_range(AddRecOps, OtherAddRec->operands().drop_front(i));
2957 break;
2958 }
2960 AddRecOps[i], OtherAddRec->getOperand(i)};
2961 AddRecOps[i] = getAddExpr(TwoOps, SCEV::FlagAnyWrap, Depth + 1);
2962 }
2963 Ops.erase(Ops.begin() + OtherIdx); --OtherIdx;
2964 }
2965 }
2966 // Step size has changed, so we cannot guarantee no self-wraparound.
2967 Ops[Idx] = getAddRecExpr(AddRecOps, AddRecLoop, SCEV::FlagAnyWrap);
2968 return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
2969 }
2970 }
2971
2972 // Otherwise couldn't fold anything into this recurrence. Move onto the
2973 // next one.
2974 }
2975
2976 // Okay, it looks like we really DO need an add expr. Check to see if we
2977 // already have one, otherwise create a new one.
2978 return getOrCreateAddExpr(Ops, ComputeFlags(Ops));
2979}
2980
2981const SCEV *
2982ScalarEvolution::getOrCreateAddExpr(ArrayRef<const SCEV *> Ops,
2983 SCEV::NoWrapFlags Flags) {
2985 ID.AddInteger(scAddExpr);
2986 for (const SCEV *Op : Ops)
2987 ID.AddPointer(Op);
2988 void *IP = nullptr;
2989 SCEVAddExpr *S =
2990 static_cast<SCEVAddExpr *>(UniqueSCEVs.FindNodeOrInsertPos(ID, IP));
2991 if (!S) {
2992 const SCEV **O = SCEVAllocator.Allocate<const SCEV *>(Ops.size());
2994 S = new (SCEVAllocator)
2995 SCEVAddExpr(ID.Intern(SCEVAllocator), O, Ops.size());
2996 UniqueSCEVs.InsertNode(S, IP);
2997 registerUser(S, Ops);
2998 }
2999 S->setNoWrapFlags(Flags);
3000 return S;
3001}
3002
3003const SCEV *
3004ScalarEvolution::getOrCreateAddRecExpr(ArrayRef<const SCEV *> Ops,
3005 const Loop *L, SCEV::NoWrapFlags Flags) {
3006 FoldingSetNodeID ID;
3007 ID.AddInteger(scAddRecExpr);
3008 for (const SCEV *Op : Ops)
3009 ID.AddPointer(Op);
3010 ID.AddPointer(L);
3011 void *IP = nullptr;
3012 SCEVAddRecExpr *S =
3013 static_cast<SCEVAddRecExpr *>(UniqueSCEVs.FindNodeOrInsertPos(ID, IP));
3014 if (!S) {
3015 const SCEV **O = SCEVAllocator.Allocate<const SCEV *>(Ops.size());
3017 S = new (SCEVAllocator)
3018 SCEVAddRecExpr(ID.Intern(SCEVAllocator), O, Ops.size(), L);
3019 UniqueSCEVs.InsertNode(S, IP);
3020 LoopUsers[L].push_back(S);
3021 registerUser(S, Ops);
3022 }
3023 setNoWrapFlags(S, Flags);
3024 return S;
3025}
3026
3027const SCEV *
3028ScalarEvolution::getOrCreateMulExpr(ArrayRef<const SCEV *> Ops,
3029 SCEV::NoWrapFlags Flags) {
3030 FoldingSetNodeID ID;
3031 ID.AddInteger(scMulExpr);
3032 for (const SCEV *Op : Ops)
3033 ID.AddPointer(Op);
3034 void *IP = nullptr;
3035 SCEVMulExpr *S =
3036 static_cast<SCEVMulExpr *>(UniqueSCEVs.FindNodeOrInsertPos(ID, IP));
3037 if (!S) {
3038 const SCEV **O = SCEVAllocator.Allocate<const SCEV *>(Ops.size());
3040 S = new (SCEVAllocator) SCEVMulExpr(ID.Intern(SCEVAllocator),
3041 O, Ops.size());
3042 UniqueSCEVs.InsertNode(S, IP);
3043 registerUser(S, Ops);
3044 }
3045 S->setNoWrapFlags(Flags);
3046 return S;
3047}
3048
3049static uint64_t umul_ov(uint64_t i, uint64_t j, bool &Overflow) {
3050 uint64_t k = i*j;
3051 if (j > 1 && k / j != i) Overflow = true;
3052 return k;
3053}
3054
3055/// Compute the result of "n choose k", the binomial coefficient. If an
3056/// intermediate computation overflows, Overflow will be set and the return will
3057/// be garbage. Overflow is not cleared on absence of overflow.
3058static uint64_t Choose(uint64_t n, uint64_t k, bool &Overflow) {
3059 // We use the multiplicative formula:
3060 // n(n-1)(n-2)...(n-(k-1)) / k(k-1)(k-2)...1 .
3061 // At each iteration, we take the n-th term of the numeral and divide by the
3062 // (k-n)th term of the denominator. This division will always produce an
3063 // integral result, and helps reduce the chance of overflow in the
3064 // intermediate computations. However, we can still overflow even when the
3065 // final result would fit.
3066
3067 if (n == 0 || n == k) return 1;
3068 if (k > n) return 0;
3069
3070 if (k > n/2)
3071 k = n-k;
3072
3073 uint64_t r = 1;
3074 for (uint64_t i = 1; i <= k; ++i) {
3075 r = umul_ov(r, n-(i-1), Overflow);
3076 r /= i;
3077 }
3078 return r;
3079}
3080
3081/// Determine if any of the operands in this SCEV are a constant or if
3082/// any of the add or multiply expressions in this SCEV contain a constant.
3083static bool containsConstantInAddMulChain(const SCEV *StartExpr) {
3084 struct FindConstantInAddMulChain {
3085 bool FoundConstant = false;
3086
3087 bool follow(const SCEV *S) {
3088 FoundConstant |= isa<SCEVConstant>(S);
3089 return isa<SCEVAddExpr>(S) || isa<SCEVMulExpr>(S);
3090 }
3091
3092 bool isDone() const {
3093 return FoundConstant;
3094 }
3095 };
3096
3097 FindConstantInAddMulChain F;
3099 ST.visitAll(StartExpr);
3100 return F.FoundConstant;
3101}
3102
3103/// Get a canonical multiply expression, or something simpler if possible.
3105 SCEV::NoWrapFlags OrigFlags,
3106 unsigned Depth) {
3107 assert(OrigFlags == maskFlags(OrigFlags, SCEV::FlagNUW | SCEV::FlagNSW) &&
3108 "only nuw or nsw allowed");
3109 assert(!Ops.empty() && "Cannot get empty mul!");
3110 if (Ops.size() == 1) return Ops[0];
3111#ifndef NDEBUG
3112 Type *ETy = Ops[0]->getType();
3113 assert(!ETy->isPointerTy());
3114 for (unsigned i = 1, e = Ops.size(); i != e; ++i)
3115 assert(Ops[i]->getType() == ETy &&
3116 "SCEVMulExpr operand types don't match!");
3117#endif
3118
3119 const SCEV *Folded = constantFoldAndGroupOps(
3120 *this, LI, DT, Ops,
3121 [](const APInt &C1, const APInt &C2) { return C1 * C2; },
3122 [](const APInt &C) { return C.isOne(); }, // identity
3123 [](const APInt &C) { return C.isZero(); }); // absorber
3124 if (Folded)
3125 return Folded;
3126
3127 // Delay expensive flag strengthening until necessary.
3128 auto ComputeFlags = [this, OrigFlags](const ArrayRef<const SCEV *> Ops) {
3129 return StrengthenNoWrapFlags(this, scMulExpr, Ops, OrigFlags);
3130 };
3131
3132 // Limit recursion calls depth.
3134 return getOrCreateMulExpr(Ops, ComputeFlags(Ops));
3135
3136 if (SCEV *S = findExistingSCEVInCache(scMulExpr, Ops)) {
3137 // Don't strengthen flags if we have no new information.
3138 SCEVMulExpr *Mul = static_cast<SCEVMulExpr *>(S);
3139 if (Mul->getNoWrapFlags(OrigFlags) != OrigFlags)
3140 Mul->setNoWrapFlags(ComputeFlags(Ops));
3141 return S;
3142 }
3143
3144 if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(Ops[0])) {
3145 if (Ops.size() == 2) {
3146 // C1*(C2+V) -> C1*C2 + C1*V
3147 // If any of Add's ops are Adds or Muls with a constant, apply this
3148 // transformation as well.
3149 //
3150 // TODO: There are some cases where this transformation is not
3151 // profitable; for example, Add = (C0 + X) * Y + Z. Maybe the scope of
3152 // this transformation should be narrowed down.
3153 const SCEV *Op0, *Op1;
3154 if (match(Ops[1], m_scev_Add(m_SCEV(Op0), m_SCEV(Op1))) &&
3156 const SCEV *LHS = getMulExpr(LHSC, Op0, SCEV::FlagAnyWrap, Depth + 1);
3157 const SCEV *RHS = getMulExpr(LHSC, Op1, SCEV::FlagAnyWrap, Depth + 1);
3158 return getAddExpr(LHS, RHS, SCEV::FlagAnyWrap, Depth + 1);
3159 }
3160
3161 if (Ops[0]->isAllOnesValue()) {
3162 // If we have a mul by -1 of an add, try distributing the -1 among the
3163 // add operands.
3164 if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Ops[1])) {
3166 bool AnyFolded = false;
3167 for (const SCEV *AddOp : Add->operands()) {
3168 const SCEV *Mul = getMulExpr(Ops[0], AddOp, SCEV::FlagAnyWrap,
3169 Depth + 1);
3170 if (!isa<SCEVMulExpr>(Mul)) AnyFolded = true;
3171 NewOps.push_back(Mul);
3172 }
3173 if (AnyFolded)
3174 return getAddExpr(NewOps, SCEV::FlagAnyWrap, Depth + 1);
3175 } else if (const auto *AddRec = dyn_cast<SCEVAddRecExpr>(Ops[1])) {
3176 // Negation preserves a recurrence's no self-wrap property.
3178 for (const SCEV *AddRecOp : AddRec->operands())
3179 Operands.push_back(getMulExpr(Ops[0], AddRecOp, SCEV::FlagAnyWrap,
3180 Depth + 1));
3181 // Let M be the minimum representable signed value. AddRec with nsw
3182 // multiplied by -1 can have signed overflow if and only if it takes a
3183 // value of M: M * (-1) would stay M and (M + 1) * (-1) would be the
3184 // maximum signed value. In all other cases signed overflow is
3185 // impossible.
3186 auto FlagsMask = SCEV::FlagNW;
3187 if (hasFlags(AddRec->getNoWrapFlags(), SCEV::FlagNSW)) {
3188 auto MinInt =
3189 APInt::getSignedMinValue(getTypeSizeInBits(AddRec->getType()));
3190 if (getSignedRangeMin(AddRec) != MinInt)
3191 FlagsMask = setFlags(FlagsMask, SCEV::FlagNSW);
3192 }
3193 return getAddRecExpr(Operands, AddRec->getLoop(),
3194 AddRec->getNoWrapFlags(FlagsMask));
3195 }
3196 }
3197
3198 // Try to push the constant operand into a ZExt: C * zext (A + B) ->
3199 // zext (C*A + C*B) if trunc (C) * (A + B) does not unsigned-wrap.
3200 const SCEVAddExpr *InnerAdd;
3201 if (match(Ops[1], m_scev_ZExt(m_scev_Add(InnerAdd)))) {
3202 const SCEV *NarrowC = getTruncateExpr(LHSC, InnerAdd->getType());
3203 if (isa<SCEVConstant>(InnerAdd->getOperand(0)) &&
3204 getZeroExtendExpr(NarrowC, Ops[1]->getType()) == LHSC &&
3205 hasFlags(StrengthenNoWrapFlags(this, scMulExpr, {NarrowC, InnerAdd},
3207 SCEV::FlagNUW)) {
3208 auto *Res = getMulExpr(NarrowC, InnerAdd, SCEV::FlagNUW, Depth + 1);
3209 return getZeroExtendExpr(Res, Ops[1]->getType(), Depth + 1);
3210 };
3211 }
3212
3213 // Try to fold (C1 * D /u C2) -> C1/C2 * D, if C1 and C2 are powers-of-2,
3214 // D is a multiple of C2, and C1 is a multiple of C2. If C2 is a multiple
3215 // of C1, fold to (D /u (C2 /u C1)).
3216 const SCEV *D;
3217 APInt C1V = LHSC->getAPInt();
3218 // (C1 * D /u C2) == -1 * -C1 * D /u C2 when C1 != INT_MIN. Don't treat -1
3219 // as -1 * 1, as it won't enable additional folds.
3220 if (C1V.isNegative() && !C1V.isMinSignedValue() && !C1V.isAllOnes())
3221 C1V = C1V.abs();
3222 const SCEVConstant *C2;
3223 if (C1V.isPowerOf2() &&
3225 C2->getAPInt().isPowerOf2() &&
3226 C1V.logBase2() <= getMinTrailingZeros(D)) {
3227 const SCEV *NewMul = nullptr;
3228 if (C1V.uge(C2->getAPInt())) {
3229 NewMul = getMulExpr(getUDivExpr(getConstant(C1V), C2), D);
3230 } else if (C2->getAPInt().logBase2() <= getMinTrailingZeros(D)) {
3231 assert(C1V.ugt(1) && "C1 <= 1 should have been folded earlier");
3232 NewMul = getUDivExpr(D, getUDivExpr(C2, getConstant(C1V)));
3233 }
3234 if (NewMul)
3235 return C1V == LHSC->getAPInt() ? NewMul : getNegativeSCEV(NewMul);
3236 }
3237 }
3238 }
3239
3240 // Skip over the add expression until we get to a multiply.
3241 unsigned Idx = 0;
3242 while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scMulExpr)
3243 ++Idx;
3244
3245 // If there are mul operands inline them all into this expression.
3246 if (Idx < Ops.size()) {
3247 bool DeletedMul = false;
3248 while (const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(Ops[Idx])) {
3249 if (Ops.size() > MulOpsInlineThreshold)
3250 break;
3251 // If we have an mul, expand the mul operands onto the end of the
3252 // operands list.
3253 Ops.erase(Ops.begin()+Idx);
3254 append_range(Ops, Mul->operands());
3255 DeletedMul = true;
3256 }
3257
3258 // If we deleted at least one mul, we added operands to the end of the
3259 // list, and they are not necessarily sorted. Recurse to resort and
3260 // resimplify any operands we just acquired.
3261 if (DeletedMul)
3262 return getMulExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
3263 }
3264
3265 // If there are any add recurrences in the operands list, see if any other
3266 // added values are loop invariant. If so, we can fold them into the
3267 // recurrence.
3268 while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scAddRecExpr)
3269 ++Idx;
3270
3271 // Scan over all recurrences, trying to fold loop invariants into them.
3272 for (; Idx < Ops.size() && isa<SCEVAddRecExpr>(Ops[Idx]); ++Idx) {
3273 // Scan all of the other operands to this mul and add them to the vector
3274 // if they are loop invariant w.r.t. the recurrence.
3276 const SCEVAddRecExpr *AddRec = cast<SCEVAddRecExpr>(Ops[Idx]);
3277 for (unsigned i = 0, e = Ops.size(); i != e; ++i)
3278 if (isAvailableAtLoopEntry(Ops[i], AddRec->getLoop())) {
3279 LIOps.push_back(Ops[i]);
3280 Ops.erase(Ops.begin()+i);
3281 --i; --e;
3282 }
3283
3284 // If we found some loop invariants, fold them into the recurrence.
3285 if (!LIOps.empty()) {
3286 // NLI * LI * {Start,+,Step} --> NLI * {LI*Start,+,LI*Step}
3288 NewOps.reserve(AddRec->getNumOperands());
3289 const SCEV *Scale = getMulExpr(LIOps, SCEV::FlagAnyWrap, Depth + 1);
3290
3291 // If both the mul and addrec are nuw, we can preserve nuw.
3292 // If both the mul and addrec are nsw, we can only preserve nsw if either
3293 // a) they are also nuw, or
3294 // b) all multiplications of addrec operands with scale are nsw.
3295 SCEV::NoWrapFlags Flags =
3296 AddRec->getNoWrapFlags(ComputeFlags({Scale, AddRec}));
3297
3298 for (unsigned i = 0, e = AddRec->getNumOperands(); i != e; ++i) {
3299 NewOps.push_back(getMulExpr(Scale, AddRec->getOperand(i),
3300 SCEV::FlagAnyWrap, Depth + 1));
3301
3302 if (hasFlags(Flags, SCEV::FlagNSW) && !hasFlags(Flags, SCEV::FlagNUW)) {
3304 Instruction::Mul, getSignedRange(Scale),
3306 if (!NSWRegion.contains(getSignedRange(AddRec->getOperand(i))))
3307 Flags = clearFlags(Flags, SCEV::FlagNSW);
3308 }
3309 }
3310
3311 const SCEV *NewRec = getAddRecExpr(NewOps, AddRec->getLoop(), Flags);
3312
3313 // If all of the other operands were loop invariant, we are done.
3314 if (Ops.size() == 1) return NewRec;
3315
3316 // Otherwise, multiply the folded AddRec by the non-invariant parts.
3317 for (unsigned i = 0;; ++i)
3318 if (Ops[i] == AddRec) {
3319 Ops[i] = NewRec;
3320 break;
3321 }
3322 return getMulExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
3323 }
3324
3325 // Okay, if there weren't any loop invariants to be folded, check to see
3326 // if there are multiple AddRec's with the same loop induction variable
3327 // being multiplied together. If so, we can fold them.
3328
3329 // {A1,+,A2,+,...,+,An}<L> * {B1,+,B2,+,...,+,Bn}<L>
3330 // = {x=1 in [ sum y=x..2x [ sum z=max(y-x, y-n)..min(x,n) [
3331 // choose(x, 2x)*choose(2x-y, x-z)*A_{y-z}*B_z
3332 // ]]],+,...up to x=2n}.
3333 // Note that the arguments to choose() are always integers with values
3334 // known at compile time, never SCEV objects.
3335 //
3336 // The implementation avoids pointless extra computations when the two
3337 // addrec's are of different length (mathematically, it's equivalent to
3338 // an infinite stream of zeros on the right).
3339 bool OpsModified = false;
3340 for (unsigned OtherIdx = Idx+1;
3341 OtherIdx != Ops.size() && isa<SCEVAddRecExpr>(Ops[OtherIdx]);
3342 ++OtherIdx) {
3343 const SCEVAddRecExpr *OtherAddRec =
3344 dyn_cast<SCEVAddRecExpr>(Ops[OtherIdx]);
3345 if (!OtherAddRec || OtherAddRec->getLoop() != AddRec->getLoop())
3346 continue;
3347
3348 // Limit max number of arguments to avoid creation of unreasonably big
3349 // SCEVAddRecs with very complex operands.
3350 if (AddRec->getNumOperands() + OtherAddRec->getNumOperands() - 1 >
3351 MaxAddRecSize || hasHugeExpression({AddRec, OtherAddRec}))
3352 continue;
3353
3354 bool Overflow = false;
3355 Type *Ty = AddRec->getType();
3356 bool LargerThan64Bits = getTypeSizeInBits(Ty) > 64;
3358 for (int x = 0, xe = AddRec->getNumOperands() +
3359 OtherAddRec->getNumOperands() - 1; x != xe && !Overflow; ++x) {
3361 for (int y = x, ye = 2*x+1; y != ye && !Overflow; ++y) {
3362 uint64_t Coeff1 = Choose(x, 2*x - y, Overflow);
3363 for (int z = std::max(y-x, y-(int)AddRec->getNumOperands()+1),
3364 ze = std::min(x+1, (int)OtherAddRec->getNumOperands());
3365 z < ze && !Overflow; ++z) {
3366 uint64_t Coeff2 = Choose(2*x - y, x-z, Overflow);
3367 uint64_t Coeff;
3368 if (LargerThan64Bits)
3369 Coeff = umul_ov(Coeff1, Coeff2, Overflow);
3370 else
3371 Coeff = Coeff1*Coeff2;
3372 const SCEV *CoeffTerm = getConstant(Ty, Coeff);
3373 const SCEV *Term1 = AddRec->getOperand(y-z);
3374 const SCEV *Term2 = OtherAddRec->getOperand(z);
3375 SumOps.push_back(getMulExpr(CoeffTerm, Term1, Term2,
3376 SCEV::FlagAnyWrap, Depth + 1));
3377 }
3378 }
3379 if (SumOps.empty())
3380 SumOps.push_back(getZero(Ty));
3381 AddRecOps.push_back(getAddExpr(SumOps, SCEV::FlagAnyWrap, Depth + 1));
3382 }
3383 if (!Overflow) {
3384 const SCEV *NewAddRec = getAddRecExpr(AddRecOps, AddRec->getLoop(),
3386 if (Ops.size() == 2) return NewAddRec;
3387 Ops[Idx] = NewAddRec;
3388 Ops.erase(Ops.begin() + OtherIdx); --OtherIdx;
3389 OpsModified = true;
3390 AddRec = dyn_cast<SCEVAddRecExpr>(NewAddRec);
3391 if (!AddRec)
3392 break;
3393 }
3394 }
3395 if (OpsModified)
3396 return getMulExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
3397
3398 // Otherwise couldn't fold anything into this recurrence. Move onto the
3399 // next one.
3400 }
3401
3402 // Okay, it looks like we really DO need an mul expr. Check to see if we
3403 // already have one, otherwise create a new one.
3404 return getOrCreateMulExpr(Ops, ComputeFlags(Ops));
3405}
3406
3407/// Represents an unsigned remainder expression based on unsigned division.
3409 const SCEV *RHS) {
3410 assert(getEffectiveSCEVType(LHS->getType()) ==
3411 getEffectiveSCEVType(RHS->getType()) &&
3412 "SCEVURemExpr operand types don't match!");
3413
3414 // Short-circuit easy cases
3415 if (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS)) {
3416 // If constant is one, the result is trivial
3417 if (RHSC->getValue()->isOne())
3418 return getZero(LHS->getType()); // X urem 1 --> 0
3419
3420 // If constant is a power of two, fold into a zext(trunc(LHS)).
3421 if (RHSC->getAPInt().isPowerOf2()) {
3422 Type *FullTy = LHS->getType();
3423 Type *TruncTy =
3424 IntegerType::get(getContext(), RHSC->getAPInt().logBase2());
3425 return getZeroExtendExpr(getTruncateExpr(LHS, TruncTy), FullTy);
3426 }
3427 }
3428
3429 // Fallback to %a == %x urem %y == %x -<nuw> ((%x udiv %y) *<nuw> %y)
3430 const SCEV *UDiv = getUDivExpr(LHS, RHS);
3431 const SCEV *Mult = getMulExpr(UDiv, RHS, SCEV::FlagNUW);
3432 return getMinusSCEV(LHS, Mult, SCEV::FlagNUW);
3433}
3434
3435/// Get a canonical unsigned division expression, or something simpler if
3436/// possible.
3438 const SCEV *RHS) {
3439 assert(!LHS->getType()->isPointerTy() &&
3440 "SCEVUDivExpr operand can't be pointer!");
3441 assert(LHS->getType() == RHS->getType() &&
3442 "SCEVUDivExpr operand types don't match!");
3443
3445 ID.AddInteger(scUDivExpr);
3446 ID.AddPointer(LHS);
3447 ID.AddPointer(RHS);
3448 void *IP = nullptr;
3449 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP))
3450 return S;
3451
3452 // 0 udiv Y == 0
3453 if (match(LHS, m_scev_Zero()))
3454 return LHS;
3455
3456 if (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS)) {
3457 if (RHSC->getValue()->isOne())
3458 return LHS; // X udiv 1 --> x
3459 // If the denominator is zero, the result of the udiv is undefined. Don't
3460 // try to analyze it, because the resolution chosen here may differ from
3461 // the resolution chosen in other parts of the compiler.
3462 if (!RHSC->getValue()->isZero()) {
3463 // Determine if the division can be folded into the operands of
3464 // its operands.
3465 // TODO: Generalize this to non-constants by using known-bits information.
3466 Type *Ty = LHS->getType();
3467 unsigned LZ = RHSC->getAPInt().countl_zero();
3468 unsigned MaxShiftAmt = getTypeSizeInBits(Ty) - LZ - 1;
3469 // For non-power-of-two values, effectively round the value up to the
3470 // nearest power of two.
3471 if (!RHSC->getAPInt().isPowerOf2())
3472 ++MaxShiftAmt;
3473 IntegerType *ExtTy =
3474 IntegerType::get(getContext(), getTypeSizeInBits(Ty) + MaxShiftAmt);
3475 if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(LHS))
3476 if (const SCEVConstant *Step =
3477 dyn_cast<SCEVConstant>(AR->getStepRecurrence(*this))) {
3478 // {X,+,N}/C --> {X/C,+,N/C} if safe and N/C can be folded.
3479 const APInt &StepInt = Step->getAPInt();
3480 const APInt &DivInt = RHSC->getAPInt();
3481 if (!StepInt.urem(DivInt) &&
3482 getZeroExtendExpr(AR, ExtTy) ==
3483 getAddRecExpr(getZeroExtendExpr(AR->getStart(), ExtTy),
3484 getZeroExtendExpr(Step, ExtTy),
3485 AR->getLoop(), SCEV::FlagAnyWrap)) {
3487 for (const SCEV *Op : AR->operands())
3488 Operands.push_back(getUDivExpr(Op, RHS));
3489 return getAddRecExpr(Operands, AR->getLoop(), SCEV::FlagNW);
3490 }
3491 /// Get a canonical UDivExpr for a recurrence.
3492 /// {X,+,N}/C => {Y,+,N}/C where Y=X-(X%N). Safe when C%N=0.
3493 // We can currently only fold X%N if X is constant.
3494 const SCEVConstant *StartC = dyn_cast<SCEVConstant>(AR->getStart());
3495 if (StartC && !DivInt.urem(StepInt) &&
3496 getZeroExtendExpr(AR, ExtTy) ==
3497 getAddRecExpr(getZeroExtendExpr(AR->getStart(), ExtTy),
3498 getZeroExtendExpr(Step, ExtTy),
3499 AR->getLoop(), SCEV::FlagAnyWrap)) {
3500 const APInt &StartInt = StartC->getAPInt();
3501 const APInt &StartRem = StartInt.urem(StepInt);
3502 if (StartRem != 0) {
3503 const SCEV *NewLHS =
3504 getAddRecExpr(getConstant(StartInt - StartRem), Step,
3505 AR->getLoop(), SCEV::FlagNW);
3506 if (LHS != NewLHS) {
3507 LHS = NewLHS;
3508
3509 // Reset the ID to include the new LHS, and check if it is
3510 // already cached.
3511 ID.clear();
3512 ID.AddInteger(scUDivExpr);
3513 ID.AddPointer(LHS);
3514 ID.AddPointer(RHS);
3515 IP = nullptr;
3516 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP))
3517 return S;
3518 }
3519 }
3520 }
3521 }
3522 // (A*B)/C --> A*(B/C) if safe and B/C can be folded.
3523 if (const SCEVMulExpr *M = dyn_cast<SCEVMulExpr>(LHS)) {
3525 for (const SCEV *Op : M->operands())
3526 Operands.push_back(getZeroExtendExpr(Op, ExtTy));
3527 if (getZeroExtendExpr(M, ExtTy) == getMulExpr(Operands))
3528 // Find an operand that's safely divisible.
3529 for (unsigned i = 0, e = M->getNumOperands(); i != e; ++i) {
3530 const SCEV *Op = M->getOperand(i);
3531 const SCEV *Div = getUDivExpr(Op, RHSC);
3532 if (!isa<SCEVUDivExpr>(Div) && getMulExpr(Div, RHSC) == Op) {
3533 Operands = SmallVector<const SCEV *, 4>(M->operands());
3534 Operands[i] = Div;
3535 return getMulExpr(Operands);
3536 }
3537 }
3538 }
3539
3540 // (A/B)/C --> A/(B*C) if safe and B*C can be folded.
3541 if (const SCEVUDivExpr *OtherDiv = dyn_cast<SCEVUDivExpr>(LHS)) {
3542 if (auto *DivisorConstant =
3543 dyn_cast<SCEVConstant>(OtherDiv->getRHS())) {
3544 bool Overflow = false;
3545 APInt NewRHS =
3546 DivisorConstant->getAPInt().umul_ov(RHSC->getAPInt(), Overflow);
3547 if (Overflow) {
3548 return getConstant(RHSC->getType(), 0, false);
3549 }
3550 return getUDivExpr(OtherDiv->getLHS(), getConstant(NewRHS));
3551 }
3552 }
3553
3554 // (A+B)/C --> (A/C + B/C) if safe and A/C and B/C can be folded.
3555 if (const SCEVAddExpr *A = dyn_cast<SCEVAddExpr>(LHS)) {
3557 for (const SCEV *Op : A->operands())
3558 Operands.push_back(getZeroExtendExpr(Op, ExtTy));
3559 if (getZeroExtendExpr(A, ExtTy) == getAddExpr(Operands)) {
3560 Operands.clear();
3561 for (unsigned i = 0, e = A->getNumOperands(); i != e; ++i) {
3562 const SCEV *Op = getUDivExpr(A->getOperand(i), RHS);
3563 if (isa<SCEVUDivExpr>(Op) ||
3564 getMulExpr(Op, RHS) != A->getOperand(i))
3565 break;
3566 Operands.push_back(Op);
3567 }
3568 if (Operands.size() == A->getNumOperands())
3569 return getAddExpr(Operands);
3570 }
3571 }
3572
3573 // Fold if both operands are constant.
3574 if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(LHS))
3575 return getConstant(LHSC->getAPInt().udiv(RHSC->getAPInt()));
3576 }
3577 }
3578
3579 // ((-C + (C smax %x)) /u %x) evaluates to zero, for any positive constant C.
3580 const APInt *NegC, *C;
3581 if (match(LHS,
3584 NegC->isNegative() && !NegC->isMinSignedValue() && *C == -*NegC)
3585 return getZero(LHS->getType());
3586
3587 // TODO: Generalize to handle any common factors.
3588 // udiv (mul nuw a, vscale), (mul nuw b, vscale) --> udiv a, b
3589 const SCEV *NewLHS, *NewRHS;
3590 if (match(LHS, m_scev_c_NUWMul(m_SCEV(NewLHS), m_SCEVVScale())) &&
3591 match(RHS, m_scev_c_NUWMul(m_SCEV(NewRHS), m_SCEVVScale())))
3592 return getUDivExpr(NewLHS, NewRHS);
3593
3594 // The Insertion Point (IP) might be invalid by now (due to UniqueSCEVs
3595 // changes). Make sure we get a new one.
3596 IP = nullptr;
3597 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
3598 SCEV *S = new (SCEVAllocator) SCEVUDivExpr(ID.Intern(SCEVAllocator),
3599 LHS, RHS);
3600 UniqueSCEVs.InsertNode(S, IP);
3601 registerUser(S, {LHS, RHS});
3602 return S;
3603}
3604
3605APInt gcd(const SCEVConstant *C1, const SCEVConstant *C2) {
3606 APInt A = C1->getAPInt().abs();
3607 APInt B = C2->getAPInt().abs();
3608 uint32_t ABW = A.getBitWidth();
3609 uint32_t BBW = B.getBitWidth();
3610
3611 if (ABW > BBW)
3612 B = B.zext(ABW);
3613 else if (ABW < BBW)
3614 A = A.zext(BBW);
3615
3616 return APIntOps::GreatestCommonDivisor(std::move(A), std::move(B));
3617}
3618
3619/// Get a canonical unsigned division expression, or something simpler if
3620/// possible. There is no representation for an exact udiv in SCEV IR, but we
3621/// can attempt to remove factors from the LHS and RHS. We can't do this when
3622/// it's not exact because the udiv may be clearing bits.
3624 const SCEV *RHS) {
3625 // TODO: we could try to find factors in all sorts of things, but for now we
3626 // just deal with u/exact (multiply, constant). See SCEVDivision towards the
3627 // end of this file for inspiration.
3628
3630 if (!Mul || !Mul->hasNoUnsignedWrap())
3631 return getUDivExpr(LHS, RHS);
3632
3633 if (const SCEVConstant *RHSCst = dyn_cast<SCEVConstant>(RHS)) {
3634 // If the mulexpr multiplies by a constant, then that constant must be the
3635 // first element of the mulexpr.
3636 if (const auto *LHSCst = dyn_cast<SCEVConstant>(Mul->getOperand(0))) {
3637 if (LHSCst == RHSCst) {
3638 SmallVector<const SCEV *, 2> Operands(drop_begin(Mul->operands()));
3639 return getMulExpr(Operands);
3640 }
3641
3642 // We can't just assume that LHSCst divides RHSCst cleanly, it could be
3643 // that there's a factor provided by one of the other terms. We need to
3644 // check.
3645 APInt Factor = gcd(LHSCst, RHSCst);
3646 if (!Factor.isIntN(1)) {
3647 LHSCst =
3648 cast<SCEVConstant>(getConstant(LHSCst->getAPInt().udiv(Factor)));
3649 RHSCst =
3650 cast<SCEVConstant>(getConstant(RHSCst->getAPInt().udiv(Factor)));
3652 Operands.push_back(LHSCst);
3653 append_range(Operands, Mul->operands().drop_front());
3654 LHS = getMulExpr(Operands);
3655 RHS = RHSCst;
3657 if (!Mul)
3658 return getUDivExactExpr(LHS, RHS);
3659 }
3660 }
3661 }
3662
3663 for (int i = 0, e = Mul->getNumOperands(); i != e; ++i) {
3664 if (Mul->getOperand(i) == RHS) {
3666 append_range(Operands, Mul->operands().take_front(i));
3667 append_range(Operands, Mul->operands().drop_front(i + 1));
3668 return getMulExpr(Operands);
3669 }
3670 }
3671
3672 return getUDivExpr(LHS, RHS);
3673}
3674
3675/// Get an add recurrence expression for the specified loop. Simplify the
3676/// expression as much as possible.
3677const SCEV *ScalarEvolution::getAddRecExpr(const SCEV *Start, const SCEV *Step,
3678 const Loop *L,
3679 SCEV::NoWrapFlags Flags) {
3681 Operands.push_back(Start);
3682 if (const SCEVAddRecExpr *StepChrec = dyn_cast<SCEVAddRecExpr>(Step))
3683 if (StepChrec->getLoop() == L) {
3684 append_range(Operands, StepChrec->operands());
3685 return getAddRecExpr(Operands, L, maskFlags(Flags, SCEV::FlagNW));
3686 }
3687
3688 Operands.push_back(Step);
3689 return getAddRecExpr(Operands, L, Flags);
3690}
3691
3692/// Get an add recurrence expression for the specified loop. Simplify the
3693/// expression as much as possible.
3694const SCEV *
3696 const Loop *L, SCEV::NoWrapFlags Flags) {
3697 if (Operands.size() == 1) return Operands[0];
3698#ifndef NDEBUG
3699 Type *ETy = getEffectiveSCEVType(Operands[0]->getType());
3700 for (const SCEV *Op : llvm::drop_begin(Operands)) {
3701 assert(getEffectiveSCEVType(Op->getType()) == ETy &&
3702 "SCEVAddRecExpr operand types don't match!");
3703 assert(!Op->getType()->isPointerTy() && "Step must be integer");
3704 }
3705 for (const SCEV *Op : Operands)
3707 "SCEVAddRecExpr operand is not available at loop entry!");
3708#endif
3709
3710 if (Operands.back()->isZero()) {
3711 Operands.pop_back();
3712 return getAddRecExpr(Operands, L, SCEV::FlagAnyWrap); // {X,+,0} --> X
3713 }
3714
3715 // It's tempting to want to call getConstantMaxBackedgeTakenCount count here and
3716 // use that information to infer NUW and NSW flags. However, computing a
3717 // BE count requires calling getAddRecExpr, so we may not yet have a
3718 // meaningful BE count at this point (and if we don't, we'd be stuck
3719 // with a SCEVCouldNotCompute as the cached BE count).
3720
3721 Flags = StrengthenNoWrapFlags(this, scAddRecExpr, Operands, Flags);
3722
3723 // Canonicalize nested AddRecs in by nesting them in order of loop depth.
3724 if (const SCEVAddRecExpr *NestedAR = dyn_cast<SCEVAddRecExpr>(Operands[0])) {
3725 const Loop *NestedLoop = NestedAR->getLoop();
3726 if (L->contains(NestedLoop)
3727 ? (L->getLoopDepth() < NestedLoop->getLoopDepth())
3728 : (!NestedLoop->contains(L) &&
3729 DT.dominates(L->getHeader(), NestedLoop->getHeader()))) {
3730 SmallVector<const SCEV *, 4> NestedOperands(NestedAR->operands());
3731 Operands[0] = NestedAR->getStart();
3732 // AddRecs require their operands be loop-invariant with respect to their
3733 // loops. Don't perform this transformation if it would break this
3734 // requirement.
3735 bool AllInvariant = all_of(
3736 Operands, [&](const SCEV *Op) { return isLoopInvariant(Op, L); });
3737
3738 if (AllInvariant) {
3739 // Create a recurrence for the outer loop with the same step size.
3740 //
3741 // The outer recurrence keeps its NW flag but only keeps NUW/NSW if the
3742 // inner recurrence has the same property.
3743 SCEV::NoWrapFlags OuterFlags =
3744 maskFlags(Flags, SCEV::FlagNW | NestedAR->getNoWrapFlags());
3745
3746 NestedOperands[0] = getAddRecExpr(Operands, L, OuterFlags);
3747 AllInvariant = all_of(NestedOperands, [&](const SCEV *Op) {
3748 return isLoopInvariant(Op, NestedLoop);
3749 });
3750
3751 if (AllInvariant) {
3752 // Ok, both add recurrences are valid after the transformation.
3753 //
3754 // The inner recurrence keeps its NW flag but only keeps NUW/NSW if
3755 // the outer recurrence has the same property.
3756 SCEV::NoWrapFlags InnerFlags =
3757 maskFlags(NestedAR->getNoWrapFlags(), SCEV::FlagNW | Flags);
3758 return getAddRecExpr(NestedOperands, NestedLoop, InnerFlags);
3759 }
3760 }
3761 // Reset Operands to its original state.
3762 Operands[0] = NestedAR;
3763 }
3764 }
3765
3766 // Okay, it looks like we really DO need an addrec expr. Check to see if we
3767 // already have one, otherwise create a new one.
3768 return getOrCreateAddRecExpr(Operands, L, Flags);
3769}
3770
3772 ArrayRef<const SCEV *> IndexExprs) {
3773 const SCEV *BaseExpr = getSCEV(GEP->getPointerOperand());
3774 // getSCEV(Base)->getType() has the same address space as Base->getType()
3775 // because SCEV::getType() preserves the address space.
3776 GEPNoWrapFlags NW = GEP->getNoWrapFlags();
3777 if (NW != GEPNoWrapFlags::none()) {
3778 // We'd like to propagate flags from the IR to the corresponding SCEV nodes,
3779 // but to do that, we have to ensure that said flag is valid in the entire
3780 // defined scope of the SCEV.
3781 // TODO: non-instructions have global scope. We might be able to prove
3782 // some global scope cases
3783 auto *GEPI = dyn_cast<Instruction>(GEP);
3784 if (!GEPI || !isSCEVExprNeverPoison(GEPI))
3785 NW = GEPNoWrapFlags::none();
3786 }
3787
3788 return getGEPExpr(BaseExpr, IndexExprs, GEP->getSourceElementType(), NW);
3789}
3790
3792 ArrayRef<const SCEV *> IndexExprs,
3793 Type *SrcElementTy, GEPNoWrapFlags NW) {
3795 if (NW.hasNoUnsignedSignedWrap())
3796 OffsetWrap = setFlags(OffsetWrap, SCEV::FlagNSW);
3797 if (NW.hasNoUnsignedWrap())
3798 OffsetWrap = setFlags(OffsetWrap, SCEV::FlagNUW);
3799
3800 Type *CurTy = BaseExpr->getType();
3801 Type *IntIdxTy = getEffectiveSCEVType(BaseExpr->getType());
3802 bool FirstIter = true;
3804 for (const SCEV *IndexExpr : IndexExprs) {
3805 // Compute the (potentially symbolic) offset in bytes for this index.
3806 if (StructType *STy = dyn_cast<StructType>(CurTy)) {
3807 // For a struct, add the member offset.
3808 ConstantInt *Index = cast<SCEVConstant>(IndexExpr)->getValue();
3809 unsigned FieldNo = Index->getZExtValue();
3810 const SCEV *FieldOffset = getOffsetOfExpr(IntIdxTy, STy, FieldNo);
3811 Offsets.push_back(FieldOffset);
3812
3813 // Update CurTy to the type of the field at Index.
3814 CurTy = STy->getTypeAtIndex(Index);
3815 } else {
3816 // Update CurTy to its element type.
3817 if (FirstIter) {
3818 assert(isa<PointerType>(CurTy) &&
3819 "The first index of a GEP indexes a pointer");
3820 CurTy = SrcElementTy;
3821 FirstIter = false;
3822 } else {
3824 }
3825 // For an array, add the element offset, explicitly scaled.
3826 const SCEV *ElementSize = getSizeOfExpr(IntIdxTy, CurTy);
3827 // Getelementptr indices are signed.
3828 IndexExpr = getTruncateOrSignExtend(IndexExpr, IntIdxTy);
3829
3830 // Multiply the index by the element size to compute the element offset.
3831 const SCEV *LocalOffset = getMulExpr(IndexExpr, ElementSize, OffsetWrap);
3832 Offsets.push_back(LocalOffset);
3833 }
3834 }
3835
3836 // Handle degenerate case of GEP without offsets.
3837 if (Offsets.empty())
3838 return BaseExpr;
3839
3840 // Add the offsets together, assuming nsw if inbounds.
3841 const SCEV *Offset = getAddExpr(Offsets, OffsetWrap);
3842 // Add the base address and the offset. We cannot use the nsw flag, as the
3843 // base address is unsigned. However, if we know that the offset is
3844 // non-negative, we can use nuw.
3845 bool NUW = NW.hasNoUnsignedWrap() ||
3848 auto *GEPExpr = getAddExpr(BaseExpr, Offset, BaseWrap);
3849 assert(BaseExpr->getType() == GEPExpr->getType() &&
3850 "GEP should not change type mid-flight.");
3851 return GEPExpr;
3852}
3853
3854SCEV *ScalarEvolution::findExistingSCEVInCache(SCEVTypes SCEVType,
3857 ID.AddInteger(SCEVType);
3858 for (const SCEV *Op : Ops)
3859 ID.AddPointer(Op);
3860 void *IP = nullptr;
3861 return UniqueSCEVs.FindNodeOrInsertPos(ID, IP);
3862}
3863
3864const SCEV *ScalarEvolution::getAbsExpr(const SCEV *Op, bool IsNSW) {
3866 return getSMaxExpr(Op, getNegativeSCEV(Op, Flags));
3867}
3868
3871 assert(SCEVMinMaxExpr::isMinMaxType(Kind) && "Not a SCEVMinMaxExpr!");
3872 assert(!Ops.empty() && "Cannot get empty (u|s)(min|max)!");
3873 if (Ops.size() == 1) return Ops[0];
3874#ifndef NDEBUG
3875 Type *ETy = getEffectiveSCEVType(Ops[0]->getType());
3876 for (unsigned i = 1, e = Ops.size(); i != e; ++i) {
3877 assert(getEffectiveSCEVType(Ops[i]->getType()) == ETy &&
3878 "Operand types don't match!");
3879 assert(Ops[0]->getType()->isPointerTy() ==
3880 Ops[i]->getType()->isPointerTy() &&
3881 "min/max should be consistently pointerish");
3882 }
3883#endif
3884
3885 bool IsSigned = Kind == scSMaxExpr || Kind == scSMinExpr;
3886 bool IsMax = Kind == scSMaxExpr || Kind == scUMaxExpr;
3887
3888 const SCEV *Folded = constantFoldAndGroupOps(
3889 *this, LI, DT, Ops,
3890 [&](const APInt &C1, const APInt &C2) {
3891 switch (Kind) {
3892 case scSMaxExpr:
3893 return APIntOps::smax(C1, C2);
3894 case scSMinExpr:
3895 return APIntOps::smin(C1, C2);
3896 case scUMaxExpr:
3897 return APIntOps::umax(C1, C2);
3898 case scUMinExpr:
3899 return APIntOps::umin(C1, C2);
3900 default:
3901 llvm_unreachable("Unknown SCEV min/max opcode");
3902 }
3903 },
3904 [&](const APInt &C) {
3905 // identity
3906 if (IsMax)
3907 return IsSigned ? C.isMinSignedValue() : C.isMinValue();
3908 else
3909 return IsSigned ? C.isMaxSignedValue() : C.isMaxValue();
3910 },
3911 [&](const APInt &C) {
3912 // absorber
3913 if (IsMax)
3914 return IsSigned ? C.isMaxSignedValue() : C.isMaxValue();
3915 else
3916 return IsSigned ? C.isMinSignedValue() : C.isMinValue();
3917 });
3918 if (Folded)
3919 return Folded;
3920
3921 // Check if we have created the same expression before.
3922 if (const SCEV *S = findExistingSCEVInCache(Kind, Ops)) {
3923 return S;
3924 }
3925
3926 // Find the first operation of the same kind
3927 unsigned Idx = 0;
3928 while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < Kind)
3929 ++Idx;
3930
3931 // Check to see if one of the operands is of the same kind. If so, expand its
3932 // operands onto our operand list, and recurse to simplify.
3933 if (Idx < Ops.size()) {
3934 bool DeletedAny = false;
3935 while (Ops[Idx]->getSCEVType() == Kind) {
3936 const SCEVMinMaxExpr *SMME = cast<SCEVMinMaxExpr>(Ops[Idx]);
3937 Ops.erase(Ops.begin()+Idx);
3938 append_range(Ops, SMME->operands());
3939 DeletedAny = true;
3940 }
3941
3942 if (DeletedAny)
3943 return getMinMaxExpr(Kind, Ops);
3944 }
3945
3946 // Okay, check to see if the same value occurs in the operand list twice. If
3947 // so, delete one. Since we sorted the list, these values are required to
3948 // be adjacent.
3953 llvm::CmpInst::Predicate FirstPred = IsMax ? GEPred : LEPred;
3954 llvm::CmpInst::Predicate SecondPred = IsMax ? LEPred : GEPred;
3955 for (unsigned i = 0, e = Ops.size() - 1; i != e; ++i) {
3956 if (Ops[i] == Ops[i + 1] ||
3957 isKnownViaNonRecursiveReasoning(FirstPred, Ops[i], Ops[i + 1])) {
3958 // X op Y op Y --> X op Y
3959 // X op Y --> X, if we know X, Y are ordered appropriately
3960 Ops.erase(Ops.begin() + i + 1, Ops.begin() + i + 2);
3961 --i;
3962 --e;
3963 } else if (isKnownViaNonRecursiveReasoning(SecondPred, Ops[i],
3964 Ops[i + 1])) {
3965 // X op Y --> Y, if we know X, Y are ordered appropriately
3966 Ops.erase(Ops.begin() + i, Ops.begin() + i + 1);
3967 --i;
3968 --e;
3969 }
3970 }
3971
3972 if (Ops.size() == 1) return Ops[0];
3973
3974 assert(!Ops.empty() && "Reduced smax down to nothing!");
3975
3976 // Okay, it looks like we really DO need an expr. Check to see if we
3977 // already have one, otherwise create a new one.
3979 ID.AddInteger(Kind);
3980 for (const SCEV *Op : Ops)
3981 ID.AddPointer(Op);
3982 void *IP = nullptr;
3983 const SCEV *ExistingSCEV = UniqueSCEVs.FindNodeOrInsertPos(ID, IP);
3984 if (ExistingSCEV)
3985 return ExistingSCEV;
3986 const SCEV **O = SCEVAllocator.Allocate<const SCEV *>(Ops.size());
3988 SCEV *S = new (SCEVAllocator)
3989 SCEVMinMaxExpr(ID.Intern(SCEVAllocator), Kind, O, Ops.size());
3990
3991 UniqueSCEVs.InsertNode(S, IP);
3992 registerUser(S, Ops);
3993 return S;
3994}
3995
3996namespace {
3997
3998class SCEVSequentialMinMaxDeduplicatingVisitor final
3999 : public SCEVVisitor<SCEVSequentialMinMaxDeduplicatingVisitor,
4000 std::optional<const SCEV *>> {
4001 using RetVal = std::optional<const SCEV *>;
4003
4004 ScalarEvolution &SE;
4005 const SCEVTypes RootKind; // Must be a sequential min/max expression.
4006 const SCEVTypes NonSequentialRootKind; // Non-sequential variant of RootKind.
4008
4009 bool canRecurseInto(SCEVTypes Kind) const {
4010 // We can only recurse into the SCEV expression of the same effective type
4011 // as the type of our root SCEV expression.
4012 return RootKind == Kind || NonSequentialRootKind == Kind;
4013 };
4014
4015 RetVal visitAnyMinMaxExpr(const SCEV *S) {
4017 "Only for min/max expressions.");
4018 SCEVTypes Kind = S->getSCEVType();
4019
4020 if (!canRecurseInto(Kind))
4021 return S;
4022
4023 auto *NAry = cast<SCEVNAryExpr>(S);
4025 bool Changed = visit(Kind, NAry->operands(), NewOps);
4026
4027 if (!Changed)
4028 return S;
4029 if (NewOps.empty())
4030 return std::nullopt;
4031
4033 ? SE.getSequentialMinMaxExpr(Kind, NewOps)
4034 : SE.getMinMaxExpr(Kind, NewOps);
4035 }
4036
4037 RetVal visit(const SCEV *S) {
4038 // Has the whole operand been seen already?
4039 if (!SeenOps.insert(S).second)
4040 return std::nullopt;
4041 return Base::visit(S);
4042 }
4043
4044public:
4045 SCEVSequentialMinMaxDeduplicatingVisitor(ScalarEvolution &SE,
4046 SCEVTypes RootKind)
4047 : SE(SE), RootKind(RootKind),
4048 NonSequentialRootKind(
4049 SCEVSequentialMinMaxExpr::getEquivalentNonSequentialSCEVType(
4050 RootKind)) {}
4051
4052 bool /*Changed*/ visit(SCEVTypes Kind, ArrayRef<const SCEV *> OrigOps,
4053 SmallVectorImpl<const SCEV *> &NewOps) {
4054 bool Changed = false;
4056 Ops.reserve(OrigOps.size());
4057
4058 for (const SCEV *Op : OrigOps) {
4059 RetVal NewOp = visit(Op);
4060 if (NewOp != Op)
4061 Changed = true;
4062 if (NewOp)
4063 Ops.emplace_back(*NewOp);
4064 }
4065
4066 if (Changed)
4067 NewOps = std::move(Ops);
4068 return Changed;
4069 }
4070
4071 RetVal visitConstant(const SCEVConstant *Constant) { return Constant; }
4072
4073 RetVal visitVScale(const SCEVVScale *VScale) { return VScale; }
4074
4075 RetVal visitPtrToIntExpr(const SCEVPtrToIntExpr *Expr) { return Expr; }
4076
4077 RetVal visitTruncateExpr(const SCEVTruncateExpr *Expr) { return Expr; }
4078
4079 RetVal visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr) { return Expr; }
4080
4081 RetVal visitSignExtendExpr(const SCEVSignExtendExpr *Expr) { return Expr; }
4082
4083 RetVal visitAddExpr(const SCEVAddExpr *Expr) { return Expr; }
4084
4085 RetVal visitMulExpr(const SCEVMulExpr *Expr) { return Expr; }
4086
4087 RetVal visitUDivExpr(const SCEVUDivExpr *Expr) { return Expr; }
4088
4089 RetVal visitAddRecExpr(const SCEVAddRecExpr *Expr) { return Expr; }
4090
4091 RetVal visitSMaxExpr(const SCEVSMaxExpr *Expr) {
4092 return visitAnyMinMaxExpr(Expr);
4093 }
4094
4095 RetVal visitUMaxExpr(const SCEVUMaxExpr *Expr) {
4096 return visitAnyMinMaxExpr(Expr);
4097 }
4098
4099 RetVal visitSMinExpr(const SCEVSMinExpr *Expr) {
4100 return visitAnyMinMaxExpr(Expr);
4101 }
4102
4103 RetVal visitUMinExpr(const SCEVUMinExpr *Expr) {
4104 return visitAnyMinMaxExpr(Expr);
4105 }
4106
4107 RetVal visitSequentialUMinExpr(const SCEVSequentialUMinExpr *Expr) {
4108 return visitAnyMinMaxExpr(Expr);
4109 }
4110
4111 RetVal visitUnknown(const SCEVUnknown *Expr) { return Expr; }
4112
4113 RetVal visitCouldNotCompute(const SCEVCouldNotCompute *Expr) { return Expr; }
4114};
4115
4116} // namespace
4117
4119 switch (Kind) {
4120 case scConstant:
4121 case scVScale:
4122 case scTruncate:
4123 case scZeroExtend:
4124 case scSignExtend:
4125 case scPtrToInt:
4126 case scAddExpr:
4127 case scMulExpr:
4128 case scUDivExpr:
4129 case scAddRecExpr:
4130 case scUMaxExpr:
4131 case scSMaxExpr:
4132 case scUMinExpr:
4133 case scSMinExpr:
4134 case scUnknown:
4135 // If any operand is poison, the whole expression is poison.
4136 return true;
4138 // FIXME: if the *first* operand is poison, the whole expression is poison.
4139 return false; // Pessimistically, say that it does not propagate poison.
4140 case scCouldNotCompute:
4141 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
4142 }
4143 llvm_unreachable("Unknown SCEV kind!");
4144}
4145
4146namespace {
4147// The only way poison may be introduced in a SCEV expression is from a
4148// poison SCEVUnknown (ConstantExprs are also represented as SCEVUnknown,
4149// not SCEVConstant). Notably, nowrap flags in SCEV nodes can *not*
4150// introduce poison -- they encode guaranteed, non-speculated knowledge.
4151//
4152// Additionally, all SCEV nodes propagate poison from inputs to outputs,
4153// with the notable exception of umin_seq, where only poison from the first
4154// operand is (unconditionally) propagated.
4155struct SCEVPoisonCollector {
4156 bool LookThroughMaybePoisonBlocking;
4157 SmallPtrSet<const SCEVUnknown *, 4> MaybePoison;
4158 SCEVPoisonCollector(bool LookThroughMaybePoisonBlocking)
4159 : LookThroughMaybePoisonBlocking(LookThroughMaybePoisonBlocking) {}
4160
4161 bool follow(const SCEV *S) {
4162 if (!LookThroughMaybePoisonBlocking &&
4164 return false;
4165
4166 if (auto *SU = dyn_cast<SCEVUnknown>(S)) {
4167 if (!isGuaranteedNotToBePoison(SU->getValue()))
4168 MaybePoison.insert(SU);
4169 }
4170 return true;
4171 }
4172 bool isDone() const { return false; }
4173};
4174} // namespace
4175
4176/// Return true if V is poison given that AssumedPoison is already poison.
4177static bool impliesPoison(const SCEV *AssumedPoison, const SCEV *S) {
4178 // First collect all SCEVs that might result in AssumedPoison to be poison.
4179 // We need to look through potentially poison-blocking operations here,
4180 // because we want to find all SCEVs that *might* result in poison, not only
4181 // those that are *required* to.
4182 SCEVPoisonCollector PC1(/* LookThroughMaybePoisonBlocking */ true);
4183 visitAll(AssumedPoison, PC1);
4184
4185 // AssumedPoison is never poison. As the assumption is false, the implication
4186 // is true. Don't bother walking the other SCEV in this case.
4187 if (PC1.MaybePoison.empty())
4188 return true;
4189
4190 // Collect all SCEVs in S that, if poison, *will* result in S being poison
4191 // as well. We cannot look through potentially poison-blocking operations
4192 // here, as their arguments only *may* make the result poison.
4193 SCEVPoisonCollector PC2(/* LookThroughMaybePoisonBlocking */ false);
4194 visitAll(S, PC2);
4195
4196 // Make sure that no matter which SCEV in PC1.MaybePoison is actually poison,
4197 // it will also make S poison by being part of PC2.MaybePoison.
4198 return llvm::set_is_subset(PC1.MaybePoison, PC2.MaybePoison);
4199}
4200
4202 SmallPtrSetImpl<const Value *> &Result, const SCEV *S) {
4203 SCEVPoisonCollector PC(/* LookThroughMaybePoisonBlocking */ false);
4204 visitAll(S, PC);
4205 for (const SCEVUnknown *SU : PC.MaybePoison)
4206 Result.insert(SU->getValue());
4207}
4208
4210 const SCEV *S, Instruction *I,
4211 SmallVectorImpl<Instruction *> &DropPoisonGeneratingInsts) {
4212 // If the instruction cannot be poison, it's always safe to reuse.
4214 return true;
4215
4216 // Otherwise, it is possible that I is more poisonous that S. Collect the
4217 // poison-contributors of S, and then check whether I has any additional
4218 // poison-contributors. Poison that is contributed through poison-generating
4219 // flags is handled by dropping those flags instead.
4221 getPoisonGeneratingValues(PoisonVals, S);
4222
4223 SmallVector<Value *> Worklist;
4225 Worklist.push_back(I);
4226 while (!Worklist.empty()) {
4227 Value *V = Worklist.pop_back_val();
4228 if (!Visited.insert(V).second)
4229 continue;
4230
4231 // Avoid walking large instruction graphs.
4232 if (Visited.size() > 16)
4233 return false;
4234
4235 // Either the value can't be poison, or the S would also be poison if it
4236 // is.
4237 if (PoisonVals.contains(V) || ::isGuaranteedNotToBePoison(V))
4238 continue;
4239
4240 auto *I = dyn_cast<Instruction>(V);
4241 if (!I)
4242 return false;
4243
4244 // Disjoint or instructions are interpreted as adds by SCEV. However, we
4245 // can't replace an arbitrary add with disjoint or, even if we drop the
4246 // flag. We would need to convert the or into an add.
4247 if (auto *PDI = dyn_cast<PossiblyDisjointInst>(I))
4248 if (PDI->isDisjoint())
4249 return false;
4250
4251 // FIXME: Ignore vscale, even though it technically could be poison. Do this
4252 // because SCEV currently assumes it can't be poison. Remove this special
4253 // case once we proper model when vscale can be poison.
4254 if (auto *II = dyn_cast<IntrinsicInst>(I);
4255 II && II->getIntrinsicID() == Intrinsic::vscale)
4256 continue;
4257
4258 if (canCreatePoison(cast<Operator>(I), /*ConsiderFlagsAndMetadata*/ false))
4259 return false;
4260
4261 // If the instruction can't create poison, we can recurse to its operands.
4262 if (I->hasPoisonGeneratingAnnotations())
4263 DropPoisonGeneratingInsts.push_back(I);
4264
4265 llvm::append_range(Worklist, I->operands());
4266 }
4267 return true;
4268}
4269
4270const SCEV *
4273 assert(SCEVSequentialMinMaxExpr::isSequentialMinMaxType(Kind) &&
4274 "Not a SCEVSequentialMinMaxExpr!");
4275 assert(!Ops.empty() && "Cannot get empty (u|s)(min|max)!");
4276 if (Ops.size() == 1)
4277 return Ops[0];
4278#ifndef NDEBUG
4279 Type *ETy = getEffectiveSCEVType(Ops[0]->getType());
4280 for (unsigned i = 1, e = Ops.size(); i != e; ++i) {
4281 assert(getEffectiveSCEVType(Ops[i]->getType()) == ETy &&
4282 "Operand types don't match!");
4283 assert(Ops[0]->getType()->isPointerTy() ==
4284 Ops[i]->getType()->isPointerTy() &&
4285 "min/max should be consistently pointerish");
4286 }
4287#endif
4288
4289 // Note that SCEVSequentialMinMaxExpr is *NOT* commutative,
4290 // so we can *NOT* do any kind of sorting of the expressions!
4291
4292 // Check if we have created the same expression before.
4293 if (const SCEV *S = findExistingSCEVInCache(Kind, Ops))
4294 return S;
4295
4296 // FIXME: there are *some* simplifications that we can do here.
4297
4298 // Keep only the first instance of an operand.
4299 {
4300 SCEVSequentialMinMaxDeduplicatingVisitor Deduplicator(*this, Kind);
4301 bool Changed = Deduplicator.visit(Kind, Ops, Ops);
4302 if (Changed)
4303 return getSequentialMinMaxExpr(Kind, Ops);
4304 }
4305
4306 // Check to see if one of the operands is of the same kind. If so, expand its
4307 // operands onto our operand list, and recurse to simplify.
4308 {
4309 unsigned Idx = 0;
4310 bool DeletedAny = false;
4311 while (Idx < Ops.size()) {
4312 if (Ops[Idx]->getSCEVType() != Kind) {
4313 ++Idx;
4314 continue;
4315 }
4316 const auto *SMME = cast<SCEVSequentialMinMaxExpr>(Ops[Idx]);
4317 Ops.erase(Ops.begin() + Idx);
4318 Ops.insert(Ops.begin() + Idx, SMME->operands().begin(),
4319 SMME->operands().end());
4320 DeletedAny = true;
4321 }
4322
4323 if (DeletedAny)
4324 return getSequentialMinMaxExpr(Kind, Ops);
4325 }
4326
4327 const SCEV *SaturationPoint;
4329 switch (Kind) {
4331 SaturationPoint = getZero(Ops[0]->getType());
4332 Pred = ICmpInst::ICMP_ULE;
4333 break;
4334 default:
4335 llvm_unreachable("Not a sequential min/max type.");
4336 }
4337
4338 for (unsigned i = 1, e = Ops.size(); i != e; ++i) {
4339 if (!isGuaranteedNotToCauseUB(Ops[i]))
4340 continue;
4341 // We can replace %x umin_seq %y with %x umin %y if either:
4342 // * %y being poison implies %x is also poison.
4343 // * %x cannot be the saturating value (e.g. zero for umin).
4344 if (::impliesPoison(Ops[i], Ops[i - 1]) ||
4345 isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_NE, Ops[i - 1],
4346 SaturationPoint)) {
4347 SmallVector<const SCEV *> SeqOps = {Ops[i - 1], Ops[i]};
4348 Ops[i - 1] = getMinMaxExpr(
4350 SeqOps);
4351 Ops.erase(Ops.begin() + i);
4352 return getSequentialMinMaxExpr(Kind, Ops);
4353 }
4354 // Fold %x umin_seq %y to %x if %x ule %y.
4355 // TODO: We might be able to prove the predicate for a later operand.
4356 if (isKnownViaNonRecursiveReasoning(Pred, Ops[i - 1], Ops[i])) {
4357 Ops.erase(Ops.begin() + i);
4358 return getSequentialMinMaxExpr(Kind, Ops);
4359 }
4360 }
4361
4362 // Okay, it looks like we really DO need an expr. Check to see if we
4363 // already have one, otherwise create a new one.
4365 ID.AddInteger(Kind);
4366 for (const SCEV *Op : Ops)
4367 ID.AddPointer(Op);
4368 void *IP = nullptr;
4369 const SCEV *ExistingSCEV = UniqueSCEVs.FindNodeOrInsertPos(ID, IP);
4370 if (ExistingSCEV)
4371 return ExistingSCEV;
4372
4373 const SCEV **O = SCEVAllocator.Allocate<const SCEV *>(Ops.size());
4375 SCEV *S = new (SCEVAllocator)
4376 SCEVSequentialMinMaxExpr(ID.Intern(SCEVAllocator), Kind, O, Ops.size());
4377
4378 UniqueSCEVs.InsertNode(S, IP);
4379 registerUser(S, Ops);
4380 return S;
4381}
4382
4383const SCEV *ScalarEvolution::getSMaxExpr(const SCEV *LHS, const SCEV *RHS) {
4384 SmallVector<const SCEV *, 2> Ops = {LHS, RHS};
4385 return getSMaxExpr(Ops);
4386}
4387
4391
4392const SCEV *ScalarEvolution::getUMaxExpr(const SCEV *LHS, const SCEV *RHS) {
4393 SmallVector<const SCEV *, 2> Ops = {LHS, RHS};
4394 return getUMaxExpr(Ops);
4395}
4396
4400
4402 const SCEV *RHS) {
4403 SmallVector<const SCEV *, 2> Ops = { LHS, RHS };
4404 return getSMinExpr(Ops);
4405}
4406
4410
4411const SCEV *ScalarEvolution::getUMinExpr(const SCEV *LHS, const SCEV *RHS,
4412 bool Sequential) {
4413 SmallVector<const SCEV *, 2> Ops = { LHS, RHS };
4414 return getUMinExpr(Ops, Sequential);
4415}
4416
4422
4423const SCEV *
4425 const SCEV *Res = getConstant(IntTy, Size.getKnownMinValue());
4426 if (Size.isScalable())
4427 Res = getMulExpr(Res, getVScale(IntTy));
4428 return Res;
4429}
4430
4432 return getSizeOfExpr(IntTy, getDataLayout().getTypeAllocSize(AllocTy));
4433}
4434
4436 return getSizeOfExpr(IntTy, getDataLayout().getTypeStoreSize(StoreTy));
4437}
4438
4440 StructType *STy,
4441 unsigned FieldNo) {
4442 // We can bypass creating a target-independent constant expression and then
4443 // folding it back into a ConstantInt. This is just a compile-time
4444 // optimization.
4445 const StructLayout *SL = getDataLayout().getStructLayout(STy);
4446 assert(!SL->getSizeInBits().isScalable() &&
4447 "Cannot get offset for structure containing scalable vector types");
4448 return getConstant(IntTy, SL->getElementOffset(FieldNo));
4449}
4450
4452 // Don't attempt to do anything other than create a SCEVUnknown object
4453 // here. createSCEV only calls getUnknown after checking for all other
4454 // interesting possibilities, and any other code that calls getUnknown
4455 // is doing so in order to hide a value from SCEV canonicalization.
4456
4458 ID.AddInteger(scUnknown);
4459 ID.AddPointer(V);
4460 void *IP = nullptr;
4461 if (SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) {
4462 assert(cast<SCEVUnknown>(S)->getValue() == V &&
4463 "Stale SCEVUnknown in uniquing map!");
4464 return S;
4465 }
4466 SCEV *S = new (SCEVAllocator) SCEVUnknown(ID.Intern(SCEVAllocator), V, this,
4467 FirstUnknown);
4468 FirstUnknown = cast<SCEVUnknown>(S);
4469 UniqueSCEVs.InsertNode(S, IP);
4470 return S;
4471}
4472
4473//===----------------------------------------------------------------------===//
4474// Basic SCEV Analysis and PHI Idiom Recognition Code
4475//
4476
4477/// Test if values of the given type are analyzable within the SCEV
4478/// framework. This primarily includes integer types, and it can optionally
4479/// include pointer types if the ScalarEvolution class has access to
4480/// target-specific information.
4482 // Integers and pointers are always SCEVable.
4483 return Ty->isIntOrPtrTy();
4484}
4485
4486/// Return the size in bits of the specified type, for which isSCEVable must
4487/// return true.
4489 assert(isSCEVable(Ty) && "Type is not SCEVable!");
4490 if (Ty->isPointerTy())
4492 return getDataLayout().getTypeSizeInBits(Ty);
4493}
4494
4495/// Return a type with the same bitwidth as the given type and which represents
4496/// how SCEV will treat the given type, for which isSCEVable must return
4497/// true. For pointer types, this is the pointer index sized integer type.
4499 assert(isSCEVable(Ty) && "Type is not SCEVable!");
4500
4501 if (Ty->isIntegerTy())
4502 return Ty;
4503
4504 // The only other support type is pointer.
4505 assert(Ty->isPointerTy() && "Unexpected non-pointer non-integer type!");
4506 return getDataLayout().getIndexType(Ty);
4507}
4508
4510 return getTypeSizeInBits(T1) >= getTypeSizeInBits(T2) ? T1 : T2;
4511}
4512
4514 const SCEV *B) {
4515 /// For a valid use point to exist, the defining scope of one operand
4516 /// must dominate the other.
4517 bool PreciseA, PreciseB;
4518 auto *ScopeA = getDefiningScopeBound({A}, PreciseA);
4519 auto *ScopeB = getDefiningScopeBound({B}, PreciseB);
4520 if (!PreciseA || !PreciseB)
4521 // Can't tell.
4522 return false;
4523 return (ScopeA == ScopeB) || DT.dominates(ScopeA, ScopeB) ||
4524 DT.dominates(ScopeB, ScopeA);
4525}
4526
4528 return CouldNotCompute.get();
4529}
4530
4531bool ScalarEvolution::checkValidity(const SCEV *S) const {
4532 bool ContainsNulls = SCEVExprContains(S, [](const SCEV *S) {
4533 auto *SU = dyn_cast<SCEVUnknown>(S);
4534 return SU && SU->getValue() == nullptr;
4535 });
4536
4537 return !ContainsNulls;
4538}
4539
4541 HasRecMapType::iterator I = HasRecMap.find(S);
4542 if (I != HasRecMap.end())
4543 return I->second;
4544
4545 bool FoundAddRec =
4546 SCEVExprContains(S, [](const SCEV *S) { return isa<SCEVAddRecExpr>(S); });
4547 HasRecMap.insert({S, FoundAddRec});
4548 return FoundAddRec;
4549}
4550
4551/// Return the ValueOffsetPair set for \p S. \p S can be represented
4552/// by the value and offset from any ValueOffsetPair in the set.
4553ArrayRef<Value *> ScalarEvolution::getSCEVValues(const SCEV *S) {
4554 ExprValueMapType::iterator SI = ExprValueMap.find_as(S);
4555 if (SI == ExprValueMap.end())
4556 return {};
4557 return SI->second.getArrayRef();
4558}
4559
4560/// Erase Value from ValueExprMap and ExprValueMap. ValueExprMap.erase(V)
4561/// cannot be used separately. eraseValueFromMap should be used to remove
4562/// V from ValueExprMap and ExprValueMap at the same time.
4563void ScalarEvolution::eraseValueFromMap(Value *V) {
4564 ValueExprMapType::iterator I = ValueExprMap.find_as(V);
4565 if (I != ValueExprMap.end()) {
4566 auto EVIt = ExprValueMap.find(I->second);
4567 bool Removed = EVIt->second.remove(V);
4568 (void) Removed;
4569 assert(Removed && "Value not in ExprValueMap?");
4570 ValueExprMap.erase(I);
4571 }
4572}
4573
4574void ScalarEvolution::insertValueToMap(Value *V, const SCEV *S) {
4575 // A recursive query may have already computed the SCEV. It should be
4576 // equivalent, but may not necessarily be exactly the same, e.g. due to lazily
4577 // inferred nowrap flags.
4578 auto It = ValueExprMap.find_as(V);
4579 if (It == ValueExprMap.end()) {
4580 ValueExprMap.insert({SCEVCallbackVH(V, this), S});
4581 ExprValueMap[S].insert(V);
4582 }
4583}
4584
4585/// Return an existing SCEV if it exists, otherwise analyze the expression and
4586/// create a new one.
4588 assert(isSCEVable(V->getType()) && "Value is not SCEVable!");
4589
4590 if (const SCEV *S = getExistingSCEV(V))
4591 return S;
4592 return createSCEVIter(V);
4593}
4594
4596 assert(isSCEVable(V->getType()) && "Value is not SCEVable!");
4597
4598 ValueExprMapType::iterator I = ValueExprMap.find_as(V);
4599 if (I != ValueExprMap.end()) {
4600 const SCEV *S = I->second;
4601 assert(checkValidity(S) &&
4602 "existing SCEV has not been properly invalidated");
4603 return S;
4604 }
4605 return nullptr;
4606}
4607
4608/// Return a SCEV corresponding to -V = -1*V
4610 SCEV::NoWrapFlags Flags) {
4611 if (const SCEVConstant *VC = dyn_cast<SCEVConstant>(V))
4612 return getConstant(
4613 cast<ConstantInt>(ConstantExpr::getNeg(VC->getValue())));
4614
4615 Type *Ty = V->getType();
4616 Ty = getEffectiveSCEVType(Ty);
4617 return getMulExpr(V, getMinusOne(Ty), Flags);
4618}
4619
4620/// If Expr computes ~A, return A else return nullptr
4621static const SCEV *MatchNotExpr(const SCEV *Expr) {
4622 const SCEV *MulOp;
4623 if (match(Expr, m_scev_Add(m_scev_AllOnes(),
4624 m_scev_Mul(m_scev_AllOnes(), m_SCEV(MulOp)))))
4625 return MulOp;
4626 return nullptr;
4627}
4628
4629/// Return a SCEV corresponding to ~V = -1-V
4631 assert(!V->getType()->isPointerTy() && "Can't negate pointer");
4632
4633 if (const SCEVConstant *VC = dyn_cast<SCEVConstant>(V))
4634 return getConstant(
4635 cast<ConstantInt>(ConstantExpr::getNot(VC->getValue())));
4636
4637 // Fold ~(u|s)(min|max)(~x, ~y) to (u|s)(max|min)(x, y)
4638 if (const SCEVMinMaxExpr *MME = dyn_cast<SCEVMinMaxExpr>(V)) {
4639 auto MatchMinMaxNegation = [&](const SCEVMinMaxExpr *MME) {
4640 SmallVector<const SCEV *, 2> MatchedOperands;
4641 for (const SCEV *Operand : MME->operands()) {
4642 const SCEV *Matched = MatchNotExpr(Operand);
4643 if (!Matched)
4644 return (const SCEV *)nullptr;
4645 MatchedOperands.push_back(Matched);
4646 }
4647 return getMinMaxExpr(SCEVMinMaxExpr::negate(MME->getSCEVType()),
4648 MatchedOperands);
4649 };
4650 if (const SCEV *Replaced = MatchMinMaxNegation(MME))
4651 return Replaced;
4652 }
4653
4654 Type *Ty = V->getType();
4655 Ty = getEffectiveSCEVType(Ty);
4656 return getMinusSCEV(getMinusOne(Ty), V);
4657}
4658
4660 assert(P->getType()->isPointerTy());
4661
4662 if (auto *AddRec = dyn_cast<SCEVAddRecExpr>(P)) {
4663 // The base of an AddRec is the first operand.
4664 SmallVector<const SCEV *> Ops{AddRec->operands()};
4665 Ops[0] = removePointerBase(Ops[0]);
4666 // Don't try to transfer nowrap flags for now. We could in some cases
4667 // (for example, if pointer operand of the AddRec is a SCEVUnknown).
4668 return getAddRecExpr(Ops, AddRec->getLoop(), SCEV::FlagAnyWrap);
4669 }
4670 if (auto *Add = dyn_cast<SCEVAddExpr>(P)) {
4671 // The base of an Add is the pointer operand.
4672 SmallVector<const SCEV *> Ops{Add->operands()};
4673 const SCEV **PtrOp = nullptr;
4674 for (const SCEV *&AddOp : Ops) {
4675 if (AddOp->getType()->isPointerTy()) {
4676 assert(!PtrOp && "Cannot have multiple pointer ops");
4677 PtrOp = &AddOp;
4678 }
4679 }
4680 *PtrOp = removePointerBase(*PtrOp);
4681 // Don't try to transfer nowrap flags for now. We could in some cases
4682 // (for example, if the pointer operand of the Add is a SCEVUnknown).
4683 return getAddExpr(Ops);
4684 }
4685 // Any other expression must be a pointer base.
4686 return getZero(P->getType());
4687}
4688
4689const SCEV *ScalarEvolution::getMinusSCEV(const SCEV *LHS, const SCEV *RHS,
4690 SCEV::NoWrapFlags Flags,
4691 unsigned Depth) {
4692 // Fast path: X - X --> 0.
4693 if (LHS == RHS)
4694 return getZero(LHS->getType());
4695
4696 // If we subtract two pointers with different pointer bases, bail.
4697 // Eventually, we're going to add an assertion to getMulExpr that we
4698 // can't multiply by a pointer.
4699 if (RHS->getType()->isPointerTy()) {
4700 if (!LHS->getType()->isPointerTy() ||
4701 getPointerBase(LHS) != getPointerBase(RHS))
4702 return getCouldNotCompute();
4703 LHS = removePointerBase(LHS);
4704 RHS = removePointerBase(RHS);
4705 }
4706
4707 // We represent LHS - RHS as LHS + (-1)*RHS. This transformation
4708 // makes it so that we cannot make much use of NUW.
4709 auto AddFlags = SCEV::FlagAnyWrap;
4710 const bool RHSIsNotMinSigned =
4712 if (hasFlags(Flags, SCEV::FlagNSW)) {
4713 // Let M be the minimum representable signed value. Then (-1)*RHS
4714 // signed-wraps if and only if RHS is M. That can happen even for
4715 // a NSW subtraction because e.g. (-1)*M signed-wraps even though
4716 // -1 - M does not. So to transfer NSW from LHS - RHS to LHS +
4717 // (-1)*RHS, we need to prove that RHS != M.
4718 //
4719 // If LHS is non-negative and we know that LHS - RHS does not
4720 // signed-wrap, then RHS cannot be M. So we can rule out signed-wrap
4721 // either by proving that RHS > M or that LHS >= 0.
4722 if (RHSIsNotMinSigned || isKnownNonNegative(LHS)) {
4723 AddFlags = SCEV::FlagNSW;
4724 }
4725 }
4726
4727 // FIXME: Find a correct way to transfer NSW to (-1)*M when LHS -
4728 // RHS is NSW and LHS >= 0.
4729 //
4730 // The difficulty here is that the NSW flag may have been proven
4731 // relative to a loop that is to be found in a recurrence in LHS and
4732 // not in RHS. Applying NSW to (-1)*M may then let the NSW have a
4733 // larger scope than intended.
4734 auto NegFlags = RHSIsNotMinSigned ? SCEV::FlagNSW : SCEV::FlagAnyWrap;
4735
4736 return getAddExpr(LHS, getNegativeSCEV(RHS, NegFlags), AddFlags, Depth);
4737}
4738
4740 unsigned Depth) {
4741 Type *SrcTy = V->getType();
4742 assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
4743 "Cannot truncate or zero extend with non-integer arguments!");
4744 if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
4745 return V; // No conversion
4746 if (getTypeSizeInBits(SrcTy) > getTypeSizeInBits(Ty))
4747 return getTruncateExpr(V, Ty, Depth);
4748 return getZeroExtendExpr(V, Ty, Depth);
4749}
4750
4752 unsigned Depth) {
4753 Type *SrcTy = V->getType();
4754 assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
4755 "Cannot truncate or zero extend with non-integer arguments!");
4756 if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
4757 return V; // No conversion
4758 if (getTypeSizeInBits(SrcTy) > getTypeSizeInBits(Ty))
4759 return getTruncateExpr(V, Ty, Depth);
4760 return getSignExtendExpr(V, Ty, Depth);
4761}
4762
4763const SCEV *
4765 Type *SrcTy = V->getType();
4766 assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
4767 "Cannot noop or zero extend with non-integer arguments!");
4769 "getNoopOrZeroExtend cannot truncate!");
4770 if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
4771 return V; // No conversion
4772 return getZeroExtendExpr(V, Ty);
4773}
4774
4775const SCEV *
4777 Type *SrcTy = V->getType();
4778 assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
4779 "Cannot noop or sign extend with non-integer arguments!");
4781 "getNoopOrSignExtend cannot truncate!");
4782 if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
4783 return V; // No conversion
4784 return getSignExtendExpr(V, Ty);
4785}
4786
4787const SCEV *
4789 Type *SrcTy = V->getType();
4790 assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
4791 "Cannot noop or any extend with non-integer arguments!");
4793 "getNoopOrAnyExtend cannot truncate!");
4794 if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
4795 return V; // No conversion
4796 return getAnyExtendExpr(V, Ty);
4797}
4798
4799const SCEV *
4801 Type *SrcTy = V->getType();
4802 assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
4803 "Cannot truncate or noop with non-integer arguments!");
4805 "getTruncateOrNoop cannot extend!");
4806 if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
4807 return V; // No conversion
4808 return getTruncateExpr(V, Ty);
4809}
4810
4812 const SCEV *RHS) {
4813 const SCEV *PromotedLHS = LHS;
4814 const SCEV *PromotedRHS = RHS;
4815
4816 if (getTypeSizeInBits(LHS->getType()) > getTypeSizeInBits(RHS->getType()))
4817 PromotedRHS = getZeroExtendExpr(RHS, LHS->getType());
4818 else
4819 PromotedLHS = getNoopOrZeroExtend(LHS, RHS->getType());
4820
4821 return getUMaxExpr(PromotedLHS, PromotedRHS);
4822}
4823
4825 const SCEV *RHS,
4826 bool Sequential) {
4827 SmallVector<const SCEV *, 2> Ops = { LHS, RHS };
4828 return getUMinFromMismatchedTypes(Ops, Sequential);
4829}
4830
4831const SCEV *
4833 bool Sequential) {
4834 assert(!Ops.empty() && "At least one operand must be!");
4835 // Trivial case.
4836 if (Ops.size() == 1)
4837 return Ops[0];
4838
4839 // Find the max type first.
4840 Type *MaxType = nullptr;
4841 for (const auto *S : Ops)
4842 if (MaxType)
4843 MaxType = getWiderType(MaxType, S->getType());
4844 else
4845 MaxType = S->getType();
4846 assert(MaxType && "Failed to find maximum type!");
4847
4848 // Extend all ops to max type.
4849 SmallVector<const SCEV *, 2> PromotedOps;
4850 for (const auto *S : Ops)
4851 PromotedOps.push_back(getNoopOrZeroExtend(S, MaxType));
4852
4853 // Generate umin.
4854 return getUMinExpr(PromotedOps, Sequential);
4855}
4856
4858 // A pointer operand may evaluate to a nonpointer expression, such as null.
4859 if (!V->getType()->isPointerTy())
4860 return V;
4861
4862 while (true) {
4863 if (auto *AddRec = dyn_cast<SCEVAddRecExpr>(V)) {
4864 V = AddRec->getStart();
4865 } else if (auto *Add = dyn_cast<SCEVAddExpr>(V)) {
4866 const SCEV *PtrOp = nullptr;
4867 for (const SCEV *AddOp : Add->operands()) {
4868 if (AddOp->getType()->isPointerTy()) {
4869 assert(!PtrOp && "Cannot have multiple pointer ops");
4870 PtrOp = AddOp;
4871 }
4872 }
4873 assert(PtrOp && "Must have pointer op");
4874 V = PtrOp;
4875 } else // Not something we can look further into.
4876 return V;
4877 }
4878}
4879
4880/// Push users of the given Instruction onto the given Worklist.
4884 // Push the def-use children onto the Worklist stack.
4885 for (User *U : I->users()) {
4886 auto *UserInsn = cast<Instruction>(U);
4887 if (Visited.insert(UserInsn).second)
4888 Worklist.push_back(UserInsn);
4889 }
4890}
4891
4892namespace {
4893
4894/// Takes SCEV S and Loop L. For each AddRec sub-expression, use its start
4895/// expression in case its Loop is L. If it is not L then
4896/// if IgnoreOtherLoops is true then use AddRec itself
4897/// otherwise rewrite cannot be done.
4898/// If SCEV contains non-invariant unknown SCEV rewrite cannot be done.
4899class SCEVInitRewriter : public SCEVRewriteVisitor<SCEVInitRewriter> {
4900public:
4901 static const SCEV *rewrite(const SCEV *S, const Loop *L, ScalarEvolution &SE,
4902 bool IgnoreOtherLoops = true) {
4903 SCEVInitRewriter Rewriter(L, SE);
4904 const SCEV *Result = Rewriter.visit(S);
4905 if (Rewriter.hasSeenLoopVariantSCEVUnknown())
4906 return SE.getCouldNotCompute();
4907 return Rewriter.hasSeenOtherLoops() && !IgnoreOtherLoops
4908 ? SE.getCouldNotCompute()
4909 : Result;
4910 }
4911
4912 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
4913 if (!SE.isLoopInvariant(Expr, L))
4914 SeenLoopVariantSCEVUnknown = true;
4915 return Expr;
4916 }
4917
4918 const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) {
4919 // Only re-write AddRecExprs for this loop.
4920 if (Expr->getLoop() == L)
4921 return Expr->getStart();
4922 SeenOtherLoops = true;
4923 return Expr;
4924 }
4925
4926 bool hasSeenLoopVariantSCEVUnknown() { return SeenLoopVariantSCEVUnknown; }
4927
4928 bool hasSeenOtherLoops() { return SeenOtherLoops; }
4929
4930private:
4931 explicit SCEVInitRewriter(const Loop *L, ScalarEvolution &SE)
4932 : SCEVRewriteVisitor(SE), L(L) {}
4933
4934 const Loop *L;
4935 bool SeenLoopVariantSCEVUnknown = false;
4936 bool SeenOtherLoops = false;
4937};
4938
4939/// Takes SCEV S and Loop L. For each AddRec sub-expression, use its post
4940/// increment expression in case its Loop is L. If it is not L then
4941/// use AddRec itself.
4942/// If SCEV contains non-invariant unknown SCEV rewrite cannot be done.
4943class SCEVPostIncRewriter : public SCEVRewriteVisitor<SCEVPostIncRewriter> {
4944public:
4945 static const SCEV *rewrite(const SCEV *S, const Loop *L, ScalarEvolution &SE) {
4946 SCEVPostIncRewriter Rewriter(L, SE);
4947 const SCEV *Result = Rewriter.visit(S);
4948 return Rewriter.hasSeenLoopVariantSCEVUnknown()
4949 ? SE.getCouldNotCompute()
4950 : Result;
4951 }
4952
4953 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
4954 if (!SE.isLoopInvariant(Expr, L))
4955 SeenLoopVariantSCEVUnknown = true;
4956 return Expr;
4957 }
4958
4959 const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) {
4960 // Only re-write AddRecExprs for this loop.
4961 if (Expr->getLoop() == L)
4962 return Expr->getPostIncExpr(SE);
4963 SeenOtherLoops = true;
4964 return Expr;
4965 }
4966
4967 bool hasSeenLoopVariantSCEVUnknown() { return SeenLoopVariantSCEVUnknown; }
4968
4969 bool hasSeenOtherLoops() { return SeenOtherLoops; }
4970
4971private:
4972 explicit SCEVPostIncRewriter(const Loop *L, ScalarEvolution &SE)
4973 : SCEVRewriteVisitor(SE), L(L) {}
4974
4975 const Loop *L;
4976 bool SeenLoopVariantSCEVUnknown = false;
4977 bool SeenOtherLoops = false;
4978};
4979
4980/// This class evaluates the compare condition by matching it against the
4981/// condition of loop latch. If there is a match we assume a true value
4982/// for the condition while building SCEV nodes.
4983class SCEVBackedgeConditionFolder
4984 : public SCEVRewriteVisitor<SCEVBackedgeConditionFolder> {
4985public:
4986 static const SCEV *rewrite(const SCEV *S, const Loop *L,
4987 ScalarEvolution &SE) {
4988 bool IsPosBECond = false;
4989 Value *BECond = nullptr;
4990 if (BasicBlock *Latch = L->getLoopLatch()) {
4991 BranchInst *BI = dyn_cast<BranchInst>(Latch->getTerminator());
4992 if (BI && BI->isConditional()) {
4993 assert(BI->getSuccessor(0) != BI->getSuccessor(1) &&
4994 "Both outgoing branches should not target same header!");
4995 BECond = BI->getCondition();
4996 IsPosBECond = BI->getSuccessor(0) == L->getHeader();
4997 } else {
4998 return S;
4999 }
5000 }
5001 SCEVBackedgeConditionFolder Rewriter(L, BECond, IsPosBECond, SE);
5002 return Rewriter.visit(S);
5003 }
5004
5005 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
5006 const SCEV *Result = Expr;
5007 bool InvariantF = SE.isLoopInvariant(Expr, L);
5008
5009 if (!InvariantF) {
5011 switch (I->getOpcode()) {
5012 case Instruction::Select: {
5013 SelectInst *SI = cast<SelectInst>(I);
5014 std::optional<const SCEV *> Res =
5015 compareWithBackedgeCondition(SI->getCondition());
5016 if (Res) {
5017 bool IsOne = cast<SCEVConstant>(*Res)->getValue()->isOne();
5018 Result = SE.getSCEV(IsOne ? SI->getTrueValue() : SI->getFalseValue());
5019 }
5020 break;
5021 }
5022 default: {
5023 std::optional<const SCEV *> Res = compareWithBackedgeCondition(I);
5024 if (Res)
5025 Result = *Res;
5026 break;
5027 }
5028 }
5029 }
5030 return Result;
5031 }
5032
5033private:
5034 explicit SCEVBackedgeConditionFolder(const Loop *L, Value *BECond,
5035 bool IsPosBECond, ScalarEvolution &SE)
5036 : SCEVRewriteVisitor(SE), L(L), BackedgeCond(BECond),
5037 IsPositiveBECond(IsPosBECond) {}
5038
5039 std::optional<const SCEV *> compareWithBackedgeCondition(Value *IC);
5040
5041 const Loop *L;
5042 /// Loop back condition.
5043 Value *BackedgeCond = nullptr;
5044 /// Set to true if loop back is on positive branch condition.
5045 bool IsPositiveBECond;
5046};
5047
5048std::optional<const SCEV *>
5049SCEVBackedgeConditionFolder::compareWithBackedgeCondition(Value *IC) {
5050
5051 // If value matches the backedge condition for loop latch,
5052 // then return a constant evolution node based on loopback
5053 // branch taken.
5054 if (BackedgeCond == IC)
5055 return IsPositiveBECond ? SE.getOne(Type::getInt1Ty(SE.getContext()))
5057 return std::nullopt;
5058}
5059
5060class SCEVShiftRewriter : public SCEVRewriteVisitor<SCEVShiftRewriter> {
5061public:
5062 static const SCEV *rewrite(const SCEV *S, const Loop *L,
5063 ScalarEvolution &SE) {
5064 SCEVShiftRewriter Rewriter(L, SE);
5065 const SCEV *Result = Rewriter.visit(S);
5066 return Rewriter.isValid() ? Result : SE.getCouldNotCompute();
5067 }
5068
5069 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
5070 // Only allow AddRecExprs for this loop.
5071 if (!SE.isLoopInvariant(Expr, L))
5072 Valid = false;
5073 return Expr;
5074 }
5075
5076 const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) {
5077 if (Expr->getLoop() == L && Expr->isAffine())
5078 return SE.getMinusSCEV(Expr, Expr->getStepRecurrence(SE));
5079 Valid = false;
5080 return Expr;
5081 }
5082
5083 bool isValid() { return Valid; }
5084
5085private:
5086 explicit SCEVShiftRewriter(const Loop *L, ScalarEvolution &SE)
5087 : SCEVRewriteVisitor(SE), L(L) {}
5088
5089 const Loop *L;
5090 bool Valid = true;
5091};
5092
5093} // end anonymous namespace
5094
5096ScalarEvolution::proveNoWrapViaConstantRanges(const SCEVAddRecExpr *AR) {
5097 if (!AR->isAffine())
5098 return SCEV::FlagAnyWrap;
5099
5100 using OBO = OverflowingBinaryOperator;
5101
5103
5104 if (!AR->hasNoSelfWrap()) {
5105 const SCEV *BECount = getConstantMaxBackedgeTakenCount(AR->getLoop());
5106 if (const SCEVConstant *BECountMax = dyn_cast<SCEVConstant>(BECount)) {
5107 ConstantRange StepCR = getSignedRange(AR->getStepRecurrence(*this));
5108 const APInt &BECountAP = BECountMax->getAPInt();
5109 unsigned NoOverflowBitWidth =
5110 BECountAP.getActiveBits() + StepCR.getMinSignedBits();
5111 if (NoOverflowBitWidth <= getTypeSizeInBits(AR->getType()))
5113 }
5114 }
5115
5116 if (!AR->hasNoSignedWrap()) {
5117 ConstantRange AddRecRange = getSignedRange(AR);
5118 ConstantRange IncRange = getSignedRange(AR->getStepRecurrence(*this));
5119
5121 Instruction::Add, IncRange, OBO::NoSignedWrap);
5122 if (NSWRegion.contains(AddRecRange))
5124 }
5125
5126 if (!AR->hasNoUnsignedWrap()) {
5127 ConstantRange AddRecRange = getUnsignedRange(AR);
5128 ConstantRange IncRange = getUnsignedRange(AR->getStepRecurrence(*this));
5129
5131 Instruction::Add, IncRange, OBO::NoUnsignedWrap);
5132 if (NUWRegion.contains(AddRecRange))
5134 }
5135
5136 return Result;
5137}
5138
5140ScalarEvolution::proveNoSignedWrapViaInduction(const SCEVAddRecExpr *AR) {
5142
5143 if (AR->hasNoSignedWrap())
5144 return Result;
5145
5146 if (!AR->isAffine())
5147 return Result;
5148
5149 // This function can be expensive, only try to prove NSW once per AddRec.
5150 if (!SignedWrapViaInductionTried.insert(AR).second)
5151 return Result;
5152
5153 const SCEV *Step = AR->getStepRecurrence(*this);
5154 const Loop *L = AR->getLoop();
5155
5156 // Check whether the backedge-taken count is SCEVCouldNotCompute.
5157 // Note that this serves two purposes: It filters out loops that are
5158 // simply not analyzable, and it covers the case where this code is
5159 // being called from within backedge-taken count analysis, such that
5160 // attempting to ask for the backedge-taken count would likely result
5161 // in infinite recursion. In the later case, the analysis code will
5162 // cope with a conservative value, and it will take care to purge
5163 // that value once it has finished.
5164 const SCEV *MaxBECount = getConstantMaxBackedgeTakenCount(L);
5165
5166 // Normally, in the cases we can prove no-overflow via a
5167 // backedge guarding condition, we can also compute a backedge
5168 // taken count for the loop. The exceptions are assumptions and
5169 // guards present in the loop -- SCEV is not great at exploiting
5170 // these to compute max backedge taken counts, but can still use
5171 // these to prove lack of overflow. Use this fact to avoid
5172 // doing extra work that may not pay off.
5173
5174 if (isa<SCEVCouldNotCompute>(MaxBECount) && !HasGuards &&
5175 AC.assumptions().empty())
5176 return Result;
5177
5178 // If the backedge is guarded by a comparison with the pre-inc value the
5179 // addrec is safe. Also, if the entry is guarded by a comparison with the
5180 // start value and the backedge is guarded by a comparison with the post-inc
5181 // value, the addrec is safe.
5183 const SCEV *OverflowLimit =
5184 getSignedOverflowLimitForStep(Step, &Pred, this);
5185 if (OverflowLimit &&
5186 (isLoopBackedgeGuardedByCond(L, Pred, AR, OverflowLimit) ||
5187 isKnownOnEveryIteration(Pred, AR, OverflowLimit))) {
5188 Result = setFlags(Result, SCEV::FlagNSW);
5189 }
5190 return Result;
5191}
5193ScalarEvolution::proveNoUnsignedWrapViaInduction(const SCEVAddRecExpr *AR) {
5195
5196 if (AR->hasNoUnsignedWrap())
5197 return Result;
5198
5199 if (!AR->isAffine())
5200 return Result;
5201
5202 // This function can be expensive, only try to prove NUW once per AddRec.
5203 if (!UnsignedWrapViaInductionTried.insert(AR).second)
5204 return Result;
5205
5206 const SCEV *Step = AR->getStepRecurrence(*this);
5207 unsigned BitWidth = getTypeSizeInBits(AR->getType());
5208 const Loop *L = AR->getLoop();
5209
5210 // Check whether the backedge-taken count is SCEVCouldNotCompute.
5211 // Note that this serves two purposes: It filters out loops that are
5212 // simply not analyzable, and it covers the case where this code is
5213 // being called from within backedge-taken count analysis, such that
5214 // attempting to ask for the backedge-taken count would likely result
5215 // in infinite recursion. In the later case, the analysis code will
5216 // cope with a conservative value, and it will take care to purge
5217 // that value once it has finished.
5218 const SCEV *MaxBECount = getConstantMaxBackedgeTakenCount(L);
5219
5220 // Normally, in the cases we can prove no-overflow via a
5221 // backedge guarding condition, we can also compute a backedge
5222 // taken count for the loop. The exceptions are assumptions and
5223 // guards present in the loop -- SCEV is not great at exploiting
5224 // these to compute max backedge taken counts, but can still use
5225 // these to prove lack of overflow. Use this fact to avoid
5226 // doing extra work that may not pay off.
5227
5228 if (isa<SCEVCouldNotCompute>(MaxBECount) && !HasGuards &&
5229 AC.assumptions().empty())
5230 return Result;
5231
5232 // If the backedge is guarded by a comparison with the pre-inc value the
5233 // addrec is safe. Also, if the entry is guarded by a comparison with the
5234 // start value and the backedge is guarded by a comparison with the post-inc
5235 // value, the addrec is safe.
5236 if (isKnownPositive(Step)) {
5237 const SCEV *N = getConstant(APInt::getMinValue(BitWidth) -
5238 getUnsignedRangeMax(Step));
5241 Result = setFlags(Result, SCEV::FlagNUW);
5242 }
5243 }
5244
5245 return Result;
5246}
5247
5248namespace {
5249
5250/// Represents an abstract binary operation. This may exist as a
5251/// normal instruction or constant expression, or may have been
5252/// derived from an expression tree.
5253struct BinaryOp {
5254 unsigned Opcode;
5255 Value *LHS;
5256 Value *RHS;
5257 bool IsNSW = false;
5258 bool IsNUW = false;
5259
5260 /// Op is set if this BinaryOp corresponds to a concrete LLVM instruction or
5261 /// constant expression.
5262 Operator *Op = nullptr;
5263
5264 explicit BinaryOp(Operator *Op)
5265 : Opcode(Op->getOpcode()), LHS(Op->getOperand(0)), RHS(Op->getOperand(1)),
5266 Op(Op) {
5267 if (auto *OBO = dyn_cast<OverflowingBinaryOperator>(Op)) {
5268 IsNSW = OBO->hasNoSignedWrap();
5269 IsNUW = OBO->hasNoUnsignedWrap();
5270 }
5271 }
5272
5273 explicit BinaryOp(unsigned Opcode, Value *LHS, Value *RHS, bool IsNSW = false,
5274 bool IsNUW = false)
5275 : Opcode(Opcode), LHS(LHS), RHS(RHS), IsNSW(IsNSW), IsNUW(IsNUW) {}
5276};
5277
5278} // end anonymous namespace
5279
5280/// Try to map \p V into a BinaryOp, and return \c std::nullopt on failure.
5281static std::optional<BinaryOp> MatchBinaryOp(Value *V, const DataLayout &DL,
5282 AssumptionCache &AC,
5283 const DominatorTree &DT,
5284 const Instruction *CxtI) {
5285 auto *Op = dyn_cast<Operator>(V);
5286 if (!Op)
5287 return std::nullopt;
5288
5289 // Implementation detail: all the cleverness here should happen without
5290 // creating new SCEV expressions -- our caller knowns tricks to avoid creating
5291 // SCEV expressions when possible, and we should not break that.
5292
5293 switch (Op->getOpcode()) {
5294 case Instruction::Add:
5295 case Instruction::Sub:
5296 case Instruction::Mul:
5297 case Instruction::UDiv:
5298 case Instruction::URem:
5299 case Instruction::And:
5300 case Instruction::AShr:
5301 case Instruction::Shl:
5302 return BinaryOp(Op);
5303
5304 case Instruction::Or: {
5305 // Convert or disjoint into add nuw nsw.
5306 if (cast<PossiblyDisjointInst>(Op)->isDisjoint())
5307 return BinaryOp(Instruction::Add, Op->getOperand(0), Op->getOperand(1),
5308 /*IsNSW=*/true, /*IsNUW=*/true);
5309 return BinaryOp(Op);
5310 }
5311
5312 case Instruction::Xor:
5313 if (auto *RHSC = dyn_cast<ConstantInt>(Op->getOperand(1)))
5314 // If the RHS of the xor is a signmask, then this is just an add.
5315 // Instcombine turns add of signmask into xor as a strength reduction step.
5316 if (RHSC->getValue().isSignMask())
5317 return BinaryOp(Instruction::Add, Op->getOperand(0), Op->getOperand(1));
5318 // Binary `xor` is a bit-wise `add`.
5319 if (V->getType()->isIntegerTy(1))
5320 return BinaryOp(Instruction::Add, Op->getOperand(0), Op->getOperand(1));
5321 return BinaryOp(Op);
5322
5323 case Instruction::LShr:
5324 // Turn logical shift right of a constant into a unsigned divide.
5325 if (ConstantInt *SA = dyn_cast<ConstantInt>(Op->getOperand(1))) {
5326 uint32_t BitWidth = cast<IntegerType>(Op->getType())->getBitWidth();
5327
5328 // If the shift count is not less than the bitwidth, the result of
5329 // the shift is undefined. Don't try to analyze it, because the
5330 // resolution chosen here may differ from the resolution chosen in
5331 // other parts of the compiler.
5332 if (SA->getValue().ult(BitWidth)) {
5333 Constant *X =
5334 ConstantInt::get(SA->getContext(),
5335 APInt::getOneBitSet(BitWidth, SA->getZExtValue()));
5336 return BinaryOp(Instruction::UDiv, Op->getOperand(0), X);
5337 }
5338 }
5339 return BinaryOp(Op);
5340
5341 case Instruction::ExtractValue: {
5342 auto *EVI = cast<ExtractValueInst>(Op);
5343 if (EVI->getNumIndices() != 1 || EVI->getIndices()[0] != 0)
5344 break;
5345
5346 auto *WO = dyn_cast<WithOverflowInst>(EVI->getAggregateOperand());
5347 if (!WO)
5348 break;
5349
5350 Instruction::BinaryOps BinOp = WO->getBinaryOp();
5351 bool Signed = WO->isSigned();
5352 // TODO: Should add nuw/nsw flags for mul as well.
5353 if (BinOp == Instruction::Mul || !isOverflowIntrinsicNoWrap(WO, DT))
5354 return BinaryOp(BinOp, WO->getLHS(), WO->getRHS());
5355
5356 // Now that we know that all uses of the arithmetic-result component of
5357 // CI are guarded by the overflow check, we can go ahead and pretend
5358 // that the arithmetic is non-overflowing.
5359 return BinaryOp(BinOp, WO->getLHS(), WO->getRHS(),
5360 /* IsNSW = */ Signed, /* IsNUW = */ !Signed);
5361 }
5362
5363 default:
5364 break;
5365 }
5366
5367 // Recognise intrinsic loop.decrement.reg, and as this has exactly the same
5368 // semantics as a Sub, return a binary sub expression.
5369 if (auto *II = dyn_cast<IntrinsicInst>(V))
5370 if (II->getIntrinsicID() == Intrinsic::loop_decrement_reg)
5371 return BinaryOp(Instruction::Sub, II->getOperand(0), II->getOperand(1));
5372
5373 return std::nullopt;
5374}
5375
5376/// Helper function to createAddRecFromPHIWithCasts. We have a phi
5377/// node whose symbolic (unknown) SCEV is \p SymbolicPHI, which is updated via
5378/// the loop backedge by a SCEVAddExpr, possibly also with a few casts on the
5379/// way. This function checks if \p Op, an operand of this SCEVAddExpr,
5380/// follows one of the following patterns:
5381/// Op == (SExt ix (Trunc iy (%SymbolicPHI) to ix) to iy)
5382/// Op == (ZExt ix (Trunc iy (%SymbolicPHI) to ix) to iy)
5383/// If the SCEV expression of \p Op conforms with one of the expected patterns
5384/// we return the type of the truncation operation, and indicate whether the
5385/// truncated type should be treated as signed/unsigned by setting
5386/// \p Signed to true/false, respectively.
5387static Type *isSimpleCastedPHI(const SCEV *Op, const SCEVUnknown *SymbolicPHI,
5388 bool &Signed, ScalarEvolution &SE) {
5389 // The case where Op == SymbolicPHI (that is, with no type conversions on
5390 // the way) is handled by the regular add recurrence creating logic and
5391 // would have already been triggered in createAddRecForPHI. Reaching it here
5392 // means that createAddRecFromPHI had failed for this PHI before (e.g.,
5393 // because one of the other operands of the SCEVAddExpr updating this PHI is
5394 // not invariant).
5395 //
5396 // Here we look for the case where Op = (ext(trunc(SymbolicPHI))), and in
5397 // this case predicates that allow us to prove that Op == SymbolicPHI will
5398 // be added.
5399 if (Op == SymbolicPHI)
5400 return nullptr;
5401
5402 unsigned SourceBits = SE.getTypeSizeInBits(SymbolicPHI->getType());
5403 unsigned NewBits = SE.getTypeSizeInBits(Op->getType());
5404 if (SourceBits != NewBits)
5405 return nullptr;
5406
5407 if (match(Op, m_scev_SExt(m_scev_Trunc(m_scev_Specific(SymbolicPHI))))) {
5408 Signed = true;
5409 return cast<SCEVCastExpr>(Op)->getOperand()->getType();
5410 }
5411 if (match(Op, m_scev_ZExt(m_scev_Trunc(m_scev_Specific(SymbolicPHI))))) {
5412 Signed = false;
5413 return cast<SCEVCastExpr>(Op)->getOperand()->getType();
5414 }
5415 return nullptr;
5416}
5417
5418static const Loop *isIntegerLoopHeaderPHI(const PHINode *PN, LoopInfo &LI) {
5419 if (!PN->getType()->isIntegerTy())
5420 return nullptr;
5421 const Loop *L = LI.getLoopFor(PN->getParent());
5422 if (!L || L->getHeader() != PN->getParent())
5423 return nullptr;
5424 return L;
5425}
5426
5427// Analyze \p SymbolicPHI, a SCEV expression of a phi node, and check if the
5428// computation that updates the phi follows the following pattern:
5429// (SExt/ZExt ix (Trunc iy (%SymbolicPHI) to ix) to iy) + InvariantAccum
5430// which correspond to a phi->trunc->sext/zext->add->phi update chain.
5431// If so, try to see if it can be rewritten as an AddRecExpr under some
5432// Predicates. If successful, return them as a pair. Also cache the results
5433// of the analysis.
5434//
5435// Example usage scenario:
5436// Say the Rewriter is called for the following SCEV:
5437// 8 * ((sext i32 (trunc i64 %X to i32) to i64) + %Step)
5438// where:
5439// %X = phi i64 (%Start, %BEValue)
5440// It will visitMul->visitAdd->visitSExt->visitTrunc->visitUnknown(%X),
5441// and call this function with %SymbolicPHI = %X.
5442//
5443// The analysis will find that the value coming around the backedge has
5444// the following SCEV:
5445// BEValue = ((sext i32 (trunc i64 %X to i32) to i64) + %Step)
5446// Upon concluding that this matches the desired pattern, the function
5447// will return the pair {NewAddRec, SmallPredsVec} where:
5448// NewAddRec = {%Start,+,%Step}
5449// SmallPredsVec = {P1, P2, P3} as follows:
5450// P1(WrapPred): AR: {trunc(%Start),+,(trunc %Step)}<nsw> Flags: <nssw>
5451// P2(EqualPred): %Start == (sext i32 (trunc i64 %Start to i32) to i64)
5452// P3(EqualPred): %Step == (sext i32 (trunc i64 %Step to i32) to i64)
5453// The returned pair means that SymbolicPHI can be rewritten into NewAddRec
5454// under the predicates {P1,P2,P3}.
5455// This predicated rewrite will be cached in PredicatedSCEVRewrites:
5456// PredicatedSCEVRewrites[{%X,L}] = {NewAddRec, {P1,P2,P3)}
5457//
5458// TODO's:
5459//
5460// 1) Extend the Induction descriptor to also support inductions that involve
5461// casts: When needed (namely, when we are called in the context of the
5462// vectorizer induction analysis), a Set of cast instructions will be
5463// populated by this method, and provided back to isInductionPHI. This is
5464// needed to allow the vectorizer to properly record them to be ignored by
5465// the cost model and to avoid vectorizing them (otherwise these casts,
5466// which are redundant under the runtime overflow checks, will be
5467// vectorized, which can be costly).
5468//
5469// 2) Support additional induction/PHISCEV patterns: We also want to support
5470// inductions where the sext-trunc / zext-trunc operations (partly) occur
5471// after the induction update operation (the induction increment):
5472//
5473// (Trunc iy (SExt/ZExt ix (%SymbolicPHI + InvariantAccum) to iy) to ix)
5474// which correspond to a phi->add->trunc->sext/zext->phi update chain.
5475//
5476// (Trunc iy ((SExt/ZExt ix (%SymbolicPhi) to iy) + InvariantAccum) to ix)
5477// which correspond to a phi->trunc->add->sext/zext->phi update chain.
5478//
5479// 3) Outline common code with createAddRecFromPHI to avoid duplication.
5480std::optional<std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>>>
5481ScalarEvolution::createAddRecFromPHIWithCastsImpl(const SCEVUnknown *SymbolicPHI) {
5483
5484 // *** Part1: Analyze if we have a phi-with-cast pattern for which we can
5485 // return an AddRec expression under some predicate.
5486
5487 auto *PN = cast<PHINode>(SymbolicPHI->getValue());
5488 const Loop *L = isIntegerLoopHeaderPHI(PN, LI);
5489 assert(L && "Expecting an integer loop header phi");
5490
5491 // The loop may have multiple entrances or multiple exits; we can analyze
5492 // this phi as an addrec if it has a unique entry value and a unique
5493 // backedge value.
5494 Value *BEValueV = nullptr, *StartValueV = nullptr;
5495 for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) {
5496 Value *V = PN->getIncomingValue(i);
5497 if (L->contains(PN->getIncomingBlock(i))) {
5498 if (!BEValueV) {
5499 BEValueV = V;
5500 } else if (BEValueV != V) {
5501 BEValueV = nullptr;
5502 break;
5503 }
5504 } else if (!StartValueV) {
5505 StartValueV = V;
5506 } else if (StartValueV != V) {
5507 StartValueV = nullptr;
5508 break;
5509 }
5510 }
5511 if (!BEValueV || !StartValueV)
5512 return std::nullopt;
5513
5514 const SCEV *BEValue = getSCEV(BEValueV);
5515
5516 // If the value coming around the backedge is an add with the symbolic
5517 // value we just inserted, possibly with casts that we can ignore under
5518 // an appropriate runtime guard, then we found a simple induction variable!
5519 const auto *Add = dyn_cast<SCEVAddExpr>(BEValue);
5520 if (!Add)
5521 return std::nullopt;
5522
5523 // If there is a single occurrence of the symbolic value, possibly
5524 // casted, replace it with a recurrence.
5525 unsigned FoundIndex = Add->getNumOperands();
5526 Type *TruncTy = nullptr;
5527 bool Signed;
5528 for (unsigned i = 0, e = Add->getNumOperands(); i != e; ++i)
5529 if ((TruncTy =
5530 isSimpleCastedPHI(Add->getOperand(i), SymbolicPHI, Signed, *this)))
5531 if (FoundIndex == e) {
5532 FoundIndex = i;
5533 break;
5534 }
5535
5536 if (FoundIndex == Add->getNumOperands())
5537 return std::nullopt;
5538
5539 // Create an add with everything but the specified operand.
5541 for (unsigned i = 0, e = Add->getNumOperands(); i != e; ++i)
5542 if (i != FoundIndex)
5543 Ops.push_back(Add->getOperand(i));
5544 const SCEV *Accum = getAddExpr(Ops);
5545
5546 // The runtime checks will not be valid if the step amount is
5547 // varying inside the loop.
5548 if (!isLoopInvariant(Accum, L))
5549 return std::nullopt;
5550
5551 // *** Part2: Create the predicates
5552
5553 // Analysis was successful: we have a phi-with-cast pattern for which we
5554 // can return an AddRec expression under the following predicates:
5555 //
5556 // P1: A Wrap predicate that guarantees that Trunc(Start) + i*Trunc(Accum)
5557 // fits within the truncated type (does not overflow) for i = 0 to n-1.
5558 // P2: An Equal predicate that guarantees that
5559 // Start = (Ext ix (Trunc iy (Start) to ix) to iy)
5560 // P3: An Equal predicate that guarantees that
5561 // Accum = (Ext ix (Trunc iy (Accum) to ix) to iy)
5562 //
5563 // As we next prove, the above predicates guarantee that:
5564 // Start + i*Accum = (Ext ix (Trunc iy ( Start + i*Accum ) to ix) to iy)
5565 //
5566 //
5567 // More formally, we want to prove that:
5568 // Expr(i+1) = Start + (i+1) * Accum
5569 // = (Ext ix (Trunc iy (Expr(i)) to ix) to iy) + Accum
5570 //
5571 // Given that:
5572 // 1) Expr(0) = Start
5573 // 2) Expr(1) = Start + Accum
5574 // = (Ext ix (Trunc iy (Start) to ix) to iy) + Accum :: from P2
5575 // 3) Induction hypothesis (step i):
5576 // Expr(i) = (Ext ix (Trunc iy (Expr(i-1)) to ix) to iy) + Accum
5577 //
5578 // Proof:
5579 // Expr(i+1) =
5580 // = Start + (i+1)*Accum
5581 // = (Start + i*Accum) + Accum
5582 // = Expr(i) + Accum
5583 // = (Ext ix (Trunc iy (Expr(i-1)) to ix) to iy) + Accum + Accum
5584 // :: from step i
5585 //
5586 // = (Ext ix (Trunc iy (Start + (i-1)*Accum) to ix) to iy) + Accum + Accum
5587 //
5588 // = (Ext ix (Trunc iy (Start + (i-1)*Accum) to ix) to iy)
5589 // + (Ext ix (Trunc iy (Accum) to ix) to iy)
5590 // + Accum :: from P3
5591 //
5592 // = (Ext ix (Trunc iy ((Start + (i-1)*Accum) + Accum) to ix) to iy)
5593 // + Accum :: from P1: Ext(x)+Ext(y)=>Ext(x+y)
5594 //
5595 // = (Ext ix (Trunc iy (Start + i*Accum) to ix) to iy) + Accum
5596 // = (Ext ix (Trunc iy (Expr(i)) to ix) to iy) + Accum
5597 //
5598 // By induction, the same applies to all iterations 1<=i<n:
5599 //
5600
5601 // Create a truncated addrec for which we will add a no overflow check (P1).
5602 const SCEV *StartVal = getSCEV(StartValueV);
5603 const SCEV *PHISCEV =
5604 getAddRecExpr(getTruncateExpr(StartVal, TruncTy),
5605 getTruncateExpr(Accum, TruncTy), L, SCEV::FlagAnyWrap);
5606
5607 // PHISCEV can be either a SCEVConstant or a SCEVAddRecExpr.
5608 // ex: If truncated Accum is 0 and StartVal is a constant, then PHISCEV
5609 // will be constant.
5610 //
5611 // If PHISCEV is a constant, then P1 degenerates into P2 or P3, so we don't
5612 // add P1.
5613 if (const auto *AR = dyn_cast<SCEVAddRecExpr>(PHISCEV)) {
5617 const SCEVPredicate *AddRecPred = getWrapPredicate(AR, AddedFlags);
5618 Predicates.push_back(AddRecPred);
5619 }
5620
5621 // Create the Equal Predicates P2,P3:
5622
5623 // It is possible that the predicates P2 and/or P3 are computable at
5624 // compile time due to StartVal and/or Accum being constants.
5625 // If either one is, then we can check that now and escape if either P2
5626 // or P3 is false.
5627
5628 // Construct the extended SCEV: (Ext ix (Trunc iy (Expr) to ix) to iy)
5629 // for each of StartVal and Accum
5630 auto getExtendedExpr = [&](const SCEV *Expr,
5631 bool CreateSignExtend) -> const SCEV * {
5632 assert(isLoopInvariant(Expr, L) && "Expr is expected to be invariant");
5633 const SCEV *TruncatedExpr = getTruncateExpr(Expr, TruncTy);
5634 const SCEV *ExtendedExpr =
5635 CreateSignExtend ? getSignExtendExpr(TruncatedExpr, Expr->getType())
5636 : getZeroExtendExpr(TruncatedExpr, Expr->getType());
5637 return ExtendedExpr;
5638 };
5639
5640 // Given:
5641 // ExtendedExpr = (Ext ix (Trunc iy (Expr) to ix) to iy
5642 // = getExtendedExpr(Expr)
5643 // Determine whether the predicate P: Expr == ExtendedExpr
5644 // is known to be false at compile time
5645 auto PredIsKnownFalse = [&](const SCEV *Expr,
5646 const SCEV *ExtendedExpr) -> bool {
5647 return Expr != ExtendedExpr &&
5648 isKnownPredicate(ICmpInst::ICMP_NE, Expr, ExtendedExpr);
5649 };
5650
5651 const SCEV *StartExtended = getExtendedExpr(StartVal, Signed);
5652 if (PredIsKnownFalse(StartVal, StartExtended)) {
5653 LLVM_DEBUG(dbgs() << "P2 is compile-time false\n";);
5654 return std::nullopt;
5655 }
5656
5657 // The Step is always Signed (because the overflow checks are either
5658 // NSSW or NUSW)
5659 const SCEV *AccumExtended = getExtendedExpr(Accum, /*CreateSignExtend=*/true);
5660 if (PredIsKnownFalse(Accum, AccumExtended)) {
5661 LLVM_DEBUG(dbgs() << "P3 is compile-time false\n";);
5662 return std::nullopt;
5663 }
5664
5665 auto AppendPredicate = [&](const SCEV *Expr,
5666 const SCEV *ExtendedExpr) -> void {
5667 if (Expr != ExtendedExpr &&
5668 !isKnownPredicate(ICmpInst::ICMP_EQ, Expr, ExtendedExpr)) {
5669 const SCEVPredicate *Pred = getEqualPredicate(Expr, ExtendedExpr);
5670 LLVM_DEBUG(dbgs() << "Added Predicate: " << *Pred);
5671 Predicates.push_back(Pred);
5672 }
5673 };
5674
5675 AppendPredicate(StartVal, StartExtended);
5676 AppendPredicate(Accum, AccumExtended);
5677
5678 // *** Part3: Predicates are ready. Now go ahead and create the new addrec in
5679 // which the casts had been folded away. The caller can rewrite SymbolicPHI
5680 // into NewAR if it will also add the runtime overflow checks specified in
5681 // Predicates.
5682 auto *NewAR = getAddRecExpr(StartVal, Accum, L, SCEV::FlagAnyWrap);
5683
5684 std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>> PredRewrite =
5685 std::make_pair(NewAR, Predicates);
5686 // Remember the result of the analysis for this SCEV at this locayyytion.
5687 PredicatedSCEVRewrites[{SymbolicPHI, L}] = PredRewrite;
5688 return PredRewrite;
5689}
5690
5691std::optional<std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>>>
5693 auto *PN = cast<PHINode>(SymbolicPHI->getValue());
5694 const Loop *L = isIntegerLoopHeaderPHI(PN, LI);
5695 if (!L)
5696 return std::nullopt;
5697
5698 // Check to see if we already analyzed this PHI.
5699 auto I = PredicatedSCEVRewrites.find({SymbolicPHI, L});
5700 if (I != PredicatedSCEVRewrites.end()) {
5701 std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>> Rewrite =
5702 I->second;
5703 // Analysis was done before and failed to create an AddRec:
5704 if (Rewrite.first == SymbolicPHI)
5705 return std::nullopt;
5706 // Analysis was done before and succeeded to create an AddRec under
5707 // a predicate:
5708 assert(isa<SCEVAddRecExpr>(Rewrite.first) && "Expected an AddRec");
5709 assert(!(Rewrite.second).empty() && "Expected to find Predicates");
5710 return Rewrite;
5711 }
5712
5713 std::optional<std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>>>
5714 Rewrite = createAddRecFromPHIWithCastsImpl(SymbolicPHI);
5715
5716 // Record in the cache that the analysis failed
5717 if (!Rewrite) {
5719 PredicatedSCEVRewrites[{SymbolicPHI, L}] = {SymbolicPHI, Predicates};
5720 return std::nullopt;
5721 }
5722
5723 return Rewrite;
5724}
5725
5726// FIXME: This utility is currently required because the Rewriter currently
5727// does not rewrite this expression:
5728// {0, +, (sext ix (trunc iy to ix) to iy)}
5729// into {0, +, %step},
5730// even when the following Equal predicate exists:
5731// "%step == (sext ix (trunc iy to ix) to iy)".
5733 const SCEVAddRecExpr *AR1, const SCEVAddRecExpr *AR2) const {
5734 if (AR1 == AR2)
5735 return true;
5736
5737 auto areExprsEqual = [&](const SCEV *Expr1, const SCEV *Expr2) -> bool {
5738 if (Expr1 != Expr2 &&
5739 !Preds->implies(SE.getEqualPredicate(Expr1, Expr2), SE) &&
5740 !Preds->implies(SE.getEqualPredicate(Expr2, Expr1), SE))
5741 return false;
5742 return true;
5743 };
5744
5745 if (!areExprsEqual(AR1->getStart(), AR2->getStart()) ||
5746 !areExprsEqual(AR1->getStepRecurrence(SE), AR2->getStepRecurrence(SE)))
5747 return false;
5748 return true;
5749}
5750
5751/// A helper function for createAddRecFromPHI to handle simple cases.
5752///
5753/// This function tries to find an AddRec expression for the simplest (yet most
5754/// common) cases: PN = PHI(Start, OP(Self, LoopInvariant)).
5755/// If it fails, createAddRecFromPHI will use a more general, but slow,
5756/// technique for finding the AddRec expression.
5757const SCEV *ScalarEvolution::createSimpleAffineAddRec(PHINode *PN,
5758 Value *BEValueV,
5759 Value *StartValueV) {
5760 const Loop *L = LI.getLoopFor(PN->getParent());
5761 assert(L && L->getHeader() == PN->getParent());
5762 assert(BEValueV && StartValueV);
5763
5764 auto BO = MatchBinaryOp(BEValueV, getDataLayout(), AC, DT, PN);
5765 if (!BO)
5766 return nullptr;
5767
5768 if (BO->Opcode != Instruction::Add)
5769 return nullptr;
5770
5771 const SCEV *Accum = nullptr;
5772 if (BO->LHS == PN && L->isLoopInvariant(BO->RHS))
5773 Accum = getSCEV(BO->RHS);
5774 else if (BO->RHS == PN && L->isLoopInvariant(BO->LHS))
5775 Accum = getSCEV(BO->LHS);
5776
5777 if (!Accum)
5778 return nullptr;
5779
5781 if (BO->IsNUW)
5782 Flags = setFlags(Flags, SCEV::FlagNUW);
5783 if (BO->IsNSW)
5784 Flags = setFlags(Flags, SCEV::FlagNSW);
5785
5786 const SCEV *StartVal = getSCEV(StartValueV);
5787 const SCEV *PHISCEV = getAddRecExpr(StartVal, Accum, L, Flags);
5788 insertValueToMap(PN, PHISCEV);
5789
5790 if (auto *AR = dyn_cast<SCEVAddRecExpr>(PHISCEV)) {
5791 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR),
5793 proveNoWrapViaConstantRanges(AR)));
5794 }
5795
5796 // We can add Flags to the post-inc expression only if we
5797 // know that it is *undefined behavior* for BEValueV to
5798 // overflow.
5799 if (auto *BEInst = dyn_cast<Instruction>(BEValueV)) {
5800 assert(isLoopInvariant(Accum, L) &&
5801 "Accum is defined outside L, but is not invariant?");
5802 if (isAddRecNeverPoison(BEInst, L))
5803 (void)getAddRecExpr(getAddExpr(StartVal, Accum), Accum, L, Flags);
5804 }
5805
5806 return PHISCEV;
5807}
5808
5809const SCEV *ScalarEvolution::createAddRecFromPHI(PHINode *PN) {
5810 const Loop *L = LI.getLoopFor(PN->getParent());
5811 if (!L || L->getHeader() != PN->getParent())
5812 return nullptr;
5813
5814 // The loop may have multiple entrances or multiple exits; we can analyze
5815 // this phi as an addrec if it has a unique entry value and a unique
5816 // backedge value.
5817 Value *BEValueV = nullptr, *StartValueV = nullptr;
5818 for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) {
5819 Value *V = PN->getIncomingValue(i);
5820 if (L->contains(PN->getIncomingBlock(i))) {
5821 if (!BEValueV) {
5822 BEValueV = V;
5823 } else if (BEValueV != V) {
5824 BEValueV = nullptr;
5825 break;
5826 }
5827 } else if (!StartValueV) {
5828 StartValueV = V;
5829 } else if (StartValueV != V) {
5830 StartValueV = nullptr;
5831 break;
5832 }
5833 }
5834 if (!BEValueV || !StartValueV)
5835 return nullptr;
5836
5837 assert(ValueExprMap.find_as(PN) == ValueExprMap.end() &&
5838 "PHI node already processed?");
5839
5840 // First, try to find AddRec expression without creating a fictituos symbolic
5841 // value for PN.
5842 if (auto *S = createSimpleAffineAddRec(PN, BEValueV, StartValueV))
5843 return S;
5844
5845 // Handle PHI node value symbolically.
5846 const SCEV *SymbolicName = getUnknown(PN);
5847 insertValueToMap(PN, SymbolicName);
5848
5849 // Using this symbolic name for the PHI, analyze the value coming around
5850 // the back-edge.
5851 const SCEV *BEValue = getSCEV(BEValueV);
5852
5853 // NOTE: If BEValue is loop invariant, we know that the PHI node just
5854 // has a special value for the first iteration of the loop.
5855
5856 // If the value coming around the backedge is an add with the symbolic
5857 // value we just inserted, then we found a simple induction variable!
5858 if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(BEValue)) {
5859 // If there is a single occurrence of the symbolic value, replace it
5860 // with a recurrence.
5861 unsigned FoundIndex = Add->getNumOperands();
5862 for (unsigned i = 0, e = Add->getNumOperands(); i != e; ++i)
5863 if (Add->getOperand(i) == SymbolicName)
5864 if (FoundIndex == e) {
5865 FoundIndex = i;
5866 break;
5867 }
5868
5869 if (FoundIndex != Add->getNumOperands()) {
5870 // Create an add with everything but the specified operand.
5872 for (unsigned i = 0, e = Add->getNumOperands(); i != e; ++i)
5873 if (i != FoundIndex)
5874 Ops.push_back(SCEVBackedgeConditionFolder::rewrite(Add->getOperand(i),
5875 L, *this));
5876 const SCEV *Accum = getAddExpr(Ops);
5877
5878 // This is not a valid addrec if the step amount is varying each
5879 // loop iteration, but is not itself an addrec in this loop.
5880 if (isLoopInvariant(Accum, L) ||
5881 (isa<SCEVAddRecExpr>(Accum) &&
5882 cast<SCEVAddRecExpr>(Accum)->getLoop() == L)) {
5884
5885 if (auto BO = MatchBinaryOp(BEValueV, getDataLayout(), AC, DT, PN)) {
5886 if (BO->Opcode == Instruction::Add && BO->LHS == PN) {
5887 if (BO->IsNUW)
5888 Flags = setFlags(Flags, SCEV::FlagNUW);
5889 if (BO->IsNSW)
5890 Flags = setFlags(Flags, SCEV::FlagNSW);
5891 }
5892 } else if (GEPOperator *GEP = dyn_cast<GEPOperator>(BEValueV)) {
5893 if (GEP->getOperand(0) == PN) {
5894 GEPNoWrapFlags NW = GEP->getNoWrapFlags();
5895 // If the increment has any nowrap flags, then we know the address
5896 // space cannot be wrapped around.
5897 if (NW != GEPNoWrapFlags::none())
5898 Flags = setFlags(Flags, SCEV::FlagNW);
5899 // If the GEP is nuw or nusw with non-negative offset, we know that
5900 // no unsigned wrap occurs. We cannot set the nsw flag as only the
5901 // offset is treated as signed, while the base is unsigned.
5902 if (NW.hasNoUnsignedWrap() ||
5904 Flags = setFlags(Flags, SCEV::FlagNUW);
5905 }
5906
5907 // We cannot transfer nuw and nsw flags from subtraction
5908 // operations -- sub nuw X, Y is not the same as add nuw X, -Y
5909 // for instance.
5910 }
5911
5912 const SCEV *StartVal = getSCEV(StartValueV);
5913 const SCEV *PHISCEV = getAddRecExpr(StartVal, Accum, L, Flags);
5914
5915 // Okay, for the entire analysis of this edge we assumed the PHI
5916 // to be symbolic. We now need to go back and purge all of the
5917 // entries for the scalars that use the symbolic expression.
5918 forgetMemoizedResults(SymbolicName);
5919 insertValueToMap(PN, PHISCEV);
5920
5921 if (auto *AR = dyn_cast<SCEVAddRecExpr>(PHISCEV)) {
5922 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR),
5924 proveNoWrapViaConstantRanges(AR)));
5925 }
5926
5927 // We can add Flags to the post-inc expression only if we
5928 // know that it is *undefined behavior* for BEValueV to
5929 // overflow.
5930 if (auto *BEInst = dyn_cast<Instruction>(BEValueV))
5931 if (isLoopInvariant(Accum, L) && isAddRecNeverPoison(BEInst, L))
5932 (void)getAddRecExpr(getAddExpr(StartVal, Accum), Accum, L, Flags);
5933
5934 return PHISCEV;
5935 }
5936 }
5937 } else {
5938 // Otherwise, this could be a loop like this:
5939 // i = 0; for (j = 1; ..; ++j) { .... i = j; }
5940 // In this case, j = {1,+,1} and BEValue is j.
5941 // Because the other in-value of i (0) fits the evolution of BEValue
5942 // i really is an addrec evolution.
5943 //
5944 // We can generalize this saying that i is the shifted value of BEValue
5945 // by one iteration:
5946 // PHI(f(0), f({1,+,1})) --> f({0,+,1})
5947
5948 // Do not allow refinement in rewriting of BEValue.
5949 const SCEV *Shifted = SCEVShiftRewriter::rewrite(BEValue, L, *this);
5950 const SCEV *Start = SCEVInitRewriter::rewrite(Shifted, L, *this, false);
5951 if (Shifted != getCouldNotCompute() && Start != getCouldNotCompute() &&
5952 isGuaranteedNotToCauseUB(Shifted) && ::impliesPoison(Shifted, Start)) {
5953 const SCEV *StartVal = getSCEV(StartValueV);
5954 if (Start == StartVal) {
5955 // Okay, for the entire analysis of this edge we assumed the PHI
5956 // to be symbolic. We now need to go back and purge all of the
5957 // entries for the scalars that use the symbolic expression.
5958 forgetMemoizedResults(SymbolicName);
5959 insertValueToMap(PN, Shifted);
5960 return Shifted;
5961 }
5962 }
5963 }
5964
5965 // Remove the temporary PHI node SCEV that has been inserted while intending
5966 // to create an AddRecExpr for this PHI node. We can not keep this temporary
5967 // as it will prevent later (possibly simpler) SCEV expressions to be added
5968 // to the ValueExprMap.
5969 eraseValueFromMap(PN);
5970
5971 return nullptr;
5972}
5973
5974// Try to match a control flow sequence that branches out at BI and merges back
5975// at Merge into a "C ? LHS : RHS" select pattern. Return true on a successful
5976// match.
5978 Value *&C, Value *&LHS, Value *&RHS) {
5979 C = BI->getCondition();
5980
5981 BasicBlockEdge LeftEdge(BI->getParent(), BI->getSuccessor(0));
5982 BasicBlockEdge RightEdge(BI->getParent(), BI->getSuccessor(1));
5983
5984 if (!LeftEdge.isSingleEdge())
5985 return false;
5986
5987 assert(RightEdge.isSingleEdge() && "Follows from LeftEdge.isSingleEdge()");
5988
5989 Use &LeftUse = Merge->getOperandUse(0);
5990 Use &RightUse = Merge->getOperandUse(1);
5991
5992 if (DT.dominates(LeftEdge, LeftUse) && DT.dominates(RightEdge, RightUse)) {
5993 LHS = LeftUse;
5994 RHS = RightUse;
5995 return true;
5996 }
5997
5998 if (DT.dominates(LeftEdge, RightUse) && DT.dominates(RightEdge, LeftUse)) {
5999 LHS = RightUse;
6000 RHS = LeftUse;
6001 return true;
6002 }
6003
6004 return false;
6005}
6006
6007const SCEV *ScalarEvolution::createNodeFromSelectLikePHI(PHINode *PN) {
6008 auto IsReachable =
6009 [&](BasicBlock *BB) { return DT.isReachableFromEntry(BB); };
6010 if (PN->getNumIncomingValues() == 2 && all_of(PN->blocks(), IsReachable)) {
6011 // Try to match
6012 //
6013 // br %cond, label %left, label %right
6014 // left:
6015 // br label %merge
6016 // right:
6017 // br label %merge
6018 // merge:
6019 // V = phi [ %x, %left ], [ %y, %right ]
6020 //
6021 // as "select %cond, %x, %y"
6022
6023 BasicBlock *IDom = DT[PN->getParent()]->getIDom()->getBlock();
6024 assert(IDom && "At least the entry block should dominate PN");
6025
6026 auto *BI = dyn_cast<BranchInst>(IDom->getTerminator());
6027 Value *Cond = nullptr, *LHS = nullptr, *RHS = nullptr;
6028
6029 if (BI && BI->isConditional() &&
6030 BrPHIToSelect(DT, BI, PN, Cond, LHS, RHS) &&
6033 return createNodeForSelectOrPHI(PN, Cond, LHS, RHS);
6034 }
6035
6036 return nullptr;
6037}
6038
6039/// Returns SCEV for the first operand of a phi if all phi operands have
6040/// identical opcodes and operands
6041/// eg.
6042/// a: %add = %a + %b
6043/// br %c
6044/// b: %add1 = %a + %b
6045/// br %c
6046/// c: %phi = phi [%add, a], [%add1, b]
6047/// scev(%phi) => scev(%add)
6048const SCEV *
6049ScalarEvolution::createNodeForPHIWithIdenticalOperands(PHINode *PN) {
6050 BinaryOperator *CommonInst = nullptr;
6051 // Check if instructions are identical.
6052 for (Value *Incoming : PN->incoming_values()) {
6053 auto *IncomingInst = dyn_cast<BinaryOperator>(Incoming);
6054 if (!IncomingInst)
6055 return nullptr;
6056 if (CommonInst) {
6057 if (!CommonInst->isIdenticalToWhenDefined(IncomingInst))
6058 return nullptr; // Not identical, give up
6059 } else {
6060 // Remember binary operator
6061 CommonInst = IncomingInst;
6062 }
6063 }
6064 if (!CommonInst)
6065 return nullptr;
6066
6067 // Check if SCEV exprs for instructions are identical.
6068 const SCEV *CommonSCEV = getSCEV(CommonInst);
6069 bool SCEVExprsIdentical =
6071 [this, CommonSCEV](Value *V) { return CommonSCEV == getSCEV(V); });
6072 return SCEVExprsIdentical ? CommonSCEV : nullptr;
6073}
6074
6075const SCEV *ScalarEvolution::createNodeForPHI(PHINode *PN) {
6076 if (const SCEV *S = createAddRecFromPHI(PN))
6077 return S;
6078
6079 // We do not allow simplifying phi (undef, X) to X here, to avoid reusing the
6080 // phi node for X.
6081 if (Value *V = simplifyInstruction(
6082 PN, {getDataLayout(), &TLI, &DT, &AC, /*CtxI=*/nullptr,
6083 /*UseInstrInfo=*/true, /*CanUseUndef=*/false}))
6084 return getSCEV(V);
6085
6086 if (const SCEV *S = createNodeForPHIWithIdenticalOperands(PN))
6087 return S;
6088
6089 if (const SCEV *S = createNodeFromSelectLikePHI(PN))
6090 return S;
6091
6092 // If it's not a loop phi, we can't handle it yet.
6093 return getUnknown(PN);
6094}
6095
6096bool SCEVMinMaxExprContains(const SCEV *Root, const SCEV *OperandToFind,
6097 SCEVTypes RootKind) {
6098 struct FindClosure {
6099 const SCEV *OperandToFind;
6100 const SCEVTypes RootKind; // Must be a sequential min/max expression.
6101 const SCEVTypes NonSequentialRootKind; // Non-seq variant of RootKind.
6102
6103 bool Found = false;
6104
6105 bool canRecurseInto(SCEVTypes Kind) const {
6106 // We can only recurse into the SCEV expression of the same effective type
6107 // as the type of our root SCEV expression, and into zero-extensions.
6108 return RootKind == Kind || NonSequentialRootKind == Kind ||
6109 scZeroExtend == Kind;
6110 };
6111
6112 FindClosure(const SCEV *OperandToFind, SCEVTypes RootKind)
6113 : OperandToFind(OperandToFind), RootKind(RootKind),
6114 NonSequentialRootKind(
6116 RootKind)) {}
6117
6118 bool follow(const SCEV *S) {
6119 Found = S == OperandToFind;
6120
6121 return !isDone() && canRecurseInto(S->getSCEVType());
6122 }
6123
6124 bool isDone() const { return Found; }
6125 };
6126
6127 FindClosure FC(OperandToFind, RootKind);
6128 visitAll(Root, FC);
6129 return FC.Found;
6130}
6131
6132std::optional<const SCEV *>
6133ScalarEvolution::createNodeForSelectOrPHIInstWithICmpInstCond(Type *Ty,
6134 ICmpInst *Cond,
6135 Value *TrueVal,
6136 Value *FalseVal) {
6137 // Try to match some simple smax or umax patterns.
6138 auto *ICI = Cond;
6139
6140 Value *LHS = ICI->getOperand(0);
6141 Value *RHS = ICI->getOperand(1);
6142
6143 switch (ICI->getPredicate()) {
6144 case ICmpInst::ICMP_SLT:
6145 case ICmpInst::ICMP_SLE:
6146 case ICmpInst::ICMP_ULT:
6147 case ICmpInst::ICMP_ULE:
6148 std::swap(LHS, RHS);
6149 [[fallthrough]];
6150 case ICmpInst::ICMP_SGT:
6151 case ICmpInst::ICMP_SGE:
6152 case ICmpInst::ICMP_UGT:
6153 case ICmpInst::ICMP_UGE:
6154 // a > b ? a+x : b+x -> max(a, b)+x
6155 // a > b ? b+x : a+x -> min(a, b)+x
6157 bool Signed = ICI->isSigned();
6158 const SCEV *LA = getSCEV(TrueVal);
6159 const SCEV *RA = getSCEV(FalseVal);
6160 const SCEV *LS = getSCEV(LHS);
6161 const SCEV *RS = getSCEV(RHS);
6162 if (LA->getType()->isPointerTy()) {
6163 // FIXME: Handle cases where LS/RS are pointers not equal to LA/RA.
6164 // Need to make sure we can't produce weird expressions involving
6165 // negated pointers.
6166 if (LA == LS && RA == RS)
6167 return Signed ? getSMaxExpr(LS, RS) : getUMaxExpr(LS, RS);
6168 if (LA == RS && RA == LS)
6169 return Signed ? getSMinExpr(LS, RS) : getUMinExpr(LS, RS);
6170 }
6171 auto CoerceOperand = [&](const SCEV *Op) -> const SCEV * {
6172 if (Op->getType()->isPointerTy()) {
6175 return Op;
6176 }
6177 if (Signed)
6178 Op = getNoopOrSignExtend(Op, Ty);
6179 else
6180 Op = getNoopOrZeroExtend(Op, Ty);
6181 return Op;
6182 };
6183 LS = CoerceOperand(LS);
6184 RS = CoerceOperand(RS);
6186 break;
6187 const SCEV *LDiff = getMinusSCEV(LA, LS);
6188 const SCEV *RDiff = getMinusSCEV(RA, RS);
6189 if (LDiff == RDiff)
6190 return getAddExpr(Signed ? getSMaxExpr(LS, RS) : getUMaxExpr(LS, RS),
6191 LDiff);
6192 LDiff = getMinusSCEV(LA, RS);
6193 RDiff = getMinusSCEV(RA, LS);
6194 if (LDiff == RDiff)
6195 return getAddExpr(Signed ? getSMinExpr(LS, RS) : getUMinExpr(LS, RS),
6196 LDiff);
6197 }
6198 break;
6199 case ICmpInst::ICMP_NE:
6200 // x != 0 ? x+y : C+y -> x == 0 ? C+y : x+y
6201 std::swap(TrueVal, FalseVal);
6202 [[fallthrough]];
6203 case ICmpInst::ICMP_EQ:
6204 // x == 0 ? C+y : x+y -> umax(x, C)+y iff C u<= 1
6207 const SCEV *X = getNoopOrZeroExtend(getSCEV(LHS), Ty);
6208 const SCEV *TrueValExpr = getSCEV(TrueVal); // C+y
6209 const SCEV *FalseValExpr = getSCEV(FalseVal); // x+y
6210 const SCEV *Y = getMinusSCEV(FalseValExpr, X); // y = (x+y)-x
6211 const SCEV *C = getMinusSCEV(TrueValExpr, Y); // C = (C+y)-y
6212 if (isa<SCEVConstant>(C) && cast<SCEVConstant>(C)->getAPInt().ule(1))
6213 return getAddExpr(getUMaxExpr(X, C), Y);
6214 }
6215 // x == 0 ? 0 : umin (..., x, ...) -> umin_seq(x, umin (...))
6216 // x == 0 ? 0 : umin_seq(..., x, ...) -> umin_seq(x, umin_seq(...))
6217 // x == 0 ? 0 : umin (..., umin_seq(..., x, ...), ...)
6218 // -> umin_seq(x, umin (..., umin_seq(...), ...))
6220 isa<ConstantInt>(TrueVal) && cast<ConstantInt>(TrueVal)->isZero()) {
6221 const SCEV *X = getSCEV(LHS);
6222 while (auto *ZExt = dyn_cast<SCEVZeroExtendExpr>(X))
6223 X = ZExt->getOperand();
6224 if (getTypeSizeInBits(X->getType()) <= getTypeSizeInBits(Ty)) {
6225 const SCEV *FalseValExpr = getSCEV(FalseVal);
6226 if (SCEVMinMaxExprContains(FalseValExpr, X, scSequentialUMinExpr))
6227 return getUMinExpr(getNoopOrZeroExtend(X, Ty), FalseValExpr,
6228 /*Sequential=*/true);
6229 }
6230 }
6231 break;
6232 default:
6233 break;
6234 }
6235
6236 return std::nullopt;
6237}
6238
6239static std::optional<const SCEV *>
6241 const SCEV *TrueExpr, const SCEV *FalseExpr) {
6242 assert(CondExpr->getType()->isIntegerTy(1) &&
6243 TrueExpr->getType() == FalseExpr->getType() &&
6244 TrueExpr->getType()->isIntegerTy(1) &&
6245 "Unexpected operands of a select.");
6246
6247 // i1 cond ? i1 x : i1 C --> C + (i1 cond ? (i1 x - i1 C) : i1 0)
6248 // --> C + (umin_seq cond, x - C)
6249 //
6250 // i1 cond ? i1 C : i1 x --> C + (i1 cond ? i1 0 : (i1 x - i1 C))
6251 // --> C + (i1 ~cond ? (i1 x - i1 C) : i1 0)
6252 // --> C + (umin_seq ~cond, x - C)
6253
6254 // FIXME: while we can't legally model the case where both of the hands
6255 // are fully variable, we only require that the *difference* is constant.
6256 if (!isa<SCEVConstant>(TrueExpr) && !isa<SCEVConstant>(FalseExpr))
6257 return std::nullopt;
6258
6259 const SCEV *X, *C;
6260 if (isa<SCEVConstant>(TrueExpr)) {
6261 CondExpr = SE->getNotSCEV(CondExpr);
6262 X = FalseExpr;
6263 C = TrueExpr;
6264 } else {
6265 X = TrueExpr;
6266 C = FalseExpr;
6267 }
6268 return SE->getAddExpr(C, SE->getUMinExpr(CondExpr, SE->getMinusSCEV(X, C),
6269 /*Sequential=*/true));
6270}
6271
6272static std::optional<const SCEV *>
6274 Value *FalseVal) {
6275 if (!isa<ConstantInt>(TrueVal) && !isa<ConstantInt>(FalseVal))
6276 return std::nullopt;
6277
6278 const auto *SECond = SE->getSCEV(Cond);
6279 const auto *SETrue = SE->getSCEV(TrueVal);
6280 const auto *SEFalse = SE->getSCEV(FalseVal);
6281 return createNodeForSelectViaUMinSeq(SE, SECond, SETrue, SEFalse);
6282}
6283
6284const SCEV *ScalarEvolution::createNodeForSelectOrPHIViaUMinSeq(
6285 Value *V, Value *Cond, Value *TrueVal, Value *FalseVal) {
6286 assert(Cond->getType()->isIntegerTy(1) && "Select condition is not an i1?");
6287 assert(TrueVal->getType() == FalseVal->getType() &&
6288 V->getType() == TrueVal->getType() &&
6289 "Types of select hands and of the result must match.");
6290
6291 // For now, only deal with i1-typed `select`s.
6292 if (!V->getType()->isIntegerTy(1))
6293 return getUnknown(V);
6294
6295 if (std::optional<const SCEV *> S =
6296 createNodeForSelectViaUMinSeq(this, Cond, TrueVal, FalseVal))
6297 return *S;
6298
6299 return getUnknown(V);
6300}
6301
6302const SCEV *ScalarEvolution::createNodeForSelectOrPHI(Value *V, Value *Cond,
6303 Value *TrueVal,
6304 Value *FalseVal) {
6305 // Handle "constant" branch or select. This can occur for instance when a
6306 // loop pass transforms an inner loop and moves on to process the outer loop.
6307 if (auto *CI = dyn_cast<ConstantInt>(Cond))
6308 return getSCEV(CI->isOne() ? TrueVal : FalseVal);
6309
6310 if (auto *I = dyn_cast<Instruction>(V)) {
6311 if (auto *ICI = dyn_cast<ICmpInst>(Cond)) {
6312 if (std::optional<const SCEV *> S =
6313 createNodeForSelectOrPHIInstWithICmpInstCond(I->getType(), ICI,
6314 TrueVal, FalseVal))
6315 return *S;
6316 }
6317 }
6318
6319 return createNodeForSelectOrPHIViaUMinSeq(V, Cond, TrueVal, FalseVal);
6320}
6321
6322/// Expand GEP instructions into add and multiply operations. This allows them
6323/// to be analyzed by regular SCEV code.
6324const SCEV *ScalarEvolution::createNodeForGEP(GEPOperator *GEP) {
6325 assert(GEP->getSourceElementType()->isSized() &&
6326 "GEP source element type must be sized");
6327
6329 for (Value *Index : GEP->indices())
6330 IndexExprs.push_back(getSCEV(Index));
6331 return getGEPExpr(GEP, IndexExprs);
6332}
6333
6334APInt ScalarEvolution::getConstantMultipleImpl(const SCEV *S,
6335 const Instruction *CtxI) {
6336 uint64_t BitWidth = getTypeSizeInBits(S->getType());
6337 auto GetShiftedByZeros = [BitWidth](uint32_t TrailingZeros) {
6338 return TrailingZeros >= BitWidth
6340 : APInt::getOneBitSet(BitWidth, TrailingZeros);
6341 };
6342 auto GetGCDMultiple = [this, CtxI](const SCEVNAryExpr *N) {
6343 // The result is GCD of all operands results.
6344 APInt Res = getConstantMultiple(N->getOperand(0), CtxI);
6345 for (unsigned I = 1, E = N->getNumOperands(); I < E && Res != 1; ++I)
6347 Res, getConstantMultiple(N->getOperand(I), CtxI));
6348 return Res;
6349 };
6350
6351 switch (S->getSCEVType()) {
6352 case scConstant:
6353 return cast<SCEVConstant>(S)->getAPInt();
6354 case scPtrToInt:
6355 return getConstantMultiple(cast<SCEVPtrToIntExpr>(S)->getOperand(), CtxI);
6356 case scUDivExpr:
6357 case scVScale:
6358 return APInt(BitWidth, 1);
6359 case scTruncate: {
6360 // Only multiples that are a power of 2 will hold after truncation.
6361 const SCEVTruncateExpr *T = cast<SCEVTruncateExpr>(S);
6362 uint32_t TZ = getMinTrailingZeros(T->getOperand(), CtxI);
6363 return GetShiftedByZeros(TZ);
6364 }
6365 case scZeroExtend: {
6366 const SCEVZeroExtendExpr *Z = cast<SCEVZeroExtendExpr>(S);
6367 return getConstantMultiple(Z->getOperand(), CtxI).zext(BitWidth);
6368 }
6369 case scSignExtend: {
6370 // Only multiples that are a power of 2 will hold after sext.
6371 const SCEVSignExtendExpr *E = cast<SCEVSignExtendExpr>(S);
6372 uint32_t TZ = getMinTrailingZeros(E->getOperand(), CtxI);
6373 return GetShiftedByZeros(TZ);
6374 }
6375 case scMulExpr: {
6376 const SCEVMulExpr *M = cast<SCEVMulExpr>(S);
6377 if (M->hasNoUnsignedWrap()) {
6378 // The result is the product of all operand results.
6379 APInt Res = getConstantMultiple(M->getOperand(0), CtxI);
6380 for (const SCEV *Operand : M->operands().drop_front())
6381 Res = Res * getConstantMultiple(Operand, CtxI);
6382 return Res;
6383 }
6384
6385 // If there are no wrap guarentees, find the trailing zeros, which is the
6386 // sum of trailing zeros for all its operands.
6387 uint32_t TZ = 0;
6388 for (const SCEV *Operand : M->operands())
6389 TZ += getMinTrailingZeros(Operand, CtxI);
6390 return GetShiftedByZeros(TZ);
6391 }
6392 case scAddExpr:
6393 case scAddRecExpr: {
6394 const SCEVNAryExpr *N = cast<SCEVNAryExpr>(S);
6395 if (N->hasNoUnsignedWrap())
6396 return GetGCDMultiple(N);
6397 // Find the trailing bits, which is the minimum of its operands.
6398 uint32_t TZ = getMinTrailingZeros(N->getOperand(0), CtxI);
6399 for (const SCEV *Operand : N->operands().drop_front())
6400 TZ = std::min(TZ, getMinTrailingZeros(Operand, CtxI));
6401 return GetShiftedByZeros(TZ);
6402 }
6403 case scUMaxExpr:
6404 case scSMaxExpr:
6405 case scUMinExpr:
6406 case scSMinExpr:
6408 return GetGCDMultiple(cast<SCEVNAryExpr>(S));
6409 case scUnknown: {
6410 // Ask ValueTracking for known bits. SCEVUnknown only become available at
6411 // the point their underlying IR instruction has been defined. If CtxI was
6412 // not provided, use:
6413 // * the first instruction in the entry block if it is an argument
6414 // * the instruction itself otherwise.
6415 const SCEVUnknown *U = cast<SCEVUnknown>(S);
6416 if (!CtxI) {
6417 if (isa<Argument>(U->getValue()))
6418 CtxI = &*F.getEntryBlock().begin();
6419 else if (auto *I = dyn_cast<Instruction>(U->getValue()))
6420 CtxI = I;
6421 }
6422 unsigned Known =
6423 computeKnownBits(U->getValue(), getDataLayout(), &AC, CtxI, &DT)
6424 .countMinTrailingZeros();
6425 return GetShiftedByZeros(Known);
6426 }
6427 case scCouldNotCompute:
6428 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
6429 }
6430 llvm_unreachable("Unknown SCEV kind!");
6431}
6432
6434 const Instruction *CtxI) {
6435 // Skip looking up and updating the cache if there is a context instruction,
6436 // as the result will only be valid in the specified context.
6437 if (CtxI)
6438 return getConstantMultipleImpl(S, CtxI);
6439
6440 auto I = ConstantMultipleCache.find(S);
6441 if (I != ConstantMultipleCache.end())
6442 return I->second;
6443
6444 APInt Result = getConstantMultipleImpl(S, CtxI);
6445 auto InsertPair = ConstantMultipleCache.insert({S, Result});
6446 assert(InsertPair.second && "Should insert a new key");
6447 return InsertPair.first->second;
6448}
6449
6451 APInt Multiple = getConstantMultiple(S);
6452 return Multiple == 0 ? APInt(Multiple.getBitWidth(), 1) : Multiple;
6453}
6454
6456 const Instruction *CtxI) {
6457 return std::min(getConstantMultiple(S, CtxI).countTrailingZeros(),
6458 (unsigned)getTypeSizeInBits(S->getType()));
6459}
6460
6461/// Helper method to assign a range to V from metadata present in the IR.
6462static std::optional<ConstantRange> GetRangeFromMetadata(Value *V) {
6464 if (MDNode *MD = I->getMetadata(LLVMContext::MD_range))
6465 return getConstantRangeFromMetadata(*MD);
6466 if (const auto *CB = dyn_cast<CallBase>(V))
6467 if (std::optional<ConstantRange> Range = CB->getRange())
6468 return Range;
6469 }
6470 if (auto *A = dyn_cast<Argument>(V))
6471 if (std::optional<ConstantRange> Range = A->getRange())
6472 return Range;
6473
6474 return std::nullopt;
6475}
6476
6478 SCEV::NoWrapFlags Flags) {
6479 if (AddRec->getNoWrapFlags(Flags) != Flags) {
6480 AddRec->setNoWrapFlags(Flags);
6481 UnsignedRanges.erase(AddRec);
6482 SignedRanges.erase(AddRec);
6483 ConstantMultipleCache.erase(AddRec);
6484 }
6485}
6486
6487ConstantRange ScalarEvolution::
6488getRangeForUnknownRecurrence(const SCEVUnknown *U) {
6489 const DataLayout &DL = getDataLayout();
6490
6491 unsigned BitWidth = getTypeSizeInBits(U->getType());
6492 const ConstantRange FullSet(BitWidth, /*isFullSet=*/true);
6493
6494 // Match a simple recurrence of the form: <start, ShiftOp, Step>, and then
6495 // use information about the trip count to improve our available range. Note
6496 // that the trip count independent cases are already handled by known bits.
6497 // WARNING: The definition of recurrence used here is subtly different than
6498 // the one used by AddRec (and thus most of this file). Step is allowed to
6499 // be arbitrarily loop varying here, where AddRec allows only loop invariant
6500 // and other addrecs in the same loop (for non-affine addrecs). The code
6501 // below intentionally handles the case where step is not loop invariant.
6502 auto *P = dyn_cast<PHINode>(U->getValue());
6503 if (!P)
6504 return FullSet;
6505
6506 // Make sure that no Phi input comes from an unreachable block. Otherwise,
6507 // even the values that are not available in these blocks may come from them,
6508 // and this leads to false-positive recurrence test.
6509 for (auto *Pred : predecessors(P->getParent()))
6510 if (!DT.isReachableFromEntry(Pred))
6511 return FullSet;
6512
6513 BinaryOperator *BO;
6514 Value *Start, *Step;
6515 if (!matchSimpleRecurrence(P, BO, Start, Step))
6516 return FullSet;
6517
6518 // If we found a recurrence in reachable code, we must be in a loop. Note
6519 // that BO might be in some subloop of L, and that's completely okay.
6520 auto *L = LI.getLoopFor(P->getParent());
6521 assert(L && L->getHeader() == P->getParent());
6522 if (!L->contains(BO->getParent()))
6523 // NOTE: This bailout should be an assert instead. However, asserting
6524 // the condition here exposes a case where LoopFusion is querying SCEV
6525 // with malformed loop information during the midst of the transform.
6526 // There doesn't appear to be an obvious fix, so for the moment bailout
6527 // until the caller issue can be fixed. PR49566 tracks the bug.
6528 return FullSet;
6529
6530 // TODO: Extend to other opcodes such as mul, and div
6531 switch (BO->getOpcode()) {
6532 default:
6533 return FullSet;
6534 case Instruction::AShr:
6535 case Instruction::LShr:
6536 case Instruction::Shl:
6537 break;
6538 };
6539
6540 if (BO->getOperand(0) != P)
6541 // TODO: Handle the power function forms some day.
6542 return FullSet;
6543
6544 unsigned TC = getSmallConstantMaxTripCount(L);
6545 if (!TC || TC >= BitWidth)
6546 return FullSet;
6547
6548 auto KnownStart = computeKnownBits(Start, DL, &AC, nullptr, &DT);
6549 auto KnownStep = computeKnownBits(Step, DL, &AC, nullptr, &DT);
6550 assert(KnownStart.getBitWidth() == BitWidth &&
6551 KnownStep.getBitWidth() == BitWidth);
6552
6553 // Compute total shift amount, being careful of overflow and bitwidths.
6554 auto MaxShiftAmt = KnownStep.getMaxValue();
6555 APInt TCAP(BitWidth, TC-1);
6556 bool Overflow = false;
6557 auto TotalShift = MaxShiftAmt.umul_ov(TCAP, Overflow);
6558 if (Overflow)
6559 return FullSet;
6560
6561 switch (BO->getOpcode()) {
6562 default:
6563 llvm_unreachable("filtered out above");
6564 case Instruction::AShr: {
6565 // For each ashr, three cases:
6566 // shift = 0 => unchanged value
6567 // saturation => 0 or -1
6568 // other => a value closer to zero (of the same sign)
6569 // Thus, the end value is closer to zero than the start.
6570 auto KnownEnd = KnownBits::ashr(KnownStart,
6571 KnownBits::makeConstant(TotalShift));
6572 if (KnownStart.isNonNegative())
6573 // Analogous to lshr (simply not yet canonicalized)
6574 return ConstantRange::getNonEmpty(KnownEnd.getMinValue(),
6575 KnownStart.getMaxValue() + 1);
6576 if (KnownStart.isNegative())
6577 // End >=u Start && End <=s Start
6578 return ConstantRange::getNonEmpty(KnownStart.getMinValue(),
6579 KnownEnd.getMaxValue() + 1);
6580 break;
6581 }
6582 case Instruction::LShr: {
6583 // For each lshr, three cases:
6584 // shift = 0 => unchanged value
6585 // saturation => 0
6586 // other => a smaller positive number
6587 // Thus, the low end of the unsigned range is the last value produced.
6588 auto KnownEnd = KnownBits::lshr(KnownStart,
6589 KnownBits::makeConstant(TotalShift));
6590 return ConstantRange::getNonEmpty(KnownEnd.getMinValue(),
6591 KnownStart.getMaxValue() + 1);
6592 }
6593 case Instruction::Shl: {
6594 // Iff no bits are shifted out, value increases on every shift.
6595 auto KnownEnd = KnownBits::shl(KnownStart,
6596 KnownBits::makeConstant(TotalShift));
6597 if (TotalShift.ult(KnownStart.countMinLeadingZeros()))
6598 return ConstantRange(KnownStart.getMinValue(),
6599 KnownEnd.getMaxValue() + 1);
6600 break;
6601 }
6602 };
6603 return FullSet;
6604}
6605
6606const ConstantRange &
6607ScalarEvolution::getRangeRefIter(const SCEV *S,
6608 ScalarEvolution::RangeSignHint SignHint) {
6609 DenseMap<const SCEV *, ConstantRange> &Cache =
6610 SignHint == ScalarEvolution::HINT_RANGE_UNSIGNED ? UnsignedRanges
6611 : SignedRanges;
6613 SmallPtrSet<const SCEV *, 8> Seen;
6614
6615 // Add Expr to the worklist, if Expr is either an N-ary expression or a
6616 // SCEVUnknown PHI node.
6617 auto AddToWorklist = [&WorkList, &Seen, &Cache](const SCEV *Expr) {
6618 if (!Seen.insert(Expr).second)
6619 return;
6620 if (Cache.contains(Expr))
6621 return;
6622 switch (Expr->getSCEVType()) {
6623 case scUnknown:
6624 if (!isa<PHINode>(cast<SCEVUnknown>(Expr)->getValue()))
6625 break;
6626 [[fallthrough]];
6627 case scConstant:
6628 case scVScale:
6629 case scTruncate:
6630 case scZeroExtend:
6631 case scSignExtend:
6632 case scPtrToInt:
6633 case scAddExpr:
6634 case scMulExpr:
6635 case scUDivExpr:
6636 case scAddRecExpr:
6637 case scUMaxExpr:
6638 case scSMaxExpr:
6639 case scUMinExpr:
6640 case scSMinExpr:
6642 WorkList.push_back(Expr);
6643 break;
6644 case scCouldNotCompute:
6645 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
6646 }
6647 };
6648 AddToWorklist(S);
6649
6650 // Build worklist by queuing operands of N-ary expressions and phi nodes.
6651 for (unsigned I = 0; I != WorkList.size(); ++I) {
6652 const SCEV *P = WorkList[I];
6653 auto *UnknownS = dyn_cast<SCEVUnknown>(P);
6654 // If it is not a `SCEVUnknown`, just recurse into operands.
6655 if (!UnknownS) {
6656 for (const SCEV *Op : P->operands())
6657 AddToWorklist(Op);
6658 continue;
6659 }
6660 // `SCEVUnknown`'s require special treatment.
6661 if (const PHINode *P = dyn_cast<PHINode>(UnknownS->getValue())) {
6662 if (!PendingPhiRangesIter.insert(P).second)
6663 continue;
6664 for (auto &Op : reverse(P->operands()))
6665 AddToWorklist(getSCEV(Op));
6666 }
6667 }
6668
6669 if (!WorkList.empty()) {
6670 // Use getRangeRef to compute ranges for items in the worklist in reverse
6671 // order. This will force ranges for earlier operands to be computed before
6672 // their users in most cases.
6673 for (const SCEV *P : reverse(drop_begin(WorkList))) {
6674 getRangeRef(P, SignHint);
6675
6676 if (auto *UnknownS = dyn_cast<SCEVUnknown>(P))
6677 if (const PHINode *P = dyn_cast<PHINode>(UnknownS->getValue()))
6678 PendingPhiRangesIter.erase(P);
6679 }
6680 }
6681
6682 return getRangeRef(S, SignHint, 0);
6683}
6684
6685/// Determine the range for a particular SCEV. If SignHint is
6686/// HINT_RANGE_UNSIGNED (resp. HINT_RANGE_SIGNED) then getRange prefers ranges
6687/// with a "cleaner" unsigned (resp. signed) representation.
6688const ConstantRange &ScalarEvolution::getRangeRef(
6689 const SCEV *S, ScalarEvolution::RangeSignHint SignHint, unsigned Depth) {
6690 DenseMap<const SCEV *, ConstantRange> &Cache =
6691 SignHint == ScalarEvolution::HINT_RANGE_UNSIGNED ? UnsignedRanges
6692 : SignedRanges;
6694 SignHint == ScalarEvolution::HINT_RANGE_UNSIGNED ? ConstantRange::Unsigned
6696
6697 // See if we've computed this range already.
6699 if (I != Cache.end())
6700 return I->second;
6701
6702 if (const SCEVConstant *C = dyn_cast<SCEVConstant>(S))
6703 return setRange(C, SignHint, ConstantRange(C->getAPInt()));
6704
6705 // Switch to iteratively computing the range for S, if it is part of a deeply
6706 // nested expression.
6708 return getRangeRefIter(S, SignHint);
6709
6710 unsigned BitWidth = getTypeSizeInBits(S->getType());
6711 ConstantRange ConservativeResult(BitWidth, /*isFullSet=*/true);
6712 using OBO = OverflowingBinaryOperator;
6713
6714 // If the value has known zeros, the maximum value will have those known zeros
6715 // as well.
6716 if (SignHint == ScalarEvolution::HINT_RANGE_UNSIGNED) {
6717 APInt Multiple = getNonZeroConstantMultiple(S);
6718 APInt Remainder = APInt::getMaxValue(BitWidth).urem(Multiple);
6719 if (!Remainder.isZero())
6720 ConservativeResult =
6721 ConstantRange(APInt::getMinValue(BitWidth),
6722 APInt::getMaxValue(BitWidth) - Remainder + 1);
6723 }
6724 else {
6725 uint32_t TZ = getMinTrailingZeros(S);
6726 if (TZ != 0) {
6727 ConservativeResult = ConstantRange(
6729 APInt::getSignedMaxValue(BitWidth).ashr(TZ).shl(TZ) + 1);
6730 }
6731 }
6732
6733 switch (S->getSCEVType()) {
6734 case scConstant:
6735 llvm_unreachable("Already handled above.");
6736 case scVScale:
6737 return setRange(S, SignHint, getVScaleRange(&F, BitWidth));
6738 case scTruncate: {
6739 const SCEVTruncateExpr *Trunc = cast<SCEVTruncateExpr>(S);
6740 ConstantRange X = getRangeRef(Trunc->getOperand(), SignHint, Depth + 1);
6741 return setRange(
6742 Trunc, SignHint,
6743 ConservativeResult.intersectWith(X.truncate(BitWidth), RangeType));
6744 }
6745 case scZeroExtend: {
6746 const SCEVZeroExtendExpr *ZExt = cast<SCEVZeroExtendExpr>(S);
6747 ConstantRange X = getRangeRef(ZExt->getOperand(), SignHint, Depth + 1);
6748 return setRange(
6749 ZExt, SignHint,
6750 ConservativeResult.intersectWith(X.zeroExtend(BitWidth), RangeType));
6751 }
6752 case scSignExtend: {
6753 const SCEVSignExtendExpr *SExt = cast<SCEVSignExtendExpr>(S);
6754 ConstantRange X = getRangeRef(SExt->getOperand(), SignHint, Depth + 1);
6755 return setRange(
6756 SExt, SignHint,
6757 ConservativeResult.intersectWith(X.signExtend(BitWidth), RangeType));
6758 }
6759 case scPtrToInt: {
6760 const SCEVPtrToIntExpr *PtrToInt = cast<SCEVPtrToIntExpr>(S);
6761 ConstantRange X = getRangeRef(PtrToInt->getOperand(), SignHint, Depth + 1);
6762 return setRange(PtrToInt, SignHint, X);
6763 }
6764 case scAddExpr: {
6765 const SCEVAddExpr *Add = cast<SCEVAddExpr>(S);
6766 ConstantRange X = getRangeRef(Add->getOperand(0), SignHint, Depth + 1);
6767 unsigned WrapType = OBO::AnyWrap;
6768 if (Add->hasNoSignedWrap())
6769 WrapType |= OBO::NoSignedWrap;
6770 if (Add->hasNoUnsignedWrap())
6771 WrapType |= OBO::NoUnsignedWrap;
6772 for (const SCEV *Op : drop_begin(Add->operands()))
6773 X = X.addWithNoWrap(getRangeRef(Op, SignHint, Depth + 1), WrapType,
6774 RangeType);
6775 return setRange(Add, SignHint,
6776 ConservativeResult.intersectWith(X, RangeType));
6777 }
6778 case scMulExpr: {
6779 const SCEVMulExpr *Mul = cast<SCEVMulExpr>(S);
6780 ConstantRange X = getRangeRef(Mul->getOperand(0), SignHint, Depth + 1);
6781 for (const SCEV *Op : drop_begin(Mul->operands()))
6782 X = X.multiply(getRangeRef(Op, SignHint, Depth + 1));
6783 return setRange(Mul, SignHint,
6784 ConservativeResult.intersectWith(X, RangeType));
6785 }
6786 case scUDivExpr: {
6787 const SCEVUDivExpr *UDiv = cast<SCEVUDivExpr>(S);
6788 ConstantRange X = getRangeRef(UDiv->getLHS(), SignHint, Depth + 1);
6789 ConstantRange Y = getRangeRef(UDiv->getRHS(), SignHint, Depth + 1);
6790 return setRange(UDiv, SignHint,
6791 ConservativeResult.intersectWith(X.udiv(Y), RangeType));
6792 }
6793 case scAddRecExpr: {
6794 const SCEVAddRecExpr *AddRec = cast<SCEVAddRecExpr>(S);
6795 // If there's no unsigned wrap, the value will never be less than its
6796 // initial value.
6797 if (AddRec->hasNoUnsignedWrap()) {
6798 APInt UnsignedMinValue = getUnsignedRangeMin(AddRec->getStart());
6799 if (!UnsignedMinValue.isZero())
6800 ConservativeResult = ConservativeResult.intersectWith(
6801 ConstantRange(UnsignedMinValue, APInt(BitWidth, 0)), RangeType);
6802 }
6803
6804 // If there's no signed wrap, and all the operands except initial value have
6805 // the same sign or zero, the value won't ever be:
6806 // 1: smaller than initial value if operands are non negative,
6807 // 2: bigger than initial value if operands are non positive.
6808 // For both cases, value can not cross signed min/max boundary.
6809 if (AddRec->hasNoSignedWrap()) {
6810 bool AllNonNeg = true;
6811 bool AllNonPos = true;
6812 for (unsigned i = 1, e = AddRec->getNumOperands(); i != e; ++i) {
6813 if (!isKnownNonNegative(AddRec->getOperand(i)))
6814 AllNonNeg = false;
6815 if (!isKnownNonPositive(AddRec->getOperand(i)))
6816 AllNonPos = false;
6817 }
6818 if (AllNonNeg)
6819 ConservativeResult = ConservativeResult.intersectWith(
6822 RangeType);
6823 else if (AllNonPos)
6824 ConservativeResult = ConservativeResult.intersectWith(
6826 getSignedRangeMax(AddRec->getStart()) +
6827 1),
6828 RangeType);
6829 }
6830
6831 // TODO: non-affine addrec
6832 if (AddRec->isAffine()) {
6833 const SCEV *MaxBEScev =
6835 if (!isa<SCEVCouldNotCompute>(MaxBEScev)) {
6836 APInt MaxBECount = cast<SCEVConstant>(MaxBEScev)->getAPInt();
6837
6838 // Adjust MaxBECount to the same bitwidth as AddRec. We can truncate if
6839 // MaxBECount's active bits are all <= AddRec's bit width.
6840 if (MaxBECount.getBitWidth() > BitWidth &&
6841 MaxBECount.getActiveBits() <= BitWidth)
6842 MaxBECount = MaxBECount.trunc(BitWidth);
6843 else if (MaxBECount.getBitWidth() < BitWidth)
6844 MaxBECount = MaxBECount.zext(BitWidth);
6845
6846 if (MaxBECount.getBitWidth() == BitWidth) {
6847 auto RangeFromAffine = getRangeForAffineAR(
6848 AddRec->getStart(), AddRec->getStepRecurrence(*this), MaxBECount);
6849 ConservativeResult =
6850 ConservativeResult.intersectWith(RangeFromAffine, RangeType);
6851
6852 auto RangeFromFactoring = getRangeViaFactoring(
6853 AddRec->getStart(), AddRec->getStepRecurrence(*this), MaxBECount);
6854 ConservativeResult =
6855 ConservativeResult.intersectWith(RangeFromFactoring, RangeType);
6856 }
6857 }
6858
6859 // Now try symbolic BE count and more powerful methods.
6861 const SCEV *SymbolicMaxBECount =
6863 if (!isa<SCEVCouldNotCompute>(SymbolicMaxBECount) &&
6864 getTypeSizeInBits(MaxBEScev->getType()) <= BitWidth &&
6865 AddRec->hasNoSelfWrap()) {
6866 auto RangeFromAffineNew = getRangeForAffineNoSelfWrappingAR(
6867 AddRec, SymbolicMaxBECount, BitWidth, SignHint);
6868 ConservativeResult =
6869 ConservativeResult.intersectWith(RangeFromAffineNew, RangeType);
6870 }
6871 }
6872 }
6873
6874 return setRange(AddRec, SignHint, std::move(ConservativeResult));
6875 }
6876 case scUMaxExpr:
6877 case scSMaxExpr:
6878 case scUMinExpr:
6879 case scSMinExpr:
6880 case scSequentialUMinExpr: {
6882 switch (S->getSCEVType()) {
6883 case scUMaxExpr:
6884 ID = Intrinsic::umax;
6885 break;
6886 case scSMaxExpr:
6887 ID = Intrinsic::smax;
6888 break;
6889 case scUMinExpr:
6891 ID = Intrinsic::umin;
6892 break;
6893 case scSMinExpr:
6894 ID = Intrinsic::smin;
6895 break;
6896 default:
6897 llvm_unreachable("Unknown SCEVMinMaxExpr/SCEVSequentialMinMaxExpr.");
6898 }
6899
6900 const auto *NAry = cast<SCEVNAryExpr>(S);
6901 ConstantRange X = getRangeRef(NAry->getOperand(0), SignHint, Depth + 1);
6902 for (unsigned i = 1, e = NAry->getNumOperands(); i != e; ++i)
6903 X = X.intrinsic(
6904 ID, {X, getRangeRef(NAry->getOperand(i), SignHint, Depth + 1)});
6905 return setRange(S, SignHint,
6906 ConservativeResult.intersectWith(X, RangeType));
6907 }
6908 case scUnknown: {
6909 const SCEVUnknown *U = cast<SCEVUnknown>(S);
6910 Value *V = U->getValue();
6911
6912 // Check if the IR explicitly contains !range metadata.
6913 std::optional<ConstantRange> MDRange = GetRangeFromMetadata(V);
6914 if (MDRange)
6915 ConservativeResult =
6916 ConservativeResult.intersectWith(*MDRange, RangeType);
6917
6918 // Use facts about recurrences in the underlying IR. Note that add
6919 // recurrences are AddRecExprs and thus don't hit this path. This
6920 // primarily handles shift recurrences.
6921 auto CR = getRangeForUnknownRecurrence(U);
6922 ConservativeResult = ConservativeResult.intersectWith(CR);
6923
6924 // See if ValueTracking can give us a useful range.
6925 const DataLayout &DL = getDataLayout();
6926 KnownBits Known = computeKnownBits(V, DL, &AC, nullptr, &DT);
6927 if (Known.getBitWidth() != BitWidth)
6928 Known = Known.zextOrTrunc(BitWidth);
6929
6930 // ValueTracking may be able to compute a tighter result for the number of
6931 // sign bits than for the value of those sign bits.
6932 unsigned NS = ComputeNumSignBits(V, DL, &AC, nullptr, &DT);
6933 if (U->getType()->isPointerTy()) {
6934 // If the pointer size is larger than the index size type, this can cause
6935 // NS to be larger than BitWidth. So compensate for this.
6936 unsigned ptrSize = DL.getPointerTypeSizeInBits(U->getType());
6937 int ptrIdxDiff = ptrSize - BitWidth;
6938 if (ptrIdxDiff > 0 && ptrSize > BitWidth && NS > (unsigned)ptrIdxDiff)
6939 NS -= ptrIdxDiff;
6940 }
6941
6942 if (NS > 1) {
6943 // If we know any of the sign bits, we know all of the sign bits.
6944 if (!Known.Zero.getHiBits(NS).isZero())
6945 Known.Zero.setHighBits(NS);
6946 if (!Known.One.getHiBits(NS).isZero())
6947 Known.One.setHighBits(NS);
6948 }
6949
6950 if (Known.getMinValue() != Known.getMaxValue() + 1)
6951 ConservativeResult = ConservativeResult.intersectWith(
6952 ConstantRange(Known.getMinValue(), Known.getMaxValue() + 1),
6953 RangeType);
6954 if (NS > 1)
6955 ConservativeResult = ConservativeResult.intersectWith(
6956 ConstantRange(APInt::getSignedMinValue(BitWidth).ashr(NS - 1),
6957 APInt::getSignedMaxValue(BitWidth).ashr(NS - 1) + 1),
6958 RangeType);
6959
6960 if (U->getType()->isPointerTy() && SignHint == HINT_RANGE_UNSIGNED) {
6961 // Strengthen the range if the underlying IR value is a
6962 // global/alloca/heap allocation using the size of the object.
6963 bool CanBeNull, CanBeFreed;
6964 uint64_t DerefBytes =
6965 V->getPointerDereferenceableBytes(DL, CanBeNull, CanBeFreed);
6966 if (DerefBytes > 1 && isUIntN(BitWidth, DerefBytes)) {
6967 // The highest address the object can start is DerefBytes bytes before
6968 // the end (unsigned max value). If this value is not a multiple of the
6969 // alignment, the last possible start value is the next lowest multiple
6970 // of the alignment. Note: The computations below cannot overflow,
6971 // because if they would there's no possible start address for the
6972 // object.
6973 APInt MaxVal =
6974 APInt::getMaxValue(BitWidth) - APInt(BitWidth, DerefBytes);
6975 uint64_t Align = U->getValue()->getPointerAlignment(DL).value();
6976 uint64_t Rem = MaxVal.urem(Align);
6977 MaxVal -= APInt(BitWidth, Rem);
6978 APInt MinVal = APInt::getZero(BitWidth);
6979 if (llvm::isKnownNonZero(V, DL))
6980 MinVal = Align;
6981 ConservativeResult = ConservativeResult.intersectWith(
6982 ConstantRange::getNonEmpty(MinVal, MaxVal + 1), RangeType);
6983 }
6984 }
6985
6986 // A range of Phi is a subset of union of all ranges of its input.
6987 if (PHINode *Phi = dyn_cast<PHINode>(V)) {
6988 // Make sure that we do not run over cycled Phis.
6989 if (PendingPhiRanges.insert(Phi).second) {
6990 ConstantRange RangeFromOps(BitWidth, /*isFullSet=*/false);
6991
6992 for (const auto &Op : Phi->operands()) {
6993 auto OpRange = getRangeRef(getSCEV(Op), SignHint, Depth + 1);
6994 RangeFromOps = RangeFromOps.unionWith(OpRange);
6995 // No point to continue if we already have a full set.
6996 if (RangeFromOps.isFullSet())
6997 break;
6998 }
6999 ConservativeResult =
7000 ConservativeResult.intersectWith(RangeFromOps, RangeType);
7001 bool Erased = PendingPhiRanges.erase(Phi);
7002 assert(Erased && "Failed to erase Phi properly?");
7003 (void)Erased;
7004 }
7005 }
7006
7007 // vscale can't be equal to zero
7008 if (const auto *II = dyn_cast<IntrinsicInst>(V))
7009 if (II->getIntrinsicID() == Intrinsic::vscale) {
7010 ConstantRange Disallowed = APInt::getZero(BitWidth);
7011 ConservativeResult = ConservativeResult.difference(Disallowed);
7012 }
7013
7014 return setRange(U, SignHint, std::move(ConservativeResult));
7015 }
7016 case scCouldNotCompute:
7017 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
7018 }
7019
7020 return setRange(S, SignHint, std::move(ConservativeResult));
7021}
7022
7023// Given a StartRange, Step and MaxBECount for an expression compute a range of
7024// values that the expression can take. Initially, the expression has a value
7025// from StartRange and then is changed by Step up to MaxBECount times. Signed
7026// argument defines if we treat Step as signed or unsigned.
7028 const ConstantRange &StartRange,
7029 const APInt &MaxBECount,
7030 bool Signed) {
7031 unsigned BitWidth = Step.getBitWidth();
7032 assert(BitWidth == StartRange.getBitWidth() &&
7033 BitWidth == MaxBECount.getBitWidth() && "mismatched bit widths");
7034 // If either Step or MaxBECount is 0, then the expression won't change, and we
7035 // just need to return the initial range.
7036 if (Step == 0 || MaxBECount == 0)
7037 return StartRange;
7038
7039 // If we don't know anything about the initial value (i.e. StartRange is
7040 // FullRange), then we don't know anything about the final range either.
7041 // Return FullRange.
7042 if (StartRange.isFullSet())
7043 return ConstantRange::getFull(BitWidth);
7044
7045 // If Step is signed and negative, then we use its absolute value, but we also
7046 // note that we're moving in the opposite direction.
7047 bool Descending = Signed && Step.isNegative();
7048
7049 if (Signed)
7050 // This is correct even for INT_SMIN. Let's look at i8 to illustrate this:
7051 // abs(INT_SMIN) = abs(-128) = abs(0x80) = -0x80 = 0x80 = 128.
7052 // This equations hold true due to the well-defined wrap-around behavior of
7053 // APInt.
7054 Step = Step.abs();
7055
7056 // Check if Offset is more than full span of BitWidth. If it is, the
7057 // expression is guaranteed to overflow.
7058 if (APInt::getMaxValue(StartRange.getBitWidth()).udiv(Step).ult(MaxBECount))
7059 return ConstantRange::getFull(BitWidth);
7060
7061 // Offset is by how much the expression can change. Checks above guarantee no
7062 // overflow here.
7063 APInt Offset = Step * MaxBECount;
7064
7065 // Minimum value of the final range will match the minimal value of StartRange
7066 // if the expression is increasing and will be decreased by Offset otherwise.
7067 // Maximum value of the final range will match the maximal value of StartRange
7068 // if the expression is decreasing and will be increased by Offset otherwise.
7069 APInt StartLower = StartRange.getLower();
7070 APInt StartUpper = StartRange.getUpper() - 1;
7071 APInt MovedBoundary = Descending ? (StartLower - std::move(Offset))
7072 : (StartUpper + std::move(Offset));
7073
7074 // It's possible that the new minimum/maximum value will fall into the initial
7075 // range (due to wrap around). This means that the expression can take any
7076 // value in this bitwidth, and we have to return full range.
7077 if (StartRange.contains(MovedBoundary))
7078 return ConstantRange::getFull(BitWidth);
7079
7080 APInt NewLower =
7081 Descending ? std::move(MovedBoundary) : std::move(StartLower);
7082 APInt NewUpper =
7083 Descending ? std::move(StartUpper) : std::move(MovedBoundary);
7084 NewUpper += 1;
7085
7086 // No overflow detected, return [StartLower, StartUpper + Offset + 1) range.
7087 return ConstantRange::getNonEmpty(std::move(NewLower), std::move(NewUpper));
7088}
7089
7090ConstantRange ScalarEvolution::getRangeForAffineAR(const SCEV *Start,
7091 const SCEV *Step,
7092 const APInt &MaxBECount) {
7093 assert(getTypeSizeInBits(Start->getType()) ==
7094 getTypeSizeInBits(Step->getType()) &&
7095 getTypeSizeInBits(Start->getType()) == MaxBECount.getBitWidth() &&
7096 "mismatched bit widths");
7097
7098 // First, consider step signed.
7099 ConstantRange StartSRange = getSignedRange(Start);
7100 ConstantRange StepSRange = getSignedRange(Step);
7101
7102 // If Step can be both positive and negative, we need to find ranges for the
7103 // maximum absolute step values in both directions and union them.
7104 ConstantRange SR = getRangeForAffineARHelper(
7105 StepSRange.getSignedMin(), StartSRange, MaxBECount, /* Signed = */ true);
7107 StartSRange, MaxBECount,
7108 /* Signed = */ true));
7109
7110 // Next, consider step unsigned.
7111 ConstantRange UR = getRangeForAffineARHelper(
7112 getUnsignedRangeMax(Step), getUnsignedRange(Start), MaxBECount,
7113 /* Signed = */ false);
7114
7115 // Finally, intersect signed and unsigned ranges.
7117}
7118
7119ConstantRange ScalarEvolution::getRangeForAffineNoSelfWrappingAR(
7120 const SCEVAddRecExpr *AddRec, const SCEV *MaxBECount, unsigned BitWidth,
7121 ScalarEvolution::RangeSignHint SignHint) {
7122 assert(AddRec->isAffine() && "Non-affine AddRecs are not suppored!\n");
7123 assert(AddRec->hasNoSelfWrap() &&
7124 "This only works for non-self-wrapping AddRecs!");
7125 const bool IsSigned = SignHint == HINT_RANGE_SIGNED;
7126 const SCEV *Step = AddRec->getStepRecurrence(*this);
7127 // Only deal with constant step to save compile time.
7128 if (!isa<SCEVConstant>(Step))
7129 return ConstantRange::getFull(BitWidth);
7130 // Let's make sure that we can prove that we do not self-wrap during
7131 // MaxBECount iterations. We need this because MaxBECount is a maximum
7132 // iteration count estimate, and we might infer nw from some exit for which we
7133 // do not know max exit count (or any other side reasoning).
7134 // TODO: Turn into assert at some point.
7135 if (getTypeSizeInBits(MaxBECount->getType()) >
7136 getTypeSizeInBits(AddRec->getType()))
7137 return ConstantRange::getFull(BitWidth);
7138 MaxBECount = getNoopOrZeroExtend(MaxBECount, AddRec->getType());
7139 const SCEV *RangeWidth = getMinusOne(AddRec->getType());
7140 const SCEV *StepAbs = getUMinExpr(Step, getNegativeSCEV(Step));
7141 const SCEV *MaxItersWithoutWrap = getUDivExpr(RangeWidth, StepAbs);
7142 if (!isKnownPredicateViaConstantRanges(ICmpInst::ICMP_ULE, MaxBECount,
7143 MaxItersWithoutWrap))
7144 return ConstantRange::getFull(BitWidth);
7145
7146 ICmpInst::Predicate LEPred =
7148 ICmpInst::Predicate GEPred =
7150 const SCEV *End = AddRec->evaluateAtIteration(MaxBECount, *this);
7151
7152 // We know that there is no self-wrap. Let's take Start and End values and
7153 // look at all intermediate values V1, V2, ..., Vn that IndVar takes during
7154 // the iteration. They either lie inside the range [Min(Start, End),
7155 // Max(Start, End)] or outside it:
7156 //
7157 // Case 1: RangeMin ... Start V1 ... VN End ... RangeMax;
7158 // Case 2: RangeMin Vk ... V1 Start ... End Vn ... Vk + 1 RangeMax;
7159 //
7160 // No self wrap flag guarantees that the intermediate values cannot be BOTH
7161 // outside and inside the range [Min(Start, End), Max(Start, End)]. Using that
7162 // knowledge, let's try to prove that we are dealing with Case 1. It is so if
7163 // Start <= End and step is positive, or Start >= End and step is negative.
7164 const SCEV *Start = applyLoopGuards(AddRec->getStart(), AddRec->getLoop());
7165 ConstantRange StartRange = getRangeRef(Start, SignHint);
7166 ConstantRange EndRange = getRangeRef(End, SignHint);
7167 ConstantRange RangeBetween = StartRange.unionWith(EndRange);
7168 // If they already cover full iteration space, we will know nothing useful
7169 // even if we prove what we want to prove.
7170 if (RangeBetween.isFullSet())
7171 return RangeBetween;
7172 // Only deal with ranges that do not wrap (i.e. RangeMin < RangeMax).
7173 bool IsWrappedSet = IsSigned ? RangeBetween.isSignWrappedSet()
7174 : RangeBetween.isWrappedSet();
7175 if (IsWrappedSet)
7176 return ConstantRange::getFull(BitWidth);
7177
7178 if (isKnownPositive(Step) &&
7179 isKnownPredicateViaConstantRanges(LEPred, Start, End))
7180 return RangeBetween;
7181 if (isKnownNegative(Step) &&
7182 isKnownPredicateViaConstantRanges(GEPred, Start, End))
7183 return RangeBetween;
7184 return ConstantRange::getFull(BitWidth);
7185}
7186
7187ConstantRange ScalarEvolution::getRangeViaFactoring(const SCEV *Start,
7188 const SCEV *Step,
7189 const APInt &MaxBECount) {
7190 // RangeOf({C?A:B,+,C?P:Q}) == RangeOf(C?{A,+,P}:{B,+,Q})
7191 // == RangeOf({A,+,P}) union RangeOf({B,+,Q})
7192
7193 unsigned BitWidth = MaxBECount.getBitWidth();
7194 assert(getTypeSizeInBits(Start->getType()) == BitWidth &&
7195 getTypeSizeInBits(Step->getType()) == BitWidth &&
7196 "mismatched bit widths");
7197
7198 struct SelectPattern {
7199 Value *Condition = nullptr;
7200 APInt TrueValue;
7201 APInt FalseValue;
7202
7203 explicit SelectPattern(ScalarEvolution &SE, unsigned BitWidth,
7204 const SCEV *S) {
7205 std::optional<unsigned> CastOp;
7206 APInt Offset(BitWidth, 0);
7207
7209 "Should be!");
7210
7211 // Peel off a constant offset. In the future we could consider being
7212 // smarter here and handle {Start+Step,+,Step} too.
7213 const APInt *Off;
7214 if (match(S, m_scev_Add(m_scev_APInt(Off), m_SCEV(S))))
7215 Offset = *Off;
7216
7217 // Peel off a cast operation
7218 if (auto *SCast = dyn_cast<SCEVIntegralCastExpr>(S)) {
7219 CastOp = SCast->getSCEVType();
7220 S = SCast->getOperand();
7221 }
7222
7223 using namespace llvm::PatternMatch;
7224
7225 auto *SU = dyn_cast<SCEVUnknown>(S);
7226 const APInt *TrueVal, *FalseVal;
7227 if (!SU ||
7228 !match(SU->getValue(), m_Select(m_Value(Condition), m_APInt(TrueVal),
7229 m_APInt(FalseVal)))) {
7230 Condition = nullptr;
7231 return;
7232 }
7233
7234 TrueValue = *TrueVal;
7235 FalseValue = *FalseVal;
7236
7237 // Re-apply the cast we peeled off earlier
7238 if (CastOp)
7239 switch (*CastOp) {
7240 default:
7241 llvm_unreachable("Unknown SCEV cast type!");
7242
7243 case scTruncate:
7244 TrueValue = TrueValue.trunc(BitWidth);
7245 FalseValue = FalseValue.trunc(BitWidth);
7246 break;
7247 case scZeroExtend:
7248 TrueValue = TrueValue.zext(BitWidth);
7249 FalseValue = FalseValue.zext(BitWidth);
7250 break;
7251 case scSignExtend:
7252 TrueValue = TrueValue.sext(BitWidth);
7253 FalseValue = FalseValue.sext(BitWidth);
7254 break;
7255 }
7256
7257 // Re-apply the constant offset we peeled off earlier
7258 TrueValue += Offset;
7259 FalseValue += Offset;
7260 }
7261
7262 bool isRecognized() { return Condition != nullptr; }
7263 };
7264
7265 SelectPattern StartPattern(*this, BitWidth, Start);
7266 if (!StartPattern.isRecognized())
7267 return ConstantRange::getFull(BitWidth);
7268
7269 SelectPattern StepPattern(*this, BitWidth, Step);
7270 if (!StepPattern.isRecognized())
7271 return ConstantRange::getFull(BitWidth);
7272
7273 if (StartPattern.Condition != StepPattern.Condition) {
7274 // We don't handle this case today; but we could, by considering four
7275 // possibilities below instead of two. I'm not sure if there are cases where
7276 // that will help over what getRange already does, though.
7277 return ConstantRange::getFull(BitWidth);
7278 }
7279
7280 // NB! Calling ScalarEvolution::getConstant is fine, but we should not try to
7281 // construct arbitrary general SCEV expressions here. This function is called
7282 // from deep in the call stack, and calling getSCEV (on a sext instruction,
7283 // say) can end up caching a suboptimal value.
7284
7285 // FIXME: without the explicit `this` receiver below, MSVC errors out with
7286 // C2352 and C2512 (otherwise it isn't needed).
7287
7288 const SCEV *TrueStart = this->getConstant(StartPattern.TrueValue);
7289 const SCEV *TrueStep = this->getConstant(StepPattern.TrueValue);
7290 const SCEV *FalseStart = this->getConstant(StartPattern.FalseValue);
7291 const SCEV *FalseStep = this->getConstant(StepPattern.FalseValue);
7292
7293 ConstantRange TrueRange =
7294 this->getRangeForAffineAR(TrueStart, TrueStep, MaxBECount);
7295 ConstantRange FalseRange =
7296 this->getRangeForAffineAR(FalseStart, FalseStep, MaxBECount);
7297
7298 return TrueRange.unionWith(FalseRange);
7299}
7300
7301SCEV::NoWrapFlags ScalarEvolution::getNoWrapFlagsFromUB(const Value *V) {
7302 if (isa<ConstantExpr>(V)) return SCEV::FlagAnyWrap;
7303 const BinaryOperator *BinOp = cast<BinaryOperator>(V);
7304
7305 // Return early if there are no flags to propagate to the SCEV.
7307 if (BinOp->hasNoUnsignedWrap())
7309 if (BinOp->hasNoSignedWrap())
7311 if (Flags == SCEV::FlagAnyWrap)
7312 return SCEV::FlagAnyWrap;
7313
7314 return isSCEVExprNeverPoison(BinOp) ? Flags : SCEV::FlagAnyWrap;
7315}
7316
7317const Instruction *
7318ScalarEvolution::getNonTrivialDefiningScopeBound(const SCEV *S) {
7319 if (auto *AddRec = dyn_cast<SCEVAddRecExpr>(S))
7320 return &*AddRec->getLoop()->getHeader()->begin();
7321 if (auto *U = dyn_cast<SCEVUnknown>(S))
7322 if (auto *I = dyn_cast<Instruction>(U->getValue()))
7323 return I;
7324 return nullptr;
7325}
7326
7327const Instruction *
7328ScalarEvolution::getDefiningScopeBound(ArrayRef<const SCEV *> Ops,
7329 bool &Precise) {
7330 Precise = true;
7331 // Do a bounded search of the def relation of the requested SCEVs.
7332 SmallPtrSet<const SCEV *, 16> Visited;
7334 auto pushOp = [&](const SCEV *S) {
7335 if (!Visited.insert(S).second)
7336 return;
7337 // Threshold of 30 here is arbitrary.
7338 if (Visited.size() > 30) {
7339 Precise = false;
7340 return;
7341 }
7342 Worklist.push_back(S);
7343 };
7344
7345 for (const auto *S : Ops)
7346 pushOp(S);
7347
7348 const Instruction *Bound = nullptr;
7349 while (!Worklist.empty()) {
7350 auto *S = Worklist.pop_back_val();
7351 if (auto *DefI = getNonTrivialDefiningScopeBound(S)) {
7352 if (!Bound || DT.dominates(Bound, DefI))
7353 Bound = DefI;
7354 } else {
7355 for (const auto *Op : S->operands())
7356 pushOp(Op);
7357 }
7358 }
7359 return Bound ? Bound : &*F.getEntryBlock().begin();
7360}
7361
7362const Instruction *
7363ScalarEvolution::getDefiningScopeBound(ArrayRef<const SCEV *> Ops) {
7364 bool Discard;
7365 return getDefiningScopeBound(Ops, Discard);
7366}
7367
7368bool ScalarEvolution::isGuaranteedToTransferExecutionTo(const Instruction *A,
7369 const Instruction *B) {
7370 if (A->getParent() == B->getParent() &&
7372 B->getIterator()))
7373 return true;
7374
7375 auto *BLoop = LI.getLoopFor(B->getParent());
7376 if (BLoop && BLoop->getHeader() == B->getParent() &&
7377 BLoop->getLoopPreheader() == A->getParent() &&
7379 A->getParent()->end()) &&
7380 isGuaranteedToTransferExecutionToSuccessor(B->getParent()->begin(),
7381 B->getIterator()))
7382 return true;
7383 return false;
7384}
7385
7386bool ScalarEvolution::isGuaranteedNotToBePoison(const SCEV *Op) {
7387 SCEVPoisonCollector PC(/* LookThroughMaybePoisonBlocking */ true);
7388 visitAll(Op, PC);
7389 return PC.MaybePoison.empty();
7390}
7391
7392bool ScalarEvolution::isGuaranteedNotToCauseUB(const SCEV *Op) {
7393 return !SCEVExprContains(Op, [this](const SCEV *S) {
7394 const SCEV *Op1;
7395 bool M = match(S, m_scev_UDiv(m_SCEV(), m_SCEV(Op1)));
7396 // The UDiv may be UB if the divisor is poison or zero. Unless the divisor
7397 // is a non-zero constant, we have to assume the UDiv may be UB.
7398 return M && (!isKnownNonZero(Op1) || !isGuaranteedNotToBePoison(Op1));
7399 });
7400}
7401
7402bool ScalarEvolution::isSCEVExprNeverPoison(const Instruction *I) {
7403 // Only proceed if we can prove that I does not yield poison.
7405 return false;
7406
7407 // At this point we know that if I is executed, then it does not wrap
7408 // according to at least one of NSW or NUW. If I is not executed, then we do
7409 // not know if the calculation that I represents would wrap. Multiple
7410 // instructions can map to the same SCEV. If we apply NSW or NUW from I to
7411 // the SCEV, we must guarantee no wrapping for that SCEV also when it is
7412 // derived from other instructions that map to the same SCEV. We cannot make
7413 // that guarantee for cases where I is not executed. So we need to find a
7414 // upper bound on the defining scope for the SCEV, and prove that I is
7415 // executed every time we enter that scope. When the bounding scope is a
7416 // loop (the common case), this is equivalent to proving I executes on every
7417 // iteration of that loop.
7419 for (const Use &Op : I->operands()) {
7420 // I could be an extractvalue from a call to an overflow intrinsic.
7421 // TODO: We can do better here in some cases.
7422 if (isSCEVable(Op->getType()))
7423 SCEVOps.push_back(getSCEV(Op));
7424 }
7425 auto *DefI = getDefiningScopeBound(SCEVOps);
7426 return isGuaranteedToTransferExecutionTo(DefI, I);
7427}
7428
7429bool ScalarEvolution::isAddRecNeverPoison(const Instruction *I, const Loop *L) {
7430 // If we know that \c I can never be poison period, then that's enough.
7431 if (isSCEVExprNeverPoison(I))
7432 return true;
7433
7434 // If the loop only has one exit, then we know that, if the loop is entered,
7435 // any instruction dominating that exit will be executed. If any such
7436 // instruction would result in UB, the addrec cannot be poison.
7437 //
7438 // This is basically the same reasoning as in isSCEVExprNeverPoison(), but
7439 // also handles uses outside the loop header (they just need to dominate the
7440 // single exit).
7441
7442 auto *ExitingBB = L->getExitingBlock();
7443 if (!ExitingBB || !loopHasNoAbnormalExits(L))
7444 return false;
7445
7446 SmallPtrSet<const Value *, 16> KnownPoison;
7448
7449 // We start by assuming \c I, the post-inc add recurrence, is poison. Only
7450 // things that are known to be poison under that assumption go on the
7451 // Worklist.
7452 KnownPoison.insert(I);
7453 Worklist.push_back(I);
7454
7455 while (!Worklist.empty()) {
7456 const Instruction *Poison = Worklist.pop_back_val();
7457
7458 for (const Use &U : Poison->uses()) {
7459 const Instruction *PoisonUser = cast<Instruction>(U.getUser());
7460 if (mustTriggerUB(PoisonUser, KnownPoison) &&
7461 DT.dominates(PoisonUser->getParent(), ExitingBB))
7462 return true;
7463
7464 if (propagatesPoison(U) && L->contains(PoisonUser))
7465 if (KnownPoison.insert(PoisonUser).second)
7466 Worklist.push_back(PoisonUser);
7467 }
7468 }
7469
7470 return false;
7471}
7472
7473ScalarEvolution::LoopProperties
7474ScalarEvolution::getLoopProperties(const Loop *L) {
7475 using LoopProperties = ScalarEvolution::LoopProperties;
7476
7477 auto Itr = LoopPropertiesCache.find(L);
7478 if (Itr == LoopPropertiesCache.end()) {
7479 auto HasSideEffects = [](Instruction *I) {
7480 if (auto *SI = dyn_cast<StoreInst>(I))
7481 return !SI->isSimple();
7482
7483 if (I->mayThrow())
7484 return true;
7485
7486 // Non-volatile memset / memcpy do not count as side-effect for forward
7487 // progress.
7488 if (isa<MemIntrinsic>(I) && !I->isVolatile())
7489 return false;
7490
7491 return I->mayWriteToMemory();
7492 };
7493
7494 LoopProperties LP = {/* HasNoAbnormalExits */ true,
7495 /*HasNoSideEffects*/ true};
7496
7497 for (auto *BB : L->getBlocks())
7498 for (auto &I : *BB) {
7500 LP.HasNoAbnormalExits = false;
7501 if (HasSideEffects(&I))
7502 LP.HasNoSideEffects = false;
7503 if (!LP.HasNoAbnormalExits && !LP.HasNoSideEffects)
7504 break; // We're already as pessimistic as we can get.
7505 }
7506
7507 auto InsertPair = LoopPropertiesCache.insert({L, LP});
7508 assert(InsertPair.second && "We just checked!");
7509 Itr = InsertPair.first;
7510 }
7511
7512 return Itr->second;
7513}
7514
7516 // A mustprogress loop without side effects must be finite.
7517 // TODO: The check used here is very conservative. It's only *specific*
7518 // side effects which are well defined in infinite loops.
7519 return isFinite(L) || (isMustProgress(L) && loopHasNoSideEffects(L));
7520}
7521
7522const SCEV *ScalarEvolution::createSCEVIter(Value *V) {
7523 // Worklist item with a Value and a bool indicating whether all operands have
7524 // been visited already.
7527
7528 Stack.emplace_back(V, true);
7529 Stack.emplace_back(V, false);
7530 while (!Stack.empty()) {
7531 auto E = Stack.pop_back_val();
7532 Value *CurV = E.getPointer();
7533
7534 if (getExistingSCEV(CurV))
7535 continue;
7536
7538 const SCEV *CreatedSCEV = nullptr;
7539 // If all operands have been visited already, create the SCEV.
7540 if (E.getInt()) {
7541 CreatedSCEV = createSCEV(CurV);
7542 } else {
7543 // Otherwise get the operands we need to create SCEV's for before creating
7544 // the SCEV for CurV. If the SCEV for CurV can be constructed trivially,
7545 // just use it.
7546 CreatedSCEV = getOperandsToCreate(CurV, Ops);
7547 }
7548
7549 if (CreatedSCEV) {
7550 insertValueToMap(CurV, CreatedSCEV);
7551 } else {
7552 // Queue CurV for SCEV creation, followed by its's operands which need to
7553 // be constructed first.
7554 Stack.emplace_back(CurV, true);
7555 for (Value *Op : Ops)
7556 Stack.emplace_back(Op, false);
7557 }
7558 }
7559
7560 return getExistingSCEV(V);
7561}
7562
7563const SCEV *
7564ScalarEvolution::getOperandsToCreate(Value *V, SmallVectorImpl<Value *> &Ops) {
7565 if (!isSCEVable(V->getType()))
7566 return getUnknown(V);
7567
7568 if (Instruction *I = dyn_cast<Instruction>(V)) {
7569 // Don't attempt to analyze instructions in blocks that aren't
7570 // reachable. Such instructions don't matter, and they aren't required
7571 // to obey basic rules for definitions dominating uses which this
7572 // analysis depends on.
7573 if (!DT.isReachableFromEntry(I->getParent()))
7574 return getUnknown(PoisonValue::get(V->getType()));
7575 } else if (ConstantInt *CI = dyn_cast<ConstantInt>(V))
7576 return getConstant(CI);
7577 else if (isa<GlobalAlias>(V))
7578 return getUnknown(V);
7579 else if (!isa<ConstantExpr>(V))
7580 return getUnknown(V);
7581
7583 if (auto BO =
7585 bool IsConstArg = isa<ConstantInt>(BO->RHS);
7586 switch (BO->Opcode) {
7587 case Instruction::Add:
7588 case Instruction::Mul: {
7589 // For additions and multiplications, traverse add/mul chains for which we
7590 // can potentially create a single SCEV, to reduce the number of
7591 // get{Add,Mul}Expr calls.
7592 do {
7593 if (BO->Op) {
7594 if (BO->Op != V && getExistingSCEV(BO->Op)) {
7595 Ops.push_back(BO->Op);
7596 break;
7597 }
7598 }
7599 Ops.push_back(BO->RHS);
7600 auto NewBO = MatchBinaryOp(BO->LHS, getDataLayout(), AC, DT,
7602 if (!NewBO ||
7603 (BO->Opcode == Instruction::Add &&
7604 (NewBO->Opcode != Instruction::Add &&
7605 NewBO->Opcode != Instruction::Sub)) ||
7606 (BO->Opcode == Instruction::Mul &&
7607 NewBO->Opcode != Instruction::Mul)) {
7608 Ops.push_back(BO->LHS);
7609 break;
7610 }
7611 // CreateSCEV calls getNoWrapFlagsFromUB, which under certain conditions
7612 // requires a SCEV for the LHS.
7613 if (BO->Op && (BO->IsNSW || BO->IsNUW)) {
7614 auto *I = dyn_cast<Instruction>(BO->Op);
7615 if (I && programUndefinedIfPoison(I)) {
7616 Ops.push_back(BO->LHS);
7617 break;
7618 }
7619 }
7620 BO = NewBO;
7621 } while (true);
7622 return nullptr;
7623 }
7624 case Instruction::Sub:
7625 case Instruction::UDiv:
7626 case Instruction::URem:
7627 break;
7628 case Instruction::AShr:
7629 case Instruction::Shl:
7630 case Instruction::Xor:
7631 if (!IsConstArg)
7632 return nullptr;
7633 break;
7634 case Instruction::And:
7635 case Instruction::Or:
7636 if (!IsConstArg && !BO->LHS->getType()->isIntegerTy(1))
7637 return nullptr;
7638 break;
7639 case Instruction::LShr:
7640 return getUnknown(V);
7641 default:
7642 llvm_unreachable("Unhandled binop");
7643 break;
7644 }
7645
7646 Ops.push_back(BO->LHS);
7647 Ops.push_back(BO->RHS);
7648 return nullptr;
7649 }
7650
7651 switch (U->getOpcode()) {
7652 case Instruction::Trunc:
7653 case Instruction::ZExt:
7654 case Instruction::SExt:
7655 case Instruction::PtrToInt:
7656 Ops.push_back(U->getOperand(0));
7657 return nullptr;
7658
7659 case Instruction::BitCast:
7660 if (isSCEVable(U->getType()) && isSCEVable(U->getOperand(0)->getType())) {
7661 Ops.push_back(U->getOperand(0));
7662 return nullptr;
7663 }
7664 return getUnknown(V);
7665
7666 case Instruction::SDiv:
7667 case Instruction::SRem:
7668 Ops.push_back(U->getOperand(0));
7669 Ops.push_back(U->getOperand(1));
7670 return nullptr;
7671
7672 case Instruction::GetElementPtr:
7673 assert(cast<GEPOperator>(U)->getSourceElementType()->isSized() &&
7674 "GEP source element type must be sized");
7675 llvm::append_range(Ops, U->operands());
7676 return nullptr;
7677
7678 case Instruction::IntToPtr:
7679 return getUnknown(V);
7680
7681 case Instruction::PHI:
7682 // Keep constructing SCEVs' for phis recursively for now.
7683 return nullptr;
7684
7685 case Instruction::Select: {
7686 // Check if U is a select that can be simplified to a SCEVUnknown.
7687 auto CanSimplifyToUnknown = [this, U]() {
7688 if (U->getType()->isIntegerTy(1) || isa<ConstantInt>(U->getOperand(0)))
7689 return false;
7690
7691 auto *ICI = dyn_cast<ICmpInst>(U->getOperand(0));
7692 if (!ICI)
7693 return false;
7694 Value *LHS = ICI->getOperand(0);
7695 Value *RHS = ICI->getOperand(1);
7696 if (ICI->getPredicate() == CmpInst::ICMP_EQ ||
7697 ICI->getPredicate() == CmpInst::ICMP_NE) {
7699 return true;
7700 } else if (getTypeSizeInBits(LHS->getType()) >
7701 getTypeSizeInBits(U->getType()))
7702 return true;
7703 return false;
7704 };
7705 if (CanSimplifyToUnknown())
7706 return getUnknown(U);
7707
7708 llvm::append_range(Ops, U->operands());
7709 return nullptr;
7710 break;
7711 }
7712 case Instruction::Call:
7713 case Instruction::Invoke:
7714 if (Value *RV = cast<CallBase>(U)->getReturnedArgOperand()) {
7715 Ops.push_back(RV);
7716 return nullptr;
7717 }
7718
7719 if (auto *II = dyn_cast<IntrinsicInst>(U)) {
7720 switch (II->getIntrinsicID()) {
7721 case Intrinsic::abs:
7722 Ops.push_back(II->getArgOperand(0));
7723 return nullptr;
7724 case Intrinsic::umax:
7725 case Intrinsic::umin:
7726 case Intrinsic::smax:
7727 case Intrinsic::smin:
7728 case Intrinsic::usub_sat:
7729 case Intrinsic::uadd_sat:
7730 Ops.push_back(II->getArgOperand(0));
7731 Ops.push_back(II->getArgOperand(1));
7732 return nullptr;
7733 case Intrinsic::start_loop_iterations:
7734 case Intrinsic::annotation:
7735 case Intrinsic::ptr_annotation:
7736 Ops.push_back(II->getArgOperand(0));
7737 return nullptr;
7738 default:
7739 break;
7740 }
7741 }
7742 break;
7743 }
7744
7745 return nullptr;
7746}
7747
7748const SCEV *ScalarEvolution::createSCEV(Value *V) {
7749 if (!isSCEVable(V->getType()))
7750 return getUnknown(V);
7751
7752 if (Instruction *I = dyn_cast<Instruction>(V)) {
7753 // Don't attempt to analyze instructions in blocks that aren't
7754 // reachable. Such instructions don't matter, and they aren't required
7755 // to obey basic rules for definitions dominating uses which this
7756 // analysis depends on.
7757 if (!DT.isReachableFromEntry(I->getParent()))
7758 return getUnknown(PoisonValue::get(V->getType()));
7759 } else if (ConstantInt *CI = dyn_cast<ConstantInt>(V))
7760 return getConstant(CI);
7761 else if (isa<GlobalAlias>(V))
7762 return getUnknown(V);
7763 else if (!isa<ConstantExpr>(V))
7764 return getUnknown(V);
7765
7766 const SCEV *LHS;
7767 const SCEV *RHS;
7768
7770 if (auto BO =
7772 switch (BO->Opcode) {
7773 case Instruction::Add: {
7774 // The simple thing to do would be to just call getSCEV on both operands
7775 // and call getAddExpr with the result. However if we're looking at a
7776 // bunch of things all added together, this can be quite inefficient,
7777 // because it leads to N-1 getAddExpr calls for N ultimate operands.
7778 // Instead, gather up all the operands and make a single getAddExpr call.
7779 // LLVM IR canonical form means we need only traverse the left operands.
7781 do {
7782 if (BO->Op) {
7783 if (auto *OpSCEV = getExistingSCEV(BO->Op)) {
7784 AddOps.push_back(OpSCEV);
7785 break;
7786 }
7787
7788 // If a NUW or NSW flag can be applied to the SCEV for this
7789 // addition, then compute the SCEV for this addition by itself
7790 // with a separate call to getAddExpr. We need to do that
7791 // instead of pushing the operands of the addition onto AddOps,
7792 // since the flags are only known to apply to this particular
7793 // addition - they may not apply to other additions that can be
7794 // formed with operands from AddOps.
7795 const SCEV *RHS = getSCEV(BO->RHS);
7796 SCEV::NoWrapFlags Flags = getNoWrapFlagsFromUB(BO->Op);
7797 if (Flags != SCEV::FlagAnyWrap) {
7798 const SCEV *LHS = getSCEV(BO->LHS);
7799 if (BO->Opcode == Instruction::Sub)
7800 AddOps.push_back(getMinusSCEV(LHS, RHS, Flags));
7801 else
7802 AddOps.push_back(getAddExpr(LHS, RHS, Flags));
7803 break;
7804 }
7805 }
7806
7807 if (BO->Opcode == Instruction::Sub)
7808 AddOps.push_back(getNegativeSCEV(getSCEV(BO->RHS)));
7809 else
7810 AddOps.push_back(getSCEV(BO->RHS));
7811
7812 auto NewBO = MatchBinaryOp(BO->LHS, getDataLayout(), AC, DT,
7814 if (!NewBO || (NewBO->Opcode != Instruction::Add &&
7815 NewBO->Opcode != Instruction::Sub)) {
7816 AddOps.push_back(getSCEV(BO->LHS));
7817 break;
7818 }
7819 BO = NewBO;
7820 } while (true);
7821
7822 return getAddExpr(AddOps);
7823 }
7824
7825 case Instruction::Mul: {
7827 do {
7828 if (BO->Op) {
7829 if (auto *OpSCEV = getExistingSCEV(BO->Op)) {
7830 MulOps.push_back(OpSCEV);
7831 break;
7832 }
7833
7834 SCEV::NoWrapFlags Flags = getNoWrapFlagsFromUB(BO->Op);
7835 if (Flags != SCEV::FlagAnyWrap) {
7836 LHS = getSCEV(BO->LHS);
7837 RHS = getSCEV(BO->RHS);
7838 MulOps.push_back(getMulExpr(LHS, RHS, Flags));
7839 break;
7840 }
7841 }
7842
7843 MulOps.push_back(getSCEV(BO->RHS));
7844 auto NewBO = MatchBinaryOp(BO->LHS, getDataLayout(), AC, DT,
7846 if (!NewBO || NewBO->Opcode != Instruction::Mul) {
7847 MulOps.push_back(getSCEV(BO->LHS));
7848 break;
7849 }
7850 BO = NewBO;
7851 } while (true);
7852
7853 return getMulExpr(MulOps);
7854 }
7855 case Instruction::UDiv:
7856 LHS = getSCEV(BO->LHS);
7857 RHS = getSCEV(BO->RHS);
7858 return getUDivExpr(LHS, RHS);
7859 case Instruction::URem:
7860 LHS = getSCEV(BO->LHS);
7861 RHS = getSCEV(BO->RHS);
7862 return getURemExpr(LHS, RHS);
7863 case Instruction::Sub: {
7865 if (BO->Op)
7866 Flags = getNoWrapFlagsFromUB(BO->Op);
7867 LHS = getSCEV(BO->LHS);
7868 RHS = getSCEV(BO->RHS);
7869 return getMinusSCEV(LHS, RHS, Flags);
7870 }
7871 case Instruction::And:
7872 // For an expression like x&255 that merely masks off the high bits,
7873 // use zext(trunc(x)) as the SCEV expression.
7874 if (ConstantInt *CI = dyn_cast<ConstantInt>(BO->RHS)) {
7875 if (CI->isZero())
7876 return getSCEV(BO->RHS);
7877 if (CI->isMinusOne())
7878 return getSCEV(BO->LHS);
7879 const APInt &A = CI->getValue();
7880
7881 // Instcombine's ShrinkDemandedConstant may strip bits out of
7882 // constants, obscuring what would otherwise be a low-bits mask.
7883 // Use computeKnownBits to compute what ShrinkDemandedConstant
7884 // knew about to reconstruct a low-bits mask value.
7885 unsigned LZ = A.countl_zero();
7886 unsigned TZ = A.countr_zero();
7887 unsigned BitWidth = A.getBitWidth();
7888 KnownBits Known(BitWidth);
7889 computeKnownBits(BO->LHS, Known, getDataLayout(), &AC, nullptr, &DT);
7890
7891 APInt EffectiveMask =
7892 APInt::getLowBitsSet(BitWidth, BitWidth - LZ - TZ).shl(TZ);
7893 if ((LZ != 0 || TZ != 0) && !((~A & ~Known.Zero) & EffectiveMask)) {
7894 const SCEV *MulCount = getConstant(APInt::getOneBitSet(BitWidth, TZ));
7895 const SCEV *LHS = getSCEV(BO->LHS);
7896 const SCEV *ShiftedLHS = nullptr;
7897 if (auto *LHSMul = dyn_cast<SCEVMulExpr>(LHS)) {
7898 if (auto *OpC = dyn_cast<SCEVConstant>(LHSMul->getOperand(0))) {
7899 // For an expression like (x * 8) & 8, simplify the multiply.
7900 unsigned MulZeros = OpC->getAPInt().countr_zero();
7901 unsigned GCD = std::min(MulZeros, TZ);
7902 APInt DivAmt = APInt::getOneBitSet(BitWidth, TZ - GCD);
7904 MulOps.push_back(getConstant(OpC->getAPInt().ashr(GCD)));
7905 append_range(MulOps, LHSMul->operands().drop_front());
7906 auto *NewMul = getMulExpr(MulOps, LHSMul->getNoWrapFlags());
7907 ShiftedLHS = getUDivExpr(NewMul, getConstant(DivAmt));
7908 }
7909 }
7910 if (!ShiftedLHS)
7911 ShiftedLHS = getUDivExpr(LHS, MulCount);
7912 return getMulExpr(
7914 getTruncateExpr(ShiftedLHS,
7915 IntegerType::get(getContext(), BitWidth - LZ - TZ)),
7916 BO->LHS->getType()),
7917 MulCount);
7918 }
7919 }
7920 // Binary `and` is a bit-wise `umin`.
7921 if (BO->LHS->getType()->isIntegerTy(1)) {
7922 LHS = getSCEV(BO->LHS);
7923 RHS = getSCEV(BO->RHS);
7924 return getUMinExpr(LHS, RHS);
7925 }
7926 break;
7927
7928 case Instruction::Or:
7929 // Binary `or` is a bit-wise `umax`.
7930 if (BO->LHS->getType()->isIntegerTy(1)) {
7931 LHS = getSCEV(BO->LHS);
7932 RHS = getSCEV(BO->RHS);
7933 return getUMaxExpr(LHS, RHS);
7934 }
7935 break;
7936
7937 case Instruction::Xor:
7938 if (ConstantInt *CI = dyn_cast<ConstantInt>(BO->RHS)) {
7939 // If the RHS of xor is -1, then this is a not operation.
7940 if (CI->isMinusOne())
7941 return getNotSCEV(getSCEV(BO->LHS));
7942
7943 // Model xor(and(x, C), C) as and(~x, C), if C is a low-bits mask.
7944 // This is a variant of the check for xor with -1, and it handles
7945 // the case where instcombine has trimmed non-demanded bits out
7946 // of an xor with -1.
7947 if (auto *LBO = dyn_cast<BinaryOperator>(BO->LHS))
7948 if (ConstantInt *LCI = dyn_cast<ConstantInt>(LBO->getOperand(1)))
7949 if (LBO->getOpcode() == Instruction::And &&
7950 LCI->getValue() == CI->getValue())
7951 if (const SCEVZeroExtendExpr *Z =
7953 Type *UTy = BO->LHS->getType();
7954 const SCEV *Z0 = Z->getOperand();
7955 Type *Z0Ty = Z0->getType();
7956 unsigned Z0TySize = getTypeSizeInBits(Z0Ty);
7957
7958 // If C is a low-bits mask, the zero extend is serving to
7959 // mask off the high bits. Complement the operand and
7960 // re-apply the zext.
7961 if (CI->getValue().isMask(Z0TySize))
7962 return getZeroExtendExpr(getNotSCEV(Z0), UTy);
7963
7964 // If C is a single bit, it may be in the sign-bit position
7965 // before the zero-extend. In this case, represent the xor
7966 // using an add, which is equivalent, and re-apply the zext.
7967 APInt Trunc = CI->getValue().trunc(Z0TySize);
7968 if (Trunc.zext(getTypeSizeInBits(UTy)) == CI->getValue() &&
7969 Trunc.isSignMask())
7970 return getZeroExtendExpr(getAddExpr(Z0, getConstant(Trunc)),
7971 UTy);
7972 }
7973 }
7974 break;
7975
7976 case Instruction::Shl:
7977 // Turn shift left of a constant amount into a multiply.
7978 if (ConstantInt *SA = dyn_cast<ConstantInt>(BO->RHS)) {
7979 uint32_t BitWidth = cast<IntegerType>(SA->getType())->getBitWidth();
7980
7981 // If the shift count is not less than the bitwidth, the result of
7982 // the shift is undefined. Don't try to analyze it, because the
7983 // resolution chosen here may differ from the resolution chosen in
7984 // other parts of the compiler.
7985 if (SA->getValue().uge(BitWidth))
7986 break;
7987
7988 // We can safely preserve the nuw flag in all cases. It's also safe to
7989 // turn a nuw nsw shl into a nuw nsw mul. However, nsw in isolation
7990 // requires special handling. It can be preserved as long as we're not
7991 // left shifting by bitwidth - 1.
7992 auto Flags = SCEV::FlagAnyWrap;
7993 if (BO->Op) {
7994 auto MulFlags = getNoWrapFlagsFromUB(BO->Op);
7995 if ((MulFlags & SCEV::FlagNSW) &&
7996 ((MulFlags & SCEV::FlagNUW) || SA->getValue().ult(BitWidth - 1)))
7998 if (MulFlags & SCEV::FlagNUW)
8000 }
8001
8002 ConstantInt *X = ConstantInt::get(
8003 getContext(), APInt::getOneBitSet(BitWidth, SA->getZExtValue()));
8004 return getMulExpr(getSCEV(BO->LHS), getConstant(X), Flags);
8005 }
8006 break;
8007
8008 case Instruction::AShr:
8009 // AShr X, C, where C is a constant.
8010 ConstantInt *CI = dyn_cast<ConstantInt>(BO->RHS);
8011 if (!CI)
8012 break;
8013
8014 Type *OuterTy = BO->LHS->getType();
8015 uint64_t BitWidth = getTypeSizeInBits(OuterTy);
8016 // If the shift count is not less than the bitwidth, the result of
8017 // the shift is undefined. Don't try to analyze it, because the
8018 // resolution chosen here may differ from the resolution chosen in
8019 // other parts of the compiler.
8020 if (CI->getValue().uge(BitWidth))
8021 break;
8022
8023 if (CI->isZero())
8024 return getSCEV(BO->LHS); // shift by zero --> noop
8025
8026 uint64_t AShrAmt = CI->getZExtValue();
8027 Type *TruncTy = IntegerType::get(getContext(), BitWidth - AShrAmt);
8028
8029 Operator *L = dyn_cast<Operator>(BO->LHS);
8030 const SCEV *AddTruncateExpr = nullptr;
8031 ConstantInt *ShlAmtCI = nullptr;
8032 const SCEV *AddConstant = nullptr;
8033
8034 if (L && L->getOpcode() == Instruction::Add) {
8035 // X = Shl A, n
8036 // Y = Add X, c
8037 // Z = AShr Y, m
8038 // n, c and m are constants.
8039
8040 Operator *LShift = dyn_cast<Operator>(L->getOperand(0));
8041 ConstantInt *AddOperandCI = dyn_cast<ConstantInt>(L->getOperand(1));
8042 if (LShift && LShift->getOpcode() == Instruction::Shl) {
8043 if (AddOperandCI) {
8044 const SCEV *ShlOp0SCEV = getSCEV(LShift->getOperand(0));
8045 ShlAmtCI = dyn_cast<ConstantInt>(LShift->getOperand(1));
8046 // since we truncate to TruncTy, the AddConstant should be of the
8047 // same type, so create a new Constant with type same as TruncTy.
8048 // Also, the Add constant should be shifted right by AShr amount.
8049 APInt AddOperand = AddOperandCI->getValue().ashr(AShrAmt);
8050 AddConstant = getConstant(AddOperand.trunc(BitWidth - AShrAmt));
8051 // we model the expression as sext(add(trunc(A), c << n)), since the
8052 // sext(trunc) part is already handled below, we create a
8053 // AddExpr(TruncExp) which will be used later.
8054 AddTruncateExpr = getTruncateExpr(ShlOp0SCEV, TruncTy);
8055 }
8056 }
8057 } else if (L && L->getOpcode() == Instruction::Shl) {
8058 // X = Shl A, n
8059 // Y = AShr X, m
8060 // Both n and m are constant.
8061
8062 const SCEV *ShlOp0SCEV = getSCEV(L->getOperand(0));
8063 ShlAmtCI = dyn_cast<ConstantInt>(L->getOperand(1));
8064 AddTruncateExpr = getTruncateExpr(ShlOp0SCEV, TruncTy);
8065 }
8066
8067 if (AddTruncateExpr && ShlAmtCI) {
8068 // We can merge the two given cases into a single SCEV statement,
8069 // incase n = m, the mul expression will be 2^0, so it gets resolved to
8070 // a simpler case. The following code handles the two cases:
8071 //
8072 // 1) For a two-shift sext-inreg, i.e. n = m,
8073 // use sext(trunc(x)) as the SCEV expression.
8074 //
8075 // 2) When n > m, use sext(mul(trunc(x), 2^(n-m)))) as the SCEV
8076 // expression. We already checked that ShlAmt < BitWidth, so
8077 // the multiplier, 1 << (ShlAmt - AShrAmt), fits into TruncTy as
8078 // ShlAmt - AShrAmt < Amt.
8079 const APInt &ShlAmt = ShlAmtCI->getValue();
8080 if (ShlAmt.ult(BitWidth) && ShlAmt.uge(AShrAmt)) {
8081 APInt Mul = APInt::getOneBitSet(BitWidth - AShrAmt,
8082 ShlAmtCI->getZExtValue() - AShrAmt);
8083 const SCEV *CompositeExpr =
8084 getMulExpr(AddTruncateExpr, getConstant(Mul));
8085 if (L->getOpcode() != Instruction::Shl)
8086 CompositeExpr = getAddExpr(CompositeExpr, AddConstant);
8087
8088 return getSignExtendExpr(CompositeExpr, OuterTy);
8089 }
8090 }
8091 break;
8092 }
8093 }
8094
8095 switch (U->getOpcode()) {
8096 case Instruction::Trunc:
8097 return getTruncateExpr(getSCEV(U->getOperand(0)), U->getType());
8098
8099 case Instruction::ZExt:
8100 return getZeroExtendExpr(getSCEV(U->getOperand(0)), U->getType());
8101
8102 case Instruction::SExt:
8103 if (auto BO = MatchBinaryOp(U->getOperand(0), getDataLayout(), AC, DT,
8105 // The NSW flag of a subtract does not always survive the conversion to
8106 // A + (-1)*B. By pushing sign extension onto its operands we are much
8107 // more likely to preserve NSW and allow later AddRec optimisations.
8108 //
8109 // NOTE: This is effectively duplicating this logic from getSignExtend:
8110 // sext((A + B + ...)<nsw>) --> (sext(A) + sext(B) + ...)<nsw>
8111 // but by that point the NSW information has potentially been lost.
8112 if (BO->Opcode == Instruction::Sub && BO->IsNSW) {
8113 Type *Ty = U->getType();
8114 auto *V1 = getSignExtendExpr(getSCEV(BO->LHS), Ty);
8115 auto *V2 = getSignExtendExpr(getSCEV(BO->RHS), Ty);
8116 return getMinusSCEV(V1, V2, SCEV::FlagNSW);
8117 }
8118 }
8119 return getSignExtendExpr(getSCEV(U->getOperand(0)), U->getType());
8120
8121 case Instruction::BitCast:
8122 // BitCasts are no-op casts so we just eliminate the cast.
8123 if (isSCEVable(U->getType()) && isSCEVable(U->getOperand(0)->getType()))
8124 return getSCEV(U->getOperand(0));
8125 break;
8126
8127 case Instruction::PtrToInt: {
8128 // Pointer to integer cast is straight-forward, so do model it.
8129 const SCEV *Op = getSCEV(U->getOperand(0));
8130 Type *DstIntTy = U->getType();
8131 // But only if effective SCEV (integer) type is wide enough to represent
8132 // all possible pointer values.
8133 const SCEV *IntOp = getPtrToIntExpr(Op, DstIntTy);
8134 if (isa<SCEVCouldNotCompute>(IntOp))
8135 return getUnknown(V);
8136 return IntOp;
8137 }
8138 case Instruction::IntToPtr:
8139 // Just don't deal with inttoptr casts.
8140 return getUnknown(V);
8141
8142 case Instruction::SDiv:
8143 // If both operands are non-negative, this is just an udiv.
8144 if (isKnownNonNegative(getSCEV(U->getOperand(0))) &&
8145 isKnownNonNegative(getSCEV(U->getOperand(1))))
8146 return getUDivExpr(getSCEV(U->getOperand(0)), getSCEV(U->getOperand(1)));
8147 break;
8148
8149 case Instruction::SRem:
8150 // If both operands are non-negative, this is just an urem.
8151 if (isKnownNonNegative(getSCEV(U->getOperand(0))) &&
8152 isKnownNonNegative(getSCEV(U->getOperand(1))))
8153 return getURemExpr(getSCEV(U->getOperand(0)), getSCEV(U->getOperand(1)));
8154 break;
8155
8156 case Instruction::GetElementPtr:
8157 return createNodeForGEP(cast<GEPOperator>(U));
8158
8159 case Instruction::PHI:
8160 return createNodeForPHI(cast<PHINode>(U));
8161
8162 case Instruction::Select:
8163 return createNodeForSelectOrPHI(U, U->getOperand(0), U->getOperand(1),
8164 U->getOperand(2));
8165
8166 case Instruction::Call:
8167 case Instruction::Invoke:
8168 if (Value *RV = cast<CallBase>(U)->getReturnedArgOperand())
8169 return getSCEV(RV);
8170
8171 if (auto *II = dyn_cast<IntrinsicInst>(U)) {
8172 switch (II->getIntrinsicID()) {
8173 case Intrinsic::abs:
8174 return getAbsExpr(
8175 getSCEV(II->getArgOperand(0)),
8176 /*IsNSW=*/cast<ConstantInt>(II->getArgOperand(1))->isOne());
8177 case Intrinsic::umax:
8178 LHS = getSCEV(II->getArgOperand(0));
8179 RHS = getSCEV(II->getArgOperand(1));
8180 return getUMaxExpr(LHS, RHS);
8181 case Intrinsic::umin:
8182 LHS = getSCEV(II->getArgOperand(0));
8183 RHS = getSCEV(II->getArgOperand(1));
8184 return getUMinExpr(LHS, RHS);
8185 case Intrinsic::smax:
8186 LHS = getSCEV(II->getArgOperand(0));
8187 RHS = getSCEV(II->getArgOperand(1));
8188 return getSMaxExpr(LHS, RHS);
8189 case Intrinsic::smin:
8190 LHS = getSCEV(II->getArgOperand(0));
8191 RHS = getSCEV(II->getArgOperand(1));
8192 return getSMinExpr(LHS, RHS);
8193 case Intrinsic::usub_sat: {
8194 const SCEV *X = getSCEV(II->getArgOperand(0));
8195 const SCEV *Y = getSCEV(II->getArgOperand(1));
8196 const SCEV *ClampedY = getUMinExpr(X, Y);
8197 return getMinusSCEV(X, ClampedY, SCEV::FlagNUW);
8198 }
8199 case Intrinsic::uadd_sat: {
8200 const SCEV *X = getSCEV(II->getArgOperand(0));
8201 const SCEV *Y = getSCEV(II->getArgOperand(1));
8202 const SCEV *ClampedX = getUMinExpr(X, getNotSCEV(Y));
8203 return getAddExpr(ClampedX, Y, SCEV::FlagNUW);
8204 }
8205 case Intrinsic::start_loop_iterations:
8206 case Intrinsic::annotation:
8207 case Intrinsic::ptr_annotation:
8208 // A start_loop_iterations or llvm.annotation or llvm.prt.annotation is
8209 // just eqivalent to the first operand for SCEV purposes.
8210 return getSCEV(II->getArgOperand(0));
8211 case Intrinsic::vscale:
8212 return getVScale(II->getType());
8213 default:
8214 break;
8215 }
8216 }
8217 break;
8218 }
8219
8220 return getUnknown(V);
8221}
8222
8223//===----------------------------------------------------------------------===//
8224// Iteration Count Computation Code
8225//
8226
8228 if (isa<SCEVCouldNotCompute>(ExitCount))
8229 return getCouldNotCompute();
8230
8231 auto *ExitCountType = ExitCount->getType();
8232 assert(ExitCountType->isIntegerTy());
8233 auto *EvalTy = Type::getIntNTy(ExitCountType->getContext(),
8234 1 + ExitCountType->getScalarSizeInBits());
8235 return getTripCountFromExitCount(ExitCount, EvalTy, nullptr);
8236}
8237
8239 Type *EvalTy,
8240 const Loop *L) {
8241 if (isa<SCEVCouldNotCompute>(ExitCount))
8242 return getCouldNotCompute();
8243
8244 unsigned ExitCountSize = getTypeSizeInBits(ExitCount->getType());
8245 unsigned EvalSize = EvalTy->getPrimitiveSizeInBits();
8246
8247 auto CanAddOneWithoutOverflow = [&]() {
8248 ConstantRange ExitCountRange =
8249 getRangeRef(ExitCount, RangeSignHint::HINT_RANGE_UNSIGNED);
8250 if (!ExitCountRange.contains(APInt::getMaxValue(ExitCountSize)))
8251 return true;
8252
8253 return L && isLoopEntryGuardedByCond(L, ICmpInst::ICMP_NE, ExitCount,
8254 getMinusOne(ExitCount->getType()));
8255 };
8256
8257 // If we need to zero extend the backedge count, check if we can add one to
8258 // it prior to zero extending without overflow. Provided this is safe, it
8259 // allows better simplification of the +1.
8260 if (EvalSize > ExitCountSize && CanAddOneWithoutOverflow())
8261 return getZeroExtendExpr(
8262 getAddExpr(ExitCount, getOne(ExitCount->getType())), EvalTy);
8263
8264 // Get the total trip count from the count by adding 1. This may wrap.
8265 return getAddExpr(getTruncateOrZeroExtend(ExitCount, EvalTy), getOne(EvalTy));
8266}
8267
8268static unsigned getConstantTripCount(const SCEVConstant *ExitCount) {
8269 if (!ExitCount)
8270 return 0;
8271
8272 ConstantInt *ExitConst = ExitCount->getValue();
8273
8274 // Guard against huge trip counts.
8275 if (ExitConst->getValue().getActiveBits() > 32)
8276 return 0;
8277
8278 // In case of integer overflow, this returns 0, which is correct.
8279 return ((unsigned)ExitConst->getZExtValue()) + 1;
8280}
8281
8283 auto *ExitCount = dyn_cast<SCEVConstant>(getBackedgeTakenCount(L, Exact));
8284 return getConstantTripCount(ExitCount);
8285}
8286
8287unsigned
8289 const BasicBlock *ExitingBlock) {
8290 assert(ExitingBlock && "Must pass a non-null exiting block!");
8291 assert(L->isLoopExiting(ExitingBlock) &&
8292 "Exiting block must actually branch out of the loop!");
8293 const SCEVConstant *ExitCount =
8294 dyn_cast<SCEVConstant>(getExitCount(L, ExitingBlock));
8295 return getConstantTripCount(ExitCount);
8296}
8297
8299 const Loop *L, SmallVectorImpl<const SCEVPredicate *> *Predicates) {
8300
8301 const auto *MaxExitCount =
8302 Predicates ? getPredicatedConstantMaxBackedgeTakenCount(L, *Predicates)
8304 return getConstantTripCount(dyn_cast<SCEVConstant>(MaxExitCount));
8305}
8306
8308 SmallVector<BasicBlock *, 8> ExitingBlocks;
8309 L->getExitingBlocks(ExitingBlocks);
8310
8311 std::optional<unsigned> Res;
8312 for (auto *ExitingBB : ExitingBlocks) {
8313 unsigned Multiple = getSmallConstantTripMultiple(L, ExitingBB);
8314 if (!Res)
8315 Res = Multiple;
8316 Res = std::gcd(*Res, Multiple);
8317 }
8318 return Res.value_or(1);
8319}
8320
8322 const SCEV *ExitCount) {
8323 if (isa<SCEVCouldNotCompute>(ExitCount))
8324 return 1;
8325
8326 // Get the trip count
8327 const SCEV *TCExpr = getTripCountFromExitCount(applyLoopGuards(ExitCount, L));
8328
8329 APInt Multiple = getNonZeroConstantMultiple(TCExpr);
8330 // If a trip multiple is huge (>=2^32), the trip count is still divisible by
8331 // the greatest power of 2 divisor less than 2^32.
8332 return Multiple.getActiveBits() > 32
8333 ? 1U << std::min(31U, Multiple.countTrailingZeros())
8334 : (unsigned)Multiple.getZExtValue();
8335}
8336
8337/// Returns the largest constant divisor of the trip count of this loop as a
8338/// normal unsigned value, if possible. This means that the actual trip count is
8339/// always a multiple of the returned value (don't forget the trip count could
8340/// very well be zero as well!).
8341///
8342/// Returns 1 if the trip count is unknown or not guaranteed to be the
8343/// multiple of a constant (which is also the case if the trip count is simply
8344/// constant, use getSmallConstantTripCount for that case), Will also return 1
8345/// if the trip count is very large (>= 2^32).
8346///
8347/// As explained in the comments for getSmallConstantTripCount, this assumes
8348/// that control exits the loop via ExitingBlock.
8349unsigned
8351 const BasicBlock *ExitingBlock) {
8352 assert(ExitingBlock && "Must pass a non-null exiting block!");
8353 assert(L->isLoopExiting(ExitingBlock) &&
8354 "Exiting block must actually branch out of the loop!");
8355 const SCEV *ExitCount = getExitCount(L, ExitingBlock);
8356 return getSmallConstantTripMultiple(L, ExitCount);
8357}
8358
8360 const BasicBlock *ExitingBlock,
8361 ExitCountKind Kind) {
8362 switch (Kind) {
8363 case Exact:
8364 return getBackedgeTakenInfo(L).getExact(ExitingBlock, this);
8365 case SymbolicMaximum:
8366 return getBackedgeTakenInfo(L).getSymbolicMax(ExitingBlock, this);
8367 case ConstantMaximum:
8368 return getBackedgeTakenInfo(L).getConstantMax(ExitingBlock, this);
8369 };
8370 llvm_unreachable("Invalid ExitCountKind!");
8371}
8372
8374 const Loop *L, const BasicBlock *ExitingBlock,
8376 switch (Kind) {
8377 case Exact:
8378 return getPredicatedBackedgeTakenInfo(L).getExact(ExitingBlock, this,
8379 Predicates);
8380 case SymbolicMaximum:
8381 return getPredicatedBackedgeTakenInfo(L).getSymbolicMax(ExitingBlock, this,
8382 Predicates);
8383 case ConstantMaximum:
8384 return getPredicatedBackedgeTakenInfo(L).getConstantMax(ExitingBlock, this,
8385 Predicates);
8386 };
8387 llvm_unreachable("Invalid ExitCountKind!");
8388}
8389
8392 return getPredicatedBackedgeTakenInfo(L).getExact(L, this, &Preds);
8393}
8394
8396 ExitCountKind Kind) {
8397 switch (Kind) {
8398 case Exact:
8399 return getBackedgeTakenInfo(L).getExact(L, this);
8400 case ConstantMaximum:
8401 return getBackedgeTakenInfo(L).getConstantMax(this);
8402 case SymbolicMaximum:
8403 return getBackedgeTakenInfo(L).getSymbolicMax(L, this);
8404 };
8405 llvm_unreachable("Invalid ExitCountKind!");
8406}
8407
8410 return getPredicatedBackedgeTakenInfo(L).getSymbolicMax(L, this, &Preds);
8411}
8412
8415 return getPredicatedBackedgeTakenInfo(L).getConstantMax(this, &Preds);
8416}
8417
8419 return getBackedgeTakenInfo(L).isConstantMaxOrZero(this);
8420}
8421
8422/// Push PHI nodes in the header of the given loop onto the given Worklist.
8423static void PushLoopPHIs(const Loop *L,
8426 BasicBlock *Header = L->getHeader();
8427
8428 // Push all Loop-header PHIs onto the Worklist stack.
8429 for (PHINode &PN : Header->phis())
8430 if (Visited.insert(&PN).second)
8431 Worklist.push_back(&PN);
8432}
8433
8434ScalarEvolution::BackedgeTakenInfo &
8435ScalarEvolution::getPredicatedBackedgeTakenInfo(const Loop *L) {
8436 auto &BTI = getBackedgeTakenInfo(L);
8437 if (BTI.hasFullInfo())
8438 return BTI;
8439
8440 auto Pair = PredicatedBackedgeTakenCounts.try_emplace(L);
8441
8442 if (!Pair.second)
8443 return Pair.first->second;
8444
8445 BackedgeTakenInfo Result =
8446 computeBackedgeTakenCount(L, /*AllowPredicates=*/true);
8447
8448 return PredicatedBackedgeTakenCounts.find(L)->second = std::move(Result);
8449}
8450
8451ScalarEvolution::BackedgeTakenInfo &
8452ScalarEvolution::getBackedgeTakenInfo(const Loop *L) {
8453 // Initially insert an invalid entry for this loop. If the insertion
8454 // succeeds, proceed to actually compute a backedge-taken count and
8455 // update the value. The temporary CouldNotCompute value tells SCEV
8456 // code elsewhere that it shouldn't attempt to request a new
8457 // backedge-taken count, which could result in infinite recursion.
8458 std::pair<DenseMap<const Loop *, BackedgeTakenInfo>::iterator, bool> Pair =
8459 BackedgeTakenCounts.try_emplace(L);
8460 if (!Pair.second)
8461 return Pair.first->second;
8462
8463 // computeBackedgeTakenCount may allocate memory for its result. Inserting it
8464 // into the BackedgeTakenCounts map transfers ownership. Otherwise, the result
8465 // must be cleared in this scope.
8466 BackedgeTakenInfo Result = computeBackedgeTakenCount(L);
8467
8468 // Now that we know more about the trip count for this loop, forget any
8469 // existing SCEV values for PHI nodes in this loop since they are only
8470 // conservative estimates made without the benefit of trip count
8471 // information. This invalidation is not necessary for correctness, and is
8472 // only done to produce more precise results.
8473 if (Result.hasAnyInfo()) {
8474 // Invalidate any expression using an addrec in this loop.
8476 auto LoopUsersIt = LoopUsers.find(L);
8477 if (LoopUsersIt != LoopUsers.end())
8478 append_range(ToForget, LoopUsersIt->second);
8479 forgetMemoizedResults(ToForget);
8480
8481 // Invalidate constant-evolved loop header phis.
8482 for (PHINode &PN : L->getHeader()->phis())
8483 ConstantEvolutionLoopExitValue.erase(&PN);
8484 }
8485
8486 // Re-lookup the insert position, since the call to
8487 // computeBackedgeTakenCount above could result in a
8488 // recusive call to getBackedgeTakenInfo (on a different
8489 // loop), which would invalidate the iterator computed
8490 // earlier.
8491 return BackedgeTakenCounts.find(L)->second = std::move(Result);
8492}
8493
8495 // This method is intended to forget all info about loops. It should
8496 // invalidate caches as if the following happened:
8497 // - The trip counts of all loops have changed arbitrarily
8498 // - Every llvm::Value has been updated in place to produce a different
8499 // result.
8500 BackedgeTakenCounts.clear();
8501 PredicatedBackedgeTakenCounts.clear();
8502 BECountUsers.clear();
8503 LoopPropertiesCache.clear();
8504 ConstantEvolutionLoopExitValue.clear();
8505 ValueExprMap.clear();
8506 ValuesAtScopes.clear();
8507 ValuesAtScopesUsers.clear();
8508 LoopDispositions.clear();
8509 BlockDispositions.clear();
8510 UnsignedRanges.clear();
8511 SignedRanges.clear();
8512 ExprValueMap.clear();
8513 HasRecMap.clear();
8514 ConstantMultipleCache.clear();
8515 PredicatedSCEVRewrites.clear();
8516 FoldCache.clear();
8517 FoldCacheUser.clear();
8518}
8519void ScalarEvolution::visitAndClearUsers(
8523 while (!Worklist.empty()) {
8524 Instruction *I = Worklist.pop_back_val();
8525 if (!isSCEVable(I->getType()) && !isa<WithOverflowInst>(I))
8526 continue;
8527
8529 ValueExprMap.find_as(static_cast<Value *>(I));
8530 if (It != ValueExprMap.end()) {
8531 eraseValueFromMap(It->first);
8532 ToForget.push_back(It->second);
8533 if (PHINode *PN = dyn_cast<PHINode>(I))
8534 ConstantEvolutionLoopExitValue.erase(PN);
8535 }
8536
8537 PushDefUseChildren(I, Worklist, Visited);
8538 }
8539}
8540
8542 SmallVector<const Loop *, 16> LoopWorklist(1, L);
8546
8547 // Iterate over all the loops and sub-loops to drop SCEV information.
8548 while (!LoopWorklist.empty()) {
8549 auto *CurrL = LoopWorklist.pop_back_val();
8550
8551 // Drop any stored trip count value.
8552 forgetBackedgeTakenCounts(CurrL, /* Predicated */ false);
8553 forgetBackedgeTakenCounts(CurrL, /* Predicated */ true);
8554
8555 // Drop information about predicated SCEV rewrites for this loop.
8556 for (auto I = PredicatedSCEVRewrites.begin();
8557 I != PredicatedSCEVRewrites.end();) {
8558 std::pair<const SCEV *, const Loop *> Entry = I->first;
8559 if (Entry.second == CurrL)
8560 PredicatedSCEVRewrites.erase(I++);
8561 else
8562 ++I;
8563 }
8564
8565 auto LoopUsersItr = LoopUsers.find(CurrL);
8566 if (LoopUsersItr != LoopUsers.end())
8567 llvm::append_range(ToForget, LoopUsersItr->second);
8568
8569 // Drop information about expressions based on loop-header PHIs.
8570 PushLoopPHIs(CurrL, Worklist, Visited);
8571 visitAndClearUsers(Worklist, Visited, ToForget);
8572
8573 LoopPropertiesCache.erase(CurrL);
8574 // Forget all contained loops too, to avoid dangling entries in the
8575 // ValuesAtScopes map.
8576 LoopWorklist.append(CurrL->begin(), CurrL->end());
8577 }
8578 forgetMemoizedResults(ToForget);
8579}
8580
8582 forgetLoop(L->getOutermostLoop());
8583}
8584
8587 if (!I) return;
8588
8589 // Drop information about expressions based on loop-header PHIs.
8593 Worklist.push_back(I);
8594 Visited.insert(I);
8595 visitAndClearUsers(Worklist, Visited, ToForget);
8596
8597 forgetMemoizedResults(ToForget);
8598}
8599
8601 if (!isSCEVable(V->getType()))
8602 return;
8603
8604 // If SCEV looked through a trivial LCSSA phi node, we might have SCEV's
8605 // directly using a SCEVUnknown/SCEVAddRec defined in the loop. After an
8606 // extra predecessor is added, this is no longer valid. Find all Unknowns and
8607 // AddRecs defined in the loop and invalidate any SCEV's making use of them.
8608 if (const SCEV *S = getExistingSCEV(V)) {
8609 struct InvalidationRootCollector {
8610 Loop *L;
8612
8613 InvalidationRootCollector(Loop *L) : L(L) {}
8614
8615 bool follow(const SCEV *S) {
8616 if (auto *SU = dyn_cast<SCEVUnknown>(S)) {
8617 if (auto *I = dyn_cast<Instruction>(SU->getValue()))
8618 if (L->contains(I))
8619 Roots.push_back(S);
8620 } else if (auto *AddRec = dyn_cast<SCEVAddRecExpr>(S)) {
8621 if (L->contains(AddRec->getLoop()))
8622 Roots.push_back(S);
8623 }
8624 return true;
8625 }
8626 bool isDone() const { return false; }
8627 };
8628
8629 InvalidationRootCollector C(L);
8630 visitAll(S, C);
8631 forgetMemoizedResults(C.Roots);
8632 }
8633
8634 // Also perform the normal invalidation.
8635 forgetValue(V);
8636}
8637
8638void ScalarEvolution::forgetLoopDispositions() { LoopDispositions.clear(); }
8639
8641 // Unless a specific value is passed to invalidation, completely clear both
8642 // caches.
8643 if (!V) {
8644 BlockDispositions.clear();
8645 LoopDispositions.clear();
8646 return;
8647 }
8648
8649 if (!isSCEVable(V->getType()))
8650 return;
8651
8652 const SCEV *S = getExistingSCEV(V);
8653 if (!S)
8654 return;
8655
8656 // Invalidate the block and loop dispositions cached for S. Dispositions of
8657 // S's users may change if S's disposition changes (i.e. a user may change to
8658 // loop-invariant, if S changes to loop invariant), so also invalidate
8659 // dispositions of S's users recursively.
8660 SmallVector<const SCEV *, 8> Worklist = {S};
8662 while (!Worklist.empty()) {
8663 const SCEV *Curr = Worklist.pop_back_val();
8664 bool LoopDispoRemoved = LoopDispositions.erase(Curr);
8665 bool BlockDispoRemoved = BlockDispositions.erase(Curr);
8666 if (!LoopDispoRemoved && !BlockDispoRemoved)
8667 continue;
8668 auto Users = SCEVUsers.find(Curr);
8669 if (Users != SCEVUsers.end())
8670 for (const auto *User : Users->second)
8671 if (Seen.insert(User).second)
8672 Worklist.push_back(User);
8673 }
8674}
8675
8676/// Get the exact loop backedge taken count considering all loop exits. A
8677/// computable result can only be returned for loops with all exiting blocks
8678/// dominating the latch. howFarToZero assumes that the limit of each loop test
8679/// is never skipped. This is a valid assumption as long as the loop exits via
8680/// that test. For precise results, it is the caller's responsibility to specify
8681/// the relevant loop exiting block using getExact(ExitingBlock, SE).
8682const SCEV *ScalarEvolution::BackedgeTakenInfo::getExact(
8683 const Loop *L, ScalarEvolution *SE,
8685 // If any exits were not computable, the loop is not computable.
8686 if (!isComplete() || ExitNotTaken.empty())
8687 return SE->getCouldNotCompute();
8688
8689 const BasicBlock *Latch = L->getLoopLatch();
8690 // All exiting blocks we have collected must dominate the only backedge.
8691 if (!Latch)
8692 return SE->getCouldNotCompute();
8693
8694 // All exiting blocks we have gathered dominate loop's latch, so exact trip
8695 // count is simply a minimum out of all these calculated exit counts.
8697 for (const auto &ENT : ExitNotTaken) {
8698 const SCEV *BECount = ENT.ExactNotTaken;
8699 assert(BECount != SE->getCouldNotCompute() && "Bad exit SCEV!");
8700 assert(SE->DT.dominates(ENT.ExitingBlock, Latch) &&
8701 "We should only have known counts for exiting blocks that dominate "
8702 "latch!");
8703
8704 Ops.push_back(BECount);
8705
8706 if (Preds)
8707 append_range(*Preds, ENT.Predicates);
8708
8709 assert((Preds || ENT.hasAlwaysTruePredicate()) &&
8710 "Predicate should be always true!");
8711 }
8712
8713 // If an earlier exit exits on the first iteration (exit count zero), then
8714 // a later poison exit count should not propagate into the result. This are
8715 // exactly the semantics provided by umin_seq.
8716 return SE->getUMinFromMismatchedTypes(Ops, /* Sequential */ true);
8717}
8718
8719const ScalarEvolution::ExitNotTakenInfo *
8720ScalarEvolution::BackedgeTakenInfo::getExitNotTaken(
8721 const BasicBlock *ExitingBlock,
8722 SmallVectorImpl<const SCEVPredicate *> *Predicates) const {
8723 for (const auto &ENT : ExitNotTaken)
8724 if (ENT.ExitingBlock == ExitingBlock) {
8725 if (ENT.hasAlwaysTruePredicate())
8726 return &ENT;
8727 else if (Predicates) {
8728 append_range(*Predicates, ENT.Predicates);
8729 return &ENT;
8730 }
8731 }
8732
8733 return nullptr;
8734}
8735
8736/// getConstantMax - Get the constant max backedge taken count for the loop.
8737const SCEV *ScalarEvolution::BackedgeTakenInfo::getConstantMax(
8738 ScalarEvolution *SE,
8739 SmallVectorImpl<const SCEVPredicate *> *Predicates) const {
8740 if (!getConstantMax())
8741 return SE->getCouldNotCompute();
8742
8743 for (const auto &ENT : ExitNotTaken)
8744 if (!ENT.hasAlwaysTruePredicate()) {
8745 if (!Predicates)
8746 return SE->getCouldNotCompute();
8747 append_range(*Predicates, ENT.Predicates);
8748 }
8749
8750 assert((isa<SCEVCouldNotCompute>(getConstantMax()) ||
8751 isa<SCEVConstant>(getConstantMax())) &&
8752 "No point in having a non-constant max backedge taken count!");
8753 return getConstantMax();
8754}
8755
8756const SCEV *ScalarEvolution::BackedgeTakenInfo::getSymbolicMax(
8757 const Loop *L, ScalarEvolution *SE,
8758 SmallVectorImpl<const SCEVPredicate *> *Predicates) {
8759 if (!SymbolicMax) {
8760 // Form an expression for the maximum exit count possible for this loop. We
8761 // merge the max and exact information to approximate a version of
8762 // getConstantMaxBackedgeTakenCount which isn't restricted to just
8763 // constants.
8765
8766 for (const auto &ENT : ExitNotTaken) {
8767 const SCEV *ExitCount = ENT.SymbolicMaxNotTaken;
8768 if (!isa<SCEVCouldNotCompute>(ExitCount)) {
8769 assert(SE->DT.dominates(ENT.ExitingBlock, L->getLoopLatch()) &&
8770 "We should only have known counts for exiting blocks that "
8771 "dominate latch!");
8772 ExitCounts.push_back(ExitCount);
8773 if (Predicates)
8774 append_range(*Predicates, ENT.Predicates);
8775
8776 assert((Predicates || ENT.hasAlwaysTruePredicate()) &&
8777 "Predicate should be always true!");
8778 }
8779 }
8780 if (ExitCounts.empty())
8781 SymbolicMax = SE->getCouldNotCompute();
8782 else
8783 SymbolicMax =
8784 SE->getUMinFromMismatchedTypes(ExitCounts, /*Sequential*/ true);
8785 }
8786 return SymbolicMax;
8787}
8788
8789bool ScalarEvolution::BackedgeTakenInfo::isConstantMaxOrZero(
8790 ScalarEvolution *SE) const {
8791 auto PredicateNotAlwaysTrue = [](const ExitNotTakenInfo &ENT) {
8792 return !ENT.hasAlwaysTruePredicate();
8793 };
8794 return MaxOrZero && !any_of(ExitNotTaken, PredicateNotAlwaysTrue);
8795}
8796
8799
8801 const SCEV *E, const SCEV *ConstantMaxNotTaken,
8802 const SCEV *SymbolicMaxNotTaken, bool MaxOrZero,
8806 // If we prove the max count is zero, so is the symbolic bound. This happens
8807 // in practice due to differences in a) how context sensitive we've chosen
8808 // to be and b) how we reason about bounds implied by UB.
8809 if (ConstantMaxNotTaken->isZero()) {
8810 this->ExactNotTaken = E = ConstantMaxNotTaken;
8811 this->SymbolicMaxNotTaken = SymbolicMaxNotTaken = ConstantMaxNotTaken;
8812 }
8813
8816 "Exact is not allowed to be less precise than Constant Max");
8819 "Exact is not allowed to be less precise than Symbolic Max");
8822 "Symbolic Max is not allowed to be less precise than Constant Max");
8825 "No point in having a non-constant max backedge taken count!");
8827 for (const auto PredList : PredLists)
8828 for (const auto *P : PredList) {
8829 if (SeenPreds.contains(P))
8830 continue;
8831 assert(!isa<SCEVUnionPredicate>(P) && "Only add leaf predicates here!");
8832 SeenPreds.insert(P);
8833 Predicates.push_back(P);
8834 }
8835 assert((isa<SCEVCouldNotCompute>(E) || !E->getType()->isPointerTy()) &&
8836 "Backedge count should be int");
8838 !ConstantMaxNotTaken->getType()->isPointerTy()) &&
8839 "Max backedge count should be int");
8840}
8841
8849
8850/// Allocate memory for BackedgeTakenInfo and copy the not-taken count of each
8851/// computable exit into a persistent ExitNotTakenInfo array.
8852ScalarEvolution::BackedgeTakenInfo::BackedgeTakenInfo(
8854 bool IsComplete, const SCEV *ConstantMax, bool MaxOrZero)
8855 : ConstantMax(ConstantMax), IsComplete(IsComplete), MaxOrZero(MaxOrZero) {
8856 using EdgeExitInfo = ScalarEvolution::BackedgeTakenInfo::EdgeExitInfo;
8857
8858 ExitNotTaken.reserve(ExitCounts.size());
8859 std::transform(ExitCounts.begin(), ExitCounts.end(),
8860 std::back_inserter(ExitNotTaken),
8861 [&](const EdgeExitInfo &EEI) {
8862 BasicBlock *ExitBB = EEI.first;
8863 const ExitLimit &EL = EEI.second;
8864 return ExitNotTakenInfo(ExitBB, EL.ExactNotTaken,
8865 EL.ConstantMaxNotTaken, EL.SymbolicMaxNotTaken,
8866 EL.Predicates);
8867 });
8868 assert((isa<SCEVCouldNotCompute>(ConstantMax) ||
8869 isa<SCEVConstant>(ConstantMax)) &&
8870 "No point in having a non-constant max backedge taken count!");
8871}
8872
8873/// Compute the number of times the backedge of the specified loop will execute.
8874ScalarEvolution::BackedgeTakenInfo
8875ScalarEvolution::computeBackedgeTakenCount(const Loop *L,
8876 bool AllowPredicates) {
8877 SmallVector<BasicBlock *, 8> ExitingBlocks;
8878 L->getExitingBlocks(ExitingBlocks);
8879
8880 using EdgeExitInfo = ScalarEvolution::BackedgeTakenInfo::EdgeExitInfo;
8881
8883 bool CouldComputeBECount = true;
8884 BasicBlock *Latch = L->getLoopLatch(); // may be NULL.
8885 const SCEV *MustExitMaxBECount = nullptr;
8886 const SCEV *MayExitMaxBECount = nullptr;
8887 bool MustExitMaxOrZero = false;
8888 bool IsOnlyExit = ExitingBlocks.size() == 1;
8889
8890 // Compute the ExitLimit for each loop exit. Use this to populate ExitCounts
8891 // and compute maxBECount.
8892 // Do a union of all the predicates here.
8893 for (BasicBlock *ExitBB : ExitingBlocks) {
8894 // We canonicalize untaken exits to br (constant), ignore them so that
8895 // proving an exit untaken doesn't negatively impact our ability to reason
8896 // about the loop as whole.
8897 if (auto *BI = dyn_cast<BranchInst>(ExitBB->getTerminator()))
8898 if (auto *CI = dyn_cast<ConstantInt>(BI->getCondition())) {
8899 bool ExitIfTrue = !L->contains(BI->getSuccessor(0));
8900 if (ExitIfTrue == CI->isZero())
8901 continue;
8902 }
8903
8904 ExitLimit EL = computeExitLimit(L, ExitBB, IsOnlyExit, AllowPredicates);
8905
8906 assert((AllowPredicates || EL.Predicates.empty()) &&
8907 "Predicated exit limit when predicates are not allowed!");
8908
8909 // 1. For each exit that can be computed, add an entry to ExitCounts.
8910 // CouldComputeBECount is true only if all exits can be computed.
8911 if (EL.ExactNotTaken != getCouldNotCompute())
8912 ++NumExitCountsComputed;
8913 else
8914 // We couldn't compute an exact value for this exit, so
8915 // we won't be able to compute an exact value for the loop.
8916 CouldComputeBECount = false;
8917 // Remember exit count if either exact or symbolic is known. Because
8918 // Exact always implies symbolic, only check symbolic.
8919 if (EL.SymbolicMaxNotTaken != getCouldNotCompute())
8920 ExitCounts.emplace_back(ExitBB, EL);
8921 else {
8922 assert(EL.ExactNotTaken == getCouldNotCompute() &&
8923 "Exact is known but symbolic isn't?");
8924 ++NumExitCountsNotComputed;
8925 }
8926
8927 // 2. Derive the loop's MaxBECount from each exit's max number of
8928 // non-exiting iterations. Partition the loop exits into two kinds:
8929 // LoopMustExits and LoopMayExits.
8930 //
8931 // If the exit dominates the loop latch, it is a LoopMustExit otherwise it
8932 // is a LoopMayExit. If any computable LoopMustExit is found, then
8933 // MaxBECount is the minimum EL.ConstantMaxNotTaken of computable
8934 // LoopMustExits. Otherwise, MaxBECount is conservatively the maximum
8935 // EL.ConstantMaxNotTaken, where CouldNotCompute is considered greater than
8936 // any
8937 // computable EL.ConstantMaxNotTaken.
8938 if (EL.ConstantMaxNotTaken != getCouldNotCompute() && Latch &&
8939 DT.dominates(ExitBB, Latch)) {
8940 if (!MustExitMaxBECount) {
8941 MustExitMaxBECount = EL.ConstantMaxNotTaken;
8942 MustExitMaxOrZero = EL.MaxOrZero;
8943 } else {
8944 MustExitMaxBECount = getUMinFromMismatchedTypes(MustExitMaxBECount,
8945 EL.ConstantMaxNotTaken);
8946 }
8947 } else if (MayExitMaxBECount != getCouldNotCompute()) {
8948 if (!MayExitMaxBECount || EL.ConstantMaxNotTaken == getCouldNotCompute())
8949 MayExitMaxBECount = EL.ConstantMaxNotTaken;
8950 else {
8951 MayExitMaxBECount = getUMaxFromMismatchedTypes(MayExitMaxBECount,
8952 EL.ConstantMaxNotTaken);
8953 }
8954 }
8955 }
8956 const SCEV *MaxBECount = MustExitMaxBECount ? MustExitMaxBECount :
8957 (MayExitMaxBECount ? MayExitMaxBECount : getCouldNotCompute());
8958 // The loop backedge will be taken the maximum or zero times if there's
8959 // a single exit that must be taken the maximum or zero times.
8960 bool MaxOrZero = (MustExitMaxOrZero && ExitingBlocks.size() == 1);
8961
8962 // Remember which SCEVs are used in exit limits for invalidation purposes.
8963 // We only care about non-constant SCEVs here, so we can ignore
8964 // EL.ConstantMaxNotTaken
8965 // and MaxBECount, which must be SCEVConstant.
8966 for (const auto &Pair : ExitCounts) {
8967 if (!isa<SCEVConstant>(Pair.second.ExactNotTaken))
8968 BECountUsers[Pair.second.ExactNotTaken].insert({L, AllowPredicates});
8969 if (!isa<SCEVConstant>(Pair.second.SymbolicMaxNotTaken))
8970 BECountUsers[Pair.second.SymbolicMaxNotTaken].insert(
8971 {L, AllowPredicates});
8972 }
8973 return BackedgeTakenInfo(std::move(ExitCounts), CouldComputeBECount,
8974 MaxBECount, MaxOrZero);
8975}
8976
8977ScalarEvolution::ExitLimit
8978ScalarEvolution::computeExitLimit(const Loop *L, BasicBlock *ExitingBlock,
8979 bool IsOnlyExit, bool AllowPredicates) {
8980 assert(L->contains(ExitingBlock) && "Exit count for non-loop block?");
8981 // If our exiting block does not dominate the latch, then its connection with
8982 // loop's exit limit may be far from trivial.
8983 const BasicBlock *Latch = L->getLoopLatch();
8984 if (!Latch || !DT.dominates(ExitingBlock, Latch))
8985 return getCouldNotCompute();
8986
8987 Instruction *Term = ExitingBlock->getTerminator();
8988 if (BranchInst *BI = dyn_cast<BranchInst>(Term)) {
8989 assert(BI->isConditional() && "If unconditional, it can't be in loop!");
8990 bool ExitIfTrue = !L->contains(BI->getSuccessor(0));
8991 assert(ExitIfTrue == L->contains(BI->getSuccessor(1)) &&
8992 "It should have one successor in loop and one exit block!");
8993 // Proceed to the next level to examine the exit condition expression.
8994 return computeExitLimitFromCond(L, BI->getCondition(), ExitIfTrue,
8995 /*ControlsOnlyExit=*/IsOnlyExit,
8996 AllowPredicates);
8997 }
8998
8999 if (SwitchInst *SI = dyn_cast<SwitchInst>(Term)) {
9000 // For switch, make sure that there is a single exit from the loop.
9001 BasicBlock *Exit = nullptr;
9002 for (auto *SBB : successors(ExitingBlock))
9003 if (!L->contains(SBB)) {
9004 if (Exit) // Multiple exit successors.
9005 return getCouldNotCompute();
9006 Exit = SBB;
9007 }
9008 assert(Exit && "Exiting block must have at least one exit");
9009 return computeExitLimitFromSingleExitSwitch(
9010 L, SI, Exit, /*ControlsOnlyExit=*/IsOnlyExit);
9011 }
9012
9013 return getCouldNotCompute();
9014}
9015
9017 const Loop *L, Value *ExitCond, bool ExitIfTrue, bool ControlsOnlyExit,
9018 bool AllowPredicates) {
9019 ScalarEvolution::ExitLimitCacheTy Cache(L, ExitIfTrue, AllowPredicates);
9020 return computeExitLimitFromCondCached(Cache, L, ExitCond, ExitIfTrue,
9021 ControlsOnlyExit, AllowPredicates);
9022}
9023
9024std::optional<ScalarEvolution::ExitLimit>
9025ScalarEvolution::ExitLimitCache::find(const Loop *L, Value *ExitCond,
9026 bool ExitIfTrue, bool ControlsOnlyExit,
9027 bool AllowPredicates) {
9028 (void)this->L;
9029 (void)this->ExitIfTrue;
9030 (void)this->AllowPredicates;
9031
9032 assert(this->L == L && this->ExitIfTrue == ExitIfTrue &&
9033 this->AllowPredicates == AllowPredicates &&
9034 "Variance in assumed invariant key components!");
9035 auto Itr = TripCountMap.find({ExitCond, ControlsOnlyExit});
9036 if (Itr == TripCountMap.end())
9037 return std::nullopt;
9038 return Itr->second;
9039}
9040
9041void ScalarEvolution::ExitLimitCache::insert(const Loop *L, Value *ExitCond,
9042 bool ExitIfTrue,
9043 bool ControlsOnlyExit,
9044 bool AllowPredicates,
9045 const ExitLimit &EL) {
9046 assert(this->L == L && this->ExitIfTrue == ExitIfTrue &&
9047 this->AllowPredicates == AllowPredicates &&
9048 "Variance in assumed invariant key components!");
9049
9050 auto InsertResult = TripCountMap.insert({{ExitCond, ControlsOnlyExit}, EL});
9051 assert(InsertResult.second && "Expected successful insertion!");
9052 (void)InsertResult;
9053 (void)ExitIfTrue;
9054}
9055
9056ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromCondCached(
9057 ExitLimitCacheTy &Cache, const Loop *L, Value *ExitCond, bool ExitIfTrue,
9058 bool ControlsOnlyExit, bool AllowPredicates) {
9059
9060 if (auto MaybeEL = Cache.find(L, ExitCond, ExitIfTrue, ControlsOnlyExit,
9061 AllowPredicates))
9062 return *MaybeEL;
9063
9064 ExitLimit EL = computeExitLimitFromCondImpl(
9065 Cache, L, ExitCond, ExitIfTrue, ControlsOnlyExit, AllowPredicates);
9066 Cache.insert(L, ExitCond, ExitIfTrue, ControlsOnlyExit, AllowPredicates, EL);
9067 return EL;
9068}
9069
9070ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromCondImpl(
9071 ExitLimitCacheTy &Cache, const Loop *L, Value *ExitCond, bool ExitIfTrue,
9072 bool ControlsOnlyExit, bool AllowPredicates) {
9073 // Handle BinOp conditions (And, Or).
9074 if (auto LimitFromBinOp = computeExitLimitFromCondFromBinOp(
9075 Cache, L, ExitCond, ExitIfTrue, ControlsOnlyExit, AllowPredicates))
9076 return *LimitFromBinOp;
9077
9078 // With an icmp, it may be feasible to compute an exact backedge-taken count.
9079 // Proceed to the next level to examine the icmp.
9080 if (ICmpInst *ExitCondICmp = dyn_cast<ICmpInst>(ExitCond)) {
9081 ExitLimit EL =
9082 computeExitLimitFromICmp(L, ExitCondICmp, ExitIfTrue, ControlsOnlyExit);
9083 if (EL.hasFullInfo() || !AllowPredicates)
9084 return EL;
9085
9086 // Try again, but use SCEV predicates this time.
9087 return computeExitLimitFromICmp(L, ExitCondICmp, ExitIfTrue,
9088 ControlsOnlyExit,
9089 /*AllowPredicates=*/true);
9090 }
9091
9092 // Check for a constant condition. These are normally stripped out by
9093 // SimplifyCFG, but ScalarEvolution may be used by a pass which wishes to
9094 // preserve the CFG and is temporarily leaving constant conditions
9095 // in place.
9096 if (ConstantInt *CI = dyn_cast<ConstantInt>(ExitCond)) {
9097 if (ExitIfTrue == !CI->getZExtValue())
9098 // The backedge is always taken.
9099 return getCouldNotCompute();
9100 // The backedge is never taken.
9101 return getZero(CI->getType());
9102 }
9103
9104 // If we're exiting based on the overflow flag of an x.with.overflow intrinsic
9105 // with a constant step, we can form an equivalent icmp predicate and figure
9106 // out how many iterations will be taken before we exit.
9107 const WithOverflowInst *WO;
9108 const APInt *C;
9109 if (match(ExitCond, m_ExtractValue<1>(m_WithOverflowInst(WO))) &&
9110 match(WO->getRHS(), m_APInt(C))) {
9111 ConstantRange NWR =
9113 WO->getNoWrapKind());
9114 CmpInst::Predicate Pred;
9115 APInt NewRHSC, Offset;
9116 NWR.getEquivalentICmp(Pred, NewRHSC, Offset);
9117 if (!ExitIfTrue)
9118 Pred = ICmpInst::getInversePredicate(Pred);
9119 auto *LHS = getSCEV(WO->getLHS());
9120 if (Offset != 0)
9122 auto EL = computeExitLimitFromICmp(L, Pred, LHS, getConstant(NewRHSC),
9123 ControlsOnlyExit, AllowPredicates);
9124 if (EL.hasAnyInfo())
9125 return EL;
9126 }
9127
9128 // If it's not an integer or pointer comparison then compute it the hard way.
9129 return computeExitCountExhaustively(L, ExitCond, ExitIfTrue);
9130}
9131
9132std::optional<ScalarEvolution::ExitLimit>
9133ScalarEvolution::computeExitLimitFromCondFromBinOp(
9134 ExitLimitCacheTy &Cache, const Loop *L, Value *ExitCond, bool ExitIfTrue,
9135 bool ControlsOnlyExit, bool AllowPredicates) {
9136 // Check if the controlling expression for this loop is an And or Or.
9137 Value *Op0, *Op1;
9138 bool IsAnd = false;
9139 if (match(ExitCond, m_LogicalAnd(m_Value(Op0), m_Value(Op1))))
9140 IsAnd = true;
9141 else if (match(ExitCond, m_LogicalOr(m_Value(Op0), m_Value(Op1))))
9142 IsAnd = false;
9143 else
9144 return std::nullopt;
9145
9146 // EitherMayExit is true in these two cases:
9147 // br (and Op0 Op1), loop, exit
9148 // br (or Op0 Op1), exit, loop
9149 bool EitherMayExit = IsAnd ^ ExitIfTrue;
9150 ExitLimit EL0 = computeExitLimitFromCondCached(
9151 Cache, L, Op0, ExitIfTrue, ControlsOnlyExit && !EitherMayExit,
9152 AllowPredicates);
9153 ExitLimit EL1 = computeExitLimitFromCondCached(
9154 Cache, L, Op1, ExitIfTrue, ControlsOnlyExit && !EitherMayExit,
9155 AllowPredicates);
9156
9157 // Be robust against unsimplified IR for the form "op i1 X, NeutralElement"
9158 const Constant *NeutralElement = ConstantInt::get(ExitCond->getType(), IsAnd);
9159 if (isa<ConstantInt>(Op1))
9160 return Op1 == NeutralElement ? EL0 : EL1;
9161 if (isa<ConstantInt>(Op0))
9162 return Op0 == NeutralElement ? EL1 : EL0;
9163
9164 const SCEV *BECount = getCouldNotCompute();
9165 const SCEV *ConstantMaxBECount = getCouldNotCompute();
9166 const SCEV *SymbolicMaxBECount = getCouldNotCompute();
9167 if (EitherMayExit) {
9168 bool UseSequentialUMin = !isa<BinaryOperator>(ExitCond);
9169 // Both conditions must be same for the loop to continue executing.
9170 // Choose the less conservative count.
9171 if (EL0.ExactNotTaken != getCouldNotCompute() &&
9172 EL1.ExactNotTaken != getCouldNotCompute()) {
9173 BECount = getUMinFromMismatchedTypes(EL0.ExactNotTaken, EL1.ExactNotTaken,
9174 UseSequentialUMin);
9175 }
9176 if (EL0.ConstantMaxNotTaken == getCouldNotCompute())
9177 ConstantMaxBECount = EL1.ConstantMaxNotTaken;
9178 else if (EL1.ConstantMaxNotTaken == getCouldNotCompute())
9179 ConstantMaxBECount = EL0.ConstantMaxNotTaken;
9180 else
9181 ConstantMaxBECount = getUMinFromMismatchedTypes(EL0.ConstantMaxNotTaken,
9182 EL1.ConstantMaxNotTaken);
9183 if (EL0.SymbolicMaxNotTaken == getCouldNotCompute())
9184 SymbolicMaxBECount = EL1.SymbolicMaxNotTaken;
9185 else if (EL1.SymbolicMaxNotTaken == getCouldNotCompute())
9186 SymbolicMaxBECount = EL0.SymbolicMaxNotTaken;
9187 else
9188 SymbolicMaxBECount = getUMinFromMismatchedTypes(
9189 EL0.SymbolicMaxNotTaken, EL1.SymbolicMaxNotTaken, UseSequentialUMin);
9190 } else {
9191 // Both conditions must be same at the same time for the loop to exit.
9192 // For now, be conservative.
9193 if (EL0.ExactNotTaken == EL1.ExactNotTaken)
9194 BECount = EL0.ExactNotTaken;
9195 }
9196
9197 // There are cases (e.g. PR26207) where computeExitLimitFromCond is able
9198 // to be more aggressive when computing BECount than when computing
9199 // ConstantMaxBECount. In these cases it is possible for EL0.ExactNotTaken
9200 // and
9201 // EL1.ExactNotTaken to match, but for EL0.ConstantMaxNotTaken and
9202 // EL1.ConstantMaxNotTaken to not.
9203 if (isa<SCEVCouldNotCompute>(ConstantMaxBECount) &&
9204 !isa<SCEVCouldNotCompute>(BECount))
9205 ConstantMaxBECount = getConstant(getUnsignedRangeMax(BECount));
9206 if (isa<SCEVCouldNotCompute>(SymbolicMaxBECount))
9207 SymbolicMaxBECount =
9208 isa<SCEVCouldNotCompute>(BECount) ? ConstantMaxBECount : BECount;
9209 return ExitLimit(BECount, ConstantMaxBECount, SymbolicMaxBECount, false,
9210 {ArrayRef(EL0.Predicates), ArrayRef(EL1.Predicates)});
9211}
9212
9213ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromICmp(
9214 const Loop *L, ICmpInst *ExitCond, bool ExitIfTrue, bool ControlsOnlyExit,
9215 bool AllowPredicates) {
9216 // If the condition was exit on true, convert the condition to exit on false
9217 CmpPredicate Pred;
9218 if (!ExitIfTrue)
9219 Pred = ExitCond->getCmpPredicate();
9220 else
9221 Pred = ExitCond->getInverseCmpPredicate();
9222 const ICmpInst::Predicate OriginalPred = Pred;
9223
9224 const SCEV *LHS = getSCEV(ExitCond->getOperand(0));
9225 const SCEV *RHS = getSCEV(ExitCond->getOperand(1));
9226
9227 ExitLimit EL = computeExitLimitFromICmp(L, Pred, LHS, RHS, ControlsOnlyExit,
9228 AllowPredicates);
9229 if (EL.hasAnyInfo())
9230 return EL;
9231
9232 auto *ExhaustiveCount =
9233 computeExitCountExhaustively(L, ExitCond, ExitIfTrue);
9234
9235 if (!isa<SCEVCouldNotCompute>(ExhaustiveCount))
9236 return ExhaustiveCount;
9237
9238 return computeShiftCompareExitLimit(ExitCond->getOperand(0),
9239 ExitCond->getOperand(1), L, OriginalPred);
9240}
9241ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromICmp(
9242 const Loop *L, CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS,
9243 bool ControlsOnlyExit, bool AllowPredicates) {
9244
9245 // Try to evaluate any dependencies out of the loop.
9246 LHS = getSCEVAtScope(LHS, L);
9247 RHS = getSCEVAtScope(RHS, L);
9248
9249 // At this point, we would like to compute how many iterations of the
9250 // loop the predicate will return true for these inputs.
9251 if (isLoopInvariant(LHS, L) && !isLoopInvariant(RHS, L)) {
9252 // If there is a loop-invariant, force it into the RHS.
9253 std::swap(LHS, RHS);
9255 }
9256
9257 bool ControllingFiniteLoop = ControlsOnlyExit && loopHasNoAbnormalExits(L) &&
9259 // Simplify the operands before analyzing them.
9260 (void)SimplifyICmpOperands(Pred, LHS, RHS, /*Depth=*/0);
9261
9262 // If we have a comparison of a chrec against a constant, try to use value
9263 // ranges to answer this query.
9264 if (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS))
9265 if (const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(LHS))
9266 if (AddRec->getLoop() == L) {
9267 // Form the constant range.
9268 ConstantRange CompRange =
9269 ConstantRange::makeExactICmpRegion(Pred, RHSC->getAPInt());
9270
9271 const SCEV *Ret = AddRec->getNumIterationsInRange(CompRange, *this);
9272 if (!isa<SCEVCouldNotCompute>(Ret)) return Ret;
9273 }
9274
9275 // If this loop must exit based on this condition (or execute undefined
9276 // behaviour), see if we can improve wrap flags. This is essentially
9277 // a must execute style proof.
9278 if (ControllingFiniteLoop && isLoopInvariant(RHS, L)) {
9279 // If we can prove the test sequence produced must repeat the same values
9280 // on self-wrap of the IV, then we can infer that IV doesn't self wrap
9281 // because if it did, we'd have an infinite (undefined) loop.
9282 // TODO: We can peel off any functions which are invertible *in L*. Loop
9283 // invariant terms are effectively constants for our purposes here.
9284 auto *InnerLHS = LHS;
9285 if (auto *ZExt = dyn_cast<SCEVZeroExtendExpr>(LHS))
9286 InnerLHS = ZExt->getOperand();
9287 if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(InnerLHS);
9288 AR && !AR->hasNoSelfWrap() && AR->getLoop() == L && AR->isAffine() &&
9289 isKnownToBeAPowerOfTwo(AR->getStepRecurrence(*this), /*OrZero=*/true,
9290 /*OrNegative=*/true)) {
9291 auto Flags = AR->getNoWrapFlags();
9292 Flags = setFlags(Flags, SCEV::FlagNW);
9293 SmallVector<const SCEV *> Operands{AR->operands()};
9294 Flags = StrengthenNoWrapFlags(this, scAddRecExpr, Operands, Flags);
9295 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), Flags);
9296 }
9297
9298 // For a slt/ult condition with a positive step, can we prove nsw/nuw?
9299 // From no-self-wrap, this follows trivially from the fact that every
9300 // (un)signed-wrapped, but not self-wrapped value must be LT than the
9301 // last value before (un)signed wrap. Since we know that last value
9302 // didn't exit, nor will any smaller one.
9303 if (Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_ULT) {
9304 auto WrapType = Pred == ICmpInst::ICMP_SLT ? SCEV::FlagNSW : SCEV::FlagNUW;
9305 if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(LHS);
9306 AR && AR->getLoop() == L && AR->isAffine() &&
9307 !AR->getNoWrapFlags(WrapType) && AR->hasNoSelfWrap() &&
9308 isKnownPositive(AR->getStepRecurrence(*this))) {
9309 auto Flags = AR->getNoWrapFlags();
9310 Flags = setFlags(Flags, WrapType);
9311 SmallVector<const SCEV*> Operands{AR->operands()};
9312 Flags = StrengthenNoWrapFlags(this, scAddRecExpr, Operands, Flags);
9313 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), Flags);
9314 }
9315 }
9316 }
9317
9318 switch (Pred) {
9319 case ICmpInst::ICMP_NE: { // while (X != Y)
9320 // Convert to: while (X-Y != 0)
9321 if (LHS->getType()->isPointerTy()) {
9324 return LHS;
9325 }
9326 if (RHS->getType()->isPointerTy()) {
9329 return RHS;
9330 }
9331 ExitLimit EL = howFarToZero(getMinusSCEV(LHS, RHS), L, ControlsOnlyExit,
9332 AllowPredicates);
9333 if (EL.hasAnyInfo())
9334 return EL;
9335 break;
9336 }
9337 case ICmpInst::ICMP_EQ: { // while (X == Y)
9338 // Convert to: while (X-Y == 0)
9339 if (LHS->getType()->isPointerTy()) {
9342 return LHS;
9343 }
9344 if (RHS->getType()->isPointerTy()) {
9347 return RHS;
9348 }
9349 ExitLimit EL = howFarToNonZero(getMinusSCEV(LHS, RHS), L);
9350 if (EL.hasAnyInfo()) return EL;
9351 break;
9352 }
9353 case ICmpInst::ICMP_SLE:
9354 case ICmpInst::ICMP_ULE:
9355 // Since the loop is finite, an invariant RHS cannot include the boundary
9356 // value, otherwise it would loop forever.
9357 if (!EnableFiniteLoopControl || !ControllingFiniteLoop ||
9358 !isLoopInvariant(RHS, L)) {
9359 // Otherwise, perform the addition in a wider type, to avoid overflow.
9360 // If the LHS is an addrec with the appropriate nowrap flag, the
9361 // extension will be sunk into it and the exit count can be analyzed.
9362 auto *OldType = dyn_cast<IntegerType>(LHS->getType());
9363 if (!OldType)
9364 break;
9365 // Prefer doubling the bitwidth over adding a single bit to make it more
9366 // likely that we use a legal type.
9367 auto *NewType =
9368 Type::getIntNTy(OldType->getContext(), OldType->getBitWidth() * 2);
9369 if (ICmpInst::isSigned(Pred)) {
9370 LHS = getSignExtendExpr(LHS, NewType);
9371 RHS = getSignExtendExpr(RHS, NewType);
9372 } else {
9373 LHS = getZeroExtendExpr(LHS, NewType);
9374 RHS = getZeroExtendExpr(RHS, NewType);
9375 }
9376 }
9378 [[fallthrough]];
9379 case ICmpInst::ICMP_SLT:
9380 case ICmpInst::ICMP_ULT: { // while (X < Y)
9381 bool IsSigned = ICmpInst::isSigned(Pred);
9382 ExitLimit EL = howManyLessThans(LHS, RHS, L, IsSigned, ControlsOnlyExit,
9383 AllowPredicates);
9384 if (EL.hasAnyInfo())
9385 return EL;
9386 break;
9387 }
9388 case ICmpInst::ICMP_SGE:
9389 case ICmpInst::ICMP_UGE:
9390 // Since the loop is finite, an invariant RHS cannot include the boundary
9391 // value, otherwise it would loop forever.
9392 if (!EnableFiniteLoopControl || !ControllingFiniteLoop ||
9393 !isLoopInvariant(RHS, L))
9394 break;
9396 [[fallthrough]];
9397 case ICmpInst::ICMP_SGT:
9398 case ICmpInst::ICMP_UGT: { // while (X > Y)
9399 bool IsSigned = ICmpInst::isSigned(Pred);
9400 ExitLimit EL = howManyGreaterThans(LHS, RHS, L, IsSigned, ControlsOnlyExit,
9401 AllowPredicates);
9402 if (EL.hasAnyInfo())
9403 return EL;
9404 break;
9405 }
9406 default:
9407 break;
9408 }
9409
9410 return getCouldNotCompute();
9411}
9412
9413ScalarEvolution::ExitLimit
9414ScalarEvolution::computeExitLimitFromSingleExitSwitch(const Loop *L,
9415 SwitchInst *Switch,
9416 BasicBlock *ExitingBlock,
9417 bool ControlsOnlyExit) {
9418 assert(!L->contains(ExitingBlock) && "Not an exiting block!");
9419
9420 // Give up if the exit is the default dest of a switch.
9421 if (Switch->getDefaultDest() == ExitingBlock)
9422 return getCouldNotCompute();
9423
9424 assert(L->contains(Switch->getDefaultDest()) &&
9425 "Default case must not exit the loop!");
9426 const SCEV *LHS = getSCEVAtScope(Switch->getCondition(), L);
9427 const SCEV *RHS = getConstant(Switch->findCaseDest(ExitingBlock));
9428
9429 // while (X != Y) --> while (X-Y != 0)
9430 ExitLimit EL = howFarToZero(getMinusSCEV(LHS, RHS), L, ControlsOnlyExit);
9431 if (EL.hasAnyInfo())
9432 return EL;
9433
9434 return getCouldNotCompute();
9435}
9436
9437static ConstantInt *
9439 ScalarEvolution &SE) {
9440 const SCEV *InVal = SE.getConstant(C);
9441 const SCEV *Val = AddRec->evaluateAtIteration(InVal, SE);
9443 "Evaluation of SCEV at constant didn't fold correctly?");
9444 return cast<SCEVConstant>(Val)->getValue();
9445}
9446
9447ScalarEvolution::ExitLimit ScalarEvolution::computeShiftCompareExitLimit(
9448 Value *LHS, Value *RHSV, const Loop *L, ICmpInst::Predicate Pred) {
9449 ConstantInt *RHS = dyn_cast<ConstantInt>(RHSV);
9450 if (!RHS)
9451 return getCouldNotCompute();
9452
9453 const BasicBlock *Latch = L->getLoopLatch();
9454 if (!Latch)
9455 return getCouldNotCompute();
9456
9457 const BasicBlock *Predecessor = L->getLoopPredecessor();
9458 if (!Predecessor)
9459 return getCouldNotCompute();
9460
9461 // Return true if V is of the form "LHS `shift_op` <positive constant>".
9462 // Return LHS in OutLHS and shift_opt in OutOpCode.
9463 auto MatchPositiveShift =
9464 [](Value *V, Value *&OutLHS, Instruction::BinaryOps &OutOpCode) {
9465
9466 using namespace PatternMatch;
9467
9468 ConstantInt *ShiftAmt;
9469 if (match(V, m_LShr(m_Value(OutLHS), m_ConstantInt(ShiftAmt))))
9470 OutOpCode = Instruction::LShr;
9471 else if (match(V, m_AShr(m_Value(OutLHS), m_ConstantInt(ShiftAmt))))
9472 OutOpCode = Instruction::AShr;
9473 else if (match(V, m_Shl(m_Value(OutLHS), m_ConstantInt(ShiftAmt))))
9474 OutOpCode = Instruction::Shl;
9475 else
9476 return false;
9477
9478 return ShiftAmt->getValue().isStrictlyPositive();
9479 };
9480
9481 // Recognize a "shift recurrence" either of the form %iv or of %iv.shifted in
9482 //
9483 // loop:
9484 // %iv = phi i32 [ %iv.shifted, %loop ], [ %val, %preheader ]
9485 // %iv.shifted = lshr i32 %iv, <positive constant>
9486 //
9487 // Return true on a successful match. Return the corresponding PHI node (%iv
9488 // above) in PNOut and the opcode of the shift operation in OpCodeOut.
9489 auto MatchShiftRecurrence =
9490 [&](Value *V, PHINode *&PNOut, Instruction::BinaryOps &OpCodeOut) {
9491 std::optional<Instruction::BinaryOps> PostShiftOpCode;
9492
9493 {
9495 Value *V;
9496
9497 // If we encounter a shift instruction, "peel off" the shift operation,
9498 // and remember that we did so. Later when we inspect %iv's backedge
9499 // value, we will make sure that the backedge value uses the same
9500 // operation.
9501 //
9502 // Note: the peeled shift operation does not have to be the same
9503 // instruction as the one feeding into the PHI's backedge value. We only
9504 // really care about it being the same *kind* of shift instruction --
9505 // that's all that is required for our later inferences to hold.
9506 if (MatchPositiveShift(LHS, V, OpC)) {
9507 PostShiftOpCode = OpC;
9508 LHS = V;
9509 }
9510 }
9511
9512 PNOut = dyn_cast<PHINode>(LHS);
9513 if (!PNOut || PNOut->getParent() != L->getHeader())
9514 return false;
9515
9516 Value *BEValue = PNOut->getIncomingValueForBlock(Latch);
9517 Value *OpLHS;
9518
9519 return
9520 // The backedge value for the PHI node must be a shift by a positive
9521 // amount
9522 MatchPositiveShift(BEValue, OpLHS, OpCodeOut) &&
9523
9524 // of the PHI node itself
9525 OpLHS == PNOut &&
9526
9527 // and the kind of shift should be match the kind of shift we peeled
9528 // off, if any.
9529 (!PostShiftOpCode || *PostShiftOpCode == OpCodeOut);
9530 };
9531
9532 PHINode *PN;
9534 if (!MatchShiftRecurrence(LHS, PN, OpCode))
9535 return getCouldNotCompute();
9536
9537 const DataLayout &DL = getDataLayout();
9538
9539 // The key rationale for this optimization is that for some kinds of shift
9540 // recurrences, the value of the recurrence "stabilizes" to either 0 or -1
9541 // within a finite number of iterations. If the condition guarding the
9542 // backedge (in the sense that the backedge is taken if the condition is true)
9543 // is false for the value the shift recurrence stabilizes to, then we know
9544 // that the backedge is taken only a finite number of times.
9545
9546 ConstantInt *StableValue = nullptr;
9547 switch (OpCode) {
9548 default:
9549 llvm_unreachable("Impossible case!");
9550
9551 case Instruction::AShr: {
9552 // {K,ashr,<positive-constant>} stabilizes to signum(K) in at most
9553 // bitwidth(K) iterations.
9554 Value *FirstValue = PN->getIncomingValueForBlock(Predecessor);
9555 KnownBits Known = computeKnownBits(FirstValue, DL, &AC,
9556 Predecessor->getTerminator(), &DT);
9557 auto *Ty = cast<IntegerType>(RHS->getType());
9558 if (Known.isNonNegative())
9559 StableValue = ConstantInt::get(Ty, 0);
9560 else if (Known.isNegative())
9561 StableValue = ConstantInt::get(Ty, -1, true);
9562 else
9563 return getCouldNotCompute();
9564
9565 break;
9566 }
9567 case Instruction::LShr:
9568 case Instruction::Shl:
9569 // Both {K,lshr,<positive-constant>} and {K,shl,<positive-constant>}
9570 // stabilize to 0 in at most bitwidth(K) iterations.
9571 StableValue = ConstantInt::get(cast<IntegerType>(RHS->getType()), 0);
9572 break;
9573 }
9574
9575 auto *Result =
9576 ConstantFoldCompareInstOperands(Pred, StableValue, RHS, DL, &TLI);
9577 assert(Result->getType()->isIntegerTy(1) &&
9578 "Otherwise cannot be an operand to a branch instruction");
9579
9580 if (Result->isZeroValue()) {
9581 unsigned BitWidth = getTypeSizeInBits(RHS->getType());
9582 const SCEV *UpperBound =
9584 return ExitLimit(getCouldNotCompute(), UpperBound, UpperBound, false);
9585 }
9586
9587 return getCouldNotCompute();
9588}
9589
9590/// Return true if we can constant fold an instruction of the specified type,
9591/// assuming that all operands were constants.
9592static bool CanConstantFold(const Instruction *I) {
9596 return true;
9597
9598 if (const CallInst *CI = dyn_cast<CallInst>(I))
9599 if (const Function *F = CI->getCalledFunction())
9600 return canConstantFoldCallTo(CI, F);
9601 return false;
9602}
9603
9604/// Determine whether this instruction can constant evolve within this loop
9605/// assuming its operands can all constant evolve.
9606static bool canConstantEvolve(Instruction *I, const Loop *L) {
9607 // An instruction outside of the loop can't be derived from a loop PHI.
9608 if (!L->contains(I)) return false;
9609
9610 if (isa<PHINode>(I)) {
9611 // We don't currently keep track of the control flow needed to evaluate
9612 // PHIs, so we cannot handle PHIs inside of loops.
9613 return L->getHeader() == I->getParent();
9614 }
9615
9616 // If we won't be able to constant fold this expression even if the operands
9617 // are constants, bail early.
9618 return CanConstantFold(I);
9619}
9620
9621/// getConstantEvolvingPHIOperands - Implement getConstantEvolvingPHI by
9622/// recursing through each instruction operand until reaching a loop header phi.
9623static PHINode *
9626 unsigned Depth) {
9628 return nullptr;
9629
9630 // Otherwise, we can evaluate this instruction if all of its operands are
9631 // constant or derived from a PHI node themselves.
9632 PHINode *PHI = nullptr;
9633 for (Value *Op : UseInst->operands()) {
9634 if (isa<Constant>(Op)) continue;
9635
9637 if (!OpInst || !canConstantEvolve(OpInst, L)) return nullptr;
9638
9639 PHINode *P = dyn_cast<PHINode>(OpInst);
9640 if (!P)
9641 // If this operand is already visited, reuse the prior result.
9642 // We may have P != PHI if this is the deepest point at which the
9643 // inconsistent paths meet.
9644 P = PHIMap.lookup(OpInst);
9645 if (!P) {
9646 // Recurse and memoize the results, whether a phi is found or not.
9647 // This recursive call invalidates pointers into PHIMap.
9648 P = getConstantEvolvingPHIOperands(OpInst, L, PHIMap, Depth + 1);
9649 PHIMap[OpInst] = P;
9650 }
9651 if (!P)
9652 return nullptr; // Not evolving from PHI
9653 if (PHI && PHI != P)
9654 return nullptr; // Evolving from multiple different PHIs.
9655 PHI = P;
9656 }
9657 // This is a expression evolving from a constant PHI!
9658 return PHI;
9659}
9660
9661/// getConstantEvolvingPHI - Given an LLVM value and a loop, return a PHI node
9662/// in the loop that V is derived from. We allow arbitrary operations along the
9663/// way, but the operands of an operation must either be constants or a value
9664/// derived from a constant PHI. If this expression does not fit with these
9665/// constraints, return null.
9668 if (!I || !canConstantEvolve(I, L)) return nullptr;
9669
9670 if (PHINode *PN = dyn_cast<PHINode>(I))
9671 return PN;
9672
9673 // Record non-constant instructions contained by the loop.
9675 return getConstantEvolvingPHIOperands(I, L, PHIMap, 0);
9676}
9677
9678/// EvaluateExpression - Given an expression that passes the
9679/// getConstantEvolvingPHI predicate, evaluate its value assuming the PHI node
9680/// in the loop has the value PHIVal. If we can't fold this expression for some
9681/// reason, return null.
9684 const DataLayout &DL,
9685 const TargetLibraryInfo *TLI) {
9686 // Convenient constant check, but redundant for recursive calls.
9687 if (Constant *C = dyn_cast<Constant>(V)) return C;
9689 if (!I) return nullptr;
9690
9691 if (Constant *C = Vals.lookup(I)) return C;
9692
9693 // An instruction inside the loop depends on a value outside the loop that we
9694 // weren't given a mapping for, or a value such as a call inside the loop.
9695 if (!canConstantEvolve(I, L)) return nullptr;
9696
9697 // An unmapped PHI can be due to a branch or another loop inside this loop,
9698 // or due to this not being the initial iteration through a loop where we
9699 // couldn't compute the evolution of this particular PHI last time.
9700 if (isa<PHINode>(I)) return nullptr;
9701
9702 std::vector<Constant*> Operands(I->getNumOperands());
9703
9704 for (unsigned i = 0, e = I->getNumOperands(); i != e; ++i) {
9705 Instruction *Operand = dyn_cast<Instruction>(I->getOperand(i));
9706 if (!Operand) {
9707 Operands[i] = dyn_cast<Constant>(I->getOperand(i));
9708 if (!Operands[i]) return nullptr;
9709 continue;
9710 }
9711 Constant *C = EvaluateExpression(Operand, L, Vals, DL, TLI);
9712 Vals[Operand] = C;
9713 if (!C) return nullptr;
9714 Operands[i] = C;
9715 }
9716
9717 return ConstantFoldInstOperands(I, Operands, DL, TLI,
9718 /*AllowNonDeterministic=*/false);
9719}
9720
9721
9722// If every incoming value to PN except the one for BB is a specific Constant,
9723// return that, else return nullptr.
9725 Constant *IncomingVal = nullptr;
9726
9727 for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) {
9728 if (PN->getIncomingBlock(i) == BB)
9729 continue;
9730
9731 auto *CurrentVal = dyn_cast<Constant>(PN->getIncomingValue(i));
9732 if (!CurrentVal)
9733 return nullptr;
9734
9735 if (IncomingVal != CurrentVal) {
9736 if (IncomingVal)
9737 return nullptr;
9738 IncomingVal = CurrentVal;
9739 }
9740 }
9741
9742 return IncomingVal;
9743}
9744
9745/// getConstantEvolutionLoopExitValue - If we know that the specified Phi is
9746/// in the header of its containing loop, we know the loop executes a
9747/// constant number of times, and the PHI node is just a recurrence
9748/// involving constants, fold it.
9749Constant *
9750ScalarEvolution::getConstantEvolutionLoopExitValue(PHINode *PN,
9751 const APInt &BEs,
9752 const Loop *L) {
9753 auto [I, Inserted] = ConstantEvolutionLoopExitValue.try_emplace(PN);
9754 if (!Inserted)
9755 return I->second;
9756
9758 return nullptr; // Not going to evaluate it.
9759
9760 Constant *&RetVal = I->second;
9761
9762 DenseMap<Instruction *, Constant *> CurrentIterVals;
9763 BasicBlock *Header = L->getHeader();
9764 assert(PN->getParent() == Header && "Can't evaluate PHI not in loop header!");
9765
9766 BasicBlock *Latch = L->getLoopLatch();
9767 if (!Latch)
9768 return nullptr;
9769
9770 for (PHINode &PHI : Header->phis()) {
9771 if (auto *StartCST = getOtherIncomingValue(&PHI, Latch))
9772 CurrentIterVals[&PHI] = StartCST;
9773 }
9774 if (!CurrentIterVals.count(PN))
9775 return RetVal = nullptr;
9776
9777 Value *BEValue = PN->getIncomingValueForBlock(Latch);
9778
9779 // Execute the loop symbolically to determine the exit value.
9780 assert(BEs.getActiveBits() < CHAR_BIT * sizeof(unsigned) &&
9781 "BEs is <= MaxBruteForceIterations which is an 'unsigned'!");
9782
9783 unsigned NumIterations = BEs.getZExtValue(); // must be in range
9784 unsigned IterationNum = 0;
9785 const DataLayout &DL = getDataLayout();
9786 for (; ; ++IterationNum) {
9787 if (IterationNum == NumIterations)
9788 return RetVal = CurrentIterVals[PN]; // Got exit value!
9789
9790 // Compute the value of the PHIs for the next iteration.
9791 // EvaluateExpression adds non-phi values to the CurrentIterVals map.
9792 DenseMap<Instruction *, Constant *> NextIterVals;
9793 Constant *NextPHI =
9794 EvaluateExpression(BEValue, L, CurrentIterVals, DL, &TLI);
9795 if (!NextPHI)
9796 return nullptr; // Couldn't evaluate!
9797 NextIterVals[PN] = NextPHI;
9798
9799 bool StoppedEvolving = NextPHI == CurrentIterVals[PN];
9800
9801 // Also evaluate the other PHI nodes. However, we don't get to stop if we
9802 // cease to be able to evaluate one of them or if they stop evolving,
9803 // because that doesn't necessarily prevent us from computing PN.
9805 for (const auto &I : CurrentIterVals) {
9806 PHINode *PHI = dyn_cast<PHINode>(I.first);
9807 if (!PHI || PHI == PN || PHI->getParent() != Header) continue;
9808 PHIsToCompute.emplace_back(PHI, I.second);
9809 }
9810 // We use two distinct loops because EvaluateExpression may invalidate any
9811 // iterators into CurrentIterVals.
9812 for (const auto &I : PHIsToCompute) {
9813 PHINode *PHI = I.first;
9814 Constant *&NextPHI = NextIterVals[PHI];
9815 if (!NextPHI) { // Not already computed.
9816 Value *BEValue = PHI->getIncomingValueForBlock(Latch);
9817 NextPHI = EvaluateExpression(BEValue, L, CurrentIterVals, DL, &TLI);
9818 }
9819 if (NextPHI != I.second)
9820 StoppedEvolving = false;
9821 }
9822
9823 // If all entries in CurrentIterVals == NextIterVals then we can stop
9824 // iterating, the loop can't continue to change.
9825 if (StoppedEvolving)
9826 return RetVal = CurrentIterVals[PN];
9827
9828 CurrentIterVals.swap(NextIterVals);
9829 }
9830}
9831
9832const SCEV *ScalarEvolution::computeExitCountExhaustively(const Loop *L,
9833 Value *Cond,
9834 bool ExitWhen) {
9835 PHINode *PN = getConstantEvolvingPHI(Cond, L);
9836 if (!PN) return getCouldNotCompute();
9837
9838 // If the loop is canonicalized, the PHI will have exactly two entries.
9839 // That's the only form we support here.
9840 if (PN->getNumIncomingValues() != 2) return getCouldNotCompute();
9841
9842 DenseMap<Instruction *, Constant *> CurrentIterVals;
9843 BasicBlock *Header = L->getHeader();
9844 assert(PN->getParent() == Header && "Can't evaluate PHI not in loop header!");
9845
9846 BasicBlock *Latch = L->getLoopLatch();
9847 assert(Latch && "Should follow from NumIncomingValues == 2!");
9848
9849 for (PHINode &PHI : Header->phis()) {
9850 if (auto *StartCST = getOtherIncomingValue(&PHI, Latch))
9851 CurrentIterVals[&PHI] = StartCST;
9852 }
9853 if (!CurrentIterVals.count(PN))
9854 return getCouldNotCompute();
9855
9856 // Okay, we find a PHI node that defines the trip count of this loop. Execute
9857 // the loop symbolically to determine when the condition gets a value of
9858 // "ExitWhen".
9859 unsigned MaxIterations = MaxBruteForceIterations; // Limit analysis.
9860 const DataLayout &DL = getDataLayout();
9861 for (unsigned IterationNum = 0; IterationNum != MaxIterations;++IterationNum){
9862 auto *CondVal = dyn_cast_or_null<ConstantInt>(
9863 EvaluateExpression(Cond, L, CurrentIterVals, DL, &TLI));
9864
9865 // Couldn't symbolically evaluate.
9866 if (!CondVal) return getCouldNotCompute();
9867
9868 if (CondVal->getValue() == uint64_t(ExitWhen)) {
9869 ++NumBruteForceTripCountsComputed;
9870 return getConstant(Type::getInt32Ty(getContext()), IterationNum);
9871 }
9872
9873 // Update all the PHI nodes for the next iteration.
9874 DenseMap<Instruction *, Constant *> NextIterVals;
9875
9876 // Create a list of which PHIs we need to compute. We want to do this before
9877 // calling EvaluateExpression on them because that may invalidate iterators
9878 // into CurrentIterVals.
9879 SmallVector<PHINode *, 8> PHIsToCompute;
9880 for (const auto &I : CurrentIterVals) {
9881 PHINode *PHI = dyn_cast<PHINode>(I.first);
9882 if (!PHI || PHI->getParent() != Header) continue;
9883 PHIsToCompute.push_back(PHI);
9884 }
9885 for (PHINode *PHI : PHIsToCompute) {
9886 Constant *&NextPHI = NextIterVals[PHI];
9887 if (NextPHI) continue; // Already computed!
9888
9889 Value *BEValue = PHI->getIncomingValueForBlock(Latch);
9890 NextPHI = EvaluateExpression(BEValue, L, CurrentIterVals, DL, &TLI);
9891 }
9892 CurrentIterVals.swap(NextIterVals);
9893 }
9894
9895 // Too many iterations were needed to evaluate.
9896 return getCouldNotCompute();
9897}
9898
9899const SCEV *ScalarEvolution::getSCEVAtScope(const SCEV *V, const Loop *L) {
9901 ValuesAtScopes[V];
9902 // Check to see if we've folded this expression at this loop before.
9903 for (auto &LS : Values)
9904 if (LS.first == L)
9905 return LS.second ? LS.second : V;
9906
9907 Values.emplace_back(L, nullptr);
9908
9909 // Otherwise compute it.
9910 const SCEV *C = computeSCEVAtScope(V, L);
9911 for (auto &LS : reverse(ValuesAtScopes[V]))
9912 if (LS.first == L) {
9913 LS.second = C;
9914 if (!isa<SCEVConstant>(C))
9915 ValuesAtScopesUsers[C].push_back({L, V});
9916 break;
9917 }
9918 return C;
9919}
9920
9921/// This builds up a Constant using the ConstantExpr interface. That way, we
9922/// will return Constants for objects which aren't represented by a
9923/// SCEVConstant, because SCEVConstant is restricted to ConstantInt.
9924/// Returns NULL if the SCEV isn't representable as a Constant.
9926 switch (V->getSCEVType()) {
9927 case scCouldNotCompute:
9928 case scAddRecExpr:
9929 case scVScale:
9930 return nullptr;
9931 case scConstant:
9932 return cast<SCEVConstant>(V)->getValue();
9933 case scUnknown:
9934 return dyn_cast<Constant>(cast<SCEVUnknown>(V)->getValue());
9935 case scPtrToInt: {
9937 if (Constant *CastOp = BuildConstantFromSCEV(P2I->getOperand()))
9938 return ConstantExpr::getPtrToInt(CastOp, P2I->getType());
9939
9940 return nullptr;
9941 }
9942 case scTruncate: {
9944 if (Constant *CastOp = BuildConstantFromSCEV(ST->getOperand()))
9945 return ConstantExpr::getTrunc(CastOp, ST->getType());
9946 return nullptr;
9947 }
9948 case scAddExpr: {
9949 const SCEVAddExpr *SA = cast<SCEVAddExpr>(V);
9950 Constant *C = nullptr;
9951 for (const SCEV *Op : SA->operands()) {
9953 if (!OpC)
9954 return nullptr;
9955 if (!C) {
9956 C = OpC;
9957 continue;
9958 }
9959 assert(!C->getType()->isPointerTy() &&
9960 "Can only have one pointer, and it must be last");
9961 if (OpC->getType()->isPointerTy()) {
9962 // The offsets have been converted to bytes. We can add bytes using
9963 // an i8 GEP.
9965 OpC, C);
9966 } else {
9967 C = ConstantExpr::getAdd(C, OpC);
9968 }
9969 }
9970 return C;
9971 }
9972 case scMulExpr:
9973 case scSignExtend:
9974 case scZeroExtend:
9975 case scUDivExpr:
9976 case scSMaxExpr:
9977 case scUMaxExpr:
9978 case scSMinExpr:
9979 case scUMinExpr:
9981 return nullptr;
9982 }
9983 llvm_unreachable("Unknown SCEV kind!");
9984}
9985
9986const SCEV *
9987ScalarEvolution::getWithOperands(const SCEV *S,
9988 SmallVectorImpl<const SCEV *> &NewOps) {
9989 switch (S->getSCEVType()) {
9990 case scTruncate:
9991 case scZeroExtend:
9992 case scSignExtend:
9993 case scPtrToInt:
9994 return getCastExpr(S->getSCEVType(), NewOps[0], S->getType());
9995 case scAddRecExpr: {
9996 auto *AddRec = cast<SCEVAddRecExpr>(S);
9997 return getAddRecExpr(NewOps, AddRec->getLoop(), AddRec->getNoWrapFlags());
9998 }
9999 case scAddExpr:
10000 return getAddExpr(NewOps, cast<SCEVAddExpr>(S)->getNoWrapFlags());
10001 case scMulExpr:
10002 return getMulExpr(NewOps, cast<SCEVMulExpr>(S)->getNoWrapFlags());
10003 case scUDivExpr:
10004 return getUDivExpr(NewOps[0], NewOps[1]);
10005 case scUMaxExpr:
10006 case scSMaxExpr:
10007 case scUMinExpr:
10008 case scSMinExpr:
10009 return getMinMaxExpr(S->getSCEVType(), NewOps);
10011 return getSequentialMinMaxExpr(S->getSCEVType(), NewOps);
10012 case scConstant:
10013 case scVScale:
10014 case scUnknown:
10015 return S;
10016 case scCouldNotCompute:
10017 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
10018 }
10019 llvm_unreachable("Unknown SCEV kind!");
10020}
10021
10022const SCEV *ScalarEvolution::computeSCEVAtScope(const SCEV *V, const Loop *L) {
10023 switch (V->getSCEVType()) {
10024 case scConstant:
10025 case scVScale:
10026 return V;
10027 case scAddRecExpr: {
10028 // If this is a loop recurrence for a loop that does not contain L, then we
10029 // are dealing with the final value computed by the loop.
10030 const SCEVAddRecExpr *AddRec = cast<SCEVAddRecExpr>(V);
10031 // First, attempt to evaluate each operand.
10032 // Avoid performing the look-up in the common case where the specified
10033 // expression has no loop-variant portions.
10034 for (unsigned i = 0, e = AddRec->getNumOperands(); i != e; ++i) {
10035 const SCEV *OpAtScope = getSCEVAtScope(AddRec->getOperand(i), L);
10036 if (OpAtScope == AddRec->getOperand(i))
10037 continue;
10038
10039 // Okay, at least one of these operands is loop variant but might be
10040 // foldable. Build a new instance of the folded commutative expression.
10042 NewOps.reserve(AddRec->getNumOperands());
10043 append_range(NewOps, AddRec->operands().take_front(i));
10044 NewOps.push_back(OpAtScope);
10045 for (++i; i != e; ++i)
10046 NewOps.push_back(getSCEVAtScope(AddRec->getOperand(i), L));
10047
10048 const SCEV *FoldedRec = getAddRecExpr(
10049 NewOps, AddRec->getLoop(), AddRec->getNoWrapFlags(SCEV::FlagNW));
10050 AddRec = dyn_cast<SCEVAddRecExpr>(FoldedRec);
10051 // The addrec may be folded to a nonrecurrence, for example, if the
10052 // induction variable is multiplied by zero after constant folding. Go
10053 // ahead and return the folded value.
10054 if (!AddRec)
10055 return FoldedRec;
10056 break;
10057 }
10058
10059 // If the scope is outside the addrec's loop, evaluate it by using the
10060 // loop exit value of the addrec.
10061 if (!AddRec->getLoop()->contains(L)) {
10062 // To evaluate this recurrence, we need to know how many times the AddRec
10063 // loop iterates. Compute this now.
10064 const SCEV *BackedgeTakenCount = getBackedgeTakenCount(AddRec->getLoop());
10065 if (BackedgeTakenCount == getCouldNotCompute())
10066 return AddRec;
10067
10068 // Then, evaluate the AddRec.
10069 return AddRec->evaluateAtIteration(BackedgeTakenCount, *this);
10070 }
10071
10072 return AddRec;
10073 }
10074 case scTruncate:
10075 case scZeroExtend:
10076 case scSignExtend:
10077 case scPtrToInt:
10078 case scAddExpr:
10079 case scMulExpr:
10080 case scUDivExpr:
10081 case scUMaxExpr:
10082 case scSMaxExpr:
10083 case scUMinExpr:
10084 case scSMinExpr:
10085 case scSequentialUMinExpr: {
10086 ArrayRef<const SCEV *> Ops = V->operands();
10087 // Avoid performing the look-up in the common case where the specified
10088 // expression has no loop-variant portions.
10089 for (unsigned i = 0, e = Ops.size(); i != e; ++i) {
10090 const SCEV *OpAtScope = getSCEVAtScope(Ops[i], L);
10091 if (OpAtScope != Ops[i]) {
10092 // Okay, at least one of these operands is loop variant but might be
10093 // foldable. Build a new instance of the folded commutative expression.
10095 NewOps.reserve(Ops.size());
10096 append_range(NewOps, Ops.take_front(i));
10097 NewOps.push_back(OpAtScope);
10098
10099 for (++i; i != e; ++i) {
10100 OpAtScope = getSCEVAtScope(Ops[i], L);
10101 NewOps.push_back(OpAtScope);
10102 }
10103
10104 return getWithOperands(V, NewOps);
10105 }
10106 }
10107 // If we got here, all operands are loop invariant.
10108 return V;
10109 }
10110 case scUnknown: {
10111 // If this instruction is evolved from a constant-evolving PHI, compute the
10112 // exit value from the loop without using SCEVs.
10113 const SCEVUnknown *SU = cast<SCEVUnknown>(V);
10115 if (!I)
10116 return V; // This is some other type of SCEVUnknown, just return it.
10117
10118 if (PHINode *PN = dyn_cast<PHINode>(I)) {
10119 const Loop *CurrLoop = this->LI[I->getParent()];
10120 // Looking for loop exit value.
10121 if (CurrLoop && CurrLoop->getParentLoop() == L &&
10122 PN->getParent() == CurrLoop->getHeader()) {
10123 // Okay, there is no closed form solution for the PHI node. Check
10124 // to see if the loop that contains it has a known backedge-taken
10125 // count. If so, we may be able to force computation of the exit
10126 // value.
10127 const SCEV *BackedgeTakenCount = getBackedgeTakenCount(CurrLoop);
10128 // This trivial case can show up in some degenerate cases where
10129 // the incoming IR has not yet been fully simplified.
10130 if (BackedgeTakenCount->isZero()) {
10131 Value *InitValue = nullptr;
10132 bool MultipleInitValues = false;
10133 for (unsigned i = 0; i < PN->getNumIncomingValues(); i++) {
10134 if (!CurrLoop->contains(PN->getIncomingBlock(i))) {
10135 if (!InitValue)
10136 InitValue = PN->getIncomingValue(i);
10137 else if (InitValue != PN->getIncomingValue(i)) {
10138 MultipleInitValues = true;
10139 break;
10140 }
10141 }
10142 }
10143 if (!MultipleInitValues && InitValue)
10144 return getSCEV(InitValue);
10145 }
10146 // Do we have a loop invariant value flowing around the backedge
10147 // for a loop which must execute the backedge?
10148 if (!isa<SCEVCouldNotCompute>(BackedgeTakenCount) &&
10149 isKnownNonZero(BackedgeTakenCount) &&
10150 PN->getNumIncomingValues() == 2) {
10151
10152 unsigned InLoopPred =
10153 CurrLoop->contains(PN->getIncomingBlock(0)) ? 0 : 1;
10154 Value *BackedgeVal = PN->getIncomingValue(InLoopPred);
10155 if (CurrLoop->isLoopInvariant(BackedgeVal))
10156 return getSCEV(BackedgeVal);
10157 }
10158 if (auto *BTCC = dyn_cast<SCEVConstant>(BackedgeTakenCount)) {
10159 // Okay, we know how many times the containing loop executes. If
10160 // this is a constant evolving PHI node, get the final value at
10161 // the specified iteration number.
10162 Constant *RV =
10163 getConstantEvolutionLoopExitValue(PN, BTCC->getAPInt(), CurrLoop);
10164 if (RV)
10165 return getSCEV(RV);
10166 }
10167 }
10168 }
10169
10170 // Okay, this is an expression that we cannot symbolically evaluate
10171 // into a SCEV. Check to see if it's possible to symbolically evaluate
10172 // the arguments into constants, and if so, try to constant propagate the
10173 // result. This is particularly useful for computing loop exit values.
10174 if (!CanConstantFold(I))
10175 return V; // This is some other type of SCEVUnknown, just return it.
10176
10178 Operands.reserve(I->getNumOperands());
10179 bool MadeImprovement = false;
10180 for (Value *Op : I->operands()) {
10181 if (Constant *C = dyn_cast<Constant>(Op)) {
10182 Operands.push_back(C);
10183 continue;
10184 }
10185
10186 // If any of the operands is non-constant and if they are
10187 // non-integer and non-pointer, don't even try to analyze them
10188 // with scev techniques.
10189 if (!isSCEVable(Op->getType()))
10190 return V;
10191
10192 const SCEV *OrigV = getSCEV(Op);
10193 const SCEV *OpV = getSCEVAtScope(OrigV, L);
10194 MadeImprovement |= OrigV != OpV;
10195
10197 if (!C)
10198 return V;
10199 assert(C->getType() == Op->getType() && "Type mismatch");
10200 Operands.push_back(C);
10201 }
10202
10203 // Check to see if getSCEVAtScope actually made an improvement.
10204 if (!MadeImprovement)
10205 return V; // This is some other type of SCEVUnknown, just return it.
10206
10207 Constant *C = nullptr;
10208 const DataLayout &DL = getDataLayout();
10209 C = ConstantFoldInstOperands(I, Operands, DL, &TLI,
10210 /*AllowNonDeterministic=*/false);
10211 if (!C)
10212 return V;
10213 return getSCEV(C);
10214 }
10215 case scCouldNotCompute:
10216 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
10217 }
10218 llvm_unreachable("Unknown SCEV type!");
10219}
10220
10222 return getSCEVAtScope(getSCEV(V), L);
10223}
10224
10225const SCEV *ScalarEvolution::stripInjectiveFunctions(const SCEV *S) const {
10227 return stripInjectiveFunctions(ZExt->getOperand());
10229 return stripInjectiveFunctions(SExt->getOperand());
10230 return S;
10231}
10232
10233/// Finds the minimum unsigned root of the following equation:
10234///
10235/// A * X = B (mod N)
10236///
10237/// where N = 2^BW and BW is the common bit width of A and B. The signedness of
10238/// A and B isn't important.
10239///
10240/// If the equation does not have a solution, SCEVCouldNotCompute is returned.
10241static const SCEV *
10244 ScalarEvolution &SE, const Loop *L) {
10245 uint32_t BW = A.getBitWidth();
10246 assert(BW == SE.getTypeSizeInBits(B->getType()));
10247 assert(A != 0 && "A must be non-zero.");
10248
10249 // 1. D = gcd(A, N)
10250 //
10251 // The gcd of A and N may have only one prime factor: 2. The number of
10252 // trailing zeros in A is its multiplicity
10253 uint32_t Mult2 = A.countr_zero();
10254 // D = 2^Mult2
10255
10256 // 2. Check if B is divisible by D.
10257 //
10258 // B is divisible by D if and only if the multiplicity of prime factor 2 for B
10259 // is not less than multiplicity of this prime factor for D.
10260 unsigned MinTZ = SE.getMinTrailingZeros(B);
10261 // Try again with the terminator of the loop predecessor for context-specific
10262 // result, if MinTZ s too small.
10263 if (MinTZ < Mult2 && L->getLoopPredecessor())
10264 MinTZ = SE.getMinTrailingZeros(B, L->getLoopPredecessor()->getTerminator());
10265 if (MinTZ < Mult2) {
10266 // Check if we can prove there's no remainder using URem.
10267 const SCEV *URem =
10268 SE.getURemExpr(B, SE.getConstant(APInt::getOneBitSet(BW, Mult2)));
10269 const SCEV *Zero = SE.getZero(B->getType());
10270 if (!SE.isKnownPredicate(CmpInst::ICMP_EQ, URem, Zero)) {
10271 // Try to add a predicate ensuring B is a multiple of 1 << Mult2.
10272 if (!Predicates)
10273 return SE.getCouldNotCompute();
10274
10275 // Avoid adding a predicate that is known to be false.
10276 if (SE.isKnownPredicate(CmpInst::ICMP_NE, URem, Zero))
10277 return SE.getCouldNotCompute();
10278 Predicates->push_back(SE.getEqualPredicate(URem, Zero));
10279 }
10280 }
10281
10282 // 3. Compute I: the multiplicative inverse of (A / D) in arithmetic
10283 // modulo (N / D).
10284 //
10285 // If D == 1, (N / D) == N == 2^BW, so we need one extra bit to represent
10286 // (N / D) in general. The inverse itself always fits into BW bits, though,
10287 // so we immediately truncate it.
10288 APInt AD = A.lshr(Mult2).trunc(BW - Mult2); // AD = A / D
10289 APInt I = AD.multiplicativeInverse().zext(BW);
10290
10291 // 4. Compute the minimum unsigned root of the equation:
10292 // I * (B / D) mod (N / D)
10293 // To simplify the computation, we factor out the divide by D:
10294 // (I * B mod N) / D
10295 const SCEV *D = SE.getConstant(APInt::getOneBitSet(BW, Mult2));
10296 return SE.getUDivExactExpr(SE.getMulExpr(B, SE.getConstant(I)), D);
10297}
10298
10299/// For a given quadratic addrec, generate coefficients of the corresponding
10300/// quadratic equation, multiplied by a common value to ensure that they are
10301/// integers.
10302/// The returned value is a tuple { A, B, C, M, BitWidth }, where
10303/// Ax^2 + Bx + C is the quadratic function, M is the value that A, B and C
10304/// were multiplied by, and BitWidth is the bit width of the original addrec
10305/// coefficients.
10306/// This function returns std::nullopt if the addrec coefficients are not
10307/// compile- time constants.
10308static std::optional<std::tuple<APInt, APInt, APInt, APInt, unsigned>>
10310 assert(AddRec->getNumOperands() == 3 && "This is not a quadratic chrec!");
10311 const SCEVConstant *LC = dyn_cast<SCEVConstant>(AddRec->getOperand(0));
10312 const SCEVConstant *MC = dyn_cast<SCEVConstant>(AddRec->getOperand(1));
10313 const SCEVConstant *NC = dyn_cast<SCEVConstant>(AddRec->getOperand(2));
10314 LLVM_DEBUG(dbgs() << __func__ << ": analyzing quadratic addrec: "
10315 << *AddRec << '\n');
10316
10317 // We currently can only solve this if the coefficients are constants.
10318 if (!LC || !MC || !NC) {
10319 LLVM_DEBUG(dbgs() << __func__ << ": coefficients are not constant\n");
10320 return std::nullopt;
10321 }
10322
10323 APInt L = LC->getAPInt();
10324 APInt M = MC->getAPInt();
10325 APInt N = NC->getAPInt();
10326 assert(!N.isZero() && "This is not a quadratic addrec");
10327
10328 unsigned BitWidth = LC->getAPInt().getBitWidth();
10329 unsigned NewWidth = BitWidth + 1;
10330 LLVM_DEBUG(dbgs() << __func__ << ": addrec coeff bw: "
10331 << BitWidth << '\n');
10332 // The sign-extension (as opposed to a zero-extension) here matches the
10333 // extension used in SolveQuadraticEquationWrap (with the same motivation).
10334 N = N.sext(NewWidth);
10335 M = M.sext(NewWidth);
10336 L = L.sext(NewWidth);
10337
10338 // The increments are M, M+N, M+2N, ..., so the accumulated values are
10339 // L+M, (L+M)+(M+N), (L+M)+(M+N)+(M+2N), ..., that is,
10340 // L+M, L+2M+N, L+3M+3N, ...
10341 // After n iterations the accumulated value Acc is L + nM + n(n-1)/2 N.
10342 //
10343 // The equation Acc = 0 is then
10344 // L + nM + n(n-1)/2 N = 0, or 2L + 2M n + n(n-1) N = 0.
10345 // In a quadratic form it becomes:
10346 // N n^2 + (2M-N) n + 2L = 0.
10347
10348 APInt A = N;
10349 APInt B = 2 * M - A;
10350 APInt C = 2 * L;
10351 APInt T = APInt(NewWidth, 2);
10352 LLVM_DEBUG(dbgs() << __func__ << ": equation " << A << "x^2 + " << B
10353 << "x + " << C << ", coeff bw: " << NewWidth
10354 << ", multiplied by " << T << '\n');
10355 return std::make_tuple(A, B, C, T, BitWidth);
10356}
10357
10358/// Helper function to compare optional APInts:
10359/// (a) if X and Y both exist, return min(X, Y),
10360/// (b) if neither X nor Y exist, return std::nullopt,
10361/// (c) if exactly one of X and Y exists, return that value.
10362static std::optional<APInt> MinOptional(std::optional<APInt> X,
10363 std::optional<APInt> Y) {
10364 if (X && Y) {
10365 unsigned W = std::max(X->getBitWidth(), Y->getBitWidth());
10366 APInt XW = X->sext(W);
10367 APInt YW = Y->sext(W);
10368 return XW.slt(YW) ? *X : *Y;
10369 }
10370 if (!X && !Y)
10371 return std::nullopt;
10372 return X ? *X : *Y;
10373}
10374
10375/// Helper function to truncate an optional APInt to a given BitWidth.
10376/// When solving addrec-related equations, it is preferable to return a value
10377/// that has the same bit width as the original addrec's coefficients. If the
10378/// solution fits in the original bit width, truncate it (except for i1).
10379/// Returning a value of a different bit width may inhibit some optimizations.
10380///
10381/// In general, a solution to a quadratic equation generated from an addrec
10382/// may require BW+1 bits, where BW is the bit width of the addrec's
10383/// coefficients. The reason is that the coefficients of the quadratic
10384/// equation are BW+1 bits wide (to avoid truncation when converting from
10385/// the addrec to the equation).
10386static std::optional<APInt> TruncIfPossible(std::optional<APInt> X,
10387 unsigned BitWidth) {
10388 if (!X)
10389 return std::nullopt;
10390 unsigned W = X->getBitWidth();
10392 return X->trunc(BitWidth);
10393 return X;
10394}
10395
10396/// Let c(n) be the value of the quadratic chrec {L,+,M,+,N} after n
10397/// iterations. The values L, M, N are assumed to be signed, and they
10398/// should all have the same bit widths.
10399/// Find the least n >= 0 such that c(n) = 0 in the arithmetic modulo 2^BW,
10400/// where BW is the bit width of the addrec's coefficients.
10401/// If the calculated value is a BW-bit integer (for BW > 1), it will be
10402/// returned as such, otherwise the bit width of the returned value may
10403/// be greater than BW.
10404///
10405/// This function returns std::nullopt if
10406/// (a) the addrec coefficients are not constant, or
10407/// (b) SolveQuadraticEquationWrap was unable to find a solution. For cases
10408/// like x^2 = 5, no integer solutions exist, in other cases an integer
10409/// solution may exist, but SolveQuadraticEquationWrap may fail to find it.
10410static std::optional<APInt>
10412 APInt A, B, C, M;
10413 unsigned BitWidth;
10414 auto T = GetQuadraticEquation(AddRec);
10415 if (!T)
10416 return std::nullopt;
10417
10418 std::tie(A, B, C, M, BitWidth) = *T;
10419 LLVM_DEBUG(dbgs() << __func__ << ": solving for unsigned overflow\n");
10420 std::optional<APInt> X =
10422 if (!X)
10423 return std::nullopt;
10424
10425 ConstantInt *CX = ConstantInt::get(SE.getContext(), *X);
10426 ConstantInt *V = EvaluateConstantChrecAtConstant(AddRec, CX, SE);
10427 if (!V->isZero())
10428 return std::nullopt;
10429
10430 return TruncIfPossible(X, BitWidth);
10431}
10432
10433/// Let c(n) be the value of the quadratic chrec {0,+,M,+,N} after n
10434/// iterations. The values M, N are assumed to be signed, and they
10435/// should all have the same bit widths.
10436/// Find the least n such that c(n) does not belong to the given range,
10437/// while c(n-1) does.
10438///
10439/// This function returns std::nullopt if
10440/// (a) the addrec coefficients are not constant, or
10441/// (b) SolveQuadraticEquationWrap was unable to find a solution for the
10442/// bounds of the range.
10443static std::optional<APInt>
10445 const ConstantRange &Range, ScalarEvolution &SE) {
10446 assert(AddRec->getOperand(0)->isZero() &&
10447 "Starting value of addrec should be 0");
10448 LLVM_DEBUG(dbgs() << __func__ << ": solving boundary crossing for range "
10449 << Range << ", addrec " << *AddRec << '\n');
10450 // This case is handled in getNumIterationsInRange. Here we can assume that
10451 // we start in the range.
10452 assert(Range.contains(APInt(SE.getTypeSizeInBits(AddRec->getType()), 0)) &&
10453 "Addrec's initial value should be in range");
10454
10455 APInt A, B, C, M;
10456 unsigned BitWidth;
10457 auto T = GetQuadraticEquation(AddRec);
10458 if (!T)
10459 return std::nullopt;
10460
10461 // Be careful about the return value: there can be two reasons for not
10462 // returning an actual number. First, if no solutions to the equations
10463 // were found, and second, if the solutions don't leave the given range.
10464 // The first case means that the actual solution is "unknown", the second
10465 // means that it's known, but not valid. If the solution is unknown, we
10466 // cannot make any conclusions.
10467 // Return a pair: the optional solution and a flag indicating if the
10468 // solution was found.
10469 auto SolveForBoundary =
10470 [&](APInt Bound) -> std::pair<std::optional<APInt>, bool> {
10471 // Solve for signed overflow and unsigned overflow, pick the lower
10472 // solution.
10473 LLVM_DEBUG(dbgs() << "SolveQuadraticAddRecRange: checking boundary "
10474 << Bound << " (before multiplying by " << M << ")\n");
10475 Bound *= M; // The quadratic equation multiplier.
10476
10477 std::optional<APInt> SO;
10478 if (BitWidth > 1) {
10479 LLVM_DEBUG(dbgs() << "SolveQuadraticAddRecRange: solving for "
10480 "signed overflow\n");
10482 }
10483 LLVM_DEBUG(dbgs() << "SolveQuadraticAddRecRange: solving for "
10484 "unsigned overflow\n");
10485 std::optional<APInt> UO =
10487
10488 auto LeavesRange = [&] (const APInt &X) {
10489 ConstantInt *C0 = ConstantInt::get(SE.getContext(), X);
10490 ConstantInt *V0 = EvaluateConstantChrecAtConstant(AddRec, C0, SE);
10491 if (Range.contains(V0->getValue()))
10492 return false;
10493 // X should be at least 1, so X-1 is non-negative.
10494 ConstantInt *C1 = ConstantInt::get(SE.getContext(), X-1);
10495 ConstantInt *V1 = EvaluateConstantChrecAtConstant(AddRec, C1, SE);
10496 if (Range.contains(V1->getValue()))
10497 return true;
10498 return false;
10499 };
10500
10501 // If SolveQuadraticEquationWrap returns std::nullopt, it means that there
10502 // can be a solution, but the function failed to find it. We cannot treat it
10503 // as "no solution".
10504 if (!SO || !UO)
10505 return {std::nullopt, false};
10506
10507 // Check the smaller value first to see if it leaves the range.
10508 // At this point, both SO and UO must have values.
10509 std::optional<APInt> Min = MinOptional(SO, UO);
10510 if (LeavesRange(*Min))
10511 return { Min, true };
10512 std::optional<APInt> Max = Min == SO ? UO : SO;
10513 if (LeavesRange(*Max))
10514 return { Max, true };
10515
10516 // Solutions were found, but were eliminated, hence the "true".
10517 return {std::nullopt, true};
10518 };
10519
10520 std::tie(A, B, C, M, BitWidth) = *T;
10521 // Lower bound is inclusive, subtract 1 to represent the exiting value.
10522 APInt Lower = Range.getLower().sext(A.getBitWidth()) - 1;
10523 APInt Upper = Range.getUpper().sext(A.getBitWidth());
10524 auto SL = SolveForBoundary(Lower);
10525 auto SU = SolveForBoundary(Upper);
10526 // If any of the solutions was unknown, no meaninigful conclusions can
10527 // be made.
10528 if (!SL.second || !SU.second)
10529 return std::nullopt;
10530
10531 // Claim: The correct solution is not some value between Min and Max.
10532 //
10533 // Justification: Assuming that Min and Max are different values, one of
10534 // them is when the first signed overflow happens, the other is when the
10535 // first unsigned overflow happens. Crossing the range boundary is only
10536 // possible via an overflow (treating 0 as a special case of it, modeling
10537 // an overflow as crossing k*2^W for some k).
10538 //
10539 // The interesting case here is when Min was eliminated as an invalid
10540 // solution, but Max was not. The argument is that if there was another
10541 // overflow between Min and Max, it would also have been eliminated if
10542 // it was considered.
10543 //
10544 // For a given boundary, it is possible to have two overflows of the same
10545 // type (signed/unsigned) without having the other type in between: this
10546 // can happen when the vertex of the parabola is between the iterations
10547 // corresponding to the overflows. This is only possible when the two
10548 // overflows cross k*2^W for the same k. In such case, if the second one
10549 // left the range (and was the first one to do so), the first overflow
10550 // would have to enter the range, which would mean that either we had left
10551 // the range before or that we started outside of it. Both of these cases
10552 // are contradictions.
10553 //
10554 // Claim: In the case where SolveForBoundary returns std::nullopt, the correct
10555 // solution is not some value between the Max for this boundary and the
10556 // Min of the other boundary.
10557 //
10558 // Justification: Assume that we had such Max_A and Min_B corresponding
10559 // to range boundaries A and B and such that Max_A < Min_B. If there was
10560 // a solution between Max_A and Min_B, it would have to be caused by an
10561 // overflow corresponding to either A or B. It cannot correspond to B,
10562 // since Min_B is the first occurrence of such an overflow. If it
10563 // corresponded to A, it would have to be either a signed or an unsigned
10564 // overflow that is larger than both eliminated overflows for A. But
10565 // between the eliminated overflows and this overflow, the values would
10566 // cover the entire value space, thus crossing the other boundary, which
10567 // is a contradiction.
10568
10569 return TruncIfPossible(MinOptional(SL.first, SU.first), BitWidth);
10570}
10571
10572ScalarEvolution::ExitLimit ScalarEvolution::howFarToZero(const SCEV *V,
10573 const Loop *L,
10574 bool ControlsOnlyExit,
10575 bool AllowPredicates) {
10576
10577 // This is only used for loops with a "x != y" exit test. The exit condition
10578 // is now expressed as a single expression, V = x-y. So the exit test is
10579 // effectively V != 0. We know and take advantage of the fact that this
10580 // expression only being used in a comparison by zero context.
10581
10583 // If the value is a constant
10584 if (const SCEVConstant *C = dyn_cast<SCEVConstant>(V)) {
10585 // If the value is already zero, the branch will execute zero times.
10586 if (C->getValue()->isZero()) return C;
10587 return getCouldNotCompute(); // Otherwise it will loop infinitely.
10588 }
10589
10590 const SCEVAddRecExpr *AddRec =
10591 dyn_cast<SCEVAddRecExpr>(stripInjectiveFunctions(V));
10592
10593 if (!AddRec && AllowPredicates)
10594 // Try to make this an AddRec using runtime tests, in the first X
10595 // iterations of this loop, where X is the SCEV expression found by the
10596 // algorithm below.
10597 AddRec = convertSCEVToAddRecWithPredicates(V, L, Predicates);
10598
10599 if (!AddRec || AddRec->getLoop() != L)
10600 return getCouldNotCompute();
10601
10602 // If this is a quadratic (3-term) AddRec {L,+,M,+,N}, find the roots of
10603 // the quadratic equation to solve it.
10604 if (AddRec->isQuadratic() && AddRec->getType()->isIntegerTy()) {
10605 // We can only use this value if the chrec ends up with an exact zero
10606 // value at this index. When solving for "X*X != 5", for example, we
10607 // should not accept a root of 2.
10608 if (auto S = SolveQuadraticAddRecExact(AddRec, *this)) {
10609 const auto *R = cast<SCEVConstant>(getConstant(*S));
10610 return ExitLimit(R, R, R, false, Predicates);
10611 }
10612 return getCouldNotCompute();
10613 }
10614
10615 // Otherwise we can only handle this if it is affine.
10616 if (!AddRec->isAffine())
10617 return getCouldNotCompute();
10618
10619 // If this is an affine expression, the execution count of this branch is
10620 // the minimum unsigned root of the following equation:
10621 //
10622 // Start + Step*N = 0 (mod 2^BW)
10623 //
10624 // equivalent to:
10625 //
10626 // Step*N = -Start (mod 2^BW)
10627 //
10628 // where BW is the common bit width of Start and Step.
10629
10630 // Get the initial value for the loop.
10631 const SCEV *Start = getSCEVAtScope(AddRec->getStart(), L->getParentLoop());
10632 const SCEV *Step = getSCEVAtScope(AddRec->getOperand(1), L->getParentLoop());
10633
10634 if (!isLoopInvariant(Step, L))
10635 return getCouldNotCompute();
10636
10637 LoopGuards Guards = LoopGuards::collect(L, *this);
10638 // Specialize step for this loop so we get context sensitive facts below.
10639 const SCEV *StepWLG = applyLoopGuards(Step, Guards);
10640
10641 // For positive steps (counting up until unsigned overflow):
10642 // N = -Start/Step (as unsigned)
10643 // For negative steps (counting down to zero):
10644 // N = Start/-Step
10645 // First compute the unsigned distance from zero in the direction of Step.
10646 bool CountDown = isKnownNegative(StepWLG);
10647 if (!CountDown && !isKnownNonNegative(StepWLG))
10648 return getCouldNotCompute();
10649
10650 const SCEV *Distance = CountDown ? Start : getNegativeSCEV(Start);
10651 // Handle unitary steps, which cannot wraparound.
10652 // 1*N = -Start; -1*N = Start (mod 2^BW), so:
10653 // N = Distance (as unsigned)
10654
10655 if (match(Step, m_CombineOr(m_scev_One(), m_scev_AllOnes()))) {
10656 APInt MaxBECount = getUnsignedRangeMax(applyLoopGuards(Distance, Guards));
10657 MaxBECount = APIntOps::umin(MaxBECount, getUnsignedRangeMax(Distance));
10658
10659 // When a loop like "for (int i = 0; i != n; ++i) { /* body */ }" is rotated,
10660 // we end up with a loop whose backedge-taken count is n - 1. Detect this
10661 // case, and see if we can improve the bound.
10662 //
10663 // Explicitly handling this here is necessary because getUnsignedRange
10664 // isn't context-sensitive; it doesn't know that we only care about the
10665 // range inside the loop.
10666 const SCEV *Zero = getZero(Distance->getType());
10667 const SCEV *One = getOne(Distance->getType());
10668 const SCEV *DistancePlusOne = getAddExpr(Distance, One);
10669 if (isLoopEntryGuardedByCond(L, ICmpInst::ICMP_NE, DistancePlusOne, Zero)) {
10670 // If Distance + 1 doesn't overflow, we can compute the maximum distance
10671 // as "unsigned_max(Distance + 1) - 1".
10672 ConstantRange CR = getUnsignedRange(DistancePlusOne);
10673 MaxBECount = APIntOps::umin(MaxBECount, CR.getUnsignedMax() - 1);
10674 }
10675 return ExitLimit(Distance, getConstant(MaxBECount), Distance, false,
10676 Predicates);
10677 }
10678
10679 // If the condition controls loop exit (the loop exits only if the expression
10680 // is true) and the addition is no-wrap we can use unsigned divide to
10681 // compute the backedge count. In this case, the step may not divide the
10682 // distance, but we don't care because if the condition is "missed" the loop
10683 // will have undefined behavior due to wrapping.
10684 if (ControlsOnlyExit && AddRec->hasNoSelfWrap() &&
10685 loopHasNoAbnormalExits(AddRec->getLoop())) {
10686
10687 // If the stride is zero and the start is non-zero, the loop must be
10688 // infinite. In C++, most loops are finite by assumption, in which case the
10689 // step being zero implies UB must execute if the loop is entered.
10690 if (!(loopIsFiniteByAssumption(L) && isKnownNonZero(Start)) &&
10691 !isKnownNonZero(StepWLG))
10692 return getCouldNotCompute();
10693
10694 const SCEV *Exact =
10695 getUDivExpr(Distance, CountDown ? getNegativeSCEV(Step) : Step);
10696 const SCEV *ConstantMax = getCouldNotCompute();
10697 if (Exact != getCouldNotCompute()) {
10698 APInt MaxInt = getUnsignedRangeMax(applyLoopGuards(Exact, Guards));
10699 ConstantMax =
10701 }
10702 const SCEV *SymbolicMax =
10703 isa<SCEVCouldNotCompute>(Exact) ? ConstantMax : Exact;
10704 return ExitLimit(Exact, ConstantMax, SymbolicMax, false, Predicates);
10705 }
10706
10707 // Solve the general equation.
10708 const SCEVConstant *StepC = dyn_cast<SCEVConstant>(Step);
10709 if (!StepC || StepC->getValue()->isZero())
10710 return getCouldNotCompute();
10711 const SCEV *E = SolveLinEquationWithOverflow(
10712 StepC->getAPInt(), getNegativeSCEV(Start),
10713 AllowPredicates ? &Predicates : nullptr, *this, L);
10714
10715 const SCEV *M = E;
10716 if (E != getCouldNotCompute()) {
10717 APInt MaxWithGuards = getUnsignedRangeMax(applyLoopGuards(E, Guards));
10718 M = getConstant(APIntOps::umin(MaxWithGuards, getUnsignedRangeMax(E)));
10719 }
10720 auto *S = isa<SCEVCouldNotCompute>(E) ? M : E;
10721 return ExitLimit(E, M, S, false, Predicates);
10722}
10723
10724ScalarEvolution::ExitLimit
10725ScalarEvolution::howFarToNonZero(const SCEV *V, const Loop *L) {
10726 // Loops that look like: while (X == 0) are very strange indeed. We don't
10727 // handle them yet except for the trivial case. This could be expanded in the
10728 // future as needed.
10729
10730 // If the value is a constant, check to see if it is known to be non-zero
10731 // already. If so, the backedge will execute zero times.
10732 if (const SCEVConstant *C = dyn_cast<SCEVConstant>(V)) {
10733 if (!C->getValue()->isZero())
10734 return getZero(C->getType());
10735 return getCouldNotCompute(); // Otherwise it will loop infinitely.
10736 }
10737
10738 // We could implement others, but I really doubt anyone writes loops like
10739 // this, and if they did, they would already be constant folded.
10740 return getCouldNotCompute();
10741}
10742
10743std::pair<const BasicBlock *, const BasicBlock *>
10744ScalarEvolution::getPredecessorWithUniqueSuccessorForBB(const BasicBlock *BB)
10745 const {
10746 // If the block has a unique predecessor, then there is no path from the
10747 // predecessor to the block that does not go through the direct edge
10748 // from the predecessor to the block.
10749 if (const BasicBlock *Pred = BB->getSinglePredecessor())
10750 return {Pred, BB};
10751
10752 // A loop's header is defined to be a block that dominates the loop.
10753 // If the header has a unique predecessor outside the loop, it must be
10754 // a block that has exactly one successor that can reach the loop.
10755 if (const Loop *L = LI.getLoopFor(BB))
10756 return {L->getLoopPredecessor(), L->getHeader()};
10757
10758 return {nullptr, BB};
10759}
10760
10761/// SCEV structural equivalence is usually sufficient for testing whether two
10762/// expressions are equal, however for the purposes of looking for a condition
10763/// guarding a loop, it can be useful to be a little more general, since a
10764/// front-end may have replicated the controlling expression.
10765static bool HasSameValue(const SCEV *A, const SCEV *B) {
10766 // Quick check to see if they are the same SCEV.
10767 if (A == B) return true;
10768
10769 auto ComputesEqualValues = [](const Instruction *A, const Instruction *B) {
10770 // Not all instructions that are "identical" compute the same value. For
10771 // instance, two distinct alloca instructions allocating the same type are
10772 // identical and do not read memory; but compute distinct values.
10773 return A->isIdenticalTo(B) && (isa<BinaryOperator>(A) || isa<GetElementPtrInst>(A));
10774 };
10775
10776 // Otherwise, if they're both SCEVUnknown, it's possible that they hold
10777 // two different instructions with the same value. Check for this case.
10778 if (const SCEVUnknown *AU = dyn_cast<SCEVUnknown>(A))
10779 if (const SCEVUnknown *BU = dyn_cast<SCEVUnknown>(B))
10780 if (const Instruction *AI = dyn_cast<Instruction>(AU->getValue()))
10781 if (const Instruction *BI = dyn_cast<Instruction>(BU->getValue()))
10782 if (ComputesEqualValues(AI, BI))
10783 return true;
10784
10785 // Otherwise assume they may have a different value.
10786 return false;
10787}
10788
10789static bool MatchBinarySub(const SCEV *S, const SCEV *&LHS, const SCEV *&RHS) {
10790 const SCEV *Op0, *Op1;
10791 if (!match(S, m_scev_Add(m_SCEV(Op0), m_SCEV(Op1))))
10792 return false;
10793 if (match(Op0, m_scev_Mul(m_scev_AllOnes(), m_SCEV(RHS)))) {
10794 LHS = Op1;
10795 return true;
10796 }
10797 if (match(Op1, m_scev_Mul(m_scev_AllOnes(), m_SCEV(RHS)))) {
10798 LHS = Op0;
10799 return true;
10800 }
10801 return false;
10802}
10803
10805 const SCEV *&RHS, unsigned Depth) {
10806 bool Changed = false;
10807 // Simplifies ICMP to trivial true or false by turning it into '0 == 0' or
10808 // '0 != 0'.
10809 auto TrivialCase = [&](bool TriviallyTrue) {
10811 Pred = TriviallyTrue ? ICmpInst::ICMP_EQ : ICmpInst::ICMP_NE;
10812 return true;
10813 };
10814 // If we hit the max recursion limit bail out.
10815 if (Depth >= 3)
10816 return false;
10817
10818 const SCEV *NewLHS, *NewRHS;
10819 if (match(LHS, m_scev_c_Mul(m_SCEV(NewLHS), m_SCEVVScale())) &&
10820 match(RHS, m_scev_c_Mul(m_SCEV(NewRHS), m_SCEVVScale()))) {
10821 const SCEVMulExpr *LMul = cast<SCEVMulExpr>(LHS);
10822 const SCEVMulExpr *RMul = cast<SCEVMulExpr>(RHS);
10823
10824 // (X * vscale) pred (Y * vscale) ==> X pred Y
10825 // when both multiples are NSW.
10826 // (X * vscale) uicmp/eq/ne (Y * vscale) ==> X uicmp/eq/ne Y
10827 // when both multiples are NUW.
10828 if ((LMul->hasNoSignedWrap() && RMul->hasNoSignedWrap()) ||
10829 (LMul->hasNoUnsignedWrap() && RMul->hasNoUnsignedWrap() &&
10830 !ICmpInst::isSigned(Pred))) {
10831 LHS = NewLHS;
10832 RHS = NewRHS;
10833 Changed = true;
10834 }
10835 }
10836
10837 // Canonicalize a constant to the right side.
10838 if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(LHS)) {
10839 // Check for both operands constant.
10840 if (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS)) {
10841 if (!ICmpInst::compare(LHSC->getAPInt(), RHSC->getAPInt(), Pred))
10842 return TrivialCase(false);
10843 return TrivialCase(true);
10844 }
10845 // Otherwise swap the operands to put the constant on the right.
10846 std::swap(LHS, RHS);
10848 Changed = true;
10849 }
10850
10851 // If we're comparing an addrec with a value which is loop-invariant in the
10852 // addrec's loop, put the addrec on the left. Also make a dominance check,
10853 // as both operands could be addrecs loop-invariant in each other's loop.
10854 if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(RHS)) {
10855 const Loop *L = AR->getLoop();
10856 if (isLoopInvariant(LHS, L) && properlyDominates(LHS, L->getHeader())) {
10857 std::swap(LHS, RHS);
10859 Changed = true;
10860 }
10861 }
10862
10863 // If there's a constant operand, canonicalize comparisons with boundary
10864 // cases, and canonicalize *-or-equal comparisons to regular comparisons.
10865 if (const SCEVConstant *RC = dyn_cast<SCEVConstant>(RHS)) {
10866 const APInt &RA = RC->getAPInt();
10867
10868 bool SimplifiedByConstantRange = false;
10869
10870 if (!ICmpInst::isEquality(Pred)) {
10872 if (ExactCR.isFullSet())
10873 return TrivialCase(true);
10874 if (ExactCR.isEmptySet())
10875 return TrivialCase(false);
10876
10877 APInt NewRHS;
10878 CmpInst::Predicate NewPred;
10879 if (ExactCR.getEquivalentICmp(NewPred, NewRHS) &&
10880 ICmpInst::isEquality(NewPred)) {
10881 // We were able to convert an inequality to an equality.
10882 Pred = NewPred;
10883 RHS = getConstant(NewRHS);
10884 Changed = SimplifiedByConstantRange = true;
10885 }
10886 }
10887
10888 if (!SimplifiedByConstantRange) {
10889 switch (Pred) {
10890 default:
10891 break;
10892 case ICmpInst::ICMP_EQ:
10893 case ICmpInst::ICMP_NE:
10894 // Fold ((-1) * %a) + %b == 0 (equivalent to %b-%a == 0) into %a == %b.
10895 if (RA.isZero() && MatchBinarySub(LHS, LHS, RHS))
10896 Changed = true;
10897 break;
10898
10899 // The "Should have been caught earlier!" messages refer to the fact
10900 // that the ExactCR.isFullSet() or ExactCR.isEmptySet() check above
10901 // should have fired on the corresponding cases, and canonicalized the
10902 // check to trivial case.
10903
10904 case ICmpInst::ICMP_UGE:
10905 assert(!RA.isMinValue() && "Should have been caught earlier!");
10906 Pred = ICmpInst::ICMP_UGT;
10907 RHS = getConstant(RA - 1);
10908 Changed = true;
10909 break;
10910 case ICmpInst::ICMP_ULE:
10911 assert(!RA.isMaxValue() && "Should have been caught earlier!");
10912 Pred = ICmpInst::ICMP_ULT;
10913 RHS = getConstant(RA + 1);
10914 Changed = true;
10915 break;
10916 case ICmpInst::ICMP_SGE:
10917 assert(!RA.isMinSignedValue() && "Should have been caught earlier!");
10918 Pred = ICmpInst::ICMP_SGT;
10919 RHS = getConstant(RA - 1);
10920 Changed = true;
10921 break;
10922 case ICmpInst::ICMP_SLE:
10923 assert(!RA.isMaxSignedValue() && "Should have been caught earlier!");
10924 Pred = ICmpInst::ICMP_SLT;
10925 RHS = getConstant(RA + 1);
10926 Changed = true;
10927 break;
10928 }
10929 }
10930 }
10931
10932 // Check for obvious equality.
10933 if (HasSameValue(LHS, RHS)) {
10934 if (ICmpInst::isTrueWhenEqual(Pred))
10935 return TrivialCase(true);
10937 return TrivialCase(false);
10938 }
10939
10940 // If possible, canonicalize GE/LE comparisons to GT/LT comparisons, by
10941 // adding or subtracting 1 from one of the operands.
10942 switch (Pred) {
10943 case ICmpInst::ICMP_SLE:
10944 if (!getSignedRangeMax(RHS).isMaxSignedValue()) {
10945 RHS = getAddExpr(getConstant(RHS->getType(), 1, true), RHS,
10947 Pred = ICmpInst::ICMP_SLT;
10948 Changed = true;
10949 } else if (!getSignedRangeMin(LHS).isMinSignedValue()) {
10950 LHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), LHS,
10952 Pred = ICmpInst::ICMP_SLT;
10953 Changed = true;
10954 }
10955 break;
10956 case ICmpInst::ICMP_SGE:
10957 if (!getSignedRangeMin(RHS).isMinSignedValue()) {
10958 RHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), RHS,
10960 Pred = ICmpInst::ICMP_SGT;
10961 Changed = true;
10962 } else if (!getSignedRangeMax(LHS).isMaxSignedValue()) {
10963 LHS = getAddExpr(getConstant(RHS->getType(), 1, true), LHS,
10965 Pred = ICmpInst::ICMP_SGT;
10966 Changed = true;
10967 }
10968 break;
10969 case ICmpInst::ICMP_ULE:
10970 if (!getUnsignedRangeMax(RHS).isMaxValue()) {
10971 RHS = getAddExpr(getConstant(RHS->getType(), 1, true), RHS,
10973 Pred = ICmpInst::ICMP_ULT;
10974 Changed = true;
10975 } else if (!getUnsignedRangeMin(LHS).isMinValue()) {
10976 LHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), LHS);
10977 Pred = ICmpInst::ICMP_ULT;
10978 Changed = true;
10979 }
10980 break;
10981 case ICmpInst::ICMP_UGE:
10982 // If RHS is an op we can fold the -1, try that first.
10983 // Otherwise prefer LHS to preserve the nuw flag.
10984 if ((isa<SCEVConstant>(RHS) ||
10986 isa<SCEVConstant>(cast<SCEVNAryExpr>(RHS)->getOperand(0)))) &&
10987 !getUnsignedRangeMin(RHS).isMinValue()) {
10988 RHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), RHS);
10989 Pred = ICmpInst::ICMP_UGT;
10990 Changed = true;
10991 } else if (!getUnsignedRangeMax(LHS).isMaxValue()) {
10992 LHS = getAddExpr(getConstant(RHS->getType(), 1, true), LHS,
10994 Pred = ICmpInst::ICMP_UGT;
10995 Changed = true;
10996 } else if (!getUnsignedRangeMin(RHS).isMinValue()) {
10997 RHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), RHS);
10998 Pred = ICmpInst::ICMP_UGT;
10999 Changed = true;
11000 }
11001 break;
11002 default:
11003 break;
11004 }
11005
11006 // TODO: More simplifications are possible here.
11007
11008 // Recursively simplify until we either hit a recursion limit or nothing
11009 // changes.
11010 if (Changed)
11011 (void)SimplifyICmpOperands(Pred, LHS, RHS, Depth + 1);
11012
11013 return Changed;
11014}
11015
11017 return getSignedRangeMax(S).isNegative();
11018}
11019
11023
11025 return !getSignedRangeMin(S).isNegative();
11026}
11027
11031
11033 // Query push down for cases where the unsigned range is
11034 // less than sufficient.
11035 if (const auto *SExt = dyn_cast<SCEVSignExtendExpr>(S))
11036 return isKnownNonZero(SExt->getOperand(0));
11037 return getUnsignedRangeMin(S) != 0;
11038}
11039
11041 bool OrNegative) {
11042 auto NonRecursive = [this, OrNegative](const SCEV *S) {
11043 if (auto *C = dyn_cast<SCEVConstant>(S))
11044 return C->getAPInt().isPowerOf2() ||
11045 (OrNegative && C->getAPInt().isNegatedPowerOf2());
11046
11047 // The vscale_range indicates vscale is a power-of-two.
11048 return isa<SCEVVScale>(S) && F.hasFnAttribute(Attribute::VScaleRange);
11049 };
11050
11051 if (NonRecursive(S))
11052 return true;
11053
11054 auto *Mul = dyn_cast<SCEVMulExpr>(S);
11055 if (!Mul)
11056 return false;
11057 return all_of(Mul->operands(), NonRecursive) && (OrZero || isKnownNonZero(S));
11058}
11059
11061 const SCEV *S, uint64_t M,
11063 if (M == 0)
11064 return false;
11065 if (M == 1)
11066 return true;
11067
11068 // Recursively check AddRec operands. An AddRecExpr S is a multiple of M if S
11069 // starts with a multiple of M and at every iteration step S only adds
11070 // multiples of M.
11071 if (auto *AddRec = dyn_cast<SCEVAddRecExpr>(S))
11072 return isKnownMultipleOf(AddRec->getStart(), M, Assumptions) &&
11073 isKnownMultipleOf(AddRec->getStepRecurrence(*this), M, Assumptions);
11074
11075 // For a constant, check that "S % M == 0".
11076 if (auto *Cst = dyn_cast<SCEVConstant>(S)) {
11077 APInt C = Cst->getAPInt();
11078 return C.urem(M) == 0;
11079 }
11080
11081 // TODO: Also check other SCEV expressions, i.e., SCEVAddRecExpr, etc.
11082
11083 // Basic tests have failed.
11084 // Check "S % M == 0" at compile time and record runtime Assumptions.
11085 auto *STy = dyn_cast<IntegerType>(S->getType());
11086 const SCEV *SmodM =
11087 getURemExpr(S, getConstant(ConstantInt::get(STy, M, false)));
11088 const SCEV *Zero = getZero(STy);
11089
11090 // Check whether "S % M == 0" is known at compile time.
11091 if (isKnownPredicate(ICmpInst::ICMP_EQ, SmodM, Zero))
11092 return true;
11093
11094 // Check whether "S % M != 0" is known at compile time.
11095 if (isKnownPredicate(ICmpInst::ICMP_NE, SmodM, Zero))
11096 return false;
11097
11099
11100 // Detect redundant predicates.
11101 for (auto *A : Assumptions)
11102 if (A->implies(P, *this))
11103 return true;
11104
11105 // Only record non-redundant predicates.
11106 Assumptions.push_back(P);
11107 return true;
11108}
11109
11110std::pair<const SCEV *, const SCEV *>
11112 // Compute SCEV on entry of loop L.
11113 const SCEV *Start = SCEVInitRewriter::rewrite(S, L, *this);
11114 if (Start == getCouldNotCompute())
11115 return { Start, Start };
11116 // Compute post increment SCEV for loop L.
11117 const SCEV *PostInc = SCEVPostIncRewriter::rewrite(S, L, *this);
11118 assert(PostInc != getCouldNotCompute() && "Unexpected could not compute");
11119 return { Start, PostInc };
11120}
11121
11123 const SCEV *RHS) {
11124 // First collect all loops.
11126 getUsedLoops(LHS, LoopsUsed);
11127 getUsedLoops(RHS, LoopsUsed);
11128
11129 if (LoopsUsed.empty())
11130 return false;
11131
11132 // Domination relationship must be a linear order on collected loops.
11133#ifndef NDEBUG
11134 for (const auto *L1 : LoopsUsed)
11135 for (const auto *L2 : LoopsUsed)
11136 assert((DT.dominates(L1->getHeader(), L2->getHeader()) ||
11137 DT.dominates(L2->getHeader(), L1->getHeader())) &&
11138 "Domination relationship is not a linear order");
11139#endif
11140
11141 const Loop *MDL =
11142 *llvm::max_element(LoopsUsed, [&](const Loop *L1, const Loop *L2) {
11143 return DT.properlyDominates(L1->getHeader(), L2->getHeader());
11144 });
11145
11146 // Get init and post increment value for LHS.
11147 auto SplitLHS = SplitIntoInitAndPostInc(MDL, LHS);
11148 // if LHS contains unknown non-invariant SCEV then bail out.
11149 if (SplitLHS.first == getCouldNotCompute())
11150 return false;
11151 assert (SplitLHS.second != getCouldNotCompute() && "Unexpected CNC");
11152 // Get init and post increment value for RHS.
11153 auto SplitRHS = SplitIntoInitAndPostInc(MDL, RHS);
11154 // if RHS contains unknown non-invariant SCEV then bail out.
11155 if (SplitRHS.first == getCouldNotCompute())
11156 return false;
11157 assert (SplitRHS.second != getCouldNotCompute() && "Unexpected CNC");
11158 // It is possible that init SCEV contains an invariant load but it does
11159 // not dominate MDL and is not available at MDL loop entry, so we should
11160 // check it here.
11161 if (!isAvailableAtLoopEntry(SplitLHS.first, MDL) ||
11162 !isAvailableAtLoopEntry(SplitRHS.first, MDL))
11163 return false;
11164
11165 // It seems backedge guard check is faster than entry one so in some cases
11166 // it can speed up whole estimation by short circuit
11167 return isLoopBackedgeGuardedByCond(MDL, Pred, SplitLHS.second,
11168 SplitRHS.second) &&
11169 isLoopEntryGuardedByCond(MDL, Pred, SplitLHS.first, SplitRHS.first);
11170}
11171
11173 const SCEV *RHS) {
11174 // Canonicalize the inputs first.
11175 (void)SimplifyICmpOperands(Pred, LHS, RHS);
11176
11177 if (isKnownViaInduction(Pred, LHS, RHS))
11178 return true;
11179
11180 if (isKnownPredicateViaSplitting(Pred, LHS, RHS))
11181 return true;
11182
11183 // Otherwise see what can be done with some simple reasoning.
11184 return isKnownViaNonRecursiveReasoning(Pred, LHS, RHS);
11185}
11186
11188 const SCEV *LHS,
11189 const SCEV *RHS) {
11190 if (isKnownPredicate(Pred, LHS, RHS))
11191 return true;
11193 return false;
11194 return std::nullopt;
11195}
11196
11198 const SCEV *RHS,
11199 const Instruction *CtxI) {
11200 // TODO: Analyze guards and assumes from Context's block.
11201 return isKnownPredicate(Pred, LHS, RHS) ||
11202 isBasicBlockEntryGuardedByCond(CtxI->getParent(), Pred, LHS, RHS);
11203}
11204
11205std::optional<bool>
11207 const SCEV *RHS, const Instruction *CtxI) {
11208 std::optional<bool> KnownWithoutContext = evaluatePredicate(Pred, LHS, RHS);
11209 if (KnownWithoutContext)
11210 return KnownWithoutContext;
11211
11212 if (isBasicBlockEntryGuardedByCond(CtxI->getParent(), Pred, LHS, RHS))
11213 return true;
11215 CtxI->getParent(), ICmpInst::getInverseCmpPredicate(Pred), LHS, RHS))
11216 return false;
11217 return std::nullopt;
11218}
11219
11221 const SCEVAddRecExpr *LHS,
11222 const SCEV *RHS) {
11223 const Loop *L = LHS->getLoop();
11224 return isLoopEntryGuardedByCond(L, Pred, LHS->getStart(), RHS) &&
11225 isLoopBackedgeGuardedByCond(L, Pred, LHS->getPostIncExpr(*this), RHS);
11226}
11227
11228std::optional<ScalarEvolution::MonotonicPredicateType>
11230 ICmpInst::Predicate Pred) {
11231 auto Result = getMonotonicPredicateTypeImpl(LHS, Pred);
11232
11233#ifndef NDEBUG
11234 // Verify an invariant: inverting the predicate should turn a monotonically
11235 // increasing change to a monotonically decreasing one, and vice versa.
11236 if (Result) {
11237 auto ResultSwapped =
11238 getMonotonicPredicateTypeImpl(LHS, ICmpInst::getSwappedPredicate(Pred));
11239
11240 assert(*ResultSwapped != *Result &&
11241 "monotonicity should flip as we flip the predicate");
11242 }
11243#endif
11244
11245 return Result;
11246}
11247
11248std::optional<ScalarEvolution::MonotonicPredicateType>
11249ScalarEvolution::getMonotonicPredicateTypeImpl(const SCEVAddRecExpr *LHS,
11250 ICmpInst::Predicate Pred) {
11251 // A zero step value for LHS means the induction variable is essentially a
11252 // loop invariant value. We don't really depend on the predicate actually
11253 // flipping from false to true (for increasing predicates, and the other way
11254 // around for decreasing predicates), all we care about is that *if* the
11255 // predicate changes then it only changes from false to true.
11256 //
11257 // A zero step value in itself is not very useful, but there may be places
11258 // where SCEV can prove X >= 0 but not prove X > 0, so it is helpful to be
11259 // as general as possible.
11260
11261 // Only handle LE/LT/GE/GT predicates.
11262 if (!ICmpInst::isRelational(Pred))
11263 return std::nullopt;
11264
11265 bool IsGreater = ICmpInst::isGE(Pred) || ICmpInst::isGT(Pred);
11266 assert((IsGreater || ICmpInst::isLE(Pred) || ICmpInst::isLT(Pred)) &&
11267 "Should be greater or less!");
11268
11269 // Check that AR does not wrap.
11270 if (ICmpInst::isUnsigned(Pred)) {
11271 if (!LHS->hasNoUnsignedWrap())
11272 return std::nullopt;
11274 }
11275 assert(ICmpInst::isSigned(Pred) &&
11276 "Relational predicate is either signed or unsigned!");
11277 if (!LHS->hasNoSignedWrap())
11278 return std::nullopt;
11279
11280 const SCEV *Step = LHS->getStepRecurrence(*this);
11281
11282 if (isKnownNonNegative(Step))
11284
11285 if (isKnownNonPositive(Step))
11287
11288 return std::nullopt;
11289}
11290
11291std::optional<ScalarEvolution::LoopInvariantPredicate>
11293 const SCEV *RHS, const Loop *L,
11294 const Instruction *CtxI) {
11295 // If there is a loop-invariant, force it into the RHS, otherwise bail out.
11296 if (!isLoopInvariant(RHS, L)) {
11297 if (!isLoopInvariant(LHS, L))
11298 return std::nullopt;
11299
11300 std::swap(LHS, RHS);
11302 }
11303
11304 const SCEVAddRecExpr *ArLHS = dyn_cast<SCEVAddRecExpr>(LHS);
11305 if (!ArLHS || ArLHS->getLoop() != L)
11306 return std::nullopt;
11307
11308 auto MonotonicType = getMonotonicPredicateType(ArLHS, Pred);
11309 if (!MonotonicType)
11310 return std::nullopt;
11311 // If the predicate "ArLHS `Pred` RHS" monotonically increases from false to
11312 // true as the loop iterates, and the backedge is control dependent on
11313 // "ArLHS `Pred` RHS" == true then we can reason as follows:
11314 //
11315 // * if the predicate was false in the first iteration then the predicate
11316 // is never evaluated again, since the loop exits without taking the
11317 // backedge.
11318 // * if the predicate was true in the first iteration then it will
11319 // continue to be true for all future iterations since it is
11320 // monotonically increasing.
11321 //
11322 // For both the above possibilities, we can replace the loop varying
11323 // predicate with its value on the first iteration of the loop (which is
11324 // loop invariant).
11325 //
11326 // A similar reasoning applies for a monotonically decreasing predicate, by
11327 // replacing true with false and false with true in the above two bullets.
11329 auto P = Increasing ? Pred : ICmpInst::getInverseCmpPredicate(Pred);
11330
11331 if (isLoopBackedgeGuardedByCond(L, P, LHS, RHS))
11333 RHS);
11334
11335 if (!CtxI)
11336 return std::nullopt;
11337 // Try to prove via context.
11338 // TODO: Support other cases.
11339 switch (Pred) {
11340 default:
11341 break;
11342 case ICmpInst::ICMP_ULE:
11343 case ICmpInst::ICMP_ULT: {
11344 assert(ArLHS->hasNoUnsignedWrap() && "Is a requirement of monotonicity!");
11345 // Given preconditions
11346 // (1) ArLHS does not cross the border of positive and negative parts of
11347 // range because of:
11348 // - Positive step; (TODO: lift this limitation)
11349 // - nuw - does not cross zero boundary;
11350 // - nsw - does not cross SINT_MAX boundary;
11351 // (2) ArLHS <s RHS
11352 // (3) RHS >=s 0
11353 // we can replace the loop variant ArLHS <u RHS condition with loop
11354 // invariant Start(ArLHS) <u RHS.
11355 //
11356 // Because of (1) there are two options:
11357 // - ArLHS is always negative. It means that ArLHS <u RHS is always false;
11358 // - ArLHS is always non-negative. Because of (3) RHS is also non-negative.
11359 // It means that ArLHS <s RHS <=> ArLHS <u RHS.
11360 // Because of (2) ArLHS <u RHS is trivially true.
11361 // All together it means that ArLHS <u RHS <=> Start(ArLHS) >=s 0.
11362 // We can strengthen this to Start(ArLHS) <u RHS.
11363 auto SignFlippedPred = ICmpInst::getFlippedSignednessPredicate(Pred);
11364 if (ArLHS->hasNoSignedWrap() && ArLHS->isAffine() &&
11365 isKnownPositive(ArLHS->getStepRecurrence(*this)) &&
11366 isKnownNonNegative(RHS) &&
11367 isKnownPredicateAt(SignFlippedPred, ArLHS, RHS, CtxI))
11369 RHS);
11370 }
11371 }
11372
11373 return std::nullopt;
11374}
11375
11376std::optional<ScalarEvolution::LoopInvariantPredicate>
11378 CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS, const Loop *L,
11379 const Instruction *CtxI, const SCEV *MaxIter) {
11381 Pred, LHS, RHS, L, CtxI, MaxIter))
11382 return LIP;
11383 if (auto *UMin = dyn_cast<SCEVUMinExpr>(MaxIter))
11384 // Number of iterations expressed as UMIN isn't always great for expressing
11385 // the value on the last iteration. If the straightforward approach didn't
11386 // work, try the following trick: if the a predicate is invariant for X, it
11387 // is also invariant for umin(X, ...). So try to find something that works
11388 // among subexpressions of MaxIter expressed as umin.
11389 for (auto *Op : UMin->operands())
11391 Pred, LHS, RHS, L, CtxI, Op))
11392 return LIP;
11393 return std::nullopt;
11394}
11395
11396std::optional<ScalarEvolution::LoopInvariantPredicate>
11398 CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS, const Loop *L,
11399 const Instruction *CtxI, const SCEV *MaxIter) {
11400 // Try to prove the following set of facts:
11401 // - The predicate is monotonic in the iteration space.
11402 // - If the check does not fail on the 1st iteration:
11403 // - No overflow will happen during first MaxIter iterations;
11404 // - It will not fail on the MaxIter'th iteration.
11405 // If the check does fail on the 1st iteration, we leave the loop and no
11406 // other checks matter.
11407
11408 // If there is a loop-invariant, force it into the RHS, otherwise bail out.
11409 if (!isLoopInvariant(RHS, L)) {
11410 if (!isLoopInvariant(LHS, L))
11411 return std::nullopt;
11412
11413 std::swap(LHS, RHS);
11415 }
11416
11417 auto *AR = dyn_cast<SCEVAddRecExpr>(LHS);
11418 if (!AR || AR->getLoop() != L)
11419 return std::nullopt;
11420
11421 // The predicate must be relational (i.e. <, <=, >=, >).
11422 if (!ICmpInst::isRelational(Pred))
11423 return std::nullopt;
11424
11425 // TODO: Support steps other than +/- 1.
11426 const SCEV *Step = AR->getStepRecurrence(*this);
11427 auto *One = getOne(Step->getType());
11428 auto *MinusOne = getNegativeSCEV(One);
11429 if (Step != One && Step != MinusOne)
11430 return std::nullopt;
11431
11432 // Type mismatch here means that MaxIter is potentially larger than max
11433 // unsigned value in start type, which mean we cannot prove no wrap for the
11434 // indvar.
11435 if (AR->getType() != MaxIter->getType())
11436 return std::nullopt;
11437
11438 // Value of IV on suggested last iteration.
11439 const SCEV *Last = AR->evaluateAtIteration(MaxIter, *this);
11440 // Does it still meet the requirement?
11441 if (!isLoopBackedgeGuardedByCond(L, Pred, Last, RHS))
11442 return std::nullopt;
11443 // Because step is +/- 1 and MaxIter has same type as Start (i.e. it does
11444 // not exceed max unsigned value of this type), this effectively proves
11445 // that there is no wrap during the iteration. To prove that there is no
11446 // signed/unsigned wrap, we need to check that
11447 // Start <= Last for step = 1 or Start >= Last for step = -1.
11448 ICmpInst::Predicate NoOverflowPred =
11450 if (Step == MinusOne)
11451 NoOverflowPred = ICmpInst::getSwappedCmpPredicate(NoOverflowPred);
11452 const SCEV *Start = AR->getStart();
11453 if (!isKnownPredicateAt(NoOverflowPred, Start, Last, CtxI))
11454 return std::nullopt;
11455
11456 // Everything is fine.
11457 return ScalarEvolution::LoopInvariantPredicate(Pred, Start, RHS);
11458}
11459
11460bool ScalarEvolution::isKnownPredicateViaConstantRanges(CmpPredicate Pred,
11461 const SCEV *LHS,
11462 const SCEV *RHS) {
11463 if (HasSameValue(LHS, RHS))
11464 return ICmpInst::isTrueWhenEqual(Pred);
11465
11466 auto CheckRange = [&](bool IsSigned) {
11467 auto RangeLHS = IsSigned ? getSignedRange(LHS) : getUnsignedRange(LHS);
11468 auto RangeRHS = IsSigned ? getSignedRange(RHS) : getUnsignedRange(RHS);
11469 return RangeLHS.icmp(Pred, RangeRHS);
11470 };
11471
11472 // The check at the top of the function catches the case where the values are
11473 // known to be equal.
11474 if (Pred == CmpInst::ICMP_EQ)
11475 return false;
11476
11477 if (Pred == CmpInst::ICMP_NE) {
11478 if (CheckRange(true) || CheckRange(false))
11479 return true;
11480 auto *Diff = getMinusSCEV(LHS, RHS);
11481 return !isa<SCEVCouldNotCompute>(Diff) && isKnownNonZero(Diff);
11482 }
11483
11484 return CheckRange(CmpInst::isSigned(Pred));
11485}
11486
11487bool ScalarEvolution::isKnownPredicateViaNoOverflow(CmpPredicate Pred,
11488 const SCEV *LHS,
11489 const SCEV *RHS) {
11490 // Match X to (A + C1)<ExpectedFlags> and Y to (A + C2)<ExpectedFlags>, where
11491 // C1 and C2 are constant integers. If either X or Y are not add expressions,
11492 // consider them as X + 0 and Y + 0 respectively. C1 and C2 are returned via
11493 // OutC1 and OutC2.
11494 auto MatchBinaryAddToConst = [this](const SCEV *X, const SCEV *Y,
11495 APInt &OutC1, APInt &OutC2,
11496 SCEV::NoWrapFlags ExpectedFlags) {
11497 const SCEV *XNonConstOp, *XConstOp;
11498 const SCEV *YNonConstOp, *YConstOp;
11499 SCEV::NoWrapFlags XFlagsPresent;
11500 SCEV::NoWrapFlags YFlagsPresent;
11501
11502 if (!splitBinaryAdd(X, XConstOp, XNonConstOp, XFlagsPresent)) {
11503 XConstOp = getZero(X->getType());
11504 XNonConstOp = X;
11505 XFlagsPresent = ExpectedFlags;
11506 }
11507 if (!isa<SCEVConstant>(XConstOp))
11508 return false;
11509
11510 if (!splitBinaryAdd(Y, YConstOp, YNonConstOp, YFlagsPresent)) {
11511 YConstOp = getZero(Y->getType());
11512 YNonConstOp = Y;
11513 YFlagsPresent = ExpectedFlags;
11514 }
11515
11516 if (YNonConstOp != XNonConstOp)
11517 return false;
11518
11519 if (!isa<SCEVConstant>(YConstOp))
11520 return false;
11521
11522 // When matching ADDs with NUW flags (and unsigned predicates), only the
11523 // second ADD (with the larger constant) requires NUW.
11524 if ((YFlagsPresent & ExpectedFlags) != ExpectedFlags)
11525 return false;
11526 if (ExpectedFlags != SCEV::FlagNUW &&
11527 (XFlagsPresent & ExpectedFlags) != ExpectedFlags) {
11528 return false;
11529 }
11530
11531 OutC1 = cast<SCEVConstant>(XConstOp)->getAPInt();
11532 OutC2 = cast<SCEVConstant>(YConstOp)->getAPInt();
11533
11534 return true;
11535 };
11536
11537 APInt C1;
11538 APInt C2;
11539
11540 switch (Pred) {
11541 default:
11542 break;
11543
11544 case ICmpInst::ICMP_SGE:
11545 std::swap(LHS, RHS);
11546 [[fallthrough]];
11547 case ICmpInst::ICMP_SLE:
11548 // (X + C1)<nsw> s<= (X + C2)<nsw> if C1 s<= C2.
11549 if (MatchBinaryAddToConst(LHS, RHS, C1, C2, SCEV::FlagNSW) && C1.sle(C2))
11550 return true;
11551
11552 break;
11553
11554 case ICmpInst::ICMP_SGT:
11555 std::swap(LHS, RHS);
11556 [[fallthrough]];
11557 case ICmpInst::ICMP_SLT:
11558 // (X + C1)<nsw> s< (X + C2)<nsw> if C1 s< C2.
11559 if (MatchBinaryAddToConst(LHS, RHS, C1, C2, SCEV::FlagNSW) && C1.slt(C2))
11560 return true;
11561
11562 break;
11563
11564 case ICmpInst::ICMP_UGE:
11565 std::swap(LHS, RHS);
11566 [[fallthrough]];
11567 case ICmpInst::ICMP_ULE:
11568 // (X + C1) u<= (X + C2)<nuw> for C1 u<= C2.
11569 if (MatchBinaryAddToConst(LHS, RHS, C1, C2, SCEV::FlagNUW) && C1.ule(C2))
11570 return true;
11571
11572 break;
11573
11574 case ICmpInst::ICMP_UGT:
11575 std::swap(LHS, RHS);
11576 [[fallthrough]];
11577 case ICmpInst::ICMP_ULT:
11578 // (X + C1) u< (X + C2)<nuw> if C1 u< C2.
11579 if (MatchBinaryAddToConst(LHS, RHS, C1, C2, SCEV::FlagNUW) && C1.ult(C2))
11580 return true;
11581 break;
11582 }
11583
11584 return false;
11585}
11586
11587bool ScalarEvolution::isKnownPredicateViaSplitting(CmpPredicate Pred,
11588 const SCEV *LHS,
11589 const SCEV *RHS) {
11590 if (Pred != ICmpInst::ICMP_ULT || ProvingSplitPredicate)
11591 return false;
11592
11593 // Allowing arbitrary number of activations of isKnownPredicateViaSplitting on
11594 // the stack can result in exponential time complexity.
11595 SaveAndRestore Restore(ProvingSplitPredicate, true);
11596
11597 // If L >= 0 then I `ult` L <=> I >= 0 && I `slt` L
11598 //
11599 // To prove L >= 0 we use isKnownNonNegative whereas to prove I >= 0 we use
11600 // isKnownPredicate. isKnownPredicate is more powerful, but also more
11601 // expensive; and using isKnownNonNegative(RHS) is sufficient for most of the
11602 // interesting cases seen in practice. We can consider "upgrading" L >= 0 to
11603 // use isKnownPredicate later if needed.
11604 return isKnownNonNegative(RHS) &&
11607}
11608
11609bool ScalarEvolution::isImpliedViaGuard(const BasicBlock *BB, CmpPredicate Pred,
11610 const SCEV *LHS, const SCEV *RHS) {
11611 // No need to even try if we know the module has no guards.
11612 if (!HasGuards)
11613 return false;
11614
11615 return any_of(*BB, [&](const Instruction &I) {
11616 using namespace llvm::PatternMatch;
11617
11618 Value *Condition;
11620 m_Value(Condition))) &&
11621 isImpliedCond(Pred, LHS, RHS, Condition, false);
11622 });
11623}
11624
11625/// isLoopBackedgeGuardedByCond - Test whether the backedge of the loop is
11626/// protected by a conditional between LHS and RHS. This is used to
11627/// to eliminate casts.
11629 CmpPredicate Pred,
11630 const SCEV *LHS,
11631 const SCEV *RHS) {
11632 // Interpret a null as meaning no loop, where there is obviously no guard
11633 // (interprocedural conditions notwithstanding). Do not bother about
11634 // unreachable loops.
11635 if (!L || !DT.isReachableFromEntry(L->getHeader()))
11636 return true;
11637
11638 if (VerifyIR)
11639 assert(!verifyFunction(*L->getHeader()->getParent(), &dbgs()) &&
11640 "This cannot be done on broken IR!");
11641
11642
11643 if (isKnownViaNonRecursiveReasoning(Pred, LHS, RHS))
11644 return true;
11645
11646 BasicBlock *Latch = L->getLoopLatch();
11647 if (!Latch)
11648 return false;
11649
11650 BranchInst *LoopContinuePredicate =
11652 if (LoopContinuePredicate && LoopContinuePredicate->isConditional() &&
11653 isImpliedCond(Pred, LHS, RHS,
11654 LoopContinuePredicate->getCondition(),
11655 LoopContinuePredicate->getSuccessor(0) != L->getHeader()))
11656 return true;
11657
11658 // We don't want more than one activation of the following loops on the stack
11659 // -- that can lead to O(n!) time complexity.
11660 if (WalkingBEDominatingConds)
11661 return false;
11662
11663 SaveAndRestore ClearOnExit(WalkingBEDominatingConds, true);
11664
11665 // See if we can exploit a trip count to prove the predicate.
11666 const auto &BETakenInfo = getBackedgeTakenInfo(L);
11667 const SCEV *LatchBECount = BETakenInfo.getExact(Latch, this);
11668 if (LatchBECount != getCouldNotCompute()) {
11669 // We know that Latch branches back to the loop header exactly
11670 // LatchBECount times. This means the backdege condition at Latch is
11671 // equivalent to "{0,+,1} u< LatchBECount".
11672 Type *Ty = LatchBECount->getType();
11673 auto NoWrapFlags = SCEV::NoWrapFlags(SCEV::FlagNUW | SCEV::FlagNW);
11674 const SCEV *LoopCounter =
11675 getAddRecExpr(getZero(Ty), getOne(Ty), L, NoWrapFlags);
11676 if (isImpliedCond(Pred, LHS, RHS, ICmpInst::ICMP_ULT, LoopCounter,
11677 LatchBECount))
11678 return true;
11679 }
11680
11681 // Check conditions due to any @llvm.assume intrinsics.
11682 for (auto &AssumeVH : AC.assumptions()) {
11683 if (!AssumeVH)
11684 continue;
11685 auto *CI = cast<CallInst>(AssumeVH);
11686 if (!DT.dominates(CI, Latch->getTerminator()))
11687 continue;
11688
11689 if (isImpliedCond(Pred, LHS, RHS, CI->getArgOperand(0), false))
11690 return true;
11691 }
11692
11693 if (isImpliedViaGuard(Latch, Pred, LHS, RHS))
11694 return true;
11695
11696 for (DomTreeNode *DTN = DT[Latch], *HeaderDTN = DT[L->getHeader()];
11697 DTN != HeaderDTN; DTN = DTN->getIDom()) {
11698 assert(DTN && "should reach the loop header before reaching the root!");
11699
11700 BasicBlock *BB = DTN->getBlock();
11701 if (isImpliedViaGuard(BB, Pred, LHS, RHS))
11702 return true;
11703
11704 BasicBlock *PBB = BB->getSinglePredecessor();
11705 if (!PBB)
11706 continue;
11707
11708 BranchInst *ContinuePredicate = dyn_cast<BranchInst>(PBB->getTerminator());
11709 if (!ContinuePredicate || !ContinuePredicate->isConditional())
11710 continue;
11711
11712 Value *Condition = ContinuePredicate->getCondition();
11713
11714 // If we have an edge `E` within the loop body that dominates the only
11715 // latch, the condition guarding `E` also guards the backedge. This
11716 // reasoning works only for loops with a single latch.
11717
11718 BasicBlockEdge DominatingEdge(PBB, BB);
11719 if (DominatingEdge.isSingleEdge()) {
11720 // We're constructively (and conservatively) enumerating edges within the
11721 // loop body that dominate the latch. The dominator tree better agree
11722 // with us on this:
11723 assert(DT.dominates(DominatingEdge, Latch) && "should be!");
11724
11725 if (isImpliedCond(Pred, LHS, RHS, Condition,
11726 BB != ContinuePredicate->getSuccessor(0)))
11727 return true;
11728 }
11729 }
11730
11731 return false;
11732}
11733
11735 CmpPredicate Pred,
11736 const SCEV *LHS,
11737 const SCEV *RHS) {
11738 // Do not bother proving facts for unreachable code.
11739 if (!DT.isReachableFromEntry(BB))
11740 return true;
11741 if (VerifyIR)
11742 assert(!verifyFunction(*BB->getParent(), &dbgs()) &&
11743 "This cannot be done on broken IR!");
11744
11745 // If we cannot prove strict comparison (e.g. a > b), maybe we can prove
11746 // the facts (a >= b && a != b) separately. A typical situation is when the
11747 // non-strict comparison is known from ranges and non-equality is known from
11748 // dominating predicates. If we are proving strict comparison, we always try
11749 // to prove non-equality and non-strict comparison separately.
11750 CmpPredicate NonStrictPredicate = ICmpInst::getNonStrictCmpPredicate(Pred);
11751 const bool ProvingStrictComparison =
11752 Pred != NonStrictPredicate.dropSameSign();
11753 bool ProvedNonStrictComparison = false;
11754 bool ProvedNonEquality = false;
11755
11756 auto SplitAndProve = [&](std::function<bool(CmpPredicate)> Fn) -> bool {
11757 if (!ProvedNonStrictComparison)
11758 ProvedNonStrictComparison = Fn(NonStrictPredicate);
11759 if (!ProvedNonEquality)
11760 ProvedNonEquality = Fn(ICmpInst::ICMP_NE);
11761 if (ProvedNonStrictComparison && ProvedNonEquality)
11762 return true;
11763 return false;
11764 };
11765
11766 if (ProvingStrictComparison) {
11767 auto ProofFn = [&](CmpPredicate P) {
11768 return isKnownViaNonRecursiveReasoning(P, LHS, RHS);
11769 };
11770 if (SplitAndProve(ProofFn))
11771 return true;
11772 }
11773
11774 // Try to prove (Pred, LHS, RHS) using isImpliedCond.
11775 auto ProveViaCond = [&](const Value *Condition, bool Inverse) {
11776 const Instruction *CtxI = &BB->front();
11777 if (isImpliedCond(Pred, LHS, RHS, Condition, Inverse, CtxI))
11778 return true;
11779 if (ProvingStrictComparison) {
11780 auto ProofFn = [&](CmpPredicate P) {
11781 return isImpliedCond(P, LHS, RHS, Condition, Inverse, CtxI);
11782 };
11783 if (SplitAndProve(ProofFn))
11784 return true;
11785 }
11786 return false;
11787 };
11788
11789 // Starting at the block's predecessor, climb up the predecessor chain, as long
11790 // as there are predecessors that can be found that have unique successors
11791 // leading to the original block.
11792 const Loop *ContainingLoop = LI.getLoopFor(BB);
11793 const BasicBlock *PredBB;
11794 if (ContainingLoop && ContainingLoop->getHeader() == BB)
11795 PredBB = ContainingLoop->getLoopPredecessor();
11796 else
11797 PredBB = BB->getSinglePredecessor();
11798 for (std::pair<const BasicBlock *, const BasicBlock *> Pair(PredBB, BB);
11799 Pair.first; Pair = getPredecessorWithUniqueSuccessorForBB(Pair.first)) {
11800 const BranchInst *BlockEntryPredicate =
11801 dyn_cast<BranchInst>(Pair.first->getTerminator());
11802 if (!BlockEntryPredicate || BlockEntryPredicate->isUnconditional())
11803 continue;
11804
11805 if (ProveViaCond(BlockEntryPredicate->getCondition(),
11806 BlockEntryPredicate->getSuccessor(0) != Pair.second))
11807 return true;
11808 }
11809
11810 // Check conditions due to any @llvm.assume intrinsics.
11811 for (auto &AssumeVH : AC.assumptions()) {
11812 if (!AssumeVH)
11813 continue;
11814 auto *CI = cast<CallInst>(AssumeVH);
11815 if (!DT.dominates(CI, BB))
11816 continue;
11817
11818 if (ProveViaCond(CI->getArgOperand(0), false))
11819 return true;
11820 }
11821
11822 // Check conditions due to any @llvm.experimental.guard intrinsics.
11823 auto *GuardDecl = Intrinsic::getDeclarationIfExists(
11824 F.getParent(), Intrinsic::experimental_guard);
11825 if (GuardDecl)
11826 for (const auto *GU : GuardDecl->users())
11827 if (const auto *Guard = dyn_cast<IntrinsicInst>(GU))
11828 if (Guard->getFunction() == BB->getParent() && DT.dominates(Guard, BB))
11829 if (ProveViaCond(Guard->getArgOperand(0), false))
11830 return true;
11831 return false;
11832}
11833
11835 const SCEV *LHS,
11836 const SCEV *RHS) {
11837 // Interpret a null as meaning no loop, where there is obviously no guard
11838 // (interprocedural conditions notwithstanding).
11839 if (!L)
11840 return false;
11841
11842 // Both LHS and RHS must be available at loop entry.
11844 "LHS is not available at Loop Entry");
11846 "RHS is not available at Loop Entry");
11847
11848 if (isKnownViaNonRecursiveReasoning(Pred, LHS, RHS))
11849 return true;
11850
11851 return isBasicBlockEntryGuardedByCond(L->getHeader(), Pred, LHS, RHS);
11852}
11853
11854bool ScalarEvolution::isImpliedCond(CmpPredicate Pred, const SCEV *LHS,
11855 const SCEV *RHS,
11856 const Value *FoundCondValue, bool Inverse,
11857 const Instruction *CtxI) {
11858 // False conditions implies anything. Do not bother analyzing it further.
11859 if (FoundCondValue ==
11860 ConstantInt::getBool(FoundCondValue->getContext(), Inverse))
11861 return true;
11862
11863 if (!PendingLoopPredicates.insert(FoundCondValue).second)
11864 return false;
11865
11866 auto ClearOnExit =
11867 make_scope_exit([&]() { PendingLoopPredicates.erase(FoundCondValue); });
11868
11869 // Recursively handle And and Or conditions.
11870 const Value *Op0, *Op1;
11871 if (match(FoundCondValue, m_LogicalAnd(m_Value(Op0), m_Value(Op1)))) {
11872 if (!Inverse)
11873 return isImpliedCond(Pred, LHS, RHS, Op0, Inverse, CtxI) ||
11874 isImpliedCond(Pred, LHS, RHS, Op1, Inverse, CtxI);
11875 } else if (match(FoundCondValue, m_LogicalOr(m_Value(Op0), m_Value(Op1)))) {
11876 if (Inverse)
11877 return isImpliedCond(Pred, LHS, RHS, Op0, Inverse, CtxI) ||
11878 isImpliedCond(Pred, LHS, RHS, Op1, Inverse, CtxI);
11879 }
11880
11881 const ICmpInst *ICI = dyn_cast<ICmpInst>(FoundCondValue);
11882 if (!ICI) return false;
11883
11884 // Now that we found a conditional branch that dominates the loop or controls
11885 // the loop latch. Check to see if it is the comparison we are looking for.
11886 CmpPredicate FoundPred;
11887 if (Inverse)
11888 FoundPred = ICI->getInverseCmpPredicate();
11889 else
11890 FoundPred = ICI->getCmpPredicate();
11891
11892 const SCEV *FoundLHS = getSCEV(ICI->getOperand(0));
11893 const SCEV *FoundRHS = getSCEV(ICI->getOperand(1));
11894
11895 return isImpliedCond(Pred, LHS, RHS, FoundPred, FoundLHS, FoundRHS, CtxI);
11896}
11897
11898bool ScalarEvolution::isImpliedCond(CmpPredicate Pred, const SCEV *LHS,
11899 const SCEV *RHS, CmpPredicate FoundPred,
11900 const SCEV *FoundLHS, const SCEV *FoundRHS,
11901 const Instruction *CtxI) {
11902 // Balance the types.
11903 if (getTypeSizeInBits(LHS->getType()) <
11904 getTypeSizeInBits(FoundLHS->getType())) {
11905 // For unsigned and equality predicates, try to prove that both found
11906 // operands fit into narrow unsigned range. If so, try to prove facts in
11907 // narrow types.
11908 if (!CmpInst::isSigned(FoundPred) && !FoundLHS->getType()->isPointerTy() &&
11909 !FoundRHS->getType()->isPointerTy()) {
11910 auto *NarrowType = LHS->getType();
11911 auto *WideType = FoundLHS->getType();
11912 auto BitWidth = getTypeSizeInBits(NarrowType);
11913 const SCEV *MaxValue = getZeroExtendExpr(
11915 if (isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_ULE, FoundLHS,
11916 MaxValue) &&
11917 isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_ULE, FoundRHS,
11918 MaxValue)) {
11919 const SCEV *TruncFoundLHS = getTruncateExpr(FoundLHS, NarrowType);
11920 const SCEV *TruncFoundRHS = getTruncateExpr(FoundRHS, NarrowType);
11921 // We cannot preserve samesign after truncation.
11922 if (isImpliedCondBalancedTypes(Pred, LHS, RHS, FoundPred.dropSameSign(),
11923 TruncFoundLHS, TruncFoundRHS, CtxI))
11924 return true;
11925 }
11926 }
11927
11928 if (LHS->getType()->isPointerTy() || RHS->getType()->isPointerTy())
11929 return false;
11930 if (CmpInst::isSigned(Pred)) {
11931 LHS = getSignExtendExpr(LHS, FoundLHS->getType());
11932 RHS = getSignExtendExpr(RHS, FoundLHS->getType());
11933 } else {
11934 LHS = getZeroExtendExpr(LHS, FoundLHS->getType());
11935 RHS = getZeroExtendExpr(RHS, FoundLHS->getType());
11936 }
11937 } else if (getTypeSizeInBits(LHS->getType()) >
11938 getTypeSizeInBits(FoundLHS->getType())) {
11939 if (FoundLHS->getType()->isPointerTy() || FoundRHS->getType()->isPointerTy())
11940 return false;
11941 if (CmpInst::isSigned(FoundPred)) {
11942 FoundLHS = getSignExtendExpr(FoundLHS, LHS->getType());
11943 FoundRHS = getSignExtendExpr(FoundRHS, LHS->getType());
11944 } else {
11945 FoundLHS = getZeroExtendExpr(FoundLHS, LHS->getType());
11946 FoundRHS = getZeroExtendExpr(FoundRHS, LHS->getType());
11947 }
11948 }
11949 return isImpliedCondBalancedTypes(Pred, LHS, RHS, FoundPred, FoundLHS,
11950 FoundRHS, CtxI);
11951}
11952
11953bool ScalarEvolution::isImpliedCondBalancedTypes(
11954 CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS, CmpPredicate FoundPred,
11955 const SCEV *FoundLHS, const SCEV *FoundRHS, const Instruction *CtxI) {
11957 getTypeSizeInBits(FoundLHS->getType()) &&
11958 "Types should be balanced!");
11959 // Canonicalize the query to match the way instcombine will have
11960 // canonicalized the comparison.
11961 if (SimplifyICmpOperands(Pred, LHS, RHS))
11962 if (LHS == RHS)
11963 return CmpInst::isTrueWhenEqual(Pred);
11964 if (SimplifyICmpOperands(FoundPred, FoundLHS, FoundRHS))
11965 if (FoundLHS == FoundRHS)
11966 return CmpInst::isFalseWhenEqual(FoundPred);
11967
11968 // Check to see if we can make the LHS or RHS match.
11969 if (LHS == FoundRHS || RHS == FoundLHS) {
11970 if (isa<SCEVConstant>(RHS)) {
11971 std::swap(FoundLHS, FoundRHS);
11972 FoundPred = ICmpInst::getSwappedCmpPredicate(FoundPred);
11973 } else {
11974 std::swap(LHS, RHS);
11976 }
11977 }
11978
11979 // Check whether the found predicate is the same as the desired predicate.
11980 if (auto P = CmpPredicate::getMatching(FoundPred, Pred))
11981 return isImpliedCondOperands(*P, LHS, RHS, FoundLHS, FoundRHS, CtxI);
11982
11983 // Check whether swapping the found predicate makes it the same as the
11984 // desired predicate.
11985 if (auto P = CmpPredicate::getMatching(
11986 ICmpInst::getSwappedCmpPredicate(FoundPred), Pred)) {
11987 // We can write the implication
11988 // 0. LHS Pred RHS <- FoundLHS SwapPred FoundRHS
11989 // using one of the following ways:
11990 // 1. LHS Pred RHS <- FoundRHS Pred FoundLHS
11991 // 2. RHS SwapPred LHS <- FoundLHS SwapPred FoundRHS
11992 // 3. LHS Pred RHS <- ~FoundLHS Pred ~FoundRHS
11993 // 4. ~LHS SwapPred ~RHS <- FoundLHS SwapPred FoundRHS
11994 // Forms 1. and 2. require swapping the operands of one condition. Don't
11995 // do this if it would break canonical constant/addrec ordering.
11997 return isImpliedCondOperands(ICmpInst::getSwappedCmpPredicate(*P), RHS,
11998 LHS, FoundLHS, FoundRHS, CtxI);
11999 if (!isa<SCEVConstant>(FoundRHS) && !isa<SCEVAddRecExpr>(FoundLHS))
12000 return isImpliedCondOperands(*P, LHS, RHS, FoundRHS, FoundLHS, CtxI);
12001
12002 // There's no clear preference between forms 3. and 4., try both. Avoid
12003 // forming getNotSCEV of pointer values as the resulting subtract is
12004 // not legal.
12005 if (!LHS->getType()->isPointerTy() && !RHS->getType()->isPointerTy() &&
12006 isImpliedCondOperands(ICmpInst::getSwappedCmpPredicate(*P),
12007 getNotSCEV(LHS), getNotSCEV(RHS), FoundLHS,
12008 FoundRHS, CtxI))
12009 return true;
12010
12011 if (!FoundLHS->getType()->isPointerTy() &&
12012 !FoundRHS->getType()->isPointerTy() &&
12013 isImpliedCondOperands(*P, LHS, RHS, getNotSCEV(FoundLHS),
12014 getNotSCEV(FoundRHS), CtxI))
12015 return true;
12016
12017 return false;
12018 }
12019
12020 auto IsSignFlippedPredicate = [](CmpInst::Predicate P1,
12021 CmpInst::Predicate P2) {
12022 assert(P1 != P2 && "Handled earlier!");
12023 return CmpInst::isRelational(P2) &&
12025 };
12026 if (IsSignFlippedPredicate(Pred, FoundPred)) {
12027 // Unsigned comparison is the same as signed comparison when both the
12028 // operands are non-negative or negative.
12029 if ((isKnownNonNegative(FoundLHS) && isKnownNonNegative(FoundRHS)) ||
12030 (isKnownNegative(FoundLHS) && isKnownNegative(FoundRHS)))
12031 return isImpliedCondOperands(Pred, LHS, RHS, FoundLHS, FoundRHS, CtxI);
12032 // Create local copies that we can freely swap and canonicalize our
12033 // conditions to "le/lt".
12034 CmpPredicate CanonicalPred = Pred, CanonicalFoundPred = FoundPred;
12035 const SCEV *CanonicalLHS = LHS, *CanonicalRHS = RHS,
12036 *CanonicalFoundLHS = FoundLHS, *CanonicalFoundRHS = FoundRHS;
12037 if (ICmpInst::isGT(CanonicalPred) || ICmpInst::isGE(CanonicalPred)) {
12038 CanonicalPred = ICmpInst::getSwappedCmpPredicate(CanonicalPred);
12039 CanonicalFoundPred = ICmpInst::getSwappedCmpPredicate(CanonicalFoundPred);
12040 std::swap(CanonicalLHS, CanonicalRHS);
12041 std::swap(CanonicalFoundLHS, CanonicalFoundRHS);
12042 }
12043 assert((ICmpInst::isLT(CanonicalPred) || ICmpInst::isLE(CanonicalPred)) &&
12044 "Must be!");
12045 assert((ICmpInst::isLT(CanonicalFoundPred) ||
12046 ICmpInst::isLE(CanonicalFoundPred)) &&
12047 "Must be!");
12048 if (ICmpInst::isSigned(CanonicalPred) && isKnownNonNegative(CanonicalRHS))
12049 // Use implication:
12050 // x <u y && y >=s 0 --> x <s y.
12051 // If we can prove the left part, the right part is also proven.
12052 return isImpliedCondOperands(CanonicalFoundPred, CanonicalLHS,
12053 CanonicalRHS, CanonicalFoundLHS,
12054 CanonicalFoundRHS);
12055 if (ICmpInst::isUnsigned(CanonicalPred) && isKnownNegative(CanonicalRHS))
12056 // Use implication:
12057 // x <s y && y <s 0 --> x <u y.
12058 // If we can prove the left part, the right part is also proven.
12059 return isImpliedCondOperands(CanonicalFoundPred, CanonicalLHS,
12060 CanonicalRHS, CanonicalFoundLHS,
12061 CanonicalFoundRHS);
12062 }
12063
12064 // Check if we can make progress by sharpening ranges.
12065 if (FoundPred == ICmpInst::ICMP_NE &&
12066 (isa<SCEVConstant>(FoundLHS) || isa<SCEVConstant>(FoundRHS))) {
12067
12068 const SCEVConstant *C = nullptr;
12069 const SCEV *V = nullptr;
12070
12071 if (isa<SCEVConstant>(FoundLHS)) {
12072 C = cast<SCEVConstant>(FoundLHS);
12073 V = FoundRHS;
12074 } else {
12075 C = cast<SCEVConstant>(FoundRHS);
12076 V = FoundLHS;
12077 }
12078
12079 // The guarding predicate tells us that C != V. If the known range
12080 // of V is [C, t), we can sharpen the range to [C + 1, t). The
12081 // range we consider has to correspond to same signedness as the
12082 // predicate we're interested in folding.
12083
12084 APInt Min = ICmpInst::isSigned(Pred) ?
12086
12087 if (Min == C->getAPInt()) {
12088 // Given (V >= Min && V != Min) we conclude V >= (Min + 1).
12089 // This is true even if (Min + 1) wraps around -- in case of
12090 // wraparound, (Min + 1) < Min, so (V >= Min => V >= (Min + 1)).
12091
12092 APInt SharperMin = Min + 1;
12093
12094 switch (Pred) {
12095 case ICmpInst::ICMP_SGE:
12096 case ICmpInst::ICMP_UGE:
12097 // We know V `Pred` SharperMin. If this implies LHS `Pred`
12098 // RHS, we're done.
12099 if (isImpliedCondOperands(Pred, LHS, RHS, V, getConstant(SharperMin),
12100 CtxI))
12101 return true;
12102 [[fallthrough]];
12103
12104 case ICmpInst::ICMP_SGT:
12105 case ICmpInst::ICMP_UGT:
12106 // We know from the range information that (V `Pred` Min ||
12107 // V == Min). We know from the guarding condition that !(V
12108 // == Min). This gives us
12109 //
12110 // V `Pred` Min || V == Min && !(V == Min)
12111 // => V `Pred` Min
12112 //
12113 // If V `Pred` Min implies LHS `Pred` RHS, we're done.
12114
12115 if (isImpliedCondOperands(Pred, LHS, RHS, V, getConstant(Min), CtxI))
12116 return true;
12117 break;
12118
12119 // `LHS < RHS` and `LHS <= RHS` are handled in the same way as `RHS > LHS` and `RHS >= LHS` respectively.
12120 case ICmpInst::ICMP_SLE:
12121 case ICmpInst::ICMP_ULE:
12122 if (isImpliedCondOperands(ICmpInst::getSwappedCmpPredicate(Pred), RHS,
12123 LHS, V, getConstant(SharperMin), CtxI))
12124 return true;
12125 [[fallthrough]];
12126
12127 case ICmpInst::ICMP_SLT:
12128 case ICmpInst::ICMP_ULT:
12129 if (isImpliedCondOperands(ICmpInst::getSwappedCmpPredicate(Pred), RHS,
12130 LHS, V, getConstant(Min), CtxI))
12131 return true;
12132 break;
12133
12134 default:
12135 // No change
12136 break;
12137 }
12138 }
12139 }
12140
12141 // Check whether the actual condition is beyond sufficient.
12142 if (FoundPred == ICmpInst::ICMP_EQ)
12143 if (ICmpInst::isTrueWhenEqual(Pred))
12144 if (isImpliedCondOperands(Pred, LHS, RHS, FoundLHS, FoundRHS, CtxI))
12145 return true;
12146 if (Pred == ICmpInst::ICMP_NE)
12147 if (!ICmpInst::isTrueWhenEqual(FoundPred))
12148 if (isImpliedCondOperands(FoundPred, LHS, RHS, FoundLHS, FoundRHS, CtxI))
12149 return true;
12150
12151 if (isImpliedCondOperandsViaRanges(Pred, LHS, RHS, FoundPred, FoundLHS, FoundRHS))
12152 return true;
12153
12154 // Otherwise assume the worst.
12155 return false;
12156}
12157
12158bool ScalarEvolution::splitBinaryAdd(const SCEV *Expr,
12159 const SCEV *&L, const SCEV *&R,
12160 SCEV::NoWrapFlags &Flags) {
12161 if (!match(Expr, m_scev_Add(m_SCEV(L), m_SCEV(R))))
12162 return false;
12163
12164 Flags = cast<SCEVAddExpr>(Expr)->getNoWrapFlags();
12165 return true;
12166}
12167
12168std::optional<APInt>
12170 // We avoid subtracting expressions here because this function is usually
12171 // fairly deep in the call stack (i.e. is called many times).
12172
12173 unsigned BW = getTypeSizeInBits(More->getType());
12174 APInt Diff(BW, 0);
12175 APInt DiffMul(BW, 1);
12176 // Try various simplifications to reduce the difference to a constant. Limit
12177 // the number of allowed simplifications to keep compile-time low.
12178 for (unsigned I = 0; I < 8; ++I) {
12179 if (More == Less)
12180 return Diff;
12181
12182 // Reduce addrecs with identical steps to their start value.
12184 const auto *LAR = cast<SCEVAddRecExpr>(Less);
12185 const auto *MAR = cast<SCEVAddRecExpr>(More);
12186
12187 if (LAR->getLoop() != MAR->getLoop())
12188 return std::nullopt;
12189
12190 // We look at affine expressions only; not for correctness but to keep
12191 // getStepRecurrence cheap.
12192 if (!LAR->isAffine() || !MAR->isAffine())
12193 return std::nullopt;
12194
12195 if (LAR->getStepRecurrence(*this) != MAR->getStepRecurrence(*this))
12196 return std::nullopt;
12197
12198 Less = LAR->getStart();
12199 More = MAR->getStart();
12200 continue;
12201 }
12202
12203 // Try to match a common constant multiply.
12204 auto MatchConstMul =
12205 [](const SCEV *S) -> std::optional<std::pair<const SCEV *, APInt>> {
12206 const APInt *C;
12207 const SCEV *Op;
12208 if (match(S, m_scev_Mul(m_scev_APInt(C), m_SCEV(Op))))
12209 return {{Op, *C}};
12210 return std::nullopt;
12211 };
12212 if (auto MatchedMore = MatchConstMul(More)) {
12213 if (auto MatchedLess = MatchConstMul(Less)) {
12214 if (MatchedMore->second == MatchedLess->second) {
12215 More = MatchedMore->first;
12216 Less = MatchedLess->first;
12217 DiffMul *= MatchedMore->second;
12218 continue;
12219 }
12220 }
12221 }
12222
12223 // Try to cancel out common factors in two add expressions.
12225 auto Add = [&](const SCEV *S, int Mul) {
12226 if (auto *C = dyn_cast<SCEVConstant>(S)) {
12227 if (Mul == 1) {
12228 Diff += C->getAPInt() * DiffMul;
12229 } else {
12230 assert(Mul == -1);
12231 Diff -= C->getAPInt() * DiffMul;
12232 }
12233 } else
12234 Multiplicity[S] += Mul;
12235 };
12236 auto Decompose = [&](const SCEV *S, int Mul) {
12237 if (isa<SCEVAddExpr>(S)) {
12238 for (const SCEV *Op : S->operands())
12239 Add(Op, Mul);
12240 } else
12241 Add(S, Mul);
12242 };
12243 Decompose(More, 1);
12244 Decompose(Less, -1);
12245
12246 // Check whether all the non-constants cancel out, or reduce to new
12247 // More/Less values.
12248 const SCEV *NewMore = nullptr, *NewLess = nullptr;
12249 for (const auto &[S, Mul] : Multiplicity) {
12250 if (Mul == 0)
12251 continue;
12252 if (Mul == 1) {
12253 if (NewMore)
12254 return std::nullopt;
12255 NewMore = S;
12256 } else if (Mul == -1) {
12257 if (NewLess)
12258 return std::nullopt;
12259 NewLess = S;
12260 } else
12261 return std::nullopt;
12262 }
12263
12264 // Values stayed the same, no point in trying further.
12265 if (NewMore == More || NewLess == Less)
12266 return std::nullopt;
12267
12268 More = NewMore;
12269 Less = NewLess;
12270
12271 // Reduced to constant.
12272 if (!More && !Less)
12273 return Diff;
12274
12275 // Left with variable on only one side, bail out.
12276 if (!More || !Less)
12277 return std::nullopt;
12278 }
12279
12280 // Did not reduce to constant.
12281 return std::nullopt;
12282}
12283
12284bool ScalarEvolution::isImpliedCondOperandsViaAddRecStart(
12285 CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS, const SCEV *FoundLHS,
12286 const SCEV *FoundRHS, const Instruction *CtxI) {
12287 // Try to recognize the following pattern:
12288 //
12289 // FoundRHS = ...
12290 // ...
12291 // loop:
12292 // FoundLHS = {Start,+,W}
12293 // context_bb: // Basic block from the same loop
12294 // known(Pred, FoundLHS, FoundRHS)
12295 //
12296 // If some predicate is known in the context of a loop, it is also known on
12297 // each iteration of this loop, including the first iteration. Therefore, in
12298 // this case, `FoundLHS Pred FoundRHS` implies `Start Pred FoundRHS`. Try to
12299 // prove the original pred using this fact.
12300 if (!CtxI)
12301 return false;
12302 const BasicBlock *ContextBB = CtxI->getParent();
12303 // Make sure AR varies in the context block.
12304 if (auto *AR = dyn_cast<SCEVAddRecExpr>(FoundLHS)) {
12305 const Loop *L = AR->getLoop();
12306 // Make sure that context belongs to the loop and executes on 1st iteration
12307 // (if it ever executes at all).
12308 if (!L->contains(ContextBB) || !DT.dominates(ContextBB, L->getLoopLatch()))
12309 return false;
12310 if (!isAvailableAtLoopEntry(FoundRHS, AR->getLoop()))
12311 return false;
12312 return isImpliedCondOperands(Pred, LHS, RHS, AR->getStart(), FoundRHS);
12313 }
12314
12315 if (auto *AR = dyn_cast<SCEVAddRecExpr>(FoundRHS)) {
12316 const Loop *L = AR->getLoop();
12317 // Make sure that context belongs to the loop and executes on 1st iteration
12318 // (if it ever executes at all).
12319 if (!L->contains(ContextBB) || !DT.dominates(ContextBB, L->getLoopLatch()))
12320 return false;
12321 if (!isAvailableAtLoopEntry(FoundLHS, AR->getLoop()))
12322 return false;
12323 return isImpliedCondOperands(Pred, LHS, RHS, FoundLHS, AR->getStart());
12324 }
12325
12326 return false;
12327}
12328
12329bool ScalarEvolution::isImpliedCondOperandsViaNoOverflow(CmpPredicate Pred,
12330 const SCEV *LHS,
12331 const SCEV *RHS,
12332 const SCEV *FoundLHS,
12333 const SCEV *FoundRHS) {
12334 if (Pred != CmpInst::ICMP_SLT && Pred != CmpInst::ICMP_ULT)
12335 return false;
12336
12337 const auto *AddRecLHS = dyn_cast<SCEVAddRecExpr>(LHS);
12338 if (!AddRecLHS)
12339 return false;
12340
12341 const auto *AddRecFoundLHS = dyn_cast<SCEVAddRecExpr>(FoundLHS);
12342 if (!AddRecFoundLHS)
12343 return false;
12344
12345 // We'd like to let SCEV reason about control dependencies, so we constrain
12346 // both the inequalities to be about add recurrences on the same loop. This
12347 // way we can use isLoopEntryGuardedByCond later.
12348
12349 const Loop *L = AddRecFoundLHS->getLoop();
12350 if (L != AddRecLHS->getLoop())
12351 return false;
12352
12353 // FoundLHS u< FoundRHS u< -C => (FoundLHS + C) u< (FoundRHS + C) ... (1)
12354 //
12355 // FoundLHS s< FoundRHS s< INT_MIN - C => (FoundLHS + C) s< (FoundRHS + C)
12356 // ... (2)
12357 //
12358 // Informal proof for (2), assuming (1) [*]:
12359 //
12360 // We'll also assume (A s< B) <=> ((A + INT_MIN) u< (B + INT_MIN)) ... (3)[**]
12361 //
12362 // Then
12363 //
12364 // FoundLHS s< FoundRHS s< INT_MIN - C
12365 // <=> (FoundLHS + INT_MIN) u< (FoundRHS + INT_MIN) u< -C [ using (3) ]
12366 // <=> (FoundLHS + INT_MIN + C) u< (FoundRHS + INT_MIN + C) [ using (1) ]
12367 // <=> (FoundLHS + INT_MIN + C + INT_MIN) s<
12368 // (FoundRHS + INT_MIN + C + INT_MIN) [ using (3) ]
12369 // <=> FoundLHS + C s< FoundRHS + C
12370 //
12371 // [*]: (1) can be proved by ruling out overflow.
12372 //
12373 // [**]: This can be proved by analyzing all the four possibilities:
12374 // (A s< 0, B s< 0), (A s< 0, B s>= 0), (A s>= 0, B s< 0) and
12375 // (A s>= 0, B s>= 0).
12376 //
12377 // Note:
12378 // Despite (2), "FoundRHS s< INT_MIN - C" does not mean that "FoundRHS + C"
12379 // will not sign underflow. For instance, say FoundLHS = (i8 -128), FoundRHS
12380 // = (i8 -127) and C = (i8 -100). Then INT_MIN - C = (i8 -28), and FoundRHS
12381 // s< (INT_MIN - C). Lack of sign overflow / underflow in "FoundRHS + C" is
12382 // neither necessary nor sufficient to prove "(FoundLHS + C) s< (FoundRHS +
12383 // C)".
12384
12385 std::optional<APInt> LDiff = computeConstantDifference(LHS, FoundLHS);
12386 if (!LDiff)
12387 return false;
12388 std::optional<APInt> RDiff = computeConstantDifference(RHS, FoundRHS);
12389 if (!RDiff || *LDiff != *RDiff)
12390 return false;
12391
12392 if (LDiff->isMinValue())
12393 return true;
12394
12395 APInt FoundRHSLimit;
12396
12397 if (Pred == CmpInst::ICMP_ULT) {
12398 FoundRHSLimit = -(*RDiff);
12399 } else {
12400 assert(Pred == CmpInst::ICMP_SLT && "Checked above!");
12401 FoundRHSLimit = APInt::getSignedMinValue(getTypeSizeInBits(RHS->getType())) - *RDiff;
12402 }
12403
12404 // Try to prove (1) or (2), as needed.
12405 return isAvailableAtLoopEntry(FoundRHS, L) &&
12406 isLoopEntryGuardedByCond(L, Pred, FoundRHS,
12407 getConstant(FoundRHSLimit));
12408}
12409
12410bool ScalarEvolution::isImpliedViaMerge(CmpPredicate Pred, const SCEV *LHS,
12411 const SCEV *RHS, const SCEV *FoundLHS,
12412 const SCEV *FoundRHS, unsigned Depth) {
12413 const PHINode *LPhi = nullptr, *RPhi = nullptr;
12414
12415 auto ClearOnExit = make_scope_exit([&]() {
12416 if (LPhi) {
12417 bool Erased = PendingMerges.erase(LPhi);
12418 assert(Erased && "Failed to erase LPhi!");
12419 (void)Erased;
12420 }
12421 if (RPhi) {
12422 bool Erased = PendingMerges.erase(RPhi);
12423 assert(Erased && "Failed to erase RPhi!");
12424 (void)Erased;
12425 }
12426 });
12427
12428 // Find respective Phis and check that they are not being pending.
12429 if (const SCEVUnknown *LU = dyn_cast<SCEVUnknown>(LHS))
12430 if (auto *Phi = dyn_cast<PHINode>(LU->getValue())) {
12431 if (!PendingMerges.insert(Phi).second)
12432 return false;
12433 LPhi = Phi;
12434 }
12435 if (const SCEVUnknown *RU = dyn_cast<SCEVUnknown>(RHS))
12436 if (auto *Phi = dyn_cast<PHINode>(RU->getValue())) {
12437 // If we detect a loop of Phi nodes being processed by this method, for
12438 // example:
12439 //
12440 // %a = phi i32 [ %some1, %preheader ], [ %b, %latch ]
12441 // %b = phi i32 [ %some2, %preheader ], [ %a, %latch ]
12442 //
12443 // we don't want to deal with a case that complex, so return conservative
12444 // answer false.
12445 if (!PendingMerges.insert(Phi).second)
12446 return false;
12447 RPhi = Phi;
12448 }
12449
12450 // If none of LHS, RHS is a Phi, nothing to do here.
12451 if (!LPhi && !RPhi)
12452 return false;
12453
12454 // If there is a SCEVUnknown Phi we are interested in, make it left.
12455 if (!LPhi) {
12456 std::swap(LHS, RHS);
12457 std::swap(FoundLHS, FoundRHS);
12458 std::swap(LPhi, RPhi);
12460 }
12461
12462 assert(LPhi && "LPhi should definitely be a SCEVUnknown Phi!");
12463 const BasicBlock *LBB = LPhi->getParent();
12464 const SCEVAddRecExpr *RAR = dyn_cast<SCEVAddRecExpr>(RHS);
12465
12466 auto ProvedEasily = [&](const SCEV *S1, const SCEV *S2) {
12467 return isKnownViaNonRecursiveReasoning(Pred, S1, S2) ||
12468 isImpliedCondOperandsViaRanges(Pred, S1, S2, Pred, FoundLHS, FoundRHS) ||
12469 isImpliedViaOperations(Pred, S1, S2, FoundLHS, FoundRHS, Depth);
12470 };
12471
12472 if (RPhi && RPhi->getParent() == LBB) {
12473 // Case one: RHS is also a SCEVUnknown Phi from the same basic block.
12474 // If we compare two Phis from the same block, and for each entry block
12475 // the predicate is true for incoming values from this block, then the
12476 // predicate is also true for the Phis.
12477 for (const BasicBlock *IncBB : predecessors(LBB)) {
12478 const SCEV *L = getSCEV(LPhi->getIncomingValueForBlock(IncBB));
12479 const SCEV *R = getSCEV(RPhi->getIncomingValueForBlock(IncBB));
12480 if (!ProvedEasily(L, R))
12481 return false;
12482 }
12483 } else if (RAR && RAR->getLoop()->getHeader() == LBB) {
12484 // Case two: RHS is also a Phi from the same basic block, and it is an
12485 // AddRec. It means that there is a loop which has both AddRec and Unknown
12486 // PHIs, for it we can compare incoming values of AddRec from above the loop
12487 // and latch with their respective incoming values of LPhi.
12488 // TODO: Generalize to handle loops with many inputs in a header.
12489 if (LPhi->getNumIncomingValues() != 2) return false;
12490
12491 auto *RLoop = RAR->getLoop();
12492 auto *Predecessor = RLoop->getLoopPredecessor();
12493 assert(Predecessor && "Loop with AddRec with no predecessor?");
12494 const SCEV *L1 = getSCEV(LPhi->getIncomingValueForBlock(Predecessor));
12495 if (!ProvedEasily(L1, RAR->getStart()))
12496 return false;
12497 auto *Latch = RLoop->getLoopLatch();
12498 assert(Latch && "Loop with AddRec with no latch?");
12499 const SCEV *L2 = getSCEV(LPhi->getIncomingValueForBlock(Latch));
12500 if (!ProvedEasily(L2, RAR->getPostIncExpr(*this)))
12501 return false;
12502 } else {
12503 // In all other cases go over inputs of LHS and compare each of them to RHS,
12504 // the predicate is true for (LHS, RHS) if it is true for all such pairs.
12505 // At this point RHS is either a non-Phi, or it is a Phi from some block
12506 // different from LBB.
12507 for (const BasicBlock *IncBB : predecessors(LBB)) {
12508 // Check that RHS is available in this block.
12509 if (!dominates(RHS, IncBB))
12510 return false;
12511 const SCEV *L = getSCEV(LPhi->getIncomingValueForBlock(IncBB));
12512 // Make sure L does not refer to a value from a potentially previous
12513 // iteration of a loop.
12514 if (!properlyDominates(L, LBB))
12515 return false;
12516 // Addrecs are considered to properly dominate their loop, so are missed
12517 // by the previous check. Discard any values that have computable
12518 // evolution in this loop.
12519 if (auto *Loop = LI.getLoopFor(LBB))
12520 if (hasComputableLoopEvolution(L, Loop))
12521 return false;
12522 if (!ProvedEasily(L, RHS))
12523 return false;
12524 }
12525 }
12526 return true;
12527}
12528
12529bool ScalarEvolution::isImpliedCondOperandsViaShift(CmpPredicate Pred,
12530 const SCEV *LHS,
12531 const SCEV *RHS,
12532 const SCEV *FoundLHS,
12533 const SCEV *FoundRHS) {
12534 // We want to imply LHS < RHS from LHS < (RHS >> shiftvalue). First, make
12535 // sure that we are dealing with same LHS.
12536 if (RHS == FoundRHS) {
12537 std::swap(LHS, RHS);
12538 std::swap(FoundLHS, FoundRHS);
12540 }
12541 if (LHS != FoundLHS)
12542 return false;
12543
12544 auto *SUFoundRHS = dyn_cast<SCEVUnknown>(FoundRHS);
12545 if (!SUFoundRHS)
12546 return false;
12547
12548 Value *Shiftee, *ShiftValue;
12549
12550 using namespace PatternMatch;
12551 if (match(SUFoundRHS->getValue(),
12552 m_LShr(m_Value(Shiftee), m_Value(ShiftValue)))) {
12553 auto *ShifteeS = getSCEV(Shiftee);
12554 // Prove one of the following:
12555 // LHS <u (shiftee >> shiftvalue) && shiftee <=u RHS ---> LHS <u RHS
12556 // LHS <=u (shiftee >> shiftvalue) && shiftee <=u RHS ---> LHS <=u RHS
12557 // LHS <s (shiftee >> shiftvalue) && shiftee <=s RHS && shiftee >=s 0
12558 // ---> LHS <s RHS
12559 // LHS <=s (shiftee >> shiftvalue) && shiftee <=s RHS && shiftee >=s 0
12560 // ---> LHS <=s RHS
12561 if (Pred == ICmpInst::ICMP_ULT || Pred == ICmpInst::ICMP_ULE)
12562 return isKnownPredicate(ICmpInst::ICMP_ULE, ShifteeS, RHS);
12563 if (Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_SLE)
12564 if (isKnownNonNegative(ShifteeS))
12565 return isKnownPredicate(ICmpInst::ICMP_SLE, ShifteeS, RHS);
12566 }
12567
12568 return false;
12569}
12570
12571bool ScalarEvolution::isImpliedCondOperands(CmpPredicate Pred, const SCEV *LHS,
12572 const SCEV *RHS,
12573 const SCEV *FoundLHS,
12574 const SCEV *FoundRHS,
12575 const Instruction *CtxI) {
12576 return isImpliedCondOperandsViaRanges(Pred, LHS, RHS, Pred, FoundLHS,
12577 FoundRHS) ||
12578 isImpliedCondOperandsViaNoOverflow(Pred, LHS, RHS, FoundLHS,
12579 FoundRHS) ||
12580 isImpliedCondOperandsViaShift(Pred, LHS, RHS, FoundLHS, FoundRHS) ||
12581 isImpliedCondOperandsViaAddRecStart(Pred, LHS, RHS, FoundLHS, FoundRHS,
12582 CtxI) ||
12583 isImpliedCondOperandsHelper(Pred, LHS, RHS, FoundLHS, FoundRHS);
12584}
12585
12586/// Is MaybeMinMaxExpr an (U|S)(Min|Max) of Candidate and some other values?
12587template <typename MinMaxExprType>
12588static bool IsMinMaxConsistingOf(const SCEV *MaybeMinMaxExpr,
12589 const SCEV *Candidate) {
12590 const MinMaxExprType *MinMaxExpr = dyn_cast<MinMaxExprType>(MaybeMinMaxExpr);
12591 if (!MinMaxExpr)
12592 return false;
12593
12594 return is_contained(MinMaxExpr->operands(), Candidate);
12595}
12596
12598 CmpPredicate Pred, const SCEV *LHS,
12599 const SCEV *RHS) {
12600 // If both sides are affine addrecs for the same loop, with equal
12601 // steps, and we know the recurrences don't wrap, then we only
12602 // need to check the predicate on the starting values.
12603
12604 if (!ICmpInst::isRelational(Pred))
12605 return false;
12606
12607 const SCEV *LStart, *RStart, *Step;
12608 const Loop *L;
12609 if (!match(LHS,
12610 m_scev_AffineAddRec(m_SCEV(LStart), m_SCEV(Step), m_Loop(L))) ||
12612 m_SpecificLoop(L))))
12613 return false;
12618 if (!LAR->getNoWrapFlags(NW) || !RAR->getNoWrapFlags(NW))
12619 return false;
12620
12621 return SE.isKnownPredicate(Pred, LStart, RStart);
12622}
12623
12624/// Is LHS `Pred` RHS true on the virtue of LHS or RHS being a Min or Max
12625/// expression?
12627 const SCEV *LHS, const SCEV *RHS) {
12628 switch (Pred) {
12629 default:
12630 return false;
12631
12632 case ICmpInst::ICMP_SGE:
12633 std::swap(LHS, RHS);
12634 [[fallthrough]];
12635 case ICmpInst::ICMP_SLE:
12636 return
12637 // min(A, ...) <= A
12639 // A <= max(A, ...)
12641
12642 case ICmpInst::ICMP_UGE:
12643 std::swap(LHS, RHS);
12644 [[fallthrough]];
12645 case ICmpInst::ICMP_ULE:
12646 return
12647 // min(A, ...) <= A
12648 // FIXME: what about umin_seq?
12650 // A <= max(A, ...)
12652 }
12653
12654 llvm_unreachable("covered switch fell through?!");
12655}
12656
12657bool ScalarEvolution::isImpliedViaOperations(CmpPredicate Pred, const SCEV *LHS,
12658 const SCEV *RHS,
12659 const SCEV *FoundLHS,
12660 const SCEV *FoundRHS,
12661 unsigned Depth) {
12664 "LHS and RHS have different sizes?");
12665 assert(getTypeSizeInBits(FoundLHS->getType()) ==
12666 getTypeSizeInBits(FoundRHS->getType()) &&
12667 "FoundLHS and FoundRHS have different sizes?");
12668 // We want to avoid hurting the compile time with analysis of too big trees.
12670 return false;
12671
12672 // We only want to work with GT comparison so far.
12673 if (ICmpInst::isLT(Pred)) {
12675 std::swap(LHS, RHS);
12676 std::swap(FoundLHS, FoundRHS);
12677 }
12678
12680
12681 // For unsigned, try to reduce it to corresponding signed comparison.
12682 if (P == ICmpInst::ICMP_UGT)
12683 // We can replace unsigned predicate with its signed counterpart if all
12684 // involved values are non-negative.
12685 // TODO: We could have better support for unsigned.
12686 if (isKnownNonNegative(FoundLHS) && isKnownNonNegative(FoundRHS)) {
12687 // Knowing that both FoundLHS and FoundRHS are non-negative, and knowing
12688 // FoundLHS >u FoundRHS, we also know that FoundLHS >s FoundRHS. Let us
12689 // use this fact to prove that LHS and RHS are non-negative.
12690 const SCEV *MinusOne = getMinusOne(LHS->getType());
12691 if (isImpliedCondOperands(ICmpInst::ICMP_SGT, LHS, MinusOne, FoundLHS,
12692 FoundRHS) &&
12693 isImpliedCondOperands(ICmpInst::ICMP_SGT, RHS, MinusOne, FoundLHS,
12694 FoundRHS))
12696 }
12697
12698 if (P != ICmpInst::ICMP_SGT)
12699 return false;
12700
12701 auto GetOpFromSExt = [&](const SCEV *S) {
12702 if (auto *Ext = dyn_cast<SCEVSignExtendExpr>(S))
12703 return Ext->getOperand();
12704 // TODO: If S is a SCEVConstant then you can cheaply "strip" the sext off
12705 // the constant in some cases.
12706 return S;
12707 };
12708
12709 // Acquire values from extensions.
12710 auto *OrigLHS = LHS;
12711 auto *OrigFoundLHS = FoundLHS;
12712 LHS = GetOpFromSExt(LHS);
12713 FoundLHS = GetOpFromSExt(FoundLHS);
12714
12715 // Is the SGT predicate can be proved trivially or using the found context.
12716 auto IsSGTViaContext = [&](const SCEV *S1, const SCEV *S2) {
12717 return isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_SGT, S1, S2) ||
12718 isImpliedViaOperations(ICmpInst::ICMP_SGT, S1, S2, OrigFoundLHS,
12719 FoundRHS, Depth + 1);
12720 };
12721
12722 if (auto *LHSAddExpr = dyn_cast<SCEVAddExpr>(LHS)) {
12723 // We want to avoid creation of any new non-constant SCEV. Since we are
12724 // going to compare the operands to RHS, we should be certain that we don't
12725 // need any size extensions for this. So let's decline all cases when the
12726 // sizes of types of LHS and RHS do not match.
12727 // TODO: Maybe try to get RHS from sext to catch more cases?
12729 return false;
12730
12731 // Should not overflow.
12732 if (!LHSAddExpr->hasNoSignedWrap())
12733 return false;
12734
12735 auto *LL = LHSAddExpr->getOperand(0);
12736 auto *LR = LHSAddExpr->getOperand(1);
12737 auto *MinusOne = getMinusOne(RHS->getType());
12738
12739 // Checks that S1 >= 0 && S2 > RHS, trivially or using the found context.
12740 auto IsSumGreaterThanRHS = [&](const SCEV *S1, const SCEV *S2) {
12741 return IsSGTViaContext(S1, MinusOne) && IsSGTViaContext(S2, RHS);
12742 };
12743 // Try to prove the following rule:
12744 // (LHS = LL + LR) && (LL >= 0) && (LR > RHS) => (LHS > RHS).
12745 // (LHS = LL + LR) && (LR >= 0) && (LL > RHS) => (LHS > RHS).
12746 if (IsSumGreaterThanRHS(LL, LR) || IsSumGreaterThanRHS(LR, LL))
12747 return true;
12748 } else if (auto *LHSUnknownExpr = dyn_cast<SCEVUnknown>(LHS)) {
12749 Value *LL, *LR;
12750 // FIXME: Once we have SDiv implemented, we can get rid of this matching.
12751
12752 using namespace llvm::PatternMatch;
12753
12754 if (match(LHSUnknownExpr->getValue(), m_SDiv(m_Value(LL), m_Value(LR)))) {
12755 // Rules for division.
12756 // We are going to perform some comparisons with Denominator and its
12757 // derivative expressions. In general case, creating a SCEV for it may
12758 // lead to a complex analysis of the entire graph, and in particular it
12759 // can request trip count recalculation for the same loop. This would
12760 // cache as SCEVCouldNotCompute to avoid the infinite recursion. To avoid
12761 // this, we only want to create SCEVs that are constants in this section.
12762 // So we bail if Denominator is not a constant.
12763 if (!isa<ConstantInt>(LR))
12764 return false;
12765
12766 auto *Denominator = cast<SCEVConstant>(getSCEV(LR));
12767
12768 // We want to make sure that LHS = FoundLHS / Denominator. If it is so,
12769 // then a SCEV for the numerator already exists and matches with FoundLHS.
12770 auto *Numerator = getExistingSCEV(LL);
12771 if (!Numerator || Numerator->getType() != FoundLHS->getType())
12772 return false;
12773
12774 // Make sure that the numerator matches with FoundLHS and the denominator
12775 // is positive.
12776 if (!HasSameValue(Numerator, FoundLHS) || !isKnownPositive(Denominator))
12777 return false;
12778
12779 auto *DTy = Denominator->getType();
12780 auto *FRHSTy = FoundRHS->getType();
12781 if (DTy->isPointerTy() != FRHSTy->isPointerTy())
12782 // One of types is a pointer and another one is not. We cannot extend
12783 // them properly to a wider type, so let us just reject this case.
12784 // TODO: Usage of getEffectiveSCEVType for DTy, FRHSTy etc should help
12785 // to avoid this check.
12786 return false;
12787
12788 // Given that:
12789 // FoundLHS > FoundRHS, LHS = FoundLHS / Denominator, Denominator > 0.
12790 auto *WTy = getWiderType(DTy, FRHSTy);
12791 auto *DenominatorExt = getNoopOrSignExtend(Denominator, WTy);
12792 auto *FoundRHSExt = getNoopOrSignExtend(FoundRHS, WTy);
12793
12794 // Try to prove the following rule:
12795 // (FoundRHS > Denominator - 2) && (RHS <= 0) => (LHS > RHS).
12796 // For example, given that FoundLHS > 2. It means that FoundLHS is at
12797 // least 3. If we divide it by Denominator < 4, we will have at least 1.
12798 auto *DenomMinusTwo = getMinusSCEV(DenominatorExt, getConstant(WTy, 2));
12799 if (isKnownNonPositive(RHS) &&
12800 IsSGTViaContext(FoundRHSExt, DenomMinusTwo))
12801 return true;
12802
12803 // Try to prove the following rule:
12804 // (FoundRHS > -1 - Denominator) && (RHS < 0) => (LHS > RHS).
12805 // For example, given that FoundLHS > -3. Then FoundLHS is at least -2.
12806 // If we divide it by Denominator > 2, then:
12807 // 1. If FoundLHS is negative, then the result is 0.
12808 // 2. If FoundLHS is non-negative, then the result is non-negative.
12809 // Anyways, the result is non-negative.
12810 auto *MinusOne = getMinusOne(WTy);
12811 auto *NegDenomMinusOne = getMinusSCEV(MinusOne, DenominatorExt);
12812 if (isKnownNegative(RHS) &&
12813 IsSGTViaContext(FoundRHSExt, NegDenomMinusOne))
12814 return true;
12815 }
12816 }
12817
12818 // If our expression contained SCEVUnknown Phis, and we split it down and now
12819 // need to prove something for them, try to prove the predicate for every
12820 // possible incoming values of those Phis.
12821 if (isImpliedViaMerge(Pred, OrigLHS, RHS, OrigFoundLHS, FoundRHS, Depth + 1))
12822 return true;
12823
12824 return false;
12825}
12826
12828 const SCEV *RHS) {
12829 // zext x u<= sext x, sext x s<= zext x
12830 const SCEV *Op;
12831 switch (Pred) {
12832 case ICmpInst::ICMP_SGE:
12833 std::swap(LHS, RHS);
12834 [[fallthrough]];
12835 case ICmpInst::ICMP_SLE: {
12836 // If operand >=s 0 then ZExt == SExt. If operand <s 0 then SExt <s ZExt.
12837 return match(LHS, m_scev_SExt(m_SCEV(Op))) &&
12839 }
12840 case ICmpInst::ICMP_UGE:
12841 std::swap(LHS, RHS);
12842 [[fallthrough]];
12843 case ICmpInst::ICMP_ULE: {
12844 // If operand >=u 0 then ZExt == SExt. If operand <u 0 then ZExt <u SExt.
12845 return match(LHS, m_scev_ZExt(m_SCEV(Op))) &&
12847 }
12848 default:
12849 return false;
12850 };
12851 llvm_unreachable("unhandled case");
12852}
12853
12854bool ScalarEvolution::isKnownViaNonRecursiveReasoning(CmpPredicate Pred,
12855 const SCEV *LHS,
12856 const SCEV *RHS) {
12857 return isKnownPredicateExtendIdiom(Pred, LHS, RHS) ||
12858 isKnownPredicateViaConstantRanges(Pred, LHS, RHS) ||
12859 IsKnownPredicateViaMinOrMax(*this, Pred, LHS, RHS) ||
12860 IsKnownPredicateViaAddRecStart(*this, Pred, LHS, RHS) ||
12861 isKnownPredicateViaNoOverflow(Pred, LHS, RHS);
12862}
12863
12864bool ScalarEvolution::isImpliedCondOperandsHelper(CmpPredicate Pred,
12865 const SCEV *LHS,
12866 const SCEV *RHS,
12867 const SCEV *FoundLHS,
12868 const SCEV *FoundRHS) {
12869 switch (Pred) {
12870 default:
12871 llvm_unreachable("Unexpected CmpPredicate value!");
12872 case ICmpInst::ICMP_EQ:
12873 case ICmpInst::ICMP_NE:
12874 if (HasSameValue(LHS, FoundLHS) && HasSameValue(RHS, FoundRHS))
12875 return true;
12876 break;
12877 case ICmpInst::ICMP_SLT:
12878 case ICmpInst::ICMP_SLE:
12879 if (isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_SLE, LHS, FoundLHS) &&
12880 isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_SGE, RHS, FoundRHS))
12881 return true;
12882 break;
12883 case ICmpInst::ICMP_SGT:
12884 case ICmpInst::ICMP_SGE:
12885 if (isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_SGE, LHS, FoundLHS) &&
12886 isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_SLE, RHS, FoundRHS))
12887 return true;
12888 break;
12889 case ICmpInst::ICMP_ULT:
12890 case ICmpInst::ICMP_ULE:
12891 if (isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_ULE, LHS, FoundLHS) &&
12892 isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_UGE, RHS, FoundRHS))
12893 return true;
12894 break;
12895 case ICmpInst::ICMP_UGT:
12896 case ICmpInst::ICMP_UGE:
12897 if (isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_UGE, LHS, FoundLHS) &&
12898 isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_ULE, RHS, FoundRHS))
12899 return true;
12900 break;
12901 }
12902
12903 // Maybe it can be proved via operations?
12904 if (isImpliedViaOperations(Pred, LHS, RHS, FoundLHS, FoundRHS))
12905 return true;
12906
12907 return false;
12908}
12909
12910bool ScalarEvolution::isImpliedCondOperandsViaRanges(
12911 CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS, CmpPredicate FoundPred,
12912 const SCEV *FoundLHS, const SCEV *FoundRHS) {
12913 if (!isa<SCEVConstant>(RHS) || !isa<SCEVConstant>(FoundRHS))
12914 // The restriction on `FoundRHS` be lifted easily -- it exists only to
12915 // reduce the compile time impact of this optimization.
12916 return false;
12917
12918 std::optional<APInt> Addend = computeConstantDifference(LHS, FoundLHS);
12919 if (!Addend)
12920 return false;
12921
12922 const APInt &ConstFoundRHS = cast<SCEVConstant>(FoundRHS)->getAPInt();
12923
12924 // `FoundLHSRange` is the range we know `FoundLHS` to be in by virtue of the
12925 // antecedent "`FoundLHS` `FoundPred` `FoundRHS`".
12926 ConstantRange FoundLHSRange =
12927 ConstantRange::makeExactICmpRegion(FoundPred, ConstFoundRHS);
12928
12929 // Since `LHS` is `FoundLHS` + `Addend`, we can compute a range for `LHS`:
12930 ConstantRange LHSRange = FoundLHSRange.add(ConstantRange(*Addend));
12931
12932 // We can also compute the range of values for `LHS` that satisfy the
12933 // consequent, "`LHS` `Pred` `RHS`":
12934 const APInt &ConstRHS = cast<SCEVConstant>(RHS)->getAPInt();
12935 // The antecedent implies the consequent if every value of `LHS` that
12936 // satisfies the antecedent also satisfies the consequent.
12937 return LHSRange.icmp(Pred, ConstRHS);
12938}
12939
12940bool ScalarEvolution::canIVOverflowOnLT(const SCEV *RHS, const SCEV *Stride,
12941 bool IsSigned) {
12942 assert(isKnownPositive(Stride) && "Positive stride expected!");
12943
12944 unsigned BitWidth = getTypeSizeInBits(RHS->getType());
12945 const SCEV *One = getOne(Stride->getType());
12946
12947 if (IsSigned) {
12948 APInt MaxRHS = getSignedRangeMax(RHS);
12949 APInt MaxValue = APInt::getSignedMaxValue(BitWidth);
12950 APInt MaxStrideMinusOne = getSignedRangeMax(getMinusSCEV(Stride, One));
12951
12952 // SMaxRHS + SMaxStrideMinusOne > SMaxValue => overflow!
12953 return (std::move(MaxValue) - MaxStrideMinusOne).slt(MaxRHS);
12954 }
12955
12956 APInt MaxRHS = getUnsignedRangeMax(RHS);
12957 APInt MaxValue = APInt::getMaxValue(BitWidth);
12958 APInt MaxStrideMinusOne = getUnsignedRangeMax(getMinusSCEV(Stride, One));
12959
12960 // UMaxRHS + UMaxStrideMinusOne > UMaxValue => overflow!
12961 return (std::move(MaxValue) - MaxStrideMinusOne).ult(MaxRHS);
12962}
12963
12964bool ScalarEvolution::canIVOverflowOnGT(const SCEV *RHS, const SCEV *Stride,
12965 bool IsSigned) {
12966
12967 unsigned BitWidth = getTypeSizeInBits(RHS->getType());
12968 const SCEV *One = getOne(Stride->getType());
12969
12970 if (IsSigned) {
12971 APInt MinRHS = getSignedRangeMin(RHS);
12972 APInt MinValue = APInt::getSignedMinValue(BitWidth);
12973 APInt MaxStrideMinusOne = getSignedRangeMax(getMinusSCEV(Stride, One));
12974
12975 // SMinRHS - SMaxStrideMinusOne < SMinValue => overflow!
12976 return (std::move(MinValue) + MaxStrideMinusOne).sgt(MinRHS);
12977 }
12978
12979 APInt MinRHS = getUnsignedRangeMin(RHS);
12980 APInt MinValue = APInt::getMinValue(BitWidth);
12981 APInt MaxStrideMinusOne = getUnsignedRangeMax(getMinusSCEV(Stride, One));
12982
12983 // UMinRHS - UMaxStrideMinusOne < UMinValue => overflow!
12984 return (std::move(MinValue) + MaxStrideMinusOne).ugt(MinRHS);
12985}
12986
12988 // umin(N, 1) + floor((N - umin(N, 1)) / D)
12989 // This is equivalent to "1 + floor((N - 1) / D)" for N != 0. The umin
12990 // expression fixes the case of N=0.
12991 const SCEV *MinNOne = getUMinExpr(N, getOne(N->getType()));
12992 const SCEV *NMinusOne = getMinusSCEV(N, MinNOne);
12993 return getAddExpr(MinNOne, getUDivExpr(NMinusOne, D));
12994}
12995
12996const SCEV *ScalarEvolution::computeMaxBECountForLT(const SCEV *Start,
12997 const SCEV *Stride,
12998 const SCEV *End,
12999 unsigned BitWidth,
13000 bool IsSigned) {
13001 // The logic in this function assumes we can represent a positive stride.
13002 // If we can't, the backedge-taken count must be zero.
13003 if (IsSigned && BitWidth == 1)
13004 return getZero(Stride->getType());
13005
13006 // This code below only been closely audited for negative strides in the
13007 // unsigned comparison case, it may be correct for signed comparison, but
13008 // that needs to be established.
13009 if (IsSigned && isKnownNegative(Stride))
13010 return getCouldNotCompute();
13011
13012 // Calculate the maximum backedge count based on the range of values
13013 // permitted by Start, End, and Stride.
13014 APInt MinStart =
13015 IsSigned ? getSignedRangeMin(Start) : getUnsignedRangeMin(Start);
13016
13017 APInt MinStride =
13018 IsSigned ? getSignedRangeMin(Stride) : getUnsignedRangeMin(Stride);
13019
13020 // We assume either the stride is positive, or the backedge-taken count
13021 // is zero. So force StrideForMaxBECount to be at least one.
13022 APInt One(BitWidth, 1);
13023 APInt StrideForMaxBECount = IsSigned ? APIntOps::smax(One, MinStride)
13024 : APIntOps::umax(One, MinStride);
13025
13026 APInt MaxValue = IsSigned ? APInt::getSignedMaxValue(BitWidth)
13027 : APInt::getMaxValue(BitWidth);
13028 APInt Limit = MaxValue - (StrideForMaxBECount - 1);
13029
13030 // Although End can be a MAX expression we estimate MaxEnd considering only
13031 // the case End = RHS of the loop termination condition. This is safe because
13032 // in the other case (End - Start) is zero, leading to a zero maximum backedge
13033 // taken count.
13034 APInt MaxEnd = IsSigned ? APIntOps::smin(getSignedRangeMax(End), Limit)
13035 : APIntOps::umin(getUnsignedRangeMax(End), Limit);
13036
13037 // MaxBECount = ceil((max(MaxEnd, MinStart) - MinStart) / Stride)
13038 MaxEnd = IsSigned ? APIntOps::smax(MaxEnd, MinStart)
13039 : APIntOps::umax(MaxEnd, MinStart);
13040
13041 return getUDivCeilSCEV(getConstant(MaxEnd - MinStart) /* Delta */,
13042 getConstant(StrideForMaxBECount) /* Step */);
13043}
13044
13046ScalarEvolution::howManyLessThans(const SCEV *LHS, const SCEV *RHS,
13047 const Loop *L, bool IsSigned,
13048 bool ControlsOnlyExit, bool AllowPredicates) {
13050
13052 bool PredicatedIV = false;
13053 if (!IV) {
13054 if (auto *ZExt = dyn_cast<SCEVZeroExtendExpr>(LHS)) {
13055 const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(ZExt->getOperand());
13056 if (AR && AR->getLoop() == L && AR->isAffine()) {
13057 auto canProveNUW = [&]() {
13058 // We can use the comparison to infer no-wrap flags only if it fully
13059 // controls the loop exit.
13060 if (!ControlsOnlyExit)
13061 return false;
13062
13063 if (!isLoopInvariant(RHS, L))
13064 return false;
13065
13066 if (!isKnownNonZero(AR->getStepRecurrence(*this)))
13067 // We need the sequence defined by AR to strictly increase in the
13068 // unsigned integer domain for the logic below to hold.
13069 return false;
13070
13071 const unsigned InnerBitWidth = getTypeSizeInBits(AR->getType());
13072 const unsigned OuterBitWidth = getTypeSizeInBits(RHS->getType());
13073 // If RHS <=u Limit, then there must exist a value V in the sequence
13074 // defined by AR (e.g. {Start,+,Step}) such that V >u RHS, and
13075 // V <=u UINT_MAX. Thus, we must exit the loop before unsigned
13076 // overflow occurs. This limit also implies that a signed comparison
13077 // (in the wide bitwidth) is equivalent to an unsigned comparison as
13078 // the high bits on both sides must be zero.
13079 APInt StrideMax = getUnsignedRangeMax(AR->getStepRecurrence(*this));
13080 APInt Limit = APInt::getMaxValue(InnerBitWidth) - (StrideMax - 1);
13081 Limit = Limit.zext(OuterBitWidth);
13082 return getUnsignedRangeMax(applyLoopGuards(RHS, L)).ule(Limit);
13083 };
13084 auto Flags = AR->getNoWrapFlags();
13085 if (!hasFlags(Flags, SCEV::FlagNUW) && canProveNUW())
13086 Flags = setFlags(Flags, SCEV::FlagNUW);
13087
13088 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), Flags);
13089 if (AR->hasNoUnsignedWrap()) {
13090 // Emulate what getZeroExtendExpr would have done during construction
13091 // if we'd been able to infer the fact just above at that time.
13092 const SCEV *Step = AR->getStepRecurrence(*this);
13093 Type *Ty = ZExt->getType();
13094 auto *S = getAddRecExpr(
13096 getZeroExtendExpr(Step, Ty, 0), L, AR->getNoWrapFlags());
13098 }
13099 }
13100 }
13101 }
13102
13103
13104 if (!IV && AllowPredicates) {
13105 // Try to make this an AddRec using runtime tests, in the first X
13106 // iterations of this loop, where X is the SCEV expression found by the
13107 // algorithm below.
13108 IV = convertSCEVToAddRecWithPredicates(LHS, L, Predicates);
13109 PredicatedIV = true;
13110 }
13111
13112 // Avoid weird loops
13113 if (!IV || IV->getLoop() != L || !IV->isAffine())
13114 return getCouldNotCompute();
13115
13116 // A precondition of this method is that the condition being analyzed
13117 // reaches an exiting branch which dominates the latch. Given that, we can
13118 // assume that an increment which violates the nowrap specification and
13119 // produces poison must cause undefined behavior when the resulting poison
13120 // value is branched upon and thus we can conclude that the backedge is
13121 // taken no more often than would be required to produce that poison value.
13122 // Note that a well defined loop can exit on the iteration which violates
13123 // the nowrap specification if there is another exit (either explicit or
13124 // implicit/exceptional) which causes the loop to execute before the
13125 // exiting instruction we're analyzing would trigger UB.
13126 auto WrapType = IsSigned ? SCEV::FlagNSW : SCEV::FlagNUW;
13127 bool NoWrap = ControlsOnlyExit && IV->getNoWrapFlags(WrapType);
13129
13130 const SCEV *Stride = IV->getStepRecurrence(*this);
13131
13132 bool PositiveStride = isKnownPositive(Stride);
13133
13134 // Avoid negative or zero stride values.
13135 if (!PositiveStride) {
13136 // We can compute the correct backedge taken count for loops with unknown
13137 // strides if we can prove that the loop is not an infinite loop with side
13138 // effects. Here's the loop structure we are trying to handle -
13139 //
13140 // i = start
13141 // do {
13142 // A[i] = i;
13143 // i += s;
13144 // } while (i < end);
13145 //
13146 // The backedge taken count for such loops is evaluated as -
13147 // (max(end, start + stride) - start - 1) /u stride
13148 //
13149 // The additional preconditions that we need to check to prove correctness
13150 // of the above formula is as follows -
13151 //
13152 // a) IV is either nuw or nsw depending upon signedness (indicated by the
13153 // NoWrap flag).
13154 // b) the loop is guaranteed to be finite (e.g. is mustprogress and has
13155 // no side effects within the loop)
13156 // c) loop has a single static exit (with no abnormal exits)
13157 //
13158 // Precondition a) implies that if the stride is negative, this is a single
13159 // trip loop. The backedge taken count formula reduces to zero in this case.
13160 //
13161 // Precondition b) and c) combine to imply that if rhs is invariant in L,
13162 // then a zero stride means the backedge can't be taken without executing
13163 // undefined behavior.
13164 //
13165 // The positive stride case is the same as isKnownPositive(Stride) returning
13166 // true (original behavior of the function).
13167 //
13168 if (PredicatedIV || !NoWrap || !loopIsFiniteByAssumption(L) ||
13170 return getCouldNotCompute();
13171
13172 if (!isKnownNonZero(Stride)) {
13173 // If we have a step of zero, and RHS isn't invariant in L, we don't know
13174 // if it might eventually be greater than start and if so, on which
13175 // iteration. We can't even produce a useful upper bound.
13176 if (!isLoopInvariant(RHS, L))
13177 return getCouldNotCompute();
13178
13179 // We allow a potentially zero stride, but we need to divide by stride
13180 // below. Since the loop can't be infinite and this check must control
13181 // the sole exit, we can infer the exit must be taken on the first
13182 // iteration (e.g. backedge count = 0) if the stride is zero. Given that,
13183 // we know the numerator in the divides below must be zero, so we can
13184 // pick an arbitrary non-zero value for the denominator (e.g. stride)
13185 // and produce the right result.
13186 // FIXME: Handle the case where Stride is poison?
13187 auto wouldZeroStrideBeUB = [&]() {
13188 // Proof by contradiction. Suppose the stride were zero. If we can
13189 // prove that the backedge *is* taken on the first iteration, then since
13190 // we know this condition controls the sole exit, we must have an
13191 // infinite loop. We can't have a (well defined) infinite loop per
13192 // check just above.
13193 // Note: The (Start - Stride) term is used to get the start' term from
13194 // (start' + stride,+,stride). Remember that we only care about the
13195 // result of this expression when stride == 0 at runtime.
13196 auto *StartIfZero = getMinusSCEV(IV->getStart(), Stride);
13197 return isLoopEntryGuardedByCond(L, Cond, StartIfZero, RHS);
13198 };
13199 if (!wouldZeroStrideBeUB()) {
13200 Stride = getUMaxExpr(Stride, getOne(Stride->getType()));
13201 }
13202 }
13203 } else if (!NoWrap) {
13204 // Avoid proven overflow cases: this will ensure that the backedge taken
13205 // count will not generate any unsigned overflow.
13206 if (canIVOverflowOnLT(RHS, Stride, IsSigned))
13207 return getCouldNotCompute();
13208 }
13209
13210 // On all paths just preceeding, we established the following invariant:
13211 // IV can be assumed not to overflow up to and including the exiting
13212 // iteration. We proved this in one of two ways:
13213 // 1) We can show overflow doesn't occur before the exiting iteration
13214 // 1a) canIVOverflowOnLT, and b) step of one
13215 // 2) We can show that if overflow occurs, the loop must execute UB
13216 // before any possible exit.
13217 // Note that we have not yet proved RHS invariant (in general).
13218
13219 const SCEV *Start = IV->getStart();
13220
13221 // Preserve pointer-typed Start/RHS to pass to isLoopEntryGuardedByCond.
13222 // If we convert to integers, isLoopEntryGuardedByCond will miss some cases.
13223 // Use integer-typed versions for actual computation; we can't subtract
13224 // pointers in general.
13225 const SCEV *OrigStart = Start;
13226 const SCEV *OrigRHS = RHS;
13227 if (Start->getType()->isPointerTy()) {
13229 if (isa<SCEVCouldNotCompute>(Start))
13230 return Start;
13231 }
13232 if (RHS->getType()->isPointerTy()) {
13235 return RHS;
13236 }
13237
13238 const SCEV *End = nullptr, *BECount = nullptr,
13239 *BECountIfBackedgeTaken = nullptr;
13240 if (!isLoopInvariant(RHS, L)) {
13241 const auto *RHSAddRec = dyn_cast<SCEVAddRecExpr>(RHS);
13242 if (PositiveStride && RHSAddRec != nullptr && RHSAddRec->getLoop() == L &&
13243 RHSAddRec->getNoWrapFlags()) {
13244 // The structure of loop we are trying to calculate backedge count of:
13245 //
13246 // left = left_start
13247 // right = right_start
13248 //
13249 // while(left < right){
13250 // ... do something here ...
13251 // left += s1; // stride of left is s1 (s1 > 0)
13252 // right += s2; // stride of right is s2 (s2 < 0)
13253 // }
13254 //
13255
13256 const SCEV *RHSStart = RHSAddRec->getStart();
13257 const SCEV *RHSStride = RHSAddRec->getStepRecurrence(*this);
13258
13259 // If Stride - RHSStride is positive and does not overflow, we can write
13260 // backedge count as ->
13261 // ceil((End - Start) /u (Stride - RHSStride))
13262 // Where, End = max(RHSStart, Start)
13263
13264 // Check if RHSStride < 0 and Stride - RHSStride will not overflow.
13265 if (isKnownNegative(RHSStride) &&
13266 willNotOverflow(Instruction::Sub, /*Signed=*/true, Stride,
13267 RHSStride)) {
13268
13269 const SCEV *Denominator = getMinusSCEV(Stride, RHSStride);
13270 if (isKnownPositive(Denominator)) {
13271 End = IsSigned ? getSMaxExpr(RHSStart, Start)
13272 : getUMaxExpr(RHSStart, Start);
13273
13274 // We can do this because End >= Start, as End = max(RHSStart, Start)
13275 const SCEV *Delta = getMinusSCEV(End, Start);
13276
13277 BECount = getUDivCeilSCEV(Delta, Denominator);
13278 BECountIfBackedgeTaken =
13279 getUDivCeilSCEV(getMinusSCEV(RHSStart, Start), Denominator);
13280 }
13281 }
13282 }
13283 if (BECount == nullptr) {
13284 // If we cannot calculate ExactBECount, we can calculate the MaxBECount,
13285 // given the start, stride and max value for the end bound of the
13286 // loop (RHS), and the fact that IV does not overflow (which is
13287 // checked above).
13288 const SCEV *MaxBECount = computeMaxBECountForLT(
13289 Start, Stride, RHS, getTypeSizeInBits(LHS->getType()), IsSigned);
13290 return ExitLimit(getCouldNotCompute() /* ExactNotTaken */, MaxBECount,
13291 MaxBECount, false /*MaxOrZero*/, Predicates);
13292 }
13293 } else {
13294 // We use the expression (max(End,Start)-Start)/Stride to describe the
13295 // backedge count, as if the backedge is taken at least once
13296 // max(End,Start) is End and so the result is as above, and if not
13297 // max(End,Start) is Start so we get a backedge count of zero.
13298 auto *OrigStartMinusStride = getMinusSCEV(OrigStart, Stride);
13299 assert(isAvailableAtLoopEntry(OrigStartMinusStride, L) && "Must be!");
13300 assert(isAvailableAtLoopEntry(OrigStart, L) && "Must be!");
13301 assert(isAvailableAtLoopEntry(OrigRHS, L) && "Must be!");
13302 // Can we prove (max(RHS,Start) > Start - Stride?
13303 if (isLoopEntryGuardedByCond(L, Cond, OrigStartMinusStride, OrigStart) &&
13304 isLoopEntryGuardedByCond(L, Cond, OrigStartMinusStride, OrigRHS)) {
13305 // In this case, we can use a refined formula for computing backedge
13306 // taken count. The general formula remains:
13307 // "End-Start /uceiling Stride" where "End = max(RHS,Start)"
13308 // We want to use the alternate formula:
13309 // "((End - 1) - (Start - Stride)) /u Stride"
13310 // Let's do a quick case analysis to show these are equivalent under
13311 // our precondition that max(RHS,Start) > Start - Stride.
13312 // * For RHS <= Start, the backedge-taken count must be zero.
13313 // "((End - 1) - (Start - Stride)) /u Stride" reduces to
13314 // "((Start - 1) - (Start - Stride)) /u Stride" which simplies to
13315 // "Stride - 1 /u Stride" which is indeed zero for all non-zero values
13316 // of Stride. For 0 stride, we've use umin(1,Stride) above,
13317 // reducing this to the stride of 1 case.
13318 // * For RHS >= Start, the backedge count must be "RHS-Start /uceil
13319 // Stride".
13320 // "((End - 1) - (Start - Stride)) /u Stride" reduces to
13321 // "((RHS - 1) - (Start - Stride)) /u Stride" reassociates to
13322 // "((RHS - (Start - Stride) - 1) /u Stride".
13323 // Our preconditions trivially imply no overflow in that form.
13324 const SCEV *MinusOne = getMinusOne(Stride->getType());
13325 const SCEV *Numerator =
13326 getMinusSCEV(getAddExpr(RHS, MinusOne), getMinusSCEV(Start, Stride));
13327 BECount = getUDivExpr(Numerator, Stride);
13328 }
13329
13330 if (!BECount) {
13331 auto canProveRHSGreaterThanEqualStart = [&]() {
13332 auto CondGE = IsSigned ? ICmpInst::ICMP_SGE : ICmpInst::ICMP_UGE;
13333 const SCEV *GuardedRHS = applyLoopGuards(OrigRHS, L);
13334 const SCEV *GuardedStart = applyLoopGuards(OrigStart, L);
13335
13336 if (isLoopEntryGuardedByCond(L, CondGE, OrigRHS, OrigStart) ||
13337 isKnownPredicate(CondGE, GuardedRHS, GuardedStart))
13338 return true;
13339
13340 // (RHS > Start - 1) implies RHS >= Start.
13341 // * "RHS >= Start" is trivially equivalent to "RHS > Start - 1" if
13342 // "Start - 1" doesn't overflow.
13343 // * For signed comparison, if Start - 1 does overflow, it's equal
13344 // to INT_MAX, and "RHS >s INT_MAX" is trivially false.
13345 // * For unsigned comparison, if Start - 1 does overflow, it's equal
13346 // to UINT_MAX, and "RHS >u UINT_MAX" is trivially false.
13347 //
13348 // FIXME: Should isLoopEntryGuardedByCond do this for us?
13349 auto CondGT = IsSigned ? ICmpInst::ICMP_SGT : ICmpInst::ICMP_UGT;
13350 auto *StartMinusOne =
13351 getAddExpr(OrigStart, getMinusOne(OrigStart->getType()));
13352 return isLoopEntryGuardedByCond(L, CondGT, OrigRHS, StartMinusOne);
13353 };
13354
13355 // If we know that RHS >= Start in the context of loop, then we know
13356 // that max(RHS, Start) = RHS at this point.
13357 if (canProveRHSGreaterThanEqualStart()) {
13358 End = RHS;
13359 } else {
13360 // If RHS < Start, the backedge will be taken zero times. So in
13361 // general, we can write the backedge-taken count as:
13362 //
13363 // RHS >= Start ? ceil(RHS - Start) / Stride : 0
13364 //
13365 // We convert it to the following to make it more convenient for SCEV:
13366 //
13367 // ceil(max(RHS, Start) - Start) / Stride
13368 End = IsSigned ? getSMaxExpr(RHS, Start) : getUMaxExpr(RHS, Start);
13369
13370 // See what would happen if we assume the backedge is taken. This is
13371 // used to compute MaxBECount.
13372 BECountIfBackedgeTaken =
13373 getUDivCeilSCEV(getMinusSCEV(RHS, Start), Stride);
13374 }
13375
13376 // At this point, we know:
13377 //
13378 // 1. If IsSigned, Start <=s End; otherwise, Start <=u End
13379 // 2. The index variable doesn't overflow.
13380 //
13381 // Therefore, we know N exists such that
13382 // (Start + Stride * N) >= End, and computing "(Start + Stride * N)"
13383 // doesn't overflow.
13384 //
13385 // Using this information, try to prove whether the addition in
13386 // "(Start - End) + (Stride - 1)" has unsigned overflow.
13387 const SCEV *One = getOne(Stride->getType());
13388 bool MayAddOverflow = [&] {
13389 if (isKnownToBeAPowerOfTwo(Stride)) {
13390 // Suppose Stride is a power of two, and Start/End are unsigned
13391 // integers. Let UMAX be the largest representable unsigned
13392 // integer.
13393 //
13394 // By the preconditions of this function, we know
13395 // "(Start + Stride * N) >= End", and this doesn't overflow.
13396 // As a formula:
13397 //
13398 // End <= (Start + Stride * N) <= UMAX
13399 //
13400 // Subtracting Start from all the terms:
13401 //
13402 // End - Start <= Stride * N <= UMAX - Start
13403 //
13404 // Since Start is unsigned, UMAX - Start <= UMAX. Therefore:
13405 //
13406 // End - Start <= Stride * N <= UMAX
13407 //
13408 // Stride * N is a multiple of Stride. Therefore,
13409 //
13410 // End - Start <= Stride * N <= UMAX - (UMAX mod Stride)
13411 //
13412 // Since Stride is a power of two, UMAX + 1 is divisible by
13413 // Stride. Therefore, UMAX mod Stride == Stride - 1. So we can
13414 // write:
13415 //
13416 // End - Start <= Stride * N <= UMAX - Stride - 1
13417 //
13418 // Dropping the middle term:
13419 //
13420 // End - Start <= UMAX - Stride - 1
13421 //
13422 // Adding Stride - 1 to both sides:
13423 //
13424 // (End - Start) + (Stride - 1) <= UMAX
13425 //
13426 // In other words, the addition doesn't have unsigned overflow.
13427 //
13428 // A similar proof works if we treat Start/End as signed values.
13429 // Just rewrite steps before "End - Start <= Stride * N <= UMAX"
13430 // to use signed max instead of unsigned max. Note that we're
13431 // trying to prove a lack of unsigned overflow in either case.
13432 return false;
13433 }
13434 if (Start == Stride || Start == getMinusSCEV(Stride, One)) {
13435 // If Start is equal to Stride, (End - Start) + (Stride - 1) == End
13436 // - 1. If !IsSigned, 0 <u Stride == Start <=u End; so 0 <u End - 1
13437 // <u End. If IsSigned, 0 <s Stride == Start <=s End; so 0 <s End -
13438 // 1 <s End.
13439 //
13440 // If Start is equal to Stride - 1, (End - Start) + Stride - 1 ==
13441 // End.
13442 return false;
13443 }
13444 return true;
13445 }();
13446
13447 const SCEV *Delta = getMinusSCEV(End, Start);
13448 if (!MayAddOverflow) {
13449 // floor((D + (S - 1)) / S)
13450 // We prefer this formulation if it's legal because it's fewer
13451 // operations.
13452 BECount =
13453 getUDivExpr(getAddExpr(Delta, getMinusSCEV(Stride, One)), Stride);
13454 } else {
13455 BECount = getUDivCeilSCEV(Delta, Stride);
13456 }
13457 }
13458 }
13459
13460 const SCEV *ConstantMaxBECount;
13461 bool MaxOrZero = false;
13462 if (isa<SCEVConstant>(BECount)) {
13463 ConstantMaxBECount = BECount;
13464 } else if (BECountIfBackedgeTaken &&
13465 isa<SCEVConstant>(BECountIfBackedgeTaken)) {
13466 // If we know exactly how many times the backedge will be taken if it's
13467 // taken at least once, then the backedge count will either be that or
13468 // zero.
13469 ConstantMaxBECount = BECountIfBackedgeTaken;
13470 MaxOrZero = true;
13471 } else {
13472 ConstantMaxBECount = computeMaxBECountForLT(
13473 Start, Stride, RHS, getTypeSizeInBits(LHS->getType()), IsSigned);
13474 }
13475
13476 if (isa<SCEVCouldNotCompute>(ConstantMaxBECount) &&
13477 !isa<SCEVCouldNotCompute>(BECount))
13478 ConstantMaxBECount = getConstant(getUnsignedRangeMax(BECount));
13479
13480 const SCEV *SymbolicMaxBECount =
13481 isa<SCEVCouldNotCompute>(BECount) ? ConstantMaxBECount : BECount;
13482 return ExitLimit(BECount, ConstantMaxBECount, SymbolicMaxBECount, MaxOrZero,
13483 Predicates);
13484}
13485
13486ScalarEvolution::ExitLimit ScalarEvolution::howManyGreaterThans(
13487 const SCEV *LHS, const SCEV *RHS, const Loop *L, bool IsSigned,
13488 bool ControlsOnlyExit, bool AllowPredicates) {
13490 // We handle only IV > Invariant
13491 if (!isLoopInvariant(RHS, L))
13492 return getCouldNotCompute();
13493
13494 const SCEVAddRecExpr *IV = dyn_cast<SCEVAddRecExpr>(LHS);
13495 if (!IV && AllowPredicates)
13496 // Try to make this an AddRec using runtime tests, in the first X
13497 // iterations of this loop, where X is the SCEV expression found by the
13498 // algorithm below.
13499 IV = convertSCEVToAddRecWithPredicates(LHS, L, Predicates);
13500
13501 // Avoid weird loops
13502 if (!IV || IV->getLoop() != L || !IV->isAffine())
13503 return getCouldNotCompute();
13504
13505 auto WrapType = IsSigned ? SCEV::FlagNSW : SCEV::FlagNUW;
13506 bool NoWrap = ControlsOnlyExit && IV->getNoWrapFlags(WrapType);
13508
13509 const SCEV *Stride = getNegativeSCEV(IV->getStepRecurrence(*this));
13510
13511 // Avoid negative or zero stride values
13512 if (!isKnownPositive(Stride))
13513 return getCouldNotCompute();
13514
13515 // Avoid proven overflow cases: this will ensure that the backedge taken count
13516 // will not generate any unsigned overflow. Relaxed no-overflow conditions
13517 // exploit NoWrapFlags, allowing to optimize in presence of undefined
13518 // behaviors like the case of C language.
13519 if (!Stride->isOne() && !NoWrap)
13520 if (canIVOverflowOnGT(RHS, Stride, IsSigned))
13521 return getCouldNotCompute();
13522
13523 const SCEV *Start = IV->getStart();
13524 const SCEV *End = RHS;
13525 if (!isLoopEntryGuardedByCond(L, Cond, getAddExpr(Start, Stride), RHS)) {
13526 // If we know that Start >= RHS in the context of loop, then we know that
13527 // min(RHS, Start) = RHS at this point.
13529 L, IsSigned ? ICmpInst::ICMP_SGE : ICmpInst::ICMP_UGE, Start, RHS))
13530 End = RHS;
13531 else
13532 End = IsSigned ? getSMinExpr(RHS, Start) : getUMinExpr(RHS, Start);
13533 }
13534
13535 if (Start->getType()->isPointerTy()) {
13537 if (isa<SCEVCouldNotCompute>(Start))
13538 return Start;
13539 }
13540 if (End->getType()->isPointerTy()) {
13541 End = getLosslessPtrToIntExpr(End);
13542 if (isa<SCEVCouldNotCompute>(End))
13543 return End;
13544 }
13545
13546 // Compute ((Start - End) + (Stride - 1)) / Stride.
13547 // FIXME: This can overflow. Holding off on fixing this for now;
13548 // howManyGreaterThans will hopefully be gone soon.
13549 const SCEV *One = getOne(Stride->getType());
13550 const SCEV *BECount = getUDivExpr(
13551 getAddExpr(getMinusSCEV(Start, End), getMinusSCEV(Stride, One)), Stride);
13552
13553 APInt MaxStart = IsSigned ? getSignedRangeMax(Start)
13555
13556 APInt MinStride = IsSigned ? getSignedRangeMin(Stride)
13557 : getUnsignedRangeMin(Stride);
13558
13559 unsigned BitWidth = getTypeSizeInBits(LHS->getType());
13560 APInt Limit = IsSigned ? APInt::getSignedMinValue(BitWidth) + (MinStride - 1)
13561 : APInt::getMinValue(BitWidth) + (MinStride - 1);
13562
13563 // Although End can be a MIN expression we estimate MinEnd considering only
13564 // the case End = RHS. This is safe because in the other case (Start - End)
13565 // is zero, leading to a zero maximum backedge taken count.
13566 APInt MinEnd =
13567 IsSigned ? APIntOps::smax(getSignedRangeMin(RHS), Limit)
13568 : APIntOps::umax(getUnsignedRangeMin(RHS), Limit);
13569
13570 const SCEV *ConstantMaxBECount =
13571 isa<SCEVConstant>(BECount)
13572 ? BECount
13573 : getUDivCeilSCEV(getConstant(MaxStart - MinEnd),
13574 getConstant(MinStride));
13575
13576 if (isa<SCEVCouldNotCompute>(ConstantMaxBECount))
13577 ConstantMaxBECount = BECount;
13578 const SCEV *SymbolicMaxBECount =
13579 isa<SCEVCouldNotCompute>(BECount) ? ConstantMaxBECount : BECount;
13580
13581 return ExitLimit(BECount, ConstantMaxBECount, SymbolicMaxBECount, false,
13582 Predicates);
13583}
13584
13586 ScalarEvolution &SE) const {
13587 if (Range.isFullSet()) // Infinite loop.
13588 return SE.getCouldNotCompute();
13589
13590 // If the start is a non-zero constant, shift the range to simplify things.
13591 if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(getStart()))
13592 if (!SC->getValue()->isZero()) {
13594 Operands[0] = SE.getZero(SC->getType());
13595 const SCEV *Shifted = SE.getAddRecExpr(Operands, getLoop(),
13597 if (const auto *ShiftedAddRec = dyn_cast<SCEVAddRecExpr>(Shifted))
13598 return ShiftedAddRec->getNumIterationsInRange(
13599 Range.subtract(SC->getAPInt()), SE);
13600 // This is strange and shouldn't happen.
13601 return SE.getCouldNotCompute();
13602 }
13603
13604 // The only time we can solve this is when we have all constant indices.
13605 // Otherwise, we cannot determine the overflow conditions.
13606 if (any_of(operands(), [](const SCEV *Op) { return !isa<SCEVConstant>(Op); }))
13607 return SE.getCouldNotCompute();
13608
13609 // Okay at this point we know that all elements of the chrec are constants and
13610 // that the start element is zero.
13611
13612 // First check to see if the range contains zero. If not, the first
13613 // iteration exits.
13614 unsigned BitWidth = SE.getTypeSizeInBits(getType());
13615 if (!Range.contains(APInt(BitWidth, 0)))
13616 return SE.getZero(getType());
13617
13618 if (isAffine()) {
13619 // If this is an affine expression then we have this situation:
13620 // Solve {0,+,A} in Range === Ax in Range
13621
13622 // We know that zero is in the range. If A is positive then we know that
13623 // the upper value of the range must be the first possible exit value.
13624 // If A is negative then the lower of the range is the last possible loop
13625 // value. Also note that we already checked for a full range.
13626 APInt A = cast<SCEVConstant>(getOperand(1))->getAPInt();
13627 APInt End = A.sge(1) ? (Range.getUpper() - 1) : Range.getLower();
13628
13629 // The exit value should be (End+A)/A.
13630 APInt ExitVal = (End + A).udiv(A);
13631 ConstantInt *ExitValue = ConstantInt::get(SE.getContext(), ExitVal);
13632
13633 // Evaluate at the exit value. If we really did fall out of the valid
13634 // range, then we computed our trip count, otherwise wrap around or other
13635 // things must have happened.
13636 ConstantInt *Val = EvaluateConstantChrecAtConstant(this, ExitValue, SE);
13637 if (Range.contains(Val->getValue()))
13638 return SE.getCouldNotCompute(); // Something strange happened
13639
13640 // Ensure that the previous value is in the range.
13641 assert(Range.contains(
13643 ConstantInt::get(SE.getContext(), ExitVal - 1), SE)->getValue()) &&
13644 "Linear scev computation is off in a bad way!");
13645 return SE.getConstant(ExitValue);
13646 }
13647
13648 if (isQuadratic()) {
13649 if (auto S = SolveQuadraticAddRecRange(this, Range, SE))
13650 return SE.getConstant(*S);
13651 }
13652
13653 return SE.getCouldNotCompute();
13654}
13655
13656const SCEVAddRecExpr *
13658 assert(getNumOperands() > 1 && "AddRec with zero step?");
13659 // There is a temptation to just call getAddExpr(this, getStepRecurrence(SE)),
13660 // but in this case we cannot guarantee that the value returned will be an
13661 // AddRec because SCEV does not have a fixed point where it stops
13662 // simplification: it is legal to return ({rec1} + {rec2}). For example, it
13663 // may happen if we reach arithmetic depth limit while simplifying. So we
13664 // construct the returned value explicitly.
13666 // If this is {A,+,B,+,C,...,+,N}, then its step is {B,+,C,+,...,+,N}, and
13667 // (this + Step) is {A+B,+,B+C,+...,+,N}.
13668 for (unsigned i = 0, e = getNumOperands() - 1; i < e; ++i)
13669 Ops.push_back(SE.getAddExpr(getOperand(i), getOperand(i + 1)));
13670 // We know that the last operand is not a constant zero (otherwise it would
13671 // have been popped out earlier). This guarantees us that if the result has
13672 // the same last operand, then it will also not be popped out, meaning that
13673 // the returned value will be an AddRec.
13674 const SCEV *Last = getOperand(getNumOperands() - 1);
13675 assert(!Last->isZero() && "Recurrency with zero step?");
13676 Ops.push_back(Last);
13679}
13680
13681// Return true when S contains at least an undef value.
13683 return SCEVExprContains(S, [](const SCEV *S) {
13684 if (const auto *SU = dyn_cast<SCEVUnknown>(S))
13685 return isa<UndefValue>(SU->getValue());
13686 return false;
13687 });
13688}
13689
13690// Return true when S contains a value that is a nullptr.
13692 return SCEVExprContains(S, [](const SCEV *S) {
13693 if (const auto *SU = dyn_cast<SCEVUnknown>(S))
13694 return SU->getValue() == nullptr;
13695 return false;
13696 });
13697}
13698
13699/// Return the size of an element read or written by Inst.
13701 Type *Ty;
13702 if (StoreInst *Store = dyn_cast<StoreInst>(Inst))
13703 Ty = Store->getValueOperand()->getType();
13704 else if (LoadInst *Load = dyn_cast<LoadInst>(Inst))
13705 Ty = Load->getType();
13706 else
13707 return nullptr;
13708
13710 return getSizeOfExpr(ETy, Ty);
13711}
13712
13713//===----------------------------------------------------------------------===//
13714// SCEVCallbackVH Class Implementation
13715//===----------------------------------------------------------------------===//
13716
13718 assert(SE && "SCEVCallbackVH called with a null ScalarEvolution!");
13719 if (PHINode *PN = dyn_cast<PHINode>(getValPtr()))
13720 SE->ConstantEvolutionLoopExitValue.erase(PN);
13721 SE->eraseValueFromMap(getValPtr());
13722 // this now dangles!
13723}
13724
13725void ScalarEvolution::SCEVCallbackVH::allUsesReplacedWith(Value *V) {
13726 assert(SE && "SCEVCallbackVH called with a null ScalarEvolution!");
13727
13728 // Forget all the expressions associated with users of the old value,
13729 // so that future queries will recompute the expressions using the new
13730 // value.
13731 SE->forgetValue(getValPtr());
13732 // this now dangles!
13733}
13734
13735ScalarEvolution::SCEVCallbackVH::SCEVCallbackVH(Value *V, ScalarEvolution *se)
13736 : CallbackVH(V), SE(se) {}
13737
13738//===----------------------------------------------------------------------===//
13739// ScalarEvolution Class Implementation
13740//===----------------------------------------------------------------------===//
13741
13744 LoopInfo &LI)
13745 : F(F), DL(F.getDataLayout()), TLI(TLI), AC(AC), DT(DT), LI(LI),
13746 CouldNotCompute(new SCEVCouldNotCompute()), ValuesAtScopes(64),
13747 LoopDispositions(64), BlockDispositions(64) {
13748 // To use guards for proving predicates, we need to scan every instruction in
13749 // relevant basic blocks, and not just terminators. Doing this is a waste of
13750 // time if the IR does not actually contain any calls to
13751 // @llvm.experimental.guard, so do a quick check and remember this beforehand.
13752 //
13753 // This pessimizes the case where a pass that preserves ScalarEvolution wants
13754 // to _add_ guards to the module when there weren't any before, and wants
13755 // ScalarEvolution to optimize based on those guards. For now we prefer to be
13756 // efficient in lieu of being smart in that rather obscure case.
13757
13758 auto *GuardDecl = Intrinsic::getDeclarationIfExists(
13759 F.getParent(), Intrinsic::experimental_guard);
13760 HasGuards = GuardDecl && !GuardDecl->use_empty();
13761}
13762
13764 : F(Arg.F), DL(Arg.DL), HasGuards(Arg.HasGuards), TLI(Arg.TLI), AC(Arg.AC),
13765 DT(Arg.DT), LI(Arg.LI), CouldNotCompute(std::move(Arg.CouldNotCompute)),
13766 ValueExprMap(std::move(Arg.ValueExprMap)),
13767 PendingLoopPredicates(std::move(Arg.PendingLoopPredicates)),
13768 PendingPhiRanges(std::move(Arg.PendingPhiRanges)),
13769 PendingMerges(std::move(Arg.PendingMerges)),
13770 ConstantMultipleCache(std::move(Arg.ConstantMultipleCache)),
13771 BackedgeTakenCounts(std::move(Arg.BackedgeTakenCounts)),
13772 PredicatedBackedgeTakenCounts(
13773 std::move(Arg.PredicatedBackedgeTakenCounts)),
13774 BECountUsers(std::move(Arg.BECountUsers)),
13775 ConstantEvolutionLoopExitValue(
13776 std::move(Arg.ConstantEvolutionLoopExitValue)),
13777 ValuesAtScopes(std::move(Arg.ValuesAtScopes)),
13778 ValuesAtScopesUsers(std::move(Arg.ValuesAtScopesUsers)),
13779 LoopDispositions(std::move(Arg.LoopDispositions)),
13780 LoopPropertiesCache(std::move(Arg.LoopPropertiesCache)),
13781 BlockDispositions(std::move(Arg.BlockDispositions)),
13782 SCEVUsers(std::move(Arg.SCEVUsers)),
13783 UnsignedRanges(std::move(Arg.UnsignedRanges)),
13784 SignedRanges(std::move(Arg.SignedRanges)),
13785 UniqueSCEVs(std::move(Arg.UniqueSCEVs)),
13786 UniquePreds(std::move(Arg.UniquePreds)),
13787 SCEVAllocator(std::move(Arg.SCEVAllocator)),
13788 LoopUsers(std::move(Arg.LoopUsers)),
13789 PredicatedSCEVRewrites(std::move(Arg.PredicatedSCEVRewrites)),
13790 FirstUnknown(Arg.FirstUnknown) {
13791 Arg.FirstUnknown = nullptr;
13792}
13793
13795 // Iterate through all the SCEVUnknown instances and call their
13796 // destructors, so that they release their references to their values.
13797 for (SCEVUnknown *U = FirstUnknown; U;) {
13798 SCEVUnknown *Tmp = U;
13799 U = U->Next;
13800 Tmp->~SCEVUnknown();
13801 }
13802 FirstUnknown = nullptr;
13803
13804 ExprValueMap.clear();
13805 ValueExprMap.clear();
13806 HasRecMap.clear();
13807 BackedgeTakenCounts.clear();
13808 PredicatedBackedgeTakenCounts.clear();
13809
13810 assert(PendingLoopPredicates.empty() && "isImpliedCond garbage");
13811 assert(PendingPhiRanges.empty() && "getRangeRef garbage");
13812 assert(PendingMerges.empty() && "isImpliedViaMerge garbage");
13813 assert(!WalkingBEDominatingConds && "isLoopBackedgeGuardedByCond garbage!");
13814 assert(!ProvingSplitPredicate && "ProvingSplitPredicate garbage!");
13815}
13816
13820
13821/// When printing a top-level SCEV for trip counts, it's helpful to include
13822/// a type for constants which are otherwise hard to disambiguate.
13823static void PrintSCEVWithTypeHint(raw_ostream &OS, const SCEV* S) {
13824 if (isa<SCEVConstant>(S))
13825 OS << *S->getType() << " ";
13826 OS << *S;
13827}
13828
13830 const Loop *L) {
13831 // Print all inner loops first
13832 for (Loop *I : *L)
13833 PrintLoopInfo(OS, SE, I);
13834
13835 OS << "Loop ";
13836 L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
13837 OS << ": ";
13838
13839 SmallVector<BasicBlock *, 8> ExitingBlocks;
13840 L->getExitingBlocks(ExitingBlocks);
13841 if (ExitingBlocks.size() != 1)
13842 OS << "<multiple exits> ";
13843
13844 auto *BTC = SE->getBackedgeTakenCount(L);
13845 if (!isa<SCEVCouldNotCompute>(BTC)) {
13846 OS << "backedge-taken count is ";
13847 PrintSCEVWithTypeHint(OS, BTC);
13848 } else
13849 OS << "Unpredictable backedge-taken count.";
13850 OS << "\n";
13851
13852 if (ExitingBlocks.size() > 1)
13853 for (BasicBlock *ExitingBlock : ExitingBlocks) {
13854 OS << " exit count for " << ExitingBlock->getName() << ": ";
13855 const SCEV *EC = SE->getExitCount(L, ExitingBlock);
13856 PrintSCEVWithTypeHint(OS, EC);
13857 if (isa<SCEVCouldNotCompute>(EC)) {
13858 // Retry with predicates.
13860 EC = SE->getPredicatedExitCount(L, ExitingBlock, &Predicates);
13861 if (!isa<SCEVCouldNotCompute>(EC)) {
13862 OS << "\n predicated exit count for " << ExitingBlock->getName()
13863 << ": ";
13864 PrintSCEVWithTypeHint(OS, EC);
13865 OS << "\n Predicates:\n";
13866 for (const auto *P : Predicates)
13867 P->print(OS, 4);
13868 }
13869 }
13870 OS << "\n";
13871 }
13872
13873 OS << "Loop ";
13874 L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
13875 OS << ": ";
13876
13877 auto *ConstantBTC = SE->getConstantMaxBackedgeTakenCount(L);
13878 if (!isa<SCEVCouldNotCompute>(ConstantBTC)) {
13879 OS << "constant max backedge-taken count is ";
13880 PrintSCEVWithTypeHint(OS, ConstantBTC);
13882 OS << ", actual taken count either this or zero.";
13883 } else {
13884 OS << "Unpredictable constant max backedge-taken count. ";
13885 }
13886
13887 OS << "\n"
13888 "Loop ";
13889 L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
13890 OS << ": ";
13891
13892 auto *SymbolicBTC = SE->getSymbolicMaxBackedgeTakenCount(L);
13893 if (!isa<SCEVCouldNotCompute>(SymbolicBTC)) {
13894 OS << "symbolic max backedge-taken count is ";
13895 PrintSCEVWithTypeHint(OS, SymbolicBTC);
13897 OS << ", actual taken count either this or zero.";
13898 } else {
13899 OS << "Unpredictable symbolic max backedge-taken count. ";
13900 }
13901 OS << "\n";
13902
13903 if (ExitingBlocks.size() > 1)
13904 for (BasicBlock *ExitingBlock : ExitingBlocks) {
13905 OS << " symbolic max exit count for " << ExitingBlock->getName() << ": ";
13906 auto *ExitBTC = SE->getExitCount(L, ExitingBlock,
13908 PrintSCEVWithTypeHint(OS, ExitBTC);
13909 if (isa<SCEVCouldNotCompute>(ExitBTC)) {
13910 // Retry with predicates.
13912 ExitBTC = SE->getPredicatedExitCount(L, ExitingBlock, &Predicates,
13914 if (!isa<SCEVCouldNotCompute>(ExitBTC)) {
13915 OS << "\n predicated symbolic max exit count for "
13916 << ExitingBlock->getName() << ": ";
13917 PrintSCEVWithTypeHint(OS, ExitBTC);
13918 OS << "\n Predicates:\n";
13919 for (const auto *P : Predicates)
13920 P->print(OS, 4);
13921 }
13922 }
13923 OS << "\n";
13924 }
13925
13927 auto *PBT = SE->getPredicatedBackedgeTakenCount(L, Preds);
13928 if (PBT != BTC) {
13929 assert(!Preds.empty() && "Different predicated BTC, but no predicates");
13930 OS << "Loop ";
13931 L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
13932 OS << ": ";
13933 if (!isa<SCEVCouldNotCompute>(PBT)) {
13934 OS << "Predicated backedge-taken count is ";
13935 PrintSCEVWithTypeHint(OS, PBT);
13936 } else
13937 OS << "Unpredictable predicated backedge-taken count.";
13938 OS << "\n";
13939 OS << " Predicates:\n";
13940 for (const auto *P : Preds)
13941 P->print(OS, 4);
13942 }
13943 Preds.clear();
13944
13945 auto *PredConstantMax =
13947 if (PredConstantMax != ConstantBTC) {
13948 assert(!Preds.empty() &&
13949 "different predicated constant max BTC but no predicates");
13950 OS << "Loop ";
13951 L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
13952 OS << ": ";
13953 if (!isa<SCEVCouldNotCompute>(PredConstantMax)) {
13954 OS << "Predicated constant max backedge-taken count is ";
13955 PrintSCEVWithTypeHint(OS, PredConstantMax);
13956 } else
13957 OS << "Unpredictable predicated constant max backedge-taken count.";
13958 OS << "\n";
13959 OS << " Predicates:\n";
13960 for (const auto *P : Preds)
13961 P->print(OS, 4);
13962 }
13963 Preds.clear();
13964
13965 auto *PredSymbolicMax =
13967 if (SymbolicBTC != PredSymbolicMax) {
13968 assert(!Preds.empty() &&
13969 "Different predicated symbolic max BTC, but no predicates");
13970 OS << "Loop ";
13971 L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
13972 OS << ": ";
13973 if (!isa<SCEVCouldNotCompute>(PredSymbolicMax)) {
13974 OS << "Predicated symbolic max backedge-taken count is ";
13975 PrintSCEVWithTypeHint(OS, PredSymbolicMax);
13976 } else
13977 OS << "Unpredictable predicated symbolic max backedge-taken count.";
13978 OS << "\n";
13979 OS << " Predicates:\n";
13980 for (const auto *P : Preds)
13981 P->print(OS, 4);
13982 }
13983
13985 OS << "Loop ";
13986 L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
13987 OS << ": ";
13988 OS << "Trip multiple is " << SE->getSmallConstantTripMultiple(L) << "\n";
13989 }
13990}
13991
13992namespace llvm {
13994 switch (LD) {
13996 OS << "Variant";
13997 break;
13999 OS << "Invariant";
14000 break;
14002 OS << "Computable";
14003 break;
14004 }
14005 return OS;
14006}
14007
14009 switch (BD) {
14011 OS << "DoesNotDominate";
14012 break;
14014 OS << "Dominates";
14015 break;
14017 OS << "ProperlyDominates";
14018 break;
14019 }
14020 return OS;
14021}
14022} // namespace llvm
14023
14025 // ScalarEvolution's implementation of the print method is to print
14026 // out SCEV values of all instructions that are interesting. Doing
14027 // this potentially causes it to create new SCEV objects though,
14028 // which technically conflicts with the const qualifier. This isn't
14029 // observable from outside the class though, so casting away the
14030 // const isn't dangerous.
14031 ScalarEvolution &SE = *const_cast<ScalarEvolution *>(this);
14032
14033 if (ClassifyExpressions) {
14034 OS << "Classifying expressions for: ";
14035 F.printAsOperand(OS, /*PrintType=*/false);
14036 OS << "\n";
14037 for (Instruction &I : instructions(F))
14038 if (isSCEVable(I.getType()) && !isa<CmpInst>(I)) {
14039 OS << I << '\n';
14040 OS << " --> ";
14041 const SCEV *SV = SE.getSCEV(&I);
14042 SV->print(OS);
14043 if (!isa<SCEVCouldNotCompute>(SV)) {
14044 OS << " U: ";
14045 SE.getUnsignedRange(SV).print(OS);
14046 OS << " S: ";
14047 SE.getSignedRange(SV).print(OS);
14048 }
14049
14050 const Loop *L = LI.getLoopFor(I.getParent());
14051
14052 const SCEV *AtUse = SE.getSCEVAtScope(SV, L);
14053 if (AtUse != SV) {
14054 OS << " --> ";
14055 AtUse->print(OS);
14056 if (!isa<SCEVCouldNotCompute>(AtUse)) {
14057 OS << " U: ";
14058 SE.getUnsignedRange(AtUse).print(OS);
14059 OS << " S: ";
14060 SE.getSignedRange(AtUse).print(OS);
14061 }
14062 }
14063
14064 if (L) {
14065 OS << "\t\t" "Exits: ";
14066 const SCEV *ExitValue = SE.getSCEVAtScope(SV, L->getParentLoop());
14067 if (!SE.isLoopInvariant(ExitValue, L)) {
14068 OS << "<<Unknown>>";
14069 } else {
14070 OS << *ExitValue;
14071 }
14072
14073 bool First = true;
14074 for (const auto *Iter = L; Iter; Iter = Iter->getParentLoop()) {
14075 if (First) {
14076 OS << "\t\t" "LoopDispositions: { ";
14077 First = false;
14078 } else {
14079 OS << ", ";
14080 }
14081
14082 Iter->getHeader()->printAsOperand(OS, /*PrintType=*/false);
14083 OS << ": " << SE.getLoopDisposition(SV, Iter);
14084 }
14085
14086 for (const auto *InnerL : depth_first(L)) {
14087 if (InnerL == L)
14088 continue;
14089 if (First) {
14090 OS << "\t\t" "LoopDispositions: { ";
14091 First = false;
14092 } else {
14093 OS << ", ";
14094 }
14095
14096 InnerL->getHeader()->printAsOperand(OS, /*PrintType=*/false);
14097 OS << ": " << SE.getLoopDisposition(SV, InnerL);
14098 }
14099
14100 OS << " }";
14101 }
14102
14103 OS << "\n";
14104 }
14105 }
14106
14107 OS << "Determining loop execution counts for: ";
14108 F.printAsOperand(OS, /*PrintType=*/false);
14109 OS << "\n";
14110 for (Loop *I : LI)
14111 PrintLoopInfo(OS, &SE, I);
14112}
14113
14116 auto &Values = LoopDispositions[S];
14117 for (auto &V : Values) {
14118 if (V.getPointer() == L)
14119 return V.getInt();
14120 }
14121 Values.emplace_back(L, LoopVariant);
14122 LoopDisposition D = computeLoopDisposition(S, L);
14123 auto &Values2 = LoopDispositions[S];
14124 for (auto &V : llvm::reverse(Values2)) {
14125 if (V.getPointer() == L) {
14126 V.setInt(D);
14127 break;
14128 }
14129 }
14130 return D;
14131}
14132
14134ScalarEvolution::computeLoopDisposition(const SCEV *S, const Loop *L) {
14135 switch (S->getSCEVType()) {
14136 case scConstant:
14137 case scVScale:
14138 return LoopInvariant;
14139 case scAddRecExpr: {
14140 const SCEVAddRecExpr *AR = cast<SCEVAddRecExpr>(S);
14141
14142 // If L is the addrec's loop, it's computable.
14143 if (AR->getLoop() == L)
14144 return LoopComputable;
14145
14146 // Add recurrences are never invariant in the function-body (null loop).
14147 if (!L)
14148 return LoopVariant;
14149
14150 // Everything that is not defined at loop entry is variant.
14151 if (DT.dominates(L->getHeader(), AR->getLoop()->getHeader()))
14152 return LoopVariant;
14153 assert(!L->contains(AR->getLoop()) && "Containing loop's header does not"
14154 " dominate the contained loop's header?");
14155
14156 // This recurrence is invariant w.r.t. L if AR's loop contains L.
14157 if (AR->getLoop()->contains(L))
14158 return LoopInvariant;
14159
14160 // This recurrence is variant w.r.t. L if any of its operands
14161 // are variant.
14162 for (const auto *Op : AR->operands())
14163 if (!isLoopInvariant(Op, L))
14164 return LoopVariant;
14165
14166 // Otherwise it's loop-invariant.
14167 return LoopInvariant;
14168 }
14169 case scTruncate:
14170 case scZeroExtend:
14171 case scSignExtend:
14172 case scPtrToInt:
14173 case scAddExpr:
14174 case scMulExpr:
14175 case scUDivExpr:
14176 case scUMaxExpr:
14177 case scSMaxExpr:
14178 case scUMinExpr:
14179 case scSMinExpr:
14180 case scSequentialUMinExpr: {
14181 bool HasVarying = false;
14182 for (const auto *Op : S->operands()) {
14184 if (D == LoopVariant)
14185 return LoopVariant;
14186 if (D == LoopComputable)
14187 HasVarying = true;
14188 }
14189 return HasVarying ? LoopComputable : LoopInvariant;
14190 }
14191 case scUnknown:
14192 // All non-instruction values are loop invariant. All instructions are loop
14193 // invariant if they are not contained in the specified loop.
14194 // Instructions are never considered invariant in the function body
14195 // (null loop) because they are defined within the "loop".
14196 if (auto *I = dyn_cast<Instruction>(cast<SCEVUnknown>(S)->getValue()))
14197 return (L && !L->contains(I)) ? LoopInvariant : LoopVariant;
14198 return LoopInvariant;
14199 case scCouldNotCompute:
14200 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
14201 }
14202 llvm_unreachable("Unknown SCEV kind!");
14203}
14204
14206 return getLoopDisposition(S, L) == LoopInvariant;
14207}
14208
14210 return getLoopDisposition(S, L) == LoopComputable;
14211}
14212
14215 auto &Values = BlockDispositions[S];
14216 for (auto &V : Values) {
14217 if (V.getPointer() == BB)
14218 return V.getInt();
14219 }
14220 Values.emplace_back(BB, DoesNotDominateBlock);
14221 BlockDisposition D = computeBlockDisposition(S, BB);
14222 auto &Values2 = BlockDispositions[S];
14223 for (auto &V : llvm::reverse(Values2)) {
14224 if (V.getPointer() == BB) {
14225 V.setInt(D);
14226 break;
14227 }
14228 }
14229 return D;
14230}
14231
14233ScalarEvolution::computeBlockDisposition(const SCEV *S, const BasicBlock *BB) {
14234 switch (S->getSCEVType()) {
14235 case scConstant:
14236 case scVScale:
14238 case scAddRecExpr: {
14239 // This uses a "dominates" query instead of "properly dominates" query
14240 // to test for proper dominance too, because the instruction which
14241 // produces the addrec's value is a PHI, and a PHI effectively properly
14242 // dominates its entire containing block.
14243 const SCEVAddRecExpr *AR = cast<SCEVAddRecExpr>(S);
14244 if (!DT.dominates(AR->getLoop()->getHeader(), BB))
14245 return DoesNotDominateBlock;
14246
14247 // Fall through into SCEVNAryExpr handling.
14248 [[fallthrough]];
14249 }
14250 case scTruncate:
14251 case scZeroExtend:
14252 case scSignExtend:
14253 case scPtrToInt:
14254 case scAddExpr:
14255 case scMulExpr:
14256 case scUDivExpr:
14257 case scUMaxExpr:
14258 case scSMaxExpr:
14259 case scUMinExpr:
14260 case scSMinExpr:
14261 case scSequentialUMinExpr: {
14262 bool Proper = true;
14263 for (const SCEV *NAryOp : S->operands()) {
14265 if (D == DoesNotDominateBlock)
14266 return DoesNotDominateBlock;
14267 if (D == DominatesBlock)
14268 Proper = false;
14269 }
14270 return Proper ? ProperlyDominatesBlock : DominatesBlock;
14271 }
14272 case scUnknown:
14273 if (Instruction *I =
14274 dyn_cast<Instruction>(cast<SCEVUnknown>(S)->getValue())) {
14275 if (I->getParent() == BB)
14276 return DominatesBlock;
14277 if (DT.properlyDominates(I->getParent(), BB))
14279 return DoesNotDominateBlock;
14280 }
14282 case scCouldNotCompute:
14283 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
14284 }
14285 llvm_unreachable("Unknown SCEV kind!");
14286}
14287
14288bool ScalarEvolution::dominates(const SCEV *S, const BasicBlock *BB) {
14289 return getBlockDisposition(S, BB) >= DominatesBlock;
14290}
14291
14294}
14295
14296bool ScalarEvolution::hasOperand(const SCEV *S, const SCEV *Op) const {
14297 return SCEVExprContains(S, [&](const SCEV *Expr) { return Expr == Op; });
14298}
14299
14300void ScalarEvolution::forgetBackedgeTakenCounts(const Loop *L,
14301 bool Predicated) {
14302 auto &BECounts =
14303 Predicated ? PredicatedBackedgeTakenCounts : BackedgeTakenCounts;
14304 auto It = BECounts.find(L);
14305 if (It != BECounts.end()) {
14306 for (const ExitNotTakenInfo &ENT : It->second.ExitNotTaken) {
14307 for (const SCEV *S : {ENT.ExactNotTaken, ENT.SymbolicMaxNotTaken}) {
14308 if (!isa<SCEVConstant>(S)) {
14309 auto UserIt = BECountUsers.find(S);
14310 assert(UserIt != BECountUsers.end());
14311 UserIt->second.erase({L, Predicated});
14312 }
14313 }
14314 }
14315 BECounts.erase(It);
14316 }
14317}
14318
14319void ScalarEvolution::forgetMemoizedResults(ArrayRef<const SCEV *> SCEVs) {
14320 SmallPtrSet<const SCEV *, 8> ToForget(llvm::from_range, SCEVs);
14321 SmallVector<const SCEV *, 8> Worklist(ToForget.begin(), ToForget.end());
14322
14323 while (!Worklist.empty()) {
14324 const SCEV *Curr = Worklist.pop_back_val();
14325 auto Users = SCEVUsers.find(Curr);
14326 if (Users != SCEVUsers.end())
14327 for (const auto *User : Users->second)
14328 if (ToForget.insert(User).second)
14329 Worklist.push_back(User);
14330 }
14331
14332 for (const auto *S : ToForget)
14333 forgetMemoizedResultsImpl(S);
14334
14335 for (auto I = PredicatedSCEVRewrites.begin();
14336 I != PredicatedSCEVRewrites.end();) {
14337 std::pair<const SCEV *, const Loop *> Entry = I->first;
14338 if (ToForget.count(Entry.first))
14339 PredicatedSCEVRewrites.erase(I++);
14340 else
14341 ++I;
14342 }
14343}
14344
14345void ScalarEvolution::forgetMemoizedResultsImpl(const SCEV *S) {
14346 LoopDispositions.erase(S);
14347 BlockDispositions.erase(S);
14348 UnsignedRanges.erase(S);
14349 SignedRanges.erase(S);
14350 HasRecMap.erase(S);
14351 ConstantMultipleCache.erase(S);
14352
14353 if (auto *AR = dyn_cast<SCEVAddRecExpr>(S)) {
14354 UnsignedWrapViaInductionTried.erase(AR);
14355 SignedWrapViaInductionTried.erase(AR);
14356 }
14357
14358 auto ExprIt = ExprValueMap.find(S);
14359 if (ExprIt != ExprValueMap.end()) {
14360 for (Value *V : ExprIt->second) {
14361 auto ValueIt = ValueExprMap.find_as(V);
14362 if (ValueIt != ValueExprMap.end())
14363 ValueExprMap.erase(ValueIt);
14364 }
14365 ExprValueMap.erase(ExprIt);
14366 }
14367
14368 auto ScopeIt = ValuesAtScopes.find(S);
14369 if (ScopeIt != ValuesAtScopes.end()) {
14370 for (const auto &Pair : ScopeIt->second)
14371 if (!isa_and_nonnull<SCEVConstant>(Pair.second))
14372 llvm::erase(ValuesAtScopesUsers[Pair.second],
14373 std::make_pair(Pair.first, S));
14374 ValuesAtScopes.erase(ScopeIt);
14375 }
14376
14377 auto ScopeUserIt = ValuesAtScopesUsers.find(S);
14378 if (ScopeUserIt != ValuesAtScopesUsers.end()) {
14379 for (const auto &Pair : ScopeUserIt->second)
14380 llvm::erase(ValuesAtScopes[Pair.second], std::make_pair(Pair.first, S));
14381 ValuesAtScopesUsers.erase(ScopeUserIt);
14382 }
14383
14384 auto BEUsersIt = BECountUsers.find(S);
14385 if (BEUsersIt != BECountUsers.end()) {
14386 // Work on a copy, as forgetBackedgeTakenCounts() will modify the original.
14387 auto Copy = BEUsersIt->second;
14388 for (const auto &Pair : Copy)
14389 forgetBackedgeTakenCounts(Pair.getPointer(), Pair.getInt());
14390 BECountUsers.erase(BEUsersIt);
14391 }
14392
14393 auto FoldUser = FoldCacheUser.find(S);
14394 if (FoldUser != FoldCacheUser.end())
14395 for (auto &KV : FoldUser->second)
14396 FoldCache.erase(KV);
14397 FoldCacheUser.erase(S);
14398}
14399
14400void
14401ScalarEvolution::getUsedLoops(const SCEV *S,
14402 SmallPtrSetImpl<const Loop *> &LoopsUsed) {
14403 struct FindUsedLoops {
14404 FindUsedLoops(SmallPtrSetImpl<const Loop *> &LoopsUsed)
14405 : LoopsUsed(LoopsUsed) {}
14406 SmallPtrSetImpl<const Loop *> &LoopsUsed;
14407 bool follow(const SCEV *S) {
14408 if (auto *AR = dyn_cast<SCEVAddRecExpr>(S))
14409 LoopsUsed.insert(AR->getLoop());
14410 return true;
14411 }
14412
14413 bool isDone() const { return false; }
14414 };
14415
14416 FindUsedLoops F(LoopsUsed);
14417 SCEVTraversal<FindUsedLoops>(F).visitAll(S);
14418}
14419
14420void ScalarEvolution::getReachableBlocks(
14423 Worklist.push_back(&F.getEntryBlock());
14424 while (!Worklist.empty()) {
14425 BasicBlock *BB = Worklist.pop_back_val();
14426 if (!Reachable.insert(BB).second)
14427 continue;
14428
14429 Value *Cond;
14430 BasicBlock *TrueBB, *FalseBB;
14431 if (match(BB->getTerminator(), m_Br(m_Value(Cond), m_BasicBlock(TrueBB),
14432 m_BasicBlock(FalseBB)))) {
14433 if (auto *C = dyn_cast<ConstantInt>(Cond)) {
14434 Worklist.push_back(C->isOne() ? TrueBB : FalseBB);
14435 continue;
14436 }
14437
14438 if (auto *Cmp = dyn_cast<ICmpInst>(Cond)) {
14439 const SCEV *L = getSCEV(Cmp->getOperand(0));
14440 const SCEV *R = getSCEV(Cmp->getOperand(1));
14441 if (isKnownPredicateViaConstantRanges(Cmp->getCmpPredicate(), L, R)) {
14442 Worklist.push_back(TrueBB);
14443 continue;
14444 }
14445 if (isKnownPredicateViaConstantRanges(Cmp->getInverseCmpPredicate(), L,
14446 R)) {
14447 Worklist.push_back(FalseBB);
14448 continue;
14449 }
14450 }
14451 }
14452
14453 append_range(Worklist, successors(BB));
14454 }
14455}
14456
14458 ScalarEvolution &SE = *const_cast<ScalarEvolution *>(this);
14459 ScalarEvolution SE2(F, TLI, AC, DT, LI);
14460
14461 SmallVector<Loop *, 8> LoopStack(LI.begin(), LI.end());
14462
14463 // Map's SCEV expressions from one ScalarEvolution "universe" to another.
14464 struct SCEVMapper : public SCEVRewriteVisitor<SCEVMapper> {
14465 SCEVMapper(ScalarEvolution &SE) : SCEVRewriteVisitor<SCEVMapper>(SE) {}
14466
14467 const SCEV *visitConstant(const SCEVConstant *Constant) {
14468 return SE.getConstant(Constant->getAPInt());
14469 }
14470
14471 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
14472 return SE.getUnknown(Expr->getValue());
14473 }
14474
14475 const SCEV *visitCouldNotCompute(const SCEVCouldNotCompute *Expr) {
14476 return SE.getCouldNotCompute();
14477 }
14478 };
14479
14480 SCEVMapper SCM(SE2);
14481 SmallPtrSet<BasicBlock *, 16> ReachableBlocks;
14482 SE2.getReachableBlocks(ReachableBlocks, F);
14483
14484 auto GetDelta = [&](const SCEV *Old, const SCEV *New) -> const SCEV * {
14485 if (containsUndefs(Old) || containsUndefs(New)) {
14486 // SCEV treats "undef" as an unknown but consistent value (i.e. it does
14487 // not propagate undef aggressively). This means we can (and do) fail
14488 // verification in cases where a transform makes a value go from "undef"
14489 // to "undef+1" (say). The transform is fine, since in both cases the
14490 // result is "undef", but SCEV thinks the value increased by 1.
14491 return nullptr;
14492 }
14493
14494 // Unless VerifySCEVStrict is set, we only compare constant deltas.
14495 const SCEV *Delta = SE2.getMinusSCEV(Old, New);
14496 if (!VerifySCEVStrict && !isa<SCEVConstant>(Delta))
14497 return nullptr;
14498
14499 return Delta;
14500 };
14501
14502 while (!LoopStack.empty()) {
14503 auto *L = LoopStack.pop_back_val();
14504 llvm::append_range(LoopStack, *L);
14505
14506 // Only verify BECounts in reachable loops. For an unreachable loop,
14507 // any BECount is legal.
14508 if (!ReachableBlocks.contains(L->getHeader()))
14509 continue;
14510
14511 // Only verify cached BECounts. Computing new BECounts may change the
14512 // results of subsequent SCEV uses.
14513 auto It = BackedgeTakenCounts.find(L);
14514 if (It == BackedgeTakenCounts.end())
14515 continue;
14516
14517 auto *CurBECount =
14518 SCM.visit(It->second.getExact(L, const_cast<ScalarEvolution *>(this)));
14519 auto *NewBECount = SE2.getBackedgeTakenCount(L);
14520
14521 if (CurBECount == SE2.getCouldNotCompute() ||
14522 NewBECount == SE2.getCouldNotCompute()) {
14523 // NB! This situation is legal, but is very suspicious -- whatever pass
14524 // change the loop to make a trip count go from could not compute to
14525 // computable or vice-versa *should have* invalidated SCEV. However, we
14526 // choose not to assert here (for now) since we don't want false
14527 // positives.
14528 continue;
14529 }
14530
14531 if (SE.getTypeSizeInBits(CurBECount->getType()) >
14532 SE.getTypeSizeInBits(NewBECount->getType()))
14533 NewBECount = SE2.getZeroExtendExpr(NewBECount, CurBECount->getType());
14534 else if (SE.getTypeSizeInBits(CurBECount->getType()) <
14535 SE.getTypeSizeInBits(NewBECount->getType()))
14536 CurBECount = SE2.getZeroExtendExpr(CurBECount, NewBECount->getType());
14537
14538 const SCEV *Delta = GetDelta(CurBECount, NewBECount);
14539 if (Delta && !Delta->isZero()) {
14540 dbgs() << "Trip Count for " << *L << " Changed!\n";
14541 dbgs() << "Old: " << *CurBECount << "\n";
14542 dbgs() << "New: " << *NewBECount << "\n";
14543 dbgs() << "Delta: " << *Delta << "\n";
14544 std::abort();
14545 }
14546 }
14547
14548 // Collect all valid loops currently in LoopInfo.
14549 SmallPtrSet<Loop *, 32> ValidLoops;
14550 SmallVector<Loop *, 32> Worklist(LI.begin(), LI.end());
14551 while (!Worklist.empty()) {
14552 Loop *L = Worklist.pop_back_val();
14553 if (ValidLoops.insert(L).second)
14554 Worklist.append(L->begin(), L->end());
14555 }
14556 for (const auto &KV : ValueExprMap) {
14557#ifndef NDEBUG
14558 // Check for SCEV expressions referencing invalid/deleted loops.
14559 if (auto *AR = dyn_cast<SCEVAddRecExpr>(KV.second)) {
14560 assert(ValidLoops.contains(AR->getLoop()) &&
14561 "AddRec references invalid loop");
14562 }
14563#endif
14564
14565 // Check that the value is also part of the reverse map.
14566 auto It = ExprValueMap.find(KV.second);
14567 if (It == ExprValueMap.end() || !It->second.contains(KV.first)) {
14568 dbgs() << "Value " << *KV.first
14569 << " is in ValueExprMap but not in ExprValueMap\n";
14570 std::abort();
14571 }
14572
14573 if (auto *I = dyn_cast<Instruction>(&*KV.first)) {
14574 if (!ReachableBlocks.contains(I->getParent()))
14575 continue;
14576 const SCEV *OldSCEV = SCM.visit(KV.second);
14577 const SCEV *NewSCEV = SE2.getSCEV(I);
14578 const SCEV *Delta = GetDelta(OldSCEV, NewSCEV);
14579 if (Delta && !Delta->isZero()) {
14580 dbgs() << "SCEV for value " << *I << " changed!\n"
14581 << "Old: " << *OldSCEV << "\n"
14582 << "New: " << *NewSCEV << "\n"
14583 << "Delta: " << *Delta << "\n";
14584 std::abort();
14585 }
14586 }
14587 }
14588
14589 for (const auto &KV : ExprValueMap) {
14590 for (Value *V : KV.second) {
14591 const SCEV *S = ValueExprMap.lookup(V);
14592 if (!S) {
14593 dbgs() << "Value " << *V
14594 << " is in ExprValueMap but not in ValueExprMap\n";
14595 std::abort();
14596 }
14597 if (S != KV.first) {
14598 dbgs() << "Value " << *V << " mapped to " << *S << " rather than "
14599 << *KV.first << "\n";
14600 std::abort();
14601 }
14602 }
14603 }
14604
14605 // Verify integrity of SCEV users.
14606 for (const auto &S : UniqueSCEVs) {
14607 for (const auto *Op : S.operands()) {
14608 // We do not store dependencies of constants.
14609 if (isa<SCEVConstant>(Op))
14610 continue;
14611 auto It = SCEVUsers.find(Op);
14612 if (It != SCEVUsers.end() && It->second.count(&S))
14613 continue;
14614 dbgs() << "Use of operand " << *Op << " by user " << S
14615 << " is not being tracked!\n";
14616 std::abort();
14617 }
14618 }
14619
14620 // Verify integrity of ValuesAtScopes users.
14621 for (const auto &ValueAndVec : ValuesAtScopes) {
14622 const SCEV *Value = ValueAndVec.first;
14623 for (const auto &LoopAndValueAtScope : ValueAndVec.second) {
14624 const Loop *L = LoopAndValueAtScope.first;
14625 const SCEV *ValueAtScope = LoopAndValueAtScope.second;
14626 if (!isa<SCEVConstant>(ValueAtScope)) {
14627 auto It = ValuesAtScopesUsers.find(ValueAtScope);
14628 if (It != ValuesAtScopesUsers.end() &&
14629 is_contained(It->second, std::make_pair(L, Value)))
14630 continue;
14631 dbgs() << "Value: " << *Value << ", Loop: " << *L << ", ValueAtScope: "
14632 << *ValueAtScope << " missing in ValuesAtScopesUsers\n";
14633 std::abort();
14634 }
14635 }
14636 }
14637
14638 for (const auto &ValueAtScopeAndVec : ValuesAtScopesUsers) {
14639 const SCEV *ValueAtScope = ValueAtScopeAndVec.first;
14640 for (const auto &LoopAndValue : ValueAtScopeAndVec.second) {
14641 const Loop *L = LoopAndValue.first;
14642 const SCEV *Value = LoopAndValue.second;
14644 auto It = ValuesAtScopes.find(Value);
14645 if (It != ValuesAtScopes.end() &&
14646 is_contained(It->second, std::make_pair(L, ValueAtScope)))
14647 continue;
14648 dbgs() << "Value: " << *Value << ", Loop: " << *L << ", ValueAtScope: "
14649 << *ValueAtScope << " missing in ValuesAtScopes\n";
14650 std::abort();
14651 }
14652 }
14653
14654 // Verify integrity of BECountUsers.
14655 auto VerifyBECountUsers = [&](bool Predicated) {
14656 auto &BECounts =
14657 Predicated ? PredicatedBackedgeTakenCounts : BackedgeTakenCounts;
14658 for (const auto &LoopAndBEInfo : BECounts) {
14659 for (const ExitNotTakenInfo &ENT : LoopAndBEInfo.second.ExitNotTaken) {
14660 for (const SCEV *S : {ENT.ExactNotTaken, ENT.SymbolicMaxNotTaken}) {
14661 if (!isa<SCEVConstant>(S)) {
14662 auto UserIt = BECountUsers.find(S);
14663 if (UserIt != BECountUsers.end() &&
14664 UserIt->second.contains({ LoopAndBEInfo.first, Predicated }))
14665 continue;
14666 dbgs() << "Value " << *S << " for loop " << *LoopAndBEInfo.first
14667 << " missing from BECountUsers\n";
14668 std::abort();
14669 }
14670 }
14671 }
14672 }
14673 };
14674 VerifyBECountUsers(/* Predicated */ false);
14675 VerifyBECountUsers(/* Predicated */ true);
14676
14677 // Verify intergity of loop disposition cache.
14678 for (auto &[S, Values] : LoopDispositions) {
14679 for (auto [Loop, CachedDisposition] : Values) {
14680 const auto RecomputedDisposition = SE2.getLoopDisposition(S, Loop);
14681 if (CachedDisposition != RecomputedDisposition) {
14682 dbgs() << "Cached disposition of " << *S << " for loop " << *Loop
14683 << " is incorrect: cached " << CachedDisposition << ", actual "
14684 << RecomputedDisposition << "\n";
14685 std::abort();
14686 }
14687 }
14688 }
14689
14690 // Verify integrity of the block disposition cache.
14691 for (auto &[S, Values] : BlockDispositions) {
14692 for (auto [BB, CachedDisposition] : Values) {
14693 const auto RecomputedDisposition = SE2.getBlockDisposition(S, BB);
14694 if (CachedDisposition != RecomputedDisposition) {
14695 dbgs() << "Cached disposition of " << *S << " for block %"
14696 << BB->getName() << " is incorrect: cached " << CachedDisposition
14697 << ", actual " << RecomputedDisposition << "\n";
14698 std::abort();
14699 }
14700 }
14701 }
14702
14703 // Verify FoldCache/FoldCacheUser caches.
14704 for (auto [FoldID, Expr] : FoldCache) {
14705 auto I = FoldCacheUser.find(Expr);
14706 if (I == FoldCacheUser.end()) {
14707 dbgs() << "Missing entry in FoldCacheUser for cached expression " << *Expr
14708 << "!\n";
14709 std::abort();
14710 }
14711 if (!is_contained(I->second, FoldID)) {
14712 dbgs() << "Missing FoldID in cached users of " << *Expr << "!\n";
14713 std::abort();
14714 }
14715 }
14716 for (auto [Expr, IDs] : FoldCacheUser) {
14717 for (auto &FoldID : IDs) {
14718 const SCEV *S = FoldCache.lookup(FoldID);
14719 if (!S) {
14720 dbgs() << "Missing entry in FoldCache for expression " << *Expr
14721 << "!\n";
14722 std::abort();
14723 }
14724 if (S != Expr) {
14725 dbgs() << "Entry in FoldCache doesn't match FoldCacheUser: " << *S
14726 << " != " << *Expr << "!\n";
14727 std::abort();
14728 }
14729 }
14730 }
14731
14732 // Verify that ConstantMultipleCache computations are correct. We check that
14733 // cached multiples and recomputed multiples are multiples of each other to
14734 // verify correctness. It is possible that a recomputed multiple is different
14735 // from the cached multiple due to strengthened no wrap flags or changes in
14736 // KnownBits computations.
14737 for (auto [S, Multiple] : ConstantMultipleCache) {
14738 APInt RecomputedMultiple = SE2.getConstantMultiple(S);
14739 if ((Multiple != 0 && RecomputedMultiple != 0 &&
14740 Multiple.urem(RecomputedMultiple) != 0 &&
14741 RecomputedMultiple.urem(Multiple) != 0)) {
14742 dbgs() << "Incorrect cached computation in ConstantMultipleCache for "
14743 << *S << " : Computed " << RecomputedMultiple
14744 << " but cache contains " << Multiple << "!\n";
14745 std::abort();
14746 }
14747 }
14748}
14749
14751 Function &F, const PreservedAnalyses &PA,
14752 FunctionAnalysisManager::Invalidator &Inv) {
14753 // Invalidate the ScalarEvolution object whenever it isn't preserved or one
14754 // of its dependencies is invalidated.
14755 auto PAC = PA.getChecker<ScalarEvolutionAnalysis>();
14756 return !(PAC.preserved() || PAC.preservedSet<AllAnalysesOn<Function>>()) ||
14757 Inv.invalidate<AssumptionAnalysis>(F, PA) ||
14758 Inv.invalidate<DominatorTreeAnalysis>(F, PA) ||
14759 Inv.invalidate<LoopAnalysis>(F, PA);
14760}
14761
14762AnalysisKey ScalarEvolutionAnalysis::Key;
14763
14766 auto &TLI = AM.getResult<TargetLibraryAnalysis>(F);
14767 auto &AC = AM.getResult<AssumptionAnalysis>(F);
14768 auto &DT = AM.getResult<DominatorTreeAnalysis>(F);
14769 auto &LI = AM.getResult<LoopAnalysis>(F);
14770 return ScalarEvolution(F, TLI, AC, DT, LI);
14771}
14772
14778
14781 // For compatibility with opt's -analyze feature under legacy pass manager
14782 // which was not ported to NPM. This keeps tests using
14783 // update_analyze_test_checks.py working.
14784 OS << "Printing analysis 'Scalar Evolution Analysis' for function '"
14785 << F.getName() << "':\n";
14787 return PreservedAnalyses::all();
14788}
14789
14791 "Scalar Evolution Analysis", false, true)
14797 "Scalar Evolution Analysis", false, true)
14798
14800
14802
14804 SE.reset(new ScalarEvolution(
14806 getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F),
14808 getAnalysis<LoopInfoWrapperPass>().getLoopInfo()));
14809 return false;
14810}
14811
14813
14815 SE->print(OS);
14816}
14817
14819 if (!VerifySCEV)
14820 return;
14821
14822 SE->verify();
14823}
14824
14832
14834 const SCEV *RHS) {
14835 return getComparePredicate(ICmpInst::ICMP_EQ, LHS, RHS);
14836}
14837
14838const SCEVPredicate *
14840 const SCEV *LHS, const SCEV *RHS) {
14842 assert(LHS->getType() == RHS->getType() &&
14843 "Type mismatch between LHS and RHS");
14844 // Unique this node based on the arguments
14845 ID.AddInteger(SCEVPredicate::P_Compare);
14846 ID.AddInteger(Pred);
14847 ID.AddPointer(LHS);
14848 ID.AddPointer(RHS);
14849 void *IP = nullptr;
14850 if (const auto *S = UniquePreds.FindNodeOrInsertPos(ID, IP))
14851 return S;
14852 SCEVComparePredicate *Eq = new (SCEVAllocator)
14853 SCEVComparePredicate(ID.Intern(SCEVAllocator), Pred, LHS, RHS);
14854 UniquePreds.InsertNode(Eq, IP);
14855 return Eq;
14856}
14857
14859 const SCEVAddRecExpr *AR,
14862 // Unique this node based on the arguments
14863 ID.AddInteger(SCEVPredicate::P_Wrap);
14864 ID.AddPointer(AR);
14865 ID.AddInteger(AddedFlags);
14866 void *IP = nullptr;
14867 if (const auto *S = UniquePreds.FindNodeOrInsertPos(ID, IP))
14868 return S;
14869 auto *OF = new (SCEVAllocator)
14870 SCEVWrapPredicate(ID.Intern(SCEVAllocator), AR, AddedFlags);
14871 UniquePreds.InsertNode(OF, IP);
14872 return OF;
14873}
14874
14875namespace {
14876
14877class SCEVPredicateRewriter : public SCEVRewriteVisitor<SCEVPredicateRewriter> {
14878public:
14879
14880 /// Rewrites \p S in the context of a loop L and the SCEV predication
14881 /// infrastructure.
14882 ///
14883 /// If \p Pred is non-null, the SCEV expression is rewritten to respect the
14884 /// equivalences present in \p Pred.
14885 ///
14886 /// If \p NewPreds is non-null, rewrite is free to add further predicates to
14887 /// \p NewPreds such that the result will be an AddRecExpr.
14888 static const SCEV *rewrite(const SCEV *S, const Loop *L, ScalarEvolution &SE,
14890 const SCEVPredicate *Pred) {
14891 SCEVPredicateRewriter Rewriter(L, SE, NewPreds, Pred);
14892 return Rewriter.visit(S);
14893 }
14894
14895 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
14896 if (Pred) {
14897 if (auto *U = dyn_cast<SCEVUnionPredicate>(Pred)) {
14898 for (const auto *Pred : U->getPredicates())
14899 if (const auto *IPred = dyn_cast<SCEVComparePredicate>(Pred))
14900 if (IPred->getLHS() == Expr &&
14901 IPred->getPredicate() == ICmpInst::ICMP_EQ)
14902 return IPred->getRHS();
14903 } else if (const auto *IPred = dyn_cast<SCEVComparePredicate>(Pred)) {
14904 if (IPred->getLHS() == Expr &&
14905 IPred->getPredicate() == ICmpInst::ICMP_EQ)
14906 return IPred->getRHS();
14907 }
14908 }
14909 return convertToAddRecWithPreds(Expr);
14910 }
14911
14912 const SCEV *visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr) {
14913 const SCEV *Operand = visit(Expr->getOperand());
14914 const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Operand);
14915 if (AR && AR->getLoop() == L && AR->isAffine()) {
14916 // This couldn't be folded because the operand didn't have the nuw
14917 // flag. Add the nusw flag as an assumption that we could make.
14918 const SCEV *Step = AR->getStepRecurrence(SE);
14919 Type *Ty = Expr->getType();
14920 if (addOverflowAssumption(AR, SCEVWrapPredicate::IncrementNUSW))
14921 return SE.getAddRecExpr(SE.getZeroExtendExpr(AR->getStart(), Ty),
14922 SE.getSignExtendExpr(Step, Ty), L,
14923 AR->getNoWrapFlags());
14924 }
14925 return SE.getZeroExtendExpr(Operand, Expr->getType());
14926 }
14927
14928 const SCEV *visitSignExtendExpr(const SCEVSignExtendExpr *Expr) {
14929 const SCEV *Operand = visit(Expr->getOperand());
14930 const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Operand);
14931 if (AR && AR->getLoop() == L && AR->isAffine()) {
14932 // This couldn't be folded because the operand didn't have the nsw
14933 // flag. Add the nssw flag as an assumption that we could make.
14934 const SCEV *Step = AR->getStepRecurrence(SE);
14935 Type *Ty = Expr->getType();
14936 if (addOverflowAssumption(AR, SCEVWrapPredicate::IncrementNSSW))
14937 return SE.getAddRecExpr(SE.getSignExtendExpr(AR->getStart(), Ty),
14938 SE.getSignExtendExpr(Step, Ty), L,
14939 AR->getNoWrapFlags());
14940 }
14941 return SE.getSignExtendExpr(Operand, Expr->getType());
14942 }
14943
14944private:
14945 explicit SCEVPredicateRewriter(
14946 const Loop *L, ScalarEvolution &SE,
14947 SmallVectorImpl<const SCEVPredicate *> *NewPreds,
14948 const SCEVPredicate *Pred)
14949 : SCEVRewriteVisitor(SE), NewPreds(NewPreds), Pred(Pred), L(L) {}
14950
14951 bool addOverflowAssumption(const SCEVPredicate *P) {
14952 if (!NewPreds) {
14953 // Check if we've already made this assumption.
14954 return Pred && Pred->implies(P, SE);
14955 }
14956 NewPreds->push_back(P);
14957 return true;
14958 }
14959
14960 bool addOverflowAssumption(const SCEVAddRecExpr *AR,
14962 auto *A = SE.getWrapPredicate(AR, AddedFlags);
14963 return addOverflowAssumption(A);
14964 }
14965
14966 // If \p Expr represents a PHINode, we try to see if it can be represented
14967 // as an AddRec, possibly under a predicate (PHISCEVPred). If it is possible
14968 // to add this predicate as a runtime overflow check, we return the AddRec.
14969 // If \p Expr does not meet these conditions (is not a PHI node, or we
14970 // couldn't create an AddRec for it, or couldn't add the predicate), we just
14971 // return \p Expr.
14972 const SCEV *convertToAddRecWithPreds(const SCEVUnknown *Expr) {
14973 if (!isa<PHINode>(Expr->getValue()))
14974 return Expr;
14975 std::optional<
14976 std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>>>
14977 PredicatedRewrite = SE.createAddRecFromPHIWithCasts(Expr);
14978 if (!PredicatedRewrite)
14979 return Expr;
14980 for (const auto *P : PredicatedRewrite->second){
14981 // Wrap predicates from outer loops are not supported.
14982 if (auto *WP = dyn_cast<const SCEVWrapPredicate>(P)) {
14983 if (L != WP->getExpr()->getLoop())
14984 return Expr;
14985 }
14986 if (!addOverflowAssumption(P))
14987 return Expr;
14988 }
14989 return PredicatedRewrite->first;
14990 }
14991
14992 SmallVectorImpl<const SCEVPredicate *> *NewPreds;
14993 const SCEVPredicate *Pred;
14994 const Loop *L;
14995};
14996
14997} // end anonymous namespace
14998
14999const SCEV *
15001 const SCEVPredicate &Preds) {
15002 return SCEVPredicateRewriter::rewrite(S, L, *this, nullptr, &Preds);
15003}
15004
15006 const SCEV *S, const Loop *L,
15009 S = SCEVPredicateRewriter::rewrite(S, L, *this, &TransformPreds, nullptr);
15010 auto *AddRec = dyn_cast<SCEVAddRecExpr>(S);
15011
15012 if (!AddRec)
15013 return nullptr;
15014
15015 // Check if any of the transformed predicates is known to be false. In that
15016 // case, it doesn't make sense to convert to a predicated AddRec, as the
15017 // versioned loop will never execute.
15018 for (const SCEVPredicate *Pred : TransformPreds) {
15019 auto *WrapPred = dyn_cast<SCEVWrapPredicate>(Pred);
15020 if (!WrapPred || WrapPred->getFlags() != SCEVWrapPredicate::IncrementNSSW)
15021 continue;
15022
15023 const SCEVAddRecExpr *AddRecToCheck = WrapPred->getExpr();
15024 const SCEV *ExitCount = getBackedgeTakenCount(AddRecToCheck->getLoop());
15025 if (isa<SCEVCouldNotCompute>(ExitCount))
15026 continue;
15027
15028 const SCEV *Step = AddRecToCheck->getStepRecurrence(*this);
15029 if (!Step->isOne())
15030 continue;
15031
15032 ExitCount = getTruncateOrSignExtend(ExitCount, Step->getType());
15033 const SCEV *Add = getAddExpr(AddRecToCheck->getStart(), ExitCount);
15034 if (isKnownPredicate(CmpInst::ICMP_SLT, Add, AddRecToCheck->getStart()))
15035 return nullptr;
15036 }
15037
15038 // Since the transformation was successful, we can now transfer the SCEV
15039 // predicates.
15040 Preds.append(TransformPreds.begin(), TransformPreds.end());
15041
15042 return AddRec;
15043}
15044
15045/// SCEV predicates
15049
15051 const ICmpInst::Predicate Pred,
15052 const SCEV *LHS, const SCEV *RHS)
15053 : SCEVPredicate(ID, P_Compare), Pred(Pred), LHS(LHS), RHS(RHS) {
15054 assert(LHS->getType() == RHS->getType() && "LHS and RHS types don't match");
15055 assert(LHS != RHS && "LHS and RHS are the same SCEV");
15056}
15057
15059 ScalarEvolution &SE) const {
15060 const auto *Op = dyn_cast<SCEVComparePredicate>(N);
15061
15062 if (!Op)
15063 return false;
15064
15065 if (Pred != ICmpInst::ICMP_EQ)
15066 return false;
15067
15068 return Op->LHS == LHS && Op->RHS == RHS;
15069}
15070
15071bool SCEVComparePredicate::isAlwaysTrue() const { return false; }
15072
15074 if (Pred == ICmpInst::ICMP_EQ)
15075 OS.indent(Depth) << "Equal predicate: " << *LHS << " == " << *RHS << "\n";
15076 else
15077 OS.indent(Depth) << "Compare predicate: " << *LHS << " " << Pred << ") "
15078 << *RHS << "\n";
15079
15080}
15081
15083 const SCEVAddRecExpr *AR,
15084 IncrementWrapFlags Flags)
15085 : SCEVPredicate(ID, P_Wrap), AR(AR), Flags(Flags) {}
15086
15087const SCEVAddRecExpr *SCEVWrapPredicate::getExpr() const { return AR; }
15088
15090 ScalarEvolution &SE) const {
15091 const auto *Op = dyn_cast<SCEVWrapPredicate>(N);
15092 if (!Op || setFlags(Flags, Op->Flags) != Flags)
15093 return false;
15094
15095 if (Op->AR == AR)
15096 return true;
15097
15098 if (Flags != SCEVWrapPredicate::IncrementNSSW &&
15100 return false;
15101
15102 const SCEV *Start = AR->getStart();
15103 const SCEV *OpStart = Op->AR->getStart();
15104 if (Start->getType()->isPointerTy() != OpStart->getType()->isPointerTy())
15105 return false;
15106
15107 // Reject pointers to different address spaces.
15108 if (Start->getType()->isPointerTy() && Start->getType() != OpStart->getType())
15109 return false;
15110
15111 const SCEV *Step = AR->getStepRecurrence(SE);
15112 const SCEV *OpStep = Op->AR->getStepRecurrence(SE);
15113 if (!SE.isKnownPositive(Step) || !SE.isKnownPositive(OpStep))
15114 return false;
15115
15116 // If both steps are positive, this implies N, if N's start and step are
15117 // ULE/SLE (for NSUW/NSSW) than this'.
15118 Type *WiderTy = SE.getWiderType(Step->getType(), OpStep->getType());
15119 Step = SE.getNoopOrZeroExtend(Step, WiderTy);
15120 OpStep = SE.getNoopOrZeroExtend(OpStep, WiderTy);
15121
15122 bool IsNUW = Flags == SCEVWrapPredicate::IncrementNUSW;
15123 OpStart = IsNUW ? SE.getNoopOrZeroExtend(OpStart, WiderTy)
15124 : SE.getNoopOrSignExtend(OpStart, WiderTy);
15125 Start = IsNUW ? SE.getNoopOrZeroExtend(Start, WiderTy)
15126 : SE.getNoopOrSignExtend(Start, WiderTy);
15128 return SE.isKnownPredicate(Pred, OpStep, Step) &&
15129 SE.isKnownPredicate(Pred, OpStart, Start);
15130}
15131
15133 SCEV::NoWrapFlags ScevFlags = AR->getNoWrapFlags();
15134 IncrementWrapFlags IFlags = Flags;
15135
15136 if (ScalarEvolution::setFlags(ScevFlags, SCEV::FlagNSW) == ScevFlags)
15137 IFlags = clearFlags(IFlags, IncrementNSSW);
15138
15139 return IFlags == IncrementAnyWrap;
15140}
15141
15142void SCEVWrapPredicate::print(raw_ostream &OS, unsigned Depth) const {
15143 OS.indent(Depth) << *getExpr() << " Added Flags: ";
15145 OS << "<nusw>";
15147 OS << "<nssw>";
15148 OS << "\n";
15149}
15150
15153 ScalarEvolution &SE) {
15154 IncrementWrapFlags ImpliedFlags = IncrementAnyWrap;
15155 SCEV::NoWrapFlags StaticFlags = AR->getNoWrapFlags();
15156
15157 // We can safely transfer the NSW flag as NSSW.
15158 if (ScalarEvolution::setFlags(StaticFlags, SCEV::FlagNSW) == StaticFlags)
15159 ImpliedFlags = IncrementNSSW;
15160
15161 if (ScalarEvolution::setFlags(StaticFlags, SCEV::FlagNUW) == StaticFlags) {
15162 // If the increment is positive, the SCEV NUW flag will also imply the
15163 // WrapPredicate NUSW flag.
15164 if (const auto *Step = dyn_cast<SCEVConstant>(AR->getStepRecurrence(SE)))
15165 if (Step->getValue()->getValue().isNonNegative())
15166 ImpliedFlags = setFlags(ImpliedFlags, IncrementNUSW);
15167 }
15168
15169 return ImpliedFlags;
15170}
15171
15172/// Union predicates don't get cached so create a dummy set ID for it.
15174 ScalarEvolution &SE)
15175 : SCEVPredicate(FoldingSetNodeIDRef(nullptr, 0), P_Union) {
15176 for (const auto *P : Preds)
15177 add(P, SE);
15178}
15179
15181 return all_of(Preds,
15182 [](const SCEVPredicate *I) { return I->isAlwaysTrue(); });
15183}
15184
15186 ScalarEvolution &SE) const {
15187 if (const auto *Set = dyn_cast<SCEVUnionPredicate>(N))
15188 return all_of(Set->Preds, [this, &SE](const SCEVPredicate *I) {
15189 return this->implies(I, SE);
15190 });
15191
15192 return any_of(Preds,
15193 [N, &SE](const SCEVPredicate *I) { return I->implies(N, SE); });
15194}
15195
15197 for (const auto *Pred : Preds)
15198 Pred->print(OS, Depth);
15199}
15200
15201void SCEVUnionPredicate::add(const SCEVPredicate *N, ScalarEvolution &SE) {
15202 if (const auto *Set = dyn_cast<SCEVUnionPredicate>(N)) {
15203 for (const auto *Pred : Set->Preds)
15204 add(Pred, SE);
15205 return;
15206 }
15207
15208 // Implication checks are quadratic in the number of predicates. Stop doing
15209 // them if there are many predicates, as they should be too expensive to use
15210 // anyway at that point.
15211 bool CheckImplies = Preds.size() < 16;
15212
15213 // Only add predicate if it is not already implied by this union predicate.
15214 if (CheckImplies && implies(N, SE))
15215 return;
15216
15217 // Build a new vector containing the current predicates, except the ones that
15218 // are implied by the new predicate N.
15220 for (auto *P : Preds) {
15221 if (CheckImplies && N->implies(P, SE))
15222 continue;
15223 PrunedPreds.push_back(P);
15224 }
15225 Preds = std::move(PrunedPreds);
15226 Preds.push_back(N);
15227}
15228
15230 Loop &L)
15231 : SE(SE), L(L) {
15233 Preds = std::make_unique<SCEVUnionPredicate>(Empty, SE);
15234}
15235
15238 for (const auto *Op : Ops)
15239 // We do not expect that forgetting cached data for SCEVConstants will ever
15240 // open any prospects for sharpening or introduce any correctness issues,
15241 // so we don't bother storing their dependencies.
15242 if (!isa<SCEVConstant>(Op))
15243 SCEVUsers[Op].insert(User);
15244}
15245
15247 const SCEV *Expr = SE.getSCEV(V);
15248 RewriteEntry &Entry = RewriteMap[Expr];
15249
15250 // If we already have an entry and the version matches, return it.
15251 if (Entry.second && Generation == Entry.first)
15252 return Entry.second;
15253
15254 // We found an entry but it's stale. Rewrite the stale entry
15255 // according to the current predicate.
15256 if (Entry.second)
15257 Expr = Entry.second;
15258
15259 const SCEV *NewSCEV = SE.rewriteUsingPredicate(Expr, &L, *Preds);
15260 Entry = {Generation, NewSCEV};
15261
15262 return NewSCEV;
15263}
15264
15266 if (!BackedgeCount) {
15268 BackedgeCount = SE.getPredicatedBackedgeTakenCount(&L, Preds);
15269 for (const auto *P : Preds)
15270 addPredicate(*P);
15271 }
15272 return BackedgeCount;
15273}
15274
15276 if (!SymbolicMaxBackedgeCount) {
15278 SymbolicMaxBackedgeCount =
15279 SE.getPredicatedSymbolicMaxBackedgeTakenCount(&L, Preds);
15280 for (const auto *P : Preds)
15281 addPredicate(*P);
15282 }
15283 return SymbolicMaxBackedgeCount;
15284}
15285
15287 if (!SmallConstantMaxTripCount) {
15289 SmallConstantMaxTripCount = SE.getSmallConstantMaxTripCount(&L, &Preds);
15290 for (const auto *P : Preds)
15291 addPredicate(*P);
15292 }
15293 return *SmallConstantMaxTripCount;
15294}
15295
15297 if (Preds->implies(&Pred, SE))
15298 return;
15299
15300 SmallVector<const SCEVPredicate *, 4> NewPreds(Preds->getPredicates());
15301 NewPreds.push_back(&Pred);
15302 Preds = std::make_unique<SCEVUnionPredicate>(NewPreds, SE);
15303 updateGeneration();
15304}
15305
15307 return *Preds;
15308}
15309
15310void PredicatedScalarEvolution::updateGeneration() {
15311 // If the generation number wrapped recompute everything.
15312 if (++Generation == 0) {
15313 for (auto &II : RewriteMap) {
15314 const SCEV *Rewritten = II.second.second;
15315 II.second = {Generation, SE.rewriteUsingPredicate(Rewritten, &L, *Preds)};
15316 }
15317 }
15318}
15319
15322 const SCEV *Expr = getSCEV(V);
15323 const auto *AR = cast<SCEVAddRecExpr>(Expr);
15324
15325 auto ImpliedFlags = SCEVWrapPredicate::getImpliedFlags(AR, SE);
15326
15327 // Clear the statically implied flags.
15328 Flags = SCEVWrapPredicate::clearFlags(Flags, ImpliedFlags);
15329 addPredicate(*SE.getWrapPredicate(AR, Flags));
15330
15331 auto II = FlagsMap.insert({V, Flags});
15332 if (!II.second)
15333 II.first->second = SCEVWrapPredicate::setFlags(Flags, II.first->second);
15334}
15335
15338 const SCEV *Expr = getSCEV(V);
15339 const auto *AR = cast<SCEVAddRecExpr>(Expr);
15340
15342 Flags, SCEVWrapPredicate::getImpliedFlags(AR, SE));
15343
15344 auto II = FlagsMap.find(V);
15345
15346 if (II != FlagsMap.end())
15347 Flags = SCEVWrapPredicate::clearFlags(Flags, II->second);
15348
15350}
15351
15353 const SCEV *Expr = this->getSCEV(V);
15355 auto *New = SE.convertSCEVToAddRecWithPredicates(Expr, &L, NewPreds);
15356
15357 if (!New)
15358 return nullptr;
15359
15360 for (const auto *P : NewPreds)
15361 addPredicate(*P);
15362
15363 RewriteMap[SE.getSCEV(V)] = {Generation, New};
15364 return New;
15365}
15366
15369 : RewriteMap(Init.RewriteMap), SE(Init.SE), L(Init.L),
15370 Preds(std::make_unique<SCEVUnionPredicate>(Init.Preds->getPredicates(),
15371 SE)),
15372 Generation(Init.Generation), BackedgeCount(Init.BackedgeCount) {
15373 for (auto I : Init.FlagsMap)
15374 FlagsMap.insert(I);
15375}
15376
15378 // For each block.
15379 for (auto *BB : L.getBlocks())
15380 for (auto &I : *BB) {
15381 if (!SE.isSCEVable(I.getType()))
15382 continue;
15383
15384 auto *Expr = SE.getSCEV(&I);
15385 auto II = RewriteMap.find(Expr);
15386
15387 if (II == RewriteMap.end())
15388 continue;
15389
15390 // Don't print things that are not interesting.
15391 if (II->second.second == Expr)
15392 continue;
15393
15394 OS.indent(Depth) << "[PSE]" << I << ":\n";
15395 OS.indent(Depth + 2) << *Expr << "\n";
15396 OS.indent(Depth + 2) << "--> " << *II->second.second << "\n";
15397 }
15398}
15399
15402 BasicBlock *Header = L->getHeader();
15403 BasicBlock *Pred = L->getLoopPredecessor();
15404 LoopGuards Guards(SE);
15405 if (!Pred)
15406 return Guards;
15408 collectFromBlock(SE, Guards, Header, Pred, VisitedBlocks);
15409 return Guards;
15410}
15411
15412void ScalarEvolution::LoopGuards::collectFromPHI(
15416 unsigned Depth) {
15417 if (!SE.isSCEVable(Phi.getType()))
15418 return;
15419
15420 using MinMaxPattern = std::pair<const SCEVConstant *, SCEVTypes>;
15421 auto GetMinMaxConst = [&](unsigned IncomingIdx) -> MinMaxPattern {
15422 const BasicBlock *InBlock = Phi.getIncomingBlock(IncomingIdx);
15423 if (!VisitedBlocks.insert(InBlock).second)
15424 return {nullptr, scCouldNotCompute};
15425
15426 // Avoid analyzing unreachable blocks so that we don't get trapped
15427 // traversing cycles with ill-formed dominance or infinite cycles
15428 if (!SE.DT.isReachableFromEntry(InBlock))
15429 return {nullptr, scCouldNotCompute};
15430
15431 auto [G, Inserted] = IncomingGuards.try_emplace(InBlock, LoopGuards(SE));
15432 if (Inserted)
15433 collectFromBlock(SE, G->second, Phi.getParent(), InBlock, VisitedBlocks,
15434 Depth + 1);
15435 auto &RewriteMap = G->second.RewriteMap;
15436 if (RewriteMap.empty())
15437 return {nullptr, scCouldNotCompute};
15438 auto S = RewriteMap.find(SE.getSCEV(Phi.getIncomingValue(IncomingIdx)));
15439 if (S == RewriteMap.end())
15440 return {nullptr, scCouldNotCompute};
15441 auto *SM = dyn_cast_if_present<SCEVMinMaxExpr>(S->second);
15442 if (!SM)
15443 return {nullptr, scCouldNotCompute};
15444 if (const SCEVConstant *C0 = dyn_cast<SCEVConstant>(SM->getOperand(0)))
15445 return {C0, SM->getSCEVType()};
15446 return {nullptr, scCouldNotCompute};
15447 };
15448 auto MergeMinMaxConst = [](MinMaxPattern P1,
15449 MinMaxPattern P2) -> MinMaxPattern {
15450 auto [C1, T1] = P1;
15451 auto [C2, T2] = P2;
15452 if (!C1 || !C2 || T1 != T2)
15453 return {nullptr, scCouldNotCompute};
15454 switch (T1) {
15455 case scUMaxExpr:
15456 return {C1->getAPInt().ult(C2->getAPInt()) ? C1 : C2, T1};
15457 case scSMaxExpr:
15458 return {C1->getAPInt().slt(C2->getAPInt()) ? C1 : C2, T1};
15459 case scUMinExpr:
15460 return {C1->getAPInt().ugt(C2->getAPInt()) ? C1 : C2, T1};
15461 case scSMinExpr:
15462 return {C1->getAPInt().sgt(C2->getAPInt()) ? C1 : C2, T1};
15463 default:
15464 llvm_unreachable("Trying to merge non-MinMaxExpr SCEVs.");
15465 }
15466 };
15467 auto P = GetMinMaxConst(0);
15468 for (unsigned int In = 1; In < Phi.getNumIncomingValues(); In++) {
15469 if (!P.first)
15470 break;
15471 P = MergeMinMaxConst(P, GetMinMaxConst(In));
15472 }
15473 if (P.first) {
15474 const SCEV *LHS = SE.getSCEV(const_cast<PHINode *>(&Phi));
15476 const SCEV *RHS = SE.getMinMaxExpr(P.second, Ops);
15477 Guards.RewriteMap.insert({LHS, RHS});
15478 }
15479}
15480
15481// Return a new SCEV that modifies \p Expr to the closest number divides by
15482// \p Divisor and less or equal than Expr. For now, only handle constant
15483// Expr.
15485 const APInt &DivisorVal,
15486 ScalarEvolution &SE) {
15487 const APInt *ExprVal;
15488 if (!match(Expr, m_scev_APInt(ExprVal)) || ExprVal->isNegative() ||
15489 DivisorVal.isNonPositive())
15490 return Expr;
15491 APInt Rem = ExprVal->urem(DivisorVal);
15492 // return the SCEV: Expr - Expr % Divisor
15493 return SE.getConstant(*ExprVal - Rem);
15494}
15495
15496// Return a new SCEV that modifies \p Expr to the closest number divides by
15497// \p Divisor and greater or equal than Expr. For now, only handle constant
15498// Expr.
15499static const SCEV *getNextSCEVDivisibleByDivisor(const SCEV *Expr,
15500 const APInt &DivisorVal,
15501 ScalarEvolution &SE) {
15502 const APInt *ExprVal;
15503 if (!match(Expr, m_scev_APInt(ExprVal)) || ExprVal->isNegative() ||
15504 DivisorVal.isNonPositive())
15505 return Expr;
15506 APInt Rem = ExprVal->urem(DivisorVal);
15507 if (Rem.isZero())
15508 return Expr;
15509 // return the SCEV: Expr + Divisor - Expr % Divisor
15510 return SE.getConstant(*ExprVal + DivisorVal - Rem);
15511}
15512
15513void ScalarEvolution::LoopGuards::collectFromBlock(
15514 ScalarEvolution &SE, ScalarEvolution::LoopGuards &Guards,
15515 const BasicBlock *Block, const BasicBlock *Pred,
15516 SmallPtrSetImpl<const BasicBlock *> &VisitedBlocks, unsigned Depth) {
15517
15519
15520 SmallVector<const SCEV *> ExprsToRewrite;
15521 auto CollectCondition = [&](ICmpInst::Predicate Predicate, const SCEV *LHS,
15522 const SCEV *RHS,
15523 DenseMap<const SCEV *, const SCEV *>
15524 &RewriteMap) {
15525 // WARNING: It is generally unsound to apply any wrap flags to the proposed
15526 // replacement SCEV which isn't directly implied by the structure of that
15527 // SCEV. In particular, using contextual facts to imply flags is *NOT*
15528 // legal. See the scoping rules for flags in the header to understand why.
15529
15530 // If LHS is a constant, apply information to the other expression.
15531 if (isa<SCEVConstant>(LHS)) {
15532 std::swap(LHS, RHS);
15534 }
15535
15536 // Check for a condition of the form (-C1 + X < C2). InstCombine will
15537 // create this form when combining two checks of the form (X u< C2 + C1) and
15538 // (X >=u C1).
15539 auto MatchRangeCheckIdiom = [&SE, Predicate, LHS, RHS, &RewriteMap,
15540 &ExprsToRewrite]() {
15541 const SCEVConstant *C1;
15542 const SCEVUnknown *LHSUnknown;
15543 auto *C2 = dyn_cast<SCEVConstant>(RHS);
15544 if (!match(LHS,
15545 m_scev_Add(m_SCEVConstant(C1), m_SCEVUnknown(LHSUnknown))) ||
15546 !C2)
15547 return false;
15548
15549 auto ExactRegion =
15550 ConstantRange::makeExactICmpRegion(Predicate, C2->getAPInt())
15551 .sub(C1->getAPInt());
15552
15553 // Bail out, unless we have a non-wrapping, monotonic range.
15554 if (ExactRegion.isWrappedSet() || ExactRegion.isFullSet())
15555 return false;
15556 auto [I, Inserted] = RewriteMap.try_emplace(LHSUnknown);
15557 const SCEV *RewrittenLHS = Inserted ? LHSUnknown : I->second;
15558 I->second = SE.getUMaxExpr(
15559 SE.getConstant(ExactRegion.getUnsignedMin()),
15560 SE.getUMinExpr(RewrittenLHS,
15561 SE.getConstant(ExactRegion.getUnsignedMax())));
15562 ExprsToRewrite.push_back(LHSUnknown);
15563 return true;
15564 };
15565 if (MatchRangeCheckIdiom())
15566 return;
15567
15568 // Return true if \p Expr is a MinMax SCEV expression with a non-negative
15569 // constant operand. If so, return in \p SCTy the SCEV type and in \p RHS
15570 // the non-constant operand and in \p LHS the constant operand.
15571 auto IsMinMaxSCEVWithNonNegativeConstant =
15572 [&](const SCEV *Expr, SCEVTypes &SCTy, const SCEV *&LHS,
15573 const SCEV *&RHS) {
15574 const APInt *C;
15575 SCTy = Expr->getSCEVType();
15576 return match(Expr, m_scev_MinMax(m_SCEV(LHS), m_SCEV(RHS))) &&
15577 match(LHS, m_scev_APInt(C)) && C->isNonNegative();
15578 };
15579
15580 // Apply divisibilty by \p Divisor on MinMaxExpr with constant values,
15581 // recursively. This is done by aligning up/down the constant value to the
15582 // Divisor.
15583 std::function<const SCEV *(const SCEV *, const SCEV *)>
15584 ApplyDivisibiltyOnMinMaxExpr = [&](const SCEV *MinMaxExpr,
15585 const SCEV *Divisor) {
15586 auto *ConstDivisor = dyn_cast<SCEVConstant>(Divisor);
15587 if (!ConstDivisor)
15588 return MinMaxExpr;
15589 const APInt &DivisorVal = ConstDivisor->getAPInt();
15590
15591 const SCEV *MinMaxLHS = nullptr, *MinMaxRHS = nullptr;
15592 SCEVTypes SCTy;
15593 if (!IsMinMaxSCEVWithNonNegativeConstant(MinMaxExpr, SCTy, MinMaxLHS,
15594 MinMaxRHS))
15595 return MinMaxExpr;
15596 auto IsMin =
15597 isa<SCEVSMinExpr>(MinMaxExpr) || isa<SCEVUMinExpr>(MinMaxExpr);
15598 assert(SE.isKnownNonNegative(MinMaxLHS) &&
15599 "Expected non-negative operand!");
15600 auto *DivisibleExpr =
15601 IsMin
15602 ? getPreviousSCEVDivisibleByDivisor(MinMaxLHS, DivisorVal, SE)
15603 : getNextSCEVDivisibleByDivisor(MinMaxLHS, DivisorVal, SE);
15605 ApplyDivisibiltyOnMinMaxExpr(MinMaxRHS, Divisor), DivisibleExpr};
15606 return SE.getMinMaxExpr(SCTy, Ops);
15607 };
15608
15609 // If we have LHS == 0, check if LHS is computing a property of some unknown
15610 // SCEV %v which we can rewrite %v to express explicitly.
15611 if (Predicate == CmpInst::ICMP_EQ && match(RHS, m_scev_Zero())) {
15612 // If LHS is A % B, i.e. A % B == 0, rewrite A to (A /u B) * B to
15613 // explicitly express that.
15614 const SCEVUnknown *URemLHS = nullptr;
15615 const SCEV *URemRHS = nullptr;
15616 if (match(LHS,
15617 m_scev_URem(m_SCEVUnknown(URemLHS), m_SCEV(URemRHS), SE))) {
15618 auto I = RewriteMap.find(URemLHS);
15619 const SCEV *RewrittenLHS = I != RewriteMap.end() ? I->second : URemLHS;
15620 RewrittenLHS = ApplyDivisibiltyOnMinMaxExpr(RewrittenLHS, URemRHS);
15621 const auto *Multiple =
15622 SE.getMulExpr(SE.getUDivExpr(RewrittenLHS, URemRHS), URemRHS);
15623 RewriteMap[URemLHS] = Multiple;
15624 ExprsToRewrite.push_back(URemLHS);
15625 return;
15626 }
15627 }
15628
15629 // Do not apply information for constants or if RHS contains an AddRec.
15631 return;
15632
15633 // If RHS is SCEVUnknown, make sure the information is applied to it.
15635 std::swap(LHS, RHS);
15637 }
15638
15639 // Puts rewrite rule \p From -> \p To into the rewrite map. Also if \p From
15640 // and \p FromRewritten are the same (i.e. there has been no rewrite
15641 // registered for \p From), then puts this value in the list of rewritten
15642 // expressions.
15643 auto AddRewrite = [&](const SCEV *From, const SCEV *FromRewritten,
15644 const SCEV *To) {
15645 if (From == FromRewritten)
15646 ExprsToRewrite.push_back(From);
15647 RewriteMap[From] = To;
15648 };
15649
15650 // Checks whether \p S has already been rewritten. In that case returns the
15651 // existing rewrite because we want to chain further rewrites onto the
15652 // already rewritten value. Otherwise returns \p S.
15653 auto GetMaybeRewritten = [&](const SCEV *S) {
15654 return RewriteMap.lookup_or(S, S);
15655 };
15656
15657 const SCEV *RewrittenLHS = GetMaybeRewritten(LHS);
15658 const APInt &DividesBy = SE.getConstantMultiple(RewrittenLHS);
15659
15660 // Collect rewrites for LHS and its transitive operands based on the
15661 // condition.
15662 // For min/max expressions, also apply the guard to its operands:
15663 // 'min(a, b) >= c' -> '(a >= c) and (b >= c)',
15664 // 'min(a, b) > c' -> '(a > c) and (b > c)',
15665 // 'max(a, b) <= c' -> '(a <= c) and (b <= c)',
15666 // 'max(a, b) < c' -> '(a < c) and (b < c)'.
15667
15668 // We cannot express strict predicates in SCEV, so instead we replace them
15669 // with non-strict ones against plus or minus one of RHS depending on the
15670 // predicate.
15671 const SCEV *One = SE.getOne(RHS->getType());
15672 switch (Predicate) {
15673 case CmpInst::ICMP_ULT:
15674 if (RHS->getType()->isPointerTy())
15675 return;
15676 RHS = SE.getUMaxExpr(RHS, One);
15677 [[fallthrough]];
15678 case CmpInst::ICMP_SLT: {
15679 RHS = SE.getMinusSCEV(RHS, One);
15680 RHS = getPreviousSCEVDivisibleByDivisor(RHS, DividesBy, SE);
15681 break;
15682 }
15683 case CmpInst::ICMP_UGT:
15684 case CmpInst::ICMP_SGT:
15685 RHS = SE.getAddExpr(RHS, One);
15686 RHS = getNextSCEVDivisibleByDivisor(RHS, DividesBy, SE);
15687 break;
15688 case CmpInst::ICMP_ULE:
15689 case CmpInst::ICMP_SLE:
15690 RHS = getPreviousSCEVDivisibleByDivisor(RHS, DividesBy, SE);
15691 break;
15692 case CmpInst::ICMP_UGE:
15693 case CmpInst::ICMP_SGE:
15694 RHS = getNextSCEVDivisibleByDivisor(RHS, DividesBy, SE);
15695 break;
15696 default:
15697 break;
15698 }
15699
15701 SmallPtrSet<const SCEV *, 16> Visited;
15702
15703 auto EnqueueOperands = [&Worklist](const SCEVNAryExpr *S) {
15704 append_range(Worklist, S->operands());
15705 };
15706
15707 while (!Worklist.empty()) {
15708 const SCEV *From = Worklist.pop_back_val();
15709 if (isa<SCEVConstant>(From))
15710 continue;
15711 if (!Visited.insert(From).second)
15712 continue;
15713 const SCEV *FromRewritten = GetMaybeRewritten(From);
15714 const SCEV *To = nullptr;
15715
15716 switch (Predicate) {
15717 case CmpInst::ICMP_ULT:
15718 case CmpInst::ICMP_ULE:
15719 To = SE.getUMinExpr(FromRewritten, RHS);
15720 if (auto *UMax = dyn_cast<SCEVUMaxExpr>(FromRewritten))
15721 EnqueueOperands(UMax);
15722 break;
15723 case CmpInst::ICMP_SLT:
15724 case CmpInst::ICMP_SLE:
15725 To = SE.getSMinExpr(FromRewritten, RHS);
15726 if (auto *SMax = dyn_cast<SCEVSMaxExpr>(FromRewritten))
15727 EnqueueOperands(SMax);
15728 break;
15729 case CmpInst::ICMP_UGT:
15730 case CmpInst::ICMP_UGE:
15731 To = SE.getUMaxExpr(FromRewritten, RHS);
15732 if (auto *UMin = dyn_cast<SCEVUMinExpr>(FromRewritten))
15733 EnqueueOperands(UMin);
15734 break;
15735 case CmpInst::ICMP_SGT:
15736 case CmpInst::ICMP_SGE:
15737 To = SE.getSMaxExpr(FromRewritten, RHS);
15738 if (auto *SMin = dyn_cast<SCEVSMinExpr>(FromRewritten))
15739 EnqueueOperands(SMin);
15740 break;
15741 case CmpInst::ICMP_EQ:
15743 To = RHS;
15744 break;
15745 case CmpInst::ICMP_NE:
15746 if (match(RHS, m_scev_Zero())) {
15747 const SCEV *OneAlignedUp =
15748 getNextSCEVDivisibleByDivisor(One, DividesBy, SE);
15749 To = SE.getUMaxExpr(FromRewritten, OneAlignedUp);
15750 } else {
15751 // LHS != RHS can be rewritten as (LHS - RHS) = UMax(1, LHS - RHS),
15752 // but creating the subtraction eagerly is expensive. Track the
15753 // inequalities in a separate map, and materialize the rewrite lazily
15754 // when encountering a suitable subtraction while re-writing.
15755 if (LHS->getType()->isPointerTy()) {
15759 break;
15760 }
15761 const SCEVConstant *C;
15762 const SCEV *A, *B;
15765 RHS = A;
15766 LHS = B;
15767 }
15768 if (LHS > RHS)
15769 std::swap(LHS, RHS);
15770 Guards.NotEqual.insert({LHS, RHS});
15771 continue;
15772 }
15773 break;
15774 default:
15775 break;
15776 }
15777
15778 if (To)
15779 AddRewrite(From, FromRewritten, To);
15780 }
15781 };
15782
15784 // First, collect information from assumptions dominating the loop.
15785 for (auto &AssumeVH : SE.AC.assumptions()) {
15786 if (!AssumeVH)
15787 continue;
15788 auto *AssumeI = cast<CallInst>(AssumeVH);
15789 if (!SE.DT.dominates(AssumeI, Block))
15790 continue;
15791 Terms.emplace_back(AssumeI->getOperand(0), true);
15792 }
15793
15794 // Second, collect information from llvm.experimental.guards dominating the loop.
15795 auto *GuardDecl = Intrinsic::getDeclarationIfExists(
15796 SE.F.getParent(), Intrinsic::experimental_guard);
15797 if (GuardDecl)
15798 for (const auto *GU : GuardDecl->users())
15799 if (const auto *Guard = dyn_cast<IntrinsicInst>(GU))
15800 if (Guard->getFunction() == Block->getParent() &&
15801 SE.DT.dominates(Guard, Block))
15802 Terms.emplace_back(Guard->getArgOperand(0), true);
15803
15804 // Third, collect conditions from dominating branches. Starting at the loop
15805 // predecessor, climb up the predecessor chain, as long as there are
15806 // predecessors that can be found that have unique successors leading to the
15807 // original header.
15808 // TODO: share this logic with isLoopEntryGuardedByCond.
15809 unsigned NumCollectedConditions = 0;
15811 std::pair<const BasicBlock *, const BasicBlock *> Pair(Pred, Block);
15812 for (; Pair.first;
15813 Pair = SE.getPredecessorWithUniqueSuccessorForBB(Pair.first)) {
15814 VisitedBlocks.insert(Pair.second);
15815 const BranchInst *LoopEntryPredicate =
15816 dyn_cast<BranchInst>(Pair.first->getTerminator());
15817 if (!LoopEntryPredicate || LoopEntryPredicate->isUnconditional())
15818 continue;
15819
15820 Terms.emplace_back(LoopEntryPredicate->getCondition(),
15821 LoopEntryPredicate->getSuccessor(0) == Pair.second);
15822 NumCollectedConditions++;
15823
15824 // If we are recursively collecting guards stop after 2
15825 // conditions to limit compile-time impact for now.
15826 if (Depth > 0 && NumCollectedConditions == 2)
15827 break;
15828 }
15829 // Finally, if we stopped climbing the predecessor chain because
15830 // there wasn't a unique one to continue, try to collect conditions
15831 // for PHINodes by recursively following all of their incoming
15832 // blocks and try to merge the found conditions to build a new one
15833 // for the Phi.
15834 if (Pair.second->hasNPredecessorsOrMore(2) &&
15836 SmallDenseMap<const BasicBlock *, LoopGuards> IncomingGuards;
15837 for (auto &Phi : Pair.second->phis())
15838 collectFromPHI(SE, Guards, Phi, VisitedBlocks, IncomingGuards, Depth);
15839 }
15840
15841 // Now apply the information from the collected conditions to
15842 // Guards.RewriteMap. Conditions are processed in reverse order, so the
15843 // earliest conditions is processed first. This ensures the SCEVs with the
15844 // shortest dependency chains are constructed first.
15845 for (auto [Term, EnterIfTrue] : reverse(Terms)) {
15846 SmallVector<Value *, 8> Worklist;
15847 SmallPtrSet<Value *, 8> Visited;
15848 Worklist.push_back(Term);
15849 while (!Worklist.empty()) {
15850 Value *Cond = Worklist.pop_back_val();
15851 if (!Visited.insert(Cond).second)
15852 continue;
15853
15854 if (auto *Cmp = dyn_cast<ICmpInst>(Cond)) {
15855 auto Predicate =
15856 EnterIfTrue ? Cmp->getPredicate() : Cmp->getInversePredicate();
15857 const auto *LHS = SE.getSCEV(Cmp->getOperand(0));
15858 const auto *RHS = SE.getSCEV(Cmp->getOperand(1));
15859 CollectCondition(Predicate, LHS, RHS, Guards.RewriteMap);
15860 continue;
15861 }
15862
15863 Value *L, *R;
15864 if (EnterIfTrue ? match(Cond, m_LogicalAnd(m_Value(L), m_Value(R)))
15865 : match(Cond, m_LogicalOr(m_Value(L), m_Value(R)))) {
15866 Worklist.push_back(L);
15867 Worklist.push_back(R);
15868 }
15869 }
15870 }
15871
15872 // Let the rewriter preserve NUW/NSW flags if the unsigned/signed ranges of
15873 // the replacement expressions are contained in the ranges of the replaced
15874 // expressions.
15875 Guards.PreserveNUW = true;
15876 Guards.PreserveNSW = true;
15877 for (const SCEV *Expr : ExprsToRewrite) {
15878 const SCEV *RewriteTo = Guards.RewriteMap[Expr];
15879 Guards.PreserveNUW &=
15880 SE.getUnsignedRange(Expr).contains(SE.getUnsignedRange(RewriteTo));
15881 Guards.PreserveNSW &=
15882 SE.getSignedRange(Expr).contains(SE.getSignedRange(RewriteTo));
15883 }
15884
15885 // Now that all rewrite information is collect, rewrite the collected
15886 // expressions with the information in the map. This applies information to
15887 // sub-expressions.
15888 if (ExprsToRewrite.size() > 1) {
15889 for (const SCEV *Expr : ExprsToRewrite) {
15890 const SCEV *RewriteTo = Guards.RewriteMap[Expr];
15891 Guards.RewriteMap.erase(Expr);
15892 Guards.RewriteMap.insert({Expr, Guards.rewrite(RewriteTo)});
15893 }
15894 }
15895}
15896
15898 /// A rewriter to replace SCEV expressions in Map with the corresponding entry
15899 /// in the map. It skips AddRecExpr because we cannot guarantee that the
15900 /// replacement is loop invariant in the loop of the AddRec.
15901 class SCEVLoopGuardRewriter
15902 : public SCEVRewriteVisitor<SCEVLoopGuardRewriter> {
15905
15907
15908 public:
15909 SCEVLoopGuardRewriter(ScalarEvolution &SE,
15910 const ScalarEvolution::LoopGuards &Guards)
15911 : SCEVRewriteVisitor(SE), Map(Guards.RewriteMap),
15912 NotEqual(Guards.NotEqual) {
15913 if (Guards.PreserveNUW)
15914 FlagMask = ScalarEvolution::setFlags(FlagMask, SCEV::FlagNUW);
15915 if (Guards.PreserveNSW)
15916 FlagMask = ScalarEvolution::setFlags(FlagMask, SCEV::FlagNSW);
15917 }
15918
15919 const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) { return Expr; }
15920
15921 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
15922 return Map.lookup_or(Expr, Expr);
15923 }
15924
15925 const SCEV *visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr) {
15926 if (const SCEV *S = Map.lookup(Expr))
15927 return S;
15928
15929 // If we didn't find the extact ZExt expr in the map, check if there's
15930 // an entry for a smaller ZExt we can use instead.
15931 Type *Ty = Expr->getType();
15932 const SCEV *Op = Expr->getOperand(0);
15933 unsigned Bitwidth = Ty->getScalarSizeInBits() / 2;
15934 while (Bitwidth % 8 == 0 && Bitwidth >= 8 &&
15935 Bitwidth > Op->getType()->getScalarSizeInBits()) {
15936 Type *NarrowTy = IntegerType::get(SE.getContext(), Bitwidth);
15937 auto *NarrowExt = SE.getZeroExtendExpr(Op, NarrowTy);
15938 if (const SCEV *S = Map.lookup(NarrowExt))
15939 return SE.getZeroExtendExpr(S, Ty);
15940 Bitwidth = Bitwidth / 2;
15941 }
15942
15944 Expr);
15945 }
15946
15947 const SCEV *visitSignExtendExpr(const SCEVSignExtendExpr *Expr) {
15948 if (const SCEV *S = Map.lookup(Expr))
15949 return S;
15951 Expr);
15952 }
15953
15954 const SCEV *visitUMinExpr(const SCEVUMinExpr *Expr) {
15955 if (const SCEV *S = Map.lookup(Expr))
15956 return S;
15958 }
15959
15960 const SCEV *visitSMinExpr(const SCEVSMinExpr *Expr) {
15961 if (const SCEV *S = Map.lookup(Expr))
15962 return S;
15964 }
15965
15966 const SCEV *visitAddExpr(const SCEVAddExpr *Expr) {
15967 // Helper to check if S is a subtraction (A - B) where A != B, and if so,
15968 // return UMax(S, 1).
15969 auto RewriteSubtraction = [&](const SCEV *S) -> const SCEV * {
15970 const SCEV *LHS, *RHS;
15971 if (MatchBinarySub(S, LHS, RHS)) {
15972 if (LHS > RHS)
15973 std::swap(LHS, RHS);
15974 if (NotEqual.contains({LHS, RHS})) {
15975 const SCEV *OneAlignedUp = getNextSCEVDivisibleByDivisor(
15976 SE.getOne(S->getType()), SE.getConstantMultiple(S), SE);
15977 return SE.getUMaxExpr(OneAlignedUp, S);
15978 }
15979 }
15980 return nullptr;
15981 };
15982
15983 // Check if Expr itself is a subtraction pattern with guard info.
15984 if (const SCEV *Rewritten = RewriteSubtraction(Expr))
15985 return Rewritten;
15986
15987 // Trip count expressions sometimes consist of adding 3 operands, i.e.
15988 // (Const + A + B). There may be guard info for A + B, and if so, apply
15989 // it.
15990 // TODO: Could more generally apply guards to Add sub-expressions.
15991 if (isa<SCEVConstant>(Expr->getOperand(0)) &&
15992 Expr->getNumOperands() == 3) {
15993 const SCEV *Add =
15994 SE.getAddExpr(Expr->getOperand(1), Expr->getOperand(2));
15995 if (const SCEV *Rewritten = RewriteSubtraction(Add))
15996 return SE.getAddExpr(
15997 Expr->getOperand(0), Rewritten,
15998 ScalarEvolution::maskFlags(Expr->getNoWrapFlags(), FlagMask));
15999 if (const SCEV *S = Map.lookup(Add))
16000 return SE.getAddExpr(Expr->getOperand(0), S);
16001 }
16003 bool Changed = false;
16004 for (const auto *Op : Expr->operands()) {
16005 Operands.push_back(
16007 Changed |= Op != Operands.back();
16008 }
16009 // We are only replacing operands with equivalent values, so transfer the
16010 // flags from the original expression.
16011 return !Changed ? Expr
16012 : SE.getAddExpr(Operands,
16014 Expr->getNoWrapFlags(), FlagMask));
16015 }
16016
16017 const SCEV *visitMulExpr(const SCEVMulExpr *Expr) {
16019 bool Changed = false;
16020 for (const auto *Op : Expr->operands()) {
16021 Operands.push_back(
16023 Changed |= Op != Operands.back();
16024 }
16025 // We are only replacing operands with equivalent values, so transfer the
16026 // flags from the original expression.
16027 return !Changed ? Expr
16028 : SE.getMulExpr(Operands,
16030 Expr->getNoWrapFlags(), FlagMask));
16031 }
16032 };
16033
16034 if (RewriteMap.empty() && NotEqual.empty())
16035 return Expr;
16036
16037 SCEVLoopGuardRewriter Rewriter(SE, *this);
16038 return Rewriter.visit(Expr);
16039}
16040
16041const SCEV *ScalarEvolution::applyLoopGuards(const SCEV *Expr, const Loop *L) {
16042 return applyLoopGuards(Expr, LoopGuards::collect(L, *this));
16043}
16044
16046 const LoopGuards &Guards) {
16047 return Guards.rewrite(Expr);
16048}
assert(UImm &&(UImm !=~static_cast< T >(0)) &&"Invalid immediate!")
constexpr LLT S1
Rewrite undef for PHI
This file implements a class to represent arbitrary precision integral constant values and operations...
@ PostInc
MachineBasicBlock MachineBasicBlock::iterator DebugLoc DL
Expand Atomic instructions
static GCRegistry::Add< ErlangGC > A("erlang", "erlang-compatible garbage collector")
static GCRegistry::Add< StatepointGC > D("statepoint-example", "an example strategy for statepoint")
static GCRegistry::Add< CoreCLRGC > E("coreclr", "CoreCLR-compatible GC")
static GCRegistry::Add< OcamlGC > B("ocaml", "ocaml 3.10-compatible GC")
#define LLVM_DUMP_METHOD
Mark debug helper function definitions like dump() that should not be stripped from debug builds.
Definition Compiler.h:638
This file contains the declarations for the subclasses of Constant, which represent the different fla...
SmallPtrSet< const BasicBlock *, 8 > VisitedBlocks
This file defines the DenseMap class.
This file builds on the ADT/GraphTraits.h file to build generic depth first graph iterator.
static bool isSigned(unsigned int Opcode)
This file defines a hash set that can be used to remove duplication of nodes in a graph.
#define op(i)
Hexagon Common GEP
This file provides various utilities for inspecting and working with the control flow graph in LLVM I...
This defines the Use class.
iv Induction Variable Users
Definition IVUsers.cpp:48
const AbstractManglingParser< Derived, Alloc >::OperatorInfo AbstractManglingParser< Derived, Alloc >::Ops[]
static bool isZero(Value *V, const DataLayout &DL, DominatorTree *DT, AssumptionCache *AC)
Definition Lint.cpp:539
#define F(x, y, z)
Definition MD5.cpp:55
#define I(x, y, z)
Definition MD5.cpp:58
#define G(x, y, z)
Definition MD5.cpp:56
#define T
#define T1
ConstantRange Range(APInt(BitWidth, Low), APInt(BitWidth, High))
uint64_t IntrinsicInst * II
#define P(N)
ppc ctr loops verify
PowerPC Reduce CR logical Operation
#define INITIALIZE_PASS_DEPENDENCY(depName)
Definition PassSupport.h:42
#define INITIALIZE_PASS_END(passName, arg, name, cfg, analysis)
Definition PassSupport.h:44
#define INITIALIZE_PASS_BEGIN(passName, arg, name, cfg, analysis)
Definition PassSupport.h:39
R600 Clause Merge
const SmallVectorImpl< MachineOperand > & Cond
static bool isValid(const char C)
Returns true if C is a valid mangled character: <0-9a-zA-Z_>.
SI optimize exec mask operations pre RA
void visit(MachineFunction &MF, MachineBasicBlock &Start, std::function< void(MachineBasicBlock *)> op)
This file contains some templates that are useful if you are working with the STL at all.
This file provides utility classes that use RAII to save and restore values.
bool SCEVMinMaxExprContains(const SCEV *Root, const SCEV *OperandToFind, SCEVTypes RootKind)
static cl::opt< unsigned > MaxAddRecSize("scalar-evolution-max-add-rec-size", cl::Hidden, cl::desc("Max coefficients in AddRec during evolving"), cl::init(8))
static cl::opt< unsigned > RangeIterThreshold("scev-range-iter-threshold", cl::Hidden, cl::desc("Threshold for switching to iteratively computing SCEV ranges"), cl::init(32))
static const Loop * isIntegerLoopHeaderPHI(const PHINode *PN, LoopInfo &LI)
static unsigned getConstantTripCount(const SCEVConstant *ExitCount)
static int CompareValueComplexity(const LoopInfo *const LI, Value *LV, Value *RV, unsigned Depth)
Compare the two values LV and RV in terms of their "complexity" where "complexity" is a partial (and ...
static const SCEV * getNextSCEVDivisibleByDivisor(const SCEV *Expr, const APInt &DivisorVal, ScalarEvolution &SE)
static void PushLoopPHIs(const Loop *L, SmallVectorImpl< Instruction * > &Worklist, SmallPtrSetImpl< Instruction * > &Visited)
Push PHI nodes in the header of the given loop onto the given Worklist.
static void insertFoldCacheEntry(const ScalarEvolution::FoldID &ID, const SCEV *S, DenseMap< ScalarEvolution::FoldID, const SCEV * > &FoldCache, DenseMap< const SCEV *, SmallVector< ScalarEvolution::FoldID, 2 > > &FoldCacheUser)
static cl::opt< bool > ClassifyExpressions("scalar-evolution-classify-expressions", cl::Hidden, cl::init(true), cl::desc("When printing analysis, include information on every instruction"))
static bool CanConstantFold(const Instruction *I)
Return true if we can constant fold an instruction of the specified type, assuming that all operands ...
static cl::opt< unsigned > AddOpsInlineThreshold("scev-addops-inline-threshold", cl::Hidden, cl::desc("Threshold for inlining addition operands into a SCEV"), cl::init(500))
static cl::opt< unsigned > MaxLoopGuardCollectionDepth("scalar-evolution-max-loop-guard-collection-depth", cl::Hidden, cl::desc("Maximum depth for recursive loop guard collection"), cl::init(1))
static cl::opt< bool > VerifyIR("scev-verify-ir", cl::Hidden, cl::desc("Verify IR correctness when making sensitive SCEV queries (slow)"), cl::init(false))
static bool BrPHIToSelect(DominatorTree &DT, BranchInst *BI, PHINode *Merge, Value *&C, Value *&LHS, Value *&RHS)
static const SCEV * getPreStartForExtend(const SCEVAddRecExpr *AR, Type *Ty, ScalarEvolution *SE, unsigned Depth)
static std::optional< APInt > MinOptional(std::optional< APInt > X, std::optional< APInt > Y)
Helper function to compare optional APInts: (a) if X and Y both exist, return min(X,...
static cl::opt< unsigned > MulOpsInlineThreshold("scev-mulops-inline-threshold", cl::Hidden, cl::desc("Threshold for inlining multiplication operands into a SCEV"), cl::init(32))
static void GroupByComplexity(SmallVectorImpl< const SCEV * > &Ops, LoopInfo *LI, DominatorTree &DT)
Given a list of SCEV objects, order them by their complexity, and group objects of the same complexit...
static const SCEV * constantFoldAndGroupOps(ScalarEvolution &SE, LoopInfo &LI, DominatorTree &DT, SmallVectorImpl< const SCEV * > &Ops, FoldT Fold, IsIdentityT IsIdentity, IsAbsorberT IsAbsorber)
Performs a number of common optimizations on the passed Ops.
static std::optional< const SCEV * > createNodeForSelectViaUMinSeq(ScalarEvolution *SE, const SCEV *CondExpr, const SCEV *TrueExpr, const SCEV *FalseExpr)
static Constant * BuildConstantFromSCEV(const SCEV *V)
This builds up a Constant using the ConstantExpr interface.
static ConstantInt * EvaluateConstantChrecAtConstant(const SCEVAddRecExpr *AddRec, ConstantInt *C, ScalarEvolution &SE)
static const SCEV * BinomialCoefficient(const SCEV *It, unsigned K, ScalarEvolution &SE, Type *ResultTy)
Compute BC(It, K). The result has width W. Assume, K > 0.
static cl::opt< unsigned > MaxCastDepth("scalar-evolution-max-cast-depth", cl::Hidden, cl::desc("Maximum depth of recursive SExt/ZExt/Trunc"), cl::init(8))
static bool IsMinMaxConsistingOf(const SCEV *MaybeMinMaxExpr, const SCEV *Candidate)
Is MaybeMinMaxExpr an (U|S)(Min|Max) of Candidate and some other values?
static PHINode * getConstantEvolvingPHI(Value *V, const Loop *L)
getConstantEvolvingPHI - Given an LLVM value and a loop, return a PHI node in the loop that V is deri...
static const SCEV * SolveLinEquationWithOverflow(const APInt &A, const SCEV *B, SmallVectorImpl< const SCEVPredicate * > *Predicates, ScalarEvolution &SE, const Loop *L)
Finds the minimum unsigned root of the following equation:
static cl::opt< unsigned > MaxBruteForceIterations("scalar-evolution-max-iterations", cl::ReallyHidden, cl::desc("Maximum number of iterations SCEV will " "symbolically execute a constant " "derived loop"), cl::init(100))
static bool MatchBinarySub(const SCEV *S, const SCEV *&LHS, const SCEV *&RHS)
static uint64_t umul_ov(uint64_t i, uint64_t j, bool &Overflow)
static void PrintSCEVWithTypeHint(raw_ostream &OS, const SCEV *S)
When printing a top-level SCEV for trip counts, it's helpful to include a type for constants which ar...
static void PrintLoopInfo(raw_ostream &OS, ScalarEvolution *SE, const Loop *L)
static bool containsConstantInAddMulChain(const SCEV *StartExpr)
Determine if any of the operands in this SCEV are a constant or if any of the add or multiply express...
static const SCEV * getExtendAddRecStart(const SCEVAddRecExpr *AR, Type *Ty, ScalarEvolution *SE, unsigned Depth)
static bool hasHugeExpression(ArrayRef< const SCEV * > Ops)
Returns true if Ops contains a huge SCEV (the subtree of S contains at least HugeExprThreshold nodes)...
static cl::opt< unsigned > MaxPhiSCCAnalysisSize("scalar-evolution-max-scc-analysis-depth", cl::Hidden, cl::desc("Maximum amount of nodes to process while searching SCEVUnknown " "Phi strongly connected components"), cl::init(8))
static bool IsKnownPredicateViaAddRecStart(ScalarEvolution &SE, CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS)
static cl::opt< unsigned > MaxSCEVOperationsImplicationDepth("scalar-evolution-max-scev-operations-implication-depth", cl::Hidden, cl::desc("Maximum depth of recursive SCEV operations implication analysis"), cl::init(2))
static void PushDefUseChildren(Instruction *I, SmallVectorImpl< Instruction * > &Worklist, SmallPtrSetImpl< Instruction * > &Visited)
Push users of the given Instruction onto the given Worklist.
static std::optional< APInt > SolveQuadraticAddRecRange(const SCEVAddRecExpr *AddRec, const ConstantRange &Range, ScalarEvolution &SE)
Let c(n) be the value of the quadratic chrec {0,+,M,+,N} after n iterations.
static cl::opt< bool > UseContextForNoWrapFlagInference("scalar-evolution-use-context-for-no-wrap-flag-strenghening", cl::Hidden, cl::desc("Infer nuw/nsw flags using context where suitable"), cl::init(true))
static cl::opt< bool > EnableFiniteLoopControl("scalar-evolution-finite-loop", cl::Hidden, cl::desc("Handle <= and >= in finite loops"), cl::init(true))
static std::optional< std::tuple< APInt, APInt, APInt, APInt, unsigned > > GetQuadraticEquation(const SCEVAddRecExpr *AddRec)
For a given quadratic addrec, generate coefficients of the corresponding quadratic equation,...
static bool isKnownPredicateExtendIdiom(CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS)
static std::optional< BinaryOp > MatchBinaryOp(Value *V, const DataLayout &DL, AssumptionCache &AC, const DominatorTree &DT, const Instruction *CxtI)
Try to map V into a BinaryOp, and return std::nullopt on failure.
static std::optional< APInt > SolveQuadraticAddRecExact(const SCEVAddRecExpr *AddRec, ScalarEvolution &SE)
Let c(n) be the value of the quadratic chrec {L,+,M,+,N} after n iterations.
static std::optional< APInt > TruncIfPossible(std::optional< APInt > X, unsigned BitWidth)
Helper function to truncate an optional APInt to a given BitWidth.
static cl::opt< unsigned > MaxSCEVCompareDepth("scalar-evolution-max-scev-compare-depth", cl::Hidden, cl::desc("Maximum depth of recursive SCEV complexity comparisons"), cl::init(32))
static APInt extractConstantWithoutWrapping(ScalarEvolution &SE, const SCEVConstant *ConstantTerm, const SCEVAddExpr *WholeAddExpr)
static cl::opt< unsigned > MaxConstantEvolvingDepth("scalar-evolution-max-constant-evolving-depth", cl::Hidden, cl::desc("Maximum depth of recursive constant evolving"), cl::init(32))
static ConstantRange getRangeForAffineARHelper(APInt Step, const ConstantRange &StartRange, const APInt &MaxBECount, bool Signed)
static std::optional< ConstantRange > GetRangeFromMetadata(Value *V)
Helper method to assign a range to V from metadata present in the IR.
static cl::opt< unsigned > HugeExprThreshold("scalar-evolution-huge-expr-threshold", cl::Hidden, cl::desc("Size of the expression which is considered huge"), cl::init(4096))
static Type * isSimpleCastedPHI(const SCEV *Op, const SCEVUnknown *SymbolicPHI, bool &Signed, ScalarEvolution &SE)
Helper function to createAddRecFromPHIWithCasts.
static Constant * EvaluateExpression(Value *V, const Loop *L, DenseMap< Instruction *, Constant * > &Vals, const DataLayout &DL, const TargetLibraryInfo *TLI)
EvaluateExpression - Given an expression that passes the getConstantEvolvingPHI predicate,...
static const SCEV * getPreviousSCEVDivisibleByDivisor(const SCEV *Expr, const APInt &DivisorVal, ScalarEvolution &SE)
static const SCEV * MatchNotExpr(const SCEV *Expr)
If Expr computes ~A, return A else return nullptr.
static cl::opt< unsigned > MaxValueCompareDepth("scalar-evolution-max-value-compare-depth", cl::Hidden, cl::desc("Maximum depth of recursive value complexity comparisons"), cl::init(2))
static cl::opt< bool, true > VerifySCEVOpt("verify-scev", cl::Hidden, cl::location(VerifySCEV), cl::desc("Verify ScalarEvolution's backedge taken counts (slow)"))
static const SCEV * getSignedOverflowLimitForStep(const SCEV *Step, ICmpInst::Predicate *Pred, ScalarEvolution *SE)
static SCEV::NoWrapFlags StrengthenNoWrapFlags(ScalarEvolution *SE, SCEVTypes Type, const ArrayRef< const SCEV * > Ops, SCEV::NoWrapFlags Flags)
static cl::opt< unsigned > MaxArithDepth("scalar-evolution-max-arith-depth", cl::Hidden, cl::desc("Maximum depth of recursive arithmetics"), cl::init(32))
static bool HasSameValue(const SCEV *A, const SCEV *B)
SCEV structural equivalence is usually sufficient for testing whether two expressions are equal,...
static uint64_t Choose(uint64_t n, uint64_t k, bool &Overflow)
Compute the result of "n choose k", the binomial coefficient.
static std::optional< int > CompareSCEVComplexity(const LoopInfo *const LI, const SCEV *LHS, const SCEV *RHS, DominatorTree &DT, unsigned Depth=0)
static bool CollectAddOperandsWithScales(SmallDenseMap< const SCEV *, APInt, 16 > &M, SmallVectorImpl< const SCEV * > &NewOps, APInt &AccumulatedConstant, ArrayRef< const SCEV * > Ops, const APInt &Scale, ScalarEvolution &SE)
Process the given Ops list, which is a list of operands to be added under the given scale,...
static bool canConstantEvolve(Instruction *I, const Loop *L)
Determine whether this instruction can constant evolve within this loop assuming its operands can all...
static PHINode * getConstantEvolvingPHIOperands(Instruction *UseInst, const Loop *L, DenseMap< Instruction *, PHINode * > &PHIMap, unsigned Depth)
getConstantEvolvingPHIOperands - Implement getConstantEvolvingPHI by recursing through each instructi...
static bool scevUnconditionallyPropagatesPoisonFromOperands(SCEVTypes Kind)
static cl::opt< bool > VerifySCEVStrict("verify-scev-strict", cl::Hidden, cl::desc("Enable stricter verification with -verify-scev is passed"))
static Constant * getOtherIncomingValue(PHINode *PN, BasicBlock *BB)
static cl::opt< bool > UseExpensiveRangeSharpening("scalar-evolution-use-expensive-range-sharpening", cl::Hidden, cl::init(false), cl::desc("Use more powerful methods of sharpening expression ranges. May " "be costly in terms of compile time"))
static const SCEV * getUnsignedOverflowLimitForStep(const SCEV *Step, ICmpInst::Predicate *Pred, ScalarEvolution *SE)
static bool IsKnownPredicateViaMinOrMax(ScalarEvolution &SE, CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS)
Is LHS Pred RHS true on the virtue of LHS or RHS being a Min or Max expression?
This file defines the make_scope_exit function, which executes user-defined cleanup logic at scope ex...
static bool InBlock(const Value *V, const BasicBlock *BB)
Provides some synthesis utilities to produce sequences of values.
This file defines the SmallPtrSet class.
This file defines the SmallVector class.
This file defines the 'Statistic' class, which is designed to be an easy way to expose various metric...
#define STATISTIC(VARNAME, DESC)
Definition Statistic.h:171
This file contains some functions that are useful when dealing with strings.
#define LLVM_DEBUG(...)
Definition Debug.h:114
static TableGen::Emitter::Opt Y("gen-skeleton-entry", EmitSkeleton, "Generate example skeleton entry")
static TableGen::Emitter::OptClass< SkeletonEmitter > X("gen-skeleton-class", "Generate example skeleton class")
static SymbolRef::Type getType(const Symbol *Sym)
Definition TapiFile.cpp:39
LocallyHashedType DenseMapInfo< LocallyHashedType >::Empty
static std::optional< unsigned > getOpcode(ArrayRef< VPValue * > Values)
Returns the opcode of Values or ~0 if they do not all agree.
Definition VPlanSLP.cpp:247
static std::optional< bool > isImpliedCondOperands(CmpInst::Predicate Pred, const Value *ALHS, const Value *ARHS, const Value *BLHS, const Value *BRHS)
Return true if "icmp Pred BLHS BRHS" is true whenever "icmp PredALHS ARHS" is true.
Virtual Register Rewriter
Value * RHS
Value * LHS
BinaryOperator * Mul
static const uint32_t IV[8]
Definition blake3_impl.h:83
Class for arbitrary precision integers.
Definition APInt.h:78
LLVM_ABI APInt umul_ov(const APInt &RHS, bool &Overflow) const
Definition APInt.cpp:1971
LLVM_ABI APInt zext(unsigned width) const
Zero extend to a new width.
Definition APInt.cpp:1012
bool isMinSignedValue() const
Determine if this is the smallest signed value.
Definition APInt.h:423
uint64_t getZExtValue() const
Get zero extended value.
Definition APInt.h:1540
void setHighBits(unsigned hiBits)
Set the top hiBits bits.
Definition APInt.h:1391
LLVM_ABI APInt getHiBits(unsigned numBits) const
Compute an APInt containing numBits highbits from this APInt.
Definition APInt.cpp:639
unsigned getActiveBits() const
Compute the number of active bits in the value.
Definition APInt.h:1512
LLVM_ABI APInt trunc(unsigned width) const
Truncate to new width.
Definition APInt.cpp:936
static APInt getMaxValue(unsigned numBits)
Gets maximum unsigned value of APInt for specific bit width.
Definition APInt.h:206
APInt abs() const
Get the absolute value.
Definition APInt.h:1795
bool sgt(const APInt &RHS) const
Signed greater than comparison.
Definition APInt.h:1201
bool isAllOnes() const
Determine if all bits are set. This is true for zero-width values.
Definition APInt.h:371
bool ugt(const APInt &RHS) const
Unsigned greater than comparison.
Definition APInt.h:1182
bool isZero() const
Determine if this value is zero, i.e. all bits are clear.
Definition APInt.h:380
bool isSignMask() const
Check if the APInt's value is returned by getSignMask.
Definition APInt.h:466
LLVM_ABI APInt urem(const APInt &RHS) const
Unsigned remainder operation.
Definition APInt.cpp:1666
unsigned getBitWidth() const
Return the number of bits in the APInt.
Definition APInt.h:1488
bool ult(const APInt &RHS) const
Unsigned less than comparison.
Definition APInt.h:1111
static APInt getSignedMaxValue(unsigned numBits)
Gets maximum signed value of APInt for a specific bit width.
Definition APInt.h:209
static APInt getMinValue(unsigned numBits)
Gets minimum unsigned value of APInt for a specific bit width.
Definition APInt.h:216
bool isNegative() const
Determine sign of this APInt.
Definition APInt.h:329
bool sle(const APInt &RHS) const
Signed less or equal comparison.
Definition APInt.h:1166
static APInt getSignedMinValue(unsigned numBits)
Gets minimum signed value of APInt for a specific bit width.
Definition APInt.h:219
bool isNonPositive() const
Determine if this APInt Value is non-positive (<= 0).
Definition APInt.h:361
unsigned countTrailingZeros() const
Definition APInt.h:1647
bool isStrictlyPositive() const
Determine if this APInt Value is positive.
Definition APInt.h:356
unsigned logBase2() const
Definition APInt.h:1761
APInt ashr(unsigned ShiftAmt) const
Arithmetic right-shift function.
Definition APInt.h:827
LLVM_ABI APInt multiplicativeInverse() const
Definition APInt.cpp:1274
bool ule(const APInt &RHS) const
Unsigned less or equal comparison.
Definition APInt.h:1150
LLVM_ABI APInt sext(unsigned width) const
Sign extend to a new width.
Definition APInt.cpp:985
APInt shl(unsigned shiftAmt) const
Left-shift function.
Definition APInt.h:873
bool isPowerOf2() const
Check if this APInt's value is a power of two greater than zero.
Definition APInt.h:440
static APInt getLowBitsSet(unsigned numBits, unsigned loBitsSet)
Constructs an APInt value that has the bottom loBitsSet bits set.
Definition APInt.h:306
bool isSignBitSet() const
Determine if sign bit of this APInt is set.
Definition APInt.h:341
bool slt(const APInt &RHS) const
Signed less than comparison.
Definition APInt.h:1130
static APInt getZero(unsigned numBits)
Get the '0' value for the specified bit-width.
Definition APInt.h:200
bool isIntN(unsigned N) const
Check if this APInt has an N-bits unsigned integer value.
Definition APInt.h:432
static APInt getOneBitSet(unsigned numBits, unsigned BitNo)
Return an APInt with exactly one bit set in the result.
Definition APInt.h:239
bool uge(const APInt &RHS) const
Unsigned greater or equal comparison.
Definition APInt.h:1221
This templated class represents "all analyses that operate over <aparticular IR unit>" (e....
Definition Analysis.h:50
PassT::Result & getResult(IRUnitT &IR, ExtraArgTs... ExtraArgs)
Get the result of an analysis pass for a given IR unit.
Represent the analysis usage information of a pass.
void setPreservesAll()
Set by analyses that do not transform their input at all.
AnalysisUsage & addRequiredTransitive()
ArrayRef - Represent a constant reference to an array (0 or more elements consecutively in memory),...
Definition ArrayRef.h:41
iterator end() const
Definition ArrayRef.h:136
size_t size() const
size - Get the array size.
Definition ArrayRef.h:147
iterator begin() const
Definition ArrayRef.h:135
A function analysis which provides an AssumptionCache.
An immutable pass that tracks lazily created AssumptionCache objects.
A cache of @llvm.assume calls within a function.
MutableArrayRef< WeakVH > assumptions()
Access the list of assumption handles currently tracked for this function.
LLVM_ABI bool isSingleEdge() const
Check if this is the only edge between Start and End.
LLVM Basic Block Representation.
Definition BasicBlock.h:62
iterator begin()
Instruction iterator methods.
Definition BasicBlock.h:459
const Function * getParent() const
Return the enclosing method, or null if none.
Definition BasicBlock.h:213
LLVM_ABI const BasicBlock * getSinglePredecessor() const
Return the predecessor of this block if it has a single predecessor block.
const Instruction & front() const
Definition BasicBlock.h:482
const Instruction * getTerminator() const LLVM_READONLY
Returns the terminator instruction if the block is well formed or null if the block is not well forme...
Definition BasicBlock.h:233
LLVM_ABI unsigned getNoWrapKind() const
Returns one of OBO::NoSignedWrap or OBO::NoUnsignedWrap.
LLVM_ABI Instruction::BinaryOps getBinaryOp() const
Returns the binary operation underlying the intrinsic.
BinaryOps getOpcode() const
Definition InstrTypes.h:374
Conditional or Unconditional Branch instruction.
bool isConditional() const
BasicBlock * getSuccessor(unsigned i) const
bool isUnconditional() const
Value * getCondition() const
LLVM_ATTRIBUTE_RETURNS_NONNULL void * Allocate(size_t Size, Align Alignment)
Allocate space at the specified alignment.
Definition Allocator.h:149
This class represents a function call, abstracting a target machine's calling convention.
virtual void deleted()
Callback for Value destruction.
void setValPtr(Value *P)
bool isFalseWhenEqual() const
This is just a convenience.
Definition InstrTypes.h:948
Predicate
This enumeration lists the possible predicates for CmpInst subclasses.
Definition InstrTypes.h:676
@ ICMP_SLT
signed less than
Definition InstrTypes.h:705
@ ICMP_SLE
signed less or equal
Definition InstrTypes.h:706
@ ICMP_UGE
unsigned greater or equal
Definition InstrTypes.h:700
@ ICMP_UGT
unsigned greater than
Definition InstrTypes.h:699
@ ICMP_SGT
signed greater than
Definition InstrTypes.h:703
@ ICMP_ULT
unsigned less than
Definition InstrTypes.h:701
@ ICMP_NE
not equal
Definition InstrTypes.h:698
@ ICMP_SGE
signed greater or equal
Definition InstrTypes.h:704
@ ICMP_ULE
unsigned less or equal
Definition InstrTypes.h:702
bool isSigned() const
Definition InstrTypes.h:930
Predicate getSwappedPredicate() const
For example, EQ->EQ, SLE->SGE, ULT->UGT, OEQ->OEQ, ULE->UGE, OLT->OGT, etc.
Definition InstrTypes.h:827
bool isTrueWhenEqual() const
This is just a convenience.
Definition InstrTypes.h:942
Predicate getInversePredicate() const
For example, EQ -> NE, UGT -> ULE, SLT -> SGE, OEQ -> UNE, UGT -> OLE, OLT -> UGE,...
Definition InstrTypes.h:789
bool isUnsigned() const
Definition InstrTypes.h:936
bool isRelational() const
Return true if the predicate is relational (not EQ or NE).
Definition InstrTypes.h:926
An abstraction over a floating-point predicate, and a pack of an integer predicate with samesign info...
static LLVM_ABI std::optional< CmpPredicate > getMatching(CmpPredicate A, CmpPredicate B)
Compares two CmpPredicates taking samesign into account and returns the canonicalized CmpPredicate if...
LLVM_ABI CmpInst::Predicate getPreferredSignedPredicate() const
Attempts to return a signed CmpInst::Predicate from the CmpPredicate.
CmpInst::Predicate dropSameSign() const
Drops samesign information.
static LLVM_ABI Constant * getNot(Constant *C)
static LLVM_ABI Constant * getPtrToInt(Constant *C, Type *Ty, bool OnlyIfReduced=false)
static Constant * getGetElementPtr(Type *Ty, Constant *C, ArrayRef< Constant * > IdxList, GEPNoWrapFlags NW=GEPNoWrapFlags::none(), std::optional< ConstantRange > InRange=std::nullopt, Type *OnlyIfReducedTy=nullptr)
Getelementptr form.
Definition Constants.h:1274
static LLVM_ABI Constant * getAdd(Constant *C1, Constant *C2, bool HasNUW=false, bool HasNSW=false)
static LLVM_ABI Constant * getNeg(Constant *C, bool HasNSW=false)
static LLVM_ABI Constant * getTrunc(Constant *C, Type *Ty, bool OnlyIfReduced=false)
This is the shared class of boolean and integer constants.
Definition Constants.h:87
bool isZero() const
This is just a convenience method to make client code smaller for a common code.
Definition Constants.h:214
static LLVM_ABI ConstantInt * getFalse(LLVMContext &Context)
uint64_t getZExtValue() const
Return the constant as a 64-bit unsigned integer value after it has been zero extended as appropriate...
Definition Constants.h:163
const APInt & getValue() const
Return the constant as an APInt value reference.
Definition Constants.h:154
static LLVM_ABI ConstantInt * getBool(LLVMContext &Context, bool V)
This class represents a range of values.
LLVM_ABI ConstantRange add(const ConstantRange &Other) const
Return a new range representing the possible values resulting from an addition of a value in this ran...
LLVM_ABI ConstantRange zextOrTrunc(uint32_t BitWidth) const
Make this range have the bit width given by BitWidth.
PreferredRangeType
If represented precisely, the result of some range operations may consist of multiple disjoint ranges...
LLVM_ABI bool getEquivalentICmp(CmpInst::Predicate &Pred, APInt &RHS) const
Set up Pred and RHS such that ConstantRange::makeExactICmpRegion(Pred, RHS) == *this.
const APInt & getLower() const
Return the lower value for this range.
LLVM_ABI bool isFullSet() const
Return true if this set contains all of the elements possible for this data-type.
LLVM_ABI bool icmp(CmpInst::Predicate Pred, const ConstantRange &Other) const
Does the predicate Pred hold between ranges this and Other?
LLVM_ABI bool isEmptySet() const
Return true if this set contains no members.
LLVM_ABI ConstantRange zeroExtend(uint32_t BitWidth) const
Return a new range in the specified integer type, which must be strictly larger than the current type...
LLVM_ABI bool isSignWrappedSet() const
Return true if this set wraps around the signed domain.
LLVM_ABI APInt getSignedMin() const
Return the smallest signed value contained in the ConstantRange.
LLVM_ABI bool isWrappedSet() const
Return true if this set wraps around the unsigned domain.
LLVM_ABI void print(raw_ostream &OS) const
Print out the bounds to a stream.
LLVM_ABI ConstantRange truncate(uint32_t BitWidth, unsigned NoWrapKind=0) const
Return a new range in the specified integer type, which must be strictly smaller than the current typ...
LLVM_ABI ConstantRange signExtend(uint32_t BitWidth) const
Return a new range in the specified integer type, which must be strictly larger than the current type...
const APInt & getUpper() const
Return the upper value for this range.
LLVM_ABI ConstantRange unionWith(const ConstantRange &CR, PreferredRangeType Type=Smallest) const
Return the range that results from the union of this range with another range.
static LLVM_ABI ConstantRange makeExactICmpRegion(CmpInst::Predicate Pred, const APInt &Other)
Produce the exact range such that all values in the returned range satisfy the given predicate with a...
LLVM_ABI bool contains(const APInt &Val) const
Return true if the specified value is in the set.
LLVM_ABI APInt getUnsignedMax() const
Return the largest unsigned value contained in the ConstantRange.
LLVM_ABI ConstantRange intersectWith(const ConstantRange &CR, PreferredRangeType Type=Smallest) const
Return the range that results from the intersection of this range with another range.
LLVM_ABI APInt getSignedMax() const
Return the largest signed value contained in the ConstantRange.
static ConstantRange getNonEmpty(APInt Lower, APInt Upper)
Create non-empty constant range with the given bounds.
static LLVM_ABI ConstantRange makeGuaranteedNoWrapRegion(Instruction::BinaryOps BinOp, const ConstantRange &Other, unsigned NoWrapKind)
Produce the largest range containing all X such that "X BinOp Y" is guaranteed not to wrap (overflow)...
LLVM_ABI unsigned getMinSignedBits() const
Compute the maximal number of bits needed to represent every value in this signed range.
uint32_t getBitWidth() const
Get the bit width of this ConstantRange.
LLVM_ABI ConstantRange sub(const ConstantRange &Other) const
Return a new range representing the possible values resulting from a subtraction of a value in this r...
LLVM_ABI ConstantRange sextOrTrunc(uint32_t BitWidth) const
Make this range have the bit width given by BitWidth.
static LLVM_ABI ConstantRange makeExactNoWrapRegion(Instruction::BinaryOps BinOp, const APInt &Other, unsigned NoWrapKind)
Produce the range that contains X if and only if "X BinOp Other" does not wrap.
This is an important base class in LLVM.
Definition Constant.h:43
A parsed version of the target data layout string in and methods for querying it.
Definition DataLayout.h:63
LLVM_ABI const StructLayout * getStructLayout(StructType *Ty) const
Returns a StructLayout object, indicating the alignment of the struct, its size, and the offsets of i...
LLVM_ABI IntegerType * getIntPtrType(LLVMContext &C, unsigned AddressSpace=0) const
Returns an integer type with size at least as big as that of a pointer in the given address space.
LLVM_ABI unsigned getIndexTypeSizeInBits(Type *Ty) const
The size in bits of the index used in GEP calculation for this type.
LLVM_ABI IntegerType * getIndexType(LLVMContext &C, unsigned AddressSpace) const
Returns the type of a GEP index in AddressSpace.
TypeSize getTypeSizeInBits(Type *Ty) const
Size examples:
Definition DataLayout.h:760
ValueT lookup(const_arg_type_t< KeyT > Val) const
lookup - Return the entry for the specified key, or a default constructed value if no such entry exis...
Definition DenseMap.h:194
iterator find(const_arg_type_t< KeyT > Val)
Definition DenseMap.h:167
std::pair< iterator, bool > try_emplace(KeyT &&Key, Ts &&...Args)
Definition DenseMap.h:237
DenseMapIterator< KeyT, ValueT, KeyInfoT, BucketT > iterator
Definition DenseMap.h:74
iterator find_as(const LookupKeyT &Val)
Alternate version of find() which allows a different, and possibly less expensive,...
Definition DenseMap.h:180
size_type count(const_arg_type_t< KeyT > Val) const
Return 1 if the specified key is in the map, 0 otherwise.
Definition DenseMap.h:163
iterator end()
Definition DenseMap.h:81
bool contains(const_arg_type_t< KeyT > Val) const
Return true if the specified key is in the map, false otherwise.
Definition DenseMap.h:158
std::pair< iterator, bool > insert(const std::pair< KeyT, ValueT > &KV)
Definition DenseMap.h:222
void swap(DenseMap &RHS)
Definition DenseMap.h:747
Analysis pass which computes a DominatorTree.
Definition Dominators.h:284
Legacy analysis pass which computes a DominatorTree.
Definition Dominators.h:322
Concrete subclass of DominatorTreeBase that is used to compute a normal dominator tree.
Definition Dominators.h:165
LLVM_ABI bool isReachableFromEntry(const Use &U) const
Provide an overload for a Use.
LLVM_ABI bool dominates(const BasicBlock *BB, const Use &U) const
Return true if the (end of the) basic block BB dominates the use U.
FoldingSetNodeIDRef - This class describes a reference to an interned FoldingSetNodeID,...
Definition FoldingSet.h:293
FoldingSetNodeID - This class is used to gather all the unique data bits of a node.
Definition FoldingSet.h:330
FunctionPass(char &pid)
Definition Pass.h:316
Represents flags for the getelementptr instruction/expression.
bool hasNoUnsignedSignedWrap() const
bool hasNoUnsignedWrap() const
static GEPNoWrapFlags none()
static LLVM_ABI Type * getTypeAtIndex(Type *Ty, Value *Idx)
Return the type of the element at the given index of an indexable type.
Module * getParent()
Get the module that this global value is contained inside of...
static bool isPrivateLinkage(LinkageTypes Linkage)
static bool isInternalLinkage(LinkageTypes Linkage)
This instruction compares its operands according to the predicate given to the constructor.
CmpPredicate getCmpPredicate() const
static bool isGE(Predicate P)
Return true if the predicate is SGE or UGE.
CmpPredicate getSwappedCmpPredicate() const
static LLVM_ABI bool compare(const APInt &LHS, const APInt &RHS, ICmpInst::Predicate Pred)
Return result of LHS Pred RHS comparison.
static bool isLT(Predicate P)
Return true if the predicate is SLT or ULT.
CmpPredicate getInverseCmpPredicate() const
Predicate getNonStrictCmpPredicate() const
For example, SGT -> SGE, SLT -> SLE, ULT -> ULE, UGT -> UGE.
static bool isGT(Predicate P)
Return true if the predicate is SGT or UGT.
Predicate getFlippedSignednessPredicate() const
For example, SLT->ULT, ULT->SLT, SLE->ULE, ULE->SLE, EQ->EQ.
static CmpPredicate getInverseCmpPredicate(CmpPredicate Pred)
static bool isEquality(Predicate P)
Return true if this predicate is either EQ or NE.
bool isRelational() const
Return true if the predicate is relational (not EQ or NE).
static bool isLE(Predicate P)
Return true if the predicate is SLE or ULE.
LLVM_ABI bool hasNoUnsignedWrap() const LLVM_READONLY
Determine whether the no unsigned wrap flag is set.
LLVM_ABI bool hasNoSignedWrap() const LLVM_READONLY
Determine whether the no signed wrap flag is set.
LLVM_ABI bool isIdenticalToWhenDefined(const Instruction *I, bool IntersectAttrs=false) const LLVM_READONLY
This is like isIdenticalTo, except that it ignores the SubclassOptionalData flags,...
Class to represent integer types.
static LLVM_ABI IntegerType * get(LLVMContext &C, unsigned NumBits)
This static method is the primary way of constructing an IntegerType.
Definition Type.cpp:319
An instruction for reading from memory.
Analysis pass that exposes the LoopInfo for a function.
Definition LoopInfo.h:569
bool contains(const LoopT *L) const
Return true if the specified loop is contained within in this loop.
BlockT * getHeader() const
unsigned getLoopDepth() const
Return the nesting level of this loop.
BlockT * getLoopPredecessor() const
If the given loop's header has exactly one unique predecessor outside the loop, return it.
LoopT * getParentLoop() const
Return the parent loop if it exists or nullptr for top level loops.
unsigned getLoopDepth(const BlockT *BB) const
Return the loop nesting level of the specified block.
LoopT * getLoopFor(const BlockT *BB) const
Return the inner most loop that BB lives in.
The legacy pass manager's analysis pass to compute loop information.
Definition LoopInfo.h:596
Represents a single loop in the control flow graph.
Definition LoopInfo.h:40
bool isLoopInvariant(const Value *V) const
Return true if the specified value is loop invariant.
Definition LoopInfo.cpp:61
Metadata node.
Definition Metadata.h:1078
A Module instance is used to store all the information related to an LLVM module.
Definition Module.h:67
unsigned getOpcode() const
Return the opcode for this Instruction or ConstantExpr.
Definition Operator.h:43
Utility class for integer operators which may exhibit overflow - Add, Sub, Mul, and Shl.
Definition Operator.h:78
bool hasNoSignedWrap() const
Test whether this operation is known to never undergo signed overflow, aka the nsw property.
Definition Operator.h:111
bool hasNoUnsignedWrap() const
Test whether this operation is known to never undergo unsigned overflow, aka the nuw property.
Definition Operator.h:105
iterator_range< const_block_iterator > blocks() const
op_range incoming_values()
Value * getIncomingValueForBlock(const BasicBlock *BB) const
BasicBlock * getIncomingBlock(unsigned i) const
Return incoming basic block number i.
Value * getIncomingValue(unsigned i) const
Return incoming value number x.
unsigned getNumIncomingValues() const
Return the number of incoming edges.
AnalysisType & getAnalysis() const
getAnalysis<AnalysisType>() - This function is used by subclasses to get to the analysis information ...
PointerIntPair - This class implements a pair of a pointer and small integer.
static PointerType * getUnqual(Type *ElementType)
This constructs a pointer to an object of the specified type in the default address space (address sp...
static LLVM_ABI PoisonValue * get(Type *T)
Static factory methods - Return an 'poison' object of the specified type.
LLVM_ABI void addPredicate(const SCEVPredicate &Pred)
Adds a new predicate.
LLVM_ABI const SCEVPredicate & getPredicate() const
LLVM_ABI bool hasNoOverflow(Value *V, SCEVWrapPredicate::IncrementWrapFlags Flags)
Returns true if we've proved that V doesn't wrap by means of a SCEV predicate.
LLVM_ABI void setNoOverflow(Value *V, SCEVWrapPredicate::IncrementWrapFlags Flags)
Proves that V doesn't overflow by adding SCEV predicate.
LLVM_ABI void print(raw_ostream &OS, unsigned Depth) const
Print the SCEV mappings done by the Predicated Scalar Evolution.
LLVM_ABI bool areAddRecsEqualWithPreds(const SCEVAddRecExpr *AR1, const SCEVAddRecExpr *AR2) const
Check if AR1 and AR2 are equal, while taking into account Equal predicates in Preds.
LLVM_ABI PredicatedScalarEvolution(ScalarEvolution &SE, Loop &L)
LLVM_ABI const SCEVAddRecExpr * getAsAddRec(Value *V)
Attempts to produce an AddRecExpr for V by adding additional SCEV predicates.
LLVM_ABI unsigned getSmallConstantMaxTripCount()
Returns the upper bound of the loop trip count as a normal unsigned value, or 0 if the trip count is ...
LLVM_ABI const SCEV * getBackedgeTakenCount()
Get the (predicated) backedge count for the analyzed loop.
LLVM_ABI const SCEV * getSymbolicMaxBackedgeTakenCount()
Get the (predicated) symbolic max backedge count for the analyzed loop.
LLVM_ABI const SCEV * getSCEV(Value *V)
Returns the SCEV expression of V, in the context of the current SCEV predicate.
A set of analyses that are preserved following a run of a transformation pass.
Definition Analysis.h:112
static PreservedAnalyses all()
Construct a special preserved set that preserves all passes.
Definition Analysis.h:118
PreservedAnalysisChecker getChecker() const
Build a checker for this PreservedAnalyses and the specified analysis type.
Definition Analysis.h:275
constexpr bool isValid() const
Definition Register.h:107
This node represents an addition of some number of SCEVs.
This node represents a polynomial recurrence on the trip count of the specified loop.
LLVM_ABI const SCEV * evaluateAtIteration(const SCEV *It, ScalarEvolution &SE) const
Return the value of this chain of recurrences at the specified iteration number.
const SCEV * getStepRecurrence(ScalarEvolution &SE) const
Constructs and returns the recurrence indicating how much this expression steps by.
void setNoWrapFlags(NoWrapFlags Flags)
Set flags for a recurrence without clearing any previously set flags.
bool isAffine() const
Return true if this represents an expression A + B*x where A and B are loop invariant values.
bool isQuadratic() const
Return true if this represents an expression A + B*x + C*x^2 where A, B and C are loop invariant valu...
LLVM_ABI const SCEV * getNumIterationsInRange(const ConstantRange &Range, ScalarEvolution &SE) const
Return the number of iterations of this loop that produce values in the specified constant range.
LLVM_ABI const SCEVAddRecExpr * getPostIncExpr(ScalarEvolution &SE) const
Return an expression representing the value of this expression one iteration of the loop ahead.
This is the base class for unary cast operator classes.
const SCEV * getOperand() const
LLVM_ABI SCEVCastExpr(const FoldingSetNodeIDRef ID, SCEVTypes SCEVTy, const SCEV *op, Type *ty)
void setNoWrapFlags(NoWrapFlags Flags)
Set flags for a non-recurrence without clearing previously set flags.
This class represents an assumption that the expression LHS Pred RHS evaluates to true,...
SCEVComparePredicate(const FoldingSetNodeIDRef ID, const ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS)
bool isAlwaysTrue() const override
Returns true if the predicate is always true.
void print(raw_ostream &OS, unsigned Depth=0) const override
Prints a textual representation of this predicate with an indentation of Depth.
bool implies(const SCEVPredicate *N, ScalarEvolution &SE) const override
Implementation of the SCEVPredicate interface.
This class represents a constant integer value.
ConstantInt * getValue() const
const APInt & getAPInt() const
This is the base class for unary integral cast operator classes.
LLVM_ABI SCEVIntegralCastExpr(const FoldingSetNodeIDRef ID, SCEVTypes SCEVTy, const SCEV *op, Type *ty)
This node is the base class min/max selections.
static enum SCEVTypes negate(enum SCEVTypes T)
This node represents multiplication of some number of SCEVs.
This node is a base class providing common functionality for n'ary operators.
NoWrapFlags getNoWrapFlags(NoWrapFlags Mask=NoWrapMask) const
const SCEV * getOperand(unsigned i) const
ArrayRef< const SCEV * > operands() const
This class represents an assumption made using SCEV expressions which can be checked at run-time.
SCEVPredicate(const SCEVPredicate &)=default
virtual bool implies(const SCEVPredicate *N, ScalarEvolution &SE) const =0
Returns true if this predicate implies N.
SCEVPredicateKind Kind
This class represents a cast from a pointer to a pointer-sized integer value.
This visitor recursively visits a SCEV expression and re-writes it.
const SCEV * visitSignExtendExpr(const SCEVSignExtendExpr *Expr)
const SCEV * visit(const SCEV *S)
const SCEV * visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr)
const SCEV * visitSMinExpr(const SCEVSMinExpr *Expr)
const SCEV * visitUMinExpr(const SCEVUMinExpr *Expr)
This class represents a signed minimum selection.
This node is the base class for sequential/in-order min/max selections.
static SCEVTypes getEquivalentNonSequentialSCEVType(SCEVTypes Ty)
This class represents a sign extension of a small integer value to a larger integer value.
Visit all nodes in the expression tree using worklist traversal.
This class represents a truncation of an integer value to a smaller integer value.
This class represents a binary unsigned division operation.
This class represents an unsigned minimum selection.
This class represents a composition of other SCEV predicates, and is the class that most clients will...
void print(raw_ostream &OS, unsigned Depth) const override
Prints a textual representation of this predicate with an indentation of Depth.
bool implies(const SCEVPredicate *N, ScalarEvolution &SE) const override
Returns true if this predicate implies N.
SCEVUnionPredicate(ArrayRef< const SCEVPredicate * > Preds, ScalarEvolution &SE)
Union predicates don't get cached so create a dummy set ID for it.
bool isAlwaysTrue() const override
Implementation of the SCEVPredicate interface.
This means that we are dealing with an entirely unknown SCEV value, and only represent it as its LLVM...
This class represents the value of vscale, as used when defining the length of a scalable vector or r...
This class represents an assumption made on an AddRec expression.
IncrementWrapFlags
Similar to SCEV::NoWrapFlags, but with slightly different semantics for FlagNUSW.
SCEVWrapPredicate(const FoldingSetNodeIDRef ID, const SCEVAddRecExpr *AR, IncrementWrapFlags Flags)
bool implies(const SCEVPredicate *N, ScalarEvolution &SE) const override
Returns true if this predicate implies N.
static SCEVWrapPredicate::IncrementWrapFlags setFlags(SCEVWrapPredicate::IncrementWrapFlags Flags, SCEVWrapPredicate::IncrementWrapFlags OnFlags)
void print(raw_ostream &OS, unsigned Depth=0) const override
Prints a textual representation of this predicate with an indentation of Depth.
bool isAlwaysTrue() const override
Returns true if the predicate is always true.
const SCEVAddRecExpr * getExpr() const
Implementation of the SCEVPredicate interface.
static SCEVWrapPredicate::IncrementWrapFlags clearFlags(SCEVWrapPredicate::IncrementWrapFlags Flags, SCEVWrapPredicate::IncrementWrapFlags OffFlags)
Convenient IncrementWrapFlags manipulation methods.
static SCEVWrapPredicate::IncrementWrapFlags getImpliedFlags(const SCEVAddRecExpr *AR, ScalarEvolution &SE)
Returns the set of SCEVWrapPredicate no wrap flags implied by a SCEVAddRecExpr.
IncrementWrapFlags getFlags() const
Returns the set assumed no overflow flags.
This class represents a zero extension of a small integer value to a larger integer value.
This class represents an analyzed expression in the program.
LLVM_ABI ArrayRef< const SCEV * > operands() const
Return operands of this SCEV expression.
unsigned short getExpressionSize() const
LLVM_ABI bool isOne() const
Return true if the expression is a constant one.
LLVM_ABI bool isZero() const
Return true if the expression is a constant zero.
LLVM_ABI void dump() const
This method is used for debugging.
LLVM_ABI bool isAllOnesValue() const
Return true if the expression is a constant all-ones value.
LLVM_ABI bool isNonConstantNegative() const
Return true if the specified scev is negated, but not a constant.
LLVM_ABI void print(raw_ostream &OS) const
Print out the internal representation of this scalar to the specified stream.
SCEV(const FoldingSetNodeIDRef ID, SCEVTypes SCEVTy, unsigned short ExpressionSize)
SCEVTypes getSCEVType() const
LLVM_ABI Type * getType() const
Return the LLVM type of this SCEV expression.
NoWrapFlags
NoWrapFlags are bitfield indices into SubclassData.
Analysis pass that exposes the ScalarEvolution for a function.
LLVM_ABI ScalarEvolution run(Function &F, FunctionAnalysisManager &AM)
LLVM_ABI PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM)
LLVM_ABI PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM)
void getAnalysisUsage(AnalysisUsage &AU) const override
getAnalysisUsage - This function should be overriden by passes that need analysis information to do t...
void print(raw_ostream &OS, const Module *=nullptr) const override
print - Print out the internal state of the pass.
bool runOnFunction(Function &F) override
runOnFunction - Virtual method overriden by subclasses to do the per-function processing of the pass.
void releaseMemory() override
releaseMemory() - This member can be implemented by a pass if it wants to be able to release its memo...
void verifyAnalysis() const override
verifyAnalysis() - This member can be implemented by a analysis pass to check state of analysis infor...
static LLVM_ABI LoopGuards collect(const Loop *L, ScalarEvolution &SE)
Collect rewrite map for loop guards for loop L, together with flags indicating if NUW and NSW can be ...
LLVM_ABI const SCEV * rewrite(const SCEV *Expr) const
Try to apply the collected loop guards to Expr.
The main scalar evolution driver.
const SCEV * getConstantMaxBackedgeTakenCount(const Loop *L)
When successful, this returns a SCEVConstant that is greater than or equal to (i.e.
static bool hasFlags(SCEV::NoWrapFlags Flags, SCEV::NoWrapFlags TestFlags)
const DataLayout & getDataLayout() const
Return the DataLayout associated with the module this SCEV instance is operating on.
LLVM_ABI bool isKnownNonNegative(const SCEV *S)
Test if the given expression is known to be non-negative.
LLVM_ABI bool isKnownOnEveryIteration(CmpPredicate Pred, const SCEVAddRecExpr *LHS, const SCEV *RHS)
Test if the condition described by Pred, LHS, RHS is known to be true on every iteration of the loop ...
LLVM_ABI const SCEV * getNegativeSCEV(const SCEV *V, SCEV::NoWrapFlags Flags=SCEV::FlagAnyWrap)
Return the SCEV object corresponding to -V.
LLVM_ABI std::optional< LoopInvariantPredicate > getLoopInvariantExitCondDuringFirstIterationsImpl(CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS, const Loop *L, const Instruction *CtxI, const SCEV *MaxIter)
LLVM_ABI const SCEV * getSMaxExpr(const SCEV *LHS, const SCEV *RHS)
LLVM_ABI const SCEV * getUDivCeilSCEV(const SCEV *N, const SCEV *D)
Compute ceil(N / D).
LLVM_ABI std::optional< LoopInvariantPredicate > getLoopInvariantExitCondDuringFirstIterations(CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS, const Loop *L, const Instruction *CtxI, const SCEV *MaxIter)
If the result of the predicate LHS Pred RHS is loop invariant with respect to L at given Context duri...
LLVM_ABI Type * getWiderType(Type *Ty1, Type *Ty2) const
LLVM_ABI const SCEV * getAbsExpr(const SCEV *Op, bool IsNSW)
LLVM_ABI bool isKnownNonPositive(const SCEV *S)
Test if the given expression is known to be non-positive.
LLVM_ABI const SCEV * getURemExpr(const SCEV *LHS, const SCEV *RHS)
Represents an unsigned remainder expression based on unsigned division.
LLVM_ABI bool isKnownNegative(const SCEV *S)
Test if the given expression is known to be negative.
LLVM_ABI const SCEV * getPredicatedConstantMaxBackedgeTakenCount(const Loop *L, SmallVectorImpl< const SCEVPredicate * > &Predicates)
Similar to getConstantMaxBackedgeTakenCount, except it will add a set of SCEV predicates to Predicate...
LLVM_ABI const SCEV * removePointerBase(const SCEV *S)
Compute an expression equivalent to S - getPointerBase(S).
LLVM_ABI bool isLoopEntryGuardedByCond(const Loop *L, CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS)
Test whether entry to the loop is protected by a conditional between LHS and RHS.
LLVM_ABI bool isKnownNonZero(const SCEV *S)
Test if the given expression is known to be non-zero.
LLVM_ABI const SCEV * getSCEVAtScope(const SCEV *S, const Loop *L)
Return a SCEV expression for the specified value at the specified scope in the program.
LLVM_ABI const SCEV * getSMinExpr(const SCEV *LHS, const SCEV *RHS)
LLVM_ABI const SCEV * getBackedgeTakenCount(const Loop *L, ExitCountKind Kind=Exact)
If the specified loop has a predictable backedge-taken count, return it, otherwise return a SCEVCould...
LLVM_ABI const SCEV * getUMaxExpr(const SCEV *LHS, const SCEV *RHS)
LLVM_ABI void setNoWrapFlags(SCEVAddRecExpr *AddRec, SCEV::NoWrapFlags Flags)
Update no-wrap flags of an AddRec.
LLVM_ABI const SCEV * getUMaxFromMismatchedTypes(const SCEV *LHS, const SCEV *RHS)
Promote the operands to the wider of the types using zero-extension, and then perform a umax operatio...
const SCEV * getZero(Type *Ty)
Return a SCEV for the constant 0 of a specific type.
LLVM_ABI bool willNotOverflow(Instruction::BinaryOps BinOp, bool Signed, const SCEV *LHS, const SCEV *RHS, const Instruction *CtxI=nullptr)
Is operation BinOp between LHS and RHS provably does not have a signed/unsigned overflow (Signed)?
LLVM_ABI ExitLimit computeExitLimitFromCond(const Loop *L, Value *ExitCond, bool ExitIfTrue, bool ControlsOnlyExit, bool AllowPredicates=false)
Compute the number of times the backedge of the specified loop will execute if its exit condition wer...
LLVM_ABI const SCEV * getZeroExtendExprImpl(const SCEV *Op, Type *Ty, unsigned Depth=0)
LLVM_ABI const SCEVPredicate * getEqualPredicate(const SCEV *LHS, const SCEV *RHS)
LLVM_ABI unsigned getSmallConstantTripMultiple(const Loop *L, const SCEV *ExitCount)
Returns the largest constant divisor of the trip count as a normal unsigned value,...
LLVM_ABI uint64_t getTypeSizeInBits(Type *Ty) const
Return the size in bits of the specified type, for which isSCEVable must return true.
LLVM_ABI const SCEV * getConstant(ConstantInt *V)
LLVM_ABI const SCEV * getPredicatedBackedgeTakenCount(const Loop *L, SmallVectorImpl< const SCEVPredicate * > &Predicates)
Similar to getBackedgeTakenCount, except it will add a set of SCEV predicates to Predicates that are ...
LLVM_ABI const SCEV * getSCEV(Value *V)
Return a SCEV expression for the full generality of the specified expression.
ConstantRange getSignedRange(const SCEV *S)
Determine the signed range for a particular SCEV.
LLVM_ABI const SCEV * getNoopOrSignExtend(const SCEV *V, Type *Ty)
Return a SCEV corresponding to a conversion of the input value to the specified type.
bool loopHasNoAbnormalExits(const Loop *L)
Return true if the loop has no abnormal exits.
LLVM_ABI const SCEV * getTripCountFromExitCount(const SCEV *ExitCount)
A version of getTripCountFromExitCount below which always picks an evaluation type which can not resu...
LLVM_ABI ScalarEvolution(Function &F, TargetLibraryInfo &TLI, AssumptionCache &AC, DominatorTree &DT, LoopInfo &LI)
const SCEV * getOne(Type *Ty)
Return a SCEV for the constant 1 of a specific type.
LLVM_ABI const SCEV * getTruncateOrNoop(const SCEV *V, Type *Ty)
Return a SCEV corresponding to a conversion of the input value to the specified type.
LLVM_ABI const SCEV * getCastExpr(SCEVTypes Kind, const SCEV *Op, Type *Ty)
LLVM_ABI const SCEV * getSequentialMinMaxExpr(SCEVTypes Kind, SmallVectorImpl< const SCEV * > &Operands)
LLVM_ABI const SCEV * getLosslessPtrToIntExpr(const SCEV *Op, unsigned Depth=0)
LLVM_ABI std::optional< bool > evaluatePredicateAt(CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS, const Instruction *CtxI)
Check whether the condition described by Pred, LHS, and RHS is true or false in the given Context.
LLVM_ABI unsigned getSmallConstantMaxTripCount(const Loop *L, SmallVectorImpl< const SCEVPredicate * > *Predicates=nullptr)
Returns the upper bound of the loop trip count as a normal unsigned value.
LLVM_ABI const SCEV * getPtrToIntExpr(const SCEV *Op, Type *Ty)
LLVM_ABI bool isBackedgeTakenCountMaxOrZero(const Loop *L)
Return true if the backedge taken count is either the value returned by getConstantMaxBackedgeTakenCo...
LLVM_ABI void forgetLoop(const Loop *L)
This method should be called by the client when it has changed a loop in a way that may effect Scalar...
LLVM_ABI bool isLoopInvariant(const SCEV *S, const Loop *L)
Return true if the value of the given SCEV is unchanging in the specified loop.
LLVM_ABI bool isKnownPositive(const SCEV *S)
Test if the given expression is known to be positive.
APInt getUnsignedRangeMin(const SCEV *S)
Determine the min of the unsigned range for a particular SCEV.
LLVM_ABI bool SimplifyICmpOperands(CmpPredicate &Pred, const SCEV *&LHS, const SCEV *&RHS, unsigned Depth=0)
Simplify LHS and RHS in a comparison with predicate Pred.
LLVM_ABI const SCEV * getOffsetOfExpr(Type *IntTy, StructType *STy, unsigned FieldNo)
Return an expression for offsetof on the given field with type IntTy.
LLVM_ABI LoopDisposition getLoopDisposition(const SCEV *S, const Loop *L)
Return the "disposition" of the given SCEV with respect to the given loop.
LLVM_ABI bool containsAddRecurrence(const SCEV *S)
Return true if the SCEV is a scAddRecExpr or it contains scAddRecExpr.
LLVM_ABI const SCEV * getSignExtendExprImpl(const SCEV *Op, Type *Ty, unsigned Depth=0)
LLVM_ABI const SCEV * getAddRecExpr(const SCEV *Start, const SCEV *Step, const Loop *L, SCEV::NoWrapFlags Flags)
Get an add recurrence expression for the specified loop.
LLVM_ABI bool hasOperand(const SCEV *S, const SCEV *Op) const
Test whether the given SCEV has Op as a direct or indirect operand.
LLVM_ABI const SCEV * getUDivExpr(const SCEV *LHS, const SCEV *RHS)
Get a canonical unsigned division expression, or something simpler if possible.
LLVM_ABI const SCEV * getZeroExtendExpr(const SCEV *Op, Type *Ty, unsigned Depth=0)
LLVM_ABI bool isSCEVable(Type *Ty) const
Test if values of the given type are analyzable within the SCEV framework.
LLVM_ABI Type * getEffectiveSCEVType(Type *Ty) const
Return a type with the same bitwidth as the given type and which represents how SCEV will treat the g...
LLVM_ABI const SCEVPredicate * getComparePredicate(ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS)
LLVM_ABI const SCEV * getNotSCEV(const SCEV *V)
Return the SCEV object corresponding to ~V.
LLVM_ABI const SCEV * getElementCount(Type *Ty, ElementCount EC, SCEV::NoWrapFlags Flags=SCEV::FlagAnyWrap)
LLVM_ABI bool instructionCouldExistWithOperands(const SCEV *A, const SCEV *B)
Return true if there exists a point in the program at which both A and B could be operands to the sam...
ConstantRange getUnsignedRange(const SCEV *S)
Determine the unsigned range for a particular SCEV.
LLVM_ABI void print(raw_ostream &OS) const
LLVM_ABI const SCEV * getUMinExpr(const SCEV *LHS, const SCEV *RHS, bool Sequential=false)
LLVM_ABI const SCEV * getPredicatedExitCount(const Loop *L, const BasicBlock *ExitingBlock, SmallVectorImpl< const SCEVPredicate * > *Predicates, ExitCountKind Kind=Exact)
Same as above except this uses the predicated backedge taken info and may require predicates.
static SCEV::NoWrapFlags clearFlags(SCEV::NoWrapFlags Flags, SCEV::NoWrapFlags OffFlags)
LLVM_ABI void forgetTopmostLoop(const Loop *L)
LLVM_ABI void forgetValue(Value *V)
This method should be called by the client when it has changed a value in a way that may effect its v...
APInt getSignedRangeMin(const SCEV *S)
Determine the min of the signed range for a particular SCEV.
LLVM_ABI const SCEV * getNoopOrAnyExtend(const SCEV *V, Type *Ty)
Return a SCEV corresponding to a conversion of the input value to the specified type.
LLVM_ABI void forgetBlockAndLoopDispositions(Value *V=nullptr)
Called when the client has changed the disposition of values in a loop or block.
LLVM_ABI const SCEV * getTruncateExpr(const SCEV *Op, Type *Ty, unsigned Depth=0)
LLVM_ABI std::optional< LoopInvariantPredicate > getLoopInvariantPredicate(CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS, const Loop *L, const Instruction *CtxI=nullptr)
If the result of the predicate LHS Pred RHS is loop invariant with respect to L, return a LoopInvaria...
LLVM_ABI const SCEV * getStoreSizeOfExpr(Type *IntTy, Type *StoreTy)
Return an expression for the store size of StoreTy that is type IntTy.
LLVM_ABI const SCEVPredicate * getWrapPredicate(const SCEVAddRecExpr *AR, SCEVWrapPredicate::IncrementWrapFlags AddedFlags)
LLVM_ABI bool isLoopBackedgeGuardedByCond(const Loop *L, CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS)
Test whether the backedge of the loop is protected by a conditional between LHS and RHS.
LLVM_ABI const SCEV * getMinusSCEV(const SCEV *LHS, const SCEV *RHS, SCEV::NoWrapFlags Flags=SCEV::FlagAnyWrap, unsigned Depth=0)
Return LHS-RHS.
LLVM_ABI APInt getNonZeroConstantMultiple(const SCEV *S)
const SCEV * getMinusOne(Type *Ty)
Return a SCEV for the constant -1 of a specific type.
static SCEV::NoWrapFlags setFlags(SCEV::NoWrapFlags Flags, SCEV::NoWrapFlags OnFlags)
LLVM_ABI bool hasLoopInvariantBackedgeTakenCount(const Loop *L)
Return true if the specified loop has an analyzable loop-invariant backedge-taken count.
LLVM_ABI BlockDisposition getBlockDisposition(const SCEV *S, const BasicBlock *BB)
Return the "disposition" of the given SCEV with respect to the given block.
LLVM_ABI const SCEV * getNoopOrZeroExtend(const SCEV *V, Type *Ty)
Return a SCEV corresponding to a conversion of the input value to the specified type.
LLVM_ABI bool invalidate(Function &F, const PreservedAnalyses &PA, FunctionAnalysisManager::Invalidator &Inv)
LLVM_ABI const SCEV * getUMinFromMismatchedTypes(const SCEV *LHS, const SCEV *RHS, bool Sequential=false)
Promote the operands to the wider of the types using zero-extension, and then perform a umin operatio...
LLVM_ABI bool loopIsFiniteByAssumption(const Loop *L)
Return true if this loop is finite by assumption.
LLVM_ABI const SCEV * getExistingSCEV(Value *V)
Return an existing SCEV for V if there is one, otherwise return nullptr.
LLVM_ABI APInt getConstantMultiple(const SCEV *S, const Instruction *CtxI=nullptr)
Returns the max constant multiple of S.
LoopDisposition
An enum describing the relationship between a SCEV and a loop.
@ LoopComputable
The SCEV varies predictably with the loop.
@ LoopVariant
The SCEV is loop-variant (unknown).
@ LoopInvariant
The SCEV is loop-invariant.
LLVM_ABI bool isKnownMultipleOf(const SCEV *S, uint64_t M, SmallVectorImpl< const SCEVPredicate * > &Assumptions)
Check that S is a multiple of M.
LLVM_ABI const SCEV * getAnyExtendExpr(const SCEV *Op, Type *Ty)
getAnyExtendExpr - Return a SCEV for the given operand extended with unspecified bits out to the give...
LLVM_ABI bool isKnownToBeAPowerOfTwo(const SCEV *S, bool OrZero=false, bool OrNegative=false)
Test if the given expression is known to be a power of 2.
LLVM_ABI std::optional< SCEV::NoWrapFlags > getStrengthenedNoWrapFlagsFromBinOp(const OverflowingBinaryOperator *OBO)
Parse NSW/NUW flags from add/sub/mul IR binary operation Op into SCEV no-wrap flags,...
LLVM_ABI void forgetLcssaPhiWithNewPredecessor(Loop *L, PHINode *V)
Forget LCSSA phi node V of loop L to which a new predecessor was added, such that it may no longer be...
LLVM_ABI bool containsUndefs(const SCEV *S) const
Return true if the SCEV expression contains an undef value.
LLVM_ABI std::optional< MonotonicPredicateType > getMonotonicPredicateType(const SCEVAddRecExpr *LHS, ICmpInst::Predicate Pred)
If, for all loop invariant X, the predicate "LHS `Pred` X" is monotonically increasing or decreasing,...
LLVM_ABI const SCEV * getCouldNotCompute()
LLVM_ABI bool isAvailableAtLoopEntry(const SCEV *S, const Loop *L)
Determine if the SCEV can be evaluated at loop's entry.
LLVM_ABI uint32_t getMinTrailingZeros(const SCEV *S, const Instruction *CtxI=nullptr)
Determine the minimum number of zero bits that S is guaranteed to end in (at every loop iteration).
BlockDisposition
An enum describing the relationship between a SCEV and a basic block.
@ DominatesBlock
The SCEV dominates the block.
@ ProperlyDominatesBlock
The SCEV properly dominates the block.
@ DoesNotDominateBlock
The SCEV does not dominate the block.
LLVM_ABI const SCEV * getGEPExpr(GEPOperator *GEP, ArrayRef< const SCEV * > IndexExprs)
Returns an expression for a GEP.
LLVM_ABI const SCEV * getExitCount(const Loop *L, const BasicBlock *ExitingBlock, ExitCountKind Kind=Exact)
Return the number of times the backedge executes before the given exit would be taken; if not exactly...
LLVM_ABI const SCEV * getSignExtendExpr(const SCEV *Op, Type *Ty, unsigned Depth=0)
LLVM_ABI void getPoisonGeneratingValues(SmallPtrSetImpl< const Value * > &Result, const SCEV *S)
Return the set of Values that, if poison, will definitively result in S being poison as well.
LLVM_ABI void forgetLoopDispositions()
Called when the client has changed the disposition of values in this loop.
LLVM_ABI const SCEV * getVScale(Type *Ty)
LLVM_ABI unsigned getSmallConstantTripCount(const Loop *L)
Returns the exact trip count of the loop if we can compute it, and the result is a small constant.
LLVM_ABI bool hasComputableLoopEvolution(const SCEV *S, const Loop *L)
Return true if the given SCEV changes value in a known way in the specified loop.
LLVM_ABI const SCEV * getPointerBase(const SCEV *V)
Transitively follow the chain of pointer-type operands until reaching a SCEV that does not have a sin...
LLVM_ABI const SCEV * getMinMaxExpr(SCEVTypes Kind, SmallVectorImpl< const SCEV * > &Operands)
LLVM_ABI void forgetAllLoops()
LLVM_ABI bool dominates(const SCEV *S, const BasicBlock *BB)
Return true if elements that makes up the given SCEV dominate the specified basic block.
APInt getUnsignedRangeMax(const SCEV *S)
Determine the max of the unsigned range for a particular SCEV.
ExitCountKind
The terms "backedge taken count" and "exit count" are used interchangeably to refer to the number of ...
@ SymbolicMaximum
An expression which provides an upper bound on the exact trip count.
@ ConstantMaximum
A constant which provides an upper bound on the exact trip count.
@ Exact
An expression exactly describing the number of times the backedge has executed when a loop is exited.
LLVM_ABI const SCEV * applyLoopGuards(const SCEV *Expr, const Loop *L)
Try to apply information from loop guards for L to Expr.
LLVM_ABI const SCEV * getMulExpr(SmallVectorImpl< const SCEV * > &Ops, SCEV::NoWrapFlags Flags=SCEV::FlagAnyWrap, unsigned Depth=0)
Get a canonical multiply expression, or something simpler if possible.
LLVM_ABI const SCEVAddRecExpr * convertSCEVToAddRecWithPredicates(const SCEV *S, const Loop *L, SmallVectorImpl< const SCEVPredicate * > &Preds)
Tries to convert the S expression to an AddRec expression, adding additional predicates to Preds as r...
LLVM_ABI const SCEV * getElementSize(Instruction *Inst)
Return the size of an element read or written by Inst.
LLVM_ABI const SCEV * getSizeOfExpr(Type *IntTy, TypeSize Size)
Return an expression for a TypeSize.
LLVM_ABI std::optional< bool > evaluatePredicate(CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS)
Check whether the condition described by Pred, LHS, and RHS is true or false.
LLVM_ABI const SCEV * getUnknown(Value *V)
LLVM_ABI std::optional< std::pair< const SCEV *, SmallVector< const SCEVPredicate *, 3 > > > createAddRecFromPHIWithCasts(const SCEVUnknown *SymbolicPHI)
Checks if SymbolicPHI can be rewritten as an AddRecExpr under some Predicates.
LLVM_ABI const SCEV * getTruncateOrZeroExtend(const SCEV *V, Type *Ty, unsigned Depth=0)
Return a SCEV corresponding to a conversion of the input value to the specified type.
static SCEV::NoWrapFlags maskFlags(SCEV::NoWrapFlags Flags, int Mask)
Convenient NoWrapFlags manipulation that hides enum casts and is visible in the ScalarEvolution name ...
LLVM_ABI std::optional< APInt > computeConstantDifference(const SCEV *LHS, const SCEV *RHS)
Compute LHS - RHS and returns the result as an APInt if it is a constant, and std::nullopt if it isn'...
LLVM_ABI bool properlyDominates(const SCEV *S, const BasicBlock *BB)
Return true if elements that makes up the given SCEV properly dominate the specified basic block.
LLVM_ABI const SCEV * rewriteUsingPredicate(const SCEV *S, const Loop *L, const SCEVPredicate &A)
Re-writes the SCEV according to the Predicates in A.
LLVM_ABI std::pair< const SCEV *, const SCEV * > SplitIntoInitAndPostInc(const Loop *L, const SCEV *S)
Splits SCEV expression S into two SCEVs.
LLVM_ABI bool canReuseInstruction(const SCEV *S, Instruction *I, SmallVectorImpl< Instruction * > &DropPoisonGeneratingInsts)
Check whether it is poison-safe to represent the expression S using the instruction I.
LLVM_ABI bool isKnownPredicateAt(CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS, const Instruction *CtxI)
Test if the given expression is known to satisfy the condition described by Pred, LHS,...
LLVM_ABI const SCEV * getPredicatedSymbolicMaxBackedgeTakenCount(const Loop *L, SmallVectorImpl< const SCEVPredicate * > &Predicates)
Similar to getSymbolicMaxBackedgeTakenCount, except it will add a set of SCEV predicates to Predicate...
LLVM_ABI const SCEV * getUDivExactExpr(const SCEV *LHS, const SCEV *RHS)
Get a canonical unsigned division expression, or something simpler if possible.
LLVM_ABI void registerUser(const SCEV *User, ArrayRef< const SCEV * > Ops)
Notify this ScalarEvolution that User directly uses SCEVs in Ops.
LLVM_ABI const SCEV * getAddExpr(SmallVectorImpl< const SCEV * > &Ops, SCEV::NoWrapFlags Flags=SCEV::FlagAnyWrap, unsigned Depth=0)
Get a canonical add expression, or something simpler if possible.
LLVM_ABI bool isBasicBlockEntryGuardedByCond(const BasicBlock *BB, CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS)
Test whether entry to the basic block is protected by a conditional between LHS and RHS.
LLVM_ABI const SCEV * getTruncateOrSignExtend(const SCEV *V, Type *Ty, unsigned Depth=0)
Return a SCEV corresponding to a conversion of the input value to the specified type.
LLVM_ABI bool containsErasedValue(const SCEV *S) const
Return true if the SCEV expression contains a Value that has been optimised out and is now a nullptr.
LLVM_ABI bool isKnownPredicate(CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS)
Test if the given expression is known to satisfy the condition described by Pred, LHS,...
LLVM_ABI bool isKnownViaInduction(CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS)
We'd like to check the predicate on every iteration of the most dominated loop between loops used in ...
const SCEV * getSymbolicMaxBackedgeTakenCount(const Loop *L)
When successful, this returns a SCEV that is greater than or equal to (i.e.
APInt getSignedRangeMax(const SCEV *S)
Determine the max of the signed range for a particular SCEV.
LLVM_ABI void verify() const
LLVMContext & getContext() const
Implements a dense probed hash-table based set with some number of buckets stored inline.
Definition DenseSet.h:291
size_type size() const
Definition SmallPtrSet.h:99
A templated base class for SmallPtrSet which provides the typesafe interface that is common across al...
std::pair< iterator, bool > insert(PtrType Ptr)
Inserts Ptr if and only if there is no element in the container equal to Ptr.
bool contains(ConstPtrType Ptr) const
SmallPtrSet - This class implements a set which is optimized for holding SmallSize or less elements.
This class consists of common code factored out of the SmallVector class to reduce code duplication b...
reference emplace_back(ArgTypes &&... Args)
void reserve(size_type N)
iterator erase(const_iterator CI)
void append(ItTy in_start, ItTy in_end)
Add the specified range to the end of the SmallVector.
iterator insert(iterator I, T &&Elt)
void push_back(const T &Elt)
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
An instruction for storing to memory.
Used to lazily calculate structure layout information for a target machine, based on the DataLayout s...
Definition DataLayout.h:712
TypeSize getElementOffset(unsigned Idx) const
Definition DataLayout.h:743
TypeSize getSizeInBits() const
Definition DataLayout.h:723
Class to represent struct types.
Analysis pass providing the TargetLibraryInfo.
Provides information about what library functions are available for the current target.
The instances of the Type class are immutable: once they are created, they are never changed.
Definition Type.h:45
static LLVM_ABI IntegerType * getInt32Ty(LLVMContext &C)
Definition Type.cpp:297
bool isPointerTy() const
True if this is an instance of PointerType.
Definition Type.h:267
static LLVM_ABI IntegerType * getInt8Ty(LLVMContext &C)
Definition Type.cpp:295
LLVM_ABI TypeSize getPrimitiveSizeInBits() const LLVM_READONLY
Return the basic size of this type if it is a primitive type.
Definition Type.cpp:198
static LLVM_ABI IntegerType * getInt1Ty(LLVMContext &C)
Definition Type.cpp:294
bool isIntOrPtrTy() const
Return true if this is an integer type or a pointer type.
Definition Type.h:255
bool isIntegerTy() const
True if this is an instance of IntegerType.
Definition Type.h:240
static LLVM_ABI IntegerType * getIntNTy(LLVMContext &C, unsigned N)
Definition Type.cpp:301
A Use represents the edge between a Value definition and its users.
Definition Use.h:35
op_range operands()
Definition User.h:292
Use & Op()
Definition User.h:196
Value * getOperand(unsigned i) const
Definition User.h:232
LLVM Value Representation.
Definition Value.h:75
Type * getType() const
All values are typed, get the type of this value.
Definition Value.h:256
unsigned getValueID() const
Return an ID for the concrete type of this object.
Definition Value.h:543
LLVM_ABI void printAsOperand(raw_ostream &O, bool PrintType=true, const Module *M=nullptr) const
Print the name of this Value out to the specified raw_ostream.
LLVM_ABI LLVMContext & getContext() const
All values hold a context through their type.
Definition Value.cpp:1099
LLVM_ABI StringRef getName() const
Return a constant reference to the value's name.
Definition Value.cpp:322
constexpr bool isScalable() const
Returns whether the quantity is scaled by a runtime quantity (vscale).
Definition TypeSize.h:169
const ParentTy * getParent() const
Definition ilist_node.h:34
This class implements an extremely fast bulk output stream that can only output to a stream.
Definition raw_ostream.h:53
raw_ostream & indent(unsigned NumSpaces)
indent - Insert 'NumSpaces' spaces.
Changed
#define llvm_unreachable(msg)
Marks that the current location is not supposed to be reachable.
constexpr char Align[]
Key for Kernel::Arg::Metadata::mAlign.
const APInt & smin(const APInt &A, const APInt &B)
Determine the smaller of two APInts considered to be signed.
Definition APInt.h:2248
const APInt & smax(const APInt &A, const APInt &B)
Determine the larger of two APInts considered to be signed.
Definition APInt.h:2253
const APInt & umin(const APInt &A, const APInt &B)
Determine the smaller of two APInts considered to be unsigned.
Definition APInt.h:2258
LLVM_ABI std::optional< APInt > SolveQuadraticEquationWrap(APInt A, APInt B, APInt C, unsigned RangeWidth)
Let q(n) = An^2 + Bn + C, and BW = bit width of the value range (e.g.
Definition APInt.cpp:2812
const APInt & umax(const APInt &A, const APInt &B)
Determine the larger of two APInts considered to be unsigned.
Definition APInt.h:2263
LLVM_ABI APInt GreatestCommonDivisor(APInt A, APInt B)
Compute GCD of two unsigned APInt values.
Definition APInt.cpp:798
@ Entry
Definition COFF.h:862
unsigned ID
LLVM IR allows to use arbitrary numbers as calling convention identifiers.
Definition CallingConv.h:24
@ C
The default llvm calling convention, compatible with C.
Definition CallingConv.h:34
int getMinValue(MCInstrInfo const &MCII, MCInst const &MCI)
Return the minimum value of an extendable operand.
@ BasicBlock
Various leaf nodes.
Definition ISDOpcodes.h:81
LLVM_ABI Function * getDeclarationIfExists(const Module *M, ID id)
Look up the Function declaration of the intrinsic id in the Module M and return it if it exists.
Predicate
Predicate - These are "(BI << 5) | BO" for various predicates.
BinaryOp_match< LHS, RHS, Instruction::AShr > m_AShr(const LHS &L, const RHS &R)
ap_match< APInt > m_APInt(const APInt *&Res)
Match a ConstantInt or splatted ConstantVector, binding the specified pointer to the contained APInt.
bool match(Val *V, const Pattern &P)
class_match< ConstantInt > m_ConstantInt()
Match an arbitrary ConstantInt and ignore it.
IntrinsicID_match m_Intrinsic()
Match intrinsic calls like this: m_Intrinsic<Intrinsic::fabs>(m_Value(X))
ThreeOps_match< Cond, LHS, RHS, Instruction::Select > m_Select(const Cond &C, const LHS &L, const RHS &R)
Matches SelectInst.
ExtractValue_match< Ind, Val_t > m_ExtractValue(const Val_t &V)
Match a single index ExtractValue instruction.
bind_ty< WithOverflowInst > m_WithOverflowInst(WithOverflowInst *&I)
Match a with overflow intrinsic, capturing it if we match.
auto m_LogicalOr()
Matches L || R where L and R are arbitrary values.
brc_match< Cond_t, bind_ty< BasicBlock >, bind_ty< BasicBlock > > m_Br(const Cond_t &C, BasicBlock *&T, BasicBlock *&F)
BinaryOp_match< LHS, RHS, Instruction::SDiv > m_SDiv(const LHS &L, const RHS &R)
class_match< Value > m_Value()
Match an arbitrary value and ignore it.
BinaryOp_match< LHS, RHS, Instruction::LShr > m_LShr(const LHS &L, const RHS &R)
BinaryOp_match< LHS, RHS, Instruction::Shl > m_Shl(const LHS &L, const RHS &R)
auto m_LogicalAnd()
Matches L && R where L and R are arbitrary values.
class_match< BasicBlock > m_BasicBlock()
Match an arbitrary basic block value and ignore it.
match_combine_or< LTy, RTy > m_CombineOr(const LTy &L, const RTy &R)
Combine two pattern matchers matching L || R.
class_match< const SCEVVScale > m_SCEVVScale()
bind_cst_ty m_scev_APInt(const APInt *&C)
Match an SCEV constant and bind it to an APInt.
cst_pred_ty< is_all_ones > m_scev_AllOnes()
Match an integer with all bits set.
SCEVUnaryExpr_match< SCEVZeroExtendExpr, Op0_t > m_scev_ZExt(const Op0_t &Op0)
SCEVBinaryExpr_match< SCEVMinMaxExpr, Op0_t, Op1_t > m_scev_MinMax(const Op0_t &Op0, const Op1_t &Op1)
class_match< const SCEVConstant > m_SCEVConstant()
cst_pred_ty< is_one > m_scev_One()
Match an integer 1.
specificloop_ty m_SpecificLoop(const Loop *L)
SCEVAffineAddRec_match< Op0_t, Op1_t, class_match< const Loop > > m_scev_AffineAddRec(const Op0_t &Op0, const Op1_t &Op1)
bind_ty< const SCEVMulExpr > m_scev_Mul(const SCEVMulExpr *&V)
SCEVUnaryExpr_match< SCEVSignExtendExpr, Op0_t > m_scev_SExt(const Op0_t &Op0)
cst_pred_ty< is_zero > m_scev_Zero()
Match an integer 0.
SCEVUnaryExpr_match< SCEVTruncateExpr, Op0_t > m_scev_Trunc(const Op0_t &Op0)
bool match(const SCEV *S, const Pattern &P)
SCEVBinaryExpr_match< SCEVUDivExpr, Op0_t, Op1_t > m_scev_UDiv(const Op0_t &Op0, const Op1_t &Op1)
specificscev_ty m_scev_Specific(const SCEV *S)
Match if we have a specific specified SCEV.
SCEVBinaryExpr_match< SCEVMulExpr, Op0_t, Op1_t, SCEV::FlagNUW, true > m_scev_c_NUWMul(const Op0_t &Op0, const Op1_t &Op1)
class_match< const Loop > m_Loop()
bind_ty< const SCEVAddExpr > m_scev_Add(const SCEVAddExpr *&V)
bind_ty< const SCEVUnknown > m_SCEVUnknown(const SCEVUnknown *&V)
SCEVBinaryExpr_match< SCEVMulExpr, Op0_t, Op1_t, SCEV::FlagAnyWrap, true > m_scev_c_Mul(const Op0_t &Op0, const Op1_t &Op1)
SCEVBinaryExpr_match< SCEVSMaxExpr, Op0_t, Op1_t > m_scev_SMax(const Op0_t &Op0, const Op1_t &Op1)
SCEVURem_match< Op0_t, Op1_t > m_scev_URem(Op0_t LHS, Op1_t RHS, ScalarEvolution &SE)
Match the mathematical pattern A - (A / B) * B, where A and B can be arbitrary expressions.
class_match< const SCEV > m_SCEV()
initializer< Ty > init(const Ty &Val)
LocationClass< Ty > location(Ty &L)
@ Switch
The "resume-switch" lowering, where there are separate resume and destroy functions that are shared b...
Definition CoroShape.h:31
constexpr double e
NodeAddr< PhiNode * > Phi
Definition RDFGraph.h:390
friend class Instruction
Iterator for Instructions in a `BasicBlock.
Definition BasicBlock.h:73
This is an optimization pass for GlobalISel generic memory operations.
void visitAll(const SCEV *Root, SV &Visitor)
Use SCEVTraversal to visit all nodes in the given expression tree.
auto drop_begin(T &&RangeOrContainer, size_t N=1)
Return a range covering RangeOrContainer with the first N elements excluded.
Definition STLExtras.h:316
@ Offset
Definition DWP.cpp:477
FunctionAddr VTableAddr Value
Definition InstrProf.h:137
LLVM_ATTRIBUTE_ALWAYS_INLINE DynamicAPInt gcd(const DynamicAPInt &A, const DynamicAPInt &B)
void stable_sort(R &&Range)
Definition STLExtras.h:2058
bool all_of(R &&range, UnaryPredicate P)
Provide wrappers to std::all_of which take ranges instead of having to pass begin/end explicitly.
Definition STLExtras.h:1725
SaveAndRestore(T &) -> SaveAndRestore< T >
Printable print(const GCNRegPressure &RP, const GCNSubtarget *ST=nullptr, unsigned DynamicVGPRBlockSize=0)
LLVM_ABI bool canCreatePoison(const Operator *Op, bool ConsiderFlagsAndMetadata=true)
LLVM_ABI bool mustTriggerUB(const Instruction *I, const SmallPtrSetImpl< const Value * > &KnownPoison)
Return true if the given instruction must trigger undefined behavior when I is executed with any oper...
detail::scope_exit< std::decay_t< Callable > > make_scope_exit(Callable &&F)
Definition ScopeExit.h:59
LLVM_ABI bool canConstantFoldCallTo(const CallBase *Call, const Function *F)
canConstantFoldCallTo - Return true if its even possible to fold a call to the specified function.
InterleavedRange< Range > interleaved(const Range &R, StringRef Separator=", ", StringRef Prefix="", StringRef Suffix="")
Output range R as a sequence of interleaved elements.
decltype(auto) dyn_cast(const From &Val)
dyn_cast<X> - Return the argument parameter cast to the specified type.
Definition Casting.h:643
LLVM_ABI bool verifyFunction(const Function &F, raw_ostream *OS=nullptr)
Check a function for errors, useful for use when debugging a pass.
auto successors(const MachineBasicBlock *BB)
constexpr from_range_t from_range
auto dyn_cast_if_present(const Y &Val)
dyn_cast_if_present<X> - Functionally identical to dyn_cast, except that a null (or none in the case ...
Definition Casting.h:732
bool set_is_subset(const S1Ty &S1, const S2Ty &S2)
set_is_subset(A, B) - Return true iff A in B
void append_range(Container &C, Range &&R)
Wrapper function to append range R to container C.
Definition STLExtras.h:2136
constexpr bool isUIntN(unsigned N, uint64_t x)
Checks if an unsigned integer fits into the given (dynamic) bit width.
Definition MathExtras.h:243
LLVM_ABI Constant * ConstantFoldCompareInstOperands(unsigned Predicate, Constant *LHS, Constant *RHS, const DataLayout &DL, const TargetLibraryInfo *TLI=nullptr, const Instruction *I=nullptr)
Attempt to constant fold a compare instruction (icmp/fcmp) with the specified operands.
unsigned short computeExpressionSize(ArrayRef< const SCEV * > Args)
void * PointerTy
LLVM_ABI bool VerifySCEV
auto uninitialized_copy(R &&Src, IterTy Dst)
Definition STLExtras.h:2053
bool isa_and_nonnull(const Y &Val)
Definition Casting.h:676
LLVM_ABI ConstantRange getConstantRangeFromMetadata(const MDNode &RangeMD)
Parse out a conservative ConstantRange from !range metadata.
int countr_zero(T Val)
Count number of 0's from the least significant bit to the most stopping at the first 1.
Definition bit.h:202
LLVM_ABI Value * simplifyInstruction(Instruction *I, const SimplifyQuery &Q)
See if we can compute a simplified version of this instruction.
LLVM_ABI bool isOverflowIntrinsicNoWrap(const WithOverflowInst *WO, const DominatorTree &DT)
Returns true if the arithmetic part of the WO 's result is used only along the paths control dependen...
DomTreeNodeBase< BasicBlock > DomTreeNode
Definition Dominators.h:95
LLVM_ABI bool matchSimpleRecurrence(const PHINode *P, BinaryOperator *&BO, Value *&Start, Value *&Step)
Attempt to match a simple first order recurrence cycle of the form: iv = phi Ty [Start,...
auto dyn_cast_or_null(const Y &Val)
Definition Casting.h:753
void erase(Container &C, ValueType V)
Wrapper function to remove a value from a container:
Definition STLExtras.h:2128
bool any_of(R &&range, UnaryPredicate P)
Provide wrappers to std::any_of which take ranges instead of having to pass begin/end explicitly.
Definition STLExtras.h:1732
iterator_range< pointee_iterator< WrappedIteratorT > > make_pointee_range(RangeT &&Range)
Definition iterator.h:336
auto reverse(ContainerTy &&C)
Definition STLExtras.h:406
LLVM_ABI bool isMustProgress(const Loop *L)
Return true if this loop can be assumed to make progress.
LLVM_ABI bool impliesPoison(const Value *ValAssumedPoison, const Value *V)
Return true if V is poison given that ValAssumedPoison is already poison.
LLVM_ABI bool isFinite(const Loop *L)
Return true if this loop can be assumed to run for a finite number of iterations.
LLVM_ABI void computeKnownBits(const Value *V, KnownBits &Known, const DataLayout &DL, AssumptionCache *AC=nullptr, const Instruction *CxtI=nullptr, const DominatorTree *DT=nullptr, bool UseInstrInfo=true, unsigned Depth=0)
Determine which bits of V are known to be either zero or one and return them in the KnownZero/KnownOn...
LLVM_ABI bool programUndefinedIfPoison(const Instruction *Inst)
LLVM_ABI raw_ostream & dbgs()
dbgs() - This returns a reference to a raw_ostream for debugging messages.
Definition Debug.cpp:207
bool isPointerTy(const Type *T)
Definition SPIRVUtils.h:339
FunctionAddr VTableAddr Count
Definition InstrProf.h:139
LLVM_ABI ConstantRange getVScaleRange(const Function *F, unsigned BitWidth)
Determine the possible constant range of vscale with the given bit width, based on the vscale_range f...
class LLVM_GSL_OWNER SmallVector
Forward declaration of SmallVector so that calculateSmallVectorDefaultInlinedElements can reference s...
bool isa(const From &Val)
isa<X> - Return true if the parameter to the template is an instance of one of the template type argu...
Definition Casting.h:547
LLVM_ATTRIBUTE_VISIBILITY_DEFAULT AnalysisKey InnerAnalysisManagerProxy< AnalysisManagerT, IRUnitT, ExtraArgTs... >::Key
LLVM_ABI bool isKnownNonZero(const Value *V, const SimplifyQuery &Q, unsigned Depth=0)
Return true if the given value is known to be non-zero when defined.
@ First
Helpers to iterate all locations in the MemoryEffectsBase class.
Definition ModRef.h:71
LLVM_ABI bool propagatesPoison(const Use &PoisonOp)
Return true if PoisonOp's user yields poison or raises UB if its operand PoisonOp is poison.
@ UMin
Unsigned integer min implemented in terms of select(cmp()).
@ Mul
Product of integers.
@ SMax
Signed integer max implemented in terms of select(cmp()).
@ SMin
Signed integer min implemented in terms of select(cmp()).
@ Add
Sum of integers.
@ UMax
Unsigned integer max implemented in terms of select(cmp()).
auto count(R &&Range, const E &Element)
Wrapper function around std::count to count the number of times an element Element occurs in the give...
Definition STLExtras.h:1954
DWARFExpression::Operation Op
auto max_element(R &&Range)
Provide wrappers to std::max_element which take ranges instead of having to pass begin/end explicitly...
Definition STLExtras.h:2030
raw_ostream & operator<<(raw_ostream &OS, const APFixedPoint &FX)
ArrayRef(const T &OneElt) -> ArrayRef< T >
LLVM_ABI unsigned ComputeNumSignBits(const Value *Op, const DataLayout &DL, AssumptionCache *AC=nullptr, const Instruction *CxtI=nullptr, const DominatorTree *DT=nullptr, bool UseInstrInfo=true, unsigned Depth=0)
Return the number of times the sign bit of the register is replicated into the other bits.
constexpr unsigned BitWidth
OutputIt move(R &&Range, OutputIt Out)
Provide wrappers to std::move which take ranges instead of having to pass begin/end explicitly.
Definition STLExtras.h:1867
LLVM_ABI bool isGuaranteedToTransferExecutionToSuccessor(const Instruction *I)
Return true if this function can prove that the instruction I will always transfer execution to one o...
auto count_if(R &&Range, UnaryPredicate P)
Wrapper function around std::count_if to count the number of times an element satisfying a given pred...
Definition STLExtras.h:1961
decltype(auto) cast(const From &Val)
cast<X> - Return the argument parameter cast to the specified type.
Definition Casting.h:559
constexpr bool isIntN(unsigned N, int64_t x)
Checks if an signed integer fits into the given (dynamic) bit width.
Definition MathExtras.h:248
auto predecessors(const MachineBasicBlock *BB)
bool is_contained(R &&Range, const E &Element)
Returns true if Element is found in Range.
Definition STLExtras.h:1897
iterator_range< df_iterator< T > > depth_first(const T &G)
auto seq(T Begin, T End)
Iterate over an integral type from Begin up to - but not including - End.
Definition Sequence.h:305
AnalysisManager< Function > FunctionAnalysisManager
Convenience typedef for the Function analysis manager.
LLVM_ABI bool isGuaranteedNotToBePoison(const Value *V, AssumptionCache *AC=nullptr, const Instruction *CtxI=nullptr, const DominatorTree *DT=nullptr, unsigned Depth=0)
Returns true if V cannot be poison, but may be undef.
LLVM_ABI Constant * ConstantFoldInstOperands(const Instruction *I, ArrayRef< Constant * > Ops, const DataLayout &DL, const TargetLibraryInfo *TLI=nullptr, bool AllowNonDeterministic=true)
ConstantFoldInstOperands - Attempt to constant fold an instruction with the specified operands.
bool SCEVExprContains(const SCEV *Root, PredTy Pred)
Return true if any node in Root satisfies the predicate Pred.
Implement std::hash so that hash_code can be used in STL containers.
Definition BitVector.h:867
void swap(llvm::BitVector &LHS, llvm::BitVector &RHS)
Implement std::swap in terms of BitVector swap.
Definition BitVector.h:869
#define N
#define NC
Definition regutils.h:42
A special type used by analysis passes to provide an address that identifies that particular analysis...
Definition Analysis.h:29
static KnownBits makeConstant(const APInt &C)
Create known bits from a known constant.
Definition KnownBits.h:301
bool isNonNegative() const
Returns true if this value is known to be non-negative.
Definition KnownBits.h:108
static LLVM_ABI KnownBits ashr(const KnownBits &LHS, const KnownBits &RHS, bool ShAmtNonZero=false, bool Exact=false)
Compute known bits for ashr(LHS, RHS).
unsigned getBitWidth() const
Get the bit width of this value.
Definition KnownBits.h:44
static LLVM_ABI KnownBits lshr(const KnownBits &LHS, const KnownBits &RHS, bool ShAmtNonZero=false, bool Exact=false)
Compute known bits for lshr(LHS, RHS).
KnownBits zextOrTrunc(unsigned BitWidth) const
Return known bits for a zero extension or truncation of the value we're tracking.
Definition KnownBits.h:196
APInt getMaxValue() const
Return the maximal unsigned value possible given these KnownBits.
Definition KnownBits.h:145
APInt getMinValue() const
Return the minimal unsigned value possible given these KnownBits.
Definition KnownBits.h:129
bool isNegative() const
Returns true if this value is known to be negative.
Definition KnownBits.h:105
static LLVM_ABI KnownBits shl(const KnownBits &LHS, const KnownBits &RHS, bool NUW=false, bool NSW=false, bool ShAmtNonZero=false)
Compute known bits for shl(LHS, RHS).
An object of this class is returned by queries that could not be answered.
static LLVM_ABI bool classof(const SCEV *S)
Methods for support type inquiry through isa, cast, and dyn_cast:
This class defines a simple visitor class that may be used for various SCEV analysis purposes.
A utility class that uses RAII to save and restore the value of a variable.
Information about the number of loop iterations for which a loop exit's branch condition evaluates to...
LLVM_ABI ExitLimit(const SCEV *E)
Construct either an exact exit limit from a constant, or an unknown one from a SCEVCouldNotCompute.
SmallVector< const SCEVPredicate *, 4 > Predicates
A vector of predicate guards for this ExitLimit.