LLVM 22.0.0git
ScalarEvolution.cpp
Go to the documentation of this file.
1//===- ScalarEvolution.cpp - Scalar Evolution Analysis --------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file contains the implementation of the scalar evolution analysis
10// engine, which is used primarily to analyze expressions involving induction
11// variables in loops.
12//
13// There are several aspects to this library. First is the representation of
14// scalar expressions, which are represented as subclasses of the SCEV class.
15// These classes are used to represent certain types of subexpressions that we
16// can handle. We only create one SCEV of a particular shape, so
17// pointer-comparisons for equality are legal.
18//
19// One important aspect of the SCEV objects is that they are never cyclic, even
20// if there is a cycle in the dataflow for an expression (ie, a PHI node). If
21// the PHI node is one of the idioms that we can represent (e.g., a polynomial
22// recurrence) then we represent it directly as a recurrence node, otherwise we
23// represent it as a SCEVUnknown node.
24//
25// In addition to being able to represent expressions of various types, we also
26// have folders that are used to build the *canonical* representation for a
27// particular expression. These folders are capable of using a variety of
28// rewrite rules to simplify the expressions.
29//
30// Once the folders are defined, we can implement the more interesting
31// higher-level code, such as the code that recognizes PHI nodes of various
32// types, computes the execution count of a loop, etc.
33//
34// TODO: We should use these routines and value representations to implement
35// dependence analysis!
36//
37//===----------------------------------------------------------------------===//
38//
39// There are several good references for the techniques used in this analysis.
40//
41// Chains of recurrences -- a method to expedite the evaluation
42// of closed-form functions
43// Olaf Bachmann, Paul S. Wang, Eugene V. Zima
44//
45// On computational properties of chains of recurrences
46// Eugene V. Zima
47//
48// Symbolic Evaluation of Chains of Recurrences for Loop Optimization
49// Robert A. van Engelen
50//
51// Efficient Symbolic Analysis for Optimizing Compilers
52// Robert A. van Engelen
53//
54// Using the chains of recurrences algebra for data dependence testing and
55// induction variable substitution
56// MS Thesis, Johnie Birch
57//
58//===----------------------------------------------------------------------===//
59
61#include "llvm/ADT/APInt.h"
62#include "llvm/ADT/ArrayRef.h"
63#include "llvm/ADT/DenseMap.h"
65#include "llvm/ADT/FoldingSet.h"
66#include "llvm/ADT/STLExtras.h"
67#include "llvm/ADT/ScopeExit.h"
68#include "llvm/ADT/Sequence.h"
71#include "llvm/ADT/Statistic.h"
73#include "llvm/ADT/StringRef.h"
83#include "llvm/Config/llvm-config.h"
84#include "llvm/IR/Argument.h"
85#include "llvm/IR/BasicBlock.h"
86#include "llvm/IR/CFG.h"
87#include "llvm/IR/Constant.h"
89#include "llvm/IR/Constants.h"
90#include "llvm/IR/DataLayout.h"
92#include "llvm/IR/Dominators.h"
93#include "llvm/IR/Function.h"
94#include "llvm/IR/GlobalAlias.h"
95#include "llvm/IR/GlobalValue.h"
97#include "llvm/IR/InstrTypes.h"
98#include "llvm/IR/Instruction.h"
101#include "llvm/IR/Intrinsics.h"
102#include "llvm/IR/LLVMContext.h"
103#include "llvm/IR/Operator.h"
104#include "llvm/IR/PatternMatch.h"
105#include "llvm/IR/Type.h"
106#include "llvm/IR/Use.h"
107#include "llvm/IR/User.h"
108#include "llvm/IR/Value.h"
109#include "llvm/IR/Verifier.h"
111#include "llvm/Pass.h"
112#include "llvm/Support/Casting.h"
115#include "llvm/Support/Debug.h"
121#include <algorithm>
122#include <cassert>
123#include <climits>
124#include <cstdint>
125#include <cstdlib>
126#include <map>
127#include <memory>
128#include <numeric>
129#include <optional>
130#include <tuple>
131#include <utility>
132#include <vector>
133
134using namespace llvm;
135using namespace PatternMatch;
136using namespace SCEVPatternMatch;
137
138#define DEBUG_TYPE "scalar-evolution"
139
140STATISTIC(NumExitCountsComputed,
141 "Number of loop exits with predictable exit counts");
142STATISTIC(NumExitCountsNotComputed,
143 "Number of loop exits without predictable exit counts");
144STATISTIC(NumBruteForceTripCountsComputed,
145 "Number of loops with trip counts computed by force");
146
147#ifdef EXPENSIVE_CHECKS
148bool llvm::VerifySCEV = true;
149#else
150bool llvm::VerifySCEV = false;
151#endif
152
154 MaxBruteForceIterations("scalar-evolution-max-iterations", cl::ReallyHidden,
155 cl::desc("Maximum number of iterations SCEV will "
156 "symbolically execute a constant "
157 "derived loop"),
158 cl::init(100));
159
161 "verify-scev", cl::Hidden, cl::location(VerifySCEV),
162 cl::desc("Verify ScalarEvolution's backedge taken counts (slow)"));
164 "verify-scev-strict", cl::Hidden,
165 cl::desc("Enable stricter verification with -verify-scev is passed"));
166
168 "scev-verify-ir", cl::Hidden,
169 cl::desc("Verify IR correctness when making sensitive SCEV queries (slow)"),
170 cl::init(false));
171
173 "scev-mulops-inline-threshold", cl::Hidden,
174 cl::desc("Threshold for inlining multiplication operands into a SCEV"),
175 cl::init(32));
176
178 "scev-addops-inline-threshold", cl::Hidden,
179 cl::desc("Threshold for inlining addition operands into a SCEV"),
180 cl::init(500));
181
183 "scalar-evolution-max-scev-compare-depth", cl::Hidden,
184 cl::desc("Maximum depth of recursive SCEV complexity comparisons"),
185 cl::init(32));
186
188 "scalar-evolution-max-scev-operations-implication-depth", cl::Hidden,
189 cl::desc("Maximum depth of recursive SCEV operations implication analysis"),
190 cl::init(2));
191
193 "scalar-evolution-max-value-compare-depth", cl::Hidden,
194 cl::desc("Maximum depth of recursive value complexity comparisons"),
195 cl::init(2));
196
198 MaxArithDepth("scalar-evolution-max-arith-depth", cl::Hidden,
199 cl::desc("Maximum depth of recursive arithmetics"),
200 cl::init(32));
201
203 "scalar-evolution-max-constant-evolving-depth", cl::Hidden,
204 cl::desc("Maximum depth of recursive constant evolving"), cl::init(32));
205
207 MaxCastDepth("scalar-evolution-max-cast-depth", cl::Hidden,
208 cl::desc("Maximum depth of recursive SExt/ZExt/Trunc"),
209 cl::init(8));
210
212 MaxAddRecSize("scalar-evolution-max-add-rec-size", cl::Hidden,
213 cl::desc("Max coefficients in AddRec during evolving"),
214 cl::init(8));
215
217 HugeExprThreshold("scalar-evolution-huge-expr-threshold", cl::Hidden,
218 cl::desc("Size of the expression which is considered huge"),
219 cl::init(4096));
220
222 "scev-range-iter-threshold", cl::Hidden,
223 cl::desc("Threshold for switching to iteratively computing SCEV ranges"),
224 cl::init(32));
225
227 "scalar-evolution-max-loop-guard-collection-depth", cl::Hidden,
228 cl::desc("Maximum depth for recursive loop guard collection"), cl::init(1));
229
230static cl::opt<bool>
231ClassifyExpressions("scalar-evolution-classify-expressions",
232 cl::Hidden, cl::init(true),
233 cl::desc("When printing analysis, include information on every instruction"));
234
236 "scalar-evolution-use-expensive-range-sharpening", cl::Hidden,
237 cl::init(false),
238 cl::desc("Use more powerful methods of sharpening expression ranges. May "
239 "be costly in terms of compile time"));
240
242 "scalar-evolution-max-scc-analysis-depth", cl::Hidden,
243 cl::desc("Maximum amount of nodes to process while searching SCEVUnknown "
244 "Phi strongly connected components"),
245 cl::init(8));
246
247static cl::opt<bool>
248 EnableFiniteLoopControl("scalar-evolution-finite-loop", cl::Hidden,
249 cl::desc("Handle <= and >= in finite loops"),
250 cl::init(true));
251
253 "scalar-evolution-use-context-for-no-wrap-flag-strenghening", cl::Hidden,
254 cl::desc("Infer nuw/nsw flags using context where suitable"),
255 cl::init(true));
256
257//===----------------------------------------------------------------------===//
258// SCEV class definitions
259//===----------------------------------------------------------------------===//
260
261//===----------------------------------------------------------------------===//
262// Implementation of the SCEV class.
263//
264
265#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
267 print(dbgs());
268 dbgs() << '\n';
269}
270#endif
271
272void SCEV::print(raw_ostream &OS) const {
273 switch (getSCEVType()) {
274 case scConstant:
275 cast<SCEVConstant>(this)->getValue()->printAsOperand(OS, false);
276 return;
277 case scVScale:
278 OS << "vscale";
279 return;
280 case scPtrToInt: {
281 const SCEVPtrToIntExpr *PtrToInt = cast<SCEVPtrToIntExpr>(this);
282 const SCEV *Op = PtrToInt->getOperand();
283 OS << "(ptrtoint " << *Op->getType() << " " << *Op << " to "
284 << *PtrToInt->getType() << ")";
285 return;
286 }
287 case scTruncate: {
288 const SCEVTruncateExpr *Trunc = cast<SCEVTruncateExpr>(this);
289 const SCEV *Op = Trunc->getOperand();
290 OS << "(trunc " << *Op->getType() << " " << *Op << " to "
291 << *Trunc->getType() << ")";
292 return;
293 }
294 case scZeroExtend: {
296 const SCEV *Op = ZExt->getOperand();
297 OS << "(zext " << *Op->getType() << " " << *Op << " to "
298 << *ZExt->getType() << ")";
299 return;
300 }
301 case scSignExtend: {
303 const SCEV *Op = SExt->getOperand();
304 OS << "(sext " << *Op->getType() << " " << *Op << " to "
305 << *SExt->getType() << ")";
306 return;
307 }
308 case scAddRecExpr: {
309 const SCEVAddRecExpr *AR = cast<SCEVAddRecExpr>(this);
310 OS << "{" << *AR->getOperand(0);
311 for (unsigned i = 1, e = AR->getNumOperands(); i != e; ++i)
312 OS << ",+," << *AR->getOperand(i);
313 OS << "}<";
314 if (AR->hasNoUnsignedWrap())
315 OS << "nuw><";
316 if (AR->hasNoSignedWrap())
317 OS << "nsw><";
318 if (AR->hasNoSelfWrap() &&
320 OS << "nw><";
321 AR->getLoop()->getHeader()->printAsOperand(OS, /*PrintType=*/false);
322 OS << ">";
323 return;
324 }
325 case scAddExpr:
326 case scMulExpr:
327 case scUMaxExpr:
328 case scSMaxExpr:
329 case scUMinExpr:
330 case scSMinExpr:
332 const SCEVNAryExpr *NAry = cast<SCEVNAryExpr>(this);
333 const char *OpStr = nullptr;
334 switch (NAry->getSCEVType()) {
335 case scAddExpr: OpStr = " + "; break;
336 case scMulExpr: OpStr = " * "; break;
337 case scUMaxExpr: OpStr = " umax "; break;
338 case scSMaxExpr: OpStr = " smax "; break;
339 case scUMinExpr:
340 OpStr = " umin ";
341 break;
342 case scSMinExpr:
343 OpStr = " smin ";
344 break;
346 OpStr = " umin_seq ";
347 break;
348 default:
349 llvm_unreachable("There are no other nary expression types.");
350 }
351 OS << "("
353 << ")";
354 switch (NAry->getSCEVType()) {
355 case scAddExpr:
356 case scMulExpr:
357 if (NAry->hasNoUnsignedWrap())
358 OS << "<nuw>";
359 if (NAry->hasNoSignedWrap())
360 OS << "<nsw>";
361 break;
362 default:
363 // Nothing to print for other nary expressions.
364 break;
365 }
366 return;
367 }
368 case scUDivExpr: {
369 const SCEVUDivExpr *UDiv = cast<SCEVUDivExpr>(this);
370 OS << "(" << *UDiv->getLHS() << " /u " << *UDiv->getRHS() << ")";
371 return;
372 }
373 case scUnknown:
374 cast<SCEVUnknown>(this)->getValue()->printAsOperand(OS, false);
375 return;
377 OS << "***COULDNOTCOMPUTE***";
378 return;
379 }
380 llvm_unreachable("Unknown SCEV kind!");
381}
382
384 switch (getSCEVType()) {
385 case scConstant:
386 return cast<SCEVConstant>(this)->getType();
387 case scVScale:
388 return cast<SCEVVScale>(this)->getType();
389 case scPtrToInt:
390 case scTruncate:
391 case scZeroExtend:
392 case scSignExtend:
393 return cast<SCEVCastExpr>(this)->getType();
394 case scAddRecExpr:
395 return cast<SCEVAddRecExpr>(this)->getType();
396 case scMulExpr:
397 return cast<SCEVMulExpr>(this)->getType();
398 case scUMaxExpr:
399 case scSMaxExpr:
400 case scUMinExpr:
401 case scSMinExpr:
402 return cast<SCEVMinMaxExpr>(this)->getType();
404 return cast<SCEVSequentialMinMaxExpr>(this)->getType();
405 case scAddExpr:
406 return cast<SCEVAddExpr>(this)->getType();
407 case scUDivExpr:
408 return cast<SCEVUDivExpr>(this)->getType();
409 case scUnknown:
410 return cast<SCEVUnknown>(this)->getType();
412 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
413 }
414 llvm_unreachable("Unknown SCEV kind!");
415}
416
418 switch (getSCEVType()) {
419 case scConstant:
420 case scVScale:
421 case scUnknown:
422 return {};
423 case scPtrToInt:
424 case scTruncate:
425 case scZeroExtend:
426 case scSignExtend:
427 return cast<SCEVCastExpr>(this)->operands();
428 case scAddRecExpr:
429 case scAddExpr:
430 case scMulExpr:
431 case scUMaxExpr:
432 case scSMaxExpr:
433 case scUMinExpr:
434 case scSMinExpr:
436 return cast<SCEVNAryExpr>(this)->operands();
437 case scUDivExpr:
438 return cast<SCEVUDivExpr>(this)->operands();
440 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
441 }
442 llvm_unreachable("Unknown SCEV kind!");
443}
444
445bool SCEV::isZero() const { return match(this, m_scev_Zero()); }
446
447bool SCEV::isOne() const { return match(this, m_scev_One()); }
448
449bool SCEV::isAllOnesValue() const { return match(this, m_scev_AllOnes()); }
450
453 if (!Mul) return false;
454
455 // If there is a constant factor, it will be first.
456 const SCEVConstant *SC = dyn_cast<SCEVConstant>(Mul->getOperand(0));
457 if (!SC) return false;
458
459 // Return true if the value is negative, this matches things like (-42 * V).
460 return SC->getAPInt().isNegative();
461}
462
465
467 return S->getSCEVType() == scCouldNotCompute;
468}
469
472 ID.AddInteger(scConstant);
473 ID.AddPointer(V);
474 void *IP = nullptr;
475 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
476 SCEV *S = new (SCEVAllocator) SCEVConstant(ID.Intern(SCEVAllocator), V);
477 UniqueSCEVs.InsertNode(S, IP);
478 return S;
479}
480
482 return getConstant(ConstantInt::get(getContext(), Val));
483}
484
485const SCEV *
488 return getConstant(ConstantInt::get(ITy, V, isSigned));
489}
490
493 ID.AddInteger(scVScale);
494 ID.AddPointer(Ty);
495 void *IP = nullptr;
496 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP))
497 return S;
498 SCEV *S = new (SCEVAllocator) SCEVVScale(ID.Intern(SCEVAllocator), Ty);
499 UniqueSCEVs.InsertNode(S, IP);
500 return S;
501}
502
504 SCEV::NoWrapFlags Flags) {
505 const SCEV *Res = getConstant(Ty, EC.getKnownMinValue());
506 if (EC.isScalable())
507 Res = getMulExpr(Res, getVScale(Ty), Flags);
508 return Res;
509}
510
512 const SCEV *op, Type *ty)
513 : SCEV(ID, SCEVTy, computeExpressionSize(op)), Op(op), Ty(ty) {}
514
515SCEVPtrToIntExpr::SCEVPtrToIntExpr(const FoldingSetNodeIDRef ID, const SCEV *Op,
516 Type *ITy)
517 : SCEVCastExpr(ID, scPtrToInt, Op, ITy) {
518 assert(getOperand()->getType()->isPointerTy() && Ty->isIntegerTy() &&
519 "Must be a non-bit-width-changing pointer-to-integer cast!");
520}
521
523 SCEVTypes SCEVTy, const SCEV *op,
524 Type *ty)
525 : SCEVCastExpr(ID, SCEVTy, op, ty) {}
526
527SCEVTruncateExpr::SCEVTruncateExpr(const FoldingSetNodeIDRef ID, const SCEV *op,
528 Type *ty)
530 assert(getOperand()->getType()->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
531 "Cannot truncate non-integer value!");
532}
533
534SCEVZeroExtendExpr::SCEVZeroExtendExpr(const FoldingSetNodeIDRef ID,
535 const SCEV *op, Type *ty)
537 assert(getOperand()->getType()->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
538 "Cannot zero extend non-integer value!");
539}
540
541SCEVSignExtendExpr::SCEVSignExtendExpr(const FoldingSetNodeIDRef ID,
542 const SCEV *op, Type *ty)
544 assert(getOperand()->getType()->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
545 "Cannot sign extend non-integer value!");
546}
547
549 // Clear this SCEVUnknown from various maps.
550 SE->forgetMemoizedResults(this);
551
552 // Remove this SCEVUnknown from the uniquing map.
553 SE->UniqueSCEVs.RemoveNode(this);
554
555 // Release the value.
556 setValPtr(nullptr);
557}
558
559void SCEVUnknown::allUsesReplacedWith(Value *New) {
560 // Clear this SCEVUnknown from various maps.
561 SE->forgetMemoizedResults(this);
562
563 // Remove this SCEVUnknown from the uniquing map.
564 SE->UniqueSCEVs.RemoveNode(this);
565
566 // Replace the value pointer in case someone is still using this SCEVUnknown.
567 setValPtr(New);
568}
569
570//===----------------------------------------------------------------------===//
571// SCEV Utilities
572//===----------------------------------------------------------------------===//
573
574/// Compare the two values \p LV and \p RV in terms of their "complexity" where
575/// "complexity" is a partial (and somewhat ad-hoc) relation used to order
576/// operands in SCEV expressions.
577static int CompareValueComplexity(const LoopInfo *const LI, Value *LV,
578 Value *RV, unsigned Depth) {
580 return 0;
581
582 // Order pointer values after integer values. This helps SCEVExpander form
583 // GEPs.
584 bool LIsPointer = LV->getType()->isPointerTy(),
585 RIsPointer = RV->getType()->isPointerTy();
586 if (LIsPointer != RIsPointer)
587 return (int)LIsPointer - (int)RIsPointer;
588
589 // Compare getValueID values.
590 unsigned LID = LV->getValueID(), RID = RV->getValueID();
591 if (LID != RID)
592 return (int)LID - (int)RID;
593
594 // Sort arguments by their position.
595 if (const auto *LA = dyn_cast<Argument>(LV)) {
596 const auto *RA = cast<Argument>(RV);
597 unsigned LArgNo = LA->getArgNo(), RArgNo = RA->getArgNo();
598 return (int)LArgNo - (int)RArgNo;
599 }
600
601 if (const auto *LGV = dyn_cast<GlobalValue>(LV)) {
602 const auto *RGV = cast<GlobalValue>(RV);
603
604 if (auto L = LGV->getLinkage() - RGV->getLinkage())
605 return L;
606
607 const auto IsGVNameSemantic = [&](const GlobalValue *GV) {
608 auto LT = GV->getLinkage();
609 return !(GlobalValue::isPrivateLinkage(LT) ||
611 };
612
613 // Use the names to distinguish the two values, but only if the
614 // names are semantically important.
615 if (IsGVNameSemantic(LGV) && IsGVNameSemantic(RGV))
616 return LGV->getName().compare(RGV->getName());
617 }
618
619 // For instructions, compare their loop depth, and their operand count. This
620 // is pretty loose.
621 if (const auto *LInst = dyn_cast<Instruction>(LV)) {
622 const auto *RInst = cast<Instruction>(RV);
623
624 // Compare loop depths.
625 const BasicBlock *LParent = LInst->getParent(),
626 *RParent = RInst->getParent();
627 if (LParent != RParent) {
628 unsigned LDepth = LI->getLoopDepth(LParent),
629 RDepth = LI->getLoopDepth(RParent);
630 if (LDepth != RDepth)
631 return (int)LDepth - (int)RDepth;
632 }
633
634 // Compare the number of operands.
635 unsigned LNumOps = LInst->getNumOperands(),
636 RNumOps = RInst->getNumOperands();
637 if (LNumOps != RNumOps)
638 return (int)LNumOps - (int)RNumOps;
639
640 for (unsigned Idx : seq(LNumOps)) {
641 int Result = CompareValueComplexity(LI, LInst->getOperand(Idx),
642 RInst->getOperand(Idx), Depth + 1);
643 if (Result != 0)
644 return Result;
645 }
646 }
647
648 return 0;
649}
650
651// Return negative, zero, or positive, if LHS is less than, equal to, or greater
652// than RHS, respectively. A three-way result allows recursive comparisons to be
653// more efficient.
654// If the max analysis depth was reached, return std::nullopt, assuming we do
655// not know if they are equivalent for sure.
656static std::optional<int>
657CompareSCEVComplexity(const LoopInfo *const LI, const SCEV *LHS,
658 const SCEV *RHS, DominatorTree &DT, unsigned Depth = 0) {
659 // Fast-path: SCEVs are uniqued so we can do a quick equality check.
660 if (LHS == RHS)
661 return 0;
662
663 // Primarily, sort the SCEVs by their getSCEVType().
664 SCEVTypes LType = LHS->getSCEVType(), RType = RHS->getSCEVType();
665 if (LType != RType)
666 return (int)LType - (int)RType;
667
669 return std::nullopt;
670
671 // Aside from the getSCEVType() ordering, the particular ordering
672 // isn't very important except that it's beneficial to be consistent,
673 // so that (a + b) and (b + a) don't end up as different expressions.
674 switch (LType) {
675 case scUnknown: {
676 const SCEVUnknown *LU = cast<SCEVUnknown>(LHS);
677 const SCEVUnknown *RU = cast<SCEVUnknown>(RHS);
678
679 int X =
680 CompareValueComplexity(LI, LU->getValue(), RU->getValue(), Depth + 1);
681 return X;
682 }
683
684 case scConstant: {
687
688 // Compare constant values.
689 const APInt &LA = LC->getAPInt();
690 const APInt &RA = RC->getAPInt();
691 unsigned LBitWidth = LA.getBitWidth(), RBitWidth = RA.getBitWidth();
692 if (LBitWidth != RBitWidth)
693 return (int)LBitWidth - (int)RBitWidth;
694 return LA.ult(RA) ? -1 : 1;
695 }
696
697 case scVScale: {
698 const auto *LTy = cast<IntegerType>(cast<SCEVVScale>(LHS)->getType());
699 const auto *RTy = cast<IntegerType>(cast<SCEVVScale>(RHS)->getType());
700 return LTy->getBitWidth() - RTy->getBitWidth();
701 }
702
703 case scAddRecExpr: {
706
707 // There is always a dominance between two recs that are used by one SCEV,
708 // so we can safely sort recs by loop header dominance. We require such
709 // order in getAddExpr.
710 const Loop *LLoop = LA->getLoop(), *RLoop = RA->getLoop();
711 if (LLoop != RLoop) {
712 const BasicBlock *LHead = LLoop->getHeader(), *RHead = RLoop->getHeader();
713 assert(LHead != RHead && "Two loops share the same header?");
714 if (DT.dominates(LHead, RHead))
715 return 1;
716 assert(DT.dominates(RHead, LHead) &&
717 "No dominance between recurrences used by one SCEV?");
718 return -1;
719 }
720
721 [[fallthrough]];
722 }
723
724 case scTruncate:
725 case scZeroExtend:
726 case scSignExtend:
727 case scPtrToInt:
728 case scAddExpr:
729 case scMulExpr:
730 case scUDivExpr:
731 case scSMaxExpr:
732 case scUMaxExpr:
733 case scSMinExpr:
734 case scUMinExpr:
736 ArrayRef<const SCEV *> LOps = LHS->operands();
737 ArrayRef<const SCEV *> ROps = RHS->operands();
738
739 // Lexicographically compare n-ary-like expressions.
740 unsigned LNumOps = LOps.size(), RNumOps = ROps.size();
741 if (LNumOps != RNumOps)
742 return (int)LNumOps - (int)RNumOps;
743
744 for (unsigned i = 0; i != LNumOps; ++i) {
745 auto X = CompareSCEVComplexity(LI, LOps[i], ROps[i], DT, Depth + 1);
746 if (X != 0)
747 return X;
748 }
749 return 0;
750 }
751
753 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
754 }
755 llvm_unreachable("Unknown SCEV kind!");
756}
757
758/// Given a list of SCEV objects, order them by their complexity, and group
759/// objects of the same complexity together by value. When this routine is
760/// finished, we know that any duplicates in the vector are consecutive and that
761/// complexity is monotonically increasing.
762///
763/// Note that we go take special precautions to ensure that we get deterministic
764/// results from this routine. In other words, we don't want the results of
765/// this to depend on where the addresses of various SCEV objects happened to
766/// land in memory.
768 LoopInfo *LI, DominatorTree &DT) {
769 if (Ops.size() < 2) return; // Noop
770
771 // Whether LHS has provably less complexity than RHS.
772 auto IsLessComplex = [&](const SCEV *LHS, const SCEV *RHS) {
773 auto Complexity = CompareSCEVComplexity(LI, LHS, RHS, DT);
774 return Complexity && *Complexity < 0;
775 };
776 if (Ops.size() == 2) {
777 // This is the common case, which also happens to be trivially simple.
778 // Special case it.
779 const SCEV *&LHS = Ops[0], *&RHS = Ops[1];
780 if (IsLessComplex(RHS, LHS))
781 std::swap(LHS, RHS);
782 return;
783 }
784
785 // Do the rough sort by complexity.
786 llvm::stable_sort(Ops, [&](const SCEV *LHS, const SCEV *RHS) {
787 return IsLessComplex(LHS, RHS);
788 });
789
790 // Now that we are sorted by complexity, group elements of the same
791 // complexity. Note that this is, at worst, N^2, but the vector is likely to
792 // be extremely short in practice. Note that we take this approach because we
793 // do not want to depend on the addresses of the objects we are grouping.
794 for (unsigned i = 0, e = Ops.size(); i != e-2; ++i) {
795 const SCEV *S = Ops[i];
796 unsigned Complexity = S->getSCEVType();
797
798 // If there are any objects of the same complexity and same value as this
799 // one, group them.
800 for (unsigned j = i+1; j != e && Ops[j]->getSCEVType() == Complexity; ++j) {
801 if (Ops[j] == S) { // Found a duplicate.
802 // Move it to immediately after i'th element.
803 std::swap(Ops[i+1], Ops[j]);
804 ++i; // no need to rescan it.
805 if (i == e-2) return; // Done!
806 }
807 }
808 }
809}
810
811/// Returns true if \p Ops contains a huge SCEV (the subtree of S contains at
812/// least HugeExprThreshold nodes).
814 return any_of(Ops, [](const SCEV *S) {
816 });
817}
818
819/// Performs a number of common optimizations on the passed \p Ops. If the
820/// whole expression reduces down to a single operand, it will be returned.
821///
822/// The following optimizations are performed:
823/// * Fold constants using the \p Fold function.
824/// * Remove identity constants satisfying \p IsIdentity.
825/// * If a constant satisfies \p IsAbsorber, return it.
826/// * Sort operands by complexity.
827template <typename FoldT, typename IsIdentityT, typename IsAbsorberT>
828static const SCEV *
831 IsIdentityT IsIdentity, IsAbsorberT IsAbsorber) {
832 const SCEVConstant *Folded = nullptr;
833 for (unsigned Idx = 0; Idx < Ops.size();) {
834 const SCEV *Op = Ops[Idx];
835 if (const auto *C = dyn_cast<SCEVConstant>(Op)) {
836 if (!Folded)
837 Folded = C;
838 else
839 Folded = cast<SCEVConstant>(
840 SE.getConstant(Fold(Folded->getAPInt(), C->getAPInt())));
841 Ops.erase(Ops.begin() + Idx);
842 continue;
843 }
844 ++Idx;
845 }
846
847 if (Ops.empty()) {
848 assert(Folded && "Must have folded value");
849 return Folded;
850 }
851
852 if (Folded && IsAbsorber(Folded->getAPInt()))
853 return Folded;
854
855 GroupByComplexity(Ops, &LI, DT);
856 if (Folded && !IsIdentity(Folded->getAPInt()))
857 Ops.insert(Ops.begin(), Folded);
858
859 return Ops.size() == 1 ? Ops[0] : nullptr;
860}
861
862//===----------------------------------------------------------------------===//
863// Simple SCEV method implementations
864//===----------------------------------------------------------------------===//
865
866/// Compute BC(It, K). The result has width W. Assume, K > 0.
867static const SCEV *BinomialCoefficient(const SCEV *It, unsigned K,
868 ScalarEvolution &SE,
869 Type *ResultTy) {
870 // Handle the simplest case efficiently.
871 if (K == 1)
872 return SE.getTruncateOrZeroExtend(It, ResultTy);
873
874 // We are using the following formula for BC(It, K):
875 //
876 // BC(It, K) = (It * (It - 1) * ... * (It - K + 1)) / K!
877 //
878 // Suppose, W is the bitwidth of the return value. We must be prepared for
879 // overflow. Hence, we must assure that the result of our computation is
880 // equal to the accurate one modulo 2^W. Unfortunately, division isn't
881 // safe in modular arithmetic.
882 //
883 // However, this code doesn't use exactly that formula; the formula it uses
884 // is something like the following, where T is the number of factors of 2 in
885 // K! (i.e. trailing zeros in the binary representation of K!), and ^ is
886 // exponentiation:
887 //
888 // BC(It, K) = (It * (It - 1) * ... * (It - K + 1)) / 2^T / (K! / 2^T)
889 //
890 // This formula is trivially equivalent to the previous formula. However,
891 // this formula can be implemented much more efficiently. The trick is that
892 // K! / 2^T is odd, and exact division by an odd number *is* safe in modular
893 // arithmetic. To do exact division in modular arithmetic, all we have
894 // to do is multiply by the inverse. Therefore, this step can be done at
895 // width W.
896 //
897 // The next issue is how to safely do the division by 2^T. The way this
898 // is done is by doing the multiplication step at a width of at least W + T
899 // bits. This way, the bottom W+T bits of the product are accurate. Then,
900 // when we perform the division by 2^T (which is equivalent to a right shift
901 // by T), the bottom W bits are accurate. Extra bits are okay; they'll get
902 // truncated out after the division by 2^T.
903 //
904 // In comparison to just directly using the first formula, this technique
905 // is much more efficient; using the first formula requires W * K bits,
906 // but this formula less than W + K bits. Also, the first formula requires
907 // a division step, whereas this formula only requires multiplies and shifts.
908 //
909 // It doesn't matter whether the subtraction step is done in the calculation
910 // width or the input iteration count's width; if the subtraction overflows,
911 // the result must be zero anyway. We prefer here to do it in the width of
912 // the induction variable because it helps a lot for certain cases; CodeGen
913 // isn't smart enough to ignore the overflow, which leads to much less
914 // efficient code if the width of the subtraction is wider than the native
915 // register width.
916 //
917 // (It's possible to not widen at all by pulling out factors of 2 before
918 // the multiplication; for example, K=2 can be calculated as
919 // It/2*(It+(It*INT_MIN/INT_MIN)+-1). However, it requires
920 // extra arithmetic, so it's not an obvious win, and it gets
921 // much more complicated for K > 3.)
922
923 // Protection from insane SCEVs; this bound is conservative,
924 // but it probably doesn't matter.
925 if (K > 1000)
926 return SE.getCouldNotCompute();
927
928 unsigned W = SE.getTypeSizeInBits(ResultTy);
929
930 // Calculate K! / 2^T and T; we divide out the factors of two before
931 // multiplying for calculating K! / 2^T to avoid overflow.
932 // Other overflow doesn't matter because we only care about the bottom
933 // W bits of the result.
934 APInt OddFactorial(W, 1);
935 unsigned T = 1;
936 for (unsigned i = 3; i <= K; ++i) {
937 unsigned TwoFactors = countr_zero(i);
938 T += TwoFactors;
939 OddFactorial *= (i >> TwoFactors);
940 }
941
942 // We need at least W + T bits for the multiplication step
943 unsigned CalculationBits = W + T;
944
945 // Calculate 2^T, at width T+W.
946 APInt DivFactor = APInt::getOneBitSet(CalculationBits, T);
947
948 // Calculate the multiplicative inverse of K! / 2^T;
949 // this multiplication factor will perform the exact division by
950 // K! / 2^T.
951 APInt MultiplyFactor = OddFactorial.multiplicativeInverse();
952
953 // Calculate the product, at width T+W
954 IntegerType *CalculationTy = IntegerType::get(SE.getContext(),
955 CalculationBits);
956 const SCEV *Dividend = SE.getTruncateOrZeroExtend(It, CalculationTy);
957 for (unsigned i = 1; i != K; ++i) {
958 const SCEV *S = SE.getMinusSCEV(It, SE.getConstant(It->getType(), i));
959 Dividend = SE.getMulExpr(Dividend,
960 SE.getTruncateOrZeroExtend(S, CalculationTy));
961 }
962
963 // Divide by 2^T
964 const SCEV *DivResult = SE.getUDivExpr(Dividend, SE.getConstant(DivFactor));
965
966 // Truncate the result, and divide by K! / 2^T.
967
968 return SE.getMulExpr(SE.getConstant(MultiplyFactor),
969 SE.getTruncateOrZeroExtend(DivResult, ResultTy));
970}
971
972/// Return the value of this chain of recurrences at the specified iteration
973/// number. We can evaluate this recurrence by multiplying each element in the
974/// chain by the binomial coefficient corresponding to it. In other words, we
975/// can evaluate {A,+,B,+,C,+,D} as:
976///
977/// A*BC(It, 0) + B*BC(It, 1) + C*BC(It, 2) + D*BC(It, 3)
978///
979/// where BC(It, k) stands for binomial coefficient.
981 ScalarEvolution &SE) const {
982 return evaluateAtIteration(operands(), It, SE);
983}
984
985const SCEV *
987 const SCEV *It, ScalarEvolution &SE) {
988 assert(Operands.size() > 0);
989 const SCEV *Result = Operands[0];
990 for (unsigned i = 1, e = Operands.size(); i != e; ++i) {
991 // The computation is correct in the face of overflow provided that the
992 // multiplication is performed _after_ the evaluation of the binomial
993 // coefficient.
994 const SCEV *Coeff = BinomialCoefficient(It, i, SE, Result->getType());
995 if (isa<SCEVCouldNotCompute>(Coeff))
996 return Coeff;
997
998 Result = SE.getAddExpr(Result, SE.getMulExpr(Operands[i], Coeff));
999 }
1000 return Result;
1001}
1002
1003//===----------------------------------------------------------------------===//
1004// SCEV Expression folder implementations
1005//===----------------------------------------------------------------------===//
1006
1008 unsigned Depth) {
1009 assert(Depth <= 1 &&
1010 "getLosslessPtrToIntExpr() should self-recurse at most once.");
1011
1012 // We could be called with an integer-typed operands during SCEV rewrites.
1013 // Since the operand is an integer already, just perform zext/trunc/self cast.
1014 if (!Op->getType()->isPointerTy())
1015 return Op;
1016
1017 // What would be an ID for such a SCEV cast expression?
1019 ID.AddInteger(scPtrToInt);
1020 ID.AddPointer(Op);
1021
1022 void *IP = nullptr;
1023
1024 // Is there already an expression for such a cast?
1025 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP))
1026 return S;
1027
1028 // It isn't legal for optimizations to construct new ptrtoint expressions
1029 // for non-integral pointers.
1030 if (getDataLayout().isNonIntegralPointerType(Op->getType()))
1031 return getCouldNotCompute();
1032
1033 Type *IntPtrTy = getDataLayout().getIntPtrType(Op->getType());
1034
1035 // We can only trivially model ptrtoint if SCEV's effective (integer) type
1036 // is sufficiently wide to represent all possible pointer values.
1037 // We could theoretically teach SCEV to truncate wider pointers, but
1038 // that isn't implemented for now.
1040 getDataLayout().getTypeSizeInBits(IntPtrTy))
1041 return getCouldNotCompute();
1042
1043 // If not, is this expression something we can't reduce any further?
1044 if (auto *U = dyn_cast<SCEVUnknown>(Op)) {
1045 // Perform some basic constant folding. If the operand of the ptr2int cast
1046 // is a null pointer, don't create a ptr2int SCEV expression (that will be
1047 // left as-is), but produce a zero constant.
1048 // NOTE: We could handle a more general case, but lack motivational cases.
1049 if (isa<ConstantPointerNull>(U->getValue()))
1050 return getZero(IntPtrTy);
1051
1052 // Create an explicit cast node.
1053 // We can reuse the existing insert position since if we get here,
1054 // we won't have made any changes which would invalidate it.
1055 SCEV *S = new (SCEVAllocator)
1056 SCEVPtrToIntExpr(ID.Intern(SCEVAllocator), Op, IntPtrTy);
1057 UniqueSCEVs.InsertNode(S, IP);
1058 registerUser(S, Op);
1059 return S;
1060 }
1061
1062 assert(Depth == 0 && "getLosslessPtrToIntExpr() should not self-recurse for "
1063 "non-SCEVUnknown's.");
1064
1065 // Otherwise, we've got some expression that is more complex than just a
1066 // single SCEVUnknown. But we don't want to have a SCEVPtrToIntExpr of an
1067 // arbitrary expression, we want to have SCEVPtrToIntExpr of an SCEVUnknown
1068 // only, and the expressions must otherwise be integer-typed.
1069 // So sink the cast down to the SCEVUnknown's.
1070
1071 /// The SCEVPtrToIntSinkingRewriter takes a scalar evolution expression,
1072 /// which computes a pointer-typed value, and rewrites the whole expression
1073 /// tree so that *all* the computations are done on integers, and the only
1074 /// pointer-typed operands in the expression are SCEVUnknown.
1075 class SCEVPtrToIntSinkingRewriter
1076 : public SCEVRewriteVisitor<SCEVPtrToIntSinkingRewriter> {
1078
1079 public:
1080 SCEVPtrToIntSinkingRewriter(ScalarEvolution &SE) : SCEVRewriteVisitor(SE) {}
1081
1082 static const SCEV *rewrite(const SCEV *Scev, ScalarEvolution &SE) {
1083 SCEVPtrToIntSinkingRewriter Rewriter(SE);
1084 return Rewriter.visit(Scev);
1085 }
1086
1087 const SCEV *visit(const SCEV *S) {
1088 Type *STy = S->getType();
1089 // If the expression is not pointer-typed, just keep it as-is.
1090 if (!STy->isPointerTy())
1091 return S;
1092 // Else, recursively sink the cast down into it.
1093 return Base::visit(S);
1094 }
1095
1096 const SCEV *visitAddExpr(const SCEVAddExpr *Expr) {
1098 bool Changed = false;
1099 for (const auto *Op : Expr->operands()) {
1100 Operands.push_back(visit(Op));
1101 Changed |= Op != Operands.back();
1102 }
1103 return !Changed ? Expr : SE.getAddExpr(Operands, Expr->getNoWrapFlags());
1104 }
1105
1106 const SCEV *visitMulExpr(const SCEVMulExpr *Expr) {
1108 bool Changed = false;
1109 for (const auto *Op : Expr->operands()) {
1110 Operands.push_back(visit(Op));
1111 Changed |= Op != Operands.back();
1112 }
1113 return !Changed ? Expr : SE.getMulExpr(Operands, Expr->getNoWrapFlags());
1114 }
1115
1116 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
1117 assert(Expr->getType()->isPointerTy() &&
1118 "Should only reach pointer-typed SCEVUnknown's.");
1119 return SE.getLosslessPtrToIntExpr(Expr, /*Depth=*/1);
1120 }
1121 };
1122
1123 // And actually perform the cast sinking.
1124 const SCEV *IntOp = SCEVPtrToIntSinkingRewriter::rewrite(Op, *this);
1125 assert(IntOp->getType()->isIntegerTy() &&
1126 "We must have succeeded in sinking the cast, "
1127 "and ending up with an integer-typed expression!");
1128 return IntOp;
1129}
1130
1132 assert(Ty->isIntegerTy() && "Target type must be an integer type!");
1133
1134 const SCEV *IntOp = getLosslessPtrToIntExpr(Op);
1135 if (isa<SCEVCouldNotCompute>(IntOp))
1136 return IntOp;
1137
1138 return getTruncateOrZeroExtend(IntOp, Ty);
1139}
1140
1142 unsigned Depth) {
1143 assert(getTypeSizeInBits(Op->getType()) > getTypeSizeInBits(Ty) &&
1144 "This is not a truncating conversion!");
1145 assert(isSCEVable(Ty) &&
1146 "This is not a conversion to a SCEVable type!");
1147 assert(!Op->getType()->isPointerTy() && "Can't truncate pointer!");
1148 Ty = getEffectiveSCEVType(Ty);
1149
1151 ID.AddInteger(scTruncate);
1152 ID.AddPointer(Op);
1153 ID.AddPointer(Ty);
1154 void *IP = nullptr;
1155 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
1156
1157 // Fold if the operand is constant.
1158 if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op))
1159 return getConstant(
1160 cast<ConstantInt>(ConstantExpr::getTrunc(SC->getValue(), Ty)));
1161
1162 // trunc(trunc(x)) --> trunc(x)
1164 return getTruncateExpr(ST->getOperand(), Ty, Depth + 1);
1165
1166 // trunc(sext(x)) --> sext(x) if widening or trunc(x) if narrowing
1168 return getTruncateOrSignExtend(SS->getOperand(), Ty, Depth + 1);
1169
1170 // trunc(zext(x)) --> zext(x) if widening or trunc(x) if narrowing
1172 return getTruncateOrZeroExtend(SZ->getOperand(), Ty, Depth + 1);
1173
1174 if (Depth > MaxCastDepth) {
1175 SCEV *S =
1176 new (SCEVAllocator) SCEVTruncateExpr(ID.Intern(SCEVAllocator), Op, Ty);
1177 UniqueSCEVs.InsertNode(S, IP);
1178 registerUser(S, Op);
1179 return S;
1180 }
1181
1182 // trunc(x1 + ... + xN) --> trunc(x1) + ... + trunc(xN) and
1183 // trunc(x1 * ... * xN) --> trunc(x1) * ... * trunc(xN),
1184 // if after transforming we have at most one truncate, not counting truncates
1185 // that replace other casts.
1187 auto *CommOp = cast<SCEVCommutativeExpr>(Op);
1189 unsigned numTruncs = 0;
1190 for (unsigned i = 0, e = CommOp->getNumOperands(); i != e && numTruncs < 2;
1191 ++i) {
1192 const SCEV *S = getTruncateExpr(CommOp->getOperand(i), Ty, Depth + 1);
1193 if (!isa<SCEVIntegralCastExpr>(CommOp->getOperand(i)) &&
1195 numTruncs++;
1196 Operands.push_back(S);
1197 }
1198 if (numTruncs < 2) {
1199 if (isa<SCEVAddExpr>(Op))
1200 return getAddExpr(Operands);
1201 if (isa<SCEVMulExpr>(Op))
1202 return getMulExpr(Operands);
1203 llvm_unreachable("Unexpected SCEV type for Op.");
1204 }
1205 // Although we checked in the beginning that ID is not in the cache, it is
1206 // possible that during recursion and different modification ID was inserted
1207 // into the cache. So if we find it, just return it.
1208 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP))
1209 return S;
1210 }
1211
1212 // If the input value is a chrec scev, truncate the chrec's operands.
1213 if (const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(Op)) {
1215 for (const SCEV *Op : AddRec->operands())
1216 Operands.push_back(getTruncateExpr(Op, Ty, Depth + 1));
1217 return getAddRecExpr(Operands, AddRec->getLoop(), SCEV::FlagAnyWrap);
1218 }
1219
1220 // Return zero if truncating to known zeros.
1221 uint32_t MinTrailingZeros = getMinTrailingZeros(Op);
1222 if (MinTrailingZeros >= getTypeSizeInBits(Ty))
1223 return getZero(Ty);
1224
1225 // The cast wasn't folded; create an explicit cast node. We can reuse
1226 // the existing insert position since if we get here, we won't have
1227 // made any changes which would invalidate it.
1228 SCEV *S = new (SCEVAllocator) SCEVTruncateExpr(ID.Intern(SCEVAllocator),
1229 Op, Ty);
1230 UniqueSCEVs.InsertNode(S, IP);
1231 registerUser(S, Op);
1232 return S;
1233}
1234
1235// Get the limit of a recurrence such that incrementing by Step cannot cause
1236// signed overflow as long as the value of the recurrence within the
1237// loop does not exceed this limit before incrementing.
1238static const SCEV *getSignedOverflowLimitForStep(const SCEV *Step,
1239 ICmpInst::Predicate *Pred,
1240 ScalarEvolution *SE) {
1241 unsigned BitWidth = SE->getTypeSizeInBits(Step->getType());
1242 if (SE->isKnownPositive(Step)) {
1243 *Pred = ICmpInst::ICMP_SLT;
1245 SE->getSignedRangeMax(Step));
1246 }
1247 if (SE->isKnownNegative(Step)) {
1248 *Pred = ICmpInst::ICMP_SGT;
1250 SE->getSignedRangeMin(Step));
1251 }
1252 return nullptr;
1253}
1254
1255// Get the limit of a recurrence such that incrementing by Step cannot cause
1256// unsigned overflow as long as the value of the recurrence within the loop does
1257// not exceed this limit before incrementing.
1259 ICmpInst::Predicate *Pred,
1260 ScalarEvolution *SE) {
1261 unsigned BitWidth = SE->getTypeSizeInBits(Step->getType());
1262 *Pred = ICmpInst::ICMP_ULT;
1263
1265 SE->getUnsignedRangeMax(Step));
1266}
1267
1268namespace {
1269
1270struct ExtendOpTraitsBase {
1271 typedef const SCEV *(ScalarEvolution::*GetExtendExprTy)(const SCEV *, Type *,
1272 unsigned);
1273};
1274
1275// Used to make code generic over signed and unsigned overflow.
1276template <typename ExtendOp> struct ExtendOpTraits {
1277 // Members present:
1278 //
1279 // static const SCEV::NoWrapFlags WrapType;
1280 //
1281 // static const ExtendOpTraitsBase::GetExtendExprTy GetExtendExpr;
1282 //
1283 // static const SCEV *getOverflowLimitForStep(const SCEV *Step,
1284 // ICmpInst::Predicate *Pred,
1285 // ScalarEvolution *SE);
1286};
1287
1288template <>
1289struct ExtendOpTraits<SCEVSignExtendExpr> : public ExtendOpTraitsBase {
1290 static const SCEV::NoWrapFlags WrapType = SCEV::FlagNSW;
1291
1292 static const GetExtendExprTy GetExtendExpr;
1293
1294 static const SCEV *getOverflowLimitForStep(const SCEV *Step,
1295 ICmpInst::Predicate *Pred,
1296 ScalarEvolution *SE) {
1297 return getSignedOverflowLimitForStep(Step, Pred, SE);
1298 }
1299};
1300
1301const ExtendOpTraitsBase::GetExtendExprTy ExtendOpTraits<
1303
1304template <>
1305struct ExtendOpTraits<SCEVZeroExtendExpr> : public ExtendOpTraitsBase {
1306 static const SCEV::NoWrapFlags WrapType = SCEV::FlagNUW;
1307
1308 static const GetExtendExprTy GetExtendExpr;
1309
1310 static const SCEV *getOverflowLimitForStep(const SCEV *Step,
1311 ICmpInst::Predicate *Pred,
1312 ScalarEvolution *SE) {
1313 return getUnsignedOverflowLimitForStep(Step, Pred, SE);
1314 }
1315};
1316
1317const ExtendOpTraitsBase::GetExtendExprTy ExtendOpTraits<
1319
1320} // end anonymous namespace
1321
1322// The recurrence AR has been shown to have no signed/unsigned wrap or something
1323// close to it. Typically, if we can prove NSW/NUW for AR, then we can just as
1324// easily prove NSW/NUW for its preincrement or postincrement sibling. This
1325// allows normalizing a sign/zero extended AddRec as such: {sext/zext(Step +
1326// Start),+,Step} => {(Step + sext/zext(Start),+,Step} As a result, the
1327// expression "Step + sext/zext(PreIncAR)" is congruent with
1328// "sext/zext(PostIncAR)"
1329template <typename ExtendOpTy>
1330static const SCEV *getPreStartForExtend(const SCEVAddRecExpr *AR, Type *Ty,
1331 ScalarEvolution *SE, unsigned Depth) {
1332 auto WrapType = ExtendOpTraits<ExtendOpTy>::WrapType;
1333 auto GetExtendExpr = ExtendOpTraits<ExtendOpTy>::GetExtendExpr;
1334
1335 const Loop *L = AR->getLoop();
1336 const SCEV *Start = AR->getStart();
1337 const SCEV *Step = AR->getStepRecurrence(*SE);
1338
1339 // Check for a simple looking step prior to loop entry.
1340 const SCEVAddExpr *SA = dyn_cast<SCEVAddExpr>(Start);
1341 if (!SA)
1342 return nullptr;
1343
1344 // Create an AddExpr for "PreStart" after subtracting Step. Full SCEV
1345 // subtraction is expensive. For this purpose, perform a quick and dirty
1346 // difference, by checking for Step in the operand list. Note, that
1347 // SA might have repeated ops, like %a + %a + ..., so only remove one.
1349 for (auto It = DiffOps.begin(); It != DiffOps.end(); ++It)
1350 if (*It == Step) {
1351 DiffOps.erase(It);
1352 break;
1353 }
1354
1355 if (DiffOps.size() == SA->getNumOperands())
1356 return nullptr;
1357
1358 // Try to prove `WrapType` (SCEV::FlagNSW or SCEV::FlagNUW) on `PreStart` +
1359 // `Step`:
1360
1361 // 1. NSW/NUW flags on the step increment.
1362 auto PreStartFlags =
1364 const SCEV *PreStart = SE->getAddExpr(DiffOps, PreStartFlags);
1366 SE->getAddRecExpr(PreStart, Step, L, SCEV::FlagAnyWrap));
1367
1368 // "{S,+,X} is <nsw>/<nuw>" and "the backedge is taken at least once" implies
1369 // "S+X does not sign/unsign-overflow".
1370 //
1371
1372 const SCEV *BECount = SE->getBackedgeTakenCount(L);
1373 if (PreAR && PreAR->getNoWrapFlags(WrapType) &&
1374 !isa<SCEVCouldNotCompute>(BECount) && SE->isKnownPositive(BECount))
1375 return PreStart;
1376
1377 // 2. Direct overflow check on the step operation's expression.
1378 unsigned BitWidth = SE->getTypeSizeInBits(AR->getType());
1379 Type *WideTy = IntegerType::get(SE->getContext(), BitWidth * 2);
1380 const SCEV *OperandExtendedStart =
1381 SE->getAddExpr((SE->*GetExtendExpr)(PreStart, WideTy, Depth),
1382 (SE->*GetExtendExpr)(Step, WideTy, Depth));
1383 if ((SE->*GetExtendExpr)(Start, WideTy, Depth) == OperandExtendedStart) {
1384 if (PreAR && AR->getNoWrapFlags(WrapType)) {
1385 // If we know `AR` == {`PreStart`+`Step`,+,`Step`} is `WrapType` (FlagNSW
1386 // or FlagNUW) and that `PreStart` + `Step` is `WrapType` too, then
1387 // `PreAR` == {`PreStart`,+,`Step`} is also `WrapType`. Cache this fact.
1388 SE->setNoWrapFlags(const_cast<SCEVAddRecExpr *>(PreAR), WrapType);
1389 }
1390 return PreStart;
1391 }
1392
1393 // 3. Loop precondition.
1395 const SCEV *OverflowLimit =
1396 ExtendOpTraits<ExtendOpTy>::getOverflowLimitForStep(Step, &Pred, SE);
1397
1398 if (OverflowLimit &&
1399 SE->isLoopEntryGuardedByCond(L, Pred, PreStart, OverflowLimit))
1400 return PreStart;
1401
1402 return nullptr;
1403}
1404
1405// Get the normalized zero or sign extended expression for this AddRec's Start.
1406template <typename ExtendOpTy>
1407static const SCEV *getExtendAddRecStart(const SCEVAddRecExpr *AR, Type *Ty,
1408 ScalarEvolution *SE,
1409 unsigned Depth) {
1410 auto GetExtendExpr = ExtendOpTraits<ExtendOpTy>::GetExtendExpr;
1411
1412 const SCEV *PreStart = getPreStartForExtend<ExtendOpTy>(AR, Ty, SE, Depth);
1413 if (!PreStart)
1414 return (SE->*GetExtendExpr)(AR->getStart(), Ty, Depth);
1415
1416 return SE->getAddExpr((SE->*GetExtendExpr)(AR->getStepRecurrence(*SE), Ty,
1417 Depth),
1418 (SE->*GetExtendExpr)(PreStart, Ty, Depth));
1419}
1420
1421// Try to prove away overflow by looking at "nearby" add recurrences. A
1422// motivating example for this rule: if we know `{0,+,4}` is `ult` `-1` and it
1423// does not itself wrap then we can conclude that `{1,+,4}` is `nuw`.
1424//
1425// Formally:
1426//
1427// {S,+,X} == {S-T,+,X} + T
1428// => Ext({S,+,X}) == Ext({S-T,+,X} + T)
1429//
1430// If ({S-T,+,X} + T) does not overflow ... (1)
1431//
1432// RHS == Ext({S-T,+,X} + T) == Ext({S-T,+,X}) + Ext(T)
1433//
1434// If {S-T,+,X} does not overflow ... (2)
1435//
1436// RHS == Ext({S-T,+,X}) + Ext(T) == {Ext(S-T),+,Ext(X)} + Ext(T)
1437// == {Ext(S-T)+Ext(T),+,Ext(X)}
1438//
1439// If (S-T)+T does not overflow ... (3)
1440//
1441// RHS == {Ext(S-T)+Ext(T),+,Ext(X)} == {Ext(S-T+T),+,Ext(X)}
1442// == {Ext(S),+,Ext(X)} == LHS
1443//
1444// Thus, if (1), (2) and (3) are true for some T, then
1445// Ext({S,+,X}) == {Ext(S),+,Ext(X)}
1446//
1447// (3) is implied by (1) -- "(S-T)+T does not overflow" is simply "({S-T,+,X}+T)
1448// does not overflow" restricted to the 0th iteration. Therefore we only need
1449// to check for (1) and (2).
1450//
1451// In the current context, S is `Start`, X is `Step`, Ext is `ExtendOpTy` and T
1452// is `Delta` (defined below).
1453template <typename ExtendOpTy>
1454bool ScalarEvolution::proveNoWrapByVaryingStart(const SCEV *Start,
1455 const SCEV *Step,
1456 const Loop *L) {
1457 auto WrapType = ExtendOpTraits<ExtendOpTy>::WrapType;
1458
1459 // We restrict `Start` to a constant to prevent SCEV from spending too much
1460 // time here. It is correct (but more expensive) to continue with a
1461 // non-constant `Start` and do a general SCEV subtraction to compute
1462 // `PreStart` below.
1463 const SCEVConstant *StartC = dyn_cast<SCEVConstant>(Start);
1464 if (!StartC)
1465 return false;
1466
1467 APInt StartAI = StartC->getAPInt();
1468
1469 for (unsigned Delta : {-2, -1, 1, 2}) {
1470 const SCEV *PreStart = getConstant(StartAI - Delta);
1471
1472 FoldingSetNodeID ID;
1473 ID.AddInteger(scAddRecExpr);
1474 ID.AddPointer(PreStart);
1475 ID.AddPointer(Step);
1476 ID.AddPointer(L);
1477 void *IP = nullptr;
1478 const auto *PreAR =
1479 static_cast<SCEVAddRecExpr *>(UniqueSCEVs.FindNodeOrInsertPos(ID, IP));
1480
1481 // Give up if we don't already have the add recurrence we need because
1482 // actually constructing an add recurrence is relatively expensive.
1483 if (PreAR && PreAR->getNoWrapFlags(WrapType)) { // proves (2)
1484 const SCEV *DeltaS = getConstant(StartC->getType(), Delta);
1486 const SCEV *Limit = ExtendOpTraits<ExtendOpTy>::getOverflowLimitForStep(
1487 DeltaS, &Pred, this);
1488 if (Limit && isKnownPredicate(Pred, PreAR, Limit)) // proves (1)
1489 return true;
1490 }
1491 }
1492
1493 return false;
1494}
1495
1496// Finds an integer D for an expression (C + x + y + ...) such that the top
1497// level addition in (D + (C - D + x + y + ...)) would not wrap (signed or
1498// unsigned) and the number of trailing zeros of (C - D + x + y + ...) is
1499// maximized, where C is the \p ConstantTerm, x, y, ... are arbitrary SCEVs, and
1500// the (C + x + y + ...) expression is \p WholeAddExpr.
1502 const SCEVConstant *ConstantTerm,
1503 const SCEVAddExpr *WholeAddExpr) {
1504 const APInt &C = ConstantTerm->getAPInt();
1505 const unsigned BitWidth = C.getBitWidth();
1506 // Find number of trailing zeros of (x + y + ...) w/o the C first:
1507 uint32_t TZ = BitWidth;
1508 for (unsigned I = 1, E = WholeAddExpr->getNumOperands(); I < E && TZ; ++I)
1509 TZ = std::min(TZ, SE.getMinTrailingZeros(WholeAddExpr->getOperand(I)));
1510 if (TZ) {
1511 // Set D to be as many least significant bits of C as possible while still
1512 // guaranteeing that adding D to (C - D + x + y + ...) won't cause a wrap:
1513 return TZ < BitWidth ? C.trunc(TZ).zext(BitWidth) : C;
1514 }
1515 return APInt(BitWidth, 0);
1516}
1517
1518// Finds an integer D for an affine AddRec expression {C,+,x} such that the top
1519// level addition in (D + {C-D,+,x}) would not wrap (signed or unsigned) and the
1520// number of trailing zeros of (C - D + x * n) is maximized, where C is the \p
1521// ConstantStart, x is an arbitrary \p Step, and n is the loop trip count.
1523 const APInt &ConstantStart,
1524 const SCEV *Step) {
1525 const unsigned BitWidth = ConstantStart.getBitWidth();
1526 const uint32_t TZ = SE.getMinTrailingZeros(Step);
1527 if (TZ)
1528 return TZ < BitWidth ? ConstantStart.trunc(TZ).zext(BitWidth)
1529 : ConstantStart;
1530 return APInt(BitWidth, 0);
1531}
1532
1534 const ScalarEvolution::FoldID &ID, const SCEV *S,
1537 &FoldCacheUser) {
1538 auto I = FoldCache.insert({ID, S});
1539 if (!I.second) {
1540 // Remove FoldCacheUser entry for ID when replacing an existing FoldCache
1541 // entry.
1542 auto &UserIDs = FoldCacheUser[I.first->second];
1543 assert(count(UserIDs, ID) == 1 && "unexpected duplicates in UserIDs");
1544 for (unsigned I = 0; I != UserIDs.size(); ++I)
1545 if (UserIDs[I] == ID) {
1546 std::swap(UserIDs[I], UserIDs.back());
1547 break;
1548 }
1549 UserIDs.pop_back();
1550 I.first->second = S;
1551 }
1552 FoldCacheUser[S].push_back(ID);
1553}
1554
1555const SCEV *
1557 assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) &&
1558 "This is not an extending conversion!");
1559 assert(isSCEVable(Ty) &&
1560 "This is not a conversion to a SCEVable type!");
1561 assert(!Op->getType()->isPointerTy() && "Can't extend pointer!");
1562 Ty = getEffectiveSCEVType(Ty);
1563
1564 FoldID ID(scZeroExtend, Op, Ty);
1565 if (const SCEV *S = FoldCache.lookup(ID))
1566 return S;
1567
1568 const SCEV *S = getZeroExtendExprImpl(Op, Ty, Depth);
1570 insertFoldCacheEntry(ID, S, FoldCache, FoldCacheUser);
1571 return S;
1572}
1573
1575 unsigned Depth) {
1576 assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) &&
1577 "This is not an extending conversion!");
1578 assert(isSCEVable(Ty) && "This is not a conversion to a SCEVable type!");
1579 assert(!Op->getType()->isPointerTy() && "Can't extend pointer!");
1580
1581 // Fold if the operand is constant.
1582 if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op))
1583 return getConstant(SC->getAPInt().zext(getTypeSizeInBits(Ty)));
1584
1585 // zext(zext(x)) --> zext(x)
1587 return getZeroExtendExpr(SZ->getOperand(), Ty, Depth + 1);
1588
1589 // Before doing any expensive analysis, check to see if we've already
1590 // computed a SCEV for this Op and Ty.
1592 ID.AddInteger(scZeroExtend);
1593 ID.AddPointer(Op);
1594 ID.AddPointer(Ty);
1595 void *IP = nullptr;
1596 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
1597 if (Depth > MaxCastDepth) {
1598 SCEV *S = new (SCEVAllocator) SCEVZeroExtendExpr(ID.Intern(SCEVAllocator),
1599 Op, Ty);
1600 UniqueSCEVs.InsertNode(S, IP);
1601 registerUser(S, Op);
1602 return S;
1603 }
1604
1605 // zext(trunc(x)) --> zext(x) or x or trunc(x)
1607 // It's possible the bits taken off by the truncate were all zero bits. If
1608 // so, we should be able to simplify this further.
1609 const SCEV *X = ST->getOperand();
1611 unsigned TruncBits = getTypeSizeInBits(ST->getType());
1612 unsigned NewBits = getTypeSizeInBits(Ty);
1613 if (CR.truncate(TruncBits).zeroExtend(NewBits).contains(
1614 CR.zextOrTrunc(NewBits)))
1615 return getTruncateOrZeroExtend(X, Ty, Depth);
1616 }
1617
1618 // If the input value is a chrec scev, and we can prove that the value
1619 // did not overflow the old, smaller, value, we can zero extend all of the
1620 // operands (often constants). This allows analysis of something like
1621 // this: for (unsigned char X = 0; X < 100; ++X) { int Y = X; }
1623 if (AR->isAffine()) {
1624 const SCEV *Start = AR->getStart();
1625 const SCEV *Step = AR->getStepRecurrence(*this);
1626 unsigned BitWidth = getTypeSizeInBits(AR->getType());
1627 const Loop *L = AR->getLoop();
1628
1629 // If we have special knowledge that this addrec won't overflow,
1630 // we don't need to do any further analysis.
1631 if (AR->hasNoUnsignedWrap()) {
1632 Start =
1634 Step = getZeroExtendExpr(Step, Ty, Depth + 1);
1635 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
1636 }
1637
1638 // Check whether the backedge-taken count is SCEVCouldNotCompute.
1639 // Note that this serves two purposes: It filters out loops that are
1640 // simply not analyzable, and it covers the case where this code is
1641 // being called from within backedge-taken count analysis, such that
1642 // attempting to ask for the backedge-taken count would likely result
1643 // in infinite recursion. In the later case, the analysis code will
1644 // cope with a conservative value, and it will take care to purge
1645 // that value once it has finished.
1646 const SCEV *MaxBECount = getConstantMaxBackedgeTakenCount(L);
1647 if (!isa<SCEVCouldNotCompute>(MaxBECount)) {
1648 // Manually compute the final value for AR, checking for overflow.
1649
1650 // Check whether the backedge-taken count can be losslessly casted to
1651 // the addrec's type. The count is always unsigned.
1652 const SCEV *CastedMaxBECount =
1653 getTruncateOrZeroExtend(MaxBECount, Start->getType(), Depth);
1654 const SCEV *RecastedMaxBECount = getTruncateOrZeroExtend(
1655 CastedMaxBECount, MaxBECount->getType(), Depth);
1656 if (MaxBECount == RecastedMaxBECount) {
1657 Type *WideTy = IntegerType::get(getContext(), BitWidth * 2);
1658 // Check whether Start+Step*MaxBECount has no unsigned overflow.
1659 const SCEV *ZMul = getMulExpr(CastedMaxBECount, Step,
1661 const SCEV *ZAdd = getZeroExtendExpr(getAddExpr(Start, ZMul,
1663 Depth + 1),
1664 WideTy, Depth + 1);
1665 const SCEV *WideStart = getZeroExtendExpr(Start, WideTy, Depth + 1);
1666 const SCEV *WideMaxBECount =
1667 getZeroExtendExpr(CastedMaxBECount, WideTy, Depth + 1);
1668 const SCEV *OperandExtendedAdd =
1669 getAddExpr(WideStart,
1670 getMulExpr(WideMaxBECount,
1671 getZeroExtendExpr(Step, WideTy, Depth + 1),
1674 if (ZAdd == OperandExtendedAdd) {
1675 // Cache knowledge of AR NUW, which is propagated to this AddRec.
1676 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNUW);
1677 // Return the expression with the addrec on the outside.
1678 Start = getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this,
1679 Depth + 1);
1680 Step = getZeroExtendExpr(Step, Ty, Depth + 1);
1681 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
1682 }
1683 // Similar to above, only this time treat the step value as signed.
1684 // This covers loops that count down.
1685 OperandExtendedAdd =
1686 getAddExpr(WideStart,
1687 getMulExpr(WideMaxBECount,
1688 getSignExtendExpr(Step, WideTy, Depth + 1),
1691 if (ZAdd == OperandExtendedAdd) {
1692 // Cache knowledge of AR NW, which is propagated to this AddRec.
1693 // Negative step causes unsigned wrap, but it still can't self-wrap.
1694 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNW);
1695 // Return the expression with the addrec on the outside.
1696 Start = getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this,
1697 Depth + 1);
1698 Step = getSignExtendExpr(Step, Ty, Depth + 1);
1699 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
1700 }
1701 }
1702 }
1703
1704 // Normally, in the cases we can prove no-overflow via a
1705 // backedge guarding condition, we can also compute a backedge
1706 // taken count for the loop. The exceptions are assumptions and
1707 // guards present in the loop -- SCEV is not great at exploiting
1708 // these to compute max backedge taken counts, but can still use
1709 // these to prove lack of overflow. Use this fact to avoid
1710 // doing extra work that may not pay off.
1711 if (!isa<SCEVCouldNotCompute>(MaxBECount) || HasGuards ||
1712 !AC.assumptions().empty()) {
1713
1714 auto NewFlags = proveNoUnsignedWrapViaInduction(AR);
1715 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), NewFlags);
1716 if (AR->hasNoUnsignedWrap()) {
1717 // Same as nuw case above - duplicated here to avoid a compile time
1718 // issue. It's not clear that the order of checks does matter, but
1719 // it's one of two issue possible causes for a change which was
1720 // reverted. Be conservative for the moment.
1721 Start =
1723 Step = getZeroExtendExpr(Step, Ty, Depth + 1);
1724 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
1725 }
1726
1727 // For a negative step, we can extend the operands iff doing so only
1728 // traverses values in the range zext([0,UINT_MAX]).
1729 if (isKnownNegative(Step)) {
1731 getSignedRangeMin(Step));
1734 // Cache knowledge of AR NW, which is propagated to this
1735 // AddRec. Negative step causes unsigned wrap, but it
1736 // still can't self-wrap.
1737 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNW);
1738 // Return the expression with the addrec on the outside.
1739 Start = getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this,
1740 Depth + 1);
1741 Step = getSignExtendExpr(Step, Ty, Depth + 1);
1742 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
1743 }
1744 }
1745 }
1746
1747 // zext({C,+,Step}) --> (zext(D) + zext({C-D,+,Step}))<nuw><nsw>
1748 // if D + (C - D + Step * n) could be proven to not unsigned wrap
1749 // where D maximizes the number of trailing zeros of (C - D + Step * n)
1750 if (const auto *SC = dyn_cast<SCEVConstant>(Start)) {
1751 const APInt &C = SC->getAPInt();
1752 const APInt &D = extractConstantWithoutWrapping(*this, C, Step);
1753 if (D != 0) {
1754 const SCEV *SZExtD = getZeroExtendExpr(getConstant(D), Ty, Depth);
1755 const SCEV *SResidual =
1756 getAddRecExpr(getConstant(C - D), Step, L, AR->getNoWrapFlags());
1757 const SCEV *SZExtR = getZeroExtendExpr(SResidual, Ty, Depth + 1);
1758 return getAddExpr(SZExtD, SZExtR,
1760 Depth + 1);
1761 }
1762 }
1763
1764 if (proveNoWrapByVaryingStart<SCEVZeroExtendExpr>(Start, Step, L)) {
1765 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNUW);
1766 Start =
1768 Step = getZeroExtendExpr(Step, Ty, Depth + 1);
1769 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
1770 }
1771 }
1772
1773 // zext(A % B) --> zext(A) % zext(B)
1774 {
1775 const SCEV *LHS;
1776 const SCEV *RHS;
1777 if (matchURem(Op, LHS, RHS))
1778 return getURemExpr(getZeroExtendExpr(LHS, Ty, Depth + 1),
1779 getZeroExtendExpr(RHS, Ty, Depth + 1));
1780 }
1781
1782 // zext(A / B) --> zext(A) / zext(B).
1783 if (auto *Div = dyn_cast<SCEVUDivExpr>(Op))
1784 return getUDivExpr(getZeroExtendExpr(Div->getLHS(), Ty, Depth + 1),
1785 getZeroExtendExpr(Div->getRHS(), Ty, Depth + 1));
1786
1787 if (auto *SA = dyn_cast<SCEVAddExpr>(Op)) {
1788 // zext((A + B + ...)<nuw>) --> (zext(A) + zext(B) + ...)<nuw>
1789 if (SA->hasNoUnsignedWrap()) {
1790 // If the addition does not unsign overflow then we can, by definition,
1791 // commute the zero extension with the addition operation.
1793 for (const auto *Op : SA->operands())
1794 Ops.push_back(getZeroExtendExpr(Op, Ty, Depth + 1));
1795 return getAddExpr(Ops, SCEV::FlagNUW, Depth + 1);
1796 }
1797
1798 // zext(C + x + y + ...) --> (zext(D) + zext((C - D) + x + y + ...))
1799 // if D + (C - D + x + y + ...) could be proven to not unsigned wrap
1800 // where D maximizes the number of trailing zeros of (C - D + x + y + ...)
1801 //
1802 // Often address arithmetics contain expressions like
1803 // (zext (add (shl X, C1), C2)), for instance, (zext (5 + (4 * X))).
1804 // This transformation is useful while proving that such expressions are
1805 // equal or differ by a small constant amount, see LoadStoreVectorizer pass.
1806 if (const auto *SC = dyn_cast<SCEVConstant>(SA->getOperand(0))) {
1807 const APInt &D = extractConstantWithoutWrapping(*this, SC, SA);
1808 if (D != 0) {
1809 const SCEV *SZExtD = getZeroExtendExpr(getConstant(D), Ty, Depth);
1810 const SCEV *SResidual =
1812 const SCEV *SZExtR = getZeroExtendExpr(SResidual, Ty, Depth + 1);
1813 return getAddExpr(SZExtD, SZExtR,
1815 Depth + 1);
1816 }
1817 }
1818 }
1819
1820 if (auto *SM = dyn_cast<SCEVMulExpr>(Op)) {
1821 // zext((A * B * ...)<nuw>) --> (zext(A) * zext(B) * ...)<nuw>
1822 if (SM->hasNoUnsignedWrap()) {
1823 // If the multiply does not unsign overflow then we can, by definition,
1824 // commute the zero extension with the multiply operation.
1826 for (const auto *Op : SM->operands())
1827 Ops.push_back(getZeroExtendExpr(Op, Ty, Depth + 1));
1828 return getMulExpr(Ops, SCEV::FlagNUW, Depth + 1);
1829 }
1830
1831 // zext(2^K * (trunc X to iN)) to iM ->
1832 // 2^K * (zext(trunc X to i{N-K}) to iM)<nuw>
1833 //
1834 // Proof:
1835 //
1836 // zext(2^K * (trunc X to iN)) to iM
1837 // = zext((trunc X to iN) << K) to iM
1838 // = zext((trunc X to i{N-K}) << K)<nuw> to iM
1839 // (because shl removes the top K bits)
1840 // = zext((2^K * (trunc X to i{N-K}))<nuw>) to iM
1841 // = (2^K * (zext(trunc X to i{N-K}) to iM))<nuw>.
1842 //
1843 if (SM->getNumOperands() == 2)
1844 if (auto *MulLHS = dyn_cast<SCEVConstant>(SM->getOperand(0)))
1845 if (MulLHS->getAPInt().isPowerOf2())
1846 if (auto *TruncRHS = dyn_cast<SCEVTruncateExpr>(SM->getOperand(1))) {
1847 int NewTruncBits = getTypeSizeInBits(TruncRHS->getType()) -
1848 MulLHS->getAPInt().logBase2();
1849 Type *NewTruncTy = IntegerType::get(getContext(), NewTruncBits);
1850 return getMulExpr(
1851 getZeroExtendExpr(MulLHS, Ty),
1853 getTruncateExpr(TruncRHS->getOperand(), NewTruncTy), Ty),
1854 SCEV::FlagNUW, Depth + 1);
1855 }
1856 }
1857
1858 // zext(umin(x, y)) -> umin(zext(x), zext(y))
1859 // zext(umax(x, y)) -> umax(zext(x), zext(y))
1863 for (auto *Operand : MinMax->operands())
1864 Operands.push_back(getZeroExtendExpr(Operand, Ty));
1866 return getUMinExpr(Operands);
1867 return getUMaxExpr(Operands);
1868 }
1869
1870 // zext(umin_seq(x, y)) -> umin_seq(zext(x), zext(y))
1872 assert(isa<SCEVSequentialUMinExpr>(MinMax) && "Not supported!");
1874 for (auto *Operand : MinMax->operands())
1875 Operands.push_back(getZeroExtendExpr(Operand, Ty));
1876 return getUMinExpr(Operands, /*Sequential*/ true);
1877 }
1878
1879 // The cast wasn't folded; create an explicit cast node.
1880 // Recompute the insert position, as it may have been invalidated.
1881 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
1882 SCEV *S = new (SCEVAllocator) SCEVZeroExtendExpr(ID.Intern(SCEVAllocator),
1883 Op, Ty);
1884 UniqueSCEVs.InsertNode(S, IP);
1885 registerUser(S, Op);
1886 return S;
1887}
1888
1889const SCEV *
1891 assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) &&
1892 "This is not an extending conversion!");
1893 assert(isSCEVable(Ty) &&
1894 "This is not a conversion to a SCEVable type!");
1895 assert(!Op->getType()->isPointerTy() && "Can't extend pointer!");
1896 Ty = getEffectiveSCEVType(Ty);
1897
1898 FoldID ID(scSignExtend, Op, Ty);
1899 if (const SCEV *S = FoldCache.lookup(ID))
1900 return S;
1901
1902 const SCEV *S = getSignExtendExprImpl(Op, Ty, Depth);
1904 insertFoldCacheEntry(ID, S, FoldCache, FoldCacheUser);
1905 return S;
1906}
1907
1909 unsigned Depth) {
1910 assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) &&
1911 "This is not an extending conversion!");
1912 assert(isSCEVable(Ty) && "This is not a conversion to a SCEVable type!");
1913 assert(!Op->getType()->isPointerTy() && "Can't extend pointer!");
1914 Ty = getEffectiveSCEVType(Ty);
1915
1916 // Fold if the operand is constant.
1917 if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op))
1918 return getConstant(SC->getAPInt().sext(getTypeSizeInBits(Ty)));
1919
1920 // sext(sext(x)) --> sext(x)
1922 return getSignExtendExpr(SS->getOperand(), Ty, Depth + 1);
1923
1924 // sext(zext(x)) --> zext(x)
1926 return getZeroExtendExpr(SZ->getOperand(), Ty, Depth + 1);
1927
1928 // Before doing any expensive analysis, check to see if we've already
1929 // computed a SCEV for this Op and Ty.
1931 ID.AddInteger(scSignExtend);
1932 ID.AddPointer(Op);
1933 ID.AddPointer(Ty);
1934 void *IP = nullptr;
1935 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
1936 // Limit recursion depth.
1937 if (Depth > MaxCastDepth) {
1938 SCEV *S = new (SCEVAllocator) SCEVSignExtendExpr(ID.Intern(SCEVAllocator),
1939 Op, Ty);
1940 UniqueSCEVs.InsertNode(S, IP);
1941 registerUser(S, Op);
1942 return S;
1943 }
1944
1945 // sext(trunc(x)) --> sext(x) or x or trunc(x)
1947 // It's possible the bits taken off by the truncate were all sign bits. If
1948 // so, we should be able to simplify this further.
1949 const SCEV *X = ST->getOperand();
1951 unsigned TruncBits = getTypeSizeInBits(ST->getType());
1952 unsigned NewBits = getTypeSizeInBits(Ty);
1953 if (CR.truncate(TruncBits).signExtend(NewBits).contains(
1954 CR.sextOrTrunc(NewBits)))
1955 return getTruncateOrSignExtend(X, Ty, Depth);
1956 }
1957
1958 if (auto *SA = dyn_cast<SCEVAddExpr>(Op)) {
1959 // sext((A + B + ...)<nsw>) --> (sext(A) + sext(B) + ...)<nsw>
1960 if (SA->hasNoSignedWrap()) {
1961 // If the addition does not sign overflow then we can, by definition,
1962 // commute the sign extension with the addition operation.
1964 for (const auto *Op : SA->operands())
1965 Ops.push_back(getSignExtendExpr(Op, Ty, Depth + 1));
1966 return getAddExpr(Ops, SCEV::FlagNSW, Depth + 1);
1967 }
1968
1969 // sext(C + x + y + ...) --> (sext(D) + sext((C - D) + x + y + ...))
1970 // if D + (C - D + x + y + ...) could be proven to not signed wrap
1971 // where D maximizes the number of trailing zeros of (C - D + x + y + ...)
1972 //
1973 // For instance, this will bring two seemingly different expressions:
1974 // 1 + sext(5 + 20 * %x + 24 * %y) and
1975 // sext(6 + 20 * %x + 24 * %y)
1976 // to the same form:
1977 // 2 + sext(4 + 20 * %x + 24 * %y)
1978 if (const auto *SC = dyn_cast<SCEVConstant>(SA->getOperand(0))) {
1979 const APInt &D = extractConstantWithoutWrapping(*this, SC, SA);
1980 if (D != 0) {
1981 const SCEV *SSExtD = getSignExtendExpr(getConstant(D), Ty, Depth);
1982 const SCEV *SResidual =
1984 const SCEV *SSExtR = getSignExtendExpr(SResidual, Ty, Depth + 1);
1985 return getAddExpr(SSExtD, SSExtR,
1987 Depth + 1);
1988 }
1989 }
1990 }
1991 // If the input value is a chrec scev, and we can prove that the value
1992 // did not overflow the old, smaller, value, we can sign extend all of the
1993 // operands (often constants). This allows analysis of something like
1994 // this: for (signed char X = 0; X < 100; ++X) { int Y = X; }
1996 if (AR->isAffine()) {
1997 const SCEV *Start = AR->getStart();
1998 const SCEV *Step = AR->getStepRecurrence(*this);
1999 unsigned BitWidth = getTypeSizeInBits(AR->getType());
2000 const Loop *L = AR->getLoop();
2001
2002 // If we have special knowledge that this addrec won't overflow,
2003 // we don't need to do any further analysis.
2004 if (AR->hasNoSignedWrap()) {
2005 Start =
2007 Step = getSignExtendExpr(Step, Ty, Depth + 1);
2008 return getAddRecExpr(Start, Step, L, SCEV::FlagNSW);
2009 }
2010
2011 // Check whether the backedge-taken count is SCEVCouldNotCompute.
2012 // Note that this serves two purposes: It filters out loops that are
2013 // simply not analyzable, and it covers the case where this code is
2014 // being called from within backedge-taken count analysis, such that
2015 // attempting to ask for the backedge-taken count would likely result
2016 // in infinite recursion. In the later case, the analysis code will
2017 // cope with a conservative value, and it will take care to purge
2018 // that value once it has finished.
2019 const SCEV *MaxBECount = getConstantMaxBackedgeTakenCount(L);
2020 if (!isa<SCEVCouldNotCompute>(MaxBECount)) {
2021 // Manually compute the final value for AR, checking for
2022 // overflow.
2023
2024 // Check whether the backedge-taken count can be losslessly casted to
2025 // the addrec's type. The count is always unsigned.
2026 const SCEV *CastedMaxBECount =
2027 getTruncateOrZeroExtend(MaxBECount, Start->getType(), Depth);
2028 const SCEV *RecastedMaxBECount = getTruncateOrZeroExtend(
2029 CastedMaxBECount, MaxBECount->getType(), Depth);
2030 if (MaxBECount == RecastedMaxBECount) {
2031 Type *WideTy = IntegerType::get(getContext(), BitWidth * 2);
2032 // Check whether Start+Step*MaxBECount has no signed overflow.
2033 const SCEV *SMul = getMulExpr(CastedMaxBECount, Step,
2035 const SCEV *SAdd = getSignExtendExpr(getAddExpr(Start, SMul,
2037 Depth + 1),
2038 WideTy, Depth + 1);
2039 const SCEV *WideStart = getSignExtendExpr(Start, WideTy, Depth + 1);
2040 const SCEV *WideMaxBECount =
2041 getZeroExtendExpr(CastedMaxBECount, WideTy, Depth + 1);
2042 const SCEV *OperandExtendedAdd =
2043 getAddExpr(WideStart,
2044 getMulExpr(WideMaxBECount,
2045 getSignExtendExpr(Step, WideTy, Depth + 1),
2048 if (SAdd == OperandExtendedAdd) {
2049 // Cache knowledge of AR NSW, which is propagated to this AddRec.
2050 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNSW);
2051 // Return the expression with the addrec on the outside.
2052 Start = getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this,
2053 Depth + 1);
2054 Step = getSignExtendExpr(Step, Ty, Depth + 1);
2055 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
2056 }
2057 // Similar to above, only this time treat the step value as unsigned.
2058 // This covers loops that count up with an unsigned step.
2059 OperandExtendedAdd =
2060 getAddExpr(WideStart,
2061 getMulExpr(WideMaxBECount,
2062 getZeroExtendExpr(Step, WideTy, Depth + 1),
2065 if (SAdd == OperandExtendedAdd) {
2066 // If AR wraps around then
2067 //
2068 // abs(Step) * MaxBECount > unsigned-max(AR->getType())
2069 // => SAdd != OperandExtendedAdd
2070 //
2071 // Thus (AR is not NW => SAdd != OperandExtendedAdd) <=>
2072 // (SAdd == OperandExtendedAdd => AR is NW)
2073
2074 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNW);
2075
2076 // Return the expression with the addrec on the outside.
2077 Start = getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this,
2078 Depth + 1);
2079 Step = getZeroExtendExpr(Step, Ty, Depth + 1);
2080 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
2081 }
2082 }
2083 }
2084
2085 auto NewFlags = proveNoSignedWrapViaInduction(AR);
2086 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), NewFlags);
2087 if (AR->hasNoSignedWrap()) {
2088 // Same as nsw case above - duplicated here to avoid a compile time
2089 // issue. It's not clear that the order of checks does matter, but
2090 // it's one of two issue possible causes for a change which was
2091 // reverted. Be conservative for the moment.
2092 Start =
2094 Step = getSignExtendExpr(Step, Ty, Depth + 1);
2095 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
2096 }
2097
2098 // sext({C,+,Step}) --> (sext(D) + sext({C-D,+,Step}))<nuw><nsw>
2099 // if D + (C - D + Step * n) could be proven to not signed wrap
2100 // where D maximizes the number of trailing zeros of (C - D + Step * n)
2101 if (const auto *SC = dyn_cast<SCEVConstant>(Start)) {
2102 const APInt &C = SC->getAPInt();
2103 const APInt &D = extractConstantWithoutWrapping(*this, C, Step);
2104 if (D != 0) {
2105 const SCEV *SSExtD = getSignExtendExpr(getConstant(D), Ty, Depth);
2106 const SCEV *SResidual =
2107 getAddRecExpr(getConstant(C - D), Step, L, AR->getNoWrapFlags());
2108 const SCEV *SSExtR = getSignExtendExpr(SResidual, Ty, Depth + 1);
2109 return getAddExpr(SSExtD, SSExtR,
2111 Depth + 1);
2112 }
2113 }
2114
2115 if (proveNoWrapByVaryingStart<SCEVSignExtendExpr>(Start, Step, L)) {
2116 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNSW);
2117 Start =
2119 Step = getSignExtendExpr(Step, Ty, Depth + 1);
2120 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
2121 }
2122 }
2123
2124 // If the input value is provably positive and we could not simplify
2125 // away the sext build a zext instead.
2127 return getZeroExtendExpr(Op, Ty, Depth + 1);
2128
2129 // sext(smin(x, y)) -> smin(sext(x), sext(y))
2130 // sext(smax(x, y)) -> smax(sext(x), sext(y))
2134 for (auto *Operand : MinMax->operands())
2135 Operands.push_back(getSignExtendExpr(Operand, Ty));
2137 return getSMinExpr(Operands);
2138 return getSMaxExpr(Operands);
2139 }
2140
2141 // The cast wasn't folded; create an explicit cast node.
2142 // Recompute the insert position, as it may have been invalidated.
2143 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
2144 SCEV *S = new (SCEVAllocator) SCEVSignExtendExpr(ID.Intern(SCEVAllocator),
2145 Op, Ty);
2146 UniqueSCEVs.InsertNode(S, IP);
2147 registerUser(S, { Op });
2148 return S;
2149}
2150
2152 Type *Ty) {
2153 switch (Kind) {
2154 case scTruncate:
2155 return getTruncateExpr(Op, Ty);
2156 case scZeroExtend:
2157 return getZeroExtendExpr(Op, Ty);
2158 case scSignExtend:
2159 return getSignExtendExpr(Op, Ty);
2160 case scPtrToInt:
2161 return getPtrToIntExpr(Op, Ty);
2162 default:
2163 llvm_unreachable("Not a SCEV cast expression!");
2164 }
2165}
2166
2167/// getAnyExtendExpr - Return a SCEV for the given operand extended with
2168/// unspecified bits out to the given type.
2170 Type *Ty) {
2171 assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) &&
2172 "This is not an extending conversion!");
2173 assert(isSCEVable(Ty) &&
2174 "This is not a conversion to a SCEVable type!");
2175 Ty = getEffectiveSCEVType(Ty);
2176
2177 // Sign-extend negative constants.
2178 if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op))
2179 if (SC->getAPInt().isNegative())
2180 return getSignExtendExpr(Op, Ty);
2181
2182 // Peel off a truncate cast.
2184 const SCEV *NewOp = T->getOperand();
2185 if (getTypeSizeInBits(NewOp->getType()) < getTypeSizeInBits(Ty))
2186 return getAnyExtendExpr(NewOp, Ty);
2187 return getTruncateOrNoop(NewOp, Ty);
2188 }
2189
2190 // Next try a zext cast. If the cast is folded, use it.
2191 const SCEV *ZExt = getZeroExtendExpr(Op, Ty);
2192 if (!isa<SCEVZeroExtendExpr>(ZExt))
2193 return ZExt;
2194
2195 // Next try a sext cast. If the cast is folded, use it.
2196 const SCEV *SExt = getSignExtendExpr(Op, Ty);
2197 if (!isa<SCEVSignExtendExpr>(SExt))
2198 return SExt;
2199
2200 // Force the cast to be folded into the operands of an addrec.
2201 if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Op)) {
2203 for (const SCEV *Op : AR->operands())
2204 Ops.push_back(getAnyExtendExpr(Op, Ty));
2205 return getAddRecExpr(Ops, AR->getLoop(), SCEV::FlagNW);
2206 }
2207
2208 // If the expression is obviously signed, use the sext cast value.
2209 if (isa<SCEVSMaxExpr>(Op))
2210 return SExt;
2211
2212 // Absent any other information, use the zext cast value.
2213 return ZExt;
2214}
2215
2216/// Process the given Ops list, which is a list of operands to be added under
2217/// the given scale, update the given map. This is a helper function for
2218/// getAddRecExpr. As an example of what it does, given a sequence of operands
2219/// that would form an add expression like this:
2220///
2221/// m + n + 13 + (A * (o + p + (B * (q + m + 29)))) + r + (-1 * r)
2222///
2223/// where A and B are constants, update the map with these values:
2224///
2225/// (m, 1+A*B), (n, 1), (o, A), (p, A), (q, A*B), (r, 0)
2226///
2227/// and add 13 + A*B*29 to AccumulatedConstant.
2228/// This will allow getAddRecExpr to produce this:
2229///
2230/// 13+A*B*29 + n + (m * (1+A*B)) + ((o + p) * A) + (q * A*B)
2231///
2232/// This form often exposes folding opportunities that are hidden in
2233/// the original operand list.
2234///
2235/// Return true iff it appears that any interesting folding opportunities
2236/// may be exposed. This helps getAddRecExpr short-circuit extra work in
2237/// the common case where no interesting opportunities are present, and
2238/// is also used as a check to avoid infinite recursion.
2239static bool
2242 APInt &AccumulatedConstant,
2243 ArrayRef<const SCEV *> Ops, const APInt &Scale,
2244 ScalarEvolution &SE) {
2245 bool Interesting = false;
2246
2247 // Iterate over the add operands. They are sorted, with constants first.
2248 unsigned i = 0;
2249 while (const SCEVConstant *C = dyn_cast<SCEVConstant>(Ops[i])) {
2250 ++i;
2251 // Pull a buried constant out to the outside.
2252 if (Scale != 1 || AccumulatedConstant != 0 || C->getValue()->isZero())
2253 Interesting = true;
2254 AccumulatedConstant += Scale * C->getAPInt();
2255 }
2256
2257 // Next comes everything else. We're especially interested in multiplies
2258 // here, but they're in the middle, so just visit the rest with one loop.
2259 for (; i != Ops.size(); ++i) {
2261 if (Mul && isa<SCEVConstant>(Mul->getOperand(0))) {
2262 APInt NewScale =
2263 Scale * cast<SCEVConstant>(Mul->getOperand(0))->getAPInt();
2264 if (Mul->getNumOperands() == 2 && isa<SCEVAddExpr>(Mul->getOperand(1))) {
2265 // A multiplication of a constant with another add; recurse.
2266 const SCEVAddExpr *Add = cast<SCEVAddExpr>(Mul->getOperand(1));
2267 Interesting |=
2268 CollectAddOperandsWithScales(M, NewOps, AccumulatedConstant,
2269 Add->operands(), NewScale, SE);
2270 } else {
2271 // A multiplication of a constant with some other value. Update
2272 // the map.
2273 SmallVector<const SCEV *, 4> MulOps(drop_begin(Mul->operands()));
2274 const SCEV *Key = SE.getMulExpr(MulOps);
2275 auto Pair = M.insert({Key, NewScale});
2276 if (Pair.second) {
2277 NewOps.push_back(Pair.first->first);
2278 } else {
2279 Pair.first->second += NewScale;
2280 // The map already had an entry for this value, which may indicate
2281 // a folding opportunity.
2282 Interesting = true;
2283 }
2284 }
2285 } else {
2286 // An ordinary operand. Update the map.
2287 std::pair<DenseMap<const SCEV *, APInt>::iterator, bool> Pair =
2288 M.insert({Ops[i], Scale});
2289 if (Pair.second) {
2290 NewOps.push_back(Pair.first->first);
2291 } else {
2292 Pair.first->second += Scale;
2293 // The map already had an entry for this value, which may indicate
2294 // a folding opportunity.
2295 Interesting = true;
2296 }
2297 }
2298 }
2299
2300 return Interesting;
2301}
2302
2304 const SCEV *LHS, const SCEV *RHS,
2305 const Instruction *CtxI) {
2306 const SCEV *(ScalarEvolution::*Operation)(const SCEV *, const SCEV *,
2307 SCEV::NoWrapFlags, unsigned);
2308 switch (BinOp) {
2309 default:
2310 llvm_unreachable("Unsupported binary op");
2311 case Instruction::Add:
2313 break;
2314 case Instruction::Sub:
2316 break;
2317 case Instruction::Mul:
2319 break;
2320 }
2321
2322 const SCEV *(ScalarEvolution::*Extension)(const SCEV *, Type *, unsigned) =
2325
2326 // Check ext(LHS op RHS) == ext(LHS) op ext(RHS)
2327 auto *NarrowTy = cast<IntegerType>(LHS->getType());
2328 auto *WideTy =
2329 IntegerType::get(NarrowTy->getContext(), NarrowTy->getBitWidth() * 2);
2330
2331 const SCEV *A = (this->*Extension)(
2332 (this->*Operation)(LHS, RHS, SCEV::FlagAnyWrap, 0), WideTy, 0);
2333 const SCEV *LHSB = (this->*Extension)(LHS, WideTy, 0);
2334 const SCEV *RHSB = (this->*Extension)(RHS, WideTy, 0);
2335 const SCEV *B = (this->*Operation)(LHSB, RHSB, SCEV::FlagAnyWrap, 0);
2336 if (A == B)
2337 return true;
2338 // Can we use context to prove the fact we need?
2339 if (!CtxI)
2340 return false;
2341 // TODO: Support mul.
2342 if (BinOp == Instruction::Mul)
2343 return false;
2344 auto *RHSC = dyn_cast<SCEVConstant>(RHS);
2345 // TODO: Lift this limitation.
2346 if (!RHSC)
2347 return false;
2348 APInt C = RHSC->getAPInt();
2349 unsigned NumBits = C.getBitWidth();
2350 bool IsSub = (BinOp == Instruction::Sub);
2351 bool IsNegativeConst = (Signed && C.isNegative());
2352 // Compute the direction and magnitude by which we need to check overflow.
2353 bool OverflowDown = IsSub ^ IsNegativeConst;
2354 APInt Magnitude = C;
2355 if (IsNegativeConst) {
2356 if (C == APInt::getSignedMinValue(NumBits))
2357 // TODO: SINT_MIN on inversion gives the same negative value, we don't
2358 // want to deal with that.
2359 return false;
2360 Magnitude = -C;
2361 }
2362
2364 if (OverflowDown) {
2365 // To avoid overflow down, we need to make sure that MIN + Magnitude <= LHS.
2366 APInt Min = Signed ? APInt::getSignedMinValue(NumBits)
2367 : APInt::getMinValue(NumBits);
2368 APInt Limit = Min + Magnitude;
2369 return isKnownPredicateAt(Pred, getConstant(Limit), LHS, CtxI);
2370 } else {
2371 // To avoid overflow up, we need to make sure that LHS <= MAX - Magnitude.
2372 APInt Max = Signed ? APInt::getSignedMaxValue(NumBits)
2373 : APInt::getMaxValue(NumBits);
2374 APInt Limit = Max - Magnitude;
2375 return isKnownPredicateAt(Pred, LHS, getConstant(Limit), CtxI);
2376 }
2377}
2378
2379std::optional<SCEV::NoWrapFlags>
2381 const OverflowingBinaryOperator *OBO) {
2382 // It cannot be done any better.
2383 if (OBO->hasNoUnsignedWrap() && OBO->hasNoSignedWrap())
2384 return std::nullopt;
2385
2387
2388 if (OBO->hasNoUnsignedWrap())
2390 if (OBO->hasNoSignedWrap())
2392
2393 bool Deduced = false;
2394
2395 if (OBO->getOpcode() != Instruction::Add &&
2396 OBO->getOpcode() != Instruction::Sub &&
2397 OBO->getOpcode() != Instruction::Mul)
2398 return std::nullopt;
2399
2400 const SCEV *LHS = getSCEV(OBO->getOperand(0));
2401 const SCEV *RHS = getSCEV(OBO->getOperand(1));
2402
2403 const Instruction *CtxI =
2405 if (!OBO->hasNoUnsignedWrap() &&
2407 /* Signed */ false, LHS, RHS, CtxI)) {
2409 Deduced = true;
2410 }
2411
2412 if (!OBO->hasNoSignedWrap() &&
2414 /* Signed */ true, LHS, RHS, CtxI)) {
2416 Deduced = true;
2417 }
2418
2419 if (Deduced)
2420 return Flags;
2421 return std::nullopt;
2422}
2423
2424// We're trying to construct a SCEV of type `Type' with `Ops' as operands and
2425// `OldFlags' as can't-wrap behavior. Infer a more aggressive set of
2426// can't-overflow flags for the operation if possible.
2427static SCEV::NoWrapFlags
2430 SCEV::NoWrapFlags Flags) {
2431 using namespace std::placeholders;
2432
2433 using OBO = OverflowingBinaryOperator;
2434
2435 bool CanAnalyze =
2437 (void)CanAnalyze;
2438 assert(CanAnalyze && "don't call from other places!");
2439
2440 int SignOrUnsignMask = SCEV::FlagNUW | SCEV::FlagNSW;
2441 SCEV::NoWrapFlags SignOrUnsignWrap =
2442 ScalarEvolution::maskFlags(Flags, SignOrUnsignMask);
2443
2444 // If FlagNSW is true and all the operands are non-negative, infer FlagNUW.
2445 auto IsKnownNonNegative = [&](const SCEV *S) {
2446 return SE->isKnownNonNegative(S);
2447 };
2448
2449 if (SignOrUnsignWrap == SCEV::FlagNSW && all_of(Ops, IsKnownNonNegative))
2450 Flags =
2451 ScalarEvolution::setFlags(Flags, (SCEV::NoWrapFlags)SignOrUnsignMask);
2452
2453 SignOrUnsignWrap = ScalarEvolution::maskFlags(Flags, SignOrUnsignMask);
2454
2455 if (SignOrUnsignWrap != SignOrUnsignMask &&
2456 (Type == scAddExpr || Type == scMulExpr) && Ops.size() == 2 &&
2457 isa<SCEVConstant>(Ops[0])) {
2458
2459 auto Opcode = [&] {
2460 switch (Type) {
2461 case scAddExpr:
2462 return Instruction::Add;
2463 case scMulExpr:
2464 return Instruction::Mul;
2465 default:
2466 llvm_unreachable("Unexpected SCEV op.");
2467 }
2468 }();
2469
2470 const APInt &C = cast<SCEVConstant>(Ops[0])->getAPInt();
2471
2472 // (A <opcode> C) --> (A <opcode> C)<nsw> if the op doesn't sign overflow.
2473 if (!(SignOrUnsignWrap & SCEV::FlagNSW)) {
2475 Opcode, C, OBO::NoSignedWrap);
2476 if (NSWRegion.contains(SE->getSignedRange(Ops[1])))
2478 }
2479
2480 // (A <opcode> C) --> (A <opcode> C)<nuw> if the op doesn't unsign overflow.
2481 if (!(SignOrUnsignWrap & SCEV::FlagNUW)) {
2483 Opcode, C, OBO::NoUnsignedWrap);
2484 if (NUWRegion.contains(SE->getUnsignedRange(Ops[1])))
2486 }
2487 }
2488
2489 // <0,+,nonnegative><nw> is also nuw
2490 // TODO: Add corresponding nsw case
2492 !ScalarEvolution::hasFlags(Flags, SCEV::FlagNUW) && Ops.size() == 2 &&
2493 Ops[0]->isZero() && IsKnownNonNegative(Ops[1]))
2495
2496 // both (udiv X, Y) * Y and Y * (udiv X, Y) are always NUW
2498 Ops.size() == 2) {
2499 if (auto *UDiv = dyn_cast<SCEVUDivExpr>(Ops[0]))
2500 if (UDiv->getOperand(1) == Ops[1])
2502 if (auto *UDiv = dyn_cast<SCEVUDivExpr>(Ops[1]))
2503 if (UDiv->getOperand(1) == Ops[0])
2505 }
2506
2507 return Flags;
2508}
2509
2511 return isLoopInvariant(S, L) && properlyDominates(S, L->getHeader());
2512}
2513
2514/// Get a canonical add expression, or something simpler if possible.
2516 SCEV::NoWrapFlags OrigFlags,
2517 unsigned Depth) {
2518 assert(!(OrigFlags & ~(SCEV::FlagNUW | SCEV::FlagNSW)) &&
2519 "only nuw or nsw allowed");
2520 assert(!Ops.empty() && "Cannot get empty add!");
2521 if (Ops.size() == 1) return Ops[0];
2522#ifndef NDEBUG
2523 Type *ETy = getEffectiveSCEVType(Ops[0]->getType());
2524 for (unsigned i = 1, e = Ops.size(); i != e; ++i)
2525 assert(getEffectiveSCEVType(Ops[i]->getType()) == ETy &&
2526 "SCEVAddExpr operand types don't match!");
2527 unsigned NumPtrs = count_if(
2528 Ops, [](const SCEV *Op) { return Op->getType()->isPointerTy(); });
2529 assert(NumPtrs <= 1 && "add has at most one pointer operand");
2530#endif
2531
2532 const SCEV *Folded = constantFoldAndGroupOps(
2533 *this, LI, DT, Ops,
2534 [](const APInt &C1, const APInt &C2) { return C1 + C2; },
2535 [](const APInt &C) { return C.isZero(); }, // identity
2536 [](const APInt &C) { return false; }); // absorber
2537 if (Folded)
2538 return Folded;
2539
2540 unsigned Idx = isa<SCEVConstant>(Ops[0]) ? 1 : 0;
2541
2542 // Delay expensive flag strengthening until necessary.
2543 auto ComputeFlags = [this, OrigFlags](const ArrayRef<const SCEV *> Ops) {
2544 return StrengthenNoWrapFlags(this, scAddExpr, Ops, OrigFlags);
2545 };
2546
2547 // Limit recursion calls depth.
2549 return getOrCreateAddExpr(Ops, ComputeFlags(Ops));
2550
2551 if (SCEV *S = findExistingSCEVInCache(scAddExpr, Ops)) {
2552 // Don't strengthen flags if we have no new information.
2553 SCEVAddExpr *Add = static_cast<SCEVAddExpr *>(S);
2554 if (Add->getNoWrapFlags(OrigFlags) != OrigFlags)
2555 Add->setNoWrapFlags(ComputeFlags(Ops));
2556 return S;
2557 }
2558
2559 // Okay, check to see if the same value occurs in the operand list more than
2560 // once. If so, merge them together into an multiply expression. Since we
2561 // sorted the list, these values are required to be adjacent.
2562 Type *Ty = Ops[0]->getType();
2563 bool FoundMatch = false;
2564 for (unsigned i = 0, e = Ops.size(); i != e-1; ++i)
2565 if (Ops[i] == Ops[i+1]) { // X + Y + Y --> X + Y*2
2566 // Scan ahead to count how many equal operands there are.
2567 unsigned Count = 2;
2568 while (i+Count != e && Ops[i+Count] == Ops[i])
2569 ++Count;
2570 // Merge the values into a multiply.
2571 const SCEV *Scale = getConstant(Ty, Count);
2572 const SCEV *Mul = getMulExpr(Scale, Ops[i], SCEV::FlagAnyWrap, Depth + 1);
2573 if (Ops.size() == Count)
2574 return Mul;
2575 Ops[i] = Mul;
2576 Ops.erase(Ops.begin()+i+1, Ops.begin()+i+Count);
2577 --i; e -= Count - 1;
2578 FoundMatch = true;
2579 }
2580 if (FoundMatch)
2581 return getAddExpr(Ops, OrigFlags, Depth + 1);
2582
2583 // Check for truncates. If all the operands are truncated from the same
2584 // type, see if factoring out the truncate would permit the result to be
2585 // folded. eg., n*trunc(x) + m*trunc(y) --> trunc(trunc(m)*x + trunc(n)*y)
2586 // if the contents of the resulting outer trunc fold to something simple.
2587 auto FindTruncSrcType = [&]() -> Type * {
2588 // We're ultimately looking to fold an addrec of truncs and muls of only
2589 // constants and truncs, so if we find any other types of SCEV
2590 // as operands of the addrec then we bail and return nullptr here.
2591 // Otherwise, we return the type of the operand of a trunc that we find.
2592 if (auto *T = dyn_cast<SCEVTruncateExpr>(Ops[Idx]))
2593 return T->getOperand()->getType();
2594 if (const auto *Mul = dyn_cast<SCEVMulExpr>(Ops[Idx])) {
2595 const auto *LastOp = Mul->getOperand(Mul->getNumOperands() - 1);
2596 if (const auto *T = dyn_cast<SCEVTruncateExpr>(LastOp))
2597 return T->getOperand()->getType();
2598 }
2599 return nullptr;
2600 };
2601 if (auto *SrcType = FindTruncSrcType()) {
2603 bool Ok = true;
2604 // Check all the operands to see if they can be represented in the
2605 // source type of the truncate.
2606 for (const SCEV *Op : Ops) {
2608 if (T->getOperand()->getType() != SrcType) {
2609 Ok = false;
2610 break;
2611 }
2612 LargeOps.push_back(T->getOperand());
2613 } else if (const SCEVConstant *C = dyn_cast<SCEVConstant>(Op)) {
2614 LargeOps.push_back(getAnyExtendExpr(C, SrcType));
2615 } else if (const SCEVMulExpr *M = dyn_cast<SCEVMulExpr>(Op)) {
2616 SmallVector<const SCEV *, 8> LargeMulOps;
2617 for (unsigned j = 0, f = M->getNumOperands(); j != f && Ok; ++j) {
2618 if (const SCEVTruncateExpr *T =
2619 dyn_cast<SCEVTruncateExpr>(M->getOperand(j))) {
2620 if (T->getOperand()->getType() != SrcType) {
2621 Ok = false;
2622 break;
2623 }
2624 LargeMulOps.push_back(T->getOperand());
2625 } else if (const auto *C = dyn_cast<SCEVConstant>(M->getOperand(j))) {
2626 LargeMulOps.push_back(getAnyExtendExpr(C, SrcType));
2627 } else {
2628 Ok = false;
2629 break;
2630 }
2631 }
2632 if (Ok)
2633 LargeOps.push_back(getMulExpr(LargeMulOps, SCEV::FlagAnyWrap, Depth + 1));
2634 } else {
2635 Ok = false;
2636 break;
2637 }
2638 }
2639 if (Ok) {
2640 // Evaluate the expression in the larger type.
2641 const SCEV *Fold = getAddExpr(LargeOps, SCEV::FlagAnyWrap, Depth + 1);
2642 // If it folds to something simple, use it. Otherwise, don't.
2643 if (isa<SCEVConstant>(Fold) || isa<SCEVUnknown>(Fold))
2644 return getTruncateExpr(Fold, Ty);
2645 }
2646 }
2647
2648 if (Ops.size() == 2) {
2649 // Check if we have an expression of the form ((X + C1) - C2), where C1 and
2650 // C2 can be folded in a way that allows retaining wrapping flags of (X +
2651 // C1).
2652 const SCEV *A = Ops[0];
2653 const SCEV *B = Ops[1];
2654 auto *AddExpr = dyn_cast<SCEVAddExpr>(B);
2655 auto *C = dyn_cast<SCEVConstant>(A);
2656 if (AddExpr && C && isa<SCEVConstant>(AddExpr->getOperand(0))) {
2657 auto C1 = cast<SCEVConstant>(AddExpr->getOperand(0))->getAPInt();
2658 auto C2 = C->getAPInt();
2659 SCEV::NoWrapFlags PreservedFlags = SCEV::FlagAnyWrap;
2660
2661 APInt ConstAdd = C1 + C2;
2662 auto AddFlags = AddExpr->getNoWrapFlags();
2663 // Adding a smaller constant is NUW if the original AddExpr was NUW.
2665 ConstAdd.ule(C1)) {
2666 PreservedFlags =
2668 }
2669
2670 // Adding a constant with the same sign and small magnitude is NSW, if the
2671 // original AddExpr was NSW.
2673 C1.isSignBitSet() == ConstAdd.isSignBitSet() &&
2674 ConstAdd.abs().ule(C1.abs())) {
2675 PreservedFlags =
2677 }
2678
2679 if (PreservedFlags != SCEV::FlagAnyWrap) {
2680 SmallVector<const SCEV *, 4> NewOps(AddExpr->operands());
2681 NewOps[0] = getConstant(ConstAdd);
2682 return getAddExpr(NewOps, PreservedFlags);
2683 }
2684 }
2685
2686 // Try to push the constant operand into a ZExt: A + zext (-A + B) -> zext
2687 // (B), if trunc (A) + -A + B does not unsigned-wrap.
2688 const SCEVAddExpr *InnerAdd;
2689 if (match(B, m_scev_ZExt(m_scev_Add(InnerAdd)))) {
2690 const SCEV *NarrowA = getTruncateExpr(A, InnerAdd->getType());
2691 if (NarrowA == getNegativeSCEV(InnerAdd->getOperand(0)) &&
2692 getZeroExtendExpr(NarrowA, B->getType()) == A &&
2693 hasFlags(StrengthenNoWrapFlags(this, scAddExpr, {NarrowA, InnerAdd},
2695 SCEV::FlagNUW)) {
2696 return getZeroExtendExpr(getAddExpr(NarrowA, InnerAdd), B->getType());
2697 }
2698 }
2699 }
2700
2701 // Canonicalize (-1 * urem X, Y) + X --> (Y * X/Y)
2702 if (Ops.size() == 2) {
2704 if (Mul && Mul->getNumOperands() == 2 &&
2705 Mul->getOperand(0)->isAllOnesValue()) {
2706 const SCEV *X;
2707 const SCEV *Y;
2708 if (matchURem(Mul->getOperand(1), X, Y) && X == Ops[1]) {
2709 return getMulExpr(Y, getUDivExpr(X, Y));
2710 }
2711 }
2712 }
2713
2714 // Skip past any other cast SCEVs.
2715 while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scAddExpr)
2716 ++Idx;
2717
2718 // If there are add operands they would be next.
2719 if (Idx < Ops.size()) {
2720 bool DeletedAdd = false;
2721 // If the original flags and all inlined SCEVAddExprs are NUW, use the
2722 // common NUW flag for expression after inlining. Other flags cannot be
2723 // preserved, because they may depend on the original order of operations.
2724 SCEV::NoWrapFlags CommonFlags = maskFlags(OrigFlags, SCEV::FlagNUW);
2725 while (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Ops[Idx])) {
2726 if (Ops.size() > AddOpsInlineThreshold ||
2727 Add->getNumOperands() > AddOpsInlineThreshold)
2728 break;
2729 // If we have an add, expand the add operands onto the end of the operands
2730 // list.
2731 Ops.erase(Ops.begin()+Idx);
2732 append_range(Ops, Add->operands());
2733 DeletedAdd = true;
2734 CommonFlags = maskFlags(CommonFlags, Add->getNoWrapFlags());
2735 }
2736
2737 // If we deleted at least one add, we added operands to the end of the list,
2738 // and they are not necessarily sorted. Recurse to resort and resimplify
2739 // any operands we just acquired.
2740 if (DeletedAdd)
2741 return getAddExpr(Ops, CommonFlags, Depth + 1);
2742 }
2743
2744 // Skip over the add expression until we get to a multiply.
2745 while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scMulExpr)
2746 ++Idx;
2747
2748 // Check to see if there are any folding opportunities present with
2749 // operands multiplied by constant values.
2750 if (Idx < Ops.size() && isa<SCEVMulExpr>(Ops[Idx])) {
2754 APInt AccumulatedConstant(BitWidth, 0);
2755 if (CollectAddOperandsWithScales(M, NewOps, AccumulatedConstant,
2756 Ops, APInt(BitWidth, 1), *this)) {
2757 struct APIntCompare {
2758 bool operator()(const APInt &LHS, const APInt &RHS) const {
2759 return LHS.ult(RHS);
2760 }
2761 };
2762
2763 // Some interesting folding opportunity is present, so its worthwhile to
2764 // re-generate the operands list. Group the operands by constant scale,
2765 // to avoid multiplying by the same constant scale multiple times.
2766 std::map<APInt, SmallVector<const SCEV *, 4>, APIntCompare> MulOpLists;
2767 for (const SCEV *NewOp : NewOps)
2768 MulOpLists[M.find(NewOp)->second].push_back(NewOp);
2769 // Re-generate the operands list.
2770 Ops.clear();
2771 if (AccumulatedConstant != 0)
2772 Ops.push_back(getConstant(AccumulatedConstant));
2773 for (auto &MulOp : MulOpLists) {
2774 if (MulOp.first == 1) {
2775 Ops.push_back(getAddExpr(MulOp.second, SCEV::FlagAnyWrap, Depth + 1));
2776 } else if (MulOp.first != 0) {
2777 Ops.push_back(getMulExpr(
2778 getConstant(MulOp.first),
2779 getAddExpr(MulOp.second, SCEV::FlagAnyWrap, Depth + 1),
2780 SCEV::FlagAnyWrap, Depth + 1));
2781 }
2782 }
2783 if (Ops.empty())
2784 return getZero(Ty);
2785 if (Ops.size() == 1)
2786 return Ops[0];
2787 return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
2788 }
2789 }
2790
2791 // If we are adding something to a multiply expression, make sure the
2792 // something is not already an operand of the multiply. If so, merge it into
2793 // the multiply.
2794 for (; Idx < Ops.size() && isa<SCEVMulExpr>(Ops[Idx]); ++Idx) {
2795 const SCEVMulExpr *Mul = cast<SCEVMulExpr>(Ops[Idx]);
2796 for (unsigned MulOp = 0, e = Mul->getNumOperands(); MulOp != e; ++MulOp) {
2797 const SCEV *MulOpSCEV = Mul->getOperand(MulOp);
2798 if (isa<SCEVConstant>(MulOpSCEV))
2799 continue;
2800 for (unsigned AddOp = 0, e = Ops.size(); AddOp != e; ++AddOp)
2801 if (MulOpSCEV == Ops[AddOp]) {
2802 // Fold W + X + (X * Y * Z) --> W + (X * ((Y*Z)+1))
2803 const SCEV *InnerMul = Mul->getOperand(MulOp == 0);
2804 if (Mul->getNumOperands() != 2) {
2805 // If the multiply has more than two operands, we must get the
2806 // Y*Z term.
2808 Mul->operands().take_front(MulOp));
2809 append_range(MulOps, Mul->operands().drop_front(MulOp + 1));
2810 InnerMul = getMulExpr(MulOps, SCEV::FlagAnyWrap, Depth + 1);
2811 }
2812 SmallVector<const SCEV *, 2> TwoOps = {getOne(Ty), InnerMul};
2813 const SCEV *AddOne = getAddExpr(TwoOps, SCEV::FlagAnyWrap, Depth + 1);
2814 const SCEV *OuterMul = getMulExpr(AddOne, MulOpSCEV,
2816 if (Ops.size() == 2) return OuterMul;
2817 if (AddOp < Idx) {
2818 Ops.erase(Ops.begin()+AddOp);
2819 Ops.erase(Ops.begin()+Idx-1);
2820 } else {
2821 Ops.erase(Ops.begin()+Idx);
2822 Ops.erase(Ops.begin()+AddOp-1);
2823 }
2824 Ops.push_back(OuterMul);
2825 return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
2826 }
2827
2828 // Check this multiply against other multiplies being added together.
2829 for (unsigned OtherMulIdx = Idx+1;
2830 OtherMulIdx < Ops.size() && isa<SCEVMulExpr>(Ops[OtherMulIdx]);
2831 ++OtherMulIdx) {
2832 const SCEVMulExpr *OtherMul = cast<SCEVMulExpr>(Ops[OtherMulIdx]);
2833 // If MulOp occurs in OtherMul, we can fold the two multiplies
2834 // together.
2835 for (unsigned OMulOp = 0, e = OtherMul->getNumOperands();
2836 OMulOp != e; ++OMulOp)
2837 if (OtherMul->getOperand(OMulOp) == MulOpSCEV) {
2838 // Fold X + (A*B*C) + (A*D*E) --> X + (A*(B*C+D*E))
2839 const SCEV *InnerMul1 = Mul->getOperand(MulOp == 0);
2840 if (Mul->getNumOperands() != 2) {
2842 Mul->operands().take_front(MulOp));
2843 append_range(MulOps, Mul->operands().drop_front(MulOp+1));
2844 InnerMul1 = getMulExpr(MulOps, SCEV::FlagAnyWrap, Depth + 1);
2845 }
2846 const SCEV *InnerMul2 = OtherMul->getOperand(OMulOp == 0);
2847 if (OtherMul->getNumOperands() != 2) {
2849 OtherMul->operands().take_front(OMulOp));
2850 append_range(MulOps, OtherMul->operands().drop_front(OMulOp+1));
2851 InnerMul2 = getMulExpr(MulOps, SCEV::FlagAnyWrap, Depth + 1);
2852 }
2853 SmallVector<const SCEV *, 2> TwoOps = {InnerMul1, InnerMul2};
2854 const SCEV *InnerMulSum =
2855 getAddExpr(TwoOps, SCEV::FlagAnyWrap, Depth + 1);
2856 const SCEV *OuterMul = getMulExpr(MulOpSCEV, InnerMulSum,
2858 if (Ops.size() == 2) return OuterMul;
2859 Ops.erase(Ops.begin()+Idx);
2860 Ops.erase(Ops.begin()+OtherMulIdx-1);
2861 Ops.push_back(OuterMul);
2862 return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
2863 }
2864 }
2865 }
2866 }
2867
2868 // If there are any add recurrences in the operands list, see if any other
2869 // added values are loop invariant. If so, we can fold them into the
2870 // recurrence.
2871 while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scAddRecExpr)
2872 ++Idx;
2873
2874 // Scan over all recurrences, trying to fold loop invariants into them.
2875 for (; Idx < Ops.size() && isa<SCEVAddRecExpr>(Ops[Idx]); ++Idx) {
2876 // Scan all of the other operands to this add and add them to the vector if
2877 // they are loop invariant w.r.t. the recurrence.
2879 const SCEVAddRecExpr *AddRec = cast<SCEVAddRecExpr>(Ops[Idx]);
2880 const Loop *AddRecLoop = AddRec->getLoop();
2881 for (unsigned i = 0, e = Ops.size(); i != e; ++i)
2882 if (isAvailableAtLoopEntry(Ops[i], AddRecLoop)) {
2883 LIOps.push_back(Ops[i]);
2884 Ops.erase(Ops.begin()+i);
2885 --i; --e;
2886 }
2887
2888 // If we found some loop invariants, fold them into the recurrence.
2889 if (!LIOps.empty()) {
2890 // Compute nowrap flags for the addition of the loop-invariant ops and
2891 // the addrec. Temporarily push it as an operand for that purpose. These
2892 // flags are valid in the scope of the addrec only.
2893 LIOps.push_back(AddRec);
2894 SCEV::NoWrapFlags Flags = ComputeFlags(LIOps);
2895 LIOps.pop_back();
2896
2897 // NLI + LI + {Start,+,Step} --> NLI + {LI+Start,+,Step}
2898 LIOps.push_back(AddRec->getStart());
2899
2900 SmallVector<const SCEV *, 4> AddRecOps(AddRec->operands());
2901
2902 // It is not in general safe to propagate flags valid on an add within
2903 // the addrec scope to one outside it. We must prove that the inner
2904 // scope is guaranteed to execute if the outer one does to be able to
2905 // safely propagate. We know the program is undefined if poison is
2906 // produced on the inner scoped addrec. We also know that *for this use*
2907 // the outer scoped add can't overflow (because of the flags we just
2908 // computed for the inner scoped add) without the program being undefined.
2909 // Proving that entry to the outer scope neccesitates entry to the inner
2910 // scope, thus proves the program undefined if the flags would be violated
2911 // in the outer scope.
2912 SCEV::NoWrapFlags AddFlags = Flags;
2913 if (AddFlags != SCEV::FlagAnyWrap) {
2914 auto *DefI = getDefiningScopeBound(LIOps);
2915 auto *ReachI = &*AddRecLoop->getHeader()->begin();
2916 if (!isGuaranteedToTransferExecutionTo(DefI, ReachI))
2917 AddFlags = SCEV::FlagAnyWrap;
2918 }
2919 AddRecOps[0] = getAddExpr(LIOps, AddFlags, Depth + 1);
2920
2921 // Build the new addrec. Propagate the NUW and NSW flags if both the
2922 // outer add and the inner addrec are guaranteed to have no overflow.
2923 // Always propagate NW.
2924 Flags = AddRec->getNoWrapFlags(setFlags(Flags, SCEV::FlagNW));
2925 const SCEV *NewRec = getAddRecExpr(AddRecOps, AddRecLoop, Flags);
2926
2927 // If all of the other operands were loop invariant, we are done.
2928 if (Ops.size() == 1) return NewRec;
2929
2930 // Otherwise, add the folded AddRec by the non-invariant parts.
2931 for (unsigned i = 0;; ++i)
2932 if (Ops[i] == AddRec) {
2933 Ops[i] = NewRec;
2934 break;
2935 }
2936 return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
2937 }
2938
2939 // Okay, if there weren't any loop invariants to be folded, check to see if
2940 // there are multiple AddRec's with the same loop induction variable being
2941 // added together. If so, we can fold them.
2942 for (unsigned OtherIdx = Idx+1;
2943 OtherIdx < Ops.size() && isa<SCEVAddRecExpr>(Ops[OtherIdx]);
2944 ++OtherIdx) {
2945 // We expect the AddRecExpr's to be sorted in reverse dominance order,
2946 // so that the 1st found AddRecExpr is dominated by all others.
2947 assert(DT.dominates(
2948 cast<SCEVAddRecExpr>(Ops[OtherIdx])->getLoop()->getHeader(),
2949 AddRec->getLoop()->getHeader()) &&
2950 "AddRecExprs are not sorted in reverse dominance order?");
2951 if (AddRecLoop == cast<SCEVAddRecExpr>(Ops[OtherIdx])->getLoop()) {
2952 // Other + {A,+,B}<L> + {C,+,D}<L> --> Other + {A+C,+,B+D}<L>
2953 SmallVector<const SCEV *, 4> AddRecOps(AddRec->operands());
2954 for (; OtherIdx != Ops.size() && isa<SCEVAddRecExpr>(Ops[OtherIdx]);
2955 ++OtherIdx) {
2956 const auto *OtherAddRec = cast<SCEVAddRecExpr>(Ops[OtherIdx]);
2957 if (OtherAddRec->getLoop() == AddRecLoop) {
2958 for (unsigned i = 0, e = OtherAddRec->getNumOperands();
2959 i != e; ++i) {
2960 if (i >= AddRecOps.size()) {
2961 append_range(AddRecOps, OtherAddRec->operands().drop_front(i));
2962 break;
2963 }
2965 AddRecOps[i], OtherAddRec->getOperand(i)};
2966 AddRecOps[i] = getAddExpr(TwoOps, SCEV::FlagAnyWrap, Depth + 1);
2967 }
2968 Ops.erase(Ops.begin() + OtherIdx); --OtherIdx;
2969 }
2970 }
2971 // Step size has changed, so we cannot guarantee no self-wraparound.
2972 Ops[Idx] = getAddRecExpr(AddRecOps, AddRecLoop, SCEV::FlagAnyWrap);
2973 return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
2974 }
2975 }
2976
2977 // Otherwise couldn't fold anything into this recurrence. Move onto the
2978 // next one.
2979 }
2980
2981 // Okay, it looks like we really DO need an add expr. Check to see if we
2982 // already have one, otherwise create a new one.
2983 return getOrCreateAddExpr(Ops, ComputeFlags(Ops));
2984}
2985
2986const SCEV *
2987ScalarEvolution::getOrCreateAddExpr(ArrayRef<const SCEV *> Ops,
2988 SCEV::NoWrapFlags Flags) {
2990 ID.AddInteger(scAddExpr);
2991 for (const SCEV *Op : Ops)
2992 ID.AddPointer(Op);
2993 void *IP = nullptr;
2994 SCEVAddExpr *S =
2995 static_cast<SCEVAddExpr *>(UniqueSCEVs.FindNodeOrInsertPos(ID, IP));
2996 if (!S) {
2997 const SCEV **O = SCEVAllocator.Allocate<const SCEV *>(Ops.size());
2999 S = new (SCEVAllocator)
3000 SCEVAddExpr(ID.Intern(SCEVAllocator), O, Ops.size());
3001 UniqueSCEVs.InsertNode(S, IP);
3002 registerUser(S, Ops);
3003 }
3004 S->setNoWrapFlags(Flags);
3005 return S;
3006}
3007
3008const SCEV *
3009ScalarEvolution::getOrCreateAddRecExpr(ArrayRef<const SCEV *> Ops,
3010 const Loop *L, SCEV::NoWrapFlags Flags) {
3011 FoldingSetNodeID ID;
3012 ID.AddInteger(scAddRecExpr);
3013 for (const SCEV *Op : Ops)
3014 ID.AddPointer(Op);
3015 ID.AddPointer(L);
3016 void *IP = nullptr;
3017 SCEVAddRecExpr *S =
3018 static_cast<SCEVAddRecExpr *>(UniqueSCEVs.FindNodeOrInsertPos(ID, IP));
3019 if (!S) {
3020 const SCEV **O = SCEVAllocator.Allocate<const SCEV *>(Ops.size());
3022 S = new (SCEVAllocator)
3023 SCEVAddRecExpr(ID.Intern(SCEVAllocator), O, Ops.size(), L);
3024 UniqueSCEVs.InsertNode(S, IP);
3025 LoopUsers[L].push_back(S);
3026 registerUser(S, Ops);
3027 }
3028 setNoWrapFlags(S, Flags);
3029 return S;
3030}
3031
3032const SCEV *
3033ScalarEvolution::getOrCreateMulExpr(ArrayRef<const SCEV *> Ops,
3034 SCEV::NoWrapFlags Flags) {
3035 FoldingSetNodeID ID;
3036 ID.AddInteger(scMulExpr);
3037 for (const SCEV *Op : Ops)
3038 ID.AddPointer(Op);
3039 void *IP = nullptr;
3040 SCEVMulExpr *S =
3041 static_cast<SCEVMulExpr *>(UniqueSCEVs.FindNodeOrInsertPos(ID, IP));
3042 if (!S) {
3043 const SCEV **O = SCEVAllocator.Allocate<const SCEV *>(Ops.size());
3045 S = new (SCEVAllocator) SCEVMulExpr(ID.Intern(SCEVAllocator),
3046 O, Ops.size());
3047 UniqueSCEVs.InsertNode(S, IP);
3048 registerUser(S, Ops);
3049 }
3050 S->setNoWrapFlags(Flags);
3051 return S;
3052}
3053
3054static uint64_t umul_ov(uint64_t i, uint64_t j, bool &Overflow) {
3055 uint64_t k = i*j;
3056 if (j > 1 && k / j != i) Overflow = true;
3057 return k;
3058}
3059
3060/// Compute the result of "n choose k", the binomial coefficient. If an
3061/// intermediate computation overflows, Overflow will be set and the return will
3062/// be garbage. Overflow is not cleared on absence of overflow.
3063static uint64_t Choose(uint64_t n, uint64_t k, bool &Overflow) {
3064 // We use the multiplicative formula:
3065 // n(n-1)(n-2)...(n-(k-1)) / k(k-1)(k-2)...1 .
3066 // At each iteration, we take the n-th term of the numeral and divide by the
3067 // (k-n)th term of the denominator. This division will always produce an
3068 // integral result, and helps reduce the chance of overflow in the
3069 // intermediate computations. However, we can still overflow even when the
3070 // final result would fit.
3071
3072 if (n == 0 || n == k) return 1;
3073 if (k > n) return 0;
3074
3075 if (k > n/2)
3076 k = n-k;
3077
3078 uint64_t r = 1;
3079 for (uint64_t i = 1; i <= k; ++i) {
3080 r = umul_ov(r, n-(i-1), Overflow);
3081 r /= i;
3082 }
3083 return r;
3084}
3085
3086/// Determine if any of the operands in this SCEV are a constant or if
3087/// any of the add or multiply expressions in this SCEV contain a constant.
3088static bool containsConstantInAddMulChain(const SCEV *StartExpr) {
3089 struct FindConstantInAddMulChain {
3090 bool FoundConstant = false;
3091
3092 bool follow(const SCEV *S) {
3093 FoundConstant |= isa<SCEVConstant>(S);
3094 return isa<SCEVAddExpr>(S) || isa<SCEVMulExpr>(S);
3095 }
3096
3097 bool isDone() const {
3098 return FoundConstant;
3099 }
3100 };
3101
3102 FindConstantInAddMulChain F;
3104 ST.visitAll(StartExpr);
3105 return F.FoundConstant;
3106}
3107
3108/// Get a canonical multiply expression, or something simpler if possible.
3110 SCEV::NoWrapFlags OrigFlags,
3111 unsigned Depth) {
3112 assert(OrigFlags == maskFlags(OrigFlags, SCEV::FlagNUW | SCEV::FlagNSW) &&
3113 "only nuw or nsw allowed");
3114 assert(!Ops.empty() && "Cannot get empty mul!");
3115 if (Ops.size() == 1) return Ops[0];
3116#ifndef NDEBUG
3117 Type *ETy = Ops[0]->getType();
3118 assert(!ETy->isPointerTy());
3119 for (unsigned i = 1, e = Ops.size(); i != e; ++i)
3120 assert(Ops[i]->getType() == ETy &&
3121 "SCEVMulExpr operand types don't match!");
3122#endif
3123
3124 const SCEV *Folded = constantFoldAndGroupOps(
3125 *this, LI, DT, Ops,
3126 [](const APInt &C1, const APInt &C2) { return C1 * C2; },
3127 [](const APInt &C) { return C.isOne(); }, // identity
3128 [](const APInt &C) { return C.isZero(); }); // absorber
3129 if (Folded)
3130 return Folded;
3131
3132 // Delay expensive flag strengthening until necessary.
3133 auto ComputeFlags = [this, OrigFlags](const ArrayRef<const SCEV *> Ops) {
3134 return StrengthenNoWrapFlags(this, scMulExpr, Ops, OrigFlags);
3135 };
3136
3137 // Limit recursion calls depth.
3139 return getOrCreateMulExpr(Ops, ComputeFlags(Ops));
3140
3141 if (SCEV *S = findExistingSCEVInCache(scMulExpr, Ops)) {
3142 // Don't strengthen flags if we have no new information.
3143 SCEVMulExpr *Mul = static_cast<SCEVMulExpr *>(S);
3144 if (Mul->getNoWrapFlags(OrigFlags) != OrigFlags)
3145 Mul->setNoWrapFlags(ComputeFlags(Ops));
3146 return S;
3147 }
3148
3149 if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(Ops[0])) {
3150 if (Ops.size() == 2) {
3151 // C1*(C2+V) -> C1*C2 + C1*V
3152 if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Ops[1]))
3153 // If any of Add's ops are Adds or Muls with a constant, apply this
3154 // transformation as well.
3155 //
3156 // TODO: There are some cases where this transformation is not
3157 // profitable; for example, Add = (C0 + X) * Y + Z. Maybe the scope of
3158 // this transformation should be narrowed down.
3159 if (Add->getNumOperands() == 2 && containsConstantInAddMulChain(Add)) {
3160 const SCEV *LHS = getMulExpr(LHSC, Add->getOperand(0),
3162 const SCEV *RHS = getMulExpr(LHSC, Add->getOperand(1),
3164 return getAddExpr(LHS, RHS, SCEV::FlagAnyWrap, Depth + 1);
3165 }
3166
3167 if (Ops[0]->isAllOnesValue()) {
3168 // If we have a mul by -1 of an add, try distributing the -1 among the
3169 // add operands.
3170 if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Ops[1])) {
3172 bool AnyFolded = false;
3173 for (const SCEV *AddOp : Add->operands()) {
3174 const SCEV *Mul = getMulExpr(Ops[0], AddOp, SCEV::FlagAnyWrap,
3175 Depth + 1);
3176 if (!isa<SCEVMulExpr>(Mul)) AnyFolded = true;
3177 NewOps.push_back(Mul);
3178 }
3179 if (AnyFolded)
3180 return getAddExpr(NewOps, SCEV::FlagAnyWrap, Depth + 1);
3181 } else if (const auto *AddRec = dyn_cast<SCEVAddRecExpr>(Ops[1])) {
3182 // Negation preserves a recurrence's no self-wrap property.
3184 for (const SCEV *AddRecOp : AddRec->operands())
3185 Operands.push_back(getMulExpr(Ops[0], AddRecOp, SCEV::FlagAnyWrap,
3186 Depth + 1));
3187 // Let M be the minimum representable signed value. AddRec with nsw
3188 // multiplied by -1 can have signed overflow if and only if it takes a
3189 // value of M: M * (-1) would stay M and (M + 1) * (-1) would be the
3190 // maximum signed value. In all other cases signed overflow is
3191 // impossible.
3192 auto FlagsMask = SCEV::FlagNW;
3193 if (hasFlags(AddRec->getNoWrapFlags(), SCEV::FlagNSW)) {
3194 auto MinInt =
3195 APInt::getSignedMinValue(getTypeSizeInBits(AddRec->getType()));
3196 if (getSignedRangeMin(AddRec) != MinInt)
3197 FlagsMask = setFlags(FlagsMask, SCEV::FlagNSW);
3198 }
3199 return getAddRecExpr(Operands, AddRec->getLoop(),
3200 AddRec->getNoWrapFlags(FlagsMask));
3201 }
3202 }
3203
3204 // Try to push the constant operand into a ZExt: C * zext (A + B) ->
3205 // zext (C*A + C*B) if trunc (C) * (A + B) does not unsigned-wrap.
3206 const SCEVAddExpr *InnerAdd;
3207 if (match(Ops[1], m_scev_ZExt(m_scev_Add(InnerAdd)))) {
3208 const SCEV *NarrowC = getTruncateExpr(LHSC, InnerAdd->getType());
3209 if (isa<SCEVConstant>(InnerAdd->getOperand(0)) &&
3210 getZeroExtendExpr(NarrowC, Ops[1]->getType()) == LHSC &&
3211 hasFlags(StrengthenNoWrapFlags(this, scMulExpr, {NarrowC, InnerAdd},
3213 SCEV::FlagNUW)) {
3214 auto *Res = getMulExpr(NarrowC, InnerAdd, SCEV::FlagNUW, Depth + 1);
3215 return getZeroExtendExpr(Res, Ops[1]->getType(), Depth + 1);
3216 };
3217 }
3218
3219 // Try to fold (C1 * D /u C2) -> C1/C2 * D, if C1 and C2 are powers-of-2,
3220 // D is a multiple of C2, and C1 is a multiple of C1.
3221 const SCEV *D;
3222 const SCEVConstant *C2;
3223 const APInt &LHSV = LHSC->getAPInt();
3224 if (LHSV.isPowerOf2() &&
3226 C2->getAPInt().isPowerOf2() && LHSV.uge(C2->getAPInt()) &&
3227 LHSV.logBase2() <= getMinTrailingZeros(D)) {
3228 return getMulExpr(getUDivExpr(LHSC, C2), D);
3229 }
3230 }
3231 }
3232
3233 // Skip over the add expression until we get to a multiply.
3234 unsigned Idx = 0;
3235 while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scMulExpr)
3236 ++Idx;
3237
3238 // If there are mul operands inline them all into this expression.
3239 if (Idx < Ops.size()) {
3240 bool DeletedMul = false;
3241 while (const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(Ops[Idx])) {
3242 if (Ops.size() > MulOpsInlineThreshold)
3243 break;
3244 // If we have an mul, expand the mul operands onto the end of the
3245 // operands list.
3246 Ops.erase(Ops.begin()+Idx);
3247 append_range(Ops, Mul->operands());
3248 DeletedMul = true;
3249 }
3250
3251 // If we deleted at least one mul, we added operands to the end of the
3252 // list, and they are not necessarily sorted. Recurse to resort and
3253 // resimplify any operands we just acquired.
3254 if (DeletedMul)
3255 return getMulExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
3256 }
3257
3258 // If there are any add recurrences in the operands list, see if any other
3259 // added values are loop invariant. If so, we can fold them into the
3260 // recurrence.
3261 while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scAddRecExpr)
3262 ++Idx;
3263
3264 // Scan over all recurrences, trying to fold loop invariants into them.
3265 for (; Idx < Ops.size() && isa<SCEVAddRecExpr>(Ops[Idx]); ++Idx) {
3266 // Scan all of the other operands to this mul and add them to the vector
3267 // if they are loop invariant w.r.t. the recurrence.
3269 const SCEVAddRecExpr *AddRec = cast<SCEVAddRecExpr>(Ops[Idx]);
3270 for (unsigned i = 0, e = Ops.size(); i != e; ++i)
3271 if (isAvailableAtLoopEntry(Ops[i], AddRec->getLoop())) {
3272 LIOps.push_back(Ops[i]);
3273 Ops.erase(Ops.begin()+i);
3274 --i; --e;
3275 }
3276
3277 // If we found some loop invariants, fold them into the recurrence.
3278 if (!LIOps.empty()) {
3279 // NLI * LI * {Start,+,Step} --> NLI * {LI*Start,+,LI*Step}
3281 NewOps.reserve(AddRec->getNumOperands());
3282 const SCEV *Scale = getMulExpr(LIOps, SCEV::FlagAnyWrap, Depth + 1);
3283
3284 // If both the mul and addrec are nuw, we can preserve nuw.
3285 // If both the mul and addrec are nsw, we can only preserve nsw if either
3286 // a) they are also nuw, or
3287 // b) all multiplications of addrec operands with scale are nsw.
3288 SCEV::NoWrapFlags Flags =
3289 AddRec->getNoWrapFlags(ComputeFlags({Scale, AddRec}));
3290
3291 for (unsigned i = 0, e = AddRec->getNumOperands(); i != e; ++i) {
3292 NewOps.push_back(getMulExpr(Scale, AddRec->getOperand(i),
3293 SCEV::FlagAnyWrap, Depth + 1));
3294
3295 if (hasFlags(Flags, SCEV::FlagNSW) && !hasFlags(Flags, SCEV::FlagNUW)) {
3297 Instruction::Mul, getSignedRange(Scale),
3299 if (!NSWRegion.contains(getSignedRange(AddRec->getOperand(i))))
3300 Flags = clearFlags(Flags, SCEV::FlagNSW);
3301 }
3302 }
3303
3304 const SCEV *NewRec = getAddRecExpr(NewOps, AddRec->getLoop(), Flags);
3305
3306 // If all of the other operands were loop invariant, we are done.
3307 if (Ops.size() == 1) return NewRec;
3308
3309 // Otherwise, multiply the folded AddRec by the non-invariant parts.
3310 for (unsigned i = 0;; ++i)
3311 if (Ops[i] == AddRec) {
3312 Ops[i] = NewRec;
3313 break;
3314 }
3315 return getMulExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
3316 }
3317
3318 // Okay, if there weren't any loop invariants to be folded, check to see
3319 // if there are multiple AddRec's with the same loop induction variable
3320 // being multiplied together. If so, we can fold them.
3321
3322 // {A1,+,A2,+,...,+,An}<L> * {B1,+,B2,+,...,+,Bn}<L>
3323 // = {x=1 in [ sum y=x..2x [ sum z=max(y-x, y-n)..min(x,n) [
3324 // choose(x, 2x)*choose(2x-y, x-z)*A_{y-z}*B_z
3325 // ]]],+,...up to x=2n}.
3326 // Note that the arguments to choose() are always integers with values
3327 // known at compile time, never SCEV objects.
3328 //
3329 // The implementation avoids pointless extra computations when the two
3330 // addrec's are of different length (mathematically, it's equivalent to
3331 // an infinite stream of zeros on the right).
3332 bool OpsModified = false;
3333 for (unsigned OtherIdx = Idx+1;
3334 OtherIdx != Ops.size() && isa<SCEVAddRecExpr>(Ops[OtherIdx]);
3335 ++OtherIdx) {
3336 const SCEVAddRecExpr *OtherAddRec =
3337 dyn_cast<SCEVAddRecExpr>(Ops[OtherIdx]);
3338 if (!OtherAddRec || OtherAddRec->getLoop() != AddRec->getLoop())
3339 continue;
3340
3341 // Limit max number of arguments to avoid creation of unreasonably big
3342 // SCEVAddRecs with very complex operands.
3343 if (AddRec->getNumOperands() + OtherAddRec->getNumOperands() - 1 >
3344 MaxAddRecSize || hasHugeExpression({AddRec, OtherAddRec}))
3345 continue;
3346
3347 bool Overflow = false;
3348 Type *Ty = AddRec->getType();
3349 bool LargerThan64Bits = getTypeSizeInBits(Ty) > 64;
3351 for (int x = 0, xe = AddRec->getNumOperands() +
3352 OtherAddRec->getNumOperands() - 1; x != xe && !Overflow; ++x) {
3354 for (int y = x, ye = 2*x+1; y != ye && !Overflow; ++y) {
3355 uint64_t Coeff1 = Choose(x, 2*x - y, Overflow);
3356 for (int z = std::max(y-x, y-(int)AddRec->getNumOperands()+1),
3357 ze = std::min(x+1, (int)OtherAddRec->getNumOperands());
3358 z < ze && !Overflow; ++z) {
3359 uint64_t Coeff2 = Choose(2*x - y, x-z, Overflow);
3360 uint64_t Coeff;
3361 if (LargerThan64Bits)
3362 Coeff = umul_ov(Coeff1, Coeff2, Overflow);
3363 else
3364 Coeff = Coeff1*Coeff2;
3365 const SCEV *CoeffTerm = getConstant(Ty, Coeff);
3366 const SCEV *Term1 = AddRec->getOperand(y-z);
3367 const SCEV *Term2 = OtherAddRec->getOperand(z);
3368 SumOps.push_back(getMulExpr(CoeffTerm, Term1, Term2,
3369 SCEV::FlagAnyWrap, Depth + 1));
3370 }
3371 }
3372 if (SumOps.empty())
3373 SumOps.push_back(getZero(Ty));
3374 AddRecOps.push_back(getAddExpr(SumOps, SCEV::FlagAnyWrap, Depth + 1));
3375 }
3376 if (!Overflow) {
3377 const SCEV *NewAddRec = getAddRecExpr(AddRecOps, AddRec->getLoop(),
3379 if (Ops.size() == 2) return NewAddRec;
3380 Ops[Idx] = NewAddRec;
3381 Ops.erase(Ops.begin() + OtherIdx); --OtherIdx;
3382 OpsModified = true;
3383 AddRec = dyn_cast<SCEVAddRecExpr>(NewAddRec);
3384 if (!AddRec)
3385 break;
3386 }
3387 }
3388 if (OpsModified)
3389 return getMulExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
3390
3391 // Otherwise couldn't fold anything into this recurrence. Move onto the
3392 // next one.
3393 }
3394
3395 // Okay, it looks like we really DO need an mul expr. Check to see if we
3396 // already have one, otherwise create a new one.
3397 return getOrCreateMulExpr(Ops, ComputeFlags(Ops));
3398}
3399
3400/// Represents an unsigned remainder expression based on unsigned division.
3402 const SCEV *RHS) {
3403 assert(getEffectiveSCEVType(LHS->getType()) ==
3404 getEffectiveSCEVType(RHS->getType()) &&
3405 "SCEVURemExpr operand types don't match!");
3406
3407 // Short-circuit easy cases
3408 if (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS)) {
3409 // If constant is one, the result is trivial
3410 if (RHSC->getValue()->isOne())
3411 return getZero(LHS->getType()); // X urem 1 --> 0
3412
3413 // If constant is a power of two, fold into a zext(trunc(LHS)).
3414 if (RHSC->getAPInt().isPowerOf2()) {
3415 Type *FullTy = LHS->getType();
3416 Type *TruncTy =
3417 IntegerType::get(getContext(), RHSC->getAPInt().logBase2());
3418 return getZeroExtendExpr(getTruncateExpr(LHS, TruncTy), FullTy);
3419 }
3420 }
3421
3422 // Fallback to %a == %x urem %y == %x -<nuw> ((%x udiv %y) *<nuw> %y)
3423 const SCEV *UDiv = getUDivExpr(LHS, RHS);
3424 const SCEV *Mult = getMulExpr(UDiv, RHS, SCEV::FlagNUW);
3425 return getMinusSCEV(LHS, Mult, SCEV::FlagNUW);
3426}
3427
3428/// Get a canonical unsigned division expression, or something simpler if
3429/// possible.
3431 const SCEV *RHS) {
3432 assert(!LHS->getType()->isPointerTy() &&
3433 "SCEVUDivExpr operand can't be pointer!");
3434 assert(LHS->getType() == RHS->getType() &&
3435 "SCEVUDivExpr operand types don't match!");
3436
3438 ID.AddInteger(scUDivExpr);
3439 ID.AddPointer(LHS);
3440 ID.AddPointer(RHS);
3441 void *IP = nullptr;
3442 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP))
3443 return S;
3444
3445 // 0 udiv Y == 0
3446 if (match(LHS, m_scev_Zero()))
3447 return LHS;
3448
3449 if (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS)) {
3450 if (RHSC->getValue()->isOne())
3451 return LHS; // X udiv 1 --> x
3452 // If the denominator is zero, the result of the udiv is undefined. Don't
3453 // try to analyze it, because the resolution chosen here may differ from
3454 // the resolution chosen in other parts of the compiler.
3455 if (!RHSC->getValue()->isZero()) {
3456 // Determine if the division can be folded into the operands of
3457 // its operands.
3458 // TODO: Generalize this to non-constants by using known-bits information.
3459 Type *Ty = LHS->getType();
3460 unsigned LZ = RHSC->getAPInt().countl_zero();
3461 unsigned MaxShiftAmt = getTypeSizeInBits(Ty) - LZ - 1;
3462 // For non-power-of-two values, effectively round the value up to the
3463 // nearest power of two.
3464 if (!RHSC->getAPInt().isPowerOf2())
3465 ++MaxShiftAmt;
3466 IntegerType *ExtTy =
3467 IntegerType::get(getContext(), getTypeSizeInBits(Ty) + MaxShiftAmt);
3468 if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(LHS))
3469 if (const SCEVConstant *Step =
3470 dyn_cast<SCEVConstant>(AR->getStepRecurrence(*this))) {
3471 // {X,+,N}/C --> {X/C,+,N/C} if safe and N/C can be folded.
3472 const APInt &StepInt = Step->getAPInt();
3473 const APInt &DivInt = RHSC->getAPInt();
3474 if (!StepInt.urem(DivInt) &&
3475 getZeroExtendExpr(AR, ExtTy) ==
3476 getAddRecExpr(getZeroExtendExpr(AR->getStart(), ExtTy),
3477 getZeroExtendExpr(Step, ExtTy),
3478 AR->getLoop(), SCEV::FlagAnyWrap)) {
3480 for (const SCEV *Op : AR->operands())
3481 Operands.push_back(getUDivExpr(Op, RHS));
3482 return getAddRecExpr(Operands, AR->getLoop(), SCEV::FlagNW);
3483 }
3484 /// Get a canonical UDivExpr for a recurrence.
3485 /// {X,+,N}/C => {Y,+,N}/C where Y=X-(X%N). Safe when C%N=0.
3486 // We can currently only fold X%N if X is constant.
3487 const SCEVConstant *StartC = dyn_cast<SCEVConstant>(AR->getStart());
3488 if (StartC && !DivInt.urem(StepInt) &&
3489 getZeroExtendExpr(AR, ExtTy) ==
3490 getAddRecExpr(getZeroExtendExpr(AR->getStart(), ExtTy),
3491 getZeroExtendExpr(Step, ExtTy),
3492 AR->getLoop(), SCEV::FlagAnyWrap)) {
3493 const APInt &StartInt = StartC->getAPInt();
3494 const APInt &StartRem = StartInt.urem(StepInt);
3495 if (StartRem != 0) {
3496 const SCEV *NewLHS =
3497 getAddRecExpr(getConstant(StartInt - StartRem), Step,
3498 AR->getLoop(), SCEV::FlagNW);
3499 if (LHS != NewLHS) {
3500 LHS = NewLHS;
3501
3502 // Reset the ID to include the new LHS, and check if it is
3503 // already cached.
3504 ID.clear();
3505 ID.AddInteger(scUDivExpr);
3506 ID.AddPointer(LHS);
3507 ID.AddPointer(RHS);
3508 IP = nullptr;
3509 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP))
3510 return S;
3511 }
3512 }
3513 }
3514 }
3515 // (A*B)/C --> A*(B/C) if safe and B/C can be folded.
3516 if (const SCEVMulExpr *M = dyn_cast<SCEVMulExpr>(LHS)) {
3518 for (const SCEV *Op : M->operands())
3519 Operands.push_back(getZeroExtendExpr(Op, ExtTy));
3520 if (getZeroExtendExpr(M, ExtTy) == getMulExpr(Operands))
3521 // Find an operand that's safely divisible.
3522 for (unsigned i = 0, e = M->getNumOperands(); i != e; ++i) {
3523 const SCEV *Op = M->getOperand(i);
3524 const SCEV *Div = getUDivExpr(Op, RHSC);
3525 if (!isa<SCEVUDivExpr>(Div) && getMulExpr(Div, RHSC) == Op) {
3526 Operands = SmallVector<const SCEV *, 4>(M->operands());
3527 Operands[i] = Div;
3528 return getMulExpr(Operands);
3529 }
3530 }
3531 }
3532
3533 // (A/B)/C --> A/(B*C) if safe and B*C can be folded.
3534 if (const SCEVUDivExpr *OtherDiv = dyn_cast<SCEVUDivExpr>(LHS)) {
3535 if (auto *DivisorConstant =
3536 dyn_cast<SCEVConstant>(OtherDiv->getRHS())) {
3537 bool Overflow = false;
3538 APInt NewRHS =
3539 DivisorConstant->getAPInt().umul_ov(RHSC->getAPInt(), Overflow);
3540 if (Overflow) {
3541 return getConstant(RHSC->getType(), 0, false);
3542 }
3543 return getUDivExpr(OtherDiv->getLHS(), getConstant(NewRHS));
3544 }
3545 }
3546
3547 // (A+B)/C --> (A/C + B/C) if safe and A/C and B/C can be folded.
3548 if (const SCEVAddExpr *A = dyn_cast<SCEVAddExpr>(LHS)) {
3550 for (const SCEV *Op : A->operands())
3551 Operands.push_back(getZeroExtendExpr(Op, ExtTy));
3552 if (getZeroExtendExpr(A, ExtTy) == getAddExpr(Operands)) {
3553 Operands.clear();
3554 for (unsigned i = 0, e = A->getNumOperands(); i != e; ++i) {
3555 const SCEV *Op = getUDivExpr(A->getOperand(i), RHS);
3556 if (isa<SCEVUDivExpr>(Op) ||
3557 getMulExpr(Op, RHS) != A->getOperand(i))
3558 break;
3559 Operands.push_back(Op);
3560 }
3561 if (Operands.size() == A->getNumOperands())
3562 return getAddExpr(Operands);
3563 }
3564 }
3565
3566 // Fold if both operands are constant.
3567 if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(LHS))
3568 return getConstant(LHSC->getAPInt().udiv(RHSC->getAPInt()));
3569 }
3570 }
3571
3572 // ((-C + (C smax %x)) /u %x) evaluates to zero, for any positive constant C.
3573 if (const auto *AE = dyn_cast<SCEVAddExpr>(LHS);
3574 AE && AE->getNumOperands() == 2) {
3575 if (const auto *VC = dyn_cast<SCEVConstant>(AE->getOperand(0))) {
3576 const APInt &NegC = VC->getAPInt();
3577 if (NegC.isNegative() && !NegC.isMinSignedValue()) {
3578 const auto *MME = dyn_cast<SCEVSMaxExpr>(AE->getOperand(1));
3579 if (MME && MME->getNumOperands() == 2 &&
3580 isa<SCEVConstant>(MME->getOperand(0)) &&
3581 cast<SCEVConstant>(MME->getOperand(0))->getAPInt() == -NegC &&
3582 MME->getOperand(1) == RHS)
3583 return getZero(LHS->getType());
3584 }
3585 }
3586 }
3587
3588 // The Insertion Point (IP) might be invalid by now (due to UniqueSCEVs
3589 // changes). Make sure we get a new one.
3590 IP = nullptr;
3591 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
3592 SCEV *S = new (SCEVAllocator) SCEVUDivExpr(ID.Intern(SCEVAllocator),
3593 LHS, RHS);
3594 UniqueSCEVs.InsertNode(S, IP);
3595 registerUser(S, {LHS, RHS});
3596 return S;
3597}
3598
3599APInt gcd(const SCEVConstant *C1, const SCEVConstant *C2) {
3600 APInt A = C1->getAPInt().abs();
3601 APInt B = C2->getAPInt().abs();
3602 uint32_t ABW = A.getBitWidth();
3603 uint32_t BBW = B.getBitWidth();
3604
3605 if (ABW > BBW)
3606 B = B.zext(ABW);
3607 else if (ABW < BBW)
3608 A = A.zext(BBW);
3609
3610 return APIntOps::GreatestCommonDivisor(std::move(A), std::move(B));
3611}
3612
3613/// Get a canonical unsigned division expression, or something simpler if
3614/// possible. There is no representation for an exact udiv in SCEV IR, but we
3615/// can attempt to remove factors from the LHS and RHS. We can't do this when
3616/// it's not exact because the udiv may be clearing bits.
3618 const SCEV *RHS) {
3619 // TODO: we could try to find factors in all sorts of things, but for now we
3620 // just deal with u/exact (multiply, constant). See SCEVDivision towards the
3621 // end of this file for inspiration.
3622
3624 if (!Mul || !Mul->hasNoUnsignedWrap())
3625 return getUDivExpr(LHS, RHS);
3626
3627 if (const SCEVConstant *RHSCst = dyn_cast<SCEVConstant>(RHS)) {
3628 // If the mulexpr multiplies by a constant, then that constant must be the
3629 // first element of the mulexpr.
3630 if (const auto *LHSCst = dyn_cast<SCEVConstant>(Mul->getOperand(0))) {
3631 if (LHSCst == RHSCst) {
3633 return getMulExpr(Operands);
3634 }
3635
3636 // We can't just assume that LHSCst divides RHSCst cleanly, it could be
3637 // that there's a factor provided by one of the other terms. We need to
3638 // check.
3639 APInt Factor = gcd(LHSCst, RHSCst);
3640 if (!Factor.isIntN(1)) {
3641 LHSCst =
3642 cast<SCEVConstant>(getConstant(LHSCst->getAPInt().udiv(Factor)));
3643 RHSCst =
3644 cast<SCEVConstant>(getConstant(RHSCst->getAPInt().udiv(Factor)));
3646 Operands.push_back(LHSCst);
3647 append_range(Operands, Mul->operands().drop_front());
3648 LHS = getMulExpr(Operands);
3649 RHS = RHSCst;
3651 if (!Mul)
3652 return getUDivExactExpr(LHS, RHS);
3653 }
3654 }
3655 }
3656
3657 for (int i = 0, e = Mul->getNumOperands(); i != e; ++i) {
3658 if (Mul->getOperand(i) == RHS) {
3660 append_range(Operands, Mul->operands().take_front(i));
3661 append_range(Operands, Mul->operands().drop_front(i + 1));
3662 return getMulExpr(Operands);
3663 }
3664 }
3665
3666 return getUDivExpr(LHS, RHS);
3667}
3668
3669/// Get an add recurrence expression for the specified loop. Simplify the
3670/// expression as much as possible.
3671const SCEV *ScalarEvolution::getAddRecExpr(const SCEV *Start, const SCEV *Step,
3672 const Loop *L,
3673 SCEV::NoWrapFlags Flags) {
3675 Operands.push_back(Start);
3676 if (const SCEVAddRecExpr *StepChrec = dyn_cast<SCEVAddRecExpr>(Step))
3677 if (StepChrec->getLoop() == L) {
3678 append_range(Operands, StepChrec->operands());
3679 return getAddRecExpr(Operands, L, maskFlags(Flags, SCEV::FlagNW));
3680 }
3681
3682 Operands.push_back(Step);
3683 return getAddRecExpr(Operands, L, Flags);
3684}
3685
3686/// Get an add recurrence expression for the specified loop. Simplify the
3687/// expression as much as possible.
3688const SCEV *
3690 const Loop *L, SCEV::NoWrapFlags Flags) {
3691 if (Operands.size() == 1) return Operands[0];
3692#ifndef NDEBUG
3694 for (const SCEV *Op : llvm::drop_begin(Operands)) {
3695 assert(getEffectiveSCEVType(Op->getType()) == ETy &&
3696 "SCEVAddRecExpr operand types don't match!");
3697 assert(!Op->getType()->isPointerTy() && "Step must be integer");
3698 }
3699 for (const SCEV *Op : Operands)
3701 "SCEVAddRecExpr operand is not available at loop entry!");
3702#endif
3703
3704 if (Operands.back()->isZero()) {
3705 Operands.pop_back();
3706 return getAddRecExpr(Operands, L, SCEV::FlagAnyWrap); // {X,+,0} --> X
3707 }
3708
3709 // It's tempting to want to call getConstantMaxBackedgeTakenCount count here and
3710 // use that information to infer NUW and NSW flags. However, computing a
3711 // BE count requires calling getAddRecExpr, so we may not yet have a
3712 // meaningful BE count at this point (and if we don't, we'd be stuck
3713 // with a SCEVCouldNotCompute as the cached BE count).
3714
3715 Flags = StrengthenNoWrapFlags(this, scAddRecExpr, Operands, Flags);
3716
3717 // Canonicalize nested AddRecs in by nesting them in order of loop depth.
3718 if (const SCEVAddRecExpr *NestedAR = dyn_cast<SCEVAddRecExpr>(Operands[0])) {
3719 const Loop *NestedLoop = NestedAR->getLoop();
3720 if (L->contains(NestedLoop)
3721 ? (L->getLoopDepth() < NestedLoop->getLoopDepth())
3722 : (!NestedLoop->contains(L) &&
3723 DT.dominates(L->getHeader(), NestedLoop->getHeader()))) {
3724 SmallVector<const SCEV *, 4> NestedOperands(NestedAR->operands());
3725 Operands[0] = NestedAR->getStart();
3726 // AddRecs require their operands be loop-invariant with respect to their
3727 // loops. Don't perform this transformation if it would break this
3728 // requirement.
3729 bool AllInvariant = all_of(
3730 Operands, [&](const SCEV *Op) { return isLoopInvariant(Op, L); });
3731
3732 if (AllInvariant) {
3733 // Create a recurrence for the outer loop with the same step size.
3734 //
3735 // The outer recurrence keeps its NW flag but only keeps NUW/NSW if the
3736 // inner recurrence has the same property.
3737 SCEV::NoWrapFlags OuterFlags =
3738 maskFlags(Flags, SCEV::FlagNW | NestedAR->getNoWrapFlags());
3739
3740 NestedOperands[0] = getAddRecExpr(Operands, L, OuterFlags);
3741 AllInvariant = all_of(NestedOperands, [&](const SCEV *Op) {
3742 return isLoopInvariant(Op, NestedLoop);
3743 });
3744
3745 if (AllInvariant) {
3746 // Ok, both add recurrences are valid after the transformation.
3747 //
3748 // The inner recurrence keeps its NW flag but only keeps NUW/NSW if
3749 // the outer recurrence has the same property.
3750 SCEV::NoWrapFlags InnerFlags =
3751 maskFlags(NestedAR->getNoWrapFlags(), SCEV::FlagNW | Flags);
3752 return getAddRecExpr(NestedOperands, NestedLoop, InnerFlags);
3753 }
3754 }
3755 // Reset Operands to its original state.
3756 Operands[0] = NestedAR;
3757 }
3758 }
3759
3760 // Okay, it looks like we really DO need an addrec expr. Check to see if we
3761 // already have one, otherwise create a new one.
3762 return getOrCreateAddRecExpr(Operands, L, Flags);
3763}
3764
3765const SCEV *
3767 const SmallVectorImpl<const SCEV *> &IndexExprs) {
3768 const SCEV *BaseExpr = getSCEV(GEP->getPointerOperand());
3769 // getSCEV(Base)->getType() has the same address space as Base->getType()
3770 // because SCEV::getType() preserves the address space.
3771 Type *IntIdxTy = getEffectiveSCEVType(BaseExpr->getType());
3772 GEPNoWrapFlags NW = GEP->getNoWrapFlags();
3773 if (NW != GEPNoWrapFlags::none()) {
3774 // We'd like to propagate flags from the IR to the corresponding SCEV nodes,
3775 // but to do that, we have to ensure that said flag is valid in the entire
3776 // defined scope of the SCEV.
3777 // TODO: non-instructions have global scope. We might be able to prove
3778 // some global scope cases
3779 auto *GEPI = dyn_cast<Instruction>(GEP);
3780 if (!GEPI || !isSCEVExprNeverPoison(GEPI))
3781 NW = GEPNoWrapFlags::none();
3782 }
3783
3785 if (NW.hasNoUnsignedSignedWrap())
3786 OffsetWrap = setFlags(OffsetWrap, SCEV::FlagNSW);
3787 if (NW.hasNoUnsignedWrap())
3788 OffsetWrap = setFlags(OffsetWrap, SCEV::FlagNUW);
3789
3790 Type *CurTy = GEP->getType();
3791 bool FirstIter = true;
3793 for (const SCEV *IndexExpr : IndexExprs) {
3794 // Compute the (potentially symbolic) offset in bytes for this index.
3795 if (StructType *STy = dyn_cast<StructType>(CurTy)) {
3796 // For a struct, add the member offset.
3797 ConstantInt *Index = cast<SCEVConstant>(IndexExpr)->getValue();
3798 unsigned FieldNo = Index->getZExtValue();
3799 const SCEV *FieldOffset = getOffsetOfExpr(IntIdxTy, STy, FieldNo);
3800 Offsets.push_back(FieldOffset);
3801
3802 // Update CurTy to the type of the field at Index.
3803 CurTy = STy->getTypeAtIndex(Index);
3804 } else {
3805 // Update CurTy to its element type.
3806 if (FirstIter) {
3807 assert(isa<PointerType>(CurTy) &&
3808 "The first index of a GEP indexes a pointer");
3809 CurTy = GEP->getSourceElementType();
3810 FirstIter = false;
3811 } else {
3813 }
3814 // For an array, add the element offset, explicitly scaled.
3815 const SCEV *ElementSize = getSizeOfExpr(IntIdxTy, CurTy);
3816 // Getelementptr indices are signed.
3817 IndexExpr = getTruncateOrSignExtend(IndexExpr, IntIdxTy);
3818
3819 // Multiply the index by the element size to compute the element offset.
3820 const SCEV *LocalOffset = getMulExpr(IndexExpr, ElementSize, OffsetWrap);
3821 Offsets.push_back(LocalOffset);
3822 }
3823 }
3824
3825 // Handle degenerate case of GEP without offsets.
3826 if (Offsets.empty())
3827 return BaseExpr;
3828
3829 // Add the offsets together, assuming nsw if inbounds.
3830 const SCEV *Offset = getAddExpr(Offsets, OffsetWrap);
3831 // Add the base address and the offset. We cannot use the nsw flag, as the
3832 // base address is unsigned. However, if we know that the offset is
3833 // non-negative, we can use nuw.
3834 bool NUW = NW.hasNoUnsignedWrap() ||
3837 auto *GEPExpr = getAddExpr(BaseExpr, Offset, BaseWrap);
3838 assert(BaseExpr->getType() == GEPExpr->getType() &&
3839 "GEP should not change type mid-flight.");
3840 return GEPExpr;
3841}
3842
3843SCEV *ScalarEvolution::findExistingSCEVInCache(SCEVTypes SCEVType,
3846 ID.AddInteger(SCEVType);
3847 for (const SCEV *Op : Ops)
3848 ID.AddPointer(Op);
3849 void *IP = nullptr;
3850 return UniqueSCEVs.FindNodeOrInsertPos(ID, IP);
3851}
3852
3853const SCEV *ScalarEvolution::getAbsExpr(const SCEV *Op, bool IsNSW) {
3855 return getSMaxExpr(Op, getNegativeSCEV(Op, Flags));
3856}
3857
3860 assert(SCEVMinMaxExpr::isMinMaxType(Kind) && "Not a SCEVMinMaxExpr!");
3861 assert(!Ops.empty() && "Cannot get empty (u|s)(min|max)!");
3862 if (Ops.size() == 1) return Ops[0];
3863#ifndef NDEBUG
3864 Type *ETy = getEffectiveSCEVType(Ops[0]->getType());
3865 for (unsigned i = 1, e = Ops.size(); i != e; ++i) {
3866 assert(getEffectiveSCEVType(Ops[i]->getType()) == ETy &&
3867 "Operand types don't match!");
3868 assert(Ops[0]->getType()->isPointerTy() ==
3869 Ops[i]->getType()->isPointerTy() &&
3870 "min/max should be consistently pointerish");
3871 }
3872#endif
3873
3874 bool IsSigned = Kind == scSMaxExpr || Kind == scSMinExpr;
3875 bool IsMax = Kind == scSMaxExpr || Kind == scUMaxExpr;
3876
3877 const SCEV *Folded = constantFoldAndGroupOps(
3878 *this, LI, DT, Ops,
3879 [&](const APInt &C1, const APInt &C2) {
3880 switch (Kind) {
3881 case scSMaxExpr:
3882 return APIntOps::smax(C1, C2);
3883 case scSMinExpr:
3884 return APIntOps::smin(C1, C2);
3885 case scUMaxExpr:
3886 return APIntOps::umax(C1, C2);
3887 case scUMinExpr:
3888 return APIntOps::umin(C1, C2);
3889 default:
3890 llvm_unreachable("Unknown SCEV min/max opcode");
3891 }
3892 },
3893 [&](const APInt &C) {
3894 // identity
3895 if (IsMax)
3896 return IsSigned ? C.isMinSignedValue() : C.isMinValue();
3897 else
3898 return IsSigned ? C.isMaxSignedValue() : C.isMaxValue();
3899 },
3900 [&](const APInt &C) {
3901 // absorber
3902 if (IsMax)
3903 return IsSigned ? C.isMaxSignedValue() : C.isMaxValue();
3904 else
3905 return IsSigned ? C.isMinSignedValue() : C.isMinValue();
3906 });
3907 if (Folded)
3908 return Folded;
3909
3910 // Check if we have created the same expression before.
3911 if (const SCEV *S = findExistingSCEVInCache(Kind, Ops)) {
3912 return S;
3913 }
3914
3915 // Find the first operation of the same kind
3916 unsigned Idx = 0;
3917 while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < Kind)
3918 ++Idx;
3919
3920 // Check to see if one of the operands is of the same kind. If so, expand its
3921 // operands onto our operand list, and recurse to simplify.
3922 if (Idx < Ops.size()) {
3923 bool DeletedAny = false;
3924 while (Ops[Idx]->getSCEVType() == Kind) {
3925 const SCEVMinMaxExpr *SMME = cast<SCEVMinMaxExpr>(Ops[Idx]);
3926 Ops.erase(Ops.begin()+Idx);
3927 append_range(Ops, SMME->operands());
3928 DeletedAny = true;
3929 }
3930
3931 if (DeletedAny)
3932 return getMinMaxExpr(Kind, Ops);
3933 }
3934
3935 // Okay, check to see if the same value occurs in the operand list twice. If
3936 // so, delete one. Since we sorted the list, these values are required to
3937 // be adjacent.
3942 llvm::CmpInst::Predicate FirstPred = IsMax ? GEPred : LEPred;
3943 llvm::CmpInst::Predicate SecondPred = IsMax ? LEPred : GEPred;
3944 for (unsigned i = 0, e = Ops.size() - 1; i != e; ++i) {
3945 if (Ops[i] == Ops[i + 1] ||
3946 isKnownViaNonRecursiveReasoning(FirstPred, Ops[i], Ops[i + 1])) {
3947 // X op Y op Y --> X op Y
3948 // X op Y --> X, if we know X, Y are ordered appropriately
3949 Ops.erase(Ops.begin() + i + 1, Ops.begin() + i + 2);
3950 --i;
3951 --e;
3952 } else if (isKnownViaNonRecursiveReasoning(SecondPred, Ops[i],
3953 Ops[i + 1])) {
3954 // X op Y --> Y, if we know X, Y are ordered appropriately
3955 Ops.erase(Ops.begin() + i, Ops.begin() + i + 1);
3956 --i;
3957 --e;
3958 }
3959 }
3960
3961 if (Ops.size() == 1) return Ops[0];
3962
3963 assert(!Ops.empty() && "Reduced smax down to nothing!");
3964
3965 // Okay, it looks like we really DO need an expr. Check to see if we
3966 // already have one, otherwise create a new one.
3968 ID.AddInteger(Kind);
3969 for (const SCEV *Op : Ops)
3970 ID.AddPointer(Op);
3971 void *IP = nullptr;
3972 const SCEV *ExistingSCEV = UniqueSCEVs.FindNodeOrInsertPos(ID, IP);
3973 if (ExistingSCEV)
3974 return ExistingSCEV;
3975 const SCEV **O = SCEVAllocator.Allocate<const SCEV *>(Ops.size());
3977 SCEV *S = new (SCEVAllocator)
3978 SCEVMinMaxExpr(ID.Intern(SCEVAllocator), Kind, O, Ops.size());
3979
3980 UniqueSCEVs.InsertNode(S, IP);
3981 registerUser(S, Ops);
3982 return S;
3983}
3984
3985namespace {
3986
3987class SCEVSequentialMinMaxDeduplicatingVisitor final
3988 : public SCEVVisitor<SCEVSequentialMinMaxDeduplicatingVisitor,
3989 std::optional<const SCEV *>> {
3990 using RetVal = std::optional<const SCEV *>;
3992
3993 ScalarEvolution &SE;
3994 const SCEVTypes RootKind; // Must be a sequential min/max expression.
3995 const SCEVTypes NonSequentialRootKind; // Non-sequential variant of RootKind.
3997
3998 bool canRecurseInto(SCEVTypes Kind) const {
3999 // We can only recurse into the SCEV expression of the same effective type
4000 // as the type of our root SCEV expression.
4001 return RootKind == Kind || NonSequentialRootKind == Kind;
4002 };
4003
4004 RetVal visitAnyMinMaxExpr(const SCEV *S) {
4006 "Only for min/max expressions.");
4007 SCEVTypes Kind = S->getSCEVType();
4008
4009 if (!canRecurseInto(Kind))
4010 return S;
4011
4012 auto *NAry = cast<SCEVNAryExpr>(S);
4014 bool Changed = visit(Kind, NAry->operands(), NewOps);
4015
4016 if (!Changed)
4017 return S;
4018 if (NewOps.empty())
4019 return std::nullopt;
4020
4022 ? SE.getSequentialMinMaxExpr(Kind, NewOps)
4023 : SE.getMinMaxExpr(Kind, NewOps);
4024 }
4025
4026 RetVal visit(const SCEV *S) {
4027 // Has the whole operand been seen already?
4028 if (!SeenOps.insert(S).second)
4029 return std::nullopt;
4030 return Base::visit(S);
4031 }
4032
4033public:
4034 SCEVSequentialMinMaxDeduplicatingVisitor(ScalarEvolution &SE,
4035 SCEVTypes RootKind)
4036 : SE(SE), RootKind(RootKind),
4037 NonSequentialRootKind(
4038 SCEVSequentialMinMaxExpr::getEquivalentNonSequentialSCEVType(
4039 RootKind)) {}
4040
4041 bool /*Changed*/ visit(SCEVTypes Kind, ArrayRef<const SCEV *> OrigOps,
4042 SmallVectorImpl<const SCEV *> &NewOps) {
4043 bool Changed = false;
4045 Ops.reserve(OrigOps.size());
4046
4047 for (const SCEV *Op : OrigOps) {
4048 RetVal NewOp = visit(Op);
4049 if (NewOp != Op)
4050 Changed = true;
4051 if (NewOp)
4052 Ops.emplace_back(*NewOp);
4053 }
4054
4055 if (Changed)
4056 NewOps = std::move(Ops);
4057 return Changed;
4058 }
4059
4060 RetVal visitConstant(const SCEVConstant *Constant) { return Constant; }
4061
4062 RetVal visitVScale(const SCEVVScale *VScale) { return VScale; }
4063
4064 RetVal visitPtrToIntExpr(const SCEVPtrToIntExpr *Expr) { return Expr; }
4065
4066 RetVal visitTruncateExpr(const SCEVTruncateExpr *Expr) { return Expr; }
4067
4068 RetVal visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr) { return Expr; }
4069
4070 RetVal visitSignExtendExpr(const SCEVSignExtendExpr *Expr) { return Expr; }
4071
4072 RetVal visitAddExpr(const SCEVAddExpr *Expr) { return Expr; }
4073
4074 RetVal visitMulExpr(const SCEVMulExpr *Expr) { return Expr; }
4075
4076 RetVal visitUDivExpr(const SCEVUDivExpr *Expr) { return Expr; }
4077
4078 RetVal visitAddRecExpr(const SCEVAddRecExpr *Expr) { return Expr; }
4079
4080 RetVal visitSMaxExpr(const SCEVSMaxExpr *Expr) {
4081 return visitAnyMinMaxExpr(Expr);
4082 }
4083
4084 RetVal visitUMaxExpr(const SCEVUMaxExpr *Expr) {
4085 return visitAnyMinMaxExpr(Expr);
4086 }
4087
4088 RetVal visitSMinExpr(const SCEVSMinExpr *Expr) {
4089 return visitAnyMinMaxExpr(Expr);
4090 }
4091
4092 RetVal visitUMinExpr(const SCEVUMinExpr *Expr) {
4093 return visitAnyMinMaxExpr(Expr);
4094 }
4095
4096 RetVal visitSequentialUMinExpr(const SCEVSequentialUMinExpr *Expr) {
4097 return visitAnyMinMaxExpr(Expr);
4098 }
4099
4100 RetVal visitUnknown(const SCEVUnknown *Expr) { return Expr; }
4101
4102 RetVal visitCouldNotCompute(const SCEVCouldNotCompute *Expr) { return Expr; }
4103};
4104
4105} // namespace
4106
4108 switch (Kind) {
4109 case scConstant:
4110 case scVScale:
4111 case scTruncate:
4112 case scZeroExtend:
4113 case scSignExtend:
4114 case scPtrToInt:
4115 case scAddExpr:
4116 case scMulExpr:
4117 case scUDivExpr:
4118 case scAddRecExpr:
4119 case scUMaxExpr:
4120 case scSMaxExpr:
4121 case scUMinExpr:
4122 case scSMinExpr:
4123 case scUnknown:
4124 // If any operand is poison, the whole expression is poison.
4125 return true;
4127 // FIXME: if the *first* operand is poison, the whole expression is poison.
4128 return false; // Pessimistically, say that it does not propagate poison.
4129 case scCouldNotCompute:
4130 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
4131 }
4132 llvm_unreachable("Unknown SCEV kind!");
4133}
4134
4135namespace {
4136// The only way poison may be introduced in a SCEV expression is from a
4137// poison SCEVUnknown (ConstantExprs are also represented as SCEVUnknown,
4138// not SCEVConstant). Notably, nowrap flags in SCEV nodes can *not*
4139// introduce poison -- they encode guaranteed, non-speculated knowledge.
4140//
4141// Additionally, all SCEV nodes propagate poison from inputs to outputs,
4142// with the notable exception of umin_seq, where only poison from the first
4143// operand is (unconditionally) propagated.
4144struct SCEVPoisonCollector {
4145 bool LookThroughMaybePoisonBlocking;
4146 SmallPtrSet<const SCEVUnknown *, 4> MaybePoison;
4147 SCEVPoisonCollector(bool LookThroughMaybePoisonBlocking)
4148 : LookThroughMaybePoisonBlocking(LookThroughMaybePoisonBlocking) {}
4149
4150 bool follow(const SCEV *S) {
4151 if (!LookThroughMaybePoisonBlocking &&
4153 return false;
4154
4155 if (auto *SU = dyn_cast<SCEVUnknown>(S)) {
4156 if (!isGuaranteedNotToBePoison(SU->getValue()))
4157 MaybePoison.insert(SU);
4158 }
4159 return true;
4160 }
4161 bool isDone() const { return false; }
4162};
4163} // namespace
4164
4165/// Return true if V is poison given that AssumedPoison is already poison.
4166static bool impliesPoison(const SCEV *AssumedPoison, const SCEV *S) {
4167 // First collect all SCEVs that might result in AssumedPoison to be poison.
4168 // We need to look through potentially poison-blocking operations here,
4169 // because we want to find all SCEVs that *might* result in poison, not only
4170 // those that are *required* to.
4171 SCEVPoisonCollector PC1(/* LookThroughMaybePoisonBlocking */ true);
4172 visitAll(AssumedPoison, PC1);
4173
4174 // AssumedPoison is never poison. As the assumption is false, the implication
4175 // is true. Don't bother walking the other SCEV in this case.
4176 if (PC1.MaybePoison.empty())
4177 return true;
4178
4179 // Collect all SCEVs in S that, if poison, *will* result in S being poison
4180 // as well. We cannot look through potentially poison-blocking operations
4181 // here, as their arguments only *may* make the result poison.
4182 SCEVPoisonCollector PC2(/* LookThroughMaybePoisonBlocking */ false);
4183 visitAll(S, PC2);
4184
4185 // Make sure that no matter which SCEV in PC1.MaybePoison is actually poison,
4186 // it will also make S poison by being part of PC2.MaybePoison.
4187 return llvm::set_is_subset(PC1.MaybePoison, PC2.MaybePoison);
4188}
4189
4191 SmallPtrSetImpl<const Value *> &Result, const SCEV *S) {
4192 SCEVPoisonCollector PC(/* LookThroughMaybePoisonBlocking */ false);
4193 visitAll(S, PC);
4194 for (const SCEVUnknown *SU : PC.MaybePoison)
4195 Result.insert(SU->getValue());
4196}
4197
4199 const SCEV *S, Instruction *I,
4200 SmallVectorImpl<Instruction *> &DropPoisonGeneratingInsts) {
4201 // If the instruction cannot be poison, it's always safe to reuse.
4203 return true;
4204
4205 // Otherwise, it is possible that I is more poisonous that S. Collect the
4206 // poison-contributors of S, and then check whether I has any additional
4207 // poison-contributors. Poison that is contributed through poison-generating
4208 // flags is handled by dropping those flags instead.
4210 getPoisonGeneratingValues(PoisonVals, S);
4211
4212 SmallVector<Value *> Worklist;
4214 Worklist.push_back(I);
4215 while (!Worklist.empty()) {
4216 Value *V = Worklist.pop_back_val();
4217 if (!Visited.insert(V).second)
4218 continue;
4219
4220 // Avoid walking large instruction graphs.
4221 if (Visited.size() > 16)
4222 return false;
4223
4224 // Either the value can't be poison, or the S would also be poison if it
4225 // is.
4226 if (PoisonVals.contains(V) || ::isGuaranteedNotToBePoison(V))
4227 continue;
4228
4229 auto *I = dyn_cast<Instruction>(V);
4230 if (!I)
4231 return false;
4232
4233 // Disjoint or instructions are interpreted as adds by SCEV. However, we
4234 // can't replace an arbitrary add with disjoint or, even if we drop the
4235 // flag. We would need to convert the or into an add.
4236 if (auto *PDI = dyn_cast<PossiblyDisjointInst>(I))
4237 if (PDI->isDisjoint())
4238 return false;
4239
4240 // FIXME: Ignore vscale, even though it technically could be poison. Do this
4241 // because SCEV currently assumes it can't be poison. Remove this special
4242 // case once we proper model when vscale can be poison.
4243 if (auto *II = dyn_cast<IntrinsicInst>(I);
4244 II && II->getIntrinsicID() == Intrinsic::vscale)
4245 continue;
4246
4247 if (canCreatePoison(cast<Operator>(I), /*ConsiderFlagsAndMetadata*/ false))
4248 return false;
4249
4250 // If the instruction can't create poison, we can recurse to its operands.
4251 if (I->hasPoisonGeneratingAnnotations())
4252 DropPoisonGeneratingInsts.push_back(I);
4253
4254 llvm::append_range(Worklist, I->operands());
4255 }
4256 return true;
4257}
4258
4259const SCEV *
4262 assert(SCEVSequentialMinMaxExpr::isSequentialMinMaxType(Kind) &&
4263 "Not a SCEVSequentialMinMaxExpr!");
4264 assert(!Ops.empty() && "Cannot get empty (u|s)(min|max)!");
4265 if (Ops.size() == 1)
4266 return Ops[0];
4267#ifndef NDEBUG
4268 Type *ETy = getEffectiveSCEVType(Ops[0]->getType());
4269 for (unsigned i = 1, e = Ops.size(); i != e; ++i) {
4270 assert(getEffectiveSCEVType(Ops[i]->getType()) == ETy &&
4271 "Operand types don't match!");
4272 assert(Ops[0]->getType()->isPointerTy() ==
4273 Ops[i]->getType()->isPointerTy() &&
4274 "min/max should be consistently pointerish");
4275 }
4276#endif
4277
4278 // Note that SCEVSequentialMinMaxExpr is *NOT* commutative,
4279 // so we can *NOT* do any kind of sorting of the expressions!
4280
4281 // Check if we have created the same expression before.
4282 if (const SCEV *S = findExistingSCEVInCache(Kind, Ops))
4283 return S;
4284
4285 // FIXME: there are *some* simplifications that we can do here.
4286
4287 // Keep only the first instance of an operand.
4288 {
4289 SCEVSequentialMinMaxDeduplicatingVisitor Deduplicator(*this, Kind);
4290 bool Changed = Deduplicator.visit(Kind, Ops, Ops);
4291 if (Changed)
4292 return getSequentialMinMaxExpr(Kind, Ops);
4293 }
4294
4295 // Check to see if one of the operands is of the same kind. If so, expand its
4296 // operands onto our operand list, and recurse to simplify.
4297 {
4298 unsigned Idx = 0;
4299 bool DeletedAny = false;
4300 while (Idx < Ops.size()) {
4301 if (Ops[Idx]->getSCEVType() != Kind) {
4302 ++Idx;
4303 continue;
4304 }
4305 const auto *SMME = cast<SCEVSequentialMinMaxExpr>(Ops[Idx]);
4306 Ops.erase(Ops.begin() + Idx);
4307 Ops.insert(Ops.begin() + Idx, SMME->operands().begin(),
4308 SMME->operands().end());
4309 DeletedAny = true;
4310 }
4311
4312 if (DeletedAny)
4313 return getSequentialMinMaxExpr(Kind, Ops);
4314 }
4315
4316 const SCEV *SaturationPoint;
4318 switch (Kind) {
4320 SaturationPoint = getZero(Ops[0]->getType());
4321 Pred = ICmpInst::ICMP_ULE;
4322 break;
4323 default:
4324 llvm_unreachable("Not a sequential min/max type.");
4325 }
4326
4327 for (unsigned i = 1, e = Ops.size(); i != e; ++i) {
4328 if (!isGuaranteedNotToCauseUB(Ops[i]))
4329 continue;
4330 // We can replace %x umin_seq %y with %x umin %y if either:
4331 // * %y being poison implies %x is also poison.
4332 // * %x cannot be the saturating value (e.g. zero for umin).
4333 if (::impliesPoison(Ops[i], Ops[i - 1]) ||
4334 isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_NE, Ops[i - 1],
4335 SaturationPoint)) {
4336 SmallVector<const SCEV *> SeqOps = {Ops[i - 1], Ops[i]};
4337 Ops[i - 1] = getMinMaxExpr(
4339 SeqOps);
4340 Ops.erase(Ops.begin() + i);
4341 return getSequentialMinMaxExpr(Kind, Ops);
4342 }
4343 // Fold %x umin_seq %y to %x if %x ule %y.
4344 // TODO: We might be able to prove the predicate for a later operand.
4345 if (isKnownViaNonRecursiveReasoning(Pred, Ops[i - 1], Ops[i])) {
4346 Ops.erase(Ops.begin() + i);
4347 return getSequentialMinMaxExpr(Kind, Ops);
4348 }
4349 }
4350
4351 // Okay, it looks like we really DO need an expr. Check to see if we
4352 // already have one, otherwise create a new one.
4354 ID.AddInteger(Kind);
4355 for (const SCEV *Op : Ops)
4356 ID.AddPointer(Op);
4357 void *IP = nullptr;
4358 const SCEV *ExistingSCEV = UniqueSCEVs.FindNodeOrInsertPos(ID, IP);
4359 if (ExistingSCEV)
4360 return ExistingSCEV;
4361
4362 const SCEV **O = SCEVAllocator.Allocate<const SCEV *>(Ops.size());
4364 SCEV *S = new (SCEVAllocator)
4365 SCEVSequentialMinMaxExpr(ID.Intern(SCEVAllocator), Kind, O, Ops.size());
4366
4367 UniqueSCEVs.InsertNode(S, IP);
4368 registerUser(S, Ops);
4369 return S;
4370}
4371
4372const SCEV *ScalarEvolution::getSMaxExpr(const SCEV *LHS, const SCEV *RHS) {
4373 SmallVector<const SCEV *, 2> Ops = {LHS, RHS};
4374 return getSMaxExpr(Ops);
4375}
4376
4380
4381const SCEV *ScalarEvolution::getUMaxExpr(const SCEV *LHS, const SCEV *RHS) {
4382 SmallVector<const SCEV *, 2> Ops = {LHS, RHS};
4383 return getUMaxExpr(Ops);
4384}
4385
4389
4391 const SCEV *RHS) {
4392 SmallVector<const SCEV *, 2> Ops = { LHS, RHS };
4393 return getSMinExpr(Ops);
4394}
4395
4399
4400const SCEV *ScalarEvolution::getUMinExpr(const SCEV *LHS, const SCEV *RHS,
4401 bool Sequential) {
4402 SmallVector<const SCEV *, 2> Ops = { LHS, RHS };
4403 return getUMinExpr(Ops, Sequential);
4404}
4405
4411
4412const SCEV *
4414 const SCEV *Res = getConstant(IntTy, Size.getKnownMinValue());
4415 if (Size.isScalable())
4416 Res = getMulExpr(Res, getVScale(IntTy));
4417 return Res;
4418}
4419
4421 return getSizeOfExpr(IntTy, getDataLayout().getTypeAllocSize(AllocTy));
4422}
4423
4425 return getSizeOfExpr(IntTy, getDataLayout().getTypeStoreSize(StoreTy));
4426}
4427
4429 StructType *STy,
4430 unsigned FieldNo) {
4431 // We can bypass creating a target-independent constant expression and then
4432 // folding it back into a ConstantInt. This is just a compile-time
4433 // optimization.
4434 const StructLayout *SL = getDataLayout().getStructLayout(STy);
4435 assert(!SL->getSizeInBits().isScalable() &&
4436 "Cannot get offset for structure containing scalable vector types");
4437 return getConstant(IntTy, SL->getElementOffset(FieldNo));
4438}
4439
4441 // Don't attempt to do anything other than create a SCEVUnknown object
4442 // here. createSCEV only calls getUnknown after checking for all other
4443 // interesting possibilities, and any other code that calls getUnknown
4444 // is doing so in order to hide a value from SCEV canonicalization.
4445
4447 ID.AddInteger(scUnknown);
4448 ID.AddPointer(V);
4449 void *IP = nullptr;
4450 if (SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) {
4451 assert(cast<SCEVUnknown>(S)->getValue() == V &&
4452 "Stale SCEVUnknown in uniquing map!");
4453 return S;
4454 }
4455 SCEV *S = new (SCEVAllocator) SCEVUnknown(ID.Intern(SCEVAllocator), V, this,
4456 FirstUnknown);
4457 FirstUnknown = cast<SCEVUnknown>(S);
4458 UniqueSCEVs.InsertNode(S, IP);
4459 return S;
4460}
4461
4462//===----------------------------------------------------------------------===//
4463// Basic SCEV Analysis and PHI Idiom Recognition Code
4464//
4465
4466/// Test if values of the given type are analyzable within the SCEV
4467/// framework. This primarily includes integer types, and it can optionally
4468/// include pointer types if the ScalarEvolution class has access to
4469/// target-specific information.
4471 // Integers and pointers are always SCEVable.
4472 return Ty->isIntOrPtrTy();
4473}
4474
4475/// Return the size in bits of the specified type, for which isSCEVable must
4476/// return true.
4478 assert(isSCEVable(Ty) && "Type is not SCEVable!");
4479 if (Ty->isPointerTy())
4481 return getDataLayout().getTypeSizeInBits(Ty);
4482}
4483
4484/// Return a type with the same bitwidth as the given type and which represents
4485/// how SCEV will treat the given type, for which isSCEVable must return
4486/// true. For pointer types, this is the pointer index sized integer type.
4488 assert(isSCEVable(Ty) && "Type is not SCEVable!");
4489
4490 if (Ty->isIntegerTy())
4491 return Ty;
4492
4493 // The only other support type is pointer.
4494 assert(Ty->isPointerTy() && "Unexpected non-pointer non-integer type!");
4495 return getDataLayout().getIndexType(Ty);
4496}
4497
4499 return getTypeSizeInBits(T1) >= getTypeSizeInBits(T2) ? T1 : T2;
4500}
4501
4503 const SCEV *B) {
4504 /// For a valid use point to exist, the defining scope of one operand
4505 /// must dominate the other.
4506 bool PreciseA, PreciseB;
4507 auto *ScopeA = getDefiningScopeBound({A}, PreciseA);
4508 auto *ScopeB = getDefiningScopeBound({B}, PreciseB);
4509 if (!PreciseA || !PreciseB)
4510 // Can't tell.
4511 return false;
4512 return (ScopeA == ScopeB) || DT.dominates(ScopeA, ScopeB) ||
4513 DT.dominates(ScopeB, ScopeA);
4514}
4515
4517 return CouldNotCompute.get();
4518}
4519
4520bool ScalarEvolution::checkValidity(const SCEV *S) const {
4521 bool ContainsNulls = SCEVExprContains(S, [](const SCEV *S) {
4522 auto *SU = dyn_cast<SCEVUnknown>(S);
4523 return SU && SU->getValue() == nullptr;
4524 });
4525
4526 return !ContainsNulls;
4527}
4528
4530 HasRecMapType::iterator I = HasRecMap.find(S);
4531 if (I != HasRecMap.end())
4532 return I->second;
4533
4534 bool FoundAddRec =
4535 SCEVExprContains(S, [](const SCEV *S) { return isa<SCEVAddRecExpr>(S); });
4536 HasRecMap.insert({S, FoundAddRec});
4537 return FoundAddRec;
4538}
4539
4540/// Return the ValueOffsetPair set for \p S. \p S can be represented
4541/// by the value and offset from any ValueOffsetPair in the set.
4542ArrayRef<Value *> ScalarEvolution::getSCEVValues(const SCEV *S) {
4543 ExprValueMapType::iterator SI = ExprValueMap.find_as(S);
4544 if (SI == ExprValueMap.end())
4545 return {};
4546 return SI->second.getArrayRef();
4547}
4548
4549/// Erase Value from ValueExprMap and ExprValueMap. ValueExprMap.erase(V)
4550/// cannot be used separately. eraseValueFromMap should be used to remove
4551/// V from ValueExprMap and ExprValueMap at the same time.
4552void ScalarEvolution::eraseValueFromMap(Value *V) {
4553 ValueExprMapType::iterator I = ValueExprMap.find_as(V);
4554 if (I != ValueExprMap.end()) {
4555 auto EVIt = ExprValueMap.find(I->second);
4556 bool Removed = EVIt->second.remove(V);
4557 (void) Removed;
4558 assert(Removed && "Value not in ExprValueMap?");
4559 ValueExprMap.erase(I);
4560 }
4561}
4562
4563void ScalarEvolution::insertValueToMap(Value *V, const SCEV *S) {
4564 // A recursive query may have already computed the SCEV. It should be
4565 // equivalent, but may not necessarily be exactly the same, e.g. due to lazily
4566 // inferred nowrap flags.
4567 auto It = ValueExprMap.find_as(V);
4568 if (It == ValueExprMap.end()) {
4569 ValueExprMap.insert({SCEVCallbackVH(V, this), S});
4570 ExprValueMap[S].insert(V);
4571 }
4572}
4573
4574/// Return an existing SCEV if it exists, otherwise analyze the expression and
4575/// create a new one.
4577 assert(isSCEVable(V->getType()) && "Value is not SCEVable!");
4578
4579 if (const SCEV *S = getExistingSCEV(V))
4580 return S;
4581 return createSCEVIter(V);
4582}
4583
4585 assert(isSCEVable(V->getType()) && "Value is not SCEVable!");
4586
4587 ValueExprMapType::iterator I = ValueExprMap.find_as(V);
4588 if (I != ValueExprMap.end()) {
4589 const SCEV *S = I->second;
4590 assert(checkValidity(S) &&
4591 "existing SCEV has not been properly invalidated");
4592 return S;
4593 }
4594 return nullptr;
4595}
4596
4597/// Return a SCEV corresponding to -V = -1*V
4599 SCEV::NoWrapFlags Flags) {
4600 if (const SCEVConstant *VC = dyn_cast<SCEVConstant>(V))
4601 return getConstant(
4602 cast<ConstantInt>(ConstantExpr::getNeg(VC->getValue())));
4603
4604 Type *Ty = V->getType();
4605 Ty = getEffectiveSCEVType(Ty);
4606 return getMulExpr(V, getMinusOne(Ty), Flags);
4607}
4608
4609/// If Expr computes ~A, return A else return nullptr
4610static const SCEV *MatchNotExpr(const SCEV *Expr) {
4611 const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Expr);
4612 if (!Add || Add->getNumOperands() != 2 ||
4613 !Add->getOperand(0)->isAllOnesValue())
4614 return nullptr;
4615
4616 const SCEVMulExpr *AddRHS = dyn_cast<SCEVMulExpr>(Add->getOperand(1));
4617 if (!AddRHS || AddRHS->getNumOperands() != 2 ||
4618 !AddRHS->getOperand(0)->isAllOnesValue())
4619 return nullptr;
4620
4621 return AddRHS->getOperand(1);
4622}
4623
4624/// Return a SCEV corresponding to ~V = -1-V
4626 assert(!V->getType()->isPointerTy() && "Can't negate pointer");
4627
4628 if (const SCEVConstant *VC = dyn_cast<SCEVConstant>(V))
4629 return getConstant(
4630 cast<ConstantInt>(ConstantExpr::getNot(VC->getValue())));
4631
4632 // Fold ~(u|s)(min|max)(~x, ~y) to (u|s)(max|min)(x, y)
4633 if (const SCEVMinMaxExpr *MME = dyn_cast<SCEVMinMaxExpr>(V)) {
4634 auto MatchMinMaxNegation = [&](const SCEVMinMaxExpr *MME) {
4635 SmallVector<const SCEV *, 2> MatchedOperands;
4636 for (const SCEV *Operand : MME->operands()) {
4637 const SCEV *Matched = MatchNotExpr(Operand);
4638 if (!Matched)
4639 return (const SCEV *)nullptr;
4640 MatchedOperands.push_back(Matched);
4641 }
4642 return getMinMaxExpr(SCEVMinMaxExpr::negate(MME->getSCEVType()),
4643 MatchedOperands);
4644 };
4645 if (const SCEV *Replaced = MatchMinMaxNegation(MME))
4646 return Replaced;
4647 }
4648
4649 Type *Ty = V->getType();
4650 Ty = getEffectiveSCEVType(Ty);
4651 return getMinusSCEV(getMinusOne(Ty), V);
4652}
4653
4655 assert(P->getType()->isPointerTy());
4656
4657 if (auto *AddRec = dyn_cast<SCEVAddRecExpr>(P)) {
4658 // The base of an AddRec is the first operand.
4659 SmallVector<const SCEV *> Ops{AddRec->operands()};
4660 Ops[0] = removePointerBase(Ops[0]);
4661 // Don't try to transfer nowrap flags for now. We could in some cases
4662 // (for example, if pointer operand of the AddRec is a SCEVUnknown).
4663 return getAddRecExpr(Ops, AddRec->getLoop(), SCEV::FlagAnyWrap);
4664 }
4665 if (auto *Add = dyn_cast<SCEVAddExpr>(P)) {
4666 // The base of an Add is the pointer operand.
4667 SmallVector<const SCEV *> Ops{Add->operands()};
4668 const SCEV **PtrOp = nullptr;
4669 for (const SCEV *&AddOp : Ops) {
4670 if (AddOp->getType()->isPointerTy()) {
4671 assert(!PtrOp && "Cannot have multiple pointer ops");
4672 PtrOp = &AddOp;
4673 }
4674 }
4675 *PtrOp = removePointerBase(*PtrOp);
4676 // Don't try to transfer nowrap flags for now. We could in some cases
4677 // (for example, if the pointer operand of the Add is a SCEVUnknown).
4678 return getAddExpr(Ops);
4679 }
4680 // Any other expression must be a pointer base.
4681 return getZero(P->getType());
4682}
4683
4684const SCEV *ScalarEvolution::getMinusSCEV(const SCEV *LHS, const SCEV *RHS,
4685 SCEV::NoWrapFlags Flags,
4686 unsigned Depth) {
4687 // Fast path: X - X --> 0.
4688 if (LHS == RHS)
4689 return getZero(LHS->getType());
4690
4691 // If we subtract two pointers with different pointer bases, bail.
4692 // Eventually, we're going to add an assertion to getMulExpr that we
4693 // can't multiply by a pointer.
4694 if (RHS->getType()->isPointerTy()) {
4695 if (!LHS->getType()->isPointerTy() ||
4696 getPointerBase(LHS) != getPointerBase(RHS))
4697 return getCouldNotCompute();
4698 LHS = removePointerBase(LHS);
4699 RHS = removePointerBase(RHS);
4700 }
4701
4702 // We represent LHS - RHS as LHS + (-1)*RHS. This transformation
4703 // makes it so that we cannot make much use of NUW.
4704 auto AddFlags = SCEV::FlagAnyWrap;
4705 const bool RHSIsNotMinSigned =
4707 if (hasFlags(Flags, SCEV::FlagNSW)) {
4708 // Let M be the minimum representable signed value. Then (-1)*RHS
4709 // signed-wraps if and only if RHS is M. That can happen even for
4710 // a NSW subtraction because e.g. (-1)*M signed-wraps even though
4711 // -1 - M does not. So to transfer NSW from LHS - RHS to LHS +
4712 // (-1)*RHS, we need to prove that RHS != M.
4713 //
4714 // If LHS is non-negative and we know that LHS - RHS does not
4715 // signed-wrap, then RHS cannot be M. So we can rule out signed-wrap
4716 // either by proving that RHS > M or that LHS >= 0.
4717 if (RHSIsNotMinSigned || isKnownNonNegative(LHS)) {
4718 AddFlags = SCEV::FlagNSW;
4719 }
4720 }
4721
4722 // FIXME: Find a correct way to transfer NSW to (-1)*M when LHS -
4723 // RHS is NSW and LHS >= 0.
4724 //
4725 // The difficulty here is that the NSW flag may have been proven
4726 // relative to a loop that is to be found in a recurrence in LHS and
4727 // not in RHS. Applying NSW to (-1)*M may then let the NSW have a
4728 // larger scope than intended.
4729 auto NegFlags = RHSIsNotMinSigned ? SCEV::FlagNSW : SCEV::FlagAnyWrap;
4730
4731 return getAddExpr(LHS, getNegativeSCEV(RHS, NegFlags), AddFlags, Depth);
4732}
4733
4735 unsigned Depth) {
4736 Type *SrcTy = V->getType();
4737 assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
4738 "Cannot truncate or zero extend with non-integer arguments!");
4739 if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
4740 return V; // No conversion
4741 if (getTypeSizeInBits(SrcTy) > getTypeSizeInBits(Ty))
4742 return getTruncateExpr(V, Ty, Depth);
4743 return getZeroExtendExpr(V, Ty, Depth);
4744}
4745
4747 unsigned Depth) {
4748 Type *SrcTy = V->getType();
4749 assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
4750 "Cannot truncate or zero extend with non-integer arguments!");
4751 if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
4752 return V; // No conversion
4753 if (getTypeSizeInBits(SrcTy) > getTypeSizeInBits(Ty))
4754 return getTruncateExpr(V, Ty, Depth);
4755 return getSignExtendExpr(V, Ty, Depth);
4756}
4757
4758const SCEV *
4760 Type *SrcTy = V->getType();
4761 assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
4762 "Cannot noop or zero extend with non-integer arguments!");
4764 "getNoopOrZeroExtend cannot truncate!");
4765 if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
4766 return V; // No conversion
4767 return getZeroExtendExpr(V, Ty);
4768}
4769
4770const SCEV *
4772 Type *SrcTy = V->getType();
4773 assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
4774 "Cannot noop or sign extend with non-integer arguments!");
4776 "getNoopOrSignExtend cannot truncate!");
4777 if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
4778 return V; // No conversion
4779 return getSignExtendExpr(V, Ty);
4780}
4781
4782const SCEV *
4784 Type *SrcTy = V->getType();
4785 assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
4786 "Cannot noop or any extend with non-integer arguments!");
4788 "getNoopOrAnyExtend cannot truncate!");
4789 if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
4790 return V; // No conversion
4791 return getAnyExtendExpr(V, Ty);
4792}
4793
4794const SCEV *
4796 Type *SrcTy = V->getType();
4797 assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
4798 "Cannot truncate or noop with non-integer arguments!");
4800 "getTruncateOrNoop cannot extend!");
4801 if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
4802 return V; // No conversion
4803 return getTruncateExpr(V, Ty);
4804}
4805
4807 const SCEV *RHS) {
4808 const SCEV *PromotedLHS = LHS;
4809 const SCEV *PromotedRHS = RHS;
4810
4811 if (getTypeSizeInBits(LHS->getType()) > getTypeSizeInBits(RHS->getType()))
4812 PromotedRHS = getZeroExtendExpr(RHS, LHS->getType());
4813 else
4814 PromotedLHS = getNoopOrZeroExtend(LHS, RHS->getType());
4815
4816 return getUMaxExpr(PromotedLHS, PromotedRHS);
4817}
4818
4820 const SCEV *RHS,
4821 bool Sequential) {
4822 SmallVector<const SCEV *, 2> Ops = { LHS, RHS };
4823 return getUMinFromMismatchedTypes(Ops, Sequential);
4824}
4825
4826const SCEV *
4828 bool Sequential) {
4829 assert(!Ops.empty() && "At least one operand must be!");
4830 // Trivial case.
4831 if (Ops.size() == 1)
4832 return Ops[0];
4833
4834 // Find the max type first.
4835 Type *MaxType = nullptr;
4836 for (const auto *S : Ops)
4837 if (MaxType)
4838 MaxType = getWiderType(MaxType, S->getType());
4839 else
4840 MaxType = S->getType();
4841 assert(MaxType && "Failed to find maximum type!");
4842
4843 // Extend all ops to max type.
4844 SmallVector<const SCEV *, 2> PromotedOps;
4845 for (const auto *S : Ops)
4846 PromotedOps.push_back(getNoopOrZeroExtend(S, MaxType));
4847
4848 // Generate umin.
4849 return getUMinExpr(PromotedOps, Sequential);
4850}
4851
4853 // A pointer operand may evaluate to a nonpointer expression, such as null.
4854 if (!V->getType()->isPointerTy())
4855 return V;
4856
4857 while (true) {
4858 if (auto *AddRec = dyn_cast<SCEVAddRecExpr>(V)) {
4859 V = AddRec->getStart();
4860 } else if (auto *Add = dyn_cast<SCEVAddExpr>(V)) {
4861 const SCEV *PtrOp = nullptr;
4862 for (const SCEV *AddOp : Add->operands()) {
4863 if (AddOp->getType()->isPointerTy()) {
4864 assert(!PtrOp && "Cannot have multiple pointer ops");
4865 PtrOp = AddOp;
4866 }
4867 }
4868 assert(PtrOp && "Must have pointer op");
4869 V = PtrOp;
4870 } else // Not something we can look further into.
4871 return V;
4872 }
4873}
4874
4875/// Push users of the given Instruction onto the given Worklist.
4879 // Push the def-use children onto the Worklist stack.
4880 for (User *U : I->users()) {
4881 auto *UserInsn = cast<Instruction>(U);
4882 if (Visited.insert(UserInsn).second)
4883 Worklist.push_back(UserInsn);
4884 }
4885}
4886
4887namespace {
4888
4889/// Takes SCEV S and Loop L. For each AddRec sub-expression, use its start
4890/// expression in case its Loop is L. If it is not L then
4891/// if IgnoreOtherLoops is true then use AddRec itself
4892/// otherwise rewrite cannot be done.
4893/// If SCEV contains non-invariant unknown SCEV rewrite cannot be done.
4894class SCEVInitRewriter : public SCEVRewriteVisitor<SCEVInitRewriter> {
4895public:
4896 static const SCEV *rewrite(const SCEV *S, const Loop *L, ScalarEvolution &SE,
4897 bool IgnoreOtherLoops = true) {
4898 SCEVInitRewriter Rewriter(L, SE);
4899 const SCEV *Result = Rewriter.visit(S);
4900 if (Rewriter.hasSeenLoopVariantSCEVUnknown())
4901 return SE.getCouldNotCompute();
4902 return Rewriter.hasSeenOtherLoops() && !IgnoreOtherLoops
4903 ? SE.getCouldNotCompute()
4904 : Result;
4905 }
4906
4907 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
4908 if (!SE.isLoopInvariant(Expr, L))
4909 SeenLoopVariantSCEVUnknown = true;
4910 return Expr;
4911 }
4912
4913 const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) {
4914 // Only re-write AddRecExprs for this loop.
4915 if (Expr->getLoop() == L)
4916 return Expr->getStart();
4917 SeenOtherLoops = true;
4918 return Expr;
4919 }
4920
4921 bool hasSeenLoopVariantSCEVUnknown() { return SeenLoopVariantSCEVUnknown; }
4922
4923 bool hasSeenOtherLoops() { return SeenOtherLoops; }
4924
4925private:
4926 explicit SCEVInitRewriter(const Loop *L, ScalarEvolution &SE)
4927 : SCEVRewriteVisitor(SE), L(L) {}
4928
4929 const Loop *L;
4930 bool SeenLoopVariantSCEVUnknown = false;
4931 bool SeenOtherLoops = false;
4932};
4933
4934/// Takes SCEV S and Loop L. For each AddRec sub-expression, use its post
4935/// increment expression in case its Loop is L. If it is not L then
4936/// use AddRec itself.
4937/// If SCEV contains non-invariant unknown SCEV rewrite cannot be done.
4938class SCEVPostIncRewriter : public SCEVRewriteVisitor<SCEVPostIncRewriter> {
4939public:
4940 static const SCEV *rewrite(const SCEV *S, const Loop *L, ScalarEvolution &SE) {
4941 SCEVPostIncRewriter Rewriter(L, SE);
4942 const SCEV *Result = Rewriter.visit(S);
4943 return Rewriter.hasSeenLoopVariantSCEVUnknown()
4944 ? SE.getCouldNotCompute()
4945 : Result;
4946 }
4947
4948 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
4949 if (!SE.isLoopInvariant(Expr, L))
4950 SeenLoopVariantSCEVUnknown = true;
4951 return Expr;
4952 }
4953
4954 const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) {
4955 // Only re-write AddRecExprs for this loop.
4956 if (Expr->getLoop() == L)
4957 return Expr->getPostIncExpr(SE);
4958 SeenOtherLoops = true;
4959 return Expr;
4960 }
4961
4962 bool hasSeenLoopVariantSCEVUnknown() { return SeenLoopVariantSCEVUnknown; }
4963
4964 bool hasSeenOtherLoops() { return SeenOtherLoops; }
4965
4966private:
4967 explicit SCEVPostIncRewriter(const Loop *L, ScalarEvolution &SE)
4968 : SCEVRewriteVisitor(SE), L(L) {}
4969
4970 const Loop *L;
4971 bool SeenLoopVariantSCEVUnknown = false;
4972 bool SeenOtherLoops = false;
4973};
4974
4975/// This class evaluates the compare condition by matching it against the
4976/// condition of loop latch. If there is a match we assume a true value
4977/// for the condition while building SCEV nodes.
4978class SCEVBackedgeConditionFolder
4979 : public SCEVRewriteVisitor<SCEVBackedgeConditionFolder> {
4980public:
4981 static const SCEV *rewrite(const SCEV *S, const Loop *L,
4982 ScalarEvolution &SE) {
4983 bool IsPosBECond = false;
4984 Value *BECond = nullptr;
4985 if (BasicBlock *Latch = L->getLoopLatch()) {
4986 BranchInst *BI = dyn_cast<BranchInst>(Latch->getTerminator());
4987 if (BI && BI->isConditional()) {
4988 assert(BI->getSuccessor(0) != BI->getSuccessor(1) &&
4989 "Both outgoing branches should not target same header!");
4990 BECond = BI->getCondition();
4991 IsPosBECond = BI->getSuccessor(0) == L->getHeader();
4992 } else {
4993 return S;
4994 }
4995 }
4996 SCEVBackedgeConditionFolder Rewriter(L, BECond, IsPosBECond, SE);
4997 return Rewriter.visit(S);
4998 }
4999
5000 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
5001 const SCEV *Result = Expr;
5002 bool InvariantF = SE.isLoopInvariant(Expr, L);
5003
5004 if (!InvariantF) {
5006 switch (I->getOpcode()) {
5007 case Instruction::Select: {
5008 SelectInst *SI = cast<SelectInst>(I);
5009 std::optional<const SCEV *> Res =
5010 compareWithBackedgeCondition(SI->getCondition());
5011 if (Res) {
5012 bool IsOne = cast<SCEVConstant>(*Res)->getValue()->isOne();
5013 Result = SE.getSCEV(IsOne ? SI->getTrueValue() : SI->getFalseValue());
5014 }
5015 break;
5016 }
5017 default: {
5018 std::optional<const SCEV *> Res = compareWithBackedgeCondition(I);
5019 if (Res)
5020 Result = *Res;
5021 break;
5022 }
5023 }
5024 }
5025 return Result;
5026 }
5027
5028private:
5029 explicit SCEVBackedgeConditionFolder(const Loop *L, Value *BECond,
5030 bool IsPosBECond, ScalarEvolution &SE)
5031 : SCEVRewriteVisitor(SE), L(L), BackedgeCond(BECond),
5032 IsPositiveBECond(IsPosBECond) {}
5033
5034 std::optional<const SCEV *> compareWithBackedgeCondition(Value *IC);
5035
5036 const Loop *L;
5037 /// Loop back condition.
5038 Value *BackedgeCond = nullptr;
5039 /// Set to true if loop back is on positive branch condition.
5040 bool IsPositiveBECond;
5041};
5042
5043std::optional<const SCEV *>
5044SCEVBackedgeConditionFolder::compareWithBackedgeCondition(Value *IC) {
5045
5046 // If value matches the backedge condition for loop latch,
5047 // then return a constant evolution node based on loopback
5048 // branch taken.
5049 if (BackedgeCond == IC)
5050 return IsPositiveBECond ? SE.getOne(Type::getInt1Ty(SE.getContext()))
5052 return std::nullopt;
5053}
5054
5055class SCEVShiftRewriter : public SCEVRewriteVisitor<SCEVShiftRewriter> {
5056public:
5057 static const SCEV *rewrite(const SCEV *S, const Loop *L,
5058 ScalarEvolution &SE) {
5059 SCEVShiftRewriter Rewriter(L, SE);
5060 const SCEV *Result = Rewriter.visit(S);
5061 return Rewriter.isValid() ? Result : SE.getCouldNotCompute();
5062 }
5063
5064 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
5065 // Only allow AddRecExprs for this loop.
5066 if (!SE.isLoopInvariant(Expr, L))
5067 Valid = false;
5068 return Expr;
5069 }
5070
5071 const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) {
5072 if (Expr->getLoop() == L && Expr->isAffine())
5073 return SE.getMinusSCEV(Expr, Expr->getStepRecurrence(SE));
5074 Valid = false;
5075 return Expr;
5076 }
5077
5078 bool isValid() { return Valid; }
5079
5080private:
5081 explicit SCEVShiftRewriter(const Loop *L, ScalarEvolution &SE)
5082 : SCEVRewriteVisitor(SE), L(L) {}
5083
5084 const Loop *L;
5085 bool Valid = true;
5086};
5087
5088} // end anonymous namespace
5089
5091ScalarEvolution::proveNoWrapViaConstantRanges(const SCEVAddRecExpr *AR) {
5092 if (!AR->isAffine())
5093 return SCEV::FlagAnyWrap;
5094
5095 using OBO = OverflowingBinaryOperator;
5096
5098
5099 if (!AR->hasNoSelfWrap()) {
5100 const SCEV *BECount = getConstantMaxBackedgeTakenCount(AR->getLoop());
5101 if (const SCEVConstant *BECountMax = dyn_cast<SCEVConstant>(BECount)) {
5102 ConstantRange StepCR = getSignedRange(AR->getStepRecurrence(*this));
5103 const APInt &BECountAP = BECountMax->getAPInt();
5104 unsigned NoOverflowBitWidth =
5105 BECountAP.getActiveBits() + StepCR.getMinSignedBits();
5106 if (NoOverflowBitWidth <= getTypeSizeInBits(AR->getType()))
5108 }
5109 }
5110
5111 if (!AR->hasNoSignedWrap()) {
5112 ConstantRange AddRecRange = getSignedRange(AR);
5113 ConstantRange IncRange = getSignedRange(AR->getStepRecurrence(*this));
5114
5116 Instruction::Add, IncRange, OBO::NoSignedWrap);
5117 if (NSWRegion.contains(AddRecRange))
5119 }
5120
5121 if (!AR->hasNoUnsignedWrap()) {
5122 ConstantRange AddRecRange = getUnsignedRange(AR);
5123 ConstantRange IncRange = getUnsignedRange(AR->getStepRecurrence(*this));
5124
5126 Instruction::Add, IncRange, OBO::NoUnsignedWrap);
5127 if (NUWRegion.contains(AddRecRange))
5129 }
5130
5131 return Result;
5132}
5133
5135ScalarEvolution::proveNoSignedWrapViaInduction(const SCEVAddRecExpr *AR) {
5137
5138 if (AR->hasNoSignedWrap())
5139 return Result;
5140
5141 if (!AR->isAffine())
5142 return Result;
5143
5144 // This function can be expensive, only try to prove NSW once per AddRec.
5145 if (!SignedWrapViaInductionTried.insert(AR).second)
5146 return Result;
5147
5148 const SCEV *Step = AR->getStepRecurrence(*this);
5149 const Loop *L = AR->getLoop();
5150
5151 // Check whether the backedge-taken count is SCEVCouldNotCompute.
5152 // Note that this serves two purposes: It filters out loops that are
5153 // simply not analyzable, and it covers the case where this code is
5154 // being called from within backedge-taken count analysis, such that
5155 // attempting to ask for the backedge-taken count would likely result
5156 // in infinite recursion. In the later case, the analysis code will
5157 // cope with a conservative value, and it will take care to purge
5158 // that value once it has finished.
5159 const SCEV *MaxBECount = getConstantMaxBackedgeTakenCount(L);
5160
5161 // Normally, in the cases we can prove no-overflow via a
5162 // backedge guarding condition, we can also compute a backedge
5163 // taken count for the loop. The exceptions are assumptions and
5164 // guards present in the loop -- SCEV is not great at exploiting
5165 // these to compute max backedge taken counts, but can still use
5166 // these to prove lack of overflow. Use this fact to avoid
5167 // doing extra work that may not pay off.
5168
5169 if (isa<SCEVCouldNotCompute>(MaxBECount) && !HasGuards &&
5170 AC.assumptions().empty())
5171 return Result;
5172
5173 // If the backedge is guarded by a comparison with the pre-inc value the
5174 // addrec is safe. Also, if the entry is guarded by a comparison with the
5175 // start value and the backedge is guarded by a comparison with the post-inc
5176 // value, the addrec is safe.
5178 const SCEV *OverflowLimit =
5179 getSignedOverflowLimitForStep(Step, &Pred, this);
5180 if (OverflowLimit &&
5181 (isLoopBackedgeGuardedByCond(L, Pred, AR, OverflowLimit) ||
5182 isKnownOnEveryIteration(Pred, AR, OverflowLimit))) {
5183 Result = setFlags(Result, SCEV::FlagNSW);
5184 }
5185 return Result;
5186}
5188ScalarEvolution::proveNoUnsignedWrapViaInduction(const SCEVAddRecExpr *AR) {
5190
5191 if (AR->hasNoUnsignedWrap())
5192 return Result;
5193
5194 if (!AR->isAffine())
5195 return Result;
5196
5197 // This function can be expensive, only try to prove NUW once per AddRec.
5198 if (!UnsignedWrapViaInductionTried.insert(AR).second)
5199 return Result;
5200
5201 const SCEV *Step = AR->getStepRecurrence(*this);
5202 unsigned BitWidth = getTypeSizeInBits(AR->getType());
5203 const Loop *L = AR->getLoop();
5204
5205 // Check whether the backedge-taken count is SCEVCouldNotCompute.
5206 // Note that this serves two purposes: It filters out loops that are
5207 // simply not analyzable, and it covers the case where this code is
5208 // being called from within backedge-taken count analysis, such that
5209 // attempting to ask for the backedge-taken count would likely result
5210 // in infinite recursion. In the later case, the analysis code will
5211 // cope with a conservative value, and it will take care to purge
5212 // that value once it has finished.
5213 const SCEV *MaxBECount = getConstantMaxBackedgeTakenCount(L);
5214
5215 // Normally, in the cases we can prove no-overflow via a
5216 // backedge guarding condition, we can also compute a backedge
5217 // taken count for the loop. The exceptions are assumptions and
5218 // guards present in the loop -- SCEV is not great at exploiting
5219 // these to compute max backedge taken counts, but can still use
5220 // these to prove lack of overflow. Use this fact to avoid
5221 // doing extra work that may not pay off.
5222
5223 if (isa<SCEVCouldNotCompute>(MaxBECount) && !HasGuards &&
5224 AC.assumptions().empty())
5225 return Result;
5226
5227 // If the backedge is guarded by a comparison with the pre-inc value the
5228 // addrec is safe. Also, if the entry is guarded by a comparison with the
5229 // start value and the backedge is guarded by a comparison with the post-inc
5230 // value, the addrec is safe.
5231 if (isKnownPositive(Step)) {
5232 const SCEV *N = getConstant(APInt::getMinValue(BitWidth) -
5233 getUnsignedRangeMax(Step));
5236 Result = setFlags(Result, SCEV::FlagNUW);
5237 }
5238 }
5239
5240 return Result;
5241}
5242
5243namespace {
5244
5245/// Represents an abstract binary operation. This may exist as a
5246/// normal instruction or constant expression, or may have been
5247/// derived from an expression tree.
5248struct BinaryOp {
5249 unsigned Opcode;
5250 Value *LHS;
5251 Value *RHS;
5252 bool IsNSW = false;
5253 bool IsNUW = false;
5254
5255 /// Op is set if this BinaryOp corresponds to a concrete LLVM instruction or
5256 /// constant expression.
5257 Operator *Op = nullptr;
5258
5259 explicit BinaryOp(Operator *Op)
5260 : Opcode(Op->getOpcode()), LHS(Op->getOperand(0)), RHS(Op->getOperand(1)),
5261 Op(Op) {
5262 if (auto *OBO = dyn_cast<OverflowingBinaryOperator>(Op)) {
5263 IsNSW = OBO->hasNoSignedWrap();
5264 IsNUW = OBO->hasNoUnsignedWrap();
5265 }
5266 }
5267
5268 explicit BinaryOp(unsigned Opcode, Value *LHS, Value *RHS, bool IsNSW = false,
5269 bool IsNUW = false)
5270 : Opcode(Opcode), LHS(LHS), RHS(RHS), IsNSW(IsNSW), IsNUW(IsNUW) {}
5271};
5272
5273} // end anonymous namespace
5274
5275/// Try to map \p V into a BinaryOp, and return \c std::nullopt on failure.
5276static std::optional<BinaryOp> MatchBinaryOp(Value *V, const DataLayout &DL,
5277 AssumptionCache &AC,
5278 const DominatorTree &DT,
5279 const Instruction *CxtI) {
5280 auto *Op = dyn_cast<Operator>(V);
5281 if (!Op)
5282 return std::nullopt;
5283
5284 // Implementation detail: all the cleverness here should happen without
5285 // creating new SCEV expressions -- our caller knowns tricks to avoid creating
5286 // SCEV expressions when possible, and we should not break that.
5287
5288 switch (Op->getOpcode()) {
5289 case Instruction::Add:
5290 case Instruction::Sub:
5291 case Instruction::Mul:
5292 case Instruction::UDiv:
5293 case Instruction::URem:
5294 case Instruction::And:
5295 case Instruction::AShr:
5296 case Instruction::Shl:
5297 return BinaryOp(Op);
5298
5299 case Instruction::Or: {
5300 // Convert or disjoint into add nuw nsw.
5301 if (cast<PossiblyDisjointInst>(Op)->isDisjoint())
5302 return BinaryOp(Instruction::Add, Op->getOperand(0), Op->getOperand(1),
5303 /*IsNSW=*/true, /*IsNUW=*/true);
5304 return BinaryOp(Op);
5305 }
5306
5307 case Instruction::Xor:
5308 if (auto *RHSC = dyn_cast<ConstantInt>(Op->getOperand(1)))
5309 // If the RHS of the xor is a signmask, then this is just an add.
5310 // Instcombine turns add of signmask into xor as a strength reduction step.
5311 if (RHSC->getValue().isSignMask())
5312 return BinaryOp(Instruction::Add, Op->getOperand(0), Op->getOperand(1));
5313 // Binary `xor` is a bit-wise `add`.
5314 if (V->getType()->isIntegerTy(1))
5315 return BinaryOp(Instruction::Add, Op->getOperand(0), Op->getOperand(1));
5316 return BinaryOp(Op);
5317
5318 case Instruction::LShr:
5319 // Turn logical shift right of a constant into a unsigned divide.
5320 if (ConstantInt *SA = dyn_cast<ConstantInt>(Op->getOperand(1))) {
5321 uint32_t BitWidth = cast<IntegerType>(Op->getType())->getBitWidth();
5322
5323 // If the shift count is not less than the bitwidth, the result of
5324 // the shift is undefined. Don't try to analyze it, because the
5325 // resolution chosen here may differ from the resolution chosen in
5326 // other parts of the compiler.
5327 if (SA->getValue().ult(BitWidth)) {
5328 Constant *X =
5329 ConstantInt::get(SA->getContext(),
5330 APInt::getOneBitSet(BitWidth, SA->getZExtValue()));
5331 return BinaryOp(Instruction::UDiv, Op->getOperand(0), X);
5332 }
5333 }
5334 return BinaryOp(Op);
5335
5336 case Instruction::ExtractValue: {
5337 auto *EVI = cast<ExtractValueInst>(Op);
5338 if (EVI->getNumIndices() != 1 || EVI->getIndices()[0] != 0)
5339 break;
5340
5341 auto *WO = dyn_cast<WithOverflowInst>(EVI->getAggregateOperand());
5342 if (!WO)
5343 break;
5344
5345 Instruction::BinaryOps BinOp = WO->getBinaryOp();
5346 bool Signed = WO->isSigned();
5347 // TODO: Should add nuw/nsw flags for mul as well.
5348 if (BinOp == Instruction::Mul || !isOverflowIntrinsicNoWrap(WO, DT))
5349 return BinaryOp(BinOp, WO->getLHS(), WO->getRHS());
5350
5351 // Now that we know that all uses of the arithmetic-result component of
5352 // CI are guarded by the overflow check, we can go ahead and pretend
5353 // that the arithmetic is non-overflowing.
5354 return BinaryOp(BinOp, WO->getLHS(), WO->getRHS(),
5355 /* IsNSW = */ Signed, /* IsNUW = */ !Signed);
5356 }
5357
5358 default:
5359 break;
5360 }
5361
5362 // Recognise intrinsic loop.decrement.reg, and as this has exactly the same
5363 // semantics as a Sub, return a binary sub expression.
5364 if (auto *II = dyn_cast<IntrinsicInst>(V))
5365 if (II->getIntrinsicID() == Intrinsic::loop_decrement_reg)
5366 return BinaryOp(Instruction::Sub, II->getOperand(0), II->getOperand(1));
5367
5368 return std::nullopt;
5369}
5370
5371/// Helper function to createAddRecFromPHIWithCasts. We have a phi
5372/// node whose symbolic (unknown) SCEV is \p SymbolicPHI, which is updated via
5373/// the loop backedge by a SCEVAddExpr, possibly also with a few casts on the
5374/// way. This function checks if \p Op, an operand of this SCEVAddExpr,
5375/// follows one of the following patterns:
5376/// Op == (SExt ix (Trunc iy (%SymbolicPHI) to ix) to iy)
5377/// Op == (ZExt ix (Trunc iy (%SymbolicPHI) to ix) to iy)
5378/// If the SCEV expression of \p Op conforms with one of the expected patterns
5379/// we return the type of the truncation operation, and indicate whether the
5380/// truncated type should be treated as signed/unsigned by setting
5381/// \p Signed to true/false, respectively.
5382static Type *isSimpleCastedPHI(const SCEV *Op, const SCEVUnknown *SymbolicPHI,
5383 bool &Signed, ScalarEvolution &SE) {
5384 // The case where Op == SymbolicPHI (that is, with no type conversions on
5385 // the way) is handled by the regular add recurrence creating logic and
5386 // would have already been triggered in createAddRecForPHI. Reaching it here
5387 // means that createAddRecFromPHI had failed for this PHI before (e.g.,
5388 // because one of the other operands of the SCEVAddExpr updating this PHI is
5389 // not invariant).
5390 //
5391 // Here we look for the case where Op = (ext(trunc(SymbolicPHI))), and in
5392 // this case predicates that allow us to prove that Op == SymbolicPHI will
5393 // be added.
5394 if (Op == SymbolicPHI)
5395 return nullptr;
5396
5397 unsigned SourceBits = SE.getTypeSizeInBits(SymbolicPHI->getType());
5398 unsigned NewBits = SE.getTypeSizeInBits(Op->getType());
5399 if (SourceBits != NewBits)
5400 return nullptr;
5401
5404 if (!SExt && !ZExt)
5405 return nullptr;
5406 const SCEVTruncateExpr *Trunc =
5409 if (!Trunc)
5410 return nullptr;
5411 const SCEV *X = Trunc->getOperand();
5412 if (X != SymbolicPHI)
5413 return nullptr;
5414 Signed = SExt != nullptr;
5415 return Trunc->getType();
5416}
5417
5418static const Loop *isIntegerLoopHeaderPHI(const PHINode *PN, LoopInfo &LI) {
5419 if (!PN->getType()->isIntegerTy())
5420 return nullptr;
5421 const Loop *L = LI.getLoopFor(PN->getParent());
5422 if (!L || L->getHeader() != PN->getParent())
5423 return nullptr;
5424 return L;
5425}
5426
5427// Analyze \p SymbolicPHI, a SCEV expression of a phi node, and check if the
5428// computation that updates the phi follows the following pattern:
5429// (SExt/ZExt ix (Trunc iy (%SymbolicPHI) to ix) to iy) + InvariantAccum
5430// which correspond to a phi->trunc->sext/zext->add->phi update chain.
5431// If so, try to see if it can be rewritten as an AddRecExpr under some
5432// Predicates. If successful, return them as a pair. Also cache the results
5433// of the analysis.
5434//
5435// Example usage scenario:
5436// Say the Rewriter is called for the following SCEV:
5437// 8 * ((sext i32 (trunc i64 %X to i32) to i64) + %Step)
5438// where:
5439// %X = phi i64 (%Start, %BEValue)
5440// It will visitMul->visitAdd->visitSExt->visitTrunc->visitUnknown(%X),
5441// and call this function with %SymbolicPHI = %X.
5442//
5443// The analysis will find that the value coming around the backedge has
5444// the following SCEV:
5445// BEValue = ((sext i32 (trunc i64 %X to i32) to i64) + %Step)
5446// Upon concluding that this matches the desired pattern, the function
5447// will return the pair {NewAddRec, SmallPredsVec} where:
5448// NewAddRec = {%Start,+,%Step}
5449// SmallPredsVec = {P1, P2, P3} as follows:
5450// P1(WrapPred): AR: {trunc(%Start),+,(trunc %Step)}<nsw> Flags: <nssw>
5451// P2(EqualPred): %Start == (sext i32 (trunc i64 %Start to i32) to i64)
5452// P3(EqualPred): %Step == (sext i32 (trunc i64 %Step to i32) to i64)
5453// The returned pair means that SymbolicPHI can be rewritten into NewAddRec
5454// under the predicates {P1,P2,P3}.
5455// This predicated rewrite will be cached in PredicatedSCEVRewrites:
5456// PredicatedSCEVRewrites[{%X,L}] = {NewAddRec, {P1,P2,P3)}
5457//
5458// TODO's:
5459//
5460// 1) Extend the Induction descriptor to also support inductions that involve
5461// casts: When needed (namely, when we are called in the context of the
5462// vectorizer induction analysis), a Set of cast instructions will be
5463// populated by this method, and provided back to isInductionPHI. This is
5464// needed to allow the vectorizer to properly record them to be ignored by
5465// the cost model and to avoid vectorizing them (otherwise these casts,
5466// which are redundant under the runtime overflow checks, will be
5467// vectorized, which can be costly).
5468//
5469// 2) Support additional induction/PHISCEV patterns: We also want to support
5470// inductions where the sext-trunc / zext-trunc operations (partly) occur
5471// after the induction update operation (the induction increment):
5472//
5473// (Trunc iy (SExt/ZExt ix (%SymbolicPHI + InvariantAccum) to iy) to ix)
5474// which correspond to a phi->add->trunc->sext/zext->phi update chain.
5475//
5476// (Trunc iy ((SExt/ZExt ix (%SymbolicPhi) to iy) + InvariantAccum) to ix)
5477// which correspond to a phi->trunc->add->sext/zext->phi update chain.
5478//
5479// 3) Outline common code with createAddRecFromPHI to avoid duplication.
5480std::optional<std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>>>
5481ScalarEvolution::createAddRecFromPHIWithCastsImpl(const SCEVUnknown *SymbolicPHI) {
5483
5484 // *** Part1: Analyze if we have a phi-with-cast pattern for which we can
5485 // return an AddRec expression under some predicate.
5486
5487 auto *PN = cast<PHINode>(SymbolicPHI->getValue());
5488 const Loop *L = isIntegerLoopHeaderPHI(PN, LI);
5489 assert(L && "Expecting an integer loop header phi");
5490
5491 // The loop may have multiple entrances or multiple exits; we can analyze
5492 // this phi as an addrec if it has a unique entry value and a unique
5493 // backedge value.
5494 Value *BEValueV = nullptr, *StartValueV = nullptr;
5495 for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) {
5496 Value *V = PN->getIncomingValue(i);
5497 if (L->contains(PN->getIncomingBlock(i))) {
5498 if (!BEValueV) {
5499 BEValueV = V;
5500 } else if (BEValueV != V) {
5501 BEValueV = nullptr;
5502 break;
5503 }
5504 } else if (!StartValueV) {
5505 StartValueV = V;
5506 } else if (StartValueV != V) {
5507 StartValueV = nullptr;
5508 break;
5509 }
5510 }
5511 if (!BEValueV || !StartValueV)
5512 return std::nullopt;
5513
5514 const SCEV *BEValue = getSCEV(BEValueV);
5515
5516 // If the value coming around the backedge is an add with the symbolic
5517 // value we just inserted, possibly with casts that we can ignore under
5518 // an appropriate runtime guard, then we found a simple induction variable!
5519 const auto *Add = dyn_cast<SCEVAddExpr>(BEValue);
5520 if (!Add)
5521 return std::nullopt;
5522
5523 // If there is a single occurrence of the symbolic value, possibly
5524 // casted, replace it with a recurrence.
5525 unsigned FoundIndex = Add->getNumOperands();
5526 Type *TruncTy = nullptr;
5527 bool Signed;
5528 for (unsigned i = 0, e = Add->getNumOperands(); i != e; ++i)
5529 if ((TruncTy =
5530 isSimpleCastedPHI(Add->getOperand(i), SymbolicPHI, Signed, *this)))
5531 if (FoundIndex == e) {
5532 FoundIndex = i;
5533 break;
5534 }
5535
5536 if (FoundIndex == Add->getNumOperands())
5537 return std::nullopt;
5538
5539 // Create an add with everything but the specified operand.
5541 for (unsigned i = 0, e = Add->getNumOperands(); i != e; ++i)
5542 if (i != FoundIndex)
5543 Ops.push_back(Add->getOperand(i));
5544 const SCEV *Accum = getAddExpr(Ops);
5545
5546 // The runtime checks will not be valid if the step amount is
5547 // varying inside the loop.
5548 if (!isLoopInvariant(Accum, L))
5549 return std::nullopt;
5550
5551 // *** Part2: Create the predicates
5552
5553 // Analysis was successful: we have a phi-with-cast pattern for which we
5554 // can return an AddRec expression under the following predicates:
5555 //
5556 // P1: A Wrap predicate that guarantees that Trunc(Start) + i*Trunc(Accum)
5557 // fits within the truncated type (does not overflow) for i = 0 to n-1.
5558 // P2: An Equal predicate that guarantees that
5559 // Start = (Ext ix (Trunc iy (Start) to ix) to iy)
5560 // P3: An Equal predicate that guarantees that
5561 // Accum = (Ext ix (Trunc iy (Accum) to ix) to iy)
5562 //
5563 // As we next prove, the above predicates guarantee that:
5564 // Start + i*Accum = (Ext ix (Trunc iy ( Start + i*Accum ) to ix) to iy)
5565 //
5566 //
5567 // More formally, we want to prove that:
5568 // Expr(i+1) = Start + (i+1) * Accum
5569 // = (Ext ix (Trunc iy (Expr(i)) to ix) to iy) + Accum
5570 //
5571 // Given that:
5572 // 1) Expr(0) = Start
5573 // 2) Expr(1) = Start + Accum
5574 // = (Ext ix (Trunc iy (Start) to ix) to iy) + Accum :: from P2
5575 // 3) Induction hypothesis (step i):
5576 // Expr(i) = (Ext ix (Trunc iy (Expr(i-1)) to ix) to iy) + Accum
5577 //
5578 // Proof:
5579 // Expr(i+1) =
5580 // = Start + (i+1)*Accum
5581 // = (Start + i*Accum) + Accum
5582 // = Expr(i) + Accum
5583 // = (Ext ix (Trunc iy (Expr(i-1)) to ix) to iy) + Accum + Accum
5584 // :: from step i
5585 //
5586 // = (Ext ix (Trunc iy (Start + (i-1)*Accum) to ix) to iy) + Accum + Accum
5587 //
5588 // = (Ext ix (Trunc iy (Start + (i-1)*Accum) to ix) to iy)
5589 // + (Ext ix (Trunc iy (Accum) to ix) to iy)
5590 // + Accum :: from P3
5591 //
5592 // = (Ext ix (Trunc iy ((Start + (i-1)*Accum) + Accum) to ix) to iy)
5593 // + Accum :: from P1: Ext(x)+Ext(y)=>Ext(x+y)
5594 //
5595 // = (Ext ix (Trunc iy (Start + i*Accum) to ix) to iy) + Accum
5596 // = (Ext ix (Trunc iy (Expr(i)) to ix) to iy) + Accum
5597 //
5598 // By induction, the same applies to all iterations 1<=i<n:
5599 //
5600
5601 // Create a truncated addrec for which we will add a no overflow check (P1).
5602 const SCEV *StartVal = getSCEV(StartValueV);
5603 const SCEV *PHISCEV =
5604 getAddRecExpr(getTruncateExpr(StartVal, TruncTy),
5605 getTruncateExpr(Accum, TruncTy), L, SCEV::FlagAnyWrap);
5606
5607 // PHISCEV can be either a SCEVConstant or a SCEVAddRecExpr.
5608 // ex: If truncated Accum is 0 and StartVal is a constant, then PHISCEV
5609 // will be constant.
5610 //
5611 // If PHISCEV is a constant, then P1 degenerates into P2 or P3, so we don't
5612 // add P1.
5613 if (const auto *AR = dyn_cast<SCEVAddRecExpr>(PHISCEV)) {
5617 const SCEVPredicate *AddRecPred = getWrapPredicate(AR, AddedFlags);
5618 Predicates.push_back(AddRecPred);
5619 }
5620
5621 // Create the Equal Predicates P2,P3:
5622
5623 // It is possible that the predicates P2 and/or P3 are computable at
5624 // compile time due to StartVal and/or Accum being constants.
5625 // If either one is, then we can check that now and escape if either P2
5626 // or P3 is false.
5627
5628 // Construct the extended SCEV: (Ext ix (Trunc iy (Expr) to ix) to iy)
5629 // for each of StartVal and Accum
5630 auto getExtendedExpr = [&](const SCEV *Expr,
5631 bool CreateSignExtend) -> const SCEV * {
5632 assert(isLoopInvariant(Expr, L) && "Expr is expected to be invariant");
5633 const SCEV *TruncatedExpr = getTruncateExpr(Expr, TruncTy);
5634 const SCEV *ExtendedExpr =
5635 CreateSignExtend ? getSignExtendExpr(TruncatedExpr, Expr->getType())
5636 : getZeroExtendExpr(TruncatedExpr, Expr->getType());
5637 return ExtendedExpr;
5638 };
5639
5640 // Given:
5641 // ExtendedExpr = (Ext ix (Trunc iy (Expr) to ix) to iy
5642 // = getExtendedExpr(Expr)
5643 // Determine whether the predicate P: Expr == ExtendedExpr
5644 // is known to be false at compile time
5645 auto PredIsKnownFalse = [&](const SCEV *Expr,
5646 const SCEV *ExtendedExpr) -> bool {
5647 return Expr != ExtendedExpr &&
5648 isKnownPredicate(ICmpInst::ICMP_NE, Expr, ExtendedExpr);
5649 };
5650
5651 const SCEV *StartExtended = getExtendedExpr(StartVal, Signed);
5652 if (PredIsKnownFalse(StartVal, StartExtended)) {
5653 LLVM_DEBUG(dbgs() << "P2 is compile-time false\n";);
5654 return std::nullopt;
5655 }
5656
5657 // The Step is always Signed (because the overflow checks are either
5658 // NSSW or NUSW)
5659 const SCEV *AccumExtended = getExtendedExpr(Accum, /*CreateSignExtend=*/true);
5660 if (PredIsKnownFalse(Accum, AccumExtended)) {
5661 LLVM_DEBUG(dbgs() << "P3 is compile-time false\n";);
5662 return std::nullopt;
5663 }
5664
5665 auto AppendPredicate = [&](const SCEV *Expr,
5666 const SCEV *ExtendedExpr) -> void {
5667 if (Expr != ExtendedExpr &&
5668 !isKnownPredicate(ICmpInst::ICMP_EQ, Expr, ExtendedExpr)) {
5669 const SCEVPredicate *Pred = getEqualPredicate(Expr, ExtendedExpr);
5670 LLVM_DEBUG(dbgs() << "Added Predicate: " << *Pred);
5671 Predicates.push_back(Pred);
5672 }
5673 };
5674
5675 AppendPredicate(StartVal, StartExtended);
5676 AppendPredicate(Accum, AccumExtended);
5677
5678 // *** Part3: Predicates are ready. Now go ahead and create the new addrec in
5679 // which the casts had been folded away. The caller can rewrite SymbolicPHI
5680 // into NewAR if it will also add the runtime overflow checks specified in
5681 // Predicates.
5682 auto *NewAR = getAddRecExpr(StartVal, Accum, L, SCEV::FlagAnyWrap);
5683
5684 std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>> PredRewrite =
5685 std::make_pair(NewAR, Predicates);
5686 // Remember the result of the analysis for this SCEV at this locayyytion.
5687 PredicatedSCEVRewrites[{SymbolicPHI, L}] = PredRewrite;
5688 return PredRewrite;
5689}
5690
5691std::optional<std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>>>
5693 auto *PN = cast<PHINode>(SymbolicPHI->getValue());
5694 const Loop *L = isIntegerLoopHeaderPHI(PN, LI);
5695 if (!L)
5696 return std::nullopt;
5697
5698 // Check to see if we already analyzed this PHI.
5699 auto I = PredicatedSCEVRewrites.find({SymbolicPHI, L});
5700 if (I != PredicatedSCEVRewrites.end()) {
5701 std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>> Rewrite =
5702 I->second;
5703 // Analysis was done before and failed to create an AddRec:
5704 if (Rewrite.first == SymbolicPHI)
5705 return std::nullopt;
5706 // Analysis was done before and succeeded to create an AddRec under
5707 // a predicate:
5708 assert(isa<SCEVAddRecExpr>(Rewrite.first) && "Expected an AddRec");
5709 assert(!(Rewrite.second).empty() && "Expected to find Predicates");
5710 return Rewrite;
5711 }
5712
5713 std::optional<std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>>>
5714 Rewrite = createAddRecFromPHIWithCastsImpl(SymbolicPHI);
5715
5716 // Record in the cache that the analysis failed
5717 if (!Rewrite) {
5719 PredicatedSCEVRewrites[{SymbolicPHI, L}] = {SymbolicPHI, Predicates};
5720 return std::nullopt;
5721 }
5722
5723 return Rewrite;
5724}
5725
5726// FIXME: This utility is currently required because the Rewriter currently
5727// does not rewrite this expression:
5728// {0, +, (sext ix (trunc iy to ix) to iy)}
5729// into {0, +, %step},
5730// even when the following Equal predicate exists:
5731// "%step == (sext ix (trunc iy to ix) to iy)".
5733 const SCEVAddRecExpr *AR1, const SCEVAddRecExpr *AR2) const {
5734 if (AR1 == AR2)
5735 return true;
5736
5737 auto areExprsEqual = [&](const SCEV *Expr1, const SCEV *Expr2) -> bool {
5738 if (Expr1 != Expr2 &&
5739 !Preds->implies(SE.getEqualPredicate(Expr1, Expr2), SE) &&
5740 !Preds->implies(SE.getEqualPredicate(Expr2, Expr1), SE))
5741 return false;
5742 return true;
5743 };
5744
5745 if (!areExprsEqual(AR1->getStart(), AR2->getStart()) ||
5746 !areExprsEqual(AR1->getStepRecurrence(SE), AR2->getStepRecurrence(SE)))
5747 return false;
5748 return true;
5749}
5750
5751/// A helper function for createAddRecFromPHI to handle simple cases.
5752///
5753/// This function tries to find an AddRec expression for the simplest (yet most
5754/// common) cases: PN = PHI(Start, OP(Self, LoopInvariant)).
5755/// If it fails, createAddRecFromPHI will use a more general, but slow,
5756/// technique for finding the AddRec expression.
5757const SCEV *ScalarEvolution::createSimpleAffineAddRec(PHINode *PN,
5758 Value *BEValueV,
5759 Value *StartValueV) {
5760 const Loop *L = LI.getLoopFor(PN->getParent());
5761 assert(L && L->getHeader() == PN->getParent());
5762 assert(BEValueV && StartValueV);
5763
5764 auto BO = MatchBinaryOp(BEValueV, getDataLayout(), AC, DT, PN);
5765 if (!BO)
5766 return nullptr;
5767
5768 if (BO->Opcode != Instruction::Add)
5769 return nullptr;
5770
5771 const SCEV *Accum = nullptr;
5772 if (BO->LHS == PN && L->isLoopInvariant(BO->RHS))
5773 Accum = getSCEV(BO->RHS);
5774 else if (BO->RHS == PN && L->isLoopInvariant(BO->LHS))
5775 Accum = getSCEV(BO->LHS);
5776
5777 if (!Accum)
5778 return nullptr;
5779
5781 if (BO->IsNUW)
5782 Flags = setFlags(Flags, SCEV::FlagNUW);
5783 if (BO->IsNSW)
5784 Flags = setFlags(Flags, SCEV::FlagNSW);
5785
5786 const SCEV *StartVal = getSCEV(StartValueV);
5787 const SCEV *PHISCEV = getAddRecExpr(StartVal, Accum, L, Flags);
5788 insertValueToMap(PN, PHISCEV);
5789
5790 if (auto *AR = dyn_cast<SCEVAddRecExpr>(PHISCEV)) {
5791 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR),
5793 proveNoWrapViaConstantRanges(AR)));
5794 }
5795
5796 // We can add Flags to the post-inc expression only if we
5797 // know that it is *undefined behavior* for BEValueV to
5798 // overflow.
5799 if (auto *BEInst = dyn_cast<Instruction>(BEValueV)) {
5800 assert(isLoopInvariant(Accum, L) &&
5801 "Accum is defined outside L, but is not invariant?");
5802 if (isAddRecNeverPoison(BEInst, L))
5803 (void)getAddRecExpr(getAddExpr(StartVal, Accum), Accum, L, Flags);
5804 }
5805
5806 return PHISCEV;
5807}
5808
5809const SCEV *ScalarEvolution::createAddRecFromPHI(PHINode *PN) {
5810 const Loop *L = LI.getLoopFor(PN->getParent());
5811 if (!L || L->getHeader() != PN->getParent())
5812 return nullptr;
5813
5814 // The loop may have multiple entrances or multiple exits; we can analyze
5815 // this phi as an addrec if it has a unique entry value and a unique
5816 // backedge value.
5817 Value *BEValueV = nullptr, *StartValueV = nullptr;
5818 for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) {
5819 Value *V = PN->getIncomingValue(i);
5820 if (L->contains(PN->getIncomingBlock(i))) {
5821 if (!BEValueV) {
5822 BEValueV = V;
5823 } else if (BEValueV != V) {
5824 BEValueV = nullptr;
5825 break;
5826 }
5827 } else if (!StartValueV) {
5828 StartValueV = V;
5829 } else if (StartValueV != V) {
5830 StartValueV = nullptr;
5831 break;
5832 }
5833 }
5834 if (!BEValueV || !StartValueV)
5835 return nullptr;
5836
5837 assert(ValueExprMap.find_as(PN) == ValueExprMap.end() &&
5838 "PHI node already processed?");
5839
5840 // First, try to find AddRec expression without creating a fictituos symbolic
5841 // value for PN.
5842 if (auto *S = createSimpleAffineAddRec(PN, BEValueV, StartValueV))
5843 return S;
5844
5845 // Handle PHI node value symbolically.
5846 const SCEV *SymbolicName = getUnknown(PN);
5847 insertValueToMap(PN, SymbolicName);
5848
5849 // Using this symbolic name for the PHI, analyze the value coming around
5850 // the back-edge.
5851 const SCEV *BEValue = getSCEV(BEValueV);
5852
5853 // NOTE: If BEValue is loop invariant, we know that the PHI node just
5854 // has a special value for the first iteration of the loop.
5855
5856 // If the value coming around the backedge is an add with the symbolic
5857 // value we just inserted, then we found a simple induction variable!
5858 if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(BEValue)) {
5859 // If there is a single occurrence of the symbolic value, replace it
5860 // with a recurrence.
5861 unsigned FoundIndex = Add->getNumOperands();
5862 for (unsigned i = 0, e = Add->getNumOperands(); i != e; ++i)
5863 if (Add->getOperand(i) == SymbolicName)
5864 if (FoundIndex == e) {
5865 FoundIndex = i;
5866 break;
5867 }
5868
5869 if (FoundIndex != Add->getNumOperands()) {
5870 // Create an add with everything but the specified operand.
5872 for (unsigned i = 0, e = Add->getNumOperands(); i != e; ++i)
5873 if (i != FoundIndex)
5874 Ops.push_back(SCEVBackedgeConditionFolder::rewrite(Add->getOperand(i),
5875 L, *this));
5876 const SCEV *Accum = getAddExpr(Ops);
5877
5878 // This is not a valid addrec if the step amount is varying each
5879 // loop iteration, but is not itself an addrec in this loop.
5880 if (isLoopInvariant(Accum, L) ||
5881 (isa<SCEVAddRecExpr>(Accum) &&
5882 cast<SCEVAddRecExpr>(Accum)->getLoop() == L)) {
5884
5885 if (auto BO = MatchBinaryOp(BEValueV, getDataLayout(), AC, DT, PN)) {
5886 if (BO->Opcode == Instruction::Add && BO->LHS == PN) {
5887 if (BO->IsNUW)
5888 Flags = setFlags(Flags, SCEV::FlagNUW);
5889 if (BO->IsNSW)
5890 Flags = setFlags(Flags, SCEV::FlagNSW);
5891 }
5892 } else if (GEPOperator *GEP = dyn_cast<GEPOperator>(BEValueV)) {
5893 if (GEP->getOperand(0) == PN) {
5894 GEPNoWrapFlags NW = GEP->getNoWrapFlags();
5895 // If the increment has any nowrap flags, then we know the address
5896 // space cannot be wrapped around.
5897 if (NW != GEPNoWrapFlags::none())
5898 Flags = setFlags(Flags, SCEV::FlagNW);
5899 // If the GEP is nuw or nusw with non-negative offset, we know that
5900 // no unsigned wrap occurs. We cannot set the nsw flag as only the
5901 // offset is treated as signed, while the base is unsigned.
5902 if (NW.hasNoUnsignedWrap() ||
5904 Flags = setFlags(Flags, SCEV::FlagNUW);
5905 }
5906
5907 // We cannot transfer nuw and nsw flags from subtraction
5908 // operations -- sub nuw X, Y is not the same as add nuw X, -Y
5909 // for instance.
5910 }
5911
5912 const SCEV *StartVal = getSCEV(StartValueV);
5913 const SCEV *PHISCEV = getAddRecExpr(StartVal, Accum, L, Flags);
5914
5915 // Okay, for the entire analysis of this edge we assumed the PHI
5916 // to be symbolic. We now need to go back and purge all of the
5917 // entries for the scalars that use the symbolic expression.
5918 forgetMemoizedResults(SymbolicName);
5919 insertValueToMap(PN, PHISCEV);
5920
5921 if (auto *AR = dyn_cast<SCEVAddRecExpr>(PHISCEV)) {
5922 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR),
5924 proveNoWrapViaConstantRanges(AR)));
5925 }
5926
5927 // We can add Flags to the post-inc expression only if we
5928 // know that it is *undefined behavior* for BEValueV to
5929 // overflow.
5930 if (auto *BEInst = dyn_cast<Instruction>(BEValueV))
5931 if (isLoopInvariant(Accum, L) && isAddRecNeverPoison(BEInst, L))
5932 (void)getAddRecExpr(getAddExpr(StartVal, Accum), Accum, L, Flags);
5933
5934 return PHISCEV;
5935 }
5936 }
5937 } else {
5938 // Otherwise, this could be a loop like this:
5939 // i = 0; for (j = 1; ..; ++j) { .... i = j; }
5940 // In this case, j = {1,+,1} and BEValue is j.
5941 // Because the other in-value of i (0) fits the evolution of BEValue
5942 // i really is an addrec evolution.
5943 //
5944 // We can generalize this saying that i is the shifted value of BEValue
5945 // by one iteration:
5946 // PHI(f(0), f({1,+,1})) --> f({0,+,1})
5947
5948 // Do not allow refinement in rewriting of BEValue.
5949 const SCEV *Shifted = SCEVShiftRewriter::rewrite(BEValue, L, *this);
5950 const SCEV *Start = SCEVInitRewriter::rewrite(Shifted, L, *this, false);
5951 if (Shifted != getCouldNotCompute() && Start != getCouldNotCompute() &&
5952 isGuaranteedNotToCauseUB(Shifted) && ::impliesPoison(Shifted, Start)) {
5953 const SCEV *StartVal = getSCEV(StartValueV);
5954 if (Start == StartVal) {
5955 // Okay, for the entire analysis of this edge we assumed the PHI
5956 // to be symbolic. We now need to go back and purge all of the
5957 // entries for the scalars that use the symbolic expression.
5958 forgetMemoizedResults(SymbolicName);
5959 insertValueToMap(PN, Shifted);
5960 return Shifted;
5961 }
5962 }
5963 }
5964
5965 // Remove the temporary PHI node SCEV that has been inserted while intending
5966 // to create an AddRecExpr for this PHI node. We can not keep this temporary
5967 // as it will prevent later (possibly simpler) SCEV expressions to be added
5968 // to the ValueExprMap.
5969 eraseValueFromMap(PN);
5970
5971 return nullptr;
5972}
5973
5974// Try to match a control flow sequence that branches out at BI and merges back
5975// at Merge into a "C ? LHS : RHS" select pattern. Return true on a successful
5976// match.
5978 Value *&C, Value *&LHS, Value *&RHS) {
5979 C = BI->getCondition();
5980
5981 BasicBlockEdge LeftEdge(BI->getParent(), BI->getSuccessor(0));
5982 BasicBlockEdge RightEdge(BI->getParent(), BI->getSuccessor(1));
5983
5984 if (!LeftEdge.isSingleEdge())
5985 return false;
5986
5987 assert(RightEdge.isSingleEdge() && "Follows from LeftEdge.isSingleEdge()");
5988
5989 Use &LeftUse = Merge->getOperandUse(0);
5990 Use &RightUse = Merge->getOperandUse(1);
5991
5992 if (DT.dominates(LeftEdge, LeftUse) && DT.dominates(RightEdge, RightUse)) {
5993 LHS = LeftUse;
5994 RHS = RightUse;
5995 return true;
5996 }
5997
5998 if (DT.dominates(LeftEdge, RightUse) && DT.dominates(RightEdge, LeftUse)) {
5999 LHS = RightUse;
6000 RHS = LeftUse;
6001 return true;
6002 }
6003
6004 return false;
6005}
6006
6007const SCEV *ScalarEvolution::createNodeFromSelectLikePHI(PHINode *PN) {
6008 auto IsReachable =
6009 [&](BasicBlock *BB) { return DT.isReachableFromEntry(BB); };
6010 if (PN->getNumIncomingValues() == 2 && all_of(PN->blocks(), IsReachable)) {
6011 // Try to match
6012 //
6013 // br %cond, label %left, label %right
6014 // left:
6015 // br label %merge
6016 // right:
6017 // br label %merge
6018 // merge:
6019 // V = phi [ %x, %left ], [ %y, %right ]
6020 //
6021 // as "select %cond, %x, %y"
6022
6023 BasicBlock *IDom = DT[PN->getParent()]->getIDom()->getBlock();
6024 assert(IDom && "At least the entry block should dominate PN");
6025
6026 auto *BI = dyn_cast<BranchInst>(IDom->getTerminator());
6027 Value *Cond = nullptr, *LHS = nullptr, *RHS = nullptr;
6028
6029 if (BI && BI->isConditional() &&
6030 BrPHIToSelect(DT, BI, PN, Cond, LHS, RHS) &&
6033 return createNodeForSelectOrPHI(PN, Cond, LHS, RHS);
6034 }
6035
6036 return nullptr;
6037}
6038
6039/// Returns SCEV for the first operand of a phi if all phi operands have
6040/// identical opcodes and operands
6041/// eg.
6042/// a: %add = %a + %b
6043/// br %c
6044/// b: %add1 = %a + %b
6045/// br %c
6046/// c: %phi = phi [%add, a], [%add1, b]
6047/// scev(%phi) => scev(%add)
6048const SCEV *
6049ScalarEvolution::createNodeForPHIWithIdenticalOperands(PHINode *PN) {
6050 BinaryOperator *CommonInst = nullptr;
6051 // Check if instructions are identical.
6052 for (Value *Incoming : PN->incoming_values()) {
6053 auto *IncomingInst = dyn_cast<BinaryOperator>(Incoming);
6054 if (!IncomingInst)
6055 return nullptr;
6056 if (CommonInst) {
6057 if (!CommonInst->isIdenticalToWhenDefined(IncomingInst))
6058 return nullptr; // Not identical, give up
6059 } else {
6060 // Remember binary operator
6061 CommonInst = IncomingInst;
6062 }
6063 }
6064 if (!CommonInst)
6065 return nullptr;
6066
6067 // Check if SCEV exprs for instructions are identical.
6068 const SCEV *CommonSCEV = getSCEV(CommonInst);
6069 bool SCEVExprsIdentical =
6071 [this, CommonSCEV](Value *V) { return CommonSCEV == getSCEV(V); });
6072 return SCEVExprsIdentical ? CommonSCEV : nullptr;
6073}
6074
6075const SCEV *ScalarEvolution::createNodeForPHI(PHINode *PN) {
6076 if (const SCEV *S = createAddRecFromPHI(PN))
6077 return S;
6078
6079 // We do not allow simplifying phi (undef, X) to X here, to avoid reusing the
6080 // phi node for X.
6081 if (Value *V = simplifyInstruction(
6082 PN, {getDataLayout(), &TLI, &DT, &AC, /*CtxI=*/nullptr,
6083 /*UseInstrInfo=*/true, /*CanUseUndef=*/false}))
6084 return getSCEV(V);
6085
6086 if (const SCEV *S = createNodeForPHIWithIdenticalOperands(PN))
6087 return S;
6088
6089 if (const SCEV *S = createNodeFromSelectLikePHI(PN))
6090 return S;
6091
6092 // If it's not a loop phi, we can't handle it yet.
6093 return getUnknown(PN);
6094}
6095
6096bool SCEVMinMaxExprContains(const SCEV *Root, const SCEV *OperandToFind,
6097 SCEVTypes RootKind) {
6098 struct FindClosure {
6099 const SCEV *OperandToFind;
6100 const SCEVTypes RootKind; // Must be a sequential min/max expression.
6101 const SCEVTypes NonSequentialRootKind; // Non-seq variant of RootKind.
6102
6103 bool Found = false;
6104
6105 bool canRecurseInto(SCEVTypes Kind) const {
6106 // We can only recurse into the SCEV expression of the same effective type
6107 // as the type of our root SCEV expression, and into zero-extensions.
6108 return RootKind == Kind || NonSequentialRootKind == Kind ||
6109 scZeroExtend == Kind;
6110 };
6111
6112 FindClosure(const SCEV *OperandToFind, SCEVTypes RootKind)
6113 : OperandToFind(OperandToFind), RootKind(RootKind),
6114 NonSequentialRootKind(
6116 RootKind)) {}
6117
6118 bool follow(const SCEV *S) {
6119 Found = S == OperandToFind;
6120
6121 return !isDone() && canRecurseInto(S->getSCEVType());
6122 }
6123
6124 bool isDone() const { return Found; }
6125 };
6126
6127 FindClosure FC(OperandToFind, RootKind);
6128 visitAll(Root, FC);
6129 return FC.Found;
6130}
6131
6132std::optional<const SCEV *>
6133ScalarEvolution::createNodeForSelectOrPHIInstWithICmpInstCond(Type *Ty,
6134 ICmpInst *Cond,
6135 Value *TrueVal,
6136 Value *FalseVal) {
6137 // Try to match some simple smax or umax patterns.
6138 auto *ICI = Cond;
6139
6140 Value *LHS = ICI->getOperand(0);
6141 Value *RHS = ICI->getOperand(1);
6142
6143 switch (ICI->getPredicate()) {
6144 case ICmpInst::ICMP_SLT:
6145 case ICmpInst::ICMP_SLE:
6146 case ICmpInst::ICMP_ULT:
6147 case ICmpInst::ICMP_ULE:
6148 std::swap(LHS, RHS);
6149 [[fallthrough]];
6150 case ICmpInst::ICMP_SGT:
6151 case ICmpInst::ICMP_SGE:
6152 case ICmpInst::ICMP_UGT:
6153 case ICmpInst::ICMP_UGE:
6154 // a > b ? a+x : b+x -> max(a, b)+x
6155 // a > b ? b+x : a+x -> min(a, b)+x
6157 bool Signed = ICI->isSigned();
6158 const SCEV *LA = getSCEV(TrueVal);
6159 const SCEV *RA = getSCEV(FalseVal);
6160 const SCEV *LS = getSCEV(LHS);
6161 const SCEV *RS = getSCEV(RHS);
6162 if (LA->getType()->isPointerTy()) {
6163 // FIXME: Handle cases where LS/RS are pointers not equal to LA/RA.
6164 // Need to make sure we can't produce weird expressions involving
6165 // negated pointers.
6166 if (LA == LS && RA == RS)
6167 return Signed ? getSMaxExpr(LS, RS) : getUMaxExpr(LS, RS);
6168 if (LA == RS && RA == LS)
6169 return Signed ? getSMinExpr(LS, RS) : getUMinExpr(LS, RS);
6170 }
6171 auto CoerceOperand = [&](const SCEV *Op) -> const SCEV * {
6172 if (Op->getType()->isPointerTy()) {
6175 return Op;
6176 }
6177 if (Signed)
6178 Op = getNoopOrSignExtend(Op, Ty);
6179 else
6180 Op = getNoopOrZeroExtend(Op, Ty);
6181 return Op;
6182 };
6183 LS = CoerceOperand(LS);
6184 RS = CoerceOperand(RS);
6186 break;
6187 const SCEV *LDiff = getMinusSCEV(LA, LS);
6188 const SCEV *RDiff = getMinusSCEV(RA, RS);
6189 if (LDiff == RDiff)
6190 return getAddExpr(Signed ? getSMaxExpr(LS, RS) : getUMaxExpr(LS, RS),
6191 LDiff);
6192 LDiff = getMinusSCEV(LA, RS);
6193 RDiff = getMinusSCEV(RA, LS);
6194 if (LDiff == RDiff)
6195 return getAddExpr(Signed ? getSMinExpr(LS, RS) : getUMinExpr(LS, RS),
6196 LDiff);
6197 }
6198 break;
6199 case ICmpInst::ICMP_NE:
6200 // x != 0 ? x+y : C+y -> x == 0 ? C+y : x+y
6201 std::swap(TrueVal, FalseVal);
6202 [[fallthrough]];
6203 case ICmpInst::ICMP_EQ:
6204 // x == 0 ? C+y : x+y -> umax(x, C)+y iff C u<= 1
6207 const SCEV *X = getNoopOrZeroExtend(getSCEV(LHS), Ty);
6208 const SCEV *TrueValExpr = getSCEV(TrueVal); // C+y
6209 const SCEV *FalseValExpr = getSCEV(FalseVal); // x+y
6210 const SCEV *Y = getMinusSCEV(FalseValExpr, X); // y = (x+y)-x
6211 const SCEV *C = getMinusSCEV(TrueValExpr, Y); // C = (C+y)-y
6212 if (isa<SCEVConstant>(C) && cast<SCEVConstant>(C)->getAPInt().ule(1))
6213 return getAddExpr(getUMaxExpr(X, C), Y);
6214 }
6215 // x == 0 ? 0 : umin (..., x, ...) -> umin_seq(x, umin (...))
6216 // x == 0 ? 0 : umin_seq(..., x, ...) -> umin_seq(x, umin_seq(...))
6217 // x == 0 ? 0 : umin (..., umin_seq(..., x, ...), ...)
6218 // -> umin_seq(x, umin (..., umin_seq(...), ...))
6220 isa<ConstantInt>(TrueVal) && cast<ConstantInt>(TrueVal)->isZero()) {
6221 const SCEV *X = getSCEV(LHS);
6222 while (auto *ZExt = dyn_cast<SCEVZeroExtendExpr>(X))
6223 X = ZExt->getOperand();
6224 if (getTypeSizeInBits(X->getType()) <= getTypeSizeInBits(Ty)) {
6225 const SCEV *FalseValExpr = getSCEV(FalseVal);
6226 if (SCEVMinMaxExprContains(FalseValExpr, X, scSequentialUMinExpr))
6227 return getUMinExpr(getNoopOrZeroExtend(X, Ty), FalseValExpr,
6228 /*Sequential=*/true);
6229 }
6230 }
6231 break;
6232 default:
6233 break;
6234 }
6235
6236 return std::nullopt;
6237}
6238
6239static std::optional<const SCEV *>
6241 const SCEV *TrueExpr, const SCEV *FalseExpr) {
6242 assert(CondExpr->getType()->isIntegerTy(1) &&
6243 TrueExpr->getType() == FalseExpr->getType() &&
6244 TrueExpr->getType()->isIntegerTy(1) &&
6245 "Unexpected operands of a select.");
6246
6247 // i1 cond ? i1 x : i1 C --> C + (i1 cond ? (i1 x - i1 C) : i1 0)
6248 // --> C + (umin_seq cond, x - C)
6249 //
6250 // i1 cond ? i1 C : i1 x --> C + (i1 cond ? i1 0 : (i1 x - i1 C))
6251 // --> C + (i1 ~cond ? (i1 x - i1 C) : i1 0)
6252 // --> C + (umin_seq ~cond, x - C)
6253
6254 // FIXME: while we can't legally model the case where both of the hands
6255 // are fully variable, we only require that the *difference* is constant.
6256 if (!isa<SCEVConstant>(TrueExpr) && !isa<SCEVConstant>(FalseExpr))
6257 return std::nullopt;
6258
6259 const SCEV *X, *C;
6260 if (isa<SCEVConstant>(TrueExpr)) {
6261 CondExpr = SE->getNotSCEV(CondExpr);
6262 X = FalseExpr;
6263 C = TrueExpr;
6264 } else {
6265 X = TrueExpr;
6266 C = FalseExpr;
6267 }
6268 return SE->getAddExpr(C, SE->getUMinExpr(CondExpr, SE->getMinusSCEV(X, C),
6269 /*Sequential=*/true));
6270}
6271
6272static std::optional<const SCEV *>
6274 Value *FalseVal) {
6275 if (!isa<ConstantInt>(TrueVal) && !isa<ConstantInt>(FalseVal))
6276 return std::nullopt;
6277
6278 const auto *SECond = SE->getSCEV(Cond);
6279 const auto *SETrue = SE->getSCEV(TrueVal);
6280 const auto *SEFalse = SE->getSCEV(FalseVal);
6281 return createNodeForSelectViaUMinSeq(SE, SECond, SETrue, SEFalse);
6282}
6283
6284const SCEV *ScalarEvolution::createNodeForSelectOrPHIViaUMinSeq(
6285 Value *V, Value *Cond, Value *TrueVal, Value *FalseVal) {
6286 assert(Cond->getType()->isIntegerTy(1) && "Select condition is not an i1?");
6287 assert(TrueVal->getType() == FalseVal->getType() &&
6288 V->getType() == TrueVal->getType() &&
6289 "Types of select hands and of the result must match.");
6290
6291 // For now, only deal with i1-typed `select`s.
6292 if (!V->getType()->isIntegerTy(1))
6293 return getUnknown(V);
6294
6295 if (std::optional<const SCEV *> S =
6296 createNodeForSelectViaUMinSeq(this, Cond, TrueVal, FalseVal))
6297 return *S;
6298
6299 return getUnknown(V);
6300}
6301
6302const SCEV *ScalarEvolution::createNodeForSelectOrPHI(Value *V, Value *Cond,
6303 Value *TrueVal,
6304 Value *FalseVal) {
6305 // Handle "constant" branch or select. This can occur for instance when a
6306 // loop pass transforms an inner loop and moves on to process the outer loop.
6307 if (auto *CI = dyn_cast<ConstantInt>(Cond))
6308 return getSCEV(CI->isOne() ? TrueVal : FalseVal);
6309
6310 if (auto *I = dyn_cast<Instruction>(V)) {
6311 if (auto *ICI = dyn_cast<ICmpInst>(Cond)) {
6312 if (std::optional<const SCEV *> S =
6313 createNodeForSelectOrPHIInstWithICmpInstCond(I->getType(), ICI,
6314 TrueVal, FalseVal))
6315 return *S;
6316 }
6317 }
6318
6319 return createNodeForSelectOrPHIViaUMinSeq(V, Cond, TrueVal, FalseVal);
6320}
6321
6322/// Expand GEP instructions into add and multiply operations. This allows them
6323/// to be analyzed by regular SCEV code.
6324const SCEV *ScalarEvolution::createNodeForGEP(GEPOperator *GEP) {
6325 assert(GEP->getSourceElementType()->isSized() &&
6326 "GEP source element type must be sized");
6327
6329 for (Value *Index : GEP->indices())
6330 IndexExprs.push_back(getSCEV(Index));
6331 return getGEPExpr(GEP, IndexExprs);
6332}
6333
6334APInt ScalarEvolution::getConstantMultipleImpl(const SCEV *S) {
6335 uint64_t BitWidth = getTypeSizeInBits(S->getType());
6336 auto GetShiftedByZeros = [BitWidth](uint32_t TrailingZeros) {
6337 return TrailingZeros >= BitWidth
6339 : APInt::getOneBitSet(BitWidth, TrailingZeros);
6340 };
6341 auto GetGCDMultiple = [this](const SCEVNAryExpr *N) {
6342 // The result is GCD of all operands results.
6343 APInt Res = getConstantMultiple(N->getOperand(0));
6344 for (unsigned I = 1, E = N->getNumOperands(); I < E && Res != 1; ++I)
6346 Res, getConstantMultiple(N->getOperand(I)));
6347 return Res;
6348 };
6349
6350 switch (S->getSCEVType()) {
6351 case scConstant:
6352 return cast<SCEVConstant>(S)->getAPInt();
6353 case scPtrToInt:
6354 return getConstantMultiple(cast<SCEVPtrToIntExpr>(S)->getOperand());
6355 case scUDivExpr:
6356 case scVScale:
6357 return APInt(BitWidth, 1);
6358 case scTruncate: {
6359 // Only multiples that are a power of 2 will hold after truncation.
6360 const SCEVTruncateExpr *T = cast<SCEVTruncateExpr>(S);
6361 uint32_t TZ = getMinTrailingZeros(T->getOperand());
6362 return GetShiftedByZeros(TZ);
6363 }
6364 case scZeroExtend: {
6365 const SCEVZeroExtendExpr *Z = cast<SCEVZeroExtendExpr>(S);
6366 return getConstantMultiple(Z->getOperand()).zext(BitWidth);
6367 }
6368 case scSignExtend: {
6369 // Only multiples that are a power of 2 will hold after sext.
6370 const SCEVSignExtendExpr *E = cast<SCEVSignExtendExpr>(S);
6371 uint32_t TZ = getMinTrailingZeros(E->getOperand());
6372 return GetShiftedByZeros(TZ);
6373 }
6374 case scMulExpr: {
6375 const SCEVMulExpr *M = cast<SCEVMulExpr>(S);
6376 if (M->hasNoUnsignedWrap()) {
6377 // The result is the product of all operand results.
6378 APInt Res = getConstantMultiple(M->getOperand(0));
6379 for (const SCEV *Operand : M->operands().drop_front())
6380 Res = Res * getConstantMultiple(Operand);
6381 return Res;
6382 }
6383
6384 // If there are no wrap guarentees, find the trailing zeros, which is the
6385 // sum of trailing zeros for all its operands.
6386 uint32_t TZ = 0;
6387 for (const SCEV *Operand : M->operands())
6388 TZ += getMinTrailingZeros(Operand);
6389 return GetShiftedByZeros(TZ);
6390 }
6391 case scAddExpr:
6392 case scAddRecExpr: {
6393 const SCEVNAryExpr *N = cast<SCEVNAryExpr>(S);
6394 if (N->hasNoUnsignedWrap())
6395 return GetGCDMultiple(N);
6396 // Find the trailing bits, which is the minimum of its operands.
6397 uint32_t TZ = getMinTrailingZeros(N->getOperand(0));
6398 for (const SCEV *Operand : N->operands().drop_front())
6399 TZ = std::min(TZ, getMinTrailingZeros(Operand));
6400 return GetShiftedByZeros(TZ);
6401 }
6402 case scUMaxExpr:
6403 case scSMaxExpr:
6404 case scUMinExpr:
6405 case scSMinExpr:
6407 return GetGCDMultiple(cast<SCEVNAryExpr>(S));
6408 case scUnknown: {
6409 // ask ValueTracking for known bits
6410 const SCEVUnknown *U = cast<SCEVUnknown>(S);
6411 unsigned Known =
6412 computeKnownBits(U->getValue(), getDataLayout(), &AC, nullptr, &DT)
6413 .countMinTrailingZeros();
6414 return GetShiftedByZeros(Known);
6415 }
6416 case scCouldNotCompute:
6417 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
6418 }
6419 llvm_unreachable("Unknown SCEV kind!");
6420}
6421
6423 auto I = ConstantMultipleCache.find(S);
6424 if (I != ConstantMultipleCache.end())
6425 return I->second;
6426
6427 APInt Result = getConstantMultipleImpl(S);
6428 auto InsertPair = ConstantMultipleCache.insert({S, Result});
6429 assert(InsertPair.second && "Should insert a new key");
6430 return InsertPair.first->second;
6431}
6432
6434 APInt Multiple = getConstantMultiple(S);
6435 return Multiple == 0 ? APInt(Multiple.getBitWidth(), 1) : Multiple;
6436}
6437
6439 return std::min(getConstantMultiple(S).countTrailingZeros(),
6440 (unsigned)getTypeSizeInBits(S->getType()));
6441}
6442
6443/// Helper method to assign a range to V from metadata present in the IR.
6444static std::optional<ConstantRange> GetRangeFromMetadata(Value *V) {
6446 if (MDNode *MD = I->getMetadata(LLVMContext::MD_range))
6447 return getConstantRangeFromMetadata(*MD);
6448 if (const auto *CB = dyn_cast<CallBase>(V))
6449 if (std::optional<ConstantRange> Range = CB->getRange())
6450 return Range;
6451 }
6452 if (auto *A = dyn_cast<Argument>(V))
6453 if (std::optional<ConstantRange> Range = A->getRange())
6454 return Range;
6455
6456 return std::nullopt;
6457}
6458
6460 SCEV::NoWrapFlags Flags) {
6461 if (AddRec->getNoWrapFlags(Flags) != Flags) {
6462 AddRec->setNoWrapFlags(Flags);
6463 UnsignedRanges.erase(AddRec);
6464 SignedRanges.erase(AddRec);
6465 ConstantMultipleCache.erase(AddRec);
6466 }
6467}
6468
6469ConstantRange ScalarEvolution::
6470getRangeForUnknownRecurrence(const SCEVUnknown *U) {
6471 const DataLayout &DL = getDataLayout();
6472
6473 unsigned BitWidth = getTypeSizeInBits(U->getType());
6474 const ConstantRange FullSet(BitWidth, /*isFullSet=*/true);
6475
6476 // Match a simple recurrence of the form: <start, ShiftOp, Step>, and then
6477 // use information about the trip count to improve our available range. Note
6478 // that the trip count independent cases are already handled by known bits.
6479 // WARNING: The definition of recurrence used here is subtly different than
6480 // the one used by AddRec (and thus most of this file). Step is allowed to
6481 // be arbitrarily loop varying here, where AddRec allows only loop invariant
6482 // and other addrecs in the same loop (for non-affine addrecs). The code
6483 // below intentionally handles the case where step is not loop invariant.
6484 auto *P = dyn_cast<PHINode>(U->getValue());
6485 if (!P)
6486 return FullSet;
6487
6488 // Make sure that no Phi input comes from an unreachable block. Otherwise,
6489 // even the values that are not available in these blocks may come from them,
6490 // and this leads to false-positive recurrence test.
6491 for (auto *Pred : predecessors(P->getParent()))
6492 if (!DT.isReachableFromEntry(Pred))
6493 return FullSet;
6494
6495 BinaryOperator *BO;
6496 Value *Start, *Step;
6497 if (!matchSimpleRecurrence(P, BO, Start, Step))
6498 return FullSet;
6499
6500 // If we found a recurrence in reachable code, we must be in a loop. Note
6501 // that BO might be in some subloop of L, and that's completely okay.
6502 auto *L = LI.getLoopFor(P->getParent());
6503 assert(L && L->getHeader() == P->getParent());
6504 if (!L->contains(BO->getParent()))
6505 // NOTE: This bailout should be an assert instead. However, asserting
6506 // the condition here exposes a case where LoopFusion is querying SCEV
6507 // with malformed loop information during the midst of the transform.
6508 // There doesn't appear to be an obvious fix, so for the moment bailout
6509 // until the caller issue can be fixed. PR49566 tracks the bug.
6510 return FullSet;
6511
6512 // TODO: Extend to other opcodes such as mul, and div
6513 switch (BO->getOpcode()) {
6514 default:
6515 return FullSet;
6516 case Instruction::AShr:
6517 case Instruction::LShr:
6518 case Instruction::Shl:
6519 break;
6520 };
6521
6522 if (BO->getOperand(0) != P)
6523 // TODO: Handle the power function forms some day.
6524 return FullSet;
6525
6526 unsigned TC = getSmallConstantMaxTripCount(L);
6527 if (!TC || TC >= BitWidth)
6528 return FullSet;
6529
6530 auto KnownStart = computeKnownBits(Start, DL, &AC, nullptr, &DT);
6531 auto KnownStep = computeKnownBits(Step, DL, &AC, nullptr, &DT);
6532 assert(KnownStart.getBitWidth() == BitWidth &&
6533 KnownStep.getBitWidth() == BitWidth);
6534
6535 // Compute total shift amount, being careful of overflow and bitwidths.
6536 auto MaxShiftAmt = KnownStep.getMaxValue();
6537 APInt TCAP(BitWidth, TC-1);
6538 bool Overflow = false;
6539 auto TotalShift = MaxShiftAmt.umul_ov(TCAP, Overflow);
6540 if (Overflow)
6541 return FullSet;
6542
6543 switch (BO->getOpcode()) {
6544 default:
6545 llvm_unreachable("filtered out above");
6546 case Instruction::AShr: {
6547 // For each ashr, three cases:
6548 // shift = 0 => unchanged value
6549 // saturation => 0 or -1
6550 // other => a value closer to zero (of the same sign)
6551 // Thus, the end value is closer to zero than the start.
6552 auto KnownEnd = KnownBits::ashr(KnownStart,
6553 KnownBits::makeConstant(TotalShift));
6554 if (KnownStart.isNonNegative())
6555 // Analogous to lshr (simply not yet canonicalized)
6556 return ConstantRange::getNonEmpty(KnownEnd.getMinValue(),
6557 KnownStart.getMaxValue() + 1);
6558 if (KnownStart.isNegative())
6559 // End >=u Start && End <=s Start
6560 return ConstantRange::getNonEmpty(KnownStart.getMinValue(),
6561 KnownEnd.getMaxValue() + 1);
6562 break;
6563 }
6564 case Instruction::LShr: {
6565 // For each lshr, three cases:
6566 // shift = 0 => unchanged value
6567 // saturation => 0
6568 // other => a smaller positive number
6569 // Thus, the low end of the unsigned range is the last value produced.
6570 auto KnownEnd = KnownBits::lshr(KnownStart,
6571 KnownBits::makeConstant(TotalShift));
6572 return ConstantRange::getNonEmpty(KnownEnd.getMinValue(),
6573 KnownStart.getMaxValue() + 1);
6574 }
6575 case Instruction::Shl: {
6576 // Iff no bits are shifted out, value increases on every shift.
6577 auto KnownEnd = KnownBits::shl(KnownStart,
6578 KnownBits::makeConstant(TotalShift));
6579 if (TotalShift.ult(KnownStart.countMinLeadingZeros()))
6580 return ConstantRange(KnownStart.getMinValue(),
6581 KnownEnd.getMaxValue() + 1);
6582 break;
6583 }
6584 };
6585 return FullSet;
6586}
6587
6588const ConstantRange &
6589ScalarEvolution::getRangeRefIter(const SCEV *S,
6590 ScalarEvolution::RangeSignHint SignHint) {
6591 DenseMap<const SCEV *, ConstantRange> &Cache =
6592 SignHint == ScalarEvolution::HINT_RANGE_UNSIGNED ? UnsignedRanges
6593 : SignedRanges;
6595 SmallPtrSet<const SCEV *, 8> Seen;
6596
6597 // Add Expr to the worklist, if Expr is either an N-ary expression or a
6598 // SCEVUnknown PHI node.
6599 auto AddToWorklist = [&WorkList, &Seen, &Cache](const SCEV *Expr) {
6600 if (!Seen.insert(Expr).second)
6601 return;
6602 if (Cache.contains(Expr))
6603 return;
6604 switch (Expr->getSCEVType()) {
6605 case scUnknown:
6606 if (!isa<PHINode>(cast<SCEVUnknown>(Expr)->getValue()))
6607 break;
6608 [[fallthrough]];
6609 case scConstant:
6610 case scVScale:
6611 case scTruncate:
6612 case scZeroExtend:
6613 case scSignExtend:
6614 case scPtrToInt:
6615 case scAddExpr:
6616 case scMulExpr:
6617 case scUDivExpr:
6618 case scAddRecExpr:
6619 case scUMaxExpr:
6620 case scSMaxExpr:
6621 case scUMinExpr:
6622 case scSMinExpr:
6624 WorkList.push_back(Expr);
6625 break;
6626 case scCouldNotCompute:
6627 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
6628 }
6629 };
6630 AddToWorklist(S);
6631
6632 // Build worklist by queuing operands of N-ary expressions and phi nodes.
6633 for (unsigned I = 0; I != WorkList.size(); ++I) {
6634 const SCEV *P = WorkList[I];
6635 auto *UnknownS = dyn_cast<SCEVUnknown>(P);
6636 // If it is not a `SCEVUnknown`, just recurse into operands.
6637 if (!UnknownS) {
6638 for (const SCEV *Op : P->operands())
6639 AddToWorklist(Op);
6640 continue;
6641 }
6642 // `SCEVUnknown`'s require special treatment.
6643 if (const PHINode *P = dyn_cast<PHINode>(UnknownS->getValue())) {
6644 if (!PendingPhiRangesIter.insert(P).second)
6645 continue;
6646 for (auto &Op : reverse(P->operands()))
6647 AddToWorklist(getSCEV(Op));
6648 }
6649 }
6650
6651 if (!WorkList.empty()) {
6652 // Use getRangeRef to compute ranges for items in the worklist in reverse
6653 // order. This will force ranges for earlier operands to be computed before
6654 // their users in most cases.
6655 for (const SCEV *P : reverse(drop_begin(WorkList))) {
6656 getRangeRef(P, SignHint);
6657
6658 if (auto *UnknownS = dyn_cast<SCEVUnknown>(P))
6659 if (const PHINode *P = dyn_cast<PHINode>(UnknownS->getValue()))
6660 PendingPhiRangesIter.erase(P);
6661 }
6662 }
6663
6664 return getRangeRef(S, SignHint, 0);
6665}
6666
6667/// Determine the range for a particular SCEV. If SignHint is
6668/// HINT_RANGE_UNSIGNED (resp. HINT_RANGE_SIGNED) then getRange prefers ranges
6669/// with a "cleaner" unsigned (resp. signed) representation.
6670const ConstantRange &ScalarEvolution::getRangeRef(
6671 const SCEV *S, ScalarEvolution::RangeSignHint SignHint, unsigned Depth) {
6672 DenseMap<const SCEV *, ConstantRange> &Cache =
6673 SignHint == ScalarEvolution::HINT_RANGE_UNSIGNED ? UnsignedRanges
6674 : SignedRanges;
6676 SignHint == ScalarEvolution::HINT_RANGE_UNSIGNED ? ConstantRange::Unsigned
6678
6679 // See if we've computed this range already.
6681 if (I != Cache.end())
6682 return I->second;
6683
6684 if (const SCEVConstant *C = dyn_cast<SCEVConstant>(S))
6685 return setRange(C, SignHint, ConstantRange(C->getAPInt()));
6686
6687 // Switch to iteratively computing the range for S, if it is part of a deeply
6688 // nested expression.
6690 return getRangeRefIter(S, SignHint);
6691
6692 unsigned BitWidth = getTypeSizeInBits(S->getType());
6693 ConstantRange ConservativeResult(BitWidth, /*isFullSet=*/true);
6694 using OBO = OverflowingBinaryOperator;
6695
6696 // If the value has known zeros, the maximum value will have those known zeros
6697 // as well.
6698 if (SignHint == ScalarEvolution::HINT_RANGE_UNSIGNED) {
6699 APInt Multiple = getNonZeroConstantMultiple(S);
6700 APInt Remainder = APInt::getMaxValue(BitWidth).urem(Multiple);
6701 if (!Remainder.isZero())
6702 ConservativeResult =
6703 ConstantRange(APInt::getMinValue(BitWidth),
6704 APInt::getMaxValue(BitWidth) - Remainder + 1);
6705 }
6706 else {
6707 uint32_t TZ = getMinTrailingZeros(S);
6708 if (TZ != 0) {
6709 ConservativeResult = ConstantRange(
6711 APInt::getSignedMaxValue(BitWidth).ashr(TZ).shl(TZ) + 1);
6712 }
6713 }
6714
6715 switch (S->getSCEVType()) {
6716 case scConstant:
6717 llvm_unreachable("Already handled above.");
6718 case scVScale:
6719 return setRange(S, SignHint, getVScaleRange(&F, BitWidth));
6720 case scTruncate: {
6721 const SCEVTruncateExpr *Trunc = cast<SCEVTruncateExpr>(S);
6722 ConstantRange X = getRangeRef(Trunc->getOperand(), SignHint, Depth + 1);
6723 return setRange(
6724 Trunc, SignHint,
6725 ConservativeResult.intersectWith(X.truncate(BitWidth), RangeType));
6726 }
6727 case scZeroExtend: {
6728 const SCEVZeroExtendExpr *ZExt = cast<SCEVZeroExtendExpr>(S);
6729 ConstantRange X = getRangeRef(ZExt->getOperand(), SignHint, Depth + 1);
6730 return setRange(
6731 ZExt, SignHint,
6732 ConservativeResult.intersectWith(X.zeroExtend(BitWidth), RangeType));
6733 }
6734 case scSignExtend: {
6735 const SCEVSignExtendExpr *SExt = cast<SCEVSignExtendExpr>(S);
6736 ConstantRange X = getRangeRef(SExt->getOperand(), SignHint, Depth + 1);
6737 return setRange(
6738 SExt, SignHint,
6739 ConservativeResult.intersectWith(X.signExtend(BitWidth), RangeType));
6740 }
6741 case scPtrToInt: {
6742 const SCEVPtrToIntExpr *PtrToInt = cast<SCEVPtrToIntExpr>(S);
6743 ConstantRange X = getRangeRef(PtrToInt->getOperand(), SignHint, Depth + 1);
6744 return setRange(PtrToInt, SignHint, X);
6745 }
6746 case scAddExpr: {
6747 const SCEVAddExpr *Add = cast<SCEVAddExpr>(S);
6748 ConstantRange X = getRangeRef(Add->getOperand(0), SignHint, Depth + 1);
6749 unsigned WrapType = OBO::AnyWrap;
6750 if (Add->hasNoSignedWrap())
6751 WrapType |= OBO::NoSignedWrap;
6752 if (Add->hasNoUnsignedWrap())
6753 WrapType |= OBO::NoUnsignedWrap;
6754 for (const SCEV *Op : drop_begin(Add->operands()))
6755 X = X.addWithNoWrap(getRangeRef(Op, SignHint, Depth + 1), WrapType,
6756 RangeType);
6757 return setRange(Add, SignHint,
6758 ConservativeResult.intersectWith(X, RangeType));
6759 }
6760 case scMulExpr: {
6761 const SCEVMulExpr *Mul = cast<SCEVMulExpr>(S);
6762 ConstantRange X = getRangeRef(Mul->getOperand(0), SignHint, Depth + 1);
6763 for (const SCEV *Op : drop_begin(Mul->operands()))
6764 X = X.multiply(getRangeRef(Op, SignHint, Depth + 1));
6765 return setRange(Mul, SignHint,
6766 ConservativeResult.intersectWith(X, RangeType));
6767 }
6768 case scUDivExpr: {
6769 const SCEVUDivExpr *UDiv = cast<SCEVUDivExpr>(S);
6770 ConstantRange X = getRangeRef(UDiv->getLHS(), SignHint, Depth + 1);
6771 ConstantRange Y = getRangeRef(UDiv->getRHS(), SignHint, Depth + 1);
6772 return setRange(UDiv, SignHint,
6773 ConservativeResult.intersectWith(X.udiv(Y), RangeType));
6774 }
6775 case scAddRecExpr: {
6776 const SCEVAddRecExpr *AddRec = cast<SCEVAddRecExpr>(S);
6777 // If there's no unsigned wrap, the value will never be less than its
6778 // initial value.
6779 if (AddRec->hasNoUnsignedWrap()) {
6780 APInt UnsignedMinValue = getUnsignedRangeMin(AddRec->getStart());
6781 if (!UnsignedMinValue.isZero())
6782 ConservativeResult = ConservativeResult.intersectWith(
6783 ConstantRange(UnsignedMinValue, APInt(BitWidth, 0)), RangeType);
6784 }
6785
6786 // If there's no signed wrap, and all the operands except initial value have
6787 // the same sign or zero, the value won't ever be:
6788 // 1: smaller than initial value if operands are non negative,
6789 // 2: bigger than initial value if operands are non positive.
6790 // For both cases, value can not cross signed min/max boundary.
6791 if (AddRec->hasNoSignedWrap()) {
6792 bool AllNonNeg = true;
6793 bool AllNonPos = true;
6794 for (unsigned i = 1, e = AddRec->getNumOperands(); i != e; ++i) {
6795 if (!isKnownNonNegative(AddRec->getOperand(i)))
6796 AllNonNeg = false;
6797 if (!isKnownNonPositive(AddRec->getOperand(i)))
6798 AllNonPos = false;
6799 }
6800 if (AllNonNeg)
6801 ConservativeResult = ConservativeResult.intersectWith(
6804 RangeType);
6805 else if (AllNonPos)
6806 ConservativeResult = ConservativeResult.intersectWith(
6808 getSignedRangeMax(AddRec->getStart()) +
6809 1),
6810 RangeType);
6811 }
6812
6813 // TODO: non-affine addrec
6814 if (AddRec->isAffine()) {
6815 const SCEV *MaxBEScev =
6817 if (!isa<SCEVCouldNotCompute>(MaxBEScev)) {
6818 APInt MaxBECount = cast<SCEVConstant>(MaxBEScev)->getAPInt();
6819
6820 // Adjust MaxBECount to the same bitwidth as AddRec. We can truncate if
6821 // MaxBECount's active bits are all <= AddRec's bit width.
6822 if (MaxBECount.getBitWidth() > BitWidth &&
6823 MaxBECount.getActiveBits() <= BitWidth)
6824 MaxBECount = MaxBECount.trunc(BitWidth);
6825 else if (MaxBECount.getBitWidth() < BitWidth)
6826 MaxBECount = MaxBECount.zext(BitWidth);
6827
6828 if (MaxBECount.getBitWidth() == BitWidth) {
6829 auto RangeFromAffine = getRangeForAffineAR(
6830 AddRec->getStart(), AddRec->getStepRecurrence(*this), MaxBECount);
6831 ConservativeResult =
6832 ConservativeResult.intersectWith(RangeFromAffine, RangeType);
6833
6834 auto RangeFromFactoring = getRangeViaFactoring(
6835 AddRec->getStart(), AddRec->getStepRecurrence(*this), MaxBECount);
6836 ConservativeResult =
6837 ConservativeResult.intersectWith(RangeFromFactoring, RangeType);
6838 }
6839 }
6840
6841 // Now try symbolic BE count and more powerful methods.
6843 const SCEV *SymbolicMaxBECount =
6845 if (!isa<SCEVCouldNotCompute>(SymbolicMaxBECount) &&
6846 getTypeSizeInBits(MaxBEScev->getType()) <= BitWidth &&
6847 AddRec->hasNoSelfWrap()) {
6848 auto RangeFromAffineNew = getRangeForAffineNoSelfWrappingAR(
6849 AddRec, SymbolicMaxBECount, BitWidth, SignHint);
6850 ConservativeResult =
6851 ConservativeResult.intersectWith(RangeFromAffineNew, RangeType);
6852 }
6853 }
6854 }
6855
6856 return setRange(AddRec, SignHint, std::move(ConservativeResult));
6857 }
6858 case scUMaxExpr:
6859 case scSMaxExpr:
6860 case scUMinExpr:
6861 case scSMinExpr:
6862 case scSequentialUMinExpr: {
6864 switch (S->getSCEVType()) {
6865 case scUMaxExpr:
6866 ID = Intrinsic::umax;
6867 break;
6868 case scSMaxExpr:
6869 ID = Intrinsic::smax;
6870 break;
6871 case scUMinExpr:
6873 ID = Intrinsic::umin;
6874 break;
6875 case scSMinExpr:
6876 ID = Intrinsic::smin;
6877 break;
6878 default:
6879 llvm_unreachable("Unknown SCEVMinMaxExpr/SCEVSequentialMinMaxExpr.");
6880 }
6881
6882 const auto *NAry = cast<SCEVNAryExpr>(S);
6883 ConstantRange X = getRangeRef(NAry->getOperand(0), SignHint, Depth + 1);
6884 for (unsigned i = 1, e = NAry->getNumOperands(); i != e; ++i)
6885 X = X.intrinsic(
6886 ID, {X, getRangeRef(NAry->getOperand(i), SignHint, Depth + 1)});
6887 return setRange(S, SignHint,
6888 ConservativeResult.intersectWith(X, RangeType));
6889 }
6890 case scUnknown: {
6891 const SCEVUnknown *U = cast<SCEVUnknown>(S);
6892 Value *V = U->getValue();
6893
6894 // Check if the IR explicitly contains !range metadata.
6895 std::optional<ConstantRange> MDRange = GetRangeFromMetadata(V);
6896 if (MDRange)
6897 ConservativeResult =
6898 ConservativeResult.intersectWith(*MDRange, RangeType);
6899
6900 // Use facts about recurrences in the underlying IR. Note that add
6901 // recurrences are AddRecExprs and thus don't hit this path. This
6902 // primarily handles shift recurrences.
6903 auto CR = getRangeForUnknownRecurrence(U);
6904 ConservativeResult = ConservativeResult.intersectWith(CR);
6905
6906 // See if ValueTracking can give us a useful range.
6907 const DataLayout &DL = getDataLayout();
6908 KnownBits Known = computeKnownBits(V, DL, &AC, nullptr, &DT);
6909 if (Known.getBitWidth() != BitWidth)
6910 Known = Known.zextOrTrunc(BitWidth);
6911
6912 // ValueTracking may be able to compute a tighter result for the number of
6913 // sign bits than for the value of those sign bits.
6914 unsigned NS = ComputeNumSignBits(V, DL, &AC, nullptr, &DT);
6915 if (U->getType()->isPointerTy()) {
6916 // If the pointer size is larger than the index size type, this can cause
6917 // NS to be larger than BitWidth. So compensate for this.
6918 unsigned ptrSize = DL.getPointerTypeSizeInBits(U->getType());
6919 int ptrIdxDiff = ptrSize - BitWidth;
6920 if (ptrIdxDiff > 0 && ptrSize > BitWidth && NS > (unsigned)ptrIdxDiff)
6921 NS -= ptrIdxDiff;
6922 }
6923
6924 if (NS > 1) {
6925 // If we know any of the sign bits, we know all of the sign bits.
6926 if (!Known.Zero.getHiBits(NS).isZero())
6927 Known.Zero.setHighBits(NS);
6928 if (!Known.One.getHiBits(NS).isZero())
6929 Known.One.setHighBits(NS);
6930 }
6931
6932 if (Known.getMinValue() != Known.getMaxValue() + 1)
6933 ConservativeResult = ConservativeResult.intersectWith(
6934 ConstantRange(Known.getMinValue(), Known.getMaxValue() + 1),
6935 RangeType);
6936 if (NS > 1)
6937 ConservativeResult = ConservativeResult.intersectWith(
6938 ConstantRange(APInt::getSignedMinValue(BitWidth).ashr(NS - 1),
6939 APInt::getSignedMaxValue(BitWidth).ashr(NS - 1) + 1),
6940 RangeType);
6941
6942 if (U->getType()->isPointerTy() && SignHint == HINT_RANGE_UNSIGNED) {
6943 // Strengthen the range if the underlying IR value is a
6944 // global/alloca/heap allocation using the size of the object.
6945 bool CanBeNull, CanBeFreed;
6946 uint64_t DerefBytes =
6947 V->getPointerDereferenceableBytes(DL, CanBeNull, CanBeFreed);
6948 if (DerefBytes > 1 && isUIntN(BitWidth, DerefBytes)) {
6949 // The highest address the object can start is DerefBytes bytes before
6950 // the end (unsigned max value). If this value is not a multiple of the
6951 // alignment, the last possible start value is the next lowest multiple
6952 // of the alignment. Note: The computations below cannot overflow,
6953 // because if they would there's no possible start address for the
6954 // object.
6955 APInt MaxVal =
6956 APInt::getMaxValue(BitWidth) - APInt(BitWidth, DerefBytes);
6957 uint64_t Align = U->getValue()->getPointerAlignment(DL).value();
6958 uint64_t Rem = MaxVal.urem(Align);
6959 MaxVal -= APInt(BitWidth, Rem);
6960 APInt MinVal = APInt::getZero(BitWidth);
6961 if (llvm::isKnownNonZero(V, DL))
6962 MinVal = Align;
6963 ConservativeResult = ConservativeResult.intersectWith(
6964 ConstantRange::getNonEmpty(MinVal, MaxVal + 1), RangeType);
6965 }
6966 }
6967
6968 // A range of Phi is a subset of union of all ranges of its input.
6969 if (PHINode *Phi = dyn_cast<PHINode>(V)) {
6970 // Make sure that we do not run over cycled Phis.
6971 if (PendingPhiRanges.insert(Phi).second) {
6972 ConstantRange RangeFromOps(BitWidth, /*isFullSet=*/false);
6973
6974 for (const auto &Op : Phi->operands()) {
6975 auto OpRange = getRangeRef(getSCEV(Op), SignHint, Depth + 1);
6976 RangeFromOps = RangeFromOps.unionWith(OpRange);
6977 // No point to continue if we already have a full set.
6978 if (RangeFromOps.isFullSet())
6979 break;
6980 }
6981 ConservativeResult =
6982 ConservativeResult.intersectWith(RangeFromOps, RangeType);
6983 bool Erased = PendingPhiRanges.erase(Phi);
6984 assert(Erased && "Failed to erase Phi properly?");
6985 (void)Erased;
6986 }
6987 }
6988
6989 // vscale can't be equal to zero
6990 if (const auto *II = dyn_cast<IntrinsicInst>(V))
6991 if (II->getIntrinsicID() == Intrinsic::vscale) {
6992 ConstantRange Disallowed = APInt::getZero(BitWidth);
6993 ConservativeResult = ConservativeResult.difference(Disallowed);
6994 }
6995
6996 return setRange(U, SignHint, std::move(ConservativeResult));
6997 }
6998 case scCouldNotCompute:
6999 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
7000 }
7001
7002 return setRange(S, SignHint, std::move(ConservativeResult));
7003}
7004
7005// Given a StartRange, Step and MaxBECount for an expression compute a range of
7006// values that the expression can take. Initially, the expression has a value
7007// from StartRange and then is changed by Step up to MaxBECount times. Signed
7008// argument defines if we treat Step as signed or unsigned.
7010 const ConstantRange &StartRange,
7011 const APInt &MaxBECount,
7012 bool Signed) {
7013 unsigned BitWidth = Step.getBitWidth();
7014 assert(BitWidth == StartRange.getBitWidth() &&
7015 BitWidth == MaxBECount.getBitWidth() && "mismatched bit widths");
7016 // If either Step or MaxBECount is 0, then the expression won't change, and we
7017 // just need to return the initial range.
7018 if (Step == 0 || MaxBECount == 0)
7019 return StartRange;
7020
7021 // If we don't know anything about the initial value (i.e. StartRange is
7022 // FullRange), then we don't know anything about the final range either.
7023 // Return FullRange.
7024 if (StartRange.isFullSet())
7025 return ConstantRange::getFull(BitWidth);
7026
7027 // If Step is signed and negative, then we use its absolute value, but we also
7028 // note that we're moving in the opposite direction.
7029 bool Descending = Signed && Step.isNegative();
7030
7031 if (Signed)
7032 // This is correct even for INT_SMIN. Let's look at i8 to illustrate this:
7033 // abs(INT_SMIN) = abs(-128) = abs(0x80) = -0x80 = 0x80 = 128.
7034 // This equations hold true due to the well-defined wrap-around behavior of
7035 // APInt.
7036 Step = Step.abs();
7037
7038 // Check if Offset is more than full span of BitWidth. If it is, the
7039 // expression is guaranteed to overflow.
7040 if (APInt::getMaxValue(StartRange.getBitWidth()).udiv(Step).ult(MaxBECount))
7041 return ConstantRange::getFull(BitWidth);
7042
7043 // Offset is by how much the expression can change. Checks above guarantee no
7044 // overflow here.
7045 APInt Offset = Step * MaxBECount;
7046
7047 // Minimum value of the final range will match the minimal value of StartRange
7048 // if the expression is increasing and will be decreased by Offset otherwise.
7049 // Maximum value of the final range will match the maximal value of StartRange
7050 // if the expression is decreasing and will be increased by Offset otherwise.
7051 APInt StartLower = StartRange.getLower();
7052 APInt StartUpper = StartRange.getUpper() - 1;
7053 APInt MovedBoundary = Descending ? (StartLower - std::move(Offset))
7054 : (StartUpper + std::move(Offset));
7055
7056 // It's possible that the new minimum/maximum value will fall into the initial
7057 // range (due to wrap around). This means that the expression can take any
7058 // value in this bitwidth, and we have to return full range.
7059 if (StartRange.contains(MovedBoundary))
7060 return ConstantRange::getFull(BitWidth);
7061
7062 APInt NewLower =
7063 Descending ? std::move(MovedBoundary) : std::move(StartLower);
7064 APInt NewUpper =
7065 Descending ? std::move(StartUpper) : std::move(MovedBoundary);
7066 NewUpper += 1;
7067
7068 // No overflow detected, return [StartLower, StartUpper + Offset + 1) range.
7069 return ConstantRange::getNonEmpty(std::move(NewLower), std::move(NewUpper));
7070}
7071
7072ConstantRange ScalarEvolution::getRangeForAffineAR(const SCEV *Start,
7073 const SCEV *Step,
7074 const APInt &MaxBECount) {
7075 assert(getTypeSizeInBits(Start->getType()) ==
7076 getTypeSizeInBits(Step->getType()) &&
7077 getTypeSizeInBits(Start->getType()) == MaxBECount.getBitWidth() &&
7078 "mismatched bit widths");
7079
7080 // First, consider step signed.
7081 ConstantRange StartSRange = getSignedRange(Start);
7082 ConstantRange StepSRange = getSignedRange(Step);
7083
7084 // If Step can be both positive and negative, we need to find ranges for the
7085 // maximum absolute step values in both directions and union them.
7086 ConstantRange SR = getRangeForAffineARHelper(
7087 StepSRange.getSignedMin(), StartSRange, MaxBECount, /* Signed = */ true);
7089 StartSRange, MaxBECount,
7090 /* Signed = */ true));
7091
7092 // Next, consider step unsigned.
7093 ConstantRange UR = getRangeForAffineARHelper(
7094 getUnsignedRangeMax(Step), getUnsignedRange(Start), MaxBECount,
7095 /* Signed = */ false);
7096
7097 // Finally, intersect signed and unsigned ranges.
7099}
7100
7101ConstantRange ScalarEvolution::getRangeForAffineNoSelfWrappingAR(
7102 const SCEVAddRecExpr *AddRec, const SCEV *MaxBECount, unsigned BitWidth,
7103 ScalarEvolution::RangeSignHint SignHint) {
7104 assert(AddRec->isAffine() && "Non-affine AddRecs are not suppored!\n");
7105 assert(AddRec->hasNoSelfWrap() &&
7106 "This only works for non-self-wrapping AddRecs!");
7107 const bool IsSigned = SignHint == HINT_RANGE_SIGNED;
7108 const SCEV *Step = AddRec->getStepRecurrence(*this);
7109 // Only deal with constant step to save compile time.
7110 if (!isa<SCEVConstant>(Step))
7111 return ConstantRange::getFull(BitWidth);
7112 // Let's make sure that we can prove that we do not self-wrap during
7113 // MaxBECount iterations. We need this because MaxBECount is a maximum
7114 // iteration count estimate, and we might infer nw from some exit for which we
7115 // do not know max exit count (or any other side reasoning).
7116 // TODO: Turn into assert at some point.
7117 if (getTypeSizeInBits(MaxBECount->getType()) >
7118 getTypeSizeInBits(AddRec->getType()))
7119 return ConstantRange::getFull(BitWidth);
7120 MaxBECount = getNoopOrZeroExtend(MaxBECount, AddRec->getType());
7121 const SCEV *RangeWidth = getMinusOne(AddRec->getType());
7122 const SCEV *StepAbs = getUMinExpr(Step, getNegativeSCEV(Step));
7123 const SCEV *MaxItersWithoutWrap = getUDivExpr(RangeWidth, StepAbs);
7124 if (!isKnownPredicateViaConstantRanges(ICmpInst::ICMP_ULE, MaxBECount,
7125 MaxItersWithoutWrap))
7126 return ConstantRange::getFull(BitWidth);
7127
7128 ICmpInst::Predicate LEPred =
7130 ICmpInst::Predicate GEPred =
7132 const SCEV *End = AddRec->evaluateAtIteration(MaxBECount, *this);
7133
7134 // We know that there is no self-wrap. Let's take Start and End values and
7135 // look at all intermediate values V1, V2, ..., Vn that IndVar takes during
7136 // the iteration. They either lie inside the range [Min(Start, End),
7137 // Max(Start, End)] or outside it:
7138 //
7139 // Case 1: RangeMin ... Start V1 ... VN End ... RangeMax;
7140 // Case 2: RangeMin Vk ... V1 Start ... End Vn ... Vk + 1 RangeMax;
7141 //
7142 // No self wrap flag guarantees that the intermediate values cannot be BOTH
7143 // outside and inside the range [Min(Start, End), Max(Start, End)]. Using that
7144 // knowledge, let's try to prove that we are dealing with Case 1. It is so if
7145 // Start <= End and step is positive, or Start >= End and step is negative.
7146 const SCEV *Start = applyLoopGuards(AddRec->getStart(), AddRec->getLoop());
7147 ConstantRange StartRange = getRangeRef(Start, SignHint);
7148 ConstantRange EndRange = getRangeRef(End, SignHint);
7149 ConstantRange RangeBetween = StartRange.unionWith(EndRange);
7150 // If they already cover full iteration space, we will know nothing useful
7151 // even if we prove what we want to prove.
7152 if (RangeBetween.isFullSet())
7153 return RangeBetween;
7154 // Only deal with ranges that do not wrap (i.e. RangeMin < RangeMax).
7155 bool IsWrappedSet = IsSigned ? RangeBetween.isSignWrappedSet()
7156 : RangeBetween.isWrappedSet();
7157 if (IsWrappedSet)
7158 return ConstantRange::getFull(BitWidth);
7159
7160 if (isKnownPositive(Step) &&
7161 isKnownPredicateViaConstantRanges(LEPred, Start, End))
7162 return RangeBetween;
7163 if (isKnownNegative(Step) &&
7164 isKnownPredicateViaConstantRanges(GEPred, Start, End))
7165 return RangeBetween;
7166 return ConstantRange::getFull(BitWidth);
7167}
7168
7169ConstantRange ScalarEvolution::getRangeViaFactoring(const SCEV *Start,
7170 const SCEV *Step,
7171 const APInt &MaxBECount) {
7172 // RangeOf({C?A:B,+,C?P:Q}) == RangeOf(C?{A,+,P}:{B,+,Q})
7173 // == RangeOf({A,+,P}) union RangeOf({B,+,Q})
7174
7175 unsigned BitWidth = MaxBECount.getBitWidth();
7176 assert(getTypeSizeInBits(Start->getType()) == BitWidth &&
7177 getTypeSizeInBits(Step->getType()) == BitWidth &&
7178 "mismatched bit widths");
7179
7180 struct SelectPattern {
7181 Value *Condition = nullptr;
7182 APInt TrueValue;
7183 APInt FalseValue;
7184
7185 explicit SelectPattern(ScalarEvolution &SE, unsigned BitWidth,
7186 const SCEV *S) {
7187 std::optional<unsigned> CastOp;
7188 APInt Offset(BitWidth, 0);
7189
7191 "Should be!");
7192
7193 // Peel off a constant offset. In the future we could consider being
7194 // smarter here and handle {Start+Step,+,Step} too.
7195 const APInt *Off;
7196 if (match(S, m_scev_Add(m_scev_APInt(Off), m_SCEV(S))))
7197 Offset = *Off;
7198
7199 // Peel off a cast operation
7200 if (auto *SCast = dyn_cast<SCEVIntegralCastExpr>(S)) {
7201 CastOp = SCast->getSCEVType();
7202 S = SCast->getOperand();
7203 }
7204
7205 using namespace llvm::PatternMatch;
7206
7207 auto *SU = dyn_cast<SCEVUnknown>(S);
7208 const APInt *TrueVal, *FalseVal;
7209 if (!SU ||
7210 !match(SU->getValue(), m_Select(m_Value(Condition), m_APInt(TrueVal),
7211 m_APInt(FalseVal)))) {
7212 Condition = nullptr;
7213 return;
7214 }
7215
7216 TrueValue = *TrueVal;
7217 FalseValue = *FalseVal;
7218
7219 // Re-apply the cast we peeled off earlier
7220 if (CastOp)
7221 switch (*CastOp) {
7222 default:
7223 llvm_unreachable("Unknown SCEV cast type!");
7224
7225 case scTruncate:
7226 TrueValue = TrueValue.trunc(BitWidth);
7227 FalseValue = FalseValue.trunc(BitWidth);
7228 break;
7229 case scZeroExtend:
7230 TrueValue = TrueValue.zext(BitWidth);
7231 FalseValue = FalseValue.zext(BitWidth);
7232 break;
7233 case scSignExtend:
7234 TrueValue = TrueValue.sext(BitWidth);
7235 FalseValue = FalseValue.sext(BitWidth);
7236 break;
7237 }
7238
7239 // Re-apply the constant offset we peeled off earlier
7240 TrueValue += Offset;
7241 FalseValue += Offset;
7242 }
7243
7244 bool isRecognized() { return Condition != nullptr; }
7245 };
7246
7247 SelectPattern StartPattern(*this, BitWidth, Start);
7248 if (!StartPattern.isRecognized())
7249 return ConstantRange::getFull(BitWidth);
7250
7251 SelectPattern StepPattern(*this, BitWidth, Step);
7252 if (!StepPattern.isRecognized())
7253 return ConstantRange::getFull(BitWidth);
7254
7255 if (StartPattern.Condition != StepPattern.Condition) {
7256 // We don't handle this case today; but we could, by considering four
7257 // possibilities below instead of two. I'm not sure if there are cases where
7258 // that will help over what getRange already does, though.
7259 return ConstantRange::getFull(BitWidth);
7260 }
7261
7262 // NB! Calling ScalarEvolution::getConstant is fine, but we should not try to
7263 // construct arbitrary general SCEV expressions here. This function is called
7264 // from deep in the call stack, and calling getSCEV (on a sext instruction,
7265 // say) can end up caching a suboptimal value.
7266
7267 // FIXME: without the explicit `this` receiver below, MSVC errors out with
7268 // C2352 and C2512 (otherwise it isn't needed).
7269
7270 const SCEV *TrueStart = this->getConstant(StartPattern.TrueValue);
7271 const SCEV *TrueStep = this->getConstant(StepPattern.TrueValue);
7272 const SCEV *FalseStart = this->getConstant(StartPattern.FalseValue);
7273 const SCEV *FalseStep = this->getConstant(StepPattern.FalseValue);
7274
7275 ConstantRange TrueRange =
7276 this->getRangeForAffineAR(TrueStart, TrueStep, MaxBECount);
7277 ConstantRange FalseRange =
7278 this->getRangeForAffineAR(FalseStart, FalseStep, MaxBECount);
7279
7280 return TrueRange.unionWith(FalseRange);
7281}
7282
7283SCEV::NoWrapFlags ScalarEvolution::getNoWrapFlagsFromUB(const Value *V) {
7284 if (isa<ConstantExpr>(V)) return SCEV::FlagAnyWrap;
7285 const BinaryOperator *BinOp = cast<BinaryOperator>(V);
7286
7287 // Return early if there are no flags to propagate to the SCEV.
7289 if (BinOp->hasNoUnsignedWrap())
7291 if (BinOp->hasNoSignedWrap())
7293 if (Flags == SCEV::FlagAnyWrap)
7294 return SCEV::FlagAnyWrap;
7295
7296 return isSCEVExprNeverPoison(BinOp) ? Flags : SCEV::FlagAnyWrap;
7297}
7298
7299const Instruction *
7300ScalarEvolution::getNonTrivialDefiningScopeBound(const SCEV *S) {
7301 if (auto *AddRec = dyn_cast<SCEVAddRecExpr>(S))
7302 return &*AddRec->getLoop()->getHeader()->begin();
7303 if (auto *U = dyn_cast<SCEVUnknown>(S))
7304 if (auto *I = dyn_cast<Instruction>(U->getValue()))
7305 return I;
7306 return nullptr;
7307}
7308
7309const Instruction *
7310ScalarEvolution::getDefiningScopeBound(ArrayRef<const SCEV *> Ops,
7311 bool &Precise) {
7312 Precise = true;
7313 // Do a bounded search of the def relation of the requested SCEVs.
7314 SmallPtrSet<const SCEV *, 16> Visited;
7316 auto pushOp = [&](const SCEV *S) {
7317 if (!Visited.insert(S).second)
7318 return;
7319 // Threshold of 30 here is arbitrary.
7320 if (Visited.size() > 30) {
7321 Precise = false;
7322 return;
7323 }
7324 Worklist.push_back(S);
7325 };
7326
7327 for (const auto *S : Ops)
7328 pushOp(S);
7329
7330 const Instruction *Bound = nullptr;
7331 while (!Worklist.empty()) {
7332 auto *S = Worklist.pop_back_val();
7333 if (auto *DefI = getNonTrivialDefiningScopeBound(S)) {
7334 if (!Bound || DT.dominates(Bound, DefI))
7335 Bound = DefI;
7336 } else {
7337 for (const auto *Op : S->operands())
7338 pushOp(Op);
7339 }
7340 }
7341 return Bound ? Bound : &*F.getEntryBlock().begin();
7342}
7343
7344const Instruction *
7345ScalarEvolution::getDefiningScopeBound(ArrayRef<const SCEV *> Ops) {
7346 bool Discard;
7347 return getDefiningScopeBound(Ops, Discard);
7348}
7349
7350bool ScalarEvolution::isGuaranteedToTransferExecutionTo(const Instruction *A,
7351 const Instruction *B) {
7352 if (A->getParent() == B->getParent() &&
7354 B->getIterator()))
7355 return true;
7356
7357 auto *BLoop = LI.getLoopFor(B->getParent());
7358 if (BLoop && BLoop->getHeader() == B->getParent() &&
7359 BLoop->getLoopPreheader() == A->getParent() &&
7361 A->getParent()->end()) &&
7362 isGuaranteedToTransferExecutionToSuccessor(B->getParent()->begin(),
7363 B->getIterator()))
7364 return true;
7365 return false;
7366}
7367
7368bool ScalarEvolution::isGuaranteedNotToBePoison(const SCEV *Op) {
7369 SCEVPoisonCollector PC(/* LookThroughMaybePoisonBlocking */ true);
7370 visitAll(Op, PC);
7371 return PC.MaybePoison.empty();
7372}
7373
7374bool ScalarEvolution::isGuaranteedNotToCauseUB(const SCEV *Op) {
7375 return !SCEVExprContains(Op, [this](const SCEV *S) {
7376 const SCEV *Op1;
7377 bool M = match(S, m_scev_UDiv(m_SCEV(), m_SCEV(Op1)));
7378 // The UDiv may be UB if the divisor is poison or zero. Unless the divisor
7379 // is a non-zero constant, we have to assume the UDiv may be UB.
7380 return M && (!isKnownNonZero(Op1) || !isGuaranteedNotToBePoison(Op1));
7381 });
7382}
7383
7384bool ScalarEvolution::isSCEVExprNeverPoison(const Instruction *I) {
7385 // Only proceed if we can prove that I does not yield poison.
7387 return false;
7388
7389 // At this point we know that if I is executed, then it does not wrap
7390 // according to at least one of NSW or NUW. If I is not executed, then we do
7391 // not know if the calculation that I represents would wrap. Multiple
7392 // instructions can map to the same SCEV. If we apply NSW or NUW from I to
7393 // the SCEV, we must guarantee no wrapping for that SCEV also when it is
7394 // derived from other instructions that map to the same SCEV. We cannot make
7395 // that guarantee for cases where I is not executed. So we need to find a
7396 // upper bound on the defining scope for the SCEV, and prove that I is
7397 // executed every time we enter that scope. When the bounding scope is a
7398 // loop (the common case), this is equivalent to proving I executes on every
7399 // iteration of that loop.
7401 for (const Use &Op : I->operands()) {
7402 // I could be an extractvalue from a call to an overflow intrinsic.
7403 // TODO: We can do better here in some cases.
7404 if (isSCEVable(Op->getType()))
7405 SCEVOps.push_back(getSCEV(Op));
7406 }
7407 auto *DefI = getDefiningScopeBound(SCEVOps);
7408 return isGuaranteedToTransferExecutionTo(DefI, I);
7409}
7410
7411bool ScalarEvolution::isAddRecNeverPoison(const Instruction *I, const Loop *L) {
7412 // If we know that \c I can never be poison period, then that's enough.
7413 if (isSCEVExprNeverPoison(I))
7414 return true;
7415
7416 // If the loop only has one exit, then we know that, if the loop is entered,
7417 // any instruction dominating that exit will be executed. If any such
7418 // instruction would result in UB, the addrec cannot be poison.
7419 //
7420 // This is basically the same reasoning as in isSCEVExprNeverPoison(), but
7421 // also handles uses outside the loop header (they just need to dominate the
7422 // single exit).
7423
7424 auto *ExitingBB = L->getExitingBlock();
7425 if (!ExitingBB || !loopHasNoAbnormalExits(L))
7426 return false;
7427
7428 SmallPtrSet<const Value *, 16> KnownPoison;
7430
7431 // We start by assuming \c I, the post-inc add recurrence, is poison. Only
7432 // things that are known to be poison under that assumption go on the
7433 // Worklist.
7434 KnownPoison.insert(I);
7435 Worklist.push_back(I);
7436
7437 while (!Worklist.empty()) {
7438 const Instruction *Poison = Worklist.pop_back_val();
7439
7440 for (const Use &U : Poison->uses()) {
7441 const Instruction *PoisonUser = cast<Instruction>(U.getUser());
7442 if (mustTriggerUB(PoisonUser, KnownPoison) &&
7443 DT.dominates(PoisonUser->getParent(), ExitingBB))
7444 return true;
7445
7446 if (propagatesPoison(U) && L->contains(PoisonUser))
7447 if (KnownPoison.insert(PoisonUser).second)
7448 Worklist.push_back(PoisonUser);
7449 }
7450 }
7451
7452 return false;
7453}
7454
7455ScalarEvolution::LoopProperties
7456ScalarEvolution::getLoopProperties(const Loop *L) {
7457 using LoopProperties = ScalarEvolution::LoopProperties;
7458
7459 auto Itr = LoopPropertiesCache.find(L);
7460 if (Itr == LoopPropertiesCache.end()) {
7461 auto HasSideEffects = [](Instruction *I) {
7462 if (auto *SI = dyn_cast<StoreInst>(I))
7463 return !SI->isSimple();
7464
7465 if (I->mayThrow())
7466 return true;
7467
7468 // Non-volatile memset / memcpy do not count as side-effect for forward
7469 // progress.
7470 if (isa<MemIntrinsic>(I) && !I->isVolatile())
7471 return false;
7472
7473 return I->mayWriteToMemory();
7474 };
7475
7476 LoopProperties LP = {/* HasNoAbnormalExits */ true,
7477 /*HasNoSideEffects*/ true};
7478
7479 for (auto *BB : L->getBlocks())
7480 for (auto &I : *BB) {
7482 LP.HasNoAbnormalExits = false;
7483 if (HasSideEffects(&I))
7484 LP.HasNoSideEffects = false;
7485 if (!LP.HasNoAbnormalExits && !LP.HasNoSideEffects)
7486 break; // We're already as pessimistic as we can get.
7487 }
7488
7489 auto InsertPair = LoopPropertiesCache.insert({L, LP});
7490 assert(InsertPair.second && "We just checked!");
7491 Itr = InsertPair.first;
7492 }
7493
7494 return Itr->second;
7495}
7496
7498 // A mustprogress loop without side effects must be finite.
7499 // TODO: The check used here is very conservative. It's only *specific*
7500 // side effects which are well defined in infinite loops.
7501 return isFinite(L) || (isMustProgress(L) && loopHasNoSideEffects(L));
7502}
7503
7504const SCEV *ScalarEvolution::createSCEVIter(Value *V) {
7505 // Worklist item with a Value and a bool indicating whether all operands have
7506 // been visited already.
7509
7510 Stack.emplace_back(V, true);
7511 Stack.emplace_back(V, false);
7512 while (!Stack.empty()) {
7513 auto E = Stack.pop_back_val();
7514 Value *CurV = E.getPointer();
7515
7516 if (getExistingSCEV(CurV))
7517 continue;
7518
7520 const SCEV *CreatedSCEV = nullptr;
7521 // If all operands have been visited already, create the SCEV.
7522 if (E.getInt()) {
7523 CreatedSCEV = createSCEV(CurV);
7524 } else {
7525 // Otherwise get the operands we need to create SCEV's for before creating
7526 // the SCEV for CurV. If the SCEV for CurV can be constructed trivially,
7527 // just use it.
7528 CreatedSCEV = getOperandsToCreate(CurV, Ops);
7529 }
7530
7531 if (CreatedSCEV) {
7532 insertValueToMap(CurV, CreatedSCEV);
7533 } else {
7534 // Queue CurV for SCEV creation, followed by its's operands which need to
7535 // be constructed first.
7536 Stack.emplace_back(CurV, true);
7537 for (Value *Op : Ops)
7538 Stack.emplace_back(Op, false);
7539 }
7540 }
7541
7542 return getExistingSCEV(V);
7543}
7544
7545const SCEV *
7546ScalarEvolution::getOperandsToCreate(Value *V, SmallVectorImpl<Value *> &Ops) {
7547 if (!isSCEVable(V->getType()))
7548 return getUnknown(V);
7549
7550 if (Instruction *I = dyn_cast<Instruction>(V)) {
7551 // Don't attempt to analyze instructions in blocks that aren't
7552 // reachable. Such instructions don't matter, and they aren't required
7553 // to obey basic rules for definitions dominating uses which this
7554 // analysis depends on.
7555 if (!DT.isReachableFromEntry(I->getParent()))
7556 return getUnknown(PoisonValue::get(V->getType()));
7557 } else if (ConstantInt *CI = dyn_cast<ConstantInt>(V))
7558 return getConstant(CI);
7559 else if (isa<GlobalAlias>(V))
7560 return getUnknown(V);
7561 else if (!isa<ConstantExpr>(V))
7562 return getUnknown(V);
7563
7564 Operator *U = cast<Operator>(V);
7565 if (auto BO =
7567 bool IsConstArg = isa<ConstantInt>(BO->RHS);
7568 switch (BO->Opcode) {
7569 case Instruction::Add:
7570 case Instruction::Mul: {
7571 // For additions and multiplications, traverse add/mul chains for which we
7572 // can potentially create a single SCEV, to reduce the number of
7573 // get{Add,Mul}Expr calls.
7574 do {
7575 if (BO->Op) {
7576 if (BO->Op != V && getExistingSCEV(BO->Op)) {
7577 Ops.push_back(BO->Op);
7578 break;
7579 }
7580 }
7581 Ops.push_back(BO->RHS);
7582 auto NewBO = MatchBinaryOp(BO->LHS, getDataLayout(), AC, DT,
7584 if (!NewBO ||
7585 (BO->Opcode == Instruction::Add &&
7586 (NewBO->Opcode != Instruction::Add &&
7587 NewBO->Opcode != Instruction::Sub)) ||
7588 (BO->Opcode == Instruction::Mul &&
7589 NewBO->Opcode != Instruction::Mul)) {
7590 Ops.push_back(BO->LHS);
7591 break;
7592 }
7593 // CreateSCEV calls getNoWrapFlagsFromUB, which under certain conditions
7594 // requires a SCEV for the LHS.
7595 if (BO->Op && (BO->IsNSW || BO->IsNUW)) {
7596 auto *I = dyn_cast<Instruction>(BO->Op);
7597 if (I && programUndefinedIfPoison(I)) {
7598 Ops.push_back(BO->LHS);
7599 break;
7600 }
7601 }
7602 BO = NewBO;
7603 } while (true);
7604 return nullptr;
7605 }
7606 case Instruction::Sub:
7607 case Instruction::UDiv:
7608 case Instruction::URem:
7609 break;
7610 case Instruction::AShr:
7611 case Instruction::Shl:
7612 case Instruction::Xor:
7613 if (!IsConstArg)
7614 return nullptr;
7615 break;
7616 case Instruction::And:
7617 case Instruction::Or:
7618 if (!IsConstArg && !BO->LHS->getType()->isIntegerTy(1))
7619 return nullptr;
7620 break;
7621 case Instruction::LShr:
7622 return getUnknown(V);
7623 default:
7624 llvm_unreachable("Unhandled binop");
7625 break;
7626 }
7627
7628 Ops.push_back(BO->LHS);
7629 Ops.push_back(BO->RHS);
7630 return nullptr;
7631 }
7632
7633 switch (U->getOpcode()) {
7634 case Instruction::Trunc:
7635 case Instruction::ZExt:
7636 case Instruction::SExt:
7637 case Instruction::PtrToInt:
7638 Ops.push_back(U->getOperand(0));
7639 return nullptr;
7640
7641 case Instruction::BitCast:
7642 if (isSCEVable(U->getType()) && isSCEVable(U->getOperand(0)->getType())) {
7643 Ops.push_back(U->getOperand(0));
7644 return nullptr;
7645 }
7646 return getUnknown(V);
7647
7648 case Instruction::SDiv:
7649 case Instruction::SRem:
7650 Ops.push_back(U->getOperand(0));
7651 Ops.push_back(U->getOperand(1));
7652 return nullptr;
7653
7654 case Instruction::GetElementPtr:
7655 assert(cast<GEPOperator>(U)->getSourceElementType()->isSized() &&
7656 "GEP source element type must be sized");
7657 llvm::append_range(Ops, U->operands());
7658 return nullptr;
7659
7660 case Instruction::IntToPtr:
7661 return getUnknown(V);
7662
7663 case Instruction::PHI:
7664 // Keep constructing SCEVs' for phis recursively for now.
7665 return nullptr;
7666
7667 case Instruction::Select: {
7668 // Check if U is a select that can be simplified to a SCEVUnknown.
7669 auto CanSimplifyToUnknown = [this, U]() {
7670 if (U->getType()->isIntegerTy(1) || isa<ConstantInt>(U->getOperand(0)))
7671 return false;
7672
7673 auto *ICI = dyn_cast<ICmpInst>(U->getOperand(0));
7674 if (!ICI)
7675 return false;
7676 Value *LHS = ICI->getOperand(0);
7677 Value *RHS = ICI->getOperand(1);
7678 if (ICI->getPredicate() == CmpInst::ICMP_EQ ||
7679 ICI->getPredicate() == CmpInst::ICMP_NE) {
7681 return true;
7682 } else if (getTypeSizeInBits(LHS->getType()) >
7683 getTypeSizeInBits(U->getType()))
7684 return true;
7685 return false;
7686 };
7687 if (CanSimplifyToUnknown())
7688 return getUnknown(U);
7689
7690 llvm::append_range(Ops, U->operands());
7691 return nullptr;
7692 break;
7693 }
7694 case Instruction::Call:
7695 case Instruction::Invoke:
7696 if (Value *RV = cast<CallBase>(U)->getReturnedArgOperand()) {
7697 Ops.push_back(RV);
7698 return nullptr;
7699 }
7700
7701 if (auto *II = dyn_cast<IntrinsicInst>(U)) {
7702 switch (II->getIntrinsicID()) {
7703 case Intrinsic::abs:
7704 Ops.push_back(II->getArgOperand(0));
7705 return nullptr;
7706 case Intrinsic::umax:
7707 case Intrinsic::umin:
7708 case Intrinsic::smax:
7709 case Intrinsic::smin:
7710 case Intrinsic::usub_sat:
7711 case Intrinsic::uadd_sat:
7712 Ops.push_back(II->getArgOperand(0));
7713 Ops.push_back(II->getArgOperand(1));
7714 return nullptr;
7715 case Intrinsic::start_loop_iterations:
7716 case Intrinsic::annotation:
7717 case Intrinsic::ptr_annotation:
7718 Ops.push_back(II->getArgOperand(0));
7719 return nullptr;
7720 default:
7721 break;
7722 }
7723 }
7724 break;
7725 }
7726
7727 return nullptr;
7728}
7729
7730const SCEV *ScalarEvolution::createSCEV(Value *V) {
7731 if (!isSCEVable(V->getType()))
7732 return getUnknown(V);
7733
7734 if (Instruction *I = dyn_cast<Instruction>(V)) {
7735 // Don't attempt to analyze instructions in blocks that aren't
7736 // reachable. Such instructions don't matter, and they aren't required
7737 // to obey basic rules for definitions dominating uses which this
7738 // analysis depends on.
7739 if (!DT.isReachableFromEntry(I->getParent()))
7740 return getUnknown(PoisonValue::get(V->getType()));
7741 } else if (ConstantInt *CI = dyn_cast<ConstantInt>(V))
7742 return getConstant(CI);
7743 else if (isa<GlobalAlias>(V))
7744 return getUnknown(V);
7745 else if (!isa<ConstantExpr>(V))
7746 return getUnknown(V);
7747
7748 const SCEV *LHS;
7749 const SCEV *RHS;
7750
7751 Operator *U = cast<Operator>(V);
7752 if (auto BO =
7754 switch (BO->Opcode) {
7755 case Instruction::Add: {
7756 // The simple thing to do would be to just call getSCEV on both operands
7757 // and call getAddExpr with the result. However if we're looking at a
7758 // bunch of things all added together, this can be quite inefficient,
7759 // because it leads to N-1 getAddExpr calls for N ultimate operands.
7760 // Instead, gather up all the operands and make a single getAddExpr call.
7761 // LLVM IR canonical form means we need only traverse the left operands.
7763 do {
7764 if (BO->Op) {
7765 if (auto *OpSCEV = getExistingSCEV(BO->Op)) {
7766 AddOps.push_back(OpSCEV);
7767 break;
7768 }
7769
7770 // If a NUW or NSW flag can be applied to the SCEV for this
7771 // addition, then compute the SCEV for this addition by itself
7772 // with a separate call to getAddExpr. We need to do that
7773 // instead of pushing the operands of the addition onto AddOps,
7774 // since the flags are only known to apply to this particular
7775 // addition - they may not apply to other additions that can be
7776 // formed with operands from AddOps.
7777 const SCEV *RHS = getSCEV(BO->RHS);
7778 SCEV::NoWrapFlags Flags = getNoWrapFlagsFromUB(BO->Op);
7779 if (Flags != SCEV::FlagAnyWrap) {
7780 const SCEV *LHS = getSCEV(BO->LHS);
7781 if (BO->Opcode == Instruction::Sub)
7782 AddOps.push_back(getMinusSCEV(LHS, RHS, Flags));
7783 else
7784 AddOps.push_back(getAddExpr(LHS, RHS, Flags));
7785 break;
7786 }
7787 }
7788
7789 if (BO->Opcode == Instruction::Sub)
7790 AddOps.push_back(getNegativeSCEV(getSCEV(BO->RHS)));
7791 else
7792 AddOps.push_back(getSCEV(BO->RHS));
7793
7794 auto NewBO = MatchBinaryOp(BO->LHS, getDataLayout(), AC, DT,
7796 if (!NewBO || (NewBO->Opcode != Instruction::Add &&
7797 NewBO->Opcode != Instruction::Sub)) {
7798 AddOps.push_back(getSCEV(BO->LHS));
7799 break;
7800 }
7801 BO = NewBO;
7802 } while (true);
7803
7804 return getAddExpr(AddOps);
7805 }
7806
7807 case Instruction::Mul: {
7809 do {
7810 if (BO->Op) {
7811 if (auto *OpSCEV = getExistingSCEV(BO->Op)) {
7812 MulOps.push_back(OpSCEV);
7813 break;
7814 }
7815
7816 SCEV::NoWrapFlags Flags = getNoWrapFlagsFromUB(BO->Op);
7817 if (Flags != SCEV::FlagAnyWrap) {
7818 LHS = getSCEV(BO->LHS);
7819 RHS = getSCEV(BO->RHS);
7820 MulOps.push_back(getMulExpr(LHS, RHS, Flags));
7821 break;
7822 }
7823 }
7824
7825 MulOps.push_back(getSCEV(BO->RHS));
7826 auto NewBO = MatchBinaryOp(BO->LHS, getDataLayout(), AC, DT,
7828 if (!NewBO || NewBO->Opcode != Instruction::Mul) {
7829 MulOps.push_back(getSCEV(BO->LHS));
7830 break;
7831 }
7832 BO = NewBO;
7833 } while (true);
7834
7835 return getMulExpr(MulOps);
7836 }
7837 case Instruction::UDiv:
7838 LHS = getSCEV(BO->LHS);
7839 RHS = getSCEV(BO->RHS);
7840 return getUDivExpr(LHS, RHS);
7841 case Instruction::URem:
7842 LHS = getSCEV(BO->LHS);
7843 RHS = getSCEV(BO->RHS);
7844 return getURemExpr(LHS, RHS);
7845 case Instruction::Sub: {
7847 if (BO->Op)
7848 Flags = getNoWrapFlagsFromUB(BO->Op);
7849 LHS = getSCEV(BO->LHS);
7850 RHS = getSCEV(BO->RHS);
7851 return getMinusSCEV(LHS, RHS, Flags);
7852 }
7853 case Instruction::And:
7854 // For an expression like x&255 that merely masks off the high bits,
7855 // use zext(trunc(x)) as the SCEV expression.
7856 if (ConstantInt *CI = dyn_cast<ConstantInt>(BO->RHS)) {
7857 if (CI->isZero())
7858 return getSCEV(BO->RHS);
7859 if (CI->isMinusOne())
7860 return getSCEV(BO->LHS);
7861 const APInt &A = CI->getValue();
7862
7863 // Instcombine's ShrinkDemandedConstant may strip bits out of
7864 // constants, obscuring what would otherwise be a low-bits mask.
7865 // Use computeKnownBits to compute what ShrinkDemandedConstant
7866 // knew about to reconstruct a low-bits mask value.
7867 unsigned LZ = A.countl_zero();
7868 unsigned TZ = A.countr_zero();
7869 unsigned BitWidth = A.getBitWidth();
7870 KnownBits Known(BitWidth);
7871 computeKnownBits(BO->LHS, Known, getDataLayout(), &AC, nullptr, &DT);
7872
7873 APInt EffectiveMask =
7874 APInt::getLowBitsSet(BitWidth, BitWidth - LZ - TZ).shl(TZ);
7875 if ((LZ != 0 || TZ != 0) && !((~A & ~Known.Zero) & EffectiveMask)) {
7876 const SCEV *MulCount = getConstant(APInt::getOneBitSet(BitWidth, TZ));
7877 const SCEV *LHS = getSCEV(BO->LHS);
7878 const SCEV *ShiftedLHS = nullptr;
7879 if (auto *LHSMul = dyn_cast<SCEVMulExpr>(LHS)) {
7880 if (auto *OpC = dyn_cast<SCEVConstant>(LHSMul->getOperand(0))) {
7881 // For an expression like (x * 8) & 8, simplify the multiply.
7882 unsigned MulZeros = OpC->getAPInt().countr_zero();
7883 unsigned GCD = std::min(MulZeros, TZ);
7884 APInt DivAmt = APInt::getOneBitSet(BitWidth, TZ - GCD);
7886 MulOps.push_back(getConstant(OpC->getAPInt().ashr(GCD)));
7887 append_range(MulOps, LHSMul->operands().drop_front());
7888 auto *NewMul = getMulExpr(MulOps, LHSMul->getNoWrapFlags());
7889 ShiftedLHS = getUDivExpr(NewMul, getConstant(DivAmt));
7890 }
7891 }
7892 if (!ShiftedLHS)
7893 ShiftedLHS = getUDivExpr(LHS, MulCount);
7894 return getMulExpr(
7896 getTruncateExpr(ShiftedLHS,
7897 IntegerType::get(getContext(), BitWidth - LZ - TZ)),
7898 BO->LHS->getType()),
7899 MulCount);
7900 }
7901 }
7902 // Binary `and` is a bit-wise `umin`.
7903 if (BO->LHS->getType()->isIntegerTy(1)) {
7904 LHS = getSCEV(BO->LHS);
7905 RHS = getSCEV(BO->RHS);
7906 return getUMinExpr(LHS, RHS);
7907 }
7908 break;
7909
7910 case Instruction::Or:
7911 // Binary `or` is a bit-wise `umax`.
7912 if (BO->LHS->getType()->isIntegerTy(1)) {
7913 LHS = getSCEV(BO->LHS);
7914 RHS = getSCEV(BO->RHS);
7915 return getUMaxExpr(LHS, RHS);
7916 }
7917 break;
7918
7919 case Instruction::Xor:
7920 if (ConstantInt *CI = dyn_cast<ConstantInt>(BO->RHS)) {
7921 // If the RHS of xor is -1, then this is a not operation.
7922 if (CI->isMinusOne())
7923 return getNotSCEV(getSCEV(BO->LHS));
7924
7925 // Model xor(and(x, C), C) as and(~x, C), if C is a low-bits mask.
7926 // This is a variant of the check for xor with -1, and it handles
7927 // the case where instcombine has trimmed non-demanded bits out
7928 // of an xor with -1.
7929 if (auto *LBO = dyn_cast<BinaryOperator>(BO->LHS))
7930 if (ConstantInt *LCI = dyn_cast<ConstantInt>(LBO->getOperand(1)))
7931 if (LBO->getOpcode() == Instruction::And &&
7932 LCI->getValue() == CI->getValue())
7933 if (const SCEVZeroExtendExpr *Z =
7935 Type *UTy = BO->LHS->getType();
7936 const SCEV *Z0 = Z->getOperand();
7937 Type *Z0Ty = Z0->getType();
7938 unsigned Z0TySize = getTypeSizeInBits(Z0Ty);
7939
7940 // If C is a low-bits mask, the zero extend is serving to
7941 // mask off the high bits. Complement the operand and
7942 // re-apply the zext.
7943 if (CI->getValue().isMask(Z0TySize))
7944 return getZeroExtendExpr(getNotSCEV(Z0), UTy);
7945
7946 // If C is a single bit, it may be in the sign-bit position
7947 // before the zero-extend. In this case, represent the xor
7948 // using an add, which is equivalent, and re-apply the zext.
7949 APInt Trunc = CI->getValue().trunc(Z0TySize);
7950 if (Trunc.zext(getTypeSizeInBits(UTy)) == CI->getValue() &&
7951 Trunc.isSignMask())
7952 return getZeroExtendExpr(getAddExpr(Z0, getConstant(Trunc)),
7953 UTy);
7954 }
7955 }
7956 break;
7957
7958 case Instruction::Shl:
7959 // Turn shift left of a constant amount into a multiply.
7960 if (ConstantInt *SA = dyn_cast<ConstantInt>(BO->RHS)) {
7961 uint32_t BitWidth = cast<IntegerType>(SA->getType())->getBitWidth();
7962
7963 // If the shift count is not less than the bitwidth, the result of
7964 // the shift is undefined. Don't try to analyze it, because the
7965 // resolution chosen here may differ from the resolution chosen in
7966 // other parts of the compiler.
7967 if (SA->getValue().uge(BitWidth))
7968 break;
7969
7970 // We can safely preserve the nuw flag in all cases. It's also safe to
7971 // turn a nuw nsw shl into a nuw nsw mul. However, nsw in isolation
7972 // requires special handling. It can be preserved as long as we're not
7973 // left shifting by bitwidth - 1.
7974 auto Flags = SCEV::FlagAnyWrap;
7975 if (BO->Op) {
7976 auto MulFlags = getNoWrapFlagsFromUB(BO->Op);
7977 if ((MulFlags & SCEV::FlagNSW) &&
7978 ((MulFlags & SCEV::FlagNUW) || SA->getValue().ult(BitWidth - 1)))
7980 if (MulFlags & SCEV::FlagNUW)
7982 }
7983
7984 ConstantInt *X = ConstantInt::get(
7985 getContext(), APInt::getOneBitSet(BitWidth, SA->getZExtValue()));
7986 return getMulExpr(getSCEV(BO->LHS), getConstant(X), Flags);
7987 }
7988 break;
7989
7990 case Instruction::AShr:
7991 // AShr X, C, where C is a constant.
7992 ConstantInt *CI = dyn_cast<ConstantInt>(BO->RHS);
7993 if (!CI)
7994 break;
7995
7996 Type *OuterTy = BO->LHS->getType();
7997 uint64_t BitWidth = getTypeSizeInBits(OuterTy);
7998 // If the shift count is not less than the bitwidth, the result of
7999 // the shift is undefined. Don't try to analyze it, because the
8000 // resolution chosen here may differ from the resolution chosen in
8001 // other parts of the compiler.
8002 if (CI->getValue().uge(BitWidth))
8003 break;
8004
8005 if (CI->isZero())
8006 return getSCEV(BO->LHS); // shift by zero --> noop
8007
8008 uint64_t AShrAmt = CI->getZExtValue();
8009 Type *TruncTy = IntegerType::get(getContext(), BitWidth - AShrAmt);
8010
8011 Operator *L = dyn_cast<Operator>(BO->LHS);
8012 const SCEV *AddTruncateExpr = nullptr;
8013 ConstantInt *ShlAmtCI = nullptr;
8014 const SCEV *AddConstant = nullptr;
8015
8016 if (L && L->getOpcode() == Instruction::Add) {
8017 // X = Shl A, n
8018 // Y = Add X, c
8019 // Z = AShr Y, m
8020 // n, c and m are constants.
8021
8022 Operator *LShift = dyn_cast<Operator>(L->getOperand(0));
8023 ConstantInt *AddOperandCI = dyn_cast<ConstantInt>(L->getOperand(1));
8024 if (LShift && LShift->getOpcode() == Instruction::Shl) {
8025 if (AddOperandCI) {
8026 const SCEV *ShlOp0SCEV = getSCEV(LShift->getOperand(0));
8027 ShlAmtCI = dyn_cast<ConstantInt>(LShift->getOperand(1));
8028 // since we truncate to TruncTy, the AddConstant should be of the
8029 // same type, so create a new Constant with type same as TruncTy.
8030 // Also, the Add constant should be shifted right by AShr amount.
8031 APInt AddOperand = AddOperandCI->getValue().ashr(AShrAmt);
8032 AddConstant = getConstant(AddOperand.trunc(BitWidth - AShrAmt));
8033 // we model the expression as sext(add(trunc(A), c << n)), since the
8034 // sext(trunc) part is already handled below, we create a
8035 // AddExpr(TruncExp) which will be used later.
8036 AddTruncateExpr = getTruncateExpr(ShlOp0SCEV, TruncTy);
8037 }
8038 }
8039 } else if (L && L->getOpcode() == Instruction::Shl) {
8040 // X = Shl A, n
8041 // Y = AShr X, m
8042 // Both n and m are constant.
8043
8044 const SCEV *ShlOp0SCEV = getSCEV(L->getOperand(0));
8045 ShlAmtCI = dyn_cast<ConstantInt>(L->getOperand(1));
8046 AddTruncateExpr = getTruncateExpr(ShlOp0SCEV, TruncTy);
8047 }
8048
8049 if (AddTruncateExpr && ShlAmtCI) {
8050 // We can merge the two given cases into a single SCEV statement,
8051 // incase n = m, the mul expression will be 2^0, so it gets resolved to
8052 // a simpler case. The following code handles the two cases:
8053 //
8054 // 1) For a two-shift sext-inreg, i.e. n = m,
8055 // use sext(trunc(x)) as the SCEV expression.
8056 //
8057 // 2) When n > m, use sext(mul(trunc(x), 2^(n-m)))) as the SCEV
8058 // expression. We already checked that ShlAmt < BitWidth, so
8059 // the multiplier, 1 << (ShlAmt - AShrAmt), fits into TruncTy as
8060 // ShlAmt - AShrAmt < Amt.
8061 const APInt &ShlAmt = ShlAmtCI->getValue();
8062 if (ShlAmt.ult(BitWidth) && ShlAmt.uge(AShrAmt)) {
8063 APInt Mul = APInt::getOneBitSet(BitWidth - AShrAmt,
8064 ShlAmtCI->getZExtValue() - AShrAmt);
8065 const SCEV *CompositeExpr =
8066 getMulExpr(AddTruncateExpr, getConstant(Mul));
8067 if (L->getOpcode() != Instruction::Shl)
8068 CompositeExpr = getAddExpr(CompositeExpr, AddConstant);
8069
8070 return getSignExtendExpr(CompositeExpr, OuterTy);
8071 }
8072 }
8073 break;
8074 }
8075 }
8076
8077 switch (U->getOpcode()) {
8078 case Instruction::Trunc:
8079 return getTruncateExpr(getSCEV(U->getOperand(0)), U->getType());
8080
8081 case Instruction::ZExt:
8082 return getZeroExtendExpr(getSCEV(U->getOperand(0)), U->getType());
8083
8084 case Instruction::SExt:
8085 if (auto BO = MatchBinaryOp(U->getOperand(0), getDataLayout(), AC, DT,
8087 // The NSW flag of a subtract does not always survive the conversion to
8088 // A + (-1)*B. By pushing sign extension onto its operands we are much
8089 // more likely to preserve NSW and allow later AddRec optimisations.
8090 //
8091 // NOTE: This is effectively duplicating this logic from getSignExtend:
8092 // sext((A + B + ...)<nsw>) --> (sext(A) + sext(B) + ...)<nsw>
8093 // but by that point the NSW information has potentially been lost.
8094 if (BO->Opcode == Instruction::Sub && BO->IsNSW) {
8095 Type *Ty = U->getType();
8096 auto *V1 = getSignExtendExpr(getSCEV(BO->LHS), Ty);
8097 auto *V2 = getSignExtendExpr(getSCEV(BO->RHS), Ty);
8098 return getMinusSCEV(V1, V2, SCEV::FlagNSW);
8099 }
8100 }
8101 return getSignExtendExpr(getSCEV(U->getOperand(0)), U->getType());
8102
8103 case Instruction::BitCast:
8104 // BitCasts are no-op casts so we just eliminate the cast.
8105 if (isSCEVable(U->getType()) && isSCEVable(U->getOperand(0)->getType()))
8106 return getSCEV(U->getOperand(0));
8107 break;
8108
8109 case Instruction::PtrToInt: {
8110 // Pointer to integer cast is straight-forward, so do model it.
8111 const SCEV *Op = getSCEV(U->getOperand(0));
8112 Type *DstIntTy = U->getType();
8113 // But only if effective SCEV (integer) type is wide enough to represent
8114 // all possible pointer values.
8115 const SCEV *IntOp = getPtrToIntExpr(Op, DstIntTy);
8116 if (isa<SCEVCouldNotCompute>(IntOp))
8117 return getUnknown(V);
8118 return IntOp;
8119 }
8120 case Instruction::IntToPtr:
8121 // Just don't deal with inttoptr casts.
8122 return getUnknown(V);
8123
8124 case Instruction::SDiv:
8125 // If both operands are non-negative, this is just an udiv.
8126 if (isKnownNonNegative(getSCEV(U->getOperand(0))) &&
8127 isKnownNonNegative(getSCEV(U->getOperand(1))))
8128 return getUDivExpr(getSCEV(U->getOperand(0)), getSCEV(U->getOperand(1)));
8129 break;
8130
8131 case Instruction::SRem:
8132 // If both operands are non-negative, this is just an urem.
8133 if (isKnownNonNegative(getSCEV(U->getOperand(0))) &&
8134 isKnownNonNegative(getSCEV(U->getOperand(1))))
8135 return getURemExpr(getSCEV(U->getOperand(0)), getSCEV(U->getOperand(1)));
8136 break;
8137
8138 case Instruction::GetElementPtr:
8139 return createNodeForGEP(cast<GEPOperator>(U));
8140
8141 case Instruction::PHI:
8142 return createNodeForPHI(cast<PHINode>(U));
8143
8144 case Instruction::Select:
8145 return createNodeForSelectOrPHI(U, U->getOperand(0), U->getOperand(1),
8146 U->getOperand(2));
8147
8148 case Instruction::Call:
8149 case Instruction::Invoke:
8150 if (Value *RV = cast<CallBase>(U)->getReturnedArgOperand())
8151 return getSCEV(RV);
8152
8153 if (auto *II = dyn_cast<IntrinsicInst>(U)) {
8154 switch (II->getIntrinsicID()) {
8155 case Intrinsic::abs:
8156 return getAbsExpr(
8157 getSCEV(II->getArgOperand(0)),
8158 /*IsNSW=*/cast<ConstantInt>(II->getArgOperand(1))->isOne());
8159 case Intrinsic::umax:
8160 LHS = getSCEV(II->getArgOperand(0));
8161 RHS = getSCEV(II->getArgOperand(1));
8162 return getUMaxExpr(LHS, RHS);
8163 case Intrinsic::umin:
8164 LHS = getSCEV(II->getArgOperand(0));
8165 RHS = getSCEV(II->getArgOperand(1));
8166 return getUMinExpr(LHS, RHS);
8167 case Intrinsic::smax:
8168 LHS = getSCEV(II->getArgOperand(0));
8169 RHS = getSCEV(II->getArgOperand(1));
8170 return getSMaxExpr(LHS, RHS);
8171 case Intrinsic::smin:
8172 LHS = getSCEV(II->getArgOperand(0));
8173 RHS = getSCEV(II->getArgOperand(1));
8174 return getSMinExpr(LHS, RHS);
8175 case Intrinsic::usub_sat: {
8176 const SCEV *X = getSCEV(II->getArgOperand(0));
8177 const SCEV *Y = getSCEV(II->getArgOperand(1));
8178 const SCEV *ClampedY = getUMinExpr(X, Y);
8179 return getMinusSCEV(X, ClampedY, SCEV::FlagNUW);
8180 }
8181 case Intrinsic::uadd_sat: {
8182 const SCEV *X = getSCEV(II->getArgOperand(0));
8183 const SCEV *Y = getSCEV(II->getArgOperand(1));
8184 const SCEV *ClampedX = getUMinExpr(X, getNotSCEV(Y));
8185 return getAddExpr(ClampedX, Y, SCEV::FlagNUW);
8186 }
8187 case Intrinsic::start_loop_iterations:
8188 case Intrinsic::annotation:
8189 case Intrinsic::ptr_annotation:
8190 // A start_loop_iterations or llvm.annotation or llvm.prt.annotation is
8191 // just eqivalent to the first operand for SCEV purposes.
8192 return getSCEV(II->getArgOperand(0));
8193 case Intrinsic::vscale:
8194 return getVScale(II->getType());
8195 default:
8196 break;
8197 }
8198 }
8199 break;
8200 }
8201
8202 return getUnknown(V);
8203}
8204
8205//===----------------------------------------------------------------------===//
8206// Iteration Count Computation Code
8207//
8208
8210 if (isa<SCEVCouldNotCompute>(ExitCount))
8211 return getCouldNotCompute();
8212
8213 auto *ExitCountType = ExitCount->getType();
8214 assert(ExitCountType->isIntegerTy());
8215 auto *EvalTy = Type::getIntNTy(ExitCountType->getContext(),
8216 1 + ExitCountType->getScalarSizeInBits());
8217 return getTripCountFromExitCount(ExitCount, EvalTy, nullptr);
8218}
8219
8221 Type *EvalTy,
8222 const Loop *L) {
8223 if (isa<SCEVCouldNotCompute>(ExitCount))
8224 return getCouldNotCompute();
8225
8226 unsigned ExitCountSize = getTypeSizeInBits(ExitCount->getType());
8227 unsigned EvalSize = EvalTy->getPrimitiveSizeInBits();
8228
8229 auto CanAddOneWithoutOverflow = [&]() {
8230 ConstantRange ExitCountRange =
8231 getRangeRef(ExitCount, RangeSignHint::HINT_RANGE_UNSIGNED);
8232 if (!ExitCountRange.contains(APInt::getMaxValue(ExitCountSize)))
8233 return true;
8234
8235 return L && isLoopEntryGuardedByCond(L, ICmpInst::ICMP_NE, ExitCount,
8236 getMinusOne(ExitCount->getType()));
8237 };
8238
8239 // If we need to zero extend the backedge count, check if we can add one to
8240 // it prior to zero extending without overflow. Provided this is safe, it
8241 // allows better simplification of the +1.
8242 if (EvalSize > ExitCountSize && CanAddOneWithoutOverflow())
8243 return getZeroExtendExpr(
8244 getAddExpr(ExitCount, getOne(ExitCount->getType())), EvalTy);
8245
8246 // Get the total trip count from the count by adding 1. This may wrap.
8247 return getAddExpr(getTruncateOrZeroExtend(ExitCount, EvalTy), getOne(EvalTy));
8248}
8249
8250static unsigned getConstantTripCount(const SCEVConstant *ExitCount) {
8251 if (!ExitCount)
8252 return 0;
8253
8254 ConstantInt *ExitConst = ExitCount->getValue();
8255
8256 // Guard against huge trip counts.
8257 if (ExitConst->getValue().getActiveBits() > 32)
8258 return 0;
8259
8260 // In case of integer overflow, this returns 0, which is correct.
8261 return ((unsigned)ExitConst->getZExtValue()) + 1;
8262}
8263
8265 auto *ExitCount = dyn_cast<SCEVConstant>(getBackedgeTakenCount(L, Exact));
8266 return getConstantTripCount(ExitCount);
8267}
8268
8269unsigned
8271 const BasicBlock *ExitingBlock) {
8272 assert(ExitingBlock && "Must pass a non-null exiting block!");
8273 assert(L->isLoopExiting(ExitingBlock) &&
8274 "Exiting block must actually branch out of the loop!");
8275 const SCEVConstant *ExitCount =
8276 dyn_cast<SCEVConstant>(getExitCount(L, ExitingBlock));
8277 return getConstantTripCount(ExitCount);
8278}
8279
8281 const Loop *L, SmallVectorImpl<const SCEVPredicate *> *Predicates) {
8282
8283 const auto *MaxExitCount =
8284 Predicates ? getPredicatedConstantMaxBackedgeTakenCount(L, *Predicates)
8286 return getConstantTripCount(dyn_cast<SCEVConstant>(MaxExitCount));
8287}
8288
8290 SmallVector<BasicBlock *, 8> ExitingBlocks;
8291 L->getExitingBlocks(ExitingBlocks);
8292
8293 std::optional<unsigned> Res;
8294 for (auto *ExitingBB : ExitingBlocks) {
8295 unsigned Multiple = getSmallConstantTripMultiple(L, ExitingBB);
8296 if (!Res)
8297 Res = Multiple;
8298 Res = std::gcd(*Res, Multiple);
8299 }
8300 return Res.value_or(1);
8301}
8302
8304 const SCEV *ExitCount) {
8305 if (isa<SCEVCouldNotCompute>(ExitCount))
8306 return 1;
8307
8308 // Get the trip count
8309 const SCEV *TCExpr = getTripCountFromExitCount(applyLoopGuards(ExitCount, L));
8310
8311 APInt Multiple = getNonZeroConstantMultiple(TCExpr);
8312 // If a trip multiple is huge (>=2^32), the trip count is still divisible by
8313 // the greatest power of 2 divisor less than 2^32.
8314 return Multiple.getActiveBits() > 32
8315 ? 1U << std::min(31U, Multiple.countTrailingZeros())
8316 : (unsigned)Multiple.getZExtValue();
8317}
8318
8319/// Returns the largest constant divisor of the trip count of this loop as a
8320/// normal unsigned value, if possible. This means that the actual trip count is
8321/// always a multiple of the returned value (don't forget the trip count could
8322/// very well be zero as well!).
8323///
8324/// Returns 1 if the trip count is unknown or not guaranteed to be the
8325/// multiple of a constant (which is also the case if the trip count is simply
8326/// constant, use getSmallConstantTripCount for that case), Will also return 1
8327/// if the trip count is very large (>= 2^32).
8328///
8329/// As explained in the comments for getSmallConstantTripCount, this assumes
8330/// that control exits the loop via ExitingBlock.
8331unsigned
8333 const BasicBlock *ExitingBlock) {
8334 assert(ExitingBlock && "Must pass a non-null exiting block!");
8335 assert(L->isLoopExiting(ExitingBlock) &&
8336 "Exiting block must actually branch out of the loop!");
8337 const SCEV *ExitCount = getExitCount(L, ExitingBlock);
8338 return getSmallConstantTripMultiple(L, ExitCount);
8339}
8340
8342 const BasicBlock *ExitingBlock,
8343 ExitCountKind Kind) {
8344 switch (Kind) {
8345 case Exact:
8346 return getBackedgeTakenInfo(L).getExact(ExitingBlock, this);
8347 case SymbolicMaximum:
8348 return getBackedgeTakenInfo(L).getSymbolicMax(ExitingBlock, this);
8349 case ConstantMaximum:
8350 return getBackedgeTakenInfo(L).getConstantMax(ExitingBlock, this);
8351 };
8352 llvm_unreachable("Invalid ExitCountKind!");
8353}
8354
8356 const Loop *L, const BasicBlock *ExitingBlock,
8358 switch (Kind) {
8359 case Exact:
8360 return getPredicatedBackedgeTakenInfo(L).getExact(ExitingBlock, this,
8361 Predicates);
8362 case SymbolicMaximum:
8363 return getPredicatedBackedgeTakenInfo(L).getSymbolicMax(ExitingBlock, this,
8364 Predicates);
8365 case ConstantMaximum:
8366 return getPredicatedBackedgeTakenInfo(L).getConstantMax(ExitingBlock, this,
8367 Predicates);
8368 };
8369 llvm_unreachable("Invalid ExitCountKind!");
8370}
8371
8374 return getPredicatedBackedgeTakenInfo(L).getExact(L, this, &Preds);
8375}
8376
8378 ExitCountKind Kind) {
8379 switch (Kind) {
8380 case Exact:
8381 return getBackedgeTakenInfo(L).getExact(L, this);
8382 case ConstantMaximum:
8383 return getBackedgeTakenInfo(L).getConstantMax(this);
8384 case SymbolicMaximum:
8385 return getBackedgeTakenInfo(L).getSymbolicMax(L, this);
8386 };
8387 llvm_unreachable("Invalid ExitCountKind!");
8388}
8389
8392 return getPredicatedBackedgeTakenInfo(L).getSymbolicMax(L, this, &Preds);
8393}
8394
8397 return getPredicatedBackedgeTakenInfo(L).getConstantMax(this, &Preds);
8398}
8399
8401 return getBackedgeTakenInfo(L).isConstantMaxOrZero(this);
8402}
8403
8404/// Push PHI nodes in the header of the given loop onto the given Worklist.
8405static void PushLoopPHIs(const Loop *L,
8408 BasicBlock *Header = L->getHeader();
8409
8410 // Push all Loop-header PHIs onto the Worklist stack.
8411 for (PHINode &PN : Header->phis())
8412 if (Visited.insert(&PN).second)
8413 Worklist.push_back(&PN);
8414}
8415
8416ScalarEvolution::BackedgeTakenInfo &
8417ScalarEvolution::getPredicatedBackedgeTakenInfo(const Loop *L) {
8418 auto &BTI = getBackedgeTakenInfo(L);
8419 if (BTI.hasFullInfo())
8420 return BTI;
8421
8422 auto Pair = PredicatedBackedgeTakenCounts.try_emplace(L);
8423
8424 if (!Pair.second)
8425 return Pair.first->second;
8426
8427 BackedgeTakenInfo Result =
8428 computeBackedgeTakenCount(L, /*AllowPredicates=*/true);
8429
8430 return PredicatedBackedgeTakenCounts.find(L)->second = std::move(Result);
8431}
8432
8433ScalarEvolution::BackedgeTakenInfo &
8434ScalarEvolution::getBackedgeTakenInfo(const Loop *L) {
8435 // Initially insert an invalid entry for this loop. If the insertion
8436 // succeeds, proceed to actually compute a backedge-taken count and
8437 // update the value. The temporary CouldNotCompute value tells SCEV
8438 // code elsewhere that it shouldn't attempt to request a new
8439 // backedge-taken count, which could result in infinite recursion.
8440 std::pair<DenseMap<const Loop *, BackedgeTakenInfo>::iterator, bool> Pair =
8441 BackedgeTakenCounts.try_emplace(L);
8442 if (!Pair.second)
8443 return Pair.first->second;
8444
8445 // computeBackedgeTakenCount may allocate memory for its result. Inserting it
8446 // into the BackedgeTakenCounts map transfers ownership. Otherwise, the result
8447 // must be cleared in this scope.
8448 BackedgeTakenInfo Result = computeBackedgeTakenCount(L);
8449
8450 // Now that we know more about the trip count for this loop, forget any
8451 // existing SCEV values for PHI nodes in this loop since they are only
8452 // conservative estimates made without the benefit of trip count
8453 // information. This invalidation is not necessary for correctness, and is
8454 // only done to produce more precise results.
8455 if (Result.hasAnyInfo()) {
8456 // Invalidate any expression using an addrec in this loop.
8458 auto LoopUsersIt = LoopUsers.find(L);
8459 if (LoopUsersIt != LoopUsers.end())
8460 append_range(ToForget, LoopUsersIt->second);
8461 forgetMemoizedResults(ToForget);
8462
8463 // Invalidate constant-evolved loop header phis.
8464 for (PHINode &PN : L->getHeader()->phis())
8465 ConstantEvolutionLoopExitValue.erase(&PN);
8466 }
8467
8468 // Re-lookup the insert position, since the call to
8469 // computeBackedgeTakenCount above could result in a
8470 // recusive call to getBackedgeTakenInfo (on a different
8471 // loop), which would invalidate the iterator computed
8472 // earlier.
8473 return BackedgeTakenCounts.find(L)->second = std::move(Result);
8474}
8475
8477 // This method is intended to forget all info about loops. It should
8478 // invalidate caches as if the following happened:
8479 // - The trip counts of all loops have changed arbitrarily
8480 // - Every llvm::Value has been updated in place to produce a different
8481 // result.
8482 BackedgeTakenCounts.clear();
8483 PredicatedBackedgeTakenCounts.clear();
8484 BECountUsers.clear();
8485 LoopPropertiesCache.clear();
8486 ConstantEvolutionLoopExitValue.clear();
8487 ValueExprMap.clear();
8488 ValuesAtScopes.clear();
8489 ValuesAtScopesUsers.clear();
8490 LoopDispositions.clear();
8491 BlockDispositions.clear();
8492 UnsignedRanges.clear();
8493 SignedRanges.clear();
8494 ExprValueMap.clear();
8495 HasRecMap.clear();
8496 ConstantMultipleCache.clear();
8497 PredicatedSCEVRewrites.clear();
8498 FoldCache.clear();
8499 FoldCacheUser.clear();
8500}
8501void ScalarEvolution::visitAndClearUsers(
8505 while (!Worklist.empty()) {
8506 Instruction *I = Worklist.pop_back_val();
8507 if (!isSCEVable(I->getType()) && !isa<WithOverflowInst>(I))
8508 continue;
8509
8511 ValueExprMap.find_as(static_cast<Value *>(I));
8512 if (It != ValueExprMap.end()) {
8513 eraseValueFromMap(It->first);
8514 ToForget.push_back(It->second);
8515 if (PHINode *PN = dyn_cast<PHINode>(I))
8516 ConstantEvolutionLoopExitValue.erase(PN);
8517 }
8518
8519 PushDefUseChildren(I, Worklist, Visited);
8520 }
8521}
8522
8524 SmallVector<const Loop *, 16> LoopWorklist(1, L);
8528
8529 // Iterate over all the loops and sub-loops to drop SCEV information.
8530 while (!LoopWorklist.empty()) {
8531 auto *CurrL = LoopWorklist.pop_back_val();
8532
8533 // Drop any stored trip count value.
8534 forgetBackedgeTakenCounts(CurrL, /* Predicated */ false);
8535 forgetBackedgeTakenCounts(CurrL, /* Predicated */ true);
8536
8537 // Drop information about predicated SCEV rewrites for this loop.
8538 for (auto I = PredicatedSCEVRewrites.begin();
8539 I != PredicatedSCEVRewrites.end();) {
8540 std::pair<const SCEV *, const Loop *> Entry = I->first;
8541 if (Entry.second == CurrL)
8542 PredicatedSCEVRewrites.erase(I++);
8543 else
8544 ++I;
8545 }
8546
8547 auto LoopUsersItr = LoopUsers.find(CurrL);
8548 if (LoopUsersItr != LoopUsers.end())
8549 llvm::append_range(ToForget, LoopUsersItr->second);
8550
8551 // Drop information about expressions based on loop-header PHIs.
8552 PushLoopPHIs(CurrL, Worklist, Visited);
8553 visitAndClearUsers(Worklist, Visited, ToForget);
8554
8555 LoopPropertiesCache.erase(CurrL);
8556 // Forget all contained loops too, to avoid dangling entries in the
8557 // ValuesAtScopes map.
8558 LoopWorklist.append(CurrL->begin(), CurrL->end());
8559 }
8560 forgetMemoizedResults(ToForget);
8561}
8562
8564 forgetLoop(L->getOutermostLoop());
8565}
8566
8569 if (!I) return;
8570
8571 // Drop information about expressions based on loop-header PHIs.
8575 Worklist.push_back(I);
8576 Visited.insert(I);
8577 visitAndClearUsers(Worklist, Visited, ToForget);
8578
8579 forgetMemoizedResults(ToForget);
8580}
8581
8583 if (!isSCEVable(V->getType()))
8584 return;
8585
8586 // If SCEV looked through a trivial LCSSA phi node, we might have SCEV's
8587 // directly using a SCEVUnknown/SCEVAddRec defined in the loop. After an
8588 // extra predecessor is added, this is no longer valid. Find all Unknowns and
8589 // AddRecs defined in the loop and invalidate any SCEV's making use of them.
8590 if (const SCEV *S = getExistingSCEV(V)) {
8591 struct InvalidationRootCollector {
8592 Loop *L;
8594
8595 InvalidationRootCollector(Loop *L) : L(L) {}
8596
8597 bool follow(const SCEV *S) {
8598 if (auto *SU = dyn_cast<SCEVUnknown>(S)) {
8599 if (auto *I = dyn_cast<Instruction>(SU->getValue()))
8600 if (L->contains(I))
8601 Roots.push_back(S);
8602 } else if (auto *AddRec = dyn_cast<SCEVAddRecExpr>(S)) {
8603 if (L->contains(AddRec->getLoop()))
8604 Roots.push_back(S);
8605 }
8606 return true;
8607 }
8608 bool isDone() const { return false; }
8609 };
8610
8611 InvalidationRootCollector C(L);
8612 visitAll(S, C);
8613 forgetMemoizedResults(C.Roots);
8614 }
8615
8616 // Also perform the normal invalidation.
8617 forgetValue(V);
8618}
8619
8620void ScalarEvolution::forgetLoopDispositions() { LoopDispositions.clear(); }
8621
8623 // Unless a specific value is passed to invalidation, completely clear both
8624 // caches.
8625 if (!V) {
8626 BlockDispositions.clear();
8627 LoopDispositions.clear();
8628 return;
8629 }
8630
8631 if (!isSCEVable(V->getType()))
8632 return;
8633
8634 const SCEV *S = getExistingSCEV(V);
8635 if (!S)
8636 return;
8637
8638 // Invalidate the block and loop dispositions cached for S. Dispositions of
8639 // S's users may change if S's disposition changes (i.e. a user may change to
8640 // loop-invariant, if S changes to loop invariant), so also invalidate
8641 // dispositions of S's users recursively.
8642 SmallVector<const SCEV *, 8> Worklist = {S};
8644 while (!Worklist.empty()) {
8645 const SCEV *Curr = Worklist.pop_back_val();
8646 bool LoopDispoRemoved = LoopDispositions.erase(Curr);
8647 bool BlockDispoRemoved = BlockDispositions.erase(Curr);
8648 if (!LoopDispoRemoved && !BlockDispoRemoved)
8649 continue;
8650 auto Users = SCEVUsers.find(Curr);
8651 if (Users != SCEVUsers.end())
8652 for (const auto *User : Users->second)
8653 if (Seen.insert(User).second)
8654 Worklist.push_back(User);
8655 }
8656}
8657
8658/// Get the exact loop backedge taken count considering all loop exits. A
8659/// computable result can only be returned for loops with all exiting blocks
8660/// dominating the latch. howFarToZero assumes that the limit of each loop test
8661/// is never skipped. This is a valid assumption as long as the loop exits via
8662/// that test. For precise results, it is the caller's responsibility to specify
8663/// the relevant loop exiting block using getExact(ExitingBlock, SE).
8664const SCEV *ScalarEvolution::BackedgeTakenInfo::getExact(
8665 const Loop *L, ScalarEvolution *SE,
8667 // If any exits were not computable, the loop is not computable.
8668 if (!isComplete() || ExitNotTaken.empty())
8669 return SE->getCouldNotCompute();
8670
8671 const BasicBlock *Latch = L->getLoopLatch();
8672 // All exiting blocks we have collected must dominate the only backedge.
8673 if (!Latch)
8674 return SE->getCouldNotCompute();
8675
8676 // All exiting blocks we have gathered dominate loop's latch, so exact trip
8677 // count is simply a minimum out of all these calculated exit counts.
8679 for (const auto &ENT : ExitNotTaken) {
8680 const SCEV *BECount = ENT.ExactNotTaken;
8681 assert(BECount != SE->getCouldNotCompute() && "Bad exit SCEV!");
8682 assert(SE->DT.dominates(ENT.ExitingBlock, Latch) &&
8683 "We should only have known counts for exiting blocks that dominate "
8684 "latch!");
8685
8686 Ops.push_back(BECount);
8687
8688 if (Preds)
8689 append_range(*Preds, ENT.Predicates);
8690
8691 assert((Preds || ENT.hasAlwaysTruePredicate()) &&
8692 "Predicate should be always true!");
8693 }
8694
8695 // If an earlier exit exits on the first iteration (exit count zero), then
8696 // a later poison exit count should not propagate into the result. This are
8697 // exactly the semantics provided by umin_seq.
8698 return SE->getUMinFromMismatchedTypes(Ops, /* Sequential */ true);
8699}
8700
8701const ScalarEvolution::ExitNotTakenInfo *
8702ScalarEvolution::BackedgeTakenInfo::getExitNotTaken(
8703 const BasicBlock *ExitingBlock,
8704 SmallVectorImpl<const SCEVPredicate *> *Predicates) const {
8705 for (const auto &ENT : ExitNotTaken)
8706 if (ENT.ExitingBlock == ExitingBlock) {
8707 if (ENT.hasAlwaysTruePredicate())
8708 return &ENT;
8709 else if (Predicates) {
8710 append_range(*Predicates, ENT.Predicates);
8711 return &ENT;
8712 }
8713 }
8714
8715 return nullptr;
8716}
8717
8718/// getConstantMax - Get the constant max backedge taken count for the loop.
8719const SCEV *ScalarEvolution::BackedgeTakenInfo::getConstantMax(
8720 ScalarEvolution *SE,
8721 SmallVectorImpl<const SCEVPredicate *> *Predicates) const {
8722 if (!getConstantMax())
8723 return SE->getCouldNotCompute();
8724
8725 for (const auto &ENT : ExitNotTaken)
8726 if (!ENT.hasAlwaysTruePredicate()) {
8727 if (!Predicates)
8728 return SE->getCouldNotCompute();
8729 append_range(*Predicates, ENT.Predicates);
8730 }
8731
8732 assert((isa<SCEVCouldNotCompute>(getConstantMax()) ||
8733 isa<SCEVConstant>(getConstantMax())) &&
8734 "No point in having a non-constant max backedge taken count!");
8735 return getConstantMax();
8736}
8737
8738const SCEV *ScalarEvolution::BackedgeTakenInfo::getSymbolicMax(
8739 const Loop *L, ScalarEvolution *SE,
8740 SmallVectorImpl<const SCEVPredicate *> *Predicates) {
8741 if (!SymbolicMax) {
8742 // Form an expression for the maximum exit count possible for this loop. We
8743 // merge the max and exact information to approximate a version of
8744 // getConstantMaxBackedgeTakenCount which isn't restricted to just
8745 // constants.
8747
8748 for (const auto &ENT : ExitNotTaken) {
8749 const SCEV *ExitCount = ENT.SymbolicMaxNotTaken;
8750 if (!isa<SCEVCouldNotCompute>(ExitCount)) {
8751 assert(SE->DT.dominates(ENT.ExitingBlock, L->getLoopLatch()) &&
8752 "We should only have known counts for exiting blocks that "
8753 "dominate latch!");
8754 ExitCounts.push_back(ExitCount);
8755 if (Predicates)
8756 append_range(*Predicates, ENT.Predicates);
8757
8758 assert((Predicates || ENT.hasAlwaysTruePredicate()) &&
8759 "Predicate should be always true!");
8760 }
8761 }
8762 if (ExitCounts.empty())
8763 SymbolicMax = SE->getCouldNotCompute();
8764 else
8765 SymbolicMax =
8766 SE->getUMinFromMismatchedTypes(ExitCounts, /*Sequential*/ true);
8767 }
8768 return SymbolicMax;
8769}
8770
8771bool ScalarEvolution::BackedgeTakenInfo::isConstantMaxOrZero(
8772 ScalarEvolution *SE) const {
8773 auto PredicateNotAlwaysTrue = [](const ExitNotTakenInfo &ENT) {
8774 return !ENT.hasAlwaysTruePredicate();
8775 };
8776 return MaxOrZero && !any_of(ExitNotTaken, PredicateNotAlwaysTrue);
8777}
8778
8781
8783 const SCEV *E, const SCEV *ConstantMaxNotTaken,
8784 const SCEV *SymbolicMaxNotTaken, bool MaxOrZero,
8788 // If we prove the max count is zero, so is the symbolic bound. This happens
8789 // in practice due to differences in a) how context sensitive we've chosen
8790 // to be and b) how we reason about bounds implied by UB.
8791 if (ConstantMaxNotTaken->isZero()) {
8792 this->ExactNotTaken = E = ConstantMaxNotTaken;
8793 this->SymbolicMaxNotTaken = SymbolicMaxNotTaken = ConstantMaxNotTaken;
8794 }
8795
8798 "Exact is not allowed to be less precise than Constant Max");
8801 "Exact is not allowed to be less precise than Symbolic Max");
8804 "Symbolic Max is not allowed to be less precise than Constant Max");
8807 "No point in having a non-constant max backedge taken count!");
8809 for (const auto PredList : PredLists)
8810 for (const auto *P : PredList) {
8811 if (SeenPreds.contains(P))
8812 continue;
8813 assert(!isa<SCEVUnionPredicate>(P) && "Only add leaf predicates here!");
8814 SeenPreds.insert(P);
8815 Predicates.push_back(P);
8816 }
8817 assert((isa<SCEVCouldNotCompute>(E) || !E->getType()->isPointerTy()) &&
8818 "Backedge count should be int");
8820 !ConstantMaxNotTaken->getType()->isPointerTy()) &&
8821 "Max backedge count should be int");
8822}
8823
8831
8832/// Allocate memory for BackedgeTakenInfo and copy the not-taken count of each
8833/// computable exit into a persistent ExitNotTakenInfo array.
8834ScalarEvolution::BackedgeTakenInfo::BackedgeTakenInfo(
8836 bool IsComplete, const SCEV *ConstantMax, bool MaxOrZero)
8837 : ConstantMax(ConstantMax), IsComplete(IsComplete), MaxOrZero(MaxOrZero) {
8838 using EdgeExitInfo = ScalarEvolution::BackedgeTakenInfo::EdgeExitInfo;
8839
8840 ExitNotTaken.reserve(ExitCounts.size());
8841 std::transform(ExitCounts.begin(), ExitCounts.end(),
8842 std::back_inserter(ExitNotTaken),
8843 [&](const EdgeExitInfo &EEI) {
8844 BasicBlock *ExitBB = EEI.first;
8845 const ExitLimit &EL = EEI.second;
8846 return ExitNotTakenInfo(ExitBB, EL.ExactNotTaken,
8847 EL.ConstantMaxNotTaken, EL.SymbolicMaxNotTaken,
8848 EL.Predicates);
8849 });
8850 assert((isa<SCEVCouldNotCompute>(ConstantMax) ||
8851 isa<SCEVConstant>(ConstantMax)) &&
8852 "No point in having a non-constant max backedge taken count!");
8853}
8854
8855/// Compute the number of times the backedge of the specified loop will execute.
8856ScalarEvolution::BackedgeTakenInfo
8857ScalarEvolution::computeBackedgeTakenCount(const Loop *L,
8858 bool AllowPredicates) {
8859 SmallVector<BasicBlock *, 8> ExitingBlocks;
8860 L->getExitingBlocks(ExitingBlocks);
8861
8862 using EdgeExitInfo = ScalarEvolution::BackedgeTakenInfo::EdgeExitInfo;
8863
8865 bool CouldComputeBECount = true;
8866 BasicBlock *Latch = L->getLoopLatch(); // may be NULL.
8867 const SCEV *MustExitMaxBECount = nullptr;
8868 const SCEV *MayExitMaxBECount = nullptr;
8869 bool MustExitMaxOrZero = false;
8870 bool IsOnlyExit = ExitingBlocks.size() == 1;
8871
8872 // Compute the ExitLimit for each loop exit. Use this to populate ExitCounts
8873 // and compute maxBECount.
8874 // Do a union of all the predicates here.
8875 for (BasicBlock *ExitBB : ExitingBlocks) {
8876 // We canonicalize untaken exits to br (constant), ignore them so that
8877 // proving an exit untaken doesn't negatively impact our ability to reason
8878 // about the loop as whole.
8879 if (auto *BI = dyn_cast<BranchInst>(ExitBB->getTerminator()))
8880 if (auto *CI = dyn_cast<ConstantInt>(BI->getCondition())) {
8881 bool ExitIfTrue = !L->contains(BI->getSuccessor(0));
8882 if (ExitIfTrue == CI->isZero())
8883 continue;
8884 }
8885
8886 ExitLimit EL = computeExitLimit(L, ExitBB, IsOnlyExit, AllowPredicates);
8887
8888 assert((AllowPredicates || EL.Predicates.empty()) &&
8889 "Predicated exit limit when predicates are not allowed!");
8890
8891 // 1. For each exit that can be computed, add an entry to ExitCounts.
8892 // CouldComputeBECount is true only if all exits can be computed.
8893 if (EL.ExactNotTaken != getCouldNotCompute())
8894 ++NumExitCountsComputed;
8895 else
8896 // We couldn't compute an exact value for this exit, so
8897 // we won't be able to compute an exact value for the loop.
8898 CouldComputeBECount = false;
8899 // Remember exit count if either exact or symbolic is known. Because
8900 // Exact always implies symbolic, only check symbolic.
8901 if (EL.SymbolicMaxNotTaken != getCouldNotCompute())
8902 ExitCounts.emplace_back(ExitBB, EL);
8903 else {
8904 assert(EL.ExactNotTaken == getCouldNotCompute() &&
8905 "Exact is known but symbolic isn't?");
8906 ++NumExitCountsNotComputed;
8907 }
8908
8909 // 2. Derive the loop's MaxBECount from each exit's max number of
8910 // non-exiting iterations. Partition the loop exits into two kinds:
8911 // LoopMustExits and LoopMayExits.
8912 //
8913 // If the exit dominates the loop latch, it is a LoopMustExit otherwise it
8914 // is a LoopMayExit. If any computable LoopMustExit is found, then
8915 // MaxBECount is the minimum EL.ConstantMaxNotTaken of computable
8916 // LoopMustExits. Otherwise, MaxBECount is conservatively the maximum
8917 // EL.ConstantMaxNotTaken, where CouldNotCompute is considered greater than
8918 // any
8919 // computable EL.ConstantMaxNotTaken.
8920 if (EL.ConstantMaxNotTaken != getCouldNotCompute() && Latch &&
8921 DT.dominates(ExitBB, Latch)) {
8922 if (!MustExitMaxBECount) {
8923 MustExitMaxBECount = EL.ConstantMaxNotTaken;
8924 MustExitMaxOrZero = EL.MaxOrZero;
8925 } else {
8926 MustExitMaxBECount = getUMinFromMismatchedTypes(MustExitMaxBECount,
8927 EL.ConstantMaxNotTaken);
8928 }
8929 } else if (MayExitMaxBECount != getCouldNotCompute()) {
8930 if (!MayExitMaxBECount || EL.ConstantMaxNotTaken == getCouldNotCompute())
8931 MayExitMaxBECount = EL.ConstantMaxNotTaken;
8932 else {
8933 MayExitMaxBECount = getUMaxFromMismatchedTypes(MayExitMaxBECount,
8934 EL.ConstantMaxNotTaken);
8935 }
8936 }
8937 }
8938 const SCEV *MaxBECount = MustExitMaxBECount ? MustExitMaxBECount :
8939 (MayExitMaxBECount ? MayExitMaxBECount : getCouldNotCompute());
8940 // The loop backedge will be taken the maximum or zero times if there's
8941 // a single exit that must be taken the maximum or zero times.
8942 bool MaxOrZero = (MustExitMaxOrZero && ExitingBlocks.size() == 1);
8943
8944 // Remember which SCEVs are used in exit limits for invalidation purposes.
8945 // We only care about non-constant SCEVs here, so we can ignore
8946 // EL.ConstantMaxNotTaken
8947 // and MaxBECount, which must be SCEVConstant.
8948 for (const auto &Pair : ExitCounts) {
8949 if (!isa<SCEVConstant>(Pair.second.ExactNotTaken))
8950 BECountUsers[Pair.second.ExactNotTaken].insert({L, AllowPredicates});
8951 if (!isa<SCEVConstant>(Pair.second.SymbolicMaxNotTaken))
8952 BECountUsers[Pair.second.SymbolicMaxNotTaken].insert(
8953 {L, AllowPredicates});
8954 }
8955 return BackedgeTakenInfo(std::move(ExitCounts), CouldComputeBECount,
8956 MaxBECount, MaxOrZero);
8957}
8958
8959ScalarEvolution::ExitLimit
8960ScalarEvolution::computeExitLimit(const Loop *L, BasicBlock *ExitingBlock,
8961 bool IsOnlyExit, bool AllowPredicates) {
8962 assert(L->contains(ExitingBlock) && "Exit count for non-loop block?");
8963 // If our exiting block does not dominate the latch, then its connection with
8964 // loop's exit limit may be far from trivial.
8965 const BasicBlock *Latch = L->getLoopLatch();
8966 if (!Latch || !DT.dominates(ExitingBlock, Latch))
8967 return getCouldNotCompute();
8968
8969 Instruction *Term = ExitingBlock->getTerminator();
8970 if (BranchInst *BI = dyn_cast<BranchInst>(Term)) {
8971 assert(BI->isConditional() && "If unconditional, it can't be in loop!");
8972 bool ExitIfTrue = !L->contains(BI->getSuccessor(0));
8973 assert(ExitIfTrue == L->contains(BI->getSuccessor(1)) &&
8974 "It should have one successor in loop and one exit block!");
8975 // Proceed to the next level to examine the exit condition expression.
8976 return computeExitLimitFromCond(L, BI->getCondition(), ExitIfTrue,
8977 /*ControlsOnlyExit=*/IsOnlyExit,
8978 AllowPredicates);
8979 }
8980
8981 if (SwitchInst *SI = dyn_cast<SwitchInst>(Term)) {
8982 // For switch, make sure that there is a single exit from the loop.
8983 BasicBlock *Exit = nullptr;
8984 for (auto *SBB : successors(ExitingBlock))
8985 if (!L->contains(SBB)) {
8986 if (Exit) // Multiple exit successors.
8987 return getCouldNotCompute();
8988 Exit = SBB;
8989 }
8990 assert(Exit && "Exiting block must have at least one exit");
8991 return computeExitLimitFromSingleExitSwitch(
8992 L, SI, Exit, /*ControlsOnlyExit=*/IsOnlyExit);
8993 }
8994
8995 return getCouldNotCompute();
8996}
8997
8999 const Loop *L, Value *ExitCond, bool ExitIfTrue, bool ControlsOnlyExit,
9000 bool AllowPredicates) {
9001 ScalarEvolution::ExitLimitCacheTy Cache(L, ExitIfTrue, AllowPredicates);
9002 return computeExitLimitFromCondCached(Cache, L, ExitCond, ExitIfTrue,
9003 ControlsOnlyExit, AllowPredicates);
9004}
9005
9006std::optional<ScalarEvolution::ExitLimit>
9007ScalarEvolution::ExitLimitCache::find(const Loop *L, Value *ExitCond,
9008 bool ExitIfTrue, bool ControlsOnlyExit,
9009 bool AllowPredicates) {
9010 (void)this->L;
9011 (void)this->ExitIfTrue;
9012 (void)this->AllowPredicates;
9013
9014 assert(this->L == L && this->ExitIfTrue == ExitIfTrue &&
9015 this->AllowPredicates == AllowPredicates &&
9016 "Variance in assumed invariant key components!");
9017 auto Itr = TripCountMap.find({ExitCond, ControlsOnlyExit});
9018 if (Itr == TripCountMap.end())
9019 return std::nullopt;
9020 return Itr->second;
9021}
9022
9023void ScalarEvolution::ExitLimitCache::insert(const Loop *L, Value *ExitCond,
9024 bool ExitIfTrue,
9025 bool ControlsOnlyExit,
9026 bool AllowPredicates,
9027 const ExitLimit &EL) {
9028 assert(this->L == L && this->ExitIfTrue == ExitIfTrue &&
9029 this->AllowPredicates == AllowPredicates &&
9030 "Variance in assumed invariant key components!");
9031
9032 auto InsertResult = TripCountMap.insert({{ExitCond, ControlsOnlyExit}, EL});
9033 assert(InsertResult.second && "Expected successful insertion!");
9034 (void)InsertResult;
9035 (void)ExitIfTrue;
9036}
9037
9038ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromCondCached(
9039 ExitLimitCacheTy &Cache, const Loop *L, Value *ExitCond, bool ExitIfTrue,
9040 bool ControlsOnlyExit, bool AllowPredicates) {
9041
9042 if (auto MaybeEL = Cache.find(L, ExitCond, ExitIfTrue, ControlsOnlyExit,
9043 AllowPredicates))
9044 return *MaybeEL;
9045
9046 ExitLimit EL = computeExitLimitFromCondImpl(
9047 Cache, L, ExitCond, ExitIfTrue, ControlsOnlyExit, AllowPredicates);
9048 Cache.insert(L, ExitCond, ExitIfTrue, ControlsOnlyExit, AllowPredicates, EL);
9049 return EL;
9050}
9051
9052ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromCondImpl(
9053 ExitLimitCacheTy &Cache, const Loop *L, Value *ExitCond, bool ExitIfTrue,
9054 bool ControlsOnlyExit, bool AllowPredicates) {
9055 // Handle BinOp conditions (And, Or).
9056 if (auto LimitFromBinOp = computeExitLimitFromCondFromBinOp(
9057 Cache, L, ExitCond, ExitIfTrue, ControlsOnlyExit, AllowPredicates))
9058 return *LimitFromBinOp;
9059
9060 // With an icmp, it may be feasible to compute an exact backedge-taken count.
9061 // Proceed to the next level to examine the icmp.
9062 if (ICmpInst *ExitCondICmp = dyn_cast<ICmpInst>(ExitCond)) {
9063 ExitLimit EL =
9064 computeExitLimitFromICmp(L, ExitCondICmp, ExitIfTrue, ControlsOnlyExit);
9065 if (EL.hasFullInfo() || !AllowPredicates)
9066 return EL;
9067
9068 // Try again, but use SCEV predicates this time.
9069 return computeExitLimitFromICmp(L, ExitCondICmp, ExitIfTrue,
9070 ControlsOnlyExit,
9071 /*AllowPredicates=*/true);
9072 }
9073
9074 // Check for a constant condition. These are normally stripped out by
9075 // SimplifyCFG, but ScalarEvolution may be used by a pass which wishes to
9076 // preserve the CFG and is temporarily leaving constant conditions
9077 // in place.
9078 if (ConstantInt *CI = dyn_cast<ConstantInt>(ExitCond)) {
9079 if (ExitIfTrue == !CI->getZExtValue())
9080 // The backedge is always taken.
9081 return getCouldNotCompute();
9082 // The backedge is never taken.
9083 return getZero(CI->getType());
9084 }
9085
9086 // If we're exiting based on the overflow flag of an x.with.overflow intrinsic
9087 // with a constant step, we can form an equivalent icmp predicate and figure
9088 // out how many iterations will be taken before we exit.
9089 const WithOverflowInst *WO;
9090 const APInt *C;
9091 if (match(ExitCond, m_ExtractValue<1>(m_WithOverflowInst(WO))) &&
9092 match(WO->getRHS(), m_APInt(C))) {
9093 ConstantRange NWR =
9095 WO->getNoWrapKind());
9096 CmpInst::Predicate Pred;
9097 APInt NewRHSC, Offset;
9098 NWR.getEquivalentICmp(Pred, NewRHSC, Offset);
9099 if (!ExitIfTrue)
9100 Pred = ICmpInst::getInversePredicate(Pred);
9101 auto *LHS = getSCEV(WO->getLHS());
9102 if (Offset != 0)
9104 auto EL = computeExitLimitFromICmp(L, Pred, LHS, getConstant(NewRHSC),
9105 ControlsOnlyExit, AllowPredicates);
9106 if (EL.hasAnyInfo())
9107 return EL;
9108 }
9109
9110 // If it's not an integer or pointer comparison then compute it the hard way.
9111 return computeExitCountExhaustively(L, ExitCond, ExitIfTrue);
9112}
9113
9114std::optional<ScalarEvolution::ExitLimit>
9115ScalarEvolution::computeExitLimitFromCondFromBinOp(
9116 ExitLimitCacheTy &Cache, const Loop *L, Value *ExitCond, bool ExitIfTrue,
9117 bool ControlsOnlyExit, bool AllowPredicates) {
9118 // Check if the controlling expression for this loop is an And or Or.
9119 Value *Op0, *Op1;
9120 bool IsAnd = false;
9121 if (match(ExitCond, m_LogicalAnd(m_Value(Op0), m_Value(Op1))))
9122 IsAnd = true;
9123 else if (match(ExitCond, m_LogicalOr(m_Value(Op0), m_Value(Op1))))
9124 IsAnd = false;
9125 else
9126 return std::nullopt;
9127
9128 // EitherMayExit is true in these two cases:
9129 // br (and Op0 Op1), loop, exit
9130 // br (or Op0 Op1), exit, loop
9131 bool EitherMayExit = IsAnd ^ ExitIfTrue;
9132 ExitLimit EL0 = computeExitLimitFromCondCached(
9133 Cache, L, Op0, ExitIfTrue, ControlsOnlyExit && !EitherMayExit,
9134 AllowPredicates);
9135 ExitLimit EL1 = computeExitLimitFromCondCached(
9136 Cache, L, Op1, ExitIfTrue, ControlsOnlyExit && !EitherMayExit,
9137 AllowPredicates);
9138
9139 // Be robust against unsimplified IR for the form "op i1 X, NeutralElement"
9140 const Constant *NeutralElement = ConstantInt::get(ExitCond->getType(), IsAnd);
9141 if (isa<ConstantInt>(Op1))
9142 return Op1 == NeutralElement ? EL0 : EL1;
9143 if (isa<ConstantInt>(Op0))
9144 return Op0 == NeutralElement ? EL1 : EL0;
9145
9146 const SCEV *BECount = getCouldNotCompute();
9147 const SCEV *ConstantMaxBECount = getCouldNotCompute();
9148 const SCEV *SymbolicMaxBECount = getCouldNotCompute();
9149 if (EitherMayExit) {
9150 bool UseSequentialUMin = !isa<BinaryOperator>(ExitCond);
9151 // Both conditions must be same for the loop to continue executing.
9152 // Choose the less conservative count.
9153 if (EL0.ExactNotTaken != getCouldNotCompute() &&
9154 EL1.ExactNotTaken != getCouldNotCompute()) {
9155 BECount = getUMinFromMismatchedTypes(EL0.ExactNotTaken, EL1.ExactNotTaken,
9156 UseSequentialUMin);
9157 }
9158 if (EL0.ConstantMaxNotTaken == getCouldNotCompute())
9159 ConstantMaxBECount = EL1.ConstantMaxNotTaken;
9160 else if (EL1.ConstantMaxNotTaken == getCouldNotCompute())
9161 ConstantMaxBECount = EL0.ConstantMaxNotTaken;
9162 else
9163 ConstantMaxBECount = getUMinFromMismatchedTypes(EL0.ConstantMaxNotTaken,
9164 EL1.ConstantMaxNotTaken);
9165 if (EL0.SymbolicMaxNotTaken == getCouldNotCompute())
9166 SymbolicMaxBECount = EL1.SymbolicMaxNotTaken;
9167 else if (EL1.SymbolicMaxNotTaken == getCouldNotCompute())
9168 SymbolicMaxBECount = EL0.SymbolicMaxNotTaken;
9169 else
9170 SymbolicMaxBECount = getUMinFromMismatchedTypes(
9171 EL0.SymbolicMaxNotTaken, EL1.SymbolicMaxNotTaken, UseSequentialUMin);
9172 } else {
9173 // Both conditions must be same at the same time for the loop to exit.
9174 // For now, be conservative.
9175 if (EL0.ExactNotTaken == EL1.ExactNotTaken)
9176 BECount = EL0.ExactNotTaken;
9177 }
9178
9179 // There are cases (e.g. PR26207) where computeExitLimitFromCond is able
9180 // to be more aggressive when computing BECount than when computing
9181 // ConstantMaxBECount. In these cases it is possible for EL0.ExactNotTaken
9182 // and
9183 // EL1.ExactNotTaken to match, but for EL0.ConstantMaxNotTaken and
9184 // EL1.ConstantMaxNotTaken to not.
9185 if (isa<SCEVCouldNotCompute>(ConstantMaxBECount) &&
9186 !isa<SCEVCouldNotCompute>(BECount))
9187 ConstantMaxBECount = getConstant(getUnsignedRangeMax(BECount));
9188 if (isa<SCEVCouldNotCompute>(SymbolicMaxBECount))
9189 SymbolicMaxBECount =
9190 isa<SCEVCouldNotCompute>(BECount) ? ConstantMaxBECount : BECount;
9191 return ExitLimit(BECount, ConstantMaxBECount, SymbolicMaxBECount, false,
9192 {ArrayRef(EL0.Predicates), ArrayRef(EL1.Predicates)});
9193}
9194
9195ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromICmp(
9196 const Loop *L, ICmpInst *ExitCond, bool ExitIfTrue, bool ControlsOnlyExit,
9197 bool AllowPredicates) {
9198 // If the condition was exit on true, convert the condition to exit on false
9199 CmpPredicate Pred;
9200 if (!ExitIfTrue)
9201 Pred = ExitCond->getCmpPredicate();
9202 else
9203 Pred = ExitCond->getInverseCmpPredicate();
9204 const ICmpInst::Predicate OriginalPred = Pred;
9205
9206 const SCEV *LHS = getSCEV(ExitCond->getOperand(0));
9207 const SCEV *RHS = getSCEV(ExitCond->getOperand(1));
9208
9209 ExitLimit EL = computeExitLimitFromICmp(L, Pred, LHS, RHS, ControlsOnlyExit,
9210 AllowPredicates);
9211 if (EL.hasAnyInfo())
9212 return EL;
9213
9214 auto *ExhaustiveCount =
9215 computeExitCountExhaustively(L, ExitCond, ExitIfTrue);
9216
9217 if (!isa<SCEVCouldNotCompute>(ExhaustiveCount))
9218 return ExhaustiveCount;
9219
9220 return computeShiftCompareExitLimit(ExitCond->getOperand(0),
9221 ExitCond->getOperand(1), L, OriginalPred);
9222}
9223ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromICmp(
9224 const Loop *L, CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS,
9225 bool ControlsOnlyExit, bool AllowPredicates) {
9226
9227 // Try to evaluate any dependencies out of the loop.
9228 LHS = getSCEVAtScope(LHS, L);
9229 RHS = getSCEVAtScope(RHS, L);
9230
9231 // At this point, we would like to compute how many iterations of the
9232 // loop the predicate will return true for these inputs.
9233 if (isLoopInvariant(LHS, L) && !isLoopInvariant(RHS, L)) {
9234 // If there is a loop-invariant, force it into the RHS.
9235 std::swap(LHS, RHS);
9237 }
9238
9239 bool ControllingFiniteLoop = ControlsOnlyExit && loopHasNoAbnormalExits(L) &&
9241 // Simplify the operands before analyzing them.
9242 (void)SimplifyICmpOperands(Pred, LHS, RHS, /*Depth=*/0);
9243
9244 // If we have a comparison of a chrec against a constant, try to use value
9245 // ranges to answer this query.
9246 if (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS))
9247 if (const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(LHS))
9248 if (AddRec->getLoop() == L) {
9249 // Form the constant range.
9250 ConstantRange CompRange =
9251 ConstantRange::makeExactICmpRegion(Pred, RHSC->getAPInt());
9252
9253 const SCEV *Ret = AddRec->getNumIterationsInRange(CompRange, *this);
9254 if (!isa<SCEVCouldNotCompute>(Ret)) return Ret;
9255 }
9256
9257 // If this loop must exit based on this condition (or execute undefined
9258 // behaviour), see if we can improve wrap flags. This is essentially
9259 // a must execute style proof.
9260 if (ControllingFiniteLoop && isLoopInvariant(RHS, L)) {
9261 // If we can prove the test sequence produced must repeat the same values
9262 // on self-wrap of the IV, then we can infer that IV doesn't self wrap
9263 // because if it did, we'd have an infinite (undefined) loop.
9264 // TODO: We can peel off any functions which are invertible *in L*. Loop
9265 // invariant terms are effectively constants for our purposes here.
9266 auto *InnerLHS = LHS;
9267 if (auto *ZExt = dyn_cast<SCEVZeroExtendExpr>(LHS))
9268 InnerLHS = ZExt->getOperand();
9269 if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(InnerLHS);
9270 AR && !AR->hasNoSelfWrap() && AR->getLoop() == L && AR->isAffine() &&
9271 isKnownToBeAPowerOfTwo(AR->getStepRecurrence(*this), /*OrZero=*/true,
9272 /*OrNegative=*/true)) {
9273 auto Flags = AR->getNoWrapFlags();
9274 Flags = setFlags(Flags, SCEV::FlagNW);
9277 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), Flags);
9278 }
9279
9280 // For a slt/ult condition with a positive step, can we prove nsw/nuw?
9281 // From no-self-wrap, this follows trivially from the fact that every
9282 // (un)signed-wrapped, but not self-wrapped value must be LT than the
9283 // last value before (un)signed wrap. Since we know that last value
9284 // didn't exit, nor will any smaller one.
9285 if (Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_ULT) {
9286 auto WrapType = Pred == ICmpInst::ICMP_SLT ? SCEV::FlagNSW : SCEV::FlagNUW;
9287 if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(LHS);
9288 AR && AR->getLoop() == L && AR->isAffine() &&
9289 !AR->getNoWrapFlags(WrapType) && AR->hasNoSelfWrap() &&
9290 isKnownPositive(AR->getStepRecurrence(*this))) {
9291 auto Flags = AR->getNoWrapFlags();
9292 Flags = setFlags(Flags, WrapType);
9295 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), Flags);
9296 }
9297 }
9298 }
9299
9300 switch (Pred) {
9301 case ICmpInst::ICMP_NE: { // while (X != Y)
9302 // Convert to: while (X-Y != 0)
9303 if (LHS->getType()->isPointerTy()) {
9306 return LHS;
9307 }
9308 if (RHS->getType()->isPointerTy()) {
9311 return RHS;
9312 }
9313 ExitLimit EL = howFarToZero(getMinusSCEV(LHS, RHS), L, ControlsOnlyExit,
9314 AllowPredicates);
9315 if (EL.hasAnyInfo())
9316 return EL;
9317 break;
9318 }
9319 case ICmpInst::ICMP_EQ: { // while (X == Y)
9320 // Convert to: while (X-Y == 0)
9321 if (LHS->getType()->isPointerTy()) {
9324 return LHS;
9325 }
9326 if (RHS->getType()->isPointerTy()) {
9329 return RHS;
9330 }
9331 ExitLimit EL = howFarToNonZero(getMinusSCEV(LHS, RHS), L);
9332 if (EL.hasAnyInfo()) return EL;
9333 break;
9334 }
9335 case ICmpInst::ICMP_SLE:
9336 case ICmpInst::ICMP_ULE:
9337 // Since the loop is finite, an invariant RHS cannot include the boundary
9338 // value, otherwise it would loop forever.
9339 if (!EnableFiniteLoopControl || !ControllingFiniteLoop ||
9340 !isLoopInvariant(RHS, L)) {
9341 // Otherwise, perform the addition in a wider type, to avoid overflow.
9342 // If the LHS is an addrec with the appropriate nowrap flag, the
9343 // extension will be sunk into it and the exit count can be analyzed.
9344 auto *OldType = dyn_cast<IntegerType>(LHS->getType());
9345 if (!OldType)
9346 break;
9347 // Prefer doubling the bitwidth over adding a single bit to make it more
9348 // likely that we use a legal type.
9349 auto *NewType =
9350 Type::getIntNTy(OldType->getContext(), OldType->getBitWidth() * 2);
9351 if (ICmpInst::isSigned(Pred)) {
9352 LHS = getSignExtendExpr(LHS, NewType);
9353 RHS = getSignExtendExpr(RHS, NewType);
9354 } else {
9355 LHS = getZeroExtendExpr(LHS, NewType);
9356 RHS = getZeroExtendExpr(RHS, NewType);
9357 }
9358 }
9360 [[fallthrough]];
9361 case ICmpInst::ICMP_SLT:
9362 case ICmpInst::ICMP_ULT: { // while (X < Y)
9363 bool IsSigned = ICmpInst::isSigned(Pred);
9364 ExitLimit EL = howManyLessThans(LHS, RHS, L, IsSigned, ControlsOnlyExit,
9365 AllowPredicates);
9366 if (EL.hasAnyInfo())
9367 return EL;
9368 break;
9369 }
9370 case ICmpInst::ICMP_SGE:
9371 case ICmpInst::ICMP_UGE:
9372 // Since the loop is finite, an invariant RHS cannot include the boundary
9373 // value, otherwise it would loop forever.
9374 if (!EnableFiniteLoopControl || !ControllingFiniteLoop ||
9375 !isLoopInvariant(RHS, L))
9376 break;
9378 [[fallthrough]];
9379 case ICmpInst::ICMP_SGT:
9380 case ICmpInst::ICMP_UGT: { // while (X > Y)
9381 bool IsSigned = ICmpInst::isSigned(Pred);
9382 ExitLimit EL = howManyGreaterThans(LHS, RHS, L, IsSigned, ControlsOnlyExit,
9383 AllowPredicates);
9384 if (EL.hasAnyInfo())
9385 return EL;
9386 break;
9387 }
9388 default:
9389 break;
9390 }
9391
9392 return getCouldNotCompute();
9393}
9394
9395ScalarEvolution::ExitLimit
9396ScalarEvolution::computeExitLimitFromSingleExitSwitch(const Loop *L,
9397 SwitchInst *Switch,
9398 BasicBlock *ExitingBlock,
9399 bool ControlsOnlyExit) {
9400 assert(!L->contains(ExitingBlock) && "Not an exiting block!");
9401
9402 // Give up if the exit is the default dest of a switch.
9403 if (Switch->getDefaultDest() == ExitingBlock)
9404 return getCouldNotCompute();
9405
9406 assert(L->contains(Switch->getDefaultDest()) &&
9407 "Default case must not exit the loop!");
9408 const SCEV *LHS = getSCEVAtScope(Switch->getCondition(), L);
9409 const SCEV *RHS = getConstant(Switch->findCaseDest(ExitingBlock));
9410
9411 // while (X != Y) --> while (X-Y != 0)
9412 ExitLimit EL = howFarToZero(getMinusSCEV(LHS, RHS), L, ControlsOnlyExit);
9413 if (EL.hasAnyInfo())
9414 return EL;
9415
9416 return getCouldNotCompute();
9417}
9418
9419static ConstantInt *
9421 ScalarEvolution &SE) {
9422 const SCEV *InVal = SE.getConstant(C);
9423 const SCEV *Val = AddRec->evaluateAtIteration(InVal, SE);
9425 "Evaluation of SCEV at constant didn't fold correctly?");
9426 return cast<SCEVConstant>(Val)->getValue();
9427}
9428
9429ScalarEvolution::ExitLimit ScalarEvolution::computeShiftCompareExitLimit(
9430 Value *LHS, Value *RHSV, const Loop *L, ICmpInst::Predicate Pred) {
9431 ConstantInt *RHS = dyn_cast<ConstantInt>(RHSV);
9432 if (!RHS)
9433 return getCouldNotCompute();
9434
9435 const BasicBlock *Latch = L->getLoopLatch();
9436 if (!Latch)
9437 return getCouldNotCompute();
9438
9439 const BasicBlock *Predecessor = L->getLoopPredecessor();
9440 if (!Predecessor)
9441 return getCouldNotCompute();
9442
9443 // Return true if V is of the form "LHS `shift_op` <positive constant>".
9444 // Return LHS in OutLHS and shift_opt in OutOpCode.
9445 auto MatchPositiveShift =
9446 [](Value *V, Value *&OutLHS, Instruction::BinaryOps &OutOpCode) {
9447
9448 using namespace PatternMatch;
9449
9450 ConstantInt *ShiftAmt;
9451 if (match(V, m_LShr(m_Value(OutLHS), m_ConstantInt(ShiftAmt))))
9452 OutOpCode = Instruction::LShr;
9453 else if (match(V, m_AShr(m_Value(OutLHS), m_ConstantInt(ShiftAmt))))
9454 OutOpCode = Instruction::AShr;
9455 else if (match(V, m_Shl(m_Value(OutLHS), m_ConstantInt(ShiftAmt))))
9456 OutOpCode = Instruction::Shl;
9457 else
9458 return false;
9459
9460 return ShiftAmt->getValue().isStrictlyPositive();
9461 };
9462
9463 // Recognize a "shift recurrence" either of the form %iv or of %iv.shifted in
9464 //
9465 // loop:
9466 // %iv = phi i32 [ %iv.shifted, %loop ], [ %val, %preheader ]
9467 // %iv.shifted = lshr i32 %iv, <positive constant>
9468 //
9469 // Return true on a successful match. Return the corresponding PHI node (%iv
9470 // above) in PNOut and the opcode of the shift operation in OpCodeOut.
9471 auto MatchShiftRecurrence =
9472 [&](Value *V, PHINode *&PNOut, Instruction::BinaryOps &OpCodeOut) {
9473 std::optional<Instruction::BinaryOps> PostShiftOpCode;
9474
9475 {
9477 Value *V;
9478
9479 // If we encounter a shift instruction, "peel off" the shift operation,
9480 // and remember that we did so. Later when we inspect %iv's backedge
9481 // value, we will make sure that the backedge value uses the same
9482 // operation.
9483 //
9484 // Note: the peeled shift operation does not have to be the same
9485 // instruction as the one feeding into the PHI's backedge value. We only
9486 // really care about it being the same *kind* of shift instruction --
9487 // that's all that is required for our later inferences to hold.
9488 if (MatchPositiveShift(LHS, V, OpC)) {
9489 PostShiftOpCode = OpC;
9490 LHS = V;
9491 }
9492 }
9493
9494 PNOut = dyn_cast<PHINode>(LHS);
9495 if (!PNOut || PNOut->getParent() != L->getHeader())
9496 return false;
9497
9498 Value *BEValue = PNOut->getIncomingValueForBlock(Latch);
9499 Value *OpLHS;
9500
9501 return
9502 // The backedge value for the PHI node must be a shift by a positive
9503 // amount
9504 MatchPositiveShift(BEValue, OpLHS, OpCodeOut) &&
9505
9506 // of the PHI node itself
9507 OpLHS == PNOut &&
9508
9509 // and the kind of shift should be match the kind of shift we peeled
9510 // off, if any.
9511 (!PostShiftOpCode || *PostShiftOpCode == OpCodeOut);
9512 };
9513
9514 PHINode *PN;
9516 if (!MatchShiftRecurrence(LHS, PN, OpCode))
9517 return getCouldNotCompute();
9518
9519 const DataLayout &DL = getDataLayout();
9520
9521 // The key rationale for this optimization is that for some kinds of shift
9522 // recurrences, the value of the recurrence "stabilizes" to either 0 or -1
9523 // within a finite number of iterations. If the condition guarding the
9524 // backedge (in the sense that the backedge is taken if the condition is true)
9525 // is false for the value the shift recurrence stabilizes to, then we know
9526 // that the backedge is taken only a finite number of times.
9527
9528 ConstantInt *StableValue = nullptr;
9529 switch (OpCode) {
9530 default:
9531 llvm_unreachable("Impossible case!");
9532
9533 case Instruction::AShr: {
9534 // {K,ashr,<positive-constant>} stabilizes to signum(K) in at most
9535 // bitwidth(K) iterations.
9536 Value *FirstValue = PN->getIncomingValueForBlock(Predecessor);
9537 KnownBits Known = computeKnownBits(FirstValue, DL, &AC,
9538 Predecessor->getTerminator(), &DT);
9539 auto *Ty = cast<IntegerType>(RHS->getType());
9540 if (Known.isNonNegative())
9541 StableValue = ConstantInt::get(Ty, 0);
9542 else if (Known.isNegative())
9543 StableValue = ConstantInt::get(Ty, -1, true);
9544 else
9545 return getCouldNotCompute();
9546
9547 break;
9548 }
9549 case Instruction::LShr:
9550 case Instruction::Shl:
9551 // Both {K,lshr,<positive-constant>} and {K,shl,<positive-constant>}
9552 // stabilize to 0 in at most bitwidth(K) iterations.
9553 StableValue = ConstantInt::get(cast<IntegerType>(RHS->getType()), 0);
9554 break;
9555 }
9556
9557 auto *Result =
9558 ConstantFoldCompareInstOperands(Pred, StableValue, RHS, DL, &TLI);
9559 assert(Result->getType()->isIntegerTy(1) &&
9560 "Otherwise cannot be an operand to a branch instruction");
9561
9562 if (Result->isZeroValue()) {
9563 unsigned BitWidth = getTypeSizeInBits(RHS->getType());
9564 const SCEV *UpperBound =
9566 return ExitLimit(getCouldNotCompute(), UpperBound, UpperBound, false);
9567 }
9568
9569 return getCouldNotCompute();
9570}
9571
9572/// Return true if we can constant fold an instruction of the specified type,
9573/// assuming that all operands were constants.
9574static bool CanConstantFold(const Instruction *I) {
9578 return true;
9579
9580 if (const CallInst *CI = dyn_cast<CallInst>(I))
9581 if (const Function *F = CI->getCalledFunction())
9582 return canConstantFoldCallTo(CI, F);
9583 return false;
9584}
9585
9586/// Determine whether this instruction can constant evolve within this loop
9587/// assuming its operands can all constant evolve.
9588static bool canConstantEvolve(Instruction *I, const Loop *L) {
9589 // An instruction outside of the loop can't be derived from a loop PHI.
9590 if (!L->contains(I)) return false;
9591
9592 if (isa<PHINode>(I)) {
9593 // We don't currently keep track of the control flow needed to evaluate
9594 // PHIs, so we cannot handle PHIs inside of loops.
9595 return L->getHeader() == I->getParent();
9596 }
9597
9598 // If we won't be able to constant fold this expression even if the operands
9599 // are constants, bail early.
9600 return CanConstantFold(I);
9601}
9602
9603/// getConstantEvolvingPHIOperands - Implement getConstantEvolvingPHI by
9604/// recursing through each instruction operand until reaching a loop header phi.
9605static PHINode *
9608 unsigned Depth) {
9610 return nullptr;
9611
9612 // Otherwise, we can evaluate this instruction if all of its operands are
9613 // constant or derived from a PHI node themselves.
9614 PHINode *PHI = nullptr;
9615 for (Value *Op : UseInst->operands()) {
9616 if (isa<Constant>(Op)) continue;
9617
9619 if (!OpInst || !canConstantEvolve(OpInst, L)) return nullptr;
9620
9621 PHINode *P = dyn_cast<PHINode>(OpInst);
9622 if (!P)
9623 // If this operand is already visited, reuse the prior result.
9624 // We may have P != PHI if this is the deepest point at which the
9625 // inconsistent paths meet.
9626 P = PHIMap.lookup(OpInst);
9627 if (!P) {
9628 // Recurse and memoize the results, whether a phi is found or not.
9629 // This recursive call invalidates pointers into PHIMap.
9630 P = getConstantEvolvingPHIOperands(OpInst, L, PHIMap, Depth + 1);
9631 PHIMap[OpInst] = P;
9632 }
9633 if (!P)
9634 return nullptr; // Not evolving from PHI
9635 if (PHI && PHI != P)
9636 return nullptr; // Evolving from multiple different PHIs.
9637 PHI = P;
9638 }
9639 // This is a expression evolving from a constant PHI!
9640 return PHI;
9641}
9642
9643/// getConstantEvolvingPHI - Given an LLVM value and a loop, return a PHI node
9644/// in the loop that V is derived from. We allow arbitrary operations along the
9645/// way, but the operands of an operation must either be constants or a value
9646/// derived from a constant PHI. If this expression does not fit with these
9647/// constraints, return null.
9650 if (!I || !canConstantEvolve(I, L)) return nullptr;
9651
9652 if (PHINode *PN = dyn_cast<PHINode>(I))
9653 return PN;
9654
9655 // Record non-constant instructions contained by the loop.
9657 return getConstantEvolvingPHIOperands(I, L, PHIMap, 0);
9658}
9659
9660/// EvaluateExpression - Given an expression that passes the
9661/// getConstantEvolvingPHI predicate, evaluate its value assuming the PHI node
9662/// in the loop has the value PHIVal. If we can't fold this expression for some
9663/// reason, return null.
9666 const DataLayout &DL,
9667 const TargetLibraryInfo *TLI) {
9668 // Convenient constant check, but redundant for recursive calls.
9669 if (Constant *C = dyn_cast<Constant>(V)) return C;
9671 if (!I) return nullptr;
9672
9673 if (Constant *C = Vals.lookup(I)) return C;
9674
9675 // An instruction inside the loop depends on a value outside the loop that we
9676 // weren't given a mapping for, or a value such as a call inside the loop.
9677 if (!canConstantEvolve(I, L)) return nullptr;
9678
9679 // An unmapped PHI can be due to a branch or another loop inside this loop,
9680 // or due to this not being the initial iteration through a loop where we
9681 // couldn't compute the evolution of this particular PHI last time.
9682 if (isa<PHINode>(I)) return nullptr;
9683
9684 std::vector<Constant*> Operands(I->getNumOperands());
9685
9686 for (unsigned i = 0, e = I->getNumOperands(); i != e; ++i) {
9687 Instruction *Operand = dyn_cast<Instruction>(I->getOperand(i));
9688 if (!Operand) {
9689 Operands[i] = dyn_cast<Constant>(I->getOperand(i));
9690 if (!Operands[i]) return nullptr;
9691 continue;
9692 }
9693 Constant *C = EvaluateExpression(Operand, L, Vals, DL, TLI);
9694 Vals[Operand] = C;
9695 if (!C) return nullptr;
9696 Operands[i] = C;
9697 }
9698
9699 return ConstantFoldInstOperands(I, Operands, DL, TLI,
9700 /*AllowNonDeterministic=*/false);
9701}
9702
9703
9704// If every incoming value to PN except the one for BB is a specific Constant,
9705// return that, else return nullptr.
9707 Constant *IncomingVal = nullptr;
9708
9709 for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) {
9710 if (PN->getIncomingBlock(i) == BB)
9711 continue;
9712
9713 auto *CurrentVal = dyn_cast<Constant>(PN->getIncomingValue(i));
9714 if (!CurrentVal)
9715 return nullptr;
9716
9717 if (IncomingVal != CurrentVal) {
9718 if (IncomingVal)
9719 return nullptr;
9720 IncomingVal = CurrentVal;
9721 }
9722 }
9723
9724 return IncomingVal;
9725}
9726
9727/// getConstantEvolutionLoopExitValue - If we know that the specified Phi is
9728/// in the header of its containing loop, we know the loop executes a
9729/// constant number of times, and the PHI node is just a recurrence
9730/// involving constants, fold it.
9731Constant *
9732ScalarEvolution::getConstantEvolutionLoopExitValue(PHINode *PN,
9733 const APInt &BEs,
9734 const Loop *L) {
9735 auto [I, Inserted] = ConstantEvolutionLoopExitValue.try_emplace(PN);
9736 if (!Inserted)
9737 return I->second;
9738
9740 return nullptr; // Not going to evaluate it.
9741
9742 Constant *&RetVal = I->second;
9743
9744 DenseMap<Instruction *, Constant *> CurrentIterVals;
9745 BasicBlock *Header = L->getHeader();
9746 assert(PN->getParent() == Header && "Can't evaluate PHI not in loop header!");
9747
9748 BasicBlock *Latch = L->getLoopLatch();
9749 if (!Latch)
9750 return nullptr;
9751
9752 for (PHINode &PHI : Header->phis()) {
9753 if (auto *StartCST = getOtherIncomingValue(&PHI, Latch))
9754 CurrentIterVals[&PHI] = StartCST;
9755 }
9756 if (!CurrentIterVals.count(PN))
9757 return RetVal = nullptr;
9758
9759 Value *BEValue = PN->getIncomingValueForBlock(Latch);
9760
9761 // Execute the loop symbolically to determine the exit value.
9762 assert(BEs.getActiveBits() < CHAR_BIT * sizeof(unsigned) &&
9763 "BEs is <= MaxBruteForceIterations which is an 'unsigned'!");
9764
9765 unsigned NumIterations = BEs.getZExtValue(); // must be in range
9766 unsigned IterationNum = 0;
9767 const DataLayout &DL = getDataLayout();
9768 for (; ; ++IterationNum) {
9769 if (IterationNum == NumIterations)
9770 return RetVal = CurrentIterVals[PN]; // Got exit value!
9771
9772 // Compute the value of the PHIs for the next iteration.
9773 // EvaluateExpression adds non-phi values to the CurrentIterVals map.
9774 DenseMap<Instruction *, Constant *> NextIterVals;
9775 Constant *NextPHI =
9776 EvaluateExpression(BEValue, L, CurrentIterVals, DL, &TLI);
9777 if (!NextPHI)
9778 return nullptr; // Couldn't evaluate!
9779 NextIterVals[PN] = NextPHI;
9780
9781 bool StoppedEvolving = NextPHI == CurrentIterVals[PN];
9782
9783 // Also evaluate the other PHI nodes. However, we don't get to stop if we
9784 // cease to be able to evaluate one of them or if they stop evolving,
9785 // because that doesn't necessarily prevent us from computing PN.
9787 for (const auto &I : CurrentIterVals) {
9788 PHINode *PHI = dyn_cast<PHINode>(I.first);
9789 if (!PHI || PHI == PN || PHI->getParent() != Header) continue;
9790 PHIsToCompute.emplace_back(PHI, I.second);
9791 }
9792 // We use two distinct loops because EvaluateExpression may invalidate any
9793 // iterators into CurrentIterVals.
9794 for (const auto &I : PHIsToCompute) {
9795 PHINode *PHI = I.first;
9796 Constant *&NextPHI = NextIterVals[PHI];
9797 if (!NextPHI) { // Not already computed.
9798 Value *BEValue = PHI->getIncomingValueForBlock(Latch);
9799 NextPHI = EvaluateExpression(BEValue, L, CurrentIterVals, DL, &TLI);
9800 }
9801 if (NextPHI != I.second)
9802 StoppedEvolving = false;
9803 }
9804
9805 // If all entries in CurrentIterVals == NextIterVals then we can stop
9806 // iterating, the loop can't continue to change.
9807 if (StoppedEvolving)
9808 return RetVal = CurrentIterVals[PN];
9809
9810 CurrentIterVals.swap(NextIterVals);
9811 }
9812}
9813
9814const SCEV *ScalarEvolution::computeExitCountExhaustively(const Loop *L,
9815 Value *Cond,
9816 bool ExitWhen) {
9817 PHINode *PN = getConstantEvolvingPHI(Cond, L);
9818 if (!PN) return getCouldNotCompute();
9819
9820 // If the loop is canonicalized, the PHI will have exactly two entries.
9821 // That's the only form we support here.
9822 if (PN->getNumIncomingValues() != 2) return getCouldNotCompute();
9823
9824 DenseMap<Instruction *, Constant *> CurrentIterVals;
9825 BasicBlock *Header = L->getHeader();
9826 assert(PN->getParent() == Header && "Can't evaluate PHI not in loop header!");
9827
9828 BasicBlock *Latch = L->getLoopLatch();
9829 assert(Latch && "Should follow from NumIncomingValues == 2!");
9830
9831 for (PHINode &PHI : Header->phis()) {
9832 if (auto *StartCST = getOtherIncomingValue(&PHI, Latch))
9833 CurrentIterVals[&PHI] = StartCST;
9834 }
9835 if (!CurrentIterVals.count(PN))
9836 return getCouldNotCompute();
9837
9838 // Okay, we find a PHI node that defines the trip count of this loop. Execute
9839 // the loop symbolically to determine when the condition gets a value of
9840 // "ExitWhen".
9841 unsigned MaxIterations = MaxBruteForceIterations; // Limit analysis.
9842 const DataLayout &DL = getDataLayout();
9843 for (unsigned IterationNum = 0; IterationNum != MaxIterations;++IterationNum){
9844 auto *CondVal = dyn_cast_or_null<ConstantInt>(
9845 EvaluateExpression(Cond, L, CurrentIterVals, DL, &TLI));
9846
9847 // Couldn't symbolically evaluate.
9848 if (!CondVal) return getCouldNotCompute();
9849
9850 if (CondVal->getValue() == uint64_t(ExitWhen)) {
9851 ++NumBruteForceTripCountsComputed;
9852 return getConstant(Type::getInt32Ty(getContext()), IterationNum);
9853 }
9854
9855 // Update all the PHI nodes for the next iteration.
9856 DenseMap<Instruction *, Constant *> NextIterVals;
9857
9858 // Create a list of which PHIs we need to compute. We want to do this before
9859 // calling EvaluateExpression on them because that may invalidate iterators
9860 // into CurrentIterVals.
9861 SmallVector<PHINode *, 8> PHIsToCompute;
9862 for (const auto &I : CurrentIterVals) {
9863 PHINode *PHI = dyn_cast<PHINode>(I.first);
9864 if (!PHI || PHI->getParent() != Header) continue;
9865 PHIsToCompute.push_back(PHI);
9866 }
9867 for (PHINode *PHI : PHIsToCompute) {
9868 Constant *&NextPHI = NextIterVals[PHI];
9869 if (NextPHI) continue; // Already computed!
9870
9871 Value *BEValue = PHI->getIncomingValueForBlock(Latch);
9872 NextPHI = EvaluateExpression(BEValue, L, CurrentIterVals, DL, &TLI);
9873 }
9874 CurrentIterVals.swap(NextIterVals);
9875 }
9876
9877 // Too many iterations were needed to evaluate.
9878 return getCouldNotCompute();
9879}
9880
9881const SCEV *ScalarEvolution::getSCEVAtScope(const SCEV *V, const Loop *L) {
9883 ValuesAtScopes[V];
9884 // Check to see if we've folded this expression at this loop before.
9885 for (auto &LS : Values)
9886 if (LS.first == L)
9887 return LS.second ? LS.second : V;
9888
9889 Values.emplace_back(L, nullptr);
9890
9891 // Otherwise compute it.
9892 const SCEV *C = computeSCEVAtScope(V, L);
9893 for (auto &LS : reverse(ValuesAtScopes[V]))
9894 if (LS.first == L) {
9895 LS.second = C;
9896 if (!isa<SCEVConstant>(C))
9897 ValuesAtScopesUsers[C].push_back({L, V});
9898 break;
9899 }
9900 return C;
9901}
9902
9903/// This builds up a Constant using the ConstantExpr interface. That way, we
9904/// will return Constants for objects which aren't represented by a
9905/// SCEVConstant, because SCEVConstant is restricted to ConstantInt.
9906/// Returns NULL if the SCEV isn't representable as a Constant.
9908 switch (V->getSCEVType()) {
9909 case scCouldNotCompute:
9910 case scAddRecExpr:
9911 case scVScale:
9912 return nullptr;
9913 case scConstant:
9914 return cast<SCEVConstant>(V)->getValue();
9915 case scUnknown:
9916 return dyn_cast<Constant>(cast<SCEVUnknown>(V)->getValue());
9917 case scPtrToInt: {
9919 if (Constant *CastOp = BuildConstantFromSCEV(P2I->getOperand()))
9920 return ConstantExpr::getPtrToInt(CastOp, P2I->getType());
9921
9922 return nullptr;
9923 }
9924 case scTruncate: {
9926 if (Constant *CastOp = BuildConstantFromSCEV(ST->getOperand()))
9927 return ConstantExpr::getTrunc(CastOp, ST->getType());
9928 return nullptr;
9929 }
9930 case scAddExpr: {
9931 const SCEVAddExpr *SA = cast<SCEVAddExpr>(V);
9932 Constant *C = nullptr;
9933 for (const SCEV *Op : SA->operands()) {
9935 if (!OpC)
9936 return nullptr;
9937 if (!C) {
9938 C = OpC;
9939 continue;
9940 }
9941 assert(!C->getType()->isPointerTy() &&
9942 "Can only have one pointer, and it must be last");
9943 if (OpC->getType()->isPointerTy()) {
9944 // The offsets have been converted to bytes. We can add bytes using
9945 // an i8 GEP.
9947 OpC, C);
9948 } else {
9949 C = ConstantExpr::getAdd(C, OpC);
9950 }
9951 }
9952 return C;
9953 }
9954 case scMulExpr:
9955 case scSignExtend:
9956 case scZeroExtend:
9957 case scUDivExpr:
9958 case scSMaxExpr:
9959 case scUMaxExpr:
9960 case scSMinExpr:
9961 case scUMinExpr:
9963 return nullptr;
9964 }
9965 llvm_unreachable("Unknown SCEV kind!");
9966}
9967
9968const SCEV *
9969ScalarEvolution::getWithOperands(const SCEV *S,
9970 SmallVectorImpl<const SCEV *> &NewOps) {
9971 switch (S->getSCEVType()) {
9972 case scTruncate:
9973 case scZeroExtend:
9974 case scSignExtend:
9975 case scPtrToInt:
9976 return getCastExpr(S->getSCEVType(), NewOps[0], S->getType());
9977 case scAddRecExpr: {
9978 auto *AddRec = cast<SCEVAddRecExpr>(S);
9979 return getAddRecExpr(NewOps, AddRec->getLoop(), AddRec->getNoWrapFlags());
9980 }
9981 case scAddExpr:
9982 return getAddExpr(NewOps, cast<SCEVAddExpr>(S)->getNoWrapFlags());
9983 case scMulExpr:
9984 return getMulExpr(NewOps, cast<SCEVMulExpr>(S)->getNoWrapFlags());
9985 case scUDivExpr:
9986 return getUDivExpr(NewOps[0], NewOps[1]);
9987 case scUMaxExpr:
9988 case scSMaxExpr:
9989 case scUMinExpr:
9990 case scSMinExpr:
9991 return getMinMaxExpr(S->getSCEVType(), NewOps);
9993 return getSequentialMinMaxExpr(S->getSCEVType(), NewOps);
9994 case scConstant:
9995 case scVScale:
9996 case scUnknown:
9997 return S;
9998 case scCouldNotCompute:
9999 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
10000 }
10001 llvm_unreachable("Unknown SCEV kind!");
10002}
10003
10004const SCEV *ScalarEvolution::computeSCEVAtScope(const SCEV *V, const Loop *L) {
10005 switch (V->getSCEVType()) {
10006 case scConstant:
10007 case scVScale:
10008 return V;
10009 case scAddRecExpr: {
10010 // If this is a loop recurrence for a loop that does not contain L, then we
10011 // are dealing with the final value computed by the loop.
10012 const SCEVAddRecExpr *AddRec = cast<SCEVAddRecExpr>(V);
10013 // First, attempt to evaluate each operand.
10014 // Avoid performing the look-up in the common case where the specified
10015 // expression has no loop-variant portions.
10016 for (unsigned i = 0, e = AddRec->getNumOperands(); i != e; ++i) {
10017 const SCEV *OpAtScope = getSCEVAtScope(AddRec->getOperand(i), L);
10018 if (OpAtScope == AddRec->getOperand(i))
10019 continue;
10020
10021 // Okay, at least one of these operands is loop variant but might be
10022 // foldable. Build a new instance of the folded commutative expression.
10024 NewOps.reserve(AddRec->getNumOperands());
10025 append_range(NewOps, AddRec->operands().take_front(i));
10026 NewOps.push_back(OpAtScope);
10027 for (++i; i != e; ++i)
10028 NewOps.push_back(getSCEVAtScope(AddRec->getOperand(i), L));
10029
10030 const SCEV *FoldedRec = getAddRecExpr(
10031 NewOps, AddRec->getLoop(), AddRec->getNoWrapFlags(SCEV::FlagNW));
10032 AddRec = dyn_cast<SCEVAddRecExpr>(FoldedRec);
10033 // The addrec may be folded to a nonrecurrence, for example, if the
10034 // induction variable is multiplied by zero after constant folding. Go
10035 // ahead and return the folded value.
10036 if (!AddRec)
10037 return FoldedRec;
10038 break;
10039 }
10040
10041 // If the scope is outside the addrec's loop, evaluate it by using the
10042 // loop exit value of the addrec.
10043 if (!AddRec->getLoop()->contains(L)) {
10044 // To evaluate this recurrence, we need to know how many times the AddRec
10045 // loop iterates. Compute this now.
10046 const SCEV *BackedgeTakenCount = getBackedgeTakenCount(AddRec->getLoop());
10047 if (BackedgeTakenCount == getCouldNotCompute())
10048 return AddRec;
10049
10050 // Then, evaluate the AddRec.
10051 return AddRec->evaluateAtIteration(BackedgeTakenCount, *this);
10052 }
10053
10054 return AddRec;
10055 }
10056 case scTruncate:
10057 case scZeroExtend:
10058 case scSignExtend:
10059 case scPtrToInt:
10060 case scAddExpr:
10061 case scMulExpr:
10062 case scUDivExpr:
10063 case scUMaxExpr:
10064 case scSMaxExpr:
10065 case scUMinExpr:
10066 case scSMinExpr:
10067 case scSequentialUMinExpr: {
10068 ArrayRef<const SCEV *> Ops = V->operands();
10069 // Avoid performing the look-up in the common case where the specified
10070 // expression has no loop-variant portions.
10071 for (unsigned i = 0, e = Ops.size(); i != e; ++i) {
10072 const SCEV *OpAtScope = getSCEVAtScope(Ops[i], L);
10073 if (OpAtScope != Ops[i]) {
10074 // Okay, at least one of these operands is loop variant but might be
10075 // foldable. Build a new instance of the folded commutative expression.
10077 NewOps.reserve(Ops.size());
10078 append_range(NewOps, Ops.take_front(i));
10079 NewOps.push_back(OpAtScope);
10080
10081 for (++i; i != e; ++i) {
10082 OpAtScope = getSCEVAtScope(Ops[i], L);
10083 NewOps.push_back(OpAtScope);
10084 }
10085
10086 return getWithOperands(V, NewOps);
10087 }
10088 }
10089 // If we got here, all operands are loop invariant.
10090 return V;
10091 }
10092 case scUnknown: {
10093 // If this instruction is evolved from a constant-evolving PHI, compute the
10094 // exit value from the loop without using SCEVs.
10095 const SCEVUnknown *SU = cast<SCEVUnknown>(V);
10097 if (!I)
10098 return V; // This is some other type of SCEVUnknown, just return it.
10099
10100 if (PHINode *PN = dyn_cast<PHINode>(I)) {
10101 const Loop *CurrLoop = this->LI[I->getParent()];
10102 // Looking for loop exit value.
10103 if (CurrLoop && CurrLoop->getParentLoop() == L &&
10104 PN->getParent() == CurrLoop->getHeader()) {
10105 // Okay, there is no closed form solution for the PHI node. Check
10106 // to see if the loop that contains it has a known backedge-taken
10107 // count. If so, we may be able to force computation of the exit
10108 // value.
10109 const SCEV *BackedgeTakenCount = getBackedgeTakenCount(CurrLoop);
10110 // This trivial case can show up in some degenerate cases where
10111 // the incoming IR has not yet been fully simplified.
10112 if (BackedgeTakenCount->isZero()) {
10113 Value *InitValue = nullptr;
10114 bool MultipleInitValues = false;
10115 for (unsigned i = 0; i < PN->getNumIncomingValues(); i++) {
10116 if (!CurrLoop->contains(PN->getIncomingBlock(i))) {
10117 if (!InitValue)
10118 InitValue = PN->getIncomingValue(i);
10119 else if (InitValue != PN->getIncomingValue(i)) {
10120 MultipleInitValues = true;
10121 break;
10122 }
10123 }
10124 }
10125 if (!MultipleInitValues && InitValue)
10126 return getSCEV(InitValue);
10127 }
10128 // Do we have a loop invariant value flowing around the backedge
10129 // for a loop which must execute the backedge?
10130 if (!isa<SCEVCouldNotCompute>(BackedgeTakenCount) &&
10131 isKnownNonZero(BackedgeTakenCount) &&
10132 PN->getNumIncomingValues() == 2) {
10133
10134 unsigned InLoopPred =
10135 CurrLoop->contains(PN->getIncomingBlock(0)) ? 0 : 1;
10136 Value *BackedgeVal = PN->getIncomingValue(InLoopPred);
10137 if (CurrLoop->isLoopInvariant(BackedgeVal))
10138 return getSCEV(BackedgeVal);
10139 }
10140 if (auto *BTCC = dyn_cast<SCEVConstant>(BackedgeTakenCount)) {
10141 // Okay, we know how many times the containing loop executes. If
10142 // this is a constant evolving PHI node, get the final value at
10143 // the specified iteration number.
10144 Constant *RV =
10145 getConstantEvolutionLoopExitValue(PN, BTCC->getAPInt(), CurrLoop);
10146 if (RV)
10147 return getSCEV(RV);
10148 }
10149 }
10150 }
10151
10152 // Okay, this is an expression that we cannot symbolically evaluate
10153 // into a SCEV. Check to see if it's possible to symbolically evaluate
10154 // the arguments into constants, and if so, try to constant propagate the
10155 // result. This is particularly useful for computing loop exit values.
10156 if (!CanConstantFold(I))
10157 return V; // This is some other type of SCEVUnknown, just return it.
10158
10160 Operands.reserve(I->getNumOperands());
10161 bool MadeImprovement = false;
10162 for (Value *Op : I->operands()) {
10163 if (Constant *C = dyn_cast<Constant>(Op)) {
10164 Operands.push_back(C);
10165 continue;
10166 }
10167
10168 // If any of the operands is non-constant and if they are
10169 // non-integer and non-pointer, don't even try to analyze them
10170 // with scev techniques.
10171 if (!isSCEVable(Op->getType()))
10172 return V;
10173
10174 const SCEV *OrigV = getSCEV(Op);
10175 const SCEV *OpV = getSCEVAtScope(OrigV, L);
10176 MadeImprovement |= OrigV != OpV;
10177
10179 if (!C)
10180 return V;
10181 assert(C->getType() == Op->getType() && "Type mismatch");
10182 Operands.push_back(C);
10183 }
10184
10185 // Check to see if getSCEVAtScope actually made an improvement.
10186 if (!MadeImprovement)
10187 return V; // This is some other type of SCEVUnknown, just return it.
10188
10189 Constant *C = nullptr;
10190 const DataLayout &DL = getDataLayout();
10191 C = ConstantFoldInstOperands(I, Operands, DL, &TLI,
10192 /*AllowNonDeterministic=*/false);
10193 if (!C)
10194 return V;
10195 return getSCEV(C);
10196 }
10197 case scCouldNotCompute:
10198 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
10199 }
10200 llvm_unreachable("Unknown SCEV type!");
10201}
10202
10204 return getSCEVAtScope(getSCEV(V), L);
10205}
10206
10207const SCEV *ScalarEvolution::stripInjectiveFunctions(const SCEV *S) const {
10209 return stripInjectiveFunctions(ZExt->getOperand());
10211 return stripInjectiveFunctions(SExt->getOperand());
10212 return S;
10213}
10214
10215/// Finds the minimum unsigned root of the following equation:
10216///
10217/// A * X = B (mod N)
10218///
10219/// where N = 2^BW and BW is the common bit width of A and B. The signedness of
10220/// A and B isn't important.
10221///
10222/// If the equation does not have a solution, SCEVCouldNotCompute is returned.
10223static const SCEV *
10226
10227 ScalarEvolution &SE) {
10228 uint32_t BW = A.getBitWidth();
10229 assert(BW == SE.getTypeSizeInBits(B->getType()));
10230 assert(A != 0 && "A must be non-zero.");
10231
10232 // 1. D = gcd(A, N)
10233 //
10234 // The gcd of A and N may have only one prime factor: 2. The number of
10235 // trailing zeros in A is its multiplicity
10236 uint32_t Mult2 = A.countr_zero();
10237 // D = 2^Mult2
10238
10239 // 2. Check if B is divisible by D.
10240 //
10241 // B is divisible by D if and only if the multiplicity of prime factor 2 for B
10242 // is not less than multiplicity of this prime factor for D.
10243 if (SE.getMinTrailingZeros(B) < Mult2) {
10244 // Check if we can prove there's no remainder using URem.
10245 const SCEV *URem =
10246 SE.getURemExpr(B, SE.getConstant(APInt::getOneBitSet(BW, Mult2)));
10247 const SCEV *Zero = SE.getZero(B->getType());
10248 if (!SE.isKnownPredicate(CmpInst::ICMP_EQ, URem, Zero)) {
10249 // Try to add a predicate ensuring B is a multiple of 1 << Mult2.
10250 if (!Predicates)
10251 return SE.getCouldNotCompute();
10252
10253 // Avoid adding a predicate that is known to be false.
10254 if (SE.isKnownPredicate(CmpInst::ICMP_NE, URem, Zero))
10255 return SE.getCouldNotCompute();
10256 Predicates->push_back(SE.getEqualPredicate(URem, Zero));
10257 }
10258 }
10259
10260 // 3. Compute I: the multiplicative inverse of (A / D) in arithmetic
10261 // modulo (N / D).
10262 //
10263 // If D == 1, (N / D) == N == 2^BW, so we need one extra bit to represent
10264 // (N / D) in general. The inverse itself always fits into BW bits, though,
10265 // so we immediately truncate it.
10266 APInt AD = A.lshr(Mult2).trunc(BW - Mult2); // AD = A / D
10267 APInt I = AD.multiplicativeInverse().zext(BW);
10268
10269 // 4. Compute the minimum unsigned root of the equation:
10270 // I * (B / D) mod (N / D)
10271 // To simplify the computation, we factor out the divide by D:
10272 // (I * B mod N) / D
10273 const SCEV *D = SE.getConstant(APInt::getOneBitSet(BW, Mult2));
10274 return SE.getUDivExactExpr(SE.getMulExpr(B, SE.getConstant(I)), D);
10275}
10276
10277/// For a given quadratic addrec, generate coefficients of the corresponding
10278/// quadratic equation, multiplied by a common value to ensure that they are
10279/// integers.
10280/// The returned value is a tuple { A, B, C, M, BitWidth }, where
10281/// Ax^2 + Bx + C is the quadratic function, M is the value that A, B and C
10282/// were multiplied by, and BitWidth is the bit width of the original addrec
10283/// coefficients.
10284/// This function returns std::nullopt if the addrec coefficients are not
10285/// compile- time constants.
10286static std::optional<std::tuple<APInt, APInt, APInt, APInt, unsigned>>
10288 assert(AddRec->getNumOperands() == 3 && "This is not a quadratic chrec!");
10289 const SCEVConstant *LC = dyn_cast<SCEVConstant>(AddRec->getOperand(0));
10290 const SCEVConstant *MC = dyn_cast<SCEVConstant>(AddRec->getOperand(1));
10291 const SCEVConstant *NC = dyn_cast<SCEVConstant>(AddRec->getOperand(2));
10292 LLVM_DEBUG(dbgs() << __func__ << ": analyzing quadratic addrec: "
10293 << *AddRec << '\n');
10294
10295 // We currently can only solve this if the coefficients are constants.
10296 if (!LC || !MC || !NC) {
10297 LLVM_DEBUG(dbgs() << __func__ << ": coefficients are not constant\n");
10298 return std::nullopt;
10299 }
10300
10301 APInt L = LC->getAPInt();
10302 APInt M = MC->getAPInt();
10303 APInt N = NC->getAPInt();
10304 assert(!N.isZero() && "This is not a quadratic addrec");
10305
10306 unsigned BitWidth = LC->getAPInt().getBitWidth();
10307 unsigned NewWidth = BitWidth + 1;
10308 LLVM_DEBUG(dbgs() << __func__ << ": addrec coeff bw: "
10309 << BitWidth << '\n');
10310 // The sign-extension (as opposed to a zero-extension) here matches the
10311 // extension used in SolveQuadraticEquationWrap (with the same motivation).
10312 N = N.sext(NewWidth);
10313 M = M.sext(NewWidth);
10314 L = L.sext(NewWidth);
10315
10316 // The increments are M, M+N, M+2N, ..., so the accumulated values are
10317 // L+M, (L+M)+(M+N), (L+M)+(M+N)+(M+2N), ..., that is,
10318 // L+M, L+2M+N, L+3M+3N, ...
10319 // After n iterations the accumulated value Acc is L + nM + n(n-1)/2 N.
10320 //
10321 // The equation Acc = 0 is then
10322 // L + nM + n(n-1)/2 N = 0, or 2L + 2M n + n(n-1) N = 0.
10323 // In a quadratic form it becomes:
10324 // N n^2 + (2M-N) n + 2L = 0.
10325
10326 APInt A = N;
10327 APInt B = 2 * M - A;
10328 APInt C = 2 * L;
10329 APInt T = APInt(NewWidth, 2);
10330 LLVM_DEBUG(dbgs() << __func__ << ": equation " << A << "x^2 + " << B
10331 << "x + " << C << ", coeff bw: " << NewWidth
10332 << ", multiplied by " << T << '\n');
10333 return std::make_tuple(A, B, C, T, BitWidth);
10334}
10335
10336/// Helper function to compare optional APInts:
10337/// (a) if X and Y both exist, return min(X, Y),
10338/// (b) if neither X nor Y exist, return std::nullopt,
10339/// (c) if exactly one of X and Y exists, return that value.
10340static std::optional<APInt> MinOptional(std::optional<APInt> X,
10341 std::optional<APInt> Y) {
10342 if (X && Y) {
10343 unsigned W = std::max(X->getBitWidth(), Y->getBitWidth());
10344 APInt XW = X->sext(W);
10345 APInt YW = Y->sext(W);
10346 return XW.slt(YW) ? *X : *Y;
10347 }
10348 if (!X && !Y)
10349 return std::nullopt;
10350 return X ? *X : *Y;
10351}
10352
10353/// Helper function to truncate an optional APInt to a given BitWidth.
10354/// When solving addrec-related equations, it is preferable to return a value
10355/// that has the same bit width as the original addrec's coefficients. If the
10356/// solution fits in the original bit width, truncate it (except for i1).
10357/// Returning a value of a different bit width may inhibit some optimizations.
10358///
10359/// In general, a solution to a quadratic equation generated from an addrec
10360/// may require BW+1 bits, where BW is the bit width of the addrec's
10361/// coefficients. The reason is that the coefficients of the quadratic
10362/// equation are BW+1 bits wide (to avoid truncation when converting from
10363/// the addrec to the equation).
10364static std::optional<APInt> TruncIfPossible(std::optional<APInt> X,
10365 unsigned BitWidth) {
10366 if (!X)
10367 return std::nullopt;
10368 unsigned W = X->getBitWidth();
10370 return X->trunc(BitWidth);
10371 return X;
10372}
10373
10374/// Let c(n) be the value of the quadratic chrec {L,+,M,+,N} after n
10375/// iterations. The values L, M, N are assumed to be signed, and they
10376/// should all have the same bit widths.
10377/// Find the least n >= 0 such that c(n) = 0 in the arithmetic modulo 2^BW,
10378/// where BW is the bit width of the addrec's coefficients.
10379/// If the calculated value is a BW-bit integer (for BW > 1), it will be
10380/// returned as such, otherwise the bit width of the returned value may
10381/// be greater than BW.
10382///
10383/// This function returns std::nullopt if
10384/// (a) the addrec coefficients are not constant, or
10385/// (b) SolveQuadraticEquationWrap was unable to find a solution. For cases
10386/// like x^2 = 5, no integer solutions exist, in other cases an integer
10387/// solution may exist, but SolveQuadraticEquationWrap may fail to find it.
10388static std::optional<APInt>
10390 APInt A, B, C, M;
10391 unsigned BitWidth;
10392 auto T = GetQuadraticEquation(AddRec);
10393 if (!T)
10394 return std::nullopt;
10395
10396 std::tie(A, B, C, M, BitWidth) = *T;
10397 LLVM_DEBUG(dbgs() << __func__ << ": solving for unsigned overflow\n");
10398 std::optional<APInt> X =
10400 if (!X)
10401 return std::nullopt;
10402
10403 ConstantInt *CX = ConstantInt::get(SE.getContext(), *X);
10404 ConstantInt *V = EvaluateConstantChrecAtConstant(AddRec, CX, SE);
10405 if (!V->isZero())
10406 return std::nullopt;
10407
10408 return TruncIfPossible(X, BitWidth);
10409}
10410
10411/// Let c(n) be the value of the quadratic chrec {0,+,M,+,N} after n
10412/// iterations. The values M, N are assumed to be signed, and they
10413/// should all have the same bit widths.
10414/// Find the least n such that c(n) does not belong to the given range,
10415/// while c(n-1) does.
10416///
10417/// This function returns std::nullopt if
10418/// (a) the addrec coefficients are not constant, or
10419/// (b) SolveQuadraticEquationWrap was unable to find a solution for the
10420/// bounds of the range.
10421static std::optional<APInt>
10423 const ConstantRange &Range, ScalarEvolution &SE) {
10424 assert(AddRec->getOperand(0)->isZero() &&
10425 "Starting value of addrec should be 0");
10426 LLVM_DEBUG(dbgs() << __func__ << ": solving boundary crossing for range "
10427 << Range << ", addrec " << *AddRec << '\n');
10428 // This case is handled in getNumIterationsInRange. Here we can assume that
10429 // we start in the range.
10430 assert(Range.contains(APInt(SE.getTypeSizeInBits(AddRec->getType()), 0)) &&
10431 "Addrec's initial value should be in range");
10432
10433 APInt A, B, C, M;
10434 unsigned BitWidth;
10435 auto T = GetQuadraticEquation(AddRec);
10436 if (!T)
10437 return std::nullopt;
10438
10439 // Be careful about the return value: there can be two reasons for not
10440 // returning an actual number. First, if no solutions to the equations
10441 // were found, and second, if the solutions don't leave the given range.
10442 // The first case means that the actual solution is "unknown", the second
10443 // means that it's known, but not valid. If the solution is unknown, we
10444 // cannot make any conclusions.
10445 // Return a pair: the optional solution and a flag indicating if the
10446 // solution was found.
10447 auto SolveForBoundary =
10448 [&](APInt Bound) -> std::pair<std::optional<APInt>, bool> {
10449 // Solve for signed overflow and unsigned overflow, pick the lower
10450 // solution.
10451 LLVM_DEBUG(dbgs() << "SolveQuadraticAddRecRange: checking boundary "
10452 << Bound << " (before multiplying by " << M << ")\n");
10453 Bound *= M; // The quadratic equation multiplier.
10454
10455 std::optional<APInt> SO;
10456 if (BitWidth > 1) {
10457 LLVM_DEBUG(dbgs() << "SolveQuadraticAddRecRange: solving for "
10458 "signed overflow\n");
10460 }
10461 LLVM_DEBUG(dbgs() << "SolveQuadraticAddRecRange: solving for "
10462 "unsigned overflow\n");
10463 std::optional<APInt> UO =
10465
10466 auto LeavesRange = [&] (const APInt &X) {
10467 ConstantInt *C0 = ConstantInt::get(SE.getContext(), X);
10468 ConstantInt *V0 = EvaluateConstantChrecAtConstant(AddRec, C0, SE);
10469 if (Range.contains(V0->getValue()))
10470 return false;
10471 // X should be at least 1, so X-1 is non-negative.
10472 ConstantInt *C1 = ConstantInt::get(SE.getContext(), X-1);
10473 ConstantInt *V1 = EvaluateConstantChrecAtConstant(AddRec, C1, SE);
10474 if (Range.contains(V1->getValue()))
10475 return true;
10476 return false;
10477 };
10478
10479 // If SolveQuadraticEquationWrap returns std::nullopt, it means that there
10480 // can be a solution, but the function failed to find it. We cannot treat it
10481 // as "no solution".
10482 if (!SO || !UO)
10483 return {std::nullopt, false};
10484
10485 // Check the smaller value first to see if it leaves the range.
10486 // At this point, both SO and UO must have values.
10487 std::optional<APInt> Min = MinOptional(SO, UO);
10488 if (LeavesRange(*Min))
10489 return { Min, true };
10490 std::optional<APInt> Max = Min == SO ? UO : SO;
10491 if (LeavesRange(*Max))
10492 return { Max, true };
10493
10494 // Solutions were found, but were eliminated, hence the "true".
10495 return {std::nullopt, true};
10496 };
10497
10498 std::tie(A, B, C, M, BitWidth) = *T;
10499 // Lower bound is inclusive, subtract 1 to represent the exiting value.
10500 APInt Lower = Range.getLower().sext(A.getBitWidth()) - 1;
10501 APInt Upper = Range.getUpper().sext(A.getBitWidth());
10502 auto SL = SolveForBoundary(Lower);
10503 auto SU = SolveForBoundary(Upper);
10504 // If any of the solutions was unknown, no meaninigful conclusions can
10505 // be made.
10506 if (!SL.second || !SU.second)
10507 return std::nullopt;
10508
10509 // Claim: The correct solution is not some value between Min and Max.
10510 //
10511 // Justification: Assuming that Min and Max are different values, one of
10512 // them is when the first signed overflow happens, the other is when the
10513 // first unsigned overflow happens. Crossing the range boundary is only
10514 // possible via an overflow (treating 0 as a special case of it, modeling
10515 // an overflow as crossing k*2^W for some k).
10516 //
10517 // The interesting case here is when Min was eliminated as an invalid
10518 // solution, but Max was not. The argument is that if there was another
10519 // overflow between Min and Max, it would also have been eliminated if
10520 // it was considered.
10521 //
10522 // For a given boundary, it is possible to have two overflows of the same
10523 // type (signed/unsigned) without having the other type in between: this
10524 // can happen when the vertex of the parabola is between the iterations
10525 // corresponding to the overflows. This is only possible when the two
10526 // overflows cross k*2^W for the same k. In such case, if the second one
10527 // left the range (and was the first one to do so), the first overflow
10528 // would have to enter the range, which would mean that either we had left
10529 // the range before or that we started outside of it. Both of these cases
10530 // are contradictions.
10531 //
10532 // Claim: In the case where SolveForBoundary returns std::nullopt, the correct
10533 // solution is not some value between the Max for this boundary and the
10534 // Min of the other boundary.
10535 //
10536 // Justification: Assume that we had such Max_A and Min_B corresponding
10537 // to range boundaries A and B and such that Max_A < Min_B. If there was
10538 // a solution between Max_A and Min_B, it would have to be caused by an
10539 // overflow corresponding to either A or B. It cannot correspond to B,
10540 // since Min_B is the first occurrence of such an overflow. If it
10541 // corresponded to A, it would have to be either a signed or an unsigned
10542 // overflow that is larger than both eliminated overflows for A. But
10543 // between the eliminated overflows and this overflow, the values would
10544 // cover the entire value space, thus crossing the other boundary, which
10545 // is a contradiction.
10546
10547 return TruncIfPossible(MinOptional(SL.first, SU.first), BitWidth);
10548}
10549
10550ScalarEvolution::ExitLimit ScalarEvolution::howFarToZero(const SCEV *V,
10551 const Loop *L,
10552 bool ControlsOnlyExit,
10553 bool AllowPredicates) {
10554
10555 // This is only used for loops with a "x != y" exit test. The exit condition
10556 // is now expressed as a single expression, V = x-y. So the exit test is
10557 // effectively V != 0. We know and take advantage of the fact that this
10558 // expression only being used in a comparison by zero context.
10559
10561 // If the value is a constant
10562 if (const SCEVConstant *C = dyn_cast<SCEVConstant>(V)) {
10563 // If the value is already zero, the branch will execute zero times.
10564 if (C->getValue()->isZero()) return C;
10565 return getCouldNotCompute(); // Otherwise it will loop infinitely.
10566 }
10567
10568 const SCEVAddRecExpr *AddRec =
10569 dyn_cast<SCEVAddRecExpr>(stripInjectiveFunctions(V));
10570
10571 if (!AddRec && AllowPredicates)
10572 // Try to make this an AddRec using runtime tests, in the first X
10573 // iterations of this loop, where X is the SCEV expression found by the
10574 // algorithm below.
10575 AddRec = convertSCEVToAddRecWithPredicates(V, L, Predicates);
10576
10577 if (!AddRec || AddRec->getLoop() != L)
10578 return getCouldNotCompute();
10579
10580 // If this is a quadratic (3-term) AddRec {L,+,M,+,N}, find the roots of
10581 // the quadratic equation to solve it.
10582 if (AddRec->isQuadratic() && AddRec->getType()->isIntegerTy()) {
10583 // We can only use this value if the chrec ends up with an exact zero
10584 // value at this index. When solving for "X*X != 5", for example, we
10585 // should not accept a root of 2.
10586 if (auto S = SolveQuadraticAddRecExact(AddRec, *this)) {
10587 const auto *R = cast<SCEVConstant>(getConstant(*S));
10588 return ExitLimit(R, R, R, false, Predicates);
10589 }
10590 return getCouldNotCompute();
10591 }
10592
10593 // Otherwise we can only handle this if it is affine.
10594 if (!AddRec->isAffine())
10595 return getCouldNotCompute();
10596
10597 // If this is an affine expression, the execution count of this branch is
10598 // the minimum unsigned root of the following equation:
10599 //
10600 // Start + Step*N = 0 (mod 2^BW)
10601 //
10602 // equivalent to:
10603 //
10604 // Step*N = -Start (mod 2^BW)
10605 //
10606 // where BW is the common bit width of Start and Step.
10607
10608 // Get the initial value for the loop.
10609 const SCEV *Start = getSCEVAtScope(AddRec->getStart(), L->getParentLoop());
10610 const SCEV *Step = getSCEVAtScope(AddRec->getOperand(1), L->getParentLoop());
10611
10612 if (!isLoopInvariant(Step, L))
10613 return getCouldNotCompute();
10614
10615 LoopGuards Guards = LoopGuards::collect(L, *this);
10616 // Specialize step for this loop so we get context sensitive facts below.
10617 const SCEV *StepWLG = applyLoopGuards(Step, Guards);
10618
10619 // For positive steps (counting up until unsigned overflow):
10620 // N = -Start/Step (as unsigned)
10621 // For negative steps (counting down to zero):
10622 // N = Start/-Step
10623 // First compute the unsigned distance from zero in the direction of Step.
10624 bool CountDown = isKnownNegative(StepWLG);
10625 if (!CountDown && !isKnownNonNegative(StepWLG))
10626 return getCouldNotCompute();
10627
10628 const SCEV *Distance = CountDown ? Start : getNegativeSCEV(Start);
10629 // Handle unitary steps, which cannot wraparound.
10630 // 1*N = -Start; -1*N = Start (mod 2^BW), so:
10631 // N = Distance (as unsigned)
10632
10633 if (match(Step, m_CombineOr(m_scev_One(), m_scev_AllOnes()))) {
10634 APInt MaxBECount = getUnsignedRangeMax(applyLoopGuards(Distance, Guards));
10635 MaxBECount = APIntOps::umin(MaxBECount, getUnsignedRangeMax(Distance));
10636
10637 // When a loop like "for (int i = 0; i != n; ++i) { /* body */ }" is rotated,
10638 // we end up with a loop whose backedge-taken count is n - 1. Detect this
10639 // case, and see if we can improve the bound.
10640 //
10641 // Explicitly handling this here is necessary because getUnsignedRange
10642 // isn't context-sensitive; it doesn't know that we only care about the
10643 // range inside the loop.
10644 const SCEV *Zero = getZero(Distance->getType());
10645 const SCEV *One = getOne(Distance->getType());
10646 const SCEV *DistancePlusOne = getAddExpr(Distance, One);
10647 if (isLoopEntryGuardedByCond(L, ICmpInst::ICMP_NE, DistancePlusOne, Zero)) {
10648 // If Distance + 1 doesn't overflow, we can compute the maximum distance
10649 // as "unsigned_max(Distance + 1) - 1".
10650 ConstantRange CR = getUnsignedRange(DistancePlusOne);
10651 MaxBECount = APIntOps::umin(MaxBECount, CR.getUnsignedMax() - 1);
10652 }
10653 return ExitLimit(Distance, getConstant(MaxBECount), Distance, false,
10654 Predicates);
10655 }
10656
10657 // If the condition controls loop exit (the loop exits only if the expression
10658 // is true) and the addition is no-wrap we can use unsigned divide to
10659 // compute the backedge count. In this case, the step may not divide the
10660 // distance, but we don't care because if the condition is "missed" the loop
10661 // will have undefined behavior due to wrapping.
10662 if (ControlsOnlyExit && AddRec->hasNoSelfWrap() &&
10663 loopHasNoAbnormalExits(AddRec->getLoop())) {
10664
10665 // If the stride is zero and the start is non-zero, the loop must be
10666 // infinite. In C++, most loops are finite by assumption, in which case the
10667 // step being zero implies UB must execute if the loop is entered.
10668 if (!(loopIsFiniteByAssumption(L) && isKnownNonZero(Start)) &&
10669 !isKnownNonZero(StepWLG))
10670 return getCouldNotCompute();
10671
10672 const SCEV *Exact =
10673 getUDivExpr(Distance, CountDown ? getNegativeSCEV(Step) : Step);
10674 const SCEV *ConstantMax = getCouldNotCompute();
10675 if (Exact != getCouldNotCompute()) {
10676 APInt MaxInt = getUnsignedRangeMax(applyLoopGuards(Exact, Guards));
10677 ConstantMax =
10679 }
10680 const SCEV *SymbolicMax =
10681 isa<SCEVCouldNotCompute>(Exact) ? ConstantMax : Exact;
10682 return ExitLimit(Exact, ConstantMax, SymbolicMax, false, Predicates);
10683 }
10684
10685 // Solve the general equation.
10686 const SCEVConstant *StepC = dyn_cast<SCEVConstant>(Step);
10687 if (!StepC || StepC->getValue()->isZero())
10688 return getCouldNotCompute();
10689 const SCEV *E = SolveLinEquationWithOverflow(
10690 StepC->getAPInt(), getNegativeSCEV(Start),
10691 AllowPredicates ? &Predicates : nullptr, *this);
10692
10693 const SCEV *M = E;
10694 if (E != getCouldNotCompute()) {
10695 APInt MaxWithGuards = getUnsignedRangeMax(applyLoopGuards(E, Guards));
10696 M = getConstant(APIntOps::umin(MaxWithGuards, getUnsignedRangeMax(E)));
10697 }
10698 auto *S = isa<SCEVCouldNotCompute>(E) ? M : E;
10699 return ExitLimit(E, M, S, false, Predicates);
10700}
10701
10702ScalarEvolution::ExitLimit
10703ScalarEvolution::howFarToNonZero(const SCEV *V, const Loop *L) {
10704 // Loops that look like: while (X == 0) are very strange indeed. We don't
10705 // handle them yet except for the trivial case. This could be expanded in the
10706 // future as needed.
10707
10708 // If the value is a constant, check to see if it is known to be non-zero
10709 // already. If so, the backedge will execute zero times.
10710 if (const SCEVConstant *C = dyn_cast<SCEVConstant>(V)) {
10711 if (!C->getValue()->isZero())
10712 return getZero(C->getType());
10713 return getCouldNotCompute(); // Otherwise it will loop infinitely.
10714 }
10715
10716 // We could implement others, but I really doubt anyone writes loops like
10717 // this, and if they did, they would already be constant folded.
10718 return getCouldNotCompute();
10719}
10720
10721std::pair<const BasicBlock *, const BasicBlock *>
10722ScalarEvolution::getPredecessorWithUniqueSuccessorForBB(const BasicBlock *BB)
10723 const {
10724 // If the block has a unique predecessor, then there is no path from the
10725 // predecessor to the block that does not go through the direct edge
10726 // from the predecessor to the block.
10727 if (const BasicBlock *Pred = BB->getSinglePredecessor())
10728 return {Pred, BB};
10729
10730 // A loop's header is defined to be a block that dominates the loop.
10731 // If the header has a unique predecessor outside the loop, it must be
10732 // a block that has exactly one successor that can reach the loop.
10733 if (const Loop *L = LI.getLoopFor(BB))
10734 return {L->getLoopPredecessor(), L->getHeader()};
10735
10736 return {nullptr, BB};
10737}
10738
10739/// SCEV structural equivalence is usually sufficient for testing whether two
10740/// expressions are equal, however for the purposes of looking for a condition
10741/// guarding a loop, it can be useful to be a little more general, since a
10742/// front-end may have replicated the controlling expression.
10743static bool HasSameValue(const SCEV *A, const SCEV *B) {
10744 // Quick check to see if they are the same SCEV.
10745 if (A == B) return true;
10746
10747 auto ComputesEqualValues = [](const Instruction *A, const Instruction *B) {
10748 // Not all instructions that are "identical" compute the same value. For
10749 // instance, two distinct alloca instructions allocating the same type are
10750 // identical and do not read memory; but compute distinct values.
10751 return A->isIdenticalTo(B) && (isa<BinaryOperator>(A) || isa<GetElementPtrInst>(A));
10752 };
10753
10754 // Otherwise, if they're both SCEVUnknown, it's possible that they hold
10755 // two different instructions with the same value. Check for this case.
10756 if (const SCEVUnknown *AU = dyn_cast<SCEVUnknown>(A))
10757 if (const SCEVUnknown *BU = dyn_cast<SCEVUnknown>(B))
10758 if (const Instruction *AI = dyn_cast<Instruction>(AU->getValue()))
10759 if (const Instruction *BI = dyn_cast<Instruction>(BU->getValue()))
10760 if (ComputesEqualValues(AI, BI))
10761 return true;
10762
10763 // Otherwise assume they may have a different value.
10764 return false;
10765}
10766
10767static bool MatchBinarySub(const SCEV *S, const SCEV *&LHS, const SCEV *&RHS) {
10769 if (!Add || Add->getNumOperands() != 2)
10770 return false;
10771 if (auto *ME = dyn_cast<SCEVMulExpr>(Add->getOperand(0));
10772 ME && ME->getNumOperands() == 2 && ME->getOperand(0)->isAllOnesValue()) {
10773 LHS = Add->getOperand(1);
10774 RHS = ME->getOperand(1);
10775 return true;
10776 }
10777 if (auto *ME = dyn_cast<SCEVMulExpr>(Add->getOperand(1));
10778 ME && ME->getNumOperands() == 2 && ME->getOperand(0)->isAllOnesValue()) {
10779 LHS = Add->getOperand(0);
10780 RHS = ME->getOperand(1);
10781 return true;
10782 }
10783 return false;
10784}
10785
10787 const SCEV *&RHS, unsigned Depth) {
10788 bool Changed = false;
10789 // Simplifies ICMP to trivial true or false by turning it into '0 == 0' or
10790 // '0 != 0'.
10791 auto TrivialCase = [&](bool TriviallyTrue) {
10793 Pred = TriviallyTrue ? ICmpInst::ICMP_EQ : ICmpInst::ICMP_NE;
10794 return true;
10795 };
10796 // If we hit the max recursion limit bail out.
10797 if (Depth >= 3)
10798 return false;
10799
10800 // Canonicalize a constant to the right side.
10801 if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(LHS)) {
10802 // Check for both operands constant.
10803 if (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS)) {
10804 if (!ICmpInst::compare(LHSC->getAPInt(), RHSC->getAPInt(), Pred))
10805 return TrivialCase(false);
10806 return TrivialCase(true);
10807 }
10808 // Otherwise swap the operands to put the constant on the right.
10809 std::swap(LHS, RHS);
10811 Changed = true;
10812 }
10813
10814 // If we're comparing an addrec with a value which is loop-invariant in the
10815 // addrec's loop, put the addrec on the left. Also make a dominance check,
10816 // as both operands could be addrecs loop-invariant in each other's loop.
10817 if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(RHS)) {
10818 const Loop *L = AR->getLoop();
10819 if (isLoopInvariant(LHS, L) && properlyDominates(LHS, L->getHeader())) {
10820 std::swap(LHS, RHS);
10822 Changed = true;
10823 }
10824 }
10825
10826 // If there's a constant operand, canonicalize comparisons with boundary
10827 // cases, and canonicalize *-or-equal comparisons to regular comparisons.
10828 if (const SCEVConstant *RC = dyn_cast<SCEVConstant>(RHS)) {
10829 const APInt &RA = RC->getAPInt();
10830
10831 bool SimplifiedByConstantRange = false;
10832
10833 if (!ICmpInst::isEquality(Pred)) {
10835 if (ExactCR.isFullSet())
10836 return TrivialCase(true);
10837 if (ExactCR.isEmptySet())
10838 return TrivialCase(false);
10839
10840 APInt NewRHS;
10841 CmpInst::Predicate NewPred;
10842 if (ExactCR.getEquivalentICmp(NewPred, NewRHS) &&
10843 ICmpInst::isEquality(NewPred)) {
10844 // We were able to convert an inequality to an equality.
10845 Pred = NewPred;
10846 RHS = getConstant(NewRHS);
10847 Changed = SimplifiedByConstantRange = true;
10848 }
10849 }
10850
10851 if (!SimplifiedByConstantRange) {
10852 switch (Pred) {
10853 default:
10854 break;
10855 case ICmpInst::ICMP_EQ:
10856 case ICmpInst::ICMP_NE:
10857 // Fold ((-1) * %a) + %b == 0 (equivalent to %b-%a == 0) into %a == %b.
10858 if (RA.isZero() && MatchBinarySub(LHS, LHS, RHS))
10859 Changed = true;
10860 break;
10861
10862 // The "Should have been caught earlier!" messages refer to the fact
10863 // that the ExactCR.isFullSet() or ExactCR.isEmptySet() check above
10864 // should have fired on the corresponding cases, and canonicalized the
10865 // check to trivial case.
10866
10867 case ICmpInst::ICMP_UGE:
10868 assert(!RA.isMinValue() && "Should have been caught earlier!");
10869 Pred = ICmpInst::ICMP_UGT;
10870 RHS = getConstant(RA - 1);
10871 Changed = true;
10872 break;
10873 case ICmpInst::ICMP_ULE:
10874 assert(!RA.isMaxValue() && "Should have been caught earlier!");
10875 Pred = ICmpInst::ICMP_ULT;
10876 RHS = getConstant(RA + 1);
10877 Changed = true;
10878 break;
10879 case ICmpInst::ICMP_SGE:
10880 assert(!RA.isMinSignedValue() && "Should have been caught earlier!");
10881 Pred = ICmpInst::ICMP_SGT;
10882 RHS = getConstant(RA - 1);
10883 Changed = true;
10884 break;
10885 case ICmpInst::ICMP_SLE:
10886 assert(!RA.isMaxSignedValue() && "Should have been caught earlier!");
10887 Pred = ICmpInst::ICMP_SLT;
10888 RHS = getConstant(RA + 1);
10889 Changed = true;
10890 break;
10891 }
10892 }
10893 }
10894
10895 // Check for obvious equality.
10896 if (HasSameValue(LHS, RHS)) {
10897 if (ICmpInst::isTrueWhenEqual(Pred))
10898 return TrivialCase(true);
10900 return TrivialCase(false);
10901 }
10902
10903 // If possible, canonicalize GE/LE comparisons to GT/LT comparisons, by
10904 // adding or subtracting 1 from one of the operands.
10905 switch (Pred) {
10906 case ICmpInst::ICMP_SLE:
10907 if (!getSignedRangeMax(RHS).isMaxSignedValue()) {
10908 RHS = getAddExpr(getConstant(RHS->getType(), 1, true), RHS,
10910 Pred = ICmpInst::ICMP_SLT;
10911 Changed = true;
10912 } else if (!getSignedRangeMin(LHS).isMinSignedValue()) {
10913 LHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), LHS,
10915 Pred = ICmpInst::ICMP_SLT;
10916 Changed = true;
10917 }
10918 break;
10919 case ICmpInst::ICMP_SGE:
10920 if (!getSignedRangeMin(RHS).isMinSignedValue()) {
10921 RHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), RHS,
10923 Pred = ICmpInst::ICMP_SGT;
10924 Changed = true;
10925 } else if (!getSignedRangeMax(LHS).isMaxSignedValue()) {
10926 LHS = getAddExpr(getConstant(RHS->getType(), 1, true), LHS,
10928 Pred = ICmpInst::ICMP_SGT;
10929 Changed = true;
10930 }
10931 break;
10932 case ICmpInst::ICMP_ULE:
10933 if (!getUnsignedRangeMax(RHS).isMaxValue()) {
10934 RHS = getAddExpr(getConstant(RHS->getType(), 1, true), RHS,
10936 Pred = ICmpInst::ICMP_ULT;
10937 Changed = true;
10938 } else if (!getUnsignedRangeMin(LHS).isMinValue()) {
10939 LHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), LHS);
10940 Pred = ICmpInst::ICMP_ULT;
10941 Changed = true;
10942 }
10943 break;
10944 case ICmpInst::ICMP_UGE:
10945 // If RHS is an op we can fold the -1, try that first.
10946 // Otherwise prefer LHS to preserve the nuw flag.
10947 if ((isa<SCEVConstant>(RHS) ||
10949 isa<SCEVConstant>(cast<SCEVNAryExpr>(RHS)->getOperand(0)))) &&
10950 !getUnsignedRangeMin(RHS).isMinValue()) {
10951 RHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), RHS);
10952 Pred = ICmpInst::ICMP_UGT;
10953 Changed = true;
10954 } else if (!getUnsignedRangeMax(LHS).isMaxValue()) {
10955 LHS = getAddExpr(getConstant(RHS->getType(), 1, true), LHS,
10957 Pred = ICmpInst::ICMP_UGT;
10958 Changed = true;
10959 } else if (!getUnsignedRangeMin(RHS).isMinValue()) {
10960 RHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), RHS);
10961 Pred = ICmpInst::ICMP_UGT;
10962 Changed = true;
10963 }
10964 break;
10965 default:
10966 break;
10967 }
10968
10969 // TODO: More simplifications are possible here.
10970
10971 // Recursively simplify until we either hit a recursion limit or nothing
10972 // changes.
10973 if (Changed)
10974 return SimplifyICmpOperands(Pred, LHS, RHS, Depth + 1);
10975
10976 return Changed;
10977}
10978
10980 return getSignedRangeMax(S).isNegative();
10981}
10982
10986
10988 return !getSignedRangeMin(S).isNegative();
10989}
10990
10994
10996 // Query push down for cases where the unsigned range is
10997 // less than sufficient.
10998 if (const auto *SExt = dyn_cast<SCEVSignExtendExpr>(S))
10999 return isKnownNonZero(SExt->getOperand(0));
11000 return getUnsignedRangeMin(S) != 0;
11001}
11002
11004 bool OrNegative) {
11005 auto NonRecursive = [this, OrNegative](const SCEV *S) {
11006 if (auto *C = dyn_cast<SCEVConstant>(S))
11007 return C->getAPInt().isPowerOf2() ||
11008 (OrNegative && C->getAPInt().isNegatedPowerOf2());
11009
11010 // The vscale_range indicates vscale is a power-of-two.
11011 return isa<SCEVVScale>(S) && F.hasFnAttribute(Attribute::VScaleRange);
11012 };
11013
11014 if (NonRecursive(S))
11015 return true;
11016
11017 auto *Mul = dyn_cast<SCEVMulExpr>(S);
11018 if (!Mul)
11019 return false;
11020 return all_of(Mul->operands(), NonRecursive) && (OrZero || isKnownNonZero(S));
11021}
11022
11024 const SCEV *S, uint64_t M,
11026 if (M == 0)
11027 return false;
11028 if (M == 1)
11029 return true;
11030
11031 // Recursively check AddRec operands. An AddRecExpr S is a multiple of M if S
11032 // starts with a multiple of M and at every iteration step S only adds
11033 // multiples of M.
11034 if (auto *AddRec = dyn_cast<SCEVAddRecExpr>(S))
11035 return isKnownMultipleOf(AddRec->getStart(), M, Assumptions) &&
11036 isKnownMultipleOf(AddRec->getStepRecurrence(*this), M, Assumptions);
11037
11038 // For a constant, check that "S % M == 0".
11039 if (auto *Cst = dyn_cast<SCEVConstant>(S)) {
11040 APInt C = Cst->getAPInt();
11041 return C.urem(M) == 0;
11042 }
11043
11044 // TODO: Also check other SCEV expressions, i.e., SCEVAddRecExpr, etc.
11045
11046 // Basic tests have failed.
11047 // Check "S % M == 0" at compile time and record runtime Assumptions.
11048 auto *STy = dyn_cast<IntegerType>(S->getType());
11049 const SCEV *SmodM =
11050 getURemExpr(S, getConstant(ConstantInt::get(STy, M, false)));
11051 const SCEV *Zero = getZero(STy);
11052
11053 // Check whether "S % M == 0" is known at compile time.
11054 if (isKnownPredicate(ICmpInst::ICMP_EQ, SmodM, Zero))
11055 return true;
11056
11057 // Check whether "S % M != 0" is known at compile time.
11058 if (isKnownPredicate(ICmpInst::ICMP_NE, SmodM, Zero))
11059 return false;
11060
11062
11063 // Detect redundant predicates.
11064 for (auto *A : Assumptions)
11065 if (A->implies(P, *this))
11066 return true;
11067
11068 // Only record non-redundant predicates.
11069 Assumptions.push_back(P);
11070 return true;
11071}
11072
11073std::pair<const SCEV *, const SCEV *>
11075 // Compute SCEV on entry of loop L.
11076 const SCEV *Start = SCEVInitRewriter::rewrite(S, L, *this);
11077 if (Start == getCouldNotCompute())
11078 return { Start, Start };
11079 // Compute post increment SCEV for loop L.
11080 const SCEV *PostInc = SCEVPostIncRewriter::rewrite(S, L, *this);
11081 assert(PostInc != getCouldNotCompute() && "Unexpected could not compute");
11082 return { Start, PostInc };
11083}
11084
11086 const SCEV *RHS) {
11087 // First collect all loops.
11089 getUsedLoops(LHS, LoopsUsed);
11090 getUsedLoops(RHS, LoopsUsed);
11091
11092 if (LoopsUsed.empty())
11093 return false;
11094
11095 // Domination relationship must be a linear order on collected loops.
11096#ifndef NDEBUG
11097 for (const auto *L1 : LoopsUsed)
11098 for (const auto *L2 : LoopsUsed)
11099 assert((DT.dominates(L1->getHeader(), L2->getHeader()) ||
11100 DT.dominates(L2->getHeader(), L1->getHeader())) &&
11101 "Domination relationship is not a linear order");
11102#endif
11103
11104 const Loop *MDL =
11105 *llvm::max_element(LoopsUsed, [&](const Loop *L1, const Loop *L2) {
11106 return DT.properlyDominates(L1->getHeader(), L2->getHeader());
11107 });
11108
11109 // Get init and post increment value for LHS.
11110 auto SplitLHS = SplitIntoInitAndPostInc(MDL, LHS);
11111 // if LHS contains unknown non-invariant SCEV then bail out.
11112 if (SplitLHS.first == getCouldNotCompute())
11113 return false;
11114 assert (SplitLHS.second != getCouldNotCompute() && "Unexpected CNC");
11115 // Get init and post increment value for RHS.
11116 auto SplitRHS = SplitIntoInitAndPostInc(MDL, RHS);
11117 // if RHS contains unknown non-invariant SCEV then bail out.
11118 if (SplitRHS.first == getCouldNotCompute())
11119 return false;
11120 assert (SplitRHS.second != getCouldNotCompute() && "Unexpected CNC");
11121 // It is possible that init SCEV contains an invariant load but it does
11122 // not dominate MDL and is not available at MDL loop entry, so we should
11123 // check it here.
11124 if (!isAvailableAtLoopEntry(SplitLHS.first, MDL) ||
11125 !isAvailableAtLoopEntry(SplitRHS.first, MDL))
11126 return false;
11127
11128 // It seems backedge guard check is faster than entry one so in some cases
11129 // it can speed up whole estimation by short circuit
11130 return isLoopBackedgeGuardedByCond(MDL, Pred, SplitLHS.second,
11131 SplitRHS.second) &&
11132 isLoopEntryGuardedByCond(MDL, Pred, SplitLHS.first, SplitRHS.first);
11133}
11134
11136 const SCEV *RHS) {
11137 // Canonicalize the inputs first.
11138 (void)SimplifyICmpOperands(Pred, LHS, RHS);
11139
11140 if (isKnownViaInduction(Pred, LHS, RHS))
11141 return true;
11142
11143 if (isKnownPredicateViaSplitting(Pred, LHS, RHS))
11144 return true;
11145
11146 // Otherwise see what can be done with some simple reasoning.
11147 return isKnownViaNonRecursiveReasoning(Pred, LHS, RHS);
11148}
11149
11151 const SCEV *LHS,
11152 const SCEV *RHS) {
11153 if (isKnownPredicate(Pred, LHS, RHS))
11154 return true;
11156 return false;
11157 return std::nullopt;
11158}
11159
11161 const SCEV *RHS,
11162 const Instruction *CtxI) {
11163 // TODO: Analyze guards and assumes from Context's block.
11164 return isKnownPredicate(Pred, LHS, RHS) ||
11165 isBasicBlockEntryGuardedByCond(CtxI->getParent(), Pred, LHS, RHS);
11166}
11167
11168std::optional<bool>
11170 const SCEV *RHS, const Instruction *CtxI) {
11171 std::optional<bool> KnownWithoutContext = evaluatePredicate(Pred, LHS, RHS);
11172 if (KnownWithoutContext)
11173 return KnownWithoutContext;
11174
11175 if (isBasicBlockEntryGuardedByCond(CtxI->getParent(), Pred, LHS, RHS))
11176 return true;
11178 CtxI->getParent(), ICmpInst::getInverseCmpPredicate(Pred), LHS, RHS))
11179 return false;
11180 return std::nullopt;
11181}
11182
11184 const SCEVAddRecExpr *LHS,
11185 const SCEV *RHS) {
11186 const Loop *L = LHS->getLoop();
11187 return isLoopEntryGuardedByCond(L, Pred, LHS->getStart(), RHS) &&
11188 isLoopBackedgeGuardedByCond(L, Pred, LHS->getPostIncExpr(*this), RHS);
11189}
11190
11191std::optional<ScalarEvolution::MonotonicPredicateType>
11193 ICmpInst::Predicate Pred) {
11194 auto Result = getMonotonicPredicateTypeImpl(LHS, Pred);
11195
11196#ifndef NDEBUG
11197 // Verify an invariant: inverting the predicate should turn a monotonically
11198 // increasing change to a monotonically decreasing one, and vice versa.
11199 if (Result) {
11200 auto ResultSwapped =
11201 getMonotonicPredicateTypeImpl(LHS, ICmpInst::getSwappedPredicate(Pred));
11202
11203 assert(*ResultSwapped != *Result &&
11204 "monotonicity should flip as we flip the predicate");
11205 }
11206#endif
11207
11208 return Result;
11209}
11210
11211std::optional<ScalarEvolution::MonotonicPredicateType>
11212ScalarEvolution::getMonotonicPredicateTypeImpl(const SCEVAddRecExpr *LHS,
11213 ICmpInst::Predicate Pred) {
11214 // A zero step value for LHS means the induction variable is essentially a
11215 // loop invariant value. We don't really depend on the predicate actually
11216 // flipping from false to true (for increasing predicates, and the other way
11217 // around for decreasing predicates), all we care about is that *if* the
11218 // predicate changes then it only changes from false to true.
11219 //
11220 // A zero step value in itself is not very useful, but there may be places
11221 // where SCEV can prove X >= 0 but not prove X > 0, so it is helpful to be
11222 // as general as possible.
11223
11224 // Only handle LE/LT/GE/GT predicates.
11225 if (!ICmpInst::isRelational(Pred))
11226 return std::nullopt;
11227
11228 bool IsGreater = ICmpInst::isGE(Pred) || ICmpInst::isGT(Pred);
11229 assert((IsGreater || ICmpInst::isLE(Pred) || ICmpInst::isLT(Pred)) &&
11230 "Should be greater or less!");
11231
11232 // Check that AR does not wrap.
11233 if (ICmpInst::isUnsigned(Pred)) {
11234 if (!LHS->hasNoUnsignedWrap())
11235 return std::nullopt;
11237 }
11238 assert(ICmpInst::isSigned(Pred) &&
11239 "Relational predicate is either signed or unsigned!");
11240 if (!LHS->hasNoSignedWrap())
11241 return std::nullopt;
11242
11243 const SCEV *Step = LHS->getStepRecurrence(*this);
11244
11245 if (isKnownNonNegative(Step))
11247
11248 if (isKnownNonPositive(Step))
11250
11251 return std::nullopt;
11252}
11253
11254std::optional<ScalarEvolution::LoopInvariantPredicate>
11256 const SCEV *RHS, const Loop *L,
11257 const Instruction *CtxI) {
11258 // If there is a loop-invariant, force it into the RHS, otherwise bail out.
11259 if (!isLoopInvariant(RHS, L)) {
11260 if (!isLoopInvariant(LHS, L))
11261 return std::nullopt;
11262
11263 std::swap(LHS, RHS);
11265 }
11266
11267 const SCEVAddRecExpr *ArLHS = dyn_cast<SCEVAddRecExpr>(LHS);
11268 if (!ArLHS || ArLHS->getLoop() != L)
11269 return std::nullopt;
11270
11271 auto MonotonicType = getMonotonicPredicateType(ArLHS, Pred);
11272 if (!MonotonicType)
11273 return std::nullopt;
11274 // If the predicate "ArLHS `Pred` RHS" monotonically increases from false to
11275 // true as the loop iterates, and the backedge is control dependent on
11276 // "ArLHS `Pred` RHS" == true then we can reason as follows:
11277 //
11278 // * if the predicate was false in the first iteration then the predicate
11279 // is never evaluated again, since the loop exits without taking the
11280 // backedge.
11281 // * if the predicate was true in the first iteration then it will
11282 // continue to be true for all future iterations since it is
11283 // monotonically increasing.
11284 //
11285 // For both the above possibilities, we can replace the loop varying
11286 // predicate with its value on the first iteration of the loop (which is
11287 // loop invariant).
11288 //
11289 // A similar reasoning applies for a monotonically decreasing predicate, by
11290 // replacing true with false and false with true in the above two bullets.
11292 auto P = Increasing ? Pred : ICmpInst::getInverseCmpPredicate(Pred);
11293
11294 if (isLoopBackedgeGuardedByCond(L, P, LHS, RHS))
11296 RHS);
11297
11298 if (!CtxI)
11299 return std::nullopt;
11300 // Try to prove via context.
11301 // TODO: Support other cases.
11302 switch (Pred) {
11303 default:
11304 break;
11305 case ICmpInst::ICMP_ULE:
11306 case ICmpInst::ICMP_ULT: {
11307 assert(ArLHS->hasNoUnsignedWrap() && "Is a requirement of monotonicity!");
11308 // Given preconditions
11309 // (1) ArLHS does not cross the border of positive and negative parts of
11310 // range because of:
11311 // - Positive step; (TODO: lift this limitation)
11312 // - nuw - does not cross zero boundary;
11313 // - nsw - does not cross SINT_MAX boundary;
11314 // (2) ArLHS <s RHS
11315 // (3) RHS >=s 0
11316 // we can replace the loop variant ArLHS <u RHS condition with loop
11317 // invariant Start(ArLHS) <u RHS.
11318 //
11319 // Because of (1) there are two options:
11320 // - ArLHS is always negative. It means that ArLHS <u RHS is always false;
11321 // - ArLHS is always non-negative. Because of (3) RHS is also non-negative.
11322 // It means that ArLHS <s RHS <=> ArLHS <u RHS.
11323 // Because of (2) ArLHS <u RHS is trivially true.
11324 // All together it means that ArLHS <u RHS <=> Start(ArLHS) >=s 0.
11325 // We can strengthen this to Start(ArLHS) <u RHS.
11326 auto SignFlippedPred = ICmpInst::getFlippedSignednessPredicate(Pred);
11327 if (ArLHS->hasNoSignedWrap() && ArLHS->isAffine() &&
11328 isKnownPositive(ArLHS->getStepRecurrence(*this)) &&
11329 isKnownNonNegative(RHS) &&
11330 isKnownPredicateAt(SignFlippedPred, ArLHS, RHS, CtxI))
11332 RHS);
11333 }
11334 }
11335
11336 return std::nullopt;
11337}
11338
11339std::optional<ScalarEvolution::LoopInvariantPredicate>
11341 CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS, const Loop *L,
11342 const Instruction *CtxI, const SCEV *MaxIter) {
11344 Pred, LHS, RHS, L, CtxI, MaxIter))
11345 return LIP;
11346 if (auto *UMin = dyn_cast<SCEVUMinExpr>(MaxIter))
11347 // Number of iterations expressed as UMIN isn't always great for expressing
11348 // the value on the last iteration. If the straightforward approach didn't
11349 // work, try the following trick: if the a predicate is invariant for X, it
11350 // is also invariant for umin(X, ...). So try to find something that works
11351 // among subexpressions of MaxIter expressed as umin.
11352 for (auto *Op : UMin->operands())
11354 Pred, LHS, RHS, L, CtxI, Op))
11355 return LIP;
11356 return std::nullopt;
11357}
11358
11359std::optional<ScalarEvolution::LoopInvariantPredicate>
11361 CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS, const Loop *L,
11362 const Instruction *CtxI, const SCEV *MaxIter) {
11363 // Try to prove the following set of facts:
11364 // - The predicate is monotonic in the iteration space.
11365 // - If the check does not fail on the 1st iteration:
11366 // - No overflow will happen during first MaxIter iterations;
11367 // - It will not fail on the MaxIter'th iteration.
11368 // If the check does fail on the 1st iteration, we leave the loop and no
11369 // other checks matter.
11370
11371 // If there is a loop-invariant, force it into the RHS, otherwise bail out.
11372 if (!isLoopInvariant(RHS, L)) {
11373 if (!isLoopInvariant(LHS, L))
11374 return std::nullopt;
11375
11376 std::swap(LHS, RHS);
11378 }
11379
11380 auto *AR = dyn_cast<SCEVAddRecExpr>(LHS);
11381 if (!AR || AR->getLoop() != L)
11382 return std::nullopt;
11383
11384 // The predicate must be relational (i.e. <, <=, >=, >).
11385 if (!ICmpInst::isRelational(Pred))
11386 return std::nullopt;
11387
11388 // TODO: Support steps other than +/- 1.
11389 const SCEV *Step = AR->getStepRecurrence(*this);
11390 auto *One = getOne(Step->getType());
11391 auto *MinusOne = getNegativeSCEV(One);
11392 if (Step != One && Step != MinusOne)
11393 return std::nullopt;
11394
11395 // Type mismatch here means that MaxIter is potentially larger than max
11396 // unsigned value in start type, which mean we cannot prove no wrap for the
11397 // indvar.
11398 if (AR->getType() != MaxIter->getType())
11399 return std::nullopt;
11400
11401 // Value of IV on suggested last iteration.
11402 const SCEV *Last = AR->evaluateAtIteration(MaxIter, *this);
11403 // Does it still meet the requirement?
11404 if (!isLoopBackedgeGuardedByCond(L, Pred, Last, RHS))
11405 return std::nullopt;
11406 // Because step is +/- 1 and MaxIter has same type as Start (i.e. it does
11407 // not exceed max unsigned value of this type), this effectively proves
11408 // that there is no wrap during the iteration. To prove that there is no
11409 // signed/unsigned wrap, we need to check that
11410 // Start <= Last for step = 1 or Start >= Last for step = -1.
11411 ICmpInst::Predicate NoOverflowPred =
11413 if (Step == MinusOne)
11414 NoOverflowPred = ICmpInst::getSwappedCmpPredicate(NoOverflowPred);
11415 const SCEV *Start = AR->getStart();
11416 if (!isKnownPredicateAt(NoOverflowPred, Start, Last, CtxI))
11417 return std::nullopt;
11418
11419 // Everything is fine.
11420 return ScalarEvolution::LoopInvariantPredicate(Pred, Start, RHS);
11421}
11422
11423bool ScalarEvolution::isKnownPredicateViaConstantRanges(CmpPredicate Pred,
11424 const SCEV *LHS,
11425 const SCEV *RHS) {
11426 if (HasSameValue(LHS, RHS))
11427 return ICmpInst::isTrueWhenEqual(Pred);
11428
11429 auto CheckRange = [&](bool IsSigned) {
11430 auto RangeLHS = IsSigned ? getSignedRange(LHS) : getUnsignedRange(LHS);
11431 auto RangeRHS = IsSigned ? getSignedRange(RHS) : getUnsignedRange(RHS);
11432 return RangeLHS.icmp(Pred, RangeRHS);
11433 };
11434
11435 // The check at the top of the function catches the case where the values are
11436 // known to be equal.
11437 if (Pred == CmpInst::ICMP_EQ)
11438 return false;
11439
11440 if (Pred == CmpInst::ICMP_NE) {
11441 if (CheckRange(true) || CheckRange(false))
11442 return true;
11443 auto *Diff = getMinusSCEV(LHS, RHS);
11444 return !isa<SCEVCouldNotCompute>(Diff) && isKnownNonZero(Diff);
11445 }
11446
11447 return CheckRange(CmpInst::isSigned(Pred));
11448}
11449
11450bool ScalarEvolution::isKnownPredicateViaNoOverflow(CmpPredicate Pred,
11451 const SCEV *LHS,
11452 const SCEV *RHS) {
11453 // Match X to (A + C1)<ExpectedFlags> and Y to (A + C2)<ExpectedFlags>, where
11454 // C1 and C2 are constant integers. If either X or Y are not add expressions,
11455 // consider them as X + 0 and Y + 0 respectively. C1 and C2 are returned via
11456 // OutC1 and OutC2.
11457 auto MatchBinaryAddToConst = [this](const SCEV *X, const SCEV *Y,
11458 APInt &OutC1, APInt &OutC2,
11459 SCEV::NoWrapFlags ExpectedFlags) {
11460 const SCEV *XNonConstOp, *XConstOp;
11461 const SCEV *YNonConstOp, *YConstOp;
11462 SCEV::NoWrapFlags XFlagsPresent;
11463 SCEV::NoWrapFlags YFlagsPresent;
11464
11465 if (!splitBinaryAdd(X, XConstOp, XNonConstOp, XFlagsPresent)) {
11466 XConstOp = getZero(X->getType());
11467 XNonConstOp = X;
11468 XFlagsPresent = ExpectedFlags;
11469 }
11470 if (!isa<SCEVConstant>(XConstOp))
11471 return false;
11472
11473 if (!splitBinaryAdd(Y, YConstOp, YNonConstOp, YFlagsPresent)) {
11474 YConstOp = getZero(Y->getType());
11475 YNonConstOp = Y;
11476 YFlagsPresent = ExpectedFlags;
11477 }
11478
11479 if (YNonConstOp != XNonConstOp)
11480 return false;
11481
11482 if (!isa<SCEVConstant>(YConstOp))
11483 return false;
11484
11485 // When matching ADDs with NUW flags (and unsigned predicates), only the
11486 // second ADD (with the larger constant) requires NUW.
11487 if ((YFlagsPresent & ExpectedFlags) != ExpectedFlags)
11488 return false;
11489 if (ExpectedFlags != SCEV::FlagNUW &&
11490 (XFlagsPresent & ExpectedFlags) != ExpectedFlags) {
11491 return false;
11492 }
11493
11494 OutC1 = cast<SCEVConstant>(XConstOp)->getAPInt();
11495 OutC2 = cast<SCEVConstant>(YConstOp)->getAPInt();
11496
11497 return true;
11498 };
11499
11500 APInt C1;
11501 APInt C2;
11502
11503 switch (Pred) {
11504 default:
11505 break;
11506
11507 case ICmpInst::ICMP_SGE:
11508 std::swap(LHS, RHS);
11509 [[fallthrough]];
11510 case ICmpInst::ICMP_SLE:
11511 // (X + C1)<nsw> s<= (X + C2)<nsw> if C1 s<= C2.
11512 if (MatchBinaryAddToConst(LHS, RHS, C1, C2, SCEV::FlagNSW) && C1.sle(C2))
11513 return true;
11514
11515 break;
11516
11517 case ICmpInst::ICMP_SGT:
11518 std::swap(LHS, RHS);
11519 [[fallthrough]];
11520 case ICmpInst::ICMP_SLT:
11521 // (X + C1)<nsw> s< (X + C2)<nsw> if C1 s< C2.
11522 if (MatchBinaryAddToConst(LHS, RHS, C1, C2, SCEV::FlagNSW) && C1.slt(C2))
11523 return true;
11524
11525 break;
11526
11527 case ICmpInst::ICMP_UGE:
11528 std::swap(LHS, RHS);
11529 [[fallthrough]];
11530 case ICmpInst::ICMP_ULE:
11531 // (X + C1) u<= (X + C2)<nuw> for C1 u<= C2.
11532 if (MatchBinaryAddToConst(LHS, RHS, C1, C2, SCEV::FlagNUW) && C1.ule(C2))
11533 return true;
11534
11535 break;
11536
11537 case ICmpInst::ICMP_UGT:
11538 std::swap(LHS, RHS);
11539 [[fallthrough]];
11540 case ICmpInst::ICMP_ULT:
11541 // (X + C1) u< (X + C2)<nuw> if C1 u< C2.
11542 if (MatchBinaryAddToConst(LHS, RHS, C1, C2, SCEV::FlagNUW) && C1.ult(C2))
11543 return true;
11544 break;
11545 }
11546
11547 return false;
11548}
11549
11550bool ScalarEvolution::isKnownPredicateViaSplitting(CmpPredicate Pred,
11551 const SCEV *LHS,
11552 const SCEV *RHS) {
11553 if (Pred != ICmpInst::ICMP_ULT || ProvingSplitPredicate)
11554 return false;
11555
11556 // Allowing arbitrary number of activations of isKnownPredicateViaSplitting on
11557 // the stack can result in exponential time complexity.
11558 SaveAndRestore Restore(ProvingSplitPredicate, true);
11559
11560 // If L >= 0 then I `ult` L <=> I >= 0 && I `slt` L
11561 //
11562 // To prove L >= 0 we use isKnownNonNegative whereas to prove I >= 0 we use
11563 // isKnownPredicate. isKnownPredicate is more powerful, but also more
11564 // expensive; and using isKnownNonNegative(RHS) is sufficient for most of the
11565 // interesting cases seen in practice. We can consider "upgrading" L >= 0 to
11566 // use isKnownPredicate later if needed.
11567 return isKnownNonNegative(RHS) &&
11570}
11571
11572bool ScalarEvolution::isImpliedViaGuard(const BasicBlock *BB, CmpPredicate Pred,
11573 const SCEV *LHS, const SCEV *RHS) {
11574 // No need to even try if we know the module has no guards.
11575 if (!HasGuards)
11576 return false;
11577
11578 return any_of(*BB, [&](const Instruction &I) {
11579 using namespace llvm::PatternMatch;
11580
11581 Value *Condition;
11583 m_Value(Condition))) &&
11584 isImpliedCond(Pred, LHS, RHS, Condition, false);
11585 });
11586}
11587
11588/// isLoopBackedgeGuardedByCond - Test whether the backedge of the loop is
11589/// protected by a conditional between LHS and RHS. This is used to
11590/// to eliminate casts.
11592 CmpPredicate Pred,
11593 const SCEV *LHS,
11594 const SCEV *RHS) {
11595 // Interpret a null as meaning no loop, where there is obviously no guard
11596 // (interprocedural conditions notwithstanding). Do not bother about
11597 // unreachable loops.
11598 if (!L || !DT.isReachableFromEntry(L->getHeader()))
11599 return true;
11600
11601 if (VerifyIR)
11602 assert(!verifyFunction(*L->getHeader()->getParent(), &dbgs()) &&
11603 "This cannot be done on broken IR!");
11604
11605
11606 if (isKnownViaNonRecursiveReasoning(Pred, LHS, RHS))
11607 return true;
11608
11609 BasicBlock *Latch = L->getLoopLatch();
11610 if (!Latch)
11611 return false;
11612
11613 BranchInst *LoopContinuePredicate =
11615 if (LoopContinuePredicate && LoopContinuePredicate->isConditional() &&
11616 isImpliedCond(Pred, LHS, RHS,
11617 LoopContinuePredicate->getCondition(),
11618 LoopContinuePredicate->getSuccessor(0) != L->getHeader()))
11619 return true;
11620
11621 // We don't want more than one activation of the following loops on the stack
11622 // -- that can lead to O(n!) time complexity.
11623 if (WalkingBEDominatingConds)
11624 return false;
11625
11626 SaveAndRestore ClearOnExit(WalkingBEDominatingConds, true);
11627
11628 // See if we can exploit a trip count to prove the predicate.
11629 const auto &BETakenInfo = getBackedgeTakenInfo(L);
11630 const SCEV *LatchBECount = BETakenInfo.getExact(Latch, this);
11631 if (LatchBECount != getCouldNotCompute()) {
11632 // We know that Latch branches back to the loop header exactly
11633 // LatchBECount times. This means the backdege condition at Latch is
11634 // equivalent to "{0,+,1} u< LatchBECount".
11635 Type *Ty = LatchBECount->getType();
11636 auto NoWrapFlags = SCEV::NoWrapFlags(SCEV::FlagNUW | SCEV::FlagNW);
11637 const SCEV *LoopCounter =
11638 getAddRecExpr(getZero(Ty), getOne(Ty), L, NoWrapFlags);
11639 if (isImpliedCond(Pred, LHS, RHS, ICmpInst::ICMP_ULT, LoopCounter,
11640 LatchBECount))
11641 return true;
11642 }
11643
11644 // Check conditions due to any @llvm.assume intrinsics.
11645 for (auto &AssumeVH : AC.assumptions()) {
11646 if (!AssumeVH)
11647 continue;
11648 auto *CI = cast<CallInst>(AssumeVH);
11649 if (!DT.dominates(CI, Latch->getTerminator()))
11650 continue;
11651
11652 if (isImpliedCond(Pred, LHS, RHS, CI->getArgOperand(0), false))
11653 return true;
11654 }
11655
11656 if (isImpliedViaGuard(Latch, Pred, LHS, RHS))
11657 return true;
11658
11659 for (DomTreeNode *DTN = DT[Latch], *HeaderDTN = DT[L->getHeader()];
11660 DTN != HeaderDTN; DTN = DTN->getIDom()) {
11661 assert(DTN && "should reach the loop header before reaching the root!");
11662
11663 BasicBlock *BB = DTN->getBlock();
11664 if (isImpliedViaGuard(BB, Pred, LHS, RHS))
11665 return true;
11666
11667 BasicBlock *PBB = BB->getSinglePredecessor();
11668 if (!PBB)
11669 continue;
11670
11671 BranchInst *ContinuePredicate = dyn_cast<BranchInst>(PBB->getTerminator());
11672 if (!ContinuePredicate || !ContinuePredicate->isConditional())
11673 continue;
11674
11675 Value *Condition = ContinuePredicate->getCondition();
11676
11677 // If we have an edge `E` within the loop body that dominates the only
11678 // latch, the condition guarding `E` also guards the backedge. This
11679 // reasoning works only for loops with a single latch.
11680
11681 BasicBlockEdge DominatingEdge(PBB, BB);
11682 if (DominatingEdge.isSingleEdge()) {
11683 // We're constructively (and conservatively) enumerating edges within the
11684 // loop body that dominate the latch. The dominator tree better agree
11685 // with us on this:
11686 assert(DT.dominates(DominatingEdge, Latch) && "should be!");
11687
11688 if (isImpliedCond(Pred, LHS, RHS, Condition,
11689 BB != ContinuePredicate->getSuccessor(0)))
11690 return true;
11691 }
11692 }
11693
11694 return false;
11695}
11696
11698 CmpPredicate Pred,
11699 const SCEV *LHS,
11700 const SCEV *RHS) {
11701 // Do not bother proving facts for unreachable code.
11702 if (!DT.isReachableFromEntry(BB))
11703 return true;
11704 if (VerifyIR)
11705 assert(!verifyFunction(*BB->getParent(), &dbgs()) &&
11706 "This cannot be done on broken IR!");
11707
11708 // If we cannot prove strict comparison (e.g. a > b), maybe we can prove
11709 // the facts (a >= b && a != b) separately. A typical situation is when the
11710 // non-strict comparison is known from ranges and non-equality is known from
11711 // dominating predicates. If we are proving strict comparison, we always try
11712 // to prove non-equality and non-strict comparison separately.
11713 CmpPredicate NonStrictPredicate = ICmpInst::getNonStrictCmpPredicate(Pred);
11714 const bool ProvingStrictComparison =
11715 Pred != NonStrictPredicate.dropSameSign();
11716 bool ProvedNonStrictComparison = false;
11717 bool ProvedNonEquality = false;
11718
11719 auto SplitAndProve = [&](std::function<bool(CmpPredicate)> Fn) -> bool {
11720 if (!ProvedNonStrictComparison)
11721 ProvedNonStrictComparison = Fn(NonStrictPredicate);
11722 if (!ProvedNonEquality)
11723 ProvedNonEquality = Fn(ICmpInst::ICMP_NE);
11724 if (ProvedNonStrictComparison && ProvedNonEquality)
11725 return true;
11726 return false;
11727 };
11728
11729 if (ProvingStrictComparison) {
11730 auto ProofFn = [&](CmpPredicate P) {
11731 return isKnownViaNonRecursiveReasoning(P, LHS, RHS);
11732 };
11733 if (SplitAndProve(ProofFn))
11734 return true;
11735 }
11736
11737 // Try to prove (Pred, LHS, RHS) using isImpliedCond.
11738 auto ProveViaCond = [&](const Value *Condition, bool Inverse) {
11739 const Instruction *CtxI = &BB->front();
11740 if (isImpliedCond(Pred, LHS, RHS, Condition, Inverse, CtxI))
11741 return true;
11742 if (ProvingStrictComparison) {
11743 auto ProofFn = [&](CmpPredicate P) {
11744 return isImpliedCond(P, LHS, RHS, Condition, Inverse, CtxI);
11745 };
11746 if (SplitAndProve(ProofFn))
11747 return true;
11748 }
11749 return false;
11750 };
11751
11752 // Starting at the block's predecessor, climb up the predecessor chain, as long
11753 // as there are predecessors that can be found that have unique successors
11754 // leading to the original block.
11755 const Loop *ContainingLoop = LI.getLoopFor(BB);
11756 const BasicBlock *PredBB;
11757 if (ContainingLoop && ContainingLoop->getHeader() == BB)
11758 PredBB = ContainingLoop->getLoopPredecessor();
11759 else
11760 PredBB = BB->getSinglePredecessor();
11761 for (std::pair<const BasicBlock *, const BasicBlock *> Pair(PredBB, BB);
11762 Pair.first; Pair = getPredecessorWithUniqueSuccessorForBB(Pair.first)) {
11763 const BranchInst *BlockEntryPredicate =
11764 dyn_cast<BranchInst>(Pair.first->getTerminator());
11765 if (!BlockEntryPredicate || BlockEntryPredicate->isUnconditional())
11766 continue;
11767
11768 if (ProveViaCond(BlockEntryPredicate->getCondition(),
11769 BlockEntryPredicate->getSuccessor(0) != Pair.second))
11770 return true;
11771 }
11772
11773 // Check conditions due to any @llvm.assume intrinsics.
11774 for (auto &AssumeVH : AC.assumptions()) {
11775 if (!AssumeVH)
11776 continue;
11777 auto *CI = cast<CallInst>(AssumeVH);
11778 if (!DT.dominates(CI, BB))
11779 continue;
11780
11781 if (ProveViaCond(CI->getArgOperand(0), false))
11782 return true;
11783 }
11784
11785 // Check conditions due to any @llvm.experimental.guard intrinsics.
11786 auto *GuardDecl = Intrinsic::getDeclarationIfExists(
11787 F.getParent(), Intrinsic::experimental_guard);
11788 if (GuardDecl)
11789 for (const auto *GU : GuardDecl->users())
11790 if (const auto *Guard = dyn_cast<IntrinsicInst>(GU))
11791 if (Guard->getFunction() == BB->getParent() && DT.dominates(Guard, BB))
11792 if (ProveViaCond(Guard->getArgOperand(0), false))
11793 return true;
11794 return false;
11795}
11796
11798 const SCEV *LHS,
11799 const SCEV *RHS) {
11800 // Interpret a null as meaning no loop, where there is obviously no guard
11801 // (interprocedural conditions notwithstanding).
11802 if (!L)
11803 return false;
11804
11805 // Both LHS and RHS must be available at loop entry.
11807 "LHS is not available at Loop Entry");
11809 "RHS is not available at Loop Entry");
11810
11811 if (isKnownViaNonRecursiveReasoning(Pred, LHS, RHS))
11812 return true;
11813
11814 return isBasicBlockEntryGuardedByCond(L->getHeader(), Pred, LHS, RHS);
11815}
11816
11817bool ScalarEvolution::isImpliedCond(CmpPredicate Pred, const SCEV *LHS,
11818 const SCEV *RHS,
11819 const Value *FoundCondValue, bool Inverse,
11820 const Instruction *CtxI) {
11821 // False conditions implies anything. Do not bother analyzing it further.
11822 if (FoundCondValue ==
11823 ConstantInt::getBool(FoundCondValue->getContext(), Inverse))
11824 return true;
11825
11826 if (!PendingLoopPredicates.insert(FoundCondValue).second)
11827 return false;
11828
11829 auto ClearOnExit =
11830 make_scope_exit([&]() { PendingLoopPredicates.erase(FoundCondValue); });
11831
11832 // Recursively handle And and Or conditions.
11833 const Value *Op0, *Op1;
11834 if (match(FoundCondValue, m_LogicalAnd(m_Value(Op0), m_Value(Op1)))) {
11835 if (!Inverse)
11836 return isImpliedCond(Pred, LHS, RHS, Op0, Inverse, CtxI) ||
11837 isImpliedCond(Pred, LHS, RHS, Op1, Inverse, CtxI);
11838 } else if (match(FoundCondValue, m_LogicalOr(m_Value(Op0), m_Value(Op1)))) {
11839 if (Inverse)
11840 return isImpliedCond(Pred, LHS, RHS, Op0, Inverse, CtxI) ||
11841 isImpliedCond(Pred, LHS, RHS, Op1, Inverse, CtxI);
11842 }
11843
11844 const ICmpInst *ICI = dyn_cast<ICmpInst>(FoundCondValue);
11845 if (!ICI) return false;
11846
11847 // Now that we found a conditional branch that dominates the loop or controls
11848 // the loop latch. Check to see if it is the comparison we are looking for.
11849 CmpPredicate FoundPred;
11850 if (Inverse)
11851 FoundPred = ICI->getInverseCmpPredicate();
11852 else
11853 FoundPred = ICI->getCmpPredicate();
11854
11855 const SCEV *FoundLHS = getSCEV(ICI->getOperand(0));
11856 const SCEV *FoundRHS = getSCEV(ICI->getOperand(1));
11857
11858 return isImpliedCond(Pred, LHS, RHS, FoundPred, FoundLHS, FoundRHS, CtxI);
11859}
11860
11861bool ScalarEvolution::isImpliedCond(CmpPredicate Pred, const SCEV *LHS,
11862 const SCEV *RHS, CmpPredicate FoundPred,
11863 const SCEV *FoundLHS, const SCEV *FoundRHS,
11864 const Instruction *CtxI) {
11865 // Balance the types.
11866 if (getTypeSizeInBits(LHS->getType()) <
11867 getTypeSizeInBits(FoundLHS->getType())) {
11868 // For unsigned and equality predicates, try to prove that both found
11869 // operands fit into narrow unsigned range. If so, try to prove facts in
11870 // narrow types.
11871 if (!CmpInst::isSigned(FoundPred) && !FoundLHS->getType()->isPointerTy() &&
11872 !FoundRHS->getType()->isPointerTy()) {
11873 auto *NarrowType = LHS->getType();
11874 auto *WideType = FoundLHS->getType();
11875 auto BitWidth = getTypeSizeInBits(NarrowType);
11876 const SCEV *MaxValue = getZeroExtendExpr(
11878 if (isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_ULE, FoundLHS,
11879 MaxValue) &&
11880 isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_ULE, FoundRHS,
11881 MaxValue)) {
11882 const SCEV *TruncFoundLHS = getTruncateExpr(FoundLHS, NarrowType);
11883 const SCEV *TruncFoundRHS = getTruncateExpr(FoundRHS, NarrowType);
11884 // We cannot preserve samesign after truncation.
11885 if (isImpliedCondBalancedTypes(Pred, LHS, RHS, FoundPred.dropSameSign(),
11886 TruncFoundLHS, TruncFoundRHS, CtxI))
11887 return true;
11888 }
11889 }
11890
11891 if (LHS->getType()->isPointerTy() || RHS->getType()->isPointerTy())
11892 return false;
11893 if (CmpInst::isSigned(Pred)) {
11894 LHS = getSignExtendExpr(LHS, FoundLHS->getType());
11895 RHS = getSignExtendExpr(RHS, FoundLHS->getType());
11896 } else {
11897 LHS = getZeroExtendExpr(LHS, FoundLHS->getType());
11898 RHS = getZeroExtendExpr(RHS, FoundLHS->getType());
11899 }
11900 } else if (getTypeSizeInBits(LHS->getType()) >
11901 getTypeSizeInBits(FoundLHS->getType())) {
11902 if (FoundLHS->getType()->isPointerTy() || FoundRHS->getType()->isPointerTy())
11903 return false;
11904 if (CmpInst::isSigned(FoundPred)) {
11905 FoundLHS = getSignExtendExpr(FoundLHS, LHS->getType());
11906 FoundRHS = getSignExtendExpr(FoundRHS, LHS->getType());
11907 } else {
11908 FoundLHS = getZeroExtendExpr(FoundLHS, LHS->getType());
11909 FoundRHS = getZeroExtendExpr(FoundRHS, LHS->getType());
11910 }
11911 }
11912 return isImpliedCondBalancedTypes(Pred, LHS, RHS, FoundPred, FoundLHS,
11913 FoundRHS, CtxI);
11914}
11915
11916bool ScalarEvolution::isImpliedCondBalancedTypes(
11917 CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS, CmpPredicate FoundPred,
11918 const SCEV *FoundLHS, const SCEV *FoundRHS, const Instruction *CtxI) {
11920 getTypeSizeInBits(FoundLHS->getType()) &&
11921 "Types should be balanced!");
11922 // Canonicalize the query to match the way instcombine will have
11923 // canonicalized the comparison.
11924 if (SimplifyICmpOperands(Pred, LHS, RHS))
11925 if (LHS == RHS)
11926 return CmpInst::isTrueWhenEqual(Pred);
11927 if (SimplifyICmpOperands(FoundPred, FoundLHS, FoundRHS))
11928 if (FoundLHS == FoundRHS)
11929 return CmpInst::isFalseWhenEqual(FoundPred);
11930
11931 // Check to see if we can make the LHS or RHS match.
11932 if (LHS == FoundRHS || RHS == FoundLHS) {
11933 if (isa<SCEVConstant>(RHS)) {
11934 std::swap(FoundLHS, FoundRHS);
11935 FoundPred = ICmpInst::getSwappedCmpPredicate(FoundPred);
11936 } else {
11937 std::swap(LHS, RHS);
11939 }
11940 }
11941
11942 // Check whether the found predicate is the same as the desired predicate.
11943 if (auto P = CmpPredicate::getMatching(FoundPred, Pred))
11944 return isImpliedCondOperands(*P, LHS, RHS, FoundLHS, FoundRHS, CtxI);
11945
11946 // Check whether swapping the found predicate makes it the same as the
11947 // desired predicate.
11948 if (auto P = CmpPredicate::getMatching(
11949 ICmpInst::getSwappedCmpPredicate(FoundPred), Pred)) {
11950 // We can write the implication
11951 // 0. LHS Pred RHS <- FoundLHS SwapPred FoundRHS
11952 // using one of the following ways:
11953 // 1. LHS Pred RHS <- FoundRHS Pred FoundLHS
11954 // 2. RHS SwapPred LHS <- FoundLHS SwapPred FoundRHS
11955 // 3. LHS Pred RHS <- ~FoundLHS Pred ~FoundRHS
11956 // 4. ~LHS SwapPred ~RHS <- FoundLHS SwapPred FoundRHS
11957 // Forms 1. and 2. require swapping the operands of one condition. Don't
11958 // do this if it would break canonical constant/addrec ordering.
11960 return isImpliedCondOperands(ICmpInst::getSwappedCmpPredicate(*P), RHS,
11961 LHS, FoundLHS, FoundRHS, CtxI);
11962 if (!isa<SCEVConstant>(FoundRHS) && !isa<SCEVAddRecExpr>(FoundLHS))
11963 return isImpliedCondOperands(*P, LHS, RHS, FoundRHS, FoundLHS, CtxI);
11964
11965 // There's no clear preference between forms 3. and 4., try both. Avoid
11966 // forming getNotSCEV of pointer values as the resulting subtract is
11967 // not legal.
11968 if (!LHS->getType()->isPointerTy() && !RHS->getType()->isPointerTy() &&
11969 isImpliedCondOperands(ICmpInst::getSwappedCmpPredicate(*P),
11970 getNotSCEV(LHS), getNotSCEV(RHS), FoundLHS,
11971 FoundRHS, CtxI))
11972 return true;
11973
11974 if (!FoundLHS->getType()->isPointerTy() &&
11975 !FoundRHS->getType()->isPointerTy() &&
11976 isImpliedCondOperands(*P, LHS, RHS, getNotSCEV(FoundLHS),
11977 getNotSCEV(FoundRHS), CtxI))
11978 return true;
11979
11980 return false;
11981 }
11982
11983 auto IsSignFlippedPredicate = [](CmpInst::Predicate P1,
11984 CmpInst::Predicate P2) {
11985 assert(P1 != P2 && "Handled earlier!");
11986 return CmpInst::isRelational(P2) &&
11988 };
11989 if (IsSignFlippedPredicate(Pred, FoundPred)) {
11990 // Unsigned comparison is the same as signed comparison when both the
11991 // operands are non-negative or negative.
11992 if ((isKnownNonNegative(FoundLHS) && isKnownNonNegative(FoundRHS)) ||
11993 (isKnownNegative(FoundLHS) && isKnownNegative(FoundRHS)))
11994 return isImpliedCondOperands(Pred, LHS, RHS, FoundLHS, FoundRHS, CtxI);
11995 // Create local copies that we can freely swap and canonicalize our
11996 // conditions to "le/lt".
11997 CmpPredicate CanonicalPred = Pred, CanonicalFoundPred = FoundPred;
11998 const SCEV *CanonicalLHS = LHS, *CanonicalRHS = RHS,
11999 *CanonicalFoundLHS = FoundLHS, *CanonicalFoundRHS = FoundRHS;
12000 if (ICmpInst::isGT(CanonicalPred) || ICmpInst::isGE(CanonicalPred)) {
12001 CanonicalPred = ICmpInst::getSwappedCmpPredicate(CanonicalPred);
12002 CanonicalFoundPred = ICmpInst::getSwappedCmpPredicate(CanonicalFoundPred);
12003 std::swap(CanonicalLHS, CanonicalRHS);
12004 std::swap(CanonicalFoundLHS, CanonicalFoundRHS);
12005 }
12006 assert((ICmpInst::isLT(CanonicalPred) || ICmpInst::isLE(CanonicalPred)) &&
12007 "Must be!");
12008 assert((ICmpInst::isLT(CanonicalFoundPred) ||
12009 ICmpInst::isLE(CanonicalFoundPred)) &&
12010 "Must be!");
12011 if (ICmpInst::isSigned(CanonicalPred) && isKnownNonNegative(CanonicalRHS))
12012 // Use implication:
12013 // x <u y && y >=s 0 --> x <s y.
12014 // If we can prove the left part, the right part is also proven.
12015 return isImpliedCondOperands(CanonicalFoundPred, CanonicalLHS,
12016 CanonicalRHS, CanonicalFoundLHS,
12017 CanonicalFoundRHS);
12018 if (ICmpInst::isUnsigned(CanonicalPred) && isKnownNegative(CanonicalRHS))
12019 // Use implication:
12020 // x <s y && y <s 0 --> x <u y.
12021 // If we can prove the left part, the right part is also proven.
12022 return isImpliedCondOperands(CanonicalFoundPred, CanonicalLHS,
12023 CanonicalRHS, CanonicalFoundLHS,
12024 CanonicalFoundRHS);
12025 }
12026
12027 // Check if we can make progress by sharpening ranges.
12028 if (FoundPred == ICmpInst::ICMP_NE &&
12029 (isa<SCEVConstant>(FoundLHS) || isa<SCEVConstant>(FoundRHS))) {
12030
12031 const SCEVConstant *C = nullptr;
12032 const SCEV *V = nullptr;
12033
12034 if (isa<SCEVConstant>(FoundLHS)) {
12035 C = cast<SCEVConstant>(FoundLHS);
12036 V = FoundRHS;
12037 } else {
12038 C = cast<SCEVConstant>(FoundRHS);
12039 V = FoundLHS;
12040 }
12041
12042 // The guarding predicate tells us that C != V. If the known range
12043 // of V is [C, t), we can sharpen the range to [C + 1, t). The
12044 // range we consider has to correspond to same signedness as the
12045 // predicate we're interested in folding.
12046
12047 APInt Min = ICmpInst::isSigned(Pred) ?
12049
12050 if (Min == C->getAPInt()) {
12051 // Given (V >= Min && V != Min) we conclude V >= (Min + 1).
12052 // This is true even if (Min + 1) wraps around -- in case of
12053 // wraparound, (Min + 1) < Min, so (V >= Min => V >= (Min + 1)).
12054
12055 APInt SharperMin = Min + 1;
12056
12057 switch (Pred) {
12058 case ICmpInst::ICMP_SGE:
12059 case ICmpInst::ICMP_UGE:
12060 // We know V `Pred` SharperMin. If this implies LHS `Pred`
12061 // RHS, we're done.
12062 if (isImpliedCondOperands(Pred, LHS, RHS, V, getConstant(SharperMin),
12063 CtxI))
12064 return true;
12065 [[fallthrough]];
12066
12067 case ICmpInst::ICMP_SGT:
12068 case ICmpInst::ICMP_UGT:
12069 // We know from the range information that (V `Pred` Min ||
12070 // V == Min). We know from the guarding condition that !(V
12071 // == Min). This gives us
12072 //
12073 // V `Pred` Min || V == Min && !(V == Min)
12074 // => V `Pred` Min
12075 //
12076 // If V `Pred` Min implies LHS `Pred` RHS, we're done.
12077
12078 if (isImpliedCondOperands(Pred, LHS, RHS, V, getConstant(Min), CtxI))
12079 return true;
12080 break;
12081
12082 // `LHS < RHS` and `LHS <= RHS` are handled in the same way as `RHS > LHS` and `RHS >= LHS` respectively.
12083 case ICmpInst::ICMP_SLE:
12084 case ICmpInst::ICMP_ULE:
12085 if (isImpliedCondOperands(ICmpInst::getSwappedCmpPredicate(Pred), RHS,
12086 LHS, V, getConstant(SharperMin), CtxI))
12087 return true;
12088 [[fallthrough]];
12089
12090 case ICmpInst::ICMP_SLT:
12091 case ICmpInst::ICMP_ULT:
12092 if (isImpliedCondOperands(ICmpInst::getSwappedCmpPredicate(Pred), RHS,
12093 LHS, V, getConstant(Min), CtxI))
12094 return true;
12095 break;
12096
12097 default:
12098 // No change
12099 break;
12100 }
12101 }
12102 }
12103
12104 // Check whether the actual condition is beyond sufficient.
12105 if (FoundPred == ICmpInst::ICMP_EQ)
12106 if (ICmpInst::isTrueWhenEqual(Pred))
12107 if (isImpliedCondOperands(Pred, LHS, RHS, FoundLHS, FoundRHS, CtxI))
12108 return true;
12109 if (Pred == ICmpInst::ICMP_NE)
12110 if (!ICmpInst::isTrueWhenEqual(FoundPred))
12111 if (isImpliedCondOperands(FoundPred, LHS, RHS, FoundLHS, FoundRHS, CtxI))
12112 return true;
12113
12114 if (isImpliedCondOperandsViaRanges(Pred, LHS, RHS, FoundPred, FoundLHS, FoundRHS))
12115 return true;
12116
12117 // Otherwise assume the worst.
12118 return false;
12119}
12120
12121bool ScalarEvolution::splitBinaryAdd(const SCEV *Expr,
12122 const SCEV *&L, const SCEV *&R,
12123 SCEV::NoWrapFlags &Flags) {
12124 const auto *AE = dyn_cast<SCEVAddExpr>(Expr);
12125 if (!AE || AE->getNumOperands() != 2)
12126 return false;
12127
12128 L = AE->getOperand(0);
12129 R = AE->getOperand(1);
12130 Flags = AE->getNoWrapFlags();
12131 return true;
12132}
12133
12134std::optional<APInt>
12136 // We avoid subtracting expressions here because this function is usually
12137 // fairly deep in the call stack (i.e. is called many times).
12138
12139 unsigned BW = getTypeSizeInBits(More->getType());
12140 APInt Diff(BW, 0);
12141 APInt DiffMul(BW, 1);
12142 // Try various simplifications to reduce the difference to a constant. Limit
12143 // the number of allowed simplifications to keep compile-time low.
12144 for (unsigned I = 0; I < 8; ++I) {
12145 if (More == Less)
12146 return Diff;
12147
12148 // Reduce addrecs with identical steps to their start value.
12150 const auto *LAR = cast<SCEVAddRecExpr>(Less);
12151 const auto *MAR = cast<SCEVAddRecExpr>(More);
12152
12153 if (LAR->getLoop() != MAR->getLoop())
12154 return std::nullopt;
12155
12156 // We look at affine expressions only; not for correctness but to keep
12157 // getStepRecurrence cheap.
12158 if (!LAR->isAffine() || !MAR->isAffine())
12159 return std::nullopt;
12160
12161 if (LAR->getStepRecurrence(*this) != MAR->getStepRecurrence(*this))
12162 return std::nullopt;
12163
12164 Less = LAR->getStart();
12165 More = MAR->getStart();
12166 continue;
12167 }
12168
12169 // Try to match a common constant multiply.
12170 auto MatchConstMul =
12171 [](const SCEV *S) -> std::optional<std::pair<const SCEV *, APInt>> {
12172 auto *M = dyn_cast<SCEVMulExpr>(S);
12173 if (!M || M->getNumOperands() != 2 ||
12174 !isa<SCEVConstant>(M->getOperand(0)))
12175 return std::nullopt;
12176 return {
12177 {M->getOperand(1), cast<SCEVConstant>(M->getOperand(0))->getAPInt()}};
12178 };
12179 if (auto MatchedMore = MatchConstMul(More)) {
12180 if (auto MatchedLess = MatchConstMul(Less)) {
12181 if (MatchedMore->second == MatchedLess->second) {
12182 More = MatchedMore->first;
12183 Less = MatchedLess->first;
12184 DiffMul *= MatchedMore->second;
12185 continue;
12186 }
12187 }
12188 }
12189
12190 // Try to cancel out common factors in two add expressions.
12192 auto Add = [&](const SCEV *S, int Mul) {
12193 if (auto *C = dyn_cast<SCEVConstant>(S)) {
12194 if (Mul == 1) {
12195 Diff += C->getAPInt() * DiffMul;
12196 } else {
12197 assert(Mul == -1);
12198 Diff -= C->getAPInt() * DiffMul;
12199 }
12200 } else
12201 Multiplicity[S] += Mul;
12202 };
12203 auto Decompose = [&](const SCEV *S, int Mul) {
12204 if (isa<SCEVAddExpr>(S)) {
12205 for (const SCEV *Op : S->operands())
12206 Add(Op, Mul);
12207 } else
12208 Add(S, Mul);
12209 };
12210 Decompose(More, 1);
12211 Decompose(Less, -1);
12212
12213 // Check whether all the non-constants cancel out, or reduce to new
12214 // More/Less values.
12215 const SCEV *NewMore = nullptr, *NewLess = nullptr;
12216 for (const auto &[S, Mul] : Multiplicity) {
12217 if (Mul == 0)
12218 continue;
12219 if (Mul == 1) {
12220 if (NewMore)
12221 return std::nullopt;
12222 NewMore = S;
12223 } else if (Mul == -1) {
12224 if (NewLess)
12225 return std::nullopt;
12226 NewLess = S;
12227 } else
12228 return std::nullopt;
12229 }
12230
12231 // Values stayed the same, no point in trying further.
12232 if (NewMore == More || NewLess == Less)
12233 return std::nullopt;
12234
12235 More = NewMore;
12236 Less = NewLess;
12237
12238 // Reduced to constant.
12239 if (!More && !Less)
12240 return Diff;
12241
12242 // Left with variable on only one side, bail out.
12243 if (!More || !Less)
12244 return std::nullopt;
12245 }
12246
12247 // Did not reduce to constant.
12248 return std::nullopt;
12249}
12250
12251bool ScalarEvolution::isImpliedCondOperandsViaAddRecStart(
12252 CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS, const SCEV *FoundLHS,
12253 const SCEV *FoundRHS, const Instruction *CtxI) {
12254 // Try to recognize the following pattern:
12255 //
12256 // FoundRHS = ...
12257 // ...
12258 // loop:
12259 // FoundLHS = {Start,+,W}
12260 // context_bb: // Basic block from the same loop
12261 // known(Pred, FoundLHS, FoundRHS)
12262 //
12263 // If some predicate is known in the context of a loop, it is also known on
12264 // each iteration of this loop, including the first iteration. Therefore, in
12265 // this case, `FoundLHS Pred FoundRHS` implies `Start Pred FoundRHS`. Try to
12266 // prove the original pred using this fact.
12267 if (!CtxI)
12268 return false;
12269 const BasicBlock *ContextBB = CtxI->getParent();
12270 // Make sure AR varies in the context block.
12271 if (auto *AR = dyn_cast<SCEVAddRecExpr>(FoundLHS)) {
12272 const Loop *L = AR->getLoop();
12273 // Make sure that context belongs to the loop and executes on 1st iteration
12274 // (if it ever executes at all).
12275 if (!L->contains(ContextBB) || !DT.dominates(ContextBB, L->getLoopLatch()))
12276 return false;
12277 if (!isAvailableAtLoopEntry(FoundRHS, AR->getLoop()))
12278 return false;
12279 return isImpliedCondOperands(Pred, LHS, RHS, AR->getStart(), FoundRHS);
12280 }
12281
12282 if (auto *AR = dyn_cast<SCEVAddRecExpr>(FoundRHS)) {
12283 const Loop *L = AR->getLoop();
12284 // Make sure that context belongs to the loop and executes on 1st iteration
12285 // (if it ever executes at all).
12286 if (!L->contains(ContextBB) || !DT.dominates(ContextBB, L->getLoopLatch()))
12287 return false;
12288 if (!isAvailableAtLoopEntry(FoundLHS, AR->getLoop()))
12289 return false;
12290 return isImpliedCondOperands(Pred, LHS, RHS, FoundLHS, AR->getStart());
12291 }
12292
12293 return false;
12294}
12295
12296bool ScalarEvolution::isImpliedCondOperandsViaNoOverflow(CmpPredicate Pred,
12297 const SCEV *LHS,
12298 const SCEV *RHS,
12299 const SCEV *FoundLHS,
12300 const SCEV *FoundRHS) {
12301 if (Pred != CmpInst::ICMP_SLT && Pred != CmpInst::ICMP_ULT)
12302 return false;
12303
12304 const auto *AddRecLHS = dyn_cast<SCEVAddRecExpr>(LHS);
12305 if (!AddRecLHS)
12306 return false;
12307
12308 const auto *AddRecFoundLHS = dyn_cast<SCEVAddRecExpr>(FoundLHS);
12309 if (!AddRecFoundLHS)
12310 return false;
12311
12312 // We'd like to let SCEV reason about control dependencies, so we constrain
12313 // both the inequalities to be about add recurrences on the same loop. This
12314 // way we can use isLoopEntryGuardedByCond later.
12315
12316 const Loop *L = AddRecFoundLHS->getLoop();
12317 if (L != AddRecLHS->getLoop())
12318 return false;
12319
12320 // FoundLHS u< FoundRHS u< -C => (FoundLHS + C) u< (FoundRHS + C) ... (1)
12321 //
12322 // FoundLHS s< FoundRHS s< INT_MIN - C => (FoundLHS + C) s< (FoundRHS + C)
12323 // ... (2)
12324 //
12325 // Informal proof for (2), assuming (1) [*]:
12326 //
12327 // We'll also assume (A s< B) <=> ((A + INT_MIN) u< (B + INT_MIN)) ... (3)[**]
12328 //
12329 // Then
12330 //
12331 // FoundLHS s< FoundRHS s< INT_MIN - C
12332 // <=> (FoundLHS + INT_MIN) u< (FoundRHS + INT_MIN) u< -C [ using (3) ]
12333 // <=> (FoundLHS + INT_MIN + C) u< (FoundRHS + INT_MIN + C) [ using (1) ]
12334 // <=> (FoundLHS + INT_MIN + C + INT_MIN) s<
12335 // (FoundRHS + INT_MIN + C + INT_MIN) [ using (3) ]
12336 // <=> FoundLHS + C s< FoundRHS + C
12337 //
12338 // [*]: (1) can be proved by ruling out overflow.
12339 //
12340 // [**]: This can be proved by analyzing all the four possibilities:
12341 // (A s< 0, B s< 0), (A s< 0, B s>= 0), (A s>= 0, B s< 0) and
12342 // (A s>= 0, B s>= 0).
12343 //
12344 // Note:
12345 // Despite (2), "FoundRHS s< INT_MIN - C" does not mean that "FoundRHS + C"
12346 // will not sign underflow. For instance, say FoundLHS = (i8 -128), FoundRHS
12347 // = (i8 -127) and C = (i8 -100). Then INT_MIN - C = (i8 -28), and FoundRHS
12348 // s< (INT_MIN - C). Lack of sign overflow / underflow in "FoundRHS + C" is
12349 // neither necessary nor sufficient to prove "(FoundLHS + C) s< (FoundRHS +
12350 // C)".
12351
12352 std::optional<APInt> LDiff = computeConstantDifference(LHS, FoundLHS);
12353 if (!LDiff)
12354 return false;
12355 std::optional<APInt> RDiff = computeConstantDifference(RHS, FoundRHS);
12356 if (!RDiff || *LDiff != *RDiff)
12357 return false;
12358
12359 if (LDiff->isMinValue())
12360 return true;
12361
12362 APInt FoundRHSLimit;
12363
12364 if (Pred == CmpInst::ICMP_ULT) {
12365 FoundRHSLimit = -(*RDiff);
12366 } else {
12367 assert(Pred == CmpInst::ICMP_SLT && "Checked above!");
12368 FoundRHSLimit = APInt::getSignedMinValue(getTypeSizeInBits(RHS->getType())) - *RDiff;
12369 }
12370
12371 // Try to prove (1) or (2), as needed.
12372 return isAvailableAtLoopEntry(FoundRHS, L) &&
12373 isLoopEntryGuardedByCond(L, Pred, FoundRHS,
12374 getConstant(FoundRHSLimit));
12375}
12376
12377bool ScalarEvolution::isImpliedViaMerge(CmpPredicate Pred, const SCEV *LHS,
12378 const SCEV *RHS, const SCEV *FoundLHS,
12379 const SCEV *FoundRHS, unsigned Depth) {
12380 const PHINode *LPhi = nullptr, *RPhi = nullptr;
12381
12382 auto ClearOnExit = make_scope_exit([&]() {
12383 if (LPhi) {
12384 bool Erased = PendingMerges.erase(LPhi);
12385 assert(Erased && "Failed to erase LPhi!");
12386 (void)Erased;
12387 }
12388 if (RPhi) {
12389 bool Erased = PendingMerges.erase(RPhi);
12390 assert(Erased && "Failed to erase RPhi!");
12391 (void)Erased;
12392 }
12393 });
12394
12395 // Find respective Phis and check that they are not being pending.
12396 if (const SCEVUnknown *LU = dyn_cast<SCEVUnknown>(LHS))
12397 if (auto *Phi = dyn_cast<PHINode>(LU->getValue())) {
12398 if (!PendingMerges.insert(Phi).second)
12399 return false;
12400 LPhi = Phi;
12401 }
12402 if (const SCEVUnknown *RU = dyn_cast<SCEVUnknown>(RHS))
12403 if (auto *Phi = dyn_cast<PHINode>(RU->getValue())) {
12404 // If we detect a loop of Phi nodes being processed by this method, for
12405 // example:
12406 //
12407 // %a = phi i32 [ %some1, %preheader ], [ %b, %latch ]
12408 // %b = phi i32 [ %some2, %preheader ], [ %a, %latch ]
12409 //
12410 // we don't want to deal with a case that complex, so return conservative
12411 // answer false.
12412 if (!PendingMerges.insert(Phi).second)
12413 return false;
12414 RPhi = Phi;
12415 }
12416
12417 // If none of LHS, RHS is a Phi, nothing to do here.
12418 if (!LPhi && !RPhi)
12419 return false;
12420
12421 // If there is a SCEVUnknown Phi we are interested in, make it left.
12422 if (!LPhi) {
12423 std::swap(LHS, RHS);
12424 std::swap(FoundLHS, FoundRHS);
12425 std::swap(LPhi, RPhi);
12427 }
12428
12429 assert(LPhi && "LPhi should definitely be a SCEVUnknown Phi!");
12430 const BasicBlock *LBB = LPhi->getParent();
12431 const SCEVAddRecExpr *RAR = dyn_cast<SCEVAddRecExpr>(RHS);
12432
12433 auto ProvedEasily = [&](const SCEV *S1, const SCEV *S2) {
12434 return isKnownViaNonRecursiveReasoning(Pred, S1, S2) ||
12435 isImpliedCondOperandsViaRanges(Pred, S1, S2, Pred, FoundLHS, FoundRHS) ||
12436 isImpliedViaOperations(Pred, S1, S2, FoundLHS, FoundRHS, Depth);
12437 };
12438
12439 if (RPhi && RPhi->getParent() == LBB) {
12440 // Case one: RHS is also a SCEVUnknown Phi from the same basic block.
12441 // If we compare two Phis from the same block, and for each entry block
12442 // the predicate is true for incoming values from this block, then the
12443 // predicate is also true for the Phis.
12444 for (const BasicBlock *IncBB : predecessors(LBB)) {
12445 const SCEV *L = getSCEV(LPhi->getIncomingValueForBlock(IncBB));
12446 const SCEV *R = getSCEV(RPhi->getIncomingValueForBlock(IncBB));
12447 if (!ProvedEasily(L, R))
12448 return false;
12449 }
12450 } else if (RAR && RAR->getLoop()->getHeader() == LBB) {
12451 // Case two: RHS is also a Phi from the same basic block, and it is an
12452 // AddRec. It means that there is a loop which has both AddRec and Unknown
12453 // PHIs, for it we can compare incoming values of AddRec from above the loop
12454 // and latch with their respective incoming values of LPhi.
12455 // TODO: Generalize to handle loops with many inputs in a header.
12456 if (LPhi->getNumIncomingValues() != 2) return false;
12457
12458 auto *RLoop = RAR->getLoop();
12459 auto *Predecessor = RLoop->getLoopPredecessor();
12460 assert(Predecessor && "Loop with AddRec with no predecessor?");
12461 const SCEV *L1 = getSCEV(LPhi->getIncomingValueForBlock(Predecessor));
12462 if (!ProvedEasily(L1, RAR->getStart()))
12463 return false;
12464 auto *Latch = RLoop->getLoopLatch();
12465 assert(Latch && "Loop with AddRec with no latch?");
12466 const SCEV *L2 = getSCEV(LPhi->getIncomingValueForBlock(Latch));
12467 if (!ProvedEasily(L2, RAR->getPostIncExpr(*this)))
12468 return false;
12469 } else {
12470 // In all other cases go over inputs of LHS and compare each of them to RHS,
12471 // the predicate is true for (LHS, RHS) if it is true for all such pairs.
12472 // At this point RHS is either a non-Phi, or it is a Phi from some block
12473 // different from LBB.
12474 for (const BasicBlock *IncBB : predecessors(LBB)) {
12475 // Check that RHS is available in this block.
12476 if (!dominates(RHS, IncBB))
12477 return false;
12478 const SCEV *L = getSCEV(LPhi->getIncomingValueForBlock(IncBB));
12479 // Make sure L does not refer to a value from a potentially previous
12480 // iteration of a loop.
12481 if (!properlyDominates(L, LBB))
12482 return false;
12483 // Addrecs are considered to properly dominate their loop, so are missed
12484 // by the previous check. Discard any values that have computable
12485 // evolution in this loop.
12486 if (auto *Loop = LI.getLoopFor(LBB))
12487 if (hasComputableLoopEvolution(L, Loop))
12488 return false;
12489 if (!ProvedEasily(L, RHS))
12490 return false;
12491 }
12492 }
12493 return true;
12494}
12495
12496bool ScalarEvolution::isImpliedCondOperandsViaShift(CmpPredicate Pred,
12497 const SCEV *LHS,
12498 const SCEV *RHS,
12499 const SCEV *FoundLHS,
12500 const SCEV *FoundRHS) {
12501 // We want to imply LHS < RHS from LHS < (RHS >> shiftvalue). First, make
12502 // sure that we are dealing with same LHS.
12503 if (RHS == FoundRHS) {
12504 std::swap(LHS, RHS);
12505 std::swap(FoundLHS, FoundRHS);
12507 }
12508 if (LHS != FoundLHS)
12509 return false;
12510
12511 auto *SUFoundRHS = dyn_cast<SCEVUnknown>(FoundRHS);
12512 if (!SUFoundRHS)
12513 return false;
12514
12515 Value *Shiftee, *ShiftValue;
12516
12517 using namespace PatternMatch;
12518 if (match(SUFoundRHS->getValue(),
12519 m_LShr(m_Value(Shiftee), m_Value(ShiftValue)))) {
12520 auto *ShifteeS = getSCEV(Shiftee);
12521 // Prove one of the following:
12522 // LHS <u (shiftee >> shiftvalue) && shiftee <=u RHS ---> LHS <u RHS
12523 // LHS <=u (shiftee >> shiftvalue) && shiftee <=u RHS ---> LHS <=u RHS
12524 // LHS <s (shiftee >> shiftvalue) && shiftee <=s RHS && shiftee >=s 0
12525 // ---> LHS <s RHS
12526 // LHS <=s (shiftee >> shiftvalue) && shiftee <=s RHS && shiftee >=s 0
12527 // ---> LHS <=s RHS
12528 if (Pred == ICmpInst::ICMP_ULT || Pred == ICmpInst::ICMP_ULE)
12529 return isKnownPredicate(ICmpInst::ICMP_ULE, ShifteeS, RHS);
12530 if (Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_SLE)
12531 if (isKnownNonNegative(ShifteeS))
12532 return isKnownPredicate(ICmpInst::ICMP_SLE, ShifteeS, RHS);
12533 }
12534
12535 return false;
12536}
12537
12538bool ScalarEvolution::isImpliedCondOperands(CmpPredicate Pred, const SCEV *LHS,
12539 const SCEV *RHS,
12540 const SCEV *FoundLHS,
12541 const SCEV *FoundRHS,
12542 const Instruction *CtxI) {
12543 return isImpliedCondOperandsViaRanges(Pred, LHS, RHS, Pred, FoundLHS,
12544 FoundRHS) ||
12545 isImpliedCondOperandsViaNoOverflow(Pred, LHS, RHS, FoundLHS,
12546 FoundRHS) ||
12547 isImpliedCondOperandsViaShift(Pred, LHS, RHS, FoundLHS, FoundRHS) ||
12548 isImpliedCondOperandsViaAddRecStart(Pred, LHS, RHS, FoundLHS, FoundRHS,
12549 CtxI) ||
12550 isImpliedCondOperandsHelper(Pred, LHS, RHS, FoundLHS, FoundRHS);
12551}
12552
12553/// Is MaybeMinMaxExpr an (U|S)(Min|Max) of Candidate and some other values?
12554template <typename MinMaxExprType>
12555static bool IsMinMaxConsistingOf(const SCEV *MaybeMinMaxExpr,
12556 const SCEV *Candidate) {
12557 const MinMaxExprType *MinMaxExpr = dyn_cast<MinMaxExprType>(MaybeMinMaxExpr);
12558 if (!MinMaxExpr)
12559 return false;
12560
12561 return is_contained(MinMaxExpr->operands(), Candidate);
12562}
12563
12565 CmpPredicate Pred, const SCEV *LHS,
12566 const SCEV *RHS) {
12567 // If both sides are affine addrecs for the same loop, with equal
12568 // steps, and we know the recurrences don't wrap, then we only
12569 // need to check the predicate on the starting values.
12570
12571 if (!ICmpInst::isRelational(Pred))
12572 return false;
12573
12574 const SCEV *LStart, *RStart, *Step;
12575 const Loop *L;
12576 if (!match(LHS,
12577 m_scev_AffineAddRec(m_SCEV(LStart), m_SCEV(Step), m_Loop(L))) ||
12579 m_SpecificLoop(L))))
12580 return false;
12585 if (!LAR->getNoWrapFlags(NW) || !RAR->getNoWrapFlags(NW))
12586 return false;
12587
12588 return SE.isKnownPredicate(Pred, LStart, RStart);
12589}
12590
12591/// Is LHS `Pred` RHS true on the virtue of LHS or RHS being a Min or Max
12592/// expression?
12594 const SCEV *LHS, const SCEV *RHS) {
12595 switch (Pred) {
12596 default:
12597 return false;
12598
12599 case ICmpInst::ICMP_SGE:
12600 std::swap(LHS, RHS);
12601 [[fallthrough]];
12602 case ICmpInst::ICMP_SLE:
12603 return
12604 // min(A, ...) <= A
12606 // A <= max(A, ...)
12608
12609 case ICmpInst::ICMP_UGE:
12610 std::swap(LHS, RHS);
12611 [[fallthrough]];
12612 case ICmpInst::ICMP_ULE:
12613 return
12614 // min(A, ...) <= A
12615 // FIXME: what about umin_seq?
12617 // A <= max(A, ...)
12619 }
12620
12621 llvm_unreachable("covered switch fell through?!");
12622}
12623
12624bool ScalarEvolution::isImpliedViaOperations(CmpPredicate Pred, const SCEV *LHS,
12625 const SCEV *RHS,
12626 const SCEV *FoundLHS,
12627 const SCEV *FoundRHS,
12628 unsigned Depth) {
12631 "LHS and RHS have different sizes?");
12632 assert(getTypeSizeInBits(FoundLHS->getType()) ==
12633 getTypeSizeInBits(FoundRHS->getType()) &&
12634 "FoundLHS and FoundRHS have different sizes?");
12635 // We want to avoid hurting the compile time with analysis of too big trees.
12637 return false;
12638
12639 // We only want to work with GT comparison so far.
12640 if (ICmpInst::isLT(Pred)) {
12642 std::swap(LHS, RHS);
12643 std::swap(FoundLHS, FoundRHS);
12644 }
12645
12647
12648 // For unsigned, try to reduce it to corresponding signed comparison.
12649 if (P == ICmpInst::ICMP_UGT)
12650 // We can replace unsigned predicate with its signed counterpart if all
12651 // involved values are non-negative.
12652 // TODO: We could have better support for unsigned.
12653 if (isKnownNonNegative(FoundLHS) && isKnownNonNegative(FoundRHS)) {
12654 // Knowing that both FoundLHS and FoundRHS are non-negative, and knowing
12655 // FoundLHS >u FoundRHS, we also know that FoundLHS >s FoundRHS. Let us
12656 // use this fact to prove that LHS and RHS are non-negative.
12657 const SCEV *MinusOne = getMinusOne(LHS->getType());
12658 if (isImpliedCondOperands(ICmpInst::ICMP_SGT, LHS, MinusOne, FoundLHS,
12659 FoundRHS) &&
12660 isImpliedCondOperands(ICmpInst::ICMP_SGT, RHS, MinusOne, FoundLHS,
12661 FoundRHS))
12663 }
12664
12665 if (P != ICmpInst::ICMP_SGT)
12666 return false;
12667
12668 auto GetOpFromSExt = [&](const SCEV *S) {
12669 if (auto *Ext = dyn_cast<SCEVSignExtendExpr>(S))
12670 return Ext->getOperand();
12671 // TODO: If S is a SCEVConstant then you can cheaply "strip" the sext off
12672 // the constant in some cases.
12673 return S;
12674 };
12675
12676 // Acquire values from extensions.
12677 auto *OrigLHS = LHS;
12678 auto *OrigFoundLHS = FoundLHS;
12679 LHS = GetOpFromSExt(LHS);
12680 FoundLHS = GetOpFromSExt(FoundLHS);
12681
12682 // Is the SGT predicate can be proved trivially or using the found context.
12683 auto IsSGTViaContext = [&](const SCEV *S1, const SCEV *S2) {
12684 return isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_SGT, S1, S2) ||
12685 isImpliedViaOperations(ICmpInst::ICMP_SGT, S1, S2, OrigFoundLHS,
12686 FoundRHS, Depth + 1);
12687 };
12688
12689 if (auto *LHSAddExpr = dyn_cast<SCEVAddExpr>(LHS)) {
12690 // We want to avoid creation of any new non-constant SCEV. Since we are
12691 // going to compare the operands to RHS, we should be certain that we don't
12692 // need any size extensions for this. So let's decline all cases when the
12693 // sizes of types of LHS and RHS do not match.
12694 // TODO: Maybe try to get RHS from sext to catch more cases?
12696 return false;
12697
12698 // Should not overflow.
12699 if (!LHSAddExpr->hasNoSignedWrap())
12700 return false;
12701
12702 auto *LL = LHSAddExpr->getOperand(0);
12703 auto *LR = LHSAddExpr->getOperand(1);
12704 auto *MinusOne = getMinusOne(RHS->getType());
12705
12706 // Checks that S1 >= 0 && S2 > RHS, trivially or using the found context.
12707 auto IsSumGreaterThanRHS = [&](const SCEV *S1, const SCEV *S2) {
12708 return IsSGTViaContext(S1, MinusOne) && IsSGTViaContext(S2, RHS);
12709 };
12710 // Try to prove the following rule:
12711 // (LHS = LL + LR) && (LL >= 0) && (LR > RHS) => (LHS > RHS).
12712 // (LHS = LL + LR) && (LR >= 0) && (LL > RHS) => (LHS > RHS).
12713 if (IsSumGreaterThanRHS(LL, LR) || IsSumGreaterThanRHS(LR, LL))
12714 return true;
12715 } else if (auto *LHSUnknownExpr = dyn_cast<SCEVUnknown>(LHS)) {
12716 Value *LL, *LR;
12717 // FIXME: Once we have SDiv implemented, we can get rid of this matching.
12718
12719 using namespace llvm::PatternMatch;
12720
12721 if (match(LHSUnknownExpr->getValue(), m_SDiv(m_Value(LL), m_Value(LR)))) {
12722 // Rules for division.
12723 // We are going to perform some comparisons with Denominator and its
12724 // derivative expressions. In general case, creating a SCEV for it may
12725 // lead to a complex analysis of the entire graph, and in particular it
12726 // can request trip count recalculation for the same loop. This would
12727 // cache as SCEVCouldNotCompute to avoid the infinite recursion. To avoid
12728 // this, we only want to create SCEVs that are constants in this section.
12729 // So we bail if Denominator is not a constant.
12730 if (!isa<ConstantInt>(LR))
12731 return false;
12732
12733 auto *Denominator = cast<SCEVConstant>(getSCEV(LR));
12734
12735 // We want to make sure that LHS = FoundLHS / Denominator. If it is so,
12736 // then a SCEV for the numerator already exists and matches with FoundLHS.
12737 auto *Numerator = getExistingSCEV(LL);
12738 if (!Numerator || Numerator->getType() != FoundLHS->getType())
12739 return false;
12740
12741 // Make sure that the numerator matches with FoundLHS and the denominator
12742 // is positive.
12743 if (!HasSameValue(Numerator, FoundLHS) || !isKnownPositive(Denominator))
12744 return false;
12745
12746 auto *DTy = Denominator->getType();
12747 auto *FRHSTy = FoundRHS->getType();
12748 if (DTy->isPointerTy() != FRHSTy->isPointerTy())
12749 // One of types is a pointer and another one is not. We cannot extend
12750 // them properly to a wider type, so let us just reject this case.
12751 // TODO: Usage of getEffectiveSCEVType for DTy, FRHSTy etc should help
12752 // to avoid this check.
12753 return false;
12754
12755 // Given that:
12756 // FoundLHS > FoundRHS, LHS = FoundLHS / Denominator, Denominator > 0.
12757 auto *WTy = getWiderType(DTy, FRHSTy);
12758 auto *DenominatorExt = getNoopOrSignExtend(Denominator, WTy);
12759 auto *FoundRHSExt = getNoopOrSignExtend(FoundRHS, WTy);
12760
12761 // Try to prove the following rule:
12762 // (FoundRHS > Denominator - 2) && (RHS <= 0) => (LHS > RHS).
12763 // For example, given that FoundLHS > 2. It means that FoundLHS is at
12764 // least 3. If we divide it by Denominator < 4, we will have at least 1.
12765 auto *DenomMinusTwo = getMinusSCEV(DenominatorExt, getConstant(WTy, 2));
12766 if (isKnownNonPositive(RHS) &&
12767 IsSGTViaContext(FoundRHSExt, DenomMinusTwo))
12768 return true;
12769
12770 // Try to prove the following rule:
12771 // (FoundRHS > -1 - Denominator) && (RHS < 0) => (LHS > RHS).
12772 // For example, given that FoundLHS > -3. Then FoundLHS is at least -2.
12773 // If we divide it by Denominator > 2, then:
12774 // 1. If FoundLHS is negative, then the result is 0.
12775 // 2. If FoundLHS is non-negative, then the result is non-negative.
12776 // Anyways, the result is non-negative.
12777 auto *MinusOne = getMinusOne(WTy);
12778 auto *NegDenomMinusOne = getMinusSCEV(MinusOne, DenominatorExt);
12779 if (isKnownNegative(RHS) &&
12780 IsSGTViaContext(FoundRHSExt, NegDenomMinusOne))
12781 return true;
12782 }
12783 }
12784
12785 // If our expression contained SCEVUnknown Phis, and we split it down and now
12786 // need to prove something for them, try to prove the predicate for every
12787 // possible incoming values of those Phis.
12788 if (isImpliedViaMerge(Pred, OrigLHS, RHS, OrigFoundLHS, FoundRHS, Depth + 1))
12789 return true;
12790
12791 return false;
12792}
12793
12795 const SCEV *RHS) {
12796 // zext x u<= sext x, sext x s<= zext x
12797 const SCEV *Op;
12798 switch (Pred) {
12799 case ICmpInst::ICMP_SGE:
12800 std::swap(LHS, RHS);
12801 [[fallthrough]];
12802 case ICmpInst::ICMP_SLE: {
12803 // If operand >=s 0 then ZExt == SExt. If operand <s 0 then SExt <s ZExt.
12804 return match(LHS, m_scev_SExt(m_SCEV(Op))) &&
12806 }
12807 case ICmpInst::ICMP_UGE:
12808 std::swap(LHS, RHS);
12809 [[fallthrough]];
12810 case ICmpInst::ICMP_ULE: {
12811 // If operand >=u 0 then ZExt == SExt. If operand <u 0 then ZExt <u SExt.
12812 return match(LHS, m_scev_ZExt(m_SCEV(Op))) &&
12814 }
12815 default:
12816 return false;
12817 };
12818 llvm_unreachable("unhandled case");
12819}
12820
12821bool ScalarEvolution::isKnownViaNonRecursiveReasoning(CmpPredicate Pred,
12822 const SCEV *LHS,
12823 const SCEV *RHS) {
12824 return isKnownPredicateExtendIdiom(Pred, LHS, RHS) ||
12825 isKnownPredicateViaConstantRanges(Pred, LHS, RHS) ||
12826 IsKnownPredicateViaMinOrMax(*this, Pred, LHS, RHS) ||
12827 IsKnownPredicateViaAddRecStart(*this, Pred, LHS, RHS) ||
12828 isKnownPredicateViaNoOverflow(Pred, LHS, RHS);
12829}
12830
12831bool ScalarEvolution::isImpliedCondOperandsHelper(CmpPredicate Pred,
12832 const SCEV *LHS,
12833 const SCEV *RHS,
12834 const SCEV *FoundLHS,
12835 const SCEV *FoundRHS) {
12836 switch (Pred) {
12837 default:
12838 llvm_unreachable("Unexpected CmpPredicate value!");
12839 case ICmpInst::ICMP_EQ:
12840 case ICmpInst::ICMP_NE:
12841 if (HasSameValue(LHS, FoundLHS) && HasSameValue(RHS, FoundRHS))
12842 return true;
12843 break;
12844 case ICmpInst::ICMP_SLT:
12845 case ICmpInst::ICMP_SLE:
12846 if (isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_SLE, LHS, FoundLHS) &&
12847 isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_SGE, RHS, FoundRHS))
12848 return true;
12849 break;
12850 case ICmpInst::ICMP_SGT:
12851 case ICmpInst::ICMP_SGE:
12852 if (isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_SGE, LHS, FoundLHS) &&
12853 isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_SLE, RHS, FoundRHS))
12854 return true;
12855 break;
12856 case ICmpInst::ICMP_ULT:
12857 case ICmpInst::ICMP_ULE:
12858 if (isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_ULE, LHS, FoundLHS) &&
12859 isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_UGE, RHS, FoundRHS))
12860 return true;
12861 break;
12862 case ICmpInst::ICMP_UGT:
12863 case ICmpInst::ICMP_UGE:
12864 if (isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_UGE, LHS, FoundLHS) &&
12865 isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_ULE, RHS, FoundRHS))
12866 return true;
12867 break;
12868 }
12869
12870 // Maybe it can be proved via operations?
12871 if (isImpliedViaOperations(Pred, LHS, RHS, FoundLHS, FoundRHS))
12872 return true;
12873
12874 return false;
12875}
12876
12877bool ScalarEvolution::isImpliedCondOperandsViaRanges(
12878 CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS, CmpPredicate FoundPred,
12879 const SCEV *FoundLHS, const SCEV *FoundRHS) {
12880 if (!isa<SCEVConstant>(RHS) || !isa<SCEVConstant>(FoundRHS))
12881 // The restriction on `FoundRHS` be lifted easily -- it exists only to
12882 // reduce the compile time impact of this optimization.
12883 return false;
12884
12885 std::optional<APInt> Addend = computeConstantDifference(LHS, FoundLHS);
12886 if (!Addend)
12887 return false;
12888
12889 const APInt &ConstFoundRHS = cast<SCEVConstant>(FoundRHS)->getAPInt();
12890
12891 // `FoundLHSRange` is the range we know `FoundLHS` to be in by virtue of the
12892 // antecedent "`FoundLHS` `FoundPred` `FoundRHS`".
12893 ConstantRange FoundLHSRange =
12894 ConstantRange::makeExactICmpRegion(FoundPred, ConstFoundRHS);
12895
12896 // Since `LHS` is `FoundLHS` + `Addend`, we can compute a range for `LHS`:
12897 ConstantRange LHSRange = FoundLHSRange.add(ConstantRange(*Addend));
12898
12899 // We can also compute the range of values for `LHS` that satisfy the
12900 // consequent, "`LHS` `Pred` `RHS`":
12901 const APInt &ConstRHS = cast<SCEVConstant>(RHS)->getAPInt();
12902 // The antecedent implies the consequent if every value of `LHS` that
12903 // satisfies the antecedent also satisfies the consequent.
12904 return LHSRange.icmp(Pred, ConstRHS);
12905}
12906
12907bool ScalarEvolution::canIVOverflowOnLT(const SCEV *RHS, const SCEV *Stride,
12908 bool IsSigned) {
12909 assert(isKnownPositive(Stride) && "Positive stride expected!");
12910
12911 unsigned BitWidth = getTypeSizeInBits(RHS->getType());
12912 const SCEV *One = getOne(Stride->getType());
12913
12914 if (IsSigned) {
12915 APInt MaxRHS = getSignedRangeMax(RHS);
12916 APInt MaxValue = APInt::getSignedMaxValue(BitWidth);
12917 APInt MaxStrideMinusOne = getSignedRangeMax(getMinusSCEV(Stride, One));
12918
12919 // SMaxRHS + SMaxStrideMinusOne > SMaxValue => overflow!
12920 return (std::move(MaxValue) - MaxStrideMinusOne).slt(MaxRHS);
12921 }
12922
12923 APInt MaxRHS = getUnsignedRangeMax(RHS);
12924 APInt MaxValue = APInt::getMaxValue(BitWidth);
12925 APInt MaxStrideMinusOne = getUnsignedRangeMax(getMinusSCEV(Stride, One));
12926
12927 // UMaxRHS + UMaxStrideMinusOne > UMaxValue => overflow!
12928 return (std::move(MaxValue) - MaxStrideMinusOne).ult(MaxRHS);
12929}
12930
12931bool ScalarEvolution::canIVOverflowOnGT(const SCEV *RHS, const SCEV *Stride,
12932 bool IsSigned) {
12933
12934 unsigned BitWidth = getTypeSizeInBits(RHS->getType());
12935 const SCEV *One = getOne(Stride->getType());
12936
12937 if (IsSigned) {
12938 APInt MinRHS = getSignedRangeMin(RHS);
12939 APInt MinValue = APInt::getSignedMinValue(BitWidth);
12940 APInt MaxStrideMinusOne = getSignedRangeMax(getMinusSCEV(Stride, One));
12941
12942 // SMinRHS - SMaxStrideMinusOne < SMinValue => overflow!
12943 return (std::move(MinValue) + MaxStrideMinusOne).sgt(MinRHS);
12944 }
12945
12946 APInt MinRHS = getUnsignedRangeMin(RHS);
12947 APInt MinValue = APInt::getMinValue(BitWidth);
12948 APInt MaxStrideMinusOne = getUnsignedRangeMax(getMinusSCEV(Stride, One));
12949
12950 // UMinRHS - UMaxStrideMinusOne < UMinValue => overflow!
12951 return (std::move(MinValue) + MaxStrideMinusOne).ugt(MinRHS);
12952}
12953
12955 // umin(N, 1) + floor((N - umin(N, 1)) / D)
12956 // This is equivalent to "1 + floor((N - 1) / D)" for N != 0. The umin
12957 // expression fixes the case of N=0.
12958 const SCEV *MinNOne = getUMinExpr(N, getOne(N->getType()));
12959 const SCEV *NMinusOne = getMinusSCEV(N, MinNOne);
12960 return getAddExpr(MinNOne, getUDivExpr(NMinusOne, D));
12961}
12962
12963const SCEV *ScalarEvolution::computeMaxBECountForLT(const SCEV *Start,
12964 const SCEV *Stride,
12965 const SCEV *End,
12966 unsigned BitWidth,
12967 bool IsSigned) {
12968 // The logic in this function assumes we can represent a positive stride.
12969 // If we can't, the backedge-taken count must be zero.
12970 if (IsSigned && BitWidth == 1)
12971 return getZero(Stride->getType());
12972
12973 // This code below only been closely audited for negative strides in the
12974 // unsigned comparison case, it may be correct for signed comparison, but
12975 // that needs to be established.
12976 if (IsSigned && isKnownNegative(Stride))
12977 return getCouldNotCompute();
12978
12979 // Calculate the maximum backedge count based on the range of values
12980 // permitted by Start, End, and Stride.
12981 APInt MinStart =
12982 IsSigned ? getSignedRangeMin(Start) : getUnsignedRangeMin(Start);
12983
12984 APInt MinStride =
12985 IsSigned ? getSignedRangeMin(Stride) : getUnsignedRangeMin(Stride);
12986
12987 // We assume either the stride is positive, or the backedge-taken count
12988 // is zero. So force StrideForMaxBECount to be at least one.
12989 APInt One(BitWidth, 1);
12990 APInt StrideForMaxBECount = IsSigned ? APIntOps::smax(One, MinStride)
12991 : APIntOps::umax(One, MinStride);
12992
12993 APInt MaxValue = IsSigned ? APInt::getSignedMaxValue(BitWidth)
12994 : APInt::getMaxValue(BitWidth);
12995 APInt Limit = MaxValue - (StrideForMaxBECount - 1);
12996
12997 // Although End can be a MAX expression we estimate MaxEnd considering only
12998 // the case End = RHS of the loop termination condition. This is safe because
12999 // in the other case (End - Start) is zero, leading to a zero maximum backedge
13000 // taken count.
13001 APInt MaxEnd = IsSigned ? APIntOps::smin(getSignedRangeMax(End), Limit)
13002 : APIntOps::umin(getUnsignedRangeMax(End), Limit);
13003
13004 // MaxBECount = ceil((max(MaxEnd, MinStart) - MinStart) / Stride)
13005 MaxEnd = IsSigned ? APIntOps::smax(MaxEnd, MinStart)
13006 : APIntOps::umax(MaxEnd, MinStart);
13007
13008 return getUDivCeilSCEV(getConstant(MaxEnd - MinStart) /* Delta */,
13009 getConstant(StrideForMaxBECount) /* Step */);
13010}
13011
13013ScalarEvolution::howManyLessThans(const SCEV *LHS, const SCEV *RHS,
13014 const Loop *L, bool IsSigned,
13015 bool ControlsOnlyExit, bool AllowPredicates) {
13017
13019 bool PredicatedIV = false;
13020 if (!IV) {
13021 if (auto *ZExt = dyn_cast<SCEVZeroExtendExpr>(LHS)) {
13022 const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(ZExt->getOperand());
13023 if (AR && AR->getLoop() == L && AR->isAffine()) {
13024 auto canProveNUW = [&]() {
13025 // We can use the comparison to infer no-wrap flags only if it fully
13026 // controls the loop exit.
13027 if (!ControlsOnlyExit)
13028 return false;
13029
13030 if (!isLoopInvariant(RHS, L))
13031 return false;
13032
13033 if (!isKnownNonZero(AR->getStepRecurrence(*this)))
13034 // We need the sequence defined by AR to strictly increase in the
13035 // unsigned integer domain for the logic below to hold.
13036 return false;
13037
13038 const unsigned InnerBitWidth = getTypeSizeInBits(AR->getType());
13039 const unsigned OuterBitWidth = getTypeSizeInBits(RHS->getType());
13040 // If RHS <=u Limit, then there must exist a value V in the sequence
13041 // defined by AR (e.g. {Start,+,Step}) such that V >u RHS, and
13042 // V <=u UINT_MAX. Thus, we must exit the loop before unsigned
13043 // overflow occurs. This limit also implies that a signed comparison
13044 // (in the wide bitwidth) is equivalent to an unsigned comparison as
13045 // the high bits on both sides must be zero.
13046 APInt StrideMax = getUnsignedRangeMax(AR->getStepRecurrence(*this));
13047 APInt Limit = APInt::getMaxValue(InnerBitWidth) - (StrideMax - 1);
13048 Limit = Limit.zext(OuterBitWidth);
13049 return getUnsignedRangeMax(applyLoopGuards(RHS, L)).ule(Limit);
13050 };
13051 auto Flags = AR->getNoWrapFlags();
13052 if (!hasFlags(Flags, SCEV::FlagNUW) && canProveNUW())
13053 Flags = setFlags(Flags, SCEV::FlagNUW);
13054
13055 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), Flags);
13056 if (AR->hasNoUnsignedWrap()) {
13057 // Emulate what getZeroExtendExpr would have done during construction
13058 // if we'd been able to infer the fact just above at that time.
13059 const SCEV *Step = AR->getStepRecurrence(*this);
13060 Type *Ty = ZExt->getType();
13061 auto *S = getAddRecExpr(
13063 getZeroExtendExpr(Step, Ty, 0), L, AR->getNoWrapFlags());
13065 }
13066 }
13067 }
13068 }
13069
13070
13071 if (!IV && AllowPredicates) {
13072 // Try to make this an AddRec using runtime tests, in the first X
13073 // iterations of this loop, where X is the SCEV expression found by the
13074 // algorithm below.
13075 IV = convertSCEVToAddRecWithPredicates(LHS, L, Predicates);
13076 PredicatedIV = true;
13077 }
13078
13079 // Avoid weird loops
13080 if (!IV || IV->getLoop() != L || !IV->isAffine())
13081 return getCouldNotCompute();
13082
13083 // A precondition of this method is that the condition being analyzed
13084 // reaches an exiting branch which dominates the latch. Given that, we can
13085 // assume that an increment which violates the nowrap specification and
13086 // produces poison must cause undefined behavior when the resulting poison
13087 // value is branched upon and thus we can conclude that the backedge is
13088 // taken no more often than would be required to produce that poison value.
13089 // Note that a well defined loop can exit on the iteration which violates
13090 // the nowrap specification if there is another exit (either explicit or
13091 // implicit/exceptional) which causes the loop to execute before the
13092 // exiting instruction we're analyzing would trigger UB.
13093 auto WrapType = IsSigned ? SCEV::FlagNSW : SCEV::FlagNUW;
13094 bool NoWrap = ControlsOnlyExit && IV->getNoWrapFlags(WrapType);
13096
13097 const SCEV *Stride = IV->getStepRecurrence(*this);
13098
13099 bool PositiveStride = isKnownPositive(Stride);
13100
13101 // Avoid negative or zero stride values.
13102 if (!PositiveStride) {
13103 // We can compute the correct backedge taken count for loops with unknown
13104 // strides if we can prove that the loop is not an infinite loop with side
13105 // effects. Here's the loop structure we are trying to handle -
13106 //
13107 // i = start
13108 // do {
13109 // A[i] = i;
13110 // i += s;
13111 // } while (i < end);
13112 //
13113 // The backedge taken count for such loops is evaluated as -
13114 // (max(end, start + stride) - start - 1) /u stride
13115 //
13116 // The additional preconditions that we need to check to prove correctness
13117 // of the above formula is as follows -
13118 //
13119 // a) IV is either nuw or nsw depending upon signedness (indicated by the
13120 // NoWrap flag).
13121 // b) the loop is guaranteed to be finite (e.g. is mustprogress and has
13122 // no side effects within the loop)
13123 // c) loop has a single static exit (with no abnormal exits)
13124 //
13125 // Precondition a) implies that if the stride is negative, this is a single
13126 // trip loop. The backedge taken count formula reduces to zero in this case.
13127 //
13128 // Precondition b) and c) combine to imply that if rhs is invariant in L,
13129 // then a zero stride means the backedge can't be taken without executing
13130 // undefined behavior.
13131 //
13132 // The positive stride case is the same as isKnownPositive(Stride) returning
13133 // true (original behavior of the function).
13134 //
13135 if (PredicatedIV || !NoWrap || !loopIsFiniteByAssumption(L) ||
13137 return getCouldNotCompute();
13138
13139 if (!isKnownNonZero(Stride)) {
13140 // If we have a step of zero, and RHS isn't invariant in L, we don't know
13141 // if it might eventually be greater than start and if so, on which
13142 // iteration. We can't even produce a useful upper bound.
13143 if (!isLoopInvariant(RHS, L))
13144 return getCouldNotCompute();
13145
13146 // We allow a potentially zero stride, but we need to divide by stride
13147 // below. Since the loop can't be infinite and this check must control
13148 // the sole exit, we can infer the exit must be taken on the first
13149 // iteration (e.g. backedge count = 0) if the stride is zero. Given that,
13150 // we know the numerator in the divides below must be zero, so we can
13151 // pick an arbitrary non-zero value for the denominator (e.g. stride)
13152 // and produce the right result.
13153 // FIXME: Handle the case where Stride is poison?
13154 auto wouldZeroStrideBeUB = [&]() {
13155 // Proof by contradiction. Suppose the stride were zero. If we can
13156 // prove that the backedge *is* taken on the first iteration, then since
13157 // we know this condition controls the sole exit, we must have an
13158 // infinite loop. We can't have a (well defined) infinite loop per
13159 // check just above.
13160 // Note: The (Start - Stride) term is used to get the start' term from
13161 // (start' + stride,+,stride). Remember that we only care about the
13162 // result of this expression when stride == 0 at runtime.
13163 auto *StartIfZero = getMinusSCEV(IV->getStart(), Stride);
13164 return isLoopEntryGuardedByCond(L, Cond, StartIfZero, RHS);
13165 };
13166 if (!wouldZeroStrideBeUB()) {
13167 Stride = getUMaxExpr(Stride, getOne(Stride->getType()));
13168 }
13169 }
13170 } else if (!NoWrap) {
13171 // Avoid proven overflow cases: this will ensure that the backedge taken
13172 // count will not generate any unsigned overflow.
13173 if (canIVOverflowOnLT(RHS, Stride, IsSigned))
13174 return getCouldNotCompute();
13175 }
13176
13177 // On all paths just preceeding, we established the following invariant:
13178 // IV can be assumed not to overflow up to and including the exiting
13179 // iteration. We proved this in one of two ways:
13180 // 1) We can show overflow doesn't occur before the exiting iteration
13181 // 1a) canIVOverflowOnLT, and b) step of one
13182 // 2) We can show that if overflow occurs, the loop must execute UB
13183 // before any possible exit.
13184 // Note that we have not yet proved RHS invariant (in general).
13185
13186 const SCEV *Start = IV->getStart();
13187
13188 // Preserve pointer-typed Start/RHS to pass to isLoopEntryGuardedByCond.
13189 // If we convert to integers, isLoopEntryGuardedByCond will miss some cases.
13190 // Use integer-typed versions for actual computation; we can't subtract
13191 // pointers in general.
13192 const SCEV *OrigStart = Start;
13193 const SCEV *OrigRHS = RHS;
13194 if (Start->getType()->isPointerTy()) {
13196 if (isa<SCEVCouldNotCompute>(Start))
13197 return Start;
13198 }
13199 if (RHS->getType()->isPointerTy()) {
13202 return RHS;
13203 }
13204
13205 const SCEV *End = nullptr, *BECount = nullptr,
13206 *BECountIfBackedgeTaken = nullptr;
13207 if (!isLoopInvariant(RHS, L)) {
13208 const auto *RHSAddRec = dyn_cast<SCEVAddRecExpr>(RHS);
13209 if (PositiveStride && RHSAddRec != nullptr && RHSAddRec->getLoop() == L &&
13210 RHSAddRec->getNoWrapFlags()) {
13211 // The structure of loop we are trying to calculate backedge count of:
13212 //
13213 // left = left_start
13214 // right = right_start
13215 //
13216 // while(left < right){
13217 // ... do something here ...
13218 // left += s1; // stride of left is s1 (s1 > 0)
13219 // right += s2; // stride of right is s2 (s2 < 0)
13220 // }
13221 //
13222
13223 const SCEV *RHSStart = RHSAddRec->getStart();
13224 const SCEV *RHSStride = RHSAddRec->getStepRecurrence(*this);
13225
13226 // If Stride - RHSStride is positive and does not overflow, we can write
13227 // backedge count as ->
13228 // ceil((End - Start) /u (Stride - RHSStride))
13229 // Where, End = max(RHSStart, Start)
13230
13231 // Check if RHSStride < 0 and Stride - RHSStride will not overflow.
13232 if (isKnownNegative(RHSStride) &&
13233 willNotOverflow(Instruction::Sub, /*Signed=*/true, Stride,
13234 RHSStride)) {
13235
13236 const SCEV *Denominator = getMinusSCEV(Stride, RHSStride);
13237 if (isKnownPositive(Denominator)) {
13238 End = IsSigned ? getSMaxExpr(RHSStart, Start)
13239 : getUMaxExpr(RHSStart, Start);
13240
13241 // We can do this because End >= Start, as End = max(RHSStart, Start)
13242 const SCEV *Delta = getMinusSCEV(End, Start);
13243
13244 BECount = getUDivCeilSCEV(Delta, Denominator);
13245 BECountIfBackedgeTaken =
13246 getUDivCeilSCEV(getMinusSCEV(RHSStart, Start), Denominator);
13247 }
13248 }
13249 }
13250 if (BECount == nullptr) {
13251 // If we cannot calculate ExactBECount, we can calculate the MaxBECount,
13252 // given the start, stride and max value for the end bound of the
13253 // loop (RHS), and the fact that IV does not overflow (which is
13254 // checked above).
13255 const SCEV *MaxBECount = computeMaxBECountForLT(
13256 Start, Stride, RHS, getTypeSizeInBits(LHS->getType()), IsSigned);
13257 return ExitLimit(getCouldNotCompute() /* ExactNotTaken */, MaxBECount,
13258 MaxBECount, false /*MaxOrZero*/, Predicates);
13259 }
13260 } else {
13261 // We use the expression (max(End,Start)-Start)/Stride to describe the
13262 // backedge count, as if the backedge is taken at least once
13263 // max(End,Start) is End and so the result is as above, and if not
13264 // max(End,Start) is Start so we get a backedge count of zero.
13265 auto *OrigStartMinusStride = getMinusSCEV(OrigStart, Stride);
13266 assert(isAvailableAtLoopEntry(OrigStartMinusStride, L) && "Must be!");
13267 assert(isAvailableAtLoopEntry(OrigStart, L) && "Must be!");
13268 assert(isAvailableAtLoopEntry(OrigRHS, L) && "Must be!");
13269 // Can we prove (max(RHS,Start) > Start - Stride?
13270 if (isLoopEntryGuardedByCond(L, Cond, OrigStartMinusStride, OrigStart) &&
13271 isLoopEntryGuardedByCond(L, Cond, OrigStartMinusStride, OrigRHS)) {
13272 // In this case, we can use a refined formula for computing backedge
13273 // taken count. The general formula remains:
13274 // "End-Start /uceiling Stride" where "End = max(RHS,Start)"
13275 // We want to use the alternate formula:
13276 // "((End - 1) - (Start - Stride)) /u Stride"
13277 // Let's do a quick case analysis to show these are equivalent under
13278 // our precondition that max(RHS,Start) > Start - Stride.
13279 // * For RHS <= Start, the backedge-taken count must be zero.
13280 // "((End - 1) - (Start - Stride)) /u Stride" reduces to
13281 // "((Start - 1) - (Start - Stride)) /u Stride" which simplies to
13282 // "Stride - 1 /u Stride" which is indeed zero for all non-zero values
13283 // of Stride. For 0 stride, we've use umin(1,Stride) above,
13284 // reducing this to the stride of 1 case.
13285 // * For RHS >= Start, the backedge count must be "RHS-Start /uceil
13286 // Stride".
13287 // "((End - 1) - (Start - Stride)) /u Stride" reduces to
13288 // "((RHS - 1) - (Start - Stride)) /u Stride" reassociates to
13289 // "((RHS - (Start - Stride) - 1) /u Stride".
13290 // Our preconditions trivially imply no overflow in that form.
13291 const SCEV *MinusOne = getMinusOne(Stride->getType());
13292 const SCEV *Numerator =
13293 getMinusSCEV(getAddExpr(RHS, MinusOne), getMinusSCEV(Start, Stride));
13294 BECount = getUDivExpr(Numerator, Stride);
13295 }
13296
13297 if (!BECount) {
13298 auto canProveRHSGreaterThanEqualStart = [&]() {
13299 auto CondGE = IsSigned ? ICmpInst::ICMP_SGE : ICmpInst::ICMP_UGE;
13300 const SCEV *GuardedRHS = applyLoopGuards(OrigRHS, L);
13301 const SCEV *GuardedStart = applyLoopGuards(OrigStart, L);
13302
13303 if (isLoopEntryGuardedByCond(L, CondGE, OrigRHS, OrigStart) ||
13304 isKnownPredicate(CondGE, GuardedRHS, GuardedStart))
13305 return true;
13306
13307 // (RHS > Start - 1) implies RHS >= Start.
13308 // * "RHS >= Start" is trivially equivalent to "RHS > Start - 1" if
13309 // "Start - 1" doesn't overflow.
13310 // * For signed comparison, if Start - 1 does overflow, it's equal
13311 // to INT_MAX, and "RHS >s INT_MAX" is trivially false.
13312 // * For unsigned comparison, if Start - 1 does overflow, it's equal
13313 // to UINT_MAX, and "RHS >u UINT_MAX" is trivially false.
13314 //
13315 // FIXME: Should isLoopEntryGuardedByCond do this for us?
13316 auto CondGT = IsSigned ? ICmpInst::ICMP_SGT : ICmpInst::ICMP_UGT;
13317 auto *StartMinusOne =
13318 getAddExpr(OrigStart, getMinusOne(OrigStart->getType()));
13319 return isLoopEntryGuardedByCond(L, CondGT, OrigRHS, StartMinusOne);
13320 };
13321
13322 // If we know that RHS >= Start in the context of loop, then we know
13323 // that max(RHS, Start) = RHS at this point.
13324 if (canProveRHSGreaterThanEqualStart()) {
13325 End = RHS;
13326 } else {
13327 // If RHS < Start, the backedge will be taken zero times. So in
13328 // general, we can write the backedge-taken count as:
13329 //
13330 // RHS >= Start ? ceil(RHS - Start) / Stride : 0
13331 //
13332 // We convert it to the following to make it more convenient for SCEV:
13333 //
13334 // ceil(max(RHS, Start) - Start) / Stride
13335 End = IsSigned ? getSMaxExpr(RHS, Start) : getUMaxExpr(RHS, Start);
13336
13337 // See what would happen if we assume the backedge is taken. This is
13338 // used to compute MaxBECount.
13339 BECountIfBackedgeTaken =
13340 getUDivCeilSCEV(getMinusSCEV(RHS, Start), Stride);
13341 }
13342
13343 // At this point, we know:
13344 //
13345 // 1. If IsSigned, Start <=s End; otherwise, Start <=u End
13346 // 2. The index variable doesn't overflow.
13347 //
13348 // Therefore, we know N exists such that
13349 // (Start + Stride * N) >= End, and computing "(Start + Stride * N)"
13350 // doesn't overflow.
13351 //
13352 // Using this information, try to prove whether the addition in
13353 // "(Start - End) + (Stride - 1)" has unsigned overflow.
13354 const SCEV *One = getOne(Stride->getType());
13355 bool MayAddOverflow = [&] {
13356 if (isKnownToBeAPowerOfTwo(Stride)) {
13357 // Suppose Stride is a power of two, and Start/End are unsigned
13358 // integers. Let UMAX be the largest representable unsigned
13359 // integer.
13360 //
13361 // By the preconditions of this function, we know
13362 // "(Start + Stride * N) >= End", and this doesn't overflow.
13363 // As a formula:
13364 //
13365 // End <= (Start + Stride * N) <= UMAX
13366 //
13367 // Subtracting Start from all the terms:
13368 //
13369 // End - Start <= Stride * N <= UMAX - Start
13370 //
13371 // Since Start is unsigned, UMAX - Start <= UMAX. Therefore:
13372 //
13373 // End - Start <= Stride * N <= UMAX
13374 //
13375 // Stride * N is a multiple of Stride. Therefore,
13376 //
13377 // End - Start <= Stride * N <= UMAX - (UMAX mod Stride)
13378 //
13379 // Since Stride is a power of two, UMAX + 1 is divisible by
13380 // Stride. Therefore, UMAX mod Stride == Stride - 1. So we can
13381 // write:
13382 //
13383 // End - Start <= Stride * N <= UMAX - Stride - 1
13384 //
13385 // Dropping the middle term:
13386 //
13387 // End - Start <= UMAX - Stride - 1
13388 //
13389 // Adding Stride - 1 to both sides:
13390 //
13391 // (End - Start) + (Stride - 1) <= UMAX
13392 //
13393 // In other words, the addition doesn't have unsigned overflow.
13394 //
13395 // A similar proof works if we treat Start/End as signed values.
13396 // Just rewrite steps before "End - Start <= Stride * N <= UMAX"
13397 // to use signed max instead of unsigned max. Note that we're
13398 // trying to prove a lack of unsigned overflow in either case.
13399 return false;
13400 }
13401 if (Start == Stride || Start == getMinusSCEV(Stride, One)) {
13402 // If Start is equal to Stride, (End - Start) + (Stride - 1) == End
13403 // - 1. If !IsSigned, 0 <u Stride == Start <=u End; so 0 <u End - 1
13404 // <u End. If IsSigned, 0 <s Stride == Start <=s End; so 0 <s End -
13405 // 1 <s End.
13406 //
13407 // If Start is equal to Stride - 1, (End - Start) + Stride - 1 ==
13408 // End.
13409 return false;
13410 }
13411 return true;
13412 }();
13413
13414 const SCEV *Delta = getMinusSCEV(End, Start);
13415 if (!MayAddOverflow) {
13416 // floor((D + (S - 1)) / S)
13417 // We prefer this formulation if it's legal because it's fewer
13418 // operations.
13419 BECount =
13420 getUDivExpr(getAddExpr(Delta, getMinusSCEV(Stride, One)), Stride);
13421 } else {
13422 BECount = getUDivCeilSCEV(Delta, Stride);
13423 }
13424 }
13425 }
13426
13427 const SCEV *ConstantMaxBECount;
13428 bool MaxOrZero = false;
13429 if (isa<SCEVConstant>(BECount)) {
13430 ConstantMaxBECount = BECount;
13431 } else if (BECountIfBackedgeTaken &&
13432 isa<SCEVConstant>(BECountIfBackedgeTaken)) {
13433 // If we know exactly how many times the backedge will be taken if it's
13434 // taken at least once, then the backedge count will either be that or
13435 // zero.
13436 ConstantMaxBECount = BECountIfBackedgeTaken;
13437 MaxOrZero = true;
13438 } else {
13439 ConstantMaxBECount = computeMaxBECountForLT(
13440 Start, Stride, RHS, getTypeSizeInBits(LHS->getType()), IsSigned);
13441 }
13442
13443 if (isa<SCEVCouldNotCompute>(ConstantMaxBECount) &&
13444 !isa<SCEVCouldNotCompute>(BECount))
13445 ConstantMaxBECount = getConstant(getUnsignedRangeMax(BECount));
13446
13447 const SCEV *SymbolicMaxBECount =
13448 isa<SCEVCouldNotCompute>(BECount) ? ConstantMaxBECount : BECount;
13449 return ExitLimit(BECount, ConstantMaxBECount, SymbolicMaxBECount, MaxOrZero,
13450 Predicates);
13451}
13452
13453ScalarEvolution::ExitLimit ScalarEvolution::howManyGreaterThans(
13454 const SCEV *LHS, const SCEV *RHS, const Loop *L, bool IsSigned,
13455 bool ControlsOnlyExit, bool AllowPredicates) {
13457 // We handle only IV > Invariant
13458 if (!isLoopInvariant(RHS, L))
13459 return getCouldNotCompute();
13460
13461 const SCEVAddRecExpr *IV = dyn_cast<SCEVAddRecExpr>(LHS);
13462 if (!IV && AllowPredicates)
13463 // Try to make this an AddRec using runtime tests, in the first X
13464 // iterations of this loop, where X is the SCEV expression found by the
13465 // algorithm below.
13466 IV = convertSCEVToAddRecWithPredicates(LHS, L, Predicates);
13467
13468 // Avoid weird loops
13469 if (!IV || IV->getLoop() != L || !IV->isAffine())
13470 return getCouldNotCompute();
13471
13472 auto WrapType = IsSigned ? SCEV::FlagNSW : SCEV::FlagNUW;
13473 bool NoWrap = ControlsOnlyExit && IV->getNoWrapFlags(WrapType);
13475
13476 const SCEV *Stride = getNegativeSCEV(IV->getStepRecurrence(*this));
13477
13478 // Avoid negative or zero stride values
13479 if (!isKnownPositive(Stride))
13480 return getCouldNotCompute();
13481
13482 // Avoid proven overflow cases: this will ensure that the backedge taken count
13483 // will not generate any unsigned overflow. Relaxed no-overflow conditions
13484 // exploit NoWrapFlags, allowing to optimize in presence of undefined
13485 // behaviors like the case of C language.
13486 if (!Stride->isOne() && !NoWrap)
13487 if (canIVOverflowOnGT(RHS, Stride, IsSigned))
13488 return getCouldNotCompute();
13489
13490 const SCEV *Start = IV->getStart();
13491 const SCEV *End = RHS;
13492 if (!isLoopEntryGuardedByCond(L, Cond, getAddExpr(Start, Stride), RHS)) {
13493 // If we know that Start >= RHS in the context of loop, then we know that
13494 // min(RHS, Start) = RHS at this point.
13496 L, IsSigned ? ICmpInst::ICMP_SGE : ICmpInst::ICMP_UGE, Start, RHS))
13497 End = RHS;
13498 else
13499 End = IsSigned ? getSMinExpr(RHS, Start) : getUMinExpr(RHS, Start);
13500 }
13501
13502 if (Start->getType()->isPointerTy()) {
13504 if (isa<SCEVCouldNotCompute>(Start))
13505 return Start;
13506 }
13507 if (End->getType()->isPointerTy()) {
13508 End = getLosslessPtrToIntExpr(End);
13509 if (isa<SCEVCouldNotCompute>(End))
13510 return End;
13511 }
13512
13513 // Compute ((Start - End) + (Stride - 1)) / Stride.
13514 // FIXME: This can overflow. Holding off on fixing this for now;
13515 // howManyGreaterThans will hopefully be gone soon.
13516 const SCEV *One = getOne(Stride->getType());
13517 const SCEV *BECount = getUDivExpr(
13518 getAddExpr(getMinusSCEV(Start, End), getMinusSCEV(Stride, One)), Stride);
13519
13520 APInt MaxStart = IsSigned ? getSignedRangeMax(Start)
13522
13523 APInt MinStride = IsSigned ? getSignedRangeMin(Stride)
13524 : getUnsignedRangeMin(Stride);
13525
13526 unsigned BitWidth = getTypeSizeInBits(LHS->getType());
13527 APInt Limit = IsSigned ? APInt::getSignedMinValue(BitWidth) + (MinStride - 1)
13528 : APInt::getMinValue(BitWidth) + (MinStride - 1);
13529
13530 // Although End can be a MIN expression we estimate MinEnd considering only
13531 // the case End = RHS. This is safe because in the other case (Start - End)
13532 // is zero, leading to a zero maximum backedge taken count.
13533 APInt MinEnd =
13534 IsSigned ? APIntOps::smax(getSignedRangeMin(RHS), Limit)
13535 : APIntOps::umax(getUnsignedRangeMin(RHS), Limit);
13536
13537 const SCEV *ConstantMaxBECount =
13538 isa<SCEVConstant>(BECount)
13539 ? BECount
13540 : getUDivCeilSCEV(getConstant(MaxStart - MinEnd),
13541 getConstant(MinStride));
13542
13543 if (isa<SCEVCouldNotCompute>(ConstantMaxBECount))
13544 ConstantMaxBECount = BECount;
13545 const SCEV *SymbolicMaxBECount =
13546 isa<SCEVCouldNotCompute>(BECount) ? ConstantMaxBECount : BECount;
13547
13548 return ExitLimit(BECount, ConstantMaxBECount, SymbolicMaxBECount, false,
13549 Predicates);
13550}
13551
13553 ScalarEvolution &SE) const {
13554 if (Range.isFullSet()) // Infinite loop.
13555 return SE.getCouldNotCompute();
13556
13557 // If the start is a non-zero constant, shift the range to simplify things.
13558 if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(getStart()))
13559 if (!SC->getValue()->isZero()) {
13561 Operands[0] = SE.getZero(SC->getType());
13562 const SCEV *Shifted = SE.getAddRecExpr(Operands, getLoop(),
13564 if (const auto *ShiftedAddRec = dyn_cast<SCEVAddRecExpr>(Shifted))
13565 return ShiftedAddRec->getNumIterationsInRange(
13566 Range.subtract(SC->getAPInt()), SE);
13567 // This is strange and shouldn't happen.
13568 return SE.getCouldNotCompute();
13569 }
13570
13571 // The only time we can solve this is when we have all constant indices.
13572 // Otherwise, we cannot determine the overflow conditions.
13573 if (any_of(operands(), [](const SCEV *Op) { return !isa<SCEVConstant>(Op); }))
13574 return SE.getCouldNotCompute();
13575
13576 // Okay at this point we know that all elements of the chrec are constants and
13577 // that the start element is zero.
13578
13579 // First check to see if the range contains zero. If not, the first
13580 // iteration exits.
13581 unsigned BitWidth = SE.getTypeSizeInBits(getType());
13582 if (!Range.contains(APInt(BitWidth, 0)))
13583 return SE.getZero(getType());
13584
13585 if (isAffine()) {
13586 // If this is an affine expression then we have this situation:
13587 // Solve {0,+,A} in Range === Ax in Range
13588
13589 // We know that zero is in the range. If A is positive then we know that
13590 // the upper value of the range must be the first possible exit value.
13591 // If A is negative then the lower of the range is the last possible loop
13592 // value. Also note that we already checked for a full range.
13593 APInt A = cast<SCEVConstant>(getOperand(1))->getAPInt();
13594 APInt End = A.sge(1) ? (Range.getUpper() - 1) : Range.getLower();
13595
13596 // The exit value should be (End+A)/A.
13597 APInt ExitVal = (End + A).udiv(A);
13598 ConstantInt *ExitValue = ConstantInt::get(SE.getContext(), ExitVal);
13599
13600 // Evaluate at the exit value. If we really did fall out of the valid
13601 // range, then we computed our trip count, otherwise wrap around or other
13602 // things must have happened.
13603 ConstantInt *Val = EvaluateConstantChrecAtConstant(this, ExitValue, SE);
13604 if (Range.contains(Val->getValue()))
13605 return SE.getCouldNotCompute(); // Something strange happened
13606
13607 // Ensure that the previous value is in the range.
13608 assert(Range.contains(
13610 ConstantInt::get(SE.getContext(), ExitVal - 1), SE)->getValue()) &&
13611 "Linear scev computation is off in a bad way!");
13612 return SE.getConstant(ExitValue);
13613 }
13614
13615 if (isQuadratic()) {
13616 if (auto S = SolveQuadraticAddRecRange(this, Range, SE))
13617 return SE.getConstant(*S);
13618 }
13619
13620 return SE.getCouldNotCompute();
13621}
13622
13623const SCEVAddRecExpr *
13625 assert(getNumOperands() > 1 && "AddRec with zero step?");
13626 // There is a temptation to just call getAddExpr(this, getStepRecurrence(SE)),
13627 // but in this case we cannot guarantee that the value returned will be an
13628 // AddRec because SCEV does not have a fixed point where it stops
13629 // simplification: it is legal to return ({rec1} + {rec2}). For example, it
13630 // may happen if we reach arithmetic depth limit while simplifying. So we
13631 // construct the returned value explicitly.
13633 // If this is {A,+,B,+,C,...,+,N}, then its step is {B,+,C,+,...,+,N}, and
13634 // (this + Step) is {A+B,+,B+C,+...,+,N}.
13635 for (unsigned i = 0, e = getNumOperands() - 1; i < e; ++i)
13636 Ops.push_back(SE.getAddExpr(getOperand(i), getOperand(i + 1)));
13637 // We know that the last operand is not a constant zero (otherwise it would
13638 // have been popped out earlier). This guarantees us that if the result has
13639 // the same last operand, then it will also not be popped out, meaning that
13640 // the returned value will be an AddRec.
13641 const SCEV *Last = getOperand(getNumOperands() - 1);
13642 assert(!Last->isZero() && "Recurrency with zero step?");
13643 Ops.push_back(Last);
13646}
13647
13648// Return true when S contains at least an undef value.
13650 return SCEVExprContains(S, [](const SCEV *S) {
13651 if (const auto *SU = dyn_cast<SCEVUnknown>(S))
13652 return isa<UndefValue>(SU->getValue());
13653 return false;
13654 });
13655}
13656
13657// Return true when S contains a value that is a nullptr.
13659 return SCEVExprContains(S, [](const SCEV *S) {
13660 if (const auto *SU = dyn_cast<SCEVUnknown>(S))
13661 return SU->getValue() == nullptr;
13662 return false;
13663 });
13664}
13665
13666/// Return the size of an element read or written by Inst.
13668 Type *Ty;
13669 if (StoreInst *Store = dyn_cast<StoreInst>(Inst))
13670 Ty = Store->getValueOperand()->getType();
13671 else if (LoadInst *Load = dyn_cast<LoadInst>(Inst))
13672 Ty = Load->getType();
13673 else
13674 return nullptr;
13675
13677 return getSizeOfExpr(ETy, Ty);
13678}
13679
13680//===----------------------------------------------------------------------===//
13681// SCEVCallbackVH Class Implementation
13682//===----------------------------------------------------------------------===//
13683
13685 assert(SE && "SCEVCallbackVH called with a null ScalarEvolution!");
13686 if (PHINode *PN = dyn_cast<PHINode>(getValPtr()))
13687 SE->ConstantEvolutionLoopExitValue.erase(PN);
13688 SE->eraseValueFromMap(getValPtr());
13689 // this now dangles!
13690}
13691
13692void ScalarEvolution::SCEVCallbackVH::allUsesReplacedWith(Value *V) {
13693 assert(SE && "SCEVCallbackVH called with a null ScalarEvolution!");
13694
13695 // Forget all the expressions associated with users of the old value,
13696 // so that future queries will recompute the expressions using the new
13697 // value.
13698 SE->forgetValue(getValPtr());
13699 // this now dangles!
13700}
13701
13702ScalarEvolution::SCEVCallbackVH::SCEVCallbackVH(Value *V, ScalarEvolution *se)
13703 : CallbackVH(V), SE(se) {}
13704
13705//===----------------------------------------------------------------------===//
13706// ScalarEvolution Class Implementation
13707//===----------------------------------------------------------------------===//
13708
13711 LoopInfo &LI)
13712 : F(F), DL(F.getDataLayout()), TLI(TLI), AC(AC), DT(DT), LI(LI),
13713 CouldNotCompute(new SCEVCouldNotCompute()), ValuesAtScopes(64),
13714 LoopDispositions(64), BlockDispositions(64) {
13715 // To use guards for proving predicates, we need to scan every instruction in
13716 // relevant basic blocks, and not just terminators. Doing this is a waste of
13717 // time if the IR does not actually contain any calls to
13718 // @llvm.experimental.guard, so do a quick check and remember this beforehand.
13719 //
13720 // This pessimizes the case where a pass that preserves ScalarEvolution wants
13721 // to _add_ guards to the module when there weren't any before, and wants
13722 // ScalarEvolution to optimize based on those guards. For now we prefer to be
13723 // efficient in lieu of being smart in that rather obscure case.
13724
13725 auto *GuardDecl = Intrinsic::getDeclarationIfExists(
13726 F.getParent(), Intrinsic::experimental_guard);
13727 HasGuards = GuardDecl && !GuardDecl->use_empty();
13728}
13729
13731 : F(Arg.F), DL(Arg.DL), HasGuards(Arg.HasGuards), TLI(Arg.TLI), AC(Arg.AC),
13732 DT(Arg.DT), LI(Arg.LI), CouldNotCompute(std::move(Arg.CouldNotCompute)),
13733 ValueExprMap(std::move(Arg.ValueExprMap)),
13734 PendingLoopPredicates(std::move(Arg.PendingLoopPredicates)),
13735 PendingPhiRanges(std::move(Arg.PendingPhiRanges)),
13736 PendingMerges(std::move(Arg.PendingMerges)),
13737 ConstantMultipleCache(std::move(Arg.ConstantMultipleCache)),
13738 BackedgeTakenCounts(std::move(Arg.BackedgeTakenCounts)),
13739 PredicatedBackedgeTakenCounts(
13740 std::move(Arg.PredicatedBackedgeTakenCounts)),
13741 BECountUsers(std::move(Arg.BECountUsers)),
13742 ConstantEvolutionLoopExitValue(
13743 std::move(Arg.ConstantEvolutionLoopExitValue)),
13744 ValuesAtScopes(std::move(Arg.ValuesAtScopes)),
13745 ValuesAtScopesUsers(std::move(Arg.ValuesAtScopesUsers)),
13746 LoopDispositions(std::move(Arg.LoopDispositions)),
13747 LoopPropertiesCache(std::move(Arg.LoopPropertiesCache)),
13748 BlockDispositions(std::move(Arg.BlockDispositions)),
13749 SCEVUsers(std::move(Arg.SCEVUsers)),
13750 UnsignedRanges(std::move(Arg.UnsignedRanges)),
13751 SignedRanges(std::move(Arg.SignedRanges)),
13752 UniqueSCEVs(std::move(Arg.UniqueSCEVs)),
13753 UniquePreds(std::move(Arg.UniquePreds)),
13754 SCEVAllocator(std::move(Arg.SCEVAllocator)),
13755 LoopUsers(std::move(Arg.LoopUsers)),
13756 PredicatedSCEVRewrites(std::move(Arg.PredicatedSCEVRewrites)),
13757 FirstUnknown(Arg.FirstUnknown) {
13758 Arg.FirstUnknown = nullptr;
13759}
13760
13762 // Iterate through all the SCEVUnknown instances and call their
13763 // destructors, so that they release their references to their values.
13764 for (SCEVUnknown *U = FirstUnknown; U;) {
13765 SCEVUnknown *Tmp = U;
13766 U = U->Next;
13767 Tmp->~SCEVUnknown();
13768 }
13769 FirstUnknown = nullptr;
13770
13771 ExprValueMap.clear();
13772 ValueExprMap.clear();
13773 HasRecMap.clear();
13774 BackedgeTakenCounts.clear();
13775 PredicatedBackedgeTakenCounts.clear();
13776
13777 assert(PendingLoopPredicates.empty() && "isImpliedCond garbage");
13778 assert(PendingPhiRanges.empty() && "getRangeRef garbage");
13779 assert(PendingMerges.empty() && "isImpliedViaMerge garbage");
13780 assert(!WalkingBEDominatingConds && "isLoopBackedgeGuardedByCond garbage!");
13781 assert(!ProvingSplitPredicate && "ProvingSplitPredicate garbage!");
13782}
13783
13787
13788/// When printing a top-level SCEV for trip counts, it's helpful to include
13789/// a type for constants which are otherwise hard to disambiguate.
13790static void PrintSCEVWithTypeHint(raw_ostream &OS, const SCEV* S) {
13791 if (isa<SCEVConstant>(S))
13792 OS << *S->getType() << " ";
13793 OS << *S;
13794}
13795
13797 const Loop *L) {
13798 // Print all inner loops first
13799 for (Loop *I : *L)
13800 PrintLoopInfo(OS, SE, I);
13801
13802 OS << "Loop ";
13803 L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
13804 OS << ": ";
13805
13806 SmallVector<BasicBlock *, 8> ExitingBlocks;
13807 L->getExitingBlocks(ExitingBlocks);
13808 if (ExitingBlocks.size() != 1)
13809 OS << "<multiple exits> ";
13810
13811 auto *BTC = SE->getBackedgeTakenCount(L);
13812 if (!isa<SCEVCouldNotCompute>(BTC)) {
13813 OS << "backedge-taken count is ";
13814 PrintSCEVWithTypeHint(OS, BTC);
13815 } else
13816 OS << "Unpredictable backedge-taken count.";
13817 OS << "\n";
13818
13819 if (ExitingBlocks.size() > 1)
13820 for (BasicBlock *ExitingBlock : ExitingBlocks) {
13821 OS << " exit count for " << ExitingBlock->getName() << ": ";
13822 const SCEV *EC = SE->getExitCount(L, ExitingBlock);
13823 PrintSCEVWithTypeHint(OS, EC);
13824 if (isa<SCEVCouldNotCompute>(EC)) {
13825 // Retry with predicates.
13827 EC = SE->getPredicatedExitCount(L, ExitingBlock, &Predicates);
13828 if (!isa<SCEVCouldNotCompute>(EC)) {
13829 OS << "\n predicated exit count for " << ExitingBlock->getName()
13830 << ": ";
13831 PrintSCEVWithTypeHint(OS, EC);
13832 OS << "\n Predicates:\n";
13833 for (const auto *P : Predicates)
13834 P->print(OS, 4);
13835 }
13836 }
13837 OS << "\n";
13838 }
13839
13840 OS << "Loop ";
13841 L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
13842 OS << ": ";
13843
13844 auto *ConstantBTC = SE->getConstantMaxBackedgeTakenCount(L);
13845 if (!isa<SCEVCouldNotCompute>(ConstantBTC)) {
13846 OS << "constant max backedge-taken count is ";
13847 PrintSCEVWithTypeHint(OS, ConstantBTC);
13849 OS << ", actual taken count either this or zero.";
13850 } else {
13851 OS << "Unpredictable constant max backedge-taken count. ";
13852 }
13853
13854 OS << "\n"
13855 "Loop ";
13856 L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
13857 OS << ": ";
13858
13859 auto *SymbolicBTC = SE->getSymbolicMaxBackedgeTakenCount(L);
13860 if (!isa<SCEVCouldNotCompute>(SymbolicBTC)) {
13861 OS << "symbolic max backedge-taken count is ";
13862 PrintSCEVWithTypeHint(OS, SymbolicBTC);
13864 OS << ", actual taken count either this or zero.";
13865 } else {
13866 OS << "Unpredictable symbolic max backedge-taken count. ";
13867 }
13868 OS << "\n";
13869
13870 if (ExitingBlocks.size() > 1)
13871 for (BasicBlock *ExitingBlock : ExitingBlocks) {
13872 OS << " symbolic max exit count for " << ExitingBlock->getName() << ": ";
13873 auto *ExitBTC = SE->getExitCount(L, ExitingBlock,
13875 PrintSCEVWithTypeHint(OS, ExitBTC);
13876 if (isa<SCEVCouldNotCompute>(ExitBTC)) {
13877 // Retry with predicates.
13879 ExitBTC = SE->getPredicatedExitCount(L, ExitingBlock, &Predicates,
13881 if (!isa<SCEVCouldNotCompute>(ExitBTC)) {
13882 OS << "\n predicated symbolic max exit count for "
13883 << ExitingBlock->getName() << ": ";
13884 PrintSCEVWithTypeHint(OS, ExitBTC);
13885 OS << "\n Predicates:\n";
13886 for (const auto *P : Predicates)
13887 P->print(OS, 4);
13888 }
13889 }
13890 OS << "\n";
13891 }
13892
13894 auto *PBT = SE->getPredicatedBackedgeTakenCount(L, Preds);
13895 if (PBT != BTC) {
13896 assert(!Preds.empty() && "Different predicated BTC, but no predicates");
13897 OS << "Loop ";
13898 L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
13899 OS << ": ";
13900 if (!isa<SCEVCouldNotCompute>(PBT)) {
13901 OS << "Predicated backedge-taken count is ";
13902 PrintSCEVWithTypeHint(OS, PBT);
13903 } else
13904 OS << "Unpredictable predicated backedge-taken count.";
13905 OS << "\n";
13906 OS << " Predicates:\n";
13907 for (const auto *P : Preds)
13908 P->print(OS, 4);
13909 }
13910 Preds.clear();
13911
13912 auto *PredConstantMax =
13914 if (PredConstantMax != ConstantBTC) {
13915 assert(!Preds.empty() &&
13916 "different predicated constant max BTC but no predicates");
13917 OS << "Loop ";
13918 L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
13919 OS << ": ";
13920 if (!isa<SCEVCouldNotCompute>(PredConstantMax)) {
13921 OS << "Predicated constant max backedge-taken count is ";
13922 PrintSCEVWithTypeHint(OS, PredConstantMax);
13923 } else
13924 OS << "Unpredictable predicated constant max backedge-taken count.";
13925 OS << "\n";
13926 OS << " Predicates:\n";
13927 for (const auto *P : Preds)
13928 P->print(OS, 4);
13929 }
13930 Preds.clear();
13931
13932 auto *PredSymbolicMax =
13934 if (SymbolicBTC != PredSymbolicMax) {
13935 assert(!Preds.empty() &&
13936 "Different predicated symbolic max BTC, but no predicates");
13937 OS << "Loop ";
13938 L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
13939 OS << ": ";
13940 if (!isa<SCEVCouldNotCompute>(PredSymbolicMax)) {
13941 OS << "Predicated symbolic max backedge-taken count is ";
13942 PrintSCEVWithTypeHint(OS, PredSymbolicMax);
13943 } else
13944 OS << "Unpredictable predicated symbolic max backedge-taken count.";
13945 OS << "\n";
13946 OS << " Predicates:\n";
13947 for (const auto *P : Preds)
13948 P->print(OS, 4);
13949 }
13950
13952 OS << "Loop ";
13953 L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
13954 OS << ": ";
13955 OS << "Trip multiple is " << SE->getSmallConstantTripMultiple(L) << "\n";
13956 }
13957}
13958
13959namespace llvm {
13961 switch (LD) {
13963 OS << "Variant";
13964 break;
13966 OS << "Invariant";
13967 break;
13969 OS << "Computable";
13970 break;
13971 }
13972 return OS;
13973}
13974
13976 switch (BD) {
13978 OS << "DoesNotDominate";
13979 break;
13981 OS << "Dominates";
13982 break;
13984 OS << "ProperlyDominates";
13985 break;
13986 }
13987 return OS;
13988}
13989} // namespace llvm
13990
13992 // ScalarEvolution's implementation of the print method is to print
13993 // out SCEV values of all instructions that are interesting. Doing
13994 // this potentially causes it to create new SCEV objects though,
13995 // which technically conflicts with the const qualifier. This isn't
13996 // observable from outside the class though, so casting away the
13997 // const isn't dangerous.
13998 ScalarEvolution &SE = *const_cast<ScalarEvolution *>(this);
13999
14000 if (ClassifyExpressions) {
14001 OS << "Classifying expressions for: ";
14002 F.printAsOperand(OS, /*PrintType=*/false);
14003 OS << "\n";
14004 for (Instruction &I : instructions(F))
14005 if (isSCEVable(I.getType()) && !isa<CmpInst>(I)) {
14006 OS << I << '\n';
14007 OS << " --> ";
14008 const SCEV *SV = SE.getSCEV(&I);
14009 SV->print(OS);
14010 if (!isa<SCEVCouldNotCompute>(SV)) {
14011 OS << " U: ";
14012 SE.getUnsignedRange(SV).print(OS);
14013 OS << " S: ";
14014 SE.getSignedRange(SV).print(OS);
14015 }
14016
14017 const Loop *L = LI.getLoopFor(I.getParent());
14018
14019 const SCEV *AtUse = SE.getSCEVAtScope(SV, L);
14020 if (AtUse != SV) {
14021 OS << " --> ";
14022 AtUse->print(OS);
14023 if (!isa<SCEVCouldNotCompute>(AtUse)) {
14024 OS << " U: ";
14025 SE.getUnsignedRange(AtUse).print(OS);
14026 OS << " S: ";
14027 SE.getSignedRange(AtUse).print(OS);
14028 }
14029 }
14030
14031 if (L) {
14032 OS << "\t\t" "Exits: ";
14033 const SCEV *ExitValue = SE.getSCEVAtScope(SV, L->getParentLoop());
14034 if (!SE.isLoopInvariant(ExitValue, L)) {
14035 OS << "<<Unknown>>";
14036 } else {
14037 OS << *ExitValue;
14038 }
14039
14040 bool First = true;
14041 for (const auto *Iter = L; Iter; Iter = Iter->getParentLoop()) {
14042 if (First) {
14043 OS << "\t\t" "LoopDispositions: { ";
14044 First = false;
14045 } else {
14046 OS << ", ";
14047 }
14048
14049 Iter->getHeader()->printAsOperand(OS, /*PrintType=*/false);
14050 OS << ": " << SE.getLoopDisposition(SV, Iter);
14051 }
14052
14053 for (const auto *InnerL : depth_first(L)) {
14054 if (InnerL == L)
14055 continue;
14056 if (First) {
14057 OS << "\t\t" "LoopDispositions: { ";
14058 First = false;
14059 } else {
14060 OS << ", ";
14061 }
14062
14063 InnerL->getHeader()->printAsOperand(OS, /*PrintType=*/false);
14064 OS << ": " << SE.getLoopDisposition(SV, InnerL);
14065 }
14066
14067 OS << " }";
14068 }
14069
14070 OS << "\n";
14071 }
14072 }
14073
14074 OS << "Determining loop execution counts for: ";
14075 F.printAsOperand(OS, /*PrintType=*/false);
14076 OS << "\n";
14077 for (Loop *I : LI)
14078 PrintLoopInfo(OS, &SE, I);
14079}
14080
14083 auto &Values = LoopDispositions[S];
14084 for (auto &V : Values) {
14085 if (V.getPointer() == L)
14086 return V.getInt();
14087 }
14088 Values.emplace_back(L, LoopVariant);
14089 LoopDisposition D = computeLoopDisposition(S, L);
14090 auto &Values2 = LoopDispositions[S];
14091 for (auto &V : llvm::reverse(Values2)) {
14092 if (V.getPointer() == L) {
14093 V.setInt(D);
14094 break;
14095 }
14096 }
14097 return D;
14098}
14099
14101ScalarEvolution::computeLoopDisposition(const SCEV *S, const Loop *L) {
14102 switch (S->getSCEVType()) {
14103 case scConstant:
14104 case scVScale:
14105 return LoopInvariant;
14106 case scAddRecExpr: {
14107 const SCEVAddRecExpr *AR = cast<SCEVAddRecExpr>(S);
14108
14109 // If L is the addrec's loop, it's computable.
14110 if (AR->getLoop() == L)
14111 return LoopComputable;
14112
14113 // Add recurrences are never invariant in the function-body (null loop).
14114 if (!L)
14115 return LoopVariant;
14116
14117 // Everything that is not defined at loop entry is variant.
14118 if (DT.dominates(L->getHeader(), AR->getLoop()->getHeader()))
14119 return LoopVariant;
14120 assert(!L->contains(AR->getLoop()) && "Containing loop's header does not"
14121 " dominate the contained loop's header?");
14122
14123 // This recurrence is invariant w.r.t. L if AR's loop contains L.
14124 if (AR->getLoop()->contains(L))
14125 return LoopInvariant;
14126
14127 // This recurrence is variant w.r.t. L if any of its operands
14128 // are variant.
14129 for (const auto *Op : AR->operands())
14130 if (!isLoopInvariant(Op, L))
14131 return LoopVariant;
14132
14133 // Otherwise it's loop-invariant.
14134 return LoopInvariant;
14135 }
14136 case scTruncate:
14137 case scZeroExtend:
14138 case scSignExtend:
14139 case scPtrToInt:
14140 case scAddExpr:
14141 case scMulExpr:
14142 case scUDivExpr:
14143 case scUMaxExpr:
14144 case scSMaxExpr:
14145 case scUMinExpr:
14146 case scSMinExpr:
14147 case scSequentialUMinExpr: {
14148 bool HasVarying = false;
14149 for (const auto *Op : S->operands()) {
14151 if (D == LoopVariant)
14152 return LoopVariant;
14153 if (D == LoopComputable)
14154 HasVarying = true;
14155 }
14156 return HasVarying ? LoopComputable : LoopInvariant;
14157 }
14158 case scUnknown:
14159 // All non-instruction values are loop invariant. All instructions are loop
14160 // invariant if they are not contained in the specified loop.
14161 // Instructions are never considered invariant in the function body
14162 // (null loop) because they are defined within the "loop".
14163 if (auto *I = dyn_cast<Instruction>(cast<SCEVUnknown>(S)->getValue()))
14164 return (L && !L->contains(I)) ? LoopInvariant : LoopVariant;
14165 return LoopInvariant;
14166 case scCouldNotCompute:
14167 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
14168 }
14169 llvm_unreachable("Unknown SCEV kind!");
14170}
14171
14173 return getLoopDisposition(S, L) == LoopInvariant;
14174}
14175
14177 return getLoopDisposition(S, L) == LoopComputable;
14178}
14179
14182 auto &Values = BlockDispositions[S];
14183 for (auto &V : Values) {
14184 if (V.getPointer() == BB)
14185 return V.getInt();
14186 }
14187 Values.emplace_back(BB, DoesNotDominateBlock);
14188 BlockDisposition D = computeBlockDisposition(S, BB);
14189 auto &Values2 = BlockDispositions[S];
14190 for (auto &V : llvm::reverse(Values2)) {
14191 if (V.getPointer() == BB) {
14192 V.setInt(D);
14193 break;
14194 }
14195 }
14196 return D;
14197}
14198
14200ScalarEvolution::computeBlockDisposition(const SCEV *S, const BasicBlock *BB) {
14201 switch (S->getSCEVType()) {
14202 case scConstant:
14203 case scVScale:
14205 case scAddRecExpr: {
14206 // This uses a "dominates" query instead of "properly dominates" query
14207 // to test for proper dominance too, because the instruction which
14208 // produces the addrec's value is a PHI, and a PHI effectively properly
14209 // dominates its entire containing block.
14210 const SCEVAddRecExpr *AR = cast<SCEVAddRecExpr>(S);
14211 if (!DT.dominates(AR->getLoop()->getHeader(), BB))
14212 return DoesNotDominateBlock;
14213
14214 // Fall through into SCEVNAryExpr handling.
14215 [[fallthrough]];
14216 }
14217 case scTruncate:
14218 case scZeroExtend:
14219 case scSignExtend:
14220 case scPtrToInt:
14221 case scAddExpr:
14222 case scMulExpr:
14223 case scUDivExpr:
14224 case scUMaxExpr:
14225 case scSMaxExpr:
14226 case scUMinExpr:
14227 case scSMinExpr:
14228 case scSequentialUMinExpr: {
14229 bool Proper = true;
14230 for (const SCEV *NAryOp : S->operands()) {
14232 if (D == DoesNotDominateBlock)
14233 return DoesNotDominateBlock;
14234 if (D == DominatesBlock)
14235 Proper = false;
14236 }
14237 return Proper ? ProperlyDominatesBlock : DominatesBlock;
14238 }
14239 case scUnknown:
14240 if (Instruction *I =
14241 dyn_cast<Instruction>(cast<SCEVUnknown>(S)->getValue())) {
14242 if (I->getParent() == BB)
14243 return DominatesBlock;
14244 if (DT.properlyDominates(I->getParent(), BB))
14246 return DoesNotDominateBlock;
14247 }
14249 case scCouldNotCompute:
14250 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
14251 }
14252 llvm_unreachable("Unknown SCEV kind!");
14253}
14254
14255bool ScalarEvolution::dominates(const SCEV *S, const BasicBlock *BB) {
14256 return getBlockDisposition(S, BB) >= DominatesBlock;
14257}
14258
14261}
14262
14263bool ScalarEvolution::hasOperand(const SCEV *S, const SCEV *Op) const {
14264 return SCEVExprContains(S, [&](const SCEV *Expr) { return Expr == Op; });
14265}
14266
14267void ScalarEvolution::forgetBackedgeTakenCounts(const Loop *L,
14268 bool Predicated) {
14269 auto &BECounts =
14270 Predicated ? PredicatedBackedgeTakenCounts : BackedgeTakenCounts;
14271 auto It = BECounts.find(L);
14272 if (It != BECounts.end()) {
14273 for (const ExitNotTakenInfo &ENT : It->second.ExitNotTaken) {
14274 for (const SCEV *S : {ENT.ExactNotTaken, ENT.SymbolicMaxNotTaken}) {
14275 if (!isa<SCEVConstant>(S)) {
14276 auto UserIt = BECountUsers.find(S);
14277 assert(UserIt != BECountUsers.end());
14278 UserIt->second.erase({L, Predicated});
14279 }
14280 }
14281 }
14282 BECounts.erase(It);
14283 }
14284}
14285
14286void ScalarEvolution::forgetMemoizedResults(ArrayRef<const SCEV *> SCEVs) {
14287 SmallPtrSet<const SCEV *, 8> ToForget(llvm::from_range, SCEVs);
14288 SmallVector<const SCEV *, 8> Worklist(ToForget.begin(), ToForget.end());
14289
14290 while (!Worklist.empty()) {
14291 const SCEV *Curr = Worklist.pop_back_val();
14292 auto Users = SCEVUsers.find(Curr);
14293 if (Users != SCEVUsers.end())
14294 for (const auto *User : Users->second)
14295 if (ToForget.insert(User).second)
14296 Worklist.push_back(User);
14297 }
14298
14299 for (const auto *S : ToForget)
14300 forgetMemoizedResultsImpl(S);
14301
14302 for (auto I = PredicatedSCEVRewrites.begin();
14303 I != PredicatedSCEVRewrites.end();) {
14304 std::pair<const SCEV *, const Loop *> Entry = I->first;
14305 if (ToForget.count(Entry.first))
14306 PredicatedSCEVRewrites.erase(I++);
14307 else
14308 ++I;
14309 }
14310}
14311
14312void ScalarEvolution::forgetMemoizedResultsImpl(const SCEV *S) {
14313 LoopDispositions.erase(S);
14314 BlockDispositions.erase(S);
14315 UnsignedRanges.erase(S);
14316 SignedRanges.erase(S);
14317 HasRecMap.erase(S);
14318 ConstantMultipleCache.erase(S);
14319
14320 if (auto *AR = dyn_cast<SCEVAddRecExpr>(S)) {
14321 UnsignedWrapViaInductionTried.erase(AR);
14322 SignedWrapViaInductionTried.erase(AR);
14323 }
14324
14325 auto ExprIt = ExprValueMap.find(S);
14326 if (ExprIt != ExprValueMap.end()) {
14327 for (Value *V : ExprIt->second) {
14328 auto ValueIt = ValueExprMap.find_as(V);
14329 if (ValueIt != ValueExprMap.end())
14330 ValueExprMap.erase(ValueIt);
14331 }
14332 ExprValueMap.erase(ExprIt);
14333 }
14334
14335 auto ScopeIt = ValuesAtScopes.find(S);
14336 if (ScopeIt != ValuesAtScopes.end()) {
14337 for (const auto &Pair : ScopeIt->second)
14338 if (!isa_and_nonnull<SCEVConstant>(Pair.second))
14339 llvm::erase(ValuesAtScopesUsers[Pair.second],
14340 std::make_pair(Pair.first, S));
14341 ValuesAtScopes.erase(ScopeIt);
14342 }
14343
14344 auto ScopeUserIt = ValuesAtScopesUsers.find(S);
14345 if (ScopeUserIt != ValuesAtScopesUsers.end()) {
14346 for (const auto &Pair : ScopeUserIt->second)
14347 llvm::erase(ValuesAtScopes[Pair.second], std::make_pair(Pair.first, S));
14348 ValuesAtScopesUsers.erase(ScopeUserIt);
14349 }
14350
14351 auto BEUsersIt = BECountUsers.find(S);
14352 if (BEUsersIt != BECountUsers.end()) {
14353 // Work on a copy, as forgetBackedgeTakenCounts() will modify the original.
14354 auto Copy = BEUsersIt->second;
14355 for (const auto &Pair : Copy)
14356 forgetBackedgeTakenCounts(Pair.getPointer(), Pair.getInt());
14357 BECountUsers.erase(BEUsersIt);
14358 }
14359
14360 auto FoldUser = FoldCacheUser.find(S);
14361 if (FoldUser != FoldCacheUser.end())
14362 for (auto &KV : FoldUser->second)
14363 FoldCache.erase(KV);
14364 FoldCacheUser.erase(S);
14365}
14366
14367void
14368ScalarEvolution::getUsedLoops(const SCEV *S,
14369 SmallPtrSetImpl<const Loop *> &LoopsUsed) {
14370 struct FindUsedLoops {
14371 FindUsedLoops(SmallPtrSetImpl<const Loop *> &LoopsUsed)
14372 : LoopsUsed(LoopsUsed) {}
14373 SmallPtrSetImpl<const Loop *> &LoopsUsed;
14374 bool follow(const SCEV *S) {
14375 if (auto *AR = dyn_cast<SCEVAddRecExpr>(S))
14376 LoopsUsed.insert(AR->getLoop());
14377 return true;
14378 }
14379
14380 bool isDone() const { return false; }
14381 };
14382
14383 FindUsedLoops F(LoopsUsed);
14384 SCEVTraversal<FindUsedLoops>(F).visitAll(S);
14385}
14386
14387void ScalarEvolution::getReachableBlocks(
14390 Worklist.push_back(&F.getEntryBlock());
14391 while (!Worklist.empty()) {
14392 BasicBlock *BB = Worklist.pop_back_val();
14393 if (!Reachable.insert(BB).second)
14394 continue;
14395
14396 Value *Cond;
14397 BasicBlock *TrueBB, *FalseBB;
14398 if (match(BB->getTerminator(), m_Br(m_Value(Cond), m_BasicBlock(TrueBB),
14399 m_BasicBlock(FalseBB)))) {
14400 if (auto *C = dyn_cast<ConstantInt>(Cond)) {
14401 Worklist.push_back(C->isOne() ? TrueBB : FalseBB);
14402 continue;
14403 }
14404
14405 if (auto *Cmp = dyn_cast<ICmpInst>(Cond)) {
14406 const SCEV *L = getSCEV(Cmp->getOperand(0));
14407 const SCEV *R = getSCEV(Cmp->getOperand(1));
14408 if (isKnownPredicateViaConstantRanges(Cmp->getCmpPredicate(), L, R)) {
14409 Worklist.push_back(TrueBB);
14410 continue;
14411 }
14412 if (isKnownPredicateViaConstantRanges(Cmp->getInverseCmpPredicate(), L,
14413 R)) {
14414 Worklist.push_back(FalseBB);
14415 continue;
14416 }
14417 }
14418 }
14419
14420 append_range(Worklist, successors(BB));
14421 }
14422}
14423
14425 ScalarEvolution &SE = *const_cast<ScalarEvolution *>(this);
14426 ScalarEvolution SE2(F, TLI, AC, DT, LI);
14427
14428 SmallVector<Loop *, 8> LoopStack(LI.begin(), LI.end());
14429
14430 // Map's SCEV expressions from one ScalarEvolution "universe" to another.
14431 struct SCEVMapper : public SCEVRewriteVisitor<SCEVMapper> {
14432 SCEVMapper(ScalarEvolution &SE) : SCEVRewriteVisitor<SCEVMapper>(SE) {}
14433
14434 const SCEV *visitConstant(const SCEVConstant *Constant) {
14435 return SE.getConstant(Constant->getAPInt());
14436 }
14437
14438 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
14439 return SE.getUnknown(Expr->getValue());
14440 }
14441
14442 const SCEV *visitCouldNotCompute(const SCEVCouldNotCompute *Expr) {
14443 return SE.getCouldNotCompute();
14444 }
14445 };
14446
14447 SCEVMapper SCM(SE2);
14448 SmallPtrSet<BasicBlock *, 16> ReachableBlocks;
14449 SE2.getReachableBlocks(ReachableBlocks, F);
14450
14451 auto GetDelta = [&](const SCEV *Old, const SCEV *New) -> const SCEV * {
14452 if (containsUndefs(Old) || containsUndefs(New)) {
14453 // SCEV treats "undef" as an unknown but consistent value (i.e. it does
14454 // not propagate undef aggressively). This means we can (and do) fail
14455 // verification in cases where a transform makes a value go from "undef"
14456 // to "undef+1" (say). The transform is fine, since in both cases the
14457 // result is "undef", but SCEV thinks the value increased by 1.
14458 return nullptr;
14459 }
14460
14461 // Unless VerifySCEVStrict is set, we only compare constant deltas.
14462 const SCEV *Delta = SE2.getMinusSCEV(Old, New);
14463 if (!VerifySCEVStrict && !isa<SCEVConstant>(Delta))
14464 return nullptr;
14465
14466 return Delta;
14467 };
14468
14469 while (!LoopStack.empty()) {
14470 auto *L = LoopStack.pop_back_val();
14471 llvm::append_range(LoopStack, *L);
14472
14473 // Only verify BECounts in reachable loops. For an unreachable loop,
14474 // any BECount is legal.
14475 if (!ReachableBlocks.contains(L->getHeader()))
14476 continue;
14477
14478 // Only verify cached BECounts. Computing new BECounts may change the
14479 // results of subsequent SCEV uses.
14480 auto It = BackedgeTakenCounts.find(L);
14481 if (It == BackedgeTakenCounts.end())
14482 continue;
14483
14484 auto *CurBECount =
14485 SCM.visit(It->second.getExact(L, const_cast<ScalarEvolution *>(this)));
14486 auto *NewBECount = SE2.getBackedgeTakenCount(L);
14487
14488 if (CurBECount == SE2.getCouldNotCompute() ||
14489 NewBECount == SE2.getCouldNotCompute()) {
14490 // NB! This situation is legal, but is very suspicious -- whatever pass
14491 // change the loop to make a trip count go from could not compute to
14492 // computable or vice-versa *should have* invalidated SCEV. However, we
14493 // choose not to assert here (for now) since we don't want false
14494 // positives.
14495 continue;
14496 }
14497
14498 if (SE.getTypeSizeInBits(CurBECount->getType()) >
14499 SE.getTypeSizeInBits(NewBECount->getType()))
14500 NewBECount = SE2.getZeroExtendExpr(NewBECount, CurBECount->getType());
14501 else if (SE.getTypeSizeInBits(CurBECount->getType()) <
14502 SE.getTypeSizeInBits(NewBECount->getType()))
14503 CurBECount = SE2.getZeroExtendExpr(CurBECount, NewBECount->getType());
14504
14505 const SCEV *Delta = GetDelta(CurBECount, NewBECount);
14506 if (Delta && !Delta->isZero()) {
14507 dbgs() << "Trip Count for " << *L << " Changed!\n";
14508 dbgs() << "Old: " << *CurBECount << "\n";
14509 dbgs() << "New: " << *NewBECount << "\n";
14510 dbgs() << "Delta: " << *Delta << "\n";
14511 std::abort();
14512 }
14513 }
14514
14515 // Collect all valid loops currently in LoopInfo.
14516 SmallPtrSet<Loop *, 32> ValidLoops;
14517 SmallVector<Loop *, 32> Worklist(LI.begin(), LI.end());
14518 while (!Worklist.empty()) {
14519 Loop *L = Worklist.pop_back_val();
14520 if (ValidLoops.insert(L).second)
14521 Worklist.append(L->begin(), L->end());
14522 }
14523 for (const auto &KV : ValueExprMap) {
14524#ifndef NDEBUG
14525 // Check for SCEV expressions referencing invalid/deleted loops.
14526 if (auto *AR = dyn_cast<SCEVAddRecExpr>(KV.second)) {
14527 assert(ValidLoops.contains(AR->getLoop()) &&
14528 "AddRec references invalid loop");
14529 }
14530#endif
14531
14532 // Check that the value is also part of the reverse map.
14533 auto It = ExprValueMap.find(KV.second);
14534 if (It == ExprValueMap.end() || !It->second.contains(KV.first)) {
14535 dbgs() << "Value " << *KV.first
14536 << " is in ValueExprMap but not in ExprValueMap\n";
14537 std::abort();
14538 }
14539
14540 if (auto *I = dyn_cast<Instruction>(&*KV.first)) {
14541 if (!ReachableBlocks.contains(I->getParent()))
14542 continue;
14543 const SCEV *OldSCEV = SCM.visit(KV.second);
14544 const SCEV *NewSCEV = SE2.getSCEV(I);
14545 const SCEV *Delta = GetDelta(OldSCEV, NewSCEV);
14546 if (Delta && !Delta->isZero()) {
14547 dbgs() << "SCEV for value " << *I << " changed!\n"
14548 << "Old: " << *OldSCEV << "\n"
14549 << "New: " << *NewSCEV << "\n"
14550 << "Delta: " << *Delta << "\n";
14551 std::abort();
14552 }
14553 }
14554 }
14555
14556 for (const auto &KV : ExprValueMap) {
14557 for (Value *V : KV.second) {
14558 const SCEV *S = ValueExprMap.lookup(V);
14559 if (!S) {
14560 dbgs() << "Value " << *V
14561 << " is in ExprValueMap but not in ValueExprMap\n";
14562 std::abort();
14563 }
14564 if (S != KV.first) {
14565 dbgs() << "Value " << *V << " mapped to " << *S << " rather than "
14566 << *KV.first << "\n";
14567 std::abort();
14568 }
14569 }
14570 }
14571
14572 // Verify integrity of SCEV users.
14573 for (const auto &S : UniqueSCEVs) {
14574 for (const auto *Op : S.operands()) {
14575 // We do not store dependencies of constants.
14576 if (isa<SCEVConstant>(Op))
14577 continue;
14578 auto It = SCEVUsers.find(Op);
14579 if (It != SCEVUsers.end() && It->second.count(&S))
14580 continue;
14581 dbgs() << "Use of operand " << *Op << " by user " << S
14582 << " is not being tracked!\n";
14583 std::abort();
14584 }
14585 }
14586
14587 // Verify integrity of ValuesAtScopes users.
14588 for (const auto &ValueAndVec : ValuesAtScopes) {
14589 const SCEV *Value = ValueAndVec.first;
14590 for (const auto &LoopAndValueAtScope : ValueAndVec.second) {
14591 const Loop *L = LoopAndValueAtScope.first;
14592 const SCEV *ValueAtScope = LoopAndValueAtScope.second;
14593 if (!isa<SCEVConstant>(ValueAtScope)) {
14594 auto It = ValuesAtScopesUsers.find(ValueAtScope);
14595 if (It != ValuesAtScopesUsers.end() &&
14596 is_contained(It->second, std::make_pair(L, Value)))
14597 continue;
14598 dbgs() << "Value: " << *Value << ", Loop: " << *L << ", ValueAtScope: "
14599 << *ValueAtScope << " missing in ValuesAtScopesUsers\n";
14600 std::abort();
14601 }
14602 }
14603 }
14604
14605 for (const auto &ValueAtScopeAndVec : ValuesAtScopesUsers) {
14606 const SCEV *ValueAtScope = ValueAtScopeAndVec.first;
14607 for (const auto &LoopAndValue : ValueAtScopeAndVec.second) {
14608 const Loop *L = LoopAndValue.first;
14609 const SCEV *Value = LoopAndValue.second;
14611 auto It = ValuesAtScopes.find(Value);
14612 if (It != ValuesAtScopes.end() &&
14613 is_contained(It->second, std::make_pair(L, ValueAtScope)))
14614 continue;
14615 dbgs() << "Value: " << *Value << ", Loop: " << *L << ", ValueAtScope: "
14616 << *ValueAtScope << " missing in ValuesAtScopes\n";
14617 std::abort();
14618 }
14619 }
14620
14621 // Verify integrity of BECountUsers.
14622 auto VerifyBECountUsers = [&](bool Predicated) {
14623 auto &BECounts =
14624 Predicated ? PredicatedBackedgeTakenCounts : BackedgeTakenCounts;
14625 for (const auto &LoopAndBEInfo : BECounts) {
14626 for (const ExitNotTakenInfo &ENT : LoopAndBEInfo.second.ExitNotTaken) {
14627 for (const SCEV *S : {ENT.ExactNotTaken, ENT.SymbolicMaxNotTaken}) {
14628 if (!isa<SCEVConstant>(S)) {
14629 auto UserIt = BECountUsers.find(S);
14630 if (UserIt != BECountUsers.end() &&
14631 UserIt->second.contains({ LoopAndBEInfo.first, Predicated }))
14632 continue;
14633 dbgs() << "Value " << *S << " for loop " << *LoopAndBEInfo.first
14634 << " missing from BECountUsers\n";
14635 std::abort();
14636 }
14637 }
14638 }
14639 }
14640 };
14641 VerifyBECountUsers(/* Predicated */ false);
14642 VerifyBECountUsers(/* Predicated */ true);
14643
14644 // Verify intergity of loop disposition cache.
14645 for (auto &[S, Values] : LoopDispositions) {
14646 for (auto [Loop, CachedDisposition] : Values) {
14647 const auto RecomputedDisposition = SE2.getLoopDisposition(S, Loop);
14648 if (CachedDisposition != RecomputedDisposition) {
14649 dbgs() << "Cached disposition of " << *S << " for loop " << *Loop
14650 << " is incorrect: cached " << CachedDisposition << ", actual "
14651 << RecomputedDisposition << "\n";
14652 std::abort();
14653 }
14654 }
14655 }
14656
14657 // Verify integrity of the block disposition cache.
14658 for (auto &[S, Values] : BlockDispositions) {
14659 for (auto [BB, CachedDisposition] : Values) {
14660 const auto RecomputedDisposition = SE2.getBlockDisposition(S, BB);
14661 if (CachedDisposition != RecomputedDisposition) {
14662 dbgs() << "Cached disposition of " << *S << " for block %"
14663 << BB->getName() << " is incorrect: cached " << CachedDisposition
14664 << ", actual " << RecomputedDisposition << "\n";
14665 std::abort();
14666 }
14667 }
14668 }
14669
14670 // Verify FoldCache/FoldCacheUser caches.
14671 for (auto [FoldID, Expr] : FoldCache) {
14672 auto I = FoldCacheUser.find(Expr);
14673 if (I == FoldCacheUser.end()) {
14674 dbgs() << "Missing entry in FoldCacheUser for cached expression " << *Expr
14675 << "!\n";
14676 std::abort();
14677 }
14678 if (!is_contained(I->second, FoldID)) {
14679 dbgs() << "Missing FoldID in cached users of " << *Expr << "!\n";
14680 std::abort();
14681 }
14682 }
14683 for (auto [Expr, IDs] : FoldCacheUser) {
14684 for (auto &FoldID : IDs) {
14685 const SCEV *S = FoldCache.lookup(FoldID);
14686 if (!S) {
14687 dbgs() << "Missing entry in FoldCache for expression " << *Expr
14688 << "!\n";
14689 std::abort();
14690 }
14691 if (S != Expr) {
14692 dbgs() << "Entry in FoldCache doesn't match FoldCacheUser: " << *S
14693 << " != " << *Expr << "!\n";
14694 std::abort();
14695 }
14696 }
14697 }
14698
14699 // Verify that ConstantMultipleCache computations are correct. We check that
14700 // cached multiples and recomputed multiples are multiples of each other to
14701 // verify correctness. It is possible that a recomputed multiple is different
14702 // from the cached multiple due to strengthened no wrap flags or changes in
14703 // KnownBits computations.
14704 for (auto [S, Multiple] : ConstantMultipleCache) {
14705 APInt RecomputedMultiple = SE2.getConstantMultiple(S);
14706 if ((Multiple != 0 && RecomputedMultiple != 0 &&
14707 Multiple.urem(RecomputedMultiple) != 0 &&
14708 RecomputedMultiple.urem(Multiple) != 0)) {
14709 dbgs() << "Incorrect cached computation in ConstantMultipleCache for "
14710 << *S << " : Computed " << RecomputedMultiple
14711 << " but cache contains " << Multiple << "!\n";
14712 std::abort();
14713 }
14714 }
14715}
14716
14718 Function &F, const PreservedAnalyses &PA,
14719 FunctionAnalysisManager::Invalidator &Inv) {
14720 // Invalidate the ScalarEvolution object whenever it isn't preserved or one
14721 // of its dependencies is invalidated.
14722 auto PAC = PA.getChecker<ScalarEvolutionAnalysis>();
14723 return !(PAC.preserved() || PAC.preservedSet<AllAnalysesOn<Function>>()) ||
14724 Inv.invalidate<AssumptionAnalysis>(F, PA) ||
14725 Inv.invalidate<DominatorTreeAnalysis>(F, PA) ||
14726 Inv.invalidate<LoopAnalysis>(F, PA);
14727}
14728
14729AnalysisKey ScalarEvolutionAnalysis::Key;
14730
14733 auto &TLI = AM.getResult<TargetLibraryAnalysis>(F);
14734 auto &AC = AM.getResult<AssumptionAnalysis>(F);
14735 auto &DT = AM.getResult<DominatorTreeAnalysis>(F);
14736 auto &LI = AM.getResult<LoopAnalysis>(F);
14737 return ScalarEvolution(F, TLI, AC, DT, LI);
14738}
14739
14745
14748 // For compatibility with opt's -analyze feature under legacy pass manager
14749 // which was not ported to NPM. This keeps tests using
14750 // update_analyze_test_checks.py working.
14751 OS << "Printing analysis 'Scalar Evolution Analysis' for function '"
14752 << F.getName() << "':\n";
14754 return PreservedAnalyses::all();
14755}
14756
14758 "Scalar Evolution Analysis", false, true)
14764 "Scalar Evolution Analysis", false, true)
14765
14767
14769
14771 SE.reset(new ScalarEvolution(
14773 getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F),
14775 getAnalysis<LoopInfoWrapperPass>().getLoopInfo()));
14776 return false;
14777}
14778
14780
14782 SE->print(OS);
14783}
14784
14786 if (!VerifySCEV)
14787 return;
14788
14789 SE->verify();
14790}
14791
14799
14801 const SCEV *RHS) {
14802 return getComparePredicate(ICmpInst::ICMP_EQ, LHS, RHS);
14803}
14804
14805const SCEVPredicate *
14807 const SCEV *LHS, const SCEV *RHS) {
14809 assert(LHS->getType() == RHS->getType() &&
14810 "Type mismatch between LHS and RHS");
14811 // Unique this node based on the arguments
14812 ID.AddInteger(SCEVPredicate::P_Compare);
14813 ID.AddInteger(Pred);
14814 ID.AddPointer(LHS);
14815 ID.AddPointer(RHS);
14816 void *IP = nullptr;
14817 if (const auto *S = UniquePreds.FindNodeOrInsertPos(ID, IP))
14818 return S;
14819 SCEVComparePredicate *Eq = new (SCEVAllocator)
14820 SCEVComparePredicate(ID.Intern(SCEVAllocator), Pred, LHS, RHS);
14821 UniquePreds.InsertNode(Eq, IP);
14822 return Eq;
14823}
14824
14826 const SCEVAddRecExpr *AR,
14829 // Unique this node based on the arguments
14830 ID.AddInteger(SCEVPredicate::P_Wrap);
14831 ID.AddPointer(AR);
14832 ID.AddInteger(AddedFlags);
14833 void *IP = nullptr;
14834 if (const auto *S = UniquePreds.FindNodeOrInsertPos(ID, IP))
14835 return S;
14836 auto *OF = new (SCEVAllocator)
14837 SCEVWrapPredicate(ID.Intern(SCEVAllocator), AR, AddedFlags);
14838 UniquePreds.InsertNode(OF, IP);
14839 return OF;
14840}
14841
14842namespace {
14843
14844class SCEVPredicateRewriter : public SCEVRewriteVisitor<SCEVPredicateRewriter> {
14845public:
14846
14847 /// Rewrites \p S in the context of a loop L and the SCEV predication
14848 /// infrastructure.
14849 ///
14850 /// If \p Pred is non-null, the SCEV expression is rewritten to respect the
14851 /// equivalences present in \p Pred.
14852 ///
14853 /// If \p NewPreds is non-null, rewrite is free to add further predicates to
14854 /// \p NewPreds such that the result will be an AddRecExpr.
14855 static const SCEV *rewrite(const SCEV *S, const Loop *L, ScalarEvolution &SE,
14857 const SCEVPredicate *Pred) {
14858 SCEVPredicateRewriter Rewriter(L, SE, NewPreds, Pred);
14859 return Rewriter.visit(S);
14860 }
14861
14862 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
14863 if (Pred) {
14864 if (auto *U = dyn_cast<SCEVUnionPredicate>(Pred)) {
14865 for (const auto *Pred : U->getPredicates())
14866 if (const auto *IPred = dyn_cast<SCEVComparePredicate>(Pred))
14867 if (IPred->getLHS() == Expr &&
14868 IPred->getPredicate() == ICmpInst::ICMP_EQ)
14869 return IPred->getRHS();
14870 } else if (const auto *IPred = dyn_cast<SCEVComparePredicate>(Pred)) {
14871 if (IPred->getLHS() == Expr &&
14872 IPred->getPredicate() == ICmpInst::ICMP_EQ)
14873 return IPred->getRHS();
14874 }
14875 }
14876 return convertToAddRecWithPreds(Expr);
14877 }
14878
14879 const SCEV *visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr) {
14880 const SCEV *Operand = visit(Expr->getOperand());
14881 const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Operand);
14882 if (AR && AR->getLoop() == L && AR->isAffine()) {
14883 // This couldn't be folded because the operand didn't have the nuw
14884 // flag. Add the nusw flag as an assumption that we could make.
14885 const SCEV *Step = AR->getStepRecurrence(SE);
14886 Type *Ty = Expr->getType();
14887 if (addOverflowAssumption(AR, SCEVWrapPredicate::IncrementNUSW))
14888 return SE.getAddRecExpr(SE.getZeroExtendExpr(AR->getStart(), Ty),
14889 SE.getSignExtendExpr(Step, Ty), L,
14890 AR->getNoWrapFlags());
14891 }
14892 return SE.getZeroExtendExpr(Operand, Expr->getType());
14893 }
14894
14895 const SCEV *visitSignExtendExpr(const SCEVSignExtendExpr *Expr) {
14896 const SCEV *Operand = visit(Expr->getOperand());
14897 const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Operand);
14898 if (AR && AR->getLoop() == L && AR->isAffine()) {
14899 // This couldn't be folded because the operand didn't have the nsw
14900 // flag. Add the nssw flag as an assumption that we could make.
14901 const SCEV *Step = AR->getStepRecurrence(SE);
14902 Type *Ty = Expr->getType();
14903 if (addOverflowAssumption(AR, SCEVWrapPredicate::IncrementNSSW))
14904 return SE.getAddRecExpr(SE.getSignExtendExpr(AR->getStart(), Ty),
14905 SE.getSignExtendExpr(Step, Ty), L,
14906 AR->getNoWrapFlags());
14907 }
14908 return SE.getSignExtendExpr(Operand, Expr->getType());
14909 }
14910
14911private:
14912 explicit SCEVPredicateRewriter(
14913 const Loop *L, ScalarEvolution &SE,
14914 SmallVectorImpl<const SCEVPredicate *> *NewPreds,
14915 const SCEVPredicate *Pred)
14916 : SCEVRewriteVisitor(SE), NewPreds(NewPreds), Pred(Pred), L(L) {}
14917
14918 bool addOverflowAssumption(const SCEVPredicate *P) {
14919 if (!NewPreds) {
14920 // Check if we've already made this assumption.
14921 return Pred && Pred->implies(P, SE);
14922 }
14923 NewPreds->push_back(P);
14924 return true;
14925 }
14926
14927 bool addOverflowAssumption(const SCEVAddRecExpr *AR,
14929 auto *A = SE.getWrapPredicate(AR, AddedFlags);
14930 return addOverflowAssumption(A);
14931 }
14932
14933 // If \p Expr represents a PHINode, we try to see if it can be represented
14934 // as an AddRec, possibly under a predicate (PHISCEVPred). If it is possible
14935 // to add this predicate as a runtime overflow check, we return the AddRec.
14936 // If \p Expr does not meet these conditions (is not a PHI node, or we
14937 // couldn't create an AddRec for it, or couldn't add the predicate), we just
14938 // return \p Expr.
14939 const SCEV *convertToAddRecWithPreds(const SCEVUnknown *Expr) {
14940 if (!isa<PHINode>(Expr->getValue()))
14941 return Expr;
14942 std::optional<
14943 std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>>>
14944 PredicatedRewrite = SE.createAddRecFromPHIWithCasts(Expr);
14945 if (!PredicatedRewrite)
14946 return Expr;
14947 for (const auto *P : PredicatedRewrite->second){
14948 // Wrap predicates from outer loops are not supported.
14949 if (auto *WP = dyn_cast<const SCEVWrapPredicate>(P)) {
14950 if (L != WP->getExpr()->getLoop())
14951 return Expr;
14952 }
14953 if (!addOverflowAssumption(P))
14954 return Expr;
14955 }
14956 return PredicatedRewrite->first;
14957 }
14958
14959 SmallVectorImpl<const SCEVPredicate *> *NewPreds;
14960 const SCEVPredicate *Pred;
14961 const Loop *L;
14962};
14963
14964} // end anonymous namespace
14965
14966const SCEV *
14968 const SCEVPredicate &Preds) {
14969 return SCEVPredicateRewriter::rewrite(S, L, *this, nullptr, &Preds);
14970}
14971
14973 const SCEV *S, const Loop *L,
14976 S = SCEVPredicateRewriter::rewrite(S, L, *this, &TransformPreds, nullptr);
14977 auto *AddRec = dyn_cast<SCEVAddRecExpr>(S);
14978
14979 if (!AddRec)
14980 return nullptr;
14981
14982 // Check if any of the transformed predicates is known to be false. In that
14983 // case, it doesn't make sense to convert to a predicated AddRec, as the
14984 // versioned loop will never execute.
14985 for (const SCEVPredicate *Pred : TransformPreds) {
14986 auto *WrapPred = dyn_cast<SCEVWrapPredicate>(Pred);
14987 if (!WrapPred || WrapPred->getFlags() != SCEVWrapPredicate::IncrementNSSW)
14988 continue;
14989
14990 const SCEVAddRecExpr *AddRecToCheck = WrapPred->getExpr();
14991 const SCEV *ExitCount = getBackedgeTakenCount(AddRecToCheck->getLoop());
14992 if (isa<SCEVCouldNotCompute>(ExitCount))
14993 continue;
14994
14995 const SCEV *Step = AddRecToCheck->getStepRecurrence(*this);
14996 if (!Step->isOne())
14997 continue;
14998
14999 ExitCount = getTruncateOrSignExtend(ExitCount, Step->getType());
15000 const SCEV *Add = getAddExpr(AddRecToCheck->getStart(), ExitCount);
15001 if (isKnownPredicate(CmpInst::ICMP_SLT, Add, AddRecToCheck->getStart()))
15002 return nullptr;
15003 }
15004
15005 // Since the transformation was successful, we can now transfer the SCEV
15006 // predicates.
15007 Preds.append(TransformPreds.begin(), TransformPreds.end());
15008
15009 return AddRec;
15010}
15011
15012/// SCEV predicates
15016
15018 const ICmpInst::Predicate Pred,
15019 const SCEV *LHS, const SCEV *RHS)
15020 : SCEVPredicate(ID, P_Compare), Pred(Pred), LHS(LHS), RHS(RHS) {
15021 assert(LHS->getType() == RHS->getType() && "LHS and RHS types don't match");
15022 assert(LHS != RHS && "LHS and RHS are the same SCEV");
15023}
15024
15026 ScalarEvolution &SE) const {
15027 const auto *Op = dyn_cast<SCEVComparePredicate>(N);
15028
15029 if (!Op)
15030 return false;
15031
15032 if (Pred != ICmpInst::ICMP_EQ)
15033 return false;
15034
15035 return Op->LHS == LHS && Op->RHS == RHS;
15036}
15037
15038bool SCEVComparePredicate::isAlwaysTrue() const { return false; }
15039
15041 if (Pred == ICmpInst::ICMP_EQ)
15042 OS.indent(Depth) << "Equal predicate: " << *LHS << " == " << *RHS << "\n";
15043 else
15044 OS.indent(Depth) << "Compare predicate: " << *LHS << " " << Pred << ") "
15045 << *RHS << "\n";
15046
15047}
15048
15050 const SCEVAddRecExpr *AR,
15051 IncrementWrapFlags Flags)
15052 : SCEVPredicate(ID, P_Wrap), AR(AR), Flags(Flags) {}
15053
15054const SCEVAddRecExpr *SCEVWrapPredicate::getExpr() const { return AR; }
15055
15057 ScalarEvolution &SE) const {
15058 const auto *Op = dyn_cast<SCEVWrapPredicate>(N);
15059 if (!Op || setFlags(Flags, Op->Flags) != Flags)
15060 return false;
15061
15062 if (Op->AR == AR)
15063 return true;
15064
15065 if (Flags != SCEVWrapPredicate::IncrementNSSW &&
15067 return false;
15068
15069 const SCEV *Start = AR->getStart();
15070 const SCEV *OpStart = Op->AR->getStart();
15071 if (Start->getType()->isPointerTy() != OpStart->getType()->isPointerTy())
15072 return false;
15073
15074 // Reject pointers to different address spaces.
15075 if (Start->getType()->isPointerTy() && Start->getType() != OpStart->getType())
15076 return false;
15077
15078 const SCEV *Step = AR->getStepRecurrence(SE);
15079 const SCEV *OpStep = Op->AR->getStepRecurrence(SE);
15080 if (!SE.isKnownPositive(Step) || !SE.isKnownPositive(OpStep))
15081 return false;
15082
15083 // If both steps are positive, this implies N, if N's start and step are
15084 // ULE/SLE (for NSUW/NSSW) than this'.
15085 Type *WiderTy = SE.getWiderType(Step->getType(), OpStep->getType());
15086 Step = SE.getNoopOrZeroExtend(Step, WiderTy);
15087 OpStep = SE.getNoopOrZeroExtend(OpStep, WiderTy);
15088
15089 bool IsNUW = Flags == SCEVWrapPredicate::IncrementNUSW;
15090 OpStart = IsNUW ? SE.getNoopOrZeroExtend(OpStart, WiderTy)
15091 : SE.getNoopOrSignExtend(OpStart, WiderTy);
15092 Start = IsNUW ? SE.getNoopOrZeroExtend(Start, WiderTy)
15093 : SE.getNoopOrSignExtend(Start, WiderTy);
15095 return SE.isKnownPredicate(Pred, OpStep, Step) &&
15096 SE.isKnownPredicate(Pred, OpStart, Start);
15097}
15098
15100 SCEV::NoWrapFlags ScevFlags = AR->getNoWrapFlags();
15101 IncrementWrapFlags IFlags = Flags;
15102
15103 if (ScalarEvolution::setFlags(ScevFlags, SCEV::FlagNSW) == ScevFlags)
15104 IFlags = clearFlags(IFlags, IncrementNSSW);
15105
15106 return IFlags == IncrementAnyWrap;
15107}
15108
15109void SCEVWrapPredicate::print(raw_ostream &OS, unsigned Depth) const {
15110 OS.indent(Depth) << *getExpr() << " Added Flags: ";
15112 OS << "<nusw>";
15114 OS << "<nssw>";
15115 OS << "\n";
15116}
15117
15120 ScalarEvolution &SE) {
15121 IncrementWrapFlags ImpliedFlags = IncrementAnyWrap;
15122 SCEV::NoWrapFlags StaticFlags = AR->getNoWrapFlags();
15123
15124 // We can safely transfer the NSW flag as NSSW.
15125 if (ScalarEvolution::setFlags(StaticFlags, SCEV::FlagNSW) == StaticFlags)
15126 ImpliedFlags = IncrementNSSW;
15127
15128 if (ScalarEvolution::setFlags(StaticFlags, SCEV::FlagNUW) == StaticFlags) {
15129 // If the increment is positive, the SCEV NUW flag will also imply the
15130 // WrapPredicate NUSW flag.
15131 if (const auto *Step = dyn_cast<SCEVConstant>(AR->getStepRecurrence(SE)))
15132 if (Step->getValue()->getValue().isNonNegative())
15133 ImpliedFlags = setFlags(ImpliedFlags, IncrementNUSW);
15134 }
15135
15136 return ImpliedFlags;
15137}
15138
15139/// Union predicates don't get cached so create a dummy set ID for it.
15141 ScalarEvolution &SE)
15142 : SCEVPredicate(FoldingSetNodeIDRef(nullptr, 0), P_Union) {
15143 for (const auto *P : Preds)
15144 add(P, SE);
15145}
15146
15148 return all_of(Preds,
15149 [](const SCEVPredicate *I) { return I->isAlwaysTrue(); });
15150}
15151
15153 ScalarEvolution &SE) const {
15154 if (const auto *Set = dyn_cast<SCEVUnionPredicate>(N))
15155 return all_of(Set->Preds, [this, &SE](const SCEVPredicate *I) {
15156 return this->implies(I, SE);
15157 });
15158
15159 return any_of(Preds,
15160 [N, &SE](const SCEVPredicate *I) { return I->implies(N, SE); });
15161}
15162
15164 for (const auto *Pred : Preds)
15165 Pred->print(OS, Depth);
15166}
15167
15168void SCEVUnionPredicate::add(const SCEVPredicate *N, ScalarEvolution &SE) {
15169 if (const auto *Set = dyn_cast<SCEVUnionPredicate>(N)) {
15170 for (const auto *Pred : Set->Preds)
15171 add(Pred, SE);
15172 return;
15173 }
15174
15175 // Only add predicate if it is not already implied by this union predicate.
15176 if (implies(N, SE))
15177 return;
15178
15179 // Build a new vector containing the current predicates, except the ones that
15180 // are implied by the new predicate N.
15182 for (auto *P : Preds) {
15183 if (N->implies(P, SE))
15184 continue;
15185 PrunedPreds.push_back(P);
15186 }
15187 Preds = std::move(PrunedPreds);
15188 Preds.push_back(N);
15189}
15190
15192 Loop &L)
15193 : SE(SE), L(L) {
15195 Preds = std::make_unique<SCEVUnionPredicate>(Empty, SE);
15196}
15197
15200 for (const auto *Op : Ops)
15201 // We do not expect that forgetting cached data for SCEVConstants will ever
15202 // open any prospects for sharpening or introduce any correctness issues,
15203 // so we don't bother storing their dependencies.
15204 if (!isa<SCEVConstant>(Op))
15205 SCEVUsers[Op].insert(User);
15206}
15207
15209 const SCEV *Expr = SE.getSCEV(V);
15210 RewriteEntry &Entry = RewriteMap[Expr];
15211
15212 // If we already have an entry and the version matches, return it.
15213 if (Entry.second && Generation == Entry.first)
15214 return Entry.second;
15215
15216 // We found an entry but it's stale. Rewrite the stale entry
15217 // according to the current predicate.
15218 if (Entry.second)
15219 Expr = Entry.second;
15220
15221 const SCEV *NewSCEV = SE.rewriteUsingPredicate(Expr, &L, *Preds);
15222 Entry = {Generation, NewSCEV};
15223
15224 return NewSCEV;
15225}
15226
15228 if (!BackedgeCount) {
15230 BackedgeCount = SE.getPredicatedBackedgeTakenCount(&L, Preds);
15231 for (const auto *P : Preds)
15232 addPredicate(*P);
15233 }
15234 return BackedgeCount;
15235}
15236
15238 if (!SymbolicMaxBackedgeCount) {
15240 SymbolicMaxBackedgeCount =
15241 SE.getPredicatedSymbolicMaxBackedgeTakenCount(&L, Preds);
15242 for (const auto *P : Preds)
15243 addPredicate(*P);
15244 }
15245 return SymbolicMaxBackedgeCount;
15246}
15247
15249 if (!SmallConstantMaxTripCount) {
15251 SmallConstantMaxTripCount = SE.getSmallConstantMaxTripCount(&L, &Preds);
15252 for (const auto *P : Preds)
15253 addPredicate(*P);
15254 }
15255 return *SmallConstantMaxTripCount;
15256}
15257
15259 if (Preds->implies(&Pred, SE))
15260 return;
15261
15262 SmallVector<const SCEVPredicate *, 4> NewPreds(Preds->getPredicates());
15263 NewPreds.push_back(&Pred);
15264 Preds = std::make_unique<SCEVUnionPredicate>(NewPreds, SE);
15265 updateGeneration();
15266}
15267
15269 return *Preds;
15270}
15271
15272void PredicatedScalarEvolution::updateGeneration() {
15273 // If the generation number wrapped recompute everything.
15274 if (++Generation == 0) {
15275 for (auto &II : RewriteMap) {
15276 const SCEV *Rewritten = II.second.second;
15277 II.second = {Generation, SE.rewriteUsingPredicate(Rewritten, &L, *Preds)};
15278 }
15279 }
15280}
15281
15284 const SCEV *Expr = getSCEV(V);
15285 const auto *AR = cast<SCEVAddRecExpr>(Expr);
15286
15287 auto ImpliedFlags = SCEVWrapPredicate::getImpliedFlags(AR, SE);
15288
15289 // Clear the statically implied flags.
15290 Flags = SCEVWrapPredicate::clearFlags(Flags, ImpliedFlags);
15291 addPredicate(*SE.getWrapPredicate(AR, Flags));
15292
15293 auto II = FlagsMap.insert({V, Flags});
15294 if (!II.second)
15295 II.first->second = SCEVWrapPredicate::setFlags(Flags, II.first->second);
15296}
15297
15300 const SCEV *Expr = getSCEV(V);
15301 const auto *AR = cast<SCEVAddRecExpr>(Expr);
15302
15304 Flags, SCEVWrapPredicate::getImpliedFlags(AR, SE));
15305
15306 auto II = FlagsMap.find(V);
15307
15308 if (II != FlagsMap.end())
15309 Flags = SCEVWrapPredicate::clearFlags(Flags, II->second);
15310
15312}
15313
15315 const SCEV *Expr = this->getSCEV(V);
15317 auto *New = SE.convertSCEVToAddRecWithPredicates(Expr, &L, NewPreds);
15318
15319 if (!New)
15320 return nullptr;
15321
15322 for (const auto *P : NewPreds)
15323 addPredicate(*P);
15324
15325 RewriteMap[SE.getSCEV(V)] = {Generation, New};
15326 return New;
15327}
15328
15331 : RewriteMap(Init.RewriteMap), SE(Init.SE), L(Init.L),
15332 Preds(std::make_unique<SCEVUnionPredicate>(Init.Preds->getPredicates(),
15333 SE)),
15334 Generation(Init.Generation), BackedgeCount(Init.BackedgeCount) {
15335 for (auto I : Init.FlagsMap)
15336 FlagsMap.insert(I);
15337}
15338
15340 // For each block.
15341 for (auto *BB : L.getBlocks())
15342 for (auto &I : *BB) {
15343 if (!SE.isSCEVable(I.getType()))
15344 continue;
15345
15346 auto *Expr = SE.getSCEV(&I);
15347 auto II = RewriteMap.find(Expr);
15348
15349 if (II == RewriteMap.end())
15350 continue;
15351
15352 // Don't print things that are not interesting.
15353 if (II->second.second == Expr)
15354 continue;
15355
15356 OS.indent(Depth) << "[PSE]" << I << ":\n";
15357 OS.indent(Depth + 2) << *Expr << "\n";
15358 OS.indent(Depth + 2) << "--> " << *II->second.second << "\n";
15359 }
15360}
15361
15362// Match the mathematical pattern A - (A / B) * B, where A and B can be
15363// arbitrary expressions. Also match zext (trunc A to iB) to iY, which is used
15364// for URem with constant power-of-2 second operands.
15365// It's not always easy, as A and B can be folded (imagine A is X / 2, and B is
15366// 4, A / B becomes X / 8).
15367bool ScalarEvolution::matchURem(const SCEV *Expr, const SCEV *&LHS,
15368 const SCEV *&RHS) {
15369 if (Expr->getType()->isPointerTy())
15370 return false;
15371
15372 // Try to match 'zext (trunc A to iB) to iY', which is used
15373 // for URem with constant power-of-2 second operands. Make sure the size of
15374 // the operand A matches the size of the whole expressions.
15375 if (const auto *ZExt = dyn_cast<SCEVZeroExtendExpr>(Expr))
15376 if (const auto *Trunc = dyn_cast<SCEVTruncateExpr>(ZExt->getOperand(0))) {
15377 LHS = Trunc->getOperand();
15378 // Bail out if the type of the LHS is larger than the type of the
15379 // expression for now.
15380 if (getTypeSizeInBits(LHS->getType()) >
15381 getTypeSizeInBits(Expr->getType()))
15382 return false;
15383 if (LHS->getType() != Expr->getType())
15384 LHS = getZeroExtendExpr(LHS, Expr->getType());
15385 RHS = getConstant(APInt(getTypeSizeInBits(Expr->getType()), 1)
15386 << getTypeSizeInBits(Trunc->getType()));
15387 return true;
15388 }
15389 const auto *Add = dyn_cast<SCEVAddExpr>(Expr);
15390 if (Add == nullptr || Add->getNumOperands() != 2)
15391 return false;
15392
15393 const SCEV *A = Add->getOperand(1);
15394 const auto *Mul = dyn_cast<SCEVMulExpr>(Add->getOperand(0));
15395
15396 if (Mul == nullptr)
15397 return false;
15398
15399 const auto MatchURemWithDivisor = [&](const SCEV *B) {
15400 // (SomeExpr + (-(SomeExpr / B) * B)).
15401 if (Expr == getURemExpr(A, B)) {
15402 LHS = A;
15403 RHS = B;
15404 return true;
15405 }
15406 return false;
15407 };
15408
15409 // (SomeExpr + (-1 * (SomeExpr / B) * B)).
15410 if (Mul->getNumOperands() == 3 && isa<SCEVConstant>(Mul->getOperand(0)))
15411 return MatchURemWithDivisor(Mul->getOperand(1)) ||
15412 MatchURemWithDivisor(Mul->getOperand(2));
15413
15414 // (SomeExpr + ((-SomeExpr / B) * B)) or (SomeExpr + ((SomeExpr / B) * -B)).
15415 if (Mul->getNumOperands() == 2)
15416 return MatchURemWithDivisor(Mul->getOperand(1)) ||
15417 MatchURemWithDivisor(Mul->getOperand(0)) ||
15418 MatchURemWithDivisor(getNegativeSCEV(Mul->getOperand(1))) ||
15419 MatchURemWithDivisor(getNegativeSCEV(Mul->getOperand(0)));
15420 return false;
15421}
15422
15425 BasicBlock *Header = L->getHeader();
15426 BasicBlock *Pred = L->getLoopPredecessor();
15427 LoopGuards Guards(SE);
15428 if (!Pred)
15429 return Guards;
15431 collectFromBlock(SE, Guards, Header, Pred, VisitedBlocks);
15432 return Guards;
15433}
15434
15435void ScalarEvolution::LoopGuards::collectFromPHI(
15437 const PHINode &Phi, SmallPtrSetImpl<const BasicBlock *> &VisitedBlocks,
15439 unsigned Depth) {
15440 if (!SE.isSCEVable(Phi.getType()))
15441 return;
15442
15443 using MinMaxPattern = std::pair<const SCEVConstant *, SCEVTypes>;
15444 auto GetMinMaxConst = [&](unsigned IncomingIdx) -> MinMaxPattern {
15445 const BasicBlock *InBlock = Phi.getIncomingBlock(IncomingIdx);
15446 if (!VisitedBlocks.insert(InBlock).second)
15447 return {nullptr, scCouldNotCompute};
15448 auto [G, Inserted] = IncomingGuards.try_emplace(InBlock, LoopGuards(SE));
15449 if (Inserted)
15450 collectFromBlock(SE, G->second, Phi.getParent(), InBlock, VisitedBlocks,
15451 Depth + 1);
15452 auto &RewriteMap = G->second.RewriteMap;
15453 if (RewriteMap.empty())
15454 return {nullptr, scCouldNotCompute};
15455 auto S = RewriteMap.find(SE.getSCEV(Phi.getIncomingValue(IncomingIdx)));
15456 if (S == RewriteMap.end())
15457 return {nullptr, scCouldNotCompute};
15458 auto *SM = dyn_cast_if_present<SCEVMinMaxExpr>(S->second);
15459 if (!SM)
15460 return {nullptr, scCouldNotCompute};
15461 if (const SCEVConstant *C0 = dyn_cast<SCEVConstant>(SM->getOperand(0)))
15462 return {C0, SM->getSCEVType()};
15463 return {nullptr, scCouldNotCompute};
15464 };
15465 auto MergeMinMaxConst = [](MinMaxPattern P1,
15466 MinMaxPattern P2) -> MinMaxPattern {
15467 auto [C1, T1] = P1;
15468 auto [C2, T2] = P2;
15469 if (!C1 || !C2 || T1 != T2)
15470 return {nullptr, scCouldNotCompute};
15471 switch (T1) {
15472 case scUMaxExpr:
15473 return {C1->getAPInt().ult(C2->getAPInt()) ? C1 : C2, T1};
15474 case scSMaxExpr:
15475 return {C1->getAPInt().slt(C2->getAPInt()) ? C1 : C2, T1};
15476 case scUMinExpr:
15477 return {C1->getAPInt().ugt(C2->getAPInt()) ? C1 : C2, T1};
15478 case scSMinExpr:
15479 return {C1->getAPInt().sgt(C2->getAPInt()) ? C1 : C2, T1};
15480 default:
15481 llvm_unreachable("Trying to merge non-MinMaxExpr SCEVs.");
15482 }
15483 };
15484 auto P = GetMinMaxConst(0);
15485 for (unsigned int In = 1; In < Phi.getNumIncomingValues(); In++) {
15486 if (!P.first)
15487 break;
15488 P = MergeMinMaxConst(P, GetMinMaxConst(In));
15489 }
15490 if (P.first) {
15491 const SCEV *LHS = SE.getSCEV(const_cast<PHINode *>(&Phi));
15493 const SCEV *RHS = SE.getMinMaxExpr(P.second, Ops);
15494 Guards.RewriteMap.insert({LHS, RHS});
15495 }
15496}
15497
15498void ScalarEvolution::LoopGuards::collectFromBlock(
15499 ScalarEvolution &SE, ScalarEvolution::LoopGuards &Guards,
15500 const BasicBlock *Block, const BasicBlock *Pred,
15501 SmallPtrSetImpl<const BasicBlock *> &VisitedBlocks, unsigned Depth) {
15502 SmallVector<const SCEV *> ExprsToRewrite;
15503 auto CollectCondition = [&](ICmpInst::Predicate Predicate, const SCEV *LHS,
15504 const SCEV *RHS,
15505 DenseMap<const SCEV *, const SCEV *>
15506 &RewriteMap) {
15507 // WARNING: It is generally unsound to apply any wrap flags to the proposed
15508 // replacement SCEV which isn't directly implied by the structure of that
15509 // SCEV. In particular, using contextual facts to imply flags is *NOT*
15510 // legal. See the scoping rules for flags in the header to understand why.
15511
15512 // If LHS is a constant, apply information to the other expression.
15513 if (isa<SCEVConstant>(LHS)) {
15514 std::swap(LHS, RHS);
15516 }
15517
15518 // Check for a condition of the form (-C1 + X < C2). InstCombine will
15519 // create this form when combining two checks of the form (X u< C2 + C1) and
15520 // (X >=u C1).
15521 auto MatchRangeCheckIdiom = [&SE, Predicate, LHS, RHS, &RewriteMap,
15522 &ExprsToRewrite]() {
15523 const SCEVConstant *C1;
15524 const SCEVUnknown *LHSUnknown;
15525 auto *C2 = dyn_cast<SCEVConstant>(RHS);
15526 if (!match(LHS,
15527 m_scev_Add(m_SCEVConstant(C1), m_SCEVUnknown(LHSUnknown))) ||
15528 !C2)
15529 return false;
15530
15531 auto ExactRegion =
15532 ConstantRange::makeExactICmpRegion(Predicate, C2->getAPInt())
15533 .sub(C1->getAPInt());
15534
15535 // Bail out, unless we have a non-wrapping, monotonic range.
15536 if (ExactRegion.isWrappedSet() || ExactRegion.isFullSet())
15537 return false;
15538 auto [I, Inserted] = RewriteMap.try_emplace(LHSUnknown);
15539 const SCEV *RewrittenLHS = Inserted ? LHSUnknown : I->second;
15540 I->second = SE.getUMaxExpr(
15541 SE.getConstant(ExactRegion.getUnsignedMin()),
15542 SE.getUMinExpr(RewrittenLHS,
15543 SE.getConstant(ExactRegion.getUnsignedMax())));
15544 ExprsToRewrite.push_back(LHSUnknown);
15545 return true;
15546 };
15547 if (MatchRangeCheckIdiom())
15548 return;
15549
15550 // Return true if \p Expr is a MinMax SCEV expression with a non-negative
15551 // constant operand. If so, return in \p SCTy the SCEV type and in \p RHS
15552 // the non-constant operand and in \p LHS the constant operand.
15553 auto IsMinMaxSCEVWithNonNegativeConstant =
15554 [&](const SCEV *Expr, SCEVTypes &SCTy, const SCEV *&LHS,
15555 const SCEV *&RHS) {
15556 if (auto *MinMax = dyn_cast<SCEVMinMaxExpr>(Expr)) {
15557 if (MinMax->getNumOperands() != 2)
15558 return false;
15559 if (auto *C = dyn_cast<SCEVConstant>(MinMax->getOperand(0))) {
15560 if (C->getAPInt().isNegative())
15561 return false;
15562 SCTy = MinMax->getSCEVType();
15563 LHS = MinMax->getOperand(0);
15564 RHS = MinMax->getOperand(1);
15565 return true;
15566 }
15567 }
15568 return false;
15569 };
15570
15571 // Checks whether Expr is a non-negative constant, and Divisor is a positive
15572 // constant, and returns their APInt in ExprVal and in DivisorVal.
15573 auto GetNonNegExprAndPosDivisor = [&](const SCEV *Expr, const SCEV *Divisor,
15574 APInt &ExprVal, APInt &DivisorVal) {
15575 auto *ConstExpr = dyn_cast<SCEVConstant>(Expr);
15576 auto *ConstDivisor = dyn_cast<SCEVConstant>(Divisor);
15577 if (!ConstExpr || !ConstDivisor)
15578 return false;
15579 ExprVal = ConstExpr->getAPInt();
15580 DivisorVal = ConstDivisor->getAPInt();
15581 return ExprVal.isNonNegative() && !DivisorVal.isNonPositive();
15582 };
15583
15584 // Return a new SCEV that modifies \p Expr to the closest number divides by
15585 // \p Divisor and greater or equal than Expr.
15586 // For now, only handle constant Expr and Divisor.
15587 auto GetNextSCEVDividesByDivisor = [&](const SCEV *Expr,
15588 const SCEV *Divisor) {
15589 APInt ExprVal;
15590 APInt DivisorVal;
15591 if (!GetNonNegExprAndPosDivisor(Expr, Divisor, ExprVal, DivisorVal))
15592 return Expr;
15593 APInt Rem = ExprVal.urem(DivisorVal);
15594 if (!Rem.isZero())
15595 // return the SCEV: Expr + Divisor - Expr % Divisor
15596 return SE.getConstant(ExprVal + DivisorVal - Rem);
15597 return Expr;
15598 };
15599
15600 // Return a new SCEV that modifies \p Expr to the closest number divides by
15601 // \p Divisor and less or equal than Expr.
15602 // For now, only handle constant Expr and Divisor.
15603 auto GetPreviousSCEVDividesByDivisor = [&](const SCEV *Expr,
15604 const SCEV *Divisor) {
15605 APInt ExprVal;
15606 APInt DivisorVal;
15607 if (!GetNonNegExprAndPosDivisor(Expr, Divisor, ExprVal, DivisorVal))
15608 return Expr;
15609 APInt Rem = ExprVal.urem(DivisorVal);
15610 // return the SCEV: Expr - Expr % Divisor
15611 return SE.getConstant(ExprVal - Rem);
15612 };
15613
15614 // Apply divisibilty by \p Divisor on MinMaxExpr with constant values,
15615 // recursively. This is done by aligning up/down the constant value to the
15616 // Divisor.
15617 std::function<const SCEV *(const SCEV *, const SCEV *)>
15618 ApplyDivisibiltyOnMinMaxExpr = [&](const SCEV *MinMaxExpr,
15619 const SCEV *Divisor) {
15620 const SCEV *MinMaxLHS = nullptr, *MinMaxRHS = nullptr;
15621 SCEVTypes SCTy;
15622 if (!IsMinMaxSCEVWithNonNegativeConstant(MinMaxExpr, SCTy, MinMaxLHS,
15623 MinMaxRHS))
15624 return MinMaxExpr;
15625 auto IsMin =
15626 isa<SCEVSMinExpr>(MinMaxExpr) || isa<SCEVUMinExpr>(MinMaxExpr);
15627 assert(SE.isKnownNonNegative(MinMaxLHS) &&
15628 "Expected non-negative operand!");
15629 auto *DivisibleExpr =
15630 IsMin ? GetPreviousSCEVDividesByDivisor(MinMaxLHS, Divisor)
15631 : GetNextSCEVDividesByDivisor(MinMaxLHS, Divisor);
15633 ApplyDivisibiltyOnMinMaxExpr(MinMaxRHS, Divisor), DivisibleExpr};
15634 return SE.getMinMaxExpr(SCTy, Ops);
15635 };
15636
15637 // If we have LHS == 0, check if LHS is computing a property of some unknown
15638 // SCEV %v which we can rewrite %v to express explicitly.
15639 if (Predicate == CmpInst::ICMP_EQ && match(RHS, m_scev_Zero())) {
15640 // If LHS is A % B, i.e. A % B == 0, rewrite A to (A /u B) * B to
15641 // explicitly express that.
15642 const SCEV *URemLHS = nullptr;
15643 const SCEV *URemRHS = nullptr;
15644 if (SE.matchURem(LHS, URemLHS, URemRHS)) {
15645 if (const SCEVUnknown *LHSUnknown = dyn_cast<SCEVUnknown>(URemLHS)) {
15646 auto I = RewriteMap.find(LHSUnknown);
15647 const SCEV *RewrittenLHS =
15648 I != RewriteMap.end() ? I->second : LHSUnknown;
15649 RewrittenLHS = ApplyDivisibiltyOnMinMaxExpr(RewrittenLHS, URemRHS);
15650 const auto *Multiple =
15651 SE.getMulExpr(SE.getUDivExpr(RewrittenLHS, URemRHS), URemRHS);
15652 RewriteMap[LHSUnknown] = Multiple;
15653 ExprsToRewrite.push_back(LHSUnknown);
15654 return;
15655 }
15656 }
15657 }
15658
15659 // Do not apply information for constants or if RHS contains an AddRec.
15661 return;
15662
15663 // If RHS is SCEVUnknown, make sure the information is applied to it.
15665 std::swap(LHS, RHS);
15667 }
15668
15669 // Puts rewrite rule \p From -> \p To into the rewrite map. Also if \p From
15670 // and \p FromRewritten are the same (i.e. there has been no rewrite
15671 // registered for \p From), then puts this value in the list of rewritten
15672 // expressions.
15673 auto AddRewrite = [&](const SCEV *From, const SCEV *FromRewritten,
15674 const SCEV *To) {
15675 if (From == FromRewritten)
15676 ExprsToRewrite.push_back(From);
15677 RewriteMap[From] = To;
15678 };
15679
15680 // Checks whether \p S has already been rewritten. In that case returns the
15681 // existing rewrite because we want to chain further rewrites onto the
15682 // already rewritten value. Otherwise returns \p S.
15683 auto GetMaybeRewritten = [&](const SCEV *S) {
15684 return RewriteMap.lookup_or(S, S);
15685 };
15686
15687 // Check for the SCEV expression (A /u B) * B while B is a constant, inside
15688 // \p Expr. The check is done recuresively on \p Expr, which is assumed to
15689 // be a composition of Min/Max SCEVs. Return whether the SCEV expression (A
15690 // /u B) * B was found, and return the divisor B in \p DividesBy. For
15691 // example, if Expr = umin (umax ((A /u 8) * 8, 16), 64), return true since
15692 // (A /u 8) * 8 matched the pattern, and return the constant SCEV 8 in \p
15693 // DividesBy.
15694 std::function<bool(const SCEV *, const SCEV *&)> HasDivisibiltyInfo =
15695 [&](const SCEV *Expr, const SCEV *&DividesBy) {
15696 if (auto *Mul = dyn_cast<SCEVMulExpr>(Expr)) {
15697 if (Mul->getNumOperands() != 2)
15698 return false;
15699 auto *MulLHS = Mul->getOperand(0);
15700 auto *MulRHS = Mul->getOperand(1);
15701 if (isa<SCEVConstant>(MulLHS))
15702 std::swap(MulLHS, MulRHS);
15703 if (auto *Div = dyn_cast<SCEVUDivExpr>(MulLHS))
15704 if (Div->getOperand(1) == MulRHS) {
15705 DividesBy = MulRHS;
15706 return true;
15707 }
15708 }
15709 if (auto *MinMax = dyn_cast<SCEVMinMaxExpr>(Expr))
15710 return HasDivisibiltyInfo(MinMax->getOperand(0), DividesBy) ||
15711 HasDivisibiltyInfo(MinMax->getOperand(1), DividesBy);
15712 return false;
15713 };
15714
15715 // Return true if Expr known to divide by \p DividesBy.
15716 std::function<bool(const SCEV *, const SCEV *&)> IsKnownToDivideBy =
15717 [&](const SCEV *Expr, const SCEV *DividesBy) {
15718 if (SE.getURemExpr(Expr, DividesBy)->isZero())
15719 return true;
15720 if (auto *MinMax = dyn_cast<SCEVMinMaxExpr>(Expr))
15721 return IsKnownToDivideBy(MinMax->getOperand(0), DividesBy) &&
15722 IsKnownToDivideBy(MinMax->getOperand(1), DividesBy);
15723 return false;
15724 };
15725
15726 const SCEV *RewrittenLHS = GetMaybeRewritten(LHS);
15727 const SCEV *DividesBy = nullptr;
15728 if (HasDivisibiltyInfo(RewrittenLHS, DividesBy))
15729 // Check that the whole expression is divided by DividesBy
15730 DividesBy =
15731 IsKnownToDivideBy(RewrittenLHS, DividesBy) ? DividesBy : nullptr;
15732
15733 // Collect rewrites for LHS and its transitive operands based on the
15734 // condition.
15735 // For min/max expressions, also apply the guard to its operands:
15736 // 'min(a, b) >= c' -> '(a >= c) and (b >= c)',
15737 // 'min(a, b) > c' -> '(a > c) and (b > c)',
15738 // 'max(a, b) <= c' -> '(a <= c) and (b <= c)',
15739 // 'max(a, b) < c' -> '(a < c) and (b < c)'.
15740
15741 // We cannot express strict predicates in SCEV, so instead we replace them
15742 // with non-strict ones against plus or minus one of RHS depending on the
15743 // predicate.
15744 const SCEV *One = SE.getOne(RHS->getType());
15745 switch (Predicate) {
15746 case CmpInst::ICMP_ULT:
15747 if (RHS->getType()->isPointerTy())
15748 return;
15749 RHS = SE.getUMaxExpr(RHS, One);
15750 [[fallthrough]];
15751 case CmpInst::ICMP_SLT: {
15752 RHS = SE.getMinusSCEV(RHS, One);
15753 RHS = DividesBy ? GetPreviousSCEVDividesByDivisor(RHS, DividesBy) : RHS;
15754 break;
15755 }
15756 case CmpInst::ICMP_UGT:
15757 case CmpInst::ICMP_SGT:
15758 RHS = SE.getAddExpr(RHS, One);
15759 RHS = DividesBy ? GetNextSCEVDividesByDivisor(RHS, DividesBy) : RHS;
15760 break;
15761 case CmpInst::ICMP_ULE:
15762 case CmpInst::ICMP_SLE:
15763 RHS = DividesBy ? GetPreviousSCEVDividesByDivisor(RHS, DividesBy) : RHS;
15764 break;
15765 case CmpInst::ICMP_UGE:
15766 case CmpInst::ICMP_SGE:
15767 RHS = DividesBy ? GetNextSCEVDividesByDivisor(RHS, DividesBy) : RHS;
15768 break;
15769 default:
15770 break;
15771 }
15772
15774 SmallPtrSet<const SCEV *, 16> Visited;
15775
15776 auto EnqueueOperands = [&Worklist](const SCEVNAryExpr *S) {
15777 append_range(Worklist, S->operands());
15778 };
15779
15780 while (!Worklist.empty()) {
15781 const SCEV *From = Worklist.pop_back_val();
15782 if (isa<SCEVConstant>(From))
15783 continue;
15784 if (!Visited.insert(From).second)
15785 continue;
15786 const SCEV *FromRewritten = GetMaybeRewritten(From);
15787 const SCEV *To = nullptr;
15788
15789 switch (Predicate) {
15790 case CmpInst::ICMP_ULT:
15791 case CmpInst::ICMP_ULE:
15792 To = SE.getUMinExpr(FromRewritten, RHS);
15793 if (auto *UMax = dyn_cast<SCEVUMaxExpr>(FromRewritten))
15794 EnqueueOperands(UMax);
15795 break;
15796 case CmpInst::ICMP_SLT:
15797 case CmpInst::ICMP_SLE:
15798 To = SE.getSMinExpr(FromRewritten, RHS);
15799 if (auto *SMax = dyn_cast<SCEVSMaxExpr>(FromRewritten))
15800 EnqueueOperands(SMax);
15801 break;
15802 case CmpInst::ICMP_UGT:
15803 case CmpInst::ICMP_UGE:
15804 To = SE.getUMaxExpr(FromRewritten, RHS);
15805 if (auto *UMin = dyn_cast<SCEVUMinExpr>(FromRewritten))
15806 EnqueueOperands(UMin);
15807 break;
15808 case CmpInst::ICMP_SGT:
15809 case CmpInst::ICMP_SGE:
15810 To = SE.getSMaxExpr(FromRewritten, RHS);
15811 if (auto *SMin = dyn_cast<SCEVSMinExpr>(FromRewritten))
15812 EnqueueOperands(SMin);
15813 break;
15814 case CmpInst::ICMP_EQ:
15816 To = RHS;
15817 break;
15818 case CmpInst::ICMP_NE:
15819 if (match(RHS, m_scev_Zero())) {
15820 const SCEV *OneAlignedUp =
15821 DividesBy ? GetNextSCEVDividesByDivisor(One, DividesBy) : One;
15822 To = SE.getUMaxExpr(FromRewritten, OneAlignedUp);
15823 }
15824 break;
15825 default:
15826 break;
15827 }
15828
15829 if (To)
15830 AddRewrite(From, FromRewritten, To);
15831 }
15832 };
15833
15835 // First, collect information from assumptions dominating the loop.
15836 for (auto &AssumeVH : SE.AC.assumptions()) {
15837 if (!AssumeVH)
15838 continue;
15839 auto *AssumeI = cast<CallInst>(AssumeVH);
15840 if (!SE.DT.dominates(AssumeI, Block))
15841 continue;
15842 Terms.emplace_back(AssumeI->getOperand(0), true);
15843 }
15844
15845 // Second, collect information from llvm.experimental.guards dominating the loop.
15846 auto *GuardDecl = Intrinsic::getDeclarationIfExists(
15847 SE.F.getParent(), Intrinsic::experimental_guard);
15848 if (GuardDecl)
15849 for (const auto *GU : GuardDecl->users())
15850 if (const auto *Guard = dyn_cast<IntrinsicInst>(GU))
15851 if (Guard->getFunction() == Block->getParent() &&
15852 SE.DT.dominates(Guard, Block))
15853 Terms.emplace_back(Guard->getArgOperand(0), true);
15854
15855 // Third, collect conditions from dominating branches. Starting at the loop
15856 // predecessor, climb up the predecessor chain, as long as there are
15857 // predecessors that can be found that have unique successors leading to the
15858 // original header.
15859 // TODO: share this logic with isLoopEntryGuardedByCond.
15860 unsigned NumCollectedConditions = 0;
15861 VisitedBlocks.insert(Block);
15862 std::pair<const BasicBlock *, const BasicBlock *> Pair(Pred, Block);
15863 for (; Pair.first;
15864 Pair = SE.getPredecessorWithUniqueSuccessorForBB(Pair.first)) {
15865 VisitedBlocks.insert(Pair.second);
15866 const BranchInst *LoopEntryPredicate =
15867 dyn_cast<BranchInst>(Pair.first->getTerminator());
15868 if (!LoopEntryPredicate || LoopEntryPredicate->isUnconditional())
15869 continue;
15870
15871 Terms.emplace_back(LoopEntryPredicate->getCondition(),
15872 LoopEntryPredicate->getSuccessor(0) == Pair.second);
15873 NumCollectedConditions++;
15874
15875 // If we are recursively collecting guards stop after 2
15876 // conditions to limit compile-time impact for now.
15877 if (Depth > 0 && NumCollectedConditions == 2)
15878 break;
15879 }
15880 // Finally, if we stopped climbing the predecessor chain because
15881 // there wasn't a unique one to continue, try to collect conditions
15882 // for PHINodes by recursively following all of their incoming
15883 // blocks and try to merge the found conditions to build a new one
15884 // for the Phi.
15885 if (Pair.second->hasNPredecessorsOrMore(2) &&
15887 SmallDenseMap<const BasicBlock *, LoopGuards> IncomingGuards;
15888 for (auto &Phi : Pair.second->phis())
15889 collectFromPHI(SE, Guards, Phi, VisitedBlocks, IncomingGuards, Depth);
15890 }
15891
15892 // Now apply the information from the collected conditions to
15893 // Guards.RewriteMap. Conditions are processed in reverse order, so the
15894 // earliest conditions is processed first. This ensures the SCEVs with the
15895 // shortest dependency chains are constructed first.
15896 for (auto [Term, EnterIfTrue] : reverse(Terms)) {
15897 SmallVector<Value *, 8> Worklist;
15898 SmallPtrSet<Value *, 8> Visited;
15899 Worklist.push_back(Term);
15900 while (!Worklist.empty()) {
15901 Value *Cond = Worklist.pop_back_val();
15902 if (!Visited.insert(Cond).second)
15903 continue;
15904
15905 if (auto *Cmp = dyn_cast<ICmpInst>(Cond)) {
15906 auto Predicate =
15907 EnterIfTrue ? Cmp->getPredicate() : Cmp->getInversePredicate();
15908 const auto *LHS = SE.getSCEV(Cmp->getOperand(0));
15909 const auto *RHS = SE.getSCEV(Cmp->getOperand(1));
15910 CollectCondition(Predicate, LHS, RHS, Guards.RewriteMap);
15911 continue;
15912 }
15913
15914 Value *L, *R;
15915 if (EnterIfTrue ? match(Cond, m_LogicalAnd(m_Value(L), m_Value(R)))
15916 : match(Cond, m_LogicalOr(m_Value(L), m_Value(R)))) {
15917 Worklist.push_back(L);
15918 Worklist.push_back(R);
15919 }
15920 }
15921 }
15922
15923 // Let the rewriter preserve NUW/NSW flags if the unsigned/signed ranges of
15924 // the replacement expressions are contained in the ranges of the replaced
15925 // expressions.
15926 Guards.PreserveNUW = true;
15927 Guards.PreserveNSW = true;
15928 for (const SCEV *Expr : ExprsToRewrite) {
15929 const SCEV *RewriteTo = Guards.RewriteMap[Expr];
15930 Guards.PreserveNUW &=
15931 SE.getUnsignedRange(Expr).contains(SE.getUnsignedRange(RewriteTo));
15932 Guards.PreserveNSW &=
15933 SE.getSignedRange(Expr).contains(SE.getSignedRange(RewriteTo));
15934 }
15935
15936 // Now that all rewrite information is collect, rewrite the collected
15937 // expressions with the information in the map. This applies information to
15938 // sub-expressions.
15939 if (ExprsToRewrite.size() > 1) {
15940 for (const SCEV *Expr : ExprsToRewrite) {
15941 const SCEV *RewriteTo = Guards.RewriteMap[Expr];
15942 Guards.RewriteMap.erase(Expr);
15943 Guards.RewriteMap.insert({Expr, Guards.rewrite(RewriteTo)});
15944 }
15945 }
15946}
15947
15949 /// A rewriter to replace SCEV expressions in Map with the corresponding entry
15950 /// in the map. It skips AddRecExpr because we cannot guarantee that the
15951 /// replacement is loop invariant in the loop of the AddRec.
15952 class SCEVLoopGuardRewriter
15953 : public SCEVRewriteVisitor<SCEVLoopGuardRewriter> {
15955
15957
15958 public:
15959 SCEVLoopGuardRewriter(ScalarEvolution &SE,
15960 const ScalarEvolution::LoopGuards &Guards)
15961 : SCEVRewriteVisitor(SE), Map(Guards.RewriteMap) {
15962 if (Guards.PreserveNUW)
15963 FlagMask = ScalarEvolution::setFlags(FlagMask, SCEV::FlagNUW);
15964 if (Guards.PreserveNSW)
15965 FlagMask = ScalarEvolution::setFlags(FlagMask, SCEV::FlagNSW);
15966 }
15967
15968 const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) { return Expr; }
15969
15970 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
15971 return Map.lookup_or(Expr, Expr);
15972 }
15973
15974 const SCEV *visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr) {
15975 if (const SCEV *S = Map.lookup(Expr))
15976 return S;
15977
15978 // If we didn't find the extact ZExt expr in the map, check if there's
15979 // an entry for a smaller ZExt we can use instead.
15980 Type *Ty = Expr->getType();
15981 const SCEV *Op = Expr->getOperand(0);
15982 unsigned Bitwidth = Ty->getScalarSizeInBits() / 2;
15983 while (Bitwidth % 8 == 0 && Bitwidth >= 8 &&
15984 Bitwidth > Op->getType()->getScalarSizeInBits()) {
15985 Type *NarrowTy = IntegerType::get(SE.getContext(), Bitwidth);
15986 auto *NarrowExt = SE.getZeroExtendExpr(Op, NarrowTy);
15987 if (const SCEV *S = Map.lookup(NarrowExt))
15988 return SE.getZeroExtendExpr(S, Ty);
15989 Bitwidth = Bitwidth / 2;
15990 }
15991
15993 Expr);
15994 }
15995
15996 const SCEV *visitSignExtendExpr(const SCEVSignExtendExpr *Expr) {
15997 if (const SCEV *S = Map.lookup(Expr))
15998 return S;
16000 Expr);
16001 }
16002
16003 const SCEV *visitUMinExpr(const SCEVUMinExpr *Expr) {
16004 if (const SCEV *S = Map.lookup(Expr))
16005 return S;
16007 }
16008
16009 const SCEV *visitSMinExpr(const SCEVSMinExpr *Expr) {
16010 if (const SCEV *S = Map.lookup(Expr))
16011 return S;
16013 }
16014
16015 const SCEV *visitAddExpr(const SCEVAddExpr *Expr) {
16016 // Trip count expressions sometimes consist of adding 3 operands, i.e.
16017 // (Const + A + B). There may be guard info for A + B, and if so, apply
16018 // it.
16019 // TODO: Could more generally apply guards to Add sub-expressions.
16020 if (isa<SCEVConstant>(Expr->getOperand(0)) &&
16021 Expr->getNumOperands() == 3) {
16022 if (const SCEV *S = Map.lookup(
16023 SE.getAddExpr(Expr->getOperand(1), Expr->getOperand(2))))
16024 return SE.getAddExpr(Expr->getOperand(0), S);
16025 }
16027 bool Changed = false;
16028 for (const auto *Op : Expr->operands()) {
16029 Operands.push_back(
16031 Changed |= Op != Operands.back();
16032 }
16033 // We are only replacing operands with equivalent values, so transfer the
16034 // flags from the original expression.
16035 return !Changed ? Expr
16036 : SE.getAddExpr(Operands,
16038 Expr->getNoWrapFlags(), FlagMask));
16039 }
16040
16041 const SCEV *visitMulExpr(const SCEVMulExpr *Expr) {
16043 bool Changed = false;
16044 for (const auto *Op : Expr->operands()) {
16045 Operands.push_back(
16047 Changed |= Op != Operands.back();
16048 }
16049 // We are only replacing operands with equivalent values, so transfer the
16050 // flags from the original expression.
16051 return !Changed ? Expr
16052 : SE.getMulExpr(Operands,
16054 Expr->getNoWrapFlags(), FlagMask));
16055 }
16056 };
16057
16058 if (RewriteMap.empty())
16059 return Expr;
16060
16061 SCEVLoopGuardRewriter Rewriter(SE, *this);
16062 return Rewriter.visit(Expr);
16063}
16064
16065const SCEV *ScalarEvolution::applyLoopGuards(const SCEV *Expr, const Loop *L) {
16066 return applyLoopGuards(Expr, LoopGuards::collect(L, *this));
16067}
16068
16070 const LoopGuards &Guards) {
16071 return Guards.rewrite(Expr);
16072}
assert(UImm &&(UImm !=~static_cast< T >(0)) &&"Invalid immediate!")
constexpr LLT S1
Rewrite undef for PHI
This file implements a class to represent arbitrary precision integral constant values and operations...
@ PostInc
MachineBasicBlock MachineBasicBlock::iterator DebugLoc DL
Expand Atomic instructions
static GCRegistry::Add< ErlangGC > A("erlang", "erlang-compatible garbage collector")
static GCRegistry::Add< StatepointGC > D("statepoint-example", "an example strategy for statepoint")
static GCRegistry::Add< CoreCLRGC > E("coreclr", "CoreCLR-compatible GC")
static GCRegistry::Add< OcamlGC > B("ocaml", "ocaml 3.10-compatible GC")
#define LLVM_DUMP_METHOD
Mark debug helper function definitions like dump() that should not be stripped from debug builds.
Definition Compiler.h:638
This file contains the declarations for the subclasses of Constant, which represent the different fla...
This file defines the DenseMap class.
This file builds on the ADT/GraphTraits.h file to build generic depth first graph iterator.
static bool isSigned(unsigned int Opcode)
This file defines a hash set that can be used to remove duplication of nodes in a graph.
#define op(i)
Hexagon Common GEP
This file provides various utilities for inspecting and working with the control flow graph in LLVM I...
This defines the Use class.
iv Induction Variable Users
Definition IVUsers.cpp:48
const AbstractManglingParser< Derived, Alloc >::OperatorInfo AbstractManglingParser< Derived, Alloc >::Ops[]
static bool isZero(Value *V, const DataLayout &DL, DominatorTree *DT, AssumptionCache *AC)
Definition Lint.cpp:546
#define F(x, y, z)
Definition MD5.cpp:55
#define I(x, y, z)
Definition MD5.cpp:58
#define G(x, y, z)
Definition MD5.cpp:56
mir Rename Register Operands
#define T
#define T1
ConstantRange Range(APInt(BitWidth, Low), APInt(BitWidth, High))
uint64_t IntrinsicInst * II
#define P(N)
ppc ctr loops verify
PowerPC Reduce CR logical Operation
#define INITIALIZE_PASS_DEPENDENCY(depName)
Definition PassSupport.h:42
#define INITIALIZE_PASS_END(passName, arg, name, cfg, analysis)
Definition PassSupport.h:44
#define INITIALIZE_PASS_BEGIN(passName, arg, name, cfg, analysis)
Definition PassSupport.h:39
R600 Clause Merge
const SmallVectorImpl< MachineOperand > & Cond
static bool isValid(const char C)
Returns true if C is a valid mangled character: <0-9a-zA-Z_>.
SI optimize exec mask operations pre RA
void visit(MachineFunction &MF, MachineBasicBlock &Start, std::function< void(MachineBasicBlock *)> op)
This file contains some templates that are useful if you are working with the STL at all.
This file provides utility classes that use RAII to save and restore values.
bool SCEVMinMaxExprContains(const SCEV *Root, const SCEV *OperandToFind, SCEVTypes RootKind)
static cl::opt< unsigned > MaxAddRecSize("scalar-evolution-max-add-rec-size", cl::Hidden, cl::desc("Max coefficients in AddRec during evolving"), cl::init(8))
static cl::opt< unsigned > RangeIterThreshold("scev-range-iter-threshold", cl::Hidden, cl::desc("Threshold for switching to iteratively computing SCEV ranges"), cl::init(32))
static const Loop * isIntegerLoopHeaderPHI(const PHINode *PN, LoopInfo &LI)
static unsigned getConstantTripCount(const SCEVConstant *ExitCount)
static int CompareValueComplexity(const LoopInfo *const LI, Value *LV, Value *RV, unsigned Depth)
Compare the two values LV and RV in terms of their "complexity" where "complexity" is a partial (and ...
static void PushLoopPHIs(const Loop *L, SmallVectorImpl< Instruction * > &Worklist, SmallPtrSetImpl< Instruction * > &Visited)
Push PHI nodes in the header of the given loop onto the given Worklist.
static void insertFoldCacheEntry(const ScalarEvolution::FoldID &ID, const SCEV *S, DenseMap< ScalarEvolution::FoldID, const SCEV * > &FoldCache, DenseMap< const SCEV *, SmallVector< ScalarEvolution::FoldID, 2 > > &FoldCacheUser)
static cl::opt< bool > ClassifyExpressions("scalar-evolution-classify-expressions", cl::Hidden, cl::init(true), cl::desc("When printing analysis, include information on every instruction"))
static bool CanConstantFold(const Instruction *I)
Return true if we can constant fold an instruction of the specified type, assuming that all operands ...
static cl::opt< unsigned > AddOpsInlineThreshold("scev-addops-inline-threshold", cl::Hidden, cl::desc("Threshold for inlining addition operands into a SCEV"), cl::init(500))
static cl::opt< unsigned > MaxLoopGuardCollectionDepth("scalar-evolution-max-loop-guard-collection-depth", cl::Hidden, cl::desc("Maximum depth for recursive loop guard collection"), cl::init(1))
static cl::opt< bool > VerifyIR("scev-verify-ir", cl::Hidden, cl::desc("Verify IR correctness when making sensitive SCEV queries (slow)"), cl::init(false))
static bool BrPHIToSelect(DominatorTree &DT, BranchInst *BI, PHINode *Merge, Value *&C, Value *&LHS, Value *&RHS)
static const SCEV * getPreStartForExtend(const SCEVAddRecExpr *AR, Type *Ty, ScalarEvolution *SE, unsigned Depth)
static std::optional< APInt > MinOptional(std::optional< APInt > X, std::optional< APInt > Y)
Helper function to compare optional APInts: (a) if X and Y both exist, return min(X,...
static cl::opt< unsigned > MulOpsInlineThreshold("scev-mulops-inline-threshold", cl::Hidden, cl::desc("Threshold for inlining multiplication operands into a SCEV"), cl::init(32))
static void GroupByComplexity(SmallVectorImpl< const SCEV * > &Ops, LoopInfo *LI, DominatorTree &DT)
Given a list of SCEV objects, order them by their complexity, and group objects of the same complexit...
static const SCEV * constantFoldAndGroupOps(ScalarEvolution &SE, LoopInfo &LI, DominatorTree &DT, SmallVectorImpl< const SCEV * > &Ops, FoldT Fold, IsIdentityT IsIdentity, IsAbsorberT IsAbsorber)
Performs a number of common optimizations on the passed Ops.
static std::optional< const SCEV * > createNodeForSelectViaUMinSeq(ScalarEvolution *SE, const SCEV *CondExpr, const SCEV *TrueExpr, const SCEV *FalseExpr)
static Constant * BuildConstantFromSCEV(const SCEV *V)
This builds up a Constant using the ConstantExpr interface.
static ConstantInt * EvaluateConstantChrecAtConstant(const SCEVAddRecExpr *AddRec, ConstantInt *C, ScalarEvolution &SE)
static const SCEV * BinomialCoefficient(const SCEV *It, unsigned K, ScalarEvolution &SE, Type *ResultTy)
Compute BC(It, K). The result has width W. Assume, K > 0.
static cl::opt< unsigned > MaxCastDepth("scalar-evolution-max-cast-depth", cl::Hidden, cl::desc("Maximum depth of recursive SExt/ZExt/Trunc"), cl::init(8))
static bool IsMinMaxConsistingOf(const SCEV *MaybeMinMaxExpr, const SCEV *Candidate)
Is MaybeMinMaxExpr an (U|S)(Min|Max) of Candidate and some other values?
static PHINode * getConstantEvolvingPHI(Value *V, const Loop *L)
getConstantEvolvingPHI - Given an LLVM value and a loop, return a PHI node in the loop that V is deri...
static cl::opt< unsigned > MaxBruteForceIterations("scalar-evolution-max-iterations", cl::ReallyHidden, cl::desc("Maximum number of iterations SCEV will " "symbolically execute a constant " "derived loop"), cl::init(100))
static bool MatchBinarySub(const SCEV *S, const SCEV *&LHS, const SCEV *&RHS)
static uint64_t umul_ov(uint64_t i, uint64_t j, bool &Overflow)
static void PrintSCEVWithTypeHint(raw_ostream &OS, const SCEV *S)
When printing a top-level SCEV for trip counts, it's helpful to include a type for constants which ar...
static void PrintLoopInfo(raw_ostream &OS, ScalarEvolution *SE, const Loop *L)
static bool containsConstantInAddMulChain(const SCEV *StartExpr)
Determine if any of the operands in this SCEV are a constant or if any of the add or multiply express...
static const SCEV * getExtendAddRecStart(const SCEVAddRecExpr *AR, Type *Ty, ScalarEvolution *SE, unsigned Depth)
static bool hasHugeExpression(ArrayRef< const SCEV * > Ops)
Returns true if Ops contains a huge SCEV (the subtree of S contains at least HugeExprThreshold nodes)...
static cl::opt< unsigned > MaxPhiSCCAnalysisSize("scalar-evolution-max-scc-analysis-depth", cl::Hidden, cl::desc("Maximum amount of nodes to process while searching SCEVUnknown " "Phi strongly connected components"), cl::init(8))
static bool IsKnownPredicateViaAddRecStart(ScalarEvolution &SE, CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS)
static cl::opt< unsigned > MaxSCEVOperationsImplicationDepth("scalar-evolution-max-scev-operations-implication-depth", cl::Hidden, cl::desc("Maximum depth of recursive SCEV operations implication analysis"), cl::init(2))
static void PushDefUseChildren(Instruction *I, SmallVectorImpl< Instruction * > &Worklist, SmallPtrSetImpl< Instruction * > &Visited)
Push users of the given Instruction onto the given Worklist.
static std::optional< APInt > SolveQuadraticAddRecRange(const SCEVAddRecExpr *AddRec, const ConstantRange &Range, ScalarEvolution &SE)
Let c(n) be the value of the quadratic chrec {0,+,M,+,N} after n iterations.
static cl::opt< bool > UseContextForNoWrapFlagInference("scalar-evolution-use-context-for-no-wrap-flag-strenghening", cl::Hidden, cl::desc("Infer nuw/nsw flags using context where suitable"), cl::init(true))
static cl::opt< bool > EnableFiniteLoopControl("scalar-evolution-finite-loop", cl::Hidden, cl::desc("Handle <= and >= in finite loops"), cl::init(true))
static std::optional< std::tuple< APInt, APInt, APInt, APInt, unsigned > > GetQuadraticEquation(const SCEVAddRecExpr *AddRec)
For a given quadratic addrec, generate coefficients of the corresponding quadratic equation,...
static bool isKnownPredicateExtendIdiom(CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS)
static std::optional< BinaryOp > MatchBinaryOp(Value *V, const DataLayout &DL, AssumptionCache &AC, const DominatorTree &DT, const Instruction *CxtI)
Try to map V into a BinaryOp, and return std::nullopt on failure.
static std::optional< APInt > SolveQuadraticAddRecExact(const SCEVAddRecExpr *AddRec, ScalarEvolution &SE)
Let c(n) be the value of the quadratic chrec {L,+,M,+,N} after n iterations.
static std::optional< APInt > TruncIfPossible(std::optional< APInt > X, unsigned BitWidth)
Helper function to truncate an optional APInt to a given BitWidth.
static cl::opt< unsigned > MaxSCEVCompareDepth("scalar-evolution-max-scev-compare-depth", cl::Hidden, cl::desc("Maximum depth of recursive SCEV complexity comparisons"), cl::init(32))
static APInt extractConstantWithoutWrapping(ScalarEvolution &SE, const SCEVConstant *ConstantTerm, const SCEVAddExpr *WholeAddExpr)
static cl::opt< unsigned > MaxConstantEvolvingDepth("scalar-evolution-max-constant-evolving-depth", cl::Hidden, cl::desc("Maximum depth of recursive constant evolving"), cl::init(32))
static ConstantRange getRangeForAffineARHelper(APInt Step, const ConstantRange &StartRange, const APInt &MaxBECount, bool Signed)
static std::optional< ConstantRange > GetRangeFromMetadata(Value *V)
Helper method to assign a range to V from metadata present in the IR.
static const SCEV * SolveLinEquationWithOverflow(const APInt &A, const SCEV *B, SmallVectorImpl< const SCEVPredicate * > *Predicates, ScalarEvolution &SE)
Finds the minimum unsigned root of the following equation:
static cl::opt< unsigned > HugeExprThreshold("scalar-evolution-huge-expr-threshold", cl::Hidden, cl::desc("Size of the expression which is considered huge"), cl::init(4096))
static Type * isSimpleCastedPHI(const SCEV *Op, const SCEVUnknown *SymbolicPHI, bool &Signed, ScalarEvolution &SE)
Helper function to createAddRecFromPHIWithCasts.
static Constant * EvaluateExpression(Value *V, const Loop *L, DenseMap< Instruction *, Constant * > &Vals, const DataLayout &DL, const TargetLibraryInfo *TLI)
EvaluateExpression - Given an expression that passes the getConstantEvolvingPHI predicate,...
static const SCEV * MatchNotExpr(const SCEV *Expr)
If Expr computes ~A, return A else return nullptr.
static cl::opt< unsigned > MaxValueCompareDepth("scalar-evolution-max-value-compare-depth", cl::Hidden, cl::desc("Maximum depth of recursive value complexity comparisons"), cl::init(2))
static cl::opt< bool, true > VerifySCEVOpt("verify-scev", cl::Hidden, cl::location(VerifySCEV), cl::desc("Verify ScalarEvolution's backedge taken counts (slow)"))
static const SCEV * getSignedOverflowLimitForStep(const SCEV *Step, ICmpInst::Predicate *Pred, ScalarEvolution *SE)
static SCEV::NoWrapFlags StrengthenNoWrapFlags(ScalarEvolution *SE, SCEVTypes Type, const ArrayRef< const SCEV * > Ops, SCEV::NoWrapFlags Flags)
static cl::opt< unsigned > MaxArithDepth("scalar-evolution-max-arith-depth", cl::Hidden, cl::desc("Maximum depth of recursive arithmetics"), cl::init(32))
static bool HasSameValue(const SCEV *A, const SCEV *B)
SCEV structural equivalence is usually sufficient for testing whether two expressions are equal,...
static uint64_t Choose(uint64_t n, uint64_t k, bool &Overflow)
Compute the result of "n choose k", the binomial coefficient.
static std::optional< int > CompareSCEVComplexity(const LoopInfo *const LI, const SCEV *LHS, const SCEV *RHS, DominatorTree &DT, unsigned Depth=0)
static bool CollectAddOperandsWithScales(SmallDenseMap< const SCEV *, APInt, 16 > &M, SmallVectorImpl< const SCEV * > &NewOps, APInt &AccumulatedConstant, ArrayRef< const SCEV * > Ops, const APInt &Scale, ScalarEvolution &SE)
Process the given Ops list, which is a list of operands to be added under the given scale,...
static bool canConstantEvolve(Instruction *I, const Loop *L)
Determine whether this instruction can constant evolve within this loop assuming its operands can all...
static PHINode * getConstantEvolvingPHIOperands(Instruction *UseInst, const Loop *L, DenseMap< Instruction *, PHINode * > &PHIMap, unsigned Depth)
getConstantEvolvingPHIOperands - Implement getConstantEvolvingPHI by recursing through each instructi...
static bool scevUnconditionallyPropagatesPoisonFromOperands(SCEVTypes Kind)
static cl::opt< bool > VerifySCEVStrict("verify-scev-strict", cl::Hidden, cl::desc("Enable stricter verification with -verify-scev is passed"))
static Constant * getOtherIncomingValue(PHINode *PN, BasicBlock *BB)
static cl::opt< bool > UseExpensiveRangeSharpening("scalar-evolution-use-expensive-range-sharpening", cl::Hidden, cl::init(false), cl::desc("Use more powerful methods of sharpening expression ranges. May " "be costly in terms of compile time"))
static const SCEV * getUnsignedOverflowLimitForStep(const SCEV *Step, ICmpInst::Predicate *Pred, ScalarEvolution *SE)
static bool IsKnownPredicateViaMinOrMax(ScalarEvolution &SE, CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS)
Is LHS Pred RHS true on the virtue of LHS or RHS being a Min or Max expression?
This file defines the make_scope_exit function, which executes user-defined cleanup logic at scope ex...
static bool InBlock(const Value *V, const BasicBlock *BB)
Provides some synthesis utilities to produce sequences of values.
This file defines the SmallPtrSet class.
This file defines the SmallVector class.
This file defines the 'Statistic' class, which is designed to be an easy way to expose various metric...
#define STATISTIC(VARNAME, DESC)
Definition Statistic.h:167
This file contains some functions that are useful when dealing with strings.
#define LLVM_DEBUG(...)
Definition Debug.h:119
static TableGen::Emitter::Opt Y("gen-skeleton-entry", EmitSkeleton, "Generate example skeleton entry")
static TableGen::Emitter::OptClass< SkeletonEmitter > X("gen-skeleton-class", "Generate example skeleton class")
static SymbolRef::Type getType(const Symbol *Sym)
Definition TapiFile.cpp:39
LocallyHashedType DenseMapInfo< LocallyHashedType >::Empty
static std::optional< unsigned > getOpcode(ArrayRef< VPValue * > Values)
Returns the opcode of Values or ~0 if they do not all agree.
Definition VPlanSLP.cpp:247
static std::optional< bool > isImpliedCondOperands(CmpInst::Predicate Pred, const Value *ALHS, const Value *ARHS, const Value *BLHS, const Value *BRHS)
Return true if "icmp Pred BLHS BRHS" is true whenever "icmp PredALHS ARHS" is true.
Virtual Register Rewriter
Value * RHS
Value * LHS
BinaryOperator * Mul
static const uint32_t IV[8]
Definition blake3_impl.h:83
Class for arbitrary precision integers.
Definition APInt.h:78
LLVM_ABI APInt umul_ov(const APInt &RHS, bool &Overflow) const
Definition APInt.cpp:1971
LLVM_ABI APInt zext(unsigned width) const
Zero extend to a new width.
Definition APInt.cpp:1012
bool isMinSignedValue() const
Determine if this is the smallest signed value.
Definition APInt.h:423
uint64_t getZExtValue() const
Get zero extended value.
Definition APInt.h:1540
void setHighBits(unsigned hiBits)
Set the top hiBits bits.
Definition APInt.h:1391
LLVM_ABI APInt getHiBits(unsigned numBits) const
Compute an APInt containing numBits highbits from this APInt.
Definition APInt.cpp:639
unsigned getActiveBits() const
Compute the number of active bits in the value.
Definition APInt.h:1512
LLVM_ABI APInt trunc(unsigned width) const
Truncate to new width.
Definition APInt.cpp:936
static APInt getMaxValue(unsigned numBits)
Gets maximum unsigned value of APInt for specific bit width.
Definition APInt.h:206
APInt abs() const
Get the absolute value.
Definition APInt.h:1795
bool sgt(const APInt &RHS) const
Signed greater than comparison.
Definition APInt.h:1201
bool ugt(const APInt &RHS) const
Unsigned greater than comparison.
Definition APInt.h:1182
bool isZero() const
Determine if this value is zero, i.e. all bits are clear.
Definition APInt.h:380
bool isSignMask() const
Check if the APInt's value is returned by getSignMask.
Definition APInt.h:466
LLVM_ABI APInt urem(const APInt &RHS) const
Unsigned remainder operation.
Definition APInt.cpp:1666
unsigned getBitWidth() const
Return the number of bits in the APInt.
Definition APInt.h:1488
bool ult(const APInt &RHS) const
Unsigned less than comparison.
Definition APInt.h:1111
static APInt getSignedMaxValue(unsigned numBits)
Gets maximum signed value of APInt for a specific bit width.
Definition APInt.h:209
static APInt getMinValue(unsigned numBits)
Gets minimum unsigned value of APInt for a specific bit width.
Definition APInt.h:216
bool isNegative() const
Determine sign of this APInt.
Definition APInt.h:329
bool sle(const APInt &RHS) const
Signed less or equal comparison.
Definition APInt.h:1166
static APInt getSignedMinValue(unsigned numBits)
Gets minimum signed value of APInt for a specific bit width.
Definition APInt.h:219
unsigned countTrailingZeros() const
Definition APInt.h:1647
bool isStrictlyPositive() const
Determine if this APInt Value is positive.
Definition APInt.h:356
unsigned logBase2() const
Definition APInt.h:1761
APInt ashr(unsigned ShiftAmt) const
Arithmetic right-shift function.
Definition APInt.h:827
LLVM_ABI APInt multiplicativeInverse() const
Definition APInt.cpp:1274
bool ule(const APInt &RHS) const
Unsigned less or equal comparison.
Definition APInt.h:1150
LLVM_ABI APInt sext(unsigned width) const
Sign extend to a new width.
Definition APInt.cpp:985
APInt shl(unsigned shiftAmt) const
Left-shift function.
Definition APInt.h:873
bool isPowerOf2() const
Check if this APInt's value is a power of two greater than zero.
Definition APInt.h:440
static APInt getLowBitsSet(unsigned numBits, unsigned loBitsSet)
Constructs an APInt value that has the bottom loBitsSet bits set.
Definition APInt.h:306
bool isSignBitSet() const
Determine if sign bit of this APInt is set.
Definition APInt.h:341
bool slt(const APInt &RHS) const
Signed less than comparison.
Definition APInt.h:1130
static APInt getZero(unsigned numBits)
Get the '0' value for the specified bit-width.
Definition APInt.h:200
bool isIntN(unsigned N) const
Check if this APInt has an N-bits unsigned integer value.
Definition APInt.h:432
static APInt getOneBitSet(unsigned numBits, unsigned BitNo)
Return an APInt with exactly one bit set in the result.
Definition APInt.h:239
bool uge(const APInt &RHS) const
Unsigned greater or equal comparison.
Definition APInt.h:1221
This templated class represents "all analyses that operate over <aparticular IR unit>" (e....
Definition Analysis.h:50
PassT::Result & getResult(IRUnitT &IR, ExtraArgTs... ExtraArgs)
Get the result of an analysis pass for a given IR unit.
Represent the analysis usage information of a pass.
void setPreservesAll()
Set by analyses that do not transform their input at all.
AnalysisUsage & addRequiredTransitive()
ArrayRef - Represent a constant reference to an array (0 or more elements consecutively in memory),...
Definition ArrayRef.h:41
iterator end() const
Definition ArrayRef.h:136
size_t size() const
size - Get the array size.
Definition ArrayRef.h:147
iterator begin() const
Definition ArrayRef.h:135
A function analysis which provides an AssumptionCache.
An immutable pass that tracks lazily created AssumptionCache objects.
A cache of @llvm.assume calls within a function.
MutableArrayRef< ResultElem > assumptions()
Access the list of assumption handles currently tracked for this function.
LLVM_ABI bool isSingleEdge() const
Check if this is the only edge between Start and End.
LLVM Basic Block Representation.
Definition BasicBlock.h:62
iterator begin()
Instruction iterator methods.
Definition BasicBlock.h:459
const Function * getParent() const
Return the enclosing method, or null if none.
Definition BasicBlock.h:213
LLVM_ABI const BasicBlock * getSinglePredecessor() const
Return the predecessor of this block if it has a single predecessor block.
const Instruction & front() const
Definition BasicBlock.h:482
const Instruction * getTerminator() const LLVM_READONLY
Returns the terminator instruction if the block is well formed or null if the block is not well forme...
Definition BasicBlock.h:233
LLVM_ABI unsigned getNoWrapKind() const
Returns one of OBO::NoSignedWrap or OBO::NoUnsignedWrap.
LLVM_ABI Instruction::BinaryOps getBinaryOp() const
Returns the binary operation underlying the intrinsic.
BinaryOps getOpcode() const
Definition InstrTypes.h:374
Conditional or Unconditional Branch instruction.
bool isConditional() const
BasicBlock * getSuccessor(unsigned i) const
bool isUnconditional() const
Value * getCondition() const
LLVM_ATTRIBUTE_RETURNS_NONNULL void * Allocate(size_t Size, Align Alignment)
Allocate space at the specified alignment.
Definition Allocator.h:149
This class represents a function call, abstracting a target machine's calling convention.
virtual void deleted()
Callback for Value destruction.
void setValPtr(Value *P)
bool isFalseWhenEqual() const
This is just a convenience.
Definition InstrTypes.h:950
Predicate
This enumeration lists the possible predicates for CmpInst subclasses.
Definition InstrTypes.h:678
@ ICMP_SLT
signed less than
Definition InstrTypes.h:707
@ ICMP_SLE
signed less or equal
Definition InstrTypes.h:708
@ ICMP_UGE
unsigned greater or equal
Definition InstrTypes.h:702
@ ICMP_UGT
unsigned greater than
Definition InstrTypes.h:701
@ ICMP_SGT
signed greater than
Definition InstrTypes.h:705
@ ICMP_ULT
unsigned less than
Definition InstrTypes.h:703
@ ICMP_NE
not equal
Definition InstrTypes.h:700
@ ICMP_SGE
signed greater or equal
Definition InstrTypes.h:706
@ ICMP_ULE
unsigned less or equal
Definition InstrTypes.h:704
bool isSigned() const
Definition InstrTypes.h:932
Predicate getSwappedPredicate() const
For example, EQ->EQ, SLE->SGE, ULT->UGT, OEQ->OEQ, ULE->UGE, OLT->OGT, etc.
Definition InstrTypes.h:829
bool isTrueWhenEqual() const
This is just a convenience.
Definition InstrTypes.h:944
Predicate getInversePredicate() const
For example, EQ -> NE, UGT -> ULE, SLT -> SGE, OEQ -> UNE, UGT -> OLE, OLT -> UGE,...
Definition InstrTypes.h:791
bool isUnsigned() const
Definition InstrTypes.h:938
bool isRelational() const
Return true if the predicate is relational (not EQ or NE).
Definition InstrTypes.h:928
An abstraction over a floating-point predicate, and a pack of an integer predicate with samesign info...
static LLVM_ABI std::optional< CmpPredicate > getMatching(CmpPredicate A, CmpPredicate B)
Compares two CmpPredicates taking samesign into account and returns the canonicalized CmpPredicate if...
LLVM_ABI CmpInst::Predicate getPreferredSignedPredicate() const
Attempts to return a signed CmpInst::Predicate from the CmpPredicate.
CmpInst::Predicate dropSameSign() const
Drops samesign information.
static LLVM_ABI Constant * getNot(Constant *C)
static LLVM_ABI Constant * getPtrToInt(Constant *C, Type *Ty, bool OnlyIfReduced=false)
static Constant * getGetElementPtr(Type *Ty, Constant *C, ArrayRef< Constant * > IdxList, GEPNoWrapFlags NW=GEPNoWrapFlags::none(), std::optional< ConstantRange > InRange=std::nullopt, Type *OnlyIfReducedTy=nullptr)
Getelementptr form.
Definition Constants.h:1274
static LLVM_ABI Constant * getAdd(Constant *C1, Constant *C2, bool HasNUW=false, bool HasNSW=false)
static LLVM_ABI Constant * getNeg(Constant *C, bool HasNSW=false)
static LLVM_ABI Constant * getTrunc(Constant *C, Type *Ty, bool OnlyIfReduced=false)
This is the shared class of boolean and integer constants.
Definition Constants.h:87
bool isZero() const
This is just a convenience method to make client code smaller for a common code.
Definition Constants.h:214
static LLVM_ABI ConstantInt * getFalse(LLVMContext &Context)
uint64_t getZExtValue() const
Return the constant as a 64-bit unsigned integer value after it has been zero extended as appropriate...
Definition Constants.h:163
const APInt & getValue() const
Return the constant as an APInt value reference.
Definition Constants.h:154
static LLVM_ABI ConstantInt * getBool(LLVMContext &Context, bool V)
This class represents a range of values.
LLVM_ABI ConstantRange add(const ConstantRange &Other) const
Return a new range representing the possible values resulting from an addition of a value in this ran...
LLVM_ABI ConstantRange zextOrTrunc(uint32_t BitWidth) const
Make this range have the bit width given by BitWidth.
PreferredRangeType
If represented precisely, the result of some range operations may consist of multiple disjoint ranges...
LLVM_ABI bool getEquivalentICmp(CmpInst::Predicate &Pred, APInt &RHS) const
Set up Pred and RHS such that ConstantRange::makeExactICmpRegion(Pred, RHS) == *this.
const APInt & getLower() const
Return the lower value for this range.
LLVM_ABI bool isFullSet() const
Return true if this set contains all of the elements possible for this data-type.
LLVM_ABI bool icmp(CmpInst::Predicate Pred, const ConstantRange &Other) const
Does the predicate Pred hold between ranges this and Other?
LLVM_ABI bool isEmptySet() const
Return true if this set contains no members.
LLVM_ABI ConstantRange zeroExtend(uint32_t BitWidth) const
Return a new range in the specified integer type, which must be strictly larger than the current type...
LLVM_ABI bool isSignWrappedSet() const
Return true if this set wraps around the signed domain.
LLVM_ABI APInt getSignedMin() const
Return the smallest signed value contained in the ConstantRange.
LLVM_ABI bool isWrappedSet() const
Return true if this set wraps around the unsigned domain.
LLVM_ABI void print(raw_ostream &OS) const
Print out the bounds to a stream.
LLVM_ABI ConstantRange truncate(uint32_t BitWidth, unsigned NoWrapKind=0) const
Return a new range in the specified integer type, which must be strictly smaller than the current typ...
LLVM_ABI ConstantRange signExtend(uint32_t BitWidth) const
Return a new range in the specified integer type, which must be strictly larger than the current type...
const APInt & getUpper() const
Return the upper value for this range.
LLVM_ABI ConstantRange unionWith(const ConstantRange &CR, PreferredRangeType Type=Smallest) const
Return the range that results from the union of this range with another range.
static LLVM_ABI ConstantRange makeExactICmpRegion(CmpInst::Predicate Pred, const APInt &Other)
Produce the exact range such that all values in the returned range satisfy the given predicate with a...
LLVM_ABI bool contains(const APInt &Val) const
Return true if the specified value is in the set.
LLVM_ABI APInt getUnsignedMax() const
Return the largest unsigned value contained in the ConstantRange.
LLVM_ABI ConstantRange intersectWith(const ConstantRange &CR, PreferredRangeType Type=Smallest) const
Return the range that results from the intersection of this range with another range.
LLVM_ABI APInt getSignedMax() const
Return the largest signed value contained in the ConstantRange.
static ConstantRange getNonEmpty(APInt Lower, APInt Upper)
Create non-empty constant range with the given bounds.
static LLVM_ABI ConstantRange makeGuaranteedNoWrapRegion(Instruction::BinaryOps BinOp, const ConstantRange &Other, unsigned NoWrapKind)
Produce the largest range containing all X such that "X BinOp Y" is guaranteed not to wrap (overflow)...
LLVM_ABI unsigned getMinSignedBits() const
Compute the maximal number of bits needed to represent every value in this signed range.
uint32_t getBitWidth() const
Get the bit width of this ConstantRange.
LLVM_ABI ConstantRange sub(const ConstantRange &Other) const
Return a new range representing the possible values resulting from a subtraction of a value in this r...
LLVM_ABI ConstantRange sextOrTrunc(uint32_t BitWidth) const
Make this range have the bit width given by BitWidth.
static LLVM_ABI ConstantRange makeExactNoWrapRegion(Instruction::BinaryOps BinOp, const APInt &Other, unsigned NoWrapKind)
Produce the range that contains X if and only if "X BinOp Other" does not wrap.
This is an important base class in LLVM.
Definition Constant.h:43
A parsed version of the target data layout string in and methods for querying it.
Definition DataLayout.h:63
LLVM_ABI const StructLayout * getStructLayout(StructType *Ty) const
Returns a StructLayout object, indicating the alignment of the struct, its size, and the offsets of i...
LLVM_ABI IntegerType * getIntPtrType(LLVMContext &C, unsigned AddressSpace=0) const
Returns an integer type with size at least as big as that of a pointer in the given address space.
LLVM_ABI unsigned getIndexTypeSizeInBits(Type *Ty) const
The size in bits of the index used in GEP calculation for this type.
LLVM_ABI IntegerType * getIndexType(LLVMContext &C, unsigned AddressSpace) const
Returns the type of a GEP index in AddressSpace.
TypeSize getTypeSizeInBits(Type *Ty) const
Size examples:
Definition DataLayout.h:671
ValueT lookup(const_arg_type_t< KeyT > Val) const
lookup - Return the entry for the specified key, or a default constructed value if no such entry exis...
Definition DenseMap.h:187
iterator find(const_arg_type_t< KeyT > Val)
Definition DenseMap.h:165
std::pair< iterator, bool > try_emplace(KeyT &&Key, Ts &&...Args)
Definition DenseMap.h:229
DenseMapIterator< KeyT, ValueT, KeyInfoT, BucketT > iterator
Definition DenseMap.h:74
iterator find_as(const LookupKeyT &Val)
Alternate version of find() which allows a different, and possibly less expensive,...
Definition DenseMap.h:173
size_type count(const_arg_type_t< KeyT > Val) const
Return 1 if the specified key is in the map, 0 otherwise.
Definition DenseMap.h:161
iterator end()
Definition DenseMap.h:81
bool contains(const_arg_type_t< KeyT > Val) const
Return true if the specified key is in the map, false otherwise.
Definition DenseMap.h:156
std::pair< iterator, bool > insert(const std::pair< KeyT, ValueT > &KV)
Definition DenseMap.h:214
void swap(DenseMap &RHS)
Definition DenseMap.h:744
Analysis pass which computes a DominatorTree.
Definition Dominators.h:284
Legacy analysis pass which computes a DominatorTree.
Definition Dominators.h:322
Concrete subclass of DominatorTreeBase that is used to compute a normal dominator tree.
Definition Dominators.h:165
LLVM_ABI bool isReachableFromEntry(const Use &U) const
Provide an overload for a Use.
LLVM_ABI bool dominates(const BasicBlock *BB, const Use &U) const
Return true if the (end of the) basic block BB dominates the use U.
FoldingSetNodeIDRef - This class describes a reference to an interned FoldingSetNodeID,...
Definition FoldingSet.h:293
FoldingSetNodeID - This class is used to gather all the unique data bits of a node.
Definition FoldingSet.h:330
FunctionPass(char &pid)
Definition Pass.h:316
Represents flags for the getelementptr instruction/expression.
bool hasNoUnsignedSignedWrap() const
bool hasNoUnsignedWrap() const
static GEPNoWrapFlags none()
static LLVM_ABI Type * getTypeAtIndex(Type *Ty, Value *Idx)
Return the type of the element at the given index of an indexable type.
Module * getParent()
Get the module that this global value is contained inside of...
static bool isPrivateLinkage(LinkageTypes Linkage)
static bool isInternalLinkage(LinkageTypes Linkage)
This instruction compares its operands according to the predicate given to the constructor.
CmpPredicate getCmpPredicate() const
static bool isGE(Predicate P)
Return true if the predicate is SGE or UGE.
CmpPredicate getSwappedCmpPredicate() const
static LLVM_ABI bool compare(const APInt &LHS, const APInt &RHS, ICmpInst::Predicate Pred)
Return result of LHS Pred RHS comparison.
static bool isLT(Predicate P)
Return true if the predicate is SLT or ULT.
CmpPredicate getInverseCmpPredicate() const
Predicate getNonStrictCmpPredicate() const
For example, SGT -> SGE, SLT -> SLE, ULT -> ULE, UGT -> UGE.
static bool isGT(Predicate P)
Return true if the predicate is SGT or UGT.
Predicate getFlippedSignednessPredicate() const
For example, SLT->ULT, ULT->SLT, SLE->ULE, ULE->SLE, EQ->EQ.
static CmpPredicate getInverseCmpPredicate(CmpPredicate Pred)
static bool isEquality(Predicate P)
Return true if this predicate is either EQ or NE.
bool isRelational() const
Return true if the predicate is relational (not EQ or NE).
static bool isLE(Predicate P)
Return true if the predicate is SLE or ULE.
LLVM_ABI bool hasNoUnsignedWrap() const LLVM_READONLY
Determine whether the no unsigned wrap flag is set.
LLVM_ABI bool hasNoSignedWrap() const LLVM_READONLY
Determine whether the no signed wrap flag is set.
LLVM_ABI bool isIdenticalToWhenDefined(const Instruction *I, bool IntersectAttrs=false) const LLVM_READONLY
This is like isIdenticalTo, except that it ignores the SubclassOptionalData flags,...
Class to represent integer types.
static LLVM_ABI IntegerType * get(LLVMContext &C, unsigned NumBits)
This static method is the primary way of constructing an IntegerType.
Definition Type.cpp:319
An instruction for reading from memory.
Analysis pass that exposes the LoopInfo for a function.
Definition LoopInfo.h:570
bool contains(const LoopT *L) const
Return true if the specified loop is contained within in this loop.
BlockT * getHeader() const
unsigned getLoopDepth() const
Return the nesting level of this loop.
BlockT * getLoopPredecessor() const
If the given loop's header has exactly one unique predecessor outside the loop, return it.
LoopT * getParentLoop() const
Return the parent loop if it exists or nullptr for top level loops.
unsigned getLoopDepth(const BlockT *BB) const
Return the loop nesting level of the specified block.
LoopT * getLoopFor(const BlockT *BB) const
Return the inner most loop that BB lives in.
The legacy pass manager's analysis pass to compute loop information.
Definition LoopInfo.h:597
Represents a single loop in the control flow graph.
Definition LoopInfo.h:40
bool isLoopInvariant(const Value *V, bool HasCoroSuspendInst=false) const
Return true if the specified value is loop invariant.
Definition LoopInfo.cpp:61
Metadata node.
Definition Metadata.h:1077
A Module instance is used to store all the information related to an LLVM module.
Definition Module.h:67
unsigned getOpcode() const
Return the opcode for this Instruction or ConstantExpr.
Definition Operator.h:43
Utility class for integer operators which may exhibit overflow - Add, Sub, Mul, and Shl.
Definition Operator.h:78
bool hasNoSignedWrap() const
Test whether this operation is known to never undergo signed overflow, aka the nsw property.
Definition Operator.h:111
bool hasNoUnsignedWrap() const
Test whether this operation is known to never undergo unsigned overflow, aka the nuw property.
Definition Operator.h:105
iterator_range< const_block_iterator > blocks() const
op_range incoming_values()
Value * getIncomingValueForBlock(const BasicBlock *BB) const
BasicBlock * getIncomingBlock(unsigned i) const
Return incoming basic block number i.
Value * getIncomingValue(unsigned i) const
Return incoming value number x.
unsigned getNumIncomingValues() const
Return the number of incoming edges.
AnalysisType & getAnalysis() const
getAnalysis<AnalysisType>() - This function is used by subclasses to get to the analysis information ...
PointerIntPair - This class implements a pair of a pointer and small integer.
static PointerType * getUnqual(Type *ElementType)
This constructs a pointer to an object of the specified type in the default address space (address sp...
static LLVM_ABI PoisonValue * get(Type *T)
Static factory methods - Return an 'poison' object of the specified type.
LLVM_ABI void addPredicate(const SCEVPredicate &Pred)
Adds a new predicate.
LLVM_ABI const SCEVPredicate & getPredicate() const
LLVM_ABI bool hasNoOverflow(Value *V, SCEVWrapPredicate::IncrementWrapFlags Flags)
Returns true if we've proved that V doesn't wrap by means of a SCEV predicate.
LLVM_ABI void setNoOverflow(Value *V, SCEVWrapPredicate::IncrementWrapFlags Flags)
Proves that V doesn't overflow by adding SCEV predicate.
LLVM_ABI void print(raw_ostream &OS, unsigned Depth) const
Print the SCEV mappings done by the Predicated Scalar Evolution.
LLVM_ABI bool areAddRecsEqualWithPreds(const SCEVAddRecExpr *AR1, const SCEVAddRecExpr *AR2) const
Check if AR1 and AR2 are equal, while taking into account Equal predicates in Preds.
LLVM_ABI PredicatedScalarEvolution(ScalarEvolution &SE, Loop &L)
LLVM_ABI const SCEVAddRecExpr * getAsAddRec(Value *V)
Attempts to produce an AddRecExpr for V by adding additional SCEV predicates.
LLVM_ABI unsigned getSmallConstantMaxTripCount()
Returns the upper bound of the loop trip count as a normal unsigned value, or 0 if the trip count is ...
LLVM_ABI const SCEV * getBackedgeTakenCount()
Get the (predicated) backedge count for the analyzed loop.
LLVM_ABI const SCEV * getSymbolicMaxBackedgeTakenCount()
Get the (predicated) symbolic max backedge count for the analyzed loop.
LLVM_ABI const SCEV * getSCEV(Value *V)
Returns the SCEV expression of V, in the context of the current SCEV predicate.
A set of analyses that are preserved following a run of a transformation pass.
Definition Analysis.h:112
static PreservedAnalyses all()
Construct a special preserved set that preserves all passes.
Definition Analysis.h:118
PreservedAnalysisChecker getChecker() const
Build a checker for this PreservedAnalyses and the specified analysis type.
Definition Analysis.h:275
constexpr bool isValid() const
Definition Register.h:107
This node represents an addition of some number of SCEVs.
This node represents a polynomial recurrence on the trip count of the specified loop.
LLVM_ABI const SCEV * evaluateAtIteration(const SCEV *It, ScalarEvolution &SE) const
Return the value of this chain of recurrences at the specified iteration number.
const SCEV * getStepRecurrence(ScalarEvolution &SE) const
Constructs and returns the recurrence indicating how much this expression steps by.
void setNoWrapFlags(NoWrapFlags Flags)
Set flags for a recurrence without clearing any previously set flags.
bool isAffine() const
Return true if this represents an expression A + B*x where A and B are loop invariant values.
bool isQuadratic() const
Return true if this represents an expression A + B*x + C*x^2 where A, B and C are loop invariant valu...
LLVM_ABI const SCEV * getNumIterationsInRange(const ConstantRange &Range, ScalarEvolution &SE) const
Return the number of iterations of this loop that produce values in the specified constant range.
LLVM_ABI const SCEVAddRecExpr * getPostIncExpr(ScalarEvolution &SE) const
Return an expression representing the value of this expression one iteration of the loop ahead.
This is the base class for unary cast operator classes.
const SCEV * getOperand() const
LLVM_ABI SCEVCastExpr(const FoldingSetNodeIDRef ID, SCEVTypes SCEVTy, const SCEV *op, Type *ty)
void setNoWrapFlags(NoWrapFlags Flags)
Set flags for a non-recurrence without clearing previously set flags.
This class represents an assumption that the expression LHS Pred RHS evaluates to true,...
SCEVComparePredicate(const FoldingSetNodeIDRef ID, const ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS)
bool isAlwaysTrue() const override
Returns true if the predicate is always true.
void print(raw_ostream &OS, unsigned Depth=0) const override
Prints a textual representation of this predicate with an indentation of Depth.
bool implies(const SCEVPredicate *N, ScalarEvolution &SE) const override
Implementation of the SCEVPredicate interface.
This class represents a constant integer value.
ConstantInt * getValue() const
const APInt & getAPInt() const
This is the base class for unary integral cast operator classes.
LLVM_ABI SCEVIntegralCastExpr(const FoldingSetNodeIDRef ID, SCEVTypes SCEVTy, const SCEV *op, Type *ty)
This node is the base class min/max selections.
static enum SCEVTypes negate(enum SCEVTypes T)
This node represents multiplication of some number of SCEVs.
This node is a base class providing common functionality for n'ary operators.
NoWrapFlags getNoWrapFlags(NoWrapFlags Mask=NoWrapMask) const
const SCEV * getOperand(unsigned i) const
ArrayRef< const SCEV * > operands() const
This class represents an assumption made using SCEV expressions which can be checked at run-time.
SCEVPredicate(const SCEVPredicate &)=default
virtual bool implies(const SCEVPredicate *N, ScalarEvolution &SE) const =0
Returns true if this predicate implies N.
SCEVPredicateKind Kind
This class represents a cast from a pointer to a pointer-sized integer value.
This visitor recursively visits a SCEV expression and re-writes it.
const SCEV * visitSignExtendExpr(const SCEVSignExtendExpr *Expr)
const SCEV * visit(const SCEV *S)
const SCEV * visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr)
const SCEV * visitSMinExpr(const SCEVSMinExpr *Expr)
const SCEV * visitUMinExpr(const SCEVUMinExpr *Expr)
This class represents a signed minimum selection.
This node is the base class for sequential/in-order min/max selections.
static SCEVTypes getEquivalentNonSequentialSCEVType(SCEVTypes Ty)
This class represents a sign extension of a small integer value to a larger integer value.
Visit all nodes in the expression tree using worklist traversal.
This class represents a truncation of an integer value to a smaller integer value.
This class represents a binary unsigned division operation.
This class represents an unsigned minimum selection.
This class represents a composition of other SCEV predicates, and is the class that most clients will...
void print(raw_ostream &OS, unsigned Depth) const override
Prints a textual representation of this predicate with an indentation of Depth.
bool implies(const SCEVPredicate *N, ScalarEvolution &SE) const override
Returns true if this predicate implies N.
SCEVUnionPredicate(ArrayRef< const SCEVPredicate * > Preds, ScalarEvolution &SE)
Union predicates don't get cached so create a dummy set ID for it.
bool isAlwaysTrue() const override
Implementation of the SCEVPredicate interface.
This means that we are dealing with an entirely unknown SCEV value, and only represent it as its LLVM...
This class represents the value of vscale, as used when defining the length of a scalable vector or r...
This class represents an assumption made on an AddRec expression.
IncrementWrapFlags
Similar to SCEV::NoWrapFlags, but with slightly different semantics for FlagNUSW.
SCEVWrapPredicate(const FoldingSetNodeIDRef ID, const SCEVAddRecExpr *AR, IncrementWrapFlags Flags)
bool implies(const SCEVPredicate *N, ScalarEvolution &SE) const override
Returns true if this predicate implies N.
static SCEVWrapPredicate::IncrementWrapFlags setFlags(SCEVWrapPredicate::IncrementWrapFlags Flags, SCEVWrapPredicate::IncrementWrapFlags OnFlags)
void print(raw_ostream &OS, unsigned Depth=0) const override
Prints a textual representation of this predicate with an indentation of Depth.
bool isAlwaysTrue() const override
Returns true if the predicate is always true.
const SCEVAddRecExpr * getExpr() const
Implementation of the SCEVPredicate interface.
static SCEVWrapPredicate::IncrementWrapFlags clearFlags(SCEVWrapPredicate::IncrementWrapFlags Flags, SCEVWrapPredicate::IncrementWrapFlags OffFlags)
Convenient IncrementWrapFlags manipulation methods.
static SCEVWrapPredicate::IncrementWrapFlags getImpliedFlags(const SCEVAddRecExpr *AR, ScalarEvolution &SE)
Returns the set of SCEVWrapPredicate no wrap flags implied by a SCEVAddRecExpr.
IncrementWrapFlags getFlags() const
Returns the set assumed no overflow flags.
This class represents a zero extension of a small integer value to a larger integer value.
This class represents an analyzed expression in the program.
LLVM_ABI ArrayRef< const SCEV * > operands() const
Return operands of this SCEV expression.
unsigned short getExpressionSize() const
LLVM_ABI bool isOne() const
Return true if the expression is a constant one.
LLVM_ABI bool isZero() const
Return true if the expression is a constant zero.
LLVM_ABI void dump() const
This method is used for debugging.
LLVM_ABI bool isAllOnesValue() const
Return true if the expression is a constant all-ones value.
LLVM_ABI bool isNonConstantNegative() const
Return true if the specified scev is negated, but not a constant.
LLVM_ABI void print(raw_ostream &OS) const
Print out the internal representation of this scalar to the specified stream.
SCEV(const FoldingSetNodeIDRef ID, SCEVTypes SCEVTy, unsigned short ExpressionSize)
SCEVTypes getSCEVType() const
LLVM_ABI Type * getType() const
Return the LLVM type of this SCEV expression.
NoWrapFlags
NoWrapFlags are bitfield indices into SubclassData.
Analysis pass that exposes the ScalarEvolution for a function.
LLVM_ABI ScalarEvolution run(Function &F, FunctionAnalysisManager &AM)
LLVM_ABI PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM)
LLVM_ABI PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM)
void getAnalysisUsage(AnalysisUsage &AU) const override
getAnalysisUsage - This function should be overriden by passes that need analysis information to do t...
void print(raw_ostream &OS, const Module *=nullptr) const override
print - Print out the internal state of the pass.
bool runOnFunction(Function &F) override
runOnFunction - Virtual method overriden by subclasses to do the per-function processing of the pass.
void releaseMemory() override
releaseMemory() - This member can be implemented by a pass if it wants to be able to release its memo...
void verifyAnalysis() const override
verifyAnalysis() - This member can be implemented by a analysis pass to check state of analysis infor...
static LLVM_ABI LoopGuards collect(const Loop *L, ScalarEvolution &SE)
Collect rewrite map for loop guards for loop L, together with flags indicating if NUW and NSW can be ...
LLVM_ABI const SCEV * rewrite(const SCEV *Expr) const
Try to apply the collected loop guards to Expr.
The main scalar evolution driver.
const SCEV * getConstantMaxBackedgeTakenCount(const Loop *L)
When successful, this returns a SCEVConstant that is greater than or equal to (i.e.
static bool hasFlags(SCEV::NoWrapFlags Flags, SCEV::NoWrapFlags TestFlags)
const DataLayout & getDataLayout() const
Return the DataLayout associated with the module this SCEV instance is operating on.
LLVM_ABI bool isKnownNonNegative(const SCEV *S)
Test if the given expression is known to be non-negative.
LLVM_ABI bool isKnownOnEveryIteration(CmpPredicate Pred, const SCEVAddRecExpr *LHS, const SCEV *RHS)
Test if the condition described by Pred, LHS, RHS is known to be true on every iteration of the loop ...
LLVM_ABI const SCEV * getNegativeSCEV(const SCEV *V, SCEV::NoWrapFlags Flags=SCEV::FlagAnyWrap)
Return the SCEV object corresponding to -V.
LLVM_ABI std::optional< LoopInvariantPredicate > getLoopInvariantExitCondDuringFirstIterationsImpl(CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS, const Loop *L, const Instruction *CtxI, const SCEV *MaxIter)
LLVM_ABI const SCEV * getSMaxExpr(const SCEV *LHS, const SCEV *RHS)
LLVM_ABI const SCEV * getUDivCeilSCEV(const SCEV *N, const SCEV *D)
Compute ceil(N / D).
LLVM_ABI const SCEV * getGEPExpr(GEPOperator *GEP, const SmallVectorImpl< const SCEV * > &IndexExprs)
Returns an expression for a GEP.
LLVM_ABI std::optional< LoopInvariantPredicate > getLoopInvariantExitCondDuringFirstIterations(CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS, const Loop *L, const Instruction *CtxI, const SCEV *MaxIter)
If the result of the predicate LHS Pred RHS is loop invariant with respect to L at given Context duri...
LLVM_ABI Type * getWiderType(Type *Ty1, Type *Ty2) const
LLVM_ABI const SCEV * getAbsExpr(const SCEV *Op, bool IsNSW)
LLVM_ABI bool isKnownNonPositive(const SCEV *S)
Test if the given expression is known to be non-positive.
LLVM_ABI const SCEV * getURemExpr(const SCEV *LHS, const SCEV *RHS)
Represents an unsigned remainder expression based on unsigned division.
LLVM_ABI APInt getConstantMultiple(const SCEV *S)
Returns the max constant multiple of S.
LLVM_ABI bool isKnownNegative(const SCEV *S)
Test if the given expression is known to be negative.
LLVM_ABI const SCEV * getPredicatedConstantMaxBackedgeTakenCount(const Loop *L, SmallVectorImpl< const SCEVPredicate * > &Predicates)
Similar to getConstantMaxBackedgeTakenCount, except it will add a set of SCEV predicates to Predicate...
LLVM_ABI const SCEV * removePointerBase(const SCEV *S)
Compute an expression equivalent to S - getPointerBase(S).
LLVM_ABI bool isLoopEntryGuardedByCond(const Loop *L, CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS)
Test whether entry to the loop is protected by a conditional between LHS and RHS.
LLVM_ABI bool isKnownNonZero(const SCEV *S)
Test if the given expression is known to be non-zero.
LLVM_ABI const SCEV * getSCEVAtScope(const SCEV *S, const Loop *L)
Return a SCEV expression for the specified value at the specified scope in the program.
LLVM_ABI const SCEV * getSMinExpr(const SCEV *LHS, const SCEV *RHS)
LLVM_ABI const SCEV * getBackedgeTakenCount(const Loop *L, ExitCountKind Kind=Exact)
If the specified loop has a predictable backedge-taken count, return it, otherwise return a SCEVCould...
LLVM_ABI const SCEV * getUMaxExpr(const SCEV *LHS, const SCEV *RHS)
LLVM_ABI void setNoWrapFlags(SCEVAddRecExpr *AddRec, SCEV::NoWrapFlags Flags)
Update no-wrap flags of an AddRec.
LLVM_ABI const SCEV * getUMaxFromMismatchedTypes(const SCEV *LHS, const SCEV *RHS)
Promote the operands to the wider of the types using zero-extension, and then perform a umax operatio...
const SCEV * getZero(Type *Ty)
Return a SCEV for the constant 0 of a specific type.
LLVM_ABI bool willNotOverflow(Instruction::BinaryOps BinOp, bool Signed, const SCEV *LHS, const SCEV *RHS, const Instruction *CtxI=nullptr)
Is operation BinOp between LHS and RHS provably does not have a signed/unsigned overflow (Signed)?
LLVM_ABI ExitLimit computeExitLimitFromCond(const Loop *L, Value *ExitCond, bool ExitIfTrue, bool ControlsOnlyExit, bool AllowPredicates=false)
Compute the number of times the backedge of the specified loop will execute if its exit condition wer...
LLVM_ABI const SCEV * getZeroExtendExprImpl(const SCEV *Op, Type *Ty, unsigned Depth=0)
LLVM_ABI const SCEVPredicate * getEqualPredicate(const SCEV *LHS, const SCEV *RHS)
LLVM_ABI unsigned getSmallConstantTripMultiple(const Loop *L, const SCEV *ExitCount)
Returns the largest constant divisor of the trip count as a normal unsigned value,...
LLVM_ABI uint64_t getTypeSizeInBits(Type *Ty) const
Return the size in bits of the specified type, for which isSCEVable must return true.
LLVM_ABI const SCEV * getConstant(ConstantInt *V)
LLVM_ABI const SCEV * getPredicatedBackedgeTakenCount(const Loop *L, SmallVectorImpl< const SCEVPredicate * > &Predicates)
Similar to getBackedgeTakenCount, except it will add a set of SCEV predicates to Predicates that are ...
LLVM_ABI const SCEV * getSCEV(Value *V)
Return a SCEV expression for the full generality of the specified expression.
ConstantRange getSignedRange(const SCEV *S)
Determine the signed range for a particular SCEV.
LLVM_ABI const SCEV * getNoopOrSignExtend(const SCEV *V, Type *Ty)
Return a SCEV corresponding to a conversion of the input value to the specified type.
bool loopHasNoAbnormalExits(const Loop *L)
Return true if the loop has no abnormal exits.
LLVM_ABI const SCEV * getTripCountFromExitCount(const SCEV *ExitCount)
A version of getTripCountFromExitCount below which always picks an evaluation type which can not resu...
LLVM_ABI ScalarEvolution(Function &F, TargetLibraryInfo &TLI, AssumptionCache &AC, DominatorTree &DT, LoopInfo &LI)
const SCEV * getOne(Type *Ty)
Return a SCEV for the constant 1 of a specific type.
LLVM_ABI const SCEV * getTruncateOrNoop(const SCEV *V, Type *Ty)
Return a SCEV corresponding to a conversion of the input value to the specified type.
LLVM_ABI const SCEV * getCastExpr(SCEVTypes Kind, const SCEV *Op, Type *Ty)
LLVM_ABI const SCEV * getSequentialMinMaxExpr(SCEVTypes Kind, SmallVectorImpl< const SCEV * > &Operands)
LLVM_ABI const SCEV * getLosslessPtrToIntExpr(const SCEV *Op, unsigned Depth=0)
LLVM_ABI std::optional< bool > evaluatePredicateAt(CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS, const Instruction *CtxI)
Check whether the condition described by Pred, LHS, and RHS is true or false in the given Context.
LLVM_ABI unsigned getSmallConstantMaxTripCount(const Loop *L, SmallVectorImpl< const SCEVPredicate * > *Predicates=nullptr)
Returns the upper bound of the loop trip count as a normal unsigned value.
LLVM_ABI const SCEV * getPtrToIntExpr(const SCEV *Op, Type *Ty)
LLVM_ABI bool isBackedgeTakenCountMaxOrZero(const Loop *L)
Return true if the backedge taken count is either the value returned by getConstantMaxBackedgeTakenCo...
LLVM_ABI void forgetLoop(const Loop *L)
This method should be called by the client when it has changed a loop in a way that may effect Scalar...
LLVM_ABI bool isLoopInvariant(const SCEV *S, const Loop *L)
Return true if the value of the given SCEV is unchanging in the specified loop.
LLVM_ABI bool isKnownPositive(const SCEV *S)
Test if the given expression is known to be positive.
APInt getUnsignedRangeMin(const SCEV *S)
Determine the min of the unsigned range for a particular SCEV.
LLVM_ABI bool SimplifyICmpOperands(CmpPredicate &Pred, const SCEV *&LHS, const SCEV *&RHS, unsigned Depth=0)
Simplify LHS and RHS in a comparison with predicate Pred.
LLVM_ABI const SCEV * getOffsetOfExpr(Type *IntTy, StructType *STy, unsigned FieldNo)
Return an expression for offsetof on the given field with type IntTy.
LLVM_ABI LoopDisposition getLoopDisposition(const SCEV *S, const Loop *L)
Return the "disposition" of the given SCEV with respect to the given loop.
LLVM_ABI bool containsAddRecurrence(const SCEV *S)
Return true if the SCEV is a scAddRecExpr or it contains scAddRecExpr.
LLVM_ABI const SCEV * getSignExtendExprImpl(const SCEV *Op, Type *Ty, unsigned Depth=0)
LLVM_ABI const SCEV * getAddRecExpr(const SCEV *Start, const SCEV *Step, const Loop *L, SCEV::NoWrapFlags Flags)
Get an add recurrence expression for the specified loop.
LLVM_ABI bool hasOperand(const SCEV *S, const SCEV *Op) const
Test whether the given SCEV has Op as a direct or indirect operand.
LLVM_ABI const SCEV * getUDivExpr(const SCEV *LHS, const SCEV *RHS)
Get a canonical unsigned division expression, or something simpler if possible.
LLVM_ABI const SCEV * getZeroExtendExpr(const SCEV *Op, Type *Ty, unsigned Depth=0)
LLVM_ABI bool isSCEVable(Type *Ty) const
Test if values of the given type are analyzable within the SCEV framework.
LLVM_ABI Type * getEffectiveSCEVType(Type *Ty) const
Return a type with the same bitwidth as the given type and which represents how SCEV will treat the g...
LLVM_ABI const SCEVPredicate * getComparePredicate(ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS)
LLVM_ABI const SCEV * getNotSCEV(const SCEV *V)
Return the SCEV object corresponding to ~V.
LLVM_ABI const SCEV * getElementCount(Type *Ty, ElementCount EC, SCEV::NoWrapFlags Flags=SCEV::FlagAnyWrap)
LLVM_ABI bool instructionCouldExistWithOperands(const SCEV *A, const SCEV *B)
Return true if there exists a point in the program at which both A and B could be operands to the sam...
ConstantRange getUnsignedRange(const SCEV *S)
Determine the unsigned range for a particular SCEV.
LLVM_ABI uint32_t getMinTrailingZeros(const SCEV *S)
Determine the minimum number of zero bits that S is guaranteed to end in (at every loop iteration).
LLVM_ABI void print(raw_ostream &OS) const
LLVM_ABI const SCEV * getUMinExpr(const SCEV *LHS, const SCEV *RHS, bool Sequential=false)
LLVM_ABI const SCEV * getPredicatedExitCount(const Loop *L, const BasicBlock *ExitingBlock, SmallVectorImpl< const SCEVPredicate * > *Predicates, ExitCountKind Kind=Exact)
Same as above except this uses the predicated backedge taken info and may require predicates.
static SCEV::NoWrapFlags clearFlags(SCEV::NoWrapFlags Flags, SCEV::NoWrapFlags OffFlags)
LLVM_ABI void forgetTopmostLoop(const Loop *L)
LLVM_ABI void forgetValue(Value *V)
This method should be called by the client when it has changed a value in a way that may effect its v...
APInt getSignedRangeMin(const SCEV *S)
Determine the min of the signed range for a particular SCEV.
LLVM_ABI const SCEV * getNoopOrAnyExtend(const SCEV *V, Type *Ty)
Return a SCEV corresponding to a conversion of the input value to the specified type.
LLVM_ABI void forgetBlockAndLoopDispositions(Value *V=nullptr)
Called when the client has changed the disposition of values in a loop or block.
LLVM_ABI const SCEV * getTruncateExpr(const SCEV *Op, Type *Ty, unsigned Depth=0)
LLVM_ABI std::optional< LoopInvariantPredicate > getLoopInvariantPredicate(CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS, const Loop *L, const Instruction *CtxI=nullptr)
If the result of the predicate LHS Pred RHS is loop invariant with respect to L, return a LoopInvaria...
LLVM_ABI const SCEV * getStoreSizeOfExpr(Type *IntTy, Type *StoreTy)
Return an expression for the store size of StoreTy that is type IntTy.
LLVM_ABI const SCEVPredicate * getWrapPredicate(const SCEVAddRecExpr *AR, SCEVWrapPredicate::IncrementWrapFlags AddedFlags)
LLVM_ABI bool isLoopBackedgeGuardedByCond(const Loop *L, CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS)
Test whether the backedge of the loop is protected by a conditional between LHS and RHS.
LLVM_ABI const SCEV * getMinusSCEV(const SCEV *LHS, const SCEV *RHS, SCEV::NoWrapFlags Flags=SCEV::FlagAnyWrap, unsigned Depth=0)
Return LHS-RHS.
LLVM_ABI APInt getNonZeroConstantMultiple(const SCEV *S)
const SCEV * getMinusOne(Type *Ty)
Return a SCEV for the constant -1 of a specific type.
static SCEV::NoWrapFlags setFlags(SCEV::NoWrapFlags Flags, SCEV::NoWrapFlags OnFlags)
LLVM_ABI bool hasLoopInvariantBackedgeTakenCount(const Loop *L)
Return true if the specified loop has an analyzable loop-invariant backedge-taken count.
LLVM_ABI BlockDisposition getBlockDisposition(const SCEV *S, const BasicBlock *BB)
Return the "disposition" of the given SCEV with respect to the given block.
LLVM_ABI const SCEV * getNoopOrZeroExtend(const SCEV *V, Type *Ty)
Return a SCEV corresponding to a conversion of the input value to the specified type.
LLVM_ABI bool invalidate(Function &F, const PreservedAnalyses &PA, FunctionAnalysisManager::Invalidator &Inv)
LLVM_ABI const SCEV * getUMinFromMismatchedTypes(const SCEV *LHS, const SCEV *RHS, bool Sequential=false)
Promote the operands to the wider of the types using zero-extension, and then perform a umin operatio...
LLVM_ABI bool loopIsFiniteByAssumption(const Loop *L)
Return true if this loop is finite by assumption.
LLVM_ABI const SCEV * getExistingSCEV(Value *V)
Return an existing SCEV for V if there is one, otherwise return nullptr.
LoopDisposition
An enum describing the relationship between a SCEV and a loop.
@ LoopComputable
The SCEV varies predictably with the loop.
@ LoopVariant
The SCEV is loop-variant (unknown).
@ LoopInvariant
The SCEV is loop-invariant.
LLVM_ABI bool isKnownMultipleOf(const SCEV *S, uint64_t M, SmallVectorImpl< const SCEVPredicate * > &Assumptions)
Check that S is a multiple of M.
LLVM_ABI const SCEV * getAnyExtendExpr(const SCEV *Op, Type *Ty)
getAnyExtendExpr - Return a SCEV for the given operand extended with unspecified bits out to the give...
LLVM_ABI bool isKnownToBeAPowerOfTwo(const SCEV *S, bool OrZero=false, bool OrNegative=false)
Test if the given expression is known to be a power of 2.
LLVM_ABI std::optional< SCEV::NoWrapFlags > getStrengthenedNoWrapFlagsFromBinOp(const OverflowingBinaryOperator *OBO)
Parse NSW/NUW flags from add/sub/mul IR binary operation Op into SCEV no-wrap flags,...
LLVM_ABI void forgetLcssaPhiWithNewPredecessor(Loop *L, PHINode *V)
Forget LCSSA phi node V of loop L to which a new predecessor was added, such that it may no longer be...
LLVM_ABI bool containsUndefs(const SCEV *S) const
Return true if the SCEV expression contains an undef value.
LLVM_ABI std::optional< MonotonicPredicateType > getMonotonicPredicateType(const SCEVAddRecExpr *LHS, ICmpInst::Predicate Pred)
If, for all loop invariant X, the predicate "LHS `Pred` X" is monotonically increasing or decreasing,...
LLVM_ABI const SCEV * getCouldNotCompute()
LLVM_ABI bool isAvailableAtLoopEntry(const SCEV *S, const Loop *L)
Determine if the SCEV can be evaluated at loop's entry.
BlockDisposition
An enum describing the relationship between a SCEV and a basic block.
@ DominatesBlock
The SCEV dominates the block.
@ ProperlyDominatesBlock
The SCEV properly dominates the block.
@ DoesNotDominateBlock
The SCEV does not dominate the block.
LLVM_ABI const SCEV * getExitCount(const Loop *L, const BasicBlock *ExitingBlock, ExitCountKind Kind=Exact)
Return the number of times the backedge executes before the given exit would be taken; if not exactly...
LLVM_ABI const SCEV * getSignExtendExpr(const SCEV *Op, Type *Ty, unsigned Depth=0)
LLVM_ABI void getPoisonGeneratingValues(SmallPtrSetImpl< const Value * > &Result, const SCEV *S)
Return the set of Values that, if poison, will definitively result in S being poison as well.
LLVM_ABI void forgetLoopDispositions()
Called when the client has changed the disposition of values in this loop.
LLVM_ABI const SCEV * getVScale(Type *Ty)
LLVM_ABI unsigned getSmallConstantTripCount(const Loop *L)
Returns the exact trip count of the loop if we can compute it, and the result is a small constant.
LLVM_ABI bool hasComputableLoopEvolution(const SCEV *S, const Loop *L)
Return true if the given SCEV changes value in a known way in the specified loop.
LLVM_ABI const SCEV * getPointerBase(const SCEV *V)
Transitively follow the chain of pointer-type operands until reaching a SCEV that does not have a sin...
LLVM_ABI const SCEV * getMinMaxExpr(SCEVTypes Kind, SmallVectorImpl< const SCEV * > &Operands)
LLVM_ABI void forgetAllLoops()
LLVM_ABI bool dominates(const SCEV *S, const BasicBlock *BB)
Return true if elements that makes up the given SCEV dominate the specified basic block.
APInt getUnsignedRangeMax(const SCEV *S)
Determine the max of the unsigned range for a particular SCEV.
ExitCountKind
The terms "backedge taken count" and "exit count" are used interchangeably to refer to the number of ...
@ SymbolicMaximum
An expression which provides an upper bound on the exact trip count.
@ ConstantMaximum
A constant which provides an upper bound on the exact trip count.
@ Exact
An expression exactly describing the number of times the backedge has executed when a loop is exited.
LLVM_ABI const SCEV * applyLoopGuards(const SCEV *Expr, const Loop *L)
Try to apply information from loop guards for L to Expr.
LLVM_ABI const SCEV * getMulExpr(SmallVectorImpl< const SCEV * > &Ops, SCEV::NoWrapFlags Flags=SCEV::FlagAnyWrap, unsigned Depth=0)
Get a canonical multiply expression, or something simpler if possible.
LLVM_ABI const SCEVAddRecExpr * convertSCEVToAddRecWithPredicates(const SCEV *S, const Loop *L, SmallVectorImpl< const SCEVPredicate * > &Preds)
Tries to convert the S expression to an AddRec expression, adding additional predicates to Preds as r...
LLVM_ABI const SCEV * getElementSize(Instruction *Inst)
Return the size of an element read or written by Inst.
LLVM_ABI const SCEV * getSizeOfExpr(Type *IntTy, TypeSize Size)
Return an expression for a TypeSize.
LLVM_ABI std::optional< bool > evaluatePredicate(CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS)
Check whether the condition described by Pred, LHS, and RHS is true or false.
LLVM_ABI const SCEV * getUnknown(Value *V)
LLVM_ABI std::optional< std::pair< const SCEV *, SmallVector< const SCEVPredicate *, 3 > > > createAddRecFromPHIWithCasts(const SCEVUnknown *SymbolicPHI)
Checks if SymbolicPHI can be rewritten as an AddRecExpr under some Predicates.
LLVM_ABI const SCEV * getTruncateOrZeroExtend(const SCEV *V, Type *Ty, unsigned Depth=0)
Return a SCEV corresponding to a conversion of the input value to the specified type.
static SCEV::NoWrapFlags maskFlags(SCEV::NoWrapFlags Flags, int Mask)
Convenient NoWrapFlags manipulation that hides enum casts and is visible in the ScalarEvolution name ...
LLVM_ABI std::optional< APInt > computeConstantDifference(const SCEV *LHS, const SCEV *RHS)
Compute LHS - RHS and returns the result as an APInt if it is a constant, and std::nullopt if it isn'...
LLVM_ABI bool properlyDominates(const SCEV *S, const BasicBlock *BB)
Return true if elements that makes up the given SCEV properly dominate the specified basic block.
LLVM_ABI const SCEV * rewriteUsingPredicate(const SCEV *S, const Loop *L, const SCEVPredicate &A)
Re-writes the SCEV according to the Predicates in A.
LLVM_ABI std::pair< const SCEV *, const SCEV * > SplitIntoInitAndPostInc(const Loop *L, const SCEV *S)
Splits SCEV expression S into two SCEVs.
LLVM_ABI bool canReuseInstruction(const SCEV *S, Instruction *I, SmallVectorImpl< Instruction * > &DropPoisonGeneratingInsts)
Check whether it is poison-safe to represent the expression S using the instruction I.
LLVM_ABI bool isKnownPredicateAt(CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS, const Instruction *CtxI)
Test if the given expression is known to satisfy the condition described by Pred, LHS,...
LLVM_ABI const SCEV * getPredicatedSymbolicMaxBackedgeTakenCount(const Loop *L, SmallVectorImpl< const SCEVPredicate * > &Predicates)
Similar to getSymbolicMaxBackedgeTakenCount, except it will add a set of SCEV predicates to Predicate...
LLVM_ABI const SCEV * getUDivExactExpr(const SCEV *LHS, const SCEV *RHS)
Get a canonical unsigned division expression, or something simpler if possible.
LLVM_ABI void registerUser(const SCEV *User, ArrayRef< const SCEV * > Ops)
Notify this ScalarEvolution that User directly uses SCEVs in Ops.
LLVM_ABI const SCEV * getAddExpr(SmallVectorImpl< const SCEV * > &Ops, SCEV::NoWrapFlags Flags=SCEV::FlagAnyWrap, unsigned Depth=0)
Get a canonical add expression, or something simpler if possible.
LLVM_ABI bool isBasicBlockEntryGuardedByCond(const BasicBlock *BB, CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS)
Test whether entry to the basic block is protected by a conditional between LHS and RHS.
LLVM_ABI const SCEV * getTruncateOrSignExtend(const SCEV *V, Type *Ty, unsigned Depth=0)
Return a SCEV corresponding to a conversion of the input value to the specified type.
LLVM_ABI bool containsErasedValue(const SCEV *S) const
Return true if the SCEV expression contains a Value that has been optimised out and is now a nullptr.
LLVM_ABI bool isKnownPredicate(CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS)
Test if the given expression is known to satisfy the condition described by Pred, LHS,...
LLVM_ABI bool isKnownViaInduction(CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS)
We'd like to check the predicate on every iteration of the most dominated loop between loops used in ...
const SCEV * getSymbolicMaxBackedgeTakenCount(const Loop *L)
When successful, this returns a SCEV that is greater than or equal to (i.e.
APInt getSignedRangeMax(const SCEV *S)
Determine the max of the signed range for a particular SCEV.
LLVM_ABI void verify() const
LLVMContext & getContext() const
size_type size() const
Definition SmallPtrSet.h:99
A templated base class for SmallPtrSet which provides the typesafe interface that is common across al...
std::pair< iterator, bool > insert(PtrType Ptr)
Inserts Ptr if and only if there is no element in the container equal to Ptr.
bool contains(ConstPtrType Ptr) const
SmallPtrSet - This class implements a set which is optimized for holding SmallSize or less elements.
This class consists of common code factored out of the SmallVector class to reduce code duplication b...
reference emplace_back(ArgTypes &&... Args)
void reserve(size_type N)
iterator erase(const_iterator CI)
void append(ItTy in_start, ItTy in_end)
Add the specified range to the end of the SmallVector.
iterator insert(iterator I, T &&Elt)
void push_back(const T &Elt)
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
An instruction for storing to memory.
Used to lazily calculate structure layout information for a target machine, based on the DataLayout s...
Definition DataLayout.h:623
TypeSize getElementOffset(unsigned Idx) const
Definition DataLayout.h:654
TypeSize getSizeInBits() const
Definition DataLayout.h:634
Class to represent struct types.
Analysis pass providing the TargetLibraryInfo.
Provides information about what library functions are available for the current target.
The instances of the Type class are immutable: once they are created, they are never changed.
Definition Type.h:45
static LLVM_ABI IntegerType * getInt32Ty(LLVMContext &C)
Definition Type.cpp:297
bool isPointerTy() const
True if this is an instance of PointerType.
Definition Type.h:267
static LLVM_ABI IntegerType * getInt8Ty(LLVMContext &C)
Definition Type.cpp:295
LLVM_ABI TypeSize getPrimitiveSizeInBits() const LLVM_READONLY
Return the basic size of this type if it is a primitive type.
Definition Type.cpp:198
static LLVM_ABI IntegerType * getInt1Ty(LLVMContext &C)
Definition Type.cpp:294
bool isIntOrPtrTy() const
Return true if this is an integer type or a pointer type.
Definition Type.h:255
bool isIntegerTy() const
True if this is an instance of IntegerType.
Definition Type.h:240
static LLVM_ABI IntegerType * getIntNTy(LLVMContext &C, unsigned N)
Definition Type.cpp:301
A Use represents the edge between a Value definition and its users.
Definition Use.h:35
op_range operands()
Definition User.h:292
Use & Op()
Definition User.h:196
Value * getOperand(unsigned i) const
Definition User.h:232
LLVM Value Representation.
Definition Value.h:75
Type * getType() const
All values are typed, get the type of this value.
Definition Value.h:256
unsigned getValueID() const
Return an ID for the concrete type of this object.
Definition Value.h:543
LLVM_ABI void printAsOperand(raw_ostream &O, bool PrintType=true, const Module *M=nullptr) const
Print the name of this Value out to the specified raw_ostream.
LLVM_ABI LLVMContext & getContext() const
All values hold a context through their type.
Definition Value.cpp:1101
LLVM_ABI StringRef getName() const
Return a constant reference to the value's name.
Definition Value.cpp:322
constexpr bool isScalable() const
Returns whether the quantity is scaled by a runtime quantity (vscale).
Definition TypeSize.h:169
const ParentTy * getParent() const
Definition ilist_node.h:34
This class implements an extremely fast bulk output stream that can only output to a stream.
Definition raw_ostream.h:53
raw_ostream & indent(unsigned NumSpaces)
indent - Insert 'NumSpaces' spaces.
Changed
#define llvm_unreachable(msg)
Marks that the current location is not supposed to be reachable.
constexpr char Align[]
Key for Kernel::Arg::Metadata::mAlign.
const APInt & smin(const APInt &A, const APInt &B)
Determine the smaller of two APInts considered to be signed.
Definition APInt.h:2248
const APInt & smax(const APInt &A, const APInt &B)
Determine the larger of two APInts considered to be signed.
Definition APInt.h:2253
const APInt & umin(const APInt &A, const APInt &B)
Determine the smaller of two APInts considered to be unsigned.
Definition APInt.h:2258
LLVM_ABI std::optional< APInt > SolveQuadraticEquationWrap(APInt A, APInt B, APInt C, unsigned RangeWidth)
Let q(n) = An^2 + Bn + C, and BW = bit width of the value range (e.g.
Definition APInt.cpp:2812
const APInt & umax(const APInt &A, const APInt &B)
Determine the larger of two APInts considered to be unsigned.
Definition APInt.h:2263
LLVM_ABI APInt GreatestCommonDivisor(APInt A, APInt B)
Compute GCD of two unsigned APInt values.
Definition APInt.cpp:798
@ Entry
Definition COFF.h:862
unsigned ID
LLVM IR allows to use arbitrary numbers as calling convention identifiers.
Definition CallingConv.h:24
@ C
The default llvm calling convention, compatible with C.
Definition CallingConv.h:34
int getMinValue(MCInstrInfo const &MCII, MCInst const &MCI)
Return the minimum value of an extendable operand.
@ BasicBlock
Various leaf nodes.
Definition ISDOpcodes.h:81
LLVM_ABI Function * getDeclarationIfExists(const Module *M, ID id)
Look up the Function declaration of the intrinsic id in the Module M and return it if it exists.
Predicate
Predicate - These are "(BI << 5) | BO" for various predicates.
BinaryOp_match< LHS, RHS, Instruction::AShr > m_AShr(const LHS &L, const RHS &R)
bool match(Val *V, const Pattern &P)
class_match< ConstantInt > m_ConstantInt()
Match an arbitrary ConstantInt and ignore it.
IntrinsicID_match m_Intrinsic()
Match intrinsic calls like this: m_Intrinsic<Intrinsic::fabs>(m_Value(X))
ThreeOps_match< Cond, LHS, RHS, Instruction::Select > m_Select(const Cond &C, const LHS &L, const RHS &R)
Matches SelectInst.
ExtractValue_match< Ind, Val_t > m_ExtractValue(const Val_t &V)
Match a single index ExtractValue instruction.
bind_ty< WithOverflowInst > m_WithOverflowInst(WithOverflowInst *&I)
Match a with overflow intrinsic, capturing it if we match.
auto m_LogicalOr()
Matches L || R where L and R are arbitrary values.
brc_match< Cond_t, bind_ty< BasicBlock >, bind_ty< BasicBlock > > m_Br(const Cond_t &C, BasicBlock *&T, BasicBlock *&F)
BinaryOp_match< LHS, RHS, Instruction::SDiv > m_SDiv(const LHS &L, const RHS &R)
apint_match m_APInt(const APInt *&Res)
Match a ConstantInt or splatted ConstantVector, binding the specified pointer to the contained APInt.
class_match< Value > m_Value()
Match an arbitrary value and ignore it.
BinaryOp_match< LHS, RHS, Instruction::LShr > m_LShr(const LHS &L, const RHS &R)
BinaryOp_match< LHS, RHS, Instruction::Shl > m_Shl(const LHS &L, const RHS &R)
auto m_LogicalAnd()
Matches L && R where L and R are arbitrary values.
class_match< BasicBlock > m_BasicBlock()
Match an arbitrary basic block value and ignore it.
match_combine_or< LTy, RTy > m_CombineOr(const LTy &L, const RTy &R)
Combine two pattern matchers matching L || R.
bind_cst_ty m_scev_APInt(const APInt *&C)
Match an SCEV constant and bind it to an APInt.
cst_pred_ty< is_all_ones > m_scev_AllOnes()
Match an integer with all bits set.
SCEVUnaryExpr_match< SCEVZeroExtendExpr, Op0_t > m_scev_ZExt(const Op0_t &Op0)
class_match< const SCEVConstant > m_SCEVConstant()
cst_pred_ty< is_one > m_scev_One()
Match an integer 1.
specificloop_ty m_SpecificLoop(const Loop *L)
SCEVAffineAddRec_match< Op0_t, Op1_t, class_match< const Loop > > m_scev_AffineAddRec(const Op0_t &Op0, const Op1_t &Op1)
SCEVUnaryExpr_match< SCEVSignExtendExpr, Op0_t > m_scev_SExt(const Op0_t &Op0)
cst_pred_ty< is_zero > m_scev_Zero()
Match an integer 0.
bool match(const SCEV *S, const Pattern &P)
SCEVBinaryExpr_match< SCEVUDivExpr, Op0_t, Op1_t > m_scev_UDiv(const Op0_t &Op0, const Op1_t &Op1)
specificscev_ty m_scev_Specific(const SCEV *S)
Match if we have a specific specified SCEV.
class_match< const Loop > m_Loop()
bind_ty< const SCEVAddExpr > m_scev_Add(const SCEVAddExpr *&V)
bind_ty< const SCEVUnknown > m_SCEVUnknown(const SCEVUnknown *&V)
class_match< const SCEV > m_SCEV()
initializer< Ty > init(const Ty &Val)
LocationClass< Ty > location(Ty &L)
@ Switch
The "resume-switch" lowering, where there are separate resume and destroy functions that are shared b...
Definition CoroShape.h:31
constexpr double e
Definition MathExtras.h:47
NodeAddr< PhiNode * > Phi
Definition RDFGraph.h:390
friend class Instruction
Iterator for Instructions in a `BasicBlock.
Definition BasicBlock.h:73
This is an optimization pass for GlobalISel generic memory operations.
void visitAll(const SCEV *Root, SV &Visitor)
Use SCEVTraversal to visit all nodes in the given expression tree.
auto drop_begin(T &&RangeOrContainer, size_t N=1)
Return a range covering RangeOrContainer with the first N elements excluded.
Definition STLExtras.h:330
@ Offset
Definition DWP.cpp:477
FunctionAddr VTableAddr Value
Definition InstrProf.h:137
LLVM_ATTRIBUTE_ALWAYS_INLINE DynamicAPInt gcd(const DynamicAPInt &A, const DynamicAPInt &B)
void stable_sort(R &&Range)
Definition STLExtras.h:2060
bool all_of(R &&range, UnaryPredicate P)
Provide wrappers to std::all_of which take ranges instead of having to pass begin/end explicitly.
Definition STLExtras.h:1727
SaveAndRestore(T &) -> SaveAndRestore< T >
Printable print(const GCNRegPressure &RP, const GCNSubtarget *ST=nullptr, unsigned DynamicVGPRBlockSize=0)
LLVM_ABI bool canCreatePoison(const Operator *Op, bool ConsiderFlagsAndMetadata=true)
LLVM_ABI bool mustTriggerUB(const Instruction *I, const SmallPtrSetImpl< const Value * > &KnownPoison)
Return true if the given instruction must trigger undefined behavior when I is executed with any oper...
detail::scope_exit< std::decay_t< Callable > > make_scope_exit(Callable &&F)
Definition ScopeExit.h:59
LLVM_ABI bool canConstantFoldCallTo(const CallBase *Call, const Function *F)
canConstantFoldCallTo - Return true if its even possible to fold a call to the specified function.
InterleavedRange< Range > interleaved(const Range &R, StringRef Separator=", ", StringRef Prefix="", StringRef Suffix="")
Output range R as a sequence of interleaved elements.
decltype(auto) dyn_cast(const From &Val)
dyn_cast<X> - Return the argument parameter cast to the specified type.
Definition Casting.h:649
LLVM_ABI bool verifyFunction(const Function &F, raw_ostream *OS=nullptr)
Check a function for errors, useful for use when debugging a pass.
auto successors(const MachineBasicBlock *BB)
constexpr from_range_t from_range
auto dyn_cast_if_present(const Y &Val)
dyn_cast_if_present<X> - Functionally identical to dyn_cast, except that a null (or none in the case ...
Definition Casting.h:738
bool set_is_subset(const S1Ty &S1, const S2Ty &S2)
set_is_subset(A, B) - Return true iff A in B
void append_range(Container &C, Range &&R)
Wrapper function to append range R to container C.
Definition STLExtras.h:2138
constexpr bool isUIntN(unsigned N, uint64_t x)
Checks if an unsigned integer fits into the given (dynamic) bit width.
Definition MathExtras.h:252
LLVM_ABI Constant * ConstantFoldCompareInstOperands(unsigned Predicate, Constant *LHS, Constant *RHS, const DataLayout &DL, const TargetLibraryInfo *TLI=nullptr, const Instruction *I=nullptr)
Attempt to constant fold a compare instruction (icmp/fcmp) with the specified operands.
unsigned short computeExpressionSize(ArrayRef< const SCEV * > Args)
void * PointerTy
LLVM_ABI bool VerifySCEV
auto uninitialized_copy(R &&Src, IterTy Dst)
Definition STLExtras.h:2055
bool isa_and_nonnull(const Y &Val)
Definition Casting.h:682
LLVM_ABI ConstantRange getConstantRangeFromMetadata(const MDNode &RangeMD)
Parse out a conservative ConstantRange from !range metadata.
int countr_zero(T Val)
Count number of 0's from the least significant bit to the most stopping at the first 1.
Definition bit.h:157
LLVM_ABI Value * simplifyInstruction(Instruction *I, const SimplifyQuery &Q)
See if we can compute a simplified version of this instruction.
LLVM_ABI bool isOverflowIntrinsicNoWrap(const WithOverflowInst *WO, const DominatorTree &DT)
Returns true if the arithmetic part of the WO 's result is used only along the paths control dependen...
DomTreeNodeBase< BasicBlock > DomTreeNode
Definition Dominators.h:95
LLVM_ABI bool matchSimpleRecurrence(const PHINode *P, BinaryOperator *&BO, Value *&Start, Value *&Step)
Attempt to match a simple first order recurrence cycle of the form: iv = phi Ty [Start,...
auto dyn_cast_or_null(const Y &Val)
Definition Casting.h:759
void erase(Container &C, ValueType V)
Wrapper function to remove a value from a container:
Definition STLExtras.h:2130
bool any_of(R &&range, UnaryPredicate P)
Provide wrappers to std::any_of which take ranges instead of having to pass begin/end explicitly.
Definition STLExtras.h:1734
iterator_range< pointee_iterator< WrappedIteratorT > > make_pointee_range(RangeT &&Range)
Definition iterator.h:336
auto reverse(ContainerTy &&C)
Definition STLExtras.h:420
LLVM_ABI bool isMustProgress(const Loop *L)
Return true if this loop can be assumed to make progress.
LLVM_ABI bool impliesPoison(const Value *ValAssumedPoison, const Value *V)
Return true if V is poison given that ValAssumedPoison is already poison.
LLVM_ABI bool isFinite(const Loop *L)
Return true if this loop can be assumed to run for a finite number of iterations.
LLVM_ABI void computeKnownBits(const Value *V, KnownBits &Known, const DataLayout &DL, AssumptionCache *AC=nullptr, const Instruction *CxtI=nullptr, const DominatorTree *DT=nullptr, bool UseInstrInfo=true, unsigned Depth=0)
Determine which bits of V are known to be either zero or one and return them in the KnownZero/KnownOn...
LLVM_ABI bool programUndefinedIfPoison(const Instruction *Inst)
LLVM_ABI raw_ostream & dbgs()
dbgs() - This returns a reference to a raw_ostream for debugging messages.
Definition Debug.cpp:207
bool isPointerTy(const Type *T)
Definition SPIRVUtils.h:288
FunctionAddr VTableAddr Count
Definition InstrProf.h:139
LLVM_ABI ConstantRange getVScaleRange(const Function *F, unsigned BitWidth)
Determine the possible constant range of vscale with the given bit width, based on the vscale_range f...
class LLVM_GSL_OWNER SmallVector
Forward declaration of SmallVector so that calculateSmallVectorDefaultInlinedElements can reference s...
bool isa(const From &Val)
isa<X> - Return true if the parameter to the template is an instance of one of the template type argu...
Definition Casting.h:548
LLVM_ATTRIBUTE_VISIBILITY_DEFAULT AnalysisKey InnerAnalysisManagerProxy< AnalysisManagerT, IRUnitT, ExtraArgTs... >::Key
LLVM_ABI bool isKnownNonZero(const Value *V, const SimplifyQuery &Q, unsigned Depth=0)
Return true if the given value is known to be non-zero when defined.
@ First
Helpers to iterate all locations in the MemoryEffectsBase class.
Definition ModRef.h:71
LLVM_ABI bool propagatesPoison(const Use &PoisonOp)
Return true if PoisonOp's user yields poison or raises UB if its operand PoisonOp is poison.
@ UMin
Unsigned integer min implemented in terms of select(cmp()).
@ Mul
Product of integers.
@ SMax
Signed integer max implemented in terms of select(cmp()).
@ SMin
Signed integer min implemented in terms of select(cmp()).
@ Add
Sum of integers.
@ UMax
Unsigned integer max implemented in terms of select(cmp()).
auto count(R &&Range, const E &Element)
Wrapper function around std::count to count the number of times an element Element occurs in the give...
Definition STLExtras.h:1956
DWARFExpression::Operation Op
auto max_element(R &&Range)
Provide wrappers to std::max_element which take ranges instead of having to pass begin/end explicitly...
Definition STLExtras.h:2032
raw_ostream & operator<<(raw_ostream &OS, const APFixedPoint &FX)
ArrayRef(const T &OneElt) -> ArrayRef< T >
LLVM_ABI unsigned ComputeNumSignBits(const Value *Op, const DataLayout &DL, AssumptionCache *AC=nullptr, const Instruction *CxtI=nullptr, const DominatorTree *DT=nullptr, bool UseInstrInfo=true, unsigned Depth=0)
Return the number of times the sign bit of the register is replicated into the other bits.
constexpr unsigned BitWidth
OutputIt move(R &&Range, OutputIt Out)
Provide wrappers to std::move which take ranges instead of having to pass begin/end explicitly.
Definition STLExtras.h:1869
LLVM_ABI bool isGuaranteedToTransferExecutionToSuccessor(const Instruction *I)
Return true if this function can prove that the instruction I will always transfer execution to one o...
auto count_if(R &&Range, UnaryPredicate P)
Wrapper function around std::count_if to count the number of times an element satisfying a given pred...
Definition STLExtras.h:1963
decltype(auto) cast(const From &Val)
cast<X> - Return the argument parameter cast to the specified type.
Definition Casting.h:565
constexpr bool isIntN(unsigned N, int64_t x)
Checks if an signed integer fits into the given (dynamic) bit width.
Definition MathExtras.h:257
auto predecessors(const MachineBasicBlock *BB)
bool is_contained(R &&Range, const E &Element)
Returns true if Element is found in Range.
Definition STLExtras.h:1899
iterator_range< df_iterator< T > > depth_first(const T &G)
auto seq(T Begin, T End)
Iterate over an integral type from Begin up to - but not including - End.
Definition Sequence.h:305
AnalysisManager< Function > FunctionAnalysisManager
Convenience typedef for the Function analysis manager.
LLVM_ABI bool isGuaranteedNotToBePoison(const Value *V, AssumptionCache *AC=nullptr, const Instruction *CtxI=nullptr, const DominatorTree *DT=nullptr, unsigned Depth=0)
Returns true if V cannot be poison, but may be undef.
LLVM_ABI Constant * ConstantFoldInstOperands(const Instruction *I, ArrayRef< Constant * > Ops, const DataLayout &DL, const TargetLibraryInfo *TLI=nullptr, bool AllowNonDeterministic=true)
ConstantFoldInstOperands - Attempt to constant fold an instruction with the specified operands.
bool SCEVExprContains(const SCEV *Root, PredTy Pred)
Return true if any node in Root satisfies the predicate Pred.
Implement std::hash so that hash_code can be used in STL containers.
Definition BitVector.h:851
void swap(llvm::BitVector &LHS, llvm::BitVector &RHS)
Implement std::swap in terms of BitVector swap.
Definition BitVector.h:853
#define N
#define NC
Definition regutils.h:42
A special type used by analysis passes to provide an address that identifies that particular analysis...
Definition Analysis.h:29
static KnownBits makeConstant(const APInt &C)
Create known bits from a known constant.
Definition KnownBits.h:294
bool isNonNegative() const
Returns true if this value is known to be non-negative.
Definition KnownBits.h:101
static LLVM_ABI KnownBits ashr(const KnownBits &LHS, const KnownBits &RHS, bool ShAmtNonZero=false, bool Exact=false)
Compute known bits for ashr(LHS, RHS).
unsigned getBitWidth() const
Get the bit width of this value.
Definition KnownBits.h:44
static LLVM_ABI KnownBits lshr(const KnownBits &LHS, const KnownBits &RHS, bool ShAmtNonZero=false, bool Exact=false)
Compute known bits for lshr(LHS, RHS).
KnownBits zextOrTrunc(unsigned BitWidth) const
Return known bits for a zero extension or truncation of the value we're tracking.
Definition KnownBits.h:189
APInt getMaxValue() const
Return the maximal unsigned value possible given these KnownBits.
Definition KnownBits.h:138
APInt getMinValue() const
Return the minimal unsigned value possible given these KnownBits.
Definition KnownBits.h:122
bool isNegative() const
Returns true if this value is known to be negative.
Definition KnownBits.h:98
static LLVM_ABI KnownBits shl(const KnownBits &LHS, const KnownBits &RHS, bool NUW=false, bool NSW=false, bool ShAmtNonZero=false)
Compute known bits for shl(LHS, RHS).
An object of this class is returned by queries that could not be answered.
static LLVM_ABI bool classof(const SCEV *S)
Methods for support type inquiry through isa, cast, and dyn_cast:
This class defines a simple visitor class that may be used for various SCEV analysis purposes.
A utility class that uses RAII to save and restore the value of a variable.
Information about the number of loop iterations for which a loop exit's branch condition evaluates to...
LLVM_ABI ExitLimit(const SCEV *E)
Construct either an exact exit limit from a constant, or an unknown one from a SCEVCouldNotCompute.
SmallVector< const SCEVPredicate *, 4 > Predicates
A vector of predicate guards for this ExitLimit.