LLVM 20.0.0git
ScalarEvolution.cpp
Go to the documentation of this file.
1//===- ScalarEvolution.cpp - Scalar Evolution Analysis --------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file contains the implementation of the scalar evolution analysis
10// engine, which is used primarily to analyze expressions involving induction
11// variables in loops.
12//
13// There are several aspects to this library. First is the representation of
14// scalar expressions, which are represented as subclasses of the SCEV class.
15// These classes are used to represent certain types of subexpressions that we
16// can handle. We only create one SCEV of a particular shape, so
17// pointer-comparisons for equality are legal.
18//
19// One important aspect of the SCEV objects is that they are never cyclic, even
20// if there is a cycle in the dataflow for an expression (ie, a PHI node). If
21// the PHI node is one of the idioms that we can represent (e.g., a polynomial
22// recurrence) then we represent it directly as a recurrence node, otherwise we
23// represent it as a SCEVUnknown node.
24//
25// In addition to being able to represent expressions of various types, we also
26// have folders that are used to build the *canonical* representation for a
27// particular expression. These folders are capable of using a variety of
28// rewrite rules to simplify the expressions.
29//
30// Once the folders are defined, we can implement the more interesting
31// higher-level code, such as the code that recognizes PHI nodes of various
32// types, computes the execution count of a loop, etc.
33//
34// TODO: We should use these routines and value representations to implement
35// dependence analysis!
36//
37//===----------------------------------------------------------------------===//
38//
39// There are several good references for the techniques used in this analysis.
40//
41// Chains of recurrences -- a method to expedite the evaluation
42// of closed-form functions
43// Olaf Bachmann, Paul S. Wang, Eugene V. Zima
44//
45// On computational properties of chains of recurrences
46// Eugene V. Zima
47//
48// Symbolic Evaluation of Chains of Recurrences for Loop Optimization
49// Robert A. van Engelen
50//
51// Efficient Symbolic Analysis for Optimizing Compilers
52// Robert A. van Engelen
53//
54// Using the chains of recurrences algebra for data dependence testing and
55// induction variable substitution
56// MS Thesis, Johnie Birch
57//
58//===----------------------------------------------------------------------===//
59
61#include "llvm/ADT/APInt.h"
62#include "llvm/ADT/ArrayRef.h"
63#include "llvm/ADT/DenseMap.h"
66#include "llvm/ADT/FoldingSet.h"
67#include "llvm/ADT/STLExtras.h"
68#include "llvm/ADT/ScopeExit.h"
69#include "llvm/ADT/Sequence.h"
71#include "llvm/ADT/SmallSet.h"
73#include "llvm/ADT/Statistic.h"
75#include "llvm/ADT/StringRef.h"
84#include "llvm/Config/llvm-config.h"
85#include "llvm/IR/Argument.h"
86#include "llvm/IR/BasicBlock.h"
87#include "llvm/IR/CFG.h"
88#include "llvm/IR/Constant.h"
90#include "llvm/IR/Constants.h"
91#include "llvm/IR/DataLayout.h"
93#include "llvm/IR/Dominators.h"
94#include "llvm/IR/Function.h"
95#include "llvm/IR/GlobalAlias.h"
96#include "llvm/IR/GlobalValue.h"
98#include "llvm/IR/InstrTypes.h"
99#include "llvm/IR/Instruction.h"
100#include "llvm/IR/Instructions.h"
102#include "llvm/IR/Intrinsics.h"
103#include "llvm/IR/LLVMContext.h"
104#include "llvm/IR/Operator.h"
105#include "llvm/IR/PatternMatch.h"
106#include "llvm/IR/Type.h"
107#include "llvm/IR/Use.h"
108#include "llvm/IR/User.h"
109#include "llvm/IR/Value.h"
110#include "llvm/IR/Verifier.h"
112#include "llvm/Pass.h"
113#include "llvm/Support/Casting.h"
116#include "llvm/Support/Debug.h"
121#include <algorithm>
122#include <cassert>
123#include <climits>
124#include <cstdint>
125#include <cstdlib>
126#include <map>
127#include <memory>
128#include <numeric>
129#include <optional>
130#include <tuple>
131#include <utility>
132#include <vector>
133
134using namespace llvm;
135using namespace PatternMatch;
136
137#define DEBUG_TYPE "scalar-evolution"
138
139STATISTIC(NumExitCountsComputed,
140 "Number of loop exits with predictable exit counts");
141STATISTIC(NumExitCountsNotComputed,
142 "Number of loop exits without predictable exit counts");
143STATISTIC(NumBruteForceTripCountsComputed,
144 "Number of loops with trip counts computed by force");
145
146#ifdef EXPENSIVE_CHECKS
147bool llvm::VerifySCEV = true;
148#else
149bool llvm::VerifySCEV = false;
150#endif
151
153 MaxBruteForceIterations("scalar-evolution-max-iterations", cl::ReallyHidden,
154 cl::desc("Maximum number of iterations SCEV will "
155 "symbolically execute a constant "
156 "derived loop"),
157 cl::init(100));
158
160 "verify-scev", cl::Hidden, cl::location(VerifySCEV),
161 cl::desc("Verify ScalarEvolution's backedge taken counts (slow)"));
163 "verify-scev-strict", cl::Hidden,
164 cl::desc("Enable stricter verification with -verify-scev is passed"));
165
167 "scev-verify-ir", cl::Hidden,
168 cl::desc("Verify IR correctness when making sensitive SCEV queries (slow)"),
169 cl::init(false));
170
172 "scev-mulops-inline-threshold", cl::Hidden,
173 cl::desc("Threshold for inlining multiplication operands into a SCEV"),
174 cl::init(32));
175
177 "scev-addops-inline-threshold", cl::Hidden,
178 cl::desc("Threshold for inlining addition operands into a SCEV"),
179 cl::init(500));
180
182 "scalar-evolution-max-scev-compare-depth", cl::Hidden,
183 cl::desc("Maximum depth of recursive SCEV complexity comparisons"),
184 cl::init(32));
185
187 "scalar-evolution-max-scev-operations-implication-depth", cl::Hidden,
188 cl::desc("Maximum depth of recursive SCEV operations implication analysis"),
189 cl::init(2));
190
192 "scalar-evolution-max-value-compare-depth", cl::Hidden,
193 cl::desc("Maximum depth of recursive value complexity comparisons"),
194 cl::init(2));
195
197 MaxArithDepth("scalar-evolution-max-arith-depth", cl::Hidden,
198 cl::desc("Maximum depth of recursive arithmetics"),
199 cl::init(32));
200
202 "scalar-evolution-max-constant-evolving-depth", cl::Hidden,
203 cl::desc("Maximum depth of recursive constant evolving"), cl::init(32));
204
206 MaxCastDepth("scalar-evolution-max-cast-depth", cl::Hidden,
207 cl::desc("Maximum depth of recursive SExt/ZExt/Trunc"),
208 cl::init(8));
209
211 MaxAddRecSize("scalar-evolution-max-add-rec-size", cl::Hidden,
212 cl::desc("Max coefficients in AddRec during evolving"),
213 cl::init(8));
214
216 HugeExprThreshold("scalar-evolution-huge-expr-threshold", cl::Hidden,
217 cl::desc("Size of the expression which is considered huge"),
218 cl::init(4096));
219
221 "scev-range-iter-threshold", cl::Hidden,
222 cl::desc("Threshold for switching to iteratively computing SCEV ranges"),
223 cl::init(32));
224
225static cl::opt<bool>
226ClassifyExpressions("scalar-evolution-classify-expressions",
227 cl::Hidden, cl::init(true),
228 cl::desc("When printing analysis, include information on every instruction"));
229
231 "scalar-evolution-use-expensive-range-sharpening", cl::Hidden,
232 cl::init(false),
233 cl::desc("Use more powerful methods of sharpening expression ranges. May "
234 "be costly in terms of compile time"));
235
237 "scalar-evolution-max-scc-analysis-depth", cl::Hidden,
238 cl::desc("Maximum amount of nodes to process while searching SCEVUnknown "
239 "Phi strongly connected components"),
240 cl::init(8));
241
242static cl::opt<bool>
243 EnableFiniteLoopControl("scalar-evolution-finite-loop", cl::Hidden,
244 cl::desc("Handle <= and >= in finite loops"),
245 cl::init(true));
246
248 "scalar-evolution-use-context-for-no-wrap-flag-strenghening", cl::Hidden,
249 cl::desc("Infer nuw/nsw flags using context where suitable"),
250 cl::init(true));
251
252//===----------------------------------------------------------------------===//
253// SCEV class definitions
254//===----------------------------------------------------------------------===//
255
256//===----------------------------------------------------------------------===//
257// Implementation of the SCEV class.
258//
259
260#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
262 print(dbgs());
263 dbgs() << '\n';
264}
265#endif
266
268 switch (getSCEVType()) {
269 case scConstant:
270 cast<SCEVConstant>(this)->getValue()->printAsOperand(OS, false);
271 return;
272 case scVScale:
273 OS << "vscale";
274 return;
275 case scPtrToInt: {
276 const SCEVPtrToIntExpr *PtrToInt = cast<SCEVPtrToIntExpr>(this);
277 const SCEV *Op = PtrToInt->getOperand();
278 OS << "(ptrtoint " << *Op->getType() << " " << *Op << " to "
279 << *PtrToInt->getType() << ")";
280 return;
281 }
282 case scTruncate: {
283 const SCEVTruncateExpr *Trunc = cast<SCEVTruncateExpr>(this);
284 const SCEV *Op = Trunc->getOperand();
285 OS << "(trunc " << *Op->getType() << " " << *Op << " to "
286 << *Trunc->getType() << ")";
287 return;
288 }
289 case scZeroExtend: {
290 const SCEVZeroExtendExpr *ZExt = cast<SCEVZeroExtendExpr>(this);
291 const SCEV *Op = ZExt->getOperand();
292 OS << "(zext " << *Op->getType() << " " << *Op << " to "
293 << *ZExt->getType() << ")";
294 return;
295 }
296 case scSignExtend: {
297 const SCEVSignExtendExpr *SExt = cast<SCEVSignExtendExpr>(this);
298 const SCEV *Op = SExt->getOperand();
299 OS << "(sext " << *Op->getType() << " " << *Op << " to "
300 << *SExt->getType() << ")";
301 return;
302 }
303 case scAddRecExpr: {
304 const SCEVAddRecExpr *AR = cast<SCEVAddRecExpr>(this);
305 OS << "{" << *AR->getOperand(0);
306 for (unsigned i = 1, e = AR->getNumOperands(); i != e; ++i)
307 OS << ",+," << *AR->getOperand(i);
308 OS << "}<";
309 if (AR->hasNoUnsignedWrap())
310 OS << "nuw><";
311 if (AR->hasNoSignedWrap())
312 OS << "nsw><";
313 if (AR->hasNoSelfWrap() &&
315 OS << "nw><";
316 AR->getLoop()->getHeader()->printAsOperand(OS, /*PrintType=*/false);
317 OS << ">";
318 return;
319 }
320 case scAddExpr:
321 case scMulExpr:
322 case scUMaxExpr:
323 case scSMaxExpr:
324 case scUMinExpr:
325 case scSMinExpr:
327 const SCEVNAryExpr *NAry = cast<SCEVNAryExpr>(this);
328 const char *OpStr = nullptr;
329 switch (NAry->getSCEVType()) {
330 case scAddExpr: OpStr = " + "; break;
331 case scMulExpr: OpStr = " * "; break;
332 case scUMaxExpr: OpStr = " umax "; break;
333 case scSMaxExpr: OpStr = " smax "; break;
334 case scUMinExpr:
335 OpStr = " umin ";
336 break;
337 case scSMinExpr:
338 OpStr = " smin ";
339 break;
341 OpStr = " umin_seq ";
342 break;
343 default:
344 llvm_unreachable("There are no other nary expression types.");
345 }
346 OS << "(";
347 ListSeparator LS(OpStr);
348 for (const SCEV *Op : NAry->operands())
349 OS << LS << *Op;
350 OS << ")";
351 switch (NAry->getSCEVType()) {
352 case scAddExpr:
353 case scMulExpr:
354 if (NAry->hasNoUnsignedWrap())
355 OS << "<nuw>";
356 if (NAry->hasNoSignedWrap())
357 OS << "<nsw>";
358 break;
359 default:
360 // Nothing to print for other nary expressions.
361 break;
362 }
363 return;
364 }
365 case scUDivExpr: {
366 const SCEVUDivExpr *UDiv = cast<SCEVUDivExpr>(this);
367 OS << "(" << *UDiv->getLHS() << " /u " << *UDiv->getRHS() << ")";
368 return;
369 }
370 case scUnknown:
371 cast<SCEVUnknown>(this)->getValue()->printAsOperand(OS, false);
372 return;
374 OS << "***COULDNOTCOMPUTE***";
375 return;
376 }
377 llvm_unreachable("Unknown SCEV kind!");
378}
379
381 switch (getSCEVType()) {
382 case scConstant:
383 return cast<SCEVConstant>(this)->getType();
384 case scVScale:
385 return cast<SCEVVScale>(this)->getType();
386 case scPtrToInt:
387 case scTruncate:
388 case scZeroExtend:
389 case scSignExtend:
390 return cast<SCEVCastExpr>(this)->getType();
391 case scAddRecExpr:
392 return cast<SCEVAddRecExpr>(this)->getType();
393 case scMulExpr:
394 return cast<SCEVMulExpr>(this)->getType();
395 case scUMaxExpr:
396 case scSMaxExpr:
397 case scUMinExpr:
398 case scSMinExpr:
399 return cast<SCEVMinMaxExpr>(this)->getType();
401 return cast<SCEVSequentialMinMaxExpr>(this)->getType();
402 case scAddExpr:
403 return cast<SCEVAddExpr>(this)->getType();
404 case scUDivExpr:
405 return cast<SCEVUDivExpr>(this)->getType();
406 case scUnknown:
407 return cast<SCEVUnknown>(this)->getType();
409 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
410 }
411 llvm_unreachable("Unknown SCEV kind!");
412}
413
415 switch (getSCEVType()) {
416 case scConstant:
417 case scVScale:
418 case scUnknown:
419 return {};
420 case scPtrToInt:
421 case scTruncate:
422 case scZeroExtend:
423 case scSignExtend:
424 return cast<SCEVCastExpr>(this)->operands();
425 case scAddRecExpr:
426 case scAddExpr:
427 case scMulExpr:
428 case scUMaxExpr:
429 case scSMaxExpr:
430 case scUMinExpr:
431 case scSMinExpr:
433 return cast<SCEVNAryExpr>(this)->operands();
434 case scUDivExpr:
435 return cast<SCEVUDivExpr>(this)->operands();
437 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
438 }
439 llvm_unreachable("Unknown SCEV kind!");
440}
441
442bool SCEV::isZero() const {
443 if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(this))
444 return SC->getValue()->isZero();
445 return false;
446}
447
448bool SCEV::isOne() const {
449 if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(this))
450 return SC->getValue()->isOne();
451 return false;
452}
453
455 if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(this))
456 return SC->getValue()->isMinusOne();
457 return false;
458}
459
461 const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(this);
462 if (!Mul) return false;
463
464 // If there is a constant factor, it will be first.
465 const SCEVConstant *SC = dyn_cast<SCEVConstant>(Mul->getOperand(0));
466 if (!SC) return false;
467
468 // Return true if the value is negative, this matches things like (-42 * V).
469 return SC->getAPInt().isNegative();
470}
471
474
476 return S->getSCEVType() == scCouldNotCompute;
477}
478
481 ID.AddInteger(scConstant);
482 ID.AddPointer(V);
483 void *IP = nullptr;
484 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
485 SCEV *S = new (SCEVAllocator) SCEVConstant(ID.Intern(SCEVAllocator), V);
486 UniqueSCEVs.InsertNode(S, IP);
487 return S;
488}
489
491 return getConstant(ConstantInt::get(getContext(), Val));
492}
493
494const SCEV *
496 IntegerType *ITy = cast<IntegerType>(getEffectiveSCEVType(Ty));
497 return getConstant(ConstantInt::get(ITy, V, isSigned));
498}
499
502 ID.AddInteger(scVScale);
503 ID.AddPointer(Ty);
504 void *IP = nullptr;
505 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP))
506 return S;
507 SCEV *S = new (SCEVAllocator) SCEVVScale(ID.Intern(SCEVAllocator), Ty);
508 UniqueSCEVs.InsertNode(S, IP);
509 return S;
510}
511
513 const SCEV *Res = getConstant(Ty, EC.getKnownMinValue());
514 if (EC.isScalable())
515 Res = getMulExpr(Res, getVScale(Ty));
516 return Res;
517}
518
520 const SCEV *op, Type *ty)
521 : SCEV(ID, SCEVTy, computeExpressionSize(op)), Op(op), Ty(ty) {}
522
523SCEVPtrToIntExpr::SCEVPtrToIntExpr(const FoldingSetNodeIDRef ID, const SCEV *Op,
524 Type *ITy)
525 : SCEVCastExpr(ID, scPtrToInt, Op, ITy) {
527 "Must be a non-bit-width-changing pointer-to-integer cast!");
528}
529
531 SCEVTypes SCEVTy, const SCEV *op,
532 Type *ty)
533 : SCEVCastExpr(ID, SCEVTy, op, ty) {}
534
535SCEVTruncateExpr::SCEVTruncateExpr(const FoldingSetNodeIDRef ID, const SCEV *op,
536 Type *ty)
538 assert(getOperand()->getType()->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
539 "Cannot truncate non-integer value!");
540}
541
542SCEVZeroExtendExpr::SCEVZeroExtendExpr(const FoldingSetNodeIDRef ID,
543 const SCEV *op, Type *ty)
545 assert(getOperand()->getType()->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
546 "Cannot zero extend non-integer value!");
547}
548
549SCEVSignExtendExpr::SCEVSignExtendExpr(const FoldingSetNodeIDRef ID,
550 const SCEV *op, Type *ty)
552 assert(getOperand()->getType()->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
553 "Cannot sign extend non-integer value!");
554}
555
556void SCEVUnknown::deleted() {
557 // Clear this SCEVUnknown from various maps.
558 SE->forgetMemoizedResults(this);
559
560 // Remove this SCEVUnknown from the uniquing map.
561 SE->UniqueSCEVs.RemoveNode(this);
562
563 // Release the value.
564 setValPtr(nullptr);
565}
566
567void SCEVUnknown::allUsesReplacedWith(Value *New) {
568 // Clear this SCEVUnknown from various maps.
569 SE->forgetMemoizedResults(this);
570
571 // Remove this SCEVUnknown from the uniquing map.
572 SE->UniqueSCEVs.RemoveNode(this);
573
574 // Replace the value pointer in case someone is still using this SCEVUnknown.
575 setValPtr(New);
576}
577
578//===----------------------------------------------------------------------===//
579// SCEV Utilities
580//===----------------------------------------------------------------------===//
581
582/// Compare the two values \p LV and \p RV in terms of their "complexity" where
583/// "complexity" is a partial (and somewhat ad-hoc) relation used to order
584/// operands in SCEV expressions.
585static int CompareValueComplexity(const LoopInfo *const LI, Value *LV,
586 Value *RV, unsigned Depth) {
588 return 0;
589
590 // Order pointer values after integer values. This helps SCEVExpander form
591 // GEPs.
592 bool LIsPointer = LV->getType()->isPointerTy(),
593 RIsPointer = RV->getType()->isPointerTy();
594 if (LIsPointer != RIsPointer)
595 return (int)LIsPointer - (int)RIsPointer;
596
597 // Compare getValueID values.
598 unsigned LID = LV->getValueID(), RID = RV->getValueID();
599 if (LID != RID)
600 return (int)LID - (int)RID;
601
602 // Sort arguments by their position.
603 if (const auto *LA = dyn_cast<Argument>(LV)) {
604 const auto *RA = cast<Argument>(RV);
605 unsigned LArgNo = LA->getArgNo(), RArgNo = RA->getArgNo();
606 return (int)LArgNo - (int)RArgNo;
607 }
608
609 if (const auto *LGV = dyn_cast<GlobalValue>(LV)) {
610 const auto *RGV = cast<GlobalValue>(RV);
611
612 const auto IsGVNameSemantic = [&](const GlobalValue *GV) {
613 auto LT = GV->getLinkage();
614 return !(GlobalValue::isPrivateLinkage(LT) ||
616 };
617
618 // Use the names to distinguish the two values, but only if the
619 // names are semantically important.
620 if (IsGVNameSemantic(LGV) && IsGVNameSemantic(RGV))
621 return LGV->getName().compare(RGV->getName());
622 }
623
624 // For instructions, compare their loop depth, and their operand count. This
625 // is pretty loose.
626 if (const auto *LInst = dyn_cast<Instruction>(LV)) {
627 const auto *RInst = cast<Instruction>(RV);
628
629 // Compare loop depths.
630 const BasicBlock *LParent = LInst->getParent(),
631 *RParent = RInst->getParent();
632 if (LParent != RParent) {
633 unsigned LDepth = LI->getLoopDepth(LParent),
634 RDepth = LI->getLoopDepth(RParent);
635 if (LDepth != RDepth)
636 return (int)LDepth - (int)RDepth;
637 }
638
639 // Compare the number of operands.
640 unsigned LNumOps = LInst->getNumOperands(),
641 RNumOps = RInst->getNumOperands();
642 if (LNumOps != RNumOps)
643 return (int)LNumOps - (int)RNumOps;
644
645 for (unsigned Idx : seq(LNumOps)) {
646 int Result = CompareValueComplexity(LI, LInst->getOperand(Idx),
647 RInst->getOperand(Idx), Depth + 1);
648 if (Result != 0)
649 return Result;
650 }
651 }
652
653 return 0;
654}
655
656// Return negative, zero, or positive, if LHS is less than, equal to, or greater
657// than RHS, respectively. A three-way result allows recursive comparisons to be
658// more efficient.
659// If the max analysis depth was reached, return std::nullopt, assuming we do
660// not know if they are equivalent for sure.
661static std::optional<int>
663 const LoopInfo *const LI, const SCEV *LHS,
664 const SCEV *RHS, DominatorTree &DT, unsigned Depth = 0) {
665 // Fast-path: SCEVs are uniqued so we can do a quick equality check.
666 if (LHS == RHS)
667 return 0;
668
669 // Primarily, sort the SCEVs by their getSCEVType().
670 SCEVTypes LType = LHS->getSCEVType(), RType = RHS->getSCEVType();
671 if (LType != RType)
672 return (int)LType - (int)RType;
673
674 if (EqCacheSCEV.isEquivalent(LHS, RHS))
675 return 0;
676
678 return std::nullopt;
679
680 // Aside from the getSCEVType() ordering, the particular ordering
681 // isn't very important except that it's beneficial to be consistent,
682 // so that (a + b) and (b + a) don't end up as different expressions.
683 switch (LType) {
684 case scUnknown: {
685 const SCEVUnknown *LU = cast<SCEVUnknown>(LHS);
686 const SCEVUnknown *RU = cast<SCEVUnknown>(RHS);
687
688 int X =
689 CompareValueComplexity(LI, LU->getValue(), RU->getValue(), Depth + 1);
690 if (X == 0)
691 EqCacheSCEV.unionSets(LHS, RHS);
692 return X;
693 }
694
695 case scConstant: {
696 const SCEVConstant *LC = cast<SCEVConstant>(LHS);
697 const SCEVConstant *RC = cast<SCEVConstant>(RHS);
698
699 // Compare constant values.
700 const APInt &LA = LC->getAPInt();
701 const APInt &RA = RC->getAPInt();
702 unsigned LBitWidth = LA.getBitWidth(), RBitWidth = RA.getBitWidth();
703 if (LBitWidth != RBitWidth)
704 return (int)LBitWidth - (int)RBitWidth;
705 return LA.ult(RA) ? -1 : 1;
706 }
707
708 case scVScale: {
709 const auto *LTy = cast<IntegerType>(cast<SCEVVScale>(LHS)->getType());
710 const auto *RTy = cast<IntegerType>(cast<SCEVVScale>(RHS)->getType());
711 return LTy->getBitWidth() - RTy->getBitWidth();
712 }
713
714 case scAddRecExpr: {
715 const SCEVAddRecExpr *LA = cast<SCEVAddRecExpr>(LHS);
716 const SCEVAddRecExpr *RA = cast<SCEVAddRecExpr>(RHS);
717
718 // There is always a dominance between two recs that are used by one SCEV,
719 // so we can safely sort recs by loop header dominance. We require such
720 // order in getAddExpr.
721 const Loop *LLoop = LA->getLoop(), *RLoop = RA->getLoop();
722 if (LLoop != RLoop) {
723 const BasicBlock *LHead = LLoop->getHeader(), *RHead = RLoop->getHeader();
724 assert(LHead != RHead && "Two loops share the same header?");
725 if (DT.dominates(LHead, RHead))
726 return 1;
727 assert(DT.dominates(RHead, LHead) &&
728 "No dominance between recurrences used by one SCEV?");
729 return -1;
730 }
731
732 [[fallthrough]];
733 }
734
735 case scTruncate:
736 case scZeroExtend:
737 case scSignExtend:
738 case scPtrToInt:
739 case scAddExpr:
740 case scMulExpr:
741 case scUDivExpr:
742 case scSMaxExpr:
743 case scUMaxExpr:
744 case scSMinExpr:
745 case scUMinExpr:
747 ArrayRef<const SCEV *> LOps = LHS->operands();
748 ArrayRef<const SCEV *> ROps = RHS->operands();
749
750 // Lexicographically compare n-ary-like expressions.
751 unsigned LNumOps = LOps.size(), RNumOps = ROps.size();
752 if (LNumOps != RNumOps)
753 return (int)LNumOps - (int)RNumOps;
754
755 for (unsigned i = 0; i != LNumOps; ++i) {
756 auto X = CompareSCEVComplexity(EqCacheSCEV, LI, LOps[i], ROps[i], DT,
757 Depth + 1);
758 if (X != 0)
759 return X;
760 }
761 EqCacheSCEV.unionSets(LHS, RHS);
762 return 0;
763 }
764
766 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
767 }
768 llvm_unreachable("Unknown SCEV kind!");
769}
770
771/// Given a list of SCEV objects, order them by their complexity, and group
772/// objects of the same complexity together by value. When this routine is
773/// finished, we know that any duplicates in the vector are consecutive and that
774/// complexity is monotonically increasing.
775///
776/// Note that we go take special precautions to ensure that we get deterministic
777/// results from this routine. In other words, we don't want the results of
778/// this to depend on where the addresses of various SCEV objects happened to
779/// land in memory.
781 LoopInfo *LI, DominatorTree &DT) {
782 if (Ops.size() < 2) return; // Noop
783
785
786 // Whether LHS has provably less complexity than RHS.
787 auto IsLessComplex = [&](const SCEV *LHS, const SCEV *RHS) {
788 auto Complexity = CompareSCEVComplexity(EqCacheSCEV, LI, LHS, RHS, DT);
789 return Complexity && *Complexity < 0;
790 };
791 if (Ops.size() == 2) {
792 // This is the common case, which also happens to be trivially simple.
793 // Special case it.
794 const SCEV *&LHS = Ops[0], *&RHS = Ops[1];
795 if (IsLessComplex(RHS, LHS))
796 std::swap(LHS, RHS);
797 return;
798 }
799
800 // Do the rough sort by complexity.
801 llvm::stable_sort(Ops, [&](const SCEV *LHS, const SCEV *RHS) {
802 return IsLessComplex(LHS, RHS);
803 });
804
805 // Now that we are sorted by complexity, group elements of the same
806 // complexity. Note that this is, at worst, N^2, but the vector is likely to
807 // be extremely short in practice. Note that we take this approach because we
808 // do not want to depend on the addresses of the objects we are grouping.
809 for (unsigned i = 0, e = Ops.size(); i != e-2; ++i) {
810 const SCEV *S = Ops[i];
811 unsigned Complexity = S->getSCEVType();
812
813 // If there are any objects of the same complexity and same value as this
814 // one, group them.
815 for (unsigned j = i+1; j != e && Ops[j]->getSCEVType() == Complexity; ++j) {
816 if (Ops[j] == S) { // Found a duplicate.
817 // Move it to immediately after i'th element.
818 std::swap(Ops[i+1], Ops[j]);
819 ++i; // no need to rescan it.
820 if (i == e-2) return; // Done!
821 }
822 }
823 }
824}
825
826/// Returns true if \p Ops contains a huge SCEV (the subtree of S contains at
827/// least HugeExprThreshold nodes).
829 return any_of(Ops, [](const SCEV *S) {
831 });
832}
833
834/// Performs a number of common optimizations on the passed \p Ops. If the
835/// whole expression reduces down to a single operand, it will be returned.
836///
837/// The following optimizations are performed:
838/// * Fold constants using the \p Fold function.
839/// * Remove identity constants satisfying \p IsIdentity.
840/// * If a constant satisfies \p IsAbsorber, return it.
841/// * Sort operands by complexity.
842template <typename FoldT, typename IsIdentityT, typename IsAbsorberT>
843static const SCEV *
845 SmallVectorImpl<const SCEV *> &Ops, FoldT Fold,
846 IsIdentityT IsIdentity, IsAbsorberT IsAbsorber) {
847 const SCEVConstant *Folded = nullptr;
848 for (unsigned Idx = 0; Idx < Ops.size();) {
849 const SCEV *Op = Ops[Idx];
850 if (const auto *C = dyn_cast<SCEVConstant>(Op)) {
851 if (!Folded)
852 Folded = C;
853 else
854 Folded = cast<SCEVConstant>(
855 SE.getConstant(Fold(Folded->getAPInt(), C->getAPInt())));
856 Ops.erase(Ops.begin() + Idx);
857 continue;
858 }
859 ++Idx;
860 }
861
862 if (Ops.empty()) {
863 assert(Folded && "Must have folded value");
864 return Folded;
865 }
866
867 if (Folded && IsAbsorber(Folded->getAPInt()))
868 return Folded;
869
870 GroupByComplexity(Ops, &LI, DT);
871 if (Folded && !IsIdentity(Folded->getAPInt()))
872 Ops.insert(Ops.begin(), Folded);
873
874 return Ops.size() == 1 ? Ops[0] : nullptr;
875}
876
877//===----------------------------------------------------------------------===//
878// Simple SCEV method implementations
879//===----------------------------------------------------------------------===//
880
881/// Compute BC(It, K). The result has width W. Assume, K > 0.
882static const SCEV *BinomialCoefficient(const SCEV *It, unsigned K,
883 ScalarEvolution &SE,
884 Type *ResultTy) {
885 // Handle the simplest case efficiently.
886 if (K == 1)
887 return SE.getTruncateOrZeroExtend(It, ResultTy);
888
889 // We are using the following formula for BC(It, K):
890 //
891 // BC(It, K) = (It * (It - 1) * ... * (It - K + 1)) / K!
892 //
893 // Suppose, W is the bitwidth of the return value. We must be prepared for
894 // overflow. Hence, we must assure that the result of our computation is
895 // equal to the accurate one modulo 2^W. Unfortunately, division isn't
896 // safe in modular arithmetic.
897 //
898 // However, this code doesn't use exactly that formula; the formula it uses
899 // is something like the following, where T is the number of factors of 2 in
900 // K! (i.e. trailing zeros in the binary representation of K!), and ^ is
901 // exponentiation:
902 //
903 // BC(It, K) = (It * (It - 1) * ... * (It - K + 1)) / 2^T / (K! / 2^T)
904 //
905 // This formula is trivially equivalent to the previous formula. However,
906 // this formula can be implemented much more efficiently. The trick is that
907 // K! / 2^T is odd, and exact division by an odd number *is* safe in modular
908 // arithmetic. To do exact division in modular arithmetic, all we have
909 // to do is multiply by the inverse. Therefore, this step can be done at
910 // width W.
911 //
912 // The next issue is how to safely do the division by 2^T. The way this
913 // is done is by doing the multiplication step at a width of at least W + T
914 // bits. This way, the bottom W+T bits of the product are accurate. Then,
915 // when we perform the division by 2^T (which is equivalent to a right shift
916 // by T), the bottom W bits are accurate. Extra bits are okay; they'll get
917 // truncated out after the division by 2^T.
918 //
919 // In comparison to just directly using the first formula, this technique
920 // is much more efficient; using the first formula requires W * K bits,
921 // but this formula less than W + K bits. Also, the first formula requires
922 // a division step, whereas this formula only requires multiplies and shifts.
923 //
924 // It doesn't matter whether the subtraction step is done in the calculation
925 // width or the input iteration count's width; if the subtraction overflows,
926 // the result must be zero anyway. We prefer here to do it in the width of
927 // the induction variable because it helps a lot for certain cases; CodeGen
928 // isn't smart enough to ignore the overflow, which leads to much less
929 // efficient code if the width of the subtraction is wider than the native
930 // register width.
931 //
932 // (It's possible to not widen at all by pulling out factors of 2 before
933 // the multiplication; for example, K=2 can be calculated as
934 // It/2*(It+(It*INT_MIN/INT_MIN)+-1). However, it requires
935 // extra arithmetic, so it's not an obvious win, and it gets
936 // much more complicated for K > 3.)
937
938 // Protection from insane SCEVs; this bound is conservative,
939 // but it probably doesn't matter.
940 if (K > 1000)
941 return SE.getCouldNotCompute();
942
943 unsigned W = SE.getTypeSizeInBits(ResultTy);
944
945 // Calculate K! / 2^T and T; we divide out the factors of two before
946 // multiplying for calculating K! / 2^T to avoid overflow.
947 // Other overflow doesn't matter because we only care about the bottom
948 // W bits of the result.
949 APInt OddFactorial(W, 1);
950 unsigned T = 1;
951 for (unsigned i = 3; i <= K; ++i) {
952 unsigned TwoFactors = countr_zero(i);
953 T += TwoFactors;
954 OddFactorial *= (i >> TwoFactors);
955 }
956
957 // We need at least W + T bits for the multiplication step
958 unsigned CalculationBits = W + T;
959
960 // Calculate 2^T, at width T+W.
961 APInt DivFactor = APInt::getOneBitSet(CalculationBits, T);
962
963 // Calculate the multiplicative inverse of K! / 2^T;
964 // this multiplication factor will perform the exact division by
965 // K! / 2^T.
966 APInt MultiplyFactor = OddFactorial.multiplicativeInverse();
967
968 // Calculate the product, at width T+W
969 IntegerType *CalculationTy = IntegerType::get(SE.getContext(),
970 CalculationBits);
971 const SCEV *Dividend = SE.getTruncateOrZeroExtend(It, CalculationTy);
972 for (unsigned i = 1; i != K; ++i) {
973 const SCEV *S = SE.getMinusSCEV(It, SE.getConstant(It->getType(), i));
974 Dividend = SE.getMulExpr(Dividend,
975 SE.getTruncateOrZeroExtend(S, CalculationTy));
976 }
977
978 // Divide by 2^T
979 const SCEV *DivResult = SE.getUDivExpr(Dividend, SE.getConstant(DivFactor));
980
981 // Truncate the result, and divide by K! / 2^T.
982
983 return SE.getMulExpr(SE.getConstant(MultiplyFactor),
984 SE.getTruncateOrZeroExtend(DivResult, ResultTy));
985}
986
987/// Return the value of this chain of recurrences at the specified iteration
988/// number. We can evaluate this recurrence by multiplying each element in the
989/// chain by the binomial coefficient corresponding to it. In other words, we
990/// can evaluate {A,+,B,+,C,+,D} as:
991///
992/// A*BC(It, 0) + B*BC(It, 1) + C*BC(It, 2) + D*BC(It, 3)
993///
994/// where BC(It, k) stands for binomial coefficient.
996 ScalarEvolution &SE) const {
997 return evaluateAtIteration(operands(), It, SE);
998}
999
1000const SCEV *
1002 const SCEV *It, ScalarEvolution &SE) {
1003 assert(Operands.size() > 0);
1004 const SCEV *Result = Operands[0];
1005 for (unsigned i = 1, e = Operands.size(); i != e; ++i) {
1006 // The computation is correct in the face of overflow provided that the
1007 // multiplication is performed _after_ the evaluation of the binomial
1008 // coefficient.
1009 const SCEV *Coeff = BinomialCoefficient(It, i, SE, Result->getType());
1010 if (isa<SCEVCouldNotCompute>(Coeff))
1011 return Coeff;
1012
1013 Result = SE.getAddExpr(Result, SE.getMulExpr(Operands[i], Coeff));
1014 }
1015 return Result;
1016}
1017
1018//===----------------------------------------------------------------------===//
1019// SCEV Expression folder implementations
1020//===----------------------------------------------------------------------===//
1021
1023 unsigned Depth) {
1024 assert(Depth <= 1 &&
1025 "getLosslessPtrToIntExpr() should self-recurse at most once.");
1026
1027 // We could be called with an integer-typed operands during SCEV rewrites.
1028 // Since the operand is an integer already, just perform zext/trunc/self cast.
1029 if (!Op->getType()->isPointerTy())
1030 return Op;
1031
1032 // What would be an ID for such a SCEV cast expression?
1034 ID.AddInteger(scPtrToInt);
1035 ID.AddPointer(Op);
1036
1037 void *IP = nullptr;
1038
1039 // Is there already an expression for such a cast?
1040 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP))
1041 return S;
1042
1043 // It isn't legal for optimizations to construct new ptrtoint expressions
1044 // for non-integral pointers.
1045 if (getDataLayout().isNonIntegralPointerType(Op->getType()))
1046 return getCouldNotCompute();
1047
1048 Type *IntPtrTy = getDataLayout().getIntPtrType(Op->getType());
1049
1050 // We can only trivially model ptrtoint if SCEV's effective (integer) type
1051 // is sufficiently wide to represent all possible pointer values.
1052 // We could theoretically teach SCEV to truncate wider pointers, but
1053 // that isn't implemented for now.
1055 getDataLayout().getTypeSizeInBits(IntPtrTy))
1056 return getCouldNotCompute();
1057
1058 // If not, is this expression something we can't reduce any further?
1059 if (auto *U = dyn_cast<SCEVUnknown>(Op)) {
1060 // Perform some basic constant folding. If the operand of the ptr2int cast
1061 // is a null pointer, don't create a ptr2int SCEV expression (that will be
1062 // left as-is), but produce a zero constant.
1063 // NOTE: We could handle a more general case, but lack motivational cases.
1064 if (isa<ConstantPointerNull>(U->getValue()))
1065 return getZero(IntPtrTy);
1066
1067 // Create an explicit cast node.
1068 // We can reuse the existing insert position since if we get here,
1069 // we won't have made any changes which would invalidate it.
1070 SCEV *S = new (SCEVAllocator)
1071 SCEVPtrToIntExpr(ID.Intern(SCEVAllocator), Op, IntPtrTy);
1072 UniqueSCEVs.InsertNode(S, IP);
1073 registerUser(S, Op);
1074 return S;
1075 }
1076
1077 assert(Depth == 0 && "getLosslessPtrToIntExpr() should not self-recurse for "
1078 "non-SCEVUnknown's.");
1079
1080 // Otherwise, we've got some expression that is more complex than just a
1081 // single SCEVUnknown. But we don't want to have a SCEVPtrToIntExpr of an
1082 // arbitrary expression, we want to have SCEVPtrToIntExpr of an SCEVUnknown
1083 // only, and the expressions must otherwise be integer-typed.
1084 // So sink the cast down to the SCEVUnknown's.
1085
1086 /// The SCEVPtrToIntSinkingRewriter takes a scalar evolution expression,
1087 /// which computes a pointer-typed value, and rewrites the whole expression
1088 /// tree so that *all* the computations are done on integers, and the only
1089 /// pointer-typed operands in the expression are SCEVUnknown.
1090 class SCEVPtrToIntSinkingRewriter
1091 : public SCEVRewriteVisitor<SCEVPtrToIntSinkingRewriter> {
1093
1094 public:
1095 SCEVPtrToIntSinkingRewriter(ScalarEvolution &SE) : SCEVRewriteVisitor(SE) {}
1096
1097 static const SCEV *rewrite(const SCEV *Scev, ScalarEvolution &SE) {
1098 SCEVPtrToIntSinkingRewriter Rewriter(SE);
1099 return Rewriter.visit(Scev);
1100 }
1101
1102 const SCEV *visit(const SCEV *S) {
1103 Type *STy = S->getType();
1104 // If the expression is not pointer-typed, just keep it as-is.
1105 if (!STy->isPointerTy())
1106 return S;
1107 // Else, recursively sink the cast down into it.
1108 return Base::visit(S);
1109 }
1110
1111 const SCEV *visitAddExpr(const SCEVAddExpr *Expr) {
1113 bool Changed = false;
1114 for (const auto *Op : Expr->operands()) {
1115 Operands.push_back(visit(Op));
1116 Changed |= Op != Operands.back();
1117 }
1118 return !Changed ? Expr : SE.getAddExpr(Operands, Expr->getNoWrapFlags());
1119 }
1120
1121 const SCEV *visitMulExpr(const SCEVMulExpr *Expr) {
1123 bool Changed = false;
1124 for (const auto *Op : Expr->operands()) {
1125 Operands.push_back(visit(Op));
1126 Changed |= Op != Operands.back();
1127 }
1128 return !Changed ? Expr : SE.getMulExpr(Operands, Expr->getNoWrapFlags());
1129 }
1130
1131 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
1132 assert(Expr->getType()->isPointerTy() &&
1133 "Should only reach pointer-typed SCEVUnknown's.");
1134 return SE.getLosslessPtrToIntExpr(Expr, /*Depth=*/1);
1135 }
1136 };
1137
1138 // And actually perform the cast sinking.
1139 const SCEV *IntOp = SCEVPtrToIntSinkingRewriter::rewrite(Op, *this);
1140 assert(IntOp->getType()->isIntegerTy() &&
1141 "We must have succeeded in sinking the cast, "
1142 "and ending up with an integer-typed expression!");
1143 return IntOp;
1144}
1145
1147 assert(Ty->isIntegerTy() && "Target type must be an integer type!");
1148
1149 const SCEV *IntOp = getLosslessPtrToIntExpr(Op);
1150 if (isa<SCEVCouldNotCompute>(IntOp))
1151 return IntOp;
1152
1153 return getTruncateOrZeroExtend(IntOp, Ty);
1154}
1155
1157 unsigned Depth) {
1158 assert(getTypeSizeInBits(Op->getType()) > getTypeSizeInBits(Ty) &&
1159 "This is not a truncating conversion!");
1160 assert(isSCEVable(Ty) &&
1161 "This is not a conversion to a SCEVable type!");
1162 assert(!Op->getType()->isPointerTy() && "Can't truncate pointer!");
1163 Ty = getEffectiveSCEVType(Ty);
1164
1166 ID.AddInteger(scTruncate);
1167 ID.AddPointer(Op);
1168 ID.AddPointer(Ty);
1169 void *IP = nullptr;
1170 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
1171
1172 // Fold if the operand is constant.
1173 if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op))
1174 return getConstant(
1175 cast<ConstantInt>(ConstantExpr::getTrunc(SC->getValue(), Ty)));
1176
1177 // trunc(trunc(x)) --> trunc(x)
1178 if (const SCEVTruncateExpr *ST = dyn_cast<SCEVTruncateExpr>(Op))
1179 return getTruncateExpr(ST->getOperand(), Ty, Depth + 1);
1180
1181 // trunc(sext(x)) --> sext(x) if widening or trunc(x) if narrowing
1182 if (const SCEVSignExtendExpr *SS = dyn_cast<SCEVSignExtendExpr>(Op))
1183 return getTruncateOrSignExtend(SS->getOperand(), Ty, Depth + 1);
1184
1185 // trunc(zext(x)) --> zext(x) if widening or trunc(x) if narrowing
1186 if (const SCEVZeroExtendExpr *SZ = dyn_cast<SCEVZeroExtendExpr>(Op))
1187 return getTruncateOrZeroExtend(SZ->getOperand(), Ty, Depth + 1);
1188
1189 if (Depth > MaxCastDepth) {
1190 SCEV *S =
1191 new (SCEVAllocator) SCEVTruncateExpr(ID.Intern(SCEVAllocator), Op, Ty);
1192 UniqueSCEVs.InsertNode(S, IP);
1193 registerUser(S, Op);
1194 return S;
1195 }
1196
1197 // trunc(x1 + ... + xN) --> trunc(x1) + ... + trunc(xN) and
1198 // trunc(x1 * ... * xN) --> trunc(x1) * ... * trunc(xN),
1199 // if after transforming we have at most one truncate, not counting truncates
1200 // that replace other casts.
1201 if (isa<SCEVAddExpr>(Op) || isa<SCEVMulExpr>(Op)) {
1202 auto *CommOp = cast<SCEVCommutativeExpr>(Op);
1204 unsigned numTruncs = 0;
1205 for (unsigned i = 0, e = CommOp->getNumOperands(); i != e && numTruncs < 2;
1206 ++i) {
1207 const SCEV *S = getTruncateExpr(CommOp->getOperand(i), Ty, Depth + 1);
1208 if (!isa<SCEVIntegralCastExpr>(CommOp->getOperand(i)) &&
1209 isa<SCEVTruncateExpr>(S))
1210 numTruncs++;
1211 Operands.push_back(S);
1212 }
1213 if (numTruncs < 2) {
1214 if (isa<SCEVAddExpr>(Op))
1215 return getAddExpr(Operands);
1216 if (isa<SCEVMulExpr>(Op))
1217 return getMulExpr(Operands);
1218 llvm_unreachable("Unexpected SCEV type for Op.");
1219 }
1220 // Although we checked in the beginning that ID is not in the cache, it is
1221 // possible that during recursion and different modification ID was inserted
1222 // into the cache. So if we find it, just return it.
1223 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP))
1224 return S;
1225 }
1226
1227 // If the input value is a chrec scev, truncate the chrec's operands.
1228 if (const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(Op)) {
1230 for (const SCEV *Op : AddRec->operands())
1231 Operands.push_back(getTruncateExpr(Op, Ty, Depth + 1));
1232 return getAddRecExpr(Operands, AddRec->getLoop(), SCEV::FlagAnyWrap);
1233 }
1234
1235 // Return zero if truncating to known zeros.
1236 uint32_t MinTrailingZeros = getMinTrailingZeros(Op);
1237 if (MinTrailingZeros >= getTypeSizeInBits(Ty))
1238 return getZero(Ty);
1239
1240 // The cast wasn't folded; create an explicit cast node. We can reuse
1241 // the existing insert position since if we get here, we won't have
1242 // made any changes which would invalidate it.
1243 SCEV *S = new (SCEVAllocator) SCEVTruncateExpr(ID.Intern(SCEVAllocator),
1244 Op, Ty);
1245 UniqueSCEVs.InsertNode(S, IP);
1246 registerUser(S, Op);
1247 return S;
1248}
1249
1250// Get the limit of a recurrence such that incrementing by Step cannot cause
1251// signed overflow as long as the value of the recurrence within the
1252// loop does not exceed this limit before incrementing.
1253static const SCEV *getSignedOverflowLimitForStep(const SCEV *Step,
1254 ICmpInst::Predicate *Pred,
1255 ScalarEvolution *SE) {
1256 unsigned BitWidth = SE->getTypeSizeInBits(Step->getType());
1257 if (SE->isKnownPositive(Step)) {
1258 *Pred = ICmpInst::ICMP_SLT;
1260 SE->getSignedRangeMax(Step));
1261 }
1262 if (SE->isKnownNegative(Step)) {
1263 *Pred = ICmpInst::ICMP_SGT;
1265 SE->getSignedRangeMin(Step));
1266 }
1267 return nullptr;
1268}
1269
1270// Get the limit of a recurrence such that incrementing by Step cannot cause
1271// unsigned overflow as long as the value of the recurrence within the loop does
1272// not exceed this limit before incrementing.
1274 ICmpInst::Predicate *Pred,
1275 ScalarEvolution *SE) {
1276 unsigned BitWidth = SE->getTypeSizeInBits(Step->getType());
1277 *Pred = ICmpInst::ICMP_ULT;
1278
1280 SE->getUnsignedRangeMax(Step));
1281}
1282
1283namespace {
1284
1285struct ExtendOpTraitsBase {
1286 typedef const SCEV *(ScalarEvolution::*GetExtendExprTy)(const SCEV *, Type *,
1287 unsigned);
1288};
1289
1290// Used to make code generic over signed and unsigned overflow.
1291template <typename ExtendOp> struct ExtendOpTraits {
1292 // Members present:
1293 //
1294 // static const SCEV::NoWrapFlags WrapType;
1295 //
1296 // static const ExtendOpTraitsBase::GetExtendExprTy GetExtendExpr;
1297 //
1298 // static const SCEV *getOverflowLimitForStep(const SCEV *Step,
1299 // ICmpInst::Predicate *Pred,
1300 // ScalarEvolution *SE);
1301};
1302
1303template <>
1304struct ExtendOpTraits<SCEVSignExtendExpr> : public ExtendOpTraitsBase {
1305 static const SCEV::NoWrapFlags WrapType = SCEV::FlagNSW;
1306
1307 static const GetExtendExprTy GetExtendExpr;
1308
1309 static const SCEV *getOverflowLimitForStep(const SCEV *Step,
1310 ICmpInst::Predicate *Pred,
1311 ScalarEvolution *SE) {
1312 return getSignedOverflowLimitForStep(Step, Pred, SE);
1313 }
1314};
1315
1316const ExtendOpTraitsBase::GetExtendExprTy ExtendOpTraits<
1318
1319template <>
1320struct ExtendOpTraits<SCEVZeroExtendExpr> : public ExtendOpTraitsBase {
1321 static const SCEV::NoWrapFlags WrapType = SCEV::FlagNUW;
1322
1323 static const GetExtendExprTy GetExtendExpr;
1324
1325 static const SCEV *getOverflowLimitForStep(const SCEV *Step,
1326 ICmpInst::Predicate *Pred,
1327 ScalarEvolution *SE) {
1328 return getUnsignedOverflowLimitForStep(Step, Pred, SE);
1329 }
1330};
1331
1332const ExtendOpTraitsBase::GetExtendExprTy ExtendOpTraits<
1334
1335} // end anonymous namespace
1336
1337// The recurrence AR has been shown to have no signed/unsigned wrap or something
1338// close to it. Typically, if we can prove NSW/NUW for AR, then we can just as
1339// easily prove NSW/NUW for its preincrement or postincrement sibling. This
1340// allows normalizing a sign/zero extended AddRec as such: {sext/zext(Step +
1341// Start),+,Step} => {(Step + sext/zext(Start),+,Step} As a result, the
1342// expression "Step + sext/zext(PreIncAR)" is congruent with
1343// "sext/zext(PostIncAR)"
1344template <typename ExtendOpTy>
1345static const SCEV *getPreStartForExtend(const SCEVAddRecExpr *AR, Type *Ty,
1346 ScalarEvolution *SE, unsigned Depth) {
1347 auto WrapType = ExtendOpTraits<ExtendOpTy>::WrapType;
1348 auto GetExtendExpr = ExtendOpTraits<ExtendOpTy>::GetExtendExpr;
1349
1350 const Loop *L = AR->getLoop();
1351 const SCEV *Start = AR->getStart();
1352 const SCEV *Step = AR->getStepRecurrence(*SE);
1353
1354 // Check for a simple looking step prior to loop entry.
1355 const SCEVAddExpr *SA = dyn_cast<SCEVAddExpr>(Start);
1356 if (!SA)
1357 return nullptr;
1358
1359 // Create an AddExpr for "PreStart" after subtracting Step. Full SCEV
1360 // subtraction is expensive. For this purpose, perform a quick and dirty
1361 // difference, by checking for Step in the operand list. Note, that
1362 // SA might have repeated ops, like %a + %a + ..., so only remove one.
1364 for (auto It = DiffOps.begin(); It != DiffOps.end(); ++It)
1365 if (*It == Step) {
1366 DiffOps.erase(It);
1367 break;
1368 }
1369
1370 if (DiffOps.size() == SA->getNumOperands())
1371 return nullptr;
1372
1373 // Try to prove `WrapType` (SCEV::FlagNSW or SCEV::FlagNUW) on `PreStart` +
1374 // `Step`:
1375
1376 // 1. NSW/NUW flags on the step increment.
1377 auto PreStartFlags =
1379 const SCEV *PreStart = SE->getAddExpr(DiffOps, PreStartFlags);
1380 const SCEVAddRecExpr *PreAR = dyn_cast<SCEVAddRecExpr>(
1381 SE->getAddRecExpr(PreStart, Step, L, SCEV::FlagAnyWrap));
1382
1383 // "{S,+,X} is <nsw>/<nuw>" and "the backedge is taken at least once" implies
1384 // "S+X does not sign/unsign-overflow".
1385 //
1386
1387 const SCEV *BECount = SE->getBackedgeTakenCount(L);
1388 if (PreAR && PreAR->getNoWrapFlags(WrapType) &&
1389 !isa<SCEVCouldNotCompute>(BECount) && SE->isKnownPositive(BECount))
1390 return PreStart;
1391
1392 // 2. Direct overflow check on the step operation's expression.
1393 unsigned BitWidth = SE->getTypeSizeInBits(AR->getType());
1394 Type *WideTy = IntegerType::get(SE->getContext(), BitWidth * 2);
1395 const SCEV *OperandExtendedStart =
1396 SE->getAddExpr((SE->*GetExtendExpr)(PreStart, WideTy, Depth),
1397 (SE->*GetExtendExpr)(Step, WideTy, Depth));
1398 if ((SE->*GetExtendExpr)(Start, WideTy, Depth) == OperandExtendedStart) {
1399 if (PreAR && AR->getNoWrapFlags(WrapType)) {
1400 // If we know `AR` == {`PreStart`+`Step`,+,`Step`} is `WrapType` (FlagNSW
1401 // or FlagNUW) and that `PreStart` + `Step` is `WrapType` too, then
1402 // `PreAR` == {`PreStart`,+,`Step`} is also `WrapType`. Cache this fact.
1403 SE->setNoWrapFlags(const_cast<SCEVAddRecExpr *>(PreAR), WrapType);
1404 }
1405 return PreStart;
1406 }
1407
1408 // 3. Loop precondition.
1410 const SCEV *OverflowLimit =
1411 ExtendOpTraits<ExtendOpTy>::getOverflowLimitForStep(Step, &Pred, SE);
1412
1413 if (OverflowLimit &&
1414 SE->isLoopEntryGuardedByCond(L, Pred, PreStart, OverflowLimit))
1415 return PreStart;
1416
1417 return nullptr;
1418}
1419
1420// Get the normalized zero or sign extended expression for this AddRec's Start.
1421template <typename ExtendOpTy>
1422static const SCEV *getExtendAddRecStart(const SCEVAddRecExpr *AR, Type *Ty,
1423 ScalarEvolution *SE,
1424 unsigned Depth) {
1425 auto GetExtendExpr = ExtendOpTraits<ExtendOpTy>::GetExtendExpr;
1426
1427 const SCEV *PreStart = getPreStartForExtend<ExtendOpTy>(AR, Ty, SE, Depth);
1428 if (!PreStart)
1429 return (SE->*GetExtendExpr)(AR->getStart(), Ty, Depth);
1430
1431 return SE->getAddExpr((SE->*GetExtendExpr)(AR->getStepRecurrence(*SE), Ty,
1432 Depth),
1433 (SE->*GetExtendExpr)(PreStart, Ty, Depth));
1434}
1435
1436// Try to prove away overflow by looking at "nearby" add recurrences. A
1437// motivating example for this rule: if we know `{0,+,4}` is `ult` `-1` and it
1438// does not itself wrap then we can conclude that `{1,+,4}` is `nuw`.
1439//
1440// Formally:
1441//
1442// {S,+,X} == {S-T,+,X} + T
1443// => Ext({S,+,X}) == Ext({S-T,+,X} + T)
1444//
1445// If ({S-T,+,X} + T) does not overflow ... (1)
1446//
1447// RHS == Ext({S-T,+,X} + T) == Ext({S-T,+,X}) + Ext(T)
1448//
1449// If {S-T,+,X} does not overflow ... (2)
1450//
1451// RHS == Ext({S-T,+,X}) + Ext(T) == {Ext(S-T),+,Ext(X)} + Ext(T)
1452// == {Ext(S-T)+Ext(T),+,Ext(X)}
1453//
1454// If (S-T)+T does not overflow ... (3)
1455//
1456// RHS == {Ext(S-T)+Ext(T),+,Ext(X)} == {Ext(S-T+T),+,Ext(X)}
1457// == {Ext(S),+,Ext(X)} == LHS
1458//
1459// Thus, if (1), (2) and (3) are true for some T, then
1460// Ext({S,+,X}) == {Ext(S),+,Ext(X)}
1461//
1462// (3) is implied by (1) -- "(S-T)+T does not overflow" is simply "({S-T,+,X}+T)
1463// does not overflow" restricted to the 0th iteration. Therefore we only need
1464// to check for (1) and (2).
1465//
1466// In the current context, S is `Start`, X is `Step`, Ext is `ExtendOpTy` and T
1467// is `Delta` (defined below).
1468template <typename ExtendOpTy>
1469bool ScalarEvolution::proveNoWrapByVaryingStart(const SCEV *Start,
1470 const SCEV *Step,
1471 const Loop *L) {
1472 auto WrapType = ExtendOpTraits<ExtendOpTy>::WrapType;
1473
1474 // We restrict `Start` to a constant to prevent SCEV from spending too much
1475 // time here. It is correct (but more expensive) to continue with a
1476 // non-constant `Start` and do a general SCEV subtraction to compute
1477 // `PreStart` below.
1478 const SCEVConstant *StartC = dyn_cast<SCEVConstant>(Start);
1479 if (!StartC)
1480 return false;
1481
1482 APInt StartAI = StartC->getAPInt();
1483
1484 for (unsigned Delta : {-2, -1, 1, 2}) {
1485 const SCEV *PreStart = getConstant(StartAI - Delta);
1486
1488 ID.AddInteger(scAddRecExpr);
1489 ID.AddPointer(PreStart);
1490 ID.AddPointer(Step);
1491 ID.AddPointer(L);
1492 void *IP = nullptr;
1493 const auto *PreAR =
1494 static_cast<SCEVAddRecExpr *>(UniqueSCEVs.FindNodeOrInsertPos(ID, IP));
1495
1496 // Give up if we don't already have the add recurrence we need because
1497 // actually constructing an add recurrence is relatively expensive.
1498 if (PreAR && PreAR->getNoWrapFlags(WrapType)) { // proves (2)
1499 const SCEV *DeltaS = getConstant(StartC->getType(), Delta);
1501 const SCEV *Limit = ExtendOpTraits<ExtendOpTy>::getOverflowLimitForStep(
1502 DeltaS, &Pred, this);
1503 if (Limit && isKnownPredicate(Pred, PreAR, Limit)) // proves (1)
1504 return true;
1505 }
1506 }
1507
1508 return false;
1509}
1510
1511// Finds an integer D for an expression (C + x + y + ...) such that the top
1512// level addition in (D + (C - D + x + y + ...)) would not wrap (signed or
1513// unsigned) and the number of trailing zeros of (C - D + x + y + ...) is
1514// maximized, where C is the \p ConstantTerm, x, y, ... are arbitrary SCEVs, and
1515// the (C + x + y + ...) expression is \p WholeAddExpr.
1517 const SCEVConstant *ConstantTerm,
1518 const SCEVAddExpr *WholeAddExpr) {
1519 const APInt &C = ConstantTerm->getAPInt();
1520 const unsigned BitWidth = C.getBitWidth();
1521 // Find number of trailing zeros of (x + y + ...) w/o the C first:
1522 uint32_t TZ = BitWidth;
1523 for (unsigned I = 1, E = WholeAddExpr->getNumOperands(); I < E && TZ; ++I)
1524 TZ = std::min(TZ, SE.getMinTrailingZeros(WholeAddExpr->getOperand(I)));
1525 if (TZ) {
1526 // Set D to be as many least significant bits of C as possible while still
1527 // guaranteeing that adding D to (C - D + x + y + ...) won't cause a wrap:
1528 return TZ < BitWidth ? C.trunc(TZ).zext(BitWidth) : C;
1529 }
1530 return APInt(BitWidth, 0);
1531}
1532
1533// Finds an integer D for an affine AddRec expression {C,+,x} such that the top
1534// level addition in (D + {C-D,+,x}) would not wrap (signed or unsigned) and the
1535// number of trailing zeros of (C - D + x * n) is maximized, where C is the \p
1536// ConstantStart, x is an arbitrary \p Step, and n is the loop trip count.
1538 const APInt &ConstantStart,
1539 const SCEV *Step) {
1540 const unsigned BitWidth = ConstantStart.getBitWidth();
1541 const uint32_t TZ = SE.getMinTrailingZeros(Step);
1542 if (TZ)
1543 return TZ < BitWidth ? ConstantStart.trunc(TZ).zext(BitWidth)
1544 : ConstantStart;
1545 return APInt(BitWidth, 0);
1546}
1547
1549 const ScalarEvolution::FoldID &ID, const SCEV *S,
1552 &FoldCacheUser) {
1553 auto I = FoldCache.insert({ID, S});
1554 if (!I.second) {
1555 // Remove FoldCacheUser entry for ID when replacing an existing FoldCache
1556 // entry.
1557 auto &UserIDs = FoldCacheUser[I.first->second];
1558 assert(count(UserIDs, ID) == 1 && "unexpected duplicates in UserIDs");
1559 for (unsigned I = 0; I != UserIDs.size(); ++I)
1560 if (UserIDs[I] == ID) {
1561 std::swap(UserIDs[I], UserIDs.back());
1562 break;
1563 }
1564 UserIDs.pop_back();
1565 I.first->second = S;
1566 }
1567 auto R = FoldCacheUser.insert({S, {}});
1568 R.first->second.push_back(ID);
1569}
1570
1571const SCEV *
1573 assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) &&
1574 "This is not an extending conversion!");
1575 assert(isSCEVable(Ty) &&
1576 "This is not a conversion to a SCEVable type!");
1577 assert(!Op->getType()->isPointerTy() && "Can't extend pointer!");
1578 Ty = getEffectiveSCEVType(Ty);
1579
1580 FoldID ID(scZeroExtend, Op, Ty);
1581 auto Iter = FoldCache.find(ID);
1582 if (Iter != FoldCache.end())
1583 return Iter->second;
1584
1585 const SCEV *S = getZeroExtendExprImpl(Op, Ty, Depth);
1586 if (!isa<SCEVZeroExtendExpr>(S))
1587 insertFoldCacheEntry(ID, S, FoldCache, FoldCacheUser);
1588 return S;
1589}
1590
1592 unsigned Depth) {
1593 assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) &&
1594 "This is not an extending conversion!");
1595 assert(isSCEVable(Ty) && "This is not a conversion to a SCEVable type!");
1596 assert(!Op->getType()->isPointerTy() && "Can't extend pointer!");
1597
1598 // Fold if the operand is constant.
1599 if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op))
1600 return getConstant(SC->getAPInt().zext(getTypeSizeInBits(Ty)));
1601
1602 // zext(zext(x)) --> zext(x)
1603 if (const SCEVZeroExtendExpr *SZ = dyn_cast<SCEVZeroExtendExpr>(Op))
1604 return getZeroExtendExpr(SZ->getOperand(), Ty, Depth + 1);
1605
1606 // Before doing any expensive analysis, check to see if we've already
1607 // computed a SCEV for this Op and Ty.
1609 ID.AddInteger(scZeroExtend);
1610 ID.AddPointer(Op);
1611 ID.AddPointer(Ty);
1612 void *IP = nullptr;
1613 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
1614 if (Depth > MaxCastDepth) {
1615 SCEV *S = new (SCEVAllocator) SCEVZeroExtendExpr(ID.Intern(SCEVAllocator),
1616 Op, Ty);
1617 UniqueSCEVs.InsertNode(S, IP);
1618 registerUser(S, Op);
1619 return S;
1620 }
1621
1622 // zext(trunc(x)) --> zext(x) or x or trunc(x)
1623 if (const SCEVTruncateExpr *ST = dyn_cast<SCEVTruncateExpr>(Op)) {
1624 // It's possible the bits taken off by the truncate were all zero bits. If
1625 // so, we should be able to simplify this further.
1626 const SCEV *X = ST->getOperand();
1628 unsigned TruncBits = getTypeSizeInBits(ST->getType());
1629 unsigned NewBits = getTypeSizeInBits(Ty);
1630 if (CR.truncate(TruncBits).zeroExtend(NewBits).contains(
1631 CR.zextOrTrunc(NewBits)))
1632 return getTruncateOrZeroExtend(X, Ty, Depth);
1633 }
1634
1635 // If the input value is a chrec scev, and we can prove that the value
1636 // did not overflow the old, smaller, value, we can zero extend all of the
1637 // operands (often constants). This allows analysis of something like
1638 // this: for (unsigned char X = 0; X < 100; ++X) { int Y = X; }
1639 if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Op))
1640 if (AR->isAffine()) {
1641 const SCEV *Start = AR->getStart();
1642 const SCEV *Step = AR->getStepRecurrence(*this);
1643 unsigned BitWidth = getTypeSizeInBits(AR->getType());
1644 const Loop *L = AR->getLoop();
1645
1646 // If we have special knowledge that this addrec won't overflow,
1647 // we don't need to do any further analysis.
1648 if (AR->hasNoUnsignedWrap()) {
1649 Start =
1650 getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this, Depth + 1);
1651 Step = getZeroExtendExpr(Step, Ty, Depth + 1);
1652 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
1653 }
1654
1655 // Check whether the backedge-taken count is SCEVCouldNotCompute.
1656 // Note that this serves two purposes: It filters out loops that are
1657 // simply not analyzable, and it covers the case where this code is
1658 // being called from within backedge-taken count analysis, such that
1659 // attempting to ask for the backedge-taken count would likely result
1660 // in infinite recursion. In the later case, the analysis code will
1661 // cope with a conservative value, and it will take care to purge
1662 // that value once it has finished.
1663 const SCEV *MaxBECount = getConstantMaxBackedgeTakenCount(L);
1664 if (!isa<SCEVCouldNotCompute>(MaxBECount)) {
1665 // Manually compute the final value for AR, checking for overflow.
1666
1667 // Check whether the backedge-taken count can be losslessly casted to
1668 // the addrec's type. The count is always unsigned.
1669 const SCEV *CastedMaxBECount =
1670 getTruncateOrZeroExtend(MaxBECount, Start->getType(), Depth);
1671 const SCEV *RecastedMaxBECount = getTruncateOrZeroExtend(
1672 CastedMaxBECount, MaxBECount->getType(), Depth);
1673 if (MaxBECount == RecastedMaxBECount) {
1674 Type *WideTy = IntegerType::get(getContext(), BitWidth * 2);
1675 // Check whether Start+Step*MaxBECount has no unsigned overflow.
1676 const SCEV *ZMul = getMulExpr(CastedMaxBECount, Step,
1678 const SCEV *ZAdd = getZeroExtendExpr(getAddExpr(Start, ZMul,
1680 Depth + 1),
1681 WideTy, Depth + 1);
1682 const SCEV *WideStart = getZeroExtendExpr(Start, WideTy, Depth + 1);
1683 const SCEV *WideMaxBECount =
1684 getZeroExtendExpr(CastedMaxBECount, WideTy, Depth + 1);
1685 const SCEV *OperandExtendedAdd =
1686 getAddExpr(WideStart,
1687 getMulExpr(WideMaxBECount,
1688 getZeroExtendExpr(Step, WideTy, Depth + 1),
1691 if (ZAdd == OperandExtendedAdd) {
1692 // Cache knowledge of AR NUW, which is propagated to this AddRec.
1693 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNUW);
1694 // Return the expression with the addrec on the outside.
1695 Start = getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this,
1696 Depth + 1);
1697 Step = getZeroExtendExpr(Step, Ty, Depth + 1);
1698 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
1699 }
1700 // Similar to above, only this time treat the step value as signed.
1701 // This covers loops that count down.
1702 OperandExtendedAdd =
1703 getAddExpr(WideStart,
1704 getMulExpr(WideMaxBECount,
1705 getSignExtendExpr(Step, WideTy, Depth + 1),
1708 if (ZAdd == OperandExtendedAdd) {
1709 // Cache knowledge of AR NW, which is propagated to this AddRec.
1710 // Negative step causes unsigned wrap, but it still can't self-wrap.
1711 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNW);
1712 // Return the expression with the addrec on the outside.
1713 Start = getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this,
1714 Depth + 1);
1715 Step = getSignExtendExpr(Step, Ty, Depth + 1);
1716 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
1717 }
1718 }
1719 }
1720
1721 // Normally, in the cases we can prove no-overflow via a
1722 // backedge guarding condition, we can also compute a backedge
1723 // taken count for the loop. The exceptions are assumptions and
1724 // guards present in the loop -- SCEV is not great at exploiting
1725 // these to compute max backedge taken counts, but can still use
1726 // these to prove lack of overflow. Use this fact to avoid
1727 // doing extra work that may not pay off.
1728 if (!isa<SCEVCouldNotCompute>(MaxBECount) || HasGuards ||
1729 !AC.assumptions().empty()) {
1730
1731 auto NewFlags = proveNoUnsignedWrapViaInduction(AR);
1732 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), NewFlags);
1733 if (AR->hasNoUnsignedWrap()) {
1734 // Same as nuw case above - duplicated here to avoid a compile time
1735 // issue. It's not clear that the order of checks does matter, but
1736 // it's one of two issue possible causes for a change which was
1737 // reverted. Be conservative for the moment.
1738 Start =
1739 getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this, Depth + 1);
1740 Step = getZeroExtendExpr(Step, Ty, Depth + 1);
1741 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
1742 }
1743
1744 // For a negative step, we can extend the operands iff doing so only
1745 // traverses values in the range zext([0,UINT_MAX]).
1746 if (isKnownNegative(Step)) {
1748 getSignedRangeMin(Step));
1751 // Cache knowledge of AR NW, which is propagated to this
1752 // AddRec. Negative step causes unsigned wrap, but it
1753 // still can't self-wrap.
1754 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNW);
1755 // Return the expression with the addrec on the outside.
1756 Start = getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this,
1757 Depth + 1);
1758 Step = getSignExtendExpr(Step, Ty, Depth + 1);
1759 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
1760 }
1761 }
1762 }
1763
1764 // zext({C,+,Step}) --> (zext(D) + zext({C-D,+,Step}))<nuw><nsw>
1765 // if D + (C - D + Step * n) could be proven to not unsigned wrap
1766 // where D maximizes the number of trailing zeros of (C - D + Step * n)
1767 if (const auto *SC = dyn_cast<SCEVConstant>(Start)) {
1768 const APInt &C = SC->getAPInt();
1769 const APInt &D = extractConstantWithoutWrapping(*this, C, Step);
1770 if (D != 0) {
1771 const SCEV *SZExtD = getZeroExtendExpr(getConstant(D), Ty, Depth);
1772 const SCEV *SResidual =
1773 getAddRecExpr(getConstant(C - D), Step, L, AR->getNoWrapFlags());
1774 const SCEV *SZExtR = getZeroExtendExpr(SResidual, Ty, Depth + 1);
1775 return getAddExpr(SZExtD, SZExtR,
1777 Depth + 1);
1778 }
1779 }
1780
1781 if (proveNoWrapByVaryingStart<SCEVZeroExtendExpr>(Start, Step, L)) {
1782 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNUW);
1783 Start =
1784 getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this, Depth + 1);
1785 Step = getZeroExtendExpr(Step, Ty, Depth + 1);
1786 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
1787 }
1788 }
1789
1790 // zext(A % B) --> zext(A) % zext(B)
1791 {
1792 const SCEV *LHS;
1793 const SCEV *RHS;
1794 if (matchURem(Op, LHS, RHS))
1795 return getURemExpr(getZeroExtendExpr(LHS, Ty, Depth + 1),
1796 getZeroExtendExpr(RHS, Ty, Depth + 1));
1797 }
1798
1799 // zext(A / B) --> zext(A) / zext(B).
1800 if (auto *Div = dyn_cast<SCEVUDivExpr>(Op))
1801 return getUDivExpr(getZeroExtendExpr(Div->getLHS(), Ty, Depth + 1),
1802 getZeroExtendExpr(Div->getRHS(), Ty, Depth + 1));
1803
1804 if (auto *SA = dyn_cast<SCEVAddExpr>(Op)) {
1805 // zext((A + B + ...)<nuw>) --> (zext(A) + zext(B) + ...)<nuw>
1806 if (SA->hasNoUnsignedWrap()) {
1807 // If the addition does not unsign overflow then we can, by definition,
1808 // commute the zero extension with the addition operation.
1810 for (const auto *Op : SA->operands())
1811 Ops.push_back(getZeroExtendExpr(Op, Ty, Depth + 1));
1812 return getAddExpr(Ops, SCEV::FlagNUW, Depth + 1);
1813 }
1814
1815 // zext(C + x + y + ...) --> (zext(D) + zext((C - D) + x + y + ...))
1816 // if D + (C - D + x + y + ...) could be proven to not unsigned wrap
1817 // where D maximizes the number of trailing zeros of (C - D + x + y + ...)
1818 //
1819 // Often address arithmetics contain expressions like
1820 // (zext (add (shl X, C1), C2)), for instance, (zext (5 + (4 * X))).
1821 // This transformation is useful while proving that such expressions are
1822 // equal or differ by a small constant amount, see LoadStoreVectorizer pass.
1823 if (const auto *SC = dyn_cast<SCEVConstant>(SA->getOperand(0))) {
1824 const APInt &D = extractConstantWithoutWrapping(*this, SC, SA);
1825 if (D != 0) {
1826 const SCEV *SZExtD = getZeroExtendExpr(getConstant(D), Ty, Depth);
1827 const SCEV *SResidual =
1829 const SCEV *SZExtR = getZeroExtendExpr(SResidual, Ty, Depth + 1);
1830 return getAddExpr(SZExtD, SZExtR,
1832 Depth + 1);
1833 }
1834 }
1835 }
1836
1837 if (auto *SM = dyn_cast<SCEVMulExpr>(Op)) {
1838 // zext((A * B * ...)<nuw>) --> (zext(A) * zext(B) * ...)<nuw>
1839 if (SM->hasNoUnsignedWrap()) {
1840 // If the multiply does not unsign overflow then we can, by definition,
1841 // commute the zero extension with the multiply operation.
1843 for (const auto *Op : SM->operands())
1844 Ops.push_back(getZeroExtendExpr(Op, Ty, Depth + 1));
1845 return getMulExpr(Ops, SCEV::FlagNUW, Depth + 1);
1846 }
1847
1848 // zext(2^K * (trunc X to iN)) to iM ->
1849 // 2^K * (zext(trunc X to i{N-K}) to iM)<nuw>
1850 //
1851 // Proof:
1852 //
1853 // zext(2^K * (trunc X to iN)) to iM
1854 // = zext((trunc X to iN) << K) to iM
1855 // = zext((trunc X to i{N-K}) << K)<nuw> to iM
1856 // (because shl removes the top K bits)
1857 // = zext((2^K * (trunc X to i{N-K}))<nuw>) to iM
1858 // = (2^K * (zext(trunc X to i{N-K}) to iM))<nuw>.
1859 //
1860 if (SM->getNumOperands() == 2)
1861 if (auto *MulLHS = dyn_cast<SCEVConstant>(SM->getOperand(0)))
1862 if (MulLHS->getAPInt().isPowerOf2())
1863 if (auto *TruncRHS = dyn_cast<SCEVTruncateExpr>(SM->getOperand(1))) {
1864 int NewTruncBits = getTypeSizeInBits(TruncRHS->getType()) -
1865 MulLHS->getAPInt().logBase2();
1866 Type *NewTruncTy = IntegerType::get(getContext(), NewTruncBits);
1867 return getMulExpr(
1868 getZeroExtendExpr(MulLHS, Ty),
1870 getTruncateExpr(TruncRHS->getOperand(), NewTruncTy), Ty),
1871 SCEV::FlagNUW, Depth + 1);
1872 }
1873 }
1874
1875 // zext(umin(x, y)) -> umin(zext(x), zext(y))
1876 // zext(umax(x, y)) -> umax(zext(x), zext(y))
1877 if (isa<SCEVUMinExpr>(Op) || isa<SCEVUMaxExpr>(Op)) {
1878 auto *MinMax = cast<SCEVMinMaxExpr>(Op);
1880 for (auto *Operand : MinMax->operands())
1881 Operands.push_back(getZeroExtendExpr(Operand, Ty));
1882 if (isa<SCEVUMinExpr>(MinMax))
1883 return getUMinExpr(Operands);
1884 return getUMaxExpr(Operands);
1885 }
1886
1887 // zext(umin_seq(x, y)) -> umin_seq(zext(x), zext(y))
1888 if (auto *MinMax = dyn_cast<SCEVSequentialMinMaxExpr>(Op)) {
1889 assert(isa<SCEVSequentialUMinExpr>(MinMax) && "Not supported!");
1891 for (auto *Operand : MinMax->operands())
1892 Operands.push_back(getZeroExtendExpr(Operand, Ty));
1893 return getUMinExpr(Operands, /*Sequential*/ true);
1894 }
1895
1896 // The cast wasn't folded; create an explicit cast node.
1897 // Recompute the insert position, as it may have been invalidated.
1898 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
1899 SCEV *S = new (SCEVAllocator) SCEVZeroExtendExpr(ID.Intern(SCEVAllocator),
1900 Op, Ty);
1901 UniqueSCEVs.InsertNode(S, IP);
1902 registerUser(S, Op);
1903 return S;
1904}
1905
1906const SCEV *
1908 assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) &&
1909 "This is not an extending conversion!");
1910 assert(isSCEVable(Ty) &&
1911 "This is not a conversion to a SCEVable type!");
1912 assert(!Op->getType()->isPointerTy() && "Can't extend pointer!");
1913 Ty = getEffectiveSCEVType(Ty);
1914
1915 FoldID ID(scSignExtend, Op, Ty);
1916 auto Iter = FoldCache.find(ID);
1917 if (Iter != FoldCache.end())
1918 return Iter->second;
1919
1920 const SCEV *S = getSignExtendExprImpl(Op, Ty, Depth);
1921 if (!isa<SCEVSignExtendExpr>(S))
1922 insertFoldCacheEntry(ID, S, FoldCache, FoldCacheUser);
1923 return S;
1924}
1925
1927 unsigned Depth) {
1928 assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) &&
1929 "This is not an extending conversion!");
1930 assert(isSCEVable(Ty) && "This is not a conversion to a SCEVable type!");
1931 assert(!Op->getType()->isPointerTy() && "Can't extend pointer!");
1932 Ty = getEffectiveSCEVType(Ty);
1933
1934 // Fold if the operand is constant.
1935 if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op))
1936 return getConstant(SC->getAPInt().sext(getTypeSizeInBits(Ty)));
1937
1938 // sext(sext(x)) --> sext(x)
1939 if (const SCEVSignExtendExpr *SS = dyn_cast<SCEVSignExtendExpr>(Op))
1940 return getSignExtendExpr(SS->getOperand(), Ty, Depth + 1);
1941
1942 // sext(zext(x)) --> zext(x)
1943 if (const SCEVZeroExtendExpr *SZ = dyn_cast<SCEVZeroExtendExpr>(Op))
1944 return getZeroExtendExpr(SZ->getOperand(), Ty, Depth + 1);
1945
1946 // Before doing any expensive analysis, check to see if we've already
1947 // computed a SCEV for this Op and Ty.
1949 ID.AddInteger(scSignExtend);
1950 ID.AddPointer(Op);
1951 ID.AddPointer(Ty);
1952 void *IP = nullptr;
1953 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
1954 // Limit recursion depth.
1955 if (Depth > MaxCastDepth) {
1956 SCEV *S = new (SCEVAllocator) SCEVSignExtendExpr(ID.Intern(SCEVAllocator),
1957 Op, Ty);
1958 UniqueSCEVs.InsertNode(S, IP);
1959 registerUser(S, Op);
1960 return S;
1961 }
1962
1963 // sext(trunc(x)) --> sext(x) or x or trunc(x)
1964 if (const SCEVTruncateExpr *ST = dyn_cast<SCEVTruncateExpr>(Op)) {
1965 // It's possible the bits taken off by the truncate were all sign bits. If
1966 // so, we should be able to simplify this further.
1967 const SCEV *X = ST->getOperand();
1969 unsigned TruncBits = getTypeSizeInBits(ST->getType());
1970 unsigned NewBits = getTypeSizeInBits(Ty);
1971 if (CR.truncate(TruncBits).signExtend(NewBits).contains(
1972 CR.sextOrTrunc(NewBits)))
1973 return getTruncateOrSignExtend(X, Ty, Depth);
1974 }
1975
1976 if (auto *SA = dyn_cast<SCEVAddExpr>(Op)) {
1977 // sext((A + B + ...)<nsw>) --> (sext(A) + sext(B) + ...)<nsw>
1978 if (SA->hasNoSignedWrap()) {
1979 // If the addition does not sign overflow then we can, by definition,
1980 // commute the sign extension with the addition operation.
1982 for (const auto *Op : SA->operands())
1983 Ops.push_back(getSignExtendExpr(Op, Ty, Depth + 1));
1984 return getAddExpr(Ops, SCEV::FlagNSW, Depth + 1);
1985 }
1986
1987 // sext(C + x + y + ...) --> (sext(D) + sext((C - D) + x + y + ...))
1988 // if D + (C - D + x + y + ...) could be proven to not signed wrap
1989 // where D maximizes the number of trailing zeros of (C - D + x + y + ...)
1990 //
1991 // For instance, this will bring two seemingly different expressions:
1992 // 1 + sext(5 + 20 * %x + 24 * %y) and
1993 // sext(6 + 20 * %x + 24 * %y)
1994 // to the same form:
1995 // 2 + sext(4 + 20 * %x + 24 * %y)
1996 if (const auto *SC = dyn_cast<SCEVConstant>(SA->getOperand(0))) {
1997 const APInt &D = extractConstantWithoutWrapping(*this, SC, SA);
1998 if (D != 0) {
1999 const SCEV *SSExtD = getSignExtendExpr(getConstant(D), Ty, Depth);
2000 const SCEV *SResidual =
2002 const SCEV *SSExtR = getSignExtendExpr(SResidual, Ty, Depth + 1);
2003 return getAddExpr(SSExtD, SSExtR,
2005 Depth + 1);
2006 }
2007 }
2008 }
2009 // If the input value is a chrec scev, and we can prove that the value
2010 // did not overflow the old, smaller, value, we can sign extend all of the
2011 // operands (often constants). This allows analysis of something like
2012 // this: for (signed char X = 0; X < 100; ++X) { int Y = X; }
2013 if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Op))
2014 if (AR->isAffine()) {
2015 const SCEV *Start = AR->getStart();
2016 const SCEV *Step = AR->getStepRecurrence(*this);
2017 unsigned BitWidth = getTypeSizeInBits(AR->getType());
2018 const Loop *L = AR->getLoop();
2019
2020 // If we have special knowledge that this addrec won't overflow,
2021 // we don't need to do any further analysis.
2022 if (AR->hasNoSignedWrap()) {
2023 Start =
2024 getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this, Depth + 1);
2025 Step = getSignExtendExpr(Step, Ty, Depth + 1);
2026 return getAddRecExpr(Start, Step, L, SCEV::FlagNSW);
2027 }
2028
2029 // Check whether the backedge-taken count is SCEVCouldNotCompute.
2030 // Note that this serves two purposes: It filters out loops that are
2031 // simply not analyzable, and it covers the case where this code is
2032 // being called from within backedge-taken count analysis, such that
2033 // attempting to ask for the backedge-taken count would likely result
2034 // in infinite recursion. In the later case, the analysis code will
2035 // cope with a conservative value, and it will take care to purge
2036 // that value once it has finished.
2037 const SCEV *MaxBECount = getConstantMaxBackedgeTakenCount(L);
2038 if (!isa<SCEVCouldNotCompute>(MaxBECount)) {
2039 // Manually compute the final value for AR, checking for
2040 // overflow.
2041
2042 // Check whether the backedge-taken count can be losslessly casted to
2043 // the addrec's type. The count is always unsigned.
2044 const SCEV *CastedMaxBECount =
2045 getTruncateOrZeroExtend(MaxBECount, Start->getType(), Depth);
2046 const SCEV *RecastedMaxBECount = getTruncateOrZeroExtend(
2047 CastedMaxBECount, MaxBECount->getType(), Depth);
2048 if (MaxBECount == RecastedMaxBECount) {
2049 Type *WideTy = IntegerType::get(getContext(), BitWidth * 2);
2050 // Check whether Start+Step*MaxBECount has no signed overflow.
2051 const SCEV *SMul = getMulExpr(CastedMaxBECount, Step,
2053 const SCEV *SAdd = getSignExtendExpr(getAddExpr(Start, SMul,
2055 Depth + 1),
2056 WideTy, Depth + 1);
2057 const SCEV *WideStart = getSignExtendExpr(Start, WideTy, Depth + 1);
2058 const SCEV *WideMaxBECount =
2059 getZeroExtendExpr(CastedMaxBECount, WideTy, Depth + 1);
2060 const SCEV *OperandExtendedAdd =
2061 getAddExpr(WideStart,
2062 getMulExpr(WideMaxBECount,
2063 getSignExtendExpr(Step, WideTy, Depth + 1),
2066 if (SAdd == OperandExtendedAdd) {
2067 // Cache knowledge of AR NSW, which is propagated to this AddRec.
2068 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNSW);
2069 // Return the expression with the addrec on the outside.
2070 Start = getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this,
2071 Depth + 1);
2072 Step = getSignExtendExpr(Step, Ty, Depth + 1);
2073 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
2074 }
2075 // Similar to above, only this time treat the step value as unsigned.
2076 // This covers loops that count up with an unsigned step.
2077 OperandExtendedAdd =
2078 getAddExpr(WideStart,
2079 getMulExpr(WideMaxBECount,
2080 getZeroExtendExpr(Step, WideTy, Depth + 1),
2083 if (SAdd == OperandExtendedAdd) {
2084 // If AR wraps around then
2085 //
2086 // abs(Step) * MaxBECount > unsigned-max(AR->getType())
2087 // => SAdd != OperandExtendedAdd
2088 //
2089 // Thus (AR is not NW => SAdd != OperandExtendedAdd) <=>
2090 // (SAdd == OperandExtendedAdd => AR is NW)
2091
2092 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNW);
2093
2094 // Return the expression with the addrec on the outside.
2095 Start = getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this,
2096 Depth + 1);
2097 Step = getZeroExtendExpr(Step, Ty, Depth + 1);
2098 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
2099 }
2100 }
2101 }
2102
2103 auto NewFlags = proveNoSignedWrapViaInduction(AR);
2104 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), NewFlags);
2105 if (AR->hasNoSignedWrap()) {
2106 // Same as nsw case above - duplicated here to avoid a compile time
2107 // issue. It's not clear that the order of checks does matter, but
2108 // it's one of two issue possible causes for a change which was
2109 // reverted. Be conservative for the moment.
2110 Start =
2111 getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this, Depth + 1);
2112 Step = getSignExtendExpr(Step, Ty, Depth + 1);
2113 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
2114 }
2115
2116 // sext({C,+,Step}) --> (sext(D) + sext({C-D,+,Step}))<nuw><nsw>
2117 // if D + (C - D + Step * n) could be proven to not signed wrap
2118 // where D maximizes the number of trailing zeros of (C - D + Step * n)
2119 if (const auto *SC = dyn_cast<SCEVConstant>(Start)) {
2120 const APInt &C = SC->getAPInt();
2121 const APInt &D = extractConstantWithoutWrapping(*this, C, Step);
2122 if (D != 0) {
2123 const SCEV *SSExtD = getSignExtendExpr(getConstant(D), Ty, Depth);
2124 const SCEV *SResidual =
2125 getAddRecExpr(getConstant(C - D), Step, L, AR->getNoWrapFlags());
2126 const SCEV *SSExtR = getSignExtendExpr(SResidual, Ty, Depth + 1);
2127 return getAddExpr(SSExtD, SSExtR,
2129 Depth + 1);
2130 }
2131 }
2132
2133 if (proveNoWrapByVaryingStart<SCEVSignExtendExpr>(Start, Step, L)) {
2134 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNSW);
2135 Start =
2136 getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this, Depth + 1);
2137 Step = getSignExtendExpr(Step, Ty, Depth + 1);
2138 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
2139 }
2140 }
2141
2142 // If the input value is provably positive and we could not simplify
2143 // away the sext build a zext instead.
2145 return getZeroExtendExpr(Op, Ty, Depth + 1);
2146
2147 // sext(smin(x, y)) -> smin(sext(x), sext(y))
2148 // sext(smax(x, y)) -> smax(sext(x), sext(y))
2149 if (isa<SCEVSMinExpr>(Op) || isa<SCEVSMaxExpr>(Op)) {
2150 auto *MinMax = cast<SCEVMinMaxExpr>(Op);
2152 for (auto *Operand : MinMax->operands())
2153 Operands.push_back(getSignExtendExpr(Operand, Ty));
2154 if (isa<SCEVSMinExpr>(MinMax))
2155 return getSMinExpr(Operands);
2156 return getSMaxExpr(Operands);
2157 }
2158
2159 // The cast wasn't folded; create an explicit cast node.
2160 // Recompute the insert position, as it may have been invalidated.
2161 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
2162 SCEV *S = new (SCEVAllocator) SCEVSignExtendExpr(ID.Intern(SCEVAllocator),
2163 Op, Ty);
2164 UniqueSCEVs.InsertNode(S, IP);
2165 registerUser(S, { Op });
2166 return S;
2167}
2168
2170 Type *Ty) {
2171 switch (Kind) {
2172 case scTruncate:
2173 return getTruncateExpr(Op, Ty);
2174 case scZeroExtend:
2175 return getZeroExtendExpr(Op, Ty);
2176 case scSignExtend:
2177 return getSignExtendExpr(Op, Ty);
2178 case scPtrToInt:
2179 return getPtrToIntExpr(Op, Ty);
2180 default:
2181 llvm_unreachable("Not a SCEV cast expression!");
2182 }
2183}
2184
2185/// getAnyExtendExpr - Return a SCEV for the given operand extended with
2186/// unspecified bits out to the given type.
2188 Type *Ty) {
2189 assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) &&
2190 "This is not an extending conversion!");
2191 assert(isSCEVable(Ty) &&
2192 "This is not a conversion to a SCEVable type!");
2193 Ty = getEffectiveSCEVType(Ty);
2194
2195 // Sign-extend negative constants.
2196 if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op))
2197 if (SC->getAPInt().isNegative())
2198 return getSignExtendExpr(Op, Ty);
2199
2200 // Peel off a truncate cast.
2201 if (const SCEVTruncateExpr *T = dyn_cast<SCEVTruncateExpr>(Op)) {
2202 const SCEV *NewOp = T->getOperand();
2203 if (getTypeSizeInBits(NewOp->getType()) < getTypeSizeInBits(Ty))
2204 return getAnyExtendExpr(NewOp, Ty);
2205 return getTruncateOrNoop(NewOp, Ty);
2206 }
2207
2208 // Next try a zext cast. If the cast is folded, use it.
2209 const SCEV *ZExt = getZeroExtendExpr(Op, Ty);
2210 if (!isa<SCEVZeroExtendExpr>(ZExt))
2211 return ZExt;
2212
2213 // Next try a sext cast. If the cast is folded, use it.
2214 const SCEV *SExt = getSignExtendExpr(Op, Ty);
2215 if (!isa<SCEVSignExtendExpr>(SExt))
2216 return SExt;
2217
2218 // Force the cast to be folded into the operands of an addrec.
2219 if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Op)) {
2221 for (const SCEV *Op : AR->operands())
2222 Ops.push_back(getAnyExtendExpr(Op, Ty));
2223 return getAddRecExpr(Ops, AR->getLoop(), SCEV::FlagNW);
2224 }
2225
2226 // If the expression is obviously signed, use the sext cast value.
2227 if (isa<SCEVSMaxExpr>(Op))
2228 return SExt;
2229
2230 // Absent any other information, use the zext cast value.
2231 return ZExt;
2232}
2233
2234/// Process the given Ops list, which is a list of operands to be added under
2235/// the given scale, update the given map. This is a helper function for
2236/// getAddRecExpr. As an example of what it does, given a sequence of operands
2237/// that would form an add expression like this:
2238///
2239/// m + n + 13 + (A * (o + p + (B * (q + m + 29)))) + r + (-1 * r)
2240///
2241/// where A and B are constants, update the map with these values:
2242///
2243/// (m, 1+A*B), (n, 1), (o, A), (p, A), (q, A*B), (r, 0)
2244///
2245/// and add 13 + A*B*29 to AccumulatedConstant.
2246/// This will allow getAddRecExpr to produce this:
2247///
2248/// 13+A*B*29 + n + (m * (1+A*B)) + ((o + p) * A) + (q * A*B)
2249///
2250/// This form often exposes folding opportunities that are hidden in
2251/// the original operand list.
2252///
2253/// Return true iff it appears that any interesting folding opportunities
2254/// may be exposed. This helps getAddRecExpr short-circuit extra work in
2255/// the common case where no interesting opportunities are present, and
2256/// is also used as a check to avoid infinite recursion.
2257static bool
2260 APInt &AccumulatedConstant,
2261 ArrayRef<const SCEV *> Ops, const APInt &Scale,
2262 ScalarEvolution &SE) {
2263 bool Interesting = false;
2264
2265 // Iterate over the add operands. They are sorted, with constants first.
2266 unsigned i = 0;
2267 while (const SCEVConstant *C = dyn_cast<SCEVConstant>(Ops[i])) {
2268 ++i;
2269 // Pull a buried constant out to the outside.
2270 if (Scale != 1 || AccumulatedConstant != 0 || C->getValue()->isZero())
2271 Interesting = true;
2272 AccumulatedConstant += Scale * C->getAPInt();
2273 }
2274
2275 // Next comes everything else. We're especially interested in multiplies
2276 // here, but they're in the middle, so just visit the rest with one loop.
2277 for (; i != Ops.size(); ++i) {
2278 const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(Ops[i]);
2279 if (Mul && isa<SCEVConstant>(Mul->getOperand(0))) {
2280 APInt NewScale =
2281 Scale * cast<SCEVConstant>(Mul->getOperand(0))->getAPInt();
2282 if (Mul->getNumOperands() == 2 && isa<SCEVAddExpr>(Mul->getOperand(1))) {
2283 // A multiplication of a constant with another add; recurse.
2284 const SCEVAddExpr *Add = cast<SCEVAddExpr>(Mul->getOperand(1));
2285 Interesting |=
2286 CollectAddOperandsWithScales(M, NewOps, AccumulatedConstant,
2287 Add->operands(), NewScale, SE);
2288 } else {
2289 // A multiplication of a constant with some other value. Update
2290 // the map.
2291 SmallVector<const SCEV *, 4> MulOps(drop_begin(Mul->operands()));
2292 const SCEV *Key = SE.getMulExpr(MulOps);
2293 auto Pair = M.insert({Key, NewScale});
2294 if (Pair.second) {
2295 NewOps.push_back(Pair.first->first);
2296 } else {
2297 Pair.first->second += NewScale;
2298 // The map already had an entry for this value, which may indicate
2299 // a folding opportunity.
2300 Interesting = true;
2301 }
2302 }
2303 } else {
2304 // An ordinary operand. Update the map.
2305 std::pair<DenseMap<const SCEV *, APInt>::iterator, bool> Pair =
2306 M.insert({Ops[i], Scale});
2307 if (Pair.second) {
2308 NewOps.push_back(Pair.first->first);
2309 } else {
2310 Pair.first->second += Scale;
2311 // The map already had an entry for this value, which may indicate
2312 // a folding opportunity.
2313 Interesting = true;
2314 }
2315 }
2316 }
2317
2318 return Interesting;
2319}
2320
2322 const SCEV *LHS, const SCEV *RHS,
2323 const Instruction *CtxI) {
2324 const SCEV *(ScalarEvolution::*Operation)(const SCEV *, const SCEV *,
2325 SCEV::NoWrapFlags, unsigned);
2326 switch (BinOp) {
2327 default:
2328 llvm_unreachable("Unsupported binary op");
2329 case Instruction::Add:
2331 break;
2332 case Instruction::Sub:
2334 break;
2335 case Instruction::Mul:
2337 break;
2338 }
2339
2340 const SCEV *(ScalarEvolution::*Extension)(const SCEV *, Type *, unsigned) =
2343
2344 // Check ext(LHS op RHS) == ext(LHS) op ext(RHS)
2345 auto *NarrowTy = cast<IntegerType>(LHS->getType());
2346 auto *WideTy =
2347 IntegerType::get(NarrowTy->getContext(), NarrowTy->getBitWidth() * 2);
2348
2349 const SCEV *A = (this->*Extension)(
2350 (this->*Operation)(LHS, RHS, SCEV::FlagAnyWrap, 0), WideTy, 0);
2351 const SCEV *LHSB = (this->*Extension)(LHS, WideTy, 0);
2352 const SCEV *RHSB = (this->*Extension)(RHS, WideTy, 0);
2353 const SCEV *B = (this->*Operation)(LHSB, RHSB, SCEV::FlagAnyWrap, 0);
2354 if (A == B)
2355 return true;
2356 // Can we use context to prove the fact we need?
2357 if (!CtxI)
2358 return false;
2359 // TODO: Support mul.
2360 if (BinOp == Instruction::Mul)
2361 return false;
2362 auto *RHSC = dyn_cast<SCEVConstant>(RHS);
2363 // TODO: Lift this limitation.
2364 if (!RHSC)
2365 return false;
2366 APInt C = RHSC->getAPInt();
2367 unsigned NumBits = C.getBitWidth();
2368 bool IsSub = (BinOp == Instruction::Sub);
2369 bool IsNegativeConst = (Signed && C.isNegative());
2370 // Compute the direction and magnitude by which we need to check overflow.
2371 bool OverflowDown = IsSub ^ IsNegativeConst;
2372 APInt Magnitude = C;
2373 if (IsNegativeConst) {
2374 if (C == APInt::getSignedMinValue(NumBits))
2375 // TODO: SINT_MIN on inversion gives the same negative value, we don't
2376 // want to deal with that.
2377 return false;
2378 Magnitude = -C;
2379 }
2380
2382 if (OverflowDown) {
2383 // To avoid overflow down, we need to make sure that MIN + Magnitude <= LHS.
2384 APInt Min = Signed ? APInt::getSignedMinValue(NumBits)
2385 : APInt::getMinValue(NumBits);
2386 APInt Limit = Min + Magnitude;
2387 return isKnownPredicateAt(Pred, getConstant(Limit), LHS, CtxI);
2388 } else {
2389 // To avoid overflow up, we need to make sure that LHS <= MAX - Magnitude.
2390 APInt Max = Signed ? APInt::getSignedMaxValue(NumBits)
2391 : APInt::getMaxValue(NumBits);
2392 APInt Limit = Max - Magnitude;
2393 return isKnownPredicateAt(Pred, LHS, getConstant(Limit), CtxI);
2394 }
2395}
2396
2397std::optional<SCEV::NoWrapFlags>
2399 const OverflowingBinaryOperator *OBO) {
2400 // It cannot be done any better.
2401 if (OBO->hasNoUnsignedWrap() && OBO->hasNoSignedWrap())
2402 return std::nullopt;
2403
2405
2406 if (OBO->hasNoUnsignedWrap())
2408 if (OBO->hasNoSignedWrap())
2410
2411 bool Deduced = false;
2412
2413 if (OBO->getOpcode() != Instruction::Add &&
2414 OBO->getOpcode() != Instruction::Sub &&
2415 OBO->getOpcode() != Instruction::Mul)
2416 return std::nullopt;
2417
2418 const SCEV *LHS = getSCEV(OBO->getOperand(0));
2419 const SCEV *RHS = getSCEV(OBO->getOperand(1));
2420
2421 const Instruction *CtxI =
2422 UseContextForNoWrapFlagInference ? dyn_cast<Instruction>(OBO) : nullptr;
2423 if (!OBO->hasNoUnsignedWrap() &&
2425 /* Signed */ false, LHS, RHS, CtxI)) {
2427 Deduced = true;
2428 }
2429
2430 if (!OBO->hasNoSignedWrap() &&
2432 /* Signed */ true, LHS, RHS, CtxI)) {
2434 Deduced = true;
2435 }
2436
2437 if (Deduced)
2438 return Flags;
2439 return std::nullopt;
2440}
2441
2442// We're trying to construct a SCEV of type `Type' with `Ops' as operands and
2443// `OldFlags' as can't-wrap behavior. Infer a more aggressive set of
2444// can't-overflow flags for the operation if possible.
2445static SCEV::NoWrapFlags
2447 const ArrayRef<const SCEV *> Ops,
2448 SCEV::NoWrapFlags Flags) {
2449 using namespace std::placeholders;
2450
2451 using OBO = OverflowingBinaryOperator;
2452
2453 bool CanAnalyze =
2455 (void)CanAnalyze;
2456 assert(CanAnalyze && "don't call from other places!");
2457
2458 int SignOrUnsignMask = SCEV::FlagNUW | SCEV::FlagNSW;
2459 SCEV::NoWrapFlags SignOrUnsignWrap =
2460 ScalarEvolution::maskFlags(Flags, SignOrUnsignMask);
2461
2462 // If FlagNSW is true and all the operands are non-negative, infer FlagNUW.
2463 auto IsKnownNonNegative = [&](const SCEV *S) {
2464 return SE->isKnownNonNegative(S);
2465 };
2466
2467 if (SignOrUnsignWrap == SCEV::FlagNSW && all_of(Ops, IsKnownNonNegative))
2468 Flags =
2469 ScalarEvolution::setFlags(Flags, (SCEV::NoWrapFlags)SignOrUnsignMask);
2470
2471 SignOrUnsignWrap = ScalarEvolution::maskFlags(Flags, SignOrUnsignMask);
2472
2473 if (SignOrUnsignWrap != SignOrUnsignMask &&
2474 (Type == scAddExpr || Type == scMulExpr) && Ops.size() == 2 &&
2475 isa<SCEVConstant>(Ops[0])) {
2476
2477 auto Opcode = [&] {
2478 switch (Type) {
2479 case scAddExpr:
2480 return Instruction::Add;
2481 case scMulExpr:
2482 return Instruction::Mul;
2483 default:
2484 llvm_unreachable("Unexpected SCEV op.");
2485 }
2486 }();
2487
2488 const APInt &C = cast<SCEVConstant>(Ops[0])->getAPInt();
2489
2490 // (A <opcode> C) --> (A <opcode> C)<nsw> if the op doesn't sign overflow.
2491 if (!(SignOrUnsignWrap & SCEV::FlagNSW)) {
2493 Opcode, C, OBO::NoSignedWrap);
2494 if (NSWRegion.contains(SE->getSignedRange(Ops[1])))
2496 }
2497
2498 // (A <opcode> C) --> (A <opcode> C)<nuw> if the op doesn't unsign overflow.
2499 if (!(SignOrUnsignWrap & SCEV::FlagNUW)) {
2501 Opcode, C, OBO::NoUnsignedWrap);
2502 if (NUWRegion.contains(SE->getUnsignedRange(Ops[1])))
2504 }
2505 }
2506
2507 // <0,+,nonnegative><nw> is also nuw
2508 // TODO: Add corresponding nsw case
2510 !ScalarEvolution::hasFlags(Flags, SCEV::FlagNUW) && Ops.size() == 2 &&
2511 Ops[0]->isZero() && IsKnownNonNegative(Ops[1]))
2513
2514 // both (udiv X, Y) * Y and Y * (udiv X, Y) are always NUW
2516 Ops.size() == 2) {
2517 if (auto *UDiv = dyn_cast<SCEVUDivExpr>(Ops[0]))
2518 if (UDiv->getOperand(1) == Ops[1])
2520 if (auto *UDiv = dyn_cast<SCEVUDivExpr>(Ops[1]))
2521 if (UDiv->getOperand(1) == Ops[0])
2523 }
2524
2525 return Flags;
2526}
2527
2529 return isLoopInvariant(S, L) && properlyDominates(S, L->getHeader());
2530}
2531
2532/// Get a canonical add expression, or something simpler if possible.
2534 SCEV::NoWrapFlags OrigFlags,
2535 unsigned Depth) {
2536 assert(!(OrigFlags & ~(SCEV::FlagNUW | SCEV::FlagNSW)) &&
2537 "only nuw or nsw allowed");
2538 assert(!Ops.empty() && "Cannot get empty add!");
2539 if (Ops.size() == 1) return Ops[0];
2540#ifndef NDEBUG
2541 Type *ETy = getEffectiveSCEVType(Ops[0]->getType());
2542 for (unsigned i = 1, e = Ops.size(); i != e; ++i)
2543 assert(getEffectiveSCEVType(Ops[i]->getType()) == ETy &&
2544 "SCEVAddExpr operand types don't match!");
2545 unsigned NumPtrs = count_if(
2546 Ops, [](const SCEV *Op) { return Op->getType()->isPointerTy(); });
2547 assert(NumPtrs <= 1 && "add has at most one pointer operand");
2548#endif
2549
2550 const SCEV *Folded = constantFoldAndGroupOps(
2551 *this, LI, DT, Ops,
2552 [](const APInt &C1, const APInt &C2) { return C1 + C2; },
2553 [](const APInt &C) { return C.isZero(); }, // identity
2554 [](const APInt &C) { return false; }); // absorber
2555 if (Folded)
2556 return Folded;
2557
2558 unsigned Idx = isa<SCEVConstant>(Ops[0]) ? 1 : 0;
2559
2560 // Delay expensive flag strengthening until necessary.
2561 auto ComputeFlags = [this, OrigFlags](const ArrayRef<const SCEV *> Ops) {
2562 return StrengthenNoWrapFlags(this, scAddExpr, Ops, OrigFlags);
2563 };
2564
2565 // Limit recursion calls depth.
2567 return getOrCreateAddExpr(Ops, ComputeFlags(Ops));
2568
2569 if (SCEV *S = findExistingSCEVInCache(scAddExpr, Ops)) {
2570 // Don't strengthen flags if we have no new information.
2571 SCEVAddExpr *Add = static_cast<SCEVAddExpr *>(S);
2572 if (Add->getNoWrapFlags(OrigFlags) != OrigFlags)
2573 Add->setNoWrapFlags(ComputeFlags(Ops));
2574 return S;
2575 }
2576
2577 // Okay, check to see if the same value occurs in the operand list more than
2578 // once. If so, merge them together into an multiply expression. Since we
2579 // sorted the list, these values are required to be adjacent.
2580 Type *Ty = Ops[0]->getType();
2581 bool FoundMatch = false;
2582 for (unsigned i = 0, e = Ops.size(); i != e-1; ++i)
2583 if (Ops[i] == Ops[i+1]) { // X + Y + Y --> X + Y*2
2584 // Scan ahead to count how many equal operands there are.
2585 unsigned Count = 2;
2586 while (i+Count != e && Ops[i+Count] == Ops[i])
2587 ++Count;
2588 // Merge the values into a multiply.
2589 const SCEV *Scale = getConstant(Ty, Count);
2590 const SCEV *Mul = getMulExpr(Scale, Ops[i], SCEV::FlagAnyWrap, Depth + 1);
2591 if (Ops.size() == Count)
2592 return Mul;
2593 Ops[i] = Mul;
2594 Ops.erase(Ops.begin()+i+1, Ops.begin()+i+Count);
2595 --i; e -= Count - 1;
2596 FoundMatch = true;
2597 }
2598 if (FoundMatch)
2599 return getAddExpr(Ops, OrigFlags, Depth + 1);
2600
2601 // Check for truncates. If all the operands are truncated from the same
2602 // type, see if factoring out the truncate would permit the result to be
2603 // folded. eg., n*trunc(x) + m*trunc(y) --> trunc(trunc(m)*x + trunc(n)*y)
2604 // if the contents of the resulting outer trunc fold to something simple.
2605 auto FindTruncSrcType = [&]() -> Type * {
2606 // We're ultimately looking to fold an addrec of truncs and muls of only
2607 // constants and truncs, so if we find any other types of SCEV
2608 // as operands of the addrec then we bail and return nullptr here.
2609 // Otherwise, we return the type of the operand of a trunc that we find.
2610 if (auto *T = dyn_cast<SCEVTruncateExpr>(Ops[Idx]))
2611 return T->getOperand()->getType();
2612 if (const auto *Mul = dyn_cast<SCEVMulExpr>(Ops[Idx])) {
2613 const auto *LastOp = Mul->getOperand(Mul->getNumOperands() - 1);
2614 if (const auto *T = dyn_cast<SCEVTruncateExpr>(LastOp))
2615 return T->getOperand()->getType();
2616 }
2617 return nullptr;
2618 };
2619 if (auto *SrcType = FindTruncSrcType()) {
2621 bool Ok = true;
2622 // Check all the operands to see if they can be represented in the
2623 // source type of the truncate.
2624 for (const SCEV *Op : Ops) {
2625 if (const SCEVTruncateExpr *T = dyn_cast<SCEVTruncateExpr>(Op)) {
2626 if (T->getOperand()->getType() != SrcType) {
2627 Ok = false;
2628 break;
2629 }
2630 LargeOps.push_back(T->getOperand());
2631 } else if (const SCEVConstant *C = dyn_cast<SCEVConstant>(Op)) {
2632 LargeOps.push_back(getAnyExtendExpr(C, SrcType));
2633 } else if (const SCEVMulExpr *M = dyn_cast<SCEVMulExpr>(Op)) {
2634 SmallVector<const SCEV *, 8> LargeMulOps;
2635 for (unsigned j = 0, f = M->getNumOperands(); j != f && Ok; ++j) {
2636 if (const SCEVTruncateExpr *T =
2637 dyn_cast<SCEVTruncateExpr>(M->getOperand(j))) {
2638 if (T->getOperand()->getType() != SrcType) {
2639 Ok = false;
2640 break;
2641 }
2642 LargeMulOps.push_back(T->getOperand());
2643 } else if (const auto *C = dyn_cast<SCEVConstant>(M->getOperand(j))) {
2644 LargeMulOps.push_back(getAnyExtendExpr(C, SrcType));
2645 } else {
2646 Ok = false;
2647 break;
2648 }
2649 }
2650 if (Ok)
2651 LargeOps.push_back(getMulExpr(LargeMulOps, SCEV::FlagAnyWrap, Depth + 1));
2652 } else {
2653 Ok = false;
2654 break;
2655 }
2656 }
2657 if (Ok) {
2658 // Evaluate the expression in the larger type.
2659 const SCEV *Fold = getAddExpr(LargeOps, SCEV::FlagAnyWrap, Depth + 1);
2660 // If it folds to something simple, use it. Otherwise, don't.
2661 if (isa<SCEVConstant>(Fold) || isa<SCEVUnknown>(Fold))
2662 return getTruncateExpr(Fold, Ty);
2663 }
2664 }
2665
2666 if (Ops.size() == 2) {
2667 // Check if we have an expression of the form ((X + C1) - C2), where C1 and
2668 // C2 can be folded in a way that allows retaining wrapping flags of (X +
2669 // C1).
2670 const SCEV *A = Ops[0];
2671 const SCEV *B = Ops[1];
2672 auto *AddExpr = dyn_cast<SCEVAddExpr>(B);
2673 auto *C = dyn_cast<SCEVConstant>(A);
2674 if (AddExpr && C && isa<SCEVConstant>(AddExpr->getOperand(0))) {
2675 auto C1 = cast<SCEVConstant>(AddExpr->getOperand(0))->getAPInt();
2676 auto C2 = C->getAPInt();
2677 SCEV::NoWrapFlags PreservedFlags = SCEV::FlagAnyWrap;
2678
2679 APInt ConstAdd = C1 + C2;
2680 auto AddFlags = AddExpr->getNoWrapFlags();
2681 // Adding a smaller constant is NUW if the original AddExpr was NUW.
2683 ConstAdd.ule(C1)) {
2684 PreservedFlags =
2686 }
2687
2688 // Adding a constant with the same sign and small magnitude is NSW, if the
2689 // original AddExpr was NSW.
2691 C1.isSignBitSet() == ConstAdd.isSignBitSet() &&
2692 ConstAdd.abs().ule(C1.abs())) {
2693 PreservedFlags =
2695 }
2696
2697 if (PreservedFlags != SCEV::FlagAnyWrap) {
2698 SmallVector<const SCEV *, 4> NewOps(AddExpr->operands());
2699 NewOps[0] = getConstant(ConstAdd);
2700 return getAddExpr(NewOps, PreservedFlags);
2701 }
2702 }
2703 }
2704
2705 // Canonicalize (-1 * urem X, Y) + X --> (Y * X/Y)
2706 if (Ops.size() == 2) {
2707 const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(Ops[0]);
2708 if (Mul && Mul->getNumOperands() == 2 &&
2709 Mul->getOperand(0)->isAllOnesValue()) {
2710 const SCEV *X;
2711 const SCEV *Y;
2712 if (matchURem(Mul->getOperand(1), X, Y) && X == Ops[1]) {
2713 return getMulExpr(Y, getUDivExpr(X, Y));
2714 }
2715 }
2716 }
2717
2718 // Skip past any other cast SCEVs.
2719 while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scAddExpr)
2720 ++Idx;
2721
2722 // If there are add operands they would be next.
2723 if (Idx < Ops.size()) {
2724 bool DeletedAdd = false;
2725 // If the original flags and all inlined SCEVAddExprs are NUW, use the
2726 // common NUW flag for expression after inlining. Other flags cannot be
2727 // preserved, because they may depend on the original order of operations.
2728 SCEV::NoWrapFlags CommonFlags = maskFlags(OrigFlags, SCEV::FlagNUW);
2729 while (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Ops[Idx])) {
2730 if (Ops.size() > AddOpsInlineThreshold ||
2731 Add->getNumOperands() > AddOpsInlineThreshold)
2732 break;
2733 // If we have an add, expand the add operands onto the end of the operands
2734 // list.
2735 Ops.erase(Ops.begin()+Idx);
2736 append_range(Ops, Add->operands());
2737 DeletedAdd = true;
2738 CommonFlags = maskFlags(CommonFlags, Add->getNoWrapFlags());
2739 }
2740
2741 // If we deleted at least one add, we added operands to the end of the list,
2742 // and they are not necessarily sorted. Recurse to resort and resimplify
2743 // any operands we just acquired.
2744 if (DeletedAdd)
2745 return getAddExpr(Ops, CommonFlags, Depth + 1);
2746 }
2747
2748 // Skip over the add expression until we get to a multiply.
2749 while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scMulExpr)
2750 ++Idx;
2751
2752 // Check to see if there are any folding opportunities present with
2753 // operands multiplied by constant values.
2754 if (Idx < Ops.size() && isa<SCEVMulExpr>(Ops[Idx])) {
2758 APInt AccumulatedConstant(BitWidth, 0);
2759 if (CollectAddOperandsWithScales(M, NewOps, AccumulatedConstant,
2760 Ops, APInt(BitWidth, 1), *this)) {
2761 struct APIntCompare {
2762 bool operator()(const APInt &LHS, const APInt &RHS) const {
2763 return LHS.ult(RHS);
2764 }
2765 };
2766
2767 // Some interesting folding opportunity is present, so its worthwhile to
2768 // re-generate the operands list. Group the operands by constant scale,
2769 // to avoid multiplying by the same constant scale multiple times.
2770 std::map<APInt, SmallVector<const SCEV *, 4>, APIntCompare> MulOpLists;
2771 for (const SCEV *NewOp : NewOps)
2772 MulOpLists[M.find(NewOp)->second].push_back(NewOp);
2773 // Re-generate the operands list.
2774 Ops.clear();
2775 if (AccumulatedConstant != 0)
2776 Ops.push_back(getConstant(AccumulatedConstant));
2777 for (auto &MulOp : MulOpLists) {
2778 if (MulOp.first == 1) {
2779 Ops.push_back(getAddExpr(MulOp.second, SCEV::FlagAnyWrap, Depth + 1));
2780 } else if (MulOp.first != 0) {
2782 getConstant(MulOp.first),
2783 getAddExpr(MulOp.second, SCEV::FlagAnyWrap, Depth + 1),
2784 SCEV::FlagAnyWrap, Depth + 1));
2785 }
2786 }
2787 if (Ops.empty())
2788 return getZero(Ty);
2789 if (Ops.size() == 1)
2790 return Ops[0];
2791 return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
2792 }
2793 }
2794
2795 // If we are adding something to a multiply expression, make sure the
2796 // something is not already an operand of the multiply. If so, merge it into
2797 // the multiply.
2798 for (; Idx < Ops.size() && isa<SCEVMulExpr>(Ops[Idx]); ++Idx) {
2799 const SCEVMulExpr *Mul = cast<SCEVMulExpr>(Ops[Idx]);
2800 for (unsigned MulOp = 0, e = Mul->getNumOperands(); MulOp != e; ++MulOp) {
2801 const SCEV *MulOpSCEV = Mul->getOperand(MulOp);
2802 if (isa<SCEVConstant>(MulOpSCEV))
2803 continue;
2804 for (unsigned AddOp = 0, e = Ops.size(); AddOp != e; ++AddOp)
2805 if (MulOpSCEV == Ops[AddOp]) {
2806 // Fold W + X + (X * Y * Z) --> W + (X * ((Y*Z)+1))
2807 const SCEV *InnerMul = Mul->getOperand(MulOp == 0);
2808 if (Mul->getNumOperands() != 2) {
2809 // If the multiply has more than two operands, we must get the
2810 // Y*Z term.
2812 Mul->operands().take_front(MulOp));
2813 append_range(MulOps, Mul->operands().drop_front(MulOp + 1));
2814 InnerMul = getMulExpr(MulOps, SCEV::FlagAnyWrap, Depth + 1);
2815 }
2816 SmallVector<const SCEV *, 2> TwoOps = {getOne(Ty), InnerMul};
2817 const SCEV *AddOne = getAddExpr(TwoOps, SCEV::FlagAnyWrap, Depth + 1);
2818 const SCEV *OuterMul = getMulExpr(AddOne, MulOpSCEV,
2820 if (Ops.size() == 2) return OuterMul;
2821 if (AddOp < Idx) {
2822 Ops.erase(Ops.begin()+AddOp);
2823 Ops.erase(Ops.begin()+Idx-1);
2824 } else {
2825 Ops.erase(Ops.begin()+Idx);
2826 Ops.erase(Ops.begin()+AddOp-1);
2827 }
2828 Ops.push_back(OuterMul);
2829 return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
2830 }
2831
2832 // Check this multiply against other multiplies being added together.
2833 for (unsigned OtherMulIdx = Idx+1;
2834 OtherMulIdx < Ops.size() && isa<SCEVMulExpr>(Ops[OtherMulIdx]);
2835 ++OtherMulIdx) {
2836 const SCEVMulExpr *OtherMul = cast<SCEVMulExpr>(Ops[OtherMulIdx]);
2837 // If MulOp occurs in OtherMul, we can fold the two multiplies
2838 // together.
2839 for (unsigned OMulOp = 0, e = OtherMul->getNumOperands();
2840 OMulOp != e; ++OMulOp)
2841 if (OtherMul->getOperand(OMulOp) == MulOpSCEV) {
2842 // Fold X + (A*B*C) + (A*D*E) --> X + (A*(B*C+D*E))
2843 const SCEV *InnerMul1 = Mul->getOperand(MulOp == 0);
2844 if (Mul->getNumOperands() != 2) {
2846 Mul->operands().take_front(MulOp));
2847 append_range(MulOps, Mul->operands().drop_front(MulOp+1));
2848 InnerMul1 = getMulExpr(MulOps, SCEV::FlagAnyWrap, Depth + 1);
2849 }
2850 const SCEV *InnerMul2 = OtherMul->getOperand(OMulOp == 0);
2851 if (OtherMul->getNumOperands() != 2) {
2853 OtherMul->operands().take_front(OMulOp));
2854 append_range(MulOps, OtherMul->operands().drop_front(OMulOp+1));
2855 InnerMul2 = getMulExpr(MulOps, SCEV::FlagAnyWrap, Depth + 1);
2856 }
2857 SmallVector<const SCEV *, 2> TwoOps = {InnerMul1, InnerMul2};
2858 const SCEV *InnerMulSum =
2859 getAddExpr(TwoOps, SCEV::FlagAnyWrap, Depth + 1);
2860 const SCEV *OuterMul = getMulExpr(MulOpSCEV, InnerMulSum,
2862 if (Ops.size() == 2) return OuterMul;
2863 Ops.erase(Ops.begin()+Idx);
2864 Ops.erase(Ops.begin()+OtherMulIdx-1);
2865 Ops.push_back(OuterMul);
2866 return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
2867 }
2868 }
2869 }
2870 }
2871
2872 // If there are any add recurrences in the operands list, see if any other
2873 // added values are loop invariant. If so, we can fold them into the
2874 // recurrence.
2875 while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scAddRecExpr)
2876 ++Idx;
2877
2878 // Scan over all recurrences, trying to fold loop invariants into them.
2879 for (; Idx < Ops.size() && isa<SCEVAddRecExpr>(Ops[Idx]); ++Idx) {
2880 // Scan all of the other operands to this add and add them to the vector if
2881 // they are loop invariant w.r.t. the recurrence.
2883 const SCEVAddRecExpr *AddRec = cast<SCEVAddRecExpr>(Ops[Idx]);
2884 const Loop *AddRecLoop = AddRec->getLoop();
2885 for (unsigned i = 0, e = Ops.size(); i != e; ++i)
2886 if (isAvailableAtLoopEntry(Ops[i], AddRecLoop)) {
2887 LIOps.push_back(Ops[i]);
2888 Ops.erase(Ops.begin()+i);
2889 --i; --e;
2890 }
2891
2892 // If we found some loop invariants, fold them into the recurrence.
2893 if (!LIOps.empty()) {
2894 // Compute nowrap flags for the addition of the loop-invariant ops and
2895 // the addrec. Temporarily push it as an operand for that purpose. These
2896 // flags are valid in the scope of the addrec only.
2897 LIOps.push_back(AddRec);
2898 SCEV::NoWrapFlags Flags = ComputeFlags(LIOps);
2899 LIOps.pop_back();
2900
2901 // NLI + LI + {Start,+,Step} --> NLI + {LI+Start,+,Step}
2902 LIOps.push_back(AddRec->getStart());
2903
2904 SmallVector<const SCEV *, 4> AddRecOps(AddRec->operands());
2905
2906 // It is not in general safe to propagate flags valid on an add within
2907 // the addrec scope to one outside it. We must prove that the inner
2908 // scope is guaranteed to execute if the outer one does to be able to
2909 // safely propagate. We know the program is undefined if poison is
2910 // produced on the inner scoped addrec. We also know that *for this use*
2911 // the outer scoped add can't overflow (because of the flags we just
2912 // computed for the inner scoped add) without the program being undefined.
2913 // Proving that entry to the outer scope neccesitates entry to the inner
2914 // scope, thus proves the program undefined if the flags would be violated
2915 // in the outer scope.
2916 SCEV::NoWrapFlags AddFlags = Flags;
2917 if (AddFlags != SCEV::FlagAnyWrap) {
2918 auto *DefI = getDefiningScopeBound(LIOps);
2919 auto *ReachI = &*AddRecLoop->getHeader()->begin();
2920 if (!isGuaranteedToTransferExecutionTo(DefI, ReachI))
2921 AddFlags = SCEV::FlagAnyWrap;
2922 }
2923 AddRecOps[0] = getAddExpr(LIOps, AddFlags, Depth + 1);
2924
2925 // Build the new addrec. Propagate the NUW and NSW flags if both the
2926 // outer add and the inner addrec are guaranteed to have no overflow.
2927 // Always propagate NW.
2928 Flags = AddRec->getNoWrapFlags(setFlags(Flags, SCEV::FlagNW));
2929 const SCEV *NewRec = getAddRecExpr(AddRecOps, AddRecLoop, Flags);
2930
2931 // If all of the other operands were loop invariant, we are done.
2932 if (Ops.size() == 1) return NewRec;
2933
2934 // Otherwise, add the folded AddRec by the non-invariant parts.
2935 for (unsigned i = 0;; ++i)
2936 if (Ops[i] == AddRec) {
2937 Ops[i] = NewRec;
2938 break;
2939 }
2940 return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
2941 }
2942
2943 // Okay, if there weren't any loop invariants to be folded, check to see if
2944 // there are multiple AddRec's with the same loop induction variable being
2945 // added together. If so, we can fold them.
2946 for (unsigned OtherIdx = Idx+1;
2947 OtherIdx < Ops.size() && isa<SCEVAddRecExpr>(Ops[OtherIdx]);
2948 ++OtherIdx) {
2949 // We expect the AddRecExpr's to be sorted in reverse dominance order,
2950 // so that the 1st found AddRecExpr is dominated by all others.
2951 assert(DT.dominates(
2952 cast<SCEVAddRecExpr>(Ops[OtherIdx])->getLoop()->getHeader(),
2953 AddRec->getLoop()->getHeader()) &&
2954 "AddRecExprs are not sorted in reverse dominance order?");
2955 if (AddRecLoop == cast<SCEVAddRecExpr>(Ops[OtherIdx])->getLoop()) {
2956 // Other + {A,+,B}<L> + {C,+,D}<L> --> Other + {A+C,+,B+D}<L>
2957 SmallVector<const SCEV *, 4> AddRecOps(AddRec->operands());
2958 for (; OtherIdx != Ops.size() && isa<SCEVAddRecExpr>(Ops[OtherIdx]);
2959 ++OtherIdx) {
2960 const auto *OtherAddRec = cast<SCEVAddRecExpr>(Ops[OtherIdx]);
2961 if (OtherAddRec->getLoop() == AddRecLoop) {
2962 for (unsigned i = 0, e = OtherAddRec->getNumOperands();
2963 i != e; ++i) {
2964 if (i >= AddRecOps.size()) {
2965 append_range(AddRecOps, OtherAddRec->operands().drop_front(i));
2966 break;
2967 }
2969 AddRecOps[i], OtherAddRec->getOperand(i)};
2970 AddRecOps[i] = getAddExpr(TwoOps, SCEV::FlagAnyWrap, Depth + 1);
2971 }
2972 Ops.erase(Ops.begin() + OtherIdx); --OtherIdx;
2973 }
2974 }
2975 // Step size has changed, so we cannot guarantee no self-wraparound.
2976 Ops[Idx] = getAddRecExpr(AddRecOps, AddRecLoop, SCEV::FlagAnyWrap);
2977 return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
2978 }
2979 }
2980
2981 // Otherwise couldn't fold anything into this recurrence. Move onto the
2982 // next one.
2983 }
2984
2985 // Okay, it looks like we really DO need an add expr. Check to see if we
2986 // already have one, otherwise create a new one.
2987 return getOrCreateAddExpr(Ops, ComputeFlags(Ops));
2988}
2989
2990const SCEV *
2991ScalarEvolution::getOrCreateAddExpr(ArrayRef<const SCEV *> Ops,
2992 SCEV::NoWrapFlags Flags) {
2994 ID.AddInteger(scAddExpr);
2995 for (const SCEV *Op : Ops)
2996 ID.AddPointer(Op);
2997 void *IP = nullptr;
2998 SCEVAddExpr *S =
2999 static_cast<SCEVAddExpr *>(UniqueSCEVs.FindNodeOrInsertPos(ID, IP));
3000 if (!S) {
3001 const SCEV **O = SCEVAllocator.Allocate<const SCEV *>(Ops.size());
3002 std::uninitialized_copy(Ops.begin(), Ops.end(), O);
3003 S = new (SCEVAllocator)
3004 SCEVAddExpr(ID.Intern(SCEVAllocator), O, Ops.size());
3005 UniqueSCEVs.InsertNode(S, IP);
3006 registerUser(S, Ops);
3007 }
3008 S->setNoWrapFlags(Flags);
3009 return S;
3010}
3011
3012const SCEV *
3013ScalarEvolution::getOrCreateAddRecExpr(ArrayRef<const SCEV *> Ops,
3014 const Loop *L, SCEV::NoWrapFlags Flags) {
3016 ID.AddInteger(scAddRecExpr);
3017 for (const SCEV *Op : Ops)
3018 ID.AddPointer(Op);
3019 ID.AddPointer(L);
3020 void *IP = nullptr;
3021 SCEVAddRecExpr *S =
3022 static_cast<SCEVAddRecExpr *>(UniqueSCEVs.FindNodeOrInsertPos(ID, IP));
3023 if (!S) {
3024 const SCEV **O = SCEVAllocator.Allocate<const SCEV *>(Ops.size());
3025 std::uninitialized_copy(Ops.begin(), Ops.end(), O);
3026 S = new (SCEVAllocator)
3027 SCEVAddRecExpr(ID.Intern(SCEVAllocator), O, Ops.size(), L);
3028 UniqueSCEVs.InsertNode(S, IP);
3029 LoopUsers[L].push_back(S);
3030 registerUser(S, Ops);
3031 }
3032 setNoWrapFlags(S, Flags);
3033 return S;
3034}
3035
3036const SCEV *
3037ScalarEvolution::getOrCreateMulExpr(ArrayRef<const SCEV *> Ops,
3038 SCEV::NoWrapFlags Flags) {
3040 ID.AddInteger(scMulExpr);
3041 for (const SCEV *Op : Ops)
3042 ID.AddPointer(Op);
3043 void *IP = nullptr;
3044 SCEVMulExpr *S =
3045 static_cast<SCEVMulExpr *>(UniqueSCEVs.FindNodeOrInsertPos(ID, IP));
3046 if (!S) {
3047 const SCEV **O = SCEVAllocator.Allocate<const SCEV *>(Ops.size());
3048 std::uninitialized_copy(Ops.begin(), Ops.end(), O);
3049 S = new (SCEVAllocator) SCEVMulExpr(ID.Intern(SCEVAllocator),
3050 O, Ops.size());
3051 UniqueSCEVs.InsertNode(S, IP);
3052 registerUser(S, Ops);
3053 }
3054 S->setNoWrapFlags(Flags);
3055 return S;
3056}
3057
3058static uint64_t umul_ov(uint64_t i, uint64_t j, bool &Overflow) {
3059 uint64_t k = i*j;
3060 if (j > 1 && k / j != i) Overflow = true;
3061 return k;
3062}
3063
3064/// Compute the result of "n choose k", the binomial coefficient. If an
3065/// intermediate computation overflows, Overflow will be set and the return will
3066/// be garbage. Overflow is not cleared on absence of overflow.
3067static uint64_t Choose(uint64_t n, uint64_t k, bool &Overflow) {
3068 // We use the multiplicative formula:
3069 // n(n-1)(n-2)...(n-(k-1)) / k(k-1)(k-2)...1 .
3070 // At each iteration, we take the n-th term of the numeral and divide by the
3071 // (k-n)th term of the denominator. This division will always produce an
3072 // integral result, and helps reduce the chance of overflow in the
3073 // intermediate computations. However, we can still overflow even when the
3074 // final result would fit.
3075
3076 if (n == 0 || n == k) return 1;
3077 if (k > n) return 0;
3078
3079 if (k > n/2)
3080 k = n-k;
3081
3082 uint64_t r = 1;
3083 for (uint64_t i = 1; i <= k; ++i) {
3084 r = umul_ov(r, n-(i-1), Overflow);
3085 r /= i;
3086 }
3087 return r;
3088}
3089
3090/// Determine if any of the operands in this SCEV are a constant or if
3091/// any of the add or multiply expressions in this SCEV contain a constant.
3092static bool containsConstantInAddMulChain(const SCEV *StartExpr) {
3093 struct FindConstantInAddMulChain {
3094 bool FoundConstant = false;
3095
3096 bool follow(const SCEV *S) {
3097 FoundConstant |= isa<SCEVConstant>(S);
3098 return isa<SCEVAddExpr>(S) || isa<SCEVMulExpr>(S);
3099 }
3100
3101 bool isDone() const {
3102 return FoundConstant;
3103 }
3104 };
3105
3106 FindConstantInAddMulChain F;
3108 ST.visitAll(StartExpr);
3109 return F.FoundConstant;
3110}
3111
3112/// Get a canonical multiply expression, or something simpler if possible.
3114 SCEV::NoWrapFlags OrigFlags,
3115 unsigned Depth) {
3116 assert(OrigFlags == maskFlags(OrigFlags, SCEV::FlagNUW | SCEV::FlagNSW) &&
3117 "only nuw or nsw allowed");
3118 assert(!Ops.empty() && "Cannot get empty mul!");
3119 if (Ops.size() == 1) return Ops[0];
3120#ifndef NDEBUG
3121 Type *ETy = Ops[0]->getType();
3122 assert(!ETy->isPointerTy());
3123 for (unsigned i = 1, e = Ops.size(); i != e; ++i)
3124 assert(Ops[i]->getType() == ETy &&
3125 "SCEVMulExpr operand types don't match!");
3126#endif
3127
3128 const SCEV *Folded = constantFoldAndGroupOps(
3129 *this, LI, DT, Ops,
3130 [](const APInt &C1, const APInt &C2) { return C1 * C2; },
3131 [](const APInt &C) { return C.isOne(); }, // identity
3132 [](const APInt &C) { return C.isZero(); }); // absorber
3133 if (Folded)
3134 return Folded;
3135
3136 // Delay expensive flag strengthening until necessary.
3137 auto ComputeFlags = [this, OrigFlags](const ArrayRef<const SCEV *> Ops) {
3138 return StrengthenNoWrapFlags(this, scMulExpr, Ops, OrigFlags);
3139 };
3140
3141 // Limit recursion calls depth.
3143 return getOrCreateMulExpr(Ops, ComputeFlags(Ops));
3144
3145 if (SCEV *S = findExistingSCEVInCache(scMulExpr, Ops)) {
3146 // Don't strengthen flags if we have no new information.
3147 SCEVMulExpr *Mul = static_cast<SCEVMulExpr *>(S);
3148 if (Mul->getNoWrapFlags(OrigFlags) != OrigFlags)
3149 Mul->setNoWrapFlags(ComputeFlags(Ops));
3150 return S;
3151 }
3152
3153 if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(Ops[0])) {
3154 if (Ops.size() == 2) {
3155 // C1*(C2+V) -> C1*C2 + C1*V
3156 if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Ops[1]))
3157 // If any of Add's ops are Adds or Muls with a constant, apply this
3158 // transformation as well.
3159 //
3160 // TODO: There are some cases where this transformation is not
3161 // profitable; for example, Add = (C0 + X) * Y + Z. Maybe the scope of
3162 // this transformation should be narrowed down.
3163 if (Add->getNumOperands() == 2 && containsConstantInAddMulChain(Add)) {
3164 const SCEV *LHS = getMulExpr(LHSC, Add->getOperand(0),
3166 const SCEV *RHS = getMulExpr(LHSC, Add->getOperand(1),
3168 return getAddExpr(LHS, RHS, SCEV::FlagAnyWrap, Depth + 1);
3169 }
3170
3171 if (Ops[0]->isAllOnesValue()) {
3172 // If we have a mul by -1 of an add, try distributing the -1 among the
3173 // add operands.
3174 if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Ops[1])) {
3176 bool AnyFolded = false;
3177 for (const SCEV *AddOp : Add->operands()) {
3178 const SCEV *Mul = getMulExpr(Ops[0], AddOp, SCEV::FlagAnyWrap,
3179 Depth + 1);
3180 if (!isa<SCEVMulExpr>(Mul)) AnyFolded = true;
3181 NewOps.push_back(Mul);
3182 }
3183 if (AnyFolded)
3184 return getAddExpr(NewOps, SCEV::FlagAnyWrap, Depth + 1);
3185 } else if (const auto *AddRec = dyn_cast<SCEVAddRecExpr>(Ops[1])) {
3186 // Negation preserves a recurrence's no self-wrap property.
3188 for (const SCEV *AddRecOp : AddRec->operands())
3189 Operands.push_back(getMulExpr(Ops[0], AddRecOp, SCEV::FlagAnyWrap,
3190 Depth + 1));
3191 // Let M be the minimum representable signed value. AddRec with nsw
3192 // multiplied by -1 can have signed overflow if and only if it takes a
3193 // value of M: M * (-1) would stay M and (M + 1) * (-1) would be the
3194 // maximum signed value. In all other cases signed overflow is
3195 // impossible.
3196 auto FlagsMask = SCEV::FlagNW;
3197 if (hasFlags(AddRec->getNoWrapFlags(), SCEV::FlagNSW)) {
3198 auto MinInt =
3199 APInt::getSignedMinValue(getTypeSizeInBits(AddRec->getType()));
3200 if (getSignedRangeMin(AddRec) != MinInt)
3201 FlagsMask = setFlags(FlagsMask, SCEV::FlagNSW);
3202 }
3203 return getAddRecExpr(Operands, AddRec->getLoop(),
3204 AddRec->getNoWrapFlags(FlagsMask));
3205 }
3206 }
3207 }
3208 }
3209
3210 // Skip over the add expression until we get to a multiply.
3211 unsigned Idx = 0;
3212 while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scMulExpr)
3213 ++Idx;
3214
3215 // If there are mul operands inline them all into this expression.
3216 if (Idx < Ops.size()) {
3217 bool DeletedMul = false;
3218 while (const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(Ops[Idx])) {
3219 if (Ops.size() > MulOpsInlineThreshold)
3220 break;
3221 // If we have an mul, expand the mul operands onto the end of the
3222 // operands list.
3223 Ops.erase(Ops.begin()+Idx);
3224 append_range(Ops, Mul->operands());
3225 DeletedMul = true;
3226 }
3227
3228 // If we deleted at least one mul, we added operands to the end of the
3229 // list, and they are not necessarily sorted. Recurse to resort and
3230 // resimplify any operands we just acquired.
3231 if (DeletedMul)
3232 return getMulExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
3233 }
3234
3235 // If there are any add recurrences in the operands list, see if any other
3236 // added values are loop invariant. If so, we can fold them into the
3237 // recurrence.
3238 while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scAddRecExpr)
3239 ++Idx;
3240
3241 // Scan over all recurrences, trying to fold loop invariants into them.
3242 for (; Idx < Ops.size() && isa<SCEVAddRecExpr>(Ops[Idx]); ++Idx) {
3243 // Scan all of the other operands to this mul and add them to the vector
3244 // if they are loop invariant w.r.t. the recurrence.
3246 const SCEVAddRecExpr *AddRec = cast<SCEVAddRecExpr>(Ops[Idx]);
3247 for (unsigned i = 0, e = Ops.size(); i != e; ++i)
3248 if (isAvailableAtLoopEntry(Ops[i], AddRec->getLoop())) {
3249 LIOps.push_back(Ops[i]);
3250 Ops.erase(Ops.begin()+i);
3251 --i; --e;
3252 }
3253
3254 // If we found some loop invariants, fold them into the recurrence.
3255 if (!LIOps.empty()) {
3256 // NLI * LI * {Start,+,Step} --> NLI * {LI*Start,+,LI*Step}
3258 NewOps.reserve(AddRec->getNumOperands());
3259 const SCEV *Scale = getMulExpr(LIOps, SCEV::FlagAnyWrap, Depth + 1);
3260
3261 // If both the mul and addrec are nuw, we can preserve nuw.
3262 // If both the mul and addrec are nsw, we can only preserve nsw if either
3263 // a) they are also nuw, or
3264 // b) all multiplications of addrec operands with scale are nsw.
3265 SCEV::NoWrapFlags Flags =
3266 AddRec->getNoWrapFlags(ComputeFlags({Scale, AddRec}));
3267
3268 for (unsigned i = 0, e = AddRec->getNumOperands(); i != e; ++i) {
3269 NewOps.push_back(getMulExpr(Scale, AddRec->getOperand(i),
3270 SCEV::FlagAnyWrap, Depth + 1));
3271
3272 if (hasFlags(Flags, SCEV::FlagNSW) && !hasFlags(Flags, SCEV::FlagNUW)) {
3274 Instruction::Mul, getSignedRange(Scale),
3276 if (!NSWRegion.contains(getSignedRange(AddRec->getOperand(i))))
3277 Flags = clearFlags(Flags, SCEV::FlagNSW);
3278 }
3279 }
3280
3281 const SCEV *NewRec = getAddRecExpr(NewOps, AddRec->getLoop(), Flags);
3282
3283 // If all of the other operands were loop invariant, we are done.
3284 if (Ops.size() == 1) return NewRec;
3285
3286 // Otherwise, multiply the folded AddRec by the non-invariant parts.
3287 for (unsigned i = 0;; ++i)
3288 if (Ops[i] == AddRec) {
3289 Ops[i] = NewRec;
3290 break;
3291 }
3292 return getMulExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
3293 }
3294
3295 // Okay, if there weren't any loop invariants to be folded, check to see
3296 // if there are multiple AddRec's with the same loop induction variable
3297 // being multiplied together. If so, we can fold them.
3298
3299 // {A1,+,A2,+,...,+,An}<L> * {B1,+,B2,+,...,+,Bn}<L>
3300 // = {x=1 in [ sum y=x..2x [ sum z=max(y-x, y-n)..min(x,n) [
3301 // choose(x, 2x)*choose(2x-y, x-z)*A_{y-z}*B_z
3302 // ]]],+,...up to x=2n}.
3303 // Note that the arguments to choose() are always integers with values
3304 // known at compile time, never SCEV objects.
3305 //
3306 // The implementation avoids pointless extra computations when the two
3307 // addrec's are of different length (mathematically, it's equivalent to
3308 // an infinite stream of zeros on the right).
3309 bool OpsModified = false;
3310 for (unsigned OtherIdx = Idx+1;
3311 OtherIdx != Ops.size() && isa<SCEVAddRecExpr>(Ops[OtherIdx]);
3312 ++OtherIdx) {
3313 const SCEVAddRecExpr *OtherAddRec =
3314 dyn_cast<SCEVAddRecExpr>(Ops[OtherIdx]);
3315 if (!OtherAddRec || OtherAddRec->getLoop() != AddRec->getLoop())
3316 continue;
3317
3318 // Limit max number of arguments to avoid creation of unreasonably big
3319 // SCEVAddRecs with very complex operands.
3320 if (AddRec->getNumOperands() + OtherAddRec->getNumOperands() - 1 >
3321 MaxAddRecSize || hasHugeExpression({AddRec, OtherAddRec}))
3322 continue;
3323
3324 bool Overflow = false;
3325 Type *Ty = AddRec->getType();
3326 bool LargerThan64Bits = getTypeSizeInBits(Ty) > 64;
3328 for (int x = 0, xe = AddRec->getNumOperands() +
3329 OtherAddRec->getNumOperands() - 1; x != xe && !Overflow; ++x) {
3330 SmallVector <const SCEV *, 7> SumOps;
3331 for (int y = x, ye = 2*x+1; y != ye && !Overflow; ++y) {
3332 uint64_t Coeff1 = Choose(x, 2*x - y, Overflow);
3333 for (int z = std::max(y-x, y-(int)AddRec->getNumOperands()+1),
3334 ze = std::min(x+1, (int)OtherAddRec->getNumOperands());
3335 z < ze && !Overflow; ++z) {
3336 uint64_t Coeff2 = Choose(2*x - y, x-z, Overflow);
3337 uint64_t Coeff;
3338 if (LargerThan64Bits)
3339 Coeff = umul_ov(Coeff1, Coeff2, Overflow);
3340 else
3341 Coeff = Coeff1*Coeff2;
3342 const SCEV *CoeffTerm = getConstant(Ty, Coeff);
3343 const SCEV *Term1 = AddRec->getOperand(y-z);
3344 const SCEV *Term2 = OtherAddRec->getOperand(z);
3345 SumOps.push_back(getMulExpr(CoeffTerm, Term1, Term2,
3346 SCEV::FlagAnyWrap, Depth + 1));
3347 }
3348 }
3349 if (SumOps.empty())
3350 SumOps.push_back(getZero(Ty));
3351 AddRecOps.push_back(getAddExpr(SumOps, SCEV::FlagAnyWrap, Depth + 1));
3352 }
3353 if (!Overflow) {
3354 const SCEV *NewAddRec = getAddRecExpr(AddRecOps, AddRec->getLoop(),
3356 if (Ops.size() == 2) return NewAddRec;
3357 Ops[Idx] = NewAddRec;
3358 Ops.erase(Ops.begin() + OtherIdx); --OtherIdx;
3359 OpsModified = true;
3360 AddRec = dyn_cast<SCEVAddRecExpr>(NewAddRec);
3361 if (!AddRec)
3362 break;
3363 }
3364 }
3365 if (OpsModified)
3366 return getMulExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
3367
3368 // Otherwise couldn't fold anything into this recurrence. Move onto the
3369 // next one.
3370 }
3371
3372 // Okay, it looks like we really DO need an mul expr. Check to see if we
3373 // already have one, otherwise create a new one.
3374 return getOrCreateMulExpr(Ops, ComputeFlags(Ops));
3375}
3376
3377/// Represents an unsigned remainder expression based on unsigned division.
3379 const SCEV *RHS) {
3382 "SCEVURemExpr operand types don't match!");
3383
3384 // Short-circuit easy cases
3385 if (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS)) {
3386 // If constant is one, the result is trivial
3387 if (RHSC->getValue()->isOne())
3388 return getZero(LHS->getType()); // X urem 1 --> 0
3389
3390 // If constant is a power of two, fold into a zext(trunc(LHS)).
3391 if (RHSC->getAPInt().isPowerOf2()) {
3392 Type *FullTy = LHS->getType();
3393 Type *TruncTy =
3394 IntegerType::get(getContext(), RHSC->getAPInt().logBase2());
3395 return getZeroExtendExpr(getTruncateExpr(LHS, TruncTy), FullTy);
3396 }
3397 }
3398
3399 // Fallback to %a == %x urem %y == %x -<nuw> ((%x udiv %y) *<nuw> %y)
3400 const SCEV *UDiv = getUDivExpr(LHS, RHS);
3401 const SCEV *Mult = getMulExpr(UDiv, RHS, SCEV::FlagNUW);
3402 return getMinusSCEV(LHS, Mult, SCEV::FlagNUW);
3403}
3404
3405/// Get a canonical unsigned division expression, or something simpler if
3406/// possible.
3408 const SCEV *RHS) {
3409 assert(!LHS->getType()->isPointerTy() &&
3410 "SCEVUDivExpr operand can't be pointer!");
3411 assert(LHS->getType() == RHS->getType() &&
3412 "SCEVUDivExpr operand types don't match!");
3413
3415 ID.AddInteger(scUDivExpr);
3416 ID.AddPointer(LHS);
3417 ID.AddPointer(RHS);
3418 void *IP = nullptr;
3419 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP))
3420 return S;
3421
3422 // 0 udiv Y == 0
3423 if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(LHS))
3424 if (LHSC->getValue()->isZero())
3425 return LHS;
3426
3427 if (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS)) {
3428 if (RHSC->getValue()->isOne())
3429 return LHS; // X udiv 1 --> x
3430 // If the denominator is zero, the result of the udiv is undefined. Don't
3431 // try to analyze it, because the resolution chosen here may differ from
3432 // the resolution chosen in other parts of the compiler.
3433 if (!RHSC->getValue()->isZero()) {
3434 // Determine if the division can be folded into the operands of
3435 // its operands.
3436 // TODO: Generalize this to non-constants by using known-bits information.
3437 Type *Ty = LHS->getType();
3438 unsigned LZ = RHSC->getAPInt().countl_zero();
3439 unsigned MaxShiftAmt = getTypeSizeInBits(Ty) - LZ - 1;
3440 // For non-power-of-two values, effectively round the value up to the
3441 // nearest power of two.
3442 if (!RHSC->getAPInt().isPowerOf2())
3443 ++MaxShiftAmt;
3444 IntegerType *ExtTy =
3445 IntegerType::get(getContext(), getTypeSizeInBits(Ty) + MaxShiftAmt);
3446 if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(LHS))
3447 if (const SCEVConstant *Step =
3448 dyn_cast<SCEVConstant>(AR->getStepRecurrence(*this))) {
3449 // {X,+,N}/C --> {X/C,+,N/C} if safe and N/C can be folded.
3450 const APInt &StepInt = Step->getAPInt();
3451 const APInt &DivInt = RHSC->getAPInt();
3452 if (!StepInt.urem(DivInt) &&
3453 getZeroExtendExpr(AR, ExtTy) ==
3454 getAddRecExpr(getZeroExtendExpr(AR->getStart(), ExtTy),
3455 getZeroExtendExpr(Step, ExtTy),
3456 AR->getLoop(), SCEV::FlagAnyWrap)) {
3458 for (const SCEV *Op : AR->operands())
3459 Operands.push_back(getUDivExpr(Op, RHS));
3460 return getAddRecExpr(Operands, AR->getLoop(), SCEV::FlagNW);
3461 }
3462 /// Get a canonical UDivExpr for a recurrence.
3463 /// {X,+,N}/C => {Y,+,N}/C where Y=X-(X%N). Safe when C%N=0.
3464 // We can currently only fold X%N if X is constant.
3465 const SCEVConstant *StartC = dyn_cast<SCEVConstant>(AR->getStart());
3466 if (StartC && !DivInt.urem(StepInt) &&
3467 getZeroExtendExpr(AR, ExtTy) ==
3468 getAddRecExpr(getZeroExtendExpr(AR->getStart(), ExtTy),
3469 getZeroExtendExpr(Step, ExtTy),
3470 AR->getLoop(), SCEV::FlagAnyWrap)) {
3471 const APInt &StartInt = StartC->getAPInt();
3472 const APInt &StartRem = StartInt.urem(StepInt);
3473 if (StartRem != 0) {
3474 const SCEV *NewLHS =
3475 getAddRecExpr(getConstant(StartInt - StartRem), Step,
3476 AR->getLoop(), SCEV::FlagNW);
3477 if (LHS != NewLHS) {
3478 LHS = NewLHS;
3479
3480 // Reset the ID to include the new LHS, and check if it is
3481 // already cached.
3482 ID.clear();
3483 ID.AddInteger(scUDivExpr);
3484 ID.AddPointer(LHS);
3485 ID.AddPointer(RHS);
3486 IP = nullptr;
3487 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP))
3488 return S;
3489 }
3490 }
3491 }
3492 }
3493 // (A*B)/C --> A*(B/C) if safe and B/C can be folded.
3494 if (const SCEVMulExpr *M = dyn_cast<SCEVMulExpr>(LHS)) {
3496 for (const SCEV *Op : M->operands())
3497 Operands.push_back(getZeroExtendExpr(Op, ExtTy));
3498 if (getZeroExtendExpr(M, ExtTy) == getMulExpr(Operands))
3499 // Find an operand that's safely divisible.
3500 for (unsigned i = 0, e = M->getNumOperands(); i != e; ++i) {
3501 const SCEV *Op = M->getOperand(i);
3502 const SCEV *Div = getUDivExpr(Op, RHSC);
3503 if (!isa<SCEVUDivExpr>(Div) && getMulExpr(Div, RHSC) == Op) {
3504 Operands = SmallVector<const SCEV *, 4>(M->operands());
3505 Operands[i] = Div;
3506 return getMulExpr(Operands);
3507 }
3508 }
3509 }
3510
3511 // (A/B)/C --> A/(B*C) if safe and B*C can be folded.
3512 if (const SCEVUDivExpr *OtherDiv = dyn_cast<SCEVUDivExpr>(LHS)) {
3513 if (auto *DivisorConstant =
3514 dyn_cast<SCEVConstant>(OtherDiv->getRHS())) {
3515 bool Overflow = false;
3516 APInt NewRHS =
3517 DivisorConstant->getAPInt().umul_ov(RHSC->getAPInt(), Overflow);
3518 if (Overflow) {
3519 return getConstant(RHSC->getType(), 0, false);
3520 }
3521 return getUDivExpr(OtherDiv->getLHS(), getConstant(NewRHS));
3522 }
3523 }
3524
3525 // (A+B)/C --> (A/C + B/C) if safe and A/C and B/C can be folded.
3526 if (const SCEVAddExpr *A = dyn_cast<SCEVAddExpr>(LHS)) {
3528 for (const SCEV *Op : A->operands())
3529 Operands.push_back(getZeroExtendExpr(Op, ExtTy));
3530 if (getZeroExtendExpr(A, ExtTy) == getAddExpr(Operands)) {
3531 Operands.clear();
3532 for (unsigned i = 0, e = A->getNumOperands(); i != e; ++i) {
3533 const SCEV *Op = getUDivExpr(A->getOperand(i), RHS);
3534 if (isa<SCEVUDivExpr>(Op) ||
3535 getMulExpr(Op, RHS) != A->getOperand(i))
3536 break;
3537 Operands.push_back(Op);
3538 }
3539 if (Operands.size() == A->getNumOperands())
3540 return getAddExpr(Operands);
3541 }
3542 }
3543
3544 // Fold if both operands are constant.
3545 if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(LHS))
3546 return getConstant(LHSC->getAPInt().udiv(RHSC->getAPInt()));
3547 }
3548 }
3549
3550 // The Insertion Point (IP) might be invalid by now (due to UniqueSCEVs
3551 // changes). Make sure we get a new one.
3552 IP = nullptr;
3553 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
3554 SCEV *S = new (SCEVAllocator) SCEVUDivExpr(ID.Intern(SCEVAllocator),
3555 LHS, RHS);
3556 UniqueSCEVs.InsertNode(S, IP);
3557 registerUser(S, {LHS, RHS});
3558 return S;
3559}
3560
3561APInt gcd(const SCEVConstant *C1, const SCEVConstant *C2) {
3562 APInt A = C1->getAPInt().abs();
3563 APInt B = C2->getAPInt().abs();
3564 uint32_t ABW = A.getBitWidth();
3565 uint32_t BBW = B.getBitWidth();
3566
3567 if (ABW > BBW)
3568 B = B.zext(ABW);
3569 else if (ABW < BBW)
3570 A = A.zext(BBW);
3571
3572 return APIntOps::GreatestCommonDivisor(std::move(A), std::move(B));
3573}
3574
3575/// Get a canonical unsigned division expression, or something simpler if
3576/// possible. There is no representation for an exact udiv in SCEV IR, but we
3577/// can attempt to remove factors from the LHS and RHS. We can't do this when
3578/// it's not exact because the udiv may be clearing bits.
3580 const SCEV *RHS) {
3581 // TODO: we could try to find factors in all sorts of things, but for now we
3582 // just deal with u/exact (multiply, constant). See SCEVDivision towards the
3583 // end of this file for inspiration.
3584
3585 const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(LHS);
3586 if (!Mul || !Mul->hasNoUnsignedWrap())
3587 return getUDivExpr(LHS, RHS);
3588
3589 if (const SCEVConstant *RHSCst = dyn_cast<SCEVConstant>(RHS)) {
3590 // If the mulexpr multiplies by a constant, then that constant must be the
3591 // first element of the mulexpr.
3592 if (const auto *LHSCst = dyn_cast<SCEVConstant>(Mul->getOperand(0))) {
3593 if (LHSCst == RHSCst) {
3595 return getMulExpr(Operands);
3596 }
3597
3598 // We can't just assume that LHSCst divides RHSCst cleanly, it could be
3599 // that there's a factor provided by one of the other terms. We need to
3600 // check.
3601 APInt Factor = gcd(LHSCst, RHSCst);
3602 if (!Factor.isIntN(1)) {
3603 LHSCst =
3604 cast<SCEVConstant>(getConstant(LHSCst->getAPInt().udiv(Factor)));
3605 RHSCst =
3606 cast<SCEVConstant>(getConstant(RHSCst->getAPInt().udiv(Factor)));
3608 Operands.push_back(LHSCst);
3609 append_range(Operands, Mul->operands().drop_front());
3611 RHS = RHSCst;
3612 Mul = dyn_cast<SCEVMulExpr>(LHS);
3613 if (!Mul)
3614 return getUDivExactExpr(LHS, RHS);
3615 }
3616 }
3617 }
3618
3619 for (int i = 0, e = Mul->getNumOperands(); i != e; ++i) {
3620 if (Mul->getOperand(i) == RHS) {
3622 append_range(Operands, Mul->operands().take_front(i));
3623 append_range(Operands, Mul->operands().drop_front(i + 1));
3624 return getMulExpr(Operands);
3625 }
3626 }
3627
3628 return getUDivExpr(LHS, RHS);
3629}
3630
3631/// Get an add recurrence expression for the specified loop. Simplify the
3632/// expression as much as possible.
3633const SCEV *ScalarEvolution::getAddRecExpr(const SCEV *Start, const SCEV *Step,
3634 const Loop *L,
3635 SCEV::NoWrapFlags Flags) {
3637 Operands.push_back(Start);
3638 if (const SCEVAddRecExpr *StepChrec = dyn_cast<SCEVAddRecExpr>(Step))
3639 if (StepChrec->getLoop() == L) {
3640 append_range(Operands, StepChrec->operands());
3641 return getAddRecExpr(Operands, L, maskFlags(Flags, SCEV::FlagNW));
3642 }
3643
3644 Operands.push_back(Step);
3645 return getAddRecExpr(Operands, L, Flags);
3646}
3647
3648/// Get an add recurrence expression for the specified loop. Simplify the
3649/// expression as much as possible.
3650const SCEV *
3652 const Loop *L, SCEV::NoWrapFlags Flags) {
3653 if (Operands.size() == 1) return Operands[0];
3654#ifndef NDEBUG
3656 for (const SCEV *Op : llvm::drop_begin(Operands)) {
3657 assert(getEffectiveSCEVType(Op->getType()) == ETy &&
3658 "SCEVAddRecExpr operand types don't match!");
3659 assert(!Op->getType()->isPointerTy() && "Step must be integer");
3660 }
3661 for (const SCEV *Op : Operands)
3663 "SCEVAddRecExpr operand is not available at loop entry!");
3664#endif
3665
3666 if (Operands.back()->isZero()) {
3667 Operands.pop_back();
3668 return getAddRecExpr(Operands, L, SCEV::FlagAnyWrap); // {X,+,0} --> X
3669 }
3670
3671 // It's tempting to want to call getConstantMaxBackedgeTakenCount count here and
3672 // use that information to infer NUW and NSW flags. However, computing a
3673 // BE count requires calling getAddRecExpr, so we may not yet have a
3674 // meaningful BE count at this point (and if we don't, we'd be stuck
3675 // with a SCEVCouldNotCompute as the cached BE count).
3676
3677 Flags = StrengthenNoWrapFlags(this, scAddRecExpr, Operands, Flags);
3678
3679 // Canonicalize nested AddRecs in by nesting them in order of loop depth.
3680 if (const SCEVAddRecExpr *NestedAR = dyn_cast<SCEVAddRecExpr>(Operands[0])) {
3681 const Loop *NestedLoop = NestedAR->getLoop();
3682 if (L->contains(NestedLoop)
3683 ? (L->getLoopDepth() < NestedLoop->getLoopDepth())
3684 : (!NestedLoop->contains(L) &&
3685 DT.dominates(L->getHeader(), NestedLoop->getHeader()))) {
3686 SmallVector<const SCEV *, 4> NestedOperands(NestedAR->operands());
3687 Operands[0] = NestedAR->getStart();
3688 // AddRecs require their operands be loop-invariant with respect to their
3689 // loops. Don't perform this transformation if it would break this
3690 // requirement.
3691 bool AllInvariant = all_of(
3692 Operands, [&](const SCEV *Op) { return isLoopInvariant(Op, L); });
3693
3694 if (AllInvariant) {
3695 // Create a recurrence for the outer loop with the same step size.
3696 //
3697 // The outer recurrence keeps its NW flag but only keeps NUW/NSW if the
3698 // inner recurrence has the same property.
3699 SCEV::NoWrapFlags OuterFlags =
3700 maskFlags(Flags, SCEV::FlagNW | NestedAR->getNoWrapFlags());
3701
3702 NestedOperands[0] = getAddRecExpr(Operands, L, OuterFlags);
3703 AllInvariant = all_of(NestedOperands, [&](const SCEV *Op) {
3704 return isLoopInvariant(Op, NestedLoop);
3705 });
3706
3707 if (AllInvariant) {
3708 // Ok, both add recurrences are valid after the transformation.
3709 //
3710 // The inner recurrence keeps its NW flag but only keeps NUW/NSW if
3711 // the outer recurrence has the same property.
3712 SCEV::NoWrapFlags InnerFlags =
3713 maskFlags(NestedAR->getNoWrapFlags(), SCEV::FlagNW | Flags);
3714 return getAddRecExpr(NestedOperands, NestedLoop, InnerFlags);
3715 }
3716 }
3717 // Reset Operands to its original state.
3718 Operands[0] = NestedAR;
3719 }
3720 }
3721
3722 // Okay, it looks like we really DO need an addrec expr. Check to see if we
3723 // already have one, otherwise create a new one.
3724 return getOrCreateAddRecExpr(Operands, L, Flags);
3725}
3726
3727const SCEV *
3729 const SmallVectorImpl<const SCEV *> &IndexExprs) {
3730 const SCEV *BaseExpr = getSCEV(GEP->getPointerOperand());
3731 // getSCEV(Base)->getType() has the same address space as Base->getType()
3732 // because SCEV::getType() preserves the address space.
3733 Type *IntIdxTy = getEffectiveSCEVType(BaseExpr->getType());
3734 GEPNoWrapFlags NW = GEP->getNoWrapFlags();
3735 if (NW != GEPNoWrapFlags::none()) {
3736 // We'd like to propagate flags from the IR to the corresponding SCEV nodes,
3737 // but to do that, we have to ensure that said flag is valid in the entire
3738 // defined scope of the SCEV.
3739 // TODO: non-instructions have global scope. We might be able to prove
3740 // some global scope cases
3741 auto *GEPI = dyn_cast<Instruction>(GEP);
3742 if (!GEPI || !isSCEVExprNeverPoison(GEPI))
3743 NW = GEPNoWrapFlags::none();
3744 }
3745
3747 if (NW.hasNoUnsignedSignedWrap())
3748 OffsetWrap = setFlags(OffsetWrap, SCEV::FlagNSW);
3749 if (NW.hasNoUnsignedWrap())
3750 OffsetWrap = setFlags(OffsetWrap, SCEV::FlagNUW);
3751
3752 Type *CurTy = GEP->getType();
3753 bool FirstIter = true;
3755 for (const SCEV *IndexExpr : IndexExprs) {
3756 // Compute the (potentially symbolic) offset in bytes for this index.
3757 if (StructType *STy = dyn_cast<StructType>(CurTy)) {
3758 // For a struct, add the member offset.
3759 ConstantInt *Index = cast<SCEVConstant>(IndexExpr)->getValue();
3760 unsigned FieldNo = Index->getZExtValue();
3761 const SCEV *FieldOffset = getOffsetOfExpr(IntIdxTy, STy, FieldNo);
3762 Offsets.push_back(FieldOffset);
3763
3764 // Update CurTy to the type of the field at Index.
3765 CurTy = STy->getTypeAtIndex(Index);
3766 } else {
3767 // Update CurTy to its element type.
3768 if (FirstIter) {
3769 assert(isa<PointerType>(CurTy) &&
3770 "The first index of a GEP indexes a pointer");
3771 CurTy = GEP->getSourceElementType();
3772 FirstIter = false;
3773 } else {
3775 }
3776 // For an array, add the element offset, explicitly scaled.
3777 const SCEV *ElementSize = getSizeOfExpr(IntIdxTy, CurTy);
3778 // Getelementptr indices are signed.
3779 IndexExpr = getTruncateOrSignExtend(IndexExpr, IntIdxTy);
3780
3781 // Multiply the index by the element size to compute the element offset.
3782 const SCEV *LocalOffset = getMulExpr(IndexExpr, ElementSize, OffsetWrap);
3783 Offsets.push_back(LocalOffset);
3784 }
3785 }
3786
3787 // Handle degenerate case of GEP without offsets.
3788 if (Offsets.empty())
3789 return BaseExpr;
3790
3791 // Add the offsets together, assuming nsw if inbounds.
3792 const SCEV *Offset = getAddExpr(Offsets, OffsetWrap);
3793 // Add the base address and the offset. We cannot use the nsw flag, as the
3794 // base address is unsigned. However, if we know that the offset is
3795 // non-negative, we can use nuw.
3796 bool NUW = NW.hasNoUnsignedWrap() ||
3799 auto *GEPExpr = getAddExpr(BaseExpr, Offset, BaseWrap);
3800 assert(BaseExpr->getType() == GEPExpr->getType() &&
3801 "GEP should not change type mid-flight.");
3802 return GEPExpr;
3803}
3804
3805SCEV *ScalarEvolution::findExistingSCEVInCache(SCEVTypes SCEVType,
3808 ID.AddInteger(SCEVType);
3809 for (const SCEV *Op : Ops)
3810 ID.AddPointer(Op);
3811 void *IP = nullptr;
3812 return UniqueSCEVs.FindNodeOrInsertPos(ID, IP);
3813}
3814
3815const SCEV *ScalarEvolution::getAbsExpr(const SCEV *Op, bool IsNSW) {
3817 return getSMaxExpr(Op, getNegativeSCEV(Op, Flags));
3818}
3819
3822 assert(SCEVMinMaxExpr::isMinMaxType(Kind) && "Not a SCEVMinMaxExpr!");
3823 assert(!Ops.empty() && "Cannot get empty (u|s)(min|max)!");
3824 if (Ops.size() == 1) return Ops[0];
3825#ifndef NDEBUG
3826 Type *ETy = getEffectiveSCEVType(Ops[0]->getType());
3827 for (unsigned i = 1, e = Ops.size(); i != e; ++i) {
3828 assert(getEffectiveSCEVType(Ops[i]->getType()) == ETy &&
3829 "Operand types don't match!");
3830 assert(Ops[0]->getType()->isPointerTy() ==
3831 Ops[i]->getType()->isPointerTy() &&
3832 "min/max should be consistently pointerish");
3833 }
3834#endif
3835
3836 bool IsSigned = Kind == scSMaxExpr || Kind == scSMinExpr;
3837 bool IsMax = Kind == scSMaxExpr || Kind == scUMaxExpr;
3838
3839 const SCEV *Folded = constantFoldAndGroupOps(
3840 *this, LI, DT, Ops,
3841 [&](const APInt &C1, const APInt &C2) {
3842 switch (Kind) {
3843 case scSMaxExpr:
3844 return APIntOps::smax(C1, C2);
3845 case scSMinExpr:
3846 return APIntOps::smin(C1, C2);
3847 case scUMaxExpr:
3848 return APIntOps::umax(C1, C2);
3849 case scUMinExpr:
3850 return APIntOps::umin(C1, C2);
3851 default:
3852 llvm_unreachable("Unknown SCEV min/max opcode");
3853 }
3854 },
3855 [&](const APInt &C) {
3856 // identity
3857 if (IsMax)
3858 return IsSigned ? C.isMinSignedValue() : C.isMinValue();
3859 else
3860 return IsSigned ? C.isMaxSignedValue() : C.isMaxValue();
3861 },
3862 [&](const APInt &C) {
3863 // absorber
3864 if (IsMax)
3865 return IsSigned ? C.isMaxSignedValue() : C.isMaxValue();
3866 else
3867 return IsSigned ? C.isMinSignedValue() : C.isMinValue();
3868 });
3869 if (Folded)
3870 return Folded;
3871
3872 // Check if we have created the same expression before.
3873 if (const SCEV *S = findExistingSCEVInCache(Kind, Ops)) {
3874 return S;
3875 }
3876
3877 // Find the first operation of the same kind
3878 unsigned Idx = 0;
3879 while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < Kind)
3880 ++Idx;
3881
3882 // Check to see if one of the operands is of the same kind. If so, expand its
3883 // operands onto our operand list, and recurse to simplify.
3884 if (Idx < Ops.size()) {
3885 bool DeletedAny = false;
3886 while (Ops[Idx]->getSCEVType() == Kind) {
3887 const SCEVMinMaxExpr *SMME = cast<SCEVMinMaxExpr>(Ops[Idx]);
3888 Ops.erase(Ops.begin()+Idx);
3889 append_range(Ops, SMME->operands());
3890 DeletedAny = true;
3891 }
3892
3893 if (DeletedAny)
3894 return getMinMaxExpr(Kind, Ops);
3895 }
3896
3897 // Okay, check to see if the same value occurs in the operand list twice. If
3898 // so, delete one. Since we sorted the list, these values are required to
3899 // be adjacent.
3904 llvm::CmpInst::Predicate FirstPred = IsMax ? GEPred : LEPred;
3905 llvm::CmpInst::Predicate SecondPred = IsMax ? LEPred : GEPred;
3906 for (unsigned i = 0, e = Ops.size() - 1; i != e; ++i) {
3907 if (Ops[i] == Ops[i + 1] ||
3908 isKnownViaNonRecursiveReasoning(FirstPred, Ops[i], Ops[i + 1])) {
3909 // X op Y op Y --> X op Y
3910 // X op Y --> X, if we know X, Y are ordered appropriately
3911 Ops.erase(Ops.begin() + i + 1, Ops.begin() + i + 2);
3912 --i;
3913 --e;
3914 } else if (isKnownViaNonRecursiveReasoning(SecondPred, Ops[i],
3915 Ops[i + 1])) {
3916 // X op Y --> Y, if we know X, Y are ordered appropriately
3917 Ops.erase(Ops.begin() + i, Ops.begin() + i + 1);
3918 --i;
3919 --e;
3920 }
3921 }
3922
3923 if (Ops.size() == 1) return Ops[0];
3924
3925 assert(!Ops.empty() && "Reduced smax down to nothing!");
3926
3927 // Okay, it looks like we really DO need an expr. Check to see if we
3928 // already have one, otherwise create a new one.
3930 ID.AddInteger(Kind);
3931 for (const SCEV *Op : Ops)
3932 ID.AddPointer(Op);
3933 void *IP = nullptr;
3934 const SCEV *ExistingSCEV = UniqueSCEVs.FindNodeOrInsertPos(ID, IP);
3935 if (ExistingSCEV)
3936 return ExistingSCEV;
3937 const SCEV **O = SCEVAllocator.Allocate<const SCEV *>(Ops.size());
3938 std::uninitialized_copy(Ops.begin(), Ops.end(), O);
3939 SCEV *S = new (SCEVAllocator)
3940 SCEVMinMaxExpr(ID.Intern(SCEVAllocator), Kind, O, Ops.size());
3941
3942 UniqueSCEVs.InsertNode(S, IP);
3943 registerUser(S, Ops);
3944 return S;
3945}
3946
3947namespace {
3948
3949class SCEVSequentialMinMaxDeduplicatingVisitor final
3950 : public SCEVVisitor<SCEVSequentialMinMaxDeduplicatingVisitor,
3951 std::optional<const SCEV *>> {
3952 using RetVal = std::optional<const SCEV *>;
3954
3955 ScalarEvolution &SE;
3956 const SCEVTypes RootKind; // Must be a sequential min/max expression.
3957 const SCEVTypes NonSequentialRootKind; // Non-sequential variant of RootKind.
3959
3960 bool canRecurseInto(SCEVTypes Kind) const {
3961 // We can only recurse into the SCEV expression of the same effective type
3962 // as the type of our root SCEV expression.
3963 return RootKind == Kind || NonSequentialRootKind == Kind;
3964 };
3965
3966 RetVal visitAnyMinMaxExpr(const SCEV *S) {
3967 assert((isa<SCEVMinMaxExpr>(S) || isa<SCEVSequentialMinMaxExpr>(S)) &&
3968 "Only for min/max expressions.");
3969 SCEVTypes Kind = S->getSCEVType();
3970
3971 if (!canRecurseInto(Kind))
3972 return S;
3973
3974 auto *NAry = cast<SCEVNAryExpr>(S);
3976 bool Changed = visit(Kind, NAry->operands(), NewOps);
3977
3978 if (!Changed)
3979 return S;
3980 if (NewOps.empty())
3981 return std::nullopt;
3982
3983 return isa<SCEVSequentialMinMaxExpr>(S)
3984 ? SE.getSequentialMinMaxExpr(Kind, NewOps)
3985 : SE.getMinMaxExpr(Kind, NewOps);
3986 }
3987
3988 RetVal visit(const SCEV *S) {
3989 // Has the whole operand been seen already?
3990 if (!SeenOps.insert(S).second)
3991 return std::nullopt;
3992 return Base::visit(S);
3993 }
3994
3995public:
3996 SCEVSequentialMinMaxDeduplicatingVisitor(ScalarEvolution &SE,
3997 SCEVTypes RootKind)
3998 : SE(SE), RootKind(RootKind),
3999 NonSequentialRootKind(
4000 SCEVSequentialMinMaxExpr::getEquivalentNonSequentialSCEVType(
4001 RootKind)) {}
4002
4003 bool /*Changed*/ visit(SCEVTypes Kind, ArrayRef<const SCEV *> OrigOps,
4005 bool Changed = false;
4007 Ops.reserve(OrigOps.size());
4008
4009 for (const SCEV *Op : OrigOps) {
4010 RetVal NewOp = visit(Op);
4011 if (NewOp != Op)
4012 Changed = true;
4013 if (NewOp)
4014 Ops.emplace_back(*NewOp);
4015 }
4016
4017 if (Changed)
4018 NewOps = std::move(Ops);
4019 return Changed;
4020 }
4021
4022 RetVal visitConstant(const SCEVConstant *Constant) { return Constant; }
4023
4024 RetVal visitVScale(const SCEVVScale *VScale) { return VScale; }
4025
4026 RetVal visitPtrToIntExpr(const SCEVPtrToIntExpr *Expr) { return Expr; }
4027
4028 RetVal visitTruncateExpr(const SCEVTruncateExpr *Expr) { return Expr; }
4029
4030 RetVal visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr) { return Expr; }
4031
4032 RetVal visitSignExtendExpr(const SCEVSignExtendExpr *Expr) { return Expr; }
4033
4034 RetVal visitAddExpr(const SCEVAddExpr *Expr) { return Expr; }
4035
4036 RetVal visitMulExpr(const SCEVMulExpr *Expr) { return Expr; }
4037
4038 RetVal visitUDivExpr(const SCEVUDivExpr *Expr) { return Expr; }
4039
4040 RetVal visitAddRecExpr(const SCEVAddRecExpr *Expr) { return Expr; }
4041
4042 RetVal visitSMaxExpr(const SCEVSMaxExpr *Expr) {
4043 return visitAnyMinMaxExpr(Expr);
4044 }
4045
4046 RetVal visitUMaxExpr(const SCEVUMaxExpr *Expr) {
4047 return visitAnyMinMaxExpr(Expr);
4048 }
4049
4050 RetVal visitSMinExpr(const SCEVSMinExpr *Expr) {
4051 return visitAnyMinMaxExpr(Expr);
4052 }
4053
4054 RetVal visitUMinExpr(const SCEVUMinExpr *Expr) {
4055 return visitAnyMinMaxExpr(Expr);
4056 }
4057
4058 RetVal visitSequentialUMinExpr(const SCEVSequentialUMinExpr *Expr) {
4059 return visitAnyMinMaxExpr(Expr);
4060 }
4061
4062 RetVal visitUnknown(const SCEVUnknown *Expr) { return Expr; }
4063
4064 RetVal visitCouldNotCompute(const SCEVCouldNotCompute *Expr) { return Expr; }
4065};
4066
4067} // namespace
4068
4070 switch (Kind) {
4071 case scConstant:
4072 case scVScale:
4073 case scTruncate:
4074 case scZeroExtend:
4075 case scSignExtend:
4076 case scPtrToInt:
4077 case scAddExpr:
4078 case scMulExpr:
4079 case scUDivExpr:
4080 case scAddRecExpr:
4081 case scUMaxExpr:
4082 case scSMaxExpr:
4083 case scUMinExpr:
4084 case scSMinExpr:
4085 case scUnknown:
4086 // If any operand is poison, the whole expression is poison.
4087 return true;
4089 // FIXME: if the *first* operand is poison, the whole expression is poison.
4090 return false; // Pessimistically, say that it does not propagate poison.
4091 case scCouldNotCompute:
4092 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
4093 }
4094 llvm_unreachable("Unknown SCEV kind!");
4095}
4096
4097namespace {
4098// The only way poison may be introduced in a SCEV expression is from a
4099// poison SCEVUnknown (ConstantExprs are also represented as SCEVUnknown,
4100// not SCEVConstant). Notably, nowrap flags in SCEV nodes can *not*
4101// introduce poison -- they encode guaranteed, non-speculated knowledge.
4102//
4103// Additionally, all SCEV nodes propagate poison from inputs to outputs,
4104// with the notable exception of umin_seq, where only poison from the first
4105// operand is (unconditionally) propagated.
4106struct SCEVPoisonCollector {
4107 bool LookThroughMaybePoisonBlocking;
4109 SCEVPoisonCollector(bool LookThroughMaybePoisonBlocking)
4110 : LookThroughMaybePoisonBlocking(LookThroughMaybePoisonBlocking) {}
4111
4112 bool follow(const SCEV *S) {
4113 if (!LookThroughMaybePoisonBlocking &&
4115 return false;
4116
4117 if (auto *SU = dyn_cast<SCEVUnknown>(S)) {
4118 if (!isGuaranteedNotToBePoison(SU->getValue()))
4119 MaybePoison.insert(SU);
4120 }
4121 return true;
4122 }
4123 bool isDone() const { return false; }
4124};
4125} // namespace
4126
4127/// Return true if V is poison given that AssumedPoison is already poison.
4128static bool impliesPoison(const SCEV *AssumedPoison, const SCEV *S) {
4129 // First collect all SCEVs that might result in AssumedPoison to be poison.
4130 // We need to look through potentially poison-blocking operations here,
4131 // because we want to find all SCEVs that *might* result in poison, not only
4132 // those that are *required* to.
4133 SCEVPoisonCollector PC1(/* LookThroughMaybePoisonBlocking */ true);
4134 visitAll(AssumedPoison, PC1);
4135
4136 // AssumedPoison is never poison. As the assumption is false, the implication
4137 // is true. Don't bother walking the other SCEV in this case.
4138 if (PC1.MaybePoison.empty())
4139 return true;
4140
4141 // Collect all SCEVs in S that, if poison, *will* result in S being poison
4142 // as well. We cannot look through potentially poison-blocking operations
4143 // here, as their arguments only *may* make the result poison.
4144 SCEVPoisonCollector PC2(/* LookThroughMaybePoisonBlocking */ false);
4145 visitAll(S, PC2);
4146
4147 // Make sure that no matter which SCEV in PC1.MaybePoison is actually poison,
4148 // it will also make S poison by being part of PC2.MaybePoison.
4149 return llvm::set_is_subset(PC1.MaybePoison, PC2.MaybePoison);
4150}
4151
4153 SmallPtrSetImpl<const Value *> &Result, const SCEV *S) {
4154 SCEVPoisonCollector PC(/* LookThroughMaybePoisonBlocking */ false);
4155 visitAll(S, PC);
4156 for (const SCEVUnknown *SU : PC.MaybePoison)
4157 Result.insert(SU->getValue());
4158}
4159
4161 const SCEV *S, Instruction *I,
4162 SmallVectorImpl<Instruction *> &DropPoisonGeneratingInsts) {
4163 // If the instruction cannot be poison, it's always safe to reuse.
4165 return true;
4166
4167 // Otherwise, it is possible that I is more poisonous that S. Collect the
4168 // poison-contributors of S, and then check whether I has any additional
4169 // poison-contributors. Poison that is contributed through poison-generating
4170 // flags is handled by dropping those flags instead.
4172 getPoisonGeneratingValues(PoisonVals, S);
4173
4174 SmallVector<Value *> Worklist;
4176 Worklist.push_back(I);
4177 while (!Worklist.empty()) {
4178 Value *V = Worklist.pop_back_val();
4179 if (!Visited.insert(V).second)
4180 continue;
4181
4182 // Avoid walking large instruction graphs.
4183 if (Visited.size() > 16)
4184 return false;
4185
4186 // Either the value can't be poison, or the S would also be poison if it
4187 // is.
4188 if (PoisonVals.contains(V) || isGuaranteedNotToBePoison(V))
4189 continue;
4190
4191 auto *I = dyn_cast<Instruction>(V);
4192 if (!I)
4193 return false;
4194
4195 // Disjoint or instructions are interpreted as adds by SCEV. However, we
4196 // can't replace an arbitrary add with disjoint or, even if we drop the
4197 // flag. We would need to convert the or into an add.
4198 if (auto *PDI = dyn_cast<PossiblyDisjointInst>(I))
4199 if (PDI->isDisjoint())
4200 return false;
4201
4202 // FIXME: Ignore vscale, even though it technically could be poison. Do this
4203 // because SCEV currently assumes it can't be poison. Remove this special
4204 // case once we proper model when vscale can be poison.
4205 if (auto *II = dyn_cast<IntrinsicInst>(I);
4206 II && II->getIntrinsicID() == Intrinsic::vscale)
4207 continue;
4208
4209 if (canCreatePoison(cast<Operator>(I), /*ConsiderFlagsAndMetadata*/ false))
4210 return false;
4211
4212 // If the instruction can't create poison, we can recurse to its operands.
4213 if (I->hasPoisonGeneratingAnnotations())
4214 DropPoisonGeneratingInsts.push_back(I);
4215
4216 for (Value *Op : I->operands())
4217 Worklist.push_back(Op);
4218 }
4219 return true;
4220}
4221
4222const SCEV *
4225 assert(SCEVSequentialMinMaxExpr::isSequentialMinMaxType(Kind) &&
4226 "Not a SCEVSequentialMinMaxExpr!");
4227 assert(!Ops.empty() && "Cannot get empty (u|s)(min|max)!");
4228 if (Ops.size() == 1)
4229 return Ops[0];
4230#ifndef NDEBUG
4231 Type *ETy = getEffectiveSCEVType(Ops[0]->getType());
4232 for (unsigned i = 1, e = Ops.size(); i != e; ++i) {
4233 assert(getEffectiveSCEVType(Ops[i]->getType()) == ETy &&
4234 "Operand types don't match!");
4235 assert(Ops[0]->getType()->isPointerTy() ==
4236 Ops[i]->getType()->isPointerTy() &&
4237 "min/max should be consistently pointerish");
4238 }
4239#endif
4240
4241 // Note that SCEVSequentialMinMaxExpr is *NOT* commutative,
4242 // so we can *NOT* do any kind of sorting of the expressions!
4243
4244 // Check if we have created the same expression before.
4245 if (const SCEV *S = findExistingSCEVInCache(Kind, Ops))
4246 return S;
4247
4248 // FIXME: there are *some* simplifications that we can do here.
4249
4250 // Keep only the first instance of an operand.
4251 {
4252 SCEVSequentialMinMaxDeduplicatingVisitor Deduplicator(*this, Kind);
4253 bool Changed = Deduplicator.visit(Kind, Ops, Ops);
4254 if (Changed)
4255 return getSequentialMinMaxExpr(Kind, Ops);
4256 }
4257
4258 // Check to see if one of the operands is of the same kind. If so, expand its
4259 // operands onto our operand list, and recurse to simplify.
4260 {
4261 unsigned Idx = 0;
4262 bool DeletedAny = false;
4263 while (Idx < Ops.size()) {
4264 if (Ops[Idx]->getSCEVType() != Kind) {
4265 ++Idx;
4266 continue;
4267 }
4268 const auto *SMME = cast<SCEVSequentialMinMaxExpr>(Ops[Idx]);
4269 Ops.erase(Ops.begin() + Idx);
4270 Ops.insert(Ops.begin() + Idx, SMME->operands().begin(),
4271 SMME->operands().end());
4272 DeletedAny = true;
4273 }
4274
4275 if (DeletedAny)
4276 return getSequentialMinMaxExpr(Kind, Ops);
4277 }
4278
4279 const SCEV *SaturationPoint;
4281 switch (Kind) {
4283 SaturationPoint = getZero(Ops[0]->getType());
4284 Pred = ICmpInst::ICMP_ULE;
4285 break;
4286 default:
4287 llvm_unreachable("Not a sequential min/max type.");
4288 }
4289
4290 for (unsigned i = 1, e = Ops.size(); i != e; ++i) {
4291 // We can replace %x umin_seq %y with %x umin %y if either:
4292 // * %y being poison implies %x is also poison.
4293 // * %x cannot be the saturating value (e.g. zero for umin).
4294 if (::impliesPoison(Ops[i], Ops[i - 1]) ||
4295 isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_NE, Ops[i - 1],
4296 SaturationPoint)) {
4297 SmallVector<const SCEV *> SeqOps = {Ops[i - 1], Ops[i]};
4298 Ops[i - 1] = getMinMaxExpr(
4300 SeqOps);
4301 Ops.erase(Ops.begin() + i);
4302 return getSequentialMinMaxExpr(Kind, Ops);
4303 }
4304 // Fold %x umin_seq %y to %x if %x ule %y.
4305 // TODO: We might be able to prove the predicate for a later operand.
4306 if (isKnownViaNonRecursiveReasoning(Pred, Ops[i - 1], Ops[i])) {
4307 Ops.erase(Ops.begin() + i);
4308 return getSequentialMinMaxExpr(Kind, Ops);
4309 }
4310 }
4311
4312 // Okay, it looks like we really DO need an expr. Check to see if we
4313 // already have one, otherwise create a new one.
4315 ID.AddInteger(Kind);
4316 for (const SCEV *Op : Ops)
4317 ID.AddPointer(Op);
4318 void *IP = nullptr;
4319 const SCEV *ExistingSCEV = UniqueSCEVs.FindNodeOrInsertPos(ID, IP);
4320 if (ExistingSCEV)
4321 return ExistingSCEV;
4322
4323 const SCEV **O = SCEVAllocator.Allocate<const SCEV *>(Ops.size());
4324 std::uninitialized_copy(Ops.begin(), Ops.end(), O);
4325 SCEV *S = new (SCEVAllocator)
4326 SCEVSequentialMinMaxExpr(ID.Intern(SCEVAllocator), Kind, O, Ops.size());
4327
4328 UniqueSCEVs.InsertNode(S, IP);
4329 registerUser(S, Ops);
4330 return S;
4331}
4332
4333const SCEV *ScalarEvolution::getSMaxExpr(const SCEV *LHS, const SCEV *RHS) {
4335 return getSMaxExpr(Ops);
4336}
4337
4339 return getMinMaxExpr(scSMaxExpr, Ops);
4340}
4341
4342const SCEV *ScalarEvolution::getUMaxExpr(const SCEV *LHS, const SCEV *RHS) {
4344 return getUMaxExpr(Ops);
4345}
4346
4348 return getMinMaxExpr(scUMaxExpr, Ops);
4349}
4350
4352 const SCEV *RHS) {
4354 return getSMinExpr(Ops);
4355}
4356
4358 return getMinMaxExpr(scSMinExpr, Ops);
4359}
4360
4361const SCEV *ScalarEvolution::getUMinExpr(const SCEV *LHS, const SCEV *RHS,
4362 bool Sequential) {
4364 return getUMinExpr(Ops, Sequential);
4365}
4366
4368 bool Sequential) {
4369 return Sequential ? getSequentialMinMaxExpr(scSequentialUMinExpr, Ops)
4370 : getMinMaxExpr(scUMinExpr, Ops);
4371}
4372
4373const SCEV *
4375 const SCEV *Res = getConstant(IntTy, Size.getKnownMinValue());
4376 if (Size.isScalable())
4377 Res = getMulExpr(Res, getVScale(IntTy));
4378 return Res;
4379}
4380
4382 return getSizeOfExpr(IntTy, getDataLayout().getTypeAllocSize(AllocTy));
4383}
4384
4386 return getSizeOfExpr(IntTy, getDataLayout().getTypeStoreSize(StoreTy));
4387}
4388
4390 StructType *STy,
4391 unsigned FieldNo) {
4392 // We can bypass creating a target-independent constant expression and then
4393 // folding it back into a ConstantInt. This is just a compile-time
4394 // optimization.
4395 const StructLayout *SL = getDataLayout().getStructLayout(STy);
4396 assert(!SL->getSizeInBits().isScalable() &&
4397 "Cannot get offset for structure containing scalable vector types");
4398 return getConstant(IntTy, SL->getElementOffset(FieldNo));
4399}
4400
4402 // Don't attempt to do anything other than create a SCEVUnknown object
4403 // here. createSCEV only calls getUnknown after checking for all other
4404 // interesting possibilities, and any other code that calls getUnknown
4405 // is doing so in order to hide a value from SCEV canonicalization.
4406
4408 ID.AddInteger(scUnknown);
4409 ID.AddPointer(V);
4410 void *IP = nullptr;
4411 if (SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) {
4412 assert(cast<SCEVUnknown>(S)->getValue() == V &&
4413 "Stale SCEVUnknown in uniquing map!");
4414 return S;
4415 }
4416 SCEV *S = new (SCEVAllocator) SCEVUnknown(ID.Intern(SCEVAllocator), V, this,
4417 FirstUnknown);
4418 FirstUnknown = cast<SCEVUnknown>(S);
4419 UniqueSCEVs.InsertNode(S, IP);
4420 return S;
4421}
4422
4423//===----------------------------------------------------------------------===//
4424// Basic SCEV Analysis and PHI Idiom Recognition Code
4425//
4426
4427/// Test if values of the given type are analyzable within the SCEV
4428/// framework. This primarily includes integer types, and it can optionally
4429/// include pointer types if the ScalarEvolution class has access to
4430/// target-specific information.
4432 // Integers and pointers are always SCEVable.
4433 return Ty->isIntOrPtrTy();
4434}
4435
4436/// Return the size in bits of the specified type, for which isSCEVable must
4437/// return true.
4439 assert(isSCEVable(Ty) && "Type is not SCEVable!");
4440 if (Ty->isPointerTy())
4442 return getDataLayout().getTypeSizeInBits(Ty);
4443}
4444
4445/// Return a type with the same bitwidth as the given type and which represents
4446/// how SCEV will treat the given type, for which isSCEVable must return
4447/// true. For pointer types, this is the pointer index sized integer type.
4449 assert(isSCEVable(Ty) && "Type is not SCEVable!");
4450
4451 if (Ty->isIntegerTy())
4452 return Ty;
4453
4454 // The only other support type is pointer.
4455 assert(Ty->isPointerTy() && "Unexpected non-pointer non-integer type!");
4456 return getDataLayout().getIndexType(Ty);
4457}
4458
4460 return getTypeSizeInBits(T1) >= getTypeSizeInBits(T2) ? T1 : T2;
4461}
4462
4464 const SCEV *B) {
4465 /// For a valid use point to exist, the defining scope of one operand
4466 /// must dominate the other.
4467 bool PreciseA, PreciseB;
4468 auto *ScopeA = getDefiningScopeBound({A}, PreciseA);
4469 auto *ScopeB = getDefiningScopeBound({B}, PreciseB);
4470 if (!PreciseA || !PreciseB)
4471 // Can't tell.
4472 return false;
4473 return (ScopeA == ScopeB) || DT.dominates(ScopeA, ScopeB) ||
4474 DT.dominates(ScopeB, ScopeA);
4475}
4476
4478 return CouldNotCompute.get();
4479}
4480
4481bool ScalarEvolution::checkValidity(const SCEV *S) const {
4482 bool ContainsNulls = SCEVExprContains(S, [](const SCEV *S) {
4483 auto *SU = dyn_cast<SCEVUnknown>(S);
4484 return SU && SU->getValue() == nullptr;
4485 });
4486
4487 return !ContainsNulls;
4488}
4489
4491 HasRecMapType::iterator I = HasRecMap.find(S);
4492 if (I != HasRecMap.end())
4493 return I->second;
4494
4495 bool FoundAddRec =
4496 SCEVExprContains(S, [](const SCEV *S) { return isa<SCEVAddRecExpr>(S); });
4497 HasRecMap.insert({S, FoundAddRec});
4498 return FoundAddRec;
4499}
4500
4501/// Return the ValueOffsetPair set for \p S. \p S can be represented
4502/// by the value and offset from any ValueOffsetPair in the set.
4503ArrayRef<Value *> ScalarEvolution::getSCEVValues(const SCEV *S) {
4504 ExprValueMapType::iterator SI = ExprValueMap.find_as(S);
4505 if (SI == ExprValueMap.end())
4506 return std::nullopt;
4507 return SI->second.getArrayRef();
4508}
4509
4510/// Erase Value from ValueExprMap and ExprValueMap. ValueExprMap.erase(V)
4511/// cannot be used separately. eraseValueFromMap should be used to remove
4512/// V from ValueExprMap and ExprValueMap at the same time.
4513void ScalarEvolution::eraseValueFromMap(Value *V) {
4514 ValueExprMapType::iterator I = ValueExprMap.find_as(V);
4515 if (I != ValueExprMap.end()) {
4516 auto EVIt = ExprValueMap.find(I->second);
4517 bool Removed = EVIt->second.remove(V);
4518 (void) Removed;
4519 assert(Removed && "Value not in ExprValueMap?");
4520 ValueExprMap.erase(I);
4521 }
4522}
4523
4524void ScalarEvolution::insertValueToMap(Value *V, const SCEV *S) {
4525 // A recursive query may have already computed the SCEV. It should be
4526 // equivalent, but may not necessarily be exactly the same, e.g. due to lazily
4527 // inferred nowrap flags.
4528 auto It = ValueExprMap.find_as(V);
4529 if (It == ValueExprMap.end()) {
4530 ValueExprMap.insert({SCEVCallbackVH(V, this), S});
4531 ExprValueMap[S].insert(V);
4532 }
4533}
4534
4535/// Return an existing SCEV if it exists, otherwise analyze the expression and
4536/// create a new one.
4538 assert(isSCEVable(V->getType()) && "Value is not SCEVable!");
4539
4540 if (const SCEV *S = getExistingSCEV(V))
4541 return S;
4542 return createSCEVIter(V);
4543}
4544
4546 assert(isSCEVable(V->getType()) && "Value is not SCEVable!");
4547
4548 ValueExprMapType::iterator I = ValueExprMap.find_as(V);
4549 if (I != ValueExprMap.end()) {
4550 const SCEV *S = I->second;
4551 assert(checkValidity(S) &&
4552 "existing SCEV has not been properly invalidated");
4553 return S;
4554 }
4555 return nullptr;
4556}
4557
4558/// Return a SCEV corresponding to -V = -1*V
4560 SCEV::NoWrapFlags Flags) {
4561 if (const SCEVConstant *VC = dyn_cast<SCEVConstant>(V))
4562 return getConstant(
4563 cast<ConstantInt>(ConstantExpr::getNeg(VC->getValue())));
4564
4565 Type *Ty = V->getType();
4566 Ty = getEffectiveSCEVType(Ty);
4567 return getMulExpr(V, getMinusOne(Ty), Flags);
4568}
4569
4570/// If Expr computes ~A, return A else return nullptr
4571static const SCEV *MatchNotExpr(const SCEV *Expr) {
4572 const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Expr);
4573 if (!Add || Add->getNumOperands() != 2 ||
4574 !Add->getOperand(0)->isAllOnesValue())
4575 return nullptr;
4576
4577 const SCEVMulExpr *AddRHS = dyn_cast<SCEVMulExpr>(Add->getOperand(1));
4578 if (!AddRHS || AddRHS->getNumOperands() != 2 ||
4579 !AddRHS->getOperand(0)->isAllOnesValue())
4580 return nullptr;
4581
4582 return AddRHS->getOperand(1);
4583}
4584
4585/// Return a SCEV corresponding to ~V = -1-V
4587 assert(!V->getType()->isPointerTy() && "Can't negate pointer");
4588
4589 if (const SCEVConstant *VC = dyn_cast<SCEVConstant>(V))
4590 return getConstant(
4591 cast<ConstantInt>(ConstantExpr::getNot(VC->getValue())));
4592
4593 // Fold ~(u|s)(min|max)(~x, ~y) to (u|s)(max|min)(x, y)
4594 if (const SCEVMinMaxExpr *MME = dyn_cast<SCEVMinMaxExpr>(V)) {
4595 auto MatchMinMaxNegation = [&](const SCEVMinMaxExpr *MME) {
4596 SmallVector<const SCEV *, 2> MatchedOperands;
4597 for (const SCEV *Operand : MME->operands()) {
4598 const SCEV *Matched = MatchNotExpr(Operand);
4599 if (!Matched)
4600 return (const SCEV *)nullptr;
4601 MatchedOperands.push_back(Matched);
4602 }
4603 return getMinMaxExpr(SCEVMinMaxExpr::negate(MME->getSCEVType()),
4604 MatchedOperands);
4605 };
4606 if (const SCEV *Replaced = MatchMinMaxNegation(MME))
4607 return Replaced;
4608 }
4609
4610 Type *Ty = V->getType();
4611 Ty = getEffectiveSCEVType(Ty);
4612 return getMinusSCEV(getMinusOne(Ty), V);
4613}
4614
4616 assert(P->getType()->isPointerTy());
4617
4618 if (auto *AddRec = dyn_cast<SCEVAddRecExpr>(P)) {
4619 // The base of an AddRec is the first operand.
4620 SmallVector<const SCEV *> Ops{AddRec->operands()};
4621 Ops[0] = removePointerBase(Ops[0]);
4622 // Don't try to transfer nowrap flags for now. We could in some cases
4623 // (for example, if pointer operand of the AddRec is a SCEVUnknown).
4624 return getAddRecExpr(Ops, AddRec->getLoop(), SCEV::FlagAnyWrap);
4625 }
4626 if (auto *Add = dyn_cast<SCEVAddExpr>(P)) {
4627 // The base of an Add is the pointer operand.
4628 SmallVector<const SCEV *> Ops{Add->operands()};
4629 const SCEV **PtrOp = nullptr;
4630 for (const SCEV *&AddOp : Ops) {
4631 if (AddOp->getType()->isPointerTy()) {
4632 assert(!PtrOp && "Cannot have multiple pointer ops");
4633 PtrOp = &AddOp;
4634 }
4635 }
4636 *PtrOp = removePointerBase(*PtrOp);
4637 // Don't try to transfer nowrap flags for now. We could in some cases
4638 // (for example, if the pointer operand of the Add is a SCEVUnknown).
4639 return getAddExpr(Ops);
4640 }
4641 // Any other expression must be a pointer base.
4642 return getZero(P->getType());
4643}
4644
4645const SCEV *ScalarEvolution::getMinusSCEV(const SCEV *LHS, const SCEV *RHS,
4646 SCEV::NoWrapFlags Flags,
4647 unsigned Depth) {
4648 // Fast path: X - X --> 0.
4649 if (LHS == RHS)
4650 return getZero(LHS->getType());
4651
4652 // If we subtract two pointers with different pointer bases, bail.
4653 // Eventually, we're going to add an assertion to getMulExpr that we
4654 // can't multiply by a pointer.
4655 if (RHS->getType()->isPointerTy()) {
4656 if (!LHS->getType()->isPointerTy() ||
4658 return getCouldNotCompute();
4661 }
4662
4663 // We represent LHS - RHS as LHS + (-1)*RHS. This transformation
4664 // makes it so that we cannot make much use of NUW.
4665 auto AddFlags = SCEV::FlagAnyWrap;
4666 const bool RHSIsNotMinSigned =
4668 if (hasFlags(Flags, SCEV::FlagNSW)) {
4669 // Let M be the minimum representable signed value. Then (-1)*RHS
4670 // signed-wraps if and only if RHS is M. That can happen even for
4671 // a NSW subtraction because e.g. (-1)*M signed-wraps even though
4672 // -1 - M does not. So to transfer NSW from LHS - RHS to LHS +
4673 // (-1)*RHS, we need to prove that RHS != M.
4674 //
4675 // If LHS is non-negative and we know that LHS - RHS does not
4676 // signed-wrap, then RHS cannot be M. So we can rule out signed-wrap
4677 // either by proving that RHS > M or that LHS >= 0.
4678 if (RHSIsNotMinSigned || isKnownNonNegative(LHS)) {
4679 AddFlags = SCEV::FlagNSW;
4680 }
4681 }
4682
4683 // FIXME: Find a correct way to transfer NSW to (-1)*M when LHS -
4684 // RHS is NSW and LHS >= 0.
4685 //
4686 // The difficulty here is that the NSW flag may have been proven
4687 // relative to a loop that is to be found in a recurrence in LHS and
4688 // not in RHS. Applying NSW to (-1)*M may then let the NSW have a
4689 // larger scope than intended.
4690 auto NegFlags = RHSIsNotMinSigned ? SCEV::FlagNSW : SCEV::FlagAnyWrap;
4691
4692 return getAddExpr(LHS, getNegativeSCEV(RHS, NegFlags), AddFlags, Depth);
4693}
4694
4696 unsigned Depth) {
4697 Type *SrcTy = V->getType();
4698 assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
4699 "Cannot truncate or zero extend with non-integer arguments!");
4700 if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
4701 return V; // No conversion
4702 if (getTypeSizeInBits(SrcTy) > getTypeSizeInBits(Ty))
4703 return getTruncateExpr(V, Ty, Depth);
4704 return getZeroExtendExpr(V, Ty, Depth);
4705}
4706
4708 unsigned Depth) {
4709 Type *SrcTy = V->getType();
4710 assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
4711 "Cannot truncate or zero extend with non-integer arguments!");
4712 if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
4713 return V; // No conversion
4714 if (getTypeSizeInBits(SrcTy) > getTypeSizeInBits(Ty))
4715 return getTruncateExpr(V, Ty, Depth);
4716 return getSignExtendExpr(V, Ty, Depth);
4717}
4718
4719const SCEV *
4721 Type *SrcTy = V->getType();
4722 assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
4723 "Cannot noop or zero extend with non-integer arguments!");
4725 "getNoopOrZeroExtend cannot truncate!");
4726 if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
4727 return V; // No conversion
4728 return getZeroExtendExpr(V, Ty);
4729}
4730
4731const SCEV *
4733 Type *SrcTy = V->getType();
4734 assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
4735 "Cannot noop or sign extend with non-integer arguments!");
4737 "getNoopOrSignExtend cannot truncate!");
4738 if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
4739 return V; // No conversion
4740 return getSignExtendExpr(V, Ty);
4741}
4742
4743const SCEV *
4745 Type *SrcTy = V->getType();
4746 assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
4747 "Cannot noop or any extend with non-integer arguments!");
4749 "getNoopOrAnyExtend cannot truncate!");
4750 if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
4751 return V; // No conversion
4752 return getAnyExtendExpr(V, Ty);
4753}
4754
4755const SCEV *
4757 Type *SrcTy = V->getType();
4758 assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
4759 "Cannot truncate or noop with non-integer arguments!");
4761 "getTruncateOrNoop cannot extend!");
4762 if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
4763 return V; // No conversion
4764 return getTruncateExpr(V, Ty);
4765}
4766
4768 const SCEV *RHS) {
4769 const SCEV *PromotedLHS = LHS;
4770 const SCEV *PromotedRHS = RHS;
4771
4773 PromotedRHS = getZeroExtendExpr(RHS, LHS->getType());
4774 else
4775 PromotedLHS = getNoopOrZeroExtend(LHS, RHS->getType());
4776
4777 return getUMaxExpr(PromotedLHS, PromotedRHS);
4778}
4779
4781 const SCEV *RHS,
4782 bool Sequential) {
4784 return getUMinFromMismatchedTypes(Ops, Sequential);
4785}
4786
4787const SCEV *
4789 bool Sequential) {
4790 assert(!Ops.empty() && "At least one operand must be!");
4791 // Trivial case.
4792 if (Ops.size() == 1)
4793 return Ops[0];
4794
4795 // Find the max type first.
4796 Type *MaxType = nullptr;
4797 for (const auto *S : Ops)
4798 if (MaxType)
4799 MaxType = getWiderType(MaxType, S->getType());
4800 else
4801 MaxType = S->getType();
4802 assert(MaxType && "Failed to find maximum type!");
4803
4804 // Extend all ops to max type.
4805 SmallVector<const SCEV *, 2> PromotedOps;
4806 for (const auto *S : Ops)
4807 PromotedOps.push_back(getNoopOrZeroExtend(S, MaxType));
4808
4809 // Generate umin.
4810 return getUMinExpr(PromotedOps, Sequential);
4811}
4812
4814 // A pointer operand may evaluate to a nonpointer expression, such as null.
4815 if (!V->getType()->isPointerTy())
4816 return V;
4817
4818 while (true) {
4819 if (auto *AddRec = dyn_cast<SCEVAddRecExpr>(V)) {
4820 V = AddRec->getStart();
4821 } else if (auto *Add = dyn_cast<SCEVAddExpr>(V)) {
4822 const SCEV *PtrOp = nullptr;
4823 for (const SCEV *AddOp : Add->operands()) {
4824 if (AddOp->getType()->isPointerTy()) {
4825 assert(!PtrOp && "Cannot have multiple pointer ops");
4826 PtrOp = AddOp;
4827 }
4828 }
4829 assert(PtrOp && "Must have pointer op");
4830 V = PtrOp;
4831 } else // Not something we can look further into.
4832 return V;
4833 }
4834}
4835
4836/// Push users of the given Instruction onto the given Worklist.
4840 // Push the def-use children onto the Worklist stack.
4841 for (User *U : I->users()) {
4842 auto *UserInsn = cast<Instruction>(U);
4843 if (Visited.insert(UserInsn).second)
4844 Worklist.push_back(UserInsn);
4845 }
4846}
4847
4848namespace {
4849
4850/// Takes SCEV S and Loop L. For each AddRec sub-expression, use its start
4851/// expression in case its Loop is L. If it is not L then
4852/// if IgnoreOtherLoops is true then use AddRec itself
4853/// otherwise rewrite cannot be done.
4854/// If SCEV contains non-invariant unknown SCEV rewrite cannot be done.
4855class SCEVInitRewriter : public SCEVRewriteVisitor<SCEVInitRewriter> {
4856public:
4857 static const SCEV *rewrite(const SCEV *S, const Loop *L, ScalarEvolution &SE,
4858 bool IgnoreOtherLoops = true) {
4859 SCEVInitRewriter Rewriter(L, SE);
4860 const SCEV *Result = Rewriter.visit(S);
4861 if (Rewriter.hasSeenLoopVariantSCEVUnknown())
4862 return SE.getCouldNotCompute();
4863 return Rewriter.hasSeenOtherLoops() && !IgnoreOtherLoops
4864 ? SE.getCouldNotCompute()
4865 : Result;
4866 }
4867
4868 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
4869 if (!SE.isLoopInvariant(Expr, L))
4870 SeenLoopVariantSCEVUnknown = true;
4871 return Expr;
4872 }
4873
4874 const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) {
4875 // Only re-write AddRecExprs for this loop.
4876 if (Expr->getLoop() == L)
4877 return Expr->getStart();
4878 SeenOtherLoops = true;
4879 return Expr;
4880 }
4881
4882 bool hasSeenLoopVariantSCEVUnknown() { return SeenLoopVariantSCEVUnknown; }
4883
4884 bool hasSeenOtherLoops() { return SeenOtherLoops; }
4885
4886private:
4887 explicit SCEVInitRewriter(const Loop *L, ScalarEvolution &SE)
4888 : SCEVRewriteVisitor(SE), L(L) {}
4889
4890 const Loop *L;
4891 bool SeenLoopVariantSCEVUnknown = false;
4892 bool SeenOtherLoops = false;
4893};
4894
4895/// Takes SCEV S and Loop L. For each AddRec sub-expression, use its post
4896/// increment expression in case its Loop is L. If it is not L then
4897/// use AddRec itself.
4898/// If SCEV contains non-invariant unknown SCEV rewrite cannot be done.
4899class SCEVPostIncRewriter : public SCEVRewriteVisitor<SCEVPostIncRewriter> {
4900public:
4901 static const SCEV *rewrite(const SCEV *S, const Loop *L, ScalarEvolution &SE) {
4902 SCEVPostIncRewriter Rewriter(L, SE);
4903 const SCEV *Result = Rewriter.visit(S);
4904 return Rewriter.hasSeenLoopVariantSCEVUnknown()
4905 ? SE.getCouldNotCompute()
4906 : Result;
4907 }
4908
4909 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
4910 if (!SE.isLoopInvariant(Expr, L))
4911 SeenLoopVariantSCEVUnknown = true;
4912 return Expr;
4913 }
4914
4915 const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) {
4916 // Only re-write AddRecExprs for this loop.
4917 if (Expr->getLoop() == L)
4918 return Expr->getPostIncExpr(SE);
4919 SeenOtherLoops = true;
4920 return Expr;
4921 }
4922
4923 bool hasSeenLoopVariantSCEVUnknown() { return SeenLoopVariantSCEVUnknown; }
4924
4925 bool hasSeenOtherLoops() { return SeenOtherLoops; }
4926
4927private:
4928 explicit SCEVPostIncRewriter(const Loop *L, ScalarEvolution &SE)
4929 : SCEVRewriteVisitor(SE), L(L) {}
4930
4931 const Loop *L;
4932 bool SeenLoopVariantSCEVUnknown = false;
4933 bool SeenOtherLoops = false;
4934};
4935
4936/// This class evaluates the compare condition by matching it against the
4937/// condition of loop latch. If there is a match we assume a true value
4938/// for the condition while building SCEV nodes.
4939class SCEVBackedgeConditionFolder
4940 : public SCEVRewriteVisitor<SCEVBackedgeConditionFolder> {
4941public:
4942 static const SCEV *rewrite(const SCEV *S, const Loop *L,
4943 ScalarEvolution &SE) {
4944 bool IsPosBECond = false;
4945 Value *BECond = nullptr;
4946 if (BasicBlock *Latch = L->getLoopLatch()) {
4947 BranchInst *BI = dyn_cast<BranchInst>(Latch->getTerminator());
4948 if (BI && BI->isConditional()) {
4949 assert(BI->getSuccessor(0) != BI->getSuccessor(1) &&
4950 "Both outgoing branches should not target same header!");
4951 BECond = BI->getCondition();
4952 IsPosBECond = BI->getSuccessor(0) == L->getHeader();
4953 } else {
4954 return S;
4955 }
4956 }
4957 SCEVBackedgeConditionFolder Rewriter(L, BECond, IsPosBECond, SE);
4958 return Rewriter.visit(S);
4959 }
4960
4961 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
4962 const SCEV *Result = Expr;
4963 bool InvariantF = SE.isLoopInvariant(Expr, L);
4964
4965 if (!InvariantF) {
4966 Instruction *I = cast<Instruction>(Expr->getValue());
4967 switch (I->getOpcode()) {
4968 case Instruction::Select: {
4969 SelectInst *SI = cast<SelectInst>(I);
4970 std::optional<const SCEV *> Res =
4971 compareWithBackedgeCondition(SI->getCondition());
4972 if (Res) {
4973 bool IsOne = cast<SCEVConstant>(*Res)->getValue()->isOne();
4974 Result = SE.getSCEV(IsOne ? SI->getTrueValue() : SI->getFalseValue());
4975 }
4976 break;
4977 }
4978 default: {
4979 std::optional<const SCEV *> Res = compareWithBackedgeCondition(I);
4980 if (Res)
4981 Result = *Res;
4982 break;
4983 }
4984 }
4985 }
4986 return Result;
4987 }
4988
4989private:
4990 explicit SCEVBackedgeConditionFolder(const Loop *L, Value *BECond,
4991 bool IsPosBECond, ScalarEvolution &SE)
4992 : SCEVRewriteVisitor(SE), L(L), BackedgeCond(BECond),
4993 IsPositiveBECond(IsPosBECond) {}
4994
4995 std::optional<const SCEV *> compareWithBackedgeCondition(Value *IC);
4996
4997 const Loop *L;
4998 /// Loop back condition.
4999 Value *BackedgeCond = nullptr;
5000 /// Set to true if loop back is on positive branch condition.
5001 bool IsPositiveBECond;
5002};
5003
5004std::optional<const SCEV *>
5005SCEVBackedgeConditionFolder::compareWithBackedgeCondition(Value *IC) {
5006
5007 // If value matches the backedge condition for loop latch,
5008 // then return a constant evolution node based on loopback
5009 // branch taken.
5010 if (BackedgeCond == IC)
5011 return IsPositiveBECond ? SE.getOne(Type::getInt1Ty(SE.getContext()))
5013 return std::nullopt;
5014}
5015
5016class SCEVShiftRewriter : public SCEVRewriteVisitor<SCEVShiftRewriter> {
5017public:
5018 static const SCEV *rewrite(const SCEV *S, const Loop *L,
5019 ScalarEvolution &SE) {
5020 SCEVShiftRewriter Rewriter(L, SE);
5021 const SCEV *Result = Rewriter.visit(S);
5022 return Rewriter.isValid() ? Result : SE.getCouldNotCompute();
5023 }
5024
5025 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
5026 // Only allow AddRecExprs for this loop.
5027 if (!SE.isLoopInvariant(Expr, L))
5028 Valid = false;
5029 return Expr;
5030 }
5031
5032 const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) {
5033 if (Expr->getLoop() == L && Expr->isAffine())
5034 return SE.getMinusSCEV(Expr, Expr->getStepRecurrence(SE));
5035 Valid = false;
5036 return Expr;
5037 }
5038
5039 bool isValid() { return Valid; }
5040
5041private:
5042 explicit SCEVShiftRewriter(const Loop *L, ScalarEvolution &SE)
5043 : SCEVRewriteVisitor(SE), L(L) {}
5044
5045 const Loop *L;
5046 bool Valid = true;
5047};
5048
5049} // end anonymous namespace
5050
5052ScalarEvolution::proveNoWrapViaConstantRanges(const SCEVAddRecExpr *AR) {
5053 if (!AR->isAffine())
5054 return SCEV::FlagAnyWrap;
5055
5056 using OBO = OverflowingBinaryOperator;
5057
5059
5060 if (!AR->hasNoSelfWrap()) {
5061 const SCEV *BECount = getConstantMaxBackedgeTakenCount(AR->getLoop());
5062 if (const SCEVConstant *BECountMax = dyn_cast<SCEVConstant>(BECount)) {
5063 ConstantRange StepCR = getSignedRange(AR->getStepRecurrence(*this));
5064 const APInt &BECountAP = BECountMax->getAPInt();
5065 unsigned NoOverflowBitWidth =
5066 BECountAP.getActiveBits() + StepCR.getMinSignedBits();
5067 if (NoOverflowBitWidth <= getTypeSizeInBits(AR->getType()))
5069 }
5070 }
5071
5072 if (!AR->hasNoSignedWrap()) {
5073 ConstantRange AddRecRange = getSignedRange(AR);
5074 ConstantRange IncRange = getSignedRange(AR->getStepRecurrence(*this));
5075
5077 Instruction::Add, IncRange, OBO::NoSignedWrap);
5078 if (NSWRegion.contains(AddRecRange))
5080 }
5081
5082 if (!AR->hasNoUnsignedWrap()) {
5083 ConstantRange AddRecRange = getUnsignedRange(AR);
5084 ConstantRange IncRange = getUnsignedRange(AR->getStepRecurrence(*this));
5085
5087 Instruction::Add, IncRange, OBO::NoUnsignedWrap);
5088 if (NUWRegion.contains(AddRecRange))
5090 }
5091
5092 return Result;
5093}
5094
5096ScalarEvolution::proveNoSignedWrapViaInduction(const SCEVAddRecExpr *AR) {
5098
5099 if (AR->hasNoSignedWrap())
5100 return Result;
5101
5102 if (!AR->isAffine())
5103 return Result;
5104
5105 // This function can be expensive, only try to prove NSW once per AddRec.
5106 if (!SignedWrapViaInductionTried.insert(AR).second)
5107 return Result;
5108
5109 const SCEV *Step = AR->getStepRecurrence(*this);
5110 const Loop *L = AR->getLoop();
5111
5112 // Check whether the backedge-taken count is SCEVCouldNotCompute.
5113 // Note that this serves two purposes: It filters out loops that are
5114 // simply not analyzable, and it covers the case where this code is
5115 // being called from within backedge-taken count analysis, such that
5116 // attempting to ask for the backedge-taken count would likely result
5117 // in infinite recursion. In the later case, the analysis code will
5118 // cope with a conservative value, and it will take care to purge
5119 // that value once it has finished.
5120 const SCEV *MaxBECount = getConstantMaxBackedgeTakenCount(L);
5121
5122 // Normally, in the cases we can prove no-overflow via a
5123 // backedge guarding condition, we can also compute a backedge
5124 // taken count for the loop. The exceptions are assumptions and
5125 // guards present in the loop -- SCEV is not great at exploiting
5126 // these to compute max backedge taken counts, but can still use
5127 // these to prove lack of overflow. Use this fact to avoid
5128 // doing extra work that may not pay off.
5129
5130 if (isa<SCEVCouldNotCompute>(MaxBECount) && !HasGuards &&
5131 AC.assumptions().empty())
5132 return Result;
5133
5134 // If the backedge is guarded by a comparison with the pre-inc value the
5135 // addrec is safe. Also, if the entry is guarded by a comparison with the
5136 // start value and the backedge is guarded by a comparison with the post-inc
5137 // value, the addrec is safe.
5139 const SCEV *OverflowLimit =
5140 getSignedOverflowLimitForStep(Step, &Pred, this);
5141 if (OverflowLimit &&
5142 (isLoopBackedgeGuardedByCond(L, Pred, AR, OverflowLimit) ||
5143 isKnownOnEveryIteration(Pred, AR, OverflowLimit))) {
5144 Result = setFlags(Result, SCEV::FlagNSW);
5145 }
5146 return Result;
5147}
5149ScalarEvolution::proveNoUnsignedWrapViaInduction(const SCEVAddRecExpr *AR) {
5151
5152 if (AR->hasNoUnsignedWrap())
5153 return Result;
5154
5155 if (!AR->isAffine())
5156 return Result;
5157
5158 // This function can be expensive, only try to prove NUW once per AddRec.
5159 if (!UnsignedWrapViaInductionTried.insert(AR).second)
5160 return Result;
5161
5162 const SCEV *Step = AR->getStepRecurrence(*this);
5163 unsigned BitWidth = getTypeSizeInBits(AR->getType());
5164 const Loop *L = AR->getLoop();
5165
5166 // Check whether the backedge-taken count is SCEVCouldNotCompute.
5167 // Note that this serves two purposes: It filters out loops that are
5168 // simply not analyzable, and it covers the case where this code is
5169 // being called from within backedge-taken count analysis, such that
5170 // attempting to ask for the backedge-taken count would likely result
5171 // in infinite recursion. In the later case, the analysis code will
5172 // cope with a conservative value, and it will take care to purge
5173 // that value once it has finished.
5174 const SCEV *MaxBECount = getConstantMaxBackedgeTakenCount(L);
5175
5176 // Normally, in the cases we can prove no-overflow via a
5177 // backedge guarding condition, we can also compute a backedge
5178 // taken count for the loop. The exceptions are assumptions and
5179 // guards present in the loop -- SCEV is not great at exploiting
5180 // these to compute max backedge taken counts, but can still use
5181 // these to prove lack of overflow. Use this fact to avoid
5182 // doing extra work that may not pay off.
5183
5184 if (isa<SCEVCouldNotCompute>(MaxBECount) && !HasGuards &&
5185 AC.assumptions().empty())
5186 return Result;
5187
5188 // If the backedge is guarded by a comparison with the pre-inc value the
5189 // addrec is safe. Also, if the entry is guarded by a comparison with the
5190 // start value and the backedge is guarded by a comparison with the post-inc
5191 // value, the addrec is safe.
5192 if (isKnownPositive(Step)) {
5194 getUnsignedRangeMax(Step));
5197 Result = setFlags(Result, SCEV::FlagNUW);
5198 }
5199 }
5200
5201 return Result;
5202}
5203
5204namespace {
5205
5206/// Represents an abstract binary operation. This may exist as a
5207/// normal instruction or constant expression, or may have been
5208/// derived from an expression tree.
5209struct BinaryOp {
5210 unsigned Opcode;
5211 Value *LHS;
5212 Value *RHS;
5213 bool IsNSW = false;
5214 bool IsNUW = false;
5215
5216 /// Op is set if this BinaryOp corresponds to a concrete LLVM instruction or
5217 /// constant expression.
5218 Operator *Op = nullptr;
5219
5220 explicit BinaryOp(Operator *Op)
5221 : Opcode(Op->getOpcode()), LHS(Op->getOperand(0)), RHS(Op->getOperand(1)),
5222 Op(Op) {
5223 if (auto *OBO = dyn_cast<OverflowingBinaryOperator>(Op)) {
5224 IsNSW = OBO->hasNoSignedWrap();
5225 IsNUW = OBO->hasNoUnsignedWrap();
5226 }
5227 }
5228
5229 explicit BinaryOp(unsigned Opcode, Value *LHS, Value *RHS, bool IsNSW = false,
5230 bool IsNUW = false)
5231 : Opcode(Opcode), LHS(LHS), RHS(RHS), IsNSW(IsNSW), IsNUW(IsNUW) {}
5232};
5233
5234} // end anonymous namespace
5235
5236/// Try to map \p V into a BinaryOp, and return \c std::nullopt on failure.
5237static std::optional<BinaryOp> MatchBinaryOp(Value *V, const DataLayout &DL,
5238 AssumptionCache &AC,
5239 const DominatorTree &DT,
5240 const Instruction *CxtI) {
5241 auto *Op = dyn_cast<Operator>(V);
5242 if (!Op)
5243 return std::nullopt;
5244
5245 // Implementation detail: all the cleverness here should happen without
5246 // creating new SCEV expressions -- our caller knowns tricks to avoid creating
5247 // SCEV expressions when possible, and we should not break that.
5248
5249 switch (Op->getOpcode()) {
5250 case Instruction::Add:
5251 case Instruction::Sub:
5252 case Instruction::Mul:
5253 case Instruction::UDiv:
5254 case Instruction::URem:
5255 case Instruction::And:
5256 case Instruction::AShr:
5257 case Instruction::Shl:
5258 return BinaryOp(Op);
5259
5260 case Instruction::Or: {
5261 // Convert or disjoint into add nuw nsw.
5262 if (cast<PossiblyDisjointInst>(Op)->isDisjoint())
5263 return BinaryOp(Instruction::Add, Op->getOperand(0), Op->getOperand(1),
5264 /*IsNSW=*/true, /*IsNUW=*/true);
5265 return BinaryOp(Op);
5266 }
5267
5268 case Instruction::Xor:
5269 if (auto *RHSC = dyn_cast<ConstantInt>(Op->getOperand(1)))
5270 // If the RHS of the xor is a signmask, then this is just an add.
5271 // Instcombine turns add of signmask into xor as a strength reduction step.
5272 if (RHSC->getValue().isSignMask())
5273 return BinaryOp(Instruction::Add, Op->getOperand(0), Op->getOperand(1));
5274 // Binary `xor` is a bit-wise `add`.
5275 if (V->getType()->isIntegerTy(1))
5276 return BinaryOp(Instruction::Add, Op->getOperand(0), Op->getOperand(1));
5277 return BinaryOp(Op);
5278
5279 case Instruction::LShr:
5280 // Turn logical shift right of a constant into a unsigned divide.
5281 if (ConstantInt *SA = dyn_cast<ConstantInt>(Op->getOperand(1))) {
5282 uint32_t BitWidth = cast<IntegerType>(Op->getType())->getBitWidth();
5283
5284 // If the shift count is not less than the bitwidth, the result of
5285 // the shift is undefined. Don't try to analyze it, because the
5286 // resolution chosen here may differ from the resolution chosen in
5287 // other parts of the compiler.
5288 if (SA->getValue().ult(BitWidth)) {
5289 Constant *X =
5290 ConstantInt::get(SA->getContext(),
5291 APInt::getOneBitSet(BitWidth, SA->getZExtValue()));
5292 return BinaryOp(Instruction::UDiv, Op->getOperand(0), X);
5293 }
5294 }
5295 return BinaryOp(Op);
5296
5297 case Instruction::ExtractValue: {
5298 auto *EVI = cast<ExtractValueInst>(Op);
5299 if (EVI->getNumIndices() != 1 || EVI->getIndices()[0] != 0)
5300 break;
5301
5302 auto *WO = dyn_cast<WithOverflowInst>(EVI->getAggregateOperand());
5303 if (!WO)
5304 break;
5305
5306 Instruction::BinaryOps BinOp = WO->getBinaryOp();
5307 bool Signed = WO->isSigned();
5308 // TODO: Should add nuw/nsw flags for mul as well.
5309 if (BinOp == Instruction::Mul || !isOverflowIntrinsicNoWrap(WO, DT))
5310 return BinaryOp(BinOp, WO->getLHS(), WO->getRHS());
5311
5312 // Now that we know that all uses of the arithmetic-result component of
5313 // CI are guarded by the overflow check, we can go ahead and pretend
5314 // that the arithmetic is non-overflowing.
5315 return BinaryOp(BinOp, WO->getLHS(), WO->getRHS(),
5316 /* IsNSW = */ Signed, /* IsNUW = */ !Signed);
5317 }
5318
5319 default:
5320 break;
5321 }
5322
5323 // Recognise intrinsic loop.decrement.reg, and as this has exactly the same
5324 // semantics as a Sub, return a binary sub expression.
5325 if (auto *II = dyn_cast<IntrinsicInst>(V))
5326 if (II->getIntrinsicID() == Intrinsic::loop_decrement_reg)
5327 return BinaryOp(Instruction::Sub, II->getOperand(0), II->getOperand(1));
5328
5329 return std::nullopt;
5330}
5331
5332/// Helper function to createAddRecFromPHIWithCasts. We have a phi
5333/// node whose symbolic (unknown) SCEV is \p SymbolicPHI, which is updated via
5334/// the loop backedge by a SCEVAddExpr, possibly also with a few casts on the
5335/// way. This function checks if \p Op, an operand of this SCEVAddExpr,
5336/// follows one of the following patterns:
5337/// Op == (SExt ix (Trunc iy (%SymbolicPHI) to ix) to iy)
5338/// Op == (ZExt ix (Trunc iy (%SymbolicPHI) to ix) to iy)
5339/// If the SCEV expression of \p Op conforms with one of the expected patterns
5340/// we return the type of the truncation operation, and indicate whether the
5341/// truncated type should be treated as signed/unsigned by setting
5342/// \p Signed to true/false, respectively.
5343static Type *isSimpleCastedPHI(const SCEV *Op, const SCEVUnknown *SymbolicPHI,
5344 bool &Signed, ScalarEvolution &SE) {
5345 // The case where Op == SymbolicPHI (that is, with no type conversions on
5346 // the way) is handled by the regular add recurrence creating logic and
5347 // would have already been triggered in createAddRecForPHI. Reaching it here
5348 // means that createAddRecFromPHI had failed for this PHI before (e.g.,
5349 // because one of the other operands of the SCEVAddExpr updating this PHI is
5350 // not invariant).
5351 //
5352 // Here we look for the case where Op = (ext(trunc(SymbolicPHI))), and in
5353 // this case predicates that allow us to prove that Op == SymbolicPHI will
5354 // be added.
5355 if (Op == SymbolicPHI)
5356 return nullptr;
5357
5358 unsigned SourceBits = SE.getTypeSizeInBits(SymbolicPHI->getType());
5359 unsigned NewBits = SE.getTypeSizeInBits(Op->getType());
5360 if (SourceBits != NewBits)
5361 return nullptr;
5362
5363 const SCEVSignExtendExpr *SExt = dyn_cast<SCEVSignExtendExpr>(Op);
5364 const SCEVZeroExtendExpr *ZExt = dyn_cast<SCEVZeroExtendExpr>(Op);
5365 if (!SExt && !ZExt)
5366 return nullptr;
5367 const SCEVTruncateExpr *Trunc =
5368 SExt ? dyn_cast<SCEVTruncateExpr>(SExt->getOperand())
5369 : dyn_cast<SCEVTruncateExpr>(ZExt->getOperand());
5370 if (!Trunc)
5371 return nullptr;
5372 const SCEV *X = Trunc->getOperand();
5373 if (X != SymbolicPHI)
5374 return nullptr;
5375 Signed = SExt != nullptr;
5376 return Trunc->getType();
5377}
5378
5379static const Loop *isIntegerLoopHeaderPHI(const PHINode *PN, LoopInfo &LI) {
5380 if (!PN->getType()->isIntegerTy())
5381 return nullptr;
5382 const Loop *L = LI.getLoopFor(PN->getParent());
5383 if (!L || L->getHeader() != PN->getParent())
5384 return nullptr;
5385 return L;
5386}
5387
5388// Analyze \p SymbolicPHI, a SCEV expression of a phi node, and check if the
5389// computation that updates the phi follows the following pattern:
5390// (SExt/ZExt ix (Trunc iy (%SymbolicPHI) to ix) to iy) + InvariantAccum
5391// which correspond to a phi->trunc->sext/zext->add->phi update chain.
5392// If so, try to see if it can be rewritten as an AddRecExpr under some
5393// Predicates. If successful, return them as a pair. Also cache the results
5394// of the analysis.
5395//
5396// Example usage scenario:
5397// Say the Rewriter is called for the following SCEV:
5398// 8 * ((sext i32 (trunc i64 %X to i32) to i64) + %Step)
5399// where:
5400// %X = phi i64 (%Start, %BEValue)
5401// It will visitMul->visitAdd->visitSExt->visitTrunc->visitUnknown(%X),
5402// and call this function with %SymbolicPHI = %X.
5403//
5404// The analysis will find that the value coming around the backedge has
5405// the following SCEV:
5406// BEValue = ((sext i32 (trunc i64 %X to i32) to i64) + %Step)
5407// Upon concluding that this matches the desired pattern, the function
5408// will return the pair {NewAddRec, SmallPredsVec} where:
5409// NewAddRec = {%Start,+,%Step}
5410// SmallPredsVec = {P1, P2, P3} as follows:
5411// P1(WrapPred): AR: {trunc(%Start),+,(trunc %Step)}<nsw> Flags: <nssw>
5412// P2(EqualPred): %Start == (sext i32 (trunc i64 %Start to i32) to i64)
5413// P3(EqualPred): %Step == (sext i32 (trunc i64 %Step to i32) to i64)
5414// The returned pair means that SymbolicPHI can be rewritten into NewAddRec
5415// under the predicates {P1,P2,P3}.
5416// This predicated rewrite will be cached in PredicatedSCEVRewrites:
5417// PredicatedSCEVRewrites[{%X,L}] = {NewAddRec, {P1,P2,P3)}
5418//
5419// TODO's:
5420//
5421// 1) Extend the Induction descriptor to also support inductions that involve
5422// casts: When needed (namely, when we are called in the context of the
5423// vectorizer induction analysis), a Set of cast instructions will be
5424// populated by this method, and provided back to isInductionPHI. This is
5425// needed to allow the vectorizer to properly record them to be ignored by
5426// the cost model and to avoid vectorizing them (otherwise these casts,
5427// which are redundant under the runtime overflow checks, will be
5428// vectorized, which can be costly).
5429//
5430// 2) Support additional induction/PHISCEV patterns: We also want to support
5431// inductions where the sext-trunc / zext-trunc operations (partly) occur
5432// after the induction update operation (the induction increment):
5433//
5434// (Trunc iy (SExt/ZExt ix (%SymbolicPHI + InvariantAccum) to iy) to ix)
5435// which correspond to a phi->add->trunc->sext/zext->phi update chain.
5436//
5437// (Trunc iy ((SExt/ZExt ix (%SymbolicPhi) to iy) + InvariantAccum) to ix)
5438// which correspond to a phi->trunc->add->sext/zext->phi update chain.
5439//
5440// 3) Outline common code with createAddRecFromPHI to avoid duplication.
5441std::optional<std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>>>
5442ScalarEvolution::createAddRecFromPHIWithCastsImpl(const SCEVUnknown *SymbolicPHI) {
5444
5445 // *** Part1: Analyze if we have a phi-with-cast pattern for which we can
5446 // return an AddRec expression under some predicate.
5447
5448 auto *PN = cast<PHINode>(SymbolicPHI->getValue());
5449 const Loop *L = isIntegerLoopHeaderPHI(PN, LI);
5450 assert(L && "Expecting an integer loop header phi");
5451
5452 // The loop may have multiple entrances or multiple exits; we can analyze
5453 // this phi as an addrec if it has a unique entry value and a unique
5454 // backedge value.
5455 Value *BEValueV = nullptr, *StartValueV = nullptr;
5456 for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) {
5457 Value *V = PN->getIncomingValue(i);
5458 if (L->contains(PN->getIncomingBlock(i))) {
5459 if (!BEValueV) {
5460 BEValueV = V;
5461 } else if (BEValueV != V) {
5462 BEValueV = nullptr;
5463 break;
5464 }
5465 } else if (!StartValueV) {
5466 StartValueV = V;
5467 } else if (StartValueV != V) {
5468 StartValueV = nullptr;
5469 break;
5470 }
5471 }
5472 if (!BEValueV || !StartValueV)
5473 return std::nullopt;
5474
5475 const SCEV *BEValue = getSCEV(BEValueV);
5476
5477 // If the value coming around the backedge is an add with the symbolic
5478 // value we just inserted, possibly with casts that we can ignore under
5479 // an appropriate runtime guard, then we found a simple induction variable!
5480 const auto *Add = dyn_cast<SCEVAddExpr>(BEValue);
5481 if (!Add)
5482 return std::nullopt;
5483
5484 // If there is a single occurrence of the symbolic value, possibly
5485 // casted, replace it with a recurrence.
5486 unsigned FoundIndex = Add->getNumOperands();
5487 Type *TruncTy = nullptr;
5488 bool Signed;
5489 for (unsigned i = 0, e = Add->getNumOperands(); i != e; ++i)
5490 if ((TruncTy =
5491 isSimpleCastedPHI(Add->getOperand(i), SymbolicPHI, Signed, *this)))
5492 if (FoundIndex == e) {
5493 FoundIndex = i;
5494 break;
5495 }
5496
5497 if (FoundIndex == Add->getNumOperands())
5498 return std::nullopt;
5499
5500 // Create an add with everything but the specified operand.
5502 for (unsigned i = 0, e = Add->getNumOperands(); i != e; ++i)
5503 if (i != FoundIndex)
5504 Ops.push_back(Add->getOperand(i));
5505 const SCEV *Accum = getAddExpr(Ops);
5506
5507 // The runtime checks will not be valid if the step amount is
5508 // varying inside the loop.
5509 if (!isLoopInvariant(Accum, L))
5510 return std::nullopt;
5511
5512 // *** Part2: Create the predicates
5513
5514 // Analysis was successful: we have a phi-with-cast pattern for which we
5515 // can return an AddRec expression under the following predicates:
5516 //
5517 // P1: A Wrap predicate that guarantees that Trunc(Start) + i*Trunc(Accum)
5518 // fits within the truncated type (does not overflow) for i = 0 to n-1.
5519 // P2: An Equal predicate that guarantees that
5520 // Start = (Ext ix (Trunc iy (Start) to ix) to iy)
5521 // P3: An Equal predicate that guarantees that
5522 // Accum = (Ext ix (Trunc iy (Accum) to ix) to iy)
5523 //
5524 // As we next prove, the above predicates guarantee that:
5525 // Start + i*Accum = (Ext ix (Trunc iy ( Start + i*Accum ) to ix) to iy)
5526 //
5527 //
5528 // More formally, we want to prove that:
5529 // Expr(i+1) = Start + (i+1) * Accum
5530 // = (Ext ix (Trunc iy (Expr(i)) to ix) to iy) + Accum
5531 //
5532 // Given that:
5533 // 1) Expr(0) = Start
5534 // 2) Expr(1) = Start + Accum
5535 // = (Ext ix (Trunc iy (Start) to ix) to iy) + Accum :: from P2
5536 // 3) Induction hypothesis (step i):
5537 // Expr(i) = (Ext ix (Trunc iy (Expr(i-1)) to ix) to iy) + Accum
5538 //
5539 // Proof:
5540 // Expr(i+1) =
5541 // = Start + (i+1)*Accum
5542 // = (Start + i*Accum) + Accum
5543 // = Expr(i) + Accum
5544 // = (Ext ix (Trunc iy (Expr(i-1)) to ix) to iy) + Accum + Accum
5545 // :: from step i
5546 //
5547 // = (Ext ix (Trunc iy (Start + (i-1)*Accum) to ix) to iy) + Accum + Accum
5548 //
5549 // = (Ext ix (Trunc iy (Start + (i-1)*Accum) to ix) to iy)
5550 // + (Ext ix (Trunc iy (Accum) to ix) to iy)
5551 // + Accum :: from P3
5552 //
5553 // = (Ext ix (Trunc iy ((Start + (i-1)*Accum) + Accum) to ix) to iy)
5554 // + Accum :: from P1: Ext(x)+Ext(y)=>Ext(x+y)
5555 //
5556 // = (Ext ix (Trunc iy (Start + i*Accum) to ix) to iy) + Accum
5557 // = (Ext ix (Trunc iy (Expr(i)) to ix) to iy) + Accum
5558 //
5559 // By induction, the same applies to all iterations 1<=i<n:
5560 //
5561
5562 // Create a truncated addrec for which we will add a no overflow check (P1).
5563 const SCEV *StartVal = getSCEV(StartValueV);
5564 const SCEV *PHISCEV =
5565 getAddRecExpr(getTruncateExpr(StartVal, TruncTy),
5566 getTruncateExpr(Accum, TruncTy), L, SCEV::FlagAnyWrap);
5567
5568 // PHISCEV can be either a SCEVConstant or a SCEVAddRecExpr.
5569 // ex: If truncated Accum is 0 and StartVal is a constant, then PHISCEV
5570 // will be constant.
5571 //
5572 // If PHISCEV is a constant, then P1 degenerates into P2 or P3, so we don't
5573 // add P1.
5574 if (const auto *AR = dyn_cast<SCEVAddRecExpr>(PHISCEV)) {
5578 const SCEVPredicate *AddRecPred = getWrapPredicate(AR, AddedFlags);
5579 Predicates.push_back(AddRecPred);
5580 }
5581
5582 // Create the Equal Predicates P2,P3:
5583
5584 // It is possible that the predicates P2 and/or P3 are computable at
5585 // compile time due to StartVal and/or Accum being constants.
5586 // If either one is, then we can check that now and escape if either P2
5587 // or P3 is false.
5588
5589 // Construct the extended SCEV: (Ext ix (Trunc iy (Expr) to ix) to iy)
5590 // for each of StartVal and Accum
5591 auto getExtendedExpr = [&](const SCEV *Expr,
5592 bool CreateSignExtend) -> const SCEV * {
5593 assert(isLoopInvariant(Expr, L) && "Expr is expected to be invariant");
5594 const SCEV *TruncatedExpr = getTruncateExpr(Expr, TruncTy);
5595 const SCEV *ExtendedExpr =
5596 CreateSignExtend ? getSignExtendExpr(TruncatedExpr, Expr->getType())
5597 : getZeroExtendExpr(TruncatedExpr, Expr->getType());
5598 return ExtendedExpr;
5599 };
5600
5601 // Given:
5602 // ExtendedExpr = (Ext ix (Trunc iy (Expr) to ix) to iy
5603 // = getExtendedExpr(Expr)
5604 // Determine whether the predicate P: Expr == ExtendedExpr
5605 // is known to be false at compile time
5606 auto PredIsKnownFalse = [&](const SCEV *Expr,
5607 const SCEV *ExtendedExpr) -> bool {
5608 return Expr != ExtendedExpr &&
5609 isKnownPredicate(ICmpInst::ICMP_NE, Expr, ExtendedExpr);
5610 };
5611
5612 const SCEV *StartExtended = getExtendedExpr(StartVal, Signed);
5613 if (PredIsKnownFalse(StartVal, StartExtended)) {
5614 LLVM_DEBUG(dbgs() << "P2 is compile-time false\n";);
5615 return std::nullopt;
5616 }
5617
5618 // The Step is always Signed (because the overflow checks are either
5619 // NSSW or NUSW)
5620 const SCEV *AccumExtended = getExtendedExpr(Accum, /*CreateSignExtend=*/true);
5621 if (PredIsKnownFalse(Accum, AccumExtended)) {
5622 LLVM_DEBUG(dbgs() << "P3 is compile-time false\n";);
5623 return std::nullopt;
5624 }
5625
5626 auto AppendPredicate = [&](const SCEV *Expr,
5627 const SCEV *ExtendedExpr) -> void {
5628 if (Expr != ExtendedExpr &&
5629 !isKnownPredicate(ICmpInst::ICMP_EQ, Expr, ExtendedExpr)) {
5630 const SCEVPredicate *Pred = getEqualPredicate(Expr, ExtendedExpr);
5631 LLVM_DEBUG(dbgs() << "Added Predicate: " << *Pred);
5632 Predicates.push_back(Pred);
5633 }
5634 };
5635
5636 AppendPredicate(StartVal, StartExtended);
5637 AppendPredicate(Accum, AccumExtended);
5638
5639 // *** Part3: Predicates are ready. Now go ahead and create the new addrec in
5640 // which the casts had been folded away. The caller can rewrite SymbolicPHI
5641 // into NewAR if it will also add the runtime overflow checks specified in
5642 // Predicates.
5643 auto *NewAR = getAddRecExpr(StartVal, Accum, L, SCEV::FlagAnyWrap);
5644
5645 std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>> PredRewrite =
5646 std::make_pair(NewAR, Predicates);
5647 // Remember the result of the analysis for this SCEV at this locayyytion.
5648 PredicatedSCEVRewrites[{SymbolicPHI, L}] = PredRewrite;
5649 return PredRewrite;
5650}
5651
5652std::optional<std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>>>
5654 auto *PN = cast<PHINode>(SymbolicPHI->getValue());
5655 const Loop *L = isIntegerLoopHeaderPHI(PN, LI);
5656 if (!L)
5657 return std::nullopt;
5658
5659 // Check to see if we already analyzed this PHI.
5660 auto I = PredicatedSCEVRewrites.find({SymbolicPHI, L});
5661 if (I != PredicatedSCEVRewrites.end()) {
5662 std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>> Rewrite =
5663 I->second;
5664 // Analysis was done before and failed to create an AddRec:
5665 if (Rewrite.first == SymbolicPHI)
5666 return std::nullopt;
5667 // Analysis was done before and succeeded to create an AddRec under
5668 // a predicate:
5669 assert(isa<SCEVAddRecExpr>(Rewrite.first) && "Expected an AddRec");
5670 assert(!(Rewrite.second).empty() && "Expected to find Predicates");
5671 return Rewrite;
5672 }
5673
5674 std::optional<std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>>>
5675 Rewrite = createAddRecFromPHIWithCastsImpl(SymbolicPHI);
5676
5677 // Record in the cache that the analysis failed
5678 if (!Rewrite) {
5680 PredicatedSCEVRewrites[{SymbolicPHI, L}] = {SymbolicPHI, Predicates};
5681 return std::nullopt;
5682 }
5683
5684 return Rewrite;
5685}
5686
5687// FIXME: This utility is currently required because the Rewriter currently
5688// does not rewrite this expression:
5689// {0, +, (sext ix (trunc iy to ix) to iy)}
5690// into {0, +, %step},
5691// even when the following Equal predicate exists:
5692// "%step == (sext ix (trunc iy to ix) to iy)".
5694 const SCEVAddRecExpr *AR1, const SCEVAddRecExpr *AR2) const {
5695 if (AR1 == AR2)
5696 return true;
5697
5698 auto areExprsEqual = [&](const SCEV *Expr1, const SCEV *Expr2) -> bool {
5699 if (Expr1 != Expr2 && !Preds->implies(SE.getEqualPredicate(Expr1, Expr2)) &&
5700 !Preds->implies(SE.getEqualPredicate(Expr2, Expr1)))
5701 return false;
5702 return true;
5703 };
5704
5705 if (!areExprsEqual(AR1->getStart(), AR2->getStart()) ||
5706 !areExprsEqual(AR1->getStepRecurrence(SE), AR2->getStepRecurrence(SE)))
5707 return false;
5708 return true;
5709}
5710
5711/// A helper function for createAddRecFromPHI to handle simple cases.
5712///
5713/// This function tries to find an AddRec expression for the simplest (yet most
5714/// common) cases: PN = PHI(Start, OP(Self, LoopInvariant)).
5715/// If it fails, createAddRecFromPHI will use a more general, but slow,
5716/// technique for finding the AddRec expression.
5717const SCEV *ScalarEvolution::createSimpleAffineAddRec(PHINode *PN,
5718 Value *BEValueV,
5719 Value *StartValueV) {
5720 const Loop *L = LI.getLoopFor(PN->getParent());
5721 assert(L && L->getHeader() == PN->getParent());
5722 assert(BEValueV && StartValueV);
5723
5724 auto BO = MatchBinaryOp(BEValueV, getDataLayout(), AC, DT, PN);
5725 if (!BO)
5726 return nullptr;
5727
5728 if (BO->Opcode != Instruction::Add)
5729 return nullptr;
5730
5731 const SCEV *Accum = nullptr;
5732 if (BO->LHS == PN && L->isLoopInvariant(BO->RHS))
5733 Accum = getSCEV(BO->RHS);
5734 else if (BO->RHS == PN && L->isLoopInvariant(BO->LHS))
5735 Accum = getSCEV(BO->LHS);
5736
5737 if (!Accum)
5738 return nullptr;
5739
5741 if (BO->IsNUW)
5742 Flags = setFlags(Flags, SCEV::FlagNUW);
5743 if (BO->IsNSW)
5744 Flags = setFlags(Flags, SCEV::FlagNSW);
5745
5746 const SCEV *StartVal = getSCEV(StartValueV);
5747 const SCEV *PHISCEV = getAddRecExpr(StartVal, Accum, L, Flags);
5748 insertValueToMap(PN, PHISCEV);
5749
5750 if (auto *AR = dyn_cast<SCEVAddRecExpr>(PHISCEV)) {
5751 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR),
5753 proveNoWrapViaConstantRanges(AR)));
5754 }
5755
5756 // We can add Flags to the post-inc expression only if we
5757 // know that it is *undefined behavior* for BEValueV to
5758 // overflow.
5759 if (auto *BEInst = dyn_cast<Instruction>(BEValueV)) {
5760 assert(isLoopInvariant(Accum, L) &&
5761 "Accum is defined outside L, but is not invariant?");
5762 if (isAddRecNeverPoison(BEInst, L))
5763 (void)getAddRecExpr(getAddExpr(StartVal, Accum), Accum, L, Flags);
5764 }
5765
5766 return PHISCEV;
5767}
5768
5769const SCEV *ScalarEvolution::createAddRecFromPHI(PHINode *PN) {
5770 const Loop *L = LI.getLoopFor(PN->getParent());
5771 if (!L || L->getHeader() != PN->getParent())
5772 return nullptr;
5773
5774 // The loop may have multiple entrances or multiple exits; we can analyze
5775 // this phi as an addrec if it has a unique entry value and a unique
5776 // backedge value.
5777 Value *BEValueV = nullptr, *StartValueV = nullptr;
5778 for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) {
5779 Value *V = PN->getIncomingValue(i);
5780 if (L->contains(PN->getIncomingBlock(i))) {
5781 if (!BEValueV) {
5782 BEValueV = V;
5783 } else if (BEValueV != V) {
5784 BEValueV = nullptr;
5785 break;
5786 }
5787 } else if (!StartValueV) {
5788 StartValueV = V;
5789 } else if (StartValueV != V) {
5790 StartValueV = nullptr;
5791 break;
5792 }
5793 }
5794 if (!BEValueV || !StartValueV)
5795 return nullptr;
5796
5797 assert(ValueExprMap.find_as(PN) == ValueExprMap.end() &&
5798 "PHI node already processed?");
5799
5800 // First, try to find AddRec expression without creating a fictituos symbolic
5801 // value for PN.
5802 if (auto *S = createSimpleAffineAddRec(PN, BEValueV, StartValueV))
5803 return S;
5804
5805 // Handle PHI node value symbolically.
5806 const SCEV *SymbolicName = getUnknown(PN);
5807 insertValueToMap(PN, SymbolicName);
5808
5809 // Using this symbolic name for the PHI, analyze the value coming around
5810 // the back-edge.
5811 const SCEV *BEValue = getSCEV(BEValueV);
5812
5813 // NOTE: If BEValue is loop invariant, we know that the PHI node just
5814 // has a special value for the first iteration of the loop.
5815
5816 // If the value coming around the backedge is an add with the symbolic
5817 // value we just inserted, then we found a simple induction variable!
5818 if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(BEValue)) {
5819 // If there is a single occurrence of the symbolic value, replace it
5820 // with a recurrence.
5821 unsigned FoundIndex = Add->getNumOperands();
5822 for (unsigned i = 0, e = Add->getNumOperands(); i != e; ++i)
5823 if (Add->getOperand(i) == SymbolicName)
5824 if (FoundIndex == e) {
5825 FoundIndex = i;
5826 break;
5827 }
5828
5829 if (FoundIndex != Add->getNumOperands()) {
5830 // Create an add with everything but the specified operand.
5832 for (unsigned i = 0, e = Add->getNumOperands(); i != e; ++i)
5833 if (i != FoundIndex)
5834 Ops.push_back(SCEVBackedgeConditionFolder::rewrite(Add->getOperand(i),
5835 L, *this));
5836 const SCEV *Accum = getAddExpr(Ops);
5837
5838 // This is not a valid addrec if the step amount is varying each
5839 // loop iteration, but is not itself an addrec in this loop.
5840 if (isLoopInvariant(Accum, L) ||
5841 (isa<SCEVAddRecExpr>(Accum) &&
5842 cast<SCEVAddRecExpr>(Accum)->getLoop() == L)) {
5844
5845 if (auto BO = MatchBinaryOp(BEValueV, getDataLayout(), AC, DT, PN)) {
5846 if (BO->Opcode == Instruction::Add && BO->LHS == PN) {
5847 if (BO->IsNUW)
5848 Flags = setFlags(Flags, SCEV::FlagNUW);
5849 if (BO->IsNSW)
5850 Flags = setFlags(Flags, SCEV::FlagNSW);
5851 }
5852 } else if (GEPOperator *GEP = dyn_cast<GEPOperator>(BEValueV)) {
5853 if (GEP->getOperand(0) == PN) {
5854 GEPNoWrapFlags NW = GEP->getNoWrapFlags();
5855 // If the increment has any nowrap flags, then we know the address
5856 // space cannot be wrapped around.
5857 if (NW != GEPNoWrapFlags::none())
5858 Flags = setFlags(Flags, SCEV::FlagNW);
5859 // If the GEP is nuw or nusw with non-negative offset, we know that
5860 // no unsigned wrap occurs. We cannot set the nsw flag as only the
5861 // offset is treated as signed, while the base is unsigned.
5862 if (NW.hasNoUnsignedWrap() ||
5864 Flags = setFlags(Flags, SCEV::FlagNUW);
5865 }
5866
5867 // We cannot transfer nuw and nsw flags from subtraction
5868 // operations -- sub nuw X, Y is not the same as add nuw X, -Y
5869 // for instance.
5870 }
5871
5872 const SCEV *StartVal = getSCEV(StartValueV);
5873 const SCEV *PHISCEV = getAddRecExpr(StartVal, Accum, L, Flags);
5874
5875 // Okay, for the entire analysis of this edge we assumed the PHI
5876 // to be symbolic. We now need to go back and purge all of the
5877 // entries for the scalars that use the symbolic expression.
5878 forgetMemoizedResults(SymbolicName);
5879 insertValueToMap(PN, PHISCEV);
5880
5881 if (auto *AR = dyn_cast<SCEVAddRecExpr>(PHISCEV)) {
5882 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR),
5884 proveNoWrapViaConstantRanges(AR)));
5885 }
5886
5887 // We can add Flags to the post-inc expression only if we
5888 // know that it is *undefined behavior* for BEValueV to
5889 // overflow.
5890 if (auto *BEInst = dyn_cast<Instruction>(BEValueV))
5891 if (isLoopInvariant(Accum, L) && isAddRecNeverPoison(BEInst, L))
5892 (void)getAddRecExpr(getAddExpr(StartVal, Accum), Accum, L, Flags);
5893
5894 return PHISCEV;
5895 }
5896 }
5897 } else {
5898 // Otherwise, this could be a loop like this:
5899 // i = 0; for (j = 1; ..; ++j) { .... i = j; }
5900 // In this case, j = {1,+,1} and BEValue is j.
5901 // Because the other in-value of i (0) fits the evolution of BEValue
5902 // i really is an addrec evolution.
5903 //
5904 // We can generalize this saying that i is the shifted value of BEValue
5905 // by one iteration:
5906 // PHI(f(0), f({1,+,1})) --> f({0,+,1})
5907 const SCEV *Shifted = SCEVShiftRewriter::rewrite(BEValue, L, *this);
5908 const SCEV *Start = SCEVInitRewriter::rewrite(Shifted, L, *this, false);
5909 if (Shifted != getCouldNotCompute() &&
5910 Start != getCouldNotCompute()) {
5911 const SCEV *StartVal = getSCEV(StartValueV);
5912 if (Start == StartVal) {
5913 // Okay, for the entire analysis of this edge we assumed the PHI
5914 // to be symbolic. We now need to go back and purge all of the
5915 // entries for the scalars that use the symbolic expression.
5916 forgetMemoizedResults(SymbolicName);
5917 insertValueToMap(PN, Shifted);
5918 return Shifted;
5919 }
5920 }
5921 }
5922
5923 // Remove the temporary PHI node SCEV that has been inserted while intending
5924 // to create an AddRecExpr for this PHI node. We can not keep this temporary
5925 // as it will prevent later (possibly simpler) SCEV expressions to be added
5926 // to the ValueExprMap.
5927 eraseValueFromMap(PN);
5928
5929 return nullptr;
5930}
5931
5932// Try to match a control flow sequence that branches out at BI and merges back
5933// at Merge into a "C ? LHS : RHS" select pattern. Return true on a successful
5934// match.
5936 Value *&C, Value *&LHS, Value *&RHS) {
5937 C = BI->getCondition();
5938
5939 BasicBlockEdge LeftEdge(BI->getParent(), BI->getSuccessor(0));
5940 BasicBlockEdge RightEdge(BI->getParent(), BI->getSuccessor(1));
5941
5942 if (!LeftEdge.isSingleEdge())
5943 return false;
5944
5945 assert(RightEdge.isSingleEdge() && "Follows from LeftEdge.isSingleEdge()");
5946
5947 Use &LeftUse = Merge->getOperandUse(0);
5948 Use &RightUse = Merge->getOperandUse(1);
5949
5950 if (DT.dominates(LeftEdge, LeftUse) && DT.dominates(RightEdge, RightUse)) {
5951 LHS = LeftUse;
5952 RHS = RightUse;
5953 return true;
5954 }
5955
5956 if (DT.dominates(LeftEdge, RightUse) && DT.dominates(RightEdge, LeftUse)) {
5957 LHS = RightUse;
5958 RHS = LeftUse;
5959 return true;
5960 }
5961
5962 return false;
5963}
5964
5965const SCEV *ScalarEvolution::createNodeFromSelectLikePHI(PHINode *PN) {
5966 auto IsReachable =
5967 [&](BasicBlock *BB) { return DT.isReachableFromEntry(BB); };
5968 if (PN->getNumIncomingValues() == 2 && all_of(PN->blocks(), IsReachable)) {
5969 // Try to match
5970 //
5971 // br %cond, label %left, label %right
5972 // left:
5973 // br label %merge
5974 // right:
5975 // br label %merge
5976 // merge:
5977 // V = phi [ %x, %left ], [ %y, %right ]
5978 //
5979 // as "select %cond, %x, %y"
5980
5981 BasicBlock *IDom = DT[PN->getParent()]->getIDom()->getBlock();
5982 assert(IDom && "At least the entry block should dominate PN");
5983
5984 auto *BI = dyn_cast<BranchInst>(IDom->getTerminator());
5985 Value *Cond = nullptr, *LHS = nullptr, *RHS = nullptr;
5986
5987 if (BI && BI->isConditional() &&
5988 BrPHIToSelect(DT, BI, PN, Cond, LHS, RHS) &&
5989 properlyDominates(getSCEV(LHS), PN->getParent()) &&
5990 properlyDominates(getSCEV(RHS), PN->getParent()))
5991 return createNodeForSelectOrPHI(PN, Cond, LHS, RHS);
5992 }
5993
5994 return nullptr;
5995}
5996
5997const SCEV *ScalarEvolution::createNodeForPHI(PHINode *PN) {
5998 if (const SCEV *S = createAddRecFromPHI(PN))
5999 return S;
6000
6001 if (Value *V = simplifyInstruction(PN, {getDataLayout(), &TLI, &DT, &AC}))
6002 return getSCEV(V);
6003
6004 if (const SCEV *S = createNodeFromSelectLikePHI(PN))
6005 return S;
6006
6007 // If it's not a loop phi, we can't handle it yet.
6008 return getUnknown(PN);
6009}
6010
6011bool SCEVMinMaxExprContains(const SCEV *Root, const SCEV *OperandToFind,
6012 SCEVTypes RootKind) {
6013 struct FindClosure {
6014 const SCEV *OperandToFind;
6015 const SCEVTypes RootKind; // Must be a sequential min/max expression.
6016 const SCEVTypes NonSequentialRootKind; // Non-seq variant of RootKind.
6017
6018 bool Found = false;
6019
6020 bool canRecurseInto(SCEVTypes Kind) const {
6021 // We can only recurse into the SCEV expression of the same effective type
6022 // as the type of our root SCEV expression, and into zero-extensions.
6023 return RootKind == Kind || NonSequentialRootKind == Kind ||
6024 scZeroExtend == Kind;
6025 };
6026
6027 FindClosure(const SCEV *OperandToFind, SCEVTypes RootKind)
6028 : OperandToFind(OperandToFind), RootKind(RootKind),
6029 NonSequentialRootKind(
6031 RootKind)) {}
6032
6033 bool follow(const SCEV *S) {
6034 Found = S == OperandToFind;
6035
6036 return !isDone() && canRecurseInto(S->getSCEVType());
6037 }
6038
6039 bool isDone() const { return Found; }
6040 };
6041
6042 FindClosure FC(OperandToFind, RootKind);
6043 visitAll(Root, FC);
6044 return FC.Found;
6045}
6046
6047std::optional<const SCEV *>
6048ScalarEvolution::createNodeForSelectOrPHIInstWithICmpInstCond(Type *Ty,
6049 ICmpInst *Cond,
6050 Value *TrueVal,
6051 Value *FalseVal) {
6052 // Try to match some simple smax or umax patterns.
6053 auto *ICI = Cond;
6054
6055 Value *LHS = ICI->getOperand(0);
6056 Value *RHS = ICI->getOperand(1);
6057
6058 switch (ICI->getPredicate()) {
6059 case ICmpInst::ICMP_SLT:
6060 case ICmpInst::ICMP_SLE:
6061 case ICmpInst::ICMP_ULT:
6062 case ICmpInst::ICMP_ULE:
6063 std::swap(LHS, RHS);
6064 [[fallthrough]];
6065 case ICmpInst::ICMP_SGT:
6066 case ICmpInst::ICMP_SGE:
6067 case ICmpInst::ICMP_UGT:
6068 case ICmpInst::ICMP_UGE:
6069 // a > b ? a+x : b+x -> max(a, b)+x
6070 // a > b ? b+x : a+x -> min(a, b)+x
6072 bool Signed = ICI->isSigned();
6073 const SCEV *LA = getSCEV(TrueVal);
6074 const SCEV *RA = getSCEV(FalseVal);
6075 const SCEV *LS = getSCEV(LHS);
6076 const SCEV *RS = getSCEV(RHS);
6077 if (LA->getType()->isPointerTy()) {
6078 // FIXME: Handle cases where LS/RS are pointers not equal to LA/RA.
6079 // Need to make sure we can't produce weird expressions involving
6080 // negated pointers.
6081 if (LA == LS && RA == RS)
6082 return Signed ? getSMaxExpr(LS, RS) : getUMaxExpr(LS, RS);
6083 if (LA == RS && RA == LS)
6084 return Signed ? getSMinExpr(LS, RS) : getUMinExpr(LS, RS);
6085 }
6086 auto CoerceOperand = [&](const SCEV *Op) -> const SCEV * {
6087 if (Op->getType()->isPointerTy()) {
6089 if (isa<SCEVCouldNotCompute>(Op))
6090 return Op;
6091 }
6092 if (Signed)
6093 Op = getNoopOrSignExtend(Op, Ty);
6094 else
6095 Op = getNoopOrZeroExtend(Op, Ty);
6096 return Op;
6097 };
6098 LS = CoerceOperand(LS);
6099 RS = CoerceOperand(RS);
6100 if (isa<SCEVCouldNotCompute>(LS) || isa<SCEVCouldNotCompute>(RS))
6101 break;
6102 const SCEV *LDiff = getMinusSCEV(LA, LS);
6103 const SCEV *RDiff = getMinusSCEV(RA, RS);
6104 if (LDiff == RDiff)
6105 return getAddExpr(Signed ? getSMaxExpr(LS, RS) : getUMaxExpr(LS, RS),
6106 LDiff);
6107 LDiff = getMinusSCEV(LA, RS);
6108 RDiff = getMinusSCEV(RA, LS);
6109 if (LDiff == RDiff)
6110 return getAddExpr(Signed ? getSMinExpr(LS, RS) : getUMinExpr(LS, RS),
6111 LDiff);
6112 }
6113 break;
6114 case ICmpInst::ICMP_NE:
6115 // x != 0 ? x+y : C+y -> x == 0 ? C+y : x+y
6116 std::swap(TrueVal, FalseVal);
6117 [[fallthrough]];
6118 case ICmpInst::ICMP_EQ:
6119 // x == 0 ? C+y : x+y -> umax(x, C)+y iff C u<= 1
6121 isa<ConstantInt>(RHS) && cast<ConstantInt>(RHS)->isZero()) {
6122 const SCEV *X = getNoopOrZeroExtend(getSCEV(LHS), Ty);
6123 const SCEV *TrueValExpr = getSCEV(TrueVal); // C+y
6124 const SCEV *FalseValExpr = getSCEV(FalseVal); // x+y
6125 const SCEV *Y = getMinusSCEV(FalseValExpr, X); // y = (x+y)-x
6126 const SCEV *C = getMinusSCEV(TrueValExpr, Y); // C = (C+y)-y
6127 if (isa<SCEVConstant>(C) && cast<SCEVConstant>(C)->getAPInt().ule(1))
6128 return getAddExpr(getUMaxExpr(X, C), Y);
6129 }
6130 // x == 0 ? 0 : umin (..., x, ...) -> umin_seq(x, umin (...))
6131 // x == 0 ? 0 : umin_seq(..., x, ...) -> umin_seq(x, umin_seq(...))
6132 // x == 0 ? 0 : umin (..., umin_seq(..., x, ...), ...)
6133 // -> umin_seq(x, umin (..., umin_seq(...), ...))
6134 if (isa<ConstantInt>(RHS) && cast<ConstantInt>(RHS)->isZero() &&
6135 isa<ConstantInt>(TrueVal) && cast<ConstantInt>(TrueVal)->isZero()) {
6136 const SCEV *X = getSCEV(LHS);
6137 while (auto *ZExt = dyn_cast<SCEVZeroExtendExpr>(X))
6138 X = ZExt->getOperand();
6139 if (getTypeSizeInBits(X->getType()) <= getTypeSizeInBits(Ty)) {
6140 const SCEV *FalseValExpr = getSCEV(FalseVal);
6141 if (SCEVMinMaxExprContains(FalseValExpr, X, scSequentialUMinExpr))
6142 return getUMinExpr(getNoopOrZeroExtend(X, Ty), FalseValExpr,
6143 /*Sequential=*/true);
6144 }
6145 }
6146 break;
6147 default:
6148 break;
6149 }
6150
6151 return std::nullopt;
6152}
6153
6154static std::optional<const SCEV *>
6156 const SCEV *TrueExpr, const SCEV *FalseExpr) {
6157 assert(CondExpr->getType()->isIntegerTy(1) &&
6158 TrueExpr->getType() == FalseExpr->getType() &&
6159 TrueExpr->getType()->isIntegerTy(1) &&
6160 "Unexpected operands of a select.");
6161
6162 // i1 cond ? i1 x : i1 C --> C + (i1 cond ? (i1 x - i1 C) : i1 0)
6163 // --> C + (umin_seq cond, x - C)
6164 //
6165 // i1 cond ? i1 C : i1 x --> C + (i1 cond ? i1 0 : (i1 x - i1 C))
6166 // --> C + (i1 ~cond ? (i1 x - i1 C) : i1 0)
6167 // --> C + (umin_seq ~cond, x - C)
6168
6169 // FIXME: while we can't legally model the case where both of the hands
6170 // are fully variable, we only require that the *difference* is constant.
6171 if (!isa<SCEVConstant>(TrueExpr) && !isa<SCEVConstant>(FalseExpr))
6172 return std::nullopt;
6173
6174 const SCEV *X, *C;
6175 if (isa<SCEVConstant>(TrueExpr)) {
6176 CondExpr = SE->getNotSCEV(CondExpr);
6177 X = FalseExpr;
6178 C = TrueExpr;
6179 } else {
6180 X = TrueExpr;
6181 C = FalseExpr;
6182 }
6183 return SE->getAddExpr(C, SE->getUMinExpr(CondExpr, SE->getMinusSCEV(X, C),
6184 /*Sequential=*/true));
6185}
6186
6187static std::optional<const SCEV *>
6189 Value *FalseVal) {
6190 if (!isa<ConstantInt>(TrueVal) && !isa<ConstantInt>(FalseVal))
6191 return std::nullopt;
6192
6193 const auto *SECond = SE->getSCEV(Cond);
6194 const auto *SETrue = SE->getSCEV(TrueVal);
6195 const auto *SEFalse = SE->getSCEV(FalseVal);
6196 return createNodeForSelectViaUMinSeq(SE, SECond, SETrue, SEFalse);
6197}
6198
6199const SCEV *ScalarEvolution::createNodeForSelectOrPHIViaUMinSeq(
6200 Value *V, Value *Cond, Value *TrueVal, Value *FalseVal) {
6201 assert(Cond->getType()->isIntegerTy(1) && "Select condition is not an i1?");
6202 assert(TrueVal->getType() == FalseVal->getType() &&
6203 V->getType() == TrueVal->getType() &&
6204 "Types of select hands and of the result must match.");
6205
6206 // For now, only deal with i1-typed `select`s.
6207 if (!V->getType()->isIntegerTy(1))
6208 return getUnknown(V);
6209
6210 if (std::optional<const SCEV *> S =
6211 createNodeForSelectViaUMinSeq(this, Cond, TrueVal, FalseVal))
6212 return *S;
6213
6214 return getUnknown(V);
6215}
6216
6217const SCEV *ScalarEvolution::createNodeForSelectOrPHI(Value *V, Value *Cond,
6218 Value *TrueVal,
6219 Value *FalseVal) {
6220 // Handle "constant" branch or select. This can occur for instance when a
6221 // loop pass transforms an inner loop and moves on to process the outer loop.
6222 if (auto *CI = dyn_cast<ConstantInt>(Cond))
6223 return getSCEV(CI->isOne() ? TrueVal : FalseVal);
6224
6225 if (auto *I = dyn_cast<Instruction>(V)) {
6226 if (auto *ICI = dyn_cast<ICmpInst>(Cond)) {
6227 if (std::optional<const SCEV *> S =
6228 createNodeForSelectOrPHIInstWithICmpInstCond(I->getType(), ICI,
6229 TrueVal, FalseVal))
6230 return *S;
6231 }
6232 }
6233
6234 return createNodeForSelectOrPHIViaUMinSeq(V, Cond, TrueVal, FalseVal);
6235}
6236
6237/// Expand GEP instructions into add and multiply operations. This allows them
6238/// to be analyzed by regular SCEV code.
6239const SCEV *ScalarEvolution::createNodeForGEP(GEPOperator *GEP) {
6240 assert(GEP->getSourceElementType()->isSized() &&
6241 "GEP source element type must be sized");
6242
6244 for (Value *Index : GEP->indices())
6245 IndexExprs.push_back(getSCEV(Index));
6246 return getGEPExpr(GEP, IndexExprs);
6247}
6248
6249APInt ScalarEvolution::getConstantMultipleImpl(const SCEV *S) {
6251 auto GetShiftedByZeros = [BitWidth](uint32_t TrailingZeros) {
6252 return TrailingZeros >= BitWidth
6254 : APInt::getOneBitSet(BitWidth, TrailingZeros);
6255 };
6256 auto GetGCDMultiple = [this](const SCEVNAryExpr *N) {
6257 // The result is GCD of all operands results.
6258 APInt Res = getConstantMultiple(N->getOperand(0));
6259 for (unsigned I = 1, E = N->getNumOperands(); I < E && Res != 1; ++I)
6261 Res, getConstantMultiple(N->getOperand(I)));
6262 return Res;
6263 };
6264
6265 switch (S->getSCEVType()) {
6266 case scConstant:
6267 return cast<SCEVConstant>(S)->getAPInt();
6268 case scPtrToInt:
6269 return getConstantMultiple(cast<SCEVPtrToIntExpr>(S)->getOperand());
6270 case scUDivExpr:
6271 case scVScale:
6272 return APInt(BitWidth, 1);
6273 case scTruncate: {
6274 // Only multiples that are a power of 2 will hold after truncation.
6275 const SCEVTruncateExpr *T = cast<SCEVTruncateExpr>(S);
6276 uint32_t TZ = getMinTrailingZeros(T->getOperand());
6277 return GetShiftedByZeros(TZ);
6278 }
6279 case scZeroExtend: {
6280 const SCEVZeroExtendExpr *Z = cast<SCEVZeroExtendExpr>(S);
6281 return getConstantMultiple(Z->getOperand()).zext(BitWidth);
6282 }
6283 case scSignExtend: {
6284 const SCEVSignExtendExpr *E = cast<SCEVSignExtendExpr>(S);
6286 }
6287 case scMulExpr: {
6288 const SCEVMulExpr *M = cast<SCEVMulExpr>(S);
6289 if (M->hasNoUnsignedWrap()) {
6290 // The result is the product of all operand results.
6291 APInt Res = getConstantMultiple(M->getOperand(0));
6292 for (const SCEV *Operand : M->operands().drop_front())
6293 Res = Res * getConstantMultiple(Operand);
6294 return Res;
6295 }
6296
6297 // If there are no wrap guarentees, find the trailing zeros, which is the
6298 // sum of trailing zeros for all its operands.
6299 uint32_t TZ = 0;
6300 for (const SCEV *Operand : M->operands())
6301 TZ += getMinTrailingZeros(Operand);
6302 return GetShiftedByZeros(TZ);
6303 }
6304 case scAddExpr:
6305 case scAddRecExpr: {
6306 const SCEVNAryExpr *N = cast<SCEVNAryExpr>(S);
6307 if (N->hasNoUnsignedWrap())
6308 return GetGCDMultiple(N);
6309 // Find the trailing bits, which is the minimum of its operands.
6310 uint32_t TZ = getMinTrailingZeros(N->getOperand(0));
6311 for (const SCEV *Operand : N->operands().drop_front())
6312 TZ = std::min(TZ, getMinTrailingZeros(Operand));
6313 return GetShiftedByZeros(TZ);
6314 }
6315 case scUMaxExpr:
6316 case scSMaxExpr:
6317 case scUMinExpr:
6318 case scSMinExpr:
6320 return GetGCDMultiple(cast<SCEVNAryExpr>(S));
6321 case scUnknown: {
6322 // ask ValueTracking for known bits
6323 const SCEVUnknown *U = cast<SCEVUnknown>(S);
6324 unsigned Known =
6325 computeKnownBits(U->getValue(), getDataLayout(), 0, &AC, nullptr, &DT)
6326 .countMinTrailingZeros();
6327 return GetShiftedByZeros(Known);
6328 }
6329 case scCouldNotCompute:
6330 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
6331 }
6332 llvm_unreachable("Unknown SCEV kind!");
6333}
6334
6336 auto I = ConstantMultipleCache.find(S);
6337 if (I != ConstantMultipleCache.end())
6338 return I->second;
6339
6340 APInt Result = getConstantMultipleImpl(S);
6341 auto InsertPair = ConstantMultipleCache.insert({S, Result});
6342 assert(InsertPair.second && "Should insert a new key");
6343 return InsertPair.first->second;
6344}
6345
6347 APInt Multiple = getConstantMultiple(S);
6348 return Multiple == 0 ? APInt(Multiple.getBitWidth(), 1) : Multiple;
6349}
6350
6352 return std::min(getConstantMultiple(S).countTrailingZeros(),
6353 (unsigned)getTypeSizeInBits(S->getType()));
6354}
6355
6356/// Helper method to assign a range to V from metadata present in the IR.
6357static std::optional<ConstantRange> GetRangeFromMetadata(Value *V) {
6358 if (Instruction *I = dyn_cast<Instruction>(V)) {
6359 if (MDNode *MD = I->getMetadata(LLVMContext::MD_range))
6360 return getConstantRangeFromMetadata(*MD);
6361 if (const auto *CB = dyn_cast<CallBase>(V))
6362 if (std::optional<ConstantRange> Range = CB->getRange())
6363 return Range;
6364 }
6365 if (auto *A = dyn_cast<Argument>(V))
6366 if (std::optional<ConstantRange> Range = A->getRange())
6367 return Range;
6368
6369 return std::nullopt;
6370}
6371
6373 SCEV::NoWrapFlags Flags) {
6374 if (AddRec->getNoWrapFlags(Flags) != Flags) {
6375 AddRec->setNoWrapFlags(Flags);
6376 UnsignedRanges.erase(AddRec);
6377 SignedRanges.erase(AddRec);
6378 ConstantMultipleCache.erase(AddRec);
6379 }
6380}
6381
6382ConstantRange ScalarEvolution::
6383getRangeForUnknownRecurrence(const SCEVUnknown *U) {
6384 const DataLayout &DL = getDataLayout();
6385
6386 unsigned BitWidth = getTypeSizeInBits(U->getType());
6387 const ConstantRange FullSet(BitWidth, /*isFullSet=*/true);
6388
6389 // Match a simple recurrence of the form: <start, ShiftOp, Step>, and then
6390 // use information about the trip count to improve our available range. Note
6391 // that the trip count independent cases are already handled by known bits.
6392 // WARNING: The definition of recurrence used here is subtly different than
6393 // the one used by AddRec (and thus most of this file). Step is allowed to
6394 // be arbitrarily loop varying here, where AddRec allows only loop invariant
6395 // and other addrecs in the same loop (for non-affine addrecs). The code
6396 // below intentionally handles the case where step is not loop invariant.
6397 auto *P = dyn_cast<PHINode>(U->getValue());
6398 if (!P)
6399 return FullSet;
6400
6401 // Make sure that no Phi input comes from an unreachable block. Otherwise,
6402 // even the values that are not available in these blocks may come from them,
6403 // and this leads to false-positive recurrence test.
6404 for (auto *Pred : predecessors(P->getParent()))
6405 if (!DT.isReachableFromEntry(Pred))
6406 return FullSet;
6407
6408 BinaryOperator *BO;
6409 Value *Start, *Step;
6410 if (!matchSimpleRecurrence(P, BO, Start, Step))
6411 return FullSet;
6412
6413 // If we found a recurrence in reachable code, we must be in a loop. Note
6414 // that BO might be in some subloop of L, and that's completely okay.
6415 auto *L = LI.getLoopFor(P->getParent());
6416 assert(L && L->getHeader() == P->getParent());
6417 if (!L->contains(BO->getParent()))
6418 // NOTE: This bailout should be an assert instead. However, asserting
6419 // the condition here exposes a case where LoopFusion is querying SCEV
6420 // with malformed loop information during the midst of the transform.
6421 // There doesn't appear to be an obvious fix, so for the moment bailout
6422 // until the caller issue can be fixed. PR49566 tracks the bug.
6423 return FullSet;
6424
6425 // TODO: Extend to other opcodes such as mul, and div
6426 switch (BO->getOpcode()) {
6427 default:
6428 return FullSet;
6429 case Instruction::AShr:
6430 case Instruction::LShr:
6431 case Instruction::Shl:
6432 break;
6433 };
6434
6435 if (BO->getOperand(0) != P)
6436 // TODO: Handle the power function forms some day.
6437 return FullSet;
6438
6439 unsigned TC = getSmallConstantMaxTripCount(L);
6440 if (!TC || TC >= BitWidth)
6441 return FullSet;
6442
6443 auto KnownStart = computeKnownBits(Start, DL, 0, &AC, nullptr, &DT);
6444 auto KnownStep = computeKnownBits(Step, DL, 0, &AC, nullptr, &DT);
6445 assert(KnownStart.getBitWidth() == BitWidth &&
6446 KnownStep.getBitWidth() == BitWidth);
6447
6448 // Compute total shift amount, being careful of overflow and bitwidths.
6449 auto MaxShiftAmt = KnownStep.getMaxValue();
6450 APInt TCAP(BitWidth, TC-1);
6451 bool Overflow = false;
6452 auto TotalShift = MaxShiftAmt.umul_ov(TCAP, Overflow);
6453 if (Overflow)
6454 return FullSet;
6455
6456 switch (BO->getOpcode()) {
6457 default:
6458 llvm_unreachable("filtered out above");
6459 case Instruction::AShr: {
6460 // For each ashr, three cases:
6461 // shift = 0 => unchanged value
6462 // saturation => 0 or -1
6463 // other => a value closer to zero (of the same sign)
6464 // Thus, the end value is closer to zero than the start.
6465 auto KnownEnd = KnownBits::ashr(KnownStart,
6466 KnownBits::makeConstant(TotalShift));
6467 if (KnownStart.isNonNegative())
6468 // Analogous to lshr (simply not yet canonicalized)
6469 return ConstantRange::getNonEmpty(KnownEnd.getMinValue(),
6470 KnownStart.getMaxValue() + 1);
6471 if (KnownStart.isNegative())
6472 // End >=u Start && End <=s Start
6473 return ConstantRange::getNonEmpty(KnownStart.getMinValue(),
6474 KnownEnd.getMaxValue() + 1);
6475 break;
6476 }
6477 case Instruction::LShr: {
6478 // For each lshr, three cases:
6479 // shift = 0 => unchanged value
6480 // saturation => 0
6481 // other => a smaller positive number
6482 // Thus, the low end of the unsigned range is the last value produced.
6483 auto KnownEnd = KnownBits::lshr(KnownStart,
6484 KnownBits::makeConstant(TotalShift));
6485 return ConstantRange::getNonEmpty(KnownEnd.getMinValue(),
6486 KnownStart.getMaxValue() + 1);
6487 }
6488 case Instruction::Shl: {
6489 // Iff no bits are shifted out, value increases on every shift.
6490 auto KnownEnd = KnownBits::shl(KnownStart,
6491 KnownBits::makeConstant(TotalShift));
6492 if (TotalShift.ult(KnownStart.countMinLeadingZeros()))
6493 return ConstantRange(KnownStart.getMinValue(),
6494 KnownEnd.getMaxValue() + 1);
6495 break;
6496 }
6497 };
6498 return FullSet;
6499}
6500
6501const ConstantRange &
6502ScalarEvolution::getRangeRefIter(const SCEV *S,
6503 ScalarEvolution::RangeSignHint SignHint) {
6505 SignHint == ScalarEvolution::HINT_RANGE_UNSIGNED ? UnsignedRanges
6506 : SignedRanges;
6509
6510 // Add Expr to the worklist, if Expr is either an N-ary expression or a
6511 // SCEVUnknown PHI node.
6512 auto AddToWorklist = [&WorkList, &Seen, &Cache](const SCEV *Expr) {
6513 if (!Seen.insert(Expr).second)
6514 return;
6515 if (Cache.contains(Expr))
6516 return;
6517 switch (Expr->getSCEVType()) {
6518 case scUnknown:
6519 if (!isa<PHINode>(cast<SCEVUnknown>(Expr)->getValue()))
6520 break;
6521 [[fallthrough]];
6522 case scConstant:
6523 case scVScale:
6524 case scTruncate:
6525 case scZeroExtend:
6526 case scSignExtend:
6527 case scPtrToInt:
6528 case scAddExpr:
6529 case scMulExpr:
6530 case scUDivExpr:
6531 case scAddRecExpr:
6532 case scUMaxExpr:
6533 case scSMaxExpr:
6534 case scUMinExpr:
6535 case scSMinExpr:
6537 WorkList.push_back(Expr);
6538 break;
6539 case scCouldNotCompute:
6540 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
6541 }
6542 };
6543 AddToWorklist(S);
6544
6545 // Build worklist by queuing operands of N-ary expressions and phi nodes.
6546 for (unsigned I = 0; I != WorkList.size(); ++I) {
6547 const SCEV *P = WorkList[I];
6548 auto *UnknownS = dyn_cast<SCEVUnknown>(P);
6549 // If it is not a `SCEVUnknown`, just recurse into operands.
6550 if (!UnknownS) {
6551 for (const SCEV *Op : P->operands())
6552 AddToWorklist(Op);
6553 continue;
6554 }
6555 // `SCEVUnknown`'s require special treatment.
6556 if (const PHINode *P = dyn_cast<PHINode>(UnknownS->getValue())) {
6557 if (!PendingPhiRangesIter.insert(P).second)
6558 continue;
6559 for (auto &Op : reverse(P->operands()))
6560 AddToWorklist(getSCEV(Op));
6561 }
6562 }
6563
6564 if (!WorkList.empty()) {
6565 // Use getRangeRef to compute ranges for items in the worklist in reverse
6566 // order. This will force ranges for earlier operands to be computed before
6567 // their users in most cases.
6568 for (const SCEV *P : reverse(drop_begin(WorkList))) {
6569 getRangeRef(P, SignHint);
6570
6571 if (auto *UnknownS = dyn_cast<SCEVUnknown>(P))
6572 if (const PHINode *P = dyn_cast<PHINode>(UnknownS->getValue()))
6573 PendingPhiRangesIter.erase(P);
6574 }
6575 }
6576
6577 return getRangeRef(S, SignHint, 0);
6578}
6579
6580/// Determine the range for a particular SCEV. If SignHint is
6581/// HINT_RANGE_UNSIGNED (resp. HINT_RANGE_SIGNED) then getRange prefers ranges
6582/// with a "cleaner" unsigned (resp. signed) representation.
6583const ConstantRange &ScalarEvolution::getRangeRef(
6584 const SCEV *S, ScalarEvolution::RangeSignHint SignHint, unsigned Depth) {
6586 SignHint == ScalarEvolution::HINT_RANGE_UNSIGNED ? UnsignedRanges
6587 : SignedRanges;
6589 SignHint == ScalarEvolution::HINT_RANGE_UNSIGNED ? ConstantRange::Unsigned
6591
6592 // See if we've computed this range already.
6594 if (I != Cache.end())
6595 return I->second;
6596
6597 if (const SCEVConstant *C = dyn_cast<SCEVConstant>(S))
6598 return setRange(C, SignHint, ConstantRange(C->getAPInt()));
6599
6600 // Switch to iteratively computing the range for S, if it is part of a deeply
6601 // nested expression.
6603 return getRangeRefIter(S, SignHint);
6604
6605 unsigned BitWidth = getTypeSizeInBits(S->getType());
6606 ConstantRange ConservativeResult(BitWidth, /*isFullSet=*/true);
6607 using OBO = OverflowingBinaryOperator;
6608
6609 // If the value has known zeros, the maximum value will have those known zeros
6610 // as well.
6611 if (SignHint == ScalarEvolution::HINT_RANGE_UNSIGNED) {
6612 APInt Multiple = getNonZeroConstantMultiple(S);
6613 APInt Remainder = APInt::getMaxValue(BitWidth).urem(Multiple);
6614 if (!Remainder.isZero())
6615 ConservativeResult =
6617 APInt::getMaxValue(BitWidth) - Remainder + 1);
6618 }
6619 else {
6621 if (TZ != 0) {
6622 ConservativeResult = ConstantRange(
6624 APInt::getSignedMaxValue(BitWidth).ashr(TZ).shl(TZ) + 1);
6625 }
6626 }
6627
6628 switch (S->getSCEVType()) {
6629 case scConstant:
6630 llvm_unreachable("Already handled above.");
6631 case scVScale:
6632 return setRange(S, SignHint, getVScaleRange(&F, BitWidth));
6633 case scTruncate: {
6634 const SCEVTruncateExpr *Trunc = cast<SCEVTruncateExpr>(S);
6635 ConstantRange X = getRangeRef(Trunc->getOperand(), SignHint, Depth + 1);
6636 return setRange(
6637 Trunc, SignHint,
6638 ConservativeResult.intersectWith(X.truncate(BitWidth), RangeType));
6639 }
6640 case scZeroExtend: {
6641 const SCEVZeroExtendExpr *ZExt = cast<SCEVZeroExtendExpr>(S);
6642 ConstantRange X = getRangeRef(ZExt->getOperand(), SignHint, Depth + 1);
6643 return setRange(
6644 ZExt, SignHint,
6645 ConservativeResult.intersectWith(X.zeroExtend(BitWidth), RangeType));
6646 }
6647 case scSignExtend: {
6648 const SCEVSignExtendExpr *SExt = cast<SCEVSignExtendExpr>(S);
6649 ConstantRange X = getRangeRef(SExt->getOperand(), SignHint, Depth + 1);
6650 return setRange(
6651 SExt, SignHint,
6652 ConservativeResult.intersectWith(X.signExtend(BitWidth), RangeType));
6653 }
6654 case scPtrToInt: {
6655 const SCEVPtrToIntExpr *PtrToInt = cast<SCEVPtrToIntExpr>(S);
6656 ConstantRange X = getRangeRef(PtrToInt->getOperand(), SignHint, Depth + 1);
6657 return setRange(PtrToInt, SignHint, X);
6658 }
6659 case scAddExpr: {
6660 const SCEVAddExpr *Add = cast<SCEVAddExpr>(S);
6661 ConstantRange X = getRangeRef(Add->getOperand(0), SignHint, Depth + 1);
6662 unsigned WrapType = OBO::AnyWrap;
6663 if (Add->hasNoSignedWrap())
6664 WrapType |= OBO::NoSignedWrap;
6665 if (Add->hasNoUnsignedWrap())
6666 WrapType |= OBO::NoUnsignedWrap;
6667 for (const SCEV *Op : drop_begin(Add->operands()))
6668 X = X.addWithNoWrap(getRangeRef(Op, SignHint, Depth + 1), WrapType,
6669 RangeType);
6670 return setRange(Add, SignHint,
6671 ConservativeResult.intersectWith(X, RangeType));
6672 }
6673 case scMulExpr: {
6674 const SCEVMulExpr *Mul = cast<SCEVMulExpr>(S);
6675 ConstantRange X = getRangeRef(Mul->getOperand(0), SignHint, Depth + 1);
6676 for (const SCEV *Op : drop_begin(Mul->operands()))
6677 X = X.multiply(getRangeRef(Op, SignHint, Depth + 1));
6678 return setRange(Mul, SignHint,
6679 ConservativeResult.intersectWith(X, RangeType));
6680 }
6681 case scUDivExpr: {
6682 const SCEVUDivExpr *UDiv = cast<SCEVUDivExpr>(S);
6683 ConstantRange X = getRangeRef(UDiv->getLHS(), SignHint, Depth + 1);
6684 ConstantRange Y = getRangeRef(UDiv->getRHS(), SignHint, Depth + 1);
6685 return setRange(UDiv, SignHint,
6686 ConservativeResult.intersectWith(X.udiv(Y), RangeType));
6687 }
6688 case scAddRecExpr: {
6689 const SCEVAddRecExpr *AddRec = cast<SCEVAddRecExpr>(S);
6690 // If there's no unsigned wrap, the value will never be less than its
6691 // initial value.
6692 if (AddRec->hasNoUnsignedWrap()) {
6693 APInt UnsignedMinValue = getUnsignedRangeMin(AddRec->getStart());
6694 if (!UnsignedMinValue.isZero())
6695 ConservativeResult = ConservativeResult.intersectWith(
6696 ConstantRange(UnsignedMinValue, APInt(BitWidth, 0)), RangeType);
6697 }
6698
6699 // If there's no signed wrap, and all the operands except initial value have
6700 // the same sign or zero, the value won't ever be:
6701 // 1: smaller than initial value if operands are non negative,
6702 // 2: bigger than initial value if operands are non positive.
6703 // For both cases, value can not cross signed min/max boundary.
6704 if (AddRec->hasNoSignedWrap()) {
6705 bool AllNonNeg = true;
6706 bool AllNonPos = true;
6707 for (unsigned i = 1, e = AddRec->getNumOperands(); i != e; ++i) {
6708 if (!isKnownNonNegative(AddRec->getOperand(i)))
6709 AllNonNeg = false;
6710 if (!isKnownNonPositive(AddRec->getOperand(i)))
6711 AllNonPos = false;
6712 }
6713 if (AllNonNeg)
6714 ConservativeResult = ConservativeResult.intersectWith(
6717 RangeType);
6718 else if (AllNonPos)
6719 ConservativeResult = ConservativeResult.intersectWith(
6721 getSignedRangeMax(AddRec->getStart()) +
6722 1),
6723 RangeType);
6724 }
6725
6726 // TODO: non-affine addrec
6727 if (AddRec->isAffine()) {
6728 const SCEV *MaxBEScev =
6730 if (!isa<SCEVCouldNotCompute>(MaxBEScev)) {
6731 APInt MaxBECount = cast<SCEVConstant>(MaxBEScev)->getAPInt();
6732
6733 // Adjust MaxBECount to the same bitwidth as AddRec. We can truncate if
6734 // MaxBECount's active bits are all <= AddRec's bit width.
6735 if (MaxBECount.getBitWidth() > BitWidth &&
6736 MaxBECount.getActiveBits() <= BitWidth)
6737 MaxBECount = MaxBECount.trunc(BitWidth);
6738 else if (MaxBECount.getBitWidth() < BitWidth)
6739 MaxBECount = MaxBECount.zext(BitWidth);
6740
6741 if (MaxBECount.getBitWidth() == BitWidth) {
6742 auto RangeFromAffine = getRangeForAffineAR(
6743 AddRec->getStart(), AddRec->getStepRecurrence(*this), MaxBECount);
6744 ConservativeResult =
6745 ConservativeResult.intersectWith(RangeFromAffine, RangeType);
6746
6747 auto RangeFromFactoring = getRangeViaFactoring(
6748 AddRec->getStart(), AddRec->getStepRecurrence(*this), MaxBECount);
6749 ConservativeResult =
6750 ConservativeResult.intersectWith(RangeFromFactoring, RangeType);
6751 }
6752 }
6753
6754 // Now try symbolic BE count and more powerful methods.
6756 const SCEV *SymbolicMaxBECount =
6758 if (!isa<SCEVCouldNotCompute>(SymbolicMaxBECount) &&
6759 getTypeSizeInBits(MaxBEScev->getType()) <= BitWidth &&
6760 AddRec->hasNoSelfWrap()) {
6761 auto RangeFromAffineNew = getRangeForAffineNoSelfWrappingAR(
6762 AddRec, SymbolicMaxBECount, BitWidth, SignHint);
6763 ConservativeResult =
6764 ConservativeResult.intersectWith(RangeFromAffineNew, RangeType);
6765 }
6766 }
6767 }
6768
6769 return setRange(AddRec, SignHint, std::move(ConservativeResult));
6770 }
6771 case scUMaxExpr:
6772 case scSMaxExpr:
6773 case scUMinExpr:
6774 case scSMinExpr:
6775 case scSequentialUMinExpr: {
6777 switch (S->getSCEVType()) {
6778 case scUMaxExpr:
6779 ID = Intrinsic::umax;
6780 break;
6781 case scSMaxExpr:
6782 ID = Intrinsic::smax;
6783 break;
6784 case scUMinExpr:
6786 ID = Intrinsic::umin;
6787 break;
6788 case scSMinExpr:
6789 ID = Intrinsic::smin;
6790 break;
6791 default:
6792 llvm_unreachable("Unknown SCEVMinMaxExpr/SCEVSequentialMinMaxExpr.");
6793 }
6794
6795 const auto *NAry = cast<SCEVNAryExpr>(S);
6796 ConstantRange X = getRangeRef(NAry->getOperand(0), SignHint, Depth + 1);
6797 for (unsigned i = 1, e = NAry->getNumOperands(); i != e; ++i)
6798 X = X.intrinsic(
6799 ID, {X, getRangeRef(NAry->getOperand(i), SignHint, Depth + 1)});
6800 return setRange(S, SignHint,
6801 ConservativeResult.intersectWith(X, RangeType));
6802 }
6803 case scUnknown: {
6804 const SCEVUnknown *U = cast<SCEVUnknown>(S);
6805 Value *V = U->getValue();
6806
6807 // Check if the IR explicitly contains !range metadata.
6808 std::optional<ConstantRange> MDRange = GetRangeFromMetadata(V);
6809 if (MDRange)
6810 ConservativeResult =
6811 ConservativeResult.intersectWith(*MDRange, RangeType);
6812
6813 // Use facts about recurrences in the underlying IR. Note that add
6814 // recurrences are AddRecExprs and thus don't hit this path. This
6815 // primarily handles shift recurrences.
6816 auto CR = getRangeForUnknownRecurrence(U);
6817 ConservativeResult = ConservativeResult.intersectWith(CR);
6818
6819 // See if ValueTracking can give us a useful range.
6820 const DataLayout &DL = getDataLayout();
6821 KnownBits Known = computeKnownBits(V, DL, 0, &AC, nullptr, &DT);
6822 if (Known.getBitWidth() != BitWidth)
6823 Known = Known.zextOrTrunc(BitWidth);
6824
6825 // ValueTracking may be able to compute a tighter result for the number of
6826 // sign bits than for the value of those sign bits.
6827 unsigned NS = ComputeNumSignBits(V, DL, 0, &AC, nullptr, &DT);
6828 if (U->getType()->isPointerTy()) {
6829 // If the pointer size is larger than the index size type, this can cause
6830 // NS to be larger than BitWidth. So compensate for this.
6831 unsigned ptrSize = DL.getPointerTypeSizeInBits(U->getType());
6832 int ptrIdxDiff = ptrSize - BitWidth;
6833 if (ptrIdxDiff > 0 && ptrSize > BitWidth && NS > (unsigned)ptrIdxDiff)
6834 NS -= ptrIdxDiff;
6835 }
6836
6837 if (NS > 1) {
6838 // If we know any of the sign bits, we know all of the sign bits.
6839 if (!Known.Zero.getHiBits(NS).isZero())
6840 Known.Zero.setHighBits(NS);
6841 if (!Known.One.getHiBits(NS).isZero())
6842 Known.One.setHighBits(NS);
6843 }
6844
6845 if (Known.getMinValue() != Known.getMaxValue() + 1)
6846 ConservativeResult = ConservativeResult.intersectWith(
6847 ConstantRange(Known.getMinValue(), Known.getMaxValue() + 1),
6848 RangeType);
6849 if (NS > 1)
6850 ConservativeResult = ConservativeResult.intersectWith(
6852 APInt::getSignedMaxValue(BitWidth).ashr(NS - 1) + 1),
6853 RangeType);
6854
6855 if (U->getType()->isPointerTy() && SignHint == HINT_RANGE_UNSIGNED) {
6856 // Strengthen the range if the underlying IR value is a
6857 // global/alloca/heap allocation using the size of the object.
6858 ObjectSizeOpts Opts;
6859 Opts.RoundToAlign = false;
6860 Opts.NullIsUnknownSize = true;
6861 uint64_t ObjSize;
6862 if ((isa<GlobalVariable>(V) || isa<AllocaInst>(V) ||
6863 isAllocationFn(V, &TLI)) &&
6864 getObjectSize(V, ObjSize, DL, &TLI, Opts) && ObjSize > 1) {
6865 // The highest address the object can start is ObjSize bytes before the
6866 // end (unsigned max value). If this value is not a multiple of the
6867 // alignment, the last possible start value is the next lowest multiple
6868 // of the alignment. Note: The computations below cannot overflow,
6869 // because if they would there's no possible start address for the
6870 // object.
6871 APInt MaxVal = APInt::getMaxValue(BitWidth) - APInt(BitWidth, ObjSize);
6872 uint64_t Align = U->getValue()->getPointerAlignment(DL).value();
6873 uint64_t Rem = MaxVal.urem(Align);
6874 MaxVal -= APInt(BitWidth, Rem);
6875 APInt MinVal = APInt::getZero(BitWidth);
6876 if (llvm::isKnownNonZero(V, DL))
6877 MinVal = Align;
6878 ConservativeResult = ConservativeResult.intersectWith(
6879 ConstantRange::getNonEmpty(MinVal, MaxVal + 1), RangeType);
6880 }
6881 }
6882
6883 // A range of Phi is a subset of union of all ranges of its input.
6884 if (PHINode *Phi = dyn_cast<PHINode>(V)) {
6885 // Make sure that we do not run over cycled Phis.
6886 if (PendingPhiRanges.insert(Phi).second) {
6887 ConstantRange RangeFromOps(BitWidth, /*isFullSet=*/false);
6888
6889 for (const auto &Op : Phi->operands()) {
6890 auto OpRange = getRangeRef(getSCEV(Op), SignHint, Depth + 1);
6891 RangeFromOps = RangeFromOps.unionWith(OpRange);
6892 // No point to continue if we already have a full set.
6893 if (RangeFromOps.isFullSet())
6894 break;
6895 }
6896 ConservativeResult =
6897 ConservativeResult.intersectWith(RangeFromOps, RangeType);
6898 bool Erased = PendingPhiRanges.erase(Phi);
6899 assert(Erased && "Failed to erase Phi properly?");
6900 (void)Erased;
6901 }
6902 }
6903
6904 // vscale can't be equal to zero
6905 if (const auto *II = dyn_cast<IntrinsicInst>(V))
6906 if (II->getIntrinsicID() == Intrinsic::vscale) {
6908 ConservativeResult = ConservativeResult.difference(Disallowed);
6909 }
6910
6911 return setRange(U, SignHint, std::move(ConservativeResult));
6912 }
6913 case scCouldNotCompute:
6914 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
6915 }
6916
6917 return setRange(S, SignHint, std::move(ConservativeResult));
6918}
6919
6920// Given a StartRange, Step and MaxBECount for an expression compute a range of
6921// values that the expression can take. Initially, the expression has a value
6922// from StartRange and then is changed by Step up to MaxBECount times. Signed
6923// argument defines if we treat Step as signed or unsigned.
6925 const ConstantRange &StartRange,
6926 const APInt &MaxBECount,
6927 bool Signed) {
6928 unsigned BitWidth = Step.getBitWidth();
6929 assert(BitWidth == StartRange.getBitWidth() &&
6930 BitWidth == MaxBECount.getBitWidth() && "mismatched bit widths");
6931 // If either Step or MaxBECount is 0, then the expression won't change, and we
6932 // just need to return the initial range.
6933 if (Step == 0 || MaxBECount == 0)
6934 return StartRange;
6935
6936 // If we don't know anything about the initial value (i.e. StartRange is
6937 // FullRange), then we don't know anything about the final range either.
6938 // Return FullRange.
6939 if (StartRange.isFullSet())
6940 return ConstantRange::getFull(BitWidth);
6941
6942 // If Step is signed and negative, then we use its absolute value, but we also
6943 // note that we're moving in the opposite direction.
6944 bool Descending = Signed && Step.isNegative();
6945
6946 if (Signed)
6947 // This is correct even for INT_SMIN. Let's look at i8 to illustrate this:
6948 // abs(INT_SMIN) = abs(-128) = abs(0x80) = -0x80 = 0x80 = 128.
6949 // This equations hold true due to the well-defined wrap-around behavior of
6950 // APInt.
6951 Step = Step.abs();
6952
6953 // Check if Offset is more than full span of BitWidth. If it is, the
6954 // expression is guaranteed to overflow.
6955 if (APInt::getMaxValue(StartRange.getBitWidth()).udiv(Step).ult(MaxBECount))
6956 return ConstantRange::getFull(BitWidth);
6957
6958 // Offset is by how much the expression can change. Checks above guarantee no
6959 // overflow here.
6960 APInt Offset = Step * MaxBECount;
6961
6962 // Minimum value of the final range will match the minimal value of StartRange
6963 // if the expression is increasing and will be decreased by Offset otherwise.
6964 // Maximum value of the final range will match the maximal value of StartRange
6965 // if the expression is decreasing and will be increased by Offset otherwise.
6966 APInt StartLower = StartRange.getLower();
6967 APInt StartUpper = StartRange.getUpper() - 1;
6968 APInt MovedBoundary = Descending ? (StartLower - std::move(Offset))
6969 : (StartUpper + std::move(Offset));
6970
6971 // It's possible that the new minimum/maximum value will fall into the initial
6972 // range (due to wrap around). This means that the expression can take any
6973 // value in this bitwidth, and we have to return full range.
6974 if (StartRange.contains(MovedBoundary))
6975 return ConstantRange::getFull(BitWidth);
6976
6977 APInt NewLower =
6978 Descending ? std::move(MovedBoundary) : std::move(StartLower);
6979 APInt NewUpper =
6980 Descending ? std::move(StartUpper) : std::move(MovedBoundary);
6981 NewUpper += 1;
6982
6983 // No overflow detected, return [StartLower, StartUpper + Offset + 1) range.
6984 return ConstantRange::getNonEmpty(std::move(NewLower), std::move(NewUpper));
6985}
6986
6987ConstantRange ScalarEvolution::getRangeForAffineAR(const SCEV *Start,
6988 const SCEV *Step,
6989 const APInt &MaxBECount) {
6990 assert(getTypeSizeInBits(Start->getType()) ==
6991 getTypeSizeInBits(Step->getType()) &&
6992 getTypeSizeInBits(Start->getType()) == MaxBECount.getBitWidth() &&
6993 "mismatched bit widths");
6994
6995 // First, consider step signed.
6996 ConstantRange StartSRange = getSignedRange(Start);
6997 ConstantRange StepSRange = getSignedRange(Step);
6998
6999 // If Step can be both positive and negative, we need to find ranges for the
7000 // maximum absolute step values in both directions and union them.
7002 StepSRange.getSignedMin(), StartSRange, MaxBECount, /* Signed = */ true);
7004 StartSRange, MaxBECount,
7005 /* Signed = */ true));
7006
7007 // Next, consider step unsigned.
7009 getUnsignedRangeMax(Step), getUnsignedRange(Start), MaxBECount,
7010 /* Signed = */ false);
7011
7012 // Finally, intersect signed and unsigned ranges.
7014}
7015
7016ConstantRange ScalarEvolution::getRangeForAffineNoSelfWrappingAR(
7017 const SCEVAddRecExpr *AddRec, const SCEV *MaxBECount, unsigned BitWidth,
7018 ScalarEvolution::RangeSignHint SignHint) {
7019 assert(AddRec->isAffine() && "Non-affine AddRecs are not suppored!\n");
7020 assert(AddRec->hasNoSelfWrap() &&
7021 "This only works for non-self-wrapping AddRecs!");
7022 const bool IsSigned = SignHint == HINT_RANGE_SIGNED;
7023 const SCEV *Step = AddRec->getStepRecurrence(*this);
7024 // Only deal with constant step to save compile time.
7025 if (!isa<SCEVConstant>(Step))
7026 return ConstantRange::getFull(BitWidth);
7027 // Let's make sure that we can prove that we do not self-wrap during
7028 // MaxBECount iterations. We need this because MaxBECount is a maximum
7029 // iteration count estimate, and we might infer nw from some exit for which we
7030 // do not know max exit count (or any other side reasoning).
7031 // TODO: Turn into assert at some point.
7032 if (getTypeSizeInBits(MaxBECount->getType()) >
7033 getTypeSizeInBits(AddRec->getType()))
7034 return ConstantRange::getFull(BitWidth);
7035 MaxBECount = getNoopOrZeroExtend(MaxBECount, AddRec->getType());
7036 const SCEV *RangeWidth = getMinusOne(AddRec->getType());
7037 const SCEV *StepAbs = getUMinExpr(Step, getNegativeSCEV(Step));
7038 const SCEV *MaxItersWithoutWrap = getUDivExpr(RangeWidth, StepAbs);
7039 if (!isKnownPredicateViaConstantRanges(ICmpInst::ICMP_ULE, MaxBECount,
7040 MaxItersWithoutWrap))
7041 return ConstantRange::getFull(BitWidth);
7042
7043 ICmpInst::Predicate LEPred =
7045 ICmpInst::Predicate GEPred =
7047 const SCEV *End = AddRec->evaluateAtIteration(MaxBECount, *this);
7048
7049 // We know that there is no self-wrap. Let's take Start and End values and
7050 // look at all intermediate values V1, V2, ..., Vn that IndVar takes during
7051 // the iteration. They either lie inside the range [Min(Start, End),
7052 // Max(Start, End)] or outside it:
7053 //
7054 // Case 1: RangeMin ... Start V1 ... VN End ... RangeMax;
7055 // Case 2: RangeMin Vk ... V1 Start ... End Vn ... Vk + 1 RangeMax;
7056 //
7057 // No self wrap flag guarantees that the intermediate values cannot be BOTH
7058 // outside and inside the range [Min(Start, End), Max(Start, End)]. Using that
7059 // knowledge, let's try to prove that we are dealing with Case 1. It is so if
7060 // Start <= End and step is positive, or Start >= End and step is negative.
7061 const SCEV *Start = applyLoopGuards(AddRec->getStart(), AddRec->getLoop());
7062 ConstantRange StartRange = getRangeRef(Start, SignHint);
7063 ConstantRange EndRange = getRangeRef(End, SignHint);
7064 ConstantRange RangeBetween = StartRange.unionWith(EndRange);
7065 // If they already cover full iteration space, we will know nothing useful
7066 // even if we prove what we want to prove.
7067 if (RangeBetween.isFullSet())
7068 return RangeBetween;
7069 // Only deal with ranges that do not wrap (i.e. RangeMin < RangeMax).
7070 bool IsWrappedSet = IsSigned ? RangeBetween.isSignWrappedSet()
7071 : RangeBetween.isWrappedSet();
7072 if (IsWrappedSet)
7073 return ConstantRange::getFull(BitWidth);
7074
7075 if (isKnownPositive(Step) &&
7076 isKnownPredicateViaConstantRanges(LEPred, Start, End))
7077 return RangeBetween;
7078 if (isKnownNegative(Step) &&
7079 isKnownPredicateViaConstantRanges(GEPred, Start, End))
7080 return RangeBetween;
7081 return ConstantRange::getFull(BitWidth);
7082}
7083
7084ConstantRange ScalarEvolution::getRangeViaFactoring(const SCEV *Start,
7085 const SCEV *Step,
7086 const APInt &MaxBECount) {
7087 // RangeOf({C?A:B,+,C?P:Q}) == RangeOf(C?{A,+,P}:{B,+,Q})
7088 // == RangeOf({A,+,P}) union RangeOf({B,+,Q})
7089
7090 unsigned BitWidth = MaxBECount.getBitWidth();
7091 assert(getTypeSizeInBits(Start->getType()) == BitWidth &&
7092 getTypeSizeInBits(Step->getType()) == BitWidth &&
7093 "mismatched bit widths");
7094
7095 struct SelectPattern {
7096 Value *Condition = nullptr;
7097 APInt TrueValue;
7098 APInt FalseValue;
7099
7100 explicit SelectPattern(ScalarEvolution &SE, unsigned BitWidth,
7101 const SCEV *S) {
7102 std::optional<unsigned> CastOp;
7103 APInt Offset(BitWidth, 0);
7104
7106 "Should be!");
7107
7108 // Peel off a constant offset:
7109 if (auto *SA = dyn_cast<SCEVAddExpr>(S)) {
7110 // In the future we could consider being smarter here and handle
7111 // {Start+Step,+,Step} too.
7112 if (SA->getNumOperands() != 2 || !isa<SCEVConstant>(SA->getOperand(0)))
7113 return;
7114
7115 Offset = cast<SCEVConstant>(SA->getOperand(0))->getAPInt();
7116 S = SA->getOperand(1);
7117 }
7118
7119 // Peel off a cast operation
7120 if (auto *SCast = dyn_cast<SCEVIntegralCastExpr>(S)) {
7121 CastOp = SCast->getSCEVType();
7122 S = SCast->getOperand();
7123 }
7124
7125 using namespace llvm::PatternMatch;
7126
7127 auto *SU = dyn_cast<SCEVUnknown>(S);
7128 const APInt *TrueVal, *FalseVal;
7129 if (!SU ||
7130 !match(SU->getValue(), m_Select(m_Value(Condition), m_APInt(TrueVal),
7131 m_APInt(FalseVal)))) {
7132 Condition = nullptr;
7133 return;
7134 }
7135
7136 TrueValue = *TrueVal;
7137 FalseValue = *FalseVal;
7138
7139 // Re-apply the cast we peeled off earlier
7140 if (CastOp)
7141 switch (*CastOp) {
7142 default:
7143 llvm_unreachable("Unknown SCEV cast type!");
7144
7145 case scTruncate:
7146 TrueValue = TrueValue.trunc(BitWidth);
7147 FalseValue = FalseValue.trunc(BitWidth);
7148 break;
7149 case scZeroExtend:
7150 TrueValue = TrueValue.zext(BitWidth);
7151 FalseValue = FalseValue.zext(BitWidth);
7152 break;
7153 case scSignExtend:
7154 TrueValue = TrueValue.sext(BitWidth);
7155 FalseValue = FalseValue.sext(BitWidth);
7156 break;
7157 }
7158
7159 // Re-apply the constant offset we peeled off earlier
7160 TrueValue += Offset;
7161 FalseValue += Offset;
7162 }
7163
7164 bool isRecognized() { return Condition != nullptr; }
7165 };
7166
7167 SelectPattern StartPattern(*this, BitWidth, Start);
7168 if (!StartPattern.isRecognized())
7169 return ConstantRange::getFull(BitWidth);
7170
7171 SelectPattern StepPattern(*this, BitWidth, Step);
7172 if (!StepPattern.isRecognized())
7173 return ConstantRange::getFull(BitWidth);
7174
7175 if (StartPattern.Condition != StepPattern.Condition) {
7176 // We don't handle this case today; but we could, by considering four
7177 // possibilities below instead of two. I'm not sure if there are cases where
7178 // that will help over what getRange already does, though.
7179 return ConstantRange::getFull(BitWidth);
7180 }
7181
7182 // NB! Calling ScalarEvolution::getConstant is fine, but we should not try to
7183 // construct arbitrary general SCEV expressions here. This function is called
7184 // from deep in the call stack, and calling getSCEV (on a sext instruction,
7185 // say) can end up caching a suboptimal value.
7186
7187 // FIXME: without the explicit `this` receiver below, MSVC errors out with
7188 // C2352 and C2512 (otherwise it isn't needed).
7189
7190 const SCEV *TrueStart = this->getConstant(StartPattern.TrueValue);
7191 const SCEV *TrueStep = this->getConstant(StepPattern.TrueValue);
7192 const SCEV *FalseStart = this->getConstant(StartPattern.FalseValue);
7193 const SCEV *FalseStep = this->getConstant(StepPattern.FalseValue);
7194
7195 ConstantRange TrueRange =
7196 this->getRangeForAffineAR(TrueStart, TrueStep, MaxBECount);
7197 ConstantRange FalseRange =
7198 this->getRangeForAffineAR(FalseStart, FalseStep, MaxBECount);
7199
7200 return TrueRange.unionWith(FalseRange);
7201}
7202
7203SCEV::NoWrapFlags ScalarEvolution::getNoWrapFlagsFromUB(const Value *V) {
7204 if (isa<ConstantExpr>(V)) return SCEV::FlagAnyWrap;
7205 const BinaryOperator *BinOp = cast<BinaryOperator>(V);
7206
7207 // Return early if there are no flags to propagate to the SCEV.
7209 if (BinOp->hasNoUnsignedWrap())
7211 if (BinOp->hasNoSignedWrap())
7213 if (Flags == SCEV::FlagAnyWrap)
7214 return SCEV::FlagAnyWrap;
7215
7216 return isSCEVExprNeverPoison(BinOp) ? Flags : SCEV::FlagAnyWrap;
7217}
7218
7219const Instruction *
7220ScalarEvolution::getNonTrivialDefiningScopeBound(const SCEV *S) {
7221 if (auto *AddRec = dyn_cast<SCEVAddRecExpr>(S))
7222 return &*AddRec->getLoop()->getHeader()->begin();
7223 if (auto *U = dyn_cast<SCEVUnknown>(S))
7224 if (auto *I = dyn_cast<Instruction>(U->getValue()))
7225 return I;
7226 return nullptr;
7227}
7228
7229const Instruction *
7230ScalarEvolution::getDefiningScopeBound(ArrayRef<const SCEV *> Ops,
7231 bool &Precise) {
7232 Precise = true;
7233 // Do a bounded search of the def relation of the requested SCEVs.
7236 auto pushOp = [&](const SCEV *S) {
7237 if (!Visited.insert(S).second)
7238 return;
7239 // Threshold of 30 here is arbitrary.
7240 if (Visited.size() > 30) {
7241 Precise = false;
7242 return;
7243 }
7244 Worklist.push_back(S);
7245 };
7246
7247 for (const auto *S : Ops)
7248 pushOp(S);
7249
7250 const Instruction *Bound = nullptr;
7251 while (!Worklist.empty()) {
7252 auto *S = Worklist.pop_back_val();
7253 if (auto *DefI = getNonTrivialDefiningScopeBound(S)) {
7254 if (!Bound || DT.dominates(Bound, DefI))
7255 Bound = DefI;
7256 } else {
7257 for (const auto *Op : S->operands())
7258 pushOp(Op);
7259 }
7260 }
7261 return Bound ? Bound : &*F.getEntryBlock().begin();
7262}
7263
7264const Instruction *
7265ScalarEvolution::getDefiningScopeBound(ArrayRef<const SCEV *> Ops) {
7266 bool Discard;
7267 return getDefiningScopeBound(Ops, Discard);
7268}
7269
7270bool ScalarEvolution::isGuaranteedToTransferExecutionTo(const Instruction *A,
7271 const Instruction *B) {
7272 if (A->getParent() == B->getParent() &&
7274 B->getIterator()))
7275 return true;
7276
7277 auto *BLoop = LI.getLoopFor(B->getParent());
7278 if (BLoop && BLoop->getHeader() == B->getParent() &&
7279 BLoop->getLoopPreheader() == A->getParent() &&
7281 A->getParent()->end()) &&
7282 isGuaranteedToTransferExecutionToSuccessor(B->getParent()->begin(),
7283 B->getIterator()))
7284 return true;
7285 return false;
7286}
7287
7288
7289bool ScalarEvolution::isSCEVExprNeverPoison(const Instruction *I) {
7290 // Only proceed if we can prove that I does not yield poison.
7292 return false;
7293
7294 // At this point we know that if I is executed, then it does not wrap
7295 // according to at least one of NSW or NUW. If I is not executed, then we do
7296 // not know if the calculation that I represents would wrap. Multiple
7297 // instructions can map to the same SCEV. If we apply NSW or NUW from I to
7298 // the SCEV, we must guarantee no wrapping for that SCEV also when it is
7299 // derived from other instructions that map to the same SCEV. We cannot make
7300 // that guarantee for cases where I is not executed. So we need to find a
7301 // upper bound on the defining scope for the SCEV, and prove that I is
7302 // executed every time we enter that scope. When the bounding scope is a
7303 // loop (the common case), this is equivalent to proving I executes on every
7304 // iteration of that loop.
7306 for (const Use &Op : I->operands()) {
7307 // I could be an extractvalue from a call to an overflow intrinsic.
7308 // TODO: We can do better here in some cases.
7309 if (isSCEVable(Op->getType()))
7310 SCEVOps.push_back(getSCEV(Op));
7311 }
7312 auto *DefI = getDefiningScopeBound(SCEVOps);
7313 return isGuaranteedToTransferExecutionTo(DefI, I);
7314}
7315
7316bool ScalarEvolution::isAddRecNeverPoison(const Instruction *I, const Loop *L) {
7317 // If we know that \c I can never be poison period, then that's enough.
7318 if (isSCEVExprNeverPoison(I))
7319 return true;
7320
7321 // If the loop only has one exit, then we know that, if the loop is entered,
7322 // any instruction dominating that exit will be executed. If any such
7323 // instruction would result in UB, the addrec cannot be poison.
7324 //
7325 // This is basically the same reasoning as in isSCEVExprNeverPoison(), but
7326 // also handles uses outside the loop header (they just need to dominate the
7327 // single exit).
7328
7329 auto *ExitingBB = L->getExitingBlock();
7330 if (!ExitingBB || !loopHasNoAbnormalExits(L))
7331 return false;
7332
7335
7336 // We start by assuming \c I, the post-inc add recurrence, is poison. Only
7337 // things that are known to be poison under that assumption go on the
7338 // Worklist.
7339 KnownPoison.insert(I);
7340 Worklist.push_back(I);
7341
7342 while (!Worklist.empty()) {
7343 const Instruction *Poison = Worklist.pop_back_val();
7344
7345 for (const Use &U : Poison->uses()) {
7346 const Instruction *PoisonUser = cast<Instruction>(U.getUser());
7347 if (mustTriggerUB(PoisonUser, KnownPoison) &&
7348 DT.dominates(PoisonUser->getParent(), ExitingBB))
7349 return true;
7350
7351 if (propagatesPoison(U) && L->contains(PoisonUser))
7352 if (KnownPoison.insert(PoisonUser).second)
7353 Worklist.push_back(PoisonUser);
7354 }
7355 }
7356
7357 return false;
7358}
7359
7360ScalarEvolution::LoopProperties
7361ScalarEvolution::getLoopProperties(const Loop *L) {
7362 using LoopProperties = ScalarEvolution::LoopProperties;
7363
7364 auto Itr = LoopPropertiesCache.find(L);
7365 if (Itr == LoopPropertiesCache.end()) {
7366 auto HasSideEffects = [](Instruction *I) {
7367 if (auto *SI = dyn_cast<StoreInst>(I))
7368 return !SI->isSimple();
7369
7370 return I->mayThrow() || I->mayWriteToMemory();
7371 };
7372
7373 LoopProperties LP = {/* HasNoAbnormalExits */ true,
7374 /*HasNoSideEffects*/ true};
7375
7376 for (auto *BB : L->getBlocks())
7377 for (auto &I : *BB) {
7379 LP.HasNoAbnormalExits = false;
7380 if (HasSideEffects(&I))
7381 LP.HasNoSideEffects = false;
7382 if (!LP.HasNoAbnormalExits && !LP.HasNoSideEffects)
7383 break; // We're already as pessimistic as we can get.
7384 }
7385
7386 auto InsertPair = LoopPropertiesCache.insert({L, LP});
7387 assert(InsertPair.second && "We just checked!");
7388 Itr = InsertPair.first;
7389 }
7390
7391 return Itr->second;
7392}
7393
7395 // A mustprogress loop without side effects must be finite.
7396 // TODO: The check used here is very conservative. It's only *specific*
7397 // side effects which are well defined in infinite loops.
7398 return isFinite(L) || (isMustProgress(L) && loopHasNoSideEffects(L));
7399}
7400
7401const SCEV *ScalarEvolution::createSCEVIter(Value *V) {
7402 // Worklist item with a Value and a bool indicating whether all operands have
7403 // been visited already.
7406
7407 Stack.emplace_back(V, true);
7408 Stack.emplace_back(V, false);
7409 while (!Stack.empty()) {
7410 auto E = Stack.pop_back_val();
7411 Value *CurV = E.getPointer();
7412
7413 if (getExistingSCEV(CurV))
7414 continue;
7415
7417 const SCEV *CreatedSCEV = nullptr;
7418 // If all operands have been visited already, create the SCEV.
7419 if (E.getInt()) {
7420 CreatedSCEV = createSCEV(CurV);
7421 } else {
7422 // Otherwise get the operands we need to create SCEV's for before creating
7423 // the SCEV for CurV. If the SCEV for CurV can be constructed trivially,
7424 // just use it.
7425 CreatedSCEV = getOperandsToCreate(CurV, Ops);
7426 }
7427
7428 if (CreatedSCEV) {
7429 insertValueToMap(CurV, CreatedSCEV);
7430 } else {
7431 // Queue CurV for SCEV creation, followed by its's operands which need to
7432 // be constructed first.
7433 Stack.emplace_back(CurV, true);
7434 for (Value *Op : Ops)
7435 Stack.emplace_back(Op, false);
7436 }
7437 }
7438
7439 return getExistingSCEV(V);
7440}
7441
7442const SCEV *
7443ScalarEvolution::getOperandsToCreate(Value *V, SmallVectorImpl<Value *> &Ops) {
7444 if (!isSCEVable(V->getType()))
7445 return getUnknown(V);
7446
7447 if (Instruction *I = dyn_cast<Instruction>(V)) {
7448 // Don't attempt to analyze instructions in blocks that aren't
7449 // reachable. Such instructions don't matter, and they aren't required
7450 // to obey basic rules for definitions dominating uses which this
7451 // analysis depends on.
7452 if (!DT.isReachableFromEntry(I->getParent()))
7453 return getUnknown(PoisonValue::get(V->getType()));
7454 } else if (ConstantInt *CI = dyn_cast<ConstantInt>(V))
7455 return getConstant(CI);
7456 else if (isa<GlobalAlias>(V))
7457 return getUnknown(V);
7458 else if (!isa<ConstantExpr>(V))
7459 return getUnknown(V);
7460
7461 Operator *U = cast<Operator>(V);
7462 if (auto BO =
7463 MatchBinaryOp(U, getDataLayout(), AC, DT, dyn_cast<Instruction>(V))) {
7464 bool IsConstArg = isa<ConstantInt>(BO->RHS);
7465 switch (BO->Opcode) {
7466 case Instruction::Add:
7467 case Instruction::Mul: {
7468 // For additions and multiplications, traverse add/mul chains for which we
7469 // can potentially create a single SCEV, to reduce the number of
7470 // get{Add,Mul}Expr calls.
7471 do {
7472 if (BO->Op) {
7473 if (BO->Op != V && getExistingSCEV(BO->Op)) {
7474 Ops.push_back(BO->Op);
7475 break;
7476 }
7477 }
7478 Ops.push_back(BO->RHS);
7479 auto NewBO = MatchBinaryOp(BO->LHS, getDataLayout(), AC, DT,
7480 dyn_cast<Instruction>(V));
7481 if (!NewBO ||
7482 (BO->Opcode == Instruction::Add &&
7483 (NewBO->Opcode != Instruction::Add &&
7484 NewBO->Opcode != Instruction::Sub)) ||
7485 (BO->Opcode == Instruction::Mul &&
7486 NewBO->Opcode != Instruction::Mul)) {
7487 Ops.push_back(BO->LHS);
7488 break;
7489 }
7490 // CreateSCEV calls getNoWrapFlagsFromUB, which under certain conditions
7491 // requires a SCEV for the LHS.
7492 if (BO->Op && (BO->IsNSW || BO->IsNUW)) {
7493 auto *I = dyn_cast<Instruction>(BO->Op);
7494 if (I && programUndefinedIfPoison(I)) {
7495 Ops.push_back(BO->LHS);
7496 break;
7497 }
7498 }
7499 BO = NewBO;
7500 } while (true);
7501 return nullptr;
7502 }
7503 case Instruction::Sub:
7504 case Instruction::UDiv:
7505 case Instruction::URem:
7506 break;
7507 case Instruction::AShr:
7508 case Instruction::Shl:
7509 case Instruction::Xor:
7510 if (!IsConstArg)
7511 return nullptr;
7512 break;
7513 case Instruction::And:
7514 case Instruction::Or:
7515 if (!IsConstArg && !BO->LHS->getType()->isIntegerTy(1))
7516 return nullptr;
7517 break;
7518 case Instruction::LShr:
7519 return getUnknown(V);
7520 default:
7521 llvm_unreachable("Unhandled binop");
7522 break;
7523 }
7524
7525 Ops.push_back(BO->LHS);
7526 Ops.push_back(BO->RHS);
7527 return nullptr;
7528 }
7529
7530 switch (U->getOpcode()) {
7531 case Instruction::Trunc:
7532 case Instruction::ZExt:
7533 case Instruction::SExt:
7534 case Instruction::PtrToInt:
7535 Ops.push_back(U->getOperand(0));
7536 return nullptr;
7537
7538 case Instruction::BitCast:
7539 if (isSCEVable(U->getType()) && isSCEVable(U->getOperand(0)->getType())) {
7540 Ops.push_back(U->getOperand(0));
7541 return nullptr;
7542 }
7543 return getUnknown(V);
7544
7545 case Instruction::SDiv:
7546 case Instruction::SRem:
7547 Ops.push_back(U->getOperand(0));
7548 Ops.push_back(U->getOperand(1));
7549 return nullptr;
7550
7551 case Instruction::GetElementPtr:
7552 assert(cast<GEPOperator>(U)->getSourceElementType()->isSized() &&
7553 "GEP source element type must be sized");
7554 for (Value *Index : U->operands())
7555 Ops.push_back(Index);
7556 return nullptr;
7557
7558 case Instruction::IntToPtr:
7559 return getUnknown(V);
7560
7561 case Instruction::PHI:
7562 // Keep constructing SCEVs' for phis recursively for now.
7563 return nullptr;
7564
7565 case Instruction::Select: {
7566 // Check if U is a select that can be simplified to a SCEVUnknown.
7567 auto CanSimplifyToUnknown = [this, U]() {
7568 if (U->getType()->isIntegerTy(1) || isa<ConstantInt>(U->getOperand(0)))
7569 return false;
7570
7571 auto *ICI = dyn_cast<ICmpInst>(U->getOperand(0));
7572 if (!ICI)
7573 return false;
7574 Value *LHS = ICI->getOperand(0);
7575 Value *RHS = ICI->getOperand(1);
7576 if (ICI->getPredicate() == CmpInst::ICMP_EQ ||
7577 ICI->getPredicate() == CmpInst::ICMP_NE) {
7578 if (!(isa<ConstantInt>(RHS) && cast<ConstantInt>(RHS)->isZero()))
7579 return true;
7580 } else if (getTypeSizeInBits(LHS->getType()) >
7581 getTypeSizeInBits(U->getType()))
7582 return true;
7583 return false;
7584 };
7585 if (CanSimplifyToUnknown())
7586 return getUnknown(U);
7587
7588 for (Value *Inc : U->operands())
7589 Ops.push_back(Inc);
7590 return nullptr;
7591 break;
7592 }
7593 case Instruction::Call:
7594 case Instruction::Invoke:
7595 if (Value *RV = cast<CallBase>(U)->getReturnedArgOperand()) {
7596 Ops.push_back(RV);
7597 return nullptr;
7598 }
7599
7600 if (auto *II = dyn_cast<IntrinsicInst>(U)) {
7601 switch (II->getIntrinsicID()) {
7602 case Intrinsic::abs:
7603 Ops.push_back(II->getArgOperand(0));
7604 return nullptr;
7605 case Intrinsic::umax:
7606 case Intrinsic::umin:
7607 case Intrinsic::smax:
7608 case Intrinsic::smin:
7609 case Intrinsic::usub_sat:
7610 case Intrinsic::uadd_sat:
7611 Ops.push_back(II->getArgOperand(0));
7612 Ops.push_back(II->getArgOperand(1));
7613 return nullptr;
7614 case Intrinsic::start_loop_iterations:
7615 case Intrinsic::annotation:
7616 case Intrinsic::ptr_annotation:
7617 Ops.push_back(II->getArgOperand(0));
7618 return nullptr;
7619 default:
7620 break;
7621 }
7622 }
7623 break;
7624 }
7625
7626 return nullptr;
7627}
7628
7629const SCEV *ScalarEvolution::createSCEV(Value *V) {
7630 if (!isSCEVable(V->getType()))
7631 return getUnknown(V);
7632
7633 if (Instruction *I = dyn_cast<Instruction>(V)) {
7634 // Don't attempt to analyze instructions in blocks that aren't
7635 // reachable. Such instructions don't matter, and they aren't required
7636 // to obey basic rules for definitions dominating uses which this
7637 // analysis depends on.
7638 if (!DT.isReachableFromEntry(I->getParent()))
7639 return getUnknown(PoisonValue::get(V->getType()));
7640 } else if (ConstantInt *CI = dyn_cast<ConstantInt>(V))
7641 return getConstant(CI);
7642 else if (isa<GlobalAlias>(V))
7643 return getUnknown(V);
7644 else if (!isa<ConstantExpr>(V))
7645 return getUnknown(V);
7646
7647 const SCEV *LHS;
7648 const SCEV *RHS;
7649
7650 Operator *U = cast<Operator>(V);
7651 if (auto BO =
7652 MatchBinaryOp(U, getDataLayout(), AC, DT, dyn_cast<Instruction>(V))) {
7653 switch (BO->Opcode) {
7654 case Instruction::Add: {
7655 // The simple thing to do would be to just call getSCEV on both operands
7656 // and call getAddExpr with the result. However if we're looking at a
7657 // bunch of things all added together, this can be quite inefficient,
7658 // because it leads to N-1 getAddExpr calls for N ultimate operands.
7659 // Instead, gather up all the operands and make a single getAddExpr call.
7660 // LLVM IR canonical form means we need only traverse the left operands.
7662 do {
7663 if (BO->Op) {
7664 if (auto *OpSCEV = getExistingSCEV(BO->Op)) {
7665 AddOps.push_back(OpSCEV);
7666 break;
7667 }
7668
7669 // If a NUW or NSW flag can be applied to the SCEV for this
7670 // addition, then compute the SCEV for this addition by itself
7671 // with a separate call to getAddExpr. We need to do that
7672 // instead of pushing the operands of the addition onto AddOps,
7673 // since the flags are only known to apply to this particular
7674 // addition - they may not apply to other additions that can be
7675 // formed with operands from AddOps.
7676 const SCEV *RHS = getSCEV(BO->RHS);
7677 SCEV::NoWrapFlags Flags = getNoWrapFlagsFromUB(BO->Op);
7678 if (Flags != SCEV::FlagAnyWrap) {
7679 const SCEV *LHS = getSCEV(BO->LHS);
7680 if (BO->Opcode == Instruction::Sub)
7681 AddOps.push_back(getMinusSCEV(LHS, RHS, Flags));
7682 else
7683 AddOps.push_back(getAddExpr(LHS, RHS, Flags));
7684 break;
7685 }
7686 }
7687
7688 if (BO->Opcode == Instruction::Sub)
7689 AddOps.push_back(getNegativeSCEV(getSCEV(BO->RHS)));
7690 else
7691 AddOps.push_back(getSCEV(BO->RHS));
7692
7693 auto NewBO = MatchBinaryOp(BO->LHS, getDataLayout(), AC, DT,
7694 dyn_cast<Instruction>(V));
7695 if (!NewBO || (NewBO->Opcode != Instruction::Add &&
7696 NewBO->Opcode != Instruction::Sub)) {
7697 AddOps.push_back(getSCEV(BO->LHS));
7698 break;
7699 }
7700 BO = NewBO;
7701 } while (true);
7702
7703 return getAddExpr(AddOps);
7704 }
7705
7706 case Instruction::Mul: {
7708 do {
7709 if (BO->Op) {
7710 if (auto *OpSCEV = getExistingSCEV(BO->Op)) {
7711 MulOps.push_back(OpSCEV);
7712 break;
7713 }
7714
7715 SCEV::NoWrapFlags Flags = getNoWrapFlagsFromUB(BO->Op);
7716 if (Flags != SCEV::FlagAnyWrap) {
7717 LHS = getSCEV(BO->LHS);
7718 RHS = getSCEV(BO->RHS);
7719 MulOps.push_back(getMulExpr(LHS, RHS, Flags));
7720 break;
7721 }
7722 }
7723
7724 MulOps.push_back(getSCEV(BO->RHS));
7725 auto NewBO = MatchBinaryOp(BO->LHS, getDataLayout(), AC, DT,
7726 dyn_cast<Instruction>(V));
7727 if (!NewBO || NewBO->Opcode != Instruction::Mul) {
7728 MulOps.push_back(getSCEV(BO->LHS));
7729 break;
7730 }
7731 BO = NewBO;
7732 } while (true);
7733
7734 return getMulExpr(MulOps);
7735 }
7736 case Instruction::UDiv:
7737 LHS = getSCEV(BO->LHS);
7738 RHS = getSCEV(BO->RHS);
7739 return getUDivExpr(LHS, RHS);
7740 case Instruction::URem:
7741 LHS = getSCEV(BO->LHS);
7742 RHS = getSCEV(BO->RHS);
7743 return getURemExpr(LHS, RHS);
7744 case Instruction::Sub: {
7746 if (BO->Op)
7747 Flags = getNoWrapFlagsFromUB(BO->Op);
7748 LHS = getSCEV(BO->LHS);
7749 RHS = getSCEV(BO->RHS);
7750 return getMinusSCEV(LHS, RHS, Flags);
7751 }
7752 case Instruction::And:
7753 // For an expression like x&255 that merely masks off the high bits,
7754 // use zext(trunc(x)) as the SCEV expression.
7755 if (ConstantInt *CI = dyn_cast<ConstantInt>(BO->RHS)) {
7756 if (CI->isZero())
7757 return getSCEV(BO->RHS);
7758 if (CI->isMinusOne())
7759 return getSCEV(BO->LHS);
7760 const APInt &A = CI->getValue();
7761
7762 // Instcombine's ShrinkDemandedConstant may strip bits out of
7763 // constants, obscuring what would otherwise be a low-bits mask.
7764 // Use computeKnownBits to compute what ShrinkDemandedConstant
7765 // knew about to reconstruct a low-bits mask value.
7766 unsigned LZ = A.countl_zero();
7767 unsigned TZ = A.countr_zero();
7768 unsigned BitWidth = A.getBitWidth();
7769 KnownBits Known(BitWidth);
7770 computeKnownBits(BO->LHS, Known, getDataLayout(),
7771 0, &AC, nullptr, &DT);
7772
7773 APInt EffectiveMask =
7774 APInt::getLowBitsSet(BitWidth, BitWidth - LZ - TZ).shl(TZ);
7775 if ((LZ != 0 || TZ != 0) && !((~A & ~Known.Zero) & EffectiveMask)) {
7776 const SCEV *MulCount = getConstant(APInt::getOneBitSet(BitWidth, TZ));
7777 const SCEV *LHS = getSCEV(BO->LHS);
7778 const SCEV *ShiftedLHS = nullptr;
7779 if (auto *LHSMul = dyn_cast<SCEVMulExpr>(LHS)) {
7780 if (auto *OpC = dyn_cast<SCEVConstant>(LHSMul->getOperand(0))) {
7781 // For an expression like (x * 8) & 8, simplify the multiply.
7782 unsigned MulZeros = OpC->getAPInt().countr_zero();
7783 unsigned GCD = std::min(MulZeros, TZ);
7784 APInt DivAmt = APInt::getOneBitSet(BitWidth, TZ - GCD);
7786 MulOps.push_back(getConstant(OpC->getAPInt().lshr(GCD)));
7787 append_range(MulOps, LHSMul->operands().drop_front());
7788 auto *NewMul = getMulExpr(MulOps, LHSMul->getNoWrapFlags());
7789 ShiftedLHS = getUDivExpr(NewMul, getConstant(DivAmt));
7790 }
7791 }
7792 if (!ShiftedLHS)
7793 ShiftedLHS = getUDivExpr(LHS, MulCount);
7794 return getMulExpr(
7796 getTruncateExpr(ShiftedLHS,
7797 IntegerType::get(getContext(), BitWidth - LZ - TZ)),
7798 BO->LHS->getType()),
7799 MulCount);
7800 }
7801 }
7802 // Binary `and` is a bit-wise `umin`.
7803 if (BO->LHS->getType()->isIntegerTy(1)) {
7804 LHS = getSCEV(BO->LHS);
7805 RHS = getSCEV(BO->RHS);
7806 return getUMinExpr(LHS, RHS);
7807 }
7808 break;
7809
7810 case Instruction::Or:
7811 // Binary `or` is a bit-wise `umax`.
7812 if (BO->LHS->getType()->isIntegerTy(1)) {
7813 LHS = getSCEV(BO->LHS);
7814 RHS = getSCEV(BO->RHS);
7815 return getUMaxExpr(LHS, RHS);
7816 }
7817 break;
7818
7819 case Instruction::Xor:
7820 if (ConstantInt *CI = dyn_cast<ConstantInt>(BO->RHS)) {
7821 // If the RHS of xor is -1, then this is a not operation.
7822 if (CI->isMinusOne())
7823 return getNotSCEV(getSCEV(BO->LHS));
7824
7825 // Model xor(and(x, C), C) as and(~x, C), if C is a low-bits mask.
7826 // This is a variant of the check for xor with -1, and it handles
7827 // the case where instcombine has trimmed non-demanded bits out
7828 // of an xor with -1.
7829 if (auto *LBO = dyn_cast<BinaryOperator>(BO->LHS))
7830 if (ConstantInt *LCI = dyn_cast<ConstantInt>(LBO->getOperand(1)))
7831 if (LBO->getOpcode() == Instruction::And &&
7832 LCI->getValue() == CI->getValue())
7833 if (const SCEVZeroExtendExpr *Z =
7834 dyn_cast<SCEVZeroExtendExpr>(getSCEV(BO->LHS))) {
7835 Type *UTy = BO->LHS->getType();
7836 const SCEV *Z0 = Z->getOperand();
7837 Type *Z0Ty = Z0->getType();
7838 unsigned Z0TySize = getTypeSizeInBits(Z0Ty);
7839
7840 // If C is a low-bits mask, the zero extend is serving to
7841 // mask off the high bits. Complement the operand and
7842 // re-apply the zext.
7843 if (CI->getValue().isMask(Z0TySize))
7844 return getZeroExtendExpr(getNotSCEV(Z0), UTy);
7845
7846 // If C is a single bit, it may be in the sign-bit position
7847 // before the zero-extend. In this case, represent the xor
7848 // using an add, which is equivalent, and re-apply the zext.
7849 APInt Trunc = CI->getValue().trunc(Z0TySize);
7850 if (Trunc.zext(getTypeSizeInBits(UTy)) == CI->getValue() &&
7851 Trunc.isSignMask())
7852 return getZeroExtendExpr(getAddExpr(Z0, getConstant(Trunc)),
7853 UTy);
7854 }
7855 }
7856 break;
7857
7858 case Instruction::Shl:
7859 // Turn shift left of a constant amount into a multiply.
7860 if (ConstantInt *SA = dyn_cast<ConstantInt>(BO->RHS)) {
7861 uint32_t BitWidth = cast<IntegerType>(SA->getType())->getBitWidth();
7862
7863 // If the shift count is not less than the bitwidth, the result of
7864 // the shift is undefined. Don't try to analyze it, because the
7865 // resolution chosen here may differ from the resolution chosen in
7866 // other parts of the compiler.
7867 if (SA->getValue().uge(BitWidth))
7868 break;
7869
7870 // We can safely preserve the nuw flag in all cases. It's also safe to
7871 // turn a nuw nsw shl into a nuw nsw mul. However, nsw in isolation
7872 // requires special handling. It can be preserved as long as we're not
7873 // left shifting by bitwidth - 1.
7874 auto Flags = SCEV::FlagAnyWrap;
7875 if (BO->Op) {
7876 auto MulFlags = getNoWrapFlagsFromUB(BO->Op);
7877 if ((MulFlags & SCEV::FlagNSW) &&
7878 ((MulFlags & SCEV::FlagNUW) || SA->getValue().ult(BitWidth - 1)))
7880 if (MulFlags & SCEV::FlagNUW)
7882 }
7883
7884 ConstantInt *X = ConstantInt::get(
7885 getContext(), APInt::getOneBitSet(BitWidth, SA->getZExtValue()));
7886 return getMulExpr(getSCEV(BO->LHS), getConstant(X), Flags);
7887 }
7888 break;
7889
7890 case Instruction::AShr:
7891 // AShr X, C, where C is a constant.
7892 ConstantInt *CI = dyn_cast<ConstantInt>(BO->RHS);
7893 if (!CI)
7894 break;
7895
7896 Type *OuterTy = BO->LHS->getType();
7898 // If the shift count is not less than the bitwidth, the result of
7899 // the shift is undefined. Don't try to analyze it, because the
7900 // resolution chosen here may differ from the resolution chosen in
7901 // other parts of the compiler.
7902 if (CI->getValue().uge(BitWidth))
7903 break;
7904
7905 if (CI->isZero())
7906 return getSCEV(BO->LHS); // shift by zero --> noop
7907
7908 uint64_t AShrAmt = CI->getZExtValue();
7909 Type *TruncTy = IntegerType::get(getContext(), BitWidth - AShrAmt);
7910
7911 Operator *L = dyn_cast<Operator>(BO->LHS);
7912 const SCEV *AddTruncateExpr = nullptr;
7913 ConstantInt *ShlAmtCI = nullptr;
7914 const SCEV *AddConstant = nullptr;
7915
7916 if (L && L->getOpcode() == Instruction::Add) {
7917 // X = Shl A, n
7918 // Y = Add X, c
7919 // Z = AShr Y, m
7920 // n, c and m are constants.
7921
7922 Operator *LShift = dyn_cast<Operator>(L->getOperand(0));
7923 ConstantInt *AddOperandCI = dyn_cast<ConstantInt>(L->getOperand(1));
7924 if (LShift && LShift->getOpcode() == Instruction::Shl) {
7925 if (AddOperandCI) {
7926 const SCEV *ShlOp0SCEV = getSCEV(LShift->getOperand(0));
7927 ShlAmtCI = dyn_cast<ConstantInt>(LShift->getOperand(1));
7928 // since we truncate to TruncTy, the AddConstant should be of the
7929 // same type, so create a new Constant with type same as TruncTy.
7930 // Also, the Add constant should be shifted right by AShr amount.
7931 APInt AddOperand = AddOperandCI->getValue().ashr(AShrAmt);
7932 AddConstant = getConstant(AddOperand.trunc(BitWidth - AShrAmt));
7933 // we model the expression as sext(add(trunc(A), c << n)), since the
7934 // sext(trunc) part is already handled below, we create a
7935 // AddExpr(TruncExp) which will be used later.
7936 AddTruncateExpr = getTruncateExpr(ShlOp0SCEV, TruncTy);
7937 }
7938 }
7939 } else if (L && L->getOpcode() == Instruction::Shl) {
7940 // X = Shl A, n
7941 // Y = AShr X, m
7942 // Both n and m are constant.
7943
7944 const SCEV *ShlOp0SCEV = getSCEV(L->getOperand(0));
7945 ShlAmtCI = dyn_cast<ConstantInt>(L->getOperand(1));
7946 AddTruncateExpr = getTruncateExpr(ShlOp0SCEV, TruncTy);
7947 }
7948
7949 if (AddTruncateExpr && ShlAmtCI) {
7950 // We can merge the two given cases into a single SCEV statement,
7951 // incase n = m, the mul expression will be 2^0, so it gets resolved to
7952 // a simpler case. The following code handles the two cases:
7953 //
7954 // 1) For a two-shift sext-inreg, i.e. n = m,
7955 // use sext(trunc(x)) as the SCEV expression.
7956 //
7957 // 2) When n > m, use sext(mul(trunc(x), 2^(n-m)))) as the SCEV
7958 // expression. We already checked that ShlAmt < BitWidth, so
7959 // the multiplier, 1 << (ShlAmt - AShrAmt), fits into TruncTy as
7960 // ShlAmt - AShrAmt < Amt.
7961 const APInt &ShlAmt = ShlAmtCI->getValue();
7962 if (ShlAmt.ult(BitWidth) && ShlAmt.uge(AShrAmt)) {
7964 ShlAmtCI->getZExtValue() - AShrAmt);
7965 const SCEV *CompositeExpr =
7966 getMulExpr(AddTruncateExpr, getConstant(Mul));
7967 if (L->getOpcode() != Instruction::Shl)
7968 CompositeExpr = getAddExpr(CompositeExpr, AddConstant);
7969
7970 return getSignExtendExpr(CompositeExpr, OuterTy);
7971 }
7972 }
7973 break;
7974 }
7975 }
7976
7977 switch (U->getOpcode()) {
7978 case Instruction::Trunc:
7979 return getTruncateExpr(getSCEV(U->getOperand(0)), U->getType());
7980
7981 case Instruction::ZExt:
7982 return getZeroExtendExpr(getSCEV(U->getOperand(0)), U->getType());
7983
7984 case Instruction::SExt:
7985 if (auto BO = MatchBinaryOp(U->getOperand(0), getDataLayout(), AC, DT,
7986 dyn_cast<Instruction>(V))) {
7987 // The NSW flag of a subtract does not always survive the conversion to
7988 // A + (-1)*B. By pushing sign extension onto its operands we are much
7989 // more likely to preserve NSW and allow later AddRec optimisations.
7990 //
7991 // NOTE: This is effectively duplicating this logic from getSignExtend:
7992 // sext((A + B + ...)<nsw>) --> (sext(A) + sext(B) + ...)<nsw>
7993 // but by that point the NSW information has potentially been lost.
7994 if (BO->Opcode == Instruction::Sub && BO->IsNSW) {
7995 Type *Ty = U->getType();
7996 auto *V1 = getSignExtendExpr(getSCEV(BO->LHS), Ty);
7997 auto *V2 = getSignExtendExpr(getSCEV(BO->RHS), Ty);
7998 return getMinusSCEV(V1, V2, SCEV::FlagNSW);
7999 }
8000 }
8001 return getSignExtendExpr(getSCEV(U->getOperand(0)), U->getType());
8002
8003 case Instruction::BitCast:
8004 // BitCasts are no-op casts so we just eliminate the cast.
8005 if (isSCEVable(U->getType()) && isSCEVable(U->getOperand(0)->getType()))
8006 return getSCEV(U->getOperand(0));
8007 break;
8008
8009 case Instruction::PtrToInt: {
8010 // Pointer to integer cast is straight-forward, so do model it.
8011 const SCEV *Op = getSCEV(U->getOperand(0));
8012 Type *DstIntTy = U->getType();
8013 // But only if effective SCEV (integer) type is wide enough to represent
8014 // all possible pointer values.
8015 const SCEV *IntOp = getPtrToIntExpr(Op, DstIntTy);
8016 if (isa<SCEVCouldNotCompute>(IntOp))
8017 return getUnknown(V);
8018 return IntOp;
8019 }
8020 case Instruction::IntToPtr:
8021 // Just don't deal with inttoptr casts.
8022 return getUnknown(V);
8023
8024 case Instruction::SDiv:
8025 // If both operands are non-negative, this is just an udiv.
8026 if (isKnownNonNegative(getSCEV(U->getOperand(0))) &&
8027 isKnownNonNegative(getSCEV(U->getOperand(1))))
8028 return getUDivExpr(getSCEV(U->getOperand(0)), getSCEV(U->getOperand(1)));
8029 break;
8030
8031 case Instruction::SRem:
8032 // If both operands are non-negative, this is just an urem.
8033 if (isKnownNonNegative(getSCEV(U->getOperand(0))) &&
8034 isKnownNonNegative(getSCEV(U->getOperand(1))))
8035 return getURemExpr(getSCEV(U->getOperand(0)), getSCEV(U->getOperand(1)));
8036 break;
8037
8038 case Instruction::GetElementPtr:
8039 return createNodeForGEP(cast<GEPOperator>(U));
8040
8041 case Instruction::PHI:
8042 return createNodeForPHI(cast<PHINode>(U));
8043
8044 case Instruction::Select:
8045 return createNodeForSelectOrPHI(U, U->getOperand(0), U->getOperand(1),
8046 U->getOperand(2));
8047
8048 case Instruction::Call:
8049 case Instruction::Invoke:
8050 if (Value *RV = cast<CallBase>(U)->getReturnedArgOperand())
8051 return getSCEV(RV);
8052
8053 if (auto *II = dyn_cast<IntrinsicInst>(U)) {
8054 switch (II->getIntrinsicID()) {
8055 case Intrinsic::abs:
8056 return getAbsExpr(
8057 getSCEV(II->getArgOperand(0)),
8058 /*IsNSW=*/cast<ConstantInt>(II->getArgOperand(1))->isOne());
8059 case Intrinsic::umax:
8060 LHS = getSCEV(II->getArgOperand(0));
8061 RHS = getSCEV(II->getArgOperand(1));
8062 return getUMaxExpr(LHS, RHS);
8063 case Intrinsic::umin:
8064 LHS = getSCEV(II->getArgOperand(0));
8065 RHS = getSCEV(II->getArgOperand(1));
8066 return getUMinExpr(LHS, RHS);
8067 case Intrinsic::smax:
8068 LHS = getSCEV(II->getArgOperand(0));
8069 RHS = getSCEV(II->getArgOperand(1));
8070 return getSMaxExpr(LHS, RHS);
8071 case Intrinsic::smin:
8072 LHS = getSCEV(II->getArgOperand(0));
8073 RHS = getSCEV(II->getArgOperand(1));
8074 return getSMinExpr(LHS, RHS);
8075 case Intrinsic::usub_sat: {
8076 const SCEV *X = getSCEV(II->getArgOperand(0));
8077 const SCEV *Y = getSCEV(II->getArgOperand(1));
8078 const SCEV *ClampedY = getUMinExpr(X, Y);
8079 return getMinusSCEV(X, ClampedY, SCEV::FlagNUW);
8080 }
8081 case Intrinsic::uadd_sat: {
8082 const SCEV *X = getSCEV(II->getArgOperand(0));
8083 const SCEV *Y = getSCEV(II->getArgOperand(1));
8084 const SCEV *ClampedX = getUMinExpr(X, getNotSCEV(Y));
8085 return getAddExpr(ClampedX, Y, SCEV::FlagNUW);
8086 }
8087 case Intrinsic::start_loop_iterations:
8088 case Intrinsic::annotation:
8089 case Intrinsic::ptr_annotation:
8090 // A start_loop_iterations or llvm.annotation or llvm.prt.annotation is
8091 // just eqivalent to the first operand for SCEV purposes.
8092 return getSCEV(II->getArgOperand(0));
8093 case Intrinsic::vscale:
8094 return getVScale(II->getType());
8095 default:
8096 break;
8097 }
8098 }
8099 break;
8100 }
8101
8102 return getUnknown(V);
8103}
8104
8105//===----------------------------------------------------------------------===//
8106// Iteration Count Computation Code
8107//
8108
8110 if (isa<SCEVCouldNotCompute>(ExitCount))
8111 return getCouldNotCompute();
8112
8113 auto *ExitCountType = ExitCount->getType();
8114 assert(ExitCountType->isIntegerTy());
8115 auto *EvalTy = Type::getIntNTy(ExitCountType->getContext(),
8116 1 + ExitCountType->getScalarSizeInBits());
8117 return getTripCountFromExitCount(ExitCount, EvalTy, nullptr);
8118}
8119
8121 Type *EvalTy,
8122 const Loop *L) {
8123 if (isa<SCEVCouldNotCompute>(ExitCount))
8124 return getCouldNotCompute();
8125
8126 unsigned ExitCountSize = getTypeSizeInBits(ExitCount->getType());
8127 unsigned EvalSize = EvalTy->getPrimitiveSizeInBits();
8128
8129 auto CanAddOneWithoutOverflow = [&]() {
8130 ConstantRange ExitCountRange =
8131 getRangeRef(ExitCount, RangeSignHint::HINT_RANGE_UNSIGNED);
8132 if (!ExitCountRange.contains(APInt::getMaxValue(ExitCountSize)))
8133 return true;
8134
8135 return L && isLoopEntryGuardedByCond(L, ICmpInst::ICMP_NE, ExitCount,
8136 getMinusOne(ExitCount->getType()));
8137 };
8138
8139 // If we need to zero extend the backedge count, check if we can add one to
8140 // it prior to zero extending without overflow. Provided this is safe, it
8141 // allows better simplification of the +1.
8142 if (EvalSize > ExitCountSize && CanAddOneWithoutOverflow())
8143 return getZeroExtendExpr(
8144 getAddExpr(ExitCount, getOne(ExitCount->getType())), EvalTy);
8145
8146 // Get the total trip count from the count by adding 1. This may wrap.
8147 return getAddExpr(getTruncateOrZeroExtend(ExitCount, EvalTy), getOne(EvalTy));
8148}
8149
8150static unsigned getConstantTripCount(const SCEVConstant *ExitCount) {
8151 if (!ExitCount)
8152 return 0;
8153
8154 ConstantInt *ExitConst = ExitCount->getValue();
8155
8156 // Guard against huge trip counts.
8157 if (ExitConst->getValue().getActiveBits() > 32)
8158 return 0;
8159
8160 // In case of integer overflow, this returns 0, which is correct.
8161 return ((unsigned)ExitConst->getZExtValue()) + 1;
8162}
8163
8165 auto *ExitCount = dyn_cast<SCEVConstant>(getBackedgeTakenCount(L, Exact));
8166 return getConstantTripCount(ExitCount);
8167}
8168
8169unsigned
8171 const BasicBlock *ExitingBlock) {
8172 assert(ExitingBlock && "Must pass a non-null exiting block!");
8173 assert(L->isLoopExiting(ExitingBlock) &&
8174 "Exiting block must actually branch out of the loop!");
8175 const SCEVConstant *ExitCount =
8176 dyn_cast<SCEVConstant>(getExitCount(L, ExitingBlock));
8177 return getConstantTripCount(ExitCount);
8178}
8179
8181 const auto *MaxExitCount =
8182 dyn_cast<SCEVConstant>(getConstantMaxBackedgeTakenCount(L));
8183 return getConstantTripCount(MaxExitCount);
8184}
8185
8187 SmallVector<BasicBlock *, 8> ExitingBlocks;
8188 L->getExitingBlocks(ExitingBlocks);
8189
8190 std::optional<unsigned> Res;
8191 for (auto *ExitingBB : ExitingBlocks) {
8192 unsigned Multiple = getSmallConstantTripMultiple(L, ExitingBB);
8193 if (!Res)
8194 Res = Multiple;
8195 Res = (unsigned)std::gcd(*Res, Multiple);
8196 }
8197 return Res.value_or(1);
8198}
8199
8201 const SCEV *ExitCount) {
8202 if (ExitCount == getCouldNotCompute())
8203 return 1;
8204
8205 // Get the trip count
8206 const SCEV *TCExpr = getTripCountFromExitCount(applyLoopGuards(ExitCount, L));
8207
8208 APInt Multiple = getNonZeroConstantMultiple(TCExpr);
8209 // If a trip multiple is huge (>=2^32), the trip count is still divisible by
8210 // the greatest power of 2 divisor less than 2^32.
8211 return Multiple.getActiveBits() > 32
8212 ? 1U << std::min((unsigned)31, Multiple.countTrailingZeros())
8213 : (unsigned)Multiple.zextOrTrunc(32).getZExtValue();
8214}
8215
8216/// Returns the largest constant divisor of the trip count of this loop as a
8217/// normal unsigned value, if possible. This means that the actual trip count is
8218/// always a multiple of the returned value (don't forget the trip count could
8219/// very well be zero as well!).
8220///
8221/// Returns 1 if the trip count is unknown or not guaranteed to be the
8222/// multiple of a constant (which is also the case if the trip count is simply
8223/// constant, use getSmallConstantTripCount for that case), Will also return 1
8224/// if the trip count is very large (>= 2^32).
8225///
8226/// As explained in the comments for getSmallConstantTripCount, this assumes
8227/// that control exits the loop via ExitingBlock.
8228unsigned
8230 const BasicBlock *ExitingBlock) {
8231 assert(ExitingBlock && "Must pass a non-null exiting block!");
8232 assert(L->isLoopExiting(ExitingBlock) &&
8233 "Exiting block must actually branch out of the loop!");
8234 const SCEV *ExitCount = getExitCount(L, ExitingBlock);
8235 return getSmallConstantTripMultiple(L, ExitCount);
8236}
8237
8239 const BasicBlock *ExitingBlock,
8240 ExitCountKind Kind) {
8241 switch (Kind) {
8242 case Exact:
8243 return getBackedgeTakenInfo(L).getExact(ExitingBlock, this);
8244 case SymbolicMaximum:
8245 return getBackedgeTakenInfo(L).getSymbolicMax(ExitingBlock, this);
8246 case ConstantMaximum:
8247 return getBackedgeTakenInfo(L).getConstantMax(ExitingBlock, this);
8248 };
8249 llvm_unreachable("Invalid ExitCountKind!");
8250}
8251
8252const SCEV *
8255 return getPredicatedBackedgeTakenInfo(L).getExact(L, this, &Preds);
8256}
8257
8259 ExitCountKind Kind) {
8260 switch (Kind) {
8261 case Exact:
8262 return getBackedgeTakenInfo(L).getExact(L, this);
8263 case ConstantMaximum:
8264 return getBackedgeTakenInfo(L).getConstantMax(this);
8265 case SymbolicMaximum:
8266 return getBackedgeTakenInfo(L).getSymbolicMax(L, this);
8267 };
8268 llvm_unreachable("Invalid ExitCountKind!");
8269}
8270
8273 return getPredicatedBackedgeTakenInfo(L).getSymbolicMax(L, this, &Preds);
8274}
8275
8277 return getBackedgeTakenInfo(L).isConstantMaxOrZero(this);
8278}
8279
8280/// Push PHI nodes in the header of the given loop onto the given Worklist.
8281static void PushLoopPHIs(const Loop *L,
8284 BasicBlock *Header = L->getHeader();
8285
8286 // Push all Loop-header PHIs onto the Worklist stack.
8287 for (PHINode &PN : Header->phis())
8288 if (Visited.insert(&PN).second)
8289 Worklist.push_back(&PN);
8290}
8291
8292ScalarEvolution::BackedgeTakenInfo &
8293ScalarEvolution::getPredicatedBackedgeTakenInfo(const Loop *L) {
8294 auto &BTI = getBackedgeTakenInfo(L);
8295 if (BTI.hasFullInfo())
8296 return BTI;
8297
8298 auto Pair = PredicatedBackedgeTakenCounts.insert({L, BackedgeTakenInfo()});
8299
8300 if (!Pair.second)
8301 return Pair.first->second;
8302
8303 BackedgeTakenInfo Result =
8304 computeBackedgeTakenCount(L, /*AllowPredicates=*/true);
8305
8306 return PredicatedBackedgeTakenCounts.find(L)->second = std::move(Result);
8307}
8308
8309ScalarEvolution::BackedgeTakenInfo &
8310ScalarEvolution::getBackedgeTakenInfo(const Loop *L) {
8311 // Initially insert an invalid entry for this loop. If the insertion
8312 // succeeds, proceed to actually compute a backedge-taken count and
8313 // update the value. The temporary CouldNotCompute value tells SCEV
8314 // code elsewhere that it shouldn't attempt to request a new
8315 // backedge-taken count, which could result in infinite recursion.
8316 std::pair<DenseMap<const Loop *, BackedgeTakenInfo>::iterator, bool> Pair =
8317 BackedgeTakenCounts.insert({L, BackedgeTakenInfo()});
8318 if (!Pair.second)
8319 return Pair.first->second;
8320
8321 // computeBackedgeTakenCount may allocate memory for its result. Inserting it
8322 // into the BackedgeTakenCounts map transfers ownership. Otherwise, the result
8323 // must be cleared in this scope.
8324 BackedgeTakenInfo Result = computeBackedgeTakenCount(L);
8325
8326 // Now that we know more about the trip count for this loop, forget any
8327 // existing SCEV values for PHI nodes in this loop since they are only
8328 // conservative estimates made without the benefit of trip count
8329 // information. This invalidation is not necessary for correctness, and is
8330 // only done to produce more precise results.
8331 if (Result.hasAnyInfo()) {
8332 // Invalidate any expression using an addrec in this loop.
8334 auto LoopUsersIt = LoopUsers.find(L);
8335 if (LoopUsersIt != LoopUsers.end())
8336 append_range(ToForget, LoopUsersIt->second);
8337 forgetMemoizedResults(ToForget);
8338
8339 // Invalidate constant-evolved loop header phis.
8340 for (PHINode &PN : L->getHeader()->phis())
8341 ConstantEvolutionLoopExitValue.erase(&PN);
8342 }
8343
8344 // Re-lookup the insert position, since the call to
8345 // computeBackedgeTakenCount above could result in a
8346 // recusive call to getBackedgeTakenInfo (on a different
8347 // loop), which would invalidate the iterator computed
8348 // earlier.
8349 return BackedgeTakenCounts.find(L)->second = std::move(Result);
8350}
8351
8353 // This method is intended to forget all info about loops. It should
8354 // invalidate caches as if the following happened:
8355 // - The trip counts of all loops have changed arbitrarily
8356 // - Every llvm::Value has been updated in place to produce a different
8357 // result.
8358 BackedgeTakenCounts.clear();
8359 PredicatedBackedgeTakenCounts.clear();
8360 BECountUsers.clear();
8361 LoopPropertiesCache.clear();
8362 ConstantEvolutionLoopExitValue.clear();
8363 ValueExprMap.clear();
8364 ValuesAtScopes.clear();
8365 ValuesAtScopesUsers.clear();
8366 LoopDispositions.clear();
8367 BlockDispositions.clear();
8368 UnsignedRanges.clear();
8369 SignedRanges.clear();
8370 ExprValueMap.clear();
8371 HasRecMap.clear();
8372 ConstantMultipleCache.clear();
8373 PredicatedSCEVRewrites.clear();
8374 FoldCache.clear();
8375 FoldCacheUser.clear();
8376}
8377void ScalarEvolution::visitAndClearUsers(
8381 while (!Worklist.empty()) {
8382 Instruction *I = Worklist.pop_back_val();
8383 if (!isSCEVable(I->getType()) && !isa<WithOverflowInst>(I))
8384 continue;
8385
8387 ValueExprMap.find_as(static_cast<Value *>(I));
8388 if (It != ValueExprMap.end()) {
8389 eraseValueFromMap(It->first);
8390 ToForget.push_back(It->second);
8391 if (PHINode *PN = dyn_cast<PHINode>(I))
8392 ConstantEvolutionLoopExitValue.erase(PN);
8393 }
8394
8395 PushDefUseChildren(I, Worklist, Visited);
8396 }
8397}
8398
8400 SmallVector<const Loop *, 16> LoopWorklist(1, L);
8404
8405 // Iterate over all the loops and sub-loops to drop SCEV information.
8406 while (!LoopWorklist.empty()) {
8407 auto *CurrL = LoopWorklist.pop_back_val();
8408
8409 // Drop any stored trip count value.
8410 forgetBackedgeTakenCounts(CurrL, /* Predicated */ false);
8411 forgetBackedgeTakenCounts(CurrL, /* Predicated */ true);
8412
8413 // Drop information about predicated SCEV rewrites for this loop.
8414 for (auto I = PredicatedSCEVRewrites.begin();
8415 I != PredicatedSCEVRewrites.end();) {
8416 std::pair<const SCEV *, const Loop *> Entry = I->first;
8417 if (Entry.second == CurrL)
8418 PredicatedSCEVRewrites.erase(I++);
8419 else
8420 ++I;
8421 }
8422
8423 auto LoopUsersItr = LoopUsers.find(CurrL);
8424 if (LoopUsersItr != LoopUsers.end()) {
8425 ToForget.insert(ToForget.end(), LoopUsersItr->second.begin(),
8426 LoopUsersItr->second.end());
8427 }
8428
8429 // Drop information about expressions based on loop-header PHIs.
8430 PushLoopPHIs(CurrL, Worklist, Visited);
8431 visitAndClearUsers(Worklist, Visited, ToForget);
8432
8433 LoopPropertiesCache.erase(CurrL);
8434 // Forget all contained loops too, to avoid dangling entries in the
8435 // ValuesAtScopes map.
8436 LoopWorklist.append(CurrL->begin(), CurrL->end());
8437 }
8438 forgetMemoizedResults(ToForget);
8439}
8440
8442 forgetLoop(L->getOutermostLoop());
8443}
8444
8446 Instruction *I = dyn_cast<Instruction>(V);
8447 if (!I) return;
8448
8449 // Drop information about expressions based on loop-header PHIs.
8453 Worklist.push_back(I);
8454 Visited.insert(I);
8455 visitAndClearUsers(Worklist, Visited, ToForget);
8456
8457 forgetMemoizedResults(ToForget);
8458}
8459
8461 if (!isSCEVable(V->getType()))
8462 return;
8463
8464 // If SCEV looked through a trivial LCSSA phi node, we might have SCEV's
8465 // directly using a SCEVUnknown/SCEVAddRec defined in the loop. After an
8466 // extra predecessor is added, this is no longer valid. Find all Unknowns and
8467 // AddRecs defined in the loop and invalidate any SCEV's making use of them.
8468 if (const SCEV *S = getExistingSCEV(V)) {
8469 struct InvalidationRootCollector {
8470 Loop *L;
8472
8473 InvalidationRootCollector(Loop *L) : L(L) {}
8474
8475 bool follow(const SCEV *S) {
8476 if (auto *SU = dyn_cast<SCEVUnknown>(S)) {
8477 if (auto *I = dyn_cast<Instruction>(SU->getValue()))
8478 if (L->contains(I))
8479 Roots.push_back(S);
8480 } else if (auto *AddRec = dyn_cast<SCEVAddRecExpr>(S)) {
8481 if (L->contains(AddRec->getLoop()))
8482 Roots.push_back(S);
8483 }
8484 return true;
8485 }
8486 bool isDone() const { return false; }
8487 };
8488
8489 InvalidationRootCollector C(L);
8490 visitAll(S, C);
8491 forgetMemoizedResults(C.Roots);
8492 }
8493
8494 // Also perform the normal invalidation.
8495 forgetValue(V);
8496}
8497
8498void ScalarEvolution::forgetLoopDispositions() { LoopDispositions.clear(); }
8499
8501 // Unless a specific value is passed to invalidation, completely clear both
8502 // caches.
8503 if (!V) {
8504 BlockDispositions.clear();
8505 LoopDispositions.clear();
8506 return;
8507 }
8508
8509 if (!isSCEVable(V->getType()))
8510 return;
8511
8512 const SCEV *S = getExistingSCEV(V);
8513 if (!S)
8514 return;
8515
8516 // Invalidate the block and loop dispositions cached for S. Dispositions of
8517 // S's users may change if S's disposition changes (i.e. a user may change to
8518 // loop-invariant, if S changes to loop invariant), so also invalidate
8519 // dispositions of S's users recursively.
8520 SmallVector<const SCEV *, 8> Worklist = {S};
8522 while (!Worklist.empty()) {
8523 const SCEV *Curr = Worklist.pop_back_val();
8524 bool LoopDispoRemoved = LoopDispositions.erase(Curr);
8525 bool BlockDispoRemoved = BlockDispositions.erase(Curr);
8526 if (!LoopDispoRemoved && !BlockDispoRemoved)
8527 continue;
8528 auto Users = SCEVUsers.find(Curr);
8529 if (Users != SCEVUsers.end())
8530 for (const auto *User : Users->second)
8531 if (Seen.insert(User).second)
8532 Worklist.push_back(User);
8533 }
8534}
8535
8536/// Get the exact loop backedge taken count considering all loop exits. A
8537/// computable result can only be returned for loops with all exiting blocks
8538/// dominating the latch. howFarToZero assumes that the limit of each loop test
8539/// is never skipped. This is a valid assumption as long as the loop exits via
8540/// that test. For precise results, it is the caller's responsibility to specify
8541/// the relevant loop exiting block using getExact(ExitingBlock, SE).
8542const SCEV *
8543ScalarEvolution::BackedgeTakenInfo::getExact(const Loop *L, ScalarEvolution *SE,
8545 // If any exits were not computable, the loop is not computable.
8546 if (!isComplete() || ExitNotTaken.empty())
8547 return SE->getCouldNotCompute();
8548
8549 const BasicBlock *Latch = L->getLoopLatch();
8550 // All exiting blocks we have collected must dominate the only backedge.
8551 if (!Latch)
8552 return SE->getCouldNotCompute();
8553
8554 // All exiting blocks we have gathered dominate loop's latch, so exact trip
8555 // count is simply a minimum out of all these calculated exit counts.
8557 for (const auto &ENT : ExitNotTaken) {
8558 const SCEV *BECount = ENT.ExactNotTaken;
8559 assert(BECount != SE->getCouldNotCompute() && "Bad exit SCEV!");
8560 assert(SE->DT.dominates(ENT.ExitingBlock, Latch) &&
8561 "We should only have known counts for exiting blocks that dominate "
8562 "latch!");
8563
8564 Ops.push_back(BECount);
8565
8566 if (Preds)
8567 for (const auto *P : ENT.Predicates)
8568 Preds->push_back(P);
8569
8570 assert((Preds || ENT.hasAlwaysTruePredicate()) &&
8571 "Predicate should be always true!");
8572 }
8573
8574 // If an earlier exit exits on the first iteration (exit count zero), then
8575 // a later poison exit count should not propagate into the result. This are
8576 // exactly the semantics provided by umin_seq.
8577 return SE->getUMinFromMismatchedTypes(Ops, /* Sequential */ true);
8578}
8579
8580/// Get the exact not taken count for this loop exit.
8581const SCEV *
8582ScalarEvolution::BackedgeTakenInfo::getExact(const BasicBlock *ExitingBlock,
8583 ScalarEvolution *SE) const {
8584 for (const auto &ENT : ExitNotTaken)
8585 if (ENT.ExitingBlock == ExitingBlock && ENT.hasAlwaysTruePredicate())
8586 return ENT.ExactNotTaken;
8587
8588 return SE->getCouldNotCompute();
8589}
8590
8591const SCEV *ScalarEvolution::BackedgeTakenInfo::getConstantMax(
8592 const BasicBlock *ExitingBlock, ScalarEvolution *SE) const {
8593 for (const auto &ENT : ExitNotTaken)
8594 if (ENT.ExitingBlock == ExitingBlock && ENT.hasAlwaysTruePredicate())
8595 return ENT.ConstantMaxNotTaken;
8596
8597 return SE->getCouldNotCompute();
8598}
8599
8600const SCEV *ScalarEvolution::BackedgeTakenInfo::getSymbolicMax(
8601 const BasicBlock *ExitingBlock, ScalarEvolution *SE) const {
8602 for (const auto &ENT : ExitNotTaken)
8603 if (ENT.ExitingBlock == ExitingBlock && ENT.hasAlwaysTruePredicate())
8604 return ENT.SymbolicMaxNotTaken;
8605
8606 return SE->getCouldNotCompute();
8607}
8608
8609/// getConstantMax - Get the constant max backedge taken count for the loop.
8610const SCEV *
8611ScalarEvolution::BackedgeTakenInfo::getConstantMax(ScalarEvolution *SE) const {
8612 auto PredicateNotAlwaysTrue = [](const ExitNotTakenInfo &ENT) {
8613 return !ENT.hasAlwaysTruePredicate();
8614 };
8615
8616 if (!getConstantMax() || any_of(ExitNotTaken, PredicateNotAlwaysTrue))
8617 return SE->getCouldNotCompute();
8618
8619 assert((isa<SCEVCouldNotCompute>(getConstantMax()) ||
8620 isa<SCEVConstant>(getConstantMax())) &&
8621 "No point in having a non-constant max backedge taken count!");
8622 return getConstantMax();
8623}
8624
8625const SCEV *ScalarEvolution::BackedgeTakenInfo::getSymbolicMax(
8626 const Loop *L, ScalarEvolution *SE,
8628 if (!SymbolicMax) {
8629 // Form an expression for the maximum exit count possible for this loop. We
8630 // merge the max and exact information to approximate a version of
8631 // getConstantMaxBackedgeTakenCount which isn't restricted to just
8632 // constants.
8634
8635 for (const auto &ENT : ExitNotTaken) {
8636 const SCEV *ExitCount = ENT.SymbolicMaxNotTaken;
8637 if (!isa<SCEVCouldNotCompute>(ExitCount)) {
8638 assert(SE->DT.dominates(ENT.ExitingBlock, L->getLoopLatch()) &&
8639 "We should only have known counts for exiting blocks that "
8640 "dominate latch!");
8641 ExitCounts.push_back(ExitCount);
8642 if (Predicates)
8643 for (const auto *P : ENT.Predicates)
8644 Predicates->push_back(P);
8645
8646 assert((Predicates || ENT.hasAlwaysTruePredicate()) &&
8647 "Predicate should be always true!");
8648 }
8649 }
8650 if (ExitCounts.empty())
8651 SymbolicMax = SE->getCouldNotCompute();
8652 else
8653 SymbolicMax =
8654 SE->getUMinFromMismatchedTypes(ExitCounts, /*Sequential*/ true);
8655 }
8656 return SymbolicMax;
8657}
8658
8659bool ScalarEvolution::BackedgeTakenInfo::isConstantMaxOrZero(
8660 ScalarEvolution *SE) const {
8661 auto PredicateNotAlwaysTrue = [](const ExitNotTakenInfo &ENT) {
8662 return !ENT.hasAlwaysTruePredicate();
8663 };
8664 return MaxOrZero && !any_of(ExitNotTaken, PredicateNotAlwaysTrue);
8665}
8666
8668 : ExitLimit(E, E, E, false, std::nullopt) {}
8669
8671 const SCEV *E, const SCEV *ConstantMaxNotTaken,
8672 const SCEV *SymbolicMaxNotTaken, bool MaxOrZero,
8674 : ExactNotTaken(E), ConstantMaxNotTaken(ConstantMaxNotTaken),
8675 SymbolicMaxNotTaken(SymbolicMaxNotTaken), MaxOrZero(MaxOrZero) {
8676 // If we prove the max count is zero, so is the symbolic bound. This happens
8677 // in practice due to differences in a) how context sensitive we've chosen
8678 // to be and b) how we reason about bounds implied by UB.
8679 if (ConstantMaxNotTaken->isZero()) {
8681 this->SymbolicMaxNotTaken = SymbolicMaxNotTaken = ConstantMaxNotTaken;
8682 }
8683
8684 assert((isa<SCEVCouldNotCompute>(ExactNotTaken) ||
8685 !isa<SCEVCouldNotCompute>(ConstantMaxNotTaken)) &&
8686 "Exact is not allowed to be less precise than Constant Max");
8687 assert((isa<SCEVCouldNotCompute>(ExactNotTaken) ||
8688 !isa<SCEVCouldNotCompute>(SymbolicMaxNotTaken)) &&
8689 "Exact is not allowed to be less precise than Symbolic Max");
8690 assert((isa<SCEVCouldNotCompute>(SymbolicMaxNotTaken) ||
8691 !isa<SCEVCouldNotCompute>(ConstantMaxNotTaken)) &&
8692 "Symbolic Max is not allowed to be less precise than Constant Max");
8693 assert((isa<SCEVCouldNotCompute>(ConstantMaxNotTaken) ||
8694 isa<SCEVConstant>(ConstantMaxNotTaken)) &&
8695 "No point in having a non-constant max backedge taken count!");
8696 for (const auto *PredSet : PredSetList)
8697 for (const auto *P : *PredSet)
8698 addPredicate(P);
8699 assert((isa<SCEVCouldNotCompute>(E) || !E->getType()->isPointerTy()) &&
8700 "Backedge count should be int");
8701 assert((isa<SCEVCouldNotCompute>(ConstantMaxNotTaken) ||
8703 "Max backedge count should be int");
8704}
8705
8707 const SCEV *E, const SCEV *ConstantMaxNotTaken,
8708 const SCEV *SymbolicMaxNotTaken, bool MaxOrZero,
8710 : ExitLimit(E, ConstantMaxNotTaken, SymbolicMaxNotTaken, MaxOrZero,
8711 { &PredSet }) {}
8712
8713/// Allocate memory for BackedgeTakenInfo and copy the not-taken count of each
8714/// computable exit into a persistent ExitNotTakenInfo array.
8715ScalarEvolution::BackedgeTakenInfo::BackedgeTakenInfo(
8717 bool IsComplete, const SCEV *ConstantMax, bool MaxOrZero)
8718 : ConstantMax(ConstantMax), IsComplete(IsComplete), MaxOrZero(MaxOrZero) {
8719 using EdgeExitInfo = ScalarEvolution::BackedgeTakenInfo::EdgeExitInfo;
8720
8721 ExitNotTaken.reserve(ExitCounts.size());
8722 std::transform(ExitCounts.begin(), ExitCounts.end(),
8723 std::back_inserter(ExitNotTaken),
8724 [&](const EdgeExitInfo &EEI) {
8725 BasicBlock *ExitBB = EEI.first;
8726 const ExitLimit &EL = EEI.second;
8727 return ExitNotTakenInfo(ExitBB, EL.ExactNotTaken,
8728 EL.ConstantMaxNotTaken, EL.SymbolicMaxNotTaken,
8729 EL.Predicates);
8730 });
8731 assert((isa<SCEVCouldNotCompute>(ConstantMax) ||
8732 isa<SCEVConstant>(ConstantMax)) &&
8733 "No point in having a non-constant max backedge taken count!");
8734}
8735
8736/// Compute the number of times the backedge of the specified loop will execute.
8737ScalarEvolution::BackedgeTakenInfo
8738ScalarEvolution::computeBackedgeTakenCount(const Loop *L,
8739 bool AllowPredicates) {
8740 SmallVector<BasicBlock *, 8> ExitingBlocks;
8741 L->getExitingBlocks(ExitingBlocks);
8742
8743 using EdgeExitInfo = ScalarEvolution::BackedgeTakenInfo::EdgeExitInfo;
8744
8746 bool CouldComputeBECount = true;
8747 BasicBlock *Latch = L->getLoopLatch(); // may be NULL.
8748 const SCEV *MustExitMaxBECount = nullptr;
8749 const SCEV *MayExitMaxBECount = nullptr;
8750 bool MustExitMaxOrZero = false;
8751 bool IsOnlyExit = ExitingBlocks.size() == 1;
8752
8753 // Compute the ExitLimit for each loop exit. Use this to populate ExitCounts
8754 // and compute maxBECount.
8755 // Do a union of all the predicates here.
8756 for (BasicBlock *ExitBB : ExitingBlocks) {
8757 // We canonicalize untaken exits to br (constant), ignore them so that
8758 // proving an exit untaken doesn't negatively impact our ability to reason
8759 // about the loop as whole.
8760 if (auto *BI = dyn_cast<BranchInst>(ExitBB->getTerminator()))
8761 if (auto *CI = dyn_cast<ConstantInt>(BI->getCondition())) {
8762 bool ExitIfTrue = !L->contains(BI->getSuccessor(0));
8763 if (ExitIfTrue == CI->isZero())
8764 continue;
8765 }
8766
8767 ExitLimit EL = computeExitLimit(L, ExitBB, IsOnlyExit, AllowPredicates);
8768
8769 assert((AllowPredicates || EL.Predicates.empty()) &&
8770 "Predicated exit limit when predicates are not allowed!");
8771
8772 // 1. For each exit that can be computed, add an entry to ExitCounts.
8773 // CouldComputeBECount is true only if all exits can be computed.
8774 if (EL.ExactNotTaken != getCouldNotCompute())
8775 ++NumExitCountsComputed;
8776 else
8777 // We couldn't compute an exact value for this exit, so
8778 // we won't be able to compute an exact value for the loop.
8779 CouldComputeBECount = false;
8780 // Remember exit count if either exact or symbolic is known. Because
8781 // Exact always implies symbolic, only check symbolic.
8782 if (EL.SymbolicMaxNotTaken != getCouldNotCompute())
8783 ExitCounts.emplace_back(ExitBB, EL);
8784 else {
8785 assert(EL.ExactNotTaken == getCouldNotCompute() &&
8786 "Exact is known but symbolic isn't?");
8787 ++NumExitCountsNotComputed;
8788 }
8789
8790 // 2. Derive the loop's MaxBECount from each exit's max number of
8791 // non-exiting iterations. Partition the loop exits into two kinds:
8792 // LoopMustExits and LoopMayExits.
8793 //
8794 // If the exit dominates the loop latch, it is a LoopMustExit otherwise it
8795 // is a LoopMayExit. If any computable LoopMustExit is found, then
8796 // MaxBECount is the minimum EL.ConstantMaxNotTaken of computable
8797 // LoopMustExits. Otherwise, MaxBECount is conservatively the maximum
8798 // EL.ConstantMaxNotTaken, where CouldNotCompute is considered greater than
8799 // any
8800 // computable EL.ConstantMaxNotTaken.
8801 if (EL.ConstantMaxNotTaken != getCouldNotCompute() && Latch &&
8802 DT.dominates(ExitBB, Latch)) {
8803 if (!MustExitMaxBECount) {
8804 MustExitMaxBECount = EL.ConstantMaxNotTaken;
8805 MustExitMaxOrZero = EL.MaxOrZero;
8806 } else {
8807 MustExitMaxBECount = getUMinFromMismatchedTypes(MustExitMaxBECount,
8808 EL.ConstantMaxNotTaken);
8809 }
8810 } else if (MayExitMaxBECount != getCouldNotCompute()) {
8811 if (!MayExitMaxBECount || EL.ConstantMaxNotTaken == getCouldNotCompute())
8812 MayExitMaxBECount = EL.ConstantMaxNotTaken;
8813 else {
8814 MayExitMaxBECount = getUMaxFromMismatchedTypes(MayExitMaxBECount,
8815 EL.ConstantMaxNotTaken);
8816 }
8817 }
8818 }
8819 const SCEV *MaxBECount = MustExitMaxBECount ? MustExitMaxBECount :
8820 (MayExitMaxBECount ? MayExitMaxBECount : getCouldNotCompute());
8821 // The loop backedge will be taken the maximum or zero times if there's
8822 // a single exit that must be taken the maximum or zero times.
8823 bool MaxOrZero = (MustExitMaxOrZero && ExitingBlocks.size() == 1);
8824
8825 // Remember which SCEVs are used in exit limits for invalidation purposes.
8826 // We only care about non-constant SCEVs here, so we can ignore
8827 // EL.ConstantMaxNotTaken
8828 // and MaxBECount, which must be SCEVConstant.
8829 for (const auto &Pair : ExitCounts) {
8830 if (!isa<SCEVConstant>(Pair.second.ExactNotTaken))
8831 BECountUsers[Pair.second.ExactNotTaken].insert({L, AllowPredicates});
8832 if (!isa<SCEVConstant>(Pair.second.SymbolicMaxNotTaken))
8833 BECountUsers[Pair.second.SymbolicMaxNotTaken].insert(
8834 {L, AllowPredicates});
8835 }
8836 return BackedgeTakenInfo(std::move(ExitCounts), CouldComputeBECount,
8837 MaxBECount, MaxOrZero);
8838}
8839
8841ScalarEvolution::computeExitLimit(const Loop *L, BasicBlock *ExitingBlock,
8842 bool IsOnlyExit, bool AllowPredicates) {
8843 assert(L->contains(ExitingBlock) && "Exit count for non-loop block?");
8844 // If our exiting block does not dominate the latch, then its connection with
8845 // loop's exit limit may be far from trivial.
8846 const BasicBlock *Latch = L->getLoopLatch();
8847 if (!Latch || !DT.dominates(ExitingBlock, Latch))
8848 return getCouldNotCompute();
8849
8850 Instruction *Term = ExitingBlock->getTerminator();
8851 if (BranchInst *BI = dyn_cast<BranchInst>(Term)) {
8852 assert(BI->isConditional() && "If unconditional, it can't be in loop!");
8853 bool ExitIfTrue = !L->contains(BI->getSuccessor(0));
8854 assert(ExitIfTrue == L->contains(BI->getSuccessor(1)) &&
8855 "It should have one successor in loop and one exit block!");
8856 // Proceed to the next level to examine the exit condition expression.
8857 return computeExitLimitFromCond(L, BI->getCondition(), ExitIfTrue,
8858 /*ControlsOnlyExit=*/IsOnlyExit,
8859 AllowPredicates);
8860 }
8861
8862 if (SwitchInst *SI = dyn_cast<SwitchInst>(Term)) {
8863 // For switch, make sure that there is a single exit from the loop.
8864 BasicBlock *Exit = nullptr;
8865 for (auto *SBB : successors(ExitingBlock))
8866 if (!L->contains(SBB)) {
8867 if (Exit) // Multiple exit successors.
8868 return getCouldNotCompute();
8869 Exit = SBB;
8870 }
8871 assert(Exit && "Exiting block must have at least one exit");
8872 return computeExitLimitFromSingleExitSwitch(
8873 L, SI, Exit, /*ControlsOnlyExit=*/IsOnlyExit);
8874 }
8875
8876 return getCouldNotCompute();
8877}
8878
8880 const Loop *L, Value *ExitCond, bool ExitIfTrue, bool ControlsOnlyExit,
8881 bool AllowPredicates) {
8882 ScalarEvolution::ExitLimitCacheTy Cache(L, ExitIfTrue, AllowPredicates);
8883 return computeExitLimitFromCondCached(Cache, L, ExitCond, ExitIfTrue,
8884 ControlsOnlyExit, AllowPredicates);
8885}
8886
8887std::optional<ScalarEvolution::ExitLimit>
8888ScalarEvolution::ExitLimitCache::find(const Loop *L, Value *ExitCond,
8889 bool ExitIfTrue, bool ControlsOnlyExit,
8890 bool AllowPredicates) {
8891 (void)this->L;
8892 (void)this->ExitIfTrue;
8893 (void)this->AllowPredicates;
8894
8895 assert(this->L == L && this->ExitIfTrue == ExitIfTrue &&
8896 this->AllowPredicates == AllowPredicates &&
8897 "Variance in assumed invariant key components!");
8898 auto Itr = TripCountMap.find({ExitCond, ControlsOnlyExit});
8899 if (Itr == TripCountMap.end())
8900 return std::nullopt;
8901 return Itr->second;
8902}
8903
8904void ScalarEvolution::ExitLimitCache::insert(const Loop *L, Value *ExitCond,
8905 bool ExitIfTrue,
8906 bool ControlsOnlyExit,
8907 bool AllowPredicates,
8908 const ExitLimit &EL) {
8909 assert(this->L == L && this->ExitIfTrue == ExitIfTrue &&
8910 this->AllowPredicates == AllowPredicates &&
8911 "Variance in assumed invariant key components!");
8912
8913 auto InsertResult = TripCountMap.insert({{ExitCond, ControlsOnlyExit}, EL});
8914 assert(InsertResult.second && "Expected successful insertion!");
8915 (void)InsertResult;
8916 (void)ExitIfTrue;
8917}
8918
8919ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromCondCached(
8920 ExitLimitCacheTy &Cache, const Loop *L, Value *ExitCond, bool ExitIfTrue,
8921 bool ControlsOnlyExit, bool AllowPredicates) {
8922
8923 if (auto MaybeEL = Cache.find(L, ExitCond, ExitIfTrue, ControlsOnlyExit,
8924 AllowPredicates))
8925 return *MaybeEL;
8926
8927 ExitLimit EL = computeExitLimitFromCondImpl(
8928 Cache, L, ExitCond, ExitIfTrue, ControlsOnlyExit, AllowPredicates);
8929 Cache.insert(L, ExitCond, ExitIfTrue, ControlsOnlyExit, AllowPredicates, EL);
8930 return EL;
8931}
8932
8933ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromCondImpl(
8934 ExitLimitCacheTy &Cache, const Loop *L, Value *ExitCond, bool ExitIfTrue,
8935 bool ControlsOnlyExit, bool AllowPredicates) {
8936 // Handle BinOp conditions (And, Or).
8937 if (auto LimitFromBinOp = computeExitLimitFromCondFromBinOp(
8938 Cache, L, ExitCond, ExitIfTrue, ControlsOnlyExit, AllowPredicates))
8939 return *LimitFromBinOp;
8940
8941 // With an icmp, it may be feasible to compute an exact backedge-taken count.
8942 // Proceed to the next level to examine the icmp.
8943 if (ICmpInst *ExitCondICmp = dyn_cast<ICmpInst>(ExitCond)) {
8944 ExitLimit EL =
8945 computeExitLimitFromICmp(L, ExitCondICmp, ExitIfTrue, ControlsOnlyExit);
8946 if (EL.hasFullInfo() || !AllowPredicates)
8947 return EL;
8948
8949 // Try again, but use SCEV predicates this time.
8950 return computeExitLimitFromICmp(L, ExitCondICmp, ExitIfTrue,
8951 ControlsOnlyExit,
8952 /*AllowPredicates=*/true);
8953 }
8954
8955 // Check for a constant condition. These are normally stripped out by
8956 // SimplifyCFG, but ScalarEvolution may be used by a pass which wishes to
8957 // preserve the CFG and is temporarily leaving constant conditions
8958 // in place.
8959 if (ConstantInt *CI = dyn_cast<ConstantInt>(ExitCond)) {
8960 if (ExitIfTrue == !CI->getZExtValue())
8961 // The backedge is always taken.
8962 return getCouldNotCompute();
8963 // The backedge is never taken.
8964 return getZero(CI->getType());
8965 }
8966
8967 // If we're exiting based on the overflow flag of an x.with.overflow intrinsic
8968 // with a constant step, we can form an equivalent icmp predicate and figure
8969 // out how many iterations will be taken before we exit.
8970 const WithOverflowInst *WO;
8971 const APInt *C;
8972 if (match(ExitCond, m_ExtractValue<1>(m_WithOverflowInst(WO))) &&
8973 match(WO->getRHS(), m_APInt(C))) {
8974 ConstantRange NWR =
8976 WO->getNoWrapKind());
8977 CmpInst::Predicate Pred;
8978 APInt NewRHSC, Offset;
8979 NWR.getEquivalentICmp(Pred, NewRHSC, Offset);
8980 if (!ExitIfTrue)
8981 Pred = ICmpInst::getInversePredicate(Pred);
8982 auto *LHS = getSCEV(WO->getLHS());
8983 if (Offset != 0)
8985 auto EL = computeExitLimitFromICmp(L, Pred, LHS, getConstant(NewRHSC),
8986 ControlsOnlyExit, AllowPredicates);
8987 if (EL.hasAnyInfo())
8988 return EL;
8989 }
8990
8991 // If it's not an integer or pointer comparison then compute it the hard way.
8992 return computeExitCountExhaustively(L, ExitCond, ExitIfTrue);
8993}
8994
8995std::optional<ScalarEvolution::ExitLimit>
8996ScalarEvolution::computeExitLimitFromCondFromBinOp(
8997 ExitLimitCacheTy &Cache, const Loop *L, Value *ExitCond, bool ExitIfTrue,
8998 bool ControlsOnlyExit, bool AllowPredicates) {
8999 // Check if the controlling expression for this loop is an And or Or.
9000 Value *Op0, *Op1;
9001 bool IsAnd = false;
9002 if (match(ExitCond, m_LogicalAnd(m_Value(Op0), m_Value(Op1))))
9003 IsAnd = true;
9004 else if (match(ExitCond, m_LogicalOr(m_Value(Op0), m_Value(Op1))))
9005 IsAnd = false;
9006 else
9007 return std::nullopt;
9008
9009 // EitherMayExit is true in these two cases:
9010 // br (and Op0 Op1), loop, exit
9011 // br (or Op0 Op1), exit, loop
9012 bool EitherMayExit = IsAnd ^ ExitIfTrue;
9013 ExitLimit EL0 = computeExitLimitFromCondCached(
9014 Cache, L, Op0, ExitIfTrue, ControlsOnlyExit && !EitherMayExit,
9015 AllowPredicates);
9016 ExitLimit EL1 = computeExitLimitFromCondCached(
9017 Cache, L, Op1, ExitIfTrue, ControlsOnlyExit && !EitherMayExit,
9018 AllowPredicates);
9019
9020 // Be robust against unsimplified IR for the form "op i1 X, NeutralElement"
9021 const Constant *NeutralElement = ConstantInt::get(ExitCond->getType(), IsAnd);
9022 if (isa<ConstantInt>(Op1))
9023 return Op1 == NeutralElement ? EL0 : EL1;
9024 if (isa<ConstantInt>(Op0))
9025 return Op0 == NeutralElement ? EL1 : EL0;
9026
9027 const SCEV *BECount = getCouldNotCompute();
9028 const SCEV *ConstantMaxBECount = getCouldNotCompute();
9029 const SCEV *SymbolicMaxBECount = getCouldNotCompute();
9030 if (EitherMayExit) {
9031 bool UseSequentialUMin = !isa<BinaryOperator>(ExitCond);
9032 // Both conditions must be same for the loop to continue executing.
9033 // Choose the less conservative count.
9034 if (EL0.ExactNotTaken != getCouldNotCompute() &&
9035 EL1.ExactNotTaken != getCouldNotCompute()) {
9036 BECount = getUMinFromMismatchedTypes(EL0.ExactNotTaken, EL1.ExactNotTaken,
9037 UseSequentialUMin);
9038 }
9039 if (EL0.ConstantMaxNotTaken == getCouldNotCompute())
9040 ConstantMaxBECount = EL1.ConstantMaxNotTaken;
9041 else if (EL1.ConstantMaxNotTaken == getCouldNotCompute())
9042 ConstantMaxBECount = EL0.ConstantMaxNotTaken;
9043 else
9044 ConstantMaxBECount = getUMinFromMismatchedTypes(EL0.ConstantMaxNotTaken,
9045 EL1.ConstantMaxNotTaken);
9046 if (EL0.SymbolicMaxNotTaken == getCouldNotCompute())
9047 SymbolicMaxBECount = EL1.SymbolicMaxNotTaken;
9048 else if (EL1.SymbolicMaxNotTaken == getCouldNotCompute())
9049 SymbolicMaxBECount = EL0.SymbolicMaxNotTaken;
9050 else
9051 SymbolicMaxBECount = getUMinFromMismatchedTypes(
9052 EL0.SymbolicMaxNotTaken, EL1.SymbolicMaxNotTaken, UseSequentialUMin);
9053 } else {
9054 // Both conditions must be same at the same time for the loop to exit.
9055 // For now, be conservative.
9056 if (EL0.ExactNotTaken == EL1.ExactNotTaken)
9057 BECount = EL0.ExactNotTaken;
9058 }
9059
9060 // There are cases (e.g. PR26207) where computeExitLimitFromCond is able
9061 // to be more aggressive when computing BECount than when computing
9062 // ConstantMaxBECount. In these cases it is possible for EL0.ExactNotTaken
9063 // and
9064 // EL1.ExactNotTaken to match, but for EL0.ConstantMaxNotTaken and
9065 // EL1.ConstantMaxNotTaken to not.
9066 if (isa<SCEVCouldNotCompute>(ConstantMaxBECount) &&
9067 !isa<SCEVCouldNotCompute>(BECount))
9068 ConstantMaxBECount = getConstant(getUnsignedRangeMax(BECount));
9069 if (isa<SCEVCouldNotCompute>(SymbolicMaxBECount))
9070 SymbolicMaxBECount =
9071 isa<SCEVCouldNotCompute>(BECount) ? ConstantMaxBECount : BECount;
9072 return ExitLimit(BECount, ConstantMaxBECount, SymbolicMaxBECount, false,
9073 { &EL0.Predicates, &EL1.Predicates });
9074}
9075
9076ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromICmp(
9077 const Loop *L, ICmpInst *ExitCond, bool ExitIfTrue, bool ControlsOnlyExit,
9078 bool AllowPredicates) {
9079 // If the condition was exit on true, convert the condition to exit on false
9081 if (!ExitIfTrue)
9082 Pred = ExitCond->getPredicate();
9083 else
9084 Pred = ExitCond->getInversePredicate();
9085 const ICmpInst::Predicate OriginalPred = Pred;
9086
9087 const SCEV *LHS = getSCEV(ExitCond->getOperand(0));
9088 const SCEV *RHS = getSCEV(ExitCond->getOperand(1));
9089
9090 ExitLimit EL = computeExitLimitFromICmp(L, Pred, LHS, RHS, ControlsOnlyExit,
9091 AllowPredicates);
9092 if (EL.hasAnyInfo())
9093 return EL;
9094
9095 auto *ExhaustiveCount =
9096 computeExitCountExhaustively(L, ExitCond, ExitIfTrue);
9097
9098 if (!isa<SCEVCouldNotCompute>(ExhaustiveCount))
9099 return ExhaustiveCount;
9100
9101 return computeShiftCompareExitLimit(ExitCond->getOperand(0),
9102 ExitCond->getOperand(1), L, OriginalPred);
9103}
9104ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromICmp(
9105 const Loop *L, ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS,
9106 bool ControlsOnlyExit, bool AllowPredicates) {
9107
9108 // Try to evaluate any dependencies out of the loop.
9109 LHS = getSCEVAtScope(LHS, L);
9110 RHS = getSCEVAtScope(RHS, L);
9111
9112 // At this point, we would like to compute how many iterations of the
9113 // loop the predicate will return true for these inputs.
9114 if (isLoopInvariant(LHS, L) && !isLoopInvariant(RHS, L)) {
9115 // If there is a loop-invariant, force it into the RHS.
9116 std::swap(LHS, RHS);
9117 Pred = ICmpInst::getSwappedPredicate(Pred);
9118 }
9119
9120 bool ControllingFiniteLoop = ControlsOnlyExit && loopHasNoAbnormalExits(L) &&
9122 // Simplify the operands before analyzing them.
9123 (void)SimplifyICmpOperands(Pred, LHS, RHS, /*Depth=*/0);
9124
9125 // If we have a comparison of a chrec against a constant, try to use value
9126 // ranges to answer this query.
9127 if (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS))
9128 if (const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(LHS))
9129 if (AddRec->getLoop() == L) {
9130 // Form the constant range.
9131 ConstantRange CompRange =
9132 ConstantRange::makeExactICmpRegion(Pred, RHSC->getAPInt());
9133
9134 const SCEV *Ret = AddRec->getNumIterationsInRange(CompRange, *this);
9135 if (!isa<SCEVCouldNotCompute>(Ret)) return Ret;
9136 }
9137
9138 // If this loop must exit based on this condition (or execute undefined
9139 // behaviour), see if we can improve wrap flags. This is essentially
9140 // a must execute style proof.
9141 if (ControllingFiniteLoop && isLoopInvariant(RHS, L)) {
9142 // If we can prove the test sequence produced must repeat the same values
9143 // on self-wrap of the IV, then we can infer that IV doesn't self wrap
9144 // because if it did, we'd have an infinite (undefined) loop.
9145 // TODO: We can peel off any functions which are invertible *in L*. Loop
9146 // invariant terms are effectively constants for our purposes here.
9147 auto *InnerLHS = LHS;
9148 if (auto *ZExt = dyn_cast<SCEVZeroExtendExpr>(LHS))
9149 InnerLHS = ZExt->getOperand();
9150 if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(InnerLHS);
9151 AR && !AR->hasNoSelfWrap() && AR->getLoop() == L && AR->isAffine() &&
9152 isKnownToBeAPowerOfTwo(AR->getStepRecurrence(*this), /*OrZero=*/true,
9153 /*OrNegative=*/true)) {
9154 auto Flags = AR->getNoWrapFlags();
9155 Flags = setFlags(Flags, SCEV::FlagNW);
9158 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), Flags);
9159 }
9160
9161 // For a slt/ult condition with a positive step, can we prove nsw/nuw?
9162 // From no-self-wrap, this follows trivially from the fact that every
9163 // (un)signed-wrapped, but not self-wrapped value must be LT than the
9164 // last value before (un)signed wrap. Since we know that last value
9165 // didn't exit, nor will any smaller one.
9166 if (Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_ULT) {
9167 auto WrapType = Pred == ICmpInst::ICMP_SLT ? SCEV::FlagNSW : SCEV::FlagNUW;
9168 if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(LHS);
9169 AR && AR->getLoop() == L && AR->isAffine() &&
9170 !AR->getNoWrapFlags(WrapType) && AR->hasNoSelfWrap() &&
9171 isKnownPositive(AR->getStepRecurrence(*this))) {
9172 auto Flags = AR->getNoWrapFlags();
9173 Flags = setFlags(Flags, WrapType);
9176 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), Flags);
9177 }
9178 }
9179 }
9180
9181 switch (Pred) {
9182 case ICmpInst::ICMP_NE: { // while (X != Y)
9183 // Convert to: while (X-Y != 0)
9184 if (LHS->getType()->isPointerTy()) {
9186 if (isa<SCEVCouldNotCompute>(LHS))
9187 return LHS;
9188 }
9189 if (RHS->getType()->isPointerTy()) {
9191 if (isa<SCEVCouldNotCompute>(RHS))
9192 return RHS;
9193 }
9194 ExitLimit EL = howFarToZero(getMinusSCEV(LHS, RHS), L, ControlsOnlyExit,
9195 AllowPredicates);
9196 if (EL.hasAnyInfo())
9197 return EL;
9198 break;
9199 }
9200 case ICmpInst::ICMP_EQ: { // while (X == Y)
9201 // Convert to: while (X-Y == 0)
9202 if (LHS->getType()->isPointerTy()) {
9204 if (isa<SCEVCouldNotCompute>(LHS))
9205 return LHS;
9206 }
9207 if (RHS->getType()->isPointerTy()) {
9209 if (isa<SCEVCouldNotCompute>(RHS))
9210 return RHS;
9211 }
9212 ExitLimit EL = howFarToNonZero(getMinusSCEV(LHS, RHS), L);
9213 if (EL.hasAnyInfo()) return EL;
9214 break;
9215 }
9216 case ICmpInst::ICMP_SLE:
9217 case ICmpInst::ICMP_ULE:
9218 // Since the loop is finite, an invariant RHS cannot include the boundary
9219 // value, otherwise it would loop forever.
9220 if (!EnableFiniteLoopControl || !ControllingFiniteLoop ||
9221 !isLoopInvariant(RHS, L)) {
9222 // Otherwise, perform the addition in a wider type, to avoid overflow.
9223 // If the LHS is an addrec with the appropriate nowrap flag, the
9224 // extension will be sunk into it and the exit count can be analyzed.
9225 auto *OldType = dyn_cast<IntegerType>(LHS->getType());
9226 if (!OldType)
9227 break;
9228 // Prefer doubling the bitwidth over adding a single bit to make it more
9229 // likely that we use a legal type.
9230 auto *NewType =
9231 Type::getIntNTy(OldType->getContext(), OldType->getBitWidth() * 2);
9232 if (ICmpInst::isSigned(Pred)) {
9233 LHS = getSignExtendExpr(LHS, NewType);
9234 RHS = getSignExtendExpr(RHS, NewType);
9235 } else {
9236 LHS = getZeroExtendExpr(LHS, NewType);
9237 RHS = getZeroExtendExpr(RHS, NewType);
9238 }
9239 }
9240 RHS = getAddExpr(getOne(RHS->getType()), RHS);
9241 [[fallthrough]];
9242 case ICmpInst::ICMP_SLT:
9243 case ICmpInst::ICMP_ULT: { // while (X < Y)
9244 bool IsSigned = ICmpInst::isSigned(Pred);
9245 ExitLimit EL = howManyLessThans(LHS, RHS, L, IsSigned, ControlsOnlyExit,
9246 AllowPredicates);
9247 if (EL.hasAnyInfo())
9248 return EL;
9249 break;
9250 }
9251 case ICmpInst::ICMP_SGE:
9252 case ICmpInst::ICMP_UGE:
9253 // Since the loop is finite, an invariant RHS cannot include the boundary
9254 // value, otherwise it would loop forever.
9255 if (!EnableFiniteLoopControl || !ControllingFiniteLoop ||
9256 !isLoopInvariant(RHS, L))
9257 break;
9258 RHS = getAddExpr(getMinusOne(RHS->getType()), RHS);
9259 [[fallthrough]];
9260 case ICmpInst::ICMP_SGT:
9261 case ICmpInst::ICMP_UGT: { // while (X > Y)
9262 bool IsSigned = ICmpInst::isSigned(Pred);
9263 ExitLimit EL = howManyGreaterThans(LHS, RHS, L, IsSigned, ControlsOnlyExit,
9264 AllowPredicates);
9265 if (EL.hasAnyInfo())
9266 return EL;
9267 break;
9268 }
9269 default:
9270 break;
9271 }
9272
9273 return getCouldNotCompute();
9274}
9275
9277ScalarEvolution::computeExitLimitFromSingleExitSwitch(const Loop *L,
9278 SwitchInst *Switch,
9279 BasicBlock *ExitingBlock,
9280 bool ControlsOnlyExit) {
9281 assert(!L->contains(ExitingBlock) && "Not an exiting block!");
9282
9283 // Give up if the exit is the default dest of a switch.
9284 if (Switch->getDefaultDest() == ExitingBlock)
9285 return getCouldNotCompute();
9286
9287 assert(L->contains(Switch->getDefaultDest()) &&
9288 "Default case must not exit the loop!");
9289 const SCEV *LHS = getSCEVAtScope(Switch->getCondition(), L);
9290 const SCEV *RHS = getConstant(Switch->findCaseDest(ExitingBlock));
9291
9292 // while (X != Y) --> while (X-Y != 0)
9293 ExitLimit EL = howFarToZero(getMinusSCEV(LHS, RHS), L, ControlsOnlyExit);
9294 if (EL.hasAnyInfo())
9295 return EL;
9296
9297 return getCouldNotCompute();
9298}
9299
9300static ConstantInt *
9302 ScalarEvolution &SE) {
9303 const SCEV *InVal = SE.getConstant(C);
9304 const SCEV *Val = AddRec->evaluateAtIteration(InVal, SE);
9305 assert(isa<SCEVConstant>(Val) &&
9306 "Evaluation of SCEV at constant didn't fold correctly?");
9307 return cast<SCEVConstant>(Val)->getValue();
9308}
9309
9310ScalarEvolution::ExitLimit ScalarEvolution::computeShiftCompareExitLimit(
9311 Value *LHS, Value *RHSV, const Loop *L, ICmpInst::Predicate Pred) {
9312 ConstantInt *RHS = dyn_cast<ConstantInt>(RHSV);
9313 if (!RHS)
9314 return getCouldNotCompute();
9315
9316 const BasicBlock *Latch = L->getLoopLatch();
9317 if (!Latch)
9318 return getCouldNotCompute();
9319
9320 const BasicBlock *Predecessor = L->getLoopPredecessor();
9321 if (!Predecessor)
9322 return getCouldNotCompute();
9323
9324 // Return true if V is of the form "LHS `shift_op` <positive constant>".
9325 // Return LHS in OutLHS and shift_opt in OutOpCode.
9326 auto MatchPositiveShift =
9327 [](Value *V, Value *&OutLHS, Instruction::BinaryOps &OutOpCode) {
9328
9329 using namespace PatternMatch;
9330
9331 ConstantInt *ShiftAmt;
9332 if (match(V, m_LShr(m_Value(OutLHS), m_ConstantInt(ShiftAmt))))
9333 OutOpCode = Instruction::LShr;
9334 else if (match(V, m_AShr(m_Value(OutLHS), m_ConstantInt(ShiftAmt))))
9335 OutOpCode = Instruction::AShr;
9336 else if (match(V, m_Shl(m_Value(OutLHS), m_ConstantInt(ShiftAmt))))
9337 OutOpCode = Instruction::Shl;
9338 else
9339 return false;
9340
9341 return ShiftAmt->getValue().isStrictlyPositive();
9342 };
9343
9344 // Recognize a "shift recurrence" either of the form %iv or of %iv.shifted in
9345 //
9346 // loop:
9347 // %iv = phi i32 [ %iv.shifted, %loop ], [ %val, %preheader ]
9348 // %iv.shifted = lshr i32 %iv, <positive constant>
9349 //
9350 // Return true on a successful match. Return the corresponding PHI node (%iv
9351 // above) in PNOut and the opcode of the shift operation in OpCodeOut.
9352 auto MatchShiftRecurrence =
9353 [&](Value *V, PHINode *&PNOut, Instruction::BinaryOps &OpCodeOut) {
9354 std::optional<Instruction::BinaryOps> PostShiftOpCode;
9355
9356 {
9358 Value *V;
9359
9360 // If we encounter a shift instruction, "peel off" the shift operation,
9361 // and remember that we did so. Later when we inspect %iv's backedge
9362 // value, we will make sure that the backedge value uses the same
9363 // operation.
9364 //
9365 // Note: the peeled shift operation does not have to be the same
9366 // instruction as the one feeding into the PHI's backedge value. We only
9367 // really care about it being the same *kind* of shift instruction --
9368 // that's all that is required for our later inferences to hold.
9369 if (MatchPositiveShift(LHS, V, OpC)) {
9370 PostShiftOpCode = OpC;
9371 LHS = V;
9372 }
9373 }
9374
9375 PNOut = dyn_cast<PHINode>(LHS);
9376 if (!PNOut || PNOut->getParent() != L->getHeader())
9377 return false;
9378
9379 Value *BEValue = PNOut->getIncomingValueForBlock(Latch);
9380 Value *OpLHS;
9381
9382 return
9383 // The backedge value for the PHI node must be a shift by a positive
9384 // amount
9385 MatchPositiveShift(BEValue, OpLHS, OpCodeOut) &&
9386
9387 // of the PHI node itself
9388 OpLHS == PNOut &&
9389
9390 // and the kind of shift should be match the kind of shift we peeled
9391 // off, if any.
9392 (!PostShiftOpCode || *PostShiftOpCode == OpCodeOut);
9393 };
9394
9395 PHINode *PN;
9397 if (!MatchShiftRecurrence(LHS, PN, OpCode))
9398 return getCouldNotCompute();
9399
9400 const DataLayout &DL = getDataLayout();
9401
9402 // The key rationale for this optimization is that for some kinds of shift
9403 // recurrences, the value of the recurrence "stabilizes" to either 0 or -1
9404 // within a finite number of iterations. If the condition guarding the
9405 // backedge (in the sense that the backedge is taken if the condition is true)
9406 // is false for the value the shift recurrence stabilizes to, then we know
9407 // that the backedge is taken only a finite number of times.
9408
9409 ConstantInt *StableValue = nullptr;
9410 switch (OpCode) {
9411 default:
9412 llvm_unreachable("Impossible case!");
9413
9414 case Instruction::AShr: {
9415 // {K,ashr,<positive-constant>} stabilizes to signum(K) in at most
9416 // bitwidth(K) iterations.
9417 Value *FirstValue = PN->getIncomingValueForBlock(Predecessor);
9418 KnownBits Known = computeKnownBits(FirstValue, DL, 0, &AC,
9419 Predecessor->getTerminator(), &DT);
9420 auto *Ty = cast<IntegerType>(RHS->getType());
9421 if (Known.isNonNegative())
9422 StableValue = ConstantInt::get(Ty, 0);
9423 else if (Known.isNegative())
9424 StableValue = ConstantInt::get(Ty, -1, true);
9425 else
9426 return getCouldNotCompute();
9427
9428 break;
9429 }
9430 case Instruction::LShr:
9431 case Instruction::Shl:
9432 // Both {K,lshr,<positive-constant>} and {K,shl,<positive-constant>}
9433 // stabilize to 0 in at most bitwidth(K) iterations.
9434 StableValue = ConstantInt::get(cast<IntegerType>(RHS->getType()), 0);
9435 break;
9436 }
9437
9438 auto *Result =
9439 ConstantFoldCompareInstOperands(Pred, StableValue, RHS, DL, &TLI);
9440 assert(Result->getType()->isIntegerTy(1) &&
9441 "Otherwise cannot be an operand to a branch instruction");
9442
9443 if (Result->isZeroValue()) {
9444 unsigned BitWidth = getTypeSizeInBits(RHS->getType());
9445 const SCEV *UpperBound =
9447 return ExitLimit(getCouldNotCompute(), UpperBound, UpperBound, false);
9448 }
9449
9450 return getCouldNotCompute();
9451}
9452
9453/// Return true if we can constant fold an instruction of the specified type,
9454/// assuming that all operands were constants.
9455static bool CanConstantFold(const Instruction *I) {
9456 if (isa<BinaryOperator>(I) || isa<CmpInst>(I) ||
9457 isa<SelectInst>(I) || isa<CastInst>(I) || isa<GetElementPtrInst>(I) ||
9458 isa<LoadInst>(I) || isa<ExtractValueInst>(I))
9459 return true;
9460
9461 if (const CallInst *CI = dyn_cast<CallInst>(I))
9462 if (const Function *F = CI->getCalledFunction())
9463 return canConstantFoldCallTo(CI, F);
9464 return false;
9465}
9466
9467/// Determine whether this instruction can constant evolve within this loop
9468/// assuming its operands can all constant evolve.
9469static bool canConstantEvolve(Instruction *I, const Loop *L) {
9470 // An instruction outside of the loop can't be derived from a loop PHI.
9471 if (!L->contains(I)) return false;
9472
9473 if (isa<PHINode>(I)) {
9474 // We don't currently keep track of the control flow needed to evaluate
9475 // PHIs, so we cannot handle PHIs inside of loops.
9476 return L->getHeader() == I->getParent();
9477 }
9478
9479 // If we won't be able to constant fold this expression even if the operands
9480 // are constants, bail early.
9481 return CanConstantFold(I);
9482}
9483
9484/// getConstantEvolvingPHIOperands - Implement getConstantEvolvingPHI by
9485/// recursing through each instruction operand until reaching a loop header phi.
9486static PHINode *
9489 unsigned Depth) {
9491 return nullptr;
9492
9493 // Otherwise, we can evaluate this instruction if all of its operands are
9494 // constant or derived from a PHI node themselves.
9495 PHINode *PHI = nullptr;
9496 for (Value *Op : UseInst->operands()) {
9497 if (isa<Constant>(Op)) continue;
9498
9499 Instruction *OpInst = dyn_cast<Instruction>(Op);
9500 if (!OpInst || !canConstantEvolve(OpInst, L)) return nullptr;
9501
9502 PHINode *P = dyn_cast<PHINode>(OpInst);
9503 if (!P)
9504 // If this operand is already visited, reuse the prior result.
9505 // We may have P != PHI if this is the deepest point at which the
9506 // inconsistent paths meet.
9507 P = PHIMap.lookup(OpInst);
9508 if (!P) {
9509 // Recurse and memoize the results, whether a phi is found or not.
9510 // This recursive call invalidates pointers into PHIMap.
9511 P = getConstantEvolvingPHIOperands(OpInst, L, PHIMap, Depth + 1);
9512 PHIMap[OpInst] = P;
9513 }
9514 if (!P)
9515 return nullptr; // Not evolving from PHI
9516 if (PHI && PHI != P)
9517 return nullptr; // Evolving from multiple different PHIs.
9518 PHI = P;
9519 }
9520 // This is a expression evolving from a constant PHI!
9521 return PHI;
9522}
9523
9524/// getConstantEvolvingPHI - Given an LLVM value and a loop, return a PHI node
9525/// in the loop that V is derived from. We allow arbitrary operations along the
9526/// way, but the operands of an operation must either be constants or a value
9527/// derived from a constant PHI. If this expression does not fit with these
9528/// constraints, return null.
9530 Instruction *I = dyn_cast<Instruction>(V);
9531 if (!I || !canConstantEvolve(I, L)) return nullptr;
9532
9533 if (PHINode *PN = dyn_cast<PHINode>(I))
9534 return PN;
9535
9536 // Record non-constant instructions contained by the loop.
9538 return getConstantEvolvingPHIOperands(I, L, PHIMap, 0);
9539}
9540
9541/// EvaluateExpression - Given an expression that passes the
9542/// getConstantEvolvingPHI predicate, evaluate its value assuming the PHI node
9543/// in the loop has the value PHIVal. If we can't fold this expression for some
9544/// reason, return null.
9547 const DataLayout &DL,
9548 const TargetLibraryInfo *TLI) {
9549 // Convenient constant check, but redundant for recursive calls.
9550 if (Constant *C = dyn_cast<Constant>(V)) return C;
9551 Instruction *I = dyn_cast<Instruction>(V);
9552 if (!I) return nullptr;
9553
9554 if (Constant *C = Vals.lookup(I)) return C;
9555
9556 // An instruction inside the loop depends on a value outside the loop that we
9557 // weren't given a mapping for, or a value such as a call inside the loop.
9558 if (!canConstantEvolve(I, L)) return nullptr;
9559
9560 // An unmapped PHI can be due to a branch or another loop inside this loop,
9561 // or due to this not being the initial iteration through a loop where we
9562 // couldn't compute the evolution of this particular PHI last time.
9563 if (isa<PHINode>(I)) return nullptr;
9564
9565 std::vector<Constant*> Operands(I->getNumOperands());
9566
9567 for (unsigned i = 0, e = I->getNumOperands(); i != e; ++i) {
9568 Instruction *Operand = dyn_cast<Instruction>(I->getOperand(i));
9569 if (!Operand) {
9570 Operands[i] = dyn_cast<Constant>(I->getOperand(i));
9571 if (!Operands[i]) return nullptr;
9572 continue;
9573 }
9574 Constant *C = EvaluateExpression(Operand, L, Vals, DL, TLI);
9575 Vals[Operand] = C;
9576 if (!C) return nullptr;
9577 Operands[i] = C;
9578 }
9579
9580 return ConstantFoldInstOperands(I, Operands, DL, TLI,
9581 /*AllowNonDeterministic=*/false);
9582}
9583
9584
9585// If every incoming value to PN except the one for BB is a specific Constant,
9586// return that, else return nullptr.
9588 Constant *IncomingVal = nullptr;
9589
9590 for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) {
9591 if (PN->getIncomingBlock(i) == BB)
9592 continue;
9593
9594 auto *CurrentVal = dyn_cast<Constant>(PN->getIncomingValue(i));
9595 if (!CurrentVal)
9596 return nullptr;
9597
9598 if (IncomingVal != CurrentVal) {
9599 if (IncomingVal)
9600 return nullptr;
9601 IncomingVal = CurrentVal;
9602 }
9603 }
9604
9605 return IncomingVal;
9606}
9607
9608/// getConstantEvolutionLoopExitValue - If we know that the specified Phi is
9609/// in the header of its containing loop, we know the loop executes a
9610/// constant number of times, and the PHI node is just a recurrence
9611/// involving constants, fold it.
9612Constant *
9613ScalarEvolution::getConstantEvolutionLoopExitValue(PHINode *PN,
9614 const APInt &BEs,
9615 const Loop *L) {
9616 auto I = ConstantEvolutionLoopExitValue.find(PN);
9617 if (I != ConstantEvolutionLoopExitValue.end())
9618 return I->second;
9619
9621 return ConstantEvolutionLoopExitValue[PN] = nullptr; // Not going to evaluate it.
9622
9623 Constant *&RetVal = ConstantEvolutionLoopExitValue[PN];
9624
9626 BasicBlock *Header = L->getHeader();
9627 assert(PN->getParent() == Header && "Can't evaluate PHI not in loop header!");
9628
9629 BasicBlock *Latch = L->getLoopLatch();
9630 if (!Latch)
9631 return nullptr;
9632
9633 for (PHINode &PHI : Header->phis()) {
9634 if (auto *StartCST = getOtherIncomingValue(&PHI, Latch))
9635 CurrentIterVals[&PHI] = StartCST;
9636 }
9637 if (!CurrentIterVals.count(PN))
9638 return RetVal = nullptr;
9639
9640 Value *BEValue = PN->getIncomingValueForBlock(Latch);
9641
9642 // Execute the loop symbolically to determine the exit value.
9643 assert(BEs.getActiveBits() < CHAR_BIT * sizeof(unsigned) &&
9644 "BEs is <= MaxBruteForceIterations which is an 'unsigned'!");
9645
9646 unsigned NumIterations = BEs.getZExtValue(); // must be in range
9647 unsigned IterationNum = 0;
9648 const DataLayout &DL = getDataLayout();
9649 for (; ; ++IterationNum) {
9650 if (IterationNum == NumIterations)
9651 return RetVal = CurrentIterVals[PN]; // Got exit value!
9652
9653 // Compute the value of the PHIs for the next iteration.
9654 // EvaluateExpression adds non-phi values to the CurrentIterVals map.
9656 Constant *NextPHI =
9657 EvaluateExpression(BEValue, L, CurrentIterVals, DL, &TLI);
9658 if (!NextPHI)
9659 return nullptr; // Couldn't evaluate!
9660 NextIterVals[PN] = NextPHI;
9661
9662 bool StoppedEvolving = NextPHI == CurrentIterVals[PN];
9663
9664 // Also evaluate the other PHI nodes. However, we don't get to stop if we
9665 // cease to be able to evaluate one of them or if they stop evolving,
9666 // because that doesn't necessarily prevent us from computing PN.
9668 for (const auto &I : CurrentIterVals) {
9669 PHINode *PHI = dyn_cast<PHINode>(I.first);
9670 if (!PHI || PHI == PN || PHI->getParent() != Header) continue;
9671 PHIsToCompute.emplace_back(PHI, I.second);
9672 }
9673 // We use two distinct loops because EvaluateExpression may invalidate any
9674 // iterators into CurrentIterVals.
9675 for (const auto &I : PHIsToCompute) {
9676 PHINode *PHI = I.first;
9677 Constant *&NextPHI = NextIterVals[PHI];
9678 if (!NextPHI) { // Not already computed.
9679 Value *BEValue = PHI->getIncomingValueForBlock(Latch);
9680 NextPHI = EvaluateExpression(BEValue, L, CurrentIterVals, DL, &TLI);
9681 }
9682 if (NextPHI != I.second)
9683 StoppedEvolving = false;
9684 }
9685
9686 // If all entries in CurrentIterVals == NextIterVals then we can stop
9687 // iterating, the loop can't continue to change.
9688 if (StoppedEvolving)
9689 return RetVal = CurrentIterVals[PN];
9690
9691 CurrentIterVals.swap(NextIterVals);
9692 }
9693}
9694
9695const SCEV *ScalarEvolution::computeExitCountExhaustively(const Loop *L,
9696 Value *Cond,
9697 bool ExitWhen) {
9699 if (!PN) return getCouldNotCompute();
9700
9701 // If the loop is canonicalized, the PHI will have exactly two entries.
9702 // That's the only form we support here.
9703 if (PN->getNumIncomingValues() != 2) return getCouldNotCompute();
9704
9706 BasicBlock *Header = L->getHeader();
9707 assert(PN->getParent() == Header && "Can't evaluate PHI not in loop header!");
9708
9709 BasicBlock *Latch = L->getLoopLatch();
9710 assert(Latch && "Should follow from NumIncomingValues == 2!");
9711
9712 for (PHINode &PHI : Header->phis()) {
9713 if (auto *StartCST = getOtherIncomingValue(&PHI, Latch))
9714 CurrentIterVals[&PHI] = StartCST;
9715 }
9716 if (!CurrentIterVals.count(PN))
9717 return getCouldNotCompute();
9718
9719 // Okay, we find a PHI node that defines the trip count of this loop. Execute
9720 // the loop symbolically to determine when the condition gets a value of
9721 // "ExitWhen".
9722 unsigned MaxIterations = MaxBruteForceIterations; // Limit analysis.
9723 const DataLayout &DL = getDataLayout();
9724 for (unsigned IterationNum = 0; IterationNum != MaxIterations;++IterationNum){
9725 auto *CondVal = dyn_cast_or_null<ConstantInt>(
9726 EvaluateExpression(Cond, L, CurrentIterVals, DL, &TLI));
9727
9728 // Couldn't symbolically evaluate.
9729 if (!CondVal) return getCouldNotCompute();
9730
9731 if (CondVal->getValue() == uint64_t(ExitWhen)) {
9732 ++NumBruteForceTripCountsComputed;
9733 return getConstant(Type::getInt32Ty(getContext()), IterationNum);
9734 }
9735
9736 // Update all the PHI nodes for the next iteration.
9738
9739 // Create a list of which PHIs we need to compute. We want to do this before
9740 // calling EvaluateExpression on them because that may invalidate iterators
9741 // into CurrentIterVals.
9742 SmallVector<PHINode *, 8> PHIsToCompute;
9743 for (const auto &I : CurrentIterVals) {
9744 PHINode *PHI = dyn_cast<PHINode>(I.first);
9745 if (!PHI || PHI->getParent() != Header) continue;
9746 PHIsToCompute.push_back(PHI);
9747 }
9748 for (PHINode *PHI : PHIsToCompute) {
9749 Constant *&NextPHI = NextIterVals[PHI];
9750 if (NextPHI) continue; // Already computed!
9751
9752 Value *BEValue = PHI->getIncomingValueForBlock(Latch);
9753 NextPHI = EvaluateExpression(BEValue, L, CurrentIterVals, DL, &TLI);
9754 }
9755 CurrentIterVals.swap(NextIterVals);
9756 }
9757
9758 // Too many iterations were needed to evaluate.
9759 return getCouldNotCompute();
9760}
9761
9762const SCEV *ScalarEvolution::getSCEVAtScope(const SCEV *V, const Loop *L) {
9764 ValuesAtScopes[V];
9765 // Check to see if we've folded this expression at this loop before.
9766 for (auto &LS : Values)
9767 if (LS.first == L)
9768 return LS.second ? LS.second : V;
9769
9770 Values.emplace_back(L, nullptr);
9771
9772 // Otherwise compute it.
9773 const SCEV *C = computeSCEVAtScope(V, L);
9774 for (auto &LS : reverse(ValuesAtScopes[V]))
9775 if (LS.first == L) {
9776 LS.second = C;
9777 if (!isa<SCEVConstant>(C))
9778 ValuesAtScopesUsers[C].push_back({L, V});
9779 break;
9780 }
9781 return C;
9782}
9783
9784/// This builds up a Constant using the ConstantExpr interface. That way, we
9785/// will return Constants for objects which aren't represented by a
9786/// SCEVConstant, because SCEVConstant is restricted to ConstantInt.
9787/// Returns NULL if the SCEV isn't representable as a Constant.
9789 switch (V->getSCEVType()) {
9790 case scCouldNotCompute:
9791 case scAddRecExpr:
9792 case scVScale:
9793 return nullptr;
9794 case scConstant:
9795 return cast<SCEVConstant>(V)->getValue();
9796 case scUnknown:
9797 return dyn_cast<Constant>(cast<SCEVUnknown>(V)->getValue());
9798 case scPtrToInt: {
9799 const SCEVPtrToIntExpr *P2I = cast<SCEVPtrToIntExpr>(V);
9800 if (Constant *CastOp = BuildConstantFromSCEV(P2I->getOperand()))
9801 return ConstantExpr::getPtrToInt(CastOp, P2I->getType());
9802
9803 return nullptr;
9804 }
9805 case scTruncate: {
9806 const SCEVTruncateExpr *ST = cast<SCEVTruncateExpr>(V);
9807 if (Constant *CastOp = BuildConstantFromSCEV(ST->getOperand()))
9808 return ConstantExpr::getTrunc(CastOp, ST->getType());
9809 return nullptr;
9810 }
9811 case scAddExpr: {
9812 const SCEVAddExpr *SA = cast<SCEVAddExpr>(V);
9813 Constant *C = nullptr;
9814 for (const SCEV *Op : SA->operands()) {
9816 if (!OpC)
9817 return nullptr;
9818 if (!C) {
9819 C = OpC;
9820 continue;
9821 }
9822 assert(!C->getType()->isPointerTy() &&
9823 "Can only have one pointer, and it must be last");
9824 if (OpC->getType()->isPointerTy()) {
9825 // The offsets have been converted to bytes. We can add bytes using
9826 // an i8 GEP.
9828 OpC, C);
9829 } else {
9830 C = ConstantExpr::getAdd(C, OpC);
9831 }
9832 }
9833 return C;
9834 }
9835 case scMulExpr:
9836 case scSignExtend:
9837 case scZeroExtend:
9838 case scUDivExpr:
9839 case scSMaxExpr:
9840 case scUMaxExpr:
9841 case scSMinExpr:
9842 case scUMinExpr:
9844 return nullptr;
9845 }
9846 llvm_unreachable("Unknown SCEV kind!");
9847}
9848
9849const SCEV *
9850ScalarEvolution::getWithOperands(const SCEV *S,
9852 switch (S->getSCEVType()) {
9853 case scTruncate:
9854 case scZeroExtend:
9855 case scSignExtend:
9856 case scPtrToInt:
9857 return getCastExpr(S->getSCEVType(), NewOps[0], S->getType());
9858 case scAddRecExpr: {
9859 auto *AddRec = cast<SCEVAddRecExpr>(S);
9860 return getAddRecExpr(NewOps, AddRec->getLoop(), AddRec->getNoWrapFlags());
9861 }
9862 case scAddExpr:
9863 return getAddExpr(NewOps, cast<SCEVAddExpr>(S)->getNoWrapFlags());
9864 case scMulExpr:
9865 return getMulExpr(NewOps, cast<SCEVMulExpr>(S)->getNoWrapFlags());
9866 case scUDivExpr:
9867 return getUDivExpr(NewOps[0], NewOps[1]);
9868 case scUMaxExpr:
9869 case scSMaxExpr:
9870 case scUMinExpr:
9871 case scSMinExpr:
9872 return getMinMaxExpr(S->getSCEVType(), NewOps);
9874 return getSequentialMinMaxExpr(S->getSCEVType(), NewOps);
9875 case scConstant:
9876 case scVScale:
9877 case scUnknown:
9878 return S;
9879 case scCouldNotCompute:
9880 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
9881 }
9882 llvm_unreachable("Unknown SCEV kind!");
9883}
9884
9885const SCEV *ScalarEvolution::computeSCEVAtScope(const SCEV *V, const Loop *L) {
9886 switch (V->getSCEVType()) {
9887 case scConstant:
9888 case scVScale:
9889 return V;
9890 case scAddRecExpr: {
9891 // If this is a loop recurrence for a loop that does not contain L, then we
9892 // are dealing with the final value computed by the loop.
9893 const SCEVAddRecExpr *AddRec = cast<SCEVAddRecExpr>(V);
9894 // First, attempt to evaluate each operand.
9895 // Avoid performing the look-up in the common case where the specified
9896 // expression has no loop-variant portions.
9897 for (unsigned i = 0, e = AddRec->getNumOperands(); i != e; ++i) {
9898 const SCEV *OpAtScope = getSCEVAtScope(AddRec->getOperand(i), L);
9899 if (OpAtScope == AddRec->getOperand(i))
9900 continue;
9901
9902 // Okay, at least one of these operands is loop variant but might be
9903 // foldable. Build a new instance of the folded commutative expression.
9905 NewOps.reserve(AddRec->getNumOperands());
9906 append_range(NewOps, AddRec->operands().take_front(i));
9907 NewOps.push_back(OpAtScope);
9908 for (++i; i != e; ++i)
9909 NewOps.push_back(getSCEVAtScope(AddRec->getOperand(i), L));
9910
9911 const SCEV *FoldedRec = getAddRecExpr(
9912 NewOps, AddRec->getLoop(), AddRec->getNoWrapFlags(SCEV::FlagNW));
9913 AddRec = dyn_cast<SCEVAddRecExpr>(FoldedRec);
9914 // The addrec may be folded to a nonrecurrence, for example, if the
9915 // induction variable is multiplied by zero after constant folding. Go
9916 // ahead and return the folded value.
9917 if (!AddRec)
9918 return FoldedRec;
9919 break;
9920 }
9921
9922 // If the scope is outside the addrec's loop, evaluate it by using the
9923 // loop exit value of the addrec.
9924 if (!AddRec->getLoop()->contains(L)) {
9925 // To evaluate this recurrence, we need to know how many times the AddRec
9926 // loop iterates. Compute this now.
9927 const SCEV *BackedgeTakenCount = getBackedgeTakenCount(AddRec->getLoop());
9928 if (BackedgeTakenCount == getCouldNotCompute())
9929 return AddRec;
9930
9931 // Then, evaluate the AddRec.
9932 return AddRec->evaluateAtIteration(BackedgeTakenCount, *this);
9933 }
9934
9935 return AddRec;
9936 }
9937 case scTruncate:
9938 case scZeroExtend:
9939 case scSignExtend:
9940 case scPtrToInt:
9941 case scAddExpr:
9942 case scMulExpr:
9943 case scUDivExpr:
9944 case scUMaxExpr:
9945 case scSMaxExpr:
9946 case scUMinExpr:
9947 case scSMinExpr:
9948 case scSequentialUMinExpr: {
9949 ArrayRef<const SCEV *> Ops = V->operands();
9950 // Avoid performing the look-up in the common case where the specified
9951 // expression has no loop-variant portions.
9952 for (unsigned i = 0, e = Ops.size(); i != e; ++i) {
9953 const SCEV *OpAtScope = getSCEVAtScope(Ops[i], L);
9954 if (OpAtScope != Ops[i]) {
9955 // Okay, at least one of these operands is loop variant but might be
9956 // foldable. Build a new instance of the folded commutative expression.
9958 NewOps.reserve(Ops.size());
9959 append_range(NewOps, Ops.take_front(i));
9960 NewOps.push_back(OpAtScope);
9961
9962 for (++i; i != e; ++i) {
9963 OpAtScope = getSCEVAtScope(Ops[i], L);
9964 NewOps.push_back(OpAtScope);
9965 }
9966
9967 return getWithOperands(V, NewOps);
9968 }
9969 }
9970 // If we got here, all operands are loop invariant.
9971 return V;
9972 }
9973 case scUnknown: {
9974 // If this instruction is evolved from a constant-evolving PHI, compute the
9975 // exit value from the loop without using SCEVs.
9976 const SCEVUnknown *SU = cast<SCEVUnknown>(V);
9977 Instruction *I = dyn_cast<Instruction>(SU->getValue());
9978 if (!I)
9979 return V; // This is some other type of SCEVUnknown, just return it.
9980
9981 if (PHINode *PN = dyn_cast<PHINode>(I)) {
9982 const Loop *CurrLoop = this->LI[I->getParent()];
9983 // Looking for loop exit value.
9984 if (CurrLoop && CurrLoop->getParentLoop() == L &&
9985 PN->getParent() == CurrLoop->getHeader()) {
9986 // Okay, there is no closed form solution for the PHI node. Check
9987 // to see if the loop that contains it has a known backedge-taken
9988 // count. If so, we may be able to force computation of the exit
9989 // value.
9990 const SCEV *BackedgeTakenCount = getBackedgeTakenCount(CurrLoop);
9991 // This trivial case can show up in some degenerate cases where
9992 // the incoming IR has not yet been fully simplified.
9993 if (BackedgeTakenCount->isZero()) {
9994 Value *InitValue = nullptr;
9995 bool MultipleInitValues = false;
9996 for (unsigned i = 0; i < PN->getNumIncomingValues(); i++) {
9997 if (!CurrLoop->contains(PN->getIncomingBlock(i))) {
9998 if (!InitValue)
9999 InitValue = PN->getIncomingValue(i);
10000 else if (InitValue != PN->getIncomingValue(i)) {
10001 MultipleInitValues = true;
10002 break;
10003 }
10004 }
10005 }
10006 if (!MultipleInitValues && InitValue)
10007 return getSCEV(InitValue);
10008 }
10009 // Do we have a loop invariant value flowing around the backedge
10010 // for a loop which must execute the backedge?
10011 if (!isa<SCEVCouldNotCompute>(BackedgeTakenCount) &&
10012 isKnownNonZero(BackedgeTakenCount) &&
10013 PN->getNumIncomingValues() == 2) {
10014
10015 unsigned InLoopPred =
10016 CurrLoop->contains(PN->getIncomingBlock(0)) ? 0 : 1;
10017 Value *BackedgeVal = PN->getIncomingValue(InLoopPred);
10018 if (CurrLoop->isLoopInvariant(BackedgeVal))
10019 return getSCEV(BackedgeVal);
10020 }
10021 if (auto *BTCC = dyn_cast<SCEVConstant>(BackedgeTakenCount)) {
10022 // Okay, we know how many times the containing loop executes. If
10023 // this is a constant evolving PHI node, get the final value at
10024 // the specified iteration number.
10025 Constant *RV =
10026 getConstantEvolutionLoopExitValue(PN, BTCC->getAPInt(), CurrLoop);
10027 if (RV)
10028 return getSCEV(RV);
10029 }
10030 }
10031 }
10032
10033 // Okay, this is an expression that we cannot symbolically evaluate
10034 // into a SCEV. Check to see if it's possible to symbolically evaluate
10035 // the arguments into constants, and if so, try to constant propagate the
10036 // result. This is particularly useful for computing loop exit values.
10037 if (!CanConstantFold(I))
10038 return V; // This is some other type of SCEVUnknown, just return it.
10039
10041 Operands.reserve(I->getNumOperands());
10042 bool MadeImprovement = false;
10043 for (Value *Op : I->operands()) {
10044 if (Constant *C = dyn_cast<Constant>(Op)) {
10045 Operands.push_back(C);
10046 continue;
10047 }
10048
10049 // If any of the operands is non-constant and if they are
10050 // non-integer and non-pointer, don't even try to analyze them
10051 // with scev techniques.
10052 if (!isSCEVable(Op->getType()))
10053 return V;
10054
10055 const SCEV *OrigV = getSCEV(Op);
10056 const SCEV *OpV = getSCEVAtScope(OrigV, L);
10057 MadeImprovement |= OrigV != OpV;
10058
10060 if (!C)
10061 return V;
10062 assert(C->getType() == Op->getType() && "Type mismatch");
10063 Operands.push_back(C);
10064 }
10065
10066 // Check to see if getSCEVAtScope actually made an improvement.
10067 if (!MadeImprovement)
10068 return V; // This is some other type of SCEVUnknown, just return it.
10069
10070 Constant *C = nullptr;
10071 const DataLayout &DL = getDataLayout();
10072 C = ConstantFoldInstOperands(I, Operands, DL, &TLI,
10073 /*AllowNonDeterministic=*/false);
10074 if (!C)
10075 return V;
10076 return getSCEV(C);
10077 }
10078 case scCouldNotCompute:
10079 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
10080 }
10081 llvm_unreachable("Unknown SCEV type!");
10082}
10083
10085 return getSCEVAtScope(getSCEV(V), L);
10086}
10087
10088const SCEV *ScalarEvolution::stripInjectiveFunctions(const SCEV *S) const {
10089 if (const SCEVZeroExtendExpr *ZExt = dyn_cast<SCEVZeroExtendExpr>(S))
10090 return stripInjectiveFunctions(ZExt->getOperand());
10091 if (const SCEVSignExtendExpr *SExt = dyn_cast<SCEVSignExtendExpr>(S))
10092 return stripInjectiveFunctions(SExt->getOperand());
10093 return S;
10094}
10095
10096/// Finds the minimum unsigned root of the following equation:
10097///
10098/// A * X = B (mod N)
10099///
10100/// where N = 2^BW and BW is the common bit width of A and B. The signedness of
10101/// A and B isn't important.
10102///
10103/// If the equation does not have a solution, SCEVCouldNotCompute is returned.
10104static const SCEV *SolveLinEquationWithOverflow(const APInt &A, const SCEV *B,
10105 ScalarEvolution &SE) {
10106 uint32_t BW = A.getBitWidth();
10107 assert(BW == SE.getTypeSizeInBits(B->getType()));
10108 assert(A != 0 && "A must be non-zero.");
10109
10110 // 1. D = gcd(A, N)
10111 //
10112 // The gcd of A and N may have only one prime factor: 2. The number of
10113 // trailing zeros in A is its multiplicity
10114 uint32_t Mult2 = A.countr_zero();
10115 // D = 2^Mult2
10116
10117 // 2. Check if B is divisible by D.
10118 //
10119 // B is divisible by D if and only if the multiplicity of prime factor 2 for B
10120 // is not less than multiplicity of this prime factor for D.
10121 if (SE.getMinTrailingZeros(B) < Mult2)
10122 return SE.getCouldNotCompute();
10123
10124 // 3. Compute I: the multiplicative inverse of (A / D) in arithmetic
10125 // modulo (N / D).
10126 //
10127 // If D == 1, (N / D) == N == 2^BW, so we need one extra bit to represent
10128 // (N / D) in general. The inverse itself always fits into BW bits, though,
10129 // so we immediately truncate it.
10130 APInt AD = A.lshr(Mult2).trunc(BW - Mult2); // AD = A / D
10131 APInt I = AD.multiplicativeInverse().zext(BW);
10132
10133 // 4. Compute the minimum unsigned root of the equation:
10134 // I * (B / D) mod (N / D)
10135 // To simplify the computation, we factor out the divide by D:
10136 // (I * B mod N) / D
10137 const SCEV *D = SE.getConstant(APInt::getOneBitSet(BW, Mult2));
10138 return SE.getUDivExactExpr(SE.getMulExpr(B, SE.getConstant(I)), D);
10139}
10140
10141/// For a given quadratic addrec, generate coefficients of the corresponding
10142/// quadratic equation, multiplied by a common value to ensure that they are
10143/// integers.
10144/// The returned value is a tuple { A, B, C, M, BitWidth }, where
10145/// Ax^2 + Bx + C is the quadratic function, M is the value that A, B and C
10146/// were multiplied by, and BitWidth is the bit width of the original addrec
10147/// coefficients.
10148/// This function returns std::nullopt if the addrec coefficients are not
10149/// compile- time constants.
10150static std::optional<std::tuple<APInt, APInt, APInt, APInt, unsigned>>
10152 assert(AddRec->getNumOperands() == 3 && "This is not a quadratic chrec!");
10153 const SCEVConstant *LC = dyn_cast<SCEVConstant>(AddRec->getOperand(0));
10154 const SCEVConstant *MC = dyn_cast<SCEVConstant>(AddRec->getOperand(1));
10155 const SCEVConstant *NC = dyn_cast<SCEVConstant>(AddRec->getOperand(2));
10156 LLVM_DEBUG(dbgs() << __func__ << ": analyzing quadratic addrec: "
10157 << *AddRec << '\n');
10158
10159 // We currently can only solve this if the coefficients are constants.
10160 if (!LC || !MC || !NC) {
10161 LLVM_DEBUG(dbgs() << __func__ << ": coefficients are not constant\n");
10162 return std::nullopt;
10163 }
10164
10165 APInt L = LC->getAPInt();
10166 APInt M = MC->getAPInt();
10167 APInt N = NC->getAPInt();
10168 assert(!N.isZero() && "This is not a quadratic addrec");
10169
10170 unsigned BitWidth = LC->getAPInt().getBitWidth();
10171 unsigned NewWidth = BitWidth + 1;
10172 LLVM_DEBUG(dbgs() << __func__ << ": addrec coeff bw: "
10173 << BitWidth << '\n');
10174 // The sign-extension (as opposed to a zero-extension) here matches the
10175 // extension used in SolveQuadraticEquationWrap (with the same motivation).
10176 N = N.sext(NewWidth);
10177 M = M.sext(NewWidth);
10178 L = L.sext(NewWidth);
10179
10180 // The increments are M, M+N, M+2N, ..., so the accumulated values are
10181 // L+M, (L+M)+(M+N), (L+M)+(M+N)+(M+2N), ..., that is,
10182 // L+M, L+2M+N, L+3M+3N, ...
10183 // After n iterations the accumulated value Acc is L + nM + n(n-1)/2 N.
10184 //
10185 // The equation Acc = 0 is then
10186 // L + nM + n(n-1)/2 N = 0, or 2L + 2M n + n(n-1) N = 0.
10187 // In a quadratic form it becomes:
10188 // N n^2 + (2M-N) n + 2L = 0.
10189
10190 APInt A = N;
10191 APInt B = 2 * M - A;
10192 APInt C = 2 * L;
10193 APInt T = APInt(NewWidth, 2);
10194 LLVM_DEBUG(dbgs() << __func__ << ": equation " << A << "x^2 + " << B
10195 << "x + " << C << ", coeff bw: " << NewWidth
10196 << ", multiplied by " << T << '\n');
10197 return std::make_tuple(A, B, C, T, BitWidth);
10198}
10199
10200/// Helper function to compare optional APInts:
10201/// (a) if X and Y both exist, return min(X, Y),
10202/// (b) if neither X nor Y exist, return std::nullopt,
10203/// (c) if exactly one of X and Y exists, return that value.
10204static std::optional<APInt> MinOptional(std::optional<APInt> X,
10205 std::optional<APInt> Y) {
10206 if (X && Y) {
10207 unsigned W = std::max(X->getBitWidth(), Y->getBitWidth());
10208 APInt XW = X->sext(W);
10209 APInt YW = Y->sext(W);
10210 return XW.slt(YW) ? *X : *Y;
10211 }
10212 if (!X && !Y)
10213 return std::nullopt;
10214 return X ? *X : *Y;
10215}
10216
10217/// Helper function to truncate an optional APInt to a given BitWidth.
10218/// When solving addrec-related equations, it is preferable to return a value
10219/// that has the same bit width as the original addrec's coefficients. If the
10220/// solution fits in the original bit width, truncate it (except for i1).
10221/// Returning a value of a different bit width may inhibit some optimizations.
10222///
10223/// In general, a solution to a quadratic equation generated from an addrec
10224/// may require BW+1 bits, where BW is the bit width of the addrec's
10225/// coefficients. The reason is that the coefficients of the quadratic
10226/// equation are BW+1 bits wide (to avoid truncation when converting from
10227/// the addrec to the equation).
10228static std::optional<APInt> TruncIfPossible(std::optional<APInt> X,
10229 unsigned BitWidth) {
10230 if (!X)
10231 return std::nullopt;
10232 unsigned W = X->getBitWidth();
10233 if (BitWidth > 1 && BitWidth < W && X->isIntN(BitWidth))
10234 return X->trunc(BitWidth);
10235 return X;
10236}
10237
10238/// Let c(n) be the value of the quadratic chrec {L,+,M,+,N} after n
10239/// iterations. The values L, M, N are assumed to be signed, and they
10240/// should all have the same bit widths.
10241/// Find the least n >= 0 such that c(n) = 0 in the arithmetic modulo 2^BW,
10242/// where BW is the bit width of the addrec's coefficients.
10243/// If the calculated value is a BW-bit integer (for BW > 1), it will be
10244/// returned as such, otherwise the bit width of the returned value may
10245/// be greater than BW.
10246///
10247/// This function returns std::nullopt if
10248/// (a) the addrec coefficients are not constant, or
10249/// (b) SolveQuadraticEquationWrap was unable to find a solution. For cases
10250/// like x^2 = 5, no integer solutions exist, in other cases an integer
10251/// solution may exist, but SolveQuadraticEquationWrap may fail to find it.
10252static std::optional<APInt>
10254 APInt A, B, C, M;
10255 unsigned BitWidth;
10256 auto T = GetQuadraticEquation(AddRec);
10257 if (!T)
10258 return std::nullopt;
10259
10260 std::tie(A, B, C, M, BitWidth) = *T;
10261 LLVM_DEBUG(dbgs() << __func__ << ": solving for unsigned overflow\n");
10262 std::optional<APInt> X =
10264 if (!X)
10265 return std::nullopt;
10266
10267 ConstantInt *CX = ConstantInt::get(SE.getContext(), *X);
10268 ConstantInt *V = EvaluateConstantChrecAtConstant(AddRec, CX, SE);
10269 if (!V->isZero())
10270 return std::nullopt;
10271
10272 return TruncIfPossible(X, BitWidth);
10273}
10274
10275/// Let c(n) be the value of the quadratic chrec {0,+,M,+,N} after n
10276/// iterations. The values M, N are assumed to be signed, and they
10277/// should all have the same bit widths.
10278/// Find the least n such that c(n) does not belong to the given range,
10279/// while c(n-1) does.
10280///
10281/// This function returns std::nullopt if
10282/// (a) the addrec coefficients are not constant, or
10283/// (b) SolveQuadraticEquationWrap was unable to find a solution for the
10284/// bounds of the range.
10285static std::optional<APInt>
10287 const ConstantRange &Range, ScalarEvolution &SE) {
10288 assert(AddRec->getOperand(0)->isZero() &&
10289 "Starting value of addrec should be 0");
10290 LLVM_DEBUG(dbgs() << __func__ << ": solving boundary crossing for range "
10291 << Range << ", addrec " << *AddRec << '\n');
10292 // This case is handled in getNumIterationsInRange. Here we can assume that
10293 // we start in the range.
10294 assert(Range.contains(APInt(SE.getTypeSizeInBits(AddRec->getType()), 0)) &&
10295 "Addrec's initial value should be in range");
10296
10297 APInt A, B, C, M;
10298 unsigned BitWidth;
10299 auto T = GetQuadraticEquation(AddRec);
10300 if (!T)
10301 return std::nullopt;
10302
10303 // Be careful about the return value: there can be two reasons for not
10304 // returning an actual number. First, if no solutions to the equations
10305 // were found, and second, if the solutions don't leave the given range.
10306 // The first case means that the actual solution is "unknown", the second
10307 // means that it's known, but not valid. If the solution is unknown, we
10308 // cannot make any conclusions.
10309 // Return a pair: the optional solution and a flag indicating if the
10310 // solution was found.
10311 auto SolveForBoundary =
10312 [&](APInt Bound) -> std::pair<std::optional<APInt>, bool> {
10313 // Solve for signed overflow and unsigned overflow, pick the lower
10314 // solution.
10315 LLVM_DEBUG(dbgs() << "SolveQuadraticAddRecRange: checking boundary "
10316 << Bound << " (before multiplying by " << M << ")\n");
10317 Bound *= M; // The quadratic equation multiplier.
10318
10319 std::optional<APInt> SO;
10320 if (BitWidth > 1) {
10321 LLVM_DEBUG(dbgs() << "SolveQuadraticAddRecRange: solving for "
10322 "signed overflow\n");
10324 }
10325 LLVM_DEBUG(dbgs() << "SolveQuadraticAddRecRange: solving for "
10326 "unsigned overflow\n");
10327 std::optional<APInt> UO =
10329
10330 auto LeavesRange = [&] (const APInt &X) {
10331 ConstantInt *C0 = ConstantInt::get(SE.getContext(), X);
10332 ConstantInt *V0 = EvaluateConstantChrecAtConstant(AddRec, C0, SE);
10333 if (Range.contains(V0->getValue()))
10334 return false;
10335 // X should be at least 1, so X-1 is non-negative.
10336 ConstantInt *C1 = ConstantInt::get(SE.getContext(), X-1);
10337 ConstantInt *V1 = EvaluateConstantChrecAtConstant(AddRec, C1, SE);
10338 if (Range.contains(V1->getValue()))
10339 return true;
10340 return false;
10341 };
10342
10343 // If SolveQuadraticEquationWrap returns std::nullopt, it means that there
10344 // can be a solution, but the function failed to find it. We cannot treat it
10345 // as "no solution".
10346 if (!SO || !UO)
10347 return {std::nullopt, false};
10348
10349 // Check the smaller value first to see if it leaves the range.
10350 // At this point, both SO and UO must have values.
10351 std::optional<APInt> Min = MinOptional(SO, UO);
10352 if (LeavesRange(*Min))
10353 return { Min, true };
10354 std::optional<APInt> Max = Min == SO ? UO : SO;
10355 if (LeavesRange(*Max))
10356 return { Max, true };
10357
10358 // Solutions were found, but were eliminated, hence the "true".
10359 return {std::nullopt, true};
10360 };
10361
10362 std::tie(A, B, C, M, BitWidth) = *T;
10363 // Lower bound is inclusive, subtract 1 to represent the exiting value.
10364 APInt Lower = Range.getLower().sext(A.getBitWidth()) - 1;
10365 APInt Upper = Range.getUpper().sext(A.getBitWidth());
10366 auto SL = SolveForBoundary(Lower);
10367 auto SU = SolveForBoundary(Upper);
10368 // If any of the solutions was unknown, no meaninigful conclusions can
10369 // be made.
10370 if (!SL.second || !SU.second)
10371 return std::nullopt;
10372
10373 // Claim: The correct solution is not some value between Min and Max.
10374 //
10375 // Justification: Assuming that Min and Max are different values, one of
10376 // them is when the first signed overflow happens, the other is when the
10377 // first unsigned overflow happens. Crossing the range boundary is only
10378 // possible via an overflow (treating 0 as a special case of it, modeling
10379 // an overflow as crossing k*2^W for some k).
10380 //
10381 // The interesting case here is when Min was eliminated as an invalid
10382 // solution, but Max was not. The argument is that if there was another
10383 // overflow between Min and Max, it would also have been eliminated if
10384 // it was considered.
10385 //
10386 // For a given boundary, it is possible to have two overflows of the same
10387 // type (signed/unsigned) without having the other type in between: this
10388 // can happen when the vertex of the parabola is between the iterations
10389 // corresponding to the overflows. This is only possible when the two
10390 // overflows cross k*2^W for the same k. In such case, if the second one
10391 // left the range (and was the first one to do so), the first overflow
10392 // would have to enter the range, which would mean that either we had left
10393 // the range before or that we started outside of it. Both of these cases
10394 // are contradictions.
10395 //
10396 // Claim: In the case where SolveForBoundary returns std::nullopt, the correct
10397 // solution is not some value between the Max for this boundary and the
10398 // Min of the other boundary.
10399 //
10400 // Justification: Assume that we had such Max_A and Min_B corresponding
10401 // to range boundaries A and B and such that Max_A < Min_B. If there was
10402 // a solution between Max_A and Min_B, it would have to be caused by an
10403 // overflow corresponding to either A or B. It cannot correspond to B,
10404 // since Min_B is the first occurrence of such an overflow. If it
10405 // corresponded to A, it would have to be either a signed or an unsigned
10406 // overflow that is larger than both eliminated overflows for A. But
10407 // between the eliminated overflows and this overflow, the values would
10408 // cover the entire value space, thus crossing the other boundary, which
10409 // is a contradiction.
10410
10411 return TruncIfPossible(MinOptional(SL.first, SU.first), BitWidth);
10412}
10413
10414ScalarEvolution::ExitLimit ScalarEvolution::howFarToZero(const SCEV *V,
10415 const Loop *L,
10416 bool ControlsOnlyExit,
10417 bool AllowPredicates) {
10418
10419 // This is only used for loops with a "x != y" exit test. The exit condition
10420 // is now expressed as a single expression, V = x-y. So the exit test is
10421 // effectively V != 0. We know and take advantage of the fact that this
10422 // expression only being used in a comparison by zero context.
10423
10425 // If the value is a constant
10426 if (const SCEVConstant *C = dyn_cast<SCEVConstant>(V)) {
10427 // If the value is already zero, the branch will execute zero times.
10428 if (C->getValue()->isZero()) return C;
10429 return getCouldNotCompute(); // Otherwise it will loop infinitely.
10430 }
10431
10432 const SCEVAddRecExpr *AddRec =
10433 dyn_cast<SCEVAddRecExpr>(stripInjectiveFunctions(V));
10434
10435 if (!AddRec && AllowPredicates)
10436 // Try to make this an AddRec using runtime tests, in the first X
10437 // iterations of this loop, where X is the SCEV expression found by the
10438 // algorithm below.
10439 AddRec = convertSCEVToAddRecWithPredicates(V, L, Predicates);
10440
10441 if (!AddRec || AddRec->getLoop() != L)
10442 return getCouldNotCompute();
10443
10444 // If this is a quadratic (3-term) AddRec {L,+,M,+,N}, find the roots of
10445 // the quadratic equation to solve it.
10446 if (AddRec->isQuadratic() && AddRec->getType()->isIntegerTy()) {
10447 // We can only use this value if the chrec ends up with an exact zero
10448 // value at this index. When solving for "X*X != 5", for example, we
10449 // should not accept a root of 2.
10450 if (auto S = SolveQuadraticAddRecExact(AddRec, *this)) {
10451 const auto *R = cast<SCEVConstant>(getConstant(*S));
10452 return ExitLimit(R, R, R, false, Predicates);
10453 }
10454 return getCouldNotCompute();
10455 }
10456
10457 // Otherwise we can only handle this if it is affine.
10458 if (!AddRec->isAffine())
10459 return getCouldNotCompute();
10460
10461 // If this is an affine expression, the execution count of this branch is
10462 // the minimum unsigned root of the following equation:
10463 //
10464 // Start + Step*N = 0 (mod 2^BW)
10465 //
10466 // equivalent to:
10467 //
10468 // Step*N = -Start (mod 2^BW)
10469 //
10470 // where BW is the common bit width of Start and Step.
10471
10472 // Get the initial value for the loop.
10473 const SCEV *Start = getSCEVAtScope(AddRec->getStart(), L->getParentLoop());
10474 const SCEV *Step = getSCEVAtScope(AddRec->getOperand(1), L->getParentLoop());
10475 const SCEVConstant *StepC = dyn_cast<SCEVConstant>(Step);
10476
10477 if (!isLoopInvariant(Step, L))
10478 return getCouldNotCompute();
10479
10480 LoopGuards Guards = LoopGuards::collect(L, *this);
10481 // Specialize step for this loop so we get context sensitive facts below.
10482 const SCEV *StepWLG = applyLoopGuards(Step, Guards);
10483
10484 // For positive steps (counting up until unsigned overflow):
10485 // N = -Start/Step (as unsigned)
10486 // For negative steps (counting down to zero):
10487 // N = Start/-Step
10488 // First compute the unsigned distance from zero in the direction of Step.
10489 bool CountDown = isKnownNegative(StepWLG);
10490 if (!CountDown && !isKnownNonNegative(StepWLG))
10491 return getCouldNotCompute();
10492
10493 const SCEV *Distance = CountDown ? Start : getNegativeSCEV(Start);
10494 // Handle unitary steps, which cannot wraparound.
10495 // 1*N = -Start; -1*N = Start (mod 2^BW), so:
10496 // N = Distance (as unsigned)
10497 if (StepC &&
10498 (StepC->getValue()->isOne() || StepC->getValue()->isMinusOne())) {
10499 APInt MaxBECount = getUnsignedRangeMax(applyLoopGuards(Distance, Guards));
10500 MaxBECount = APIntOps::umin(MaxBECount, getUnsignedRangeMax(Distance));
10501
10502 // When a loop like "for (int i = 0; i != n; ++i) { /* body */ }" is rotated,
10503 // we end up with a loop whose backedge-taken count is n - 1. Detect this
10504 // case, and see if we can improve the bound.
10505 //
10506 // Explicitly handling this here is necessary because getUnsignedRange
10507 // isn't context-sensitive; it doesn't know that we only care about the
10508 // range inside the loop.
10509 const SCEV *Zero = getZero(Distance->getType());
10510 const SCEV *One = getOne(Distance->getType());
10511 const SCEV *DistancePlusOne = getAddExpr(Distance, One);
10512 if (isLoopEntryGuardedByCond(L, ICmpInst::ICMP_NE, DistancePlusOne, Zero)) {
10513 // If Distance + 1 doesn't overflow, we can compute the maximum distance
10514 // as "unsigned_max(Distance + 1) - 1".
10515 ConstantRange CR = getUnsignedRange(DistancePlusOne);
10516 MaxBECount = APIntOps::umin(MaxBECount, CR.getUnsignedMax() - 1);
10517 }
10518 return ExitLimit(Distance, getConstant(MaxBECount), Distance, false,
10519 Predicates);
10520 }
10521
10522 // If the condition controls loop exit (the loop exits only if the expression
10523 // is true) and the addition is no-wrap we can use unsigned divide to
10524 // compute the backedge count. In this case, the step may not divide the
10525 // distance, but we don't care because if the condition is "missed" the loop
10526 // will have undefined behavior due to wrapping.
10527 if (ControlsOnlyExit && AddRec->hasNoSelfWrap() &&
10528 loopHasNoAbnormalExits(AddRec->getLoop())) {
10529
10530 // If the stride is zero, the loop must be infinite. In C++, most loops
10531 // are finite by assumption, in which case the step being zero implies
10532 // UB must execute if the loop is entered.
10533 if (!loopIsFiniteByAssumption(L) && !isKnownNonZero(StepWLG))
10534 return getCouldNotCompute();
10535
10536 const SCEV *Exact =
10537 getUDivExpr(Distance, CountDown ? getNegativeSCEV(Step) : Step);
10538 const SCEV *ConstantMax = getCouldNotCompute();
10539 if (Exact != getCouldNotCompute()) {
10541 ConstantMax =
10543 }
10544 const SCEV *SymbolicMax =
10545 isa<SCEVCouldNotCompute>(Exact) ? ConstantMax : Exact;
10546 return ExitLimit(Exact, ConstantMax, SymbolicMax, false, Predicates);
10547 }
10548
10549 // Solve the general equation.
10550 if (!StepC || StepC->getValue()->isZero())
10551 return getCouldNotCompute();
10552 const SCEV *E = SolveLinEquationWithOverflow(StepC->getAPInt(),
10553 getNegativeSCEV(Start), *this);
10554
10555 const SCEV *M = E;
10556 if (E != getCouldNotCompute()) {
10557 APInt MaxWithGuards = getUnsignedRangeMax(applyLoopGuards(E, Guards));
10558 M = getConstant(APIntOps::umin(MaxWithGuards, getUnsignedRangeMax(E)));
10559 }
10560 auto *S = isa<SCEVCouldNotCompute>(E) ? M : E;
10561 return ExitLimit(E, M, S, false, Predicates);
10562}
10563
10565ScalarEvolution::howFarToNonZero(const SCEV *V, const Loop *L) {
10566 // Loops that look like: while (X == 0) are very strange indeed. We don't
10567 // handle them yet except for the trivial case. This could be expanded in the
10568 // future as needed.
10569
10570 // If the value is a constant, check to see if it is known to be non-zero
10571 // already. If so, the backedge will execute zero times.
10572 if (const SCEVConstant *C = dyn_cast<SCEVConstant>(V)) {
10573 if (!C->getValue()->isZero())
10574 return getZero(C->getType());
10575 return getCouldNotCompute(); // Otherwise it will loop infinitely.
10576 }
10577
10578 // We could implement others, but I really doubt anyone writes loops like
10579 // this, and if they did, they would already be constant folded.
10580 return getCouldNotCompute();
10581}
10582
10583std::pair<const BasicBlock *, const BasicBlock *>
10584ScalarEvolution::getPredecessorWithUniqueSuccessorForBB(const BasicBlock *BB)
10585 const {
10586 // If the block has a unique predecessor, then there is no path from the
10587 // predecessor to the block that does not go through the direct edge
10588 // from the predecessor to the block.
10589 if (const BasicBlock *Pred = BB->getSinglePredecessor())
10590 return {Pred, BB};
10591
10592 // A loop's header is defined to be a block that dominates the loop.
10593 // If the header has a unique predecessor outside the loop, it must be
10594 // a block that has exactly one successor that can reach the loop.
10595 if (const Loop *L = LI.getLoopFor(BB))
10596 return {L->getLoopPredecessor(), L->getHeader()};
10597
10598 return {nullptr, nullptr};
10599}
10600
10601/// SCEV structural equivalence is usually sufficient for testing whether two
10602/// expressions are equal, however for the purposes of looking for a condition
10603/// guarding a loop, it can be useful to be a little more general, since a
10604/// front-end may have replicated the controlling expression.
10605static bool HasSameValue(const SCEV *A, const SCEV *B) {
10606 // Quick check to see if they are the same SCEV.
10607 if (A == B) return true;
10608
10609 auto ComputesEqualValues = [](const Instruction *A, const Instruction *B) {
10610 // Not all instructions that are "identical" compute the same value. For
10611 // instance, two distinct alloca instructions allocating the same type are
10612 // identical and do not read memory; but compute distinct values.
10613 return A->isIdenticalTo(B) && (isa<BinaryOperator>(A) || isa<GetElementPtrInst>(A));
10614 };
10615
10616 // Otherwise, if they're both SCEVUnknown, it's possible that they hold
10617 // two different instructions with the same value. Check for this case.
10618 if (const SCEVUnknown *AU = dyn_cast<SCEVUnknown>(A))
10619 if (const SCEVUnknown *BU = dyn_cast<SCEVUnknown>(B))
10620 if (const Instruction *AI = dyn_cast<Instruction>(AU->getValue()))
10621 if (const Instruction *BI = dyn_cast<Instruction>(BU->getValue()))
10622 if (ComputesEqualValues(AI, BI))
10623 return true;
10624
10625 // Otherwise assume they may have a different value.
10626 return false;
10627}
10628
10629static bool MatchBinarySub(const SCEV *S, const SCEV *&LHS, const SCEV *&RHS) {
10630 const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(S);
10631 if (!Add || Add->getNumOperands() != 2)
10632 return false;
10633 if (auto *ME = dyn_cast<SCEVMulExpr>(Add->getOperand(0));
10634 ME && ME->getNumOperands() == 2 && ME->getOperand(0)->isAllOnesValue()) {
10635 LHS = Add->getOperand(1);
10636 RHS = ME->getOperand(1);
10637 return true;
10638 }
10639 if (auto *ME = dyn_cast<SCEVMulExpr>(Add->getOperand(1));
10640 ME && ME->getNumOperands() == 2 && ME->getOperand(0)->isAllOnesValue()) {
10641 LHS = Add->getOperand(0);
10642 RHS = ME->getOperand(1);
10643 return true;
10644 }
10645 return false;
10646}
10647
10649 const SCEV *&LHS, const SCEV *&RHS,
10650 unsigned Depth) {
10651 bool Changed = false;
10652 // Simplifies ICMP to trivial true or false by turning it into '0 == 0' or
10653 // '0 != 0'.
10654 auto TrivialCase = [&](bool TriviallyTrue) {
10656 Pred = TriviallyTrue ? ICmpInst::ICMP_EQ : ICmpInst::ICMP_NE;
10657 return true;
10658 };
10659 // If we hit the max recursion limit bail out.
10660 if (Depth >= 3)
10661 return false;
10662
10663 // Canonicalize a constant to the right side.
10664 if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(LHS)) {
10665 // Check for both operands constant.
10666 if (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS)) {
10667 if (!ICmpInst::compare(LHSC->getAPInt(), RHSC->getAPInt(), Pred))
10668 return TrivialCase(false);
10669 return TrivialCase(true);
10670 }
10671 // Otherwise swap the operands to put the constant on the right.
10672 std::swap(LHS, RHS);
10673 Pred = ICmpInst::getSwappedPredicate(Pred);
10674 Changed = true;
10675 }
10676
10677 // If we're comparing an addrec with a value which is loop-invariant in the
10678 // addrec's loop, put the addrec on the left. Also make a dominance check,
10679 // as both operands could be addrecs loop-invariant in each other's loop.
10680 if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(RHS)) {
10681 const Loop *L = AR->getLoop();
10682 if (isLoopInvariant(LHS, L) && properlyDominates(LHS, L->getHeader())) {
10683 std::swap(LHS, RHS);
10684 Pred = ICmpInst::getSwappedPredicate(Pred);
10685 Changed = true;
10686 }
10687 }
10688
10689 // If there's a constant operand, canonicalize comparisons with boundary
10690 // cases, and canonicalize *-or-equal comparisons to regular comparisons.
10691 if (const SCEVConstant *RC = dyn_cast<SCEVConstant>(RHS)) {
10692 const APInt &RA = RC->getAPInt();
10693
10694 bool SimplifiedByConstantRange = false;
10695
10696 if (!ICmpInst::isEquality(Pred)) {
10698 if (ExactCR.isFullSet())
10699 return TrivialCase(true);
10700 if (ExactCR.isEmptySet())
10701 return TrivialCase(false);
10702
10703 APInt NewRHS;
10704 CmpInst::Predicate NewPred;
10705 if (ExactCR.getEquivalentICmp(NewPred, NewRHS) &&
10706 ICmpInst::isEquality(NewPred)) {
10707 // We were able to convert an inequality to an equality.
10708 Pred = NewPred;
10709 RHS = getConstant(NewRHS);
10710 Changed = SimplifiedByConstantRange = true;
10711 }
10712 }
10713
10714 if (!SimplifiedByConstantRange) {
10715 switch (Pred) {
10716 default:
10717 break;
10718 case ICmpInst::ICMP_EQ:
10719 case ICmpInst::ICMP_NE:
10720 // Fold ((-1) * %a) + %b == 0 (equivalent to %b-%a == 0) into %a == %b.
10721 if (RA.isZero() && MatchBinarySub(LHS, LHS, RHS))
10722 Changed = true;
10723 break;
10724
10725 // The "Should have been caught earlier!" messages refer to the fact
10726 // that the ExactCR.isFullSet() or ExactCR.isEmptySet() check above
10727 // should have fired on the corresponding cases, and canonicalized the
10728 // check to trivial case.
10729
10730 case ICmpInst::ICMP_UGE:
10731 assert(!RA.isMinValue() && "Should have been caught earlier!");
10732 Pred = ICmpInst::ICMP_UGT;
10733 RHS = getConstant(RA - 1);
10734 Changed = true;
10735 break;
10736 case ICmpInst::ICMP_ULE:
10737 assert(!RA.isMaxValue() && "Should have been caught earlier!");
10738 Pred = ICmpInst::ICMP_ULT;
10739 RHS = getConstant(RA + 1);
10740 Changed = true;
10741 break;
10742 case ICmpInst::ICMP_SGE:
10743 assert(!RA.isMinSignedValue() && "Should have been caught earlier!");
10744 Pred = ICmpInst::ICMP_SGT;
10745 RHS = getConstant(RA - 1);
10746 Changed = true;
10747 break;
10748 case ICmpInst::ICMP_SLE:
10749 assert(!RA.isMaxSignedValue() && "Should have been caught earlier!");
10750 Pred = ICmpInst::ICMP_SLT;
10751 RHS = getConstant(RA + 1);
10752 Changed = true;
10753 break;
10754 }
10755 }
10756 }
10757
10758 // Check for obvious equality.
10759 if (HasSameValue(LHS, RHS)) {
10760 if (ICmpInst::isTrueWhenEqual(Pred))
10761 return TrivialCase(true);
10763 return TrivialCase(false);
10764 }
10765
10766 // If possible, canonicalize GE/LE comparisons to GT/LT comparisons, by
10767 // adding or subtracting 1 from one of the operands.
10768 switch (Pred) {
10769 case ICmpInst::ICMP_SLE:
10770 if (!getSignedRangeMax(RHS).isMaxSignedValue()) {
10771 RHS = getAddExpr(getConstant(RHS->getType(), 1, true), RHS,
10773 Pred = ICmpInst::ICMP_SLT;
10774 Changed = true;
10775 } else if (!getSignedRangeMin(LHS).isMinSignedValue()) {
10776 LHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), LHS,
10778 Pred = ICmpInst::ICMP_SLT;
10779 Changed = true;
10780 }
10781 break;
10782 case ICmpInst::ICMP_SGE:
10783 if (!getSignedRangeMin(RHS).isMinSignedValue()) {
10784 RHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), RHS,
10786 Pred = ICmpInst::ICMP_SGT;
10787 Changed = true;
10788 } else if (!getSignedRangeMax(LHS).isMaxSignedValue()) {
10789 LHS = getAddExpr(getConstant(RHS->getType(), 1, true), LHS,
10791 Pred = ICmpInst::ICMP_SGT;
10792 Changed = true;
10793 }
10794 break;
10795 case ICmpInst::ICMP_ULE:
10796 if (!getUnsignedRangeMax(RHS).isMaxValue()) {
10797 RHS = getAddExpr(getConstant(RHS->getType(), 1, true), RHS,
10799 Pred = ICmpInst::ICMP_ULT;
10800 Changed = true;
10801 } else if (!getUnsignedRangeMin(LHS).isMinValue()) {
10802 LHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), LHS);
10803 Pred = ICmpInst::ICMP_ULT;
10804 Changed = true;
10805 }
10806 break;
10807 case ICmpInst::ICMP_UGE:
10808 if (!getUnsignedRangeMin(RHS).isMinValue()) {
10809 RHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), RHS);
10810 Pred = ICmpInst::ICMP_UGT;
10811 Changed = true;
10812 } else if (!getUnsignedRangeMax(LHS).isMaxValue()) {
10813 LHS = getAddExpr(getConstant(RHS->getType(), 1, true), LHS,
10815 Pred = ICmpInst::ICMP_UGT;
10816 Changed = true;
10817 }
10818 break;
10819 default:
10820 break;
10821 }
10822
10823 // TODO: More simplifications are possible here.
10824
10825 // Recursively simplify until we either hit a recursion limit or nothing
10826 // changes.
10827 if (Changed)
10828 return SimplifyICmpOperands(Pred, LHS, RHS, Depth + 1);
10829
10830 return Changed;
10831}
10832
10834 return getSignedRangeMax(S).isNegative();
10835}
10836
10839}
10840
10842 return !getSignedRangeMin(S).isNegative();
10843}
10844
10847}
10848
10850 // Query push down for cases where the unsigned range is
10851 // less than sufficient.
10852 if (const auto *SExt = dyn_cast<SCEVSignExtendExpr>(S))
10853 return isKnownNonZero(SExt->getOperand(0));
10854 return getUnsignedRangeMin(S) != 0;
10855}
10856
10858 bool OrNegative) {
10859 auto NonRecursive = [this, OrNegative](const SCEV *S) {
10860 if (auto *C = dyn_cast<SCEVConstant>(S))
10861 return C->getAPInt().isPowerOf2() ||
10862 (OrNegative && C->getAPInt().isNegatedPowerOf2());
10863
10864 // The vscale_range indicates vscale is a power-of-two.
10865 return isa<SCEVVScale>(S) && F.hasFnAttribute(Attribute::VScaleRange);
10866 };
10867
10868 if (NonRecursive(S))
10869 return true;
10870
10871 auto *Mul = dyn_cast<SCEVMulExpr>(S);
10872 if (!Mul)
10873 return false;
10874 return all_of(Mul->operands(), NonRecursive) && (OrZero || isKnownNonZero(S));
10875}
10876
10877std::pair<const SCEV *, const SCEV *>
10879 // Compute SCEV on entry of loop L.
10880 const SCEV *Start = SCEVInitRewriter::rewrite(S, L, *this);
10881 if (Start == getCouldNotCompute())
10882 return { Start, Start };
10883 // Compute post increment SCEV for loop L.
10884 const SCEV *PostInc = SCEVPostIncRewriter::rewrite(S, L, *this);
10885 assert(PostInc != getCouldNotCompute() && "Unexpected could not compute");
10886 return { Start, PostInc };
10887}
10888
10890 const SCEV *LHS, const SCEV *RHS) {
10891 // First collect all loops.
10893 getUsedLoops(LHS, LoopsUsed);
10894 getUsedLoops(RHS, LoopsUsed);
10895
10896 if (LoopsUsed.empty())
10897 return false;
10898
10899 // Domination relationship must be a linear order on collected loops.
10900#ifndef NDEBUG
10901 for (const auto *L1 : LoopsUsed)
10902 for (const auto *L2 : LoopsUsed)
10903 assert((DT.dominates(L1->getHeader(), L2->getHeader()) ||
10904 DT.dominates(L2->getHeader(), L1->getHeader())) &&
10905 "Domination relationship is not a linear order");
10906#endif
10907
10908 const Loop *MDL =
10909 *llvm::max_element(LoopsUsed, [&](const Loop *L1, const Loop *L2) {
10910 return DT.properlyDominates(L1->getHeader(), L2->getHeader());
10911 });
10912
10913 // Get init and post increment value for LHS.
10914 auto SplitLHS = SplitIntoInitAndPostInc(MDL, LHS);
10915 // if LHS contains unknown non-invariant SCEV then bail out.
10916 if (SplitLHS.first == getCouldNotCompute())
10917 return false;
10918 assert (SplitLHS.second != getCouldNotCompute() && "Unexpected CNC");
10919 // Get init and post increment value for RHS.
10920 auto SplitRHS = SplitIntoInitAndPostInc(MDL, RHS);
10921 // if RHS contains unknown non-invariant SCEV then bail out.
10922 if (SplitRHS.first == getCouldNotCompute())
10923 return false;
10924 assert (SplitRHS.second != getCouldNotCompute() && "Unexpected CNC");
10925 // It is possible that init SCEV contains an invariant load but it does
10926 // not dominate MDL and is not available at MDL loop entry, so we should
10927 // check it here.
10928 if (!isAvailableAtLoopEntry(SplitLHS.first, MDL) ||
10929 !isAvailableAtLoopEntry(SplitRHS.first, MDL))
10930 return false;
10931
10932 // It seems backedge guard check is faster than entry one so in some cases
10933 // it can speed up whole estimation by short circuit
10934 return isLoopBackedgeGuardedByCond(MDL, Pred, SplitLHS.second,
10935 SplitRHS.second) &&
10936 isLoopEntryGuardedByCond(MDL, Pred, SplitLHS.first, SplitRHS.first);
10937}
10938
10940 const SCEV *LHS, const SCEV *RHS) {
10941 // Canonicalize the inputs first.
10942 (void)SimplifyICmpOperands(Pred, LHS, RHS);
10943
10944 if (isKnownViaInduction(Pred, LHS, RHS))
10945 return true;
10946
10947 if (isKnownPredicateViaSplitting(Pred, LHS, RHS))
10948 return true;
10949
10950 // Otherwise see what can be done with some simple reasoning.
10951 return isKnownViaNonRecursiveReasoning(Pred, LHS, RHS);
10952}
10953
10955 const SCEV *LHS,
10956 const SCEV *RHS) {
10957 if (isKnownPredicate(Pred, LHS, RHS))
10958 return true;
10960 return false;
10961 return std::nullopt;
10962}
10963
10965 const SCEV *LHS, const SCEV *RHS,
10966 const Instruction *CtxI) {
10967 // TODO: Analyze guards and assumes from Context's block.
10968 return isKnownPredicate(Pred, LHS, RHS) ||
10970}
10971
10972std::optional<bool>
10974 const SCEV *RHS, const Instruction *CtxI) {
10975 std::optional<bool> KnownWithoutContext = evaluatePredicate(Pred, LHS, RHS);
10976 if (KnownWithoutContext)
10977 return KnownWithoutContext;
10978
10979 if (isBasicBlockEntryGuardedByCond(CtxI->getParent(), Pred, LHS, RHS))
10980 return true;
10983 LHS, RHS))
10984 return false;
10985 return std::nullopt;
10986}
10987
10989 const SCEVAddRecExpr *LHS,
10990 const SCEV *RHS) {
10991 const Loop *L = LHS->getLoop();
10992 return isLoopEntryGuardedByCond(L, Pred, LHS->getStart(), RHS) &&
10993 isLoopBackedgeGuardedByCond(L, Pred, LHS->getPostIncExpr(*this), RHS);
10994}
10995
10996std::optional<ScalarEvolution::MonotonicPredicateType>
10998 ICmpInst::Predicate Pred) {
10999 auto Result = getMonotonicPredicateTypeImpl(LHS, Pred);
11000
11001#ifndef NDEBUG
11002 // Verify an invariant: inverting the predicate should turn a monotonically
11003 // increasing change to a monotonically decreasing one, and vice versa.
11004 if (Result) {
11005 auto ResultSwapped =
11006 getMonotonicPredicateTypeImpl(LHS, ICmpInst::getSwappedPredicate(Pred));
11007
11008 assert(*ResultSwapped != *Result &&
11009 "monotonicity should flip as we flip the predicate");
11010 }
11011#endif
11012
11013 return Result;
11014}
11015
11016std::optional<ScalarEvolution::MonotonicPredicateType>
11017ScalarEvolution::getMonotonicPredicateTypeImpl(const SCEVAddRecExpr *LHS,
11018 ICmpInst::Predicate Pred) {
11019 // A zero step value for LHS means the induction variable is essentially a
11020 // loop invariant value. We don't really depend on the predicate actually
11021 // flipping from false to true (for increasing predicates, and the other way
11022 // around for decreasing predicates), all we care about is that *if* the
11023 // predicate changes then it only changes from false to true.
11024 //
11025 // A zero step value in itself is not very useful, but there may be places
11026 // where SCEV can prove X >= 0 but not prove X > 0, so it is helpful to be
11027 // as general as possible.
11028
11029 // Only handle LE/LT/GE/GT predicates.
11030 if (!ICmpInst::isRelational(Pred))
11031 return std::nullopt;
11032
11033 bool IsGreater = ICmpInst::isGE(Pred) || ICmpInst::isGT(Pred);
11034 assert((IsGreater || ICmpInst::isLE(Pred) || ICmpInst::isLT(Pred)) &&
11035 "Should be greater or less!");
11036
11037 // Check that AR does not wrap.
11038 if (ICmpInst::isUnsigned(Pred)) {
11039 if (!LHS->hasNoUnsignedWrap())
11040 return std::nullopt;
11042 }
11043 assert(ICmpInst::isSigned(Pred) &&
11044 "Relational predicate is either signed or unsigned!");
11045 if (!LHS->hasNoSignedWrap())
11046 return std::nullopt;
11047
11048 const SCEV *Step = LHS->getStepRecurrence(*this);
11049
11050 if (isKnownNonNegative(Step))
11052
11053 if (isKnownNonPositive(Step))
11055
11056 return std::nullopt;
11057}
11058
11059std::optional<ScalarEvolution::LoopInvariantPredicate>
11061 const SCEV *LHS, const SCEV *RHS,
11062 const Loop *L,
11063 const Instruction *CtxI) {
11064 // If there is a loop-invariant, force it into the RHS, otherwise bail out.
11065 if (!isLoopInvariant(RHS, L)) {
11066 if (!isLoopInvariant(LHS, L))
11067 return std::nullopt;
11068
11069 std::swap(LHS, RHS);
11070 Pred = ICmpInst::getSwappedPredicate(Pred);
11071 }
11072
11073 const SCEVAddRecExpr *ArLHS = dyn_cast<SCEVAddRecExpr>(LHS);
11074 if (!ArLHS || ArLHS->getLoop() != L)
11075 return std::nullopt;
11076
11077 auto MonotonicType = getMonotonicPredicateType(ArLHS, Pred);
11078 if (!MonotonicType)
11079 return std::nullopt;
11080 // If the predicate "ArLHS `Pred` RHS" monotonically increases from false to
11081 // true as the loop iterates, and the backedge is control dependent on
11082 // "ArLHS `Pred` RHS" == true then we can reason as follows:
11083 //
11084 // * if the predicate was false in the first iteration then the predicate
11085 // is never evaluated again, since the loop exits without taking the
11086 // backedge.
11087 // * if the predicate was true in the first iteration then it will
11088 // continue to be true for all future iterations since it is
11089 // monotonically increasing.
11090 //
11091 // For both the above possibilities, we can replace the loop varying
11092 // predicate with its value on the first iteration of the loop (which is
11093 // loop invariant).
11094 //
11095 // A similar reasoning applies for a monotonically decreasing predicate, by
11096 // replacing true with false and false with true in the above two bullets.
11097 bool Increasing = *MonotonicType == ScalarEvolution::MonotonicallyIncreasing;
11098 auto P = Increasing ? Pred : ICmpInst::getInversePredicate(Pred);
11099
11102 RHS);
11103
11104 if (!CtxI)
11105 return std::nullopt;
11106 // Try to prove via context.
11107 // TODO: Support other cases.
11108 switch (Pred) {
11109 default:
11110 break;
11111 case ICmpInst::ICMP_ULE:
11112 case ICmpInst::ICMP_ULT: {
11113 assert(ArLHS->hasNoUnsignedWrap() && "Is a requirement of monotonicity!");
11114 // Given preconditions
11115 // (1) ArLHS does not cross the border of positive and negative parts of
11116 // range because of:
11117 // - Positive step; (TODO: lift this limitation)
11118 // - nuw - does not cross zero boundary;
11119 // - nsw - does not cross SINT_MAX boundary;
11120 // (2) ArLHS <s RHS
11121 // (3) RHS >=s 0
11122 // we can replace the loop variant ArLHS <u RHS condition with loop
11123 // invariant Start(ArLHS) <u RHS.
11124 //
11125 // Because of (1) there are two options:
11126 // - ArLHS is always negative. It means that ArLHS <u RHS is always false;
11127 // - ArLHS is always non-negative. Because of (3) RHS is also non-negative.
11128 // It means that ArLHS <s RHS <=> ArLHS <u RHS.
11129 // Because of (2) ArLHS <u RHS is trivially true.
11130 // All together it means that ArLHS <u RHS <=> Start(ArLHS) >=s 0.
11131 // We can strengthen this to Start(ArLHS) <u RHS.
11132 auto SignFlippedPred = ICmpInst::getFlippedSignednessPredicate(Pred);
11133 if (ArLHS->hasNoSignedWrap() && ArLHS->isAffine() &&
11134 isKnownPositive(ArLHS->getStepRecurrence(*this)) &&
11136 isKnownPredicateAt(SignFlippedPred, ArLHS, RHS, CtxI))
11138 RHS);
11139 }
11140 }
11141
11142 return std::nullopt;
11143}
11144
11145std::optional<ScalarEvolution::LoopInvariantPredicate>
11147 ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS, const Loop *L,
11148 const Instruction *CtxI, const SCEV *MaxIter) {
11150 Pred, LHS, RHS, L, CtxI, MaxIter))
11151 return LIP;
11152 if (auto *UMin = dyn_cast<SCEVUMinExpr>(MaxIter))
11153 // Number of iterations expressed as UMIN isn't always great for expressing
11154 // the value on the last iteration. If the straightforward approach didn't
11155 // work, try the following trick: if the a predicate is invariant for X, it
11156 // is also invariant for umin(X, ...). So try to find something that works
11157 // among subexpressions of MaxIter expressed as umin.
11158 for (auto *Op : UMin->operands())
11160 Pred, LHS, RHS, L, CtxI, Op))
11161 return LIP;
11162 return std::nullopt;
11163}
11164
11165std::optional<ScalarEvolution::LoopInvariantPredicate>
11167 ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS, const Loop *L,
11168 const Instruction *CtxI, const SCEV *MaxIter) {
11169 // Try to prove the following set of facts:
11170 // - The predicate is monotonic in the iteration space.
11171 // - If the check does not fail on the 1st iteration:
11172 // - No overflow will happen during first MaxIter iterations;
11173 // - It will not fail on the MaxIter'th iteration.
11174 // If the check does fail on the 1st iteration, we leave the loop and no
11175 // other checks matter.
11176
11177 // If there is a loop-invariant, force it into the RHS, otherwise bail out.
11178 if (!isLoopInvariant(RHS, L)) {
11179 if (!isLoopInvariant(LHS, L))
11180 return std::nullopt;
11181
11182 std::swap(LHS, RHS);
11183 Pred = ICmpInst::getSwappedPredicate(Pred);
11184 }
11185
11186 auto *AR = dyn_cast<SCEVAddRecExpr>(LHS);
11187 if (!AR || AR->getLoop() != L)
11188 return std::nullopt;
11189
11190 // The predicate must be relational (i.e. <, <=, >=, >).
11191 if (!ICmpInst::isRelational(Pred))
11192 return std::nullopt;
11193
11194 // TODO: Support steps other than +/- 1.
11195 const SCEV *Step = AR->getStepRecurrence(*this);
11196 auto *One = getOne(Step->getType());
11197 auto *MinusOne = getNegativeSCEV(One);
11198 if (Step != One && Step != MinusOne)
11199 return std::nullopt;
11200
11201 // Type mismatch here means that MaxIter is potentially larger than max
11202 // unsigned value in start type, which mean we cannot prove no wrap for the
11203 // indvar.
11204 if (AR->getType() != MaxIter->getType())
11205 return std::nullopt;
11206
11207 // Value of IV on suggested last iteration.
11208 const SCEV *Last = AR->evaluateAtIteration(MaxIter, *this);
11209 // Does it still meet the requirement?
11210 if (!isLoopBackedgeGuardedByCond(L, Pred, Last, RHS))
11211 return std::nullopt;
11212 // Because step is +/- 1 and MaxIter has same type as Start (i.e. it does
11213 // not exceed max unsigned value of this type), this effectively proves
11214 // that there is no wrap during the iteration. To prove that there is no
11215 // signed/unsigned wrap, we need to check that
11216 // Start <= Last for step = 1 or Start >= Last for step = -1.
11217 ICmpInst::Predicate NoOverflowPred =
11219 if (Step == MinusOne)
11220 NoOverflowPred = CmpInst::getSwappedPredicate(NoOverflowPred);
11221 const SCEV *Start = AR->getStart();
11222 if (!isKnownPredicateAt(NoOverflowPred, Start, Last, CtxI))
11223 return std::nullopt;
11224
11225 // Everything is fine.
11226 return ScalarEvolution::LoopInvariantPredicate(Pred, Start, RHS);
11227}
11228
11229bool ScalarEvolution::isKnownPredicateViaConstantRanges(
11230 ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS) {
11231 if (HasSameValue(LHS, RHS))
11232 return ICmpInst::isTrueWhenEqual(Pred);
11233
11234 // This code is split out from isKnownPredicate because it is called from
11235 // within isLoopEntryGuardedByCond.
11236
11237 auto CheckRanges = [&](const ConstantRange &RangeLHS,
11238 const ConstantRange &RangeRHS) {
11239 return RangeLHS.icmp(Pred, RangeRHS);
11240 };
11241
11242 // The check at the top of the function catches the case where the values are
11243 // known to be equal.
11244 if (Pred == CmpInst::ICMP_EQ)
11245 return false;
11246
11247 if (Pred == CmpInst::ICMP_NE) {
11248 auto SL = getSignedRange(LHS);
11249 auto SR = getSignedRange(RHS);
11250 if (CheckRanges(SL, SR))
11251 return true;
11252 auto UL = getUnsignedRange(LHS);
11253 auto UR = getUnsignedRange(RHS);
11254 if (CheckRanges(UL, UR))
11255 return true;
11256 auto *Diff = getMinusSCEV(LHS, RHS);
11257 return !isa<SCEVCouldNotCompute>(Diff) && isKnownNonZero(Diff);
11258 }
11259
11260 if (CmpInst::isSigned(Pred)) {
11261 auto SL = getSignedRange(LHS);
11262 auto SR = getSignedRange(RHS);
11263 return CheckRanges(SL, SR);
11264 }
11265
11266 auto UL = getUnsignedRange(LHS);
11267 auto UR = getUnsignedRange(RHS);
11268 return CheckRanges(UL, UR);
11269}
11270
11271bool ScalarEvolution::isKnownPredicateViaNoOverflow(ICmpInst::Predicate Pred,
11272 const SCEV *LHS,
11273 const SCEV *RHS) {
11274 // Match X to (A + C1)<ExpectedFlags> and Y to (A + C2)<ExpectedFlags>, where
11275 // C1 and C2 are constant integers. If either X or Y are not add expressions,
11276 // consider them as X + 0 and Y + 0 respectively. C1 and C2 are returned via
11277 // OutC1 and OutC2.
11278 auto MatchBinaryAddToConst = [this](const SCEV *X, const SCEV *Y,
11279 APInt &OutC1, APInt &OutC2,
11280 SCEV::NoWrapFlags ExpectedFlags) {
11281 const SCEV *XNonConstOp, *XConstOp;
11282 const SCEV *YNonConstOp, *YConstOp;
11283 SCEV::NoWrapFlags XFlagsPresent;
11284 SCEV::NoWrapFlags YFlagsPresent;
11285
11286 if (!splitBinaryAdd(X, XConstOp, XNonConstOp, XFlagsPresent)) {
11287 XConstOp = getZero(X->getType());
11288 XNonConstOp = X;
11289 XFlagsPresent = ExpectedFlags;
11290 }
11291 if (!isa<SCEVConstant>(XConstOp) ||
11292 (XFlagsPresent & ExpectedFlags) != ExpectedFlags)
11293 return false;
11294
11295 if (!splitBinaryAdd(Y, YConstOp, YNonConstOp, YFlagsPresent)) {
11296 YConstOp = getZero(Y->getType());
11297 YNonConstOp = Y;
11298 YFlagsPresent = ExpectedFlags;
11299 }
11300
11301 if (!isa<SCEVConstant>(YConstOp) ||
11302 (YFlagsPresent & ExpectedFlags) != ExpectedFlags)
11303 return false;
11304
11305 if (YNonConstOp != XNonConstOp)
11306 return false;
11307
11308 OutC1 = cast<SCEVConstant>(XConstOp)->getAPInt();
11309 OutC2 = cast<SCEVConstant>(YConstOp)->getAPInt();
11310
11311 return true;
11312 };
11313
11314 APInt C1;
11315 APInt C2;
11316
11317 switch (Pred) {
11318 default:
11319 break;
11320
11321 case ICmpInst::ICMP_SGE:
11322 std::swap(LHS, RHS);
11323 [[fallthrough]];
11324 case ICmpInst::ICMP_SLE:
11325 // (X + C1)<nsw> s<= (X + C2)<nsw> if C1 s<= C2.
11326 if (MatchBinaryAddToConst(LHS, RHS, C1, C2, SCEV::FlagNSW) && C1.sle(C2))
11327 return true;
11328
11329 break;
11330
11331 case ICmpInst::ICMP_SGT:
11332 std::swap(LHS, RHS);
11333 [[fallthrough]];
11334 case ICmpInst::ICMP_SLT:
11335 // (X + C1)<nsw> s< (X + C2)<nsw> if C1 s< C2.
11336 if (MatchBinaryAddToConst(LHS, RHS, C1, C2, SCEV::FlagNSW) && C1.slt(C2))
11337 return true;
11338
11339 break;
11340
11341 case ICmpInst::ICMP_UGE:
11342 std::swap(LHS, RHS);
11343 [[fallthrough]];
11344 case ICmpInst::ICMP_ULE:
11345 // (X + C1)<nuw> u<= (X + C2)<nuw> for C1 u<= C2.
11346 if (MatchBinaryAddToConst(LHS, RHS, C1, C2, SCEV::FlagNUW) && C1.ule(C2))
11347 return true;
11348
11349 break;
11350
11351 case ICmpInst::ICMP_UGT:
11352 std::swap(LHS, RHS);
11353 [[fallthrough]];
11354 case ICmpInst::ICMP_ULT:
11355 // (X + C1)<nuw> u< (X + C2)<nuw> if C1 u< C2.
11356 if (MatchBinaryAddToConst(LHS, RHS, C1, C2, SCEV::FlagNUW) && C1.ult(C2))
11357 return true;
11358 break;
11359 }
11360
11361 return false;
11362}
11363
11364bool ScalarEvolution::isKnownPredicateViaSplitting(ICmpInst::Predicate Pred,
11365 const SCEV *LHS,
11366 const SCEV *RHS) {
11367 if (Pred != ICmpInst::ICMP_ULT || ProvingSplitPredicate)
11368 return false;
11369
11370 // Allowing arbitrary number of activations of isKnownPredicateViaSplitting on
11371 // the stack can result in exponential time complexity.
11372 SaveAndRestore Restore(ProvingSplitPredicate, true);
11373
11374 // If L >= 0 then I `ult` L <=> I >= 0 && I `slt` L
11375 //
11376 // To prove L >= 0 we use isKnownNonNegative whereas to prove I >= 0 we use
11377 // isKnownPredicate. isKnownPredicate is more powerful, but also more
11378 // expensive; and using isKnownNonNegative(RHS) is sufficient for most of the
11379 // interesting cases seen in practice. We can consider "upgrading" L >= 0 to
11380 // use isKnownPredicate later if needed.
11381 return isKnownNonNegative(RHS) &&
11384}
11385
11386bool ScalarEvolution::isImpliedViaGuard(const BasicBlock *BB,
11388 const SCEV *LHS, const SCEV *RHS) {
11389 // No need to even try if we know the module has no guards.
11390 if (!HasGuards)
11391 return false;
11392
11393 return any_of(*BB, [&](const Instruction &I) {
11394 using namespace llvm::PatternMatch;
11395
11396 Value *Condition;
11397 return match(&I, m_Intrinsic<Intrinsic::experimental_guard>(
11398 m_Value(Condition))) &&
11399 isImpliedCond(Pred, LHS, RHS, Condition, false);
11400 });
11401}
11402
11403/// isLoopBackedgeGuardedByCond - Test whether the backedge of the loop is
11404/// protected by a conditional between LHS and RHS. This is used to
11405/// to eliminate casts.
11406bool
11409 const SCEV *LHS, const SCEV *RHS) {
11410 // Interpret a null as meaning no loop, where there is obviously no guard
11411 // (interprocedural conditions notwithstanding). Do not bother about
11412 // unreachable loops.
11413 if (!L || !DT.isReachableFromEntry(L->getHeader()))
11414 return true;
11415
11416 if (VerifyIR)
11417 assert(!verifyFunction(*L->getHeader()->getParent(), &dbgs()) &&
11418 "This cannot be done on broken IR!");
11419
11420
11421 if (isKnownViaNonRecursiveReasoning(Pred, LHS, RHS))
11422 return true;
11423
11424 BasicBlock *Latch = L->getLoopLatch();
11425 if (!Latch)
11426 return false;
11427
11428 BranchInst *LoopContinuePredicate =
11429 dyn_cast<BranchInst>(Latch->getTerminator());
11430 if (LoopContinuePredicate && LoopContinuePredicate->isConditional() &&
11431 isImpliedCond(Pred, LHS, RHS,
11432 LoopContinuePredicate->getCondition(),
11433 LoopContinuePredicate->getSuccessor(0) != L->getHeader()))
11434 return true;
11435
11436 // We don't want more than one activation of the following loops on the stack
11437 // -- that can lead to O(n!) time complexity.
11438 if (WalkingBEDominatingConds)
11439 return false;
11440
11441 SaveAndRestore ClearOnExit(WalkingBEDominatingConds, true);
11442
11443 // See if we can exploit a trip count to prove the predicate.
11444 const auto &BETakenInfo = getBackedgeTakenInfo(L);
11445 const SCEV *LatchBECount = BETakenInfo.getExact(Latch, this);
11446 if (LatchBECount != getCouldNotCompute()) {
11447 // We know that Latch branches back to the loop header exactly
11448 // LatchBECount times. This means the backdege condition at Latch is
11449 // equivalent to "{0,+,1} u< LatchBECount".
11450 Type *Ty = LatchBECount->getType();
11451 auto NoWrapFlags = SCEV::NoWrapFlags(SCEV::FlagNUW | SCEV::FlagNW);
11452 const SCEV *LoopCounter =
11453 getAddRecExpr(getZero(Ty), getOne(Ty), L, NoWrapFlags);
11454 if (isImpliedCond(Pred, LHS, RHS, ICmpInst::ICMP_ULT, LoopCounter,
11455 LatchBECount))
11456 return true;
11457 }
11458
11459 // Check conditions due to any @llvm.assume intrinsics.
11460 for (auto &AssumeVH : AC.assumptions()) {
11461 if (!AssumeVH)
11462 continue;
11463 auto *CI = cast<CallInst>(AssumeVH);
11464 if (!DT.dominates(CI, Latch->getTerminator()))
11465 continue;
11466
11467 if (isImpliedCond(Pred, LHS, RHS, CI->getArgOperand(0), false))
11468 return true;
11469 }
11470
11471 if (isImpliedViaGuard(Latch, Pred, LHS, RHS))
11472 return true;
11473
11474 for (DomTreeNode *DTN = DT[Latch], *HeaderDTN = DT[L->getHeader()];
11475 DTN != HeaderDTN; DTN = DTN->getIDom()) {
11476 assert(DTN && "should reach the loop header before reaching the root!");
11477
11478 BasicBlock *BB = DTN->getBlock();
11479 if (isImpliedViaGuard(BB, Pred, LHS, RHS))
11480 return true;
11481
11482 BasicBlock *PBB = BB->getSinglePredecessor();
11483 if (!PBB)
11484 continue;
11485
11486 BranchInst *ContinuePredicate = dyn_cast<BranchInst>(PBB->getTerminator());
11487 if (!ContinuePredicate || !ContinuePredicate->isConditional())
11488 continue;
11489
11490 Value *Condition = ContinuePredicate->getCondition();
11491
11492 // If we have an edge `E` within the loop body that dominates the only
11493 // latch, the condition guarding `E` also guards the backedge. This
11494 // reasoning works only for loops with a single latch.
11495
11496 BasicBlockEdge DominatingEdge(PBB, BB);
11497 if (DominatingEdge.isSingleEdge()) {
11498 // We're constructively (and conservatively) enumerating edges within the
11499 // loop body that dominate the latch. The dominator tree better agree
11500 // with us on this:
11501 assert(DT.dominates(DominatingEdge, Latch) && "should be!");
11502
11503 if (isImpliedCond(Pred, LHS, RHS, Condition,
11504 BB != ContinuePredicate->getSuccessor(0)))
11505 return true;
11506 }
11507 }
11508
11509 return false;
11510}
11511
11514 const SCEV *LHS,
11515 const SCEV *RHS) {
11516 // Do not bother proving facts for unreachable code.
11517 if (!DT.isReachableFromEntry(BB))
11518 return true;
11519 if (VerifyIR)
11520 assert(!verifyFunction(*BB->getParent(), &dbgs()) &&
11521 "This cannot be done on broken IR!");
11522
11523 // If we cannot prove strict comparison (e.g. a > b), maybe we can prove
11524 // the facts (a >= b && a != b) separately. A typical situation is when the
11525 // non-strict comparison is known from ranges and non-equality is known from
11526 // dominating predicates. If we are proving strict comparison, we always try
11527 // to prove non-equality and non-strict comparison separately.
11528 auto NonStrictPredicate = ICmpInst::getNonStrictPredicate(Pred);
11529 const bool ProvingStrictComparison = (Pred != NonStrictPredicate);
11530 bool ProvedNonStrictComparison = false;
11531 bool ProvedNonEquality = false;
11532
11533 auto SplitAndProve =
11534 [&](std::function<bool(ICmpInst::Predicate)> Fn) -> bool {
11535 if (!ProvedNonStrictComparison)
11536 ProvedNonStrictComparison = Fn(NonStrictPredicate);
11537 if (!ProvedNonEquality)
11538 ProvedNonEquality = Fn(ICmpInst::ICMP_NE);
11539 if (ProvedNonStrictComparison && ProvedNonEquality)
11540 return true;
11541 return false;
11542 };
11543
11544 if (ProvingStrictComparison) {
11545 auto ProofFn = [&](ICmpInst::Predicate P) {
11546 return isKnownViaNonRecursiveReasoning(P, LHS, RHS);
11547 };
11548 if (SplitAndProve(ProofFn))
11549 return true;
11550 }
11551
11552 // Try to prove (Pred, LHS, RHS) using isImpliedCond.
11553 auto ProveViaCond = [&](const Value *Condition, bool Inverse) {
11554 const Instruction *CtxI = &BB->front();
11555 if (isImpliedCond(Pred, LHS, RHS, Condition, Inverse, CtxI))
11556 return true;
11557 if (ProvingStrictComparison) {
11558 auto ProofFn = [&](ICmpInst::Predicate P) {
11559 return isImpliedCond(P, LHS, RHS, Condition, Inverse, CtxI);
11560 };
11561 if (SplitAndProve(ProofFn))
11562 return true;
11563 }
11564 return false;
11565 };
11566
11567 // Starting at the block's predecessor, climb up the predecessor chain, as long
11568 // as there are predecessors that can be found that have unique successors
11569 // leading to the original block.
11570 const Loop *ContainingLoop = LI.getLoopFor(BB);
11571 const BasicBlock *PredBB;
11572 if (ContainingLoop && ContainingLoop->getHeader() == BB)
11573 PredBB = ContainingLoop->getLoopPredecessor();
11574 else
11575 PredBB = BB->getSinglePredecessor();
11576 for (std::pair<const BasicBlock *, const BasicBlock *> Pair(PredBB, BB);
11577 Pair.first; Pair = getPredecessorWithUniqueSuccessorForBB(Pair.first)) {
11578 const BranchInst *BlockEntryPredicate =
11579 dyn_cast<BranchInst>(Pair.first->getTerminator());
11580 if (!BlockEntryPredicate || BlockEntryPredicate->isUnconditional())
11581 continue;
11582
11583 if (ProveViaCond(BlockEntryPredicate->getCondition(),
11584 BlockEntryPredicate->getSuccessor(0) != Pair.second))
11585 return true;
11586 }
11587
11588 // Check conditions due to any @llvm.assume intrinsics.
11589 for (auto &AssumeVH : AC.assumptions()) {
11590 if (!AssumeVH)
11591 continue;
11592 auto *CI = cast<CallInst>(AssumeVH);
11593 if (!DT.dominates(CI, BB))
11594 continue;
11595
11596 if (ProveViaCond(CI->getArgOperand(0), false))
11597 return true;
11598 }
11599
11600 // Check conditions due to any @llvm.experimental.guard intrinsics.
11601 auto *GuardDecl = F.getParent()->getFunction(
11602 Intrinsic::getName(Intrinsic::experimental_guard));
11603 if (GuardDecl)
11604 for (const auto *GU : GuardDecl->users())
11605 if (const auto *Guard = dyn_cast<IntrinsicInst>(GU))
11606 if (Guard->getFunction() == BB->getParent() && DT.dominates(Guard, BB))
11607 if (ProveViaCond(Guard->getArgOperand(0), false))
11608 return true;
11609 return false;
11610}
11611
11614 const SCEV *LHS,
11615 const SCEV *RHS) {
11616 // Interpret a null as meaning no loop, where there is obviously no guard
11617 // (interprocedural conditions notwithstanding).
11618 if (!L)
11619 return false;
11620
11621 // Both LHS and RHS must be available at loop entry.
11623 "LHS is not available at Loop Entry");
11625 "RHS is not available at Loop Entry");
11626
11627 if (isKnownViaNonRecursiveReasoning(Pred, LHS, RHS))
11628 return true;
11629
11630 return isBasicBlockEntryGuardedByCond(L->getHeader(), Pred, LHS, RHS);
11631}
11632
11633bool ScalarEvolution::isImpliedCond(ICmpInst::Predicate Pred, const SCEV *LHS,
11634 const SCEV *RHS,
11635 const Value *FoundCondValue, bool Inverse,
11636 const Instruction *CtxI) {
11637 // False conditions implies anything. Do not bother analyzing it further.
11638 if (FoundCondValue ==
11639 ConstantInt::getBool(FoundCondValue->getContext(), Inverse))
11640 return true;
11641
11642 if (!PendingLoopPredicates.insert(FoundCondValue).second)
11643 return false;
11644
11645 auto ClearOnExit =
11646 make_scope_exit([&]() { PendingLoopPredicates.erase(FoundCondValue); });
11647
11648 // Recursively handle And and Or conditions.
11649 const Value *Op0, *Op1;
11650 if (match(FoundCondValue, m_LogicalAnd(m_Value(Op0), m_Value(Op1)))) {
11651 if (!Inverse)
11652 return isImpliedCond(Pred, LHS, RHS, Op0, Inverse, CtxI) ||
11653 isImpliedCond(Pred, LHS, RHS, Op1, Inverse, CtxI);
11654 } else if (match(FoundCondValue, m_LogicalOr(m_Value(Op0), m_Value(Op1)))) {
11655 if (Inverse)
11656 return isImpliedCond(Pred, LHS, RHS, Op0, Inverse, CtxI) ||
11657 isImpliedCond(Pred, LHS, RHS, Op1, Inverse, CtxI);
11658 }
11659
11660 const ICmpInst *ICI = dyn_cast<ICmpInst>(FoundCondValue);
11661 if (!ICI) return false;
11662
11663 // Now that we found a conditional branch that dominates the loop or controls
11664 // the loop latch. Check to see if it is the comparison we are looking for.
11665 ICmpInst::Predicate FoundPred;
11666 if (Inverse)
11667 FoundPred = ICI->getInversePredicate();
11668 else
11669 FoundPred = ICI->getPredicate();
11670
11671 const SCEV *FoundLHS = getSCEV(ICI->getOperand(0));
11672 const SCEV *FoundRHS = getSCEV(ICI->getOperand(1));
11673
11674 return isImpliedCond(Pred, LHS, RHS, FoundPred, FoundLHS, FoundRHS, CtxI);
11675}
11676
11677bool ScalarEvolution::isImpliedCond(ICmpInst::Predicate Pred, const SCEV *LHS,
11678 const SCEV *RHS,
11679 ICmpInst::Predicate FoundPred,
11680 const SCEV *FoundLHS, const SCEV *FoundRHS,
11681 const Instruction *CtxI) {
11682 // Balance the types.
11683 if (getTypeSizeInBits(LHS->getType()) <
11684 getTypeSizeInBits(FoundLHS->getType())) {
11685 // For unsigned and equality predicates, try to prove that both found
11686 // operands fit into narrow unsigned range. If so, try to prove facts in
11687 // narrow types.
11688 if (!CmpInst::isSigned(FoundPred) && !FoundLHS->getType()->isPointerTy() &&
11689 !FoundRHS->getType()->isPointerTy()) {
11690 auto *NarrowType = LHS->getType();
11691 auto *WideType = FoundLHS->getType();
11692 auto BitWidth = getTypeSizeInBits(NarrowType);
11693 const SCEV *MaxValue = getZeroExtendExpr(
11695 if (isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_ULE, FoundLHS,
11696 MaxValue) &&
11697 isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_ULE, FoundRHS,
11698 MaxValue)) {
11699 const SCEV *TruncFoundLHS = getTruncateExpr(FoundLHS, NarrowType);
11700 const SCEV *TruncFoundRHS = getTruncateExpr(FoundRHS, NarrowType);
11701 if (isImpliedCondBalancedTypes(Pred, LHS, RHS, FoundPred, TruncFoundLHS,
11702 TruncFoundRHS, CtxI))
11703 return true;
11704 }
11705 }
11706
11707 if (LHS->getType()->isPointerTy() || RHS->getType()->isPointerTy())
11708 return false;
11709 if (CmpInst::isSigned(Pred)) {
11710 LHS = getSignExtendExpr(LHS, FoundLHS->getType());
11711 RHS = getSignExtendExpr(RHS, FoundLHS->getType());
11712 } else {
11713 LHS = getZeroExtendExpr(LHS, FoundLHS->getType());
11714 RHS = getZeroExtendExpr(RHS, FoundLHS->getType());
11715 }
11716 } else if (getTypeSizeInBits(LHS->getType()) >
11717 getTypeSizeInBits(FoundLHS->getType())) {
11718 if (FoundLHS->getType()->isPointerTy() || FoundRHS->getType()->isPointerTy())
11719 return false;
11720 if (CmpInst::isSigned(FoundPred)) {
11721 FoundLHS = getSignExtendExpr(FoundLHS, LHS->getType());
11722 FoundRHS = getSignExtendExpr(FoundRHS, LHS->getType());
11723 } else {
11724 FoundLHS = getZeroExtendExpr(FoundLHS, LHS->getType());
11725 FoundRHS = getZeroExtendExpr(FoundRHS, LHS->getType());
11726 }
11727 }
11728 return isImpliedCondBalancedTypes(Pred, LHS, RHS, FoundPred, FoundLHS,
11729 FoundRHS, CtxI);
11730}
11731
11732bool ScalarEvolution::isImpliedCondBalancedTypes(
11733 ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS,
11734 ICmpInst::Predicate FoundPred, const SCEV *FoundLHS, const SCEV *FoundRHS,
11735 const Instruction *CtxI) {
11737 getTypeSizeInBits(FoundLHS->getType()) &&
11738 "Types should be balanced!");
11739 // Canonicalize the query to match the way instcombine will have
11740 // canonicalized the comparison.
11741 if (SimplifyICmpOperands(Pred, LHS, RHS))
11742 if (LHS == RHS)
11743 return CmpInst::isTrueWhenEqual(Pred);
11744 if (SimplifyICmpOperands(FoundPred, FoundLHS, FoundRHS))
11745 if (FoundLHS == FoundRHS)
11746 return CmpInst::isFalseWhenEqual(FoundPred);
11747
11748 // Check to see if we can make the LHS or RHS match.
11749 if (LHS == FoundRHS || RHS == FoundLHS) {
11750 if (isa<SCEVConstant>(RHS)) {
11751 std::swap(FoundLHS, FoundRHS);
11752 FoundPred = ICmpInst::getSwappedPredicate(FoundPred);
11753 } else {
11754 std::swap(LHS, RHS);
11755 Pred = ICmpInst::getSwappedPredicate(Pred);
11756 }
11757 }
11758
11759 // Check whether the found predicate is the same as the desired predicate.
11760 if (FoundPred == Pred)
11761 return isImpliedCondOperands(Pred, LHS, RHS, FoundLHS, FoundRHS, CtxI);
11762
11763 // Check whether swapping the found predicate makes it the same as the
11764 // desired predicate.
11765 if (ICmpInst::getSwappedPredicate(FoundPred) == Pred) {
11766 // We can write the implication
11767 // 0. LHS Pred RHS <- FoundLHS SwapPred FoundRHS
11768 // using one of the following ways:
11769 // 1. LHS Pred RHS <- FoundRHS Pred FoundLHS
11770 // 2. RHS SwapPred LHS <- FoundLHS SwapPred FoundRHS
11771 // 3. LHS Pred RHS <- ~FoundLHS Pred ~FoundRHS
11772 // 4. ~LHS SwapPred ~RHS <- FoundLHS SwapPred FoundRHS
11773 // Forms 1. and 2. require swapping the operands of one condition. Don't
11774 // do this if it would break canonical constant/addrec ordering.
11775 if (!isa<SCEVConstant>(RHS) && !isa<SCEVAddRecExpr>(LHS))
11776 return isImpliedCondOperands(FoundPred, RHS, LHS, FoundLHS, FoundRHS,
11777 CtxI);
11778 if (!isa<SCEVConstant>(FoundRHS) && !isa<SCEVAddRecExpr>(FoundLHS))
11779 return isImpliedCondOperands(Pred, LHS, RHS, FoundRHS, FoundLHS, CtxI);
11780
11781 // There's no clear preference between forms 3. and 4., try both. Avoid
11782 // forming getNotSCEV of pointer values as the resulting subtract is
11783 // not legal.
11784 if (!LHS->getType()->isPointerTy() && !RHS->getType()->isPointerTy() &&
11785 isImpliedCondOperands(FoundPred, getNotSCEV(LHS), getNotSCEV(RHS),
11786 FoundLHS, FoundRHS, CtxI))
11787 return true;
11788
11789 if (!FoundLHS->getType()->isPointerTy() &&
11790 !FoundRHS->getType()->isPointerTy() &&
11791 isImpliedCondOperands(Pred, LHS, RHS, getNotSCEV(FoundLHS),
11792 getNotSCEV(FoundRHS), CtxI))
11793 return true;
11794
11795 return false;
11796 }
11797
11798 auto IsSignFlippedPredicate = [](CmpInst::Predicate P1,
11799 CmpInst::Predicate P2) {
11800 assert(P1 != P2 && "Handled earlier!");
11801 return CmpInst::isRelational(P2) &&
11803 };
11804 if (IsSignFlippedPredicate(Pred, FoundPred)) {
11805 // Unsigned comparison is the same as signed comparison when both the
11806 // operands are non-negative or negative.
11807 if ((isKnownNonNegative(FoundLHS) && isKnownNonNegative(FoundRHS)) ||
11808 (isKnownNegative(FoundLHS) && isKnownNegative(FoundRHS)))
11809 return isImpliedCondOperands(Pred, LHS, RHS, FoundLHS, FoundRHS, CtxI);
11810 // Create local copies that we can freely swap and canonicalize our
11811 // conditions to "le/lt".
11812 ICmpInst::Predicate CanonicalPred = Pred, CanonicalFoundPred = FoundPred;
11813 const SCEV *CanonicalLHS = LHS, *CanonicalRHS = RHS,
11814 *CanonicalFoundLHS = FoundLHS, *CanonicalFoundRHS = FoundRHS;
11815 if (ICmpInst::isGT(CanonicalPred) || ICmpInst::isGE(CanonicalPred)) {
11816 CanonicalPred = ICmpInst::getSwappedPredicate(CanonicalPred);
11817 CanonicalFoundPred = ICmpInst::getSwappedPredicate(CanonicalFoundPred);
11818 std::swap(CanonicalLHS, CanonicalRHS);
11819 std::swap(CanonicalFoundLHS, CanonicalFoundRHS);
11820 }
11821 assert((ICmpInst::isLT(CanonicalPred) || ICmpInst::isLE(CanonicalPred)) &&
11822 "Must be!");
11823 assert((ICmpInst::isLT(CanonicalFoundPred) ||
11824 ICmpInst::isLE(CanonicalFoundPred)) &&
11825 "Must be!");
11826 if (ICmpInst::isSigned(CanonicalPred) && isKnownNonNegative(CanonicalRHS))
11827 // Use implication:
11828 // x <u y && y >=s 0 --> x <s y.
11829 // If we can prove the left part, the right part is also proven.
11830 return isImpliedCondOperands(CanonicalFoundPred, CanonicalLHS,
11831 CanonicalRHS, CanonicalFoundLHS,
11832 CanonicalFoundRHS);
11833 if (ICmpInst::isUnsigned(CanonicalPred) && isKnownNegative(CanonicalRHS))
11834 // Use implication:
11835 // x <s y && y <s 0 --> x <u y.
11836 // If we can prove the left part, the right part is also proven.
11837 return isImpliedCondOperands(CanonicalFoundPred, CanonicalLHS,
11838 CanonicalRHS, CanonicalFoundLHS,
11839 CanonicalFoundRHS);
11840 }
11841
11842 // Check if we can make progress by sharpening ranges.
11843 if (FoundPred == ICmpInst::ICMP_NE &&
11844 (isa<SCEVConstant>(FoundLHS) || isa<SCEVConstant>(FoundRHS))) {
11845
11846 const SCEVConstant *C = nullptr;
11847 const SCEV *V = nullptr;
11848
11849 if (isa<SCEVConstant>(FoundLHS)) {
11850 C = cast<SCEVConstant>(FoundLHS);
11851 V = FoundRHS;
11852 } else {
11853 C = cast<SCEVConstant>(FoundRHS);
11854 V = FoundLHS;
11855 }
11856
11857 // The guarding predicate tells us that C != V. If the known range
11858 // of V is [C, t), we can sharpen the range to [C + 1, t). The
11859 // range we consider has to correspond to same signedness as the
11860 // predicate we're interested in folding.
11861
11862 APInt Min = ICmpInst::isSigned(Pred) ?
11864
11865 if (Min == C->getAPInt()) {
11866 // Given (V >= Min && V != Min) we conclude V >= (Min + 1).
11867 // This is true even if (Min + 1) wraps around -- in case of
11868 // wraparound, (Min + 1) < Min, so (V >= Min => V >= (Min + 1)).
11869
11870 APInt SharperMin = Min + 1;
11871
11872 switch (Pred) {
11873 case ICmpInst::ICMP_SGE:
11874 case ICmpInst::ICMP_UGE:
11875 // We know V `Pred` SharperMin. If this implies LHS `Pred`
11876 // RHS, we're done.
11877 if (isImpliedCondOperands(Pred, LHS, RHS, V, getConstant(SharperMin),
11878 CtxI))
11879 return true;
11880 [[fallthrough]];
11881
11882 case ICmpInst::ICMP_SGT:
11883 case ICmpInst::ICMP_UGT:
11884 // We know from the range information that (V `Pred` Min ||
11885 // V == Min). We know from the guarding condition that !(V
11886 // == Min). This gives us
11887 //
11888 // V `Pred` Min || V == Min && !(V == Min)
11889 // => V `Pred` Min
11890 //
11891 // If V `Pred` Min implies LHS `Pred` RHS, we're done.
11892
11893 if (isImpliedCondOperands(Pred, LHS, RHS, V, getConstant(Min), CtxI))
11894 return true;
11895 break;
11896
11897 // `LHS < RHS` and `LHS <= RHS` are handled in the same way as `RHS > LHS` and `RHS >= LHS` respectively.
11898 case ICmpInst::ICMP_SLE:
11899 case ICmpInst::ICMP_ULE:
11900 if (isImpliedCondOperands(CmpInst::getSwappedPredicate(Pred), RHS,
11901 LHS, V, getConstant(SharperMin), CtxI))
11902 return true;
11903 [[fallthrough]];
11904
11905 case ICmpInst::ICMP_SLT:
11906 case ICmpInst::ICMP_ULT:
11907 if (isImpliedCondOperands(CmpInst::getSwappedPredicate(Pred), RHS,
11908 LHS, V, getConstant(Min), CtxI))
11909 return true;
11910 break;
11911
11912 default:
11913 // No change
11914 break;
11915 }
11916 }
11917 }
11918
11919 // Check whether the actual condition is beyond sufficient.
11920 if (FoundPred == ICmpInst::ICMP_EQ)
11921 if (ICmpInst::isTrueWhenEqual(Pred))
11922 if (isImpliedCondOperands(Pred, LHS, RHS, FoundLHS, FoundRHS, CtxI))
11923 return true;
11924 if (Pred == ICmpInst::ICMP_NE)
11925 if (!ICmpInst::isTrueWhenEqual(FoundPred))
11926 if (isImpliedCondOperands(FoundPred, LHS, RHS, FoundLHS, FoundRHS, CtxI))
11927 return true;
11928
11929 if (isImpliedCondOperandsViaRanges(Pred, LHS, RHS, FoundPred, FoundLHS, FoundRHS))
11930 return true;
11931
11932 // Otherwise assume the worst.
11933 return false;
11934}
11935
11936bool ScalarEvolution::splitBinaryAdd(const SCEV *Expr,
11937 const SCEV *&L, const SCEV *&R,
11938 SCEV::NoWrapFlags &Flags) {
11939 const auto *AE = dyn_cast<SCEVAddExpr>(Expr);
11940 if (!AE || AE->getNumOperands() != 2)
11941 return false;
11942
11943 L = AE->getOperand(0);
11944 R = AE->getOperand(1);
11945 Flags = AE->getNoWrapFlags();
11946 return true;
11947}
11948
11949std::optional<APInt>
11951 // We avoid subtracting expressions here because this function is usually
11952 // fairly deep in the call stack (i.e. is called many times).
11953
11954 unsigned BW = getTypeSizeInBits(More->getType());
11955 APInt Diff(BW, 0);
11956 APInt DiffMul(BW, 1);
11957 // Try various simplifications to reduce the difference to a constant. Limit
11958 // the number of allowed simplifications to keep compile-time low.
11959 for (unsigned I = 0; I < 8; ++I) {
11960 if (More == Less)
11961 return Diff;
11962
11963 // Reduce addrecs with identical steps to their start value.
11964 if (isa<SCEVAddRecExpr>(Less) && isa<SCEVAddRecExpr>(More)) {
11965 const auto *LAR = cast<SCEVAddRecExpr>(Less);
11966 const auto *MAR = cast<SCEVAddRecExpr>(More);
11967
11968 if (LAR->getLoop() != MAR->getLoop())
11969 return std::nullopt;
11970
11971 // We look at affine expressions only; not for correctness but to keep
11972 // getStepRecurrence cheap.
11973 if (!LAR->isAffine() || !MAR->isAffine())
11974 return std::nullopt;
11975
11976 if (LAR->getStepRecurrence(*this) != MAR->getStepRecurrence(*this))
11977 return std::nullopt;
11978
11979 Less = LAR->getStart();
11980 More = MAR->getStart();
11981 continue;
11982 }
11983
11984 // Try to match a common constant multiply.
11985 auto MatchConstMul =
11986 [](const SCEV *S) -> std::optional<std::pair<const SCEV *, APInt>> {
11987 auto *M = dyn_cast<SCEVMulExpr>(S);
11988 if (!M || M->getNumOperands() != 2 ||
11989 !isa<SCEVConstant>(M->getOperand(0)))
11990 return std::nullopt;
11991 return {
11992 {M->getOperand(1), cast<SCEVConstant>(M->getOperand(0))->getAPInt()}};
11993 };
11994 if (auto MatchedMore = MatchConstMul(More)) {
11995 if (auto MatchedLess = MatchConstMul(Less)) {
11996 if (MatchedMore->second == MatchedLess->second) {
11997 More = MatchedMore->first;
11998 Less = MatchedLess->first;
11999 DiffMul *= MatchedMore->second;
12000 continue;
12001 }
12002 }
12003 }
12004
12005 // Try to cancel out common factors in two add expressions.
12007 auto Add = [&](const SCEV *S, int Mul) {
12008 if (auto *C = dyn_cast<SCEVConstant>(S)) {
12009 if (Mul == 1) {
12010 Diff += C->getAPInt() * DiffMul;
12011 } else {
12012 assert(Mul == -1);
12013 Diff -= C->getAPInt() * DiffMul;
12014 }
12015 } else
12016 Multiplicity[S] += Mul;
12017 };
12018 auto Decompose = [&](const SCEV *S, int Mul) {
12019 if (isa<SCEVAddExpr>(S)) {
12020 for (const SCEV *Op : S->operands())
12021 Add(Op, Mul);
12022 } else
12023 Add(S, Mul);
12024 };
12025 Decompose(More, 1);
12026 Decompose(Less, -1);
12027
12028 // Check whether all the non-constants cancel out, or reduce to new
12029 // More/Less values.
12030 const SCEV *NewMore = nullptr, *NewLess = nullptr;
12031 for (const auto &[S, Mul] : Multiplicity) {
12032 if (Mul == 0)
12033 continue;
12034 if (Mul == 1) {
12035 if (NewMore)
12036 return std::nullopt;
12037 NewMore = S;
12038 } else if (Mul == -1) {
12039 if (NewLess)
12040 return std::nullopt;
12041 NewLess = S;
12042 } else
12043 return std::nullopt;
12044 }
12045
12046 // Values stayed the same, no point in trying further.
12047 if (NewMore == More || NewLess == Less)
12048 return std::nullopt;
12049
12050 More = NewMore;
12051 Less = NewLess;
12052
12053 // Reduced to constant.
12054 if (!More && !Less)
12055 return Diff;
12056
12057 // Left with variable on only one side, bail out.
12058 if (!More || !Less)
12059 return std::nullopt;
12060 }
12061
12062 // Did not reduce to constant.
12063 return std::nullopt;
12064}
12065
12066bool ScalarEvolution::isImpliedCondOperandsViaAddRecStart(
12067 ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS,
12068 const SCEV *FoundLHS, const SCEV *FoundRHS, const Instruction *CtxI) {
12069 // Try to recognize the following pattern:
12070 //
12071 // FoundRHS = ...
12072 // ...
12073 // loop:
12074 // FoundLHS = {Start,+,W}
12075 // context_bb: // Basic block from the same loop
12076 // known(Pred, FoundLHS, FoundRHS)
12077 //
12078 // If some predicate is known in the context of a loop, it is also known on
12079 // each iteration of this loop, including the first iteration. Therefore, in
12080 // this case, `FoundLHS Pred FoundRHS` implies `Start Pred FoundRHS`. Try to
12081 // prove the original pred using this fact.
12082 if (!CtxI)
12083 return false;
12084 const BasicBlock *ContextBB = CtxI->getParent();
12085 // Make sure AR varies in the context block.
12086 if (auto *AR = dyn_cast<SCEVAddRecExpr>(FoundLHS)) {
12087 const Loop *L = AR->getLoop();
12088 // Make sure that context belongs to the loop and executes on 1st iteration
12089 // (if it ever executes at all).
12090 if (!L->contains(ContextBB) || !DT.dominates(ContextBB, L->getLoopLatch()))
12091 return false;
12092 if (!isAvailableAtLoopEntry(FoundRHS, AR->getLoop()))
12093 return false;
12094 return isImpliedCondOperands(Pred, LHS, RHS, AR->getStart(), FoundRHS);
12095 }
12096
12097 if (auto *AR = dyn_cast<SCEVAddRecExpr>(FoundRHS)) {
12098 const Loop *L = AR->getLoop();
12099 // Make sure that context belongs to the loop and executes on 1st iteration
12100 // (if it ever executes at all).
12101 if (!L->contains(ContextBB) || !DT.dominates(ContextBB, L->getLoopLatch()))
12102 return false;
12103 if (!isAvailableAtLoopEntry(FoundLHS, AR->getLoop()))
12104 return false;
12105 return isImpliedCondOperands(Pred, LHS, RHS, FoundLHS, AR->getStart());
12106 }
12107
12108 return false;
12109}
12110
12111bool ScalarEvolution::isImpliedCondOperandsViaNoOverflow(
12112 ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS,
12113 const SCEV *FoundLHS, const SCEV *FoundRHS) {
12114 if (Pred != CmpInst::ICMP_SLT && Pred != CmpInst::ICMP_ULT)
12115 return false;
12116
12117 const auto *AddRecLHS = dyn_cast<SCEVAddRecExpr>(LHS);
12118 if (!AddRecLHS)
12119 return false;
12120
12121 const auto *AddRecFoundLHS = dyn_cast<SCEVAddRecExpr>(FoundLHS);
12122 if (!AddRecFoundLHS)
12123 return false;
12124
12125 // We'd like to let SCEV reason about control dependencies, so we constrain
12126 // both the inequalities to be about add recurrences on the same loop. This
12127 // way we can use isLoopEntryGuardedByCond later.
12128
12129 const Loop *L = AddRecFoundLHS->getLoop();
12130 if (L != AddRecLHS->getLoop())
12131 return false;
12132
12133 // FoundLHS u< FoundRHS u< -C => (FoundLHS + C) u< (FoundRHS + C) ... (1)
12134 //
12135 // FoundLHS s< FoundRHS s< INT_MIN - C => (FoundLHS + C) s< (FoundRHS + C)
12136 // ... (2)
12137 //
12138 // Informal proof for (2), assuming (1) [*]:
12139 //
12140 // We'll also assume (A s< B) <=> ((A + INT_MIN) u< (B + INT_MIN)) ... (3)[**]
12141 //
12142 // Then
12143 //
12144 // FoundLHS s< FoundRHS s< INT_MIN - C
12145 // <=> (FoundLHS + INT_MIN) u< (FoundRHS + INT_MIN) u< -C [ using (3) ]
12146 // <=> (FoundLHS + INT_MIN + C) u< (FoundRHS + INT_MIN + C) [ using (1) ]
12147 // <=> (FoundLHS + INT_MIN + C + INT_MIN) s<
12148 // (FoundRHS + INT_MIN + C + INT_MIN) [ using (3) ]
12149 // <=> FoundLHS + C s< FoundRHS + C
12150 //
12151 // [*]: (1) can be proved by ruling out overflow.
12152 //
12153 // [**]: This can be proved by analyzing all the four possibilities:
12154 // (A s< 0, B s< 0), (A s< 0, B s>= 0), (A s>= 0, B s< 0) and
12155 // (A s>= 0, B s>= 0).
12156 //
12157 // Note:
12158 // Despite (2), "FoundRHS s< INT_MIN - C" does not mean that "FoundRHS + C"
12159 // will not sign underflow. For instance, say FoundLHS = (i8 -128), FoundRHS
12160 // = (i8 -127) and C = (i8 -100). Then INT_MIN - C = (i8 -28), and FoundRHS
12161 // s< (INT_MIN - C). Lack of sign overflow / underflow in "FoundRHS + C" is
12162 // neither necessary nor sufficient to prove "(FoundLHS + C) s< (FoundRHS +
12163 // C)".
12164
12165 std::optional<APInt> LDiff = computeConstantDifference(LHS, FoundLHS);
12166 if (!LDiff)
12167 return false;
12168 std::optional<APInt> RDiff = computeConstantDifference(RHS, FoundRHS);
12169 if (!RDiff || *LDiff != *RDiff)
12170 return false;
12171
12172 if (LDiff->isMinValue())
12173 return true;
12174
12175 APInt FoundRHSLimit;
12176
12177 if (Pred == CmpInst::ICMP_ULT) {
12178 FoundRHSLimit = -(*RDiff);
12179 } else {
12180 assert(Pred == CmpInst::ICMP_SLT && "Checked above!");
12181 FoundRHSLimit = APInt::getSignedMinValue(getTypeSizeInBits(RHS->getType())) - *RDiff;
12182 }
12183
12184 // Try to prove (1) or (2), as needed.
12185 return isAvailableAtLoopEntry(FoundRHS, L) &&
12186 isLoopEntryGuardedByCond(L, Pred, FoundRHS,
12187 getConstant(FoundRHSLimit));
12188}
12189
12190bool ScalarEvolution::isImpliedViaMerge(ICmpInst::Predicate Pred,
12191 const SCEV *LHS, const SCEV *RHS,
12192 const SCEV *FoundLHS,
12193 const SCEV *FoundRHS, unsigned Depth) {
12194 const PHINode *LPhi = nullptr, *RPhi = nullptr;
12195
12196 auto ClearOnExit = make_scope_exit([&]() {
12197 if (LPhi) {
12198 bool Erased = PendingMerges.erase(LPhi);
12199 assert(Erased && "Failed to erase LPhi!");
12200 (void)Erased;
12201 }
12202 if (RPhi) {
12203 bool Erased = PendingMerges.erase(RPhi);
12204 assert(Erased && "Failed to erase RPhi!");
12205 (void)Erased;
12206 }
12207 });
12208
12209 // Find respective Phis and check that they are not being pending.
12210 if (const SCEVUnknown *LU = dyn_cast<SCEVUnknown>(LHS))
12211 if (auto *Phi = dyn_cast<PHINode>(LU->getValue())) {
12212 if (!PendingMerges.insert(Phi).second)
12213 return false;
12214 LPhi = Phi;
12215 }
12216 if (const SCEVUnknown *RU = dyn_cast<SCEVUnknown>(RHS))
12217 if (auto *Phi = dyn_cast<PHINode>(RU->getValue())) {
12218 // If we detect a loop of Phi nodes being processed by this method, for
12219 // example:
12220 //
12221 // %a = phi i32 [ %some1, %preheader ], [ %b, %latch ]
12222 // %b = phi i32 [ %some2, %preheader ], [ %a, %latch ]
12223 //
12224 // we don't want to deal with a case that complex, so return conservative
12225 // answer false.
12226 if (!PendingMerges.insert(Phi).second)
12227 return false;
12228 RPhi = Phi;
12229 }
12230
12231 // If none of LHS, RHS is a Phi, nothing to do here.
12232 if (!LPhi && !RPhi)
12233 return false;
12234
12235 // If there is a SCEVUnknown Phi we are interested in, make it left.
12236 if (!LPhi) {
12237 std::swap(LHS, RHS);
12238 std::swap(FoundLHS, FoundRHS);
12239 std::swap(LPhi, RPhi);
12240 Pred = ICmpInst::getSwappedPredicate(Pred);
12241 }
12242
12243 assert(LPhi && "LPhi should definitely be a SCEVUnknown Phi!");
12244 const BasicBlock *LBB = LPhi->getParent();
12245 const SCEVAddRecExpr *RAR = dyn_cast<SCEVAddRecExpr>(RHS);
12246
12247 auto ProvedEasily = [&](const SCEV *S1, const SCEV *S2) {
12248 return isKnownViaNonRecursiveReasoning(Pred, S1, S2) ||
12249 isImpliedCondOperandsViaRanges(Pred, S1, S2, Pred, FoundLHS, FoundRHS) ||
12250 isImpliedViaOperations(Pred, S1, S2, FoundLHS, FoundRHS, Depth);
12251 };
12252
12253 if (RPhi && RPhi->getParent() == LBB) {
12254 // Case one: RHS is also a SCEVUnknown Phi from the same basic block.
12255 // If we compare two Phis from the same block, and for each entry block
12256 // the predicate is true for incoming values from this block, then the
12257 // predicate is also true for the Phis.
12258 for (const BasicBlock *IncBB : predecessors(LBB)) {
12259 const SCEV *L = getSCEV(LPhi->getIncomingValueForBlock(IncBB));
12260 const SCEV *R = getSCEV(RPhi->getIncomingValueForBlock(IncBB));
12261 if (!ProvedEasily(L, R))
12262 return false;
12263 }
12264 } else if (RAR && RAR->getLoop()->getHeader() == LBB) {
12265 // Case two: RHS is also a Phi from the same basic block, and it is an
12266 // AddRec. It means that there is a loop which has both AddRec and Unknown
12267 // PHIs, for it we can compare incoming values of AddRec from above the loop
12268 // and latch with their respective incoming values of LPhi.
12269 // TODO: Generalize to handle loops with many inputs in a header.
12270 if (LPhi->getNumIncomingValues() != 2) return false;
12271
12272 auto *RLoop = RAR->getLoop();
12273 auto *Predecessor = RLoop->getLoopPredecessor();
12274 assert(Predecessor && "Loop with AddRec with no predecessor?");
12275 const SCEV *L1 = getSCEV(LPhi->getIncomingValueForBlock(Predecessor));
12276 if (!ProvedEasily(L1, RAR->getStart()))
12277 return false;
12278 auto *Latch = RLoop->getLoopLatch();
12279 assert(Latch && "Loop with AddRec with no latch?");
12280 const SCEV *L2 = getSCEV(LPhi->getIncomingValueForBlock(Latch));
12281 if (!ProvedEasily(L2, RAR->getPostIncExpr(*this)))
12282 return false;
12283 } else {
12284 // In all other cases go over inputs of LHS and compare each of them to RHS,
12285 // the predicate is true for (LHS, RHS) if it is true for all such pairs.
12286 // At this point RHS is either a non-Phi, or it is a Phi from some block
12287 // different from LBB.
12288 for (const BasicBlock *IncBB : predecessors(LBB)) {
12289 // Check that RHS is available in this block.
12290 if (!dominates(RHS, IncBB))
12291 return false;
12292 const SCEV *L = getSCEV(LPhi->getIncomingValueForBlock(IncBB));
12293 // Make sure L does not refer to a value from a potentially previous
12294 // iteration of a loop.
12295 if (!properlyDominates(L, LBB))
12296 return false;
12297 if (!ProvedEasily(L, RHS))
12298 return false;
12299 }
12300 }
12301 return true;
12302}
12303
12304bool ScalarEvolution::isImpliedCondOperandsViaShift(ICmpInst::Predicate Pred,
12305 const SCEV *LHS,
12306 const SCEV *RHS,
12307 const SCEV *FoundLHS,
12308 const SCEV *FoundRHS) {
12309 // We want to imply LHS < RHS from LHS < (RHS >> shiftvalue). First, make
12310 // sure that we are dealing with same LHS.
12311 if (RHS == FoundRHS) {
12312 std::swap(LHS, RHS);
12313 std::swap(FoundLHS, FoundRHS);
12314 Pred = ICmpInst::getSwappedPredicate(Pred);
12315 }
12316 if (LHS != FoundLHS)
12317 return false;
12318
12319 auto *SUFoundRHS = dyn_cast<SCEVUnknown>(FoundRHS);
12320 if (!SUFoundRHS)
12321 return false;
12322
12323 Value *Shiftee, *ShiftValue;
12324
12325 using namespace PatternMatch;
12326 if (match(SUFoundRHS->getValue(),
12327 m_LShr(m_Value(Shiftee), m_Value(ShiftValue)))) {
12328 auto *ShifteeS = getSCEV(Shiftee);
12329 // Prove one of the following:
12330 // LHS <u (shiftee >> shiftvalue) && shiftee <=u RHS ---> LHS <u RHS
12331 // LHS <=u (shiftee >> shiftvalue) && shiftee <=u RHS ---> LHS <=u RHS
12332 // LHS <s (shiftee >> shiftvalue) && shiftee <=s RHS && shiftee >=s 0
12333 // ---> LHS <s RHS
12334 // LHS <=s (shiftee >> shiftvalue) && shiftee <=s RHS && shiftee >=s 0
12335 // ---> LHS <=s RHS
12336 if (Pred == ICmpInst::ICMP_ULT || Pred == ICmpInst::ICMP_ULE)
12337 return isKnownPredicate(ICmpInst::ICMP_ULE, ShifteeS, RHS);
12338 if (Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_SLE)
12339 if (isKnownNonNegative(ShifteeS))
12340 return isKnownPredicate(ICmpInst::ICMP_SLE, ShifteeS, RHS);
12341 }
12342
12343 return false;
12344}
12345
12346bool ScalarEvolution::isImpliedCondOperands(ICmpInst::Predicate Pred,
12347 const SCEV *LHS, const SCEV *RHS,
12348 const SCEV *FoundLHS,
12349 const SCEV *FoundRHS,
12350 const Instruction *CtxI) {
12351 if (isImpliedCondOperandsViaRanges(Pred, LHS, RHS, Pred, FoundLHS, FoundRHS))
12352 return true;
12353
12354 if (isImpliedCondOperandsViaNoOverflow(Pred, LHS, RHS, FoundLHS, FoundRHS))
12355 return true;
12356
12357 if (isImpliedCondOperandsViaShift(Pred, LHS, RHS, FoundLHS, FoundRHS))
12358 return true;
12359
12360 if (isImpliedCondOperandsViaAddRecStart(Pred, LHS, RHS, FoundLHS, FoundRHS,
12361 CtxI))
12362 return true;
12363
12364 return isImpliedCondOperandsHelper(Pred, LHS, RHS,
12365 FoundLHS, FoundRHS);
12366}
12367
12368/// Is MaybeMinMaxExpr an (U|S)(Min|Max) of Candidate and some other values?
12369template <typename MinMaxExprType>
12370static bool IsMinMaxConsistingOf(const SCEV *MaybeMinMaxExpr,
12371 const SCEV *Candidate) {
12372 const MinMaxExprType *MinMaxExpr = dyn_cast<MinMaxExprType>(MaybeMinMaxExpr);
12373 if (!MinMaxExpr)
12374 return false;
12375
12376 return is_contained(MinMaxExpr->operands(), Candidate);
12377}
12378
12381 const SCEV *LHS, const SCEV *RHS) {
12382 // If both sides are affine addrecs for the same loop, with equal
12383 // steps, and we know the recurrences don't wrap, then we only
12384 // need to check the predicate on the starting values.
12385
12386 if (!ICmpInst::isRelational(Pred))
12387 return false;
12388
12389 const SCEVAddRecExpr *LAR = dyn_cast<SCEVAddRecExpr>(LHS);
12390 if (!LAR)
12391 return false;
12392 const SCEVAddRecExpr *RAR = dyn_cast<SCEVAddRecExpr>(RHS);
12393 if (!RAR)
12394 return false;
12395 if (LAR->getLoop() != RAR->getLoop())
12396 return false;
12397 if (!LAR->isAffine() || !RAR->isAffine())
12398 return false;
12399
12400 if (LAR->getStepRecurrence(SE) != RAR->getStepRecurrence(SE))
12401 return false;
12402
12405 if (!LAR->getNoWrapFlags(NW) || !RAR->getNoWrapFlags(NW))
12406 return false;
12407
12408 return SE.isKnownPredicate(Pred, LAR->getStart(), RAR->getStart());
12409}
12410
12411/// Is LHS `Pred` RHS true on the virtue of LHS or RHS being a Min or Max
12412/// expression?
12415 const SCEV *LHS, const SCEV *RHS) {
12416 switch (Pred) {
12417 default:
12418 return false;
12419
12420 case ICmpInst::ICMP_SGE:
12421 std::swap(LHS, RHS);
12422 [[fallthrough]];
12423 case ICmpInst::ICMP_SLE:
12424 return
12425 // min(A, ...) <= A
12426 IsMinMaxConsistingOf<SCEVSMinExpr>(LHS, RHS) ||
12427 // A <= max(A, ...)
12428 IsMinMaxConsistingOf<SCEVSMaxExpr>(RHS, LHS);
12429
12430 case ICmpInst::ICMP_UGE:
12431 std::swap(LHS, RHS);
12432 [[fallthrough]];
12433 case ICmpInst::ICMP_ULE:
12434 return
12435 // min(A, ...) <= A
12436 // FIXME: what about umin_seq?
12437 IsMinMaxConsistingOf<SCEVUMinExpr>(LHS, RHS) ||
12438 // A <= max(A, ...)
12439 IsMinMaxConsistingOf<SCEVUMaxExpr>(RHS, LHS);
12440 }
12441
12442 llvm_unreachable("covered switch fell through?!");
12443}
12444
12445bool ScalarEvolution::isImpliedViaOperations(ICmpInst::Predicate Pred,
12446 const SCEV *LHS, const SCEV *RHS,
12447 const SCEV *FoundLHS,
12448 const SCEV *FoundRHS,
12449 unsigned Depth) {
12452 "LHS and RHS have different sizes?");
12453 assert(getTypeSizeInBits(FoundLHS->getType()) ==
12454 getTypeSizeInBits(FoundRHS->getType()) &&
12455 "FoundLHS and FoundRHS have different sizes?");
12456 // We want to avoid hurting the compile time with analysis of too big trees.
12458 return false;
12459
12460 // We only want to work with GT comparison so far.
12461 if (Pred == ICmpInst::ICMP_ULT || Pred == ICmpInst::ICMP_SLT) {
12462 Pred = CmpInst::getSwappedPredicate(Pred);
12463 std::swap(LHS, RHS);
12464 std::swap(FoundLHS, FoundRHS);
12465 }
12466
12467 // For unsigned, try to reduce it to corresponding signed comparison.
12468 if (Pred == ICmpInst::ICMP_UGT)
12469 // We can replace unsigned predicate with its signed counterpart if all
12470 // involved values are non-negative.
12471 // TODO: We could have better support for unsigned.
12472 if (isKnownNonNegative(FoundLHS) && isKnownNonNegative(FoundRHS)) {
12473 // Knowing that both FoundLHS and FoundRHS are non-negative, and knowing
12474 // FoundLHS >u FoundRHS, we also know that FoundLHS >s FoundRHS. Let us
12475 // use this fact to prove that LHS and RHS are non-negative.
12476 const SCEV *MinusOne = getMinusOne(LHS->getType());
12477 if (isImpliedCondOperands(ICmpInst::ICMP_SGT, LHS, MinusOne, FoundLHS,
12478 FoundRHS) &&
12479 isImpliedCondOperands(ICmpInst::ICMP_SGT, RHS, MinusOne, FoundLHS,
12480 FoundRHS))
12481 Pred = ICmpInst::ICMP_SGT;
12482 }
12483
12484 if (Pred != ICmpInst::ICMP_SGT)
12485 return false;
12486
12487 auto GetOpFromSExt = [&](const SCEV *S) {
12488 if (auto *Ext = dyn_cast<SCEVSignExtendExpr>(S))
12489 return Ext->getOperand();
12490 // TODO: If S is a SCEVConstant then you can cheaply "strip" the sext off
12491 // the constant in some cases.
12492 return S;
12493 };
12494
12495 // Acquire values from extensions.
12496 auto *OrigLHS = LHS;
12497 auto *OrigFoundLHS = FoundLHS;
12498 LHS = GetOpFromSExt(LHS);
12499 FoundLHS = GetOpFromSExt(FoundLHS);
12500
12501 // Is the SGT predicate can be proved trivially or using the found context.
12502 auto IsSGTViaContext = [&](const SCEV *S1, const SCEV *S2) {
12503 return isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_SGT, S1, S2) ||
12504 isImpliedViaOperations(ICmpInst::ICMP_SGT, S1, S2, OrigFoundLHS,
12505 FoundRHS, Depth + 1);
12506 };
12507
12508 if (auto *LHSAddExpr = dyn_cast<SCEVAddExpr>(LHS)) {
12509 // We want to avoid creation of any new non-constant SCEV. Since we are
12510 // going to compare the operands to RHS, we should be certain that we don't
12511 // need any size extensions for this. So let's decline all cases when the
12512 // sizes of types of LHS and RHS do not match.
12513 // TODO: Maybe try to get RHS from sext to catch more cases?
12515 return false;
12516
12517 // Should not overflow.
12518 if (!LHSAddExpr->hasNoSignedWrap())
12519 return false;
12520
12521 auto *LL = LHSAddExpr->getOperand(0);
12522 auto *LR = LHSAddExpr->getOperand(1);
12523 auto *MinusOne = getMinusOne(RHS->getType());
12524
12525 // Checks that S1 >= 0 && S2 > RHS, trivially or using the found context.
12526 auto IsSumGreaterThanRHS = [&](const SCEV *S1, const SCEV *S2) {
12527 return IsSGTViaContext(S1, MinusOne) && IsSGTViaContext(S2, RHS);
12528 };
12529 // Try to prove the following rule:
12530 // (LHS = LL + LR) && (LL >= 0) && (LR > RHS) => (LHS > RHS).
12531 // (LHS = LL + LR) && (LR >= 0) && (LL > RHS) => (LHS > RHS).
12532 if (IsSumGreaterThanRHS(LL, LR) || IsSumGreaterThanRHS(LR, LL))
12533 return true;
12534 } else if (auto *LHSUnknownExpr = dyn_cast<SCEVUnknown>(LHS)) {
12535 Value *LL, *LR;
12536 // FIXME: Once we have SDiv implemented, we can get rid of this matching.
12537
12538 using namespace llvm::PatternMatch;
12539
12540 if (match(LHSUnknownExpr->getValue(), m_SDiv(m_Value(LL), m_Value(LR)))) {
12541 // Rules for division.
12542 // We are going to perform some comparisons with Denominator and its
12543 // derivative expressions. In general case, creating a SCEV for it may
12544 // lead to a complex analysis of the entire graph, and in particular it
12545 // can request trip count recalculation for the same loop. This would
12546 // cache as SCEVCouldNotCompute to avoid the infinite recursion. To avoid
12547 // this, we only want to create SCEVs that are constants in this section.
12548 // So we bail if Denominator is not a constant.
12549 if (!isa<ConstantInt>(LR))
12550 return false;
12551
12552 auto *Denominator = cast<SCEVConstant>(getSCEV(LR));
12553
12554 // We want to make sure that LHS = FoundLHS / Denominator. If it is so,
12555 // then a SCEV for the numerator already exists and matches with FoundLHS.
12556 auto *Numerator = getExistingSCEV(LL);
12557 if (!Numerator || Numerator->getType() != FoundLHS->getType())
12558 return false;
12559
12560 // Make sure that the numerator matches with FoundLHS and the denominator
12561 // is positive.
12562 if (!HasSameValue(Numerator, FoundLHS) || !isKnownPositive(Denominator))
12563 return false;
12564
12565 auto *DTy = Denominator->getType();
12566 auto *FRHSTy = FoundRHS->getType();
12567 if (DTy->isPointerTy() != FRHSTy->isPointerTy())
12568 // One of types is a pointer and another one is not. We cannot extend
12569 // them properly to a wider type, so let us just reject this case.
12570 // TODO: Usage of getEffectiveSCEVType for DTy, FRHSTy etc should help
12571 // to avoid this check.
12572 return false;
12573
12574 // Given that:
12575 // FoundLHS > FoundRHS, LHS = FoundLHS / Denominator, Denominator > 0.
12576 auto *WTy = getWiderType(DTy, FRHSTy);
12577 auto *DenominatorExt = getNoopOrSignExtend(Denominator, WTy);
12578 auto *FoundRHSExt = getNoopOrSignExtend(FoundRHS, WTy);
12579
12580 // Try to prove the following rule:
12581 // (FoundRHS > Denominator - 2) && (RHS <= 0) => (LHS > RHS).
12582 // For example, given that FoundLHS > 2. It means that FoundLHS is at
12583 // least 3. If we divide it by Denominator < 4, we will have at least 1.
12584 auto *DenomMinusTwo = getMinusSCEV(DenominatorExt, getConstant(WTy, 2));
12585 if (isKnownNonPositive(RHS) &&
12586 IsSGTViaContext(FoundRHSExt, DenomMinusTwo))
12587 return true;
12588
12589 // Try to prove the following rule:
12590 // (FoundRHS > -1 - Denominator) && (RHS < 0) => (LHS > RHS).
12591 // For example, given that FoundLHS > -3. Then FoundLHS is at least -2.
12592 // If we divide it by Denominator > 2, then:
12593 // 1. If FoundLHS is negative, then the result is 0.
12594 // 2. If FoundLHS is non-negative, then the result is non-negative.
12595 // Anyways, the result is non-negative.
12596 auto *MinusOne = getMinusOne(WTy);
12597 auto *NegDenomMinusOne = getMinusSCEV(MinusOne, DenominatorExt);
12598 if (isKnownNegative(RHS) &&
12599 IsSGTViaContext(FoundRHSExt, NegDenomMinusOne))
12600 return true;
12601 }
12602 }
12603
12604 // If our expression contained SCEVUnknown Phis, and we split it down and now
12605 // need to prove something for them, try to prove the predicate for every
12606 // possible incoming values of those Phis.
12607 if (isImpliedViaMerge(Pred, OrigLHS, RHS, OrigFoundLHS, FoundRHS, Depth + 1))
12608 return true;
12609
12610 return false;
12611}
12612
12614 const SCEV *LHS, const SCEV *RHS) {
12615 // zext x u<= sext x, sext x s<= zext x
12616 switch (Pred) {
12617 case ICmpInst::ICMP_SGE:
12618 std::swap(LHS, RHS);
12619 [[fallthrough]];
12620 case ICmpInst::ICMP_SLE: {
12621 // If operand >=s 0 then ZExt == SExt. If operand <s 0 then SExt <s ZExt.
12622 const SCEVSignExtendExpr *SExt = dyn_cast<SCEVSignExtendExpr>(LHS);
12623 const SCEVZeroExtendExpr *ZExt = dyn_cast<SCEVZeroExtendExpr>(RHS);
12624 if (SExt && ZExt && SExt->getOperand() == ZExt->getOperand())
12625 return true;
12626 break;
12627 }
12628 case ICmpInst::ICMP_UGE:
12629 std::swap(LHS, RHS);
12630 [[fallthrough]];
12631 case ICmpInst::ICMP_ULE: {
12632 // If operand >=s 0 then ZExt == SExt. If operand <s 0 then ZExt <u SExt.
12633 const SCEVZeroExtendExpr *ZExt = dyn_cast<SCEVZeroExtendExpr>(LHS);
12634 const SCEVSignExtendExpr *SExt = dyn_cast<SCEVSignExtendExpr>(RHS);
12635 if (SExt && ZExt && SExt->getOperand() == ZExt->getOperand())
12636 return true;
12637 break;
12638 }
12639 default:
12640 break;
12641 };
12642 return false;
12643}
12644
12645bool
12646ScalarEvolution::isKnownViaNonRecursiveReasoning(ICmpInst::Predicate Pred,
12647 const SCEV *LHS, const SCEV *RHS) {
12648 return isKnownPredicateExtendIdiom(Pred, LHS, RHS) ||
12649 isKnownPredicateViaConstantRanges(Pred, LHS, RHS) ||
12650 IsKnownPredicateViaMinOrMax(*this, Pred, LHS, RHS) ||
12651 IsKnownPredicateViaAddRecStart(*this, Pred, LHS, RHS) ||
12652 isKnownPredicateViaNoOverflow(Pred, LHS, RHS);
12653}
12654
12655bool
12656ScalarEvolution::isImpliedCondOperandsHelper(ICmpInst::Predicate Pred,
12657 const SCEV *LHS, const SCEV *RHS,
12658 const SCEV *FoundLHS,
12659 const SCEV *FoundRHS) {
12660 switch (Pred) {
12661 default: llvm_unreachable("Unexpected ICmpInst::Predicate value!");
12662 case ICmpInst::ICMP_EQ:
12663 case ICmpInst::ICMP_NE:
12664 if (HasSameValue(LHS, FoundLHS) && HasSameValue(RHS, FoundRHS))
12665 return true;
12666 break;
12667 case ICmpInst::ICMP_SLT:
12668 case ICmpInst::ICMP_SLE:
12669 if (isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_SLE, LHS, FoundLHS) &&
12670 isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_SGE, RHS, FoundRHS))
12671 return true;
12672 break;
12673 case ICmpInst::ICMP_SGT:
12674 case ICmpInst::ICMP_SGE:
12675 if (isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_SGE, LHS, FoundLHS) &&
12676 isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_SLE, RHS, FoundRHS))
12677 return true;
12678 break;
12679 case ICmpInst::ICMP_ULT:
12680 case ICmpInst::ICMP_ULE:
12681 if (isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_ULE, LHS, FoundLHS) &&
12682 isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_UGE, RHS, FoundRHS))
12683 return true;
12684 break;
12685 case ICmpInst::ICMP_UGT:
12686 case ICmpInst::ICMP_UGE:
12687 if (isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_UGE, LHS, FoundLHS) &&
12688 isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_ULE, RHS, FoundRHS))
12689 return true;
12690 break;
12691 }
12692
12693 // Maybe it can be proved via operations?
12694 if (isImpliedViaOperations(Pred, LHS, RHS, FoundLHS, FoundRHS))
12695 return true;
12696
12697 return false;
12698}
12699
12700bool ScalarEvolution::isImpliedCondOperandsViaRanges(ICmpInst::Predicate Pred,
12701 const SCEV *LHS,
12702 const SCEV *RHS,
12703 ICmpInst::Predicate FoundPred,
12704 const SCEV *FoundLHS,
12705 const SCEV *FoundRHS) {
12706 if (!isa<SCEVConstant>(RHS) || !isa<SCEVConstant>(FoundRHS))
12707 // The restriction on `FoundRHS` be lifted easily -- it exists only to
12708 // reduce the compile time impact of this optimization.
12709 return false;
12710
12711 std::optional<APInt> Addend = computeConstantDifference(LHS, FoundLHS);
12712 if (!Addend)
12713 return false;
12714
12715 const APInt &ConstFoundRHS = cast<SCEVConstant>(FoundRHS)->getAPInt();
12716
12717 // `FoundLHSRange` is the range we know `FoundLHS` to be in by virtue of the
12718 // antecedent "`FoundLHS` `FoundPred` `FoundRHS`".
12719 ConstantRange FoundLHSRange =
12720 ConstantRange::makeExactICmpRegion(FoundPred, ConstFoundRHS);
12721
12722 // Since `LHS` is `FoundLHS` + `Addend`, we can compute a range for `LHS`:
12723 ConstantRange LHSRange = FoundLHSRange.add(ConstantRange(*Addend));
12724
12725 // We can also compute the range of values for `LHS` that satisfy the
12726 // consequent, "`LHS` `Pred` `RHS`":
12727 const APInt &ConstRHS = cast<SCEVConstant>(RHS)->getAPInt();
12728 // The antecedent implies the consequent if every value of `LHS` that
12729 // satisfies the antecedent also satisfies the consequent.
12730 return LHSRange.icmp(Pred, ConstRHS);
12731}
12732
12733bool ScalarEvolution::canIVOverflowOnLT(const SCEV *RHS, const SCEV *Stride,
12734 bool IsSigned) {
12735 assert(isKnownPositive(Stride) && "Positive stride expected!");
12736
12737 unsigned BitWidth = getTypeSizeInBits(RHS->getType());
12738 const SCEV *One = getOne(Stride->getType());
12739
12740 if (IsSigned) {
12741 APInt MaxRHS = getSignedRangeMax(RHS);
12743 APInt MaxStrideMinusOne = getSignedRangeMax(getMinusSCEV(Stride, One));
12744
12745 // SMaxRHS + SMaxStrideMinusOne > SMaxValue => overflow!
12746 return (std::move(MaxValue) - MaxStrideMinusOne).slt(MaxRHS);
12747 }
12748
12749 APInt MaxRHS = getUnsignedRangeMax(RHS);
12750 APInt MaxValue = APInt::getMaxValue(BitWidth);
12751 APInt MaxStrideMinusOne = getUnsignedRangeMax(getMinusSCEV(Stride, One));
12752
12753 // UMaxRHS + UMaxStrideMinusOne > UMaxValue => overflow!
12754 return (std::move(MaxValue) - MaxStrideMinusOne).ult(MaxRHS);
12755}
12756
12757bool ScalarEvolution::canIVOverflowOnGT(const SCEV *RHS, const SCEV *Stride,
12758 bool IsSigned) {
12759
12760 unsigned BitWidth = getTypeSizeInBits(RHS->getType());
12761 const SCEV *One = getOne(Stride->getType());
12762
12763 if (IsSigned) {
12764 APInt MinRHS = getSignedRangeMin(RHS);
12766 APInt MaxStrideMinusOne = getSignedRangeMax(getMinusSCEV(Stride, One));
12767
12768 // SMinRHS - SMaxStrideMinusOne < SMinValue => overflow!
12769 return (std::move(MinValue) + MaxStrideMinusOne).sgt(MinRHS);
12770 }
12771
12772 APInt MinRHS = getUnsignedRangeMin(RHS);
12773 APInt MinValue = APInt::getMinValue(BitWidth);
12774 APInt MaxStrideMinusOne = getUnsignedRangeMax(getMinusSCEV(Stride, One));
12775
12776 // UMinRHS - UMaxStrideMinusOne < UMinValue => overflow!
12777 return (std::move(MinValue) + MaxStrideMinusOne).ugt(MinRHS);
12778}
12779
12781 // umin(N, 1) + floor((N - umin(N, 1)) / D)
12782 // This is equivalent to "1 + floor((N - 1) / D)" for N != 0. The umin
12783 // expression fixes the case of N=0.
12784 const SCEV *MinNOne = getUMinExpr(N, getOne(N->getType()));
12785 const SCEV *NMinusOne = getMinusSCEV(N, MinNOne);
12786 return getAddExpr(MinNOne, getUDivExpr(NMinusOne, D));
12787}
12788
12789const SCEV *ScalarEvolution::computeMaxBECountForLT(const SCEV *Start,
12790 const SCEV *Stride,
12791 const SCEV *End,
12792 unsigned BitWidth,
12793 bool IsSigned) {
12794 // The logic in this function assumes we can represent a positive stride.
12795 // If we can't, the backedge-taken count must be zero.
12796 if (IsSigned && BitWidth == 1)
12797 return getZero(Stride->getType());
12798
12799 // This code below only been closely audited for negative strides in the
12800 // unsigned comparison case, it may be correct for signed comparison, but
12801 // that needs to be established.
12802 if (IsSigned && isKnownNegative(Stride))
12803 return getCouldNotCompute();
12804
12805 // Calculate the maximum backedge count based on the range of values
12806 // permitted by Start, End, and Stride.
12807 APInt MinStart =
12808 IsSigned ? getSignedRangeMin(Start) : getUnsignedRangeMin(Start);
12809
12810 APInt MinStride =
12811 IsSigned ? getSignedRangeMin(Stride) : getUnsignedRangeMin(Stride);
12812
12813 // We assume either the stride is positive, or the backedge-taken count
12814 // is zero. So force StrideForMaxBECount to be at least one.
12815 APInt One(BitWidth, 1);
12816 APInt StrideForMaxBECount = IsSigned ? APIntOps::smax(One, MinStride)
12817 : APIntOps::umax(One, MinStride);
12818
12819 APInt MaxValue = IsSigned ? APInt::getSignedMaxValue(BitWidth)
12820 : APInt::getMaxValue(BitWidth);
12821 APInt Limit = MaxValue - (StrideForMaxBECount - 1);
12822
12823 // Although End can be a MAX expression we estimate MaxEnd considering only
12824 // the case End = RHS of the loop termination condition. This is safe because
12825 // in the other case (End - Start) is zero, leading to a zero maximum backedge
12826 // taken count.
12827 APInt MaxEnd = IsSigned ? APIntOps::smin(getSignedRangeMax(End), Limit)
12828 : APIntOps::umin(getUnsignedRangeMax(End), Limit);
12829
12830 // MaxBECount = ceil((max(MaxEnd, MinStart) - MinStart) / Stride)
12831 MaxEnd = IsSigned ? APIntOps::smax(MaxEnd, MinStart)
12832 : APIntOps::umax(MaxEnd, MinStart);
12833
12834 return getUDivCeilSCEV(getConstant(MaxEnd - MinStart) /* Delta */,
12835 getConstant(StrideForMaxBECount) /* Step */);
12836}
12837
12839ScalarEvolution::howManyLessThans(const SCEV *LHS, const SCEV *RHS,
12840 const Loop *L, bool IsSigned,
12841 bool ControlsOnlyExit, bool AllowPredicates) {
12843
12844 const SCEVAddRecExpr *IV = dyn_cast<SCEVAddRecExpr>(LHS);
12845 bool PredicatedIV = false;
12846 if (!IV) {
12847 if (auto *ZExt = dyn_cast<SCEVZeroExtendExpr>(LHS)) {
12848 const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(ZExt->getOperand());
12849 if (AR && AR->getLoop() == L && AR->isAffine()) {
12850 auto canProveNUW = [&]() {
12851 // We can use the comparison to infer no-wrap flags only if it fully
12852 // controls the loop exit.
12853 if (!ControlsOnlyExit)
12854 return false;
12855
12856 if (!isLoopInvariant(RHS, L))
12857 return false;
12858
12859 if (!isKnownNonZero(AR->getStepRecurrence(*this)))
12860 // We need the sequence defined by AR to strictly increase in the
12861 // unsigned integer domain for the logic below to hold.
12862 return false;
12863
12864 const unsigned InnerBitWidth = getTypeSizeInBits(AR->getType());
12865 const unsigned OuterBitWidth = getTypeSizeInBits(RHS->getType());
12866 // If RHS <=u Limit, then there must exist a value V in the sequence
12867 // defined by AR (e.g. {Start,+,Step}) such that V >u RHS, and
12868 // V <=u UINT_MAX. Thus, we must exit the loop before unsigned
12869 // overflow occurs. This limit also implies that a signed comparison
12870 // (in the wide bitwidth) is equivalent to an unsigned comparison as
12871 // the high bits on both sides must be zero.
12872 APInt StrideMax = getUnsignedRangeMax(AR->getStepRecurrence(*this));
12873 APInt Limit = APInt::getMaxValue(InnerBitWidth) - (StrideMax - 1);
12874 Limit = Limit.zext(OuterBitWidth);
12875 return getUnsignedRangeMax(applyLoopGuards(RHS, L)).ule(Limit);
12876 };
12877 auto Flags = AR->getNoWrapFlags();
12878 if (!hasFlags(Flags, SCEV::FlagNUW) && canProveNUW())
12879 Flags = setFlags(Flags, SCEV::FlagNUW);
12880
12881 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), Flags);
12882 if (AR->hasNoUnsignedWrap()) {
12883 // Emulate what getZeroExtendExpr would have done during construction
12884 // if we'd been able to infer the fact just above at that time.
12885 const SCEV *Step = AR->getStepRecurrence(*this);
12886 Type *Ty = ZExt->getType();
12887 auto *S = getAddRecExpr(
12888 getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this, 0),
12889 getZeroExtendExpr(Step, Ty, 0), L, AR->getNoWrapFlags());
12890 IV = dyn_cast<SCEVAddRecExpr>(S);
12891 }
12892 }
12893 }
12894 }
12895
12896
12897 if (!IV && AllowPredicates) {
12898 // Try to make this an AddRec using runtime tests, in the first X
12899 // iterations of this loop, where X is the SCEV expression found by the
12900 // algorithm below.
12901 IV = convertSCEVToAddRecWithPredicates(LHS, L, Predicates);
12902 PredicatedIV = true;
12903 }
12904
12905 // Avoid weird loops
12906 if (!IV || IV->getLoop() != L || !IV->isAffine())
12907 return getCouldNotCompute();
12908
12909 // A precondition of this method is that the condition being analyzed
12910 // reaches an exiting branch which dominates the latch. Given that, we can
12911 // assume that an increment which violates the nowrap specification and
12912 // produces poison must cause undefined behavior when the resulting poison
12913 // value is branched upon and thus we can conclude that the backedge is
12914 // taken no more often than would be required to produce that poison value.
12915 // Note that a well defined loop can exit on the iteration which violates
12916 // the nowrap specification if there is another exit (either explicit or
12917 // implicit/exceptional) which causes the loop to execute before the
12918 // exiting instruction we're analyzing would trigger UB.
12919 auto WrapType = IsSigned ? SCEV::FlagNSW : SCEV::FlagNUW;
12920 bool NoWrap = ControlsOnlyExit && IV->getNoWrapFlags(WrapType);
12922
12923 const SCEV *Stride = IV->getStepRecurrence(*this);
12924
12925 bool PositiveStride = isKnownPositive(Stride);
12926
12927 // Avoid negative or zero stride values.
12928 if (!PositiveStride) {
12929 // We can compute the correct backedge taken count for loops with unknown
12930 // strides if we can prove that the loop is not an infinite loop with side
12931 // effects. Here's the loop structure we are trying to handle -
12932 //
12933 // i = start
12934 // do {
12935 // A[i] = i;
12936 // i += s;
12937 // } while (i < end);
12938 //
12939 // The backedge taken count for such loops is evaluated as -
12940 // (max(end, start + stride) - start - 1) /u stride
12941 //
12942 // The additional preconditions that we need to check to prove correctness
12943 // of the above formula is as follows -
12944 //
12945 // a) IV is either nuw or nsw depending upon signedness (indicated by the
12946 // NoWrap flag).
12947 // b) the loop is guaranteed to be finite (e.g. is mustprogress and has
12948 // no side effects within the loop)
12949 // c) loop has a single static exit (with no abnormal exits)
12950 //
12951 // Precondition a) implies that if the stride is negative, this is a single
12952 // trip loop. The backedge taken count formula reduces to zero in this case.
12953 //
12954 // Precondition b) and c) combine to imply that if rhs is invariant in L,
12955 // then a zero stride means the backedge can't be taken without executing
12956 // undefined behavior.
12957 //
12958 // The positive stride case is the same as isKnownPositive(Stride) returning
12959 // true (original behavior of the function).
12960 //
12961 if (PredicatedIV || !NoWrap || !loopIsFiniteByAssumption(L) ||
12963 return getCouldNotCompute();
12964
12965 if (!isKnownNonZero(Stride)) {
12966 // If we have a step of zero, and RHS isn't invariant in L, we don't know
12967 // if it might eventually be greater than start and if so, on which
12968 // iteration. We can't even produce a useful upper bound.
12969 if (!isLoopInvariant(RHS, L))
12970 return getCouldNotCompute();
12971
12972 // We allow a potentially zero stride, but we need to divide by stride
12973 // below. Since the loop can't be infinite and this check must control
12974 // the sole exit, we can infer the exit must be taken on the first
12975 // iteration (e.g. backedge count = 0) if the stride is zero. Given that,
12976 // we know the numerator in the divides below must be zero, so we can
12977 // pick an arbitrary non-zero value for the denominator (e.g. stride)
12978 // and produce the right result.
12979 // FIXME: Handle the case where Stride is poison?
12980 auto wouldZeroStrideBeUB = [&]() {
12981 // Proof by contradiction. Suppose the stride were zero. If we can
12982 // prove that the backedge *is* taken on the first iteration, then since
12983 // we know this condition controls the sole exit, we must have an
12984 // infinite loop. We can't have a (well defined) infinite loop per
12985 // check just above.
12986 // Note: The (Start - Stride) term is used to get the start' term from
12987 // (start' + stride,+,stride). Remember that we only care about the
12988 // result of this expression when stride == 0 at runtime.
12989 auto *StartIfZero = getMinusSCEV(IV->getStart(), Stride);
12990 return isLoopEntryGuardedByCond(L, Cond, StartIfZero, RHS);
12991 };
12992 if (!wouldZeroStrideBeUB()) {
12993 Stride = getUMaxExpr(Stride, getOne(Stride->getType()));
12994 }
12995 }
12996 } else if (!NoWrap) {
12997 // Avoid proven overflow cases: this will ensure that the backedge taken
12998 // count will not generate any unsigned overflow.
12999 if (canIVOverflowOnLT(RHS, Stride, IsSigned))
13000 return getCouldNotCompute();
13001 }
13002
13003 // On all paths just preceeding, we established the following invariant:
13004 // IV can be assumed not to overflow up to and including the exiting
13005 // iteration. We proved this in one of two ways:
13006 // 1) We can show overflow doesn't occur before the exiting iteration
13007 // 1a) canIVOverflowOnLT, and b) step of one
13008 // 2) We can show that if overflow occurs, the loop must execute UB
13009 // before any possible exit.
13010 // Note that we have not yet proved RHS invariant (in general).
13011
13012 const SCEV *Start = IV->getStart();
13013
13014 // Preserve pointer-typed Start/RHS to pass to isLoopEntryGuardedByCond.
13015 // If we convert to integers, isLoopEntryGuardedByCond will miss some cases.
13016 // Use integer-typed versions for actual computation; we can't subtract
13017 // pointers in general.
13018 const SCEV *OrigStart = Start;
13019 const SCEV *OrigRHS = RHS;
13020 if (Start->getType()->isPointerTy()) {
13021 Start = getLosslessPtrToIntExpr(Start);
13022 if (isa<SCEVCouldNotCompute>(Start))
13023 return Start;
13024 }
13025 if (RHS->getType()->isPointerTy()) {
13027 if (isa<SCEVCouldNotCompute>(RHS))
13028 return RHS;
13029 }
13030
13031 const SCEV *End = nullptr, *BECount = nullptr,
13032 *BECountIfBackedgeTaken = nullptr;
13033 if (!isLoopInvariant(RHS, L)) {
13034 const auto *RHSAddRec = dyn_cast<SCEVAddRecExpr>(RHS);
13035 if (PositiveStride && RHSAddRec != nullptr && RHSAddRec->getLoop() == L &&
13036 RHSAddRec->getNoWrapFlags()) {
13037 // The structure of loop we are trying to calculate backedge count of:
13038 //
13039 // left = left_start
13040 // right = right_start
13041 //
13042 // while(left < right){
13043 // ... do something here ...
13044 // left += s1; // stride of left is s1 (s1 > 0)
13045 // right += s2; // stride of right is s2 (s2 < 0)
13046 // }
13047 //
13048
13049 const SCEV *RHSStart = RHSAddRec->getStart();
13050 const SCEV *RHSStride = RHSAddRec->getStepRecurrence(*this);
13051
13052 // If Stride - RHSStride is positive and does not overflow, we can write
13053 // backedge count as ->
13054 // ceil((End - Start) /u (Stride - RHSStride))
13055 // Where, End = max(RHSStart, Start)
13056
13057 // Check if RHSStride < 0 and Stride - RHSStride will not overflow.
13058 if (isKnownNegative(RHSStride) &&
13059 willNotOverflow(Instruction::Sub, /*Signed=*/true, Stride,
13060 RHSStride)) {
13061
13062 const SCEV *Denominator = getMinusSCEV(Stride, RHSStride);
13063 if (isKnownPositive(Denominator)) {
13064 End = IsSigned ? getSMaxExpr(RHSStart, Start)
13065 : getUMaxExpr(RHSStart, Start);
13066
13067 // We can do this because End >= Start, as End = max(RHSStart, Start)
13068 const SCEV *Delta = getMinusSCEV(End, Start);
13069
13070 BECount = getUDivCeilSCEV(Delta, Denominator);
13071 BECountIfBackedgeTaken =
13072 getUDivCeilSCEV(getMinusSCEV(RHSStart, Start), Denominator);
13073 }
13074 }
13075 }
13076 if (BECount == nullptr) {
13077 // If we cannot calculate ExactBECount, we can calculate the MaxBECount,
13078 // given the start, stride and max value for the end bound of the
13079 // loop (RHS), and the fact that IV does not overflow (which is
13080 // checked above).
13081 const SCEV *MaxBECount = computeMaxBECountForLT(
13082 Start, Stride, RHS, getTypeSizeInBits(LHS->getType()), IsSigned);
13083 return ExitLimit(getCouldNotCompute() /* ExactNotTaken */, MaxBECount,
13084 MaxBECount, false /*MaxOrZero*/, Predicates);
13085 }
13086 } else {
13087 // We use the expression (max(End,Start)-Start)/Stride to describe the
13088 // backedge count, as if the backedge is taken at least once
13089 // max(End,Start) is End and so the result is as above, and if not
13090 // max(End,Start) is Start so we get a backedge count of zero.
13091 auto *OrigStartMinusStride = getMinusSCEV(OrigStart, Stride);
13092 assert(isAvailableAtLoopEntry(OrigStartMinusStride, L) && "Must be!");
13093 assert(isAvailableAtLoopEntry(OrigStart, L) && "Must be!");
13094 assert(isAvailableAtLoopEntry(OrigRHS, L) && "Must be!");
13095 // Can we prove (max(RHS,Start) > Start - Stride?
13096 if (isLoopEntryGuardedByCond(L, Cond, OrigStartMinusStride, OrigStart) &&
13097 isLoopEntryGuardedByCond(L, Cond, OrigStartMinusStride, OrigRHS)) {
13098 // In this case, we can use a refined formula for computing backedge
13099 // taken count. The general formula remains:
13100 // "End-Start /uceiling Stride" where "End = max(RHS,Start)"
13101 // We want to use the alternate formula:
13102 // "((End - 1) - (Start - Stride)) /u Stride"
13103 // Let's do a quick case analysis to show these are equivalent under
13104 // our precondition that max(RHS,Start) > Start - Stride.
13105 // * For RHS <= Start, the backedge-taken count must be zero.
13106 // "((End - 1) - (Start - Stride)) /u Stride" reduces to
13107 // "((Start - 1) - (Start - Stride)) /u Stride" which simplies to
13108 // "Stride - 1 /u Stride" which is indeed zero for all non-zero values
13109 // of Stride. For 0 stride, we've use umin(1,Stride) above,
13110 // reducing this to the stride of 1 case.
13111 // * For RHS >= Start, the backedge count must be "RHS-Start /uceil
13112 // Stride".
13113 // "((End - 1) - (Start - Stride)) /u Stride" reduces to
13114 // "((RHS - 1) - (Start - Stride)) /u Stride" reassociates to
13115 // "((RHS - (Start - Stride) - 1) /u Stride".
13116 // Our preconditions trivially imply no overflow in that form.
13117 const SCEV *MinusOne = getMinusOne(Stride->getType());
13118 const SCEV *Numerator =
13119 getMinusSCEV(getAddExpr(RHS, MinusOne), getMinusSCEV(Start, Stride));
13120 BECount = getUDivExpr(Numerator, Stride);
13121 }
13122
13123 if (!BECount) {
13124 auto canProveRHSGreaterThanEqualStart = [&]() {
13125 auto CondGE = IsSigned ? ICmpInst::ICMP_SGE : ICmpInst::ICMP_UGE;
13126 const SCEV *GuardedRHS = applyLoopGuards(OrigRHS, L);
13127 const SCEV *GuardedStart = applyLoopGuards(OrigStart, L);
13128
13129 if (isLoopEntryGuardedByCond(L, CondGE, OrigRHS, OrigStart) ||
13130 isKnownPredicate(CondGE, GuardedRHS, GuardedStart))
13131 return true;
13132
13133 // (RHS > Start - 1) implies RHS >= Start.
13134 // * "RHS >= Start" is trivially equivalent to "RHS > Start - 1" if
13135 // "Start - 1" doesn't overflow.
13136 // * For signed comparison, if Start - 1 does overflow, it's equal
13137 // to INT_MAX, and "RHS >s INT_MAX" is trivially false.
13138 // * For unsigned comparison, if Start - 1 does overflow, it's equal
13139 // to UINT_MAX, and "RHS >u UINT_MAX" is trivially false.
13140 //
13141 // FIXME: Should isLoopEntryGuardedByCond do this for us?
13142 auto CondGT = IsSigned ? ICmpInst::ICMP_SGT : ICmpInst::ICMP_UGT;
13143 auto *StartMinusOne =
13144 getAddExpr(OrigStart, getMinusOne(OrigStart->getType()));
13145 return isLoopEntryGuardedByCond(L, CondGT, OrigRHS, StartMinusOne);
13146 };
13147
13148 // If we know that RHS >= Start in the context of loop, then we know
13149 // that max(RHS, Start) = RHS at this point.
13150 if (canProveRHSGreaterThanEqualStart()) {
13151 End = RHS;
13152 } else {
13153 // If RHS < Start, the backedge will be taken zero times. So in
13154 // general, we can write the backedge-taken count as:
13155 //
13156 // RHS >= Start ? ceil(RHS - Start) / Stride : 0
13157 //
13158 // We convert it to the following to make it more convenient for SCEV:
13159 //
13160 // ceil(max(RHS, Start) - Start) / Stride
13161 End = IsSigned ? getSMaxExpr(RHS, Start) : getUMaxExpr(RHS, Start);
13162
13163 // See what would happen if we assume the backedge is taken. This is
13164 // used to compute MaxBECount.
13165 BECountIfBackedgeTaken =
13166 getUDivCeilSCEV(getMinusSCEV(RHS, Start), Stride);
13167 }
13168
13169 // At this point, we know:
13170 //
13171 // 1. If IsSigned, Start <=s End; otherwise, Start <=u End
13172 // 2. The index variable doesn't overflow.
13173 //
13174 // Therefore, we know N exists such that
13175 // (Start + Stride * N) >= End, and computing "(Start + Stride * N)"
13176 // doesn't overflow.
13177 //
13178 // Using this information, try to prove whether the addition in
13179 // "(Start - End) + (Stride - 1)" has unsigned overflow.
13180 const SCEV *One = getOne(Stride->getType());
13181 bool MayAddOverflow = [&] {
13182 if (isKnownToBeAPowerOfTwo(Stride)) {
13183 // Suppose Stride is a power of two, and Start/End are unsigned
13184 // integers. Let UMAX be the largest representable unsigned
13185 // integer.
13186 //
13187 // By the preconditions of this function, we know
13188 // "(Start + Stride * N) >= End", and this doesn't overflow.
13189 // As a formula:
13190 //
13191 // End <= (Start + Stride * N) <= UMAX
13192 //
13193 // Subtracting Start from all the terms:
13194 //
13195 // End - Start <= Stride * N <= UMAX - Start
13196 //
13197 // Since Start is unsigned, UMAX - Start <= UMAX. Therefore:
13198 //
13199 // End - Start <= Stride * N <= UMAX
13200 //
13201 // Stride * N is a multiple of Stride. Therefore,
13202 //
13203 // End - Start <= Stride * N <= UMAX - (UMAX mod Stride)
13204 //
13205 // Since Stride is a power of two, UMAX + 1 is divisible by
13206 // Stride. Therefore, UMAX mod Stride == Stride - 1. So we can
13207 // write:
13208 //
13209 // End - Start <= Stride * N <= UMAX - Stride - 1
13210 //
13211 // Dropping the middle term:
13212 //
13213 // End - Start <= UMAX - Stride - 1
13214 //
13215 // Adding Stride - 1 to both sides:
13216 //
13217 // (End - Start) + (Stride - 1) <= UMAX
13218 //
13219 // In other words, the addition doesn't have unsigned overflow.
13220 //
13221 // A similar proof works if we treat Start/End as signed values.
13222 // Just rewrite steps before "End - Start <= Stride * N <= UMAX"
13223 // to use signed max instead of unsigned max. Note that we're
13224 // trying to prove a lack of unsigned overflow in either case.
13225 return false;
13226 }
13227 if (Start == Stride || Start == getMinusSCEV(Stride, One)) {
13228 // If Start is equal to Stride, (End - Start) + (Stride - 1) == End
13229 // - 1. If !IsSigned, 0 <u Stride == Start <=u End; so 0 <u End - 1
13230 // <u End. If IsSigned, 0 <s Stride == Start <=s End; so 0 <s End -
13231 // 1 <s End.
13232 //
13233 // If Start is equal to Stride - 1, (End - Start) + Stride - 1 ==
13234 // End.
13235 return false;
13236 }
13237 return true;
13238 }();
13239
13240 const SCEV *Delta = getMinusSCEV(End, Start);
13241 if (!MayAddOverflow) {
13242 // floor((D + (S - 1)) / S)
13243 // We prefer this formulation if it's legal because it's fewer
13244 // operations.
13245 BECount =
13246 getUDivExpr(getAddExpr(Delta, getMinusSCEV(Stride, One)), Stride);
13247 } else {
13248 BECount = getUDivCeilSCEV(Delta, Stride);
13249 }
13250 }
13251 }
13252
13253 const SCEV *ConstantMaxBECount;
13254 bool MaxOrZero = false;
13255 if (isa<SCEVConstant>(BECount)) {
13256 ConstantMaxBECount = BECount;
13257 } else if (BECountIfBackedgeTaken &&
13258 isa<SCEVConstant>(BECountIfBackedgeTaken)) {
13259 // If we know exactly how many times the backedge will be taken if it's
13260 // taken at least once, then the backedge count will either be that or
13261 // zero.
13262 ConstantMaxBECount = BECountIfBackedgeTaken;
13263 MaxOrZero = true;
13264 } else {
13265 ConstantMaxBECount = computeMaxBECountForLT(
13266 Start, Stride, RHS, getTypeSizeInBits(LHS->getType()), IsSigned);
13267 }
13268
13269 if (isa<SCEVCouldNotCompute>(ConstantMaxBECount) &&
13270 !isa<SCEVCouldNotCompute>(BECount))
13271 ConstantMaxBECount = getConstant(getUnsignedRangeMax(BECount));
13272
13273 const SCEV *SymbolicMaxBECount =
13274 isa<SCEVCouldNotCompute>(BECount) ? ConstantMaxBECount : BECount;
13275 return ExitLimit(BECount, ConstantMaxBECount, SymbolicMaxBECount, MaxOrZero,
13276 Predicates);
13277}
13278
13279ScalarEvolution::ExitLimit ScalarEvolution::howManyGreaterThans(
13280 const SCEV *LHS, const SCEV *RHS, const Loop *L, bool IsSigned,
13281 bool ControlsOnlyExit, bool AllowPredicates) {
13283 // We handle only IV > Invariant
13284 if (!isLoopInvariant(RHS, L))
13285 return getCouldNotCompute();
13286
13287 const SCEVAddRecExpr *IV = dyn_cast<SCEVAddRecExpr>(LHS);
13288 if (!IV && AllowPredicates)
13289 // Try to make this an AddRec using runtime tests, in the first X
13290 // iterations of this loop, where X is the SCEV expression found by the
13291 // algorithm below.
13292 IV = convertSCEVToAddRecWithPredicates(LHS, L, Predicates);
13293
13294 // Avoid weird loops
13295 if (!IV || IV->getLoop() != L || !IV->isAffine())
13296 return getCouldNotCompute();
13297
13298 auto WrapType = IsSigned ? SCEV::FlagNSW : SCEV::FlagNUW;
13299 bool NoWrap = ControlsOnlyExit && IV->getNoWrapFlags(WrapType);
13301
13302 const SCEV *Stride = getNegativeSCEV(IV->getStepRecurrence(*this));
13303
13304 // Avoid negative or zero stride values
13305 if (!isKnownPositive(Stride))
13306 return getCouldNotCompute();
13307
13308 // Avoid proven overflow cases: this will ensure that the backedge taken count
13309 // will not generate any unsigned overflow. Relaxed no-overflow conditions
13310 // exploit NoWrapFlags, allowing to optimize in presence of undefined
13311 // behaviors like the case of C language.
13312 if (!Stride->isOne() && !NoWrap)
13313 if (canIVOverflowOnGT(RHS, Stride, IsSigned))
13314 return getCouldNotCompute();
13315
13316 const SCEV *Start = IV->getStart();
13317 const SCEV *End = RHS;
13318 if (!isLoopEntryGuardedByCond(L, Cond, getAddExpr(Start, Stride), RHS)) {
13319 // If we know that Start >= RHS in the context of loop, then we know that
13320 // min(RHS, Start) = RHS at this point.
13322 L, IsSigned ? ICmpInst::ICMP_SGE : ICmpInst::ICMP_UGE, Start, RHS))
13323 End = RHS;
13324 else
13325 End = IsSigned ? getSMinExpr(RHS, Start) : getUMinExpr(RHS, Start);
13326 }
13327
13328 if (Start->getType()->isPointerTy()) {
13329 Start = getLosslessPtrToIntExpr(Start);
13330 if (isa<SCEVCouldNotCompute>(Start))
13331 return Start;
13332 }
13333 if (End->getType()->isPointerTy()) {
13335 if (isa<SCEVCouldNotCompute>(End))
13336 return End;
13337 }
13338
13339 // Compute ((Start - End) + (Stride - 1)) / Stride.
13340 // FIXME: This can overflow. Holding off on fixing this for now;
13341 // howManyGreaterThans will hopefully be gone soon.
13342 const SCEV *One = getOne(Stride->getType());
13343 const SCEV *BECount = getUDivExpr(
13344 getAddExpr(getMinusSCEV(Start, End), getMinusSCEV(Stride, One)), Stride);
13345
13346 APInt MaxStart = IsSigned ? getSignedRangeMax(Start)
13347 : getUnsignedRangeMax(Start);
13348
13349 APInt MinStride = IsSigned ? getSignedRangeMin(Stride)
13350 : getUnsignedRangeMin(Stride);
13351
13352 unsigned BitWidth = getTypeSizeInBits(LHS->getType());
13353 APInt Limit = IsSigned ? APInt::getSignedMinValue(BitWidth) + (MinStride - 1)
13354 : APInt::getMinValue(BitWidth) + (MinStride - 1);
13355
13356 // Although End can be a MIN expression we estimate MinEnd considering only
13357 // the case End = RHS. This is safe because in the other case (Start - End)
13358 // is zero, leading to a zero maximum backedge taken count.
13359 APInt MinEnd =
13360 IsSigned ? APIntOps::smax(getSignedRangeMin(RHS), Limit)
13361 : APIntOps::umax(getUnsignedRangeMin(RHS), Limit);
13362
13363 const SCEV *ConstantMaxBECount =
13364 isa<SCEVConstant>(BECount)
13365 ? BECount
13366 : getUDivCeilSCEV(getConstant(MaxStart - MinEnd),
13367 getConstant(MinStride));
13368
13369 if (isa<SCEVCouldNotCompute>(ConstantMaxBECount))
13370 ConstantMaxBECount = BECount;
13371 const SCEV *SymbolicMaxBECount =
13372 isa<SCEVCouldNotCompute>(BECount) ? ConstantMaxBECount : BECount;
13373
13374 return ExitLimit(BECount, ConstantMaxBECount, SymbolicMaxBECount, false,
13375 Predicates);
13376}
13377
13379 ScalarEvolution &SE) const {
13380 if (Range.isFullSet()) // Infinite loop.
13381 return SE.getCouldNotCompute();
13382
13383 // If the start is a non-zero constant, shift the range to simplify things.
13384 if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(getStart()))
13385 if (!SC->getValue()->isZero()) {
13387 Operands[0] = SE.getZero(SC->getType());
13388 const SCEV *Shifted = SE.getAddRecExpr(Operands, getLoop(),
13389 getNoWrapFlags(FlagNW));
13390 if (const auto *ShiftedAddRec = dyn_cast<SCEVAddRecExpr>(Shifted))
13391 return ShiftedAddRec->getNumIterationsInRange(
13392 Range.subtract(SC->getAPInt()), SE);
13393 // This is strange and shouldn't happen.
13394 return SE.getCouldNotCompute();
13395 }
13396
13397 // The only time we can solve this is when we have all constant indices.
13398 // Otherwise, we cannot determine the overflow conditions.
13399 if (any_of(operands(), [](const SCEV *Op) { return !isa<SCEVConstant>(Op); }))
13400 return SE.getCouldNotCompute();
13401
13402 // Okay at this point we know that all elements of the chrec are constants and
13403 // that the start element is zero.
13404
13405 // First check to see if the range contains zero. If not, the first
13406 // iteration exits.
13407 unsigned BitWidth = SE.getTypeSizeInBits(getType());
13408 if (!Range.contains(APInt(BitWidth, 0)))
13409 return SE.getZero(getType());
13410
13411 if (isAffine()) {
13412 // If this is an affine expression then we have this situation:
13413 // Solve {0,+,A} in Range === Ax in Range
13414
13415 // We know that zero is in the range. If A is positive then we know that
13416 // the upper value of the range must be the first possible exit value.
13417 // If A is negative then the lower of the range is the last possible loop
13418 // value. Also note that we already checked for a full range.
13419 APInt A = cast<SCEVConstant>(getOperand(1))->getAPInt();
13420 APInt End = A.sge(1) ? (Range.getUpper() - 1) : Range.getLower();
13421
13422 // The exit value should be (End+A)/A.
13423 APInt ExitVal = (End + A).udiv(A);
13424 ConstantInt *ExitValue = ConstantInt::get(SE.getContext(), ExitVal);
13425
13426 // Evaluate at the exit value. If we really did fall out of the valid
13427 // range, then we computed our trip count, otherwise wrap around or other
13428 // things must have happened.
13429 ConstantInt *Val = EvaluateConstantChrecAtConstant(this, ExitValue, SE);
13430 if (Range.contains(Val->getValue()))
13431 return SE.getCouldNotCompute(); // Something strange happened
13432
13433 // Ensure that the previous value is in the range.
13436 ConstantInt::get(SE.getContext(), ExitVal - 1), SE)->getValue()) &&
13437 "Linear scev computation is off in a bad way!");
13438 return SE.getConstant(ExitValue);
13439 }
13440
13441 if (isQuadratic()) {
13442 if (auto S = SolveQuadraticAddRecRange(this, Range, SE))
13443 return SE.getConstant(*S);
13444 }
13445
13446 return SE.getCouldNotCompute();
13447}
13448
13449const SCEVAddRecExpr *
13451 assert(getNumOperands() > 1 && "AddRec with zero step?");
13452 // There is a temptation to just call getAddExpr(this, getStepRecurrence(SE)),
13453 // but in this case we cannot guarantee that the value returned will be an
13454 // AddRec because SCEV does not have a fixed point where it stops
13455 // simplification: it is legal to return ({rec1} + {rec2}). For example, it
13456 // may happen if we reach arithmetic depth limit while simplifying. So we
13457 // construct the returned value explicitly.
13459 // If this is {A,+,B,+,C,...,+,N}, then its step is {B,+,C,+,...,+,N}, and
13460 // (this + Step) is {A+B,+,B+C,+...,+,N}.
13461 for (unsigned i = 0, e = getNumOperands() - 1; i < e; ++i)
13462 Ops.push_back(SE.getAddExpr(getOperand(i), getOperand(i + 1)));
13463 // We know that the last operand is not a constant zero (otherwise it would
13464 // have been popped out earlier). This guarantees us that if the result has
13465 // the same last operand, then it will also not be popped out, meaning that
13466 // the returned value will be an AddRec.
13467 const SCEV *Last = getOperand(getNumOperands() - 1);
13468 assert(!Last->isZero() && "Recurrency with zero step?");
13469 Ops.push_back(Last);
13470 return cast<SCEVAddRecExpr>(SE.getAddRecExpr(Ops, getLoop(),
13472}
13473
13474// Return true when S contains at least an undef value.
13476 return SCEVExprContains(S, [](const SCEV *S) {
13477 if (const auto *SU = dyn_cast<SCEVUnknown>(S))
13478 return isa<UndefValue>(SU->getValue());
13479 return false;
13480 });
13481}
13482
13483// Return true when S contains a value that is a nullptr.
13485 return SCEVExprContains(S, [](const SCEV *S) {
13486 if (const auto *SU = dyn_cast<SCEVUnknown>(S))
13487 return SU->getValue() == nullptr;
13488 return false;
13489 });
13490}
13491
13492/// Return the size of an element read or written by Inst.
13494 Type *Ty;
13495 if (StoreInst *Store = dyn_cast<StoreInst>(Inst))
13496 Ty = Store->getValueOperand()->getType();
13497 else if (LoadInst *Load = dyn_cast<LoadInst>(Inst))
13498 Ty = Load->getType();
13499 else
13500 return nullptr;
13501
13503 return getSizeOfExpr(ETy, Ty);
13504}
13505
13506//===----------------------------------------------------------------------===//
13507// SCEVCallbackVH Class Implementation
13508//===----------------------------------------------------------------------===//
13509
13510void ScalarEvolution::SCEVCallbackVH::deleted() {
13511 assert(SE && "SCEVCallbackVH called with a null ScalarEvolution!");
13512 if (PHINode *PN = dyn_cast<PHINode>(getValPtr()))
13513 SE->ConstantEvolutionLoopExitValue.erase(PN);
13514 SE->eraseValueFromMap(getValPtr());
13515 // this now dangles!
13516}
13517
13518void ScalarEvolution::SCEVCallbackVH::allUsesReplacedWith(Value *V) {
13519 assert(SE && "SCEVCallbackVH called with a null ScalarEvolution!");
13520
13521 // Forget all the expressions associated with users of the old value,
13522 // so that future queries will recompute the expressions using the new
13523 // value.
13524 SE->forgetValue(getValPtr());
13525 // this now dangles!
13526}
13527
13528ScalarEvolution::SCEVCallbackVH::SCEVCallbackVH(Value *V, ScalarEvolution *se)
13529 : CallbackVH(V), SE(se) {}
13530
13531//===----------------------------------------------------------------------===//
13532// ScalarEvolution Class Implementation
13533//===----------------------------------------------------------------------===//
13534
13537 LoopInfo &LI)
13538 : F(F), DL(F.getDataLayout()), TLI(TLI), AC(AC), DT(DT), LI(LI),
13539 CouldNotCompute(new SCEVCouldNotCompute()), ValuesAtScopes(64),
13540 LoopDispositions(64), BlockDispositions(64) {
13541 // To use guards for proving predicates, we need to scan every instruction in
13542 // relevant basic blocks, and not just terminators. Doing this is a waste of
13543 // time if the IR does not actually contain any calls to
13544 // @llvm.experimental.guard, so do a quick check and remember this beforehand.
13545 //
13546 // This pessimizes the case where a pass that preserves ScalarEvolution wants
13547 // to _add_ guards to the module when there weren't any before, and wants
13548 // ScalarEvolution to optimize based on those guards. For now we prefer to be
13549 // efficient in lieu of being smart in that rather obscure case.
13550
13551 auto *GuardDecl = F.getParent()->getFunction(
13552 Intrinsic::getName(Intrinsic::experimental_guard));
13553 HasGuards = GuardDecl && !GuardDecl->use_empty();
13554}
13555
13557 : F(Arg.F), DL(Arg.DL), HasGuards(Arg.HasGuards), TLI(Arg.TLI), AC(Arg.AC),
13558 DT(Arg.DT), LI(Arg.LI), CouldNotCompute(std::move(Arg.CouldNotCompute)),
13559 ValueExprMap(std::move(Arg.ValueExprMap)),
13560 PendingLoopPredicates(std::move(Arg.PendingLoopPredicates)),
13561 PendingPhiRanges(std::move(Arg.PendingPhiRanges)),
13562 PendingMerges(std::move(Arg.PendingMerges)),
13563 ConstantMultipleCache(std::move(Arg.ConstantMultipleCache)),
13564 BackedgeTakenCounts(std::move(Arg.BackedgeTakenCounts)),
13565 PredicatedBackedgeTakenCounts(
13566 std::move(Arg.PredicatedBackedgeTakenCounts)),
13567 BECountUsers(std::move(Arg.BECountUsers)),
13568 ConstantEvolutionLoopExitValue(
13569 std::move(Arg.ConstantEvolutionLoopExitValue)),
13570 ValuesAtScopes(std::move(Arg.ValuesAtScopes)),
13571 ValuesAtScopesUsers(std::move(Arg.ValuesAtScopesUsers)),
13572 LoopDispositions(std::move(Arg.LoopDispositions)),
13573 LoopPropertiesCache(std::move(Arg.LoopPropertiesCache)),
13574 BlockDispositions(std::move(Arg.BlockDispositions)),
13575 SCEVUsers(std::move(Arg.SCEVUsers)),
13576 UnsignedRanges(std::move(Arg.UnsignedRanges)),
13577 SignedRanges(std::move(Arg.SignedRanges)),
13578 UniqueSCEVs(std::move(Arg.UniqueSCEVs)),
13579 UniquePreds(std::move(Arg.UniquePreds)),
13580 SCEVAllocator(std::move(Arg.SCEVAllocator)),
13581 LoopUsers(std::move(Arg.LoopUsers)),
13582 PredicatedSCEVRewrites(std::move(Arg.PredicatedSCEVRewrites)),
13583 FirstUnknown(Arg.FirstUnknown) {
13584 Arg.FirstUnknown = nullptr;
13585}
13586
13588 // Iterate through all the SCEVUnknown instances and call their
13589 // destructors, so that they release their references to their values.
13590 for (SCEVUnknown *U = FirstUnknown; U;) {
13591 SCEVUnknown *Tmp = U;
13592 U = U->Next;
13593 Tmp->~SCEVUnknown();
13594 }
13595 FirstUnknown = nullptr;
13596
13597 ExprValueMap.clear();
13598 ValueExprMap.clear();
13599 HasRecMap.clear();
13600 BackedgeTakenCounts.clear();
13601 PredicatedBackedgeTakenCounts.clear();
13602
13603 assert(PendingLoopPredicates.empty() && "isImpliedCond garbage");
13604 assert(PendingPhiRanges.empty() && "getRangeRef garbage");
13605 assert(PendingMerges.empty() && "isImpliedViaMerge garbage");
13606 assert(!WalkingBEDominatingConds && "isLoopBackedgeGuardedByCond garbage!");
13607 assert(!ProvingSplitPredicate && "ProvingSplitPredicate garbage!");
13608}
13609
13611 return !isa<SCEVCouldNotCompute>(getBackedgeTakenCount(L));
13612}
13613
13614/// When printing a top-level SCEV for trip counts, it's helpful to include
13615/// a type for constants which are otherwise hard to disambiguate.
13616static void PrintSCEVWithTypeHint(raw_ostream &OS, const SCEV* S) {
13617 if (isa<SCEVConstant>(S))
13618 OS << *S->getType() << " ";
13619 OS << *S;
13620}
13621
13623 const Loop *L) {
13624 // Print all inner loops first
13625 for (Loop *I : *L)
13626 PrintLoopInfo(OS, SE, I);
13627
13628 OS << "Loop ";
13629 L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
13630 OS << ": ";
13631
13632 SmallVector<BasicBlock *, 8> ExitingBlocks;
13633 L->getExitingBlocks(ExitingBlocks);
13634 if (ExitingBlocks.size() != 1)
13635 OS << "<multiple exits> ";
13636
13637 auto *BTC = SE->getBackedgeTakenCount(L);
13638 if (!isa<SCEVCouldNotCompute>(BTC)) {
13639 OS << "backedge-taken count is ";
13641 } else
13642 OS << "Unpredictable backedge-taken count.";
13643 OS << "\n";
13644
13645 if (ExitingBlocks.size() > 1)
13646 for (BasicBlock *ExitingBlock : ExitingBlocks) {
13647 OS << " exit count for " << ExitingBlock->getName() << ": ";
13648 PrintSCEVWithTypeHint(OS, SE->getExitCount(L, ExitingBlock));
13649 OS << "\n";
13650 }
13651
13652 OS << "Loop ";
13653 L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
13654 OS << ": ";
13655
13656 auto *ConstantBTC = SE->getConstantMaxBackedgeTakenCount(L);
13657 if (!isa<SCEVCouldNotCompute>(ConstantBTC)) {
13658 OS << "constant max backedge-taken count is ";
13659 PrintSCEVWithTypeHint(OS, ConstantBTC);
13661 OS << ", actual taken count either this or zero.";
13662 } else {
13663 OS << "Unpredictable constant max backedge-taken count. ";
13664 }
13665
13666 OS << "\n"
13667 "Loop ";
13668 L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
13669 OS << ": ";
13670
13671 auto *SymbolicBTC = SE->getSymbolicMaxBackedgeTakenCount(L);
13672 if (!isa<SCEVCouldNotCompute>(SymbolicBTC)) {
13673 OS << "symbolic max backedge-taken count is ";
13674 PrintSCEVWithTypeHint(OS, SymbolicBTC);
13676 OS << ", actual taken count either this or zero.";
13677 } else {
13678 OS << "Unpredictable symbolic max backedge-taken count. ";
13679 }
13680 OS << "\n";
13681
13682 if (ExitingBlocks.size() > 1)
13683 for (BasicBlock *ExitingBlock : ExitingBlocks) {
13684 OS << " symbolic max exit count for " << ExitingBlock->getName() << ": ";
13685 auto *ExitBTC = SE->getExitCount(L, ExitingBlock,
13687 PrintSCEVWithTypeHint(OS, ExitBTC);
13688 OS << "\n";
13689 }
13690
13692 auto *PBT = SE->getPredicatedBackedgeTakenCount(L, Preds);
13693 if (PBT != BTC || !Preds.empty()) {
13694 OS << "Loop ";
13695 L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
13696 OS << ": ";
13697 if (!isa<SCEVCouldNotCompute>(PBT)) {
13698 OS << "Predicated backedge-taken count is ";
13700 } else
13701 OS << "Unpredictable predicated backedge-taken count.";
13702 OS << "\n";
13703 OS << " Predicates:\n";
13704 for (const auto *P : Preds)
13705 P->print(OS, 4);
13706 }
13707
13708 Preds.clear();
13709 auto *PredSymbolicMax =
13711 if (SymbolicBTC != PredSymbolicMax) {
13712 OS << "Loop ";
13713 L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
13714 OS << ": ";
13715 if (!isa<SCEVCouldNotCompute>(PredSymbolicMax)) {
13716 OS << "Predicated symbolic max backedge-taken count is ";
13717 PrintSCEVWithTypeHint(OS, PredSymbolicMax);
13718 } else
13719 OS << "Unpredictable predicated symbolic max backedge-taken count.";
13720 OS << "\n";
13721 OS << " Predicates:\n";
13722 for (const auto *P : Preds)
13723 P->print(OS, 4);
13724 }
13725
13727 OS << "Loop ";
13728 L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
13729 OS << ": ";
13730 OS << "Trip multiple is " << SE->getSmallConstantTripMultiple(L) << "\n";
13731 }
13732}
13733
13734namespace llvm {
13736 switch (LD) {
13738 OS << "Variant";
13739 break;
13741 OS << "Invariant";
13742 break;
13744 OS << "Computable";
13745 break;
13746 }
13747 return OS;
13748}
13749
13751 switch (BD) {
13753 OS << "DoesNotDominate";
13754 break;
13756 OS << "Dominates";
13757 break;
13759 OS << "ProperlyDominates";
13760 break;
13761 }
13762 return OS;
13763}
13764} // namespace llvm
13765
13767 // ScalarEvolution's implementation of the print method is to print
13768 // out SCEV values of all instructions that are interesting. Doing
13769 // this potentially causes it to create new SCEV objects though,
13770 // which technically conflicts with the const qualifier. This isn't
13771 // observable from outside the class though, so casting away the
13772 // const isn't dangerous.
13773 ScalarEvolution &SE = *const_cast<ScalarEvolution *>(this);
13774
13775 if (ClassifyExpressions) {
13776 OS << "Classifying expressions for: ";
13777 F.printAsOperand(OS, /*PrintType=*/false);
13778 OS << "\n";
13779 for (Instruction &I : instructions(F))
13780 if (isSCEVable(I.getType()) && !isa<CmpInst>(I)) {
13781 OS << I << '\n';
13782 OS << " --> ";
13783 const SCEV *SV = SE.getSCEV(&I);
13784 SV->print(OS);
13785 if (!isa<SCEVCouldNotCompute>(SV)) {
13786 OS << " U: ";
13787 SE.getUnsignedRange(SV).print(OS);
13788 OS << " S: ";
13789 SE.getSignedRange(SV).print(OS);
13790 }
13791
13792 const Loop *L = LI.getLoopFor(I.getParent());
13793
13794 const SCEV *AtUse = SE.getSCEVAtScope(SV, L);
13795 if (AtUse != SV) {
13796 OS << " --> ";
13797 AtUse->print(OS);
13798 if (!isa<SCEVCouldNotCompute>(AtUse)) {
13799 OS << " U: ";
13800 SE.getUnsignedRange(AtUse).print(OS);
13801 OS << " S: ";
13802 SE.getSignedRange(AtUse).print(OS);
13803 }
13804 }
13805
13806 if (L) {
13807 OS << "\t\t" "Exits: ";
13808 const SCEV *ExitValue = SE.getSCEVAtScope(SV, L->getParentLoop());
13809 if (!SE.isLoopInvariant(ExitValue, L)) {
13810 OS << "<<Unknown>>";
13811 } else {
13812 OS << *ExitValue;
13813 }
13814
13815 bool First = true;
13816 for (const auto *Iter = L; Iter; Iter = Iter->getParentLoop()) {
13817 if (First) {
13818 OS << "\t\t" "LoopDispositions: { ";
13819 First = false;
13820 } else {
13821 OS << ", ";
13822 }
13823
13824 Iter->getHeader()->printAsOperand(OS, /*PrintType=*/false);
13825 OS << ": " << SE.getLoopDisposition(SV, Iter);
13826 }
13827
13828 for (const auto *InnerL : depth_first(L)) {
13829 if (InnerL == L)
13830 continue;
13831 if (First) {
13832 OS << "\t\t" "LoopDispositions: { ";
13833 First = false;
13834 } else {
13835 OS << ", ";
13836 }
13837
13838 InnerL->getHeader()->printAsOperand(OS, /*PrintType=*/false);
13839 OS << ": " << SE.getLoopDisposition(SV, InnerL);
13840 }
13841
13842 OS << " }";
13843 }
13844
13845 OS << "\n";
13846 }
13847 }
13848
13849 OS << "Determining loop execution counts for: ";
13850 F.printAsOperand(OS, /*PrintType=*/false);
13851 OS << "\n";
13852 for (Loop *I : LI)
13853 PrintLoopInfo(OS, &SE, I);
13854}
13855
13858 auto &Values = LoopDispositions[S];
13859 for (auto &V : Values) {
13860 if (V.getPointer() == L)
13861 return V.getInt();
13862 }
13863 Values.emplace_back(L, LoopVariant);
13864 LoopDisposition D = computeLoopDisposition(S, L);
13865 auto &Values2 = LoopDispositions[S];
13866 for (auto &V : llvm::reverse(Values2)) {
13867 if (V.getPointer() == L) {
13868 V.setInt(D);
13869 break;
13870 }
13871 }
13872 return D;
13873}
13874
13876ScalarEvolution::computeLoopDisposition(const SCEV *S, const Loop *L) {
13877 switch (S->getSCEVType()) {
13878 case scConstant:
13879 case scVScale:
13880 return LoopInvariant;
13881 case scAddRecExpr: {
13882 const SCEVAddRecExpr *AR = cast<SCEVAddRecExpr>(S);
13883
13884 // If L is the addrec's loop, it's computable.
13885 if (AR->getLoop() == L)
13886 return LoopComputable;
13887
13888 // Add recurrences are never invariant in the function-body (null loop).
13889 if (!L)
13890 return LoopVariant;
13891
13892 // Everything that is not defined at loop entry is variant.
13893 if (DT.dominates(L->getHeader(), AR->getLoop()->getHeader()))
13894 return LoopVariant;
13895 assert(!L->contains(AR->getLoop()) && "Containing loop's header does not"
13896 " dominate the contained loop's header?");
13897
13898 // This recurrence is invariant w.r.t. L if AR's loop contains L.
13899 if (AR->getLoop()->contains(L))
13900 return LoopInvariant;
13901
13902 // This recurrence is variant w.r.t. L if any of its operands
13903 // are variant.
13904 for (const auto *Op : AR->operands())
13905 if (!isLoopInvariant(Op, L))
13906 return LoopVariant;
13907
13908 // Otherwise it's loop-invariant.
13909 return LoopInvariant;
13910 }
13911 case scTruncate:
13912 case scZeroExtend:
13913 case scSignExtend:
13914 case scPtrToInt:
13915 case scAddExpr:
13916 case scMulExpr:
13917 case scUDivExpr:
13918 case scUMaxExpr:
13919 case scSMaxExpr:
13920 case scUMinExpr:
13921 case scSMinExpr:
13922 case scSequentialUMinExpr: {
13923 bool HasVarying = false;
13924 for (const auto *Op : S->operands()) {
13926 if (D == LoopVariant)
13927 return LoopVariant;
13928 if (D == LoopComputable)
13929 HasVarying = true;
13930 }
13931 return HasVarying ? LoopComputable : LoopInvariant;
13932 }
13933 case scUnknown:
13934 // All non-instruction values are loop invariant. All instructions are loop
13935 // invariant if they are not contained in the specified loop.
13936 // Instructions are never considered invariant in the function body
13937 // (null loop) because they are defined within the "loop".
13938 if (auto *I = dyn_cast<Instruction>(cast<SCEVUnknown>(S)->getValue()))
13939 return (L && !L->contains(I)) ? LoopInvariant : LoopVariant;
13940 return LoopInvariant;
13941 case scCouldNotCompute:
13942 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
13943 }
13944 llvm_unreachable("Unknown SCEV kind!");
13945}
13946
13948 return getLoopDisposition(S, L) == LoopInvariant;
13949}
13950
13952 return getLoopDisposition(S, L) == LoopComputable;
13953}
13954
13957 auto &Values = BlockDispositions[S];
13958 for (auto &V : Values) {
13959 if (V.getPointer() == BB)
13960 return V.getInt();
13961 }
13962 Values.emplace_back(BB, DoesNotDominateBlock);
13963 BlockDisposition D = computeBlockDisposition(S, BB);
13964 auto &Values2 = BlockDispositions[S];
13965 for (auto &V : llvm::reverse(Values2)) {
13966 if (V.getPointer() == BB) {
13967 V.setInt(D);
13968 break;
13969 }
13970 }
13971 return D;
13972}
13973
13975ScalarEvolution::computeBlockDisposition(const SCEV *S, const BasicBlock *BB) {
13976 switch (S->getSCEVType()) {
13977 case scConstant:
13978 case scVScale:
13980 case scAddRecExpr: {
13981 // This uses a "dominates" query instead of "properly dominates" query
13982 // to test for proper dominance too, because the instruction which
13983 // produces the addrec's value is a PHI, and a PHI effectively properly
13984 // dominates its entire containing block.
13985 const SCEVAddRecExpr *AR = cast<SCEVAddRecExpr>(S);
13986 if (!DT.dominates(AR->getLoop()->getHeader(), BB))
13987 return DoesNotDominateBlock;
13988
13989 // Fall through into SCEVNAryExpr handling.
13990 [[fallthrough]];
13991 }
13992 case scTruncate:
13993 case scZeroExtend:
13994 case scSignExtend:
13995 case scPtrToInt:
13996 case scAddExpr:
13997 case scMulExpr:
13998 case scUDivExpr:
13999 case scUMaxExpr:
14000 case scSMaxExpr:
14001 case scUMinExpr:
14002 case scSMinExpr:
14003 case scSequentialUMinExpr: {
14004 bool Proper = true;
14005 for (const SCEV *NAryOp : S->operands()) {
14007 if (D == DoesNotDominateBlock)
14008 return DoesNotDominateBlock;
14009 if (D == DominatesBlock)
14010 Proper = false;
14011 }
14012 return Proper ? ProperlyDominatesBlock : DominatesBlock;
14013 }
14014 case scUnknown:
14015 if (Instruction *I =
14016 dyn_cast<Instruction>(cast<SCEVUnknown>(S)->getValue())) {
14017 if (I->getParent() == BB)
14018 return DominatesBlock;
14019 if (DT.properlyDominates(I->getParent(), BB))
14021 return DoesNotDominateBlock;
14022 }
14024 case scCouldNotCompute:
14025 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
14026 }
14027 llvm_unreachable("Unknown SCEV kind!");
14028}
14029
14030bool ScalarEvolution::dominates(const SCEV *S, const BasicBlock *BB) {
14031 return getBlockDisposition(S, BB) >= DominatesBlock;
14032}
14033
14036}
14037
14038bool ScalarEvolution::hasOperand(const SCEV *S, const SCEV *Op) const {
14039 return SCEVExprContains(S, [&](const SCEV *Expr) { return Expr == Op; });
14040}
14041
14042void ScalarEvolution::forgetBackedgeTakenCounts(const Loop *L,
14043 bool Predicated) {
14044 auto &BECounts =
14045 Predicated ? PredicatedBackedgeTakenCounts : BackedgeTakenCounts;
14046 auto It = BECounts.find(L);
14047 if (It != BECounts.end()) {
14048 for (const ExitNotTakenInfo &ENT : It->second.ExitNotTaken) {
14049 for (const SCEV *S : {ENT.ExactNotTaken, ENT.SymbolicMaxNotTaken}) {
14050 if (!isa<SCEVConstant>(S)) {
14051 auto UserIt = BECountUsers.find(S);
14052 assert(UserIt != BECountUsers.end());
14053 UserIt->second.erase({L, Predicated});
14054 }
14055 }
14056 }
14057 BECounts.erase(It);
14058 }
14059}
14060
14061void ScalarEvolution::forgetMemoizedResults(ArrayRef<const SCEV *> SCEVs) {
14062 SmallPtrSet<const SCEV *, 8> ToForget(SCEVs.begin(), SCEVs.end());
14063 SmallVector<const SCEV *, 8> Worklist(ToForget.begin(), ToForget.end());
14064
14065 while (!Worklist.empty()) {
14066 const SCEV *Curr = Worklist.pop_back_val();
14067 auto Users = SCEVUsers.find(Curr);
14068 if (Users != SCEVUsers.end())
14069 for (const auto *User : Users->second)
14070 if (ToForget.insert(User).second)
14071 Worklist.push_back(User);
14072 }
14073
14074 for (const auto *S : ToForget)
14075 forgetMemoizedResultsImpl(S);
14076
14077 for (auto I = PredicatedSCEVRewrites.begin();
14078 I != PredicatedSCEVRewrites.end();) {
14079 std::pair<const SCEV *, const Loop *> Entry = I->first;
14080 if (ToForget.count(Entry.first))
14081 PredicatedSCEVRewrites.erase(I++);
14082 else
14083 ++I;
14084 }
14085}
14086
14087void ScalarEvolution::forgetMemoizedResultsImpl(const SCEV *S) {
14088 LoopDispositions.erase(S);
14089 BlockDispositions.erase(S);
14090 UnsignedRanges.erase(S);
14091 SignedRanges.erase(S);
14092 HasRecMap.erase(S);
14093 ConstantMultipleCache.erase(S);
14094
14095 if (auto *AR = dyn_cast<SCEVAddRecExpr>(S)) {
14096 UnsignedWrapViaInductionTried.erase(AR);
14097 SignedWrapViaInductionTried.erase(AR);
14098 }
14099
14100 auto ExprIt = ExprValueMap.find(S);
14101 if (ExprIt != ExprValueMap.end()) {
14102 for (Value *V : ExprIt->second) {
14103 auto ValueIt = ValueExprMap.find_as(V);
14104 if (ValueIt != ValueExprMap.end())
14105 ValueExprMap.erase(ValueIt);
14106 }
14107 ExprValueMap.erase(ExprIt);
14108 }
14109
14110 auto ScopeIt = ValuesAtScopes.find(S);
14111 if (ScopeIt != ValuesAtScopes.end()) {
14112 for (const auto &Pair : ScopeIt->second)
14113 if (!isa_and_nonnull<SCEVConstant>(Pair.second))
14114 llvm::erase(ValuesAtScopesUsers[Pair.second],
14115 std::make_pair(Pair.first, S));
14116 ValuesAtScopes.erase(ScopeIt);
14117 }
14118
14119 auto ScopeUserIt = ValuesAtScopesUsers.find(S);
14120 if (ScopeUserIt != ValuesAtScopesUsers.end()) {
14121 for (const auto &Pair : ScopeUserIt->second)
14122 llvm::erase(ValuesAtScopes[Pair.second], std::make_pair(Pair.first, S));
14123 ValuesAtScopesUsers.erase(ScopeUserIt);
14124 }
14125
14126 auto BEUsersIt = BECountUsers.find(S);
14127 if (BEUsersIt != BECountUsers.end()) {
14128 // Work on a copy, as forgetBackedgeTakenCounts() will modify the original.
14129 auto Copy = BEUsersIt->second;
14130 for (const auto &Pair : Copy)
14131 forgetBackedgeTakenCounts(Pair.getPointer(), Pair.getInt());
14132 BECountUsers.erase(BEUsersIt);
14133 }
14134
14135 auto FoldUser = FoldCacheUser.find(S);
14136 if (FoldUser != FoldCacheUser.end())
14137 for (auto &KV : FoldUser->second)
14138 FoldCache.erase(KV);
14139 FoldCacheUser.erase(S);
14140}
14141
14142void
14143ScalarEvolution::getUsedLoops(const SCEV *S,
14144 SmallPtrSetImpl<const Loop *> &LoopsUsed) {
14145 struct FindUsedLoops {
14146 FindUsedLoops(SmallPtrSetImpl<const Loop *> &LoopsUsed)
14147 : LoopsUsed(LoopsUsed) {}
14149 bool follow(const SCEV *S) {
14150 if (auto *AR = dyn_cast<SCEVAddRecExpr>(S))
14151 LoopsUsed.insert(AR->getLoop());
14152 return true;
14153 }
14154
14155 bool isDone() const { return false; }
14156 };
14157
14158 FindUsedLoops F(LoopsUsed);
14160}
14161
14162void ScalarEvolution::getReachableBlocks(
14165 Worklist.push_back(&F.getEntryBlock());
14166 while (!Worklist.empty()) {
14167 BasicBlock *BB = Worklist.pop_back_val();
14168 if (!Reachable.insert(BB).second)
14169 continue;
14170
14171 Value *Cond;
14172 BasicBlock *TrueBB, *FalseBB;
14173 if (match(BB->getTerminator(), m_Br(m_Value(Cond), m_BasicBlock(TrueBB),
14174 m_BasicBlock(FalseBB)))) {
14175 if (auto *C = dyn_cast<ConstantInt>(Cond)) {
14176 Worklist.push_back(C->isOne() ? TrueBB : FalseBB);
14177 continue;
14178 }
14179
14180 if (auto *Cmp = dyn_cast<ICmpInst>(Cond)) {
14181 const SCEV *L = getSCEV(Cmp->getOperand(0));
14182 const SCEV *R = getSCEV(Cmp->getOperand(1));
14183 if (isKnownPredicateViaConstantRanges(Cmp->getPredicate(), L, R)) {
14184 Worklist.push_back(TrueBB);
14185 continue;
14186 }
14187 if (isKnownPredicateViaConstantRanges(Cmp->getInversePredicate(), L,
14188 R)) {
14189 Worklist.push_back(FalseBB);
14190 continue;
14191 }
14192 }
14193 }
14194
14195 append_range(Worklist, successors(BB));
14196 }
14197}
14198
14200 ScalarEvolution &SE = *const_cast<ScalarEvolution *>(this);
14201 ScalarEvolution SE2(F, TLI, AC, DT, LI);
14202
14203 SmallVector<Loop *, 8> LoopStack(LI.begin(), LI.end());
14204
14205 // Map's SCEV expressions from one ScalarEvolution "universe" to another.
14206 struct SCEVMapper : public SCEVRewriteVisitor<SCEVMapper> {
14207 SCEVMapper(ScalarEvolution &SE) : SCEVRewriteVisitor<SCEVMapper>(SE) {}
14208
14209 const SCEV *visitConstant(const SCEVConstant *Constant) {
14210 return SE.getConstant(Constant->getAPInt());
14211 }
14212
14213 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
14214 return SE.getUnknown(Expr->getValue());
14215 }
14216
14217 const SCEV *visitCouldNotCompute(const SCEVCouldNotCompute *Expr) {
14218 return SE.getCouldNotCompute();
14219 }
14220 };
14221
14222 SCEVMapper SCM(SE2);
14223 SmallPtrSet<BasicBlock *, 16> ReachableBlocks;
14224 SE2.getReachableBlocks(ReachableBlocks, F);
14225
14226 auto GetDelta = [&](const SCEV *Old, const SCEV *New) -> const SCEV * {
14227 if (containsUndefs(Old) || containsUndefs(New)) {
14228 // SCEV treats "undef" as an unknown but consistent value (i.e. it does
14229 // not propagate undef aggressively). This means we can (and do) fail
14230 // verification in cases where a transform makes a value go from "undef"
14231 // to "undef+1" (say). The transform is fine, since in both cases the
14232 // result is "undef", but SCEV thinks the value increased by 1.
14233 return nullptr;
14234 }
14235
14236 // Unless VerifySCEVStrict is set, we only compare constant deltas.
14237 const SCEV *Delta = SE2.getMinusSCEV(Old, New);
14238 if (!VerifySCEVStrict && !isa<SCEVConstant>(Delta))
14239 return nullptr;
14240
14241 return Delta;
14242 };
14243
14244 while (!LoopStack.empty()) {
14245 auto *L = LoopStack.pop_back_val();
14246 llvm::append_range(LoopStack, *L);
14247
14248 // Only verify BECounts in reachable loops. For an unreachable loop,
14249 // any BECount is legal.
14250 if (!ReachableBlocks.contains(L->getHeader()))
14251 continue;
14252
14253 // Only verify cached BECounts. Computing new BECounts may change the
14254 // results of subsequent SCEV uses.
14255 auto It = BackedgeTakenCounts.find(L);
14256 if (It == BackedgeTakenCounts.end())
14257 continue;
14258
14259 auto *CurBECount =
14260 SCM.visit(It->second.getExact(L, const_cast<ScalarEvolution *>(this)));
14261 auto *NewBECount = SE2.getBackedgeTakenCount(L);
14262
14263 if (CurBECount == SE2.getCouldNotCompute() ||
14264 NewBECount == SE2.getCouldNotCompute()) {
14265 // NB! This situation is legal, but is very suspicious -- whatever pass
14266 // change the loop to make a trip count go from could not compute to
14267 // computable or vice-versa *should have* invalidated SCEV. However, we
14268 // choose not to assert here (for now) since we don't want false
14269 // positives.
14270 continue;
14271 }
14272
14273 if (SE.getTypeSizeInBits(CurBECount->getType()) >
14274 SE.getTypeSizeInBits(NewBECount->getType()))
14275 NewBECount = SE2.getZeroExtendExpr(NewBECount, CurBECount->getType());
14276 else if (SE.getTypeSizeInBits(CurBECount->getType()) <
14277 SE.getTypeSizeInBits(NewBECount->getType()))
14278 CurBECount = SE2.getZeroExtendExpr(CurBECount, NewBECount->getType());
14279
14280 const SCEV *Delta = GetDelta(CurBECount, NewBECount);
14281 if (Delta && !Delta->isZero()) {
14282 dbgs() << "Trip Count for " << *L << " Changed!\n";
14283 dbgs() << "Old: " << *CurBECount << "\n";
14284 dbgs() << "New: " << *NewBECount << "\n";
14285 dbgs() << "Delta: " << *Delta << "\n";
14286 std::abort();
14287 }
14288 }
14289
14290 // Collect all valid loops currently in LoopInfo.
14291 SmallPtrSet<Loop *, 32> ValidLoops;
14292 SmallVector<Loop *, 32> Worklist(LI.begin(), LI.end());
14293 while (!Worklist.empty()) {
14294 Loop *L = Worklist.pop_back_val();
14295 if (ValidLoops.insert(L).second)
14296 Worklist.append(L->begin(), L->end());
14297 }
14298 for (const auto &KV : ValueExprMap) {
14299#ifndef NDEBUG
14300 // Check for SCEV expressions referencing invalid/deleted loops.
14301 if (auto *AR = dyn_cast<SCEVAddRecExpr>(KV.second)) {
14302 assert(ValidLoops.contains(AR->getLoop()) &&
14303 "AddRec references invalid loop");
14304 }
14305#endif
14306
14307 // Check that the value is also part of the reverse map.
14308 auto It = ExprValueMap.find(KV.second);
14309 if (It == ExprValueMap.end() || !It->second.contains(KV.first)) {
14310 dbgs() << "Value " << *KV.first
14311 << " is in ValueExprMap but not in ExprValueMap\n";
14312 std::abort();
14313 }
14314
14315 if (auto *I = dyn_cast<Instruction>(&*KV.first)) {
14316 if (!ReachableBlocks.contains(I->getParent()))
14317 continue;
14318 const SCEV *OldSCEV = SCM.visit(KV.second);
14319 const SCEV *NewSCEV = SE2.getSCEV(I);
14320 const SCEV *Delta = GetDelta(OldSCEV, NewSCEV);
14321 if (Delta && !Delta->isZero()) {
14322 dbgs() << "SCEV for value " << *I << " changed!\n"
14323 << "Old: " << *OldSCEV << "\n"
14324 << "New: " << *NewSCEV << "\n"
14325 << "Delta: " << *Delta << "\n";
14326 std::abort();
14327 }
14328 }
14329 }
14330
14331 for (const auto &KV : ExprValueMap) {
14332 for (Value *V : KV.second) {
14333 auto It = ValueExprMap.find_as(V);
14334 if (It == ValueExprMap.end()) {
14335 dbgs() << "Value " << *V
14336 << " is in ExprValueMap but not in ValueExprMap\n";
14337 std::abort();
14338 }
14339 if (It->second != KV.first) {
14340 dbgs() << "Value " << *V << " mapped to " << *It->second
14341 << " rather than " << *KV.first << "\n";
14342 std::abort();
14343 }
14344 }
14345 }
14346
14347 // Verify integrity of SCEV users.
14348 for (const auto &S : UniqueSCEVs) {
14349 for (const auto *Op : S.operands()) {
14350 // We do not store dependencies of constants.
14351 if (isa<SCEVConstant>(Op))
14352 continue;
14353 auto It = SCEVUsers.find(Op);
14354 if (It != SCEVUsers.end() && It->second.count(&S))
14355 continue;
14356 dbgs() << "Use of operand " << *Op << " by user " << S
14357 << " is not being tracked!\n";
14358 std::abort();
14359 }
14360 }
14361
14362 // Verify integrity of ValuesAtScopes users.
14363 for (const auto &ValueAndVec : ValuesAtScopes) {
14364 const SCEV *Value = ValueAndVec.first;
14365 for (const auto &LoopAndValueAtScope : ValueAndVec.second) {
14366 const Loop *L = LoopAndValueAtScope.first;
14367 const SCEV *ValueAtScope = LoopAndValueAtScope.second;
14368 if (!isa<SCEVConstant>(ValueAtScope)) {
14369 auto It = ValuesAtScopesUsers.find(ValueAtScope);
14370 if (It != ValuesAtScopesUsers.end() &&
14371 is_contained(It->second, std::make_pair(L, Value)))
14372 continue;
14373 dbgs() << "Value: " << *Value << ", Loop: " << *L << ", ValueAtScope: "
14374 << *ValueAtScope << " missing in ValuesAtScopesUsers\n";
14375 std::abort();
14376 }
14377 }
14378 }
14379
14380 for (const auto &ValueAtScopeAndVec : ValuesAtScopesUsers) {
14381 const SCEV *ValueAtScope = ValueAtScopeAndVec.first;
14382 for (const auto &LoopAndValue : ValueAtScopeAndVec.second) {
14383 const Loop *L = LoopAndValue.first;
14384 const SCEV *Value = LoopAndValue.second;
14385 assert(!isa<SCEVConstant>(Value));
14386 auto It = ValuesAtScopes.find(Value);
14387 if (It != ValuesAtScopes.end() &&
14388 is_contained(It->second, std::make_pair(L, ValueAtScope)))
14389 continue;
14390 dbgs() << "Value: " << *Value << ", Loop: " << *L << ", ValueAtScope: "
14391 << *ValueAtScope << " missing in ValuesAtScopes\n";
14392 std::abort();
14393 }
14394 }
14395
14396 // Verify integrity of BECountUsers.
14397 auto VerifyBECountUsers = [&](bool Predicated) {
14398 auto &BECounts =
14399 Predicated ? PredicatedBackedgeTakenCounts : BackedgeTakenCounts;
14400 for (const auto &LoopAndBEInfo : BECounts) {
14401 for (const ExitNotTakenInfo &ENT : LoopAndBEInfo.second.ExitNotTaken) {
14402 for (const SCEV *S : {ENT.ExactNotTaken, ENT.SymbolicMaxNotTaken}) {
14403 if (!isa<SCEVConstant>(S)) {
14404 auto UserIt = BECountUsers.find(S);
14405 if (UserIt != BECountUsers.end() &&
14406 UserIt->second.contains({ LoopAndBEInfo.first, Predicated }))
14407 continue;
14408 dbgs() << "Value " << *S << " for loop " << *LoopAndBEInfo.first
14409 << " missing from BECountUsers\n";
14410 std::abort();
14411 }
14412 }
14413 }
14414 }
14415 };
14416 VerifyBECountUsers(/* Predicated */ false);
14417 VerifyBECountUsers(/* Predicated */ true);
14418
14419 // Verify intergity of loop disposition cache.
14420 for (auto &[S, Values] : LoopDispositions) {
14421 for (auto [Loop, CachedDisposition] : Values) {
14422 const auto RecomputedDisposition = SE2.getLoopDisposition(S, Loop);
14423 if (CachedDisposition != RecomputedDisposition) {
14424 dbgs() << "Cached disposition of " << *S << " for loop " << *Loop
14425 << " is incorrect: cached " << CachedDisposition << ", actual "
14426 << RecomputedDisposition << "\n";
14427 std::abort();
14428 }
14429 }
14430 }
14431
14432 // Verify integrity of the block disposition cache.
14433 for (auto &[S, Values] : BlockDispositions) {
14434 for (auto [BB, CachedDisposition] : Values) {
14435 const auto RecomputedDisposition = SE2.getBlockDisposition(S, BB);
14436 if (CachedDisposition != RecomputedDisposition) {
14437 dbgs() << "Cached disposition of " << *S << " for block %"
14438 << BB->getName() << " is incorrect: cached " << CachedDisposition
14439 << ", actual " << RecomputedDisposition << "\n";
14440 std::abort();
14441 }
14442 }
14443 }
14444
14445 // Verify FoldCache/FoldCacheUser caches.
14446 for (auto [FoldID, Expr] : FoldCache) {
14447 auto I = FoldCacheUser.find(Expr);
14448 if (I == FoldCacheUser.end()) {
14449 dbgs() << "Missing entry in FoldCacheUser for cached expression " << *Expr
14450 << "!\n";
14451 std::abort();
14452 }
14453 if (!is_contained(I->second, FoldID)) {
14454 dbgs() << "Missing FoldID in cached users of " << *Expr << "!\n";
14455 std::abort();
14456 }
14457 }
14458 for (auto [Expr, IDs] : FoldCacheUser) {
14459 for (auto &FoldID : IDs) {
14460 auto I = FoldCache.find(FoldID);
14461 if (I == FoldCache.end()) {
14462 dbgs() << "Missing entry in FoldCache for expression " << *Expr
14463 << "!\n";
14464 std::abort();
14465 }
14466 if (I->second != Expr) {
14467 dbgs() << "Entry in FoldCache doesn't match FoldCacheUser: "
14468 << *I->second << " != " << *Expr << "!\n";
14469 std::abort();
14470 }
14471 }
14472 }
14473
14474 // Verify that ConstantMultipleCache computations are correct. We check that
14475 // cached multiples and recomputed multiples are multiples of each other to
14476 // verify correctness. It is possible that a recomputed multiple is different
14477 // from the cached multiple due to strengthened no wrap flags or changes in
14478 // KnownBits computations.
14479 for (auto [S, Multiple] : ConstantMultipleCache) {
14480 APInt RecomputedMultiple = SE2.getConstantMultiple(S);
14481 if ((Multiple != 0 && RecomputedMultiple != 0 &&
14482 Multiple.urem(RecomputedMultiple) != 0 &&
14483 RecomputedMultiple.urem(Multiple) != 0)) {
14484 dbgs() << "Incorrect cached computation in ConstantMultipleCache for "
14485 << *S << " : Computed " << RecomputedMultiple
14486 << " but cache contains " << Multiple << "!\n";
14487 std::abort();
14488 }
14489 }
14490}
14491
14493 Function &F, const PreservedAnalyses &PA,
14495 // Invalidate the ScalarEvolution object whenever it isn't preserved or one
14496 // of its dependencies is invalidated.
14497 auto PAC = PA.getChecker<ScalarEvolutionAnalysis>();
14498 return !(PAC.preserved() || PAC.preservedSet<AllAnalysesOn<Function>>()) ||
14499 Inv.invalidate<AssumptionAnalysis>(F, PA) ||
14501 Inv.invalidate<LoopAnalysis>(F, PA);
14502}
14503
14504AnalysisKey ScalarEvolutionAnalysis::Key;
14505
14508 auto &TLI = AM.getResult<TargetLibraryAnalysis>(F);
14509 auto &AC = AM.getResult<AssumptionAnalysis>(F);
14510 auto &DT = AM.getResult<DominatorTreeAnalysis>(F);
14511 auto &LI = AM.getResult<LoopAnalysis>(F);
14512 return ScalarEvolution(F, TLI, AC, DT, LI);
14513}
14514
14518 return PreservedAnalyses::all();
14519}
14520
14523 // For compatibility with opt's -analyze feature under legacy pass manager
14524 // which was not ported to NPM. This keeps tests using
14525 // update_analyze_test_checks.py working.
14526 OS << "Printing analysis 'Scalar Evolution Analysis' for function '"
14527 << F.getName() << "':\n";
14529 return PreservedAnalyses::all();
14530}
14531
14533 "Scalar Evolution Analysis", false, true)
14539 "Scalar Evolution Analysis", false, true)
14540
14542
14545}
14546
14548 SE.reset(new ScalarEvolution(
14549 F, getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F),
14550 getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F),
14551 getAnalysis<DominatorTreeWrapperPass>().getDomTree(),
14552 getAnalysis<LoopInfoWrapperPass>().getLoopInfo()));
14553 return false;
14554}
14555
14557
14559 SE->print(OS);
14560}
14561
14563 if (!VerifySCEV)
14564 return;
14565
14566 SE->verify();
14567}
14568
14570 AU.setPreservesAll();
14575}
14576
14578 const SCEV *RHS) {
14580}
14581
14582const SCEVPredicate *
14584 const SCEV *LHS, const SCEV *RHS) {
14586 assert(LHS->getType() == RHS->getType() &&
14587 "Type mismatch between LHS and RHS");
14588 // Unique this node based on the arguments
14589 ID.AddInteger(SCEVPredicate::P_Compare);
14590 ID.AddInteger(Pred);
14591 ID.AddPointer(LHS);
14592 ID.AddPointer(RHS);
14593 void *IP = nullptr;
14594 if (const auto *S = UniquePreds.FindNodeOrInsertPos(ID, IP))
14595 return S;
14596 SCEVComparePredicate *Eq = new (SCEVAllocator)
14597 SCEVComparePredicate(ID.Intern(SCEVAllocator), Pred, LHS, RHS);
14598 UniquePreds.InsertNode(Eq, IP);
14599 return Eq;
14600}
14601
14603 const SCEVAddRecExpr *AR,
14606 // Unique this node based on the arguments
14607 ID.AddInteger(SCEVPredicate::P_Wrap);
14608 ID.AddPointer(AR);
14609 ID.AddInteger(AddedFlags);
14610 void *IP = nullptr;
14611 if (const auto *S = UniquePreds.FindNodeOrInsertPos(ID, IP))
14612 return S;
14613 auto *OF = new (SCEVAllocator)
14614 SCEVWrapPredicate(ID.Intern(SCEVAllocator), AR, AddedFlags);
14615 UniquePreds.InsertNode(OF, IP);
14616 return OF;
14617}
14618
14619namespace {
14620
14621class SCEVPredicateRewriter : public SCEVRewriteVisitor<SCEVPredicateRewriter> {
14622public:
14623
14624 /// Rewrites \p S in the context of a loop L and the SCEV predication
14625 /// infrastructure.
14626 ///
14627 /// If \p Pred is non-null, the SCEV expression is rewritten to respect the
14628 /// equivalences present in \p Pred.
14629 ///
14630 /// If \p NewPreds is non-null, rewrite is free to add further predicates to
14631 /// \p NewPreds such that the result will be an AddRecExpr.
14632 static const SCEV *rewrite(const SCEV *S, const Loop *L, ScalarEvolution &SE,
14634 const SCEVPredicate *Pred) {
14635 SCEVPredicateRewriter Rewriter(L, SE, NewPreds, Pred);
14636 return Rewriter.visit(S);
14637 }
14638
14639 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
14640 if (Pred) {
14641 if (auto *U = dyn_cast<SCEVUnionPredicate>(Pred)) {
14642 for (const auto *Pred : U->getPredicates())
14643 if (const auto *IPred = dyn_cast<SCEVComparePredicate>(Pred))
14644 if (IPred->getLHS() == Expr &&
14645 IPred->getPredicate() == ICmpInst::ICMP_EQ)
14646 return IPred->getRHS();
14647 } else if (const auto *IPred = dyn_cast<SCEVComparePredicate>(Pred)) {
14648 if (IPred->getLHS() == Expr &&
14649 IPred->getPredicate() == ICmpInst::ICMP_EQ)
14650 return IPred->getRHS();
14651 }
14652 }
14653 return convertToAddRecWithPreds(Expr);
14654 }
14655
14656 const SCEV *visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr) {
14657 const SCEV *Operand = visit(Expr->getOperand());
14658 const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Operand);
14659 if (AR && AR->getLoop() == L && AR->isAffine()) {
14660 // This couldn't be folded because the operand didn't have the nuw
14661 // flag. Add the nusw flag as an assumption that we could make.
14662 const SCEV *Step = AR->getStepRecurrence(SE);
14663 Type *Ty = Expr->getType();
14664 if (addOverflowAssumption(AR, SCEVWrapPredicate::IncrementNUSW))
14665 return SE.getAddRecExpr(SE.getZeroExtendExpr(AR->getStart(), Ty),
14666 SE.getSignExtendExpr(Step, Ty), L,
14667 AR->getNoWrapFlags());
14668 }
14669 return SE.getZeroExtendExpr(Operand, Expr->getType());
14670 }
14671
14672 const SCEV *visitSignExtendExpr(const SCEVSignExtendExpr *Expr) {
14673 const SCEV *Operand = visit(Expr->getOperand());
14674 const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Operand);
14675 if (AR && AR->getLoop() == L && AR->isAffine()) {
14676 // This couldn't be folded because the operand didn't have the nsw
14677 // flag. Add the nssw flag as an assumption that we could make.
14678 const SCEV *Step = AR->getStepRecurrence(SE);
14679 Type *Ty = Expr->getType();
14680 if (addOverflowAssumption(AR, SCEVWrapPredicate::IncrementNSSW))
14681 return SE.getAddRecExpr(SE.getSignExtendExpr(AR->getStart(), Ty),
14682 SE.getSignExtendExpr(Step, Ty), L,
14683 AR->getNoWrapFlags());
14684 }
14685 return SE.getSignExtendExpr(Operand, Expr->getType());
14686 }
14687
14688private:
14689 explicit SCEVPredicateRewriter(const Loop *L, ScalarEvolution &SE,
14691 const SCEVPredicate *Pred)
14692 : SCEVRewriteVisitor(SE), NewPreds(NewPreds), Pred(Pred), L(L) {}
14693
14694 bool addOverflowAssumption(const SCEVPredicate *P) {
14695 if (!NewPreds) {
14696 // Check if we've already made this assumption.
14697 return Pred && Pred->implies(P);
14698 }
14699 NewPreds->insert(P);
14700 return true;
14701 }
14702
14703 bool addOverflowAssumption(const SCEVAddRecExpr *AR,
14705 auto *A = SE.getWrapPredicate(AR, AddedFlags);
14706 return addOverflowAssumption(A);
14707 }
14708
14709 // If \p Expr represents a PHINode, we try to see if it can be represented
14710 // as an AddRec, possibly under a predicate (PHISCEVPred). If it is possible
14711 // to add this predicate as a runtime overflow check, we return the AddRec.
14712 // If \p Expr does not meet these conditions (is not a PHI node, or we
14713 // couldn't create an AddRec for it, or couldn't add the predicate), we just
14714 // return \p Expr.
14715 const SCEV *convertToAddRecWithPreds(const SCEVUnknown *Expr) {
14716 if (!isa<PHINode>(Expr->getValue()))
14717 return Expr;
14718 std::optional<
14719 std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>>>
14720 PredicatedRewrite = SE.createAddRecFromPHIWithCasts(Expr);
14721 if (!PredicatedRewrite)
14722 return Expr;
14723 for (const auto *P : PredicatedRewrite->second){
14724 // Wrap predicates from outer loops are not supported.
14725 if (auto *WP = dyn_cast<const SCEVWrapPredicate>(P)) {
14726 if (L != WP->getExpr()->getLoop())
14727 return Expr;
14728 }
14729 if (!addOverflowAssumption(P))
14730 return Expr;
14731 }
14732 return PredicatedRewrite->first;
14733 }
14734
14736 const SCEVPredicate *Pred;
14737 const Loop *L;
14738};
14739
14740} // end anonymous namespace
14741
14742const SCEV *
14744 const SCEVPredicate &Preds) {
14745 return SCEVPredicateRewriter::rewrite(S, L, *this, nullptr, &Preds);
14746}
14747
14749 const SCEV *S, const Loop *L,
14752 S = SCEVPredicateRewriter::rewrite(S, L, *this, &TransformPreds, nullptr);
14753 auto *AddRec = dyn_cast<SCEVAddRecExpr>(S);
14754
14755 if (!AddRec)
14756 return nullptr;
14757
14758 // Since the transformation was successful, we can now transfer the SCEV
14759 // predicates.
14760 for (const auto *P : TransformPreds)
14761 Preds.insert(P);
14762
14763 return AddRec;
14764}
14765
14766/// SCEV predicates
14768 SCEVPredicateKind Kind)
14769 : FastID(ID), Kind(Kind) {}
14770
14772 const ICmpInst::Predicate Pred,
14773 const SCEV *LHS, const SCEV *RHS)
14774 : SCEVPredicate(ID, P_Compare), Pred(Pred), LHS(LHS), RHS(RHS) {
14775 assert(LHS->getType() == RHS->getType() && "LHS and RHS types don't match");
14776 assert(LHS != RHS && "LHS and RHS are the same SCEV");
14777}
14778
14780 const auto *Op = dyn_cast<SCEVComparePredicate>(N);
14781
14782 if (!Op)
14783 return false;
14784
14785 if (Pred != ICmpInst::ICMP_EQ)
14786 return false;
14787
14788 return Op->LHS == LHS && Op->RHS == RHS;
14789}
14790
14791bool SCEVComparePredicate::isAlwaysTrue() const { return false; }
14792
14794 if (Pred == ICmpInst::ICMP_EQ)
14795 OS.indent(Depth) << "Equal predicate: " << *LHS << " == " << *RHS << "\n";
14796 else
14797 OS.indent(Depth) << "Compare predicate: " << *LHS << " " << Pred << ") "
14798 << *RHS << "\n";
14799
14800}
14801
14803 const SCEVAddRecExpr *AR,
14804 IncrementWrapFlags Flags)
14805 : SCEVPredicate(ID, P_Wrap), AR(AR), Flags(Flags) {}
14806
14807const SCEVAddRecExpr *SCEVWrapPredicate::getExpr() const { return AR; }
14808
14810 const auto *Op = dyn_cast<SCEVWrapPredicate>(N);
14811
14812 return Op && Op->AR == AR && setFlags(Flags, Op->Flags) == Flags;
14813}
14814
14816 SCEV::NoWrapFlags ScevFlags = AR->getNoWrapFlags();
14817 IncrementWrapFlags IFlags = Flags;
14818
14819 if (ScalarEvolution::setFlags(ScevFlags, SCEV::FlagNSW) == ScevFlags)
14820 IFlags = clearFlags(IFlags, IncrementNSSW);
14821
14822 return IFlags == IncrementAnyWrap;
14823}
14824
14826 OS.indent(Depth) << *getExpr() << " Added Flags: ";
14828 OS << "<nusw>";
14830 OS << "<nssw>";
14831 OS << "\n";
14832}
14833
14836 ScalarEvolution &SE) {
14837 IncrementWrapFlags ImpliedFlags = IncrementAnyWrap;
14838 SCEV::NoWrapFlags StaticFlags = AR->getNoWrapFlags();
14839
14840 // We can safely transfer the NSW flag as NSSW.
14841 if (ScalarEvolution::setFlags(StaticFlags, SCEV::FlagNSW) == StaticFlags)
14842 ImpliedFlags = IncrementNSSW;
14843
14844 if (ScalarEvolution::setFlags(StaticFlags, SCEV::FlagNUW) == StaticFlags) {
14845 // If the increment is positive, the SCEV NUW flag will also imply the
14846 // WrapPredicate NUSW flag.
14847 if (const auto *Step = dyn_cast<SCEVConstant>(AR->getStepRecurrence(SE)))
14848 if (Step->getValue()->getValue().isNonNegative())
14849 ImpliedFlags = setFlags(ImpliedFlags, IncrementNUSW);
14850 }
14851
14852 return ImpliedFlags;
14853}
14854
14855/// Union predicates don't get cached so create a dummy set ID for it.
14857 : SCEVPredicate(FoldingSetNodeIDRef(nullptr, 0), P_Union) {
14858 for (const auto *P : Preds)
14859 add(P);
14860}
14861
14863 return all_of(Preds,
14864 [](const SCEVPredicate *I) { return I->isAlwaysTrue(); });
14865}
14866
14868 if (const auto *Set = dyn_cast<SCEVUnionPredicate>(N))
14869 return all_of(Set->Preds,
14870 [this](const SCEVPredicate *I) { return this->implies(I); });
14871
14872 return any_of(Preds,
14873 [N](const SCEVPredicate *I) { return I->implies(N); });
14874}
14875
14877 for (const auto *Pred : Preds)
14878 Pred->print(OS, Depth);
14879}
14880
14881void SCEVUnionPredicate::add(const SCEVPredicate *N) {
14882 if (const auto *Set = dyn_cast<SCEVUnionPredicate>(N)) {
14883 for (const auto *Pred : Set->Preds)
14884 add(Pred);
14885 return;
14886 }
14887
14888 // Only add predicate if it is not already implied by this union predicate.
14889 if (!implies(N))
14890 Preds.push_back(N);
14891}
14892
14894 Loop &L)
14895 : SE(SE), L(L) {
14897 Preds = std::make_unique<SCEVUnionPredicate>(Empty);
14898}
14899
14902 for (const auto *Op : Ops)
14903 // We do not expect that forgetting cached data for SCEVConstants will ever
14904 // open any prospects for sharpening or introduce any correctness issues,
14905 // so we don't bother storing their dependencies.
14906 if (!isa<SCEVConstant>(Op))
14907 SCEVUsers[Op].insert(User);
14908}
14909
14911 const SCEV *Expr = SE.getSCEV(V);
14912 RewriteEntry &Entry = RewriteMap[Expr];
14913
14914 // If we already have an entry and the version matches, return it.
14915 if (Entry.second && Generation == Entry.first)
14916 return Entry.second;
14917
14918 // We found an entry but it's stale. Rewrite the stale entry
14919 // according to the current predicate.
14920 if (Entry.second)
14921 Expr = Entry.second;
14922
14923 const SCEV *NewSCEV = SE.rewriteUsingPredicate(Expr, &L, *Preds);
14924 Entry = {Generation, NewSCEV};
14925
14926 return NewSCEV;
14927}
14928
14930 if (!BackedgeCount) {
14932 BackedgeCount = SE.getPredicatedBackedgeTakenCount(&L, Preds);
14933 for (const auto *P : Preds)
14934 addPredicate(*P);
14935 }
14936 return BackedgeCount;
14937}
14938
14940 if (!SymbolicMaxBackedgeCount) {
14942 SymbolicMaxBackedgeCount =
14944 for (const auto *P : Preds)
14945 addPredicate(*P);
14946 }
14947 return SymbolicMaxBackedgeCount;
14948}
14949
14951 if (Preds->implies(&Pred))
14952 return;
14953
14954 SmallVector<const SCEVPredicate *, 4> NewPreds(Preds->getPredicates());
14955 NewPreds.push_back(&Pred);
14956 Preds = std::make_unique<SCEVUnionPredicate>(NewPreds);
14957 updateGeneration();
14958}
14959
14961 return *Preds;
14962}
14963
14964void PredicatedScalarEvolution::updateGeneration() {
14965 // If the generation number wrapped recompute everything.
14966 if (++Generation == 0) {
14967 for (auto &II : RewriteMap) {
14968 const SCEV *Rewritten = II.second.second;
14969 II.second = {Generation, SE.rewriteUsingPredicate(Rewritten, &L, *Preds)};
14970 }
14971 }
14972}
14973
14976 const SCEV *Expr = getSCEV(V);
14977 const auto *AR = cast<SCEVAddRecExpr>(Expr);
14978
14979 auto ImpliedFlags = SCEVWrapPredicate::getImpliedFlags(AR, SE);
14980
14981 // Clear the statically implied flags.
14982 Flags = SCEVWrapPredicate::clearFlags(Flags, ImpliedFlags);
14983 addPredicate(*SE.getWrapPredicate(AR, Flags));
14984
14985 auto II = FlagsMap.insert({V, Flags});
14986 if (!II.second)
14987 II.first->second = SCEVWrapPredicate::setFlags(Flags, II.first->second);
14988}
14989
14992 const SCEV *Expr = getSCEV(V);
14993 const auto *AR = cast<SCEVAddRecExpr>(Expr);
14994
14996 Flags, SCEVWrapPredicate::getImpliedFlags(AR, SE));
14997
14998 auto II = FlagsMap.find(V);
14999
15000 if (II != FlagsMap.end())
15001 Flags = SCEVWrapPredicate::clearFlags(Flags, II->second);
15002
15004}
15005
15007 const SCEV *Expr = this->getSCEV(V);
15009 auto *New = SE.convertSCEVToAddRecWithPredicates(Expr, &L, NewPreds);
15010
15011 if (!New)
15012 return nullptr;
15013
15014 for (const auto *P : NewPreds)
15015 addPredicate(*P);
15016
15017 RewriteMap[SE.getSCEV(V)] = {Generation, New};
15018 return New;
15019}
15020
15023 : RewriteMap(Init.RewriteMap), SE(Init.SE), L(Init.L),
15024 Preds(std::make_unique<SCEVUnionPredicate>(Init.Preds->getPredicates())),
15025 Generation(Init.Generation), BackedgeCount(Init.BackedgeCount) {
15026 for (auto I : Init.FlagsMap)
15027 FlagsMap.insert(I);
15028}
15029
15031 // For each block.
15032 for (auto *BB : L.getBlocks())
15033 for (auto &I : *BB) {
15034 if (!SE.isSCEVable(I.getType()))
15035 continue;
15036
15037 auto *Expr = SE.getSCEV(&I);
15038 auto II = RewriteMap.find(Expr);
15039
15040 if (II == RewriteMap.end())
15041 continue;
15042
15043 // Don't print things that are not interesting.
15044 if (II->second.second == Expr)
15045 continue;
15046
15047 OS.indent(Depth) << "[PSE]" << I << ":\n";
15048 OS.indent(Depth + 2) << *Expr << "\n";
15049 OS.indent(Depth + 2) << "--> " << *II->second.second << "\n";
15050 }
15051}
15052
15053// Match the mathematical pattern A - (A / B) * B, where A and B can be
15054// arbitrary expressions. Also match zext (trunc A to iB) to iY, which is used
15055// for URem with constant power-of-2 second operands.
15056// It's not always easy, as A and B can be folded (imagine A is X / 2, and B is
15057// 4, A / B becomes X / 8).
15058bool ScalarEvolution::matchURem(const SCEV *Expr, const SCEV *&LHS,
15059 const SCEV *&RHS) {
15060 if (Expr->getType()->isPointerTy())
15061 return false;
15062
15063 // Try to match 'zext (trunc A to iB) to iY', which is used
15064 // for URem with constant power-of-2 second operands. Make sure the size of
15065 // the operand A matches the size of the whole expressions.
15066 if (const auto *ZExt = dyn_cast<SCEVZeroExtendExpr>(Expr))
15067 if (const auto *Trunc = dyn_cast<SCEVTruncateExpr>(ZExt->getOperand(0))) {
15068 LHS = Trunc->getOperand();
15069 // Bail out if the type of the LHS is larger than the type of the
15070 // expression for now.
15071 if (getTypeSizeInBits(LHS->getType()) >
15072 getTypeSizeInBits(Expr->getType()))
15073 return false;
15074 if (LHS->getType() != Expr->getType())
15075 LHS = getZeroExtendExpr(LHS, Expr->getType());
15077 << getTypeSizeInBits(Trunc->getType()));
15078 return true;
15079 }
15080 const auto *Add = dyn_cast<SCEVAddExpr>(Expr);
15081 if (Add == nullptr || Add->getNumOperands() != 2)
15082 return false;
15083
15084 const SCEV *A = Add->getOperand(1);
15085 const auto *Mul = dyn_cast<SCEVMulExpr>(Add->getOperand(0));
15086
15087 if (Mul == nullptr)
15088 return false;
15089
15090 const auto MatchURemWithDivisor = [&](const SCEV *B) {
15091 // (SomeExpr + (-(SomeExpr / B) * B)).
15092 if (Expr == getURemExpr(A, B)) {
15093 LHS = A;
15094 RHS = B;
15095 return true;
15096 }
15097 return false;
15098 };
15099
15100 // (SomeExpr + (-1 * (SomeExpr / B) * B)).
15101 if (Mul->getNumOperands() == 3 && isa<SCEVConstant>(Mul->getOperand(0)))
15102 return MatchURemWithDivisor(Mul->getOperand(1)) ||
15103 MatchURemWithDivisor(Mul->getOperand(2));
15104
15105 // (SomeExpr + ((-SomeExpr / B) * B)) or (SomeExpr + ((SomeExpr / B) * -B)).
15106 if (Mul->getNumOperands() == 2)
15107 return MatchURemWithDivisor(Mul->getOperand(1)) ||
15108 MatchURemWithDivisor(Mul->getOperand(0)) ||
15109 MatchURemWithDivisor(getNegativeSCEV(Mul->getOperand(1))) ||
15110 MatchURemWithDivisor(getNegativeSCEV(Mul->getOperand(0)));
15111 return false;
15112}
15113
15116 LoopGuards Guards(SE);
15117 SmallVector<const SCEV *> ExprsToRewrite;
15118 auto CollectCondition = [&](ICmpInst::Predicate Predicate, const SCEV *LHS,
15119 const SCEV *RHS,
15121 &RewriteMap) {
15122 // WARNING: It is generally unsound to apply any wrap flags to the proposed
15123 // replacement SCEV which isn't directly implied by the structure of that
15124 // SCEV. In particular, using contextual facts to imply flags is *NOT*
15125 // legal. See the scoping rules for flags in the header to understand why.
15126
15127 // If LHS is a constant, apply information to the other expression.
15128 if (isa<SCEVConstant>(LHS)) {
15129 std::swap(LHS, RHS);
15130 Predicate = CmpInst::getSwappedPredicate(Predicate);
15131 }
15132
15133 // Check for a condition of the form (-C1 + X < C2). InstCombine will
15134 // create this form when combining two checks of the form (X u< C2 + C1) and
15135 // (X >=u C1).
15136 auto MatchRangeCheckIdiom = [&SE, Predicate, LHS, RHS, &RewriteMap,
15137 &ExprsToRewrite]() {
15138 auto *AddExpr = dyn_cast<SCEVAddExpr>(LHS);
15139 if (!AddExpr || AddExpr->getNumOperands() != 2)
15140 return false;
15141
15142 auto *C1 = dyn_cast<SCEVConstant>(AddExpr->getOperand(0));
15143 auto *LHSUnknown = dyn_cast<SCEVUnknown>(AddExpr->getOperand(1));
15144 auto *C2 = dyn_cast<SCEVConstant>(RHS);
15145 if (!C1 || !C2 || !LHSUnknown)
15146 return false;
15147
15148 auto ExactRegion =
15149 ConstantRange::makeExactICmpRegion(Predicate, C2->getAPInt())
15150 .sub(C1->getAPInt());
15151
15152 // Bail out, unless we have a non-wrapping, monotonic range.
15153 if (ExactRegion.isWrappedSet() || ExactRegion.isFullSet())
15154 return false;
15155 auto I = RewriteMap.find(LHSUnknown);
15156 const SCEV *RewrittenLHS = I != RewriteMap.end() ? I->second : LHSUnknown;
15157 RewriteMap[LHSUnknown] = SE.getUMaxExpr(
15158 SE.getConstant(ExactRegion.getUnsignedMin()),
15159 SE.getUMinExpr(RewrittenLHS,
15160 SE.getConstant(ExactRegion.getUnsignedMax())));
15161 ExprsToRewrite.push_back(LHSUnknown);
15162 return true;
15163 };
15164 if (MatchRangeCheckIdiom())
15165 return;
15166
15167 // Return true if \p Expr is a MinMax SCEV expression with a non-negative
15168 // constant operand. If so, return in \p SCTy the SCEV type and in \p RHS
15169 // the non-constant operand and in \p LHS the constant operand.
15170 auto IsMinMaxSCEVWithNonNegativeConstant =
15171 [&](const SCEV *Expr, SCEVTypes &SCTy, const SCEV *&LHS,
15172 const SCEV *&RHS) {
15173 if (auto *MinMax = dyn_cast<SCEVMinMaxExpr>(Expr)) {
15174 if (MinMax->getNumOperands() != 2)
15175 return false;
15176 if (auto *C = dyn_cast<SCEVConstant>(MinMax->getOperand(0))) {
15177 if (C->getAPInt().isNegative())
15178 return false;
15179 SCTy = MinMax->getSCEVType();
15180 LHS = MinMax->getOperand(0);
15181 RHS = MinMax->getOperand(1);
15182 return true;
15183 }
15184 }
15185 return false;
15186 };
15187
15188 // Checks whether Expr is a non-negative constant, and Divisor is a positive
15189 // constant, and returns their APInt in ExprVal and in DivisorVal.
15190 auto GetNonNegExprAndPosDivisor = [&](const SCEV *Expr, const SCEV *Divisor,
15191 APInt &ExprVal, APInt &DivisorVal) {
15192 auto *ConstExpr = dyn_cast<SCEVConstant>(Expr);
15193 auto *ConstDivisor = dyn_cast<SCEVConstant>(Divisor);
15194 if (!ConstExpr || !ConstDivisor)
15195 return false;
15196 ExprVal = ConstExpr->getAPInt();
15197 DivisorVal = ConstDivisor->getAPInt();
15198 return ExprVal.isNonNegative() && !DivisorVal.isNonPositive();
15199 };
15200
15201 // Return a new SCEV that modifies \p Expr to the closest number divides by
15202 // \p Divisor and greater or equal than Expr.
15203 // For now, only handle constant Expr and Divisor.
15204 auto GetNextSCEVDividesByDivisor = [&](const SCEV *Expr,
15205 const SCEV *Divisor) {
15206 APInt ExprVal;
15207 APInt DivisorVal;
15208 if (!GetNonNegExprAndPosDivisor(Expr, Divisor, ExprVal, DivisorVal))
15209 return Expr;
15210 APInt Rem = ExprVal.urem(DivisorVal);
15211 if (!Rem.isZero())
15212 // return the SCEV: Expr + Divisor - Expr % Divisor
15213 return SE.getConstant(ExprVal + DivisorVal - Rem);
15214 return Expr;
15215 };
15216
15217 // Return a new SCEV that modifies \p Expr to the closest number divides by
15218 // \p Divisor and less or equal than Expr.
15219 // For now, only handle constant Expr and Divisor.
15220 auto GetPreviousSCEVDividesByDivisor = [&](const SCEV *Expr,
15221 const SCEV *Divisor) {
15222 APInt ExprVal;
15223 APInt DivisorVal;
15224 if (!GetNonNegExprAndPosDivisor(Expr, Divisor, ExprVal, DivisorVal))
15225 return Expr;
15226 APInt Rem = ExprVal.urem(DivisorVal);
15227 // return the SCEV: Expr - Expr % Divisor
15228 return SE.getConstant(ExprVal - Rem);
15229 };
15230
15231 // Apply divisibilty by \p Divisor on MinMaxExpr with constant values,
15232 // recursively. This is done by aligning up/down the constant value to the
15233 // Divisor.
15234 std::function<const SCEV *(const SCEV *, const SCEV *)>
15235 ApplyDivisibiltyOnMinMaxExpr = [&](const SCEV *MinMaxExpr,
15236 const SCEV *Divisor) {
15237 const SCEV *MinMaxLHS = nullptr, *MinMaxRHS = nullptr;
15238 SCEVTypes SCTy;
15239 if (!IsMinMaxSCEVWithNonNegativeConstant(MinMaxExpr, SCTy, MinMaxLHS,
15240 MinMaxRHS))
15241 return MinMaxExpr;
15242 auto IsMin =
15243 isa<SCEVSMinExpr>(MinMaxExpr) || isa<SCEVUMinExpr>(MinMaxExpr);
15244 assert(SE.isKnownNonNegative(MinMaxLHS) &&
15245 "Expected non-negative operand!");
15246 auto *DivisibleExpr =
15247 IsMin ? GetPreviousSCEVDividesByDivisor(MinMaxLHS, Divisor)
15248 : GetNextSCEVDividesByDivisor(MinMaxLHS, Divisor);
15250 ApplyDivisibiltyOnMinMaxExpr(MinMaxRHS, Divisor), DivisibleExpr};
15251 return SE.getMinMaxExpr(SCTy, Ops);
15252 };
15253
15254 // If we have LHS == 0, check if LHS is computing a property of some unknown
15255 // SCEV %v which we can rewrite %v to express explicitly.
15256 const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS);
15257 if (Predicate == CmpInst::ICMP_EQ && RHSC &&
15258 RHSC->getValue()->isNullValue()) {
15259 // If LHS is A % B, i.e. A % B == 0, rewrite A to (A /u B) * B to
15260 // explicitly express that.
15261 const SCEV *URemLHS = nullptr;
15262 const SCEV *URemRHS = nullptr;
15263 if (SE.matchURem(LHS, URemLHS, URemRHS)) {
15264 if (const SCEVUnknown *LHSUnknown = dyn_cast<SCEVUnknown>(URemLHS)) {
15265 auto I = RewriteMap.find(LHSUnknown);
15266 const SCEV *RewrittenLHS =
15267 I != RewriteMap.end() ? I->second : LHSUnknown;
15268 RewrittenLHS = ApplyDivisibiltyOnMinMaxExpr(RewrittenLHS, URemRHS);
15269 const auto *Multiple =
15270 SE.getMulExpr(SE.getUDivExpr(RewrittenLHS, URemRHS), URemRHS);
15271 RewriteMap[LHSUnknown] = Multiple;
15272 ExprsToRewrite.push_back(LHSUnknown);
15273 return;
15274 }
15275 }
15276 }
15277
15278 // Do not apply information for constants or if RHS contains an AddRec.
15279 if (isa<SCEVConstant>(LHS) || SE.containsAddRecurrence(RHS))
15280 return;
15281
15282 // If RHS is SCEVUnknown, make sure the information is applied to it.
15283 if (!isa<SCEVUnknown>(LHS) && isa<SCEVUnknown>(RHS)) {
15284 std::swap(LHS, RHS);
15285 Predicate = CmpInst::getSwappedPredicate(Predicate);
15286 }
15287
15288 // Puts rewrite rule \p From -> \p To into the rewrite map. Also if \p From
15289 // and \p FromRewritten are the same (i.e. there has been no rewrite
15290 // registered for \p From), then puts this value in the list of rewritten
15291 // expressions.
15292 auto AddRewrite = [&](const SCEV *From, const SCEV *FromRewritten,
15293 const SCEV *To) {
15294 if (From == FromRewritten)
15295 ExprsToRewrite.push_back(From);
15296 RewriteMap[From] = To;
15297 };
15298
15299 // Checks whether \p S has already been rewritten. In that case returns the
15300 // existing rewrite because we want to chain further rewrites onto the
15301 // already rewritten value. Otherwise returns \p S.
15302 auto GetMaybeRewritten = [&](const SCEV *S) {
15303 auto I = RewriteMap.find(S);
15304 return I != RewriteMap.end() ? I->second : S;
15305 };
15306
15307 // Check for the SCEV expression (A /u B) * B while B is a constant, inside
15308 // \p Expr. The check is done recuresively on \p Expr, which is assumed to
15309 // be a composition of Min/Max SCEVs. Return whether the SCEV expression (A
15310 // /u B) * B was found, and return the divisor B in \p DividesBy. For
15311 // example, if Expr = umin (umax ((A /u 8) * 8, 16), 64), return true since
15312 // (A /u 8) * 8 matched the pattern, and return the constant SCEV 8 in \p
15313 // DividesBy.
15314 std::function<bool(const SCEV *, const SCEV *&)> HasDivisibiltyInfo =
15315 [&](const SCEV *Expr, const SCEV *&DividesBy) {
15316 if (auto *Mul = dyn_cast<SCEVMulExpr>(Expr)) {
15317 if (Mul->getNumOperands() != 2)
15318 return false;
15319 auto *MulLHS = Mul->getOperand(0);
15320 auto *MulRHS = Mul->getOperand(1);
15321 if (isa<SCEVConstant>(MulLHS))
15322 std::swap(MulLHS, MulRHS);
15323 if (auto *Div = dyn_cast<SCEVUDivExpr>(MulLHS))
15324 if (Div->getOperand(1) == MulRHS) {
15325 DividesBy = MulRHS;
15326 return true;
15327 }
15328 }
15329 if (auto *MinMax = dyn_cast<SCEVMinMaxExpr>(Expr))
15330 return HasDivisibiltyInfo(MinMax->getOperand(0), DividesBy) ||
15331 HasDivisibiltyInfo(MinMax->getOperand(1), DividesBy);
15332 return false;
15333 };
15334
15335 // Return true if Expr known to divide by \p DividesBy.
15336 std::function<bool(const SCEV *, const SCEV *&)> IsKnownToDivideBy =
15337 [&](const SCEV *Expr, const SCEV *DividesBy) {
15338 if (SE.getURemExpr(Expr, DividesBy)->isZero())
15339 return true;
15340 if (auto *MinMax = dyn_cast<SCEVMinMaxExpr>(Expr))
15341 return IsKnownToDivideBy(MinMax->getOperand(0), DividesBy) &&
15342 IsKnownToDivideBy(MinMax->getOperand(1), DividesBy);
15343 return false;
15344 };
15345
15346 const SCEV *RewrittenLHS = GetMaybeRewritten(LHS);
15347 const SCEV *DividesBy = nullptr;
15348 if (HasDivisibiltyInfo(RewrittenLHS, DividesBy))
15349 // Check that the whole expression is divided by DividesBy
15350 DividesBy =
15351 IsKnownToDivideBy(RewrittenLHS, DividesBy) ? DividesBy : nullptr;
15352
15353 // Collect rewrites for LHS and its transitive operands based on the
15354 // condition.
15355 // For min/max expressions, also apply the guard to its operands:
15356 // 'min(a, b) >= c' -> '(a >= c) and (b >= c)',
15357 // 'min(a, b) > c' -> '(a > c) and (b > c)',
15358 // 'max(a, b) <= c' -> '(a <= c) and (b <= c)',
15359 // 'max(a, b) < c' -> '(a < c) and (b < c)'.
15360
15361 // We cannot express strict predicates in SCEV, so instead we replace them
15362 // with non-strict ones against plus or minus one of RHS depending on the
15363 // predicate.
15364 const SCEV *One = SE.getOne(RHS->getType());
15365 switch (Predicate) {
15366 case CmpInst::ICMP_ULT:
15367 if (RHS->getType()->isPointerTy())
15368 return;
15369 RHS = SE.getUMaxExpr(RHS, One);
15370 [[fallthrough]];
15371 case CmpInst::ICMP_SLT: {
15372 RHS = SE.getMinusSCEV(RHS, One);
15373 RHS = DividesBy ? GetPreviousSCEVDividesByDivisor(RHS, DividesBy) : RHS;
15374 break;
15375 }
15376 case CmpInst::ICMP_UGT:
15377 case CmpInst::ICMP_SGT:
15378 RHS = SE.getAddExpr(RHS, One);
15379 RHS = DividesBy ? GetNextSCEVDividesByDivisor(RHS, DividesBy) : RHS;
15380 break;
15381 case CmpInst::ICMP_ULE:
15382 case CmpInst::ICMP_SLE:
15383 RHS = DividesBy ? GetPreviousSCEVDividesByDivisor(RHS, DividesBy) : RHS;
15384 break;
15385 case CmpInst::ICMP_UGE:
15386 case CmpInst::ICMP_SGE:
15387 RHS = DividesBy ? GetNextSCEVDividesByDivisor(RHS, DividesBy) : RHS;
15388 break;
15389 default:
15390 break;
15391 }
15392
15395
15396 auto EnqueueOperands = [&Worklist](const SCEVNAryExpr *S) {
15397 append_range(Worklist, S->operands());
15398 };
15399
15400 while (!Worklist.empty()) {
15401 const SCEV *From = Worklist.pop_back_val();
15402 if (isa<SCEVConstant>(From))
15403 continue;
15404 if (!Visited.insert(From).second)
15405 continue;
15406 const SCEV *FromRewritten = GetMaybeRewritten(From);
15407 const SCEV *To = nullptr;
15408
15409 switch (Predicate) {
15410 case CmpInst::ICMP_ULT:
15411 case CmpInst::ICMP_ULE:
15412 To = SE.getUMinExpr(FromRewritten, RHS);
15413 if (auto *UMax = dyn_cast<SCEVUMaxExpr>(FromRewritten))
15414 EnqueueOperands(UMax);
15415 break;
15416 case CmpInst::ICMP_SLT:
15417 case CmpInst::ICMP_SLE:
15418 To = SE.getSMinExpr(FromRewritten, RHS);
15419 if (auto *SMax = dyn_cast<SCEVSMaxExpr>(FromRewritten))
15420 EnqueueOperands(SMax);
15421 break;
15422 case CmpInst::ICMP_UGT:
15423 case CmpInst::ICMP_UGE:
15424 To = SE.getUMaxExpr(FromRewritten, RHS);
15425 if (auto *UMin = dyn_cast<SCEVUMinExpr>(FromRewritten))
15426 EnqueueOperands(UMin);
15427 break;
15428 case CmpInst::ICMP_SGT:
15429 case CmpInst::ICMP_SGE:
15430 To = SE.getSMaxExpr(FromRewritten, RHS);
15431 if (auto *SMin = dyn_cast<SCEVSMinExpr>(FromRewritten))
15432 EnqueueOperands(SMin);
15433 break;
15434 case CmpInst::ICMP_EQ:
15435 if (isa<SCEVConstant>(RHS))
15436 To = RHS;
15437 break;
15438 case CmpInst::ICMP_NE:
15439 if (isa<SCEVConstant>(RHS) &&
15440 cast<SCEVConstant>(RHS)->getValue()->isNullValue()) {
15441 const SCEV *OneAlignedUp =
15442 DividesBy ? GetNextSCEVDividesByDivisor(One, DividesBy) : One;
15443 To = SE.getUMaxExpr(FromRewritten, OneAlignedUp);
15444 }
15445 break;
15446 default:
15447 break;
15448 }
15449
15450 if (To)
15451 AddRewrite(From, FromRewritten, To);
15452 }
15453 };
15454
15455 BasicBlock *Header = L->getHeader();
15457 // First, collect information from assumptions dominating the loop.
15458 for (auto &AssumeVH : SE.AC.assumptions()) {
15459 if (!AssumeVH)
15460 continue;
15461 auto *AssumeI = cast<CallInst>(AssumeVH);
15462 if (!SE.DT.dominates(AssumeI, Header))
15463 continue;
15464 Terms.emplace_back(AssumeI->getOperand(0), true);
15465 }
15466
15467 // Second, collect information from llvm.experimental.guards dominating the loop.
15468 auto *GuardDecl = SE.F.getParent()->getFunction(
15469 Intrinsic::getName(Intrinsic::experimental_guard));
15470 if (GuardDecl)
15471 for (const auto *GU : GuardDecl->users())
15472 if (const auto *Guard = dyn_cast<IntrinsicInst>(GU))
15473 if (Guard->getFunction() == Header->getParent() &&
15474 SE.DT.dominates(Guard, Header))
15475 Terms.emplace_back(Guard->getArgOperand(0), true);
15476
15477 // Third, collect conditions from dominating branches. Starting at the loop
15478 // predecessor, climb up the predecessor chain, as long as there are
15479 // predecessors that can be found that have unique successors leading to the
15480 // original header.
15481 // TODO: share this logic with isLoopEntryGuardedByCond.
15482 for (std::pair<const BasicBlock *, const BasicBlock *> Pair(
15483 L->getLoopPredecessor(), Header);
15484 Pair.first;
15485 Pair = SE.getPredecessorWithUniqueSuccessorForBB(Pair.first)) {
15486
15487 const BranchInst *LoopEntryPredicate =
15488 dyn_cast<BranchInst>(Pair.first->getTerminator());
15489 if (!LoopEntryPredicate || LoopEntryPredicate->isUnconditional())
15490 continue;
15491
15492 Terms.emplace_back(LoopEntryPredicate->getCondition(),
15493 LoopEntryPredicate->getSuccessor(0) == Pair.second);
15494 }
15495
15496 // Now apply the information from the collected conditions to
15497 // Guards.RewriteMap. Conditions are processed in reverse order, so the
15498 // earliest conditions is processed first. This ensures the SCEVs with the
15499 // shortest dependency chains are constructed first.
15500 for (auto [Term, EnterIfTrue] : reverse(Terms)) {
15501 SmallVector<Value *, 8> Worklist;
15503 Worklist.push_back(Term);
15504 while (!Worklist.empty()) {
15505 Value *Cond = Worklist.pop_back_val();
15506 if (!Visited.insert(Cond).second)
15507 continue;
15508
15509 if (auto *Cmp = dyn_cast<ICmpInst>(Cond)) {
15510 auto Predicate =
15511 EnterIfTrue ? Cmp->getPredicate() : Cmp->getInversePredicate();
15512 const auto *LHS = SE.getSCEV(Cmp->getOperand(0));
15513 const auto *RHS = SE.getSCEV(Cmp->getOperand(1));
15514 CollectCondition(Predicate, LHS, RHS, Guards.RewriteMap);
15515 continue;
15516 }
15517
15518 Value *L, *R;
15519 if (EnterIfTrue ? match(Cond, m_LogicalAnd(m_Value(L), m_Value(R)))
15520 : match(Cond, m_LogicalOr(m_Value(L), m_Value(R)))) {
15521 Worklist.push_back(L);
15522 Worklist.push_back(R);
15523 }
15524 }
15525 }
15526
15527 // Let the rewriter preserve NUW/NSW flags if the unsigned/signed ranges of
15528 // the replacement expressions are contained in the ranges of the replaced
15529 // expressions.
15530 Guards.PreserveNUW = true;
15531 Guards.PreserveNSW = true;
15532 for (const SCEV *Expr : ExprsToRewrite) {
15533 const SCEV *RewriteTo = Guards.RewriteMap[Expr];
15534 Guards.PreserveNUW &=
15535 SE.getUnsignedRange(Expr).contains(SE.getUnsignedRange(RewriteTo));
15536 Guards.PreserveNSW &=
15537 SE.getSignedRange(Expr).contains(SE.getSignedRange(RewriteTo));
15538 }
15539
15540 // Now that all rewrite information is collect, rewrite the collected
15541 // expressions with the information in the map. This applies information to
15542 // sub-expressions.
15543 if (ExprsToRewrite.size() > 1) {
15544 for (const SCEV *Expr : ExprsToRewrite) {
15545 const SCEV *RewriteTo = Guards.RewriteMap[Expr];
15546 Guards.RewriteMap.erase(Expr);
15547 Guards.RewriteMap.insert({Expr, Guards.rewrite(RewriteTo)});
15548 }
15549 }
15550 return Guards;
15551}
15552
15554 /// A rewriter to replace SCEV expressions in Map with the corresponding entry
15555 /// in the map. It skips AddRecExpr because we cannot guarantee that the
15556 /// replacement is loop invariant in the loop of the AddRec.
15557 class SCEVLoopGuardRewriter
15558 : public SCEVRewriteVisitor<SCEVLoopGuardRewriter> {
15560
15562
15563 public:
15564 SCEVLoopGuardRewriter(ScalarEvolution &SE,
15565 const ScalarEvolution::LoopGuards &Guards)
15566 : SCEVRewriteVisitor(SE), Map(Guards.RewriteMap) {
15567 if (Guards.PreserveNUW)
15568 FlagMask = ScalarEvolution::setFlags(FlagMask, SCEV::FlagNUW);
15569 if (Guards.PreserveNSW)
15570 FlagMask = ScalarEvolution::setFlags(FlagMask, SCEV::FlagNSW);
15571 }
15572
15573 const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) { return Expr; }
15574
15575 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
15576 auto I = Map.find(Expr);
15577 if (I == Map.end())
15578 return Expr;
15579 return I->second;
15580 }
15581
15582 const SCEV *visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr) {
15583 auto I = Map.find(Expr);
15584 if (I == Map.end()) {
15585 // If we didn't find the extact ZExt expr in the map, check if there's
15586 // an entry for a smaller ZExt we can use instead.
15587 Type *Ty = Expr->getType();
15588 const SCEV *Op = Expr->getOperand(0);
15589 unsigned Bitwidth = Ty->getScalarSizeInBits() / 2;
15590 while (Bitwidth % 8 == 0 && Bitwidth >= 8 &&
15591 Bitwidth > Op->getType()->getScalarSizeInBits()) {
15592 Type *NarrowTy = IntegerType::get(SE.getContext(), Bitwidth);
15593 auto *NarrowExt = SE.getZeroExtendExpr(Op, NarrowTy);
15594 auto I = Map.find(NarrowExt);
15595 if (I != Map.end())
15596 return SE.getZeroExtendExpr(I->second, Ty);
15597 Bitwidth = Bitwidth / 2;
15598 }
15599
15601 Expr);
15602 }
15603 return I->second;
15604 }
15605
15606 const SCEV *visitSignExtendExpr(const SCEVSignExtendExpr *Expr) {
15607 auto I = Map.find(Expr);
15608 if (I == Map.end())
15610 Expr);
15611 return I->second;
15612 }
15613
15614 const SCEV *visitUMinExpr(const SCEVUMinExpr *Expr) {
15615 auto I = Map.find(Expr);
15616 if (I == Map.end())
15618 return I->second;
15619 }
15620
15621 const SCEV *visitSMinExpr(const SCEVSMinExpr *Expr) {
15622 auto I = Map.find(Expr);
15623 if (I == Map.end())
15625 return I->second;
15626 }
15627
15628 const SCEV *visitAddExpr(const SCEVAddExpr *Expr) {
15630 bool Changed = false;
15631 for (const auto *Op : Expr->operands()) {
15632 Operands.push_back(
15634 Changed |= Op != Operands.back();
15635 }
15636 // We are only replacing operands with equivalent values, so transfer the
15637 // flags from the original expression.
15638 return !Changed ? Expr
15639 : SE.getAddExpr(Operands,
15641 Expr->getNoWrapFlags(), FlagMask));
15642 }
15643
15644 const SCEV *visitMulExpr(const SCEVMulExpr *Expr) {
15646 bool Changed = false;
15647 for (const auto *Op : Expr->operands()) {
15648 Operands.push_back(
15650 Changed |= Op != Operands.back();
15651 }
15652 // We are only replacing operands with equivalent values, so transfer the
15653 // flags from the original expression.
15654 return !Changed ? Expr
15655 : SE.getMulExpr(Operands,
15657 Expr->getNoWrapFlags(), FlagMask));
15658 }
15659 };
15660
15661 if (RewriteMap.empty())
15662 return Expr;
15663
15664 SCEVLoopGuardRewriter Rewriter(SE, *this);
15665 return Rewriter.visit(Expr);
15666}
15667
15668const SCEV *ScalarEvolution::applyLoopGuards(const SCEV *Expr, const Loop *L) {
15669 return applyLoopGuards(Expr, LoopGuards::collect(L, *this));
15670}
15671
15673 const LoopGuards &Guards) {
15674 return Guards.rewrite(Expr);
15675}
@ Poison
static const LLT S1
Rewrite undef for PHI
This file implements a class to represent arbitrary precision integral constant values and operations...
@ PostInc
MachineBasicBlock MachineBasicBlock::iterator DebugLoc DL
Expand Atomic instructions
basic Basic Alias true
block Block Frequency Analysis
BlockVerifier::State From
static GCRegistry::Add< OcamlGC > B("ocaml", "ocaml 3.10-compatible GC")
static GCRegistry::Add< ErlangGC > A("erlang", "erlang-compatible garbage collector")
static GCRegistry::Add< StatepointGC > D("statepoint-example", "an example strategy for statepoint")
#define LLVM_DUMP_METHOD
Mark debug helper function definitions like dump() that should not be stripped from debug builds.
Definition: Compiler.h:533
This file contains the declarations for the subclasses of Constant, which represent the different fla...
Returns the sub type a function will return at a given Idx Should correspond to the result type of an ExtractValue instruction executed with just that one unsigned Idx
#define LLVM_DEBUG(X)
Definition: Debug.h:101
This file defines the DenseMap class.
This file builds on the ADT/GraphTraits.h file to build generic depth first graph iterator.
uint64_t Size
bool End
Definition: ELF_riscv.cpp:480
Generic implementation of equivalence classes through the use Tarjan's efficient union-find algorithm...
static GCMetadataPrinterRegistry::Add< ErlangGCPrinter > X("erlang", "erlang-compatible garbage collector")
static bool isSigned(unsigned int Opcode)
This file defines a hash set that can be used to remove duplication of nodes in a graph.
#define op(i)
Hexagon Common GEP
This file provides various utilities for inspecting and working with the control flow graph in LLVM I...
This defines the Use class.
iv Induction Variable Users
Definition: IVUsers.cpp:48
static bool isZero(Value *V, const DataLayout &DL, DominatorTree *DT, AssumptionCache *AC)
Definition: Lint.cpp:512
#define F(x, y, z)
Definition: MD5.cpp:55
#define I(x, y, z)
Definition: MD5.cpp:58
mir Rename Register Operands
#define T1
ConstantRange Range(APInt(BitWidth, Low), APInt(BitWidth, High))
uint64_t IntrinsicInst * II
static GCMetadataPrinterRegistry::Add< OcamlGCMetadataPrinter > Y("ocaml", "ocaml 3.10-compatible collector")
#define P(N)
ppc ctr loops verify
PowerPC Reduce CR logical Operation
if(PassOpts->AAPipeline)
#define INITIALIZE_PASS_DEPENDENCY(depName)
Definition: PassSupport.h:55
#define INITIALIZE_PASS_END(passName, arg, name, cfg, analysis)
Definition: PassSupport.h:57
#define INITIALIZE_PASS_BEGIN(passName, arg, name, cfg, analysis)
Definition: PassSupport.h:52
static bool rewrite(Function &F)
R600 Clause Merge
const SmallVectorImpl< MachineOperand > & Cond
static bool isValid(const char C)
Returns true if C is a valid mangled character: <0-9a-zA-Z_>.
assert(ImpDefSCC.getReg()==AMDGPU::SCC &&ImpDefSCC.isDef())
SI optimize exec mask operations pre RA
void visit(MachineFunction &MF, MachineBasicBlock &Start, std::function< void(MachineBasicBlock *)> op)
This file contains some templates that are useful if you are working with the STL at all.
raw_pwrite_stream & OS
This file provides utility classes that use RAII to save and restore values.
bool SCEVMinMaxExprContains(const SCEV *Root, const SCEV *OperandToFind, SCEVTypes RootKind)
static cl::opt< unsigned > MaxAddRecSize("scalar-evolution-max-add-rec-size", cl::Hidden, cl::desc("Max coefficients in AddRec during evolving"), cl::init(8))
static cl::opt< unsigned > RangeIterThreshold("scev-range-iter-threshold", cl::Hidden, cl::desc("Threshold for switching to iteratively computing SCEV ranges"), cl::init(32))
static const Loop * isIntegerLoopHeaderPHI(const PHINode *PN, LoopInfo &LI)
static unsigned getConstantTripCount(const SCEVConstant *ExitCount)
static int CompareValueComplexity(const LoopInfo *const LI, Value *LV, Value *RV, unsigned Depth)
Compare the two values LV and RV in terms of their "complexity" where "complexity" is a partial (and ...
static void PushLoopPHIs(const Loop *L, SmallVectorImpl< Instruction * > &Worklist, SmallPtrSetImpl< Instruction * > &Visited)
Push PHI nodes in the header of the given loop onto the given Worklist.
static void insertFoldCacheEntry(const ScalarEvolution::FoldID &ID, const SCEV *S, DenseMap< ScalarEvolution::FoldID, const SCEV * > &FoldCache, DenseMap< const SCEV *, SmallVector< ScalarEvolution::FoldID, 2 > > &FoldCacheUser)
static bool IsKnownPredicateViaMinOrMax(ScalarEvolution &SE, ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS)
Is LHS Pred RHS true on the virtue of LHS or RHS being a Min or Max expression?
static cl::opt< bool > ClassifyExpressions("scalar-evolution-classify-expressions", cl::Hidden, cl::init(true), cl::desc("When printing analysis, include information on every instruction"))
static bool CanConstantFold(const Instruction *I)
Return true if we can constant fold an instruction of the specified type, assuming that all operands ...
static cl::opt< unsigned > AddOpsInlineThreshold("scev-addops-inline-threshold", cl::Hidden, cl::desc("Threshold for inlining addition operands into a SCEV"), cl::init(500))
static cl::opt< bool > VerifyIR("scev-verify-ir", cl::Hidden, cl::desc("Verify IR correctness when making sensitive SCEV queries (slow)"), cl::init(false))
static bool BrPHIToSelect(DominatorTree &DT, BranchInst *BI, PHINode *Merge, Value *&C, Value *&LHS, Value *&RHS)
static std::optional< int > CompareSCEVComplexity(EquivalenceClasses< const SCEV * > &EqCacheSCEV, const LoopInfo *const LI, const SCEV *LHS, const SCEV *RHS, DominatorTree &DT, unsigned Depth=0)
static const SCEV * getPreStartForExtend(const SCEVAddRecExpr *AR, Type *Ty, ScalarEvolution *SE, unsigned Depth)
static std::optional< APInt > MinOptional(std::optional< APInt > X, std::optional< APInt > Y)
Helper function to compare optional APInts: (a) if X and Y both exist, return min(X,...
static cl::opt< unsigned > MulOpsInlineThreshold("scev-mulops-inline-threshold", cl::Hidden, cl::desc("Threshold for inlining multiplication operands into a SCEV"), cl::init(32))
static void GroupByComplexity(SmallVectorImpl< const SCEV * > &Ops, LoopInfo *LI, DominatorTree &DT)
Given a list of SCEV objects, order them by their complexity, and group objects of the same complexit...
static const SCEV * constantFoldAndGroupOps(ScalarEvolution &SE, LoopInfo &LI, DominatorTree &DT, SmallVectorImpl< const SCEV * > &Ops, FoldT Fold, IsIdentityT IsIdentity, IsAbsorberT IsAbsorber)
Performs a number of common optimizations on the passed Ops.
static std::optional< const SCEV * > createNodeForSelectViaUMinSeq(ScalarEvolution *SE, const SCEV *CondExpr, const SCEV *TrueExpr, const SCEV *FalseExpr)
static Constant * BuildConstantFromSCEV(const SCEV *V)
This builds up a Constant using the ConstantExpr interface.
static ConstantInt * EvaluateConstantChrecAtConstant(const SCEVAddRecExpr *AddRec, ConstantInt *C, ScalarEvolution &SE)
static const SCEV * BinomialCoefficient(const SCEV *It, unsigned K, ScalarEvolution &SE, Type *ResultTy)
Compute BC(It, K). The result has width W. Assume, K > 0.
static cl::opt< unsigned > MaxCastDepth("scalar-evolution-max-cast-depth", cl::Hidden, cl::desc("Maximum depth of recursive SExt/ZExt/Trunc"), cl::init(8))
static bool IsMinMaxConsistingOf(const SCEV *MaybeMinMaxExpr, const SCEV *Candidate)
Is MaybeMinMaxExpr an (U|S)(Min|Max) of Candidate and some other values?
static PHINode * getConstantEvolvingPHI(Value *V, const Loop *L)
getConstantEvolvingPHI - Given an LLVM value and a loop, return a PHI node in the loop that V is deri...
static cl::opt< unsigned > MaxBruteForceIterations("scalar-evolution-max-iterations", cl::ReallyHidden, cl::desc("Maximum number of iterations SCEV will " "symbolically execute a constant " "derived loop"), cl::init(100))
static bool MatchBinarySub(const SCEV *S, const SCEV *&LHS, const SCEV *&RHS)
static uint64_t umul_ov(uint64_t i, uint64_t j, bool &Overflow)
static void PrintSCEVWithTypeHint(raw_ostream &OS, const SCEV *S)
When printing a top-level SCEV for trip counts, it's helpful to include a type for constants which ar...
static void PrintLoopInfo(raw_ostream &OS, ScalarEvolution *SE, const Loop *L)
static bool containsConstantInAddMulChain(const SCEV *StartExpr)
Determine if any of the operands in this SCEV are a constant or if any of the add or multiply express...
static const SCEV * getExtendAddRecStart(const SCEVAddRecExpr *AR, Type *Ty, ScalarEvolution *SE, unsigned Depth)
static bool hasHugeExpression(ArrayRef< const SCEV * > Ops)
Returns true if Ops contains a huge SCEV (the subtree of S contains at least HugeExprThreshold nodes)...
static cl::opt< unsigned > MaxPhiSCCAnalysisSize("scalar-evolution-max-scc-analysis-depth", cl::Hidden, cl::desc("Maximum amount of nodes to process while searching SCEVUnknown " "Phi strongly connected components"), cl::init(8))
static cl::opt< unsigned > MaxSCEVOperationsImplicationDepth("scalar-evolution-max-scev-operations-implication-depth", cl::Hidden, cl::desc("Maximum depth of recursive SCEV operations implication analysis"), cl::init(2))
static void PushDefUseChildren(Instruction *I, SmallVectorImpl< Instruction * > &Worklist, SmallPtrSetImpl< Instruction * > &Visited)
Push users of the given Instruction onto the given Worklist.
static std::optional< APInt > SolveQuadraticAddRecRange(const SCEVAddRecExpr *AddRec, const ConstantRange &Range, ScalarEvolution &SE)
Let c(n) be the value of the quadratic chrec {0,+,M,+,N} after n iterations.
static cl::opt< bool > UseContextForNoWrapFlagInference("scalar-evolution-use-context-for-no-wrap-flag-strenghening", cl::Hidden, cl::desc("Infer nuw/nsw flags using context where suitable"), cl::init(true))
static cl::opt< bool > EnableFiniteLoopControl("scalar-evolution-finite-loop", cl::Hidden, cl::desc("Handle <= and >= in finite loops"), cl::init(true))
static std::optional< std::tuple< APInt, APInt, APInt, APInt, unsigned > > GetQuadraticEquation(const SCEVAddRecExpr *AddRec)
For a given quadratic addrec, generate coefficients of the corresponding quadratic equation,...
static std::optional< BinaryOp > MatchBinaryOp(Value *V, const DataLayout &DL, AssumptionCache &AC, const DominatorTree &DT, const Instruction *CxtI)
Try to map V into a BinaryOp, and return std::nullopt on failure.
static std::optional< APInt > SolveQuadraticAddRecExact(const SCEVAddRecExpr *AddRec, ScalarEvolution &SE)
Let c(n) be the value of the quadratic chrec {L,+,M,+,N} after n iterations.
static std::optional< APInt > TruncIfPossible(std::optional< APInt > X, unsigned BitWidth)
Helper function to truncate an optional APInt to a given BitWidth.
static bool IsKnownPredicateViaAddRecStart(ScalarEvolution &SE, ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS)
static cl::opt< unsigned > MaxSCEVCompareDepth("scalar-evolution-max-scev-compare-depth", cl::Hidden, cl::desc("Maximum depth of recursive SCEV complexity comparisons"), cl::init(32))
static APInt extractConstantWithoutWrapping(ScalarEvolution &SE, const SCEVConstant *ConstantTerm, const SCEVAddExpr *WholeAddExpr)
static cl::opt< unsigned > MaxConstantEvolvingDepth("scalar-evolution-max-constant-evolving-depth", cl::Hidden, cl::desc("Maximum depth of recursive constant evolving"), cl::init(32))
static const SCEV * SolveLinEquationWithOverflow(const APInt &A, const SCEV *B, ScalarEvolution &SE)
Finds the minimum unsigned root of the following equation:
static ConstantRange getRangeForAffineARHelper(APInt Step, const ConstantRange &StartRange, const APInt &MaxBECount, bool Signed)
static std::optional< ConstantRange > GetRangeFromMetadata(Value *V)
Helper method to assign a range to V from metadata present in the IR.
static bool CollectAddOperandsWithScales(DenseMap< const SCEV *, APInt > &M, SmallVectorImpl< const SCEV * > &NewOps, APInt &AccumulatedConstant, ArrayRef< const SCEV * > Ops, const APInt &Scale, ScalarEvolution &SE)
Process the given Ops list, which is a list of operands to be added under the given scale,...
static cl::opt< unsigned > HugeExprThreshold("scalar-evolution-huge-expr-threshold", cl::Hidden, cl::desc("Size of the expression which is considered huge"), cl::init(4096))
static bool isKnownPredicateExtendIdiom(ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS)
static Type * isSimpleCastedPHI(const SCEV *Op, const SCEVUnknown *SymbolicPHI, bool &Signed, ScalarEvolution &SE)
Helper function to createAddRecFromPHIWithCasts.
static Constant * EvaluateExpression(Value *V, const Loop *L, DenseMap< Instruction *, Constant * > &Vals, const DataLayout &DL, const TargetLibraryInfo *TLI)
EvaluateExpression - Given an expression that passes the getConstantEvolvingPHI predicate,...
static const SCEV * MatchNotExpr(const SCEV *Expr)
If Expr computes ~A, return A else return nullptr.
static cl::opt< unsigned > MaxValueCompareDepth("scalar-evolution-max-value-compare-depth", cl::Hidden, cl::desc("Maximum depth of recursive value complexity comparisons"), cl::init(2))
static cl::opt< bool, true > VerifySCEVOpt("verify-scev", cl::Hidden, cl::location(VerifySCEV), cl::desc("Verify ScalarEvolution's backedge taken counts (slow)"))
static const SCEV * getSignedOverflowLimitForStep(const SCEV *Step, ICmpInst::Predicate *Pred, ScalarEvolution *SE)
static SCEV::NoWrapFlags StrengthenNoWrapFlags(ScalarEvolution *SE, SCEVTypes Type, const ArrayRef< const SCEV * > Ops, SCEV::NoWrapFlags Flags)
static cl::opt< unsigned > MaxArithDepth("scalar-evolution-max-arith-depth", cl::Hidden, cl::desc("Maximum depth of recursive arithmetics"), cl::init(32))
static bool HasSameValue(const SCEV *A, const SCEV *B)
SCEV structural equivalence is usually sufficient for testing whether two expressions are equal,...
static uint64_t Choose(uint64_t n, uint64_t k, bool &Overflow)
Compute the result of "n choose k", the binomial coefficient.
static bool canConstantEvolve(Instruction *I, const Loop *L)
Determine whether this instruction can constant evolve within this loop assuming its operands can all...
static PHINode * getConstantEvolvingPHIOperands(Instruction *UseInst, const Loop *L, DenseMap< Instruction *, PHINode * > &PHIMap, unsigned Depth)
getConstantEvolvingPHIOperands - Implement getConstantEvolvingPHI by recursing through each instructi...
static bool scevUnconditionallyPropagatesPoisonFromOperands(SCEVTypes Kind)
static cl::opt< bool > VerifySCEVStrict("verify-scev-strict", cl::Hidden, cl::desc("Enable stricter verification with -verify-scev is passed"))
static Constant * getOtherIncomingValue(PHINode *PN, BasicBlock *BB)
scalar evolution
static cl::opt< bool > UseExpensiveRangeSharpening("scalar-evolution-use-expensive-range-sharpening", cl::Hidden, cl::init(false), cl::desc("Use more powerful methods of sharpening expression ranges. May " "be costly in terms of compile time"))
static const SCEV * getUnsignedOverflowLimitForStep(const SCEV *Step, ICmpInst::Predicate *Pred, ScalarEvolution *SE)
This file defines the make_scope_exit function, which executes user-defined cleanup logic at scope ex...
Provides some synthesis utilities to produce sequences of values.
This file defines the SmallPtrSet class.
This file defines the SmallSet class.
This file defines the SmallVector class.
This file defines the 'Statistic' class, which is designed to be an easy way to expose various metric...
#define STATISTIC(VARNAME, DESC)
Definition: Statistic.h:166
This file contains some functions that are useful when dealing with strings.
static SymbolRef::Type getType(const Symbol *Sym)
Definition: TapiFile.cpp:40
static std::optional< unsigned > getOpcode(ArrayRef< VPValue * > Values)
Returns the opcode of Values or ~0 if they do not all agree.
Definition: VPlanSLP.cpp:191
Virtual Register Rewriter
Definition: VirtRegMap.cpp:237
Value * RHS
Value * LHS
static const uint32_t IV[8]
Definition: blake3_impl.h:78
Class for arbitrary precision integers.
Definition: APInt.h:78
APInt umul_ov(const APInt &RHS, bool &Overflow) const
Definition: APInt.cpp:1941
APInt udiv(const APInt &RHS) const
Unsigned division operation.
Definition: APInt.cpp:1543
APInt zext(unsigned width) const
Zero extend to a new width.
Definition: APInt.cpp:981
bool isMinSignedValue() const
Determine if this is the smallest signed value.
Definition: APInt.h:401
uint64_t getZExtValue() const
Get zero extended value.
Definition: APInt.h:1498
void setHighBits(unsigned hiBits)
Set the top hiBits bits.
Definition: APInt.h:1370
APInt getHiBits(unsigned numBits) const
Compute an APInt containing numBits highbits from this APInt.
Definition: APInt.cpp:608
APInt zextOrTrunc(unsigned width) const
Zero extend or truncate to width.
Definition: APInt.cpp:1002
unsigned getActiveBits() const
Compute the number of active bits in the value.
Definition: APInt.h:1470
APInt trunc(unsigned width) const
Truncate to new width.
Definition: APInt.cpp:906
static APInt getMaxValue(unsigned numBits)
Gets maximum unsigned value of APInt for specific bit width.
Definition: APInt.h:184
APInt abs() const
Get the absolute value.
Definition: APInt.h:1751
bool ugt(const APInt &RHS) const
Unsigned greater than comparison.
Definition: APInt.h:1160
bool isZero() const
Determine if this value is zero, i.e. all bits are clear.
Definition: APInt.h:358
bool isSignMask() const
Check if the APInt's value is returned by getSignMask.
Definition: APInt.h:444
APInt urem(const APInt &RHS) const
Unsigned remainder operation.
Definition: APInt.cpp:1636
unsigned getBitWidth() const
Return the number of bits in the APInt.
Definition: APInt.h:1446
bool ult(const APInt &RHS) const
Unsigned less than comparison.
Definition: APInt.h:1089
static APInt getSignedMaxValue(unsigned numBits)
Gets maximum signed value of APInt for a specific bit width.
Definition: APInt.h:187
static APInt getMinValue(unsigned numBits)
Gets minimum unsigned value of APInt for a specific bit width.
Definition: APInt.h:194
bool isNegative() const
Determine sign of this APInt.
Definition: APInt.h:307
bool sle(const APInt &RHS) const
Signed less or equal comparison.
Definition: APInt.h:1144
static APInt getSignedMinValue(unsigned numBits)
Gets minimum signed value of APInt for a specific bit width.
Definition: APInt.h:197
unsigned countTrailingZeros() const
Definition: APInt.h:1604
bool isStrictlyPositive() const
Determine if this APInt Value is positive.
Definition: APInt.h:334
APInt ashr(unsigned ShiftAmt) const
Arithmetic right-shift function.
Definition: APInt.h:805
APInt multiplicativeInverse() const
Definition: APInt.cpp:1244
bool ule(const APInt &RHS) const
Unsigned less or equal comparison.
Definition: APInt.h:1128
APInt sext(unsigned width) const
Sign extend to a new width.
Definition: APInt.cpp:954
APInt shl(unsigned shiftAmt) const
Left-shift function.
Definition: APInt.h:851
static APInt getLowBitsSet(unsigned numBits, unsigned loBitsSet)
Constructs an APInt value that has the bottom loBitsSet bits set.
Definition: APInt.h:284
bool isSignBitSet() const
Determine if sign bit of this APInt is set.
Definition: APInt.h:319
bool slt(const APInt &RHS) const
Signed less than comparison.
Definition: APInt.h:1108
static APInt getZero(unsigned numBits)
Get the '0' value for the specified bit-width.
Definition: APInt.h:178
bool isIntN(unsigned N) const
Check if this APInt has an N-bits unsigned integer value.
Definition: APInt.h:410
static APInt getOneBitSet(unsigned numBits, unsigned BitNo)
Return an APInt with exactly one bit set in the result.
Definition: APInt.h:217
bool uge(const APInt &RHS) const
Unsigned greater or equal comparison.
Definition: APInt.h:1199
This templated class represents "all analyses that operate over <a particular IR unit>" (e....
Definition: Analysis.h:49
API to communicate dependencies between analyses during invalidation.
Definition: PassManager.h:292
bool invalidate(IRUnitT &IR, const PreservedAnalyses &PA)
Trigger the invalidation of some other analysis pass if not already handled and return whether it was...
Definition: PassManager.h:310
A container for analyses that lazily runs them and caches their results.
Definition: PassManager.h:253
PassT::Result & getResult(IRUnitT &IR, ExtraArgTs... ExtraArgs)
Get the result of an analysis pass for a given IR unit.
Definition: PassManager.h:405
Represent the analysis usage information of a pass.
void setPreservesAll()
Set by analyses that do not transform their input at all.
AnalysisUsage & addRequiredTransitive()
ArrayRef - Represent a constant reference to an array (0 or more elements consecutively in memory),...
Definition: ArrayRef.h:41
ArrayRef< T > take_front(size_t N=1) const
Return a copy of *this with only the first N elements.
Definition: ArrayRef.h:228
iterator end() const
Definition: ArrayRef.h:154
size_t size() const
size - Get the array size.
Definition: ArrayRef.h:165
iterator begin() const
Definition: ArrayRef.h:153
A function analysis which provides an AssumptionCache.
An immutable pass that tracks lazily created AssumptionCache objects.
A cache of @llvm.assume calls within a function.
MutableArrayRef< ResultElem > assumptions()
Access the list of assumption handles currently tracked for this function.
bool isSingleEdge() const
Check if this is the only edge between Start and End.
Definition: Dominators.cpp:51
LLVM Basic Block Representation.
Definition: BasicBlock.h:61
iterator begin()
Instruction iterator methods.
Definition: BasicBlock.h:448
const Instruction & front() const
Definition: BasicBlock.h:471
const BasicBlock * getSinglePredecessor() const
Return the predecessor of this block if it has a single predecessor block.
Definition: BasicBlock.cpp:459
const Function * getParent() const
Return the enclosing method, or null if none.
Definition: BasicBlock.h:219
const Instruction * getTerminator() const LLVM_READONLY
Returns the terminator instruction if the block is well formed or null if the block is not well forme...
Definition: BasicBlock.h:239
Value * getRHS() const
unsigned getNoWrapKind() const
Returns one of OBO::NoSignedWrap or OBO::NoUnsignedWrap.
Instruction::BinaryOps getBinaryOp() const
Returns the binary operation underlying the intrinsic.
Value * getLHS() const
BinaryOps getOpcode() const
Definition: InstrTypes.h:442
Conditional or Unconditional Branch instruction.
bool isConditional() const
BasicBlock * getSuccessor(unsigned i) const
bool isUnconditional() const
Value * getCondition() const
LLVM_ATTRIBUTE_RETURNS_NONNULL void * Allocate(size_t Size, Align Alignment)
Allocate space at the specified alignment.
Definition: Allocator.h:148
This class represents a function call, abstracting a target machine's calling convention.
Value handle with callbacks on RAUW and destruction.
Definition: ValueHandle.h:383
void setValPtr(Value *P)
Definition: ValueHandle.h:390
bool isFalseWhenEqual() const
This is just a convenience.
Definition: InstrTypes.h:1062
Predicate
This enumeration lists the possible predicates for CmpInst subclasses.
Definition: InstrTypes.h:757
@ ICMP_SLT
signed less than
Definition: InstrTypes.h:786
@ ICMP_SLE
signed less or equal
Definition: InstrTypes.h:787
@ ICMP_UGE
unsigned greater or equal
Definition: InstrTypes.h:781
@ ICMP_UGT
unsigned greater than
Definition: InstrTypes.h:780
@ ICMP_SGT
signed greater than
Definition: InstrTypes.h:784
@ ICMP_ULT
unsigned less than
Definition: InstrTypes.h:782
@ ICMP_EQ
equal
Definition: InstrTypes.h:778
@ ICMP_NE
not equal
Definition: InstrTypes.h:779
@ ICMP_SGE
signed greater or equal
Definition: InstrTypes.h:785
@ ICMP_ULE
unsigned less or equal
Definition: InstrTypes.h:783
bool isSigned() const
Definition: InstrTypes.h:1007
Predicate getSwappedPredicate() const
For example, EQ->EQ, SLE->SGE, ULT->UGT, OEQ->OEQ, ULE->UGE, OLT->OGT, etc.
Definition: InstrTypes.h:909
bool isTrueWhenEqual() const
This is just a convenience.
Definition: InstrTypes.h:1056
Predicate getNonStrictPredicate() const
For example, SGT -> SGE, SLT -> SLE, ULT -> ULE, UGT -> UGE.
Definition: InstrTypes.h:953
Predicate getInversePredicate() const
For example, EQ -> NE, UGT -> ULE, SLT -> SGE, OEQ -> UNE, UGT -> OLE, OLT -> UGE,...
Definition: InstrTypes.h:871
Predicate getPredicate() const
Return the predicate for this instruction.
Definition: InstrTypes.h:847
Predicate getFlippedSignednessPredicate()
For example, SLT->ULT, ULT->SLT, SLE->ULE, ULE->SLE, EQ->Failed assert.
Definition: InstrTypes.h:1050
bool isUnsigned() const
Definition: InstrTypes.h:1013
bool isRelational() const
Return true if the predicate is relational (not EQ or NE).
Definition: InstrTypes.h:1003
static Constant * getNot(Constant *C)
Definition: Constants.cpp:2605
static Constant * getPtrToInt(Constant *C, Type *Ty, bool OnlyIfReduced=false)
Definition: Constants.cpp:2267
static Constant * getGetElementPtr(Type *Ty, Constant *C, ArrayRef< Constant * > IdxList, GEPNoWrapFlags NW=GEPNoWrapFlags::none(), std::optional< ConstantRange > InRange=std::nullopt, Type *OnlyIfReducedTy=nullptr)
Getelementptr form.
Definition: Constants.h:1253
static Constant * getAdd(Constant *C1, Constant *C2, bool HasNUW=false, bool HasNSW=false)
Definition: Constants.cpp:2611
static Constant * getNeg(Constant *C, bool HasNSW=false)
Definition: Constants.cpp:2599
static Constant * getTrunc(Constant *C, Type *Ty, bool OnlyIfReduced=false)
Definition: Constants.cpp:2253
This is the shared class of boolean and integer constants.
Definition: Constants.h:81
bool isMinusOne() const
This function will return true iff every bit in this constant is set to true.
Definition: Constants.h:218
bool isOne() const
This is just a convenience method to make client code smaller for a common case.
Definition: Constants.h:212
bool isZero() const
This is just a convenience method to make client code smaller for a common code.
Definition: Constants.h:206
static ConstantInt * getFalse(LLVMContext &Context)
Definition: Constants.cpp:857
uint64_t getZExtValue() const
Return the constant as a 64-bit unsigned integer value after it has been zero extended as appropriate...
Definition: Constants.h:155
const APInt & getValue() const
Return the constant as an APInt value reference.
Definition: Constants.h:146
static ConstantInt * getBool(LLVMContext &Context, bool V)
Definition: Constants.cpp:864
This class represents a range of values.
Definition: ConstantRange.h:47
ConstantRange add(const ConstantRange &Other) const
Return a new range representing the possible values resulting from an addition of a value in this ran...
ConstantRange zextOrTrunc(uint32_t BitWidth) const
Make this range have the bit width given by BitWidth.
PreferredRangeType
If represented precisely, the result of some range operations may consist of multiple disjoint ranges...
bool getEquivalentICmp(CmpInst::Predicate &Pred, APInt &RHS) const
Set up Pred and RHS such that ConstantRange::makeExactICmpRegion(Pred, RHS) == *this.
ConstantRange subtract(const APInt &CI) const
Subtract the specified constant from the endpoints of this constant range.
const APInt & getLower() const
Return the lower value for this range.
ConstantRange truncate(uint32_t BitWidth) const
Return a new range in the specified integer type, which must be strictly smaller than the current typ...
bool isFullSet() const
Return true if this set contains all of the elements possible for this data-type.
bool icmp(CmpInst::Predicate Pred, const ConstantRange &Other) const
Does the predicate Pred hold between ranges this and Other? NOTE: false does not mean that inverse pr...
bool isEmptySet() const
Return true if this set contains no members.
ConstantRange zeroExtend(uint32_t BitWidth) const
Return a new range in the specified integer type, which must be strictly larger than the current type...
bool isSignWrappedSet() const
Return true if this set wraps around the signed domain.
APInt getSignedMin() const
Return the smallest signed value contained in the ConstantRange.
bool isWrappedSet() const
Return true if this set wraps around the unsigned domain.
void print(raw_ostream &OS) const
Print out the bounds to a stream.
ConstantRange signExtend(uint32_t BitWidth) const
Return a new range in the specified integer type, which must be strictly larger than the current type...
const APInt & getUpper() const
Return the upper value for this range.
ConstantRange unionWith(const ConstantRange &CR, PreferredRangeType Type=Smallest) const
Return the range that results from the union of this range with another range.
static ConstantRange makeExactICmpRegion(CmpInst::Predicate Pred, const APInt &Other)
Produce the exact range such that all values in the returned range satisfy the given predicate with a...
bool contains(const APInt &Val) const
Return true if the specified value is in the set.
APInt getUnsignedMax() const
Return the largest unsigned value contained in the ConstantRange.
ConstantRange intersectWith(const ConstantRange &CR, PreferredRangeType Type=Smallest) const
Return the range that results from the intersection of this range with another range.
APInt getSignedMax() const
Return the largest signed value contained in the ConstantRange.
static ConstantRange getNonEmpty(APInt Lower, APInt Upper)
Create non-empty constant range with the given bounds.
Definition: ConstantRange.h:84
static ConstantRange makeGuaranteedNoWrapRegion(Instruction::BinaryOps BinOp, const ConstantRange &Other, unsigned NoWrapKind)
Produce the largest range containing all X such that "X BinOp Y" is guaranteed not to wrap (overflow)...
unsigned getMinSignedBits() const
Compute the maximal number of bits needed to represent every value in this signed range.
uint32_t getBitWidth() const
Get the bit width of this ConstantRange.
ConstantRange sub(const ConstantRange &Other) const
Return a new range representing the possible values resulting from a subtraction of a value in this r...
ConstantRange sextOrTrunc(uint32_t BitWidth) const
Make this range have the bit width given by BitWidth.
static ConstantRange makeExactNoWrapRegion(Instruction::BinaryOps BinOp, const APInt &Other, unsigned NoWrapKind)
Produce the range that contains X if and only if "X BinOp Other" does not wrap.
This is an important base class in LLVM.
Definition: Constant.h:42
bool isNullValue() const
Return true if this is the value that would be returned by getNullValue.
Definition: Constants.cpp:90
This class represents an Operation in the Expression.
A parsed version of the target data layout string in and methods for querying it.
Definition: DataLayout.h:63
const StructLayout * getStructLayout(StructType *Ty) const
Returns a StructLayout object, indicating the alignment of the struct, its size, and the offsets of i...
Definition: DataLayout.cpp:695
IntegerType * getIntPtrType(LLVMContext &C, unsigned AddressSpace=0) const
Returns an integer type with size at least as big as that of a pointer in the given address space.
Definition: DataLayout.cpp:846
unsigned getIndexTypeSizeInBits(Type *Ty) const
Layout size of the index used in GEP calculation.
Definition: DataLayout.cpp:749
IntegerType * getIndexType(LLVMContext &C, unsigned AddressSpace) const
Returns the type of a GEP index in AddressSpace.
Definition: DataLayout.cpp:873
TypeSize getTypeSizeInBits(Type *Ty) const
Size examples:
Definition: DataLayout.h:621
ValueT lookup(const_arg_type_t< KeyT > Val) const
lookup - Return the entry for the specified key, or a default constructed value if no such entry exis...
Definition: DenseMap.h:194
iterator find(const_arg_type_t< KeyT > Val)
Definition: DenseMap.h:155
bool erase(const KeyT &Val)
Definition: DenseMap.h:336
DenseMapIterator< KeyT, ValueT, KeyInfoT, BucketT > iterator
Definition: DenseMap.h:71
iterator find_as(const LookupKeyT &Val)
Alternate version of find() which allows a different, and possibly less expensive,...
Definition: DenseMap.h:176
size_type count(const_arg_type_t< KeyT > Val) const
Return 1 if the specified key is in the map, 0 otherwise.
Definition: DenseMap.h:151
iterator end()
Definition: DenseMap.h:84
bool contains(const_arg_type_t< KeyT > Val) const
Return true if the specified key is in the map, false otherwise.
Definition: DenseMap.h:146
std::pair< iterator, bool > insert(const std::pair< KeyT, ValueT > &KV)
Definition: DenseMap.h:211
Analysis pass which computes a DominatorTree.
Definition: Dominators.h:279
bool properlyDominates(const DomTreeNodeBase< NodeT > *A, const DomTreeNodeBase< NodeT > *B) const
properlyDominates - Returns true iff A dominates B and A != B.
Legacy analysis pass which computes a DominatorTree.
Definition: Dominators.h:317
Concrete subclass of DominatorTreeBase that is used to compute a normal dominator tree.
Definition: Dominators.h:162
bool isReachableFromEntry(const Use &U) const
Provide an overload for a Use.
Definition: Dominators.cpp:321
bool dominates(const BasicBlock *BB, const Use &U) const
Return true if the (end of the) basic block BB dominates the use U.
Definition: Dominators.cpp:122
EquivalenceClasses - This represents a collection of equivalence classes and supports three efficient...
member_iterator unionSets(const ElemTy &V1, const ElemTy &V2)
union - Merge the two equivalence sets for the specified values, inserting them if they do not alread...
bool isEquivalent(const ElemTy &V1, const ElemTy &V2) const
FoldingSetNodeIDRef - This class describes a reference to an interned FoldingSetNodeID,...
Definition: FoldingSet.h:290
FoldingSetNodeID - This class is used to gather all the unique data bits of a node.
Definition: FoldingSet.h:327
FunctionPass class - This class is used to implement most global optimizations.
Definition: Pass.h:310
const BasicBlock & getEntryBlock() const
Definition: Function.h:807
bool hasFnAttribute(Attribute::AttrKind Kind) const
Return true if the function has the attribute.
Definition: Function.cpp:743
Represents flags for the getelementptr instruction/expression.
bool hasNoUnsignedSignedWrap() const
bool hasNoUnsignedWrap() const
static GEPNoWrapFlags none()
static Type * getTypeAtIndex(Type *Ty, Value *Idx)
Return the type of the element at the given index of an indexable type.
Module * getParent()
Get the module that this global value is contained inside of...
Definition: GlobalValue.h:656
static bool isPrivateLinkage(LinkageTypes Linkage)
Definition: GlobalValue.h:406
static bool isInternalLinkage(LinkageTypes Linkage)
Definition: GlobalValue.h:403
This instruction compares its operands according to the predicate given to the constructor.
static bool isGE(Predicate P)
Return true if the predicate is SGE or UGE.
static bool compare(const APInt &LHS, const APInt &RHS, ICmpInst::Predicate Pred)
Return result of LHS Pred RHS comparison.
static bool isLT(Predicate P)
Return true if the predicate is SLT or ULT.
static bool isGT(Predicate P)
Return true if the predicate is SGT or UGT.
bool isEquality() const
Return true if this predicate is either EQ or NE.
bool isRelational() const
Return true if the predicate is relational (not EQ or NE).
static bool isLE(Predicate P)
Return true if the predicate is SLE or ULE.
bool hasNoUnsignedWrap() const LLVM_READONLY
Determine whether the no unsigned wrap flag is set.
bool hasNoSignedWrap() const LLVM_READONLY
Determine whether the no signed wrap flag is set.
Class to represent integer types.
Definition: DerivedTypes.h:40
static IntegerType * get(LLVMContext &C, unsigned NumBits)
This static method is the primary way of constructing an IntegerType.
Definition: Type.cpp:266
An instruction for reading from memory.
Definition: Instructions.h:174
Analysis pass that exposes the LoopInfo for a function.
Definition: LoopInfo.h:566
bool contains(const LoopT *L) const
Return true if the specified loop is contained within in this loop.
BlockT * getHeader() const
unsigned getLoopDepth() const
Return the nesting level of this loop.
BlockT * getLoopPredecessor() const
If the given loop's header has exactly one unique predecessor outside the loop, return it.
LoopT * getParentLoop() const
Return the parent loop if it exists or nullptr for top level loops.
iterator end() const
unsigned getLoopDepth(const BlockT *BB) const
Return the loop nesting level of the specified block.
iterator begin() const
LoopT * getLoopFor(const BlockT *BB) const
Return the inner most loop that BB lives in.
The legacy pass manager's analysis pass to compute loop information.
Definition: LoopInfo.h:593
Represents a single loop in the control flow graph.
Definition: LoopInfo.h:39
bool isLoopInvariant(const Value *V) const
Return true if the specified value is loop invariant.
Definition: LoopInfo.cpp:61
Metadata node.
Definition: Metadata.h:1069
A Module instance is used to store all the information related to an LLVM module.
Definition: Module.h:65
Function * getFunction(StringRef Name) const
Look up the specified function in the module symbol table.
Definition: Module.cpp:193
This is a utility class that provides an abstraction for the common functionality between Instruction...
Definition: Operator.h:32
unsigned getOpcode() const
Return the opcode for this Instruction or ConstantExpr.
Definition: Operator.h:42
Utility class for integer operators which may exhibit overflow - Add, Sub, Mul, and Shl.
Definition: Operator.h:77
bool hasNoSignedWrap() const
Test whether this operation is known to never undergo signed overflow, aka the nsw property.
Definition: Operator.h:110
bool hasNoUnsignedWrap() const
Test whether this operation is known to never undergo unsigned overflow, aka the nuw property.
Definition: Operator.h:104
iterator_range< const_block_iterator > blocks() const
Value * getIncomingValueForBlock(const BasicBlock *BB) const
BasicBlock * getIncomingBlock(unsigned i) const
Return incoming basic block number i.
Value * getIncomingValue(unsigned i) const
Return incoming value number x.
unsigned getNumIncomingValues() const
Return the number of incoming edges.
static PassRegistry * getPassRegistry()
getPassRegistry - Access the global registry object, which is automatically initialized at applicatio...
PointerIntPair - This class implements a pair of a pointer and small integer.
static PointerType * getUnqual(Type *ElementType)
This constructs a pointer to an object of the specified type in the default address space (address sp...
Definition: DerivedTypes.h:662
static PoisonValue * get(Type *T)
Static factory methods - Return an 'poison' object of the specified type.
Definition: Constants.cpp:1852
An interface layer with SCEV used to manage how we see SCEV expressions for values in the context of ...
void addPredicate(const SCEVPredicate &Pred)
Adds a new predicate.
const SCEVPredicate & getPredicate() const
bool hasNoOverflow(Value *V, SCEVWrapPredicate::IncrementWrapFlags Flags)
Returns true if we've proved that V doesn't wrap by means of a SCEV predicate.
void setNoOverflow(Value *V, SCEVWrapPredicate::IncrementWrapFlags Flags)
Proves that V doesn't overflow by adding SCEV predicate.
void print(raw_ostream &OS, unsigned Depth) const
Print the SCEV mappings done by the Predicated Scalar Evolution.
bool areAddRecsEqualWithPreds(const SCEVAddRecExpr *AR1, const SCEVAddRecExpr *AR2) const
Check if AR1 and AR2 are equal, while taking into account Equal predicates in Preds.
PredicatedScalarEvolution(ScalarEvolution &SE, Loop &L)
const SCEVAddRecExpr * getAsAddRec(Value *V)
Attempts to produce an AddRecExpr for V by adding additional SCEV predicates.
const SCEV * getBackedgeTakenCount()
Get the (predicated) backedge count for the analyzed loop.
const SCEV * getSymbolicMaxBackedgeTakenCount()
Get the (predicated) symbolic max backedge count for the analyzed loop.
const SCEV * getSCEV(Value *V)
Returns the SCEV expression of V, in the context of the current SCEV predicate.
A set of analyses that are preserved following a run of a transformation pass.
Definition: Analysis.h:111
static PreservedAnalyses all()
Construct a special preserved set that preserves all passes.
Definition: Analysis.h:117
PreservedAnalysisChecker getChecker() const
Build a checker for this PreservedAnalyses and the specified analysis type.
Definition: Analysis.h:264
constexpr bool isValid() const
Definition: Register.h:116
This node represents an addition of some number of SCEVs.
This node represents a polynomial recurrence on the trip count of the specified loop.
const SCEV * evaluateAtIteration(const SCEV *It, ScalarEvolution &SE) const
Return the value of this chain of recurrences at the specified iteration number.
const SCEV * getStepRecurrence(ScalarEvolution &SE) const
Constructs and returns the recurrence indicating how much this expression steps by.
void setNoWrapFlags(NoWrapFlags Flags)
Set flags for a recurrence without clearing any previously set flags.
bool isAffine() const
Return true if this represents an expression A + B*x where A and B are loop invariant values.
bool isQuadratic() const
Return true if this represents an expression A + B*x + C*x^2 where A, B and C are loop invariant valu...
const SCEV * getNumIterationsInRange(const ConstantRange &Range, ScalarEvolution &SE) const
Return the number of iterations of this loop that produce values in the specified constant range.
const SCEVAddRecExpr * getPostIncExpr(ScalarEvolution &SE) const
Return an expression representing the value of this expression one iteration of the loop ahead.
This is the base class for unary cast operator classes.
const SCEV * getOperand() const
SCEVCastExpr(const FoldingSetNodeIDRef ID, SCEVTypes SCEVTy, const SCEV *op, Type *ty)
void setNoWrapFlags(NoWrapFlags Flags)
Set flags for a non-recurrence without clearing previously set flags.
This class represents an assumption that the expression LHS Pred RHS evaluates to true,...
SCEVComparePredicate(const FoldingSetNodeIDRef ID, const ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS)
bool isAlwaysTrue() const override
Returns true if the predicate is always true.
bool implies(const SCEVPredicate *N) const override
Implementation of the SCEVPredicate interface.
void print(raw_ostream &OS, unsigned Depth=0) const override
Prints a textual representation of this predicate with an indentation of Depth.
This class represents a constant integer value.
ConstantInt * getValue() const
const APInt & getAPInt() const
This is the base class for unary integral cast operator classes.
SCEVIntegralCastExpr(const FoldingSetNodeIDRef ID, SCEVTypes SCEVTy, const SCEV *op, Type *ty)
This node is the base class min/max selections.
static enum SCEVTypes negate(enum SCEVTypes T)
This node represents multiplication of some number of SCEVs.
This node is a base class providing common functionality for n'ary operators.
NoWrapFlags getNoWrapFlags(NoWrapFlags Mask=NoWrapMask) const
const SCEV * getOperand(unsigned i) const
const SCEV *const * Operands
ArrayRef< const SCEV * > operands() const
This class represents an assumption made using SCEV expressions which can be checked at run-time.
virtual bool implies(const SCEVPredicate *N) const =0
Returns true if this predicate implies N.
SCEVPredicate(const SCEVPredicate &)=default
virtual void print(raw_ostream &OS, unsigned Depth=0) const =0
Prints a textual representation of this predicate with an indentation of Depth.
This class represents a cast from a pointer to a pointer-sized integer value.
This visitor recursively visits a SCEV expression and re-writes it.
const SCEV * visitSignExtendExpr(const SCEVSignExtendExpr *Expr)
const SCEV * visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr)
const SCEV * visitSMinExpr(const SCEVSMinExpr *Expr)
const SCEV * visitUMinExpr(const SCEVUMinExpr *Expr)
This class represents a signed maximum selection.
This class represents a signed minimum selection.
This node is the base class for sequential/in-order min/max selections.
This class represents a sequential/in-order unsigned minimum selection.
This class represents a sign extension of a small integer value to a larger integer value.
Visit all nodes in the expression tree using worklist traversal.
void visitAll(const SCEV *Root)
This class represents a truncation of an integer value to a smaller integer value.
This class represents a binary unsigned division operation.
const SCEV * getLHS() const
const SCEV * getRHS() const
This class represents an unsigned maximum selection.
This class represents an unsigned minimum selection.
This class represents a composition of other SCEV predicates, and is the class that most clients will...
SCEVUnionPredicate(ArrayRef< const SCEVPredicate * > Preds)
Union predicates don't get cached so create a dummy set ID for it.
void print(raw_ostream &OS, unsigned Depth) const override
Prints a textual representation of this predicate with an indentation of Depth.
bool isAlwaysTrue() const override
Implementation of the SCEVPredicate interface.
bool implies(const SCEVPredicate *N) const override
Returns true if this predicate implies N.
This means that we are dealing with an entirely unknown SCEV value, and only represent it as its LLVM...
This class represents the value of vscale, as used when defining the length of a scalable vector or r...
This class represents an assumption made on an AddRec expression.
IncrementWrapFlags
Similar to SCEV::NoWrapFlags, but with slightly different semantics for FlagNUSW.
SCEVWrapPredicate(const FoldingSetNodeIDRef ID, const SCEVAddRecExpr *AR, IncrementWrapFlags Flags)
bool implies(const SCEVPredicate *N) const override
Returns true if this predicate implies N.
static SCEVWrapPredicate::IncrementWrapFlags setFlags(SCEVWrapPredicate::IncrementWrapFlags Flags, SCEVWrapPredicate::IncrementWrapFlags OnFlags)
void print(raw_ostream &OS, unsigned Depth=0) const override
Prints a textual representation of this predicate with an indentation of Depth.
bool isAlwaysTrue() const override
Returns true if the predicate is always true.
const SCEVAddRecExpr * getExpr() const
Implementation of the SCEVPredicate interface.
static SCEVWrapPredicate::IncrementWrapFlags clearFlags(SCEVWrapPredicate::IncrementWrapFlags Flags, SCEVWrapPredicate::IncrementWrapFlags OffFlags)
Convenient IncrementWrapFlags manipulation methods.
static SCEVWrapPredicate::IncrementWrapFlags getImpliedFlags(const SCEVAddRecExpr *AR, ScalarEvolution &SE)
Returns the set of SCEVWrapPredicate no wrap flags implied by a SCEVAddRecExpr.
IncrementWrapFlags getFlags() const
Returns the set assumed no overflow flags.
This class represents a zero extension of a small integer value to a larger integer value.
This class represents an analyzed expression in the program.
ArrayRef< const SCEV * > operands() const
Return operands of this SCEV expression.
unsigned short getExpressionSize() const
bool isOne() const
Return true if the expression is a constant one.
bool isZero() const
Return true if the expression is a constant zero.
void dump() const
This method is used for debugging.
bool isAllOnesValue() const
Return true if the expression is a constant all-ones value.
bool isNonConstantNegative() const
Return true if the specified scev is negated, but not a constant.
void print(raw_ostream &OS) const
Print out the internal representation of this scalar to the specified stream.
SCEVTypes getSCEVType() const
Type * getType() const
Return the LLVM type of this SCEV expression.
NoWrapFlags
NoWrapFlags are bitfield indices into SubclassData.
Analysis pass that exposes the ScalarEvolution for a function.
ScalarEvolution run(Function &F, FunctionAnalysisManager &AM)
PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM)
PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM)
void getAnalysisUsage(AnalysisUsage &AU) const override
getAnalysisUsage - This function should be overriden by passes that need analysis information to do t...
void print(raw_ostream &OS, const Module *=nullptr) const override
print - Print out the internal state of the pass.
bool runOnFunction(Function &F) override
runOnFunction - Virtual method overriden by subclasses to do the per-function processing of the pass.
void releaseMemory() override
releaseMemory() - This member can be implemented by a pass if it wants to be able to release its memo...
void verifyAnalysis() const override
verifyAnalysis() - This member can be implemented by a analysis pass to check state of analysis infor...
static LoopGuards collect(const Loop *L, ScalarEvolution &SE)
Collect rewrite map for loop guards for loop L, together with flags indicating if NUW and NSW can be ...
const SCEV * rewrite(const SCEV *Expr) const
Try to apply the collected loop guards to Expr.
The main scalar evolution driver.
const SCEV * getConstantMaxBackedgeTakenCount(const Loop *L)
When successful, this returns a SCEVConstant that is greater than or equal to (i.e.
static bool hasFlags(SCEV::NoWrapFlags Flags, SCEV::NoWrapFlags TestFlags)
const DataLayout & getDataLayout() const
Return the DataLayout associated with the module this SCEV instance is operating on.
bool isKnownNonNegative(const SCEV *S)
Test if the given expression is known to be non-negative.
const SCEV * getNegativeSCEV(const SCEV *V, SCEV::NoWrapFlags Flags=SCEV::FlagAnyWrap)
Return the SCEV object corresponding to -V.
bool isLoopBackedgeGuardedByCond(const Loop *L, ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS)
Test whether the backedge of the loop is protected by a conditional between LHS and RHS.
const SCEV * getSMaxExpr(const SCEV *LHS, const SCEV *RHS)
const SCEV * getUDivCeilSCEV(const SCEV *N, const SCEV *D)
Compute ceil(N / D).
const SCEV * getGEPExpr(GEPOperator *GEP, const SmallVectorImpl< const SCEV * > &IndexExprs)
Returns an expression for a GEP.
Type * getWiderType(Type *Ty1, Type *Ty2) const
const SCEV * getAbsExpr(const SCEV *Op, bool IsNSW)
bool isKnownNonPositive(const SCEV *S)
Test if the given expression is known to be non-positive.
const SCEV * getURemExpr(const SCEV *LHS, const SCEV *RHS)
Represents an unsigned remainder expression based on unsigned division.
bool SimplifyICmpOperands(ICmpInst::Predicate &Pred, const SCEV *&LHS, const SCEV *&RHS, unsigned Depth=0)
Simplify LHS and RHS in a comparison with predicate Pred.
APInt getConstantMultiple(const SCEV *S)
Returns the max constant multiple of S.
bool isKnownNegative(const SCEV *S)
Test if the given expression is known to be negative.
const SCEV * removePointerBase(const SCEV *S)
Compute an expression equivalent to S - getPointerBase(S).
bool isKnownNonZero(const SCEV *S)
Test if the given expression is known to be non-zero.
const SCEV * getSCEVAtScope(const SCEV *S, const Loop *L)
Return a SCEV expression for the specified value at the specified scope in the program.
const SCEV * getSMinExpr(const SCEV *LHS, const SCEV *RHS)
const SCEV * getBackedgeTakenCount(const Loop *L, ExitCountKind Kind=Exact)
If the specified loop has a predictable backedge-taken count, return it, otherwise return a SCEVCould...
const SCEV * getUMaxExpr(const SCEV *LHS, const SCEV *RHS)
void setNoWrapFlags(SCEVAddRecExpr *AddRec, SCEV::NoWrapFlags Flags)
Update no-wrap flags of an AddRec.
const SCEV * getUMaxFromMismatchedTypes(const SCEV *LHS, const SCEV *RHS)
Promote the operands to the wider of the types using zero-extension, and then perform a umax operatio...
const SCEV * getZero(Type *Ty)
Return a SCEV for the constant 0 of a specific type.
bool willNotOverflow(Instruction::BinaryOps BinOp, bool Signed, const SCEV *LHS, const SCEV *RHS, const Instruction *CtxI=nullptr)
Is operation BinOp between LHS and RHS provably does not have a signed/unsigned overflow (Signed)?...
ExitLimit computeExitLimitFromCond(const Loop *L, Value *ExitCond, bool ExitIfTrue, bool ControlsOnlyExit, bool AllowPredicates=false)
Compute the number of times the backedge of the specified loop will execute if its exit condition wer...
const SCEV * getZeroExtendExprImpl(const SCEV *Op, Type *Ty, unsigned Depth=0)
const SCEVPredicate * getEqualPredicate(const SCEV *LHS, const SCEV *RHS)
unsigned getSmallConstantTripMultiple(const Loop *L, const SCEV *ExitCount)
Returns the largest constant divisor of the trip count as a normal unsigned value,...
uint64_t getTypeSizeInBits(Type *Ty) const
Return the size in bits of the specified type, for which isSCEVable must return true.
const SCEV * getConstant(ConstantInt *V)
const SCEV * getSCEV(Value *V)
Return a SCEV expression for the full generality of the specified expression.
ConstantRange getSignedRange(const SCEV *S)
Determine the signed range for a particular SCEV.
const SCEV * getNoopOrSignExtend(const SCEV *V, Type *Ty)
Return a SCEV corresponding to a conversion of the input value to the specified type.
unsigned getSmallConstantMaxTripCount(const Loop *L)
Returns the upper bound of the loop trip count as a normal unsigned value.
bool loopHasNoAbnormalExits(const Loop *L)
Return true if the loop has no abnormal exits.
const SCEV * getTripCountFromExitCount(const SCEV *ExitCount)
A version of getTripCountFromExitCount below which always picks an evaluation type which can not resu...
ScalarEvolution(Function &F, TargetLibraryInfo &TLI, AssumptionCache &AC, DominatorTree &DT, LoopInfo &LI)
const SCEV * getOne(Type *Ty)
Return a SCEV for the constant 1 of a specific type.
const SCEV * getTruncateOrNoop(const SCEV *V, Type *Ty)
Return a SCEV corresponding to a conversion of the input value to the specified type.
const SCEV * getCastExpr(SCEVTypes Kind, const SCEV *Op, Type *Ty)
const SCEV * getSequentialMinMaxExpr(SCEVTypes Kind, SmallVectorImpl< const SCEV * > &Operands)
const SCEV * getLosslessPtrToIntExpr(const SCEV *Op, unsigned Depth=0)
bool isKnownViaInduction(ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS)
We'd like to check the predicate on every iteration of the most dominated loop between loops used in ...
std::optional< bool > evaluatePredicate(ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS)
Check whether the condition described by Pred, LHS, and RHS is true or false.
bool isKnownPredicateAt(ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS, const Instruction *CtxI)
Test if the given expression is known to satisfy the condition described by Pred, LHS,...
const SCEV * getPtrToIntExpr(const SCEV *Op, Type *Ty)
bool isBackedgeTakenCountMaxOrZero(const Loop *L)
Return true if the backedge taken count is either the value returned by getConstantMaxBackedgeTakenCo...
void forgetLoop(const Loop *L)
This method should be called by the client when it has changed a loop in a way that may effect Scalar...
bool isLoopInvariant(const SCEV *S, const Loop *L)
Return true if the value of the given SCEV is unchanging in the specified loop.
bool isKnownPositive(const SCEV *S)
Test if the given expression is known to be positive.
APInt getUnsignedRangeMin(const SCEV *S)
Determine the min of the unsigned range for a particular SCEV.
bool isKnownPredicate(ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS)
Test if the given expression is known to satisfy the condition described by Pred, LHS,...
const SCEV * getOffsetOfExpr(Type *IntTy, StructType *STy, unsigned FieldNo)
Return an expression for offsetof on the given field with type IntTy.
LoopDisposition getLoopDisposition(const SCEV *S, const Loop *L)
Return the "disposition" of the given SCEV with respect to the given loop.
bool containsAddRecurrence(const SCEV *S)
Return true if the SCEV is a scAddRecExpr or it contains scAddRecExpr.
const SCEV * getSignExtendExprImpl(const SCEV *Op, Type *Ty, unsigned Depth=0)
const SCEV * getAddRecExpr(const SCEV *Start, const SCEV *Step, const Loop *L, SCEV::NoWrapFlags Flags)
Get an add recurrence expression for the specified loop.
bool isBasicBlockEntryGuardedByCond(const BasicBlock *BB, ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS)
Test whether entry to the basic block is protected by a conditional between LHS and RHS.
bool isKnownOnEveryIteration(ICmpInst::Predicate Pred, const SCEVAddRecExpr *LHS, const SCEV *RHS)
Test if the condition described by Pred, LHS, RHS is known to be true on every iteration of the loop ...
bool hasOperand(const SCEV *S, const SCEV *Op) const
Test whether the given SCEV has Op as a direct or indirect operand.
const SCEV * getUDivExpr(const SCEV *LHS, const SCEV *RHS)
Get a canonical unsigned division expression, or something simpler if possible.
const SCEV * getZeroExtendExpr(const SCEV *Op, Type *Ty, unsigned Depth=0)
bool isSCEVable(Type *Ty) const
Test if values of the given type are analyzable within the SCEV framework.
Type * getEffectiveSCEVType(Type *Ty) const
Return a type with the same bitwidth as the given type and which represents how SCEV will treat the g...
const SCEVPredicate * getComparePredicate(ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS)
const SCEV * getNotSCEV(const SCEV *V)
Return the SCEV object corresponding to ~V.
std::optional< LoopInvariantPredicate > getLoopInvariantPredicate(ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS, const Loop *L, const Instruction *CtxI=nullptr)
If the result of the predicate LHS Pred RHS is loop invariant with respect to L, return a LoopInvaria...
bool instructionCouldExistWithOperands(const SCEV *A, const SCEV *B)
Return true if there exists a point in the program at which both A and B could be operands to the sam...
std::optional< bool > evaluatePredicateAt(ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS, const Instruction *CtxI)
Check whether the condition described by Pred, LHS, and RHS is true or false in the given Context.
ConstantRange getUnsignedRange(const SCEV *S)
Determine the unsigned range for a particular SCEV.
uint32_t getMinTrailingZeros(const SCEV *S)
Determine the minimum number of zero bits that S is guaranteed to end in (at every loop iteration).
void print(raw_ostream &OS) const
const SCEV * getUMinExpr(const SCEV *LHS, const SCEV *RHS, bool Sequential=false)
const SCEV * getPredicatedBackedgeTakenCount(const Loop *L, SmallVector< const SCEVPredicate *, 4 > &Predicates)
Similar to getBackedgeTakenCount, except it will add a set of SCEV predicates to Predicates that are ...
static SCEV::NoWrapFlags clearFlags(SCEV::NoWrapFlags Flags, SCEV::NoWrapFlags OffFlags)
void forgetTopmostLoop(const Loop *L)
void forgetValue(Value *V)
This method should be called by the client when it has changed a value in a way that may effect its v...
APInt getSignedRangeMin(const SCEV *S)
Determine the min of the signed range for a particular SCEV.
const SCEV * getNoopOrAnyExtend(const SCEV *V, Type *Ty)
Return a SCEV corresponding to a conversion of the input value to the specified type.
void forgetBlockAndLoopDispositions(Value *V=nullptr)
Called when the client has changed the disposition of values in a loop or block.
const SCEV * getTruncateExpr(const SCEV *Op, Type *Ty, unsigned Depth=0)
const SCEV * getStoreSizeOfExpr(Type *IntTy, Type *StoreTy)
Return an expression for the store size of StoreTy that is type IntTy.
const SCEVPredicate * getWrapPredicate(const SCEVAddRecExpr *AR, SCEVWrapPredicate::IncrementWrapFlags AddedFlags)
const SCEV * getMinusSCEV(const SCEV *LHS, const SCEV *RHS, SCEV::NoWrapFlags Flags=SCEV::FlagAnyWrap, unsigned Depth=0)
Return LHS-RHS.
APInt getNonZeroConstantMultiple(const SCEV *S)
const SCEV * getMinusOne(Type *Ty)
Return a SCEV for the constant -1 of a specific type.
static SCEV::NoWrapFlags setFlags(SCEV::NoWrapFlags Flags, SCEV::NoWrapFlags OnFlags)
bool hasLoopInvariantBackedgeTakenCount(const Loop *L)
Return true if the specified loop has an analyzable loop-invariant backedge-taken count.
BlockDisposition getBlockDisposition(const SCEV *S, const BasicBlock *BB)
Return the "disposition" of the given SCEV with respect to the given block.
const SCEV * getNoopOrZeroExtend(const SCEV *V, Type *Ty)
Return a SCEV corresponding to a conversion of the input value to the specified type.
bool invalidate(Function &F, const PreservedAnalyses &PA, FunctionAnalysisManager::Invalidator &Inv)
const SCEV * getUMinFromMismatchedTypes(const SCEV *LHS, const SCEV *RHS, bool Sequential=false)
Promote the operands to the wider of the types using zero-extension, and then perform a umin operatio...
bool loopIsFiniteByAssumption(const Loop *L)
Return true if this loop is finite by assumption.
const SCEV * getExistingSCEV(Value *V)
Return an existing SCEV for V if there is one, otherwise return nullptr.
LoopDisposition
An enum describing the relationship between a SCEV and a loop.
@ LoopComputable
The SCEV varies predictably with the loop.
@ LoopVariant
The SCEV is loop-variant (unknown).
@ LoopInvariant
The SCEV is loop-invariant.
const SCEV * getAnyExtendExpr(const SCEV *Op, Type *Ty)
getAnyExtendExpr - Return a SCEV for the given operand extended with unspecified bits out to the give...
const SCEVAddRecExpr * convertSCEVToAddRecWithPredicates(const SCEV *S, const Loop *L, SmallPtrSetImpl< const SCEVPredicate * > &Preds)
Tries to convert the S expression to an AddRec expression, adding additional predicates to Preds as r...
bool isKnownToBeAPowerOfTwo(const SCEV *S, bool OrZero=false, bool OrNegative=false)
Test if the given expression is known to be a power of 2.
std::optional< SCEV::NoWrapFlags > getStrengthenedNoWrapFlagsFromBinOp(const OverflowingBinaryOperator *OBO)
Parse NSW/NUW flags from add/sub/mul IR binary operation Op into SCEV no-wrap flags,...
void forgetLcssaPhiWithNewPredecessor(Loop *L, PHINode *V)
Forget LCSSA phi node V of loop L to which a new predecessor was added, such that it may no longer be...
bool containsUndefs(const SCEV *S) const
Return true if the SCEV expression contains an undef value.
std::optional< MonotonicPredicateType > getMonotonicPredicateType(const SCEVAddRecExpr *LHS, ICmpInst::Predicate Pred)
If, for all loop invariant X, the predicate "LHS `Pred` X" is monotonically increasing or decreasing,...
const SCEV * getCouldNotCompute()
bool isAvailableAtLoopEntry(const SCEV *S, const Loop *L)
Determine if the SCEV can be evaluated at loop's entry.
BlockDisposition
An enum describing the relationship between a SCEV and a basic block.
@ DominatesBlock
The SCEV dominates the block.
@ ProperlyDominatesBlock
The SCEV properly dominates the block.
@ DoesNotDominateBlock
The SCEV does not dominate the block.
std::optional< LoopInvariantPredicate > getLoopInvariantExitCondDuringFirstIterationsImpl(ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS, const Loop *L, const Instruction *CtxI, const SCEV *MaxIter)
const SCEV * getExitCount(const Loop *L, const BasicBlock *ExitingBlock, ExitCountKind Kind=Exact)
Return the number of times the backedge executes before the given exit would be taken; if not exactly...
const SCEV * getSignExtendExpr(const SCEV *Op, Type *Ty, unsigned Depth=0)
void getPoisonGeneratingValues(SmallPtrSetImpl< const Value * > &Result, const SCEV *S)
Return the set of Values that, if poison, will definitively result in S being poison as well.
void forgetLoopDispositions()
Called when the client has changed the disposition of values in this loop.
const SCEV * getVScale(Type *Ty)
unsigned getSmallConstantTripCount(const Loop *L)
Returns the exact trip count of the loop if we can compute it, and the result is a small constant.
bool hasComputableLoopEvolution(const SCEV *S, const Loop *L)
Return true if the given SCEV changes value in a known way in the specified loop.
const SCEV * getPointerBase(const SCEV *V)
Transitively follow the chain of pointer-type operands until reaching a SCEV that does not have a sin...
const SCEV * getMinMaxExpr(SCEVTypes Kind, SmallVectorImpl< const SCEV * > &Operands)
bool dominates(const SCEV *S, const BasicBlock *BB)
Return true if elements that makes up the given SCEV dominate the specified basic block.
APInt getUnsignedRangeMax(const SCEV *S)
Determine the max of the unsigned range for a particular SCEV.
ExitCountKind
The terms "backedge taken count" and "exit count" are used interchangeably to refer to the number of ...
@ SymbolicMaximum
An expression which provides an upper bound on the exact trip count.
@ ConstantMaximum
A constant which provides an upper bound on the exact trip count.
@ Exact
An expression exactly describing the number of times the backedge has executed when a loop is exited.
std::optional< LoopInvariantPredicate > getLoopInvariantExitCondDuringFirstIterations(ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS, const Loop *L, const Instruction *CtxI, const SCEV *MaxIter)
If the result of the predicate LHS Pred RHS is loop invariant with respect to L at given Context duri...
const SCEV * applyLoopGuards(const SCEV *Expr, const Loop *L)
Try to apply information from loop guards for L to Expr.
const SCEV * getMulExpr(SmallVectorImpl< const SCEV * > &Ops, SCEV::NoWrapFlags Flags=SCEV::FlagAnyWrap, unsigned Depth=0)
Get a canonical multiply expression, or something simpler if possible.
const SCEV * getElementSize(Instruction *Inst)
Return the size of an element read or written by Inst.
const SCEV * getSizeOfExpr(Type *IntTy, TypeSize Size)
Return an expression for a TypeSize.
const SCEV * getUnknown(Value *V)
std::optional< std::pair< const SCEV *, SmallVector< const SCEVPredicate *, 3 > > > createAddRecFromPHIWithCasts(const SCEVUnknown *SymbolicPHI)
Checks if SymbolicPHI can be rewritten as an AddRecExpr under some Predicates.
const SCEV * getTruncateOrZeroExtend(const SCEV *V, Type *Ty, unsigned Depth=0)
Return a SCEV corresponding to a conversion of the input value to the specified type.
bool isLoopEntryGuardedByCond(const Loop *L, ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS)
Test whether entry to the loop is protected by a conditional between LHS and RHS.
const SCEV * getElementCount(Type *Ty, ElementCount EC)
static SCEV::NoWrapFlags maskFlags(SCEV::NoWrapFlags Flags, int Mask)
Convenient NoWrapFlags manipulation that hides enum casts and is visible in the ScalarEvolution name ...
std::optional< APInt > computeConstantDifference(const SCEV *LHS, const SCEV *RHS)
Compute LHS - RHS and returns the result as an APInt if it is a constant, and std::nullopt if it isn'...
bool properlyDominates(const SCEV *S, const BasicBlock *BB)
Return true if elements that makes up the given SCEV properly dominate the specified basic block.
const SCEV * rewriteUsingPredicate(const SCEV *S, const Loop *L, const SCEVPredicate &A)
Re-writes the SCEV according to the Predicates in A.
std::pair< const SCEV *, const SCEV * > SplitIntoInitAndPostInc(const Loop *L, const SCEV *S)
Splits SCEV expression S into two SCEVs.
bool canReuseInstruction(const SCEV *S, Instruction *I, SmallVectorImpl< Instruction * > &DropPoisonGeneratingInsts)
Check whether it is poison-safe to represent the expression S using the instruction I.
const SCEV * getUDivExactExpr(const SCEV *LHS, const SCEV *RHS)
Get a canonical unsigned division expression, or something simpler if possible.
void registerUser(const SCEV *User, ArrayRef< const SCEV * > Ops)
Notify this ScalarEvolution that User directly uses SCEVs in Ops.
const SCEV * getAddExpr(SmallVectorImpl< const SCEV * > &Ops, SCEV::NoWrapFlags Flags=SCEV::FlagAnyWrap, unsigned Depth=0)
Get a canonical add expression, or something simpler if possible.
const SCEV * getTruncateOrSignExtend(const SCEV *V, Type *Ty, unsigned Depth=0)
Return a SCEV corresponding to a conversion of the input value to the specified type.
bool containsErasedValue(const SCEV *S) const
Return true if the SCEV expression contains a Value that has been optimised out and is now a nullptr.
const SCEV * getPredicatedSymbolicMaxBackedgeTakenCount(const Loop *L, SmallVector< const SCEVPredicate *, 4 > &Predicates)
Similar to getSymbolicMaxBackedgeTakenCount, except it will add a set of SCEV predicates to Predicate...
const SCEV * getSymbolicMaxBackedgeTakenCount(const Loop *L)
When successful, this returns a SCEV that is greater than or equal to (i.e.
APInt getSignedRangeMax(const SCEV *S)
Determine the max of the signed range for a particular SCEV.
LLVMContext & getContext() const
This class represents the LLVM 'select' instruction.
size_type size() const
Definition: SmallPtrSet.h:95
A templated base class for SmallPtrSet which provides the typesafe interface that is common across al...
Definition: SmallPtrSet.h:346
std::pair< iterator, bool > insert(PtrType Ptr)
Inserts Ptr if and only if there is no element in the container equal to Ptr.
Definition: SmallPtrSet.h:367
bool contains(ConstPtrType Ptr) const
Definition: SmallPtrSet.h:441
SmallPtrSet - This class implements a set which is optimized for holding SmallSize or less elements.
Definition: SmallPtrSet.h:502
SmallSet - This maintains a set of unique values, optimizing for the case when the set is small (less...
Definition: SmallSet.h:135
std::pair< const_iterator, bool > insert(const T &V)
insert - Insert an element into the set if it isn't already there.
Definition: SmallSet.h:179
size_type size() const
Definition: SmallSet.h:161
bool empty() const
Definition: SmallVector.h:94
size_t size() const
Definition: SmallVector.h:91
This class consists of common code factored out of the SmallVector class to reduce code duplication b...
Definition: SmallVector.h:586
reference emplace_back(ArgTypes &&... Args)
Definition: SmallVector.h:950
void reserve(size_type N)
Definition: SmallVector.h:676
iterator erase(const_iterator CI)
Definition: SmallVector.h:750
void append(ItTy in_start, ItTy in_end)
Add the specified range to the end of the SmallVector.
Definition: SmallVector.h:696
iterator insert(iterator I, T &&Elt)
Definition: SmallVector.h:818
void push_back(const T &Elt)
Definition: SmallVector.h:426
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
Definition: SmallVector.h:1209
An instruction for storing to memory.
Definition: Instructions.h:290
Used to lazily calculate structure layout information for a target machine, based on the DataLayout s...
Definition: DataLayout.h:571
TypeSize getElementOffset(unsigned Idx) const
Definition: DataLayout.h:600
TypeSize getSizeInBits() const
Definition: DataLayout.h:580
Class to represent struct types.
Definition: DerivedTypes.h:216
Multiway switch.
Analysis pass providing the TargetLibraryInfo.
Provides information about what library functions are available for the current target.
The instances of the Type class are immutable: once they are created, they are never changed.
Definition: Type.h:45
bool isPointerTy() const
True if this is an instance of PointerType.
Definition: Type.h:251
static IntegerType * getInt1Ty(LLVMContext &C)
static IntegerType * getIntNTy(LLVMContext &C, unsigned N)
unsigned getScalarSizeInBits() const LLVM_READONLY
If this is a vector type, return the getPrimitiveSizeInBits value for the element type.
static IntegerType * getInt8Ty(LLVMContext &C)
bool isIntOrPtrTy() const
Return true if this is an integer type or a pointer type.
Definition: Type.h:239
static IntegerType * getInt32Ty(LLVMContext &C)
bool isIntegerTy() const
True if this is an instance of IntegerType.
Definition: Type.h:224
TypeSize getPrimitiveSizeInBits() const LLVM_READONLY
Return the basic size of this type if it is a primitive type.
A Use represents the edge between a Value definition and its users.
Definition: Use.h:43
op_range operands()
Definition: User.h:242
Use & Op()
Definition: User.h:133
Value * getOperand(unsigned i) const
Definition: User.h:169
LLVM Value Representation.
Definition: Value.h:74
Type * getType() const
All values are typed, get the type of this value.
Definition: Value.h:255
unsigned getValueID() const
Return an ID for the concrete type of this object.
Definition: Value.h:532
void printAsOperand(raw_ostream &O, bool PrintType=true, const Module *M=nullptr) const
Print the name of this Value out to the specified raw_ostream.
Definition: AsmWriter.cpp:5106
LLVMContext & getContext() const
All values hold a context through their type.
Definition: Value.cpp:1075
StringRef getName() const
Return a constant reference to the value's name.
Definition: Value.cpp:309
Represents an op.with.overflow intrinsic.
constexpr bool isScalable() const
Returns whether the quantity is scaled by a runtime quantity (vscale).
Definition: TypeSize.h:171
const ParentTy * getParent() const
Definition: ilist_node.h:32
This class implements an extremely fast bulk output stream that can only output to a stream.
Definition: raw_ostream.h:52
raw_ostream & indent(unsigned NumSpaces)
indent - Insert 'NumSpaces' spaces.
#define llvm_unreachable(msg)
Marks that the current location is not supposed to be reachable.
const APInt & smin(const APInt &A, const APInt &B)
Determine the smaller of two APInts considered to be signed.
Definition: APInt.h:2195
const APInt & smax(const APInt &A, const APInt &B)
Determine the larger of two APInts considered to be signed.
Definition: APInt.h:2200
const APInt & umin(const APInt &A, const APInt &B)
Determine the smaller of two APInts considered to be unsigned.
Definition: APInt.h:2205
std::optional< APInt > SolveQuadraticEquationWrap(APInt A, APInt B, APInt C, unsigned RangeWidth)
Let q(n) = An^2 + Bn + C, and BW = bit width of the value range (e.g.
Definition: APInt.cpp:2781
const APInt & umax(const APInt &A, const APInt &B)
Determine the larger of two APInts considered to be unsigned.
Definition: APInt.h:2210
APInt GreatestCommonDivisor(APInt A, APInt B)
Compute GCD of two unsigned APInt values.
Definition: APInt.cpp:767
@ Entry
Definition: COFF.h:826
@ Exit
Definition: COFF.h:827
unsigned ID
LLVM IR allows to use arbitrary numbers as calling convention identifiers.
Definition: CallingConv.h:24
@ C
The default llvm calling convention, compatible with C.
Definition: CallingConv.h:34
StringRef getName(ID id)
Return the LLVM name for an intrinsic, such as "llvm.ppc.altivec.lvx".
Definition: Function.cpp:1096
BinaryOp_match< LHS, RHS, Instruction::AShr > m_AShr(const LHS &L, const RHS &R)
bool match(Val *V, const Pattern &P)
Definition: PatternMatch.h:49
class_match< ConstantInt > m_ConstantInt()
Match an arbitrary ConstantInt and ignore it.
Definition: PatternMatch.h:168
ThreeOps_match< Cond, LHS, RHS, Instruction::Select > m_Select(const Cond &C, const LHS &L, const RHS &R)
Matches SelectInst.
bind_ty< WithOverflowInst > m_WithOverflowInst(WithOverflowInst *&I)
Match a with overflow intrinsic, capturing it if we match.
Definition: PatternMatch.h:822
auto m_LogicalOr()
Matches L || R where L and R are arbitrary values.
brc_match< Cond_t, bind_ty< BasicBlock >, bind_ty< BasicBlock > > m_Br(const Cond_t &C, BasicBlock *&T, BasicBlock *&F)
BinaryOp_match< LHS, RHS, Instruction::SDiv > m_SDiv(const LHS &L, const RHS &R)
apint_match m_APInt(const APInt *&Res)
Match a ConstantInt or splatted ConstantVector, binding the specified pointer to the contained APInt.
Definition: PatternMatch.h:299
class_match< Value > m_Value()
Match an arbitrary value and ignore it.
Definition: PatternMatch.h:92
BinaryOp_match< LHS, RHS, Instruction::LShr > m_LShr(const LHS &L, const RHS &R)
BinaryOp_match< LHS, RHS, Instruction::Shl > m_Shl(const LHS &L, const RHS &R)
auto m_LogicalAnd()
Matches L && R where L and R are arbitrary values.
class_match< BasicBlock > m_BasicBlock()
Match an arbitrary basic block value and ignore it.
Definition: PatternMatch.h:189
@ ReallyHidden
Definition: CommandLine.h:138
initializer< Ty > init(const Ty &Val)
Definition: CommandLine.h:443
LocationClass< Ty > location(Ty &L)
Definition: CommandLine.h:463
@ Switch
The "resume-switch" lowering, where there are separate resume and destroy functions that are shared b...
constexpr double e
Definition: MathExtras.h:47
NodeAddr< PhiNode * > Phi
Definition: RDFGraph.h:390
@ FalseVal
Definition: TGLexer.h:59
This is an optimization pass for GlobalISel generic memory operations.
Definition: AddressRanges.h:18
void visitAll(const SCEV *Root, SV &Visitor)
Use SCEVTraversal to visit all nodes in the given expression tree.
auto drop_begin(T &&RangeOrContainer, size_t N=1)
Return a range covering RangeOrContainer with the first N elements excluded.
Definition: STLExtras.h:329
@ Offset
Definition: DWP.cpp:480
LLVM_ATTRIBUTE_ALWAYS_INLINE DynamicAPInt gcd(const DynamicAPInt &A, const DynamicAPInt &B)
Definition: DynamicAPInt.h:390
void stable_sort(R &&Range)
Definition: STLExtras.h:2020
bool all_of(R &&range, UnaryPredicate P)
Provide wrappers to std::all_of which take ranges instead of having to pass begin/end explicitly.
Definition: STLExtras.h:1722
bool canCreatePoison(const Operator *Op, bool ConsiderFlagsAndMetadata=true)
bool mustTriggerUB(const Instruction *I, const SmallPtrSetImpl< const Value * > &KnownPoison)
Return true if the given instruction must trigger undefined behavior when I is executed with any oper...
detail::scope_exit< std::decay_t< Callable > > make_scope_exit(Callable &&F)
Definition: ScopeExit.h:59
bool canConstantFoldCallTo(const CallBase *Call, const Function *F)
canConstantFoldCallTo - Return true if its even possible to fold a call to the specified function.
bool verifyFunction(const Function &F, raw_ostream *OS=nullptr)
Check a function for errors, useful for use when debugging a pass.
Definition: Verifier.cpp:7151
auto successors(const MachineBasicBlock *BB)
void * PointerTy
Definition: GenericValue.h:21
bool set_is_subset(const S1Ty &S1, const S2Ty &S2)
set_is_subset(A, B) - Return true iff A in B
void append_range(Container &C, Range &&R)
Wrapper function to append range R to container C.
Definition: STLExtras.h:2098
Constant * ConstantFoldCompareInstOperands(unsigned Predicate, Constant *LHS, Constant *RHS, const DataLayout &DL, const TargetLibraryInfo *TLI=nullptr, const Instruction *I=nullptr)
Attempt to constant fold a compare instruction (icmp/fcmp) with the specified operands.
unsigned short computeExpressionSize(ArrayRef< const SCEV * > Args)
bool VerifySCEV
Printable print(const GCNRegPressure &RP, const GCNSubtarget *ST=nullptr)
ConstantRange getConstantRangeFromMetadata(const MDNode &RangeMD)
Parse out a conservative ConstantRange from !range metadata.
int countr_zero(T Val)
Count number of 0's from the least significant bit to the most stopping at the first 1.
Definition: bit.h:215
Value * simplifyInstruction(Instruction *I, const SimplifyQuery &Q)
See if we can compute a simplified version of this instruction.
bool isOverflowIntrinsicNoWrap(const WithOverflowInst *WO, const DominatorTree &DT)
Returns true if the arithmetic part of the WO 's result is used only along the paths control dependen...
bool matchSimpleRecurrence(const PHINode *P, BinaryOperator *&BO, Value *&Start, Value *&Step)
Attempt to match a simple first order recurrence cycle of the form: iv = phi Ty [Start,...
void erase(Container &C, ValueType V)
Wrapper function to remove a value from a container:
Definition: STLExtras.h:2090
bool any_of(R &&range, UnaryPredicate P)
Provide wrappers to std::any_of which take ranges instead of having to pass begin/end explicitly.
Definition: STLExtras.h:1729
bool getObjectSize(const Value *Ptr, uint64_t &Size, const DataLayout &DL, const TargetLibraryInfo *TLI, ObjectSizeOpts Opts={})
Compute the size of the object pointed by Ptr.
void initializeScalarEvolutionWrapperPassPass(PassRegistry &)
auto reverse(ContainerTy &&C)
Definition: STLExtras.h:419
bool isMustProgress(const Loop *L)
Return true if this loop can be assumed to make progress.
Definition: LoopInfo.cpp:1150
bool impliesPoison(const Value *ValAssumedPoison, const Value *V)
Return true if V is poison given that ValAssumedPoison is already poison.
bool isFinite(const Loop *L)
Return true if this loop can be assumed to run for a finite number of iterations.
Definition: LoopInfo.cpp:1140
bool programUndefinedIfPoison(const Instruction *Inst)
raw_ostream & dbgs()
dbgs() - This returns a reference to a raw_ostream for debugging messages.
Definition: Debug.cpp:163
bool isPointerTy(const Type *T)
Definition: SPIRVUtils.h:120
ConstantRange getVScaleRange(const Function *F, unsigned BitWidth)
Determine the possible constant range of vscale with the given bit width, based on the vscale_range f...
Constant * ConstantFoldInstOperands(Instruction *I, ArrayRef< Constant * > Ops, const DataLayout &DL, const TargetLibraryInfo *TLI=nullptr, bool AllowNonDeterministic=true)
ConstantFoldInstOperands - Attempt to constant fold an instruction with the specified operands.
bool isKnownNonZero(const Value *V, const SimplifyQuery &Q, unsigned Depth=0)
Return true if the given value is known to be non-zero when defined.
@ First
Helpers to iterate all locations in the MemoryEffectsBase class.
bool propagatesPoison(const Use &PoisonOp)
Return true if PoisonOp's user yields poison or raises UB if its operand PoisonOp is poison.
@ UMin
Unsigned integer min implemented in terms of select(cmp()).
@ Mul
Product of integers.
@ SMax
Signed integer max implemented in terms of select(cmp()).
@ SMin
Signed integer min implemented in terms of select(cmp()).
@ Add
Sum of integers.
@ UMax
Unsigned integer max implemented in terms of select(cmp()).
bool isIntN(unsigned N, int64_t x)
Checks if an signed integer fits into the given (dynamic) bit width.
Definition: MathExtras.h:260
auto count(R &&Range, const E &Element)
Wrapper function around std::count to count the number of times an element Element occurs in the give...
Definition: STLExtras.h:1921
void computeKnownBits(const Value *V, KnownBits &Known, const DataLayout &DL, unsigned Depth=0, AssumptionCache *AC=nullptr, const Instruction *CxtI=nullptr, const DominatorTree *DT=nullptr, bool UseInstrInfo=true)
Determine which bits of V are known to be either zero or one and return them in the KnownZero/KnownOn...
DWARFExpression::Operation Op
auto max_element(R &&Range)
Provide wrappers to std::max_element which take ranges instead of having to pass begin/end explicitly...
Definition: STLExtras.h:1997
raw_ostream & operator<<(raw_ostream &OS, const APFixedPoint &FX)
Definition: APFixedPoint.h:292
constexpr unsigned BitWidth
Definition: BitmaskEnum.h:191
OutputIt move(R &&Range, OutputIt Out)
Provide wrappers to std::move which take ranges instead of having to pass begin/end explicitly.
Definition: STLExtras.h:1856
bool isGuaranteedToTransferExecutionToSuccessor(const Instruction *I)
Return true if this function can prove that the instruction I will always transfer execution to one o...
auto count_if(R &&Range, UnaryPredicate P)
Wrapper function around std::count_if to count the number of times an element satisfying a given pred...
Definition: STLExtras.h:1928
auto predecessors(const MachineBasicBlock *BB)
bool isAllocationFn(const Value *V, const TargetLibraryInfo *TLI)
Tests if a value is a call or invoke to a library function that allocates or reallocates memory (eith...
bool is_contained(R &&Range, const E &Element)
Returns true if Element is found in Range.
Definition: STLExtras.h:1886
unsigned ComputeNumSignBits(const Value *Op, const DataLayout &DL, unsigned Depth=0, AssumptionCache *AC=nullptr, const Instruction *CxtI=nullptr, const DominatorTree *DT=nullptr, bool UseInstrInfo=true)
Return the number of times the sign bit of the register is replicated into the other bits.
iterator_range< df_iterator< T > > depth_first(const T &G)
auto seq(T Begin, T End)
Iterate over an integral type from Begin up to - but not including - End.
Definition: Sequence.h:305
bool isGuaranteedNotToBePoison(const Value *V, AssumptionCache *AC=nullptr, const Instruction *CtxI=nullptr, const DominatorTree *DT=nullptr, unsigned Depth=0)
Returns true if V cannot be poison, but may be undef.
bool SCEVExprContains(const SCEV *Root, PredTy Pred)
Return true if any node in Root satisfies the predicate Pred.
Implement std::hash so that hash_code can be used in STL containers.
Definition: BitVector.h:858
void swap(llvm::BitVector &LHS, llvm::BitVector &RHS)
Implement std::swap in terms of BitVector swap.
Definition: BitVector.h:860
#define N
#define NC
Definition: regutils.h:42
This struct is a compact representation of a valid (non-zero power of two) alignment.
Definition: Alignment.h:39
A special type used by analysis passes to provide an address that identifies that particular analysis...
Definition: Analysis.h:28
static KnownBits makeConstant(const APInt &C)
Create known bits from a known constant.
Definition: KnownBits.h:290
bool isNonNegative() const
Returns true if this value is known to be non-negative.
Definition: KnownBits.h:97
static KnownBits ashr(const KnownBits &LHS, const KnownBits &RHS, bool ShAmtNonZero=false, bool Exact=false)
Compute known bits for ashr(LHS, RHS).
Definition: KnownBits.cpp:428
unsigned getBitWidth() const
Get the bit width of this value.
Definition: KnownBits.h:40
static KnownBits lshr(const KnownBits &LHS, const KnownBits &RHS, bool ShAmtNonZero=false, bool Exact=false)
Compute known bits for lshr(LHS, RHS).
Definition: KnownBits.cpp:370
KnownBits zextOrTrunc(unsigned BitWidth) const
Return known bits for a zero extension or truncation of the value we're tracking.
Definition: KnownBits.h:185
APInt getMaxValue() const
Return the maximal unsigned value possible given these KnownBits.
Definition: KnownBits.h:134
APInt getMinValue() const
Return the minimal unsigned value possible given these KnownBits.
Definition: KnownBits.h:118
bool isNegative() const
Returns true if this value is known to be negative.
Definition: KnownBits.h:94
static KnownBits shl(const KnownBits &LHS, const KnownBits &RHS, bool NUW=false, bool NSW=false, bool ShAmtNonZero=false)
Compute known bits for shl(LHS, RHS).
Definition: KnownBits.cpp:285
Various options to control the behavior of getObjectSize.
bool NullIsUnknownSize
If this is true, null pointers in address space 0 will be treated as though they can't be evaluated.
bool RoundToAlign
Whether to round the result up to the alignment of allocas, byval arguments, and global variables.
An object of this class is returned by queries that could not be answered.
static bool classof(const SCEV *S)
Methods for support type inquiry through isa, cast, and dyn_cast:
This class defines a simple visitor class that may be used for various SCEV analysis purposes.
A utility class that uses RAII to save and restore the value of a variable.
Information about the number of loop iterations for which a loop exit's branch condition evaluates to...
ExitLimit(const SCEV *E)
Construct either an exact exit limit from a constant, or an unknown one from a SCEVCouldNotCompute.
void addPredicate(const SCEVPredicate *P)