LLVM 22.0.0git
ScalarEvolution.cpp
Go to the documentation of this file.
1//===- ScalarEvolution.cpp - Scalar Evolution Analysis --------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file contains the implementation of the scalar evolution analysis
10// engine, which is used primarily to analyze expressions involving induction
11// variables in loops.
12//
13// There are several aspects to this library. First is the representation of
14// scalar expressions, which are represented as subclasses of the SCEV class.
15// These classes are used to represent certain types of subexpressions that we
16// can handle. We only create one SCEV of a particular shape, so
17// pointer-comparisons for equality are legal.
18//
19// One important aspect of the SCEV objects is that they are never cyclic, even
20// if there is a cycle in the dataflow for an expression (ie, a PHI node). If
21// the PHI node is one of the idioms that we can represent (e.g., a polynomial
22// recurrence) then we represent it directly as a recurrence node, otherwise we
23// represent it as a SCEVUnknown node.
24//
25// In addition to being able to represent expressions of various types, we also
26// have folders that are used to build the *canonical* representation for a
27// particular expression. These folders are capable of using a variety of
28// rewrite rules to simplify the expressions.
29//
30// Once the folders are defined, we can implement the more interesting
31// higher-level code, such as the code that recognizes PHI nodes of various
32// types, computes the execution count of a loop, etc.
33//
34// TODO: We should use these routines and value representations to implement
35// dependence analysis!
36//
37//===----------------------------------------------------------------------===//
38//
39// There are several good references for the techniques used in this analysis.
40//
41// Chains of recurrences -- a method to expedite the evaluation
42// of closed-form functions
43// Olaf Bachmann, Paul S. Wang, Eugene V. Zima
44//
45// On computational properties of chains of recurrences
46// Eugene V. Zima
47//
48// Symbolic Evaluation of Chains of Recurrences for Loop Optimization
49// Robert A. van Engelen
50//
51// Efficient Symbolic Analysis for Optimizing Compilers
52// Robert A. van Engelen
53//
54// Using the chains of recurrences algebra for data dependence testing and
55// induction variable substitution
56// MS Thesis, Johnie Birch
57//
58//===----------------------------------------------------------------------===//
59
61#include "llvm/ADT/APInt.h"
62#include "llvm/ADT/ArrayRef.h"
63#include "llvm/ADT/DenseMap.h"
65#include "llvm/ADT/FoldingSet.h"
66#include "llvm/ADT/STLExtras.h"
67#include "llvm/ADT/ScopeExit.h"
68#include "llvm/ADT/Sequence.h"
71#include "llvm/ADT/Statistic.h"
73#include "llvm/ADT/StringRef.h"
83#include "llvm/Config/llvm-config.h"
84#include "llvm/IR/Argument.h"
85#include "llvm/IR/BasicBlock.h"
86#include "llvm/IR/CFG.h"
87#include "llvm/IR/Constant.h"
89#include "llvm/IR/Constants.h"
90#include "llvm/IR/DataLayout.h"
92#include "llvm/IR/Dominators.h"
93#include "llvm/IR/Function.h"
94#include "llvm/IR/GlobalAlias.h"
95#include "llvm/IR/GlobalValue.h"
97#include "llvm/IR/InstrTypes.h"
98#include "llvm/IR/Instruction.h"
101#include "llvm/IR/Intrinsics.h"
102#include "llvm/IR/LLVMContext.h"
103#include "llvm/IR/Operator.h"
104#include "llvm/IR/PatternMatch.h"
105#include "llvm/IR/Type.h"
106#include "llvm/IR/Use.h"
107#include "llvm/IR/User.h"
108#include "llvm/IR/Value.h"
109#include "llvm/IR/Verifier.h"
111#include "llvm/Pass.h"
112#include "llvm/Support/Casting.h"
115#include "llvm/Support/Debug.h"
121#include <algorithm>
122#include <cassert>
123#include <climits>
124#include <cstdint>
125#include <cstdlib>
126#include <map>
127#include <memory>
128#include <numeric>
129#include <optional>
130#include <tuple>
131#include <utility>
132#include <vector>
133
134using namespace llvm;
135using namespace PatternMatch;
136using namespace SCEVPatternMatch;
137
138#define DEBUG_TYPE "scalar-evolution"
139
140STATISTIC(NumExitCountsComputed,
141 "Number of loop exits with predictable exit counts");
142STATISTIC(NumExitCountsNotComputed,
143 "Number of loop exits without predictable exit counts");
144STATISTIC(NumBruteForceTripCountsComputed,
145 "Number of loops with trip counts computed by force");
146
147#ifdef EXPENSIVE_CHECKS
148bool llvm::VerifySCEV = true;
149#else
150bool llvm::VerifySCEV = false;
151#endif
152
154 MaxBruteForceIterations("scalar-evolution-max-iterations", cl::ReallyHidden,
155 cl::desc("Maximum number of iterations SCEV will "
156 "symbolically execute a constant "
157 "derived loop"),
158 cl::init(100));
159
161 "verify-scev", cl::Hidden, cl::location(VerifySCEV),
162 cl::desc("Verify ScalarEvolution's backedge taken counts (slow)"));
164 "verify-scev-strict", cl::Hidden,
165 cl::desc("Enable stricter verification with -verify-scev is passed"));
166
168 "scev-verify-ir", cl::Hidden,
169 cl::desc("Verify IR correctness when making sensitive SCEV queries (slow)"),
170 cl::init(false));
171
173 "scev-mulops-inline-threshold", cl::Hidden,
174 cl::desc("Threshold for inlining multiplication operands into a SCEV"),
175 cl::init(32));
176
178 "scev-addops-inline-threshold", cl::Hidden,
179 cl::desc("Threshold for inlining addition operands into a SCEV"),
180 cl::init(500));
181
183 "scalar-evolution-max-scev-compare-depth", cl::Hidden,
184 cl::desc("Maximum depth of recursive SCEV complexity comparisons"),
185 cl::init(32));
186
188 "scalar-evolution-max-scev-operations-implication-depth", cl::Hidden,
189 cl::desc("Maximum depth of recursive SCEV operations implication analysis"),
190 cl::init(2));
191
193 "scalar-evolution-max-value-compare-depth", cl::Hidden,
194 cl::desc("Maximum depth of recursive value complexity comparisons"),
195 cl::init(2));
196
198 MaxArithDepth("scalar-evolution-max-arith-depth", cl::Hidden,
199 cl::desc("Maximum depth of recursive arithmetics"),
200 cl::init(32));
201
203 "scalar-evolution-max-constant-evolving-depth", cl::Hidden,
204 cl::desc("Maximum depth of recursive constant evolving"), cl::init(32));
205
207 MaxCastDepth("scalar-evolution-max-cast-depth", cl::Hidden,
208 cl::desc("Maximum depth of recursive SExt/ZExt/Trunc"),
209 cl::init(8));
210
212 MaxAddRecSize("scalar-evolution-max-add-rec-size", cl::Hidden,
213 cl::desc("Max coefficients in AddRec during evolving"),
214 cl::init(8));
215
217 HugeExprThreshold("scalar-evolution-huge-expr-threshold", cl::Hidden,
218 cl::desc("Size of the expression which is considered huge"),
219 cl::init(4096));
220
222 "scev-range-iter-threshold", cl::Hidden,
223 cl::desc("Threshold for switching to iteratively computing SCEV ranges"),
224 cl::init(32));
225
227 "scalar-evolution-max-loop-guard-collection-depth", cl::Hidden,
228 cl::desc("Maximum depth for recursive loop guard collection"), cl::init(1));
229
230static cl::opt<bool>
231ClassifyExpressions("scalar-evolution-classify-expressions",
232 cl::Hidden, cl::init(true),
233 cl::desc("When printing analysis, include information on every instruction"));
234
236 "scalar-evolution-use-expensive-range-sharpening", cl::Hidden,
237 cl::init(false),
238 cl::desc("Use more powerful methods of sharpening expression ranges. May "
239 "be costly in terms of compile time"));
240
242 "scalar-evolution-max-scc-analysis-depth", cl::Hidden,
243 cl::desc("Maximum amount of nodes to process while searching SCEVUnknown "
244 "Phi strongly connected components"),
245 cl::init(8));
246
247static cl::opt<bool>
248 EnableFiniteLoopControl("scalar-evolution-finite-loop", cl::Hidden,
249 cl::desc("Handle <= and >= in finite loops"),
250 cl::init(true));
251
253 "scalar-evolution-use-context-for-no-wrap-flag-strenghening", cl::Hidden,
254 cl::desc("Infer nuw/nsw flags using context where suitable"),
255 cl::init(true));
256
257//===----------------------------------------------------------------------===//
258// SCEV class definitions
259//===----------------------------------------------------------------------===//
260
261//===----------------------------------------------------------------------===//
262// Implementation of the SCEV class.
263//
264
265#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
267 print(dbgs());
268 dbgs() << '\n';
269}
270#endif
271
272void SCEV::print(raw_ostream &OS) const {
273 switch (getSCEVType()) {
274 case scConstant:
275 cast<SCEVConstant>(this)->getValue()->printAsOperand(OS, false);
276 return;
277 case scVScale:
278 OS << "vscale";
279 return;
280 case scPtrToInt: {
281 const SCEVPtrToIntExpr *PtrToInt = cast<SCEVPtrToIntExpr>(this);
282 const SCEV *Op = PtrToInt->getOperand();
283 OS << "(ptrtoint " << *Op->getType() << " " << *Op << " to "
284 << *PtrToInt->getType() << ")";
285 return;
286 }
287 case scTruncate: {
288 const SCEVTruncateExpr *Trunc = cast<SCEVTruncateExpr>(this);
289 const SCEV *Op = Trunc->getOperand();
290 OS << "(trunc " << *Op->getType() << " " << *Op << " to "
291 << *Trunc->getType() << ")";
292 return;
293 }
294 case scZeroExtend: {
296 const SCEV *Op = ZExt->getOperand();
297 OS << "(zext " << *Op->getType() << " " << *Op << " to "
298 << *ZExt->getType() << ")";
299 return;
300 }
301 case scSignExtend: {
303 const SCEV *Op = SExt->getOperand();
304 OS << "(sext " << *Op->getType() << " " << *Op << " to "
305 << *SExt->getType() << ")";
306 return;
307 }
308 case scAddRecExpr: {
309 const SCEVAddRecExpr *AR = cast<SCEVAddRecExpr>(this);
310 OS << "{" << *AR->getOperand(0);
311 for (unsigned i = 1, e = AR->getNumOperands(); i != e; ++i)
312 OS << ",+," << *AR->getOperand(i);
313 OS << "}<";
314 if (AR->hasNoUnsignedWrap())
315 OS << "nuw><";
316 if (AR->hasNoSignedWrap())
317 OS << "nsw><";
318 if (AR->hasNoSelfWrap() &&
320 OS << "nw><";
321 AR->getLoop()->getHeader()->printAsOperand(OS, /*PrintType=*/false);
322 OS << ">";
323 return;
324 }
325 case scAddExpr:
326 case scMulExpr:
327 case scUMaxExpr:
328 case scSMaxExpr:
329 case scUMinExpr:
330 case scSMinExpr:
332 const SCEVNAryExpr *NAry = cast<SCEVNAryExpr>(this);
333 const char *OpStr = nullptr;
334 switch (NAry->getSCEVType()) {
335 case scAddExpr: OpStr = " + "; break;
336 case scMulExpr: OpStr = " * "; break;
337 case scUMaxExpr: OpStr = " umax "; break;
338 case scSMaxExpr: OpStr = " smax "; break;
339 case scUMinExpr:
340 OpStr = " umin ";
341 break;
342 case scSMinExpr:
343 OpStr = " smin ";
344 break;
346 OpStr = " umin_seq ";
347 break;
348 default:
349 llvm_unreachable("There are no other nary expression types.");
350 }
351 OS << "("
353 << ")";
354 switch (NAry->getSCEVType()) {
355 case scAddExpr:
356 case scMulExpr:
357 if (NAry->hasNoUnsignedWrap())
358 OS << "<nuw>";
359 if (NAry->hasNoSignedWrap())
360 OS << "<nsw>";
361 break;
362 default:
363 // Nothing to print for other nary expressions.
364 break;
365 }
366 return;
367 }
368 case scUDivExpr: {
369 const SCEVUDivExpr *UDiv = cast<SCEVUDivExpr>(this);
370 OS << "(" << *UDiv->getLHS() << " /u " << *UDiv->getRHS() << ")";
371 return;
372 }
373 case scUnknown:
374 cast<SCEVUnknown>(this)->getValue()->printAsOperand(OS, false);
375 return;
377 OS << "***COULDNOTCOMPUTE***";
378 return;
379 }
380 llvm_unreachable("Unknown SCEV kind!");
381}
382
384 switch (getSCEVType()) {
385 case scConstant:
386 return cast<SCEVConstant>(this)->getType();
387 case scVScale:
388 return cast<SCEVVScale>(this)->getType();
389 case scPtrToInt:
390 case scTruncate:
391 case scZeroExtend:
392 case scSignExtend:
393 return cast<SCEVCastExpr>(this)->getType();
394 case scAddRecExpr:
395 return cast<SCEVAddRecExpr>(this)->getType();
396 case scMulExpr:
397 return cast<SCEVMulExpr>(this)->getType();
398 case scUMaxExpr:
399 case scSMaxExpr:
400 case scUMinExpr:
401 case scSMinExpr:
402 return cast<SCEVMinMaxExpr>(this)->getType();
404 return cast<SCEVSequentialMinMaxExpr>(this)->getType();
405 case scAddExpr:
406 return cast<SCEVAddExpr>(this)->getType();
407 case scUDivExpr:
408 return cast<SCEVUDivExpr>(this)->getType();
409 case scUnknown:
410 return cast<SCEVUnknown>(this)->getType();
412 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
413 }
414 llvm_unreachable("Unknown SCEV kind!");
415}
416
418 switch (getSCEVType()) {
419 case scConstant:
420 case scVScale:
421 case scUnknown:
422 return {};
423 case scPtrToInt:
424 case scTruncate:
425 case scZeroExtend:
426 case scSignExtend:
427 return cast<SCEVCastExpr>(this)->operands();
428 case scAddRecExpr:
429 case scAddExpr:
430 case scMulExpr:
431 case scUMaxExpr:
432 case scSMaxExpr:
433 case scUMinExpr:
434 case scSMinExpr:
436 return cast<SCEVNAryExpr>(this)->operands();
437 case scUDivExpr:
438 return cast<SCEVUDivExpr>(this)->operands();
440 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
441 }
442 llvm_unreachable("Unknown SCEV kind!");
443}
444
445bool SCEV::isZero() const { return match(this, m_scev_Zero()); }
446
447bool SCEV::isOne() const { return match(this, m_scev_One()); }
448
449bool SCEV::isAllOnesValue() const { return match(this, m_scev_AllOnes()); }
450
453 if (!Mul) return false;
454
455 // If there is a constant factor, it will be first.
456 const SCEVConstant *SC = dyn_cast<SCEVConstant>(Mul->getOperand(0));
457 if (!SC) return false;
458
459 // Return true if the value is negative, this matches things like (-42 * V).
460 return SC->getAPInt().isNegative();
461}
462
465
467 return S->getSCEVType() == scCouldNotCompute;
468}
469
472 ID.AddInteger(scConstant);
473 ID.AddPointer(V);
474 void *IP = nullptr;
475 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
476 SCEV *S = new (SCEVAllocator) SCEVConstant(ID.Intern(SCEVAllocator), V);
477 UniqueSCEVs.InsertNode(S, IP);
478 return S;
479}
480
482 return getConstant(ConstantInt::get(getContext(), Val));
483}
484
485const SCEV *
488 return getConstant(ConstantInt::get(ITy, V, isSigned));
489}
490
493 ID.AddInteger(scVScale);
494 ID.AddPointer(Ty);
495 void *IP = nullptr;
496 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP))
497 return S;
498 SCEV *S = new (SCEVAllocator) SCEVVScale(ID.Intern(SCEVAllocator), Ty);
499 UniqueSCEVs.InsertNode(S, IP);
500 return S;
501}
502
504 SCEV::NoWrapFlags Flags) {
505 const SCEV *Res = getConstant(Ty, EC.getKnownMinValue());
506 if (EC.isScalable())
507 Res = getMulExpr(Res, getVScale(Ty), Flags);
508 return Res;
509}
510
512 const SCEV *op, Type *ty)
513 : SCEV(ID, SCEVTy, computeExpressionSize(op)), Op(op), Ty(ty) {}
514
515SCEVPtrToIntExpr::SCEVPtrToIntExpr(const FoldingSetNodeIDRef ID, const SCEV *Op,
516 Type *ITy)
517 : SCEVCastExpr(ID, scPtrToInt, Op, ITy) {
518 assert(getOperand()->getType()->isPointerTy() && Ty->isIntegerTy() &&
519 "Must be a non-bit-width-changing pointer-to-integer cast!");
520}
521
523 SCEVTypes SCEVTy, const SCEV *op,
524 Type *ty)
525 : SCEVCastExpr(ID, SCEVTy, op, ty) {}
526
527SCEVTruncateExpr::SCEVTruncateExpr(const FoldingSetNodeIDRef ID, const SCEV *op,
528 Type *ty)
530 assert(getOperand()->getType()->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
531 "Cannot truncate non-integer value!");
532}
533
534SCEVZeroExtendExpr::SCEVZeroExtendExpr(const FoldingSetNodeIDRef ID,
535 const SCEV *op, Type *ty)
537 assert(getOperand()->getType()->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
538 "Cannot zero extend non-integer value!");
539}
540
541SCEVSignExtendExpr::SCEVSignExtendExpr(const FoldingSetNodeIDRef ID,
542 const SCEV *op, Type *ty)
544 assert(getOperand()->getType()->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
545 "Cannot sign extend non-integer value!");
546}
547
549 // Clear this SCEVUnknown from various maps.
550 SE->forgetMemoizedResults(this);
551
552 // Remove this SCEVUnknown from the uniquing map.
553 SE->UniqueSCEVs.RemoveNode(this);
554
555 // Release the value.
556 setValPtr(nullptr);
557}
558
559void SCEVUnknown::allUsesReplacedWith(Value *New) {
560 // Clear this SCEVUnknown from various maps.
561 SE->forgetMemoizedResults(this);
562
563 // Remove this SCEVUnknown from the uniquing map.
564 SE->UniqueSCEVs.RemoveNode(this);
565
566 // Replace the value pointer in case someone is still using this SCEVUnknown.
567 setValPtr(New);
568}
569
570//===----------------------------------------------------------------------===//
571// SCEV Utilities
572//===----------------------------------------------------------------------===//
573
574/// Compare the two values \p LV and \p RV in terms of their "complexity" where
575/// "complexity" is a partial (and somewhat ad-hoc) relation used to order
576/// operands in SCEV expressions.
577static int CompareValueComplexity(const LoopInfo *const LI, Value *LV,
578 Value *RV, unsigned Depth) {
580 return 0;
581
582 // Order pointer values after integer values. This helps SCEVExpander form
583 // GEPs.
584 bool LIsPointer = LV->getType()->isPointerTy(),
585 RIsPointer = RV->getType()->isPointerTy();
586 if (LIsPointer != RIsPointer)
587 return (int)LIsPointer - (int)RIsPointer;
588
589 // Compare getValueID values.
590 unsigned LID = LV->getValueID(), RID = RV->getValueID();
591 if (LID != RID)
592 return (int)LID - (int)RID;
593
594 // Sort arguments by their position.
595 if (const auto *LA = dyn_cast<Argument>(LV)) {
596 const auto *RA = cast<Argument>(RV);
597 unsigned LArgNo = LA->getArgNo(), RArgNo = RA->getArgNo();
598 return (int)LArgNo - (int)RArgNo;
599 }
600
601 if (const auto *LGV = dyn_cast<GlobalValue>(LV)) {
602 const auto *RGV = cast<GlobalValue>(RV);
603
604 if (auto L = LGV->getLinkage() - RGV->getLinkage())
605 return L;
606
607 const auto IsGVNameSemantic = [&](const GlobalValue *GV) {
608 auto LT = GV->getLinkage();
609 return !(GlobalValue::isPrivateLinkage(LT) ||
611 };
612
613 // Use the names to distinguish the two values, but only if the
614 // names are semantically important.
615 if (IsGVNameSemantic(LGV) && IsGVNameSemantic(RGV))
616 return LGV->getName().compare(RGV->getName());
617 }
618
619 // For instructions, compare their loop depth, and their operand count. This
620 // is pretty loose.
621 if (const auto *LInst = dyn_cast<Instruction>(LV)) {
622 const auto *RInst = cast<Instruction>(RV);
623
624 // Compare loop depths.
625 const BasicBlock *LParent = LInst->getParent(),
626 *RParent = RInst->getParent();
627 if (LParent != RParent) {
628 unsigned LDepth = LI->getLoopDepth(LParent),
629 RDepth = LI->getLoopDepth(RParent);
630 if (LDepth != RDepth)
631 return (int)LDepth - (int)RDepth;
632 }
633
634 // Compare the number of operands.
635 unsigned LNumOps = LInst->getNumOperands(),
636 RNumOps = RInst->getNumOperands();
637 if (LNumOps != RNumOps)
638 return (int)LNumOps - (int)RNumOps;
639
640 for (unsigned Idx : seq(LNumOps)) {
641 int Result = CompareValueComplexity(LI, LInst->getOperand(Idx),
642 RInst->getOperand(Idx), Depth + 1);
643 if (Result != 0)
644 return Result;
645 }
646 }
647
648 return 0;
649}
650
651// Return negative, zero, or positive, if LHS is less than, equal to, or greater
652// than RHS, respectively. A three-way result allows recursive comparisons to be
653// more efficient.
654// If the max analysis depth was reached, return std::nullopt, assuming we do
655// not know if they are equivalent for sure.
656static std::optional<int>
657CompareSCEVComplexity(const LoopInfo *const LI, const SCEV *LHS,
658 const SCEV *RHS, DominatorTree &DT, unsigned Depth = 0) {
659 // Fast-path: SCEVs are uniqued so we can do a quick equality check.
660 if (LHS == RHS)
661 return 0;
662
663 // Primarily, sort the SCEVs by their getSCEVType().
664 SCEVTypes LType = LHS->getSCEVType(), RType = RHS->getSCEVType();
665 if (LType != RType)
666 return (int)LType - (int)RType;
667
669 return std::nullopt;
670
671 // Aside from the getSCEVType() ordering, the particular ordering
672 // isn't very important except that it's beneficial to be consistent,
673 // so that (a + b) and (b + a) don't end up as different expressions.
674 switch (LType) {
675 case scUnknown: {
676 const SCEVUnknown *LU = cast<SCEVUnknown>(LHS);
677 const SCEVUnknown *RU = cast<SCEVUnknown>(RHS);
678
679 int X =
680 CompareValueComplexity(LI, LU->getValue(), RU->getValue(), Depth + 1);
681 return X;
682 }
683
684 case scConstant: {
687
688 // Compare constant values.
689 const APInt &LA = LC->getAPInt();
690 const APInt &RA = RC->getAPInt();
691 unsigned LBitWidth = LA.getBitWidth(), RBitWidth = RA.getBitWidth();
692 if (LBitWidth != RBitWidth)
693 return (int)LBitWidth - (int)RBitWidth;
694 return LA.ult(RA) ? -1 : 1;
695 }
696
697 case scVScale: {
698 const auto *LTy = cast<IntegerType>(cast<SCEVVScale>(LHS)->getType());
699 const auto *RTy = cast<IntegerType>(cast<SCEVVScale>(RHS)->getType());
700 return LTy->getBitWidth() - RTy->getBitWidth();
701 }
702
703 case scAddRecExpr: {
706
707 // There is always a dominance between two recs that are used by one SCEV,
708 // so we can safely sort recs by loop header dominance. We require such
709 // order in getAddExpr.
710 const Loop *LLoop = LA->getLoop(), *RLoop = RA->getLoop();
711 if (LLoop != RLoop) {
712 const BasicBlock *LHead = LLoop->getHeader(), *RHead = RLoop->getHeader();
713 assert(LHead != RHead && "Two loops share the same header?");
714 if (DT.dominates(LHead, RHead))
715 return 1;
716 assert(DT.dominates(RHead, LHead) &&
717 "No dominance between recurrences used by one SCEV?");
718 return -1;
719 }
720
721 [[fallthrough]];
722 }
723
724 case scTruncate:
725 case scZeroExtend:
726 case scSignExtend:
727 case scPtrToInt:
728 case scAddExpr:
729 case scMulExpr:
730 case scUDivExpr:
731 case scSMaxExpr:
732 case scUMaxExpr:
733 case scSMinExpr:
734 case scUMinExpr:
736 ArrayRef<const SCEV *> LOps = LHS->operands();
737 ArrayRef<const SCEV *> ROps = RHS->operands();
738
739 // Lexicographically compare n-ary-like expressions.
740 unsigned LNumOps = LOps.size(), RNumOps = ROps.size();
741 if (LNumOps != RNumOps)
742 return (int)LNumOps - (int)RNumOps;
743
744 for (unsigned i = 0; i != LNumOps; ++i) {
745 auto X = CompareSCEVComplexity(LI, LOps[i], ROps[i], DT, Depth + 1);
746 if (X != 0)
747 return X;
748 }
749 return 0;
750 }
751
753 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
754 }
755 llvm_unreachable("Unknown SCEV kind!");
756}
757
758/// Given a list of SCEV objects, order them by their complexity, and group
759/// objects of the same complexity together by value. When this routine is
760/// finished, we know that any duplicates in the vector are consecutive and that
761/// complexity is monotonically increasing.
762///
763/// Note that we go take special precautions to ensure that we get deterministic
764/// results from this routine. In other words, we don't want the results of
765/// this to depend on where the addresses of various SCEV objects happened to
766/// land in memory.
768 LoopInfo *LI, DominatorTree &DT) {
769 if (Ops.size() < 2) return; // Noop
770
771 // Whether LHS has provably less complexity than RHS.
772 auto IsLessComplex = [&](const SCEV *LHS, const SCEV *RHS) {
773 auto Complexity = CompareSCEVComplexity(LI, LHS, RHS, DT);
774 return Complexity && *Complexity < 0;
775 };
776 if (Ops.size() == 2) {
777 // This is the common case, which also happens to be trivially simple.
778 // Special case it.
779 const SCEV *&LHS = Ops[0], *&RHS = Ops[1];
780 if (IsLessComplex(RHS, LHS))
781 std::swap(LHS, RHS);
782 return;
783 }
784
785 // Do the rough sort by complexity.
786 llvm::stable_sort(Ops, [&](const SCEV *LHS, const SCEV *RHS) {
787 return IsLessComplex(LHS, RHS);
788 });
789
790 // Now that we are sorted by complexity, group elements of the same
791 // complexity. Note that this is, at worst, N^2, but the vector is likely to
792 // be extremely short in practice. Note that we take this approach because we
793 // do not want to depend on the addresses of the objects we are grouping.
794 for (unsigned i = 0, e = Ops.size(); i != e-2; ++i) {
795 const SCEV *S = Ops[i];
796 unsigned Complexity = S->getSCEVType();
797
798 // If there are any objects of the same complexity and same value as this
799 // one, group them.
800 for (unsigned j = i+1; j != e && Ops[j]->getSCEVType() == Complexity; ++j) {
801 if (Ops[j] == S) { // Found a duplicate.
802 // Move it to immediately after i'th element.
803 std::swap(Ops[i+1], Ops[j]);
804 ++i; // no need to rescan it.
805 if (i == e-2) return; // Done!
806 }
807 }
808 }
809}
810
811/// Returns true if \p Ops contains a huge SCEV (the subtree of S contains at
812/// least HugeExprThreshold nodes).
814 return any_of(Ops, [](const SCEV *S) {
816 });
817}
818
819/// Performs a number of common optimizations on the passed \p Ops. If the
820/// whole expression reduces down to a single operand, it will be returned.
821///
822/// The following optimizations are performed:
823/// * Fold constants using the \p Fold function.
824/// * Remove identity constants satisfying \p IsIdentity.
825/// * If a constant satisfies \p IsAbsorber, return it.
826/// * Sort operands by complexity.
827template <typename FoldT, typename IsIdentityT, typename IsAbsorberT>
828static const SCEV *
831 IsIdentityT IsIdentity, IsAbsorberT IsAbsorber) {
832 const SCEVConstant *Folded = nullptr;
833 for (unsigned Idx = 0; Idx < Ops.size();) {
834 const SCEV *Op = Ops[Idx];
835 if (const auto *C = dyn_cast<SCEVConstant>(Op)) {
836 if (!Folded)
837 Folded = C;
838 else
839 Folded = cast<SCEVConstant>(
840 SE.getConstant(Fold(Folded->getAPInt(), C->getAPInt())));
841 Ops.erase(Ops.begin() + Idx);
842 continue;
843 }
844 ++Idx;
845 }
846
847 if (Ops.empty()) {
848 assert(Folded && "Must have folded value");
849 return Folded;
850 }
851
852 if (Folded && IsAbsorber(Folded->getAPInt()))
853 return Folded;
854
855 GroupByComplexity(Ops, &LI, DT);
856 if (Folded && !IsIdentity(Folded->getAPInt()))
857 Ops.insert(Ops.begin(), Folded);
858
859 return Ops.size() == 1 ? Ops[0] : nullptr;
860}
861
862//===----------------------------------------------------------------------===//
863// Simple SCEV method implementations
864//===----------------------------------------------------------------------===//
865
866/// Compute BC(It, K). The result has width W. Assume, K > 0.
867static const SCEV *BinomialCoefficient(const SCEV *It, unsigned K,
868 ScalarEvolution &SE,
869 Type *ResultTy) {
870 // Handle the simplest case efficiently.
871 if (K == 1)
872 return SE.getTruncateOrZeroExtend(It, ResultTy);
873
874 // We are using the following formula for BC(It, K):
875 //
876 // BC(It, K) = (It * (It - 1) * ... * (It - K + 1)) / K!
877 //
878 // Suppose, W is the bitwidth of the return value. We must be prepared for
879 // overflow. Hence, we must assure that the result of our computation is
880 // equal to the accurate one modulo 2^W. Unfortunately, division isn't
881 // safe in modular arithmetic.
882 //
883 // However, this code doesn't use exactly that formula; the formula it uses
884 // is something like the following, where T is the number of factors of 2 in
885 // K! (i.e. trailing zeros in the binary representation of K!), and ^ is
886 // exponentiation:
887 //
888 // BC(It, K) = (It * (It - 1) * ... * (It - K + 1)) / 2^T / (K! / 2^T)
889 //
890 // This formula is trivially equivalent to the previous formula. However,
891 // this formula can be implemented much more efficiently. The trick is that
892 // K! / 2^T is odd, and exact division by an odd number *is* safe in modular
893 // arithmetic. To do exact division in modular arithmetic, all we have
894 // to do is multiply by the inverse. Therefore, this step can be done at
895 // width W.
896 //
897 // The next issue is how to safely do the division by 2^T. The way this
898 // is done is by doing the multiplication step at a width of at least W + T
899 // bits. This way, the bottom W+T bits of the product are accurate. Then,
900 // when we perform the division by 2^T (which is equivalent to a right shift
901 // by T), the bottom W bits are accurate. Extra bits are okay; they'll get
902 // truncated out after the division by 2^T.
903 //
904 // In comparison to just directly using the first formula, this technique
905 // is much more efficient; using the first formula requires W * K bits,
906 // but this formula less than W + K bits. Also, the first formula requires
907 // a division step, whereas this formula only requires multiplies and shifts.
908 //
909 // It doesn't matter whether the subtraction step is done in the calculation
910 // width or the input iteration count's width; if the subtraction overflows,
911 // the result must be zero anyway. We prefer here to do it in the width of
912 // the induction variable because it helps a lot for certain cases; CodeGen
913 // isn't smart enough to ignore the overflow, which leads to much less
914 // efficient code if the width of the subtraction is wider than the native
915 // register width.
916 //
917 // (It's possible to not widen at all by pulling out factors of 2 before
918 // the multiplication; for example, K=2 can be calculated as
919 // It/2*(It+(It*INT_MIN/INT_MIN)+-1). However, it requires
920 // extra arithmetic, so it's not an obvious win, and it gets
921 // much more complicated for K > 3.)
922
923 // Protection from insane SCEVs; this bound is conservative,
924 // but it probably doesn't matter.
925 if (K > 1000)
926 return SE.getCouldNotCompute();
927
928 unsigned W = SE.getTypeSizeInBits(ResultTy);
929
930 // Calculate K! / 2^T and T; we divide out the factors of two before
931 // multiplying for calculating K! / 2^T to avoid overflow.
932 // Other overflow doesn't matter because we only care about the bottom
933 // W bits of the result.
934 APInt OddFactorial(W, 1);
935 unsigned T = 1;
936 for (unsigned i = 3; i <= K; ++i) {
937 unsigned TwoFactors = countr_zero(i);
938 T += TwoFactors;
939 OddFactorial *= (i >> TwoFactors);
940 }
941
942 // We need at least W + T bits for the multiplication step
943 unsigned CalculationBits = W + T;
944
945 // Calculate 2^T, at width T+W.
946 APInt DivFactor = APInt::getOneBitSet(CalculationBits, T);
947
948 // Calculate the multiplicative inverse of K! / 2^T;
949 // this multiplication factor will perform the exact division by
950 // K! / 2^T.
951 APInt MultiplyFactor = OddFactorial.multiplicativeInverse();
952
953 // Calculate the product, at width T+W
954 IntegerType *CalculationTy = IntegerType::get(SE.getContext(),
955 CalculationBits);
956 const SCEV *Dividend = SE.getTruncateOrZeroExtend(It, CalculationTy);
957 for (unsigned i = 1; i != K; ++i) {
958 const SCEV *S = SE.getMinusSCEV(It, SE.getConstant(It->getType(), i));
959 Dividend = SE.getMulExpr(Dividend,
960 SE.getTruncateOrZeroExtend(S, CalculationTy));
961 }
962
963 // Divide by 2^T
964 const SCEV *DivResult = SE.getUDivExpr(Dividend, SE.getConstant(DivFactor));
965
966 // Truncate the result, and divide by K! / 2^T.
967
968 return SE.getMulExpr(SE.getConstant(MultiplyFactor),
969 SE.getTruncateOrZeroExtend(DivResult, ResultTy));
970}
971
972/// Return the value of this chain of recurrences at the specified iteration
973/// number. We can evaluate this recurrence by multiplying each element in the
974/// chain by the binomial coefficient corresponding to it. In other words, we
975/// can evaluate {A,+,B,+,C,+,D} as:
976///
977/// A*BC(It, 0) + B*BC(It, 1) + C*BC(It, 2) + D*BC(It, 3)
978///
979/// where BC(It, k) stands for binomial coefficient.
981 ScalarEvolution &SE) const {
982 return evaluateAtIteration(operands(), It, SE);
983}
984
985const SCEV *
987 const SCEV *It, ScalarEvolution &SE) {
988 assert(Operands.size() > 0);
989 const SCEV *Result = Operands[0];
990 for (unsigned i = 1, e = Operands.size(); i != e; ++i) {
991 // The computation is correct in the face of overflow provided that the
992 // multiplication is performed _after_ the evaluation of the binomial
993 // coefficient.
994 const SCEV *Coeff = BinomialCoefficient(It, i, SE, Result->getType());
995 if (isa<SCEVCouldNotCompute>(Coeff))
996 return Coeff;
997
998 Result = SE.getAddExpr(Result, SE.getMulExpr(Operands[i], Coeff));
999 }
1000 return Result;
1001}
1002
1003//===----------------------------------------------------------------------===//
1004// SCEV Expression folder implementations
1005//===----------------------------------------------------------------------===//
1006
1008 unsigned Depth) {
1009 assert(Depth <= 1 &&
1010 "getLosslessPtrToIntExpr() should self-recurse at most once.");
1011
1012 // We could be called with an integer-typed operands during SCEV rewrites.
1013 // Since the operand is an integer already, just perform zext/trunc/self cast.
1014 if (!Op->getType()->isPointerTy())
1015 return Op;
1016
1017 // What would be an ID for such a SCEV cast expression?
1019 ID.AddInteger(scPtrToInt);
1020 ID.AddPointer(Op);
1021
1022 void *IP = nullptr;
1023
1024 // Is there already an expression for such a cast?
1025 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP))
1026 return S;
1027
1028 // It isn't legal for optimizations to construct new ptrtoint expressions
1029 // for non-integral pointers.
1030 if (getDataLayout().isNonIntegralPointerType(Op->getType()))
1031 return getCouldNotCompute();
1032
1033 Type *IntPtrTy = getDataLayout().getIntPtrType(Op->getType());
1034
1035 // We can only trivially model ptrtoint if SCEV's effective (integer) type
1036 // is sufficiently wide to represent all possible pointer values.
1037 // We could theoretically teach SCEV to truncate wider pointers, but
1038 // that isn't implemented for now.
1040 getDataLayout().getTypeSizeInBits(IntPtrTy))
1041 return getCouldNotCompute();
1042
1043 // If not, is this expression something we can't reduce any further?
1044 if (auto *U = dyn_cast<SCEVUnknown>(Op)) {
1045 // Perform some basic constant folding. If the operand of the ptr2int cast
1046 // is a null pointer, don't create a ptr2int SCEV expression (that will be
1047 // left as-is), but produce a zero constant.
1048 // NOTE: We could handle a more general case, but lack motivational cases.
1049 if (isa<ConstantPointerNull>(U->getValue()))
1050 return getZero(IntPtrTy);
1051
1052 // Create an explicit cast node.
1053 // We can reuse the existing insert position since if we get here,
1054 // we won't have made any changes which would invalidate it.
1055 SCEV *S = new (SCEVAllocator)
1056 SCEVPtrToIntExpr(ID.Intern(SCEVAllocator), Op, IntPtrTy);
1057 UniqueSCEVs.InsertNode(S, IP);
1058 registerUser(S, Op);
1059 return S;
1060 }
1061
1062 assert(Depth == 0 && "getLosslessPtrToIntExpr() should not self-recurse for "
1063 "non-SCEVUnknown's.");
1064
1065 // Otherwise, we've got some expression that is more complex than just a
1066 // single SCEVUnknown. But we don't want to have a SCEVPtrToIntExpr of an
1067 // arbitrary expression, we want to have SCEVPtrToIntExpr of an SCEVUnknown
1068 // only, and the expressions must otherwise be integer-typed.
1069 // So sink the cast down to the SCEVUnknown's.
1070
1071 /// The SCEVPtrToIntSinkingRewriter takes a scalar evolution expression,
1072 /// which computes a pointer-typed value, and rewrites the whole expression
1073 /// tree so that *all* the computations are done on integers, and the only
1074 /// pointer-typed operands in the expression are SCEVUnknown.
1075 class SCEVPtrToIntSinkingRewriter
1076 : public SCEVRewriteVisitor<SCEVPtrToIntSinkingRewriter> {
1078
1079 public:
1080 SCEVPtrToIntSinkingRewriter(ScalarEvolution &SE) : SCEVRewriteVisitor(SE) {}
1081
1082 static const SCEV *rewrite(const SCEV *Scev, ScalarEvolution &SE) {
1083 SCEVPtrToIntSinkingRewriter Rewriter(SE);
1084 return Rewriter.visit(Scev);
1085 }
1086
1087 const SCEV *visit(const SCEV *S) {
1088 Type *STy = S->getType();
1089 // If the expression is not pointer-typed, just keep it as-is.
1090 if (!STy->isPointerTy())
1091 return S;
1092 // Else, recursively sink the cast down into it.
1093 return Base::visit(S);
1094 }
1095
1096 const SCEV *visitAddExpr(const SCEVAddExpr *Expr) {
1098 bool Changed = false;
1099 for (const auto *Op : Expr->operands()) {
1100 Operands.push_back(visit(Op));
1101 Changed |= Op != Operands.back();
1102 }
1103 return !Changed ? Expr : SE.getAddExpr(Operands, Expr->getNoWrapFlags());
1104 }
1105
1106 const SCEV *visitMulExpr(const SCEVMulExpr *Expr) {
1108 bool Changed = false;
1109 for (const auto *Op : Expr->operands()) {
1110 Operands.push_back(visit(Op));
1111 Changed |= Op != Operands.back();
1112 }
1113 return !Changed ? Expr : SE.getMulExpr(Operands, Expr->getNoWrapFlags());
1114 }
1115
1116 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
1117 assert(Expr->getType()->isPointerTy() &&
1118 "Should only reach pointer-typed SCEVUnknown's.");
1119 return SE.getLosslessPtrToIntExpr(Expr, /*Depth=*/1);
1120 }
1121 };
1122
1123 // And actually perform the cast sinking.
1124 const SCEV *IntOp = SCEVPtrToIntSinkingRewriter::rewrite(Op, *this);
1125 assert(IntOp->getType()->isIntegerTy() &&
1126 "We must have succeeded in sinking the cast, "
1127 "and ending up with an integer-typed expression!");
1128 return IntOp;
1129}
1130
1132 assert(Ty->isIntegerTy() && "Target type must be an integer type!");
1133
1134 const SCEV *IntOp = getLosslessPtrToIntExpr(Op);
1135 if (isa<SCEVCouldNotCompute>(IntOp))
1136 return IntOp;
1137
1138 return getTruncateOrZeroExtend(IntOp, Ty);
1139}
1140
1142 unsigned Depth) {
1143 assert(getTypeSizeInBits(Op->getType()) > getTypeSizeInBits(Ty) &&
1144 "This is not a truncating conversion!");
1145 assert(isSCEVable(Ty) &&
1146 "This is not a conversion to a SCEVable type!");
1147 assert(!Op->getType()->isPointerTy() && "Can't truncate pointer!");
1148 Ty = getEffectiveSCEVType(Ty);
1149
1151 ID.AddInteger(scTruncate);
1152 ID.AddPointer(Op);
1153 ID.AddPointer(Ty);
1154 void *IP = nullptr;
1155 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
1156
1157 // Fold if the operand is constant.
1158 if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op))
1159 return getConstant(
1160 cast<ConstantInt>(ConstantExpr::getTrunc(SC->getValue(), Ty)));
1161
1162 // trunc(trunc(x)) --> trunc(x)
1164 return getTruncateExpr(ST->getOperand(), Ty, Depth + 1);
1165
1166 // trunc(sext(x)) --> sext(x) if widening or trunc(x) if narrowing
1168 return getTruncateOrSignExtend(SS->getOperand(), Ty, Depth + 1);
1169
1170 // trunc(zext(x)) --> zext(x) if widening or trunc(x) if narrowing
1172 return getTruncateOrZeroExtend(SZ->getOperand(), Ty, Depth + 1);
1173
1174 if (Depth > MaxCastDepth) {
1175 SCEV *S =
1176 new (SCEVAllocator) SCEVTruncateExpr(ID.Intern(SCEVAllocator), Op, Ty);
1177 UniqueSCEVs.InsertNode(S, IP);
1178 registerUser(S, Op);
1179 return S;
1180 }
1181
1182 // trunc(x1 + ... + xN) --> trunc(x1) + ... + trunc(xN) and
1183 // trunc(x1 * ... * xN) --> trunc(x1) * ... * trunc(xN),
1184 // if after transforming we have at most one truncate, not counting truncates
1185 // that replace other casts.
1187 auto *CommOp = cast<SCEVCommutativeExpr>(Op);
1189 unsigned numTruncs = 0;
1190 for (unsigned i = 0, e = CommOp->getNumOperands(); i != e && numTruncs < 2;
1191 ++i) {
1192 const SCEV *S = getTruncateExpr(CommOp->getOperand(i), Ty, Depth + 1);
1193 if (!isa<SCEVIntegralCastExpr>(CommOp->getOperand(i)) &&
1195 numTruncs++;
1196 Operands.push_back(S);
1197 }
1198 if (numTruncs < 2) {
1199 if (isa<SCEVAddExpr>(Op))
1200 return getAddExpr(Operands);
1201 if (isa<SCEVMulExpr>(Op))
1202 return getMulExpr(Operands);
1203 llvm_unreachable("Unexpected SCEV type for Op.");
1204 }
1205 // Although we checked in the beginning that ID is not in the cache, it is
1206 // possible that during recursion and different modification ID was inserted
1207 // into the cache. So if we find it, just return it.
1208 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP))
1209 return S;
1210 }
1211
1212 // If the input value is a chrec scev, truncate the chrec's operands.
1213 if (const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(Op)) {
1215 for (const SCEV *Op : AddRec->operands())
1216 Operands.push_back(getTruncateExpr(Op, Ty, Depth + 1));
1217 return getAddRecExpr(Operands, AddRec->getLoop(), SCEV::FlagAnyWrap);
1218 }
1219
1220 // Return zero if truncating to known zeros.
1221 uint32_t MinTrailingZeros = getMinTrailingZeros(Op);
1222 if (MinTrailingZeros >= getTypeSizeInBits(Ty))
1223 return getZero(Ty);
1224
1225 // The cast wasn't folded; create an explicit cast node. We can reuse
1226 // the existing insert position since if we get here, we won't have
1227 // made any changes which would invalidate it.
1228 SCEV *S = new (SCEVAllocator) SCEVTruncateExpr(ID.Intern(SCEVAllocator),
1229 Op, Ty);
1230 UniqueSCEVs.InsertNode(S, IP);
1231 registerUser(S, Op);
1232 return S;
1233}
1234
1235// Get the limit of a recurrence such that incrementing by Step cannot cause
1236// signed overflow as long as the value of the recurrence within the
1237// loop does not exceed this limit before incrementing.
1238static const SCEV *getSignedOverflowLimitForStep(const SCEV *Step,
1239 ICmpInst::Predicate *Pred,
1240 ScalarEvolution *SE) {
1241 unsigned BitWidth = SE->getTypeSizeInBits(Step->getType());
1242 if (SE->isKnownPositive(Step)) {
1243 *Pred = ICmpInst::ICMP_SLT;
1245 SE->getSignedRangeMax(Step));
1246 }
1247 if (SE->isKnownNegative(Step)) {
1248 *Pred = ICmpInst::ICMP_SGT;
1250 SE->getSignedRangeMin(Step));
1251 }
1252 return nullptr;
1253}
1254
1255// Get the limit of a recurrence such that incrementing by Step cannot cause
1256// unsigned overflow as long as the value of the recurrence within the loop does
1257// not exceed this limit before incrementing.
1259 ICmpInst::Predicate *Pred,
1260 ScalarEvolution *SE) {
1261 unsigned BitWidth = SE->getTypeSizeInBits(Step->getType());
1262 *Pred = ICmpInst::ICMP_ULT;
1263
1265 SE->getUnsignedRangeMax(Step));
1266}
1267
1268namespace {
1269
1270struct ExtendOpTraitsBase {
1271 typedef const SCEV *(ScalarEvolution::*GetExtendExprTy)(const SCEV *, Type *,
1272 unsigned);
1273};
1274
1275// Used to make code generic over signed and unsigned overflow.
1276template <typename ExtendOp> struct ExtendOpTraits {
1277 // Members present:
1278 //
1279 // static const SCEV::NoWrapFlags WrapType;
1280 //
1281 // static const ExtendOpTraitsBase::GetExtendExprTy GetExtendExpr;
1282 //
1283 // static const SCEV *getOverflowLimitForStep(const SCEV *Step,
1284 // ICmpInst::Predicate *Pred,
1285 // ScalarEvolution *SE);
1286};
1287
1288template <>
1289struct ExtendOpTraits<SCEVSignExtendExpr> : public ExtendOpTraitsBase {
1290 static const SCEV::NoWrapFlags WrapType = SCEV::FlagNSW;
1291
1292 static const GetExtendExprTy GetExtendExpr;
1293
1294 static const SCEV *getOverflowLimitForStep(const SCEV *Step,
1295 ICmpInst::Predicate *Pred,
1296 ScalarEvolution *SE) {
1297 return getSignedOverflowLimitForStep(Step, Pred, SE);
1298 }
1299};
1300
1301const ExtendOpTraitsBase::GetExtendExprTy ExtendOpTraits<
1303
1304template <>
1305struct ExtendOpTraits<SCEVZeroExtendExpr> : public ExtendOpTraitsBase {
1306 static const SCEV::NoWrapFlags WrapType = SCEV::FlagNUW;
1307
1308 static const GetExtendExprTy GetExtendExpr;
1309
1310 static const SCEV *getOverflowLimitForStep(const SCEV *Step,
1311 ICmpInst::Predicate *Pred,
1312 ScalarEvolution *SE) {
1313 return getUnsignedOverflowLimitForStep(Step, Pred, SE);
1314 }
1315};
1316
1317const ExtendOpTraitsBase::GetExtendExprTy ExtendOpTraits<
1319
1320} // end anonymous namespace
1321
1322// The recurrence AR has been shown to have no signed/unsigned wrap or something
1323// close to it. Typically, if we can prove NSW/NUW for AR, then we can just as
1324// easily prove NSW/NUW for its preincrement or postincrement sibling. This
1325// allows normalizing a sign/zero extended AddRec as such: {sext/zext(Step +
1326// Start),+,Step} => {(Step + sext/zext(Start),+,Step} As a result, the
1327// expression "Step + sext/zext(PreIncAR)" is congruent with
1328// "sext/zext(PostIncAR)"
1329template <typename ExtendOpTy>
1330static const SCEV *getPreStartForExtend(const SCEVAddRecExpr *AR, Type *Ty,
1331 ScalarEvolution *SE, unsigned Depth) {
1332 auto WrapType = ExtendOpTraits<ExtendOpTy>::WrapType;
1333 auto GetExtendExpr = ExtendOpTraits<ExtendOpTy>::GetExtendExpr;
1334
1335 const Loop *L = AR->getLoop();
1336 const SCEV *Start = AR->getStart();
1337 const SCEV *Step = AR->getStepRecurrence(*SE);
1338
1339 // Check for a simple looking step prior to loop entry.
1340 const SCEVAddExpr *SA = dyn_cast<SCEVAddExpr>(Start);
1341 if (!SA)
1342 return nullptr;
1343
1344 // Create an AddExpr for "PreStart" after subtracting Step. Full SCEV
1345 // subtraction is expensive. For this purpose, perform a quick and dirty
1346 // difference, by checking for Step in the operand list. Note, that
1347 // SA might have repeated ops, like %a + %a + ..., so only remove one.
1349 for (auto It = DiffOps.begin(); It != DiffOps.end(); ++It)
1350 if (*It == Step) {
1351 DiffOps.erase(It);
1352 break;
1353 }
1354
1355 if (DiffOps.size() == SA->getNumOperands())
1356 return nullptr;
1357
1358 // Try to prove `WrapType` (SCEV::FlagNSW or SCEV::FlagNUW) on `PreStart` +
1359 // `Step`:
1360
1361 // 1. NSW/NUW flags on the step increment.
1362 auto PreStartFlags =
1364 const SCEV *PreStart = SE->getAddExpr(DiffOps, PreStartFlags);
1366 SE->getAddRecExpr(PreStart, Step, L, SCEV::FlagAnyWrap));
1367
1368 // "{S,+,X} is <nsw>/<nuw>" and "the backedge is taken at least once" implies
1369 // "S+X does not sign/unsign-overflow".
1370 //
1371
1372 const SCEV *BECount = SE->getBackedgeTakenCount(L);
1373 if (PreAR && PreAR->getNoWrapFlags(WrapType) &&
1374 !isa<SCEVCouldNotCompute>(BECount) && SE->isKnownPositive(BECount))
1375 return PreStart;
1376
1377 // 2. Direct overflow check on the step operation's expression.
1378 unsigned BitWidth = SE->getTypeSizeInBits(AR->getType());
1379 Type *WideTy = IntegerType::get(SE->getContext(), BitWidth * 2);
1380 const SCEV *OperandExtendedStart =
1381 SE->getAddExpr((SE->*GetExtendExpr)(PreStart, WideTy, Depth),
1382 (SE->*GetExtendExpr)(Step, WideTy, Depth));
1383 if ((SE->*GetExtendExpr)(Start, WideTy, Depth) == OperandExtendedStart) {
1384 if (PreAR && AR->getNoWrapFlags(WrapType)) {
1385 // If we know `AR` == {`PreStart`+`Step`,+,`Step`} is `WrapType` (FlagNSW
1386 // or FlagNUW) and that `PreStart` + `Step` is `WrapType` too, then
1387 // `PreAR` == {`PreStart`,+,`Step`} is also `WrapType`. Cache this fact.
1388 SE->setNoWrapFlags(const_cast<SCEVAddRecExpr *>(PreAR), WrapType);
1389 }
1390 return PreStart;
1391 }
1392
1393 // 3. Loop precondition.
1395 const SCEV *OverflowLimit =
1396 ExtendOpTraits<ExtendOpTy>::getOverflowLimitForStep(Step, &Pred, SE);
1397
1398 if (OverflowLimit &&
1399 SE->isLoopEntryGuardedByCond(L, Pred, PreStart, OverflowLimit))
1400 return PreStart;
1401
1402 return nullptr;
1403}
1404
1405// Get the normalized zero or sign extended expression for this AddRec's Start.
1406template <typename ExtendOpTy>
1407static const SCEV *getExtendAddRecStart(const SCEVAddRecExpr *AR, Type *Ty,
1408 ScalarEvolution *SE,
1409 unsigned Depth) {
1410 auto GetExtendExpr = ExtendOpTraits<ExtendOpTy>::GetExtendExpr;
1411
1412 const SCEV *PreStart = getPreStartForExtend<ExtendOpTy>(AR, Ty, SE, Depth);
1413 if (!PreStart)
1414 return (SE->*GetExtendExpr)(AR->getStart(), Ty, Depth);
1415
1416 return SE->getAddExpr((SE->*GetExtendExpr)(AR->getStepRecurrence(*SE), Ty,
1417 Depth),
1418 (SE->*GetExtendExpr)(PreStart, Ty, Depth));
1419}
1420
1421// Try to prove away overflow by looking at "nearby" add recurrences. A
1422// motivating example for this rule: if we know `{0,+,4}` is `ult` `-1` and it
1423// does not itself wrap then we can conclude that `{1,+,4}` is `nuw`.
1424//
1425// Formally:
1426//
1427// {S,+,X} == {S-T,+,X} + T
1428// => Ext({S,+,X}) == Ext({S-T,+,X} + T)
1429//
1430// If ({S-T,+,X} + T) does not overflow ... (1)
1431//
1432// RHS == Ext({S-T,+,X} + T) == Ext({S-T,+,X}) + Ext(T)
1433//
1434// If {S-T,+,X} does not overflow ... (2)
1435//
1436// RHS == Ext({S-T,+,X}) + Ext(T) == {Ext(S-T),+,Ext(X)} + Ext(T)
1437// == {Ext(S-T)+Ext(T),+,Ext(X)}
1438//
1439// If (S-T)+T does not overflow ... (3)
1440//
1441// RHS == {Ext(S-T)+Ext(T),+,Ext(X)} == {Ext(S-T+T),+,Ext(X)}
1442// == {Ext(S),+,Ext(X)} == LHS
1443//
1444// Thus, if (1), (2) and (3) are true for some T, then
1445// Ext({S,+,X}) == {Ext(S),+,Ext(X)}
1446//
1447// (3) is implied by (1) -- "(S-T)+T does not overflow" is simply "({S-T,+,X}+T)
1448// does not overflow" restricted to the 0th iteration. Therefore we only need
1449// to check for (1) and (2).
1450//
1451// In the current context, S is `Start`, X is `Step`, Ext is `ExtendOpTy` and T
1452// is `Delta` (defined below).
1453template <typename ExtendOpTy>
1454bool ScalarEvolution::proveNoWrapByVaryingStart(const SCEV *Start,
1455 const SCEV *Step,
1456 const Loop *L) {
1457 auto WrapType = ExtendOpTraits<ExtendOpTy>::WrapType;
1458
1459 // We restrict `Start` to a constant to prevent SCEV from spending too much
1460 // time here. It is correct (but more expensive) to continue with a
1461 // non-constant `Start` and do a general SCEV subtraction to compute
1462 // `PreStart` below.
1463 const SCEVConstant *StartC = dyn_cast<SCEVConstant>(Start);
1464 if (!StartC)
1465 return false;
1466
1467 APInt StartAI = StartC->getAPInt();
1468
1469 for (unsigned Delta : {-2, -1, 1, 2}) {
1470 const SCEV *PreStart = getConstant(StartAI - Delta);
1471
1472 FoldingSetNodeID ID;
1473 ID.AddInteger(scAddRecExpr);
1474 ID.AddPointer(PreStart);
1475 ID.AddPointer(Step);
1476 ID.AddPointer(L);
1477 void *IP = nullptr;
1478 const auto *PreAR =
1479 static_cast<SCEVAddRecExpr *>(UniqueSCEVs.FindNodeOrInsertPos(ID, IP));
1480
1481 // Give up if we don't already have the add recurrence we need because
1482 // actually constructing an add recurrence is relatively expensive.
1483 if (PreAR && PreAR->getNoWrapFlags(WrapType)) { // proves (2)
1484 const SCEV *DeltaS = getConstant(StartC->getType(), Delta);
1486 const SCEV *Limit = ExtendOpTraits<ExtendOpTy>::getOverflowLimitForStep(
1487 DeltaS, &Pred, this);
1488 if (Limit && isKnownPredicate(Pred, PreAR, Limit)) // proves (1)
1489 return true;
1490 }
1491 }
1492
1493 return false;
1494}
1495
1496// Finds an integer D for an expression (C + x + y + ...) such that the top
1497// level addition in (D + (C - D + x + y + ...)) would not wrap (signed or
1498// unsigned) and the number of trailing zeros of (C - D + x + y + ...) is
1499// maximized, where C is the \p ConstantTerm, x, y, ... are arbitrary SCEVs, and
1500// the (C + x + y + ...) expression is \p WholeAddExpr.
1502 const SCEVConstant *ConstantTerm,
1503 const SCEVAddExpr *WholeAddExpr) {
1504 const APInt &C = ConstantTerm->getAPInt();
1505 const unsigned BitWidth = C.getBitWidth();
1506 // Find number of trailing zeros of (x + y + ...) w/o the C first:
1507 uint32_t TZ = BitWidth;
1508 for (unsigned I = 1, E = WholeAddExpr->getNumOperands(); I < E && TZ; ++I)
1509 TZ = std::min(TZ, SE.getMinTrailingZeros(WholeAddExpr->getOperand(I)));
1510 if (TZ) {
1511 // Set D to be as many least significant bits of C as possible while still
1512 // guaranteeing that adding D to (C - D + x + y + ...) won't cause a wrap:
1513 return TZ < BitWidth ? C.trunc(TZ).zext(BitWidth) : C;
1514 }
1515 return APInt(BitWidth, 0);
1516}
1517
1518// Finds an integer D for an affine AddRec expression {C,+,x} such that the top
1519// level addition in (D + {C-D,+,x}) would not wrap (signed or unsigned) and the
1520// number of trailing zeros of (C - D + x * n) is maximized, where C is the \p
1521// ConstantStart, x is an arbitrary \p Step, and n is the loop trip count.
1523 const APInt &ConstantStart,
1524 const SCEV *Step) {
1525 const unsigned BitWidth = ConstantStart.getBitWidth();
1526 const uint32_t TZ = SE.getMinTrailingZeros(Step);
1527 if (TZ)
1528 return TZ < BitWidth ? ConstantStart.trunc(TZ).zext(BitWidth)
1529 : ConstantStart;
1530 return APInt(BitWidth, 0);
1531}
1532
1534 const ScalarEvolution::FoldID &ID, const SCEV *S,
1537 &FoldCacheUser) {
1538 auto I = FoldCache.insert({ID, S});
1539 if (!I.second) {
1540 // Remove FoldCacheUser entry for ID when replacing an existing FoldCache
1541 // entry.
1542 auto &UserIDs = FoldCacheUser[I.first->second];
1543 assert(count(UserIDs, ID) == 1 && "unexpected duplicates in UserIDs");
1544 for (unsigned I = 0; I != UserIDs.size(); ++I)
1545 if (UserIDs[I] == ID) {
1546 std::swap(UserIDs[I], UserIDs.back());
1547 break;
1548 }
1549 UserIDs.pop_back();
1550 I.first->second = S;
1551 }
1552 FoldCacheUser[S].push_back(ID);
1553}
1554
1555const SCEV *
1557 assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) &&
1558 "This is not an extending conversion!");
1559 assert(isSCEVable(Ty) &&
1560 "This is not a conversion to a SCEVable type!");
1561 assert(!Op->getType()->isPointerTy() && "Can't extend pointer!");
1562 Ty = getEffectiveSCEVType(Ty);
1563
1564 FoldID ID(scZeroExtend, Op, Ty);
1565 if (const SCEV *S = FoldCache.lookup(ID))
1566 return S;
1567
1568 const SCEV *S = getZeroExtendExprImpl(Op, Ty, Depth);
1570 insertFoldCacheEntry(ID, S, FoldCache, FoldCacheUser);
1571 return S;
1572}
1573
1575 unsigned Depth) {
1576 assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) &&
1577 "This is not an extending conversion!");
1578 assert(isSCEVable(Ty) && "This is not a conversion to a SCEVable type!");
1579 assert(!Op->getType()->isPointerTy() && "Can't extend pointer!");
1580
1581 // Fold if the operand is constant.
1582 if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op))
1583 return getConstant(SC->getAPInt().zext(getTypeSizeInBits(Ty)));
1584
1585 // zext(zext(x)) --> zext(x)
1587 return getZeroExtendExpr(SZ->getOperand(), Ty, Depth + 1);
1588
1589 // Before doing any expensive analysis, check to see if we've already
1590 // computed a SCEV for this Op and Ty.
1592 ID.AddInteger(scZeroExtend);
1593 ID.AddPointer(Op);
1594 ID.AddPointer(Ty);
1595 void *IP = nullptr;
1596 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
1597 if (Depth > MaxCastDepth) {
1598 SCEV *S = new (SCEVAllocator) SCEVZeroExtendExpr(ID.Intern(SCEVAllocator),
1599 Op, Ty);
1600 UniqueSCEVs.InsertNode(S, IP);
1601 registerUser(S, Op);
1602 return S;
1603 }
1604
1605 // zext(trunc(x)) --> zext(x) or x or trunc(x)
1607 // It's possible the bits taken off by the truncate were all zero bits. If
1608 // so, we should be able to simplify this further.
1609 const SCEV *X = ST->getOperand();
1611 unsigned TruncBits = getTypeSizeInBits(ST->getType());
1612 unsigned NewBits = getTypeSizeInBits(Ty);
1613 if (CR.truncate(TruncBits).zeroExtend(NewBits).contains(
1614 CR.zextOrTrunc(NewBits)))
1615 return getTruncateOrZeroExtend(X, Ty, Depth);
1616 }
1617
1618 // If the input value is a chrec scev, and we can prove that the value
1619 // did not overflow the old, smaller, value, we can zero extend all of the
1620 // operands (often constants). This allows analysis of something like
1621 // this: for (unsigned char X = 0; X < 100; ++X) { int Y = X; }
1623 if (AR->isAffine()) {
1624 const SCEV *Start = AR->getStart();
1625 const SCEV *Step = AR->getStepRecurrence(*this);
1626 unsigned BitWidth = getTypeSizeInBits(AR->getType());
1627 const Loop *L = AR->getLoop();
1628
1629 // If we have special knowledge that this addrec won't overflow,
1630 // we don't need to do any further analysis.
1631 if (AR->hasNoUnsignedWrap()) {
1632 Start =
1634 Step = getZeroExtendExpr(Step, Ty, Depth + 1);
1635 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
1636 }
1637
1638 // Check whether the backedge-taken count is SCEVCouldNotCompute.
1639 // Note that this serves two purposes: It filters out loops that are
1640 // simply not analyzable, and it covers the case where this code is
1641 // being called from within backedge-taken count analysis, such that
1642 // attempting to ask for the backedge-taken count would likely result
1643 // in infinite recursion. In the later case, the analysis code will
1644 // cope with a conservative value, and it will take care to purge
1645 // that value once it has finished.
1646 const SCEV *MaxBECount = getConstantMaxBackedgeTakenCount(L);
1647 if (!isa<SCEVCouldNotCompute>(MaxBECount)) {
1648 // Manually compute the final value for AR, checking for overflow.
1649
1650 // Check whether the backedge-taken count can be losslessly casted to
1651 // the addrec's type. The count is always unsigned.
1652 const SCEV *CastedMaxBECount =
1653 getTruncateOrZeroExtend(MaxBECount, Start->getType(), Depth);
1654 const SCEV *RecastedMaxBECount = getTruncateOrZeroExtend(
1655 CastedMaxBECount, MaxBECount->getType(), Depth);
1656 if (MaxBECount == RecastedMaxBECount) {
1657 Type *WideTy = IntegerType::get(getContext(), BitWidth * 2);
1658 // Check whether Start+Step*MaxBECount has no unsigned overflow.
1659 const SCEV *ZMul = getMulExpr(CastedMaxBECount, Step,
1661 const SCEV *ZAdd = getZeroExtendExpr(getAddExpr(Start, ZMul,
1663 Depth + 1),
1664 WideTy, Depth + 1);
1665 const SCEV *WideStart = getZeroExtendExpr(Start, WideTy, Depth + 1);
1666 const SCEV *WideMaxBECount =
1667 getZeroExtendExpr(CastedMaxBECount, WideTy, Depth + 1);
1668 const SCEV *OperandExtendedAdd =
1669 getAddExpr(WideStart,
1670 getMulExpr(WideMaxBECount,
1671 getZeroExtendExpr(Step, WideTy, Depth + 1),
1674 if (ZAdd == OperandExtendedAdd) {
1675 // Cache knowledge of AR NUW, which is propagated to this AddRec.
1676 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNUW);
1677 // Return the expression with the addrec on the outside.
1678 Start = getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this,
1679 Depth + 1);
1680 Step = getZeroExtendExpr(Step, Ty, Depth + 1);
1681 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
1682 }
1683 // Similar to above, only this time treat the step value as signed.
1684 // This covers loops that count down.
1685 OperandExtendedAdd =
1686 getAddExpr(WideStart,
1687 getMulExpr(WideMaxBECount,
1688 getSignExtendExpr(Step, WideTy, Depth + 1),
1691 if (ZAdd == OperandExtendedAdd) {
1692 // Cache knowledge of AR NW, which is propagated to this AddRec.
1693 // Negative step causes unsigned wrap, but it still can't self-wrap.
1694 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNW);
1695 // Return the expression with the addrec on the outside.
1696 Start = getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this,
1697 Depth + 1);
1698 Step = getSignExtendExpr(Step, Ty, Depth + 1);
1699 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
1700 }
1701 }
1702 }
1703
1704 // Normally, in the cases we can prove no-overflow via a
1705 // backedge guarding condition, we can also compute a backedge
1706 // taken count for the loop. The exceptions are assumptions and
1707 // guards present in the loop -- SCEV is not great at exploiting
1708 // these to compute max backedge taken counts, but can still use
1709 // these to prove lack of overflow. Use this fact to avoid
1710 // doing extra work that may not pay off.
1711 if (!isa<SCEVCouldNotCompute>(MaxBECount) || HasGuards ||
1712 !AC.assumptions().empty()) {
1713
1714 auto NewFlags = proveNoUnsignedWrapViaInduction(AR);
1715 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), NewFlags);
1716 if (AR->hasNoUnsignedWrap()) {
1717 // Same as nuw case above - duplicated here to avoid a compile time
1718 // issue. It's not clear that the order of checks does matter, but
1719 // it's one of two issue possible causes for a change which was
1720 // reverted. Be conservative for the moment.
1721 Start =
1723 Step = getZeroExtendExpr(Step, Ty, Depth + 1);
1724 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
1725 }
1726
1727 // For a negative step, we can extend the operands iff doing so only
1728 // traverses values in the range zext([0,UINT_MAX]).
1729 if (isKnownNegative(Step)) {
1731 getSignedRangeMin(Step));
1734 // Cache knowledge of AR NW, which is propagated to this
1735 // AddRec. Negative step causes unsigned wrap, but it
1736 // still can't self-wrap.
1737 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNW);
1738 // Return the expression with the addrec on the outside.
1739 Start = getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this,
1740 Depth + 1);
1741 Step = getSignExtendExpr(Step, Ty, Depth + 1);
1742 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
1743 }
1744 }
1745 }
1746
1747 // zext({C,+,Step}) --> (zext(D) + zext({C-D,+,Step}))<nuw><nsw>
1748 // if D + (C - D + Step * n) could be proven to not unsigned wrap
1749 // where D maximizes the number of trailing zeros of (C - D + Step * n)
1750 if (const auto *SC = dyn_cast<SCEVConstant>(Start)) {
1751 const APInt &C = SC->getAPInt();
1752 const APInt &D = extractConstantWithoutWrapping(*this, C, Step);
1753 if (D != 0) {
1754 const SCEV *SZExtD = getZeroExtendExpr(getConstant(D), Ty, Depth);
1755 const SCEV *SResidual =
1756 getAddRecExpr(getConstant(C - D), Step, L, AR->getNoWrapFlags());
1757 const SCEV *SZExtR = getZeroExtendExpr(SResidual, Ty, Depth + 1);
1758 return getAddExpr(SZExtD, SZExtR,
1760 Depth + 1);
1761 }
1762 }
1763
1764 if (proveNoWrapByVaryingStart<SCEVZeroExtendExpr>(Start, Step, L)) {
1765 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNUW);
1766 Start =
1768 Step = getZeroExtendExpr(Step, Ty, Depth + 1);
1769 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
1770 }
1771 }
1772
1773 // zext(A % B) --> zext(A) % zext(B)
1774 {
1775 const SCEV *LHS;
1776 const SCEV *RHS;
1777 if (match(Op, m_scev_URem(m_SCEV(LHS), m_SCEV(RHS), *this)))
1778 return getURemExpr(getZeroExtendExpr(LHS, Ty, Depth + 1),
1779 getZeroExtendExpr(RHS, Ty, Depth + 1));
1780 }
1781
1782 // zext(A / B) --> zext(A) / zext(B).
1783 if (auto *Div = dyn_cast<SCEVUDivExpr>(Op))
1784 return getUDivExpr(getZeroExtendExpr(Div->getLHS(), Ty, Depth + 1),
1785 getZeroExtendExpr(Div->getRHS(), Ty, Depth + 1));
1786
1787 if (auto *SA = dyn_cast<SCEVAddExpr>(Op)) {
1788 // zext((A + B + ...)<nuw>) --> (zext(A) + zext(B) + ...)<nuw>
1789 if (SA->hasNoUnsignedWrap()) {
1790 // If the addition does not unsign overflow then we can, by definition,
1791 // commute the zero extension with the addition operation.
1793 for (const auto *Op : SA->operands())
1794 Ops.push_back(getZeroExtendExpr(Op, Ty, Depth + 1));
1795 return getAddExpr(Ops, SCEV::FlagNUW, Depth + 1);
1796 }
1797
1798 // zext(C + x + y + ...) --> (zext(D) + zext((C - D) + x + y + ...))
1799 // if D + (C - D + x + y + ...) could be proven to not unsigned wrap
1800 // where D maximizes the number of trailing zeros of (C - D + x + y + ...)
1801 //
1802 // Often address arithmetics contain expressions like
1803 // (zext (add (shl X, C1), C2)), for instance, (zext (5 + (4 * X))).
1804 // This transformation is useful while proving that such expressions are
1805 // equal or differ by a small constant amount, see LoadStoreVectorizer pass.
1806 if (const auto *SC = dyn_cast<SCEVConstant>(SA->getOperand(0))) {
1807 const APInt &D = extractConstantWithoutWrapping(*this, SC, SA);
1808 if (D != 0) {
1809 const SCEV *SZExtD = getZeroExtendExpr(getConstant(D), Ty, Depth);
1810 const SCEV *SResidual =
1812 const SCEV *SZExtR = getZeroExtendExpr(SResidual, Ty, Depth + 1);
1813 return getAddExpr(SZExtD, SZExtR,
1815 Depth + 1);
1816 }
1817 }
1818 }
1819
1820 if (auto *SM = dyn_cast<SCEVMulExpr>(Op)) {
1821 // zext((A * B * ...)<nuw>) --> (zext(A) * zext(B) * ...)<nuw>
1822 if (SM->hasNoUnsignedWrap()) {
1823 // If the multiply does not unsign overflow then we can, by definition,
1824 // commute the zero extension with the multiply operation.
1826 for (const auto *Op : SM->operands())
1827 Ops.push_back(getZeroExtendExpr(Op, Ty, Depth + 1));
1828 return getMulExpr(Ops, SCEV::FlagNUW, Depth + 1);
1829 }
1830
1831 // zext(2^K * (trunc X to iN)) to iM ->
1832 // 2^K * (zext(trunc X to i{N-K}) to iM)<nuw>
1833 //
1834 // Proof:
1835 //
1836 // zext(2^K * (trunc X to iN)) to iM
1837 // = zext((trunc X to iN) << K) to iM
1838 // = zext((trunc X to i{N-K}) << K)<nuw> to iM
1839 // (because shl removes the top K bits)
1840 // = zext((2^K * (trunc X to i{N-K}))<nuw>) to iM
1841 // = (2^K * (zext(trunc X to i{N-K}) to iM))<nuw>.
1842 //
1843 const APInt *C;
1844 const SCEV *TruncRHS;
1845 if (match(SM,
1846 m_scev_Mul(m_scev_APInt(C), m_scev_Trunc(m_SCEV(TruncRHS)))) &&
1847 C->isPowerOf2()) {
1848 int NewTruncBits =
1849 getTypeSizeInBits(SM->getOperand(1)->getType()) - C->logBase2();
1850 Type *NewTruncTy = IntegerType::get(getContext(), NewTruncBits);
1851 return getMulExpr(
1852 getZeroExtendExpr(SM->getOperand(0), Ty),
1853 getZeroExtendExpr(getTruncateExpr(TruncRHS, NewTruncTy), Ty),
1854 SCEV::FlagNUW, Depth + 1);
1855 }
1856 }
1857
1858 // zext(umin(x, y)) -> umin(zext(x), zext(y))
1859 // zext(umax(x, y)) -> umax(zext(x), zext(y))
1863 for (auto *Operand : MinMax->operands())
1864 Operands.push_back(getZeroExtendExpr(Operand, Ty));
1866 return getUMinExpr(Operands);
1867 return getUMaxExpr(Operands);
1868 }
1869
1870 // zext(umin_seq(x, y)) -> umin_seq(zext(x), zext(y))
1872 assert(isa<SCEVSequentialUMinExpr>(MinMax) && "Not supported!");
1874 for (auto *Operand : MinMax->operands())
1875 Operands.push_back(getZeroExtendExpr(Operand, Ty));
1876 return getUMinExpr(Operands, /*Sequential*/ true);
1877 }
1878
1879 // The cast wasn't folded; create an explicit cast node.
1880 // Recompute the insert position, as it may have been invalidated.
1881 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
1882 SCEV *S = new (SCEVAllocator) SCEVZeroExtendExpr(ID.Intern(SCEVAllocator),
1883 Op, Ty);
1884 UniqueSCEVs.InsertNode(S, IP);
1885 registerUser(S, Op);
1886 return S;
1887}
1888
1889const SCEV *
1891 assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) &&
1892 "This is not an extending conversion!");
1893 assert(isSCEVable(Ty) &&
1894 "This is not a conversion to a SCEVable type!");
1895 assert(!Op->getType()->isPointerTy() && "Can't extend pointer!");
1896 Ty = getEffectiveSCEVType(Ty);
1897
1898 FoldID ID(scSignExtend, Op, Ty);
1899 if (const SCEV *S = FoldCache.lookup(ID))
1900 return S;
1901
1902 const SCEV *S = getSignExtendExprImpl(Op, Ty, Depth);
1904 insertFoldCacheEntry(ID, S, FoldCache, FoldCacheUser);
1905 return S;
1906}
1907
1909 unsigned Depth) {
1910 assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) &&
1911 "This is not an extending conversion!");
1912 assert(isSCEVable(Ty) && "This is not a conversion to a SCEVable type!");
1913 assert(!Op->getType()->isPointerTy() && "Can't extend pointer!");
1914 Ty = getEffectiveSCEVType(Ty);
1915
1916 // Fold if the operand is constant.
1917 if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op))
1918 return getConstant(SC->getAPInt().sext(getTypeSizeInBits(Ty)));
1919
1920 // sext(sext(x)) --> sext(x)
1922 return getSignExtendExpr(SS->getOperand(), Ty, Depth + 1);
1923
1924 // sext(zext(x)) --> zext(x)
1926 return getZeroExtendExpr(SZ->getOperand(), Ty, Depth + 1);
1927
1928 // Before doing any expensive analysis, check to see if we've already
1929 // computed a SCEV for this Op and Ty.
1931 ID.AddInteger(scSignExtend);
1932 ID.AddPointer(Op);
1933 ID.AddPointer(Ty);
1934 void *IP = nullptr;
1935 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
1936 // Limit recursion depth.
1937 if (Depth > MaxCastDepth) {
1938 SCEV *S = new (SCEVAllocator) SCEVSignExtendExpr(ID.Intern(SCEVAllocator),
1939 Op, Ty);
1940 UniqueSCEVs.InsertNode(S, IP);
1941 registerUser(S, Op);
1942 return S;
1943 }
1944
1945 // sext(trunc(x)) --> sext(x) or x or trunc(x)
1947 // It's possible the bits taken off by the truncate were all sign bits. If
1948 // so, we should be able to simplify this further.
1949 const SCEV *X = ST->getOperand();
1951 unsigned TruncBits = getTypeSizeInBits(ST->getType());
1952 unsigned NewBits = getTypeSizeInBits(Ty);
1953 if (CR.truncate(TruncBits).signExtend(NewBits).contains(
1954 CR.sextOrTrunc(NewBits)))
1955 return getTruncateOrSignExtend(X, Ty, Depth);
1956 }
1957
1958 if (auto *SA = dyn_cast<SCEVAddExpr>(Op)) {
1959 // sext((A + B + ...)<nsw>) --> (sext(A) + sext(B) + ...)<nsw>
1960 if (SA->hasNoSignedWrap()) {
1961 // If the addition does not sign overflow then we can, by definition,
1962 // commute the sign extension with the addition operation.
1964 for (const auto *Op : SA->operands())
1965 Ops.push_back(getSignExtendExpr(Op, Ty, Depth + 1));
1966 return getAddExpr(Ops, SCEV::FlagNSW, Depth + 1);
1967 }
1968
1969 // sext(C + x + y + ...) --> (sext(D) + sext((C - D) + x + y + ...))
1970 // if D + (C - D + x + y + ...) could be proven to not signed wrap
1971 // where D maximizes the number of trailing zeros of (C - D + x + y + ...)
1972 //
1973 // For instance, this will bring two seemingly different expressions:
1974 // 1 + sext(5 + 20 * %x + 24 * %y) and
1975 // sext(6 + 20 * %x + 24 * %y)
1976 // to the same form:
1977 // 2 + sext(4 + 20 * %x + 24 * %y)
1978 if (const auto *SC = dyn_cast<SCEVConstant>(SA->getOperand(0))) {
1979 const APInt &D = extractConstantWithoutWrapping(*this, SC, SA);
1980 if (D != 0) {
1981 const SCEV *SSExtD = getSignExtendExpr(getConstant(D), Ty, Depth);
1982 const SCEV *SResidual =
1984 const SCEV *SSExtR = getSignExtendExpr(SResidual, Ty, Depth + 1);
1985 return getAddExpr(SSExtD, SSExtR,
1987 Depth + 1);
1988 }
1989 }
1990 }
1991 // If the input value is a chrec scev, and we can prove that the value
1992 // did not overflow the old, smaller, value, we can sign extend all of the
1993 // operands (often constants). This allows analysis of something like
1994 // this: for (signed char X = 0; X < 100; ++X) { int Y = X; }
1996 if (AR->isAffine()) {
1997 const SCEV *Start = AR->getStart();
1998 const SCEV *Step = AR->getStepRecurrence(*this);
1999 unsigned BitWidth = getTypeSizeInBits(AR->getType());
2000 const Loop *L = AR->getLoop();
2001
2002 // If we have special knowledge that this addrec won't overflow,
2003 // we don't need to do any further analysis.
2004 if (AR->hasNoSignedWrap()) {
2005 Start =
2007 Step = getSignExtendExpr(Step, Ty, Depth + 1);
2008 return getAddRecExpr(Start, Step, L, SCEV::FlagNSW);
2009 }
2010
2011 // Check whether the backedge-taken count is SCEVCouldNotCompute.
2012 // Note that this serves two purposes: It filters out loops that are
2013 // simply not analyzable, and it covers the case where this code is
2014 // being called from within backedge-taken count analysis, such that
2015 // attempting to ask for the backedge-taken count would likely result
2016 // in infinite recursion. In the later case, the analysis code will
2017 // cope with a conservative value, and it will take care to purge
2018 // that value once it has finished.
2019 const SCEV *MaxBECount = getConstantMaxBackedgeTakenCount(L);
2020 if (!isa<SCEVCouldNotCompute>(MaxBECount)) {
2021 // Manually compute the final value for AR, checking for
2022 // overflow.
2023
2024 // Check whether the backedge-taken count can be losslessly casted to
2025 // the addrec's type. The count is always unsigned.
2026 const SCEV *CastedMaxBECount =
2027 getTruncateOrZeroExtend(MaxBECount, Start->getType(), Depth);
2028 const SCEV *RecastedMaxBECount = getTruncateOrZeroExtend(
2029 CastedMaxBECount, MaxBECount->getType(), Depth);
2030 if (MaxBECount == RecastedMaxBECount) {
2031 Type *WideTy = IntegerType::get(getContext(), BitWidth * 2);
2032 // Check whether Start+Step*MaxBECount has no signed overflow.
2033 const SCEV *SMul = getMulExpr(CastedMaxBECount, Step,
2035 const SCEV *SAdd = getSignExtendExpr(getAddExpr(Start, SMul,
2037 Depth + 1),
2038 WideTy, Depth + 1);
2039 const SCEV *WideStart = getSignExtendExpr(Start, WideTy, Depth + 1);
2040 const SCEV *WideMaxBECount =
2041 getZeroExtendExpr(CastedMaxBECount, WideTy, Depth + 1);
2042 const SCEV *OperandExtendedAdd =
2043 getAddExpr(WideStart,
2044 getMulExpr(WideMaxBECount,
2045 getSignExtendExpr(Step, WideTy, Depth + 1),
2048 if (SAdd == OperandExtendedAdd) {
2049 // Cache knowledge of AR NSW, which is propagated to this AddRec.
2050 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNSW);
2051 // Return the expression with the addrec on the outside.
2052 Start = getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this,
2053 Depth + 1);
2054 Step = getSignExtendExpr(Step, Ty, Depth + 1);
2055 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
2056 }
2057 // Similar to above, only this time treat the step value as unsigned.
2058 // This covers loops that count up with an unsigned step.
2059 OperandExtendedAdd =
2060 getAddExpr(WideStart,
2061 getMulExpr(WideMaxBECount,
2062 getZeroExtendExpr(Step, WideTy, Depth + 1),
2065 if (SAdd == OperandExtendedAdd) {
2066 // If AR wraps around then
2067 //
2068 // abs(Step) * MaxBECount > unsigned-max(AR->getType())
2069 // => SAdd != OperandExtendedAdd
2070 //
2071 // Thus (AR is not NW => SAdd != OperandExtendedAdd) <=>
2072 // (SAdd == OperandExtendedAdd => AR is NW)
2073
2074 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNW);
2075
2076 // Return the expression with the addrec on the outside.
2077 Start = getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this,
2078 Depth + 1);
2079 Step = getZeroExtendExpr(Step, Ty, Depth + 1);
2080 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
2081 }
2082 }
2083 }
2084
2085 auto NewFlags = proveNoSignedWrapViaInduction(AR);
2086 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), NewFlags);
2087 if (AR->hasNoSignedWrap()) {
2088 // Same as nsw case above - duplicated here to avoid a compile time
2089 // issue. It's not clear that the order of checks does matter, but
2090 // it's one of two issue possible causes for a change which was
2091 // reverted. Be conservative for the moment.
2092 Start =
2094 Step = getSignExtendExpr(Step, Ty, Depth + 1);
2095 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
2096 }
2097
2098 // sext({C,+,Step}) --> (sext(D) + sext({C-D,+,Step}))<nuw><nsw>
2099 // if D + (C - D + Step * n) could be proven to not signed wrap
2100 // where D maximizes the number of trailing zeros of (C - D + Step * n)
2101 if (const auto *SC = dyn_cast<SCEVConstant>(Start)) {
2102 const APInt &C = SC->getAPInt();
2103 const APInt &D = extractConstantWithoutWrapping(*this, C, Step);
2104 if (D != 0) {
2105 const SCEV *SSExtD = getSignExtendExpr(getConstant(D), Ty, Depth);
2106 const SCEV *SResidual =
2107 getAddRecExpr(getConstant(C - D), Step, L, AR->getNoWrapFlags());
2108 const SCEV *SSExtR = getSignExtendExpr(SResidual, Ty, Depth + 1);
2109 return getAddExpr(SSExtD, SSExtR,
2111 Depth + 1);
2112 }
2113 }
2114
2115 if (proveNoWrapByVaryingStart<SCEVSignExtendExpr>(Start, Step, L)) {
2116 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNSW);
2117 Start =
2119 Step = getSignExtendExpr(Step, Ty, Depth + 1);
2120 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
2121 }
2122 }
2123
2124 // If the input value is provably positive and we could not simplify
2125 // away the sext build a zext instead.
2127 return getZeroExtendExpr(Op, Ty, Depth + 1);
2128
2129 // sext(smin(x, y)) -> smin(sext(x), sext(y))
2130 // sext(smax(x, y)) -> smax(sext(x), sext(y))
2134 for (auto *Operand : MinMax->operands())
2135 Operands.push_back(getSignExtendExpr(Operand, Ty));
2137 return getSMinExpr(Operands);
2138 return getSMaxExpr(Operands);
2139 }
2140
2141 // The cast wasn't folded; create an explicit cast node.
2142 // Recompute the insert position, as it may have been invalidated.
2143 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
2144 SCEV *S = new (SCEVAllocator) SCEVSignExtendExpr(ID.Intern(SCEVAllocator),
2145 Op, Ty);
2146 UniqueSCEVs.InsertNode(S, IP);
2147 registerUser(S, { Op });
2148 return S;
2149}
2150
2152 Type *Ty) {
2153 switch (Kind) {
2154 case scTruncate:
2155 return getTruncateExpr(Op, Ty);
2156 case scZeroExtend:
2157 return getZeroExtendExpr(Op, Ty);
2158 case scSignExtend:
2159 return getSignExtendExpr(Op, Ty);
2160 case scPtrToInt:
2161 return getPtrToIntExpr(Op, Ty);
2162 default:
2163 llvm_unreachable("Not a SCEV cast expression!");
2164 }
2165}
2166
2167/// getAnyExtendExpr - Return a SCEV for the given operand extended with
2168/// unspecified bits out to the given type.
2170 Type *Ty) {
2171 assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) &&
2172 "This is not an extending conversion!");
2173 assert(isSCEVable(Ty) &&
2174 "This is not a conversion to a SCEVable type!");
2175 Ty = getEffectiveSCEVType(Ty);
2176
2177 // Sign-extend negative constants.
2178 if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op))
2179 if (SC->getAPInt().isNegative())
2180 return getSignExtendExpr(Op, Ty);
2181
2182 // Peel off a truncate cast.
2184 const SCEV *NewOp = T->getOperand();
2185 if (getTypeSizeInBits(NewOp->getType()) < getTypeSizeInBits(Ty))
2186 return getAnyExtendExpr(NewOp, Ty);
2187 return getTruncateOrNoop(NewOp, Ty);
2188 }
2189
2190 // Next try a zext cast. If the cast is folded, use it.
2191 const SCEV *ZExt = getZeroExtendExpr(Op, Ty);
2192 if (!isa<SCEVZeroExtendExpr>(ZExt))
2193 return ZExt;
2194
2195 // Next try a sext cast. If the cast is folded, use it.
2196 const SCEV *SExt = getSignExtendExpr(Op, Ty);
2197 if (!isa<SCEVSignExtendExpr>(SExt))
2198 return SExt;
2199
2200 // Force the cast to be folded into the operands of an addrec.
2201 if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Op)) {
2203 for (const SCEV *Op : AR->operands())
2204 Ops.push_back(getAnyExtendExpr(Op, Ty));
2205 return getAddRecExpr(Ops, AR->getLoop(), SCEV::FlagNW);
2206 }
2207
2208 // If the expression is obviously signed, use the sext cast value.
2209 if (isa<SCEVSMaxExpr>(Op))
2210 return SExt;
2211
2212 // Absent any other information, use the zext cast value.
2213 return ZExt;
2214}
2215
2216/// Process the given Ops list, which is a list of operands to be added under
2217/// the given scale, update the given map. This is a helper function for
2218/// getAddRecExpr. As an example of what it does, given a sequence of operands
2219/// that would form an add expression like this:
2220///
2221/// m + n + 13 + (A * (o + p + (B * (q + m + 29)))) + r + (-1 * r)
2222///
2223/// where A and B are constants, update the map with these values:
2224///
2225/// (m, 1+A*B), (n, 1), (o, A), (p, A), (q, A*B), (r, 0)
2226///
2227/// and add 13 + A*B*29 to AccumulatedConstant.
2228/// This will allow getAddRecExpr to produce this:
2229///
2230/// 13+A*B*29 + n + (m * (1+A*B)) + ((o + p) * A) + (q * A*B)
2231///
2232/// This form often exposes folding opportunities that are hidden in
2233/// the original operand list.
2234///
2235/// Return true iff it appears that any interesting folding opportunities
2236/// may be exposed. This helps getAddRecExpr short-circuit extra work in
2237/// the common case where no interesting opportunities are present, and
2238/// is also used as a check to avoid infinite recursion.
2239static bool
2242 APInt &AccumulatedConstant,
2243 ArrayRef<const SCEV *> Ops, const APInt &Scale,
2244 ScalarEvolution &SE) {
2245 bool Interesting = false;
2246
2247 // Iterate over the add operands. They are sorted, with constants first.
2248 unsigned i = 0;
2249 while (const SCEVConstant *C = dyn_cast<SCEVConstant>(Ops[i])) {
2250 ++i;
2251 // Pull a buried constant out to the outside.
2252 if (Scale != 1 || AccumulatedConstant != 0 || C->getValue()->isZero())
2253 Interesting = true;
2254 AccumulatedConstant += Scale * C->getAPInt();
2255 }
2256
2257 // Next comes everything else. We're especially interested in multiplies
2258 // here, but they're in the middle, so just visit the rest with one loop.
2259 for (; i != Ops.size(); ++i) {
2261 if (Mul && isa<SCEVConstant>(Mul->getOperand(0))) {
2262 APInt NewScale =
2263 Scale * cast<SCEVConstant>(Mul->getOperand(0))->getAPInt();
2264 if (Mul->getNumOperands() == 2 && isa<SCEVAddExpr>(Mul->getOperand(1))) {
2265 // A multiplication of a constant with another add; recurse.
2266 const SCEVAddExpr *Add = cast<SCEVAddExpr>(Mul->getOperand(1));
2267 Interesting |=
2268 CollectAddOperandsWithScales(M, NewOps, AccumulatedConstant,
2269 Add->operands(), NewScale, SE);
2270 } else {
2271 // A multiplication of a constant with some other value. Update
2272 // the map.
2273 SmallVector<const SCEV *, 4> MulOps(drop_begin(Mul->operands()));
2274 const SCEV *Key = SE.getMulExpr(MulOps);
2275 auto Pair = M.insert({Key, NewScale});
2276 if (Pair.second) {
2277 NewOps.push_back(Pair.first->first);
2278 } else {
2279 Pair.first->second += NewScale;
2280 // The map already had an entry for this value, which may indicate
2281 // a folding opportunity.
2282 Interesting = true;
2283 }
2284 }
2285 } else {
2286 // An ordinary operand. Update the map.
2287 std::pair<DenseMap<const SCEV *, APInt>::iterator, bool> Pair =
2288 M.insert({Ops[i], Scale});
2289 if (Pair.second) {
2290 NewOps.push_back(Pair.first->first);
2291 } else {
2292 Pair.first->second += Scale;
2293 // The map already had an entry for this value, which may indicate
2294 // a folding opportunity.
2295 Interesting = true;
2296 }
2297 }
2298 }
2299
2300 return Interesting;
2301}
2302
2304 const SCEV *LHS, const SCEV *RHS,
2305 const Instruction *CtxI) {
2306 const SCEV *(ScalarEvolution::*Operation)(const SCEV *, const SCEV *,
2307 SCEV::NoWrapFlags, unsigned);
2308 switch (BinOp) {
2309 default:
2310 llvm_unreachable("Unsupported binary op");
2311 case Instruction::Add:
2313 break;
2314 case Instruction::Sub:
2316 break;
2317 case Instruction::Mul:
2319 break;
2320 }
2321
2322 const SCEV *(ScalarEvolution::*Extension)(const SCEV *, Type *, unsigned) =
2325
2326 // Check ext(LHS op RHS) == ext(LHS) op ext(RHS)
2327 auto *NarrowTy = cast<IntegerType>(LHS->getType());
2328 auto *WideTy =
2329 IntegerType::get(NarrowTy->getContext(), NarrowTy->getBitWidth() * 2);
2330
2331 const SCEV *A = (this->*Extension)(
2332 (this->*Operation)(LHS, RHS, SCEV::FlagAnyWrap, 0), WideTy, 0);
2333 const SCEV *LHSB = (this->*Extension)(LHS, WideTy, 0);
2334 const SCEV *RHSB = (this->*Extension)(RHS, WideTy, 0);
2335 const SCEV *B = (this->*Operation)(LHSB, RHSB, SCEV::FlagAnyWrap, 0);
2336 if (A == B)
2337 return true;
2338 // Can we use context to prove the fact we need?
2339 if (!CtxI)
2340 return false;
2341 // TODO: Support mul.
2342 if (BinOp == Instruction::Mul)
2343 return false;
2344 auto *RHSC = dyn_cast<SCEVConstant>(RHS);
2345 // TODO: Lift this limitation.
2346 if (!RHSC)
2347 return false;
2348 APInt C = RHSC->getAPInt();
2349 unsigned NumBits = C.getBitWidth();
2350 bool IsSub = (BinOp == Instruction::Sub);
2351 bool IsNegativeConst = (Signed && C.isNegative());
2352 // Compute the direction and magnitude by which we need to check overflow.
2353 bool OverflowDown = IsSub ^ IsNegativeConst;
2354 APInt Magnitude = C;
2355 if (IsNegativeConst) {
2356 if (C == APInt::getSignedMinValue(NumBits))
2357 // TODO: SINT_MIN on inversion gives the same negative value, we don't
2358 // want to deal with that.
2359 return false;
2360 Magnitude = -C;
2361 }
2362
2364 if (OverflowDown) {
2365 // To avoid overflow down, we need to make sure that MIN + Magnitude <= LHS.
2366 APInt Min = Signed ? APInt::getSignedMinValue(NumBits)
2367 : APInt::getMinValue(NumBits);
2368 APInt Limit = Min + Magnitude;
2369 return isKnownPredicateAt(Pred, getConstant(Limit), LHS, CtxI);
2370 } else {
2371 // To avoid overflow up, we need to make sure that LHS <= MAX - Magnitude.
2372 APInt Max = Signed ? APInt::getSignedMaxValue(NumBits)
2373 : APInt::getMaxValue(NumBits);
2374 APInt Limit = Max - Magnitude;
2375 return isKnownPredicateAt(Pred, LHS, getConstant(Limit), CtxI);
2376 }
2377}
2378
2379std::optional<SCEV::NoWrapFlags>
2381 const OverflowingBinaryOperator *OBO) {
2382 // It cannot be done any better.
2383 if (OBO->hasNoUnsignedWrap() && OBO->hasNoSignedWrap())
2384 return std::nullopt;
2385
2387
2388 if (OBO->hasNoUnsignedWrap())
2390 if (OBO->hasNoSignedWrap())
2392
2393 bool Deduced = false;
2394
2395 if (OBO->getOpcode() != Instruction::Add &&
2396 OBO->getOpcode() != Instruction::Sub &&
2397 OBO->getOpcode() != Instruction::Mul)
2398 return std::nullopt;
2399
2400 const SCEV *LHS = getSCEV(OBO->getOperand(0));
2401 const SCEV *RHS = getSCEV(OBO->getOperand(1));
2402
2403 const Instruction *CtxI =
2405 if (!OBO->hasNoUnsignedWrap() &&
2407 /* Signed */ false, LHS, RHS, CtxI)) {
2409 Deduced = true;
2410 }
2411
2412 if (!OBO->hasNoSignedWrap() &&
2414 /* Signed */ true, LHS, RHS, CtxI)) {
2416 Deduced = true;
2417 }
2418
2419 if (Deduced)
2420 return Flags;
2421 return std::nullopt;
2422}
2423
2424// We're trying to construct a SCEV of type `Type' with `Ops' as operands and
2425// `OldFlags' as can't-wrap behavior. Infer a more aggressive set of
2426// can't-overflow flags for the operation if possible.
2430 SCEV::NoWrapFlags Flags) {
2431 using namespace std::placeholders;
2432
2433 using OBO = OverflowingBinaryOperator;
2434
2435 bool CanAnalyze =
2437 (void)CanAnalyze;
2438 assert(CanAnalyze && "don't call from other places!");
2439
2440 int SignOrUnsignMask = SCEV::FlagNUW | SCEV::FlagNSW;
2441 SCEV::NoWrapFlags SignOrUnsignWrap =
2442 ScalarEvolution::maskFlags(Flags, SignOrUnsignMask);
2443
2444 // If FlagNSW is true and all the operands are non-negative, infer FlagNUW.
2445 auto IsKnownNonNegative = [&](const SCEV *S) {
2446 return SE->isKnownNonNegative(S);
2447 };
2448
2449 if (SignOrUnsignWrap == SCEV::FlagNSW && all_of(Ops, IsKnownNonNegative))
2450 Flags =
2451 ScalarEvolution::setFlags(Flags, (SCEV::NoWrapFlags)SignOrUnsignMask);
2452
2453 SignOrUnsignWrap = ScalarEvolution::maskFlags(Flags, SignOrUnsignMask);
2454
2455 if (SignOrUnsignWrap != SignOrUnsignMask &&
2456 (Type == scAddExpr || Type == scMulExpr) && Ops.size() == 2 &&
2457 isa<SCEVConstant>(Ops[0])) {
2458
2459 auto Opcode = [&] {
2460 switch (Type) {
2461 case scAddExpr:
2462 return Instruction::Add;
2463 case scMulExpr:
2464 return Instruction::Mul;
2465 default:
2466 llvm_unreachable("Unexpected SCEV op.");
2467 }
2468 }();
2469
2470 const APInt &C = cast<SCEVConstant>(Ops[0])->getAPInt();
2471
2472 // (A <opcode> C) --> (A <opcode> C)<nsw> if the op doesn't sign overflow.
2473 if (!(SignOrUnsignWrap & SCEV::FlagNSW)) {
2475 Opcode, C, OBO::NoSignedWrap);
2476 if (NSWRegion.contains(SE->getSignedRange(Ops[1])))
2478 }
2479
2480 // (A <opcode> C) --> (A <opcode> C)<nuw> if the op doesn't unsign overflow.
2481 if (!(SignOrUnsignWrap & SCEV::FlagNUW)) {
2483 Opcode, C, OBO::NoUnsignedWrap);
2484 if (NUWRegion.contains(SE->getUnsignedRange(Ops[1])))
2486 }
2487 }
2488
2489 // <0,+,nonnegative><nw> is also nuw
2490 // TODO: Add corresponding nsw case
2492 !ScalarEvolution::hasFlags(Flags, SCEV::FlagNUW) && Ops.size() == 2 &&
2493 Ops[0]->isZero() && IsKnownNonNegative(Ops[1]))
2495
2496 // both (udiv X, Y) * Y and Y * (udiv X, Y) are always NUW
2498 Ops.size() == 2) {
2499 if (auto *UDiv = dyn_cast<SCEVUDivExpr>(Ops[0]))
2500 if (UDiv->getOperand(1) == Ops[1])
2502 if (auto *UDiv = dyn_cast<SCEVUDivExpr>(Ops[1]))
2503 if (UDiv->getOperand(1) == Ops[0])
2505 }
2506
2507 return Flags;
2508}
2509
2511 return isLoopInvariant(S, L) && properlyDominates(S, L->getHeader());
2512}
2513
2514/// Get a canonical add expression, or something simpler if possible.
2516 SCEV::NoWrapFlags OrigFlags,
2517 unsigned Depth) {
2518 assert(!(OrigFlags & ~(SCEV::FlagNUW | SCEV::FlagNSW)) &&
2519 "only nuw or nsw allowed");
2520 assert(!Ops.empty() && "Cannot get empty add!");
2521 if (Ops.size() == 1) return Ops[0];
2522#ifndef NDEBUG
2523 Type *ETy = getEffectiveSCEVType(Ops[0]->getType());
2524 for (unsigned i = 1, e = Ops.size(); i != e; ++i)
2525 assert(getEffectiveSCEVType(Ops[i]->getType()) == ETy &&
2526 "SCEVAddExpr operand types don't match!");
2527 unsigned NumPtrs = count_if(
2528 Ops, [](const SCEV *Op) { return Op->getType()->isPointerTy(); });
2529 assert(NumPtrs <= 1 && "add has at most one pointer operand");
2530#endif
2531
2532 const SCEV *Folded = constantFoldAndGroupOps(
2533 *this, LI, DT, Ops,
2534 [](const APInt &C1, const APInt &C2) { return C1 + C2; },
2535 [](const APInt &C) { return C.isZero(); }, // identity
2536 [](const APInt &C) { return false; }); // absorber
2537 if (Folded)
2538 return Folded;
2539
2540 unsigned Idx = isa<SCEVConstant>(Ops[0]) ? 1 : 0;
2541
2542 // Delay expensive flag strengthening until necessary.
2543 auto ComputeFlags = [this, OrigFlags](ArrayRef<const SCEV *> Ops) {
2544 return StrengthenNoWrapFlags(this, scAddExpr, Ops, OrigFlags);
2545 };
2546
2547 // Limit recursion calls depth.
2549 return getOrCreateAddExpr(Ops, ComputeFlags(Ops));
2550
2551 if (SCEV *S = findExistingSCEVInCache(scAddExpr, Ops)) {
2552 // Don't strengthen flags if we have no new information.
2553 SCEVAddExpr *Add = static_cast<SCEVAddExpr *>(S);
2554 if (Add->getNoWrapFlags(OrigFlags) != OrigFlags)
2555 Add->setNoWrapFlags(ComputeFlags(Ops));
2556 return S;
2557 }
2558
2559 // Okay, check to see if the same value occurs in the operand list more than
2560 // once. If so, merge them together into an multiply expression. Since we
2561 // sorted the list, these values are required to be adjacent.
2562 Type *Ty = Ops[0]->getType();
2563 bool FoundMatch = false;
2564 for (unsigned i = 0, e = Ops.size(); i != e-1; ++i)
2565 if (Ops[i] == Ops[i+1]) { // X + Y + Y --> X + Y*2
2566 // Scan ahead to count how many equal operands there are.
2567 unsigned Count = 2;
2568 while (i+Count != e && Ops[i+Count] == Ops[i])
2569 ++Count;
2570 // Merge the values into a multiply.
2571 const SCEV *Scale = getConstant(Ty, Count);
2572 const SCEV *Mul = getMulExpr(Scale, Ops[i], SCEV::FlagAnyWrap, Depth + 1);
2573 if (Ops.size() == Count)
2574 return Mul;
2575 Ops[i] = Mul;
2576 Ops.erase(Ops.begin()+i+1, Ops.begin()+i+Count);
2577 --i; e -= Count - 1;
2578 FoundMatch = true;
2579 }
2580 if (FoundMatch)
2581 return getAddExpr(Ops, OrigFlags, Depth + 1);
2582
2583 // Check for truncates. If all the operands are truncated from the same
2584 // type, see if factoring out the truncate would permit the result to be
2585 // folded. eg., n*trunc(x) + m*trunc(y) --> trunc(trunc(m)*x + trunc(n)*y)
2586 // if the contents of the resulting outer trunc fold to something simple.
2587 auto FindTruncSrcType = [&]() -> Type * {
2588 // We're ultimately looking to fold an addrec of truncs and muls of only
2589 // constants and truncs, so if we find any other types of SCEV
2590 // as operands of the addrec then we bail and return nullptr here.
2591 // Otherwise, we return the type of the operand of a trunc that we find.
2592 if (auto *T = dyn_cast<SCEVTruncateExpr>(Ops[Idx]))
2593 return T->getOperand()->getType();
2594 if (const auto *Mul = dyn_cast<SCEVMulExpr>(Ops[Idx])) {
2595 const auto *LastOp = Mul->getOperand(Mul->getNumOperands() - 1);
2596 if (const auto *T = dyn_cast<SCEVTruncateExpr>(LastOp))
2597 return T->getOperand()->getType();
2598 }
2599 return nullptr;
2600 };
2601 if (auto *SrcType = FindTruncSrcType()) {
2603 bool Ok = true;
2604 // Check all the operands to see if they can be represented in the
2605 // source type of the truncate.
2606 for (const SCEV *Op : Ops) {
2608 if (T->getOperand()->getType() != SrcType) {
2609 Ok = false;
2610 break;
2611 }
2612 LargeOps.push_back(T->getOperand());
2613 } else if (const SCEVConstant *C = dyn_cast<SCEVConstant>(Op)) {
2614 LargeOps.push_back(getAnyExtendExpr(C, SrcType));
2615 } else if (const SCEVMulExpr *M = dyn_cast<SCEVMulExpr>(Op)) {
2616 SmallVector<const SCEV *, 8> LargeMulOps;
2617 for (unsigned j = 0, f = M->getNumOperands(); j != f && Ok; ++j) {
2618 if (const SCEVTruncateExpr *T =
2619 dyn_cast<SCEVTruncateExpr>(M->getOperand(j))) {
2620 if (T->getOperand()->getType() != SrcType) {
2621 Ok = false;
2622 break;
2623 }
2624 LargeMulOps.push_back(T->getOperand());
2625 } else if (const auto *C = dyn_cast<SCEVConstant>(M->getOperand(j))) {
2626 LargeMulOps.push_back(getAnyExtendExpr(C, SrcType));
2627 } else {
2628 Ok = false;
2629 break;
2630 }
2631 }
2632 if (Ok)
2633 LargeOps.push_back(getMulExpr(LargeMulOps, SCEV::FlagAnyWrap, Depth + 1));
2634 } else {
2635 Ok = false;
2636 break;
2637 }
2638 }
2639 if (Ok) {
2640 // Evaluate the expression in the larger type.
2641 const SCEV *Fold = getAddExpr(LargeOps, SCEV::FlagAnyWrap, Depth + 1);
2642 // If it folds to something simple, use it. Otherwise, don't.
2643 if (isa<SCEVConstant>(Fold) || isa<SCEVUnknown>(Fold))
2644 return getTruncateExpr(Fold, Ty);
2645 }
2646 }
2647
2648 if (Ops.size() == 2) {
2649 // Check if we have an expression of the form ((X + C1) - C2), where C1 and
2650 // C2 can be folded in a way that allows retaining wrapping flags of (X +
2651 // C1).
2652 const SCEV *A = Ops[0];
2653 const SCEV *B = Ops[1];
2654 auto *AddExpr = dyn_cast<SCEVAddExpr>(B);
2655 auto *C = dyn_cast<SCEVConstant>(A);
2656 if (AddExpr && C && isa<SCEVConstant>(AddExpr->getOperand(0))) {
2657 auto C1 = cast<SCEVConstant>(AddExpr->getOperand(0))->getAPInt();
2658 auto C2 = C->getAPInt();
2659 SCEV::NoWrapFlags PreservedFlags = SCEV::FlagAnyWrap;
2660
2661 APInt ConstAdd = C1 + C2;
2662 auto AddFlags = AddExpr->getNoWrapFlags();
2663 // Adding a smaller constant is NUW if the original AddExpr was NUW.
2665 ConstAdd.ule(C1)) {
2666 PreservedFlags =
2668 }
2669
2670 // Adding a constant with the same sign and small magnitude is NSW, if the
2671 // original AddExpr was NSW.
2673 C1.isSignBitSet() == ConstAdd.isSignBitSet() &&
2674 ConstAdd.abs().ule(C1.abs())) {
2675 PreservedFlags =
2677 }
2678
2679 if (PreservedFlags != SCEV::FlagAnyWrap) {
2680 SmallVector<const SCEV *, 4> NewOps(AddExpr->operands());
2681 NewOps[0] = getConstant(ConstAdd);
2682 return getAddExpr(NewOps, PreservedFlags);
2683 }
2684 }
2685
2686 // Try to push the constant operand into a ZExt: A + zext (-A + B) -> zext
2687 // (B), if trunc (A) + -A + B does not unsigned-wrap.
2688 const SCEVAddExpr *InnerAdd;
2689 if (match(B, m_scev_ZExt(m_scev_Add(InnerAdd)))) {
2690 const SCEV *NarrowA = getTruncateExpr(A, InnerAdd->getType());
2691 if (NarrowA == getNegativeSCEV(InnerAdd->getOperand(0)) &&
2692 getZeroExtendExpr(NarrowA, B->getType()) == A &&
2693 hasFlags(StrengthenNoWrapFlags(this, scAddExpr, {NarrowA, InnerAdd},
2695 SCEV::FlagNUW)) {
2696 return getZeroExtendExpr(getAddExpr(NarrowA, InnerAdd), B->getType());
2697 }
2698 }
2699 }
2700
2701 // Canonicalize (-1 * urem X, Y) + X --> (Y * X/Y)
2702 const SCEV *Y;
2703 if (Ops.size() == 2 &&
2704 match(Ops[0],
2706 m_scev_URem(m_scev_Specific(Ops[1]), m_SCEV(Y), *this))))
2707 return getMulExpr(Y, getUDivExpr(Ops[1], Y));
2708
2709 // Skip past any other cast SCEVs.
2710 while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scAddExpr)
2711 ++Idx;
2712
2713 // If there are add operands they would be next.
2714 if (Idx < Ops.size()) {
2715 bool DeletedAdd = false;
2716 // If the original flags and all inlined SCEVAddExprs are NUW, use the
2717 // common NUW flag for expression after inlining. Other flags cannot be
2718 // preserved, because they may depend on the original order of operations.
2719 SCEV::NoWrapFlags CommonFlags = maskFlags(OrigFlags, SCEV::FlagNUW);
2720 while (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Ops[Idx])) {
2721 if (Ops.size() > AddOpsInlineThreshold ||
2722 Add->getNumOperands() > AddOpsInlineThreshold)
2723 break;
2724 // If we have an add, expand the add operands onto the end of the operands
2725 // list.
2726 Ops.erase(Ops.begin()+Idx);
2727 append_range(Ops, Add->operands());
2728 DeletedAdd = true;
2729 CommonFlags = maskFlags(CommonFlags, Add->getNoWrapFlags());
2730 }
2731
2732 // If we deleted at least one add, we added operands to the end of the list,
2733 // and they are not necessarily sorted. Recurse to resort and resimplify
2734 // any operands we just acquired.
2735 if (DeletedAdd)
2736 return getAddExpr(Ops, CommonFlags, Depth + 1);
2737 }
2738
2739 // Skip over the add expression until we get to a multiply.
2740 while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scMulExpr)
2741 ++Idx;
2742
2743 // Check to see if there are any folding opportunities present with
2744 // operands multiplied by constant values.
2745 if (Idx < Ops.size() && isa<SCEVMulExpr>(Ops[Idx])) {
2749 APInt AccumulatedConstant(BitWidth, 0);
2750 if (CollectAddOperandsWithScales(M, NewOps, AccumulatedConstant,
2751 Ops, APInt(BitWidth, 1), *this)) {
2752 struct APIntCompare {
2753 bool operator()(const APInt &LHS, const APInt &RHS) const {
2754 return LHS.ult(RHS);
2755 }
2756 };
2757
2758 // Some interesting folding opportunity is present, so its worthwhile to
2759 // re-generate the operands list. Group the operands by constant scale,
2760 // to avoid multiplying by the same constant scale multiple times.
2761 std::map<APInt, SmallVector<const SCEV *, 4>, APIntCompare> MulOpLists;
2762 for (const SCEV *NewOp : NewOps)
2763 MulOpLists[M.find(NewOp)->second].push_back(NewOp);
2764 // Re-generate the operands list.
2765 Ops.clear();
2766 if (AccumulatedConstant != 0)
2767 Ops.push_back(getConstant(AccumulatedConstant));
2768 for (auto &MulOp : MulOpLists) {
2769 if (MulOp.first == 1) {
2770 Ops.push_back(getAddExpr(MulOp.second, SCEV::FlagAnyWrap, Depth + 1));
2771 } else if (MulOp.first != 0) {
2772 Ops.push_back(getMulExpr(
2773 getConstant(MulOp.first),
2774 getAddExpr(MulOp.second, SCEV::FlagAnyWrap, Depth + 1),
2775 SCEV::FlagAnyWrap, Depth + 1));
2776 }
2777 }
2778 if (Ops.empty())
2779 return getZero(Ty);
2780 if (Ops.size() == 1)
2781 return Ops[0];
2782 return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
2783 }
2784 }
2785
2786 // If we are adding something to a multiply expression, make sure the
2787 // something is not already an operand of the multiply. If so, merge it into
2788 // the multiply.
2789 for (; Idx < Ops.size() && isa<SCEVMulExpr>(Ops[Idx]); ++Idx) {
2790 const SCEVMulExpr *Mul = cast<SCEVMulExpr>(Ops[Idx]);
2791 for (unsigned MulOp = 0, e = Mul->getNumOperands(); MulOp != e; ++MulOp) {
2792 const SCEV *MulOpSCEV = Mul->getOperand(MulOp);
2793 if (isa<SCEVConstant>(MulOpSCEV))
2794 continue;
2795 for (unsigned AddOp = 0, e = Ops.size(); AddOp != e; ++AddOp)
2796 if (MulOpSCEV == Ops[AddOp]) {
2797 // Fold W + X + (X * Y * Z) --> W + (X * ((Y*Z)+1))
2798 const SCEV *InnerMul = Mul->getOperand(MulOp == 0);
2799 if (Mul->getNumOperands() != 2) {
2800 // If the multiply has more than two operands, we must get the
2801 // Y*Z term.
2803 Mul->operands().take_front(MulOp));
2804 append_range(MulOps, Mul->operands().drop_front(MulOp + 1));
2805 InnerMul = getMulExpr(MulOps, SCEV::FlagAnyWrap, Depth + 1);
2806 }
2807 SmallVector<const SCEV *, 2> TwoOps = {getOne(Ty), InnerMul};
2808 const SCEV *AddOne = getAddExpr(TwoOps, SCEV::FlagAnyWrap, Depth + 1);
2809 const SCEV *OuterMul = getMulExpr(AddOne, MulOpSCEV,
2811 if (Ops.size() == 2) return OuterMul;
2812 if (AddOp < Idx) {
2813 Ops.erase(Ops.begin()+AddOp);
2814 Ops.erase(Ops.begin()+Idx-1);
2815 } else {
2816 Ops.erase(Ops.begin()+Idx);
2817 Ops.erase(Ops.begin()+AddOp-1);
2818 }
2819 Ops.push_back(OuterMul);
2820 return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
2821 }
2822
2823 // Check this multiply against other multiplies being added together.
2824 for (unsigned OtherMulIdx = Idx+1;
2825 OtherMulIdx < Ops.size() && isa<SCEVMulExpr>(Ops[OtherMulIdx]);
2826 ++OtherMulIdx) {
2827 const SCEVMulExpr *OtherMul = cast<SCEVMulExpr>(Ops[OtherMulIdx]);
2828 // If MulOp occurs in OtherMul, we can fold the two multiplies
2829 // together.
2830 for (unsigned OMulOp = 0, e = OtherMul->getNumOperands();
2831 OMulOp != e; ++OMulOp)
2832 if (OtherMul->getOperand(OMulOp) == MulOpSCEV) {
2833 // Fold X + (A*B*C) + (A*D*E) --> X + (A*(B*C+D*E))
2834 const SCEV *InnerMul1 = Mul->getOperand(MulOp == 0);
2835 if (Mul->getNumOperands() != 2) {
2837 Mul->operands().take_front(MulOp));
2838 append_range(MulOps, Mul->operands().drop_front(MulOp+1));
2839 InnerMul1 = getMulExpr(MulOps, SCEV::FlagAnyWrap, Depth + 1);
2840 }
2841 const SCEV *InnerMul2 = OtherMul->getOperand(OMulOp == 0);
2842 if (OtherMul->getNumOperands() != 2) {
2844 OtherMul->operands().take_front(OMulOp));
2845 append_range(MulOps, OtherMul->operands().drop_front(OMulOp+1));
2846 InnerMul2 = getMulExpr(MulOps, SCEV::FlagAnyWrap, Depth + 1);
2847 }
2848 SmallVector<const SCEV *, 2> TwoOps = {InnerMul1, InnerMul2};
2849 const SCEV *InnerMulSum =
2850 getAddExpr(TwoOps, SCEV::FlagAnyWrap, Depth + 1);
2851 const SCEV *OuterMul = getMulExpr(MulOpSCEV, InnerMulSum,
2853 if (Ops.size() == 2) return OuterMul;
2854 Ops.erase(Ops.begin()+Idx);
2855 Ops.erase(Ops.begin()+OtherMulIdx-1);
2856 Ops.push_back(OuterMul);
2857 return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
2858 }
2859 }
2860 }
2861 }
2862
2863 // If there are any add recurrences in the operands list, see if any other
2864 // added values are loop invariant. If so, we can fold them into the
2865 // recurrence.
2866 while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scAddRecExpr)
2867 ++Idx;
2868
2869 // Scan over all recurrences, trying to fold loop invariants into them.
2870 for (; Idx < Ops.size() && isa<SCEVAddRecExpr>(Ops[Idx]); ++Idx) {
2871 // Scan all of the other operands to this add and add them to the vector if
2872 // they are loop invariant w.r.t. the recurrence.
2874 const SCEVAddRecExpr *AddRec = cast<SCEVAddRecExpr>(Ops[Idx]);
2875 const Loop *AddRecLoop = AddRec->getLoop();
2876 for (unsigned i = 0, e = Ops.size(); i != e; ++i)
2877 if (isAvailableAtLoopEntry(Ops[i], AddRecLoop)) {
2878 LIOps.push_back(Ops[i]);
2879 Ops.erase(Ops.begin()+i);
2880 --i; --e;
2881 }
2882
2883 // If we found some loop invariants, fold them into the recurrence.
2884 if (!LIOps.empty()) {
2885 // Compute nowrap flags for the addition of the loop-invariant ops and
2886 // the addrec. Temporarily push it as an operand for that purpose. These
2887 // flags are valid in the scope of the addrec only.
2888 LIOps.push_back(AddRec);
2889 SCEV::NoWrapFlags Flags = ComputeFlags(LIOps);
2890 LIOps.pop_back();
2891
2892 // NLI + LI + {Start,+,Step} --> NLI + {LI+Start,+,Step}
2893 LIOps.push_back(AddRec->getStart());
2894
2895 SmallVector<const SCEV *, 4> AddRecOps(AddRec->operands());
2896
2897 // It is not in general safe to propagate flags valid on an add within
2898 // the addrec scope to one outside it. We must prove that the inner
2899 // scope is guaranteed to execute if the outer one does to be able to
2900 // safely propagate. We know the program is undefined if poison is
2901 // produced on the inner scoped addrec. We also know that *for this use*
2902 // the outer scoped add can't overflow (because of the flags we just
2903 // computed for the inner scoped add) without the program being undefined.
2904 // Proving that entry to the outer scope neccesitates entry to the inner
2905 // scope, thus proves the program undefined if the flags would be violated
2906 // in the outer scope.
2907 SCEV::NoWrapFlags AddFlags = Flags;
2908 if (AddFlags != SCEV::FlagAnyWrap) {
2909 auto *DefI = getDefiningScopeBound(LIOps);
2910 auto *ReachI = &*AddRecLoop->getHeader()->begin();
2911 if (!isGuaranteedToTransferExecutionTo(DefI, ReachI))
2912 AddFlags = SCEV::FlagAnyWrap;
2913 }
2914 AddRecOps[0] = getAddExpr(LIOps, AddFlags, Depth + 1);
2915
2916 // Build the new addrec. Propagate the NUW and NSW flags if both the
2917 // outer add and the inner addrec are guaranteed to have no overflow.
2918 // Always propagate NW.
2919 Flags = AddRec->getNoWrapFlags(setFlags(Flags, SCEV::FlagNW));
2920 const SCEV *NewRec = getAddRecExpr(AddRecOps, AddRecLoop, Flags);
2921
2922 // If all of the other operands were loop invariant, we are done.
2923 if (Ops.size() == 1) return NewRec;
2924
2925 // Otherwise, add the folded AddRec by the non-invariant parts.
2926 for (unsigned i = 0;; ++i)
2927 if (Ops[i] == AddRec) {
2928 Ops[i] = NewRec;
2929 break;
2930 }
2931 return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
2932 }
2933
2934 // Okay, if there weren't any loop invariants to be folded, check to see if
2935 // there are multiple AddRec's with the same loop induction variable being
2936 // added together. If so, we can fold them.
2937 for (unsigned OtherIdx = Idx+1;
2938 OtherIdx < Ops.size() && isa<SCEVAddRecExpr>(Ops[OtherIdx]);
2939 ++OtherIdx) {
2940 // We expect the AddRecExpr's to be sorted in reverse dominance order,
2941 // so that the 1st found AddRecExpr is dominated by all others.
2942 assert(DT.dominates(
2943 cast<SCEVAddRecExpr>(Ops[OtherIdx])->getLoop()->getHeader(),
2944 AddRec->getLoop()->getHeader()) &&
2945 "AddRecExprs are not sorted in reverse dominance order?");
2946 if (AddRecLoop == cast<SCEVAddRecExpr>(Ops[OtherIdx])->getLoop()) {
2947 // Other + {A,+,B}<L> + {C,+,D}<L> --> Other + {A+C,+,B+D}<L>
2948 SmallVector<const SCEV *, 4> AddRecOps(AddRec->operands());
2949 for (; OtherIdx != Ops.size() && isa<SCEVAddRecExpr>(Ops[OtherIdx]);
2950 ++OtherIdx) {
2951 const auto *OtherAddRec = cast<SCEVAddRecExpr>(Ops[OtherIdx]);
2952 if (OtherAddRec->getLoop() == AddRecLoop) {
2953 for (unsigned i = 0, e = OtherAddRec->getNumOperands();
2954 i != e; ++i) {
2955 if (i >= AddRecOps.size()) {
2956 append_range(AddRecOps, OtherAddRec->operands().drop_front(i));
2957 break;
2958 }
2960 AddRecOps[i], OtherAddRec->getOperand(i)};
2961 AddRecOps[i] = getAddExpr(TwoOps, SCEV::FlagAnyWrap, Depth + 1);
2962 }
2963 Ops.erase(Ops.begin() + OtherIdx); --OtherIdx;
2964 }
2965 }
2966 // Step size has changed, so we cannot guarantee no self-wraparound.
2967 Ops[Idx] = getAddRecExpr(AddRecOps, AddRecLoop, SCEV::FlagAnyWrap);
2968 return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
2969 }
2970 }
2971
2972 // Otherwise couldn't fold anything into this recurrence. Move onto the
2973 // next one.
2974 }
2975
2976 // Okay, it looks like we really DO need an add expr. Check to see if we
2977 // already have one, otherwise create a new one.
2978 return getOrCreateAddExpr(Ops, ComputeFlags(Ops));
2979}
2980
2981const SCEV *
2982ScalarEvolution::getOrCreateAddExpr(ArrayRef<const SCEV *> Ops,
2983 SCEV::NoWrapFlags Flags) {
2985 ID.AddInteger(scAddExpr);
2986 for (const SCEV *Op : Ops)
2987 ID.AddPointer(Op);
2988 void *IP = nullptr;
2989 SCEVAddExpr *S =
2990 static_cast<SCEVAddExpr *>(UniqueSCEVs.FindNodeOrInsertPos(ID, IP));
2991 if (!S) {
2992 const SCEV **O = SCEVAllocator.Allocate<const SCEV *>(Ops.size());
2994 S = new (SCEVAllocator)
2995 SCEVAddExpr(ID.Intern(SCEVAllocator), O, Ops.size());
2996 UniqueSCEVs.InsertNode(S, IP);
2997 registerUser(S, Ops);
2998 }
2999 S->setNoWrapFlags(Flags);
3000 return S;
3001}
3002
3003const SCEV *
3004ScalarEvolution::getOrCreateAddRecExpr(ArrayRef<const SCEV *> Ops,
3005 const Loop *L, SCEV::NoWrapFlags Flags) {
3006 FoldingSetNodeID ID;
3007 ID.AddInteger(scAddRecExpr);
3008 for (const SCEV *Op : Ops)
3009 ID.AddPointer(Op);
3010 ID.AddPointer(L);
3011 void *IP = nullptr;
3012 SCEVAddRecExpr *S =
3013 static_cast<SCEVAddRecExpr *>(UniqueSCEVs.FindNodeOrInsertPos(ID, IP));
3014 if (!S) {
3015 const SCEV **O = SCEVAllocator.Allocate<const SCEV *>(Ops.size());
3017 S = new (SCEVAllocator)
3018 SCEVAddRecExpr(ID.Intern(SCEVAllocator), O, Ops.size(), L);
3019 UniqueSCEVs.InsertNode(S, IP);
3020 LoopUsers[L].push_back(S);
3021 registerUser(S, Ops);
3022 }
3023 setNoWrapFlags(S, Flags);
3024 return S;
3025}
3026
3027const SCEV *
3028ScalarEvolution::getOrCreateMulExpr(ArrayRef<const SCEV *> Ops,
3029 SCEV::NoWrapFlags Flags) {
3030 FoldingSetNodeID ID;
3031 ID.AddInteger(scMulExpr);
3032 for (const SCEV *Op : Ops)
3033 ID.AddPointer(Op);
3034 void *IP = nullptr;
3035 SCEVMulExpr *S =
3036 static_cast<SCEVMulExpr *>(UniqueSCEVs.FindNodeOrInsertPos(ID, IP));
3037 if (!S) {
3038 const SCEV **O = SCEVAllocator.Allocate<const SCEV *>(Ops.size());
3040 S = new (SCEVAllocator) SCEVMulExpr(ID.Intern(SCEVAllocator),
3041 O, Ops.size());
3042 UniqueSCEVs.InsertNode(S, IP);
3043 registerUser(S, Ops);
3044 }
3045 S->setNoWrapFlags(Flags);
3046 return S;
3047}
3048
3049static uint64_t umul_ov(uint64_t i, uint64_t j, bool &Overflow) {
3050 uint64_t k = i*j;
3051 if (j > 1 && k / j != i) Overflow = true;
3052 return k;
3053}
3054
3055/// Compute the result of "n choose k", the binomial coefficient. If an
3056/// intermediate computation overflows, Overflow will be set and the return will
3057/// be garbage. Overflow is not cleared on absence of overflow.
3058static uint64_t Choose(uint64_t n, uint64_t k, bool &Overflow) {
3059 // We use the multiplicative formula:
3060 // n(n-1)(n-2)...(n-(k-1)) / k(k-1)(k-2)...1 .
3061 // At each iteration, we take the n-th term of the numeral and divide by the
3062 // (k-n)th term of the denominator. This division will always produce an
3063 // integral result, and helps reduce the chance of overflow in the
3064 // intermediate computations. However, we can still overflow even when the
3065 // final result would fit.
3066
3067 if (n == 0 || n == k) return 1;
3068 if (k > n) return 0;
3069
3070 if (k > n/2)
3071 k = n-k;
3072
3073 uint64_t r = 1;
3074 for (uint64_t i = 1; i <= k; ++i) {
3075 r = umul_ov(r, n-(i-1), Overflow);
3076 r /= i;
3077 }
3078 return r;
3079}
3080
3081/// Determine if any of the operands in this SCEV are a constant or if
3082/// any of the add or multiply expressions in this SCEV contain a constant.
3083static bool containsConstantInAddMulChain(const SCEV *StartExpr) {
3084 struct FindConstantInAddMulChain {
3085 bool FoundConstant = false;
3086
3087 bool follow(const SCEV *S) {
3088 FoundConstant |= isa<SCEVConstant>(S);
3089 return isa<SCEVAddExpr>(S) || isa<SCEVMulExpr>(S);
3090 }
3091
3092 bool isDone() const {
3093 return FoundConstant;
3094 }
3095 };
3096
3097 FindConstantInAddMulChain F;
3099 ST.visitAll(StartExpr);
3100 return F.FoundConstant;
3101}
3102
3103/// Get a canonical multiply expression, or something simpler if possible.
3105 SCEV::NoWrapFlags OrigFlags,
3106 unsigned Depth) {
3107 assert(OrigFlags == maskFlags(OrigFlags, SCEV::FlagNUW | SCEV::FlagNSW) &&
3108 "only nuw or nsw allowed");
3109 assert(!Ops.empty() && "Cannot get empty mul!");
3110 if (Ops.size() == 1) return Ops[0];
3111#ifndef NDEBUG
3112 Type *ETy = Ops[0]->getType();
3113 assert(!ETy->isPointerTy());
3114 for (unsigned i = 1, e = Ops.size(); i != e; ++i)
3115 assert(Ops[i]->getType() == ETy &&
3116 "SCEVMulExpr operand types don't match!");
3117#endif
3118
3119 const SCEV *Folded = constantFoldAndGroupOps(
3120 *this, LI, DT, Ops,
3121 [](const APInt &C1, const APInt &C2) { return C1 * C2; },
3122 [](const APInt &C) { return C.isOne(); }, // identity
3123 [](const APInt &C) { return C.isZero(); }); // absorber
3124 if (Folded)
3125 return Folded;
3126
3127 // Delay expensive flag strengthening until necessary.
3128 auto ComputeFlags = [this, OrigFlags](ArrayRef<const SCEV *> Ops) {
3129 return StrengthenNoWrapFlags(this, scMulExpr, Ops, OrigFlags);
3130 };
3131
3132 // Limit recursion calls depth.
3134 return getOrCreateMulExpr(Ops, ComputeFlags(Ops));
3135
3136 if (SCEV *S = findExistingSCEVInCache(scMulExpr, Ops)) {
3137 // Don't strengthen flags if we have no new information.
3138 SCEVMulExpr *Mul = static_cast<SCEVMulExpr *>(S);
3139 if (Mul->getNoWrapFlags(OrigFlags) != OrigFlags)
3140 Mul->setNoWrapFlags(ComputeFlags(Ops));
3141 return S;
3142 }
3143
3144 if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(Ops[0])) {
3145 if (Ops.size() == 2) {
3146 // C1*(C2+V) -> C1*C2 + C1*V
3147 // If any of Add's ops are Adds or Muls with a constant, apply this
3148 // transformation as well.
3149 //
3150 // TODO: There are some cases where this transformation is not
3151 // profitable; for example, Add = (C0 + X) * Y + Z. Maybe the scope of
3152 // this transformation should be narrowed down.
3153 const SCEV *Op0, *Op1;
3154 if (match(Ops[1], m_scev_Add(m_SCEV(Op0), m_SCEV(Op1))) &&
3156 const SCEV *LHS = getMulExpr(LHSC, Op0, SCEV::FlagAnyWrap, Depth + 1);
3157 const SCEV *RHS = getMulExpr(LHSC, Op1, SCEV::FlagAnyWrap, Depth + 1);
3158 return getAddExpr(LHS, RHS, SCEV::FlagAnyWrap, Depth + 1);
3159 }
3160
3161 if (Ops[0]->isAllOnesValue()) {
3162 // If we have a mul by -1 of an add, try distributing the -1 among the
3163 // add operands.
3164 if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Ops[1])) {
3166 bool AnyFolded = false;
3167 for (const SCEV *AddOp : Add->operands()) {
3168 const SCEV *Mul = getMulExpr(Ops[0], AddOp, SCEV::FlagAnyWrap,
3169 Depth + 1);
3170 if (!isa<SCEVMulExpr>(Mul)) AnyFolded = true;
3171 NewOps.push_back(Mul);
3172 }
3173 if (AnyFolded)
3174 return getAddExpr(NewOps, SCEV::FlagAnyWrap, Depth + 1);
3175 } else if (const auto *AddRec = dyn_cast<SCEVAddRecExpr>(Ops[1])) {
3176 // Negation preserves a recurrence's no self-wrap property.
3178 for (const SCEV *AddRecOp : AddRec->operands())
3179 Operands.push_back(getMulExpr(Ops[0], AddRecOp, SCEV::FlagAnyWrap,
3180 Depth + 1));
3181 // Let M be the minimum representable signed value. AddRec with nsw
3182 // multiplied by -1 can have signed overflow if and only if it takes a
3183 // value of M: M * (-1) would stay M and (M + 1) * (-1) would be the
3184 // maximum signed value. In all other cases signed overflow is
3185 // impossible.
3186 auto FlagsMask = SCEV::FlagNW;
3187 if (hasFlags(AddRec->getNoWrapFlags(), SCEV::FlagNSW)) {
3188 auto MinInt =
3189 APInt::getSignedMinValue(getTypeSizeInBits(AddRec->getType()));
3190 if (getSignedRangeMin(AddRec) != MinInt)
3191 FlagsMask = setFlags(FlagsMask, SCEV::FlagNSW);
3192 }
3193 return getAddRecExpr(Operands, AddRec->getLoop(),
3194 AddRec->getNoWrapFlags(FlagsMask));
3195 }
3196 }
3197
3198 // Try to push the constant operand into a ZExt: C * zext (A + B) ->
3199 // zext (C*A + C*B) if trunc (C) * (A + B) does not unsigned-wrap.
3200 const SCEVAddExpr *InnerAdd;
3201 if (match(Ops[1], m_scev_ZExt(m_scev_Add(InnerAdd)))) {
3202 const SCEV *NarrowC = getTruncateExpr(LHSC, InnerAdd->getType());
3203 if (isa<SCEVConstant>(InnerAdd->getOperand(0)) &&
3204 getZeroExtendExpr(NarrowC, Ops[1]->getType()) == LHSC &&
3205 hasFlags(StrengthenNoWrapFlags(this, scMulExpr, {NarrowC, InnerAdd},
3207 SCEV::FlagNUW)) {
3208 auto *Res = getMulExpr(NarrowC, InnerAdd, SCEV::FlagNUW, Depth + 1);
3209 return getZeroExtendExpr(Res, Ops[1]->getType(), Depth + 1);
3210 };
3211 }
3212
3213 // Try to fold (C1 * D /u C2) -> C1/C2 * D, if C1 and C2 are powers-of-2,
3214 // D is a multiple of C2, and C1 is a multiple of C2. If C2 is a multiple
3215 // of C1, fold to (D /u (C2 /u C1)).
3216 const SCEV *D;
3217 APInt C1V = LHSC->getAPInt();
3218 // (C1 * D /u C2) == -1 * -C1 * D /u C2 when C1 != INT_MIN. Don't treat -1
3219 // as -1 * 1, as it won't enable additional folds.
3220 if (C1V.isNegative() && !C1V.isMinSignedValue() && !C1V.isAllOnes())
3221 C1V = C1V.abs();
3222 const SCEVConstant *C2;
3223 if (C1V.isPowerOf2() &&
3225 C2->getAPInt().isPowerOf2() &&
3226 C1V.logBase2() <= getMinTrailingZeros(D)) {
3227 const SCEV *NewMul = nullptr;
3228 if (C1V.uge(C2->getAPInt())) {
3229 NewMul = getMulExpr(getUDivExpr(getConstant(C1V), C2), D);
3230 } else if (C2->getAPInt().logBase2() <= getMinTrailingZeros(D)) {
3231 assert(C1V.ugt(1) && "C1 <= 1 should have been folded earlier");
3232 NewMul = getUDivExpr(D, getUDivExpr(C2, getConstant(C1V)));
3233 }
3234 if (NewMul)
3235 return C1V == LHSC->getAPInt() ? NewMul : getNegativeSCEV(NewMul);
3236 }
3237 }
3238 }
3239
3240 // Skip over the add expression until we get to a multiply.
3241 unsigned Idx = 0;
3242 while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scMulExpr)
3243 ++Idx;
3244
3245 // If there are mul operands inline them all into this expression.
3246 if (Idx < Ops.size()) {
3247 bool DeletedMul = false;
3248 while (const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(Ops[Idx])) {
3249 if (Ops.size() > MulOpsInlineThreshold)
3250 break;
3251 // If we have an mul, expand the mul operands onto the end of the
3252 // operands list.
3253 Ops.erase(Ops.begin()+Idx);
3254 append_range(Ops, Mul->operands());
3255 DeletedMul = true;
3256 }
3257
3258 // If we deleted at least one mul, we added operands to the end of the
3259 // list, and they are not necessarily sorted. Recurse to resort and
3260 // resimplify any operands we just acquired.
3261 if (DeletedMul)
3262 return getMulExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
3263 }
3264
3265 // If there are any add recurrences in the operands list, see if any other
3266 // added values are loop invariant. If so, we can fold them into the
3267 // recurrence.
3268 while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scAddRecExpr)
3269 ++Idx;
3270
3271 // Scan over all recurrences, trying to fold loop invariants into them.
3272 for (; Idx < Ops.size() && isa<SCEVAddRecExpr>(Ops[Idx]); ++Idx) {
3273 // Scan all of the other operands to this mul and add them to the vector
3274 // if they are loop invariant w.r.t. the recurrence.
3276 const SCEVAddRecExpr *AddRec = cast<SCEVAddRecExpr>(Ops[Idx]);
3277 for (unsigned i = 0, e = Ops.size(); i != e; ++i)
3278 if (isAvailableAtLoopEntry(Ops[i], AddRec->getLoop())) {
3279 LIOps.push_back(Ops[i]);
3280 Ops.erase(Ops.begin()+i);
3281 --i; --e;
3282 }
3283
3284 // If we found some loop invariants, fold them into the recurrence.
3285 if (!LIOps.empty()) {
3286 // NLI * LI * {Start,+,Step} --> NLI * {LI*Start,+,LI*Step}
3288 NewOps.reserve(AddRec->getNumOperands());
3289 const SCEV *Scale = getMulExpr(LIOps, SCEV::FlagAnyWrap, Depth + 1);
3290
3291 // If both the mul and addrec are nuw, we can preserve nuw.
3292 // If both the mul and addrec are nsw, we can only preserve nsw if either
3293 // a) they are also nuw, or
3294 // b) all multiplications of addrec operands with scale are nsw.
3295 SCEV::NoWrapFlags Flags =
3296 AddRec->getNoWrapFlags(ComputeFlags({Scale, AddRec}));
3297
3298 for (unsigned i = 0, e = AddRec->getNumOperands(); i != e; ++i) {
3299 NewOps.push_back(getMulExpr(Scale, AddRec->getOperand(i),
3300 SCEV::FlagAnyWrap, Depth + 1));
3301
3302 if (hasFlags(Flags, SCEV::FlagNSW) && !hasFlags(Flags, SCEV::FlagNUW)) {
3304 Instruction::Mul, getSignedRange(Scale),
3306 if (!NSWRegion.contains(getSignedRange(AddRec->getOperand(i))))
3307 Flags = clearFlags(Flags, SCEV::FlagNSW);
3308 }
3309 }
3310
3311 const SCEV *NewRec = getAddRecExpr(NewOps, AddRec->getLoop(), Flags);
3312
3313 // If all of the other operands were loop invariant, we are done.
3314 if (Ops.size() == 1) return NewRec;
3315
3316 // Otherwise, multiply the folded AddRec by the non-invariant parts.
3317 for (unsigned i = 0;; ++i)
3318 if (Ops[i] == AddRec) {
3319 Ops[i] = NewRec;
3320 break;
3321 }
3322 return getMulExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
3323 }
3324
3325 // Okay, if there weren't any loop invariants to be folded, check to see
3326 // if there are multiple AddRec's with the same loop induction variable
3327 // being multiplied together. If so, we can fold them.
3328
3329 // {A1,+,A2,+,...,+,An}<L> * {B1,+,B2,+,...,+,Bn}<L>
3330 // = {x=1 in [ sum y=x..2x [ sum z=max(y-x, y-n)..min(x,n) [
3331 // choose(x, 2x)*choose(2x-y, x-z)*A_{y-z}*B_z
3332 // ]]],+,...up to x=2n}.
3333 // Note that the arguments to choose() are always integers with values
3334 // known at compile time, never SCEV objects.
3335 //
3336 // The implementation avoids pointless extra computations when the two
3337 // addrec's are of different length (mathematically, it's equivalent to
3338 // an infinite stream of zeros on the right).
3339 bool OpsModified = false;
3340 for (unsigned OtherIdx = Idx+1;
3341 OtherIdx != Ops.size() && isa<SCEVAddRecExpr>(Ops[OtherIdx]);
3342 ++OtherIdx) {
3343 const SCEVAddRecExpr *OtherAddRec =
3344 dyn_cast<SCEVAddRecExpr>(Ops[OtherIdx]);
3345 if (!OtherAddRec || OtherAddRec->getLoop() != AddRec->getLoop())
3346 continue;
3347
3348 // Limit max number of arguments to avoid creation of unreasonably big
3349 // SCEVAddRecs with very complex operands.
3350 if (AddRec->getNumOperands() + OtherAddRec->getNumOperands() - 1 >
3351 MaxAddRecSize || hasHugeExpression({AddRec, OtherAddRec}))
3352 continue;
3353
3354 bool Overflow = false;
3355 Type *Ty = AddRec->getType();
3356 bool LargerThan64Bits = getTypeSizeInBits(Ty) > 64;
3358 for (int x = 0, xe = AddRec->getNumOperands() +
3359 OtherAddRec->getNumOperands() - 1; x != xe && !Overflow; ++x) {
3361 for (int y = x, ye = 2*x+1; y != ye && !Overflow; ++y) {
3362 uint64_t Coeff1 = Choose(x, 2*x - y, Overflow);
3363 for (int z = std::max(y-x, y-(int)AddRec->getNumOperands()+1),
3364 ze = std::min(x+1, (int)OtherAddRec->getNumOperands());
3365 z < ze && !Overflow; ++z) {
3366 uint64_t Coeff2 = Choose(2*x - y, x-z, Overflow);
3367 uint64_t Coeff;
3368 if (LargerThan64Bits)
3369 Coeff = umul_ov(Coeff1, Coeff2, Overflow);
3370 else
3371 Coeff = Coeff1*Coeff2;
3372 const SCEV *CoeffTerm = getConstant(Ty, Coeff);
3373 const SCEV *Term1 = AddRec->getOperand(y-z);
3374 const SCEV *Term2 = OtherAddRec->getOperand(z);
3375 SumOps.push_back(getMulExpr(CoeffTerm, Term1, Term2,
3376 SCEV::FlagAnyWrap, Depth + 1));
3377 }
3378 }
3379 if (SumOps.empty())
3380 SumOps.push_back(getZero(Ty));
3381 AddRecOps.push_back(getAddExpr(SumOps, SCEV::FlagAnyWrap, Depth + 1));
3382 }
3383 if (!Overflow) {
3384 const SCEV *NewAddRec = getAddRecExpr(AddRecOps, AddRec->getLoop(),
3386 if (Ops.size() == 2) return NewAddRec;
3387 Ops[Idx] = NewAddRec;
3388 Ops.erase(Ops.begin() + OtherIdx); --OtherIdx;
3389 OpsModified = true;
3390 AddRec = dyn_cast<SCEVAddRecExpr>(NewAddRec);
3391 if (!AddRec)
3392 break;
3393 }
3394 }
3395 if (OpsModified)
3396 return getMulExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
3397
3398 // Otherwise couldn't fold anything into this recurrence. Move onto the
3399 // next one.
3400 }
3401
3402 // Okay, it looks like we really DO need an mul expr. Check to see if we
3403 // already have one, otherwise create a new one.
3404 return getOrCreateMulExpr(Ops, ComputeFlags(Ops));
3405}
3406
3407/// Represents an unsigned remainder expression based on unsigned division.
3409 const SCEV *RHS) {
3410 assert(getEffectiveSCEVType(LHS->getType()) ==
3411 getEffectiveSCEVType(RHS->getType()) &&
3412 "SCEVURemExpr operand types don't match!");
3413
3414 // Short-circuit easy cases
3415 if (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS)) {
3416 // If constant is one, the result is trivial
3417 if (RHSC->getValue()->isOne())
3418 return getZero(LHS->getType()); // X urem 1 --> 0
3419
3420 // If constant is a power of two, fold into a zext(trunc(LHS)).
3421 if (RHSC->getAPInt().isPowerOf2()) {
3422 Type *FullTy = LHS->getType();
3423 Type *TruncTy =
3424 IntegerType::get(getContext(), RHSC->getAPInt().logBase2());
3425 return getZeroExtendExpr(getTruncateExpr(LHS, TruncTy), FullTy);
3426 }
3427 }
3428
3429 // Fallback to %a == %x urem %y == %x -<nuw> ((%x udiv %y) *<nuw> %y)
3430 const SCEV *UDiv = getUDivExpr(LHS, RHS);
3431 const SCEV *Mult = getMulExpr(UDiv, RHS, SCEV::FlagNUW);
3432 return getMinusSCEV(LHS, Mult, SCEV::FlagNUW);
3433}
3434
3435/// Get a canonical unsigned division expression, or something simpler if
3436/// possible.
3438 const SCEV *RHS) {
3439 assert(!LHS->getType()->isPointerTy() &&
3440 "SCEVUDivExpr operand can't be pointer!");
3441 assert(LHS->getType() == RHS->getType() &&
3442 "SCEVUDivExpr operand types don't match!");
3443
3445 ID.AddInteger(scUDivExpr);
3446 ID.AddPointer(LHS);
3447 ID.AddPointer(RHS);
3448 void *IP = nullptr;
3449 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP))
3450 return S;
3451
3452 // 0 udiv Y == 0
3453 if (match(LHS, m_scev_Zero()))
3454 return LHS;
3455
3456 if (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS)) {
3457 if (RHSC->getValue()->isOne())
3458 return LHS; // X udiv 1 --> x
3459 // If the denominator is zero, the result of the udiv is undefined. Don't
3460 // try to analyze it, because the resolution chosen here may differ from
3461 // the resolution chosen in other parts of the compiler.
3462 if (!RHSC->getValue()->isZero()) {
3463 // Determine if the division can be folded into the operands of
3464 // its operands.
3465 // TODO: Generalize this to non-constants by using known-bits information.
3466 Type *Ty = LHS->getType();
3467 unsigned LZ = RHSC->getAPInt().countl_zero();
3468 unsigned MaxShiftAmt = getTypeSizeInBits(Ty) - LZ - 1;
3469 // For non-power-of-two values, effectively round the value up to the
3470 // nearest power of two.
3471 if (!RHSC->getAPInt().isPowerOf2())
3472 ++MaxShiftAmt;
3473 IntegerType *ExtTy =
3474 IntegerType::get(getContext(), getTypeSizeInBits(Ty) + MaxShiftAmt);
3475 if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(LHS))
3476 if (const SCEVConstant *Step =
3477 dyn_cast<SCEVConstant>(AR->getStepRecurrence(*this))) {
3478 // {X,+,N}/C --> {X/C,+,N/C} if safe and N/C can be folded.
3479 const APInt &StepInt = Step->getAPInt();
3480 const APInt &DivInt = RHSC->getAPInt();
3481 if (!StepInt.urem(DivInt) &&
3482 getZeroExtendExpr(AR, ExtTy) ==
3483 getAddRecExpr(getZeroExtendExpr(AR->getStart(), ExtTy),
3484 getZeroExtendExpr(Step, ExtTy),
3485 AR->getLoop(), SCEV::FlagAnyWrap)) {
3487 for (const SCEV *Op : AR->operands())
3488 Operands.push_back(getUDivExpr(Op, RHS));
3489 return getAddRecExpr(Operands, AR->getLoop(), SCEV::FlagNW);
3490 }
3491 /// Get a canonical UDivExpr for a recurrence.
3492 /// {X,+,N}/C => {Y,+,N}/C where Y=X-(X%N). Safe when C%N=0.
3493 const APInt *StartRem;
3494 if (!DivInt.urem(StepInt) && match(getURemExpr(AR->getStart(), Step),
3495 m_scev_APInt(StartRem))) {
3496 bool NoWrap =
3497 getZeroExtendExpr(AR, ExtTy) ==
3498 getAddRecExpr(getZeroExtendExpr(AR->getStart(), ExtTy),
3499 getZeroExtendExpr(Step, ExtTy), AR->getLoop(),
3501
3502 // With N <= C and both N, C as powers-of-2, the transformation
3503 // {X,+,N}/C => {(X - X%N),+,N}/C preserves division results even
3504 // if wrapping occurs, as the division results remain equivalent for
3505 // all offsets in [[(X - X%N), X).
3506 bool CanFoldWithWrap = StepInt.ule(DivInt) && // N <= C
3507 StepInt.isPowerOf2() && DivInt.isPowerOf2();
3508 // Only fold if the subtraction can be folded in the start
3509 // expression.
3510 const SCEV *NewStart =
3511 getMinusSCEV(AR->getStart(), getConstant(*StartRem));
3512 if (*StartRem != 0 && (NoWrap || CanFoldWithWrap) &&
3513 !isa<SCEVAddExpr>(NewStart)) {
3514 const SCEV *NewLHS =
3515 getAddRecExpr(NewStart, Step, AR->getLoop(),
3516 NoWrap ? SCEV::FlagNW : SCEV::FlagAnyWrap);
3517 if (LHS != NewLHS) {
3518 LHS = NewLHS;
3519
3520 // Reset the ID to include the new LHS, and check if it is
3521 // already cached.
3522 ID.clear();
3523 ID.AddInteger(scUDivExpr);
3524 ID.AddPointer(LHS);
3525 ID.AddPointer(RHS);
3526 IP = nullptr;
3527 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP))
3528 return S;
3529 }
3530 }
3531 }
3532 }
3533 // (A*B)/C --> A*(B/C) if safe and B/C can be folded.
3534 if (const SCEVMulExpr *M = dyn_cast<SCEVMulExpr>(LHS)) {
3536 for (const SCEV *Op : M->operands())
3537 Operands.push_back(getZeroExtendExpr(Op, ExtTy));
3538 if (getZeroExtendExpr(M, ExtTy) == getMulExpr(Operands))
3539 // Find an operand that's safely divisible.
3540 for (unsigned i = 0, e = M->getNumOperands(); i != e; ++i) {
3541 const SCEV *Op = M->getOperand(i);
3542 const SCEV *Div = getUDivExpr(Op, RHSC);
3543 if (!isa<SCEVUDivExpr>(Div) && getMulExpr(Div, RHSC) == Op) {
3544 Operands = SmallVector<const SCEV *, 4>(M->operands());
3545 Operands[i] = Div;
3546 return getMulExpr(Operands);
3547 }
3548 }
3549 }
3550
3551 // (A/B)/C --> A/(B*C) if safe and B*C can be folded.
3552 if (const SCEVUDivExpr *OtherDiv = dyn_cast<SCEVUDivExpr>(LHS)) {
3553 if (auto *DivisorConstant =
3554 dyn_cast<SCEVConstant>(OtherDiv->getRHS())) {
3555 bool Overflow = false;
3556 APInt NewRHS =
3557 DivisorConstant->getAPInt().umul_ov(RHSC->getAPInt(), Overflow);
3558 if (Overflow) {
3559 return getConstant(RHSC->getType(), 0, false);
3560 }
3561 return getUDivExpr(OtherDiv->getLHS(), getConstant(NewRHS));
3562 }
3563 }
3564
3565 // (A+B)/C --> (A/C + B/C) if safe and A/C and B/C can be folded.
3566 if (const SCEVAddExpr *A = dyn_cast<SCEVAddExpr>(LHS)) {
3568 for (const SCEV *Op : A->operands())
3569 Operands.push_back(getZeroExtendExpr(Op, ExtTy));
3570 if (getZeroExtendExpr(A, ExtTy) == getAddExpr(Operands)) {
3571 Operands.clear();
3572 for (unsigned i = 0, e = A->getNumOperands(); i != e; ++i) {
3573 const SCEV *Op = getUDivExpr(A->getOperand(i), RHS);
3574 if (isa<SCEVUDivExpr>(Op) ||
3575 getMulExpr(Op, RHS) != A->getOperand(i))
3576 break;
3577 Operands.push_back(Op);
3578 }
3579 if (Operands.size() == A->getNumOperands())
3580 return getAddExpr(Operands);
3581 }
3582 }
3583
3584 // Fold if both operands are constant.
3585 if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(LHS))
3586 return getConstant(LHSC->getAPInt().udiv(RHSC->getAPInt()));
3587 }
3588 }
3589
3590 // ((-C + (C smax %x)) /u %x) evaluates to zero, for any positive constant C.
3591 const APInt *NegC, *C;
3592 if (match(LHS,
3595 NegC->isNegative() && !NegC->isMinSignedValue() && *C == -*NegC)
3596 return getZero(LHS->getType());
3597
3598 // TODO: Generalize to handle any common factors.
3599 // udiv (mul nuw a, vscale), (mul nuw b, vscale) --> udiv a, b
3600 const SCEV *NewLHS, *NewRHS;
3601 if (match(LHS, m_scev_c_NUWMul(m_SCEV(NewLHS), m_SCEVVScale())) &&
3602 match(RHS, m_scev_c_NUWMul(m_SCEV(NewRHS), m_SCEVVScale())))
3603 return getUDivExpr(NewLHS, NewRHS);
3604
3605 // The Insertion Point (IP) might be invalid by now (due to UniqueSCEVs
3606 // changes). Make sure we get a new one.
3607 IP = nullptr;
3608 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
3609 SCEV *S = new (SCEVAllocator) SCEVUDivExpr(ID.Intern(SCEVAllocator),
3610 LHS, RHS);
3611 UniqueSCEVs.InsertNode(S, IP);
3612 registerUser(S, {LHS, RHS});
3613 return S;
3614}
3615
3616APInt gcd(const SCEVConstant *C1, const SCEVConstant *C2) {
3617 APInt A = C1->getAPInt().abs();
3618 APInt B = C2->getAPInt().abs();
3619 uint32_t ABW = A.getBitWidth();
3620 uint32_t BBW = B.getBitWidth();
3621
3622 if (ABW > BBW)
3623 B = B.zext(ABW);
3624 else if (ABW < BBW)
3625 A = A.zext(BBW);
3626
3627 return APIntOps::GreatestCommonDivisor(std::move(A), std::move(B));
3628}
3629
3630/// Get a canonical unsigned division expression, or something simpler if
3631/// possible. There is no representation for an exact udiv in SCEV IR, but we
3632/// can attempt to remove factors from the LHS and RHS. We can't do this when
3633/// it's not exact because the udiv may be clearing bits.
3635 const SCEV *RHS) {
3636 // TODO: we could try to find factors in all sorts of things, but for now we
3637 // just deal with u/exact (multiply, constant). See SCEVDivision towards the
3638 // end of this file for inspiration.
3639
3641 if (!Mul || !Mul->hasNoUnsignedWrap())
3642 return getUDivExpr(LHS, RHS);
3643
3644 if (const SCEVConstant *RHSCst = dyn_cast<SCEVConstant>(RHS)) {
3645 // If the mulexpr multiplies by a constant, then that constant must be the
3646 // first element of the mulexpr.
3647 if (const auto *LHSCst = dyn_cast<SCEVConstant>(Mul->getOperand(0))) {
3648 if (LHSCst == RHSCst) {
3649 SmallVector<const SCEV *, 2> Operands(drop_begin(Mul->operands()));
3650 return getMulExpr(Operands);
3651 }
3652
3653 // We can't just assume that LHSCst divides RHSCst cleanly, it could be
3654 // that there's a factor provided by one of the other terms. We need to
3655 // check.
3656 APInt Factor = gcd(LHSCst, RHSCst);
3657 if (!Factor.isIntN(1)) {
3658 LHSCst =
3659 cast<SCEVConstant>(getConstant(LHSCst->getAPInt().udiv(Factor)));
3660 RHSCst =
3661 cast<SCEVConstant>(getConstant(RHSCst->getAPInt().udiv(Factor)));
3663 Operands.push_back(LHSCst);
3664 append_range(Operands, Mul->operands().drop_front());
3665 LHS = getMulExpr(Operands);
3666 RHS = RHSCst;
3668 if (!Mul)
3669 return getUDivExactExpr(LHS, RHS);
3670 }
3671 }
3672 }
3673
3674 for (int i = 0, e = Mul->getNumOperands(); i != e; ++i) {
3675 if (Mul->getOperand(i) == RHS) {
3677 append_range(Operands, Mul->operands().take_front(i));
3678 append_range(Operands, Mul->operands().drop_front(i + 1));
3679 return getMulExpr(Operands);
3680 }
3681 }
3682
3683 return getUDivExpr(LHS, RHS);
3684}
3685
3686/// Get an add recurrence expression for the specified loop. Simplify the
3687/// expression as much as possible.
3688const SCEV *ScalarEvolution::getAddRecExpr(const SCEV *Start, const SCEV *Step,
3689 const Loop *L,
3690 SCEV::NoWrapFlags Flags) {
3692 Operands.push_back(Start);
3693 if (const SCEVAddRecExpr *StepChrec = dyn_cast<SCEVAddRecExpr>(Step))
3694 if (StepChrec->getLoop() == L) {
3695 append_range(Operands, StepChrec->operands());
3696 return getAddRecExpr(Operands, L, maskFlags(Flags, SCEV::FlagNW));
3697 }
3698
3699 Operands.push_back(Step);
3700 return getAddRecExpr(Operands, L, Flags);
3701}
3702
3703/// Get an add recurrence expression for the specified loop. Simplify the
3704/// expression as much as possible.
3705const SCEV *
3707 const Loop *L, SCEV::NoWrapFlags Flags) {
3708 if (Operands.size() == 1) return Operands[0];
3709#ifndef NDEBUG
3710 Type *ETy = getEffectiveSCEVType(Operands[0]->getType());
3711 for (const SCEV *Op : llvm::drop_begin(Operands)) {
3712 assert(getEffectiveSCEVType(Op->getType()) == ETy &&
3713 "SCEVAddRecExpr operand types don't match!");
3714 assert(!Op->getType()->isPointerTy() && "Step must be integer");
3715 }
3716 for (const SCEV *Op : Operands)
3718 "SCEVAddRecExpr operand is not available at loop entry!");
3719#endif
3720
3721 if (Operands.back()->isZero()) {
3722 Operands.pop_back();
3723 return getAddRecExpr(Operands, L, SCEV::FlagAnyWrap); // {X,+,0} --> X
3724 }
3725
3726 // It's tempting to want to call getConstantMaxBackedgeTakenCount count here and
3727 // use that information to infer NUW and NSW flags. However, computing a
3728 // BE count requires calling getAddRecExpr, so we may not yet have a
3729 // meaningful BE count at this point (and if we don't, we'd be stuck
3730 // with a SCEVCouldNotCompute as the cached BE count).
3731
3732 Flags = StrengthenNoWrapFlags(this, scAddRecExpr, Operands, Flags);
3733
3734 // Canonicalize nested AddRecs in by nesting them in order of loop depth.
3735 if (const SCEVAddRecExpr *NestedAR = dyn_cast<SCEVAddRecExpr>(Operands[0])) {
3736 const Loop *NestedLoop = NestedAR->getLoop();
3737 if (L->contains(NestedLoop)
3738 ? (L->getLoopDepth() < NestedLoop->getLoopDepth())
3739 : (!NestedLoop->contains(L) &&
3740 DT.dominates(L->getHeader(), NestedLoop->getHeader()))) {
3741 SmallVector<const SCEV *, 4> NestedOperands(NestedAR->operands());
3742 Operands[0] = NestedAR->getStart();
3743 // AddRecs require their operands be loop-invariant with respect to their
3744 // loops. Don't perform this transformation if it would break this
3745 // requirement.
3746 bool AllInvariant = all_of(
3747 Operands, [&](const SCEV *Op) { return isLoopInvariant(Op, L); });
3748
3749 if (AllInvariant) {
3750 // Create a recurrence for the outer loop with the same step size.
3751 //
3752 // The outer recurrence keeps its NW flag but only keeps NUW/NSW if the
3753 // inner recurrence has the same property.
3754 SCEV::NoWrapFlags OuterFlags =
3755 maskFlags(Flags, SCEV::FlagNW | NestedAR->getNoWrapFlags());
3756
3757 NestedOperands[0] = getAddRecExpr(Operands, L, OuterFlags);
3758 AllInvariant = all_of(NestedOperands, [&](const SCEV *Op) {
3759 return isLoopInvariant(Op, NestedLoop);
3760 });
3761
3762 if (AllInvariant) {
3763 // Ok, both add recurrences are valid after the transformation.
3764 //
3765 // The inner recurrence keeps its NW flag but only keeps NUW/NSW if
3766 // the outer recurrence has the same property.
3767 SCEV::NoWrapFlags InnerFlags =
3768 maskFlags(NestedAR->getNoWrapFlags(), SCEV::FlagNW | Flags);
3769 return getAddRecExpr(NestedOperands, NestedLoop, InnerFlags);
3770 }
3771 }
3772 // Reset Operands to its original state.
3773 Operands[0] = NestedAR;
3774 }
3775 }
3776
3777 // Okay, it looks like we really DO need an addrec expr. Check to see if we
3778 // already have one, otherwise create a new one.
3779 return getOrCreateAddRecExpr(Operands, L, Flags);
3780}
3781
3783 ArrayRef<const SCEV *> IndexExprs) {
3784 const SCEV *BaseExpr = getSCEV(GEP->getPointerOperand());
3785 // getSCEV(Base)->getType() has the same address space as Base->getType()
3786 // because SCEV::getType() preserves the address space.
3787 GEPNoWrapFlags NW = GEP->getNoWrapFlags();
3788 if (NW != GEPNoWrapFlags::none()) {
3789 // We'd like to propagate flags from the IR to the corresponding SCEV nodes,
3790 // but to do that, we have to ensure that said flag is valid in the entire
3791 // defined scope of the SCEV.
3792 // TODO: non-instructions have global scope. We might be able to prove
3793 // some global scope cases
3794 auto *GEPI = dyn_cast<Instruction>(GEP);
3795 if (!GEPI || !isSCEVExprNeverPoison(GEPI))
3796 NW = GEPNoWrapFlags::none();
3797 }
3798
3799 return getGEPExpr(BaseExpr, IndexExprs, GEP->getSourceElementType(), NW);
3800}
3801
3803 ArrayRef<const SCEV *> IndexExprs,
3804 Type *SrcElementTy, GEPNoWrapFlags NW) {
3806 if (NW.hasNoUnsignedSignedWrap())
3807 OffsetWrap = setFlags(OffsetWrap, SCEV::FlagNSW);
3808 if (NW.hasNoUnsignedWrap())
3809 OffsetWrap = setFlags(OffsetWrap, SCEV::FlagNUW);
3810
3811 Type *CurTy = BaseExpr->getType();
3812 Type *IntIdxTy = getEffectiveSCEVType(BaseExpr->getType());
3813 bool FirstIter = true;
3815 for (const SCEV *IndexExpr : IndexExprs) {
3816 // Compute the (potentially symbolic) offset in bytes for this index.
3817 if (StructType *STy = dyn_cast<StructType>(CurTy)) {
3818 // For a struct, add the member offset.
3819 ConstantInt *Index = cast<SCEVConstant>(IndexExpr)->getValue();
3820 unsigned FieldNo = Index->getZExtValue();
3821 const SCEV *FieldOffset = getOffsetOfExpr(IntIdxTy, STy, FieldNo);
3822 Offsets.push_back(FieldOffset);
3823
3824 // Update CurTy to the type of the field at Index.
3825 CurTy = STy->getTypeAtIndex(Index);
3826 } else {
3827 // Update CurTy to its element type.
3828 if (FirstIter) {
3829 assert(isa<PointerType>(CurTy) &&
3830 "The first index of a GEP indexes a pointer");
3831 CurTy = SrcElementTy;
3832 FirstIter = false;
3833 } else {
3835 }
3836 // For an array, add the element offset, explicitly scaled.
3837 const SCEV *ElementSize = getSizeOfExpr(IntIdxTy, CurTy);
3838 // Getelementptr indices are signed.
3839 IndexExpr = getTruncateOrSignExtend(IndexExpr, IntIdxTy);
3840
3841 // Multiply the index by the element size to compute the element offset.
3842 const SCEV *LocalOffset = getMulExpr(IndexExpr, ElementSize, OffsetWrap);
3843 Offsets.push_back(LocalOffset);
3844 }
3845 }
3846
3847 // Handle degenerate case of GEP without offsets.
3848 if (Offsets.empty())
3849 return BaseExpr;
3850
3851 // Add the offsets together, assuming nsw if inbounds.
3852 const SCEV *Offset = getAddExpr(Offsets, OffsetWrap);
3853 // Add the base address and the offset. We cannot use the nsw flag, as the
3854 // base address is unsigned. However, if we know that the offset is
3855 // non-negative, we can use nuw.
3856 bool NUW = NW.hasNoUnsignedWrap() ||
3859 auto *GEPExpr = getAddExpr(BaseExpr, Offset, BaseWrap);
3860 assert(BaseExpr->getType() == GEPExpr->getType() &&
3861 "GEP should not change type mid-flight.");
3862 return GEPExpr;
3863}
3864
3865SCEV *ScalarEvolution::findExistingSCEVInCache(SCEVTypes SCEVType,
3868 ID.AddInteger(SCEVType);
3869 for (const SCEV *Op : Ops)
3870 ID.AddPointer(Op);
3871 void *IP = nullptr;
3872 return UniqueSCEVs.FindNodeOrInsertPos(ID, IP);
3873}
3874
3875const SCEV *ScalarEvolution::getAbsExpr(const SCEV *Op, bool IsNSW) {
3877 return getSMaxExpr(Op, getNegativeSCEV(Op, Flags));
3878}
3879
3882 assert(SCEVMinMaxExpr::isMinMaxType(Kind) && "Not a SCEVMinMaxExpr!");
3883 assert(!Ops.empty() && "Cannot get empty (u|s)(min|max)!");
3884 if (Ops.size() == 1) return Ops[0];
3885#ifndef NDEBUG
3886 Type *ETy = getEffectiveSCEVType(Ops[0]->getType());
3887 for (unsigned i = 1, e = Ops.size(); i != e; ++i) {
3888 assert(getEffectiveSCEVType(Ops[i]->getType()) == ETy &&
3889 "Operand types don't match!");
3890 assert(Ops[0]->getType()->isPointerTy() ==
3891 Ops[i]->getType()->isPointerTy() &&
3892 "min/max should be consistently pointerish");
3893 }
3894#endif
3895
3896 bool IsSigned = Kind == scSMaxExpr || Kind == scSMinExpr;
3897 bool IsMax = Kind == scSMaxExpr || Kind == scUMaxExpr;
3898
3899 const SCEV *Folded = constantFoldAndGroupOps(
3900 *this, LI, DT, Ops,
3901 [&](const APInt &C1, const APInt &C2) {
3902 switch (Kind) {
3903 case scSMaxExpr:
3904 return APIntOps::smax(C1, C2);
3905 case scSMinExpr:
3906 return APIntOps::smin(C1, C2);
3907 case scUMaxExpr:
3908 return APIntOps::umax(C1, C2);
3909 case scUMinExpr:
3910 return APIntOps::umin(C1, C2);
3911 default:
3912 llvm_unreachable("Unknown SCEV min/max opcode");
3913 }
3914 },
3915 [&](const APInt &C) {
3916 // identity
3917 if (IsMax)
3918 return IsSigned ? C.isMinSignedValue() : C.isMinValue();
3919 else
3920 return IsSigned ? C.isMaxSignedValue() : C.isMaxValue();
3921 },
3922 [&](const APInt &C) {
3923 // absorber
3924 if (IsMax)
3925 return IsSigned ? C.isMaxSignedValue() : C.isMaxValue();
3926 else
3927 return IsSigned ? C.isMinSignedValue() : C.isMinValue();
3928 });
3929 if (Folded)
3930 return Folded;
3931
3932 // Check if we have created the same expression before.
3933 if (const SCEV *S = findExistingSCEVInCache(Kind, Ops)) {
3934 return S;
3935 }
3936
3937 // Find the first operation of the same kind
3938 unsigned Idx = 0;
3939 while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < Kind)
3940 ++Idx;
3941
3942 // Check to see if one of the operands is of the same kind. If so, expand its
3943 // operands onto our operand list, and recurse to simplify.
3944 if (Idx < Ops.size()) {
3945 bool DeletedAny = false;
3946 while (Ops[Idx]->getSCEVType() == Kind) {
3947 const SCEVMinMaxExpr *SMME = cast<SCEVMinMaxExpr>(Ops[Idx]);
3948 Ops.erase(Ops.begin()+Idx);
3949 append_range(Ops, SMME->operands());
3950 DeletedAny = true;
3951 }
3952
3953 if (DeletedAny)
3954 return getMinMaxExpr(Kind, Ops);
3955 }
3956
3957 // Okay, check to see if the same value occurs in the operand list twice. If
3958 // so, delete one. Since we sorted the list, these values are required to
3959 // be adjacent.
3964 llvm::CmpInst::Predicate FirstPred = IsMax ? GEPred : LEPred;
3965 llvm::CmpInst::Predicate SecondPred = IsMax ? LEPred : GEPred;
3966 for (unsigned i = 0, e = Ops.size() - 1; i != e; ++i) {
3967 if (Ops[i] == Ops[i + 1] ||
3968 isKnownViaNonRecursiveReasoning(FirstPred, Ops[i], Ops[i + 1])) {
3969 // X op Y op Y --> X op Y
3970 // X op Y --> X, if we know X, Y are ordered appropriately
3971 Ops.erase(Ops.begin() + i + 1, Ops.begin() + i + 2);
3972 --i;
3973 --e;
3974 } else if (isKnownViaNonRecursiveReasoning(SecondPred, Ops[i],
3975 Ops[i + 1])) {
3976 // X op Y --> Y, if we know X, Y are ordered appropriately
3977 Ops.erase(Ops.begin() + i, Ops.begin() + i + 1);
3978 --i;
3979 --e;
3980 }
3981 }
3982
3983 if (Ops.size() == 1) return Ops[0];
3984
3985 assert(!Ops.empty() && "Reduced smax down to nothing!");
3986
3987 // Okay, it looks like we really DO need an expr. Check to see if we
3988 // already have one, otherwise create a new one.
3990 ID.AddInteger(Kind);
3991 for (const SCEV *Op : Ops)
3992 ID.AddPointer(Op);
3993 void *IP = nullptr;
3994 const SCEV *ExistingSCEV = UniqueSCEVs.FindNodeOrInsertPos(ID, IP);
3995 if (ExistingSCEV)
3996 return ExistingSCEV;
3997 const SCEV **O = SCEVAllocator.Allocate<const SCEV *>(Ops.size());
3999 SCEV *S = new (SCEVAllocator)
4000 SCEVMinMaxExpr(ID.Intern(SCEVAllocator), Kind, O, Ops.size());
4001
4002 UniqueSCEVs.InsertNode(S, IP);
4003 registerUser(S, Ops);
4004 return S;
4005}
4006
4007namespace {
4008
4009class SCEVSequentialMinMaxDeduplicatingVisitor final
4010 : public SCEVVisitor<SCEVSequentialMinMaxDeduplicatingVisitor,
4011 std::optional<const SCEV *>> {
4012 using RetVal = std::optional<const SCEV *>;
4014
4015 ScalarEvolution &SE;
4016 const SCEVTypes RootKind; // Must be a sequential min/max expression.
4017 const SCEVTypes NonSequentialRootKind; // Non-sequential variant of RootKind.
4019
4020 bool canRecurseInto(SCEVTypes Kind) const {
4021 // We can only recurse into the SCEV expression of the same effective type
4022 // as the type of our root SCEV expression.
4023 return RootKind == Kind || NonSequentialRootKind == Kind;
4024 };
4025
4026 RetVal visitAnyMinMaxExpr(const SCEV *S) {
4028 "Only for min/max expressions.");
4029 SCEVTypes Kind = S->getSCEVType();
4030
4031 if (!canRecurseInto(Kind))
4032 return S;
4033
4034 auto *NAry = cast<SCEVNAryExpr>(S);
4036 bool Changed = visit(Kind, NAry->operands(), NewOps);
4037
4038 if (!Changed)
4039 return S;
4040 if (NewOps.empty())
4041 return std::nullopt;
4042
4044 ? SE.getSequentialMinMaxExpr(Kind, NewOps)
4045 : SE.getMinMaxExpr(Kind, NewOps);
4046 }
4047
4048 RetVal visit(const SCEV *S) {
4049 // Has the whole operand been seen already?
4050 if (!SeenOps.insert(S).second)
4051 return std::nullopt;
4052 return Base::visit(S);
4053 }
4054
4055public:
4056 SCEVSequentialMinMaxDeduplicatingVisitor(ScalarEvolution &SE,
4057 SCEVTypes RootKind)
4058 : SE(SE), RootKind(RootKind),
4059 NonSequentialRootKind(
4060 SCEVSequentialMinMaxExpr::getEquivalentNonSequentialSCEVType(
4061 RootKind)) {}
4062
4063 bool /*Changed*/ visit(SCEVTypes Kind, ArrayRef<const SCEV *> OrigOps,
4064 SmallVectorImpl<const SCEV *> &NewOps) {
4065 bool Changed = false;
4067 Ops.reserve(OrigOps.size());
4068
4069 for (const SCEV *Op : OrigOps) {
4070 RetVal NewOp = visit(Op);
4071 if (NewOp != Op)
4072 Changed = true;
4073 if (NewOp)
4074 Ops.emplace_back(*NewOp);
4075 }
4076
4077 if (Changed)
4078 NewOps = std::move(Ops);
4079 return Changed;
4080 }
4081
4082 RetVal visitConstant(const SCEVConstant *Constant) { return Constant; }
4083
4084 RetVal visitVScale(const SCEVVScale *VScale) { return VScale; }
4085
4086 RetVal visitPtrToIntExpr(const SCEVPtrToIntExpr *Expr) { return Expr; }
4087
4088 RetVal visitTruncateExpr(const SCEVTruncateExpr *Expr) { return Expr; }
4089
4090 RetVal visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr) { return Expr; }
4091
4092 RetVal visitSignExtendExpr(const SCEVSignExtendExpr *Expr) { return Expr; }
4093
4094 RetVal visitAddExpr(const SCEVAddExpr *Expr) { return Expr; }
4095
4096 RetVal visitMulExpr(const SCEVMulExpr *Expr) { return Expr; }
4097
4098 RetVal visitUDivExpr(const SCEVUDivExpr *Expr) { return Expr; }
4099
4100 RetVal visitAddRecExpr(const SCEVAddRecExpr *Expr) { return Expr; }
4101
4102 RetVal visitSMaxExpr(const SCEVSMaxExpr *Expr) {
4103 return visitAnyMinMaxExpr(Expr);
4104 }
4105
4106 RetVal visitUMaxExpr(const SCEVUMaxExpr *Expr) {
4107 return visitAnyMinMaxExpr(Expr);
4108 }
4109
4110 RetVal visitSMinExpr(const SCEVSMinExpr *Expr) {
4111 return visitAnyMinMaxExpr(Expr);
4112 }
4113
4114 RetVal visitUMinExpr(const SCEVUMinExpr *Expr) {
4115 return visitAnyMinMaxExpr(Expr);
4116 }
4117
4118 RetVal visitSequentialUMinExpr(const SCEVSequentialUMinExpr *Expr) {
4119 return visitAnyMinMaxExpr(Expr);
4120 }
4121
4122 RetVal visitUnknown(const SCEVUnknown *Expr) { return Expr; }
4123
4124 RetVal visitCouldNotCompute(const SCEVCouldNotCompute *Expr) { return Expr; }
4125};
4126
4127} // namespace
4128
4130 switch (Kind) {
4131 case scConstant:
4132 case scVScale:
4133 case scTruncate:
4134 case scZeroExtend:
4135 case scSignExtend:
4136 case scPtrToInt:
4137 case scAddExpr:
4138 case scMulExpr:
4139 case scUDivExpr:
4140 case scAddRecExpr:
4141 case scUMaxExpr:
4142 case scSMaxExpr:
4143 case scUMinExpr:
4144 case scSMinExpr:
4145 case scUnknown:
4146 // If any operand is poison, the whole expression is poison.
4147 return true;
4149 // FIXME: if the *first* operand is poison, the whole expression is poison.
4150 return false; // Pessimistically, say that it does not propagate poison.
4151 case scCouldNotCompute:
4152 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
4153 }
4154 llvm_unreachable("Unknown SCEV kind!");
4155}
4156
4157namespace {
4158// The only way poison may be introduced in a SCEV expression is from a
4159// poison SCEVUnknown (ConstantExprs are also represented as SCEVUnknown,
4160// not SCEVConstant). Notably, nowrap flags in SCEV nodes can *not*
4161// introduce poison -- they encode guaranteed, non-speculated knowledge.
4162//
4163// Additionally, all SCEV nodes propagate poison from inputs to outputs,
4164// with the notable exception of umin_seq, where only poison from the first
4165// operand is (unconditionally) propagated.
4166struct SCEVPoisonCollector {
4167 bool LookThroughMaybePoisonBlocking;
4168 SmallPtrSet<const SCEVUnknown *, 4> MaybePoison;
4169 SCEVPoisonCollector(bool LookThroughMaybePoisonBlocking)
4170 : LookThroughMaybePoisonBlocking(LookThroughMaybePoisonBlocking) {}
4171
4172 bool follow(const SCEV *S) {
4173 if (!LookThroughMaybePoisonBlocking &&
4175 return false;
4176
4177 if (auto *SU = dyn_cast<SCEVUnknown>(S)) {
4178 if (!isGuaranteedNotToBePoison(SU->getValue()))
4179 MaybePoison.insert(SU);
4180 }
4181 return true;
4182 }
4183 bool isDone() const { return false; }
4184};
4185} // namespace
4186
4187/// Return true if V is poison given that AssumedPoison is already poison.
4188static bool impliesPoison(const SCEV *AssumedPoison, const SCEV *S) {
4189 // First collect all SCEVs that might result in AssumedPoison to be poison.
4190 // We need to look through potentially poison-blocking operations here,
4191 // because we want to find all SCEVs that *might* result in poison, not only
4192 // those that are *required* to.
4193 SCEVPoisonCollector PC1(/* LookThroughMaybePoisonBlocking */ true);
4194 visitAll(AssumedPoison, PC1);
4195
4196 // AssumedPoison is never poison. As the assumption is false, the implication
4197 // is true. Don't bother walking the other SCEV in this case.
4198 if (PC1.MaybePoison.empty())
4199 return true;
4200
4201 // Collect all SCEVs in S that, if poison, *will* result in S being poison
4202 // as well. We cannot look through potentially poison-blocking operations
4203 // here, as their arguments only *may* make the result poison.
4204 SCEVPoisonCollector PC2(/* LookThroughMaybePoisonBlocking */ false);
4205 visitAll(S, PC2);
4206
4207 // Make sure that no matter which SCEV in PC1.MaybePoison is actually poison,
4208 // it will also make S poison by being part of PC2.MaybePoison.
4209 return llvm::set_is_subset(PC1.MaybePoison, PC2.MaybePoison);
4210}
4211
4213 SmallPtrSetImpl<const Value *> &Result, const SCEV *S) {
4214 SCEVPoisonCollector PC(/* LookThroughMaybePoisonBlocking */ false);
4215 visitAll(S, PC);
4216 for (const SCEVUnknown *SU : PC.MaybePoison)
4217 Result.insert(SU->getValue());
4218}
4219
4221 const SCEV *S, Instruction *I,
4222 SmallVectorImpl<Instruction *> &DropPoisonGeneratingInsts) {
4223 // If the instruction cannot be poison, it's always safe to reuse.
4225 return true;
4226
4227 // Otherwise, it is possible that I is more poisonous that S. Collect the
4228 // poison-contributors of S, and then check whether I has any additional
4229 // poison-contributors. Poison that is contributed through poison-generating
4230 // flags is handled by dropping those flags instead.
4232 getPoisonGeneratingValues(PoisonVals, S);
4233
4234 SmallVector<Value *> Worklist;
4236 Worklist.push_back(I);
4237 while (!Worklist.empty()) {
4238 Value *V = Worklist.pop_back_val();
4239 if (!Visited.insert(V).second)
4240 continue;
4241
4242 // Avoid walking large instruction graphs.
4243 if (Visited.size() > 16)
4244 return false;
4245
4246 // Either the value can't be poison, or the S would also be poison if it
4247 // is.
4248 if (PoisonVals.contains(V) || ::isGuaranteedNotToBePoison(V))
4249 continue;
4250
4251 auto *I = dyn_cast<Instruction>(V);
4252 if (!I)
4253 return false;
4254
4255 // Disjoint or instructions are interpreted as adds by SCEV. However, we
4256 // can't replace an arbitrary add with disjoint or, even if we drop the
4257 // flag. We would need to convert the or into an add.
4258 if (auto *PDI = dyn_cast<PossiblyDisjointInst>(I))
4259 if (PDI->isDisjoint())
4260 return false;
4261
4262 // FIXME: Ignore vscale, even though it technically could be poison. Do this
4263 // because SCEV currently assumes it can't be poison. Remove this special
4264 // case once we proper model when vscale can be poison.
4265 if (auto *II = dyn_cast<IntrinsicInst>(I);
4266 II && II->getIntrinsicID() == Intrinsic::vscale)
4267 continue;
4268
4269 if (canCreatePoison(cast<Operator>(I), /*ConsiderFlagsAndMetadata*/ false))
4270 return false;
4271
4272 // If the instruction can't create poison, we can recurse to its operands.
4273 if (I->hasPoisonGeneratingAnnotations())
4274 DropPoisonGeneratingInsts.push_back(I);
4275
4276 llvm::append_range(Worklist, I->operands());
4277 }
4278 return true;
4279}
4280
4281const SCEV *
4284 assert(SCEVSequentialMinMaxExpr::isSequentialMinMaxType(Kind) &&
4285 "Not a SCEVSequentialMinMaxExpr!");
4286 assert(!Ops.empty() && "Cannot get empty (u|s)(min|max)!");
4287 if (Ops.size() == 1)
4288 return Ops[0];
4289#ifndef NDEBUG
4290 Type *ETy = getEffectiveSCEVType(Ops[0]->getType());
4291 for (unsigned i = 1, e = Ops.size(); i != e; ++i) {
4292 assert(getEffectiveSCEVType(Ops[i]->getType()) == ETy &&
4293 "Operand types don't match!");
4294 assert(Ops[0]->getType()->isPointerTy() ==
4295 Ops[i]->getType()->isPointerTy() &&
4296 "min/max should be consistently pointerish");
4297 }
4298#endif
4299
4300 // Note that SCEVSequentialMinMaxExpr is *NOT* commutative,
4301 // so we can *NOT* do any kind of sorting of the expressions!
4302
4303 // Check if we have created the same expression before.
4304 if (const SCEV *S = findExistingSCEVInCache(Kind, Ops))
4305 return S;
4306
4307 // FIXME: there are *some* simplifications that we can do here.
4308
4309 // Keep only the first instance of an operand.
4310 {
4311 SCEVSequentialMinMaxDeduplicatingVisitor Deduplicator(*this, Kind);
4312 bool Changed = Deduplicator.visit(Kind, Ops, Ops);
4313 if (Changed)
4314 return getSequentialMinMaxExpr(Kind, Ops);
4315 }
4316
4317 // Check to see if one of the operands is of the same kind. If so, expand its
4318 // operands onto our operand list, and recurse to simplify.
4319 {
4320 unsigned Idx = 0;
4321 bool DeletedAny = false;
4322 while (Idx < Ops.size()) {
4323 if (Ops[Idx]->getSCEVType() != Kind) {
4324 ++Idx;
4325 continue;
4326 }
4327 const auto *SMME = cast<SCEVSequentialMinMaxExpr>(Ops[Idx]);
4328 Ops.erase(Ops.begin() + Idx);
4329 Ops.insert(Ops.begin() + Idx, SMME->operands().begin(),
4330 SMME->operands().end());
4331 DeletedAny = true;
4332 }
4333
4334 if (DeletedAny)
4335 return getSequentialMinMaxExpr(Kind, Ops);
4336 }
4337
4338 const SCEV *SaturationPoint;
4340 switch (Kind) {
4342 SaturationPoint = getZero(Ops[0]->getType());
4343 Pred = ICmpInst::ICMP_ULE;
4344 break;
4345 default:
4346 llvm_unreachable("Not a sequential min/max type.");
4347 }
4348
4349 for (unsigned i = 1, e = Ops.size(); i != e; ++i) {
4350 if (!isGuaranteedNotToCauseUB(Ops[i]))
4351 continue;
4352 // We can replace %x umin_seq %y with %x umin %y if either:
4353 // * %y being poison implies %x is also poison.
4354 // * %x cannot be the saturating value (e.g. zero for umin).
4355 if (::impliesPoison(Ops[i], Ops[i - 1]) ||
4356 isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_NE, Ops[i - 1],
4357 SaturationPoint)) {
4358 SmallVector<const SCEV *> SeqOps = {Ops[i - 1], Ops[i]};
4359 Ops[i - 1] = getMinMaxExpr(
4361 SeqOps);
4362 Ops.erase(Ops.begin() + i);
4363 return getSequentialMinMaxExpr(Kind, Ops);
4364 }
4365 // Fold %x umin_seq %y to %x if %x ule %y.
4366 // TODO: We might be able to prove the predicate for a later operand.
4367 if (isKnownViaNonRecursiveReasoning(Pred, Ops[i - 1], Ops[i])) {
4368 Ops.erase(Ops.begin() + i);
4369 return getSequentialMinMaxExpr(Kind, Ops);
4370 }
4371 }
4372
4373 // Okay, it looks like we really DO need an expr. Check to see if we
4374 // already have one, otherwise create a new one.
4376 ID.AddInteger(Kind);
4377 for (const SCEV *Op : Ops)
4378 ID.AddPointer(Op);
4379 void *IP = nullptr;
4380 const SCEV *ExistingSCEV = UniqueSCEVs.FindNodeOrInsertPos(ID, IP);
4381 if (ExistingSCEV)
4382 return ExistingSCEV;
4383
4384 const SCEV **O = SCEVAllocator.Allocate<const SCEV *>(Ops.size());
4386 SCEV *S = new (SCEVAllocator)
4387 SCEVSequentialMinMaxExpr(ID.Intern(SCEVAllocator), Kind, O, Ops.size());
4388
4389 UniqueSCEVs.InsertNode(S, IP);
4390 registerUser(S, Ops);
4391 return S;
4392}
4393
4394const SCEV *ScalarEvolution::getSMaxExpr(const SCEV *LHS, const SCEV *RHS) {
4395 SmallVector<const SCEV *, 2> Ops = {LHS, RHS};
4396 return getSMaxExpr(Ops);
4397}
4398
4402
4403const SCEV *ScalarEvolution::getUMaxExpr(const SCEV *LHS, const SCEV *RHS) {
4404 SmallVector<const SCEV *, 2> Ops = {LHS, RHS};
4405 return getUMaxExpr(Ops);
4406}
4407
4411
4413 const SCEV *RHS) {
4414 SmallVector<const SCEV *, 2> Ops = { LHS, RHS };
4415 return getSMinExpr(Ops);
4416}
4417
4421
4422const SCEV *ScalarEvolution::getUMinExpr(const SCEV *LHS, const SCEV *RHS,
4423 bool Sequential) {
4424 SmallVector<const SCEV *, 2> Ops = { LHS, RHS };
4425 return getUMinExpr(Ops, Sequential);
4426}
4427
4433
4434const SCEV *
4436 const SCEV *Res = getConstant(IntTy, Size.getKnownMinValue());
4437 if (Size.isScalable())
4438 Res = getMulExpr(Res, getVScale(IntTy));
4439 return Res;
4440}
4441
4443 return getSizeOfExpr(IntTy, getDataLayout().getTypeAllocSize(AllocTy));
4444}
4445
4447 return getSizeOfExpr(IntTy, getDataLayout().getTypeStoreSize(StoreTy));
4448}
4449
4451 StructType *STy,
4452 unsigned FieldNo) {
4453 // We can bypass creating a target-independent constant expression and then
4454 // folding it back into a ConstantInt. This is just a compile-time
4455 // optimization.
4456 const StructLayout *SL = getDataLayout().getStructLayout(STy);
4457 assert(!SL->getSizeInBits().isScalable() &&
4458 "Cannot get offset for structure containing scalable vector types");
4459 return getConstant(IntTy, SL->getElementOffset(FieldNo));
4460}
4461
4463 // Don't attempt to do anything other than create a SCEVUnknown object
4464 // here. createSCEV only calls getUnknown after checking for all other
4465 // interesting possibilities, and any other code that calls getUnknown
4466 // is doing so in order to hide a value from SCEV canonicalization.
4467
4469 ID.AddInteger(scUnknown);
4470 ID.AddPointer(V);
4471 void *IP = nullptr;
4472 if (SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) {
4473 assert(cast<SCEVUnknown>(S)->getValue() == V &&
4474 "Stale SCEVUnknown in uniquing map!");
4475 return S;
4476 }
4477 SCEV *S = new (SCEVAllocator) SCEVUnknown(ID.Intern(SCEVAllocator), V, this,
4478 FirstUnknown);
4479 FirstUnknown = cast<SCEVUnknown>(S);
4480 UniqueSCEVs.InsertNode(S, IP);
4481 return S;
4482}
4483
4484//===----------------------------------------------------------------------===//
4485// Basic SCEV Analysis and PHI Idiom Recognition Code
4486//
4487
4488/// Test if values of the given type are analyzable within the SCEV
4489/// framework. This primarily includes integer types, and it can optionally
4490/// include pointer types if the ScalarEvolution class has access to
4491/// target-specific information.
4493 // Integers and pointers are always SCEVable.
4494 return Ty->isIntOrPtrTy();
4495}
4496
4497/// Return the size in bits of the specified type, for which isSCEVable must
4498/// return true.
4500 assert(isSCEVable(Ty) && "Type is not SCEVable!");
4501 if (Ty->isPointerTy())
4503 return getDataLayout().getTypeSizeInBits(Ty);
4504}
4505
4506/// Return a type with the same bitwidth as the given type and which represents
4507/// how SCEV will treat the given type, for which isSCEVable must return
4508/// true. For pointer types, this is the pointer index sized integer type.
4510 assert(isSCEVable(Ty) && "Type is not SCEVable!");
4511
4512 if (Ty->isIntegerTy())
4513 return Ty;
4514
4515 // The only other support type is pointer.
4516 assert(Ty->isPointerTy() && "Unexpected non-pointer non-integer type!");
4517 return getDataLayout().getIndexType(Ty);
4518}
4519
4521 return getTypeSizeInBits(T1) >= getTypeSizeInBits(T2) ? T1 : T2;
4522}
4523
4525 const SCEV *B) {
4526 /// For a valid use point to exist, the defining scope of one operand
4527 /// must dominate the other.
4528 bool PreciseA, PreciseB;
4529 auto *ScopeA = getDefiningScopeBound({A}, PreciseA);
4530 auto *ScopeB = getDefiningScopeBound({B}, PreciseB);
4531 if (!PreciseA || !PreciseB)
4532 // Can't tell.
4533 return false;
4534 return (ScopeA == ScopeB) || DT.dominates(ScopeA, ScopeB) ||
4535 DT.dominates(ScopeB, ScopeA);
4536}
4537
4539 return CouldNotCompute.get();
4540}
4541
4542bool ScalarEvolution::checkValidity(const SCEV *S) const {
4543 bool ContainsNulls = SCEVExprContains(S, [](const SCEV *S) {
4544 auto *SU = dyn_cast<SCEVUnknown>(S);
4545 return SU && SU->getValue() == nullptr;
4546 });
4547
4548 return !ContainsNulls;
4549}
4550
4552 HasRecMapType::iterator I = HasRecMap.find(S);
4553 if (I != HasRecMap.end())
4554 return I->second;
4555
4556 bool FoundAddRec =
4557 SCEVExprContains(S, [](const SCEV *S) { return isa<SCEVAddRecExpr>(S); });
4558 HasRecMap.insert({S, FoundAddRec});
4559 return FoundAddRec;
4560}
4561
4562/// Return the ValueOffsetPair set for \p S. \p S can be represented
4563/// by the value and offset from any ValueOffsetPair in the set.
4564ArrayRef<Value *> ScalarEvolution::getSCEVValues(const SCEV *S) {
4565 ExprValueMapType::iterator SI = ExprValueMap.find_as(S);
4566 if (SI == ExprValueMap.end())
4567 return {};
4568 return SI->second.getArrayRef();
4569}
4570
4571/// Erase Value from ValueExprMap and ExprValueMap. ValueExprMap.erase(V)
4572/// cannot be used separately. eraseValueFromMap should be used to remove
4573/// V from ValueExprMap and ExprValueMap at the same time.
4574void ScalarEvolution::eraseValueFromMap(Value *V) {
4575 ValueExprMapType::iterator I = ValueExprMap.find_as(V);
4576 if (I != ValueExprMap.end()) {
4577 auto EVIt = ExprValueMap.find(I->second);
4578 bool Removed = EVIt->second.remove(V);
4579 (void) Removed;
4580 assert(Removed && "Value not in ExprValueMap?");
4581 ValueExprMap.erase(I);
4582 }
4583}
4584
4585void ScalarEvolution::insertValueToMap(Value *V, const SCEV *S) {
4586 // A recursive query may have already computed the SCEV. It should be
4587 // equivalent, but may not necessarily be exactly the same, e.g. due to lazily
4588 // inferred nowrap flags.
4589 auto It = ValueExprMap.find_as(V);
4590 if (It == ValueExprMap.end()) {
4591 ValueExprMap.insert({SCEVCallbackVH(V, this), S});
4592 ExprValueMap[S].insert(V);
4593 }
4594}
4595
4596/// Return an existing SCEV if it exists, otherwise analyze the expression and
4597/// create a new one.
4599 assert(isSCEVable(V->getType()) && "Value is not SCEVable!");
4600
4601 if (const SCEV *S = getExistingSCEV(V))
4602 return S;
4603 return createSCEVIter(V);
4604}
4605
4607 assert(isSCEVable(V->getType()) && "Value is not SCEVable!");
4608
4609 ValueExprMapType::iterator I = ValueExprMap.find_as(V);
4610 if (I != ValueExprMap.end()) {
4611 const SCEV *S = I->second;
4612 assert(checkValidity(S) &&
4613 "existing SCEV has not been properly invalidated");
4614 return S;
4615 }
4616 return nullptr;
4617}
4618
4619/// Return a SCEV corresponding to -V = -1*V
4621 SCEV::NoWrapFlags Flags) {
4622 if (const SCEVConstant *VC = dyn_cast<SCEVConstant>(V))
4623 return getConstant(
4624 cast<ConstantInt>(ConstantExpr::getNeg(VC->getValue())));
4625
4626 Type *Ty = V->getType();
4627 Ty = getEffectiveSCEVType(Ty);
4628 return getMulExpr(V, getMinusOne(Ty), Flags);
4629}
4630
4631/// If Expr computes ~A, return A else return nullptr
4632static const SCEV *MatchNotExpr(const SCEV *Expr) {
4633 const SCEV *MulOp;
4634 if (match(Expr, m_scev_Add(m_scev_AllOnes(),
4635 m_scev_Mul(m_scev_AllOnes(), m_SCEV(MulOp)))))
4636 return MulOp;
4637 return nullptr;
4638}
4639
4640/// Return a SCEV corresponding to ~V = -1-V
4642 assert(!V->getType()->isPointerTy() && "Can't negate pointer");
4643
4644 if (const SCEVConstant *VC = dyn_cast<SCEVConstant>(V))
4645 return getConstant(
4646 cast<ConstantInt>(ConstantExpr::getNot(VC->getValue())));
4647
4648 // Fold ~(u|s)(min|max)(~x, ~y) to (u|s)(max|min)(x, y)
4649 if (const SCEVMinMaxExpr *MME = dyn_cast<SCEVMinMaxExpr>(V)) {
4650 auto MatchMinMaxNegation = [&](const SCEVMinMaxExpr *MME) {
4651 SmallVector<const SCEV *, 2> MatchedOperands;
4652 for (const SCEV *Operand : MME->operands()) {
4653 const SCEV *Matched = MatchNotExpr(Operand);
4654 if (!Matched)
4655 return (const SCEV *)nullptr;
4656 MatchedOperands.push_back(Matched);
4657 }
4658 return getMinMaxExpr(SCEVMinMaxExpr::negate(MME->getSCEVType()),
4659 MatchedOperands);
4660 };
4661 if (const SCEV *Replaced = MatchMinMaxNegation(MME))
4662 return Replaced;
4663 }
4664
4665 Type *Ty = V->getType();
4666 Ty = getEffectiveSCEVType(Ty);
4667 return getMinusSCEV(getMinusOne(Ty), V);
4668}
4669
4671 assert(P->getType()->isPointerTy());
4672
4673 if (auto *AddRec = dyn_cast<SCEVAddRecExpr>(P)) {
4674 // The base of an AddRec is the first operand.
4675 SmallVector<const SCEV *> Ops{AddRec->operands()};
4676 Ops[0] = removePointerBase(Ops[0]);
4677 // Don't try to transfer nowrap flags for now. We could in some cases
4678 // (for example, if pointer operand of the AddRec is a SCEVUnknown).
4679 return getAddRecExpr(Ops, AddRec->getLoop(), SCEV::FlagAnyWrap);
4680 }
4681 if (auto *Add = dyn_cast<SCEVAddExpr>(P)) {
4682 // The base of an Add is the pointer operand.
4683 SmallVector<const SCEV *> Ops{Add->operands()};
4684 const SCEV **PtrOp = nullptr;
4685 for (const SCEV *&AddOp : Ops) {
4686 if (AddOp->getType()->isPointerTy()) {
4687 assert(!PtrOp && "Cannot have multiple pointer ops");
4688 PtrOp = &AddOp;
4689 }
4690 }
4691 *PtrOp = removePointerBase(*PtrOp);
4692 // Don't try to transfer nowrap flags for now. We could in some cases
4693 // (for example, if the pointer operand of the Add is a SCEVUnknown).
4694 return getAddExpr(Ops);
4695 }
4696 // Any other expression must be a pointer base.
4697 return getZero(P->getType());
4698}
4699
4700const SCEV *ScalarEvolution::getMinusSCEV(const SCEV *LHS, const SCEV *RHS,
4701 SCEV::NoWrapFlags Flags,
4702 unsigned Depth) {
4703 // Fast path: X - X --> 0.
4704 if (LHS == RHS)
4705 return getZero(LHS->getType());
4706
4707 // If we subtract two pointers with different pointer bases, bail.
4708 // Eventually, we're going to add an assertion to getMulExpr that we
4709 // can't multiply by a pointer.
4710 if (RHS->getType()->isPointerTy()) {
4711 if (!LHS->getType()->isPointerTy() ||
4712 getPointerBase(LHS) != getPointerBase(RHS))
4713 return getCouldNotCompute();
4714 LHS = removePointerBase(LHS);
4715 RHS = removePointerBase(RHS);
4716 }
4717
4718 // We represent LHS - RHS as LHS + (-1)*RHS. This transformation
4719 // makes it so that we cannot make much use of NUW.
4720 auto AddFlags = SCEV::FlagAnyWrap;
4721 const bool RHSIsNotMinSigned =
4723 if (hasFlags(Flags, SCEV::FlagNSW)) {
4724 // Let M be the minimum representable signed value. Then (-1)*RHS
4725 // signed-wraps if and only if RHS is M. That can happen even for
4726 // a NSW subtraction because e.g. (-1)*M signed-wraps even though
4727 // -1 - M does not. So to transfer NSW from LHS - RHS to LHS +
4728 // (-1)*RHS, we need to prove that RHS != M.
4729 //
4730 // If LHS is non-negative and we know that LHS - RHS does not
4731 // signed-wrap, then RHS cannot be M. So we can rule out signed-wrap
4732 // either by proving that RHS > M or that LHS >= 0.
4733 if (RHSIsNotMinSigned || isKnownNonNegative(LHS)) {
4734 AddFlags = SCEV::FlagNSW;
4735 }
4736 }
4737
4738 // FIXME: Find a correct way to transfer NSW to (-1)*M when LHS -
4739 // RHS is NSW and LHS >= 0.
4740 //
4741 // The difficulty here is that the NSW flag may have been proven
4742 // relative to a loop that is to be found in a recurrence in LHS and
4743 // not in RHS. Applying NSW to (-1)*M may then let the NSW have a
4744 // larger scope than intended.
4745 auto NegFlags = RHSIsNotMinSigned ? SCEV::FlagNSW : SCEV::FlagAnyWrap;
4746
4747 return getAddExpr(LHS, getNegativeSCEV(RHS, NegFlags), AddFlags, Depth);
4748}
4749
4751 unsigned Depth) {
4752 Type *SrcTy = V->getType();
4753 assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
4754 "Cannot truncate or zero extend with non-integer arguments!");
4755 if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
4756 return V; // No conversion
4757 if (getTypeSizeInBits(SrcTy) > getTypeSizeInBits(Ty))
4758 return getTruncateExpr(V, Ty, Depth);
4759 return getZeroExtendExpr(V, Ty, Depth);
4760}
4761
4763 unsigned Depth) {
4764 Type *SrcTy = V->getType();
4765 assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
4766 "Cannot truncate or zero extend with non-integer arguments!");
4767 if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
4768 return V; // No conversion
4769 if (getTypeSizeInBits(SrcTy) > getTypeSizeInBits(Ty))
4770 return getTruncateExpr(V, Ty, Depth);
4771 return getSignExtendExpr(V, Ty, Depth);
4772}
4773
4774const SCEV *
4776 Type *SrcTy = V->getType();
4777 assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
4778 "Cannot noop or zero extend with non-integer arguments!");
4780 "getNoopOrZeroExtend cannot truncate!");
4781 if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
4782 return V; // No conversion
4783 return getZeroExtendExpr(V, Ty);
4784}
4785
4786const SCEV *
4788 Type *SrcTy = V->getType();
4789 assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
4790 "Cannot noop or sign extend with non-integer arguments!");
4792 "getNoopOrSignExtend cannot truncate!");
4793 if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
4794 return V; // No conversion
4795 return getSignExtendExpr(V, Ty);
4796}
4797
4798const SCEV *
4800 Type *SrcTy = V->getType();
4801 assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
4802 "Cannot noop or any extend with non-integer arguments!");
4804 "getNoopOrAnyExtend cannot truncate!");
4805 if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
4806 return V; // No conversion
4807 return getAnyExtendExpr(V, Ty);
4808}
4809
4810const SCEV *
4812 Type *SrcTy = V->getType();
4813 assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
4814 "Cannot truncate or noop with non-integer arguments!");
4816 "getTruncateOrNoop cannot extend!");
4817 if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
4818 return V; // No conversion
4819 return getTruncateExpr(V, Ty);
4820}
4821
4823 const SCEV *RHS) {
4824 const SCEV *PromotedLHS = LHS;
4825 const SCEV *PromotedRHS = RHS;
4826
4827 if (getTypeSizeInBits(LHS->getType()) > getTypeSizeInBits(RHS->getType()))
4828 PromotedRHS = getZeroExtendExpr(RHS, LHS->getType());
4829 else
4830 PromotedLHS = getNoopOrZeroExtend(LHS, RHS->getType());
4831
4832 return getUMaxExpr(PromotedLHS, PromotedRHS);
4833}
4834
4836 const SCEV *RHS,
4837 bool Sequential) {
4838 SmallVector<const SCEV *, 2> Ops = { LHS, RHS };
4839 return getUMinFromMismatchedTypes(Ops, Sequential);
4840}
4841
4842const SCEV *
4844 bool Sequential) {
4845 assert(!Ops.empty() && "At least one operand must be!");
4846 // Trivial case.
4847 if (Ops.size() == 1)
4848 return Ops[0];
4849
4850 // Find the max type first.
4851 Type *MaxType = nullptr;
4852 for (const auto *S : Ops)
4853 if (MaxType)
4854 MaxType = getWiderType(MaxType, S->getType());
4855 else
4856 MaxType = S->getType();
4857 assert(MaxType && "Failed to find maximum type!");
4858
4859 // Extend all ops to max type.
4860 SmallVector<const SCEV *, 2> PromotedOps;
4861 for (const auto *S : Ops)
4862 PromotedOps.push_back(getNoopOrZeroExtend(S, MaxType));
4863
4864 // Generate umin.
4865 return getUMinExpr(PromotedOps, Sequential);
4866}
4867
4869 // A pointer operand may evaluate to a nonpointer expression, such as null.
4870 if (!V->getType()->isPointerTy())
4871 return V;
4872
4873 while (true) {
4874 if (auto *AddRec = dyn_cast<SCEVAddRecExpr>(V)) {
4875 V = AddRec->getStart();
4876 } else if (auto *Add = dyn_cast<SCEVAddExpr>(V)) {
4877 const SCEV *PtrOp = nullptr;
4878 for (const SCEV *AddOp : Add->operands()) {
4879 if (AddOp->getType()->isPointerTy()) {
4880 assert(!PtrOp && "Cannot have multiple pointer ops");
4881 PtrOp = AddOp;
4882 }
4883 }
4884 assert(PtrOp && "Must have pointer op");
4885 V = PtrOp;
4886 } else // Not something we can look further into.
4887 return V;
4888 }
4889}
4890
4891/// Push users of the given Instruction onto the given Worklist.
4895 // Push the def-use children onto the Worklist stack.
4896 for (User *U : I->users()) {
4897 auto *UserInsn = cast<Instruction>(U);
4898 if (Visited.insert(UserInsn).second)
4899 Worklist.push_back(UserInsn);
4900 }
4901}
4902
4903namespace {
4904
4905/// Takes SCEV S and Loop L. For each AddRec sub-expression, use its start
4906/// expression in case its Loop is L. If it is not L then
4907/// if IgnoreOtherLoops is true then use AddRec itself
4908/// otherwise rewrite cannot be done.
4909/// If SCEV contains non-invariant unknown SCEV rewrite cannot be done.
4910class SCEVInitRewriter : public SCEVRewriteVisitor<SCEVInitRewriter> {
4911public:
4912 static const SCEV *rewrite(const SCEV *S, const Loop *L, ScalarEvolution &SE,
4913 bool IgnoreOtherLoops = true) {
4914 SCEVInitRewriter Rewriter(L, SE);
4915 const SCEV *Result = Rewriter.visit(S);
4916 if (Rewriter.hasSeenLoopVariantSCEVUnknown())
4917 return SE.getCouldNotCompute();
4918 return Rewriter.hasSeenOtherLoops() && !IgnoreOtherLoops
4919 ? SE.getCouldNotCompute()
4920 : Result;
4921 }
4922
4923 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
4924 if (!SE.isLoopInvariant(Expr, L))
4925 SeenLoopVariantSCEVUnknown = true;
4926 return Expr;
4927 }
4928
4929 const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) {
4930 // Only re-write AddRecExprs for this loop.
4931 if (Expr->getLoop() == L)
4932 return Expr->getStart();
4933 SeenOtherLoops = true;
4934 return Expr;
4935 }
4936
4937 bool hasSeenLoopVariantSCEVUnknown() { return SeenLoopVariantSCEVUnknown; }
4938
4939 bool hasSeenOtherLoops() { return SeenOtherLoops; }
4940
4941private:
4942 explicit SCEVInitRewriter(const Loop *L, ScalarEvolution &SE)
4943 : SCEVRewriteVisitor(SE), L(L) {}
4944
4945 const Loop *L;
4946 bool SeenLoopVariantSCEVUnknown = false;
4947 bool SeenOtherLoops = false;
4948};
4949
4950/// Takes SCEV S and Loop L. For each AddRec sub-expression, use its post
4951/// increment expression in case its Loop is L. If it is not L then
4952/// use AddRec itself.
4953/// If SCEV contains non-invariant unknown SCEV rewrite cannot be done.
4954class SCEVPostIncRewriter : public SCEVRewriteVisitor<SCEVPostIncRewriter> {
4955public:
4956 static const SCEV *rewrite(const SCEV *S, const Loop *L, ScalarEvolution &SE) {
4957 SCEVPostIncRewriter Rewriter(L, SE);
4958 const SCEV *Result = Rewriter.visit(S);
4959 return Rewriter.hasSeenLoopVariantSCEVUnknown()
4960 ? SE.getCouldNotCompute()
4961 : Result;
4962 }
4963
4964 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
4965 if (!SE.isLoopInvariant(Expr, L))
4966 SeenLoopVariantSCEVUnknown = true;
4967 return Expr;
4968 }
4969
4970 const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) {
4971 // Only re-write AddRecExprs for this loop.
4972 if (Expr->getLoop() == L)
4973 return Expr->getPostIncExpr(SE);
4974 SeenOtherLoops = true;
4975 return Expr;
4976 }
4977
4978 bool hasSeenLoopVariantSCEVUnknown() { return SeenLoopVariantSCEVUnknown; }
4979
4980 bool hasSeenOtherLoops() { return SeenOtherLoops; }
4981
4982private:
4983 explicit SCEVPostIncRewriter(const Loop *L, ScalarEvolution &SE)
4984 : SCEVRewriteVisitor(SE), L(L) {}
4985
4986 const Loop *L;
4987 bool SeenLoopVariantSCEVUnknown = false;
4988 bool SeenOtherLoops = false;
4989};
4990
4991/// This class evaluates the compare condition by matching it against the
4992/// condition of loop latch. If there is a match we assume a true value
4993/// for the condition while building SCEV nodes.
4994class SCEVBackedgeConditionFolder
4995 : public SCEVRewriteVisitor<SCEVBackedgeConditionFolder> {
4996public:
4997 static const SCEV *rewrite(const SCEV *S, const Loop *L,
4998 ScalarEvolution &SE) {
4999 bool IsPosBECond = false;
5000 Value *BECond = nullptr;
5001 if (BasicBlock *Latch = L->getLoopLatch()) {
5002 BranchInst *BI = dyn_cast<BranchInst>(Latch->getTerminator());
5003 if (BI && BI->isConditional()) {
5004 assert(BI->getSuccessor(0) != BI->getSuccessor(1) &&
5005 "Both outgoing branches should not target same header!");
5006 BECond = BI->getCondition();
5007 IsPosBECond = BI->getSuccessor(0) == L->getHeader();
5008 } else {
5009 return S;
5010 }
5011 }
5012 SCEVBackedgeConditionFolder Rewriter(L, BECond, IsPosBECond, SE);
5013 return Rewriter.visit(S);
5014 }
5015
5016 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
5017 const SCEV *Result = Expr;
5018 bool InvariantF = SE.isLoopInvariant(Expr, L);
5019
5020 if (!InvariantF) {
5022 switch (I->getOpcode()) {
5023 case Instruction::Select: {
5024 SelectInst *SI = cast<SelectInst>(I);
5025 std::optional<const SCEV *> Res =
5026 compareWithBackedgeCondition(SI->getCondition());
5027 if (Res) {
5028 bool IsOne = cast<SCEVConstant>(*Res)->getValue()->isOne();
5029 Result = SE.getSCEV(IsOne ? SI->getTrueValue() : SI->getFalseValue());
5030 }
5031 break;
5032 }
5033 default: {
5034 std::optional<const SCEV *> Res = compareWithBackedgeCondition(I);
5035 if (Res)
5036 Result = *Res;
5037 break;
5038 }
5039 }
5040 }
5041 return Result;
5042 }
5043
5044private:
5045 explicit SCEVBackedgeConditionFolder(const Loop *L, Value *BECond,
5046 bool IsPosBECond, ScalarEvolution &SE)
5047 : SCEVRewriteVisitor(SE), L(L), BackedgeCond(BECond),
5048 IsPositiveBECond(IsPosBECond) {}
5049
5050 std::optional<const SCEV *> compareWithBackedgeCondition(Value *IC);
5051
5052 const Loop *L;
5053 /// Loop back condition.
5054 Value *BackedgeCond = nullptr;
5055 /// Set to true if loop back is on positive branch condition.
5056 bool IsPositiveBECond;
5057};
5058
5059std::optional<const SCEV *>
5060SCEVBackedgeConditionFolder::compareWithBackedgeCondition(Value *IC) {
5061
5062 // If value matches the backedge condition for loop latch,
5063 // then return a constant evolution node based on loopback
5064 // branch taken.
5065 if (BackedgeCond == IC)
5066 return IsPositiveBECond ? SE.getOne(Type::getInt1Ty(SE.getContext()))
5068 return std::nullopt;
5069}
5070
5071class SCEVShiftRewriter : public SCEVRewriteVisitor<SCEVShiftRewriter> {
5072public:
5073 static const SCEV *rewrite(const SCEV *S, const Loop *L,
5074 ScalarEvolution &SE) {
5075 SCEVShiftRewriter Rewriter(L, SE);
5076 const SCEV *Result = Rewriter.visit(S);
5077 return Rewriter.isValid() ? Result : SE.getCouldNotCompute();
5078 }
5079
5080 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
5081 // Only allow AddRecExprs for this loop.
5082 if (!SE.isLoopInvariant(Expr, L))
5083 Valid = false;
5084 return Expr;
5085 }
5086
5087 const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) {
5088 if (Expr->getLoop() == L && Expr->isAffine())
5089 return SE.getMinusSCEV(Expr, Expr->getStepRecurrence(SE));
5090 Valid = false;
5091 return Expr;
5092 }
5093
5094 bool isValid() { return Valid; }
5095
5096private:
5097 explicit SCEVShiftRewriter(const Loop *L, ScalarEvolution &SE)
5098 : SCEVRewriteVisitor(SE), L(L) {}
5099
5100 const Loop *L;
5101 bool Valid = true;
5102};
5103
5104} // end anonymous namespace
5105
5107ScalarEvolution::proveNoWrapViaConstantRanges(const SCEVAddRecExpr *AR) {
5108 if (!AR->isAffine())
5109 return SCEV::FlagAnyWrap;
5110
5111 using OBO = OverflowingBinaryOperator;
5112
5114
5115 if (!AR->hasNoSelfWrap()) {
5116 const SCEV *BECount = getConstantMaxBackedgeTakenCount(AR->getLoop());
5117 if (const SCEVConstant *BECountMax = dyn_cast<SCEVConstant>(BECount)) {
5118 ConstantRange StepCR = getSignedRange(AR->getStepRecurrence(*this));
5119 const APInt &BECountAP = BECountMax->getAPInt();
5120 unsigned NoOverflowBitWidth =
5121 BECountAP.getActiveBits() + StepCR.getMinSignedBits();
5122 if (NoOverflowBitWidth <= getTypeSizeInBits(AR->getType()))
5124 }
5125 }
5126
5127 if (!AR->hasNoSignedWrap()) {
5128 ConstantRange AddRecRange = getSignedRange(AR);
5129 ConstantRange IncRange = getSignedRange(AR->getStepRecurrence(*this));
5130
5132 Instruction::Add, IncRange, OBO::NoSignedWrap);
5133 if (NSWRegion.contains(AddRecRange))
5135 }
5136
5137 if (!AR->hasNoUnsignedWrap()) {
5138 ConstantRange AddRecRange = getUnsignedRange(AR);
5139 ConstantRange IncRange = getUnsignedRange(AR->getStepRecurrence(*this));
5140
5142 Instruction::Add, IncRange, OBO::NoUnsignedWrap);
5143 if (NUWRegion.contains(AddRecRange))
5145 }
5146
5147 return Result;
5148}
5149
5151ScalarEvolution::proveNoSignedWrapViaInduction(const SCEVAddRecExpr *AR) {
5153
5154 if (AR->hasNoSignedWrap())
5155 return Result;
5156
5157 if (!AR->isAffine())
5158 return Result;
5159
5160 // This function can be expensive, only try to prove NSW once per AddRec.
5161 if (!SignedWrapViaInductionTried.insert(AR).second)
5162 return Result;
5163
5164 const SCEV *Step = AR->getStepRecurrence(*this);
5165 const Loop *L = AR->getLoop();
5166
5167 // Check whether the backedge-taken count is SCEVCouldNotCompute.
5168 // Note that this serves two purposes: It filters out loops that are
5169 // simply not analyzable, and it covers the case where this code is
5170 // being called from within backedge-taken count analysis, such that
5171 // attempting to ask for the backedge-taken count would likely result
5172 // in infinite recursion. In the later case, the analysis code will
5173 // cope with a conservative value, and it will take care to purge
5174 // that value once it has finished.
5175 const SCEV *MaxBECount = getConstantMaxBackedgeTakenCount(L);
5176
5177 // Normally, in the cases we can prove no-overflow via a
5178 // backedge guarding condition, we can also compute a backedge
5179 // taken count for the loop. The exceptions are assumptions and
5180 // guards present in the loop -- SCEV is not great at exploiting
5181 // these to compute max backedge taken counts, but can still use
5182 // these to prove lack of overflow. Use this fact to avoid
5183 // doing extra work that may not pay off.
5184
5185 if (isa<SCEVCouldNotCompute>(MaxBECount) && !HasGuards &&
5186 AC.assumptions().empty())
5187 return Result;
5188
5189 // If the backedge is guarded by a comparison with the pre-inc value the
5190 // addrec is safe. Also, if the entry is guarded by a comparison with the
5191 // start value and the backedge is guarded by a comparison with the post-inc
5192 // value, the addrec is safe.
5194 const SCEV *OverflowLimit =
5195 getSignedOverflowLimitForStep(Step, &Pred, this);
5196 if (OverflowLimit &&
5197 (isLoopBackedgeGuardedByCond(L, Pred, AR, OverflowLimit) ||
5198 isKnownOnEveryIteration(Pred, AR, OverflowLimit))) {
5199 Result = setFlags(Result, SCEV::FlagNSW);
5200 }
5201 return Result;
5202}
5204ScalarEvolution::proveNoUnsignedWrapViaInduction(const SCEVAddRecExpr *AR) {
5206
5207 if (AR->hasNoUnsignedWrap())
5208 return Result;
5209
5210 if (!AR->isAffine())
5211 return Result;
5212
5213 // This function can be expensive, only try to prove NUW once per AddRec.
5214 if (!UnsignedWrapViaInductionTried.insert(AR).second)
5215 return Result;
5216
5217 const SCEV *Step = AR->getStepRecurrence(*this);
5218 unsigned BitWidth = getTypeSizeInBits(AR->getType());
5219 const Loop *L = AR->getLoop();
5220
5221 // Check whether the backedge-taken count is SCEVCouldNotCompute.
5222 // Note that this serves two purposes: It filters out loops that are
5223 // simply not analyzable, and it covers the case where this code is
5224 // being called from within backedge-taken count analysis, such that
5225 // attempting to ask for the backedge-taken count would likely result
5226 // in infinite recursion. In the later case, the analysis code will
5227 // cope with a conservative value, and it will take care to purge
5228 // that value once it has finished.
5229 const SCEV *MaxBECount = getConstantMaxBackedgeTakenCount(L);
5230
5231 // Normally, in the cases we can prove no-overflow via a
5232 // backedge guarding condition, we can also compute a backedge
5233 // taken count for the loop. The exceptions are assumptions and
5234 // guards present in the loop -- SCEV is not great at exploiting
5235 // these to compute max backedge taken counts, but can still use
5236 // these to prove lack of overflow. Use this fact to avoid
5237 // doing extra work that may not pay off.
5238
5239 if (isa<SCEVCouldNotCompute>(MaxBECount) && !HasGuards &&
5240 AC.assumptions().empty())
5241 return Result;
5242
5243 // If the backedge is guarded by a comparison with the pre-inc value the
5244 // addrec is safe. Also, if the entry is guarded by a comparison with the
5245 // start value and the backedge is guarded by a comparison with the post-inc
5246 // value, the addrec is safe.
5247 if (isKnownPositive(Step)) {
5248 const SCEV *N = getConstant(APInt::getMinValue(BitWidth) -
5249 getUnsignedRangeMax(Step));
5252 Result = setFlags(Result, SCEV::FlagNUW);
5253 }
5254 }
5255
5256 return Result;
5257}
5258
5259namespace {
5260
5261/// Represents an abstract binary operation. This may exist as a
5262/// normal instruction or constant expression, or may have been
5263/// derived from an expression tree.
5264struct BinaryOp {
5265 unsigned Opcode;
5266 Value *LHS;
5267 Value *RHS;
5268 bool IsNSW = false;
5269 bool IsNUW = false;
5270
5271 /// Op is set if this BinaryOp corresponds to a concrete LLVM instruction or
5272 /// constant expression.
5273 Operator *Op = nullptr;
5274
5275 explicit BinaryOp(Operator *Op)
5276 : Opcode(Op->getOpcode()), LHS(Op->getOperand(0)), RHS(Op->getOperand(1)),
5277 Op(Op) {
5278 if (auto *OBO = dyn_cast<OverflowingBinaryOperator>(Op)) {
5279 IsNSW = OBO->hasNoSignedWrap();
5280 IsNUW = OBO->hasNoUnsignedWrap();
5281 }
5282 }
5283
5284 explicit BinaryOp(unsigned Opcode, Value *LHS, Value *RHS, bool IsNSW = false,
5285 bool IsNUW = false)
5286 : Opcode(Opcode), LHS(LHS), RHS(RHS), IsNSW(IsNSW), IsNUW(IsNUW) {}
5287};
5288
5289} // end anonymous namespace
5290
5291/// Try to map \p V into a BinaryOp, and return \c std::nullopt on failure.
5292static std::optional<BinaryOp> MatchBinaryOp(Value *V, const DataLayout &DL,
5293 AssumptionCache &AC,
5294 const DominatorTree &DT,
5295 const Instruction *CxtI) {
5296 auto *Op = dyn_cast<Operator>(V);
5297 if (!Op)
5298 return std::nullopt;
5299
5300 // Implementation detail: all the cleverness here should happen without
5301 // creating new SCEV expressions -- our caller knowns tricks to avoid creating
5302 // SCEV expressions when possible, and we should not break that.
5303
5304 switch (Op->getOpcode()) {
5305 case Instruction::Add:
5306 case Instruction::Sub:
5307 case Instruction::Mul:
5308 case Instruction::UDiv:
5309 case Instruction::URem:
5310 case Instruction::And:
5311 case Instruction::AShr:
5312 case Instruction::Shl:
5313 return BinaryOp(Op);
5314
5315 case Instruction::Or: {
5316 // Convert or disjoint into add nuw nsw.
5317 if (cast<PossiblyDisjointInst>(Op)->isDisjoint())
5318 return BinaryOp(Instruction::Add, Op->getOperand(0), Op->getOperand(1),
5319 /*IsNSW=*/true, /*IsNUW=*/true);
5320 return BinaryOp(Op);
5321 }
5322
5323 case Instruction::Xor:
5324 if (auto *RHSC = dyn_cast<ConstantInt>(Op->getOperand(1)))
5325 // If the RHS of the xor is a signmask, then this is just an add.
5326 // Instcombine turns add of signmask into xor as a strength reduction step.
5327 if (RHSC->getValue().isSignMask())
5328 return BinaryOp(Instruction::Add, Op->getOperand(0), Op->getOperand(1));
5329 // Binary `xor` is a bit-wise `add`.
5330 if (V->getType()->isIntegerTy(1))
5331 return BinaryOp(Instruction::Add, Op->getOperand(0), Op->getOperand(1));
5332 return BinaryOp(Op);
5333
5334 case Instruction::LShr:
5335 // Turn logical shift right of a constant into a unsigned divide.
5336 if (ConstantInt *SA = dyn_cast<ConstantInt>(Op->getOperand(1))) {
5337 uint32_t BitWidth = cast<IntegerType>(Op->getType())->getBitWidth();
5338
5339 // If the shift count is not less than the bitwidth, the result of
5340 // the shift is undefined. Don't try to analyze it, because the
5341 // resolution chosen here may differ from the resolution chosen in
5342 // other parts of the compiler.
5343 if (SA->getValue().ult(BitWidth)) {
5344 Constant *X =
5345 ConstantInt::get(SA->getContext(),
5346 APInt::getOneBitSet(BitWidth, SA->getZExtValue()));
5347 return BinaryOp(Instruction::UDiv, Op->getOperand(0), X);
5348 }
5349 }
5350 return BinaryOp(Op);
5351
5352 case Instruction::ExtractValue: {
5353 auto *EVI = cast<ExtractValueInst>(Op);
5354 if (EVI->getNumIndices() != 1 || EVI->getIndices()[0] != 0)
5355 break;
5356
5357 auto *WO = dyn_cast<WithOverflowInst>(EVI->getAggregateOperand());
5358 if (!WO)
5359 break;
5360
5361 Instruction::BinaryOps BinOp = WO->getBinaryOp();
5362 bool Signed = WO->isSigned();
5363 // TODO: Should add nuw/nsw flags for mul as well.
5364 if (BinOp == Instruction::Mul || !isOverflowIntrinsicNoWrap(WO, DT))
5365 return BinaryOp(BinOp, WO->getLHS(), WO->getRHS());
5366
5367 // Now that we know that all uses of the arithmetic-result component of
5368 // CI are guarded by the overflow check, we can go ahead and pretend
5369 // that the arithmetic is non-overflowing.
5370 return BinaryOp(BinOp, WO->getLHS(), WO->getRHS(),
5371 /* IsNSW = */ Signed, /* IsNUW = */ !Signed);
5372 }
5373
5374 default:
5375 break;
5376 }
5377
5378 // Recognise intrinsic loop.decrement.reg, and as this has exactly the same
5379 // semantics as a Sub, return a binary sub expression.
5380 if (auto *II = dyn_cast<IntrinsicInst>(V))
5381 if (II->getIntrinsicID() == Intrinsic::loop_decrement_reg)
5382 return BinaryOp(Instruction::Sub, II->getOperand(0), II->getOperand(1));
5383
5384 return std::nullopt;
5385}
5386
5387/// Helper function to createAddRecFromPHIWithCasts. We have a phi
5388/// node whose symbolic (unknown) SCEV is \p SymbolicPHI, which is updated via
5389/// the loop backedge by a SCEVAddExpr, possibly also with a few casts on the
5390/// way. This function checks if \p Op, an operand of this SCEVAddExpr,
5391/// follows one of the following patterns:
5392/// Op == (SExt ix (Trunc iy (%SymbolicPHI) to ix) to iy)
5393/// Op == (ZExt ix (Trunc iy (%SymbolicPHI) to ix) to iy)
5394/// If the SCEV expression of \p Op conforms with one of the expected patterns
5395/// we return the type of the truncation operation, and indicate whether the
5396/// truncated type should be treated as signed/unsigned by setting
5397/// \p Signed to true/false, respectively.
5398static Type *isSimpleCastedPHI(const SCEV *Op, const SCEVUnknown *SymbolicPHI,
5399 bool &Signed, ScalarEvolution &SE) {
5400 // The case where Op == SymbolicPHI (that is, with no type conversions on
5401 // the way) is handled by the regular add recurrence creating logic and
5402 // would have already been triggered in createAddRecForPHI. Reaching it here
5403 // means that createAddRecFromPHI had failed for this PHI before (e.g.,
5404 // because one of the other operands of the SCEVAddExpr updating this PHI is
5405 // not invariant).
5406 //
5407 // Here we look for the case where Op = (ext(trunc(SymbolicPHI))), and in
5408 // this case predicates that allow us to prove that Op == SymbolicPHI will
5409 // be added.
5410 if (Op == SymbolicPHI)
5411 return nullptr;
5412
5413 unsigned SourceBits = SE.getTypeSizeInBits(SymbolicPHI->getType());
5414 unsigned NewBits = SE.getTypeSizeInBits(Op->getType());
5415 if (SourceBits != NewBits)
5416 return nullptr;
5417
5418 if (match(Op, m_scev_SExt(m_scev_Trunc(m_scev_Specific(SymbolicPHI))))) {
5419 Signed = true;
5420 return cast<SCEVCastExpr>(Op)->getOperand()->getType();
5421 }
5422 if (match(Op, m_scev_ZExt(m_scev_Trunc(m_scev_Specific(SymbolicPHI))))) {
5423 Signed = false;
5424 return cast<SCEVCastExpr>(Op)->getOperand()->getType();
5425 }
5426 return nullptr;
5427}
5428
5429static const Loop *isIntegerLoopHeaderPHI(const PHINode *PN, LoopInfo &LI) {
5430 if (!PN->getType()->isIntegerTy())
5431 return nullptr;
5432 const Loop *L = LI.getLoopFor(PN->getParent());
5433 if (!L || L->getHeader() != PN->getParent())
5434 return nullptr;
5435 return L;
5436}
5437
5438// Analyze \p SymbolicPHI, a SCEV expression of a phi node, and check if the
5439// computation that updates the phi follows the following pattern:
5440// (SExt/ZExt ix (Trunc iy (%SymbolicPHI) to ix) to iy) + InvariantAccum
5441// which correspond to a phi->trunc->sext/zext->add->phi update chain.
5442// If so, try to see if it can be rewritten as an AddRecExpr under some
5443// Predicates. If successful, return them as a pair. Also cache the results
5444// of the analysis.
5445//
5446// Example usage scenario:
5447// Say the Rewriter is called for the following SCEV:
5448// 8 * ((sext i32 (trunc i64 %X to i32) to i64) + %Step)
5449// where:
5450// %X = phi i64 (%Start, %BEValue)
5451// It will visitMul->visitAdd->visitSExt->visitTrunc->visitUnknown(%X),
5452// and call this function with %SymbolicPHI = %X.
5453//
5454// The analysis will find that the value coming around the backedge has
5455// the following SCEV:
5456// BEValue = ((sext i32 (trunc i64 %X to i32) to i64) + %Step)
5457// Upon concluding that this matches the desired pattern, the function
5458// will return the pair {NewAddRec, SmallPredsVec} where:
5459// NewAddRec = {%Start,+,%Step}
5460// SmallPredsVec = {P1, P2, P3} as follows:
5461// P1(WrapPred): AR: {trunc(%Start),+,(trunc %Step)}<nsw> Flags: <nssw>
5462// P2(EqualPred): %Start == (sext i32 (trunc i64 %Start to i32) to i64)
5463// P3(EqualPred): %Step == (sext i32 (trunc i64 %Step to i32) to i64)
5464// The returned pair means that SymbolicPHI can be rewritten into NewAddRec
5465// under the predicates {P1,P2,P3}.
5466// This predicated rewrite will be cached in PredicatedSCEVRewrites:
5467// PredicatedSCEVRewrites[{%X,L}] = {NewAddRec, {P1,P2,P3)}
5468//
5469// TODO's:
5470//
5471// 1) Extend the Induction descriptor to also support inductions that involve
5472// casts: When needed (namely, when we are called in the context of the
5473// vectorizer induction analysis), a Set of cast instructions will be
5474// populated by this method, and provided back to isInductionPHI. This is
5475// needed to allow the vectorizer to properly record them to be ignored by
5476// the cost model and to avoid vectorizing them (otherwise these casts,
5477// which are redundant under the runtime overflow checks, will be
5478// vectorized, which can be costly).
5479//
5480// 2) Support additional induction/PHISCEV patterns: We also want to support
5481// inductions where the sext-trunc / zext-trunc operations (partly) occur
5482// after the induction update operation (the induction increment):
5483//
5484// (Trunc iy (SExt/ZExt ix (%SymbolicPHI + InvariantAccum) to iy) to ix)
5485// which correspond to a phi->add->trunc->sext/zext->phi update chain.
5486//
5487// (Trunc iy ((SExt/ZExt ix (%SymbolicPhi) to iy) + InvariantAccum) to ix)
5488// which correspond to a phi->trunc->add->sext/zext->phi update chain.
5489//
5490// 3) Outline common code with createAddRecFromPHI to avoid duplication.
5491std::optional<std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>>>
5492ScalarEvolution::createAddRecFromPHIWithCastsImpl(const SCEVUnknown *SymbolicPHI) {
5494
5495 // *** Part1: Analyze if we have a phi-with-cast pattern for which we can
5496 // return an AddRec expression under some predicate.
5497
5498 auto *PN = cast<PHINode>(SymbolicPHI->getValue());
5499 const Loop *L = isIntegerLoopHeaderPHI(PN, LI);
5500 assert(L && "Expecting an integer loop header phi");
5501
5502 // The loop may have multiple entrances or multiple exits; we can analyze
5503 // this phi as an addrec if it has a unique entry value and a unique
5504 // backedge value.
5505 Value *BEValueV = nullptr, *StartValueV = nullptr;
5506 for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) {
5507 Value *V = PN->getIncomingValue(i);
5508 if (L->contains(PN->getIncomingBlock(i))) {
5509 if (!BEValueV) {
5510 BEValueV = V;
5511 } else if (BEValueV != V) {
5512 BEValueV = nullptr;
5513 break;
5514 }
5515 } else if (!StartValueV) {
5516 StartValueV = V;
5517 } else if (StartValueV != V) {
5518 StartValueV = nullptr;
5519 break;
5520 }
5521 }
5522 if (!BEValueV || !StartValueV)
5523 return std::nullopt;
5524
5525 const SCEV *BEValue = getSCEV(BEValueV);
5526
5527 // If the value coming around the backedge is an add with the symbolic
5528 // value we just inserted, possibly with casts that we can ignore under
5529 // an appropriate runtime guard, then we found a simple induction variable!
5530 const auto *Add = dyn_cast<SCEVAddExpr>(BEValue);
5531 if (!Add)
5532 return std::nullopt;
5533
5534 // If there is a single occurrence of the symbolic value, possibly
5535 // casted, replace it with a recurrence.
5536 unsigned FoundIndex = Add->getNumOperands();
5537 Type *TruncTy = nullptr;
5538 bool Signed;
5539 for (unsigned i = 0, e = Add->getNumOperands(); i != e; ++i)
5540 if ((TruncTy =
5541 isSimpleCastedPHI(Add->getOperand(i), SymbolicPHI, Signed, *this)))
5542 if (FoundIndex == e) {
5543 FoundIndex = i;
5544 break;
5545 }
5546
5547 if (FoundIndex == Add->getNumOperands())
5548 return std::nullopt;
5549
5550 // Create an add with everything but the specified operand.
5552 for (unsigned i = 0, e = Add->getNumOperands(); i != e; ++i)
5553 if (i != FoundIndex)
5554 Ops.push_back(Add->getOperand(i));
5555 const SCEV *Accum = getAddExpr(Ops);
5556
5557 // The runtime checks will not be valid if the step amount is
5558 // varying inside the loop.
5559 if (!isLoopInvariant(Accum, L))
5560 return std::nullopt;
5561
5562 // *** Part2: Create the predicates
5563
5564 // Analysis was successful: we have a phi-with-cast pattern for which we
5565 // can return an AddRec expression under the following predicates:
5566 //
5567 // P1: A Wrap predicate that guarantees that Trunc(Start) + i*Trunc(Accum)
5568 // fits within the truncated type (does not overflow) for i = 0 to n-1.
5569 // P2: An Equal predicate that guarantees that
5570 // Start = (Ext ix (Trunc iy (Start) to ix) to iy)
5571 // P3: An Equal predicate that guarantees that
5572 // Accum = (Ext ix (Trunc iy (Accum) to ix) to iy)
5573 //
5574 // As we next prove, the above predicates guarantee that:
5575 // Start + i*Accum = (Ext ix (Trunc iy ( Start + i*Accum ) to ix) to iy)
5576 //
5577 //
5578 // More formally, we want to prove that:
5579 // Expr(i+1) = Start + (i+1) * Accum
5580 // = (Ext ix (Trunc iy (Expr(i)) to ix) to iy) + Accum
5581 //
5582 // Given that:
5583 // 1) Expr(0) = Start
5584 // 2) Expr(1) = Start + Accum
5585 // = (Ext ix (Trunc iy (Start) to ix) to iy) + Accum :: from P2
5586 // 3) Induction hypothesis (step i):
5587 // Expr(i) = (Ext ix (Trunc iy (Expr(i-1)) to ix) to iy) + Accum
5588 //
5589 // Proof:
5590 // Expr(i+1) =
5591 // = Start + (i+1)*Accum
5592 // = (Start + i*Accum) + Accum
5593 // = Expr(i) + Accum
5594 // = (Ext ix (Trunc iy (Expr(i-1)) to ix) to iy) + Accum + Accum
5595 // :: from step i
5596 //
5597 // = (Ext ix (Trunc iy (Start + (i-1)*Accum) to ix) to iy) + Accum + Accum
5598 //
5599 // = (Ext ix (Trunc iy (Start + (i-1)*Accum) to ix) to iy)
5600 // + (Ext ix (Trunc iy (Accum) to ix) to iy)
5601 // + Accum :: from P3
5602 //
5603 // = (Ext ix (Trunc iy ((Start + (i-1)*Accum) + Accum) to ix) to iy)
5604 // + Accum :: from P1: Ext(x)+Ext(y)=>Ext(x+y)
5605 //
5606 // = (Ext ix (Trunc iy (Start + i*Accum) to ix) to iy) + Accum
5607 // = (Ext ix (Trunc iy (Expr(i)) to ix) to iy) + Accum
5608 //
5609 // By induction, the same applies to all iterations 1<=i<n:
5610 //
5611
5612 // Create a truncated addrec for which we will add a no overflow check (P1).
5613 const SCEV *StartVal = getSCEV(StartValueV);
5614 const SCEV *PHISCEV =
5615 getAddRecExpr(getTruncateExpr(StartVal, TruncTy),
5616 getTruncateExpr(Accum, TruncTy), L, SCEV::FlagAnyWrap);
5617
5618 // PHISCEV can be either a SCEVConstant or a SCEVAddRecExpr.
5619 // ex: If truncated Accum is 0 and StartVal is a constant, then PHISCEV
5620 // will be constant.
5621 //
5622 // If PHISCEV is a constant, then P1 degenerates into P2 or P3, so we don't
5623 // add P1.
5624 if (const auto *AR = dyn_cast<SCEVAddRecExpr>(PHISCEV)) {
5628 const SCEVPredicate *AddRecPred = getWrapPredicate(AR, AddedFlags);
5629 Predicates.push_back(AddRecPred);
5630 }
5631
5632 // Create the Equal Predicates P2,P3:
5633
5634 // It is possible that the predicates P2 and/or P3 are computable at
5635 // compile time due to StartVal and/or Accum being constants.
5636 // If either one is, then we can check that now and escape if either P2
5637 // or P3 is false.
5638
5639 // Construct the extended SCEV: (Ext ix (Trunc iy (Expr) to ix) to iy)
5640 // for each of StartVal and Accum
5641 auto getExtendedExpr = [&](const SCEV *Expr,
5642 bool CreateSignExtend) -> const SCEV * {
5643 assert(isLoopInvariant(Expr, L) && "Expr is expected to be invariant");
5644 const SCEV *TruncatedExpr = getTruncateExpr(Expr, TruncTy);
5645 const SCEV *ExtendedExpr =
5646 CreateSignExtend ? getSignExtendExpr(TruncatedExpr, Expr->getType())
5647 : getZeroExtendExpr(TruncatedExpr, Expr->getType());
5648 return ExtendedExpr;
5649 };
5650
5651 // Given:
5652 // ExtendedExpr = (Ext ix (Trunc iy (Expr) to ix) to iy
5653 // = getExtendedExpr(Expr)
5654 // Determine whether the predicate P: Expr == ExtendedExpr
5655 // is known to be false at compile time
5656 auto PredIsKnownFalse = [&](const SCEV *Expr,
5657 const SCEV *ExtendedExpr) -> bool {
5658 return Expr != ExtendedExpr &&
5659 isKnownPredicate(ICmpInst::ICMP_NE, Expr, ExtendedExpr);
5660 };
5661
5662 const SCEV *StartExtended = getExtendedExpr(StartVal, Signed);
5663 if (PredIsKnownFalse(StartVal, StartExtended)) {
5664 LLVM_DEBUG(dbgs() << "P2 is compile-time false\n";);
5665 return std::nullopt;
5666 }
5667
5668 // The Step is always Signed (because the overflow checks are either
5669 // NSSW or NUSW)
5670 const SCEV *AccumExtended = getExtendedExpr(Accum, /*CreateSignExtend=*/true);
5671 if (PredIsKnownFalse(Accum, AccumExtended)) {
5672 LLVM_DEBUG(dbgs() << "P3 is compile-time false\n";);
5673 return std::nullopt;
5674 }
5675
5676 auto AppendPredicate = [&](const SCEV *Expr,
5677 const SCEV *ExtendedExpr) -> void {
5678 if (Expr != ExtendedExpr &&
5679 !isKnownPredicate(ICmpInst::ICMP_EQ, Expr, ExtendedExpr)) {
5680 const SCEVPredicate *Pred = getEqualPredicate(Expr, ExtendedExpr);
5681 LLVM_DEBUG(dbgs() << "Added Predicate: " << *Pred);
5682 Predicates.push_back(Pred);
5683 }
5684 };
5685
5686 AppendPredicate(StartVal, StartExtended);
5687 AppendPredicate(Accum, AccumExtended);
5688
5689 // *** Part3: Predicates are ready. Now go ahead and create the new addrec in
5690 // which the casts had been folded away. The caller can rewrite SymbolicPHI
5691 // into NewAR if it will also add the runtime overflow checks specified in
5692 // Predicates.
5693 auto *NewAR = getAddRecExpr(StartVal, Accum, L, SCEV::FlagAnyWrap);
5694
5695 std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>> PredRewrite =
5696 std::make_pair(NewAR, Predicates);
5697 // Remember the result of the analysis for this SCEV at this locayyytion.
5698 PredicatedSCEVRewrites[{SymbolicPHI, L}] = PredRewrite;
5699 return PredRewrite;
5700}
5701
5702std::optional<std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>>>
5704 auto *PN = cast<PHINode>(SymbolicPHI->getValue());
5705 const Loop *L = isIntegerLoopHeaderPHI(PN, LI);
5706 if (!L)
5707 return std::nullopt;
5708
5709 // Check to see if we already analyzed this PHI.
5710 auto I = PredicatedSCEVRewrites.find({SymbolicPHI, L});
5711 if (I != PredicatedSCEVRewrites.end()) {
5712 std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>> Rewrite =
5713 I->second;
5714 // Analysis was done before and failed to create an AddRec:
5715 if (Rewrite.first == SymbolicPHI)
5716 return std::nullopt;
5717 // Analysis was done before and succeeded to create an AddRec under
5718 // a predicate:
5719 assert(isa<SCEVAddRecExpr>(Rewrite.first) && "Expected an AddRec");
5720 assert(!(Rewrite.second).empty() && "Expected to find Predicates");
5721 return Rewrite;
5722 }
5723
5724 std::optional<std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>>>
5725 Rewrite = createAddRecFromPHIWithCastsImpl(SymbolicPHI);
5726
5727 // Record in the cache that the analysis failed
5728 if (!Rewrite) {
5730 PredicatedSCEVRewrites[{SymbolicPHI, L}] = {SymbolicPHI, Predicates};
5731 return std::nullopt;
5732 }
5733
5734 return Rewrite;
5735}
5736
5737// FIXME: This utility is currently required because the Rewriter currently
5738// does not rewrite this expression:
5739// {0, +, (sext ix (trunc iy to ix) to iy)}
5740// into {0, +, %step},
5741// even when the following Equal predicate exists:
5742// "%step == (sext ix (trunc iy to ix) to iy)".
5744 const SCEVAddRecExpr *AR1, const SCEVAddRecExpr *AR2) const {
5745 if (AR1 == AR2)
5746 return true;
5747
5748 auto areExprsEqual = [&](const SCEV *Expr1, const SCEV *Expr2) -> bool {
5749 if (Expr1 != Expr2 &&
5750 !Preds->implies(SE.getEqualPredicate(Expr1, Expr2), SE) &&
5751 !Preds->implies(SE.getEqualPredicate(Expr2, Expr1), SE))
5752 return false;
5753 return true;
5754 };
5755
5756 if (!areExprsEqual(AR1->getStart(), AR2->getStart()) ||
5757 !areExprsEqual(AR1->getStepRecurrence(SE), AR2->getStepRecurrence(SE)))
5758 return false;
5759 return true;
5760}
5761
5762/// A helper function for createAddRecFromPHI to handle simple cases.
5763///
5764/// This function tries to find an AddRec expression for the simplest (yet most
5765/// common) cases: PN = PHI(Start, OP(Self, LoopInvariant)).
5766/// If it fails, createAddRecFromPHI will use a more general, but slow,
5767/// technique for finding the AddRec expression.
5768const SCEV *ScalarEvolution::createSimpleAffineAddRec(PHINode *PN,
5769 Value *BEValueV,
5770 Value *StartValueV) {
5771 const Loop *L = LI.getLoopFor(PN->getParent());
5772 assert(L && L->getHeader() == PN->getParent());
5773 assert(BEValueV && StartValueV);
5774
5775 auto BO = MatchBinaryOp(BEValueV, getDataLayout(), AC, DT, PN);
5776 if (!BO)
5777 return nullptr;
5778
5779 if (BO->Opcode != Instruction::Add)
5780 return nullptr;
5781
5782 const SCEV *Accum = nullptr;
5783 if (BO->LHS == PN && L->isLoopInvariant(BO->RHS))
5784 Accum = getSCEV(BO->RHS);
5785 else if (BO->RHS == PN && L->isLoopInvariant(BO->LHS))
5786 Accum = getSCEV(BO->LHS);
5787
5788 if (!Accum)
5789 return nullptr;
5790
5792 if (BO->IsNUW)
5793 Flags = setFlags(Flags, SCEV::FlagNUW);
5794 if (BO->IsNSW)
5795 Flags = setFlags(Flags, SCEV::FlagNSW);
5796
5797 const SCEV *StartVal = getSCEV(StartValueV);
5798 const SCEV *PHISCEV = getAddRecExpr(StartVal, Accum, L, Flags);
5799 insertValueToMap(PN, PHISCEV);
5800
5801 if (auto *AR = dyn_cast<SCEVAddRecExpr>(PHISCEV)) {
5802 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR),
5804 proveNoWrapViaConstantRanges(AR)));
5805 }
5806
5807 // We can add Flags to the post-inc expression only if we
5808 // know that it is *undefined behavior* for BEValueV to
5809 // overflow.
5810 if (auto *BEInst = dyn_cast<Instruction>(BEValueV)) {
5811 assert(isLoopInvariant(Accum, L) &&
5812 "Accum is defined outside L, but is not invariant?");
5813 if (isAddRecNeverPoison(BEInst, L))
5814 (void)getAddRecExpr(getAddExpr(StartVal, Accum), Accum, L, Flags);
5815 }
5816
5817 return PHISCEV;
5818}
5819
5820const SCEV *ScalarEvolution::createAddRecFromPHI(PHINode *PN) {
5821 const Loop *L = LI.getLoopFor(PN->getParent());
5822 if (!L || L->getHeader() != PN->getParent())
5823 return nullptr;
5824
5825 // The loop may have multiple entrances or multiple exits; we can analyze
5826 // this phi as an addrec if it has a unique entry value and a unique
5827 // backedge value.
5828 Value *BEValueV = nullptr, *StartValueV = nullptr;
5829 for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) {
5830 Value *V = PN->getIncomingValue(i);
5831 if (L->contains(PN->getIncomingBlock(i))) {
5832 if (!BEValueV) {
5833 BEValueV = V;
5834 } else if (BEValueV != V) {
5835 BEValueV = nullptr;
5836 break;
5837 }
5838 } else if (!StartValueV) {
5839 StartValueV = V;
5840 } else if (StartValueV != V) {
5841 StartValueV = nullptr;
5842 break;
5843 }
5844 }
5845 if (!BEValueV || !StartValueV)
5846 return nullptr;
5847
5848 assert(ValueExprMap.find_as(PN) == ValueExprMap.end() &&
5849 "PHI node already processed?");
5850
5851 // First, try to find AddRec expression without creating a fictituos symbolic
5852 // value for PN.
5853 if (auto *S = createSimpleAffineAddRec(PN, BEValueV, StartValueV))
5854 return S;
5855
5856 // Handle PHI node value symbolically.
5857 const SCEV *SymbolicName = getUnknown(PN);
5858 insertValueToMap(PN, SymbolicName);
5859
5860 // Using this symbolic name for the PHI, analyze the value coming around
5861 // the back-edge.
5862 const SCEV *BEValue = getSCEV(BEValueV);
5863
5864 // NOTE: If BEValue is loop invariant, we know that the PHI node just
5865 // has a special value for the first iteration of the loop.
5866
5867 // If the value coming around the backedge is an add with the symbolic
5868 // value we just inserted, then we found a simple induction variable!
5869 if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(BEValue)) {
5870 // If there is a single occurrence of the symbolic value, replace it
5871 // with a recurrence.
5872 unsigned FoundIndex = Add->getNumOperands();
5873 for (unsigned i = 0, e = Add->getNumOperands(); i != e; ++i)
5874 if (Add->getOperand(i) == SymbolicName)
5875 if (FoundIndex == e) {
5876 FoundIndex = i;
5877 break;
5878 }
5879
5880 if (FoundIndex != Add->getNumOperands()) {
5881 // Create an add with everything but the specified operand.
5883 for (unsigned i = 0, e = Add->getNumOperands(); i != e; ++i)
5884 if (i != FoundIndex)
5885 Ops.push_back(SCEVBackedgeConditionFolder::rewrite(Add->getOperand(i),
5886 L, *this));
5887 const SCEV *Accum = getAddExpr(Ops);
5888
5889 // This is not a valid addrec if the step amount is varying each
5890 // loop iteration, but is not itself an addrec in this loop.
5891 if (isLoopInvariant(Accum, L) ||
5892 (isa<SCEVAddRecExpr>(Accum) &&
5893 cast<SCEVAddRecExpr>(Accum)->getLoop() == L)) {
5895
5896 if (auto BO = MatchBinaryOp(BEValueV, getDataLayout(), AC, DT, PN)) {
5897 if (BO->Opcode == Instruction::Add && BO->LHS == PN) {
5898 if (BO->IsNUW)
5899 Flags = setFlags(Flags, SCEV::FlagNUW);
5900 if (BO->IsNSW)
5901 Flags = setFlags(Flags, SCEV::FlagNSW);
5902 }
5903 } else if (GEPOperator *GEP = dyn_cast<GEPOperator>(BEValueV)) {
5904 if (GEP->getOperand(0) == PN) {
5905 GEPNoWrapFlags NW = GEP->getNoWrapFlags();
5906 // If the increment has any nowrap flags, then we know the address
5907 // space cannot be wrapped around.
5908 if (NW != GEPNoWrapFlags::none())
5909 Flags = setFlags(Flags, SCEV::FlagNW);
5910 // If the GEP is nuw or nusw with non-negative offset, we know that
5911 // no unsigned wrap occurs. We cannot set the nsw flag as only the
5912 // offset is treated as signed, while the base is unsigned.
5913 if (NW.hasNoUnsignedWrap() ||
5915 Flags = setFlags(Flags, SCEV::FlagNUW);
5916 }
5917
5918 // We cannot transfer nuw and nsw flags from subtraction
5919 // operations -- sub nuw X, Y is not the same as add nuw X, -Y
5920 // for instance.
5921 }
5922
5923 const SCEV *StartVal = getSCEV(StartValueV);
5924 const SCEV *PHISCEV = getAddRecExpr(StartVal, Accum, L, Flags);
5925
5926 // Okay, for the entire analysis of this edge we assumed the PHI
5927 // to be symbolic. We now need to go back and purge all of the
5928 // entries for the scalars that use the symbolic expression.
5929 forgetMemoizedResults(SymbolicName);
5930 insertValueToMap(PN, PHISCEV);
5931
5932 if (auto *AR = dyn_cast<SCEVAddRecExpr>(PHISCEV)) {
5933 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR),
5935 proveNoWrapViaConstantRanges(AR)));
5936 }
5937
5938 // We can add Flags to the post-inc expression only if we
5939 // know that it is *undefined behavior* for BEValueV to
5940 // overflow.
5941 if (auto *BEInst = dyn_cast<Instruction>(BEValueV))
5942 if (isLoopInvariant(Accum, L) && isAddRecNeverPoison(BEInst, L))
5943 (void)getAddRecExpr(getAddExpr(StartVal, Accum), Accum, L, Flags);
5944
5945 return PHISCEV;
5946 }
5947 }
5948 } else {
5949 // Otherwise, this could be a loop like this:
5950 // i = 0; for (j = 1; ..; ++j) { .... i = j; }
5951 // In this case, j = {1,+,1} and BEValue is j.
5952 // Because the other in-value of i (0) fits the evolution of BEValue
5953 // i really is an addrec evolution.
5954 //
5955 // We can generalize this saying that i is the shifted value of BEValue
5956 // by one iteration:
5957 // PHI(f(0), f({1,+,1})) --> f({0,+,1})
5958
5959 // Do not allow refinement in rewriting of BEValue.
5960 const SCEV *Shifted = SCEVShiftRewriter::rewrite(BEValue, L, *this);
5961 const SCEV *Start = SCEVInitRewriter::rewrite(Shifted, L, *this, false);
5962 if (Shifted != getCouldNotCompute() && Start != getCouldNotCompute() &&
5963 isGuaranteedNotToCauseUB(Shifted) && ::impliesPoison(Shifted, Start)) {
5964 const SCEV *StartVal = getSCEV(StartValueV);
5965 if (Start == StartVal) {
5966 // Okay, for the entire analysis of this edge we assumed the PHI
5967 // to be symbolic. We now need to go back and purge all of the
5968 // entries for the scalars that use the symbolic expression.
5969 forgetMemoizedResults(SymbolicName);
5970 insertValueToMap(PN, Shifted);
5971 return Shifted;
5972 }
5973 }
5974 }
5975
5976 // Remove the temporary PHI node SCEV that has been inserted while intending
5977 // to create an AddRecExpr for this PHI node. We can not keep this temporary
5978 // as it will prevent later (possibly simpler) SCEV expressions to be added
5979 // to the ValueExprMap.
5980 eraseValueFromMap(PN);
5981
5982 return nullptr;
5983}
5984
5985// Try to match a control flow sequence that branches out at BI and merges back
5986// at Merge into a "C ? LHS : RHS" select pattern. Return true on a successful
5987// match.
5989 Value *&C, Value *&LHS, Value *&RHS) {
5990 C = BI->getCondition();
5991
5992 BasicBlockEdge LeftEdge(BI->getParent(), BI->getSuccessor(0));
5993 BasicBlockEdge RightEdge(BI->getParent(), BI->getSuccessor(1));
5994
5995 if (!LeftEdge.isSingleEdge())
5996 return false;
5997
5998 assert(RightEdge.isSingleEdge() && "Follows from LeftEdge.isSingleEdge()");
5999
6000 Use &LeftUse = Merge->getOperandUse(0);
6001 Use &RightUse = Merge->getOperandUse(1);
6002
6003 if (DT.dominates(LeftEdge, LeftUse) && DT.dominates(RightEdge, RightUse)) {
6004 LHS = LeftUse;
6005 RHS = RightUse;
6006 return true;
6007 }
6008
6009 if (DT.dominates(LeftEdge, RightUse) && DT.dominates(RightEdge, LeftUse)) {
6010 LHS = RightUse;
6011 RHS = LeftUse;
6012 return true;
6013 }
6014
6015 return false;
6016}
6017
6018const SCEV *ScalarEvolution::createNodeFromSelectLikePHI(PHINode *PN) {
6019 auto IsReachable =
6020 [&](BasicBlock *BB) { return DT.isReachableFromEntry(BB); };
6021 if (PN->getNumIncomingValues() == 2 && all_of(PN->blocks(), IsReachable)) {
6022 // Try to match
6023 //
6024 // br %cond, label %left, label %right
6025 // left:
6026 // br label %merge
6027 // right:
6028 // br label %merge
6029 // merge:
6030 // V = phi [ %x, %left ], [ %y, %right ]
6031 //
6032 // as "select %cond, %x, %y"
6033
6034 BasicBlock *IDom = DT[PN->getParent()]->getIDom()->getBlock();
6035 assert(IDom && "At least the entry block should dominate PN");
6036
6037 auto *BI = dyn_cast<BranchInst>(IDom->getTerminator());
6038 Value *Cond = nullptr, *LHS = nullptr, *RHS = nullptr;
6039
6040 if (BI && BI->isConditional() &&
6041 BrPHIToSelect(DT, BI, PN, Cond, LHS, RHS) &&
6044 return createNodeForSelectOrPHI(PN, Cond, LHS, RHS);
6045 }
6046
6047 return nullptr;
6048}
6049
6050/// Returns SCEV for the first operand of a phi if all phi operands have
6051/// identical opcodes and operands
6052/// eg.
6053/// a: %add = %a + %b
6054/// br %c
6055/// b: %add1 = %a + %b
6056/// br %c
6057/// c: %phi = phi [%add, a], [%add1, b]
6058/// scev(%phi) => scev(%add)
6059const SCEV *
6060ScalarEvolution::createNodeForPHIWithIdenticalOperands(PHINode *PN) {
6061 BinaryOperator *CommonInst = nullptr;
6062 // Check if instructions are identical.
6063 for (Value *Incoming : PN->incoming_values()) {
6064 auto *IncomingInst = dyn_cast<BinaryOperator>(Incoming);
6065 if (!IncomingInst)
6066 return nullptr;
6067 if (CommonInst) {
6068 if (!CommonInst->isIdenticalToWhenDefined(IncomingInst))
6069 return nullptr; // Not identical, give up
6070 } else {
6071 // Remember binary operator
6072 CommonInst = IncomingInst;
6073 }
6074 }
6075 if (!CommonInst)
6076 return nullptr;
6077
6078 // Check if SCEV exprs for instructions are identical.
6079 const SCEV *CommonSCEV = getSCEV(CommonInst);
6080 bool SCEVExprsIdentical =
6082 [this, CommonSCEV](Value *V) { return CommonSCEV == getSCEV(V); });
6083 return SCEVExprsIdentical ? CommonSCEV : nullptr;
6084}
6085
6086const SCEV *ScalarEvolution::createNodeForPHI(PHINode *PN) {
6087 if (const SCEV *S = createAddRecFromPHI(PN))
6088 return S;
6089
6090 // We do not allow simplifying phi (undef, X) to X here, to avoid reusing the
6091 // phi node for X.
6092 if (Value *V = simplifyInstruction(
6093 PN, {getDataLayout(), &TLI, &DT, &AC, /*CtxI=*/nullptr,
6094 /*UseInstrInfo=*/true, /*CanUseUndef=*/false}))
6095 return getSCEV(V);
6096
6097 if (const SCEV *S = createNodeForPHIWithIdenticalOperands(PN))
6098 return S;
6099
6100 if (const SCEV *S = createNodeFromSelectLikePHI(PN))
6101 return S;
6102
6103 // If it's not a loop phi, we can't handle it yet.
6104 return getUnknown(PN);
6105}
6106
6107bool SCEVMinMaxExprContains(const SCEV *Root, const SCEV *OperandToFind,
6108 SCEVTypes RootKind) {
6109 struct FindClosure {
6110 const SCEV *OperandToFind;
6111 const SCEVTypes RootKind; // Must be a sequential min/max expression.
6112 const SCEVTypes NonSequentialRootKind; // Non-seq variant of RootKind.
6113
6114 bool Found = false;
6115
6116 bool canRecurseInto(SCEVTypes Kind) const {
6117 // We can only recurse into the SCEV expression of the same effective type
6118 // as the type of our root SCEV expression, and into zero-extensions.
6119 return RootKind == Kind || NonSequentialRootKind == Kind ||
6120 scZeroExtend == Kind;
6121 };
6122
6123 FindClosure(const SCEV *OperandToFind, SCEVTypes RootKind)
6124 : OperandToFind(OperandToFind), RootKind(RootKind),
6125 NonSequentialRootKind(
6127 RootKind)) {}
6128
6129 bool follow(const SCEV *S) {
6130 Found = S == OperandToFind;
6131
6132 return !isDone() && canRecurseInto(S->getSCEVType());
6133 }
6134
6135 bool isDone() const { return Found; }
6136 };
6137
6138 FindClosure FC(OperandToFind, RootKind);
6139 visitAll(Root, FC);
6140 return FC.Found;
6141}
6142
6143std::optional<const SCEV *>
6144ScalarEvolution::createNodeForSelectOrPHIInstWithICmpInstCond(Type *Ty,
6145 ICmpInst *Cond,
6146 Value *TrueVal,
6147 Value *FalseVal) {
6148 // Try to match some simple smax or umax patterns.
6149 auto *ICI = Cond;
6150
6151 Value *LHS = ICI->getOperand(0);
6152 Value *RHS = ICI->getOperand(1);
6153
6154 switch (ICI->getPredicate()) {
6155 case ICmpInst::ICMP_SLT:
6156 case ICmpInst::ICMP_SLE:
6157 case ICmpInst::ICMP_ULT:
6158 case ICmpInst::ICMP_ULE:
6159 std::swap(LHS, RHS);
6160 [[fallthrough]];
6161 case ICmpInst::ICMP_SGT:
6162 case ICmpInst::ICMP_SGE:
6163 case ICmpInst::ICMP_UGT:
6164 case ICmpInst::ICMP_UGE:
6165 // a > b ? a+x : b+x -> max(a, b)+x
6166 // a > b ? b+x : a+x -> min(a, b)+x
6168 bool Signed = ICI->isSigned();
6169 const SCEV *LA = getSCEV(TrueVal);
6170 const SCEV *RA = getSCEV(FalseVal);
6171 const SCEV *LS = getSCEV(LHS);
6172 const SCEV *RS = getSCEV(RHS);
6173 if (LA->getType()->isPointerTy()) {
6174 // FIXME: Handle cases where LS/RS are pointers not equal to LA/RA.
6175 // Need to make sure we can't produce weird expressions involving
6176 // negated pointers.
6177 if (LA == LS && RA == RS)
6178 return Signed ? getSMaxExpr(LS, RS) : getUMaxExpr(LS, RS);
6179 if (LA == RS && RA == LS)
6180 return Signed ? getSMinExpr(LS, RS) : getUMinExpr(LS, RS);
6181 }
6182 auto CoerceOperand = [&](const SCEV *Op) -> const SCEV * {
6183 if (Op->getType()->isPointerTy()) {
6186 return Op;
6187 }
6188 if (Signed)
6189 Op = getNoopOrSignExtend(Op, Ty);
6190 else
6191 Op = getNoopOrZeroExtend(Op, Ty);
6192 return Op;
6193 };
6194 LS = CoerceOperand(LS);
6195 RS = CoerceOperand(RS);
6197 break;
6198 const SCEV *LDiff = getMinusSCEV(LA, LS);
6199 const SCEV *RDiff = getMinusSCEV(RA, RS);
6200 if (LDiff == RDiff)
6201 return getAddExpr(Signed ? getSMaxExpr(LS, RS) : getUMaxExpr(LS, RS),
6202 LDiff);
6203 LDiff = getMinusSCEV(LA, RS);
6204 RDiff = getMinusSCEV(RA, LS);
6205 if (LDiff == RDiff)
6206 return getAddExpr(Signed ? getSMinExpr(LS, RS) : getUMinExpr(LS, RS),
6207 LDiff);
6208 }
6209 break;
6210 case ICmpInst::ICMP_NE:
6211 // x != 0 ? x+y : C+y -> x == 0 ? C+y : x+y
6212 std::swap(TrueVal, FalseVal);
6213 [[fallthrough]];
6214 case ICmpInst::ICMP_EQ:
6215 // x == 0 ? C+y : x+y -> umax(x, C)+y iff C u<= 1
6218 const SCEV *X = getNoopOrZeroExtend(getSCEV(LHS), Ty);
6219 const SCEV *TrueValExpr = getSCEV(TrueVal); // C+y
6220 const SCEV *FalseValExpr = getSCEV(FalseVal); // x+y
6221 const SCEV *Y = getMinusSCEV(FalseValExpr, X); // y = (x+y)-x
6222 const SCEV *C = getMinusSCEV(TrueValExpr, Y); // C = (C+y)-y
6223 if (isa<SCEVConstant>(C) && cast<SCEVConstant>(C)->getAPInt().ule(1))
6224 return getAddExpr(getUMaxExpr(X, C), Y);
6225 }
6226 // x == 0 ? 0 : umin (..., x, ...) -> umin_seq(x, umin (...))
6227 // x == 0 ? 0 : umin_seq(..., x, ...) -> umin_seq(x, umin_seq(...))
6228 // x == 0 ? 0 : umin (..., umin_seq(..., x, ...), ...)
6229 // -> umin_seq(x, umin (..., umin_seq(...), ...))
6231 isa<ConstantInt>(TrueVal) && cast<ConstantInt>(TrueVal)->isZero()) {
6232 const SCEV *X = getSCEV(LHS);
6233 while (auto *ZExt = dyn_cast<SCEVZeroExtendExpr>(X))
6234 X = ZExt->getOperand();
6235 if (getTypeSizeInBits(X->getType()) <= getTypeSizeInBits(Ty)) {
6236 const SCEV *FalseValExpr = getSCEV(FalseVal);
6237 if (SCEVMinMaxExprContains(FalseValExpr, X, scSequentialUMinExpr))
6238 return getUMinExpr(getNoopOrZeroExtend(X, Ty), FalseValExpr,
6239 /*Sequential=*/true);
6240 }
6241 }
6242 break;
6243 default:
6244 break;
6245 }
6246
6247 return std::nullopt;
6248}
6249
6250static std::optional<const SCEV *>
6252 const SCEV *TrueExpr, const SCEV *FalseExpr) {
6253 assert(CondExpr->getType()->isIntegerTy(1) &&
6254 TrueExpr->getType() == FalseExpr->getType() &&
6255 TrueExpr->getType()->isIntegerTy(1) &&
6256 "Unexpected operands of a select.");
6257
6258 // i1 cond ? i1 x : i1 C --> C + (i1 cond ? (i1 x - i1 C) : i1 0)
6259 // --> C + (umin_seq cond, x - C)
6260 //
6261 // i1 cond ? i1 C : i1 x --> C + (i1 cond ? i1 0 : (i1 x - i1 C))
6262 // --> C + (i1 ~cond ? (i1 x - i1 C) : i1 0)
6263 // --> C + (umin_seq ~cond, x - C)
6264
6265 // FIXME: while we can't legally model the case where both of the hands
6266 // are fully variable, we only require that the *difference* is constant.
6267 if (!isa<SCEVConstant>(TrueExpr) && !isa<SCEVConstant>(FalseExpr))
6268 return std::nullopt;
6269
6270 const SCEV *X, *C;
6271 if (isa<SCEVConstant>(TrueExpr)) {
6272 CondExpr = SE->getNotSCEV(CondExpr);
6273 X = FalseExpr;
6274 C = TrueExpr;
6275 } else {
6276 X = TrueExpr;
6277 C = FalseExpr;
6278 }
6279 return SE->getAddExpr(C, SE->getUMinExpr(CondExpr, SE->getMinusSCEV(X, C),
6280 /*Sequential=*/true));
6281}
6282
6283static std::optional<const SCEV *>
6285 Value *FalseVal) {
6286 if (!isa<ConstantInt>(TrueVal) && !isa<ConstantInt>(FalseVal))
6287 return std::nullopt;
6288
6289 const auto *SECond = SE->getSCEV(Cond);
6290 const auto *SETrue = SE->getSCEV(TrueVal);
6291 const auto *SEFalse = SE->getSCEV(FalseVal);
6292 return createNodeForSelectViaUMinSeq(SE, SECond, SETrue, SEFalse);
6293}
6294
6295const SCEV *ScalarEvolution::createNodeForSelectOrPHIViaUMinSeq(
6296 Value *V, Value *Cond, Value *TrueVal, Value *FalseVal) {
6297 assert(Cond->getType()->isIntegerTy(1) && "Select condition is not an i1?");
6298 assert(TrueVal->getType() == FalseVal->getType() &&
6299 V->getType() == TrueVal->getType() &&
6300 "Types of select hands and of the result must match.");
6301
6302 // For now, only deal with i1-typed `select`s.
6303 if (!V->getType()->isIntegerTy(1))
6304 return getUnknown(V);
6305
6306 if (std::optional<const SCEV *> S =
6307 createNodeForSelectViaUMinSeq(this, Cond, TrueVal, FalseVal))
6308 return *S;
6309
6310 return getUnknown(V);
6311}
6312
6313const SCEV *ScalarEvolution::createNodeForSelectOrPHI(Value *V, Value *Cond,
6314 Value *TrueVal,
6315 Value *FalseVal) {
6316 // Handle "constant" branch or select. This can occur for instance when a
6317 // loop pass transforms an inner loop and moves on to process the outer loop.
6318 if (auto *CI = dyn_cast<ConstantInt>(Cond))
6319 return getSCEV(CI->isOne() ? TrueVal : FalseVal);
6320
6321 if (auto *I = dyn_cast<Instruction>(V)) {
6322 if (auto *ICI = dyn_cast<ICmpInst>(Cond)) {
6323 if (std::optional<const SCEV *> S =
6324 createNodeForSelectOrPHIInstWithICmpInstCond(I->getType(), ICI,
6325 TrueVal, FalseVal))
6326 return *S;
6327 }
6328 }
6329
6330 return createNodeForSelectOrPHIViaUMinSeq(V, Cond, TrueVal, FalseVal);
6331}
6332
6333/// Expand GEP instructions into add and multiply operations. This allows them
6334/// to be analyzed by regular SCEV code.
6335const SCEV *ScalarEvolution::createNodeForGEP(GEPOperator *GEP) {
6336 assert(GEP->getSourceElementType()->isSized() &&
6337 "GEP source element type must be sized");
6338
6340 for (Value *Index : GEP->indices())
6341 IndexExprs.push_back(getSCEV(Index));
6342 return getGEPExpr(GEP, IndexExprs);
6343}
6344
6345APInt ScalarEvolution::getConstantMultipleImpl(const SCEV *S,
6346 const Instruction *CtxI) {
6347 uint64_t BitWidth = getTypeSizeInBits(S->getType());
6348 auto GetShiftedByZeros = [BitWidth](uint32_t TrailingZeros) {
6349 return TrailingZeros >= BitWidth
6351 : APInt::getOneBitSet(BitWidth, TrailingZeros);
6352 };
6353 auto GetGCDMultiple = [this, CtxI](const SCEVNAryExpr *N) {
6354 // The result is GCD of all operands results.
6355 APInt Res = getConstantMultiple(N->getOperand(0), CtxI);
6356 for (unsigned I = 1, E = N->getNumOperands(); I < E && Res != 1; ++I)
6358 Res, getConstantMultiple(N->getOperand(I), CtxI));
6359 return Res;
6360 };
6361
6362 switch (S->getSCEVType()) {
6363 case scConstant:
6364 return cast<SCEVConstant>(S)->getAPInt();
6365 case scPtrToInt:
6366 return getConstantMultiple(cast<SCEVPtrToIntExpr>(S)->getOperand(), CtxI);
6367 case scUDivExpr:
6368 case scVScale:
6369 return APInt(BitWidth, 1);
6370 case scTruncate: {
6371 // Only multiples that are a power of 2 will hold after truncation.
6372 const SCEVTruncateExpr *T = cast<SCEVTruncateExpr>(S);
6373 uint32_t TZ = getMinTrailingZeros(T->getOperand(), CtxI);
6374 return GetShiftedByZeros(TZ);
6375 }
6376 case scZeroExtend: {
6377 const SCEVZeroExtendExpr *Z = cast<SCEVZeroExtendExpr>(S);
6378 return getConstantMultiple(Z->getOperand(), CtxI).zext(BitWidth);
6379 }
6380 case scSignExtend: {
6381 // Only multiples that are a power of 2 will hold after sext.
6382 const SCEVSignExtendExpr *E = cast<SCEVSignExtendExpr>(S);
6383 uint32_t TZ = getMinTrailingZeros(E->getOperand(), CtxI);
6384 return GetShiftedByZeros(TZ);
6385 }
6386 case scMulExpr: {
6387 const SCEVMulExpr *M = cast<SCEVMulExpr>(S);
6388 if (M->hasNoUnsignedWrap()) {
6389 // The result is the product of all operand results.
6390 APInt Res = getConstantMultiple(M->getOperand(0), CtxI);
6391 for (const SCEV *Operand : M->operands().drop_front())
6392 Res = Res * getConstantMultiple(Operand, CtxI);
6393 return Res;
6394 }
6395
6396 // If there are no wrap guarentees, find the trailing zeros, which is the
6397 // sum of trailing zeros for all its operands.
6398 uint32_t TZ = 0;
6399 for (const SCEV *Operand : M->operands())
6400 TZ += getMinTrailingZeros(Operand, CtxI);
6401 return GetShiftedByZeros(TZ);
6402 }
6403 case scAddExpr:
6404 case scAddRecExpr: {
6405 const SCEVNAryExpr *N = cast<SCEVNAryExpr>(S);
6406 if (N->hasNoUnsignedWrap())
6407 return GetGCDMultiple(N);
6408 // Find the trailing bits, which is the minimum of its operands.
6409 uint32_t TZ = getMinTrailingZeros(N->getOperand(0), CtxI);
6410 for (const SCEV *Operand : N->operands().drop_front())
6411 TZ = std::min(TZ, getMinTrailingZeros(Operand, CtxI));
6412 return GetShiftedByZeros(TZ);
6413 }
6414 case scUMaxExpr:
6415 case scSMaxExpr:
6416 case scUMinExpr:
6417 case scSMinExpr:
6419 return GetGCDMultiple(cast<SCEVNAryExpr>(S));
6420 case scUnknown: {
6421 // Ask ValueTracking for known bits. SCEVUnknown only become available at
6422 // the point their underlying IR instruction has been defined. If CtxI was
6423 // not provided, use:
6424 // * the first instruction in the entry block if it is an argument
6425 // * the instruction itself otherwise.
6426 const SCEVUnknown *U = cast<SCEVUnknown>(S);
6427 if (!CtxI) {
6428 if (isa<Argument>(U->getValue()))
6429 CtxI = &*F.getEntryBlock().begin();
6430 else if (auto *I = dyn_cast<Instruction>(U->getValue()))
6431 CtxI = I;
6432 }
6433 unsigned Known =
6434 computeKnownBits(U->getValue(), getDataLayout(), &AC, CtxI, &DT)
6435 .countMinTrailingZeros();
6436 return GetShiftedByZeros(Known);
6437 }
6438 case scCouldNotCompute:
6439 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
6440 }
6441 llvm_unreachable("Unknown SCEV kind!");
6442}
6443
6445 const Instruction *CtxI) {
6446 // Skip looking up and updating the cache if there is a context instruction,
6447 // as the result will only be valid in the specified context.
6448 if (CtxI)
6449 return getConstantMultipleImpl(S, CtxI);
6450
6451 auto I = ConstantMultipleCache.find(S);
6452 if (I != ConstantMultipleCache.end())
6453 return I->second;
6454
6455 APInt Result = getConstantMultipleImpl(S, CtxI);
6456 auto InsertPair = ConstantMultipleCache.insert({S, Result});
6457 assert(InsertPair.second && "Should insert a new key");
6458 return InsertPair.first->second;
6459}
6460
6462 APInt Multiple = getConstantMultiple(S);
6463 return Multiple == 0 ? APInt(Multiple.getBitWidth(), 1) : Multiple;
6464}
6465
6467 const Instruction *CtxI) {
6468 return std::min(getConstantMultiple(S, CtxI).countTrailingZeros(),
6469 (unsigned)getTypeSizeInBits(S->getType()));
6470}
6471
6472/// Helper method to assign a range to V from metadata present in the IR.
6473static std::optional<ConstantRange> GetRangeFromMetadata(Value *V) {
6475 if (MDNode *MD = I->getMetadata(LLVMContext::MD_range))
6476 return getConstantRangeFromMetadata(*MD);
6477 if (const auto *CB = dyn_cast<CallBase>(V))
6478 if (std::optional<ConstantRange> Range = CB->getRange())
6479 return Range;
6480 }
6481 if (auto *A = dyn_cast<Argument>(V))
6482 if (std::optional<ConstantRange> Range = A->getRange())
6483 return Range;
6484
6485 return std::nullopt;
6486}
6487
6489 SCEV::NoWrapFlags Flags) {
6490 if (AddRec->getNoWrapFlags(Flags) != Flags) {
6491 AddRec->setNoWrapFlags(Flags);
6492 UnsignedRanges.erase(AddRec);
6493 SignedRanges.erase(AddRec);
6494 ConstantMultipleCache.erase(AddRec);
6495 }
6496}
6497
6498ConstantRange ScalarEvolution::
6499getRangeForUnknownRecurrence(const SCEVUnknown *U) {
6500 const DataLayout &DL = getDataLayout();
6501
6502 unsigned BitWidth = getTypeSizeInBits(U->getType());
6503 const ConstantRange FullSet(BitWidth, /*isFullSet=*/true);
6504
6505 // Match a simple recurrence of the form: <start, ShiftOp, Step>, and then
6506 // use information about the trip count to improve our available range. Note
6507 // that the trip count independent cases are already handled by known bits.
6508 // WARNING: The definition of recurrence used here is subtly different than
6509 // the one used by AddRec (and thus most of this file). Step is allowed to
6510 // be arbitrarily loop varying here, where AddRec allows only loop invariant
6511 // and other addrecs in the same loop (for non-affine addrecs). The code
6512 // below intentionally handles the case where step is not loop invariant.
6513 auto *P = dyn_cast<PHINode>(U->getValue());
6514 if (!P)
6515 return FullSet;
6516
6517 // Make sure that no Phi input comes from an unreachable block. Otherwise,
6518 // even the values that are not available in these blocks may come from them,
6519 // and this leads to false-positive recurrence test.
6520 for (auto *Pred : predecessors(P->getParent()))
6521 if (!DT.isReachableFromEntry(Pred))
6522 return FullSet;
6523
6524 BinaryOperator *BO;
6525 Value *Start, *Step;
6526 if (!matchSimpleRecurrence(P, BO, Start, Step))
6527 return FullSet;
6528
6529 // If we found a recurrence in reachable code, we must be in a loop. Note
6530 // that BO might be in some subloop of L, and that's completely okay.
6531 auto *L = LI.getLoopFor(P->getParent());
6532 assert(L && L->getHeader() == P->getParent());
6533 if (!L->contains(BO->getParent()))
6534 // NOTE: This bailout should be an assert instead. However, asserting
6535 // the condition here exposes a case where LoopFusion is querying SCEV
6536 // with malformed loop information during the midst of the transform.
6537 // There doesn't appear to be an obvious fix, so for the moment bailout
6538 // until the caller issue can be fixed. PR49566 tracks the bug.
6539 return FullSet;
6540
6541 // TODO: Extend to other opcodes such as mul, and div
6542 switch (BO->getOpcode()) {
6543 default:
6544 return FullSet;
6545 case Instruction::AShr:
6546 case Instruction::LShr:
6547 case Instruction::Shl:
6548 break;
6549 };
6550
6551 if (BO->getOperand(0) != P)
6552 // TODO: Handle the power function forms some day.
6553 return FullSet;
6554
6555 unsigned TC = getSmallConstantMaxTripCount(L);
6556 if (!TC || TC >= BitWidth)
6557 return FullSet;
6558
6559 auto KnownStart = computeKnownBits(Start, DL, &AC, nullptr, &DT);
6560 auto KnownStep = computeKnownBits(Step, DL, &AC, nullptr, &DT);
6561 assert(KnownStart.getBitWidth() == BitWidth &&
6562 KnownStep.getBitWidth() == BitWidth);
6563
6564 // Compute total shift amount, being careful of overflow and bitwidths.
6565 auto MaxShiftAmt = KnownStep.getMaxValue();
6566 APInt TCAP(BitWidth, TC-1);
6567 bool Overflow = false;
6568 auto TotalShift = MaxShiftAmt.umul_ov(TCAP, Overflow);
6569 if (Overflow)
6570 return FullSet;
6571
6572 switch (BO->getOpcode()) {
6573 default:
6574 llvm_unreachable("filtered out above");
6575 case Instruction::AShr: {
6576 // For each ashr, three cases:
6577 // shift = 0 => unchanged value
6578 // saturation => 0 or -1
6579 // other => a value closer to zero (of the same sign)
6580 // Thus, the end value is closer to zero than the start.
6581 auto KnownEnd = KnownBits::ashr(KnownStart,
6582 KnownBits::makeConstant(TotalShift));
6583 if (KnownStart.isNonNegative())
6584 // Analogous to lshr (simply not yet canonicalized)
6585 return ConstantRange::getNonEmpty(KnownEnd.getMinValue(),
6586 KnownStart.getMaxValue() + 1);
6587 if (KnownStart.isNegative())
6588 // End >=u Start && End <=s Start
6589 return ConstantRange::getNonEmpty(KnownStart.getMinValue(),
6590 KnownEnd.getMaxValue() + 1);
6591 break;
6592 }
6593 case Instruction::LShr: {
6594 // For each lshr, three cases:
6595 // shift = 0 => unchanged value
6596 // saturation => 0
6597 // other => a smaller positive number
6598 // Thus, the low end of the unsigned range is the last value produced.
6599 auto KnownEnd = KnownBits::lshr(KnownStart,
6600 KnownBits::makeConstant(TotalShift));
6601 return ConstantRange::getNonEmpty(KnownEnd.getMinValue(),
6602 KnownStart.getMaxValue() + 1);
6603 }
6604 case Instruction::Shl: {
6605 // Iff no bits are shifted out, value increases on every shift.
6606 auto KnownEnd = KnownBits::shl(KnownStart,
6607 KnownBits::makeConstant(TotalShift));
6608 if (TotalShift.ult(KnownStart.countMinLeadingZeros()))
6609 return ConstantRange(KnownStart.getMinValue(),
6610 KnownEnd.getMaxValue() + 1);
6611 break;
6612 }
6613 };
6614 return FullSet;
6615}
6616
6617const ConstantRange &
6618ScalarEvolution::getRangeRefIter(const SCEV *S,
6619 ScalarEvolution::RangeSignHint SignHint) {
6620 DenseMap<const SCEV *, ConstantRange> &Cache =
6621 SignHint == ScalarEvolution::HINT_RANGE_UNSIGNED ? UnsignedRanges
6622 : SignedRanges;
6624 SmallPtrSet<const SCEV *, 8> Seen;
6625
6626 // Add Expr to the worklist, if Expr is either an N-ary expression or a
6627 // SCEVUnknown PHI node.
6628 auto AddToWorklist = [&WorkList, &Seen, &Cache](const SCEV *Expr) {
6629 if (!Seen.insert(Expr).second)
6630 return;
6631 if (Cache.contains(Expr))
6632 return;
6633 switch (Expr->getSCEVType()) {
6634 case scUnknown:
6635 if (!isa<PHINode>(cast<SCEVUnknown>(Expr)->getValue()))
6636 break;
6637 [[fallthrough]];
6638 case scConstant:
6639 case scVScale:
6640 case scTruncate:
6641 case scZeroExtend:
6642 case scSignExtend:
6643 case scPtrToInt:
6644 case scAddExpr:
6645 case scMulExpr:
6646 case scUDivExpr:
6647 case scAddRecExpr:
6648 case scUMaxExpr:
6649 case scSMaxExpr:
6650 case scUMinExpr:
6651 case scSMinExpr:
6653 WorkList.push_back(Expr);
6654 break;
6655 case scCouldNotCompute:
6656 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
6657 }
6658 };
6659 AddToWorklist(S);
6660
6661 // Build worklist by queuing operands of N-ary expressions and phi nodes.
6662 for (unsigned I = 0; I != WorkList.size(); ++I) {
6663 const SCEV *P = WorkList[I];
6664 auto *UnknownS = dyn_cast<SCEVUnknown>(P);
6665 // If it is not a `SCEVUnknown`, just recurse into operands.
6666 if (!UnknownS) {
6667 for (const SCEV *Op : P->operands())
6668 AddToWorklist(Op);
6669 continue;
6670 }
6671 // `SCEVUnknown`'s require special treatment.
6672 if (const PHINode *P = dyn_cast<PHINode>(UnknownS->getValue())) {
6673 if (!PendingPhiRangesIter.insert(P).second)
6674 continue;
6675 for (auto &Op : reverse(P->operands()))
6676 AddToWorklist(getSCEV(Op));
6677 }
6678 }
6679
6680 if (!WorkList.empty()) {
6681 // Use getRangeRef to compute ranges for items in the worklist in reverse
6682 // order. This will force ranges for earlier operands to be computed before
6683 // their users in most cases.
6684 for (const SCEV *P : reverse(drop_begin(WorkList))) {
6685 getRangeRef(P, SignHint);
6686
6687 if (auto *UnknownS = dyn_cast<SCEVUnknown>(P))
6688 if (const PHINode *P = dyn_cast<PHINode>(UnknownS->getValue()))
6689 PendingPhiRangesIter.erase(P);
6690 }
6691 }
6692
6693 return getRangeRef(S, SignHint, 0);
6694}
6695
6696/// Determine the range for a particular SCEV. If SignHint is
6697/// HINT_RANGE_UNSIGNED (resp. HINT_RANGE_SIGNED) then getRange prefers ranges
6698/// with a "cleaner" unsigned (resp. signed) representation.
6699const ConstantRange &ScalarEvolution::getRangeRef(
6700 const SCEV *S, ScalarEvolution::RangeSignHint SignHint, unsigned Depth) {
6701 DenseMap<const SCEV *, ConstantRange> &Cache =
6702 SignHint == ScalarEvolution::HINT_RANGE_UNSIGNED ? UnsignedRanges
6703 : SignedRanges;
6705 SignHint == ScalarEvolution::HINT_RANGE_UNSIGNED ? ConstantRange::Unsigned
6707
6708 // See if we've computed this range already.
6710 if (I != Cache.end())
6711 return I->second;
6712
6713 if (const SCEVConstant *C = dyn_cast<SCEVConstant>(S))
6714 return setRange(C, SignHint, ConstantRange(C->getAPInt()));
6715
6716 // Switch to iteratively computing the range for S, if it is part of a deeply
6717 // nested expression.
6719 return getRangeRefIter(S, SignHint);
6720
6721 unsigned BitWidth = getTypeSizeInBits(S->getType());
6722 ConstantRange ConservativeResult(BitWidth, /*isFullSet=*/true);
6723 using OBO = OverflowingBinaryOperator;
6724
6725 // If the value has known zeros, the maximum value will have those known zeros
6726 // as well.
6727 if (SignHint == ScalarEvolution::HINT_RANGE_UNSIGNED) {
6728 APInt Multiple = getNonZeroConstantMultiple(S);
6729 APInt Remainder = APInt::getMaxValue(BitWidth).urem(Multiple);
6730 if (!Remainder.isZero())
6731 ConservativeResult =
6732 ConstantRange(APInt::getMinValue(BitWidth),
6733 APInt::getMaxValue(BitWidth) - Remainder + 1);
6734 }
6735 else {
6736 uint32_t TZ = getMinTrailingZeros(S);
6737 if (TZ != 0) {
6738 ConservativeResult = ConstantRange(
6740 APInt::getSignedMaxValue(BitWidth).ashr(TZ).shl(TZ) + 1);
6741 }
6742 }
6743
6744 switch (S->getSCEVType()) {
6745 case scConstant:
6746 llvm_unreachable("Already handled above.");
6747 case scVScale:
6748 return setRange(S, SignHint, getVScaleRange(&F, BitWidth));
6749 case scTruncate: {
6750 const SCEVTruncateExpr *Trunc = cast<SCEVTruncateExpr>(S);
6751 ConstantRange X = getRangeRef(Trunc->getOperand(), SignHint, Depth + 1);
6752 return setRange(
6753 Trunc, SignHint,
6754 ConservativeResult.intersectWith(X.truncate(BitWidth), RangeType));
6755 }
6756 case scZeroExtend: {
6757 const SCEVZeroExtendExpr *ZExt = cast<SCEVZeroExtendExpr>(S);
6758 ConstantRange X = getRangeRef(ZExt->getOperand(), SignHint, Depth + 1);
6759 return setRange(
6760 ZExt, SignHint,
6761 ConservativeResult.intersectWith(X.zeroExtend(BitWidth), RangeType));
6762 }
6763 case scSignExtend: {
6764 const SCEVSignExtendExpr *SExt = cast<SCEVSignExtendExpr>(S);
6765 ConstantRange X = getRangeRef(SExt->getOperand(), SignHint, Depth + 1);
6766 return setRange(
6767 SExt, SignHint,
6768 ConservativeResult.intersectWith(X.signExtend(BitWidth), RangeType));
6769 }
6770 case scPtrToInt: {
6771 const SCEVPtrToIntExpr *PtrToInt = cast<SCEVPtrToIntExpr>(S);
6772 ConstantRange X = getRangeRef(PtrToInt->getOperand(), SignHint, Depth + 1);
6773 return setRange(PtrToInt, SignHint, X);
6774 }
6775 case scAddExpr: {
6776 const SCEVAddExpr *Add = cast<SCEVAddExpr>(S);
6777 ConstantRange X = getRangeRef(Add->getOperand(0), SignHint, Depth + 1);
6778 unsigned WrapType = OBO::AnyWrap;
6779 if (Add->hasNoSignedWrap())
6780 WrapType |= OBO::NoSignedWrap;
6781 if (Add->hasNoUnsignedWrap())
6782 WrapType |= OBO::NoUnsignedWrap;
6783 for (const SCEV *Op : drop_begin(Add->operands()))
6784 X = X.addWithNoWrap(getRangeRef(Op, SignHint, Depth + 1), WrapType,
6785 RangeType);
6786 return setRange(Add, SignHint,
6787 ConservativeResult.intersectWith(X, RangeType));
6788 }
6789 case scMulExpr: {
6790 const SCEVMulExpr *Mul = cast<SCEVMulExpr>(S);
6791 ConstantRange X = getRangeRef(Mul->getOperand(0), SignHint, Depth + 1);
6792 for (const SCEV *Op : drop_begin(Mul->operands()))
6793 X = X.multiply(getRangeRef(Op, SignHint, Depth + 1));
6794 return setRange(Mul, SignHint,
6795 ConservativeResult.intersectWith(X, RangeType));
6796 }
6797 case scUDivExpr: {
6798 const SCEVUDivExpr *UDiv = cast<SCEVUDivExpr>(S);
6799 ConstantRange X = getRangeRef(UDiv->getLHS(), SignHint, Depth + 1);
6800 ConstantRange Y = getRangeRef(UDiv->getRHS(), SignHint, Depth + 1);
6801 return setRange(UDiv, SignHint,
6802 ConservativeResult.intersectWith(X.udiv(Y), RangeType));
6803 }
6804 case scAddRecExpr: {
6805 const SCEVAddRecExpr *AddRec = cast<SCEVAddRecExpr>(S);
6806 // If there's no unsigned wrap, the value will never be less than its
6807 // initial value.
6808 if (AddRec->hasNoUnsignedWrap()) {
6809 APInt UnsignedMinValue = getUnsignedRangeMin(AddRec->getStart());
6810 if (!UnsignedMinValue.isZero())
6811 ConservativeResult = ConservativeResult.intersectWith(
6812 ConstantRange(UnsignedMinValue, APInt(BitWidth, 0)), RangeType);
6813 }
6814
6815 // If there's no signed wrap, and all the operands except initial value have
6816 // the same sign or zero, the value won't ever be:
6817 // 1: smaller than initial value if operands are non negative,
6818 // 2: bigger than initial value if operands are non positive.
6819 // For both cases, value can not cross signed min/max boundary.
6820 if (AddRec->hasNoSignedWrap()) {
6821 bool AllNonNeg = true;
6822 bool AllNonPos = true;
6823 for (unsigned i = 1, e = AddRec->getNumOperands(); i != e; ++i) {
6824 if (!isKnownNonNegative(AddRec->getOperand(i)))
6825 AllNonNeg = false;
6826 if (!isKnownNonPositive(AddRec->getOperand(i)))
6827 AllNonPos = false;
6828 }
6829 if (AllNonNeg)
6830 ConservativeResult = ConservativeResult.intersectWith(
6833 RangeType);
6834 else if (AllNonPos)
6835 ConservativeResult = ConservativeResult.intersectWith(
6837 getSignedRangeMax(AddRec->getStart()) +
6838 1),
6839 RangeType);
6840 }
6841
6842 // TODO: non-affine addrec
6843 if (AddRec->isAffine()) {
6844 const SCEV *MaxBEScev =
6846 if (!isa<SCEVCouldNotCompute>(MaxBEScev)) {
6847 APInt MaxBECount = cast<SCEVConstant>(MaxBEScev)->getAPInt();
6848
6849 // Adjust MaxBECount to the same bitwidth as AddRec. We can truncate if
6850 // MaxBECount's active bits are all <= AddRec's bit width.
6851 if (MaxBECount.getBitWidth() > BitWidth &&
6852 MaxBECount.getActiveBits() <= BitWidth)
6853 MaxBECount = MaxBECount.trunc(BitWidth);
6854 else if (MaxBECount.getBitWidth() < BitWidth)
6855 MaxBECount = MaxBECount.zext(BitWidth);
6856
6857 if (MaxBECount.getBitWidth() == BitWidth) {
6858 auto RangeFromAffine = getRangeForAffineAR(
6859 AddRec->getStart(), AddRec->getStepRecurrence(*this), MaxBECount);
6860 ConservativeResult =
6861 ConservativeResult.intersectWith(RangeFromAffine, RangeType);
6862
6863 auto RangeFromFactoring = getRangeViaFactoring(
6864 AddRec->getStart(), AddRec->getStepRecurrence(*this), MaxBECount);
6865 ConservativeResult =
6866 ConservativeResult.intersectWith(RangeFromFactoring, RangeType);
6867 }
6868 }
6869
6870 // Now try symbolic BE count and more powerful methods.
6872 const SCEV *SymbolicMaxBECount =
6874 if (!isa<SCEVCouldNotCompute>(SymbolicMaxBECount) &&
6875 getTypeSizeInBits(MaxBEScev->getType()) <= BitWidth &&
6876 AddRec->hasNoSelfWrap()) {
6877 auto RangeFromAffineNew = getRangeForAffineNoSelfWrappingAR(
6878 AddRec, SymbolicMaxBECount, BitWidth, SignHint);
6879 ConservativeResult =
6880 ConservativeResult.intersectWith(RangeFromAffineNew, RangeType);
6881 }
6882 }
6883 }
6884
6885 return setRange(AddRec, SignHint, std::move(ConservativeResult));
6886 }
6887 case scUMaxExpr:
6888 case scSMaxExpr:
6889 case scUMinExpr:
6890 case scSMinExpr:
6891 case scSequentialUMinExpr: {
6893 switch (S->getSCEVType()) {
6894 case scUMaxExpr:
6895 ID = Intrinsic::umax;
6896 break;
6897 case scSMaxExpr:
6898 ID = Intrinsic::smax;
6899 break;
6900 case scUMinExpr:
6902 ID = Intrinsic::umin;
6903 break;
6904 case scSMinExpr:
6905 ID = Intrinsic::smin;
6906 break;
6907 default:
6908 llvm_unreachable("Unknown SCEVMinMaxExpr/SCEVSequentialMinMaxExpr.");
6909 }
6910
6911 const auto *NAry = cast<SCEVNAryExpr>(S);
6912 ConstantRange X = getRangeRef(NAry->getOperand(0), SignHint, Depth + 1);
6913 for (unsigned i = 1, e = NAry->getNumOperands(); i != e; ++i)
6914 X = X.intrinsic(
6915 ID, {X, getRangeRef(NAry->getOperand(i), SignHint, Depth + 1)});
6916 return setRange(S, SignHint,
6917 ConservativeResult.intersectWith(X, RangeType));
6918 }
6919 case scUnknown: {
6920 const SCEVUnknown *U = cast<SCEVUnknown>(S);
6921 Value *V = U->getValue();
6922
6923 // Check if the IR explicitly contains !range metadata.
6924 std::optional<ConstantRange> MDRange = GetRangeFromMetadata(V);
6925 if (MDRange)
6926 ConservativeResult =
6927 ConservativeResult.intersectWith(*MDRange, RangeType);
6928
6929 // Use facts about recurrences in the underlying IR. Note that add
6930 // recurrences are AddRecExprs and thus don't hit this path. This
6931 // primarily handles shift recurrences.
6932 auto CR = getRangeForUnknownRecurrence(U);
6933 ConservativeResult = ConservativeResult.intersectWith(CR);
6934
6935 // See if ValueTracking can give us a useful range.
6936 const DataLayout &DL = getDataLayout();
6937 KnownBits Known = computeKnownBits(V, DL, &AC, nullptr, &DT);
6938 if (Known.getBitWidth() != BitWidth)
6939 Known = Known.zextOrTrunc(BitWidth);
6940
6941 // ValueTracking may be able to compute a tighter result for the number of
6942 // sign bits than for the value of those sign bits.
6943 unsigned NS = ComputeNumSignBits(V, DL, &AC, nullptr, &DT);
6944 if (U->getType()->isPointerTy()) {
6945 // If the pointer size is larger than the index size type, this can cause
6946 // NS to be larger than BitWidth. So compensate for this.
6947 unsigned ptrSize = DL.getPointerTypeSizeInBits(U->getType());
6948 int ptrIdxDiff = ptrSize - BitWidth;
6949 if (ptrIdxDiff > 0 && ptrSize > BitWidth && NS > (unsigned)ptrIdxDiff)
6950 NS -= ptrIdxDiff;
6951 }
6952
6953 if (NS > 1) {
6954 // If we know any of the sign bits, we know all of the sign bits.
6955 if (!Known.Zero.getHiBits(NS).isZero())
6956 Known.Zero.setHighBits(NS);
6957 if (!Known.One.getHiBits(NS).isZero())
6958 Known.One.setHighBits(NS);
6959 }
6960
6961 if (Known.getMinValue() != Known.getMaxValue() + 1)
6962 ConservativeResult = ConservativeResult.intersectWith(
6963 ConstantRange(Known.getMinValue(), Known.getMaxValue() + 1),
6964 RangeType);
6965 if (NS > 1)
6966 ConservativeResult = ConservativeResult.intersectWith(
6967 ConstantRange(APInt::getSignedMinValue(BitWidth).ashr(NS - 1),
6968 APInt::getSignedMaxValue(BitWidth).ashr(NS - 1) + 1),
6969 RangeType);
6970
6971 if (U->getType()->isPointerTy() && SignHint == HINT_RANGE_UNSIGNED) {
6972 // Strengthen the range if the underlying IR value is a
6973 // global/alloca/heap allocation using the size of the object.
6974 bool CanBeNull, CanBeFreed;
6975 uint64_t DerefBytes =
6976 V->getPointerDereferenceableBytes(DL, CanBeNull, CanBeFreed);
6977 if (DerefBytes > 1 && isUIntN(BitWidth, DerefBytes)) {
6978 // The highest address the object can start is DerefBytes bytes before
6979 // the end (unsigned max value). If this value is not a multiple of the
6980 // alignment, the last possible start value is the next lowest multiple
6981 // of the alignment. Note: The computations below cannot overflow,
6982 // because if they would there's no possible start address for the
6983 // object.
6984 APInt MaxVal =
6985 APInt::getMaxValue(BitWidth) - APInt(BitWidth, DerefBytes);
6986 uint64_t Align = U->getValue()->getPointerAlignment(DL).value();
6987 uint64_t Rem = MaxVal.urem(Align);
6988 MaxVal -= APInt(BitWidth, Rem);
6989 APInt MinVal = APInt::getZero(BitWidth);
6990 if (llvm::isKnownNonZero(V, DL))
6991 MinVal = Align;
6992 ConservativeResult = ConservativeResult.intersectWith(
6993 ConstantRange::getNonEmpty(MinVal, MaxVal + 1), RangeType);
6994 }
6995 }
6996
6997 // A range of Phi is a subset of union of all ranges of its input.
6998 if (PHINode *Phi = dyn_cast<PHINode>(V)) {
6999 // Make sure that we do not run over cycled Phis.
7000 if (PendingPhiRanges.insert(Phi).second) {
7001 ConstantRange RangeFromOps(BitWidth, /*isFullSet=*/false);
7002
7003 for (const auto &Op : Phi->operands()) {
7004 auto OpRange = getRangeRef(getSCEV(Op), SignHint, Depth + 1);
7005 RangeFromOps = RangeFromOps.unionWith(OpRange);
7006 // No point to continue if we already have a full set.
7007 if (RangeFromOps.isFullSet())
7008 break;
7009 }
7010 ConservativeResult =
7011 ConservativeResult.intersectWith(RangeFromOps, RangeType);
7012 bool Erased = PendingPhiRanges.erase(Phi);
7013 assert(Erased && "Failed to erase Phi properly?");
7014 (void)Erased;
7015 }
7016 }
7017
7018 // vscale can't be equal to zero
7019 if (const auto *II = dyn_cast<IntrinsicInst>(V))
7020 if (II->getIntrinsicID() == Intrinsic::vscale) {
7021 ConstantRange Disallowed = APInt::getZero(BitWidth);
7022 ConservativeResult = ConservativeResult.difference(Disallowed);
7023 }
7024
7025 return setRange(U, SignHint, std::move(ConservativeResult));
7026 }
7027 case scCouldNotCompute:
7028 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
7029 }
7030
7031 return setRange(S, SignHint, std::move(ConservativeResult));
7032}
7033
7034// Given a StartRange, Step and MaxBECount for an expression compute a range of
7035// values that the expression can take. Initially, the expression has a value
7036// from StartRange and then is changed by Step up to MaxBECount times. Signed
7037// argument defines if we treat Step as signed or unsigned.
7039 const ConstantRange &StartRange,
7040 const APInt &MaxBECount,
7041 bool Signed) {
7042 unsigned BitWidth = Step.getBitWidth();
7043 assert(BitWidth == StartRange.getBitWidth() &&
7044 BitWidth == MaxBECount.getBitWidth() && "mismatched bit widths");
7045 // If either Step or MaxBECount is 0, then the expression won't change, and we
7046 // just need to return the initial range.
7047 if (Step == 0 || MaxBECount == 0)
7048 return StartRange;
7049
7050 // If we don't know anything about the initial value (i.e. StartRange is
7051 // FullRange), then we don't know anything about the final range either.
7052 // Return FullRange.
7053 if (StartRange.isFullSet())
7054 return ConstantRange::getFull(BitWidth);
7055
7056 // If Step is signed and negative, then we use its absolute value, but we also
7057 // note that we're moving in the opposite direction.
7058 bool Descending = Signed && Step.isNegative();
7059
7060 if (Signed)
7061 // This is correct even for INT_SMIN. Let's look at i8 to illustrate this:
7062 // abs(INT_SMIN) = abs(-128) = abs(0x80) = -0x80 = 0x80 = 128.
7063 // This equations hold true due to the well-defined wrap-around behavior of
7064 // APInt.
7065 Step = Step.abs();
7066
7067 // Check if Offset is more than full span of BitWidth. If it is, the
7068 // expression is guaranteed to overflow.
7069 if (APInt::getMaxValue(StartRange.getBitWidth()).udiv(Step).ult(MaxBECount))
7070 return ConstantRange::getFull(BitWidth);
7071
7072 // Offset is by how much the expression can change. Checks above guarantee no
7073 // overflow here.
7074 APInt Offset = Step * MaxBECount;
7075
7076 // Minimum value of the final range will match the minimal value of StartRange
7077 // if the expression is increasing and will be decreased by Offset otherwise.
7078 // Maximum value of the final range will match the maximal value of StartRange
7079 // if the expression is decreasing and will be increased by Offset otherwise.
7080 APInt StartLower = StartRange.getLower();
7081 APInt StartUpper = StartRange.getUpper() - 1;
7082 APInt MovedBoundary = Descending ? (StartLower - std::move(Offset))
7083 : (StartUpper + std::move(Offset));
7084
7085 // It's possible that the new minimum/maximum value will fall into the initial
7086 // range (due to wrap around). This means that the expression can take any
7087 // value in this bitwidth, and we have to return full range.
7088 if (StartRange.contains(MovedBoundary))
7089 return ConstantRange::getFull(BitWidth);
7090
7091 APInt NewLower =
7092 Descending ? std::move(MovedBoundary) : std::move(StartLower);
7093 APInt NewUpper =
7094 Descending ? std::move(StartUpper) : std::move(MovedBoundary);
7095 NewUpper += 1;
7096
7097 // No overflow detected, return [StartLower, StartUpper + Offset + 1) range.
7098 return ConstantRange::getNonEmpty(std::move(NewLower), std::move(NewUpper));
7099}
7100
7101ConstantRange ScalarEvolution::getRangeForAffineAR(const SCEV *Start,
7102 const SCEV *Step,
7103 const APInt &MaxBECount) {
7104 assert(getTypeSizeInBits(Start->getType()) ==
7105 getTypeSizeInBits(Step->getType()) &&
7106 getTypeSizeInBits(Start->getType()) == MaxBECount.getBitWidth() &&
7107 "mismatched bit widths");
7108
7109 // First, consider step signed.
7110 ConstantRange StartSRange = getSignedRange(Start);
7111 ConstantRange StepSRange = getSignedRange(Step);
7112
7113 // If Step can be both positive and negative, we need to find ranges for the
7114 // maximum absolute step values in both directions and union them.
7115 ConstantRange SR = getRangeForAffineARHelper(
7116 StepSRange.getSignedMin(), StartSRange, MaxBECount, /* Signed = */ true);
7118 StartSRange, MaxBECount,
7119 /* Signed = */ true));
7120
7121 // Next, consider step unsigned.
7122 ConstantRange UR = getRangeForAffineARHelper(
7123 getUnsignedRangeMax(Step), getUnsignedRange(Start), MaxBECount,
7124 /* Signed = */ false);
7125
7126 // Finally, intersect signed and unsigned ranges.
7128}
7129
7130ConstantRange ScalarEvolution::getRangeForAffineNoSelfWrappingAR(
7131 const SCEVAddRecExpr *AddRec, const SCEV *MaxBECount, unsigned BitWidth,
7132 ScalarEvolution::RangeSignHint SignHint) {
7133 assert(AddRec->isAffine() && "Non-affine AddRecs are not suppored!\n");
7134 assert(AddRec->hasNoSelfWrap() &&
7135 "This only works for non-self-wrapping AddRecs!");
7136 const bool IsSigned = SignHint == HINT_RANGE_SIGNED;
7137 const SCEV *Step = AddRec->getStepRecurrence(*this);
7138 // Only deal with constant step to save compile time.
7139 if (!isa<SCEVConstant>(Step))
7140 return ConstantRange::getFull(BitWidth);
7141 // Let's make sure that we can prove that we do not self-wrap during
7142 // MaxBECount iterations. We need this because MaxBECount is a maximum
7143 // iteration count estimate, and we might infer nw from some exit for which we
7144 // do not know max exit count (or any other side reasoning).
7145 // TODO: Turn into assert at some point.
7146 if (getTypeSizeInBits(MaxBECount->getType()) >
7147 getTypeSizeInBits(AddRec->getType()))
7148 return ConstantRange::getFull(BitWidth);
7149 MaxBECount = getNoopOrZeroExtend(MaxBECount, AddRec->getType());
7150 const SCEV *RangeWidth = getMinusOne(AddRec->getType());
7151 const SCEV *StepAbs = getUMinExpr(Step, getNegativeSCEV(Step));
7152 const SCEV *MaxItersWithoutWrap = getUDivExpr(RangeWidth, StepAbs);
7153 if (!isKnownPredicateViaConstantRanges(ICmpInst::ICMP_ULE, MaxBECount,
7154 MaxItersWithoutWrap))
7155 return ConstantRange::getFull(BitWidth);
7156
7157 ICmpInst::Predicate LEPred =
7159 ICmpInst::Predicate GEPred =
7161 const SCEV *End = AddRec->evaluateAtIteration(MaxBECount, *this);
7162
7163 // We know that there is no self-wrap. Let's take Start and End values and
7164 // look at all intermediate values V1, V2, ..., Vn that IndVar takes during
7165 // the iteration. They either lie inside the range [Min(Start, End),
7166 // Max(Start, End)] or outside it:
7167 //
7168 // Case 1: RangeMin ... Start V1 ... VN End ... RangeMax;
7169 // Case 2: RangeMin Vk ... V1 Start ... End Vn ... Vk + 1 RangeMax;
7170 //
7171 // No self wrap flag guarantees that the intermediate values cannot be BOTH
7172 // outside and inside the range [Min(Start, End), Max(Start, End)]. Using that
7173 // knowledge, let's try to prove that we are dealing with Case 1. It is so if
7174 // Start <= End and step is positive, or Start >= End and step is negative.
7175 const SCEV *Start = applyLoopGuards(AddRec->getStart(), AddRec->getLoop());
7176 ConstantRange StartRange = getRangeRef(Start, SignHint);
7177 ConstantRange EndRange = getRangeRef(End, SignHint);
7178 ConstantRange RangeBetween = StartRange.unionWith(EndRange);
7179 // If they already cover full iteration space, we will know nothing useful
7180 // even if we prove what we want to prove.
7181 if (RangeBetween.isFullSet())
7182 return RangeBetween;
7183 // Only deal with ranges that do not wrap (i.e. RangeMin < RangeMax).
7184 bool IsWrappedSet = IsSigned ? RangeBetween.isSignWrappedSet()
7185 : RangeBetween.isWrappedSet();
7186 if (IsWrappedSet)
7187 return ConstantRange::getFull(BitWidth);
7188
7189 if (isKnownPositive(Step) &&
7190 isKnownPredicateViaConstantRanges(LEPred, Start, End))
7191 return RangeBetween;
7192 if (isKnownNegative(Step) &&
7193 isKnownPredicateViaConstantRanges(GEPred, Start, End))
7194 return RangeBetween;
7195 return ConstantRange::getFull(BitWidth);
7196}
7197
7198ConstantRange ScalarEvolution::getRangeViaFactoring(const SCEV *Start,
7199 const SCEV *Step,
7200 const APInt &MaxBECount) {
7201 // RangeOf({C?A:B,+,C?P:Q}) == RangeOf(C?{A,+,P}:{B,+,Q})
7202 // == RangeOf({A,+,P}) union RangeOf({B,+,Q})
7203
7204 unsigned BitWidth = MaxBECount.getBitWidth();
7205 assert(getTypeSizeInBits(Start->getType()) == BitWidth &&
7206 getTypeSizeInBits(Step->getType()) == BitWidth &&
7207 "mismatched bit widths");
7208
7209 struct SelectPattern {
7210 Value *Condition = nullptr;
7211 APInt TrueValue;
7212 APInt FalseValue;
7213
7214 explicit SelectPattern(ScalarEvolution &SE, unsigned BitWidth,
7215 const SCEV *S) {
7216 std::optional<unsigned> CastOp;
7217 APInt Offset(BitWidth, 0);
7218
7220 "Should be!");
7221
7222 // Peel off a constant offset. In the future we could consider being
7223 // smarter here and handle {Start+Step,+,Step} too.
7224 const APInt *Off;
7225 if (match(S, m_scev_Add(m_scev_APInt(Off), m_SCEV(S))))
7226 Offset = *Off;
7227
7228 // Peel off a cast operation
7229 if (auto *SCast = dyn_cast<SCEVIntegralCastExpr>(S)) {
7230 CastOp = SCast->getSCEVType();
7231 S = SCast->getOperand();
7232 }
7233
7234 using namespace llvm::PatternMatch;
7235
7236 auto *SU = dyn_cast<SCEVUnknown>(S);
7237 const APInt *TrueVal, *FalseVal;
7238 if (!SU ||
7239 !match(SU->getValue(), m_Select(m_Value(Condition), m_APInt(TrueVal),
7240 m_APInt(FalseVal)))) {
7241 Condition = nullptr;
7242 return;
7243 }
7244
7245 TrueValue = *TrueVal;
7246 FalseValue = *FalseVal;
7247
7248 // Re-apply the cast we peeled off earlier
7249 if (CastOp)
7250 switch (*CastOp) {
7251 default:
7252 llvm_unreachable("Unknown SCEV cast type!");
7253
7254 case scTruncate:
7255 TrueValue = TrueValue.trunc(BitWidth);
7256 FalseValue = FalseValue.trunc(BitWidth);
7257 break;
7258 case scZeroExtend:
7259 TrueValue = TrueValue.zext(BitWidth);
7260 FalseValue = FalseValue.zext(BitWidth);
7261 break;
7262 case scSignExtend:
7263 TrueValue = TrueValue.sext(BitWidth);
7264 FalseValue = FalseValue.sext(BitWidth);
7265 break;
7266 }
7267
7268 // Re-apply the constant offset we peeled off earlier
7269 TrueValue += Offset;
7270 FalseValue += Offset;
7271 }
7272
7273 bool isRecognized() { return Condition != nullptr; }
7274 };
7275
7276 SelectPattern StartPattern(*this, BitWidth, Start);
7277 if (!StartPattern.isRecognized())
7278 return ConstantRange::getFull(BitWidth);
7279
7280 SelectPattern StepPattern(*this, BitWidth, Step);
7281 if (!StepPattern.isRecognized())
7282 return ConstantRange::getFull(BitWidth);
7283
7284 if (StartPattern.Condition != StepPattern.Condition) {
7285 // We don't handle this case today; but we could, by considering four
7286 // possibilities below instead of two. I'm not sure if there are cases where
7287 // that will help over what getRange already does, though.
7288 return ConstantRange::getFull(BitWidth);
7289 }
7290
7291 // NB! Calling ScalarEvolution::getConstant is fine, but we should not try to
7292 // construct arbitrary general SCEV expressions here. This function is called
7293 // from deep in the call stack, and calling getSCEV (on a sext instruction,
7294 // say) can end up caching a suboptimal value.
7295
7296 // FIXME: without the explicit `this` receiver below, MSVC errors out with
7297 // C2352 and C2512 (otherwise it isn't needed).
7298
7299 const SCEV *TrueStart = this->getConstant(StartPattern.TrueValue);
7300 const SCEV *TrueStep = this->getConstant(StepPattern.TrueValue);
7301 const SCEV *FalseStart = this->getConstant(StartPattern.FalseValue);
7302 const SCEV *FalseStep = this->getConstant(StepPattern.FalseValue);
7303
7304 ConstantRange TrueRange =
7305 this->getRangeForAffineAR(TrueStart, TrueStep, MaxBECount);
7306 ConstantRange FalseRange =
7307 this->getRangeForAffineAR(FalseStart, FalseStep, MaxBECount);
7308
7309 return TrueRange.unionWith(FalseRange);
7310}
7311
7312SCEV::NoWrapFlags ScalarEvolution::getNoWrapFlagsFromUB(const Value *V) {
7313 if (isa<ConstantExpr>(V)) return SCEV::FlagAnyWrap;
7314 const BinaryOperator *BinOp = cast<BinaryOperator>(V);
7315
7316 // Return early if there are no flags to propagate to the SCEV.
7318 if (BinOp->hasNoUnsignedWrap())
7320 if (BinOp->hasNoSignedWrap())
7322 if (Flags == SCEV::FlagAnyWrap)
7323 return SCEV::FlagAnyWrap;
7324
7325 return isSCEVExprNeverPoison(BinOp) ? Flags : SCEV::FlagAnyWrap;
7326}
7327
7328const Instruction *
7329ScalarEvolution::getNonTrivialDefiningScopeBound(const SCEV *S) {
7330 if (auto *AddRec = dyn_cast<SCEVAddRecExpr>(S))
7331 return &*AddRec->getLoop()->getHeader()->begin();
7332 if (auto *U = dyn_cast<SCEVUnknown>(S))
7333 if (auto *I = dyn_cast<Instruction>(U->getValue()))
7334 return I;
7335 return nullptr;
7336}
7337
7338const Instruction *
7339ScalarEvolution::getDefiningScopeBound(ArrayRef<const SCEV *> Ops,
7340 bool &Precise) {
7341 Precise = true;
7342 // Do a bounded search of the def relation of the requested SCEVs.
7343 SmallPtrSet<const SCEV *, 16> Visited;
7345 auto pushOp = [&](const SCEV *S) {
7346 if (!Visited.insert(S).second)
7347 return;
7348 // Threshold of 30 here is arbitrary.
7349 if (Visited.size() > 30) {
7350 Precise = false;
7351 return;
7352 }
7353 Worklist.push_back(S);
7354 };
7355
7356 for (const auto *S : Ops)
7357 pushOp(S);
7358
7359 const Instruction *Bound = nullptr;
7360 while (!Worklist.empty()) {
7361 auto *S = Worklist.pop_back_val();
7362 if (auto *DefI = getNonTrivialDefiningScopeBound(S)) {
7363 if (!Bound || DT.dominates(Bound, DefI))
7364 Bound = DefI;
7365 } else {
7366 for (const auto *Op : S->operands())
7367 pushOp(Op);
7368 }
7369 }
7370 return Bound ? Bound : &*F.getEntryBlock().begin();
7371}
7372
7373const Instruction *
7374ScalarEvolution::getDefiningScopeBound(ArrayRef<const SCEV *> Ops) {
7375 bool Discard;
7376 return getDefiningScopeBound(Ops, Discard);
7377}
7378
7379bool ScalarEvolution::isGuaranteedToTransferExecutionTo(const Instruction *A,
7380 const Instruction *B) {
7381 if (A->getParent() == B->getParent() &&
7383 B->getIterator()))
7384 return true;
7385
7386 auto *BLoop = LI.getLoopFor(B->getParent());
7387 if (BLoop && BLoop->getHeader() == B->getParent() &&
7388 BLoop->getLoopPreheader() == A->getParent() &&
7390 A->getParent()->end()) &&
7391 isGuaranteedToTransferExecutionToSuccessor(B->getParent()->begin(),
7392 B->getIterator()))
7393 return true;
7394 return false;
7395}
7396
7397bool ScalarEvolution::isGuaranteedNotToBePoison(const SCEV *Op) {
7398 SCEVPoisonCollector PC(/* LookThroughMaybePoisonBlocking */ true);
7399 visitAll(Op, PC);
7400 return PC.MaybePoison.empty();
7401}
7402
7403bool ScalarEvolution::isGuaranteedNotToCauseUB(const SCEV *Op) {
7404 return !SCEVExprContains(Op, [this](const SCEV *S) {
7405 const SCEV *Op1;
7406 bool M = match(S, m_scev_UDiv(m_SCEV(), m_SCEV(Op1)));
7407 // The UDiv may be UB if the divisor is poison or zero. Unless the divisor
7408 // is a non-zero constant, we have to assume the UDiv may be UB.
7409 return M && (!isKnownNonZero(Op1) || !isGuaranteedNotToBePoison(Op1));
7410 });
7411}
7412
7413bool ScalarEvolution::isSCEVExprNeverPoison(const Instruction *I) {
7414 // Only proceed if we can prove that I does not yield poison.
7416 return false;
7417
7418 // At this point we know that if I is executed, then it does not wrap
7419 // according to at least one of NSW or NUW. If I is not executed, then we do
7420 // not know if the calculation that I represents would wrap. Multiple
7421 // instructions can map to the same SCEV. If we apply NSW or NUW from I to
7422 // the SCEV, we must guarantee no wrapping for that SCEV also when it is
7423 // derived from other instructions that map to the same SCEV. We cannot make
7424 // that guarantee for cases where I is not executed. So we need to find a
7425 // upper bound on the defining scope for the SCEV, and prove that I is
7426 // executed every time we enter that scope. When the bounding scope is a
7427 // loop (the common case), this is equivalent to proving I executes on every
7428 // iteration of that loop.
7430 for (const Use &Op : I->operands()) {
7431 // I could be an extractvalue from a call to an overflow intrinsic.
7432 // TODO: We can do better here in some cases.
7433 if (isSCEVable(Op->getType()))
7434 SCEVOps.push_back(getSCEV(Op));
7435 }
7436 auto *DefI = getDefiningScopeBound(SCEVOps);
7437 return isGuaranteedToTransferExecutionTo(DefI, I);
7438}
7439
7440bool ScalarEvolution::isAddRecNeverPoison(const Instruction *I, const Loop *L) {
7441 // If we know that \c I can never be poison period, then that's enough.
7442 if (isSCEVExprNeverPoison(I))
7443 return true;
7444
7445 // If the loop only has one exit, then we know that, if the loop is entered,
7446 // any instruction dominating that exit will be executed. If any such
7447 // instruction would result in UB, the addrec cannot be poison.
7448 //
7449 // This is basically the same reasoning as in isSCEVExprNeverPoison(), but
7450 // also handles uses outside the loop header (they just need to dominate the
7451 // single exit).
7452
7453 auto *ExitingBB = L->getExitingBlock();
7454 if (!ExitingBB || !loopHasNoAbnormalExits(L))
7455 return false;
7456
7457 SmallPtrSet<const Value *, 16> KnownPoison;
7459
7460 // We start by assuming \c I, the post-inc add recurrence, is poison. Only
7461 // things that are known to be poison under that assumption go on the
7462 // Worklist.
7463 KnownPoison.insert(I);
7464 Worklist.push_back(I);
7465
7466 while (!Worklist.empty()) {
7467 const Instruction *Poison = Worklist.pop_back_val();
7468
7469 for (const Use &U : Poison->uses()) {
7470 const Instruction *PoisonUser = cast<Instruction>(U.getUser());
7471 if (mustTriggerUB(PoisonUser, KnownPoison) &&
7472 DT.dominates(PoisonUser->getParent(), ExitingBB))
7473 return true;
7474
7475 if (propagatesPoison(U) && L->contains(PoisonUser))
7476 if (KnownPoison.insert(PoisonUser).second)
7477 Worklist.push_back(PoisonUser);
7478 }
7479 }
7480
7481 return false;
7482}
7483
7484ScalarEvolution::LoopProperties
7485ScalarEvolution::getLoopProperties(const Loop *L) {
7486 using LoopProperties = ScalarEvolution::LoopProperties;
7487
7488 auto Itr = LoopPropertiesCache.find(L);
7489 if (Itr == LoopPropertiesCache.end()) {
7490 auto HasSideEffects = [](Instruction *I) {
7491 if (auto *SI = dyn_cast<StoreInst>(I))
7492 return !SI->isSimple();
7493
7494 if (I->mayThrow())
7495 return true;
7496
7497 // Non-volatile memset / memcpy do not count as side-effect for forward
7498 // progress.
7499 if (isa<MemIntrinsic>(I) && !I->isVolatile())
7500 return false;
7501
7502 return I->mayWriteToMemory();
7503 };
7504
7505 LoopProperties LP = {/* HasNoAbnormalExits */ true,
7506 /*HasNoSideEffects*/ true};
7507
7508 for (auto *BB : L->getBlocks())
7509 for (auto &I : *BB) {
7511 LP.HasNoAbnormalExits = false;
7512 if (HasSideEffects(&I))
7513 LP.HasNoSideEffects = false;
7514 if (!LP.HasNoAbnormalExits && !LP.HasNoSideEffects)
7515 break; // We're already as pessimistic as we can get.
7516 }
7517
7518 auto InsertPair = LoopPropertiesCache.insert({L, LP});
7519 assert(InsertPair.second && "We just checked!");
7520 Itr = InsertPair.first;
7521 }
7522
7523 return Itr->second;
7524}
7525
7527 // A mustprogress loop without side effects must be finite.
7528 // TODO: The check used here is very conservative. It's only *specific*
7529 // side effects which are well defined in infinite loops.
7530 return isFinite(L) || (isMustProgress(L) && loopHasNoSideEffects(L));
7531}
7532
7533const SCEV *ScalarEvolution::createSCEVIter(Value *V) {
7534 // Worklist item with a Value and a bool indicating whether all operands have
7535 // been visited already.
7538
7539 Stack.emplace_back(V, true);
7540 Stack.emplace_back(V, false);
7541 while (!Stack.empty()) {
7542 auto E = Stack.pop_back_val();
7543 Value *CurV = E.getPointer();
7544
7545 if (getExistingSCEV(CurV))
7546 continue;
7547
7549 const SCEV *CreatedSCEV = nullptr;
7550 // If all operands have been visited already, create the SCEV.
7551 if (E.getInt()) {
7552 CreatedSCEV = createSCEV(CurV);
7553 } else {
7554 // Otherwise get the operands we need to create SCEV's for before creating
7555 // the SCEV for CurV. If the SCEV for CurV can be constructed trivially,
7556 // just use it.
7557 CreatedSCEV = getOperandsToCreate(CurV, Ops);
7558 }
7559
7560 if (CreatedSCEV) {
7561 insertValueToMap(CurV, CreatedSCEV);
7562 } else {
7563 // Queue CurV for SCEV creation, followed by its's operands which need to
7564 // be constructed first.
7565 Stack.emplace_back(CurV, true);
7566 for (Value *Op : Ops)
7567 Stack.emplace_back(Op, false);
7568 }
7569 }
7570
7571 return getExistingSCEV(V);
7572}
7573
7574const SCEV *
7575ScalarEvolution::getOperandsToCreate(Value *V, SmallVectorImpl<Value *> &Ops) {
7576 if (!isSCEVable(V->getType()))
7577 return getUnknown(V);
7578
7579 if (Instruction *I = dyn_cast<Instruction>(V)) {
7580 // Don't attempt to analyze instructions in blocks that aren't
7581 // reachable. Such instructions don't matter, and they aren't required
7582 // to obey basic rules for definitions dominating uses which this
7583 // analysis depends on.
7584 if (!DT.isReachableFromEntry(I->getParent()))
7585 return getUnknown(PoisonValue::get(V->getType()));
7586 } else if (ConstantInt *CI = dyn_cast<ConstantInt>(V))
7587 return getConstant(CI);
7588 else if (isa<GlobalAlias>(V))
7589 return getUnknown(V);
7590 else if (!isa<ConstantExpr>(V))
7591 return getUnknown(V);
7592
7594 if (auto BO =
7596 bool IsConstArg = isa<ConstantInt>(BO->RHS);
7597 switch (BO->Opcode) {
7598 case Instruction::Add:
7599 case Instruction::Mul: {
7600 // For additions and multiplications, traverse add/mul chains for which we
7601 // can potentially create a single SCEV, to reduce the number of
7602 // get{Add,Mul}Expr calls.
7603 do {
7604 if (BO->Op) {
7605 if (BO->Op != V && getExistingSCEV(BO->Op)) {
7606 Ops.push_back(BO->Op);
7607 break;
7608 }
7609 }
7610 Ops.push_back(BO->RHS);
7611 auto NewBO = MatchBinaryOp(BO->LHS, getDataLayout(), AC, DT,
7613 if (!NewBO ||
7614 (BO->Opcode == Instruction::Add &&
7615 (NewBO->Opcode != Instruction::Add &&
7616 NewBO->Opcode != Instruction::Sub)) ||
7617 (BO->Opcode == Instruction::Mul &&
7618 NewBO->Opcode != Instruction::Mul)) {
7619 Ops.push_back(BO->LHS);
7620 break;
7621 }
7622 // CreateSCEV calls getNoWrapFlagsFromUB, which under certain conditions
7623 // requires a SCEV for the LHS.
7624 if (BO->Op && (BO->IsNSW || BO->IsNUW)) {
7625 auto *I = dyn_cast<Instruction>(BO->Op);
7626 if (I && programUndefinedIfPoison(I)) {
7627 Ops.push_back(BO->LHS);
7628 break;
7629 }
7630 }
7631 BO = NewBO;
7632 } while (true);
7633 return nullptr;
7634 }
7635 case Instruction::Sub:
7636 case Instruction::UDiv:
7637 case Instruction::URem:
7638 break;
7639 case Instruction::AShr:
7640 case Instruction::Shl:
7641 case Instruction::Xor:
7642 if (!IsConstArg)
7643 return nullptr;
7644 break;
7645 case Instruction::And:
7646 case Instruction::Or:
7647 if (!IsConstArg && !BO->LHS->getType()->isIntegerTy(1))
7648 return nullptr;
7649 break;
7650 case Instruction::LShr:
7651 return getUnknown(V);
7652 default:
7653 llvm_unreachable("Unhandled binop");
7654 break;
7655 }
7656
7657 Ops.push_back(BO->LHS);
7658 Ops.push_back(BO->RHS);
7659 return nullptr;
7660 }
7661
7662 switch (U->getOpcode()) {
7663 case Instruction::Trunc:
7664 case Instruction::ZExt:
7665 case Instruction::SExt:
7666 case Instruction::PtrToInt:
7667 Ops.push_back(U->getOperand(0));
7668 return nullptr;
7669
7670 case Instruction::BitCast:
7671 if (isSCEVable(U->getType()) && isSCEVable(U->getOperand(0)->getType())) {
7672 Ops.push_back(U->getOperand(0));
7673 return nullptr;
7674 }
7675 return getUnknown(V);
7676
7677 case Instruction::SDiv:
7678 case Instruction::SRem:
7679 Ops.push_back(U->getOperand(0));
7680 Ops.push_back(U->getOperand(1));
7681 return nullptr;
7682
7683 case Instruction::GetElementPtr:
7684 assert(cast<GEPOperator>(U)->getSourceElementType()->isSized() &&
7685 "GEP source element type must be sized");
7686 llvm::append_range(Ops, U->operands());
7687 return nullptr;
7688
7689 case Instruction::IntToPtr:
7690 return getUnknown(V);
7691
7692 case Instruction::PHI:
7693 // Keep constructing SCEVs' for phis recursively for now.
7694 return nullptr;
7695
7696 case Instruction::Select: {
7697 // Check if U is a select that can be simplified to a SCEVUnknown.
7698 auto CanSimplifyToUnknown = [this, U]() {
7699 if (U->getType()->isIntegerTy(1) || isa<ConstantInt>(U->getOperand(0)))
7700 return false;
7701
7702 auto *ICI = dyn_cast<ICmpInst>(U->getOperand(0));
7703 if (!ICI)
7704 return false;
7705 Value *LHS = ICI->getOperand(0);
7706 Value *RHS = ICI->getOperand(1);
7707 if (ICI->getPredicate() == CmpInst::ICMP_EQ ||
7708 ICI->getPredicate() == CmpInst::ICMP_NE) {
7710 return true;
7711 } else if (getTypeSizeInBits(LHS->getType()) >
7712 getTypeSizeInBits(U->getType()))
7713 return true;
7714 return false;
7715 };
7716 if (CanSimplifyToUnknown())
7717 return getUnknown(U);
7718
7719 llvm::append_range(Ops, U->operands());
7720 return nullptr;
7721 break;
7722 }
7723 case Instruction::Call:
7724 case Instruction::Invoke:
7725 if (Value *RV = cast<CallBase>(U)->getReturnedArgOperand()) {
7726 Ops.push_back(RV);
7727 return nullptr;
7728 }
7729
7730 if (auto *II = dyn_cast<IntrinsicInst>(U)) {
7731 switch (II->getIntrinsicID()) {
7732 case Intrinsic::abs:
7733 Ops.push_back(II->getArgOperand(0));
7734 return nullptr;
7735 case Intrinsic::umax:
7736 case Intrinsic::umin:
7737 case Intrinsic::smax:
7738 case Intrinsic::smin:
7739 case Intrinsic::usub_sat:
7740 case Intrinsic::uadd_sat:
7741 Ops.push_back(II->getArgOperand(0));
7742 Ops.push_back(II->getArgOperand(1));
7743 return nullptr;
7744 case Intrinsic::start_loop_iterations:
7745 case Intrinsic::annotation:
7746 case Intrinsic::ptr_annotation:
7747 Ops.push_back(II->getArgOperand(0));
7748 return nullptr;
7749 default:
7750 break;
7751 }
7752 }
7753 break;
7754 }
7755
7756 return nullptr;
7757}
7758
7759const SCEV *ScalarEvolution::createSCEV(Value *V) {
7760 if (!isSCEVable(V->getType()))
7761 return getUnknown(V);
7762
7763 if (Instruction *I = dyn_cast<Instruction>(V)) {
7764 // Don't attempt to analyze instructions in blocks that aren't
7765 // reachable. Such instructions don't matter, and they aren't required
7766 // to obey basic rules for definitions dominating uses which this
7767 // analysis depends on.
7768 if (!DT.isReachableFromEntry(I->getParent()))
7769 return getUnknown(PoisonValue::get(V->getType()));
7770 } else if (ConstantInt *CI = dyn_cast<ConstantInt>(V))
7771 return getConstant(CI);
7772 else if (isa<GlobalAlias>(V))
7773 return getUnknown(V);
7774 else if (!isa<ConstantExpr>(V))
7775 return getUnknown(V);
7776
7777 const SCEV *LHS;
7778 const SCEV *RHS;
7779
7781 if (auto BO =
7783 switch (BO->Opcode) {
7784 case Instruction::Add: {
7785 // The simple thing to do would be to just call getSCEV on both operands
7786 // and call getAddExpr with the result. However if we're looking at a
7787 // bunch of things all added together, this can be quite inefficient,
7788 // because it leads to N-1 getAddExpr calls for N ultimate operands.
7789 // Instead, gather up all the operands and make a single getAddExpr call.
7790 // LLVM IR canonical form means we need only traverse the left operands.
7792 do {
7793 if (BO->Op) {
7794 if (auto *OpSCEV = getExistingSCEV(BO->Op)) {
7795 AddOps.push_back(OpSCEV);
7796 break;
7797 }
7798
7799 // If a NUW or NSW flag can be applied to the SCEV for this
7800 // addition, then compute the SCEV for this addition by itself
7801 // with a separate call to getAddExpr. We need to do that
7802 // instead of pushing the operands of the addition onto AddOps,
7803 // since the flags are only known to apply to this particular
7804 // addition - they may not apply to other additions that can be
7805 // formed with operands from AddOps.
7806 const SCEV *RHS = getSCEV(BO->RHS);
7807 SCEV::NoWrapFlags Flags = getNoWrapFlagsFromUB(BO->Op);
7808 if (Flags != SCEV::FlagAnyWrap) {
7809 const SCEV *LHS = getSCEV(BO->LHS);
7810 if (BO->Opcode == Instruction::Sub)
7811 AddOps.push_back(getMinusSCEV(LHS, RHS, Flags));
7812 else
7813 AddOps.push_back(getAddExpr(LHS, RHS, Flags));
7814 break;
7815 }
7816 }
7817
7818 if (BO->Opcode == Instruction::Sub)
7819 AddOps.push_back(getNegativeSCEV(getSCEV(BO->RHS)));
7820 else
7821 AddOps.push_back(getSCEV(BO->RHS));
7822
7823 auto NewBO = MatchBinaryOp(BO->LHS, getDataLayout(), AC, DT,
7825 if (!NewBO || (NewBO->Opcode != Instruction::Add &&
7826 NewBO->Opcode != Instruction::Sub)) {
7827 AddOps.push_back(getSCEV(BO->LHS));
7828 break;
7829 }
7830 BO = NewBO;
7831 } while (true);
7832
7833 return getAddExpr(AddOps);
7834 }
7835
7836 case Instruction::Mul: {
7838 do {
7839 if (BO->Op) {
7840 if (auto *OpSCEV = getExistingSCEV(BO->Op)) {
7841 MulOps.push_back(OpSCEV);
7842 break;
7843 }
7844
7845 SCEV::NoWrapFlags Flags = getNoWrapFlagsFromUB(BO->Op);
7846 if (Flags != SCEV::FlagAnyWrap) {
7847 LHS = getSCEV(BO->LHS);
7848 RHS = getSCEV(BO->RHS);
7849 MulOps.push_back(getMulExpr(LHS, RHS, Flags));
7850 break;
7851 }
7852 }
7853
7854 MulOps.push_back(getSCEV(BO->RHS));
7855 auto NewBO = MatchBinaryOp(BO->LHS, getDataLayout(), AC, DT,
7857 if (!NewBO || NewBO->Opcode != Instruction::Mul) {
7858 MulOps.push_back(getSCEV(BO->LHS));
7859 break;
7860 }
7861 BO = NewBO;
7862 } while (true);
7863
7864 return getMulExpr(MulOps);
7865 }
7866 case Instruction::UDiv:
7867 LHS = getSCEV(BO->LHS);
7868 RHS = getSCEV(BO->RHS);
7869 return getUDivExpr(LHS, RHS);
7870 case Instruction::URem:
7871 LHS = getSCEV(BO->LHS);
7872 RHS = getSCEV(BO->RHS);
7873 return getURemExpr(LHS, RHS);
7874 case Instruction::Sub: {
7876 if (BO->Op)
7877 Flags = getNoWrapFlagsFromUB(BO->Op);
7878 LHS = getSCEV(BO->LHS);
7879 RHS = getSCEV(BO->RHS);
7880 return getMinusSCEV(LHS, RHS, Flags);
7881 }
7882 case Instruction::And:
7883 // For an expression like x&255 that merely masks off the high bits,
7884 // use zext(trunc(x)) as the SCEV expression.
7885 if (ConstantInt *CI = dyn_cast<ConstantInt>(BO->RHS)) {
7886 if (CI->isZero())
7887 return getSCEV(BO->RHS);
7888 if (CI->isMinusOne())
7889 return getSCEV(BO->LHS);
7890 const APInt &A = CI->getValue();
7891
7892 // Instcombine's ShrinkDemandedConstant may strip bits out of
7893 // constants, obscuring what would otherwise be a low-bits mask.
7894 // Use computeKnownBits to compute what ShrinkDemandedConstant
7895 // knew about to reconstruct a low-bits mask value.
7896 unsigned LZ = A.countl_zero();
7897 unsigned TZ = A.countr_zero();
7898 unsigned BitWidth = A.getBitWidth();
7899 KnownBits Known(BitWidth);
7900 computeKnownBits(BO->LHS, Known, getDataLayout(), &AC, nullptr, &DT);
7901
7902 APInt EffectiveMask =
7903 APInt::getLowBitsSet(BitWidth, BitWidth - LZ - TZ).shl(TZ);
7904 if ((LZ != 0 || TZ != 0) && !((~A & ~Known.Zero) & EffectiveMask)) {
7905 const SCEV *MulCount = getConstant(APInt::getOneBitSet(BitWidth, TZ));
7906 const SCEV *LHS = getSCEV(BO->LHS);
7907 const SCEV *ShiftedLHS = nullptr;
7908 if (auto *LHSMul = dyn_cast<SCEVMulExpr>(LHS)) {
7909 if (auto *OpC = dyn_cast<SCEVConstant>(LHSMul->getOperand(0))) {
7910 // For an expression like (x * 8) & 8, simplify the multiply.
7911 unsigned MulZeros = OpC->getAPInt().countr_zero();
7912 unsigned GCD = std::min(MulZeros, TZ);
7913 APInt DivAmt = APInt::getOneBitSet(BitWidth, TZ - GCD);
7915 MulOps.push_back(getConstant(OpC->getAPInt().ashr(GCD)));
7916 append_range(MulOps, LHSMul->operands().drop_front());
7917 auto *NewMul = getMulExpr(MulOps, LHSMul->getNoWrapFlags());
7918 ShiftedLHS = getUDivExpr(NewMul, getConstant(DivAmt));
7919 }
7920 }
7921 if (!ShiftedLHS)
7922 ShiftedLHS = getUDivExpr(LHS, MulCount);
7923 return getMulExpr(
7925 getTruncateExpr(ShiftedLHS,
7926 IntegerType::get(getContext(), BitWidth - LZ - TZ)),
7927 BO->LHS->getType()),
7928 MulCount);
7929 }
7930 }
7931 // Binary `and` is a bit-wise `umin`.
7932 if (BO->LHS->getType()->isIntegerTy(1)) {
7933 LHS = getSCEV(BO->LHS);
7934 RHS = getSCEV(BO->RHS);
7935 return getUMinExpr(LHS, RHS);
7936 }
7937 break;
7938
7939 case Instruction::Or:
7940 // Binary `or` is a bit-wise `umax`.
7941 if (BO->LHS->getType()->isIntegerTy(1)) {
7942 LHS = getSCEV(BO->LHS);
7943 RHS = getSCEV(BO->RHS);
7944 return getUMaxExpr(LHS, RHS);
7945 }
7946 break;
7947
7948 case Instruction::Xor:
7949 if (ConstantInt *CI = dyn_cast<ConstantInt>(BO->RHS)) {
7950 // If the RHS of xor is -1, then this is a not operation.
7951 if (CI->isMinusOne())
7952 return getNotSCEV(getSCEV(BO->LHS));
7953
7954 // Model xor(and(x, C), C) as and(~x, C), if C is a low-bits mask.
7955 // This is a variant of the check for xor with -1, and it handles
7956 // the case where instcombine has trimmed non-demanded bits out
7957 // of an xor with -1.
7958 if (auto *LBO = dyn_cast<BinaryOperator>(BO->LHS))
7959 if (ConstantInt *LCI = dyn_cast<ConstantInt>(LBO->getOperand(1)))
7960 if (LBO->getOpcode() == Instruction::And &&
7961 LCI->getValue() == CI->getValue())
7962 if (const SCEVZeroExtendExpr *Z =
7964 Type *UTy = BO->LHS->getType();
7965 const SCEV *Z0 = Z->getOperand();
7966 Type *Z0Ty = Z0->getType();
7967 unsigned Z0TySize = getTypeSizeInBits(Z0Ty);
7968
7969 // If C is a low-bits mask, the zero extend is serving to
7970 // mask off the high bits. Complement the operand and
7971 // re-apply the zext.
7972 if (CI->getValue().isMask(Z0TySize))
7973 return getZeroExtendExpr(getNotSCEV(Z0), UTy);
7974
7975 // If C is a single bit, it may be in the sign-bit position
7976 // before the zero-extend. In this case, represent the xor
7977 // using an add, which is equivalent, and re-apply the zext.
7978 APInt Trunc = CI->getValue().trunc(Z0TySize);
7979 if (Trunc.zext(getTypeSizeInBits(UTy)) == CI->getValue() &&
7980 Trunc.isSignMask())
7981 return getZeroExtendExpr(getAddExpr(Z0, getConstant(Trunc)),
7982 UTy);
7983 }
7984 }
7985 break;
7986
7987 case Instruction::Shl:
7988 // Turn shift left of a constant amount into a multiply.
7989 if (ConstantInt *SA = dyn_cast<ConstantInt>(BO->RHS)) {
7990 uint32_t BitWidth = cast<IntegerType>(SA->getType())->getBitWidth();
7991
7992 // If the shift count is not less than the bitwidth, the result of
7993 // the shift is undefined. Don't try to analyze it, because the
7994 // resolution chosen here may differ from the resolution chosen in
7995 // other parts of the compiler.
7996 if (SA->getValue().uge(BitWidth))
7997 break;
7998
7999 // We can safely preserve the nuw flag in all cases. It's also safe to
8000 // turn a nuw nsw shl into a nuw nsw mul. However, nsw in isolation
8001 // requires special handling. It can be preserved as long as we're not
8002 // left shifting by bitwidth - 1.
8003 auto Flags = SCEV::FlagAnyWrap;
8004 if (BO->Op) {
8005 auto MulFlags = getNoWrapFlagsFromUB(BO->Op);
8006 if ((MulFlags & SCEV::FlagNSW) &&
8007 ((MulFlags & SCEV::FlagNUW) || SA->getValue().ult(BitWidth - 1)))
8009 if (MulFlags & SCEV::FlagNUW)
8011 }
8012
8013 ConstantInt *X = ConstantInt::get(
8014 getContext(), APInt::getOneBitSet(BitWidth, SA->getZExtValue()));
8015 return getMulExpr(getSCEV(BO->LHS), getConstant(X), Flags);
8016 }
8017 break;
8018
8019 case Instruction::AShr:
8020 // AShr X, C, where C is a constant.
8021 ConstantInt *CI = dyn_cast<ConstantInt>(BO->RHS);
8022 if (!CI)
8023 break;
8024
8025 Type *OuterTy = BO->LHS->getType();
8026 uint64_t BitWidth = getTypeSizeInBits(OuterTy);
8027 // If the shift count is not less than the bitwidth, the result of
8028 // the shift is undefined. Don't try to analyze it, because the
8029 // resolution chosen here may differ from the resolution chosen in
8030 // other parts of the compiler.
8031 if (CI->getValue().uge(BitWidth))
8032 break;
8033
8034 if (CI->isZero())
8035 return getSCEV(BO->LHS); // shift by zero --> noop
8036
8037 uint64_t AShrAmt = CI->getZExtValue();
8038 Type *TruncTy = IntegerType::get(getContext(), BitWidth - AShrAmt);
8039
8040 Operator *L = dyn_cast<Operator>(BO->LHS);
8041 const SCEV *AddTruncateExpr = nullptr;
8042 ConstantInt *ShlAmtCI = nullptr;
8043 const SCEV *AddConstant = nullptr;
8044
8045 if (L && L->getOpcode() == Instruction::Add) {
8046 // X = Shl A, n
8047 // Y = Add X, c
8048 // Z = AShr Y, m
8049 // n, c and m are constants.
8050
8051 Operator *LShift = dyn_cast<Operator>(L->getOperand(0));
8052 ConstantInt *AddOperandCI = dyn_cast<ConstantInt>(L->getOperand(1));
8053 if (LShift && LShift->getOpcode() == Instruction::Shl) {
8054 if (AddOperandCI) {
8055 const SCEV *ShlOp0SCEV = getSCEV(LShift->getOperand(0));
8056 ShlAmtCI = dyn_cast<ConstantInt>(LShift->getOperand(1));
8057 // since we truncate to TruncTy, the AddConstant should be of the
8058 // same type, so create a new Constant with type same as TruncTy.
8059 // Also, the Add constant should be shifted right by AShr amount.
8060 APInt AddOperand = AddOperandCI->getValue().ashr(AShrAmt);
8061 AddConstant = getConstant(AddOperand.trunc(BitWidth - AShrAmt));
8062 // we model the expression as sext(add(trunc(A), c << n)), since the
8063 // sext(trunc) part is already handled below, we create a
8064 // AddExpr(TruncExp) which will be used later.
8065 AddTruncateExpr = getTruncateExpr(ShlOp0SCEV, TruncTy);
8066 }
8067 }
8068 } else if (L && L->getOpcode() == Instruction::Shl) {
8069 // X = Shl A, n
8070 // Y = AShr X, m
8071 // Both n and m are constant.
8072
8073 const SCEV *ShlOp0SCEV = getSCEV(L->getOperand(0));
8074 ShlAmtCI = dyn_cast<ConstantInt>(L->getOperand(1));
8075 AddTruncateExpr = getTruncateExpr(ShlOp0SCEV, TruncTy);
8076 }
8077
8078 if (AddTruncateExpr && ShlAmtCI) {
8079 // We can merge the two given cases into a single SCEV statement,
8080 // incase n = m, the mul expression will be 2^0, so it gets resolved to
8081 // a simpler case. The following code handles the two cases:
8082 //
8083 // 1) For a two-shift sext-inreg, i.e. n = m,
8084 // use sext(trunc(x)) as the SCEV expression.
8085 //
8086 // 2) When n > m, use sext(mul(trunc(x), 2^(n-m)))) as the SCEV
8087 // expression. We already checked that ShlAmt < BitWidth, so
8088 // the multiplier, 1 << (ShlAmt - AShrAmt), fits into TruncTy as
8089 // ShlAmt - AShrAmt < Amt.
8090 const APInt &ShlAmt = ShlAmtCI->getValue();
8091 if (ShlAmt.ult(BitWidth) && ShlAmt.uge(AShrAmt)) {
8092 APInt Mul = APInt::getOneBitSet(BitWidth - AShrAmt,
8093 ShlAmtCI->getZExtValue() - AShrAmt);
8094 const SCEV *CompositeExpr =
8095 getMulExpr(AddTruncateExpr, getConstant(Mul));
8096 if (L->getOpcode() != Instruction::Shl)
8097 CompositeExpr = getAddExpr(CompositeExpr, AddConstant);
8098
8099 return getSignExtendExpr(CompositeExpr, OuterTy);
8100 }
8101 }
8102 break;
8103 }
8104 }
8105
8106 switch (U->getOpcode()) {
8107 case Instruction::Trunc:
8108 return getTruncateExpr(getSCEV(U->getOperand(0)), U->getType());
8109
8110 case Instruction::ZExt:
8111 return getZeroExtendExpr(getSCEV(U->getOperand(0)), U->getType());
8112
8113 case Instruction::SExt:
8114 if (auto BO = MatchBinaryOp(U->getOperand(0), getDataLayout(), AC, DT,
8116 // The NSW flag of a subtract does not always survive the conversion to
8117 // A + (-1)*B. By pushing sign extension onto its operands we are much
8118 // more likely to preserve NSW and allow later AddRec optimisations.
8119 //
8120 // NOTE: This is effectively duplicating this logic from getSignExtend:
8121 // sext((A + B + ...)<nsw>) --> (sext(A) + sext(B) + ...)<nsw>
8122 // but by that point the NSW information has potentially been lost.
8123 if (BO->Opcode == Instruction::Sub && BO->IsNSW) {
8124 Type *Ty = U->getType();
8125 auto *V1 = getSignExtendExpr(getSCEV(BO->LHS), Ty);
8126 auto *V2 = getSignExtendExpr(getSCEV(BO->RHS), Ty);
8127 return getMinusSCEV(V1, V2, SCEV::FlagNSW);
8128 }
8129 }
8130 return getSignExtendExpr(getSCEV(U->getOperand(0)), U->getType());
8131
8132 case Instruction::BitCast:
8133 // BitCasts are no-op casts so we just eliminate the cast.
8134 if (isSCEVable(U->getType()) && isSCEVable(U->getOperand(0)->getType()))
8135 return getSCEV(U->getOperand(0));
8136 break;
8137
8138 case Instruction::PtrToInt: {
8139 // Pointer to integer cast is straight-forward, so do model it.
8140 const SCEV *Op = getSCEV(U->getOperand(0));
8141 Type *DstIntTy = U->getType();
8142 // But only if effective SCEV (integer) type is wide enough to represent
8143 // all possible pointer values.
8144 const SCEV *IntOp = getPtrToIntExpr(Op, DstIntTy);
8145 if (isa<SCEVCouldNotCompute>(IntOp))
8146 return getUnknown(V);
8147 return IntOp;
8148 }
8149 case Instruction::IntToPtr:
8150 // Just don't deal with inttoptr casts.
8151 return getUnknown(V);
8152
8153 case Instruction::SDiv:
8154 // If both operands are non-negative, this is just an udiv.
8155 if (isKnownNonNegative(getSCEV(U->getOperand(0))) &&
8156 isKnownNonNegative(getSCEV(U->getOperand(1))))
8157 return getUDivExpr(getSCEV(U->getOperand(0)), getSCEV(U->getOperand(1)));
8158 break;
8159
8160 case Instruction::SRem:
8161 // If both operands are non-negative, this is just an urem.
8162 if (isKnownNonNegative(getSCEV(U->getOperand(0))) &&
8163 isKnownNonNegative(getSCEV(U->getOperand(1))))
8164 return getURemExpr(getSCEV(U->getOperand(0)), getSCEV(U->getOperand(1)));
8165 break;
8166
8167 case Instruction::GetElementPtr:
8168 return createNodeForGEP(cast<GEPOperator>(U));
8169
8170 case Instruction::PHI:
8171 return createNodeForPHI(cast<PHINode>(U));
8172
8173 case Instruction::Select:
8174 return createNodeForSelectOrPHI(U, U->getOperand(0), U->getOperand(1),
8175 U->getOperand(2));
8176
8177 case Instruction::Call:
8178 case Instruction::Invoke:
8179 if (Value *RV = cast<CallBase>(U)->getReturnedArgOperand())
8180 return getSCEV(RV);
8181
8182 if (auto *II = dyn_cast<IntrinsicInst>(U)) {
8183 switch (II->getIntrinsicID()) {
8184 case Intrinsic::abs:
8185 return getAbsExpr(
8186 getSCEV(II->getArgOperand(0)),
8187 /*IsNSW=*/cast<ConstantInt>(II->getArgOperand(1))->isOne());
8188 case Intrinsic::umax:
8189 LHS = getSCEV(II->getArgOperand(0));
8190 RHS = getSCEV(II->getArgOperand(1));
8191 return getUMaxExpr(LHS, RHS);
8192 case Intrinsic::umin:
8193 LHS = getSCEV(II->getArgOperand(0));
8194 RHS = getSCEV(II->getArgOperand(1));
8195 return getUMinExpr(LHS, RHS);
8196 case Intrinsic::smax:
8197 LHS = getSCEV(II->getArgOperand(0));
8198 RHS = getSCEV(II->getArgOperand(1));
8199 return getSMaxExpr(LHS, RHS);
8200 case Intrinsic::smin:
8201 LHS = getSCEV(II->getArgOperand(0));
8202 RHS = getSCEV(II->getArgOperand(1));
8203 return getSMinExpr(LHS, RHS);
8204 case Intrinsic::usub_sat: {
8205 const SCEV *X = getSCEV(II->getArgOperand(0));
8206 const SCEV *Y = getSCEV(II->getArgOperand(1));
8207 const SCEV *ClampedY = getUMinExpr(X, Y);
8208 return getMinusSCEV(X, ClampedY, SCEV::FlagNUW);
8209 }
8210 case Intrinsic::uadd_sat: {
8211 const SCEV *X = getSCEV(II->getArgOperand(0));
8212 const SCEV *Y = getSCEV(II->getArgOperand(1));
8213 const SCEV *ClampedX = getUMinExpr(X, getNotSCEV(Y));
8214 return getAddExpr(ClampedX, Y, SCEV::FlagNUW);
8215 }
8216 case Intrinsic::start_loop_iterations:
8217 case Intrinsic::annotation:
8218 case Intrinsic::ptr_annotation:
8219 // A start_loop_iterations or llvm.annotation or llvm.prt.annotation is
8220 // just eqivalent to the first operand for SCEV purposes.
8221 return getSCEV(II->getArgOperand(0));
8222 case Intrinsic::vscale:
8223 return getVScale(II->getType());
8224 default:
8225 break;
8226 }
8227 }
8228 break;
8229 }
8230
8231 return getUnknown(V);
8232}
8233
8234//===----------------------------------------------------------------------===//
8235// Iteration Count Computation Code
8236//
8237
8239 if (isa<SCEVCouldNotCompute>(ExitCount))
8240 return getCouldNotCompute();
8241
8242 auto *ExitCountType = ExitCount->getType();
8243 assert(ExitCountType->isIntegerTy());
8244 auto *EvalTy = Type::getIntNTy(ExitCountType->getContext(),
8245 1 + ExitCountType->getScalarSizeInBits());
8246 return getTripCountFromExitCount(ExitCount, EvalTy, nullptr);
8247}
8248
8250 Type *EvalTy,
8251 const Loop *L) {
8252 if (isa<SCEVCouldNotCompute>(ExitCount))
8253 return getCouldNotCompute();
8254
8255 unsigned ExitCountSize = getTypeSizeInBits(ExitCount->getType());
8256 unsigned EvalSize = EvalTy->getPrimitiveSizeInBits();
8257
8258 auto CanAddOneWithoutOverflow = [&]() {
8259 ConstantRange ExitCountRange =
8260 getRangeRef(ExitCount, RangeSignHint::HINT_RANGE_UNSIGNED);
8261 if (!ExitCountRange.contains(APInt::getMaxValue(ExitCountSize)))
8262 return true;
8263
8264 return L && isLoopEntryGuardedByCond(L, ICmpInst::ICMP_NE, ExitCount,
8265 getMinusOne(ExitCount->getType()));
8266 };
8267
8268 // If we need to zero extend the backedge count, check if we can add one to
8269 // it prior to zero extending without overflow. Provided this is safe, it
8270 // allows better simplification of the +1.
8271 if (EvalSize > ExitCountSize && CanAddOneWithoutOverflow())
8272 return getZeroExtendExpr(
8273 getAddExpr(ExitCount, getOne(ExitCount->getType())), EvalTy);
8274
8275 // Get the total trip count from the count by adding 1. This may wrap.
8276 return getAddExpr(getTruncateOrZeroExtend(ExitCount, EvalTy), getOne(EvalTy));
8277}
8278
8279static unsigned getConstantTripCount(const SCEVConstant *ExitCount) {
8280 if (!ExitCount)
8281 return 0;
8282
8283 ConstantInt *ExitConst = ExitCount->getValue();
8284
8285 // Guard against huge trip counts.
8286 if (ExitConst->getValue().getActiveBits() > 32)
8287 return 0;
8288
8289 // In case of integer overflow, this returns 0, which is correct.
8290 return ((unsigned)ExitConst->getZExtValue()) + 1;
8291}
8292
8294 auto *ExitCount = dyn_cast<SCEVConstant>(getBackedgeTakenCount(L, Exact));
8295 return getConstantTripCount(ExitCount);
8296}
8297
8298unsigned
8300 const BasicBlock *ExitingBlock) {
8301 assert(ExitingBlock && "Must pass a non-null exiting block!");
8302 assert(L->isLoopExiting(ExitingBlock) &&
8303 "Exiting block must actually branch out of the loop!");
8304 const SCEVConstant *ExitCount =
8305 dyn_cast<SCEVConstant>(getExitCount(L, ExitingBlock));
8306 return getConstantTripCount(ExitCount);
8307}
8308
8310 const Loop *L, SmallVectorImpl<const SCEVPredicate *> *Predicates) {
8311
8312 const auto *MaxExitCount =
8313 Predicates ? getPredicatedConstantMaxBackedgeTakenCount(L, *Predicates)
8315 return getConstantTripCount(dyn_cast<SCEVConstant>(MaxExitCount));
8316}
8317
8319 SmallVector<BasicBlock *, 8> ExitingBlocks;
8320 L->getExitingBlocks(ExitingBlocks);
8321
8322 std::optional<unsigned> Res;
8323 for (auto *ExitingBB : ExitingBlocks) {
8324 unsigned Multiple = getSmallConstantTripMultiple(L, ExitingBB);
8325 if (!Res)
8326 Res = Multiple;
8327 Res = std::gcd(*Res, Multiple);
8328 }
8329 return Res.value_or(1);
8330}
8331
8333 const SCEV *ExitCount) {
8334 if (isa<SCEVCouldNotCompute>(ExitCount))
8335 return 1;
8336
8337 // Get the trip count
8338 const SCEV *TCExpr = getTripCountFromExitCount(applyLoopGuards(ExitCount, L));
8339
8340 APInt Multiple = getNonZeroConstantMultiple(TCExpr);
8341 // If a trip multiple is huge (>=2^32), the trip count is still divisible by
8342 // the greatest power of 2 divisor less than 2^32.
8343 return Multiple.getActiveBits() > 32
8344 ? 1U << std::min(31U, Multiple.countTrailingZeros())
8345 : (unsigned)Multiple.getZExtValue();
8346}
8347
8348/// Returns the largest constant divisor of the trip count of this loop as a
8349/// normal unsigned value, if possible. This means that the actual trip count is
8350/// always a multiple of the returned value (don't forget the trip count could
8351/// very well be zero as well!).
8352///
8353/// Returns 1 if the trip count is unknown or not guaranteed to be the
8354/// multiple of a constant (which is also the case if the trip count is simply
8355/// constant, use getSmallConstantTripCount for that case), Will also return 1
8356/// if the trip count is very large (>= 2^32).
8357///
8358/// As explained in the comments for getSmallConstantTripCount, this assumes
8359/// that control exits the loop via ExitingBlock.
8360unsigned
8362 const BasicBlock *ExitingBlock) {
8363 assert(ExitingBlock && "Must pass a non-null exiting block!");
8364 assert(L->isLoopExiting(ExitingBlock) &&
8365 "Exiting block must actually branch out of the loop!");
8366 const SCEV *ExitCount = getExitCount(L, ExitingBlock);
8367 return getSmallConstantTripMultiple(L, ExitCount);
8368}
8369
8371 const BasicBlock *ExitingBlock,
8372 ExitCountKind Kind) {
8373 switch (Kind) {
8374 case Exact:
8375 return getBackedgeTakenInfo(L).getExact(ExitingBlock, this);
8376 case SymbolicMaximum:
8377 return getBackedgeTakenInfo(L).getSymbolicMax(ExitingBlock, this);
8378 case ConstantMaximum:
8379 return getBackedgeTakenInfo(L).getConstantMax(ExitingBlock, this);
8380 };
8381 llvm_unreachable("Invalid ExitCountKind!");
8382}
8383
8385 const Loop *L, const BasicBlock *ExitingBlock,
8387 switch (Kind) {
8388 case Exact:
8389 return getPredicatedBackedgeTakenInfo(L).getExact(ExitingBlock, this,
8390 Predicates);
8391 case SymbolicMaximum:
8392 return getPredicatedBackedgeTakenInfo(L).getSymbolicMax(ExitingBlock, this,
8393 Predicates);
8394 case ConstantMaximum:
8395 return getPredicatedBackedgeTakenInfo(L).getConstantMax(ExitingBlock, this,
8396 Predicates);
8397 };
8398 llvm_unreachable("Invalid ExitCountKind!");
8399}
8400
8403 return getPredicatedBackedgeTakenInfo(L).getExact(L, this, &Preds);
8404}
8405
8407 ExitCountKind Kind) {
8408 switch (Kind) {
8409 case Exact:
8410 return getBackedgeTakenInfo(L).getExact(L, this);
8411 case ConstantMaximum:
8412 return getBackedgeTakenInfo(L).getConstantMax(this);
8413 case SymbolicMaximum:
8414 return getBackedgeTakenInfo(L).getSymbolicMax(L, this);
8415 };
8416 llvm_unreachable("Invalid ExitCountKind!");
8417}
8418
8421 return getPredicatedBackedgeTakenInfo(L).getSymbolicMax(L, this, &Preds);
8422}
8423
8426 return getPredicatedBackedgeTakenInfo(L).getConstantMax(this, &Preds);
8427}
8428
8430 return getBackedgeTakenInfo(L).isConstantMaxOrZero(this);
8431}
8432
8433/// Push PHI nodes in the header of the given loop onto the given Worklist.
8434static void PushLoopPHIs(const Loop *L,
8437 BasicBlock *Header = L->getHeader();
8438
8439 // Push all Loop-header PHIs onto the Worklist stack.
8440 for (PHINode &PN : Header->phis())
8441 if (Visited.insert(&PN).second)
8442 Worklist.push_back(&PN);
8443}
8444
8445ScalarEvolution::BackedgeTakenInfo &
8446ScalarEvolution::getPredicatedBackedgeTakenInfo(const Loop *L) {
8447 auto &BTI = getBackedgeTakenInfo(L);
8448 if (BTI.hasFullInfo())
8449 return BTI;
8450
8451 auto Pair = PredicatedBackedgeTakenCounts.try_emplace(L);
8452
8453 if (!Pair.second)
8454 return Pair.first->second;
8455
8456 BackedgeTakenInfo Result =
8457 computeBackedgeTakenCount(L, /*AllowPredicates=*/true);
8458
8459 return PredicatedBackedgeTakenCounts.find(L)->second = std::move(Result);
8460}
8461
8462ScalarEvolution::BackedgeTakenInfo &
8463ScalarEvolution::getBackedgeTakenInfo(const Loop *L) {
8464 // Initially insert an invalid entry for this loop. If the insertion
8465 // succeeds, proceed to actually compute a backedge-taken count and
8466 // update the value. The temporary CouldNotCompute value tells SCEV
8467 // code elsewhere that it shouldn't attempt to request a new
8468 // backedge-taken count, which could result in infinite recursion.
8469 std::pair<DenseMap<const Loop *, BackedgeTakenInfo>::iterator, bool> Pair =
8470 BackedgeTakenCounts.try_emplace(L);
8471 if (!Pair.second)
8472 return Pair.first->second;
8473
8474 // computeBackedgeTakenCount may allocate memory for its result. Inserting it
8475 // into the BackedgeTakenCounts map transfers ownership. Otherwise, the result
8476 // must be cleared in this scope.
8477 BackedgeTakenInfo Result = computeBackedgeTakenCount(L);
8478
8479 // Now that we know more about the trip count for this loop, forget any
8480 // existing SCEV values for PHI nodes in this loop since they are only
8481 // conservative estimates made without the benefit of trip count
8482 // information. This invalidation is not necessary for correctness, and is
8483 // only done to produce more precise results.
8484 if (Result.hasAnyInfo()) {
8485 // Invalidate any expression using an addrec in this loop.
8487 auto LoopUsersIt = LoopUsers.find(L);
8488 if (LoopUsersIt != LoopUsers.end())
8489 append_range(ToForget, LoopUsersIt->second);
8490 forgetMemoizedResults(ToForget);
8491
8492 // Invalidate constant-evolved loop header phis.
8493 for (PHINode &PN : L->getHeader()->phis())
8494 ConstantEvolutionLoopExitValue.erase(&PN);
8495 }
8496
8497 // Re-lookup the insert position, since the call to
8498 // computeBackedgeTakenCount above could result in a
8499 // recusive call to getBackedgeTakenInfo (on a different
8500 // loop), which would invalidate the iterator computed
8501 // earlier.
8502 return BackedgeTakenCounts.find(L)->second = std::move(Result);
8503}
8504
8506 // This method is intended to forget all info about loops. It should
8507 // invalidate caches as if the following happened:
8508 // - The trip counts of all loops have changed arbitrarily
8509 // - Every llvm::Value has been updated in place to produce a different
8510 // result.
8511 BackedgeTakenCounts.clear();
8512 PredicatedBackedgeTakenCounts.clear();
8513 BECountUsers.clear();
8514 LoopPropertiesCache.clear();
8515 ConstantEvolutionLoopExitValue.clear();
8516 ValueExprMap.clear();
8517 ValuesAtScopes.clear();
8518 ValuesAtScopesUsers.clear();
8519 LoopDispositions.clear();
8520 BlockDispositions.clear();
8521 UnsignedRanges.clear();
8522 SignedRanges.clear();
8523 ExprValueMap.clear();
8524 HasRecMap.clear();
8525 ConstantMultipleCache.clear();
8526 PredicatedSCEVRewrites.clear();
8527 FoldCache.clear();
8528 FoldCacheUser.clear();
8529}
8530void ScalarEvolution::visitAndClearUsers(
8534 while (!Worklist.empty()) {
8535 Instruction *I = Worklist.pop_back_val();
8536 if (!isSCEVable(I->getType()) && !isa<WithOverflowInst>(I))
8537 continue;
8538
8540 ValueExprMap.find_as(static_cast<Value *>(I));
8541 if (It != ValueExprMap.end()) {
8542 eraseValueFromMap(It->first);
8543 ToForget.push_back(It->second);
8544 if (PHINode *PN = dyn_cast<PHINode>(I))
8545 ConstantEvolutionLoopExitValue.erase(PN);
8546 }
8547
8548 PushDefUseChildren(I, Worklist, Visited);
8549 }
8550}
8551
8553 SmallVector<const Loop *, 16> LoopWorklist(1, L);
8557
8558 // Iterate over all the loops and sub-loops to drop SCEV information.
8559 while (!LoopWorklist.empty()) {
8560 auto *CurrL = LoopWorklist.pop_back_val();
8561
8562 // Drop any stored trip count value.
8563 forgetBackedgeTakenCounts(CurrL, /* Predicated */ false);
8564 forgetBackedgeTakenCounts(CurrL, /* Predicated */ true);
8565
8566 // Drop information about predicated SCEV rewrites for this loop.
8567 for (auto I = PredicatedSCEVRewrites.begin();
8568 I != PredicatedSCEVRewrites.end();) {
8569 std::pair<const SCEV *, const Loop *> Entry = I->first;
8570 if (Entry.second == CurrL)
8571 PredicatedSCEVRewrites.erase(I++);
8572 else
8573 ++I;
8574 }
8575
8576 auto LoopUsersItr = LoopUsers.find(CurrL);
8577 if (LoopUsersItr != LoopUsers.end())
8578 llvm::append_range(ToForget, LoopUsersItr->second);
8579
8580 // Drop information about expressions based on loop-header PHIs.
8581 PushLoopPHIs(CurrL, Worklist, Visited);
8582 visitAndClearUsers(Worklist, Visited, ToForget);
8583
8584 LoopPropertiesCache.erase(CurrL);
8585 // Forget all contained loops too, to avoid dangling entries in the
8586 // ValuesAtScopes map.
8587 LoopWorklist.append(CurrL->begin(), CurrL->end());
8588 }
8589 forgetMemoizedResults(ToForget);
8590}
8591
8593 forgetLoop(L->getOutermostLoop());
8594}
8595
8598 if (!I) return;
8599
8600 // Drop information about expressions based on loop-header PHIs.
8604 Worklist.push_back(I);
8605 Visited.insert(I);
8606 visitAndClearUsers(Worklist, Visited, ToForget);
8607
8608 forgetMemoizedResults(ToForget);
8609}
8610
8612 if (!isSCEVable(V->getType()))
8613 return;
8614
8615 // If SCEV looked through a trivial LCSSA phi node, we might have SCEV's
8616 // directly using a SCEVUnknown/SCEVAddRec defined in the loop. After an
8617 // extra predecessor is added, this is no longer valid. Find all Unknowns and
8618 // AddRecs defined in the loop and invalidate any SCEV's making use of them.
8619 if (const SCEV *S = getExistingSCEV(V)) {
8620 struct InvalidationRootCollector {
8621 Loop *L;
8623
8624 InvalidationRootCollector(Loop *L) : L(L) {}
8625
8626 bool follow(const SCEV *S) {
8627 if (auto *SU = dyn_cast<SCEVUnknown>(S)) {
8628 if (auto *I = dyn_cast<Instruction>(SU->getValue()))
8629 if (L->contains(I))
8630 Roots.push_back(S);
8631 } else if (auto *AddRec = dyn_cast<SCEVAddRecExpr>(S)) {
8632 if (L->contains(AddRec->getLoop()))
8633 Roots.push_back(S);
8634 }
8635 return true;
8636 }
8637 bool isDone() const { return false; }
8638 };
8639
8640 InvalidationRootCollector C(L);
8641 visitAll(S, C);
8642 forgetMemoizedResults(C.Roots);
8643 }
8644
8645 // Also perform the normal invalidation.
8646 forgetValue(V);
8647}
8648
8649void ScalarEvolution::forgetLoopDispositions() { LoopDispositions.clear(); }
8650
8652 // Unless a specific value is passed to invalidation, completely clear both
8653 // caches.
8654 if (!V) {
8655 BlockDispositions.clear();
8656 LoopDispositions.clear();
8657 return;
8658 }
8659
8660 if (!isSCEVable(V->getType()))
8661 return;
8662
8663 const SCEV *S = getExistingSCEV(V);
8664 if (!S)
8665 return;
8666
8667 // Invalidate the block and loop dispositions cached for S. Dispositions of
8668 // S's users may change if S's disposition changes (i.e. a user may change to
8669 // loop-invariant, if S changes to loop invariant), so also invalidate
8670 // dispositions of S's users recursively.
8671 SmallVector<const SCEV *, 8> Worklist = {S};
8673 while (!Worklist.empty()) {
8674 const SCEV *Curr = Worklist.pop_back_val();
8675 bool LoopDispoRemoved = LoopDispositions.erase(Curr);
8676 bool BlockDispoRemoved = BlockDispositions.erase(Curr);
8677 if (!LoopDispoRemoved && !BlockDispoRemoved)
8678 continue;
8679 auto Users = SCEVUsers.find(Curr);
8680 if (Users != SCEVUsers.end())
8681 for (const auto *User : Users->second)
8682 if (Seen.insert(User).second)
8683 Worklist.push_back(User);
8684 }
8685}
8686
8687/// Get the exact loop backedge taken count considering all loop exits. A
8688/// computable result can only be returned for loops with all exiting blocks
8689/// dominating the latch. howFarToZero assumes that the limit of each loop test
8690/// is never skipped. This is a valid assumption as long as the loop exits via
8691/// that test. For precise results, it is the caller's responsibility to specify
8692/// the relevant loop exiting block using getExact(ExitingBlock, SE).
8693const SCEV *ScalarEvolution::BackedgeTakenInfo::getExact(
8694 const Loop *L, ScalarEvolution *SE,
8696 // If any exits were not computable, the loop is not computable.
8697 if (!isComplete() || ExitNotTaken.empty())
8698 return SE->getCouldNotCompute();
8699
8700 const BasicBlock *Latch = L->getLoopLatch();
8701 // All exiting blocks we have collected must dominate the only backedge.
8702 if (!Latch)
8703 return SE->getCouldNotCompute();
8704
8705 // All exiting blocks we have gathered dominate loop's latch, so exact trip
8706 // count is simply a minimum out of all these calculated exit counts.
8708 for (const auto &ENT : ExitNotTaken) {
8709 const SCEV *BECount = ENT.ExactNotTaken;
8710 assert(BECount != SE->getCouldNotCompute() && "Bad exit SCEV!");
8711 assert(SE->DT.dominates(ENT.ExitingBlock, Latch) &&
8712 "We should only have known counts for exiting blocks that dominate "
8713 "latch!");
8714
8715 Ops.push_back(BECount);
8716
8717 if (Preds)
8718 append_range(*Preds, ENT.Predicates);
8719
8720 assert((Preds || ENT.hasAlwaysTruePredicate()) &&
8721 "Predicate should be always true!");
8722 }
8723
8724 // If an earlier exit exits on the first iteration (exit count zero), then
8725 // a later poison exit count should not propagate into the result. This are
8726 // exactly the semantics provided by umin_seq.
8727 return SE->getUMinFromMismatchedTypes(Ops, /* Sequential */ true);
8728}
8729
8730const ScalarEvolution::ExitNotTakenInfo *
8731ScalarEvolution::BackedgeTakenInfo::getExitNotTaken(
8732 const BasicBlock *ExitingBlock,
8733 SmallVectorImpl<const SCEVPredicate *> *Predicates) const {
8734 for (const auto &ENT : ExitNotTaken)
8735 if (ENT.ExitingBlock == ExitingBlock) {
8736 if (ENT.hasAlwaysTruePredicate())
8737 return &ENT;
8738 else if (Predicates) {
8739 append_range(*Predicates, ENT.Predicates);
8740 return &ENT;
8741 }
8742 }
8743
8744 return nullptr;
8745}
8746
8747/// getConstantMax - Get the constant max backedge taken count for the loop.
8748const SCEV *ScalarEvolution::BackedgeTakenInfo::getConstantMax(
8749 ScalarEvolution *SE,
8750 SmallVectorImpl<const SCEVPredicate *> *Predicates) const {
8751 if (!getConstantMax())
8752 return SE->getCouldNotCompute();
8753
8754 for (const auto &ENT : ExitNotTaken)
8755 if (!ENT.hasAlwaysTruePredicate()) {
8756 if (!Predicates)
8757 return SE->getCouldNotCompute();
8758 append_range(*Predicates, ENT.Predicates);
8759 }
8760
8761 assert((isa<SCEVCouldNotCompute>(getConstantMax()) ||
8762 isa<SCEVConstant>(getConstantMax())) &&
8763 "No point in having a non-constant max backedge taken count!");
8764 return getConstantMax();
8765}
8766
8767const SCEV *ScalarEvolution::BackedgeTakenInfo::getSymbolicMax(
8768 const Loop *L, ScalarEvolution *SE,
8769 SmallVectorImpl<const SCEVPredicate *> *Predicates) {
8770 if (!SymbolicMax) {
8771 // Form an expression for the maximum exit count possible for this loop. We
8772 // merge the max and exact information to approximate a version of
8773 // getConstantMaxBackedgeTakenCount which isn't restricted to just
8774 // constants.
8776
8777 for (const auto &ENT : ExitNotTaken) {
8778 const SCEV *ExitCount = ENT.SymbolicMaxNotTaken;
8779 if (!isa<SCEVCouldNotCompute>(ExitCount)) {
8780 assert(SE->DT.dominates(ENT.ExitingBlock, L->getLoopLatch()) &&
8781 "We should only have known counts for exiting blocks that "
8782 "dominate latch!");
8783 ExitCounts.push_back(ExitCount);
8784 if (Predicates)
8785 append_range(*Predicates, ENT.Predicates);
8786
8787 assert((Predicates || ENT.hasAlwaysTruePredicate()) &&
8788 "Predicate should be always true!");
8789 }
8790 }
8791 if (ExitCounts.empty())
8792 SymbolicMax = SE->getCouldNotCompute();
8793 else
8794 SymbolicMax =
8795 SE->getUMinFromMismatchedTypes(ExitCounts, /*Sequential*/ true);
8796 }
8797 return SymbolicMax;
8798}
8799
8800bool ScalarEvolution::BackedgeTakenInfo::isConstantMaxOrZero(
8801 ScalarEvolution *SE) const {
8802 auto PredicateNotAlwaysTrue = [](const ExitNotTakenInfo &ENT) {
8803 return !ENT.hasAlwaysTruePredicate();
8804 };
8805 return MaxOrZero && !any_of(ExitNotTaken, PredicateNotAlwaysTrue);
8806}
8807
8810
8812 const SCEV *E, const SCEV *ConstantMaxNotTaken,
8813 const SCEV *SymbolicMaxNotTaken, bool MaxOrZero,
8817 // If we prove the max count is zero, so is the symbolic bound. This happens
8818 // in practice due to differences in a) how context sensitive we've chosen
8819 // to be and b) how we reason about bounds implied by UB.
8820 if (ConstantMaxNotTaken->isZero()) {
8821 this->ExactNotTaken = E = ConstantMaxNotTaken;
8822 this->SymbolicMaxNotTaken = SymbolicMaxNotTaken = ConstantMaxNotTaken;
8823 }
8824
8827 "Exact is not allowed to be less precise than Constant Max");
8830 "Exact is not allowed to be less precise than Symbolic Max");
8833 "Symbolic Max is not allowed to be less precise than Constant Max");
8836 "No point in having a non-constant max backedge taken count!");
8838 for (const auto PredList : PredLists)
8839 for (const auto *P : PredList) {
8840 if (SeenPreds.contains(P))
8841 continue;
8842 assert(!isa<SCEVUnionPredicate>(P) && "Only add leaf predicates here!");
8843 SeenPreds.insert(P);
8844 Predicates.push_back(P);
8845 }
8846 assert((isa<SCEVCouldNotCompute>(E) || !E->getType()->isPointerTy()) &&
8847 "Backedge count should be int");
8849 !ConstantMaxNotTaken->getType()->isPointerTy()) &&
8850 "Max backedge count should be int");
8851}
8852
8860
8861/// Allocate memory for BackedgeTakenInfo and copy the not-taken count of each
8862/// computable exit into a persistent ExitNotTakenInfo array.
8863ScalarEvolution::BackedgeTakenInfo::BackedgeTakenInfo(
8865 bool IsComplete, const SCEV *ConstantMax, bool MaxOrZero)
8866 : ConstantMax(ConstantMax), IsComplete(IsComplete), MaxOrZero(MaxOrZero) {
8867 using EdgeExitInfo = ScalarEvolution::BackedgeTakenInfo::EdgeExitInfo;
8868
8869 ExitNotTaken.reserve(ExitCounts.size());
8870 std::transform(ExitCounts.begin(), ExitCounts.end(),
8871 std::back_inserter(ExitNotTaken),
8872 [&](const EdgeExitInfo &EEI) {
8873 BasicBlock *ExitBB = EEI.first;
8874 const ExitLimit &EL = EEI.second;
8875 return ExitNotTakenInfo(ExitBB, EL.ExactNotTaken,
8876 EL.ConstantMaxNotTaken, EL.SymbolicMaxNotTaken,
8877 EL.Predicates);
8878 });
8879 assert((isa<SCEVCouldNotCompute>(ConstantMax) ||
8880 isa<SCEVConstant>(ConstantMax)) &&
8881 "No point in having a non-constant max backedge taken count!");
8882}
8883
8884/// Compute the number of times the backedge of the specified loop will execute.
8885ScalarEvolution::BackedgeTakenInfo
8886ScalarEvolution::computeBackedgeTakenCount(const Loop *L,
8887 bool AllowPredicates) {
8888 SmallVector<BasicBlock *, 8> ExitingBlocks;
8889 L->getExitingBlocks(ExitingBlocks);
8890
8891 using EdgeExitInfo = ScalarEvolution::BackedgeTakenInfo::EdgeExitInfo;
8892
8894 bool CouldComputeBECount = true;
8895 BasicBlock *Latch = L->getLoopLatch(); // may be NULL.
8896 const SCEV *MustExitMaxBECount = nullptr;
8897 const SCEV *MayExitMaxBECount = nullptr;
8898 bool MustExitMaxOrZero = false;
8899 bool IsOnlyExit = ExitingBlocks.size() == 1;
8900
8901 // Compute the ExitLimit for each loop exit. Use this to populate ExitCounts
8902 // and compute maxBECount.
8903 // Do a union of all the predicates here.
8904 for (BasicBlock *ExitBB : ExitingBlocks) {
8905 // We canonicalize untaken exits to br (constant), ignore them so that
8906 // proving an exit untaken doesn't negatively impact our ability to reason
8907 // about the loop as whole.
8908 if (auto *BI = dyn_cast<BranchInst>(ExitBB->getTerminator()))
8909 if (auto *CI = dyn_cast<ConstantInt>(BI->getCondition())) {
8910 bool ExitIfTrue = !L->contains(BI->getSuccessor(0));
8911 if (ExitIfTrue == CI->isZero())
8912 continue;
8913 }
8914
8915 ExitLimit EL = computeExitLimit(L, ExitBB, IsOnlyExit, AllowPredicates);
8916
8917 assert((AllowPredicates || EL.Predicates.empty()) &&
8918 "Predicated exit limit when predicates are not allowed!");
8919
8920 // 1. For each exit that can be computed, add an entry to ExitCounts.
8921 // CouldComputeBECount is true only if all exits can be computed.
8922 if (EL.ExactNotTaken != getCouldNotCompute())
8923 ++NumExitCountsComputed;
8924 else
8925 // We couldn't compute an exact value for this exit, so
8926 // we won't be able to compute an exact value for the loop.
8927 CouldComputeBECount = false;
8928 // Remember exit count if either exact or symbolic is known. Because
8929 // Exact always implies symbolic, only check symbolic.
8930 if (EL.SymbolicMaxNotTaken != getCouldNotCompute())
8931 ExitCounts.emplace_back(ExitBB, EL);
8932 else {
8933 assert(EL.ExactNotTaken == getCouldNotCompute() &&
8934 "Exact is known but symbolic isn't?");
8935 ++NumExitCountsNotComputed;
8936 }
8937
8938 // 2. Derive the loop's MaxBECount from each exit's max number of
8939 // non-exiting iterations. Partition the loop exits into two kinds:
8940 // LoopMustExits and LoopMayExits.
8941 //
8942 // If the exit dominates the loop latch, it is a LoopMustExit otherwise it
8943 // is a LoopMayExit. If any computable LoopMustExit is found, then
8944 // MaxBECount is the minimum EL.ConstantMaxNotTaken of computable
8945 // LoopMustExits. Otherwise, MaxBECount is conservatively the maximum
8946 // EL.ConstantMaxNotTaken, where CouldNotCompute is considered greater than
8947 // any
8948 // computable EL.ConstantMaxNotTaken.
8949 if (EL.ConstantMaxNotTaken != getCouldNotCompute() && Latch &&
8950 DT.dominates(ExitBB, Latch)) {
8951 if (!MustExitMaxBECount) {
8952 MustExitMaxBECount = EL.ConstantMaxNotTaken;
8953 MustExitMaxOrZero = EL.MaxOrZero;
8954 } else {
8955 MustExitMaxBECount = getUMinFromMismatchedTypes(MustExitMaxBECount,
8956 EL.ConstantMaxNotTaken);
8957 }
8958 } else if (MayExitMaxBECount != getCouldNotCompute()) {
8959 if (!MayExitMaxBECount || EL.ConstantMaxNotTaken == getCouldNotCompute())
8960 MayExitMaxBECount = EL.ConstantMaxNotTaken;
8961 else {
8962 MayExitMaxBECount = getUMaxFromMismatchedTypes(MayExitMaxBECount,
8963 EL.ConstantMaxNotTaken);
8964 }
8965 }
8966 }
8967 const SCEV *MaxBECount = MustExitMaxBECount ? MustExitMaxBECount :
8968 (MayExitMaxBECount ? MayExitMaxBECount : getCouldNotCompute());
8969 // The loop backedge will be taken the maximum or zero times if there's
8970 // a single exit that must be taken the maximum or zero times.
8971 bool MaxOrZero = (MustExitMaxOrZero && ExitingBlocks.size() == 1);
8972
8973 // Remember which SCEVs are used in exit limits for invalidation purposes.
8974 // We only care about non-constant SCEVs here, so we can ignore
8975 // EL.ConstantMaxNotTaken
8976 // and MaxBECount, which must be SCEVConstant.
8977 for (const auto &Pair : ExitCounts) {
8978 if (!isa<SCEVConstant>(Pair.second.ExactNotTaken))
8979 BECountUsers[Pair.second.ExactNotTaken].insert({L, AllowPredicates});
8980 if (!isa<SCEVConstant>(Pair.second.SymbolicMaxNotTaken))
8981 BECountUsers[Pair.second.SymbolicMaxNotTaken].insert(
8982 {L, AllowPredicates});
8983 }
8984 return BackedgeTakenInfo(std::move(ExitCounts), CouldComputeBECount,
8985 MaxBECount, MaxOrZero);
8986}
8987
8988ScalarEvolution::ExitLimit
8989ScalarEvolution::computeExitLimit(const Loop *L, BasicBlock *ExitingBlock,
8990 bool IsOnlyExit, bool AllowPredicates) {
8991 assert(L->contains(ExitingBlock) && "Exit count for non-loop block?");
8992 // If our exiting block does not dominate the latch, then its connection with
8993 // loop's exit limit may be far from trivial.
8994 const BasicBlock *Latch = L->getLoopLatch();
8995 if (!Latch || !DT.dominates(ExitingBlock, Latch))
8996 return getCouldNotCompute();
8997
8998 Instruction *Term = ExitingBlock->getTerminator();
8999 if (BranchInst *BI = dyn_cast<BranchInst>(Term)) {
9000 assert(BI->isConditional() && "If unconditional, it can't be in loop!");
9001 bool ExitIfTrue = !L->contains(BI->getSuccessor(0));
9002 assert(ExitIfTrue == L->contains(BI->getSuccessor(1)) &&
9003 "It should have one successor in loop and one exit block!");
9004 // Proceed to the next level to examine the exit condition expression.
9005 return computeExitLimitFromCond(L, BI->getCondition(), ExitIfTrue,
9006 /*ControlsOnlyExit=*/IsOnlyExit,
9007 AllowPredicates);
9008 }
9009
9010 if (SwitchInst *SI = dyn_cast<SwitchInst>(Term)) {
9011 // For switch, make sure that there is a single exit from the loop.
9012 BasicBlock *Exit = nullptr;
9013 for (auto *SBB : successors(ExitingBlock))
9014 if (!L->contains(SBB)) {
9015 if (Exit) // Multiple exit successors.
9016 return getCouldNotCompute();
9017 Exit = SBB;
9018 }
9019 assert(Exit && "Exiting block must have at least one exit");
9020 return computeExitLimitFromSingleExitSwitch(
9021 L, SI, Exit, /*ControlsOnlyExit=*/IsOnlyExit);
9022 }
9023
9024 return getCouldNotCompute();
9025}
9026
9028 const Loop *L, Value *ExitCond, bool ExitIfTrue, bool ControlsOnlyExit,
9029 bool AllowPredicates) {
9030 ScalarEvolution::ExitLimitCacheTy Cache(L, ExitIfTrue, AllowPredicates);
9031 return computeExitLimitFromCondCached(Cache, L, ExitCond, ExitIfTrue,
9032 ControlsOnlyExit, AllowPredicates);
9033}
9034
9035std::optional<ScalarEvolution::ExitLimit>
9036ScalarEvolution::ExitLimitCache::find(const Loop *L, Value *ExitCond,
9037 bool ExitIfTrue, bool ControlsOnlyExit,
9038 bool AllowPredicates) {
9039 (void)this->L;
9040 (void)this->ExitIfTrue;
9041 (void)this->AllowPredicates;
9042
9043 assert(this->L == L && this->ExitIfTrue == ExitIfTrue &&
9044 this->AllowPredicates == AllowPredicates &&
9045 "Variance in assumed invariant key components!");
9046 auto Itr = TripCountMap.find({ExitCond, ControlsOnlyExit});
9047 if (Itr == TripCountMap.end())
9048 return std::nullopt;
9049 return Itr->second;
9050}
9051
9052void ScalarEvolution::ExitLimitCache::insert(const Loop *L, Value *ExitCond,
9053 bool ExitIfTrue,
9054 bool ControlsOnlyExit,
9055 bool AllowPredicates,
9056 const ExitLimit &EL) {
9057 assert(this->L == L && this->ExitIfTrue == ExitIfTrue &&
9058 this->AllowPredicates == AllowPredicates &&
9059 "Variance in assumed invariant key components!");
9060
9061 auto InsertResult = TripCountMap.insert({{ExitCond, ControlsOnlyExit}, EL});
9062 assert(InsertResult.second && "Expected successful insertion!");
9063 (void)InsertResult;
9064 (void)ExitIfTrue;
9065}
9066
9067ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromCondCached(
9068 ExitLimitCacheTy &Cache, const Loop *L, Value *ExitCond, bool ExitIfTrue,
9069 bool ControlsOnlyExit, bool AllowPredicates) {
9070
9071 if (auto MaybeEL = Cache.find(L, ExitCond, ExitIfTrue, ControlsOnlyExit,
9072 AllowPredicates))
9073 return *MaybeEL;
9074
9075 ExitLimit EL = computeExitLimitFromCondImpl(
9076 Cache, L, ExitCond, ExitIfTrue, ControlsOnlyExit, AllowPredicates);
9077 Cache.insert(L, ExitCond, ExitIfTrue, ControlsOnlyExit, AllowPredicates, EL);
9078 return EL;
9079}
9080
9081ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromCondImpl(
9082 ExitLimitCacheTy &Cache, const Loop *L, Value *ExitCond, bool ExitIfTrue,
9083 bool ControlsOnlyExit, bool AllowPredicates) {
9084 // Handle BinOp conditions (And, Or).
9085 if (auto LimitFromBinOp = computeExitLimitFromCondFromBinOp(
9086 Cache, L, ExitCond, ExitIfTrue, ControlsOnlyExit, AllowPredicates))
9087 return *LimitFromBinOp;
9088
9089 // With an icmp, it may be feasible to compute an exact backedge-taken count.
9090 // Proceed to the next level to examine the icmp.
9091 if (ICmpInst *ExitCondICmp = dyn_cast<ICmpInst>(ExitCond)) {
9092 ExitLimit EL =
9093 computeExitLimitFromICmp(L, ExitCondICmp, ExitIfTrue, ControlsOnlyExit);
9094 if (EL.hasFullInfo() || !AllowPredicates)
9095 return EL;
9096
9097 // Try again, but use SCEV predicates this time.
9098 return computeExitLimitFromICmp(L, ExitCondICmp, ExitIfTrue,
9099 ControlsOnlyExit,
9100 /*AllowPredicates=*/true);
9101 }
9102
9103 // Check for a constant condition. These are normally stripped out by
9104 // SimplifyCFG, but ScalarEvolution may be used by a pass which wishes to
9105 // preserve the CFG and is temporarily leaving constant conditions
9106 // in place.
9107 if (ConstantInt *CI = dyn_cast<ConstantInt>(ExitCond)) {
9108 if (ExitIfTrue == !CI->getZExtValue())
9109 // The backedge is always taken.
9110 return getCouldNotCompute();
9111 // The backedge is never taken.
9112 return getZero(CI->getType());
9113 }
9114
9115 // If we're exiting based on the overflow flag of an x.with.overflow intrinsic
9116 // with a constant step, we can form an equivalent icmp predicate and figure
9117 // out how many iterations will be taken before we exit.
9118 const WithOverflowInst *WO;
9119 const APInt *C;
9120 if (match(ExitCond, m_ExtractValue<1>(m_WithOverflowInst(WO))) &&
9121 match(WO->getRHS(), m_APInt(C))) {
9122 ConstantRange NWR =
9124 WO->getNoWrapKind());
9125 CmpInst::Predicate Pred;
9126 APInt NewRHSC, Offset;
9127 NWR.getEquivalentICmp(Pred, NewRHSC, Offset);
9128 if (!ExitIfTrue)
9129 Pred = ICmpInst::getInversePredicate(Pred);
9130 auto *LHS = getSCEV(WO->getLHS());
9131 if (Offset != 0)
9133 auto EL = computeExitLimitFromICmp(L, Pred, LHS, getConstant(NewRHSC),
9134 ControlsOnlyExit, AllowPredicates);
9135 if (EL.hasAnyInfo())
9136 return EL;
9137 }
9138
9139 // If it's not an integer or pointer comparison then compute it the hard way.
9140 return computeExitCountExhaustively(L, ExitCond, ExitIfTrue);
9141}
9142
9143std::optional<ScalarEvolution::ExitLimit>
9144ScalarEvolution::computeExitLimitFromCondFromBinOp(
9145 ExitLimitCacheTy &Cache, const Loop *L, Value *ExitCond, bool ExitIfTrue,
9146 bool ControlsOnlyExit, bool AllowPredicates) {
9147 // Check if the controlling expression for this loop is an And or Or.
9148 Value *Op0, *Op1;
9149 bool IsAnd = false;
9150 if (match(ExitCond, m_LogicalAnd(m_Value(Op0), m_Value(Op1))))
9151 IsAnd = true;
9152 else if (match(ExitCond, m_LogicalOr(m_Value(Op0), m_Value(Op1))))
9153 IsAnd = false;
9154 else
9155 return std::nullopt;
9156
9157 // EitherMayExit is true in these two cases:
9158 // br (and Op0 Op1), loop, exit
9159 // br (or Op0 Op1), exit, loop
9160 bool EitherMayExit = IsAnd ^ ExitIfTrue;
9161 ExitLimit EL0 = computeExitLimitFromCondCached(
9162 Cache, L, Op0, ExitIfTrue, ControlsOnlyExit && !EitherMayExit,
9163 AllowPredicates);
9164 ExitLimit EL1 = computeExitLimitFromCondCached(
9165 Cache, L, Op1, ExitIfTrue, ControlsOnlyExit && !EitherMayExit,
9166 AllowPredicates);
9167
9168 // Be robust against unsimplified IR for the form "op i1 X, NeutralElement"
9169 const Constant *NeutralElement = ConstantInt::get(ExitCond->getType(), IsAnd);
9170 if (isa<ConstantInt>(Op1))
9171 return Op1 == NeutralElement ? EL0 : EL1;
9172 if (isa<ConstantInt>(Op0))
9173 return Op0 == NeutralElement ? EL1 : EL0;
9174
9175 const SCEV *BECount = getCouldNotCompute();
9176 const SCEV *ConstantMaxBECount = getCouldNotCompute();
9177 const SCEV *SymbolicMaxBECount = getCouldNotCompute();
9178 if (EitherMayExit) {
9179 bool UseSequentialUMin = !isa<BinaryOperator>(ExitCond);
9180 // Both conditions must be same for the loop to continue executing.
9181 // Choose the less conservative count.
9182 if (EL0.ExactNotTaken != getCouldNotCompute() &&
9183 EL1.ExactNotTaken != getCouldNotCompute()) {
9184 BECount = getUMinFromMismatchedTypes(EL0.ExactNotTaken, EL1.ExactNotTaken,
9185 UseSequentialUMin);
9186 }
9187 if (EL0.ConstantMaxNotTaken == getCouldNotCompute())
9188 ConstantMaxBECount = EL1.ConstantMaxNotTaken;
9189 else if (EL1.ConstantMaxNotTaken == getCouldNotCompute())
9190 ConstantMaxBECount = EL0.ConstantMaxNotTaken;
9191 else
9192 ConstantMaxBECount = getUMinFromMismatchedTypes(EL0.ConstantMaxNotTaken,
9193 EL1.ConstantMaxNotTaken);
9194 if (EL0.SymbolicMaxNotTaken == getCouldNotCompute())
9195 SymbolicMaxBECount = EL1.SymbolicMaxNotTaken;
9196 else if (EL1.SymbolicMaxNotTaken == getCouldNotCompute())
9197 SymbolicMaxBECount = EL0.SymbolicMaxNotTaken;
9198 else
9199 SymbolicMaxBECount = getUMinFromMismatchedTypes(
9200 EL0.SymbolicMaxNotTaken, EL1.SymbolicMaxNotTaken, UseSequentialUMin);
9201 } else {
9202 // Both conditions must be same at the same time for the loop to exit.
9203 // For now, be conservative.
9204 if (EL0.ExactNotTaken == EL1.ExactNotTaken)
9205 BECount = EL0.ExactNotTaken;
9206 }
9207
9208 // There are cases (e.g. PR26207) where computeExitLimitFromCond is able
9209 // to be more aggressive when computing BECount than when computing
9210 // ConstantMaxBECount. In these cases it is possible for EL0.ExactNotTaken
9211 // and
9212 // EL1.ExactNotTaken to match, but for EL0.ConstantMaxNotTaken and
9213 // EL1.ConstantMaxNotTaken to not.
9214 if (isa<SCEVCouldNotCompute>(ConstantMaxBECount) &&
9215 !isa<SCEVCouldNotCompute>(BECount))
9216 ConstantMaxBECount = getConstant(getUnsignedRangeMax(BECount));
9217 if (isa<SCEVCouldNotCompute>(SymbolicMaxBECount))
9218 SymbolicMaxBECount =
9219 isa<SCEVCouldNotCompute>(BECount) ? ConstantMaxBECount : BECount;
9220 return ExitLimit(BECount, ConstantMaxBECount, SymbolicMaxBECount, false,
9221 {ArrayRef(EL0.Predicates), ArrayRef(EL1.Predicates)});
9222}
9223
9224ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromICmp(
9225 const Loop *L, ICmpInst *ExitCond, bool ExitIfTrue, bool ControlsOnlyExit,
9226 bool AllowPredicates) {
9227 // If the condition was exit on true, convert the condition to exit on false
9228 CmpPredicate Pred;
9229 if (!ExitIfTrue)
9230 Pred = ExitCond->getCmpPredicate();
9231 else
9232 Pred = ExitCond->getInverseCmpPredicate();
9233 const ICmpInst::Predicate OriginalPred = Pred;
9234
9235 const SCEV *LHS = getSCEV(ExitCond->getOperand(0));
9236 const SCEV *RHS = getSCEV(ExitCond->getOperand(1));
9237
9238 ExitLimit EL = computeExitLimitFromICmp(L, Pred, LHS, RHS, ControlsOnlyExit,
9239 AllowPredicates);
9240 if (EL.hasAnyInfo())
9241 return EL;
9242
9243 auto *ExhaustiveCount =
9244 computeExitCountExhaustively(L, ExitCond, ExitIfTrue);
9245
9246 if (!isa<SCEVCouldNotCompute>(ExhaustiveCount))
9247 return ExhaustiveCount;
9248
9249 return computeShiftCompareExitLimit(ExitCond->getOperand(0),
9250 ExitCond->getOperand(1), L, OriginalPred);
9251}
9252ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromICmp(
9253 const Loop *L, CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS,
9254 bool ControlsOnlyExit, bool AllowPredicates) {
9255
9256 // Try to evaluate any dependencies out of the loop.
9257 LHS = getSCEVAtScope(LHS, L);
9258 RHS = getSCEVAtScope(RHS, L);
9259
9260 // At this point, we would like to compute how many iterations of the
9261 // loop the predicate will return true for these inputs.
9262 if (isLoopInvariant(LHS, L) && !isLoopInvariant(RHS, L)) {
9263 // If there is a loop-invariant, force it into the RHS.
9264 std::swap(LHS, RHS);
9266 }
9267
9268 bool ControllingFiniteLoop = ControlsOnlyExit && loopHasNoAbnormalExits(L) &&
9270 // Simplify the operands before analyzing them.
9271 (void)SimplifyICmpOperands(Pred, LHS, RHS, /*Depth=*/0);
9272
9273 // If we have a comparison of a chrec against a constant, try to use value
9274 // ranges to answer this query.
9275 if (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS))
9276 if (const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(LHS))
9277 if (AddRec->getLoop() == L) {
9278 // Form the constant range.
9279 ConstantRange CompRange =
9280 ConstantRange::makeExactICmpRegion(Pred, RHSC->getAPInt());
9281
9282 const SCEV *Ret = AddRec->getNumIterationsInRange(CompRange, *this);
9283 if (!isa<SCEVCouldNotCompute>(Ret)) return Ret;
9284 }
9285
9286 // If this loop must exit based on this condition (or execute undefined
9287 // behaviour), see if we can improve wrap flags. This is essentially
9288 // a must execute style proof.
9289 if (ControllingFiniteLoop && isLoopInvariant(RHS, L)) {
9290 // If we can prove the test sequence produced must repeat the same values
9291 // on self-wrap of the IV, then we can infer that IV doesn't self wrap
9292 // because if it did, we'd have an infinite (undefined) loop.
9293 // TODO: We can peel off any functions which are invertible *in L*. Loop
9294 // invariant terms are effectively constants for our purposes here.
9295 auto *InnerLHS = LHS;
9296 if (auto *ZExt = dyn_cast<SCEVZeroExtendExpr>(LHS))
9297 InnerLHS = ZExt->getOperand();
9298 if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(InnerLHS);
9299 AR && !AR->hasNoSelfWrap() && AR->getLoop() == L && AR->isAffine() &&
9300 isKnownToBeAPowerOfTwo(AR->getStepRecurrence(*this), /*OrZero=*/true,
9301 /*OrNegative=*/true)) {
9302 auto Flags = AR->getNoWrapFlags();
9303 Flags = setFlags(Flags, SCEV::FlagNW);
9304 SmallVector<const SCEV *> Operands{AR->operands()};
9305 Flags = StrengthenNoWrapFlags(this, scAddRecExpr, Operands, Flags);
9306 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), Flags);
9307 }
9308
9309 // For a slt/ult condition with a positive step, can we prove nsw/nuw?
9310 // From no-self-wrap, this follows trivially from the fact that every
9311 // (un)signed-wrapped, but not self-wrapped value must be LT than the
9312 // last value before (un)signed wrap. Since we know that last value
9313 // didn't exit, nor will any smaller one.
9314 if (Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_ULT) {
9315 auto WrapType = Pred == ICmpInst::ICMP_SLT ? SCEV::FlagNSW : SCEV::FlagNUW;
9316 if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(LHS);
9317 AR && AR->getLoop() == L && AR->isAffine() &&
9318 !AR->getNoWrapFlags(WrapType) && AR->hasNoSelfWrap() &&
9319 isKnownPositive(AR->getStepRecurrence(*this))) {
9320 auto Flags = AR->getNoWrapFlags();
9321 Flags = setFlags(Flags, WrapType);
9322 SmallVector<const SCEV*> Operands{AR->operands()};
9323 Flags = StrengthenNoWrapFlags(this, scAddRecExpr, Operands, Flags);
9324 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), Flags);
9325 }
9326 }
9327 }
9328
9329 switch (Pred) {
9330 case ICmpInst::ICMP_NE: { // while (X != Y)
9331 // Convert to: while (X-Y != 0)
9332 if (LHS->getType()->isPointerTy()) {
9335 return LHS;
9336 }
9337 if (RHS->getType()->isPointerTy()) {
9340 return RHS;
9341 }
9342 ExitLimit EL = howFarToZero(getMinusSCEV(LHS, RHS), L, ControlsOnlyExit,
9343 AllowPredicates);
9344 if (EL.hasAnyInfo())
9345 return EL;
9346 break;
9347 }
9348 case ICmpInst::ICMP_EQ: { // while (X == Y)
9349 // Convert to: while (X-Y == 0)
9350 if (LHS->getType()->isPointerTy()) {
9353 return LHS;
9354 }
9355 if (RHS->getType()->isPointerTy()) {
9358 return RHS;
9359 }
9360 ExitLimit EL = howFarToNonZero(getMinusSCEV(LHS, RHS), L);
9361 if (EL.hasAnyInfo()) return EL;
9362 break;
9363 }
9364 case ICmpInst::ICMP_SLE:
9365 case ICmpInst::ICMP_ULE:
9366 // Since the loop is finite, an invariant RHS cannot include the boundary
9367 // value, otherwise it would loop forever.
9368 if (!EnableFiniteLoopControl || !ControllingFiniteLoop ||
9369 !isLoopInvariant(RHS, L)) {
9370 // Otherwise, perform the addition in a wider type, to avoid overflow.
9371 // If the LHS is an addrec with the appropriate nowrap flag, the
9372 // extension will be sunk into it and the exit count can be analyzed.
9373 auto *OldType = dyn_cast<IntegerType>(LHS->getType());
9374 if (!OldType)
9375 break;
9376 // Prefer doubling the bitwidth over adding a single bit to make it more
9377 // likely that we use a legal type.
9378 auto *NewType =
9379 Type::getIntNTy(OldType->getContext(), OldType->getBitWidth() * 2);
9380 if (ICmpInst::isSigned(Pred)) {
9381 LHS = getSignExtendExpr(LHS, NewType);
9382 RHS = getSignExtendExpr(RHS, NewType);
9383 } else {
9384 LHS = getZeroExtendExpr(LHS, NewType);
9385 RHS = getZeroExtendExpr(RHS, NewType);
9386 }
9387 }
9389 [[fallthrough]];
9390 case ICmpInst::ICMP_SLT:
9391 case ICmpInst::ICMP_ULT: { // while (X < Y)
9392 bool IsSigned = ICmpInst::isSigned(Pred);
9393 ExitLimit EL = howManyLessThans(LHS, RHS, L, IsSigned, ControlsOnlyExit,
9394 AllowPredicates);
9395 if (EL.hasAnyInfo())
9396 return EL;
9397 break;
9398 }
9399 case ICmpInst::ICMP_SGE:
9400 case ICmpInst::ICMP_UGE:
9401 // Since the loop is finite, an invariant RHS cannot include the boundary
9402 // value, otherwise it would loop forever.
9403 if (!EnableFiniteLoopControl || !ControllingFiniteLoop ||
9404 !isLoopInvariant(RHS, L))
9405 break;
9407 [[fallthrough]];
9408 case ICmpInst::ICMP_SGT:
9409 case ICmpInst::ICMP_UGT: { // while (X > Y)
9410 bool IsSigned = ICmpInst::isSigned(Pred);
9411 ExitLimit EL = howManyGreaterThans(LHS, RHS, L, IsSigned, ControlsOnlyExit,
9412 AllowPredicates);
9413 if (EL.hasAnyInfo())
9414 return EL;
9415 break;
9416 }
9417 default:
9418 break;
9419 }
9420
9421 return getCouldNotCompute();
9422}
9423
9424ScalarEvolution::ExitLimit
9425ScalarEvolution::computeExitLimitFromSingleExitSwitch(const Loop *L,
9426 SwitchInst *Switch,
9427 BasicBlock *ExitingBlock,
9428 bool ControlsOnlyExit) {
9429 assert(!L->contains(ExitingBlock) && "Not an exiting block!");
9430
9431 // Give up if the exit is the default dest of a switch.
9432 if (Switch->getDefaultDest() == ExitingBlock)
9433 return getCouldNotCompute();
9434
9435 assert(L->contains(Switch->getDefaultDest()) &&
9436 "Default case must not exit the loop!");
9437 const SCEV *LHS = getSCEVAtScope(Switch->getCondition(), L);
9438 const SCEV *RHS = getConstant(Switch->findCaseDest(ExitingBlock));
9439
9440 // while (X != Y) --> while (X-Y != 0)
9441 ExitLimit EL = howFarToZero(getMinusSCEV(LHS, RHS), L, ControlsOnlyExit);
9442 if (EL.hasAnyInfo())
9443 return EL;
9444
9445 return getCouldNotCompute();
9446}
9447
9448static ConstantInt *
9450 ScalarEvolution &SE) {
9451 const SCEV *InVal = SE.getConstant(C);
9452 const SCEV *Val = AddRec->evaluateAtIteration(InVal, SE);
9454 "Evaluation of SCEV at constant didn't fold correctly?");
9455 return cast<SCEVConstant>(Val)->getValue();
9456}
9457
9458ScalarEvolution::ExitLimit ScalarEvolution::computeShiftCompareExitLimit(
9459 Value *LHS, Value *RHSV, const Loop *L, ICmpInst::Predicate Pred) {
9460 ConstantInt *RHS = dyn_cast<ConstantInt>(RHSV);
9461 if (!RHS)
9462 return getCouldNotCompute();
9463
9464 const BasicBlock *Latch = L->getLoopLatch();
9465 if (!Latch)
9466 return getCouldNotCompute();
9467
9468 const BasicBlock *Predecessor = L->getLoopPredecessor();
9469 if (!Predecessor)
9470 return getCouldNotCompute();
9471
9472 // Return true if V is of the form "LHS `shift_op` <positive constant>".
9473 // Return LHS in OutLHS and shift_opt in OutOpCode.
9474 auto MatchPositiveShift =
9475 [](Value *V, Value *&OutLHS, Instruction::BinaryOps &OutOpCode) {
9476
9477 using namespace PatternMatch;
9478
9479 ConstantInt *ShiftAmt;
9480 if (match(V, m_LShr(m_Value(OutLHS), m_ConstantInt(ShiftAmt))))
9481 OutOpCode = Instruction::LShr;
9482 else if (match(V, m_AShr(m_Value(OutLHS), m_ConstantInt(ShiftAmt))))
9483 OutOpCode = Instruction::AShr;
9484 else if (match(V, m_Shl(m_Value(OutLHS), m_ConstantInt(ShiftAmt))))
9485 OutOpCode = Instruction::Shl;
9486 else
9487 return false;
9488
9489 return ShiftAmt->getValue().isStrictlyPositive();
9490 };
9491
9492 // Recognize a "shift recurrence" either of the form %iv or of %iv.shifted in
9493 //
9494 // loop:
9495 // %iv = phi i32 [ %iv.shifted, %loop ], [ %val, %preheader ]
9496 // %iv.shifted = lshr i32 %iv, <positive constant>
9497 //
9498 // Return true on a successful match. Return the corresponding PHI node (%iv
9499 // above) in PNOut and the opcode of the shift operation in OpCodeOut.
9500 auto MatchShiftRecurrence =
9501 [&](Value *V, PHINode *&PNOut, Instruction::BinaryOps &OpCodeOut) {
9502 std::optional<Instruction::BinaryOps> PostShiftOpCode;
9503
9504 {
9506 Value *V;
9507
9508 // If we encounter a shift instruction, "peel off" the shift operation,
9509 // and remember that we did so. Later when we inspect %iv's backedge
9510 // value, we will make sure that the backedge value uses the same
9511 // operation.
9512 //
9513 // Note: the peeled shift operation does not have to be the same
9514 // instruction as the one feeding into the PHI's backedge value. We only
9515 // really care about it being the same *kind* of shift instruction --
9516 // that's all that is required for our later inferences to hold.
9517 if (MatchPositiveShift(LHS, V, OpC)) {
9518 PostShiftOpCode = OpC;
9519 LHS = V;
9520 }
9521 }
9522
9523 PNOut = dyn_cast<PHINode>(LHS);
9524 if (!PNOut || PNOut->getParent() != L->getHeader())
9525 return false;
9526
9527 Value *BEValue = PNOut->getIncomingValueForBlock(Latch);
9528 Value *OpLHS;
9529
9530 return
9531 // The backedge value for the PHI node must be a shift by a positive
9532 // amount
9533 MatchPositiveShift(BEValue, OpLHS, OpCodeOut) &&
9534
9535 // of the PHI node itself
9536 OpLHS == PNOut &&
9537
9538 // and the kind of shift should be match the kind of shift we peeled
9539 // off, if any.
9540 (!PostShiftOpCode || *PostShiftOpCode == OpCodeOut);
9541 };
9542
9543 PHINode *PN;
9545 if (!MatchShiftRecurrence(LHS, PN, OpCode))
9546 return getCouldNotCompute();
9547
9548 const DataLayout &DL = getDataLayout();
9549
9550 // The key rationale for this optimization is that for some kinds of shift
9551 // recurrences, the value of the recurrence "stabilizes" to either 0 or -1
9552 // within a finite number of iterations. If the condition guarding the
9553 // backedge (in the sense that the backedge is taken if the condition is true)
9554 // is false for the value the shift recurrence stabilizes to, then we know
9555 // that the backedge is taken only a finite number of times.
9556
9557 ConstantInt *StableValue = nullptr;
9558 switch (OpCode) {
9559 default:
9560 llvm_unreachable("Impossible case!");
9561
9562 case Instruction::AShr: {
9563 // {K,ashr,<positive-constant>} stabilizes to signum(K) in at most
9564 // bitwidth(K) iterations.
9565 Value *FirstValue = PN->getIncomingValueForBlock(Predecessor);
9566 KnownBits Known = computeKnownBits(FirstValue, DL, &AC,
9567 Predecessor->getTerminator(), &DT);
9568 auto *Ty = cast<IntegerType>(RHS->getType());
9569 if (Known.isNonNegative())
9570 StableValue = ConstantInt::get(Ty, 0);
9571 else if (Known.isNegative())
9572 StableValue = ConstantInt::get(Ty, -1, true);
9573 else
9574 return getCouldNotCompute();
9575
9576 break;
9577 }
9578 case Instruction::LShr:
9579 case Instruction::Shl:
9580 // Both {K,lshr,<positive-constant>} and {K,shl,<positive-constant>}
9581 // stabilize to 0 in at most bitwidth(K) iterations.
9582 StableValue = ConstantInt::get(cast<IntegerType>(RHS->getType()), 0);
9583 break;
9584 }
9585
9586 auto *Result =
9587 ConstantFoldCompareInstOperands(Pred, StableValue, RHS, DL, &TLI);
9588 assert(Result->getType()->isIntegerTy(1) &&
9589 "Otherwise cannot be an operand to a branch instruction");
9590
9591 if (Result->isZeroValue()) {
9592 unsigned BitWidth = getTypeSizeInBits(RHS->getType());
9593 const SCEV *UpperBound =
9595 return ExitLimit(getCouldNotCompute(), UpperBound, UpperBound, false);
9596 }
9597
9598 return getCouldNotCompute();
9599}
9600
9601/// Return true if we can constant fold an instruction of the specified type,
9602/// assuming that all operands were constants.
9603static bool CanConstantFold(const Instruction *I) {
9607 return true;
9608
9609 if (const CallInst *CI = dyn_cast<CallInst>(I))
9610 if (const Function *F = CI->getCalledFunction())
9611 return canConstantFoldCallTo(CI, F);
9612 return false;
9613}
9614
9615/// Determine whether this instruction can constant evolve within this loop
9616/// assuming its operands can all constant evolve.
9617static bool canConstantEvolve(Instruction *I, const Loop *L) {
9618 // An instruction outside of the loop can't be derived from a loop PHI.
9619 if (!L->contains(I)) return false;
9620
9621 if (isa<PHINode>(I)) {
9622 // We don't currently keep track of the control flow needed to evaluate
9623 // PHIs, so we cannot handle PHIs inside of loops.
9624 return L->getHeader() == I->getParent();
9625 }
9626
9627 // If we won't be able to constant fold this expression even if the operands
9628 // are constants, bail early.
9629 return CanConstantFold(I);
9630}
9631
9632/// getConstantEvolvingPHIOperands - Implement getConstantEvolvingPHI by
9633/// recursing through each instruction operand until reaching a loop header phi.
9634static PHINode *
9637 unsigned Depth) {
9639 return nullptr;
9640
9641 // Otherwise, we can evaluate this instruction if all of its operands are
9642 // constant or derived from a PHI node themselves.
9643 PHINode *PHI = nullptr;
9644 for (Value *Op : UseInst->operands()) {
9645 if (isa<Constant>(Op)) continue;
9646
9648 if (!OpInst || !canConstantEvolve(OpInst, L)) return nullptr;
9649
9650 PHINode *P = dyn_cast<PHINode>(OpInst);
9651 if (!P)
9652 // If this operand is already visited, reuse the prior result.
9653 // We may have P != PHI if this is the deepest point at which the
9654 // inconsistent paths meet.
9655 P = PHIMap.lookup(OpInst);
9656 if (!P) {
9657 // Recurse and memoize the results, whether a phi is found or not.
9658 // This recursive call invalidates pointers into PHIMap.
9659 P = getConstantEvolvingPHIOperands(OpInst, L, PHIMap, Depth + 1);
9660 PHIMap[OpInst] = P;
9661 }
9662 if (!P)
9663 return nullptr; // Not evolving from PHI
9664 if (PHI && PHI != P)
9665 return nullptr; // Evolving from multiple different PHIs.
9666 PHI = P;
9667 }
9668 // This is a expression evolving from a constant PHI!
9669 return PHI;
9670}
9671
9672/// getConstantEvolvingPHI - Given an LLVM value and a loop, return a PHI node
9673/// in the loop that V is derived from. We allow arbitrary operations along the
9674/// way, but the operands of an operation must either be constants or a value
9675/// derived from a constant PHI. If this expression does not fit with these
9676/// constraints, return null.
9679 if (!I || !canConstantEvolve(I, L)) return nullptr;
9680
9681 if (PHINode *PN = dyn_cast<PHINode>(I))
9682 return PN;
9683
9684 // Record non-constant instructions contained by the loop.
9686 return getConstantEvolvingPHIOperands(I, L, PHIMap, 0);
9687}
9688
9689/// EvaluateExpression - Given an expression that passes the
9690/// getConstantEvolvingPHI predicate, evaluate its value assuming the PHI node
9691/// in the loop has the value PHIVal. If we can't fold this expression for some
9692/// reason, return null.
9695 const DataLayout &DL,
9696 const TargetLibraryInfo *TLI) {
9697 // Convenient constant check, but redundant for recursive calls.
9698 if (Constant *C = dyn_cast<Constant>(V)) return C;
9700 if (!I) return nullptr;
9701
9702 if (Constant *C = Vals.lookup(I)) return C;
9703
9704 // An instruction inside the loop depends on a value outside the loop that we
9705 // weren't given a mapping for, or a value such as a call inside the loop.
9706 if (!canConstantEvolve(I, L)) return nullptr;
9707
9708 // An unmapped PHI can be due to a branch or another loop inside this loop,
9709 // or due to this not being the initial iteration through a loop where we
9710 // couldn't compute the evolution of this particular PHI last time.
9711 if (isa<PHINode>(I)) return nullptr;
9712
9713 std::vector<Constant*> Operands(I->getNumOperands());
9714
9715 for (unsigned i = 0, e = I->getNumOperands(); i != e; ++i) {
9716 Instruction *Operand = dyn_cast<Instruction>(I->getOperand(i));
9717 if (!Operand) {
9718 Operands[i] = dyn_cast<Constant>(I->getOperand(i));
9719 if (!Operands[i]) return nullptr;
9720 continue;
9721 }
9722 Constant *C = EvaluateExpression(Operand, L, Vals, DL, TLI);
9723 Vals[Operand] = C;
9724 if (!C) return nullptr;
9725 Operands[i] = C;
9726 }
9727
9728 return ConstantFoldInstOperands(I, Operands, DL, TLI,
9729 /*AllowNonDeterministic=*/false);
9730}
9731
9732
9733// If every incoming value to PN except the one for BB is a specific Constant,
9734// return that, else return nullptr.
9736 Constant *IncomingVal = nullptr;
9737
9738 for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) {
9739 if (PN->getIncomingBlock(i) == BB)
9740 continue;
9741
9742 auto *CurrentVal = dyn_cast<Constant>(PN->getIncomingValue(i));
9743 if (!CurrentVal)
9744 return nullptr;
9745
9746 if (IncomingVal != CurrentVal) {
9747 if (IncomingVal)
9748 return nullptr;
9749 IncomingVal = CurrentVal;
9750 }
9751 }
9752
9753 return IncomingVal;
9754}
9755
9756/// getConstantEvolutionLoopExitValue - If we know that the specified Phi is
9757/// in the header of its containing loop, we know the loop executes a
9758/// constant number of times, and the PHI node is just a recurrence
9759/// involving constants, fold it.
9760Constant *
9761ScalarEvolution::getConstantEvolutionLoopExitValue(PHINode *PN,
9762 const APInt &BEs,
9763 const Loop *L) {
9764 auto [I, Inserted] = ConstantEvolutionLoopExitValue.try_emplace(PN);
9765 if (!Inserted)
9766 return I->second;
9767
9769 return nullptr; // Not going to evaluate it.
9770
9771 Constant *&RetVal = I->second;
9772
9773 DenseMap<Instruction *, Constant *> CurrentIterVals;
9774 BasicBlock *Header = L->getHeader();
9775 assert(PN->getParent() == Header && "Can't evaluate PHI not in loop header!");
9776
9777 BasicBlock *Latch = L->getLoopLatch();
9778 if (!Latch)
9779 return nullptr;
9780
9781 for (PHINode &PHI : Header->phis()) {
9782 if (auto *StartCST = getOtherIncomingValue(&PHI, Latch))
9783 CurrentIterVals[&PHI] = StartCST;
9784 }
9785 if (!CurrentIterVals.count(PN))
9786 return RetVal = nullptr;
9787
9788 Value *BEValue = PN->getIncomingValueForBlock(Latch);
9789
9790 // Execute the loop symbolically to determine the exit value.
9791 assert(BEs.getActiveBits() < CHAR_BIT * sizeof(unsigned) &&
9792 "BEs is <= MaxBruteForceIterations which is an 'unsigned'!");
9793
9794 unsigned NumIterations = BEs.getZExtValue(); // must be in range
9795 unsigned IterationNum = 0;
9796 const DataLayout &DL = getDataLayout();
9797 for (; ; ++IterationNum) {
9798 if (IterationNum == NumIterations)
9799 return RetVal = CurrentIterVals[PN]; // Got exit value!
9800
9801 // Compute the value of the PHIs for the next iteration.
9802 // EvaluateExpression adds non-phi values to the CurrentIterVals map.
9803 DenseMap<Instruction *, Constant *> NextIterVals;
9804 Constant *NextPHI =
9805 EvaluateExpression(BEValue, L, CurrentIterVals, DL, &TLI);
9806 if (!NextPHI)
9807 return nullptr; // Couldn't evaluate!
9808 NextIterVals[PN] = NextPHI;
9809
9810 bool StoppedEvolving = NextPHI == CurrentIterVals[PN];
9811
9812 // Also evaluate the other PHI nodes. However, we don't get to stop if we
9813 // cease to be able to evaluate one of them or if they stop evolving,
9814 // because that doesn't necessarily prevent us from computing PN.
9816 for (const auto &I : CurrentIterVals) {
9817 PHINode *PHI = dyn_cast<PHINode>(I.first);
9818 if (!PHI || PHI == PN || PHI->getParent() != Header) continue;
9819 PHIsToCompute.emplace_back(PHI, I.second);
9820 }
9821 // We use two distinct loops because EvaluateExpression may invalidate any
9822 // iterators into CurrentIterVals.
9823 for (const auto &I : PHIsToCompute) {
9824 PHINode *PHI = I.first;
9825 Constant *&NextPHI = NextIterVals[PHI];
9826 if (!NextPHI) { // Not already computed.
9827 Value *BEValue = PHI->getIncomingValueForBlock(Latch);
9828 NextPHI = EvaluateExpression(BEValue, L, CurrentIterVals, DL, &TLI);
9829 }
9830 if (NextPHI != I.second)
9831 StoppedEvolving = false;
9832 }
9833
9834 // If all entries in CurrentIterVals == NextIterVals then we can stop
9835 // iterating, the loop can't continue to change.
9836 if (StoppedEvolving)
9837 return RetVal = CurrentIterVals[PN];
9838
9839 CurrentIterVals.swap(NextIterVals);
9840 }
9841}
9842
9843const SCEV *ScalarEvolution::computeExitCountExhaustively(const Loop *L,
9844 Value *Cond,
9845 bool ExitWhen) {
9846 PHINode *PN = getConstantEvolvingPHI(Cond, L);
9847 if (!PN) return getCouldNotCompute();
9848
9849 // If the loop is canonicalized, the PHI will have exactly two entries.
9850 // That's the only form we support here.
9851 if (PN->getNumIncomingValues() != 2) return getCouldNotCompute();
9852
9853 DenseMap<Instruction *, Constant *> CurrentIterVals;
9854 BasicBlock *Header = L->getHeader();
9855 assert(PN->getParent() == Header && "Can't evaluate PHI not in loop header!");
9856
9857 BasicBlock *Latch = L->getLoopLatch();
9858 assert(Latch && "Should follow from NumIncomingValues == 2!");
9859
9860 for (PHINode &PHI : Header->phis()) {
9861 if (auto *StartCST = getOtherIncomingValue(&PHI, Latch))
9862 CurrentIterVals[&PHI] = StartCST;
9863 }
9864 if (!CurrentIterVals.count(PN))
9865 return getCouldNotCompute();
9866
9867 // Okay, we find a PHI node that defines the trip count of this loop. Execute
9868 // the loop symbolically to determine when the condition gets a value of
9869 // "ExitWhen".
9870 unsigned MaxIterations = MaxBruteForceIterations; // Limit analysis.
9871 const DataLayout &DL = getDataLayout();
9872 for (unsigned IterationNum = 0; IterationNum != MaxIterations;++IterationNum){
9873 auto *CondVal = dyn_cast_or_null<ConstantInt>(
9874 EvaluateExpression(Cond, L, CurrentIterVals, DL, &TLI));
9875
9876 // Couldn't symbolically evaluate.
9877 if (!CondVal) return getCouldNotCompute();
9878
9879 if (CondVal->getValue() == uint64_t(ExitWhen)) {
9880 ++NumBruteForceTripCountsComputed;
9881 return getConstant(Type::getInt32Ty(getContext()), IterationNum);
9882 }
9883
9884 // Update all the PHI nodes for the next iteration.
9885 DenseMap<Instruction *, Constant *> NextIterVals;
9886
9887 // Create a list of which PHIs we need to compute. We want to do this before
9888 // calling EvaluateExpression on them because that may invalidate iterators
9889 // into CurrentIterVals.
9890 SmallVector<PHINode *, 8> PHIsToCompute;
9891 for (const auto &I : CurrentIterVals) {
9892 PHINode *PHI = dyn_cast<PHINode>(I.first);
9893 if (!PHI || PHI->getParent() != Header) continue;
9894 PHIsToCompute.push_back(PHI);
9895 }
9896 for (PHINode *PHI : PHIsToCompute) {
9897 Constant *&NextPHI = NextIterVals[PHI];
9898 if (NextPHI) continue; // Already computed!
9899
9900 Value *BEValue = PHI->getIncomingValueForBlock(Latch);
9901 NextPHI = EvaluateExpression(BEValue, L, CurrentIterVals, DL, &TLI);
9902 }
9903 CurrentIterVals.swap(NextIterVals);
9904 }
9905
9906 // Too many iterations were needed to evaluate.
9907 return getCouldNotCompute();
9908}
9909
9910const SCEV *ScalarEvolution::getSCEVAtScope(const SCEV *V, const Loop *L) {
9912 ValuesAtScopes[V];
9913 // Check to see if we've folded this expression at this loop before.
9914 for (auto &LS : Values)
9915 if (LS.first == L)
9916 return LS.second ? LS.second : V;
9917
9918 Values.emplace_back(L, nullptr);
9919
9920 // Otherwise compute it.
9921 const SCEV *C = computeSCEVAtScope(V, L);
9922 for (auto &LS : reverse(ValuesAtScopes[V]))
9923 if (LS.first == L) {
9924 LS.second = C;
9925 if (!isa<SCEVConstant>(C))
9926 ValuesAtScopesUsers[C].push_back({L, V});
9927 break;
9928 }
9929 return C;
9930}
9931
9932/// This builds up a Constant using the ConstantExpr interface. That way, we
9933/// will return Constants for objects which aren't represented by a
9934/// SCEVConstant, because SCEVConstant is restricted to ConstantInt.
9935/// Returns NULL if the SCEV isn't representable as a Constant.
9937 switch (V->getSCEVType()) {
9938 case scCouldNotCompute:
9939 case scAddRecExpr:
9940 case scVScale:
9941 return nullptr;
9942 case scConstant:
9943 return cast<SCEVConstant>(V)->getValue();
9944 case scUnknown:
9945 return dyn_cast<Constant>(cast<SCEVUnknown>(V)->getValue());
9946 case scPtrToInt: {
9948 if (Constant *CastOp = BuildConstantFromSCEV(P2I->getOperand()))
9949 return ConstantExpr::getPtrToInt(CastOp, P2I->getType());
9950
9951 return nullptr;
9952 }
9953 case scTruncate: {
9955 if (Constant *CastOp = BuildConstantFromSCEV(ST->getOperand()))
9956 return ConstantExpr::getTrunc(CastOp, ST->getType());
9957 return nullptr;
9958 }
9959 case scAddExpr: {
9960 const SCEVAddExpr *SA = cast<SCEVAddExpr>(V);
9961 Constant *C = nullptr;
9962 for (const SCEV *Op : SA->operands()) {
9964 if (!OpC)
9965 return nullptr;
9966 if (!C) {
9967 C = OpC;
9968 continue;
9969 }
9970 assert(!C->getType()->isPointerTy() &&
9971 "Can only have one pointer, and it must be last");
9972 if (OpC->getType()->isPointerTy()) {
9973 // The offsets have been converted to bytes. We can add bytes using
9974 // an i8 GEP.
9976 OpC, C);
9977 } else {
9978 C = ConstantExpr::getAdd(C, OpC);
9979 }
9980 }
9981 return C;
9982 }
9983 case scMulExpr:
9984 case scSignExtend:
9985 case scZeroExtend:
9986 case scUDivExpr:
9987 case scSMaxExpr:
9988 case scUMaxExpr:
9989 case scSMinExpr:
9990 case scUMinExpr:
9992 return nullptr;
9993 }
9994 llvm_unreachable("Unknown SCEV kind!");
9995}
9996
9997const SCEV *
9998ScalarEvolution::getWithOperands(const SCEV *S,
9999 SmallVectorImpl<const SCEV *> &NewOps) {
10000 switch (S->getSCEVType()) {
10001 case scTruncate:
10002 case scZeroExtend:
10003 case scSignExtend:
10004 case scPtrToInt:
10005 return getCastExpr(S->getSCEVType(), NewOps[0], S->getType());
10006 case scAddRecExpr: {
10007 auto *AddRec = cast<SCEVAddRecExpr>(S);
10008 return getAddRecExpr(NewOps, AddRec->getLoop(), AddRec->getNoWrapFlags());
10009 }
10010 case scAddExpr:
10011 return getAddExpr(NewOps, cast<SCEVAddExpr>(S)->getNoWrapFlags());
10012 case scMulExpr:
10013 return getMulExpr(NewOps, cast<SCEVMulExpr>(S)->getNoWrapFlags());
10014 case scUDivExpr:
10015 return getUDivExpr(NewOps[0], NewOps[1]);
10016 case scUMaxExpr:
10017 case scSMaxExpr:
10018 case scUMinExpr:
10019 case scSMinExpr:
10020 return getMinMaxExpr(S->getSCEVType(), NewOps);
10022 return getSequentialMinMaxExpr(S->getSCEVType(), NewOps);
10023 case scConstant:
10024 case scVScale:
10025 case scUnknown:
10026 return S;
10027 case scCouldNotCompute:
10028 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
10029 }
10030 llvm_unreachable("Unknown SCEV kind!");
10031}
10032
10033const SCEV *ScalarEvolution::computeSCEVAtScope(const SCEV *V, const Loop *L) {
10034 switch (V->getSCEVType()) {
10035 case scConstant:
10036 case scVScale:
10037 return V;
10038 case scAddRecExpr: {
10039 // If this is a loop recurrence for a loop that does not contain L, then we
10040 // are dealing with the final value computed by the loop.
10041 const SCEVAddRecExpr *AddRec = cast<SCEVAddRecExpr>(V);
10042 // First, attempt to evaluate each operand.
10043 // Avoid performing the look-up in the common case where the specified
10044 // expression has no loop-variant portions.
10045 for (unsigned i = 0, e = AddRec->getNumOperands(); i != e; ++i) {
10046 const SCEV *OpAtScope = getSCEVAtScope(AddRec->getOperand(i), L);
10047 if (OpAtScope == AddRec->getOperand(i))
10048 continue;
10049
10050 // Okay, at least one of these operands is loop variant but might be
10051 // foldable. Build a new instance of the folded commutative expression.
10053 NewOps.reserve(AddRec->getNumOperands());
10054 append_range(NewOps, AddRec->operands().take_front(i));
10055 NewOps.push_back(OpAtScope);
10056 for (++i; i != e; ++i)
10057 NewOps.push_back(getSCEVAtScope(AddRec->getOperand(i), L));
10058
10059 const SCEV *FoldedRec = getAddRecExpr(
10060 NewOps, AddRec->getLoop(), AddRec->getNoWrapFlags(SCEV::FlagNW));
10061 AddRec = dyn_cast<SCEVAddRecExpr>(FoldedRec);
10062 // The addrec may be folded to a nonrecurrence, for example, if the
10063 // induction variable is multiplied by zero after constant folding. Go
10064 // ahead and return the folded value.
10065 if (!AddRec)
10066 return FoldedRec;
10067 break;
10068 }
10069
10070 // If the scope is outside the addrec's loop, evaluate it by using the
10071 // loop exit value of the addrec.
10072 if (!AddRec->getLoop()->contains(L)) {
10073 // To evaluate this recurrence, we need to know how many times the AddRec
10074 // loop iterates. Compute this now.
10075 const SCEV *BackedgeTakenCount = getBackedgeTakenCount(AddRec->getLoop());
10076 if (BackedgeTakenCount == getCouldNotCompute())
10077 return AddRec;
10078
10079 // Then, evaluate the AddRec.
10080 return AddRec->evaluateAtIteration(BackedgeTakenCount, *this);
10081 }
10082
10083 return AddRec;
10084 }
10085 case scTruncate:
10086 case scZeroExtend:
10087 case scSignExtend:
10088 case scPtrToInt:
10089 case scAddExpr:
10090 case scMulExpr:
10091 case scUDivExpr:
10092 case scUMaxExpr:
10093 case scSMaxExpr:
10094 case scUMinExpr:
10095 case scSMinExpr:
10096 case scSequentialUMinExpr: {
10097 ArrayRef<const SCEV *> Ops = V->operands();
10098 // Avoid performing the look-up in the common case where the specified
10099 // expression has no loop-variant portions.
10100 for (unsigned i = 0, e = Ops.size(); i != e; ++i) {
10101 const SCEV *OpAtScope = getSCEVAtScope(Ops[i], L);
10102 if (OpAtScope != Ops[i]) {
10103 // Okay, at least one of these operands is loop variant but might be
10104 // foldable. Build a new instance of the folded commutative expression.
10106 NewOps.reserve(Ops.size());
10107 append_range(NewOps, Ops.take_front(i));
10108 NewOps.push_back(OpAtScope);
10109
10110 for (++i; i != e; ++i) {
10111 OpAtScope = getSCEVAtScope(Ops[i], L);
10112 NewOps.push_back(OpAtScope);
10113 }
10114
10115 return getWithOperands(V, NewOps);
10116 }
10117 }
10118 // If we got here, all operands are loop invariant.
10119 return V;
10120 }
10121 case scUnknown: {
10122 // If this instruction is evolved from a constant-evolving PHI, compute the
10123 // exit value from the loop without using SCEVs.
10124 const SCEVUnknown *SU = cast<SCEVUnknown>(V);
10126 if (!I)
10127 return V; // This is some other type of SCEVUnknown, just return it.
10128
10129 if (PHINode *PN = dyn_cast<PHINode>(I)) {
10130 const Loop *CurrLoop = this->LI[I->getParent()];
10131 // Looking for loop exit value.
10132 if (CurrLoop && CurrLoop->getParentLoop() == L &&
10133 PN->getParent() == CurrLoop->getHeader()) {
10134 // Okay, there is no closed form solution for the PHI node. Check
10135 // to see if the loop that contains it has a known backedge-taken
10136 // count. If so, we may be able to force computation of the exit
10137 // value.
10138 const SCEV *BackedgeTakenCount = getBackedgeTakenCount(CurrLoop);
10139 // This trivial case can show up in some degenerate cases where
10140 // the incoming IR has not yet been fully simplified.
10141 if (BackedgeTakenCount->isZero()) {
10142 Value *InitValue = nullptr;
10143 bool MultipleInitValues = false;
10144 for (unsigned i = 0; i < PN->getNumIncomingValues(); i++) {
10145 if (!CurrLoop->contains(PN->getIncomingBlock(i))) {
10146 if (!InitValue)
10147 InitValue = PN->getIncomingValue(i);
10148 else if (InitValue != PN->getIncomingValue(i)) {
10149 MultipleInitValues = true;
10150 break;
10151 }
10152 }
10153 }
10154 if (!MultipleInitValues && InitValue)
10155 return getSCEV(InitValue);
10156 }
10157 // Do we have a loop invariant value flowing around the backedge
10158 // for a loop which must execute the backedge?
10159 if (!isa<SCEVCouldNotCompute>(BackedgeTakenCount) &&
10160 isKnownNonZero(BackedgeTakenCount) &&
10161 PN->getNumIncomingValues() == 2) {
10162
10163 unsigned InLoopPred =
10164 CurrLoop->contains(PN->getIncomingBlock(0)) ? 0 : 1;
10165 Value *BackedgeVal = PN->getIncomingValue(InLoopPred);
10166 if (CurrLoop->isLoopInvariant(BackedgeVal))
10167 return getSCEV(BackedgeVal);
10168 }
10169 if (auto *BTCC = dyn_cast<SCEVConstant>(BackedgeTakenCount)) {
10170 // Okay, we know how many times the containing loop executes. If
10171 // this is a constant evolving PHI node, get the final value at
10172 // the specified iteration number.
10173 Constant *RV =
10174 getConstantEvolutionLoopExitValue(PN, BTCC->getAPInt(), CurrLoop);
10175 if (RV)
10176 return getSCEV(RV);
10177 }
10178 }
10179 }
10180
10181 // Okay, this is an expression that we cannot symbolically evaluate
10182 // into a SCEV. Check to see if it's possible to symbolically evaluate
10183 // the arguments into constants, and if so, try to constant propagate the
10184 // result. This is particularly useful for computing loop exit values.
10185 if (!CanConstantFold(I))
10186 return V; // This is some other type of SCEVUnknown, just return it.
10187
10189 Operands.reserve(I->getNumOperands());
10190 bool MadeImprovement = false;
10191 for (Value *Op : I->operands()) {
10192 if (Constant *C = dyn_cast<Constant>(Op)) {
10193 Operands.push_back(C);
10194 continue;
10195 }
10196
10197 // If any of the operands is non-constant and if they are
10198 // non-integer and non-pointer, don't even try to analyze them
10199 // with scev techniques.
10200 if (!isSCEVable(Op->getType()))
10201 return V;
10202
10203 const SCEV *OrigV = getSCEV(Op);
10204 const SCEV *OpV = getSCEVAtScope(OrigV, L);
10205 MadeImprovement |= OrigV != OpV;
10206
10208 if (!C)
10209 return V;
10210 assert(C->getType() == Op->getType() && "Type mismatch");
10211 Operands.push_back(C);
10212 }
10213
10214 // Check to see if getSCEVAtScope actually made an improvement.
10215 if (!MadeImprovement)
10216 return V; // This is some other type of SCEVUnknown, just return it.
10217
10218 Constant *C = nullptr;
10219 const DataLayout &DL = getDataLayout();
10220 C = ConstantFoldInstOperands(I, Operands, DL, &TLI,
10221 /*AllowNonDeterministic=*/false);
10222 if (!C)
10223 return V;
10224 return getSCEV(C);
10225 }
10226 case scCouldNotCompute:
10227 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
10228 }
10229 llvm_unreachable("Unknown SCEV type!");
10230}
10231
10233 return getSCEVAtScope(getSCEV(V), L);
10234}
10235
10236const SCEV *ScalarEvolution::stripInjectiveFunctions(const SCEV *S) const {
10238 return stripInjectiveFunctions(ZExt->getOperand());
10240 return stripInjectiveFunctions(SExt->getOperand());
10241 return S;
10242}
10243
10244/// Finds the minimum unsigned root of the following equation:
10245///
10246/// A * X = B (mod N)
10247///
10248/// where N = 2^BW and BW is the common bit width of A and B. The signedness of
10249/// A and B isn't important.
10250///
10251/// If the equation does not have a solution, SCEVCouldNotCompute is returned.
10252static const SCEV *
10255 ScalarEvolution &SE, const Loop *L) {
10256 uint32_t BW = A.getBitWidth();
10257 assert(BW == SE.getTypeSizeInBits(B->getType()));
10258 assert(A != 0 && "A must be non-zero.");
10259
10260 // 1. D = gcd(A, N)
10261 //
10262 // The gcd of A and N may have only one prime factor: 2. The number of
10263 // trailing zeros in A is its multiplicity
10264 uint32_t Mult2 = A.countr_zero();
10265 // D = 2^Mult2
10266
10267 // 2. Check if B is divisible by D.
10268 //
10269 // B is divisible by D if and only if the multiplicity of prime factor 2 for B
10270 // is not less than multiplicity of this prime factor for D.
10271 unsigned MinTZ = SE.getMinTrailingZeros(B);
10272 // Try again with the terminator of the loop predecessor for context-specific
10273 // result, if MinTZ s too small.
10274 if (MinTZ < Mult2 && L->getLoopPredecessor())
10275 MinTZ = SE.getMinTrailingZeros(B, L->getLoopPredecessor()->getTerminator());
10276 if (MinTZ < Mult2) {
10277 // Check if we can prove there's no remainder using URem.
10278 const SCEV *URem =
10279 SE.getURemExpr(B, SE.getConstant(APInt::getOneBitSet(BW, Mult2)));
10280 const SCEV *Zero = SE.getZero(B->getType());
10281 if (!SE.isKnownPredicate(CmpInst::ICMP_EQ, URem, Zero)) {
10282 // Try to add a predicate ensuring B is a multiple of 1 << Mult2.
10283 if (!Predicates)
10284 return SE.getCouldNotCompute();
10285
10286 // Avoid adding a predicate that is known to be false.
10287 if (SE.isKnownPredicate(CmpInst::ICMP_NE, URem, Zero))
10288 return SE.getCouldNotCompute();
10289 Predicates->push_back(SE.getEqualPredicate(URem, Zero));
10290 }
10291 }
10292
10293 // 3. Compute I: the multiplicative inverse of (A / D) in arithmetic
10294 // modulo (N / D).
10295 //
10296 // If D == 1, (N / D) == N == 2^BW, so we need one extra bit to represent
10297 // (N / D) in general. The inverse itself always fits into BW bits, though,
10298 // so we immediately truncate it.
10299 APInt AD = A.lshr(Mult2).trunc(BW - Mult2); // AD = A / D
10300 APInt I = AD.multiplicativeInverse().zext(BW);
10301
10302 // 4. Compute the minimum unsigned root of the equation:
10303 // I * (B / D) mod (N / D)
10304 // To simplify the computation, we factor out the divide by D:
10305 // (I * B mod N) / D
10306 const SCEV *D = SE.getConstant(APInt::getOneBitSet(BW, Mult2));
10307 return SE.getUDivExactExpr(SE.getMulExpr(B, SE.getConstant(I)), D);
10308}
10309
10310/// For a given quadratic addrec, generate coefficients of the corresponding
10311/// quadratic equation, multiplied by a common value to ensure that they are
10312/// integers.
10313/// The returned value is a tuple { A, B, C, M, BitWidth }, where
10314/// Ax^2 + Bx + C is the quadratic function, M is the value that A, B and C
10315/// were multiplied by, and BitWidth is the bit width of the original addrec
10316/// coefficients.
10317/// This function returns std::nullopt if the addrec coefficients are not
10318/// compile- time constants.
10319static std::optional<std::tuple<APInt, APInt, APInt, APInt, unsigned>>
10321 assert(AddRec->getNumOperands() == 3 && "This is not a quadratic chrec!");
10322 const SCEVConstant *LC = dyn_cast<SCEVConstant>(AddRec->getOperand(0));
10323 const SCEVConstant *MC = dyn_cast<SCEVConstant>(AddRec->getOperand(1));
10324 const SCEVConstant *NC = dyn_cast<SCEVConstant>(AddRec->getOperand(2));
10325 LLVM_DEBUG(dbgs() << __func__ << ": analyzing quadratic addrec: "
10326 << *AddRec << '\n');
10327
10328 // We currently can only solve this if the coefficients are constants.
10329 if (!LC || !MC || !NC) {
10330 LLVM_DEBUG(dbgs() << __func__ << ": coefficients are not constant\n");
10331 return std::nullopt;
10332 }
10333
10334 APInt L = LC->getAPInt();
10335 APInt M = MC->getAPInt();
10336 APInt N = NC->getAPInt();
10337 assert(!N.isZero() && "This is not a quadratic addrec");
10338
10339 unsigned BitWidth = LC->getAPInt().getBitWidth();
10340 unsigned NewWidth = BitWidth + 1;
10341 LLVM_DEBUG(dbgs() << __func__ << ": addrec coeff bw: "
10342 << BitWidth << '\n');
10343 // The sign-extension (as opposed to a zero-extension) here matches the
10344 // extension used in SolveQuadraticEquationWrap (with the same motivation).
10345 N = N.sext(NewWidth);
10346 M = M.sext(NewWidth);
10347 L = L.sext(NewWidth);
10348
10349 // The increments are M, M+N, M+2N, ..., so the accumulated values are
10350 // L+M, (L+M)+(M+N), (L+M)+(M+N)+(M+2N), ..., that is,
10351 // L+M, L+2M+N, L+3M+3N, ...
10352 // After n iterations the accumulated value Acc is L + nM + n(n-1)/2 N.
10353 //
10354 // The equation Acc = 0 is then
10355 // L + nM + n(n-1)/2 N = 0, or 2L + 2M n + n(n-1) N = 0.
10356 // In a quadratic form it becomes:
10357 // N n^2 + (2M-N) n + 2L = 0.
10358
10359 APInt A = N;
10360 APInt B = 2 * M - A;
10361 APInt C = 2 * L;
10362 APInt T = APInt(NewWidth, 2);
10363 LLVM_DEBUG(dbgs() << __func__ << ": equation " << A << "x^2 + " << B
10364 << "x + " << C << ", coeff bw: " << NewWidth
10365 << ", multiplied by " << T << '\n');
10366 return std::make_tuple(A, B, C, T, BitWidth);
10367}
10368
10369/// Helper function to compare optional APInts:
10370/// (a) if X and Y both exist, return min(X, Y),
10371/// (b) if neither X nor Y exist, return std::nullopt,
10372/// (c) if exactly one of X and Y exists, return that value.
10373static std::optional<APInt> MinOptional(std::optional<APInt> X,
10374 std::optional<APInt> Y) {
10375 if (X && Y) {
10376 unsigned W = std::max(X->getBitWidth(), Y->getBitWidth());
10377 APInt XW = X->sext(W);
10378 APInt YW = Y->sext(W);
10379 return XW.slt(YW) ? *X : *Y;
10380 }
10381 if (!X && !Y)
10382 return std::nullopt;
10383 return X ? *X : *Y;
10384}
10385
10386/// Helper function to truncate an optional APInt to a given BitWidth.
10387/// When solving addrec-related equations, it is preferable to return a value
10388/// that has the same bit width as the original addrec's coefficients. If the
10389/// solution fits in the original bit width, truncate it (except for i1).
10390/// Returning a value of a different bit width may inhibit some optimizations.
10391///
10392/// In general, a solution to a quadratic equation generated from an addrec
10393/// may require BW+1 bits, where BW is the bit width of the addrec's
10394/// coefficients. The reason is that the coefficients of the quadratic
10395/// equation are BW+1 bits wide (to avoid truncation when converting from
10396/// the addrec to the equation).
10397static std::optional<APInt> TruncIfPossible(std::optional<APInt> X,
10398 unsigned BitWidth) {
10399 if (!X)
10400 return std::nullopt;
10401 unsigned W = X->getBitWidth();
10403 return X->trunc(BitWidth);
10404 return X;
10405}
10406
10407/// Let c(n) be the value of the quadratic chrec {L,+,M,+,N} after n
10408/// iterations. The values L, M, N are assumed to be signed, and they
10409/// should all have the same bit widths.
10410/// Find the least n >= 0 such that c(n) = 0 in the arithmetic modulo 2^BW,
10411/// where BW is the bit width of the addrec's coefficients.
10412/// If the calculated value is a BW-bit integer (for BW > 1), it will be
10413/// returned as such, otherwise the bit width of the returned value may
10414/// be greater than BW.
10415///
10416/// This function returns std::nullopt if
10417/// (a) the addrec coefficients are not constant, or
10418/// (b) SolveQuadraticEquationWrap was unable to find a solution. For cases
10419/// like x^2 = 5, no integer solutions exist, in other cases an integer
10420/// solution may exist, but SolveQuadraticEquationWrap may fail to find it.
10421static std::optional<APInt>
10423 APInt A, B, C, M;
10424 unsigned BitWidth;
10425 auto T = GetQuadraticEquation(AddRec);
10426 if (!T)
10427 return std::nullopt;
10428
10429 std::tie(A, B, C, M, BitWidth) = *T;
10430 LLVM_DEBUG(dbgs() << __func__ << ": solving for unsigned overflow\n");
10431 std::optional<APInt> X =
10433 if (!X)
10434 return std::nullopt;
10435
10436 ConstantInt *CX = ConstantInt::get(SE.getContext(), *X);
10437 ConstantInt *V = EvaluateConstantChrecAtConstant(AddRec, CX, SE);
10438 if (!V->isZero())
10439 return std::nullopt;
10440
10441 return TruncIfPossible(X, BitWidth);
10442}
10443
10444/// Let c(n) be the value of the quadratic chrec {0,+,M,+,N} after n
10445/// iterations. The values M, N are assumed to be signed, and they
10446/// should all have the same bit widths.
10447/// Find the least n such that c(n) does not belong to the given range,
10448/// while c(n-1) does.
10449///
10450/// This function returns std::nullopt if
10451/// (a) the addrec coefficients are not constant, or
10452/// (b) SolveQuadraticEquationWrap was unable to find a solution for the
10453/// bounds of the range.
10454static std::optional<APInt>
10456 const ConstantRange &Range, ScalarEvolution &SE) {
10457 assert(AddRec->getOperand(0)->isZero() &&
10458 "Starting value of addrec should be 0");
10459 LLVM_DEBUG(dbgs() << __func__ << ": solving boundary crossing for range "
10460 << Range << ", addrec " << *AddRec << '\n');
10461 // This case is handled in getNumIterationsInRange. Here we can assume that
10462 // we start in the range.
10463 assert(Range.contains(APInt(SE.getTypeSizeInBits(AddRec->getType()), 0)) &&
10464 "Addrec's initial value should be in range");
10465
10466 APInt A, B, C, M;
10467 unsigned BitWidth;
10468 auto T = GetQuadraticEquation(AddRec);
10469 if (!T)
10470 return std::nullopt;
10471
10472 // Be careful about the return value: there can be two reasons for not
10473 // returning an actual number. First, if no solutions to the equations
10474 // were found, and second, if the solutions don't leave the given range.
10475 // The first case means that the actual solution is "unknown", the second
10476 // means that it's known, but not valid. If the solution is unknown, we
10477 // cannot make any conclusions.
10478 // Return a pair: the optional solution and a flag indicating if the
10479 // solution was found.
10480 auto SolveForBoundary =
10481 [&](APInt Bound) -> std::pair<std::optional<APInt>, bool> {
10482 // Solve for signed overflow and unsigned overflow, pick the lower
10483 // solution.
10484 LLVM_DEBUG(dbgs() << "SolveQuadraticAddRecRange: checking boundary "
10485 << Bound << " (before multiplying by " << M << ")\n");
10486 Bound *= M; // The quadratic equation multiplier.
10487
10488 std::optional<APInt> SO;
10489 if (BitWidth > 1) {
10490 LLVM_DEBUG(dbgs() << "SolveQuadraticAddRecRange: solving for "
10491 "signed overflow\n");
10493 }
10494 LLVM_DEBUG(dbgs() << "SolveQuadraticAddRecRange: solving for "
10495 "unsigned overflow\n");
10496 std::optional<APInt> UO =
10498
10499 auto LeavesRange = [&] (const APInt &X) {
10500 ConstantInt *C0 = ConstantInt::get(SE.getContext(), X);
10501 ConstantInt *V0 = EvaluateConstantChrecAtConstant(AddRec, C0, SE);
10502 if (Range.contains(V0->getValue()))
10503 return false;
10504 // X should be at least 1, so X-1 is non-negative.
10505 ConstantInt *C1 = ConstantInt::get(SE.getContext(), X-1);
10506 ConstantInt *V1 = EvaluateConstantChrecAtConstant(AddRec, C1, SE);
10507 if (Range.contains(V1->getValue()))
10508 return true;
10509 return false;
10510 };
10511
10512 // If SolveQuadraticEquationWrap returns std::nullopt, it means that there
10513 // can be a solution, but the function failed to find it. We cannot treat it
10514 // as "no solution".
10515 if (!SO || !UO)
10516 return {std::nullopt, false};
10517
10518 // Check the smaller value first to see if it leaves the range.
10519 // At this point, both SO and UO must have values.
10520 std::optional<APInt> Min = MinOptional(SO, UO);
10521 if (LeavesRange(*Min))
10522 return { Min, true };
10523 std::optional<APInt> Max = Min == SO ? UO : SO;
10524 if (LeavesRange(*Max))
10525 return { Max, true };
10526
10527 // Solutions were found, but were eliminated, hence the "true".
10528 return {std::nullopt, true};
10529 };
10530
10531 std::tie(A, B, C, M, BitWidth) = *T;
10532 // Lower bound is inclusive, subtract 1 to represent the exiting value.
10533 APInt Lower = Range.getLower().sext(A.getBitWidth()) - 1;
10534 APInt Upper = Range.getUpper().sext(A.getBitWidth());
10535 auto SL = SolveForBoundary(Lower);
10536 auto SU = SolveForBoundary(Upper);
10537 // If any of the solutions was unknown, no meaninigful conclusions can
10538 // be made.
10539 if (!SL.second || !SU.second)
10540 return std::nullopt;
10541
10542 // Claim: The correct solution is not some value between Min and Max.
10543 //
10544 // Justification: Assuming that Min and Max are different values, one of
10545 // them is when the first signed overflow happens, the other is when the
10546 // first unsigned overflow happens. Crossing the range boundary is only
10547 // possible via an overflow (treating 0 as a special case of it, modeling
10548 // an overflow as crossing k*2^W for some k).
10549 //
10550 // The interesting case here is when Min was eliminated as an invalid
10551 // solution, but Max was not. The argument is that if there was another
10552 // overflow between Min and Max, it would also have been eliminated if
10553 // it was considered.
10554 //
10555 // For a given boundary, it is possible to have two overflows of the same
10556 // type (signed/unsigned) without having the other type in between: this
10557 // can happen when the vertex of the parabola is between the iterations
10558 // corresponding to the overflows. This is only possible when the two
10559 // overflows cross k*2^W for the same k. In such case, if the second one
10560 // left the range (and was the first one to do so), the first overflow
10561 // would have to enter the range, which would mean that either we had left
10562 // the range before or that we started outside of it. Both of these cases
10563 // are contradictions.
10564 //
10565 // Claim: In the case where SolveForBoundary returns std::nullopt, the correct
10566 // solution is not some value between the Max for this boundary and the
10567 // Min of the other boundary.
10568 //
10569 // Justification: Assume that we had such Max_A and Min_B corresponding
10570 // to range boundaries A and B and such that Max_A < Min_B. If there was
10571 // a solution between Max_A and Min_B, it would have to be caused by an
10572 // overflow corresponding to either A or B. It cannot correspond to B,
10573 // since Min_B is the first occurrence of such an overflow. If it
10574 // corresponded to A, it would have to be either a signed or an unsigned
10575 // overflow that is larger than both eliminated overflows for A. But
10576 // between the eliminated overflows and this overflow, the values would
10577 // cover the entire value space, thus crossing the other boundary, which
10578 // is a contradiction.
10579
10580 return TruncIfPossible(MinOptional(SL.first, SU.first), BitWidth);
10581}
10582
10583ScalarEvolution::ExitLimit ScalarEvolution::howFarToZero(const SCEV *V,
10584 const Loop *L,
10585 bool ControlsOnlyExit,
10586 bool AllowPredicates) {
10587
10588 // This is only used for loops with a "x != y" exit test. The exit condition
10589 // is now expressed as a single expression, V = x-y. So the exit test is
10590 // effectively V != 0. We know and take advantage of the fact that this
10591 // expression only being used in a comparison by zero context.
10592
10594 // If the value is a constant
10595 if (const SCEVConstant *C = dyn_cast<SCEVConstant>(V)) {
10596 // If the value is already zero, the branch will execute zero times.
10597 if (C->getValue()->isZero()) return C;
10598 return getCouldNotCompute(); // Otherwise it will loop infinitely.
10599 }
10600
10601 const SCEVAddRecExpr *AddRec =
10602 dyn_cast<SCEVAddRecExpr>(stripInjectiveFunctions(V));
10603
10604 if (!AddRec && AllowPredicates)
10605 // Try to make this an AddRec using runtime tests, in the first X
10606 // iterations of this loop, where X is the SCEV expression found by the
10607 // algorithm below.
10608 AddRec = convertSCEVToAddRecWithPredicates(V, L, Predicates);
10609
10610 if (!AddRec || AddRec->getLoop() != L)
10611 return getCouldNotCompute();
10612
10613 // If this is a quadratic (3-term) AddRec {L,+,M,+,N}, find the roots of
10614 // the quadratic equation to solve it.
10615 if (AddRec->isQuadratic() && AddRec->getType()->isIntegerTy()) {
10616 // We can only use this value if the chrec ends up with an exact zero
10617 // value at this index. When solving for "X*X != 5", for example, we
10618 // should not accept a root of 2.
10619 if (auto S = SolveQuadraticAddRecExact(AddRec, *this)) {
10620 const auto *R = cast<SCEVConstant>(getConstant(*S));
10621 return ExitLimit(R, R, R, false, Predicates);
10622 }
10623 return getCouldNotCompute();
10624 }
10625
10626 // Otherwise we can only handle this if it is affine.
10627 if (!AddRec->isAffine())
10628 return getCouldNotCompute();
10629
10630 // If this is an affine expression, the execution count of this branch is
10631 // the minimum unsigned root of the following equation:
10632 //
10633 // Start + Step*N = 0 (mod 2^BW)
10634 //
10635 // equivalent to:
10636 //
10637 // Step*N = -Start (mod 2^BW)
10638 //
10639 // where BW is the common bit width of Start and Step.
10640
10641 // Get the initial value for the loop.
10642 const SCEV *Start = getSCEVAtScope(AddRec->getStart(), L->getParentLoop());
10643 const SCEV *Step = getSCEVAtScope(AddRec->getOperand(1), L->getParentLoop());
10644
10645 if (!isLoopInvariant(Step, L))
10646 return getCouldNotCompute();
10647
10648 LoopGuards Guards = LoopGuards::collect(L, *this);
10649 // Specialize step for this loop so we get context sensitive facts below.
10650 const SCEV *StepWLG = applyLoopGuards(Step, Guards);
10651
10652 // For positive steps (counting up until unsigned overflow):
10653 // N = -Start/Step (as unsigned)
10654 // For negative steps (counting down to zero):
10655 // N = Start/-Step
10656 // First compute the unsigned distance from zero in the direction of Step.
10657 bool CountDown = isKnownNegative(StepWLG);
10658 if (!CountDown && !isKnownNonNegative(StepWLG))
10659 return getCouldNotCompute();
10660
10661 const SCEV *Distance = CountDown ? Start : getNegativeSCEV(Start);
10662 // Handle unitary steps, which cannot wraparound.
10663 // 1*N = -Start; -1*N = Start (mod 2^BW), so:
10664 // N = Distance (as unsigned)
10665
10666 if (match(Step, m_CombineOr(m_scev_One(), m_scev_AllOnes()))) {
10667 APInt MaxBECount = getUnsignedRangeMax(applyLoopGuards(Distance, Guards));
10668 MaxBECount = APIntOps::umin(MaxBECount, getUnsignedRangeMax(Distance));
10669
10670 // When a loop like "for (int i = 0; i != n; ++i) { /* body */ }" is rotated,
10671 // we end up with a loop whose backedge-taken count is n - 1. Detect this
10672 // case, and see if we can improve the bound.
10673 //
10674 // Explicitly handling this here is necessary because getUnsignedRange
10675 // isn't context-sensitive; it doesn't know that we only care about the
10676 // range inside the loop.
10677 const SCEV *Zero = getZero(Distance->getType());
10678 const SCEV *One = getOne(Distance->getType());
10679 const SCEV *DistancePlusOne = getAddExpr(Distance, One);
10680 if (isLoopEntryGuardedByCond(L, ICmpInst::ICMP_NE, DistancePlusOne, Zero)) {
10681 // If Distance + 1 doesn't overflow, we can compute the maximum distance
10682 // as "unsigned_max(Distance + 1) - 1".
10683 ConstantRange CR = getUnsignedRange(DistancePlusOne);
10684 MaxBECount = APIntOps::umin(MaxBECount, CR.getUnsignedMax() - 1);
10685 }
10686 return ExitLimit(Distance, getConstant(MaxBECount), Distance, false,
10687 Predicates);
10688 }
10689
10690 // If the condition controls loop exit (the loop exits only if the expression
10691 // is true) and the addition is no-wrap we can use unsigned divide to
10692 // compute the backedge count. In this case, the step may not divide the
10693 // distance, but we don't care because if the condition is "missed" the loop
10694 // will have undefined behavior due to wrapping.
10695 if (ControlsOnlyExit && AddRec->hasNoSelfWrap() &&
10696 loopHasNoAbnormalExits(AddRec->getLoop())) {
10697
10698 // If the stride is zero and the start is non-zero, the loop must be
10699 // infinite. In C++, most loops are finite by assumption, in which case the
10700 // step being zero implies UB must execute if the loop is entered.
10701 if (!(loopIsFiniteByAssumption(L) && isKnownNonZero(Start)) &&
10702 !isKnownNonZero(StepWLG))
10703 return getCouldNotCompute();
10704
10705 const SCEV *Exact =
10706 getUDivExpr(Distance, CountDown ? getNegativeSCEV(Step) : Step);
10707 const SCEV *ConstantMax = getCouldNotCompute();
10708 if (Exact != getCouldNotCompute()) {
10709 APInt MaxInt = getUnsignedRangeMax(applyLoopGuards(Exact, Guards));
10710 ConstantMax =
10712 }
10713 const SCEV *SymbolicMax =
10714 isa<SCEVCouldNotCompute>(Exact) ? ConstantMax : Exact;
10715 return ExitLimit(Exact, ConstantMax, SymbolicMax, false, Predicates);
10716 }
10717
10718 // Solve the general equation.
10719 const SCEVConstant *StepC = dyn_cast<SCEVConstant>(Step);
10720 if (!StepC || StepC->getValue()->isZero())
10721 return getCouldNotCompute();
10722 const SCEV *E = SolveLinEquationWithOverflow(
10723 StepC->getAPInt(), getNegativeSCEV(Start),
10724 AllowPredicates ? &Predicates : nullptr, *this, L);
10725
10726 const SCEV *M = E;
10727 if (E != getCouldNotCompute()) {
10728 APInt MaxWithGuards = getUnsignedRangeMax(applyLoopGuards(E, Guards));
10729 M = getConstant(APIntOps::umin(MaxWithGuards, getUnsignedRangeMax(E)));
10730 }
10731 auto *S = isa<SCEVCouldNotCompute>(E) ? M : E;
10732 return ExitLimit(E, M, S, false, Predicates);
10733}
10734
10735ScalarEvolution::ExitLimit
10736ScalarEvolution::howFarToNonZero(const SCEV *V, const Loop *L) {
10737 // Loops that look like: while (X == 0) are very strange indeed. We don't
10738 // handle them yet except for the trivial case. This could be expanded in the
10739 // future as needed.
10740
10741 // If the value is a constant, check to see if it is known to be non-zero
10742 // already. If so, the backedge will execute zero times.
10743 if (const SCEVConstant *C = dyn_cast<SCEVConstant>(V)) {
10744 if (!C->getValue()->isZero())
10745 return getZero(C->getType());
10746 return getCouldNotCompute(); // Otherwise it will loop infinitely.
10747 }
10748
10749 // We could implement others, but I really doubt anyone writes loops like
10750 // this, and if they did, they would already be constant folded.
10751 return getCouldNotCompute();
10752}
10753
10754std::pair<const BasicBlock *, const BasicBlock *>
10755ScalarEvolution::getPredecessorWithUniqueSuccessorForBB(const BasicBlock *BB)
10756 const {
10757 // If the block has a unique predecessor, then there is no path from the
10758 // predecessor to the block that does not go through the direct edge
10759 // from the predecessor to the block.
10760 if (const BasicBlock *Pred = BB->getSinglePredecessor())
10761 return {Pred, BB};
10762
10763 // A loop's header is defined to be a block that dominates the loop.
10764 // If the header has a unique predecessor outside the loop, it must be
10765 // a block that has exactly one successor that can reach the loop.
10766 if (const Loop *L = LI.getLoopFor(BB))
10767 return {L->getLoopPredecessor(), L->getHeader()};
10768
10769 return {nullptr, BB};
10770}
10771
10772/// SCEV structural equivalence is usually sufficient for testing whether two
10773/// expressions are equal, however for the purposes of looking for a condition
10774/// guarding a loop, it can be useful to be a little more general, since a
10775/// front-end may have replicated the controlling expression.
10776static bool HasSameValue(const SCEV *A, const SCEV *B) {
10777 // Quick check to see if they are the same SCEV.
10778 if (A == B) return true;
10779
10780 auto ComputesEqualValues = [](const Instruction *A, const Instruction *B) {
10781 // Not all instructions that are "identical" compute the same value. For
10782 // instance, two distinct alloca instructions allocating the same type are
10783 // identical and do not read memory; but compute distinct values.
10784 return A->isIdenticalTo(B) && (isa<BinaryOperator>(A) || isa<GetElementPtrInst>(A));
10785 };
10786
10787 // Otherwise, if they're both SCEVUnknown, it's possible that they hold
10788 // two different instructions with the same value. Check for this case.
10789 if (const SCEVUnknown *AU = dyn_cast<SCEVUnknown>(A))
10790 if (const SCEVUnknown *BU = dyn_cast<SCEVUnknown>(B))
10791 if (const Instruction *AI = dyn_cast<Instruction>(AU->getValue()))
10792 if (const Instruction *BI = dyn_cast<Instruction>(BU->getValue()))
10793 if (ComputesEqualValues(AI, BI))
10794 return true;
10795
10796 // Otherwise assume they may have a different value.
10797 return false;
10798}
10799
10800static bool MatchBinarySub(const SCEV *S, const SCEV *&LHS, const SCEV *&RHS) {
10801 const SCEV *Op0, *Op1;
10802 if (!match(S, m_scev_Add(m_SCEV(Op0), m_SCEV(Op1))))
10803 return false;
10804 if (match(Op0, m_scev_Mul(m_scev_AllOnes(), m_SCEV(RHS)))) {
10805 LHS = Op1;
10806 return true;
10807 }
10808 if (match(Op1, m_scev_Mul(m_scev_AllOnes(), m_SCEV(RHS)))) {
10809 LHS = Op0;
10810 return true;
10811 }
10812 return false;
10813}
10814
10816 const SCEV *&RHS, unsigned Depth) {
10817 bool Changed = false;
10818 // Simplifies ICMP to trivial true or false by turning it into '0 == 0' or
10819 // '0 != 0'.
10820 auto TrivialCase = [&](bool TriviallyTrue) {
10822 Pred = TriviallyTrue ? ICmpInst::ICMP_EQ : ICmpInst::ICMP_NE;
10823 return true;
10824 };
10825 // If we hit the max recursion limit bail out.
10826 if (Depth >= 3)
10827 return false;
10828
10829 const SCEV *NewLHS, *NewRHS;
10830 if (match(LHS, m_scev_c_Mul(m_SCEV(NewLHS), m_SCEVVScale())) &&
10831 match(RHS, m_scev_c_Mul(m_SCEV(NewRHS), m_SCEVVScale()))) {
10832 const SCEVMulExpr *LMul = cast<SCEVMulExpr>(LHS);
10833 const SCEVMulExpr *RMul = cast<SCEVMulExpr>(RHS);
10834
10835 // (X * vscale) pred (Y * vscale) ==> X pred Y
10836 // when both multiples are NSW.
10837 // (X * vscale) uicmp/eq/ne (Y * vscale) ==> X uicmp/eq/ne Y
10838 // when both multiples are NUW.
10839 if ((LMul->hasNoSignedWrap() && RMul->hasNoSignedWrap()) ||
10840 (LMul->hasNoUnsignedWrap() && RMul->hasNoUnsignedWrap() &&
10841 !ICmpInst::isSigned(Pred))) {
10842 LHS = NewLHS;
10843 RHS = NewRHS;
10844 Changed = true;
10845 }
10846 }
10847
10848 // Canonicalize a constant to the right side.
10849 if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(LHS)) {
10850 // Check for both operands constant.
10851 if (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS)) {
10852 if (!ICmpInst::compare(LHSC->getAPInt(), RHSC->getAPInt(), Pred))
10853 return TrivialCase(false);
10854 return TrivialCase(true);
10855 }
10856 // Otherwise swap the operands to put the constant on the right.
10857 std::swap(LHS, RHS);
10859 Changed = true;
10860 }
10861
10862 // If we're comparing an addrec with a value which is loop-invariant in the
10863 // addrec's loop, put the addrec on the left. Also make a dominance check,
10864 // as both operands could be addrecs loop-invariant in each other's loop.
10865 if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(RHS)) {
10866 const Loop *L = AR->getLoop();
10867 if (isLoopInvariant(LHS, L) && properlyDominates(LHS, L->getHeader())) {
10868 std::swap(LHS, RHS);
10870 Changed = true;
10871 }
10872 }
10873
10874 // If there's a constant operand, canonicalize comparisons with boundary
10875 // cases, and canonicalize *-or-equal comparisons to regular comparisons.
10876 if (const SCEVConstant *RC = dyn_cast<SCEVConstant>(RHS)) {
10877 const APInt &RA = RC->getAPInt();
10878
10879 bool SimplifiedByConstantRange = false;
10880
10881 if (!ICmpInst::isEquality(Pred)) {
10883 if (ExactCR.isFullSet())
10884 return TrivialCase(true);
10885 if (ExactCR.isEmptySet())
10886 return TrivialCase(false);
10887
10888 APInt NewRHS;
10889 CmpInst::Predicate NewPred;
10890 if (ExactCR.getEquivalentICmp(NewPred, NewRHS) &&
10891 ICmpInst::isEquality(NewPred)) {
10892 // We were able to convert an inequality to an equality.
10893 Pred = NewPred;
10894 RHS = getConstant(NewRHS);
10895 Changed = SimplifiedByConstantRange = true;
10896 }
10897 }
10898
10899 if (!SimplifiedByConstantRange) {
10900 switch (Pred) {
10901 default:
10902 break;
10903 case ICmpInst::ICMP_EQ:
10904 case ICmpInst::ICMP_NE:
10905 // Fold ((-1) * %a) + %b == 0 (equivalent to %b-%a == 0) into %a == %b.
10906 if (RA.isZero() && MatchBinarySub(LHS, LHS, RHS))
10907 Changed = true;
10908 break;
10909
10910 // The "Should have been caught earlier!" messages refer to the fact
10911 // that the ExactCR.isFullSet() or ExactCR.isEmptySet() check above
10912 // should have fired on the corresponding cases, and canonicalized the
10913 // check to trivial case.
10914
10915 case ICmpInst::ICMP_UGE:
10916 assert(!RA.isMinValue() && "Should have been caught earlier!");
10917 Pred = ICmpInst::ICMP_UGT;
10918 RHS = getConstant(RA - 1);
10919 Changed = true;
10920 break;
10921 case ICmpInst::ICMP_ULE:
10922 assert(!RA.isMaxValue() && "Should have been caught earlier!");
10923 Pred = ICmpInst::ICMP_ULT;
10924 RHS = getConstant(RA + 1);
10925 Changed = true;
10926 break;
10927 case ICmpInst::ICMP_SGE:
10928 assert(!RA.isMinSignedValue() && "Should have been caught earlier!");
10929 Pred = ICmpInst::ICMP_SGT;
10930 RHS = getConstant(RA - 1);
10931 Changed = true;
10932 break;
10933 case ICmpInst::ICMP_SLE:
10934 assert(!RA.isMaxSignedValue() && "Should have been caught earlier!");
10935 Pred = ICmpInst::ICMP_SLT;
10936 RHS = getConstant(RA + 1);
10937 Changed = true;
10938 break;
10939 }
10940 }
10941 }
10942
10943 // Check for obvious equality.
10944 if (HasSameValue(LHS, RHS)) {
10945 if (ICmpInst::isTrueWhenEqual(Pred))
10946 return TrivialCase(true);
10948 return TrivialCase(false);
10949 }
10950
10951 // If possible, canonicalize GE/LE comparisons to GT/LT comparisons, by
10952 // adding or subtracting 1 from one of the operands.
10953 switch (Pred) {
10954 case ICmpInst::ICMP_SLE:
10955 if (!getSignedRangeMax(RHS).isMaxSignedValue()) {
10956 RHS = getAddExpr(getConstant(RHS->getType(), 1, true), RHS,
10958 Pred = ICmpInst::ICMP_SLT;
10959 Changed = true;
10960 } else if (!getSignedRangeMin(LHS).isMinSignedValue()) {
10961 LHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), LHS,
10963 Pred = ICmpInst::ICMP_SLT;
10964 Changed = true;
10965 }
10966 break;
10967 case ICmpInst::ICMP_SGE:
10968 if (!getSignedRangeMin(RHS).isMinSignedValue()) {
10969 RHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), RHS,
10971 Pred = ICmpInst::ICMP_SGT;
10972 Changed = true;
10973 } else if (!getSignedRangeMax(LHS).isMaxSignedValue()) {
10974 LHS = getAddExpr(getConstant(RHS->getType(), 1, true), LHS,
10976 Pred = ICmpInst::ICMP_SGT;
10977 Changed = true;
10978 }
10979 break;
10980 case ICmpInst::ICMP_ULE:
10981 if (!getUnsignedRangeMax(RHS).isMaxValue()) {
10982 RHS = getAddExpr(getConstant(RHS->getType(), 1, true), RHS,
10984 Pred = ICmpInst::ICMP_ULT;
10985 Changed = true;
10986 } else if (!getUnsignedRangeMin(LHS).isMinValue()) {
10987 LHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), LHS);
10988 Pred = ICmpInst::ICMP_ULT;
10989 Changed = true;
10990 }
10991 break;
10992 case ICmpInst::ICMP_UGE:
10993 // If RHS is an op we can fold the -1, try that first.
10994 // Otherwise prefer LHS to preserve the nuw flag.
10995 if ((isa<SCEVConstant>(RHS) ||
10997 isa<SCEVConstant>(cast<SCEVNAryExpr>(RHS)->getOperand(0)))) &&
10998 !getUnsignedRangeMin(RHS).isMinValue()) {
10999 RHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), RHS);
11000 Pred = ICmpInst::ICMP_UGT;
11001 Changed = true;
11002 } else if (!getUnsignedRangeMax(LHS).isMaxValue()) {
11003 LHS = getAddExpr(getConstant(RHS->getType(), 1, true), LHS,
11005 Pred = ICmpInst::ICMP_UGT;
11006 Changed = true;
11007 } else if (!getUnsignedRangeMin(RHS).isMinValue()) {
11008 RHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), RHS);
11009 Pred = ICmpInst::ICMP_UGT;
11010 Changed = true;
11011 }
11012 break;
11013 default:
11014 break;
11015 }
11016
11017 // TODO: More simplifications are possible here.
11018
11019 // Recursively simplify until we either hit a recursion limit or nothing
11020 // changes.
11021 if (Changed)
11022 (void)SimplifyICmpOperands(Pred, LHS, RHS, Depth + 1);
11023
11024 return Changed;
11025}
11026
11028 return getSignedRangeMax(S).isNegative();
11029}
11030
11034
11036 return !getSignedRangeMin(S).isNegative();
11037}
11038
11042
11044 // Query push down for cases where the unsigned range is
11045 // less than sufficient.
11046 if (const auto *SExt = dyn_cast<SCEVSignExtendExpr>(S))
11047 return isKnownNonZero(SExt->getOperand(0));
11048 return getUnsignedRangeMin(S) != 0;
11049}
11050
11052 bool OrNegative) {
11053 auto NonRecursive = [this, OrNegative](const SCEV *S) {
11054 if (auto *C = dyn_cast<SCEVConstant>(S))
11055 return C->getAPInt().isPowerOf2() ||
11056 (OrNegative && C->getAPInt().isNegatedPowerOf2());
11057
11058 // The vscale_range indicates vscale is a power-of-two.
11059 return isa<SCEVVScale>(S) && F.hasFnAttribute(Attribute::VScaleRange);
11060 };
11061
11062 if (NonRecursive(S))
11063 return true;
11064
11065 auto *Mul = dyn_cast<SCEVMulExpr>(S);
11066 if (!Mul)
11067 return false;
11068 return all_of(Mul->operands(), NonRecursive) && (OrZero || isKnownNonZero(S));
11069}
11070
11072 const SCEV *S, uint64_t M,
11074 if (M == 0)
11075 return false;
11076 if (M == 1)
11077 return true;
11078
11079 // Recursively check AddRec operands. An AddRecExpr S is a multiple of M if S
11080 // starts with a multiple of M and at every iteration step S only adds
11081 // multiples of M.
11082 if (auto *AddRec = dyn_cast<SCEVAddRecExpr>(S))
11083 return isKnownMultipleOf(AddRec->getStart(), M, Assumptions) &&
11084 isKnownMultipleOf(AddRec->getStepRecurrence(*this), M, Assumptions);
11085
11086 // For a constant, check that "S % M == 0".
11087 if (auto *Cst = dyn_cast<SCEVConstant>(S)) {
11088 APInt C = Cst->getAPInt();
11089 return C.urem(M) == 0;
11090 }
11091
11092 // TODO: Also check other SCEV expressions, i.e., SCEVAddRecExpr, etc.
11093
11094 // Basic tests have failed.
11095 // Check "S % M == 0" at compile time and record runtime Assumptions.
11096 auto *STy = dyn_cast<IntegerType>(S->getType());
11097 const SCEV *SmodM =
11098 getURemExpr(S, getConstant(ConstantInt::get(STy, M, false)));
11099 const SCEV *Zero = getZero(STy);
11100
11101 // Check whether "S % M == 0" is known at compile time.
11102 if (isKnownPredicate(ICmpInst::ICMP_EQ, SmodM, Zero))
11103 return true;
11104
11105 // Check whether "S % M != 0" is known at compile time.
11106 if (isKnownPredicate(ICmpInst::ICMP_NE, SmodM, Zero))
11107 return false;
11108
11110
11111 // Detect redundant predicates.
11112 for (auto *A : Assumptions)
11113 if (A->implies(P, *this))
11114 return true;
11115
11116 // Only record non-redundant predicates.
11117 Assumptions.push_back(P);
11118 return true;
11119}
11120
11122 return ((isKnownNonNegative(S1) && isKnownNonNegative(S2)) ||
11124}
11125
11126std::pair<const SCEV *, const SCEV *>
11128 // Compute SCEV on entry of loop L.
11129 const SCEV *Start = SCEVInitRewriter::rewrite(S, L, *this);
11130 if (Start == getCouldNotCompute())
11131 return { Start, Start };
11132 // Compute post increment SCEV for loop L.
11133 const SCEV *PostInc = SCEVPostIncRewriter::rewrite(S, L, *this);
11134 assert(PostInc != getCouldNotCompute() && "Unexpected could not compute");
11135 return { Start, PostInc };
11136}
11137
11139 const SCEV *RHS) {
11140 // First collect all loops.
11142 getUsedLoops(LHS, LoopsUsed);
11143 getUsedLoops(RHS, LoopsUsed);
11144
11145 if (LoopsUsed.empty())
11146 return false;
11147
11148 // Domination relationship must be a linear order on collected loops.
11149#ifndef NDEBUG
11150 for (const auto *L1 : LoopsUsed)
11151 for (const auto *L2 : LoopsUsed)
11152 assert((DT.dominates(L1->getHeader(), L2->getHeader()) ||
11153 DT.dominates(L2->getHeader(), L1->getHeader())) &&
11154 "Domination relationship is not a linear order");
11155#endif
11156
11157 const Loop *MDL =
11158 *llvm::max_element(LoopsUsed, [&](const Loop *L1, const Loop *L2) {
11159 return DT.properlyDominates(L1->getHeader(), L2->getHeader());
11160 });
11161
11162 // Get init and post increment value for LHS.
11163 auto SplitLHS = SplitIntoInitAndPostInc(MDL, LHS);
11164 // if LHS contains unknown non-invariant SCEV then bail out.
11165 if (SplitLHS.first == getCouldNotCompute())
11166 return false;
11167 assert (SplitLHS.second != getCouldNotCompute() && "Unexpected CNC");
11168 // Get init and post increment value for RHS.
11169 auto SplitRHS = SplitIntoInitAndPostInc(MDL, RHS);
11170 // if RHS contains unknown non-invariant SCEV then bail out.
11171 if (SplitRHS.first == getCouldNotCompute())
11172 return false;
11173 assert (SplitRHS.second != getCouldNotCompute() && "Unexpected CNC");
11174 // It is possible that init SCEV contains an invariant load but it does
11175 // not dominate MDL and is not available at MDL loop entry, so we should
11176 // check it here.
11177 if (!isAvailableAtLoopEntry(SplitLHS.first, MDL) ||
11178 !isAvailableAtLoopEntry(SplitRHS.first, MDL))
11179 return false;
11180
11181 // It seems backedge guard check is faster than entry one so in some cases
11182 // it can speed up whole estimation by short circuit
11183 return isLoopBackedgeGuardedByCond(MDL, Pred, SplitLHS.second,
11184 SplitRHS.second) &&
11185 isLoopEntryGuardedByCond(MDL, Pred, SplitLHS.first, SplitRHS.first);
11186}
11187
11189 const SCEV *RHS) {
11190 // Canonicalize the inputs first.
11191 (void)SimplifyICmpOperands(Pred, LHS, RHS);
11192
11193 if (isKnownViaInduction(Pred, LHS, RHS))
11194 return true;
11195
11196 if (isKnownPredicateViaSplitting(Pred, LHS, RHS))
11197 return true;
11198
11199 // Otherwise see what can be done with some simple reasoning.
11200 return isKnownViaNonRecursiveReasoning(Pred, LHS, RHS);
11201}
11202
11204 const SCEV *LHS,
11205 const SCEV *RHS) {
11206 if (isKnownPredicate(Pred, LHS, RHS))
11207 return true;
11209 return false;
11210 return std::nullopt;
11211}
11212
11214 const SCEV *RHS,
11215 const Instruction *CtxI) {
11216 // TODO: Analyze guards and assumes from Context's block.
11217 return isKnownPredicate(Pred, LHS, RHS) ||
11218 isBasicBlockEntryGuardedByCond(CtxI->getParent(), Pred, LHS, RHS);
11219}
11220
11221std::optional<bool>
11223 const SCEV *RHS, const Instruction *CtxI) {
11224 std::optional<bool> KnownWithoutContext = evaluatePredicate(Pred, LHS, RHS);
11225 if (KnownWithoutContext)
11226 return KnownWithoutContext;
11227
11228 if (isBasicBlockEntryGuardedByCond(CtxI->getParent(), Pred, LHS, RHS))
11229 return true;
11231 CtxI->getParent(), ICmpInst::getInverseCmpPredicate(Pred), LHS, RHS))
11232 return false;
11233 return std::nullopt;
11234}
11235
11237 const SCEVAddRecExpr *LHS,
11238 const SCEV *RHS) {
11239 const Loop *L = LHS->getLoop();
11240 return isLoopEntryGuardedByCond(L, Pred, LHS->getStart(), RHS) &&
11241 isLoopBackedgeGuardedByCond(L, Pred, LHS->getPostIncExpr(*this), RHS);
11242}
11243
11244std::optional<ScalarEvolution::MonotonicPredicateType>
11246 ICmpInst::Predicate Pred) {
11247 auto Result = getMonotonicPredicateTypeImpl(LHS, Pred);
11248
11249#ifndef NDEBUG
11250 // Verify an invariant: inverting the predicate should turn a monotonically
11251 // increasing change to a monotonically decreasing one, and vice versa.
11252 if (Result) {
11253 auto ResultSwapped =
11254 getMonotonicPredicateTypeImpl(LHS, ICmpInst::getSwappedPredicate(Pred));
11255
11256 assert(*ResultSwapped != *Result &&
11257 "monotonicity should flip as we flip the predicate");
11258 }
11259#endif
11260
11261 return Result;
11262}
11263
11264std::optional<ScalarEvolution::MonotonicPredicateType>
11265ScalarEvolution::getMonotonicPredicateTypeImpl(const SCEVAddRecExpr *LHS,
11266 ICmpInst::Predicate Pred) {
11267 // A zero step value for LHS means the induction variable is essentially a
11268 // loop invariant value. We don't really depend on the predicate actually
11269 // flipping from false to true (for increasing predicates, and the other way
11270 // around for decreasing predicates), all we care about is that *if* the
11271 // predicate changes then it only changes from false to true.
11272 //
11273 // A zero step value in itself is not very useful, but there may be places
11274 // where SCEV can prove X >= 0 but not prove X > 0, so it is helpful to be
11275 // as general as possible.
11276
11277 // Only handle LE/LT/GE/GT predicates.
11278 if (!ICmpInst::isRelational(Pred))
11279 return std::nullopt;
11280
11281 bool IsGreater = ICmpInst::isGE(Pred) || ICmpInst::isGT(Pred);
11282 assert((IsGreater || ICmpInst::isLE(Pred) || ICmpInst::isLT(Pred)) &&
11283 "Should be greater or less!");
11284
11285 // Check that AR does not wrap.
11286 if (ICmpInst::isUnsigned(Pred)) {
11287 if (!LHS->hasNoUnsignedWrap())
11288 return std::nullopt;
11290 }
11291 assert(ICmpInst::isSigned(Pred) &&
11292 "Relational predicate is either signed or unsigned!");
11293 if (!LHS->hasNoSignedWrap())
11294 return std::nullopt;
11295
11296 const SCEV *Step = LHS->getStepRecurrence(*this);
11297
11298 if (isKnownNonNegative(Step))
11300
11301 if (isKnownNonPositive(Step))
11303
11304 return std::nullopt;
11305}
11306
11307std::optional<ScalarEvolution::LoopInvariantPredicate>
11309 const SCEV *RHS, const Loop *L,
11310 const Instruction *CtxI) {
11311 // If there is a loop-invariant, force it into the RHS, otherwise bail out.
11312 if (!isLoopInvariant(RHS, L)) {
11313 if (!isLoopInvariant(LHS, L))
11314 return std::nullopt;
11315
11316 std::swap(LHS, RHS);
11318 }
11319
11320 const SCEVAddRecExpr *ArLHS = dyn_cast<SCEVAddRecExpr>(LHS);
11321 if (!ArLHS || ArLHS->getLoop() != L)
11322 return std::nullopt;
11323
11324 auto MonotonicType = getMonotonicPredicateType(ArLHS, Pred);
11325 if (!MonotonicType)
11326 return std::nullopt;
11327 // If the predicate "ArLHS `Pred` RHS" monotonically increases from false to
11328 // true as the loop iterates, and the backedge is control dependent on
11329 // "ArLHS `Pred` RHS" == true then we can reason as follows:
11330 //
11331 // * if the predicate was false in the first iteration then the predicate
11332 // is never evaluated again, since the loop exits without taking the
11333 // backedge.
11334 // * if the predicate was true in the first iteration then it will
11335 // continue to be true for all future iterations since it is
11336 // monotonically increasing.
11337 //
11338 // For both the above possibilities, we can replace the loop varying
11339 // predicate with its value on the first iteration of the loop (which is
11340 // loop invariant).
11341 //
11342 // A similar reasoning applies for a monotonically decreasing predicate, by
11343 // replacing true with false and false with true in the above two bullets.
11345 auto P = Increasing ? Pred : ICmpInst::getInverseCmpPredicate(Pred);
11346
11347 if (isLoopBackedgeGuardedByCond(L, P, LHS, RHS))
11349 RHS);
11350
11351 if (!CtxI)
11352 return std::nullopt;
11353 // Try to prove via context.
11354 // TODO: Support other cases.
11355 switch (Pred) {
11356 default:
11357 break;
11358 case ICmpInst::ICMP_ULE:
11359 case ICmpInst::ICMP_ULT: {
11360 assert(ArLHS->hasNoUnsignedWrap() && "Is a requirement of monotonicity!");
11361 // Given preconditions
11362 // (1) ArLHS does not cross the border of positive and negative parts of
11363 // range because of:
11364 // - Positive step; (TODO: lift this limitation)
11365 // - nuw - does not cross zero boundary;
11366 // - nsw - does not cross SINT_MAX boundary;
11367 // (2) ArLHS <s RHS
11368 // (3) RHS >=s 0
11369 // we can replace the loop variant ArLHS <u RHS condition with loop
11370 // invariant Start(ArLHS) <u RHS.
11371 //
11372 // Because of (1) there are two options:
11373 // - ArLHS is always negative. It means that ArLHS <u RHS is always false;
11374 // - ArLHS is always non-negative. Because of (3) RHS is also non-negative.
11375 // It means that ArLHS <s RHS <=> ArLHS <u RHS.
11376 // Because of (2) ArLHS <u RHS is trivially true.
11377 // All together it means that ArLHS <u RHS <=> Start(ArLHS) >=s 0.
11378 // We can strengthen this to Start(ArLHS) <u RHS.
11379 auto SignFlippedPred = ICmpInst::getFlippedSignednessPredicate(Pred);
11380 if (ArLHS->hasNoSignedWrap() && ArLHS->isAffine() &&
11381 isKnownPositive(ArLHS->getStepRecurrence(*this)) &&
11382 isKnownNonNegative(RHS) &&
11383 isKnownPredicateAt(SignFlippedPred, ArLHS, RHS, CtxI))
11385 RHS);
11386 }
11387 }
11388
11389 return std::nullopt;
11390}
11391
11392std::optional<ScalarEvolution::LoopInvariantPredicate>
11394 CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS, const Loop *L,
11395 const Instruction *CtxI, const SCEV *MaxIter) {
11397 Pred, LHS, RHS, L, CtxI, MaxIter))
11398 return LIP;
11399 if (auto *UMin = dyn_cast<SCEVUMinExpr>(MaxIter))
11400 // Number of iterations expressed as UMIN isn't always great for expressing
11401 // the value on the last iteration. If the straightforward approach didn't
11402 // work, try the following trick: if the a predicate is invariant for X, it
11403 // is also invariant for umin(X, ...). So try to find something that works
11404 // among subexpressions of MaxIter expressed as umin.
11405 for (auto *Op : UMin->operands())
11407 Pred, LHS, RHS, L, CtxI, Op))
11408 return LIP;
11409 return std::nullopt;
11410}
11411
11412std::optional<ScalarEvolution::LoopInvariantPredicate>
11414 CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS, const Loop *L,
11415 const Instruction *CtxI, const SCEV *MaxIter) {
11416 // Try to prove the following set of facts:
11417 // - The predicate is monotonic in the iteration space.
11418 // - If the check does not fail on the 1st iteration:
11419 // - No overflow will happen during first MaxIter iterations;
11420 // - It will not fail on the MaxIter'th iteration.
11421 // If the check does fail on the 1st iteration, we leave the loop and no
11422 // other checks matter.
11423
11424 // If there is a loop-invariant, force it into the RHS, otherwise bail out.
11425 if (!isLoopInvariant(RHS, L)) {
11426 if (!isLoopInvariant(LHS, L))
11427 return std::nullopt;
11428
11429 std::swap(LHS, RHS);
11431 }
11432
11433 auto *AR = dyn_cast<SCEVAddRecExpr>(LHS);
11434 if (!AR || AR->getLoop() != L)
11435 return std::nullopt;
11436
11437 // The predicate must be relational (i.e. <, <=, >=, >).
11438 if (!ICmpInst::isRelational(Pred))
11439 return std::nullopt;
11440
11441 // TODO: Support steps other than +/- 1.
11442 const SCEV *Step = AR->getStepRecurrence(*this);
11443 auto *One = getOne(Step->getType());
11444 auto *MinusOne = getNegativeSCEV(One);
11445 if (Step != One && Step != MinusOne)
11446 return std::nullopt;
11447
11448 // Type mismatch here means that MaxIter is potentially larger than max
11449 // unsigned value in start type, which mean we cannot prove no wrap for the
11450 // indvar.
11451 if (AR->getType() != MaxIter->getType())
11452 return std::nullopt;
11453
11454 // Value of IV on suggested last iteration.
11455 const SCEV *Last = AR->evaluateAtIteration(MaxIter, *this);
11456 // Does it still meet the requirement?
11457 if (!isLoopBackedgeGuardedByCond(L, Pred, Last, RHS))
11458 return std::nullopt;
11459 // Because step is +/- 1 and MaxIter has same type as Start (i.e. it does
11460 // not exceed max unsigned value of this type), this effectively proves
11461 // that there is no wrap during the iteration. To prove that there is no
11462 // signed/unsigned wrap, we need to check that
11463 // Start <= Last for step = 1 or Start >= Last for step = -1.
11464 ICmpInst::Predicate NoOverflowPred =
11466 if (Step == MinusOne)
11467 NoOverflowPred = ICmpInst::getSwappedCmpPredicate(NoOverflowPred);
11468 const SCEV *Start = AR->getStart();
11469 if (!isKnownPredicateAt(NoOverflowPred, Start, Last, CtxI))
11470 return std::nullopt;
11471
11472 // Everything is fine.
11473 return ScalarEvolution::LoopInvariantPredicate(Pred, Start, RHS);
11474}
11475
11476bool ScalarEvolution::isKnownPredicateViaConstantRanges(CmpPredicate Pred,
11477 const SCEV *LHS,
11478 const SCEV *RHS) {
11479 if (HasSameValue(LHS, RHS))
11480 return ICmpInst::isTrueWhenEqual(Pred);
11481
11482 auto CheckRange = [&](bool IsSigned) {
11483 auto RangeLHS = IsSigned ? getSignedRange(LHS) : getUnsignedRange(LHS);
11484 auto RangeRHS = IsSigned ? getSignedRange(RHS) : getUnsignedRange(RHS);
11485 return RangeLHS.icmp(Pred, RangeRHS);
11486 };
11487
11488 // The check at the top of the function catches the case where the values are
11489 // known to be equal.
11490 if (Pred == CmpInst::ICMP_EQ)
11491 return false;
11492
11493 if (Pred == CmpInst::ICMP_NE) {
11494 if (CheckRange(true) || CheckRange(false))
11495 return true;
11496 auto *Diff = getMinusSCEV(LHS, RHS);
11497 return !isa<SCEVCouldNotCompute>(Diff) && isKnownNonZero(Diff);
11498 }
11499
11500 return CheckRange(CmpInst::isSigned(Pred));
11501}
11502
11503bool ScalarEvolution::isKnownPredicateViaNoOverflow(CmpPredicate Pred,
11504 const SCEV *LHS,
11505 const SCEV *RHS) {
11506 // Match X to (A + C1)<ExpectedFlags> and Y to (A + C2)<ExpectedFlags>, where
11507 // C1 and C2 are constant integers. If either X or Y are not add expressions,
11508 // consider them as X + 0 and Y + 0 respectively. C1 and C2 are returned via
11509 // OutC1 and OutC2.
11510 auto MatchBinaryAddToConst = [this](const SCEV *X, const SCEV *Y,
11511 APInt &OutC1, APInt &OutC2,
11512 SCEV::NoWrapFlags ExpectedFlags) {
11513 const SCEV *XNonConstOp, *XConstOp;
11514 const SCEV *YNonConstOp, *YConstOp;
11515 SCEV::NoWrapFlags XFlagsPresent;
11516 SCEV::NoWrapFlags YFlagsPresent;
11517
11518 if (!splitBinaryAdd(X, XConstOp, XNonConstOp, XFlagsPresent)) {
11519 XConstOp = getZero(X->getType());
11520 XNonConstOp = X;
11521 XFlagsPresent = ExpectedFlags;
11522 }
11523 if (!isa<SCEVConstant>(XConstOp))
11524 return false;
11525
11526 if (!splitBinaryAdd(Y, YConstOp, YNonConstOp, YFlagsPresent)) {
11527 YConstOp = getZero(Y->getType());
11528 YNonConstOp = Y;
11529 YFlagsPresent = ExpectedFlags;
11530 }
11531
11532 if (YNonConstOp != XNonConstOp)
11533 return false;
11534
11535 if (!isa<SCEVConstant>(YConstOp))
11536 return false;
11537
11538 // When matching ADDs with NUW flags (and unsigned predicates), only the
11539 // second ADD (with the larger constant) requires NUW.
11540 if ((YFlagsPresent & ExpectedFlags) != ExpectedFlags)
11541 return false;
11542 if (ExpectedFlags != SCEV::FlagNUW &&
11543 (XFlagsPresent & ExpectedFlags) != ExpectedFlags) {
11544 return false;
11545 }
11546
11547 OutC1 = cast<SCEVConstant>(XConstOp)->getAPInt();
11548 OutC2 = cast<SCEVConstant>(YConstOp)->getAPInt();
11549
11550 return true;
11551 };
11552
11553 APInt C1;
11554 APInt C2;
11555
11556 switch (Pred) {
11557 default:
11558 break;
11559
11560 case ICmpInst::ICMP_SGE:
11561 std::swap(LHS, RHS);
11562 [[fallthrough]];
11563 case ICmpInst::ICMP_SLE:
11564 // (X + C1)<nsw> s<= (X + C2)<nsw> if C1 s<= C2.
11565 if (MatchBinaryAddToConst(LHS, RHS, C1, C2, SCEV::FlagNSW) && C1.sle(C2))
11566 return true;
11567
11568 break;
11569
11570 case ICmpInst::ICMP_SGT:
11571 std::swap(LHS, RHS);
11572 [[fallthrough]];
11573 case ICmpInst::ICMP_SLT:
11574 // (X + C1)<nsw> s< (X + C2)<nsw> if C1 s< C2.
11575 if (MatchBinaryAddToConst(LHS, RHS, C1, C2, SCEV::FlagNSW) && C1.slt(C2))
11576 return true;
11577
11578 break;
11579
11580 case ICmpInst::ICMP_UGE:
11581 std::swap(LHS, RHS);
11582 [[fallthrough]];
11583 case ICmpInst::ICMP_ULE:
11584 // (X + C1) u<= (X + C2)<nuw> for C1 u<= C2.
11585 if (MatchBinaryAddToConst(LHS, RHS, C1, C2, SCEV::FlagNUW) && C1.ule(C2))
11586 return true;
11587
11588 break;
11589
11590 case ICmpInst::ICMP_UGT:
11591 std::swap(LHS, RHS);
11592 [[fallthrough]];
11593 case ICmpInst::ICMP_ULT:
11594 // (X + C1) u< (X + C2)<nuw> if C1 u< C2.
11595 if (MatchBinaryAddToConst(LHS, RHS, C1, C2, SCEV::FlagNUW) && C1.ult(C2))
11596 return true;
11597 break;
11598 }
11599
11600 return false;
11601}
11602
11603bool ScalarEvolution::isKnownPredicateViaSplitting(CmpPredicate Pred,
11604 const SCEV *LHS,
11605 const SCEV *RHS) {
11606 if (Pred != ICmpInst::ICMP_ULT || ProvingSplitPredicate)
11607 return false;
11608
11609 // Allowing arbitrary number of activations of isKnownPredicateViaSplitting on
11610 // the stack can result in exponential time complexity.
11611 SaveAndRestore Restore(ProvingSplitPredicate, true);
11612
11613 // If L >= 0 then I `ult` L <=> I >= 0 && I `slt` L
11614 //
11615 // To prove L >= 0 we use isKnownNonNegative whereas to prove I >= 0 we use
11616 // isKnownPredicate. isKnownPredicate is more powerful, but also more
11617 // expensive; and using isKnownNonNegative(RHS) is sufficient for most of the
11618 // interesting cases seen in practice. We can consider "upgrading" L >= 0 to
11619 // use isKnownPredicate later if needed.
11620 return isKnownNonNegative(RHS) &&
11623}
11624
11625bool ScalarEvolution::isImpliedViaGuard(const BasicBlock *BB, CmpPredicate Pred,
11626 const SCEV *LHS, const SCEV *RHS) {
11627 // No need to even try if we know the module has no guards.
11628 if (!HasGuards)
11629 return false;
11630
11631 return any_of(*BB, [&](const Instruction &I) {
11632 using namespace llvm::PatternMatch;
11633
11634 Value *Condition;
11636 m_Value(Condition))) &&
11637 isImpliedCond(Pred, LHS, RHS, Condition, false);
11638 });
11639}
11640
11641/// isLoopBackedgeGuardedByCond - Test whether the backedge of the loop is
11642/// protected by a conditional between LHS and RHS. This is used to
11643/// to eliminate casts.
11645 CmpPredicate Pred,
11646 const SCEV *LHS,
11647 const SCEV *RHS) {
11648 // Interpret a null as meaning no loop, where there is obviously no guard
11649 // (interprocedural conditions notwithstanding). Do not bother about
11650 // unreachable loops.
11651 if (!L || !DT.isReachableFromEntry(L->getHeader()))
11652 return true;
11653
11654 if (VerifyIR)
11655 assert(!verifyFunction(*L->getHeader()->getParent(), &dbgs()) &&
11656 "This cannot be done on broken IR!");
11657
11658
11659 if (isKnownViaNonRecursiveReasoning(Pred, LHS, RHS))
11660 return true;
11661
11662 BasicBlock *Latch = L->getLoopLatch();
11663 if (!Latch)
11664 return false;
11665
11666 BranchInst *LoopContinuePredicate =
11668 if (LoopContinuePredicate && LoopContinuePredicate->isConditional() &&
11669 isImpliedCond(Pred, LHS, RHS,
11670 LoopContinuePredicate->getCondition(),
11671 LoopContinuePredicate->getSuccessor(0) != L->getHeader()))
11672 return true;
11673
11674 // We don't want more than one activation of the following loops on the stack
11675 // -- that can lead to O(n!) time complexity.
11676 if (WalkingBEDominatingConds)
11677 return false;
11678
11679 SaveAndRestore ClearOnExit(WalkingBEDominatingConds, true);
11680
11681 // See if we can exploit a trip count to prove the predicate.
11682 const auto &BETakenInfo = getBackedgeTakenInfo(L);
11683 const SCEV *LatchBECount = BETakenInfo.getExact(Latch, this);
11684 if (LatchBECount != getCouldNotCompute()) {
11685 // We know that Latch branches back to the loop header exactly
11686 // LatchBECount times. This means the backdege condition at Latch is
11687 // equivalent to "{0,+,1} u< LatchBECount".
11688 Type *Ty = LatchBECount->getType();
11689 auto NoWrapFlags = SCEV::NoWrapFlags(SCEV::FlagNUW | SCEV::FlagNW);
11690 const SCEV *LoopCounter =
11691 getAddRecExpr(getZero(Ty), getOne(Ty), L, NoWrapFlags);
11692 if (isImpliedCond(Pred, LHS, RHS, ICmpInst::ICMP_ULT, LoopCounter,
11693 LatchBECount))
11694 return true;
11695 }
11696
11697 // Check conditions due to any @llvm.assume intrinsics.
11698 for (auto &AssumeVH : AC.assumptions()) {
11699 if (!AssumeVH)
11700 continue;
11701 auto *CI = cast<CallInst>(AssumeVH);
11702 if (!DT.dominates(CI, Latch->getTerminator()))
11703 continue;
11704
11705 if (isImpliedCond(Pred, LHS, RHS, CI->getArgOperand(0), false))
11706 return true;
11707 }
11708
11709 if (isImpliedViaGuard(Latch, Pred, LHS, RHS))
11710 return true;
11711
11712 for (DomTreeNode *DTN = DT[Latch], *HeaderDTN = DT[L->getHeader()];
11713 DTN != HeaderDTN; DTN = DTN->getIDom()) {
11714 assert(DTN && "should reach the loop header before reaching the root!");
11715
11716 BasicBlock *BB = DTN->getBlock();
11717 if (isImpliedViaGuard(BB, Pred, LHS, RHS))
11718 return true;
11719
11720 BasicBlock *PBB = BB->getSinglePredecessor();
11721 if (!PBB)
11722 continue;
11723
11724 BranchInst *ContinuePredicate = dyn_cast<BranchInst>(PBB->getTerminator());
11725 if (!ContinuePredicate || !ContinuePredicate->isConditional())
11726 continue;
11727
11728 Value *Condition = ContinuePredicate->getCondition();
11729
11730 // If we have an edge `E` within the loop body that dominates the only
11731 // latch, the condition guarding `E` also guards the backedge. This
11732 // reasoning works only for loops with a single latch.
11733
11734 BasicBlockEdge DominatingEdge(PBB, BB);
11735 if (DominatingEdge.isSingleEdge()) {
11736 // We're constructively (and conservatively) enumerating edges within the
11737 // loop body that dominate the latch. The dominator tree better agree
11738 // with us on this:
11739 assert(DT.dominates(DominatingEdge, Latch) && "should be!");
11740
11741 if (isImpliedCond(Pred, LHS, RHS, Condition,
11742 BB != ContinuePredicate->getSuccessor(0)))
11743 return true;
11744 }
11745 }
11746
11747 return false;
11748}
11749
11751 CmpPredicate Pred,
11752 const SCEV *LHS,
11753 const SCEV *RHS) {
11754 // Do not bother proving facts for unreachable code.
11755 if (!DT.isReachableFromEntry(BB))
11756 return true;
11757 if (VerifyIR)
11758 assert(!verifyFunction(*BB->getParent(), &dbgs()) &&
11759 "This cannot be done on broken IR!");
11760
11761 // If we cannot prove strict comparison (e.g. a > b), maybe we can prove
11762 // the facts (a >= b && a != b) separately. A typical situation is when the
11763 // non-strict comparison is known from ranges and non-equality is known from
11764 // dominating predicates. If we are proving strict comparison, we always try
11765 // to prove non-equality and non-strict comparison separately.
11766 CmpPredicate NonStrictPredicate = ICmpInst::getNonStrictCmpPredicate(Pred);
11767 const bool ProvingStrictComparison =
11768 Pred != NonStrictPredicate.dropSameSign();
11769 bool ProvedNonStrictComparison = false;
11770 bool ProvedNonEquality = false;
11771
11772 auto SplitAndProve = [&](std::function<bool(CmpPredicate)> Fn) -> bool {
11773 if (!ProvedNonStrictComparison)
11774 ProvedNonStrictComparison = Fn(NonStrictPredicate);
11775 if (!ProvedNonEquality)
11776 ProvedNonEquality = Fn(ICmpInst::ICMP_NE);
11777 if (ProvedNonStrictComparison && ProvedNonEquality)
11778 return true;
11779 return false;
11780 };
11781
11782 if (ProvingStrictComparison) {
11783 auto ProofFn = [&](CmpPredicate P) {
11784 return isKnownViaNonRecursiveReasoning(P, LHS, RHS);
11785 };
11786 if (SplitAndProve(ProofFn))
11787 return true;
11788 }
11789
11790 // Try to prove (Pred, LHS, RHS) using isImpliedCond.
11791 auto ProveViaCond = [&](const Value *Condition, bool Inverse) {
11792 const Instruction *CtxI = &BB->front();
11793 if (isImpliedCond(Pred, LHS, RHS, Condition, Inverse, CtxI))
11794 return true;
11795 if (ProvingStrictComparison) {
11796 auto ProofFn = [&](CmpPredicate P) {
11797 return isImpliedCond(P, LHS, RHS, Condition, Inverse, CtxI);
11798 };
11799 if (SplitAndProve(ProofFn))
11800 return true;
11801 }
11802 return false;
11803 };
11804
11805 // Starting at the block's predecessor, climb up the predecessor chain, as long
11806 // as there are predecessors that can be found that have unique successors
11807 // leading to the original block.
11808 const Loop *ContainingLoop = LI.getLoopFor(BB);
11809 const BasicBlock *PredBB;
11810 if (ContainingLoop && ContainingLoop->getHeader() == BB)
11811 PredBB = ContainingLoop->getLoopPredecessor();
11812 else
11813 PredBB = BB->getSinglePredecessor();
11814 for (std::pair<const BasicBlock *, const BasicBlock *> Pair(PredBB, BB);
11815 Pair.first; Pair = getPredecessorWithUniqueSuccessorForBB(Pair.first)) {
11816 const BranchInst *BlockEntryPredicate =
11817 dyn_cast<BranchInst>(Pair.first->getTerminator());
11818 if (!BlockEntryPredicate || BlockEntryPredicate->isUnconditional())
11819 continue;
11820
11821 if (ProveViaCond(BlockEntryPredicate->getCondition(),
11822 BlockEntryPredicate->getSuccessor(0) != Pair.second))
11823 return true;
11824 }
11825
11826 // Check conditions due to any @llvm.assume intrinsics.
11827 for (auto &AssumeVH : AC.assumptions()) {
11828 if (!AssumeVH)
11829 continue;
11830 auto *CI = cast<CallInst>(AssumeVH);
11831 if (!DT.dominates(CI, BB))
11832 continue;
11833
11834 if (ProveViaCond(CI->getArgOperand(0), false))
11835 return true;
11836 }
11837
11838 // Check conditions due to any @llvm.experimental.guard intrinsics.
11839 auto *GuardDecl = Intrinsic::getDeclarationIfExists(
11840 F.getParent(), Intrinsic::experimental_guard);
11841 if (GuardDecl)
11842 for (const auto *GU : GuardDecl->users())
11843 if (const auto *Guard = dyn_cast<IntrinsicInst>(GU))
11844 if (Guard->getFunction() == BB->getParent() && DT.dominates(Guard, BB))
11845 if (ProveViaCond(Guard->getArgOperand(0), false))
11846 return true;
11847 return false;
11848}
11849
11851 const SCEV *LHS,
11852 const SCEV *RHS) {
11853 // Interpret a null as meaning no loop, where there is obviously no guard
11854 // (interprocedural conditions notwithstanding).
11855 if (!L)
11856 return false;
11857
11858 // Both LHS and RHS must be available at loop entry.
11860 "LHS is not available at Loop Entry");
11862 "RHS is not available at Loop Entry");
11863
11864 if (isKnownViaNonRecursiveReasoning(Pred, LHS, RHS))
11865 return true;
11866
11867 return isBasicBlockEntryGuardedByCond(L->getHeader(), Pred, LHS, RHS);
11868}
11869
11870bool ScalarEvolution::isImpliedCond(CmpPredicate Pred, const SCEV *LHS,
11871 const SCEV *RHS,
11872 const Value *FoundCondValue, bool Inverse,
11873 const Instruction *CtxI) {
11874 // False conditions implies anything. Do not bother analyzing it further.
11875 if (FoundCondValue ==
11876 ConstantInt::getBool(FoundCondValue->getContext(), Inverse))
11877 return true;
11878
11879 if (!PendingLoopPredicates.insert(FoundCondValue).second)
11880 return false;
11881
11882 auto ClearOnExit =
11883 make_scope_exit([&]() { PendingLoopPredicates.erase(FoundCondValue); });
11884
11885 // Recursively handle And and Or conditions.
11886 const Value *Op0, *Op1;
11887 if (match(FoundCondValue, m_LogicalAnd(m_Value(Op0), m_Value(Op1)))) {
11888 if (!Inverse)
11889 return isImpliedCond(Pred, LHS, RHS, Op0, Inverse, CtxI) ||
11890 isImpliedCond(Pred, LHS, RHS, Op1, Inverse, CtxI);
11891 } else if (match(FoundCondValue, m_LogicalOr(m_Value(Op0), m_Value(Op1)))) {
11892 if (Inverse)
11893 return isImpliedCond(Pred, LHS, RHS, Op0, Inverse, CtxI) ||
11894 isImpliedCond(Pred, LHS, RHS, Op1, Inverse, CtxI);
11895 }
11896
11897 const ICmpInst *ICI = dyn_cast<ICmpInst>(FoundCondValue);
11898 if (!ICI) return false;
11899
11900 // Now that we found a conditional branch that dominates the loop or controls
11901 // the loop latch. Check to see if it is the comparison we are looking for.
11902 CmpPredicate FoundPred;
11903 if (Inverse)
11904 FoundPred = ICI->getInverseCmpPredicate();
11905 else
11906 FoundPred = ICI->getCmpPredicate();
11907
11908 const SCEV *FoundLHS = getSCEV(ICI->getOperand(0));
11909 const SCEV *FoundRHS = getSCEV(ICI->getOperand(1));
11910
11911 return isImpliedCond(Pred, LHS, RHS, FoundPred, FoundLHS, FoundRHS, CtxI);
11912}
11913
11914bool ScalarEvolution::isImpliedCond(CmpPredicate Pred, const SCEV *LHS,
11915 const SCEV *RHS, CmpPredicate FoundPred,
11916 const SCEV *FoundLHS, const SCEV *FoundRHS,
11917 const Instruction *CtxI) {
11918 // Balance the types.
11919 if (getTypeSizeInBits(LHS->getType()) <
11920 getTypeSizeInBits(FoundLHS->getType())) {
11921 // For unsigned and equality predicates, try to prove that both found
11922 // operands fit into narrow unsigned range. If so, try to prove facts in
11923 // narrow types.
11924 if (!CmpInst::isSigned(FoundPred) && !FoundLHS->getType()->isPointerTy() &&
11925 !FoundRHS->getType()->isPointerTy()) {
11926 auto *NarrowType = LHS->getType();
11927 auto *WideType = FoundLHS->getType();
11928 auto BitWidth = getTypeSizeInBits(NarrowType);
11929 const SCEV *MaxValue = getZeroExtendExpr(
11931 if (isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_ULE, FoundLHS,
11932 MaxValue) &&
11933 isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_ULE, FoundRHS,
11934 MaxValue)) {
11935 const SCEV *TruncFoundLHS = getTruncateExpr(FoundLHS, NarrowType);
11936 const SCEV *TruncFoundRHS = getTruncateExpr(FoundRHS, NarrowType);
11937 // We cannot preserve samesign after truncation.
11938 if (isImpliedCondBalancedTypes(Pred, LHS, RHS, FoundPred.dropSameSign(),
11939 TruncFoundLHS, TruncFoundRHS, CtxI))
11940 return true;
11941 }
11942 }
11943
11944 if (LHS->getType()->isPointerTy() || RHS->getType()->isPointerTy())
11945 return false;
11946 if (CmpInst::isSigned(Pred)) {
11947 LHS = getSignExtendExpr(LHS, FoundLHS->getType());
11948 RHS = getSignExtendExpr(RHS, FoundLHS->getType());
11949 } else {
11950 LHS = getZeroExtendExpr(LHS, FoundLHS->getType());
11951 RHS = getZeroExtendExpr(RHS, FoundLHS->getType());
11952 }
11953 } else if (getTypeSizeInBits(LHS->getType()) >
11954 getTypeSizeInBits(FoundLHS->getType())) {
11955 if (FoundLHS->getType()->isPointerTy() || FoundRHS->getType()->isPointerTy())
11956 return false;
11957 if (CmpInst::isSigned(FoundPred)) {
11958 FoundLHS = getSignExtendExpr(FoundLHS, LHS->getType());
11959 FoundRHS = getSignExtendExpr(FoundRHS, LHS->getType());
11960 } else {
11961 FoundLHS = getZeroExtendExpr(FoundLHS, LHS->getType());
11962 FoundRHS = getZeroExtendExpr(FoundRHS, LHS->getType());
11963 }
11964 }
11965 return isImpliedCondBalancedTypes(Pred, LHS, RHS, FoundPred, FoundLHS,
11966 FoundRHS, CtxI);
11967}
11968
11969bool ScalarEvolution::isImpliedCondBalancedTypes(
11970 CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS, CmpPredicate FoundPred,
11971 const SCEV *FoundLHS, const SCEV *FoundRHS, const Instruction *CtxI) {
11973 getTypeSizeInBits(FoundLHS->getType()) &&
11974 "Types should be balanced!");
11975 // Canonicalize the query to match the way instcombine will have
11976 // canonicalized the comparison.
11977 if (SimplifyICmpOperands(Pred, LHS, RHS))
11978 if (LHS == RHS)
11979 return CmpInst::isTrueWhenEqual(Pred);
11980 if (SimplifyICmpOperands(FoundPred, FoundLHS, FoundRHS))
11981 if (FoundLHS == FoundRHS)
11982 return CmpInst::isFalseWhenEqual(FoundPred);
11983
11984 // Check to see if we can make the LHS or RHS match.
11985 if (LHS == FoundRHS || RHS == FoundLHS) {
11986 if (isa<SCEVConstant>(RHS)) {
11987 std::swap(FoundLHS, FoundRHS);
11988 FoundPred = ICmpInst::getSwappedCmpPredicate(FoundPred);
11989 } else {
11990 std::swap(LHS, RHS);
11992 }
11993 }
11994
11995 // Check whether the found predicate is the same as the desired predicate.
11996 if (auto P = CmpPredicate::getMatching(FoundPred, Pred))
11997 return isImpliedCondOperands(*P, LHS, RHS, FoundLHS, FoundRHS, CtxI);
11998
11999 // Check whether swapping the found predicate makes it the same as the
12000 // desired predicate.
12001 if (auto P = CmpPredicate::getMatching(
12002 ICmpInst::getSwappedCmpPredicate(FoundPred), Pred)) {
12003 // We can write the implication
12004 // 0. LHS Pred RHS <- FoundLHS SwapPred FoundRHS
12005 // using one of the following ways:
12006 // 1. LHS Pred RHS <- FoundRHS Pred FoundLHS
12007 // 2. RHS SwapPred LHS <- FoundLHS SwapPred FoundRHS
12008 // 3. LHS Pred RHS <- ~FoundLHS Pred ~FoundRHS
12009 // 4. ~LHS SwapPred ~RHS <- FoundLHS SwapPred FoundRHS
12010 // Forms 1. and 2. require swapping the operands of one condition. Don't
12011 // do this if it would break canonical constant/addrec ordering.
12013 return isImpliedCondOperands(ICmpInst::getSwappedCmpPredicate(*P), RHS,
12014 LHS, FoundLHS, FoundRHS, CtxI);
12015 if (!isa<SCEVConstant>(FoundRHS) && !isa<SCEVAddRecExpr>(FoundLHS))
12016 return isImpliedCondOperands(*P, LHS, RHS, FoundRHS, FoundLHS, CtxI);
12017
12018 // There's no clear preference between forms 3. and 4., try both. Avoid
12019 // forming getNotSCEV of pointer values as the resulting subtract is
12020 // not legal.
12021 if (!LHS->getType()->isPointerTy() && !RHS->getType()->isPointerTy() &&
12022 isImpliedCondOperands(ICmpInst::getSwappedCmpPredicate(*P),
12023 getNotSCEV(LHS), getNotSCEV(RHS), FoundLHS,
12024 FoundRHS, CtxI))
12025 return true;
12026
12027 if (!FoundLHS->getType()->isPointerTy() &&
12028 !FoundRHS->getType()->isPointerTy() &&
12029 isImpliedCondOperands(*P, LHS, RHS, getNotSCEV(FoundLHS),
12030 getNotSCEV(FoundRHS), CtxI))
12031 return true;
12032
12033 return false;
12034 }
12035
12036 auto IsSignFlippedPredicate = [](CmpInst::Predicate P1,
12037 CmpInst::Predicate P2) {
12038 assert(P1 != P2 && "Handled earlier!");
12039 return CmpInst::isRelational(P2) &&
12041 };
12042 if (IsSignFlippedPredicate(Pred, FoundPred)) {
12043 // Unsigned comparison is the same as signed comparison when both the
12044 // operands are non-negative or negative.
12045 if (haveSameSign(FoundLHS, FoundRHS))
12046 return isImpliedCondOperands(Pred, LHS, RHS, FoundLHS, FoundRHS, CtxI);
12047 // Create local copies that we can freely swap and canonicalize our
12048 // conditions to "le/lt".
12049 CmpPredicate CanonicalPred = Pred, CanonicalFoundPred = FoundPred;
12050 const SCEV *CanonicalLHS = LHS, *CanonicalRHS = RHS,
12051 *CanonicalFoundLHS = FoundLHS, *CanonicalFoundRHS = FoundRHS;
12052 if (ICmpInst::isGT(CanonicalPred) || ICmpInst::isGE(CanonicalPred)) {
12053 CanonicalPred = ICmpInst::getSwappedCmpPredicate(CanonicalPred);
12054 CanonicalFoundPred = ICmpInst::getSwappedCmpPredicate(CanonicalFoundPred);
12055 std::swap(CanonicalLHS, CanonicalRHS);
12056 std::swap(CanonicalFoundLHS, CanonicalFoundRHS);
12057 }
12058 assert((ICmpInst::isLT(CanonicalPred) || ICmpInst::isLE(CanonicalPred)) &&
12059 "Must be!");
12060 assert((ICmpInst::isLT(CanonicalFoundPred) ||
12061 ICmpInst::isLE(CanonicalFoundPred)) &&
12062 "Must be!");
12063 if (ICmpInst::isSigned(CanonicalPred) && isKnownNonNegative(CanonicalRHS))
12064 // Use implication:
12065 // x <u y && y >=s 0 --> x <s y.
12066 // If we can prove the left part, the right part is also proven.
12067 return isImpliedCondOperands(CanonicalFoundPred, CanonicalLHS,
12068 CanonicalRHS, CanonicalFoundLHS,
12069 CanonicalFoundRHS);
12070 if (ICmpInst::isUnsigned(CanonicalPred) && isKnownNegative(CanonicalRHS))
12071 // Use implication:
12072 // x <s y && y <s 0 --> x <u y.
12073 // If we can prove the left part, the right part is also proven.
12074 return isImpliedCondOperands(CanonicalFoundPred, CanonicalLHS,
12075 CanonicalRHS, CanonicalFoundLHS,
12076 CanonicalFoundRHS);
12077 }
12078
12079 // Check if we can make progress by sharpening ranges.
12080 if (FoundPred == ICmpInst::ICMP_NE &&
12081 (isa<SCEVConstant>(FoundLHS) || isa<SCEVConstant>(FoundRHS))) {
12082
12083 const SCEVConstant *C = nullptr;
12084 const SCEV *V = nullptr;
12085
12086 if (isa<SCEVConstant>(FoundLHS)) {
12087 C = cast<SCEVConstant>(FoundLHS);
12088 V = FoundRHS;
12089 } else {
12090 C = cast<SCEVConstant>(FoundRHS);
12091 V = FoundLHS;
12092 }
12093
12094 // The guarding predicate tells us that C != V. If the known range
12095 // of V is [C, t), we can sharpen the range to [C + 1, t). The
12096 // range we consider has to correspond to same signedness as the
12097 // predicate we're interested in folding.
12098
12099 APInt Min = ICmpInst::isSigned(Pred) ?
12101
12102 if (Min == C->getAPInt()) {
12103 // Given (V >= Min && V != Min) we conclude V >= (Min + 1).
12104 // This is true even if (Min + 1) wraps around -- in case of
12105 // wraparound, (Min + 1) < Min, so (V >= Min => V >= (Min + 1)).
12106
12107 APInt SharperMin = Min + 1;
12108
12109 switch (Pred) {
12110 case ICmpInst::ICMP_SGE:
12111 case ICmpInst::ICMP_UGE:
12112 // We know V `Pred` SharperMin. If this implies LHS `Pred`
12113 // RHS, we're done.
12114 if (isImpliedCondOperands(Pred, LHS, RHS, V, getConstant(SharperMin),
12115 CtxI))
12116 return true;
12117 [[fallthrough]];
12118
12119 case ICmpInst::ICMP_SGT:
12120 case ICmpInst::ICMP_UGT:
12121 // We know from the range information that (V `Pred` Min ||
12122 // V == Min). We know from the guarding condition that !(V
12123 // == Min). This gives us
12124 //
12125 // V `Pred` Min || V == Min && !(V == Min)
12126 // => V `Pred` Min
12127 //
12128 // If V `Pred` Min implies LHS `Pred` RHS, we're done.
12129
12130 if (isImpliedCondOperands(Pred, LHS, RHS, V, getConstant(Min), CtxI))
12131 return true;
12132 break;
12133
12134 // `LHS < RHS` and `LHS <= RHS` are handled in the same way as `RHS > LHS` and `RHS >= LHS` respectively.
12135 case ICmpInst::ICMP_SLE:
12136 case ICmpInst::ICMP_ULE:
12137 if (isImpliedCondOperands(ICmpInst::getSwappedCmpPredicate(Pred), RHS,
12138 LHS, V, getConstant(SharperMin), CtxI))
12139 return true;
12140 [[fallthrough]];
12141
12142 case ICmpInst::ICMP_SLT:
12143 case ICmpInst::ICMP_ULT:
12144 if (isImpliedCondOperands(ICmpInst::getSwappedCmpPredicate(Pred), RHS,
12145 LHS, V, getConstant(Min), CtxI))
12146 return true;
12147 break;
12148
12149 default:
12150 // No change
12151 break;
12152 }
12153 }
12154 }
12155
12156 // Check whether the actual condition is beyond sufficient.
12157 if (FoundPred == ICmpInst::ICMP_EQ)
12158 if (ICmpInst::isTrueWhenEqual(Pred))
12159 if (isImpliedCondOperands(Pred, LHS, RHS, FoundLHS, FoundRHS, CtxI))
12160 return true;
12161 if (Pred == ICmpInst::ICMP_NE)
12162 if (!ICmpInst::isTrueWhenEqual(FoundPred))
12163 if (isImpliedCondOperands(FoundPred, LHS, RHS, FoundLHS, FoundRHS, CtxI))
12164 return true;
12165
12166 if (isImpliedCondOperandsViaRanges(Pred, LHS, RHS, FoundPred, FoundLHS, FoundRHS))
12167 return true;
12168
12169 // Otherwise assume the worst.
12170 return false;
12171}
12172
12173bool ScalarEvolution::splitBinaryAdd(const SCEV *Expr,
12174 const SCEV *&L, const SCEV *&R,
12175 SCEV::NoWrapFlags &Flags) {
12176 if (!match(Expr, m_scev_Add(m_SCEV(L), m_SCEV(R))))
12177 return false;
12178
12179 Flags = cast<SCEVAddExpr>(Expr)->getNoWrapFlags();
12180 return true;
12181}
12182
12183std::optional<APInt>
12185 // We avoid subtracting expressions here because this function is usually
12186 // fairly deep in the call stack (i.e. is called many times).
12187
12188 unsigned BW = getTypeSizeInBits(More->getType());
12189 APInt Diff(BW, 0);
12190 APInt DiffMul(BW, 1);
12191 // Try various simplifications to reduce the difference to a constant. Limit
12192 // the number of allowed simplifications to keep compile-time low.
12193 for (unsigned I = 0; I < 8; ++I) {
12194 if (More == Less)
12195 return Diff;
12196
12197 // Reduce addrecs with identical steps to their start value.
12199 const auto *LAR = cast<SCEVAddRecExpr>(Less);
12200 const auto *MAR = cast<SCEVAddRecExpr>(More);
12201
12202 if (LAR->getLoop() != MAR->getLoop())
12203 return std::nullopt;
12204
12205 // We look at affine expressions only; not for correctness but to keep
12206 // getStepRecurrence cheap.
12207 if (!LAR->isAffine() || !MAR->isAffine())
12208 return std::nullopt;
12209
12210 if (LAR->getStepRecurrence(*this) != MAR->getStepRecurrence(*this))
12211 return std::nullopt;
12212
12213 Less = LAR->getStart();
12214 More = MAR->getStart();
12215 continue;
12216 }
12217
12218 // Try to match a common constant multiply.
12219 auto MatchConstMul =
12220 [](const SCEV *S) -> std::optional<std::pair<const SCEV *, APInt>> {
12221 const APInt *C;
12222 const SCEV *Op;
12223 if (match(S, m_scev_Mul(m_scev_APInt(C), m_SCEV(Op))))
12224 return {{Op, *C}};
12225 return std::nullopt;
12226 };
12227 if (auto MatchedMore = MatchConstMul(More)) {
12228 if (auto MatchedLess = MatchConstMul(Less)) {
12229 if (MatchedMore->second == MatchedLess->second) {
12230 More = MatchedMore->first;
12231 Less = MatchedLess->first;
12232 DiffMul *= MatchedMore->second;
12233 continue;
12234 }
12235 }
12236 }
12237
12238 // Try to cancel out common factors in two add expressions.
12240 auto Add = [&](const SCEV *S, int Mul) {
12241 if (auto *C = dyn_cast<SCEVConstant>(S)) {
12242 if (Mul == 1) {
12243 Diff += C->getAPInt() * DiffMul;
12244 } else {
12245 assert(Mul == -1);
12246 Diff -= C->getAPInt() * DiffMul;
12247 }
12248 } else
12249 Multiplicity[S] += Mul;
12250 };
12251 auto Decompose = [&](const SCEV *S, int Mul) {
12252 if (isa<SCEVAddExpr>(S)) {
12253 for (const SCEV *Op : S->operands())
12254 Add(Op, Mul);
12255 } else
12256 Add(S, Mul);
12257 };
12258 Decompose(More, 1);
12259 Decompose(Less, -1);
12260
12261 // Check whether all the non-constants cancel out, or reduce to new
12262 // More/Less values.
12263 const SCEV *NewMore = nullptr, *NewLess = nullptr;
12264 for (const auto &[S, Mul] : Multiplicity) {
12265 if (Mul == 0)
12266 continue;
12267 if (Mul == 1) {
12268 if (NewMore)
12269 return std::nullopt;
12270 NewMore = S;
12271 } else if (Mul == -1) {
12272 if (NewLess)
12273 return std::nullopt;
12274 NewLess = S;
12275 } else
12276 return std::nullopt;
12277 }
12278
12279 // Values stayed the same, no point in trying further.
12280 if (NewMore == More || NewLess == Less)
12281 return std::nullopt;
12282
12283 More = NewMore;
12284 Less = NewLess;
12285
12286 // Reduced to constant.
12287 if (!More && !Less)
12288 return Diff;
12289
12290 // Left with variable on only one side, bail out.
12291 if (!More || !Less)
12292 return std::nullopt;
12293 }
12294
12295 // Did not reduce to constant.
12296 return std::nullopt;
12297}
12298
12299bool ScalarEvolution::isImpliedCondOperandsViaAddRecStart(
12300 CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS, const SCEV *FoundLHS,
12301 const SCEV *FoundRHS, const Instruction *CtxI) {
12302 // Try to recognize the following pattern:
12303 //
12304 // FoundRHS = ...
12305 // ...
12306 // loop:
12307 // FoundLHS = {Start,+,W}
12308 // context_bb: // Basic block from the same loop
12309 // known(Pred, FoundLHS, FoundRHS)
12310 //
12311 // If some predicate is known in the context of a loop, it is also known on
12312 // each iteration of this loop, including the first iteration. Therefore, in
12313 // this case, `FoundLHS Pred FoundRHS` implies `Start Pred FoundRHS`. Try to
12314 // prove the original pred using this fact.
12315 if (!CtxI)
12316 return false;
12317 const BasicBlock *ContextBB = CtxI->getParent();
12318 // Make sure AR varies in the context block.
12319 if (auto *AR = dyn_cast<SCEVAddRecExpr>(FoundLHS)) {
12320 const Loop *L = AR->getLoop();
12321 // Make sure that context belongs to the loop and executes on 1st iteration
12322 // (if it ever executes at all).
12323 if (!L->contains(ContextBB) || !DT.dominates(ContextBB, L->getLoopLatch()))
12324 return false;
12325 if (!isAvailableAtLoopEntry(FoundRHS, AR->getLoop()))
12326 return false;
12327 return isImpliedCondOperands(Pred, LHS, RHS, AR->getStart(), FoundRHS);
12328 }
12329
12330 if (auto *AR = dyn_cast<SCEVAddRecExpr>(FoundRHS)) {
12331 const Loop *L = AR->getLoop();
12332 // Make sure that context belongs to the loop and executes on 1st iteration
12333 // (if it ever executes at all).
12334 if (!L->contains(ContextBB) || !DT.dominates(ContextBB, L->getLoopLatch()))
12335 return false;
12336 if (!isAvailableAtLoopEntry(FoundLHS, AR->getLoop()))
12337 return false;
12338 return isImpliedCondOperands(Pred, LHS, RHS, FoundLHS, AR->getStart());
12339 }
12340
12341 return false;
12342}
12343
12344bool ScalarEvolution::isImpliedCondOperandsViaNoOverflow(CmpPredicate Pred,
12345 const SCEV *LHS,
12346 const SCEV *RHS,
12347 const SCEV *FoundLHS,
12348 const SCEV *FoundRHS) {
12349 if (Pred != CmpInst::ICMP_SLT && Pred != CmpInst::ICMP_ULT)
12350 return false;
12351
12352 const auto *AddRecLHS = dyn_cast<SCEVAddRecExpr>(LHS);
12353 if (!AddRecLHS)
12354 return false;
12355
12356 const auto *AddRecFoundLHS = dyn_cast<SCEVAddRecExpr>(FoundLHS);
12357 if (!AddRecFoundLHS)
12358 return false;
12359
12360 // We'd like to let SCEV reason about control dependencies, so we constrain
12361 // both the inequalities to be about add recurrences on the same loop. This
12362 // way we can use isLoopEntryGuardedByCond later.
12363
12364 const Loop *L = AddRecFoundLHS->getLoop();
12365 if (L != AddRecLHS->getLoop())
12366 return false;
12367
12368 // FoundLHS u< FoundRHS u< -C => (FoundLHS + C) u< (FoundRHS + C) ... (1)
12369 //
12370 // FoundLHS s< FoundRHS s< INT_MIN - C => (FoundLHS + C) s< (FoundRHS + C)
12371 // ... (2)
12372 //
12373 // Informal proof for (2), assuming (1) [*]:
12374 //
12375 // We'll also assume (A s< B) <=> ((A + INT_MIN) u< (B + INT_MIN)) ... (3)[**]
12376 //
12377 // Then
12378 //
12379 // FoundLHS s< FoundRHS s< INT_MIN - C
12380 // <=> (FoundLHS + INT_MIN) u< (FoundRHS + INT_MIN) u< -C [ using (3) ]
12381 // <=> (FoundLHS + INT_MIN + C) u< (FoundRHS + INT_MIN + C) [ using (1) ]
12382 // <=> (FoundLHS + INT_MIN + C + INT_MIN) s<
12383 // (FoundRHS + INT_MIN + C + INT_MIN) [ using (3) ]
12384 // <=> FoundLHS + C s< FoundRHS + C
12385 //
12386 // [*]: (1) can be proved by ruling out overflow.
12387 //
12388 // [**]: This can be proved by analyzing all the four possibilities:
12389 // (A s< 0, B s< 0), (A s< 0, B s>= 0), (A s>= 0, B s< 0) and
12390 // (A s>= 0, B s>= 0).
12391 //
12392 // Note:
12393 // Despite (2), "FoundRHS s< INT_MIN - C" does not mean that "FoundRHS + C"
12394 // will not sign underflow. For instance, say FoundLHS = (i8 -128), FoundRHS
12395 // = (i8 -127) and C = (i8 -100). Then INT_MIN - C = (i8 -28), and FoundRHS
12396 // s< (INT_MIN - C). Lack of sign overflow / underflow in "FoundRHS + C" is
12397 // neither necessary nor sufficient to prove "(FoundLHS + C) s< (FoundRHS +
12398 // C)".
12399
12400 std::optional<APInt> LDiff = computeConstantDifference(LHS, FoundLHS);
12401 if (!LDiff)
12402 return false;
12403 std::optional<APInt> RDiff = computeConstantDifference(RHS, FoundRHS);
12404 if (!RDiff || *LDiff != *RDiff)
12405 return false;
12406
12407 if (LDiff->isMinValue())
12408 return true;
12409
12410 APInt FoundRHSLimit;
12411
12412 if (Pred == CmpInst::ICMP_ULT) {
12413 FoundRHSLimit = -(*RDiff);
12414 } else {
12415 assert(Pred == CmpInst::ICMP_SLT && "Checked above!");
12416 FoundRHSLimit = APInt::getSignedMinValue(getTypeSizeInBits(RHS->getType())) - *RDiff;
12417 }
12418
12419 // Try to prove (1) or (2), as needed.
12420 return isAvailableAtLoopEntry(FoundRHS, L) &&
12421 isLoopEntryGuardedByCond(L, Pred, FoundRHS,
12422 getConstant(FoundRHSLimit));
12423}
12424
12425bool ScalarEvolution::isImpliedViaMerge(CmpPredicate Pred, const SCEV *LHS,
12426 const SCEV *RHS, const SCEV *FoundLHS,
12427 const SCEV *FoundRHS, unsigned Depth) {
12428 const PHINode *LPhi = nullptr, *RPhi = nullptr;
12429
12430 auto ClearOnExit = make_scope_exit([&]() {
12431 if (LPhi) {
12432 bool Erased = PendingMerges.erase(LPhi);
12433 assert(Erased && "Failed to erase LPhi!");
12434 (void)Erased;
12435 }
12436 if (RPhi) {
12437 bool Erased = PendingMerges.erase(RPhi);
12438 assert(Erased && "Failed to erase RPhi!");
12439 (void)Erased;
12440 }
12441 });
12442
12443 // Find respective Phis and check that they are not being pending.
12444 if (const SCEVUnknown *LU = dyn_cast<SCEVUnknown>(LHS))
12445 if (auto *Phi = dyn_cast<PHINode>(LU->getValue())) {
12446 if (!PendingMerges.insert(Phi).second)
12447 return false;
12448 LPhi = Phi;
12449 }
12450 if (const SCEVUnknown *RU = dyn_cast<SCEVUnknown>(RHS))
12451 if (auto *Phi = dyn_cast<PHINode>(RU->getValue())) {
12452 // If we detect a loop of Phi nodes being processed by this method, for
12453 // example:
12454 //
12455 // %a = phi i32 [ %some1, %preheader ], [ %b, %latch ]
12456 // %b = phi i32 [ %some2, %preheader ], [ %a, %latch ]
12457 //
12458 // we don't want to deal with a case that complex, so return conservative
12459 // answer false.
12460 if (!PendingMerges.insert(Phi).second)
12461 return false;
12462 RPhi = Phi;
12463 }
12464
12465 // If none of LHS, RHS is a Phi, nothing to do here.
12466 if (!LPhi && !RPhi)
12467 return false;
12468
12469 // If there is a SCEVUnknown Phi we are interested in, make it left.
12470 if (!LPhi) {
12471 std::swap(LHS, RHS);
12472 std::swap(FoundLHS, FoundRHS);
12473 std::swap(LPhi, RPhi);
12475 }
12476
12477 assert(LPhi && "LPhi should definitely be a SCEVUnknown Phi!");
12478 const BasicBlock *LBB = LPhi->getParent();
12479 const SCEVAddRecExpr *RAR = dyn_cast<SCEVAddRecExpr>(RHS);
12480
12481 auto ProvedEasily = [&](const SCEV *S1, const SCEV *S2) {
12482 return isKnownViaNonRecursiveReasoning(Pred, S1, S2) ||
12483 isImpliedCondOperandsViaRanges(Pred, S1, S2, Pred, FoundLHS, FoundRHS) ||
12484 isImpliedViaOperations(Pred, S1, S2, FoundLHS, FoundRHS, Depth);
12485 };
12486
12487 if (RPhi && RPhi->getParent() == LBB) {
12488 // Case one: RHS is also a SCEVUnknown Phi from the same basic block.
12489 // If we compare two Phis from the same block, and for each entry block
12490 // the predicate is true for incoming values from this block, then the
12491 // predicate is also true for the Phis.
12492 for (const BasicBlock *IncBB : predecessors(LBB)) {
12493 const SCEV *L = getSCEV(LPhi->getIncomingValueForBlock(IncBB));
12494 const SCEV *R = getSCEV(RPhi->getIncomingValueForBlock(IncBB));
12495 if (!ProvedEasily(L, R))
12496 return false;
12497 }
12498 } else if (RAR && RAR->getLoop()->getHeader() == LBB) {
12499 // Case two: RHS is also a Phi from the same basic block, and it is an
12500 // AddRec. It means that there is a loop which has both AddRec and Unknown
12501 // PHIs, for it we can compare incoming values of AddRec from above the loop
12502 // and latch with their respective incoming values of LPhi.
12503 // TODO: Generalize to handle loops with many inputs in a header.
12504 if (LPhi->getNumIncomingValues() != 2) return false;
12505
12506 auto *RLoop = RAR->getLoop();
12507 auto *Predecessor = RLoop->getLoopPredecessor();
12508 assert(Predecessor && "Loop with AddRec with no predecessor?");
12509 const SCEV *L1 = getSCEV(LPhi->getIncomingValueForBlock(Predecessor));
12510 if (!ProvedEasily(L1, RAR->getStart()))
12511 return false;
12512 auto *Latch = RLoop->getLoopLatch();
12513 assert(Latch && "Loop with AddRec with no latch?");
12514 const SCEV *L2 = getSCEV(LPhi->getIncomingValueForBlock(Latch));
12515 if (!ProvedEasily(L2, RAR->getPostIncExpr(*this)))
12516 return false;
12517 } else {
12518 // In all other cases go over inputs of LHS and compare each of them to RHS,
12519 // the predicate is true for (LHS, RHS) if it is true for all such pairs.
12520 // At this point RHS is either a non-Phi, or it is a Phi from some block
12521 // different from LBB.
12522 for (const BasicBlock *IncBB : predecessors(LBB)) {
12523 // Check that RHS is available in this block.
12524 if (!dominates(RHS, IncBB))
12525 return false;
12526 const SCEV *L = getSCEV(LPhi->getIncomingValueForBlock(IncBB));
12527 // Make sure L does not refer to a value from a potentially previous
12528 // iteration of a loop.
12529 if (!properlyDominates(L, LBB))
12530 return false;
12531 // Addrecs are considered to properly dominate their loop, so are missed
12532 // by the previous check. Discard any values that have computable
12533 // evolution in this loop.
12534 if (auto *Loop = LI.getLoopFor(LBB))
12535 if (hasComputableLoopEvolution(L, Loop))
12536 return false;
12537 if (!ProvedEasily(L, RHS))
12538 return false;
12539 }
12540 }
12541 return true;
12542}
12543
12544bool ScalarEvolution::isImpliedCondOperandsViaShift(CmpPredicate Pred,
12545 const SCEV *LHS,
12546 const SCEV *RHS,
12547 const SCEV *FoundLHS,
12548 const SCEV *FoundRHS) {
12549 // We want to imply LHS < RHS from LHS < (RHS >> shiftvalue). First, make
12550 // sure that we are dealing with same LHS.
12551 if (RHS == FoundRHS) {
12552 std::swap(LHS, RHS);
12553 std::swap(FoundLHS, FoundRHS);
12555 }
12556 if (LHS != FoundLHS)
12557 return false;
12558
12559 auto *SUFoundRHS = dyn_cast<SCEVUnknown>(FoundRHS);
12560 if (!SUFoundRHS)
12561 return false;
12562
12563 Value *Shiftee, *ShiftValue;
12564
12565 using namespace PatternMatch;
12566 if (match(SUFoundRHS->getValue(),
12567 m_LShr(m_Value(Shiftee), m_Value(ShiftValue)))) {
12568 auto *ShifteeS = getSCEV(Shiftee);
12569 // Prove one of the following:
12570 // LHS <u (shiftee >> shiftvalue) && shiftee <=u RHS ---> LHS <u RHS
12571 // LHS <=u (shiftee >> shiftvalue) && shiftee <=u RHS ---> LHS <=u RHS
12572 // LHS <s (shiftee >> shiftvalue) && shiftee <=s RHS && shiftee >=s 0
12573 // ---> LHS <s RHS
12574 // LHS <=s (shiftee >> shiftvalue) && shiftee <=s RHS && shiftee >=s 0
12575 // ---> LHS <=s RHS
12576 if (Pred == ICmpInst::ICMP_ULT || Pred == ICmpInst::ICMP_ULE)
12577 return isKnownPredicate(ICmpInst::ICMP_ULE, ShifteeS, RHS);
12578 if (Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_SLE)
12579 if (isKnownNonNegative(ShifteeS))
12580 return isKnownPredicate(ICmpInst::ICMP_SLE, ShifteeS, RHS);
12581 }
12582
12583 return false;
12584}
12585
12586bool ScalarEvolution::isImpliedCondOperands(CmpPredicate Pred, const SCEV *LHS,
12587 const SCEV *RHS,
12588 const SCEV *FoundLHS,
12589 const SCEV *FoundRHS,
12590 const Instruction *CtxI) {
12591 return isImpliedCondOperandsViaRanges(Pred, LHS, RHS, Pred, FoundLHS,
12592 FoundRHS) ||
12593 isImpliedCondOperandsViaNoOverflow(Pred, LHS, RHS, FoundLHS,
12594 FoundRHS) ||
12595 isImpliedCondOperandsViaShift(Pred, LHS, RHS, FoundLHS, FoundRHS) ||
12596 isImpliedCondOperandsViaAddRecStart(Pred, LHS, RHS, FoundLHS, FoundRHS,
12597 CtxI) ||
12598 isImpliedCondOperandsHelper(Pred, LHS, RHS, FoundLHS, FoundRHS);
12599}
12600
12601/// Is MaybeMinMaxExpr an (U|S)(Min|Max) of Candidate and some other values?
12602template <typename MinMaxExprType>
12603static bool IsMinMaxConsistingOf(const SCEV *MaybeMinMaxExpr,
12604 const SCEV *Candidate) {
12605 const MinMaxExprType *MinMaxExpr = dyn_cast<MinMaxExprType>(MaybeMinMaxExpr);
12606 if (!MinMaxExpr)
12607 return false;
12608
12609 return is_contained(MinMaxExpr->operands(), Candidate);
12610}
12611
12613 CmpPredicate Pred, const SCEV *LHS,
12614 const SCEV *RHS) {
12615 // If both sides are affine addrecs for the same loop, with equal
12616 // steps, and we know the recurrences don't wrap, then we only
12617 // need to check the predicate on the starting values.
12618
12619 if (!ICmpInst::isRelational(Pred))
12620 return false;
12621
12622 const SCEV *LStart, *RStart, *Step;
12623 const Loop *L;
12624 if (!match(LHS,
12625 m_scev_AffineAddRec(m_SCEV(LStart), m_SCEV(Step), m_Loop(L))) ||
12627 m_SpecificLoop(L))))
12628 return false;
12633 if (!LAR->getNoWrapFlags(NW) || !RAR->getNoWrapFlags(NW))
12634 return false;
12635
12636 return SE.isKnownPredicate(Pred, LStart, RStart);
12637}
12638
12639/// Is LHS `Pred` RHS true on the virtue of LHS or RHS being a Min or Max
12640/// expression?
12642 const SCEV *LHS, const SCEV *RHS) {
12643 switch (Pred) {
12644 default:
12645 return false;
12646
12647 case ICmpInst::ICMP_SGE:
12648 std::swap(LHS, RHS);
12649 [[fallthrough]];
12650 case ICmpInst::ICMP_SLE:
12651 return
12652 // min(A, ...) <= A
12654 // A <= max(A, ...)
12656
12657 case ICmpInst::ICMP_UGE:
12658 std::swap(LHS, RHS);
12659 [[fallthrough]];
12660 case ICmpInst::ICMP_ULE:
12661 return
12662 // min(A, ...) <= A
12663 // FIXME: what about umin_seq?
12665 // A <= max(A, ...)
12667 }
12668
12669 llvm_unreachable("covered switch fell through?!");
12670}
12671
12672bool ScalarEvolution::isImpliedViaOperations(CmpPredicate Pred, const SCEV *LHS,
12673 const SCEV *RHS,
12674 const SCEV *FoundLHS,
12675 const SCEV *FoundRHS,
12676 unsigned Depth) {
12679 "LHS and RHS have different sizes?");
12680 assert(getTypeSizeInBits(FoundLHS->getType()) ==
12681 getTypeSizeInBits(FoundRHS->getType()) &&
12682 "FoundLHS and FoundRHS have different sizes?");
12683 // We want to avoid hurting the compile time with analysis of too big trees.
12685 return false;
12686
12687 // We only want to work with GT comparison so far.
12688 if (ICmpInst::isLT(Pred)) {
12690 std::swap(LHS, RHS);
12691 std::swap(FoundLHS, FoundRHS);
12692 }
12693
12695
12696 // For unsigned, try to reduce it to corresponding signed comparison.
12697 if (P == ICmpInst::ICMP_UGT)
12698 // We can replace unsigned predicate with its signed counterpart if all
12699 // involved values are non-negative.
12700 // TODO: We could have better support for unsigned.
12701 if (isKnownNonNegative(FoundLHS) && isKnownNonNegative(FoundRHS)) {
12702 // Knowing that both FoundLHS and FoundRHS are non-negative, and knowing
12703 // FoundLHS >u FoundRHS, we also know that FoundLHS >s FoundRHS. Let us
12704 // use this fact to prove that LHS and RHS are non-negative.
12705 const SCEV *MinusOne = getMinusOne(LHS->getType());
12706 if (isImpliedCondOperands(ICmpInst::ICMP_SGT, LHS, MinusOne, FoundLHS,
12707 FoundRHS) &&
12708 isImpliedCondOperands(ICmpInst::ICMP_SGT, RHS, MinusOne, FoundLHS,
12709 FoundRHS))
12711 }
12712
12713 if (P != ICmpInst::ICMP_SGT)
12714 return false;
12715
12716 auto GetOpFromSExt = [&](const SCEV *S) {
12717 if (auto *Ext = dyn_cast<SCEVSignExtendExpr>(S))
12718 return Ext->getOperand();
12719 // TODO: If S is a SCEVConstant then you can cheaply "strip" the sext off
12720 // the constant in some cases.
12721 return S;
12722 };
12723
12724 // Acquire values from extensions.
12725 auto *OrigLHS = LHS;
12726 auto *OrigFoundLHS = FoundLHS;
12727 LHS = GetOpFromSExt(LHS);
12728 FoundLHS = GetOpFromSExt(FoundLHS);
12729
12730 // Is the SGT predicate can be proved trivially or using the found context.
12731 auto IsSGTViaContext = [&](const SCEV *S1, const SCEV *S2) {
12732 return isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_SGT, S1, S2) ||
12733 isImpliedViaOperations(ICmpInst::ICMP_SGT, S1, S2, OrigFoundLHS,
12734 FoundRHS, Depth + 1);
12735 };
12736
12737 if (auto *LHSAddExpr = dyn_cast<SCEVAddExpr>(LHS)) {
12738 // We want to avoid creation of any new non-constant SCEV. Since we are
12739 // going to compare the operands to RHS, we should be certain that we don't
12740 // need any size extensions for this. So let's decline all cases when the
12741 // sizes of types of LHS and RHS do not match.
12742 // TODO: Maybe try to get RHS from sext to catch more cases?
12744 return false;
12745
12746 // Should not overflow.
12747 if (!LHSAddExpr->hasNoSignedWrap())
12748 return false;
12749
12750 auto *LL = LHSAddExpr->getOperand(0);
12751 auto *LR = LHSAddExpr->getOperand(1);
12752 auto *MinusOne = getMinusOne(RHS->getType());
12753
12754 // Checks that S1 >= 0 && S2 > RHS, trivially or using the found context.
12755 auto IsSumGreaterThanRHS = [&](const SCEV *S1, const SCEV *S2) {
12756 return IsSGTViaContext(S1, MinusOne) && IsSGTViaContext(S2, RHS);
12757 };
12758 // Try to prove the following rule:
12759 // (LHS = LL + LR) && (LL >= 0) && (LR > RHS) => (LHS > RHS).
12760 // (LHS = LL + LR) && (LR >= 0) && (LL > RHS) => (LHS > RHS).
12761 if (IsSumGreaterThanRHS(LL, LR) || IsSumGreaterThanRHS(LR, LL))
12762 return true;
12763 } else if (auto *LHSUnknownExpr = dyn_cast<SCEVUnknown>(LHS)) {
12764 Value *LL, *LR;
12765 // FIXME: Once we have SDiv implemented, we can get rid of this matching.
12766
12767 using namespace llvm::PatternMatch;
12768
12769 if (match(LHSUnknownExpr->getValue(), m_SDiv(m_Value(LL), m_Value(LR)))) {
12770 // Rules for division.
12771 // We are going to perform some comparisons with Denominator and its
12772 // derivative expressions. In general case, creating a SCEV for it may
12773 // lead to a complex analysis of the entire graph, and in particular it
12774 // can request trip count recalculation for the same loop. This would
12775 // cache as SCEVCouldNotCompute to avoid the infinite recursion. To avoid
12776 // this, we only want to create SCEVs that are constants in this section.
12777 // So we bail if Denominator is not a constant.
12778 if (!isa<ConstantInt>(LR))
12779 return false;
12780
12781 auto *Denominator = cast<SCEVConstant>(getSCEV(LR));
12782
12783 // We want to make sure that LHS = FoundLHS / Denominator. If it is so,
12784 // then a SCEV for the numerator already exists and matches with FoundLHS.
12785 auto *Numerator = getExistingSCEV(LL);
12786 if (!Numerator || Numerator->getType() != FoundLHS->getType())
12787 return false;
12788
12789 // Make sure that the numerator matches with FoundLHS and the denominator
12790 // is positive.
12791 if (!HasSameValue(Numerator, FoundLHS) || !isKnownPositive(Denominator))
12792 return false;
12793
12794 auto *DTy = Denominator->getType();
12795 auto *FRHSTy = FoundRHS->getType();
12796 if (DTy->isPointerTy() != FRHSTy->isPointerTy())
12797 // One of types is a pointer and another one is not. We cannot extend
12798 // them properly to a wider type, so let us just reject this case.
12799 // TODO: Usage of getEffectiveSCEVType for DTy, FRHSTy etc should help
12800 // to avoid this check.
12801 return false;
12802
12803 // Given that:
12804 // FoundLHS > FoundRHS, LHS = FoundLHS / Denominator, Denominator > 0.
12805 auto *WTy = getWiderType(DTy, FRHSTy);
12806 auto *DenominatorExt = getNoopOrSignExtend(Denominator, WTy);
12807 auto *FoundRHSExt = getNoopOrSignExtend(FoundRHS, WTy);
12808
12809 // Try to prove the following rule:
12810 // (FoundRHS > Denominator - 2) && (RHS <= 0) => (LHS > RHS).
12811 // For example, given that FoundLHS > 2. It means that FoundLHS is at
12812 // least 3. If we divide it by Denominator < 4, we will have at least 1.
12813 auto *DenomMinusTwo = getMinusSCEV(DenominatorExt, getConstant(WTy, 2));
12814 if (isKnownNonPositive(RHS) &&
12815 IsSGTViaContext(FoundRHSExt, DenomMinusTwo))
12816 return true;
12817
12818 // Try to prove the following rule:
12819 // (FoundRHS > -1 - Denominator) && (RHS < 0) => (LHS > RHS).
12820 // For example, given that FoundLHS > -3. Then FoundLHS is at least -2.
12821 // If we divide it by Denominator > 2, then:
12822 // 1. If FoundLHS is negative, then the result is 0.
12823 // 2. If FoundLHS is non-negative, then the result is non-negative.
12824 // Anyways, the result is non-negative.
12825 auto *MinusOne = getMinusOne(WTy);
12826 auto *NegDenomMinusOne = getMinusSCEV(MinusOne, DenominatorExt);
12827 if (isKnownNegative(RHS) &&
12828 IsSGTViaContext(FoundRHSExt, NegDenomMinusOne))
12829 return true;
12830 }
12831 }
12832
12833 // If our expression contained SCEVUnknown Phis, and we split it down and now
12834 // need to prove something for them, try to prove the predicate for every
12835 // possible incoming values of those Phis.
12836 if (isImpliedViaMerge(Pred, OrigLHS, RHS, OrigFoundLHS, FoundRHS, Depth + 1))
12837 return true;
12838
12839 return false;
12840}
12841
12843 const SCEV *RHS) {
12844 // zext x u<= sext x, sext x s<= zext x
12845 const SCEV *Op;
12846 switch (Pred) {
12847 case ICmpInst::ICMP_SGE:
12848 std::swap(LHS, RHS);
12849 [[fallthrough]];
12850 case ICmpInst::ICMP_SLE: {
12851 // If operand >=s 0 then ZExt == SExt. If operand <s 0 then SExt <s ZExt.
12852 return match(LHS, m_scev_SExt(m_SCEV(Op))) &&
12854 }
12855 case ICmpInst::ICMP_UGE:
12856 std::swap(LHS, RHS);
12857 [[fallthrough]];
12858 case ICmpInst::ICMP_ULE: {
12859 // If operand >=u 0 then ZExt == SExt. If operand <u 0 then ZExt <u SExt.
12860 return match(LHS, m_scev_ZExt(m_SCEV(Op))) &&
12862 }
12863 default:
12864 return false;
12865 };
12866 llvm_unreachable("unhandled case");
12867}
12868
12869bool ScalarEvolution::isKnownViaNonRecursiveReasoning(CmpPredicate Pred,
12870 const SCEV *LHS,
12871 const SCEV *RHS) {
12872 return isKnownPredicateExtendIdiom(Pred, LHS, RHS) ||
12873 isKnownPredicateViaConstantRanges(Pred, LHS, RHS) ||
12874 IsKnownPredicateViaMinOrMax(*this, Pred, LHS, RHS) ||
12875 IsKnownPredicateViaAddRecStart(*this, Pred, LHS, RHS) ||
12876 isKnownPredicateViaNoOverflow(Pred, LHS, RHS);
12877}
12878
12879bool ScalarEvolution::isImpliedCondOperandsHelper(CmpPredicate Pred,
12880 const SCEV *LHS,
12881 const SCEV *RHS,
12882 const SCEV *FoundLHS,
12883 const SCEV *FoundRHS) {
12884 switch (Pred) {
12885 default:
12886 llvm_unreachable("Unexpected CmpPredicate value!");
12887 case ICmpInst::ICMP_EQ:
12888 case ICmpInst::ICMP_NE:
12889 if (HasSameValue(LHS, FoundLHS) && HasSameValue(RHS, FoundRHS))
12890 return true;
12891 break;
12892 case ICmpInst::ICMP_SLT:
12893 case ICmpInst::ICMP_SLE:
12894 if (isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_SLE, LHS, FoundLHS) &&
12895 isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_SGE, RHS, FoundRHS))
12896 return true;
12897 break;
12898 case ICmpInst::ICMP_SGT:
12899 case ICmpInst::ICMP_SGE:
12900 if (isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_SGE, LHS, FoundLHS) &&
12901 isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_SLE, RHS, FoundRHS))
12902 return true;
12903 break;
12904 case ICmpInst::ICMP_ULT:
12905 case ICmpInst::ICMP_ULE:
12906 if (isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_ULE, LHS, FoundLHS) &&
12907 isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_UGE, RHS, FoundRHS))
12908 return true;
12909 break;
12910 case ICmpInst::ICMP_UGT:
12911 case ICmpInst::ICMP_UGE:
12912 if (isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_UGE, LHS, FoundLHS) &&
12913 isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_ULE, RHS, FoundRHS))
12914 return true;
12915 break;
12916 }
12917
12918 // Maybe it can be proved via operations?
12919 if (isImpliedViaOperations(Pred, LHS, RHS, FoundLHS, FoundRHS))
12920 return true;
12921
12922 return false;
12923}
12924
12925bool ScalarEvolution::isImpliedCondOperandsViaRanges(
12926 CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS, CmpPredicate FoundPred,
12927 const SCEV *FoundLHS, const SCEV *FoundRHS) {
12928 if (!isa<SCEVConstant>(RHS) || !isa<SCEVConstant>(FoundRHS))
12929 // The restriction on `FoundRHS` be lifted easily -- it exists only to
12930 // reduce the compile time impact of this optimization.
12931 return false;
12932
12933 std::optional<APInt> Addend = computeConstantDifference(LHS, FoundLHS);
12934 if (!Addend)
12935 return false;
12936
12937 const APInt &ConstFoundRHS = cast<SCEVConstant>(FoundRHS)->getAPInt();
12938
12939 // `FoundLHSRange` is the range we know `FoundLHS` to be in by virtue of the
12940 // antecedent "`FoundLHS` `FoundPred` `FoundRHS`".
12941 ConstantRange FoundLHSRange =
12942 ConstantRange::makeExactICmpRegion(FoundPred, ConstFoundRHS);
12943
12944 // Since `LHS` is `FoundLHS` + `Addend`, we can compute a range for `LHS`:
12945 ConstantRange LHSRange = FoundLHSRange.add(ConstantRange(*Addend));
12946
12947 // We can also compute the range of values for `LHS` that satisfy the
12948 // consequent, "`LHS` `Pred` `RHS`":
12949 const APInt &ConstRHS = cast<SCEVConstant>(RHS)->getAPInt();
12950 // The antecedent implies the consequent if every value of `LHS` that
12951 // satisfies the antecedent also satisfies the consequent.
12952 return LHSRange.icmp(Pred, ConstRHS);
12953}
12954
12955bool ScalarEvolution::canIVOverflowOnLT(const SCEV *RHS, const SCEV *Stride,
12956 bool IsSigned) {
12957 assert(isKnownPositive(Stride) && "Positive stride expected!");
12958
12959 unsigned BitWidth = getTypeSizeInBits(RHS->getType());
12960 const SCEV *One = getOne(Stride->getType());
12961
12962 if (IsSigned) {
12963 APInt MaxRHS = getSignedRangeMax(RHS);
12964 APInt MaxValue = APInt::getSignedMaxValue(BitWidth);
12965 APInt MaxStrideMinusOne = getSignedRangeMax(getMinusSCEV(Stride, One));
12966
12967 // SMaxRHS + SMaxStrideMinusOne > SMaxValue => overflow!
12968 return (std::move(MaxValue) - MaxStrideMinusOne).slt(MaxRHS);
12969 }
12970
12971 APInt MaxRHS = getUnsignedRangeMax(RHS);
12972 APInt MaxValue = APInt::getMaxValue(BitWidth);
12973 APInt MaxStrideMinusOne = getUnsignedRangeMax(getMinusSCEV(Stride, One));
12974
12975 // UMaxRHS + UMaxStrideMinusOne > UMaxValue => overflow!
12976 return (std::move(MaxValue) - MaxStrideMinusOne).ult(MaxRHS);
12977}
12978
12979bool ScalarEvolution::canIVOverflowOnGT(const SCEV *RHS, const SCEV *Stride,
12980 bool IsSigned) {
12981
12982 unsigned BitWidth = getTypeSizeInBits(RHS->getType());
12983 const SCEV *One = getOne(Stride->getType());
12984
12985 if (IsSigned) {
12986 APInt MinRHS = getSignedRangeMin(RHS);
12987 APInt MinValue = APInt::getSignedMinValue(BitWidth);
12988 APInt MaxStrideMinusOne = getSignedRangeMax(getMinusSCEV(Stride, One));
12989
12990 // SMinRHS - SMaxStrideMinusOne < SMinValue => overflow!
12991 return (std::move(MinValue) + MaxStrideMinusOne).sgt(MinRHS);
12992 }
12993
12994 APInt MinRHS = getUnsignedRangeMin(RHS);
12995 APInt MinValue = APInt::getMinValue(BitWidth);
12996 APInt MaxStrideMinusOne = getUnsignedRangeMax(getMinusSCEV(Stride, One));
12997
12998 // UMinRHS - UMaxStrideMinusOne < UMinValue => overflow!
12999 return (std::move(MinValue) + MaxStrideMinusOne).ugt(MinRHS);
13000}
13001
13003 // umin(N, 1) + floor((N - umin(N, 1)) / D)
13004 // This is equivalent to "1 + floor((N - 1) / D)" for N != 0. The umin
13005 // expression fixes the case of N=0.
13006 const SCEV *MinNOne = getUMinExpr(N, getOne(N->getType()));
13007 const SCEV *NMinusOne = getMinusSCEV(N, MinNOne);
13008 return getAddExpr(MinNOne, getUDivExpr(NMinusOne, D));
13009}
13010
13011const SCEV *ScalarEvolution::computeMaxBECountForLT(const SCEV *Start,
13012 const SCEV *Stride,
13013 const SCEV *End,
13014 unsigned BitWidth,
13015 bool IsSigned) {
13016 // The logic in this function assumes we can represent a positive stride.
13017 // If we can't, the backedge-taken count must be zero.
13018 if (IsSigned && BitWidth == 1)
13019 return getZero(Stride->getType());
13020
13021 // This code below only been closely audited for negative strides in the
13022 // unsigned comparison case, it may be correct for signed comparison, but
13023 // that needs to be established.
13024 if (IsSigned && isKnownNegative(Stride))
13025 return getCouldNotCompute();
13026
13027 // Calculate the maximum backedge count based on the range of values
13028 // permitted by Start, End, and Stride.
13029 APInt MinStart =
13030 IsSigned ? getSignedRangeMin(Start) : getUnsignedRangeMin(Start);
13031
13032 APInt MinStride =
13033 IsSigned ? getSignedRangeMin(Stride) : getUnsignedRangeMin(Stride);
13034
13035 // We assume either the stride is positive, or the backedge-taken count
13036 // is zero. So force StrideForMaxBECount to be at least one.
13037 APInt One(BitWidth, 1);
13038 APInt StrideForMaxBECount = IsSigned ? APIntOps::smax(One, MinStride)
13039 : APIntOps::umax(One, MinStride);
13040
13041 APInt MaxValue = IsSigned ? APInt::getSignedMaxValue(BitWidth)
13042 : APInt::getMaxValue(BitWidth);
13043 APInt Limit = MaxValue - (StrideForMaxBECount - 1);
13044
13045 // Although End can be a MAX expression we estimate MaxEnd considering only
13046 // the case End = RHS of the loop termination condition. This is safe because
13047 // in the other case (End - Start) is zero, leading to a zero maximum backedge
13048 // taken count.
13049 APInt MaxEnd = IsSigned ? APIntOps::smin(getSignedRangeMax(End), Limit)
13050 : APIntOps::umin(getUnsignedRangeMax(End), Limit);
13051
13052 // MaxBECount = ceil((max(MaxEnd, MinStart) - MinStart) / Stride)
13053 MaxEnd = IsSigned ? APIntOps::smax(MaxEnd, MinStart)
13054 : APIntOps::umax(MaxEnd, MinStart);
13055
13056 return getUDivCeilSCEV(getConstant(MaxEnd - MinStart) /* Delta */,
13057 getConstant(StrideForMaxBECount) /* Step */);
13058}
13059
13061ScalarEvolution::howManyLessThans(const SCEV *LHS, const SCEV *RHS,
13062 const Loop *L, bool IsSigned,
13063 bool ControlsOnlyExit, bool AllowPredicates) {
13065
13067 bool PredicatedIV = false;
13068 if (!IV) {
13069 if (auto *ZExt = dyn_cast<SCEVZeroExtendExpr>(LHS)) {
13070 const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(ZExt->getOperand());
13071 if (AR && AR->getLoop() == L && AR->isAffine()) {
13072 auto canProveNUW = [&]() {
13073 // We can use the comparison to infer no-wrap flags only if it fully
13074 // controls the loop exit.
13075 if (!ControlsOnlyExit)
13076 return false;
13077
13078 if (!isLoopInvariant(RHS, L))
13079 return false;
13080
13081 if (!isKnownNonZero(AR->getStepRecurrence(*this)))
13082 // We need the sequence defined by AR to strictly increase in the
13083 // unsigned integer domain for the logic below to hold.
13084 return false;
13085
13086 const unsigned InnerBitWidth = getTypeSizeInBits(AR->getType());
13087 const unsigned OuterBitWidth = getTypeSizeInBits(RHS->getType());
13088 // If RHS <=u Limit, then there must exist a value V in the sequence
13089 // defined by AR (e.g. {Start,+,Step}) such that V >u RHS, and
13090 // V <=u UINT_MAX. Thus, we must exit the loop before unsigned
13091 // overflow occurs. This limit also implies that a signed comparison
13092 // (in the wide bitwidth) is equivalent to an unsigned comparison as
13093 // the high bits on both sides must be zero.
13094 APInt StrideMax = getUnsignedRangeMax(AR->getStepRecurrence(*this));
13095 APInt Limit = APInt::getMaxValue(InnerBitWidth) - (StrideMax - 1);
13096 Limit = Limit.zext(OuterBitWidth);
13097 return getUnsignedRangeMax(applyLoopGuards(RHS, L)).ule(Limit);
13098 };
13099 auto Flags = AR->getNoWrapFlags();
13100 if (!hasFlags(Flags, SCEV::FlagNUW) && canProveNUW())
13101 Flags = setFlags(Flags, SCEV::FlagNUW);
13102
13103 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), Flags);
13104 if (AR->hasNoUnsignedWrap()) {
13105 // Emulate what getZeroExtendExpr would have done during construction
13106 // if we'd been able to infer the fact just above at that time.
13107 const SCEV *Step = AR->getStepRecurrence(*this);
13108 Type *Ty = ZExt->getType();
13109 auto *S = getAddRecExpr(
13111 getZeroExtendExpr(Step, Ty, 0), L, AR->getNoWrapFlags());
13113 }
13114 }
13115 }
13116 }
13117
13118
13119 if (!IV && AllowPredicates) {
13120 // Try to make this an AddRec using runtime tests, in the first X
13121 // iterations of this loop, where X is the SCEV expression found by the
13122 // algorithm below.
13123 IV = convertSCEVToAddRecWithPredicates(LHS, L, Predicates);
13124 PredicatedIV = true;
13125 }
13126
13127 // Avoid weird loops
13128 if (!IV || IV->getLoop() != L || !IV->isAffine())
13129 return getCouldNotCompute();
13130
13131 // A precondition of this method is that the condition being analyzed
13132 // reaches an exiting branch which dominates the latch. Given that, we can
13133 // assume that an increment which violates the nowrap specification and
13134 // produces poison must cause undefined behavior when the resulting poison
13135 // value is branched upon and thus we can conclude that the backedge is
13136 // taken no more often than would be required to produce that poison value.
13137 // Note that a well defined loop can exit on the iteration which violates
13138 // the nowrap specification if there is another exit (either explicit or
13139 // implicit/exceptional) which causes the loop to execute before the
13140 // exiting instruction we're analyzing would trigger UB.
13141 auto WrapType = IsSigned ? SCEV::FlagNSW : SCEV::FlagNUW;
13142 bool NoWrap = ControlsOnlyExit && IV->getNoWrapFlags(WrapType);
13144
13145 const SCEV *Stride = IV->getStepRecurrence(*this);
13146
13147 bool PositiveStride = isKnownPositive(Stride);
13148
13149 // Avoid negative or zero stride values.
13150 if (!PositiveStride) {
13151 // We can compute the correct backedge taken count for loops with unknown
13152 // strides if we can prove that the loop is not an infinite loop with side
13153 // effects. Here's the loop structure we are trying to handle -
13154 //
13155 // i = start
13156 // do {
13157 // A[i] = i;
13158 // i += s;
13159 // } while (i < end);
13160 //
13161 // The backedge taken count for such loops is evaluated as -
13162 // (max(end, start + stride) - start - 1) /u stride
13163 //
13164 // The additional preconditions that we need to check to prove correctness
13165 // of the above formula is as follows -
13166 //
13167 // a) IV is either nuw or nsw depending upon signedness (indicated by the
13168 // NoWrap flag).
13169 // b) the loop is guaranteed to be finite (e.g. is mustprogress and has
13170 // no side effects within the loop)
13171 // c) loop has a single static exit (with no abnormal exits)
13172 //
13173 // Precondition a) implies that if the stride is negative, this is a single
13174 // trip loop. The backedge taken count formula reduces to zero in this case.
13175 //
13176 // Precondition b) and c) combine to imply that if rhs is invariant in L,
13177 // then a zero stride means the backedge can't be taken without executing
13178 // undefined behavior.
13179 //
13180 // The positive stride case is the same as isKnownPositive(Stride) returning
13181 // true (original behavior of the function).
13182 //
13183 if (PredicatedIV || !NoWrap || !loopIsFiniteByAssumption(L) ||
13185 return getCouldNotCompute();
13186
13187 if (!isKnownNonZero(Stride)) {
13188 // If we have a step of zero, and RHS isn't invariant in L, we don't know
13189 // if it might eventually be greater than start and if so, on which
13190 // iteration. We can't even produce a useful upper bound.
13191 if (!isLoopInvariant(RHS, L))
13192 return getCouldNotCompute();
13193
13194 // We allow a potentially zero stride, but we need to divide by stride
13195 // below. Since the loop can't be infinite and this check must control
13196 // the sole exit, we can infer the exit must be taken on the first
13197 // iteration (e.g. backedge count = 0) if the stride is zero. Given that,
13198 // we know the numerator in the divides below must be zero, so we can
13199 // pick an arbitrary non-zero value for the denominator (e.g. stride)
13200 // and produce the right result.
13201 // FIXME: Handle the case where Stride is poison?
13202 auto wouldZeroStrideBeUB = [&]() {
13203 // Proof by contradiction. Suppose the stride were zero. If we can
13204 // prove that the backedge *is* taken on the first iteration, then since
13205 // we know this condition controls the sole exit, we must have an
13206 // infinite loop. We can't have a (well defined) infinite loop per
13207 // check just above.
13208 // Note: The (Start - Stride) term is used to get the start' term from
13209 // (start' + stride,+,stride). Remember that we only care about the
13210 // result of this expression when stride == 0 at runtime.
13211 auto *StartIfZero = getMinusSCEV(IV->getStart(), Stride);
13212 return isLoopEntryGuardedByCond(L, Cond, StartIfZero, RHS);
13213 };
13214 if (!wouldZeroStrideBeUB()) {
13215 Stride = getUMaxExpr(Stride, getOne(Stride->getType()));
13216 }
13217 }
13218 } else if (!NoWrap) {
13219 // Avoid proven overflow cases: this will ensure that the backedge taken
13220 // count will not generate any unsigned overflow.
13221 if (canIVOverflowOnLT(RHS, Stride, IsSigned))
13222 return getCouldNotCompute();
13223 }
13224
13225 // On all paths just preceeding, we established the following invariant:
13226 // IV can be assumed not to overflow up to and including the exiting
13227 // iteration. We proved this in one of two ways:
13228 // 1) We can show overflow doesn't occur before the exiting iteration
13229 // 1a) canIVOverflowOnLT, and b) step of one
13230 // 2) We can show that if overflow occurs, the loop must execute UB
13231 // before any possible exit.
13232 // Note that we have not yet proved RHS invariant (in general).
13233
13234 const SCEV *Start = IV->getStart();
13235
13236 // Preserve pointer-typed Start/RHS to pass to isLoopEntryGuardedByCond.
13237 // If we convert to integers, isLoopEntryGuardedByCond will miss some cases.
13238 // Use integer-typed versions for actual computation; we can't subtract
13239 // pointers in general.
13240 const SCEV *OrigStart = Start;
13241 const SCEV *OrigRHS = RHS;
13242 if (Start->getType()->isPointerTy()) {
13244 if (isa<SCEVCouldNotCompute>(Start))
13245 return Start;
13246 }
13247 if (RHS->getType()->isPointerTy()) {
13250 return RHS;
13251 }
13252
13253 const SCEV *End = nullptr, *BECount = nullptr,
13254 *BECountIfBackedgeTaken = nullptr;
13255 if (!isLoopInvariant(RHS, L)) {
13256 const auto *RHSAddRec = dyn_cast<SCEVAddRecExpr>(RHS);
13257 if (PositiveStride && RHSAddRec != nullptr && RHSAddRec->getLoop() == L &&
13258 RHSAddRec->getNoWrapFlags()) {
13259 // The structure of loop we are trying to calculate backedge count of:
13260 //
13261 // left = left_start
13262 // right = right_start
13263 //
13264 // while(left < right){
13265 // ... do something here ...
13266 // left += s1; // stride of left is s1 (s1 > 0)
13267 // right += s2; // stride of right is s2 (s2 < 0)
13268 // }
13269 //
13270
13271 const SCEV *RHSStart = RHSAddRec->getStart();
13272 const SCEV *RHSStride = RHSAddRec->getStepRecurrence(*this);
13273
13274 // If Stride - RHSStride is positive and does not overflow, we can write
13275 // backedge count as ->
13276 // ceil((End - Start) /u (Stride - RHSStride))
13277 // Where, End = max(RHSStart, Start)
13278
13279 // Check if RHSStride < 0 and Stride - RHSStride will not overflow.
13280 if (isKnownNegative(RHSStride) &&
13281 willNotOverflow(Instruction::Sub, /*Signed=*/true, Stride,
13282 RHSStride)) {
13283
13284 const SCEV *Denominator = getMinusSCEV(Stride, RHSStride);
13285 if (isKnownPositive(Denominator)) {
13286 End = IsSigned ? getSMaxExpr(RHSStart, Start)
13287 : getUMaxExpr(RHSStart, Start);
13288
13289 // We can do this because End >= Start, as End = max(RHSStart, Start)
13290 const SCEV *Delta = getMinusSCEV(End, Start);
13291
13292 BECount = getUDivCeilSCEV(Delta, Denominator);
13293 BECountIfBackedgeTaken =
13294 getUDivCeilSCEV(getMinusSCEV(RHSStart, Start), Denominator);
13295 }
13296 }
13297 }
13298 if (BECount == nullptr) {
13299 // If we cannot calculate ExactBECount, we can calculate the MaxBECount,
13300 // given the start, stride and max value for the end bound of the
13301 // loop (RHS), and the fact that IV does not overflow (which is
13302 // checked above).
13303 const SCEV *MaxBECount = computeMaxBECountForLT(
13304 Start, Stride, RHS, getTypeSizeInBits(LHS->getType()), IsSigned);
13305 return ExitLimit(getCouldNotCompute() /* ExactNotTaken */, MaxBECount,
13306 MaxBECount, false /*MaxOrZero*/, Predicates);
13307 }
13308 } else {
13309 // We use the expression (max(End,Start)-Start)/Stride to describe the
13310 // backedge count, as if the backedge is taken at least once
13311 // max(End,Start) is End and so the result is as above, and if not
13312 // max(End,Start) is Start so we get a backedge count of zero.
13313 auto *OrigStartMinusStride = getMinusSCEV(OrigStart, Stride);
13314 assert(isAvailableAtLoopEntry(OrigStartMinusStride, L) && "Must be!");
13315 assert(isAvailableAtLoopEntry(OrigStart, L) && "Must be!");
13316 assert(isAvailableAtLoopEntry(OrigRHS, L) && "Must be!");
13317 // Can we prove (max(RHS,Start) > Start - Stride?
13318 if (isLoopEntryGuardedByCond(L, Cond, OrigStartMinusStride, OrigStart) &&
13319 isLoopEntryGuardedByCond(L, Cond, OrigStartMinusStride, OrigRHS)) {
13320 // In this case, we can use a refined formula for computing backedge
13321 // taken count. The general formula remains:
13322 // "End-Start /uceiling Stride" where "End = max(RHS,Start)"
13323 // We want to use the alternate formula:
13324 // "((End - 1) - (Start - Stride)) /u Stride"
13325 // Let's do a quick case analysis to show these are equivalent under
13326 // our precondition that max(RHS,Start) > Start - Stride.
13327 // * For RHS <= Start, the backedge-taken count must be zero.
13328 // "((End - 1) - (Start - Stride)) /u Stride" reduces to
13329 // "((Start - 1) - (Start - Stride)) /u Stride" which simplies to
13330 // "Stride - 1 /u Stride" which is indeed zero for all non-zero values
13331 // of Stride. For 0 stride, we've use umin(1,Stride) above,
13332 // reducing this to the stride of 1 case.
13333 // * For RHS >= Start, the backedge count must be "RHS-Start /uceil
13334 // Stride".
13335 // "((End - 1) - (Start - Stride)) /u Stride" reduces to
13336 // "((RHS - 1) - (Start - Stride)) /u Stride" reassociates to
13337 // "((RHS - (Start - Stride) - 1) /u Stride".
13338 // Our preconditions trivially imply no overflow in that form.
13339 const SCEV *MinusOne = getMinusOne(Stride->getType());
13340 const SCEV *Numerator =
13341 getMinusSCEV(getAddExpr(RHS, MinusOne), getMinusSCEV(Start, Stride));
13342 BECount = getUDivExpr(Numerator, Stride);
13343 }
13344
13345 if (!BECount) {
13346 auto canProveRHSGreaterThanEqualStart = [&]() {
13347 auto CondGE = IsSigned ? ICmpInst::ICMP_SGE : ICmpInst::ICMP_UGE;
13348 const SCEV *GuardedRHS = applyLoopGuards(OrigRHS, L);
13349 const SCEV *GuardedStart = applyLoopGuards(OrigStart, L);
13350
13351 if (isLoopEntryGuardedByCond(L, CondGE, OrigRHS, OrigStart) ||
13352 isKnownPredicate(CondGE, GuardedRHS, GuardedStart))
13353 return true;
13354
13355 // (RHS > Start - 1) implies RHS >= Start.
13356 // * "RHS >= Start" is trivially equivalent to "RHS > Start - 1" if
13357 // "Start - 1" doesn't overflow.
13358 // * For signed comparison, if Start - 1 does overflow, it's equal
13359 // to INT_MAX, and "RHS >s INT_MAX" is trivially false.
13360 // * For unsigned comparison, if Start - 1 does overflow, it's equal
13361 // to UINT_MAX, and "RHS >u UINT_MAX" is trivially false.
13362 //
13363 // FIXME: Should isLoopEntryGuardedByCond do this for us?
13364 auto CondGT = IsSigned ? ICmpInst::ICMP_SGT : ICmpInst::ICMP_UGT;
13365 auto *StartMinusOne =
13366 getAddExpr(OrigStart, getMinusOne(OrigStart->getType()));
13367 return isLoopEntryGuardedByCond(L, CondGT, OrigRHS, StartMinusOne);
13368 };
13369
13370 // If we know that RHS >= Start in the context of loop, then we know
13371 // that max(RHS, Start) = RHS at this point.
13372 if (canProveRHSGreaterThanEqualStart()) {
13373 End = RHS;
13374 } else {
13375 // If RHS < Start, the backedge will be taken zero times. So in
13376 // general, we can write the backedge-taken count as:
13377 //
13378 // RHS >= Start ? ceil(RHS - Start) / Stride : 0
13379 //
13380 // We convert it to the following to make it more convenient for SCEV:
13381 //
13382 // ceil(max(RHS, Start) - Start) / Stride
13383 End = IsSigned ? getSMaxExpr(RHS, Start) : getUMaxExpr(RHS, Start);
13384
13385 // See what would happen if we assume the backedge is taken. This is
13386 // used to compute MaxBECount.
13387 BECountIfBackedgeTaken =
13388 getUDivCeilSCEV(getMinusSCEV(RHS, Start), Stride);
13389 }
13390
13391 // At this point, we know:
13392 //
13393 // 1. If IsSigned, Start <=s End; otherwise, Start <=u End
13394 // 2. The index variable doesn't overflow.
13395 //
13396 // Therefore, we know N exists such that
13397 // (Start + Stride * N) >= End, and computing "(Start + Stride * N)"
13398 // doesn't overflow.
13399 //
13400 // Using this information, try to prove whether the addition in
13401 // "(Start - End) + (Stride - 1)" has unsigned overflow.
13402 const SCEV *One = getOne(Stride->getType());
13403 bool MayAddOverflow = [&] {
13404 if (isKnownToBeAPowerOfTwo(Stride)) {
13405 // Suppose Stride is a power of two, and Start/End are unsigned
13406 // integers. Let UMAX be the largest representable unsigned
13407 // integer.
13408 //
13409 // By the preconditions of this function, we know
13410 // "(Start + Stride * N) >= End", and this doesn't overflow.
13411 // As a formula:
13412 //
13413 // End <= (Start + Stride * N) <= UMAX
13414 //
13415 // Subtracting Start from all the terms:
13416 //
13417 // End - Start <= Stride * N <= UMAX - Start
13418 //
13419 // Since Start is unsigned, UMAX - Start <= UMAX. Therefore:
13420 //
13421 // End - Start <= Stride * N <= UMAX
13422 //
13423 // Stride * N is a multiple of Stride. Therefore,
13424 //
13425 // End - Start <= Stride * N <= UMAX - (UMAX mod Stride)
13426 //
13427 // Since Stride is a power of two, UMAX + 1 is divisible by
13428 // Stride. Therefore, UMAX mod Stride == Stride - 1. So we can
13429 // write:
13430 //
13431 // End - Start <= Stride * N <= UMAX - Stride - 1
13432 //
13433 // Dropping the middle term:
13434 //
13435 // End - Start <= UMAX - Stride - 1
13436 //
13437 // Adding Stride - 1 to both sides:
13438 //
13439 // (End - Start) + (Stride - 1) <= UMAX
13440 //
13441 // In other words, the addition doesn't have unsigned overflow.
13442 //
13443 // A similar proof works if we treat Start/End as signed values.
13444 // Just rewrite steps before "End - Start <= Stride * N <= UMAX"
13445 // to use signed max instead of unsigned max. Note that we're
13446 // trying to prove a lack of unsigned overflow in either case.
13447 return false;
13448 }
13449 if (Start == Stride || Start == getMinusSCEV(Stride, One)) {
13450 // If Start is equal to Stride, (End - Start) + (Stride - 1) == End
13451 // - 1. If !IsSigned, 0 <u Stride == Start <=u End; so 0 <u End - 1
13452 // <u End. If IsSigned, 0 <s Stride == Start <=s End; so 0 <s End -
13453 // 1 <s End.
13454 //
13455 // If Start is equal to Stride - 1, (End - Start) + Stride - 1 ==
13456 // End.
13457 return false;
13458 }
13459 return true;
13460 }();
13461
13462 const SCEV *Delta = getMinusSCEV(End, Start);
13463 if (!MayAddOverflow) {
13464 // floor((D + (S - 1)) / S)
13465 // We prefer this formulation if it's legal because it's fewer
13466 // operations.
13467 BECount =
13468 getUDivExpr(getAddExpr(Delta, getMinusSCEV(Stride, One)), Stride);
13469 } else {
13470 BECount = getUDivCeilSCEV(Delta, Stride);
13471 }
13472 }
13473 }
13474
13475 const SCEV *ConstantMaxBECount;
13476 bool MaxOrZero = false;
13477 if (isa<SCEVConstant>(BECount)) {
13478 ConstantMaxBECount = BECount;
13479 } else if (BECountIfBackedgeTaken &&
13480 isa<SCEVConstant>(BECountIfBackedgeTaken)) {
13481 // If we know exactly how many times the backedge will be taken if it's
13482 // taken at least once, then the backedge count will either be that or
13483 // zero.
13484 ConstantMaxBECount = BECountIfBackedgeTaken;
13485 MaxOrZero = true;
13486 } else {
13487 ConstantMaxBECount = computeMaxBECountForLT(
13488 Start, Stride, RHS, getTypeSizeInBits(LHS->getType()), IsSigned);
13489 }
13490
13491 if (isa<SCEVCouldNotCompute>(ConstantMaxBECount) &&
13492 !isa<SCEVCouldNotCompute>(BECount))
13493 ConstantMaxBECount = getConstant(getUnsignedRangeMax(BECount));
13494
13495 const SCEV *SymbolicMaxBECount =
13496 isa<SCEVCouldNotCompute>(BECount) ? ConstantMaxBECount : BECount;
13497 return ExitLimit(BECount, ConstantMaxBECount, SymbolicMaxBECount, MaxOrZero,
13498 Predicates);
13499}
13500
13501ScalarEvolution::ExitLimit ScalarEvolution::howManyGreaterThans(
13502 const SCEV *LHS, const SCEV *RHS, const Loop *L, bool IsSigned,
13503 bool ControlsOnlyExit, bool AllowPredicates) {
13505 // We handle only IV > Invariant
13506 if (!isLoopInvariant(RHS, L))
13507 return getCouldNotCompute();
13508
13509 const SCEVAddRecExpr *IV = dyn_cast<SCEVAddRecExpr>(LHS);
13510 if (!IV && AllowPredicates)
13511 // Try to make this an AddRec using runtime tests, in the first X
13512 // iterations of this loop, where X is the SCEV expression found by the
13513 // algorithm below.
13514 IV = convertSCEVToAddRecWithPredicates(LHS, L, Predicates);
13515
13516 // Avoid weird loops
13517 if (!IV || IV->getLoop() != L || !IV->isAffine())
13518 return getCouldNotCompute();
13519
13520 auto WrapType = IsSigned ? SCEV::FlagNSW : SCEV::FlagNUW;
13521 bool NoWrap = ControlsOnlyExit && IV->getNoWrapFlags(WrapType);
13523
13524 const SCEV *Stride = getNegativeSCEV(IV->getStepRecurrence(*this));
13525
13526 // Avoid negative or zero stride values
13527 if (!isKnownPositive(Stride))
13528 return getCouldNotCompute();
13529
13530 // Avoid proven overflow cases: this will ensure that the backedge taken count
13531 // will not generate any unsigned overflow. Relaxed no-overflow conditions
13532 // exploit NoWrapFlags, allowing to optimize in presence of undefined
13533 // behaviors like the case of C language.
13534 if (!Stride->isOne() && !NoWrap)
13535 if (canIVOverflowOnGT(RHS, Stride, IsSigned))
13536 return getCouldNotCompute();
13537
13538 const SCEV *Start = IV->getStart();
13539 const SCEV *End = RHS;
13540 if (!isLoopEntryGuardedByCond(L, Cond, getAddExpr(Start, Stride), RHS)) {
13541 // If we know that Start >= RHS in the context of loop, then we know that
13542 // min(RHS, Start) = RHS at this point.
13544 L, IsSigned ? ICmpInst::ICMP_SGE : ICmpInst::ICMP_UGE, Start, RHS))
13545 End = RHS;
13546 else
13547 End = IsSigned ? getSMinExpr(RHS, Start) : getUMinExpr(RHS, Start);
13548 }
13549
13550 if (Start->getType()->isPointerTy()) {
13552 if (isa<SCEVCouldNotCompute>(Start))
13553 return Start;
13554 }
13555 if (End->getType()->isPointerTy()) {
13556 End = getLosslessPtrToIntExpr(End);
13557 if (isa<SCEVCouldNotCompute>(End))
13558 return End;
13559 }
13560
13561 // Compute ((Start - End) + (Stride - 1)) / Stride.
13562 // FIXME: This can overflow. Holding off on fixing this for now;
13563 // howManyGreaterThans will hopefully be gone soon.
13564 const SCEV *One = getOne(Stride->getType());
13565 const SCEV *BECount = getUDivExpr(
13566 getAddExpr(getMinusSCEV(Start, End), getMinusSCEV(Stride, One)), Stride);
13567
13568 APInt MaxStart = IsSigned ? getSignedRangeMax(Start)
13570
13571 APInt MinStride = IsSigned ? getSignedRangeMin(Stride)
13572 : getUnsignedRangeMin(Stride);
13573
13574 unsigned BitWidth = getTypeSizeInBits(LHS->getType());
13575 APInt Limit = IsSigned ? APInt::getSignedMinValue(BitWidth) + (MinStride - 1)
13576 : APInt::getMinValue(BitWidth) + (MinStride - 1);
13577
13578 // Although End can be a MIN expression we estimate MinEnd considering only
13579 // the case End = RHS. This is safe because in the other case (Start - End)
13580 // is zero, leading to a zero maximum backedge taken count.
13581 APInt MinEnd =
13582 IsSigned ? APIntOps::smax(getSignedRangeMin(RHS), Limit)
13583 : APIntOps::umax(getUnsignedRangeMin(RHS), Limit);
13584
13585 const SCEV *ConstantMaxBECount =
13586 isa<SCEVConstant>(BECount)
13587 ? BECount
13588 : getUDivCeilSCEV(getConstant(MaxStart - MinEnd),
13589 getConstant(MinStride));
13590
13591 if (isa<SCEVCouldNotCompute>(ConstantMaxBECount))
13592 ConstantMaxBECount = BECount;
13593 const SCEV *SymbolicMaxBECount =
13594 isa<SCEVCouldNotCompute>(BECount) ? ConstantMaxBECount : BECount;
13595
13596 return ExitLimit(BECount, ConstantMaxBECount, SymbolicMaxBECount, false,
13597 Predicates);
13598}
13599
13601 ScalarEvolution &SE) const {
13602 if (Range.isFullSet()) // Infinite loop.
13603 return SE.getCouldNotCompute();
13604
13605 // If the start is a non-zero constant, shift the range to simplify things.
13606 if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(getStart()))
13607 if (!SC->getValue()->isZero()) {
13609 Operands[0] = SE.getZero(SC->getType());
13610 const SCEV *Shifted = SE.getAddRecExpr(Operands, getLoop(),
13612 if (const auto *ShiftedAddRec = dyn_cast<SCEVAddRecExpr>(Shifted))
13613 return ShiftedAddRec->getNumIterationsInRange(
13614 Range.subtract(SC->getAPInt()), SE);
13615 // This is strange and shouldn't happen.
13616 return SE.getCouldNotCompute();
13617 }
13618
13619 // The only time we can solve this is when we have all constant indices.
13620 // Otherwise, we cannot determine the overflow conditions.
13621 if (any_of(operands(), [](const SCEV *Op) { return !isa<SCEVConstant>(Op); }))
13622 return SE.getCouldNotCompute();
13623
13624 // Okay at this point we know that all elements of the chrec are constants and
13625 // that the start element is zero.
13626
13627 // First check to see if the range contains zero. If not, the first
13628 // iteration exits.
13629 unsigned BitWidth = SE.getTypeSizeInBits(getType());
13630 if (!Range.contains(APInt(BitWidth, 0)))
13631 return SE.getZero(getType());
13632
13633 if (isAffine()) {
13634 // If this is an affine expression then we have this situation:
13635 // Solve {0,+,A} in Range === Ax in Range
13636
13637 // We know that zero is in the range. If A is positive then we know that
13638 // the upper value of the range must be the first possible exit value.
13639 // If A is negative then the lower of the range is the last possible loop
13640 // value. Also note that we already checked for a full range.
13641 APInt A = cast<SCEVConstant>(getOperand(1))->getAPInt();
13642 APInt End = A.sge(1) ? (Range.getUpper() - 1) : Range.getLower();
13643
13644 // The exit value should be (End+A)/A.
13645 APInt ExitVal = (End + A).udiv(A);
13646 ConstantInt *ExitValue = ConstantInt::get(SE.getContext(), ExitVal);
13647
13648 // Evaluate at the exit value. If we really did fall out of the valid
13649 // range, then we computed our trip count, otherwise wrap around or other
13650 // things must have happened.
13651 ConstantInt *Val = EvaluateConstantChrecAtConstant(this, ExitValue, SE);
13652 if (Range.contains(Val->getValue()))
13653 return SE.getCouldNotCompute(); // Something strange happened
13654
13655 // Ensure that the previous value is in the range.
13656 assert(Range.contains(
13658 ConstantInt::get(SE.getContext(), ExitVal - 1), SE)->getValue()) &&
13659 "Linear scev computation is off in a bad way!");
13660 return SE.getConstant(ExitValue);
13661 }
13662
13663 if (isQuadratic()) {
13664 if (auto S = SolveQuadraticAddRecRange(this, Range, SE))
13665 return SE.getConstant(*S);
13666 }
13667
13668 return SE.getCouldNotCompute();
13669}
13670
13671const SCEVAddRecExpr *
13673 assert(getNumOperands() > 1 && "AddRec with zero step?");
13674 // There is a temptation to just call getAddExpr(this, getStepRecurrence(SE)),
13675 // but in this case we cannot guarantee that the value returned will be an
13676 // AddRec because SCEV does not have a fixed point where it stops
13677 // simplification: it is legal to return ({rec1} + {rec2}). For example, it
13678 // may happen if we reach arithmetic depth limit while simplifying. So we
13679 // construct the returned value explicitly.
13681 // If this is {A,+,B,+,C,...,+,N}, then its step is {B,+,C,+,...,+,N}, and
13682 // (this + Step) is {A+B,+,B+C,+...,+,N}.
13683 for (unsigned i = 0, e = getNumOperands() - 1; i < e; ++i)
13684 Ops.push_back(SE.getAddExpr(getOperand(i), getOperand(i + 1)));
13685 // We know that the last operand is not a constant zero (otherwise it would
13686 // have been popped out earlier). This guarantees us that if the result has
13687 // the same last operand, then it will also not be popped out, meaning that
13688 // the returned value will be an AddRec.
13689 const SCEV *Last = getOperand(getNumOperands() - 1);
13690 assert(!Last->isZero() && "Recurrency with zero step?");
13691 Ops.push_back(Last);
13694}
13695
13696// Return true when S contains at least an undef value.
13698 return SCEVExprContains(
13699 S, [](const SCEV *S) { return match(S, m_scev_UndefOrPoison()); });
13700}
13701
13702// Return true when S contains a value that is a nullptr.
13704 return SCEVExprContains(S, [](const SCEV *S) {
13705 if (const auto *SU = dyn_cast<SCEVUnknown>(S))
13706 return SU->getValue() == nullptr;
13707 return false;
13708 });
13709}
13710
13711/// Return the size of an element read or written by Inst.
13713 Type *Ty;
13714 if (StoreInst *Store = dyn_cast<StoreInst>(Inst))
13715 Ty = Store->getValueOperand()->getType();
13716 else if (LoadInst *Load = dyn_cast<LoadInst>(Inst))
13717 Ty = Load->getType();
13718 else
13719 return nullptr;
13720
13722 return getSizeOfExpr(ETy, Ty);
13723}
13724
13725//===----------------------------------------------------------------------===//
13726// SCEVCallbackVH Class Implementation
13727//===----------------------------------------------------------------------===//
13728
13730 assert(SE && "SCEVCallbackVH called with a null ScalarEvolution!");
13731 if (PHINode *PN = dyn_cast<PHINode>(getValPtr()))
13732 SE->ConstantEvolutionLoopExitValue.erase(PN);
13733 SE->eraseValueFromMap(getValPtr());
13734 // this now dangles!
13735}
13736
13737void ScalarEvolution::SCEVCallbackVH::allUsesReplacedWith(Value *V) {
13738 assert(SE && "SCEVCallbackVH called with a null ScalarEvolution!");
13739
13740 // Forget all the expressions associated with users of the old value,
13741 // so that future queries will recompute the expressions using the new
13742 // value.
13743 SE->forgetValue(getValPtr());
13744 // this now dangles!
13745}
13746
13747ScalarEvolution::SCEVCallbackVH::SCEVCallbackVH(Value *V, ScalarEvolution *se)
13748 : CallbackVH(V), SE(se) {}
13749
13750//===----------------------------------------------------------------------===//
13751// ScalarEvolution Class Implementation
13752//===----------------------------------------------------------------------===//
13753
13756 LoopInfo &LI)
13757 : F(F), DL(F.getDataLayout()), TLI(TLI), AC(AC), DT(DT), LI(LI),
13758 CouldNotCompute(new SCEVCouldNotCompute()), ValuesAtScopes(64),
13759 LoopDispositions(64), BlockDispositions(64) {
13760 // To use guards for proving predicates, we need to scan every instruction in
13761 // relevant basic blocks, and not just terminators. Doing this is a waste of
13762 // time if the IR does not actually contain any calls to
13763 // @llvm.experimental.guard, so do a quick check and remember this beforehand.
13764 //
13765 // This pessimizes the case where a pass that preserves ScalarEvolution wants
13766 // to _add_ guards to the module when there weren't any before, and wants
13767 // ScalarEvolution to optimize based on those guards. For now we prefer to be
13768 // efficient in lieu of being smart in that rather obscure case.
13769
13770 auto *GuardDecl = Intrinsic::getDeclarationIfExists(
13771 F.getParent(), Intrinsic::experimental_guard);
13772 HasGuards = GuardDecl && !GuardDecl->use_empty();
13773}
13774
13776 : F(Arg.F), DL(Arg.DL), HasGuards(Arg.HasGuards), TLI(Arg.TLI), AC(Arg.AC),
13777 DT(Arg.DT), LI(Arg.LI), CouldNotCompute(std::move(Arg.CouldNotCompute)),
13778 ValueExprMap(std::move(Arg.ValueExprMap)),
13779 PendingLoopPredicates(std::move(Arg.PendingLoopPredicates)),
13780 PendingPhiRanges(std::move(Arg.PendingPhiRanges)),
13781 PendingMerges(std::move(Arg.PendingMerges)),
13782 ConstantMultipleCache(std::move(Arg.ConstantMultipleCache)),
13783 BackedgeTakenCounts(std::move(Arg.BackedgeTakenCounts)),
13784 PredicatedBackedgeTakenCounts(
13785 std::move(Arg.PredicatedBackedgeTakenCounts)),
13786 BECountUsers(std::move(Arg.BECountUsers)),
13787 ConstantEvolutionLoopExitValue(
13788 std::move(Arg.ConstantEvolutionLoopExitValue)),
13789 ValuesAtScopes(std::move(Arg.ValuesAtScopes)),
13790 ValuesAtScopesUsers(std::move(Arg.ValuesAtScopesUsers)),
13791 LoopDispositions(std::move(Arg.LoopDispositions)),
13792 LoopPropertiesCache(std::move(Arg.LoopPropertiesCache)),
13793 BlockDispositions(std::move(Arg.BlockDispositions)),
13794 SCEVUsers(std::move(Arg.SCEVUsers)),
13795 UnsignedRanges(std::move(Arg.UnsignedRanges)),
13796 SignedRanges(std::move(Arg.SignedRanges)),
13797 UniqueSCEVs(std::move(Arg.UniqueSCEVs)),
13798 UniquePreds(std::move(Arg.UniquePreds)),
13799 SCEVAllocator(std::move(Arg.SCEVAllocator)),
13800 LoopUsers(std::move(Arg.LoopUsers)),
13801 PredicatedSCEVRewrites(std::move(Arg.PredicatedSCEVRewrites)),
13802 FirstUnknown(Arg.FirstUnknown) {
13803 Arg.FirstUnknown = nullptr;
13804}
13805
13807 // Iterate through all the SCEVUnknown instances and call their
13808 // destructors, so that they release their references to their values.
13809 for (SCEVUnknown *U = FirstUnknown; U;) {
13810 SCEVUnknown *Tmp = U;
13811 U = U->Next;
13812 Tmp->~SCEVUnknown();
13813 }
13814 FirstUnknown = nullptr;
13815
13816 ExprValueMap.clear();
13817 ValueExprMap.clear();
13818 HasRecMap.clear();
13819 BackedgeTakenCounts.clear();
13820 PredicatedBackedgeTakenCounts.clear();
13821
13822 assert(PendingLoopPredicates.empty() && "isImpliedCond garbage");
13823 assert(PendingPhiRanges.empty() && "getRangeRef garbage");
13824 assert(PendingMerges.empty() && "isImpliedViaMerge garbage");
13825 assert(!WalkingBEDominatingConds && "isLoopBackedgeGuardedByCond garbage!");
13826 assert(!ProvingSplitPredicate && "ProvingSplitPredicate garbage!");
13827}
13828
13832
13833/// When printing a top-level SCEV for trip counts, it's helpful to include
13834/// a type for constants which are otherwise hard to disambiguate.
13835static void PrintSCEVWithTypeHint(raw_ostream &OS, const SCEV* S) {
13836 if (isa<SCEVConstant>(S))
13837 OS << *S->getType() << " ";
13838 OS << *S;
13839}
13840
13842 const Loop *L) {
13843 // Print all inner loops first
13844 for (Loop *I : *L)
13845 PrintLoopInfo(OS, SE, I);
13846
13847 OS << "Loop ";
13848 L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
13849 OS << ": ";
13850
13851 SmallVector<BasicBlock *, 8> ExitingBlocks;
13852 L->getExitingBlocks(ExitingBlocks);
13853 if (ExitingBlocks.size() != 1)
13854 OS << "<multiple exits> ";
13855
13856 auto *BTC = SE->getBackedgeTakenCount(L);
13857 if (!isa<SCEVCouldNotCompute>(BTC)) {
13858 OS << "backedge-taken count is ";
13859 PrintSCEVWithTypeHint(OS, BTC);
13860 } else
13861 OS << "Unpredictable backedge-taken count.";
13862 OS << "\n";
13863
13864 if (ExitingBlocks.size() > 1)
13865 for (BasicBlock *ExitingBlock : ExitingBlocks) {
13866 OS << " exit count for " << ExitingBlock->getName() << ": ";
13867 const SCEV *EC = SE->getExitCount(L, ExitingBlock);
13868 PrintSCEVWithTypeHint(OS, EC);
13869 if (isa<SCEVCouldNotCompute>(EC)) {
13870 // Retry with predicates.
13872 EC = SE->getPredicatedExitCount(L, ExitingBlock, &Predicates);
13873 if (!isa<SCEVCouldNotCompute>(EC)) {
13874 OS << "\n predicated exit count for " << ExitingBlock->getName()
13875 << ": ";
13876 PrintSCEVWithTypeHint(OS, EC);
13877 OS << "\n Predicates:\n";
13878 for (const auto *P : Predicates)
13879 P->print(OS, 4);
13880 }
13881 }
13882 OS << "\n";
13883 }
13884
13885 OS << "Loop ";
13886 L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
13887 OS << ": ";
13888
13889 auto *ConstantBTC = SE->getConstantMaxBackedgeTakenCount(L);
13890 if (!isa<SCEVCouldNotCompute>(ConstantBTC)) {
13891 OS << "constant max backedge-taken count is ";
13892 PrintSCEVWithTypeHint(OS, ConstantBTC);
13894 OS << ", actual taken count either this or zero.";
13895 } else {
13896 OS << "Unpredictable constant max backedge-taken count. ";
13897 }
13898
13899 OS << "\n"
13900 "Loop ";
13901 L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
13902 OS << ": ";
13903
13904 auto *SymbolicBTC = SE->getSymbolicMaxBackedgeTakenCount(L);
13905 if (!isa<SCEVCouldNotCompute>(SymbolicBTC)) {
13906 OS << "symbolic max backedge-taken count is ";
13907 PrintSCEVWithTypeHint(OS, SymbolicBTC);
13909 OS << ", actual taken count either this or zero.";
13910 } else {
13911 OS << "Unpredictable symbolic max backedge-taken count. ";
13912 }
13913 OS << "\n";
13914
13915 if (ExitingBlocks.size() > 1)
13916 for (BasicBlock *ExitingBlock : ExitingBlocks) {
13917 OS << " symbolic max exit count for " << ExitingBlock->getName() << ": ";
13918 auto *ExitBTC = SE->getExitCount(L, ExitingBlock,
13920 PrintSCEVWithTypeHint(OS, ExitBTC);
13921 if (isa<SCEVCouldNotCompute>(ExitBTC)) {
13922 // Retry with predicates.
13924 ExitBTC = SE->getPredicatedExitCount(L, ExitingBlock, &Predicates,
13926 if (!isa<SCEVCouldNotCompute>(ExitBTC)) {
13927 OS << "\n predicated symbolic max exit count for "
13928 << ExitingBlock->getName() << ": ";
13929 PrintSCEVWithTypeHint(OS, ExitBTC);
13930 OS << "\n Predicates:\n";
13931 for (const auto *P : Predicates)
13932 P->print(OS, 4);
13933 }
13934 }
13935 OS << "\n";
13936 }
13937
13939 auto *PBT = SE->getPredicatedBackedgeTakenCount(L, Preds);
13940 if (PBT != BTC) {
13941 assert(!Preds.empty() && "Different predicated BTC, but no predicates");
13942 OS << "Loop ";
13943 L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
13944 OS << ": ";
13945 if (!isa<SCEVCouldNotCompute>(PBT)) {
13946 OS << "Predicated backedge-taken count is ";
13947 PrintSCEVWithTypeHint(OS, PBT);
13948 } else
13949 OS << "Unpredictable predicated backedge-taken count.";
13950 OS << "\n";
13951 OS << " Predicates:\n";
13952 for (const auto *P : Preds)
13953 P->print(OS, 4);
13954 }
13955 Preds.clear();
13956
13957 auto *PredConstantMax =
13959 if (PredConstantMax != ConstantBTC) {
13960 assert(!Preds.empty() &&
13961 "different predicated constant max BTC but no predicates");
13962 OS << "Loop ";
13963 L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
13964 OS << ": ";
13965 if (!isa<SCEVCouldNotCompute>(PredConstantMax)) {
13966 OS << "Predicated constant max backedge-taken count is ";
13967 PrintSCEVWithTypeHint(OS, PredConstantMax);
13968 } else
13969 OS << "Unpredictable predicated constant max backedge-taken count.";
13970 OS << "\n";
13971 OS << " Predicates:\n";
13972 for (const auto *P : Preds)
13973 P->print(OS, 4);
13974 }
13975 Preds.clear();
13976
13977 auto *PredSymbolicMax =
13979 if (SymbolicBTC != PredSymbolicMax) {
13980 assert(!Preds.empty() &&
13981 "Different predicated symbolic max BTC, but no predicates");
13982 OS << "Loop ";
13983 L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
13984 OS << ": ";
13985 if (!isa<SCEVCouldNotCompute>(PredSymbolicMax)) {
13986 OS << "Predicated symbolic max backedge-taken count is ";
13987 PrintSCEVWithTypeHint(OS, PredSymbolicMax);
13988 } else
13989 OS << "Unpredictable predicated symbolic max backedge-taken count.";
13990 OS << "\n";
13991 OS << " Predicates:\n";
13992 for (const auto *P : Preds)
13993 P->print(OS, 4);
13994 }
13995
13997 OS << "Loop ";
13998 L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
13999 OS << ": ";
14000 OS << "Trip multiple is " << SE->getSmallConstantTripMultiple(L) << "\n";
14001 }
14002}
14003
14004namespace llvm {
14005// Note: these overloaded operators need to be in the llvm namespace for them
14006// to be resolved correctly. If we put them outside the llvm namespace, the
14007//
14008// OS << ": " << SE.getLoopDisposition(SV, InnerL);
14009//
14010// code below "breaks" and start printing raw enum values as opposed to the
14011// string values.
14014 switch (LD) {
14016 OS << "Variant";
14017 break;
14019 OS << "Invariant";
14020 break;
14022 OS << "Computable";
14023 break;
14024 }
14025 return OS;
14026}
14027
14030 switch (BD) {
14032 OS << "DoesNotDominate";
14033 break;
14035 OS << "Dominates";
14036 break;
14038 OS << "ProperlyDominates";
14039 break;
14040 }
14041 return OS;
14042}
14043} // namespace llvm
14044
14046 // ScalarEvolution's implementation of the print method is to print
14047 // out SCEV values of all instructions that are interesting. Doing
14048 // this potentially causes it to create new SCEV objects though,
14049 // which technically conflicts with the const qualifier. This isn't
14050 // observable from outside the class though, so casting away the
14051 // const isn't dangerous.
14052 ScalarEvolution &SE = *const_cast<ScalarEvolution *>(this);
14053
14054 if (ClassifyExpressions) {
14055 OS << "Classifying expressions for: ";
14056 F.printAsOperand(OS, /*PrintType=*/false);
14057 OS << "\n";
14058 for (Instruction &I : instructions(F))
14059 if (isSCEVable(I.getType()) && !isa<CmpInst>(I)) {
14060 OS << I << '\n';
14061 OS << " --> ";
14062 const SCEV *SV = SE.getSCEV(&I);
14063 SV->print(OS);
14064 if (!isa<SCEVCouldNotCompute>(SV)) {
14065 OS << " U: ";
14066 SE.getUnsignedRange(SV).print(OS);
14067 OS << " S: ";
14068 SE.getSignedRange(SV).print(OS);
14069 }
14070
14071 const Loop *L = LI.getLoopFor(I.getParent());
14072
14073 const SCEV *AtUse = SE.getSCEVAtScope(SV, L);
14074 if (AtUse != SV) {
14075 OS << " --> ";
14076 AtUse->print(OS);
14077 if (!isa<SCEVCouldNotCompute>(AtUse)) {
14078 OS << " U: ";
14079 SE.getUnsignedRange(AtUse).print(OS);
14080 OS << " S: ";
14081 SE.getSignedRange(AtUse).print(OS);
14082 }
14083 }
14084
14085 if (L) {
14086 OS << "\t\t" "Exits: ";
14087 const SCEV *ExitValue = SE.getSCEVAtScope(SV, L->getParentLoop());
14088 if (!SE.isLoopInvariant(ExitValue, L)) {
14089 OS << "<<Unknown>>";
14090 } else {
14091 OS << *ExitValue;
14092 }
14093
14094 bool First = true;
14095 for (const auto *Iter = L; Iter; Iter = Iter->getParentLoop()) {
14096 if (First) {
14097 OS << "\t\t" "LoopDispositions: { ";
14098 First = false;
14099 } else {
14100 OS << ", ";
14101 }
14102
14103 Iter->getHeader()->printAsOperand(OS, /*PrintType=*/false);
14104 OS << ": " << SE.getLoopDisposition(SV, Iter);
14105 }
14106
14107 for (const auto *InnerL : depth_first(L)) {
14108 if (InnerL == L)
14109 continue;
14110 if (First) {
14111 OS << "\t\t" "LoopDispositions: { ";
14112 First = false;
14113 } else {
14114 OS << ", ";
14115 }
14116
14117 InnerL->getHeader()->printAsOperand(OS, /*PrintType=*/false);
14118 OS << ": " << SE.getLoopDisposition(SV, InnerL);
14119 }
14120
14121 OS << " }";
14122 }
14123
14124 OS << "\n";
14125 }
14126 }
14127
14128 OS << "Determining loop execution counts for: ";
14129 F.printAsOperand(OS, /*PrintType=*/false);
14130 OS << "\n";
14131 for (Loop *I : LI)
14132 PrintLoopInfo(OS, &SE, I);
14133}
14134
14137 auto &Values = LoopDispositions[S];
14138 for (auto &V : Values) {
14139 if (V.getPointer() == L)
14140 return V.getInt();
14141 }
14142 Values.emplace_back(L, LoopVariant);
14143 LoopDisposition D = computeLoopDisposition(S, L);
14144 auto &Values2 = LoopDispositions[S];
14145 for (auto &V : llvm::reverse(Values2)) {
14146 if (V.getPointer() == L) {
14147 V.setInt(D);
14148 break;
14149 }
14150 }
14151 return D;
14152}
14153
14155ScalarEvolution::computeLoopDisposition(const SCEV *S, const Loop *L) {
14156 switch (S->getSCEVType()) {
14157 case scConstant:
14158 case scVScale:
14159 return LoopInvariant;
14160 case scAddRecExpr: {
14161 const SCEVAddRecExpr *AR = cast<SCEVAddRecExpr>(S);
14162
14163 // If L is the addrec's loop, it's computable.
14164 if (AR->getLoop() == L)
14165 return LoopComputable;
14166
14167 // Add recurrences are never invariant in the function-body (null loop).
14168 if (!L)
14169 return LoopVariant;
14170
14171 // Everything that is not defined at loop entry is variant.
14172 if (DT.dominates(L->getHeader(), AR->getLoop()->getHeader()))
14173 return LoopVariant;
14174 assert(!L->contains(AR->getLoop()) && "Containing loop's header does not"
14175 " dominate the contained loop's header?");
14176
14177 // This recurrence is invariant w.r.t. L if AR's loop contains L.
14178 if (AR->getLoop()->contains(L))
14179 return LoopInvariant;
14180
14181 // This recurrence is variant w.r.t. L if any of its operands
14182 // are variant.
14183 for (const auto *Op : AR->operands())
14184 if (!isLoopInvariant(Op, L))
14185 return LoopVariant;
14186
14187 // Otherwise it's loop-invariant.
14188 return LoopInvariant;
14189 }
14190 case scTruncate:
14191 case scZeroExtend:
14192 case scSignExtend:
14193 case scPtrToInt:
14194 case scAddExpr:
14195 case scMulExpr:
14196 case scUDivExpr:
14197 case scUMaxExpr:
14198 case scSMaxExpr:
14199 case scUMinExpr:
14200 case scSMinExpr:
14201 case scSequentialUMinExpr: {
14202 bool HasVarying = false;
14203 for (const auto *Op : S->operands()) {
14205 if (D == LoopVariant)
14206 return LoopVariant;
14207 if (D == LoopComputable)
14208 HasVarying = true;
14209 }
14210 return HasVarying ? LoopComputable : LoopInvariant;
14211 }
14212 case scUnknown:
14213 // All non-instruction values are loop invariant. All instructions are loop
14214 // invariant if they are not contained in the specified loop.
14215 // Instructions are never considered invariant in the function body
14216 // (null loop) because they are defined within the "loop".
14217 if (auto *I = dyn_cast<Instruction>(cast<SCEVUnknown>(S)->getValue()))
14218 return (L && !L->contains(I)) ? LoopInvariant : LoopVariant;
14219 return LoopInvariant;
14220 case scCouldNotCompute:
14221 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
14222 }
14223 llvm_unreachable("Unknown SCEV kind!");
14224}
14225
14227 return getLoopDisposition(S, L) == LoopInvariant;
14228}
14229
14231 return getLoopDisposition(S, L) == LoopComputable;
14232}
14233
14236 auto &Values = BlockDispositions[S];
14237 for (auto &V : Values) {
14238 if (V.getPointer() == BB)
14239 return V.getInt();
14240 }
14241 Values.emplace_back(BB, DoesNotDominateBlock);
14242 BlockDisposition D = computeBlockDisposition(S, BB);
14243 auto &Values2 = BlockDispositions[S];
14244 for (auto &V : llvm::reverse(Values2)) {
14245 if (V.getPointer() == BB) {
14246 V.setInt(D);
14247 break;
14248 }
14249 }
14250 return D;
14251}
14252
14254ScalarEvolution::computeBlockDisposition(const SCEV *S, const BasicBlock *BB) {
14255 switch (S->getSCEVType()) {
14256 case scConstant:
14257 case scVScale:
14259 case scAddRecExpr: {
14260 // This uses a "dominates" query instead of "properly dominates" query
14261 // to test for proper dominance too, because the instruction which
14262 // produces the addrec's value is a PHI, and a PHI effectively properly
14263 // dominates its entire containing block.
14264 const SCEVAddRecExpr *AR = cast<SCEVAddRecExpr>(S);
14265 if (!DT.dominates(AR->getLoop()->getHeader(), BB))
14266 return DoesNotDominateBlock;
14267
14268 // Fall through into SCEVNAryExpr handling.
14269 [[fallthrough]];
14270 }
14271 case scTruncate:
14272 case scZeroExtend:
14273 case scSignExtend:
14274 case scPtrToInt:
14275 case scAddExpr:
14276 case scMulExpr:
14277 case scUDivExpr:
14278 case scUMaxExpr:
14279 case scSMaxExpr:
14280 case scUMinExpr:
14281 case scSMinExpr:
14282 case scSequentialUMinExpr: {
14283 bool Proper = true;
14284 for (const SCEV *NAryOp : S->operands()) {
14286 if (D == DoesNotDominateBlock)
14287 return DoesNotDominateBlock;
14288 if (D == DominatesBlock)
14289 Proper = false;
14290 }
14291 return Proper ? ProperlyDominatesBlock : DominatesBlock;
14292 }
14293 case scUnknown:
14294 if (Instruction *I =
14295 dyn_cast<Instruction>(cast<SCEVUnknown>(S)->getValue())) {
14296 if (I->getParent() == BB)
14297 return DominatesBlock;
14298 if (DT.properlyDominates(I->getParent(), BB))
14300 return DoesNotDominateBlock;
14301 }
14303 case scCouldNotCompute:
14304 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
14305 }
14306 llvm_unreachable("Unknown SCEV kind!");
14307}
14308
14309bool ScalarEvolution::dominates(const SCEV *S, const BasicBlock *BB) {
14310 return getBlockDisposition(S, BB) >= DominatesBlock;
14311}
14312
14315}
14316
14317bool ScalarEvolution::hasOperand(const SCEV *S, const SCEV *Op) const {
14318 return SCEVExprContains(S, [&](const SCEV *Expr) { return Expr == Op; });
14319}
14320
14321void ScalarEvolution::forgetBackedgeTakenCounts(const Loop *L,
14322 bool Predicated) {
14323 auto &BECounts =
14324 Predicated ? PredicatedBackedgeTakenCounts : BackedgeTakenCounts;
14325 auto It = BECounts.find(L);
14326 if (It != BECounts.end()) {
14327 for (const ExitNotTakenInfo &ENT : It->second.ExitNotTaken) {
14328 for (const SCEV *S : {ENT.ExactNotTaken, ENT.SymbolicMaxNotTaken}) {
14329 if (!isa<SCEVConstant>(S)) {
14330 auto UserIt = BECountUsers.find(S);
14331 assert(UserIt != BECountUsers.end());
14332 UserIt->second.erase({L, Predicated});
14333 }
14334 }
14335 }
14336 BECounts.erase(It);
14337 }
14338}
14339
14340void ScalarEvolution::forgetMemoizedResults(ArrayRef<const SCEV *> SCEVs) {
14341 SmallPtrSet<const SCEV *, 8> ToForget(llvm::from_range, SCEVs);
14342 SmallVector<const SCEV *, 8> Worklist(ToForget.begin(), ToForget.end());
14343
14344 while (!Worklist.empty()) {
14345 const SCEV *Curr = Worklist.pop_back_val();
14346 auto Users = SCEVUsers.find(Curr);
14347 if (Users != SCEVUsers.end())
14348 for (const auto *User : Users->second)
14349 if (ToForget.insert(User).second)
14350 Worklist.push_back(User);
14351 }
14352
14353 for (const auto *S : ToForget)
14354 forgetMemoizedResultsImpl(S);
14355
14356 for (auto I = PredicatedSCEVRewrites.begin();
14357 I != PredicatedSCEVRewrites.end();) {
14358 std::pair<const SCEV *, const Loop *> Entry = I->first;
14359 if (ToForget.count(Entry.first))
14360 PredicatedSCEVRewrites.erase(I++);
14361 else
14362 ++I;
14363 }
14364}
14365
14366void ScalarEvolution::forgetMemoizedResultsImpl(const SCEV *S) {
14367 LoopDispositions.erase(S);
14368 BlockDispositions.erase(S);
14369 UnsignedRanges.erase(S);
14370 SignedRanges.erase(S);
14371 HasRecMap.erase(S);
14372 ConstantMultipleCache.erase(S);
14373
14374 if (auto *AR = dyn_cast<SCEVAddRecExpr>(S)) {
14375 UnsignedWrapViaInductionTried.erase(AR);
14376 SignedWrapViaInductionTried.erase(AR);
14377 }
14378
14379 auto ExprIt = ExprValueMap.find(S);
14380 if (ExprIt != ExprValueMap.end()) {
14381 for (Value *V : ExprIt->second) {
14382 auto ValueIt = ValueExprMap.find_as(V);
14383 if (ValueIt != ValueExprMap.end())
14384 ValueExprMap.erase(ValueIt);
14385 }
14386 ExprValueMap.erase(ExprIt);
14387 }
14388
14389 auto ScopeIt = ValuesAtScopes.find(S);
14390 if (ScopeIt != ValuesAtScopes.end()) {
14391 for (const auto &Pair : ScopeIt->second)
14392 if (!isa_and_nonnull<SCEVConstant>(Pair.second))
14393 llvm::erase(ValuesAtScopesUsers[Pair.second],
14394 std::make_pair(Pair.first, S));
14395 ValuesAtScopes.erase(ScopeIt);
14396 }
14397
14398 auto ScopeUserIt = ValuesAtScopesUsers.find(S);
14399 if (ScopeUserIt != ValuesAtScopesUsers.end()) {
14400 for (const auto &Pair : ScopeUserIt->second)
14401 llvm::erase(ValuesAtScopes[Pair.second], std::make_pair(Pair.first, S));
14402 ValuesAtScopesUsers.erase(ScopeUserIt);
14403 }
14404
14405 auto BEUsersIt = BECountUsers.find(S);
14406 if (BEUsersIt != BECountUsers.end()) {
14407 // Work on a copy, as forgetBackedgeTakenCounts() will modify the original.
14408 auto Copy = BEUsersIt->second;
14409 for (const auto &Pair : Copy)
14410 forgetBackedgeTakenCounts(Pair.getPointer(), Pair.getInt());
14411 BECountUsers.erase(BEUsersIt);
14412 }
14413
14414 auto FoldUser = FoldCacheUser.find(S);
14415 if (FoldUser != FoldCacheUser.end())
14416 for (auto &KV : FoldUser->second)
14417 FoldCache.erase(KV);
14418 FoldCacheUser.erase(S);
14419}
14420
14421void
14422ScalarEvolution::getUsedLoops(const SCEV *S,
14423 SmallPtrSetImpl<const Loop *> &LoopsUsed) {
14424 struct FindUsedLoops {
14425 FindUsedLoops(SmallPtrSetImpl<const Loop *> &LoopsUsed)
14426 : LoopsUsed(LoopsUsed) {}
14427 SmallPtrSetImpl<const Loop *> &LoopsUsed;
14428 bool follow(const SCEV *S) {
14429 if (auto *AR = dyn_cast<SCEVAddRecExpr>(S))
14430 LoopsUsed.insert(AR->getLoop());
14431 return true;
14432 }
14433
14434 bool isDone() const { return false; }
14435 };
14436
14437 FindUsedLoops F(LoopsUsed);
14438 SCEVTraversal<FindUsedLoops>(F).visitAll(S);
14439}
14440
14441void ScalarEvolution::getReachableBlocks(
14444 Worklist.push_back(&F.getEntryBlock());
14445 while (!Worklist.empty()) {
14446 BasicBlock *BB = Worklist.pop_back_val();
14447 if (!Reachable.insert(BB).second)
14448 continue;
14449
14450 Value *Cond;
14451 BasicBlock *TrueBB, *FalseBB;
14452 if (match(BB->getTerminator(), m_Br(m_Value(Cond), m_BasicBlock(TrueBB),
14453 m_BasicBlock(FalseBB)))) {
14454 if (auto *C = dyn_cast<ConstantInt>(Cond)) {
14455 Worklist.push_back(C->isOne() ? TrueBB : FalseBB);
14456 continue;
14457 }
14458
14459 if (auto *Cmp = dyn_cast<ICmpInst>(Cond)) {
14460 const SCEV *L = getSCEV(Cmp->getOperand(0));
14461 const SCEV *R = getSCEV(Cmp->getOperand(1));
14462 if (isKnownPredicateViaConstantRanges(Cmp->getCmpPredicate(), L, R)) {
14463 Worklist.push_back(TrueBB);
14464 continue;
14465 }
14466 if (isKnownPredicateViaConstantRanges(Cmp->getInverseCmpPredicate(), L,
14467 R)) {
14468 Worklist.push_back(FalseBB);
14469 continue;
14470 }
14471 }
14472 }
14473
14474 append_range(Worklist, successors(BB));
14475 }
14476}
14477
14479 ScalarEvolution &SE = *const_cast<ScalarEvolution *>(this);
14480 ScalarEvolution SE2(F, TLI, AC, DT, LI);
14481
14482 SmallVector<Loop *, 8> LoopStack(LI.begin(), LI.end());
14483
14484 // Map's SCEV expressions from one ScalarEvolution "universe" to another.
14485 struct SCEVMapper : public SCEVRewriteVisitor<SCEVMapper> {
14486 SCEVMapper(ScalarEvolution &SE) : SCEVRewriteVisitor<SCEVMapper>(SE) {}
14487
14488 const SCEV *visitConstant(const SCEVConstant *Constant) {
14489 return SE.getConstant(Constant->getAPInt());
14490 }
14491
14492 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
14493 return SE.getUnknown(Expr->getValue());
14494 }
14495
14496 const SCEV *visitCouldNotCompute(const SCEVCouldNotCompute *Expr) {
14497 return SE.getCouldNotCompute();
14498 }
14499 };
14500
14501 SCEVMapper SCM(SE2);
14502 SmallPtrSet<BasicBlock *, 16> ReachableBlocks;
14503 SE2.getReachableBlocks(ReachableBlocks, F);
14504
14505 auto GetDelta = [&](const SCEV *Old, const SCEV *New) -> const SCEV * {
14506 if (containsUndefs(Old) || containsUndefs(New)) {
14507 // SCEV treats "undef" as an unknown but consistent value (i.e. it does
14508 // not propagate undef aggressively). This means we can (and do) fail
14509 // verification in cases where a transform makes a value go from "undef"
14510 // to "undef+1" (say). The transform is fine, since in both cases the
14511 // result is "undef", but SCEV thinks the value increased by 1.
14512 return nullptr;
14513 }
14514
14515 // Unless VerifySCEVStrict is set, we only compare constant deltas.
14516 const SCEV *Delta = SE2.getMinusSCEV(Old, New);
14517 if (!VerifySCEVStrict && !isa<SCEVConstant>(Delta))
14518 return nullptr;
14519
14520 return Delta;
14521 };
14522
14523 while (!LoopStack.empty()) {
14524 auto *L = LoopStack.pop_back_val();
14525 llvm::append_range(LoopStack, *L);
14526
14527 // Only verify BECounts in reachable loops. For an unreachable loop,
14528 // any BECount is legal.
14529 if (!ReachableBlocks.contains(L->getHeader()))
14530 continue;
14531
14532 // Only verify cached BECounts. Computing new BECounts may change the
14533 // results of subsequent SCEV uses.
14534 auto It = BackedgeTakenCounts.find(L);
14535 if (It == BackedgeTakenCounts.end())
14536 continue;
14537
14538 auto *CurBECount =
14539 SCM.visit(It->second.getExact(L, const_cast<ScalarEvolution *>(this)));
14540 auto *NewBECount = SE2.getBackedgeTakenCount(L);
14541
14542 if (CurBECount == SE2.getCouldNotCompute() ||
14543 NewBECount == SE2.getCouldNotCompute()) {
14544 // NB! This situation is legal, but is very suspicious -- whatever pass
14545 // change the loop to make a trip count go from could not compute to
14546 // computable or vice-versa *should have* invalidated SCEV. However, we
14547 // choose not to assert here (for now) since we don't want false
14548 // positives.
14549 continue;
14550 }
14551
14552 if (SE.getTypeSizeInBits(CurBECount->getType()) >
14553 SE.getTypeSizeInBits(NewBECount->getType()))
14554 NewBECount = SE2.getZeroExtendExpr(NewBECount, CurBECount->getType());
14555 else if (SE.getTypeSizeInBits(CurBECount->getType()) <
14556 SE.getTypeSizeInBits(NewBECount->getType()))
14557 CurBECount = SE2.getZeroExtendExpr(CurBECount, NewBECount->getType());
14558
14559 const SCEV *Delta = GetDelta(CurBECount, NewBECount);
14560 if (Delta && !Delta->isZero()) {
14561 dbgs() << "Trip Count for " << *L << " Changed!\n";
14562 dbgs() << "Old: " << *CurBECount << "\n";
14563 dbgs() << "New: " << *NewBECount << "\n";
14564 dbgs() << "Delta: " << *Delta << "\n";
14565 std::abort();
14566 }
14567 }
14568
14569 // Collect all valid loops currently in LoopInfo.
14570 SmallPtrSet<Loop *, 32> ValidLoops;
14571 SmallVector<Loop *, 32> Worklist(LI.begin(), LI.end());
14572 while (!Worklist.empty()) {
14573 Loop *L = Worklist.pop_back_val();
14574 if (ValidLoops.insert(L).second)
14575 Worklist.append(L->begin(), L->end());
14576 }
14577 for (const auto &KV : ValueExprMap) {
14578#ifndef NDEBUG
14579 // Check for SCEV expressions referencing invalid/deleted loops.
14580 if (auto *AR = dyn_cast<SCEVAddRecExpr>(KV.second)) {
14581 assert(ValidLoops.contains(AR->getLoop()) &&
14582 "AddRec references invalid loop");
14583 }
14584#endif
14585
14586 // Check that the value is also part of the reverse map.
14587 auto It = ExprValueMap.find(KV.second);
14588 if (It == ExprValueMap.end() || !It->second.contains(KV.first)) {
14589 dbgs() << "Value " << *KV.first
14590 << " is in ValueExprMap but not in ExprValueMap\n";
14591 std::abort();
14592 }
14593
14594 if (auto *I = dyn_cast<Instruction>(&*KV.first)) {
14595 if (!ReachableBlocks.contains(I->getParent()))
14596 continue;
14597 const SCEV *OldSCEV = SCM.visit(KV.second);
14598 const SCEV *NewSCEV = SE2.getSCEV(I);
14599 const SCEV *Delta = GetDelta(OldSCEV, NewSCEV);
14600 if (Delta && !Delta->isZero()) {
14601 dbgs() << "SCEV for value " << *I << " changed!\n"
14602 << "Old: " << *OldSCEV << "\n"
14603 << "New: " << *NewSCEV << "\n"
14604 << "Delta: " << *Delta << "\n";
14605 std::abort();
14606 }
14607 }
14608 }
14609
14610 for (const auto &KV : ExprValueMap) {
14611 for (Value *V : KV.second) {
14612 const SCEV *S = ValueExprMap.lookup(V);
14613 if (!S) {
14614 dbgs() << "Value " << *V
14615 << " is in ExprValueMap but not in ValueExprMap\n";
14616 std::abort();
14617 }
14618 if (S != KV.first) {
14619 dbgs() << "Value " << *V << " mapped to " << *S << " rather than "
14620 << *KV.first << "\n";
14621 std::abort();
14622 }
14623 }
14624 }
14625
14626 // Verify integrity of SCEV users.
14627 for (const auto &S : UniqueSCEVs) {
14628 for (const auto *Op : S.operands()) {
14629 // We do not store dependencies of constants.
14630 if (isa<SCEVConstant>(Op))
14631 continue;
14632 auto It = SCEVUsers.find(Op);
14633 if (It != SCEVUsers.end() && It->second.count(&S))
14634 continue;
14635 dbgs() << "Use of operand " << *Op << " by user " << S
14636 << " is not being tracked!\n";
14637 std::abort();
14638 }
14639 }
14640
14641 // Verify integrity of ValuesAtScopes users.
14642 for (const auto &ValueAndVec : ValuesAtScopes) {
14643 const SCEV *Value = ValueAndVec.first;
14644 for (const auto &LoopAndValueAtScope : ValueAndVec.second) {
14645 const Loop *L = LoopAndValueAtScope.first;
14646 const SCEV *ValueAtScope = LoopAndValueAtScope.second;
14647 if (!isa<SCEVConstant>(ValueAtScope)) {
14648 auto It = ValuesAtScopesUsers.find(ValueAtScope);
14649 if (It != ValuesAtScopesUsers.end() &&
14650 is_contained(It->second, std::make_pair(L, Value)))
14651 continue;
14652 dbgs() << "Value: " << *Value << ", Loop: " << *L << ", ValueAtScope: "
14653 << *ValueAtScope << " missing in ValuesAtScopesUsers\n";
14654 std::abort();
14655 }
14656 }
14657 }
14658
14659 for (const auto &ValueAtScopeAndVec : ValuesAtScopesUsers) {
14660 const SCEV *ValueAtScope = ValueAtScopeAndVec.first;
14661 for (const auto &LoopAndValue : ValueAtScopeAndVec.second) {
14662 const Loop *L = LoopAndValue.first;
14663 const SCEV *Value = LoopAndValue.second;
14665 auto It = ValuesAtScopes.find(Value);
14666 if (It != ValuesAtScopes.end() &&
14667 is_contained(It->second, std::make_pair(L, ValueAtScope)))
14668 continue;
14669 dbgs() << "Value: " << *Value << ", Loop: " << *L << ", ValueAtScope: "
14670 << *ValueAtScope << " missing in ValuesAtScopes\n";
14671 std::abort();
14672 }
14673 }
14674
14675 // Verify integrity of BECountUsers.
14676 auto VerifyBECountUsers = [&](bool Predicated) {
14677 auto &BECounts =
14678 Predicated ? PredicatedBackedgeTakenCounts : BackedgeTakenCounts;
14679 for (const auto &LoopAndBEInfo : BECounts) {
14680 for (const ExitNotTakenInfo &ENT : LoopAndBEInfo.second.ExitNotTaken) {
14681 for (const SCEV *S : {ENT.ExactNotTaken, ENT.SymbolicMaxNotTaken}) {
14682 if (!isa<SCEVConstant>(S)) {
14683 auto UserIt = BECountUsers.find(S);
14684 if (UserIt != BECountUsers.end() &&
14685 UserIt->second.contains({ LoopAndBEInfo.first, Predicated }))
14686 continue;
14687 dbgs() << "Value " << *S << " for loop " << *LoopAndBEInfo.first
14688 << " missing from BECountUsers\n";
14689 std::abort();
14690 }
14691 }
14692 }
14693 }
14694 };
14695 VerifyBECountUsers(/* Predicated */ false);
14696 VerifyBECountUsers(/* Predicated */ true);
14697
14698 // Verify intergity of loop disposition cache.
14699 for (auto &[S, Values] : LoopDispositions) {
14700 for (auto [Loop, CachedDisposition] : Values) {
14701 const auto RecomputedDisposition = SE2.getLoopDisposition(S, Loop);
14702 if (CachedDisposition != RecomputedDisposition) {
14703 dbgs() << "Cached disposition of " << *S << " for loop " << *Loop
14704 << " is incorrect: cached " << CachedDisposition << ", actual "
14705 << RecomputedDisposition << "\n";
14706 std::abort();
14707 }
14708 }
14709 }
14710
14711 // Verify integrity of the block disposition cache.
14712 for (auto &[S, Values] : BlockDispositions) {
14713 for (auto [BB, CachedDisposition] : Values) {
14714 const auto RecomputedDisposition = SE2.getBlockDisposition(S, BB);
14715 if (CachedDisposition != RecomputedDisposition) {
14716 dbgs() << "Cached disposition of " << *S << " for block %"
14717 << BB->getName() << " is incorrect: cached " << CachedDisposition
14718 << ", actual " << RecomputedDisposition << "\n";
14719 std::abort();
14720 }
14721 }
14722 }
14723
14724 // Verify FoldCache/FoldCacheUser caches.
14725 for (auto [FoldID, Expr] : FoldCache) {
14726 auto I = FoldCacheUser.find(Expr);
14727 if (I == FoldCacheUser.end()) {
14728 dbgs() << "Missing entry in FoldCacheUser for cached expression " << *Expr
14729 << "!\n";
14730 std::abort();
14731 }
14732 if (!is_contained(I->second, FoldID)) {
14733 dbgs() << "Missing FoldID in cached users of " << *Expr << "!\n";
14734 std::abort();
14735 }
14736 }
14737 for (auto [Expr, IDs] : FoldCacheUser) {
14738 for (auto &FoldID : IDs) {
14739 const SCEV *S = FoldCache.lookup(FoldID);
14740 if (!S) {
14741 dbgs() << "Missing entry in FoldCache for expression " << *Expr
14742 << "!\n";
14743 std::abort();
14744 }
14745 if (S != Expr) {
14746 dbgs() << "Entry in FoldCache doesn't match FoldCacheUser: " << *S
14747 << " != " << *Expr << "!\n";
14748 std::abort();
14749 }
14750 }
14751 }
14752
14753 // Verify that ConstantMultipleCache computations are correct. We check that
14754 // cached multiples and recomputed multiples are multiples of each other to
14755 // verify correctness. It is possible that a recomputed multiple is different
14756 // from the cached multiple due to strengthened no wrap flags or changes in
14757 // KnownBits computations.
14758 for (auto [S, Multiple] : ConstantMultipleCache) {
14759 APInt RecomputedMultiple = SE2.getConstantMultiple(S);
14760 if ((Multiple != 0 && RecomputedMultiple != 0 &&
14761 Multiple.urem(RecomputedMultiple) != 0 &&
14762 RecomputedMultiple.urem(Multiple) != 0)) {
14763 dbgs() << "Incorrect cached computation in ConstantMultipleCache for "
14764 << *S << " : Computed " << RecomputedMultiple
14765 << " but cache contains " << Multiple << "!\n";
14766 std::abort();
14767 }
14768 }
14769}
14770
14772 Function &F, const PreservedAnalyses &PA,
14773 FunctionAnalysisManager::Invalidator &Inv) {
14774 // Invalidate the ScalarEvolution object whenever it isn't preserved or one
14775 // of its dependencies is invalidated.
14776 auto PAC = PA.getChecker<ScalarEvolutionAnalysis>();
14777 return !(PAC.preserved() || PAC.preservedSet<AllAnalysesOn<Function>>()) ||
14778 Inv.invalidate<AssumptionAnalysis>(F, PA) ||
14779 Inv.invalidate<DominatorTreeAnalysis>(F, PA) ||
14780 Inv.invalidate<LoopAnalysis>(F, PA);
14781}
14782
14783AnalysisKey ScalarEvolutionAnalysis::Key;
14784
14787 auto &TLI = AM.getResult<TargetLibraryAnalysis>(F);
14788 auto &AC = AM.getResult<AssumptionAnalysis>(F);
14789 auto &DT = AM.getResult<DominatorTreeAnalysis>(F);
14790 auto &LI = AM.getResult<LoopAnalysis>(F);
14791 return ScalarEvolution(F, TLI, AC, DT, LI);
14792}
14793
14799
14802 // For compatibility with opt's -analyze feature under legacy pass manager
14803 // which was not ported to NPM. This keeps tests using
14804 // update_analyze_test_checks.py working.
14805 OS << "Printing analysis 'Scalar Evolution Analysis' for function '"
14806 << F.getName() << "':\n";
14808 return PreservedAnalyses::all();
14809}
14810
14812 "Scalar Evolution Analysis", false, true)
14818 "Scalar Evolution Analysis", false, true)
14819
14821
14823
14825 SE.reset(new ScalarEvolution(
14827 getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F),
14829 getAnalysis<LoopInfoWrapperPass>().getLoopInfo()));
14830 return false;
14831}
14832
14834
14836 SE->print(OS);
14837}
14838
14840 if (!VerifySCEV)
14841 return;
14842
14843 SE->verify();
14844}
14845
14853
14855 const SCEV *RHS) {
14856 return getComparePredicate(ICmpInst::ICMP_EQ, LHS, RHS);
14857}
14858
14859const SCEVPredicate *
14861 const SCEV *LHS, const SCEV *RHS) {
14863 assert(LHS->getType() == RHS->getType() &&
14864 "Type mismatch between LHS and RHS");
14865 // Unique this node based on the arguments
14866 ID.AddInteger(SCEVPredicate::P_Compare);
14867 ID.AddInteger(Pred);
14868 ID.AddPointer(LHS);
14869 ID.AddPointer(RHS);
14870 void *IP = nullptr;
14871 if (const auto *S = UniquePreds.FindNodeOrInsertPos(ID, IP))
14872 return S;
14873 SCEVComparePredicate *Eq = new (SCEVAllocator)
14874 SCEVComparePredicate(ID.Intern(SCEVAllocator), Pred, LHS, RHS);
14875 UniquePreds.InsertNode(Eq, IP);
14876 return Eq;
14877}
14878
14880 const SCEVAddRecExpr *AR,
14883 // Unique this node based on the arguments
14884 ID.AddInteger(SCEVPredicate::P_Wrap);
14885 ID.AddPointer(AR);
14886 ID.AddInteger(AddedFlags);
14887 void *IP = nullptr;
14888 if (const auto *S = UniquePreds.FindNodeOrInsertPos(ID, IP))
14889 return S;
14890 auto *OF = new (SCEVAllocator)
14891 SCEVWrapPredicate(ID.Intern(SCEVAllocator), AR, AddedFlags);
14892 UniquePreds.InsertNode(OF, IP);
14893 return OF;
14894}
14895
14896namespace {
14897
14898class SCEVPredicateRewriter : public SCEVRewriteVisitor<SCEVPredicateRewriter> {
14899public:
14900
14901 /// Rewrites \p S in the context of a loop L and the SCEV predication
14902 /// infrastructure.
14903 ///
14904 /// If \p Pred is non-null, the SCEV expression is rewritten to respect the
14905 /// equivalences present in \p Pred.
14906 ///
14907 /// If \p NewPreds is non-null, rewrite is free to add further predicates to
14908 /// \p NewPreds such that the result will be an AddRecExpr.
14909 static const SCEV *rewrite(const SCEV *S, const Loop *L, ScalarEvolution &SE,
14911 const SCEVPredicate *Pred) {
14912 SCEVPredicateRewriter Rewriter(L, SE, NewPreds, Pred);
14913 return Rewriter.visit(S);
14914 }
14915
14916 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
14917 if (Pred) {
14918 if (auto *U = dyn_cast<SCEVUnionPredicate>(Pred)) {
14919 for (const auto *Pred : U->getPredicates())
14920 if (const auto *IPred = dyn_cast<SCEVComparePredicate>(Pred))
14921 if (IPred->getLHS() == Expr &&
14922 IPred->getPredicate() == ICmpInst::ICMP_EQ)
14923 return IPred->getRHS();
14924 } else if (const auto *IPred = dyn_cast<SCEVComparePredicate>(Pred)) {
14925 if (IPred->getLHS() == Expr &&
14926 IPred->getPredicate() == ICmpInst::ICMP_EQ)
14927 return IPred->getRHS();
14928 }
14929 }
14930 return convertToAddRecWithPreds(Expr);
14931 }
14932
14933 const SCEV *visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr) {
14934 const SCEV *Operand = visit(Expr->getOperand());
14935 const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Operand);
14936 if (AR && AR->getLoop() == L && AR->isAffine()) {
14937 // This couldn't be folded because the operand didn't have the nuw
14938 // flag. Add the nusw flag as an assumption that we could make.
14939 const SCEV *Step = AR->getStepRecurrence(SE);
14940 Type *Ty = Expr->getType();
14941 if (addOverflowAssumption(AR, SCEVWrapPredicate::IncrementNUSW))
14942 return SE.getAddRecExpr(SE.getZeroExtendExpr(AR->getStart(), Ty),
14943 SE.getSignExtendExpr(Step, Ty), L,
14944 AR->getNoWrapFlags());
14945 }
14946 return SE.getZeroExtendExpr(Operand, Expr->getType());
14947 }
14948
14949 const SCEV *visitSignExtendExpr(const SCEVSignExtendExpr *Expr) {
14950 const SCEV *Operand = visit(Expr->getOperand());
14951 const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Operand);
14952 if (AR && AR->getLoop() == L && AR->isAffine()) {
14953 // This couldn't be folded because the operand didn't have the nsw
14954 // flag. Add the nssw flag as an assumption that we could make.
14955 const SCEV *Step = AR->getStepRecurrence(SE);
14956 Type *Ty = Expr->getType();
14957 if (addOverflowAssumption(AR, SCEVWrapPredicate::IncrementNSSW))
14958 return SE.getAddRecExpr(SE.getSignExtendExpr(AR->getStart(), Ty),
14959 SE.getSignExtendExpr(Step, Ty), L,
14960 AR->getNoWrapFlags());
14961 }
14962 return SE.getSignExtendExpr(Operand, Expr->getType());
14963 }
14964
14965private:
14966 explicit SCEVPredicateRewriter(
14967 const Loop *L, ScalarEvolution &SE,
14968 SmallVectorImpl<const SCEVPredicate *> *NewPreds,
14969 const SCEVPredicate *Pred)
14970 : SCEVRewriteVisitor(SE), NewPreds(NewPreds), Pred(Pred), L(L) {}
14971
14972 bool addOverflowAssumption(const SCEVPredicate *P) {
14973 if (!NewPreds) {
14974 // Check if we've already made this assumption.
14975 return Pred && Pred->implies(P, SE);
14976 }
14977 NewPreds->push_back(P);
14978 return true;
14979 }
14980
14981 bool addOverflowAssumption(const SCEVAddRecExpr *AR,
14983 auto *A = SE.getWrapPredicate(AR, AddedFlags);
14984 return addOverflowAssumption(A);
14985 }
14986
14987 // If \p Expr represents a PHINode, we try to see if it can be represented
14988 // as an AddRec, possibly under a predicate (PHISCEVPred). If it is possible
14989 // to add this predicate as a runtime overflow check, we return the AddRec.
14990 // If \p Expr does not meet these conditions (is not a PHI node, or we
14991 // couldn't create an AddRec for it, or couldn't add the predicate), we just
14992 // return \p Expr.
14993 const SCEV *convertToAddRecWithPreds(const SCEVUnknown *Expr) {
14994 if (!isa<PHINode>(Expr->getValue()))
14995 return Expr;
14996 std::optional<
14997 std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>>>
14998 PredicatedRewrite = SE.createAddRecFromPHIWithCasts(Expr);
14999 if (!PredicatedRewrite)
15000 return Expr;
15001 for (const auto *P : PredicatedRewrite->second){
15002 // Wrap predicates from outer loops are not supported.
15003 if (auto *WP = dyn_cast<const SCEVWrapPredicate>(P)) {
15004 if (L != WP->getExpr()->getLoop())
15005 return Expr;
15006 }
15007 if (!addOverflowAssumption(P))
15008 return Expr;
15009 }
15010 return PredicatedRewrite->first;
15011 }
15012
15013 SmallVectorImpl<const SCEVPredicate *> *NewPreds;
15014 const SCEVPredicate *Pred;
15015 const Loop *L;
15016};
15017
15018} // end anonymous namespace
15019
15020const SCEV *
15022 const SCEVPredicate &Preds) {
15023 return SCEVPredicateRewriter::rewrite(S, L, *this, nullptr, &Preds);
15024}
15025
15027 const SCEV *S, const Loop *L,
15030 S = SCEVPredicateRewriter::rewrite(S, L, *this, &TransformPreds, nullptr);
15031 auto *AddRec = dyn_cast<SCEVAddRecExpr>(S);
15032
15033 if (!AddRec)
15034 return nullptr;
15035
15036 // Check if any of the transformed predicates is known to be false. In that
15037 // case, it doesn't make sense to convert to a predicated AddRec, as the
15038 // versioned loop will never execute.
15039 for (const SCEVPredicate *Pred : TransformPreds) {
15040 auto *WrapPred = dyn_cast<SCEVWrapPredicate>(Pred);
15041 if (!WrapPred || WrapPred->getFlags() != SCEVWrapPredicate::IncrementNSSW)
15042 continue;
15043
15044 const SCEVAddRecExpr *AddRecToCheck = WrapPred->getExpr();
15045 const SCEV *ExitCount = getBackedgeTakenCount(AddRecToCheck->getLoop());
15046 if (isa<SCEVCouldNotCompute>(ExitCount))
15047 continue;
15048
15049 const SCEV *Step = AddRecToCheck->getStepRecurrence(*this);
15050 if (!Step->isOne())
15051 continue;
15052
15053 ExitCount = getTruncateOrSignExtend(ExitCount, Step->getType());
15054 const SCEV *Add = getAddExpr(AddRecToCheck->getStart(), ExitCount);
15055 if (isKnownPredicate(CmpInst::ICMP_SLT, Add, AddRecToCheck->getStart()))
15056 return nullptr;
15057 }
15058
15059 // Since the transformation was successful, we can now transfer the SCEV
15060 // predicates.
15061 Preds.append(TransformPreds.begin(), TransformPreds.end());
15062
15063 return AddRec;
15064}
15065
15066/// SCEV predicates
15070
15072 const ICmpInst::Predicate Pred,
15073 const SCEV *LHS, const SCEV *RHS)
15074 : SCEVPredicate(ID, P_Compare), Pred(Pred), LHS(LHS), RHS(RHS) {
15075 assert(LHS->getType() == RHS->getType() && "LHS and RHS types don't match");
15076 assert(LHS != RHS && "LHS and RHS are the same SCEV");
15077}
15078
15080 ScalarEvolution &SE) const {
15081 const auto *Op = dyn_cast<SCEVComparePredicate>(N);
15082
15083 if (!Op)
15084 return false;
15085
15086 if (Pred != ICmpInst::ICMP_EQ)
15087 return false;
15088
15089 return Op->LHS == LHS && Op->RHS == RHS;
15090}
15091
15092bool SCEVComparePredicate::isAlwaysTrue() const { return false; }
15093
15095 if (Pred == ICmpInst::ICMP_EQ)
15096 OS.indent(Depth) << "Equal predicate: " << *LHS << " == " << *RHS << "\n";
15097 else
15098 OS.indent(Depth) << "Compare predicate: " << *LHS << " " << Pred << ") "
15099 << *RHS << "\n";
15100
15101}
15102
15104 const SCEVAddRecExpr *AR,
15105 IncrementWrapFlags Flags)
15106 : SCEVPredicate(ID, P_Wrap), AR(AR), Flags(Flags) {}
15107
15108const SCEVAddRecExpr *SCEVWrapPredicate::getExpr() const { return AR; }
15109
15111 ScalarEvolution &SE) const {
15112 const auto *Op = dyn_cast<SCEVWrapPredicate>(N);
15113 if (!Op || setFlags(Flags, Op->Flags) != Flags)
15114 return false;
15115
15116 if (Op->AR == AR)
15117 return true;
15118
15119 if (Flags != SCEVWrapPredicate::IncrementNSSW &&
15121 return false;
15122
15123 const SCEV *Start = AR->getStart();
15124 const SCEV *OpStart = Op->AR->getStart();
15125 if (Start->getType()->isPointerTy() != OpStart->getType()->isPointerTy())
15126 return false;
15127
15128 // Reject pointers to different address spaces.
15129 if (Start->getType()->isPointerTy() && Start->getType() != OpStart->getType())
15130 return false;
15131
15132 const SCEV *Step = AR->getStepRecurrence(SE);
15133 const SCEV *OpStep = Op->AR->getStepRecurrence(SE);
15134 if (!SE.isKnownPositive(Step) || !SE.isKnownPositive(OpStep))
15135 return false;
15136
15137 // If both steps are positive, this implies N, if N's start and step are
15138 // ULE/SLE (for NSUW/NSSW) than this'.
15139 Type *WiderTy = SE.getWiderType(Step->getType(), OpStep->getType());
15140 Step = SE.getNoopOrZeroExtend(Step, WiderTy);
15141 OpStep = SE.getNoopOrZeroExtend(OpStep, WiderTy);
15142
15143 bool IsNUW = Flags == SCEVWrapPredicate::IncrementNUSW;
15144 OpStart = IsNUW ? SE.getNoopOrZeroExtend(OpStart, WiderTy)
15145 : SE.getNoopOrSignExtend(OpStart, WiderTy);
15146 Start = IsNUW ? SE.getNoopOrZeroExtend(Start, WiderTy)
15147 : SE.getNoopOrSignExtend(Start, WiderTy);
15149 return SE.isKnownPredicate(Pred, OpStep, Step) &&
15150 SE.isKnownPredicate(Pred, OpStart, Start);
15151}
15152
15154 SCEV::NoWrapFlags ScevFlags = AR->getNoWrapFlags();
15155 IncrementWrapFlags IFlags = Flags;
15156
15157 if (ScalarEvolution::setFlags(ScevFlags, SCEV::FlagNSW) == ScevFlags)
15158 IFlags = clearFlags(IFlags, IncrementNSSW);
15159
15160 return IFlags == IncrementAnyWrap;
15161}
15162
15163void SCEVWrapPredicate::print(raw_ostream &OS, unsigned Depth) const {
15164 OS.indent(Depth) << *getExpr() << " Added Flags: ";
15166 OS << "<nusw>";
15168 OS << "<nssw>";
15169 OS << "\n";
15170}
15171
15174 ScalarEvolution &SE) {
15175 IncrementWrapFlags ImpliedFlags = IncrementAnyWrap;
15176 SCEV::NoWrapFlags StaticFlags = AR->getNoWrapFlags();
15177
15178 // We can safely transfer the NSW flag as NSSW.
15179 if (ScalarEvolution::setFlags(StaticFlags, SCEV::FlagNSW) == StaticFlags)
15180 ImpliedFlags = IncrementNSSW;
15181
15182 if (ScalarEvolution::setFlags(StaticFlags, SCEV::FlagNUW) == StaticFlags) {
15183 // If the increment is positive, the SCEV NUW flag will also imply the
15184 // WrapPredicate NUSW flag.
15185 if (const auto *Step = dyn_cast<SCEVConstant>(AR->getStepRecurrence(SE)))
15186 if (Step->getValue()->getValue().isNonNegative())
15187 ImpliedFlags = setFlags(ImpliedFlags, IncrementNUSW);
15188 }
15189
15190 return ImpliedFlags;
15191}
15192
15193/// Union predicates don't get cached so create a dummy set ID for it.
15195 ScalarEvolution &SE)
15196 : SCEVPredicate(FoldingSetNodeIDRef(nullptr, 0), P_Union) {
15197 for (const auto *P : Preds)
15198 add(P, SE);
15199}
15200
15202 return all_of(Preds,
15203 [](const SCEVPredicate *I) { return I->isAlwaysTrue(); });
15204}
15205
15207 ScalarEvolution &SE) const {
15208 if (const auto *Set = dyn_cast<SCEVUnionPredicate>(N))
15209 return all_of(Set->Preds, [this, &SE](const SCEVPredicate *I) {
15210 return this->implies(I, SE);
15211 });
15212
15213 return any_of(Preds,
15214 [N, &SE](const SCEVPredicate *I) { return I->implies(N, SE); });
15215}
15216
15218 for (const auto *Pred : Preds)
15219 Pred->print(OS, Depth);
15220}
15221
15222void SCEVUnionPredicate::add(const SCEVPredicate *N, ScalarEvolution &SE) {
15223 if (const auto *Set = dyn_cast<SCEVUnionPredicate>(N)) {
15224 for (const auto *Pred : Set->Preds)
15225 add(Pred, SE);
15226 return;
15227 }
15228
15229 // Implication checks are quadratic in the number of predicates. Stop doing
15230 // them if there are many predicates, as they should be too expensive to use
15231 // anyway at that point.
15232 bool CheckImplies = Preds.size() < 16;
15233
15234 // Only add predicate if it is not already implied by this union predicate.
15235 if (CheckImplies && implies(N, SE))
15236 return;
15237
15238 // Build a new vector containing the current predicates, except the ones that
15239 // are implied by the new predicate N.
15241 for (auto *P : Preds) {
15242 if (CheckImplies && N->implies(P, SE))
15243 continue;
15244 PrunedPreds.push_back(P);
15245 }
15246 Preds = std::move(PrunedPreds);
15247 Preds.push_back(N);
15248}
15249
15251 Loop &L)
15252 : SE(SE), L(L) {
15254 Preds = std::make_unique<SCEVUnionPredicate>(Empty, SE);
15255}
15256
15259 for (const auto *Op : Ops)
15260 // We do not expect that forgetting cached data for SCEVConstants will ever
15261 // open any prospects for sharpening or introduce any correctness issues,
15262 // so we don't bother storing their dependencies.
15263 if (!isa<SCEVConstant>(Op))
15264 SCEVUsers[Op].insert(User);
15265}
15266
15268 const SCEV *Expr = SE.getSCEV(V);
15269 RewriteEntry &Entry = RewriteMap[Expr];
15270
15271 // If we already have an entry and the version matches, return it.
15272 if (Entry.second && Generation == Entry.first)
15273 return Entry.second;
15274
15275 // We found an entry but it's stale. Rewrite the stale entry
15276 // according to the current predicate.
15277 if (Entry.second)
15278 Expr = Entry.second;
15279
15280 const SCEV *NewSCEV = SE.rewriteUsingPredicate(Expr, &L, *Preds);
15281 Entry = {Generation, NewSCEV};
15282
15283 return NewSCEV;
15284}
15285
15287 if (!BackedgeCount) {
15289 BackedgeCount = SE.getPredicatedBackedgeTakenCount(&L, Preds);
15290 for (const auto *P : Preds)
15291 addPredicate(*P);
15292 }
15293 return BackedgeCount;
15294}
15295
15297 if (!SymbolicMaxBackedgeCount) {
15299 SymbolicMaxBackedgeCount =
15300 SE.getPredicatedSymbolicMaxBackedgeTakenCount(&L, Preds);
15301 for (const auto *P : Preds)
15302 addPredicate(*P);
15303 }
15304 return SymbolicMaxBackedgeCount;
15305}
15306
15308 if (!SmallConstantMaxTripCount) {
15310 SmallConstantMaxTripCount = SE.getSmallConstantMaxTripCount(&L, &Preds);
15311 for (const auto *P : Preds)
15312 addPredicate(*P);
15313 }
15314 return *SmallConstantMaxTripCount;
15315}
15316
15318 if (Preds->implies(&Pred, SE))
15319 return;
15320
15321 SmallVector<const SCEVPredicate *, 4> NewPreds(Preds->getPredicates());
15322 NewPreds.push_back(&Pred);
15323 Preds = std::make_unique<SCEVUnionPredicate>(NewPreds, SE);
15324 updateGeneration();
15325}
15326
15328 return *Preds;
15329}
15330
15331void PredicatedScalarEvolution::updateGeneration() {
15332 // If the generation number wrapped recompute everything.
15333 if (++Generation == 0) {
15334 for (auto &II : RewriteMap) {
15335 const SCEV *Rewritten = II.second.second;
15336 II.second = {Generation, SE.rewriteUsingPredicate(Rewritten, &L, *Preds)};
15337 }
15338 }
15339}
15340
15343 const SCEV *Expr = getSCEV(V);
15344 const auto *AR = cast<SCEVAddRecExpr>(Expr);
15345
15346 auto ImpliedFlags = SCEVWrapPredicate::getImpliedFlags(AR, SE);
15347
15348 // Clear the statically implied flags.
15349 Flags = SCEVWrapPredicate::clearFlags(Flags, ImpliedFlags);
15350 addPredicate(*SE.getWrapPredicate(AR, Flags));
15351
15352 auto II = FlagsMap.insert({V, Flags});
15353 if (!II.second)
15354 II.first->second = SCEVWrapPredicate::setFlags(Flags, II.first->second);
15355}
15356
15359 const SCEV *Expr = getSCEV(V);
15360 const auto *AR = cast<SCEVAddRecExpr>(Expr);
15361
15363 Flags, SCEVWrapPredicate::getImpliedFlags(AR, SE));
15364
15365 auto II = FlagsMap.find(V);
15366
15367 if (II != FlagsMap.end())
15368 Flags = SCEVWrapPredicate::clearFlags(Flags, II->second);
15369
15371}
15372
15374 const SCEV *Expr = this->getSCEV(V);
15376 auto *New = SE.convertSCEVToAddRecWithPredicates(Expr, &L, NewPreds);
15377
15378 if (!New)
15379 return nullptr;
15380
15381 for (const auto *P : NewPreds)
15382 addPredicate(*P);
15383
15384 RewriteMap[SE.getSCEV(V)] = {Generation, New};
15385 return New;
15386}
15387
15390 : RewriteMap(Init.RewriteMap), SE(Init.SE), L(Init.L),
15391 Preds(std::make_unique<SCEVUnionPredicate>(Init.Preds->getPredicates(),
15392 SE)),
15393 Generation(Init.Generation), BackedgeCount(Init.BackedgeCount) {
15394 for (auto I : Init.FlagsMap)
15395 FlagsMap.insert(I);
15396}
15397
15399 // For each block.
15400 for (auto *BB : L.getBlocks())
15401 for (auto &I : *BB) {
15402 if (!SE.isSCEVable(I.getType()))
15403 continue;
15404
15405 auto *Expr = SE.getSCEV(&I);
15406 auto II = RewriteMap.find(Expr);
15407
15408 if (II == RewriteMap.end())
15409 continue;
15410
15411 // Don't print things that are not interesting.
15412 if (II->second.second == Expr)
15413 continue;
15414
15415 OS.indent(Depth) << "[PSE]" << I << ":\n";
15416 OS.indent(Depth + 2) << *Expr << "\n";
15417 OS.indent(Depth + 2) << "--> " << *II->second.second << "\n";
15418 }
15419}
15420
15423 BasicBlock *Header = L->getHeader();
15424 BasicBlock *Pred = L->getLoopPredecessor();
15425 LoopGuards Guards(SE);
15426 if (!Pred)
15427 return Guards;
15429 collectFromBlock(SE, Guards, Header, Pred, VisitedBlocks);
15430 return Guards;
15431}
15432
15433void ScalarEvolution::LoopGuards::collectFromPHI(
15437 unsigned Depth) {
15438 if (!SE.isSCEVable(Phi.getType()))
15439 return;
15440
15441 using MinMaxPattern = std::pair<const SCEVConstant *, SCEVTypes>;
15442 auto GetMinMaxConst = [&](unsigned IncomingIdx) -> MinMaxPattern {
15443 const BasicBlock *InBlock = Phi.getIncomingBlock(IncomingIdx);
15444 if (!VisitedBlocks.insert(InBlock).second)
15445 return {nullptr, scCouldNotCompute};
15446
15447 // Avoid analyzing unreachable blocks so that we don't get trapped
15448 // traversing cycles with ill-formed dominance or infinite cycles
15449 if (!SE.DT.isReachableFromEntry(InBlock))
15450 return {nullptr, scCouldNotCompute};
15451
15452 auto [G, Inserted] = IncomingGuards.try_emplace(InBlock, LoopGuards(SE));
15453 if (Inserted)
15454 collectFromBlock(SE, G->second, Phi.getParent(), InBlock, VisitedBlocks,
15455 Depth + 1);
15456 auto &RewriteMap = G->second.RewriteMap;
15457 if (RewriteMap.empty())
15458 return {nullptr, scCouldNotCompute};
15459 auto S = RewriteMap.find(SE.getSCEV(Phi.getIncomingValue(IncomingIdx)));
15460 if (S == RewriteMap.end())
15461 return {nullptr, scCouldNotCompute};
15462 auto *SM = dyn_cast_if_present<SCEVMinMaxExpr>(S->second);
15463 if (!SM)
15464 return {nullptr, scCouldNotCompute};
15465 if (const SCEVConstant *C0 = dyn_cast<SCEVConstant>(SM->getOperand(0)))
15466 return {C0, SM->getSCEVType()};
15467 return {nullptr, scCouldNotCompute};
15468 };
15469 auto MergeMinMaxConst = [](MinMaxPattern P1,
15470 MinMaxPattern P2) -> MinMaxPattern {
15471 auto [C1, T1] = P1;
15472 auto [C2, T2] = P2;
15473 if (!C1 || !C2 || T1 != T2)
15474 return {nullptr, scCouldNotCompute};
15475 switch (T1) {
15476 case scUMaxExpr:
15477 return {C1->getAPInt().ult(C2->getAPInt()) ? C1 : C2, T1};
15478 case scSMaxExpr:
15479 return {C1->getAPInt().slt(C2->getAPInt()) ? C1 : C2, T1};
15480 case scUMinExpr:
15481 return {C1->getAPInt().ugt(C2->getAPInt()) ? C1 : C2, T1};
15482 case scSMinExpr:
15483 return {C1->getAPInt().sgt(C2->getAPInt()) ? C1 : C2, T1};
15484 default:
15485 llvm_unreachable("Trying to merge non-MinMaxExpr SCEVs.");
15486 }
15487 };
15488 auto P = GetMinMaxConst(0);
15489 for (unsigned int In = 1; In < Phi.getNumIncomingValues(); In++) {
15490 if (!P.first)
15491 break;
15492 P = MergeMinMaxConst(P, GetMinMaxConst(In));
15493 }
15494 if (P.first) {
15495 const SCEV *LHS = SE.getSCEV(const_cast<PHINode *>(&Phi));
15497 const SCEV *RHS = SE.getMinMaxExpr(P.second, Ops);
15498 Guards.RewriteMap.insert({LHS, RHS});
15499 }
15500}
15501
15502// Return a new SCEV that modifies \p Expr to the closest number divides by
15503// \p Divisor and less or equal than Expr. For now, only handle constant
15504// Expr.
15506 const APInt &DivisorVal,
15507 ScalarEvolution &SE) {
15508 const APInt *ExprVal;
15509 if (!match(Expr, m_scev_APInt(ExprVal)) || ExprVal->isNegative() ||
15510 DivisorVal.isNonPositive())
15511 return Expr;
15512 APInt Rem = ExprVal->urem(DivisorVal);
15513 // return the SCEV: Expr - Expr % Divisor
15514 return SE.getConstant(*ExprVal - Rem);
15515}
15516
15517// Return a new SCEV that modifies \p Expr to the closest number divides by
15518// \p Divisor and greater or equal than Expr. For now, only handle constant
15519// Expr.
15520static const SCEV *getNextSCEVDivisibleByDivisor(const SCEV *Expr,
15521 const APInt &DivisorVal,
15522 ScalarEvolution &SE) {
15523 const APInt *ExprVal;
15524 if (!match(Expr, m_scev_APInt(ExprVal)) || ExprVal->isNegative() ||
15525 DivisorVal.isNonPositive())
15526 return Expr;
15527 APInt Rem = ExprVal->urem(DivisorVal);
15528 if (Rem.isZero())
15529 return Expr;
15530 // return the SCEV: Expr + Divisor - Expr % Divisor
15531 return SE.getConstant(*ExprVal + DivisorVal - Rem);
15532}
15533
15535 ICmpInst::Predicate Predicate, const SCEV *LHS, const SCEV *RHS,
15538 // If we have LHS == 0, check if LHS is computing a property of some unknown
15539 // SCEV %v which we can rewrite %v to express explicitly.
15541 return false;
15542 // If LHS is A % B, i.e. A % B == 0, rewrite A to (A /u B) * B to
15543 // explicitly express that.
15544 const SCEVUnknown *URemLHS = nullptr;
15545 const SCEV *URemRHS = nullptr;
15546 if (!match(LHS, m_scev_URem(m_SCEVUnknown(URemLHS), m_SCEV(URemRHS), SE)))
15547 return false;
15548
15549 const SCEV *Multiple =
15550 SE.getMulExpr(SE.getUDivExpr(URemLHS, URemRHS), URemRHS);
15551 DivInfo[URemLHS] = Multiple;
15552 if (auto *C = dyn_cast<SCEVConstant>(URemRHS))
15553 Multiples[URemLHS] = C->getAPInt();
15554 return true;
15555}
15556
15557// Check if the condition is a divisibility guard (A % B == 0).
15558static bool isDivisibilityGuard(const SCEV *LHS, const SCEV *RHS,
15559 ScalarEvolution &SE) {
15560 const SCEV *X, *Y;
15561 return match(LHS, m_scev_URem(m_SCEV(X), m_SCEV(Y), SE)) && RHS->isZero();
15562}
15563
15564// Apply divisibility by \p Divisor on MinMaxExpr with constant values,
15565// recursively. This is done by aligning up/down the constant value to the
15566// Divisor.
15567static const SCEV *applyDivisibilityOnMinMaxExpr(const SCEV *MinMaxExpr,
15568 APInt Divisor,
15569 ScalarEvolution &SE) {
15570 // Return true if \p Expr is a MinMax SCEV expression with a non-negative
15571 // constant operand. If so, return in \p SCTy the SCEV type and in \p RHS
15572 // the non-constant operand and in \p LHS the constant operand.
15573 auto IsMinMaxSCEVWithNonNegativeConstant =
15574 [&](const SCEV *Expr, SCEVTypes &SCTy, const SCEV *&LHS,
15575 const SCEV *&RHS) {
15576 if (auto *MinMax = dyn_cast<SCEVMinMaxExpr>(Expr)) {
15577 if (MinMax->getNumOperands() != 2)
15578 return false;
15579 if (auto *C = dyn_cast<SCEVConstant>(MinMax->getOperand(0))) {
15580 if (C->getAPInt().isNegative())
15581 return false;
15582 SCTy = MinMax->getSCEVType();
15583 LHS = MinMax->getOperand(0);
15584 RHS = MinMax->getOperand(1);
15585 return true;
15586 }
15587 }
15588 return false;
15589 };
15590
15591 const SCEV *MinMaxLHS = nullptr, *MinMaxRHS = nullptr;
15592 SCEVTypes SCTy;
15593 if (!IsMinMaxSCEVWithNonNegativeConstant(MinMaxExpr, SCTy, MinMaxLHS,
15594 MinMaxRHS))
15595 return MinMaxExpr;
15596 auto IsMin = isa<SCEVSMinExpr>(MinMaxExpr) || isa<SCEVUMinExpr>(MinMaxExpr);
15597 assert(SE.isKnownNonNegative(MinMaxLHS) && "Expected non-negative operand!");
15598 auto *DivisibleExpr =
15599 IsMin ? getPreviousSCEVDivisibleByDivisor(MinMaxLHS, Divisor, SE)
15600 : getNextSCEVDivisibleByDivisor(MinMaxLHS, Divisor, SE);
15602 applyDivisibilityOnMinMaxExpr(MinMaxRHS, Divisor, SE), DivisibleExpr};
15603 return SE.getMinMaxExpr(SCTy, Ops);
15604}
15605
15606void ScalarEvolution::LoopGuards::collectFromBlock(
15607 ScalarEvolution &SE, ScalarEvolution::LoopGuards &Guards,
15608 const BasicBlock *Block, const BasicBlock *Pred,
15609 SmallPtrSetImpl<const BasicBlock *> &VisitedBlocks, unsigned Depth) {
15610
15612
15613 SmallVector<const SCEV *> ExprsToRewrite;
15614 auto CollectCondition = [&](ICmpInst::Predicate Predicate, const SCEV *LHS,
15615 const SCEV *RHS,
15616 DenseMap<const SCEV *, const SCEV *> &RewriteMap,
15617 const LoopGuards &DivGuards) {
15618 // WARNING: It is generally unsound to apply any wrap flags to the proposed
15619 // replacement SCEV which isn't directly implied by the structure of that
15620 // SCEV. In particular, using contextual facts to imply flags is *NOT*
15621 // legal. See the scoping rules for flags in the header to understand why.
15622
15623 // Check for a condition of the form (-C1 + X < C2). InstCombine will
15624 // create this form when combining two checks of the form (X u< C2 + C1) and
15625 // (X >=u C1).
15626 auto MatchRangeCheckIdiom = [&SE, Predicate, LHS, RHS, &RewriteMap,
15627 &ExprsToRewrite]() {
15628 const SCEVConstant *C1;
15629 const SCEVUnknown *LHSUnknown;
15630 auto *C2 = dyn_cast<SCEVConstant>(RHS);
15631 if (!match(LHS,
15632 m_scev_Add(m_SCEVConstant(C1), m_SCEVUnknown(LHSUnknown))) ||
15633 !C2)
15634 return false;
15635
15636 auto ExactRegion =
15637 ConstantRange::makeExactICmpRegion(Predicate, C2->getAPInt())
15638 .sub(C1->getAPInt());
15639
15640 // Bail out, unless we have a non-wrapping, monotonic range.
15641 if (ExactRegion.isWrappedSet() || ExactRegion.isFullSet())
15642 return false;
15643 auto [I, Inserted] = RewriteMap.try_emplace(LHSUnknown);
15644 const SCEV *RewrittenLHS = Inserted ? LHSUnknown : I->second;
15645 I->second = SE.getUMaxExpr(
15646 SE.getConstant(ExactRegion.getUnsignedMin()),
15647 SE.getUMinExpr(RewrittenLHS,
15648 SE.getConstant(ExactRegion.getUnsignedMax())));
15649 ExprsToRewrite.push_back(LHSUnknown);
15650 return true;
15651 };
15652 if (MatchRangeCheckIdiom())
15653 return;
15654
15655 // Do not apply information for constants or if RHS contains an AddRec.
15657 return;
15658
15659 // If RHS is SCEVUnknown, make sure the information is applied to it.
15661 std::swap(LHS, RHS);
15663 }
15664
15665 // Puts rewrite rule \p From -> \p To into the rewrite map. Also if \p From
15666 // and \p FromRewritten are the same (i.e. there has been no rewrite
15667 // registered for \p From), then puts this value in the list of rewritten
15668 // expressions.
15669 auto AddRewrite = [&](const SCEV *From, const SCEV *FromRewritten,
15670 const SCEV *To) {
15671 if (From == FromRewritten)
15672 ExprsToRewrite.push_back(From);
15673 RewriteMap[From] = To;
15674 };
15675
15676 // Checks whether \p S has already been rewritten. In that case returns the
15677 // existing rewrite because we want to chain further rewrites onto the
15678 // already rewritten value. Otherwise returns \p S.
15679 auto GetMaybeRewritten = [&](const SCEV *S) {
15680 return RewriteMap.lookup_or(S, S);
15681 };
15682
15683 const SCEV *RewrittenLHS = GetMaybeRewritten(LHS);
15684 // Apply divisibility information when computing the constant multiple.
15685 const APInt &DividesBy =
15686 SE.getConstantMultiple(DivGuards.rewrite(RewrittenLHS));
15687
15688 // Collect rewrites for LHS and its transitive operands based on the
15689 // condition.
15690 // For min/max expressions, also apply the guard to its operands:
15691 // 'min(a, b) >= c' -> '(a >= c) and (b >= c)',
15692 // 'min(a, b) > c' -> '(a > c) and (b > c)',
15693 // 'max(a, b) <= c' -> '(a <= c) and (b <= c)',
15694 // 'max(a, b) < c' -> '(a < c) and (b < c)'.
15695
15696 // We cannot express strict predicates in SCEV, so instead we replace them
15697 // with non-strict ones against plus or minus one of RHS depending on the
15698 // predicate.
15699 const SCEV *One = SE.getOne(RHS->getType());
15700 switch (Predicate) {
15701 case CmpInst::ICMP_ULT:
15702 if (RHS->getType()->isPointerTy())
15703 return;
15704 RHS = SE.getUMaxExpr(RHS, One);
15705 [[fallthrough]];
15706 case CmpInst::ICMP_SLT: {
15707 RHS = SE.getMinusSCEV(RHS, One);
15708 RHS = getPreviousSCEVDivisibleByDivisor(RHS, DividesBy, SE);
15709 break;
15710 }
15711 case CmpInst::ICMP_UGT:
15712 case CmpInst::ICMP_SGT:
15713 RHS = SE.getAddExpr(RHS, One);
15714 RHS = getNextSCEVDivisibleByDivisor(RHS, DividesBy, SE);
15715 break;
15716 case CmpInst::ICMP_ULE:
15717 case CmpInst::ICMP_SLE:
15718 RHS = getPreviousSCEVDivisibleByDivisor(RHS, DividesBy, SE);
15719 break;
15720 case CmpInst::ICMP_UGE:
15721 case CmpInst::ICMP_SGE:
15722 RHS = getNextSCEVDivisibleByDivisor(RHS, DividesBy, SE);
15723 break;
15724 default:
15725 break;
15726 }
15727
15729 SmallPtrSet<const SCEV *, 16> Visited;
15730
15731 auto EnqueueOperands = [&Worklist](const SCEVNAryExpr *S) {
15732 append_range(Worklist, S->operands());
15733 };
15734
15735 while (!Worklist.empty()) {
15736 const SCEV *From = Worklist.pop_back_val();
15737 if (isa<SCEVConstant>(From))
15738 continue;
15739 if (!Visited.insert(From).second)
15740 continue;
15741 const SCEV *FromRewritten = GetMaybeRewritten(From);
15742 const SCEV *To = nullptr;
15743
15744 switch (Predicate) {
15745 case CmpInst::ICMP_ULT:
15746 case CmpInst::ICMP_ULE:
15747 To = SE.getUMinExpr(FromRewritten, RHS);
15748 if (auto *UMax = dyn_cast<SCEVUMaxExpr>(FromRewritten))
15749 EnqueueOperands(UMax);
15750 break;
15751 case CmpInst::ICMP_SLT:
15752 case CmpInst::ICMP_SLE:
15753 To = SE.getSMinExpr(FromRewritten, RHS);
15754 if (auto *SMax = dyn_cast<SCEVSMaxExpr>(FromRewritten))
15755 EnqueueOperands(SMax);
15756 break;
15757 case CmpInst::ICMP_UGT:
15758 case CmpInst::ICMP_UGE:
15759 To = SE.getUMaxExpr(FromRewritten, RHS);
15760 if (auto *UMin = dyn_cast<SCEVUMinExpr>(FromRewritten))
15761 EnqueueOperands(UMin);
15762 break;
15763 case CmpInst::ICMP_SGT:
15764 case CmpInst::ICMP_SGE:
15765 To = SE.getSMaxExpr(FromRewritten, RHS);
15766 if (auto *SMin = dyn_cast<SCEVSMinExpr>(FromRewritten))
15767 EnqueueOperands(SMin);
15768 break;
15769 case CmpInst::ICMP_EQ:
15771 To = RHS;
15772 break;
15773 case CmpInst::ICMP_NE:
15774 if (match(RHS, m_scev_Zero())) {
15775 const SCEV *OneAlignedUp =
15776 getNextSCEVDivisibleByDivisor(One, DividesBy, SE);
15777 To = SE.getUMaxExpr(FromRewritten, OneAlignedUp);
15778 } else {
15779 // LHS != RHS can be rewritten as (LHS - RHS) = UMax(1, LHS - RHS),
15780 // but creating the subtraction eagerly is expensive. Track the
15781 // inequalities in a separate map, and materialize the rewrite lazily
15782 // when encountering a suitable subtraction while re-writing.
15783 if (LHS->getType()->isPointerTy()) {
15787 break;
15788 }
15789 const SCEVConstant *C;
15790 const SCEV *A, *B;
15793 RHS = A;
15794 LHS = B;
15795 }
15796 if (LHS > RHS)
15797 std::swap(LHS, RHS);
15798 Guards.NotEqual.insert({LHS, RHS});
15799 continue;
15800 }
15801 break;
15802 default:
15803 break;
15804 }
15805
15806 if (To)
15807 AddRewrite(From, FromRewritten, To);
15808 }
15809 };
15810
15812 // First, collect information from assumptions dominating the loop.
15813 for (auto &AssumeVH : SE.AC.assumptions()) {
15814 if (!AssumeVH)
15815 continue;
15816 auto *AssumeI = cast<CallInst>(AssumeVH);
15817 if (!SE.DT.dominates(AssumeI, Block))
15818 continue;
15819 Terms.emplace_back(AssumeI->getOperand(0), true);
15820 }
15821
15822 // Second, collect information from llvm.experimental.guards dominating the loop.
15823 auto *GuardDecl = Intrinsic::getDeclarationIfExists(
15824 SE.F.getParent(), Intrinsic::experimental_guard);
15825 if (GuardDecl)
15826 for (const auto *GU : GuardDecl->users())
15827 if (const auto *Guard = dyn_cast<IntrinsicInst>(GU))
15828 if (Guard->getFunction() == Block->getParent() &&
15829 SE.DT.dominates(Guard, Block))
15830 Terms.emplace_back(Guard->getArgOperand(0), true);
15831
15832 // Third, collect conditions from dominating branches. Starting at the loop
15833 // predecessor, climb up the predecessor chain, as long as there are
15834 // predecessors that can be found that have unique successors leading to the
15835 // original header.
15836 // TODO: share this logic with isLoopEntryGuardedByCond.
15837 unsigned NumCollectedConditions = 0;
15839 std::pair<const BasicBlock *, const BasicBlock *> Pair(Pred, Block);
15840 for (; Pair.first;
15841 Pair = SE.getPredecessorWithUniqueSuccessorForBB(Pair.first)) {
15842 VisitedBlocks.insert(Pair.second);
15843 const BranchInst *LoopEntryPredicate =
15844 dyn_cast<BranchInst>(Pair.first->getTerminator());
15845 if (!LoopEntryPredicate || LoopEntryPredicate->isUnconditional())
15846 continue;
15847
15848 Terms.emplace_back(LoopEntryPredicate->getCondition(),
15849 LoopEntryPredicate->getSuccessor(0) == Pair.second);
15850 NumCollectedConditions++;
15851
15852 // If we are recursively collecting guards stop after 2
15853 // conditions to limit compile-time impact for now.
15854 if (Depth > 0 && NumCollectedConditions == 2)
15855 break;
15856 }
15857 // Finally, if we stopped climbing the predecessor chain because
15858 // there wasn't a unique one to continue, try to collect conditions
15859 // for PHINodes by recursively following all of their incoming
15860 // blocks and try to merge the found conditions to build a new one
15861 // for the Phi.
15862 if (Pair.second->hasNPredecessorsOrMore(2) &&
15864 SmallDenseMap<const BasicBlock *, LoopGuards> IncomingGuards;
15865 for (auto &Phi : Pair.second->phis())
15866 collectFromPHI(SE, Guards, Phi, VisitedBlocks, IncomingGuards, Depth);
15867 }
15868
15869 // Now apply the information from the collected conditions to
15870 // Guards.RewriteMap. Conditions are processed in reverse order, so the
15871 // earliest conditions is processed first, except guards with divisibility
15872 // information, which are moved to the back. This ensures the SCEVs with the
15873 // shortest dependency chains are constructed first.
15875 GuardsToProcess;
15876 for (auto [Term, EnterIfTrue] : reverse(Terms)) {
15877 SmallVector<Value *, 8> Worklist;
15878 SmallPtrSet<Value *, 8> Visited;
15879 Worklist.push_back(Term);
15880 while (!Worklist.empty()) {
15881 Value *Cond = Worklist.pop_back_val();
15882 if (!Visited.insert(Cond).second)
15883 continue;
15884
15885 if (auto *Cmp = dyn_cast<ICmpInst>(Cond)) {
15886 auto Predicate =
15887 EnterIfTrue ? Cmp->getPredicate() : Cmp->getInversePredicate();
15888 const auto *LHS = SE.getSCEV(Cmp->getOperand(0));
15889 const auto *RHS = SE.getSCEV(Cmp->getOperand(1));
15890 // If LHS is a constant, apply information to the other expression.
15891 // TODO: If LHS is not a constant, check if using CompareSCEVComplexity
15892 // can improve results.
15893 if (isa<SCEVConstant>(LHS)) {
15894 std::swap(LHS, RHS);
15896 }
15897 GuardsToProcess.emplace_back(Predicate, LHS, RHS);
15898 continue;
15899 }
15900
15901 Value *L, *R;
15902 if (EnterIfTrue ? match(Cond, m_LogicalAnd(m_Value(L), m_Value(R)))
15903 : match(Cond, m_LogicalOr(m_Value(L), m_Value(R)))) {
15904 Worklist.push_back(L);
15905 Worklist.push_back(R);
15906 }
15907 }
15908 }
15909
15910 // Process divisibility guards in reverse order to populate DivGuards early.
15911 DenseMap<const SCEV *, APInt> Multiples;
15912 LoopGuards DivGuards(SE);
15913 for (const auto &[Predicate, LHS, RHS] : GuardsToProcess) {
15914 if (!isDivisibilityGuard(LHS, RHS, SE))
15915 continue;
15916 collectDivisibilityInformation(Predicate, LHS, RHS, DivGuards.RewriteMap,
15917 Multiples, SE);
15918 }
15919
15920 for (const auto &[Predicate, LHS, RHS] : GuardsToProcess)
15921 CollectCondition(Predicate, LHS, RHS, Guards.RewriteMap, DivGuards);
15922
15923 // Apply divisibility information last. This ensures it is applied to the
15924 // outermost expression after other rewrites for the given value.
15925 for (const auto &[K, Divisor] : Multiples) {
15926 const SCEV *DivisorSCEV = SE.getConstant(Divisor);
15927 Guards.RewriteMap[K] =
15929 Guards.rewrite(K), Divisor, SE),
15930 DivisorSCEV),
15931 DivisorSCEV);
15932 ExprsToRewrite.push_back(K);
15933 }
15934
15935 // Let the rewriter preserve NUW/NSW flags if the unsigned/signed ranges of
15936 // the replacement expressions are contained in the ranges of the replaced
15937 // expressions.
15938 Guards.PreserveNUW = true;
15939 Guards.PreserveNSW = true;
15940 for (const SCEV *Expr : ExprsToRewrite) {
15941 const SCEV *RewriteTo = Guards.RewriteMap[Expr];
15942 Guards.PreserveNUW &=
15943 SE.getUnsignedRange(Expr).contains(SE.getUnsignedRange(RewriteTo));
15944 Guards.PreserveNSW &=
15945 SE.getSignedRange(Expr).contains(SE.getSignedRange(RewriteTo));
15946 }
15947
15948 // Now that all rewrite information is collect, rewrite the collected
15949 // expressions with the information in the map. This applies information to
15950 // sub-expressions.
15951 if (ExprsToRewrite.size() > 1) {
15952 for (const SCEV *Expr : ExprsToRewrite) {
15953 const SCEV *RewriteTo = Guards.RewriteMap[Expr];
15954 Guards.RewriteMap.erase(Expr);
15955 Guards.RewriteMap.insert({Expr, Guards.rewrite(RewriteTo)});
15956 }
15957 }
15958}
15959
15961 /// A rewriter to replace SCEV expressions in Map with the corresponding entry
15962 /// in the map. It skips AddRecExpr because we cannot guarantee that the
15963 /// replacement is loop invariant in the loop of the AddRec.
15964 class SCEVLoopGuardRewriter
15965 : public SCEVRewriteVisitor<SCEVLoopGuardRewriter> {
15968
15970
15971 public:
15972 SCEVLoopGuardRewriter(ScalarEvolution &SE,
15973 const ScalarEvolution::LoopGuards &Guards)
15974 : SCEVRewriteVisitor(SE), Map(Guards.RewriteMap),
15975 NotEqual(Guards.NotEqual) {
15976 if (Guards.PreserveNUW)
15977 FlagMask = ScalarEvolution::setFlags(FlagMask, SCEV::FlagNUW);
15978 if (Guards.PreserveNSW)
15979 FlagMask = ScalarEvolution::setFlags(FlagMask, SCEV::FlagNSW);
15980 }
15981
15982 const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) { return Expr; }
15983
15984 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
15985 return Map.lookup_or(Expr, Expr);
15986 }
15987
15988 const SCEV *visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr) {
15989 if (const SCEV *S = Map.lookup(Expr))
15990 return S;
15991
15992 // If we didn't find the extact ZExt expr in the map, check if there's
15993 // an entry for a smaller ZExt we can use instead.
15994 Type *Ty = Expr->getType();
15995 const SCEV *Op = Expr->getOperand(0);
15996 unsigned Bitwidth = Ty->getScalarSizeInBits() / 2;
15997 while (Bitwidth % 8 == 0 && Bitwidth >= 8 &&
15998 Bitwidth > Op->getType()->getScalarSizeInBits()) {
15999 Type *NarrowTy = IntegerType::get(SE.getContext(), Bitwidth);
16000 auto *NarrowExt = SE.getZeroExtendExpr(Op, NarrowTy);
16001 if (const SCEV *S = Map.lookup(NarrowExt))
16002 return SE.getZeroExtendExpr(S, Ty);
16003 Bitwidth = Bitwidth / 2;
16004 }
16005
16007 Expr);
16008 }
16009
16010 const SCEV *visitSignExtendExpr(const SCEVSignExtendExpr *Expr) {
16011 if (const SCEV *S = Map.lookup(Expr))
16012 return S;
16014 Expr);
16015 }
16016
16017 const SCEV *visitUMinExpr(const SCEVUMinExpr *Expr) {
16018 if (const SCEV *S = Map.lookup(Expr))
16019 return S;
16021 }
16022
16023 const SCEV *visitSMinExpr(const SCEVSMinExpr *Expr) {
16024 if (const SCEV *S = Map.lookup(Expr))
16025 return S;
16027 }
16028
16029 const SCEV *visitAddExpr(const SCEVAddExpr *Expr) {
16030 // Helper to check if S is a subtraction (A - B) where A != B, and if so,
16031 // return UMax(S, 1).
16032 auto RewriteSubtraction = [&](const SCEV *S) -> const SCEV * {
16033 const SCEV *LHS, *RHS;
16034 if (MatchBinarySub(S, LHS, RHS)) {
16035 if (LHS > RHS)
16036 std::swap(LHS, RHS);
16037 if (NotEqual.contains({LHS, RHS})) {
16038 const SCEV *OneAlignedUp = getNextSCEVDivisibleByDivisor(
16039 SE.getOne(S->getType()), SE.getConstantMultiple(S), SE);
16040 return SE.getUMaxExpr(OneAlignedUp, S);
16041 }
16042 }
16043 return nullptr;
16044 };
16045
16046 // Check if Expr itself is a subtraction pattern with guard info.
16047 if (const SCEV *Rewritten = RewriteSubtraction(Expr))
16048 return Rewritten;
16049
16050 // Trip count expressions sometimes consist of adding 3 operands, i.e.
16051 // (Const + A + B). There may be guard info for A + B, and if so, apply
16052 // it.
16053 // TODO: Could more generally apply guards to Add sub-expressions.
16054 if (isa<SCEVConstant>(Expr->getOperand(0)) &&
16055 Expr->getNumOperands() == 3) {
16056 const SCEV *Add =
16057 SE.getAddExpr(Expr->getOperand(1), Expr->getOperand(2));
16058 if (const SCEV *Rewritten = RewriteSubtraction(Add))
16059 return SE.getAddExpr(
16060 Expr->getOperand(0), Rewritten,
16061 ScalarEvolution::maskFlags(Expr->getNoWrapFlags(), FlagMask));
16062 if (const SCEV *S = Map.lookup(Add))
16063 return SE.getAddExpr(Expr->getOperand(0), S);
16064 }
16066 bool Changed = false;
16067 for (const auto *Op : Expr->operands()) {
16068 Operands.push_back(
16070 Changed |= Op != Operands.back();
16071 }
16072 // We are only replacing operands with equivalent values, so transfer the
16073 // flags from the original expression.
16074 return !Changed ? Expr
16075 : SE.getAddExpr(Operands,
16077 Expr->getNoWrapFlags(), FlagMask));
16078 }
16079
16080 const SCEV *visitMulExpr(const SCEVMulExpr *Expr) {
16082 bool Changed = false;
16083 for (const auto *Op : Expr->operands()) {
16084 Operands.push_back(
16086 Changed |= Op != Operands.back();
16087 }
16088 // We are only replacing operands with equivalent values, so transfer the
16089 // flags from the original expression.
16090 return !Changed ? Expr
16091 : SE.getMulExpr(Operands,
16093 Expr->getNoWrapFlags(), FlagMask));
16094 }
16095 };
16096
16097 if (RewriteMap.empty() && NotEqual.empty())
16098 return Expr;
16099
16100 SCEVLoopGuardRewriter Rewriter(SE, *this);
16101 return Rewriter.visit(Expr);
16102}
16103
16104const SCEV *ScalarEvolution::applyLoopGuards(const SCEV *Expr, const Loop *L) {
16105 return applyLoopGuards(Expr, LoopGuards::collect(L, *this));
16106}
16107
16109 const LoopGuards &Guards) {
16110 return Guards.rewrite(Expr);
16111}
assert(UImm &&(UImm !=~static_cast< T >(0)) &&"Invalid immediate!")
constexpr LLT S1
Rewrite undef for PHI
This file implements a class to represent arbitrary precision integral constant values and operations...
@ PostInc
MachineBasicBlock MachineBasicBlock::iterator DebugLoc DL
Expand Atomic instructions
static GCRegistry::Add< ErlangGC > A("erlang", "erlang-compatible garbage collector")
static GCRegistry::Add< StatepointGC > D("statepoint-example", "an example strategy for statepoint")
static GCRegistry::Add< CoreCLRGC > E("coreclr", "CoreCLR-compatible GC")
static GCRegistry::Add< OcamlGC > B("ocaml", "ocaml 3.10-compatible GC")
#define LLVM_DUMP_METHOD
Mark debug helper function definitions like dump() that should not be stripped from debug builds.
Definition Compiler.h:638
This file contains the declarations for the subclasses of Constant, which represent the different fla...
SmallPtrSet< const BasicBlock *, 8 > VisitedBlocks
This file defines the DenseMap class.
This file builds on the ADT/GraphTraits.h file to build generic depth first graph iterator.
static bool isSigned(unsigned int Opcode)
This file defines a hash set that can be used to remove duplication of nodes in a graph.
#define op(i)
Hexagon Common GEP
This file provides various utilities for inspecting and working with the control flow graph in LLVM I...
This defines the Use class.
iv Induction Variable Users
Definition IVUsers.cpp:48
const AbstractManglingParser< Derived, Alloc >::OperatorInfo AbstractManglingParser< Derived, Alloc >::Ops[]
static bool isZero(Value *V, const DataLayout &DL, DominatorTree *DT, AssumptionCache *AC)
Definition Lint.cpp:539
#define F(x, y, z)
Definition MD5.cpp:54
#define I(x, y, z)
Definition MD5.cpp:57
#define G(x, y, z)
Definition MD5.cpp:55
#define T
#define T1
ConstantRange Range(APInt(BitWidth, Low), APInt(BitWidth, High))
uint64_t IntrinsicInst * II
#define P(N)
ppc ctr loops verify
PowerPC Reduce CR logical Operation
#define INITIALIZE_PASS_DEPENDENCY(depName)
Definition PassSupport.h:42
#define INITIALIZE_PASS_END(passName, arg, name, cfg, analysis)
Definition PassSupport.h:44
#define INITIALIZE_PASS_BEGIN(passName, arg, name, cfg, analysis)
Definition PassSupport.h:39
R600 Clause Merge
const SmallVectorImpl< MachineOperand > & Cond
static bool isValid(const char C)
Returns true if C is a valid mangled character: <0-9a-zA-Z_>.
SI optimize exec mask operations pre RA
void visit(MachineFunction &MF, MachineBasicBlock &Start, std::function< void(MachineBasicBlock *)> op)
This file contains some templates that are useful if you are working with the STL at all.
This file provides utility classes that use RAII to save and restore values.
bool SCEVMinMaxExprContains(const SCEV *Root, const SCEV *OperandToFind, SCEVTypes RootKind)
static cl::opt< unsigned > MaxAddRecSize("scalar-evolution-max-add-rec-size", cl::Hidden, cl::desc("Max coefficients in AddRec during evolving"), cl::init(8))
static cl::opt< unsigned > RangeIterThreshold("scev-range-iter-threshold", cl::Hidden, cl::desc("Threshold for switching to iteratively computing SCEV ranges"), cl::init(32))
static const Loop * isIntegerLoopHeaderPHI(const PHINode *PN, LoopInfo &LI)
static unsigned getConstantTripCount(const SCEVConstant *ExitCount)
static int CompareValueComplexity(const LoopInfo *const LI, Value *LV, Value *RV, unsigned Depth)
Compare the two values LV and RV in terms of their "complexity" where "complexity" is a partial (and ...
static const SCEV * getNextSCEVDivisibleByDivisor(const SCEV *Expr, const APInt &DivisorVal, ScalarEvolution &SE)
static void PushLoopPHIs(const Loop *L, SmallVectorImpl< Instruction * > &Worklist, SmallPtrSetImpl< Instruction * > &Visited)
Push PHI nodes in the header of the given loop onto the given Worklist.
static void insertFoldCacheEntry(const ScalarEvolution::FoldID &ID, const SCEV *S, DenseMap< ScalarEvolution::FoldID, const SCEV * > &FoldCache, DenseMap< const SCEV *, SmallVector< ScalarEvolution::FoldID, 2 > > &FoldCacheUser)
static cl::opt< bool > ClassifyExpressions("scalar-evolution-classify-expressions", cl::Hidden, cl::init(true), cl::desc("When printing analysis, include information on every instruction"))
static bool CanConstantFold(const Instruction *I)
Return true if we can constant fold an instruction of the specified type, assuming that all operands ...
static cl::opt< unsigned > AddOpsInlineThreshold("scev-addops-inline-threshold", cl::Hidden, cl::desc("Threshold for inlining addition operands into a SCEV"), cl::init(500))
static cl::opt< unsigned > MaxLoopGuardCollectionDepth("scalar-evolution-max-loop-guard-collection-depth", cl::Hidden, cl::desc("Maximum depth for recursive loop guard collection"), cl::init(1))
static cl::opt< bool > VerifyIR("scev-verify-ir", cl::Hidden, cl::desc("Verify IR correctness when making sensitive SCEV queries (slow)"), cl::init(false))
static bool BrPHIToSelect(DominatorTree &DT, BranchInst *BI, PHINode *Merge, Value *&C, Value *&LHS, Value *&RHS)
static const SCEV * getPreStartForExtend(const SCEVAddRecExpr *AR, Type *Ty, ScalarEvolution *SE, unsigned Depth)
static std::optional< APInt > MinOptional(std::optional< APInt > X, std::optional< APInt > Y)
Helper function to compare optional APInts: (a) if X and Y both exist, return min(X,...
static cl::opt< unsigned > MulOpsInlineThreshold("scev-mulops-inline-threshold", cl::Hidden, cl::desc("Threshold for inlining multiplication operands into a SCEV"), cl::init(32))
static void GroupByComplexity(SmallVectorImpl< const SCEV * > &Ops, LoopInfo *LI, DominatorTree &DT)
Given a list of SCEV objects, order them by their complexity, and group objects of the same complexit...
static const SCEV * constantFoldAndGroupOps(ScalarEvolution &SE, LoopInfo &LI, DominatorTree &DT, SmallVectorImpl< const SCEV * > &Ops, FoldT Fold, IsIdentityT IsIdentity, IsAbsorberT IsAbsorber)
Performs a number of common optimizations on the passed Ops.
static bool isDivisibilityGuard(const SCEV *LHS, const SCEV *RHS, ScalarEvolution &SE)
static std::optional< const SCEV * > createNodeForSelectViaUMinSeq(ScalarEvolution *SE, const SCEV *CondExpr, const SCEV *TrueExpr, const SCEV *FalseExpr)
static Constant * BuildConstantFromSCEV(const SCEV *V)
This builds up a Constant using the ConstantExpr interface.
static ConstantInt * EvaluateConstantChrecAtConstant(const SCEVAddRecExpr *AddRec, ConstantInt *C, ScalarEvolution &SE)
static const SCEV * BinomialCoefficient(const SCEV *It, unsigned K, ScalarEvolution &SE, Type *ResultTy)
Compute BC(It, K). The result has width W. Assume, K > 0.
static cl::opt< unsigned > MaxCastDepth("scalar-evolution-max-cast-depth", cl::Hidden, cl::desc("Maximum depth of recursive SExt/ZExt/Trunc"), cl::init(8))
static bool IsMinMaxConsistingOf(const SCEV *MaybeMinMaxExpr, const SCEV *Candidate)
Is MaybeMinMaxExpr an (U|S)(Min|Max) of Candidate and some other values?
static PHINode * getConstantEvolvingPHI(Value *V, const Loop *L)
getConstantEvolvingPHI - Given an LLVM value and a loop, return a PHI node in the loop that V is deri...
static const SCEV * SolveLinEquationWithOverflow(const APInt &A, const SCEV *B, SmallVectorImpl< const SCEVPredicate * > *Predicates, ScalarEvolution &SE, const Loop *L)
Finds the minimum unsigned root of the following equation:
static cl::opt< unsigned > MaxBruteForceIterations("scalar-evolution-max-iterations", cl::ReallyHidden, cl::desc("Maximum number of iterations SCEV will " "symbolically execute a constant " "derived loop"), cl::init(100))
static bool MatchBinarySub(const SCEV *S, const SCEV *&LHS, const SCEV *&RHS)
static uint64_t umul_ov(uint64_t i, uint64_t j, bool &Overflow)
static void PrintSCEVWithTypeHint(raw_ostream &OS, const SCEV *S)
When printing a top-level SCEV for trip counts, it's helpful to include a type for constants which ar...
static void PrintLoopInfo(raw_ostream &OS, ScalarEvolution *SE, const Loop *L)
static bool containsConstantInAddMulChain(const SCEV *StartExpr)
Determine if any of the operands in this SCEV are a constant or if any of the add or multiply express...
static const SCEV * getExtendAddRecStart(const SCEVAddRecExpr *AR, Type *Ty, ScalarEvolution *SE, unsigned Depth)
static bool hasHugeExpression(ArrayRef< const SCEV * > Ops)
Returns true if Ops contains a huge SCEV (the subtree of S contains at least HugeExprThreshold nodes)...
static cl::opt< unsigned > MaxPhiSCCAnalysisSize("scalar-evolution-max-scc-analysis-depth", cl::Hidden, cl::desc("Maximum amount of nodes to process while searching SCEVUnknown " "Phi strongly connected components"), cl::init(8))
static bool IsKnownPredicateViaAddRecStart(ScalarEvolution &SE, CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS)
static bool collectDivisibilityInformation(ICmpInst::Predicate Predicate, const SCEV *LHS, const SCEV *RHS, DenseMap< const SCEV *, const SCEV * > &DivInfo, DenseMap< const SCEV *, APInt > &Multiples, ScalarEvolution &SE)
static cl::opt< unsigned > MaxSCEVOperationsImplicationDepth("scalar-evolution-max-scev-operations-implication-depth", cl::Hidden, cl::desc("Maximum depth of recursive SCEV operations implication analysis"), cl::init(2))
static void PushDefUseChildren(Instruction *I, SmallVectorImpl< Instruction * > &Worklist, SmallPtrSetImpl< Instruction * > &Visited)
Push users of the given Instruction onto the given Worklist.
static std::optional< APInt > SolveQuadraticAddRecRange(const SCEVAddRecExpr *AddRec, const ConstantRange &Range, ScalarEvolution &SE)
Let c(n) be the value of the quadratic chrec {0,+,M,+,N} after n iterations.
static cl::opt< bool > UseContextForNoWrapFlagInference("scalar-evolution-use-context-for-no-wrap-flag-strenghening", cl::Hidden, cl::desc("Infer nuw/nsw flags using context where suitable"), cl::init(true))
static cl::opt< bool > EnableFiniteLoopControl("scalar-evolution-finite-loop", cl::Hidden, cl::desc("Handle <= and >= in finite loops"), cl::init(true))
static std::optional< std::tuple< APInt, APInt, APInt, APInt, unsigned > > GetQuadraticEquation(const SCEVAddRecExpr *AddRec)
For a given quadratic addrec, generate coefficients of the corresponding quadratic equation,...
static bool isKnownPredicateExtendIdiom(CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS)
static std::optional< BinaryOp > MatchBinaryOp(Value *V, const DataLayout &DL, AssumptionCache &AC, const DominatorTree &DT, const Instruction *CxtI)
Try to map V into a BinaryOp, and return std::nullopt on failure.
static std::optional< APInt > SolveQuadraticAddRecExact(const SCEVAddRecExpr *AddRec, ScalarEvolution &SE)
Let c(n) be the value of the quadratic chrec {L,+,M,+,N} after n iterations.
static std::optional< APInt > TruncIfPossible(std::optional< APInt > X, unsigned BitWidth)
Helper function to truncate an optional APInt to a given BitWidth.
static cl::opt< unsigned > MaxSCEVCompareDepth("scalar-evolution-max-scev-compare-depth", cl::Hidden, cl::desc("Maximum depth of recursive SCEV complexity comparisons"), cl::init(32))
static APInt extractConstantWithoutWrapping(ScalarEvolution &SE, const SCEVConstant *ConstantTerm, const SCEVAddExpr *WholeAddExpr)
static cl::opt< unsigned > MaxConstantEvolvingDepth("scalar-evolution-max-constant-evolving-depth", cl::Hidden, cl::desc("Maximum depth of recursive constant evolving"), cl::init(32))
static ConstantRange getRangeForAffineARHelper(APInt Step, const ConstantRange &StartRange, const APInt &MaxBECount, bool Signed)
static std::optional< ConstantRange > GetRangeFromMetadata(Value *V)
Helper method to assign a range to V from metadata present in the IR.
static cl::opt< unsigned > HugeExprThreshold("scalar-evolution-huge-expr-threshold", cl::Hidden, cl::desc("Size of the expression which is considered huge"), cl::init(4096))
static SCEV::NoWrapFlags StrengthenNoWrapFlags(ScalarEvolution *SE, SCEVTypes Type, ArrayRef< const SCEV * > Ops, SCEV::NoWrapFlags Flags)
static Type * isSimpleCastedPHI(const SCEV *Op, const SCEVUnknown *SymbolicPHI, bool &Signed, ScalarEvolution &SE)
Helper function to createAddRecFromPHIWithCasts.
static Constant * EvaluateExpression(Value *V, const Loop *L, DenseMap< Instruction *, Constant * > &Vals, const DataLayout &DL, const TargetLibraryInfo *TLI)
EvaluateExpression - Given an expression that passes the getConstantEvolvingPHI predicate,...
static const SCEV * getPreviousSCEVDivisibleByDivisor(const SCEV *Expr, const APInt &DivisorVal, ScalarEvolution &SE)
static const SCEV * MatchNotExpr(const SCEV *Expr)
If Expr computes ~A, return A else return nullptr.
static cl::opt< unsigned > MaxValueCompareDepth("scalar-evolution-max-value-compare-depth", cl::Hidden, cl::desc("Maximum depth of recursive value complexity comparisons"), cl::init(2))
static const SCEV * applyDivisibilityOnMinMaxExpr(const SCEV *MinMaxExpr, APInt Divisor, ScalarEvolution &SE)
static cl::opt< bool, true > VerifySCEVOpt("verify-scev", cl::Hidden, cl::location(VerifySCEV), cl::desc("Verify ScalarEvolution's backedge taken counts (slow)"))
static const SCEV * getSignedOverflowLimitForStep(const SCEV *Step, ICmpInst::Predicate *Pred, ScalarEvolution *SE)
static cl::opt< unsigned > MaxArithDepth("scalar-evolution-max-arith-depth", cl::Hidden, cl::desc("Maximum depth of recursive arithmetics"), cl::init(32))
static bool HasSameValue(const SCEV *A, const SCEV *B)
SCEV structural equivalence is usually sufficient for testing whether two expressions are equal,...
static uint64_t Choose(uint64_t n, uint64_t k, bool &Overflow)
Compute the result of "n choose k", the binomial coefficient.
static std::optional< int > CompareSCEVComplexity(const LoopInfo *const LI, const SCEV *LHS, const SCEV *RHS, DominatorTree &DT, unsigned Depth=0)
static bool CollectAddOperandsWithScales(SmallDenseMap< const SCEV *, APInt, 16 > &M, SmallVectorImpl< const SCEV * > &NewOps, APInt &AccumulatedConstant, ArrayRef< const SCEV * > Ops, const APInt &Scale, ScalarEvolution &SE)
Process the given Ops list, which is a list of operands to be added under the given scale,...
static bool canConstantEvolve(Instruction *I, const Loop *L)
Determine whether this instruction can constant evolve within this loop assuming its operands can all...
static PHINode * getConstantEvolvingPHIOperands(Instruction *UseInst, const Loop *L, DenseMap< Instruction *, PHINode * > &PHIMap, unsigned Depth)
getConstantEvolvingPHIOperands - Implement getConstantEvolvingPHI by recursing through each instructi...
static bool scevUnconditionallyPropagatesPoisonFromOperands(SCEVTypes Kind)
static cl::opt< bool > VerifySCEVStrict("verify-scev-strict", cl::Hidden, cl::desc("Enable stricter verification with -verify-scev is passed"))
static Constant * getOtherIncomingValue(PHINode *PN, BasicBlock *BB)
static cl::opt< bool > UseExpensiveRangeSharpening("scalar-evolution-use-expensive-range-sharpening", cl::Hidden, cl::init(false), cl::desc("Use more powerful methods of sharpening expression ranges. May " "be costly in terms of compile time"))
static const SCEV * getUnsignedOverflowLimitForStep(const SCEV *Step, ICmpInst::Predicate *Pred, ScalarEvolution *SE)
static bool IsKnownPredicateViaMinOrMax(ScalarEvolution &SE, CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS)
Is LHS Pred RHS true on the virtue of LHS or RHS being a Min or Max expression?
This file defines the make_scope_exit function, which executes user-defined cleanup logic at scope ex...
static bool InBlock(const Value *V, const BasicBlock *BB)
Provides some synthesis utilities to produce sequences of values.
This file defines the SmallPtrSet class.
This file defines the SmallVector class.
This file defines the 'Statistic' class, which is designed to be an easy way to expose various metric...
#define STATISTIC(VARNAME, DESC)
Definition Statistic.h:171
This file contains some functions that are useful when dealing with strings.
#define LLVM_DEBUG(...)
Definition Debug.h:114
static TableGen::Emitter::Opt Y("gen-skeleton-entry", EmitSkeleton, "Generate example skeleton entry")
static TableGen::Emitter::OptClass< SkeletonEmitter > X("gen-skeleton-class", "Generate example skeleton class")
static SymbolRef::Type getType(const Symbol *Sym)
Definition TapiFile.cpp:39
LocallyHashedType DenseMapInfo< LocallyHashedType >::Empty
static std::optional< unsigned > getOpcode(ArrayRef< VPValue * > Values)
Returns the opcode of Values or ~0 if they do not all agree.
Definition VPlanSLP.cpp:247
static std::optional< bool > isImpliedCondOperands(CmpInst::Predicate Pred, const Value *ALHS, const Value *ARHS, const Value *BLHS, const Value *BRHS)
Return true if "icmp Pred BLHS BRHS" is true whenever "icmp PredALHS ARHS" is true.
Virtual Register Rewriter
Value * RHS
Value * LHS
BinaryOperator * Mul
static const uint32_t IV[8]
Definition blake3_impl.h:83
Class for arbitrary precision integers.
Definition APInt.h:78
LLVM_ABI APInt umul_ov(const APInt &RHS, bool &Overflow) const
Definition APInt.cpp:1971
LLVM_ABI APInt zext(unsigned width) const
Zero extend to a new width.
Definition APInt.cpp:1012
bool isMinSignedValue() const
Determine if this is the smallest signed value.
Definition APInt.h:424
uint64_t getZExtValue() const
Get zero extended value.
Definition APInt.h:1541
void setHighBits(unsigned hiBits)
Set the top hiBits bits.
Definition APInt.h:1392
LLVM_ABI APInt getHiBits(unsigned numBits) const
Compute an APInt containing numBits highbits from this APInt.
Definition APInt.cpp:639
unsigned getActiveBits() const
Compute the number of active bits in the value.
Definition APInt.h:1513
LLVM_ABI APInt trunc(unsigned width) const
Truncate to new width.
Definition APInt.cpp:936
static APInt getMaxValue(unsigned numBits)
Gets maximum unsigned value of APInt for specific bit width.
Definition APInt.h:207
APInt abs() const
Get the absolute value.
Definition APInt.h:1796
bool sgt(const APInt &RHS) const
Signed greater than comparison.
Definition APInt.h:1202
bool isAllOnes() const
Determine if all bits are set. This is true for zero-width values.
Definition APInt.h:372
bool ugt(const APInt &RHS) const
Unsigned greater than comparison.
Definition APInt.h:1183
bool isZero() const
Determine if this value is zero, i.e. all bits are clear.
Definition APInt.h:381
bool isSignMask() const
Check if the APInt's value is returned by getSignMask.
Definition APInt.h:467
LLVM_ABI APInt urem(const APInt &RHS) const
Unsigned remainder operation.
Definition APInt.cpp:1666
unsigned getBitWidth() const
Return the number of bits in the APInt.
Definition APInt.h:1489
bool ult(const APInt &RHS) const
Unsigned less than comparison.
Definition APInt.h:1112
static APInt getSignedMaxValue(unsigned numBits)
Gets maximum signed value of APInt for a specific bit width.
Definition APInt.h:210
static APInt getMinValue(unsigned numBits)
Gets minimum unsigned value of APInt for a specific bit width.
Definition APInt.h:217
bool isNegative() const
Determine sign of this APInt.
Definition APInt.h:330
bool sle(const APInt &RHS) const
Signed less or equal comparison.
Definition APInt.h:1167
static APInt getSignedMinValue(unsigned numBits)
Gets minimum signed value of APInt for a specific bit width.
Definition APInt.h:220
bool isNonPositive() const
Determine if this APInt Value is non-positive (<= 0).
Definition APInt.h:362
unsigned countTrailingZeros() const
Definition APInt.h:1648
bool isStrictlyPositive() const
Determine if this APInt Value is positive.
Definition APInt.h:357
unsigned logBase2() const
Definition APInt.h:1762
APInt ashr(unsigned ShiftAmt) const
Arithmetic right-shift function.
Definition APInt.h:828
LLVM_ABI APInt multiplicativeInverse() const
Definition APInt.cpp:1274
bool ule(const APInt &RHS) const
Unsigned less or equal comparison.
Definition APInt.h:1151
LLVM_ABI APInt sext(unsigned width) const
Sign extend to a new width.
Definition APInt.cpp:985
APInt shl(unsigned shiftAmt) const
Left-shift function.
Definition APInt.h:874
bool isPowerOf2() const
Check if this APInt's value is a power of two greater than zero.
Definition APInt.h:441
static APInt getLowBitsSet(unsigned numBits, unsigned loBitsSet)
Constructs an APInt value that has the bottom loBitsSet bits set.
Definition APInt.h:307
bool isSignBitSet() const
Determine if sign bit of this APInt is set.
Definition APInt.h:342
bool slt(const APInt &RHS) const
Signed less than comparison.
Definition APInt.h:1131
static APInt getZero(unsigned numBits)
Get the '0' value for the specified bit-width.
Definition APInt.h:201
bool isIntN(unsigned N) const
Check if this APInt has an N-bits unsigned integer value.
Definition APInt.h:433
static APInt getOneBitSet(unsigned numBits, unsigned BitNo)
Return an APInt with exactly one bit set in the result.
Definition APInt.h:240
bool uge(const APInt &RHS) const
Unsigned greater or equal comparison.
Definition APInt.h:1222
This templated class represents "all analyses that operate over <aparticular IR unit>" (e....
Definition Analysis.h:50
PassT::Result & getResult(IRUnitT &IR, ExtraArgTs... ExtraArgs)
Get the result of an analysis pass for a given IR unit.
Represent the analysis usage information of a pass.
void setPreservesAll()
Set by analyses that do not transform their input at all.
AnalysisUsage & addRequiredTransitive()
ArrayRef - Represent a constant reference to an array (0 or more elements consecutively in memory),...
Definition ArrayRef.h:40
iterator end() const
Definition ArrayRef.h:131
size_t size() const
size - Get the array size.
Definition ArrayRef.h:142
iterator begin() const
Definition ArrayRef.h:130
A function analysis which provides an AssumptionCache.
An immutable pass that tracks lazily created AssumptionCache objects.
A cache of @llvm.assume calls within a function.
MutableArrayRef< WeakVH > assumptions()
Access the list of assumption handles currently tracked for this function.
LLVM_ABI bool isSingleEdge() const
Check if this is the only edge between Start and End.
LLVM Basic Block Representation.
Definition BasicBlock.h:62
iterator begin()
Instruction iterator methods.
Definition BasicBlock.h:459
const Function * getParent() const
Return the enclosing method, or null if none.
Definition BasicBlock.h:213
LLVM_ABI const BasicBlock * getSinglePredecessor() const
Return the predecessor of this block if it has a single predecessor block.
const Instruction & front() const
Definition BasicBlock.h:482
const Instruction * getTerminator() const LLVM_READONLY
Returns the terminator instruction if the block is well formed or null if the block is not well forme...
Definition BasicBlock.h:233
LLVM_ABI unsigned getNoWrapKind() const
Returns one of OBO::NoSignedWrap or OBO::NoUnsignedWrap.
LLVM_ABI Instruction::BinaryOps getBinaryOp() const
Returns the binary operation underlying the intrinsic.
BinaryOps getOpcode() const
Definition InstrTypes.h:374
Conditional or Unconditional Branch instruction.
bool isConditional() const
BasicBlock * getSuccessor(unsigned i) const
bool isUnconditional() const
Value * getCondition() const
This class represents a function call, abstracting a target machine's calling convention.
virtual void deleted()
Callback for Value destruction.
void setValPtr(Value *P)
bool isFalseWhenEqual() const
This is just a convenience.
Definition InstrTypes.h:948
Predicate
This enumeration lists the possible predicates for CmpInst subclasses.
Definition InstrTypes.h:676
@ ICMP_SLT
signed less than
Definition InstrTypes.h:705
@ ICMP_SLE
signed less or equal
Definition InstrTypes.h:706
@ ICMP_UGE
unsigned greater or equal
Definition InstrTypes.h:700
@ ICMP_UGT
unsigned greater than
Definition InstrTypes.h:699
@ ICMP_SGT
signed greater than
Definition InstrTypes.h:703
@ ICMP_ULT
unsigned less than
Definition InstrTypes.h:701
@ ICMP_NE
not equal
Definition InstrTypes.h:698
@ ICMP_SGE
signed greater or equal
Definition InstrTypes.h:704
@ ICMP_ULE
unsigned less or equal
Definition InstrTypes.h:702
bool isSigned() const
Definition InstrTypes.h:930
Predicate getSwappedPredicate() const
For example, EQ->EQ, SLE->SGE, ULT->UGT, OEQ->OEQ, ULE->UGE, OLT->OGT, etc.
Definition InstrTypes.h:827
bool isTrueWhenEqual() const
This is just a convenience.
Definition InstrTypes.h:942
Predicate getInversePredicate() const
For example, EQ -> NE, UGT -> ULE, SLT -> SGE, OEQ -> UNE, UGT -> OLE, OLT -> UGE,...
Definition InstrTypes.h:789
bool isUnsigned() const
Definition InstrTypes.h:936
bool isRelational() const
Return true if the predicate is relational (not EQ or NE).
Definition InstrTypes.h:926
An abstraction over a floating-point predicate, and a pack of an integer predicate with samesign info...
static LLVM_ABI std::optional< CmpPredicate > getMatching(CmpPredicate A, CmpPredicate B)
Compares two CmpPredicates taking samesign into account and returns the canonicalized CmpPredicate if...
LLVM_ABI CmpInst::Predicate getPreferredSignedPredicate() const
Attempts to return a signed CmpInst::Predicate from the CmpPredicate.
CmpInst::Predicate dropSameSign() const
Drops samesign information.
static LLVM_ABI Constant * getNot(Constant *C)
static LLVM_ABI Constant * getPtrToInt(Constant *C, Type *Ty, bool OnlyIfReduced=false)
static Constant * getGetElementPtr(Type *Ty, Constant *C, ArrayRef< Constant * > IdxList, GEPNoWrapFlags NW=GEPNoWrapFlags::none(), std::optional< ConstantRange > InRange=std::nullopt, Type *OnlyIfReducedTy=nullptr)
Getelementptr form.
Definition Constants.h:1279
static LLVM_ABI Constant * getAdd(Constant *C1, Constant *C2, bool HasNUW=false, bool HasNSW=false)
static LLVM_ABI Constant * getNeg(Constant *C, bool HasNSW=false)
static LLVM_ABI Constant * getTrunc(Constant *C, Type *Ty, bool OnlyIfReduced=false)
This is the shared class of boolean and integer constants.
Definition Constants.h:87
bool isZero() const
This is just a convenience method to make client code smaller for a common code.
Definition Constants.h:214
static LLVM_ABI ConstantInt * getFalse(LLVMContext &Context)
uint64_t getZExtValue() const
Return the constant as a 64-bit unsigned integer value after it has been zero extended as appropriate...
Definition Constants.h:163
const APInt & getValue() const
Return the constant as an APInt value reference.
Definition Constants.h:154
static LLVM_ABI ConstantInt * getBool(LLVMContext &Context, bool V)
This class represents a range of values.
LLVM_ABI ConstantRange add(const ConstantRange &Other) const
Return a new range representing the possible values resulting from an addition of a value in this ran...
LLVM_ABI ConstantRange zextOrTrunc(uint32_t BitWidth) const
Make this range have the bit width given by BitWidth.
PreferredRangeType
If represented precisely, the result of some range operations may consist of multiple disjoint ranges...
LLVM_ABI bool getEquivalentICmp(CmpInst::Predicate &Pred, APInt &RHS) const
Set up Pred and RHS such that ConstantRange::makeExactICmpRegion(Pred, RHS) == *this.
const APInt & getLower() const
Return the lower value for this range.
LLVM_ABI bool isFullSet() const
Return true if this set contains all of the elements possible for this data-type.
LLVM_ABI bool icmp(CmpInst::Predicate Pred, const ConstantRange &Other) const
Does the predicate Pred hold between ranges this and Other?
LLVM_ABI bool isEmptySet() const
Return true if this set contains no members.
LLVM_ABI ConstantRange zeroExtend(uint32_t BitWidth) const
Return a new range in the specified integer type, which must be strictly larger than the current type...
LLVM_ABI bool isSignWrappedSet() const
Return true if this set wraps around the signed domain.
LLVM_ABI APInt getSignedMin() const
Return the smallest signed value contained in the ConstantRange.
LLVM_ABI bool isWrappedSet() const
Return true if this set wraps around the unsigned domain.
LLVM_ABI void print(raw_ostream &OS) const
Print out the bounds to a stream.
LLVM_ABI ConstantRange truncate(uint32_t BitWidth, unsigned NoWrapKind=0) const
Return a new range in the specified integer type, which must be strictly smaller than the current typ...
LLVM_ABI ConstantRange signExtend(uint32_t BitWidth) const
Return a new range in the specified integer type, which must be strictly larger than the current type...
const APInt & getUpper() const
Return the upper value for this range.
LLVM_ABI ConstantRange unionWith(const ConstantRange &CR, PreferredRangeType Type=Smallest) const
Return the range that results from the union of this range with another range.
static LLVM_ABI ConstantRange makeExactICmpRegion(CmpInst::Predicate Pred, const APInt &Other)
Produce the exact range such that all values in the returned range satisfy the given predicate with a...
LLVM_ABI bool contains(const APInt &Val) const
Return true if the specified value is in the set.
LLVM_ABI APInt getUnsignedMax() const
Return the largest unsigned value contained in the ConstantRange.
LLVM_ABI ConstantRange intersectWith(const ConstantRange &CR, PreferredRangeType Type=Smallest) const
Return the range that results from the intersection of this range with another range.
LLVM_ABI APInt getSignedMax() const
Return the largest signed value contained in the ConstantRange.
static ConstantRange getNonEmpty(APInt Lower, APInt Upper)
Create non-empty constant range with the given bounds.
static LLVM_ABI ConstantRange makeGuaranteedNoWrapRegion(Instruction::BinaryOps BinOp, const ConstantRange &Other, unsigned NoWrapKind)
Produce the largest range containing all X such that "X BinOp Y" is guaranteed not to wrap (overflow)...
LLVM_ABI unsigned getMinSignedBits() const
Compute the maximal number of bits needed to represent every value in this signed range.
uint32_t getBitWidth() const
Get the bit width of this ConstantRange.
LLVM_ABI ConstantRange sub(const ConstantRange &Other) const
Return a new range representing the possible values resulting from a subtraction of a value in this r...
LLVM_ABI ConstantRange sextOrTrunc(uint32_t BitWidth) const
Make this range have the bit width given by BitWidth.
static LLVM_ABI ConstantRange makeExactNoWrapRegion(Instruction::BinaryOps BinOp, const APInt &Other, unsigned NoWrapKind)
Produce the range that contains X if and only if "X BinOp Other" does not wrap.
This is an important base class in LLVM.
Definition Constant.h:43
A parsed version of the target data layout string in and methods for querying it.
Definition DataLayout.h:63
LLVM_ABI const StructLayout * getStructLayout(StructType *Ty) const
Returns a StructLayout object, indicating the alignment of the struct, its size, and the offsets of i...
LLVM_ABI IntegerType * getIntPtrType(LLVMContext &C, unsigned AddressSpace=0) const
Returns an integer type with size at least as big as that of a pointer in the given address space.
LLVM_ABI unsigned getIndexTypeSizeInBits(Type *Ty) const
The size in bits of the index used in GEP calculation for this type.
LLVM_ABI IntegerType * getIndexType(LLVMContext &C, unsigned AddressSpace) const
Returns the type of a GEP index in AddressSpace.
TypeSize getTypeSizeInBits(Type *Ty) const
Size examples:
Definition DataLayout.h:760
ValueT lookup(const_arg_type_t< KeyT > Val) const
lookup - Return the entry for the specified key, or a default constructed value if no such entry exis...
Definition DenseMap.h:205
iterator find(const_arg_type_t< KeyT > Val)
Definition DenseMap.h:178
std::pair< iterator, bool > try_emplace(KeyT &&Key, Ts &&...Args)
Definition DenseMap.h:256
DenseMapIterator< KeyT, ValueT, KeyInfoT, BucketT > iterator
Definition DenseMap.h:74
iterator find_as(const LookupKeyT &Val)
Alternate version of find() which allows a different, and possibly less expensive,...
Definition DenseMap.h:191
size_type count(const_arg_type_t< KeyT > Val) const
Return 1 if the specified key is in the map, 0 otherwise.
Definition DenseMap.h:174
iterator end()
Definition DenseMap.h:81
bool contains(const_arg_type_t< KeyT > Val) const
Return true if the specified key is in the map, false otherwise.
Definition DenseMap.h:169
void swap(DerivedT &RHS)
Definition DenseMap.h:371
std::pair< iterator, bool > insert(const std::pair< KeyT, ValueT > &KV)
Definition DenseMap.h:241
Analysis pass which computes a DominatorTree.
Definition Dominators.h:283
Legacy analysis pass which computes a DominatorTree.
Definition Dominators.h:321
Concrete subclass of DominatorTreeBase that is used to compute a normal dominator tree.
Definition Dominators.h:164
LLVM_ABI bool isReachableFromEntry(const Use &U) const
Provide an overload for a Use.
LLVM_ABI bool dominates(const BasicBlock *BB, const Use &U) const
Return true if the (end of the) basic block BB dominates the use U.
FoldingSetNodeIDRef - This class describes a reference to an interned FoldingSetNodeID,...
Definition FoldingSet.h:293
FoldingSetNodeID - This class is used to gather all the unique data bits of a node.
Definition FoldingSet.h:330
FunctionPass(char &pid)
Definition Pass.h:316
Represents flags for the getelementptr instruction/expression.
bool hasNoUnsignedSignedWrap() const
bool hasNoUnsignedWrap() const
static GEPNoWrapFlags none()
static LLVM_ABI Type * getTypeAtIndex(Type *Ty, Value *Idx)
Return the type of the element at the given index of an indexable type.
Module * getParent()
Get the module that this global value is contained inside of...
static bool isPrivateLinkage(LinkageTypes Linkage)
static bool isInternalLinkage(LinkageTypes Linkage)
This instruction compares its operands according to the predicate given to the constructor.
CmpPredicate getCmpPredicate() const
static bool isGE(Predicate P)
Return true if the predicate is SGE or UGE.
CmpPredicate getSwappedCmpPredicate() const
static LLVM_ABI bool compare(const APInt &LHS, const APInt &RHS, ICmpInst::Predicate Pred)
Return result of LHS Pred RHS comparison.
static bool isLT(Predicate P)
Return true if the predicate is SLT or ULT.
CmpPredicate getInverseCmpPredicate() const
Predicate getNonStrictCmpPredicate() const
For example, SGT -> SGE, SLT -> SLE, ULT -> ULE, UGT -> UGE.
static bool isGT(Predicate P)
Return true if the predicate is SGT or UGT.
Predicate getFlippedSignednessPredicate() const
For example, SLT->ULT, ULT->SLT, SLE->ULE, ULE->SLE, EQ->EQ.
static CmpPredicate getInverseCmpPredicate(CmpPredicate Pred)
static bool isEquality(Predicate P)
Return true if this predicate is either EQ or NE.
bool isRelational() const
Return true if the predicate is relational (not EQ or NE).
static bool isLE(Predicate P)
Return true if the predicate is SLE or ULE.
LLVM_ABI bool hasNoUnsignedWrap() const LLVM_READONLY
Determine whether the no unsigned wrap flag is set.
LLVM_ABI bool hasNoSignedWrap() const LLVM_READONLY
Determine whether the no signed wrap flag is set.
LLVM_ABI bool isIdenticalToWhenDefined(const Instruction *I, bool IntersectAttrs=false) const LLVM_READONLY
This is like isIdenticalTo, except that it ignores the SubclassOptionalData flags,...
Class to represent integer types.
static LLVM_ABI IntegerType * get(LLVMContext &C, unsigned NumBits)
This static method is the primary way of constructing an IntegerType.
Definition Type.cpp:318
An instruction for reading from memory.
Analysis pass that exposes the LoopInfo for a function.
Definition LoopInfo.h:569
bool contains(const LoopT *L) const
Return true if the specified loop is contained within in this loop.
BlockT * getHeader() const
unsigned getLoopDepth() const
Return the nesting level of this loop.
BlockT * getLoopPredecessor() const
If the given loop's header has exactly one unique predecessor outside the loop, return it.
LoopT * getParentLoop() const
Return the parent loop if it exists or nullptr for top level loops.
unsigned getLoopDepth(const BlockT *BB) const
Return the loop nesting level of the specified block.
LoopT * getLoopFor(const BlockT *BB) const
Return the inner most loop that BB lives in.
The legacy pass manager's analysis pass to compute loop information.
Definition LoopInfo.h:596
Represents a single loop in the control flow graph.
Definition LoopInfo.h:40
bool isLoopInvariant(const Value *V) const
Return true if the specified value is loop invariant.
Definition LoopInfo.cpp:61
Metadata node.
Definition Metadata.h:1078
A Module instance is used to store all the information related to an LLVM module.
Definition Module.h:67
unsigned getOpcode() const
Return the opcode for this Instruction or ConstantExpr.
Definition Operator.h:43
Utility class for integer operators which may exhibit overflow - Add, Sub, Mul, and Shl.
Definition Operator.h:78
bool hasNoSignedWrap() const
Test whether this operation is known to never undergo signed overflow, aka the nsw property.
Definition Operator.h:111
bool hasNoUnsignedWrap() const
Test whether this operation is known to never undergo unsigned overflow, aka the nuw property.
Definition Operator.h:105
iterator_range< const_block_iterator > blocks() const
op_range incoming_values()
Value * getIncomingValueForBlock(const BasicBlock *BB) const
BasicBlock * getIncomingBlock(unsigned i) const
Return incoming basic block number i.
Value * getIncomingValue(unsigned i) const
Return incoming value number x.
unsigned getNumIncomingValues() const
Return the number of incoming edges.
AnalysisType & getAnalysis() const
getAnalysis<AnalysisType>() - This function is used by subclasses to get to the analysis information ...
PointerIntPair - This class implements a pair of a pointer and small integer.
static PointerType * getUnqual(Type *ElementType)
This constructs a pointer to an object of the specified type in the default address space (address sp...
static LLVM_ABI PoisonValue * get(Type *T)
Static factory methods - Return an 'poison' object of the specified type.
LLVM_ABI void addPredicate(const SCEVPredicate &Pred)
Adds a new predicate.
LLVM_ABI const SCEVPredicate & getPredicate() const
LLVM_ABI bool hasNoOverflow(Value *V, SCEVWrapPredicate::IncrementWrapFlags Flags)
Returns true if we've proved that V doesn't wrap by means of a SCEV predicate.
LLVM_ABI void setNoOverflow(Value *V, SCEVWrapPredicate::IncrementWrapFlags Flags)
Proves that V doesn't overflow by adding SCEV predicate.
LLVM_ABI void print(raw_ostream &OS, unsigned Depth) const
Print the SCEV mappings done by the Predicated Scalar Evolution.
LLVM_ABI bool areAddRecsEqualWithPreds(const SCEVAddRecExpr *AR1, const SCEVAddRecExpr *AR2) const
Check if AR1 and AR2 are equal, while taking into account Equal predicates in Preds.
LLVM_ABI PredicatedScalarEvolution(ScalarEvolution &SE, Loop &L)
LLVM_ABI const SCEVAddRecExpr * getAsAddRec(Value *V)
Attempts to produce an AddRecExpr for V by adding additional SCEV predicates.
LLVM_ABI unsigned getSmallConstantMaxTripCount()
Returns the upper bound of the loop trip count as a normal unsigned value, or 0 if the trip count is ...
LLVM_ABI const SCEV * getBackedgeTakenCount()
Get the (predicated) backedge count for the analyzed loop.
LLVM_ABI const SCEV * getSymbolicMaxBackedgeTakenCount()
Get the (predicated) symbolic max backedge count for the analyzed loop.
LLVM_ABI const SCEV * getSCEV(Value *V)
Returns the SCEV expression of V, in the context of the current SCEV predicate.
A set of analyses that are preserved following a run of a transformation pass.
Definition Analysis.h:112
static PreservedAnalyses all()
Construct a special preserved set that preserves all passes.
Definition Analysis.h:118
PreservedAnalysisChecker getChecker() const
Build a checker for this PreservedAnalyses and the specified analysis type.
Definition Analysis.h:275
constexpr bool isValid() const
Definition Register.h:112
This node represents an addition of some number of SCEVs.
This node represents a polynomial recurrence on the trip count of the specified loop.
LLVM_ABI const SCEV * evaluateAtIteration(const SCEV *It, ScalarEvolution &SE) const
Return the value of this chain of recurrences at the specified iteration number.
const SCEV * getStepRecurrence(ScalarEvolution &SE) const
Constructs and returns the recurrence indicating how much this expression steps by.
void setNoWrapFlags(NoWrapFlags Flags)
Set flags for a recurrence without clearing any previously set flags.
bool isAffine() const
Return true if this represents an expression A + B*x where A and B are loop invariant values.
bool isQuadratic() const
Return true if this represents an expression A + B*x + C*x^2 where A, B and C are loop invariant valu...
LLVM_ABI const SCEV * getNumIterationsInRange(const ConstantRange &Range, ScalarEvolution &SE) const
Return the number of iterations of this loop that produce values in the specified constant range.
LLVM_ABI const SCEVAddRecExpr * getPostIncExpr(ScalarEvolution &SE) const
Return an expression representing the value of this expression one iteration of the loop ahead.
This is the base class for unary cast operator classes.
const SCEV * getOperand() const
LLVM_ABI SCEVCastExpr(const FoldingSetNodeIDRef ID, SCEVTypes SCEVTy, const SCEV *op, Type *ty)
void setNoWrapFlags(NoWrapFlags Flags)
Set flags for a non-recurrence without clearing previously set flags.
This class represents an assumption that the expression LHS Pred RHS evaluates to true,...
SCEVComparePredicate(const FoldingSetNodeIDRef ID, const ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS)
bool isAlwaysTrue() const override
Returns true if the predicate is always true.
void print(raw_ostream &OS, unsigned Depth=0) const override
Prints a textual representation of this predicate with an indentation of Depth.
bool implies(const SCEVPredicate *N, ScalarEvolution &SE) const override
Implementation of the SCEVPredicate interface.
This class represents a constant integer value.
ConstantInt * getValue() const
const APInt & getAPInt() const
This is the base class for unary integral cast operator classes.
LLVM_ABI SCEVIntegralCastExpr(const FoldingSetNodeIDRef ID, SCEVTypes SCEVTy, const SCEV *op, Type *ty)
This node is the base class min/max selections.
static enum SCEVTypes negate(enum SCEVTypes T)
This node represents multiplication of some number of SCEVs.
This node is a base class providing common functionality for n'ary operators.
NoWrapFlags getNoWrapFlags(NoWrapFlags Mask=NoWrapMask) const
const SCEV * getOperand(unsigned i) const
ArrayRef< const SCEV * > operands() const
This class represents an assumption made using SCEV expressions which can be checked at run-time.
SCEVPredicate(const SCEVPredicate &)=default
virtual bool implies(const SCEVPredicate *N, ScalarEvolution &SE) const =0
Returns true if this predicate implies N.
SCEVPredicateKind Kind
This class represents a cast from a pointer to a pointer-sized integer value.
This visitor recursively visits a SCEV expression and re-writes it.
const SCEV * visitSignExtendExpr(const SCEVSignExtendExpr *Expr)
const SCEV * visit(const SCEV *S)
const SCEV * visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr)
const SCEV * visitSMinExpr(const SCEVSMinExpr *Expr)
const SCEV * visitUMinExpr(const SCEVUMinExpr *Expr)
This class represents a signed minimum selection.
This node is the base class for sequential/in-order min/max selections.
static SCEVTypes getEquivalentNonSequentialSCEVType(SCEVTypes Ty)
This class represents a sign extension of a small integer value to a larger integer value.
Visit all nodes in the expression tree using worklist traversal.
This class represents a truncation of an integer value to a smaller integer value.
This class represents a binary unsigned division operation.
This class represents an unsigned minimum selection.
This class represents a composition of other SCEV predicates, and is the class that most clients will...
void print(raw_ostream &OS, unsigned Depth) const override
Prints a textual representation of this predicate with an indentation of Depth.
bool implies(const SCEVPredicate *N, ScalarEvolution &SE) const override
Returns true if this predicate implies N.
SCEVUnionPredicate(ArrayRef< const SCEVPredicate * > Preds, ScalarEvolution &SE)
Union predicates don't get cached so create a dummy set ID for it.
bool isAlwaysTrue() const override
Implementation of the SCEVPredicate interface.
This means that we are dealing with an entirely unknown SCEV value, and only represent it as its LLVM...
This class represents the value of vscale, as used when defining the length of a scalable vector or r...
This class represents an assumption made on an AddRec expression.
IncrementWrapFlags
Similar to SCEV::NoWrapFlags, but with slightly different semantics for FlagNUSW.
SCEVWrapPredicate(const FoldingSetNodeIDRef ID, const SCEVAddRecExpr *AR, IncrementWrapFlags Flags)
bool implies(const SCEVPredicate *N, ScalarEvolution &SE) const override
Returns true if this predicate implies N.
static SCEVWrapPredicate::IncrementWrapFlags setFlags(SCEVWrapPredicate::IncrementWrapFlags Flags, SCEVWrapPredicate::IncrementWrapFlags OnFlags)
void print(raw_ostream &OS, unsigned Depth=0) const override
Prints a textual representation of this predicate with an indentation of Depth.
bool isAlwaysTrue() const override
Returns true if the predicate is always true.
const SCEVAddRecExpr * getExpr() const
Implementation of the SCEVPredicate interface.
static SCEVWrapPredicate::IncrementWrapFlags clearFlags(SCEVWrapPredicate::IncrementWrapFlags Flags, SCEVWrapPredicate::IncrementWrapFlags OffFlags)
Convenient IncrementWrapFlags manipulation methods.
static SCEVWrapPredicate::IncrementWrapFlags getImpliedFlags(const SCEVAddRecExpr *AR, ScalarEvolution &SE)
Returns the set of SCEVWrapPredicate no wrap flags implied by a SCEVAddRecExpr.
IncrementWrapFlags getFlags() const
Returns the set assumed no overflow flags.
This class represents a zero extension of a small integer value to a larger integer value.
This class represents an analyzed expression in the program.
LLVM_ABI ArrayRef< const SCEV * > operands() const
Return operands of this SCEV expression.
unsigned short getExpressionSize() const
LLVM_ABI bool isOne() const
Return true if the expression is a constant one.
LLVM_ABI bool isZero() const
Return true if the expression is a constant zero.
LLVM_ABI void dump() const
This method is used for debugging.
LLVM_ABI bool isAllOnesValue() const
Return true if the expression is a constant all-ones value.
LLVM_ABI bool isNonConstantNegative() const
Return true if the specified scev is negated, but not a constant.
LLVM_ABI void print(raw_ostream &OS) const
Print out the internal representation of this scalar to the specified stream.
SCEV(const FoldingSetNodeIDRef ID, SCEVTypes SCEVTy, unsigned short ExpressionSize)
SCEVTypes getSCEVType() const
LLVM_ABI Type * getType() const
Return the LLVM type of this SCEV expression.
NoWrapFlags
NoWrapFlags are bitfield indices into SubclassData.
Analysis pass that exposes the ScalarEvolution for a function.
LLVM_ABI ScalarEvolution run(Function &F, FunctionAnalysisManager &AM)
LLVM_ABI PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM)
LLVM_ABI PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM)
void getAnalysisUsage(AnalysisUsage &AU) const override
getAnalysisUsage - This function should be overriden by passes that need analysis information to do t...
void print(raw_ostream &OS, const Module *=nullptr) const override
print - Print out the internal state of the pass.
bool runOnFunction(Function &F) override
runOnFunction - Virtual method overriden by subclasses to do the per-function processing of the pass.
void releaseMemory() override
releaseMemory() - This member can be implemented by a pass if it wants to be able to release its memo...
void verifyAnalysis() const override
verifyAnalysis() - This member can be implemented by a analysis pass to check state of analysis infor...
static LLVM_ABI LoopGuards collect(const Loop *L, ScalarEvolution &SE)
Collect rewrite map for loop guards for loop L, together with flags indicating if NUW and NSW can be ...
LLVM_ABI const SCEV * rewrite(const SCEV *Expr) const
Try to apply the collected loop guards to Expr.
The main scalar evolution driver.
const SCEV * getConstantMaxBackedgeTakenCount(const Loop *L)
When successful, this returns a SCEVConstant that is greater than or equal to (i.e.
static bool hasFlags(SCEV::NoWrapFlags Flags, SCEV::NoWrapFlags TestFlags)
const DataLayout & getDataLayout() const
Return the DataLayout associated with the module this SCEV instance is operating on.
LLVM_ABI bool isKnownNonNegative(const SCEV *S)
Test if the given expression is known to be non-negative.
LLVM_ABI bool isKnownOnEveryIteration(CmpPredicate Pred, const SCEVAddRecExpr *LHS, const SCEV *RHS)
Test if the condition described by Pred, LHS, RHS is known to be true on every iteration of the loop ...
LLVM_ABI const SCEV * getNegativeSCEV(const SCEV *V, SCEV::NoWrapFlags Flags=SCEV::FlagAnyWrap)
Return the SCEV object corresponding to -V.
LLVM_ABI std::optional< LoopInvariantPredicate > getLoopInvariantExitCondDuringFirstIterationsImpl(CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS, const Loop *L, const Instruction *CtxI, const SCEV *MaxIter)
LLVM_ABI const SCEV * getSMaxExpr(const SCEV *LHS, const SCEV *RHS)
LLVM_ABI const SCEV * getUDivCeilSCEV(const SCEV *N, const SCEV *D)
Compute ceil(N / D).
LLVM_ABI std::optional< LoopInvariantPredicate > getLoopInvariantExitCondDuringFirstIterations(CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS, const Loop *L, const Instruction *CtxI, const SCEV *MaxIter)
If the result of the predicate LHS Pred RHS is loop invariant with respect to L at given Context duri...
LLVM_ABI Type * getWiderType(Type *Ty1, Type *Ty2) const
LLVM_ABI const SCEV * getAbsExpr(const SCEV *Op, bool IsNSW)
LLVM_ABI bool isKnownNonPositive(const SCEV *S)
Test if the given expression is known to be non-positive.
LLVM_ABI const SCEV * getURemExpr(const SCEV *LHS, const SCEV *RHS)
Represents an unsigned remainder expression based on unsigned division.
LLVM_ABI bool isKnownNegative(const SCEV *S)
Test if the given expression is known to be negative.
LLVM_ABI const SCEV * getPredicatedConstantMaxBackedgeTakenCount(const Loop *L, SmallVectorImpl< const SCEVPredicate * > &Predicates)
Similar to getConstantMaxBackedgeTakenCount, except it will add a set of SCEV predicates to Predicate...
LLVM_ABI const SCEV * removePointerBase(const SCEV *S)
Compute an expression equivalent to S - getPointerBase(S).
LLVM_ABI bool isLoopEntryGuardedByCond(const Loop *L, CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS)
Test whether entry to the loop is protected by a conditional between LHS and RHS.
LLVM_ABI bool isKnownNonZero(const SCEV *S)
Test if the given expression is known to be non-zero.
LLVM_ABI const SCEV * getSCEVAtScope(const SCEV *S, const Loop *L)
Return a SCEV expression for the specified value at the specified scope in the program.
LLVM_ABI const SCEV * getSMinExpr(const SCEV *LHS, const SCEV *RHS)
LLVM_ABI const SCEV * getBackedgeTakenCount(const Loop *L, ExitCountKind Kind=Exact)
If the specified loop has a predictable backedge-taken count, return it, otherwise return a SCEVCould...
LLVM_ABI const SCEV * getUMaxExpr(const SCEV *LHS, const SCEV *RHS)
LLVM_ABI void setNoWrapFlags(SCEVAddRecExpr *AddRec, SCEV::NoWrapFlags Flags)
Update no-wrap flags of an AddRec.
LLVM_ABI const SCEV * getUMaxFromMismatchedTypes(const SCEV *LHS, const SCEV *RHS)
Promote the operands to the wider of the types using zero-extension, and then perform a umax operatio...
const SCEV * getZero(Type *Ty)
Return a SCEV for the constant 0 of a specific type.
LLVM_ABI bool willNotOverflow(Instruction::BinaryOps BinOp, bool Signed, const SCEV *LHS, const SCEV *RHS, const Instruction *CtxI=nullptr)
Is operation BinOp between LHS and RHS provably does not have a signed/unsigned overflow (Signed)?
LLVM_ABI ExitLimit computeExitLimitFromCond(const Loop *L, Value *ExitCond, bool ExitIfTrue, bool ControlsOnlyExit, bool AllowPredicates=false)
Compute the number of times the backedge of the specified loop will execute if its exit condition wer...
LLVM_ABI const SCEV * getZeroExtendExprImpl(const SCEV *Op, Type *Ty, unsigned Depth=0)
LLVM_ABI const SCEVPredicate * getEqualPredicate(const SCEV *LHS, const SCEV *RHS)
LLVM_ABI unsigned getSmallConstantTripMultiple(const Loop *L, const SCEV *ExitCount)
Returns the largest constant divisor of the trip count as a normal unsigned value,...
LLVM_ABI uint64_t getTypeSizeInBits(Type *Ty) const
Return the size in bits of the specified type, for which isSCEVable must return true.
LLVM_ABI const SCEV * getConstant(ConstantInt *V)
LLVM_ABI const SCEV * getPredicatedBackedgeTakenCount(const Loop *L, SmallVectorImpl< const SCEVPredicate * > &Predicates)
Similar to getBackedgeTakenCount, except it will add a set of SCEV predicates to Predicates that are ...
LLVM_ABI const SCEV * getSCEV(Value *V)
Return a SCEV expression for the full generality of the specified expression.
ConstantRange getSignedRange(const SCEV *S)
Determine the signed range for a particular SCEV.
LLVM_ABI const SCEV * getNoopOrSignExtend(const SCEV *V, Type *Ty)
Return a SCEV corresponding to a conversion of the input value to the specified type.
bool loopHasNoAbnormalExits(const Loop *L)
Return true if the loop has no abnormal exits.
LLVM_ABI const SCEV * getTripCountFromExitCount(const SCEV *ExitCount)
A version of getTripCountFromExitCount below which always picks an evaluation type which can not resu...
LLVM_ABI ScalarEvolution(Function &F, TargetLibraryInfo &TLI, AssumptionCache &AC, DominatorTree &DT, LoopInfo &LI)
const SCEV * getOne(Type *Ty)
Return a SCEV for the constant 1 of a specific type.
LLVM_ABI const SCEV * getTruncateOrNoop(const SCEV *V, Type *Ty)
Return a SCEV corresponding to a conversion of the input value to the specified type.
LLVM_ABI const SCEV * getCastExpr(SCEVTypes Kind, const SCEV *Op, Type *Ty)
LLVM_ABI const SCEV * getSequentialMinMaxExpr(SCEVTypes Kind, SmallVectorImpl< const SCEV * > &Operands)
LLVM_ABI const SCEV * getLosslessPtrToIntExpr(const SCEV *Op, unsigned Depth=0)
LLVM_ABI std::optional< bool > evaluatePredicateAt(CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS, const Instruction *CtxI)
Check whether the condition described by Pred, LHS, and RHS is true or false in the given Context.
LLVM_ABI unsigned getSmallConstantMaxTripCount(const Loop *L, SmallVectorImpl< const SCEVPredicate * > *Predicates=nullptr)
Returns the upper bound of the loop trip count as a normal unsigned value.
LLVM_ABI const SCEV * getPtrToIntExpr(const SCEV *Op, Type *Ty)
LLVM_ABI bool isBackedgeTakenCountMaxOrZero(const Loop *L)
Return true if the backedge taken count is either the value returned by getConstantMaxBackedgeTakenCo...
LLVM_ABI void forgetLoop(const Loop *L)
This method should be called by the client when it has changed a loop in a way that may effect Scalar...
LLVM_ABI bool isLoopInvariant(const SCEV *S, const Loop *L)
Return true if the value of the given SCEV is unchanging in the specified loop.
LLVM_ABI bool isKnownPositive(const SCEV *S)
Test if the given expression is known to be positive.
APInt getUnsignedRangeMin(const SCEV *S)
Determine the min of the unsigned range for a particular SCEV.
LLVM_ABI bool SimplifyICmpOperands(CmpPredicate &Pred, const SCEV *&LHS, const SCEV *&RHS, unsigned Depth=0)
Simplify LHS and RHS in a comparison with predicate Pred.
LLVM_ABI const SCEV * getOffsetOfExpr(Type *IntTy, StructType *STy, unsigned FieldNo)
Return an expression for offsetof on the given field with type IntTy.
LLVM_ABI LoopDisposition getLoopDisposition(const SCEV *S, const Loop *L)
Return the "disposition" of the given SCEV with respect to the given loop.
LLVM_ABI bool containsAddRecurrence(const SCEV *S)
Return true if the SCEV is a scAddRecExpr or it contains scAddRecExpr.
LLVM_ABI const SCEV * getSignExtendExprImpl(const SCEV *Op, Type *Ty, unsigned Depth=0)
LLVM_ABI const SCEV * getAddRecExpr(const SCEV *Start, const SCEV *Step, const Loop *L, SCEV::NoWrapFlags Flags)
Get an add recurrence expression for the specified loop.
LLVM_ABI bool hasOperand(const SCEV *S, const SCEV *Op) const
Test whether the given SCEV has Op as a direct or indirect operand.
LLVM_ABI const SCEV * getUDivExpr(const SCEV *LHS, const SCEV *RHS)
Get a canonical unsigned division expression, or something simpler if possible.
LLVM_ABI const SCEV * getZeroExtendExpr(const SCEV *Op, Type *Ty, unsigned Depth=0)
LLVM_ABI bool isSCEVable(Type *Ty) const
Test if values of the given type are analyzable within the SCEV framework.
LLVM_ABI Type * getEffectiveSCEVType(Type *Ty) const
Return a type with the same bitwidth as the given type and which represents how SCEV will treat the g...
LLVM_ABI const SCEVPredicate * getComparePredicate(ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS)
LLVM_ABI bool haveSameSign(const SCEV *S1, const SCEV *S2)
Return true if we know that S1 and S2 must have the same sign.
LLVM_ABI const SCEV * getNotSCEV(const SCEV *V)
Return the SCEV object corresponding to ~V.
LLVM_ABI const SCEV * getElementCount(Type *Ty, ElementCount EC, SCEV::NoWrapFlags Flags=SCEV::FlagAnyWrap)
LLVM_ABI bool instructionCouldExistWithOperands(const SCEV *A, const SCEV *B)
Return true if there exists a point in the program at which both A and B could be operands to the sam...
ConstantRange getUnsignedRange(const SCEV *S)
Determine the unsigned range for a particular SCEV.
LLVM_ABI void print(raw_ostream &OS) const
LLVM_ABI const SCEV * getUMinExpr(const SCEV *LHS, const SCEV *RHS, bool Sequential=false)
LLVM_ABI const SCEV * getPredicatedExitCount(const Loop *L, const BasicBlock *ExitingBlock, SmallVectorImpl< const SCEVPredicate * > *Predicates, ExitCountKind Kind=Exact)
Same as above except this uses the predicated backedge taken info and may require predicates.
static SCEV::NoWrapFlags clearFlags(SCEV::NoWrapFlags Flags, SCEV::NoWrapFlags OffFlags)
LLVM_ABI void forgetTopmostLoop(const Loop *L)
LLVM_ABI void forgetValue(Value *V)
This method should be called by the client when it has changed a value in a way that may effect its v...
APInt getSignedRangeMin(const SCEV *S)
Determine the min of the signed range for a particular SCEV.
LLVM_ABI const SCEV * getNoopOrAnyExtend(const SCEV *V, Type *Ty)
Return a SCEV corresponding to a conversion of the input value to the specified type.
LLVM_ABI void forgetBlockAndLoopDispositions(Value *V=nullptr)
Called when the client has changed the disposition of values in a loop or block.
LLVM_ABI const SCEV * getTruncateExpr(const SCEV *Op, Type *Ty, unsigned Depth=0)
LLVM_ABI std::optional< LoopInvariantPredicate > getLoopInvariantPredicate(CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS, const Loop *L, const Instruction *CtxI=nullptr)
If the result of the predicate LHS Pred RHS is loop invariant with respect to L, return a LoopInvaria...
LLVM_ABI const SCEV * getStoreSizeOfExpr(Type *IntTy, Type *StoreTy)
Return an expression for the store size of StoreTy that is type IntTy.
LLVM_ABI const SCEVPredicate * getWrapPredicate(const SCEVAddRecExpr *AR, SCEVWrapPredicate::IncrementWrapFlags AddedFlags)
LLVM_ABI bool isLoopBackedgeGuardedByCond(const Loop *L, CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS)
Test whether the backedge of the loop is protected by a conditional between LHS and RHS.
LLVM_ABI const SCEV * getMinusSCEV(const SCEV *LHS, const SCEV *RHS, SCEV::NoWrapFlags Flags=SCEV::FlagAnyWrap, unsigned Depth=0)
Return LHS-RHS.
LLVM_ABI APInt getNonZeroConstantMultiple(const SCEV *S)
const SCEV * getMinusOne(Type *Ty)
Return a SCEV for the constant -1 of a specific type.
static SCEV::NoWrapFlags setFlags(SCEV::NoWrapFlags Flags, SCEV::NoWrapFlags OnFlags)
LLVM_ABI bool hasLoopInvariantBackedgeTakenCount(const Loop *L)
Return true if the specified loop has an analyzable loop-invariant backedge-taken count.
LLVM_ABI BlockDisposition getBlockDisposition(const SCEV *S, const BasicBlock *BB)
Return the "disposition" of the given SCEV with respect to the given block.
LLVM_ABI const SCEV * getNoopOrZeroExtend(const SCEV *V, Type *Ty)
Return a SCEV corresponding to a conversion of the input value to the specified type.
LLVM_ABI bool invalidate(Function &F, const PreservedAnalyses &PA, FunctionAnalysisManager::Invalidator &Inv)
LLVM_ABI const SCEV * getUMinFromMismatchedTypes(const SCEV *LHS, const SCEV *RHS, bool Sequential=false)
Promote the operands to the wider of the types using zero-extension, and then perform a umin operatio...
LLVM_ABI bool loopIsFiniteByAssumption(const Loop *L)
Return true if this loop is finite by assumption.
LLVM_ABI const SCEV * getExistingSCEV(Value *V)
Return an existing SCEV for V if there is one, otherwise return nullptr.
LLVM_ABI APInt getConstantMultiple(const SCEV *S, const Instruction *CtxI=nullptr)
Returns the max constant multiple of S.
LoopDisposition
An enum describing the relationship between a SCEV and a loop.
@ LoopComputable
The SCEV varies predictably with the loop.
@ LoopVariant
The SCEV is loop-variant (unknown).
@ LoopInvariant
The SCEV is loop-invariant.
LLVM_ABI bool isKnownMultipleOf(const SCEV *S, uint64_t M, SmallVectorImpl< const SCEVPredicate * > &Assumptions)
Check that S is a multiple of M.
LLVM_ABI const SCEV * getAnyExtendExpr(const SCEV *Op, Type *Ty)
getAnyExtendExpr - Return a SCEV for the given operand extended with unspecified bits out to the give...
LLVM_ABI bool isKnownToBeAPowerOfTwo(const SCEV *S, bool OrZero=false, bool OrNegative=false)
Test if the given expression is known to be a power of 2.
LLVM_ABI std::optional< SCEV::NoWrapFlags > getStrengthenedNoWrapFlagsFromBinOp(const OverflowingBinaryOperator *OBO)
Parse NSW/NUW flags from add/sub/mul IR binary operation Op into SCEV no-wrap flags,...
LLVM_ABI void forgetLcssaPhiWithNewPredecessor(Loop *L, PHINode *V)
Forget LCSSA phi node V of loop L to which a new predecessor was added, such that it may no longer be...
LLVM_ABI bool containsUndefs(const SCEV *S) const
Return true if the SCEV expression contains an undef value.
LLVM_ABI std::optional< MonotonicPredicateType > getMonotonicPredicateType(const SCEVAddRecExpr *LHS, ICmpInst::Predicate Pred)
If, for all loop invariant X, the predicate "LHS `Pred` X" is monotonically increasing or decreasing,...
LLVM_ABI const SCEV * getCouldNotCompute()
LLVM_ABI bool isAvailableAtLoopEntry(const SCEV *S, const Loop *L)
Determine if the SCEV can be evaluated at loop's entry.
LLVM_ABI uint32_t getMinTrailingZeros(const SCEV *S, const Instruction *CtxI=nullptr)
Determine the minimum number of zero bits that S is guaranteed to end in (at every loop iteration).
BlockDisposition
An enum describing the relationship between a SCEV and a basic block.
@ DominatesBlock
The SCEV dominates the block.
@ ProperlyDominatesBlock
The SCEV properly dominates the block.
@ DoesNotDominateBlock
The SCEV does not dominate the block.
LLVM_ABI const SCEV * getGEPExpr(GEPOperator *GEP, ArrayRef< const SCEV * > IndexExprs)
Returns an expression for a GEP.
LLVM_ABI const SCEV * getExitCount(const Loop *L, const BasicBlock *ExitingBlock, ExitCountKind Kind=Exact)
Return the number of times the backedge executes before the given exit would be taken; if not exactly...
LLVM_ABI const SCEV * getSignExtendExpr(const SCEV *Op, Type *Ty, unsigned Depth=0)
LLVM_ABI void getPoisonGeneratingValues(SmallPtrSetImpl< const Value * > &Result, const SCEV *S)
Return the set of Values that, if poison, will definitively result in S being poison as well.
LLVM_ABI void forgetLoopDispositions()
Called when the client has changed the disposition of values in this loop.
LLVM_ABI const SCEV * getVScale(Type *Ty)
LLVM_ABI unsigned getSmallConstantTripCount(const Loop *L)
Returns the exact trip count of the loop if we can compute it, and the result is a small constant.
LLVM_ABI bool hasComputableLoopEvolution(const SCEV *S, const Loop *L)
Return true if the given SCEV changes value in a known way in the specified loop.
LLVM_ABI const SCEV * getPointerBase(const SCEV *V)
Transitively follow the chain of pointer-type operands until reaching a SCEV that does not have a sin...
LLVM_ABI const SCEV * getMinMaxExpr(SCEVTypes Kind, SmallVectorImpl< const SCEV * > &Operands)
LLVM_ABI void forgetAllLoops()
LLVM_ABI bool dominates(const SCEV *S, const BasicBlock *BB)
Return true if elements that makes up the given SCEV dominate the specified basic block.
APInt getUnsignedRangeMax(const SCEV *S)
Determine the max of the unsigned range for a particular SCEV.
ExitCountKind
The terms "backedge taken count" and "exit count" are used interchangeably to refer to the number of ...
@ SymbolicMaximum
An expression which provides an upper bound on the exact trip count.
@ ConstantMaximum
A constant which provides an upper bound on the exact trip count.
@ Exact
An expression exactly describing the number of times the backedge has executed when a loop is exited.
LLVM_ABI const SCEV * applyLoopGuards(const SCEV *Expr, const Loop *L)
Try to apply information from loop guards for L to Expr.
LLVM_ABI const SCEV * getMulExpr(SmallVectorImpl< const SCEV * > &Ops, SCEV::NoWrapFlags Flags=SCEV::FlagAnyWrap, unsigned Depth=0)
Get a canonical multiply expression, or something simpler if possible.
LLVM_ABI const SCEVAddRecExpr * convertSCEVToAddRecWithPredicates(const SCEV *S, const Loop *L, SmallVectorImpl< const SCEVPredicate * > &Preds)
Tries to convert the S expression to an AddRec expression, adding additional predicates to Preds as r...
LLVM_ABI const SCEV * getElementSize(Instruction *Inst)
Return the size of an element read or written by Inst.
LLVM_ABI const SCEV * getSizeOfExpr(Type *IntTy, TypeSize Size)
Return an expression for a TypeSize.
LLVM_ABI std::optional< bool > evaluatePredicate(CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS)
Check whether the condition described by Pred, LHS, and RHS is true or false.
LLVM_ABI const SCEV * getUnknown(Value *V)
LLVM_ABI std::optional< std::pair< const SCEV *, SmallVector< const SCEVPredicate *, 3 > > > createAddRecFromPHIWithCasts(const SCEVUnknown *SymbolicPHI)
Checks if SymbolicPHI can be rewritten as an AddRecExpr under some Predicates.
LLVM_ABI const SCEV * getTruncateOrZeroExtend(const SCEV *V, Type *Ty, unsigned Depth=0)
Return a SCEV corresponding to a conversion of the input value to the specified type.
static SCEV::NoWrapFlags maskFlags(SCEV::NoWrapFlags Flags, int Mask)
Convenient NoWrapFlags manipulation that hides enum casts and is visible in the ScalarEvolution name ...
LLVM_ABI std::optional< APInt > computeConstantDifference(const SCEV *LHS, const SCEV *RHS)
Compute LHS - RHS and returns the result as an APInt if it is a constant, and std::nullopt if it isn'...
LLVM_ABI bool properlyDominates(const SCEV *S, const BasicBlock *BB)
Return true if elements that makes up the given SCEV properly dominate the specified basic block.
LLVM_ABI const SCEV * rewriteUsingPredicate(const SCEV *S, const Loop *L, const SCEVPredicate &A)
Re-writes the SCEV according to the Predicates in A.
LLVM_ABI std::pair< const SCEV *, const SCEV * > SplitIntoInitAndPostInc(const Loop *L, const SCEV *S)
Splits SCEV expression S into two SCEVs.
LLVM_ABI bool canReuseInstruction(const SCEV *S, Instruction *I, SmallVectorImpl< Instruction * > &DropPoisonGeneratingInsts)
Check whether it is poison-safe to represent the expression S using the instruction I.
LLVM_ABI bool isKnownPredicateAt(CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS, const Instruction *CtxI)
Test if the given expression is known to satisfy the condition described by Pred, LHS,...
LLVM_ABI const SCEV * getPredicatedSymbolicMaxBackedgeTakenCount(const Loop *L, SmallVectorImpl< const SCEVPredicate * > &Predicates)
Similar to getSymbolicMaxBackedgeTakenCount, except it will add a set of SCEV predicates to Predicate...
LLVM_ABI const SCEV * getUDivExactExpr(const SCEV *LHS, const SCEV *RHS)
Get a canonical unsigned division expression, or something simpler if possible.
LLVM_ABI void registerUser(const SCEV *User, ArrayRef< const SCEV * > Ops)
Notify this ScalarEvolution that User directly uses SCEVs in Ops.
LLVM_ABI const SCEV * getAddExpr(SmallVectorImpl< const SCEV * > &Ops, SCEV::NoWrapFlags Flags=SCEV::FlagAnyWrap, unsigned Depth=0)
Get a canonical add expression, or something simpler if possible.
LLVM_ABI bool isBasicBlockEntryGuardedByCond(const BasicBlock *BB, CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS)
Test whether entry to the basic block is protected by a conditional between LHS and RHS.
LLVM_ABI const SCEV * getTruncateOrSignExtend(const SCEV *V, Type *Ty, unsigned Depth=0)
Return a SCEV corresponding to a conversion of the input value to the specified type.
LLVM_ABI bool containsErasedValue(const SCEV *S) const
Return true if the SCEV expression contains a Value that has been optimised out and is now a nullptr.
LLVM_ABI bool isKnownPredicate(CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS)
Test if the given expression is known to satisfy the condition described by Pred, LHS,...
LLVM_ABI bool isKnownViaInduction(CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS)
We'd like to check the predicate on every iteration of the most dominated loop between loops used in ...
const SCEV * getSymbolicMaxBackedgeTakenCount(const Loop *L)
When successful, this returns a SCEV that is greater than or equal to (i.e.
APInt getSignedRangeMax(const SCEV *S)
Determine the max of the signed range for a particular SCEV.
LLVM_ABI void verify() const
LLVMContext & getContext() const
Implements a dense probed hash-table based set with some number of buckets stored inline.
Definition DenseSet.h:291
size_type size() const
Definition SmallPtrSet.h:99
A templated base class for SmallPtrSet which provides the typesafe interface that is common across al...
std::pair< iterator, bool > insert(PtrType Ptr)
Inserts Ptr if and only if there is no element in the container equal to Ptr.
bool contains(ConstPtrType Ptr) const
SmallPtrSet - This class implements a set which is optimized for holding SmallSize or less elements.
This class consists of common code factored out of the SmallVector class to reduce code duplication b...
reference emplace_back(ArgTypes &&... Args)
void reserve(size_type N)
iterator erase(const_iterator CI)
void append(ItTy in_start, ItTy in_end)
Add the specified range to the end of the SmallVector.
iterator insert(iterator I, T &&Elt)
void push_back(const T &Elt)
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
An instruction for storing to memory.
Used to lazily calculate structure layout information for a target machine, based on the DataLayout s...
Definition DataLayout.h:712
TypeSize getElementOffset(unsigned Idx) const
Definition DataLayout.h:743
TypeSize getSizeInBits() const
Definition DataLayout.h:723
Class to represent struct types.
Analysis pass providing the TargetLibraryInfo.
Provides information about what library functions are available for the current target.
The instances of the Type class are immutable: once they are created, they are never changed.
Definition Type.h:45
static LLVM_ABI IntegerType * getInt32Ty(LLVMContext &C)
Definition Type.cpp:296
bool isPointerTy() const
True if this is an instance of PointerType.
Definition Type.h:267
static LLVM_ABI IntegerType * getInt8Ty(LLVMContext &C)
Definition Type.cpp:294
LLVM_ABI TypeSize getPrimitiveSizeInBits() const LLVM_READONLY
Return the basic size of this type if it is a primitive type.
Definition Type.cpp:197
static LLVM_ABI IntegerType * getInt1Ty(LLVMContext &C)
Definition Type.cpp:293
bool isIntOrPtrTy() const
Return true if this is an integer type or a pointer type.
Definition Type.h:255
bool isIntegerTy() const
True if this is an instance of IntegerType.
Definition Type.h:240
static LLVM_ABI IntegerType * getIntNTy(LLVMContext &C, unsigned N)
Definition Type.cpp:300
A Use represents the edge between a Value definition and its users.
Definition Use.h:35
op_range operands()
Definition User.h:292
Use & Op()
Definition User.h:196
Value * getOperand(unsigned i) const
Definition User.h:232
LLVM Value Representation.
Definition Value.h:75
Type * getType() const
All values are typed, get the type of this value.
Definition Value.h:256
unsigned getValueID() const
Return an ID for the concrete type of this object.
Definition Value.h:543
LLVM_ABI void printAsOperand(raw_ostream &O, bool PrintType=true, const Module *M=nullptr) const
Print the name of this Value out to the specified raw_ostream.
LLVM_ABI LLVMContext & getContext() const
All values hold a context through their type.
Definition Value.cpp:1099
LLVM_ABI StringRef getName() const
Return a constant reference to the value's name.
Definition Value.cpp:322
constexpr bool isScalable() const
Returns whether the quantity is scaled by a runtime quantity (vscale).
Definition TypeSize.h:168
const ParentTy * getParent() const
Definition ilist_node.h:34
This class implements an extremely fast bulk output stream that can only output to a stream.
Definition raw_ostream.h:53
raw_ostream & indent(unsigned NumSpaces)
indent - Insert 'NumSpaces' spaces.
Changed
#define llvm_unreachable(msg)
Marks that the current location is not supposed to be reachable.
constexpr char Align[]
Key for Kernel::Arg::Metadata::mAlign.
const APInt & smin(const APInt &A, const APInt &B)
Determine the smaller of two APInts considered to be signed.
Definition APInt.h:2249
const APInt & smax(const APInt &A, const APInt &B)
Determine the larger of two APInts considered to be signed.
Definition APInt.h:2254
const APInt & umin(const APInt &A, const APInt &B)
Determine the smaller of two APInts considered to be unsigned.
Definition APInt.h:2259
LLVM_ABI std::optional< APInt > SolveQuadraticEquationWrap(APInt A, APInt B, APInt C, unsigned RangeWidth)
Let q(n) = An^2 + Bn + C, and BW = bit width of the value range (e.g.
Definition APInt.cpp:2812
const APInt & umax(const APInt &A, const APInt &B)
Determine the larger of two APInts considered to be unsigned.
Definition APInt.h:2264
LLVM_ABI APInt GreatestCommonDivisor(APInt A, APInt B)
Compute GCD of two unsigned APInt values.
Definition APInt.cpp:798
@ Entry
Definition COFF.h:862
unsigned ID
LLVM IR allows to use arbitrary numbers as calling convention identifiers.
Definition CallingConv.h:24
@ C
The default llvm calling convention, compatible with C.
Definition CallingConv.h:34
int getMinValue(MCInstrInfo const &MCII, MCInst const &MCI)
Return the minimum value of an extendable operand.
@ BasicBlock
Various leaf nodes.
Definition ISDOpcodes.h:81
LLVM_ABI Function * getDeclarationIfExists(const Module *M, ID id)
Look up the Function declaration of the intrinsic id in the Module M and return it if it exists.
Predicate
Predicate - These are "(BI << 5) | BO" for various predicates.
BinaryOp_match< LHS, RHS, Instruction::AShr > m_AShr(const LHS &L, const RHS &R)
ap_match< APInt > m_APInt(const APInt *&Res)
Match a ConstantInt or splatted ConstantVector, binding the specified pointer to the contained APInt.
bool match(Val *V, const Pattern &P)
class_match< ConstantInt > m_ConstantInt()
Match an arbitrary ConstantInt and ignore it.
IntrinsicID_match m_Intrinsic()
Match intrinsic calls like this: m_Intrinsic<Intrinsic::fabs>(m_Value(X))
ThreeOps_match< Cond, LHS, RHS, Instruction::Select > m_Select(const Cond &C, const LHS &L, const RHS &R)
Matches SelectInst.
ExtractValue_match< Ind, Val_t > m_ExtractValue(const Val_t &V)
Match a single index ExtractValue instruction.
bind_ty< WithOverflowInst > m_WithOverflowInst(WithOverflowInst *&I)
Match a with overflow intrinsic, capturing it if we match.
auto m_LogicalOr()
Matches L || R where L and R are arbitrary values.
brc_match< Cond_t, bind_ty< BasicBlock >, bind_ty< BasicBlock > > m_Br(const Cond_t &C, BasicBlock *&T, BasicBlock *&F)
BinaryOp_match< LHS, RHS, Instruction::SDiv > m_SDiv(const LHS &L, const RHS &R)
class_match< Value > m_Value()
Match an arbitrary value and ignore it.
BinaryOp_match< LHS, RHS, Instruction::LShr > m_LShr(const LHS &L, const RHS &R)
BinaryOp_match< LHS, RHS, Instruction::Shl > m_Shl(const LHS &L, const RHS &R)
auto m_LogicalAnd()
Matches L && R where L and R are arbitrary values.
class_match< BasicBlock > m_BasicBlock()
Match an arbitrary basic block value and ignore it.
match_combine_or< LTy, RTy > m_CombineOr(const LTy &L, const RTy &R)
Combine two pattern matchers matching L || R.
class_match< const SCEVVScale > m_SCEVVScale()
bind_cst_ty m_scev_APInt(const APInt *&C)
Match an SCEV constant and bind it to an APInt.
cst_pred_ty< is_all_ones > m_scev_AllOnes()
Match an integer with all bits set.
SCEVUnaryExpr_match< SCEVZeroExtendExpr, Op0_t > m_scev_ZExt(const Op0_t &Op0)
is_undef_or_poison m_scev_UndefOrPoison()
Match an SCEVUnknown wrapping undef or poison.
class_match< const SCEVConstant > m_SCEVConstant()
cst_pred_ty< is_one > m_scev_One()
Match an integer 1.
specificloop_ty m_SpecificLoop(const Loop *L)
SCEVAffineAddRec_match< Op0_t, Op1_t, class_match< const Loop > > m_scev_AffineAddRec(const Op0_t &Op0, const Op1_t &Op1)
bind_ty< const SCEVMulExpr > m_scev_Mul(const SCEVMulExpr *&V)
SCEVUnaryExpr_match< SCEVSignExtendExpr, Op0_t > m_scev_SExt(const Op0_t &Op0)
cst_pred_ty< is_zero > m_scev_Zero()
Match an integer 0.
SCEVUnaryExpr_match< SCEVTruncateExpr, Op0_t > m_scev_Trunc(const Op0_t &Op0)
bool match(const SCEV *S, const Pattern &P)
SCEVBinaryExpr_match< SCEVUDivExpr, Op0_t, Op1_t > m_scev_UDiv(const Op0_t &Op0, const Op1_t &Op1)
specificscev_ty m_scev_Specific(const SCEV *S)
Match if we have a specific specified SCEV.
SCEVBinaryExpr_match< SCEVMulExpr, Op0_t, Op1_t, SCEV::FlagNUW, true > m_scev_c_NUWMul(const Op0_t &Op0, const Op1_t &Op1)
class_match< const Loop > m_Loop()
bind_ty< const SCEVAddExpr > m_scev_Add(const SCEVAddExpr *&V)
bind_ty< const SCEVUnknown > m_SCEVUnknown(const SCEVUnknown *&V)
SCEVBinaryExpr_match< SCEVMulExpr, Op0_t, Op1_t, SCEV::FlagAnyWrap, true > m_scev_c_Mul(const Op0_t &Op0, const Op1_t &Op1)
SCEVBinaryExpr_match< SCEVSMaxExpr, Op0_t, Op1_t > m_scev_SMax(const Op0_t &Op0, const Op1_t &Op1)
SCEVURem_match< Op0_t, Op1_t > m_scev_URem(Op0_t LHS, Op1_t RHS, ScalarEvolution &SE)
Match the mathematical pattern A - (A / B) * B, where A and B can be arbitrary expressions.
class_match< const SCEV > m_SCEV()
@ Valid
The data is already valid.
initializer< Ty > init(const Ty &Val)
LocationClass< Ty > location(Ty &L)
@ Switch
The "resume-switch" lowering, where there are separate resume and destroy functions that are shared b...
Definition CoroShape.h:31
constexpr double e
NodeAddr< PhiNode * > Phi
Definition RDFGraph.h:390
friend class Instruction
Iterator for Instructions in a `BasicBlock.
Definition BasicBlock.h:73
This is an optimization pass for GlobalISel generic memory operations.
void visitAll(const SCEV *Root, SV &Visitor)
Use SCEVTraversal to visit all nodes in the given expression tree.
auto drop_begin(T &&RangeOrContainer, size_t N=1)
Return a range covering RangeOrContainer with the first N elements excluded.
Definition STLExtras.h:316
@ Offset
Definition DWP.cpp:532
FunctionAddr VTableAddr Value
Definition InstrProf.h:137
LLVM_ATTRIBUTE_ALWAYS_INLINE DynamicAPInt gcd(const DynamicAPInt &A, const DynamicAPInt &B)
void stable_sort(R &&Range)
Definition STLExtras.h:2058
bool all_of(R &&range, UnaryPredicate P)
Provide wrappers to std::all_of which take ranges instead of having to pass begin/end explicitly.
Definition STLExtras.h:1725
SaveAndRestore(T &) -> SaveAndRestore< T >
Printable print(const GCNRegPressure &RP, const GCNSubtarget *ST=nullptr, unsigned DynamicVGPRBlockSize=0)
LLVM_ABI bool canCreatePoison(const Operator *Op, bool ConsiderFlagsAndMetadata=true)
LLVM_ABI bool mustTriggerUB(const Instruction *I, const SmallPtrSetImpl< const Value * > &KnownPoison)
Return true if the given instruction must trigger undefined behavior when I is executed with any oper...
detail::scope_exit< std::decay_t< Callable > > make_scope_exit(Callable &&F)
Definition ScopeExit.h:59
LLVM_ABI bool canConstantFoldCallTo(const CallBase *Call, const Function *F)
canConstantFoldCallTo - Return true if its even possible to fold a call to the specified function.
InterleavedRange< Range > interleaved(const Range &R, StringRef Separator=", ", StringRef Prefix="", StringRef Suffix="")
Output range R as a sequence of interleaved elements.
decltype(auto) dyn_cast(const From &Val)
dyn_cast<X> - Return the argument parameter cast to the specified type.
Definition Casting.h:643
LLVM_ABI bool verifyFunction(const Function &F, raw_ostream *OS=nullptr)
Check a function for errors, useful for use when debugging a pass.
auto successors(const MachineBasicBlock *BB)
constexpr from_range_t from_range
auto dyn_cast_if_present(const Y &Val)
dyn_cast_if_present<X> - Functionally identical to dyn_cast, except that a null (or none in the case ...
Definition Casting.h:732
bool set_is_subset(const S1Ty &S1, const S2Ty &S2)
set_is_subset(A, B) - Return true iff A in B
void append_range(Container &C, Range &&R)
Wrapper function to append range R to container C.
Definition STLExtras.h:2136
constexpr bool isUIntN(unsigned N, uint64_t x)
Checks if an unsigned integer fits into the given (dynamic) bit width.
Definition MathExtras.h:243
LLVM_ABI Constant * ConstantFoldCompareInstOperands(unsigned Predicate, Constant *LHS, Constant *RHS, const DataLayout &DL, const TargetLibraryInfo *TLI=nullptr, const Instruction *I=nullptr)
Attempt to constant fold a compare instruction (icmp/fcmp) with the specified operands.
unsigned short computeExpressionSize(ArrayRef< const SCEV * > Args)
void * PointerTy
LLVM_ABI bool VerifySCEV
auto uninitialized_copy(R &&Src, IterTy Dst)
Definition STLExtras.h:2053
bool isa_and_nonnull(const Y &Val)
Definition Casting.h:676
LLVM_ABI ConstantRange getConstantRangeFromMetadata(const MDNode &RangeMD)
Parse out a conservative ConstantRange from !range metadata.
int countr_zero(T Val)
Count number of 0's from the least significant bit to the most stopping at the first 1.
Definition bit.h:202
LLVM_ABI Value * simplifyInstruction(Instruction *I, const SimplifyQuery &Q)
See if we can compute a simplified version of this instruction.
LLVM_ABI bool isOverflowIntrinsicNoWrap(const WithOverflowInst *WO, const DominatorTree &DT)
Returns true if the arithmetic part of the WO 's result is used only along the paths control dependen...
DomTreeNodeBase< BasicBlock > DomTreeNode
Definition Dominators.h:94
LLVM_ABI bool matchSimpleRecurrence(const PHINode *P, BinaryOperator *&BO, Value *&Start, Value *&Step)
Attempt to match a simple first order recurrence cycle of the form: iv = phi Ty [Start,...
auto dyn_cast_or_null(const Y &Val)
Definition Casting.h:753
void erase(Container &C, ValueType V)
Wrapper function to remove a value from a container:
Definition STLExtras.h:2128
bool any_of(R &&range, UnaryPredicate P)
Provide wrappers to std::any_of which take ranges instead of having to pass begin/end explicitly.
Definition STLExtras.h:1732
iterator_range< pointee_iterator< WrappedIteratorT > > make_pointee_range(RangeT &&Range)
Definition iterator.h:336
auto reverse(ContainerTy &&C)
Definition STLExtras.h:406
LLVM_ABI bool isMustProgress(const Loop *L)
Return true if this loop can be assumed to make progress.
LLVM_ABI bool impliesPoison(const Value *ValAssumedPoison, const Value *V)
Return true if V is poison given that ValAssumedPoison is already poison.
LLVM_ABI bool isFinite(const Loop *L)
Return true if this loop can be assumed to run for a finite number of iterations.
LLVM_ABI void computeKnownBits(const Value *V, KnownBits &Known, const DataLayout &DL, AssumptionCache *AC=nullptr, const Instruction *CxtI=nullptr, const DominatorTree *DT=nullptr, bool UseInstrInfo=true, unsigned Depth=0)
Determine which bits of V are known to be either zero or one and return them in the KnownZero/KnownOn...
LLVM_ABI bool programUndefinedIfPoison(const Instruction *Inst)
LLVM_ABI raw_ostream & dbgs()
dbgs() - This returns a reference to a raw_ostream for debugging messages.
Definition Debug.cpp:207
bool isPointerTy(const Type *T)
Definition SPIRVUtils.h:354
FunctionAddr VTableAddr Count
Definition InstrProf.h:139
LLVM_ABI ConstantRange getVScaleRange(const Function *F, unsigned BitWidth)
Determine the possible constant range of vscale with the given bit width, based on the vscale_range f...
class LLVM_GSL_OWNER SmallVector
Forward declaration of SmallVector so that calculateSmallVectorDefaultInlinedElements can reference s...
bool isa(const From &Val)
isa<X> - Return true if the parameter to the template is an instance of one of the template type argu...
Definition Casting.h:547
LLVM_ATTRIBUTE_VISIBILITY_DEFAULT AnalysisKey InnerAnalysisManagerProxy< AnalysisManagerT, IRUnitT, ExtraArgTs... >::Key
LLVM_ABI bool isKnownNonZero(const Value *V, const SimplifyQuery &Q, unsigned Depth=0)
Return true if the given value is known to be non-zero when defined.
@ First
Helpers to iterate all locations in the MemoryEffectsBase class.
Definition ModRef.h:74
LLVM_ABI bool propagatesPoison(const Use &PoisonOp)
Return true if PoisonOp's user yields poison or raises UB if its operand PoisonOp is poison.
@ UMin
Unsigned integer min implemented in terms of select(cmp()).
@ Mul
Product of integers.
@ SMax
Signed integer max implemented in terms of select(cmp()).
@ SMin
Signed integer min implemented in terms of select(cmp()).
@ Add
Sum of integers.
@ UMax
Unsigned integer max implemented in terms of select(cmp()).
auto count(R &&Range, const E &Element)
Wrapper function around std::count to count the number of times an element Element occurs in the give...
Definition STLExtras.h:1954
DWARFExpression::Operation Op
auto max_element(R &&Range)
Provide wrappers to std::max_element which take ranges instead of having to pass begin/end explicitly...
Definition STLExtras.h:2030
raw_ostream & operator<<(raw_ostream &OS, const APFixedPoint &FX)
ArrayRef(const T &OneElt) -> ArrayRef< T >
LLVM_ABI unsigned ComputeNumSignBits(const Value *Op, const DataLayout &DL, AssumptionCache *AC=nullptr, const Instruction *CxtI=nullptr, const DominatorTree *DT=nullptr, bool UseInstrInfo=true, unsigned Depth=0)
Return the number of times the sign bit of the register is replicated into the other bits.
constexpr unsigned BitWidth
OutputIt move(R &&Range, OutputIt Out)
Provide wrappers to std::move which take ranges instead of having to pass begin/end explicitly.
Definition STLExtras.h:1867
LLVM_ABI bool isGuaranteedToTransferExecutionToSuccessor(const Instruction *I)
Return true if this function can prove that the instruction I will always transfer execution to one o...
auto count_if(R &&Range, UnaryPredicate P)
Wrapper function around std::count_if to count the number of times an element satisfying a given pred...
Definition STLExtras.h:1961
decltype(auto) cast(const From &Val)
cast<X> - Return the argument parameter cast to the specified type.
Definition Casting.h:559
constexpr bool isIntN(unsigned N, int64_t x)
Checks if an signed integer fits into the given (dynamic) bit width.
Definition MathExtras.h:248
auto predecessors(const MachineBasicBlock *BB)
bool is_contained(R &&Range, const E &Element)
Returns true if Element is found in Range.
Definition STLExtras.h:1897
iterator_range< df_iterator< T > > depth_first(const T &G)
auto seq(T Begin, T End)
Iterate over an integral type from Begin up to - but not including - End.
Definition Sequence.h:305
AnalysisManager< Function > FunctionAnalysisManager
Convenience typedef for the Function analysis manager.
LLVM_ABI bool isGuaranteedNotToBePoison(const Value *V, AssumptionCache *AC=nullptr, const Instruction *CtxI=nullptr, const DominatorTree *DT=nullptr, unsigned Depth=0)
Returns true if V cannot be poison, but may be undef.
LLVM_ABI Constant * ConstantFoldInstOperands(const Instruction *I, ArrayRef< Constant * > Ops, const DataLayout &DL, const TargetLibraryInfo *TLI=nullptr, bool AllowNonDeterministic=true)
ConstantFoldInstOperands - Attempt to constant fold an instruction with the specified operands.
bool SCEVExprContains(const SCEV *Root, PredTy Pred)
Return true if any node in Root satisfies the predicate Pred.
Implement std::hash so that hash_code can be used in STL containers.
Definition BitVector.h:867
void swap(llvm::BitVector &LHS, llvm::BitVector &RHS)
Implement std::swap in terms of BitVector swap.
Definition BitVector.h:869
#define N
#define NC
Definition regutils.h:42
A special type used by analysis passes to provide an address that identifies that particular analysis...
Definition Analysis.h:29
static KnownBits makeConstant(const APInt &C)
Create known bits from a known constant.
Definition KnownBits.h:301
bool isNonNegative() const
Returns true if this value is known to be non-negative.
Definition KnownBits.h:108
static LLVM_ABI KnownBits ashr(const KnownBits &LHS, const KnownBits &RHS, bool ShAmtNonZero=false, bool Exact=false)
Compute known bits for ashr(LHS, RHS).
unsigned getBitWidth() const
Get the bit width of this value.
Definition KnownBits.h:44
static LLVM_ABI KnownBits lshr(const KnownBits &LHS, const KnownBits &RHS, bool ShAmtNonZero=false, bool Exact=false)
Compute known bits for lshr(LHS, RHS).
KnownBits zextOrTrunc(unsigned BitWidth) const
Return known bits for a zero extension or truncation of the value we're tracking.
Definition KnownBits.h:196
APInt getMaxValue() const
Return the maximal unsigned value possible given these KnownBits.
Definition KnownBits.h:145
APInt getMinValue() const
Return the minimal unsigned value possible given these KnownBits.
Definition KnownBits.h:129
bool isNegative() const
Returns true if this value is known to be negative.
Definition KnownBits.h:105
static LLVM_ABI KnownBits shl(const KnownBits &LHS, const KnownBits &RHS, bool NUW=false, bool NSW=false, bool ShAmtNonZero=false)
Compute known bits for shl(LHS, RHS).
An object of this class is returned by queries that could not be answered.
static LLVM_ABI bool classof(const SCEV *S)
Methods for support type inquiry through isa, cast, and dyn_cast:
This class defines a simple visitor class that may be used for various SCEV analysis purposes.
A utility class that uses RAII to save and restore the value of a variable.
Information about the number of loop iterations for which a loop exit's branch condition evaluates to...
LLVM_ABI ExitLimit(const SCEV *E)
Construct either an exact exit limit from a constant, or an unknown one from a SCEVCouldNotCompute.
SmallVector< const SCEVPredicate *, 4 > Predicates
A vector of predicate guards for this ExitLimit.