LLVM 19.0.0git
ScalarEvolution.cpp
Go to the documentation of this file.
1//===- ScalarEvolution.cpp - Scalar Evolution Analysis --------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file contains the implementation of the scalar evolution analysis
10// engine, which is used primarily to analyze expressions involving induction
11// variables in loops.
12//
13// There are several aspects to this library. First is the representation of
14// scalar expressions, which are represented as subclasses of the SCEV class.
15// These classes are used to represent certain types of subexpressions that we
16// can handle. We only create one SCEV of a particular shape, so
17// pointer-comparisons for equality are legal.
18//
19// One important aspect of the SCEV objects is that they are never cyclic, even
20// if there is a cycle in the dataflow for an expression (ie, a PHI node). If
21// the PHI node is one of the idioms that we can represent (e.g., a polynomial
22// recurrence) then we represent it directly as a recurrence node, otherwise we
23// represent it as a SCEVUnknown node.
24//
25// In addition to being able to represent expressions of various types, we also
26// have folders that are used to build the *canonical* representation for a
27// particular expression. These folders are capable of using a variety of
28// rewrite rules to simplify the expressions.
29//
30// Once the folders are defined, we can implement the more interesting
31// higher-level code, such as the code that recognizes PHI nodes of various
32// types, computes the execution count of a loop, etc.
33//
34// TODO: We should use these routines and value representations to implement
35// dependence analysis!
36//
37//===----------------------------------------------------------------------===//
38//
39// There are several good references for the techniques used in this analysis.
40//
41// Chains of recurrences -- a method to expedite the evaluation
42// of closed-form functions
43// Olaf Bachmann, Paul S. Wang, Eugene V. Zima
44//
45// On computational properties of chains of recurrences
46// Eugene V. Zima
47//
48// Symbolic Evaluation of Chains of Recurrences for Loop Optimization
49// Robert A. van Engelen
50//
51// Efficient Symbolic Analysis for Optimizing Compilers
52// Robert A. van Engelen
53//
54// Using the chains of recurrences algebra for data dependence testing and
55// induction variable substitution
56// MS Thesis, Johnie Birch
57//
58//===----------------------------------------------------------------------===//
59
61#include "llvm/ADT/APInt.h"
62#include "llvm/ADT/ArrayRef.h"
63#include "llvm/ADT/DenseMap.h"
66#include "llvm/ADT/FoldingSet.h"
67#include "llvm/ADT/STLExtras.h"
68#include "llvm/ADT/ScopeExit.h"
69#include "llvm/ADT/Sequence.h"
71#include "llvm/ADT/SmallSet.h"
73#include "llvm/ADT/Statistic.h"
75#include "llvm/ADT/StringRef.h"
84#include "llvm/Config/llvm-config.h"
85#include "llvm/IR/Argument.h"
86#include "llvm/IR/BasicBlock.h"
87#include "llvm/IR/CFG.h"
88#include "llvm/IR/Constant.h"
90#include "llvm/IR/Constants.h"
91#include "llvm/IR/DataLayout.h"
93#include "llvm/IR/Dominators.h"
94#include "llvm/IR/Function.h"
95#include "llvm/IR/GlobalAlias.h"
96#include "llvm/IR/GlobalValue.h"
98#include "llvm/IR/InstrTypes.h"
99#include "llvm/IR/Instruction.h"
100#include "llvm/IR/Instructions.h"
102#include "llvm/IR/Intrinsics.h"
103#include "llvm/IR/LLVMContext.h"
104#include "llvm/IR/Operator.h"
105#include "llvm/IR/PatternMatch.h"
106#include "llvm/IR/Type.h"
107#include "llvm/IR/Use.h"
108#include "llvm/IR/User.h"
109#include "llvm/IR/Value.h"
110#include "llvm/IR/Verifier.h"
112#include "llvm/Pass.h"
113#include "llvm/Support/Casting.h"
116#include "llvm/Support/Debug.h"
121#include <algorithm>
122#include <cassert>
123#include <climits>
124#include <cstdint>
125#include <cstdlib>
126#include <map>
127#include <memory>
128#include <numeric>
129#include <optional>
130#include <tuple>
131#include <utility>
132#include <vector>
133
134using namespace llvm;
135using namespace PatternMatch;
136
137#define DEBUG_TYPE "scalar-evolution"
138
139STATISTIC(NumExitCountsComputed,
140 "Number of loop exits with predictable exit counts");
141STATISTIC(NumExitCountsNotComputed,
142 "Number of loop exits without predictable exit counts");
143STATISTIC(NumBruteForceTripCountsComputed,
144 "Number of loops with trip counts computed by force");
145
146#ifdef EXPENSIVE_CHECKS
147bool llvm::VerifySCEV = true;
148#else
149bool llvm::VerifySCEV = false;
150#endif
151
153 MaxBruteForceIterations("scalar-evolution-max-iterations", cl::ReallyHidden,
154 cl::desc("Maximum number of iterations SCEV will "
155 "symbolically execute a constant "
156 "derived loop"),
157 cl::init(100));
158
160 "verify-scev", cl::Hidden, cl::location(VerifySCEV),
161 cl::desc("Verify ScalarEvolution's backedge taken counts (slow)"));
163 "verify-scev-strict", cl::Hidden,
164 cl::desc("Enable stricter verification with -verify-scev is passed"));
165
167 "scev-verify-ir", cl::Hidden,
168 cl::desc("Verify IR correctness when making sensitive SCEV queries (slow)"),
169 cl::init(false));
170
172 "scev-mulops-inline-threshold", cl::Hidden,
173 cl::desc("Threshold for inlining multiplication operands into a SCEV"),
174 cl::init(32));
175
177 "scev-addops-inline-threshold", cl::Hidden,
178 cl::desc("Threshold for inlining addition operands into a SCEV"),
179 cl::init(500));
180
182 "scalar-evolution-max-scev-compare-depth", cl::Hidden,
183 cl::desc("Maximum depth of recursive SCEV complexity comparisons"),
184 cl::init(32));
185
187 "scalar-evolution-max-scev-operations-implication-depth", cl::Hidden,
188 cl::desc("Maximum depth of recursive SCEV operations implication analysis"),
189 cl::init(2));
190
192 "scalar-evolution-max-value-compare-depth", cl::Hidden,
193 cl::desc("Maximum depth of recursive value complexity comparisons"),
194 cl::init(2));
195
197 MaxArithDepth("scalar-evolution-max-arith-depth", cl::Hidden,
198 cl::desc("Maximum depth of recursive arithmetics"),
199 cl::init(32));
200
202 "scalar-evolution-max-constant-evolving-depth", cl::Hidden,
203 cl::desc("Maximum depth of recursive constant evolving"), cl::init(32));
204
206 MaxCastDepth("scalar-evolution-max-cast-depth", cl::Hidden,
207 cl::desc("Maximum depth of recursive SExt/ZExt/Trunc"),
208 cl::init(8));
209
211 MaxAddRecSize("scalar-evolution-max-add-rec-size", cl::Hidden,
212 cl::desc("Max coefficients in AddRec during evolving"),
213 cl::init(8));
214
216 HugeExprThreshold("scalar-evolution-huge-expr-threshold", cl::Hidden,
217 cl::desc("Size of the expression which is considered huge"),
218 cl::init(4096));
219
221 "scev-range-iter-threshold", cl::Hidden,
222 cl::desc("Threshold for switching to iteratively computing SCEV ranges"),
223 cl::init(32));
224
225static cl::opt<bool>
226ClassifyExpressions("scalar-evolution-classify-expressions",
227 cl::Hidden, cl::init(true),
228 cl::desc("When printing analysis, include information on every instruction"));
229
231 "scalar-evolution-use-expensive-range-sharpening", cl::Hidden,
232 cl::init(false),
233 cl::desc("Use more powerful methods of sharpening expression ranges. May "
234 "be costly in terms of compile time"));
235
237 "scalar-evolution-max-scc-analysis-depth", cl::Hidden,
238 cl::desc("Maximum amount of nodes to process while searching SCEVUnknown "
239 "Phi strongly connected components"),
240 cl::init(8));
241
242static cl::opt<bool>
243 EnableFiniteLoopControl("scalar-evolution-finite-loop", cl::Hidden,
244 cl::desc("Handle <= and >= in finite loops"),
245 cl::init(true));
246
248 "scalar-evolution-use-context-for-no-wrap-flag-strenghening", cl::Hidden,
249 cl::desc("Infer nuw/nsw flags using context where suitable"),
250 cl::init(true));
251
252//===----------------------------------------------------------------------===//
253// SCEV class definitions
254//===----------------------------------------------------------------------===//
255
256//===----------------------------------------------------------------------===//
257// Implementation of the SCEV class.
258//
259
260#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
262 print(dbgs());
263 dbgs() << '\n';
264}
265#endif
266
268 switch (getSCEVType()) {
269 case scConstant:
270 cast<SCEVConstant>(this)->getValue()->printAsOperand(OS, false);
271 return;
272 case scVScale:
273 OS << "vscale";
274 return;
275 case scPtrToInt: {
276 const SCEVPtrToIntExpr *PtrToInt = cast<SCEVPtrToIntExpr>(this);
277 const SCEV *Op = PtrToInt->getOperand();
278 OS << "(ptrtoint " << *Op->getType() << " " << *Op << " to "
279 << *PtrToInt->getType() << ")";
280 return;
281 }
282 case scTruncate: {
283 const SCEVTruncateExpr *Trunc = cast<SCEVTruncateExpr>(this);
284 const SCEV *Op = Trunc->getOperand();
285 OS << "(trunc " << *Op->getType() << " " << *Op << " to "
286 << *Trunc->getType() << ")";
287 return;
288 }
289 case scZeroExtend: {
290 const SCEVZeroExtendExpr *ZExt = cast<SCEVZeroExtendExpr>(this);
291 const SCEV *Op = ZExt->getOperand();
292 OS << "(zext " << *Op->getType() << " " << *Op << " to "
293 << *ZExt->getType() << ")";
294 return;
295 }
296 case scSignExtend: {
297 const SCEVSignExtendExpr *SExt = cast<SCEVSignExtendExpr>(this);
298 const SCEV *Op = SExt->getOperand();
299 OS << "(sext " << *Op->getType() << " " << *Op << " to "
300 << *SExt->getType() << ")";
301 return;
302 }
303 case scAddRecExpr: {
304 const SCEVAddRecExpr *AR = cast<SCEVAddRecExpr>(this);
305 OS << "{" << *AR->getOperand(0);
306 for (unsigned i = 1, e = AR->getNumOperands(); i != e; ++i)
307 OS << ",+," << *AR->getOperand(i);
308 OS << "}<";
309 if (AR->hasNoUnsignedWrap())
310 OS << "nuw><";
311 if (AR->hasNoSignedWrap())
312 OS << "nsw><";
313 if (AR->hasNoSelfWrap() &&
315 OS << "nw><";
316 AR->getLoop()->getHeader()->printAsOperand(OS, /*PrintType=*/false);
317 OS << ">";
318 return;
319 }
320 case scAddExpr:
321 case scMulExpr:
322 case scUMaxExpr:
323 case scSMaxExpr:
324 case scUMinExpr:
325 case scSMinExpr:
327 const SCEVNAryExpr *NAry = cast<SCEVNAryExpr>(this);
328 const char *OpStr = nullptr;
329 switch (NAry->getSCEVType()) {
330 case scAddExpr: OpStr = " + "; break;
331 case scMulExpr: OpStr = " * "; break;
332 case scUMaxExpr: OpStr = " umax "; break;
333 case scSMaxExpr: OpStr = " smax "; break;
334 case scUMinExpr:
335 OpStr = " umin ";
336 break;
337 case scSMinExpr:
338 OpStr = " smin ";
339 break;
341 OpStr = " umin_seq ";
342 break;
343 default:
344 llvm_unreachable("There are no other nary expression types.");
345 }
346 OS << "(";
347 ListSeparator LS(OpStr);
348 for (const SCEV *Op : NAry->operands())
349 OS << LS << *Op;
350 OS << ")";
351 switch (NAry->getSCEVType()) {
352 case scAddExpr:
353 case scMulExpr:
354 if (NAry->hasNoUnsignedWrap())
355 OS << "<nuw>";
356 if (NAry->hasNoSignedWrap())
357 OS << "<nsw>";
358 break;
359 default:
360 // Nothing to print for other nary expressions.
361 break;
362 }
363 return;
364 }
365 case scUDivExpr: {
366 const SCEVUDivExpr *UDiv = cast<SCEVUDivExpr>(this);
367 OS << "(" << *UDiv->getLHS() << " /u " << *UDiv->getRHS() << ")";
368 return;
369 }
370 case scUnknown:
371 cast<SCEVUnknown>(this)->getValue()->printAsOperand(OS, false);
372 return;
374 OS << "***COULDNOTCOMPUTE***";
375 return;
376 }
377 llvm_unreachable("Unknown SCEV kind!");
378}
379
381 switch (getSCEVType()) {
382 case scConstant:
383 return cast<SCEVConstant>(this)->getType();
384 case scVScale:
385 return cast<SCEVVScale>(this)->getType();
386 case scPtrToInt:
387 case scTruncate:
388 case scZeroExtend:
389 case scSignExtend:
390 return cast<SCEVCastExpr>(this)->getType();
391 case scAddRecExpr:
392 return cast<SCEVAddRecExpr>(this)->getType();
393 case scMulExpr:
394 return cast<SCEVMulExpr>(this)->getType();
395 case scUMaxExpr:
396 case scSMaxExpr:
397 case scUMinExpr:
398 case scSMinExpr:
399 return cast<SCEVMinMaxExpr>(this)->getType();
401 return cast<SCEVSequentialMinMaxExpr>(this)->getType();
402 case scAddExpr:
403 return cast<SCEVAddExpr>(this)->getType();
404 case scUDivExpr:
405 return cast<SCEVUDivExpr>(this)->getType();
406 case scUnknown:
407 return cast<SCEVUnknown>(this)->getType();
409 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
410 }
411 llvm_unreachable("Unknown SCEV kind!");
412}
413
415 switch (getSCEVType()) {
416 case scConstant:
417 case scVScale:
418 case scUnknown:
419 return {};
420 case scPtrToInt:
421 case scTruncate:
422 case scZeroExtend:
423 case scSignExtend:
424 return cast<SCEVCastExpr>(this)->operands();
425 case scAddRecExpr:
426 case scAddExpr:
427 case scMulExpr:
428 case scUMaxExpr:
429 case scSMaxExpr:
430 case scUMinExpr:
431 case scSMinExpr:
433 return cast<SCEVNAryExpr>(this)->operands();
434 case scUDivExpr:
435 return cast<SCEVUDivExpr>(this)->operands();
437 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
438 }
439 llvm_unreachable("Unknown SCEV kind!");
440}
441
442bool SCEV::isZero() const {
443 if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(this))
444 return SC->getValue()->isZero();
445 return false;
446}
447
448bool SCEV::isOne() const {
449 if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(this))
450 return SC->getValue()->isOne();
451 return false;
452}
453
455 if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(this))
456 return SC->getValue()->isMinusOne();
457 return false;
458}
459
461 const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(this);
462 if (!Mul) return false;
463
464 // If there is a constant factor, it will be first.
465 const SCEVConstant *SC = dyn_cast<SCEVConstant>(Mul->getOperand(0));
466 if (!SC) return false;
467
468 // Return true if the value is negative, this matches things like (-42 * V).
469 return SC->getAPInt().isNegative();
470}
471
474
476 return S->getSCEVType() == scCouldNotCompute;
477}
478
481 ID.AddInteger(scConstant);
482 ID.AddPointer(V);
483 void *IP = nullptr;
484 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
485 SCEV *S = new (SCEVAllocator) SCEVConstant(ID.Intern(SCEVAllocator), V);
486 UniqueSCEVs.InsertNode(S, IP);
487 return S;
488}
489
491 return getConstant(ConstantInt::get(getContext(), Val));
492}
493
494const SCEV *
496 IntegerType *ITy = cast<IntegerType>(getEffectiveSCEVType(Ty));
497 return getConstant(ConstantInt::get(ITy, V, isSigned));
498}
499
502 ID.AddInteger(scVScale);
503 ID.AddPointer(Ty);
504 void *IP = nullptr;
505 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP))
506 return S;
507 SCEV *S = new (SCEVAllocator) SCEVVScale(ID.Intern(SCEVAllocator), Ty);
508 UniqueSCEVs.InsertNode(S, IP);
509 return S;
510}
511
513 const SCEV *Res = getConstant(Ty, EC.getKnownMinValue());
514 if (EC.isScalable())
515 Res = getMulExpr(Res, getVScale(Ty));
516 return Res;
517}
518
520 const SCEV *op, Type *ty)
521 : SCEV(ID, SCEVTy, computeExpressionSize(op)), Op(op), Ty(ty) {}
522
523SCEVPtrToIntExpr::SCEVPtrToIntExpr(const FoldingSetNodeIDRef ID, const SCEV *Op,
524 Type *ITy)
525 : SCEVCastExpr(ID, scPtrToInt, Op, ITy) {
527 "Must be a non-bit-width-changing pointer-to-integer cast!");
528}
529
531 SCEVTypes SCEVTy, const SCEV *op,
532 Type *ty)
533 : SCEVCastExpr(ID, SCEVTy, op, ty) {}
534
535SCEVTruncateExpr::SCEVTruncateExpr(const FoldingSetNodeIDRef ID, const SCEV *op,
536 Type *ty)
538 assert(getOperand()->getType()->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
539 "Cannot truncate non-integer value!");
540}
541
542SCEVZeroExtendExpr::SCEVZeroExtendExpr(const FoldingSetNodeIDRef ID,
543 const SCEV *op, Type *ty)
545 assert(getOperand()->getType()->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
546 "Cannot zero extend non-integer value!");
547}
548
549SCEVSignExtendExpr::SCEVSignExtendExpr(const FoldingSetNodeIDRef ID,
550 const SCEV *op, Type *ty)
552 assert(getOperand()->getType()->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
553 "Cannot sign extend non-integer value!");
554}
555
556void SCEVUnknown::deleted() {
557 // Clear this SCEVUnknown from various maps.
558 SE->forgetMemoizedResults(this);
559
560 // Remove this SCEVUnknown from the uniquing map.
561 SE->UniqueSCEVs.RemoveNode(this);
562
563 // Release the value.
564 setValPtr(nullptr);
565}
566
567void SCEVUnknown::allUsesReplacedWith(Value *New) {
568 // Clear this SCEVUnknown from various maps.
569 SE->forgetMemoizedResults(this);
570
571 // Remove this SCEVUnknown from the uniquing map.
572 SE->UniqueSCEVs.RemoveNode(this);
573
574 // Replace the value pointer in case someone is still using this SCEVUnknown.
575 setValPtr(New);
576}
577
578//===----------------------------------------------------------------------===//
579// SCEV Utilities
580//===----------------------------------------------------------------------===//
581
582/// Compare the two values \p LV and \p RV in terms of their "complexity" where
583/// "complexity" is a partial (and somewhat ad-hoc) relation used to order
584/// operands in SCEV expressions. \p EqCache is a set of pairs of values that
585/// have been previously deemed to be "equally complex" by this routine. It is
586/// intended to avoid exponential time complexity in cases like:
587///
588/// %a = f(%x, %y)
589/// %b = f(%a, %a)
590/// %c = f(%b, %b)
591///
592/// %d = f(%x, %y)
593/// %e = f(%d, %d)
594/// %f = f(%e, %e)
595///
596/// CompareValueComplexity(%f, %c)
597///
598/// Since we do not continue running this routine on expression trees once we
599/// have seen unequal values, there is no need to track them in the cache.
600static int
602 const LoopInfo *const LI, Value *LV, Value *RV,
603 unsigned Depth) {
604 if (Depth > MaxValueCompareDepth || EqCacheValue.isEquivalent(LV, RV))
605 return 0;
606
607 // Order pointer values after integer values. This helps SCEVExpander form
608 // GEPs.
609 bool LIsPointer = LV->getType()->isPointerTy(),
610 RIsPointer = RV->getType()->isPointerTy();
611 if (LIsPointer != RIsPointer)
612 return (int)LIsPointer - (int)RIsPointer;
613
614 // Compare getValueID values.
615 unsigned LID = LV->getValueID(), RID = RV->getValueID();
616 if (LID != RID)
617 return (int)LID - (int)RID;
618
619 // Sort arguments by their position.
620 if (const auto *LA = dyn_cast<Argument>(LV)) {
621 const auto *RA = cast<Argument>(RV);
622 unsigned LArgNo = LA->getArgNo(), RArgNo = RA->getArgNo();
623 return (int)LArgNo - (int)RArgNo;
624 }
625
626 if (const auto *LGV = dyn_cast<GlobalValue>(LV)) {
627 const auto *RGV = cast<GlobalValue>(RV);
628
629 const auto IsGVNameSemantic = [&](const GlobalValue *GV) {
630 auto LT = GV->getLinkage();
631 return !(GlobalValue::isPrivateLinkage(LT) ||
633 };
634
635 // Use the names to distinguish the two values, but only if the
636 // names are semantically important.
637 if (IsGVNameSemantic(LGV) && IsGVNameSemantic(RGV))
638 return LGV->getName().compare(RGV->getName());
639 }
640
641 // For instructions, compare their loop depth, and their operand count. This
642 // is pretty loose.
643 if (const auto *LInst = dyn_cast<Instruction>(LV)) {
644 const auto *RInst = cast<Instruction>(RV);
645
646 // Compare loop depths.
647 const BasicBlock *LParent = LInst->getParent(),
648 *RParent = RInst->getParent();
649 if (LParent != RParent) {
650 unsigned LDepth = LI->getLoopDepth(LParent),
651 RDepth = LI->getLoopDepth(RParent);
652 if (LDepth != RDepth)
653 return (int)LDepth - (int)RDepth;
654 }
655
656 // Compare the number of operands.
657 unsigned LNumOps = LInst->getNumOperands(),
658 RNumOps = RInst->getNumOperands();
659 if (LNumOps != RNumOps)
660 return (int)LNumOps - (int)RNumOps;
661
662 for (unsigned Idx : seq(LNumOps)) {
663 int Result =
664 CompareValueComplexity(EqCacheValue, LI, LInst->getOperand(Idx),
665 RInst->getOperand(Idx), Depth + 1);
666 if (Result != 0)
667 return Result;
668 }
669 }
670
671 EqCacheValue.unionSets(LV, RV);
672 return 0;
673}
674
675// Return negative, zero, or positive, if LHS is less than, equal to, or greater
676// than RHS, respectively. A three-way result allows recursive comparisons to be
677// more efficient.
678// If the max analysis depth was reached, return std::nullopt, assuming we do
679// not know if they are equivalent for sure.
680static std::optional<int>
683 const LoopInfo *const LI, const SCEV *LHS,
684 const SCEV *RHS, DominatorTree &DT, unsigned Depth = 0) {
685 // Fast-path: SCEVs are uniqued so we can do a quick equality check.
686 if (LHS == RHS)
687 return 0;
688
689 // Primarily, sort the SCEVs by their getSCEVType().
690 SCEVTypes LType = LHS->getSCEVType(), RType = RHS->getSCEVType();
691 if (LType != RType)
692 return (int)LType - (int)RType;
693
694 if (EqCacheSCEV.isEquivalent(LHS, RHS))
695 return 0;
696
698 return std::nullopt;
699
700 // Aside from the getSCEVType() ordering, the particular ordering
701 // isn't very important except that it's beneficial to be consistent,
702 // so that (a + b) and (b + a) don't end up as different expressions.
703 switch (LType) {
704 case scUnknown: {
705 const SCEVUnknown *LU = cast<SCEVUnknown>(LHS);
706 const SCEVUnknown *RU = cast<SCEVUnknown>(RHS);
707
708 int X = CompareValueComplexity(EqCacheValue, LI, LU->getValue(),
709 RU->getValue(), Depth + 1);
710 if (X == 0)
711 EqCacheSCEV.unionSets(LHS, RHS);
712 return X;
713 }
714
715 case scConstant: {
716 const SCEVConstant *LC = cast<SCEVConstant>(LHS);
717 const SCEVConstant *RC = cast<SCEVConstant>(RHS);
718
719 // Compare constant values.
720 const APInt &LA = LC->getAPInt();
721 const APInt &RA = RC->getAPInt();
722 unsigned LBitWidth = LA.getBitWidth(), RBitWidth = RA.getBitWidth();
723 if (LBitWidth != RBitWidth)
724 return (int)LBitWidth - (int)RBitWidth;
725 return LA.ult(RA) ? -1 : 1;
726 }
727
728 case scVScale: {
729 const auto *LTy = cast<IntegerType>(cast<SCEVVScale>(LHS)->getType());
730 const auto *RTy = cast<IntegerType>(cast<SCEVVScale>(RHS)->getType());
731 return LTy->getBitWidth() - RTy->getBitWidth();
732 }
733
734 case scAddRecExpr: {
735 const SCEVAddRecExpr *LA = cast<SCEVAddRecExpr>(LHS);
736 const SCEVAddRecExpr *RA = cast<SCEVAddRecExpr>(RHS);
737
738 // There is always a dominance between two recs that are used by one SCEV,
739 // so we can safely sort recs by loop header dominance. We require such
740 // order in getAddExpr.
741 const Loop *LLoop = LA->getLoop(), *RLoop = RA->getLoop();
742 if (LLoop != RLoop) {
743 const BasicBlock *LHead = LLoop->getHeader(), *RHead = RLoop->getHeader();
744 assert(LHead != RHead && "Two loops share the same header?");
745 if (DT.dominates(LHead, RHead))
746 return 1;
747 assert(DT.dominates(RHead, LHead) &&
748 "No dominance between recurrences used by one SCEV?");
749 return -1;
750 }
751
752 [[fallthrough]];
753 }
754
755 case scTruncate:
756 case scZeroExtend:
757 case scSignExtend:
758 case scPtrToInt:
759 case scAddExpr:
760 case scMulExpr:
761 case scUDivExpr:
762 case scSMaxExpr:
763 case scUMaxExpr:
764 case scSMinExpr:
765 case scUMinExpr:
767 ArrayRef<const SCEV *> LOps = LHS->operands();
768 ArrayRef<const SCEV *> ROps = RHS->operands();
769
770 // Lexicographically compare n-ary-like expressions.
771 unsigned LNumOps = LOps.size(), RNumOps = ROps.size();
772 if (LNumOps != RNumOps)
773 return (int)LNumOps - (int)RNumOps;
774
775 for (unsigned i = 0; i != LNumOps; ++i) {
776 auto X = CompareSCEVComplexity(EqCacheSCEV, EqCacheValue, LI, LOps[i],
777 ROps[i], DT, Depth + 1);
778 if (X != 0)
779 return X;
780 }
781 EqCacheSCEV.unionSets(LHS, RHS);
782 return 0;
783 }
784
786 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
787 }
788 llvm_unreachable("Unknown SCEV kind!");
789}
790
791/// Given a list of SCEV objects, order them by their complexity, and group
792/// objects of the same complexity together by value. When this routine is
793/// finished, we know that any duplicates in the vector are consecutive and that
794/// complexity is monotonically increasing.
795///
796/// Note that we go take special precautions to ensure that we get deterministic
797/// results from this routine. In other words, we don't want the results of
798/// this to depend on where the addresses of various SCEV objects happened to
799/// land in memory.
801 LoopInfo *LI, DominatorTree &DT) {
802 if (Ops.size() < 2) return; // Noop
803
806
807 // Whether LHS has provably less complexity than RHS.
808 auto IsLessComplex = [&](const SCEV *LHS, const SCEV *RHS) {
809 auto Complexity =
810 CompareSCEVComplexity(EqCacheSCEV, EqCacheValue, LI, LHS, RHS, DT);
811 return Complexity && *Complexity < 0;
812 };
813 if (Ops.size() == 2) {
814 // This is the common case, which also happens to be trivially simple.
815 // Special case it.
816 const SCEV *&LHS = Ops[0], *&RHS = Ops[1];
817 if (IsLessComplex(RHS, LHS))
818 std::swap(LHS, RHS);
819 return;
820 }
821
822 // Do the rough sort by complexity.
823 llvm::stable_sort(Ops, [&](const SCEV *LHS, const SCEV *RHS) {
824 return IsLessComplex(LHS, RHS);
825 });
826
827 // Now that we are sorted by complexity, group elements of the same
828 // complexity. Note that this is, at worst, N^2, but the vector is likely to
829 // be extremely short in practice. Note that we take this approach because we
830 // do not want to depend on the addresses of the objects we are grouping.
831 for (unsigned i = 0, e = Ops.size(); i != e-2; ++i) {
832 const SCEV *S = Ops[i];
833 unsigned Complexity = S->getSCEVType();
834
835 // If there are any objects of the same complexity and same value as this
836 // one, group them.
837 for (unsigned j = i+1; j != e && Ops[j]->getSCEVType() == Complexity; ++j) {
838 if (Ops[j] == S) { // Found a duplicate.
839 // Move it to immediately after i'th element.
840 std::swap(Ops[i+1], Ops[j]);
841 ++i; // no need to rescan it.
842 if (i == e-2) return; // Done!
843 }
844 }
845 }
846}
847
848/// Returns true if \p Ops contains a huge SCEV (the subtree of S contains at
849/// least HugeExprThreshold nodes).
851 return any_of(Ops, [](const SCEV *S) {
853 });
854}
855
856//===----------------------------------------------------------------------===//
857// Simple SCEV method implementations
858//===----------------------------------------------------------------------===//
859
860/// Compute BC(It, K). The result has width W. Assume, K > 0.
861static const SCEV *BinomialCoefficient(const SCEV *It, unsigned K,
862 ScalarEvolution &SE,
863 Type *ResultTy) {
864 // Handle the simplest case efficiently.
865 if (K == 1)
866 return SE.getTruncateOrZeroExtend(It, ResultTy);
867
868 // We are using the following formula for BC(It, K):
869 //
870 // BC(It, K) = (It * (It - 1) * ... * (It - K + 1)) / K!
871 //
872 // Suppose, W is the bitwidth of the return value. We must be prepared for
873 // overflow. Hence, we must assure that the result of our computation is
874 // equal to the accurate one modulo 2^W. Unfortunately, division isn't
875 // safe in modular arithmetic.
876 //
877 // However, this code doesn't use exactly that formula; the formula it uses
878 // is something like the following, where T is the number of factors of 2 in
879 // K! (i.e. trailing zeros in the binary representation of K!), and ^ is
880 // exponentiation:
881 //
882 // BC(It, K) = (It * (It - 1) * ... * (It - K + 1)) / 2^T / (K! / 2^T)
883 //
884 // This formula is trivially equivalent to the previous formula. However,
885 // this formula can be implemented much more efficiently. The trick is that
886 // K! / 2^T is odd, and exact division by an odd number *is* safe in modular
887 // arithmetic. To do exact division in modular arithmetic, all we have
888 // to do is multiply by the inverse. Therefore, this step can be done at
889 // width W.
890 //
891 // The next issue is how to safely do the division by 2^T. The way this
892 // is done is by doing the multiplication step at a width of at least W + T
893 // bits. This way, the bottom W+T bits of the product are accurate. Then,
894 // when we perform the division by 2^T (which is equivalent to a right shift
895 // by T), the bottom W bits are accurate. Extra bits are okay; they'll get
896 // truncated out after the division by 2^T.
897 //
898 // In comparison to just directly using the first formula, this technique
899 // is much more efficient; using the first formula requires W * K bits,
900 // but this formula less than W + K bits. Also, the first formula requires
901 // a division step, whereas this formula only requires multiplies and shifts.
902 //
903 // It doesn't matter whether the subtraction step is done in the calculation
904 // width or the input iteration count's width; if the subtraction overflows,
905 // the result must be zero anyway. We prefer here to do it in the width of
906 // the induction variable because it helps a lot for certain cases; CodeGen
907 // isn't smart enough to ignore the overflow, which leads to much less
908 // efficient code if the width of the subtraction is wider than the native
909 // register width.
910 //
911 // (It's possible to not widen at all by pulling out factors of 2 before
912 // the multiplication; for example, K=2 can be calculated as
913 // It/2*(It+(It*INT_MIN/INT_MIN)+-1). However, it requires
914 // extra arithmetic, so it's not an obvious win, and it gets
915 // much more complicated for K > 3.)
916
917 // Protection from insane SCEVs; this bound is conservative,
918 // but it probably doesn't matter.
919 if (K > 1000)
920 return SE.getCouldNotCompute();
921
922 unsigned W = SE.getTypeSizeInBits(ResultTy);
923
924 // Calculate K! / 2^T and T; we divide out the factors of two before
925 // multiplying for calculating K! / 2^T to avoid overflow.
926 // Other overflow doesn't matter because we only care about the bottom
927 // W bits of the result.
928 APInt OddFactorial(W, 1);
929 unsigned T = 1;
930 for (unsigned i = 3; i <= K; ++i) {
931 unsigned TwoFactors = countr_zero(i);
932 T += TwoFactors;
933 OddFactorial *= (i >> TwoFactors);
934 }
935
936 // We need at least W + T bits for the multiplication step
937 unsigned CalculationBits = W + T;
938
939 // Calculate 2^T, at width T+W.
940 APInt DivFactor = APInt::getOneBitSet(CalculationBits, T);
941
942 // Calculate the multiplicative inverse of K! / 2^T;
943 // this multiplication factor will perform the exact division by
944 // K! / 2^T.
945 APInt MultiplyFactor = OddFactorial.multiplicativeInverse();
946
947 // Calculate the product, at width T+W
948 IntegerType *CalculationTy = IntegerType::get(SE.getContext(),
949 CalculationBits);
950 const SCEV *Dividend = SE.getTruncateOrZeroExtend(It, CalculationTy);
951 for (unsigned i = 1; i != K; ++i) {
952 const SCEV *S = SE.getMinusSCEV(It, SE.getConstant(It->getType(), i));
953 Dividend = SE.getMulExpr(Dividend,
954 SE.getTruncateOrZeroExtend(S, CalculationTy));
955 }
956
957 // Divide by 2^T
958 const SCEV *DivResult = SE.getUDivExpr(Dividend, SE.getConstant(DivFactor));
959
960 // Truncate the result, and divide by K! / 2^T.
961
962 return SE.getMulExpr(SE.getConstant(MultiplyFactor),
963 SE.getTruncateOrZeroExtend(DivResult, ResultTy));
964}
965
966/// Return the value of this chain of recurrences at the specified iteration
967/// number. We can evaluate this recurrence by multiplying each element in the
968/// chain by the binomial coefficient corresponding to it. In other words, we
969/// can evaluate {A,+,B,+,C,+,D} as:
970///
971/// A*BC(It, 0) + B*BC(It, 1) + C*BC(It, 2) + D*BC(It, 3)
972///
973/// where BC(It, k) stands for binomial coefficient.
975 ScalarEvolution &SE) const {
976 return evaluateAtIteration(operands(), It, SE);
977}
978
979const SCEV *
981 const SCEV *It, ScalarEvolution &SE) {
982 assert(Operands.size() > 0);
983 const SCEV *Result = Operands[0];
984 for (unsigned i = 1, e = Operands.size(); i != e; ++i) {
985 // The computation is correct in the face of overflow provided that the
986 // multiplication is performed _after_ the evaluation of the binomial
987 // coefficient.
988 const SCEV *Coeff = BinomialCoefficient(It, i, SE, Result->getType());
989 if (isa<SCEVCouldNotCompute>(Coeff))
990 return Coeff;
991
992 Result = SE.getAddExpr(Result, SE.getMulExpr(Operands[i], Coeff));
993 }
994 return Result;
995}
996
997//===----------------------------------------------------------------------===//
998// SCEV Expression folder implementations
999//===----------------------------------------------------------------------===//
1000
1002 unsigned Depth) {
1003 assert(Depth <= 1 &&
1004 "getLosslessPtrToIntExpr() should self-recurse at most once.");
1005
1006 // We could be called with an integer-typed operands during SCEV rewrites.
1007 // Since the operand is an integer already, just perform zext/trunc/self cast.
1008 if (!Op->getType()->isPointerTy())
1009 return Op;
1010
1011 // What would be an ID for such a SCEV cast expression?
1013 ID.AddInteger(scPtrToInt);
1014 ID.AddPointer(Op);
1015
1016 void *IP = nullptr;
1017
1018 // Is there already an expression for such a cast?
1019 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP))
1020 return S;
1021
1022 // It isn't legal for optimizations to construct new ptrtoint expressions
1023 // for non-integral pointers.
1024 if (getDataLayout().isNonIntegralPointerType(Op->getType()))
1025 return getCouldNotCompute();
1026
1027 Type *IntPtrTy = getDataLayout().getIntPtrType(Op->getType());
1028
1029 // We can only trivially model ptrtoint if SCEV's effective (integer) type
1030 // is sufficiently wide to represent all possible pointer values.
1031 // We could theoretically teach SCEV to truncate wider pointers, but
1032 // that isn't implemented for now.
1034 getDataLayout().getTypeSizeInBits(IntPtrTy))
1035 return getCouldNotCompute();
1036
1037 // If not, is this expression something we can't reduce any further?
1038 if (auto *U = dyn_cast<SCEVUnknown>(Op)) {
1039 // Perform some basic constant folding. If the operand of the ptr2int cast
1040 // is a null pointer, don't create a ptr2int SCEV expression (that will be
1041 // left as-is), but produce a zero constant.
1042 // NOTE: We could handle a more general case, but lack motivational cases.
1043 if (isa<ConstantPointerNull>(U->getValue()))
1044 return getZero(IntPtrTy);
1045
1046 // Create an explicit cast node.
1047 // We can reuse the existing insert position since if we get here,
1048 // we won't have made any changes which would invalidate it.
1049 SCEV *S = new (SCEVAllocator)
1050 SCEVPtrToIntExpr(ID.Intern(SCEVAllocator), Op, IntPtrTy);
1051 UniqueSCEVs.InsertNode(S, IP);
1052 registerUser(S, Op);
1053 return S;
1054 }
1055
1056 assert(Depth == 0 && "getLosslessPtrToIntExpr() should not self-recurse for "
1057 "non-SCEVUnknown's.");
1058
1059 // Otherwise, we've got some expression that is more complex than just a
1060 // single SCEVUnknown. But we don't want to have a SCEVPtrToIntExpr of an
1061 // arbitrary expression, we want to have SCEVPtrToIntExpr of an SCEVUnknown
1062 // only, and the expressions must otherwise be integer-typed.
1063 // So sink the cast down to the SCEVUnknown's.
1064
1065 /// The SCEVPtrToIntSinkingRewriter takes a scalar evolution expression,
1066 /// which computes a pointer-typed value, and rewrites the whole expression
1067 /// tree so that *all* the computations are done on integers, and the only
1068 /// pointer-typed operands in the expression are SCEVUnknown.
1069 class SCEVPtrToIntSinkingRewriter
1070 : public SCEVRewriteVisitor<SCEVPtrToIntSinkingRewriter> {
1072
1073 public:
1074 SCEVPtrToIntSinkingRewriter(ScalarEvolution &SE) : SCEVRewriteVisitor(SE) {}
1075
1076 static const SCEV *rewrite(const SCEV *Scev, ScalarEvolution &SE) {
1077 SCEVPtrToIntSinkingRewriter Rewriter(SE);
1078 return Rewriter.visit(Scev);
1079 }
1080
1081 const SCEV *visit(const SCEV *S) {
1082 Type *STy = S->getType();
1083 // If the expression is not pointer-typed, just keep it as-is.
1084 if (!STy->isPointerTy())
1085 return S;
1086 // Else, recursively sink the cast down into it.
1087 return Base::visit(S);
1088 }
1089
1090 const SCEV *visitAddExpr(const SCEVAddExpr *Expr) {
1092 bool Changed = false;
1093 for (const auto *Op : Expr->operands()) {
1094 Operands.push_back(visit(Op));
1095 Changed |= Op != Operands.back();
1096 }
1097 return !Changed ? Expr : SE.getAddExpr(Operands, Expr->getNoWrapFlags());
1098 }
1099
1100 const SCEV *visitMulExpr(const SCEVMulExpr *Expr) {
1102 bool Changed = false;
1103 for (const auto *Op : Expr->operands()) {
1104 Operands.push_back(visit(Op));
1105 Changed |= Op != Operands.back();
1106 }
1107 return !Changed ? Expr : SE.getMulExpr(Operands, Expr->getNoWrapFlags());
1108 }
1109
1110 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
1111 assert(Expr->getType()->isPointerTy() &&
1112 "Should only reach pointer-typed SCEVUnknown's.");
1113 return SE.getLosslessPtrToIntExpr(Expr, /*Depth=*/1);
1114 }
1115 };
1116
1117 // And actually perform the cast sinking.
1118 const SCEV *IntOp = SCEVPtrToIntSinkingRewriter::rewrite(Op, *this);
1119 assert(IntOp->getType()->isIntegerTy() &&
1120 "We must have succeeded in sinking the cast, "
1121 "and ending up with an integer-typed expression!");
1122 return IntOp;
1123}
1124
1126 assert(Ty->isIntegerTy() && "Target type must be an integer type!");
1127
1128 const SCEV *IntOp = getLosslessPtrToIntExpr(Op);
1129 if (isa<SCEVCouldNotCompute>(IntOp))
1130 return IntOp;
1131
1132 return getTruncateOrZeroExtend(IntOp, Ty);
1133}
1134
1136 unsigned Depth) {
1137 assert(getTypeSizeInBits(Op->getType()) > getTypeSizeInBits(Ty) &&
1138 "This is not a truncating conversion!");
1139 assert(isSCEVable(Ty) &&
1140 "This is not a conversion to a SCEVable type!");
1141 assert(!Op->getType()->isPointerTy() && "Can't truncate pointer!");
1142 Ty = getEffectiveSCEVType(Ty);
1143
1145 ID.AddInteger(scTruncate);
1146 ID.AddPointer(Op);
1147 ID.AddPointer(Ty);
1148 void *IP = nullptr;
1149 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
1150
1151 // Fold if the operand is constant.
1152 if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op))
1153 return getConstant(
1154 cast<ConstantInt>(ConstantExpr::getTrunc(SC->getValue(), Ty)));
1155
1156 // trunc(trunc(x)) --> trunc(x)
1157 if (const SCEVTruncateExpr *ST = dyn_cast<SCEVTruncateExpr>(Op))
1158 return getTruncateExpr(ST->getOperand(), Ty, Depth + 1);
1159
1160 // trunc(sext(x)) --> sext(x) if widening or trunc(x) if narrowing
1161 if (const SCEVSignExtendExpr *SS = dyn_cast<SCEVSignExtendExpr>(Op))
1162 return getTruncateOrSignExtend(SS->getOperand(), Ty, Depth + 1);
1163
1164 // trunc(zext(x)) --> zext(x) if widening or trunc(x) if narrowing
1165 if (const SCEVZeroExtendExpr *SZ = dyn_cast<SCEVZeroExtendExpr>(Op))
1166 return getTruncateOrZeroExtend(SZ->getOperand(), Ty, Depth + 1);
1167
1168 if (Depth > MaxCastDepth) {
1169 SCEV *S =
1170 new (SCEVAllocator) SCEVTruncateExpr(ID.Intern(SCEVAllocator), Op, Ty);
1171 UniqueSCEVs.InsertNode(S, IP);
1172 registerUser(S, Op);
1173 return S;
1174 }
1175
1176 // trunc(x1 + ... + xN) --> trunc(x1) + ... + trunc(xN) and
1177 // trunc(x1 * ... * xN) --> trunc(x1) * ... * trunc(xN),
1178 // if after transforming we have at most one truncate, not counting truncates
1179 // that replace other casts.
1180 if (isa<SCEVAddExpr>(Op) || isa<SCEVMulExpr>(Op)) {
1181 auto *CommOp = cast<SCEVCommutativeExpr>(Op);
1183 unsigned numTruncs = 0;
1184 for (unsigned i = 0, e = CommOp->getNumOperands(); i != e && numTruncs < 2;
1185 ++i) {
1186 const SCEV *S = getTruncateExpr(CommOp->getOperand(i), Ty, Depth + 1);
1187 if (!isa<SCEVIntegralCastExpr>(CommOp->getOperand(i)) &&
1188 isa<SCEVTruncateExpr>(S))
1189 numTruncs++;
1190 Operands.push_back(S);
1191 }
1192 if (numTruncs < 2) {
1193 if (isa<SCEVAddExpr>(Op))
1194 return getAddExpr(Operands);
1195 if (isa<SCEVMulExpr>(Op))
1196 return getMulExpr(Operands);
1197 llvm_unreachable("Unexpected SCEV type for Op.");
1198 }
1199 // Although we checked in the beginning that ID is not in the cache, it is
1200 // possible that during recursion and different modification ID was inserted
1201 // into the cache. So if we find it, just return it.
1202 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP))
1203 return S;
1204 }
1205
1206 // If the input value is a chrec scev, truncate the chrec's operands.
1207 if (const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(Op)) {
1209 for (const SCEV *Op : AddRec->operands())
1210 Operands.push_back(getTruncateExpr(Op, Ty, Depth + 1));
1211 return getAddRecExpr(Operands, AddRec->getLoop(), SCEV::FlagAnyWrap);
1212 }
1213
1214 // Return zero if truncating to known zeros.
1215 uint32_t MinTrailingZeros = getMinTrailingZeros(Op);
1216 if (MinTrailingZeros >= getTypeSizeInBits(Ty))
1217 return getZero(Ty);
1218
1219 // The cast wasn't folded; create an explicit cast node. We can reuse
1220 // the existing insert position since if we get here, we won't have
1221 // made any changes which would invalidate it.
1222 SCEV *S = new (SCEVAllocator) SCEVTruncateExpr(ID.Intern(SCEVAllocator),
1223 Op, Ty);
1224 UniqueSCEVs.InsertNode(S, IP);
1225 registerUser(S, Op);
1226 return S;
1227}
1228
1229// Get the limit of a recurrence such that incrementing by Step cannot cause
1230// signed overflow as long as the value of the recurrence within the
1231// loop does not exceed this limit before incrementing.
1232static const SCEV *getSignedOverflowLimitForStep(const SCEV *Step,
1233 ICmpInst::Predicate *Pred,
1234 ScalarEvolution *SE) {
1235 unsigned BitWidth = SE->getTypeSizeInBits(Step->getType());
1236 if (SE->isKnownPositive(Step)) {
1237 *Pred = ICmpInst::ICMP_SLT;
1239 SE->getSignedRangeMax(Step));
1240 }
1241 if (SE->isKnownNegative(Step)) {
1242 *Pred = ICmpInst::ICMP_SGT;
1244 SE->getSignedRangeMin(Step));
1245 }
1246 return nullptr;
1247}
1248
1249// Get the limit of a recurrence such that incrementing by Step cannot cause
1250// unsigned overflow as long as the value of the recurrence within the loop does
1251// not exceed this limit before incrementing.
1253 ICmpInst::Predicate *Pred,
1254 ScalarEvolution *SE) {
1255 unsigned BitWidth = SE->getTypeSizeInBits(Step->getType());
1256 *Pred = ICmpInst::ICMP_ULT;
1257
1259 SE->getUnsignedRangeMax(Step));
1260}
1261
1262namespace {
1263
1264struct ExtendOpTraitsBase {
1265 typedef const SCEV *(ScalarEvolution::*GetExtendExprTy)(const SCEV *, Type *,
1266 unsigned);
1267};
1268
1269// Used to make code generic over signed and unsigned overflow.
1270template <typename ExtendOp> struct ExtendOpTraits {
1271 // Members present:
1272 //
1273 // static const SCEV::NoWrapFlags WrapType;
1274 //
1275 // static const ExtendOpTraitsBase::GetExtendExprTy GetExtendExpr;
1276 //
1277 // static const SCEV *getOverflowLimitForStep(const SCEV *Step,
1278 // ICmpInst::Predicate *Pred,
1279 // ScalarEvolution *SE);
1280};
1281
1282template <>
1283struct ExtendOpTraits<SCEVSignExtendExpr> : public ExtendOpTraitsBase {
1284 static const SCEV::NoWrapFlags WrapType = SCEV::FlagNSW;
1285
1286 static const GetExtendExprTy GetExtendExpr;
1287
1288 static const SCEV *getOverflowLimitForStep(const SCEV *Step,
1289 ICmpInst::Predicate *Pred,
1290 ScalarEvolution *SE) {
1291 return getSignedOverflowLimitForStep(Step, Pred, SE);
1292 }
1293};
1294
1295const ExtendOpTraitsBase::GetExtendExprTy ExtendOpTraits<
1297
1298template <>
1299struct ExtendOpTraits<SCEVZeroExtendExpr> : public ExtendOpTraitsBase {
1300 static const SCEV::NoWrapFlags WrapType = SCEV::FlagNUW;
1301
1302 static const GetExtendExprTy GetExtendExpr;
1303
1304 static const SCEV *getOverflowLimitForStep(const SCEV *Step,
1305 ICmpInst::Predicate *Pred,
1306 ScalarEvolution *SE) {
1307 return getUnsignedOverflowLimitForStep(Step, Pred, SE);
1308 }
1309};
1310
1311const ExtendOpTraitsBase::GetExtendExprTy ExtendOpTraits<
1313
1314} // end anonymous namespace
1315
1316// The recurrence AR has been shown to have no signed/unsigned wrap or something
1317// close to it. Typically, if we can prove NSW/NUW for AR, then we can just as
1318// easily prove NSW/NUW for its preincrement or postincrement sibling. This
1319// allows normalizing a sign/zero extended AddRec as such: {sext/zext(Step +
1320// Start),+,Step} => {(Step + sext/zext(Start),+,Step} As a result, the
1321// expression "Step + sext/zext(PreIncAR)" is congruent with
1322// "sext/zext(PostIncAR)"
1323template <typename ExtendOpTy>
1324static const SCEV *getPreStartForExtend(const SCEVAddRecExpr *AR, Type *Ty,
1325 ScalarEvolution *SE, unsigned Depth) {
1326 auto WrapType = ExtendOpTraits<ExtendOpTy>::WrapType;
1327 auto GetExtendExpr = ExtendOpTraits<ExtendOpTy>::GetExtendExpr;
1328
1329 const Loop *L = AR->getLoop();
1330 const SCEV *Start = AR->getStart();
1331 const SCEV *Step = AR->getStepRecurrence(*SE);
1332
1333 // Check for a simple looking step prior to loop entry.
1334 const SCEVAddExpr *SA = dyn_cast<SCEVAddExpr>(Start);
1335 if (!SA)
1336 return nullptr;
1337
1338 // Create an AddExpr for "PreStart" after subtracting Step. Full SCEV
1339 // subtraction is expensive. For this purpose, perform a quick and dirty
1340 // difference, by checking for Step in the operand list. Note, that
1341 // SA might have repeated ops, like %a + %a + ..., so only remove one.
1343 for (auto It = DiffOps.begin(); It != DiffOps.end(); ++It)
1344 if (*It == Step) {
1345 DiffOps.erase(It);
1346 break;
1347 }
1348
1349 if (DiffOps.size() == SA->getNumOperands())
1350 return nullptr;
1351
1352 // Try to prove `WrapType` (SCEV::FlagNSW or SCEV::FlagNUW) on `PreStart` +
1353 // `Step`:
1354
1355 // 1. NSW/NUW flags on the step increment.
1356 auto PreStartFlags =
1358 const SCEV *PreStart = SE->getAddExpr(DiffOps, PreStartFlags);
1359 const SCEVAddRecExpr *PreAR = dyn_cast<SCEVAddRecExpr>(
1360 SE->getAddRecExpr(PreStart, Step, L, SCEV::FlagAnyWrap));
1361
1362 // "{S,+,X} is <nsw>/<nuw>" and "the backedge is taken at least once" implies
1363 // "S+X does not sign/unsign-overflow".
1364 //
1365
1366 const SCEV *BECount = SE->getBackedgeTakenCount(L);
1367 if (PreAR && PreAR->getNoWrapFlags(WrapType) &&
1368 !isa<SCEVCouldNotCompute>(BECount) && SE->isKnownPositive(BECount))
1369 return PreStart;
1370
1371 // 2. Direct overflow check on the step operation's expression.
1372 unsigned BitWidth = SE->getTypeSizeInBits(AR->getType());
1373 Type *WideTy = IntegerType::get(SE->getContext(), BitWidth * 2);
1374 const SCEV *OperandExtendedStart =
1375 SE->getAddExpr((SE->*GetExtendExpr)(PreStart, WideTy, Depth),
1376 (SE->*GetExtendExpr)(Step, WideTy, Depth));
1377 if ((SE->*GetExtendExpr)(Start, WideTy, Depth) == OperandExtendedStart) {
1378 if (PreAR && AR->getNoWrapFlags(WrapType)) {
1379 // If we know `AR` == {`PreStart`+`Step`,+,`Step`} is `WrapType` (FlagNSW
1380 // or FlagNUW) and that `PreStart` + `Step` is `WrapType` too, then
1381 // `PreAR` == {`PreStart`,+,`Step`} is also `WrapType`. Cache this fact.
1382 SE->setNoWrapFlags(const_cast<SCEVAddRecExpr *>(PreAR), WrapType);
1383 }
1384 return PreStart;
1385 }
1386
1387 // 3. Loop precondition.
1389 const SCEV *OverflowLimit =
1390 ExtendOpTraits<ExtendOpTy>::getOverflowLimitForStep(Step, &Pred, SE);
1391
1392 if (OverflowLimit &&
1393 SE->isLoopEntryGuardedByCond(L, Pred, PreStart, OverflowLimit))
1394 return PreStart;
1395
1396 return nullptr;
1397}
1398
1399// Get the normalized zero or sign extended expression for this AddRec's Start.
1400template <typename ExtendOpTy>
1401static const SCEV *getExtendAddRecStart(const SCEVAddRecExpr *AR, Type *Ty,
1402 ScalarEvolution *SE,
1403 unsigned Depth) {
1404 auto GetExtendExpr = ExtendOpTraits<ExtendOpTy>::GetExtendExpr;
1405
1406 const SCEV *PreStart = getPreStartForExtend<ExtendOpTy>(AR, Ty, SE, Depth);
1407 if (!PreStart)
1408 return (SE->*GetExtendExpr)(AR->getStart(), Ty, Depth);
1409
1410 return SE->getAddExpr((SE->*GetExtendExpr)(AR->getStepRecurrence(*SE), Ty,
1411 Depth),
1412 (SE->*GetExtendExpr)(PreStart, Ty, Depth));
1413}
1414
1415// Try to prove away overflow by looking at "nearby" add recurrences. A
1416// motivating example for this rule: if we know `{0,+,4}` is `ult` `-1` and it
1417// does not itself wrap then we can conclude that `{1,+,4}` is `nuw`.
1418//
1419// Formally:
1420//
1421// {S,+,X} == {S-T,+,X} + T
1422// => Ext({S,+,X}) == Ext({S-T,+,X} + T)
1423//
1424// If ({S-T,+,X} + T) does not overflow ... (1)
1425//
1426// RHS == Ext({S-T,+,X} + T) == Ext({S-T,+,X}) + Ext(T)
1427//
1428// If {S-T,+,X} does not overflow ... (2)
1429//
1430// RHS == Ext({S-T,+,X}) + Ext(T) == {Ext(S-T),+,Ext(X)} + Ext(T)
1431// == {Ext(S-T)+Ext(T),+,Ext(X)}
1432//
1433// If (S-T)+T does not overflow ... (3)
1434//
1435// RHS == {Ext(S-T)+Ext(T),+,Ext(X)} == {Ext(S-T+T),+,Ext(X)}
1436// == {Ext(S),+,Ext(X)} == LHS
1437//
1438// Thus, if (1), (2) and (3) are true for some T, then
1439// Ext({S,+,X}) == {Ext(S),+,Ext(X)}
1440//
1441// (3) is implied by (1) -- "(S-T)+T does not overflow" is simply "({S-T,+,X}+T)
1442// does not overflow" restricted to the 0th iteration. Therefore we only need
1443// to check for (1) and (2).
1444//
1445// In the current context, S is `Start`, X is `Step`, Ext is `ExtendOpTy` and T
1446// is `Delta` (defined below).
1447template <typename ExtendOpTy>
1448bool ScalarEvolution::proveNoWrapByVaryingStart(const SCEV *Start,
1449 const SCEV *Step,
1450 const Loop *L) {
1451 auto WrapType = ExtendOpTraits<ExtendOpTy>::WrapType;
1452
1453 // We restrict `Start` to a constant to prevent SCEV from spending too much
1454 // time here. It is correct (but more expensive) to continue with a
1455 // non-constant `Start` and do a general SCEV subtraction to compute
1456 // `PreStart` below.
1457 const SCEVConstant *StartC = dyn_cast<SCEVConstant>(Start);
1458 if (!StartC)
1459 return false;
1460
1461 APInt StartAI = StartC->getAPInt();
1462
1463 for (unsigned Delta : {-2, -1, 1, 2}) {
1464 const SCEV *PreStart = getConstant(StartAI - Delta);
1465
1467 ID.AddInteger(scAddRecExpr);
1468 ID.AddPointer(PreStart);
1469 ID.AddPointer(Step);
1470 ID.AddPointer(L);
1471 void *IP = nullptr;
1472 const auto *PreAR =
1473 static_cast<SCEVAddRecExpr *>(UniqueSCEVs.FindNodeOrInsertPos(ID, IP));
1474
1475 // Give up if we don't already have the add recurrence we need because
1476 // actually constructing an add recurrence is relatively expensive.
1477 if (PreAR && PreAR->getNoWrapFlags(WrapType)) { // proves (2)
1478 const SCEV *DeltaS = getConstant(StartC->getType(), Delta);
1480 const SCEV *Limit = ExtendOpTraits<ExtendOpTy>::getOverflowLimitForStep(
1481 DeltaS, &Pred, this);
1482 if (Limit && isKnownPredicate(Pred, PreAR, Limit)) // proves (1)
1483 return true;
1484 }
1485 }
1486
1487 return false;
1488}
1489
1490// Finds an integer D for an expression (C + x + y + ...) such that the top
1491// level addition in (D + (C - D + x + y + ...)) would not wrap (signed or
1492// unsigned) and the number of trailing zeros of (C - D + x + y + ...) is
1493// maximized, where C is the \p ConstantTerm, x, y, ... are arbitrary SCEVs, and
1494// the (C + x + y + ...) expression is \p WholeAddExpr.
1496 const SCEVConstant *ConstantTerm,
1497 const SCEVAddExpr *WholeAddExpr) {
1498 const APInt &C = ConstantTerm->getAPInt();
1499 const unsigned BitWidth = C.getBitWidth();
1500 // Find number of trailing zeros of (x + y + ...) w/o the C first:
1501 uint32_t TZ = BitWidth;
1502 for (unsigned I = 1, E = WholeAddExpr->getNumOperands(); I < E && TZ; ++I)
1503 TZ = std::min(TZ, SE.getMinTrailingZeros(WholeAddExpr->getOperand(I)));
1504 if (TZ) {
1505 // Set D to be as many least significant bits of C as possible while still
1506 // guaranteeing that adding D to (C - D + x + y + ...) won't cause a wrap:
1507 return TZ < BitWidth ? C.trunc(TZ).zext(BitWidth) : C;
1508 }
1509 return APInt(BitWidth, 0);
1510}
1511
1512// Finds an integer D for an affine AddRec expression {C,+,x} such that the top
1513// level addition in (D + {C-D,+,x}) would not wrap (signed or unsigned) and the
1514// number of trailing zeros of (C - D + x * n) is maximized, where C is the \p
1515// ConstantStart, x is an arbitrary \p Step, and n is the loop trip count.
1517 const APInt &ConstantStart,
1518 const SCEV *Step) {
1519 const unsigned BitWidth = ConstantStart.getBitWidth();
1520 const uint32_t TZ = SE.getMinTrailingZeros(Step);
1521 if (TZ)
1522 return TZ < BitWidth ? ConstantStart.trunc(TZ).zext(BitWidth)
1523 : ConstantStart;
1524 return APInt(BitWidth, 0);
1525}
1526
1528 const ScalarEvolution::FoldID &ID, const SCEV *S,
1531 &FoldCacheUser) {
1532 auto I = FoldCache.insert({ID, S});
1533 if (!I.second) {
1534 // Remove FoldCacheUser entry for ID when replacing an existing FoldCache
1535 // entry.
1536 auto &UserIDs = FoldCacheUser[I.first->second];
1537 assert(count(UserIDs, ID) == 1 && "unexpected duplicates in UserIDs");
1538 for (unsigned I = 0; I != UserIDs.size(); ++I)
1539 if (UserIDs[I] == ID) {
1540 std::swap(UserIDs[I], UserIDs.back());
1541 break;
1542 }
1543 UserIDs.pop_back();
1544 I.first->second = S;
1545 }
1546 auto R = FoldCacheUser.insert({S, {}});
1547 R.first->second.push_back(ID);
1548}
1549
1550const SCEV *
1552 assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) &&
1553 "This is not an extending conversion!");
1554 assert(isSCEVable(Ty) &&
1555 "This is not a conversion to a SCEVable type!");
1556 assert(!Op->getType()->isPointerTy() && "Can't extend pointer!");
1557 Ty = getEffectiveSCEVType(Ty);
1558
1559 FoldID ID(scZeroExtend, Op, Ty);
1560 auto Iter = FoldCache.find(ID);
1561 if (Iter != FoldCache.end())
1562 return Iter->second;
1563
1564 const SCEV *S = getZeroExtendExprImpl(Op, Ty, Depth);
1565 if (!isa<SCEVZeroExtendExpr>(S))
1566 insertFoldCacheEntry(ID, S, FoldCache, FoldCacheUser);
1567 return S;
1568}
1569
1571 unsigned Depth) {
1572 assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) &&
1573 "This is not an extending conversion!");
1574 assert(isSCEVable(Ty) && "This is not a conversion to a SCEVable type!");
1575 assert(!Op->getType()->isPointerTy() && "Can't extend pointer!");
1576
1577 // Fold if the operand is constant.
1578 if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op))
1579 return getConstant(SC->getAPInt().zext(getTypeSizeInBits(Ty)));
1580
1581 // zext(zext(x)) --> zext(x)
1582 if (const SCEVZeroExtendExpr *SZ = dyn_cast<SCEVZeroExtendExpr>(Op))
1583 return getZeroExtendExpr(SZ->getOperand(), Ty, Depth + 1);
1584
1585 // Before doing any expensive analysis, check to see if we've already
1586 // computed a SCEV for this Op and Ty.
1588 ID.AddInteger(scZeroExtend);
1589 ID.AddPointer(Op);
1590 ID.AddPointer(Ty);
1591 void *IP = nullptr;
1592 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
1593 if (Depth > MaxCastDepth) {
1594 SCEV *S = new (SCEVAllocator) SCEVZeroExtendExpr(ID.Intern(SCEVAllocator),
1595 Op, Ty);
1596 UniqueSCEVs.InsertNode(S, IP);
1597 registerUser(S, Op);
1598 return S;
1599 }
1600
1601 // zext(trunc(x)) --> zext(x) or x or trunc(x)
1602 if (const SCEVTruncateExpr *ST = dyn_cast<SCEVTruncateExpr>(Op)) {
1603 // It's possible the bits taken off by the truncate were all zero bits. If
1604 // so, we should be able to simplify this further.
1605 const SCEV *X = ST->getOperand();
1607 unsigned TruncBits = getTypeSizeInBits(ST->getType());
1608 unsigned NewBits = getTypeSizeInBits(Ty);
1609 if (CR.truncate(TruncBits).zeroExtend(NewBits).contains(
1610 CR.zextOrTrunc(NewBits)))
1611 return getTruncateOrZeroExtend(X, Ty, Depth);
1612 }
1613
1614 // If the input value is a chrec scev, and we can prove that the value
1615 // did not overflow the old, smaller, value, we can zero extend all of the
1616 // operands (often constants). This allows analysis of something like
1617 // this: for (unsigned char X = 0; X < 100; ++X) { int Y = X; }
1618 if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Op))
1619 if (AR->isAffine()) {
1620 const SCEV *Start = AR->getStart();
1621 const SCEV *Step = AR->getStepRecurrence(*this);
1622 unsigned BitWidth = getTypeSizeInBits(AR->getType());
1623 const Loop *L = AR->getLoop();
1624
1625 // If we have special knowledge that this addrec won't overflow,
1626 // we don't need to do any further analysis.
1627 if (AR->hasNoUnsignedWrap()) {
1628 Start =
1629 getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this, Depth + 1);
1630 Step = getZeroExtendExpr(Step, Ty, Depth + 1);
1631 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
1632 }
1633
1634 // Check whether the backedge-taken count is SCEVCouldNotCompute.
1635 // Note that this serves two purposes: It filters out loops that are
1636 // simply not analyzable, and it covers the case where this code is
1637 // being called from within backedge-taken count analysis, such that
1638 // attempting to ask for the backedge-taken count would likely result
1639 // in infinite recursion. In the later case, the analysis code will
1640 // cope with a conservative value, and it will take care to purge
1641 // that value once it has finished.
1642 const SCEV *MaxBECount = getConstantMaxBackedgeTakenCount(L);
1643 if (!isa<SCEVCouldNotCompute>(MaxBECount)) {
1644 // Manually compute the final value for AR, checking for overflow.
1645
1646 // Check whether the backedge-taken count can be losslessly casted to
1647 // the addrec's type. The count is always unsigned.
1648 const SCEV *CastedMaxBECount =
1649 getTruncateOrZeroExtend(MaxBECount, Start->getType(), Depth);
1650 const SCEV *RecastedMaxBECount = getTruncateOrZeroExtend(
1651 CastedMaxBECount, MaxBECount->getType(), Depth);
1652 if (MaxBECount == RecastedMaxBECount) {
1653 Type *WideTy = IntegerType::get(getContext(), BitWidth * 2);
1654 // Check whether Start+Step*MaxBECount has no unsigned overflow.
1655 const SCEV *ZMul = getMulExpr(CastedMaxBECount, Step,
1657 const SCEV *ZAdd = getZeroExtendExpr(getAddExpr(Start, ZMul,
1659 Depth + 1),
1660 WideTy, Depth + 1);
1661 const SCEV *WideStart = getZeroExtendExpr(Start, WideTy, Depth + 1);
1662 const SCEV *WideMaxBECount =
1663 getZeroExtendExpr(CastedMaxBECount, WideTy, Depth + 1);
1664 const SCEV *OperandExtendedAdd =
1665 getAddExpr(WideStart,
1666 getMulExpr(WideMaxBECount,
1667 getZeroExtendExpr(Step, WideTy, Depth + 1),
1670 if (ZAdd == OperandExtendedAdd) {
1671 // Cache knowledge of AR NUW, which is propagated to this AddRec.
1672 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNUW);
1673 // Return the expression with the addrec on the outside.
1674 Start = getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this,
1675 Depth + 1);
1676 Step = getZeroExtendExpr(Step, Ty, Depth + 1);
1677 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
1678 }
1679 // Similar to above, only this time treat the step value as signed.
1680 // This covers loops that count down.
1681 OperandExtendedAdd =
1682 getAddExpr(WideStart,
1683 getMulExpr(WideMaxBECount,
1684 getSignExtendExpr(Step, WideTy, Depth + 1),
1687 if (ZAdd == OperandExtendedAdd) {
1688 // Cache knowledge of AR NW, which is propagated to this AddRec.
1689 // Negative step causes unsigned wrap, but it still can't self-wrap.
1690 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNW);
1691 // Return the expression with the addrec on the outside.
1692 Start = getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this,
1693 Depth + 1);
1694 Step = getSignExtendExpr(Step, Ty, Depth + 1);
1695 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
1696 }
1697 }
1698 }
1699
1700 // Normally, in the cases we can prove no-overflow via a
1701 // backedge guarding condition, we can also compute a backedge
1702 // taken count for the loop. The exceptions are assumptions and
1703 // guards present in the loop -- SCEV is not great at exploiting
1704 // these to compute max backedge taken counts, but can still use
1705 // these to prove lack of overflow. Use this fact to avoid
1706 // doing extra work that may not pay off.
1707 if (!isa<SCEVCouldNotCompute>(MaxBECount) || HasGuards ||
1708 !AC.assumptions().empty()) {
1709
1710 auto NewFlags = proveNoUnsignedWrapViaInduction(AR);
1711 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), NewFlags);
1712 if (AR->hasNoUnsignedWrap()) {
1713 // Same as nuw case above - duplicated here to avoid a compile time
1714 // issue. It's not clear that the order of checks does matter, but
1715 // it's one of two issue possible causes for a change which was
1716 // reverted. Be conservative for the moment.
1717 Start =
1718 getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this, Depth + 1);
1719 Step = getZeroExtendExpr(Step, Ty, Depth + 1);
1720 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
1721 }
1722
1723 // For a negative step, we can extend the operands iff doing so only
1724 // traverses values in the range zext([0,UINT_MAX]).
1725 if (isKnownNegative(Step)) {
1727 getSignedRangeMin(Step));
1730 // Cache knowledge of AR NW, which is propagated to this
1731 // AddRec. Negative step causes unsigned wrap, but it
1732 // still can't self-wrap.
1733 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNW);
1734 // Return the expression with the addrec on the outside.
1735 Start = getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this,
1736 Depth + 1);
1737 Step = getSignExtendExpr(Step, Ty, Depth + 1);
1738 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
1739 }
1740 }
1741 }
1742
1743 // zext({C,+,Step}) --> (zext(D) + zext({C-D,+,Step}))<nuw><nsw>
1744 // if D + (C - D + Step * n) could be proven to not unsigned wrap
1745 // where D maximizes the number of trailing zeros of (C - D + Step * n)
1746 if (const auto *SC = dyn_cast<SCEVConstant>(Start)) {
1747 const APInt &C = SC->getAPInt();
1748 const APInt &D = extractConstantWithoutWrapping(*this, C, Step);
1749 if (D != 0) {
1750 const SCEV *SZExtD = getZeroExtendExpr(getConstant(D), Ty, Depth);
1751 const SCEV *SResidual =
1752 getAddRecExpr(getConstant(C - D), Step, L, AR->getNoWrapFlags());
1753 const SCEV *SZExtR = getZeroExtendExpr(SResidual, Ty, Depth + 1);
1754 return getAddExpr(SZExtD, SZExtR,
1756 Depth + 1);
1757 }
1758 }
1759
1760 if (proveNoWrapByVaryingStart<SCEVZeroExtendExpr>(Start, Step, L)) {
1761 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNUW);
1762 Start =
1763 getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this, Depth + 1);
1764 Step = getZeroExtendExpr(Step, Ty, Depth + 1);
1765 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
1766 }
1767 }
1768
1769 // zext(A % B) --> zext(A) % zext(B)
1770 {
1771 const SCEV *LHS;
1772 const SCEV *RHS;
1773 if (matchURem(Op, LHS, RHS))
1774 return getURemExpr(getZeroExtendExpr(LHS, Ty, Depth + 1),
1775 getZeroExtendExpr(RHS, Ty, Depth + 1));
1776 }
1777
1778 // zext(A / B) --> zext(A) / zext(B).
1779 if (auto *Div = dyn_cast<SCEVUDivExpr>(Op))
1780 return getUDivExpr(getZeroExtendExpr(Div->getLHS(), Ty, Depth + 1),
1781 getZeroExtendExpr(Div->getRHS(), Ty, Depth + 1));
1782
1783 if (auto *SA = dyn_cast<SCEVAddExpr>(Op)) {
1784 // zext((A + B + ...)<nuw>) --> (zext(A) + zext(B) + ...)<nuw>
1785 if (SA->hasNoUnsignedWrap()) {
1786 // If the addition does not unsign overflow then we can, by definition,
1787 // commute the zero extension with the addition operation.
1789 for (const auto *Op : SA->operands())
1790 Ops.push_back(getZeroExtendExpr(Op, Ty, Depth + 1));
1791 return getAddExpr(Ops, SCEV::FlagNUW, Depth + 1);
1792 }
1793
1794 // zext(C + x + y + ...) --> (zext(D) + zext((C - D) + x + y + ...))
1795 // if D + (C - D + x + y + ...) could be proven to not unsigned wrap
1796 // where D maximizes the number of trailing zeros of (C - D + x + y + ...)
1797 //
1798 // Often address arithmetics contain expressions like
1799 // (zext (add (shl X, C1), C2)), for instance, (zext (5 + (4 * X))).
1800 // This transformation is useful while proving that such expressions are
1801 // equal or differ by a small constant amount, see LoadStoreVectorizer pass.
1802 if (const auto *SC = dyn_cast<SCEVConstant>(SA->getOperand(0))) {
1803 const APInt &D = extractConstantWithoutWrapping(*this, SC, SA);
1804 if (D != 0) {
1805 const SCEV *SZExtD = getZeroExtendExpr(getConstant(D), Ty, Depth);
1806 const SCEV *SResidual =
1808 const SCEV *SZExtR = getZeroExtendExpr(SResidual, Ty, Depth + 1);
1809 return getAddExpr(SZExtD, SZExtR,
1811 Depth + 1);
1812 }
1813 }
1814 }
1815
1816 if (auto *SM = dyn_cast<SCEVMulExpr>(Op)) {
1817 // zext((A * B * ...)<nuw>) --> (zext(A) * zext(B) * ...)<nuw>
1818 if (SM->hasNoUnsignedWrap()) {
1819 // If the multiply does not unsign overflow then we can, by definition,
1820 // commute the zero extension with the multiply operation.
1822 for (const auto *Op : SM->operands())
1823 Ops.push_back(getZeroExtendExpr(Op, Ty, Depth + 1));
1824 return getMulExpr(Ops, SCEV::FlagNUW, Depth + 1);
1825 }
1826
1827 // zext(2^K * (trunc X to iN)) to iM ->
1828 // 2^K * (zext(trunc X to i{N-K}) to iM)<nuw>
1829 //
1830 // Proof:
1831 //
1832 // zext(2^K * (trunc X to iN)) to iM
1833 // = zext((trunc X to iN) << K) to iM
1834 // = zext((trunc X to i{N-K}) << K)<nuw> to iM
1835 // (because shl removes the top K bits)
1836 // = zext((2^K * (trunc X to i{N-K}))<nuw>) to iM
1837 // = (2^K * (zext(trunc X to i{N-K}) to iM))<nuw>.
1838 //
1839 if (SM->getNumOperands() == 2)
1840 if (auto *MulLHS = dyn_cast<SCEVConstant>(SM->getOperand(0)))
1841 if (MulLHS->getAPInt().isPowerOf2())
1842 if (auto *TruncRHS = dyn_cast<SCEVTruncateExpr>(SM->getOperand(1))) {
1843 int NewTruncBits = getTypeSizeInBits(TruncRHS->getType()) -
1844 MulLHS->getAPInt().logBase2();
1845 Type *NewTruncTy = IntegerType::get(getContext(), NewTruncBits);
1846 return getMulExpr(
1847 getZeroExtendExpr(MulLHS, Ty),
1849 getTruncateExpr(TruncRHS->getOperand(), NewTruncTy), Ty),
1850 SCEV::FlagNUW, Depth + 1);
1851 }
1852 }
1853
1854 // zext(umin(x, y)) -> umin(zext(x), zext(y))
1855 // zext(umax(x, y)) -> umax(zext(x), zext(y))
1856 if (isa<SCEVUMinExpr>(Op) || isa<SCEVUMaxExpr>(Op)) {
1857 auto *MinMax = cast<SCEVMinMaxExpr>(Op);
1859 for (auto *Operand : MinMax->operands())
1860 Operands.push_back(getZeroExtendExpr(Operand, Ty));
1861 if (isa<SCEVUMinExpr>(MinMax))
1862 return getUMinExpr(Operands);
1863 return getUMaxExpr(Operands);
1864 }
1865
1866 // zext(umin_seq(x, y)) -> umin_seq(zext(x), zext(y))
1867 if (auto *MinMax = dyn_cast<SCEVSequentialMinMaxExpr>(Op)) {
1868 assert(isa<SCEVSequentialUMinExpr>(MinMax) && "Not supported!");
1870 for (auto *Operand : MinMax->operands())
1871 Operands.push_back(getZeroExtendExpr(Operand, Ty));
1872 return getUMinExpr(Operands, /*Sequential*/ true);
1873 }
1874
1875 // The cast wasn't folded; create an explicit cast node.
1876 // Recompute the insert position, as it may have been invalidated.
1877 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
1878 SCEV *S = new (SCEVAllocator) SCEVZeroExtendExpr(ID.Intern(SCEVAllocator),
1879 Op, Ty);
1880 UniqueSCEVs.InsertNode(S, IP);
1881 registerUser(S, Op);
1882 return S;
1883}
1884
1885const SCEV *
1887 assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) &&
1888 "This is not an extending conversion!");
1889 assert(isSCEVable(Ty) &&
1890 "This is not a conversion to a SCEVable type!");
1891 assert(!Op->getType()->isPointerTy() && "Can't extend pointer!");
1892 Ty = getEffectiveSCEVType(Ty);
1893
1894 FoldID ID(scSignExtend, Op, Ty);
1895 auto Iter = FoldCache.find(ID);
1896 if (Iter != FoldCache.end())
1897 return Iter->second;
1898
1899 const SCEV *S = getSignExtendExprImpl(Op, Ty, Depth);
1900 if (!isa<SCEVSignExtendExpr>(S))
1901 insertFoldCacheEntry(ID, S, FoldCache, FoldCacheUser);
1902 return S;
1903}
1904
1906 unsigned Depth) {
1907 assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) &&
1908 "This is not an extending conversion!");
1909 assert(isSCEVable(Ty) && "This is not a conversion to a SCEVable type!");
1910 assert(!Op->getType()->isPointerTy() && "Can't extend pointer!");
1911 Ty = getEffectiveSCEVType(Ty);
1912
1913 // Fold if the operand is constant.
1914 if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op))
1915 return getConstant(SC->getAPInt().sext(getTypeSizeInBits(Ty)));
1916
1917 // sext(sext(x)) --> sext(x)
1918 if (const SCEVSignExtendExpr *SS = dyn_cast<SCEVSignExtendExpr>(Op))
1919 return getSignExtendExpr(SS->getOperand(), Ty, Depth + 1);
1920
1921 // sext(zext(x)) --> zext(x)
1922 if (const SCEVZeroExtendExpr *SZ = dyn_cast<SCEVZeroExtendExpr>(Op))
1923 return getZeroExtendExpr(SZ->getOperand(), Ty, Depth + 1);
1924
1925 // Before doing any expensive analysis, check to see if we've already
1926 // computed a SCEV for this Op and Ty.
1928 ID.AddInteger(scSignExtend);
1929 ID.AddPointer(Op);
1930 ID.AddPointer(Ty);
1931 void *IP = nullptr;
1932 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
1933 // Limit recursion depth.
1934 if (Depth > MaxCastDepth) {
1935 SCEV *S = new (SCEVAllocator) SCEVSignExtendExpr(ID.Intern(SCEVAllocator),
1936 Op, Ty);
1937 UniqueSCEVs.InsertNode(S, IP);
1938 registerUser(S, Op);
1939 return S;
1940 }
1941
1942 // sext(trunc(x)) --> sext(x) or x or trunc(x)
1943 if (const SCEVTruncateExpr *ST = dyn_cast<SCEVTruncateExpr>(Op)) {
1944 // It's possible the bits taken off by the truncate were all sign bits. If
1945 // so, we should be able to simplify this further.
1946 const SCEV *X = ST->getOperand();
1948 unsigned TruncBits = getTypeSizeInBits(ST->getType());
1949 unsigned NewBits = getTypeSizeInBits(Ty);
1950 if (CR.truncate(TruncBits).signExtend(NewBits).contains(
1951 CR.sextOrTrunc(NewBits)))
1952 return getTruncateOrSignExtend(X, Ty, Depth);
1953 }
1954
1955 if (auto *SA = dyn_cast<SCEVAddExpr>(Op)) {
1956 // sext((A + B + ...)<nsw>) --> (sext(A) + sext(B) + ...)<nsw>
1957 if (SA->hasNoSignedWrap()) {
1958 // If the addition does not sign overflow then we can, by definition,
1959 // commute the sign extension with the addition operation.
1961 for (const auto *Op : SA->operands())
1962 Ops.push_back(getSignExtendExpr(Op, Ty, Depth + 1));
1963 return getAddExpr(Ops, SCEV::FlagNSW, Depth + 1);
1964 }
1965
1966 // sext(C + x + y + ...) --> (sext(D) + sext((C - D) + x + y + ...))
1967 // if D + (C - D + x + y + ...) could be proven to not signed wrap
1968 // where D maximizes the number of trailing zeros of (C - D + x + y + ...)
1969 //
1970 // For instance, this will bring two seemingly different expressions:
1971 // 1 + sext(5 + 20 * %x + 24 * %y) and
1972 // sext(6 + 20 * %x + 24 * %y)
1973 // to the same form:
1974 // 2 + sext(4 + 20 * %x + 24 * %y)
1975 if (const auto *SC = dyn_cast<SCEVConstant>(SA->getOperand(0))) {
1976 const APInt &D = extractConstantWithoutWrapping(*this, SC, SA);
1977 if (D != 0) {
1978 const SCEV *SSExtD = getSignExtendExpr(getConstant(D), Ty, Depth);
1979 const SCEV *SResidual =
1981 const SCEV *SSExtR = getSignExtendExpr(SResidual, Ty, Depth + 1);
1982 return getAddExpr(SSExtD, SSExtR,
1984 Depth + 1);
1985 }
1986 }
1987 }
1988 // If the input value is a chrec scev, and we can prove that the value
1989 // did not overflow the old, smaller, value, we can sign extend all of the
1990 // operands (often constants). This allows analysis of something like
1991 // this: for (signed char X = 0; X < 100; ++X) { int Y = X; }
1992 if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Op))
1993 if (AR->isAffine()) {
1994 const SCEV *Start = AR->getStart();
1995 const SCEV *Step = AR->getStepRecurrence(*this);
1996 unsigned BitWidth = getTypeSizeInBits(AR->getType());
1997 const Loop *L = AR->getLoop();
1998
1999 // If we have special knowledge that this addrec won't overflow,
2000 // we don't need to do any further analysis.
2001 if (AR->hasNoSignedWrap()) {
2002 Start =
2003 getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this, Depth + 1);
2004 Step = getSignExtendExpr(Step, Ty, Depth + 1);
2005 return getAddRecExpr(Start, Step, L, SCEV::FlagNSW);
2006 }
2007
2008 // Check whether the backedge-taken count is SCEVCouldNotCompute.
2009 // Note that this serves two purposes: It filters out loops that are
2010 // simply not analyzable, and it covers the case where this code is
2011 // being called from within backedge-taken count analysis, such that
2012 // attempting to ask for the backedge-taken count would likely result
2013 // in infinite recursion. In the later case, the analysis code will
2014 // cope with a conservative value, and it will take care to purge
2015 // that value once it has finished.
2016 const SCEV *MaxBECount = getConstantMaxBackedgeTakenCount(L);
2017 if (!isa<SCEVCouldNotCompute>(MaxBECount)) {
2018 // Manually compute the final value for AR, checking for
2019 // overflow.
2020
2021 // Check whether the backedge-taken count can be losslessly casted to
2022 // the addrec's type. The count is always unsigned.
2023 const SCEV *CastedMaxBECount =
2024 getTruncateOrZeroExtend(MaxBECount, Start->getType(), Depth);
2025 const SCEV *RecastedMaxBECount = getTruncateOrZeroExtend(
2026 CastedMaxBECount, MaxBECount->getType(), Depth);
2027 if (MaxBECount == RecastedMaxBECount) {
2028 Type *WideTy = IntegerType::get(getContext(), BitWidth * 2);
2029 // Check whether Start+Step*MaxBECount has no signed overflow.
2030 const SCEV *SMul = getMulExpr(CastedMaxBECount, Step,
2032 const SCEV *SAdd = getSignExtendExpr(getAddExpr(Start, SMul,
2034 Depth + 1),
2035 WideTy, Depth + 1);
2036 const SCEV *WideStart = getSignExtendExpr(Start, WideTy, Depth + 1);
2037 const SCEV *WideMaxBECount =
2038 getZeroExtendExpr(CastedMaxBECount, WideTy, Depth + 1);
2039 const SCEV *OperandExtendedAdd =
2040 getAddExpr(WideStart,
2041 getMulExpr(WideMaxBECount,
2042 getSignExtendExpr(Step, WideTy, Depth + 1),
2045 if (SAdd == OperandExtendedAdd) {
2046 // Cache knowledge of AR NSW, which is propagated to this AddRec.
2047 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNSW);
2048 // Return the expression with the addrec on the outside.
2049 Start = getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this,
2050 Depth + 1);
2051 Step = getSignExtendExpr(Step, Ty, Depth + 1);
2052 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
2053 }
2054 // Similar to above, only this time treat the step value as unsigned.
2055 // This covers loops that count up with an unsigned step.
2056 OperandExtendedAdd =
2057 getAddExpr(WideStart,
2058 getMulExpr(WideMaxBECount,
2059 getZeroExtendExpr(Step, WideTy, Depth + 1),
2062 if (SAdd == OperandExtendedAdd) {
2063 // If AR wraps around then
2064 //
2065 // abs(Step) * MaxBECount > unsigned-max(AR->getType())
2066 // => SAdd != OperandExtendedAdd
2067 //
2068 // Thus (AR is not NW => SAdd != OperandExtendedAdd) <=>
2069 // (SAdd == OperandExtendedAdd => AR is NW)
2070
2071 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNW);
2072
2073 // Return the expression with the addrec on the outside.
2074 Start = getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this,
2075 Depth + 1);
2076 Step = getZeroExtendExpr(Step, Ty, Depth + 1);
2077 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
2078 }
2079 }
2080 }
2081
2082 auto NewFlags = proveNoSignedWrapViaInduction(AR);
2083 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), NewFlags);
2084 if (AR->hasNoSignedWrap()) {
2085 // Same as nsw case above - duplicated here to avoid a compile time
2086 // issue. It's not clear that the order of checks does matter, but
2087 // it's one of two issue possible causes for a change which was
2088 // reverted. Be conservative for the moment.
2089 Start =
2090 getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this, Depth + 1);
2091 Step = getSignExtendExpr(Step, Ty, Depth + 1);
2092 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
2093 }
2094
2095 // sext({C,+,Step}) --> (sext(D) + sext({C-D,+,Step}))<nuw><nsw>
2096 // if D + (C - D + Step * n) could be proven to not signed wrap
2097 // where D maximizes the number of trailing zeros of (C - D + Step * n)
2098 if (const auto *SC = dyn_cast<SCEVConstant>(Start)) {
2099 const APInt &C = SC->getAPInt();
2100 const APInt &D = extractConstantWithoutWrapping(*this, C, Step);
2101 if (D != 0) {
2102 const SCEV *SSExtD = getSignExtendExpr(getConstant(D), Ty, Depth);
2103 const SCEV *SResidual =
2104 getAddRecExpr(getConstant(C - D), Step, L, AR->getNoWrapFlags());
2105 const SCEV *SSExtR = getSignExtendExpr(SResidual, Ty, Depth + 1);
2106 return getAddExpr(SSExtD, SSExtR,
2108 Depth + 1);
2109 }
2110 }
2111
2112 if (proveNoWrapByVaryingStart<SCEVSignExtendExpr>(Start, Step, L)) {
2113 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNSW);
2114 Start =
2115 getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this, Depth + 1);
2116 Step = getSignExtendExpr(Step, Ty, Depth + 1);
2117 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
2118 }
2119 }
2120
2121 // If the input value is provably positive and we could not simplify
2122 // away the sext build a zext instead.
2124 return getZeroExtendExpr(Op, Ty, Depth + 1);
2125
2126 // sext(smin(x, y)) -> smin(sext(x), sext(y))
2127 // sext(smax(x, y)) -> smax(sext(x), sext(y))
2128 if (isa<SCEVSMinExpr>(Op) || isa<SCEVSMaxExpr>(Op)) {
2129 auto *MinMax = cast<SCEVMinMaxExpr>(Op);
2131 for (auto *Operand : MinMax->operands())
2132 Operands.push_back(getSignExtendExpr(Operand, Ty));
2133 if (isa<SCEVSMinExpr>(MinMax))
2134 return getSMinExpr(Operands);
2135 return getSMaxExpr(Operands);
2136 }
2137
2138 // The cast wasn't folded; create an explicit cast node.
2139 // Recompute the insert position, as it may have been invalidated.
2140 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
2141 SCEV *S = new (SCEVAllocator) SCEVSignExtendExpr(ID.Intern(SCEVAllocator),
2142 Op, Ty);
2143 UniqueSCEVs.InsertNode(S, IP);
2144 registerUser(S, { Op });
2145 return S;
2146}
2147
2149 Type *Ty) {
2150 switch (Kind) {
2151 case scTruncate:
2152 return getTruncateExpr(Op, Ty);
2153 case scZeroExtend:
2154 return getZeroExtendExpr(Op, Ty);
2155 case scSignExtend:
2156 return getSignExtendExpr(Op, Ty);
2157 case scPtrToInt:
2158 return getPtrToIntExpr(Op, Ty);
2159 default:
2160 llvm_unreachable("Not a SCEV cast expression!");
2161 }
2162}
2163
2164/// getAnyExtendExpr - Return a SCEV for the given operand extended with
2165/// unspecified bits out to the given type.
2167 Type *Ty) {
2168 assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) &&
2169 "This is not an extending conversion!");
2170 assert(isSCEVable(Ty) &&
2171 "This is not a conversion to a SCEVable type!");
2172 Ty = getEffectiveSCEVType(Ty);
2173
2174 // Sign-extend negative constants.
2175 if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op))
2176 if (SC->getAPInt().isNegative())
2177 return getSignExtendExpr(Op, Ty);
2178
2179 // Peel off a truncate cast.
2180 if (const SCEVTruncateExpr *T = dyn_cast<SCEVTruncateExpr>(Op)) {
2181 const SCEV *NewOp = T->getOperand();
2182 if (getTypeSizeInBits(NewOp->getType()) < getTypeSizeInBits(Ty))
2183 return getAnyExtendExpr(NewOp, Ty);
2184 return getTruncateOrNoop(NewOp, Ty);
2185 }
2186
2187 // Next try a zext cast. If the cast is folded, use it.
2188 const SCEV *ZExt = getZeroExtendExpr(Op, Ty);
2189 if (!isa<SCEVZeroExtendExpr>(ZExt))
2190 return ZExt;
2191
2192 // Next try a sext cast. If the cast is folded, use it.
2193 const SCEV *SExt = getSignExtendExpr(Op, Ty);
2194 if (!isa<SCEVSignExtendExpr>(SExt))
2195 return SExt;
2196
2197 // Force the cast to be folded into the operands of an addrec.
2198 if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Op)) {
2200 for (const SCEV *Op : AR->operands())
2201 Ops.push_back(getAnyExtendExpr(Op, Ty));
2202 return getAddRecExpr(Ops, AR->getLoop(), SCEV::FlagNW);
2203 }
2204
2205 // If the expression is obviously signed, use the sext cast value.
2206 if (isa<SCEVSMaxExpr>(Op))
2207 return SExt;
2208
2209 // Absent any other information, use the zext cast value.
2210 return ZExt;
2211}
2212
2213/// Process the given Ops list, which is a list of operands to be added under
2214/// the given scale, update the given map. This is a helper function for
2215/// getAddRecExpr. As an example of what it does, given a sequence of operands
2216/// that would form an add expression like this:
2217///
2218/// m + n + 13 + (A * (o + p + (B * (q + m + 29)))) + r + (-1 * r)
2219///
2220/// where A and B are constants, update the map with these values:
2221///
2222/// (m, 1+A*B), (n, 1), (o, A), (p, A), (q, A*B), (r, 0)
2223///
2224/// and add 13 + A*B*29 to AccumulatedConstant.
2225/// This will allow getAddRecExpr to produce this:
2226///
2227/// 13+A*B*29 + n + (m * (1+A*B)) + ((o + p) * A) + (q * A*B)
2228///
2229/// This form often exposes folding opportunities that are hidden in
2230/// the original operand list.
2231///
2232/// Return true iff it appears that any interesting folding opportunities
2233/// may be exposed. This helps getAddRecExpr short-circuit extra work in
2234/// the common case where no interesting opportunities are present, and
2235/// is also used as a check to avoid infinite recursion.
2236static bool
2239 APInt &AccumulatedConstant,
2240 ArrayRef<const SCEV *> Ops, const APInt &Scale,
2241 ScalarEvolution &SE) {
2242 bool Interesting = false;
2243
2244 // Iterate over the add operands. They are sorted, with constants first.
2245 unsigned i = 0;
2246 while (const SCEVConstant *C = dyn_cast<SCEVConstant>(Ops[i])) {
2247 ++i;
2248 // Pull a buried constant out to the outside.
2249 if (Scale != 1 || AccumulatedConstant != 0 || C->getValue()->isZero())
2250 Interesting = true;
2251 AccumulatedConstant += Scale * C->getAPInt();
2252 }
2253
2254 // Next comes everything else. We're especially interested in multiplies
2255 // here, but they're in the middle, so just visit the rest with one loop.
2256 for (; i != Ops.size(); ++i) {
2257 const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(Ops[i]);
2258 if (Mul && isa<SCEVConstant>(Mul->getOperand(0))) {
2259 APInt NewScale =
2260 Scale * cast<SCEVConstant>(Mul->getOperand(0))->getAPInt();
2261 if (Mul->getNumOperands() == 2 && isa<SCEVAddExpr>(Mul->getOperand(1))) {
2262 // A multiplication of a constant with another add; recurse.
2263 const SCEVAddExpr *Add = cast<SCEVAddExpr>(Mul->getOperand(1));
2264 Interesting |=
2265 CollectAddOperandsWithScales(M, NewOps, AccumulatedConstant,
2266 Add->operands(), NewScale, SE);
2267 } else {
2268 // A multiplication of a constant with some other value. Update
2269 // the map.
2270 SmallVector<const SCEV *, 4> MulOps(drop_begin(Mul->operands()));
2271 const SCEV *Key = SE.getMulExpr(MulOps);
2272 auto Pair = M.insert({Key, NewScale});
2273 if (Pair.second) {
2274 NewOps.push_back(Pair.first->first);
2275 } else {
2276 Pair.first->second += NewScale;
2277 // The map already had an entry for this value, which may indicate
2278 // a folding opportunity.
2279 Interesting = true;
2280 }
2281 }
2282 } else {
2283 // An ordinary operand. Update the map.
2284 std::pair<DenseMap<const SCEV *, APInt>::iterator, bool> Pair =
2285 M.insert({Ops[i], Scale});
2286 if (Pair.second) {
2287 NewOps.push_back(Pair.first->first);
2288 } else {
2289 Pair.first->second += Scale;
2290 // The map already had an entry for this value, which may indicate
2291 // a folding opportunity.
2292 Interesting = true;
2293 }
2294 }
2295 }
2296
2297 return Interesting;
2298}
2299
2301 const SCEV *LHS, const SCEV *RHS,
2302 const Instruction *CtxI) {
2303 const SCEV *(ScalarEvolution::*Operation)(const SCEV *, const SCEV *,
2304 SCEV::NoWrapFlags, unsigned);
2305 switch (BinOp) {
2306 default:
2307 llvm_unreachable("Unsupported binary op");
2308 case Instruction::Add:
2310 break;
2311 case Instruction::Sub:
2313 break;
2314 case Instruction::Mul:
2316 break;
2317 }
2318
2319 const SCEV *(ScalarEvolution::*Extension)(const SCEV *, Type *, unsigned) =
2322
2323 // Check ext(LHS op RHS) == ext(LHS) op ext(RHS)
2324 auto *NarrowTy = cast<IntegerType>(LHS->getType());
2325 auto *WideTy =
2326 IntegerType::get(NarrowTy->getContext(), NarrowTy->getBitWidth() * 2);
2327
2328 const SCEV *A = (this->*Extension)(
2329 (this->*Operation)(LHS, RHS, SCEV::FlagAnyWrap, 0), WideTy, 0);
2330 const SCEV *LHSB = (this->*Extension)(LHS, WideTy, 0);
2331 const SCEV *RHSB = (this->*Extension)(RHS, WideTy, 0);
2332 const SCEV *B = (this->*Operation)(LHSB, RHSB, SCEV::FlagAnyWrap, 0);
2333 if (A == B)
2334 return true;
2335 // Can we use context to prove the fact we need?
2336 if (!CtxI)
2337 return false;
2338 // TODO: Support mul.
2339 if (BinOp == Instruction::Mul)
2340 return false;
2341 auto *RHSC = dyn_cast<SCEVConstant>(RHS);
2342 // TODO: Lift this limitation.
2343 if (!RHSC)
2344 return false;
2345 APInt C = RHSC->getAPInt();
2346 unsigned NumBits = C.getBitWidth();
2347 bool IsSub = (BinOp == Instruction::Sub);
2348 bool IsNegativeConst = (Signed && C.isNegative());
2349 // Compute the direction and magnitude by which we need to check overflow.
2350 bool OverflowDown = IsSub ^ IsNegativeConst;
2351 APInt Magnitude = C;
2352 if (IsNegativeConst) {
2353 if (C == APInt::getSignedMinValue(NumBits))
2354 // TODO: SINT_MIN on inversion gives the same negative value, we don't
2355 // want to deal with that.
2356 return false;
2357 Magnitude = -C;
2358 }
2359
2361 if (OverflowDown) {
2362 // To avoid overflow down, we need to make sure that MIN + Magnitude <= LHS.
2363 APInt Min = Signed ? APInt::getSignedMinValue(NumBits)
2364 : APInt::getMinValue(NumBits);
2365 APInt Limit = Min + Magnitude;
2366 return isKnownPredicateAt(Pred, getConstant(Limit), LHS, CtxI);
2367 } else {
2368 // To avoid overflow up, we need to make sure that LHS <= MAX - Magnitude.
2369 APInt Max = Signed ? APInt::getSignedMaxValue(NumBits)
2370 : APInt::getMaxValue(NumBits);
2371 APInt Limit = Max - Magnitude;
2372 return isKnownPredicateAt(Pred, LHS, getConstant(Limit), CtxI);
2373 }
2374}
2375
2376std::optional<SCEV::NoWrapFlags>
2378 const OverflowingBinaryOperator *OBO) {
2379 // It cannot be done any better.
2380 if (OBO->hasNoUnsignedWrap() && OBO->hasNoSignedWrap())
2381 return std::nullopt;
2382
2384
2385 if (OBO->hasNoUnsignedWrap())
2387 if (OBO->hasNoSignedWrap())
2389
2390 bool Deduced = false;
2391
2392 if (OBO->getOpcode() != Instruction::Add &&
2393 OBO->getOpcode() != Instruction::Sub &&
2394 OBO->getOpcode() != Instruction::Mul)
2395 return std::nullopt;
2396
2397 const SCEV *LHS = getSCEV(OBO->getOperand(0));
2398 const SCEV *RHS = getSCEV(OBO->getOperand(1));
2399
2400 const Instruction *CtxI =
2401 UseContextForNoWrapFlagInference ? dyn_cast<Instruction>(OBO) : nullptr;
2402 if (!OBO->hasNoUnsignedWrap() &&
2404 /* Signed */ false, LHS, RHS, CtxI)) {
2406 Deduced = true;
2407 }
2408
2409 if (!OBO->hasNoSignedWrap() &&
2411 /* Signed */ true, LHS, RHS, CtxI)) {
2413 Deduced = true;
2414 }
2415
2416 if (Deduced)
2417 return Flags;
2418 return std::nullopt;
2419}
2420
2421// We're trying to construct a SCEV of type `Type' with `Ops' as operands and
2422// `OldFlags' as can't-wrap behavior. Infer a more aggressive set of
2423// can't-overflow flags for the operation if possible.
2424static SCEV::NoWrapFlags
2426 const ArrayRef<const SCEV *> Ops,
2427 SCEV::NoWrapFlags Flags) {
2428 using namespace std::placeholders;
2429
2430 using OBO = OverflowingBinaryOperator;
2431
2432 bool CanAnalyze =
2434 (void)CanAnalyze;
2435 assert(CanAnalyze && "don't call from other places!");
2436
2437 int SignOrUnsignMask = SCEV::FlagNUW | SCEV::FlagNSW;
2438 SCEV::NoWrapFlags SignOrUnsignWrap =
2439 ScalarEvolution::maskFlags(Flags, SignOrUnsignMask);
2440
2441 // If FlagNSW is true and all the operands are non-negative, infer FlagNUW.
2442 auto IsKnownNonNegative = [&](const SCEV *S) {
2443 return SE->isKnownNonNegative(S);
2444 };
2445
2446 if (SignOrUnsignWrap == SCEV::FlagNSW && all_of(Ops, IsKnownNonNegative))
2447 Flags =
2448 ScalarEvolution::setFlags(Flags, (SCEV::NoWrapFlags)SignOrUnsignMask);
2449
2450 SignOrUnsignWrap = ScalarEvolution::maskFlags(Flags, SignOrUnsignMask);
2451
2452 if (SignOrUnsignWrap != SignOrUnsignMask &&
2453 (Type == scAddExpr || Type == scMulExpr) && Ops.size() == 2 &&
2454 isa<SCEVConstant>(Ops[0])) {
2455
2456 auto Opcode = [&] {
2457 switch (Type) {
2458 case scAddExpr:
2459 return Instruction::Add;
2460 case scMulExpr:
2461 return Instruction::Mul;
2462 default:
2463 llvm_unreachable("Unexpected SCEV op.");
2464 }
2465 }();
2466
2467 const APInt &C = cast<SCEVConstant>(Ops[0])->getAPInt();
2468
2469 // (A <opcode> C) --> (A <opcode> C)<nsw> if the op doesn't sign overflow.
2470 if (!(SignOrUnsignWrap & SCEV::FlagNSW)) {
2472 Opcode, C, OBO::NoSignedWrap);
2473 if (NSWRegion.contains(SE->getSignedRange(Ops[1])))
2475 }
2476
2477 // (A <opcode> C) --> (A <opcode> C)<nuw> if the op doesn't unsign overflow.
2478 if (!(SignOrUnsignWrap & SCEV::FlagNUW)) {
2480 Opcode, C, OBO::NoUnsignedWrap);
2481 if (NUWRegion.contains(SE->getUnsignedRange(Ops[1])))
2483 }
2484 }
2485
2486 // <0,+,nonnegative><nw> is also nuw
2487 // TODO: Add corresponding nsw case
2489 !ScalarEvolution::hasFlags(Flags, SCEV::FlagNUW) && Ops.size() == 2 &&
2490 Ops[0]->isZero() && IsKnownNonNegative(Ops[1]))
2492
2493 // both (udiv X, Y) * Y and Y * (udiv X, Y) are always NUW
2495 Ops.size() == 2) {
2496 if (auto *UDiv = dyn_cast<SCEVUDivExpr>(Ops[0]))
2497 if (UDiv->getOperand(1) == Ops[1])
2499 if (auto *UDiv = dyn_cast<SCEVUDivExpr>(Ops[1]))
2500 if (UDiv->getOperand(1) == Ops[0])
2502 }
2503
2504 return Flags;
2505}
2506
2508 return isLoopInvariant(S, L) && properlyDominates(S, L->getHeader());
2509}
2510
2511/// Get a canonical add expression, or something simpler if possible.
2513 SCEV::NoWrapFlags OrigFlags,
2514 unsigned Depth) {
2515 assert(!(OrigFlags & ~(SCEV::FlagNUW | SCEV::FlagNSW)) &&
2516 "only nuw or nsw allowed");
2517 assert(!Ops.empty() && "Cannot get empty add!");
2518 if (Ops.size() == 1) return Ops[0];
2519#ifndef NDEBUG
2520 Type *ETy = getEffectiveSCEVType(Ops[0]->getType());
2521 for (unsigned i = 1, e = Ops.size(); i != e; ++i)
2522 assert(getEffectiveSCEVType(Ops[i]->getType()) == ETy &&
2523 "SCEVAddExpr operand types don't match!");
2524 unsigned NumPtrs = count_if(
2525 Ops, [](const SCEV *Op) { return Op->getType()->isPointerTy(); });
2526 assert(NumPtrs <= 1 && "add has at most one pointer operand");
2527#endif
2528
2529 // Sort by complexity, this groups all similar expression types together.
2530 GroupByComplexity(Ops, &LI, DT);
2531
2532 // If there are any constants, fold them together.
2533 unsigned Idx = 0;
2534 if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(Ops[0])) {
2535 ++Idx;
2536 assert(Idx < Ops.size());
2537 while (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(Ops[Idx])) {
2538 // We found two constants, fold them together!
2539 Ops[0] = getConstant(LHSC->getAPInt() + RHSC->getAPInt());
2540 if (Ops.size() == 2) return Ops[0];
2541 Ops.erase(Ops.begin()+1); // Erase the folded element
2542 LHSC = cast<SCEVConstant>(Ops[0]);
2543 }
2544
2545 // If we are left with a constant zero being added, strip it off.
2546 if (LHSC->getValue()->isZero()) {
2547 Ops.erase(Ops.begin());
2548 --Idx;
2549 }
2550
2551 if (Ops.size() == 1) return Ops[0];
2552 }
2553
2554 // Delay expensive flag strengthening until necessary.
2555 auto ComputeFlags = [this, OrigFlags](const ArrayRef<const SCEV *> Ops) {
2556 return StrengthenNoWrapFlags(this, scAddExpr, Ops, OrigFlags);
2557 };
2558
2559 // Limit recursion calls depth.
2561 return getOrCreateAddExpr(Ops, ComputeFlags(Ops));
2562
2563 if (SCEV *S = findExistingSCEVInCache(scAddExpr, Ops)) {
2564 // Don't strengthen flags if we have no new information.
2565 SCEVAddExpr *Add = static_cast<SCEVAddExpr *>(S);
2566 if (Add->getNoWrapFlags(OrigFlags) != OrigFlags)
2567 Add->setNoWrapFlags(ComputeFlags(Ops));
2568 return S;
2569 }
2570
2571 // Okay, check to see if the same value occurs in the operand list more than
2572 // once. If so, merge them together into an multiply expression. Since we
2573 // sorted the list, these values are required to be adjacent.
2574 Type *Ty = Ops[0]->getType();
2575 bool FoundMatch = false;
2576 for (unsigned i = 0, e = Ops.size(); i != e-1; ++i)
2577 if (Ops[i] == Ops[i+1]) { // X + Y + Y --> X + Y*2
2578 // Scan ahead to count how many equal operands there are.
2579 unsigned Count = 2;
2580 while (i+Count != e && Ops[i+Count] == Ops[i])
2581 ++Count;
2582 // Merge the values into a multiply.
2583 const SCEV *Scale = getConstant(Ty, Count);
2584 const SCEV *Mul = getMulExpr(Scale, Ops[i], SCEV::FlagAnyWrap, Depth + 1);
2585 if (Ops.size() == Count)
2586 return Mul;
2587 Ops[i] = Mul;
2588 Ops.erase(Ops.begin()+i+1, Ops.begin()+i+Count);
2589 --i; e -= Count - 1;
2590 FoundMatch = true;
2591 }
2592 if (FoundMatch)
2593 return getAddExpr(Ops, OrigFlags, Depth + 1);
2594
2595 // Check for truncates. If all the operands are truncated from the same
2596 // type, see if factoring out the truncate would permit the result to be
2597 // folded. eg., n*trunc(x) + m*trunc(y) --> trunc(trunc(m)*x + trunc(n)*y)
2598 // if the contents of the resulting outer trunc fold to something simple.
2599 auto FindTruncSrcType = [&]() -> Type * {
2600 // We're ultimately looking to fold an addrec of truncs and muls of only
2601 // constants and truncs, so if we find any other types of SCEV
2602 // as operands of the addrec then we bail and return nullptr here.
2603 // Otherwise, we return the type of the operand of a trunc that we find.
2604 if (auto *T = dyn_cast<SCEVTruncateExpr>(Ops[Idx]))
2605 return T->getOperand()->getType();
2606 if (const auto *Mul = dyn_cast<SCEVMulExpr>(Ops[Idx])) {
2607 const auto *LastOp = Mul->getOperand(Mul->getNumOperands() - 1);
2608 if (const auto *T = dyn_cast<SCEVTruncateExpr>(LastOp))
2609 return T->getOperand()->getType();
2610 }
2611 return nullptr;
2612 };
2613 if (auto *SrcType = FindTruncSrcType()) {
2615 bool Ok = true;
2616 // Check all the operands to see if they can be represented in the
2617 // source type of the truncate.
2618 for (unsigned i = 0, e = Ops.size(); i != e; ++i) {
2619 if (const SCEVTruncateExpr *T = dyn_cast<SCEVTruncateExpr>(Ops[i])) {
2620 if (T->getOperand()->getType() != SrcType) {
2621 Ok = false;
2622 break;
2623 }
2624 LargeOps.push_back(T->getOperand());
2625 } else if (const SCEVConstant *C = dyn_cast<SCEVConstant>(Ops[i])) {
2626 LargeOps.push_back(getAnyExtendExpr(C, SrcType));
2627 } else if (const SCEVMulExpr *M = dyn_cast<SCEVMulExpr>(Ops[i])) {
2628 SmallVector<const SCEV *, 8> LargeMulOps;
2629 for (unsigned j = 0, f = M->getNumOperands(); j != f && Ok; ++j) {
2630 if (const SCEVTruncateExpr *T =
2631 dyn_cast<SCEVTruncateExpr>(M->getOperand(j))) {
2632 if (T->getOperand()->getType() != SrcType) {
2633 Ok = false;
2634 break;
2635 }
2636 LargeMulOps.push_back(T->getOperand());
2637 } else if (const auto *C = dyn_cast<SCEVConstant>(M->getOperand(j))) {
2638 LargeMulOps.push_back(getAnyExtendExpr(C, SrcType));
2639 } else {
2640 Ok = false;
2641 break;
2642 }
2643 }
2644 if (Ok)
2645 LargeOps.push_back(getMulExpr(LargeMulOps, SCEV::FlagAnyWrap, Depth + 1));
2646 } else {
2647 Ok = false;
2648 break;
2649 }
2650 }
2651 if (Ok) {
2652 // Evaluate the expression in the larger type.
2653 const SCEV *Fold = getAddExpr(LargeOps, SCEV::FlagAnyWrap, Depth + 1);
2654 // If it folds to something simple, use it. Otherwise, don't.
2655 if (isa<SCEVConstant>(Fold) || isa<SCEVUnknown>(Fold))
2656 return getTruncateExpr(Fold, Ty);
2657 }
2658 }
2659
2660 if (Ops.size() == 2) {
2661 // Check if we have an expression of the form ((X + C1) - C2), where C1 and
2662 // C2 can be folded in a way that allows retaining wrapping flags of (X +
2663 // C1).
2664 const SCEV *A = Ops[0];
2665 const SCEV *B = Ops[1];
2666 auto *AddExpr = dyn_cast<SCEVAddExpr>(B);
2667 auto *C = dyn_cast<SCEVConstant>(A);
2668 if (AddExpr && C && isa<SCEVConstant>(AddExpr->getOperand(0))) {
2669 auto C1 = cast<SCEVConstant>(AddExpr->getOperand(0))->getAPInt();
2670 auto C2 = C->getAPInt();
2671 SCEV::NoWrapFlags PreservedFlags = SCEV::FlagAnyWrap;
2672
2673 APInt ConstAdd = C1 + C2;
2674 auto AddFlags = AddExpr->getNoWrapFlags();
2675 // Adding a smaller constant is NUW if the original AddExpr was NUW.
2677 ConstAdd.ule(C1)) {
2678 PreservedFlags =
2680 }
2681
2682 // Adding a constant with the same sign and small magnitude is NSW, if the
2683 // original AddExpr was NSW.
2685 C1.isSignBitSet() == ConstAdd.isSignBitSet() &&
2686 ConstAdd.abs().ule(C1.abs())) {
2687 PreservedFlags =
2689 }
2690
2691 if (PreservedFlags != SCEV::FlagAnyWrap) {
2692 SmallVector<const SCEV *, 4> NewOps(AddExpr->operands());
2693 NewOps[0] = getConstant(ConstAdd);
2694 return getAddExpr(NewOps, PreservedFlags);
2695 }
2696 }
2697 }
2698
2699 // Canonicalize (-1 * urem X, Y) + X --> (Y * X/Y)
2700 if (Ops.size() == 2) {
2701 const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(Ops[0]);
2702 if (Mul && Mul->getNumOperands() == 2 &&
2703 Mul->getOperand(0)->isAllOnesValue()) {
2704 const SCEV *X;
2705 const SCEV *Y;
2706 if (matchURem(Mul->getOperand(1), X, Y) && X == Ops[1]) {
2707 return getMulExpr(Y, getUDivExpr(X, Y));
2708 }
2709 }
2710 }
2711
2712 // Skip past any other cast SCEVs.
2713 while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scAddExpr)
2714 ++Idx;
2715
2716 // If there are add operands they would be next.
2717 if (Idx < Ops.size()) {
2718 bool DeletedAdd = false;
2719 // If the original flags and all inlined SCEVAddExprs are NUW, use the
2720 // common NUW flag for expression after inlining. Other flags cannot be
2721 // preserved, because they may depend on the original order of operations.
2722 SCEV::NoWrapFlags CommonFlags = maskFlags(OrigFlags, SCEV::FlagNUW);
2723 while (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Ops[Idx])) {
2724 if (Ops.size() > AddOpsInlineThreshold ||
2725 Add->getNumOperands() > AddOpsInlineThreshold)
2726 break;
2727 // If we have an add, expand the add operands onto the end of the operands
2728 // list.
2729 Ops.erase(Ops.begin()+Idx);
2730 append_range(Ops, Add->operands());
2731 DeletedAdd = true;
2732 CommonFlags = maskFlags(CommonFlags, Add->getNoWrapFlags());
2733 }
2734
2735 // If we deleted at least one add, we added operands to the end of the list,
2736 // and they are not necessarily sorted. Recurse to resort and resimplify
2737 // any operands we just acquired.
2738 if (DeletedAdd)
2739 return getAddExpr(Ops, CommonFlags, Depth + 1);
2740 }
2741
2742 // Skip over the add expression until we get to a multiply.
2743 while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scMulExpr)
2744 ++Idx;
2745
2746 // Check to see if there are any folding opportunities present with
2747 // operands multiplied by constant values.
2748 if (Idx < Ops.size() && isa<SCEVMulExpr>(Ops[Idx])) {
2752 APInt AccumulatedConstant(BitWidth, 0);
2753 if (CollectAddOperandsWithScales(M, NewOps, AccumulatedConstant,
2754 Ops, APInt(BitWidth, 1), *this)) {
2755 struct APIntCompare {
2756 bool operator()(const APInt &LHS, const APInt &RHS) const {
2757 return LHS.ult(RHS);
2758 }
2759 };
2760
2761 // Some interesting folding opportunity is present, so its worthwhile to
2762 // re-generate the operands list. Group the operands by constant scale,
2763 // to avoid multiplying by the same constant scale multiple times.
2764 std::map<APInt, SmallVector<const SCEV *, 4>, APIntCompare> MulOpLists;
2765 for (const SCEV *NewOp : NewOps)
2766 MulOpLists[M.find(NewOp)->second].push_back(NewOp);
2767 // Re-generate the operands list.
2768 Ops.clear();
2769 if (AccumulatedConstant != 0)
2770 Ops.push_back(getConstant(AccumulatedConstant));
2771 for (auto &MulOp : MulOpLists) {
2772 if (MulOp.first == 1) {
2773 Ops.push_back(getAddExpr(MulOp.second, SCEV::FlagAnyWrap, Depth + 1));
2774 } else if (MulOp.first != 0) {
2776 getConstant(MulOp.first),
2777 getAddExpr(MulOp.second, SCEV::FlagAnyWrap, Depth + 1),
2778 SCEV::FlagAnyWrap, Depth + 1));
2779 }
2780 }
2781 if (Ops.empty())
2782 return getZero(Ty);
2783 if (Ops.size() == 1)
2784 return Ops[0];
2785 return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
2786 }
2787 }
2788
2789 // If we are adding something to a multiply expression, make sure the
2790 // something is not already an operand of the multiply. If so, merge it into
2791 // the multiply.
2792 for (; Idx < Ops.size() && isa<SCEVMulExpr>(Ops[Idx]); ++Idx) {
2793 const SCEVMulExpr *Mul = cast<SCEVMulExpr>(Ops[Idx]);
2794 for (unsigned MulOp = 0, e = Mul->getNumOperands(); MulOp != e; ++MulOp) {
2795 const SCEV *MulOpSCEV = Mul->getOperand(MulOp);
2796 if (isa<SCEVConstant>(MulOpSCEV))
2797 continue;
2798 for (unsigned AddOp = 0, e = Ops.size(); AddOp != e; ++AddOp)
2799 if (MulOpSCEV == Ops[AddOp]) {
2800 // Fold W + X + (X * Y * Z) --> W + (X * ((Y*Z)+1))
2801 const SCEV *InnerMul = Mul->getOperand(MulOp == 0);
2802 if (Mul->getNumOperands() != 2) {
2803 // If the multiply has more than two operands, we must get the
2804 // Y*Z term.
2806 Mul->operands().take_front(MulOp));
2807 append_range(MulOps, Mul->operands().drop_front(MulOp + 1));
2808 InnerMul = getMulExpr(MulOps, SCEV::FlagAnyWrap, Depth + 1);
2809 }
2810 SmallVector<const SCEV *, 2> TwoOps = {getOne(Ty), InnerMul};
2811 const SCEV *AddOne = getAddExpr(TwoOps, SCEV::FlagAnyWrap, Depth + 1);
2812 const SCEV *OuterMul = getMulExpr(AddOne, MulOpSCEV,
2814 if (Ops.size() == 2) return OuterMul;
2815 if (AddOp < Idx) {
2816 Ops.erase(Ops.begin()+AddOp);
2817 Ops.erase(Ops.begin()+Idx-1);
2818 } else {
2819 Ops.erase(Ops.begin()+Idx);
2820 Ops.erase(Ops.begin()+AddOp-1);
2821 }
2822 Ops.push_back(OuterMul);
2823 return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
2824 }
2825
2826 // Check this multiply against other multiplies being added together.
2827 for (unsigned OtherMulIdx = Idx+1;
2828 OtherMulIdx < Ops.size() && isa<SCEVMulExpr>(Ops[OtherMulIdx]);
2829 ++OtherMulIdx) {
2830 const SCEVMulExpr *OtherMul = cast<SCEVMulExpr>(Ops[OtherMulIdx]);
2831 // If MulOp occurs in OtherMul, we can fold the two multiplies
2832 // together.
2833 for (unsigned OMulOp = 0, e = OtherMul->getNumOperands();
2834 OMulOp != e; ++OMulOp)
2835 if (OtherMul->getOperand(OMulOp) == MulOpSCEV) {
2836 // Fold X + (A*B*C) + (A*D*E) --> X + (A*(B*C+D*E))
2837 const SCEV *InnerMul1 = Mul->getOperand(MulOp == 0);
2838 if (Mul->getNumOperands() != 2) {
2840 Mul->operands().take_front(MulOp));
2841 append_range(MulOps, Mul->operands().drop_front(MulOp+1));
2842 InnerMul1 = getMulExpr(MulOps, SCEV::FlagAnyWrap, Depth + 1);
2843 }
2844 const SCEV *InnerMul2 = OtherMul->getOperand(OMulOp == 0);
2845 if (OtherMul->getNumOperands() != 2) {
2847 OtherMul->operands().take_front(OMulOp));
2848 append_range(MulOps, OtherMul->operands().drop_front(OMulOp+1));
2849 InnerMul2 = getMulExpr(MulOps, SCEV::FlagAnyWrap, Depth + 1);
2850 }
2851 SmallVector<const SCEV *, 2> TwoOps = {InnerMul1, InnerMul2};
2852 const SCEV *InnerMulSum =
2853 getAddExpr(TwoOps, SCEV::FlagAnyWrap, Depth + 1);
2854 const SCEV *OuterMul = getMulExpr(MulOpSCEV, InnerMulSum,
2856 if (Ops.size() == 2) return OuterMul;
2857 Ops.erase(Ops.begin()+Idx);
2858 Ops.erase(Ops.begin()+OtherMulIdx-1);
2859 Ops.push_back(OuterMul);
2860 return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
2861 }
2862 }
2863 }
2864 }
2865
2866 // If there are any add recurrences in the operands list, see if any other
2867 // added values are loop invariant. If so, we can fold them into the
2868 // recurrence.
2869 while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scAddRecExpr)
2870 ++Idx;
2871
2872 // Scan over all recurrences, trying to fold loop invariants into them.
2873 for (; Idx < Ops.size() && isa<SCEVAddRecExpr>(Ops[Idx]); ++Idx) {
2874 // Scan all of the other operands to this add and add them to the vector if
2875 // they are loop invariant w.r.t. the recurrence.
2877 const SCEVAddRecExpr *AddRec = cast<SCEVAddRecExpr>(Ops[Idx]);
2878 const Loop *AddRecLoop = AddRec->getLoop();
2879 for (unsigned i = 0, e = Ops.size(); i != e; ++i)
2880 if (isAvailableAtLoopEntry(Ops[i], AddRecLoop)) {
2881 LIOps.push_back(Ops[i]);
2882 Ops.erase(Ops.begin()+i);
2883 --i; --e;
2884 }
2885
2886 // If we found some loop invariants, fold them into the recurrence.
2887 if (!LIOps.empty()) {
2888 // Compute nowrap flags for the addition of the loop-invariant ops and
2889 // the addrec. Temporarily push it as an operand for that purpose. These
2890 // flags are valid in the scope of the addrec only.
2891 LIOps.push_back(AddRec);
2892 SCEV::NoWrapFlags Flags = ComputeFlags(LIOps);
2893 LIOps.pop_back();
2894
2895 // NLI + LI + {Start,+,Step} --> NLI + {LI+Start,+,Step}
2896 LIOps.push_back(AddRec->getStart());
2897
2898 SmallVector<const SCEV *, 4> AddRecOps(AddRec->operands());
2899
2900 // It is not in general safe to propagate flags valid on an add within
2901 // the addrec scope to one outside it. We must prove that the inner
2902 // scope is guaranteed to execute if the outer one does to be able to
2903 // safely propagate. We know the program is undefined if poison is
2904 // produced on the inner scoped addrec. We also know that *for this use*
2905 // the outer scoped add can't overflow (because of the flags we just
2906 // computed for the inner scoped add) without the program being undefined.
2907 // Proving that entry to the outer scope neccesitates entry to the inner
2908 // scope, thus proves the program undefined if the flags would be violated
2909 // in the outer scope.
2910 SCEV::NoWrapFlags AddFlags = Flags;
2911 if (AddFlags != SCEV::FlagAnyWrap) {
2912 auto *DefI = getDefiningScopeBound(LIOps);
2913 auto *ReachI = &*AddRecLoop->getHeader()->begin();
2914 if (!isGuaranteedToTransferExecutionTo(DefI, ReachI))
2915 AddFlags = SCEV::FlagAnyWrap;
2916 }
2917 AddRecOps[0] = getAddExpr(LIOps, AddFlags, Depth + 1);
2918
2919 // Build the new addrec. Propagate the NUW and NSW flags if both the
2920 // outer add and the inner addrec are guaranteed to have no overflow.
2921 // Always propagate NW.
2922 Flags = AddRec->getNoWrapFlags(setFlags(Flags, SCEV::FlagNW));
2923 const SCEV *NewRec = getAddRecExpr(AddRecOps, AddRecLoop, Flags);
2924
2925 // If all of the other operands were loop invariant, we are done.
2926 if (Ops.size() == 1) return NewRec;
2927
2928 // Otherwise, add the folded AddRec by the non-invariant parts.
2929 for (unsigned i = 0;; ++i)
2930 if (Ops[i] == AddRec) {
2931 Ops[i] = NewRec;
2932 break;
2933 }
2934 return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
2935 }
2936
2937 // Okay, if there weren't any loop invariants to be folded, check to see if
2938 // there are multiple AddRec's with the same loop induction variable being
2939 // added together. If so, we can fold them.
2940 for (unsigned OtherIdx = Idx+1;
2941 OtherIdx < Ops.size() && isa<SCEVAddRecExpr>(Ops[OtherIdx]);
2942 ++OtherIdx) {
2943 // We expect the AddRecExpr's to be sorted in reverse dominance order,
2944 // so that the 1st found AddRecExpr is dominated by all others.
2945 assert(DT.dominates(
2946 cast<SCEVAddRecExpr>(Ops[OtherIdx])->getLoop()->getHeader(),
2947 AddRec->getLoop()->getHeader()) &&
2948 "AddRecExprs are not sorted in reverse dominance order?");
2949 if (AddRecLoop == cast<SCEVAddRecExpr>(Ops[OtherIdx])->getLoop()) {
2950 // Other + {A,+,B}<L> + {C,+,D}<L> --> Other + {A+C,+,B+D}<L>
2951 SmallVector<const SCEV *, 4> AddRecOps(AddRec->operands());
2952 for (; OtherIdx != Ops.size() && isa<SCEVAddRecExpr>(Ops[OtherIdx]);
2953 ++OtherIdx) {
2954 const auto *OtherAddRec = cast<SCEVAddRecExpr>(Ops[OtherIdx]);
2955 if (OtherAddRec->getLoop() == AddRecLoop) {
2956 for (unsigned i = 0, e = OtherAddRec->getNumOperands();
2957 i != e; ++i) {
2958 if (i >= AddRecOps.size()) {
2959 append_range(AddRecOps, OtherAddRec->operands().drop_front(i));
2960 break;
2961 }
2963 AddRecOps[i], OtherAddRec->getOperand(i)};
2964 AddRecOps[i] = getAddExpr(TwoOps, SCEV::FlagAnyWrap, Depth + 1);
2965 }
2966 Ops.erase(Ops.begin() + OtherIdx); --OtherIdx;
2967 }
2968 }
2969 // Step size has changed, so we cannot guarantee no self-wraparound.
2970 Ops[Idx] = getAddRecExpr(AddRecOps, AddRecLoop, SCEV::FlagAnyWrap);
2971 return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
2972 }
2973 }
2974
2975 // Otherwise couldn't fold anything into this recurrence. Move onto the
2976 // next one.
2977 }
2978
2979 // Okay, it looks like we really DO need an add expr. Check to see if we
2980 // already have one, otherwise create a new one.
2981 return getOrCreateAddExpr(Ops, ComputeFlags(Ops));
2982}
2983
2984const SCEV *
2985ScalarEvolution::getOrCreateAddExpr(ArrayRef<const SCEV *> Ops,
2986 SCEV::NoWrapFlags Flags) {
2988 ID.AddInteger(scAddExpr);
2989 for (const SCEV *Op : Ops)
2990 ID.AddPointer(Op);
2991 void *IP = nullptr;
2992 SCEVAddExpr *S =
2993 static_cast<SCEVAddExpr *>(UniqueSCEVs.FindNodeOrInsertPos(ID, IP));
2994 if (!S) {
2995 const SCEV **O = SCEVAllocator.Allocate<const SCEV *>(Ops.size());
2996 std::uninitialized_copy(Ops.begin(), Ops.end(), O);
2997 S = new (SCEVAllocator)
2998 SCEVAddExpr(ID.Intern(SCEVAllocator), O, Ops.size());
2999 UniqueSCEVs.InsertNode(S, IP);
3000 registerUser(S, Ops);
3001 }
3002 S->setNoWrapFlags(Flags);
3003 return S;
3004}
3005
3006const SCEV *
3007ScalarEvolution::getOrCreateAddRecExpr(ArrayRef<const SCEV *> Ops,
3008 const Loop *L, SCEV::NoWrapFlags Flags) {
3010 ID.AddInteger(scAddRecExpr);
3011 for (const SCEV *Op : Ops)
3012 ID.AddPointer(Op);
3013 ID.AddPointer(L);
3014 void *IP = nullptr;
3015 SCEVAddRecExpr *S =
3016 static_cast<SCEVAddRecExpr *>(UniqueSCEVs.FindNodeOrInsertPos(ID, IP));
3017 if (!S) {
3018 const SCEV **O = SCEVAllocator.Allocate<const SCEV *>(Ops.size());
3019 std::uninitialized_copy(Ops.begin(), Ops.end(), O);
3020 S = new (SCEVAllocator)
3021 SCEVAddRecExpr(ID.Intern(SCEVAllocator), O, Ops.size(), L);
3022 UniqueSCEVs.InsertNode(S, IP);
3023 LoopUsers[L].push_back(S);
3024 registerUser(S, Ops);
3025 }
3026 setNoWrapFlags(S, Flags);
3027 return S;
3028}
3029
3030const SCEV *
3031ScalarEvolution::getOrCreateMulExpr(ArrayRef<const SCEV *> Ops,
3032 SCEV::NoWrapFlags Flags) {
3034 ID.AddInteger(scMulExpr);
3035 for (const SCEV *Op : Ops)
3036 ID.AddPointer(Op);
3037 void *IP = nullptr;
3038 SCEVMulExpr *S =
3039 static_cast<SCEVMulExpr *>(UniqueSCEVs.FindNodeOrInsertPos(ID, IP));
3040 if (!S) {
3041 const SCEV **O = SCEVAllocator.Allocate<const SCEV *>(Ops.size());
3042 std::uninitialized_copy(Ops.begin(), Ops.end(), O);
3043 S = new (SCEVAllocator) SCEVMulExpr(ID.Intern(SCEVAllocator),
3044 O, Ops.size());
3045 UniqueSCEVs.InsertNode(S, IP);
3046 registerUser(S, Ops);
3047 }
3048 S->setNoWrapFlags(Flags);
3049 return S;
3050}
3051
3052static uint64_t umul_ov(uint64_t i, uint64_t j, bool &Overflow) {
3053 uint64_t k = i*j;
3054 if (j > 1 && k / j != i) Overflow = true;
3055 return k;
3056}
3057
3058/// Compute the result of "n choose k", the binomial coefficient. If an
3059/// intermediate computation overflows, Overflow will be set and the return will
3060/// be garbage. Overflow is not cleared on absence of overflow.
3061static uint64_t Choose(uint64_t n, uint64_t k, bool &Overflow) {
3062 // We use the multiplicative formula:
3063 // n(n-1)(n-2)...(n-(k-1)) / k(k-1)(k-2)...1 .
3064 // At each iteration, we take the n-th term of the numeral and divide by the
3065 // (k-n)th term of the denominator. This division will always produce an
3066 // integral result, and helps reduce the chance of overflow in the
3067 // intermediate computations. However, we can still overflow even when the
3068 // final result would fit.
3069
3070 if (n == 0 || n == k) return 1;
3071 if (k > n) return 0;
3072
3073 if (k > n/2)
3074 k = n-k;
3075
3076 uint64_t r = 1;
3077 for (uint64_t i = 1; i <= k; ++i) {
3078 r = umul_ov(r, n-(i-1), Overflow);
3079 r /= i;
3080 }
3081 return r;
3082}
3083
3084/// Determine if any of the operands in this SCEV are a constant or if
3085/// any of the add or multiply expressions in this SCEV contain a constant.
3086static bool containsConstantInAddMulChain(const SCEV *StartExpr) {
3087 struct FindConstantInAddMulChain {
3088 bool FoundConstant = false;
3089
3090 bool follow(const SCEV *S) {
3091 FoundConstant |= isa<SCEVConstant>(S);
3092 return isa<SCEVAddExpr>(S) || isa<SCEVMulExpr>(S);
3093 }
3094
3095 bool isDone() const {
3096 return FoundConstant;
3097 }
3098 };
3099
3100 FindConstantInAddMulChain F;
3102 ST.visitAll(StartExpr);
3103 return F.FoundConstant;
3104}
3105
3106/// Get a canonical multiply expression, or something simpler if possible.
3108 SCEV::NoWrapFlags OrigFlags,
3109 unsigned Depth) {
3110 assert(OrigFlags == maskFlags(OrigFlags, SCEV::FlagNUW | SCEV::FlagNSW) &&
3111 "only nuw or nsw allowed");
3112 assert(!Ops.empty() && "Cannot get empty mul!");
3113 if (Ops.size() == 1) return Ops[0];
3114#ifndef NDEBUG
3115 Type *ETy = Ops[0]->getType();
3116 assert(!ETy->isPointerTy());
3117 for (unsigned i = 1, e = Ops.size(); i != e; ++i)
3118 assert(Ops[i]->getType() == ETy &&
3119 "SCEVMulExpr operand types don't match!");
3120#endif
3121
3122 // Sort by complexity, this groups all similar expression types together.
3123 GroupByComplexity(Ops, &LI, DT);
3124
3125 // If there are any constants, fold them together.
3126 unsigned Idx = 0;
3127 if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(Ops[0])) {
3128 ++Idx;
3129 assert(Idx < Ops.size());
3130 while (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(Ops[Idx])) {
3131 // We found two constants, fold them together!
3132 Ops[0] = getConstant(LHSC->getAPInt() * RHSC->getAPInt());
3133 if (Ops.size() == 2) return Ops[0];
3134 Ops.erase(Ops.begin()+1); // Erase the folded element
3135 LHSC = cast<SCEVConstant>(Ops[0]);
3136 }
3137
3138 // If we have a multiply of zero, it will always be zero.
3139 if (LHSC->getValue()->isZero())
3140 return LHSC;
3141
3142 // If we are left with a constant one being multiplied, strip it off.
3143 if (LHSC->getValue()->isOne()) {
3144 Ops.erase(Ops.begin());
3145 --Idx;
3146 }
3147
3148 if (Ops.size() == 1)
3149 return Ops[0];
3150 }
3151
3152 // Delay expensive flag strengthening until necessary.
3153 auto ComputeFlags = [this, OrigFlags](const ArrayRef<const SCEV *> Ops) {
3154 return StrengthenNoWrapFlags(this, scMulExpr, Ops, OrigFlags);
3155 };
3156
3157 // Limit recursion calls depth.
3159 return getOrCreateMulExpr(Ops, ComputeFlags(Ops));
3160
3161 if (SCEV *S = findExistingSCEVInCache(scMulExpr, Ops)) {
3162 // Don't strengthen flags if we have no new information.
3163 SCEVMulExpr *Mul = static_cast<SCEVMulExpr *>(S);
3164 if (Mul->getNoWrapFlags(OrigFlags) != OrigFlags)
3165 Mul->setNoWrapFlags(ComputeFlags(Ops));
3166 return S;
3167 }
3168
3169 if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(Ops[0])) {
3170 if (Ops.size() == 2) {
3171 // C1*(C2+V) -> C1*C2 + C1*V
3172 if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Ops[1]))
3173 // If any of Add's ops are Adds or Muls with a constant, apply this
3174 // transformation as well.
3175 //
3176 // TODO: There are some cases where this transformation is not
3177 // profitable; for example, Add = (C0 + X) * Y + Z. Maybe the scope of
3178 // this transformation should be narrowed down.
3179 if (Add->getNumOperands() == 2 && containsConstantInAddMulChain(Add)) {
3180 const SCEV *LHS = getMulExpr(LHSC, Add->getOperand(0),
3182 const SCEV *RHS = getMulExpr(LHSC, Add->getOperand(1),
3184 return getAddExpr(LHS, RHS, SCEV::FlagAnyWrap, Depth + 1);
3185 }
3186
3187 if (Ops[0]->isAllOnesValue()) {
3188 // If we have a mul by -1 of an add, try distributing the -1 among the
3189 // add operands.
3190 if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Ops[1])) {
3192 bool AnyFolded = false;
3193 for (const SCEV *AddOp : Add->operands()) {
3194 const SCEV *Mul = getMulExpr(Ops[0], AddOp, SCEV::FlagAnyWrap,
3195 Depth + 1);
3196 if (!isa<SCEVMulExpr>(Mul)) AnyFolded = true;
3197 NewOps.push_back(Mul);
3198 }
3199 if (AnyFolded)
3200 return getAddExpr(NewOps, SCEV::FlagAnyWrap, Depth + 1);
3201 } else if (const auto *AddRec = dyn_cast<SCEVAddRecExpr>(Ops[1])) {
3202 // Negation preserves a recurrence's no self-wrap property.
3204 for (const SCEV *AddRecOp : AddRec->operands())
3205 Operands.push_back(getMulExpr(Ops[0], AddRecOp, SCEV::FlagAnyWrap,
3206 Depth + 1));
3207 // Let M be the minimum representable signed value. AddRec with nsw
3208 // multiplied by -1 can have signed overflow if and only if it takes a
3209 // value of M: M * (-1) would stay M and (M + 1) * (-1) would be the
3210 // maximum signed value. In all other cases signed overflow is
3211 // impossible.
3212 auto FlagsMask = SCEV::FlagNW;
3213 if (hasFlags(AddRec->getNoWrapFlags(), SCEV::FlagNSW)) {
3214 auto MinInt =
3215 APInt::getSignedMinValue(getTypeSizeInBits(AddRec->getType()));
3216 if (getSignedRangeMin(AddRec) != MinInt)
3217 FlagsMask = setFlags(FlagsMask, SCEV::FlagNSW);
3218 }
3219 return getAddRecExpr(Operands, AddRec->getLoop(),
3220 AddRec->getNoWrapFlags(FlagsMask));
3221 }
3222 }
3223 }
3224 }
3225
3226 // Skip over the add expression until we get to a multiply.
3227 while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scMulExpr)
3228 ++Idx;
3229
3230 // If there are mul operands inline them all into this expression.
3231 if (Idx < Ops.size()) {
3232 bool DeletedMul = false;
3233 while (const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(Ops[Idx])) {
3234 if (Ops.size() > MulOpsInlineThreshold)
3235 break;
3236 // If we have an mul, expand the mul operands onto the end of the
3237 // operands list.
3238 Ops.erase(Ops.begin()+Idx);
3239 append_range(Ops, Mul->operands());
3240 DeletedMul = true;
3241 }
3242
3243 // If we deleted at least one mul, we added operands to the end of the
3244 // list, and they are not necessarily sorted. Recurse to resort and
3245 // resimplify any operands we just acquired.
3246 if (DeletedMul)
3247 return getMulExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
3248 }
3249
3250 // If there are any add recurrences in the operands list, see if any other
3251 // added values are loop invariant. If so, we can fold them into the
3252 // recurrence.
3253 while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scAddRecExpr)
3254 ++Idx;
3255
3256 // Scan over all recurrences, trying to fold loop invariants into them.
3257 for (; Idx < Ops.size() && isa<SCEVAddRecExpr>(Ops[Idx]); ++Idx) {
3258 // Scan all of the other operands to this mul and add them to the vector
3259 // if they are loop invariant w.r.t. the recurrence.
3261 const SCEVAddRecExpr *AddRec = cast<SCEVAddRecExpr>(Ops[Idx]);
3262 for (unsigned i = 0, e = Ops.size(); i != e; ++i)
3263 if (isAvailableAtLoopEntry(Ops[i], AddRec->getLoop())) {
3264 LIOps.push_back(Ops[i]);
3265 Ops.erase(Ops.begin()+i);
3266 --i; --e;
3267 }
3268
3269 // If we found some loop invariants, fold them into the recurrence.
3270 if (!LIOps.empty()) {
3271 // NLI * LI * {Start,+,Step} --> NLI * {LI*Start,+,LI*Step}
3273 NewOps.reserve(AddRec->getNumOperands());
3274 const SCEV *Scale = getMulExpr(LIOps, SCEV::FlagAnyWrap, Depth + 1);
3275
3276 // If both the mul and addrec are nuw, we can preserve nuw.
3277 // If both the mul and addrec are nsw, we can only preserve nsw if either
3278 // a) they are also nuw, or
3279 // b) all multiplications of addrec operands with scale are nsw.
3280 SCEV::NoWrapFlags Flags =
3281 AddRec->getNoWrapFlags(ComputeFlags({Scale, AddRec}));
3282
3283 for (unsigned i = 0, e = AddRec->getNumOperands(); i != e; ++i) {
3284 NewOps.push_back(getMulExpr(Scale, AddRec->getOperand(i),
3285 SCEV::FlagAnyWrap, Depth + 1));
3286
3287 if (hasFlags(Flags, SCEV::FlagNSW) && !hasFlags(Flags, SCEV::FlagNUW)) {
3289 Instruction::Mul, getSignedRange(Scale),
3291 if (!NSWRegion.contains(getSignedRange(AddRec->getOperand(i))))
3292 Flags = clearFlags(Flags, SCEV::FlagNSW);
3293 }
3294 }
3295
3296 const SCEV *NewRec = getAddRecExpr(NewOps, AddRec->getLoop(), Flags);
3297
3298 // If all of the other operands were loop invariant, we are done.
3299 if (Ops.size() == 1) return NewRec;
3300
3301 // Otherwise, multiply the folded AddRec by the non-invariant parts.
3302 for (unsigned i = 0;; ++i)
3303 if (Ops[i] == AddRec) {
3304 Ops[i] = NewRec;
3305 break;
3306 }
3307 return getMulExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
3308 }
3309
3310 // Okay, if there weren't any loop invariants to be folded, check to see
3311 // if there are multiple AddRec's with the same loop induction variable
3312 // being multiplied together. If so, we can fold them.
3313
3314 // {A1,+,A2,+,...,+,An}<L> * {B1,+,B2,+,...,+,Bn}<L>
3315 // = {x=1 in [ sum y=x..2x [ sum z=max(y-x, y-n)..min(x,n) [
3316 // choose(x, 2x)*choose(2x-y, x-z)*A_{y-z}*B_z
3317 // ]]],+,...up to x=2n}.
3318 // Note that the arguments to choose() are always integers with values
3319 // known at compile time, never SCEV objects.
3320 //
3321 // The implementation avoids pointless extra computations when the two
3322 // addrec's are of different length (mathematically, it's equivalent to
3323 // an infinite stream of zeros on the right).
3324 bool OpsModified = false;
3325 for (unsigned OtherIdx = Idx+1;
3326 OtherIdx != Ops.size() && isa<SCEVAddRecExpr>(Ops[OtherIdx]);
3327 ++OtherIdx) {
3328 const SCEVAddRecExpr *OtherAddRec =
3329 dyn_cast<SCEVAddRecExpr>(Ops[OtherIdx]);
3330 if (!OtherAddRec || OtherAddRec->getLoop() != AddRec->getLoop())
3331 continue;
3332
3333 // Limit max number of arguments to avoid creation of unreasonably big
3334 // SCEVAddRecs with very complex operands.
3335 if (AddRec->getNumOperands() + OtherAddRec->getNumOperands() - 1 >
3336 MaxAddRecSize || hasHugeExpression({AddRec, OtherAddRec}))
3337 continue;
3338
3339 bool Overflow = false;
3340 Type *Ty = AddRec->getType();
3341 bool LargerThan64Bits = getTypeSizeInBits(Ty) > 64;
3343 for (int x = 0, xe = AddRec->getNumOperands() +
3344 OtherAddRec->getNumOperands() - 1; x != xe && !Overflow; ++x) {
3345 SmallVector <const SCEV *, 7> SumOps;
3346 for (int y = x, ye = 2*x+1; y != ye && !Overflow; ++y) {
3347 uint64_t Coeff1 = Choose(x, 2*x - y, Overflow);
3348 for (int z = std::max(y-x, y-(int)AddRec->getNumOperands()+1),
3349 ze = std::min(x+1, (int)OtherAddRec->getNumOperands());
3350 z < ze && !Overflow; ++z) {
3351 uint64_t Coeff2 = Choose(2*x - y, x-z, Overflow);
3352 uint64_t Coeff;
3353 if (LargerThan64Bits)
3354 Coeff = umul_ov(Coeff1, Coeff2, Overflow);
3355 else
3356 Coeff = Coeff1*Coeff2;
3357 const SCEV *CoeffTerm = getConstant(Ty, Coeff);
3358 const SCEV *Term1 = AddRec->getOperand(y-z);
3359 const SCEV *Term2 = OtherAddRec->getOperand(z);
3360 SumOps.push_back(getMulExpr(CoeffTerm, Term1, Term2,
3361 SCEV::FlagAnyWrap, Depth + 1));
3362 }
3363 }
3364 if (SumOps.empty())
3365 SumOps.push_back(getZero(Ty));
3366 AddRecOps.push_back(getAddExpr(SumOps, SCEV::FlagAnyWrap, Depth + 1));
3367 }
3368 if (!Overflow) {
3369 const SCEV *NewAddRec = getAddRecExpr(AddRecOps, AddRec->getLoop(),
3371 if (Ops.size() == 2) return NewAddRec;
3372 Ops[Idx] = NewAddRec;
3373 Ops.erase(Ops.begin() + OtherIdx); --OtherIdx;
3374 OpsModified = true;
3375 AddRec = dyn_cast<SCEVAddRecExpr>(NewAddRec);
3376 if (!AddRec)
3377 break;
3378 }
3379 }
3380 if (OpsModified)
3381 return getMulExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
3382
3383 // Otherwise couldn't fold anything into this recurrence. Move onto the
3384 // next one.
3385 }
3386
3387 // Okay, it looks like we really DO need an mul expr. Check to see if we
3388 // already have one, otherwise create a new one.
3389 return getOrCreateMulExpr(Ops, ComputeFlags(Ops));
3390}
3391
3392/// Represents an unsigned remainder expression based on unsigned division.
3394 const SCEV *RHS) {
3397 "SCEVURemExpr operand types don't match!");
3398
3399 // Short-circuit easy cases
3400 if (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS)) {
3401 // If constant is one, the result is trivial
3402 if (RHSC->getValue()->isOne())
3403 return getZero(LHS->getType()); // X urem 1 --> 0
3404
3405 // If constant is a power of two, fold into a zext(trunc(LHS)).
3406 if (RHSC->getAPInt().isPowerOf2()) {
3407 Type *FullTy = LHS->getType();
3408 Type *TruncTy =
3409 IntegerType::get(getContext(), RHSC->getAPInt().logBase2());
3410 return getZeroExtendExpr(getTruncateExpr(LHS, TruncTy), FullTy);
3411 }
3412 }
3413
3414 // Fallback to %a == %x urem %y == %x -<nuw> ((%x udiv %y) *<nuw> %y)
3415 const SCEV *UDiv = getUDivExpr(LHS, RHS);
3416 const SCEV *Mult = getMulExpr(UDiv, RHS, SCEV::FlagNUW);
3417 return getMinusSCEV(LHS, Mult, SCEV::FlagNUW);
3418}
3419
3420/// Get a canonical unsigned division expression, or something simpler if
3421/// possible.
3423 const SCEV *RHS) {
3424 assert(!LHS->getType()->isPointerTy() &&
3425 "SCEVUDivExpr operand can't be pointer!");
3426 assert(LHS->getType() == RHS->getType() &&
3427 "SCEVUDivExpr operand types don't match!");
3428
3430 ID.AddInteger(scUDivExpr);
3431 ID.AddPointer(LHS);
3432 ID.AddPointer(RHS);
3433 void *IP = nullptr;
3434 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP))
3435 return S;
3436
3437 // 0 udiv Y == 0
3438 if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(LHS))
3439 if (LHSC->getValue()->isZero())
3440 return LHS;
3441
3442 if (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS)) {
3443 if (RHSC->getValue()->isOne())
3444 return LHS; // X udiv 1 --> x
3445 // If the denominator is zero, the result of the udiv is undefined. Don't
3446 // try to analyze it, because the resolution chosen here may differ from
3447 // the resolution chosen in other parts of the compiler.
3448 if (!RHSC->getValue()->isZero()) {
3449 // Determine if the division can be folded into the operands of
3450 // its operands.
3451 // TODO: Generalize this to non-constants by using known-bits information.
3452 Type *Ty = LHS->getType();
3453 unsigned LZ = RHSC->getAPInt().countl_zero();
3454 unsigned MaxShiftAmt = getTypeSizeInBits(Ty) - LZ - 1;
3455 // For non-power-of-two values, effectively round the value up to the
3456 // nearest power of two.
3457 if (!RHSC->getAPInt().isPowerOf2())
3458 ++MaxShiftAmt;
3459 IntegerType *ExtTy =
3460 IntegerType::get(getContext(), getTypeSizeInBits(Ty) + MaxShiftAmt);
3461 if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(LHS))
3462 if (const SCEVConstant *Step =
3463 dyn_cast<SCEVConstant>(AR->getStepRecurrence(*this))) {
3464 // {X,+,N}/C --> {X/C,+,N/C} if safe and N/C can be folded.
3465 const APInt &StepInt = Step->getAPInt();
3466 const APInt &DivInt = RHSC->getAPInt();
3467 if (!StepInt.urem(DivInt) &&
3468 getZeroExtendExpr(AR, ExtTy) ==
3469 getAddRecExpr(getZeroExtendExpr(AR->getStart(), ExtTy),
3470 getZeroExtendExpr(Step, ExtTy),
3471 AR->getLoop(), SCEV::FlagAnyWrap)) {
3473 for (const SCEV *Op : AR->operands())
3474 Operands.push_back(getUDivExpr(Op, RHS));
3475 return getAddRecExpr(Operands, AR->getLoop(), SCEV::FlagNW);
3476 }
3477 /// Get a canonical UDivExpr for a recurrence.
3478 /// {X,+,N}/C => {Y,+,N}/C where Y=X-(X%N). Safe when C%N=0.
3479 // We can currently only fold X%N if X is constant.
3480 const SCEVConstant *StartC = dyn_cast<SCEVConstant>(AR->getStart());
3481 if (StartC && !DivInt.urem(StepInt) &&
3482 getZeroExtendExpr(AR, ExtTy) ==
3483 getAddRecExpr(getZeroExtendExpr(AR->getStart(), ExtTy),
3484 getZeroExtendExpr(Step, ExtTy),
3485 AR->getLoop(), SCEV::FlagAnyWrap)) {
3486 const APInt &StartInt = StartC->getAPInt();
3487 const APInt &StartRem = StartInt.urem(StepInt);
3488 if (StartRem != 0) {
3489 const SCEV *NewLHS =
3490 getAddRecExpr(getConstant(StartInt - StartRem), Step,
3491 AR->getLoop(), SCEV::FlagNW);
3492 if (LHS != NewLHS) {
3493 LHS = NewLHS;
3494
3495 // Reset the ID to include the new LHS, and check if it is
3496 // already cached.
3497 ID.clear();
3498 ID.AddInteger(scUDivExpr);
3499 ID.AddPointer(LHS);
3500 ID.AddPointer(RHS);
3501 IP = nullptr;
3502 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP))
3503 return S;
3504 }
3505 }
3506 }
3507 }
3508 // (A*B)/C --> A*(B/C) if safe and B/C can be folded.
3509 if (const SCEVMulExpr *M = dyn_cast<SCEVMulExpr>(LHS)) {
3511 for (const SCEV *Op : M->operands())
3512 Operands.push_back(getZeroExtendExpr(Op, ExtTy));
3513 if (getZeroExtendExpr(M, ExtTy) == getMulExpr(Operands))
3514 // Find an operand that's safely divisible.
3515 for (unsigned i = 0, e = M->getNumOperands(); i != e; ++i) {
3516 const SCEV *Op = M->getOperand(i);
3517 const SCEV *Div = getUDivExpr(Op, RHSC);
3518 if (!isa<SCEVUDivExpr>(Div) && getMulExpr(Div, RHSC) == Op) {
3519 Operands = SmallVector<const SCEV *, 4>(M->operands());
3520 Operands[i] = Div;
3521 return getMulExpr(Operands);
3522 }
3523 }
3524 }
3525
3526 // (A/B)/C --> A/(B*C) if safe and B*C can be folded.
3527 if (const SCEVUDivExpr *OtherDiv = dyn_cast<SCEVUDivExpr>(LHS)) {
3528 if (auto *DivisorConstant =
3529 dyn_cast<SCEVConstant>(OtherDiv->getRHS())) {
3530 bool Overflow = false;
3531 APInt NewRHS =
3532 DivisorConstant->getAPInt().umul_ov(RHSC->getAPInt(), Overflow);
3533 if (Overflow) {
3534 return getConstant(RHSC->getType(), 0, false);
3535 }
3536 return getUDivExpr(OtherDiv->getLHS(), getConstant(NewRHS));
3537 }
3538 }
3539
3540 // (A+B)/C --> (A/C + B/C) if safe and A/C and B/C can be folded.
3541 if (const SCEVAddExpr *A = dyn_cast<SCEVAddExpr>(LHS)) {
3543 for (const SCEV *Op : A->operands())
3544 Operands.push_back(getZeroExtendExpr(Op, ExtTy));
3545 if (getZeroExtendExpr(A, ExtTy) == getAddExpr(Operands)) {
3546 Operands.clear();
3547 for (unsigned i = 0, e = A->getNumOperands(); i != e; ++i) {
3548 const SCEV *Op = getUDivExpr(A->getOperand(i), RHS);
3549 if (isa<SCEVUDivExpr>(Op) ||
3550 getMulExpr(Op, RHS) != A->getOperand(i))
3551 break;
3552 Operands.push_back(Op);
3553 }
3554 if (Operands.size() == A->getNumOperands())
3555 return getAddExpr(Operands);
3556 }
3557 }
3558
3559 // Fold if both operands are constant.
3560 if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(LHS))
3561 return getConstant(LHSC->getAPInt().udiv(RHSC->getAPInt()));
3562 }
3563 }
3564
3565 // The Insertion Point (IP) might be invalid by now (due to UniqueSCEVs
3566 // changes). Make sure we get a new one.
3567 IP = nullptr;
3568 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
3569 SCEV *S = new (SCEVAllocator) SCEVUDivExpr(ID.Intern(SCEVAllocator),
3570 LHS, RHS);
3571 UniqueSCEVs.InsertNode(S, IP);
3572 registerUser(S, {LHS, RHS});
3573 return S;
3574}
3575
3576APInt gcd(const SCEVConstant *C1, const SCEVConstant *C2) {
3577 APInt A = C1->getAPInt().abs();
3578 APInt B = C2->getAPInt().abs();
3579 uint32_t ABW = A.getBitWidth();
3580 uint32_t BBW = B.getBitWidth();
3581
3582 if (ABW > BBW)
3583 B = B.zext(ABW);
3584 else if (ABW < BBW)
3585 A = A.zext(BBW);
3586
3587 return APIntOps::GreatestCommonDivisor(std::move(A), std::move(B));
3588}
3589
3590/// Get a canonical unsigned division expression, or something simpler if
3591/// possible. There is no representation for an exact udiv in SCEV IR, but we
3592/// can attempt to remove factors from the LHS and RHS. We can't do this when
3593/// it's not exact because the udiv may be clearing bits.
3595 const SCEV *RHS) {
3596 // TODO: we could try to find factors in all sorts of things, but for now we
3597 // just deal with u/exact (multiply, constant). See SCEVDivision towards the
3598 // end of this file for inspiration.
3599
3600 const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(LHS);
3601 if (!Mul || !Mul->hasNoUnsignedWrap())
3602 return getUDivExpr(LHS, RHS);
3603
3604 if (const SCEVConstant *RHSCst = dyn_cast<SCEVConstant>(RHS)) {
3605 // If the mulexpr multiplies by a constant, then that constant must be the
3606 // first element of the mulexpr.
3607 if (const auto *LHSCst = dyn_cast<SCEVConstant>(Mul->getOperand(0))) {
3608 if (LHSCst == RHSCst) {
3610 return getMulExpr(Operands);
3611 }
3612
3613 // We can't just assume that LHSCst divides RHSCst cleanly, it could be
3614 // that there's a factor provided by one of the other terms. We need to
3615 // check.
3616 APInt Factor = gcd(LHSCst, RHSCst);
3617 if (!Factor.isIntN(1)) {
3618 LHSCst =
3619 cast<SCEVConstant>(getConstant(LHSCst->getAPInt().udiv(Factor)));
3620 RHSCst =
3621 cast<SCEVConstant>(getConstant(RHSCst->getAPInt().udiv(Factor)));
3623 Operands.push_back(LHSCst);
3624 append_range(Operands, Mul->operands().drop_front());
3626 RHS = RHSCst;
3627 Mul = dyn_cast<SCEVMulExpr>(LHS);
3628 if (!Mul)
3629 return getUDivExactExpr(LHS, RHS);
3630 }
3631 }
3632 }
3633
3634 for (int i = 0, e = Mul->getNumOperands(); i != e; ++i) {
3635 if (Mul->getOperand(i) == RHS) {
3637 append_range(Operands, Mul->operands().take_front(i));
3638 append_range(Operands, Mul->operands().drop_front(i + 1));
3639 return getMulExpr(Operands);
3640 }
3641 }
3642
3643 return getUDivExpr(LHS, RHS);
3644}
3645
3646/// Get an add recurrence expression for the specified loop. Simplify the
3647/// expression as much as possible.
3648const SCEV *ScalarEvolution::getAddRecExpr(const SCEV *Start, const SCEV *Step,
3649 const Loop *L,
3650 SCEV::NoWrapFlags Flags) {
3652 Operands.push_back(Start);
3653 if (const SCEVAddRecExpr *StepChrec = dyn_cast<SCEVAddRecExpr>(Step))
3654 if (StepChrec->getLoop() == L) {
3655 append_range(Operands, StepChrec->operands());
3656 return getAddRecExpr(Operands, L, maskFlags(Flags, SCEV::FlagNW));
3657 }
3658
3659 Operands.push_back(Step);
3660 return getAddRecExpr(Operands, L, Flags);
3661}
3662
3663/// Get an add recurrence expression for the specified loop. Simplify the
3664/// expression as much as possible.
3665const SCEV *
3667 const Loop *L, SCEV::NoWrapFlags Flags) {
3668 if (Operands.size() == 1) return Operands[0];
3669#ifndef NDEBUG
3671 for (unsigned i = 1, e = Operands.size(); i != e; ++i) {
3673 "SCEVAddRecExpr operand types don't match!");
3674 assert(!Operands[i]->getType()->isPointerTy() && "Step must be integer");
3675 }
3676 for (unsigned i = 0, e = Operands.size(); i != e; ++i)
3678 "SCEVAddRecExpr operand is not available at loop entry!");
3679#endif
3680
3681 if (Operands.back()->isZero()) {
3682 Operands.pop_back();
3683 return getAddRecExpr(Operands, L, SCEV::FlagAnyWrap); // {X,+,0} --> X
3684 }
3685
3686 // It's tempting to want to call getConstantMaxBackedgeTakenCount count here and
3687 // use that information to infer NUW and NSW flags. However, computing a
3688 // BE count requires calling getAddRecExpr, so we may not yet have a
3689 // meaningful BE count at this point (and if we don't, we'd be stuck
3690 // with a SCEVCouldNotCompute as the cached BE count).
3691
3692 Flags = StrengthenNoWrapFlags(this, scAddRecExpr, Operands, Flags);
3693
3694 // Canonicalize nested AddRecs in by nesting them in order of loop depth.
3695 if (const SCEVAddRecExpr *NestedAR = dyn_cast<SCEVAddRecExpr>(Operands[0])) {
3696 const Loop *NestedLoop = NestedAR->getLoop();
3697 if (L->contains(NestedLoop)
3698 ? (L->getLoopDepth() < NestedLoop->getLoopDepth())
3699 : (!NestedLoop->contains(L) &&
3700 DT.dominates(L->getHeader(), NestedLoop->getHeader()))) {
3701 SmallVector<const SCEV *, 4> NestedOperands(NestedAR->operands());
3702 Operands[0] = NestedAR->getStart();
3703 // AddRecs require their operands be loop-invariant with respect to their
3704 // loops. Don't perform this transformation if it would break this
3705 // requirement.
3706 bool AllInvariant = all_of(
3707 Operands, [&](const SCEV *Op) { return isLoopInvariant(Op, L); });
3708
3709 if (AllInvariant) {
3710 // Create a recurrence for the outer loop with the same step size.
3711 //
3712 // The outer recurrence keeps its NW flag but only keeps NUW/NSW if the
3713 // inner recurrence has the same property.
3714 SCEV::NoWrapFlags OuterFlags =
3715 maskFlags(Flags, SCEV::FlagNW | NestedAR->getNoWrapFlags());
3716
3717 NestedOperands[0] = getAddRecExpr(Operands, L, OuterFlags);
3718 AllInvariant = all_of(NestedOperands, [&](const SCEV *Op) {
3719 return isLoopInvariant(Op, NestedLoop);
3720 });
3721
3722 if (AllInvariant) {
3723 // Ok, both add recurrences are valid after the transformation.
3724 //
3725 // The inner recurrence keeps its NW flag but only keeps NUW/NSW if
3726 // the outer recurrence has the same property.
3727 SCEV::NoWrapFlags InnerFlags =
3728 maskFlags(NestedAR->getNoWrapFlags(), SCEV::FlagNW | Flags);
3729 return getAddRecExpr(NestedOperands, NestedLoop, InnerFlags);
3730 }
3731 }
3732 // Reset Operands to its original state.
3733 Operands[0] = NestedAR;
3734 }
3735 }
3736
3737 // Okay, it looks like we really DO need an addrec expr. Check to see if we
3738 // already have one, otherwise create a new one.
3739 return getOrCreateAddRecExpr(Operands, L, Flags);
3740}
3741
3742const SCEV *
3744 const SmallVectorImpl<const SCEV *> &IndexExprs) {
3745 const SCEV *BaseExpr = getSCEV(GEP->getPointerOperand());
3746 // getSCEV(Base)->getType() has the same address space as Base->getType()
3747 // because SCEV::getType() preserves the address space.
3748 Type *IntIdxTy = getEffectiveSCEVType(BaseExpr->getType());
3749 const bool AssumeInBoundsFlags = [&]() {
3750 if (!GEP->isInBounds())
3751 return false;
3752
3753 // We'd like to propagate flags from the IR to the corresponding SCEV nodes,
3754 // but to do that, we have to ensure that said flag is valid in the entire
3755 // defined scope of the SCEV.
3756 auto *GEPI = dyn_cast<Instruction>(GEP);
3757 // TODO: non-instructions have global scope. We might be able to prove
3758 // some global scope cases
3759 return GEPI && isSCEVExprNeverPoison(GEPI);
3760 }();
3761
3762 SCEV::NoWrapFlags OffsetWrap =
3763 AssumeInBoundsFlags ? SCEV::FlagNSW : SCEV::FlagAnyWrap;
3764
3765 Type *CurTy = GEP->getType();
3766 bool FirstIter = true;
3768 for (const SCEV *IndexExpr : IndexExprs) {
3769 // Compute the (potentially symbolic) offset in bytes for this index.
3770 if (StructType *STy = dyn_cast<StructType>(CurTy)) {
3771 // For a struct, add the member offset.
3772 ConstantInt *Index = cast<SCEVConstant>(IndexExpr)->getValue();
3773 unsigned FieldNo = Index->getZExtValue();
3774 const SCEV *FieldOffset = getOffsetOfExpr(IntIdxTy, STy, FieldNo);
3775 Offsets.push_back(FieldOffset);
3776
3777 // Update CurTy to the type of the field at Index.
3778 CurTy = STy->getTypeAtIndex(Index);
3779 } else {
3780 // Update CurTy to its element type.
3781 if (FirstIter) {
3782 assert(isa<PointerType>(CurTy) &&
3783 "The first index of a GEP indexes a pointer");
3784 CurTy = GEP->getSourceElementType();
3785 FirstIter = false;
3786 } else {
3788 }
3789 // For an array, add the element offset, explicitly scaled.
3790 const SCEV *ElementSize = getSizeOfExpr(IntIdxTy, CurTy);
3791 // Getelementptr indices are signed.
3792 IndexExpr = getTruncateOrSignExtend(IndexExpr, IntIdxTy);
3793
3794 // Multiply the index by the element size to compute the element offset.
3795 const SCEV *LocalOffset = getMulExpr(IndexExpr, ElementSize, OffsetWrap);
3796 Offsets.push_back(LocalOffset);
3797 }
3798 }
3799
3800 // Handle degenerate case of GEP without offsets.
3801 if (Offsets.empty())
3802 return BaseExpr;
3803
3804 // Add the offsets together, assuming nsw if inbounds.
3805 const SCEV *Offset = getAddExpr(Offsets, OffsetWrap);
3806 // Add the base address and the offset. We cannot use the nsw flag, as the
3807 // base address is unsigned. However, if we know that the offset is
3808 // non-negative, we can use nuw.
3809 SCEV::NoWrapFlags BaseWrap = AssumeInBoundsFlags && isKnownNonNegative(Offset)
3811 auto *GEPExpr = getAddExpr(BaseExpr, Offset, BaseWrap);
3812 assert(BaseExpr->getType() == GEPExpr->getType() &&
3813 "GEP should not change type mid-flight.");
3814 return GEPExpr;
3815}
3816
3817SCEV *ScalarEvolution::findExistingSCEVInCache(SCEVTypes SCEVType,
3820 ID.AddInteger(SCEVType);
3821 for (const SCEV *Op : Ops)
3822 ID.AddPointer(Op);
3823 void *IP = nullptr;
3824 return UniqueSCEVs.FindNodeOrInsertPos(ID, IP);
3825}
3826
3827const SCEV *ScalarEvolution::getAbsExpr(const SCEV *Op, bool IsNSW) {
3829 return getSMaxExpr(Op, getNegativeSCEV(Op, Flags));
3830}
3831
3834 assert(SCEVMinMaxExpr::isMinMaxType(Kind) && "Not a SCEVMinMaxExpr!");
3835 assert(!Ops.empty() && "Cannot get empty (u|s)(min|max)!");
3836 if (Ops.size() == 1) return Ops[0];
3837#ifndef NDEBUG
3838 Type *ETy = getEffectiveSCEVType(Ops[0]->getType());
3839 for (unsigned i = 1, e = Ops.size(); i != e; ++i) {
3840 assert(getEffectiveSCEVType(Ops[i]->getType()) == ETy &&
3841 "Operand types don't match!");
3842 assert(Ops[0]->getType()->isPointerTy() ==
3843 Ops[i]->getType()->isPointerTy() &&
3844 "min/max should be consistently pointerish");
3845 }
3846#endif
3847
3848 bool IsSigned = Kind == scSMaxExpr || Kind == scSMinExpr;
3849 bool IsMax = Kind == scSMaxExpr || Kind == scUMaxExpr;
3850
3851 // Sort by complexity, this groups all similar expression types together.
3852 GroupByComplexity(Ops, &LI, DT);
3853
3854 // Check if we have created the same expression before.
3855 if (const SCEV *S = findExistingSCEVInCache(Kind, Ops)) {
3856 return S;
3857 }
3858
3859 // If there are any constants, fold them together.
3860 unsigned Idx = 0;
3861 if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(Ops[0])) {
3862 ++Idx;
3863 assert(Idx < Ops.size());
3864 auto FoldOp = [&](const APInt &LHS, const APInt &RHS) {
3865 switch (Kind) {
3866 case scSMaxExpr:
3867 return APIntOps::smax(LHS, RHS);
3868 case scSMinExpr:
3869 return APIntOps::smin(LHS, RHS);
3870 case scUMaxExpr:
3871 return APIntOps::umax(LHS, RHS);
3872 case scUMinExpr:
3873 return APIntOps::umin(LHS, RHS);
3874 default:
3875 llvm_unreachable("Unknown SCEV min/max opcode");
3876 }
3877 };
3878
3879 while (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(Ops[Idx])) {
3880 // We found two constants, fold them together!
3881 ConstantInt *Fold = ConstantInt::get(
3882 getContext(), FoldOp(LHSC->getAPInt(), RHSC->getAPInt()));
3883 Ops[0] = getConstant(Fold);
3884 Ops.erase(Ops.begin()+1); // Erase the folded element
3885 if (Ops.size() == 1) return Ops[0];
3886 LHSC = cast<SCEVConstant>(Ops[0]);
3887 }
3888
3889 bool IsMinV = LHSC->getValue()->isMinValue(IsSigned);
3890 bool IsMaxV = LHSC->getValue()->isMaxValue(IsSigned);
3891
3892 if (IsMax ? IsMinV : IsMaxV) {
3893 // If we are left with a constant minimum(/maximum)-int, strip it off.
3894 Ops.erase(Ops.begin());
3895 --Idx;
3896 } else if (IsMax ? IsMaxV : IsMinV) {
3897 // If we have a max(/min) with a constant maximum(/minimum)-int,
3898 // it will always be the extremum.
3899 return LHSC;
3900 }
3901
3902 if (Ops.size() == 1) return Ops[0];
3903 }
3904
3905 // Find the first operation of the same kind
3906 while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < Kind)
3907 ++Idx;
3908
3909 // Check to see if one of the operands is of the same kind. If so, expand its
3910 // operands onto our operand list, and recurse to simplify.
3911 if (Idx < Ops.size()) {
3912 bool DeletedAny = false;
3913 while (Ops[Idx]->getSCEVType() == Kind) {
3914 const SCEVMinMaxExpr *SMME = cast<SCEVMinMaxExpr>(Ops[Idx]);
3915 Ops.erase(Ops.begin()+Idx);
3916 append_range(Ops, SMME->operands());
3917 DeletedAny = true;
3918 }
3919
3920 if (DeletedAny)
3921 return getMinMaxExpr(Kind, Ops);
3922 }
3923
3924 // Okay, check to see if the same value occurs in the operand list twice. If
3925 // so, delete one. Since we sorted the list, these values are required to
3926 // be adjacent.
3931 llvm::CmpInst::Predicate FirstPred = IsMax ? GEPred : LEPred;
3932 llvm::CmpInst::Predicate SecondPred = IsMax ? LEPred : GEPred;
3933 for (unsigned i = 0, e = Ops.size() - 1; i != e; ++i) {
3934 if (Ops[i] == Ops[i + 1] ||
3935 isKnownViaNonRecursiveReasoning(FirstPred, Ops[i], Ops[i + 1])) {
3936 // X op Y op Y --> X op Y
3937 // X op Y --> X, if we know X, Y are ordered appropriately
3938 Ops.erase(Ops.begin() + i + 1, Ops.begin() + i + 2);
3939 --i;
3940 --e;
3941 } else if (isKnownViaNonRecursiveReasoning(SecondPred, Ops[i],
3942 Ops[i + 1])) {
3943 // X op Y --> Y, if we know X, Y are ordered appropriately
3944 Ops.erase(Ops.begin() + i, Ops.begin() + i + 1);
3945 --i;
3946 --e;
3947 }
3948 }
3949
3950 if (Ops.size() == 1) return Ops[0];
3951
3952 assert(!Ops.empty() && "Reduced smax down to nothing!");
3953
3954 // Okay, it looks like we really DO need an expr. Check to see if we
3955 // already have one, otherwise create a new one.
3957 ID.AddInteger(Kind);
3958 for (unsigned i = 0, e = Ops.size(); i != e; ++i)
3959 ID.AddPointer(Ops[i]);
3960 void *IP = nullptr;
3961 const SCEV *ExistingSCEV = UniqueSCEVs.FindNodeOrInsertPos(ID, IP);
3962 if (ExistingSCEV)
3963 return ExistingSCEV;
3964 const SCEV **O = SCEVAllocator.Allocate<const SCEV *>(Ops.size());
3965 std::uninitialized_copy(Ops.begin(), Ops.end(), O);
3966 SCEV *S = new (SCEVAllocator)
3967 SCEVMinMaxExpr(ID.Intern(SCEVAllocator), Kind, O, Ops.size());
3968
3969 UniqueSCEVs.InsertNode(S, IP);
3970 registerUser(S, Ops);
3971 return S;
3972}
3973
3974namespace {
3975
3976class SCEVSequentialMinMaxDeduplicatingVisitor final
3977 : public SCEVVisitor<SCEVSequentialMinMaxDeduplicatingVisitor,
3978 std::optional<const SCEV *>> {
3979 using RetVal = std::optional<const SCEV *>;
3981
3982 ScalarEvolution &SE;
3983 const SCEVTypes RootKind; // Must be a sequential min/max expression.
3984 const SCEVTypes NonSequentialRootKind; // Non-sequential variant of RootKind.
3986
3987 bool canRecurseInto(SCEVTypes Kind) const {
3988 // We can only recurse into the SCEV expression of the same effective type
3989 // as the type of our root SCEV expression.
3990 return RootKind == Kind || NonSequentialRootKind == Kind;
3991 };
3992
3993 RetVal visitAnyMinMaxExpr(const SCEV *S) {
3994 assert((isa<SCEVMinMaxExpr>(S) || isa<SCEVSequentialMinMaxExpr>(S)) &&
3995 "Only for min/max expressions.");
3996 SCEVTypes Kind = S->getSCEVType();
3997
3998 if (!canRecurseInto(Kind))
3999 return S;
4000
4001 auto *NAry = cast<SCEVNAryExpr>(S);
4003 bool Changed = visit(Kind, NAry->operands(), NewOps);
4004
4005 if (!Changed)
4006 return S;
4007 if (NewOps.empty())
4008 return std::nullopt;
4009
4010 return isa<SCEVSequentialMinMaxExpr>(S)
4011 ? SE.getSequentialMinMaxExpr(Kind, NewOps)
4012 : SE.getMinMaxExpr(Kind, NewOps);
4013 }
4014
4015 RetVal visit(const SCEV *S) {
4016 // Has the whole operand been seen already?
4017 if (!SeenOps.insert(S).second)
4018 return std::nullopt;
4019 return Base::visit(S);
4020 }
4021
4022public:
4023 SCEVSequentialMinMaxDeduplicatingVisitor(ScalarEvolution &SE,
4024 SCEVTypes RootKind)
4025 : SE(SE), RootKind(RootKind),
4026 NonSequentialRootKind(
4027 SCEVSequentialMinMaxExpr::getEquivalentNonSequentialSCEVType(
4028 RootKind)) {}
4029
4030 bool /*Changed*/ visit(SCEVTypes Kind, ArrayRef<const SCEV *> OrigOps,
4032 bool Changed = false;
4034 Ops.reserve(OrigOps.size());
4035
4036 for (const SCEV *Op : OrigOps) {
4037 RetVal NewOp = visit(Op);
4038 if (NewOp != Op)
4039 Changed = true;
4040 if (NewOp)
4041 Ops.emplace_back(*NewOp);
4042 }
4043
4044 if (Changed)
4045 NewOps = std::move(Ops);
4046 return Changed;
4047 }
4048
4049 RetVal visitConstant(const SCEVConstant *Constant) { return Constant; }
4050
4051 RetVal visitVScale(const SCEVVScale *VScale) { return VScale; }
4052
4053 RetVal visitPtrToIntExpr(const SCEVPtrToIntExpr *Expr) { return Expr; }
4054
4055 RetVal visitTruncateExpr(const SCEVTruncateExpr *Expr) { return Expr; }
4056
4057 RetVal visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr) { return Expr; }
4058
4059 RetVal visitSignExtendExpr(const SCEVSignExtendExpr *Expr) { return Expr; }
4060
4061 RetVal visitAddExpr(const SCEVAddExpr *Expr) { return Expr; }
4062
4063 RetVal visitMulExpr(const SCEVMulExpr *Expr) { return Expr; }
4064
4065 RetVal visitUDivExpr(const SCEVUDivExpr *Expr) { return Expr; }
4066
4067 RetVal visitAddRecExpr(const SCEVAddRecExpr *Expr) { return Expr; }
4068
4069 RetVal visitSMaxExpr(const SCEVSMaxExpr *Expr) {
4070 return visitAnyMinMaxExpr(Expr);
4071 }
4072
4073 RetVal visitUMaxExpr(const SCEVUMaxExpr *Expr) {
4074 return visitAnyMinMaxExpr(Expr);
4075 }
4076
4077 RetVal visitSMinExpr(const SCEVSMinExpr *Expr) {
4078 return visitAnyMinMaxExpr(Expr);
4079 }
4080
4081 RetVal visitUMinExpr(const SCEVUMinExpr *Expr) {
4082 return visitAnyMinMaxExpr(Expr);
4083 }
4084
4085 RetVal visitSequentialUMinExpr(const SCEVSequentialUMinExpr *Expr) {
4086 return visitAnyMinMaxExpr(Expr);
4087 }
4088
4089 RetVal visitUnknown(const SCEVUnknown *Expr) { return Expr; }
4090
4091 RetVal visitCouldNotCompute(const SCEVCouldNotCompute *Expr) { return Expr; }
4092};
4093
4094} // namespace
4095
4097 switch (Kind) {
4098 case scConstant:
4099 case scVScale:
4100 case scTruncate:
4101 case scZeroExtend:
4102 case scSignExtend:
4103 case scPtrToInt:
4104 case scAddExpr:
4105 case scMulExpr:
4106 case scUDivExpr:
4107 case scAddRecExpr:
4108 case scUMaxExpr:
4109 case scSMaxExpr:
4110 case scUMinExpr:
4111 case scSMinExpr:
4112 case scUnknown:
4113 // If any operand is poison, the whole expression is poison.
4114 return true;
4116 // FIXME: if the *first* operand is poison, the whole expression is poison.
4117 return false; // Pessimistically, say that it does not propagate poison.
4118 case scCouldNotCompute:
4119 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
4120 }
4121 llvm_unreachable("Unknown SCEV kind!");
4122}
4123
4124namespace {
4125// The only way poison may be introduced in a SCEV expression is from a
4126// poison SCEVUnknown (ConstantExprs are also represented as SCEVUnknown,
4127// not SCEVConstant). Notably, nowrap flags in SCEV nodes can *not*
4128// introduce poison -- they encode guaranteed, non-speculated knowledge.
4129//
4130// Additionally, all SCEV nodes propagate poison from inputs to outputs,
4131// with the notable exception of umin_seq, where only poison from the first
4132// operand is (unconditionally) propagated.
4133struct SCEVPoisonCollector {
4134 bool LookThroughMaybePoisonBlocking;
4136 SCEVPoisonCollector(bool LookThroughMaybePoisonBlocking)
4137 : LookThroughMaybePoisonBlocking(LookThroughMaybePoisonBlocking) {}
4138
4139 bool follow(const SCEV *S) {
4140 if (!LookThroughMaybePoisonBlocking &&
4142 return false;
4143
4144 if (auto *SU = dyn_cast<SCEVUnknown>(S)) {
4145 if (!isGuaranteedNotToBePoison(SU->getValue()))
4146 MaybePoison.insert(SU);
4147 }
4148 return true;
4149 }
4150 bool isDone() const { return false; }
4151};
4152} // namespace
4153
4154/// Return true if V is poison given that AssumedPoison is already poison.
4155static bool impliesPoison(const SCEV *AssumedPoison, const SCEV *S) {
4156 // First collect all SCEVs that might result in AssumedPoison to be poison.
4157 // We need to look through potentially poison-blocking operations here,
4158 // because we want to find all SCEVs that *might* result in poison, not only
4159 // those that are *required* to.
4160 SCEVPoisonCollector PC1(/* LookThroughMaybePoisonBlocking */ true);
4161 visitAll(AssumedPoison, PC1);
4162
4163 // AssumedPoison is never poison. As the assumption is false, the implication
4164 // is true. Don't bother walking the other SCEV in this case.
4165 if (PC1.MaybePoison.empty())
4166 return true;
4167
4168 // Collect all SCEVs in S that, if poison, *will* result in S being poison
4169 // as well. We cannot look through potentially poison-blocking operations
4170 // here, as their arguments only *may* make the result poison.
4171 SCEVPoisonCollector PC2(/* LookThroughMaybePoisonBlocking */ false);
4172 visitAll(S, PC2);
4173
4174 // Make sure that no matter which SCEV in PC1.MaybePoison is actually poison,
4175 // it will also make S poison by being part of PC2.MaybePoison.
4176 return all_of(PC1.MaybePoison, [&](const SCEVUnknown *S) {
4177 return PC2.MaybePoison.contains(S);
4178 });
4179}
4180
4182 SmallPtrSetImpl<const Value *> &Result, const SCEV *S) {
4183 SCEVPoisonCollector PC(/* LookThroughMaybePoisonBlocking */ false);
4184 visitAll(S, PC);
4185 for (const SCEVUnknown *SU : PC.MaybePoison)
4186 Result.insert(SU->getValue());
4187}
4188
4190 const SCEV *S, Instruction *I,
4191 SmallVectorImpl<Instruction *> &DropPoisonGeneratingInsts) {
4192 // If the instruction cannot be poison, it's always safe to reuse.
4194 return true;
4195
4196 // Otherwise, it is possible that I is more poisonous that S. Collect the
4197 // poison-contributors of S, and then check whether I has any additional
4198 // poison-contributors. Poison that is contributed through poison-generating
4199 // flags is handled by dropping those flags instead.
4201 getPoisonGeneratingValues(PoisonVals, S);
4202
4203 SmallVector<Value *> Worklist;
4205 Worklist.push_back(I);
4206 while (!Worklist.empty()) {
4207 Value *V = Worklist.pop_back_val();
4208 if (!Visited.insert(V).second)
4209 continue;
4210
4211 // Avoid walking large instruction graphs.
4212 if (Visited.size() > 16)
4213 return false;
4214
4215 // Either the value can't be poison, or the S would also be poison if it
4216 // is.
4217 if (PoisonVals.contains(V) || isGuaranteedNotToBePoison(V))
4218 continue;
4219
4220 auto *I = dyn_cast<Instruction>(V);
4221 if (!I)
4222 return false;
4223
4224 // Disjoint or instructions are interpreted as adds by SCEV. However, we
4225 // can't replace an arbitrary add with disjoint or, even if we drop the
4226 // flag. We would need to convert the or into an add.
4227 if (auto *PDI = dyn_cast<PossiblyDisjointInst>(I))
4228 if (PDI->isDisjoint())
4229 return false;
4230
4231 // FIXME: Ignore vscale, even though it technically could be poison. Do this
4232 // because SCEV currently assumes it can't be poison. Remove this special
4233 // case once we proper model when vscale can be poison.
4234 if (auto *II = dyn_cast<IntrinsicInst>(I);
4235 II && II->getIntrinsicID() == Intrinsic::vscale)
4236 continue;
4237
4238 if (canCreatePoison(cast<Operator>(I), /*ConsiderFlagsAndMetadata*/ false))
4239 return false;
4240
4241 // If the instruction can't create poison, we can recurse to its operands.
4242 if (I->hasPoisonGeneratingFlagsOrMetadata())
4243 DropPoisonGeneratingInsts.push_back(I);
4244
4245 for (Value *Op : I->operands())
4246 Worklist.push_back(Op);
4247 }
4248 return true;
4249}
4250
4251const SCEV *
4254 assert(SCEVSequentialMinMaxExpr::isSequentialMinMaxType(Kind) &&
4255 "Not a SCEVSequentialMinMaxExpr!");
4256 assert(!Ops.empty() && "Cannot get empty (u|s)(min|max)!");
4257 if (Ops.size() == 1)
4258 return Ops[0];
4259#ifndef NDEBUG
4260 Type *ETy = getEffectiveSCEVType(Ops[0]->getType());
4261 for (unsigned i = 1, e = Ops.size(); i != e; ++i) {
4262 assert(getEffectiveSCEVType(Ops[i]->getType()) == ETy &&
4263 "Operand types don't match!");
4264 assert(Ops[0]->getType()->isPointerTy() ==
4265 Ops[i]->getType()->isPointerTy() &&
4266 "min/max should be consistently pointerish");
4267 }
4268#endif
4269
4270 // Note that SCEVSequentialMinMaxExpr is *NOT* commutative,
4271 // so we can *NOT* do any kind of sorting of the expressions!
4272
4273 // Check if we have created the same expression before.
4274 if (const SCEV *S = findExistingSCEVInCache(Kind, Ops))
4275 return S;
4276
4277 // FIXME: there are *some* simplifications that we can do here.
4278
4279 // Keep only the first instance of an operand.
4280 {
4281 SCEVSequentialMinMaxDeduplicatingVisitor Deduplicator(*this, Kind);
4282 bool Changed = Deduplicator.visit(Kind, Ops, Ops);
4283 if (Changed)
4284 return getSequentialMinMaxExpr(Kind, Ops);
4285 }
4286
4287 // Check to see if one of the operands is of the same kind. If so, expand its
4288 // operands onto our operand list, and recurse to simplify.
4289 {
4290 unsigned Idx = 0;
4291 bool DeletedAny = false;
4292 while (Idx < Ops.size()) {
4293 if (Ops[Idx]->getSCEVType() != Kind) {
4294 ++Idx;
4295 continue;
4296 }
4297 const auto *SMME = cast<SCEVSequentialMinMaxExpr>(Ops[Idx]);
4298 Ops.erase(Ops.begin() + Idx);
4299 Ops.insert(Ops.begin() + Idx, SMME->operands().begin(),
4300 SMME->operands().end());
4301 DeletedAny = true;
4302 }
4303
4304 if (DeletedAny)
4305 return getSequentialMinMaxExpr(Kind, Ops);
4306 }
4307
4308 const SCEV *SaturationPoint;
4310 switch (Kind) {
4312 SaturationPoint = getZero(Ops[0]->getType());
4313 Pred = ICmpInst::ICMP_ULE;
4314 break;
4315 default:
4316 llvm_unreachable("Not a sequential min/max type.");
4317 }
4318
4319 for (unsigned i = 1, e = Ops.size(); i != e; ++i) {
4320 // We can replace %x umin_seq %y with %x umin %y if either:
4321 // * %y being poison implies %x is also poison.
4322 // * %x cannot be the saturating value (e.g. zero for umin).
4323 if (::impliesPoison(Ops[i], Ops[i - 1]) ||
4324 isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_NE, Ops[i - 1],
4325 SaturationPoint)) {
4326 SmallVector<const SCEV *> SeqOps = {Ops[i - 1], Ops[i]};
4327 Ops[i - 1] = getMinMaxExpr(
4329 SeqOps);
4330 Ops.erase(Ops.begin() + i);
4331 return getSequentialMinMaxExpr(Kind, Ops);
4332 }
4333 // Fold %x umin_seq %y to %x if %x ule %y.
4334 // TODO: We might be able to prove the predicate for a later operand.
4335 if (isKnownViaNonRecursiveReasoning(Pred, Ops[i - 1], Ops[i])) {
4336 Ops.erase(Ops.begin() + i);
4337 return getSequentialMinMaxExpr(Kind, Ops);
4338 }
4339 }
4340
4341 // Okay, it looks like we really DO need an expr. Check to see if we
4342 // already have one, otherwise create a new one.
4344 ID.AddInteger(Kind);
4345 for (unsigned i = 0, e = Ops.size(); i != e; ++i)
4346 ID.AddPointer(Ops[i]);
4347 void *IP = nullptr;
4348 const SCEV *ExistingSCEV = UniqueSCEVs.FindNodeOrInsertPos(ID, IP);
4349 if (ExistingSCEV)
4350 return ExistingSCEV;
4351
4352 const SCEV **O = SCEVAllocator.Allocate<const SCEV *>(Ops.size());
4353 std::uninitialized_copy(Ops.begin(), Ops.end(), O);
4354 SCEV *S = new (SCEVAllocator)
4355 SCEVSequentialMinMaxExpr(ID.Intern(SCEVAllocator), Kind, O, Ops.size());
4356
4357 UniqueSCEVs.InsertNode(S, IP);
4358 registerUser(S, Ops);
4359 return S;
4360}
4361
4362const SCEV *ScalarEvolution::getSMaxExpr(const SCEV *LHS, const SCEV *RHS) {
4364 return getSMaxExpr(Ops);
4365}
4366
4368 return getMinMaxExpr(scSMaxExpr, Ops);
4369}
4370
4371const SCEV *ScalarEvolution::getUMaxExpr(const SCEV *LHS, const SCEV *RHS) {
4373 return getUMaxExpr(Ops);
4374}
4375
4377 return getMinMaxExpr(scUMaxExpr, Ops);
4378}
4379
4381 const SCEV *RHS) {
4383 return getSMinExpr(Ops);
4384}
4385
4387 return getMinMaxExpr(scSMinExpr, Ops);
4388}
4389
4390const SCEV *ScalarEvolution::getUMinExpr(const SCEV *LHS, const SCEV *RHS,
4391 bool Sequential) {
4393 return getUMinExpr(Ops, Sequential);
4394}
4395
4397 bool Sequential) {
4398 return Sequential ? getSequentialMinMaxExpr(scSequentialUMinExpr, Ops)
4399 : getMinMaxExpr(scUMinExpr, Ops);
4400}
4401
4402const SCEV *
4404 const SCEV *Res = getConstant(IntTy, Size.getKnownMinValue());
4405 if (Size.isScalable())
4406 Res = getMulExpr(Res, getVScale(IntTy));
4407 return Res;
4408}
4409
4411 return getSizeOfExpr(IntTy, getDataLayout().getTypeAllocSize(AllocTy));
4412}
4413
4415 return getSizeOfExpr(IntTy, getDataLayout().getTypeStoreSize(StoreTy));
4416}
4417
4419 StructType *STy,
4420 unsigned FieldNo) {
4421 // We can bypass creating a target-independent constant expression and then
4422 // folding it back into a ConstantInt. This is just a compile-time
4423 // optimization.
4424 const StructLayout *SL = getDataLayout().getStructLayout(STy);
4425 assert(!SL->getSizeInBits().isScalable() &&
4426 "Cannot get offset for structure containing scalable vector types");
4427 return getConstant(IntTy, SL->getElementOffset(FieldNo));
4428}
4429
4431 // Don't attempt to do anything other than create a SCEVUnknown object
4432 // here. createSCEV only calls getUnknown after checking for all other
4433 // interesting possibilities, and any other code that calls getUnknown
4434 // is doing so in order to hide a value from SCEV canonicalization.
4435
4437 ID.AddInteger(scUnknown);
4438 ID.AddPointer(V);
4439 void *IP = nullptr;
4440 if (SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) {
4441 assert(cast<SCEVUnknown>(S)->getValue() == V &&
4442 "Stale SCEVUnknown in uniquing map!");
4443 return S;
4444 }
4445 SCEV *S = new (SCEVAllocator) SCEVUnknown(ID.Intern(SCEVAllocator), V, this,
4446 FirstUnknown);
4447 FirstUnknown = cast<SCEVUnknown>(S);
4448 UniqueSCEVs.InsertNode(S, IP);
4449 return S;
4450}
4451
4452//===----------------------------------------------------------------------===//
4453// Basic SCEV Analysis and PHI Idiom Recognition Code
4454//
4455
4456/// Test if values of the given type are analyzable within the SCEV
4457/// framework. This primarily includes integer types, and it can optionally
4458/// include pointer types if the ScalarEvolution class has access to
4459/// target-specific information.
4461 // Integers and pointers are always SCEVable.
4462 return Ty->isIntOrPtrTy();
4463}
4464
4465/// Return the size in bits of the specified type, for which isSCEVable must
4466/// return true.
4468 assert(isSCEVable(Ty) && "Type is not SCEVable!");
4469 if (Ty->isPointerTy())
4471 return getDataLayout().getTypeSizeInBits(Ty);
4472}
4473
4474/// Return a type with the same bitwidth as the given type and which represents
4475/// how SCEV will treat the given type, for which isSCEVable must return
4476/// true. For pointer types, this is the pointer index sized integer type.
4478 assert(isSCEVable(Ty) && "Type is not SCEVable!");
4479
4480 if (Ty->isIntegerTy())
4481 return Ty;
4482
4483 // The only other support type is pointer.
4484 assert(Ty->isPointerTy() && "Unexpected non-pointer non-integer type!");
4485 return getDataLayout().getIndexType(Ty);
4486}
4487
4489 return getTypeSizeInBits(T1) >= getTypeSizeInBits(T2) ? T1 : T2;
4490}
4491
4493 const SCEV *B) {
4494 /// For a valid use point to exist, the defining scope of one operand
4495 /// must dominate the other.
4496 bool PreciseA, PreciseB;
4497 auto *ScopeA = getDefiningScopeBound({A}, PreciseA);
4498 auto *ScopeB = getDefiningScopeBound({B}, PreciseB);
4499 if (!PreciseA || !PreciseB)
4500 // Can't tell.
4501 return false;
4502 return (ScopeA == ScopeB) || DT.dominates(ScopeA, ScopeB) ||
4503 DT.dominates(ScopeB, ScopeA);
4504}
4505
4507 return CouldNotCompute.get();
4508}
4509
4510bool ScalarEvolution::checkValidity(const SCEV *S) const {
4511 bool ContainsNulls = SCEVExprContains(S, [](const SCEV *S) {
4512 auto *SU = dyn_cast<SCEVUnknown>(S);
4513 return SU && SU->getValue() == nullptr;
4514 });
4515
4516 return !ContainsNulls;
4517}
4518
4520 HasRecMapType::iterator I = HasRecMap.find(S);
4521 if (I != HasRecMap.end())
4522 return I->second;
4523
4524 bool FoundAddRec =
4525 SCEVExprContains(S, [](const SCEV *S) { return isa<SCEVAddRecExpr>(S); });
4526 HasRecMap.insert({S, FoundAddRec});
4527 return FoundAddRec;
4528}
4529
4530/// Return the ValueOffsetPair set for \p S. \p S can be represented
4531/// by the value and offset from any ValueOffsetPair in the set.
4532ArrayRef<Value *> ScalarEvolution::getSCEVValues(const SCEV *S) {
4533 ExprValueMapType::iterator SI = ExprValueMap.find_as(S);
4534 if (SI == ExprValueMap.end())
4535 return std::nullopt;
4536 return SI->second.getArrayRef();
4537}
4538
4539/// Erase Value from ValueExprMap and ExprValueMap. ValueExprMap.erase(V)
4540/// cannot be used separately. eraseValueFromMap should be used to remove
4541/// V from ValueExprMap and ExprValueMap at the same time.
4542void ScalarEvolution::eraseValueFromMap(Value *V) {
4543 ValueExprMapType::iterator I = ValueExprMap.find_as(V);
4544 if (I != ValueExprMap.end()) {
4545 auto EVIt = ExprValueMap.find(I->second);
4546 bool Removed = EVIt->second.remove(V);
4547 (void) Removed;
4548 assert(Removed && "Value not in ExprValueMap?");
4549 ValueExprMap.erase(I);
4550 }
4551}
4552
4553void ScalarEvolution::insertValueToMap(Value *V, const SCEV *S) {
4554 // A recursive query may have already computed the SCEV. It should be
4555 // equivalent, but may not necessarily be exactly the same, e.g. due to lazily
4556 // inferred nowrap flags.
4557 auto It = ValueExprMap.find_as(V);
4558 if (It == ValueExprMap.end()) {
4559 ValueExprMap.insert({SCEVCallbackVH(V, this), S});
4560 ExprValueMap[S].insert(V);
4561 }
4562}
4563
4564/// Return an existing SCEV if it exists, otherwise analyze the expression and
4565/// create a new one.
4567 assert(isSCEVable(V->getType()) && "Value is not SCEVable!");
4568
4569 if (const SCEV *S = getExistingSCEV(V))
4570 return S;
4571 return createSCEVIter(V);
4572}
4573
4575 assert(isSCEVable(V->getType()) && "Value is not SCEVable!");
4576
4577 ValueExprMapType::iterator I = ValueExprMap.find_as(V);
4578 if (I != ValueExprMap.end()) {
4579 const SCEV *S = I->second;
4580 assert(checkValidity(S) &&
4581 "existing SCEV has not been properly invalidated");
4582 return S;
4583 }
4584 return nullptr;
4585}
4586
4587/// Return a SCEV corresponding to -V = -1*V
4589 SCEV::NoWrapFlags Flags) {
4590 if (const SCEVConstant *VC = dyn_cast<SCEVConstant>(V))
4591 return getConstant(
4592 cast<ConstantInt>(ConstantExpr::getNeg(VC->getValue())));
4593
4594 Type *Ty = V->getType();
4595 Ty = getEffectiveSCEVType(Ty);
4596 return getMulExpr(V, getMinusOne(Ty), Flags);
4597}
4598
4599/// If Expr computes ~A, return A else return nullptr
4600static const SCEV *MatchNotExpr(const SCEV *Expr) {
4601 const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Expr);
4602 if (!Add || Add->getNumOperands() != 2 ||
4603 !Add->getOperand(0)->isAllOnesValue())
4604 return nullptr;
4605
4606 const SCEVMulExpr *AddRHS = dyn_cast<SCEVMulExpr>(Add->getOperand(1));
4607 if (!AddRHS || AddRHS->getNumOperands() != 2 ||
4608 !AddRHS->getOperand(0)->isAllOnesValue())
4609 return nullptr;
4610
4611 return AddRHS->getOperand(1);
4612}
4613
4614/// Return a SCEV corresponding to ~V = -1-V
4616 assert(!V->getType()->isPointerTy() && "Can't negate pointer");
4617
4618 if (const SCEVConstant *VC = dyn_cast<SCEVConstant>(V))
4619 return getConstant(
4620 cast<ConstantInt>(ConstantExpr::getNot(VC->getValue())));
4621
4622 // Fold ~(u|s)(min|max)(~x, ~y) to (u|s)(max|min)(x, y)
4623 if (const SCEVMinMaxExpr *MME = dyn_cast<SCEVMinMaxExpr>(V)) {
4624 auto MatchMinMaxNegation = [&](const SCEVMinMaxExpr *MME) {
4625 SmallVector<const SCEV *, 2> MatchedOperands;
4626 for (const SCEV *Operand : MME->operands()) {
4627 const SCEV *Matched = MatchNotExpr(Operand);
4628 if (!Matched)
4629 return (const SCEV *)nullptr;
4630 MatchedOperands.push_back(Matched);
4631 }
4632 return getMinMaxExpr(SCEVMinMaxExpr::negate(MME->getSCEVType()),
4633 MatchedOperands);
4634 };
4635 if (const SCEV *Replaced = MatchMinMaxNegation(MME))
4636 return Replaced;
4637 }
4638
4639 Type *Ty = V->getType();
4640 Ty = getEffectiveSCEVType(Ty);
4641 return getMinusSCEV(getMinusOne(Ty), V);
4642}
4643
4645 assert(P->getType()->isPointerTy());
4646
4647 if (auto *AddRec = dyn_cast<SCEVAddRecExpr>(P)) {
4648 // The base of an AddRec is the first operand.
4649 SmallVector<const SCEV *> Ops{AddRec->operands()};
4650 Ops[0] = removePointerBase(Ops[0]);
4651 // Don't try to transfer nowrap flags for now. We could in some cases
4652 // (for example, if pointer operand of the AddRec is a SCEVUnknown).
4653 return getAddRecExpr(Ops, AddRec->getLoop(), SCEV::FlagAnyWrap);
4654 }
4655 if (auto *Add = dyn_cast<SCEVAddExpr>(P)) {
4656 // The base of an Add is the pointer operand.
4657 SmallVector<const SCEV *> Ops{Add->operands()};
4658 const SCEV **PtrOp = nullptr;
4659 for (const SCEV *&AddOp : Ops) {
4660 if (AddOp->getType()->isPointerTy()) {
4661 assert(!PtrOp && "Cannot have multiple pointer ops");
4662 PtrOp = &AddOp;
4663 }
4664 }
4665 *PtrOp = removePointerBase(*PtrOp);
4666 // Don't try to transfer nowrap flags for now. We could in some cases
4667 // (for example, if the pointer operand of the Add is a SCEVUnknown).
4668 return getAddExpr(Ops);
4669 }
4670 // Any other expression must be a pointer base.
4671 return getZero(P->getType());
4672}
4673
4674const SCEV *ScalarEvolution::getMinusSCEV(const SCEV *LHS, const SCEV *RHS,
4675 SCEV::NoWrapFlags Flags,
4676 unsigned Depth) {
4677 // Fast path: X - X --> 0.
4678 if (LHS == RHS)
4679 return getZero(LHS->getType());
4680
4681 // If we subtract two pointers with different pointer bases, bail.
4682 // Eventually, we're going to add an assertion to getMulExpr that we
4683 // can't multiply by a pointer.
4684 if (RHS->getType()->isPointerTy()) {
4685 if (!LHS->getType()->isPointerTy() ||
4687 return getCouldNotCompute();
4690 }
4691
4692 // We represent LHS - RHS as LHS + (-1)*RHS. This transformation
4693 // makes it so that we cannot make much use of NUW.
4694 auto AddFlags = SCEV::FlagAnyWrap;
4695 const bool RHSIsNotMinSigned =
4697 if (hasFlags(Flags, SCEV::FlagNSW)) {
4698 // Let M be the minimum representable signed value. Then (-1)*RHS
4699 // signed-wraps if and only if RHS is M. That can happen even for
4700 // a NSW subtraction because e.g. (-1)*M signed-wraps even though
4701 // -1 - M does not. So to transfer NSW from LHS - RHS to LHS +
4702 // (-1)*RHS, we need to prove that RHS != M.
4703 //
4704 // If LHS is non-negative and we know that LHS - RHS does not
4705 // signed-wrap, then RHS cannot be M. So we can rule out signed-wrap
4706 // either by proving that RHS > M or that LHS >= 0.
4707 if (RHSIsNotMinSigned || isKnownNonNegative(LHS)) {
4708 AddFlags = SCEV::FlagNSW;
4709 }
4710 }
4711
4712 // FIXME: Find a correct way to transfer NSW to (-1)*M when LHS -
4713 // RHS is NSW and LHS >= 0.
4714 //
4715 // The difficulty here is that the NSW flag may have been proven
4716 // relative to a loop that is to be found in a recurrence in LHS and
4717 // not in RHS. Applying NSW to (-1)*M may then let the NSW have a
4718 // larger scope than intended.
4719 auto NegFlags = RHSIsNotMinSigned ? SCEV::FlagNSW : SCEV::FlagAnyWrap;
4720
4721 return getAddExpr(LHS, getNegativeSCEV(RHS, NegFlags), AddFlags, Depth);
4722}
4723
4725 unsigned Depth) {
4726 Type *SrcTy = V->getType();
4727 assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
4728 "Cannot truncate or zero extend with non-integer arguments!");
4729 if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
4730 return V; // No conversion
4731 if (getTypeSizeInBits(SrcTy) > getTypeSizeInBits(Ty))
4732 return getTruncateExpr(V, Ty, Depth);
4733 return getZeroExtendExpr(V, Ty, Depth);
4734}
4735
4737 unsigned Depth) {
4738 Type *SrcTy = V->getType();
4739 assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
4740 "Cannot truncate or zero extend with non-integer arguments!");
4741 if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
4742 return V; // No conversion
4743 if (getTypeSizeInBits(SrcTy) > getTypeSizeInBits(Ty))
4744 return getTruncateExpr(V, Ty, Depth);
4745 return getSignExtendExpr(V, Ty, Depth);
4746}
4747
4748const SCEV *
4750 Type *SrcTy = V->getType();
4751 assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
4752 "Cannot noop or zero extend with non-integer arguments!");
4754 "getNoopOrZeroExtend cannot truncate!");
4755 if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
4756 return V; // No conversion
4757 return getZeroExtendExpr(V, Ty);
4758}
4759
4760const SCEV *
4762 Type *SrcTy = V->getType();
4763 assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
4764 "Cannot noop or sign extend with non-integer arguments!");
4766 "getNoopOrSignExtend cannot truncate!");
4767 if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
4768 return V; // No conversion
4769 return getSignExtendExpr(V, Ty);
4770}
4771
4772const SCEV *
4774 Type *SrcTy = V->getType();
4775 assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
4776 "Cannot noop or any extend with non-integer arguments!");
4778 "getNoopOrAnyExtend cannot truncate!");
4779 if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
4780 return V; // No conversion
4781 return getAnyExtendExpr(V, Ty);
4782}
4783
4784const SCEV *
4786 Type *SrcTy = V->getType();
4787 assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
4788 "Cannot truncate or noop with non-integer arguments!");
4790 "getTruncateOrNoop cannot extend!");
4791 if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
4792 return V; // No conversion
4793 return getTruncateExpr(V, Ty);
4794}
4795
4797 const SCEV *RHS) {
4798 const SCEV *PromotedLHS = LHS;
4799 const SCEV *PromotedRHS = RHS;
4800
4802 PromotedRHS = getZeroExtendExpr(RHS, LHS->getType());
4803 else
4804 PromotedLHS = getNoopOrZeroExtend(LHS, RHS->getType());
4805
4806 return getUMaxExpr(PromotedLHS, PromotedRHS);
4807}
4808
4810 const SCEV *RHS,
4811 bool Sequential) {
4813 return getUMinFromMismatchedTypes(Ops, Sequential);
4814}
4815
4816const SCEV *
4818 bool Sequential) {
4819 assert(!Ops.empty() && "At least one operand must be!");
4820 // Trivial case.
4821 if (Ops.size() == 1)
4822 return Ops[0];
4823
4824 // Find the max type first.
4825 Type *MaxType = nullptr;
4826 for (const auto *S : Ops)
4827 if (MaxType)
4828 MaxType = getWiderType(MaxType, S->getType());
4829 else
4830 MaxType = S->getType();
4831 assert(MaxType && "Failed to find maximum type!");
4832
4833 // Extend all ops to max type.
4834 SmallVector<const SCEV *, 2> PromotedOps;
4835 for (const auto *S : Ops)
4836 PromotedOps.push_back(getNoopOrZeroExtend(S, MaxType));
4837
4838 // Generate umin.
4839 return getUMinExpr(PromotedOps, Sequential);
4840}
4841
4843 // A pointer operand may evaluate to a nonpointer expression, such as null.
4844 if (!V->getType()->isPointerTy())
4845 return V;
4846
4847 while (true) {
4848 if (auto *AddRec = dyn_cast<SCEVAddRecExpr>(V)) {
4849 V = AddRec->getStart();
4850 } else if (auto *Add = dyn_cast<SCEVAddExpr>(V)) {
4851 const SCEV *PtrOp = nullptr;
4852 for (const SCEV *AddOp : Add->operands()) {
4853 if (AddOp->getType()->isPointerTy()) {
4854 assert(!PtrOp && "Cannot have multiple pointer ops");
4855 PtrOp = AddOp;
4856 }
4857 }
4858 assert(PtrOp && "Must have pointer op");
4859 V = PtrOp;
4860 } else // Not something we can look further into.
4861 return V;
4862 }
4863}
4864
4865/// Push users of the given Instruction onto the given Worklist.
4869 // Push the def-use children onto the Worklist stack.
4870 for (User *U : I->users()) {
4871 auto *UserInsn = cast<Instruction>(U);
4872 if (Visited.insert(UserInsn).second)
4873 Worklist.push_back(UserInsn);
4874 }
4875}
4876
4877namespace {
4878
4879/// Takes SCEV S and Loop L. For each AddRec sub-expression, use its start
4880/// expression in case its Loop is L. If it is not L then
4881/// if IgnoreOtherLoops is true then use AddRec itself
4882/// otherwise rewrite cannot be done.
4883/// If SCEV contains non-invariant unknown SCEV rewrite cannot be done.
4884class SCEVInitRewriter : public SCEVRewriteVisitor<SCEVInitRewriter> {
4885public:
4886 static const SCEV *rewrite(const SCEV *S, const Loop *L, ScalarEvolution &SE,
4887 bool IgnoreOtherLoops = true) {
4888 SCEVInitRewriter Rewriter(L, SE);
4889 const SCEV *Result = Rewriter.visit(S);
4890 if (Rewriter.hasSeenLoopVariantSCEVUnknown())
4891 return SE.getCouldNotCompute();
4892 return Rewriter.hasSeenOtherLoops() && !IgnoreOtherLoops
4893 ? SE.getCouldNotCompute()
4894 : Result;
4895 }
4896
4897 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
4898 if (!SE.isLoopInvariant(Expr, L))
4899 SeenLoopVariantSCEVUnknown = true;
4900 return Expr;
4901 }
4902
4903 const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) {
4904 // Only re-write AddRecExprs for this loop.
4905 if (Expr->getLoop() == L)
4906 return Expr->getStart();
4907 SeenOtherLoops = true;
4908 return Expr;
4909 }
4910
4911 bool hasSeenLoopVariantSCEVUnknown() { return SeenLoopVariantSCEVUnknown; }
4912
4913 bool hasSeenOtherLoops() { return SeenOtherLoops; }
4914
4915private:
4916 explicit SCEVInitRewriter(const Loop *L, ScalarEvolution &SE)
4917 : SCEVRewriteVisitor(SE), L(L) {}
4918
4919 const Loop *L;
4920 bool SeenLoopVariantSCEVUnknown = false;
4921 bool SeenOtherLoops = false;
4922};
4923
4924/// Takes SCEV S and Loop L. For each AddRec sub-expression, use its post
4925/// increment expression in case its Loop is L. If it is not L then
4926/// use AddRec itself.
4927/// If SCEV contains non-invariant unknown SCEV rewrite cannot be done.
4928class SCEVPostIncRewriter : public SCEVRewriteVisitor<SCEVPostIncRewriter> {
4929public:
4930 static const SCEV *rewrite(const SCEV *S, const Loop *L, ScalarEvolution &SE) {
4931 SCEVPostIncRewriter Rewriter(L, SE);
4932 const SCEV *Result = Rewriter.visit(S);
4933 return Rewriter.hasSeenLoopVariantSCEVUnknown()
4934 ? SE.getCouldNotCompute()
4935 : Result;
4936 }
4937
4938 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
4939 if (!SE.isLoopInvariant(Expr, L))
4940 SeenLoopVariantSCEVUnknown = true;
4941 return Expr;
4942 }
4943
4944 const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) {
4945 // Only re-write AddRecExprs for this loop.
4946 if (Expr->getLoop() == L)
4947 return Expr->getPostIncExpr(SE);
4948 SeenOtherLoops = true;
4949 return Expr;
4950 }
4951
4952 bool hasSeenLoopVariantSCEVUnknown() { return SeenLoopVariantSCEVUnknown; }
4953
4954 bool hasSeenOtherLoops() { return SeenOtherLoops; }
4955
4956private:
4957 explicit SCEVPostIncRewriter(const Loop *L, ScalarEvolution &SE)
4958 : SCEVRewriteVisitor(SE), L(L) {}
4959
4960 const Loop *L;
4961 bool SeenLoopVariantSCEVUnknown = false;
4962 bool SeenOtherLoops = false;
4963};
4964
4965/// This class evaluates the compare condition by matching it against the
4966/// condition of loop latch. If there is a match we assume a true value
4967/// for the condition while building SCEV nodes.
4968class SCEVBackedgeConditionFolder
4969 : public SCEVRewriteVisitor<SCEVBackedgeConditionFolder> {
4970public:
4971 static const SCEV *rewrite(const SCEV *S, const Loop *L,
4972 ScalarEvolution &SE) {
4973 bool IsPosBECond = false;
4974 Value *BECond = nullptr;
4975 if (BasicBlock *Latch = L->getLoopLatch()) {
4976 BranchInst *BI = dyn_cast<BranchInst>(Latch->getTerminator());
4977 if (BI && BI->isConditional()) {
4978 assert(BI->getSuccessor(0) != BI->getSuccessor(1) &&
4979 "Both outgoing branches should not target same header!");
4980 BECond = BI->getCondition();
4981 IsPosBECond = BI->getSuccessor(0) == L->getHeader();
4982 } else {
4983 return S;
4984 }
4985 }
4986 SCEVBackedgeConditionFolder Rewriter(L, BECond, IsPosBECond, SE);
4987 return Rewriter.visit(S);
4988 }
4989
4990 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
4991 const SCEV *Result = Expr;
4992 bool InvariantF = SE.isLoopInvariant(Expr, L);
4993
4994 if (!InvariantF) {
4995 Instruction *I = cast<Instruction>(Expr->getValue());
4996 switch (I->getOpcode()) {
4997 case Instruction::Select: {
4998 SelectInst *SI = cast<SelectInst>(I);
4999 std::optional<const SCEV *> Res =
5000 compareWithBackedgeCondition(SI->getCondition());
5001 if (Res) {
5002 bool IsOne = cast<SCEVConstant>(*Res)->getValue()->isOne();
5003 Result = SE.getSCEV(IsOne ? SI->getTrueValue() : SI->getFalseValue());
5004 }
5005 break;
5006 }
5007 default: {
5008 std::optional<const SCEV *> Res = compareWithBackedgeCondition(I);
5009 if (Res)
5010 Result = *Res;
5011 break;
5012 }
5013 }
5014 }
5015 return Result;
5016 }
5017
5018private:
5019 explicit SCEVBackedgeConditionFolder(const Loop *L, Value *BECond,
5020 bool IsPosBECond, ScalarEvolution &SE)
5021 : SCEVRewriteVisitor(SE), L(L), BackedgeCond(BECond),
5022 IsPositiveBECond(IsPosBECond) {}
5023
5024 std::optional<const SCEV *> compareWithBackedgeCondition(Value *IC);
5025
5026 const Loop *L;
5027 /// Loop back condition.
5028 Value *BackedgeCond = nullptr;
5029 /// Set to true if loop back is on positive branch condition.
5030 bool IsPositiveBECond;
5031};
5032
5033std::optional<const SCEV *>
5034SCEVBackedgeConditionFolder::compareWithBackedgeCondition(Value *IC) {
5035
5036 // If value matches the backedge condition for loop latch,
5037 // then return a constant evolution node based on loopback
5038 // branch taken.
5039 if (BackedgeCond == IC)
5040 return IsPositiveBECond ? SE.getOne(Type::getInt1Ty(SE.getContext()))
5042 return std::nullopt;
5043}
5044
5045class SCEVShiftRewriter : public SCEVRewriteVisitor<SCEVShiftRewriter> {
5046public:
5047 static const SCEV *rewrite(const SCEV *S, const Loop *L,
5048 ScalarEvolution &SE) {
5049 SCEVShiftRewriter Rewriter(L, SE);
5050 const SCEV *Result = Rewriter.visit(S);
5051 return Rewriter.isValid() ? Result : SE.getCouldNotCompute();
5052 }
5053
5054 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
5055 // Only allow AddRecExprs for this loop.
5056 if (!SE.isLoopInvariant(Expr, L))
5057 Valid = false;
5058 return Expr;
5059 }
5060
5061 const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) {
5062 if (Expr->getLoop() == L && Expr->isAffine())
5063 return SE.getMinusSCEV(Expr, Expr->getStepRecurrence(SE));
5064 Valid = false;
5065 return Expr;
5066 }
5067
5068 bool isValid() { return Valid; }
5069
5070private:
5071 explicit SCEVShiftRewriter(const Loop *L, ScalarEvolution &SE)
5072 : SCEVRewriteVisitor(SE), L(L) {}
5073
5074 const Loop *L;
5075 bool Valid = true;
5076};
5077
5078} // end anonymous namespace
5079
5081ScalarEvolution::proveNoWrapViaConstantRanges(const SCEVAddRecExpr *AR) {
5082 if (!AR->isAffine())
5083 return SCEV::FlagAnyWrap;
5084
5085 using OBO = OverflowingBinaryOperator;
5086
5088
5089 if (!AR->hasNoSelfWrap()) {
5090 const SCEV *BECount = getConstantMaxBackedgeTakenCount(AR->getLoop());
5091 if (const SCEVConstant *BECountMax = dyn_cast<SCEVConstant>(BECount)) {
5092 ConstantRange StepCR = getSignedRange(AR->getStepRecurrence(*this));
5093 const APInt &BECountAP = BECountMax->getAPInt();
5094 unsigned NoOverflowBitWidth =
5095 BECountAP.getActiveBits() + StepCR.getMinSignedBits();
5096 if (NoOverflowBitWidth <= getTypeSizeInBits(AR->getType()))
5098 }
5099 }
5100
5101 if (!AR->hasNoSignedWrap()) {
5102 ConstantRange AddRecRange = getSignedRange(AR);
5103 ConstantRange IncRange = getSignedRange(AR->getStepRecurrence(*this));
5104
5106 Instruction::Add, IncRange, OBO::NoSignedWrap);
5107 if (NSWRegion.contains(AddRecRange))
5109 }
5110
5111 if (!AR->hasNoUnsignedWrap()) {
5112 ConstantRange AddRecRange = getUnsignedRange(AR);
5113 ConstantRange IncRange = getUnsignedRange(AR->getStepRecurrence(*this));
5114
5116 Instruction::Add, IncRange, OBO::NoUnsignedWrap);
5117 if (NUWRegion.contains(AddRecRange))
5119 }
5120
5121 return Result;
5122}
5123
5125ScalarEvolution::proveNoSignedWrapViaInduction(const SCEVAddRecExpr *AR) {
5127
5128 if (AR->hasNoSignedWrap())
5129 return Result;
5130
5131 if (!AR->isAffine())
5132 return Result;
5133
5134 // This function can be expensive, only try to prove NSW once per AddRec.
5135 if (!SignedWrapViaInductionTried.insert(AR).second)
5136 return Result;
5137
5138 const SCEV *Step = AR->getStepRecurrence(*this);
5139 const Loop *L = AR->getLoop();
5140
5141 // Check whether the backedge-taken count is SCEVCouldNotCompute.
5142 // Note that this serves two purposes: It filters out loops that are
5143 // simply not analyzable, and it covers the case where this code is
5144 // being called from within backedge-taken count analysis, such that
5145 // attempting to ask for the backedge-taken count would likely result
5146 // in infinite recursion. In the later case, the analysis code will
5147 // cope with a conservative value, and it will take care to purge
5148 // that value once it has finished.
5149 const SCEV *MaxBECount = getConstantMaxBackedgeTakenCount(L);
5150
5151 // Normally, in the cases we can prove no-overflow via a
5152 // backedge guarding condition, we can also compute a backedge
5153 // taken count for the loop. The exceptions are assumptions and
5154 // guards present in the loop -- SCEV is not great at exploiting
5155 // these to compute max backedge taken counts, but can still use
5156 // these to prove lack of overflow. Use this fact to avoid
5157 // doing extra work that may not pay off.
5158
5159 if (isa<SCEVCouldNotCompute>(MaxBECount) && !HasGuards &&
5160 AC.assumptions().empty())
5161 return Result;
5162
5163 // If the backedge is guarded by a comparison with the pre-inc value the
5164 // addrec is safe. Also, if the entry is guarded by a comparison with the
5165 // start value and the backedge is guarded by a comparison with the post-inc
5166 // value, the addrec is safe.
5168 const SCEV *OverflowLimit =
5169 getSignedOverflowLimitForStep(Step, &Pred, this);
5170 if (OverflowLimit &&
5171 (isLoopBackedgeGuardedByCond(L, Pred, AR, OverflowLimit) ||
5172 isKnownOnEveryIteration(Pred, AR, OverflowLimit))) {
5173 Result = setFlags(Result, SCEV::FlagNSW);
5174 }
5175 return Result;
5176}
5178ScalarEvolution::proveNoUnsignedWrapViaInduction(const SCEVAddRecExpr *AR) {
5180
5181 if (AR->hasNoUnsignedWrap())
5182 return Result;
5183
5184 if (!AR->isAffine())
5185 return Result;
5186
5187 // This function can be expensive, only try to prove NUW once per AddRec.
5188 if (!UnsignedWrapViaInductionTried.insert(AR).second)
5189 return Result;
5190
5191 const SCEV *Step = AR->getStepRecurrence(*this);
5192 unsigned BitWidth = getTypeSizeInBits(AR->getType());
5193 const Loop *L = AR->getLoop();
5194
5195 // Check whether the backedge-taken count is SCEVCouldNotCompute.
5196 // Note that this serves two purposes: It filters out loops that are
5197 // simply not analyzable, and it covers the case where this code is
5198 // being called from within backedge-taken count analysis, such that
5199 // attempting to ask for the backedge-taken count would likely result
5200 // in infinite recursion. In the later case, the analysis code will
5201 // cope with a conservative value, and it will take care to purge
5202 // that value once it has finished.
5203 const SCEV *MaxBECount = getConstantMaxBackedgeTakenCount(L);
5204
5205 // Normally, in the cases we can prove no-overflow via a
5206 // backedge guarding condition, we can also compute a backedge
5207 // taken count for the loop. The exceptions are assumptions and
5208 // guards present in the loop -- SCEV is not great at exploiting
5209 // these to compute max backedge taken counts, but can still use
5210 // these to prove lack of overflow. Use this fact to avoid
5211 // doing extra work that may not pay off.
5212
5213 if (isa<SCEVCouldNotCompute>(MaxBECount) && !HasGuards &&
5214 AC.assumptions().empty())
5215 return Result;
5216
5217 // If the backedge is guarded by a comparison with the pre-inc value the
5218 // addrec is safe. Also, if the entry is guarded by a comparison with the
5219 // start value and the backedge is guarded by a comparison with the post-inc
5220 // value, the addrec is safe.
5221 if (isKnownPositive(Step)) {
5223 getUnsignedRangeMax(Step));
5226 Result = setFlags(Result, SCEV::FlagNUW);
5227 }
5228 }
5229
5230 return Result;
5231}
5232
5233namespace {
5234
5235/// Represents an abstract binary operation. This may exist as a
5236/// normal instruction or constant expression, or may have been
5237/// derived from an expression tree.
5238struct BinaryOp {
5239 unsigned Opcode;
5240 Value *LHS;
5241 Value *RHS;
5242 bool IsNSW = false;
5243 bool IsNUW = false;
5244
5245 /// Op is set if this BinaryOp corresponds to a concrete LLVM instruction or
5246 /// constant expression.
5247 Operator *Op = nullptr;
5248
5249 explicit BinaryOp(Operator *Op)
5250 : Opcode(Op->getOpcode()), LHS(Op->getOperand(0)), RHS(Op->getOperand(1)),
5251 Op(Op) {
5252 if (auto *OBO = dyn_cast<OverflowingBinaryOperator>(Op)) {
5253 IsNSW = OBO->hasNoSignedWrap();
5254 IsNUW = OBO->hasNoUnsignedWrap();
5255 }
5256 }
5257
5258 explicit BinaryOp(unsigned Opcode, Value *LHS, Value *RHS, bool IsNSW = false,
5259 bool IsNUW = false)
5260 : Opcode(Opcode), LHS(LHS), RHS(RHS), IsNSW(IsNSW), IsNUW(IsNUW) {}
5261};
5262
5263} // end anonymous namespace
5264
5265/// Try to map \p V into a BinaryOp, and return \c std::nullopt on failure.
5266static std::optional<BinaryOp> MatchBinaryOp(Value *V, const DataLayout &DL,
5267 AssumptionCache &AC,
5268 const DominatorTree &DT,
5269 const Instruction *CxtI) {
5270 auto *Op = dyn_cast<Operator>(V);
5271 if (!Op)
5272 return std::nullopt;
5273
5274 // Implementation detail: all the cleverness here should happen without
5275 // creating new SCEV expressions -- our caller knowns tricks to avoid creating
5276 // SCEV expressions when possible, and we should not break that.
5277
5278 switch (Op->getOpcode()) {
5279 case Instruction::Add:
5280 case Instruction::Sub:
5281 case Instruction::Mul:
5282 case Instruction::UDiv:
5283 case Instruction::URem:
5284 case Instruction::And:
5285 case Instruction::AShr:
5286 case Instruction::Shl:
5287 return BinaryOp(Op);
5288
5289 case Instruction::Or: {
5290 // Convert or disjoint into add nuw nsw.
5291 if (cast<PossiblyDisjointInst>(Op)->isDisjoint())
5292 return BinaryOp(Instruction::Add, Op->getOperand(0), Op->getOperand(1),
5293 /*IsNSW=*/true, /*IsNUW=*/true);
5294 return BinaryOp(Op);
5295 }
5296
5297 case Instruction::Xor:
5298 if (auto *RHSC = dyn_cast<ConstantInt>(Op->getOperand(1)))
5299 // If the RHS of the xor is a signmask, then this is just an add.
5300 // Instcombine turns add of signmask into xor as a strength reduction step.
5301 if (RHSC->getValue().isSignMask())
5302 return BinaryOp(Instruction::Add, Op->getOperand(0), Op->getOperand(1));
5303 // Binary `xor` is a bit-wise `add`.
5304 if (V->getType()->isIntegerTy(1))
5305 return BinaryOp(Instruction::Add, Op->getOperand(0), Op->getOperand(1));
5306 return BinaryOp(Op);
5307
5308 case Instruction::LShr:
5309 // Turn logical shift right of a constant into a unsigned divide.
5310 if (ConstantInt *SA = dyn_cast<ConstantInt>(Op->getOperand(1))) {
5311 uint32_t BitWidth = cast<IntegerType>(Op->getType())->getBitWidth();
5312
5313 // If the shift count is not less than the bitwidth, the result of
5314 // the shift is undefined. Don't try to analyze it, because the
5315 // resolution chosen here may differ from the resolution chosen in
5316 // other parts of the compiler.
5317 if (SA->getValue().ult(BitWidth)) {
5318 Constant *X =
5319 ConstantInt::get(SA->getContext(),
5320 APInt::getOneBitSet(BitWidth, SA->getZExtValue()));
5321 return BinaryOp(Instruction::UDiv, Op->getOperand(0), X);
5322 }
5323 }
5324 return BinaryOp(Op);
5325
5326 case Instruction::ExtractValue: {
5327 auto *EVI = cast<ExtractValueInst>(Op);
5328 if (EVI->getNumIndices() != 1 || EVI->getIndices()[0] != 0)
5329 break;
5330
5331 auto *WO = dyn_cast<WithOverflowInst>(EVI->getAggregateOperand());
5332 if (!WO)
5333 break;
5334
5335 Instruction::BinaryOps BinOp = WO->getBinaryOp();
5336 bool Signed = WO->isSigned();
5337 // TODO: Should add nuw/nsw flags for mul as well.
5338 if (BinOp == Instruction::Mul || !isOverflowIntrinsicNoWrap(WO, DT))
5339 return BinaryOp(BinOp, WO->getLHS(), WO->getRHS());
5340
5341 // Now that we know that all uses of the arithmetic-result component of
5342 // CI are guarded by the overflow check, we can go ahead and pretend
5343 // that the arithmetic is non-overflowing.
5344 return BinaryOp(BinOp, WO->getLHS(), WO->getRHS(),
5345 /* IsNSW = */ Signed, /* IsNUW = */ !Signed);
5346 }
5347
5348 default:
5349 break;
5350 }
5351
5352 // Recognise intrinsic loop.decrement.reg, and as this has exactly the same
5353 // semantics as a Sub, return a binary sub expression.
5354 if (auto *II = dyn_cast<IntrinsicInst>(V))
5355 if (II->getIntrinsicID() == Intrinsic::loop_decrement_reg)
5356 return BinaryOp(Instruction::Sub, II->getOperand(0), II->getOperand(1));
5357
5358 return std::nullopt;
5359}
5360
5361/// Helper function to createAddRecFromPHIWithCasts. We have a phi
5362/// node whose symbolic (unknown) SCEV is \p SymbolicPHI, which is updated via
5363/// the loop backedge by a SCEVAddExpr, possibly also with a few casts on the
5364/// way. This function checks if \p Op, an operand of this SCEVAddExpr,
5365/// follows one of the following patterns:
5366/// Op == (SExt ix (Trunc iy (%SymbolicPHI) to ix) to iy)
5367/// Op == (ZExt ix (Trunc iy (%SymbolicPHI) to ix) to iy)
5368/// If the SCEV expression of \p Op conforms with one of the expected patterns
5369/// we return the type of the truncation operation, and indicate whether the
5370/// truncated type should be treated as signed/unsigned by setting
5371/// \p Signed to true/false, respectively.
5372static Type *isSimpleCastedPHI(const SCEV *Op, const SCEVUnknown *SymbolicPHI,
5373 bool &Signed, ScalarEvolution &SE) {
5374 // The case where Op == SymbolicPHI (that is, with no type conversions on
5375 // the way) is handled by the regular add recurrence creating logic and
5376 // would have already been triggered in createAddRecForPHI. Reaching it here
5377 // means that createAddRecFromPHI had failed for this PHI before (e.g.,
5378 // because one of the other operands of the SCEVAddExpr updating this PHI is
5379 // not invariant).
5380 //
5381 // Here we look for the case where Op = (ext(trunc(SymbolicPHI))), and in
5382 // this case predicates that allow us to prove that Op == SymbolicPHI will
5383 // be added.
5384 if (Op == SymbolicPHI)
5385 return nullptr;
5386
5387 unsigned SourceBits = SE.getTypeSizeInBits(SymbolicPHI->getType());
5388 unsigned NewBits = SE.getTypeSizeInBits(Op->getType());
5389 if (SourceBits != NewBits)
5390 return nullptr;
5391
5392 const SCEVSignExtendExpr *SExt = dyn_cast<SCEVSignExtendExpr>(Op);
5393 const SCEVZeroExtendExpr *ZExt = dyn_cast<SCEVZeroExtendExpr>(Op);
5394 if (!SExt && !ZExt)
5395 return nullptr;
5396 const SCEVTruncateExpr *Trunc =
5397 SExt ? dyn_cast<SCEVTruncateExpr>(SExt->getOperand())
5398 : dyn_cast<SCEVTruncateExpr>(ZExt->getOperand());
5399 if (!Trunc)
5400 return nullptr;
5401 const SCEV *X = Trunc->getOperand();
5402 if (X != SymbolicPHI)
5403 return nullptr;
5404 Signed = SExt != nullptr;
5405 return Trunc->getType();
5406}
5407
5408static const Loop *isIntegerLoopHeaderPHI(const PHINode *PN, LoopInfo &LI) {
5409 if (!PN->getType()->isIntegerTy())
5410 return nullptr;
5411 const Loop *L = LI.getLoopFor(PN->getParent());
5412 if (!L || L->getHeader() != PN->getParent())
5413 return nullptr;
5414 return L;
5415}
5416
5417// Analyze \p SymbolicPHI, a SCEV expression of a phi node, and check if the
5418// computation that updates the phi follows the following pattern:
5419// (SExt/ZExt ix (Trunc iy (%SymbolicPHI) to ix) to iy) + InvariantAccum
5420// which correspond to a phi->trunc->sext/zext->add->phi update chain.
5421// If so, try to see if it can be rewritten as an AddRecExpr under some
5422// Predicates. If successful, return them as a pair. Also cache the results
5423// of the analysis.
5424//
5425// Example usage scenario:
5426// Say the Rewriter is called for the following SCEV:
5427// 8 * ((sext i32 (trunc i64 %X to i32) to i64) + %Step)
5428// where:
5429// %X = phi i64 (%Start, %BEValue)
5430// It will visitMul->visitAdd->visitSExt->visitTrunc->visitUnknown(%X),
5431// and call this function with %SymbolicPHI = %X.
5432//
5433// The analysis will find that the value coming around the backedge has
5434// the following SCEV:
5435// BEValue = ((sext i32 (trunc i64 %X to i32) to i64) + %Step)
5436// Upon concluding that this matches the desired pattern, the function
5437// will return the pair {NewAddRec, SmallPredsVec} where:
5438// NewAddRec = {%Start,+,%Step}
5439// SmallPredsVec = {P1, P2, P3} as follows:
5440// P1(WrapPred): AR: {trunc(%Start),+,(trunc %Step)}<nsw> Flags: <nssw>
5441// P2(EqualPred): %Start == (sext i32 (trunc i64 %Start to i32) to i64)
5442// P3(EqualPred): %Step == (sext i32 (trunc i64 %Step to i32) to i64)
5443// The returned pair means that SymbolicPHI can be rewritten into NewAddRec
5444// under the predicates {P1,P2,P3}.
5445// This predicated rewrite will be cached in PredicatedSCEVRewrites:
5446// PredicatedSCEVRewrites[{%X,L}] = {NewAddRec, {P1,P2,P3)}
5447//
5448// TODO's:
5449//
5450// 1) Extend the Induction descriptor to also support inductions that involve
5451// casts: When needed (namely, when we are called in the context of the
5452// vectorizer induction analysis), a Set of cast instructions will be
5453// populated by this method, and provided back to isInductionPHI. This is
5454// needed to allow the vectorizer to properly record them to be ignored by
5455// the cost model and to avoid vectorizing them (otherwise these casts,
5456// which are redundant under the runtime overflow checks, will be
5457// vectorized, which can be costly).
5458//
5459// 2) Support additional induction/PHISCEV patterns: We also want to support
5460// inductions where the sext-trunc / zext-trunc operations (partly) occur
5461// after the induction update operation (the induction increment):
5462//
5463// (Trunc iy (SExt/ZExt ix (%SymbolicPHI + InvariantAccum) to iy) to ix)
5464// which correspond to a phi->add->trunc->sext/zext->phi update chain.
5465//
5466// (Trunc iy ((SExt/ZExt ix (%SymbolicPhi) to iy) + InvariantAccum) to ix)
5467// which correspond to a phi->trunc->add->sext/zext->phi update chain.
5468//
5469// 3) Outline common code with createAddRecFromPHI to avoid duplication.
5470std::optional<std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>>>
5471ScalarEvolution::createAddRecFromPHIWithCastsImpl(const SCEVUnknown *SymbolicPHI) {
5473
5474 // *** Part1: Analyze if we have a phi-with-cast pattern for which we can
5475 // return an AddRec expression under some predicate.
5476
5477 auto *PN = cast<PHINode>(SymbolicPHI->getValue());
5478 const Loop *L = isIntegerLoopHeaderPHI(PN, LI);
5479 assert(L && "Expecting an integer loop header phi");
5480
5481 // The loop may have multiple entrances or multiple exits; we can analyze
5482 // this phi as an addrec if it has a unique entry value and a unique
5483 // backedge value.
5484 Value *BEValueV = nullptr, *StartValueV = nullptr;
5485 for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) {
5486 Value *V = PN->getIncomingValue(i);
5487 if (L->contains(PN->getIncomingBlock(i))) {
5488 if (!BEValueV) {
5489 BEValueV = V;
5490 } else if (BEValueV != V) {
5491 BEValueV = nullptr;
5492 break;
5493 }
5494 } else if (!StartValueV) {
5495 StartValueV = V;
5496 } else if (StartValueV != V) {
5497 StartValueV = nullptr;
5498 break;
5499 }
5500 }
5501 if (!BEValueV || !StartValueV)
5502 return std::nullopt;
5503
5504 const SCEV *BEValue = getSCEV(BEValueV);
5505
5506 // If the value coming around the backedge is an add with the symbolic
5507 // value we just inserted, possibly with casts that we can ignore under
5508 // an appropriate runtime guard, then we found a simple induction variable!
5509 const auto *Add = dyn_cast<SCEVAddExpr>(BEValue);
5510 if (!Add)
5511 return std::nullopt;
5512
5513 // If there is a single occurrence of the symbolic value, possibly
5514 // casted, replace it with a recurrence.
5515 unsigned FoundIndex = Add->getNumOperands();
5516 Type *TruncTy = nullptr;
5517 bool Signed;
5518 for (unsigned i = 0, e = Add->getNumOperands(); i != e; ++i)
5519 if ((TruncTy =
5520 isSimpleCastedPHI(Add->getOperand(i), SymbolicPHI, Signed, *this)))
5521 if (FoundIndex == e) {
5522 FoundIndex = i;
5523 break;
5524 }
5525
5526 if (FoundIndex == Add->getNumOperands())
5527 return std::nullopt;
5528
5529 // Create an add with everything but the specified operand.
5531 for (unsigned i = 0, e = Add->getNumOperands(); i != e; ++i)
5532 if (i != FoundIndex)
5533 Ops.push_back(Add->getOperand(i));
5534 const SCEV *Accum = getAddExpr(Ops);
5535
5536 // The runtime checks will not be valid if the step amount is
5537 // varying inside the loop.
5538 if (!isLoopInvariant(Accum, L))
5539 return std::nullopt;
5540
5541 // *** Part2: Create the predicates
5542
5543 // Analysis was successful: we have a phi-with-cast pattern for which we
5544 // can return an AddRec expression under the following predicates:
5545 //
5546 // P1: A Wrap predicate that guarantees that Trunc(Start) + i*Trunc(Accum)
5547 // fits within the truncated type (does not overflow) for i = 0 to n-1.
5548 // P2: An Equal predicate that guarantees that
5549 // Start = (Ext ix (Trunc iy (Start) to ix) to iy)
5550 // P3: An Equal predicate that guarantees that
5551 // Accum = (Ext ix (Trunc iy (Accum) to ix) to iy)
5552 //
5553 // As we next prove, the above predicates guarantee that:
5554 // Start + i*Accum = (Ext ix (Trunc iy ( Start + i*Accum ) to ix) to iy)
5555 //
5556 //
5557 // More formally, we want to prove that:
5558 // Expr(i+1) = Start + (i+1) * Accum
5559 // = (Ext ix (Trunc iy (Expr(i)) to ix) to iy) + Accum
5560 //
5561 // Given that:
5562 // 1) Expr(0) = Start
5563 // 2) Expr(1) = Start + Accum
5564 // = (Ext ix (Trunc iy (Start) to ix) to iy) + Accum :: from P2
5565 // 3) Induction hypothesis (step i):
5566 // Expr(i) = (Ext ix (Trunc iy (Expr(i-1)) to ix) to iy) + Accum
5567 //
5568 // Proof:
5569 // Expr(i+1) =
5570 // = Start + (i+1)*Accum
5571 // = (Start + i*Accum) + Accum
5572 // = Expr(i) + Accum
5573 // = (Ext ix (Trunc iy (Expr(i-1)) to ix) to iy) + Accum + Accum
5574 // :: from step i
5575 //
5576 // = (Ext ix (Trunc iy (Start + (i-1)*Accum) to ix) to iy) + Accum + Accum
5577 //
5578 // = (Ext ix (Trunc iy (Start + (i-1)*Accum) to ix) to iy)
5579 // + (Ext ix (Trunc iy (Accum) to ix) to iy)
5580 // + Accum :: from P3
5581 //
5582 // = (Ext ix (Trunc iy ((Start + (i-1)*Accum) + Accum) to ix) to iy)
5583 // + Accum :: from P1: Ext(x)+Ext(y)=>Ext(x+y)
5584 //
5585 // = (Ext ix (Trunc iy (Start + i*Accum) to ix) to iy) + Accum
5586 // = (Ext ix (Trunc iy (Expr(i)) to ix) to iy) + Accum
5587 //
5588 // By induction, the same applies to all iterations 1<=i<n:
5589 //
5590
5591 // Create a truncated addrec for which we will add a no overflow check (P1).
5592 const SCEV *StartVal = getSCEV(StartValueV);
5593 const SCEV *PHISCEV =
5594 getAddRecExpr(getTruncateExpr(StartVal, TruncTy),
5595 getTruncateExpr(Accum, TruncTy), L, SCEV::FlagAnyWrap);
5596
5597 // PHISCEV can be either a SCEVConstant or a SCEVAddRecExpr.
5598 // ex: If truncated Accum is 0 and StartVal is a constant, then PHISCEV
5599 // will be constant.
5600 //
5601 // If PHISCEV is a constant, then P1 degenerates into P2 or P3, so we don't
5602 // add P1.
5603 if (const auto *AR = dyn_cast<SCEVAddRecExpr>(PHISCEV)) {
5607 const SCEVPredicate *AddRecPred = getWrapPredicate(AR, AddedFlags);
5608 Predicates.push_back(AddRecPred);
5609 }
5610
5611 // Create the Equal Predicates P2,P3:
5612
5613 // It is possible that the predicates P2 and/or P3 are computable at
5614 // compile time due to StartVal and/or Accum being constants.
5615 // If either one is, then we can check that now and escape if either P2
5616 // or P3 is false.
5617
5618 // Construct the extended SCEV: (Ext ix (Trunc iy (Expr) to ix) to iy)
5619 // for each of StartVal and Accum
5620 auto getExtendedExpr = [&](const SCEV *Expr,
5621 bool CreateSignExtend) -> const SCEV * {
5622 assert(isLoopInvariant(Expr, L) && "Expr is expected to be invariant");
5623 const SCEV *TruncatedExpr = getTruncateExpr(Expr, TruncTy);
5624 const SCEV *ExtendedExpr =
5625 CreateSignExtend ? getSignExtendExpr(TruncatedExpr, Expr->getType())
5626 : getZeroExtendExpr(TruncatedExpr, Expr->getType());
5627 return ExtendedExpr;
5628 };
5629
5630 // Given:
5631 // ExtendedExpr = (Ext ix (Trunc iy (Expr) to ix) to iy
5632 // = getExtendedExpr(Expr)
5633 // Determine whether the predicate P: Expr == ExtendedExpr
5634 // is known to be false at compile time
5635 auto PredIsKnownFalse = [&](const SCEV *Expr,
5636 const SCEV *ExtendedExpr) -> bool {
5637 return Expr != ExtendedExpr &&
5638 isKnownPredicate(ICmpInst::ICMP_NE, Expr, ExtendedExpr);
5639 };
5640
5641 const SCEV *StartExtended = getExtendedExpr(StartVal, Signed);
5642 if (PredIsKnownFalse(StartVal, StartExtended)) {
5643 LLVM_DEBUG(dbgs() << "P2 is compile-time false\n";);
5644 return std::nullopt;
5645 }
5646
5647 // The Step is always Signed (because the overflow checks are either
5648 // NSSW or NUSW)
5649 const SCEV *AccumExtended = getExtendedExpr(Accum, /*CreateSignExtend=*/true);
5650 if (PredIsKnownFalse(Accum, AccumExtended)) {
5651 LLVM_DEBUG(dbgs() << "P3 is compile-time false\n";);
5652 return std::nullopt;
5653 }
5654
5655 auto AppendPredicate = [&](const SCEV *Expr,
5656 const SCEV *ExtendedExpr) -> void {
5657 if (Expr != ExtendedExpr &&
5658 !isKnownPredicate(ICmpInst::ICMP_EQ, Expr, ExtendedExpr)) {
5659 const SCEVPredicate *Pred = getEqualPredicate(Expr, ExtendedExpr);
5660 LLVM_DEBUG(dbgs() << "Added Predicate: " << *Pred);
5661 Predicates.push_back(Pred);
5662 }
5663 };
5664
5665 AppendPredicate(StartVal, StartExtended);
5666 AppendPredicate(Accum, AccumExtended);
5667
5668 // *** Part3: Predicates are ready. Now go ahead and create the new addrec in
5669 // which the casts had been folded away. The caller can rewrite SymbolicPHI
5670 // into NewAR if it will also add the runtime overflow checks specified in
5671 // Predicates.
5672 auto *NewAR = getAddRecExpr(StartVal, Accum, L, SCEV::FlagAnyWrap);
5673
5674 std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>> PredRewrite =
5675 std::make_pair(NewAR, Predicates);
5676 // Remember the result of the analysis for this SCEV at this locayyytion.
5677 PredicatedSCEVRewrites[{SymbolicPHI, L}] = PredRewrite;
5678 return PredRewrite;
5679}
5680
5681std::optional<std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>>>
5683 auto *PN = cast<PHINode>(SymbolicPHI->getValue());
5684 const Loop *L = isIntegerLoopHeaderPHI(PN, LI);
5685 if (!L)
5686 return std::nullopt;
5687
5688 // Check to see if we already analyzed this PHI.
5689 auto I = PredicatedSCEVRewrites.find({SymbolicPHI, L});
5690 if (I != PredicatedSCEVRewrites.end()) {
5691 std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>> Rewrite =
5692 I->second;
5693 // Analysis was done before and failed to create an AddRec:
5694 if (Rewrite.first == SymbolicPHI)
5695 return std::nullopt;
5696 // Analysis was done before and succeeded to create an AddRec under
5697 // a predicate:
5698 assert(isa<SCEVAddRecExpr>(Rewrite.first) && "Expected an AddRec");
5699 assert(!(Rewrite.second).empty() && "Expected to find Predicates");
5700 return Rewrite;
5701 }
5702
5703 std::optional<std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>>>
5704 Rewrite = createAddRecFromPHIWithCastsImpl(SymbolicPHI);
5705
5706 // Record in the cache that the analysis failed
5707 if (!Rewrite) {
5709 PredicatedSCEVRewrites[{SymbolicPHI, L}] = {SymbolicPHI, Predicates};
5710 return std::nullopt;
5711 }
5712
5713 return Rewrite;
5714}
5715
5716// FIXME: This utility is currently required because the Rewriter currently
5717// does not rewrite this expression:
5718// {0, +, (sext ix (trunc iy to ix) to iy)}
5719// into {0, +, %step},
5720// even when the following Equal predicate exists:
5721// "%step == (sext ix (trunc iy to ix) to iy)".
5723 const SCEVAddRecExpr *AR1, const SCEVAddRecExpr *AR2) const {
5724 if (AR1 == AR2)
5725 return true;
5726
5727 auto areExprsEqual = [&](const SCEV *Expr1, const SCEV *Expr2) -> bool {
5728 if (Expr1 != Expr2 && !Preds->implies(SE.getEqualPredicate(Expr1, Expr2)) &&
5729 !Preds->implies(SE.getEqualPredicate(Expr2, Expr1)))
5730 return false;
5731 return true;
5732 };
5733
5734 if (!areExprsEqual(AR1->getStart(), AR2->getStart()) ||
5735 !areExprsEqual(AR1->getStepRecurrence(SE), AR2->getStepRecurrence(SE)))
5736 return false;
5737 return true;
5738}
5739
5740/// A helper function for createAddRecFromPHI to handle simple cases.
5741///
5742/// This function tries to find an AddRec expression for the simplest (yet most
5743/// common) cases: PN = PHI(Start, OP(Self, LoopInvariant)).
5744/// If it fails, createAddRecFromPHI will use a more general, but slow,
5745/// technique for finding the AddRec expression.
5746const SCEV *ScalarEvolution::createSimpleAffineAddRec(PHINode *PN,
5747 Value *BEValueV,
5748 Value *StartValueV) {
5749 const Loop *L = LI.getLoopFor(PN->getParent());
5750 assert(L && L->getHeader() == PN->getParent());
5751 assert(BEValueV && StartValueV);
5752
5753 auto BO = MatchBinaryOp(BEValueV, getDataLayout(), AC, DT, PN);
5754 if (!BO)
5755 return nullptr;
5756
5757 if (BO->Opcode != Instruction::Add)
5758 return nullptr;
5759
5760 const SCEV *Accum = nullptr;
5761 if (BO->LHS == PN && L->isLoopInvariant(BO->RHS))
5762 Accum = getSCEV(BO->RHS);
5763 else if (BO->RHS == PN && L->isLoopInvariant(BO->LHS))
5764 Accum = getSCEV(BO->LHS);
5765
5766 if (!Accum)
5767 return nullptr;
5768
5770 if (BO->IsNUW)
5771 Flags = setFlags(Flags, SCEV::FlagNUW);
5772 if (BO->IsNSW)
5773 Flags = setFlags(Flags, SCEV::FlagNSW);
5774
5775 const SCEV *StartVal = getSCEV(StartValueV);
5776 const SCEV *PHISCEV = getAddRecExpr(StartVal, Accum, L, Flags);
5777 insertValueToMap(PN, PHISCEV);
5778
5779 if (auto *AR = dyn_cast<SCEVAddRecExpr>(PHISCEV)) {
5780 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR),
5782 proveNoWrapViaConstantRanges(AR)));
5783 }
5784
5785 // We can add Flags to the post-inc expression only if we
5786 // know that it is *undefined behavior* for BEValueV to
5787 // overflow.
5788 if (auto *BEInst = dyn_cast<Instruction>(BEValueV)) {
5789 assert(isLoopInvariant(Accum, L) &&
5790 "Accum is defined outside L, but is not invariant?");
5791 if (isAddRecNeverPoison(BEInst, L))
5792 (void)getAddRecExpr(getAddExpr(StartVal, Accum), Accum, L, Flags);
5793 }
5794
5795 return PHISCEV;
5796}
5797
5798const SCEV *ScalarEvolution::createAddRecFromPHI(PHINode *PN) {
5799 const Loop *L = LI.getLoopFor(PN->getParent());
5800 if (!L || L->getHeader() != PN->getParent())
5801 return nullptr;
5802
5803 // The loop may have multiple entrances or multiple exits; we can analyze
5804 // this phi as an addrec if it has a unique entry value and a unique
5805 // backedge value.
5806 Value *BEValueV = nullptr, *StartValueV = nullptr;
5807 for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) {
5808 Value *V = PN->getIncomingValue(i);
5809 if (L->contains(PN->getIncomingBlock(i))) {
5810 if (!BEValueV) {
5811 BEValueV = V;
5812 } else if (BEValueV != V) {
5813 BEValueV = nullptr;
5814 break;
5815 }
5816 } else if (!StartValueV) {
5817 StartValueV = V;
5818 } else if (StartValueV != V) {
5819 StartValueV = nullptr;
5820 break;
5821 }
5822 }
5823 if (!BEValueV || !StartValueV)
5824 return nullptr;
5825
5826 assert(ValueExprMap.find_as(PN) == ValueExprMap.end() &&
5827 "PHI node already processed?");
5828
5829 // First, try to find AddRec expression without creating a fictituos symbolic
5830 // value for PN.
5831 if (auto *S = createSimpleAffineAddRec(PN, BEValueV, StartValueV))
5832 return S;
5833
5834 // Handle PHI node value symbolically.
5835 const SCEV *SymbolicName = getUnknown(PN);
5836 insertValueToMap(PN, SymbolicName);
5837
5838 // Using this symbolic name for the PHI, analyze the value coming around
5839 // the back-edge.
5840 const SCEV *BEValue = getSCEV(BEValueV);
5841
5842 // NOTE: If BEValue is loop invariant, we know that the PHI node just
5843 // has a special value for the first iteration of the loop.
5844
5845 // If the value coming around the backedge is an add with the symbolic
5846 // value we just inserted, then we found a simple induction variable!
5847 if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(BEValue)) {
5848 // If there is a single occurrence of the symbolic value, replace it
5849 // with a recurrence.
5850 unsigned FoundIndex = Add->getNumOperands();
5851 for (unsigned i = 0, e = Add->getNumOperands(); i != e; ++i)
5852 if (Add->getOperand(i) == SymbolicName)
5853 if (FoundIndex == e) {
5854 FoundIndex = i;
5855 break;
5856 }
5857
5858 if (FoundIndex != Add->getNumOperands()) {
5859 // Create an add with everything but the specified operand.
5861 for (unsigned i = 0, e = Add->getNumOperands(); i != e; ++i)
5862 if (i != FoundIndex)
5863 Ops.push_back(SCEVBackedgeConditionFolder::rewrite(Add->getOperand(i),
5864 L, *this));
5865 const SCEV *Accum = getAddExpr(Ops);
5866
5867 // This is not a valid addrec if the step amount is varying each
5868 // loop iteration, but is not itself an addrec in this loop.
5869 if (isLoopInvariant(Accum, L) ||
5870 (isa<SCEVAddRecExpr>(Accum) &&
5871 cast<SCEVAddRecExpr>(Accum)->getLoop() == L)) {
5873
5874 if (auto BO = MatchBinaryOp(BEValueV, getDataLayout(), AC, DT, PN)) {
5875 if (BO->Opcode == Instruction::Add && BO->LHS == PN) {
5876 if (BO->IsNUW)
5877 Flags = setFlags(Flags, SCEV::FlagNUW);
5878 if (BO->IsNSW)
5879 Flags = setFlags(Flags, SCEV::FlagNSW);
5880 }
5881 } else if (GEPOperator *GEP = dyn_cast<GEPOperator>(BEValueV)) {
5882 // If the increment is an inbounds GEP, then we know the address
5883 // space cannot be wrapped around. We cannot make any guarantee
5884 // about signed or unsigned overflow because pointers are
5885 // unsigned but we may have a negative index from the base
5886 // pointer. We can guarantee that no unsigned wrap occurs if the
5887 // indices form a positive value.
5888 if (GEP->isInBounds() && GEP->getOperand(0) == PN) {
5889 Flags = setFlags(Flags, SCEV::FlagNW);
5890 if (isKnownPositive(Accum))
5891 Flags = setFlags(Flags, SCEV::FlagNUW);
5892 }
5893
5894 // We cannot transfer nuw and nsw flags from subtraction
5895 // operations -- sub nuw X, Y is not the same as add nuw X, -Y
5896 // for instance.
5897 }
5898
5899 const SCEV *StartVal = getSCEV(StartValueV);
5900 const SCEV *PHISCEV = getAddRecExpr(StartVal, Accum, L, Flags);
5901
5902 // Okay, for the entire analysis of this edge we assumed the PHI
5903 // to be symbolic. We now need to go back and purge all of the
5904 // entries for the scalars that use the symbolic expression.
5905 forgetMemoizedResults(SymbolicName);
5906 insertValueToMap(PN, PHISCEV);
5907
5908 if (auto *AR = dyn_cast<SCEVAddRecExpr>(PHISCEV)) {
5909 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR),
5911 proveNoWrapViaConstantRanges(AR)));
5912 }
5913
5914 // We can add Flags to the post-inc expression only if we
5915 // know that it is *undefined behavior* for BEValueV to
5916 // overflow.
5917 if (auto *BEInst = dyn_cast<Instruction>(BEValueV))
5918 if (isLoopInvariant(Accum, L) && isAddRecNeverPoison(BEInst, L))
5919 (void)getAddRecExpr(getAddExpr(StartVal, Accum), Accum, L, Flags);
5920
5921 return PHISCEV;
5922 }
5923 }
5924 } else {
5925 // Otherwise, this could be a loop like this:
5926 // i = 0; for (j = 1; ..; ++j) { .... i = j; }
5927 // In this case, j = {1,+,1} and BEValue is j.
5928 // Because the other in-value of i (0) fits the evolution of BEValue
5929 // i really is an addrec evolution.
5930 //
5931 // We can generalize this saying that i is the shifted value of BEValue
5932 // by one iteration:
5933 // PHI(f(0), f({1,+,1})) --> f({0,+,1})
5934 const SCEV *Shifted = SCEVShiftRewriter::rewrite(BEValue, L, *this);
5935 const SCEV *Start = SCEVInitRewriter::rewrite(Shifted, L, *this, false);
5936 if (Shifted != getCouldNotCompute() &&
5937 Start != getCouldNotCompute()) {
5938 const SCEV *StartVal = getSCEV(StartValueV);
5939 if (Start == StartVal) {
5940 // Okay, for the entire analysis of this edge we assumed the PHI
5941 // to be symbolic. We now need to go back and purge all of the
5942 // entries for the scalars that use the symbolic expression.
5943 forgetMemoizedResults(SymbolicName);
5944 insertValueToMap(PN, Shifted);
5945 return Shifted;
5946 }
5947 }
5948 }
5949
5950 // Remove the temporary PHI node SCEV that has been inserted while intending
5951 // to create an AddRecExpr for this PHI node. We can not keep this temporary
5952 // as it will prevent later (possibly simpler) SCEV expressions to be added
5953 // to the ValueExprMap.
5954 eraseValueFromMap(PN);
5955
5956 return nullptr;
5957}
5958
5959// Try to match a control flow sequence that branches out at BI and merges back
5960// at Merge into a "C ? LHS : RHS" select pattern. Return true on a successful
5961// match.
5963 Value *&C, Value *&LHS, Value *&RHS) {
5964 C = BI->getCondition();
5965
5966 BasicBlockEdge LeftEdge(BI->getParent(), BI->getSuccessor(0));
5967 BasicBlockEdge RightEdge(BI->getParent(), BI->getSuccessor(1));
5968
5969 if (!LeftEdge.isSingleEdge())
5970 return false;
5971
5972 assert(RightEdge.isSingleEdge() && "Follows from LeftEdge.isSingleEdge()");
5973
5974 Use &LeftUse = Merge->getOperandUse(0);
5975 Use &RightUse = Merge->getOperandUse(1);
5976
5977 if (DT.dominates(LeftEdge, LeftUse) && DT.dominates(RightEdge, RightUse)) {
5978 LHS = LeftUse;
5979 RHS = RightUse;
5980 return true;
5981 }
5982
5983 if (DT.dominates(LeftEdge, RightUse) && DT.dominates(RightEdge, LeftUse)) {
5984 LHS = RightUse;
5985 RHS = LeftUse;
5986 return true;
5987 }
5988
5989 return false;
5990}
5991
5992const SCEV *ScalarEvolution::createNodeFromSelectLikePHI(PHINode *PN) {
5993 auto IsReachable =
5994 [&](BasicBlock *BB) { return DT.isReachableFromEntry(BB); };
5995 if (PN->getNumIncomingValues() == 2 && all_of(PN->blocks(), IsReachable)) {
5996 // Try to match
5997 //
5998 // br %cond, label %left, label %right
5999 // left:
6000 // br label %merge
6001 // right:
6002 // br label %merge
6003 // merge:
6004 // V = phi [ %x, %left ], [ %y, %right ]
6005 //
6006 // as "select %cond, %x, %y"
6007
6008 BasicBlock *IDom = DT[PN->getParent()]->getIDom()->getBlock();
6009 assert(IDom && "At least the entry block should dominate PN");
6010
6011 auto *BI = dyn_cast<BranchInst>(IDom->getTerminator());
6012 Value *Cond = nullptr, *LHS = nullptr, *RHS = nullptr;
6013
6014 if (BI && BI->isConditional() &&
6015 BrPHIToSelect(DT, BI, PN, Cond, LHS, RHS) &&
6016 properlyDominates(getSCEV(LHS), PN->getParent()) &&
6017 properlyDominates(getSCEV(RHS), PN->getParent()))
6018 return createNodeForSelectOrPHI(PN, Cond, LHS, RHS);
6019 }
6020
6021 return nullptr;
6022}
6023
6024const SCEV *ScalarEvolution::createNodeForPHI(PHINode *PN) {
6025 if (const SCEV *S = createAddRecFromPHI(PN))
6026 return S;
6027
6028 if (Value *V = simplifyInstruction(PN, {getDataLayout(), &TLI, &DT, &AC}))
6029 return getSCEV(V);
6030
6031 if (const SCEV *S = createNodeFromSelectLikePHI(PN))
6032 return S;
6033
6034 // If it's not a loop phi, we can't handle it yet.
6035 return getUnknown(PN);
6036}
6037
6038bool SCEVMinMaxExprContains(const SCEV *Root, const SCEV *OperandToFind,
6039 SCEVTypes RootKind) {
6040 struct FindClosure {
6041 const SCEV *OperandToFind;
6042 const SCEVTypes RootKind; // Must be a sequential min/max expression.
6043 const SCEVTypes NonSequentialRootKind; // Non-seq variant of RootKind.
6044
6045 bool Found = false;
6046
6047 bool canRecurseInto(SCEVTypes Kind) const {
6048 // We can only recurse into the SCEV expression of the same effective type
6049 // as the type of our root SCEV expression, and into zero-extensions.
6050 return RootKind == Kind || NonSequentialRootKind == Kind ||
6051 scZeroExtend == Kind;
6052 };
6053
6054 FindClosure(const SCEV *OperandToFind, SCEVTypes RootKind)
6055 : OperandToFind(OperandToFind), RootKind(RootKind),
6056 NonSequentialRootKind(
6058 RootKind)) {}
6059
6060 bool follow(const SCEV *S) {
6061 Found = S == OperandToFind;
6062
6063 return !isDone() && canRecurseInto(S->getSCEVType());
6064 }
6065
6066 bool isDone() const { return Found; }
6067 };
6068
6069 FindClosure FC(OperandToFind, RootKind);
6070 visitAll(Root, FC);
6071 return FC.Found;
6072}
6073
6074std::optional<const SCEV *>
6075ScalarEvolution::createNodeForSelectOrPHIInstWithICmpInstCond(Type *Ty,
6076 ICmpInst *Cond,
6077 Value *TrueVal,
6078 Value *FalseVal) {
6079 // Try to match some simple smax or umax patterns.
6080 auto *ICI = Cond;
6081
6082 Value *LHS = ICI->getOperand(0);
6083 Value *RHS = ICI->getOperand(1);
6084
6085 switch (ICI->getPredicate()) {
6086 case ICmpInst::ICMP_SLT:
6087 case ICmpInst::ICMP_SLE:
6088 case ICmpInst::ICMP_ULT:
6089 case ICmpInst::ICMP_ULE:
6090 std::swap(LHS, RHS);
6091 [[fallthrough]];
6092 case ICmpInst::ICMP_SGT:
6093 case ICmpInst::ICMP_SGE:
6094 case ICmpInst::ICMP_UGT:
6095 case ICmpInst::ICMP_UGE:
6096 // a > b ? a+x : b+x -> max(a, b)+x
6097 // a > b ? b+x : a+x -> min(a, b)+x
6099 bool Signed = ICI->isSigned();
6100 const SCEV *LA = getSCEV(TrueVal);
6101 const SCEV *RA = getSCEV(FalseVal);
6102 const SCEV *LS = getSCEV(LHS);
6103 const SCEV *RS = getSCEV(RHS);
6104 if (LA->getType()->isPointerTy()) {
6105 // FIXME: Handle cases where LS/RS are pointers not equal to LA/RA.
6106 // Need to make sure we can't produce weird expressions involving
6107 // negated pointers.
6108 if (LA == LS && RA == RS)
6109 return Signed ? getSMaxExpr(LS, RS) : getUMaxExpr(LS, RS);
6110 if (LA == RS && RA == LS)
6111 return Signed ? getSMinExpr(LS, RS) : getUMinExpr(LS, RS);
6112 }
6113 auto CoerceOperand = [&](const SCEV *Op) -> const SCEV * {
6114 if (Op->getType()->isPointerTy()) {
6116 if (isa<SCEVCouldNotCompute>(Op))
6117 return Op;
6118 }
6119 if (Signed)
6120 Op = getNoopOrSignExtend(Op, Ty);
6121 else
6122 Op = getNoopOrZeroExtend(Op, Ty);
6123 return Op;
6124 };
6125 LS = CoerceOperand(LS);
6126 RS = CoerceOperand(RS);
6127 if (isa<SCEVCouldNotCompute>(LS) || isa<SCEVCouldNotCompute>(RS))
6128 break;
6129 const SCEV *LDiff = getMinusSCEV(LA, LS);
6130 const SCEV *RDiff = getMinusSCEV(RA, RS);
6131 if (LDiff == RDiff)
6132 return getAddExpr(Signed ? getSMaxExpr(LS, RS) : getUMaxExpr(LS, RS),
6133 LDiff);
6134 LDiff = getMinusSCEV(LA, RS);
6135 RDiff = getMinusSCEV(RA, LS);
6136 if (LDiff == RDiff)
6137 return getAddExpr(Signed ? getSMinExpr(LS, RS) : getUMinExpr(LS, RS),
6138 LDiff);
6139 }
6140 break;
6141 case ICmpInst::ICMP_NE:
6142 // x != 0 ? x+y : C+y -> x == 0 ? C+y : x+y
6143 std::swap(TrueVal, FalseVal);
6144 [[fallthrough]];
6145 case ICmpInst::ICMP_EQ:
6146 // x == 0 ? C+y : x+y -> umax(x, C)+y iff C u<= 1
6148 isa<ConstantInt>(RHS) && cast<ConstantInt>(RHS)->isZero()) {
6149 const SCEV *X = getNoopOrZeroExtend(getSCEV(LHS), Ty);
6150 const SCEV *TrueValExpr = getSCEV(TrueVal); // C+y
6151 const SCEV *FalseValExpr = getSCEV(FalseVal); // x+y
6152 const SCEV *Y = getMinusSCEV(FalseValExpr, X); // y = (x+y)-x
6153 const SCEV *C = getMinusSCEV(TrueValExpr, Y); // C = (C+y)-y
6154 if (isa<SCEVConstant>(C) && cast<SCEVConstant>(C)->getAPInt().ule(1))
6155 return getAddExpr(getUMaxExpr(X, C), Y);
6156 }
6157 // x == 0 ? 0 : umin (..., x, ...) -> umin_seq(x, umin (...))
6158 // x == 0 ? 0 : umin_seq(..., x, ...) -> umin_seq(x, umin_seq(...))
6159 // x == 0 ? 0 : umin (..., umin_seq(..., x, ...), ...)
6160 // -> umin_seq(x, umin (..., umin_seq(...), ...))
6161 if (isa<ConstantInt>(RHS) && cast<ConstantInt>(RHS)->isZero() &&
6162 isa<ConstantInt>(TrueVal) && cast<ConstantInt>(TrueVal)->isZero()) {
6163 const SCEV *X = getSCEV(LHS);
6164 while (auto *ZExt = dyn_cast<SCEVZeroExtendExpr>(X))
6165 X = ZExt->getOperand();
6166 if (getTypeSizeInBits(X->getType()) <= getTypeSizeInBits(Ty)) {
6167 const SCEV *FalseValExpr = getSCEV(FalseVal);
6168 if (SCEVMinMaxExprContains(FalseValExpr, X, scSequentialUMinExpr))
6169 return getUMinExpr(getNoopOrZeroExtend(X, Ty), FalseValExpr,
6170 /*Sequential=*/true);
6171 }
6172 }
6173 break;
6174 default:
6175 break;
6176 }
6177
6178 return std::nullopt;
6179}
6180
6181static std::optional<const SCEV *>
6183 const SCEV *TrueExpr, const SCEV *FalseExpr) {
6184 assert(CondExpr->getType()->isIntegerTy(1) &&
6185 TrueExpr->getType() == FalseExpr->getType() &&
6186 TrueExpr->getType()->isIntegerTy(1) &&
6187 "Unexpected operands of a select.");
6188
6189 // i1 cond ? i1 x : i1 C --> C + (i1 cond ? (i1 x - i1 C) : i1 0)
6190 // --> C + (umin_seq cond, x - C)
6191 //
6192 // i1 cond ? i1 C : i1 x --> C + (i1 cond ? i1 0 : (i1 x - i1 C))
6193 // --> C + (i1 ~cond ? (i1 x - i1 C) : i1 0)
6194 // --> C + (umin_seq ~cond, x - C)
6195
6196 // FIXME: while we can't legally model the case where both of the hands
6197 // are fully variable, we only require that the *difference* is constant.
6198 if (!isa<SCEVConstant>(TrueExpr) && !isa<SCEVConstant>(FalseExpr))
6199 return std::nullopt;
6200
6201 const SCEV *X, *C;
6202 if (isa<SCEVConstant>(TrueExpr)) {
6203 CondExpr = SE->getNotSCEV(CondExpr);
6204 X = FalseExpr;
6205 C = TrueExpr;
6206 } else {
6207 X = TrueExpr;
6208 C = FalseExpr;
6209 }
6210 return SE->getAddExpr(C, SE->getUMinExpr(CondExpr, SE->getMinusSCEV(X, C),
6211 /*Sequential=*/true));
6212}
6213
6214static std::optional<const SCEV *>
6216 Value *FalseVal) {
6217 if (!isa<ConstantInt>(TrueVal) && !isa<ConstantInt>(FalseVal))
6218 return std::nullopt;
6219
6220 const auto *SECond = SE->getSCEV(Cond);
6221 const auto *SETrue = SE->getSCEV(TrueVal);
6222 const auto *SEFalse = SE->getSCEV(FalseVal);
6223 return createNodeForSelectViaUMinSeq(SE, SECond, SETrue, SEFalse);
6224}
6225
6226const SCEV *ScalarEvolution::createNodeForSelectOrPHIViaUMinSeq(
6227 Value *V, Value *Cond, Value *TrueVal, Value *FalseVal) {
6228 assert(Cond->getType()->isIntegerTy(1) && "Select condition is not an i1?");
6229 assert(TrueVal->getType() == FalseVal->getType() &&
6230 V->getType() == TrueVal->getType() &&
6231 "Types of select hands and of the result must match.");
6232
6233 // For now, only deal with i1-typed `select`s.
6234 if (!V->getType()->isIntegerTy(1))
6235 return getUnknown(V);
6236
6237 if (std::optional<const SCEV *> S =
6238 createNodeForSelectViaUMinSeq(this, Cond, TrueVal, FalseVal))
6239 return *S;
6240
6241 return getUnknown(V);
6242}
6243
6244const SCEV *ScalarEvolution::createNodeForSelectOrPHI(Value *V, Value *Cond,
6245 Value *TrueVal,
6246 Value *FalseVal) {
6247 // Handle "constant" branch or select. This can occur for instance when a
6248 // loop pass transforms an inner loop and moves on to process the outer loop.
6249 if (auto *CI = dyn_cast<ConstantInt>(Cond))
6250 return getSCEV(CI->isOne() ? TrueVal : FalseVal);
6251
6252 if (auto *I = dyn_cast<Instruction>(V)) {
6253 if (auto *ICI = dyn_cast<ICmpInst>(Cond)) {
6254 if (std::optional<const SCEV *> S =
6255 createNodeForSelectOrPHIInstWithICmpInstCond(I->getType(), ICI,
6256 TrueVal, FalseVal))
6257 return *S;
6258 }
6259 }
6260
6261 return createNodeForSelectOrPHIViaUMinSeq(V, Cond, TrueVal, FalseVal);
6262}
6263
6264/// Expand GEP instructions into add and multiply operations. This allows them
6265/// to be analyzed by regular SCEV code.
6266const SCEV *ScalarEvolution::createNodeForGEP(GEPOperator *GEP) {
6267 assert(GEP->getSourceElementType()->isSized() &&
6268 "GEP source element type must be sized");
6269
6271 for (Value *Index : GEP->indices())
6272 IndexExprs.push_back(getSCEV(Index));
6273 return getGEPExpr(GEP, IndexExprs);
6274}
6275
6276APInt ScalarEvolution::getConstantMultipleImpl(const SCEV *S) {
6278 auto GetShiftedByZeros = [BitWidth](uint32_t TrailingZeros) {
6279 return TrailingZeros >= BitWidth
6281 : APInt::getOneBitSet(BitWidth, TrailingZeros);
6282 };
6283 auto GetGCDMultiple = [this](const SCEVNAryExpr *N) {
6284 // The result is GCD of all operands results.
6285 APInt Res = getConstantMultiple(N->getOperand(0));
6286 for (unsigned I = 1, E = N->getNumOperands(); I < E && Res != 1; ++I)
6288 Res, getConstantMultiple(N->getOperand(I)));
6289 return Res;
6290 };
6291
6292 switch (S->getSCEVType()) {
6293 case scConstant:
6294 return cast<SCEVConstant>(S)->getAPInt();
6295 case scPtrToInt:
6296 return getConstantMultiple(cast<SCEVPtrToIntExpr>(S)->getOperand());
6297 case scUDivExpr:
6298 case scVScale:
6299 return APInt(BitWidth, 1);
6300 case scTruncate: {
6301 // Only multiples that are a power of 2 will hold after truncation.
6302 const SCEVTruncateExpr *T = cast<SCEVTruncateExpr>(S);
6303 uint32_t TZ = getMinTrailingZeros(T->getOperand());
6304 return GetShiftedByZeros(TZ);
6305 }
6306 case scZeroExtend: {
6307 const SCEVZeroExtendExpr *Z = cast<SCEVZeroExtendExpr>(S);
6308 return getConstantMultiple(Z->getOperand()).zext(BitWidth);
6309 }
6310 case scSignExtend: {
6311 const SCEVSignExtendExpr *E = cast<SCEVSignExtendExpr>(S);
6313 }
6314 case scMulExpr: {
6315 const SCEVMulExpr *M = cast<SCEVMulExpr>(S);
6316 if (M->hasNoUnsignedWrap()) {
6317 // The result is the product of all operand results.
6318 APInt Res = getConstantMultiple(M->getOperand(0));
6319 for (const SCEV *Operand : M->operands().drop_front())
6320 Res = Res * getConstantMultiple(Operand);
6321 return Res;
6322 }
6323
6324 // If there are no wrap guarentees, find the trailing zeros, which is the
6325 // sum of trailing zeros for all its operands.
6326 uint32_t TZ = 0;
6327 for (const SCEV *Operand : M->operands())
6328 TZ += getMinTrailingZeros(Operand);
6329 return GetShiftedByZeros(TZ);
6330 }
6331 case scAddExpr:
6332 case scAddRecExpr: {
6333 const SCEVNAryExpr *N = cast<SCEVNAryExpr>(S);
6334 if (N->hasNoUnsignedWrap())
6335 return GetGCDMultiple(N);
6336 // Find the trailing bits, which is the minimum of its operands.
6337 uint32_t TZ = getMinTrailingZeros(N->getOperand(0));
6338 for (const SCEV *Operand : N->operands().drop_front())
6339 TZ = std::min(TZ, getMinTrailingZeros(Operand));
6340 return GetShiftedByZeros(TZ);
6341 }
6342 case scUMaxExpr:
6343 case scSMaxExpr:
6344 case scUMinExpr:
6345 case scSMinExpr:
6347 return GetGCDMultiple(cast<SCEVNAryExpr>(S));
6348 case scUnknown: {
6349 // ask ValueTracking for known bits
6350 const SCEVUnknown *U = cast<SCEVUnknown>(S);
6351 unsigned Known =
6352 computeKnownBits(U->getValue(), getDataLayout(), 0, &AC, nullptr, &DT)
6353 .countMinTrailingZeros();
6354 return GetShiftedByZeros(Known);
6355 }
6356 case scCouldNotCompute:
6357 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
6358 }
6359 llvm_unreachable("Unknown SCEV kind!");
6360}
6361
6363 auto I = ConstantMultipleCache.find(S);
6364 if (I != ConstantMultipleCache.end())
6365 return I->second;
6366
6367 APInt Result = getConstantMultipleImpl(S);
6368 auto InsertPair = ConstantMultipleCache.insert({S, Result});
6369 assert(InsertPair.second && "Should insert a new key");
6370 return InsertPair.first->second;
6371}
6372
6374 APInt Multiple = getConstantMultiple(S);
6375 return Multiple == 0 ? APInt(Multiple.getBitWidth(), 1) : Multiple;
6376}
6377
6379 return std::min(getConstantMultiple(S).countTrailingZeros(),
6380 (unsigned)getTypeSizeInBits(S->getType()));
6381}
6382
6383/// Helper method to assign a range to V from metadata present in the IR.
6384static std::optional<ConstantRange> GetRangeFromMetadata(Value *V) {
6385 if (Instruction *I = dyn_cast<Instruction>(V)) {
6386 if (MDNode *MD = I->getMetadata(LLVMContext::MD_range))
6387 return getConstantRangeFromMetadata(*MD);
6388 if (const auto *CB = dyn_cast<CallBase>(V))
6389 if (std::optional<ConstantRange> Range = CB->getRange())
6390 return Range;
6391 }
6392 if (auto *A = dyn_cast<Argument>(V))
6393 if (std::optional<ConstantRange> Range = A->getRange())
6394 return Range;
6395
6396 return std::nullopt;
6397}
6398
6400 SCEV::NoWrapFlags Flags) {
6401 if (AddRec->getNoWrapFlags(Flags) != Flags) {
6402 AddRec->setNoWrapFlags(Flags);
6403 UnsignedRanges.erase(AddRec);
6404 SignedRanges.erase(AddRec);
6405 ConstantMultipleCache.erase(AddRec);
6406 }
6407}
6408
6409ConstantRange ScalarEvolution::
6410getRangeForUnknownRecurrence(const SCEVUnknown *U) {
6411 const DataLayout &DL = getDataLayout();
6412
6413 unsigned BitWidth = getTypeSizeInBits(U->getType());
6414 const ConstantRange FullSet(BitWidth, /*isFullSet=*/true);
6415
6416 // Match a simple recurrence of the form: <start, ShiftOp, Step>, and then
6417 // use information about the trip count to improve our available range. Note
6418 // that the trip count independent cases are already handled by known bits.
6419 // WARNING: The definition of recurrence used here is subtly different than
6420 // the one used by AddRec (and thus most of this file). Step is allowed to
6421 // be arbitrarily loop varying here, where AddRec allows only loop invariant
6422 // and other addrecs in the same loop (for non-affine addrecs). The code
6423 // below intentionally handles the case where step is not loop invariant.
6424 auto *P = dyn_cast<PHINode>(U->getValue());
6425 if (!P)
6426 return FullSet;
6427
6428 // Make sure that no Phi input comes from an unreachable block. Otherwise,
6429 // even the values that are not available in these blocks may come from them,
6430 // and this leads to false-positive recurrence test.
6431 for (auto *Pred : predecessors(P->getParent()))
6432 if (!DT.isReachableFromEntry(Pred))
6433 return FullSet;
6434
6435 BinaryOperator *BO;
6436 Value *Start, *Step;
6437 if (!matchSimpleRecurrence(P, BO, Start, Step))
6438 return FullSet;
6439
6440 // If we found a recurrence in reachable code, we must be in a loop. Note
6441 // that BO might be in some subloop of L, and that's completely okay.
6442 auto *L = LI.getLoopFor(P->getParent());
6443 assert(L && L->getHeader() == P->getParent());
6444 if (!L->contains(BO->getParent()))
6445 // NOTE: This bailout should be an assert instead. However, asserting
6446 // the condition here exposes a case where LoopFusion is querying SCEV
6447 // with malformed loop information during the midst of the transform.
6448 // There doesn't appear to be an obvious fix, so for the moment bailout
6449 // until the caller issue can be fixed. PR49566 tracks the bug.
6450 return FullSet;
6451
6452 // TODO: Extend to other opcodes such as mul, and div
6453 switch (BO->getOpcode()) {
6454 default:
6455 return FullSet;
6456 case Instruction::AShr:
6457 case Instruction::LShr:
6458 case Instruction::Shl:
6459 break;
6460 };
6461
6462 if (BO->getOperand(0) != P)
6463 // TODO: Handle the power function forms some day.
6464 return FullSet;
6465
6466 unsigned TC = getSmallConstantMaxTripCount(L);
6467 if (!TC || TC >= BitWidth)
6468 return FullSet;
6469
6470 auto KnownStart = computeKnownBits(Start, DL, 0, &AC, nullptr, &DT);
6471 auto KnownStep = computeKnownBits(Step, DL, 0, &AC, nullptr, &DT);
6472 assert(KnownStart.getBitWidth() == BitWidth &&
6473 KnownStep.getBitWidth() == BitWidth);
6474
6475 // Compute total shift amount, being careful of overflow and bitwidths.
6476 auto MaxShiftAmt = KnownStep.getMaxValue();
6477 APInt TCAP(BitWidth, TC-1);
6478 bool Overflow = false;
6479 auto TotalShift = MaxShiftAmt.umul_ov(TCAP, Overflow);
6480 if (Overflow)
6481 return FullSet;
6482
6483 switch (BO->getOpcode()) {
6484 default:
6485 llvm_unreachable("filtered out above");
6486 case Instruction::AShr: {
6487 // For each ashr, three cases:
6488 // shift = 0 => unchanged value
6489 // saturation => 0 or -1
6490 // other => a value closer to zero (of the same sign)
6491 // Thus, the end value is closer to zero than the start.
6492 auto KnownEnd = KnownBits::ashr(KnownStart,
6493 KnownBits::makeConstant(TotalShift));
6494 if (KnownStart.isNonNegative())
6495 // Analogous to lshr (simply not yet canonicalized)
6496 return ConstantRange::getNonEmpty(KnownEnd.getMinValue(),
6497 KnownStart.getMaxValue() + 1);
6498 if (KnownStart.isNegative())
6499 // End >=u Start && End <=s Start
6500 return ConstantRange::getNonEmpty(KnownStart.getMinValue(),
6501 KnownEnd.getMaxValue() + 1);
6502 break;
6503 }
6504 case Instruction::LShr: {
6505 // For each lshr, three cases:
6506 // shift = 0 => unchanged value
6507 // saturation => 0
6508 // other => a smaller positive number
6509 // Thus, the low end of the unsigned range is the last value produced.
6510 auto KnownEnd = KnownBits::lshr(KnownStart,
6511 KnownBits::makeConstant(TotalShift));
6512 return ConstantRange::getNonEmpty(KnownEnd.getMinValue(),
6513 KnownStart.getMaxValue() + 1);
6514 }
6515 case Instruction::Shl: {
6516 // Iff no bits are shifted out, value increases on every shift.
6517 auto KnownEnd = KnownBits::shl(KnownStart,
6518 KnownBits::makeConstant(TotalShift));
6519 if (TotalShift.ult(KnownStart.countMinLeadingZeros()))
6520 return ConstantRange(KnownStart.getMinValue(),
6521 KnownEnd.getMaxValue() + 1);
6522 break;
6523 }
6524 };
6525 return FullSet;
6526}
6527
6528const ConstantRange &
6529ScalarEvolution::getRangeRefIter(const SCEV *S,
6530 ScalarEvolution::RangeSignHint SignHint) {
6532 SignHint == ScalarEvolution::HINT_RANGE_UNSIGNED ? UnsignedRanges
6533 : SignedRanges;
6536
6537 // Add Expr to the worklist, if Expr is either an N-ary expression or a
6538 // SCEVUnknown PHI node.
6539 auto AddToWorklist = [&WorkList, &Seen, &Cache](const SCEV *Expr) {
6540 if (!Seen.insert(Expr).second)
6541 return;
6542 if (Cache.contains(Expr))
6543 return;
6544 switch (Expr->getSCEVType()) {
6545 case scUnknown:
6546 if (!isa<PHINode>(cast<SCEVUnknown>(Expr)->getValue()))
6547 break;
6548 [[fallthrough]];
6549 case scConstant:
6550 case scVScale:
6551 case scTruncate:
6552 case scZeroExtend:
6553 case scSignExtend:
6554 case scPtrToInt:
6555 case scAddExpr:
6556 case scMulExpr:
6557 case scUDivExpr:
6558 case scAddRecExpr:
6559 case scUMaxExpr:
6560 case scSMaxExpr:
6561 case scUMinExpr:
6562 case scSMinExpr:
6564 WorkList.push_back(Expr);
6565 break;
6566 case scCouldNotCompute:
6567 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
6568 }
6569 };
6570 AddToWorklist(S);
6571
6572 // Build worklist by queuing operands of N-ary expressions and phi nodes.
6573 for (unsigned I = 0; I != WorkList.size(); ++I) {
6574 const SCEV *P = WorkList[I];
6575 auto *UnknownS = dyn_cast<SCEVUnknown>(P);
6576 // If it is not a `SCEVUnknown`, just recurse into operands.
6577 if (!UnknownS) {
6578 for (const SCEV *Op : P->operands())
6579 AddToWorklist(Op);
6580 continue;
6581 }
6582 // `SCEVUnknown`'s require special treatment.
6583 if (const PHINode *P = dyn_cast<PHINode>(UnknownS->getValue())) {
6584 if (!PendingPhiRangesIter.insert(P).second)
6585 continue;
6586 for (auto &Op : reverse(P->operands()))
6587 AddToWorklist(getSCEV(Op));
6588 }
6589 }
6590
6591 if (!WorkList.empty()) {
6592 // Use getRangeRef to compute ranges for items in the worklist in reverse
6593 // order. This will force ranges for earlier operands to be computed before
6594 // their users in most cases.
6595 for (const SCEV *P : reverse(drop_begin(WorkList))) {
6596 getRangeRef(P, SignHint);
6597
6598 if (auto *UnknownS = dyn_cast<SCEVUnknown>(P))
6599 if (const PHINode *P = dyn_cast<PHINode>(UnknownS->getValue()))
6600 PendingPhiRangesIter.erase(P);
6601 }
6602 }
6603
6604 return getRangeRef(S, SignHint, 0);
6605}
6606
6607/// Determine the range for a particular SCEV. If SignHint is
6608/// HINT_RANGE_UNSIGNED (resp. HINT_RANGE_SIGNED) then getRange prefers ranges
6609/// with a "cleaner" unsigned (resp. signed) representation.
6610const ConstantRange &ScalarEvolution::getRangeRef(
6611 const SCEV *S, ScalarEvolution::RangeSignHint SignHint, unsigned Depth) {
6613 SignHint == ScalarEvolution::HINT_RANGE_UNSIGNED ? UnsignedRanges
6614 : SignedRanges;
6616 SignHint == ScalarEvolution::HINT_RANGE_UNSIGNED ? ConstantRange::Unsigned
6618
6619 // See if we've computed this range already.
6621 if (I != Cache.end())
6622 return I->second;
6623
6624 if (const SCEVConstant *C = dyn_cast<SCEVConstant>(S))
6625 return setRange(C, SignHint, ConstantRange(C->getAPInt()));
6626
6627 // Switch to iteratively computing the range for S, if it is part of a deeply
6628 // nested expression.
6630 return getRangeRefIter(S, SignHint);
6631
6632 unsigned BitWidth = getTypeSizeInBits(S->getType());
6633 ConstantRange ConservativeResult(BitWidth, /*isFullSet=*/true);
6634 using OBO = OverflowingBinaryOperator;
6635
6636 // If the value has known zeros, the maximum value will have those known zeros
6637 // as well.
6638 if (SignHint == ScalarEvolution::HINT_RANGE_UNSIGNED) {
6639 APInt Multiple = getNonZeroConstantMultiple(S);
6640 APInt Remainder = APInt::getMaxValue(BitWidth).urem(Multiple);
6641 if (!Remainder.isZero())
6642 ConservativeResult =
6644 APInt::getMaxValue(BitWidth) - Remainder + 1);
6645 }
6646 else {
6648 if (TZ != 0) {
6649 ConservativeResult = ConstantRange(
6651 APInt::getSignedMaxValue(BitWidth).ashr(TZ).shl(TZ) + 1);
6652 }
6653 }
6654
6655 switch (S->getSCEVType()) {
6656 case scConstant:
6657 llvm_unreachable("Already handled above.");
6658 case scVScale:
6659 return setRange(S, SignHint, getVScaleRange(&F, BitWidth));
6660 case scTruncate: {
6661 const SCEVTruncateExpr *Trunc = cast<SCEVTruncateExpr>(S);
6662 ConstantRange X = getRangeRef(Trunc->getOperand(), SignHint, Depth + 1);
6663 return setRange(
6664 Trunc, SignHint,
6665 ConservativeResult.intersectWith(X.truncate(BitWidth), RangeType));
6666 }
6667 case scZeroExtend: {
6668 const SCEVZeroExtendExpr *ZExt = cast<SCEVZeroExtendExpr>(S);
6669 ConstantRange X = getRangeRef(ZExt->getOperand(), SignHint, Depth + 1);
6670 return setRange(
6671 ZExt, SignHint,
6672 ConservativeResult.intersectWith(X.zeroExtend(BitWidth), RangeType));
6673 }
6674 case scSignExtend: {
6675 const SCEVSignExtendExpr *SExt = cast<SCEVSignExtendExpr>(S);
6676 ConstantRange X = getRangeRef(SExt->getOperand(), SignHint, Depth + 1);
6677 return setRange(
6678 SExt, SignHint,
6679 ConservativeResult.intersectWith(X.signExtend(BitWidth), RangeType));
6680 }
6681 case scPtrToInt: {
6682 const SCEVPtrToIntExpr *PtrToInt = cast<SCEVPtrToIntExpr>(S);
6683 ConstantRange X = getRangeRef(PtrToInt->getOperand(), SignHint, Depth + 1);
6684 return setRange(PtrToInt, SignHint, X);
6685 }
6686 case scAddExpr: {
6687 const SCEVAddExpr *Add = cast<SCEVAddExpr>(S);
6688 ConstantRange X = getRangeRef(Add->getOperand(0), SignHint, Depth + 1);
6689 unsigned WrapType = OBO::AnyWrap;
6690 if (Add->hasNoSignedWrap())
6691 WrapType |= OBO::NoSignedWrap;
6692 if (Add->hasNoUnsignedWrap())
6693 WrapType |= OBO::NoUnsignedWrap;
6694 for (unsigned i = 1, e = Add->getNumOperands(); i != e; ++i)
6695 X = X.addWithNoWrap(getRangeRef(Add->getOperand(i), SignHint, Depth + 1),
6696 WrapType, RangeType);
6697 return setRange(Add, SignHint,
6698 ConservativeResult.intersectWith(X, RangeType));
6699 }
6700 case scMulExpr: {
6701 const SCEVMulExpr *Mul = cast<SCEVMulExpr>(S);
6702 ConstantRange X = getRangeRef(Mul->getOperand(0), SignHint, Depth + 1);
6703 for (unsigned i = 1, e = Mul->getNumOperands(); i != e; ++i)
6704 X = X.multiply(getRangeRef(Mul->getOperand(i), SignHint, Depth + 1));
6705 return setRange(Mul, SignHint,
6706 ConservativeResult.intersectWith(X, RangeType));
6707 }
6708 case scUDivExpr: {
6709 const SCEVUDivExpr *UDiv = cast<SCEVUDivExpr>(S);
6710 ConstantRange X = getRangeRef(UDiv->getLHS(), SignHint, Depth + 1);
6711 ConstantRange Y = getRangeRef(UDiv->getRHS(), SignHint, Depth + 1);
6712 return setRange(UDiv, SignHint,
6713 ConservativeResult.intersectWith(X.udiv(Y), RangeType));
6714 }
6715 case scAddRecExpr: {
6716 const SCEVAddRecExpr *AddRec = cast<SCEVAddRecExpr>(S);
6717 // If there's no unsigned wrap, the value will never be less than its
6718 // initial value.
6719 if (AddRec->hasNoUnsignedWrap()) {
6720 APInt UnsignedMinValue = getUnsignedRangeMin(AddRec->getStart());
6721 if (!UnsignedMinValue.isZero())
6722 ConservativeResult = ConservativeResult.intersectWith(
6723 ConstantRange(UnsignedMinValue, APInt(BitWidth, 0)), RangeType);
6724 }
6725
6726 // If there's no signed wrap, and all the operands except initial value have
6727 // the same sign or zero, the value won't ever be:
6728 // 1: smaller than initial value if operands are non negative,
6729 // 2: bigger than initial value if operands are non positive.
6730 // For both cases, value can not cross signed min/max boundary.
6731 if (AddRec->hasNoSignedWrap()) {
6732 bool AllNonNeg = true;
6733 bool AllNonPos = true;
6734 for (unsigned i = 1, e = AddRec->getNumOperands(); i != e; ++i) {
6735 if (!isKnownNonNegative(AddRec->getOperand(i)))
6736 AllNonNeg = false;
6737 if (!isKnownNonPositive(AddRec->getOperand(i)))
6738 AllNonPos = false;
6739 }
6740 if (AllNonNeg)
6741 ConservativeResult = ConservativeResult.intersectWith(
6744 RangeType);
6745 else if (AllNonPos)
6746 ConservativeResult = ConservativeResult.intersectWith(
6748 getSignedRangeMax(AddRec->getStart()) +
6749 1),
6750 RangeType);
6751 }
6752
6753 // TODO: non-affine addrec
6754 if (AddRec->isAffine()) {
6755 const SCEV *MaxBEScev =
6757 if (!isa<SCEVCouldNotCompute>(MaxBEScev)) {
6758 APInt MaxBECount = cast<SCEVConstant>(MaxBEScev)->getAPInt();
6759
6760 // Adjust MaxBECount to the same bitwidth as AddRec. We can truncate if
6761 // MaxBECount's active bits are all <= AddRec's bit width.
6762 if (MaxBECount.getBitWidth() > BitWidth &&
6763 MaxBECount.getActiveBits() <= BitWidth)
6764 MaxBECount = MaxBECount.trunc(BitWidth);
6765 else if (MaxBECount.getBitWidth() < BitWidth)
6766 MaxBECount = MaxBECount.zext(BitWidth);
6767
6768 if (MaxBECount.getBitWidth() == BitWidth) {
6769 auto RangeFromAffine = getRangeForAffineAR(
6770 AddRec->getStart(), AddRec->getStepRecurrence(*this), MaxBECount);
6771 ConservativeResult =
6772 ConservativeResult.intersectWith(RangeFromAffine, RangeType);
6773
6774 auto RangeFromFactoring = getRangeViaFactoring(
6775 AddRec->getStart(), AddRec->getStepRecurrence(*this), MaxBECount);
6776 ConservativeResult =
6777 ConservativeResult.intersectWith(RangeFromFactoring, RangeType);
6778 }
6779 }
6780
6781 // Now try symbolic BE count and more powerful methods.
6783 const SCEV *SymbolicMaxBECount =
6785 if (!isa<SCEVCouldNotCompute>(SymbolicMaxBECount) &&
6786 getTypeSizeInBits(MaxBEScev->getType()) <= BitWidth &&
6787 AddRec->hasNoSelfWrap()) {
6788 auto RangeFromAffineNew = getRangeForAffineNoSelfWrappingAR(
6789 AddRec, SymbolicMaxBECount, BitWidth, SignHint);
6790 ConservativeResult =
6791 ConservativeResult.intersectWith(RangeFromAffineNew, RangeType);
6792 }
6793 }
6794 }
6795
6796 return setRange(AddRec, SignHint, std::move(ConservativeResult));
6797 }
6798 case scUMaxExpr:
6799 case scSMaxExpr:
6800 case scUMinExpr:
6801 case scSMinExpr:
6802 case scSequentialUMinExpr: {
6804 switch (S->getSCEVType()) {
6805 case scUMaxExpr:
6806 ID = Intrinsic::umax;
6807 break;
6808 case scSMaxExpr:
6809 ID = Intrinsic::smax;
6810 break;
6811 case scUMinExpr:
6813 ID = Intrinsic::umin;
6814 break;
6815 case scSMinExpr:
6816 ID = Intrinsic::smin;
6817 break;
6818 default:
6819 llvm_unreachable("Unknown SCEVMinMaxExpr/SCEVSequentialMinMaxExpr.");
6820 }
6821
6822 const auto *NAry = cast<SCEVNAryExpr>(S);
6823 ConstantRange X = getRangeRef(NAry->getOperand(0), SignHint, Depth + 1);
6824 for (unsigned i = 1, e = NAry->getNumOperands(); i != e; ++i)
6825 X = X.intrinsic(
6826 ID, {X, getRangeRef(NAry->getOperand(i), SignHint, Depth + 1)});
6827 return setRange(S, SignHint,
6828 ConservativeResult.intersectWith(X, RangeType));
6829 }
6830 case scUnknown: {
6831 const SCEVUnknown *U = cast<SCEVUnknown>(S);
6832 Value *V = U->getValue();
6833
6834 // Check if the IR explicitly contains !range metadata.
6835 std::optional<ConstantRange> MDRange = GetRangeFromMetadata(V);
6836 if (MDRange)
6837 ConservativeResult =
6838 ConservativeResult.intersectWith(*MDRange, RangeType);
6839
6840 // Use facts about recurrences in the underlying IR. Note that add
6841 // recurrences are AddRecExprs and thus don't hit this path. This
6842 // primarily handles shift recurrences.
6843 auto CR = getRangeForUnknownRecurrence(U);
6844 ConservativeResult = ConservativeResult.intersectWith(CR);
6845
6846 // See if ValueTracking can give us a useful range.
6847 const DataLayout &DL = getDataLayout();
6848 KnownBits Known = computeKnownBits(V, DL, 0, &AC, nullptr, &DT);
6849 if (Known.getBitWidth() != BitWidth)
6850 Known = Known.zextOrTrunc(BitWidth);
6851
6852 // ValueTracking may be able to compute a tighter result for the number of
6853 // sign bits than for the value of those sign bits.
6854 unsigned NS = ComputeNumSignBits(V, DL, 0, &AC, nullptr, &DT);
6855 if (U->getType()->isPointerTy()) {
6856 // If the pointer size is larger than the index size type, this can cause
6857 // NS to be larger than BitWidth. So compensate for this.
6858 unsigned ptrSize = DL.getPointerTypeSizeInBits(U->getType());
6859 int ptrIdxDiff = ptrSize - BitWidth;
6860 if (ptrIdxDiff > 0 && ptrSize > BitWidth && NS > (unsigned)ptrIdxDiff)
6861 NS -= ptrIdxDiff;
6862 }
6863
6864 if (NS > 1) {
6865 // If we know any of the sign bits, we know all of the sign bits.
6866 if (!Known.Zero.getHiBits(NS).isZero())
6867 Known.Zero.setHighBits(NS);
6868 if (!Known.One.getHiBits(NS).isZero())
6869 Known.One.setHighBits(NS);
6870 }
6871
6872 if (Known.getMinValue() != Known.getMaxValue() + 1)
6873 ConservativeResult = ConservativeResult.intersectWith(
6874 ConstantRange(Known.getMinValue(), Known.getMaxValue() + 1),
6875 RangeType);
6876 if (NS > 1)
6877 ConservativeResult = ConservativeResult.intersectWith(
6879 APInt::getSignedMaxValue(BitWidth).ashr(NS - 1) + 1),
6880 RangeType);
6881
6882 if (U->getType()->isPointerTy() && SignHint == HINT_RANGE_UNSIGNED) {
6883 // Strengthen the range if the underlying IR value is a
6884 // global/alloca/heap allocation using the size of the object.
6885 ObjectSizeOpts Opts;
6886 Opts.RoundToAlign = false;
6887 Opts.NullIsUnknownSize = true;
6888 uint64_t ObjSize;
6889 if ((isa<GlobalVariable>(V) || isa<AllocaInst>(V) ||
6890 isAllocationFn(V, &TLI)) &&
6891 getObjectSize(V, ObjSize, DL, &TLI, Opts) && ObjSize > 1) {
6892 // The highest address the object can start is ObjSize bytes before the
6893 // end (unsigned max value). If this value is not a multiple of the
6894 // alignment, the last possible start value is the next lowest multiple
6895 // of the alignment. Note: The computations below cannot overflow,
6896 // because if they would there's no possible start address for the
6897 // object.
6898 APInt MaxVal = APInt::getMaxValue(BitWidth) - APInt(BitWidth, ObjSize);
6899 uint64_t Align = U->getValue()->getPointerAlignment(DL).value();
6900 uint64_t Rem = MaxVal.urem(Align);
6901 MaxVal -= APInt(BitWidth, Rem);
6902 APInt MinVal = APInt::getZero(BitWidth);
6903 if (llvm::isKnownNonZero(V, DL))
6904 MinVal = Align;
6905 ConservativeResult = ConservativeResult.intersectWith(
6906 ConstantRange::getNonEmpty(MinVal, MaxVal + 1), RangeType);
6907 }
6908 }
6909
6910 // A range of Phi is a subset of union of all ranges of its input.
6911 if (PHINode *Phi = dyn_cast<PHINode>(V)) {
6912 // Make sure that we do not run over cycled Phis.
6913 if (PendingPhiRanges.insert(Phi).second) {
6914 ConstantRange RangeFromOps(BitWidth, /*isFullSet=*/false);
6915
6916 for (const auto &Op : Phi->operands()) {
6917 auto OpRange = getRangeRef(getSCEV(Op), SignHint, Depth + 1);
6918 RangeFromOps = RangeFromOps.unionWith(OpRange);
6919 // No point to continue if we already have a full set.
6920 if (RangeFromOps.isFullSet())
6921 break;
6922 }
6923 ConservativeResult =
6924 ConservativeResult.intersectWith(RangeFromOps, RangeType);
6925 bool Erased = PendingPhiRanges.erase(Phi);
6926 assert(Erased && "Failed to erase Phi properly?");
6927 (void)Erased;
6928 }
6929 }
6930
6931 // vscale can't be equal to zero
6932 if (const auto *II = dyn_cast<IntrinsicInst>(V))
6933 if (II->getIntrinsicID() == Intrinsic::vscale) {
6935 ConservativeResult = ConservativeResult.difference(Disallowed);
6936 }
6937
6938 return setRange(U, SignHint, std::move(ConservativeResult));
6939 }
6940 case scCouldNotCompute:
6941 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
6942 }
6943
6944 return setRange(S, SignHint, std::move(ConservativeResult));
6945}
6946
6947// Given a StartRange, Step and MaxBECount for an expression compute a range of
6948// values that the expression can take. Initially, the expression has a value
6949// from StartRange and then is changed by Step up to MaxBECount times. Signed
6950// argument defines if we treat Step as signed or unsigned.
6952 const ConstantRange &StartRange,
6953 const APInt &MaxBECount,
6954 bool Signed) {
6955 unsigned BitWidth = Step.getBitWidth();
6956 assert(BitWidth == StartRange.getBitWidth() &&
6957 BitWidth == MaxBECount.getBitWidth() && "mismatched bit widths");
6958 // If either Step or MaxBECount is 0, then the expression won't change, and we
6959 // just need to return the initial range.
6960 if (Step == 0 || MaxBECount == 0)
6961 return StartRange;
6962
6963 // If we don't know anything about the initial value (i.e. StartRange is
6964 // FullRange), then we don't know anything about the final range either.
6965 // Return FullRange.
6966 if (StartRange.isFullSet())
6967 return ConstantRange::getFull(BitWidth);
6968
6969 // If Step is signed and negative, then we use its absolute value, but we also
6970 // note that we're moving in the opposite direction.
6971 bool Descending = Signed && Step.isNegative();
6972
6973 if (Signed)
6974 // This is correct even for INT_SMIN. Let's look at i8 to illustrate this:
6975 // abs(INT_SMIN) = abs(-128) = abs(0x80) = -0x80 = 0x80 = 128.
6976 // This equations hold true due to the well-defined wrap-around behavior of
6977 // APInt.
6978 Step = Step.abs();
6979
6980 // Check if Offset is more than full span of BitWidth. If it is, the
6981 // expression is guaranteed to overflow.
6982 if (APInt::getMaxValue(StartRange.getBitWidth()).udiv(Step).ult(MaxBECount))
6983 return ConstantRange::getFull(BitWidth);
6984
6985 // Offset is by how much the expression can change. Checks above guarantee no
6986 // overflow here.
6987 APInt Offset = Step * MaxBECount;
6988
6989 // Minimum value of the final range will match the minimal value of StartRange
6990 // if the expression is increasing and will be decreased by Offset otherwise.
6991 // Maximum value of the final range will match the maximal value of StartRange
6992 // if the expression is decreasing and will be increased by Offset otherwise.
6993 APInt StartLower = StartRange.getLower();
6994 APInt StartUpper = StartRange.getUpper() - 1;
6995 APInt MovedBoundary = Descending ? (StartLower - std::move(Offset))
6996 : (StartUpper + std::move(Offset));
6997
6998 // It's possible that the new minimum/maximum value will fall into the initial
6999 // range (due to wrap around). This means that the expression can take any
7000 // value in this bitwidth, and we have to return full range.
7001 if (StartRange.contains(MovedBoundary))
7002 return ConstantRange::getFull(BitWidth);
7003
7004 APInt NewLower =
7005 Descending ? std::move(MovedBoundary) : std::move(StartLower);
7006 APInt NewUpper =
7007 Descending ? std::move(StartUpper) : std::move(MovedBoundary);
7008 NewUpper += 1;
7009
7010 // No overflow detected, return [StartLower, StartUpper + Offset + 1) range.
7011 return ConstantRange::getNonEmpty(std::move(NewLower), std::move(NewUpper));
7012}
7013
7014ConstantRange ScalarEvolution::getRangeForAffineAR(const SCEV *Start,
7015 const SCEV *Step,
7016 const APInt &MaxBECount) {
7017 assert(getTypeSizeInBits(Start->getType()) ==
7018 getTypeSizeInBits(Step->getType()) &&
7019 getTypeSizeInBits(Start->getType()) == MaxBECount.getBitWidth() &&
7020 "mismatched bit widths");
7021
7022 // First, consider step signed.
7023 ConstantRange StartSRange = getSignedRange(Start);
7024 ConstantRange StepSRange = getSignedRange(Step);
7025
7026 // If Step can be both positive and negative, we need to find ranges for the
7027 // maximum absolute step values in both directions and union them.
7029 StepSRange.getSignedMin(), StartSRange, MaxBECount, /* Signed = */ true);
7031 StartSRange, MaxBECount,
7032 /* Signed = */ true));
7033
7034 // Next, consider step unsigned.
7036 getUnsignedRangeMax(Step), getUnsignedRange(Start), MaxBECount,
7037 /* Signed = */ false);
7038
7039 // Finally, intersect signed and unsigned ranges.
7041}
7042
7043ConstantRange ScalarEvolution::getRangeForAffineNoSelfWrappingAR(
7044 const SCEVAddRecExpr *AddRec, const SCEV *MaxBECount, unsigned BitWidth,
7045 ScalarEvolution::RangeSignHint SignHint) {
7046 assert(AddRec->isAffine() && "Non-affine AddRecs are not suppored!\n");
7047 assert(AddRec->hasNoSelfWrap() &&
7048 "This only works for non-self-wrapping AddRecs!");
7049 const bool IsSigned = SignHint == HINT_RANGE_SIGNED;
7050 const SCEV *Step = AddRec->getStepRecurrence(*this);
7051 // Only deal with constant step to save compile time.
7052 if (!isa<SCEVConstant>(Step))
7053 return ConstantRange::getFull(BitWidth);
7054 // Let's make sure that we can prove that we do not self-wrap during
7055 // MaxBECount iterations. We need this because MaxBECount is a maximum
7056 // iteration count estimate, and we might infer nw from some exit for which we
7057 // do not know max exit count (or any other side reasoning).
7058 // TODO: Turn into assert at some point.
7059 if (getTypeSizeInBits(MaxBECount->getType()) >
7060 getTypeSizeInBits(AddRec->getType()))
7061 return ConstantRange::getFull(BitWidth);
7062 MaxBECount = getNoopOrZeroExtend(MaxBECount, AddRec->getType());
7063 const SCEV *RangeWidth = getMinusOne(AddRec->getType());
7064 const SCEV *StepAbs = getUMinExpr(Step, getNegativeSCEV(Step));
7065 const SCEV *MaxItersWithoutWrap = getUDivExpr(RangeWidth, StepAbs);
7066 if (!isKnownPredicateViaConstantRanges(ICmpInst::ICMP_ULE, MaxBECount,
7067 MaxItersWithoutWrap))
7068 return ConstantRange::getFull(BitWidth);
7069
7070 ICmpInst::Predicate LEPred =
7072 ICmpInst::Predicate GEPred =
7074 const SCEV *End = AddRec->evaluateAtIteration(MaxBECount, *this);
7075
7076 // We know that there is no self-wrap. Let's take Start and End values and
7077 // look at all intermediate values V1, V2, ..., Vn that IndVar takes during
7078 // the iteration. They either lie inside the range [Min(Start, End),
7079 // Max(Start, End)] or outside it:
7080 //
7081 // Case 1: RangeMin ... Start V1 ... VN End ... RangeMax;
7082 // Case 2: RangeMin Vk ... V1 Start ... End Vn ... Vk + 1 RangeMax;
7083 //
7084 // No self wrap flag guarantees that the intermediate values cannot be BOTH
7085 // outside and inside the range [Min(Start, End), Max(Start, End)]. Using that
7086 // knowledge, let's try to prove that we are dealing with Case 1. It is so if
7087 // Start <= End and step is positive, or Start >= End and step is negative.
7088 const SCEV *Start = applyLoopGuards(AddRec->getStart(), AddRec->getLoop());
7089 ConstantRange StartRange = getRangeRef(Start, SignHint);
7090 ConstantRange EndRange = getRangeRef(End, SignHint);
7091 ConstantRange RangeBetween = StartRange.unionWith(EndRange);
7092 // If they already cover full iteration space, we will know nothing useful
7093 // even if we prove what we want to prove.
7094 if (RangeBetween.isFullSet())
7095 return RangeBetween;
7096 // Only deal with ranges that do not wrap (i.e. RangeMin < RangeMax).
7097 bool IsWrappedSet = IsSigned ? RangeBetween.isSignWrappedSet()
7098 : RangeBetween.isWrappedSet();
7099 if (IsWrappedSet)
7100 return ConstantRange::getFull(BitWidth);
7101
7102 if (isKnownPositive(Step) &&
7103 isKnownPredicateViaConstantRanges(LEPred, Start, End))
7104 return RangeBetween;
7105 if (isKnownNegative(Step) &&
7106 isKnownPredicateViaConstantRanges(GEPred, Start, End))
7107 return RangeBetween;
7108 return ConstantRange::getFull(BitWidth);
7109}
7110
7111ConstantRange ScalarEvolution::getRangeViaFactoring(const SCEV *Start,
7112 const SCEV *Step,
7113 const APInt &MaxBECount) {
7114 // RangeOf({C?A:B,+,C?P:Q}) == RangeOf(C?{A,+,P}:{B,+,Q})
7115 // == RangeOf({A,+,P}) union RangeOf({B,+,Q})
7116
7117 unsigned BitWidth = MaxBECount.getBitWidth();
7118 assert(getTypeSizeInBits(Start->getType()) == BitWidth &&
7119 getTypeSizeInBits(Step->getType()) == BitWidth &&
7120 "mismatched bit widths");
7121
7122 struct SelectPattern {
7123 Value *Condition = nullptr;
7124 APInt TrueValue;
7125 APInt FalseValue;
7126
7127 explicit SelectPattern(ScalarEvolution &SE, unsigned BitWidth,
7128 const SCEV *S) {
7129 std::optional<unsigned> CastOp;
7130 APInt Offset(BitWidth, 0);
7131
7133 "Should be!");
7134
7135 // Peel off a constant offset:
7136 if (auto *SA = dyn_cast<SCEVAddExpr>(S)) {
7137 // In the future we could consider being smarter here and handle
7138 // {Start+Step,+,Step} too.
7139 if (SA->getNumOperands() != 2 || !isa<SCEVConstant>(SA->getOperand(0)))
7140 return;
7141
7142 Offset = cast<SCEVConstant>(SA->getOperand(0))->getAPInt();
7143 S = SA->getOperand(1);
7144 }
7145
7146 // Peel off a cast operation
7147 if (auto *SCast = dyn_cast<SCEVIntegralCastExpr>(S)) {
7148 CastOp = SCast->getSCEVType();
7149 S = SCast->getOperand();
7150 }
7151
7152 using namespace llvm::PatternMatch;
7153
7154 auto *SU = dyn_cast<SCEVUnknown>(S);
7155 const APInt *TrueVal, *FalseVal;
7156 if (!SU ||
7157 !match(SU->getValue(), m_Select(m_Value(Condition), m_APInt(TrueVal),
7158 m_APInt(FalseVal)))) {
7159 Condition = nullptr;
7160 return;
7161 }
7162
7163 TrueValue = *TrueVal;
7164 FalseValue = *FalseVal;
7165
7166 // Re-apply the cast we peeled off earlier
7167 if (CastOp)
7168 switch (*CastOp) {
7169 default:
7170 llvm_unreachable("Unknown SCEV cast type!");
7171
7172 case scTruncate:
7173 TrueValue = TrueValue.trunc(BitWidth);
7174 FalseValue = FalseValue.trunc(BitWidth);
7175 break;
7176 case scZeroExtend:
7177 TrueValue = TrueValue.zext(BitWidth);
7178 FalseValue = FalseValue.zext(BitWidth);
7179 break;
7180 case scSignExtend:
7181 TrueValue = TrueValue.sext(BitWidth);
7182 FalseValue = FalseValue.sext(BitWidth);
7183 break;
7184 }
7185
7186 // Re-apply the constant offset we peeled off earlier
7187 TrueValue += Offset;
7188 FalseValue += Offset;
7189 }
7190
7191 bool isRecognized() { return Condition != nullptr; }
7192 };
7193
7194 SelectPattern StartPattern(*this, BitWidth, Start);
7195 if (!StartPattern.isRecognized())
7196 return ConstantRange::getFull(BitWidth);
7197
7198 SelectPattern StepPattern(*this, BitWidth, Step);
7199 if (!StepPattern.isRecognized())
7200 return ConstantRange::getFull(BitWidth);
7201
7202 if (StartPattern.Condition != StepPattern.Condition) {
7203 // We don't handle this case today; but we could, by considering four
7204 // possibilities below instead of two. I'm not sure if there are cases where
7205 // that will help over what getRange already does, though.
7206 return ConstantRange::getFull(BitWidth);
7207 }
7208
7209 // NB! Calling ScalarEvolution::getConstant is fine, but we should not try to
7210 // construct arbitrary general SCEV expressions here. This function is called
7211 // from deep in the call stack, and calling getSCEV (on a sext instruction,
7212 // say) can end up caching a suboptimal value.
7213
7214 // FIXME: without the explicit `this` receiver below, MSVC errors out with
7215 // C2352 and C2512 (otherwise it isn't needed).
7216
7217 const SCEV *TrueStart = this->getConstant(StartPattern.TrueValue);
7218 const SCEV *TrueStep = this->getConstant(StepPattern.TrueValue);
7219 const SCEV *FalseStart = this->getConstant(StartPattern.FalseValue);
7220 const SCEV *FalseStep = this->getConstant(StepPattern.FalseValue);
7221
7222 ConstantRange TrueRange =
7223 this->getRangeForAffineAR(TrueStart, TrueStep, MaxBECount);
7224 ConstantRange FalseRange =
7225 this->getRangeForAffineAR(FalseStart, FalseStep, MaxBECount);
7226
7227 return TrueRange.unionWith(FalseRange);
7228}
7229
7230SCEV::NoWrapFlags ScalarEvolution::getNoWrapFlagsFromUB(const Value *V) {
7231 if (isa<ConstantExpr>(V)) return SCEV::FlagAnyWrap;
7232 const BinaryOperator *BinOp = cast<BinaryOperator>(V);
7233
7234 // Return early if there are no flags to propagate to the SCEV.
7236 if (BinOp->hasNoUnsignedWrap())
7238 if (BinOp->hasNoSignedWrap())
7240 if (Flags == SCEV::FlagAnyWrap)
7241 return SCEV::FlagAnyWrap;
7242
7243 return isSCEVExprNeverPoison(BinOp) ? Flags : SCEV::FlagAnyWrap;
7244}
7245
7246const Instruction *
7247ScalarEvolution::getNonTrivialDefiningScopeBound(const SCEV *S) {
7248 if (auto *AddRec = dyn_cast<SCEVAddRecExpr>(S))
7249 return &*AddRec->getLoop()->getHeader()->begin();
7250 if (auto *U = dyn_cast<SCEVUnknown>(S))
7251 if (auto *I = dyn_cast<Instruction>(U->getValue()))
7252 return I;
7253 return nullptr;
7254}
7255
7256const Instruction *
7257ScalarEvolution::getDefiningScopeBound(ArrayRef<const SCEV *> Ops,
7258 bool &Precise) {
7259 Precise = true;
7260 // Do a bounded search of the def relation of the requested SCEVs.
7263 auto pushOp = [&](const SCEV *S) {
7264 if (!Visited.insert(S).second)
7265 return;
7266 // Threshold of 30 here is arbitrary.
7267 if (Visited.size() > 30) {
7268 Precise = false;
7269 return;
7270 }
7271 Worklist.push_back(S);
7272 };
7273
7274 for (const auto *S : Ops)
7275 pushOp(S);
7276
7277 const Instruction *Bound = nullptr;
7278 while (!Worklist.empty()) {
7279 auto *S = Worklist.pop_back_val();
7280 if (auto *DefI = getNonTrivialDefiningScopeBound(S)) {
7281 if (!Bound || DT.dominates(Bound, DefI))
7282 Bound = DefI;
7283 } else {
7284 for (const auto *Op : S->operands())
7285 pushOp(Op);
7286 }
7287 }
7288 return Bound ? Bound : &*F.getEntryBlock().begin();
7289}
7290
7291const Instruction *
7292ScalarEvolution::getDefiningScopeBound(ArrayRef<const SCEV *> Ops) {
7293 bool Discard;
7294 return getDefiningScopeBound(Ops, Discard);
7295}
7296
7297bool ScalarEvolution::isGuaranteedToTransferExecutionTo(const Instruction *A,
7298 const Instruction *B) {
7299 if (A->getParent() == B->getParent() &&
7301 B->getIterator()))
7302 return true;
7303
7304 auto *BLoop = LI.getLoopFor(B->getParent());
7305 if (BLoop && BLoop->getHeader() == B->getParent() &&
7306 BLoop->getLoopPreheader() == A->getParent() &&
7308 A->getParent()->end()) &&
7309 isGuaranteedToTransferExecutionToSuccessor(B->getParent()->begin(),
7310 B->getIterator()))
7311 return true;
7312 return false;
7313}
7314
7315
7316bool ScalarEvolution::isSCEVExprNeverPoison(const Instruction *I) {
7317 // Only proceed if we can prove that I does not yield poison.
7319 return false;
7320
7321 // At this point we know that if I is executed, then it does not wrap
7322 // according to at least one of NSW or NUW. If I is not executed, then we do
7323 // not know if the calculation that I represents would wrap. Multiple
7324 // instructions can map to the same SCEV. If we apply NSW or NUW from I to
7325 // the SCEV, we must guarantee no wrapping for that SCEV also when it is
7326 // derived from other instructions that map to the same SCEV. We cannot make
7327 // that guarantee for cases where I is not executed. So we need to find a
7328 // upper bound on the defining scope for the SCEV, and prove that I is
7329 // executed every time we enter that scope. When the bounding scope is a
7330 // loop (the common case), this is equivalent to proving I executes on every
7331 // iteration of that loop.
7333 for (const Use &Op : I->operands()) {
7334 // I could be an extractvalue from a call to an overflow intrinsic.
7335 // TODO: We can do better here in some cases.
7336 if (isSCEVable(Op->getType()))
7337 SCEVOps.push_back(getSCEV(Op));
7338 }
7339 auto *DefI = getDefiningScopeBound(SCEVOps);
7340 return isGuaranteedToTransferExecutionTo(DefI, I);
7341}
7342
7343bool ScalarEvolution::isAddRecNeverPoison(const Instruction *I, const Loop *L) {
7344 // If we know that \c I can never be poison period, then that's enough.
7345 if (isSCEVExprNeverPoison(I))
7346 return true;
7347
7348 // If the loop only has one exit, then we know that, if the loop is entered,
7349 // any instruction dominating that exit will be executed. If any such
7350 // instruction would result in UB, the addrec cannot be poison.
7351 //
7352 // This is basically the same reasoning as in isSCEVExprNeverPoison(), but
7353 // also handles uses outside the loop header (they just need to dominate the
7354 // single exit).
7355
7356 auto *ExitingBB = L->getExitingBlock();
7357 if (!ExitingBB || !loopHasNoAbnormalExits(L))
7358 return false;
7359
7362
7363 // We start by assuming \c I, the post-inc add recurrence, is poison. Only
7364 // things that are known to be poison under that assumption go on the
7365 // Worklist.
7366 KnownPoison.insert(I);
7367 Worklist.push_back(I);
7368
7369 while (!Worklist.empty()) {
7370 const Instruction *Poison = Worklist.pop_back_val();
7371
7372 for (const Use &U : Poison->uses()) {
7373 const Instruction *PoisonUser = cast<Instruction>(U.getUser());
7374 if (mustTriggerUB(PoisonUser, KnownPoison) &&
7375 DT.dominates(PoisonUser->getParent(), ExitingBB))
7376 return true;
7377
7378 if (propagatesPoison(U) && L->contains(PoisonUser))
7379 if (KnownPoison.insert(PoisonUser).second)
7380 Worklist.push_back(PoisonUser);
7381 }
7382 }
7383
7384 return false;
7385}
7386
7387ScalarEvolution::LoopProperties
7388ScalarEvolution::getLoopProperties(const Loop *L) {
7389 using LoopProperties = ScalarEvolution::LoopProperties;
7390
7391 auto Itr = LoopPropertiesCache.find(L);
7392 if (Itr == LoopPropertiesCache.end()) {
7393 auto HasSideEffects = [](Instruction *I) {
7394 if (auto *SI = dyn_cast<StoreInst>(I))
7395 return !SI->isSimple();
7396
7397 return I->mayThrow() || I->mayWriteToMemory();
7398 };
7399
7400 LoopProperties LP = {/* HasNoAbnormalExits */ true,
7401 /*HasNoSideEffects*/ true};
7402
7403 for (auto *BB : L->getBlocks())
7404 for (auto &I : *BB) {
7406 LP.HasNoAbnormalExits = false;
7407 if (HasSideEffects(&I))
7408 LP.HasNoSideEffects = false;
7409 if (!LP.HasNoAbnormalExits && !LP.HasNoSideEffects)
7410 break; // We're already as pessimistic as we can get.
7411 }
7412
7413 auto InsertPair = LoopPropertiesCache.insert({L, LP});
7414 assert(InsertPair.second && "We just checked!");
7415 Itr = InsertPair.first;
7416 }
7417
7418 return Itr->second;
7419}
7420
7422 // A mustprogress loop without side effects must be finite.
7423 // TODO: The check used here is very conservative. It's only *specific*
7424 // side effects which are well defined in infinite loops.
7425 return isFinite(L) || (isMustProgress(L) && loopHasNoSideEffects(L));
7426}
7427
7428const SCEV *ScalarEvolution::createSCEVIter(Value *V) {
7429 // Worklist item with a Value and a bool indicating whether all operands have
7430 // been visited already.
7433
7434 Stack.emplace_back(V, true);
7435 Stack.emplace_back(V, false);
7436 while (!Stack.empty()) {
7437 auto E = Stack.pop_back_val();
7438 Value *CurV = E.getPointer();
7439
7440 if (getExistingSCEV(CurV))
7441 continue;
7442
7444 const SCEV *CreatedSCEV = nullptr;
7445 // If all operands have been visited already, create the SCEV.
7446 if (E.getInt()) {
7447 CreatedSCEV = createSCEV(CurV);
7448 } else {
7449 // Otherwise get the operands we need to create SCEV's for before creating
7450 // the SCEV for CurV. If the SCEV for CurV can be constructed trivially,
7451 // just use it.
7452 CreatedSCEV = getOperandsToCreate(CurV, Ops);
7453 }
7454
7455 if (CreatedSCEV) {
7456 insertValueToMap(CurV, CreatedSCEV);
7457 } else {
7458 // Queue CurV for SCEV creation, followed by its's operands which need to
7459 // be constructed first.
7460 Stack.emplace_back(CurV, true);
7461 for (Value *Op : Ops)
7462 Stack.emplace_back(Op, false);
7463 }
7464 }
7465
7466 return getExistingSCEV(V);
7467}
7468
7469const SCEV *
7470ScalarEvolution::getOperandsToCreate(Value *V, SmallVectorImpl<Value *> &Ops) {
7471 if (!isSCEVable(V->getType()))
7472 return getUnknown(V);
7473
7474 if (Instruction *I = dyn_cast<Instruction>(V)) {
7475 // Don't attempt to analyze instructions in blocks that aren't
7476 // reachable. Such instructions don't matter, and they aren't required
7477 // to obey basic rules for definitions dominating uses which this
7478 // analysis depends on.
7479 if (!DT.isReachableFromEntry(I->getParent()))
7480 return getUnknown(PoisonValue::get(V->getType()));
7481 } else if (ConstantInt *CI = dyn_cast<ConstantInt>(V))
7482 return getConstant(CI);
7483 else if (isa<GlobalAlias>(V))
7484 return getUnknown(V);
7485 else if (!isa<ConstantExpr>(V))
7486 return getUnknown(V);
7487
7488 Operator *U = cast<Operator>(V);
7489 if (auto BO =
7490 MatchBinaryOp(U, getDataLayout(), AC, DT, dyn_cast<Instruction>(V))) {
7491 bool IsConstArg = isa<ConstantInt>(BO->RHS);
7492 switch (BO->Opcode) {
7493 case Instruction::Add:
7494 case Instruction::Mul: {
7495 // For additions and multiplications, traverse add/mul chains for which we
7496 // can potentially create a single SCEV, to reduce the number of
7497 // get{Add,Mul}Expr calls.
7498 do {
7499 if (BO->Op) {
7500 if (BO->Op != V && getExistingSCEV(BO->Op)) {
7501 Ops.push_back(BO->Op);
7502 break;
7503 }
7504 }
7505 Ops.push_back(BO->RHS);
7506 auto NewBO = MatchBinaryOp(BO->LHS, getDataLayout(), AC, DT,
7507 dyn_cast<Instruction>(V));
7508 if (!NewBO ||
7509 (BO->Opcode == Instruction::Add &&
7510 (NewBO->Opcode != Instruction::Add &&
7511 NewBO->Opcode != Instruction::Sub)) ||
7512 (BO->Opcode == Instruction::Mul &&
7513 NewBO->Opcode != Instruction::Mul)) {
7514 Ops.push_back(BO->LHS);
7515 break;
7516 }
7517 // CreateSCEV calls getNoWrapFlagsFromUB, which under certain conditions
7518 // requires a SCEV for the LHS.
7519 if (BO->Op && (BO->IsNSW || BO->IsNUW)) {
7520 auto *I = dyn_cast<Instruction>(BO->Op);
7521 if (I && programUndefinedIfPoison(I)) {
7522 Ops.push_back(BO->LHS);
7523 break;
7524 }
7525 }
7526 BO = NewBO;
7527 } while (true);
7528 return nullptr;
7529 }
7530 case Instruction::Sub:
7531 case Instruction::UDiv:
7532 case Instruction::URem:
7533 break;
7534 case Instruction::AShr:
7535 case Instruction::Shl:
7536 case Instruction::Xor:
7537 if (!IsConstArg)
7538 return nullptr;
7539 break;
7540 case Instruction::And:
7541 case Instruction::Or:
7542 if (!IsConstArg && !BO->LHS->getType()->isIntegerTy(1))
7543 return nullptr;
7544 break;
7545 case Instruction::LShr:
7546 return getUnknown(V);
7547 default:
7548 llvm_unreachable("Unhandled binop");
7549 break;
7550 }
7551
7552 Ops.push_back(BO->LHS);
7553 Ops.push_back(BO->RHS);
7554 return nullptr;
7555 }
7556
7557 switch (U->getOpcode()) {
7558 case Instruction::Trunc:
7559 case Instruction::ZExt:
7560 case Instruction::SExt:
7561 case Instruction::PtrToInt:
7562 Ops.push_back(U->getOperand(0));
7563 return nullptr;
7564
7565 case Instruction::BitCast:
7566 if (isSCEVable(U->getType()) && isSCEVable(U->getOperand(0)->getType())) {
7567 Ops.push_back(U->getOperand(0));
7568 return nullptr;
7569 }
7570 return getUnknown(V);
7571
7572 case Instruction::SDiv:
7573 case Instruction::SRem:
7574 Ops.push_back(U->getOperand(0));
7575 Ops.push_back(U->getOperand(1));
7576 return nullptr;
7577
7578 case Instruction::GetElementPtr:
7579 assert(cast<GEPOperator>(U)->getSourceElementType()->isSized() &&
7580 "GEP source element type must be sized");
7581 for (Value *Index : U->operands())
7582 Ops.push_back(Index);
7583 return nullptr;
7584
7585 case Instruction::IntToPtr:
7586 return getUnknown(V);
7587
7588 case Instruction::PHI:
7589 // Keep constructing SCEVs' for phis recursively for now.
7590 return nullptr;
7591
7592 case Instruction::Select: {
7593 // Check if U is a select that can be simplified to a SCEVUnknown.
7594 auto CanSimplifyToUnknown = [this, U]() {
7595 if (U->getType()->isIntegerTy(1) || isa<ConstantInt>(U->getOperand(0)))
7596 return false;
7597
7598 auto *ICI = dyn_cast<ICmpInst>(U->getOperand(0));
7599 if (!ICI)
7600 return false;
7601 Value *LHS = ICI->getOperand(0);
7602 Value *RHS = ICI->getOperand(1);
7603 if (ICI->getPredicate() == CmpInst::ICMP_EQ ||
7604 ICI->getPredicate() == CmpInst::ICMP_NE) {
7605 if (!(isa<ConstantInt>(RHS) && cast<ConstantInt>(RHS)->isZero()))
7606 return true;
7607 } else if (getTypeSizeInBits(LHS->getType()) >
7608 getTypeSizeInBits(U->getType()))
7609 return true;
7610 return false;
7611 };
7612 if (CanSimplifyToUnknown())
7613 return getUnknown(U);
7614
7615 for (Value *Inc : U->operands())
7616 Ops.push_back(Inc);
7617 return nullptr;
7618 break;
7619 }
7620 case Instruction::Call:
7621 case Instruction::Invoke:
7622 if (Value *RV = cast<CallBase>(U)->getReturnedArgOperand()) {
7623 Ops.push_back(RV);
7624 return nullptr;
7625 }
7626
7627 if (auto *II = dyn_cast<IntrinsicInst>(U)) {
7628 switch (II->getIntrinsicID()) {
7629 case Intrinsic::abs:
7630 Ops.push_back(II->getArgOperand(0));
7631 return nullptr;
7632 case Intrinsic::umax:
7633 case Intrinsic::umin:
7634 case Intrinsic::smax:
7635 case Intrinsic::smin:
7636 case Intrinsic::usub_sat:
7637 case Intrinsic::uadd_sat:
7638 Ops.push_back(II->getArgOperand(0));
7639 Ops.push_back(II->getArgOperand(1));
7640 return nullptr;
7641 case Intrinsic::start_loop_iterations:
7642 case Intrinsic::annotation:
7643 case Intrinsic::ptr_annotation:
7644 Ops.push_back(II->getArgOperand(0));
7645 return nullptr;
7646 default:
7647 break;
7648 }
7649 }
7650 break;
7651 }
7652
7653 return nullptr;
7654}
7655
7656const SCEV *ScalarEvolution::createSCEV(Value *V) {
7657 if (!isSCEVable(V->getType()))
7658 return getUnknown(V);
7659
7660 if (Instruction *I = dyn_cast<Instruction>(V)) {
7661 // Don't attempt to analyze instructions in blocks that aren't
7662 // reachable. Such instructions don't matter, and they aren't required
7663 // to obey basic rules for definitions dominating uses which this
7664 // analysis depends on.
7665 if (!DT.isReachableFromEntry(I->getParent()))
7666 return getUnknown(PoisonValue::get(V->getType()));
7667 } else if (ConstantInt *CI = dyn_cast<ConstantInt>(V))
7668 return getConstant(CI);
7669 else if (isa<GlobalAlias>(V))
7670 return getUnknown(V);
7671 else if (!isa<ConstantExpr>(V))
7672 return getUnknown(V);
7673
7674 const SCEV *LHS;
7675 const SCEV *RHS;
7676
7677 Operator *U = cast<Operator>(V);
7678 if (auto BO =
7679 MatchBinaryOp(U, getDataLayout(), AC, DT, dyn_cast<Instruction>(V))) {
7680 switch (BO->Opcode) {
7681 case Instruction::Add: {
7682 // The simple thing to do would be to just call getSCEV on both operands
7683 // and call getAddExpr with the result. However if we're looking at a
7684 // bunch of things all added together, this can be quite inefficient,
7685 // because it leads to N-1 getAddExpr calls for N ultimate operands.
7686 // Instead, gather up all the operands and make a single getAddExpr call.
7687 // LLVM IR canonical form means we need only traverse the left operands.
7689 do {
7690 if (BO->Op) {
7691 if (auto *OpSCEV = getExistingSCEV(BO->Op)) {
7692 AddOps.push_back(OpSCEV);
7693 break;
7694 }
7695
7696 // If a NUW or NSW flag can be applied to the SCEV for this
7697 // addition, then compute the SCEV for this addition by itself
7698 // with a separate call to getAddExpr. We need to do that
7699 // instead of pushing the operands of the addition onto AddOps,
7700 // since the flags are only known to apply to this particular
7701 // addition - they may not apply to other additions that can be
7702 // formed with operands from AddOps.
7703 const SCEV *RHS = getSCEV(BO->RHS);
7704 SCEV::NoWrapFlags Flags = getNoWrapFlagsFromUB(BO->Op);
7705 if (Flags != SCEV::FlagAnyWrap) {
7706 const SCEV *LHS = getSCEV(BO->LHS);
7707 if (BO->Opcode == Instruction::Sub)
7708 AddOps.push_back(getMinusSCEV(LHS, RHS, Flags));
7709 else
7710 AddOps.push_back(getAddExpr(LHS, RHS, Flags));
7711 break;
7712 }
7713 }
7714
7715 if (BO->Opcode == Instruction::Sub)
7716 AddOps.push_back(getNegativeSCEV(getSCEV(BO->RHS)));
7717 else
7718 AddOps.push_back(getSCEV(BO->RHS));
7719
7720 auto NewBO = MatchBinaryOp(BO->LHS, getDataLayout(), AC, DT,
7721 dyn_cast<Instruction>(V));
7722 if (!NewBO || (NewBO->Opcode != Instruction::Add &&
7723 NewBO->Opcode != Instruction::Sub)) {
7724 AddOps.push_back(getSCEV(BO->LHS));
7725 break;
7726 }
7727 BO = NewBO;
7728 } while (true);
7729
7730 return getAddExpr(AddOps);
7731 }
7732
7733 case Instruction::Mul: {
7735 do {
7736 if (BO->Op) {
7737 if (auto *OpSCEV = getExistingSCEV(BO->Op)) {
7738 MulOps.push_back(OpSCEV);
7739 break;
7740 }
7741
7742 SCEV::NoWrapFlags Flags = getNoWrapFlagsFromUB(BO->Op);
7743 if (Flags != SCEV::FlagAnyWrap) {
7744 LHS = getSCEV(BO->LHS);
7745 RHS = getSCEV(BO->RHS);
7746 MulOps.push_back(getMulExpr(LHS, RHS, Flags));
7747 break;
7748 }
7749 }
7750
7751 MulOps.push_back(getSCEV(BO->RHS));
7752 auto NewBO = MatchBinaryOp(BO->LHS, getDataLayout(), AC, DT,
7753 dyn_cast<Instruction>(V));
7754 if (!NewBO || NewBO->Opcode != Instruction::Mul) {
7755 MulOps.push_back(getSCEV(BO->LHS));
7756 break;
7757 }
7758 BO = NewBO;
7759 } while (true);
7760
7761 return getMulExpr(MulOps);
7762 }
7763 case Instruction::UDiv:
7764 LHS = getSCEV(BO->LHS);
7765 RHS = getSCEV(BO->RHS);
7766 return getUDivExpr(LHS, RHS);
7767 case Instruction::URem:
7768 LHS = getSCEV(BO->LHS);
7769 RHS = getSCEV(BO->RHS);
7770 return getURemExpr(LHS, RHS);
7771 case Instruction::Sub: {
7773 if (BO->Op)
7774 Flags = getNoWrapFlagsFromUB(BO->Op);
7775 LHS = getSCEV(BO->LHS);
7776 RHS = getSCEV(BO->RHS);
7777 return getMinusSCEV(LHS, RHS, Flags);
7778 }
7779 case Instruction::And:
7780 // For an expression like x&255 that merely masks off the high bits,
7781 // use zext(trunc(x)) as the SCEV expression.
7782 if (ConstantInt *CI = dyn_cast<ConstantInt>(BO->RHS)) {
7783 if (CI->isZero())
7784 return getSCEV(BO->RHS);
7785 if (CI->isMinusOne())
7786 return getSCEV(BO->LHS);
7787 const APInt &A = CI->getValue();
7788
7789 // Instcombine's ShrinkDemandedConstant may strip bits out of
7790 // constants, obscuring what would otherwise be a low-bits mask.
7791 // Use computeKnownBits to compute what ShrinkDemandedConstant
7792 // knew about to reconstruct a low-bits mask value.
7793 unsigned LZ = A.countl_zero();
7794 unsigned TZ = A.countr_zero();
7795 unsigned BitWidth = A.getBitWidth();
7796 KnownBits Known(BitWidth);
7797 computeKnownBits(BO->LHS, Known, getDataLayout(),
7798 0, &AC, nullptr, &DT);
7799
7800 APInt EffectiveMask =
7801 APInt::getLowBitsSet(BitWidth, BitWidth - LZ - TZ).shl(TZ);
7802 if ((LZ != 0 || TZ != 0) && !((~A & ~Known.Zero) & EffectiveMask)) {
7803 const SCEV *MulCount = getConstant(APInt::getOneBitSet(BitWidth, TZ));
7804 const SCEV *LHS = getSCEV(BO->LHS);
7805 const SCEV *ShiftedLHS = nullptr;
7806 if (auto *LHSMul = dyn_cast<SCEVMulExpr>(LHS)) {
7807 if (auto *OpC = dyn_cast<SCEVConstant>(LHSMul->getOperand(0))) {
7808 // For an expression like (x * 8) & 8, simplify the multiply.
7809 unsigned MulZeros = OpC->getAPInt().countr_zero();
7810 unsigned GCD = std::min(MulZeros, TZ);
7811 APInt DivAmt = APInt::getOneBitSet(BitWidth, TZ - GCD);
7813 MulOps.push_back(getConstant(OpC->getAPInt().lshr(GCD)));
7814 append_range(MulOps, LHSMul->operands().drop_front());
7815 auto *NewMul = getMulExpr(MulOps, LHSMul->getNoWrapFlags());
7816 ShiftedLHS = getUDivExpr(NewMul, getConstant(DivAmt));
7817 }
7818 }
7819 if (!ShiftedLHS)
7820 ShiftedLHS = getUDivExpr(LHS, MulCount);
7821 return getMulExpr(
7823 getTruncateExpr(ShiftedLHS,
7824 IntegerType::get(getContext(), BitWidth - LZ - TZ)),
7825 BO->LHS->getType()),
7826 MulCount);
7827 }
7828 }
7829 // Binary `and` is a bit-wise `umin`.
7830 if (BO->LHS->getType()->isIntegerTy(1)) {
7831 LHS = getSCEV(BO->LHS);
7832 RHS = getSCEV(BO->RHS);
7833 return getUMinExpr(LHS, RHS);
7834 }
7835 break;
7836
7837 case Instruction::Or:
7838 // Binary `or` is a bit-wise `umax`.
7839 if (BO->LHS->getType()->isIntegerTy(1)) {
7840 LHS = getSCEV(BO->LHS);
7841 RHS = getSCEV(BO->RHS);
7842 return getUMaxExpr(LHS, RHS);
7843 }
7844 break;
7845
7846 case Instruction::Xor:
7847 if (ConstantInt *CI = dyn_cast<ConstantInt>(BO->RHS)) {
7848 // If the RHS of xor is -1, then this is a not operation.
7849 if (CI->isMinusOne())
7850 return getNotSCEV(getSCEV(BO->LHS));
7851
7852 // Model xor(and(x, C), C) as and(~x, C), if C is a low-bits mask.
7853 // This is a variant of the check for xor with -1, and it handles
7854 // the case where instcombine has trimmed non-demanded bits out
7855 // of an xor with -1.
7856 if (auto *LBO = dyn_cast<BinaryOperator>(BO->LHS))
7857 if (ConstantInt *LCI = dyn_cast<ConstantInt>(LBO->getOperand(1)))
7858 if (LBO->getOpcode() == Instruction::And &&
7859 LCI->getValue() == CI->getValue())
7860 if (const SCEVZeroExtendExpr *Z =
7861 dyn_cast<SCEVZeroExtendExpr>(getSCEV(BO->LHS))) {
7862 Type *UTy = BO->LHS->getType();
7863 const SCEV *Z0 = Z->getOperand();
7864 Type *Z0Ty = Z0->getType();
7865 unsigned Z0TySize = getTypeSizeInBits(Z0Ty);
7866
7867 // If C is a low-bits mask, the zero extend is serving to
7868 // mask off the high bits. Complement the operand and
7869 // re-apply the zext.
7870 if (CI->getValue().isMask(Z0TySize))
7871 return getZeroExtendExpr(getNotSCEV(Z0), UTy);
7872
7873 // If C is a single bit, it may be in the sign-bit position
7874 // before the zero-extend. In this case, represent the xor
7875 // using an add, which is equivalent, and re-apply the zext.
7876 APInt Trunc = CI->getValue().trunc(Z0TySize);
7877 if (Trunc.zext(getTypeSizeInBits(UTy)) == CI->getValue() &&
7878 Trunc.isSignMask())
7879 return getZeroExtendExpr(getAddExpr(Z0, getConstant(Trunc)),
7880 UTy);
7881 }
7882 }
7883 break;
7884
7885 case Instruction::Shl:
7886 // Turn shift left of a constant amount into a multiply.
7887 if (ConstantInt *SA = dyn_cast<ConstantInt>(BO->RHS)) {
7888 uint32_t BitWidth = cast<IntegerType>(SA->getType())->getBitWidth();
7889
7890 // If the shift count is not less than the bitwidth, the result of
7891 // the shift is undefined. Don't try to analyze it, because the
7892 // resolution chosen here may differ from the resolution chosen in
7893 // other parts of the compiler.
7894 if (SA->getValue().uge(BitWidth))
7895 break;
7896
7897 // We can safely preserve the nuw flag in all cases. It's also safe to
7898 // turn a nuw nsw shl into a nuw nsw mul. However, nsw in isolation
7899 // requires special handling. It can be preserved as long as we're not
7900 // left shifting by bitwidth - 1.
7901 auto Flags = SCEV::FlagAnyWrap;
7902 if (BO->Op) {
7903 auto MulFlags = getNoWrapFlagsFromUB(BO->Op);
7904 if ((MulFlags & SCEV::FlagNSW) &&
7905 ((MulFlags & SCEV::FlagNUW) || SA->getValue().ult(BitWidth - 1)))
7907 if (MulFlags & SCEV::FlagNUW)
7909 }
7910
7911 ConstantInt *X = ConstantInt::get(
7912 getContext(), APInt::getOneBitSet(BitWidth, SA->getZExtValue()));
7913 return getMulExpr(getSCEV(BO->LHS), getConstant(X), Flags);
7914 }
7915 break;
7916
7917 case Instruction::AShr:
7918 // AShr X, C, where C is a constant.
7919 ConstantInt *CI = dyn_cast<ConstantInt>(BO->RHS);
7920 if (!CI)
7921 break;
7922
7923 Type *OuterTy = BO->LHS->getType();
7925 // If the shift count is not less than the bitwidth, the result of
7926 // the shift is undefined. Don't try to analyze it, because the
7927 // resolution chosen here may differ from the resolution chosen in
7928 // other parts of the compiler.
7929 if (CI->getValue().uge(BitWidth))
7930 break;
7931
7932 if (CI->isZero())
7933 return getSCEV(BO->LHS); // shift by zero --> noop
7934
7935 uint64_t AShrAmt = CI->getZExtValue();
7936 Type *TruncTy = IntegerType::get(getContext(), BitWidth - AShrAmt);
7937
7938 Operator *L = dyn_cast<Operator>(BO->LHS);
7939 const SCEV *AddTruncateExpr = nullptr;
7940 ConstantInt *ShlAmtCI = nullptr;
7941 const SCEV *AddConstant = nullptr;
7942
7943 if (L && L->getOpcode() == Instruction::Add) {
7944 // X = Shl A, n
7945 // Y = Add X, c
7946 // Z = AShr Y, m
7947 // n, c and m are constants.
7948
7949 Operator *LShift = dyn_cast<Operator>(L->getOperand(0));
7950 ConstantInt *AddOperandCI = dyn_cast<ConstantInt>(L->getOperand(1));
7951 if (LShift && LShift->getOpcode() == Instruction::Shl) {
7952 if (AddOperandCI) {
7953 const SCEV *ShlOp0SCEV = getSCEV(LShift->getOperand(0));
7954 ShlAmtCI = dyn_cast<ConstantInt>(LShift->getOperand(1));
7955 // since we truncate to TruncTy, the AddConstant should be of the
7956 // same type, so create a new Constant with type same as TruncTy.
7957 // Also, the Add constant should be shifted right by AShr amount.
7958 APInt AddOperand = AddOperandCI->getValue().ashr(AShrAmt);
7959 AddConstant = getConstant(AddOperand.trunc(BitWidth - AShrAmt));
7960 // we model the expression as sext(add(trunc(A), c << n)), since the
7961 // sext(trunc) part is already handled below, we create a
7962 // AddExpr(TruncExp) which will be used later.
7963 AddTruncateExpr = getTruncateExpr(ShlOp0SCEV, TruncTy);
7964 }
7965 }
7966 } else if (L && L->getOpcode() == Instruction::Shl) {
7967 // X = Shl A, n
7968 // Y = AShr X, m
7969 // Both n and m are constant.
7970
7971 const SCEV *ShlOp0SCEV = getSCEV(L->getOperand(0));
7972 ShlAmtCI = dyn_cast<ConstantInt>(L->getOperand(1));
7973 AddTruncateExpr = getTruncateExpr(ShlOp0SCEV, TruncTy);
7974 }
7975
7976 if (AddTruncateExpr && ShlAmtCI) {
7977 // We can merge the two given cases into a single SCEV statement,
7978 // incase n = m, the mul expression will be 2^0, so it gets resolved to
7979 // a simpler case. The following code handles the two cases:
7980 //
7981 // 1) For a two-shift sext-inreg, i.e. n = m,
7982 // use sext(trunc(x)) as the SCEV expression.
7983 //
7984 // 2) When n > m, use sext(mul(trunc(x), 2^(n-m)))) as the SCEV
7985 // expression. We already checked that ShlAmt < BitWidth, so
7986 // the multiplier, 1 << (ShlAmt - AShrAmt), fits into TruncTy as
7987 // ShlAmt - AShrAmt < Amt.
7988 const APInt &ShlAmt = ShlAmtCI->getValue();
7989 if (ShlAmt.ult(BitWidth) && ShlAmt.uge(AShrAmt)) {
7991 ShlAmtCI->getZExtValue() - AShrAmt);
7992 const SCEV *CompositeExpr =
7993 getMulExpr(AddTruncateExpr, getConstant(Mul));
7994 if (L->getOpcode() != Instruction::Shl)
7995 CompositeExpr = getAddExpr(CompositeExpr, AddConstant);
7996
7997 return getSignExtendExpr(CompositeExpr, OuterTy);
7998 }
7999 }
8000 break;
8001 }
8002 }
8003
8004 switch (U->getOpcode()) {
8005 case Instruction::Trunc:
8006 return getTruncateExpr(getSCEV(U->getOperand(0)), U->getType());
8007
8008 case Instruction::ZExt:
8009 return getZeroExtendExpr(getSCEV(U->getOperand(0)), U->getType());
8010
8011 case Instruction::SExt:
8012 if (auto BO = MatchBinaryOp(U->getOperand(0), getDataLayout(), AC, DT,
8013 dyn_cast<Instruction>(V))) {
8014 // The NSW flag of a subtract does not always survive the conversion to
8015 // A + (-1)*B. By pushing sign extension onto its operands we are much
8016 // more likely to preserve NSW and allow later AddRec optimisations.
8017 //
8018 // NOTE: This is effectively duplicating this logic from getSignExtend:
8019 // sext((A + B + ...)<nsw>) --> (sext(A) + sext(B) + ...)<nsw>
8020 // but by that point the NSW information has potentially been lost.
8021 if (BO->Opcode == Instruction::Sub && BO->IsNSW) {
8022 Type *Ty = U->getType();
8023 auto *V1 = getSignExtendExpr(getSCEV(BO->LHS), Ty);
8024 auto *V2 = getSignExtendExpr(getSCEV(BO->RHS), Ty);
8025 return getMinusSCEV(V1, V2, SCEV::FlagNSW);
8026 }
8027 }
8028 return getSignExtendExpr(getSCEV(U->getOperand(0)), U->getType());
8029
8030 case Instruction::BitCast:
8031 // BitCasts are no-op casts so we just eliminate the cast.
8032 if (isSCEVable(U->getType()) && isSCEVable(U->getOperand(0)->getType()))
8033 return getSCEV(U->getOperand(0));
8034 break;
8035
8036 case Instruction::PtrToInt: {
8037 // Pointer to integer cast is straight-forward, so do model it.
8038 const SCEV *Op = getSCEV(U->getOperand(0));
8039 Type *DstIntTy = U->getType();
8040 // But only if effective SCEV (integer) type is wide enough to represent
8041 // all possible pointer values.
8042 const SCEV *IntOp = getPtrToIntExpr(Op, DstIntTy);
8043 if (isa<SCEVCouldNotCompute>(IntOp))
8044 return getUnknown(V);
8045 return IntOp;
8046 }
8047 case Instruction::IntToPtr:
8048 // Just don't deal with inttoptr casts.
8049 return getUnknown(V);
8050
8051 case Instruction::SDiv:
8052 // If both operands are non-negative, this is just an udiv.
8053 if (isKnownNonNegative(getSCEV(U->getOperand(0))) &&
8054 isKnownNonNegative(getSCEV(U->getOperand(1))))
8055 return getUDivExpr(getSCEV(U->getOperand(0)), getSCEV(U->getOperand(1)));
8056 break;
8057
8058 case Instruction::SRem:
8059 // If both operands are non-negative, this is just an urem.
8060 if (isKnownNonNegative(getSCEV(U->getOperand(0))) &&
8061 isKnownNonNegative(getSCEV(U->getOperand(1))))
8062 return getURemExpr(getSCEV(U->getOperand(0)), getSCEV(U->getOperand(1)));
8063 break;
8064
8065 case Instruction::GetElementPtr:
8066 return createNodeForGEP(cast<GEPOperator>(U));
8067
8068 case Instruction::PHI:
8069 return createNodeForPHI(cast<PHINode>(U));
8070
8071 case Instruction::Select:
8072 return createNodeForSelectOrPHI(U, U->getOperand(0), U->getOperand(1),
8073 U->getOperand(2));
8074
8075 case Instruction::Call:
8076 case Instruction::Invoke:
8077 if (Value *RV = cast<CallBase>(U)->getReturnedArgOperand())
8078 return getSCEV(RV);
8079
8080 if (auto *II = dyn_cast<IntrinsicInst>(U)) {
8081 switch (II->getIntrinsicID()) {
8082 case Intrinsic::abs:
8083 return getAbsExpr(
8084 getSCEV(II->getArgOperand(0)),
8085 /*IsNSW=*/cast<ConstantInt>(II->getArgOperand(1))->isOne());
8086 case Intrinsic::umax:
8087 LHS = getSCEV(II->getArgOperand(0));
8088 RHS = getSCEV(II->getArgOperand(1));
8089 return getUMaxExpr(LHS, RHS);
8090 case Intrinsic::umin:
8091 LHS = getSCEV(II->getArgOperand(0));
8092 RHS = getSCEV(II->getArgOperand(1));
8093 return getUMinExpr(LHS, RHS);
8094 case Intrinsic::smax:
8095 LHS = getSCEV(II->getArgOperand(0));
8096 RHS = getSCEV(II->getArgOperand(1));
8097 return getSMaxExpr(LHS, RHS);
8098 case Intrinsic::smin:
8099 LHS = getSCEV(II->getArgOperand(0));
8100 RHS = getSCEV(II->getArgOperand(1));
8101 return getSMinExpr(LHS, RHS);
8102 case Intrinsic::usub_sat: {
8103 const SCEV *X = getSCEV(II->getArgOperand(0));
8104 const SCEV *Y = getSCEV(II->getArgOperand(1));
8105 const SCEV *ClampedY = getUMinExpr(X, Y);
8106 return getMinusSCEV(X, ClampedY, SCEV::FlagNUW);
8107 }
8108 case Intrinsic::uadd_sat: {
8109 const SCEV *X = getSCEV(II->getArgOperand(0));
8110 const SCEV *Y = getSCEV(II->getArgOperand(1));
8111 const SCEV *ClampedX = getUMinExpr(X, getNotSCEV(Y));
8112 return getAddExpr(ClampedX, Y, SCEV::FlagNUW);
8113 }
8114 case Intrinsic::start_loop_iterations:
8115 case Intrinsic::annotation:
8116 case Intrinsic::ptr_annotation:
8117 // A start_loop_iterations or llvm.annotation or llvm.prt.annotation is
8118 // just eqivalent to the first operand for SCEV purposes.
8119 return getSCEV(II->getArgOperand(0));
8120 case Intrinsic::vscale:
8121 return getVScale(II->getType());
8122 default:
8123 break;
8124 }
8125 }
8126 break;
8127 }
8128
8129 return getUnknown(V);
8130}
8131
8132//===----------------------------------------------------------------------===//
8133// Iteration Count Computation Code
8134//
8135
8137 if (isa<SCEVCouldNotCompute>(ExitCount))
8138 return getCouldNotCompute();
8139
8140 auto *ExitCountType = ExitCount->getType();
8141 assert(ExitCountType->isIntegerTy());
8142 auto *EvalTy = Type::getIntNTy(ExitCountType->getContext(),
8143 1 + ExitCountType->getScalarSizeInBits());
8144 return getTripCountFromExitCount(ExitCount, EvalTy, nullptr);
8145}
8146
8148 Type *EvalTy,
8149 const Loop *L) {
8150 if (isa<SCEVCouldNotCompute>(ExitCount))
8151 return getCouldNotCompute();
8152
8153 unsigned ExitCountSize = getTypeSizeInBits(ExitCount->getType());
8154 unsigned EvalSize = EvalTy->getPrimitiveSizeInBits();
8155
8156 auto CanAddOneWithoutOverflow = [&]() {
8157 ConstantRange ExitCountRange =
8158 getRangeRef(ExitCount, RangeSignHint::HINT_RANGE_UNSIGNED);
8159 if (!ExitCountRange.contains(APInt::getMaxValue(ExitCountSize)))
8160 return true;
8161
8162 return L && isLoopEntryGuardedByCond(L, ICmpInst::ICMP_NE, ExitCount,
8163 getMinusOne(ExitCount->getType()));
8164 };
8165
8166 // If we need to zero extend the backedge count, check if we can add one to
8167 // it prior to zero extending without overflow. Provided this is safe, it
8168 // allows better simplification of the +1.
8169 if (EvalSize > ExitCountSize && CanAddOneWithoutOverflow())
8170 return getZeroExtendExpr(
8171 getAddExpr(ExitCount, getOne(ExitCount->getType())), EvalTy);
8172
8173 // Get the total trip count from the count by adding 1. This may wrap.
8174 return getAddExpr(getTruncateOrZeroExtend(ExitCount, EvalTy), getOne(EvalTy));
8175}
8176
8177static unsigned getConstantTripCount(const SCEVConstant *ExitCount) {
8178 if (!ExitCount)
8179 return 0;
8180
8181 ConstantInt *ExitConst = ExitCount->getValue();
8182
8183 // Guard against huge trip counts.
8184 if (ExitConst->getValue().getActiveBits() > 32)
8185 return 0;
8186
8187 // In case of integer overflow, this returns 0, which is correct.
8188 return ((unsigned)ExitConst->getZExtValue()) + 1;
8189}
8190
8192 auto *ExitCount = dyn_cast<SCEVConstant>(getBackedgeTakenCount(L, Exact));
8193 return getConstantTripCount(ExitCount);
8194}
8195
8196unsigned
8198 const BasicBlock *ExitingBlock) {
8199 assert(ExitingBlock && "Must pass a non-null exiting block!");
8200 assert(L->isLoopExiting(ExitingBlock) &&
8201 "Exiting block must actually branch out of the loop!");
8202 const SCEVConstant *ExitCount =
8203 dyn_cast<SCEVConstant>(getExitCount(L, ExitingBlock));
8204 return getConstantTripCount(ExitCount);
8205}
8206
8208 const auto *MaxExitCount =
8209 dyn_cast<SCEVConstant>(getConstantMaxBackedgeTakenCount(L));
8210 return getConstantTripCount(MaxExitCount);
8211}
8212
8214 SmallVector<BasicBlock *, 8> ExitingBlocks;
8215 L->getExitingBlocks(ExitingBlocks);
8216
8217 std::optional<unsigned> Res;
8218 for (auto *ExitingBB : ExitingBlocks) {
8219 unsigned Multiple = getSmallConstantTripMultiple(L, ExitingBB);
8220 if (!Res)
8221 Res = Multiple;
8222 Res = (unsigned)std::gcd(*Res, Multiple);
8223 }
8224 return Res.value_or(1);
8225}
8226
8228 const SCEV *ExitCount) {
8229 if (ExitCount == getCouldNotCompute())
8230 return 1;
8231
8232 // Get the trip count
8233 const SCEV *TCExpr = getTripCountFromExitCount(applyLoopGuards(ExitCount, L));
8234
8235 APInt Multiple = getNonZeroConstantMultiple(TCExpr);
8236 // If a trip multiple is huge (>=2^32), the trip count is still divisible by
8237 // the greatest power of 2 divisor less than 2^32.
8238 return Multiple.getActiveBits() > 32
8239 ? 1U << std::min((unsigned)31, Multiple.countTrailingZeros())
8240 : (unsigned)Multiple.zextOrTrunc(32).getZExtValue();
8241}
8242
8243/// Returns the largest constant divisor of the trip count of this loop as a
8244/// normal unsigned value, if possible. This means that the actual trip count is
8245/// always a multiple of the returned value (don't forget the trip count could
8246/// very well be zero as well!).
8247///
8248/// Returns 1 if the trip count is unknown or not guaranteed to be the
8249/// multiple of a constant (which is also the case if the trip count is simply
8250/// constant, use getSmallConstantTripCount for that case), Will also return 1
8251/// if the trip count is very large (>= 2^32).
8252///
8253/// As explained in the comments for getSmallConstantTripCount, this assumes
8254/// that control exits the loop via ExitingBlock.
8255unsigned
8257 const BasicBlock *ExitingBlock) {
8258 assert(ExitingBlock && "Must pass a non-null exiting block!");
8259 assert(L->isLoopExiting(ExitingBlock) &&
8260 "Exiting block must actually branch out of the loop!");
8261 const SCEV *ExitCount = getExitCount(L, ExitingBlock);
8262 return getSmallConstantTripMultiple(L, ExitCount);
8263}
8264
8266 const BasicBlock *ExitingBlock,
8267 ExitCountKind Kind) {
8268 switch (Kind) {
8269 case Exact:
8270 return getBackedgeTakenInfo(L).getExact(ExitingBlock, this);
8271 case SymbolicMaximum:
8272 return getBackedgeTakenInfo(L).getSymbolicMax(ExitingBlock, this);
8273 case ConstantMaximum:
8274 return getBackedgeTakenInfo(L).getConstantMax(ExitingBlock, this);
8275 };
8276 llvm_unreachable("Invalid ExitCountKind!");
8277}
8278
8279const SCEV *
8282 return getPredicatedBackedgeTakenInfo(L).getExact(L, this, &Preds);
8283}
8284
8286 ExitCountKind Kind) {
8287 switch (Kind) {
8288 case Exact:
8289 return getBackedgeTakenInfo(L).getExact(L, this);
8290 case ConstantMaximum:
8291 return getBackedgeTakenInfo(L).getConstantMax(this);
8292 case SymbolicMaximum:
8293 return getBackedgeTakenInfo(L).getSymbolicMax(L, this);
8294 };
8295 llvm_unreachable("Invalid ExitCountKind!");
8296}
8297
8299 return getBackedgeTakenInfo(L).isConstantMaxOrZero(this);
8300}
8301
8302/// Push PHI nodes in the header of the given loop onto the given Worklist.
8303static void PushLoopPHIs(const Loop *L,
8306 BasicBlock *Header = L->getHeader();
8307
8308 // Push all Loop-header PHIs onto the Worklist stack.
8309 for (PHINode &PN : Header->phis())
8310 if (Visited.insert(&PN).second)
8311 Worklist.push_back(&PN);
8312}
8313
8314const ScalarEvolution::BackedgeTakenInfo &
8315ScalarEvolution::getPredicatedBackedgeTakenInfo(const Loop *L) {
8316 auto &BTI = getBackedgeTakenInfo(L);
8317 if (BTI.hasFullInfo())
8318 return BTI;
8319
8320 auto Pair = PredicatedBackedgeTakenCounts.insert({L, BackedgeTakenInfo()});
8321
8322 if (!Pair.second)
8323 return Pair.first->second;
8324
8325 BackedgeTakenInfo Result =
8326 computeBackedgeTakenCount(L, /*AllowPredicates=*/true);
8327
8328 return PredicatedBackedgeTakenCounts.find(L)->second = std::move(Result);
8329}
8330
8331ScalarEvolution::BackedgeTakenInfo &
8332ScalarEvolution::getBackedgeTakenInfo(const Loop *L) {
8333 // Initially insert an invalid entry for this loop. If the insertion
8334 // succeeds, proceed to actually compute a backedge-taken count and
8335 // update the value. The temporary CouldNotCompute value tells SCEV
8336 // code elsewhere that it shouldn't attempt to request a new
8337 // backedge-taken count, which could result in infinite recursion.
8338 std::pair<DenseMap<const Loop *, BackedgeTakenInfo>::iterator, bool> Pair =
8339 BackedgeTakenCounts.insert({L, BackedgeTakenInfo()});
8340 if (!Pair.second)
8341 return Pair.first->second;
8342
8343 // computeBackedgeTakenCount may allocate memory for its result. Inserting it
8344 // into the BackedgeTakenCounts map transfers ownership. Otherwise, the result
8345 // must be cleared in this scope.
8346 BackedgeTakenInfo Result = computeBackedgeTakenCount(L);
8347
8348 // Now that we know more about the trip count for this loop, forget any
8349 // existing SCEV values for PHI nodes in this loop since they are only
8350 // conservative estimates made without the benefit of trip count
8351 // information. This invalidation is not necessary for correctness, and is
8352 // only done to produce more precise results.
8353 if (Result.hasAnyInfo()) {
8354 // Invalidate any expression using an addrec in this loop.
8356 auto LoopUsersIt = LoopUsers.find(L);
8357 if (LoopUsersIt != LoopUsers.end())
8358 append_range(ToForget, LoopUsersIt->second);
8359 forgetMemoizedResults(ToForget);
8360
8361 // Invalidate constant-evolved loop header phis.
8362 for (PHINode &PN : L->getHeader()->phis())
8363 ConstantEvolutionLoopExitValue.erase(&PN);
8364 }
8365
8366 // Re-lookup the insert position, since the call to
8367 // computeBackedgeTakenCount above could result in a
8368 // recusive call to getBackedgeTakenInfo (on a different
8369 // loop), which would invalidate the iterator computed
8370 // earlier.
8371 return BackedgeTakenCounts.find(L)->second = std::move(Result);
8372}
8373
8375 // This method is intended to forget all info about loops. It should
8376 // invalidate caches as if the following happened:
8377 // - The trip counts of all loops have changed arbitrarily
8378 // - Every llvm::Value has been updated in place to produce a different
8379 // result.
8380 BackedgeTakenCounts.clear();
8381 PredicatedBackedgeTakenCounts.clear();
8382 BECountUsers.clear();
8383 LoopPropertiesCache.clear();
8384 ConstantEvolutionLoopExitValue.clear();
8385 ValueExprMap.clear();
8386 ValuesAtScopes.clear();
8387 ValuesAtScopesUsers.clear();
8388 LoopDispositions.clear();
8389 BlockDispositions.clear();
8390 UnsignedRanges.clear();
8391 SignedRanges.clear();
8392 ExprValueMap.clear();
8393 HasRecMap.clear();
8394 ConstantMultipleCache.clear();
8395 PredicatedSCEVRewrites.clear();
8396 FoldCache.clear();
8397 FoldCacheUser.clear();
8398}
8399void ScalarEvolution::visitAndClearUsers(
8403 while (!Worklist.empty()) {
8404 Instruction *I = Worklist.pop_back_val();
8405 if (!isSCEVable(I->getType()))
8406 continue;
8407
8409 ValueExprMap.find_as(static_cast<Value *>(I));
8410 if (It != ValueExprMap.end()) {
8411 eraseValueFromMap(It->first);
8412 ToForget.push_back(It->second);
8413 if (PHINode *PN = dyn_cast<PHINode>(I))
8414 ConstantEvolutionLoopExitValue.erase(PN);
8415 }
8416
8417 PushDefUseChildren(I, Worklist, Visited);
8418 }
8419}
8420
8422 SmallVector<const Loop *, 16> LoopWorklist(1, L);
8426
8427 // Iterate over all the loops and sub-loops to drop SCEV information.
8428 while (!LoopWorklist.empty()) {
8429 auto *CurrL = LoopWorklist.pop_back_val();
8430
8431 // Drop any stored trip count value.
8432 forgetBackedgeTakenCounts(CurrL, /* Predicated */ false);
8433 forgetBackedgeTakenCounts(CurrL, /* Predicated */ true);
8434
8435 // Drop information about predicated SCEV rewrites for this loop.
8436 for (auto I = PredicatedSCEVRewrites.begin();
8437 I != PredicatedSCEVRewrites.end();) {
8438 std::pair<const SCEV *, const Loop *> Entry = I->first;
8439 if (Entry.second == CurrL)
8440 PredicatedSCEVRewrites.erase(I++);
8441 else
8442 ++I;
8443 }
8444
8445 auto LoopUsersItr = LoopUsers.find(CurrL);
8446 if (LoopUsersItr != LoopUsers.end()) {
8447 ToForget.insert(ToForget.end(), LoopUsersItr->second.begin(),
8448 LoopUsersItr->second.end());
8449 }
8450
8451 // Drop information about expressions based on loop-header PHIs.
8452 PushLoopPHIs(CurrL, Worklist, Visited);
8453 visitAndClearUsers(Worklist, Visited, ToForget);
8454
8455 LoopPropertiesCache.erase(CurrL);
8456 // Forget all contained loops too, to avoid dangling entries in the
8457 // ValuesAtScopes map.
8458 LoopWorklist.append(CurrL->begin(), CurrL->end());
8459 }
8460 forgetMemoizedResults(ToForget);
8461}
8462
8464 forgetLoop(L->getOutermostLoop());
8465}
8466
8468 Instruction *I = dyn_cast<Instruction>(V);
8469 if (!I) return;
8470
8471 // Drop information about expressions based on loop-header PHIs.
8475 Worklist.push_back(I);
8476 Visited.insert(I);
8477 visitAndClearUsers(Worklist, Visited, ToForget);
8478
8479 forgetMemoizedResults(ToForget);
8480}
8481
8483 if (!isSCEVable(V->getType()))
8484 return;
8485
8486 // If SCEV looked through a trivial LCSSA phi node, we might have SCEV's
8487 // directly using a SCEVUnknown/SCEVAddRec defined in the loop. After an
8488 // extra predecessor is added, this is no longer valid. Find all Unknowns and
8489 // AddRecs defined in the loop and invalidate any SCEV's making use of them.
8490 if (const SCEV *S = getExistingSCEV(V)) {
8491 struct InvalidationRootCollector {
8492 Loop *L;
8494
8495 InvalidationRootCollector(Loop *L) : L(L) {}
8496
8497 bool follow(const SCEV *S) {
8498 if (auto *SU = dyn_cast<SCEVUnknown>(S)) {
8499 if (auto *I = dyn_cast<Instruction>(SU->getValue()))
8500 if (L->contains(I))
8501 Roots.push_back(S);
8502 } else if (auto *AddRec = dyn_cast<SCEVAddRecExpr>(S)) {
8503 if (L->contains(AddRec->getLoop()))
8504 Roots.push_back(S);
8505 }
8506 return true;
8507 }
8508 bool isDone() const { return false; }
8509 };
8510
8511 InvalidationRootCollector C(L);
8512 visitAll(S, C);
8513 forgetMemoizedResults(C.Roots);
8514 }
8515
8516 // Also perform the normal invalidation.
8517 forgetValue(V);
8518}
8519
8520void ScalarEvolution::forgetLoopDispositions() { LoopDispositions.clear(); }
8521
8523 // Unless a specific value is passed to invalidation, completely clear both
8524 // caches.
8525 if (!V) {
8526 BlockDispositions.clear();
8527 LoopDispositions.clear();
8528 return;
8529 }
8530
8531 if (!isSCEVable(V->getType()))
8532 return;
8533
8534 const SCEV *S = getExistingSCEV(V);
8535 if (!S)
8536 return;
8537
8538 // Invalidate the block and loop dispositions cached for S. Dispositions of
8539 // S's users may change if S's disposition changes (i.e. a user may change to
8540 // loop-invariant, if S changes to loop invariant), so also invalidate
8541 // dispositions of S's users recursively.
8542 SmallVector<const SCEV *, 8> Worklist = {S};
8544 while (!Worklist.empty()) {
8545 const SCEV *Curr = Worklist.pop_back_val();
8546 bool LoopDispoRemoved = LoopDispositions.erase(Curr);
8547 bool BlockDispoRemoved = BlockDispositions.erase(Curr);
8548 if (!LoopDispoRemoved && !BlockDispoRemoved)
8549 continue;
8550 auto Users = SCEVUsers.find(Curr);
8551 if (Users != SCEVUsers.end())
8552 for (const auto *User : Users->second)
8553 if (Seen.insert(User).second)
8554 Worklist.push_back(User);
8555 }
8556}
8557
8558/// Get the exact loop backedge taken count considering all loop exits. A
8559/// computable result can only be returned for loops with all exiting blocks
8560/// dominating the latch. howFarToZero assumes that the limit of each loop test
8561/// is never skipped. This is a valid assumption as long as the loop exits via
8562/// that test. For precise results, it is the caller's responsibility to specify
8563/// the relevant loop exiting block using getExact(ExitingBlock, SE).
8564const SCEV *
8565ScalarEvolution::BackedgeTakenInfo::getExact(const Loop *L, ScalarEvolution *SE,
8567 // If any exits were not computable, the loop is not computable.
8568 if (!isComplete() || ExitNotTaken.empty())
8569 return SE->getCouldNotCompute();
8570
8571 const BasicBlock *Latch = L->getLoopLatch();
8572 // All exiting blocks we have collected must dominate the only backedge.
8573 if (!Latch)
8574 return SE->getCouldNotCompute();
8575
8576 // All exiting blocks we have gathered dominate loop's latch, so exact trip
8577 // count is simply a minimum out of all these calculated exit counts.
8579 for (const auto &ENT : ExitNotTaken) {
8580 const SCEV *BECount = ENT.ExactNotTaken;
8581 assert(BECount != SE->getCouldNotCompute() && "Bad exit SCEV!");
8582 assert(SE->DT.dominates(ENT.ExitingBlock, Latch) &&
8583 "We should only have known counts for exiting blocks that dominate "
8584 "latch!");
8585
8586 Ops.push_back(BECount);
8587
8588 if (Preds)
8589 for (const auto *P : ENT.Predicates)
8590 Preds->push_back(P);
8591
8592 assert((Preds || ENT.hasAlwaysTruePredicate()) &&
8593 "Predicate should be always true!");
8594 }
8595
8596 // If an earlier exit exits on the first iteration (exit count zero), then
8597 // a later poison exit count should not propagate into the result. This are
8598 // exactly the semantics provided by umin_seq.
8599 return SE->getUMinFromMismatchedTypes(Ops, /* Sequential */ true);
8600}
8601
8602/// Get the exact not taken count for this loop exit.
8603const SCEV *
8604ScalarEvolution::BackedgeTakenInfo::getExact(const BasicBlock *ExitingBlock,
8605 ScalarEvolution *SE) const {
8606 for (const auto &ENT : ExitNotTaken)
8607 if (ENT.ExitingBlock == ExitingBlock && ENT.hasAlwaysTruePredicate())
8608 return ENT.ExactNotTaken;
8609
8610 return SE->getCouldNotCompute();
8611}
8612
8613const SCEV *ScalarEvolution::BackedgeTakenInfo::getConstantMax(
8614 const BasicBlock *ExitingBlock, ScalarEvolution *SE) const {
8615 for (const auto &ENT : ExitNotTaken)
8616 if (ENT.ExitingBlock == ExitingBlock && ENT.hasAlwaysTruePredicate())
8617 return ENT.ConstantMaxNotTaken;
8618
8619 return SE->getCouldNotCompute();
8620}
8621
8622const SCEV *ScalarEvolution::BackedgeTakenInfo::getSymbolicMax(
8623 const BasicBlock *ExitingBlock, ScalarEvolution *SE) const {
8624 for (const auto &ENT : ExitNotTaken)
8625 if (ENT.ExitingBlock == ExitingBlock && ENT.hasAlwaysTruePredicate())
8626 return ENT.SymbolicMaxNotTaken;
8627
8628 return SE->getCouldNotCompute();
8629}
8630
8631/// getConstantMax - Get the constant max backedge taken count for the loop.
8632const SCEV *
8633ScalarEvolution::BackedgeTakenInfo::getConstantMax(ScalarEvolution *SE) const {
8634 auto PredicateNotAlwaysTrue = [](const ExitNotTakenInfo &ENT) {
8635 return !ENT.hasAlwaysTruePredicate();
8636 };
8637
8638 if (!getConstantMax() || any_of(ExitNotTaken, PredicateNotAlwaysTrue))
8639 return SE->getCouldNotCompute();
8640
8641 assert((isa<SCEVCouldNotCompute>(getConstantMax()) ||
8642 isa<SCEVConstant>(getConstantMax())) &&
8643 "No point in having a non-constant max backedge taken count!");
8644 return getConstantMax();
8645}
8646
8647const SCEV *
8648ScalarEvolution::BackedgeTakenInfo::getSymbolicMax(const Loop *L,
8649 ScalarEvolution *SE) {
8650 if (!SymbolicMax)
8651 SymbolicMax = SE->computeSymbolicMaxBackedgeTakenCount(L);
8652 return SymbolicMax;
8653}
8654
8655bool ScalarEvolution::BackedgeTakenInfo::isConstantMaxOrZero(
8656 ScalarEvolution *SE) const {
8657 auto PredicateNotAlwaysTrue = [](const ExitNotTakenInfo &ENT) {
8658 return !ENT.hasAlwaysTruePredicate();
8659 };
8660 return MaxOrZero && !any_of(ExitNotTaken, PredicateNotAlwaysTrue);
8661}
8662
8664 : ExitLimit(E, E, E, false, std::nullopt) {}
8665
8667 const SCEV *E, const SCEV *ConstantMaxNotTaken,
8668 const SCEV *SymbolicMaxNotTaken, bool MaxOrZero,
8670 : ExactNotTaken(E), ConstantMaxNotTaken(ConstantMaxNotTaken),
8671 SymbolicMaxNotTaken(SymbolicMaxNotTaken), MaxOrZero(MaxOrZero) {
8672 // If we prove the max count is zero, so is the symbolic bound. This happens
8673 // in practice due to differences in a) how context sensitive we've chosen
8674 // to be and b) how we reason about bounds implied by UB.
8675 if (ConstantMaxNotTaken->isZero()) {
8677 this->SymbolicMaxNotTaken = SymbolicMaxNotTaken = ConstantMaxNotTaken;
8678 }
8679
8680 assert((isa<SCEVCouldNotCompute>(ExactNotTaken) ||
8681 !isa<SCEVCouldNotCompute>(ConstantMaxNotTaken)) &&
8682 "Exact is not allowed to be less precise than Constant Max");
8683 assert((isa<SCEVCouldNotCompute>(ExactNotTaken) ||
8684 !isa<SCEVCouldNotCompute>(SymbolicMaxNotTaken)) &&
8685 "Exact is not allowed to be less precise than Symbolic Max");
8686 assert((isa<SCEVCouldNotCompute>(SymbolicMaxNotTaken) ||
8687 !isa<SCEVCouldNotCompute>(ConstantMaxNotTaken)) &&
8688 "Symbolic Max is not allowed to be less precise than Constant Max");
8689 assert((isa<SCEVCouldNotCompute>(ConstantMaxNotTaken) ||
8690 isa<SCEVConstant>(ConstantMaxNotTaken)) &&
8691 "No point in having a non-constant max backedge taken count!");
8692 for (const auto *PredSet : PredSetList)
8693 for (const auto *P : *PredSet)
8694 addPredicate(P);
8695 assert((isa<SCEVCouldNotCompute>(E) || !E->getType()->isPointerTy()) &&
8696 "Backedge count should be int");
8697 assert((isa<SCEVCouldNotCompute>(ConstantMaxNotTaken) ||
8699 "Max backedge count should be int");
8700}
8701
8703 const SCEV *E, const SCEV *ConstantMaxNotTaken,
8704 const SCEV *SymbolicMaxNotTaken, bool MaxOrZero,
8706 : ExitLimit(E, ConstantMaxNotTaken, SymbolicMaxNotTaken, MaxOrZero,
8707 { &PredSet }) {}
8708
8709/// Allocate memory for BackedgeTakenInfo and copy the not-taken count of each
8710/// computable exit into a persistent ExitNotTakenInfo array.
8711ScalarEvolution::BackedgeTakenInfo::BackedgeTakenInfo(
8713 bool IsComplete, const SCEV *ConstantMax, bool MaxOrZero)
8714 : ConstantMax(ConstantMax), IsComplete(IsComplete), MaxOrZero(MaxOrZero) {
8715 using EdgeExitInfo = ScalarEvolution::BackedgeTakenInfo::EdgeExitInfo;
8716
8717 ExitNotTaken.reserve(ExitCounts.size());
8718 std::transform(ExitCounts.begin(), ExitCounts.end(),
8719 std::back_inserter(ExitNotTaken),
8720 [&](const EdgeExitInfo &EEI) {
8721 BasicBlock *ExitBB = EEI.first;
8722 const ExitLimit &EL = EEI.second;
8723 return ExitNotTakenInfo(ExitBB, EL.ExactNotTaken,
8724 EL.ConstantMaxNotTaken, EL.SymbolicMaxNotTaken,
8725 EL.Predicates);
8726 });
8727 assert((isa<SCEVCouldNotCompute>(ConstantMax) ||
8728 isa<SCEVConstant>(ConstantMax)) &&
8729 "No point in having a non-constant max backedge taken count!");
8730}
8731
8732/// Compute the number of times the backedge of the specified loop will execute.
8733ScalarEvolution::BackedgeTakenInfo
8734ScalarEvolution::computeBackedgeTakenCount(const Loop *L,
8735 bool AllowPredicates) {
8736 SmallVector<BasicBlock *, 8> ExitingBlocks;
8737 L->getExitingBlocks(ExitingBlocks);
8738
8739 using EdgeExitInfo = ScalarEvolution::BackedgeTakenInfo::EdgeExitInfo;
8740
8742 bool CouldComputeBECount = true;
8743 BasicBlock *Latch = L->getLoopLatch(); // may be NULL.
8744 const SCEV *MustExitMaxBECount = nullptr;
8745 const SCEV *MayExitMaxBECount = nullptr;
8746 bool MustExitMaxOrZero = false;
8747
8748 // Compute the ExitLimit for each loop exit. Use this to populate ExitCounts
8749 // and compute maxBECount.
8750 // Do a union of all the predicates here.
8751 for (unsigned i = 0, e = ExitingBlocks.size(); i != e; ++i) {
8752 BasicBlock *ExitBB = ExitingBlocks[i];
8753
8754 // We canonicalize untaken exits to br (constant), ignore them so that
8755 // proving an exit untaken doesn't negatively impact our ability to reason
8756 // about the loop as whole.
8757 if (auto *BI = dyn_cast<BranchInst>(ExitBB->getTerminator()))
8758 if (auto *CI = dyn_cast<ConstantInt>(BI->getCondition())) {
8759 bool ExitIfTrue = !L->contains(BI->getSuccessor(0));
8760 if (ExitIfTrue == CI->isZero())
8761 continue;
8762 }
8763
8764 ExitLimit EL = computeExitLimit(L, ExitBB, AllowPredicates);
8765
8766 assert((AllowPredicates || EL.Predicates.empty()) &&
8767 "Predicated exit limit when predicates are not allowed!");
8768
8769 // 1. For each exit that can be computed, add an entry to ExitCounts.
8770 // CouldComputeBECount is true only if all exits can be computed.
8771 if (EL.ExactNotTaken != getCouldNotCompute())
8772 ++NumExitCountsComputed;
8773 else
8774 // We couldn't compute an exact value for this exit, so
8775 // we won't be able to compute an exact value for the loop.
8776 CouldComputeBECount = false;
8777 // Remember exit count if either exact or symbolic is known. Because
8778 // Exact always implies symbolic, only check symbolic.
8779 if (EL.SymbolicMaxNotTaken != getCouldNotCompute())
8780 ExitCounts.emplace_back(ExitBB, EL);
8781 else {
8782 assert(EL.ExactNotTaken == getCouldNotCompute() &&
8783 "Exact is known but symbolic isn't?");
8784 ++NumExitCountsNotComputed;
8785 }
8786
8787 // 2. Derive the loop's MaxBECount from each exit's max number of
8788 // non-exiting iterations. Partition the loop exits into two kinds:
8789 // LoopMustExits and LoopMayExits.
8790 //
8791 // If the exit dominates the loop latch, it is a LoopMustExit otherwise it
8792 // is a LoopMayExit. If any computable LoopMustExit is found, then
8793 // MaxBECount is the minimum EL.ConstantMaxNotTaken of computable
8794 // LoopMustExits. Otherwise, MaxBECount is conservatively the maximum
8795 // EL.ConstantMaxNotTaken, where CouldNotCompute is considered greater than
8796 // any
8797 // computable EL.ConstantMaxNotTaken.
8798 if (EL.ConstantMaxNotTaken != getCouldNotCompute() && Latch &&
8799 DT.dominates(ExitBB, Latch)) {
8800 if (!MustExitMaxBECount) {
8801 MustExitMaxBECount = EL.ConstantMaxNotTaken;
8802 MustExitMaxOrZero = EL.MaxOrZero;
8803 } else {
8804 MustExitMaxBECount = getUMinFromMismatchedTypes(MustExitMaxBECount,
8805 EL.ConstantMaxNotTaken);
8806 }
8807 } else if (MayExitMaxBECount != getCouldNotCompute()) {
8808 if (!MayExitMaxBECount || EL.ConstantMaxNotTaken == getCouldNotCompute())
8809 MayExitMaxBECount = EL.ConstantMaxNotTaken;
8810 else {
8811 MayExitMaxBECount = getUMaxFromMismatchedTypes(MayExitMaxBECount,
8812 EL.ConstantMaxNotTaken);
8813 }
8814 }
8815 }
8816 const SCEV *MaxBECount = MustExitMaxBECount ? MustExitMaxBECount :
8817 (MayExitMaxBECount ? MayExitMaxBECount : getCouldNotCompute());
8818 // The loop backedge will be taken the maximum or zero times if there's
8819 // a single exit that must be taken the maximum or zero times.
8820 bool MaxOrZero = (MustExitMaxOrZero && ExitingBlocks.size() == 1);
8821
8822 // Remember which SCEVs are used in exit limits for invalidation purposes.
8823 // We only care about non-constant SCEVs here, so we can ignore
8824 // EL.ConstantMaxNotTaken
8825 // and MaxBECount, which must be SCEVConstant.
8826 for (const auto &Pair : ExitCounts) {
8827 if (!isa<SCEVConstant>(Pair.second.ExactNotTaken))
8828 BECountUsers[Pair.second.ExactNotTaken].insert({L, AllowPredicates});
8829 if (!isa<SCEVConstant>(Pair.second.SymbolicMaxNotTaken))
8830 BECountUsers[Pair.second.SymbolicMaxNotTaken].insert(
8831 {L, AllowPredicates});
8832 }
8833 return BackedgeTakenInfo(std::move(ExitCounts), CouldComputeBECount,
8834 MaxBECount, MaxOrZero);
8835}
8836
8838ScalarEvolution::computeExitLimit(const Loop *L, BasicBlock *ExitingBlock,
8839 bool AllowPredicates) {
8840 assert(L->contains(ExitingBlock) && "Exit count for non-loop block?");
8841 // If our exiting block does not dominate the latch, then its connection with
8842 // loop's exit limit may be far from trivial.
8843 const BasicBlock *Latch = L->getLoopLatch();
8844 if (!Latch || !DT.dominates(ExitingBlock, Latch))
8845 return getCouldNotCompute();
8846
8847 bool IsOnlyExit = (L->getExitingBlock() != nullptr);
8848 Instruction *Term = ExitingBlock->getTerminator();
8849 if (BranchInst *BI = dyn_cast<BranchInst>(Term)) {
8850 assert(BI->isConditional() && "If unconditional, it can't be in loop!");
8851 bool ExitIfTrue = !L->contains(BI->getSuccessor(0));
8852 assert(ExitIfTrue == L->contains(BI->getSuccessor(1)) &&
8853 "It should have one successor in loop and one exit block!");
8854 // Proceed to the next level to examine the exit condition expression.
8855 return computeExitLimitFromCond(L, BI->getCondition(), ExitIfTrue,
8856 /*ControlsOnlyExit=*/IsOnlyExit,
8857 AllowPredicates);
8858 }
8859
8860 if (SwitchInst *SI = dyn_cast<SwitchInst>(Term)) {
8861 // For switch, make sure that there is a single exit from the loop.
8862 BasicBlock *Exit = nullptr;
8863 for (auto *SBB : successors(ExitingBlock))
8864 if (!L->contains(SBB)) {
8865 if (Exit) // Multiple exit successors.
8866 return getCouldNotCompute();
8867 Exit = SBB;
8868 }
8869 assert(Exit && "Exiting block must have at least one exit");
8870 return computeExitLimitFromSingleExitSwitch(
8871 L, SI, Exit,
8872 /*ControlsOnlyExit=*/IsOnlyExit);
8873 }
8874
8875 return getCouldNotCompute();
8876}
8877
8879 const Loop *L, Value *ExitCond, bool ExitIfTrue, bool ControlsOnlyExit,
8880 bool AllowPredicates) {
8881 ScalarEvolution::ExitLimitCacheTy Cache(L, ExitIfTrue, AllowPredicates);
8882 return computeExitLimitFromCondCached(Cache, L, ExitCond, ExitIfTrue,
8883 ControlsOnlyExit, AllowPredicates);
8884}
8885
8886std::optional<ScalarEvolution::ExitLimit>
8887ScalarEvolution::ExitLimitCache::find(const Loop *L, Value *ExitCond,
8888 bool ExitIfTrue, bool ControlsOnlyExit,
8889 bool AllowPredicates) {
8890 (void)this->L;
8891 (void)this->ExitIfTrue;
8892 (void)this->AllowPredicates;
8893
8894 assert(this->L == L && this->ExitIfTrue == ExitIfTrue &&
8895 this->AllowPredicates == AllowPredicates &&
8896 "Variance in assumed invariant key components!");
8897 auto Itr = TripCountMap.find({ExitCond, ControlsOnlyExit});
8898 if (Itr == TripCountMap.end())
8899 return std::nullopt;
8900 return Itr->second;
8901}
8902
8903void ScalarEvolution::ExitLimitCache::insert(const Loop *L, Value *ExitCond,
8904 bool ExitIfTrue,
8905 bool ControlsOnlyExit,
8906 bool AllowPredicates,
8907 const ExitLimit &EL) {
8908 assert(this->L == L && this->ExitIfTrue == ExitIfTrue &&
8909 this->AllowPredicates == AllowPredicates &&
8910 "Variance in assumed invariant key components!");
8911
8912 auto InsertResult = TripCountMap.insert({{ExitCond, ControlsOnlyExit}, EL});
8913 assert(InsertResult.second && "Expected successful insertion!");
8914 (void)InsertResult;
8915 (void)ExitIfTrue;
8916}
8917
8918ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromCondCached(
8919 ExitLimitCacheTy &Cache, const Loop *L, Value *ExitCond, bool ExitIfTrue,
8920 bool ControlsOnlyExit, bool AllowPredicates) {
8921
8922 if (auto MaybeEL = Cache.find(L, ExitCond, ExitIfTrue, ControlsOnlyExit,
8923 AllowPredicates))
8924 return *MaybeEL;
8925
8926 ExitLimit EL = computeExitLimitFromCondImpl(
8927 Cache, L, ExitCond, ExitIfTrue, ControlsOnlyExit, AllowPredicates);
8928 Cache.insert(L, ExitCond, ExitIfTrue, ControlsOnlyExit, AllowPredicates, EL);
8929 return EL;
8930}
8931
8932ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromCondImpl(
8933 ExitLimitCacheTy &Cache, const Loop *L, Value *ExitCond, bool ExitIfTrue,
8934 bool ControlsOnlyExit, bool AllowPredicates) {
8935 // Handle BinOp conditions (And, Or).
8936 if (auto LimitFromBinOp = computeExitLimitFromCondFromBinOp(
8937 Cache, L, ExitCond, ExitIfTrue, ControlsOnlyExit, AllowPredicates))
8938 return *LimitFromBinOp;
8939
8940 // With an icmp, it may be feasible to compute an exact backedge-taken count.
8941 // Proceed to the next level to examine the icmp.
8942 if (ICmpInst *ExitCondICmp = dyn_cast<ICmpInst>(ExitCond)) {
8943 ExitLimit EL =
8944 computeExitLimitFromICmp(L, ExitCondICmp, ExitIfTrue, ControlsOnlyExit);
8945 if (EL.hasFullInfo() || !AllowPredicates)
8946 return EL;
8947
8948 // Try again, but use SCEV predicates this time.
8949 return computeExitLimitFromICmp(L, ExitCondICmp, ExitIfTrue,
8950 ControlsOnlyExit,
8951 /*AllowPredicates=*/true);
8952 }
8953
8954 // Check for a constant condition. These are normally stripped out by
8955 // SimplifyCFG, but ScalarEvolution may be used by a pass which wishes to
8956 // preserve the CFG and is temporarily leaving constant conditions
8957 // in place.
8958 if (ConstantInt *CI = dyn_cast<ConstantInt>(ExitCond)) {
8959 if (ExitIfTrue == !CI->getZExtValue())
8960 // The backedge is always taken.
8961 return getCouldNotCompute();
8962 // The backedge is never taken.
8963 return getZero(CI->getType());
8964 }
8965
8966 // If we're exiting based on the overflow flag of an x.with.overflow intrinsic
8967 // with a constant step, we can form an equivalent icmp predicate and figure
8968 // out how many iterations will be taken before we exit.
8969 const WithOverflowInst *WO;
8970 const APInt *C;
8971 if (match(ExitCond, m_ExtractValue<1>(m_WithOverflowInst(WO))) &&
8972 match(WO->getRHS(), m_APInt(C))) {
8973 ConstantRange NWR =
8975 WO->getNoWrapKind());
8976 CmpInst::Predicate Pred;
8977 APInt NewRHSC, Offset;
8978 NWR.getEquivalentICmp(Pred, NewRHSC, Offset);
8979 if (!ExitIfTrue)
8980 Pred = ICmpInst::getInversePredicate(Pred);
8981 auto *LHS = getSCEV(WO->getLHS());
8982 if (Offset != 0)
8984 auto EL = computeExitLimitFromICmp(L, Pred, LHS, getConstant(NewRHSC),
8985 ControlsOnlyExit, AllowPredicates);
8986 if (EL.hasAnyInfo())
8987 return EL;
8988 }
8989
8990 // If it's not an integer or pointer comparison then compute it the hard way.
8991 return computeExitCountExhaustively(L, ExitCond, ExitIfTrue);
8992}
8993
8994std::optional<ScalarEvolution::ExitLimit>
8995ScalarEvolution::computeExitLimitFromCondFromBinOp(
8996 ExitLimitCacheTy &Cache, const Loop *L, Value *ExitCond, bool ExitIfTrue,
8997 bool ControlsOnlyExit, bool AllowPredicates) {
8998 // Check if the controlling expression for this loop is an And or Or.
8999 Value *Op0, *Op1;
9000 bool IsAnd = false;
9001 if (match(ExitCond, m_LogicalAnd(m_Value(Op0), m_Value(Op1))))
9002 IsAnd = true;
9003 else if (match(ExitCond, m_LogicalOr(m_Value(Op0), m_Value(Op1))))
9004 IsAnd = false;
9005 else
9006 return std::nullopt;
9007
9008 // EitherMayExit is true in these two cases:
9009 // br (and Op0 Op1), loop, exit
9010 // br (or Op0 Op1), exit, loop
9011 bool EitherMayExit = IsAnd ^ ExitIfTrue;
9012 ExitLimit EL0 = computeExitLimitFromCondCached(
9013 Cache, L, Op0, ExitIfTrue, ControlsOnlyExit && !EitherMayExit,
9014 AllowPredicates);
9015 ExitLimit EL1 = computeExitLimitFromCondCached(
9016 Cache, L, Op1, ExitIfTrue, ControlsOnlyExit && !EitherMayExit,
9017 AllowPredicates);
9018
9019 // Be robust against unsimplified IR for the form "op i1 X, NeutralElement"
9020 const Constant *NeutralElement = ConstantInt::get(ExitCond->getType(), IsAnd);
9021 if (isa<ConstantInt>(Op1))
9022 return Op1 == NeutralElement ? EL0 : EL1;
9023 if (isa<ConstantInt>(Op0))
9024 return Op0 == NeutralElement ? EL1 : EL0;
9025
9026 const SCEV *BECount = getCouldNotCompute();
9027 const SCEV *ConstantMaxBECount = getCouldNotCompute();
9028 const SCEV *SymbolicMaxBECount = getCouldNotCompute();
9029 if (EitherMayExit) {
9030 bool UseSequentialUMin = !isa<BinaryOperator>(ExitCond);
9031 // Both conditions must be same for the loop to continue executing.
9032 // Choose the less conservative count.
9033 if (EL0.ExactNotTaken != getCouldNotCompute() &&
9034 EL1.ExactNotTaken != getCouldNotCompute()) {
9035 BECount = getUMinFromMismatchedTypes(EL0.ExactNotTaken, EL1.ExactNotTaken,
9036 UseSequentialUMin);
9037 }
9038 if (EL0.ConstantMaxNotTaken == getCouldNotCompute())
9039 ConstantMaxBECount = EL1.ConstantMaxNotTaken;
9040 else if (EL1.ConstantMaxNotTaken == getCouldNotCompute())
9041 ConstantMaxBECount = EL0.ConstantMaxNotTaken;
9042 else
9043 ConstantMaxBECount = getUMinFromMismatchedTypes(EL0.ConstantMaxNotTaken,
9044 EL1.ConstantMaxNotTaken);
9045 if (EL0.SymbolicMaxNotTaken == getCouldNotCompute())
9046 SymbolicMaxBECount = EL1.SymbolicMaxNotTaken;
9047 else if (EL1.SymbolicMaxNotTaken == getCouldNotCompute())
9048 SymbolicMaxBECount = EL0.SymbolicMaxNotTaken;
9049 else
9050 SymbolicMaxBECount = getUMinFromMismatchedTypes(
9051 EL0.SymbolicMaxNotTaken, EL1.SymbolicMaxNotTaken, UseSequentialUMin);
9052 } else {
9053 // Both conditions must be same at the same time for the loop to exit.
9054 // For now, be conservative.
9055 if (EL0.ExactNotTaken == EL1.ExactNotTaken)
9056 BECount = EL0.ExactNotTaken;
9057 }
9058
9059 // There are cases (e.g. PR26207) where computeExitLimitFromCond is able
9060 // to be more aggressive when computing BECount than when computing
9061 // ConstantMaxBECount. In these cases it is possible for EL0.ExactNotTaken
9062 // and
9063 // EL1.ExactNotTaken to match, but for EL0.ConstantMaxNotTaken and
9064 // EL1.ConstantMaxNotTaken to not.
9065 if (isa<SCEVCouldNotCompute>(ConstantMaxBECount) &&
9066 !isa<SCEVCouldNotCompute>(BECount))
9067 ConstantMaxBECount = getConstant(getUnsignedRangeMax(BECount));
9068 if (isa<SCEVCouldNotCompute>(SymbolicMaxBECount))
9069 SymbolicMaxBECount =
9070 isa<SCEVCouldNotCompute>(BECount) ? ConstantMaxBECount : BECount;
9071 return ExitLimit(BECount, ConstantMaxBECount, SymbolicMaxBECount, false,
9072 { &EL0.Predicates, &EL1.Predicates });
9073}
9074
9075ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromICmp(
9076 const Loop *L, ICmpInst *ExitCond, bool ExitIfTrue, bool ControlsOnlyExit,
9077 bool AllowPredicates) {
9078 // If the condition was exit on true, convert the condition to exit on false
9080 if (!ExitIfTrue)
9081 Pred = ExitCond->getPredicate();
9082 else
9083 Pred = ExitCond->getInversePredicate();
9084 const ICmpInst::Predicate OriginalPred = Pred;
9085
9086 const SCEV *LHS = getSCEV(ExitCond->getOperand(0));
9087 const SCEV *RHS = getSCEV(ExitCond->getOperand(1));
9088
9089 ExitLimit EL = computeExitLimitFromICmp(L, Pred, LHS, RHS, ControlsOnlyExit,
9090 AllowPredicates);
9091 if (EL.hasAnyInfo())
9092 return EL;
9093
9094 auto *ExhaustiveCount =
9095 computeExitCountExhaustively(L, ExitCond, ExitIfTrue);
9096
9097 if (!isa<SCEVCouldNotCompute>(ExhaustiveCount))
9098 return ExhaustiveCount;
9099
9100 return computeShiftCompareExitLimit(ExitCond->getOperand(0),
9101 ExitCond->getOperand(1), L, OriginalPred);
9102}
9103ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromICmp(
9104 const Loop *L, ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS,
9105 bool ControlsOnlyExit, bool AllowPredicates) {
9106
9107 // Try to evaluate any dependencies out of the loop.
9108 LHS = getSCEVAtScope(LHS, L);
9109 RHS = getSCEVAtScope(RHS, L);
9110
9111 // At this point, we would like to compute how many iterations of the
9112 // loop the predicate will return true for these inputs.
9113 if (isLoopInvariant(LHS, L) && !isLoopInvariant(RHS, L)) {
9114 // If there is a loop-invariant, force it into the RHS.
9115 std::swap(LHS, RHS);
9116 Pred = ICmpInst::getSwappedPredicate(Pred);
9117 }
9118
9119 bool ControllingFiniteLoop = ControlsOnlyExit && loopHasNoAbnormalExits(L) &&
9121 // Simplify the operands before analyzing them.
9122 (void)SimplifyICmpOperands(Pred, LHS, RHS, /*Depth=*/0);
9123
9124 // If we have a comparison of a chrec against a constant, try to use value
9125 // ranges to answer this query.
9126 if (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS))
9127 if (const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(LHS))
9128 if (AddRec->getLoop() == L) {
9129 // Form the constant range.
9130 ConstantRange CompRange =
9131 ConstantRange::makeExactICmpRegion(Pred, RHSC->getAPInt());
9132
9133 const SCEV *Ret = AddRec->getNumIterationsInRange(CompRange, *this);
9134 if (!isa<SCEVCouldNotCompute>(Ret)) return Ret;
9135 }
9136
9137 // If this loop must exit based on this condition (or execute undefined
9138 // behaviour), and we can prove the test sequence produced must repeat
9139 // the same values on self-wrap of the IV, then we can infer that IV
9140 // doesn't self wrap because if it did, we'd have an infinite (undefined)
9141 // loop.
9142 if (ControllingFiniteLoop && isLoopInvariant(RHS, L)) {
9143 // TODO: We can peel off any functions which are invertible *in L*. Loop
9144 // invariant terms are effectively constants for our purposes here.
9145 auto *InnerLHS = LHS;
9146 if (auto *ZExt = dyn_cast<SCEVZeroExtendExpr>(LHS))
9147 InnerLHS = ZExt->getOperand();
9148 if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(InnerLHS)) {
9149 auto *StrideC = dyn_cast<SCEVConstant>(AR->getStepRecurrence(*this));
9150 if (!AR->hasNoSelfWrap() && AR->getLoop() == L && AR->isAffine() &&
9151 StrideC && StrideC->getAPInt().isPowerOf2()) {
9152 auto Flags = AR->getNoWrapFlags();
9153 Flags = setFlags(Flags, SCEV::FlagNW);
9156 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), Flags);
9157 }
9158 }
9159 }
9160
9161 switch (Pred) {
9162 case ICmpInst::ICMP_NE: { // while (X != Y)
9163 // Convert to: while (X-Y != 0)
9164 if (LHS->getType()->isPointerTy()) {
9166 if (isa<SCEVCouldNotCompute>(LHS))
9167 return LHS;
9168 }
9169 if (RHS->getType()->isPointerTy()) {
9171 if (isa<SCEVCouldNotCompute>(RHS))
9172 return RHS;
9173 }
9174 ExitLimit EL = howFarToZero(getMinusSCEV(LHS, RHS), L, ControlsOnlyExit,
9175 AllowPredicates);
9176 if (EL.hasAnyInfo())
9177 return EL;
9178 break;
9179 }
9180 case ICmpInst::ICMP_EQ: { // while (X == Y)
9181 // Convert to: while (X-Y == 0)
9182 if (LHS->getType()->isPointerTy()) {
9184 if (isa<SCEVCouldNotCompute>(LHS))
9185 return LHS;
9186 }
9187 if (RHS->getType()->isPointerTy()) {
9189 if (isa<SCEVCouldNotCompute>(RHS))
9190 return RHS;
9191 }
9192 ExitLimit EL = howFarToNonZero(getMinusSCEV(LHS, RHS), L);
9193 if (EL.hasAnyInfo()) return EL;
9194 break;
9195 }
9196 case ICmpInst::ICMP_SLE:
9197 case ICmpInst::ICMP_ULE:
9198 // Since the loop is finite, an invariant RHS cannot include the boundary
9199 // value, otherwise it would loop forever.
9200 if (!EnableFiniteLoopControl || !ControllingFiniteLoop ||
9201 !isLoopInvariant(RHS, L))
9202 break;
9203 RHS = getAddExpr(getOne(RHS->getType()), RHS);
9204 [[fallthrough]];
9205 case ICmpInst::ICMP_SLT:
9206 case ICmpInst::ICMP_ULT: { // while (X < Y)
9207 bool IsSigned = ICmpInst::isSigned(Pred);
9208 ExitLimit EL = howManyLessThans(LHS, RHS, L, IsSigned, ControlsOnlyExit,
9209 AllowPredicates);
9210 if (EL.hasAnyInfo())
9211 return EL;
9212 break;
9213 }
9214 case ICmpInst::ICMP_SGE:
9215 case ICmpInst::ICMP_UGE:
9216 // Since the loop is finite, an invariant RHS cannot include the boundary
9217 // value, otherwise it would loop forever.
9218 if (!EnableFiniteLoopControl || !ControllingFiniteLoop ||
9219 !isLoopInvariant(RHS, L))
9220 break;
9221 RHS = getAddExpr(getMinusOne(RHS->getType()), RHS);
9222 [[fallthrough]];
9223 case ICmpInst::ICMP_SGT:
9224 case ICmpInst::ICMP_UGT: { // while (X > Y)
9225 bool IsSigned = ICmpInst::isSigned(Pred);
9226 ExitLimit EL = howManyGreaterThans(LHS, RHS, L, IsSigned, ControlsOnlyExit,
9227 AllowPredicates);
9228 if (EL.hasAnyInfo())
9229 return EL;
9230 break;
9231 }
9232 default:
9233 break;
9234 }
9235
9236 return getCouldNotCompute();
9237}
9238
9240ScalarEvolution::computeExitLimitFromSingleExitSwitch(const Loop *L,
9241 SwitchInst *Switch,
9242 BasicBlock *ExitingBlock,
9243 bool ControlsOnlyExit) {
9244 assert(!L->contains(ExitingBlock) && "Not an exiting block!");
9245
9246 // Give up if the exit is the default dest of a switch.
9247 if (Switch->getDefaultDest() == ExitingBlock)
9248 return getCouldNotCompute();
9249
9250 assert(L->contains(Switch->getDefaultDest()) &&
9251 "Default case must not exit the loop!");
9252 const SCEV *LHS = getSCEVAtScope(Switch->getCondition(), L);
9253 const SCEV *RHS = getConstant(Switch->findCaseDest(ExitingBlock));
9254
9255 // while (X != Y) --> while (X-Y != 0)
9256 ExitLimit EL = howFarToZero(getMinusSCEV(LHS, RHS), L, ControlsOnlyExit);
9257 if (EL.hasAnyInfo())
9258 return EL;
9259
9260 return getCouldNotCompute();
9261}
9262
9263static ConstantInt *
9265 ScalarEvolution &SE) {
9266 const SCEV *InVal = SE.getConstant(C);
9267 const SCEV *Val = AddRec->evaluateAtIteration(InVal, SE);
9268 assert(isa<SCEVConstant>(Val) &&
9269 "Evaluation of SCEV at constant didn't fold correctly?");
9270 return cast<SCEVConstant>(Val)->getValue();
9271}
9272
9273ScalarEvolution::ExitLimit ScalarEvolution::computeShiftCompareExitLimit(
9274 Value *LHS, Value *RHSV, const Loop *L, ICmpInst::Predicate Pred) {
9275 ConstantInt *RHS = dyn_cast<ConstantInt>(RHSV);
9276 if (!RHS)
9277 return getCouldNotCompute();
9278
9279 const BasicBlock *Latch = L->getLoopLatch();
9280 if (!Latch)
9281 return getCouldNotCompute();
9282
9283 const BasicBlock *Predecessor = L->getLoopPredecessor();
9284 if (!Predecessor)
9285 return getCouldNotCompute();
9286
9287 // Return true if V is of the form "LHS `shift_op` <positive constant>".
9288 // Return LHS in OutLHS and shift_opt in OutOpCode.
9289 auto MatchPositiveShift =
9290 [](Value *V, Value *&OutLHS, Instruction::BinaryOps &OutOpCode) {
9291
9292 using namespace PatternMatch;
9293
9294 ConstantInt *ShiftAmt;
9295 if (match(V, m_LShr(m_Value(OutLHS), m_ConstantInt(ShiftAmt))))
9296 OutOpCode = Instruction::LShr;
9297 else if (match(V, m_AShr(m_Value(OutLHS), m_ConstantInt(ShiftAmt))))
9298 OutOpCode = Instruction::AShr;
9299 else if (match(V, m_Shl(m_Value(OutLHS), m_ConstantInt(ShiftAmt))))
9300 OutOpCode = Instruction::Shl;
9301 else
9302 return false;
9303
9304 return ShiftAmt->getValue().isStrictlyPositive();
9305 };
9306
9307 // Recognize a "shift recurrence" either of the form %iv or of %iv.shifted in
9308 //
9309 // loop:
9310 // %iv = phi i32 [ %iv.shifted, %loop ], [ %val, %preheader ]
9311 // %iv.shifted = lshr i32 %iv, <positive constant>
9312 //
9313 // Return true on a successful match. Return the corresponding PHI node (%iv
9314 // above) in PNOut and the opcode of the shift operation in OpCodeOut.
9315 auto MatchShiftRecurrence =
9316 [&](Value *V, PHINode *&PNOut, Instruction::BinaryOps &OpCodeOut) {
9317 std::optional<Instruction::BinaryOps> PostShiftOpCode;
9318
9319 {
9321 Value *V;
9322
9323 // If we encounter a shift instruction, "peel off" the shift operation,
9324 // and remember that we did so. Later when we inspect %iv's backedge
9325 // value, we will make sure that the backedge value uses the same
9326 // operation.
9327 //
9328 // Note: the peeled shift operation does not have to be the same
9329 // instruction as the one feeding into the PHI's backedge value. We only
9330 // really care about it being the same *kind* of shift instruction --
9331 // that's all that is required for our later inferences to hold.
9332 if (MatchPositiveShift(LHS, V, OpC)) {
9333 PostShiftOpCode = OpC;
9334 LHS = V;
9335 }
9336 }
9337
9338 PNOut = dyn_cast<PHINode>(LHS);
9339 if (!PNOut || PNOut->getParent() != L->getHeader())
9340 return false;
9341
9342 Value *BEValue = PNOut->getIncomingValueForBlock(Latch);
9343 Value *OpLHS;
9344
9345 return
9346 // The backedge value for the PHI node must be a shift by a positive
9347 // amount
9348 MatchPositiveShift(BEValue, OpLHS, OpCodeOut) &&
9349
9350 // of the PHI node itself
9351 OpLHS == PNOut &&
9352
9353 // and the kind of shift should be match the kind of shift we peeled
9354 // off, if any.
9355 (!PostShiftOpCode || *PostShiftOpCode == OpCodeOut);
9356 };
9357
9358 PHINode *PN;
9360 if (!MatchShiftRecurrence(LHS, PN, OpCode))
9361 return getCouldNotCompute();
9362
9363 const DataLayout &DL = getDataLayout();
9364
9365 // The key rationale for this optimization is that for some kinds of shift
9366 // recurrences, the value of the recurrence "stabilizes" to either 0 or -1
9367 // within a finite number of iterations. If the condition guarding the
9368 // backedge (in the sense that the backedge is taken if the condition is true)
9369 // is false for the value the shift recurrence stabilizes to, then we know
9370 // that the backedge is taken only a finite number of times.
9371
9372 ConstantInt *StableValue = nullptr;
9373 switch (OpCode) {
9374 default:
9375 llvm_unreachable("Impossible case!");
9376
9377 case Instruction::AShr: {
9378 // {K,ashr,<positive-constant>} stabilizes to signum(K) in at most
9379 // bitwidth(K) iterations.
9380 Value *FirstValue = PN->getIncomingValueForBlock(Predecessor);
9381 KnownBits Known = computeKnownBits(FirstValue, DL, 0, &AC,
9382 Predecessor->getTerminator(), &DT);
9383 auto *Ty = cast<IntegerType>(RHS->getType());
9384 if (Known.isNonNegative())
9385 StableValue = ConstantInt::get(Ty, 0);
9386 else if (Known.isNegative())
9387 StableValue = ConstantInt::get(Ty, -1, true);
9388 else
9389 return getCouldNotCompute();
9390
9391 break;
9392 }
9393 case Instruction::LShr:
9394 case Instruction::Shl:
9395 // Both {K,lshr,<positive-constant>} and {K,shl,<positive-constant>}
9396 // stabilize to 0 in at most bitwidth(K) iterations.
9397 StableValue = ConstantInt::get(cast<IntegerType>(RHS->getType()), 0);
9398 break;
9399 }
9400
9401 auto *Result =
9402 ConstantFoldCompareInstOperands(Pred, StableValue, RHS, DL, &TLI);
9403 assert(Result->getType()->isIntegerTy(1) &&
9404 "Otherwise cannot be an operand to a branch instruction");
9405
9406 if (Result->isZeroValue()) {
9407 unsigned BitWidth = getTypeSizeInBits(RHS->getType());
9408 const SCEV *UpperBound =
9410 return ExitLimit(getCouldNotCompute(), UpperBound, UpperBound, false);
9411 }
9412
9413 return getCouldNotCompute();
9414}
9415
9416/// Return true if we can constant fold an instruction of the specified type,
9417/// assuming that all operands were constants.
9418static bool CanConstantFold(const Instruction *I) {
9419 if (isa<BinaryOperator>(I) || isa<CmpInst>(I) ||
9420 isa<SelectInst>(I) || isa<CastInst>(I) || isa<GetElementPtrInst>(I) ||
9421 isa<LoadInst>(I) || isa<ExtractValueInst>(I))
9422 return true;
9423
9424 if (const CallInst *CI = dyn_cast<CallInst>(I))
9425 if (const Function *F = CI->getCalledFunction())
9426 return canConstantFoldCallTo(CI, F);
9427 return false;
9428}
9429
9430/// Determine whether this instruction can constant evolve within this loop
9431/// assuming its operands can all constant evolve.
9432static bool canConstantEvolve(Instruction *I, const Loop *L) {
9433 // An instruction outside of the loop can't be derived from a loop PHI.
9434 if (!L->contains(I)) return false;
9435
9436 if (isa<PHINode>(I)) {
9437 // We don't currently keep track of the control flow needed to evaluate
9438 // PHIs, so we cannot handle PHIs inside of loops.
9439 return L->getHeader() == I->getParent();
9440 }
9441
9442 // If we won't be able to constant fold this expression even if the operands
9443 // are constants, bail early.
9444 return CanConstantFold(I);
9445}
9446
9447/// getConstantEvolvingPHIOperands - Implement getConstantEvolvingPHI by
9448/// recursing through each instruction operand until reaching a loop header phi.
9449static PHINode *
9452 unsigned Depth) {
9454 return nullptr;
9455
9456 // Otherwise, we can evaluate this instruction if all of its operands are
9457 // constant or derived from a PHI node themselves.
9458 PHINode *PHI = nullptr;
9459 for (Value *Op : UseInst->operands()) {
9460 if (isa<Constant>(Op)) continue;
9461
9462 Instruction *OpInst = dyn_cast<Instruction>(Op);
9463 if (!OpInst || !canConstantEvolve(OpInst, L)) return nullptr;
9464
9465 PHINode *P = dyn_cast<PHINode>(OpInst);
9466 if (!P)
9467 // If this operand is already visited, reuse the prior result.
9468 // We may have P != PHI if this is the deepest point at which the
9469 // inconsistent paths meet.
9470 P = PHIMap.lookup(OpInst);
9471 if (!P) {
9472 // Recurse and memoize the results, whether a phi is found or not.
9473 // This recursive call invalidates pointers into PHIMap.
9474 P = getConstantEvolvingPHIOperands(OpInst, L, PHIMap, Depth + 1);
9475 PHIMap[OpInst] = P;
9476 }
9477 if (!P)
9478 return nullptr; // Not evolving from PHI
9479 if (PHI && PHI != P)
9480 return nullptr; // Evolving from multiple different PHIs.
9481 PHI = P;
9482 }
9483 // This is a expression evolving from a constant PHI!
9484 return PHI;
9485}
9486
9487/// getConstantEvolvingPHI - Given an LLVM value and a loop, return a PHI node
9488/// in the loop that V is derived from. We allow arbitrary operations along the
9489/// way, but the operands of an operation must either be constants or a value
9490/// derived from a constant PHI. If this expression does not fit with these
9491/// constraints, return null.
9493 Instruction *I = dyn_cast<Instruction>(V);
9494 if (!I || !canConstantEvolve(I, L)) return nullptr;
9495
9496 if (PHINode *PN = dyn_cast<PHINode>(I))
9497 return PN;
9498
9499 // Record non-constant instructions contained by the loop.
9501 return getConstantEvolvingPHIOperands(I, L, PHIMap, 0);
9502}
9503
9504/// EvaluateExpression - Given an expression that passes the
9505/// getConstantEvolvingPHI predicate, evaluate its value assuming the PHI node
9506/// in the loop has the value PHIVal. If we can't fold this expression for some
9507/// reason, return null.
9510 const DataLayout &DL,
9511 const TargetLibraryInfo *TLI) {
9512 // Convenient constant check, but redundant for recursive calls.
9513 if (Constant *C = dyn_cast<Constant>(V)) return C;
9514 Instruction *I = dyn_cast<Instruction>(V);
9515 if (!I) return nullptr;
9516
9517 if (Constant *C = Vals.lookup(I)) return C;
9518
9519 // An instruction inside the loop depends on a value outside the loop that we
9520 // weren't given a mapping for, or a value such as a call inside the loop.
9521 if (!canConstantEvolve(I, L)) return nullptr;
9522
9523 // An unmapped PHI can be due to a branch or another loop inside this loop,
9524 // or due to this not being the initial iteration through a loop where we
9525 // couldn't compute the evolution of this particular PHI last time.
9526 if (isa<PHINode>(I)) return nullptr;
9527
9528 std::vector<Constant*> Operands(I->getNumOperands());
9529
9530 for (unsigned i = 0, e = I->getNumOperands(); i != e; ++i) {
9531 Instruction *Operand = dyn_cast<Instruction>(I->getOperand(i));
9532 if (!Operand) {
9533 Operands[i] = dyn_cast<Constant>(I->getOperand(i));
9534 if (!Operands[i]) return nullptr;
9535 continue;
9536 }
9537 Constant *C = EvaluateExpression(Operand, L, Vals, DL, TLI);
9538 Vals[Operand] = C;
9539 if (!C) return nullptr;
9540 Operands[i] = C;
9541 }
9542
9543 return ConstantFoldInstOperands(I, Operands, DL, TLI);
9544}
9545
9546
9547// If every incoming value to PN except the one for BB is a specific Constant,
9548// return that, else return nullptr.
9550 Constant *IncomingVal = nullptr;
9551
9552 for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) {
9553 if (PN->getIncomingBlock(i) == BB)
9554 continue;
9555
9556 auto *CurrentVal = dyn_cast<Constant>(PN->getIncomingValue(i));
9557 if (!CurrentVal)
9558 return nullptr;
9559
9560 if (IncomingVal != CurrentVal) {
9561 if (IncomingVal)
9562 return nullptr;
9563 IncomingVal = CurrentVal;
9564 }
9565 }
9566
9567 return IncomingVal;
9568}
9569
9570/// getConstantEvolutionLoopExitValue - If we know that the specified Phi is
9571/// in the header of its containing loop, we know the loop executes a
9572/// constant number of times, and the PHI node is just a recurrence
9573/// involving constants, fold it.
9574Constant *
9575ScalarEvolution::getConstantEvolutionLoopExitValue(PHINode *PN,
9576 const APInt &BEs,
9577 const Loop *L) {
9578 auto I = ConstantEvolutionLoopExitValue.find(PN);
9579 if (I != ConstantEvolutionLoopExitValue.end())
9580 return I->second;
9581
9583 return ConstantEvolutionLoopExitValue[PN] = nullptr; // Not going to evaluate it.
9584
9585 Constant *&RetVal = ConstantEvolutionLoopExitValue[PN];
9586
9588 BasicBlock *Header = L->getHeader();
9589 assert(PN->getParent() == Header && "Can't evaluate PHI not in loop header!");
9590
9591 BasicBlock *Latch = L->getLoopLatch();
9592 if (!Latch)
9593 return nullptr;
9594
9595 for (PHINode &PHI : Header->phis()) {
9596 if (auto *StartCST = getOtherIncomingValue(&PHI, Latch))
9597 CurrentIterVals[&PHI] = StartCST;
9598 }
9599 if (!CurrentIterVals.count(PN))
9600 return RetVal = nullptr;
9601
9602 Value *BEValue = PN->getIncomingValueForBlock(Latch);
9603
9604 // Execute the loop symbolically to determine the exit value.
9605 assert(BEs.getActiveBits() < CHAR_BIT * sizeof(unsigned) &&
9606 "BEs is <= MaxBruteForceIterations which is an 'unsigned'!");
9607
9608 unsigned NumIterations = BEs.getZExtValue(); // must be in range
9609 unsigned IterationNum = 0;
9610 const DataLayout &DL = getDataLayout();
9611 for (; ; ++IterationNum) {
9612 if (IterationNum == NumIterations)
9613 return RetVal = CurrentIterVals[PN]; // Got exit value!
9614
9615 // Compute the value of the PHIs for the next iteration.
9616 // EvaluateExpression adds non-phi values to the CurrentIterVals map.
9618 Constant *NextPHI =
9619 EvaluateExpression(BEValue, L, CurrentIterVals, DL, &TLI);
9620 if (!NextPHI)
9621 return nullptr; // Couldn't evaluate!
9622 NextIterVals[PN] = NextPHI;
9623
9624 bool StoppedEvolving = NextPHI == CurrentIterVals[PN];
9625
9626 // Also evaluate the other PHI nodes. However, we don't get to stop if we
9627 // cease to be able to evaluate one of them or if they stop evolving,
9628 // because that doesn't necessarily prevent us from computing PN.
9630 for (const auto &I : CurrentIterVals) {
9631 PHINode *PHI = dyn_cast<PHINode>(I.first);
9632 if (!PHI || PHI == PN || PHI->getParent() != Header) continue;
9633 PHIsToCompute.emplace_back(PHI, I.second);
9634 }
9635 // We use two distinct loops because EvaluateExpression may invalidate any
9636 // iterators into CurrentIterVals.
9637 for (const auto &I : PHIsToCompute) {
9638 PHINode *PHI = I.first;
9639 Constant *&NextPHI = NextIterVals[PHI];
9640 if (!NextPHI) { // Not already computed.
9641 Value *BEValue = PHI->getIncomingValueForBlock(Latch);
9642 NextPHI = EvaluateExpression(BEValue, L, CurrentIterVals, DL, &TLI);
9643 }
9644 if (NextPHI != I.second)
9645 StoppedEvolving = false;
9646 }
9647
9648 // If all entries in CurrentIterVals == NextIterVals then we can stop
9649 // iterating, the loop can't continue to change.
9650 if (StoppedEvolving)
9651 return RetVal = CurrentIterVals[PN];
9652
9653 CurrentIterVals.swap(NextIterVals);
9654 }
9655}
9656
9657const SCEV *ScalarEvolution::computeExitCountExhaustively(const Loop *L,
9658 Value *Cond,
9659 bool ExitWhen) {
9661 if (!PN) return getCouldNotCompute();
9662
9663 // If the loop is canonicalized, the PHI will have exactly two entries.
9664 // That's the only form we support here.
9665 if (PN->getNumIncomingValues() != 2) return getCouldNotCompute();
9666
9668 BasicBlock *Header = L->getHeader();
9669 assert(PN->getParent() == Header && "Can't evaluate PHI not in loop header!");
9670
9671 BasicBlock *Latch = L->getLoopLatch();
9672 assert(Latch && "Should follow from NumIncomingValues == 2!");
9673
9674 for (PHINode &PHI : Header->phis()) {
9675 if (auto *StartCST = getOtherIncomingValue(&PHI, Latch))
9676 CurrentIterVals[&PHI] = StartCST;
9677 }
9678 if (!CurrentIterVals.count(PN))
9679 return getCouldNotCompute();
9680
9681 // Okay, we find a PHI node that defines the trip count of this loop. Execute
9682 // the loop symbolically to determine when the condition gets a value of
9683 // "ExitWhen".
9684 unsigned MaxIterations = MaxBruteForceIterations; // Limit analysis.
9685 const DataLayout &DL = getDataLayout();
9686 for (unsigned IterationNum = 0; IterationNum != MaxIterations;++IterationNum){
9687 auto *CondVal = dyn_cast_or_null<ConstantInt>(
9688 EvaluateExpression(Cond, L, CurrentIterVals, DL, &TLI));
9689
9690 // Couldn't symbolically evaluate.
9691 if (!CondVal) return getCouldNotCompute();
9692
9693 if (CondVal->getValue() == uint64_t(ExitWhen)) {
9694 ++NumBruteForceTripCountsComputed;
9695 return getConstant(Type::getInt32Ty(getContext()), IterationNum);
9696 }
9697
9698 // Update all the PHI nodes for the next iteration.
9700
9701 // Create a list of which PHIs we need to compute. We want to do this before
9702 // calling EvaluateExpression on them because that may invalidate iterators
9703 // into CurrentIterVals.
9704 SmallVector<PHINode *, 8> PHIsToCompute;
9705 for (const auto &I : CurrentIterVals) {
9706 PHINode *PHI = dyn_cast<PHINode>(I.first);
9707 if (!PHI || PHI->getParent() != Header) continue;
9708 PHIsToCompute.push_back(PHI);
9709 }
9710 for (PHINode *PHI : PHIsToCompute) {
9711 Constant *&NextPHI = NextIterVals[PHI];
9712 if (NextPHI) continue; // Already computed!
9713
9714 Value *BEValue = PHI->getIncomingValueForBlock(Latch);
9715 NextPHI = EvaluateExpression(BEValue, L, CurrentIterVals, DL, &TLI);
9716 }
9717 CurrentIterVals.swap(NextIterVals);
9718 }
9719
9720 // Too many iterations were needed to evaluate.
9721 return getCouldNotCompute();
9722}
9723
9724const SCEV *ScalarEvolution::getSCEVAtScope(const SCEV *V, const Loop *L) {
9726 ValuesAtScopes[V];
9727 // Check to see if we've folded this expression at this loop before.
9728 for (auto &LS : Values)
9729 if (LS.first == L)
9730 return LS.second ? LS.second : V;
9731
9732 Values.emplace_back(L, nullptr);
9733
9734 // Otherwise compute it.
9735 const SCEV *C = computeSCEVAtScope(V, L);
9736 for (auto &LS : reverse(ValuesAtScopes[V]))
9737 if (LS.first == L) {
9738 LS.second = C;
9739 if (!isa<SCEVConstant>(C))
9740 ValuesAtScopesUsers[C].push_back({L, V});
9741 break;
9742 }
9743 return C;
9744}
9745
9746/// This builds up a Constant using the ConstantExpr interface. That way, we
9747/// will return Constants for objects which aren't represented by a
9748/// SCEVConstant, because SCEVConstant is restricted to ConstantInt.
9749/// Returns NULL if the SCEV isn't representable as a Constant.
9751 switch (V->getSCEVType()) {
9752 case scCouldNotCompute:
9753 case scAddRecExpr:
9754 case scVScale:
9755 return nullptr;
9756 case scConstant:
9757 return cast<SCEVConstant>(V)->getValue();
9758 case scUnknown:
9759 return dyn_cast<Constant>(cast<SCEVUnknown>(V)->getValue());
9760 case scPtrToInt: {
9761 const SCEVPtrToIntExpr *P2I = cast<SCEVPtrToIntExpr>(V);
9762 if (Constant *CastOp = BuildConstantFromSCEV(P2I->getOperand()))
9763 return ConstantExpr::getPtrToInt(CastOp, P2I->getType());
9764
9765 return nullptr;
9766 }
9767 case scTruncate: {
9768 const SCEVTruncateExpr *ST = cast<SCEVTruncateExpr>(V);
9769 if (Constant *CastOp = BuildConstantFromSCEV(ST->getOperand()))
9770 return ConstantExpr::getTrunc(CastOp, ST->getType());
9771 return nullptr;
9772 }
9773 case scAddExpr: {
9774 const SCEVAddExpr *SA = cast<SCEVAddExpr>(V);
9775 Constant *C = nullptr;
9776 for (const SCEV *Op : SA->operands()) {
9778 if (!OpC)
9779 return nullptr;
9780 if (!C) {
9781 C = OpC;
9782 continue;
9783 }
9784 assert(!C->getType()->isPointerTy() &&
9785 "Can only have one pointer, and it must be last");
9786 if (OpC->getType()->isPointerTy()) {
9787 // The offsets have been converted to bytes. We can add bytes using
9788 // an i8 GEP.
9790 OpC, C);
9791 } else {
9792 C = ConstantExpr::getAdd(C, OpC);
9793 }
9794 }
9795 return C;
9796 }
9797 case scMulExpr:
9798 case scSignExtend:
9799 case scZeroExtend:
9800 case scUDivExpr:
9801 case scSMaxExpr:
9802 case scUMaxExpr:
9803 case scSMinExpr:
9804 case scUMinExpr:
9806 return nullptr;
9807 }
9808 llvm_unreachable("Unknown SCEV kind!");
9809}
9810
9811const SCEV *
9812ScalarEvolution::getWithOperands(const SCEV *S,
9814 switch (S->getSCEVType()) {
9815 case scTruncate:
9816 case scZeroExtend:
9817 case scSignExtend:
9818 case scPtrToInt:
9819 return getCastExpr(S->getSCEVType(), NewOps[0], S->getType());
9820 case scAddRecExpr: {
9821 auto *AddRec = cast<SCEVAddRecExpr>(S);
9822 return getAddRecExpr(NewOps, AddRec->getLoop(), AddRec->getNoWrapFlags());
9823 }
9824 case scAddExpr:
9825 return getAddExpr(NewOps, cast<SCEVAddExpr>(S)->getNoWrapFlags());
9826 case scMulExpr:
9827 return getMulExpr(NewOps, cast<SCEVMulExpr>(S)->getNoWrapFlags());
9828 case scUDivExpr:
9829 return getUDivExpr(NewOps[0], NewOps[1]);
9830 case scUMaxExpr:
9831 case scSMaxExpr:
9832 case scUMinExpr:
9833 case scSMinExpr:
9834 return getMinMaxExpr(S->getSCEVType(), NewOps);
9836 return getSequentialMinMaxExpr(S->getSCEVType(), NewOps);
9837 case scConstant:
9838 case scVScale:
9839 case scUnknown:
9840 return S;
9841 case scCouldNotCompute:
9842 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
9843 }
9844 llvm_unreachable("Unknown SCEV kind!");
9845}
9846
9847const SCEV *ScalarEvolution::computeSCEVAtScope(const SCEV *V, const Loop *L) {
9848 switch (V->getSCEVType()) {
9849 case scConstant:
9850 case scVScale:
9851 return V;
9852 case scAddRecExpr: {
9853 // If this is a loop recurrence for a loop that does not contain L, then we
9854 // are dealing with the final value computed by the loop.
9855 const SCEVAddRecExpr *AddRec = cast<SCEVAddRecExpr>(V);
9856 // First, attempt to evaluate each operand.
9857 // Avoid performing the look-up in the common case where the specified
9858 // expression has no loop-variant portions.
9859 for (unsigned i = 0, e = AddRec->getNumOperands(); i != e; ++i) {
9860 const SCEV *OpAtScope = getSCEVAtScope(AddRec->getOperand(i), L);
9861 if (OpAtScope == AddRec->getOperand(i))
9862 continue;
9863
9864 // Okay, at least one of these operands is loop variant but might be
9865 // foldable. Build a new instance of the folded commutative expression.
9867 NewOps.reserve(AddRec->getNumOperands());
9868 append_range(NewOps, AddRec->operands().take_front(i));
9869 NewOps.push_back(OpAtScope);
9870 for (++i; i != e; ++i)
9871 NewOps.push_back(getSCEVAtScope(AddRec->getOperand(i), L));
9872
9873 const SCEV *FoldedRec = getAddRecExpr(
9874 NewOps, AddRec->getLoop(), AddRec->getNoWrapFlags(SCEV::FlagNW));
9875 AddRec = dyn_cast<SCEVAddRecExpr>(FoldedRec);
9876 // The addrec may be folded to a nonrecurrence, for example, if the
9877 // induction variable is multiplied by zero after constant folding. Go
9878 // ahead and return the folded value.
9879 if (!AddRec)
9880 return FoldedRec;
9881 break;
9882 }
9883
9884 // If the scope is outside the addrec's loop, evaluate it by using the
9885 // loop exit value of the addrec.
9886 if (!AddRec->getLoop()->contains(L)) {
9887 // To evaluate this recurrence, we need to know how many times the AddRec
9888 // loop iterates. Compute this now.
9889 const SCEV *BackedgeTakenCount = getBackedgeTakenCount(AddRec->getLoop());
9890 if (BackedgeTakenCount == getCouldNotCompute())
9891 return AddRec;
9892
9893 // Then, evaluate the AddRec.
9894 return AddRec->evaluateAtIteration(BackedgeTakenCount, *this);
9895 }
9896
9897 return AddRec;
9898 }
9899 case scTruncate:
9900 case scZeroExtend:
9901 case scSignExtend:
9902 case scPtrToInt:
9903 case scAddExpr:
9904 case scMulExpr:
9905 case scUDivExpr:
9906 case scUMaxExpr:
9907 case scSMaxExpr:
9908 case scUMinExpr:
9909 case scSMinExpr:
9910 case scSequentialUMinExpr: {
9911 ArrayRef<const SCEV *> Ops = V->operands();
9912 // Avoid performing the look-up in the common case where the specified
9913 // expression has no loop-variant portions.
9914 for (unsigned i = 0, e = Ops.size(); i != e; ++i) {
9915 const SCEV *OpAtScope = getSCEVAtScope(Ops[i], L);
9916 if (OpAtScope != Ops[i]) {
9917 // Okay, at least one of these operands is loop variant but might be
9918 // foldable. Build a new instance of the folded commutative expression.
9920 NewOps.reserve(Ops.size());
9921 append_range(NewOps, Ops.take_front(i));
9922 NewOps.push_back(OpAtScope);
9923
9924 for (++i; i != e; ++i) {
9925 OpAtScope = getSCEVAtScope(Ops[i], L);
9926 NewOps.push_back(OpAtScope);
9927 }
9928
9929 return getWithOperands(V, NewOps);
9930 }
9931 }
9932 // If we got here, all operands are loop invariant.
9933 return V;
9934 }
9935 case scUnknown: {
9936 // If this instruction is evolved from a constant-evolving PHI, compute the
9937 // exit value from the loop without using SCEVs.
9938 const SCEVUnknown *SU = cast<SCEVUnknown>(V);
9939 Instruction *I = dyn_cast<Instruction>(SU->getValue());
9940 if (!I)
9941 return V; // This is some other type of SCEVUnknown, just return it.
9942
9943 if (PHINode *PN = dyn_cast<PHINode>(I)) {
9944 const Loop *CurrLoop = this->LI[I->getParent()];
9945 // Looking for loop exit value.
9946 if (CurrLoop && CurrLoop->getParentLoop() == L &&
9947 PN->getParent() == CurrLoop->getHeader()) {
9948 // Okay, there is no closed form solution for the PHI node. Check
9949 // to see if the loop that contains it has a known backedge-taken
9950 // count. If so, we may be able to force computation of the exit
9951 // value.
9952 const SCEV *BackedgeTakenCount = getBackedgeTakenCount(CurrLoop);
9953 // This trivial case can show up in some degenerate cases where
9954 // the incoming IR has not yet been fully simplified.
9955 if (BackedgeTakenCount->isZero()) {
9956 Value *InitValue = nullptr;
9957 bool MultipleInitValues = false;
9958 for (unsigned i = 0; i < PN->getNumIncomingValues(); i++) {
9959 if (!CurrLoop->contains(PN->getIncomingBlock(i))) {
9960 if (!InitValue)
9961 InitValue = PN->getIncomingValue(i);
9962 else if (InitValue != PN->getIncomingValue(i)) {
9963 MultipleInitValues = true;
9964 break;
9965 }
9966 }
9967 }
9968 if (!MultipleInitValues && InitValue)
9969 return getSCEV(InitValue);
9970 }
9971 // Do we have a loop invariant value flowing around the backedge
9972 // for a loop which must execute the backedge?
9973 if (!isa<SCEVCouldNotCompute>(BackedgeTakenCount) &&
9974 isKnownNonZero(BackedgeTakenCount) &&
9975 PN->getNumIncomingValues() == 2) {
9976
9977 unsigned InLoopPred =
9978 CurrLoop->contains(PN->getIncomingBlock(0)) ? 0 : 1;
9979 Value *BackedgeVal = PN->getIncomingValue(InLoopPred);
9980 if (CurrLoop->isLoopInvariant(BackedgeVal))
9981 return getSCEV(BackedgeVal);
9982 }
9983 if (auto *BTCC = dyn_cast<SCEVConstant>(BackedgeTakenCount)) {
9984 // Okay, we know how many times the containing loop executes. If
9985 // this is a constant evolving PHI node, get the final value at
9986 // the specified iteration number.
9987 Constant *RV =
9988 getConstantEvolutionLoopExitValue(PN, BTCC->getAPInt(), CurrLoop);
9989 if (RV)
9990 return getSCEV(RV);
9991 }
9992 }
9993 }
9994
9995 // Okay, this is an expression that we cannot symbolically evaluate
9996 // into a SCEV. Check to see if it's possible to symbolically evaluate
9997 // the arguments into constants, and if so, try to constant propagate the
9998 // result. This is particularly useful for computing loop exit values.
9999 if (!CanConstantFold(I))
10000 return V; // This is some other type of SCEVUnknown, just return it.
10001
10003 Operands.reserve(I->getNumOperands());
10004 bool MadeImprovement = false;
10005 for (Value *Op : I->operands()) {
10006 if (Constant *C = dyn_cast<Constant>(Op)) {
10007 Operands.push_back(C);
10008 continue;
10009 }
10010
10011 // If any of the operands is non-constant and if they are
10012 // non-integer and non-pointer, don't even try to analyze them
10013 // with scev techniques.
10014 if (!isSCEVable(Op->getType()))
10015 return V;
10016
10017 const SCEV *OrigV = getSCEV(Op);
10018 const SCEV *OpV = getSCEVAtScope(OrigV, L);
10019 MadeImprovement |= OrigV != OpV;
10020
10022 if (!C)
10023 return V;
10024 assert(C->getType() == Op->getType() && "Type mismatch");
10025 Operands.push_back(C);
10026 }
10027
10028 // Check to see if getSCEVAtScope actually made an improvement.
10029 if (!MadeImprovement)
10030 return V; // This is some other type of SCEVUnknown, just return it.
10031
10032 Constant *C = nullptr;
10033 const DataLayout &DL = getDataLayout();
10035 if (!C)
10036 return V;
10037 return getSCEV(C);
10038 }
10039 case scCouldNotCompute:
10040 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
10041 }
10042 llvm_unreachable("Unknown SCEV type!");
10043}
10044
10046 return getSCEVAtScope(getSCEV(V), L);
10047}
10048
10049const SCEV *ScalarEvolution::stripInjectiveFunctions(const SCEV *S) const {
10050 if (const SCEVZeroExtendExpr *ZExt = dyn_cast<SCEVZeroExtendExpr>(S))
10051 return stripInjectiveFunctions(ZExt->getOperand());
10052 if (const SCEVSignExtendExpr *SExt = dyn_cast<SCEVSignExtendExpr>(S))
10053 return stripInjectiveFunctions(SExt->getOperand());
10054 return S;
10055}
10056
10057/// Finds the minimum unsigned root of the following equation:
10058///
10059/// A * X = B (mod N)
10060///
10061/// where N = 2^BW and BW is the common bit width of A and B. The signedness of
10062/// A and B isn't important.
10063///
10064/// If the equation does not have a solution, SCEVCouldNotCompute is returned.
10065static const SCEV *SolveLinEquationWithOverflow(const APInt &A, const SCEV *B,
10066 ScalarEvolution &SE) {
10067 uint32_t BW = A.getBitWidth();
10068 assert(BW == SE.getTypeSizeInBits(B->getType()));
10069 assert(A != 0 && "A must be non-zero.");
10070
10071 // 1. D = gcd(A, N)
10072 //
10073 // The gcd of A and N may have only one prime factor: 2. The number of
10074 // trailing zeros in A is its multiplicity
10075 uint32_t Mult2 = A.countr_zero();
10076 // D = 2^Mult2
10077
10078 // 2. Check if B is divisible by D.
10079 //
10080 // B is divisible by D if and only if the multiplicity of prime factor 2 for B
10081 // is not less than multiplicity of this prime factor for D.
10082 if (SE.getMinTrailingZeros(B) < Mult2)
10083 return SE.getCouldNotCompute();
10084
10085 // 3. Compute I: the multiplicative inverse of (A / D) in arithmetic
10086 // modulo (N / D).
10087 //
10088 // If D == 1, (N / D) == N == 2^BW, so we need one extra bit to represent
10089 // (N / D) in general. The inverse itself always fits into BW bits, though,
10090 // so we immediately truncate it.
10091 APInt AD = A.lshr(Mult2).trunc(BW - Mult2); // AD = A / D
10092 APInt I = AD.multiplicativeInverse().zext(BW);
10093
10094 // 4. Compute the minimum unsigned root of the equation:
10095 // I * (B / D) mod (N / D)
10096 // To simplify the computation, we factor out the divide by D:
10097 // (I * B mod N) / D
10098 const SCEV *D = SE.getConstant(APInt::getOneBitSet(BW, Mult2));
10099 return SE.getUDivExactExpr(SE.getMulExpr(B, SE.getConstant(I)), D);
10100}
10101
10102/// For a given quadratic addrec, generate coefficients of the corresponding
10103/// quadratic equation, multiplied by a common value to ensure that they are
10104/// integers.
10105/// The returned value is a tuple { A, B, C, M, BitWidth }, where
10106/// Ax^2 + Bx + C is the quadratic function, M is the value that A, B and C
10107/// were multiplied by, and BitWidth is the bit width of the original addrec
10108/// coefficients.
10109/// This function returns std::nullopt if the addrec coefficients are not
10110/// compile- time constants.
10111static std::optional<std::tuple<APInt, APInt, APInt, APInt, unsigned>>
10113 assert(AddRec->getNumOperands() == 3 && "This is not a quadratic chrec!");
10114 const SCEVConstant *LC = dyn_cast<SCEVConstant>(AddRec->getOperand(0));
10115 const SCEVConstant *MC = dyn_cast<SCEVConstant>(AddRec->getOperand(1));
10116 const SCEVConstant *NC = dyn_cast<SCEVConstant>(AddRec->getOperand(2));
10117 LLVM_DEBUG(dbgs() << __func__ << ": analyzing quadratic addrec: "
10118 << *AddRec << '\n');
10119
10120 // We currently can only solve this if the coefficients are constants.
10121 if (!LC || !MC || !NC) {
10122 LLVM_DEBUG(dbgs() << __func__ << ": coefficients are not constant\n");
10123 return std::nullopt;
10124 }
10125
10126 APInt L = LC->getAPInt();
10127 APInt M = MC->getAPInt();
10128 APInt N = NC->getAPInt();
10129 assert(!N.isZero() && "This is not a quadratic addrec");
10130
10131 unsigned BitWidth = LC->getAPInt().getBitWidth();
10132 unsigned NewWidth = BitWidth + 1;
10133 LLVM_DEBUG(dbgs() << __func__ << ": addrec coeff bw: "
10134 << BitWidth << '\n');
10135 // The sign-extension (as opposed to a zero-extension) here matches the
10136 // extension used in SolveQuadraticEquationWrap (with the same motivation).
10137 N = N.sext(NewWidth);
10138 M = M.sext(NewWidth);
10139 L = L.sext(NewWidth);
10140
10141 // The increments are M, M+N, M+2N, ..., so the accumulated values are
10142 // L+M, (L+M)+(M+N), (L+M)+(M+N)+(M+2N), ..., that is,
10143 // L+M, L+2M+N, L+3M+3N, ...
10144 // After n iterations the accumulated value Acc is L + nM + n(n-1)/2 N.
10145 //
10146 // The equation Acc = 0 is then
10147 // L + nM + n(n-1)/2 N = 0, or 2L + 2M n + n(n-1) N = 0.
10148 // In a quadratic form it becomes:
10149 // N n^2 + (2M-N) n + 2L = 0.
10150
10151 APInt A = N;
10152 APInt B = 2 * M - A;
10153 APInt C = 2 * L;
10154 APInt T = APInt(NewWidth, 2);
10155 LLVM_DEBUG(dbgs() << __func__ << ": equation " << A << "x^2 + " << B
10156 << "x + " << C << ", coeff bw: " << NewWidth
10157 << ", multiplied by " << T << '\n');
10158 return std::make_tuple(A, B, C, T, BitWidth);
10159}
10160
10161/// Helper function to compare optional APInts:
10162/// (a) if X and Y both exist, return min(X, Y),
10163/// (b) if neither X nor Y exist, return std::nullopt,
10164/// (c) if exactly one of X and Y exists, return that value.
10165static std::optional<APInt> MinOptional(std::optional<APInt> X,
10166 std::optional<APInt> Y) {
10167 if (X && Y) {
10168 unsigned W = std::max(X->getBitWidth(), Y->getBitWidth());
10169 APInt XW = X->sext(W);
10170 APInt YW = Y->sext(W);
10171 return XW.slt(YW) ? *X : *Y;
10172 }
10173 if (!X && !Y)
10174 return std::nullopt;
10175 return X ? *X : *Y;
10176}
10177
10178/// Helper function to truncate an optional APInt to a given BitWidth.
10179/// When solving addrec-related equations, it is preferable to return a value
10180/// that has the same bit width as the original addrec's coefficients. If the
10181/// solution fits in the original bit width, truncate it (except for i1).
10182/// Returning a value of a different bit width may inhibit some optimizations.
10183///
10184/// In general, a solution to a quadratic equation generated from an addrec
10185/// may require BW+1 bits, where BW is the bit width of the addrec's
10186/// coefficients. The reason is that the coefficients of the quadratic
10187/// equation are BW+1 bits wide (to avoid truncation when converting from
10188/// the addrec to the equation).
10189static std::optional<APInt> TruncIfPossible(std::optional<APInt> X,
10190 unsigned BitWidth) {
10191 if (!X)
10192 return std::nullopt;
10193 unsigned W = X->getBitWidth();
10194 if (BitWidth > 1 && BitWidth < W && X->isIntN(BitWidth))
10195 return X->trunc(BitWidth);
10196 return X;
10197}
10198
10199/// Let c(n) be the value of the quadratic chrec {L,+,M,+,N} after n
10200/// iterations. The values L, M, N are assumed to be signed, and they
10201/// should all have the same bit widths.
10202/// Find the least n >= 0 such that c(n) = 0 in the arithmetic modulo 2^BW,
10203/// where BW is the bit width of the addrec's coefficients.
10204/// If the calculated value is a BW-bit integer (for BW > 1), it will be
10205/// returned as such, otherwise the bit width of the returned value may
10206/// be greater than BW.
10207///
10208/// This function returns std::nullopt if
10209/// (a) the addrec coefficients are not constant, or
10210/// (b) SolveQuadraticEquationWrap was unable to find a solution. For cases
10211/// like x^2 = 5, no integer solutions exist, in other cases an integer
10212/// solution may exist, but SolveQuadraticEquationWrap may fail to find it.
10213static std::optional<APInt>
10215 APInt A, B, C, M;
10216 unsigned BitWidth;
10217 auto T = GetQuadraticEquation(AddRec);
10218 if (!T)
10219 return std::nullopt;
10220
10221 std::tie(A, B, C, M, BitWidth) = *T;
10222 LLVM_DEBUG(dbgs() << __func__ << ": solving for unsigned overflow\n");
10223 std::optional<APInt> X =
10225 if (!X)
10226 return std::nullopt;
10227
10228 ConstantInt *CX = ConstantInt::get(SE.getContext(), *X);
10229 ConstantInt *V = EvaluateConstantChrecAtConstant(AddRec, CX, SE);
10230 if (!V->isZero())
10231 return std::nullopt;
10232
10233 return TruncIfPossible(X, BitWidth);
10234}
10235
10236/// Let c(n) be the value of the quadratic chrec {0,+,M,+,N} after n
10237/// iterations. The values M, N are assumed to be signed, and they
10238/// should all have the same bit widths.
10239/// Find the least n such that c(n) does not belong to the given range,
10240/// while c(n-1) does.
10241///
10242/// This function returns std::nullopt if
10243/// (a) the addrec coefficients are not constant, or
10244/// (b) SolveQuadraticEquationWrap was unable to find a solution for the
10245/// bounds of the range.
10246static std::optional<APInt>
10248 const ConstantRange &Range, ScalarEvolution &SE) {
10249 assert(AddRec->getOperand(0)->isZero() &&
10250 "Starting value of addrec should be 0");
10251 LLVM_DEBUG(dbgs() << __func__ << ": solving boundary crossing for range "
10252 << Range << ", addrec " << *AddRec << '\n');
10253 // This case is handled in getNumIterationsInRange. Here we can assume that
10254 // we start in the range.
10255 assert(Range.contains(APInt(SE.getTypeSizeInBits(AddRec->getType()), 0)) &&
10256 "Addrec's initial value should be in range");
10257
10258 APInt A, B, C, M;
10259 unsigned BitWidth;
10260 auto T = GetQuadraticEquation(AddRec);
10261 if (!T)
10262 return std::nullopt;
10263
10264 // Be careful about the return value: there can be two reasons for not
10265 // returning an actual number. First, if no solutions to the equations
10266 // were found, and second, if the solutions don't leave the given range.
10267 // The first case means that the actual solution is "unknown", the second
10268 // means that it's known, but not valid. If the solution is unknown, we
10269 // cannot make any conclusions.
10270 // Return a pair: the optional solution and a flag indicating if the
10271 // solution was found.
10272 auto SolveForBoundary =
10273 [&](APInt Bound) -> std::pair<std::optional<APInt>, bool> {
10274 // Solve for signed overflow and unsigned overflow, pick the lower
10275 // solution.
10276 LLVM_DEBUG(dbgs() << "SolveQuadraticAddRecRange: checking boundary "
10277 << Bound << " (before multiplying by " << M << ")\n");
10278 Bound *= M; // The quadratic equation multiplier.
10279
10280 std::optional<APInt> SO;
10281 if (BitWidth > 1) {
10282 LLVM_DEBUG(dbgs() << "SolveQuadraticAddRecRange: solving for "
10283 "signed overflow\n");
10285 }
10286 LLVM_DEBUG(dbgs() << "SolveQuadraticAddRecRange: solving for "
10287 "unsigned overflow\n");
10288 std::optional<APInt> UO =
10290
10291 auto LeavesRange = [&] (const APInt &X) {
10292 ConstantInt *C0 = ConstantInt::get(SE.getContext(), X);
10293 ConstantInt *V0 = EvaluateConstantChrecAtConstant(AddRec, C0, SE);
10294 if (Range.contains(V0->getValue()))
10295 return false;
10296 // X should be at least 1, so X-1 is non-negative.
10297 ConstantInt *C1 = ConstantInt::get(SE.getContext(), X-1);
10298 ConstantInt *V1 = EvaluateConstantChrecAtConstant(AddRec, C1, SE);
10299 if (Range.contains(V1->getValue()))
10300 return true;
10301 return false;
10302 };
10303
10304 // If SolveQuadraticEquationWrap returns std::nullopt, it means that there
10305 // can be a solution, but the function failed to find it. We cannot treat it
10306 // as "no solution".
10307 if (!SO || !UO)
10308 return {std::nullopt, false};
10309
10310 // Check the smaller value first to see if it leaves the range.
10311 // At this point, both SO and UO must have values.
10312 std::optional<APInt> Min = MinOptional(SO, UO);
10313 if (LeavesRange(*Min))
10314 return { Min, true };
10315 std::optional<APInt> Max = Min == SO ? UO : SO;
10316 if (LeavesRange(*Max))
10317 return { Max, true };
10318
10319 // Solutions were found, but were eliminated, hence the "true".
10320 return {std::nullopt, true};
10321 };
10322
10323 std::tie(A, B, C, M, BitWidth) = *T;
10324 // Lower bound is inclusive, subtract 1 to represent the exiting value.
10325 APInt Lower = Range.getLower().sext(A.getBitWidth()) - 1;
10326 APInt Upper = Range.getUpper().sext(A.getBitWidth());
10327 auto SL = SolveForBoundary(Lower);
10328 auto SU = SolveForBoundary(Upper);
10329 // If any of the solutions was unknown, no meaninigful conclusions can
10330 // be made.
10331 if (!SL.second || !SU.second)
10332 return std::nullopt;
10333
10334 // Claim: The correct solution is not some value between Min and Max.
10335 //
10336 // Justification: Assuming that Min and Max are different values, one of
10337 // them is when the first signed overflow happens, the other is when the
10338 // first unsigned overflow happens. Crossing the range boundary is only
10339 // possible via an overflow (treating 0 as a special case of it, modeling
10340 // an overflow as crossing k*2^W for some k).
10341 //
10342 // The interesting case here is when Min was eliminated as an invalid
10343 // solution, but Max was not. The argument is that if there was another
10344 // overflow between Min and Max, it would also have been eliminated if
10345 // it was considered.
10346 //
10347 // For a given boundary, it is possible to have two overflows of the same
10348 // type (signed/unsigned) without having the other type in between: this
10349 // can happen when the vertex of the parabola is between the iterations
10350 // corresponding to the overflows. This is only possible when the two
10351 // overflows cross k*2^W for the same k. In such case, if the second one
10352 // left the range (and was the first one to do so), the first overflow
10353 // would have to enter the range, which would mean that either we had left
10354 // the range before or that we started outside of it. Both of these cases
10355 // are contradictions.
10356 //
10357 // Claim: In the case where SolveForBoundary returns std::nullopt, the correct
10358 // solution is not some value between the Max for this boundary and the
10359 // Min of the other boundary.
10360 //
10361 // Justification: Assume that we had such Max_A and Min_B corresponding
10362 // to range boundaries A and B and such that Max_A < Min_B. If there was
10363 // a solution between Max_A and Min_B, it would have to be caused by an
10364 // overflow corresponding to either A or B. It cannot correspond to B,
10365 // since Min_B is the first occurrence of such an overflow. If it
10366 // corresponded to A, it would have to be either a signed or an unsigned
10367 // overflow that is larger than both eliminated overflows for A. But
10368 // between the eliminated overflows and this overflow, the values would
10369 // cover the entire value space, thus crossing the other boundary, which
10370 // is a contradiction.
10371
10372 return TruncIfPossible(MinOptional(SL.first, SU.first), BitWidth);
10373}
10374
10375ScalarEvolution::ExitLimit ScalarEvolution::howFarToZero(const SCEV *V,
10376 const Loop *L,
10377 bool ControlsOnlyExit,
10378 bool AllowPredicates) {
10379
10380 // This is only used for loops with a "x != y" exit test. The exit condition
10381 // is now expressed as a single expression, V = x-y. So the exit test is
10382 // effectively V != 0. We know and take advantage of the fact that this
10383 // expression only being used in a comparison by zero context.
10384
10386 // If the value is a constant
10387 if (const SCEVConstant *C = dyn_cast<SCEVConstant>(V)) {
10388 // If the value is already zero, the branch will execute zero times.
10389 if (C->getValue()->isZero()) return C;
10390 return getCouldNotCompute(); // Otherwise it will loop infinitely.
10391 }
10392
10393 const SCEVAddRecExpr *AddRec =
10394 dyn_cast<SCEVAddRecExpr>(stripInjectiveFunctions(V));
10395
10396 if (!AddRec && AllowPredicates)
10397 // Try to make this an AddRec using runtime tests, in the first X
10398 // iterations of this loop, where X is the SCEV expression found by the
10399 // algorithm below.
10400 AddRec = convertSCEVToAddRecWithPredicates(V, L, Predicates);
10401
10402 if (!AddRec || AddRec->getLoop() != L)
10403 return getCouldNotCompute();
10404
10405 // If this is a quadratic (3-term) AddRec {L,+,M,+,N}, find the roots of
10406 // the quadratic equation to solve it.
10407 if (AddRec->isQuadratic() && AddRec->getType()->isIntegerTy()) {
10408 // We can only use this value if the chrec ends up with an exact zero
10409 // value at this index. When solving for "X*X != 5", for example, we
10410 // should not accept a root of 2.
10411 if (auto S = SolveQuadraticAddRecExact(AddRec, *this)) {
10412 const auto *R = cast<SCEVConstant>(getConstant(*S));
10413 return ExitLimit(R, R, R, false, Predicates);
10414 }
10415 return getCouldNotCompute();
10416 }
10417
10418 // Otherwise we can only handle this if it is affine.
10419 if (!AddRec->isAffine())
10420 return getCouldNotCompute();
10421
10422 // If this is an affine expression, the execution count of this branch is
10423 // the minimum unsigned root of the following equation:
10424 //
10425 // Start + Step*N = 0 (mod 2^BW)
10426 //
10427 // equivalent to:
10428 //
10429 // Step*N = -Start (mod 2^BW)
10430 //
10431 // where BW is the common bit width of Start and Step.
10432
10433 // Get the initial value for the loop.
10434 const SCEV *Start = getSCEVAtScope(AddRec->getStart(), L->getParentLoop());
10435 const SCEV *Step = getSCEVAtScope(AddRec->getOperand(1), L->getParentLoop());
10436
10437 // For now we handle only constant steps.
10438 //
10439 // TODO: Handle a nonconstant Step given AddRec<NUW>. If the
10440 // AddRec is NUW, then (in an unsigned sense) it cannot be counting up to wrap
10441 // to 0, it must be counting down to equal 0. Consequently, N = Start / -Step.
10442 // We have not yet seen any such cases.
10443 const SCEVConstant *StepC = dyn_cast<SCEVConstant>(Step);
10444 if (!StepC || StepC->getValue()->isZero())
10445 return getCouldNotCompute();
10446
10447 // For positive steps (counting up until unsigned overflow):
10448 // N = -Start/Step (as unsigned)
10449 // For negative steps (counting down to zero):
10450 // N = Start/-Step
10451 // First compute the unsigned distance from zero in the direction of Step.
10452 bool CountDown = StepC->getAPInt().isNegative();
10453 const SCEV *Distance = CountDown ? Start : getNegativeSCEV(Start);
10454
10455 // Handle unitary steps, which cannot wraparound.
10456 // 1*N = -Start; -1*N = Start (mod 2^BW), so:
10457 // N = Distance (as unsigned)
10458 if (StepC->getValue()->isOne() || StepC->getValue()->isMinusOne()) {
10459 APInt MaxBECount = getUnsignedRangeMax(applyLoopGuards(Distance, L));
10460 MaxBECount = APIntOps::umin(MaxBECount, getUnsignedRangeMax(Distance));
10461
10462 // When a loop like "for (int i = 0; i != n; ++i) { /* body */ }" is rotated,
10463 // we end up with a loop whose backedge-taken count is n - 1. Detect this
10464 // case, and see if we can improve the bound.
10465 //
10466 // Explicitly handling this here is necessary because getUnsignedRange
10467 // isn't context-sensitive; it doesn't know that we only care about the
10468 // range inside the loop.
10469 const SCEV *Zero = getZero(Distance->getType());
10470 const SCEV *One = getOne(Distance->getType());
10471 const SCEV *DistancePlusOne = getAddExpr(Distance, One);
10472 if (isLoopEntryGuardedByCond(L, ICmpInst::ICMP_NE, DistancePlusOne, Zero)) {
10473 // If Distance + 1 doesn't overflow, we can compute the maximum distance
10474 // as "unsigned_max(Distance + 1) - 1".
10475 ConstantRange CR = getUnsignedRange(DistancePlusOne);
10476 MaxBECount = APIntOps::umin(MaxBECount, CR.getUnsignedMax() - 1);
10477 }
10478 return ExitLimit(Distance, getConstant(MaxBECount), Distance, false,
10479 Predicates);
10480 }
10481
10482 // If the condition controls loop exit (the loop exits only if the expression
10483 // is true) and the addition is no-wrap we can use unsigned divide to
10484 // compute the backedge count. In this case, the step may not divide the
10485 // distance, but we don't care because if the condition is "missed" the loop
10486 // will have undefined behavior due to wrapping.
10487 if (ControlsOnlyExit && AddRec->hasNoSelfWrap() &&
10488 loopHasNoAbnormalExits(AddRec->getLoop())) {
10489 const SCEV *Exact =
10490 getUDivExpr(Distance, CountDown ? getNegativeSCEV(Step) : Step);
10491 const SCEV *ConstantMax = getCouldNotCompute();
10492 if (Exact != getCouldNotCompute()) {
10494 ConstantMax =
10496 }
10497 const SCEV *SymbolicMax =
10498 isa<SCEVCouldNotCompute>(Exact) ? ConstantMax : Exact;
10499 return ExitLimit(Exact, ConstantMax, SymbolicMax, false, Predicates);
10500 }
10501
10502 // Solve the general equation.
10503 const SCEV *E = SolveLinEquationWithOverflow(StepC->getAPInt(),
10504 getNegativeSCEV(Start), *this);
10505
10506 const SCEV *M = E;
10507 if (E != getCouldNotCompute()) {
10508 APInt MaxWithGuards = getUnsignedRangeMax(applyLoopGuards(E, L));
10509 M = getConstant(APIntOps::umin(MaxWithGuards, getUnsignedRangeMax(E)));
10510 }
10511 auto *S = isa<SCEVCouldNotCompute>(E) ? M : E;
10512 return ExitLimit(E, M, S, false, Predicates);
10513}
10514
10516ScalarEvolution::howFarToNonZero(const SCEV *V, const Loop *L) {
10517 // Loops that look like: while (X == 0) are very strange indeed. We don't
10518 // handle them yet except for the trivial case. This could be expanded in the
10519 // future as needed.
10520
10521 // If the value is a constant, check to see if it is known to be non-zero
10522 // already. If so, the backedge will execute zero times.
10523 if (const SCEVConstant *C = dyn_cast<SCEVConstant>(V)) {
10524 if (!C->getValue()->isZero())
10525 return getZero(C->getType());
10526 return getCouldNotCompute(); // Otherwise it will loop infinitely.
10527 }
10528
10529 // We could implement others, but I really doubt anyone writes loops like
10530 // this, and if they did, they would already be constant folded.
10531 return getCouldNotCompute();
10532}
10533
10534std::pair<const BasicBlock *, const BasicBlock *>
10535ScalarEvolution::getPredecessorWithUniqueSuccessorForBB(const BasicBlock *BB)
10536 const {
10537 // If the block has a unique predecessor, then there is no path from the
10538 // predecessor to the block that does not go through the direct edge
10539 // from the predecessor to the block.
10540 if (const BasicBlock *Pred = BB->getSinglePredecessor())
10541 return {Pred, BB};
10542
10543 // A loop's header is defined to be a block that dominates the loop.
10544 // If the header has a unique predecessor outside the loop, it must be
10545 // a block that has exactly one successor that can reach the loop.
10546 if (const Loop *L = LI.getLoopFor(BB))
10547 return {L->getLoopPredecessor(), L->getHeader()};
10548
10549 return {nullptr, nullptr};
10550}
10551
10552/// SCEV structural equivalence is usually sufficient for testing whether two
10553/// expressions are equal, however for the purposes of looking for a condition
10554/// guarding a loop, it can be useful to be a little more general, since a
10555/// front-end may have replicated the controlling expression.
10556static bool HasSameValue(const SCEV *A, const SCEV *B) {
10557 // Quick check to see if they are the same SCEV.
10558 if (A == B) return true;
10559
10560 auto ComputesEqualValues = [](const Instruction *A, const Instruction *B) {
10561 // Not all instructions that are "identical" compute the same value. For
10562 // instance, two distinct alloca instructions allocating the same type are
10563 // identical and do not read memory; but compute distinct values.
10564 return A->isIdenticalTo(B) && (isa<BinaryOperator>(A) || isa<GetElementPtrInst>(A));
10565 };
10566
10567 // Otherwise, if they're both SCEVUnknown, it's possible that they hold
10568 // two different instructions with the same value. Check for this case.
10569 if (const SCEVUnknown *AU = dyn_cast<SCEVUnknown>(A))
10570 if (const SCEVUnknown *BU = dyn_cast<SCEVUnknown>(B))
10571 if (const Instruction *AI = dyn_cast<Instruction>(AU->getValue()))
10572 if (const Instruction *BI = dyn_cast<Instruction>(BU->getValue()))
10573 if (ComputesEqualValues(AI, BI))
10574 return true;
10575
10576 // Otherwise assume they may have a different value.
10577 return false;
10578}
10579
10580static bool MatchBinarySub(const SCEV *S, const SCEV *&LHS, const SCEV *&RHS) {
10581 const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(S);
10582 if (!Add || Add->getNumOperands() != 2)
10583 return false;
10584 if (auto *ME = dyn_cast<SCEVMulExpr>(Add->getOperand(0));
10585 ME && ME->getNumOperands() == 2 && ME->getOperand(0)->isAllOnesValue()) {
10586 LHS = Add->getOperand(1);
10587 RHS = ME->getOperand(1);
10588 return true;
10589 }
10590 if (auto *ME = dyn_cast<SCEVMulExpr>(Add->getOperand(1));
10591 ME && ME->getNumOperands() == 2 && ME->getOperand(0)->isAllOnesValue()) {
10592 LHS = Add->getOperand(0);
10593 RHS = ME->getOperand(1);
10594 return true;
10595 }
10596 return false;
10597}
10598
10600 const SCEV *&LHS, const SCEV *&RHS,
10601 unsigned Depth) {
10602 bool Changed = false;
10603 // Simplifies ICMP to trivial true or false by turning it into '0 == 0' or
10604 // '0 != 0'.
10605 auto TrivialCase = [&](bool TriviallyTrue) {
10607 Pred = TriviallyTrue ? ICmpInst::ICMP_EQ : ICmpInst::ICMP_NE;
10608 return true;
10609 };
10610 // If we hit the max recursion limit bail out.
10611 if (Depth >= 3)
10612 return false;
10613
10614 // Canonicalize a constant to the right side.
10615 if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(LHS)) {
10616 // Check for both operands constant.
10617 if (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS)) {
10618 if (ConstantExpr::getICmp(Pred,
10619 LHSC->getValue(),
10620 RHSC->getValue())->isNullValue())
10621 return TrivialCase(false);
10622 return TrivialCase(true);
10623 }
10624 // Otherwise swap the operands to put the constant on the right.
10625 std::swap(LHS, RHS);
10626 Pred = ICmpInst::getSwappedPredicate(Pred);
10627 Changed = true;
10628 }
10629
10630 // If we're comparing an addrec with a value which is loop-invariant in the
10631 // addrec's loop, put the addrec on the left. Also make a dominance check,
10632 // as both operands could be addrecs loop-invariant in each other's loop.
10633 if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(RHS)) {
10634 const Loop *L = AR->getLoop();
10635 if (isLoopInvariant(LHS, L) && properlyDominates(LHS, L->getHeader())) {
10636 std::swap(LHS, RHS);
10637 Pred = ICmpInst::getSwappedPredicate(Pred);
10638 Changed = true;
10639 }
10640 }
10641
10642 // If there's a constant operand, canonicalize comparisons with boundary
10643 // cases, and canonicalize *-or-equal comparisons to regular comparisons.
10644 if (const SCEVConstant *RC = dyn_cast<SCEVConstant>(RHS)) {
10645 const APInt &RA = RC->getAPInt();
10646
10647 bool SimplifiedByConstantRange = false;
10648
10649 if (!ICmpInst::isEquality(Pred)) {
10651 if (ExactCR.isFullSet())
10652 return TrivialCase(true);
10653 if (ExactCR.isEmptySet())
10654 return TrivialCase(false);
10655
10656 APInt NewRHS;
10657 CmpInst::Predicate NewPred;
10658 if (ExactCR.getEquivalentICmp(NewPred, NewRHS) &&
10659 ICmpInst::isEquality(NewPred)) {
10660 // We were able to convert an inequality to an equality.
10661 Pred = NewPred;
10662 RHS = getConstant(NewRHS);
10663 Changed = SimplifiedByConstantRange = true;
10664 }
10665 }
10666
10667 if (!SimplifiedByConstantRange) {
10668 switch (Pred) {
10669 default:
10670 break;
10671 case ICmpInst::ICMP_EQ:
10672 case ICmpInst::ICMP_NE:
10673 // Fold ((-1) * %a) + %b == 0 (equivalent to %b-%a == 0) into %a == %b.
10674 if (RA.isZero() && MatchBinarySub(LHS, LHS, RHS))
10675 Changed = true;
10676 break;
10677
10678 // The "Should have been caught earlier!" messages refer to the fact
10679 // that the ExactCR.isFullSet() or ExactCR.isEmptySet() check above
10680 // should have fired on the corresponding cases, and canonicalized the
10681 // check to trivial case.
10682
10683 case ICmpInst::ICMP_UGE:
10684 assert(!RA.isMinValue() && "Should have been caught earlier!");
10685 Pred = ICmpInst::ICMP_UGT;
10686 RHS = getConstant(RA - 1);
10687 Changed = true;
10688 break;
10689 case ICmpInst::ICMP_ULE:
10690 assert(!RA.isMaxValue() && "Should have been caught earlier!");
10691 Pred = ICmpInst::ICMP_ULT;
10692 RHS = getConstant(RA + 1);
10693 Changed = true;
10694 break;
10695 case ICmpInst::ICMP_SGE:
10696 assert(!RA.isMinSignedValue() && "Should have been caught earlier!");
10697 Pred = ICmpInst::ICMP_SGT;
10698 RHS = getConstant(RA - 1);
10699 Changed = true;
10700 break;
10701 case ICmpInst::ICMP_SLE:
10702 assert(!RA.isMaxSignedValue() && "Should have been caught earlier!");
10703 Pred = ICmpInst::ICMP_SLT;
10704 RHS = getConstant(RA + 1);
10705 Changed = true;
10706 break;
10707 }
10708 }
10709 }
10710
10711 // Check for obvious equality.
10712 if (HasSameValue(LHS, RHS)) {
10713 if (ICmpInst::isTrueWhenEqual(Pred))
10714 return TrivialCase(true);
10716 return TrivialCase(false);
10717 }
10718
10719 // If possible, canonicalize GE/LE comparisons to GT/LT comparisons, by
10720 // adding or subtracting 1 from one of the operands.
10721 switch (Pred) {
10722 case ICmpInst::ICMP_SLE:
10723 if (!getSignedRangeMax(RHS).isMaxSignedValue()) {
10724 RHS = getAddExpr(getConstant(RHS->getType(), 1, true), RHS,
10726 Pred = ICmpInst::ICMP_SLT;
10727 Changed = true;
10728 } else if (!getSignedRangeMin(LHS).isMinSignedValue()) {
10729 LHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), LHS,
10731 Pred = ICmpInst::ICMP_SLT;
10732 Changed = true;
10733 }
10734 break;
10735 case ICmpInst::ICMP_SGE:
10736 if (!getSignedRangeMin(RHS).isMinSignedValue()) {
10737 RHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), RHS,
10739 Pred = ICmpInst::ICMP_SGT;
10740 Changed = true;
10741 } else if (!getSignedRangeMax(LHS).isMaxSignedValue()) {
10742 LHS = getAddExpr(getConstant(RHS->getType(), 1, true), LHS,
10744 Pred = ICmpInst::ICMP_SGT;
10745 Changed = true;
10746 }
10747 break;
10748 case ICmpInst::ICMP_ULE:
10749 if (!getUnsignedRangeMax(RHS).isMaxValue()) {
10750 RHS = getAddExpr(getConstant(RHS->getType(), 1, true), RHS,
10752 Pred = ICmpInst::ICMP_ULT;
10753 Changed = true;
10754 } else if (!getUnsignedRangeMin(LHS).isMinValue()) {
10755 LHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), LHS);
10756 Pred = ICmpInst::ICMP_ULT;
10757 Changed = true;
10758 }
10759 break;
10760 case ICmpInst::ICMP_UGE:
10761 if (!getUnsignedRangeMin(RHS).isMinValue()) {
10762 RHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), RHS);
10763 Pred = ICmpInst::ICMP_UGT;
10764 Changed = true;
10765 } else if (!getUnsignedRangeMax(LHS).isMaxValue()) {
10766 LHS = getAddExpr(getConstant(RHS->getType(), 1, true), LHS,
10768 Pred = ICmpInst::ICMP_UGT;
10769 Changed = true;
10770 }
10771 break;
10772 default:
10773 break;
10774 }
10775
10776 // TODO: More simplifications are possible here.
10777
10778 // Recursively simplify until we either hit a recursion limit or nothing
10779 // changes.
10780 if (Changed)
10781 return SimplifyICmpOperands(Pred, LHS, RHS, Depth + 1);
10782
10783 return Changed;
10784}
10785
10787 return getSignedRangeMax(S).isNegative();
10788}
10789
10792}
10793
10795 return !getSignedRangeMin(S).isNegative();
10796}
10797
10800}
10801
10803 // Query push down for cases where the unsigned range is
10804 // less than sufficient.
10805 if (const auto *SExt = dyn_cast<SCEVSignExtendExpr>(S))
10806 return isKnownNonZero(SExt->getOperand(0));
10807 return getUnsignedRangeMin(S) != 0;
10808}
10809
10810std::pair<const SCEV *, const SCEV *>
10812 // Compute SCEV on entry of loop L.
10813 const SCEV *Start = SCEVInitRewriter::rewrite(S, L, *this);
10814 if (Start == getCouldNotCompute())
10815 return { Start, Start };
10816 // Compute post increment SCEV for loop L.
10817 const SCEV *PostInc = SCEVPostIncRewriter::rewrite(S, L, *this);
10818 assert(PostInc != getCouldNotCompute() && "Unexpected could not compute");
10819 return { Start, PostInc };
10820}
10821
10823 const SCEV *LHS, const SCEV *RHS) {
10824 // First collect all loops.
10826 getUsedLoops(LHS, LoopsUsed);
10827 getUsedLoops(RHS, LoopsUsed);
10828
10829 if (LoopsUsed.empty())
10830 return false;
10831
10832 // Domination relationship must be a linear order on collected loops.
10833#ifndef NDEBUG
10834 for (const auto *L1 : LoopsUsed)
10835 for (const auto *L2 : LoopsUsed)
10836 assert((DT.dominates(L1->getHeader(), L2->getHeader()) ||
10837 DT.dominates(L2->getHeader(), L1->getHeader())) &&
10838 "Domination relationship is not a linear order");
10839#endif
10840
10841 const Loop *MDL =
10842 *llvm::max_element(LoopsUsed, [&](const Loop *L1, const Loop *L2) {
10843 return DT.properlyDominates(L1->getHeader(), L2->getHeader());
10844 });
10845
10846 // Get init and post increment value for LHS.
10847 auto SplitLHS = SplitIntoInitAndPostInc(MDL, LHS);
10848 // if LHS contains unknown non-invariant SCEV then bail out.
10849 if (SplitLHS.first == getCouldNotCompute())
10850 return false;
10851 assert (SplitLHS.second != getCouldNotCompute() && "Unexpected CNC");
10852 // Get init and post increment value for RHS.
10853 auto SplitRHS = SplitIntoInitAndPostInc(MDL, RHS);
10854 // if RHS contains unknown non-invariant SCEV then bail out.
10855 if (SplitRHS.first == getCouldNotCompute())
10856 return false;
10857 assert (SplitRHS.second != getCouldNotCompute() && "Unexpected CNC");
10858 // It is possible that init SCEV contains an invariant load but it does
10859 // not dominate MDL and is not available at MDL loop entry, so we should
10860 // check it here.
10861 if (!isAvailableAtLoopEntry(SplitLHS.first, MDL) ||
10862 !isAvailableAtLoopEntry(SplitRHS.first, MDL))
10863 return false;
10864
10865 // It seems backedge guard check is faster than entry one so in some cases
10866 // it can speed up whole estimation by short circuit
10867 return isLoopBackedgeGuardedByCond(MDL, Pred, SplitLHS.second,
10868 SplitRHS.second) &&
10869 isLoopEntryGuardedByCond(MDL, Pred, SplitLHS.first, SplitRHS.first);
10870}
10871
10873 const SCEV *LHS, const SCEV *RHS) {
10874 // Canonicalize the inputs first.
10875 (void)SimplifyICmpOperands(Pred, LHS, RHS);
10876
10877 if (isKnownViaInduction(Pred, LHS, RHS))
10878 return true;
10879
10880 if (isKnownPredicateViaSplitting(Pred, LHS, RHS))
10881 return true;
10882
10883 // Otherwise see what can be done with some simple reasoning.
10884 return isKnownViaNonRecursiveReasoning(Pred, LHS, RHS);
10885}
10886
10888 const SCEV *LHS,
10889 const SCEV *RHS) {
10890 if (isKnownPredicate(Pred, LHS, RHS))
10891 return true;
10893 return false;
10894 return std::nullopt;
10895}
10896
10898 const SCEV *LHS, const SCEV *RHS,
10899 const Instruction *CtxI) {
10900 // TODO: Analyze guards and assumes from Context's block.
10901 return isKnownPredicate(Pred, LHS, RHS) ||
10903}
10904
10905std::optional<bool>
10907 const SCEV *RHS, const Instruction *CtxI) {
10908 std::optional<bool> KnownWithoutContext = evaluatePredicate(Pred, LHS, RHS);
10909 if (KnownWithoutContext)
10910 return KnownWithoutContext;
10911
10912 if (isBasicBlockEntryGuardedByCond(CtxI->getParent(), Pred, LHS, RHS))
10913 return true;
10916 LHS, RHS))
10917 return false;
10918 return std::nullopt;
10919}
10920
10922 const SCEVAddRecExpr *LHS,
10923 const SCEV *RHS) {
10924 const Loop *L = LHS->getLoop();
10925 return isLoopEntryGuardedByCond(L, Pred, LHS->getStart(), RHS) &&
10926 isLoopBackedgeGuardedByCond(L, Pred, LHS->getPostIncExpr(*this), RHS);
10927}
10928
10929std::optional<ScalarEvolution::MonotonicPredicateType>
10931 ICmpInst::Predicate Pred) {
10932 auto Result = getMonotonicPredicateTypeImpl(LHS, Pred);
10933
10934#ifndef NDEBUG
10935 // Verify an invariant: inverting the predicate should turn a monotonically
10936 // increasing change to a monotonically decreasing one, and vice versa.
10937 if (Result) {
10938 auto ResultSwapped =
10939 getMonotonicPredicateTypeImpl(LHS, ICmpInst::getSwappedPredicate(Pred));
10940
10941 assert(*ResultSwapped != *Result &&
10942 "monotonicity should flip as we flip the predicate");
10943 }
10944#endif
10945
10946 return Result;
10947}
10948
10949std::optional<ScalarEvolution::MonotonicPredicateType>
10950ScalarEvolution::getMonotonicPredicateTypeImpl(const SCEVAddRecExpr *LHS,
10951 ICmpInst::Predicate Pred) {
10952 // A zero step value for LHS means the induction variable is essentially a
10953 // loop invariant value. We don't really depend on the predicate actually
10954 // flipping from false to true (for increasing predicates, and the other way
10955 // around for decreasing predicates), all we care about is that *if* the
10956 // predicate changes then it only changes from false to true.
10957 //
10958 // A zero step value in itself is not very useful, but there may be places
10959 // where SCEV can prove X >= 0 but not prove X > 0, so it is helpful to be
10960 // as general as possible.
10961
10962 // Only handle LE/LT/GE/GT predicates.
10963 if (!ICmpInst::isRelational(Pred))
10964 return std::nullopt;
10965
10966 bool IsGreater = ICmpInst::isGE(Pred) || ICmpInst::isGT(Pred);
10967 assert((IsGreater || ICmpInst::isLE(Pred) || ICmpInst::isLT(Pred)) &&
10968 "Should be greater or less!");
10969
10970 // Check that AR does not wrap.
10971 if (ICmpInst::isUnsigned(Pred)) {
10972 if (!LHS->hasNoUnsignedWrap())
10973 return std::nullopt;
10975 }
10976 assert(ICmpInst::isSigned(Pred) &&
10977 "Relational predicate is either signed or unsigned!");
10978 if (!LHS->hasNoSignedWrap())
10979 return std::nullopt;
10980
10981 const SCEV *Step = LHS->getStepRecurrence(*this);
10982
10983 if (isKnownNonNegative(Step))
10985
10986 if (isKnownNonPositive(Step))
10988
10989 return std::nullopt;
10990}
10991
10992std::optional<ScalarEvolution::LoopInvariantPredicate>
10994 const SCEV *LHS, const SCEV *RHS,
10995 const Loop *L,
10996 const Instruction *CtxI) {
10997 // If there is a loop-invariant, force it into the RHS, otherwise bail out.
10998 if (!isLoopInvariant(RHS, L)) {
10999 if (!isLoopInvariant(LHS, L))
11000 return std::nullopt;
11001
11002 std::swap(LHS, RHS);
11003 Pred = ICmpInst::getSwappedPredicate(Pred);
11004 }
11005
11006 const SCEVAddRecExpr *ArLHS = dyn_cast<SCEVAddRecExpr>(LHS);
11007 if (!ArLHS || ArLHS->getLoop() != L)
11008 return std::nullopt;
11009
11010 auto MonotonicType = getMonotonicPredicateType(ArLHS, Pred);
11011 if (!MonotonicType)
11012 return std::nullopt;
11013 // If the predicate "ArLHS `Pred` RHS" monotonically increases from false to
11014 // true as the loop iterates, and the backedge is control dependent on
11015 // "ArLHS `Pred` RHS" == true then we can reason as follows:
11016 //
11017 // * if the predicate was false in the first iteration then the predicate
11018 // is never evaluated again, since the loop exits without taking the
11019 // backedge.
11020 // * if the predicate was true in the first iteration then it will
11021 // continue to be true for all future iterations since it is
11022 // monotonically increasing.
11023 //
11024 // For both the above possibilities, we can replace the loop varying
11025 // predicate with its value on the first iteration of the loop (which is
11026 // loop invariant).
11027 //
11028 // A similar reasoning applies for a monotonically decreasing predicate, by
11029 // replacing true with false and false with true in the above two bullets.
11030 bool Increasing = *MonotonicType == ScalarEvolution::MonotonicallyIncreasing;
11031 auto P = Increasing ? Pred : ICmpInst::getInversePredicate(Pred);
11032
11035 RHS);
11036
11037 if (!CtxI)
11038 return std::nullopt;
11039 // Try to prove via context.
11040 // TODO: Support other cases.
11041 switch (Pred) {
11042 default:
11043 break;
11044 case ICmpInst::ICMP_ULE:
11045 case ICmpInst::ICMP_ULT: {
11046 assert(ArLHS->hasNoUnsignedWrap() && "Is a requirement of monotonicity!");
11047 // Given preconditions
11048 // (1) ArLHS does not cross the border of positive and negative parts of
11049 // range because of:
11050 // - Positive step; (TODO: lift this limitation)
11051 // - nuw - does not cross zero boundary;
11052 // - nsw - does not cross SINT_MAX boundary;
11053 // (2) ArLHS <s RHS
11054 // (3) RHS >=s 0
11055 // we can replace the loop variant ArLHS <u RHS condition with loop
11056 // invariant Start(ArLHS) <u RHS.
11057 //
11058 // Because of (1) there are two options:
11059 // - ArLHS is always negative. It means that ArLHS <u RHS is always false;
11060 // - ArLHS is always non-negative. Because of (3) RHS is also non-negative.
11061 // It means that ArLHS <s RHS <=> ArLHS <u RHS.
11062 // Because of (2) ArLHS <u RHS is trivially true.
11063 // All together it means that ArLHS <u RHS <=> Start(ArLHS) >=s 0.
11064 // We can strengthen this to Start(ArLHS) <u RHS.
11065 auto SignFlippedPred = ICmpInst::getFlippedSignednessPredicate(Pred);
11066 if (ArLHS->hasNoSignedWrap() && ArLHS->isAffine() &&
11067 isKnownPositive(ArLHS->getStepRecurrence(*this)) &&
11069 isKnownPredicateAt(SignFlippedPred, ArLHS, RHS, CtxI))
11071 RHS);
11072 }
11073 }
11074
11075 return std::nullopt;
11076}
11077
11078std::optional<ScalarEvolution::LoopInvariantPredicate>
11080 ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS, const Loop *L,
11081 const Instruction *CtxI, const SCEV *MaxIter) {
11083 Pred, LHS, RHS, L, CtxI, MaxIter))
11084 return LIP;
11085 if (auto *UMin = dyn_cast<SCEVUMinExpr>(MaxIter))
11086 // Number of iterations expressed as UMIN isn't always great for expressing
11087 // the value on the last iteration. If the straightforward approach didn't
11088 // work, try the following trick: if the a predicate is invariant for X, it
11089 // is also invariant for umin(X, ...). So try to find something that works
11090 // among subexpressions of MaxIter expressed as umin.
11091 for (auto *Op : UMin->operands())
11093 Pred, LHS, RHS, L, CtxI, Op))
11094 return LIP;
11095 return std::nullopt;
11096}
11097
11098std::optional<ScalarEvolution::LoopInvariantPredicate>
11100 ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS, const Loop *L,
11101 const Instruction *CtxI, const SCEV *MaxIter) {
11102 // Try to prove the following set of facts:
11103 // - The predicate is monotonic in the iteration space.
11104 // - If the check does not fail on the 1st iteration:
11105 // - No overflow will happen during first MaxIter iterations;
11106 // - It will not fail on the MaxIter'th iteration.
11107 // If the check does fail on the 1st iteration, we leave the loop and no
11108 // other checks matter.
11109
11110 // If there is a loop-invariant, force it into the RHS, otherwise bail out.
11111 if (!isLoopInvariant(RHS, L)) {
11112 if (!isLoopInvariant(LHS, L))
11113 return std::nullopt;
11114
11115 std::swap(LHS, RHS);
11116 Pred = ICmpInst::getSwappedPredicate(Pred);
11117 }
11118
11119 auto *AR = dyn_cast<SCEVAddRecExpr>(LHS);
11120 if (!AR || AR->getLoop() != L)
11121 return std::nullopt;
11122
11123 // The predicate must be relational (i.e. <, <=, >=, >).
11124 if (!ICmpInst::isRelational(Pred))
11125 return std::nullopt;
11126
11127 // TODO: Support steps other than +/- 1.
11128 const SCEV *Step = AR->getStepRecurrence(*this);
11129 auto *One = getOne(Step->getType());
11130 auto *MinusOne = getNegativeSCEV(One);
11131 if (Step != One && Step != MinusOne)
11132 return std::nullopt;
11133
11134 // Type mismatch here means that MaxIter is potentially larger than max
11135 // unsigned value in start type, which mean we cannot prove no wrap for the
11136 // indvar.
11137 if (AR->getType() != MaxIter->getType())
11138 return std::nullopt;
11139
11140 // Value of IV on suggested last iteration.
11141 const SCEV *Last = AR->evaluateAtIteration(MaxIter, *this);
11142 // Does it still meet the requirement?
11143 if (!isLoopBackedgeGuardedByCond(L, Pred, Last, RHS))
11144 return std::nullopt;
11145 // Because step is +/- 1 and MaxIter has same type as Start (i.e. it does
11146 // not exceed max unsigned value of this type), this effectively proves
11147 // that there is no wrap during the iteration. To prove that there is no
11148 // signed/unsigned wrap, we need to check that
11149 // Start <= Last for step = 1 or Start >= Last for step = -1.
11150 ICmpInst::Predicate NoOverflowPred =
11152 if (Step == MinusOne)
11153 NoOverflowPred = CmpInst::getSwappedPredicate(NoOverflowPred);
11154 const SCEV *Start = AR->getStart();
11155 if (!isKnownPredicateAt(NoOverflowPred, Start, Last, CtxI))
11156 return std::nullopt;
11157
11158 // Everything is fine.
11159 return ScalarEvolution::LoopInvariantPredicate(Pred, Start, RHS);
11160}
11161
11162bool ScalarEvolution::isKnownPredicateViaConstantRanges(
11163 ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS) {
11164 if (HasSameValue(LHS, RHS))
11165 return ICmpInst::isTrueWhenEqual(Pred);
11166
11167 // This code is split out from isKnownPredicate because it is called from
11168 // within isLoopEntryGuardedByCond.
11169
11170 auto CheckRanges = [&](const ConstantRange &RangeLHS,
11171 const ConstantRange &RangeRHS) {
11172 return RangeLHS.icmp(Pred, RangeRHS);
11173 };
11174
11175 // The check at the top of the function catches the case where the values are
11176 // known to be equal.
11177 if (Pred == CmpInst::ICMP_EQ)
11178 return false;
11179
11180 if (Pred == CmpInst::ICMP_NE) {
11181 auto SL = getSignedRange(LHS);
11182 auto SR = getSignedRange(RHS);
11183 if (CheckRanges(SL, SR))
11184 return true;
11185 auto UL = getUnsignedRange(LHS);
11186 auto UR = getUnsignedRange(RHS);
11187 if (CheckRanges(UL, UR))
11188 return true;
11189 auto *Diff = getMinusSCEV(LHS, RHS);
11190 return !isa<SCEVCouldNotCompute>(Diff) && isKnownNonZero(Diff);
11191 }
11192
11193 if (CmpInst::isSigned(Pred)) {
11194 auto SL = getSignedRange(LHS);
11195 auto SR = getSignedRange(RHS);
11196 return CheckRanges(SL, SR);
11197 }
11198
11199 auto UL = getUnsignedRange(LHS);
11200 auto UR = getUnsignedRange(RHS);
11201 return CheckRanges(UL, UR);
11202}
11203
11204bool ScalarEvolution::isKnownPredicateViaNoOverflow(ICmpInst::Predicate Pred,
11205 const SCEV *LHS,
11206 const SCEV *RHS) {
11207 // Match X to (A + C1)<ExpectedFlags> and Y to (A + C2)<ExpectedFlags>, where
11208 // C1 and C2 are constant integers. If either X or Y are not add expressions,
11209 // consider them as X + 0 and Y + 0 respectively. C1 and C2 are returned via
11210 // OutC1 and OutC2.
11211 auto MatchBinaryAddToConst = [this](const SCEV *X, const SCEV *Y,
11212 APInt &OutC1, APInt &OutC2,
11213 SCEV::NoWrapFlags ExpectedFlags) {
11214 const SCEV *XNonConstOp, *XConstOp;
11215 const SCEV *YNonConstOp, *YConstOp;
11216 SCEV::NoWrapFlags XFlagsPresent;
11217 SCEV::NoWrapFlags YFlagsPresent;
11218
11219 if (!splitBinaryAdd(X, XConstOp, XNonConstOp, XFlagsPresent)) {
11220 XConstOp = getZero(X->getType());
11221 XNonConstOp = X;
11222 XFlagsPresent = ExpectedFlags;
11223 }
11224 if (!isa<SCEVConstant>(XConstOp) ||
11225 (XFlagsPresent & ExpectedFlags) != ExpectedFlags)
11226 return false;
11227
11228 if (!splitBinaryAdd(Y, YConstOp, YNonConstOp, YFlagsPresent)) {
11229 YConstOp = getZero(Y->getType());
11230 YNonConstOp = Y;
11231 YFlagsPresent = ExpectedFlags;
11232 }
11233
11234 if (!isa<SCEVConstant>(YConstOp) ||
11235 (YFlagsPresent & ExpectedFlags) != ExpectedFlags)
11236 return false;
11237
11238 if (YNonConstOp != XNonConstOp)
11239 return false;
11240
11241 OutC1 = cast<SCEVConstant>(XConstOp)->getAPInt();
11242 OutC2 = cast<SCEVConstant>(YConstOp)->getAPInt();
11243
11244 return true;
11245 };
11246
11247 APInt C1;
11248 APInt C2;
11249
11250 switch (Pred) {
11251 default:
11252 break;
11253
11254 case ICmpInst::ICMP_SGE:
11255 std::swap(LHS, RHS);
11256 [[fallthrough]];
11257 case ICmpInst::ICMP_SLE:
11258 // (X + C1)<nsw> s<= (X + C2)<nsw> if C1 s<= C2.
11259 if (MatchBinaryAddToConst(LHS, RHS, C1, C2, SCEV::FlagNSW) && C1.sle(C2))
11260 return true;
11261
11262 break;
11263
11264 case ICmpInst::ICMP_SGT:
11265 std::swap(LHS, RHS);
11266 [[fallthrough]];
11267 case ICmpInst::ICMP_SLT:
11268 // (X + C1)<nsw> s< (X + C2)<nsw> if C1 s< C2.
11269 if (MatchBinaryAddToConst(LHS, RHS, C1, C2, SCEV::FlagNSW) && C1.slt(C2))
11270 return true;
11271
11272 break;
11273
11274 case ICmpInst::ICMP_UGE:
11275 std::swap(LHS, RHS);
11276 [[fallthrough]];
11277 case ICmpInst::ICMP_ULE:
11278 // (X + C1)<nuw> u<= (X + C2)<nuw> for C1 u<= C2.
11279 if (MatchBinaryAddToConst(RHS, LHS, C2, C1, SCEV::FlagNUW) && C1.ule(C2))
11280 return true;
11281
11282 break;
11283
11284 case ICmpInst::ICMP_UGT:
11285 std::swap(LHS, RHS);
11286 [[fallthrough]];
11287 case ICmpInst::ICMP_ULT:
11288 // (X + C1)<nuw> u< (X + C2)<nuw> if C1 u< C2.
11289 if (MatchBinaryAddToConst(RHS, LHS, C2, C1, SCEV::FlagNUW) && C1.ult(C2))
11290 return true;
11291 break;
11292 }
11293
11294 return false;
11295}
11296
11297bool ScalarEvolution::isKnownPredicateViaSplitting(ICmpInst::Predicate Pred,
11298 const SCEV *LHS,
11299 const SCEV *RHS) {
11300 if (Pred != ICmpInst::ICMP_ULT || ProvingSplitPredicate)
11301 return false;
11302
11303 // Allowing arbitrary number of activations of isKnownPredicateViaSplitting on
11304 // the stack can result in exponential time complexity.
11305 SaveAndRestore Restore(ProvingSplitPredicate, true);
11306
11307 // If L >= 0 then I `ult` L <=> I >= 0 && I `slt` L
11308 //
11309 // To prove L >= 0 we use isKnownNonNegative whereas to prove I >= 0 we use
11310 // isKnownPredicate. isKnownPredicate is more powerful, but also more
11311 // expensive; and using isKnownNonNegative(RHS) is sufficient for most of the
11312 // interesting cases seen in practice. We can consider "upgrading" L >= 0 to
11313 // use isKnownPredicate later if needed.
11314 return isKnownNonNegative(RHS) &&
11317}
11318
11319bool ScalarEvolution::isImpliedViaGuard(const BasicBlock *BB,
11321 const SCEV *LHS, const SCEV *RHS) {
11322 // No need to even try if we know the module has no guards.
11323 if (!HasGuards)
11324 return false;
11325
11326 return any_of(*BB, [&](const Instruction &I) {
11327 using namespace llvm::PatternMatch;
11328
11329 Value *Condition;
11330 return match(&I, m_Intrinsic<Intrinsic::experimental_guard>(
11331 m_Value(Condition))) &&
11332 isImpliedCond(Pred, LHS, RHS, Condition, false);
11333 });
11334}
11335
11336/// isLoopBackedgeGuardedByCond - Test whether the backedge of the loop is
11337/// protected by a conditional between LHS and RHS. This is used to
11338/// to eliminate casts.
11339bool
11342 const SCEV *LHS, const SCEV *RHS) {
11343 // Interpret a null as meaning no loop, where there is obviously no guard
11344 // (interprocedural conditions notwithstanding). Do not bother about
11345 // unreachable loops.
11346 if (!L || !DT.isReachableFromEntry(L->getHeader()))
11347 return true;
11348
11349 if (VerifyIR)
11350 assert(!verifyFunction(*L->getHeader()->getParent(), &dbgs()) &&
11351 "This cannot be done on broken IR!");
11352
11353
11354 if (isKnownViaNonRecursiveReasoning(Pred, LHS, RHS))
11355 return true;
11356
11357 BasicBlock *Latch = L->getLoopLatch();
11358 if (!Latch)
11359 return false;
11360
11361 BranchInst *LoopContinuePredicate =
11362 dyn_cast<BranchInst>(Latch->getTerminator());
11363 if (LoopContinuePredicate && LoopContinuePredicate->isConditional() &&
11364 isImpliedCond(Pred, LHS, RHS,
11365 LoopContinuePredicate->getCondition(),
11366 LoopContinuePredicate->getSuccessor(0) != L->getHeader()))
11367 return true;
11368
11369 // We don't want more than one activation of the following loops on the stack
11370 // -- that can lead to O(n!) time complexity.
11371 if (WalkingBEDominatingConds)
11372 return false;
11373
11374 SaveAndRestore ClearOnExit(WalkingBEDominatingConds, true);
11375
11376 // See if we can exploit a trip count to prove the predicate.
11377 const auto &BETakenInfo = getBackedgeTakenInfo(L);
11378 const SCEV *LatchBECount = BETakenInfo.getExact(Latch, this);
11379 if (LatchBECount != getCouldNotCompute()) {
11380 // We know that Latch branches back to the loop header exactly
11381 // LatchBECount times. This means the backdege condition at Latch is
11382 // equivalent to "{0,+,1} u< LatchBECount".
11383 Type *Ty = LatchBECount->getType();
11384 auto NoWrapFlags = SCEV::NoWrapFlags(SCEV::FlagNUW | SCEV::FlagNW);
11385 const SCEV *LoopCounter =
11386 getAddRecExpr(getZero(Ty), getOne(Ty), L, NoWrapFlags);
11387 if (isImpliedCond(Pred, LHS, RHS, ICmpInst::ICMP_ULT, LoopCounter,
11388 LatchBECount))
11389 return true;
11390 }
11391
11392 // Check conditions due to any @llvm.assume intrinsics.
11393 for (auto &AssumeVH : AC.assumptions()) {
11394 if (!AssumeVH)
11395 continue;
11396 auto *CI = cast<CallInst>(AssumeVH);
11397 if (!DT.dominates(CI, Latch->getTerminator()))
11398 continue;
11399
11400 if (isImpliedCond(Pred, LHS, RHS, CI->getArgOperand(0), false))
11401 return true;
11402 }
11403
11404 if (isImpliedViaGuard(Latch, Pred, LHS, RHS))
11405 return true;
11406
11407 for (DomTreeNode *DTN = DT[Latch], *HeaderDTN = DT[L->getHeader()];
11408 DTN != HeaderDTN; DTN = DTN->getIDom()) {
11409 assert(DTN && "should reach the loop header before reaching the root!");
11410
11411 BasicBlock *BB = DTN->getBlock();
11412 if (isImpliedViaGuard(BB, Pred, LHS, RHS))
11413 return true;
11414
11415 BasicBlock *PBB = BB->getSinglePredecessor();
11416 if (!PBB)
11417 continue;
11418
11419 BranchInst *ContinuePredicate = dyn_cast<BranchInst>(PBB->getTerminator());
11420 if (!ContinuePredicate || !ContinuePredicate->isConditional())
11421 continue;
11422
11423 Value *Condition = ContinuePredicate->getCondition();
11424
11425 // If we have an edge `E` within the loop body that dominates the only
11426 // latch, the condition guarding `E` also guards the backedge. This
11427 // reasoning works only for loops with a single latch.
11428
11429 BasicBlockEdge DominatingEdge(PBB, BB);
11430 if (DominatingEdge.isSingleEdge()) {
11431 // We're constructively (and conservatively) enumerating edges within the
11432 // loop body that dominate the latch. The dominator tree better agree
11433 // with us on this:
11434 assert(DT.dominates(DominatingEdge, Latch) && "should be!");
11435
11436 if (isImpliedCond(Pred, LHS, RHS, Condition,
11437 BB != ContinuePredicate->getSuccessor(0)))
11438 return true;
11439 }
11440 }
11441
11442 return false;
11443}
11444
11447 const SCEV *LHS,
11448 const SCEV *RHS) {
11449 // Do not bother proving facts for unreachable code.
11450 if (!DT.isReachableFromEntry(BB))
11451 return true;
11452 if (VerifyIR)
11453 assert(!verifyFunction(*BB->getParent(), &dbgs()) &&
11454 "This cannot be done on broken IR!");
11455
11456 // If we cannot prove strict comparison (e.g. a > b), maybe we can prove
11457 // the facts (a >= b && a != b) separately. A typical situation is when the
11458 // non-strict comparison is known from ranges and non-equality is known from
11459 // dominating predicates. If we are proving strict comparison, we always try
11460 // to prove non-equality and non-strict comparison separately.
11461 auto NonStrictPredicate = ICmpInst::getNonStrictPredicate(Pred);
11462 const bool ProvingStrictComparison = (Pred != NonStrictPredicate);
11463 bool ProvedNonStrictComparison = false;
11464 bool ProvedNonEquality = false;
11465
11466 auto SplitAndProve =
11467 [&](std::function<bool(ICmpInst::Predicate)> Fn) -> bool {
11468 if (!ProvedNonStrictComparison)
11469 ProvedNonStrictComparison = Fn(NonStrictPredicate);
11470 if (!ProvedNonEquality)
11471 ProvedNonEquality = Fn(ICmpInst::ICMP_NE);
11472 if (ProvedNonStrictComparison && ProvedNonEquality)
11473 return true;
11474 return false;
11475 };
11476
11477 if (ProvingStrictComparison) {
11478 auto ProofFn = [&](ICmpInst::Predicate P) {
11479 return isKnownViaNonRecursiveReasoning(P, LHS, RHS);
11480 };
11481 if (SplitAndProve(ProofFn))
11482 return true;
11483 }
11484
11485 // Try to prove (Pred, LHS, RHS) using isImpliedCond.
11486 auto ProveViaCond = [&](const Value *Condition, bool Inverse) {
11487 const Instruction *CtxI = &BB->front();
11488 if (isImpliedCond(Pred, LHS, RHS, Condition, Inverse, CtxI))
11489 return true;
11490 if (ProvingStrictComparison) {
11491 auto ProofFn = [&](ICmpInst::Predicate P) {
11492 return isImpliedCond(P, LHS, RHS, Condition, Inverse, CtxI);
11493 };
11494 if (SplitAndProve(ProofFn))
11495 return true;
11496 }
11497 return false;
11498 };
11499
11500 // Starting at the block's predecessor, climb up the predecessor chain, as long
11501 // as there are predecessors that can be found that have unique successors
11502 // leading to the original block.
11503 const Loop *ContainingLoop = LI.getLoopFor(BB);
11504 const BasicBlock *PredBB;
11505 if (ContainingLoop && ContainingLoop->getHeader() == BB)
11506 PredBB = ContainingLoop->getLoopPredecessor();
11507 else
11508 PredBB = BB->getSinglePredecessor();
11509 for (std::pair<const BasicBlock *, const BasicBlock *> Pair(PredBB, BB);
11510 Pair.first; Pair = getPredecessorWithUniqueSuccessorForBB(Pair.first)) {
11511 const BranchInst *BlockEntryPredicate =
11512 dyn_cast<BranchInst>(Pair.first->getTerminator());
11513 if (!BlockEntryPredicate || BlockEntryPredicate->isUnconditional())
11514 continue;
11515
11516 if (ProveViaCond(BlockEntryPredicate->getCondition(),
11517 BlockEntryPredicate->getSuccessor(0) != Pair.second))
11518 return true;
11519 }
11520
11521 // Check conditions due to any @llvm.assume intrinsics.
11522 for (auto &AssumeVH : AC.assumptions()) {
11523 if (!AssumeVH)
11524 continue;
11525 auto *CI = cast<CallInst>(AssumeVH);
11526 if (!DT.dominates(CI, BB))
11527 continue;
11528
11529 if (ProveViaCond(CI->getArgOperand(0), false))
11530 return true;
11531 }
11532
11533 // Check conditions due to any @llvm.experimental.guard intrinsics.
11534 auto *GuardDecl = F.getParent()->getFunction(
11535 Intrinsic::getName(Intrinsic::experimental_guard));
11536 if (GuardDecl)
11537 for (const auto *GU : GuardDecl->users())
11538 if (const auto *Guard = dyn_cast<IntrinsicInst>(GU))
11539 if (Guard->getFunction() == BB->getParent() && DT.dominates(Guard, BB))
11540 if (ProveViaCond(Guard->getArgOperand(0), false))
11541 return true;
11542 return false;
11543}
11544
11547 const SCEV *LHS,
11548 const SCEV *RHS) {
11549 // Interpret a null as meaning no loop, where there is obviously no guard
11550 // (interprocedural conditions notwithstanding).
11551 if (!L)
11552 return false;
11553
11554 // Both LHS and RHS must be available at loop entry.
11556 "LHS is not available at Loop Entry");
11558 "RHS is not available at Loop Entry");
11559
11560 if (isKnownViaNonRecursiveReasoning(Pred, LHS, RHS))
11561 return true;
11562
11563 return isBasicBlockEntryGuardedByCond(L->getHeader(), Pred, LHS, RHS);
11564}
11565
11566bool ScalarEvolution::isImpliedCond(ICmpInst::Predicate Pred, const SCEV *LHS,
11567 const SCEV *RHS,
11568 const Value *FoundCondValue, bool Inverse,
11569 const Instruction *CtxI) {
11570 // False conditions implies anything. Do not bother analyzing it further.
11571 if (FoundCondValue ==
11572 ConstantInt::getBool(FoundCondValue->getContext(), Inverse))
11573 return true;
11574
11575 if (!PendingLoopPredicates.insert(FoundCondValue).second)
11576 return false;
11577
11578 auto ClearOnExit =
11579 make_scope_exit([&]() { PendingLoopPredicates.erase(FoundCondValue); });
11580
11581 // Recursively handle And and Or conditions.
11582 const Value *Op0, *Op1;
11583 if (match(FoundCondValue, m_LogicalAnd(m_Value(Op0), m_Value(Op1)))) {
11584 if (!Inverse)
11585 return isImpliedCond(Pred, LHS, RHS, Op0, Inverse, CtxI) ||
11586 isImpliedCond(Pred, LHS, RHS, Op1, Inverse, CtxI);
11587 } else if (match(FoundCondValue, m_LogicalOr(m_Value(Op0), m_Value(Op1)))) {
11588 if (Inverse)
11589 return isImpliedCond(Pred, LHS, RHS, Op0, Inverse, CtxI) ||
11590 isImpliedCond(Pred, LHS, RHS, Op1, Inverse, CtxI);
11591 }
11592
11593 const ICmpInst *ICI = dyn_cast<ICmpInst>(FoundCondValue);
11594 if (!ICI) return false;
11595
11596 // Now that we found a conditional branch that dominates the loop or controls
11597 // the loop latch. Check to see if it is the comparison we are looking for.
11598 ICmpInst::Predicate FoundPred;
11599 if (Inverse)
11600 FoundPred = ICI->getInversePredicate();
11601 else
11602 FoundPred = ICI->getPredicate();
11603
11604 const SCEV *FoundLHS = getSCEV(ICI->getOperand(0));
11605 const SCEV *FoundRHS = getSCEV(ICI->getOperand(1));
11606
11607 return isImpliedCond(Pred, LHS, RHS, FoundPred, FoundLHS, FoundRHS, CtxI);
11608}
11609
11610bool ScalarEvolution::isImpliedCond(ICmpInst::Predicate Pred, const SCEV *LHS,
11611 const SCEV *RHS,
11612 ICmpInst::Predicate FoundPred,
11613 const SCEV *FoundLHS, const SCEV *FoundRHS,
11614 const Instruction *CtxI) {
11615 // Balance the types.
11616 if (getTypeSizeInBits(LHS->getType()) <
11617 getTypeSizeInBits(FoundLHS->getType())) {
11618 // For unsigned and equality predicates, try to prove that both found
11619 // operands fit into narrow unsigned range. If so, try to prove facts in
11620 // narrow types.
11621 if (!CmpInst::isSigned(FoundPred) && !FoundLHS->getType()->isPointerTy() &&
11622 !FoundRHS->getType()->isPointerTy()) {
11623 auto *NarrowType = LHS->getType();
11624 auto *WideType = FoundLHS->getType();
11625 auto BitWidth = getTypeSizeInBits(NarrowType);
11626 const SCEV *MaxValue = getZeroExtendExpr(
11628 if (isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_ULE, FoundLHS,
11629 MaxValue) &&
11630 isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_ULE, FoundRHS,
11631 MaxValue)) {
11632 const SCEV *TruncFoundLHS = getTruncateExpr(FoundLHS, NarrowType);
11633 const SCEV *TruncFoundRHS = getTruncateExpr(FoundRHS, NarrowType);
11634 if (isImpliedCondBalancedTypes(Pred, LHS, RHS, FoundPred, TruncFoundLHS,
11635 TruncFoundRHS, CtxI))
11636 return true;
11637 }
11638 }
11639
11640 if (LHS->getType()->isPointerTy() || RHS->getType()->isPointerTy())
11641 return false;
11642 if (CmpInst::isSigned(Pred)) {
11643 LHS = getSignExtendExpr(LHS, FoundLHS->getType());
11644 RHS = getSignExtendExpr(RHS, FoundLHS->getType());
11645 } else {
11646 LHS = getZeroExtendExpr(LHS, FoundLHS->getType());
11647 RHS = getZeroExtendExpr(RHS, FoundLHS->getType());
11648 }
11649 } else if (getTypeSizeInBits(LHS->getType()) >
11650 getTypeSizeInBits(FoundLHS->getType())) {
11651 if (FoundLHS->getType()->isPointerTy() || FoundRHS->getType()->isPointerTy())
11652 return false;
11653 if (CmpInst::isSigned(FoundPred)) {
11654 FoundLHS = getSignExtendExpr(FoundLHS, LHS->getType());
11655 FoundRHS = getSignExtendExpr(FoundRHS, LHS->getType());
11656 } else {
11657 FoundLHS = getZeroExtendExpr(FoundLHS, LHS->getType());
11658 FoundRHS = getZeroExtendExpr(FoundRHS, LHS->getType());
11659 }
11660 }
11661 return isImpliedCondBalancedTypes(Pred, LHS, RHS, FoundPred, FoundLHS,
11662 FoundRHS, CtxI);
11663}
11664
11665bool ScalarEvolution::isImpliedCondBalancedTypes(
11666 ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS,
11667 ICmpInst::Predicate FoundPred, const SCEV *FoundLHS, const SCEV *FoundRHS,
11668 const Instruction *CtxI) {
11670 getTypeSizeInBits(FoundLHS->getType()) &&
11671 "Types should be balanced!");
11672 // Canonicalize the query to match the way instcombine will have
11673 // canonicalized the comparison.
11674 if (SimplifyICmpOperands(Pred, LHS, RHS))
11675 if (LHS == RHS)
11676 return CmpInst::isTrueWhenEqual(Pred);
11677 if (SimplifyICmpOperands(FoundPred, FoundLHS, FoundRHS))
11678 if (FoundLHS == FoundRHS)
11679 return CmpInst::isFalseWhenEqual(FoundPred);
11680
11681 // Check to see if we can make the LHS or RHS match.
11682 if (LHS == FoundRHS || RHS == FoundLHS) {
11683 if (isa<SCEVConstant>(RHS)) {
11684 std::swap(FoundLHS, FoundRHS);
11685 FoundPred = ICmpInst::getSwappedPredicate(FoundPred);
11686 } else {
11687 std::swap(LHS, RHS);
11688 Pred = ICmpInst::getSwappedPredicate(Pred);
11689 }
11690 }
11691
11692 // Check whether the found predicate is the same as the desired predicate.
11693 if (FoundPred == Pred)
11694 return isImpliedCondOperands(Pred, LHS, RHS, FoundLHS, FoundRHS, CtxI);
11695
11696 // Check whether swapping the found predicate makes it the same as the
11697 // desired predicate.
11698 if (ICmpInst::getSwappedPredicate(FoundPred) == Pred) {
11699 // We can write the implication
11700 // 0. LHS Pred RHS <- FoundLHS SwapPred FoundRHS
11701 // using one of the following ways:
11702 // 1. LHS Pred RHS <- FoundRHS Pred FoundLHS
11703 // 2. RHS SwapPred LHS <- FoundLHS SwapPred FoundRHS
11704 // 3. LHS Pred RHS <- ~FoundLHS Pred ~FoundRHS
11705 // 4. ~LHS SwapPred ~RHS <- FoundLHS SwapPred FoundRHS
11706 // Forms 1. and 2. require swapping the operands of one condition. Don't
11707 // do this if it would break canonical constant/addrec ordering.
11708 if (!isa<SCEVConstant>(RHS) && !isa<SCEVAddRecExpr>(LHS))
11709 return isImpliedCondOperands(FoundPred, RHS, LHS, FoundLHS, FoundRHS,
11710 CtxI);
11711 if (!isa<SCEVConstant>(FoundRHS) && !isa<SCEVAddRecExpr>(FoundLHS))
11712 return isImpliedCondOperands(Pred, LHS, RHS, FoundRHS, FoundLHS, CtxI);
11713
11714 // There's no clear preference between forms 3. and 4., try both. Avoid
11715 // forming getNotSCEV of pointer values as the resulting subtract is
11716 // not legal.
11717 if (!LHS->getType()->isPointerTy() && !RHS->getType()->isPointerTy() &&
11718 isImpliedCondOperands(FoundPred, getNotSCEV(LHS), getNotSCEV(RHS),
11719 FoundLHS, FoundRHS, CtxI))
11720 return true;
11721
11722 if (!FoundLHS->getType()->isPointerTy() &&
11723 !FoundRHS->getType()->isPointerTy() &&
11724 isImpliedCondOperands(Pred, LHS, RHS, getNotSCEV(FoundLHS),
11725 getNotSCEV(FoundRHS), CtxI))
11726 return true;
11727
11728 return false;
11729 }
11730
11731 auto IsSignFlippedPredicate = [](CmpInst::Predicate P1,
11732 CmpInst::Predicate P2) {
11733 assert(P1 != P2 && "Handled earlier!");
11734 return CmpInst::isRelational(P2) &&
11736 };
11737 if (IsSignFlippedPredicate(Pred, FoundPred)) {
11738 // Unsigned comparison is the same as signed comparison when both the
11739 // operands are non-negative or negative.
11740 if ((isKnownNonNegative(FoundLHS) && isKnownNonNegative(FoundRHS)) ||
11741 (isKnownNegative(FoundLHS) && isKnownNegative(FoundRHS)))
11742 return isImpliedCondOperands(Pred, LHS, RHS, FoundLHS, FoundRHS, CtxI);
11743 // Create local copies that we can freely swap and canonicalize our
11744 // conditions to "le/lt".
11745 ICmpInst::Predicate CanonicalPred = Pred, CanonicalFoundPred = FoundPred;
11746 const SCEV *CanonicalLHS = LHS, *CanonicalRHS = RHS,
11747 *CanonicalFoundLHS = FoundLHS, *CanonicalFoundRHS = FoundRHS;
11748 if (ICmpInst::isGT(CanonicalPred) || ICmpInst::isGE(CanonicalPred)) {
11749 CanonicalPred = ICmpInst::getSwappedPredicate(CanonicalPred);
11750 CanonicalFoundPred = ICmpInst::getSwappedPredicate(CanonicalFoundPred);
11751 std::swap(CanonicalLHS, CanonicalRHS);
11752 std::swap(CanonicalFoundLHS, CanonicalFoundRHS);
11753 }
11754 assert((ICmpInst::isLT(CanonicalPred) || ICmpInst::isLE(CanonicalPred)) &&
11755 "Must be!");
11756 assert((ICmpInst::isLT(CanonicalFoundPred) ||
11757 ICmpInst::isLE(CanonicalFoundPred)) &&
11758 "Must be!");
11759 if (ICmpInst::isSigned(CanonicalPred) && isKnownNonNegative(CanonicalRHS))
11760 // Use implication:
11761 // x <u y && y >=s 0 --> x <s y.
11762 // If we can prove the left part, the right part is also proven.
11763 return isImpliedCondOperands(CanonicalFoundPred, CanonicalLHS,
11764 CanonicalRHS, CanonicalFoundLHS,
11765 CanonicalFoundRHS);
11766 if (ICmpInst::isUnsigned(CanonicalPred) && isKnownNegative(CanonicalRHS))
11767 // Use implication:
11768 // x <s y && y <s 0 --> x <u y.
11769 // If we can prove the left part, the right part is also proven.
11770 return isImpliedCondOperands(CanonicalFoundPred, CanonicalLHS,
11771 CanonicalRHS, CanonicalFoundLHS,
11772 CanonicalFoundRHS);
11773 }
11774
11775 // Check if we can make progress by sharpening ranges.
11776 if (FoundPred == ICmpInst::ICMP_NE &&
11777 (isa<SCEVConstant>(FoundLHS) || isa<SCEVConstant>(FoundRHS))) {
11778
11779 const SCEVConstant *C = nullptr;
11780 const SCEV *V = nullptr;
11781
11782 if (isa<SCEVConstant>(FoundLHS)) {
11783 C = cast<SCEVConstant>(FoundLHS);
11784 V = FoundRHS;
11785 } else {
11786 C = cast<SCEVConstant>(FoundRHS);
11787 V = FoundLHS;
11788 }
11789
11790 // The guarding predicate tells us that C != V. If the known range
11791 // of V is [C, t), we can sharpen the range to [C + 1, t). The
11792 // range we consider has to correspond to same signedness as the
11793 // predicate we're interested in folding.
11794
11795 APInt Min = ICmpInst::isSigned(Pred) ?
11797
11798 if (Min == C->getAPInt()) {
11799 // Given (V >= Min && V != Min) we conclude V >= (Min + 1).
11800 // This is true even if (Min + 1) wraps around -- in case of
11801 // wraparound, (Min + 1) < Min, so (V >= Min => V >= (Min + 1)).
11802
11803 APInt SharperMin = Min + 1;
11804
11805 switch (Pred) {
11806 case ICmpInst::ICMP_SGE:
11807 case ICmpInst::ICMP_UGE:
11808 // We know V `Pred` SharperMin. If this implies LHS `Pred`
11809 // RHS, we're done.
11810 if (isImpliedCondOperands(Pred, LHS, RHS, V, getConstant(SharperMin),
11811 CtxI))
11812 return true;
11813 [[fallthrough]];
11814
11815 case ICmpInst::ICMP_SGT:
11816 case ICmpInst::ICMP_UGT:
11817 // We know from the range information that (V `Pred` Min ||
11818 // V == Min). We know from the guarding condition that !(V
11819 // == Min). This gives us
11820 //
11821 // V `Pred` Min || V == Min && !(V == Min)
11822 // => V `Pred` Min
11823 //
11824 // If V `Pred` Min implies LHS `Pred` RHS, we're done.
11825
11826 if (isImpliedCondOperands(Pred, LHS, RHS, V, getConstant(Min), CtxI))
11827 return true;
11828 break;
11829
11830 // `LHS < RHS` and `LHS <= RHS` are handled in the same way as `RHS > LHS` and `RHS >= LHS` respectively.
11831 case ICmpInst::ICMP_SLE:
11832 case ICmpInst::ICMP_ULE:
11833 if (isImpliedCondOperands(CmpInst::getSwappedPredicate(Pred), RHS,
11834 LHS, V, getConstant(SharperMin), CtxI))
11835 return true;
11836 [[fallthrough]];
11837
11838 case ICmpInst::ICMP_SLT:
11839 case ICmpInst::ICMP_ULT:
11840 if (isImpliedCondOperands(CmpInst::getSwappedPredicate(Pred), RHS,
11841 LHS, V, getConstant(Min), CtxI))
11842 return true;
11843 break;
11844
11845 default:
11846 // No change
11847 break;
11848 }
11849 }
11850 }
11851
11852 // Check whether the actual condition is beyond sufficient.
11853 if (FoundPred == ICmpInst::ICMP_EQ)
11854 if (ICmpInst::isTrueWhenEqual(Pred))
11855 if (isImpliedCondOperands(Pred, LHS, RHS, FoundLHS, FoundRHS, CtxI))
11856 return true;
11857 if (Pred == ICmpInst::ICMP_NE)
11858 if (!ICmpInst::isTrueWhenEqual(FoundPred))
11859 if (isImpliedCondOperands(FoundPred, LHS, RHS, FoundLHS, FoundRHS, CtxI))
11860 return true;
11861
11862 if (isImpliedCondOperandsViaRanges(Pred, LHS, RHS, FoundPred, FoundLHS, FoundRHS))
11863 return true;
11864
11865 // Otherwise assume the worst.
11866 return false;
11867}
11868
11869bool ScalarEvolution::splitBinaryAdd(const SCEV *Expr,
11870 const SCEV *&L, const SCEV *&R,
11871 SCEV::NoWrapFlags &Flags) {
11872 const auto *AE = dyn_cast<SCEVAddExpr>(Expr);
11873 if (!AE || AE->getNumOperands() != 2)
11874 return false;
11875
11876 L = AE->getOperand(0);
11877 R = AE->getOperand(1);
11878 Flags = AE->getNoWrapFlags();
11879 return true;
11880}
11881
11882std::optional<APInt>
11884 // We avoid subtracting expressions here because this function is usually
11885 // fairly deep in the call stack (i.e. is called many times).
11886
11887 // X - X = 0.
11888 if (More == Less)
11889 return APInt(getTypeSizeInBits(More->getType()), 0);
11890
11891 if (isa<SCEVAddRecExpr>(Less) && isa<SCEVAddRecExpr>(More)) {
11892 const auto *LAR = cast<SCEVAddRecExpr>(Less);
11893 const auto *MAR = cast<SCEVAddRecExpr>(More);
11894
11895 if (LAR->getLoop() != MAR->getLoop())
11896 return std::nullopt;
11897
11898 // We look at affine expressions only; not for correctness but to keep
11899 // getStepRecurrence cheap.
11900 if (!LAR->isAffine() || !MAR->isAffine())
11901 return std::nullopt;
11902
11903 if (LAR->getStepRecurrence(*this) != MAR->getStepRecurrence(*this))
11904 return std::nullopt;
11905
11906 Less = LAR->getStart();
11907 More = MAR->getStart();
11908
11909 // fall through
11910 }
11911
11912 if (isa<SCEVConstant>(Less) && isa<SCEVConstant>(More)) {
11913 const auto &M = cast<SCEVConstant>(More)->getAPInt();
11914 const auto &L = cast<SCEVConstant>(Less)->getAPInt();
11915 return M - L;
11916 }
11917
11918 SCEV::NoWrapFlags Flags;
11919 const SCEV *LLess = nullptr, *RLess = nullptr;
11920 const SCEV *LMore = nullptr, *RMore = nullptr;
11921 const SCEVConstant *C1 = nullptr, *C2 = nullptr;
11922 // Compare (X + C1) vs X.
11923 if (splitBinaryAdd(Less, LLess, RLess, Flags))
11924 if ((C1 = dyn_cast<SCEVConstant>(LLess)))
11925 if (RLess == More)
11926 return -(C1->getAPInt());
11927
11928 // Compare X vs (X + C2).
11929 if (splitBinaryAdd(More, LMore, RMore, Flags))
11930 if ((C2 = dyn_cast<SCEVConstant>(LMore)))
11931 if (RMore == Less)
11932 return C2->getAPInt();
11933
11934 // Compare (X + C1) vs (X + C2).
11935 if (C1 && C2 && RLess == RMore)
11936 return C2->getAPInt() - C1->getAPInt();
11937
11938 return std::nullopt;
11939}
11940
11941bool ScalarEvolution::isImpliedCondOperandsViaAddRecStart(
11942 ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS,
11943 const SCEV *FoundLHS, const SCEV *FoundRHS, const Instruction *CtxI) {
11944 // Try to recognize the following pattern:
11945 //
11946 // FoundRHS = ...
11947 // ...
11948 // loop:
11949 // FoundLHS = {Start,+,W}
11950 // context_bb: // Basic block from the same loop
11951 // known(Pred, FoundLHS, FoundRHS)
11952 //
11953 // If some predicate is known in the context of a loop, it is also known on
11954 // each iteration of this loop, including the first iteration. Therefore, in
11955 // this case, `FoundLHS Pred FoundRHS` implies `Start Pred FoundRHS`. Try to
11956 // prove the original pred using this fact.
11957 if (!CtxI)
11958 return false;
11959 const BasicBlock *ContextBB = CtxI->getParent();
11960 // Make sure AR varies in the context block.
11961 if (auto *AR = dyn_cast<SCEVAddRecExpr>(FoundLHS)) {
11962 const Loop *L = AR->getLoop();
11963 // Make sure that context belongs to the loop and executes on 1st iteration
11964 // (if it ever executes at all).
11965 if (!L->contains(ContextBB) || !DT.dominates(ContextBB, L->getLoopLatch()))
11966 return false;
11967 if (!isAvailableAtLoopEntry(FoundRHS, AR->getLoop()))
11968 return false;
11969 return isImpliedCondOperands(Pred, LHS, RHS, AR->getStart(), FoundRHS);
11970 }
11971
11972 if (auto *AR = dyn_cast<SCEVAddRecExpr>(FoundRHS)) {
11973 const Loop *L = AR->getLoop();
11974 // Make sure that context belongs to the loop and executes on 1st iteration
11975 // (if it ever executes at all).
11976 if (!L->contains(ContextBB) || !DT.dominates(ContextBB, L->getLoopLatch()))
11977 return false;
11978 if (!isAvailableAtLoopEntry(FoundLHS, AR->getLoop()))
11979 return false;
11980 return isImpliedCondOperands(Pred, LHS, RHS, FoundLHS, AR->getStart());
11981 }
11982
11983 return false;
11984}
11985
11986bool ScalarEvolution::isImpliedCondOperandsViaNoOverflow(
11987 ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS,
11988 const SCEV *FoundLHS, const SCEV *FoundRHS) {
11989 if (Pred != CmpInst::ICMP_SLT && Pred != CmpInst::ICMP_ULT)
11990 return false;
11991
11992 const auto *AddRecLHS = dyn_cast<SCEVAddRecExpr>(LHS);
11993 if (!AddRecLHS)
11994 return false;
11995
11996 const auto *AddRecFoundLHS = dyn_cast<SCEVAddRecExpr>(FoundLHS);
11997 if (!AddRecFoundLHS)
11998 return false;
11999
12000 // We'd like to let SCEV reason about control dependencies, so we constrain
12001 // both the inequalities to be about add recurrences on the same loop. This
12002 // way we can use isLoopEntryGuardedByCond later.
12003
12004 const Loop *L = AddRecFoundLHS->getLoop();
12005 if (L != AddRecLHS->getLoop())
12006 return false;
12007
12008 // FoundLHS u< FoundRHS u< -C => (FoundLHS + C) u< (FoundRHS + C) ... (1)
12009 //
12010 // FoundLHS s< FoundRHS s< INT_MIN - C => (FoundLHS + C) s< (FoundRHS + C)
12011 // ... (2)
12012 //
12013 // Informal proof for (2), assuming (1) [*]:
12014 //
12015 // We'll also assume (A s< B) <=> ((A + INT_MIN) u< (B + INT_MIN)) ... (3)[**]
12016 //
12017 // Then
12018 //
12019 // FoundLHS s< FoundRHS s< INT_MIN - C
12020 // <=> (FoundLHS + INT_MIN) u< (FoundRHS + INT_MIN) u< -C [ using (3) ]
12021 // <=> (FoundLHS + INT_MIN + C) u< (FoundRHS + INT_MIN + C) [ using (1) ]
12022 // <=> (FoundLHS + INT_MIN + C + INT_MIN) s<
12023 // (FoundRHS + INT_MIN + C + INT_MIN) [ using (3) ]
12024 // <=> FoundLHS + C s< FoundRHS + C
12025 //
12026 // [*]: (1) can be proved by ruling out overflow.
12027 //
12028 // [**]: This can be proved by analyzing all the four possibilities:
12029 // (A s< 0, B s< 0), (A s< 0, B s>= 0), (A s>= 0, B s< 0) and
12030 // (A s>= 0, B s>= 0).
12031 //
12032 // Note:
12033 // Despite (2), "FoundRHS s< INT_MIN - C" does not mean that "FoundRHS + C"
12034 // will not sign underflow. For instance, say FoundLHS = (i8 -128), FoundRHS
12035 // = (i8 -127) and C = (i8 -100). Then INT_MIN - C = (i8 -28), and FoundRHS
12036 // s< (INT_MIN - C). Lack of sign overflow / underflow in "FoundRHS + C" is
12037 // neither necessary nor sufficient to prove "(FoundLHS + C) s< (FoundRHS +
12038 // C)".
12039
12040 std::optional<APInt> LDiff = computeConstantDifference(LHS, FoundLHS);
12041 std::optional<APInt> RDiff = computeConstantDifference(RHS, FoundRHS);
12042 if (!LDiff || !RDiff || *LDiff != *RDiff)
12043 return false;
12044
12045 if (LDiff->isMinValue())
12046 return true;
12047
12048 APInt FoundRHSLimit;
12049
12050 if (Pred == CmpInst::ICMP_ULT) {
12051 FoundRHSLimit = -(*RDiff);
12052 } else {
12053 assert(Pred == CmpInst::ICMP_SLT && "Checked above!");
12054 FoundRHSLimit = APInt::getSignedMinValue(getTypeSizeInBits(RHS->getType())) - *RDiff;
12055 }
12056
12057 // Try to prove (1) or (2), as needed.
12058 return isAvailableAtLoopEntry(FoundRHS, L) &&
12059 isLoopEntryGuardedByCond(L, Pred, FoundRHS,
12060 getConstant(FoundRHSLimit));
12061}
12062
12063bool ScalarEvolution::isImpliedViaMerge(ICmpInst::Predicate Pred,
12064 const SCEV *LHS, const SCEV *RHS,
12065 const SCEV *FoundLHS,
12066 const SCEV *FoundRHS, unsigned Depth) {
12067 const PHINode *LPhi = nullptr, *RPhi = nullptr;
12068
12069 auto ClearOnExit = make_scope_exit([&]() {
12070 if (LPhi) {
12071 bool Erased = PendingMerges.erase(LPhi);
12072 assert(Erased && "Failed to erase LPhi!");
12073 (void)Erased;
12074 }
12075 if (RPhi) {
12076 bool Erased = PendingMerges.erase(RPhi);
12077 assert(Erased && "Failed to erase RPhi!");
12078 (void)Erased;
12079 }
12080 });
12081
12082 // Find respective Phis and check that they are not being pending.
12083 if (const SCEVUnknown *LU = dyn_cast<SCEVUnknown>(LHS))
12084 if (auto *Phi = dyn_cast<PHINode>(LU->getValue())) {
12085 if (!PendingMerges.insert(Phi).second)
12086 return false;
12087 LPhi = Phi;
12088 }
12089 if (const SCEVUnknown *RU = dyn_cast<SCEVUnknown>(RHS))
12090 if (auto *Phi = dyn_cast<PHINode>(RU->getValue())) {
12091 // If we detect a loop of Phi nodes being processed by this method, for
12092 // example:
12093 //
12094 // %a = phi i32 [ %some1, %preheader ], [ %b, %latch ]
12095 // %b = phi i32 [ %some2, %preheader ], [ %a, %latch ]
12096 //
12097 // we don't want to deal with a case that complex, so return conservative
12098 // answer false.
12099 if (!PendingMerges.insert(Phi).second)
12100 return false;
12101 RPhi = Phi;
12102 }
12103
12104 // If none of LHS, RHS is a Phi, nothing to do here.
12105 if (!LPhi && !RPhi)
12106 return false;
12107
12108 // If there is a SCEVUnknown Phi we are interested in, make it left.
12109 if (!LPhi) {
12110 std::swap(LHS, RHS);
12111 std::swap(FoundLHS, FoundRHS);
12112 std::swap(LPhi, RPhi);
12113 Pred = ICmpInst::getSwappedPredicate(Pred);
12114 }
12115
12116 assert(LPhi && "LPhi should definitely be a SCEVUnknown Phi!");
12117 const BasicBlock *LBB = LPhi->getParent();
12118 const SCEVAddRecExpr *RAR = dyn_cast<SCEVAddRecExpr>(RHS);
12119
12120 auto ProvedEasily = [&](const SCEV *S1, const SCEV *S2) {
12121 return isKnownViaNonRecursiveReasoning(Pred, S1, S2) ||
12122 isImpliedCondOperandsViaRanges(Pred, S1, S2, Pred, FoundLHS, FoundRHS) ||
12123 isImpliedViaOperations(Pred, S1, S2, FoundLHS, FoundRHS, Depth);
12124 };
12125
12126 if (RPhi && RPhi->getParent() == LBB) {
12127 // Case one: RHS is also a SCEVUnknown Phi from the same basic block.
12128 // If we compare two Phis from the same block, and for each entry block
12129 // the predicate is true for incoming values from this block, then the
12130 // predicate is also true for the Phis.
12131 for (const BasicBlock *IncBB : predecessors(LBB)) {
12132 const SCEV *L = getSCEV(LPhi->getIncomingValueForBlock(IncBB));
12133 const SCEV *R = getSCEV(RPhi->getIncomingValueForBlock(IncBB));
12134 if (!ProvedEasily(L, R))
12135 return false;
12136 }
12137 } else if (RAR && RAR->getLoop()->getHeader() == LBB) {
12138 // Case two: RHS is also a Phi from the same basic block, and it is an
12139 // AddRec. It means that there is a loop which has both AddRec and Unknown
12140 // PHIs, for it we can compare incoming values of AddRec from above the loop
12141 // and latch with their respective incoming values of LPhi.
12142 // TODO: Generalize to handle loops with many inputs in a header.
12143 if (LPhi->getNumIncomingValues() != 2) return false;
12144
12145 auto *RLoop = RAR->getLoop();
12146 auto *Predecessor = RLoop->getLoopPredecessor();
12147 assert(Predecessor && "Loop with AddRec with no predecessor?");
12148 const SCEV *L1 = getSCEV(LPhi->getIncomingValueForBlock(Predecessor));
12149 if (!ProvedEasily(L1, RAR->getStart()))
12150 return false;
12151 auto *Latch = RLoop->getLoopLatch();
12152 assert(Latch && "Loop with AddRec with no latch?");
12153 const SCEV *L2 = getSCEV(LPhi->getIncomingValueForBlock(Latch));
12154 if (!ProvedEasily(L2, RAR->getPostIncExpr(*this)))
12155 return false;
12156 } else {
12157 // In all other cases go over inputs of LHS and compare each of them to RHS,
12158 // the predicate is true for (LHS, RHS) if it is true for all such pairs.
12159 // At this point RHS is either a non-Phi, or it is a Phi from some block
12160 // different from LBB.
12161 for (const BasicBlock *IncBB : predecessors(LBB)) {
12162 // Check that RHS is available in this block.
12163 if (!dominates(RHS, IncBB))
12164 return false;
12165 const SCEV *L = getSCEV(LPhi->getIncomingValueForBlock(IncBB));
12166 // Make sure L does not refer to a value from a potentially previous
12167 // iteration of a loop.
12168 if (!properlyDominates(L, LBB))
12169 return false;
12170 if (!ProvedEasily(L, RHS))
12171 return false;
12172 }
12173 }
12174 return true;
12175}
12176
12177bool ScalarEvolution::isImpliedCondOperandsViaShift(ICmpInst::Predicate Pred,
12178 const SCEV *LHS,
12179 const SCEV *RHS,
12180 const SCEV *FoundLHS,
12181 const SCEV *FoundRHS) {
12182 // We want to imply LHS < RHS from LHS < (RHS >> shiftvalue). First, make
12183 // sure that we are dealing with same LHS.
12184 if (RHS == FoundRHS) {
12185 std::swap(LHS, RHS);
12186 std::swap(FoundLHS, FoundRHS);
12187 Pred = ICmpInst::getSwappedPredicate(Pred);
12188 }
12189 if (LHS != FoundLHS)
12190 return false;
12191
12192 auto *SUFoundRHS = dyn_cast<SCEVUnknown>(FoundRHS);
12193 if (!SUFoundRHS)
12194 return false;
12195
12196 Value *Shiftee, *ShiftValue;
12197
12198 using namespace PatternMatch;
12199 if (match(SUFoundRHS->getValue(),
12200 m_LShr(m_Value(Shiftee), m_Value(ShiftValue)))) {
12201 auto *ShifteeS = getSCEV(Shiftee);
12202 // Prove one of the following:
12203 // LHS <u (shiftee >> shiftvalue) && shiftee <=u RHS ---> LHS <u RHS
12204 // LHS <=u (shiftee >> shiftvalue) && shiftee <=u RHS ---> LHS <=u RHS
12205 // LHS <s (shiftee >> shiftvalue) && shiftee <=s RHS && shiftee >=s 0
12206 // ---> LHS <s RHS
12207 // LHS <=s (shiftee >> shiftvalue) && shiftee <=s RHS && shiftee >=s 0
12208 // ---> LHS <=s RHS
12209 if (Pred == ICmpInst::ICMP_ULT || Pred == ICmpInst::ICMP_ULE)
12210 return isKnownPredicate(ICmpInst::ICMP_ULE, ShifteeS, RHS);
12211 if (Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_SLE)
12212 if (isKnownNonNegative(ShifteeS))
12213 return isKnownPredicate(ICmpInst::ICMP_SLE, ShifteeS, RHS);
12214 }
12215
12216 return false;
12217}
12218
12219bool ScalarEvolution::isImpliedCondOperands(ICmpInst::Predicate Pred,
12220 const SCEV *LHS, const SCEV *RHS,
12221 const SCEV *FoundLHS,
12222 const SCEV *FoundRHS,
12223 const Instruction *CtxI) {
12224 if (isImpliedCondOperandsViaRanges(Pred, LHS, RHS, Pred, FoundLHS, FoundRHS))
12225 return true;
12226
12227 if (isImpliedCondOperandsViaNoOverflow(Pred, LHS, RHS, FoundLHS, FoundRHS))
12228 return true;
12229
12230 if (isImpliedCondOperandsViaShift(Pred, LHS, RHS, FoundLHS, FoundRHS))
12231 return true;
12232
12233 if (isImpliedCondOperandsViaAddRecStart(Pred, LHS, RHS, FoundLHS, FoundRHS,
12234 CtxI))
12235 return true;
12236
12237 return isImpliedCondOperandsHelper(Pred, LHS, RHS,
12238 FoundLHS, FoundRHS);
12239}
12240
12241/// Is MaybeMinMaxExpr an (U|S)(Min|Max) of Candidate and some other values?
12242template <typename MinMaxExprType>
12243static bool IsMinMaxConsistingOf(const SCEV *MaybeMinMaxExpr,
12244 const SCEV *Candidate) {
12245 const MinMaxExprType *MinMaxExpr = dyn_cast<MinMaxExprType>(MaybeMinMaxExpr);
12246 if (!MinMaxExpr)
12247 return false;
12248
12249 return is_contained(MinMaxExpr->operands(), Candidate);
12250}
12251
12254 const SCEV *LHS, const SCEV *RHS) {
12255 // If both sides are affine addrecs for the same loop, with equal
12256 // steps, and we know the recurrences don't wrap, then we only
12257 // need to check the predicate on the starting values.
12258
12259 if (!ICmpInst::isRelational(Pred))
12260 return false;
12261
12262 const SCEVAddRecExpr *LAR = dyn_cast<SCEVAddRecExpr>(LHS);
12263 if (!LAR)
12264 return false;
12265 const SCEVAddRecExpr *RAR = dyn_cast<SCEVAddRecExpr>(RHS);
12266 if (!RAR)
12267 return false;
12268 if (LAR->getLoop() != RAR->getLoop())
12269 return false;
12270 if (!LAR->isAffine() || !RAR->isAffine())
12271 return false;
12272
12273 if (LAR->getStepRecurrence(SE) != RAR->getStepRecurrence(SE))
12274 return false;
12275
12278 if (!LAR->getNoWrapFlags(NW) || !RAR->getNoWrapFlags(NW))
12279 return false;
12280
12281 return SE.isKnownPredicate(Pred, LAR->getStart(), RAR->getStart());
12282}
12283
12284/// Is LHS `Pred` RHS true on the virtue of LHS or RHS being a Min or Max
12285/// expression?
12288 const SCEV *LHS, const SCEV *RHS) {
12289 switch (Pred) {
12290 default:
12291 return false;
12292
12293 case ICmpInst::ICMP_SGE:
12294 std::swap(LHS, RHS);
12295 [[fallthrough]];
12296 case ICmpInst::ICMP_SLE:
12297 return
12298 // min(A, ...) <= A
12299 IsMinMaxConsistingOf<SCEVSMinExpr>(LHS, RHS) ||
12300 // A <= max(A, ...)
12301 IsMinMaxConsistingOf<SCEVSMaxExpr>(RHS, LHS);
12302
12303 case ICmpInst::ICMP_UGE:
12304 std::swap(LHS, RHS);
12305 [[fallthrough]];
12306 case ICmpInst::ICMP_ULE:
12307 return
12308 // min(A, ...) <= A
12309 // FIXME: what about umin_seq?
12310 IsMinMaxConsistingOf<SCEVUMinExpr>(LHS, RHS) ||
12311 // A <= max(A, ...)
12312 IsMinMaxConsistingOf<SCEVUMaxExpr>(RHS, LHS);
12313 }
12314
12315 llvm_unreachable("covered switch fell through?!");
12316}
12317
12318bool ScalarEvolution::isImpliedViaOperations(ICmpInst::Predicate Pred,
12319 const SCEV *LHS, const SCEV *RHS,
12320 const SCEV *FoundLHS,
12321 const SCEV *FoundRHS,
12322 unsigned Depth) {
12325 "LHS and RHS have different sizes?");
12326 assert(getTypeSizeInBits(FoundLHS->getType()) ==
12327 getTypeSizeInBits(FoundRHS->getType()) &&
12328 "FoundLHS and FoundRHS have different sizes?");
12329 // We want to avoid hurting the compile time with analysis of too big trees.
12331 return false;
12332
12333 // We only want to work with GT comparison so far.
12334 if (Pred == ICmpInst::ICMP_ULT || Pred == ICmpInst::ICMP_SLT) {
12335 Pred = CmpInst::getSwappedPredicate(Pred);
12336 std::swap(LHS, RHS);
12337 std::swap(FoundLHS, FoundRHS);
12338 }
12339
12340 // For unsigned, try to reduce it to corresponding signed comparison.
12341 if (Pred == ICmpInst::ICMP_UGT)
12342 // We can replace unsigned predicate with its signed counterpart if all
12343 // involved values are non-negative.
12344 // TODO: We could have better support for unsigned.
12345 if (isKnownNonNegative(FoundLHS) && isKnownNonNegative(FoundRHS)) {
12346 // Knowing that both FoundLHS and FoundRHS are non-negative, and knowing
12347 // FoundLHS >u FoundRHS, we also know that FoundLHS >s FoundRHS. Let us
12348 // use this fact to prove that LHS and RHS are non-negative.
12349 const SCEV *MinusOne = getMinusOne(LHS->getType());
12350 if (isImpliedCondOperands(ICmpInst::ICMP_SGT, LHS, MinusOne, FoundLHS,
12351 FoundRHS) &&
12352 isImpliedCondOperands(ICmpInst::ICMP_SGT, RHS, MinusOne, FoundLHS,
12353 FoundRHS))
12354 Pred = ICmpInst::ICMP_SGT;
12355 }
12356
12357 if (Pred != ICmpInst::ICMP_SGT)
12358 return false;
12359
12360 auto GetOpFromSExt = [&](const SCEV *S) {
12361 if (auto *Ext = dyn_cast<SCEVSignExtendExpr>(S))
12362 return Ext->getOperand();
12363 // TODO: If S is a SCEVConstant then you can cheaply "strip" the sext off
12364 // the constant in some cases.
12365 return S;
12366 };
12367
12368 // Acquire values from extensions.
12369 auto *OrigLHS = LHS;
12370 auto *OrigFoundLHS = FoundLHS;
12371 LHS = GetOpFromSExt(LHS);
12372 FoundLHS = GetOpFromSExt(FoundLHS);
12373
12374 // Is the SGT predicate can be proved trivially or using the found context.
12375 auto IsSGTViaContext = [&](const SCEV *S1, const SCEV *S2) {
12376 return isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_SGT, S1, S2) ||
12377 isImpliedViaOperations(ICmpInst::ICMP_SGT, S1, S2, OrigFoundLHS,
12378 FoundRHS, Depth + 1);
12379 };
12380
12381 if (auto *LHSAddExpr = dyn_cast<SCEVAddExpr>(LHS)) {
12382 // We want to avoid creation of any new non-constant SCEV. Since we are
12383 // going to compare the operands to RHS, we should be certain that we don't
12384 // need any size extensions for this. So let's decline all cases when the
12385 // sizes of types of LHS and RHS do not match.
12386 // TODO: Maybe try to get RHS from sext to catch more cases?
12388 return false;
12389
12390 // Should not overflow.
12391 if (!LHSAddExpr->hasNoSignedWrap())
12392 return false;
12393
12394 auto *LL = LHSAddExpr->getOperand(0);
12395 auto *LR = LHSAddExpr->getOperand(1);
12396 auto *MinusOne = getMinusOne(RHS->getType());
12397
12398 // Checks that S1 >= 0 && S2 > RHS, trivially or using the found context.
12399 auto IsSumGreaterThanRHS = [&](const SCEV *S1, const SCEV *S2) {
12400 return IsSGTViaContext(S1, MinusOne) && IsSGTViaContext(S2, RHS);
12401 };
12402 // Try to prove the following rule:
12403 // (LHS = LL + LR) && (LL >= 0) && (LR > RHS) => (LHS > RHS).
12404 // (LHS = LL + LR) && (LR >= 0) && (LL > RHS) => (LHS > RHS).
12405 if (IsSumGreaterThanRHS(LL, LR) || IsSumGreaterThanRHS(LR, LL))
12406 return true;
12407 } else if (auto *LHSUnknownExpr = dyn_cast<SCEVUnknown>(LHS)) {
12408 Value *LL, *LR;
12409 // FIXME: Once we have SDiv implemented, we can get rid of this matching.
12410
12411 using namespace llvm::PatternMatch;
12412
12413 if (match(LHSUnknownExpr->getValue(), m_SDiv(m_Value(LL), m_Value(LR)))) {
12414 // Rules for division.
12415 // We are going to perform some comparisons with Denominator and its
12416 // derivative expressions. In general case, creating a SCEV for it may
12417 // lead to a complex analysis of the entire graph, and in particular it
12418 // can request trip count recalculation for the same loop. This would
12419 // cache as SCEVCouldNotCompute to avoid the infinite recursion. To avoid
12420 // this, we only want to create SCEVs that are constants in this section.
12421 // So we bail if Denominator is not a constant.
12422 if (!isa<ConstantInt>(LR))
12423 return false;
12424
12425 auto *Denominator = cast<SCEVConstant>(getSCEV(LR));
12426
12427 // We want to make sure that LHS = FoundLHS / Denominator. If it is so,
12428 // then a SCEV for the numerator already exists and matches with FoundLHS.
12429 auto *Numerator = getExistingSCEV(LL);
12430 if (!Numerator || Numerator->getType() != FoundLHS->getType())
12431 return false;
12432
12433 // Make sure that the numerator matches with FoundLHS and the denominator
12434 // is positive.
12435 if (!HasSameValue(Numerator, FoundLHS) || !isKnownPositive(Denominator))
12436 return false;
12437
12438 auto *DTy = Denominator->getType();
12439 auto *FRHSTy = FoundRHS->getType();
12440 if (DTy->isPointerTy() != FRHSTy->isPointerTy())
12441 // One of types is a pointer and another one is not. We cannot extend
12442 // them properly to a wider type, so let us just reject this case.
12443 // TODO: Usage of getEffectiveSCEVType for DTy, FRHSTy etc should help
12444 // to avoid this check.
12445 return false;
12446
12447 // Given that:
12448 // FoundLHS > FoundRHS, LHS = FoundLHS / Denominator, Denominator > 0.
12449 auto *WTy = getWiderType(DTy, FRHSTy);
12450 auto *DenominatorExt = getNoopOrSignExtend(Denominator, WTy);
12451 auto *FoundRHSExt = getNoopOrSignExtend(FoundRHS, WTy);
12452
12453 // Try to prove the following rule:
12454 // (FoundRHS > Denominator - 2) && (RHS <= 0) => (LHS > RHS).
12455 // For example, given that FoundLHS > 2. It means that FoundLHS is at
12456 // least 3. If we divide it by Denominator < 4, we will have at least 1.
12457 auto *DenomMinusTwo = getMinusSCEV(DenominatorExt, getConstant(WTy, 2));
12458 if (isKnownNonPositive(RHS) &&
12459 IsSGTViaContext(FoundRHSExt, DenomMinusTwo))
12460 return true;
12461
12462 // Try to prove the following rule:
12463 // (FoundRHS > -1 - Denominator) && (RHS < 0) => (LHS > RHS).
12464 // For example, given that FoundLHS > -3. Then FoundLHS is at least -2.
12465 // If we divide it by Denominator > 2, then:
12466 // 1. If FoundLHS is negative, then the result is 0.
12467 // 2. If FoundLHS is non-negative, then the result is non-negative.
12468 // Anyways, the result is non-negative.
12469 auto *MinusOne = getMinusOne(WTy);
12470 auto *NegDenomMinusOne = getMinusSCEV(MinusOne, DenominatorExt);
12471 if (isKnownNegative(RHS) &&
12472 IsSGTViaContext(FoundRHSExt, NegDenomMinusOne))
12473 return true;
12474 }
12475 }
12476
12477 // If our expression contained SCEVUnknown Phis, and we split it down and now
12478 // need to prove something for them, try to prove the predicate for every
12479 // possible incoming values of those Phis.
12480 if (isImpliedViaMerge(Pred, OrigLHS, RHS, OrigFoundLHS, FoundRHS, Depth + 1))
12481 return true;
12482
12483 return false;
12484}
12485
12487 const SCEV *LHS, const SCEV *RHS) {
12488 // zext x u<= sext x, sext x s<= zext x
12489 switch (Pred) {
12490 case ICmpInst::ICMP_SGE:
12491 std::swap(LHS, RHS);
12492 [[fallthrough]];
12493 case ICmpInst::ICMP_SLE: {
12494 // If operand >=s 0 then ZExt == SExt. If operand <s 0 then SExt <s ZExt.
12495 const SCEVSignExtendExpr *SExt = dyn_cast<SCEVSignExtendExpr>(LHS);
12496 const SCEVZeroExtendExpr *ZExt = dyn_cast<SCEVZeroExtendExpr>(RHS);
12497 if (SExt && ZExt && SExt->getOperand() == ZExt->getOperand())
12498 return true;
12499 break;
12500 }
12501 case ICmpInst::ICMP_UGE:
12502 std::swap(LHS, RHS);
12503 [[fallthrough]];
12504 case ICmpInst::ICMP_ULE: {
12505 // If operand >=s 0 then ZExt == SExt. If operand <s 0 then ZExt <u SExt.
12506 const SCEVZeroExtendExpr *ZExt = dyn_cast<SCEVZeroExtendExpr>(LHS);
12507 const SCEVSignExtendExpr *SExt = dyn_cast<SCEVSignExtendExpr>(RHS);
12508 if (SExt && ZExt && SExt->getOperand() == ZExt->getOperand())
12509 return true;
12510 break;
12511 }
12512 default:
12513 break;
12514 };
12515 return false;
12516}
12517
12518bool
12519ScalarEvolution::isKnownViaNonRecursiveReasoning(ICmpInst::Predicate Pred,
12520 const SCEV *LHS, const SCEV *RHS) {
12521 return isKnownPredicateExtendIdiom(Pred, LHS, RHS) ||
12522 isKnownPredicateViaConstantRanges(Pred, LHS, RHS) ||
12523 IsKnownPredicateViaMinOrMax(*this, Pred, LHS, RHS) ||
12524 IsKnownPredicateViaAddRecStart(*this, Pred, LHS, RHS) ||
12525 isKnownPredicateViaNoOverflow(Pred, LHS, RHS);
12526}
12527
12528bool
12529ScalarEvolution::isImpliedCondOperandsHelper(ICmpInst::Predicate Pred,
12530 const SCEV *LHS, const SCEV *RHS,
12531 const SCEV *FoundLHS,
12532 const SCEV *FoundRHS) {
12533 switch (Pred) {
12534 default: llvm_unreachable("Unexpected ICmpInst::Predicate value!");
12535 case ICmpInst::ICMP_EQ:
12536 case ICmpInst::ICMP_NE:
12537 if (HasSameValue(LHS, FoundLHS) && HasSameValue(RHS, FoundRHS))
12538 return true;
12539 break;
12540 case ICmpInst::ICMP_SLT:
12541 case ICmpInst::ICMP_SLE:
12542 if (isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_SLE, LHS, FoundLHS) &&
12543 isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_SGE, RHS, FoundRHS))
12544 return true;
12545 break;
12546 case ICmpInst::ICMP_SGT:
12547 case ICmpInst::ICMP_SGE:
12548 if (isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_SGE, LHS, FoundLHS) &&
12549 isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_SLE, RHS, FoundRHS))
12550 return true;
12551 break;
12552 case ICmpInst::ICMP_ULT:
12553 case ICmpInst::ICMP_ULE:
12554 if (isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_ULE, LHS, FoundLHS) &&
12555 isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_UGE, RHS, FoundRHS))
12556 return true;
12557 break;
12558 case ICmpInst::ICMP_UGT:
12559 case ICmpInst::ICMP_UGE:
12560 if (isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_UGE, LHS, FoundLHS) &&
12561 isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_ULE, RHS, FoundRHS))
12562 return true;
12563 break;
12564 }
12565
12566 // Maybe it can be proved via operations?
12567 if (isImpliedViaOperations(Pred, LHS, RHS, FoundLHS, FoundRHS))
12568 return true;
12569
12570 return false;
12571}
12572
12573bool ScalarEvolution::isImpliedCondOperandsViaRanges(ICmpInst::Predicate Pred,
12574 const SCEV *LHS,
12575 const SCEV *RHS,
12576 ICmpInst::Predicate FoundPred,
12577 const SCEV *FoundLHS,
12578 const SCEV *FoundRHS) {
12579 if (!isa<SCEVConstant>(RHS) || !isa<SCEVConstant>(FoundRHS))
12580 // The restriction on `FoundRHS` be lifted easily -- it exists only to
12581 // reduce the compile time impact of this optimization.
12582 return false;
12583
12584 std::optional<APInt> Addend = computeConstantDifference(LHS, FoundLHS);
12585 if (!Addend)
12586 return false;
12587
12588 const APInt &ConstFoundRHS = cast<SCEVConstant>(FoundRHS)->getAPInt();
12589
12590 // `FoundLHSRange` is the range we know `FoundLHS` to be in by virtue of the
12591 // antecedent "`FoundLHS` `FoundPred` `FoundRHS`".
12592 ConstantRange FoundLHSRange =
12593 ConstantRange::makeExactICmpRegion(FoundPred, ConstFoundRHS);
12594
12595 // Since `LHS` is `FoundLHS` + `Addend`, we can compute a range for `LHS`:
12596 ConstantRange LHSRange = FoundLHSRange.add(ConstantRange(*Addend));
12597
12598 // We can also compute the range of values for `LHS` that satisfy the
12599 // consequent, "`LHS` `Pred` `RHS`":
12600 const APInt &ConstRHS = cast<SCEVConstant>(RHS)->getAPInt();
12601 // The antecedent implies the consequent if every value of `LHS` that
12602 // satisfies the antecedent also satisfies the consequent.
12603 return LHSRange.icmp(Pred, ConstRHS);
12604}
12605
12606bool ScalarEvolution::canIVOverflowOnLT(const SCEV *RHS, const SCEV *Stride,
12607 bool IsSigned) {
12608 assert(isKnownPositive(Stride) && "Positive stride expected!");
12609
12610 unsigned BitWidth = getTypeSizeInBits(RHS->getType());
12611 const SCEV *One = getOne(Stride->getType());
12612
12613 if (IsSigned) {
12614 APInt MaxRHS = getSignedRangeMax(RHS);
12616 APInt MaxStrideMinusOne = getSignedRangeMax(getMinusSCEV(Stride, One));
12617
12618 // SMaxRHS + SMaxStrideMinusOne > SMaxValue => overflow!
12619 return (std::move(MaxValue) - MaxStrideMinusOne).slt(MaxRHS);
12620 }
12621
12622 APInt MaxRHS = getUnsignedRangeMax(RHS);
12623 APInt MaxValue = APInt::getMaxValue(BitWidth);
12624 APInt MaxStrideMinusOne = getUnsignedRangeMax(getMinusSCEV(Stride, One));
12625
12626 // UMaxRHS + UMaxStrideMinusOne > UMaxValue => overflow!
12627 return (std::move(MaxValue) - MaxStrideMinusOne).ult(MaxRHS);
12628}
12629
12630bool ScalarEvolution::canIVOverflowOnGT(const SCEV *RHS, const SCEV *Stride,
12631 bool IsSigned) {
12632
12633 unsigned BitWidth = getTypeSizeInBits(RHS->getType());
12634 const SCEV *One = getOne(Stride->getType());
12635
12636 if (IsSigned) {
12637 APInt MinRHS = getSignedRangeMin(RHS);
12639 APInt MaxStrideMinusOne = getSignedRangeMax(getMinusSCEV(Stride, One));
12640
12641 // SMinRHS - SMaxStrideMinusOne < SMinValue => overflow!
12642 return (std::move(MinValue) + MaxStrideMinusOne).sgt(MinRHS);
12643 }
12644
12645 APInt MinRHS = getUnsignedRangeMin(RHS);
12646 APInt MinValue = APInt::getMinValue(BitWidth);
12647 APInt MaxStrideMinusOne = getUnsignedRangeMax(getMinusSCEV(Stride, One));
12648
12649 // UMinRHS - UMaxStrideMinusOne < UMinValue => overflow!
12650 return (std::move(MinValue) + MaxStrideMinusOne).ugt(MinRHS);
12651}
12652
12654 // umin(N, 1) + floor((N - umin(N, 1)) / D)
12655 // This is equivalent to "1 + floor((N - 1) / D)" for N != 0. The umin
12656 // expression fixes the case of N=0.
12657 const SCEV *MinNOne = getUMinExpr(N, getOne(N->getType()));
12658 const SCEV *NMinusOne = getMinusSCEV(N, MinNOne);
12659 return getAddExpr(MinNOne, getUDivExpr(NMinusOne, D));
12660}
12661
12662const SCEV *ScalarEvolution::computeMaxBECountForLT(const SCEV *Start,
12663 const SCEV *Stride,
12664 const SCEV *End,
12665 unsigned BitWidth,
12666 bool IsSigned) {
12667 // The logic in this function assumes we can represent a positive stride.
12668 // If we can't, the backedge-taken count must be zero.
12669 if (IsSigned && BitWidth == 1)
12670 return getZero(Stride->getType());
12671
12672 // This code below only been closely audited for negative strides in the
12673 // unsigned comparison case, it may be correct for signed comparison, but
12674 // that needs to be established.
12675 if (IsSigned && isKnownNegative(Stride))
12676 return getCouldNotCompute();
12677
12678 // Calculate the maximum backedge count based on the range of values
12679 // permitted by Start, End, and Stride.
12680 APInt MinStart =
12681 IsSigned ? getSignedRangeMin(Start) : getUnsignedRangeMin(Start);
12682
12683 APInt MinStride =
12684 IsSigned ? getSignedRangeMin(Stride) : getUnsignedRangeMin(Stride);
12685
12686 // We assume either the stride is positive, or the backedge-taken count
12687 // is zero. So force StrideForMaxBECount to be at least one.
12688 APInt One(BitWidth, 1);
12689 APInt StrideForMaxBECount = IsSigned ? APIntOps::smax(One, MinStride)
12690 : APIntOps::umax(One, MinStride);
12691
12692 APInt MaxValue = IsSigned ? APInt::getSignedMaxValue(BitWidth)
12693 : APInt::getMaxValue(BitWidth);
12694 APInt Limit = MaxValue - (StrideForMaxBECount - 1);
12695
12696 // Although End can be a MAX expression we estimate MaxEnd considering only
12697 // the case End = RHS of the loop termination condition. This is safe because
12698 // in the other case (End - Start) is zero, leading to a zero maximum backedge
12699 // taken count.
12700 APInt MaxEnd = IsSigned ? APIntOps::smin(getSignedRangeMax(End), Limit)
12701 : APIntOps::umin(getUnsignedRangeMax(End), Limit);
12702
12703 // MaxBECount = ceil((max(MaxEnd, MinStart) - MinStart) / Stride)
12704 MaxEnd = IsSigned ? APIntOps::smax(MaxEnd, MinStart)
12705 : APIntOps::umax(MaxEnd, MinStart);
12706
12707 return getUDivCeilSCEV(getConstant(MaxEnd - MinStart) /* Delta */,
12708 getConstant(StrideForMaxBECount) /* Step */);
12709}
12710
12712ScalarEvolution::howManyLessThans(const SCEV *LHS, const SCEV *RHS,
12713 const Loop *L, bool IsSigned,
12714 bool ControlsOnlyExit, bool AllowPredicates) {
12716
12717 const SCEVAddRecExpr *IV = dyn_cast<SCEVAddRecExpr>(LHS);
12718 bool PredicatedIV = false;
12719
12720 auto canAssumeNoSelfWrap = [&](const SCEVAddRecExpr *AR) {
12721 // Can we prove this loop *must* be UB if overflow of IV occurs?
12722 // Reasoning goes as follows:
12723 // * Suppose the IV did self wrap.
12724 // * If Stride evenly divides the iteration space, then once wrap
12725 // occurs, the loop must revisit the same values.
12726 // * We know that RHS is invariant, and that none of those values
12727 // caused this exit to be taken previously. Thus, this exit is
12728 // dynamically dead.
12729 // * If this is the sole exit, then a dead exit implies the loop
12730 // must be infinite if there are no abnormal exits.
12731 // * If the loop were infinite, then it must either not be mustprogress
12732 // or have side effects. Otherwise, it must be UB.
12733 // * It can't (by assumption), be UB so we have contradicted our
12734 // premise and can conclude the IV did not in fact self-wrap.
12735 if (!isLoopInvariant(RHS, L))
12736 return false;
12737
12738 auto *StrideC = dyn_cast<SCEVConstant>(AR->getStepRecurrence(*this));
12739 if (!StrideC || !StrideC->getAPInt().isPowerOf2())
12740 return false;
12741
12742 if (!ControlsOnlyExit || !loopHasNoAbnormalExits(L))
12743 return false;
12744
12745 return loopIsFiniteByAssumption(L);
12746 };
12747
12748 if (!IV) {
12749 if (auto *ZExt = dyn_cast<SCEVZeroExtendExpr>(LHS)) {
12750 const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(ZExt->getOperand());
12751 if (AR && AR->getLoop() == L && AR->isAffine()) {
12752 auto canProveNUW = [&]() {
12753 // We can use the comparison to infer no-wrap flags only if it fully
12754 // controls the loop exit.
12755 if (!ControlsOnlyExit)
12756 return false;
12757
12758 if (!isLoopInvariant(RHS, L))
12759 return false;
12760
12761 if (!isKnownNonZero(AR->getStepRecurrence(*this)))
12762 // We need the sequence defined by AR to strictly increase in the
12763 // unsigned integer domain for the logic below to hold.
12764 return false;
12765
12766 const unsigned InnerBitWidth = getTypeSizeInBits(AR->getType());
12767 const unsigned OuterBitWidth = getTypeSizeInBits(RHS->getType());
12768 // If RHS <=u Limit, then there must exist a value V in the sequence
12769 // defined by AR (e.g. {Start,+,Step}) such that V >u RHS, and
12770 // V <=u UINT_MAX. Thus, we must exit the loop before unsigned
12771 // overflow occurs. This limit also implies that a signed comparison
12772 // (in the wide bitwidth) is equivalent to an unsigned comparison as
12773 // the high bits on both sides must be zero.
12774 APInt StrideMax = getUnsignedRangeMax(AR->getStepRecurrence(*this));
12775 APInt Limit = APInt::getMaxValue(InnerBitWidth) - (StrideMax - 1);
12776 Limit = Limit.zext(OuterBitWidth);
12777 return getUnsignedRangeMax(applyLoopGuards(RHS, L)).ule(Limit);
12778 };
12779 auto Flags = AR->getNoWrapFlags();
12780 if (!hasFlags(Flags, SCEV::FlagNUW) && canProveNUW())
12781 Flags = setFlags(Flags, SCEV::FlagNUW);
12782
12783 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), Flags);
12784 if (AR->hasNoUnsignedWrap()) {
12785 // Emulate what getZeroExtendExpr would have done during construction
12786 // if we'd been able to infer the fact just above at that time.
12787 const SCEV *Step = AR->getStepRecurrence(*this);
12788 Type *Ty = ZExt->getType();
12789 auto *S = getAddRecExpr(
12790 getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this, 0),
12791 getZeroExtendExpr(Step, Ty, 0), L, AR->getNoWrapFlags());
12792 IV = dyn_cast<SCEVAddRecExpr>(S);
12793 }
12794 }
12795 }
12796 }
12797
12798
12799 if (!IV && AllowPredicates) {
12800 // Try to make this an AddRec using runtime tests, in the first X
12801 // iterations of this loop, where X is the SCEV expression found by the
12802 // algorithm below.
12803 IV = convertSCEVToAddRecWithPredicates(LHS, L, Predicates);
12804 PredicatedIV = true;
12805 }
12806
12807 // Avoid weird loops
12808 if (!IV || IV->getLoop() != L || !IV->isAffine())
12809 return getCouldNotCompute();
12810
12811 // A precondition of this method is that the condition being analyzed
12812 // reaches an exiting branch which dominates the latch. Given that, we can
12813 // assume that an increment which violates the nowrap specification and
12814 // produces poison must cause undefined behavior when the resulting poison
12815 // value is branched upon and thus we can conclude that the backedge is
12816 // taken no more often than would be required to produce that poison value.
12817 // Note that a well defined loop can exit on the iteration which violates
12818 // the nowrap specification if there is another exit (either explicit or
12819 // implicit/exceptional) which causes the loop to execute before the
12820 // exiting instruction we're analyzing would trigger UB.
12821 auto WrapType = IsSigned ? SCEV::FlagNSW : SCEV::FlagNUW;
12822 bool NoWrap = ControlsOnlyExit && IV->getNoWrapFlags(WrapType);
12824
12825 const SCEV *Stride = IV->getStepRecurrence(*this);
12826
12827 bool PositiveStride = isKnownPositive(Stride);
12828
12829 // Avoid negative or zero stride values.
12830 if (!PositiveStride) {
12831 // We can compute the correct backedge taken count for loops with unknown
12832 // strides if we can prove that the loop is not an infinite loop with side
12833 // effects. Here's the loop structure we are trying to handle -
12834 //
12835 // i = start
12836 // do {
12837 // A[i] = i;
12838 // i += s;
12839 // } while (i < end);
12840 //
12841 // The backedge taken count for such loops is evaluated as -
12842 // (max(end, start + stride) - start - 1) /u stride
12843 //
12844 // The additional preconditions that we need to check to prove correctness
12845 // of the above formula is as follows -
12846 //
12847 // a) IV is either nuw or nsw depending upon signedness (indicated by the
12848 // NoWrap flag).
12849 // b) the loop is guaranteed to be finite (e.g. is mustprogress and has
12850 // no side effects within the loop)
12851 // c) loop has a single static exit (with no abnormal exits)
12852 //
12853 // Precondition a) implies that if the stride is negative, this is a single
12854 // trip loop. The backedge taken count formula reduces to zero in this case.
12855 //
12856 // Precondition b) and c) combine to imply that if rhs is invariant in L,
12857 // then a zero stride means the backedge can't be taken without executing
12858 // undefined behavior.
12859 //
12860 // The positive stride case is the same as isKnownPositive(Stride) returning
12861 // true (original behavior of the function).
12862 //
12863 if (PredicatedIV || !NoWrap || !loopIsFiniteByAssumption(L) ||
12865 return getCouldNotCompute();
12866
12867 if (!isKnownNonZero(Stride)) {
12868 // If we have a step of zero, and RHS isn't invariant in L, we don't know
12869 // if it might eventually be greater than start and if so, on which
12870 // iteration. We can't even produce a useful upper bound.
12871 if (!isLoopInvariant(RHS, L))
12872 return getCouldNotCompute();
12873
12874 // We allow a potentially zero stride, but we need to divide by stride
12875 // below. Since the loop can't be infinite and this check must control
12876 // the sole exit, we can infer the exit must be taken on the first
12877 // iteration (e.g. backedge count = 0) if the stride is zero. Given that,
12878 // we know the numerator in the divides below must be zero, so we can
12879 // pick an arbitrary non-zero value for the denominator (e.g. stride)
12880 // and produce the right result.
12881 // FIXME: Handle the case where Stride is poison?
12882 auto wouldZeroStrideBeUB = [&]() {
12883 // Proof by contradiction. Suppose the stride were zero. If we can
12884 // prove that the backedge *is* taken on the first iteration, then since
12885 // we know this condition controls the sole exit, we must have an
12886 // infinite loop. We can't have a (well defined) infinite loop per
12887 // check just above.
12888 // Note: The (Start - Stride) term is used to get the start' term from
12889 // (start' + stride,+,stride). Remember that we only care about the
12890 // result of this expression when stride == 0 at runtime.
12891 auto *StartIfZero = getMinusSCEV(IV->getStart(), Stride);
12892 return isLoopEntryGuardedByCond(L, Cond, StartIfZero, RHS);
12893 };
12894 if (!wouldZeroStrideBeUB()) {
12895 Stride = getUMaxExpr(Stride, getOne(Stride->getType()));
12896 }
12897 }
12898 } else if (!Stride->isOne() && !NoWrap) {
12899 auto isUBOnWrap = [&]() {
12900 // From no-self-wrap, we need to then prove no-(un)signed-wrap. This
12901 // follows trivially from the fact that every (un)signed-wrapped, but
12902 // not self-wrapped value must be LT than the last value before
12903 // (un)signed wrap. Since we know that last value didn't exit, nor
12904 // will any smaller one.
12905 return canAssumeNoSelfWrap(IV);
12906 };
12907
12908 // Avoid proven overflow cases: this will ensure that the backedge taken
12909 // count will not generate any unsigned overflow. Relaxed no-overflow
12910 // conditions exploit NoWrapFlags, allowing to optimize in presence of
12911 // undefined behaviors like the case of C language.
12912 if (canIVOverflowOnLT(RHS, Stride, IsSigned) && !isUBOnWrap())
12913 return getCouldNotCompute();
12914 }
12915
12916 // On all paths just preceeding, we established the following invariant:
12917 // IV can be assumed not to overflow up to and including the exiting
12918 // iteration. We proved this in one of two ways:
12919 // 1) We can show overflow doesn't occur before the exiting iteration
12920 // 1a) canIVOverflowOnLT, and b) step of one
12921 // 2) We can show that if overflow occurs, the loop must execute UB
12922 // before any possible exit.
12923 // Note that we have not yet proved RHS invariant (in general).
12924
12925 const SCEV *Start = IV->getStart();
12926
12927 // Preserve pointer-typed Start/RHS to pass to isLoopEntryGuardedByCond.
12928 // If we convert to integers, isLoopEntryGuardedByCond will miss some cases.
12929 // Use integer-typed versions for actual computation; we can't subtract
12930 // pointers in general.
12931 const SCEV *OrigStart = Start;
12932 const SCEV *OrigRHS = RHS;
12933 if (Start->getType()->isPointerTy()) {
12934 Start = getLosslessPtrToIntExpr(Start);
12935 if (isa<SCEVCouldNotCompute>(Start))
12936 return Start;
12937 }
12938 if (RHS->getType()->isPointerTy()) {
12940 if (isa<SCEVCouldNotCompute>(RHS))
12941 return RHS;
12942 }
12943
12944 // When the RHS is not invariant, we do not know the end bound of the loop and
12945 // cannot calculate the ExactBECount needed by ExitLimit. However, we can
12946 // calculate the MaxBECount, given the start, stride and max value for the end
12947 // bound of the loop (RHS), and the fact that IV does not overflow (which is
12948 // checked above).
12949 if (!isLoopInvariant(RHS, L)) {
12950 const SCEV *MaxBECount = computeMaxBECountForLT(
12951 Start, Stride, RHS, getTypeSizeInBits(LHS->getType()), IsSigned);
12952 return ExitLimit(getCouldNotCompute() /* ExactNotTaken */, MaxBECount,
12953 MaxBECount, false /*MaxOrZero*/, Predicates);
12954 }
12955
12956 // We use the expression (max(End,Start)-Start)/Stride to describe the
12957 // backedge count, as if the backedge is taken at least once max(End,Start)
12958 // is End and so the result is as above, and if not max(End,Start) is Start
12959 // so we get a backedge count of zero.
12960 const SCEV *BECount = nullptr;
12961 auto *OrigStartMinusStride = getMinusSCEV(OrigStart, Stride);
12962 assert(isAvailableAtLoopEntry(OrigStartMinusStride, L) && "Must be!");
12963 assert(isAvailableAtLoopEntry(OrigStart, L) && "Must be!");
12964 assert(isAvailableAtLoopEntry(OrigRHS, L) && "Must be!");
12965 // Can we prove (max(RHS,Start) > Start - Stride?
12966 if (isLoopEntryGuardedByCond(L, Cond, OrigStartMinusStride, OrigStart) &&
12967 isLoopEntryGuardedByCond(L, Cond, OrigStartMinusStride, OrigRHS)) {
12968 // In this case, we can use a refined formula for computing backedge taken
12969 // count. The general formula remains:
12970 // "End-Start /uceiling Stride" where "End = max(RHS,Start)"
12971 // We want to use the alternate formula:
12972 // "((End - 1) - (Start - Stride)) /u Stride"
12973 // Let's do a quick case analysis to show these are equivalent under
12974 // our precondition that max(RHS,Start) > Start - Stride.
12975 // * For RHS <= Start, the backedge-taken count must be zero.
12976 // "((End - 1) - (Start - Stride)) /u Stride" reduces to
12977 // "((Start - 1) - (Start - Stride)) /u Stride" which simplies to
12978 // "Stride - 1 /u Stride" which is indeed zero for all non-zero values
12979 // of Stride. For 0 stride, we've use umin(1,Stride) above, reducing
12980 // this to the stride of 1 case.
12981 // * For RHS >= Start, the backedge count must be "RHS-Start /uceil Stride".
12982 // "((End - 1) - (Start - Stride)) /u Stride" reduces to
12983 // "((RHS - 1) - (Start - Stride)) /u Stride" reassociates to
12984 // "((RHS - (Start - Stride) - 1) /u Stride".
12985 // Our preconditions trivially imply no overflow in that form.
12986 const SCEV *MinusOne = getMinusOne(Stride->getType());
12987 const SCEV *Numerator =
12988 getMinusSCEV(getAddExpr(RHS, MinusOne), getMinusSCEV(Start, Stride));
12989 BECount = getUDivExpr(Numerator, Stride);
12990 }
12991
12992 const SCEV *BECountIfBackedgeTaken = nullptr;
12993 if (!BECount) {
12994 auto canProveRHSGreaterThanEqualStart = [&]() {
12995 auto CondGE = IsSigned ? ICmpInst::ICMP_SGE : ICmpInst::ICMP_UGE;
12996 const SCEV *GuardedRHS = applyLoopGuards(OrigRHS, L);
12997 const SCEV *GuardedStart = applyLoopGuards(OrigStart, L);
12998
12999 if (isLoopEntryGuardedByCond(L, CondGE, OrigRHS, OrigStart) ||
13000 isKnownPredicate(CondGE, GuardedRHS, GuardedStart))
13001 return true;
13002
13003 // (RHS > Start - 1) implies RHS >= Start.
13004 // * "RHS >= Start" is trivially equivalent to "RHS > Start - 1" if
13005 // "Start - 1" doesn't overflow.
13006 // * For signed comparison, if Start - 1 does overflow, it's equal
13007 // to INT_MAX, and "RHS >s INT_MAX" is trivially false.
13008 // * For unsigned comparison, if Start - 1 does overflow, it's equal
13009 // to UINT_MAX, and "RHS >u UINT_MAX" is trivially false.
13010 //
13011 // FIXME: Should isLoopEntryGuardedByCond do this for us?
13012 auto CondGT = IsSigned ? ICmpInst::ICMP_SGT : ICmpInst::ICMP_UGT;
13013 auto *StartMinusOne = getAddExpr(OrigStart,
13014 getMinusOne(OrigStart->getType()));
13015 return isLoopEntryGuardedByCond(L, CondGT, OrigRHS, StartMinusOne);
13016 };
13017
13018 // If we know that RHS >= Start in the context of loop, then we know that
13019 // max(RHS, Start) = RHS at this point.
13020 const SCEV *End;
13021 if (canProveRHSGreaterThanEqualStart()) {
13022 End = RHS;
13023 } else {
13024 // If RHS < Start, the backedge will be taken zero times. So in
13025 // general, we can write the backedge-taken count as:
13026 //
13027 // RHS >= Start ? ceil(RHS - Start) / Stride : 0
13028 //
13029 // We convert it to the following to make it more convenient for SCEV:
13030 //
13031 // ceil(max(RHS, Start) - Start) / Stride
13032 End = IsSigned ? getSMaxExpr(RHS, Start) : getUMaxExpr(RHS, Start);
13033
13034 // See what would happen if we assume the backedge is taken. This is
13035 // used to compute MaxBECount.
13036 BECountIfBackedgeTaken = getUDivCeilSCEV(getMinusSCEV(RHS, Start), Stride);
13037 }
13038
13039 // At this point, we know:
13040 //
13041 // 1. If IsSigned, Start <=s End; otherwise, Start <=u End
13042 // 2. The index variable doesn't overflow.
13043 //
13044 // Therefore, we know N exists such that
13045 // (Start + Stride * N) >= End, and computing "(Start + Stride * N)"
13046 // doesn't overflow.
13047 //
13048 // Using this information, try to prove whether the addition in
13049 // "(Start - End) + (Stride - 1)" has unsigned overflow.
13050 const SCEV *One = getOne(Stride->getType());
13051 bool MayAddOverflow = [&] {
13052 if (auto *StrideC = dyn_cast<SCEVConstant>(Stride)) {
13053 if (StrideC->getAPInt().isPowerOf2()) {
13054 // Suppose Stride is a power of two, and Start/End are unsigned
13055 // integers. Let UMAX be the largest representable unsigned
13056 // integer.
13057 //
13058 // By the preconditions of this function, we know
13059 // "(Start + Stride * N) >= End", and this doesn't overflow.
13060 // As a formula:
13061 //
13062 // End <= (Start + Stride * N) <= UMAX
13063 //
13064 // Subtracting Start from all the terms:
13065 //
13066 // End - Start <= Stride * N <= UMAX - Start
13067 //
13068 // Since Start is unsigned, UMAX - Start <= UMAX. Therefore:
13069 //
13070 // End - Start <= Stride * N <= UMAX
13071 //
13072 // Stride * N is a multiple of Stride. Therefore,
13073 //
13074 // End - Start <= Stride * N <= UMAX - (UMAX mod Stride)
13075 //
13076 // Since Stride is a power of two, UMAX + 1 is divisible by Stride.
13077 // Therefore, UMAX mod Stride == Stride - 1. So we can write:
13078 //
13079 // End - Start <= Stride * N <= UMAX - Stride - 1
13080 //
13081 // Dropping the middle term:
13082 //
13083 // End - Start <= UMAX - Stride - 1
13084 //
13085 // Adding Stride - 1 to both sides:
13086 //
13087 // (End - Start) + (Stride - 1) <= UMAX
13088 //
13089 // In other words, the addition doesn't have unsigned overflow.
13090 //
13091 // A similar proof works if we treat Start/End as signed values.
13092 // Just rewrite steps before "End - Start <= Stride * N <= UMAX" to
13093 // use signed max instead of unsigned max. Note that we're trying
13094 // to prove a lack of unsigned overflow in either case.
13095 return false;
13096 }
13097 }
13098 if (Start == Stride || Start == getMinusSCEV(Stride, One)) {
13099 // If Start is equal to Stride, (End - Start) + (Stride - 1) == End - 1.
13100 // If !IsSigned, 0 <u Stride == Start <=u End; so 0 <u End - 1 <u End.
13101 // If IsSigned, 0 <s Stride == Start <=s End; so 0 <s End - 1 <s End.
13102 //
13103 // If Start is equal to Stride - 1, (End - Start) + Stride - 1 == End.
13104 return false;
13105 }
13106 return true;
13107 }();
13108
13109 const SCEV *Delta = getMinusSCEV(End, Start);
13110 if (!MayAddOverflow) {
13111 // floor((D + (S - 1)) / S)
13112 // We prefer this formulation if it's legal because it's fewer operations.
13113 BECount =
13114 getUDivExpr(getAddExpr(Delta, getMinusSCEV(Stride, One)), Stride);
13115 } else {
13116 BECount = getUDivCeilSCEV(Delta, Stride);
13117 }
13118 }
13119
13120 const SCEV *ConstantMaxBECount;
13121 bool MaxOrZero = false;
13122 if (isa<SCEVConstant>(BECount)) {
13123 ConstantMaxBECount = BECount;
13124 } else if (BECountIfBackedgeTaken &&
13125 isa<SCEVConstant>(BECountIfBackedgeTaken)) {
13126 // If we know exactly how many times the backedge will be taken if it's
13127 // taken at least once, then the backedge count will either be that or
13128 // zero.
13129 ConstantMaxBECount = BECountIfBackedgeTaken;
13130 MaxOrZero = true;
13131 } else {
13132 ConstantMaxBECount = computeMaxBECountForLT(
13133 Start, Stride, RHS, getTypeSizeInBits(LHS->getType()), IsSigned);
13134 }
13135
13136 if (isa<SCEVCouldNotCompute>(ConstantMaxBECount) &&
13137 !isa<SCEVCouldNotCompute>(BECount))
13138 ConstantMaxBECount = getConstant(getUnsignedRangeMax(BECount));
13139
13140 const SCEV *SymbolicMaxBECount =
13141 isa<SCEVCouldNotCompute>(BECount) ? ConstantMaxBECount : BECount;
13142 return ExitLimit(BECount, ConstantMaxBECount, SymbolicMaxBECount, MaxOrZero,
13143 Predicates);
13144}
13145
13146ScalarEvolution::ExitLimit ScalarEvolution::howManyGreaterThans(
13147 const SCEV *LHS, const SCEV *RHS, const Loop *L, bool IsSigned,
13148 bool ControlsOnlyExit, bool AllowPredicates) {
13150 // We handle only IV > Invariant
13151 if (!isLoopInvariant(RHS, L))
13152 return getCouldNotCompute();
13153
13154 const SCEVAddRecExpr *IV = dyn_cast<SCEVAddRecExpr>(LHS);
13155 if (!IV && AllowPredicates)
13156 // Try to make this an AddRec using runtime tests, in the first X
13157 // iterations of this loop, where X is the SCEV expression found by the
13158 // algorithm below.
13159 IV = convertSCEVToAddRecWithPredicates(LHS, L, Predicates);
13160
13161 // Avoid weird loops
13162 if (!IV || IV->getLoop() != L || !IV->isAffine())
13163 return getCouldNotCompute();
13164
13165 auto WrapType = IsSigned ? SCEV::FlagNSW : SCEV::FlagNUW;
13166 bool NoWrap = ControlsOnlyExit && IV->getNoWrapFlags(WrapType);
13168
13169 const SCEV *Stride = getNegativeSCEV(IV->getStepRecurrence(*this));
13170
13171 // Avoid negative or zero stride values
13172 if (!isKnownPositive(Stride))
13173 return getCouldNotCompute();
13174
13175 // Avoid proven overflow cases: this will ensure that the backedge taken count
13176 // will not generate any unsigned overflow. Relaxed no-overflow conditions
13177 // exploit NoWrapFlags, allowing to optimize in presence of undefined
13178 // behaviors like the case of C language.
13179 if (!Stride->isOne() && !NoWrap)
13180 if (canIVOverflowOnGT(RHS, Stride, IsSigned))
13181 return getCouldNotCompute();
13182
13183 const SCEV *Start = IV->getStart();
13184 const SCEV *End = RHS;
13185 if (!isLoopEntryGuardedByCond(L, Cond, getAddExpr(Start, Stride), RHS)) {
13186 // If we know that Start >= RHS in the context of loop, then we know that
13187 // min(RHS, Start) = RHS at this point.
13189 L, IsSigned ? ICmpInst::ICMP_SGE : ICmpInst::ICMP_UGE, Start, RHS))
13190 End = RHS;
13191 else
13192 End = IsSigned ? getSMinExpr(RHS, Start) : getUMinExpr(RHS, Start);
13193 }
13194
13195 if (Start->getType()->isPointerTy()) {
13196 Start = getLosslessPtrToIntExpr(Start);
13197 if (isa<SCEVCouldNotCompute>(Start))
13198 return Start;
13199 }
13200 if (End->getType()->isPointerTy()) {
13202 if (isa<SCEVCouldNotCompute>(End))
13203 return End;
13204 }
13205
13206 // Compute ((Start - End) + (Stride - 1)) / Stride.
13207 // FIXME: This can overflow. Holding off on fixing this for now;
13208 // howManyGreaterThans will hopefully be gone soon.
13209 const SCEV *One = getOne(Stride->getType());
13210 const SCEV *BECount = getUDivExpr(
13211 getAddExpr(getMinusSCEV(Start, End), getMinusSCEV(Stride, One)), Stride);
13212
13213 APInt MaxStart = IsSigned ? getSignedRangeMax(Start)
13214 : getUnsignedRangeMax(Start);
13215
13216 APInt MinStride = IsSigned ? getSignedRangeMin(Stride)
13217 : getUnsignedRangeMin(Stride);
13218
13219 unsigned BitWidth = getTypeSizeInBits(LHS->getType());
13220 APInt Limit = IsSigned ? APInt::getSignedMinValue(BitWidth) + (MinStride - 1)
13221 : APInt::getMinValue(BitWidth) + (MinStride - 1);
13222
13223 // Although End can be a MIN expression we estimate MinEnd considering only
13224 // the case End = RHS. This is safe because in the other case (Start - End)
13225 // is zero, leading to a zero maximum backedge taken count.
13226 APInt MinEnd =
13227 IsSigned ? APIntOps::smax(getSignedRangeMin(RHS), Limit)
13228 : APIntOps::umax(getUnsignedRangeMin(RHS), Limit);
13229
13230 const SCEV *ConstantMaxBECount =
13231 isa<SCEVConstant>(BECount)
13232 ? BECount
13233 : getUDivCeilSCEV(getConstant(MaxStart - MinEnd),
13234 getConstant(MinStride));
13235
13236 if (isa<SCEVCouldNotCompute>(ConstantMaxBECount))
13237 ConstantMaxBECount = BECount;
13238 const SCEV *SymbolicMaxBECount =
13239 isa<SCEVCouldNotCompute>(BECount) ? ConstantMaxBECount : BECount;
13240
13241 return ExitLimit(BECount, ConstantMaxBECount, SymbolicMaxBECount, false,
13242 Predicates);
13243}
13244
13246 ScalarEvolution &SE) const {
13247 if (Range.isFullSet()) // Infinite loop.
13248 return SE.getCouldNotCompute();
13249
13250 // If the start is a non-zero constant, shift the range to simplify things.
13251 if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(getStart()))
13252 if (!SC->getValue()->isZero()) {
13254 Operands[0] = SE.getZero(SC->getType());
13255 const SCEV *Shifted = SE.getAddRecExpr(Operands, getLoop(),
13256 getNoWrapFlags(FlagNW));
13257 if (const auto *ShiftedAddRec = dyn_cast<SCEVAddRecExpr>(Shifted))
13258 return ShiftedAddRec->getNumIterationsInRange(
13259 Range.subtract(SC->getAPInt()), SE);
13260 // This is strange and shouldn't happen.
13261 return SE.getCouldNotCompute();
13262 }
13263
13264 // The only time we can solve this is when we have all constant indices.
13265 // Otherwise, we cannot determine the overflow conditions.
13266 if (any_of(operands(), [](const SCEV *Op) { return !isa<SCEVConstant>(Op); }))
13267 return SE.getCouldNotCompute();
13268
13269 // Okay at this point we know that all elements of the chrec are constants and
13270 // that the start element is zero.
13271
13272 // First check to see if the range contains zero. If not, the first
13273 // iteration exits.
13274 unsigned BitWidth = SE.getTypeSizeInBits(getType());
13275 if (!Range.contains(APInt(BitWidth, 0)))
13276 return SE.getZero(getType());
13277
13278 if (isAffine()) {
13279 // If this is an affine expression then we have this situation:
13280 // Solve {0,+,A} in Range === Ax in Range
13281
13282 // We know that zero is in the range. If A is positive then we know that
13283 // the upper value of the range must be the first possible exit value.
13284 // If A is negative then the lower of the range is the last possible loop
13285 // value. Also note that we already checked for a full range.
13286 APInt A = cast<SCEVConstant>(getOperand(1))->getAPInt();
13287 APInt End = A.sge(1) ? (Range.getUpper() - 1) : Range.getLower();
13288
13289 // The exit value should be (End+A)/A.
13290 APInt ExitVal = (End + A).udiv(A);
13291 ConstantInt *ExitValue = ConstantInt::get(SE.getContext(), ExitVal);
13292
13293 // Evaluate at the exit value. If we really did fall out of the valid
13294 // range, then we computed our trip count, otherwise wrap around or other
13295 // things must have happened.
13296 ConstantInt *Val = EvaluateConstantChrecAtConstant(this, ExitValue, SE);
13297 if (Range.contains(Val->getValue()))
13298 return SE.getCouldNotCompute(); // Something strange happened
13299
13300 // Ensure that the previous value is in the range.
13301 assert(Range.contains(
13303 ConstantInt::get(SE.getContext(), ExitVal - 1), SE)->getValue()) &&
13304 "Linear scev computation is off in a bad way!");
13305 return SE.getConstant(ExitValue);
13306 }
13307
13308 if (isQuadratic()) {
13309 if (auto S = SolveQuadraticAddRecRange(this, Range, SE))
13310 return SE.getConstant(*S);
13311 }
13312
13313 return SE.getCouldNotCompute();
13314}
13315
13316const SCEVAddRecExpr *
13318 assert(getNumOperands() > 1 && "AddRec with zero step?");
13319 // There is a temptation to just call getAddExpr(this, getStepRecurrence(SE)),
13320 // but in this case we cannot guarantee that the value returned will be an
13321 // AddRec because SCEV does not have a fixed point where it stops
13322 // simplification: it is legal to return ({rec1} + {rec2}). For example, it
13323 // may happen if we reach arithmetic depth limit while simplifying. So we
13324 // construct the returned value explicitly.
13326 // If this is {A,+,B,+,C,...,+,N}, then its step is {B,+,C,+,...,+,N}, and
13327 // (this + Step) is {A+B,+,B+C,+...,+,N}.
13328 for (unsigned i = 0, e = getNumOperands() - 1; i < e; ++i)
13329 Ops.push_back(SE.getAddExpr(getOperand(i), getOperand(i + 1)));
13330 // We know that the last operand is not a constant zero (otherwise it would
13331 // have been popped out earlier). This guarantees us that if the result has
13332 // the same last operand, then it will also not be popped out, meaning that
13333 // the returned value will be an AddRec.
13334 const SCEV *Last = getOperand(getNumOperands() - 1);
13335 assert(!Last->isZero() && "Recurrency with zero step?");
13336 Ops.push_back(Last);
13337 return cast<SCEVAddRecExpr>(SE.getAddRecExpr(Ops, getLoop(),
13339}
13340
13341// Return true when S contains at least an undef value.
13343 return SCEVExprContains(S, [](const SCEV *S) {
13344 if (const auto *SU = dyn_cast<SCEVUnknown>(S))
13345 return isa<UndefValue>(SU->getValue());
13346 return false;
13347 });
13348}
13349
13350// Return true when S contains a value that is a nullptr.
13352 return SCEVExprContains(S, [](const SCEV *S) {
13353 if (const auto *SU = dyn_cast<SCEVUnknown>(S))
13354 return SU->getValue() == nullptr;
13355 return false;
13356 });
13357}
13358
13359/// Return the size of an element read or written by Inst.
13361 Type *Ty;
13362 if (StoreInst *Store = dyn_cast<StoreInst>(Inst))
13363 Ty = Store->getValueOperand()->getType();
13364 else if (LoadInst *Load = dyn_cast<LoadInst>(Inst))
13365 Ty = Load->getType();
13366 else
13367 return nullptr;
13368
13370 return getSizeOfExpr(ETy, Ty);
13371}
13372
13373//===----------------------------------------------------------------------===//
13374// SCEVCallbackVH Class Implementation
13375//===----------------------------------------------------------------------===//
13376
13377void ScalarEvolution::SCEVCallbackVH::deleted() {
13378 assert(SE && "SCEVCallbackVH called with a null ScalarEvolution!");
13379 if (PHINode *PN = dyn_cast<PHINode>(getValPtr()))
13380 SE->ConstantEvolutionLoopExitValue.erase(PN);
13381 SE->eraseValueFromMap(getValPtr());
13382 // this now dangles!
13383}
13384
13385void ScalarEvolution::SCEVCallbackVH::allUsesReplacedWith(Value *V) {
13386 assert(SE && "SCEVCallbackVH called with a null ScalarEvolution!");
13387
13388 // Forget all the expressions associated with users of the old value,
13389 // so that future queries will recompute the expressions using the new
13390 // value.
13391 SE->forgetValue(getValPtr());
13392 // this now dangles!
13393}
13394
13395ScalarEvolution::SCEVCallbackVH::SCEVCallbackVH(Value *V, ScalarEvolution *se)
13396 : CallbackVH(V), SE(se) {}
13397
13398//===----------------------------------------------------------------------===//
13399// ScalarEvolution Class Implementation
13400//===----------------------------------------------------------------------===//
13401
13404 LoopInfo &LI)
13405 : F(F), TLI(TLI), AC(AC), DT(DT), LI(LI),
13406 CouldNotCompute(new SCEVCouldNotCompute()), ValuesAtScopes(64),
13407 LoopDispositions(64), BlockDispositions(64) {
13408 // To use guards for proving predicates, we need to scan every instruction in
13409 // relevant basic blocks, and not just terminators. Doing this is a waste of
13410 // time if the IR does not actually contain any calls to
13411 // @llvm.experimental.guard, so do a quick check and remember this beforehand.
13412 //
13413 // This pessimizes the case where a pass that preserves ScalarEvolution wants
13414 // to _add_ guards to the module when there weren't any before, and wants
13415 // ScalarEvolution to optimize based on those guards. For now we prefer to be
13416 // efficient in lieu of being smart in that rather obscure case.
13417
13418 auto *GuardDecl = F.getParent()->getFunction(
13419 Intrinsic::getName(Intrinsic::experimental_guard));
13420 HasGuards = GuardDecl && !GuardDecl->use_empty();
13421}
13422
13424 : F(Arg.F), HasGuards(Arg.HasGuards), TLI(Arg.TLI), AC(Arg.AC), DT(Arg.DT),
13425 LI(Arg.LI), CouldNotCompute(std::move(Arg.CouldNotCompute)),
13426 ValueExprMap(std::move(Arg.ValueExprMap)),
13427 PendingLoopPredicates(std::move(Arg.PendingLoopPredicates)),
13428 PendingPhiRanges(std::move(Arg.PendingPhiRanges)),
13429 PendingMerges(std::move(Arg.PendingMerges)),
13430 ConstantMultipleCache(std::move(Arg.ConstantMultipleCache)),
13431 BackedgeTakenCounts(std::move(Arg.BackedgeTakenCounts)),
13432 PredicatedBackedgeTakenCounts(
13433 std::move(Arg.PredicatedBackedgeTakenCounts)),
13434 BECountUsers(std::move(Arg.BECountUsers)),
13435 ConstantEvolutionLoopExitValue(
13436 std::move(Arg.ConstantEvolutionLoopExitValue)),
13437 ValuesAtScopes(std::move(Arg.ValuesAtScopes)),
13438 ValuesAtScopesUsers(std::move(Arg.ValuesAtScopesUsers)),
13439 LoopDispositions(std::move(Arg.LoopDispositions)),
13440 LoopPropertiesCache(std::move(Arg.LoopPropertiesCache)),
13441 BlockDispositions(std::move(Arg.BlockDispositions)),
13442 SCEVUsers(std::move(Arg.SCEVUsers)),
13443 UnsignedRanges(std::move(Arg.UnsignedRanges)),
13444 SignedRanges(std::move(Arg.SignedRanges)),
13445 UniqueSCEVs(std::move(Arg.UniqueSCEVs)),
13446 UniquePreds(std::move(Arg.UniquePreds)),
13447 SCEVAllocator(std::move(Arg.SCEVAllocator)),
13448 LoopUsers(std::move(Arg.LoopUsers)),
13449 PredicatedSCEVRewrites(std::move(Arg.PredicatedSCEVRewrites)),
13450 FirstUnknown(Arg.FirstUnknown) {
13451 Arg.FirstUnknown = nullptr;
13452}
13453
13455 // Iterate through all the SCEVUnknown instances and call their
13456 // destructors, so that they release their references to their values.
13457 for (SCEVUnknown *U = FirstUnknown; U;) {
13458 SCEVUnknown *Tmp = U;
13459 U = U->Next;
13460 Tmp->~SCEVUnknown();
13461 }
13462 FirstUnknown = nullptr;
13463
13464 ExprValueMap.clear();
13465 ValueExprMap.clear();
13466 HasRecMap.clear();
13467 BackedgeTakenCounts.clear();
13468 PredicatedBackedgeTakenCounts.clear();
13469
13470 assert(PendingLoopPredicates.empty() && "isImpliedCond garbage");
13471 assert(PendingPhiRanges.empty() && "getRangeRef garbage");
13472 assert(PendingMerges.empty() && "isImpliedViaMerge garbage");
13473 assert(!WalkingBEDominatingConds && "isLoopBackedgeGuardedByCond garbage!");
13474 assert(!ProvingSplitPredicate && "ProvingSplitPredicate garbage!");
13475}
13476
13478 return !isa<SCEVCouldNotCompute>(getBackedgeTakenCount(L));
13479}
13480
13481/// When printing a top-level SCEV for trip counts, it's helpful to include
13482/// a type for constants which are otherwise hard to disambiguate.
13483static void PrintSCEVWithTypeHint(raw_ostream &OS, const SCEV* S) {
13484 if (isa<SCEVConstant>(S))
13485 OS << *S->getType() << " ";
13486 OS << *S;
13487}
13488
13490 const Loop *L) {
13491 // Print all inner loops first
13492 for (Loop *I : *L)
13493 PrintLoopInfo(OS, SE, I);
13494
13495 OS << "Loop ";
13496 L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
13497 OS << ": ";
13498
13499 SmallVector<BasicBlock *, 8> ExitingBlocks;
13500 L->getExitingBlocks(ExitingBlocks);
13501 if (ExitingBlocks.size() != 1)
13502 OS << "<multiple exits> ";
13503
13504 auto *BTC = SE->getBackedgeTakenCount(L);
13505 if (!isa<SCEVCouldNotCompute>(BTC)) {
13506 OS << "backedge-taken count is ";
13508 } else
13509 OS << "Unpredictable backedge-taken count.";
13510 OS << "\n";
13511
13512 if (ExitingBlocks.size() > 1)
13513 for (BasicBlock *ExitingBlock : ExitingBlocks) {
13514 OS << " exit count for " << ExitingBlock->getName() << ": ";
13515 PrintSCEVWithTypeHint(OS, SE->getExitCount(L, ExitingBlock));
13516 OS << "\n";
13517 }
13518
13519 OS << "Loop ";
13520 L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
13521 OS << ": ";
13522
13523 auto *ConstantBTC = SE->getConstantMaxBackedgeTakenCount(L);
13524 if (!isa<SCEVCouldNotCompute>(ConstantBTC)) {
13525 OS << "constant max backedge-taken count is ";
13526 PrintSCEVWithTypeHint(OS, ConstantBTC);
13528 OS << ", actual taken count either this or zero.";
13529 } else {
13530 OS << "Unpredictable constant max backedge-taken count. ";
13531 }
13532
13533 OS << "\n"
13534 "Loop ";
13535 L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
13536 OS << ": ";
13537
13538 auto *SymbolicBTC = SE->getSymbolicMaxBackedgeTakenCount(L);
13539 if (!isa<SCEVCouldNotCompute>(SymbolicBTC)) {
13540 OS << "symbolic max backedge-taken count is ";
13541 PrintSCEVWithTypeHint(OS, SymbolicBTC);
13543 OS << ", actual taken count either this or zero.";
13544 } else {
13545 OS << "Unpredictable symbolic max backedge-taken count. ";
13546 }
13547 OS << "\n";
13548
13549 if (ExitingBlocks.size() > 1)
13550 for (BasicBlock *ExitingBlock : ExitingBlocks) {
13551 OS << " symbolic max exit count for " << ExitingBlock->getName() << ": ";
13552 auto *ExitBTC = SE->getExitCount(L, ExitingBlock,
13554 PrintSCEVWithTypeHint(OS, ExitBTC);
13555 OS << "\n";
13556 }
13557
13559 auto *PBT = SE->getPredicatedBackedgeTakenCount(L, Preds);
13560 if (PBT != BTC || !Preds.empty()) {
13561 OS << "Loop ";
13562 L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
13563 OS << ": ";
13564 if (!isa<SCEVCouldNotCompute>(PBT)) {
13565 OS << "Predicated backedge-taken count is ";
13567 } else
13568 OS << "Unpredictable predicated backedge-taken count.";
13569 OS << "\n";
13570 OS << " Predicates:\n";
13571 for (const auto *P : Preds)
13572 P->print(OS, 4);
13573 }
13574
13576 OS << "Loop ";
13577 L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
13578 OS << ": ";
13579 OS << "Trip multiple is " << SE->getSmallConstantTripMultiple(L) << "\n";
13580 }
13581}
13582
13583namespace llvm {
13585 switch (LD) {
13587 OS << "Variant";
13588 break;
13590 OS << "Invariant";
13591 break;
13593 OS << "Computable";
13594 break;
13595 }
13596 return OS;
13597}
13598
13600 switch (BD) {
13602 OS << "DoesNotDominate";
13603 break;
13605 OS << "Dominates";
13606 break;
13608 OS << "ProperlyDominates";
13609 break;
13610 }
13611 return OS;
13612}
13613}
13614
13616 // ScalarEvolution's implementation of the print method is to print
13617 // out SCEV values of all instructions that are interesting. Doing
13618 // this potentially causes it to create new SCEV objects though,
13619 // which technically conflicts with the const qualifier. This isn't
13620 // observable from outside the class though, so casting away the
13621 // const isn't dangerous.
13622 ScalarEvolution &SE = *const_cast<ScalarEvolution *>(this);
13623
13624 if (ClassifyExpressions) {
13625 OS << "Classifying expressions for: ";
13626 F.printAsOperand(OS, /*PrintType=*/false);
13627 OS << "\n";
13628 for (Instruction &I : instructions(F))
13629 if (isSCEVable(I.getType()) && !isa<CmpInst>(I)) {
13630 OS << I << '\n';
13631 OS << " --> ";
13632 const SCEV *SV = SE.getSCEV(&I);
13633 SV->print(OS);
13634 if (!isa<SCEVCouldNotCompute>(SV)) {
13635 OS << " U: ";
13636 SE.getUnsignedRange(SV).print(OS);
13637 OS << " S: ";
13638 SE.getSignedRange(SV).print(OS);
13639 }
13640
13641 const Loop *L = LI.getLoopFor(I.getParent());
13642
13643 const SCEV *AtUse = SE.getSCEVAtScope(SV, L);
13644 if (AtUse != SV) {
13645 OS << " --> ";
13646 AtUse->print(OS);
13647 if (!isa<SCEVCouldNotCompute>(AtUse)) {
13648 OS << " U: ";
13649 SE.getUnsignedRange(AtUse).print(OS);
13650 OS << " S: ";
13651 SE.getSignedRange(AtUse).print(OS);
13652 }
13653 }
13654
13655 if (L) {
13656 OS << "\t\t" "Exits: ";
13657 const SCEV *ExitValue = SE.getSCEVAtScope(SV, L->getParentLoop());
13658 if (!SE.isLoopInvariant(ExitValue, L)) {
13659 OS << "<<Unknown>>";
13660 } else {
13661 OS << *ExitValue;
13662 }
13663
13664 bool First = true;
13665 for (const auto *Iter = L; Iter; Iter = Iter->getParentLoop()) {
13666 if (First) {
13667 OS << "\t\t" "LoopDispositions: { ";
13668 First = false;
13669 } else {
13670 OS << ", ";
13671 }
13672
13673 Iter->getHeader()->printAsOperand(OS, /*PrintType=*/false);
13674 OS << ": " << SE.getLoopDisposition(SV, Iter);
13675 }
13676
13677 for (const auto *InnerL : depth_first(L)) {
13678 if (InnerL == L)
13679 continue;
13680 if (First) {
13681 OS << "\t\t" "LoopDispositions: { ";
13682 First = false;
13683 } else {
13684 OS << ", ";
13685 }
13686
13687 InnerL->getHeader()->printAsOperand(OS, /*PrintType=*/false);
13688 OS << ": " << SE.getLoopDisposition(SV, InnerL);
13689 }
13690
13691 OS << " }";
13692 }
13693
13694 OS << "\n";
13695 }
13696 }
13697
13698 OS << "Determining loop execution counts for: ";
13699 F.printAsOperand(OS, /*PrintType=*/false);
13700 OS << "\n";
13701 for (Loop *I : LI)
13702 PrintLoopInfo(OS, &SE, I);
13703}
13704
13707 auto &Values = LoopDispositions[S];
13708 for (auto &V : Values) {
13709 if (V.getPointer() == L)
13710 return V.getInt();
13711 }
13712 Values.emplace_back(L, LoopVariant);
13713 LoopDisposition D = computeLoopDisposition(S, L);
13714 auto &Values2 = LoopDispositions[S];
13715 for (auto &V : llvm::reverse(Values2)) {
13716 if (V.getPointer() == L) {
13717 V.setInt(D);
13718 break;
13719 }
13720 }
13721 return D;
13722}
13723
13725ScalarEvolution::computeLoopDisposition(const SCEV *S, const Loop *L) {
13726 switch (S->getSCEVType()) {
13727 case scConstant:
13728 case scVScale:
13729 return LoopInvariant;
13730 case scAddRecExpr: {
13731 const SCEVAddRecExpr *AR = cast<SCEVAddRecExpr>(S);
13732
13733 // If L is the addrec's loop, it's computable.
13734 if (AR->getLoop() == L)
13735 return LoopComputable;
13736
13737 // Add recurrences are never invariant in the function-body (null loop).
13738 if (!L)
13739 return LoopVariant;
13740
13741 // Everything that is not defined at loop entry is variant.
13742 if (DT.dominates(L->getHeader(), AR->getLoop()->getHeader()))
13743 return LoopVariant;
13744 assert(!L->contains(AR->getLoop()) && "Containing loop's header does not"
13745 " dominate the contained loop's header?");
13746
13747 // This recurrence is invariant w.r.t. L if AR's loop contains L.
13748 if (AR->getLoop()->contains(L))
13749 return LoopInvariant;
13750
13751 // This recurrence is variant w.r.t. L if any of its operands
13752 // are variant.
13753 for (const auto *Op : AR->operands())
13754 if (!isLoopInvariant(Op, L))
13755 return LoopVariant;
13756
13757 // Otherwise it's loop-invariant.
13758 return LoopInvariant;
13759 }
13760 case scTruncate:
13761 case scZeroExtend:
13762 case scSignExtend:
13763 case scPtrToInt:
13764 case scAddExpr:
13765 case scMulExpr:
13766 case scUDivExpr:
13767 case scUMaxExpr:
13768 case scSMaxExpr:
13769 case scUMinExpr:
13770 case scSMinExpr:
13771 case scSequentialUMinExpr: {
13772 bool HasVarying = false;
13773 for (const auto *Op : S->operands()) {
13775 if (D == LoopVariant)
13776 return LoopVariant;
13777 if (D == LoopComputable)
13778 HasVarying = true;
13779 }
13780 return HasVarying ? LoopComputable : LoopInvariant;
13781 }
13782 case scUnknown:
13783 // All non-instruction values are loop invariant. All instructions are loop
13784 // invariant if they are not contained in the specified loop.
13785 // Instructions are never considered invariant in the function body
13786 // (null loop) because they are defined within the "loop".
13787 if (auto *I = dyn_cast<Instruction>(cast<SCEVUnknown>(S)->getValue()))
13788 return (L && !L->contains(I)) ? LoopInvariant : LoopVariant;
13789 return LoopInvariant;
13790 case scCouldNotCompute:
13791 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
13792 }
13793 llvm_unreachable("Unknown SCEV kind!");
13794}
13795
13797 return getLoopDisposition(S, L) == LoopInvariant;
13798}
13799
13801 return getLoopDisposition(S, L) == LoopComputable;
13802}
13803
13806 auto &Values = BlockDispositions[S];
13807 for (auto &V : Values) {
13808 if (V.getPointer() == BB)
13809 return V.getInt();
13810 }
13811 Values.emplace_back(BB, DoesNotDominateBlock);
13812 BlockDisposition D = computeBlockDisposition(S, BB);
13813 auto &Values2 = BlockDispositions[S];
13814 for (auto &V : llvm::reverse(Values2)) {
13815 if (V.getPointer() == BB) {
13816 V.setInt(D);
13817 break;
13818 }
13819 }
13820 return D;
13821}
13822
13824ScalarEvolution::computeBlockDisposition(const SCEV *S, const BasicBlock *BB) {
13825 switch (S->getSCEVType()) {
13826 case scConstant:
13827 case scVScale:
13829 case scAddRecExpr: {
13830 // This uses a "dominates" query instead of "properly dominates" query
13831 // to test for proper dominance too, because the instruction which
13832 // produces the addrec's value is a PHI, and a PHI effectively properly
13833 // dominates its entire containing block.
13834 const SCEVAddRecExpr *AR = cast<SCEVAddRecExpr>(S);
13835 if (!DT.dominates(AR->getLoop()->getHeader(), BB))
13836 return DoesNotDominateBlock;
13837
13838 // Fall through into SCEVNAryExpr handling.
13839 [[fallthrough]];
13840 }
13841 case scTruncate:
13842 case scZeroExtend:
13843 case scSignExtend:
13844 case scPtrToInt:
13845 case scAddExpr:
13846 case scMulExpr:
13847 case scUDivExpr:
13848 case scUMaxExpr:
13849 case scSMaxExpr:
13850 case scUMinExpr:
13851 case scSMinExpr:
13852 case scSequentialUMinExpr: {
13853 bool Proper = true;
13854 for (const SCEV *NAryOp : S->operands()) {
13856 if (D == DoesNotDominateBlock)
13857 return DoesNotDominateBlock;
13858 if (D == DominatesBlock)
13859 Proper = false;
13860 }
13861 return Proper ? ProperlyDominatesBlock : DominatesBlock;
13862 }
13863 case scUnknown:
13864 if (Instruction *I =
13865 dyn_cast<Instruction>(cast<SCEVUnknown>(S)->getValue())) {
13866 if (I->getParent() == BB)
13867 return DominatesBlock;
13868 if (DT.properlyDominates(I->getParent(), BB))
13870 return DoesNotDominateBlock;
13871 }
13873 case scCouldNotCompute:
13874 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
13875 }
13876 llvm_unreachable("Unknown SCEV kind!");
13877}
13878
13879bool ScalarEvolution::dominates(const SCEV *S, const BasicBlock *BB) {
13880 return getBlockDisposition(S, BB) >= DominatesBlock;
13881}
13882
13885}
13886
13887bool ScalarEvolution::hasOperand(const SCEV *S, const SCEV *Op) const {
13888 return SCEVExprContains(S, [&](const SCEV *Expr) { return Expr == Op; });
13889}
13890
13891void ScalarEvolution::forgetBackedgeTakenCounts(const Loop *L,
13892 bool Predicated) {
13893 auto &BECounts =
13894 Predicated ? PredicatedBackedgeTakenCounts : BackedgeTakenCounts;
13895 auto It = BECounts.find(L);
13896 if (It != BECounts.end()) {
13897 for (const ExitNotTakenInfo &ENT : It->second.ExitNotTaken) {
13898 for (const SCEV *S : {ENT.ExactNotTaken, ENT.SymbolicMaxNotTaken}) {
13899 if (!isa<SCEVConstant>(S)) {
13900 auto UserIt = BECountUsers.find(S);
13901 assert(UserIt != BECountUsers.end());
13902 UserIt->second.erase({L, Predicated});
13903 }
13904 }
13905 }
13906 BECounts.erase(It);
13907 }
13908}
13909
13910void ScalarEvolution::forgetMemoizedResults(ArrayRef<const SCEV *> SCEVs) {
13911 SmallPtrSet<const SCEV *, 8> ToForget(SCEVs.begin(), SCEVs.end());
13912 SmallVector<const SCEV *, 8> Worklist(ToForget.begin(), ToForget.end());
13913
13914 while (!Worklist.empty()) {
13915 const SCEV *Curr = Worklist.pop_back_val();
13916 auto Users = SCEVUsers.find(Curr);
13917 if (Users != SCEVUsers.end())
13918 for (const auto *User : Users->second)
13919 if (ToForget.insert(User).second)
13920 Worklist.push_back(User);
13921 }
13922
13923 for (const auto *S : ToForget)
13924 forgetMemoizedResultsImpl(S);
13925
13926 for (auto I = PredicatedSCEVRewrites.begin();
13927 I != PredicatedSCEVRewrites.end();) {
13928 std::pair<const SCEV *, const Loop *> Entry = I->first;
13929 if (ToForget.count(Entry.first))
13930 PredicatedSCEVRewrites.erase(I++);
13931 else
13932 ++I;
13933 }
13934}
13935
13936void ScalarEvolution::forgetMemoizedResultsImpl(const SCEV *S) {
13937 LoopDispositions.erase(S);
13938 BlockDispositions.erase(S);
13939 UnsignedRanges.erase(S);
13940 SignedRanges.erase(S);
13941 HasRecMap.erase(S);
13942 ConstantMultipleCache.erase(S);
13943
13944 if (auto *AR = dyn_cast<SCEVAddRecExpr>(S)) {
13945 UnsignedWrapViaInductionTried.erase(AR);
13946 SignedWrapViaInductionTried.erase(AR);
13947 }
13948
13949 auto ExprIt = ExprValueMap.find(S);
13950 if (ExprIt != ExprValueMap.end()) {
13951 for (Value *V : ExprIt->second) {
13952 auto ValueIt = ValueExprMap.find_as(V);
13953 if (ValueIt != ValueExprMap.end())
13954 ValueExprMap.erase(ValueIt);
13955 }
13956 ExprValueMap.erase(ExprIt);
13957 }
13958
13959 auto ScopeIt = ValuesAtScopes.find(S);
13960 if (ScopeIt != ValuesAtScopes.end()) {
13961 for (const auto &Pair : ScopeIt->second)
13962 if (!isa_and_nonnull<SCEVConstant>(Pair.second))
13963 llvm::erase(ValuesAtScopesUsers[Pair.second],
13964 std::make_pair(Pair.first, S));
13965 ValuesAtScopes.erase(ScopeIt);
13966 }
13967
13968 auto ScopeUserIt = ValuesAtScopesUsers.find(S);
13969 if (ScopeUserIt != ValuesAtScopesUsers.end()) {
13970 for (const auto &Pair : ScopeUserIt->second)
13971 llvm::erase(ValuesAtScopes[Pair.second], std::make_pair(Pair.first, S));
13972 ValuesAtScopesUsers.erase(ScopeUserIt);
13973 }
13974
13975 auto BEUsersIt = BECountUsers.find(S);
13976 if (BEUsersIt != BECountUsers.end()) {
13977 // Work on a copy, as forgetBackedgeTakenCounts() will modify the original.
13978 auto Copy = BEUsersIt->second;
13979 for (const auto &Pair : Copy)
13980 forgetBackedgeTakenCounts(Pair.getPointer(), Pair.getInt());
13981 BECountUsers.erase(BEUsersIt);
13982 }
13983
13984 auto FoldUser = FoldCacheUser.find(S);
13985 if (FoldUser != FoldCacheUser.end())
13986 for (auto &KV : FoldUser->second)
13987 FoldCache.erase(KV);
13988 FoldCacheUser.erase(S);
13989}
13990
13991void
13992ScalarEvolution::getUsedLoops(const SCEV *S,
13993 SmallPtrSetImpl<const Loop *> &LoopsUsed) {
13994 struct FindUsedLoops {
13995 FindUsedLoops(SmallPtrSetImpl<const Loop *> &LoopsUsed)
13996 : LoopsUsed(LoopsUsed) {}
13998 bool follow(const SCEV *S) {
13999 if (auto *AR = dyn_cast<SCEVAddRecExpr>(S))
14000 LoopsUsed.insert(AR->getLoop());
14001 return true;
14002 }
14003
14004 bool isDone() const { return false; }
14005 };
14006
14007 FindUsedLoops F(LoopsUsed);
14009}
14010
14011void ScalarEvolution::getReachableBlocks(
14014 Worklist.push_back(&F.getEntryBlock());
14015 while (!Worklist.empty()) {
14016 BasicBlock *BB = Worklist.pop_back_val();
14017 if (!Reachable.insert(BB).second)
14018 continue;
14019
14020 Value *Cond;
14021 BasicBlock *TrueBB, *FalseBB;
14022 if (match(BB->getTerminator(), m_Br(m_Value(Cond), m_BasicBlock(TrueBB),
14023 m_BasicBlock(FalseBB)))) {
14024 if (auto *C = dyn_cast<ConstantInt>(Cond)) {
14025 Worklist.push_back(C->isOne() ? TrueBB : FalseBB);
14026 continue;
14027 }
14028
14029 if (auto *Cmp = dyn_cast<ICmpInst>(Cond)) {
14030 const SCEV *L = getSCEV(Cmp->getOperand(0));
14031 const SCEV *R = getSCEV(Cmp->getOperand(1));
14032 if (isKnownPredicateViaConstantRanges(Cmp->getPredicate(), L, R)) {
14033 Worklist.push_back(TrueBB);
14034 continue;
14035 }
14036 if (isKnownPredicateViaConstantRanges(Cmp->getInversePredicate(), L,
14037 R)) {
14038 Worklist.push_back(FalseBB);
14039 continue;
14040 }
14041 }
14042 }
14043
14044 append_range(Worklist, successors(BB));
14045 }
14046}
14047
14049 ScalarEvolution &SE = *const_cast<ScalarEvolution *>(this);
14050 ScalarEvolution SE2(F, TLI, AC, DT, LI);
14051
14052 SmallVector<Loop *, 8> LoopStack(LI.begin(), LI.end());
14053
14054 // Map's SCEV expressions from one ScalarEvolution "universe" to another.
14055 struct SCEVMapper : public SCEVRewriteVisitor<SCEVMapper> {
14056 SCEVMapper(ScalarEvolution &SE) : SCEVRewriteVisitor<SCEVMapper>(SE) {}
14057
14058 const SCEV *visitConstant(const SCEVConstant *Constant) {
14059 return SE.getConstant(Constant->getAPInt());
14060 }
14061
14062 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
14063 return SE.getUnknown(Expr->getValue());
14064 }
14065
14066 const SCEV *visitCouldNotCompute(const SCEVCouldNotCompute *Expr) {
14067 return SE.getCouldNotCompute();
14068 }
14069 };
14070
14071 SCEVMapper SCM(SE2);
14072 SmallPtrSet<BasicBlock *, 16> ReachableBlocks;
14073 SE2.getReachableBlocks(ReachableBlocks, F);
14074
14075 auto GetDelta = [&](const SCEV *Old, const SCEV *New) -> const SCEV * {
14076 if (containsUndefs(Old) || containsUndefs(New)) {
14077 // SCEV treats "undef" as an unknown but consistent value (i.e. it does
14078 // not propagate undef aggressively). This means we can (and do) fail
14079 // verification in cases where a transform makes a value go from "undef"
14080 // to "undef+1" (say). The transform is fine, since in both cases the
14081 // result is "undef", but SCEV thinks the value increased by 1.
14082 return nullptr;
14083 }
14084
14085 // Unless VerifySCEVStrict is set, we only compare constant deltas.
14086 const SCEV *Delta = SE2.getMinusSCEV(Old, New);
14087 if (!VerifySCEVStrict && !isa<SCEVConstant>(Delta))
14088 return nullptr;
14089
14090 return Delta;
14091 };
14092
14093 while (!LoopStack.empty()) {
14094 auto *L = LoopStack.pop_back_val();
14095 llvm::append_range(LoopStack, *L);
14096
14097 // Only verify BECounts in reachable loops. For an unreachable loop,
14098 // any BECount is legal.
14099 if (!ReachableBlocks.contains(L->getHeader()))
14100 continue;
14101
14102 // Only verify cached BECounts. Computing new BECounts may change the
14103 // results of subsequent SCEV uses.
14104 auto It = BackedgeTakenCounts.find(L);
14105 if (It == BackedgeTakenCounts.end())
14106 continue;
14107
14108 auto *CurBECount =
14109 SCM.visit(It->second.getExact(L, const_cast<ScalarEvolution *>(this)));
14110 auto *NewBECount = SE2.getBackedgeTakenCount(L);
14111
14112 if (CurBECount == SE2.getCouldNotCompute() ||
14113 NewBECount == SE2.getCouldNotCompute()) {
14114 // NB! This situation is legal, but is very suspicious -- whatever pass
14115 // change the loop to make a trip count go from could not compute to
14116 // computable or vice-versa *should have* invalidated SCEV. However, we
14117 // choose not to assert here (for now) since we don't want false
14118 // positives.
14119 continue;
14120 }
14121
14122 if (SE.getTypeSizeInBits(CurBECount->getType()) >
14123 SE.getTypeSizeInBits(NewBECount->getType()))
14124 NewBECount = SE2.getZeroExtendExpr(NewBECount, CurBECount->getType());
14125 else if (SE.getTypeSizeInBits(CurBECount->getType()) <
14126 SE.getTypeSizeInBits(NewBECount->getType()))
14127 CurBECount = SE2.getZeroExtendExpr(CurBECount, NewBECount->getType());
14128
14129 const SCEV *Delta = GetDelta(CurBECount, NewBECount);
14130 if (Delta && !Delta->isZero()) {
14131 dbgs() << "Trip Count for " << *L << " Changed!\n";
14132 dbgs() << "Old: " << *CurBECount << "\n";
14133 dbgs() << "New: " << *NewBECount << "\n";
14134 dbgs() << "Delta: " << *Delta << "\n";
14135 std::abort();
14136 }
14137 }
14138
14139 // Collect all valid loops currently in LoopInfo.
14140 SmallPtrSet<Loop *, 32> ValidLoops;
14141 SmallVector<Loop *, 32> Worklist(LI.begin(), LI.end());
14142 while (!Worklist.empty()) {
14143 Loop *L = Worklist.pop_back_val();
14144 if (ValidLoops.insert(L).second)
14145 Worklist.append(L->begin(), L->end());
14146 }
14147 for (const auto &KV : ValueExprMap) {
14148#ifndef NDEBUG
14149 // Check for SCEV expressions referencing invalid/deleted loops.
14150 if (auto *AR = dyn_cast<SCEVAddRecExpr>(KV.second)) {
14151 assert(ValidLoops.contains(AR->getLoop()) &&
14152 "AddRec references invalid loop");
14153 }
14154#endif
14155
14156 // Check that the value is also part of the reverse map.
14157 auto It = ExprValueMap.find(KV.second);
14158 if (It == ExprValueMap.end() || !It->second.contains(KV.first)) {
14159 dbgs() << "Value " << *KV.first
14160 << " is in ValueExprMap but not in ExprValueMap\n";
14161 std::abort();
14162 }
14163
14164 if (auto *I = dyn_cast<Instruction>(&*KV.first)) {
14165 if (!ReachableBlocks.contains(I->getParent()))
14166 continue;
14167 const SCEV *OldSCEV = SCM.visit(KV.second);
14168 const SCEV *NewSCEV = SE2.getSCEV(I);
14169 const SCEV *Delta = GetDelta(OldSCEV, NewSCEV);
14170 if (Delta && !Delta->isZero()) {
14171 dbgs() << "SCEV for value " << *I << " changed!\n"
14172 << "Old: " << *OldSCEV << "\n"
14173 << "New: " << *NewSCEV << "\n"
14174 << "Delta: " << *Delta << "\n";
14175 std::abort();
14176 }
14177 }
14178 }
14179
14180 for (const auto &KV : ExprValueMap) {
14181 for (Value *V : KV.second) {
14182 auto It = ValueExprMap.find_as(V);
14183 if (It == ValueExprMap.end()) {
14184 dbgs() << "Value " << *V
14185 << " is in ExprValueMap but not in ValueExprMap\n";
14186 std::abort();
14187 }
14188 if (It->second != KV.first) {
14189 dbgs() << "Value " << *V << " mapped to " << *It->second
14190 << " rather than " << *KV.first << "\n";
14191 std::abort();
14192 }
14193 }
14194 }
14195
14196 // Verify integrity of SCEV users.
14197 for (const auto &S : UniqueSCEVs) {
14198 for (const auto *Op : S.operands()) {
14199 // We do not store dependencies of constants.
14200 if (isa<SCEVConstant>(Op))
14201 continue;
14202 auto It = SCEVUsers.find(Op);
14203 if (It != SCEVUsers.end() && It->second.count(&S))
14204 continue;
14205 dbgs() << "Use of operand " << *Op << " by user " << S
14206 << " is not being tracked!\n";
14207 std::abort();
14208 }
14209 }
14210
14211 // Verify integrity of ValuesAtScopes users.
14212 for (const auto &ValueAndVec : ValuesAtScopes) {
14213 const SCEV *Value = ValueAndVec.first;
14214 for (const auto &LoopAndValueAtScope : ValueAndVec.second) {
14215 const Loop *L = LoopAndValueAtScope.first;
14216 const SCEV *ValueAtScope = LoopAndValueAtScope.second;
14217 if (!isa<SCEVConstant>(ValueAtScope)) {
14218 auto It = ValuesAtScopesUsers.find(ValueAtScope);
14219 if (It != ValuesAtScopesUsers.end() &&
14220 is_contained(It->second, std::make_pair(L, Value)))
14221 continue;
14222 dbgs() << "Value: " << *Value << ", Loop: " << *L << ", ValueAtScope: "
14223 << *ValueAtScope << " missing in ValuesAtScopesUsers\n";
14224 std::abort();
14225 }
14226 }
14227 }
14228
14229 for (const auto &ValueAtScopeAndVec : ValuesAtScopesUsers) {
14230 const SCEV *ValueAtScope = ValueAtScopeAndVec.first;
14231 for (const auto &LoopAndValue : ValueAtScopeAndVec.second) {
14232 const Loop *L = LoopAndValue.first;
14233 const SCEV *Value = LoopAndValue.second;
14234 assert(!isa<SCEVConstant>(Value));
14235 auto It = ValuesAtScopes.find(Value);
14236 if (It != ValuesAtScopes.end() &&
14237 is_contained(It->second, std::make_pair(L, ValueAtScope)))
14238 continue;
14239 dbgs() << "Value: " << *Value << ", Loop: " << *L << ", ValueAtScope: "
14240 << *ValueAtScope << " missing in ValuesAtScopes\n";
14241 std::abort();
14242 }
14243 }
14244
14245 // Verify integrity of BECountUsers.
14246 auto VerifyBECountUsers = [&](bool Predicated) {
14247 auto &BECounts =
14248 Predicated ? PredicatedBackedgeTakenCounts : BackedgeTakenCounts;
14249 for (const auto &LoopAndBEInfo : BECounts) {
14250 for (const ExitNotTakenInfo &ENT : LoopAndBEInfo.second.ExitNotTaken) {
14251 for (const SCEV *S : {ENT.ExactNotTaken, ENT.SymbolicMaxNotTaken}) {
14252 if (!isa<SCEVConstant>(S)) {
14253 auto UserIt = BECountUsers.find(S);
14254 if (UserIt != BECountUsers.end() &&
14255 UserIt->second.contains({ LoopAndBEInfo.first, Predicated }))
14256 continue;
14257 dbgs() << "Value " << *S << " for loop " << *LoopAndBEInfo.first
14258 << " missing from BECountUsers\n";
14259 std::abort();
14260 }
14261 }
14262 }
14263 }
14264 };
14265 VerifyBECountUsers(/* Predicated */ false);
14266 VerifyBECountUsers(/* Predicated */ true);
14267
14268 // Verify intergity of loop disposition cache.
14269 for (auto &[S, Values] : LoopDispositions) {
14270 for (auto [Loop, CachedDisposition] : Values) {
14271 const auto RecomputedDisposition = SE2.getLoopDisposition(S, Loop);
14272 if (CachedDisposition != RecomputedDisposition) {
14273 dbgs() << "Cached disposition of " << *S << " for loop " << *Loop
14274 << " is incorrect: cached " << CachedDisposition << ", actual "
14275 << RecomputedDisposition << "\n";
14276 std::abort();
14277 }
14278 }
14279 }
14280
14281 // Verify integrity of the block disposition cache.
14282 for (auto &[S, Values] : BlockDispositions) {
14283 for (auto [BB, CachedDisposition] : Values) {
14284 const auto RecomputedDisposition = SE2.getBlockDisposition(S, BB);
14285 if (CachedDisposition != RecomputedDisposition) {
14286 dbgs() << "Cached disposition of " << *S << " for block %"
14287 << BB->getName() << " is incorrect: cached " << CachedDisposition
14288 << ", actual " << RecomputedDisposition << "\n";
14289 std::abort();
14290 }
14291 }
14292 }
14293
14294 // Verify FoldCache/FoldCacheUser caches.
14295 for (auto [FoldID, Expr] : FoldCache) {
14296 auto I = FoldCacheUser.find(Expr);
14297 if (I == FoldCacheUser.end()) {
14298 dbgs() << "Missing entry in FoldCacheUser for cached expression " << *Expr
14299 << "!\n";
14300 std::abort();
14301 }
14302 if (!is_contained(I->second, FoldID)) {
14303 dbgs() << "Missing FoldID in cached users of " << *Expr << "!\n";
14304 std::abort();
14305 }
14306 }
14307 for (auto [Expr, IDs] : FoldCacheUser) {
14308 for (auto &FoldID : IDs) {
14309 auto I = FoldCache.find(FoldID);
14310 if (I == FoldCache.end()) {
14311 dbgs() << "Missing entry in FoldCache for expression " << *Expr
14312 << "!\n";
14313 std::abort();
14314 }
14315 if (I->second != Expr) {
14316 dbgs() << "Entry in FoldCache doesn't match FoldCacheUser: "
14317 << *I->second << " != " << *Expr << "!\n";
14318 std::abort();
14319 }
14320 }
14321 }
14322
14323 // Verify that ConstantMultipleCache computations are correct. We check that
14324 // cached multiples and recomputed multiples are multiples of each other to
14325 // verify correctness. It is possible that a recomputed multiple is different
14326 // from the cached multiple due to strengthened no wrap flags or changes in
14327 // KnownBits computations.
14328 for (auto [S, Multiple] : ConstantMultipleCache) {
14329 APInt RecomputedMultiple = SE2.getConstantMultiple(S);
14330 if ((Multiple != 0 && RecomputedMultiple != 0 &&
14331 Multiple.urem(RecomputedMultiple) != 0 &&
14332 RecomputedMultiple.urem(Multiple) != 0)) {
14333 dbgs() << "Incorrect cached computation in ConstantMultipleCache for "
14334 << *S << " : Computed " << RecomputedMultiple
14335 << " but cache contains " << Multiple << "!\n";
14336 std::abort();
14337 }
14338 }
14339}
14340
14342 Function &F, const PreservedAnalyses &PA,
14344 // Invalidate the ScalarEvolution object whenever it isn't preserved or one
14345 // of its dependencies is invalidated.
14346 auto PAC = PA.getChecker<ScalarEvolutionAnalysis>();
14347 return !(PAC.preserved() || PAC.preservedSet<AllAnalysesOn<Function>>()) ||
14348 Inv.invalidate<AssumptionAnalysis>(F, PA) ||
14350 Inv.invalidate<LoopAnalysis>(F, PA);
14351}
14352
14353AnalysisKey ScalarEvolutionAnalysis::Key;
14354
14357 auto &TLI = AM.getResult<TargetLibraryAnalysis>(F);
14358 auto &AC = AM.getResult<AssumptionAnalysis>(F);
14359 auto &DT = AM.getResult<DominatorTreeAnalysis>(F);
14360 auto &LI = AM.getResult<LoopAnalysis>(F);
14361 return ScalarEvolution(F, TLI, AC, DT, LI);
14362}
14363
14367 return PreservedAnalyses::all();
14368}
14369
14372 // For compatibility with opt's -analyze feature under legacy pass manager
14373 // which was not ported to NPM. This keeps tests using
14374 // update_analyze_test_checks.py working.
14375 OS << "Printing analysis 'Scalar Evolution Analysis' for function '"
14376 << F.getName() << "':\n";
14378 return PreservedAnalyses::all();
14379}
14380
14382 "Scalar Evolution Analysis", false, true)
14388 "Scalar Evolution Analysis", false, true)
14389
14391
14394}
14395
14397 SE.reset(new ScalarEvolution(
14398 F, getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F),
14399 getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F),
14400 getAnalysis<DominatorTreeWrapperPass>().getDomTree(),
14401 getAnalysis<LoopInfoWrapperPass>().getLoopInfo()));
14402 return false;
14403}
14404
14406
14408 SE->print(OS);
14409}
14410
14412 if (!VerifySCEV)
14413 return;
14414
14415 SE->verify();
14416}
14417
14419 AU.setPreservesAll();
14424}
14425
14427 const SCEV *RHS) {
14429}
14430
14431const SCEVPredicate *
14433 const SCEV *LHS, const SCEV *RHS) {
14435 assert(LHS->getType() == RHS->getType() &&
14436 "Type mismatch between LHS and RHS");
14437 // Unique this node based on the arguments
14438 ID.AddInteger(SCEVPredicate::P_Compare);
14439 ID.AddInteger(Pred);
14440 ID.AddPointer(LHS);
14441 ID.AddPointer(RHS);
14442 void *IP = nullptr;
14443 if (const auto *S = UniquePreds.FindNodeOrInsertPos(ID, IP))
14444 return S;
14445 SCEVComparePredicate *Eq = new (SCEVAllocator)
14446 SCEVComparePredicate(ID.Intern(SCEVAllocator), Pred, LHS, RHS);
14447 UniquePreds.InsertNode(Eq, IP);
14448 return Eq;
14449}
14450
14452 const SCEVAddRecExpr *AR,
14455 // Unique this node based on the arguments
14456 ID.AddInteger(SCEVPredicate::P_Wrap);
14457 ID.AddPointer(AR);
14458 ID.AddInteger(AddedFlags);
14459 void *IP = nullptr;
14460 if (const auto *S = UniquePreds.FindNodeOrInsertPos(ID, IP))
14461 return S;
14462 auto *OF = new (SCEVAllocator)
14463 SCEVWrapPredicate(ID.Intern(SCEVAllocator), AR, AddedFlags);
14464 UniquePreds.InsertNode(OF, IP);
14465 return OF;
14466}
14467
14468namespace {
14469
14470class SCEVPredicateRewriter : public SCEVRewriteVisitor<SCEVPredicateRewriter> {
14471public:
14472
14473 /// Rewrites \p S in the context of a loop L and the SCEV predication
14474 /// infrastructure.
14475 ///
14476 /// If \p Pred is non-null, the SCEV expression is rewritten to respect the
14477 /// equivalences present in \p Pred.
14478 ///
14479 /// If \p NewPreds is non-null, rewrite is free to add further predicates to
14480 /// \p NewPreds such that the result will be an AddRecExpr.
14481 static const SCEV *rewrite(const SCEV *S, const Loop *L, ScalarEvolution &SE,
14483 const SCEVPredicate *Pred) {
14484 SCEVPredicateRewriter Rewriter(L, SE, NewPreds, Pred);
14485 return Rewriter.visit(S);
14486 }
14487
14488 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
14489 if (Pred) {
14490 if (auto *U = dyn_cast<SCEVUnionPredicate>(Pred)) {
14491 for (const auto *Pred : U->getPredicates())
14492 if (const auto *IPred = dyn_cast<SCEVComparePredicate>(Pred))
14493 if (IPred->getLHS() == Expr &&
14494 IPred->getPredicate() == ICmpInst::ICMP_EQ)
14495 return IPred->getRHS();
14496 } else if (const auto *IPred = dyn_cast<SCEVComparePredicate>(Pred)) {
14497 if (IPred->getLHS() == Expr &&
14498 IPred->getPredicate() == ICmpInst::ICMP_EQ)
14499 return IPred->getRHS();
14500 }
14501 }
14502 return convertToAddRecWithPreds(Expr);
14503 }
14504
14505 const SCEV *visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr) {
14506 const SCEV *Operand = visit(Expr->getOperand());
14507 const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Operand);
14508 if (AR && AR->getLoop() == L && AR->isAffine()) {
14509 // This couldn't be folded because the operand didn't have the nuw
14510 // flag. Add the nusw flag as an assumption that we could make.
14511 const SCEV *Step = AR->getStepRecurrence(SE);
14512 Type *Ty = Expr->getType();
14513 if (addOverflowAssumption(AR, SCEVWrapPredicate::IncrementNUSW))
14514 return SE.getAddRecExpr(SE.getZeroExtendExpr(AR->getStart(), Ty),
14515 SE.getSignExtendExpr(Step, Ty), L,
14516 AR->getNoWrapFlags());
14517 }
14518 return SE.getZeroExtendExpr(Operand, Expr->getType());
14519 }
14520
14521 const SCEV *visitSignExtendExpr(const SCEVSignExtendExpr *Expr) {
14522 const SCEV *Operand = visit(Expr->getOperand());
14523 const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Operand);
14524 if (AR && AR->getLoop() == L && AR->isAffine()) {
14525 // This couldn't be folded because the operand didn't have the nsw
14526 // flag. Add the nssw flag as an assumption that we could make.
14527 const SCEV *Step = AR->getStepRecurrence(SE);
14528 Type *Ty = Expr->getType();
14529 if (addOverflowAssumption(AR, SCEVWrapPredicate::IncrementNSSW))
14530 return SE.getAddRecExpr(SE.getSignExtendExpr(AR->getStart(), Ty),
14531 SE.getSignExtendExpr(Step, Ty), L,
14532 AR->getNoWrapFlags());
14533 }
14534 return SE.getSignExtendExpr(Operand, Expr->getType());
14535 }
14536
14537private:
14538 explicit SCEVPredicateRewriter(const Loop *L, ScalarEvolution &SE,
14540 const SCEVPredicate *Pred)
14541 : SCEVRewriteVisitor(SE), NewPreds(NewPreds), Pred(Pred), L(L) {}
14542
14543 bool addOverflowAssumption(const SCEVPredicate *P) {
14544 if (!NewPreds) {
14545 // Check if we've already made this assumption.
14546 return Pred && Pred->implies(P);
14547 }
14548 NewPreds->insert(P);
14549 return true;
14550 }
14551
14552 bool addOverflowAssumption(const SCEVAddRecExpr *AR,
14554 auto *A = SE.getWrapPredicate(AR, AddedFlags);
14555 return addOverflowAssumption(A);
14556 }
14557
14558 // If \p Expr represents a PHINode, we try to see if it can be represented
14559 // as an AddRec, possibly under a predicate (PHISCEVPred). If it is possible
14560 // to add this predicate as a runtime overflow check, we return the AddRec.
14561 // If \p Expr does not meet these conditions (is not a PHI node, or we
14562 // couldn't create an AddRec for it, or couldn't add the predicate), we just
14563 // return \p Expr.
14564 const SCEV *convertToAddRecWithPreds(const SCEVUnknown *Expr) {
14565 if (!isa<PHINode>(Expr->getValue()))
14566 return Expr;
14567 std::optional<
14568 std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>>>
14569 PredicatedRewrite = SE.createAddRecFromPHIWithCasts(Expr);
14570 if (!PredicatedRewrite)
14571 return Expr;
14572 for (const auto *P : PredicatedRewrite->second){
14573 // Wrap predicates from outer loops are not supported.
14574 if (auto *WP = dyn_cast<const SCEVWrapPredicate>(P)) {
14575 if (L != WP->getExpr()->getLoop())
14576 return Expr;
14577 }
14578 if (!addOverflowAssumption(P))
14579 return Expr;
14580 }
14581 return PredicatedRewrite->first;
14582 }
14583
14585 const SCEVPredicate *Pred;
14586 const Loop *L;
14587};
14588
14589} // end anonymous namespace
14590
14591const SCEV *
14593 const SCEVPredicate &Preds) {
14594 return SCEVPredicateRewriter::rewrite(S, L, *this, nullptr, &Preds);
14595}
14596
14598 const SCEV *S, const Loop *L,
14601 S = SCEVPredicateRewriter::rewrite(S, L, *this, &TransformPreds, nullptr);
14602 auto *AddRec = dyn_cast<SCEVAddRecExpr>(S);
14603
14604 if (!AddRec)
14605 return nullptr;
14606
14607 // Since the transformation was successful, we can now transfer the SCEV
14608 // predicates.
14609 for (const auto *P : TransformPreds)
14610 Preds.insert(P);
14611
14612 return AddRec;
14613}
14614
14615/// SCEV predicates
14617 SCEVPredicateKind Kind)
14618 : FastID(ID), Kind(Kind) {}
14619
14621 const ICmpInst::Predicate Pred,
14622 const SCEV *LHS, const SCEV *RHS)
14623 : SCEVPredicate(ID, P_Compare), Pred(Pred), LHS(LHS), RHS(RHS) {
14624 assert(LHS->getType() == RHS->getType() && "LHS and RHS types don't match");
14625 assert(LHS != RHS && "LHS and RHS are the same SCEV");
14626}
14627
14629 const auto *Op = dyn_cast<SCEVComparePredicate>(N);
14630
14631 if (!Op)
14632 return false;
14633
14634 if (Pred != ICmpInst::ICMP_EQ)
14635 return false;
14636
14637 return Op->LHS == LHS && Op->RHS == RHS;
14638}
14639
14640bool SCEVComparePredicate::isAlwaysTrue() const { return false; }
14641
14643 if (Pred == ICmpInst::ICMP_EQ)
14644 OS.indent(Depth) << "Equal predicate: " << *LHS << " == " << *RHS << "\n";
14645 else
14646 OS.indent(Depth) << "Compare predicate: " << *LHS << " " << Pred << ") "
14647 << *RHS << "\n";
14648
14649}
14650
14652 const SCEVAddRecExpr *AR,
14653 IncrementWrapFlags Flags)
14654 : SCEVPredicate(ID, P_Wrap), AR(AR), Flags(Flags) {}
14655
14656const SCEVAddRecExpr *SCEVWrapPredicate::getExpr() const { return AR; }
14657
14659 const auto *Op = dyn_cast<SCEVWrapPredicate>(N);
14660
14661 return Op && Op->AR == AR && setFlags(Flags, Op->Flags) == Flags;
14662}
14663
14665 SCEV::NoWrapFlags ScevFlags = AR->getNoWrapFlags();
14666 IncrementWrapFlags IFlags = Flags;
14667
14668 if (ScalarEvolution::setFlags(ScevFlags, SCEV::FlagNSW) == ScevFlags)
14669 IFlags = clearFlags(IFlags, IncrementNSSW);
14670
14671 return IFlags == IncrementAnyWrap;
14672}
14673
14675 OS.indent(Depth) << *getExpr() << " Added Flags: ";
14677 OS << "<nusw>";
14679 OS << "<nssw>";
14680 OS << "\n";
14681}
14682
14685 ScalarEvolution &SE) {
14686 IncrementWrapFlags ImpliedFlags = IncrementAnyWrap;
14687 SCEV::NoWrapFlags StaticFlags = AR->getNoWrapFlags();
14688
14689 // We can safely transfer the NSW flag as NSSW.
14690 if (ScalarEvolution::setFlags(StaticFlags, SCEV::FlagNSW) == StaticFlags)
14691 ImpliedFlags = IncrementNSSW;
14692
14693 if (ScalarEvolution::setFlags(StaticFlags, SCEV::FlagNUW) == StaticFlags) {
14694 // If the increment is positive, the SCEV NUW flag will also imply the
14695 // WrapPredicate NUSW flag.
14696 if (const auto *Step = dyn_cast<SCEVConstant>(AR->getStepRecurrence(SE)))
14697 if (Step->getValue()->getValue().isNonNegative())
14698 ImpliedFlags = setFlags(ImpliedFlags, IncrementNUSW);
14699 }
14700
14701 return ImpliedFlags;
14702}
14703
14704/// Union predicates don't get cached so create a dummy set ID for it.
14706 : SCEVPredicate(FoldingSetNodeIDRef(nullptr, 0), P_Union) {
14707 for (const auto *P : Preds)
14708 add(P);
14709}
14710
14712 return all_of(Preds,
14713 [](const SCEVPredicate *I) { return I->isAlwaysTrue(); });
14714}
14715
14717 if (const auto *Set = dyn_cast<SCEVUnionPredicate>(N))
14718 return all_of(Set->Preds,
14719 [this](const SCEVPredicate *I) { return this->implies(I); });
14720
14721 return any_of(Preds,
14722 [N](const SCEVPredicate *I) { return I->implies(N); });
14723}
14724
14726 for (const auto *Pred : Preds)
14727 Pred->print(OS, Depth);
14728}
14729
14730void SCEVUnionPredicate::add(const SCEVPredicate *N) {
14731 if (const auto *Set = dyn_cast<SCEVUnionPredicate>(N)) {
14732 for (const auto *Pred : Set->Preds)
14733 add(Pred);
14734 return;
14735 }
14736
14737 Preds.push_back(N);
14738}
14739
14741 Loop &L)
14742 : SE(SE), L(L) {
14744 Preds = std::make_unique<SCEVUnionPredicate>(Empty);
14745}
14746
14749 for (const auto *Op : Ops)
14750 // We do not expect that forgetting cached data for SCEVConstants will ever
14751 // open any prospects for sharpening or introduce any correctness issues,
14752 // so we don't bother storing their dependencies.
14753 if (!isa<SCEVConstant>(Op))
14754 SCEVUsers[Op].insert(User);
14755}
14756
14758 const SCEV *Expr = SE.getSCEV(V);
14759 RewriteEntry &Entry = RewriteMap[Expr];
14760
14761 // If we already have an entry and the version matches, return it.
14762 if (Entry.second && Generation == Entry.first)
14763 return Entry.second;
14764
14765 // We found an entry but it's stale. Rewrite the stale entry
14766 // according to the current predicate.
14767 if (Entry.second)
14768 Expr = Entry.second;
14769
14770 const SCEV *NewSCEV = SE.rewriteUsingPredicate(Expr, &L, *Preds);
14771 Entry = {Generation, NewSCEV};
14772
14773 return NewSCEV;
14774}
14775
14777 if (!BackedgeCount) {
14779 BackedgeCount = SE.getPredicatedBackedgeTakenCount(&L, Preds);
14780 for (const auto *P : Preds)
14781 addPredicate(*P);
14782 }
14783 return BackedgeCount;
14784}
14785
14787 if (Preds->implies(&Pred))
14788 return;
14789
14790 auto &OldPreds = Preds->getPredicates();
14791 SmallVector<const SCEVPredicate*, 4> NewPreds(OldPreds.begin(), OldPreds.end());
14792 NewPreds.push_back(&Pred);
14793 Preds = std::make_unique<SCEVUnionPredicate>(NewPreds);
14794 updateGeneration();
14795}
14796
14798 return *Preds;
14799}
14800
14801void PredicatedScalarEvolution::updateGeneration() {
14802 // If the generation number wrapped recompute everything.
14803 if (++Generation == 0) {
14804 for (auto &II : RewriteMap) {
14805 const SCEV *Rewritten = II.second.second;
14806 II.second = {Generation, SE.rewriteUsingPredicate(Rewritten, &L, *Preds)};
14807 }
14808 }
14809}
14810
14813 const SCEV *Expr = getSCEV(V);
14814 const auto *AR = cast<SCEVAddRecExpr>(Expr);
14815
14816 auto ImpliedFlags = SCEVWrapPredicate::getImpliedFlags(AR, SE);
14817
14818 // Clear the statically implied flags.
14819 Flags = SCEVWrapPredicate::clearFlags(Flags, ImpliedFlags);
14820 addPredicate(*SE.getWrapPredicate(AR, Flags));
14821
14822 auto II = FlagsMap.insert({V, Flags});
14823 if (!II.second)
14824 II.first->second = SCEVWrapPredicate::setFlags(Flags, II.first->second);
14825}
14826
14829 const SCEV *Expr = getSCEV(V);
14830 const auto *AR = cast<SCEVAddRecExpr>(Expr);
14831
14833 Flags, SCEVWrapPredicate::getImpliedFlags(AR, SE));
14834
14835 auto II = FlagsMap.find(V);
14836
14837 if (II != FlagsMap.end())
14838 Flags = SCEVWrapPredicate::clearFlags(Flags, II->second);
14839
14841}
14842
14844 const SCEV *Expr = this->getSCEV(V);
14846 auto *New = SE.convertSCEVToAddRecWithPredicates(Expr, &L, NewPreds);
14847
14848 if (!New)
14849 return nullptr;
14850
14851 for (const auto *P : NewPreds)
14852 addPredicate(*P);
14853
14854 RewriteMap[SE.getSCEV(V)] = {Generation, New};
14855 return New;
14856}
14857
14860 : RewriteMap(Init.RewriteMap), SE(Init.SE), L(Init.L),
14861 Preds(std::make_unique<SCEVUnionPredicate>(Init.Preds->getPredicates())),
14862 Generation(Init.Generation), BackedgeCount(Init.BackedgeCount) {
14863 for (auto I : Init.FlagsMap)
14864 FlagsMap.insert(I);
14865}
14866
14868 // For each block.
14869 for (auto *BB : L.getBlocks())
14870 for (auto &I : *BB) {
14871 if (!SE.isSCEVable(I.getType()))
14872 continue;
14873
14874 auto *Expr = SE.getSCEV(&I);
14875 auto II = RewriteMap.find(Expr);
14876
14877 if (II == RewriteMap.end())
14878 continue;
14879
14880 // Don't print things that are not interesting.
14881 if (II->second.second == Expr)
14882 continue;
14883
14884 OS.indent(Depth) << "[PSE]" << I << ":\n";
14885 OS.indent(Depth + 2) << *Expr << "\n";
14886 OS.indent(Depth + 2) << "--> " << *II->second.second << "\n";
14887 }
14888}
14889
14890// Match the mathematical pattern A - (A / B) * B, where A and B can be
14891// arbitrary expressions. Also match zext (trunc A to iB) to iY, which is used
14892// for URem with constant power-of-2 second operands.
14893// It's not always easy, as A and B can be folded (imagine A is X / 2, and B is
14894// 4, A / B becomes X / 8).
14895bool ScalarEvolution::matchURem(const SCEV *Expr, const SCEV *&LHS,
14896 const SCEV *&RHS) {
14897 // Try to match 'zext (trunc A to iB) to iY', which is used
14898 // for URem with constant power-of-2 second operands. Make sure the size of
14899 // the operand A matches the size of the whole expressions.
14900 if (const auto *ZExt = dyn_cast<SCEVZeroExtendExpr>(Expr))
14901 if (const auto *Trunc = dyn_cast<SCEVTruncateExpr>(ZExt->getOperand(0))) {
14902 LHS = Trunc->getOperand();
14903 // Bail out if the type of the LHS is larger than the type of the
14904 // expression for now.
14905 if (getTypeSizeInBits(LHS->getType()) >
14906 getTypeSizeInBits(Expr->getType()))
14907 return false;
14908 if (LHS->getType() != Expr->getType())
14909 LHS = getZeroExtendExpr(LHS, Expr->getType());
14911 << getTypeSizeInBits(Trunc->getType()));
14912 return true;
14913 }
14914 const auto *Add = dyn_cast<SCEVAddExpr>(Expr);
14915 if (Add == nullptr || Add->getNumOperands() != 2)
14916 return false;
14917
14918 const SCEV *A = Add->getOperand(1);
14919 const auto *Mul = dyn_cast<SCEVMulExpr>(Add->getOperand(0));
14920
14921 if (Mul == nullptr)
14922 return false;
14923
14924 const auto MatchURemWithDivisor = [&](const SCEV *B) {
14925 // (SomeExpr + (-(SomeExpr / B) * B)).
14926 if (Expr == getURemExpr(A, B)) {
14927 LHS = A;
14928 RHS = B;
14929 return true;
14930 }
14931 return false;
14932 };
14933
14934 // (SomeExpr + (-1 * (SomeExpr / B) * B)).
14935 if (Mul->getNumOperands() == 3 && isa<SCEVConstant>(Mul->getOperand(0)))
14936 return MatchURemWithDivisor(Mul->getOperand(1)) ||
14937 MatchURemWithDivisor(Mul->getOperand(2));
14938
14939 // (SomeExpr + ((-SomeExpr / B) * B)) or (SomeExpr + ((SomeExpr / B) * -B)).
14940 if (Mul->getNumOperands() == 2)
14941 return MatchURemWithDivisor(Mul->getOperand(1)) ||
14942 MatchURemWithDivisor(Mul->getOperand(0)) ||
14943 MatchURemWithDivisor(getNegativeSCEV(Mul->getOperand(1))) ||
14944 MatchURemWithDivisor(getNegativeSCEV(Mul->getOperand(0)));
14945 return false;
14946}
14947
14948const SCEV *
14949ScalarEvolution::computeSymbolicMaxBackedgeTakenCount(const Loop *L) {
14950 SmallVector<BasicBlock*, 16> ExitingBlocks;
14951 L->getExitingBlocks(ExitingBlocks);
14952
14953 // Form an expression for the maximum exit count possible for this loop. We
14954 // merge the max and exact information to approximate a version of
14955 // getConstantMaxBackedgeTakenCount which isn't restricted to just constants.
14956 SmallVector<const SCEV*, 4> ExitCounts;
14957 for (BasicBlock *ExitingBB : ExitingBlocks) {
14958 const SCEV *ExitCount =
14960 if (!isa<SCEVCouldNotCompute>(ExitCount)) {
14961 assert(DT.dominates(ExitingBB, L->getLoopLatch()) &&
14962 "We should only have known counts for exiting blocks that "
14963 "dominate latch!");
14964 ExitCounts.push_back(ExitCount);
14965 }
14966 }
14967 if (ExitCounts.empty())
14968 return getCouldNotCompute();
14969 return getUMinFromMismatchedTypes(ExitCounts, /*Sequential*/ true);
14970}
14971
14972/// A rewriter to replace SCEV expressions in Map with the corresponding entry
14973/// in the map. It skips AddRecExpr because we cannot guarantee that the
14974/// replacement is loop invariant in the loop of the AddRec.
14975class SCEVLoopGuardRewriter : public SCEVRewriteVisitor<SCEVLoopGuardRewriter> {
14977
14978public:
14981 : SCEVRewriteVisitor(SE), Map(M) {}
14982
14983 const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) { return Expr; }
14984
14985 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
14986 auto I = Map.find(Expr);
14987 if (I == Map.end())
14988 return Expr;
14989 return I->second;
14990 }
14991
14993 auto I = Map.find(Expr);
14994 if (I == Map.end()) {
14995 // If we didn't find the extact ZExt expr in the map, check if there's an
14996 // entry for a smaller ZExt we can use instead.
14997 Type *Ty = Expr->getType();
14998 const SCEV *Op = Expr->getOperand(0);
14999 unsigned Bitwidth = Ty->getScalarSizeInBits() / 2;
15000 while (Bitwidth % 8 == 0 && Bitwidth >= 8 &&
15001 Bitwidth > Op->getType()->getScalarSizeInBits()) {
15002 Type *NarrowTy = IntegerType::get(SE.getContext(), Bitwidth);
15003 auto *NarrowExt = SE.getZeroExtendExpr(Op, NarrowTy);
15004 auto I = Map.find(NarrowExt);
15005 if (I != Map.end())
15006 return SE.getZeroExtendExpr(I->second, Ty);
15007 Bitwidth = Bitwidth / 2;
15008 }
15009
15011 Expr);
15012 }
15013 return I->second;
15014 }
15015
15017 auto I = Map.find(Expr);
15018 if (I == Map.end())
15020 Expr);
15021 return I->second;
15022 }
15023
15024 const SCEV *visitUMinExpr(const SCEVUMinExpr *Expr) {
15025 auto I = Map.find(Expr);
15026 if (I == Map.end())
15028 return I->second;
15029 }
15030
15031 const SCEV *visitSMinExpr(const SCEVSMinExpr *Expr) {
15032 auto I = Map.find(Expr);
15033 if (I == Map.end())
15035 return I->second;
15036 }
15037};
15038
15039const SCEV *ScalarEvolution::applyLoopGuards(const SCEV *Expr, const Loop *L) {
15040 SmallVector<const SCEV *> ExprsToRewrite;
15041 auto CollectCondition = [&](ICmpInst::Predicate Predicate, const SCEV *LHS,
15042 const SCEV *RHS,
15044 &RewriteMap) {
15045 // WARNING: It is generally unsound to apply any wrap flags to the proposed
15046 // replacement SCEV which isn't directly implied by the structure of that
15047 // SCEV. In particular, using contextual facts to imply flags is *NOT*
15048 // legal. See the scoping rules for flags in the header to understand why.
15049
15050 // If LHS is a constant, apply information to the other expression.
15051 if (isa<SCEVConstant>(LHS)) {
15052 std::swap(LHS, RHS);
15053 Predicate = CmpInst::getSwappedPredicate(Predicate);
15054 }
15055
15056 // Check for a condition of the form (-C1 + X < C2). InstCombine will
15057 // create this form when combining two checks of the form (X u< C2 + C1) and
15058 // (X >=u C1).
15059 auto MatchRangeCheckIdiom = [this, Predicate, LHS, RHS, &RewriteMap,
15060 &ExprsToRewrite]() {
15061 auto *AddExpr = dyn_cast<SCEVAddExpr>(LHS);
15062 if (!AddExpr || AddExpr->getNumOperands() != 2)
15063 return false;
15064
15065 auto *C1 = dyn_cast<SCEVConstant>(AddExpr->getOperand(0));
15066 auto *LHSUnknown = dyn_cast<SCEVUnknown>(AddExpr->getOperand(1));
15067 auto *C2 = dyn_cast<SCEVConstant>(RHS);
15068 if (!C1 || !C2 || !LHSUnknown)
15069 return false;
15070
15071 auto ExactRegion =
15072 ConstantRange::makeExactICmpRegion(Predicate, C2->getAPInt())
15073 .sub(C1->getAPInt());
15074
15075 // Bail out, unless we have a non-wrapping, monotonic range.
15076 if (ExactRegion.isWrappedSet() || ExactRegion.isFullSet())
15077 return false;
15078 auto I = RewriteMap.find(LHSUnknown);
15079 const SCEV *RewrittenLHS = I != RewriteMap.end() ? I->second : LHSUnknown;
15080 RewriteMap[LHSUnknown] = getUMaxExpr(
15081 getConstant(ExactRegion.getUnsignedMin()),
15082 getUMinExpr(RewrittenLHS, getConstant(ExactRegion.getUnsignedMax())));
15083 ExprsToRewrite.push_back(LHSUnknown);
15084 return true;
15085 };
15086 if (MatchRangeCheckIdiom())
15087 return;
15088
15089 // Return true if \p Expr is a MinMax SCEV expression with a non-negative
15090 // constant operand. If so, return in \p SCTy the SCEV type and in \p RHS
15091 // the non-constant operand and in \p LHS the constant operand.
15092 auto IsMinMaxSCEVWithNonNegativeConstant =
15093 [&](const SCEV *Expr, SCEVTypes &SCTy, const SCEV *&LHS,
15094 const SCEV *&RHS) {
15095 if (auto *MinMax = dyn_cast<SCEVMinMaxExpr>(Expr)) {
15096 if (MinMax->getNumOperands() != 2)
15097 return false;
15098 if (auto *C = dyn_cast<SCEVConstant>(MinMax->getOperand(0))) {
15099 if (C->getAPInt().isNegative())
15100 return false;
15101 SCTy = MinMax->getSCEVType();
15102 LHS = MinMax->getOperand(0);
15103 RHS = MinMax->getOperand(1);
15104 return true;
15105 }
15106 }
15107 return false;
15108 };
15109
15110 // Checks whether Expr is a non-negative constant, and Divisor is a positive
15111 // constant, and returns their APInt in ExprVal and in DivisorVal.
15112 auto GetNonNegExprAndPosDivisor = [&](const SCEV *Expr, const SCEV *Divisor,
15113 APInt &ExprVal, APInt &DivisorVal) {
15114 auto *ConstExpr = dyn_cast<SCEVConstant>(Expr);
15115 auto *ConstDivisor = dyn_cast<SCEVConstant>(Divisor);
15116 if (!ConstExpr || !ConstDivisor)
15117 return false;
15118 ExprVal = ConstExpr->getAPInt();
15119 DivisorVal = ConstDivisor->getAPInt();
15120 return ExprVal.isNonNegative() && !DivisorVal.isNonPositive();
15121 };
15122
15123 // Return a new SCEV that modifies \p Expr to the closest number divides by
15124 // \p Divisor and greater or equal than Expr.
15125 // For now, only handle constant Expr and Divisor.
15126 auto GetNextSCEVDividesByDivisor = [&](const SCEV *Expr,
15127 const SCEV *Divisor) {
15128 APInt ExprVal;
15129 APInt DivisorVal;
15130 if (!GetNonNegExprAndPosDivisor(Expr, Divisor, ExprVal, DivisorVal))
15131 return Expr;
15132 APInt Rem = ExprVal.urem(DivisorVal);
15133 if (!Rem.isZero())
15134 // return the SCEV: Expr + Divisor - Expr % Divisor
15135 return getConstant(ExprVal + DivisorVal - Rem);
15136 return Expr;
15137 };
15138
15139 // Return a new SCEV that modifies \p Expr to the closest number divides by
15140 // \p Divisor and less or equal than Expr.
15141 // For now, only handle constant Expr and Divisor.
15142 auto GetPreviousSCEVDividesByDivisor = [&](const SCEV *Expr,
15143 const SCEV *Divisor) {
15144 APInt ExprVal;
15145 APInt DivisorVal;
15146 if (!GetNonNegExprAndPosDivisor(Expr, Divisor, ExprVal, DivisorVal))
15147 return Expr;
15148 APInt Rem = ExprVal.urem(DivisorVal);
15149 // return the SCEV: Expr - Expr % Divisor
15150 return getConstant(ExprVal - Rem);
15151 };
15152
15153 // Apply divisibilty by \p Divisor on MinMaxExpr with constant values,
15154 // recursively. This is done by aligning up/down the constant value to the
15155 // Divisor.
15156 std::function<const SCEV *(const SCEV *, const SCEV *)>
15157 ApplyDivisibiltyOnMinMaxExpr = [&](const SCEV *MinMaxExpr,
15158 const SCEV *Divisor) {
15159 const SCEV *MinMaxLHS = nullptr, *MinMaxRHS = nullptr;
15160 SCEVTypes SCTy;
15161 if (!IsMinMaxSCEVWithNonNegativeConstant(MinMaxExpr, SCTy, MinMaxLHS,
15162 MinMaxRHS))
15163 return MinMaxExpr;
15164 auto IsMin =
15165 isa<SCEVSMinExpr>(MinMaxExpr) || isa<SCEVUMinExpr>(MinMaxExpr);
15166 assert(isKnownNonNegative(MinMaxLHS) &&
15167 "Expected non-negative operand!");
15168 auto *DivisibleExpr =
15169 IsMin ? GetPreviousSCEVDividesByDivisor(MinMaxLHS, Divisor)
15170 : GetNextSCEVDividesByDivisor(MinMaxLHS, Divisor);
15172 ApplyDivisibiltyOnMinMaxExpr(MinMaxRHS, Divisor), DivisibleExpr};
15173 return getMinMaxExpr(SCTy, Ops);
15174 };
15175
15176 // If we have LHS == 0, check if LHS is computing a property of some unknown
15177 // SCEV %v which we can rewrite %v to express explicitly.
15178 const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS);
15179 if (Predicate == CmpInst::ICMP_EQ && RHSC &&
15180 RHSC->getValue()->isNullValue()) {
15181 // If LHS is A % B, i.e. A % B == 0, rewrite A to (A /u B) * B to
15182 // explicitly express that.
15183 const SCEV *URemLHS = nullptr;
15184 const SCEV *URemRHS = nullptr;
15185 if (matchURem(LHS, URemLHS, URemRHS)) {
15186 if (const SCEVUnknown *LHSUnknown = dyn_cast<SCEVUnknown>(URemLHS)) {
15187 auto I = RewriteMap.find(LHSUnknown);
15188 const SCEV *RewrittenLHS =
15189 I != RewriteMap.end() ? I->second : LHSUnknown;
15190 RewrittenLHS = ApplyDivisibiltyOnMinMaxExpr(RewrittenLHS, URemRHS);
15191 const auto *Multiple =
15192 getMulExpr(getUDivExpr(RewrittenLHS, URemRHS), URemRHS);
15193 RewriteMap[LHSUnknown] = Multiple;
15194 ExprsToRewrite.push_back(LHSUnknown);
15195 return;
15196 }
15197 }
15198 }
15199
15200 // Do not apply information for constants or if RHS contains an AddRec.
15201 if (isa<SCEVConstant>(LHS) || containsAddRecurrence(RHS))
15202 return;
15203
15204 // If RHS is SCEVUnknown, make sure the information is applied to it.
15205 if (!isa<SCEVUnknown>(LHS) && isa<SCEVUnknown>(RHS)) {
15206 std::swap(LHS, RHS);
15207 Predicate = CmpInst::getSwappedPredicate(Predicate);
15208 }
15209
15210 // Puts rewrite rule \p From -> \p To into the rewrite map. Also if \p From
15211 // and \p FromRewritten are the same (i.e. there has been no rewrite
15212 // registered for \p From), then puts this value in the list of rewritten
15213 // expressions.
15214 auto AddRewrite = [&](const SCEV *From, const SCEV *FromRewritten,
15215 const SCEV *To) {
15216 if (From == FromRewritten)
15217 ExprsToRewrite.push_back(From);
15218 RewriteMap[From] = To;
15219 };
15220
15221 // Checks whether \p S has already been rewritten. In that case returns the
15222 // existing rewrite because we want to chain further rewrites onto the
15223 // already rewritten value. Otherwise returns \p S.
15224 auto GetMaybeRewritten = [&](const SCEV *S) {
15225 auto I = RewriteMap.find(S);
15226 return I != RewriteMap.end() ? I->second : S;
15227 };
15228
15229 // Check for the SCEV expression (A /u B) * B while B is a constant, inside
15230 // \p Expr. The check is done recuresively on \p Expr, which is assumed to
15231 // be a composition of Min/Max SCEVs. Return whether the SCEV expression (A
15232 // /u B) * B was found, and return the divisor B in \p DividesBy. For
15233 // example, if Expr = umin (umax ((A /u 8) * 8, 16), 64), return true since
15234 // (A /u 8) * 8 matched the pattern, and return the constant SCEV 8 in \p
15235 // DividesBy.
15236 std::function<bool(const SCEV *, const SCEV *&)> HasDivisibiltyInfo =
15237 [&](const SCEV *Expr, const SCEV *&DividesBy) {
15238 if (auto *Mul = dyn_cast<SCEVMulExpr>(Expr)) {
15239 if (Mul->getNumOperands() != 2)
15240 return false;
15241 auto *MulLHS = Mul->getOperand(0);
15242 auto *MulRHS = Mul->getOperand(1);
15243 if (isa<SCEVConstant>(MulLHS))
15244 std::swap(MulLHS, MulRHS);
15245 if (auto *Div = dyn_cast<SCEVUDivExpr>(MulLHS))
15246 if (Div->getOperand(1) == MulRHS) {
15247 DividesBy = MulRHS;
15248 return true;
15249 }
15250 }
15251 if (auto *MinMax = dyn_cast<SCEVMinMaxExpr>(Expr))
15252 return HasDivisibiltyInfo(MinMax->getOperand(0), DividesBy) ||
15253 HasDivisibiltyInfo(MinMax->getOperand(1), DividesBy);
15254 return false;
15255 };
15256
15257 // Return true if Expr known to divide by \p DividesBy.
15258 std::function<bool(const SCEV *, const SCEV *&)> IsKnownToDivideBy =
15259 [&](const SCEV *Expr, const SCEV *DividesBy) {
15260 if (getURemExpr(Expr, DividesBy)->isZero())
15261 return true;
15262 if (auto *MinMax = dyn_cast<SCEVMinMaxExpr>(Expr))
15263 return IsKnownToDivideBy(MinMax->getOperand(0), DividesBy) &&
15264 IsKnownToDivideBy(MinMax->getOperand(1), DividesBy);
15265 return false;
15266 };
15267
15268 const SCEV *RewrittenLHS = GetMaybeRewritten(LHS);
15269 const SCEV *DividesBy = nullptr;
15270 if (HasDivisibiltyInfo(RewrittenLHS, DividesBy))
15271 // Check that the whole expression is divided by DividesBy
15272 DividesBy =
15273 IsKnownToDivideBy(RewrittenLHS, DividesBy) ? DividesBy : nullptr;
15274
15275 // Collect rewrites for LHS and its transitive operands based on the
15276 // condition.
15277 // For min/max expressions, also apply the guard to its operands:
15278 // 'min(a, b) >= c' -> '(a >= c) and (b >= c)',
15279 // 'min(a, b) > c' -> '(a > c) and (b > c)',
15280 // 'max(a, b) <= c' -> '(a <= c) and (b <= c)',
15281 // 'max(a, b) < c' -> '(a < c) and (b < c)'.
15282
15283 // We cannot express strict predicates in SCEV, so instead we replace them
15284 // with non-strict ones against plus or minus one of RHS depending on the
15285 // predicate.
15286 const SCEV *One = getOne(RHS->getType());
15287 switch (Predicate) {
15288 case CmpInst::ICMP_ULT:
15289 if (RHS->getType()->isPointerTy())
15290 return;
15291 RHS = getUMaxExpr(RHS, One);
15292 [[fallthrough]];
15293 case CmpInst::ICMP_SLT: {
15294 RHS = getMinusSCEV(RHS, One);
15295 RHS = DividesBy ? GetPreviousSCEVDividesByDivisor(RHS, DividesBy) : RHS;
15296 break;
15297 }
15298 case CmpInst::ICMP_UGT:
15299 case CmpInst::ICMP_SGT:
15300 RHS = getAddExpr(RHS, One);
15301 RHS = DividesBy ? GetNextSCEVDividesByDivisor(RHS, DividesBy) : RHS;
15302 break;
15303 case CmpInst::ICMP_ULE:
15304 case CmpInst::ICMP_SLE:
15305 RHS = DividesBy ? GetPreviousSCEVDividesByDivisor(RHS, DividesBy) : RHS;
15306 break;
15307 case CmpInst::ICMP_UGE:
15308 case CmpInst::ICMP_SGE:
15309 RHS = DividesBy ? GetNextSCEVDividesByDivisor(RHS, DividesBy) : RHS;
15310 break;
15311 default:
15312 break;
15313 }
15314
15317
15318 auto EnqueueOperands = [&Worklist](const SCEVNAryExpr *S) {
15319 append_range(Worklist, S->operands());
15320 };
15321
15322 while (!Worklist.empty()) {
15323 const SCEV *From = Worklist.pop_back_val();
15324 if (isa<SCEVConstant>(From))
15325 continue;
15326 if (!Visited.insert(From).second)
15327 continue;
15328 const SCEV *FromRewritten = GetMaybeRewritten(From);
15329 const SCEV *To = nullptr;
15330
15331 switch (Predicate) {
15332 case CmpInst::ICMP_ULT:
15333 case CmpInst::ICMP_ULE:
15334 To = getUMinExpr(FromRewritten, RHS);
15335 if (auto *UMax = dyn_cast<SCEVUMaxExpr>(FromRewritten))
15336 EnqueueOperands(UMax);
15337 break;
15338 case CmpInst::ICMP_SLT:
15339 case CmpInst::ICMP_SLE:
15340 To = getSMinExpr(FromRewritten, RHS);
15341 if (auto *SMax = dyn_cast<SCEVSMaxExpr>(FromRewritten))
15342 EnqueueOperands(SMax);
15343 break;
15344 case CmpInst::ICMP_UGT:
15345 case CmpInst::ICMP_UGE:
15346 To = getUMaxExpr(FromRewritten, RHS);
15347 if (auto *UMin = dyn_cast<SCEVUMinExpr>(FromRewritten))
15348 EnqueueOperands(UMin);
15349 break;
15350 case CmpInst::ICMP_SGT:
15351 case CmpInst::ICMP_SGE:
15352 To = getSMaxExpr(FromRewritten, RHS);
15353 if (auto *SMin = dyn_cast<SCEVSMinExpr>(FromRewritten))
15354 EnqueueOperands(SMin);
15355 break;
15356 case CmpInst::ICMP_EQ:
15357 if (isa<SCEVConstant>(RHS))
15358 To = RHS;
15359 break;
15360 case CmpInst::ICMP_NE:
15361 if (isa<SCEVConstant>(RHS) &&
15362 cast<SCEVConstant>(RHS)->getValue()->isNullValue()) {
15363 const SCEV *OneAlignedUp =
15364 DividesBy ? GetNextSCEVDividesByDivisor(One, DividesBy) : One;
15365 To = getUMaxExpr(FromRewritten, OneAlignedUp);
15366 }
15367 break;
15368 default:
15369 break;
15370 }
15371
15372 if (To)
15373 AddRewrite(From, FromRewritten, To);
15374 }
15375 };
15376
15377 BasicBlock *Header = L->getHeader();
15379 // First, collect information from assumptions dominating the loop.
15380 for (auto &AssumeVH : AC.assumptions()) {
15381 if (!AssumeVH)
15382 continue;
15383 auto *AssumeI = cast<CallInst>(AssumeVH);
15384 if (!DT.dominates(AssumeI, Header))
15385 continue;
15386 Terms.emplace_back(AssumeI->getOperand(0), true);
15387 }
15388
15389 // Second, collect information from llvm.experimental.guards dominating the loop.
15390 auto *GuardDecl = F.getParent()->getFunction(
15391 Intrinsic::getName(Intrinsic::experimental_guard));
15392 if (GuardDecl)
15393 for (const auto *GU : GuardDecl->users())
15394 if (const auto *Guard = dyn_cast<IntrinsicInst>(GU))
15395 if (Guard->getFunction() == Header->getParent() && DT.dominates(Guard, Header))
15396 Terms.emplace_back(Guard->getArgOperand(0), true);
15397
15398 // Third, collect conditions from dominating branches. Starting at the loop
15399 // predecessor, climb up the predecessor chain, as long as there are
15400 // predecessors that can be found that have unique successors leading to the
15401 // original header.
15402 // TODO: share this logic with isLoopEntryGuardedByCond.
15403 for (std::pair<const BasicBlock *, const BasicBlock *> Pair(
15404 L->getLoopPredecessor(), Header);
15405 Pair.first; Pair = getPredecessorWithUniqueSuccessorForBB(Pair.first)) {
15406
15407 const BranchInst *LoopEntryPredicate =
15408 dyn_cast<BranchInst>(Pair.first->getTerminator());
15409 if (!LoopEntryPredicate || LoopEntryPredicate->isUnconditional())
15410 continue;
15411
15412 Terms.emplace_back(LoopEntryPredicate->getCondition(),
15413 LoopEntryPredicate->getSuccessor(0) == Pair.second);
15414 }
15415
15416 // Now apply the information from the collected conditions to RewriteMap.
15417 // Conditions are processed in reverse order, so the earliest conditions is
15418 // processed first. This ensures the SCEVs with the shortest dependency chains
15419 // are constructed first.
15421 for (auto [Term, EnterIfTrue] : reverse(Terms)) {
15422 SmallVector<Value *, 8> Worklist;
15424 Worklist.push_back(Term);
15425 while (!Worklist.empty()) {
15426 Value *Cond = Worklist.pop_back_val();
15427 if (!Visited.insert(Cond).second)
15428 continue;
15429
15430 if (auto *Cmp = dyn_cast<ICmpInst>(Cond)) {
15431 auto Predicate =
15432 EnterIfTrue ? Cmp->getPredicate() : Cmp->getInversePredicate();
15433 const auto *LHS = getSCEV(Cmp->getOperand(0));
15434 const auto *RHS = getSCEV(Cmp->getOperand(1));
15435 CollectCondition(Predicate, LHS, RHS, RewriteMap);
15436 continue;
15437 }
15438
15439 Value *L, *R;
15440 if (EnterIfTrue ? match(Cond, m_LogicalAnd(m_Value(L), m_Value(R)))
15441 : match(Cond, m_LogicalOr(m_Value(L), m_Value(R)))) {
15442 Worklist.push_back(L);
15443 Worklist.push_back(R);
15444 }
15445 }
15446 }
15447
15448 if (RewriteMap.empty())
15449 return Expr;
15450
15451 // Now that all rewrite information is collect, rewrite the collected
15452 // expressions with the information in the map. This applies information to
15453 // sub-expressions.
15454 if (ExprsToRewrite.size() > 1) {
15455 for (const SCEV *Expr : ExprsToRewrite) {
15456 const SCEV *RewriteTo = RewriteMap[Expr];
15457 RewriteMap.erase(Expr);
15458 SCEVLoopGuardRewriter Rewriter(*this, RewriteMap);
15459 RewriteMap.insert({Expr, Rewriter.visit(RewriteTo)});
15460 }
15461 }
15462
15463 SCEVLoopGuardRewriter Rewriter(*this, RewriteMap);
15464 return Rewriter.visit(Expr);
15465}
MachineBasicBlock MachineBasicBlock::iterator DebugLoc DL
static const LLT S1
Rewrite undef for PHI
This file implements a class to represent arbitrary precision integral constant values and operations...
@ PostInc
Expand Atomic instructions
basic Basic Alias true
block Block Frequency Analysis
BlockVerifier::State From
static GCRegistry::Add< OcamlGC > B("ocaml", "ocaml 3.10-compatible GC")
static GCRegistry::Add< ErlangGC > A("erlang", "erlang-compatible garbage collector")
static GCRegistry::Add< StatepointGC > D("statepoint-example", "an example strategy for statepoint")
#define LLVM_DUMP_METHOD
Mark debug helper function definitions like dump() that should not be stripped from debug builds.
Definition: Compiler.h:529
This file contains the declarations for the subclasses of Constant, which represent the different fla...
Returns the sub type a function will return at a given Idx Should correspond to the result type of an ExtractValue instruction executed with just that one unsigned Idx
#define LLVM_DEBUG(X)
Definition: Debug.h:101
This file defines the DenseMap class.
This file builds on the ADT/GraphTraits.h file to build generic depth first graph iterator.
uint64_t Size
bool End
Definition: ELF_riscv.cpp:480
Generic implementation of equivalence classes through the use Tarjan's efficient union-find algorithm...
static GCMetadataPrinterRegistry::Add< ErlangGCPrinter > X("erlang", "erlang-compatible garbage collector")
static bool isSigned(unsigned int Opcode)
This file defines a hash set that can be used to remove duplication of nodes in a graph.
#define op(i)
Hexagon Common GEP
This file provides various utilities for inspecting and working with the control flow graph in LLVM I...
iv Induction Variable Users
Definition: IVUsers.cpp:48
static bool isZero(Value *V, const DataLayout &DL, DominatorTree *DT, AssumptionCache *AC)
Definition: Lint.cpp:528
#define F(x, y, z)
Definition: MD5.cpp:55
#define I(x, y, z)
Definition: MD5.cpp:58
mir Rename Register Operands
#define T1
static GCMetadataPrinterRegistry::Add< OcamlGCMetadataPrinter > Y("ocaml", "ocaml 3.10-compatible collector")
#define P(N)
ppc ctr loops verify
PowerPC Reduce CR logical Operation
if(VerifyEach)
#define INITIALIZE_PASS_DEPENDENCY(depName)
Definition: PassSupport.h:55
#define INITIALIZE_PASS_END(passName, arg, name, cfg, analysis)
Definition: PassSupport.h:59
#define INITIALIZE_PASS_BEGIN(passName, arg, name, cfg, analysis)
Definition: PassSupport.h:52
static bool rewrite(Function &F)
R600 Clause Merge
const SmallVectorImpl< MachineOperand > & Cond
static bool isValid(const char C)
Returns true if C is a valid mangled character: <0-9a-zA-Z_>.
assert(ImpDefSCC.getReg()==AMDGPU::SCC &&ImpDefSCC.isDef())
SI optimize exec mask operations pre RA
This file contains some templates that are useful if you are working with the STL at all.
raw_pwrite_stream & OS
This file provides utility classes that use RAII to save and restore values.
bool SCEVMinMaxExprContains(const SCEV *Root, const SCEV *OperandToFind, SCEVTypes RootKind)
static cl::opt< unsigned > MaxAddRecSize("scalar-evolution-max-add-rec-size", cl::Hidden, cl::desc("Max coefficients in AddRec during evolving"), cl::init(8))
static cl::opt< unsigned > RangeIterThreshold("scev-range-iter-threshold", cl::Hidden, cl::desc("Threshold for switching to iteratively computing SCEV ranges"), cl::init(32))
static const Loop * isIntegerLoopHeaderPHI(const PHINode *PN, LoopInfo &LI)
static unsigned getConstantTripCount(const SCEVConstant *ExitCount)
static void PushLoopPHIs(const Loop *L, SmallVectorImpl< Instruction * > &Worklist, SmallPtrSetImpl< Instruction * > &Visited)
Push PHI nodes in the header of the given loop onto the given Worklist.
static void insertFoldCacheEntry(const ScalarEvolution::FoldID &ID, const SCEV *S, DenseMap< ScalarEvolution::FoldID, const SCEV * > &FoldCache, DenseMap< const SCEV *, SmallVector< ScalarEvolution::FoldID, 2 > > &FoldCacheUser)
static bool IsKnownPredicateViaMinOrMax(ScalarEvolution &SE, ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS)
Is LHS Pred RHS true on the virtue of LHS or RHS being a Min or Max expression?
static cl::opt< bool > ClassifyExpressions("scalar-evolution-classify-expressions", cl::Hidden, cl::init(true), cl::desc("When printing analysis, include information on every instruction"))
static bool CanConstantFold(const Instruction *I)
Return true if we can constant fold an instruction of the specified type, assuming that all operands ...
static cl::opt< unsigned > AddOpsInlineThreshold("scev-addops-inline-threshold", cl::Hidden, cl::desc("Threshold for inlining addition operands into a SCEV"), cl::init(500))
static cl::opt< bool > VerifyIR("scev-verify-ir", cl::Hidden, cl::desc("Verify IR correctness when making sensitive SCEV queries (slow)"), cl::init(false))
static bool BrPHIToSelect(DominatorTree &DT, BranchInst *BI, PHINode *Merge, Value *&C, Value *&LHS, Value *&RHS)
static const SCEV * getPreStartForExtend(const SCEVAddRecExpr *AR, Type *Ty, ScalarEvolution *SE, unsigned Depth)
static std::optional< APInt > MinOptional(std::optional< APInt > X, std::optional< APInt > Y)
Helper function to compare optional APInts: (a) if X and Y both exist, return min(X,...
static cl::opt< unsigned > MulOpsInlineThreshold("scev-mulops-inline-threshold", cl::Hidden, cl::desc("Threshold for inlining multiplication operands into a SCEV"), cl::init(32))
static void GroupByComplexity(SmallVectorImpl< const SCEV * > &Ops, LoopInfo *LI, DominatorTree &DT)
Given a list of SCEV objects, order them by their complexity, and group objects of the same complexit...
static std::optional< const SCEV * > createNodeForSelectViaUMinSeq(ScalarEvolution *SE, const SCEV *CondExpr, const SCEV *TrueExpr, const SCEV *FalseExpr)
static Constant * BuildConstantFromSCEV(const SCEV *V)
This builds up a Constant using the ConstantExpr interface.
static ConstantInt * EvaluateConstantChrecAtConstant(const SCEVAddRecExpr *AddRec, ConstantInt *C, ScalarEvolution &SE)
static const SCEV * BinomialCoefficient(const SCEV *It, unsigned K, ScalarEvolution &SE, Type *ResultTy)
Compute BC(It, K). The result has width W. Assume, K > 0.
static cl::opt< unsigned > MaxCastDepth("scalar-evolution-max-cast-depth", cl::Hidden, cl::desc("Maximum depth of recursive SExt/ZExt/Trunc"), cl::init(8))
static bool IsMinMaxConsistingOf(const SCEV *MaybeMinMaxExpr, const SCEV *Candidate)
Is MaybeMinMaxExpr an (U|S)(Min|Max) of Candidate and some other values?
static PHINode * getConstantEvolvingPHI(Value *V, const Loop *L)
getConstantEvolvingPHI - Given an LLVM value and a loop, return a PHI node in the loop that V is deri...
static cl::opt< unsigned > MaxBruteForceIterations("scalar-evolution-max-iterations", cl::ReallyHidden, cl::desc("Maximum number of iterations SCEV will " "symbolically execute a constant " "derived loop"), cl::init(100))
static bool MatchBinarySub(const SCEV *S, const SCEV *&LHS, const SCEV *&RHS)
static std::optional< int > CompareSCEVComplexity(EquivalenceClasses< const SCEV * > &EqCacheSCEV, EquivalenceClasses< const Value * > &EqCacheValue, const LoopInfo *const LI, const SCEV *LHS, const SCEV *RHS, DominatorTree &DT, unsigned Depth=0)
static uint64_t umul_ov(uint64_t i, uint64_t j, bool &Overflow)
static void PrintSCEVWithTypeHint(raw_ostream &OS, const SCEV *S)
When printing a top-level SCEV for trip counts, it's helpful to include a type for constants which ar...
static void PrintLoopInfo(raw_ostream &OS, ScalarEvolution *SE, const Loop *L)
static bool containsConstantInAddMulChain(const SCEV *StartExpr)
Determine if any of the operands in this SCEV are a constant or if any of the add or multiply express...
static const SCEV * getExtendAddRecStart(const SCEVAddRecExpr *AR, Type *Ty, ScalarEvolution *SE, unsigned Depth)
static bool hasHugeExpression(ArrayRef< const SCEV * > Ops)
Returns true if Ops contains a huge SCEV (the subtree of S contains at least HugeExprThreshold nodes)...
static cl::opt< unsigned > MaxPhiSCCAnalysisSize("scalar-evolution-max-scc-analysis-depth", cl::Hidden, cl::desc("Maximum amount of nodes to process while searching SCEVUnknown " "Phi strongly connected components"), cl::init(8))
static int CompareValueComplexity(EquivalenceClasses< const Value * > &EqCacheValue, const LoopInfo *const LI, Value *LV, Value *RV, unsigned Depth)
Compare the two values LV and RV in terms of their "complexity" where "complexity" is a partial (and ...
static cl::opt< unsigned > MaxSCEVOperationsImplicationDepth("scalar-evolution-max-scev-operations-implication-depth", cl::Hidden, cl::desc("Maximum depth of recursive SCEV operations implication analysis"), cl::init(2))
static void PushDefUseChildren(Instruction *I, SmallVectorImpl< Instruction * > &Worklist, SmallPtrSetImpl< Instruction * > &Visited)
Push users of the given Instruction onto the given Worklist.
static std::optional< APInt > SolveQuadraticAddRecRange(const SCEVAddRecExpr *AddRec, const ConstantRange &Range, ScalarEvolution &SE)
Let c(n) be the value of the quadratic chrec {0,+,M,+,N} after n iterations.
static cl::opt< bool > UseContextForNoWrapFlagInference("scalar-evolution-use-context-for-no-wrap-flag-strenghening", cl::Hidden, cl::desc("Infer nuw/nsw flags using context where suitable"), cl::init(true))
static cl::opt< bool > EnableFiniteLoopControl("scalar-evolution-finite-loop", cl::Hidden, cl::desc("Handle <= and >= in finite loops"), cl::init(true))
static std::optional< std::tuple< APInt, APInt, APInt, APInt, unsigned > > GetQuadraticEquation(const SCEVAddRecExpr *AddRec)
For a given quadratic addrec, generate coefficients of the corresponding quadratic equation,...
static std::optional< BinaryOp > MatchBinaryOp(Value *V, const DataLayout &DL, AssumptionCache &AC, const DominatorTree &DT, const Instruction *CxtI)
Try to map V into a BinaryOp, and return std::nullopt on failure.
APInt gcd(const SCEVConstant *C1, const SCEVConstant *C2)
static std::optional< APInt > SolveQuadraticAddRecExact(const SCEVAddRecExpr *AddRec, ScalarEvolution &SE)
Let c(n) be the value of the quadratic chrec {L,+,M,+,N} after n iterations.
static std::optional< APInt > TruncIfPossible(std::optional< APInt > X, unsigned BitWidth)
Helper function to truncate an optional APInt to a given BitWidth.
static bool IsKnownPredicateViaAddRecStart(ScalarEvolution &SE, ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS)
static cl::opt< unsigned > MaxSCEVCompareDepth("scalar-evolution-max-scev-compare-depth", cl::Hidden, cl::desc("Maximum depth of recursive SCEV complexity comparisons"), cl::init(32))
static APInt extractConstantWithoutWrapping(ScalarEvolution &SE, const SCEVConstant *ConstantTerm, const SCEVAddExpr *WholeAddExpr)
static cl::opt< unsigned > MaxConstantEvolvingDepth("scalar-evolution-max-constant-evolving-depth", cl::Hidden, cl::desc("Maximum depth of recursive constant evolving"), cl::init(32))
static const SCEV * SolveLinEquationWithOverflow(const APInt &A, const SCEV *B, ScalarEvolution &SE)
Finds the minimum unsigned root of the following equation:
static ConstantRange getRangeForAffineARHelper(APInt Step, const ConstantRange &StartRange, const APInt &MaxBECount, bool Signed)
static std::optional< ConstantRange > GetRangeFromMetadata(Value *V)
Helper method to assign a range to V from metadata present in the IR.
static bool CollectAddOperandsWithScales(DenseMap< const SCEV *, APInt > &M, SmallVectorImpl< const SCEV * > &NewOps, APInt &AccumulatedConstant, ArrayRef< const SCEV * > Ops, const APInt &Scale, ScalarEvolution &SE)
Process the given Ops list, which is a list of operands to be added under the given scale,...
static cl::opt< unsigned > HugeExprThreshold("scalar-evolution-huge-expr-threshold", cl::Hidden, cl::desc("Size of the expression which is considered huge"), cl::init(4096))
static bool isKnownPredicateExtendIdiom(ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS)
static Type * isSimpleCastedPHI(const SCEV *Op, const SCEVUnknown *SymbolicPHI, bool &Signed, ScalarEvolution &SE)
Helper function to createAddRecFromPHIWithCasts.
static Constant * EvaluateExpression(Value *V, const Loop *L, DenseMap< Instruction *, Constant * > &Vals, const DataLayout &DL, const TargetLibraryInfo *TLI)
EvaluateExpression - Given an expression that passes the getConstantEvolvingPHI predicate,...
static const SCEV * MatchNotExpr(const SCEV *Expr)
If Expr computes ~A, return A else return nullptr.
static cl::opt< unsigned > MaxValueCompareDepth("scalar-evolution-max-value-compare-depth", cl::Hidden, cl::desc("Maximum depth of recursive value complexity comparisons"), cl::init(2))
static cl::opt< bool, true > VerifySCEVOpt("verify-scev", cl::Hidden, cl::location(VerifySCEV), cl::desc("Verify ScalarEvolution's backedge taken counts (slow)"))
static const SCEV * getSignedOverflowLimitForStep(const SCEV *Step, ICmpInst::Predicate *Pred, ScalarEvolution *SE)
static SCEV::NoWrapFlags StrengthenNoWrapFlags(ScalarEvolution *SE, SCEVTypes Type, const ArrayRef< const SCEV * > Ops, SCEV::NoWrapFlags Flags)
static cl::opt< unsigned > MaxArithDepth("scalar-evolution-max-arith-depth", cl::Hidden, cl::desc("Maximum depth of recursive arithmetics"), cl::init(32))
static bool HasSameValue(const SCEV *A, const SCEV *B)
SCEV structural equivalence is usually sufficient for testing whether two expressions are equal,...
static uint64_t Choose(uint64_t n, uint64_t k, bool &Overflow)
Compute the result of "n choose k", the binomial coefficient.
static bool canConstantEvolve(Instruction *I, const Loop *L)
Determine whether this instruction can constant evolve within this loop assuming its operands can all...
static PHINode * getConstantEvolvingPHIOperands(Instruction *UseInst, const Loop *L, DenseMap< Instruction *, PHINode * > &PHIMap, unsigned Depth)
getConstantEvolvingPHIOperands - Implement getConstantEvolvingPHI by recursing through each instructi...
static bool scevUnconditionallyPropagatesPoisonFromOperands(SCEVTypes Kind)
static cl::opt< bool > VerifySCEVStrict("verify-scev-strict", cl::Hidden, cl::desc("Enable stricter verification with -verify-scev is passed"))
static Constant * getOtherIncomingValue(PHINode *PN, BasicBlock *BB)
scalar evolution
static cl::opt< bool > UseExpensiveRangeSharpening("scalar-evolution-use-expensive-range-sharpening", cl::Hidden, cl::init(false), cl::desc("Use more powerful methods of sharpening expression ranges. May " "be costly in terms of compile time"))
static const SCEV * getUnsignedOverflowLimitForStep(const SCEV *Step, ICmpInst::Predicate *Pred, ScalarEvolution *SE)
This file defines the make_scope_exit function, which executes user-defined cleanup logic at scope ex...
Provides some synthesis utilities to produce sequences of values.
This file defines the SmallPtrSet class.
This file defines the SmallSet class.
This file defines the SmallVector class.
This file defines the 'Statistic' class, which is designed to be an easy way to expose various metric...
#define STATISTIC(VARNAME, DESC)
Definition: Statistic.h:167
This file contains some functions that are useful when dealing with strings.
static SymbolRef::Type getType(const Symbol *Sym)
Definition: TapiFile.cpp:40
This defines the Use class.
static std::optional< unsigned > getOpcode(ArrayRef< VPValue * > Values)
Returns the opcode of Values or ~0 if they do not all agree.
Definition: VPlanSLP.cpp:191
Virtual Register Rewriter
Definition: VirtRegMap.cpp:237
Value * RHS
Value * LHS
static const uint32_t IV[8]
Definition: blake3_impl.h:78
A rewriter to replace SCEV expressions in Map with the corresponding entry in the map.
const SCEV * visitAddRecExpr(const SCEVAddRecExpr *Expr)
const SCEV * visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr)
SCEVLoopGuardRewriter(ScalarEvolution &SE, DenseMap< const SCEV *, const SCEV * > &M)
const SCEV * visitSignExtendExpr(const SCEVSignExtendExpr *Expr)
const SCEV * visitUnknown(const SCEVUnknown *Expr)
const SCEV * visitUMinExpr(const SCEVUMinExpr *Expr)
const SCEV * visitSMinExpr(const SCEVSMinExpr *Expr)
Class for arbitrary precision integers.
Definition: APInt.h:76
APInt umul_ov(const APInt &RHS, bool &Overflow) const
Definition: APInt.cpp:1941
APInt udiv(const APInt &RHS) const
Unsigned division operation.
Definition: APInt.cpp:1543
APInt zext(unsigned width) const
Zero extend to a new width.
Definition: APInt.cpp:981
bool isMinSignedValue() const
Determine if this is the smallest signed value.
Definition: APInt.h:401
uint64_t getZExtValue() const
Get zero extended value.
Definition: APInt.h:1491
void setHighBits(unsigned hiBits)
Set the top hiBits bits.
Definition: APInt.h:1370
APInt getHiBits(unsigned numBits) const
Compute an APInt containing numBits highbits from this APInt.
Definition: APInt.cpp:608
APInt zextOrTrunc(unsigned width) const
Zero extend or truncate to width.
Definition: APInt.cpp:1002
unsigned getActiveBits() const
Compute the number of active bits in the value.
Definition: APInt.h:1463
APInt trunc(unsigned width) const
Truncate to new width.
Definition: APInt.cpp:906
static APInt getMaxValue(unsigned numBits)
Gets maximum unsigned value of APInt for specific bit width.
Definition: APInt.h:184
APInt abs() const
Get the absolute value.
Definition: APInt.h:1737
bool ugt(const APInt &RHS) const
Unsigned greater than comparison.
Definition: APInt.h:1160
bool isZero() const
Determine if this value is zero, i.e. all bits are clear.
Definition: APInt.h:358
bool isSignMask() const
Check if the APInt's value is returned by getSignMask.
Definition: APInt.h:444
APInt urem(const APInt &RHS) const
Unsigned remainder operation.
Definition: APInt.cpp:1636
unsigned getBitWidth() const
Return the number of bits in the APInt.
Definition: APInt.h:1439
bool ult(const APInt &RHS) const
Unsigned less than comparison.
Definition: APInt.h:1089
static APInt getSignedMaxValue(unsigned numBits)
Gets maximum signed value of APInt for a specific bit width.
Definition: APInt.h:187
static APInt getMinValue(unsigned numBits)
Gets minimum unsigned value of APInt for a specific bit width.
Definition: APInt.h:194
bool isNegative() const
Determine sign of this APInt.
Definition: APInt.h:307
bool sle(const APInt &RHS) const
Signed less or equal comparison.
Definition: APInt.h:1144
static APInt getSignedMinValue(unsigned numBits)
Gets minimum signed value of APInt for a specific bit width.
Definition: APInt.h:197
unsigned countTrailingZeros() const
Definition: APInt.h:1597
bool isStrictlyPositive() const
Determine if this APInt Value is positive.
Definition: APInt.h:334
APInt ashr(unsigned ShiftAmt) const
Arithmetic right-shift function.
Definition: APInt.h:805
APInt multiplicativeInverse() const
Definition: APInt.cpp:1244
bool ule(const APInt &RHS) const
Unsigned less or equal comparison.
Definition: APInt.h:1128
APInt sext(unsigned width) const
Sign extend to a new width.
Definition: APInt.cpp:954
APInt shl(unsigned shiftAmt) const
Left-shift function.
Definition: APInt.h:851
static APInt getLowBitsSet(unsigned numBits, unsigned loBitsSet)
Constructs an APInt value that has the bottom loBitsSet bits set.
Definition: APInt.h:284
bool isSignBitSet() const
Determine if sign bit of this APInt is set.
Definition: APInt.h:319
bool slt(const APInt &RHS) const
Signed less than comparison.
Definition: APInt.h:1108
static APInt getZero(unsigned numBits)
Get the '0' value for the specified bit-width.
Definition: APInt.h:178
bool isIntN(unsigned N) const
Check if this APInt has an N-bits unsigned integer value.
Definition: APInt.h:410
static APInt getOneBitSet(unsigned numBits, unsigned BitNo)
Return an APInt with exactly one bit set in the result.
Definition: APInt.h:217
bool uge(const APInt &RHS) const
Unsigned greater or equal comparison.
Definition: APInt.h:1199
This templated class represents "all analyses that operate over <a particular IR unit>" (e....
Definition: Analysis.h:47
API to communicate dependencies between analyses during invalidation.
Definition: PassManager.h:360
bool invalidate(IRUnitT &IR, const PreservedAnalyses &PA)
Trigger the invalidation of some other analysis pass if not already handled and return whether it was...
Definition: PassManager.h:378
A container for analyses that lazily runs them and caches their results.
Definition: PassManager.h:321
PassT::Result & getResult(IRUnitT &IR, ExtraArgTs... ExtraArgs)
Get the result of an analysis pass for a given IR unit.
Definition: PassManager.h:473
Represent the analysis usage information of a pass.
void setPreservesAll()
Set by analyses that do not transform their input at all.
AnalysisUsage & addRequiredTransitive()
ArrayRef - Represent a constant reference to an array (0 or more elements consecutively in memory),...
Definition: ArrayRef.h:41
ArrayRef< T > take_front(size_t N=1) const
Return a copy of *this with only the first N elements.
Definition: ArrayRef.h:228
iterator end() const
Definition: ArrayRef.h:154
size_t size() const
size - Get the array size.
Definition: ArrayRef.h:165
iterator begin() const
Definition: ArrayRef.h:153
A function analysis which provides an AssumptionCache.
An immutable pass that tracks lazily created AssumptionCache objects.
A cache of @llvm.assume calls within a function.
MutableArrayRef< ResultElem > assumptions()
Access the list of assumption handles currently tracked for this function.
bool isSingleEdge() const
Check if this is the only edge between Start and End.
Definition: Dominators.cpp:51
LLVM Basic Block Representation.
Definition: BasicBlock.h:60
iterator begin()
Instruction iterator methods.
Definition: BasicBlock.h:430
const Instruction & front() const
Definition: BasicBlock.h:453
const BasicBlock * getSinglePredecessor() const
Return the predecessor of this block if it has a single predecessor block.
Definition: BasicBlock.cpp:452
const Function * getParent() const
Return the enclosing method, or null if none.
Definition: BasicBlock.h:206
const Instruction * getTerminator() const LLVM_READONLY
Returns the terminator instruction if the block is well formed or null if the block is not well forme...
Definition: BasicBlock.h:221
Value * getRHS() const
unsigned getNoWrapKind() const
Returns one of OBO::NoSignedWrap or OBO::NoUnsignedWrap.
Instruction::BinaryOps getBinaryOp() const
Returns the binary operation underlying the intrinsic.
Value * getLHS() const
BinaryOps getOpcode() const
Definition: InstrTypes.h:513
Conditional or Unconditional Branch instruction.
bool isConditional() const
BasicBlock * getSuccessor(unsigned i) const
bool isUnconditional() const
Value * getCondition() const
LLVM_ATTRIBUTE_RETURNS_NONNULL void * Allocate(size_t Size, Align Alignment)
Allocate space at the specified alignment.
Definition: Allocator.h:148
This class represents a function call, abstracting a target machine's calling convention.
Value handle with callbacks on RAUW and destruction.
Definition: ValueHandle.h:383
void setValPtr(Value *P)
Definition: ValueHandle.h:390
bool isFalseWhenEqual() const
This is just a convenience.
Definition: InstrTypes.h:1320
Predicate
This enumeration lists the possible predicates for CmpInst subclasses.
Definition: InstrTypes.h:993
@ ICMP_SLT
signed less than
Definition: InstrTypes.h:1022
@ ICMP_SLE
signed less or equal
Definition: InstrTypes.h:1023
@ ICMP_UGE
unsigned greater or equal
Definition: InstrTypes.h:1017
@ ICMP_UGT
unsigned greater than
Definition: InstrTypes.h:1016
@ ICMP_SGT
signed greater than
Definition: InstrTypes.h:1020
@ ICMP_ULT
unsigned less than
Definition: InstrTypes.h:1018
@ ICMP_EQ
equal
Definition: InstrTypes.h:1014
@ ICMP_NE
not equal
Definition: InstrTypes.h:1015
@ ICMP_SGE
signed greater or equal
Definition: InstrTypes.h:1021
@ ICMP_ULE
unsigned less or equal
Definition: InstrTypes.h:1019
bool isSigned() const
Definition: InstrTypes.h:1265
Predicate getSwappedPredicate() const
For example, EQ->EQ, SLE->SGE, ULT->UGT, OEQ->OEQ, ULE->UGE, OLT->OGT, etc.
Definition: InstrTypes.h:1167
bool isTrueWhenEqual() const
This is just a convenience.
Definition: InstrTypes.h:1314
Predicate getNonStrictPredicate() const
For example, SGT -> SGE, SLT -> SLE, ULT -> ULE, UGT -> UGE.
Definition: InstrTypes.h:1211
Predicate getInversePredicate() const
For example, EQ -> NE, UGT -> ULE, SLT -> SGE, OEQ -> UNE, UGT -> OLE, OLT -> UGE,...
Definition: InstrTypes.h:1129
Predicate getPredicate() const
Return the predicate for this instruction.
Definition: InstrTypes.h:1105
Predicate getFlippedSignednessPredicate()
For example, SLT->ULT, ULT->SLT, SLE->ULE, ULE->SLE, EQ->Failed assert.
Definition: InstrTypes.h:1308
bool isUnsigned() const
Definition: InstrTypes.h:1271
bool isRelational() const
Return true if the predicate is relational (not EQ or NE).
Definition: InstrTypes.h:1261
static Constant * getNot(Constant *C)
Definition: Constants.cpp:2529
static Constant * getPtrToInt(Constant *C, Type *Ty, bool OnlyIfReduced=false)
Definition: Constants.cpp:2112
static Constant * getICmp(unsigned short pred, Constant *LHS, Constant *RHS, bool OnlyIfReduced=false)
get* - Return some common constants without having to specify the full Instruction::OPCODE identifier...
Definition: Constants.cpp:2402
static Constant * getGetElementPtr(Type *Ty, Constant *C, ArrayRef< Constant * > IdxList, bool InBounds=false, std::optional< ConstantRange > InRange=std::nullopt, Type *OnlyIfReducedTy=nullptr)
Getelementptr form.
Definition: Constants.h:1200
static Constant * getAdd(Constant *C1, Constant *C2, bool HasNUW=false, bool HasNSW=false)
Definition: Constants.cpp:2535
static Constant * getNeg(Constant *C, bool HasNSW=false)
Definition: Constants.cpp:2523
static Constant * getTrunc(Constant *C, Type *Ty, bool OnlyIfReduced=false)
Definition: Constants.cpp:2098
This is the shared class of boolean and integer constants.
Definition: Constants.h:80
bool isMinusOne() const
This function will return true iff every bit in this constant is set to true.
Definition: Constants.h:217
bool isOne() const
This is just a convenience method to make client code smaller for a common case.
Definition: Constants.h:211
bool isZero() const
This is just a convenience method to make client code smaller for a common code.
Definition: Constants.h:205
static ConstantInt * getFalse(LLVMContext &Context)
Definition: Constants.cpp:856
uint64_t getZExtValue() const
Return the constant as a 64-bit unsigned integer value after it has been zero extended as appropriate...
Definition: Constants.h:154
const APInt & getValue() const
Return the constant as an APInt value reference.
Definition: Constants.h:145
static ConstantInt * getBool(LLVMContext &Context, bool V)
Definition: Constants.cpp:863
This class represents a range of values.
Definition: ConstantRange.h:47
ConstantRange add(const ConstantRange &Other) const
Return a new range representing the possible values resulting from an addition of a value in this ran...
ConstantRange zextOrTrunc(uint32_t BitWidth) const
Make this range have the bit width given by BitWidth.
PreferredRangeType
If represented precisely, the result of some range operations may consist of multiple disjoint ranges...
bool getEquivalentICmp(CmpInst::Predicate &Pred, APInt &RHS) const
Set up Pred and RHS such that ConstantRange::makeExactICmpRegion(Pred, RHS) == *this.
const APInt & getLower() const
Return the lower value for this range.
ConstantRange truncate(uint32_t BitWidth) const
Return a new range in the specified integer type, which must be strictly smaller than the current typ...
bool isFullSet() const
Return true if this set contains all of the elements possible for this data-type.
bool icmp(CmpInst::Predicate Pred, const ConstantRange &Other) const
Does the predicate Pred hold between ranges this and Other? NOTE: false does not mean that inverse pr...
bool isEmptySet() const
Return true if this set contains no members.
ConstantRange zeroExtend(uint32_t BitWidth) const
Return a new range in the specified integer type, which must be strictly larger than the current type...
bool isSignWrappedSet() const
Return true if this set wraps around the signed domain.
APInt getSignedMin() const
Return the smallest signed value contained in the ConstantRange.
bool isWrappedSet() const
Return true if this set wraps around the unsigned domain.
void print(raw_ostream &OS) const
Print out the bounds to a stream.
ConstantRange signExtend(uint32_t BitWidth) const
Return a new range in the specified integer type, which must be strictly larger than the current type...
const APInt & getUpper() const
Return the upper value for this range.
ConstantRange unionWith(const ConstantRange &CR, PreferredRangeType Type=Smallest) const
Return the range that results from the union of this range with another range.
static ConstantRange makeExactICmpRegion(CmpInst::Predicate Pred, const APInt &Other)
Produce the exact range such that all values in the returned range satisfy the given predicate with a...
bool contains(const APInt &Val) const
Return true if the specified value is in the set.
APInt getUnsignedMax() const
Return the largest unsigned value contained in the ConstantRange.
ConstantRange intersectWith(const ConstantRange &CR, PreferredRangeType Type=Smallest) const
Return the range that results from the intersection of this range with another range.
APInt getSignedMax() const
Return the largest signed value contained in the ConstantRange.
static ConstantRange getNonEmpty(APInt Lower, APInt Upper)
Create non-empty constant range with the given bounds.
Definition: ConstantRange.h:84
static ConstantRange makeGuaranteedNoWrapRegion(Instruction::BinaryOps BinOp, const ConstantRange &Other, unsigned NoWrapKind)
Produce the largest range containing all X such that "X BinOp Y" is guaranteed not to wrap (overflow)...
unsigned getMinSignedBits() const
Compute the maximal number of bits needed to represent every value in this signed range.
uint32_t getBitWidth() const
Get the bit width of this ConstantRange.
ConstantRange sub(const ConstantRange &Other) const
Return a new range representing the possible values resulting from a subtraction of a value in this r...
ConstantRange sextOrTrunc(uint32_t BitWidth) const
Make this range have the bit width given by BitWidth.
static ConstantRange makeExactNoWrapRegion(Instruction::BinaryOps BinOp, const APInt &Other, unsigned NoWrapKind)
Produce the range that contains X if and only if "X BinOp Other" does not wrap.
This is an important base class in LLVM.
Definition: Constant.h:41
bool isNullValue() const
Return true if this is the value that would be returned by getNullValue.
Definition: Constants.cpp:90
This class represents an Operation in the Expression.
A parsed version of the target data layout string in and methods for querying it.
Definition: DataLayout.h:110
const StructLayout * getStructLayout(StructType *Ty) const
Returns a StructLayout object, indicating the alignment of the struct, its size, and the offsets of i...
Definition: DataLayout.cpp:720
IntegerType * getIntPtrType(LLVMContext &C, unsigned AddressSpace=0) const
Returns an integer type with size at least as big as that of a pointer in the given address space.
Definition: DataLayout.cpp:878
unsigned getIndexTypeSizeInBits(Type *Ty) const
Layout size of the index used in GEP calculation.
Definition: DataLayout.cpp:774
IntegerType * getIndexType(LLVMContext &C, unsigned AddressSpace) const
Returns the type of a GEP index in AddressSpace.
Definition: DataLayout.cpp:905
TypeSize getTypeSizeInBits(Type *Ty) const
Size examples:
Definition: DataLayout.h:672
ValueT lookup(const_arg_type_t< KeyT > Val) const
lookup - Return the entry for the specified key, or a default constructed value if no such entry exis...
Definition: DenseMap.h:202
iterator find(const_arg_type_t< KeyT > Val)
Definition: DenseMap.h:155
bool erase(const KeyT &Val)
Definition: DenseMap.h:329
DenseMapIterator< KeyT, ValueT, KeyInfoT, BucketT > iterator
Definition: DenseMap.h:71
iterator find_as(const LookupKeyT &Val)
Alternate version of find() which allows a different, and possibly less expensive,...
Definition: DenseMap.h:180
bool empty() const
Definition: DenseMap.h:98
size_type count(const_arg_type_t< KeyT > Val) const
Return 1 if the specified key is in the map, 0 otherwise.
Definition: DenseMap.h:151
iterator end()
Definition: DenseMap.h:84
bool contains(const_arg_type_t< KeyT > Val) const
Return true if the specified key is in the map, false otherwise.
Definition: DenseMap.h:145
std::pair< iterator, bool > insert(const std::pair< KeyT, ValueT > &KV)
Definition: DenseMap.h:220
Analysis pass which computes a DominatorTree.
Definition: Dominators.h:279
bool properlyDominates(const DomTreeNodeBase< NodeT > *A, const DomTreeNodeBase< NodeT > *B) const
properlyDominates - Returns true iff A dominates B and A != B.
Legacy analysis pass which computes a DominatorTree.
Definition: Dominators.h:317
Concrete subclass of DominatorTreeBase that is used to compute a normal dominator tree.
Definition: Dominators.h:162
bool isReachableFromEntry(const Use &U) const
Provide an overload for a Use.
Definition: Dominators.cpp:321
bool dominates(const BasicBlock *BB, const Use &U) const
Return true if the (end of the) basic block BB dominates the use U.
Definition: Dominators.cpp:122
EquivalenceClasses - This represents a collection of equivalence classes and supports three efficient...
member_iterator unionSets(const ElemTy &V1, const ElemTy &V2)
union - Merge the two equivalence sets for the specified values, inserting them if they do not alread...
bool isEquivalent(const ElemTy &V1, const ElemTy &V2) const
FoldingSetNodeIDRef - This class describes a reference to an interned FoldingSetNodeID,...
Definition: FoldingSet.h:290
FoldingSetNodeID - This class is used to gather all the unique data bits of a node.
Definition: FoldingSet.h:320
FunctionPass class - This class is used to implement most global optimizations.
Definition: Pass.h:311
const BasicBlock & getEntryBlock() const
Definition: Function.h:783
static Type * getTypeAtIndex(Type *Ty, Value *Idx)
Return the type of the element at the given index of an indexable type.
Module * getParent()
Get the module that this global value is contained inside of...
Definition: GlobalValue.h:656
static bool isPrivateLinkage(LinkageTypes Linkage)
Definition: GlobalValue.h:406
static bool isInternalLinkage(LinkageTypes Linkage)
Definition: GlobalValue.h:403
This instruction compares its operands according to the predicate given to the constructor.
static bool isGE(Predicate P)
Return true if the predicate is SGE or UGE.
static bool isLT(Predicate P)
Return true if the predicate is SLT or ULT.
static bool isGT(Predicate P)
Return true if the predicate is SGT or UGT.
bool isEquality() const
Return true if this predicate is either EQ or NE.
bool isRelational() const
Return true if the predicate is relational (not EQ or NE).
static bool isLE(Predicate P)
Return true if the predicate is SLE or ULE.
bool hasNoUnsignedWrap() const LLVM_READONLY
Determine whether the no unsigned wrap flag is set.
bool hasNoSignedWrap() const LLVM_READONLY
Determine whether the no signed wrap flag is set.
const BasicBlock * getParent() const
Definition: Instruction.h:152
Class to represent integer types.
Definition: DerivedTypes.h:40
static IntegerType * get(LLVMContext &C, unsigned NumBits)
This static method is the primary way of constructing an IntegerType.
Definition: Type.cpp:278
An instruction for reading from memory.
Definition: Instructions.h:184
Analysis pass that exposes the LoopInfo for a function.
Definition: LoopInfo.h:566
bool contains(const LoopT *L) const
Return true if the specified loop is contained within in this loop.
BlockT * getHeader() const
unsigned getLoopDepth() const
Return the nesting level of this loop.
BlockT * getLoopPredecessor() const
If the given loop's header has exactly one unique predecessor outside the loop, return it.
LoopT * getParentLoop() const
Return the parent loop if it exists or nullptr for top level loops.
iterator end() const
unsigned getLoopDepth(const BlockT *BB) const
Return the loop nesting level of the specified block.
iterator begin() const
LoopT * getLoopFor(const BlockT *BB) const
Return the inner most loop that BB lives in.
The legacy pass manager's analysis pass to compute loop information.
Definition: LoopInfo.h:593
Represents a single loop in the control flow graph.
Definition: LoopInfo.h:44
bool isLoopInvariant(const Value *V) const
Return true if the specified value is loop invariant.
Definition: LoopInfo.cpp:60
Metadata node.
Definition: Metadata.h:1067
A Module instance is used to store all the information related to an LLVM module.
Definition: Module.h:65
Function * getFunction(StringRef Name) const
Look up the specified function in the module symbol table.
Definition: Module.cpp:191
This is a utility class that provides an abstraction for the common functionality between Instruction...
Definition: Operator.h:31
unsigned getOpcode() const
Return the opcode for this Instruction or ConstantExpr.
Definition: Operator.h:41
Utility class for integer operators which may exhibit overflow - Add, Sub, Mul, and Shl.
Definition: Operator.h:75
bool hasNoSignedWrap() const
Test whether this operation is known to never undergo signed overflow, aka the nsw property.
Definition: Operator.h:108
bool hasNoUnsignedWrap() const
Test whether this operation is known to never undergo unsigned overflow, aka the nuw property.
Definition: Operator.h:102
iterator_range< const_block_iterator > blocks() const
Value * getIncomingValueForBlock(const BasicBlock *BB) const
BasicBlock * getIncomingBlock(unsigned i) const
Return incoming basic block number i.
Value * getIncomingValue(unsigned i) const
Return incoming value number x.
unsigned getNumIncomingValues() const
Return the number of incoming edges.
static PassRegistry * getPassRegistry()
getPassRegistry - Access the global registry object, which is automatically initialized at applicatio...
PointerIntPair - This class implements a pair of a pointer and small integer.
static PointerType * getUnqual(Type *ElementType)
This constructs a pointer to an object of the specified type in the default address space (address sp...
Definition: DerivedTypes.h:662
static PoisonValue * get(Type *T)
Static factory methods - Return an 'poison' object of the specified type.
Definition: Constants.cpp:1827
An interface layer with SCEV used to manage how we see SCEV expressions for values in the context of ...
void addPredicate(const SCEVPredicate &Pred)
Adds a new predicate.
const SCEVPredicate & getPredicate() const
bool hasNoOverflow(Value *V, SCEVWrapPredicate::IncrementWrapFlags Flags)
Returns true if we've proved that V doesn't wrap by means of a SCEV predicate.
void setNoOverflow(Value *V, SCEVWrapPredicate::IncrementWrapFlags Flags)
Proves that V doesn't overflow by adding SCEV predicate.
void print(raw_ostream &OS, unsigned Depth) const
Print the SCEV mappings done by the Predicated Scalar Evolution.
bool areAddRecsEqualWithPreds(const SCEVAddRecExpr *AR1, const SCEVAddRecExpr *AR2) const
Check if AR1 and AR2 are equal, while taking into account Equal predicates in Preds.
PredicatedScalarEvolution(ScalarEvolution &SE, Loop &L)
const SCEVAddRecExpr * getAsAddRec(Value *V)
Attempts to produce an AddRecExpr for V by adding additional SCEV predicates.
const SCEV * getBackedgeTakenCount()
Get the (predicated) backedge count for the analyzed loop.
const SCEV * getSCEV(Value *V)
Returns the SCEV expression of V, in the context of the current SCEV predicate.
A set of analyses that are preserved following a run of a transformation pass.
Definition: Analysis.h:109
static PreservedAnalyses all()
Construct a special preserved set that preserves all passes.
Definition: Analysis.h:115
PreservedAnalysisChecker getChecker() const
Build a checker for this PreservedAnalyses and the specified analysis type.
Definition: Analysis.h:264
constexpr bool isValid() const
Definition: Register.h:116
This node represents an addition of some number of SCEVs.
This node represents a polynomial recurrence on the trip count of the specified loop.
const SCEV * evaluateAtIteration(const SCEV *It, ScalarEvolution &SE) const
Return the value of this chain of recurrences at the specified iteration number.
const SCEV * getStepRecurrence(ScalarEvolution &SE) const
Constructs and returns the recurrence indicating how much this expression steps by.
void setNoWrapFlags(NoWrapFlags Flags)
Set flags for a recurrence without clearing any previously set flags.
bool isAffine() const
Return true if this represents an expression A + B*x where A and B are loop invariant values.
bool isQuadratic() const
Return true if this represents an expression A + B*x + C*x^2 where A, B and C are loop invariant valu...
const SCEV * getNumIterationsInRange(const ConstantRange &Range, ScalarEvolution &SE) const
Return the number of iterations of this loop that produce values in the specified constant range.
const SCEVAddRecExpr * getPostIncExpr(ScalarEvolution &SE) const
Return an expression representing the value of this expression one iteration of the loop ahead.
This is the base class for unary cast operator classes.
const SCEV * getOperand() const
SCEVCastExpr(const FoldingSetNodeIDRef ID, SCEVTypes SCEVTy, const SCEV *op, Type *ty)
void setNoWrapFlags(NoWrapFlags Flags)
Set flags for a non-recurrence without clearing previously set flags.
This class represents an assumption that the expression LHS Pred RHS evaluates to true,...
SCEVComparePredicate(const FoldingSetNodeIDRef ID, const ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS)
bool isAlwaysTrue() const override
Returns true if the predicate is always true.
bool implies(const SCEVPredicate *N) const override
Implementation of the SCEVPredicate interface.
void print(raw_ostream &OS, unsigned Depth=0) const override
Prints a textual representation of this predicate with an indentation of Depth.
This class represents a constant integer value.
ConstantInt * getValue() const
const APInt & getAPInt() const
This is the base class for unary integral cast operator classes.
SCEVIntegralCastExpr(const FoldingSetNodeIDRef ID, SCEVTypes SCEVTy, const SCEV *op, Type *ty)
This node is the base class min/max selections.
static enum SCEVTypes negate(enum SCEVTypes T)
This node represents multiplication of some number of SCEVs.
This node is a base class providing common functionality for n'ary operators.
NoWrapFlags getNoWrapFlags(NoWrapFlags Mask=NoWrapMask) const
const SCEV * getOperand(unsigned i) const
const SCEV *const * Operands
ArrayRef< const SCEV * > operands() const
This class represents an assumption made using SCEV expressions which can be checked at run-time.
virtual bool implies(const SCEVPredicate *N) const =0
Returns true if this predicate implies N.
SCEVPredicate(const SCEVPredicate &)=default
virtual void print(raw_ostream &OS, unsigned Depth=0) const =0
Prints a textual representation of this predicate with an indentation of Depth.
This class represents a cast from a pointer to a pointer-sized integer value.
This visitor recursively visits a SCEV expression and re-writes it.
const SCEV * visitSignExtendExpr(const SCEVSignExtendExpr *Expr)
const SCEV * visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr)
const SCEV * visitSMinExpr(const SCEVSMinExpr *Expr)
const SCEV * visitUMinExpr(const SCEVUMinExpr *Expr)
This class represents a signed maximum selection.
This class represents a signed minimum selection.
This node is the base class for sequential/in-order min/max selections.
This class represents a sequential/in-order unsigned minimum selection.
This class represents a sign extension of a small integer value to a larger integer value.
Visit all nodes in the expression tree using worklist traversal.
void visitAll(const SCEV *Root)
This class represents a truncation of an integer value to a smaller integer value.
This class represents a binary unsigned division operation.
const SCEV * getLHS() const
const SCEV * getRHS() const
This class represents an unsigned maximum selection.
This class represents an unsigned minimum selection.
This class represents a composition of other SCEV predicates, and is the class that most clients will...
SCEVUnionPredicate(ArrayRef< const SCEVPredicate * > Preds)
Union predicates don't get cached so create a dummy set ID for it.
void print(raw_ostream &OS, unsigned Depth) const override
Prints a textual representation of this predicate with an indentation of Depth.
bool isAlwaysTrue() const override
Implementation of the SCEVPredicate interface.
bool implies(const SCEVPredicate *N) const override
Returns true if this predicate implies N.
This means that we are dealing with an entirely unknown SCEV value, and only represent it as its LLVM...
This class represents the value of vscale, as used when defining the length of a scalable vector or r...
This class represents an assumption made on an AddRec expression.
IncrementWrapFlags
Similar to SCEV::NoWrapFlags, but with slightly different semantics for FlagNUSW.
SCEVWrapPredicate(const FoldingSetNodeIDRef ID, const SCEVAddRecExpr *AR, IncrementWrapFlags Flags)
bool implies(const SCEVPredicate *N) const override
Returns true if this predicate implies N.
static SCEVWrapPredicate::IncrementWrapFlags setFlags(SCEVWrapPredicate::IncrementWrapFlags Flags, SCEVWrapPredicate::IncrementWrapFlags OnFlags)
void print(raw_ostream &OS, unsigned Depth=0) const override
Prints a textual representation of this predicate with an indentation of Depth.
bool isAlwaysTrue() const override
Returns true if the predicate is always true.
const SCEVAddRecExpr * getExpr() const
Implementation of the SCEVPredicate interface.
static SCEVWrapPredicate::IncrementWrapFlags clearFlags(SCEVWrapPredicate::IncrementWrapFlags Flags, SCEVWrapPredicate::IncrementWrapFlags OffFlags)
Convenient IncrementWrapFlags manipulation methods.
static SCEVWrapPredicate::IncrementWrapFlags getImpliedFlags(const SCEVAddRecExpr *AR, ScalarEvolution &SE)
Returns the set of SCEVWrapPredicate no wrap flags implied by a SCEVAddRecExpr.
IncrementWrapFlags getFlags() const
Returns the set assumed no overflow flags.
This class represents a zero extension of a small integer value to a larger integer value.
This class represents an analyzed expression in the program.
ArrayRef< const SCEV * > operands() const
Return operands of this SCEV expression.
unsigned short getExpressionSize() const
bool isOne() const
Return true if the expression is a constant one.
bool isZero() const
Return true if the expression is a constant zero.
void dump() const
This method is used for debugging.
bool isAllOnesValue() const
Return true if the expression is a constant all-ones value.
bool isNonConstantNegative() const
Return true if the specified scev is negated, but not a constant.
void print(raw_ostream &OS) const
Print out the internal representation of this scalar to the specified stream.
SCEVTypes getSCEVType() const
Type * getType() const
Return the LLVM type of this SCEV expression.
NoWrapFlags
NoWrapFlags are bitfield indices into SubclassData.
Analysis pass that exposes the ScalarEvolution for a function.
ScalarEvolution run(Function &F, FunctionAnalysisManager &AM)
PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM)
PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM)
void getAnalysisUsage(AnalysisUsage &AU) const override
getAnalysisUsage - This function should be overriden by passes that need analysis information to do t...
void print(raw_ostream &OS, const Module *=nullptr) const override
print - Print out the internal state of the pass.
bool runOnFunction(Function &F) override
runOnFunction - Virtual method overriden by subclasses to do the per-function processing of the pass.
void releaseMemory() override
releaseMemory() - This member can be implemented by a pass if it wants to be able to release its memo...
void verifyAnalysis() const override
verifyAnalysis() - This member can be implemented by a analysis pass to check state of analysis infor...
The main scalar evolution driver.
const SCEV * getConstantMaxBackedgeTakenCount(const Loop *L)
When successful, this returns a SCEVConstant that is greater than or equal to (i.e.
static bool hasFlags(SCEV::NoWrapFlags Flags, SCEV::NoWrapFlags TestFlags)
const DataLayout & getDataLayout() const
Return the DataLayout associated with the module this SCEV instance is operating on.
bool isKnownNonNegative(const SCEV *S)
Test if the given expression is known to be non-negative.
const SCEV * getNegativeSCEV(const SCEV *V, SCEV::NoWrapFlags Flags=SCEV::FlagAnyWrap)
Return the SCEV object corresponding to -V.
bool isLoopBackedgeGuardedByCond(const Loop *L, ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS)
Test whether the backedge of the loop is protected by a conditional between LHS and RHS.
const SCEV * getSMaxExpr(const SCEV *LHS, const SCEV *RHS)
const SCEV * getUDivCeilSCEV(const SCEV *N, const SCEV *D)
Compute ceil(N / D).
const SCEV * getGEPExpr(GEPOperator *GEP, const SmallVectorImpl< const SCEV * > &IndexExprs)
Returns an expression for a GEP.
Type * getWiderType(Type *Ty1, Type *Ty2) const
const SCEV * getAbsExpr(const SCEV *Op, bool IsNSW)
bool isKnownNonPositive(const SCEV *S)
Test if the given expression is known to be non-positive.
const SCEV * getURemExpr(const SCEV *LHS, const SCEV *RHS)
Represents an unsigned remainder expression based on unsigned division.
bool SimplifyICmpOperands(ICmpInst::Predicate &Pred, const SCEV *&LHS, const SCEV *&RHS, unsigned Depth=0)
Simplify LHS and RHS in a comparison with predicate Pred.
APInt getConstantMultiple(const SCEV *S)
Returns the max constant multiple of S.
bool isKnownNegative(const SCEV *S)
Test if the given expression is known to be negative.
const SCEV * removePointerBase(const SCEV *S)
Compute an expression equivalent to S - getPointerBase(S).
bool isKnownNonZero(const SCEV *S)
Test if the given expression is known to be non-zero.
const SCEV * getSCEVAtScope(const SCEV *S, const Loop *L)
Return a SCEV expression for the specified value at the specified scope in the program.
const SCEV * getSMinExpr(const SCEV *LHS, const SCEV *RHS)
const SCEV * getBackedgeTakenCount(const Loop *L, ExitCountKind Kind=Exact)
If the specified loop has a predictable backedge-taken count, return it, otherwise return a SCEVCould...
const SCEV * getUMaxExpr(const SCEV *LHS, const SCEV *RHS)
void setNoWrapFlags(SCEVAddRecExpr *AddRec, SCEV::NoWrapFlags Flags)
Update no-wrap flags of an AddRec.
const SCEV * getUMaxFromMismatchedTypes(const SCEV *LHS, const SCEV *RHS)
Promote the operands to the wider of the types using zero-extension, and then perform a umax operatio...
const SCEV * getZero(Type *Ty)
Return a SCEV for the constant 0 of a specific type.
bool willNotOverflow(Instruction::BinaryOps BinOp, bool Signed, const SCEV *LHS, const SCEV *RHS, const Instruction *CtxI=nullptr)
Is operation BinOp between LHS and RHS provably does not have a signed/unsigned overflow (Signed)?...
ExitLimit computeExitLimitFromCond(const Loop *L, Value *ExitCond, bool ExitIfTrue, bool ControlsOnlyExit, bool AllowPredicates=false)
Compute the number of times the backedge of the specified loop will execute if its exit condition wer...
const SCEV * getZeroExtendExprImpl(const SCEV *Op, Type *Ty, unsigned Depth=0)
const SCEVPredicate * getEqualPredicate(const SCEV *LHS, const SCEV *RHS)
unsigned getSmallConstantTripMultiple(const Loop *L, const SCEV *ExitCount)
Returns the largest constant divisor of the trip count as a normal unsigned value,...
uint64_t getTypeSizeInBits(Type *Ty) const
Return the size in bits of the specified type, for which isSCEVable must return true.
const SCEV * getConstant(ConstantInt *V)
const SCEV * getSCEV(Value *V)
Return a SCEV expression for the full generality of the specified expression.
ConstantRange getSignedRange(const SCEV *S)
Determine the signed range for a particular SCEV.
const SCEV * getNoopOrSignExtend(const SCEV *V, Type *Ty)
Return a SCEV corresponding to a conversion of the input value to the specified type.
unsigned getSmallConstantMaxTripCount(const Loop *L)
Returns the upper bound of the loop trip count as a normal unsigned value.
bool loopHasNoAbnormalExits(const Loop *L)
Return true if the loop has no abnormal exits.
const SCEV * getTripCountFromExitCount(const SCEV *ExitCount)
A version of getTripCountFromExitCount below which always picks an evaluation type which can not resu...
ScalarEvolution(Function &F, TargetLibraryInfo &TLI, AssumptionCache &AC, DominatorTree &DT, LoopInfo &LI)
const SCEV * getOne(Type *Ty)
Return a SCEV for the constant 1 of a specific type.
const SCEV * getTruncateOrNoop(const SCEV *V, Type *Ty)
Return a SCEV corresponding to a conversion of the input value to the specified type.
const SCEV * getCastExpr(SCEVTypes Kind, const SCEV *Op, Type *Ty)
const SCEV * getSequentialMinMaxExpr(SCEVTypes Kind, SmallVectorImpl< const SCEV * > &Operands)
const SCEV * getLosslessPtrToIntExpr(const SCEV *Op, unsigned Depth=0)
bool isKnownViaInduction(ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS)
We'd like to check the predicate on every iteration of the most dominated loop between loops used in ...
std::optional< bool > evaluatePredicate(ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS)
Check whether the condition described by Pred, LHS, and RHS is true or false.
bool isKnownPredicateAt(ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS, const Instruction *CtxI)
Test if the given expression is known to satisfy the condition described by Pred, LHS,...
const SCEV * getPtrToIntExpr(const SCEV *Op, Type *Ty)
bool isBackedgeTakenCountMaxOrZero(const Loop *L)
Return true if the backedge taken count is either the value returned by getConstantMaxBackedgeTakenCo...
void forgetLoop(const Loop *L)
This method should be called by the client when it has changed a loop in a way that may effect Scalar...
bool isLoopInvariant(const SCEV *S, const Loop *L)
Return true if the value of the given SCEV is unchanging in the specified loop.
bool isKnownPositive(const SCEV *S)
Test if the given expression is known to be positive.
APInt getUnsignedRangeMin(const SCEV *S)
Determine the min of the unsigned range for a particular SCEV.
bool isKnownPredicate(ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS)
Test if the given expression is known to satisfy the condition described by Pred, LHS,...
const SCEV * getOffsetOfExpr(Type *IntTy, StructType *STy, unsigned FieldNo)
Return an expression for offsetof on the given field with type IntTy.
LoopDisposition getLoopDisposition(const SCEV *S, const Loop *L)
Return the "disposition" of the given SCEV with respect to the given loop.
bool containsAddRecurrence(const SCEV *S)
Return true if the SCEV is a scAddRecExpr or it contains scAddRecExpr.
const SCEV * getSignExtendExprImpl(const SCEV *Op, Type *Ty, unsigned Depth=0)
const SCEV * getAddRecExpr(const SCEV *Start, const SCEV *Step, const Loop *L, SCEV::NoWrapFlags Flags)
Get an add recurrence expression for the specified loop.
bool isBasicBlockEntryGuardedByCond(const BasicBlock *BB, ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS)
Test whether entry to the basic block is protected by a conditional between LHS and RHS.
bool isKnownOnEveryIteration(ICmpInst::Predicate Pred, const SCEVAddRecExpr *LHS, const SCEV *RHS)
Test if the condition described by Pred, LHS, RHS is known to be true on every iteration of the loop ...
bool hasOperand(const SCEV *S, const SCEV *Op) const
Test whether the given SCEV has Op as a direct or indirect operand.
const SCEV * getUDivExpr(const SCEV *LHS, const SCEV *RHS)
Get a canonical unsigned division expression, or something simpler if possible.
const SCEV * getZeroExtendExpr(const SCEV *Op, Type *Ty, unsigned Depth=0)
bool isSCEVable(Type *Ty) const
Test if values of the given type are analyzable within the SCEV framework.
Type * getEffectiveSCEVType(Type *Ty) const
Return a type with the same bitwidth as the given type and which represents how SCEV will treat the g...
const SCEVPredicate * getComparePredicate(ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS)
const SCEV * getNotSCEV(const SCEV *V)
Return the SCEV object corresponding to ~V.
std::optional< LoopInvariantPredicate > getLoopInvariantPredicate(ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS, const Loop *L, const Instruction *CtxI=nullptr)
If the result of the predicate LHS Pred RHS is loop invariant with respect to L, return a LoopInvaria...
bool instructionCouldExistWithOperands(const SCEV *A, const SCEV *B)
Return true if there exists a point in the program at which both A and B could be operands to the sam...
std::optional< bool > evaluatePredicateAt(ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS, const Instruction *CtxI)
Check whether the condition described by Pred, LHS, and RHS is true or false in the given Context.
ConstantRange getUnsignedRange(const SCEV *S)
Determine the unsigned range for a particular SCEV.
uint32_t getMinTrailingZeros(const SCEV *S)
Determine the minimum number of zero bits that S is guaranteed to end in (at every loop iteration).
void print(raw_ostream &OS) const
const SCEV * getUMinExpr(const SCEV *LHS, const SCEV *RHS, bool Sequential=false)
const SCEV * getPredicatedBackedgeTakenCount(const Loop *L, SmallVector< const SCEVPredicate *, 4 > &Predicates)
Similar to getBackedgeTakenCount, except it will add a set of SCEV predicates to Predicates that are ...
static SCEV::NoWrapFlags clearFlags(SCEV::NoWrapFlags Flags, SCEV::NoWrapFlags OffFlags)
void forgetTopmostLoop(const Loop *L)
void forgetValue(Value *V)
This method should be called by the client when it has changed a value in a way that may effect its v...
APInt getSignedRangeMin(const SCEV *S)
Determine the min of the signed range for a particular SCEV.
const SCEV * getNoopOrAnyExtend(const SCEV *V, Type *Ty)
Return a SCEV corresponding to a conversion of the input value to the specified type.
void forgetBlockAndLoopDispositions(Value *V=nullptr)
Called when the client has changed the disposition of values in a loop or block.
const SCEV * getTruncateExpr(const SCEV *Op, Type *Ty, unsigned Depth=0)
const SCEV * getStoreSizeOfExpr(Type *IntTy, Type *StoreTy)
Return an expression for the store size of StoreTy that is type IntTy.
const SCEVPredicate * getWrapPredicate(const SCEVAddRecExpr *AR, SCEVWrapPredicate::IncrementWrapFlags AddedFlags)
const SCEV * getMinusSCEV(const SCEV *LHS, const SCEV *RHS, SCEV::NoWrapFlags Flags=SCEV::FlagAnyWrap, unsigned Depth=0)
Return LHS-RHS.
APInt getNonZeroConstantMultiple(const SCEV *S)
const SCEV * getMinusOne(Type *Ty)
Return a SCEV for the constant -1 of a specific type.
static SCEV::NoWrapFlags setFlags(SCEV::NoWrapFlags Flags, SCEV::NoWrapFlags OnFlags)
bool hasLoopInvariantBackedgeTakenCount(const Loop *L)
Return true if the specified loop has an analyzable loop-invariant backedge-taken count.
BlockDisposition getBlockDisposition(const SCEV *S, const BasicBlock *BB)
Return the "disposition" of the given SCEV with respect to the given block.
const SCEV * getNoopOrZeroExtend(const SCEV *V, Type *Ty)
Return a SCEV corresponding to a conversion of the input value to the specified type.
bool invalidate(Function &F, const PreservedAnalyses &PA, FunctionAnalysisManager::Invalidator &Inv)
const SCEV * getUMinFromMismatchedTypes(const SCEV *LHS, const SCEV *RHS, bool Sequential=false)
Promote the operands to the wider of the types using zero-extension, and then perform a umin operatio...
bool loopIsFiniteByAssumption(const Loop *L)
Return true if this loop is finite by assumption.
const SCEV * getExistingSCEV(Value *V)
Return an existing SCEV for V if there is one, otherwise return nullptr.
LoopDisposition
An enum describing the relationship between a SCEV and a loop.
@ LoopComputable
The SCEV varies predictably with the loop.
@ LoopVariant
The SCEV is loop-variant (unknown).
@ LoopInvariant
The SCEV is loop-invariant.
const SCEV * getAnyExtendExpr(const SCEV *Op, Type *Ty)
getAnyExtendExpr - Return a SCEV for the given operand extended with unspecified bits out to the give...
const SCEVAddRecExpr * convertSCEVToAddRecWithPredicates(const SCEV *S, const Loop *L, SmallPtrSetImpl< const SCEVPredicate * > &Preds)
Tries to convert the S expression to an AddRec expression, adding additional predicates to Preds as r...
std::optional< SCEV::NoWrapFlags > getStrengthenedNoWrapFlagsFromBinOp(const OverflowingBinaryOperator *OBO)
Parse NSW/NUW flags from add/sub/mul IR binary operation Op into SCEV no-wrap flags,...
void forgetLcssaPhiWithNewPredecessor(Loop *L, PHINode *V)
Forget LCSSA phi node V of loop L to which a new predecessor was added, such that it may no longer be...
bool containsUndefs(const SCEV *S) const
Return true if the SCEV expression contains an undef value.
std::optional< MonotonicPredicateType > getMonotonicPredicateType(const SCEVAddRecExpr *LHS, ICmpInst::Predicate Pred)
If, for all loop invariant X, the predicate "LHS `Pred` X" is monotonically increasing or decreasing,...
const SCEV * getCouldNotCompute()
bool isAvailableAtLoopEntry(const SCEV *S, const Loop *L)
Determine if the SCEV can be evaluated at loop's entry.
BlockDisposition
An enum describing the relationship between a SCEV and a basic block.
@ DominatesBlock
The SCEV dominates the block.
@ ProperlyDominatesBlock
The SCEV properly dominates the block.
@ DoesNotDominateBlock
The SCEV does not dominate the block.
std::optional< LoopInvariantPredicate > getLoopInvariantExitCondDuringFirstIterationsImpl(ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS, const Loop *L, const Instruction *CtxI, const SCEV *MaxIter)
const SCEV * getExitCount(const Loop *L, const BasicBlock *ExitingBlock, ExitCountKind Kind=Exact)
Return the number of times the backedge executes before the given exit would be taken; if not exactly...
const SCEV * getSignExtendExpr(const SCEV *Op, Type *Ty, unsigned Depth=0)
void getPoisonGeneratingValues(SmallPtrSetImpl< const Value * > &Result, const SCEV *S)
Return the set of Values that, if poison, will definitively result in S being poison as well.
void forgetLoopDispositions()
Called when the client has changed the disposition of values in this loop.
const SCEV * getVScale(Type *Ty)
unsigned getSmallConstantTripCount(const Loop *L)
Returns the exact trip count of the loop if we can compute it, and the result is a small constant.
bool hasComputableLoopEvolution(const SCEV *S, const Loop *L)
Return true if the given SCEV changes value in a known way in the specified loop.
const SCEV * getPointerBase(const SCEV *V)
Transitively follow the chain of pointer-type operands until reaching a SCEV that does not have a sin...
const SCEV * getMinMaxExpr(SCEVTypes Kind, SmallVectorImpl< const SCEV * > &Operands)
bool dominates(const SCEV *S, const BasicBlock *BB)
Return true if elements that makes up the given SCEV dominate the specified basic block.
APInt getUnsignedRangeMax(const SCEV *S)
Determine the max of the unsigned range for a particular SCEV.
ExitCountKind
The terms "backedge taken count" and "exit count" are used interchangeably to refer to the number of ...
@ SymbolicMaximum
An expression which provides an upper bound on the exact trip count.
@ ConstantMaximum
A constant which provides an upper bound on the exact trip count.
@ Exact
An expression exactly describing the number of times the backedge has executed when a loop is exited.
std::optional< LoopInvariantPredicate > getLoopInvariantExitCondDuringFirstIterations(ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS, const Loop *L, const Instruction *CtxI, const SCEV *MaxIter)
If the result of the predicate LHS Pred RHS is loop invariant with respect to L at given Context duri...
const SCEV * applyLoopGuards(const SCEV *Expr, const Loop *L)
Try to apply information from loop guards for L to Expr.
const SCEV * getMulExpr(SmallVectorImpl< const SCEV * > &Ops, SCEV::NoWrapFlags Flags=SCEV::FlagAnyWrap, unsigned Depth=0)
Get a canonical multiply expression, or something simpler if possible.
const SCEV * getElementSize(Instruction *Inst)
Return the size of an element read or written by Inst.
const SCEV * getSizeOfExpr(Type *IntTy, TypeSize Size)
Return an expression for a TypeSize.
const SCEV * getUnknown(Value *V)
std::optional< std::pair< const SCEV *, SmallVector< const SCEVPredicate *, 3 > > > createAddRecFromPHIWithCasts(const SCEVUnknown *SymbolicPHI)
Checks if SymbolicPHI can be rewritten as an AddRecExpr under some Predicates.
const SCEV * getTruncateOrZeroExtend(const SCEV *V, Type *Ty, unsigned Depth=0)
Return a SCEV corresponding to a conversion of the input value to the specified type.
bool isLoopEntryGuardedByCond(const Loop *L, ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS)
Test whether entry to the loop is protected by a conditional between LHS and RHS.
const SCEV * getElementCount(Type *Ty, ElementCount EC)
static SCEV::NoWrapFlags maskFlags(SCEV::NoWrapFlags Flags, int Mask)
Convenient NoWrapFlags manipulation that hides enum casts and is visible in the ScalarEvolution name ...
std::optional< APInt > computeConstantDifference(const SCEV *LHS, const SCEV *RHS)
Compute LHS - RHS and returns the result as an APInt if it is a constant, and std::nullopt if it isn'...
bool properlyDominates(const SCEV *S, const BasicBlock *BB)
Return true if elements that makes up the given SCEV properly dominate the specified basic block.
const SCEV * rewriteUsingPredicate(const SCEV *S, const Loop *L, const SCEVPredicate &A)
Re-writes the SCEV according to the Predicates in A.
std::pair< const SCEV *, const SCEV * > SplitIntoInitAndPostInc(const Loop *L, const SCEV *S)
Splits SCEV expression S into two SCEVs.
bool canReuseInstruction(const SCEV *S, Instruction *I, SmallVectorImpl< Instruction * > &DropPoisonGeneratingInsts)
Check whether it is poison-safe to represent the expression S using the instruction I.
const SCEV * getUDivExactExpr(const SCEV *LHS, const SCEV *RHS)
Get a canonical unsigned division expression, or something simpler if possible.
void registerUser(const SCEV *User, ArrayRef< const SCEV * > Ops)
Notify this ScalarEvolution that User directly uses SCEVs in Ops.
const SCEV * getAddExpr(SmallVectorImpl< const SCEV * > &Ops, SCEV::NoWrapFlags Flags=SCEV::FlagAnyWrap, unsigned Depth=0)
Get a canonical add expression, or something simpler if possible.
const SCEV * getTruncateOrSignExtend(const SCEV *V, Type *Ty, unsigned Depth=0)
Return a SCEV corresponding to a conversion of the input value to the specified type.
bool containsErasedValue(const SCEV *S) const
Return true if the SCEV expression contains a Value that has been optimised out and is now a nullptr.
const SCEV * getSymbolicMaxBackedgeTakenCount(const Loop *L)
When successful, this returns a SCEV that is greater than or equal to (i.e.
APInt getSignedRangeMax(const SCEV *S)
Determine the max of the signed range for a particular SCEV.
LLVMContext & getContext() const
This class represents the LLVM 'select' instruction.
size_type size() const
Definition: SmallPtrSet.h:94
A templated base class for SmallPtrSet which provides the typesafe interface that is common across al...
Definition: SmallPtrSet.h:321
std::pair< iterator, bool > insert(PtrType Ptr)
Inserts Ptr if and only if there is no element in the container equal to Ptr.
Definition: SmallPtrSet.h:342
bool contains(ConstPtrType Ptr) const
Definition: SmallPtrSet.h:366
SmallPtrSet - This class implements a set which is optimized for holding SmallSize or less elements.
Definition: SmallPtrSet.h:427
SmallSet - This maintains a set of unique values, optimizing for the case when the set is small (less...
Definition: SmallSet.h:135
std::pair< const_iterator, bool > insert(const T &V)
insert - Insert an element into the set if it isn't already there.
Definition: SmallSet.h:179
size_type size() const
Definition: SmallSet.h:161
bool empty() const
Definition: SmallVector.h:94
size_t size() const
Definition: SmallVector.h:91
This class consists of common code factored out of the SmallVector class to reduce code duplication b...
Definition: SmallVector.h:586
reference emplace_back(ArgTypes &&... Args)
Definition: SmallVector.h:950
void reserve(size_type N)
Definition: SmallVector.h:676
iterator erase(const_iterator CI)
Definition: SmallVector.h:750
void append(ItTy in_start, ItTy in_end)
Add the specified range to the end of the SmallVector.
Definition: SmallVector.h:696
iterator insert(iterator I, T &&Elt)
Definition: SmallVector.h:818
void push_back(const T &Elt)
Definition: SmallVector.h:426
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
Definition: SmallVector.h:1209
An instruction for storing to memory.
Definition: Instructions.h:317
Used to lazily calculate structure layout information for a target machine, based on the DataLayout s...
Definition: DataLayout.h:622
TypeSize getElementOffset(unsigned Idx) const
Definition: DataLayout.h:651
TypeSize getSizeInBits() const
Definition: DataLayout.h:631
Class to represent struct types.
Definition: DerivedTypes.h:216
Multiway switch.
Analysis pass providing the TargetLibraryInfo.
Provides information about what library functions are available for the current target.
The instances of the Type class are immutable: once they are created, they are never changed.
Definition: Type.h:45
bool isPointerTy() const
True if this is an instance of PointerType.
Definition: Type.h:255
static IntegerType * getInt1Ty(LLVMContext &C)
static IntegerType * getIntNTy(LLVMContext &C, unsigned N)
unsigned getScalarSizeInBits() const LLVM_READONLY
If this is a vector type, return the getPrimitiveSizeInBits value for the element type.
static IntegerType * getInt8Ty(LLVMContext &C)
bool isIntOrPtrTy() const
Return true if this is an integer type or a pointer type.
Definition: Type.h:243
static IntegerType * getInt32Ty(LLVMContext &C)
bool isIntegerTy() const
True if this is an instance of IntegerType.
Definition: Type.h:228
TypeSize getPrimitiveSizeInBits() const LLVM_READONLY
Return the basic size of this type if it is a primitive type.
A Use represents the edge between a Value definition and its users.
Definition: Use.h:43
op_range operands()
Definition: User.h:242
Use & Op()
Definition: User.h:133
Value * getOperand(unsigned i) const
Definition: User.h:169
LLVM Value Representation.
Definition: Value.h:74
Type * getType() const
All values are typed, get the type of this value.
Definition: Value.h:255
unsigned getValueID() const
Return an ID for the concrete type of this object.
Definition: Value.h:532
void printAsOperand(raw_ostream &O, bool PrintType=true, const Module *M=nullptr) const
Print the name of this Value out to the specified raw_ostream.
Definition: AsmWriter.cpp:5079
LLVMContext & getContext() const
All values hold a context through their type.
Definition: Value.cpp:1074
iterator_range< use_iterator > uses()
Definition: Value.h:376
StringRef getName() const
Return a constant reference to the value's name.
Definition: Value.cpp:309
Represents an op.with.overflow intrinsic.
constexpr bool isScalable() const
Returns whether the quantity is scaled by a runtime quantity (vscale).
Definition: TypeSize.h:171
This class implements an extremely fast bulk output stream that can only output to a stream.
Definition: raw_ostream.h:52
raw_ostream & indent(unsigned NumSpaces)
indent - Insert 'NumSpaces' spaces.
#define llvm_unreachable(msg)
Marks that the current location is not supposed to be reachable.
const APInt & smin(const APInt &A, const APInt &B)
Determine the smaller of two APInts considered to be signed.
Definition: APInt.h:2178
const APInt & smax(const APInt &A, const APInt &B)
Determine the larger of two APInts considered to be signed.
Definition: APInt.h:2183
const APInt & umin(const APInt &A, const APInt &B)
Determine the smaller of two APInts considered to be unsigned.
Definition: APInt.h:2188
std::optional< APInt > SolveQuadraticEquationWrap(APInt A, APInt B, APInt C, unsigned RangeWidth)
Let q(n) = An^2 + Bn + C, and BW = bit width of the value range (e.g.
Definition: APInt.cpp:2781
const APInt & umax(const APInt &A, const APInt &B)
Determine the larger of two APInts considered to be unsigned.
Definition: APInt.h:2193
APInt GreatestCommonDivisor(APInt A, APInt B)
Compute GCD of two unsigned APInt values.
Definition: APInt.cpp:767
unsigned ID
LLVM IR allows to use arbitrary numbers as calling convention identifiers.
Definition: CallingConv.h:24
@ C
The default llvm calling convention, compatible with C.
Definition: CallingConv.h:34
StringRef getName(ID id)
Return the LLVM name for an intrinsic, such as "llvm.ppc.altivec.lvx".
Definition: Function.cpp:1023
BinaryOp_match< LHS, RHS, Instruction::AShr > m_AShr(const LHS &L, const RHS &R)
bool match(Val *V, const Pattern &P)
Definition: PatternMatch.h:49
class_match< ConstantInt > m_ConstantInt()
Match an arbitrary ConstantInt and ignore it.
Definition: PatternMatch.h:163
ThreeOps_match< Cond, LHS, RHS, Instruction::Select > m_Select(const Cond &C, const LHS &L, const RHS &R)
Matches SelectInst.
bind_ty< WithOverflowInst > m_WithOverflowInst(WithOverflowInst *&I)
Match a with overflow intrinsic, capturing it if we match.
Definition: PatternMatch.h:771
auto m_LogicalOr()
Matches L || R where L and R are arbitrary values.
brc_match< Cond_t, bind_ty< BasicBlock >, bind_ty< BasicBlock > > m_Br(const Cond_t &C, BasicBlock *&T, BasicBlock *&F)
BinaryOp_match< LHS, RHS, Instruction::SDiv > m_SDiv(const LHS &L, const RHS &R)
apint_match m_APInt(const APInt *&Res)
Match a ConstantInt or splatted ConstantVector, binding the specified pointer to the contained APInt.
Definition: PatternMatch.h:294
class_match< Value > m_Value()
Match an arbitrary value and ignore it.
Definition: PatternMatch.h:92
BinaryOp_match< LHS, RHS, Instruction::LShr > m_LShr(const LHS &L, const RHS &R)
BinaryOp_match< LHS, RHS, Instruction::Shl > m_Shl(const LHS &L, const RHS &R)
auto m_LogicalAnd()
Matches L && R where L and R are arbitrary values.
class_match< BasicBlock > m_BasicBlock()
Match an arbitrary basic block value and ignore it.
Definition: PatternMatch.h:184
@ ReallyHidden
Definition: CommandLine.h:139
initializer< Ty > init(const Ty &Val)
Definition: CommandLine.h:450
LocationClass< Ty > location(Ty &L)
Definition: CommandLine.h:470
@ Switch
The "resume-switch" lowering, where there are separate resume and destroy functions that are shared b...
constexpr double e
Definition: MathExtras.h:31
NodeAddr< PhiNode * > Phi
Definition: RDFGraph.h:390
@ FalseVal
Definition: TGLexer.h:59
This is an optimization pass for GlobalISel generic memory operations.
Definition: AddressRanges.h:18
void visitAll(const SCEV *Root, SV &Visitor)
Use SCEVTraversal to visit all nodes in the given expression tree.
auto drop_begin(T &&RangeOrContainer, size_t N=1)
Return a range covering RangeOrContainer with the first N elements excluded.
Definition: STLExtras.h:329
@ Offset
Definition: DWP.cpp:456
void stable_sort(R &&Range)
Definition: STLExtras.h:1995
bool all_of(R &&range, UnaryPredicate P)
Provide wrappers to std::all_of which take ranges instead of having to pass begin/end explicitly.
Definition: STLExtras.h:1722
bool canCreatePoison(const Operator *Op, bool ConsiderFlagsAndMetadata=true)
bool mustTriggerUB(const Instruction *I, const SmallPtrSetImpl< const Value * > &KnownPoison)
Return true if the given instruction must trigger undefined behavior when I is executed with any oper...
detail::scope_exit< std::decay_t< Callable > > make_scope_exit(Callable &&F)
Definition: ScopeExit.h:59
bool canConstantFoldCallTo(const CallBase *Call, const Function *F)
canConstantFoldCallTo - Return true if its even possible to fold a call to the specified function.
bool verifyFunction(const Function &F, raw_ostream *OS=nullptr)
Check a function for errors, useful for use when debugging a pass.
Definition: Verifier.cpp:7041
auto successors(const MachineBasicBlock *BB)
void * PointerTy
Definition: GenericValue.h:21
void append_range(Container &C, Range &&R)
Wrapper function to append range R to container C.
Definition: STLExtras.h:2073
Constant * ConstantFoldCompareInstOperands(unsigned Predicate, Constant *LHS, Constant *RHS, const DataLayout &DL, const TargetLibraryInfo *TLI=nullptr, const Instruction *I=nullptr)
Attempt to constant fold a compare instruction (icmp/fcmp) with the specified operands.
unsigned short computeExpressionSize(ArrayRef< const SCEV * > Args)
bool VerifySCEV
Printable print(const GCNRegPressure &RP, const GCNSubtarget *ST=nullptr)
ConstantRange getConstantRangeFromMetadata(const MDNode &RangeMD)
Parse out a conservative ConstantRange from !range metadata.
int countr_zero(T Val)
Count number of 0's from the least significant bit to the most stopping at the first 1.
Definition: bit.h:215
Value * simplifyInstruction(Instruction *I, const SimplifyQuery &Q)
See if we can compute a simplified version of this instruction.
bool isOverflowIntrinsicNoWrap(const WithOverflowInst *WO, const DominatorTree &DT)
Returns true if the arithmetic part of the WO 's result is used only along the paths control dependen...
bool matchSimpleRecurrence(const PHINode *P, BinaryOperator *&BO, Value *&Start, Value *&Step)
Attempt to match a simple first order recurrence cycle of the form: iv = phi Ty [Start,...
void erase(Container &C, ValueType V)
Wrapper function to remove a value from a container:
Definition: STLExtras.h:2059
bool any_of(R &&range, UnaryPredicate P)
Provide wrappers to std::any_of which take ranges instead of having to pass begin/end explicitly.
Definition: STLExtras.h:1729
bool getObjectSize(const Value *Ptr, uint64_t &Size, const DataLayout &DL, const TargetLibraryInfo *TLI, ObjectSizeOpts Opts={})
Compute the size of the object pointed by Ptr.
void initializeScalarEvolutionWrapperPassPass(PassRegistry &)
auto reverse(ContainerTy &&C)
Definition: STLExtras.h:419
bool isMustProgress(const Loop *L)
Return true if this loop can be assumed to make progress.
Definition: LoopInfo.cpp:1118
bool impliesPoison(const Value *ValAssumedPoison, const Value *V)
Return true if V is poison given that ValAssumedPoison is already poison.
Constant * ConstantFoldInstOperands(Instruction *I, ArrayRef< Constant * > Ops, const DataLayout &DL, const TargetLibraryInfo *TLI=nullptr)
ConstantFoldInstOperands - Attempt to constant fold an instruction with the specified operands.
bool isFinite(const Loop *L)
Return true if this loop can be assumed to run for a finite number of iterations.
Definition: LoopInfo.cpp:1108
bool programUndefinedIfPoison(const Instruction *Inst)
raw_ostream & dbgs()
dbgs() - This returns a reference to a raw_ostream for debugging messages.
Definition: Debug.cpp:163
bool isPointerTy(const Type *T)
Definition: SPIRVUtils.h:116
ConstantRange getVScaleRange(const Function *F, unsigned BitWidth)
Determine the possible constant range of vscale with the given bit width, based on the vscale_range f...
bool isKnownNonZero(const Value *V, const SimplifyQuery &Q, unsigned Depth=0)
Return true if the given value is known to be non-zero when defined.
@ First
Helpers to iterate all locations in the MemoryEffectsBase class.
bool propagatesPoison(const Use &PoisonOp)
Return true if PoisonOp's user yields poison or raises UB if its operand PoisonOp is poison.
@ UMin
Unsigned integer min implemented in terms of select(cmp()).
@ Mul
Product of integers.
@ SMax
Signed integer max implemented in terms of select(cmp()).
@ SMin
Signed integer min implemented in terms of select(cmp()).
@ Add
Sum of integers.
@ UMax
Unsigned integer max implemented in terms of select(cmp()).
bool isIntN(unsigned N, int64_t x)
Checks if an signed integer fits into the given (dynamic) bit width.
Definition: MathExtras.h:233
auto count(R &&Range, const E &Element)
Wrapper function around std::count to count the number of times an element Element occurs in the give...
Definition: STLExtras.h:1914
void computeKnownBits(const Value *V, KnownBits &Known, const DataLayout &DL, unsigned Depth=0, AssumptionCache *AC=nullptr, const Instruction *CxtI=nullptr, const DominatorTree *DT=nullptr, bool UseInstrInfo=true)
Determine which bits of V are known to be either zero or one and return them in the KnownZero/KnownOn...
DWARFExpression::Operation Op
auto max_element(R &&Range)
Definition: STLExtras.h:1986
raw_ostream & operator<<(raw_ostream &OS, const APFixedPoint &FX)
Definition: APFixedPoint.h:293
constexpr unsigned BitWidth
Definition: BitmaskEnum.h:191
OutputIt move(R &&Range, OutputIt Out)
Provide wrappers to std::move which take ranges instead of having to pass begin/end explicitly.
Definition: STLExtras.h:1849
bool isGuaranteedToTransferExecutionToSuccessor(const Instruction *I)
Return true if this function can prove that the instruction I will always transfer execution to one o...
auto count_if(R &&Range, UnaryPredicate P)
Wrapper function around std::count_if to count the number of times an element satisfying a given pred...
Definition: STLExtras.h:1921
auto predecessors(const MachineBasicBlock *BB)
bool isAllocationFn(const Value *V, const TargetLibraryInfo *TLI)
Tests if a value is a call or invoke to a library function that allocates or reallocates memory (eith...
bool is_contained(R &&Range, const E &Element)
Returns true if Element is found in Range.
Definition: STLExtras.h:1879
unsigned ComputeNumSignBits(const Value *Op, const DataLayout &DL, unsigned Depth=0, AssumptionCache *AC=nullptr, const Instruction *CxtI=nullptr, const DominatorTree *DT=nullptr, bool UseInstrInfo=true)
Return the number of times the sign bit of the register is replicated into the other bits.
iterator_range< df_iterator< T > > depth_first(const T &G)
auto seq(T Begin, T End)
Iterate over an integral type from Begin up to - but not including - End.
Definition: Sequence.h:305
bool isGuaranteedNotToBePoison(const Value *V, AssumptionCache *AC=nullptr, const Instruction *CtxI=nullptr, const DominatorTree *DT=nullptr, unsigned Depth=0)
Returns true if V cannot be poison, but may be undef.
bool SCEVExprContains(const SCEV *Root, PredTy Pred)
Return true if any node in Root satisfies the predicate Pred.
Implement std::hash so that hash_code can be used in STL containers.
Definition: BitVector.h:858
void swap(llvm::BitVector &LHS, llvm::BitVector &RHS)
Implement std::swap in terms of BitVector swap.
Definition: BitVector.h:860
#define N
#define NC
Definition: regutils.h:42
This struct is a compact representation of a valid (non-zero power of two) alignment.
Definition: Alignment.h:39
A special type used by analysis passes to provide an address that identifies that particular analysis...
Definition: Analysis.h:26
static KnownBits makeConstant(const APInt &C)
Create known bits from a known constant.
Definition: KnownBits.h:297
bool isNonNegative() const
Returns true if this value is known to be non-negative.
Definition: KnownBits.h:104
static KnownBits ashr(const KnownBits &LHS, const KnownBits &RHS, bool ShAmtNonZero=false, bool Exact=false)
Compute known bits for ashr(LHS, RHS).
Definition: KnownBits.cpp:422
unsigned getBitWidth() const
Get the bit width of this value.
Definition: KnownBits.h:40
static KnownBits lshr(const KnownBits &LHS, const KnownBits &RHS, bool ShAmtNonZero=false, bool Exact=false)
Compute known bits for lshr(LHS, RHS).
Definition: KnownBits.cpp:364
KnownBits zextOrTrunc(unsigned BitWidth) const
Return known bits for a zero extension or truncation of the value we're tracking.
Definition: KnownBits.h:192
APInt getMaxValue() const
Return the maximal unsigned value possible given these KnownBits.
Definition: KnownBits.h:141
APInt getMinValue() const
Return the minimal unsigned value possible given these KnownBits.
Definition: KnownBits.h:125
bool isNegative() const
Returns true if this value is known to be negative.
Definition: KnownBits.h:101
static KnownBits shl(const KnownBits &LHS, const KnownBits &RHS, bool NUW=false, bool NSW=false, bool ShAmtNonZero=false)
Compute known bits for shl(LHS, RHS).
Definition: KnownBits.cpp:279
Various options to control the behavior of getObjectSize.
bool NullIsUnknownSize
If this is true, null pointers in address space 0 will be treated as though they can't be evaluated.
bool RoundToAlign
Whether to round the result up to the alignment of allocas, byval arguments, and global variables.
An object of this class is returned by queries that could not be answered.
static bool classof(const SCEV *S)
Methods for support type inquiry through isa, cast, and dyn_cast:
This class defines a simple visitor class that may be used for various SCEV analysis purposes.
A utility class that uses RAII to save and restore the value of a variable.
Information about the number of loop iterations for which a loop exit's branch condition evaluates to...
ExitLimit(const SCEV *E)
Construct either an exact exit limit from a constant, or an unknown one from a SCEVCouldNotCompute.
void addPredicate(const SCEVPredicate *P)