LLVM 20.0.0git
ScalarEvolution.cpp
Go to the documentation of this file.
1//===- ScalarEvolution.cpp - Scalar Evolution Analysis --------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file contains the implementation of the scalar evolution analysis
10// engine, which is used primarily to analyze expressions involving induction
11// variables in loops.
12//
13// There are several aspects to this library. First is the representation of
14// scalar expressions, which are represented as subclasses of the SCEV class.
15// These classes are used to represent certain types of subexpressions that we
16// can handle. We only create one SCEV of a particular shape, so
17// pointer-comparisons for equality are legal.
18//
19// One important aspect of the SCEV objects is that they are never cyclic, even
20// if there is a cycle in the dataflow for an expression (ie, a PHI node). If
21// the PHI node is one of the idioms that we can represent (e.g., a polynomial
22// recurrence) then we represent it directly as a recurrence node, otherwise we
23// represent it as a SCEVUnknown node.
24//
25// In addition to being able to represent expressions of various types, we also
26// have folders that are used to build the *canonical* representation for a
27// particular expression. These folders are capable of using a variety of
28// rewrite rules to simplify the expressions.
29//
30// Once the folders are defined, we can implement the more interesting
31// higher-level code, such as the code that recognizes PHI nodes of various
32// types, computes the execution count of a loop, etc.
33//
34// TODO: We should use these routines and value representations to implement
35// dependence analysis!
36//
37//===----------------------------------------------------------------------===//
38//
39// There are several good references for the techniques used in this analysis.
40//
41// Chains of recurrences -- a method to expedite the evaluation
42// of closed-form functions
43// Olaf Bachmann, Paul S. Wang, Eugene V. Zima
44//
45// On computational properties of chains of recurrences
46// Eugene V. Zima
47//
48// Symbolic Evaluation of Chains of Recurrences for Loop Optimization
49// Robert A. van Engelen
50//
51// Efficient Symbolic Analysis for Optimizing Compilers
52// Robert A. van Engelen
53//
54// Using the chains of recurrences algebra for data dependence testing and
55// induction variable substitution
56// MS Thesis, Johnie Birch
57//
58//===----------------------------------------------------------------------===//
59
61#include "llvm/ADT/APInt.h"
62#include "llvm/ADT/ArrayRef.h"
63#include "llvm/ADT/DenseMap.h"
66#include "llvm/ADT/FoldingSet.h"
67#include "llvm/ADT/STLExtras.h"
68#include "llvm/ADT/ScopeExit.h"
69#include "llvm/ADT/Sequence.h"
71#include "llvm/ADT/SmallSet.h"
73#include "llvm/ADT/Statistic.h"
75#include "llvm/ADT/StringRef.h"
85#include "llvm/Config/llvm-config.h"
86#include "llvm/IR/Argument.h"
87#include "llvm/IR/BasicBlock.h"
88#include "llvm/IR/CFG.h"
89#include "llvm/IR/Constant.h"
91#include "llvm/IR/Constants.h"
92#include "llvm/IR/DataLayout.h"
94#include "llvm/IR/Dominators.h"
95#include "llvm/IR/Function.h"
96#include "llvm/IR/GlobalAlias.h"
97#include "llvm/IR/GlobalValue.h"
99#include "llvm/IR/InstrTypes.h"
100#include "llvm/IR/Instruction.h"
101#include "llvm/IR/Instructions.h"
103#include "llvm/IR/Intrinsics.h"
104#include "llvm/IR/LLVMContext.h"
105#include "llvm/IR/Operator.h"
106#include "llvm/IR/PatternMatch.h"
107#include "llvm/IR/Type.h"
108#include "llvm/IR/Use.h"
109#include "llvm/IR/User.h"
110#include "llvm/IR/Value.h"
111#include "llvm/IR/Verifier.h"
113#include "llvm/Pass.h"
114#include "llvm/Support/Casting.h"
117#include "llvm/Support/Debug.h"
122#include <algorithm>
123#include <cassert>
124#include <climits>
125#include <cstdint>
126#include <cstdlib>
127#include <map>
128#include <memory>
129#include <numeric>
130#include <optional>
131#include <tuple>
132#include <utility>
133#include <vector>
134
135using namespace llvm;
136using namespace PatternMatch;
137using namespace SCEVPatternMatch;
138
139#define DEBUG_TYPE "scalar-evolution"
140
141STATISTIC(NumExitCountsComputed,
142 "Number of loop exits with predictable exit counts");
143STATISTIC(NumExitCountsNotComputed,
144 "Number of loop exits without predictable exit counts");
145STATISTIC(NumBruteForceTripCountsComputed,
146 "Number of loops with trip counts computed by force");
147
148#ifdef EXPENSIVE_CHECKS
149bool llvm::VerifySCEV = true;
150#else
151bool llvm::VerifySCEV = false;
152#endif
153
155 MaxBruteForceIterations("scalar-evolution-max-iterations", cl::ReallyHidden,
156 cl::desc("Maximum number of iterations SCEV will "
157 "symbolically execute a constant "
158 "derived loop"),
159 cl::init(100));
160
162 "verify-scev", cl::Hidden, cl::location(VerifySCEV),
163 cl::desc("Verify ScalarEvolution's backedge taken counts (slow)"));
165 "verify-scev-strict", cl::Hidden,
166 cl::desc("Enable stricter verification with -verify-scev is passed"));
167
169 "scev-verify-ir", cl::Hidden,
170 cl::desc("Verify IR correctness when making sensitive SCEV queries (slow)"),
171 cl::init(false));
172
174 "scev-mulops-inline-threshold", cl::Hidden,
175 cl::desc("Threshold for inlining multiplication operands into a SCEV"),
176 cl::init(32));
177
179 "scev-addops-inline-threshold", cl::Hidden,
180 cl::desc("Threshold for inlining addition operands into a SCEV"),
181 cl::init(500));
182
184 "scalar-evolution-max-scev-compare-depth", cl::Hidden,
185 cl::desc("Maximum depth of recursive SCEV complexity comparisons"),
186 cl::init(32));
187
189 "scalar-evolution-max-scev-operations-implication-depth", cl::Hidden,
190 cl::desc("Maximum depth of recursive SCEV operations implication analysis"),
191 cl::init(2));
192
194 "scalar-evolution-max-value-compare-depth", cl::Hidden,
195 cl::desc("Maximum depth of recursive value complexity comparisons"),
196 cl::init(2));
197
199 MaxArithDepth("scalar-evolution-max-arith-depth", cl::Hidden,
200 cl::desc("Maximum depth of recursive arithmetics"),
201 cl::init(32));
202
204 "scalar-evolution-max-constant-evolving-depth", cl::Hidden,
205 cl::desc("Maximum depth of recursive constant evolving"), cl::init(32));
206
208 MaxCastDepth("scalar-evolution-max-cast-depth", cl::Hidden,
209 cl::desc("Maximum depth of recursive SExt/ZExt/Trunc"),
210 cl::init(8));
211
213 MaxAddRecSize("scalar-evolution-max-add-rec-size", cl::Hidden,
214 cl::desc("Max coefficients in AddRec during evolving"),
215 cl::init(8));
216
218 HugeExprThreshold("scalar-evolution-huge-expr-threshold", cl::Hidden,
219 cl::desc("Size of the expression which is considered huge"),
220 cl::init(4096));
221
223 "scev-range-iter-threshold", cl::Hidden,
224 cl::desc("Threshold for switching to iteratively computing SCEV ranges"),
225 cl::init(32));
226
228 "scalar-evolution-max-loop-guard-collection-depth", cl::Hidden,
229 cl::desc("Maximum depth for recrusive loop guard collection"), cl::init(1));
230
231static cl::opt<bool>
232ClassifyExpressions("scalar-evolution-classify-expressions",
233 cl::Hidden, cl::init(true),
234 cl::desc("When printing analysis, include information on every instruction"));
235
237 "scalar-evolution-use-expensive-range-sharpening", cl::Hidden,
238 cl::init(false),
239 cl::desc("Use more powerful methods of sharpening expression ranges. May "
240 "be costly in terms of compile time"));
241
243 "scalar-evolution-max-scc-analysis-depth", cl::Hidden,
244 cl::desc("Maximum amount of nodes to process while searching SCEVUnknown "
245 "Phi strongly connected components"),
246 cl::init(8));
247
248static cl::opt<bool>
249 EnableFiniteLoopControl("scalar-evolution-finite-loop", cl::Hidden,
250 cl::desc("Handle <= and >= in finite loops"),
251 cl::init(true));
252
254 "scalar-evolution-use-context-for-no-wrap-flag-strenghening", cl::Hidden,
255 cl::desc("Infer nuw/nsw flags using context where suitable"),
256 cl::init(true));
257
258//===----------------------------------------------------------------------===//
259// SCEV class definitions
260//===----------------------------------------------------------------------===//
261
262//===----------------------------------------------------------------------===//
263// Implementation of the SCEV class.
264//
265
266#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
268 print(dbgs());
269 dbgs() << '\n';
270}
271#endif
272
274 switch (getSCEVType()) {
275 case scConstant:
276 cast<SCEVConstant>(this)->getValue()->printAsOperand(OS, false);
277 return;
278 case scVScale:
279 OS << "vscale";
280 return;
281 case scPtrToInt: {
282 const SCEVPtrToIntExpr *PtrToInt = cast<SCEVPtrToIntExpr>(this);
283 const SCEV *Op = PtrToInt->getOperand();
284 OS << "(ptrtoint " << *Op->getType() << " " << *Op << " to "
285 << *PtrToInt->getType() << ")";
286 return;
287 }
288 case scTruncate: {
289 const SCEVTruncateExpr *Trunc = cast<SCEVTruncateExpr>(this);
290 const SCEV *Op = Trunc->getOperand();
291 OS << "(trunc " << *Op->getType() << " " << *Op << " to "
292 << *Trunc->getType() << ")";
293 return;
294 }
295 case scZeroExtend: {
296 const SCEVZeroExtendExpr *ZExt = cast<SCEVZeroExtendExpr>(this);
297 const SCEV *Op = ZExt->getOperand();
298 OS << "(zext " << *Op->getType() << " " << *Op << " to "
299 << *ZExt->getType() << ")";
300 return;
301 }
302 case scSignExtend: {
303 const SCEVSignExtendExpr *SExt = cast<SCEVSignExtendExpr>(this);
304 const SCEV *Op = SExt->getOperand();
305 OS << "(sext " << *Op->getType() << " " << *Op << " to "
306 << *SExt->getType() << ")";
307 return;
308 }
309 case scAddRecExpr: {
310 const SCEVAddRecExpr *AR = cast<SCEVAddRecExpr>(this);
311 OS << "{" << *AR->getOperand(0);
312 for (unsigned i = 1, e = AR->getNumOperands(); i != e; ++i)
313 OS << ",+," << *AR->getOperand(i);
314 OS << "}<";
315 if (AR->hasNoUnsignedWrap())
316 OS << "nuw><";
317 if (AR->hasNoSignedWrap())
318 OS << "nsw><";
319 if (AR->hasNoSelfWrap() &&
321 OS << "nw><";
322 AR->getLoop()->getHeader()->printAsOperand(OS, /*PrintType=*/false);
323 OS << ">";
324 return;
325 }
326 case scAddExpr:
327 case scMulExpr:
328 case scUMaxExpr:
329 case scSMaxExpr:
330 case scUMinExpr:
331 case scSMinExpr:
333 const SCEVNAryExpr *NAry = cast<SCEVNAryExpr>(this);
334 const char *OpStr = nullptr;
335 switch (NAry->getSCEVType()) {
336 case scAddExpr: OpStr = " + "; break;
337 case scMulExpr: OpStr = " * "; break;
338 case scUMaxExpr: OpStr = " umax "; break;
339 case scSMaxExpr: OpStr = " smax "; break;
340 case scUMinExpr:
341 OpStr = " umin ";
342 break;
343 case scSMinExpr:
344 OpStr = " smin ";
345 break;
347 OpStr = " umin_seq ";
348 break;
349 default:
350 llvm_unreachable("There are no other nary expression types.");
351 }
352 OS << "(";
353 ListSeparator LS(OpStr);
354 for (const SCEV *Op : NAry->operands())
355 OS << LS << *Op;
356 OS << ")";
357 switch (NAry->getSCEVType()) {
358 case scAddExpr:
359 case scMulExpr:
360 if (NAry->hasNoUnsignedWrap())
361 OS << "<nuw>";
362 if (NAry->hasNoSignedWrap())
363 OS << "<nsw>";
364 break;
365 default:
366 // Nothing to print for other nary expressions.
367 break;
368 }
369 return;
370 }
371 case scUDivExpr: {
372 const SCEVUDivExpr *UDiv = cast<SCEVUDivExpr>(this);
373 OS << "(" << *UDiv->getLHS() << " /u " << *UDiv->getRHS() << ")";
374 return;
375 }
376 case scUnknown:
377 cast<SCEVUnknown>(this)->getValue()->printAsOperand(OS, false);
378 return;
380 OS << "***COULDNOTCOMPUTE***";
381 return;
382 }
383 llvm_unreachable("Unknown SCEV kind!");
384}
385
387 switch (getSCEVType()) {
388 case scConstant:
389 return cast<SCEVConstant>(this)->getType();
390 case scVScale:
391 return cast<SCEVVScale>(this)->getType();
392 case scPtrToInt:
393 case scTruncate:
394 case scZeroExtend:
395 case scSignExtend:
396 return cast<SCEVCastExpr>(this)->getType();
397 case scAddRecExpr:
398 return cast<SCEVAddRecExpr>(this)->getType();
399 case scMulExpr:
400 return cast<SCEVMulExpr>(this)->getType();
401 case scUMaxExpr:
402 case scSMaxExpr:
403 case scUMinExpr:
404 case scSMinExpr:
405 return cast<SCEVMinMaxExpr>(this)->getType();
407 return cast<SCEVSequentialMinMaxExpr>(this)->getType();
408 case scAddExpr:
409 return cast<SCEVAddExpr>(this)->getType();
410 case scUDivExpr:
411 return cast<SCEVUDivExpr>(this)->getType();
412 case scUnknown:
413 return cast<SCEVUnknown>(this)->getType();
415 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
416 }
417 llvm_unreachable("Unknown SCEV kind!");
418}
419
421 switch (getSCEVType()) {
422 case scConstant:
423 case scVScale:
424 case scUnknown:
425 return {};
426 case scPtrToInt:
427 case scTruncate:
428 case scZeroExtend:
429 case scSignExtend:
430 return cast<SCEVCastExpr>(this)->operands();
431 case scAddRecExpr:
432 case scAddExpr:
433 case scMulExpr:
434 case scUMaxExpr:
435 case scSMaxExpr:
436 case scUMinExpr:
437 case scSMinExpr:
439 return cast<SCEVNAryExpr>(this)->operands();
440 case scUDivExpr:
441 return cast<SCEVUDivExpr>(this)->operands();
443 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
444 }
445 llvm_unreachable("Unknown SCEV kind!");
446}
447
448bool SCEV::isZero() const { return match(this, m_scev_Zero()); }
449
450bool SCEV::isOne() const { return match(this, m_scev_One()); }
451
452bool SCEV::isAllOnesValue() const { return match(this, m_scev_AllOnes()); }
453
455 const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(this);
456 if (!Mul) return false;
457
458 // If there is a constant factor, it will be first.
459 const SCEVConstant *SC = dyn_cast<SCEVConstant>(Mul->getOperand(0));
460 if (!SC) return false;
461
462 // Return true if the value is negative, this matches things like (-42 * V).
463 return SC->getAPInt().isNegative();
464}
465
468
470 return S->getSCEVType() == scCouldNotCompute;
471}
472
475 ID.AddInteger(scConstant);
476 ID.AddPointer(V);
477 void *IP = nullptr;
478 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
479 SCEV *S = new (SCEVAllocator) SCEVConstant(ID.Intern(SCEVAllocator), V);
480 UniqueSCEVs.InsertNode(S, IP);
481 return S;
482}
483
485 return getConstant(ConstantInt::get(getContext(), Val));
486}
487
488const SCEV *
490 IntegerType *ITy = cast<IntegerType>(getEffectiveSCEVType(Ty));
491 return getConstant(ConstantInt::get(ITy, V, isSigned));
492}
493
496 ID.AddInteger(scVScale);
497 ID.AddPointer(Ty);
498 void *IP = nullptr;
499 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP))
500 return S;
501 SCEV *S = new (SCEVAllocator) SCEVVScale(ID.Intern(SCEVAllocator), Ty);
502 UniqueSCEVs.InsertNode(S, IP);
503 return S;
504}
505
507 const SCEV *Res = getConstant(Ty, EC.getKnownMinValue());
508 if (EC.isScalable())
509 Res = getMulExpr(Res, getVScale(Ty));
510 return Res;
511}
512
514 const SCEV *op, Type *ty)
515 : SCEV(ID, SCEVTy, computeExpressionSize(op)), Op(op), Ty(ty) {}
516
517SCEVPtrToIntExpr::SCEVPtrToIntExpr(const FoldingSetNodeIDRef ID, const SCEV *Op,
518 Type *ITy)
519 : SCEVCastExpr(ID, scPtrToInt, Op, ITy) {
521 "Must be a non-bit-width-changing pointer-to-integer cast!");
522}
523
525 SCEVTypes SCEVTy, const SCEV *op,
526 Type *ty)
527 : SCEVCastExpr(ID, SCEVTy, op, ty) {}
528
529SCEVTruncateExpr::SCEVTruncateExpr(const FoldingSetNodeIDRef ID, const SCEV *op,
530 Type *ty)
532 assert(getOperand()->getType()->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
533 "Cannot truncate non-integer value!");
534}
535
536SCEVZeroExtendExpr::SCEVZeroExtendExpr(const FoldingSetNodeIDRef ID,
537 const SCEV *op, Type *ty)
539 assert(getOperand()->getType()->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
540 "Cannot zero extend non-integer value!");
541}
542
543SCEVSignExtendExpr::SCEVSignExtendExpr(const FoldingSetNodeIDRef ID,
544 const SCEV *op, Type *ty)
546 assert(getOperand()->getType()->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
547 "Cannot sign extend non-integer value!");
548}
549
550void SCEVUnknown::deleted() {
551 // Clear this SCEVUnknown from various maps.
552 SE->forgetMemoizedResults(this);
553
554 // Remove this SCEVUnknown from the uniquing map.
555 SE->UniqueSCEVs.RemoveNode(this);
556
557 // Release the value.
558 setValPtr(nullptr);
559}
560
561void SCEVUnknown::allUsesReplacedWith(Value *New) {
562 // Clear this SCEVUnknown from various maps.
563 SE->forgetMemoizedResults(this);
564
565 // Remove this SCEVUnknown from the uniquing map.
566 SE->UniqueSCEVs.RemoveNode(this);
567
568 // Replace the value pointer in case someone is still using this SCEVUnknown.
569 setValPtr(New);
570}
571
572//===----------------------------------------------------------------------===//
573// SCEV Utilities
574//===----------------------------------------------------------------------===//
575
576/// Compare the two values \p LV and \p RV in terms of their "complexity" where
577/// "complexity" is a partial (and somewhat ad-hoc) relation used to order
578/// operands in SCEV expressions.
579static int CompareValueComplexity(const LoopInfo *const LI, Value *LV,
580 Value *RV, unsigned Depth) {
582 return 0;
583
584 // Order pointer values after integer values. This helps SCEVExpander form
585 // GEPs.
586 bool LIsPointer = LV->getType()->isPointerTy(),
587 RIsPointer = RV->getType()->isPointerTy();
588 if (LIsPointer != RIsPointer)
589 return (int)LIsPointer - (int)RIsPointer;
590
591 // Compare getValueID values.
592 unsigned LID = LV->getValueID(), RID = RV->getValueID();
593 if (LID != RID)
594 return (int)LID - (int)RID;
595
596 // Sort arguments by their position.
597 if (const auto *LA = dyn_cast<Argument>(LV)) {
598 const auto *RA = cast<Argument>(RV);
599 unsigned LArgNo = LA->getArgNo(), RArgNo = RA->getArgNo();
600 return (int)LArgNo - (int)RArgNo;
601 }
602
603 if (const auto *LGV = dyn_cast<GlobalValue>(LV)) {
604 const auto *RGV = cast<GlobalValue>(RV);
605
606 const auto IsGVNameSemantic = [&](const GlobalValue *GV) {
607 auto LT = GV->getLinkage();
608 return !(GlobalValue::isPrivateLinkage(LT) ||
610 };
611
612 // Use the names to distinguish the two values, but only if the
613 // names are semantically important.
614 if (IsGVNameSemantic(LGV) && IsGVNameSemantic(RGV))
615 return LGV->getName().compare(RGV->getName());
616 }
617
618 // For instructions, compare their loop depth, and their operand count. This
619 // is pretty loose.
620 if (const auto *LInst = dyn_cast<Instruction>(LV)) {
621 const auto *RInst = cast<Instruction>(RV);
622
623 // Compare loop depths.
624 const BasicBlock *LParent = LInst->getParent(),
625 *RParent = RInst->getParent();
626 if (LParent != RParent) {
627 unsigned LDepth = LI->getLoopDepth(LParent),
628 RDepth = LI->getLoopDepth(RParent);
629 if (LDepth != RDepth)
630 return (int)LDepth - (int)RDepth;
631 }
632
633 // Compare the number of operands.
634 unsigned LNumOps = LInst->getNumOperands(),
635 RNumOps = RInst->getNumOperands();
636 if (LNumOps != RNumOps)
637 return (int)LNumOps - (int)RNumOps;
638
639 for (unsigned Idx : seq(LNumOps)) {
640 int Result = CompareValueComplexity(LI, LInst->getOperand(Idx),
641 RInst->getOperand(Idx), Depth + 1);
642 if (Result != 0)
643 return Result;
644 }
645 }
646
647 return 0;
648}
649
650// Return negative, zero, or positive, if LHS is less than, equal to, or greater
651// than RHS, respectively. A three-way result allows recursive comparisons to be
652// more efficient.
653// If the max analysis depth was reached, return std::nullopt, assuming we do
654// not know if they are equivalent for sure.
655static std::optional<int>
657 const LoopInfo *const LI, const SCEV *LHS,
658 const SCEV *RHS, DominatorTree &DT, unsigned Depth = 0) {
659 // Fast-path: SCEVs are uniqued so we can do a quick equality check.
660 if (LHS == RHS)
661 return 0;
662
663 // Primarily, sort the SCEVs by their getSCEVType().
664 SCEVTypes LType = LHS->getSCEVType(), RType = RHS->getSCEVType();
665 if (LType != RType)
666 return (int)LType - (int)RType;
667
668 if (EqCacheSCEV.isEquivalent(LHS, RHS))
669 return 0;
670
672 return std::nullopt;
673
674 // Aside from the getSCEVType() ordering, the particular ordering
675 // isn't very important except that it's beneficial to be consistent,
676 // so that (a + b) and (b + a) don't end up as different expressions.
677 switch (LType) {
678 case scUnknown: {
679 const SCEVUnknown *LU = cast<SCEVUnknown>(LHS);
680 const SCEVUnknown *RU = cast<SCEVUnknown>(RHS);
681
682 int X =
683 CompareValueComplexity(LI, LU->getValue(), RU->getValue(), Depth + 1);
684 if (X == 0)
685 EqCacheSCEV.unionSets(LHS, RHS);
686 return X;
687 }
688
689 case scConstant: {
690 const SCEVConstant *LC = cast<SCEVConstant>(LHS);
691 const SCEVConstant *RC = cast<SCEVConstant>(RHS);
692
693 // Compare constant values.
694 const APInt &LA = LC->getAPInt();
695 const APInt &RA = RC->getAPInt();
696 unsigned LBitWidth = LA.getBitWidth(), RBitWidth = RA.getBitWidth();
697 if (LBitWidth != RBitWidth)
698 return (int)LBitWidth - (int)RBitWidth;
699 return LA.ult(RA) ? -1 : 1;
700 }
701
702 case scVScale: {
703 const auto *LTy = cast<IntegerType>(cast<SCEVVScale>(LHS)->getType());
704 const auto *RTy = cast<IntegerType>(cast<SCEVVScale>(RHS)->getType());
705 return LTy->getBitWidth() - RTy->getBitWidth();
706 }
707
708 case scAddRecExpr: {
709 const SCEVAddRecExpr *LA = cast<SCEVAddRecExpr>(LHS);
710 const SCEVAddRecExpr *RA = cast<SCEVAddRecExpr>(RHS);
711
712 // There is always a dominance between two recs that are used by one SCEV,
713 // so we can safely sort recs by loop header dominance. We require such
714 // order in getAddExpr.
715 const Loop *LLoop = LA->getLoop(), *RLoop = RA->getLoop();
716 if (LLoop != RLoop) {
717 const BasicBlock *LHead = LLoop->getHeader(), *RHead = RLoop->getHeader();
718 assert(LHead != RHead && "Two loops share the same header?");
719 if (DT.dominates(LHead, RHead))
720 return 1;
721 assert(DT.dominates(RHead, LHead) &&
722 "No dominance between recurrences used by one SCEV?");
723 return -1;
724 }
725
726 [[fallthrough]];
727 }
728
729 case scTruncate:
730 case scZeroExtend:
731 case scSignExtend:
732 case scPtrToInt:
733 case scAddExpr:
734 case scMulExpr:
735 case scUDivExpr:
736 case scSMaxExpr:
737 case scUMaxExpr:
738 case scSMinExpr:
739 case scUMinExpr:
741 ArrayRef<const SCEV *> LOps = LHS->operands();
742 ArrayRef<const SCEV *> ROps = RHS->operands();
743
744 // Lexicographically compare n-ary-like expressions.
745 unsigned LNumOps = LOps.size(), RNumOps = ROps.size();
746 if (LNumOps != RNumOps)
747 return (int)LNumOps - (int)RNumOps;
748
749 for (unsigned i = 0; i != LNumOps; ++i) {
750 auto X = CompareSCEVComplexity(EqCacheSCEV, LI, LOps[i], ROps[i], DT,
751 Depth + 1);
752 if (X != 0)
753 return X;
754 }
755 EqCacheSCEV.unionSets(LHS, RHS);
756 return 0;
757 }
758
760 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
761 }
762 llvm_unreachable("Unknown SCEV kind!");
763}
764
765/// Given a list of SCEV objects, order them by their complexity, and group
766/// objects of the same complexity together by value. When this routine is
767/// finished, we know that any duplicates in the vector are consecutive and that
768/// complexity is monotonically increasing.
769///
770/// Note that we go take special precautions to ensure that we get deterministic
771/// results from this routine. In other words, we don't want the results of
772/// this to depend on where the addresses of various SCEV objects happened to
773/// land in memory.
775 LoopInfo *LI, DominatorTree &DT) {
776 if (Ops.size() < 2) return; // Noop
777
779
780 // Whether LHS has provably less complexity than RHS.
781 auto IsLessComplex = [&](const SCEV *LHS, const SCEV *RHS) {
782 auto Complexity = CompareSCEVComplexity(EqCacheSCEV, LI, LHS, RHS, DT);
783 return Complexity && *Complexity < 0;
784 };
785 if (Ops.size() == 2) {
786 // This is the common case, which also happens to be trivially simple.
787 // Special case it.
788 const SCEV *&LHS = Ops[0], *&RHS = Ops[1];
789 if (IsLessComplex(RHS, LHS))
790 std::swap(LHS, RHS);
791 return;
792 }
793
794 // Do the rough sort by complexity.
795 llvm::stable_sort(Ops, [&](const SCEV *LHS, const SCEV *RHS) {
796 return IsLessComplex(LHS, RHS);
797 });
798
799 // Now that we are sorted by complexity, group elements of the same
800 // complexity. Note that this is, at worst, N^2, but the vector is likely to
801 // be extremely short in practice. Note that we take this approach because we
802 // do not want to depend on the addresses of the objects we are grouping.
803 for (unsigned i = 0, e = Ops.size(); i != e-2; ++i) {
804 const SCEV *S = Ops[i];
805 unsigned Complexity = S->getSCEVType();
806
807 // If there are any objects of the same complexity and same value as this
808 // one, group them.
809 for (unsigned j = i+1; j != e && Ops[j]->getSCEVType() == Complexity; ++j) {
810 if (Ops[j] == S) { // Found a duplicate.
811 // Move it to immediately after i'th element.
812 std::swap(Ops[i+1], Ops[j]);
813 ++i; // no need to rescan it.
814 if (i == e-2) return; // Done!
815 }
816 }
817 }
818}
819
820/// Returns true if \p Ops contains a huge SCEV (the subtree of S contains at
821/// least HugeExprThreshold nodes).
823 return any_of(Ops, [](const SCEV *S) {
825 });
826}
827
828/// Performs a number of common optimizations on the passed \p Ops. If the
829/// whole expression reduces down to a single operand, it will be returned.
830///
831/// The following optimizations are performed:
832/// * Fold constants using the \p Fold function.
833/// * Remove identity constants satisfying \p IsIdentity.
834/// * If a constant satisfies \p IsAbsorber, return it.
835/// * Sort operands by complexity.
836template <typename FoldT, typename IsIdentityT, typename IsAbsorberT>
837static const SCEV *
839 SmallVectorImpl<const SCEV *> &Ops, FoldT Fold,
840 IsIdentityT IsIdentity, IsAbsorberT IsAbsorber) {
841 const SCEVConstant *Folded = nullptr;
842 for (unsigned Idx = 0; Idx < Ops.size();) {
843 const SCEV *Op = Ops[Idx];
844 if (const auto *C = dyn_cast<SCEVConstant>(Op)) {
845 if (!Folded)
846 Folded = C;
847 else
848 Folded = cast<SCEVConstant>(
849 SE.getConstant(Fold(Folded->getAPInt(), C->getAPInt())));
850 Ops.erase(Ops.begin() + Idx);
851 continue;
852 }
853 ++Idx;
854 }
855
856 if (Ops.empty()) {
857 assert(Folded && "Must have folded value");
858 return Folded;
859 }
860
861 if (Folded && IsAbsorber(Folded->getAPInt()))
862 return Folded;
863
864 GroupByComplexity(Ops, &LI, DT);
865 if (Folded && !IsIdentity(Folded->getAPInt()))
866 Ops.insert(Ops.begin(), Folded);
867
868 return Ops.size() == 1 ? Ops[0] : nullptr;
869}
870
871//===----------------------------------------------------------------------===//
872// Simple SCEV method implementations
873//===----------------------------------------------------------------------===//
874
875/// Compute BC(It, K). The result has width W. Assume, K > 0.
876static const SCEV *BinomialCoefficient(const SCEV *It, unsigned K,
877 ScalarEvolution &SE,
878 Type *ResultTy) {
879 // Handle the simplest case efficiently.
880 if (K == 1)
881 return SE.getTruncateOrZeroExtend(It, ResultTy);
882
883 // We are using the following formula for BC(It, K):
884 //
885 // BC(It, K) = (It * (It - 1) * ... * (It - K + 1)) / K!
886 //
887 // Suppose, W is the bitwidth of the return value. We must be prepared for
888 // overflow. Hence, we must assure that the result of our computation is
889 // equal to the accurate one modulo 2^W. Unfortunately, division isn't
890 // safe in modular arithmetic.
891 //
892 // However, this code doesn't use exactly that formula; the formula it uses
893 // is something like the following, where T is the number of factors of 2 in
894 // K! (i.e. trailing zeros in the binary representation of K!), and ^ is
895 // exponentiation:
896 //
897 // BC(It, K) = (It * (It - 1) * ... * (It - K + 1)) / 2^T / (K! / 2^T)
898 //
899 // This formula is trivially equivalent to the previous formula. However,
900 // this formula can be implemented much more efficiently. The trick is that
901 // K! / 2^T is odd, and exact division by an odd number *is* safe in modular
902 // arithmetic. To do exact division in modular arithmetic, all we have
903 // to do is multiply by the inverse. Therefore, this step can be done at
904 // width W.
905 //
906 // The next issue is how to safely do the division by 2^T. The way this
907 // is done is by doing the multiplication step at a width of at least W + T
908 // bits. This way, the bottom W+T bits of the product are accurate. Then,
909 // when we perform the division by 2^T (which is equivalent to a right shift
910 // by T), the bottom W bits are accurate. Extra bits are okay; they'll get
911 // truncated out after the division by 2^T.
912 //
913 // In comparison to just directly using the first formula, this technique
914 // is much more efficient; using the first formula requires W * K bits,
915 // but this formula less than W + K bits. Also, the first formula requires
916 // a division step, whereas this formula only requires multiplies and shifts.
917 //
918 // It doesn't matter whether the subtraction step is done in the calculation
919 // width or the input iteration count's width; if the subtraction overflows,
920 // the result must be zero anyway. We prefer here to do it in the width of
921 // the induction variable because it helps a lot for certain cases; CodeGen
922 // isn't smart enough to ignore the overflow, which leads to much less
923 // efficient code if the width of the subtraction is wider than the native
924 // register width.
925 //
926 // (It's possible to not widen at all by pulling out factors of 2 before
927 // the multiplication; for example, K=2 can be calculated as
928 // It/2*(It+(It*INT_MIN/INT_MIN)+-1). However, it requires
929 // extra arithmetic, so it's not an obvious win, and it gets
930 // much more complicated for K > 3.)
931
932 // Protection from insane SCEVs; this bound is conservative,
933 // but it probably doesn't matter.
934 if (K > 1000)
935 return SE.getCouldNotCompute();
936
937 unsigned W = SE.getTypeSizeInBits(ResultTy);
938
939 // Calculate K! / 2^T and T; we divide out the factors of two before
940 // multiplying for calculating K! / 2^T to avoid overflow.
941 // Other overflow doesn't matter because we only care about the bottom
942 // W bits of the result.
943 APInt OddFactorial(W, 1);
944 unsigned T = 1;
945 for (unsigned i = 3; i <= K; ++i) {
946 unsigned TwoFactors = countr_zero(i);
947 T += TwoFactors;
948 OddFactorial *= (i >> TwoFactors);
949 }
950
951 // We need at least W + T bits for the multiplication step
952 unsigned CalculationBits = W + T;
953
954 // Calculate 2^T, at width T+W.
955 APInt DivFactor = APInt::getOneBitSet(CalculationBits, T);
956
957 // Calculate the multiplicative inverse of K! / 2^T;
958 // this multiplication factor will perform the exact division by
959 // K! / 2^T.
960 APInt MultiplyFactor = OddFactorial.multiplicativeInverse();
961
962 // Calculate the product, at width T+W
963 IntegerType *CalculationTy = IntegerType::get(SE.getContext(),
964 CalculationBits);
965 const SCEV *Dividend = SE.getTruncateOrZeroExtend(It, CalculationTy);
966 for (unsigned i = 1; i != K; ++i) {
967 const SCEV *S = SE.getMinusSCEV(It, SE.getConstant(It->getType(), i));
968 Dividend = SE.getMulExpr(Dividend,
969 SE.getTruncateOrZeroExtend(S, CalculationTy));
970 }
971
972 // Divide by 2^T
973 const SCEV *DivResult = SE.getUDivExpr(Dividend, SE.getConstant(DivFactor));
974
975 // Truncate the result, and divide by K! / 2^T.
976
977 return SE.getMulExpr(SE.getConstant(MultiplyFactor),
978 SE.getTruncateOrZeroExtend(DivResult, ResultTy));
979}
980
981/// Return the value of this chain of recurrences at the specified iteration
982/// number. We can evaluate this recurrence by multiplying each element in the
983/// chain by the binomial coefficient corresponding to it. In other words, we
984/// can evaluate {A,+,B,+,C,+,D} as:
985///
986/// A*BC(It, 0) + B*BC(It, 1) + C*BC(It, 2) + D*BC(It, 3)
987///
988/// where BC(It, k) stands for binomial coefficient.
990 ScalarEvolution &SE) const {
991 return evaluateAtIteration(operands(), It, SE);
992}
993
994const SCEV *
996 const SCEV *It, ScalarEvolution &SE) {
997 assert(Operands.size() > 0);
998 const SCEV *Result = Operands[0];
999 for (unsigned i = 1, e = Operands.size(); i != e; ++i) {
1000 // The computation is correct in the face of overflow provided that the
1001 // multiplication is performed _after_ the evaluation of the binomial
1002 // coefficient.
1003 const SCEV *Coeff = BinomialCoefficient(It, i, SE, Result->getType());
1004 if (isa<SCEVCouldNotCompute>(Coeff))
1005 return Coeff;
1006
1007 Result = SE.getAddExpr(Result, SE.getMulExpr(Operands[i], Coeff));
1008 }
1009 return Result;
1010}
1011
1012//===----------------------------------------------------------------------===//
1013// SCEV Expression folder implementations
1014//===----------------------------------------------------------------------===//
1015
1017 unsigned Depth) {
1018 assert(Depth <= 1 &&
1019 "getLosslessPtrToIntExpr() should self-recurse at most once.");
1020
1021 // We could be called with an integer-typed operands during SCEV rewrites.
1022 // Since the operand is an integer already, just perform zext/trunc/self cast.
1023 if (!Op->getType()->isPointerTy())
1024 return Op;
1025
1026 // What would be an ID for such a SCEV cast expression?
1028 ID.AddInteger(scPtrToInt);
1029 ID.AddPointer(Op);
1030
1031 void *IP = nullptr;
1032
1033 // Is there already an expression for such a cast?
1034 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP))
1035 return S;
1036
1037 // It isn't legal for optimizations to construct new ptrtoint expressions
1038 // for non-integral pointers.
1039 if (getDataLayout().isNonIntegralPointerType(Op->getType()))
1040 return getCouldNotCompute();
1041
1042 Type *IntPtrTy = getDataLayout().getIntPtrType(Op->getType());
1043
1044 // We can only trivially model ptrtoint if SCEV's effective (integer) type
1045 // is sufficiently wide to represent all possible pointer values.
1046 // We could theoretically teach SCEV to truncate wider pointers, but
1047 // that isn't implemented for now.
1049 getDataLayout().getTypeSizeInBits(IntPtrTy))
1050 return getCouldNotCompute();
1051
1052 // If not, is this expression something we can't reduce any further?
1053 if (auto *U = dyn_cast<SCEVUnknown>(Op)) {
1054 // Perform some basic constant folding. If the operand of the ptr2int cast
1055 // is a null pointer, don't create a ptr2int SCEV expression (that will be
1056 // left as-is), but produce a zero constant.
1057 // NOTE: We could handle a more general case, but lack motivational cases.
1058 if (isa<ConstantPointerNull>(U->getValue()))
1059 return getZero(IntPtrTy);
1060
1061 // Create an explicit cast node.
1062 // We can reuse the existing insert position since if we get here,
1063 // we won't have made any changes which would invalidate it.
1064 SCEV *S = new (SCEVAllocator)
1065 SCEVPtrToIntExpr(ID.Intern(SCEVAllocator), Op, IntPtrTy);
1066 UniqueSCEVs.InsertNode(S, IP);
1067 registerUser(S, Op);
1068 return S;
1069 }
1070
1071 assert(Depth == 0 && "getLosslessPtrToIntExpr() should not self-recurse for "
1072 "non-SCEVUnknown's.");
1073
1074 // Otherwise, we've got some expression that is more complex than just a
1075 // single SCEVUnknown. But we don't want to have a SCEVPtrToIntExpr of an
1076 // arbitrary expression, we want to have SCEVPtrToIntExpr of an SCEVUnknown
1077 // only, and the expressions must otherwise be integer-typed.
1078 // So sink the cast down to the SCEVUnknown's.
1079
1080 /// The SCEVPtrToIntSinkingRewriter takes a scalar evolution expression,
1081 /// which computes a pointer-typed value, and rewrites the whole expression
1082 /// tree so that *all* the computations are done on integers, and the only
1083 /// pointer-typed operands in the expression are SCEVUnknown.
1084 class SCEVPtrToIntSinkingRewriter
1085 : public SCEVRewriteVisitor<SCEVPtrToIntSinkingRewriter> {
1087
1088 public:
1089 SCEVPtrToIntSinkingRewriter(ScalarEvolution &SE) : SCEVRewriteVisitor(SE) {}
1090
1091 static const SCEV *rewrite(const SCEV *Scev, ScalarEvolution &SE) {
1092 SCEVPtrToIntSinkingRewriter Rewriter(SE);
1093 return Rewriter.visit(Scev);
1094 }
1095
1096 const SCEV *visit(const SCEV *S) {
1097 Type *STy = S->getType();
1098 // If the expression is not pointer-typed, just keep it as-is.
1099 if (!STy->isPointerTy())
1100 return S;
1101 // Else, recursively sink the cast down into it.
1102 return Base::visit(S);
1103 }
1104
1105 const SCEV *visitAddExpr(const SCEVAddExpr *Expr) {
1107 bool Changed = false;
1108 for (const auto *Op : Expr->operands()) {
1109 Operands.push_back(visit(Op));
1110 Changed |= Op != Operands.back();
1111 }
1112 return !Changed ? Expr : SE.getAddExpr(Operands, Expr->getNoWrapFlags());
1113 }
1114
1115 const SCEV *visitMulExpr(const SCEVMulExpr *Expr) {
1117 bool Changed = false;
1118 for (const auto *Op : Expr->operands()) {
1119 Operands.push_back(visit(Op));
1120 Changed |= Op != Operands.back();
1121 }
1122 return !Changed ? Expr : SE.getMulExpr(Operands, Expr->getNoWrapFlags());
1123 }
1124
1125 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
1126 assert(Expr->getType()->isPointerTy() &&
1127 "Should only reach pointer-typed SCEVUnknown's.");
1128 return SE.getLosslessPtrToIntExpr(Expr, /*Depth=*/1);
1129 }
1130 };
1131
1132 // And actually perform the cast sinking.
1133 const SCEV *IntOp = SCEVPtrToIntSinkingRewriter::rewrite(Op, *this);
1134 assert(IntOp->getType()->isIntegerTy() &&
1135 "We must have succeeded in sinking the cast, "
1136 "and ending up with an integer-typed expression!");
1137 return IntOp;
1138}
1139
1141 assert(Ty->isIntegerTy() && "Target type must be an integer type!");
1142
1143 const SCEV *IntOp = getLosslessPtrToIntExpr(Op);
1144 if (isa<SCEVCouldNotCompute>(IntOp))
1145 return IntOp;
1146
1147 return getTruncateOrZeroExtend(IntOp, Ty);
1148}
1149
1151 unsigned Depth) {
1152 assert(getTypeSizeInBits(Op->getType()) > getTypeSizeInBits(Ty) &&
1153 "This is not a truncating conversion!");
1154 assert(isSCEVable(Ty) &&
1155 "This is not a conversion to a SCEVable type!");
1156 assert(!Op->getType()->isPointerTy() && "Can't truncate pointer!");
1157 Ty = getEffectiveSCEVType(Ty);
1158
1160 ID.AddInteger(scTruncate);
1161 ID.AddPointer(Op);
1162 ID.AddPointer(Ty);
1163 void *IP = nullptr;
1164 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
1165
1166 // Fold if the operand is constant.
1167 if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op))
1168 return getConstant(
1169 cast<ConstantInt>(ConstantExpr::getTrunc(SC->getValue(), Ty)));
1170
1171 // trunc(trunc(x)) --> trunc(x)
1172 if (const SCEVTruncateExpr *ST = dyn_cast<SCEVTruncateExpr>(Op))
1173 return getTruncateExpr(ST->getOperand(), Ty, Depth + 1);
1174
1175 // trunc(sext(x)) --> sext(x) if widening or trunc(x) if narrowing
1176 if (const SCEVSignExtendExpr *SS = dyn_cast<SCEVSignExtendExpr>(Op))
1177 return getTruncateOrSignExtend(SS->getOperand(), Ty, Depth + 1);
1178
1179 // trunc(zext(x)) --> zext(x) if widening or trunc(x) if narrowing
1180 if (const SCEVZeroExtendExpr *SZ = dyn_cast<SCEVZeroExtendExpr>(Op))
1181 return getTruncateOrZeroExtend(SZ->getOperand(), Ty, Depth + 1);
1182
1183 if (Depth > MaxCastDepth) {
1184 SCEV *S =
1185 new (SCEVAllocator) SCEVTruncateExpr(ID.Intern(SCEVAllocator), Op, Ty);
1186 UniqueSCEVs.InsertNode(S, IP);
1187 registerUser(S, Op);
1188 return S;
1189 }
1190
1191 // trunc(x1 + ... + xN) --> trunc(x1) + ... + trunc(xN) and
1192 // trunc(x1 * ... * xN) --> trunc(x1) * ... * trunc(xN),
1193 // if after transforming we have at most one truncate, not counting truncates
1194 // that replace other casts.
1195 if (isa<SCEVAddExpr>(Op) || isa<SCEVMulExpr>(Op)) {
1196 auto *CommOp = cast<SCEVCommutativeExpr>(Op);
1198 unsigned numTruncs = 0;
1199 for (unsigned i = 0, e = CommOp->getNumOperands(); i != e && numTruncs < 2;
1200 ++i) {
1201 const SCEV *S = getTruncateExpr(CommOp->getOperand(i), Ty, Depth + 1);
1202 if (!isa<SCEVIntegralCastExpr>(CommOp->getOperand(i)) &&
1203 isa<SCEVTruncateExpr>(S))
1204 numTruncs++;
1205 Operands.push_back(S);
1206 }
1207 if (numTruncs < 2) {
1208 if (isa<SCEVAddExpr>(Op))
1209 return getAddExpr(Operands);
1210 if (isa<SCEVMulExpr>(Op))
1211 return getMulExpr(Operands);
1212 llvm_unreachable("Unexpected SCEV type for Op.");
1213 }
1214 // Although we checked in the beginning that ID is not in the cache, it is
1215 // possible that during recursion and different modification ID was inserted
1216 // into the cache. So if we find it, just return it.
1217 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP))
1218 return S;
1219 }
1220
1221 // If the input value is a chrec scev, truncate the chrec's operands.
1222 if (const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(Op)) {
1224 for (const SCEV *Op : AddRec->operands())
1225 Operands.push_back(getTruncateExpr(Op, Ty, Depth + 1));
1226 return getAddRecExpr(Operands, AddRec->getLoop(), SCEV::FlagAnyWrap);
1227 }
1228
1229 // Return zero if truncating to known zeros.
1230 uint32_t MinTrailingZeros = getMinTrailingZeros(Op);
1231 if (MinTrailingZeros >= getTypeSizeInBits(Ty))
1232 return getZero(Ty);
1233
1234 // The cast wasn't folded; create an explicit cast node. We can reuse
1235 // the existing insert position since if we get here, we won't have
1236 // made any changes which would invalidate it.
1237 SCEV *S = new (SCEVAllocator) SCEVTruncateExpr(ID.Intern(SCEVAllocator),
1238 Op, Ty);
1239 UniqueSCEVs.InsertNode(S, IP);
1240 registerUser(S, Op);
1241 return S;
1242}
1243
1244// Get the limit of a recurrence such that incrementing by Step cannot cause
1245// signed overflow as long as the value of the recurrence within the
1246// loop does not exceed this limit before incrementing.
1247static const SCEV *getSignedOverflowLimitForStep(const SCEV *Step,
1248 ICmpInst::Predicate *Pred,
1249 ScalarEvolution *SE) {
1250 unsigned BitWidth = SE->getTypeSizeInBits(Step->getType());
1251 if (SE->isKnownPositive(Step)) {
1252 *Pred = ICmpInst::ICMP_SLT;
1254 SE->getSignedRangeMax(Step));
1255 }
1256 if (SE->isKnownNegative(Step)) {
1257 *Pred = ICmpInst::ICMP_SGT;
1259 SE->getSignedRangeMin(Step));
1260 }
1261 return nullptr;
1262}
1263
1264// Get the limit of a recurrence such that incrementing by Step cannot cause
1265// unsigned overflow as long as the value of the recurrence within the loop does
1266// not exceed this limit before incrementing.
1268 ICmpInst::Predicate *Pred,
1269 ScalarEvolution *SE) {
1270 unsigned BitWidth = SE->getTypeSizeInBits(Step->getType());
1271 *Pred = ICmpInst::ICMP_ULT;
1272
1274 SE->getUnsignedRangeMax(Step));
1275}
1276
1277namespace {
1278
1279struct ExtendOpTraitsBase {
1280 typedef const SCEV *(ScalarEvolution::*GetExtendExprTy)(const SCEV *, Type *,
1281 unsigned);
1282};
1283
1284// Used to make code generic over signed and unsigned overflow.
1285template <typename ExtendOp> struct ExtendOpTraits {
1286 // Members present:
1287 //
1288 // static const SCEV::NoWrapFlags WrapType;
1289 //
1290 // static const ExtendOpTraitsBase::GetExtendExprTy GetExtendExpr;
1291 //
1292 // static const SCEV *getOverflowLimitForStep(const SCEV *Step,
1293 // ICmpInst::Predicate *Pred,
1294 // ScalarEvolution *SE);
1295};
1296
1297template <>
1298struct ExtendOpTraits<SCEVSignExtendExpr> : public ExtendOpTraitsBase {
1299 static const SCEV::NoWrapFlags WrapType = SCEV::FlagNSW;
1300
1301 static const GetExtendExprTy GetExtendExpr;
1302
1303 static const SCEV *getOverflowLimitForStep(const SCEV *Step,
1304 ICmpInst::Predicate *Pred,
1305 ScalarEvolution *SE) {
1306 return getSignedOverflowLimitForStep(Step, Pred, SE);
1307 }
1308};
1309
1310const ExtendOpTraitsBase::GetExtendExprTy ExtendOpTraits<
1312
1313template <>
1314struct ExtendOpTraits<SCEVZeroExtendExpr> : public ExtendOpTraitsBase {
1315 static const SCEV::NoWrapFlags WrapType = SCEV::FlagNUW;
1316
1317 static const GetExtendExprTy GetExtendExpr;
1318
1319 static const SCEV *getOverflowLimitForStep(const SCEV *Step,
1320 ICmpInst::Predicate *Pred,
1321 ScalarEvolution *SE) {
1322 return getUnsignedOverflowLimitForStep(Step, Pred, SE);
1323 }
1324};
1325
1326const ExtendOpTraitsBase::GetExtendExprTy ExtendOpTraits<
1328
1329} // end anonymous namespace
1330
1331// The recurrence AR has been shown to have no signed/unsigned wrap or something
1332// close to it. Typically, if we can prove NSW/NUW for AR, then we can just as
1333// easily prove NSW/NUW for its preincrement or postincrement sibling. This
1334// allows normalizing a sign/zero extended AddRec as such: {sext/zext(Step +
1335// Start),+,Step} => {(Step + sext/zext(Start),+,Step} As a result, the
1336// expression "Step + sext/zext(PreIncAR)" is congruent with
1337// "sext/zext(PostIncAR)"
1338template <typename ExtendOpTy>
1339static const SCEV *getPreStartForExtend(const SCEVAddRecExpr *AR, Type *Ty,
1340 ScalarEvolution *SE, unsigned Depth) {
1341 auto WrapType = ExtendOpTraits<ExtendOpTy>::WrapType;
1342 auto GetExtendExpr = ExtendOpTraits<ExtendOpTy>::GetExtendExpr;
1343
1344 const Loop *L = AR->getLoop();
1345 const SCEV *Start = AR->getStart();
1346 const SCEV *Step = AR->getStepRecurrence(*SE);
1347
1348 // Check for a simple looking step prior to loop entry.
1349 const SCEVAddExpr *SA = dyn_cast<SCEVAddExpr>(Start);
1350 if (!SA)
1351 return nullptr;
1352
1353 // Create an AddExpr for "PreStart" after subtracting Step. Full SCEV
1354 // subtraction is expensive. For this purpose, perform a quick and dirty
1355 // difference, by checking for Step in the operand list. Note, that
1356 // SA might have repeated ops, like %a + %a + ..., so only remove one.
1358 for (auto It = DiffOps.begin(); It != DiffOps.end(); ++It)
1359 if (*It == Step) {
1360 DiffOps.erase(It);
1361 break;
1362 }
1363
1364 if (DiffOps.size() == SA->getNumOperands())
1365 return nullptr;
1366
1367 // Try to prove `WrapType` (SCEV::FlagNSW or SCEV::FlagNUW) on `PreStart` +
1368 // `Step`:
1369
1370 // 1. NSW/NUW flags on the step increment.
1371 auto PreStartFlags =
1373 const SCEV *PreStart = SE->getAddExpr(DiffOps, PreStartFlags);
1374 const SCEVAddRecExpr *PreAR = dyn_cast<SCEVAddRecExpr>(
1375 SE->getAddRecExpr(PreStart, Step, L, SCEV::FlagAnyWrap));
1376
1377 // "{S,+,X} is <nsw>/<nuw>" and "the backedge is taken at least once" implies
1378 // "S+X does not sign/unsign-overflow".
1379 //
1380
1381 const SCEV *BECount = SE->getBackedgeTakenCount(L);
1382 if (PreAR && PreAR->getNoWrapFlags(WrapType) &&
1383 !isa<SCEVCouldNotCompute>(BECount) && SE->isKnownPositive(BECount))
1384 return PreStart;
1385
1386 // 2. Direct overflow check on the step operation's expression.
1387 unsigned BitWidth = SE->getTypeSizeInBits(AR->getType());
1388 Type *WideTy = IntegerType::get(SE->getContext(), BitWidth * 2);
1389 const SCEV *OperandExtendedStart =
1390 SE->getAddExpr((SE->*GetExtendExpr)(PreStart, WideTy, Depth),
1391 (SE->*GetExtendExpr)(Step, WideTy, Depth));
1392 if ((SE->*GetExtendExpr)(Start, WideTy, Depth) == OperandExtendedStart) {
1393 if (PreAR && AR->getNoWrapFlags(WrapType)) {
1394 // If we know `AR` == {`PreStart`+`Step`,+,`Step`} is `WrapType` (FlagNSW
1395 // or FlagNUW) and that `PreStart` + `Step` is `WrapType` too, then
1396 // `PreAR` == {`PreStart`,+,`Step`} is also `WrapType`. Cache this fact.
1397 SE->setNoWrapFlags(const_cast<SCEVAddRecExpr *>(PreAR), WrapType);
1398 }
1399 return PreStart;
1400 }
1401
1402 // 3. Loop precondition.
1404 const SCEV *OverflowLimit =
1405 ExtendOpTraits<ExtendOpTy>::getOverflowLimitForStep(Step, &Pred, SE);
1406
1407 if (OverflowLimit &&
1408 SE->isLoopEntryGuardedByCond(L, Pred, PreStart, OverflowLimit))
1409 return PreStart;
1410
1411 return nullptr;
1412}
1413
1414// Get the normalized zero or sign extended expression for this AddRec's Start.
1415template <typename ExtendOpTy>
1416static const SCEV *getExtendAddRecStart(const SCEVAddRecExpr *AR, Type *Ty,
1417 ScalarEvolution *SE,
1418 unsigned Depth) {
1419 auto GetExtendExpr = ExtendOpTraits<ExtendOpTy>::GetExtendExpr;
1420
1421 const SCEV *PreStart = getPreStartForExtend<ExtendOpTy>(AR, Ty, SE, Depth);
1422 if (!PreStart)
1423 return (SE->*GetExtendExpr)(AR->getStart(), Ty, Depth);
1424
1425 return SE->getAddExpr((SE->*GetExtendExpr)(AR->getStepRecurrence(*SE), Ty,
1426 Depth),
1427 (SE->*GetExtendExpr)(PreStart, Ty, Depth));
1428}
1429
1430// Try to prove away overflow by looking at "nearby" add recurrences. A
1431// motivating example for this rule: if we know `{0,+,4}` is `ult` `-1` and it
1432// does not itself wrap then we can conclude that `{1,+,4}` is `nuw`.
1433//
1434// Formally:
1435//
1436// {S,+,X} == {S-T,+,X} + T
1437// => Ext({S,+,X}) == Ext({S-T,+,X} + T)
1438//
1439// If ({S-T,+,X} + T) does not overflow ... (1)
1440//
1441// RHS == Ext({S-T,+,X} + T) == Ext({S-T,+,X}) + Ext(T)
1442//
1443// If {S-T,+,X} does not overflow ... (2)
1444//
1445// RHS == Ext({S-T,+,X}) + Ext(T) == {Ext(S-T),+,Ext(X)} + Ext(T)
1446// == {Ext(S-T)+Ext(T),+,Ext(X)}
1447//
1448// If (S-T)+T does not overflow ... (3)
1449//
1450// RHS == {Ext(S-T)+Ext(T),+,Ext(X)} == {Ext(S-T+T),+,Ext(X)}
1451// == {Ext(S),+,Ext(X)} == LHS
1452//
1453// Thus, if (1), (2) and (3) are true for some T, then
1454// Ext({S,+,X}) == {Ext(S),+,Ext(X)}
1455//
1456// (3) is implied by (1) -- "(S-T)+T does not overflow" is simply "({S-T,+,X}+T)
1457// does not overflow" restricted to the 0th iteration. Therefore we only need
1458// to check for (1) and (2).
1459//
1460// In the current context, S is `Start`, X is `Step`, Ext is `ExtendOpTy` and T
1461// is `Delta` (defined below).
1462template <typename ExtendOpTy>
1463bool ScalarEvolution::proveNoWrapByVaryingStart(const SCEV *Start,
1464 const SCEV *Step,
1465 const Loop *L) {
1466 auto WrapType = ExtendOpTraits<ExtendOpTy>::WrapType;
1467
1468 // We restrict `Start` to a constant to prevent SCEV from spending too much
1469 // time here. It is correct (but more expensive) to continue with a
1470 // non-constant `Start` and do a general SCEV subtraction to compute
1471 // `PreStart` below.
1472 const SCEVConstant *StartC = dyn_cast<SCEVConstant>(Start);
1473 if (!StartC)
1474 return false;
1475
1476 APInt StartAI = StartC->getAPInt();
1477
1478 for (unsigned Delta : {-2, -1, 1, 2}) {
1479 const SCEV *PreStart = getConstant(StartAI - Delta);
1480
1482 ID.AddInteger(scAddRecExpr);
1483 ID.AddPointer(PreStart);
1484 ID.AddPointer(Step);
1485 ID.AddPointer(L);
1486 void *IP = nullptr;
1487 const auto *PreAR =
1488 static_cast<SCEVAddRecExpr *>(UniqueSCEVs.FindNodeOrInsertPos(ID, IP));
1489
1490 // Give up if we don't already have the add recurrence we need because
1491 // actually constructing an add recurrence is relatively expensive.
1492 if (PreAR && PreAR->getNoWrapFlags(WrapType)) { // proves (2)
1493 const SCEV *DeltaS = getConstant(StartC->getType(), Delta);
1495 const SCEV *Limit = ExtendOpTraits<ExtendOpTy>::getOverflowLimitForStep(
1496 DeltaS, &Pred, this);
1497 if (Limit && isKnownPredicate(Pred, PreAR, Limit)) // proves (1)
1498 return true;
1499 }
1500 }
1501
1502 return false;
1503}
1504
1505// Finds an integer D for an expression (C + x + y + ...) such that the top
1506// level addition in (D + (C - D + x + y + ...)) would not wrap (signed or
1507// unsigned) and the number of trailing zeros of (C - D + x + y + ...) is
1508// maximized, where C is the \p ConstantTerm, x, y, ... are arbitrary SCEVs, and
1509// the (C + x + y + ...) expression is \p WholeAddExpr.
1511 const SCEVConstant *ConstantTerm,
1512 const SCEVAddExpr *WholeAddExpr) {
1513 const APInt &C = ConstantTerm->getAPInt();
1514 const unsigned BitWidth = C.getBitWidth();
1515 // Find number of trailing zeros of (x + y + ...) w/o the C first:
1516 uint32_t TZ = BitWidth;
1517 for (unsigned I = 1, E = WholeAddExpr->getNumOperands(); I < E && TZ; ++I)
1518 TZ = std::min(TZ, SE.getMinTrailingZeros(WholeAddExpr->getOperand(I)));
1519 if (TZ) {
1520 // Set D to be as many least significant bits of C as possible while still
1521 // guaranteeing that adding D to (C - D + x + y + ...) won't cause a wrap:
1522 return TZ < BitWidth ? C.trunc(TZ).zext(BitWidth) : C;
1523 }
1524 return APInt(BitWidth, 0);
1525}
1526
1527// Finds an integer D for an affine AddRec expression {C,+,x} such that the top
1528// level addition in (D + {C-D,+,x}) would not wrap (signed or unsigned) and the
1529// number of trailing zeros of (C - D + x * n) is maximized, where C is the \p
1530// ConstantStart, x is an arbitrary \p Step, and n is the loop trip count.
1532 const APInt &ConstantStart,
1533 const SCEV *Step) {
1534 const unsigned BitWidth = ConstantStart.getBitWidth();
1535 const uint32_t TZ = SE.getMinTrailingZeros(Step);
1536 if (TZ)
1537 return TZ < BitWidth ? ConstantStart.trunc(TZ).zext(BitWidth)
1538 : ConstantStart;
1539 return APInt(BitWidth, 0);
1540}
1541
1543 const ScalarEvolution::FoldID &ID, const SCEV *S,
1546 &FoldCacheUser) {
1547 auto I = FoldCache.insert({ID, S});
1548 if (!I.second) {
1549 // Remove FoldCacheUser entry for ID when replacing an existing FoldCache
1550 // entry.
1551 auto &UserIDs = FoldCacheUser[I.first->second];
1552 assert(count(UserIDs, ID) == 1 && "unexpected duplicates in UserIDs");
1553 for (unsigned I = 0; I != UserIDs.size(); ++I)
1554 if (UserIDs[I] == ID) {
1555 std::swap(UserIDs[I], UserIDs.back());
1556 break;
1557 }
1558 UserIDs.pop_back();
1559 I.first->second = S;
1560 }
1561 FoldCacheUser[S].push_back(ID);
1562}
1563
1564const SCEV *
1566 assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) &&
1567 "This is not an extending conversion!");
1568 assert(isSCEVable(Ty) &&
1569 "This is not a conversion to a SCEVable type!");
1570 assert(!Op->getType()->isPointerTy() && "Can't extend pointer!");
1571 Ty = getEffectiveSCEVType(Ty);
1572
1573 FoldID ID(scZeroExtend, Op, Ty);
1574 auto Iter = FoldCache.find(ID);
1575 if (Iter != FoldCache.end())
1576 return Iter->second;
1577
1578 const SCEV *S = getZeroExtendExprImpl(Op, Ty, Depth);
1579 if (!isa<SCEVZeroExtendExpr>(S))
1580 insertFoldCacheEntry(ID, S, FoldCache, FoldCacheUser);
1581 return S;
1582}
1583
1585 unsigned Depth) {
1586 assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) &&
1587 "This is not an extending conversion!");
1588 assert(isSCEVable(Ty) && "This is not a conversion to a SCEVable type!");
1589 assert(!Op->getType()->isPointerTy() && "Can't extend pointer!");
1590
1591 // Fold if the operand is constant.
1592 if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op))
1593 return getConstant(SC->getAPInt().zext(getTypeSizeInBits(Ty)));
1594
1595 // zext(zext(x)) --> zext(x)
1596 if (const SCEVZeroExtendExpr *SZ = dyn_cast<SCEVZeroExtendExpr>(Op))
1597 return getZeroExtendExpr(SZ->getOperand(), Ty, Depth + 1);
1598
1599 // Before doing any expensive analysis, check to see if we've already
1600 // computed a SCEV for this Op and Ty.
1602 ID.AddInteger(scZeroExtend);
1603 ID.AddPointer(Op);
1604 ID.AddPointer(Ty);
1605 void *IP = nullptr;
1606 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
1607 if (Depth > MaxCastDepth) {
1608 SCEV *S = new (SCEVAllocator) SCEVZeroExtendExpr(ID.Intern(SCEVAllocator),
1609 Op, Ty);
1610 UniqueSCEVs.InsertNode(S, IP);
1611 registerUser(S, Op);
1612 return S;
1613 }
1614
1615 // zext(trunc(x)) --> zext(x) or x or trunc(x)
1616 if (const SCEVTruncateExpr *ST = dyn_cast<SCEVTruncateExpr>(Op)) {
1617 // It's possible the bits taken off by the truncate were all zero bits. If
1618 // so, we should be able to simplify this further.
1619 const SCEV *X = ST->getOperand();
1621 unsigned TruncBits = getTypeSizeInBits(ST->getType());
1622 unsigned NewBits = getTypeSizeInBits(Ty);
1623 if (CR.truncate(TruncBits).zeroExtend(NewBits).contains(
1624 CR.zextOrTrunc(NewBits)))
1625 return getTruncateOrZeroExtend(X, Ty, Depth);
1626 }
1627
1628 // If the input value is a chrec scev, and we can prove that the value
1629 // did not overflow the old, smaller, value, we can zero extend all of the
1630 // operands (often constants). This allows analysis of something like
1631 // this: for (unsigned char X = 0; X < 100; ++X) { int Y = X; }
1632 if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Op))
1633 if (AR->isAffine()) {
1634 const SCEV *Start = AR->getStart();
1635 const SCEV *Step = AR->getStepRecurrence(*this);
1636 unsigned BitWidth = getTypeSizeInBits(AR->getType());
1637 const Loop *L = AR->getLoop();
1638
1639 // If we have special knowledge that this addrec won't overflow,
1640 // we don't need to do any further analysis.
1641 if (AR->hasNoUnsignedWrap()) {
1642 Start =
1643 getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this, Depth + 1);
1644 Step = getZeroExtendExpr(Step, Ty, Depth + 1);
1645 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
1646 }
1647
1648 // Check whether the backedge-taken count is SCEVCouldNotCompute.
1649 // Note that this serves two purposes: It filters out loops that are
1650 // simply not analyzable, and it covers the case where this code is
1651 // being called from within backedge-taken count analysis, such that
1652 // attempting to ask for the backedge-taken count would likely result
1653 // in infinite recursion. In the later case, the analysis code will
1654 // cope with a conservative value, and it will take care to purge
1655 // that value once it has finished.
1656 const SCEV *MaxBECount = getConstantMaxBackedgeTakenCount(L);
1657 if (!isa<SCEVCouldNotCompute>(MaxBECount)) {
1658 // Manually compute the final value for AR, checking for overflow.
1659
1660 // Check whether the backedge-taken count can be losslessly casted to
1661 // the addrec's type. The count is always unsigned.
1662 const SCEV *CastedMaxBECount =
1663 getTruncateOrZeroExtend(MaxBECount, Start->getType(), Depth);
1664 const SCEV *RecastedMaxBECount = getTruncateOrZeroExtend(
1665 CastedMaxBECount, MaxBECount->getType(), Depth);
1666 if (MaxBECount == RecastedMaxBECount) {
1667 Type *WideTy = IntegerType::get(getContext(), BitWidth * 2);
1668 // Check whether Start+Step*MaxBECount has no unsigned overflow.
1669 const SCEV *ZMul = getMulExpr(CastedMaxBECount, Step,
1671 const SCEV *ZAdd = getZeroExtendExpr(getAddExpr(Start, ZMul,
1673 Depth + 1),
1674 WideTy, Depth + 1);
1675 const SCEV *WideStart = getZeroExtendExpr(Start, WideTy, Depth + 1);
1676 const SCEV *WideMaxBECount =
1677 getZeroExtendExpr(CastedMaxBECount, WideTy, Depth + 1);
1678 const SCEV *OperandExtendedAdd =
1679 getAddExpr(WideStart,
1680 getMulExpr(WideMaxBECount,
1681 getZeroExtendExpr(Step, WideTy, Depth + 1),
1684 if (ZAdd == OperandExtendedAdd) {
1685 // Cache knowledge of AR NUW, which is propagated to this AddRec.
1686 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNUW);
1687 // Return the expression with the addrec on the outside.
1688 Start = getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this,
1689 Depth + 1);
1690 Step = getZeroExtendExpr(Step, Ty, Depth + 1);
1691 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
1692 }
1693 // Similar to above, only this time treat the step value as signed.
1694 // This covers loops that count down.
1695 OperandExtendedAdd =
1696 getAddExpr(WideStart,
1697 getMulExpr(WideMaxBECount,
1698 getSignExtendExpr(Step, WideTy, Depth + 1),
1701 if (ZAdd == OperandExtendedAdd) {
1702 // Cache knowledge of AR NW, which is propagated to this AddRec.
1703 // Negative step causes unsigned wrap, but it still can't self-wrap.
1704 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNW);
1705 // Return the expression with the addrec on the outside.
1706 Start = getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this,
1707 Depth + 1);
1708 Step = getSignExtendExpr(Step, Ty, Depth + 1);
1709 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
1710 }
1711 }
1712 }
1713
1714 // Normally, in the cases we can prove no-overflow via a
1715 // backedge guarding condition, we can also compute a backedge
1716 // taken count for the loop. The exceptions are assumptions and
1717 // guards present in the loop -- SCEV is not great at exploiting
1718 // these to compute max backedge taken counts, but can still use
1719 // these to prove lack of overflow. Use this fact to avoid
1720 // doing extra work that may not pay off.
1721 if (!isa<SCEVCouldNotCompute>(MaxBECount) || HasGuards ||
1722 !AC.assumptions().empty()) {
1723
1724 auto NewFlags = proveNoUnsignedWrapViaInduction(AR);
1725 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), NewFlags);
1726 if (AR->hasNoUnsignedWrap()) {
1727 // Same as nuw case above - duplicated here to avoid a compile time
1728 // issue. It's not clear that the order of checks does matter, but
1729 // it's one of two issue possible causes for a change which was
1730 // reverted. Be conservative for the moment.
1731 Start =
1732 getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this, Depth + 1);
1733 Step = getZeroExtendExpr(Step, Ty, Depth + 1);
1734 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
1735 }
1736
1737 // For a negative step, we can extend the operands iff doing so only
1738 // traverses values in the range zext([0,UINT_MAX]).
1739 if (isKnownNegative(Step)) {
1741 getSignedRangeMin(Step));
1744 // Cache knowledge of AR NW, which is propagated to this
1745 // AddRec. Negative step causes unsigned wrap, but it
1746 // still can't self-wrap.
1747 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNW);
1748 // Return the expression with the addrec on the outside.
1749 Start = getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this,
1750 Depth + 1);
1751 Step = getSignExtendExpr(Step, Ty, Depth + 1);
1752 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
1753 }
1754 }
1755 }
1756
1757 // zext({C,+,Step}) --> (zext(D) + zext({C-D,+,Step}))<nuw><nsw>
1758 // if D + (C - D + Step * n) could be proven to not unsigned wrap
1759 // where D maximizes the number of trailing zeros of (C - D + Step * n)
1760 if (const auto *SC = dyn_cast<SCEVConstant>(Start)) {
1761 const APInt &C = SC->getAPInt();
1762 const APInt &D = extractConstantWithoutWrapping(*this, C, Step);
1763 if (D != 0) {
1764 const SCEV *SZExtD = getZeroExtendExpr(getConstant(D), Ty, Depth);
1765 const SCEV *SResidual =
1766 getAddRecExpr(getConstant(C - D), Step, L, AR->getNoWrapFlags());
1767 const SCEV *SZExtR = getZeroExtendExpr(SResidual, Ty, Depth + 1);
1768 return getAddExpr(SZExtD, SZExtR,
1770 Depth + 1);
1771 }
1772 }
1773
1774 if (proveNoWrapByVaryingStart<SCEVZeroExtendExpr>(Start, Step, L)) {
1775 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNUW);
1776 Start =
1777 getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this, Depth + 1);
1778 Step = getZeroExtendExpr(Step, Ty, Depth + 1);
1779 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
1780 }
1781 }
1782
1783 // zext(A % B) --> zext(A) % zext(B)
1784 {
1785 const SCEV *LHS;
1786 const SCEV *RHS;
1787 if (matchURem(Op, LHS, RHS))
1788 return getURemExpr(getZeroExtendExpr(LHS, Ty, Depth + 1),
1789 getZeroExtendExpr(RHS, Ty, Depth + 1));
1790 }
1791
1792 // zext(A / B) --> zext(A) / zext(B).
1793 if (auto *Div = dyn_cast<SCEVUDivExpr>(Op))
1794 return getUDivExpr(getZeroExtendExpr(Div->getLHS(), Ty, Depth + 1),
1795 getZeroExtendExpr(Div->getRHS(), Ty, Depth + 1));
1796
1797 if (auto *SA = dyn_cast<SCEVAddExpr>(Op)) {
1798 // zext((A + B + ...)<nuw>) --> (zext(A) + zext(B) + ...)<nuw>
1799 if (SA->hasNoUnsignedWrap()) {
1800 // If the addition does not unsign overflow then we can, by definition,
1801 // commute the zero extension with the addition operation.
1803 for (const auto *Op : SA->operands())
1804 Ops.push_back(getZeroExtendExpr(Op, Ty, Depth + 1));
1805 return getAddExpr(Ops, SCEV::FlagNUW, Depth + 1);
1806 }
1807
1808 // zext(C + x + y + ...) --> (zext(D) + zext((C - D) + x + y + ...))
1809 // if D + (C - D + x + y + ...) could be proven to not unsigned wrap
1810 // where D maximizes the number of trailing zeros of (C - D + x + y + ...)
1811 //
1812 // Often address arithmetics contain expressions like
1813 // (zext (add (shl X, C1), C2)), for instance, (zext (5 + (4 * X))).
1814 // This transformation is useful while proving that such expressions are
1815 // equal or differ by a small constant amount, see LoadStoreVectorizer pass.
1816 if (const auto *SC = dyn_cast<SCEVConstant>(SA->getOperand(0))) {
1817 const APInt &D = extractConstantWithoutWrapping(*this, SC, SA);
1818 if (D != 0) {
1819 const SCEV *SZExtD = getZeroExtendExpr(getConstant(D), Ty, Depth);
1820 const SCEV *SResidual =
1822 const SCEV *SZExtR = getZeroExtendExpr(SResidual, Ty, Depth + 1);
1823 return getAddExpr(SZExtD, SZExtR,
1825 Depth + 1);
1826 }
1827 }
1828 }
1829
1830 if (auto *SM = dyn_cast<SCEVMulExpr>(Op)) {
1831 // zext((A * B * ...)<nuw>) --> (zext(A) * zext(B) * ...)<nuw>
1832 if (SM->hasNoUnsignedWrap()) {
1833 // If the multiply does not unsign overflow then we can, by definition,
1834 // commute the zero extension with the multiply operation.
1836 for (const auto *Op : SM->operands())
1837 Ops.push_back(getZeroExtendExpr(Op, Ty, Depth + 1));
1838 return getMulExpr(Ops, SCEV::FlagNUW, Depth + 1);
1839 }
1840
1841 // zext(2^K * (trunc X to iN)) to iM ->
1842 // 2^K * (zext(trunc X to i{N-K}) to iM)<nuw>
1843 //
1844 // Proof:
1845 //
1846 // zext(2^K * (trunc X to iN)) to iM
1847 // = zext((trunc X to iN) << K) to iM
1848 // = zext((trunc X to i{N-K}) << K)<nuw> to iM
1849 // (because shl removes the top K bits)
1850 // = zext((2^K * (trunc X to i{N-K}))<nuw>) to iM
1851 // = (2^K * (zext(trunc X to i{N-K}) to iM))<nuw>.
1852 //
1853 if (SM->getNumOperands() == 2)
1854 if (auto *MulLHS = dyn_cast<SCEVConstant>(SM->getOperand(0)))
1855 if (MulLHS->getAPInt().isPowerOf2())
1856 if (auto *TruncRHS = dyn_cast<SCEVTruncateExpr>(SM->getOperand(1))) {
1857 int NewTruncBits = getTypeSizeInBits(TruncRHS->getType()) -
1858 MulLHS->getAPInt().logBase2();
1859 Type *NewTruncTy = IntegerType::get(getContext(), NewTruncBits);
1860 return getMulExpr(
1861 getZeroExtendExpr(MulLHS, Ty),
1863 getTruncateExpr(TruncRHS->getOperand(), NewTruncTy), Ty),
1864 SCEV::FlagNUW, Depth + 1);
1865 }
1866 }
1867
1868 // zext(umin(x, y)) -> umin(zext(x), zext(y))
1869 // zext(umax(x, y)) -> umax(zext(x), zext(y))
1870 if (isa<SCEVUMinExpr>(Op) || isa<SCEVUMaxExpr>(Op)) {
1871 auto *MinMax = cast<SCEVMinMaxExpr>(Op);
1873 for (auto *Operand : MinMax->operands())
1874 Operands.push_back(getZeroExtendExpr(Operand, Ty));
1875 if (isa<SCEVUMinExpr>(MinMax))
1876 return getUMinExpr(Operands);
1877 return getUMaxExpr(Operands);
1878 }
1879
1880 // zext(umin_seq(x, y)) -> umin_seq(zext(x), zext(y))
1881 if (auto *MinMax = dyn_cast<SCEVSequentialMinMaxExpr>(Op)) {
1882 assert(isa<SCEVSequentialUMinExpr>(MinMax) && "Not supported!");
1884 for (auto *Operand : MinMax->operands())
1885 Operands.push_back(getZeroExtendExpr(Operand, Ty));
1886 return getUMinExpr(Operands, /*Sequential*/ true);
1887 }
1888
1889 // The cast wasn't folded; create an explicit cast node.
1890 // Recompute the insert position, as it may have been invalidated.
1891 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
1892 SCEV *S = new (SCEVAllocator) SCEVZeroExtendExpr(ID.Intern(SCEVAllocator),
1893 Op, Ty);
1894 UniqueSCEVs.InsertNode(S, IP);
1895 registerUser(S, Op);
1896 return S;
1897}
1898
1899const SCEV *
1901 assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) &&
1902 "This is not an extending conversion!");
1903 assert(isSCEVable(Ty) &&
1904 "This is not a conversion to a SCEVable type!");
1905 assert(!Op->getType()->isPointerTy() && "Can't extend pointer!");
1906 Ty = getEffectiveSCEVType(Ty);
1907
1908 FoldID ID(scSignExtend, Op, Ty);
1909 auto Iter = FoldCache.find(ID);
1910 if (Iter != FoldCache.end())
1911 return Iter->second;
1912
1913 const SCEV *S = getSignExtendExprImpl(Op, Ty, Depth);
1914 if (!isa<SCEVSignExtendExpr>(S))
1915 insertFoldCacheEntry(ID, S, FoldCache, FoldCacheUser);
1916 return S;
1917}
1918
1920 unsigned Depth) {
1921 assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) &&
1922 "This is not an extending conversion!");
1923 assert(isSCEVable(Ty) && "This is not a conversion to a SCEVable type!");
1924 assert(!Op->getType()->isPointerTy() && "Can't extend pointer!");
1925 Ty = getEffectiveSCEVType(Ty);
1926
1927 // Fold if the operand is constant.
1928 if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op))
1929 return getConstant(SC->getAPInt().sext(getTypeSizeInBits(Ty)));
1930
1931 // sext(sext(x)) --> sext(x)
1932 if (const SCEVSignExtendExpr *SS = dyn_cast<SCEVSignExtendExpr>(Op))
1933 return getSignExtendExpr(SS->getOperand(), Ty, Depth + 1);
1934
1935 // sext(zext(x)) --> zext(x)
1936 if (const SCEVZeroExtendExpr *SZ = dyn_cast<SCEVZeroExtendExpr>(Op))
1937 return getZeroExtendExpr(SZ->getOperand(), Ty, Depth + 1);
1938
1939 // Before doing any expensive analysis, check to see if we've already
1940 // computed a SCEV for this Op and Ty.
1942 ID.AddInteger(scSignExtend);
1943 ID.AddPointer(Op);
1944 ID.AddPointer(Ty);
1945 void *IP = nullptr;
1946 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
1947 // Limit recursion depth.
1948 if (Depth > MaxCastDepth) {
1949 SCEV *S = new (SCEVAllocator) SCEVSignExtendExpr(ID.Intern(SCEVAllocator),
1950 Op, Ty);
1951 UniqueSCEVs.InsertNode(S, IP);
1952 registerUser(S, Op);
1953 return S;
1954 }
1955
1956 // sext(trunc(x)) --> sext(x) or x or trunc(x)
1957 if (const SCEVTruncateExpr *ST = dyn_cast<SCEVTruncateExpr>(Op)) {
1958 // It's possible the bits taken off by the truncate were all sign bits. If
1959 // so, we should be able to simplify this further.
1960 const SCEV *X = ST->getOperand();
1962 unsigned TruncBits = getTypeSizeInBits(ST->getType());
1963 unsigned NewBits = getTypeSizeInBits(Ty);
1964 if (CR.truncate(TruncBits).signExtend(NewBits).contains(
1965 CR.sextOrTrunc(NewBits)))
1966 return getTruncateOrSignExtend(X, Ty, Depth);
1967 }
1968
1969 if (auto *SA = dyn_cast<SCEVAddExpr>(Op)) {
1970 // sext((A + B + ...)<nsw>) --> (sext(A) + sext(B) + ...)<nsw>
1971 if (SA->hasNoSignedWrap()) {
1972 // If the addition does not sign overflow then we can, by definition,
1973 // commute the sign extension with the addition operation.
1975 for (const auto *Op : SA->operands())
1976 Ops.push_back(getSignExtendExpr(Op, Ty, Depth + 1));
1977 return getAddExpr(Ops, SCEV::FlagNSW, Depth + 1);
1978 }
1979
1980 // sext(C + x + y + ...) --> (sext(D) + sext((C - D) + x + y + ...))
1981 // if D + (C - D + x + y + ...) could be proven to not signed wrap
1982 // where D maximizes the number of trailing zeros of (C - D + x + y + ...)
1983 //
1984 // For instance, this will bring two seemingly different expressions:
1985 // 1 + sext(5 + 20 * %x + 24 * %y) and
1986 // sext(6 + 20 * %x + 24 * %y)
1987 // to the same form:
1988 // 2 + sext(4 + 20 * %x + 24 * %y)
1989 if (const auto *SC = dyn_cast<SCEVConstant>(SA->getOperand(0))) {
1990 const APInt &D = extractConstantWithoutWrapping(*this, SC, SA);
1991 if (D != 0) {
1992 const SCEV *SSExtD = getSignExtendExpr(getConstant(D), Ty, Depth);
1993 const SCEV *SResidual =
1995 const SCEV *SSExtR = getSignExtendExpr(SResidual, Ty, Depth + 1);
1996 return getAddExpr(SSExtD, SSExtR,
1998 Depth + 1);
1999 }
2000 }
2001 }
2002 // If the input value is a chrec scev, and we can prove that the value
2003 // did not overflow the old, smaller, value, we can sign extend all of the
2004 // operands (often constants). This allows analysis of something like
2005 // this: for (signed char X = 0; X < 100; ++X) { int Y = X; }
2006 if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Op))
2007 if (AR->isAffine()) {
2008 const SCEV *Start = AR->getStart();
2009 const SCEV *Step = AR->getStepRecurrence(*this);
2010 unsigned BitWidth = getTypeSizeInBits(AR->getType());
2011 const Loop *L = AR->getLoop();
2012
2013 // If we have special knowledge that this addrec won't overflow,
2014 // we don't need to do any further analysis.
2015 if (AR->hasNoSignedWrap()) {
2016 Start =
2017 getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this, Depth + 1);
2018 Step = getSignExtendExpr(Step, Ty, Depth + 1);
2019 return getAddRecExpr(Start, Step, L, SCEV::FlagNSW);
2020 }
2021
2022 // Check whether the backedge-taken count is SCEVCouldNotCompute.
2023 // Note that this serves two purposes: It filters out loops that are
2024 // simply not analyzable, and it covers the case where this code is
2025 // being called from within backedge-taken count analysis, such that
2026 // attempting to ask for the backedge-taken count would likely result
2027 // in infinite recursion. In the later case, the analysis code will
2028 // cope with a conservative value, and it will take care to purge
2029 // that value once it has finished.
2030 const SCEV *MaxBECount = getConstantMaxBackedgeTakenCount(L);
2031 if (!isa<SCEVCouldNotCompute>(MaxBECount)) {
2032 // Manually compute the final value for AR, checking for
2033 // overflow.
2034
2035 // Check whether the backedge-taken count can be losslessly casted to
2036 // the addrec's type. The count is always unsigned.
2037 const SCEV *CastedMaxBECount =
2038 getTruncateOrZeroExtend(MaxBECount, Start->getType(), Depth);
2039 const SCEV *RecastedMaxBECount = getTruncateOrZeroExtend(
2040 CastedMaxBECount, MaxBECount->getType(), Depth);
2041 if (MaxBECount == RecastedMaxBECount) {
2042 Type *WideTy = IntegerType::get(getContext(), BitWidth * 2);
2043 // Check whether Start+Step*MaxBECount has no signed overflow.
2044 const SCEV *SMul = getMulExpr(CastedMaxBECount, Step,
2046 const SCEV *SAdd = getSignExtendExpr(getAddExpr(Start, SMul,
2048 Depth + 1),
2049 WideTy, Depth + 1);
2050 const SCEV *WideStart = getSignExtendExpr(Start, WideTy, Depth + 1);
2051 const SCEV *WideMaxBECount =
2052 getZeroExtendExpr(CastedMaxBECount, WideTy, Depth + 1);
2053 const SCEV *OperandExtendedAdd =
2054 getAddExpr(WideStart,
2055 getMulExpr(WideMaxBECount,
2056 getSignExtendExpr(Step, WideTy, Depth + 1),
2059 if (SAdd == OperandExtendedAdd) {
2060 // Cache knowledge of AR NSW, which is propagated to this AddRec.
2061 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNSW);
2062 // Return the expression with the addrec on the outside.
2063 Start = getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this,
2064 Depth + 1);
2065 Step = getSignExtendExpr(Step, Ty, Depth + 1);
2066 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
2067 }
2068 // Similar to above, only this time treat the step value as unsigned.
2069 // This covers loops that count up with an unsigned step.
2070 OperandExtendedAdd =
2071 getAddExpr(WideStart,
2072 getMulExpr(WideMaxBECount,
2073 getZeroExtendExpr(Step, WideTy, Depth + 1),
2076 if (SAdd == OperandExtendedAdd) {
2077 // If AR wraps around then
2078 //
2079 // abs(Step) * MaxBECount > unsigned-max(AR->getType())
2080 // => SAdd != OperandExtendedAdd
2081 //
2082 // Thus (AR is not NW => SAdd != OperandExtendedAdd) <=>
2083 // (SAdd == OperandExtendedAdd => AR is NW)
2084
2085 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNW);
2086
2087 // Return the expression with the addrec on the outside.
2088 Start = getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this,
2089 Depth + 1);
2090 Step = getZeroExtendExpr(Step, Ty, Depth + 1);
2091 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
2092 }
2093 }
2094 }
2095
2096 auto NewFlags = proveNoSignedWrapViaInduction(AR);
2097 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), NewFlags);
2098 if (AR->hasNoSignedWrap()) {
2099 // Same as nsw case above - duplicated here to avoid a compile time
2100 // issue. It's not clear that the order of checks does matter, but
2101 // it's one of two issue possible causes for a change which was
2102 // reverted. Be conservative for the moment.
2103 Start =
2104 getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this, Depth + 1);
2105 Step = getSignExtendExpr(Step, Ty, Depth + 1);
2106 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
2107 }
2108
2109 // sext({C,+,Step}) --> (sext(D) + sext({C-D,+,Step}))<nuw><nsw>
2110 // if D + (C - D + Step * n) could be proven to not signed wrap
2111 // where D maximizes the number of trailing zeros of (C - D + Step * n)
2112 if (const auto *SC = dyn_cast<SCEVConstant>(Start)) {
2113 const APInt &C = SC->getAPInt();
2114 const APInt &D = extractConstantWithoutWrapping(*this, C, Step);
2115 if (D != 0) {
2116 const SCEV *SSExtD = getSignExtendExpr(getConstant(D), Ty, Depth);
2117 const SCEV *SResidual =
2118 getAddRecExpr(getConstant(C - D), Step, L, AR->getNoWrapFlags());
2119 const SCEV *SSExtR = getSignExtendExpr(SResidual, Ty, Depth + 1);
2120 return getAddExpr(SSExtD, SSExtR,
2122 Depth + 1);
2123 }
2124 }
2125
2126 if (proveNoWrapByVaryingStart<SCEVSignExtendExpr>(Start, Step, L)) {
2127 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNSW);
2128 Start =
2129 getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this, Depth + 1);
2130 Step = getSignExtendExpr(Step, Ty, Depth + 1);
2131 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
2132 }
2133 }
2134
2135 // If the input value is provably positive and we could not simplify
2136 // away the sext build a zext instead.
2138 return getZeroExtendExpr(Op, Ty, Depth + 1);
2139
2140 // sext(smin(x, y)) -> smin(sext(x), sext(y))
2141 // sext(smax(x, y)) -> smax(sext(x), sext(y))
2142 if (isa<SCEVSMinExpr>(Op) || isa<SCEVSMaxExpr>(Op)) {
2143 auto *MinMax = cast<SCEVMinMaxExpr>(Op);
2145 for (auto *Operand : MinMax->operands())
2146 Operands.push_back(getSignExtendExpr(Operand, Ty));
2147 if (isa<SCEVSMinExpr>(MinMax))
2148 return getSMinExpr(Operands);
2149 return getSMaxExpr(Operands);
2150 }
2151
2152 // The cast wasn't folded; create an explicit cast node.
2153 // Recompute the insert position, as it may have been invalidated.
2154 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
2155 SCEV *S = new (SCEVAllocator) SCEVSignExtendExpr(ID.Intern(SCEVAllocator),
2156 Op, Ty);
2157 UniqueSCEVs.InsertNode(S, IP);
2158 registerUser(S, { Op });
2159 return S;
2160}
2161
2163 Type *Ty) {
2164 switch (Kind) {
2165 case scTruncate:
2166 return getTruncateExpr(Op, Ty);
2167 case scZeroExtend:
2168 return getZeroExtendExpr(Op, Ty);
2169 case scSignExtend:
2170 return getSignExtendExpr(Op, Ty);
2171 case scPtrToInt:
2172 return getPtrToIntExpr(Op, Ty);
2173 default:
2174 llvm_unreachable("Not a SCEV cast expression!");
2175 }
2176}
2177
2178/// getAnyExtendExpr - Return a SCEV for the given operand extended with
2179/// unspecified bits out to the given type.
2181 Type *Ty) {
2182 assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) &&
2183 "This is not an extending conversion!");
2184 assert(isSCEVable(Ty) &&
2185 "This is not a conversion to a SCEVable type!");
2186 Ty = getEffectiveSCEVType(Ty);
2187
2188 // Sign-extend negative constants.
2189 if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op))
2190 if (SC->getAPInt().isNegative())
2191 return getSignExtendExpr(Op, Ty);
2192
2193 // Peel off a truncate cast.
2194 if (const SCEVTruncateExpr *T = dyn_cast<SCEVTruncateExpr>(Op)) {
2195 const SCEV *NewOp = T->getOperand();
2196 if (getTypeSizeInBits(NewOp->getType()) < getTypeSizeInBits(Ty))
2197 return getAnyExtendExpr(NewOp, Ty);
2198 return getTruncateOrNoop(NewOp, Ty);
2199 }
2200
2201 // Next try a zext cast. If the cast is folded, use it.
2202 const SCEV *ZExt = getZeroExtendExpr(Op, Ty);
2203 if (!isa<SCEVZeroExtendExpr>(ZExt))
2204 return ZExt;
2205
2206 // Next try a sext cast. If the cast is folded, use it.
2207 const SCEV *SExt = getSignExtendExpr(Op, Ty);
2208 if (!isa<SCEVSignExtendExpr>(SExt))
2209 return SExt;
2210
2211 // Force the cast to be folded into the operands of an addrec.
2212 if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Op)) {
2214 for (const SCEV *Op : AR->operands())
2215 Ops.push_back(getAnyExtendExpr(Op, Ty));
2216 return getAddRecExpr(Ops, AR->getLoop(), SCEV::FlagNW);
2217 }
2218
2219 // If the expression is obviously signed, use the sext cast value.
2220 if (isa<SCEVSMaxExpr>(Op))
2221 return SExt;
2222
2223 // Absent any other information, use the zext cast value.
2224 return ZExt;
2225}
2226
2227/// Process the given Ops list, which is a list of operands to be added under
2228/// the given scale, update the given map. This is a helper function for
2229/// getAddRecExpr. As an example of what it does, given a sequence of operands
2230/// that would form an add expression like this:
2231///
2232/// m + n + 13 + (A * (o + p + (B * (q + m + 29)))) + r + (-1 * r)
2233///
2234/// where A and B are constants, update the map with these values:
2235///
2236/// (m, 1+A*B), (n, 1), (o, A), (p, A), (q, A*B), (r, 0)
2237///
2238/// and add 13 + A*B*29 to AccumulatedConstant.
2239/// This will allow getAddRecExpr to produce this:
2240///
2241/// 13+A*B*29 + n + (m * (1+A*B)) + ((o + p) * A) + (q * A*B)
2242///
2243/// This form often exposes folding opportunities that are hidden in
2244/// the original operand list.
2245///
2246/// Return true iff it appears that any interesting folding opportunities
2247/// may be exposed. This helps getAddRecExpr short-circuit extra work in
2248/// the common case where no interesting opportunities are present, and
2249/// is also used as a check to avoid infinite recursion.
2250static bool
2253 APInt &AccumulatedConstant,
2254 ArrayRef<const SCEV *> Ops, const APInt &Scale,
2255 ScalarEvolution &SE) {
2256 bool Interesting = false;
2257
2258 // Iterate over the add operands. They are sorted, with constants first.
2259 unsigned i = 0;
2260 while (const SCEVConstant *C = dyn_cast<SCEVConstant>(Ops[i])) {
2261 ++i;
2262 // Pull a buried constant out to the outside.
2263 if (Scale != 1 || AccumulatedConstant != 0 || C->getValue()->isZero())
2264 Interesting = true;
2265 AccumulatedConstant += Scale * C->getAPInt();
2266 }
2267
2268 // Next comes everything else. We're especially interested in multiplies
2269 // here, but they're in the middle, so just visit the rest with one loop.
2270 for (; i != Ops.size(); ++i) {
2271 const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(Ops[i]);
2272 if (Mul && isa<SCEVConstant>(Mul->getOperand(0))) {
2273 APInt NewScale =
2274 Scale * cast<SCEVConstant>(Mul->getOperand(0))->getAPInt();
2275 if (Mul->getNumOperands() == 2 && isa<SCEVAddExpr>(Mul->getOperand(1))) {
2276 // A multiplication of a constant with another add; recurse.
2277 const SCEVAddExpr *Add = cast<SCEVAddExpr>(Mul->getOperand(1));
2278 Interesting |=
2279 CollectAddOperandsWithScales(M, NewOps, AccumulatedConstant,
2280 Add->operands(), NewScale, SE);
2281 } else {
2282 // A multiplication of a constant with some other value. Update
2283 // the map.
2284 SmallVector<const SCEV *, 4> MulOps(drop_begin(Mul->operands()));
2285 const SCEV *Key = SE.getMulExpr(MulOps);
2286 auto Pair = M.insert({Key, NewScale});
2287 if (Pair.second) {
2288 NewOps.push_back(Pair.first->first);
2289 } else {
2290 Pair.first->second += NewScale;
2291 // The map already had an entry for this value, which may indicate
2292 // a folding opportunity.
2293 Interesting = true;
2294 }
2295 }
2296 } else {
2297 // An ordinary operand. Update the map.
2298 std::pair<DenseMap<const SCEV *, APInt>::iterator, bool> Pair =
2299 M.insert({Ops[i], Scale});
2300 if (Pair.second) {
2301 NewOps.push_back(Pair.first->first);
2302 } else {
2303 Pair.first->second += Scale;
2304 // The map already had an entry for this value, which may indicate
2305 // a folding opportunity.
2306 Interesting = true;
2307 }
2308 }
2309 }
2310
2311 return Interesting;
2312}
2313
2315 const SCEV *LHS, const SCEV *RHS,
2316 const Instruction *CtxI) {
2317 const SCEV *(ScalarEvolution::*Operation)(const SCEV *, const SCEV *,
2318 SCEV::NoWrapFlags, unsigned);
2319 switch (BinOp) {
2320 default:
2321 llvm_unreachable("Unsupported binary op");
2322 case Instruction::Add:
2324 break;
2325 case Instruction::Sub:
2327 break;
2328 case Instruction::Mul:
2330 break;
2331 }
2332
2333 const SCEV *(ScalarEvolution::*Extension)(const SCEV *, Type *, unsigned) =
2336
2337 // Check ext(LHS op RHS) == ext(LHS) op ext(RHS)
2338 auto *NarrowTy = cast<IntegerType>(LHS->getType());
2339 auto *WideTy =
2340 IntegerType::get(NarrowTy->getContext(), NarrowTy->getBitWidth() * 2);
2341
2342 const SCEV *A = (this->*Extension)(
2343 (this->*Operation)(LHS, RHS, SCEV::FlagAnyWrap, 0), WideTy, 0);
2344 const SCEV *LHSB = (this->*Extension)(LHS, WideTy, 0);
2345 const SCEV *RHSB = (this->*Extension)(RHS, WideTy, 0);
2346 const SCEV *B = (this->*Operation)(LHSB, RHSB, SCEV::FlagAnyWrap, 0);
2347 if (A == B)
2348 return true;
2349 // Can we use context to prove the fact we need?
2350 if (!CtxI)
2351 return false;
2352 // TODO: Support mul.
2353 if (BinOp == Instruction::Mul)
2354 return false;
2355 auto *RHSC = dyn_cast<SCEVConstant>(RHS);
2356 // TODO: Lift this limitation.
2357 if (!RHSC)
2358 return false;
2359 APInt C = RHSC->getAPInt();
2360 unsigned NumBits = C.getBitWidth();
2361 bool IsSub = (BinOp == Instruction::Sub);
2362 bool IsNegativeConst = (Signed && C.isNegative());
2363 // Compute the direction and magnitude by which we need to check overflow.
2364 bool OverflowDown = IsSub ^ IsNegativeConst;
2365 APInt Magnitude = C;
2366 if (IsNegativeConst) {
2367 if (C == APInt::getSignedMinValue(NumBits))
2368 // TODO: SINT_MIN on inversion gives the same negative value, we don't
2369 // want to deal with that.
2370 return false;
2371 Magnitude = -C;
2372 }
2373
2375 if (OverflowDown) {
2376 // To avoid overflow down, we need to make sure that MIN + Magnitude <= LHS.
2377 APInt Min = Signed ? APInt::getSignedMinValue(NumBits)
2378 : APInt::getMinValue(NumBits);
2379 APInt Limit = Min + Magnitude;
2380 return isKnownPredicateAt(Pred, getConstant(Limit), LHS, CtxI);
2381 } else {
2382 // To avoid overflow up, we need to make sure that LHS <= MAX - Magnitude.
2383 APInt Max = Signed ? APInt::getSignedMaxValue(NumBits)
2384 : APInt::getMaxValue(NumBits);
2385 APInt Limit = Max - Magnitude;
2386 return isKnownPredicateAt(Pred, LHS, getConstant(Limit), CtxI);
2387 }
2388}
2389
2390std::optional<SCEV::NoWrapFlags>
2392 const OverflowingBinaryOperator *OBO) {
2393 // It cannot be done any better.
2394 if (OBO->hasNoUnsignedWrap() && OBO->hasNoSignedWrap())
2395 return std::nullopt;
2396
2398
2399 if (OBO->hasNoUnsignedWrap())
2401 if (OBO->hasNoSignedWrap())
2403
2404 bool Deduced = false;
2405
2406 if (OBO->getOpcode() != Instruction::Add &&
2407 OBO->getOpcode() != Instruction::Sub &&
2408 OBO->getOpcode() != Instruction::Mul)
2409 return std::nullopt;
2410
2411 const SCEV *LHS = getSCEV(OBO->getOperand(0));
2412 const SCEV *RHS = getSCEV(OBO->getOperand(1));
2413
2414 const Instruction *CtxI =
2415 UseContextForNoWrapFlagInference ? dyn_cast<Instruction>(OBO) : nullptr;
2416 if (!OBO->hasNoUnsignedWrap() &&
2418 /* Signed */ false, LHS, RHS, CtxI)) {
2420 Deduced = true;
2421 }
2422
2423 if (!OBO->hasNoSignedWrap() &&
2425 /* Signed */ true, LHS, RHS, CtxI)) {
2427 Deduced = true;
2428 }
2429
2430 if (Deduced)
2431 return Flags;
2432 return std::nullopt;
2433}
2434
2435// We're trying to construct a SCEV of type `Type' with `Ops' as operands and
2436// `OldFlags' as can't-wrap behavior. Infer a more aggressive set of
2437// can't-overflow flags for the operation if possible.
2438static SCEV::NoWrapFlags
2440 const ArrayRef<const SCEV *> Ops,
2441 SCEV::NoWrapFlags Flags) {
2442 using namespace std::placeholders;
2443
2444 using OBO = OverflowingBinaryOperator;
2445
2446 bool CanAnalyze =
2448 (void)CanAnalyze;
2449 assert(CanAnalyze && "don't call from other places!");
2450
2451 int SignOrUnsignMask = SCEV::FlagNUW | SCEV::FlagNSW;
2452 SCEV::NoWrapFlags SignOrUnsignWrap =
2453 ScalarEvolution::maskFlags(Flags, SignOrUnsignMask);
2454
2455 // If FlagNSW is true and all the operands are non-negative, infer FlagNUW.
2456 auto IsKnownNonNegative = [&](const SCEV *S) {
2457 return SE->isKnownNonNegative(S);
2458 };
2459
2460 if (SignOrUnsignWrap == SCEV::FlagNSW && all_of(Ops, IsKnownNonNegative))
2461 Flags =
2462 ScalarEvolution::setFlags(Flags, (SCEV::NoWrapFlags)SignOrUnsignMask);
2463
2464 SignOrUnsignWrap = ScalarEvolution::maskFlags(Flags, SignOrUnsignMask);
2465
2466 if (SignOrUnsignWrap != SignOrUnsignMask &&
2467 (Type == scAddExpr || Type == scMulExpr) && Ops.size() == 2 &&
2468 isa<SCEVConstant>(Ops[0])) {
2469
2470 auto Opcode = [&] {
2471 switch (Type) {
2472 case scAddExpr:
2473 return Instruction::Add;
2474 case scMulExpr:
2475 return Instruction::Mul;
2476 default:
2477 llvm_unreachable("Unexpected SCEV op.");
2478 }
2479 }();
2480
2481 const APInt &C = cast<SCEVConstant>(Ops[0])->getAPInt();
2482
2483 // (A <opcode> C) --> (A <opcode> C)<nsw> if the op doesn't sign overflow.
2484 if (!(SignOrUnsignWrap & SCEV::FlagNSW)) {
2486 Opcode, C, OBO::NoSignedWrap);
2487 if (NSWRegion.contains(SE->getSignedRange(Ops[1])))
2489 }
2490
2491 // (A <opcode> C) --> (A <opcode> C)<nuw> if the op doesn't unsign overflow.
2492 if (!(SignOrUnsignWrap & SCEV::FlagNUW)) {
2494 Opcode, C, OBO::NoUnsignedWrap);
2495 if (NUWRegion.contains(SE->getUnsignedRange(Ops[1])))
2497 }
2498 }
2499
2500 // <0,+,nonnegative><nw> is also nuw
2501 // TODO: Add corresponding nsw case
2503 !ScalarEvolution::hasFlags(Flags, SCEV::FlagNUW) && Ops.size() == 2 &&
2504 Ops[0]->isZero() && IsKnownNonNegative(Ops[1]))
2506
2507 // both (udiv X, Y) * Y and Y * (udiv X, Y) are always NUW
2509 Ops.size() == 2) {
2510 if (auto *UDiv = dyn_cast<SCEVUDivExpr>(Ops[0]))
2511 if (UDiv->getOperand(1) == Ops[1])
2513 if (auto *UDiv = dyn_cast<SCEVUDivExpr>(Ops[1]))
2514 if (UDiv->getOperand(1) == Ops[0])
2516 }
2517
2518 return Flags;
2519}
2520
2522 return isLoopInvariant(S, L) && properlyDominates(S, L->getHeader());
2523}
2524
2525/// Get a canonical add expression, or something simpler if possible.
2527 SCEV::NoWrapFlags OrigFlags,
2528 unsigned Depth) {
2529 assert(!(OrigFlags & ~(SCEV::FlagNUW | SCEV::FlagNSW)) &&
2530 "only nuw or nsw allowed");
2531 assert(!Ops.empty() && "Cannot get empty add!");
2532 if (Ops.size() == 1) return Ops[0];
2533#ifndef NDEBUG
2534 Type *ETy = getEffectiveSCEVType(Ops[0]->getType());
2535 for (unsigned i = 1, e = Ops.size(); i != e; ++i)
2536 assert(getEffectiveSCEVType(Ops[i]->getType()) == ETy &&
2537 "SCEVAddExpr operand types don't match!");
2538 unsigned NumPtrs = count_if(
2539 Ops, [](const SCEV *Op) { return Op->getType()->isPointerTy(); });
2540 assert(NumPtrs <= 1 && "add has at most one pointer operand");
2541#endif
2542
2543 const SCEV *Folded = constantFoldAndGroupOps(
2544 *this, LI, DT, Ops,
2545 [](const APInt &C1, const APInt &C2) { return C1 + C2; },
2546 [](const APInt &C) { return C.isZero(); }, // identity
2547 [](const APInt &C) { return false; }); // absorber
2548 if (Folded)
2549 return Folded;
2550
2551 unsigned Idx = isa<SCEVConstant>(Ops[0]) ? 1 : 0;
2552
2553 // Delay expensive flag strengthening until necessary.
2554 auto ComputeFlags = [this, OrigFlags](const ArrayRef<const SCEV *> Ops) {
2555 return StrengthenNoWrapFlags(this, scAddExpr, Ops, OrigFlags);
2556 };
2557
2558 // Limit recursion calls depth.
2560 return getOrCreateAddExpr(Ops, ComputeFlags(Ops));
2561
2562 if (SCEV *S = findExistingSCEVInCache(scAddExpr, Ops)) {
2563 // Don't strengthen flags if we have no new information.
2564 SCEVAddExpr *Add = static_cast<SCEVAddExpr *>(S);
2565 if (Add->getNoWrapFlags(OrigFlags) != OrigFlags)
2566 Add->setNoWrapFlags(ComputeFlags(Ops));
2567 return S;
2568 }
2569
2570 // Okay, check to see if the same value occurs in the operand list more than
2571 // once. If so, merge them together into an multiply expression. Since we
2572 // sorted the list, these values are required to be adjacent.
2573 Type *Ty = Ops[0]->getType();
2574 bool FoundMatch = false;
2575 for (unsigned i = 0, e = Ops.size(); i != e-1; ++i)
2576 if (Ops[i] == Ops[i+1]) { // X + Y + Y --> X + Y*2
2577 // Scan ahead to count how many equal operands there are.
2578 unsigned Count = 2;
2579 while (i+Count != e && Ops[i+Count] == Ops[i])
2580 ++Count;
2581 // Merge the values into a multiply.
2582 const SCEV *Scale = getConstant(Ty, Count);
2583 const SCEV *Mul = getMulExpr(Scale, Ops[i], SCEV::FlagAnyWrap, Depth + 1);
2584 if (Ops.size() == Count)
2585 return Mul;
2586 Ops[i] = Mul;
2587 Ops.erase(Ops.begin()+i+1, Ops.begin()+i+Count);
2588 --i; e -= Count - 1;
2589 FoundMatch = true;
2590 }
2591 if (FoundMatch)
2592 return getAddExpr(Ops, OrigFlags, Depth + 1);
2593
2594 // Check for truncates. If all the operands are truncated from the same
2595 // type, see if factoring out the truncate would permit the result to be
2596 // folded. eg., n*trunc(x) + m*trunc(y) --> trunc(trunc(m)*x + trunc(n)*y)
2597 // if the contents of the resulting outer trunc fold to something simple.
2598 auto FindTruncSrcType = [&]() -> Type * {
2599 // We're ultimately looking to fold an addrec of truncs and muls of only
2600 // constants and truncs, so if we find any other types of SCEV
2601 // as operands of the addrec then we bail and return nullptr here.
2602 // Otherwise, we return the type of the operand of a trunc that we find.
2603 if (auto *T = dyn_cast<SCEVTruncateExpr>(Ops[Idx]))
2604 return T->getOperand()->getType();
2605 if (const auto *Mul = dyn_cast<SCEVMulExpr>(Ops[Idx])) {
2606 const auto *LastOp = Mul->getOperand(Mul->getNumOperands() - 1);
2607 if (const auto *T = dyn_cast<SCEVTruncateExpr>(LastOp))
2608 return T->getOperand()->getType();
2609 }
2610 return nullptr;
2611 };
2612 if (auto *SrcType = FindTruncSrcType()) {
2614 bool Ok = true;
2615 // Check all the operands to see if they can be represented in the
2616 // source type of the truncate.
2617 for (const SCEV *Op : Ops) {
2618 if (const SCEVTruncateExpr *T = dyn_cast<SCEVTruncateExpr>(Op)) {
2619 if (T->getOperand()->getType() != SrcType) {
2620 Ok = false;
2621 break;
2622 }
2623 LargeOps.push_back(T->getOperand());
2624 } else if (const SCEVConstant *C = dyn_cast<SCEVConstant>(Op)) {
2625 LargeOps.push_back(getAnyExtendExpr(C, SrcType));
2626 } else if (const SCEVMulExpr *M = dyn_cast<SCEVMulExpr>(Op)) {
2627 SmallVector<const SCEV *, 8> LargeMulOps;
2628 for (unsigned j = 0, f = M->getNumOperands(); j != f && Ok; ++j) {
2629 if (const SCEVTruncateExpr *T =
2630 dyn_cast<SCEVTruncateExpr>(M->getOperand(j))) {
2631 if (T->getOperand()->getType() != SrcType) {
2632 Ok = false;
2633 break;
2634 }
2635 LargeMulOps.push_back(T->getOperand());
2636 } else if (const auto *C = dyn_cast<SCEVConstant>(M->getOperand(j))) {
2637 LargeMulOps.push_back(getAnyExtendExpr(C, SrcType));
2638 } else {
2639 Ok = false;
2640 break;
2641 }
2642 }
2643 if (Ok)
2644 LargeOps.push_back(getMulExpr(LargeMulOps, SCEV::FlagAnyWrap, Depth + 1));
2645 } else {
2646 Ok = false;
2647 break;
2648 }
2649 }
2650 if (Ok) {
2651 // Evaluate the expression in the larger type.
2652 const SCEV *Fold = getAddExpr(LargeOps, SCEV::FlagAnyWrap, Depth + 1);
2653 // If it folds to something simple, use it. Otherwise, don't.
2654 if (isa<SCEVConstant>(Fold) || isa<SCEVUnknown>(Fold))
2655 return getTruncateExpr(Fold, Ty);
2656 }
2657 }
2658
2659 if (Ops.size() == 2) {
2660 // Check if we have an expression of the form ((X + C1) - C2), where C1 and
2661 // C2 can be folded in a way that allows retaining wrapping flags of (X +
2662 // C1).
2663 const SCEV *A = Ops[0];
2664 const SCEV *B = Ops[1];
2665 auto *AddExpr = dyn_cast<SCEVAddExpr>(B);
2666 auto *C = dyn_cast<SCEVConstant>(A);
2667 if (AddExpr && C && isa<SCEVConstant>(AddExpr->getOperand(0))) {
2668 auto C1 = cast<SCEVConstant>(AddExpr->getOperand(0))->getAPInt();
2669 auto C2 = C->getAPInt();
2670 SCEV::NoWrapFlags PreservedFlags = SCEV::FlagAnyWrap;
2671
2672 APInt ConstAdd = C1 + C2;
2673 auto AddFlags = AddExpr->getNoWrapFlags();
2674 // Adding a smaller constant is NUW if the original AddExpr was NUW.
2676 ConstAdd.ule(C1)) {
2677 PreservedFlags =
2679 }
2680
2681 // Adding a constant with the same sign and small magnitude is NSW, if the
2682 // original AddExpr was NSW.
2684 C1.isSignBitSet() == ConstAdd.isSignBitSet() &&
2685 ConstAdd.abs().ule(C1.abs())) {
2686 PreservedFlags =
2688 }
2689
2690 if (PreservedFlags != SCEV::FlagAnyWrap) {
2691 SmallVector<const SCEV *, 4> NewOps(AddExpr->operands());
2692 NewOps[0] = getConstant(ConstAdd);
2693 return getAddExpr(NewOps, PreservedFlags);
2694 }
2695 }
2696 }
2697
2698 // Canonicalize (-1 * urem X, Y) + X --> (Y * X/Y)
2699 if (Ops.size() == 2) {
2700 const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(Ops[0]);
2701 if (Mul && Mul->getNumOperands() == 2 &&
2702 Mul->getOperand(0)->isAllOnesValue()) {
2703 const SCEV *X;
2704 const SCEV *Y;
2705 if (matchURem(Mul->getOperand(1), X, Y) && X == Ops[1]) {
2706 return getMulExpr(Y, getUDivExpr(X, Y));
2707 }
2708 }
2709 }
2710
2711 // Skip past any other cast SCEVs.
2712 while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scAddExpr)
2713 ++Idx;
2714
2715 // If there are add operands they would be next.
2716 if (Idx < Ops.size()) {
2717 bool DeletedAdd = false;
2718 // If the original flags and all inlined SCEVAddExprs are NUW, use the
2719 // common NUW flag for expression after inlining. Other flags cannot be
2720 // preserved, because they may depend on the original order of operations.
2721 SCEV::NoWrapFlags CommonFlags = maskFlags(OrigFlags, SCEV::FlagNUW);
2722 while (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Ops[Idx])) {
2723 if (Ops.size() > AddOpsInlineThreshold ||
2724 Add->getNumOperands() > AddOpsInlineThreshold)
2725 break;
2726 // If we have an add, expand the add operands onto the end of the operands
2727 // list.
2728 Ops.erase(Ops.begin()+Idx);
2729 append_range(Ops, Add->operands());
2730 DeletedAdd = true;
2731 CommonFlags = maskFlags(CommonFlags, Add->getNoWrapFlags());
2732 }
2733
2734 // If we deleted at least one add, we added operands to the end of the list,
2735 // and they are not necessarily sorted. Recurse to resort and resimplify
2736 // any operands we just acquired.
2737 if (DeletedAdd)
2738 return getAddExpr(Ops, CommonFlags, Depth + 1);
2739 }
2740
2741 // Skip over the add expression until we get to a multiply.
2742 while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scMulExpr)
2743 ++Idx;
2744
2745 // Check to see if there are any folding opportunities present with
2746 // operands multiplied by constant values.
2747 if (Idx < Ops.size() && isa<SCEVMulExpr>(Ops[Idx])) {
2751 APInt AccumulatedConstant(BitWidth, 0);
2752 if (CollectAddOperandsWithScales(M, NewOps, AccumulatedConstant,
2753 Ops, APInt(BitWidth, 1), *this)) {
2754 struct APIntCompare {
2755 bool operator()(const APInt &LHS, const APInt &RHS) const {
2756 return LHS.ult(RHS);
2757 }
2758 };
2759
2760 // Some interesting folding opportunity is present, so its worthwhile to
2761 // re-generate the operands list. Group the operands by constant scale,
2762 // to avoid multiplying by the same constant scale multiple times.
2763 std::map<APInt, SmallVector<const SCEV *, 4>, APIntCompare> MulOpLists;
2764 for (const SCEV *NewOp : NewOps)
2765 MulOpLists[M.find(NewOp)->second].push_back(NewOp);
2766 // Re-generate the operands list.
2767 Ops.clear();
2768 if (AccumulatedConstant != 0)
2769 Ops.push_back(getConstant(AccumulatedConstant));
2770 for (auto &MulOp : MulOpLists) {
2771 if (MulOp.first == 1) {
2772 Ops.push_back(getAddExpr(MulOp.second, SCEV::FlagAnyWrap, Depth + 1));
2773 } else if (MulOp.first != 0) {
2775 getConstant(MulOp.first),
2776 getAddExpr(MulOp.second, SCEV::FlagAnyWrap, Depth + 1),
2777 SCEV::FlagAnyWrap, Depth + 1));
2778 }
2779 }
2780 if (Ops.empty())
2781 return getZero(Ty);
2782 if (Ops.size() == 1)
2783 return Ops[0];
2784 return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
2785 }
2786 }
2787
2788 // If we are adding something to a multiply expression, make sure the
2789 // something is not already an operand of the multiply. If so, merge it into
2790 // the multiply.
2791 for (; Idx < Ops.size() && isa<SCEVMulExpr>(Ops[Idx]); ++Idx) {
2792 const SCEVMulExpr *Mul = cast<SCEVMulExpr>(Ops[Idx]);
2793 for (unsigned MulOp = 0, e = Mul->getNumOperands(); MulOp != e; ++MulOp) {
2794 const SCEV *MulOpSCEV = Mul->getOperand(MulOp);
2795 if (isa<SCEVConstant>(MulOpSCEV))
2796 continue;
2797 for (unsigned AddOp = 0, e = Ops.size(); AddOp != e; ++AddOp)
2798 if (MulOpSCEV == Ops[AddOp]) {
2799 // Fold W + X + (X * Y * Z) --> W + (X * ((Y*Z)+1))
2800 const SCEV *InnerMul = Mul->getOperand(MulOp == 0);
2801 if (Mul->getNumOperands() != 2) {
2802 // If the multiply has more than two operands, we must get the
2803 // Y*Z term.
2805 Mul->operands().take_front(MulOp));
2806 append_range(MulOps, Mul->operands().drop_front(MulOp + 1));
2807 InnerMul = getMulExpr(MulOps, SCEV::FlagAnyWrap, Depth + 1);
2808 }
2809 SmallVector<const SCEV *, 2> TwoOps = {getOne(Ty), InnerMul};
2810 const SCEV *AddOne = getAddExpr(TwoOps, SCEV::FlagAnyWrap, Depth + 1);
2811 const SCEV *OuterMul = getMulExpr(AddOne, MulOpSCEV,
2813 if (Ops.size() == 2) return OuterMul;
2814 if (AddOp < Idx) {
2815 Ops.erase(Ops.begin()+AddOp);
2816 Ops.erase(Ops.begin()+Idx-1);
2817 } else {
2818 Ops.erase(Ops.begin()+Idx);
2819 Ops.erase(Ops.begin()+AddOp-1);
2820 }
2821 Ops.push_back(OuterMul);
2822 return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
2823 }
2824
2825 // Check this multiply against other multiplies being added together.
2826 for (unsigned OtherMulIdx = Idx+1;
2827 OtherMulIdx < Ops.size() && isa<SCEVMulExpr>(Ops[OtherMulIdx]);
2828 ++OtherMulIdx) {
2829 const SCEVMulExpr *OtherMul = cast<SCEVMulExpr>(Ops[OtherMulIdx]);
2830 // If MulOp occurs in OtherMul, we can fold the two multiplies
2831 // together.
2832 for (unsigned OMulOp = 0, e = OtherMul->getNumOperands();
2833 OMulOp != e; ++OMulOp)
2834 if (OtherMul->getOperand(OMulOp) == MulOpSCEV) {
2835 // Fold X + (A*B*C) + (A*D*E) --> X + (A*(B*C+D*E))
2836 const SCEV *InnerMul1 = Mul->getOperand(MulOp == 0);
2837 if (Mul->getNumOperands() != 2) {
2839 Mul->operands().take_front(MulOp));
2840 append_range(MulOps, Mul->operands().drop_front(MulOp+1));
2841 InnerMul1 = getMulExpr(MulOps, SCEV::FlagAnyWrap, Depth + 1);
2842 }
2843 const SCEV *InnerMul2 = OtherMul->getOperand(OMulOp == 0);
2844 if (OtherMul->getNumOperands() != 2) {
2846 OtherMul->operands().take_front(OMulOp));
2847 append_range(MulOps, OtherMul->operands().drop_front(OMulOp+1));
2848 InnerMul2 = getMulExpr(MulOps, SCEV::FlagAnyWrap, Depth + 1);
2849 }
2850 SmallVector<const SCEV *, 2> TwoOps = {InnerMul1, InnerMul2};
2851 const SCEV *InnerMulSum =
2852 getAddExpr(TwoOps, SCEV::FlagAnyWrap, Depth + 1);
2853 const SCEV *OuterMul = getMulExpr(MulOpSCEV, InnerMulSum,
2855 if (Ops.size() == 2) return OuterMul;
2856 Ops.erase(Ops.begin()+Idx);
2857 Ops.erase(Ops.begin()+OtherMulIdx-1);
2858 Ops.push_back(OuterMul);
2859 return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
2860 }
2861 }
2862 }
2863 }
2864
2865 // If there are any add recurrences in the operands list, see if any other
2866 // added values are loop invariant. If so, we can fold them into the
2867 // recurrence.
2868 while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scAddRecExpr)
2869 ++Idx;
2870
2871 // Scan over all recurrences, trying to fold loop invariants into them.
2872 for (; Idx < Ops.size() && isa<SCEVAddRecExpr>(Ops[Idx]); ++Idx) {
2873 // Scan all of the other operands to this add and add them to the vector if
2874 // they are loop invariant w.r.t. the recurrence.
2876 const SCEVAddRecExpr *AddRec = cast<SCEVAddRecExpr>(Ops[Idx]);
2877 const Loop *AddRecLoop = AddRec->getLoop();
2878 for (unsigned i = 0, e = Ops.size(); i != e; ++i)
2879 if (isAvailableAtLoopEntry(Ops[i], AddRecLoop)) {
2880 LIOps.push_back(Ops[i]);
2881 Ops.erase(Ops.begin()+i);
2882 --i; --e;
2883 }
2884
2885 // If we found some loop invariants, fold them into the recurrence.
2886 if (!LIOps.empty()) {
2887 // Compute nowrap flags for the addition of the loop-invariant ops and
2888 // the addrec. Temporarily push it as an operand for that purpose. These
2889 // flags are valid in the scope of the addrec only.
2890 LIOps.push_back(AddRec);
2891 SCEV::NoWrapFlags Flags = ComputeFlags(LIOps);
2892 LIOps.pop_back();
2893
2894 // NLI + LI + {Start,+,Step} --> NLI + {LI+Start,+,Step}
2895 LIOps.push_back(AddRec->getStart());
2896
2897 SmallVector<const SCEV *, 4> AddRecOps(AddRec->operands());
2898
2899 // It is not in general safe to propagate flags valid on an add within
2900 // the addrec scope to one outside it. We must prove that the inner
2901 // scope is guaranteed to execute if the outer one does to be able to
2902 // safely propagate. We know the program is undefined if poison is
2903 // produced on the inner scoped addrec. We also know that *for this use*
2904 // the outer scoped add can't overflow (because of the flags we just
2905 // computed for the inner scoped add) without the program being undefined.
2906 // Proving that entry to the outer scope neccesitates entry to the inner
2907 // scope, thus proves the program undefined if the flags would be violated
2908 // in the outer scope.
2909 SCEV::NoWrapFlags AddFlags = Flags;
2910 if (AddFlags != SCEV::FlagAnyWrap) {
2911 auto *DefI = getDefiningScopeBound(LIOps);
2912 auto *ReachI = &*AddRecLoop->getHeader()->begin();
2913 if (!isGuaranteedToTransferExecutionTo(DefI, ReachI))
2914 AddFlags = SCEV::FlagAnyWrap;
2915 }
2916 AddRecOps[0] = getAddExpr(LIOps, AddFlags, Depth + 1);
2917
2918 // Build the new addrec. Propagate the NUW and NSW flags if both the
2919 // outer add and the inner addrec are guaranteed to have no overflow.
2920 // Always propagate NW.
2921 Flags = AddRec->getNoWrapFlags(setFlags(Flags, SCEV::FlagNW));
2922 const SCEV *NewRec = getAddRecExpr(AddRecOps, AddRecLoop, Flags);
2923
2924 // If all of the other operands were loop invariant, we are done.
2925 if (Ops.size() == 1) return NewRec;
2926
2927 // Otherwise, add the folded AddRec by the non-invariant parts.
2928 for (unsigned i = 0;; ++i)
2929 if (Ops[i] == AddRec) {
2930 Ops[i] = NewRec;
2931 break;
2932 }
2933 return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
2934 }
2935
2936 // Okay, if there weren't any loop invariants to be folded, check to see if
2937 // there are multiple AddRec's with the same loop induction variable being
2938 // added together. If so, we can fold them.
2939 for (unsigned OtherIdx = Idx+1;
2940 OtherIdx < Ops.size() && isa<SCEVAddRecExpr>(Ops[OtherIdx]);
2941 ++OtherIdx) {
2942 // We expect the AddRecExpr's to be sorted in reverse dominance order,
2943 // so that the 1st found AddRecExpr is dominated by all others.
2944 assert(DT.dominates(
2945 cast<SCEVAddRecExpr>(Ops[OtherIdx])->getLoop()->getHeader(),
2946 AddRec->getLoop()->getHeader()) &&
2947 "AddRecExprs are not sorted in reverse dominance order?");
2948 if (AddRecLoop == cast<SCEVAddRecExpr>(Ops[OtherIdx])->getLoop()) {
2949 // Other + {A,+,B}<L> + {C,+,D}<L> --> Other + {A+C,+,B+D}<L>
2950 SmallVector<const SCEV *, 4> AddRecOps(AddRec->operands());
2951 for (; OtherIdx != Ops.size() && isa<SCEVAddRecExpr>(Ops[OtherIdx]);
2952 ++OtherIdx) {
2953 const auto *OtherAddRec = cast<SCEVAddRecExpr>(Ops[OtherIdx]);
2954 if (OtherAddRec->getLoop() == AddRecLoop) {
2955 for (unsigned i = 0, e = OtherAddRec->getNumOperands();
2956 i != e; ++i) {
2957 if (i >= AddRecOps.size()) {
2958 append_range(AddRecOps, OtherAddRec->operands().drop_front(i));
2959 break;
2960 }
2962 AddRecOps[i], OtherAddRec->getOperand(i)};
2963 AddRecOps[i] = getAddExpr(TwoOps, SCEV::FlagAnyWrap, Depth + 1);
2964 }
2965 Ops.erase(Ops.begin() + OtherIdx); --OtherIdx;
2966 }
2967 }
2968 // Step size has changed, so we cannot guarantee no self-wraparound.
2969 Ops[Idx] = getAddRecExpr(AddRecOps, AddRecLoop, SCEV::FlagAnyWrap);
2970 return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
2971 }
2972 }
2973
2974 // Otherwise couldn't fold anything into this recurrence. Move onto the
2975 // next one.
2976 }
2977
2978 // Okay, it looks like we really DO need an add expr. Check to see if we
2979 // already have one, otherwise create a new one.
2980 return getOrCreateAddExpr(Ops, ComputeFlags(Ops));
2981}
2982
2983const SCEV *
2984ScalarEvolution::getOrCreateAddExpr(ArrayRef<const SCEV *> Ops,
2985 SCEV::NoWrapFlags Flags) {
2987 ID.AddInteger(scAddExpr);
2988 for (const SCEV *Op : Ops)
2989 ID.AddPointer(Op);
2990 void *IP = nullptr;
2991 SCEVAddExpr *S =
2992 static_cast<SCEVAddExpr *>(UniqueSCEVs.FindNodeOrInsertPos(ID, IP));
2993 if (!S) {
2994 const SCEV **O = SCEVAllocator.Allocate<const SCEV *>(Ops.size());
2995 std::uninitialized_copy(Ops.begin(), Ops.end(), O);
2996 S = new (SCEVAllocator)
2997 SCEVAddExpr(ID.Intern(SCEVAllocator), O, Ops.size());
2998 UniqueSCEVs.InsertNode(S, IP);
2999 registerUser(S, Ops);
3000 }
3001 S->setNoWrapFlags(Flags);
3002 return S;
3003}
3004
3005const SCEV *
3006ScalarEvolution::getOrCreateAddRecExpr(ArrayRef<const SCEV *> Ops,
3007 const Loop *L, SCEV::NoWrapFlags Flags) {
3009 ID.AddInteger(scAddRecExpr);
3010 for (const SCEV *Op : Ops)
3011 ID.AddPointer(Op);
3012 ID.AddPointer(L);
3013 void *IP = nullptr;
3014 SCEVAddRecExpr *S =
3015 static_cast<SCEVAddRecExpr *>(UniqueSCEVs.FindNodeOrInsertPos(ID, IP));
3016 if (!S) {
3017 const SCEV **O = SCEVAllocator.Allocate<const SCEV *>(Ops.size());
3018 std::uninitialized_copy(Ops.begin(), Ops.end(), O);
3019 S = new (SCEVAllocator)
3020 SCEVAddRecExpr(ID.Intern(SCEVAllocator), O, Ops.size(), L);
3021 UniqueSCEVs.InsertNode(S, IP);
3022 LoopUsers[L].push_back(S);
3023 registerUser(S, Ops);
3024 }
3025 setNoWrapFlags(S, Flags);
3026 return S;
3027}
3028
3029const SCEV *
3030ScalarEvolution::getOrCreateMulExpr(ArrayRef<const SCEV *> Ops,
3031 SCEV::NoWrapFlags Flags) {
3033 ID.AddInteger(scMulExpr);
3034 for (const SCEV *Op : Ops)
3035 ID.AddPointer(Op);
3036 void *IP = nullptr;
3037 SCEVMulExpr *S =
3038 static_cast<SCEVMulExpr *>(UniqueSCEVs.FindNodeOrInsertPos(ID, IP));
3039 if (!S) {
3040 const SCEV **O = SCEVAllocator.Allocate<const SCEV *>(Ops.size());
3041 std::uninitialized_copy(Ops.begin(), Ops.end(), O);
3042 S = new (SCEVAllocator) SCEVMulExpr(ID.Intern(SCEVAllocator),
3043 O, Ops.size());
3044 UniqueSCEVs.InsertNode(S, IP);
3045 registerUser(S, Ops);
3046 }
3047 S->setNoWrapFlags(Flags);
3048 return S;
3049}
3050
3051static uint64_t umul_ov(uint64_t i, uint64_t j, bool &Overflow) {
3052 uint64_t k = i*j;
3053 if (j > 1 && k / j != i) Overflow = true;
3054 return k;
3055}
3056
3057/// Compute the result of "n choose k", the binomial coefficient. If an
3058/// intermediate computation overflows, Overflow will be set and the return will
3059/// be garbage. Overflow is not cleared on absence of overflow.
3060static uint64_t Choose(uint64_t n, uint64_t k, bool &Overflow) {
3061 // We use the multiplicative formula:
3062 // n(n-1)(n-2)...(n-(k-1)) / k(k-1)(k-2)...1 .
3063 // At each iteration, we take the n-th term of the numeral and divide by the
3064 // (k-n)th term of the denominator. This division will always produce an
3065 // integral result, and helps reduce the chance of overflow in the
3066 // intermediate computations. However, we can still overflow even when the
3067 // final result would fit.
3068
3069 if (n == 0 || n == k) return 1;
3070 if (k > n) return 0;
3071
3072 if (k > n/2)
3073 k = n-k;
3074
3075 uint64_t r = 1;
3076 for (uint64_t i = 1; i <= k; ++i) {
3077 r = umul_ov(r, n-(i-1), Overflow);
3078 r /= i;
3079 }
3080 return r;
3081}
3082
3083/// Determine if any of the operands in this SCEV are a constant or if
3084/// any of the add or multiply expressions in this SCEV contain a constant.
3085static bool containsConstantInAddMulChain(const SCEV *StartExpr) {
3086 struct FindConstantInAddMulChain {
3087 bool FoundConstant = false;
3088
3089 bool follow(const SCEV *S) {
3090 FoundConstant |= isa<SCEVConstant>(S);
3091 return isa<SCEVAddExpr>(S) || isa<SCEVMulExpr>(S);
3092 }
3093
3094 bool isDone() const {
3095 return FoundConstant;
3096 }
3097 };
3098
3099 FindConstantInAddMulChain F;
3101 ST.visitAll(StartExpr);
3102 return F.FoundConstant;
3103}
3104
3105/// Get a canonical multiply expression, or something simpler if possible.
3107 SCEV::NoWrapFlags OrigFlags,
3108 unsigned Depth) {
3109 assert(OrigFlags == maskFlags(OrigFlags, SCEV::FlagNUW | SCEV::FlagNSW) &&
3110 "only nuw or nsw allowed");
3111 assert(!Ops.empty() && "Cannot get empty mul!");
3112 if (Ops.size() == 1) return Ops[0];
3113#ifndef NDEBUG
3114 Type *ETy = Ops[0]->getType();
3115 assert(!ETy->isPointerTy());
3116 for (unsigned i = 1, e = Ops.size(); i != e; ++i)
3117 assert(Ops[i]->getType() == ETy &&
3118 "SCEVMulExpr operand types don't match!");
3119#endif
3120
3121 const SCEV *Folded = constantFoldAndGroupOps(
3122 *this, LI, DT, Ops,
3123 [](const APInt &C1, const APInt &C2) { return C1 * C2; },
3124 [](const APInt &C) { return C.isOne(); }, // identity
3125 [](const APInt &C) { return C.isZero(); }); // absorber
3126 if (Folded)
3127 return Folded;
3128
3129 // Delay expensive flag strengthening until necessary.
3130 auto ComputeFlags = [this, OrigFlags](const ArrayRef<const SCEV *> Ops) {
3131 return StrengthenNoWrapFlags(this, scMulExpr, Ops, OrigFlags);
3132 };
3133
3134 // Limit recursion calls depth.
3136 return getOrCreateMulExpr(Ops, ComputeFlags(Ops));
3137
3138 if (SCEV *S = findExistingSCEVInCache(scMulExpr, Ops)) {
3139 // Don't strengthen flags if we have no new information.
3140 SCEVMulExpr *Mul = static_cast<SCEVMulExpr *>(S);
3141 if (Mul->getNoWrapFlags(OrigFlags) != OrigFlags)
3142 Mul->setNoWrapFlags(ComputeFlags(Ops));
3143 return S;
3144 }
3145
3146 if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(Ops[0])) {
3147 if (Ops.size() == 2) {
3148 // C1*(C2+V) -> C1*C2 + C1*V
3149 if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Ops[1]))
3150 // If any of Add's ops are Adds or Muls with a constant, apply this
3151 // transformation as well.
3152 //
3153 // TODO: There are some cases where this transformation is not
3154 // profitable; for example, Add = (C0 + X) * Y + Z. Maybe the scope of
3155 // this transformation should be narrowed down.
3156 if (Add->getNumOperands() == 2 && containsConstantInAddMulChain(Add)) {
3157 const SCEV *LHS = getMulExpr(LHSC, Add->getOperand(0),
3159 const SCEV *RHS = getMulExpr(LHSC, Add->getOperand(1),
3161 return getAddExpr(LHS, RHS, SCEV::FlagAnyWrap, Depth + 1);
3162 }
3163
3164 if (Ops[0]->isAllOnesValue()) {
3165 // If we have a mul by -1 of an add, try distributing the -1 among the
3166 // add operands.
3167 if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Ops[1])) {
3169 bool AnyFolded = false;
3170 for (const SCEV *AddOp : Add->operands()) {
3171 const SCEV *Mul = getMulExpr(Ops[0], AddOp, SCEV::FlagAnyWrap,
3172 Depth + 1);
3173 if (!isa<SCEVMulExpr>(Mul)) AnyFolded = true;
3174 NewOps.push_back(Mul);
3175 }
3176 if (AnyFolded)
3177 return getAddExpr(NewOps, SCEV::FlagAnyWrap, Depth + 1);
3178 } else if (const auto *AddRec = dyn_cast<SCEVAddRecExpr>(Ops[1])) {
3179 // Negation preserves a recurrence's no self-wrap property.
3181 for (const SCEV *AddRecOp : AddRec->operands())
3182 Operands.push_back(getMulExpr(Ops[0], AddRecOp, SCEV::FlagAnyWrap,
3183 Depth + 1));
3184 // Let M be the minimum representable signed value. AddRec with nsw
3185 // multiplied by -1 can have signed overflow if and only if it takes a
3186 // value of M: M * (-1) would stay M and (M + 1) * (-1) would be the
3187 // maximum signed value. In all other cases signed overflow is
3188 // impossible.
3189 auto FlagsMask = SCEV::FlagNW;
3190 if (hasFlags(AddRec->getNoWrapFlags(), SCEV::FlagNSW)) {
3191 auto MinInt =
3192 APInt::getSignedMinValue(getTypeSizeInBits(AddRec->getType()));
3193 if (getSignedRangeMin(AddRec) != MinInt)
3194 FlagsMask = setFlags(FlagsMask, SCEV::FlagNSW);
3195 }
3196 return getAddRecExpr(Operands, AddRec->getLoop(),
3197 AddRec->getNoWrapFlags(FlagsMask));
3198 }
3199 }
3200 }
3201 }
3202
3203 // Skip over the add expression until we get to a multiply.
3204 unsigned Idx = 0;
3205 while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scMulExpr)
3206 ++Idx;
3207
3208 // If there are mul operands inline them all into this expression.
3209 if (Idx < Ops.size()) {
3210 bool DeletedMul = false;
3211 while (const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(Ops[Idx])) {
3212 if (Ops.size() > MulOpsInlineThreshold)
3213 break;
3214 // If we have an mul, expand the mul operands onto the end of the
3215 // operands list.
3216 Ops.erase(Ops.begin()+Idx);
3217 append_range(Ops, Mul->operands());
3218 DeletedMul = true;
3219 }
3220
3221 // If we deleted at least one mul, we added operands to the end of the
3222 // list, and they are not necessarily sorted. Recurse to resort and
3223 // resimplify any operands we just acquired.
3224 if (DeletedMul)
3225 return getMulExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
3226 }
3227
3228 // If there are any add recurrences in the operands list, see if any other
3229 // added values are loop invariant. If so, we can fold them into the
3230 // recurrence.
3231 while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scAddRecExpr)
3232 ++Idx;
3233
3234 // Scan over all recurrences, trying to fold loop invariants into them.
3235 for (; Idx < Ops.size() && isa<SCEVAddRecExpr>(Ops[Idx]); ++Idx) {
3236 // Scan all of the other operands to this mul and add them to the vector
3237 // if they are loop invariant w.r.t. the recurrence.
3239 const SCEVAddRecExpr *AddRec = cast<SCEVAddRecExpr>(Ops[Idx]);
3240 for (unsigned i = 0, e = Ops.size(); i != e; ++i)
3241 if (isAvailableAtLoopEntry(Ops[i], AddRec->getLoop())) {
3242 LIOps.push_back(Ops[i]);
3243 Ops.erase(Ops.begin()+i);
3244 --i; --e;
3245 }
3246
3247 // If we found some loop invariants, fold them into the recurrence.
3248 if (!LIOps.empty()) {
3249 // NLI * LI * {Start,+,Step} --> NLI * {LI*Start,+,LI*Step}
3251 NewOps.reserve(AddRec->getNumOperands());
3252 const SCEV *Scale = getMulExpr(LIOps, SCEV::FlagAnyWrap, Depth + 1);
3253
3254 // If both the mul and addrec are nuw, we can preserve nuw.
3255 // If both the mul and addrec are nsw, we can only preserve nsw if either
3256 // a) they are also nuw, or
3257 // b) all multiplications of addrec operands with scale are nsw.
3258 SCEV::NoWrapFlags Flags =
3259 AddRec->getNoWrapFlags(ComputeFlags({Scale, AddRec}));
3260
3261 for (unsigned i = 0, e = AddRec->getNumOperands(); i != e; ++i) {
3262 NewOps.push_back(getMulExpr(Scale, AddRec->getOperand(i),
3263 SCEV::FlagAnyWrap, Depth + 1));
3264
3265 if (hasFlags(Flags, SCEV::FlagNSW) && !hasFlags(Flags, SCEV::FlagNUW)) {
3267 Instruction::Mul, getSignedRange(Scale),
3269 if (!NSWRegion.contains(getSignedRange(AddRec->getOperand(i))))
3270 Flags = clearFlags(Flags, SCEV::FlagNSW);
3271 }
3272 }
3273
3274 const SCEV *NewRec = getAddRecExpr(NewOps, AddRec->getLoop(), Flags);
3275
3276 // If all of the other operands were loop invariant, we are done.
3277 if (Ops.size() == 1) return NewRec;
3278
3279 // Otherwise, multiply the folded AddRec by the non-invariant parts.
3280 for (unsigned i = 0;; ++i)
3281 if (Ops[i] == AddRec) {
3282 Ops[i] = NewRec;
3283 break;
3284 }
3285 return getMulExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
3286 }
3287
3288 // Okay, if there weren't any loop invariants to be folded, check to see
3289 // if there are multiple AddRec's with the same loop induction variable
3290 // being multiplied together. If so, we can fold them.
3291
3292 // {A1,+,A2,+,...,+,An}<L> * {B1,+,B2,+,...,+,Bn}<L>
3293 // = {x=1 in [ sum y=x..2x [ sum z=max(y-x, y-n)..min(x,n) [
3294 // choose(x, 2x)*choose(2x-y, x-z)*A_{y-z}*B_z
3295 // ]]],+,...up to x=2n}.
3296 // Note that the arguments to choose() are always integers with values
3297 // known at compile time, never SCEV objects.
3298 //
3299 // The implementation avoids pointless extra computations when the two
3300 // addrec's are of different length (mathematically, it's equivalent to
3301 // an infinite stream of zeros on the right).
3302 bool OpsModified = false;
3303 for (unsigned OtherIdx = Idx+1;
3304 OtherIdx != Ops.size() && isa<SCEVAddRecExpr>(Ops[OtherIdx]);
3305 ++OtherIdx) {
3306 const SCEVAddRecExpr *OtherAddRec =
3307 dyn_cast<SCEVAddRecExpr>(Ops[OtherIdx]);
3308 if (!OtherAddRec || OtherAddRec->getLoop() != AddRec->getLoop())
3309 continue;
3310
3311 // Limit max number of arguments to avoid creation of unreasonably big
3312 // SCEVAddRecs with very complex operands.
3313 if (AddRec->getNumOperands() + OtherAddRec->getNumOperands() - 1 >
3314 MaxAddRecSize || hasHugeExpression({AddRec, OtherAddRec}))
3315 continue;
3316
3317 bool Overflow = false;
3318 Type *Ty = AddRec->getType();
3319 bool LargerThan64Bits = getTypeSizeInBits(Ty) > 64;
3321 for (int x = 0, xe = AddRec->getNumOperands() +
3322 OtherAddRec->getNumOperands() - 1; x != xe && !Overflow; ++x) {
3323 SmallVector <const SCEV *, 7> SumOps;
3324 for (int y = x, ye = 2*x+1; y != ye && !Overflow; ++y) {
3325 uint64_t Coeff1 = Choose(x, 2*x - y, Overflow);
3326 for (int z = std::max(y-x, y-(int)AddRec->getNumOperands()+1),
3327 ze = std::min(x+1, (int)OtherAddRec->getNumOperands());
3328 z < ze && !Overflow; ++z) {
3329 uint64_t Coeff2 = Choose(2*x - y, x-z, Overflow);
3330 uint64_t Coeff;
3331 if (LargerThan64Bits)
3332 Coeff = umul_ov(Coeff1, Coeff2, Overflow);
3333 else
3334 Coeff = Coeff1*Coeff2;
3335 const SCEV *CoeffTerm = getConstant(Ty, Coeff);
3336 const SCEV *Term1 = AddRec->getOperand(y-z);
3337 const SCEV *Term2 = OtherAddRec->getOperand(z);
3338 SumOps.push_back(getMulExpr(CoeffTerm, Term1, Term2,
3339 SCEV::FlagAnyWrap, Depth + 1));
3340 }
3341 }
3342 if (SumOps.empty())
3343 SumOps.push_back(getZero(Ty));
3344 AddRecOps.push_back(getAddExpr(SumOps, SCEV::FlagAnyWrap, Depth + 1));
3345 }
3346 if (!Overflow) {
3347 const SCEV *NewAddRec = getAddRecExpr(AddRecOps, AddRec->getLoop(),
3349 if (Ops.size() == 2) return NewAddRec;
3350 Ops[Idx] = NewAddRec;
3351 Ops.erase(Ops.begin() + OtherIdx); --OtherIdx;
3352 OpsModified = true;
3353 AddRec = dyn_cast<SCEVAddRecExpr>(NewAddRec);
3354 if (!AddRec)
3355 break;
3356 }
3357 }
3358 if (OpsModified)
3359 return getMulExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
3360
3361 // Otherwise couldn't fold anything into this recurrence. Move onto the
3362 // next one.
3363 }
3364
3365 // Okay, it looks like we really DO need an mul expr. Check to see if we
3366 // already have one, otherwise create a new one.
3367 return getOrCreateMulExpr(Ops, ComputeFlags(Ops));
3368}
3369
3370/// Represents an unsigned remainder expression based on unsigned division.
3372 const SCEV *RHS) {
3375 "SCEVURemExpr operand types don't match!");
3376
3377 // Short-circuit easy cases
3378 if (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS)) {
3379 // If constant is one, the result is trivial
3380 if (RHSC->getValue()->isOne())
3381 return getZero(LHS->getType()); // X urem 1 --> 0
3382
3383 // If constant is a power of two, fold into a zext(trunc(LHS)).
3384 if (RHSC->getAPInt().isPowerOf2()) {
3385 Type *FullTy = LHS->getType();
3386 Type *TruncTy =
3387 IntegerType::get(getContext(), RHSC->getAPInt().logBase2());
3388 return getZeroExtendExpr(getTruncateExpr(LHS, TruncTy), FullTy);
3389 }
3390 }
3391
3392 // Fallback to %a == %x urem %y == %x -<nuw> ((%x udiv %y) *<nuw> %y)
3393 const SCEV *UDiv = getUDivExpr(LHS, RHS);
3394 const SCEV *Mult = getMulExpr(UDiv, RHS, SCEV::FlagNUW);
3395 return getMinusSCEV(LHS, Mult, SCEV::FlagNUW);
3396}
3397
3398/// Get a canonical unsigned division expression, or something simpler if
3399/// possible.
3401 const SCEV *RHS) {
3402 assert(!LHS->getType()->isPointerTy() &&
3403 "SCEVUDivExpr operand can't be pointer!");
3404 assert(LHS->getType() == RHS->getType() &&
3405 "SCEVUDivExpr operand types don't match!");
3406
3408 ID.AddInteger(scUDivExpr);
3409 ID.AddPointer(LHS);
3410 ID.AddPointer(RHS);
3411 void *IP = nullptr;
3412 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP))
3413 return S;
3414
3415 // 0 udiv Y == 0
3416 if (match(LHS, m_scev_Zero()))
3417 return LHS;
3418
3419 if (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS)) {
3420 if (RHSC->getValue()->isOne())
3421 return LHS; // X udiv 1 --> x
3422 // If the denominator is zero, the result of the udiv is undefined. Don't
3423 // try to analyze it, because the resolution chosen here may differ from
3424 // the resolution chosen in other parts of the compiler.
3425 if (!RHSC->getValue()->isZero()) {
3426 // Determine if the division can be folded into the operands of
3427 // its operands.
3428 // TODO: Generalize this to non-constants by using known-bits information.
3429 Type *Ty = LHS->getType();
3430 unsigned LZ = RHSC->getAPInt().countl_zero();
3431 unsigned MaxShiftAmt = getTypeSizeInBits(Ty) - LZ - 1;
3432 // For non-power-of-two values, effectively round the value up to the
3433 // nearest power of two.
3434 if (!RHSC->getAPInt().isPowerOf2())
3435 ++MaxShiftAmt;
3436 IntegerType *ExtTy =
3437 IntegerType::get(getContext(), getTypeSizeInBits(Ty) + MaxShiftAmt);
3438 if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(LHS))
3439 if (const SCEVConstant *Step =
3440 dyn_cast<SCEVConstant>(AR->getStepRecurrence(*this))) {
3441 // {X,+,N}/C --> {X/C,+,N/C} if safe and N/C can be folded.
3442 const APInt &StepInt = Step->getAPInt();
3443 const APInt &DivInt = RHSC->getAPInt();
3444 if (!StepInt.urem(DivInt) &&
3445 getZeroExtendExpr(AR, ExtTy) ==
3446 getAddRecExpr(getZeroExtendExpr(AR->getStart(), ExtTy),
3447 getZeroExtendExpr(Step, ExtTy),
3448 AR->getLoop(), SCEV::FlagAnyWrap)) {
3450 for (const SCEV *Op : AR->operands())
3451 Operands.push_back(getUDivExpr(Op, RHS));
3452 return getAddRecExpr(Operands, AR->getLoop(), SCEV::FlagNW);
3453 }
3454 /// Get a canonical UDivExpr for a recurrence.
3455 /// {X,+,N}/C => {Y,+,N}/C where Y=X-(X%N). Safe when C%N=0.
3456 // We can currently only fold X%N if X is constant.
3457 const SCEVConstant *StartC = dyn_cast<SCEVConstant>(AR->getStart());
3458 if (StartC && !DivInt.urem(StepInt) &&
3459 getZeroExtendExpr(AR, ExtTy) ==
3460 getAddRecExpr(getZeroExtendExpr(AR->getStart(), ExtTy),
3461 getZeroExtendExpr(Step, ExtTy),
3462 AR->getLoop(), SCEV::FlagAnyWrap)) {
3463 const APInt &StartInt = StartC->getAPInt();
3464 const APInt &StartRem = StartInt.urem(StepInt);
3465 if (StartRem != 0) {
3466 const SCEV *NewLHS =
3467 getAddRecExpr(getConstant(StartInt - StartRem), Step,
3468 AR->getLoop(), SCEV::FlagNW);
3469 if (LHS != NewLHS) {
3470 LHS = NewLHS;
3471
3472 // Reset the ID to include the new LHS, and check if it is
3473 // already cached.
3474 ID.clear();
3475 ID.AddInteger(scUDivExpr);
3476 ID.AddPointer(LHS);
3477 ID.AddPointer(RHS);
3478 IP = nullptr;
3479 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP))
3480 return S;
3481 }
3482 }
3483 }
3484 }
3485 // (A*B)/C --> A*(B/C) if safe and B/C can be folded.
3486 if (const SCEVMulExpr *M = dyn_cast<SCEVMulExpr>(LHS)) {
3488 for (const SCEV *Op : M->operands())
3489 Operands.push_back(getZeroExtendExpr(Op, ExtTy));
3490 if (getZeroExtendExpr(M, ExtTy) == getMulExpr(Operands))
3491 // Find an operand that's safely divisible.
3492 for (unsigned i = 0, e = M->getNumOperands(); i != e; ++i) {
3493 const SCEV *Op = M->getOperand(i);
3494 const SCEV *Div = getUDivExpr(Op, RHSC);
3495 if (!isa<SCEVUDivExpr>(Div) && getMulExpr(Div, RHSC) == Op) {
3496 Operands = SmallVector<const SCEV *, 4>(M->operands());
3497 Operands[i] = Div;
3498 return getMulExpr(Operands);
3499 }
3500 }
3501 }
3502
3503 // (A/B)/C --> A/(B*C) if safe and B*C can be folded.
3504 if (const SCEVUDivExpr *OtherDiv = dyn_cast<SCEVUDivExpr>(LHS)) {
3505 if (auto *DivisorConstant =
3506 dyn_cast<SCEVConstant>(OtherDiv->getRHS())) {
3507 bool Overflow = false;
3508 APInt NewRHS =
3509 DivisorConstant->getAPInt().umul_ov(RHSC->getAPInt(), Overflow);
3510 if (Overflow) {
3511 return getConstant(RHSC->getType(), 0, false);
3512 }
3513 return getUDivExpr(OtherDiv->getLHS(), getConstant(NewRHS));
3514 }
3515 }
3516
3517 // (A+B)/C --> (A/C + B/C) if safe and A/C and B/C can be folded.
3518 if (const SCEVAddExpr *A = dyn_cast<SCEVAddExpr>(LHS)) {
3520 for (const SCEV *Op : A->operands())
3521 Operands.push_back(getZeroExtendExpr(Op, ExtTy));
3522 if (getZeroExtendExpr(A, ExtTy) == getAddExpr(Operands)) {
3523 Operands.clear();
3524 for (unsigned i = 0, e = A->getNumOperands(); i != e; ++i) {
3525 const SCEV *Op = getUDivExpr(A->getOperand(i), RHS);
3526 if (isa<SCEVUDivExpr>(Op) ||
3527 getMulExpr(Op, RHS) != A->getOperand(i))
3528 break;
3529 Operands.push_back(Op);
3530 }
3531 if (Operands.size() == A->getNumOperands())
3532 return getAddExpr(Operands);
3533 }
3534 }
3535
3536 // Fold if both operands are constant.
3537 if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(LHS))
3538 return getConstant(LHSC->getAPInt().udiv(RHSC->getAPInt()));
3539 }
3540 }
3541
3542 // ((-C + (C smax %x)) /u %x) evaluates to zero, for any positive constant C.
3543 if (const auto *AE = dyn_cast<SCEVAddExpr>(LHS);
3544 AE && AE->getNumOperands() == 2) {
3545 if (const auto *VC = dyn_cast<SCEVConstant>(AE->getOperand(0))) {
3546 const APInt &NegC = VC->getAPInt();
3547 if (NegC.isNegative() && !NegC.isMinSignedValue()) {
3548 const auto *MME = dyn_cast<SCEVSMaxExpr>(AE->getOperand(1));
3549 if (MME && MME->getNumOperands() == 2 &&
3550 isa<SCEVConstant>(MME->getOperand(0)) &&
3551 cast<SCEVConstant>(MME->getOperand(0))->getAPInt() == -NegC &&
3552 MME->getOperand(1) == RHS)
3553 return getZero(LHS->getType());
3554 }
3555 }
3556 }
3557
3558 // The Insertion Point (IP) might be invalid by now (due to UniqueSCEVs
3559 // changes). Make sure we get a new one.
3560 IP = nullptr;
3561 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
3562 SCEV *S = new (SCEVAllocator) SCEVUDivExpr(ID.Intern(SCEVAllocator),
3563 LHS, RHS);
3564 UniqueSCEVs.InsertNode(S, IP);
3565 registerUser(S, {LHS, RHS});
3566 return S;
3567}
3568
3569APInt gcd(const SCEVConstant *C1, const SCEVConstant *C2) {
3570 APInt A = C1->getAPInt().abs();
3571 APInt B = C2->getAPInt().abs();
3572 uint32_t ABW = A.getBitWidth();
3573 uint32_t BBW = B.getBitWidth();
3574
3575 if (ABW > BBW)
3576 B = B.zext(ABW);
3577 else if (ABW < BBW)
3578 A = A.zext(BBW);
3579
3580 return APIntOps::GreatestCommonDivisor(std::move(A), std::move(B));
3581}
3582
3583/// Get a canonical unsigned division expression, or something simpler if
3584/// possible. There is no representation for an exact udiv in SCEV IR, but we
3585/// can attempt to remove factors from the LHS and RHS. We can't do this when
3586/// it's not exact because the udiv may be clearing bits.
3588 const SCEV *RHS) {
3589 // TODO: we could try to find factors in all sorts of things, but for now we
3590 // just deal with u/exact (multiply, constant). See SCEVDivision towards the
3591 // end of this file for inspiration.
3592
3593 const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(LHS);
3594 if (!Mul || !Mul->hasNoUnsignedWrap())
3595 return getUDivExpr(LHS, RHS);
3596
3597 if (const SCEVConstant *RHSCst = dyn_cast<SCEVConstant>(RHS)) {
3598 // If the mulexpr multiplies by a constant, then that constant must be the
3599 // first element of the mulexpr.
3600 if (const auto *LHSCst = dyn_cast<SCEVConstant>(Mul->getOperand(0))) {
3601 if (LHSCst == RHSCst) {
3603 return getMulExpr(Operands);
3604 }
3605
3606 // We can't just assume that LHSCst divides RHSCst cleanly, it could be
3607 // that there's a factor provided by one of the other terms. We need to
3608 // check.
3609 APInt Factor = gcd(LHSCst, RHSCst);
3610 if (!Factor.isIntN(1)) {
3611 LHSCst =
3612 cast<SCEVConstant>(getConstant(LHSCst->getAPInt().udiv(Factor)));
3613 RHSCst =
3614 cast<SCEVConstant>(getConstant(RHSCst->getAPInt().udiv(Factor)));
3616 Operands.push_back(LHSCst);
3617 append_range(Operands, Mul->operands().drop_front());
3619 RHS = RHSCst;
3620 Mul = dyn_cast<SCEVMulExpr>(LHS);
3621 if (!Mul)
3622 return getUDivExactExpr(LHS, RHS);
3623 }
3624 }
3625 }
3626
3627 for (int i = 0, e = Mul->getNumOperands(); i != e; ++i) {
3628 if (Mul->getOperand(i) == RHS) {
3630 append_range(Operands, Mul->operands().take_front(i));
3631 append_range(Operands, Mul->operands().drop_front(i + 1));
3632 return getMulExpr(Operands);
3633 }
3634 }
3635
3636 return getUDivExpr(LHS, RHS);
3637}
3638
3639/// Get an add recurrence expression for the specified loop. Simplify the
3640/// expression as much as possible.
3641const SCEV *ScalarEvolution::getAddRecExpr(const SCEV *Start, const SCEV *Step,
3642 const Loop *L,
3643 SCEV::NoWrapFlags Flags) {
3645 Operands.push_back(Start);
3646 if (const SCEVAddRecExpr *StepChrec = dyn_cast<SCEVAddRecExpr>(Step))
3647 if (StepChrec->getLoop() == L) {
3648 append_range(Operands, StepChrec->operands());
3649 return getAddRecExpr(Operands, L, maskFlags(Flags, SCEV::FlagNW));
3650 }
3651
3652 Operands.push_back(Step);
3653 return getAddRecExpr(Operands, L, Flags);
3654}
3655
3656/// Get an add recurrence expression for the specified loop. Simplify the
3657/// expression as much as possible.
3658const SCEV *
3660 const Loop *L, SCEV::NoWrapFlags Flags) {
3661 if (Operands.size() == 1) return Operands[0];
3662#ifndef NDEBUG
3664 for (const SCEV *Op : llvm::drop_begin(Operands)) {
3665 assert(getEffectiveSCEVType(Op->getType()) == ETy &&
3666 "SCEVAddRecExpr operand types don't match!");
3667 assert(!Op->getType()->isPointerTy() && "Step must be integer");
3668 }
3669 for (const SCEV *Op : Operands)
3671 "SCEVAddRecExpr operand is not available at loop entry!");
3672#endif
3673
3674 if (Operands.back()->isZero()) {
3675 Operands.pop_back();
3676 return getAddRecExpr(Operands, L, SCEV::FlagAnyWrap); // {X,+,0} --> X
3677 }
3678
3679 // It's tempting to want to call getConstantMaxBackedgeTakenCount count here and
3680 // use that information to infer NUW and NSW flags. However, computing a
3681 // BE count requires calling getAddRecExpr, so we may not yet have a
3682 // meaningful BE count at this point (and if we don't, we'd be stuck
3683 // with a SCEVCouldNotCompute as the cached BE count).
3684
3685 Flags = StrengthenNoWrapFlags(this, scAddRecExpr, Operands, Flags);
3686
3687 // Canonicalize nested AddRecs in by nesting them in order of loop depth.
3688 if (const SCEVAddRecExpr *NestedAR = dyn_cast<SCEVAddRecExpr>(Operands[0])) {
3689 const Loop *NestedLoop = NestedAR->getLoop();
3690 if (L->contains(NestedLoop)
3691 ? (L->getLoopDepth() < NestedLoop->getLoopDepth())
3692 : (!NestedLoop->contains(L) &&
3693 DT.dominates(L->getHeader(), NestedLoop->getHeader()))) {
3694 SmallVector<const SCEV *, 4> NestedOperands(NestedAR->operands());
3695 Operands[0] = NestedAR->getStart();
3696 // AddRecs require their operands be loop-invariant with respect to their
3697 // loops. Don't perform this transformation if it would break this
3698 // requirement.
3699 bool AllInvariant = all_of(
3700 Operands, [&](const SCEV *Op) { return isLoopInvariant(Op, L); });
3701
3702 if (AllInvariant) {
3703 // Create a recurrence for the outer loop with the same step size.
3704 //
3705 // The outer recurrence keeps its NW flag but only keeps NUW/NSW if the
3706 // inner recurrence has the same property.
3707 SCEV::NoWrapFlags OuterFlags =
3708 maskFlags(Flags, SCEV::FlagNW | NestedAR->getNoWrapFlags());
3709
3710 NestedOperands[0] = getAddRecExpr(Operands, L, OuterFlags);
3711 AllInvariant = all_of(NestedOperands, [&](const SCEV *Op) {
3712 return isLoopInvariant(Op, NestedLoop);
3713 });
3714
3715 if (AllInvariant) {
3716 // Ok, both add recurrences are valid after the transformation.
3717 //
3718 // The inner recurrence keeps its NW flag but only keeps NUW/NSW if
3719 // the outer recurrence has the same property.
3720 SCEV::NoWrapFlags InnerFlags =
3721 maskFlags(NestedAR->getNoWrapFlags(), SCEV::FlagNW | Flags);
3722 return getAddRecExpr(NestedOperands, NestedLoop, InnerFlags);
3723 }
3724 }
3725 // Reset Operands to its original state.
3726 Operands[0] = NestedAR;
3727 }
3728 }
3729
3730 // Okay, it looks like we really DO need an addrec expr. Check to see if we
3731 // already have one, otherwise create a new one.
3732 return getOrCreateAddRecExpr(Operands, L, Flags);
3733}
3734
3735const SCEV *
3737 const SmallVectorImpl<const SCEV *> &IndexExprs) {
3738 const SCEV *BaseExpr = getSCEV(GEP->getPointerOperand());
3739 // getSCEV(Base)->getType() has the same address space as Base->getType()
3740 // because SCEV::getType() preserves the address space.
3741 Type *IntIdxTy = getEffectiveSCEVType(BaseExpr->getType());
3742 GEPNoWrapFlags NW = GEP->getNoWrapFlags();
3743 if (NW != GEPNoWrapFlags::none()) {
3744 // We'd like to propagate flags from the IR to the corresponding SCEV nodes,
3745 // but to do that, we have to ensure that said flag is valid in the entire
3746 // defined scope of the SCEV.
3747 // TODO: non-instructions have global scope. We might be able to prove
3748 // some global scope cases
3749 auto *GEPI = dyn_cast<Instruction>(GEP);
3750 if (!GEPI || !isSCEVExprNeverPoison(GEPI))
3751 NW = GEPNoWrapFlags::none();
3752 }
3753
3755 if (NW.hasNoUnsignedSignedWrap())
3756 OffsetWrap = setFlags(OffsetWrap, SCEV::FlagNSW);
3757 if (NW.hasNoUnsignedWrap())
3758 OffsetWrap = setFlags(OffsetWrap, SCEV::FlagNUW);
3759
3760 Type *CurTy = GEP->getType();
3761 bool FirstIter = true;
3763 for (const SCEV *IndexExpr : IndexExprs) {
3764 // Compute the (potentially symbolic) offset in bytes for this index.
3765 if (StructType *STy = dyn_cast<StructType>(CurTy)) {
3766 // For a struct, add the member offset.
3767 ConstantInt *Index = cast<SCEVConstant>(IndexExpr)->getValue();
3768 unsigned FieldNo = Index->getZExtValue();
3769 const SCEV *FieldOffset = getOffsetOfExpr(IntIdxTy, STy, FieldNo);
3770 Offsets.push_back(FieldOffset);
3771
3772 // Update CurTy to the type of the field at Index.
3773 CurTy = STy->getTypeAtIndex(Index);
3774 } else {
3775 // Update CurTy to its element type.
3776 if (FirstIter) {
3777 assert(isa<PointerType>(CurTy) &&
3778 "The first index of a GEP indexes a pointer");
3779 CurTy = GEP->getSourceElementType();
3780 FirstIter = false;
3781 } else {
3783 }
3784 // For an array, add the element offset, explicitly scaled.
3785 const SCEV *ElementSize = getSizeOfExpr(IntIdxTy, CurTy);
3786 // Getelementptr indices are signed.
3787 IndexExpr = getTruncateOrSignExtend(IndexExpr, IntIdxTy);
3788
3789 // Multiply the index by the element size to compute the element offset.
3790 const SCEV *LocalOffset = getMulExpr(IndexExpr, ElementSize, OffsetWrap);
3791 Offsets.push_back(LocalOffset);
3792 }
3793 }
3794
3795 // Handle degenerate case of GEP without offsets.
3796 if (Offsets.empty())
3797 return BaseExpr;
3798
3799 // Add the offsets together, assuming nsw if inbounds.
3800 const SCEV *Offset = getAddExpr(Offsets, OffsetWrap);
3801 // Add the base address and the offset. We cannot use the nsw flag, as the
3802 // base address is unsigned. However, if we know that the offset is
3803 // non-negative, we can use nuw.
3804 bool NUW = NW.hasNoUnsignedWrap() ||
3807 auto *GEPExpr = getAddExpr(BaseExpr, Offset, BaseWrap);
3808 assert(BaseExpr->getType() == GEPExpr->getType() &&
3809 "GEP should not change type mid-flight.");
3810 return GEPExpr;
3811}
3812
3813SCEV *ScalarEvolution::findExistingSCEVInCache(SCEVTypes SCEVType,
3816 ID.AddInteger(SCEVType);
3817 for (const SCEV *Op : Ops)
3818 ID.AddPointer(Op);
3819 void *IP = nullptr;
3820 return UniqueSCEVs.FindNodeOrInsertPos(ID, IP);
3821}
3822
3823const SCEV *ScalarEvolution::getAbsExpr(const SCEV *Op, bool IsNSW) {
3825 return getSMaxExpr(Op, getNegativeSCEV(Op, Flags));
3826}
3827
3830 assert(SCEVMinMaxExpr::isMinMaxType(Kind) && "Not a SCEVMinMaxExpr!");
3831 assert(!Ops.empty() && "Cannot get empty (u|s)(min|max)!");
3832 if (Ops.size() == 1) return Ops[0];
3833#ifndef NDEBUG
3834 Type *ETy = getEffectiveSCEVType(Ops[0]->getType());
3835 for (unsigned i = 1, e = Ops.size(); i != e; ++i) {
3836 assert(getEffectiveSCEVType(Ops[i]->getType()) == ETy &&
3837 "Operand types don't match!");
3838 assert(Ops[0]->getType()->isPointerTy() ==
3839 Ops[i]->getType()->isPointerTy() &&
3840 "min/max should be consistently pointerish");
3841 }
3842#endif
3843
3844 bool IsSigned = Kind == scSMaxExpr || Kind == scSMinExpr;
3845 bool IsMax = Kind == scSMaxExpr || Kind == scUMaxExpr;
3846
3847 const SCEV *Folded = constantFoldAndGroupOps(
3848 *this, LI, DT, Ops,
3849 [&](const APInt &C1, const APInt &C2) {
3850 switch (Kind) {
3851 case scSMaxExpr:
3852 return APIntOps::smax(C1, C2);
3853 case scSMinExpr:
3854 return APIntOps::smin(C1, C2);
3855 case scUMaxExpr:
3856 return APIntOps::umax(C1, C2);
3857 case scUMinExpr:
3858 return APIntOps::umin(C1, C2);
3859 default:
3860 llvm_unreachable("Unknown SCEV min/max opcode");
3861 }
3862 },
3863 [&](const APInt &C) {
3864 // identity
3865 if (IsMax)
3866 return IsSigned ? C.isMinSignedValue() : C.isMinValue();
3867 else
3868 return IsSigned ? C.isMaxSignedValue() : C.isMaxValue();
3869 },
3870 [&](const APInt &C) {
3871 // absorber
3872 if (IsMax)
3873 return IsSigned ? C.isMaxSignedValue() : C.isMaxValue();
3874 else
3875 return IsSigned ? C.isMinSignedValue() : C.isMinValue();
3876 });
3877 if (Folded)
3878 return Folded;
3879
3880 // Check if we have created the same expression before.
3881 if (const SCEV *S = findExistingSCEVInCache(Kind, Ops)) {
3882 return S;
3883 }
3884
3885 // Find the first operation of the same kind
3886 unsigned Idx = 0;
3887 while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < Kind)
3888 ++Idx;
3889
3890 // Check to see if one of the operands is of the same kind. If so, expand its
3891 // operands onto our operand list, and recurse to simplify.
3892 if (Idx < Ops.size()) {
3893 bool DeletedAny = false;
3894 while (Ops[Idx]->getSCEVType() == Kind) {
3895 const SCEVMinMaxExpr *SMME = cast<SCEVMinMaxExpr>(Ops[Idx]);
3896 Ops.erase(Ops.begin()+Idx);
3897 append_range(Ops, SMME->operands());
3898 DeletedAny = true;
3899 }
3900
3901 if (DeletedAny)
3902 return getMinMaxExpr(Kind, Ops);
3903 }
3904
3905 // Okay, check to see if the same value occurs in the operand list twice. If
3906 // so, delete one. Since we sorted the list, these values are required to
3907 // be adjacent.
3912 llvm::CmpInst::Predicate FirstPred = IsMax ? GEPred : LEPred;
3913 llvm::CmpInst::Predicate SecondPred = IsMax ? LEPred : GEPred;
3914 for (unsigned i = 0, e = Ops.size() - 1; i != e; ++i) {
3915 if (Ops[i] == Ops[i + 1] ||
3916 isKnownViaNonRecursiveReasoning(FirstPred, Ops[i], Ops[i + 1])) {
3917 // X op Y op Y --> X op Y
3918 // X op Y --> X, if we know X, Y are ordered appropriately
3919 Ops.erase(Ops.begin() + i + 1, Ops.begin() + i + 2);
3920 --i;
3921 --e;
3922 } else if (isKnownViaNonRecursiveReasoning(SecondPred, Ops[i],
3923 Ops[i + 1])) {
3924 // X op Y --> Y, if we know X, Y are ordered appropriately
3925 Ops.erase(Ops.begin() + i, Ops.begin() + i + 1);
3926 --i;
3927 --e;
3928 }
3929 }
3930
3931 if (Ops.size() == 1) return Ops[0];
3932
3933 assert(!Ops.empty() && "Reduced smax down to nothing!");
3934
3935 // Okay, it looks like we really DO need an expr. Check to see if we
3936 // already have one, otherwise create a new one.
3938 ID.AddInteger(Kind);
3939 for (const SCEV *Op : Ops)
3940 ID.AddPointer(Op);
3941 void *IP = nullptr;
3942 const SCEV *ExistingSCEV = UniqueSCEVs.FindNodeOrInsertPos(ID, IP);
3943 if (ExistingSCEV)
3944 return ExistingSCEV;
3945 const SCEV **O = SCEVAllocator.Allocate<const SCEV *>(Ops.size());
3946 std::uninitialized_copy(Ops.begin(), Ops.end(), O);
3947 SCEV *S = new (SCEVAllocator)
3948 SCEVMinMaxExpr(ID.Intern(SCEVAllocator), Kind, O, Ops.size());
3949
3950 UniqueSCEVs.InsertNode(S, IP);
3951 registerUser(S, Ops);
3952 return S;
3953}
3954
3955namespace {
3956
3957class SCEVSequentialMinMaxDeduplicatingVisitor final
3958 : public SCEVVisitor<SCEVSequentialMinMaxDeduplicatingVisitor,
3959 std::optional<const SCEV *>> {
3960 using RetVal = std::optional<const SCEV *>;
3962
3963 ScalarEvolution &SE;
3964 const SCEVTypes RootKind; // Must be a sequential min/max expression.
3965 const SCEVTypes NonSequentialRootKind; // Non-sequential variant of RootKind.
3967
3968 bool canRecurseInto(SCEVTypes Kind) const {
3969 // We can only recurse into the SCEV expression of the same effective type
3970 // as the type of our root SCEV expression.
3971 return RootKind == Kind || NonSequentialRootKind == Kind;
3972 };
3973
3974 RetVal visitAnyMinMaxExpr(const SCEV *S) {
3975 assert((isa<SCEVMinMaxExpr>(S) || isa<SCEVSequentialMinMaxExpr>(S)) &&
3976 "Only for min/max expressions.");
3977 SCEVTypes Kind = S->getSCEVType();
3978
3979 if (!canRecurseInto(Kind))
3980 return S;
3981
3982 auto *NAry = cast<SCEVNAryExpr>(S);
3984 bool Changed = visit(Kind, NAry->operands(), NewOps);
3985
3986 if (!Changed)
3987 return S;
3988 if (NewOps.empty())
3989 return std::nullopt;
3990
3991 return isa<SCEVSequentialMinMaxExpr>(S)
3992 ? SE.getSequentialMinMaxExpr(Kind, NewOps)
3993 : SE.getMinMaxExpr(Kind, NewOps);
3994 }
3995
3996 RetVal visit(const SCEV *S) {
3997 // Has the whole operand been seen already?
3998 if (!SeenOps.insert(S).second)
3999 return std::nullopt;
4000 return Base::visit(S);
4001 }
4002
4003public:
4004 SCEVSequentialMinMaxDeduplicatingVisitor(ScalarEvolution &SE,
4005 SCEVTypes RootKind)
4006 : SE(SE), RootKind(RootKind),
4007 NonSequentialRootKind(
4008 SCEVSequentialMinMaxExpr::getEquivalentNonSequentialSCEVType(
4009 RootKind)) {}
4010
4011 bool /*Changed*/ visit(SCEVTypes Kind, ArrayRef<const SCEV *> OrigOps,
4013 bool Changed = false;
4015 Ops.reserve(OrigOps.size());
4016
4017 for (const SCEV *Op : OrigOps) {
4018 RetVal NewOp = visit(Op);
4019 if (NewOp != Op)
4020 Changed = true;
4021 if (NewOp)
4022 Ops.emplace_back(*NewOp);
4023 }
4024
4025 if (Changed)
4026 NewOps = std::move(Ops);
4027 return Changed;
4028 }
4029
4030 RetVal visitConstant(const SCEVConstant *Constant) { return Constant; }
4031
4032 RetVal visitVScale(const SCEVVScale *VScale) { return VScale; }
4033
4034 RetVal visitPtrToIntExpr(const SCEVPtrToIntExpr *Expr) { return Expr; }
4035
4036 RetVal visitTruncateExpr(const SCEVTruncateExpr *Expr) { return Expr; }
4037
4038 RetVal visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr) { return Expr; }
4039
4040 RetVal visitSignExtendExpr(const SCEVSignExtendExpr *Expr) { return Expr; }
4041
4042 RetVal visitAddExpr(const SCEVAddExpr *Expr) { return Expr; }
4043
4044 RetVal visitMulExpr(const SCEVMulExpr *Expr) { return Expr; }
4045
4046 RetVal visitUDivExpr(const SCEVUDivExpr *Expr) { return Expr; }
4047
4048 RetVal visitAddRecExpr(const SCEVAddRecExpr *Expr) { return Expr; }
4049
4050 RetVal visitSMaxExpr(const SCEVSMaxExpr *Expr) {
4051 return visitAnyMinMaxExpr(Expr);
4052 }
4053
4054 RetVal visitUMaxExpr(const SCEVUMaxExpr *Expr) {
4055 return visitAnyMinMaxExpr(Expr);
4056 }
4057
4058 RetVal visitSMinExpr(const SCEVSMinExpr *Expr) {
4059 return visitAnyMinMaxExpr(Expr);
4060 }
4061
4062 RetVal visitUMinExpr(const SCEVUMinExpr *Expr) {
4063 return visitAnyMinMaxExpr(Expr);
4064 }
4065
4066 RetVal visitSequentialUMinExpr(const SCEVSequentialUMinExpr *Expr) {
4067 return visitAnyMinMaxExpr(Expr);
4068 }
4069
4070 RetVal visitUnknown(const SCEVUnknown *Expr) { return Expr; }
4071
4072 RetVal visitCouldNotCompute(const SCEVCouldNotCompute *Expr) { return Expr; }
4073};
4074
4075} // namespace
4076
4078 switch (Kind) {
4079 case scConstant:
4080 case scVScale:
4081 case scTruncate:
4082 case scZeroExtend:
4083 case scSignExtend:
4084 case scPtrToInt:
4085 case scAddExpr:
4086 case scMulExpr:
4087 case scUDivExpr:
4088 case scAddRecExpr:
4089 case scUMaxExpr:
4090 case scSMaxExpr:
4091 case scUMinExpr:
4092 case scSMinExpr:
4093 case scUnknown:
4094 // If any operand is poison, the whole expression is poison.
4095 return true;
4097 // FIXME: if the *first* operand is poison, the whole expression is poison.
4098 return false; // Pessimistically, say that it does not propagate poison.
4099 case scCouldNotCompute:
4100 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
4101 }
4102 llvm_unreachable("Unknown SCEV kind!");
4103}
4104
4105namespace {
4106// The only way poison may be introduced in a SCEV expression is from a
4107// poison SCEVUnknown (ConstantExprs are also represented as SCEVUnknown,
4108// not SCEVConstant). Notably, nowrap flags in SCEV nodes can *not*
4109// introduce poison -- they encode guaranteed, non-speculated knowledge.
4110//
4111// Additionally, all SCEV nodes propagate poison from inputs to outputs,
4112// with the notable exception of umin_seq, where only poison from the first
4113// operand is (unconditionally) propagated.
4114struct SCEVPoisonCollector {
4115 bool LookThroughMaybePoisonBlocking;
4117 SCEVPoisonCollector(bool LookThroughMaybePoisonBlocking)
4118 : LookThroughMaybePoisonBlocking(LookThroughMaybePoisonBlocking) {}
4119
4120 bool follow(const SCEV *S) {
4121 if (!LookThroughMaybePoisonBlocking &&
4123 return false;
4124
4125 if (auto *SU = dyn_cast<SCEVUnknown>(S)) {
4126 if (!isGuaranteedNotToBePoison(SU->getValue()))
4127 MaybePoison.insert(SU);
4128 }
4129 return true;
4130 }
4131 bool isDone() const { return false; }
4132};
4133} // namespace
4134
4135/// Return true if V is poison given that AssumedPoison is already poison.
4136static bool impliesPoison(const SCEV *AssumedPoison, const SCEV *S) {
4137 // First collect all SCEVs that might result in AssumedPoison to be poison.
4138 // We need to look through potentially poison-blocking operations here,
4139 // because we want to find all SCEVs that *might* result in poison, not only
4140 // those that are *required* to.
4141 SCEVPoisonCollector PC1(/* LookThroughMaybePoisonBlocking */ true);
4142 visitAll(AssumedPoison, PC1);
4143
4144 // AssumedPoison is never poison. As the assumption is false, the implication
4145 // is true. Don't bother walking the other SCEV in this case.
4146 if (PC1.MaybePoison.empty())
4147 return true;
4148
4149 // Collect all SCEVs in S that, if poison, *will* result in S being poison
4150 // as well. We cannot look through potentially poison-blocking operations
4151 // here, as their arguments only *may* make the result poison.
4152 SCEVPoisonCollector PC2(/* LookThroughMaybePoisonBlocking */ false);
4153 visitAll(S, PC2);
4154
4155 // Make sure that no matter which SCEV in PC1.MaybePoison is actually poison,
4156 // it will also make S poison by being part of PC2.MaybePoison.
4157 return llvm::set_is_subset(PC1.MaybePoison, PC2.MaybePoison);
4158}
4159
4161 SmallPtrSetImpl<const Value *> &Result, const SCEV *S) {
4162 SCEVPoisonCollector PC(/* LookThroughMaybePoisonBlocking */ false);
4163 visitAll(S, PC);
4164 for (const SCEVUnknown *SU : PC.MaybePoison)
4165 Result.insert(SU->getValue());
4166}
4167
4169 const SCEV *S, Instruction *I,
4170 SmallVectorImpl<Instruction *> &DropPoisonGeneratingInsts) {
4171 // If the instruction cannot be poison, it's always safe to reuse.
4173 return true;
4174
4175 // Otherwise, it is possible that I is more poisonous that S. Collect the
4176 // poison-contributors of S, and then check whether I has any additional
4177 // poison-contributors. Poison that is contributed through poison-generating
4178 // flags is handled by dropping those flags instead.
4180 getPoisonGeneratingValues(PoisonVals, S);
4181
4182 SmallVector<Value *> Worklist;
4184 Worklist.push_back(I);
4185 while (!Worklist.empty()) {
4186 Value *V = Worklist.pop_back_val();
4187 if (!Visited.insert(V).second)
4188 continue;
4189
4190 // Avoid walking large instruction graphs.
4191 if (Visited.size() > 16)
4192 return false;
4193
4194 // Either the value can't be poison, or the S would also be poison if it
4195 // is.
4196 if (PoisonVals.contains(V) || ::isGuaranteedNotToBePoison(V))
4197 continue;
4198
4199 auto *I = dyn_cast<Instruction>(V);
4200 if (!I)
4201 return false;
4202
4203 // Disjoint or instructions are interpreted as adds by SCEV. However, we
4204 // can't replace an arbitrary add with disjoint or, even if we drop the
4205 // flag. We would need to convert the or into an add.
4206 if (auto *PDI = dyn_cast<PossiblyDisjointInst>(I))
4207 if (PDI->isDisjoint())
4208 return false;
4209
4210 // FIXME: Ignore vscale, even though it technically could be poison. Do this
4211 // because SCEV currently assumes it can't be poison. Remove this special
4212 // case once we proper model when vscale can be poison.
4213 if (auto *II = dyn_cast<IntrinsicInst>(I);
4214 II && II->getIntrinsicID() == Intrinsic::vscale)
4215 continue;
4216
4217 if (canCreatePoison(cast<Operator>(I), /*ConsiderFlagsAndMetadata*/ false))
4218 return false;
4219
4220 // If the instruction can't create poison, we can recurse to its operands.
4221 if (I->hasPoisonGeneratingAnnotations())
4222 DropPoisonGeneratingInsts.push_back(I);
4223
4224 for (Value *Op : I->operands())
4225 Worklist.push_back(Op);
4226 }
4227 return true;
4228}
4229
4230const SCEV *
4233 assert(SCEVSequentialMinMaxExpr::isSequentialMinMaxType(Kind) &&
4234 "Not a SCEVSequentialMinMaxExpr!");
4235 assert(!Ops.empty() && "Cannot get empty (u|s)(min|max)!");
4236 if (Ops.size() == 1)
4237 return Ops[0];
4238#ifndef NDEBUG
4239 Type *ETy = getEffectiveSCEVType(Ops[0]->getType());
4240 for (unsigned i = 1, e = Ops.size(); i != e; ++i) {
4241 assert(getEffectiveSCEVType(Ops[i]->getType()) == ETy &&
4242 "Operand types don't match!");
4243 assert(Ops[0]->getType()->isPointerTy() ==
4244 Ops[i]->getType()->isPointerTy() &&
4245 "min/max should be consistently pointerish");
4246 }
4247#endif
4248
4249 // Note that SCEVSequentialMinMaxExpr is *NOT* commutative,
4250 // so we can *NOT* do any kind of sorting of the expressions!
4251
4252 // Check if we have created the same expression before.
4253 if (const SCEV *S = findExistingSCEVInCache(Kind, Ops))
4254 return S;
4255
4256 // FIXME: there are *some* simplifications that we can do here.
4257
4258 // Keep only the first instance of an operand.
4259 {
4260 SCEVSequentialMinMaxDeduplicatingVisitor Deduplicator(*this, Kind);
4261 bool Changed = Deduplicator.visit(Kind, Ops, Ops);
4262 if (Changed)
4263 return getSequentialMinMaxExpr(Kind, Ops);
4264 }
4265
4266 // Check to see if one of the operands is of the same kind. If so, expand its
4267 // operands onto our operand list, and recurse to simplify.
4268 {
4269 unsigned Idx = 0;
4270 bool DeletedAny = false;
4271 while (Idx < Ops.size()) {
4272 if (Ops[Idx]->getSCEVType() != Kind) {
4273 ++Idx;
4274 continue;
4275 }
4276 const auto *SMME = cast<SCEVSequentialMinMaxExpr>(Ops[Idx]);
4277 Ops.erase(Ops.begin() + Idx);
4278 Ops.insert(Ops.begin() + Idx, SMME->operands().begin(),
4279 SMME->operands().end());
4280 DeletedAny = true;
4281 }
4282
4283 if (DeletedAny)
4284 return getSequentialMinMaxExpr(Kind, Ops);
4285 }
4286
4287 const SCEV *SaturationPoint;
4289 switch (Kind) {
4291 SaturationPoint = getZero(Ops[0]->getType());
4292 Pred = ICmpInst::ICMP_ULE;
4293 break;
4294 default:
4295 llvm_unreachable("Not a sequential min/max type.");
4296 }
4297
4298 for (unsigned i = 1, e = Ops.size(); i != e; ++i) {
4299 if (!isGuaranteedNotToCauseUB(Ops[i]))
4300 continue;
4301 // We can replace %x umin_seq %y with %x umin %y if either:
4302 // * %y being poison implies %x is also poison.
4303 // * %x cannot be the saturating value (e.g. zero for umin).
4304 if (::impliesPoison(Ops[i], Ops[i - 1]) ||
4305 isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_NE, Ops[i - 1],
4306 SaturationPoint)) {
4307 SmallVector<const SCEV *> SeqOps = {Ops[i - 1], Ops[i]};
4308 Ops[i - 1] = getMinMaxExpr(
4310 SeqOps);
4311 Ops.erase(Ops.begin() + i);
4312 return getSequentialMinMaxExpr(Kind, Ops);
4313 }
4314 // Fold %x umin_seq %y to %x if %x ule %y.
4315 // TODO: We might be able to prove the predicate for a later operand.
4316 if (isKnownViaNonRecursiveReasoning(Pred, Ops[i - 1], Ops[i])) {
4317 Ops.erase(Ops.begin() + i);
4318 return getSequentialMinMaxExpr(Kind, Ops);
4319 }
4320 }
4321
4322 // Okay, it looks like we really DO need an expr. Check to see if we
4323 // already have one, otherwise create a new one.
4325 ID.AddInteger(Kind);
4326 for (const SCEV *Op : Ops)
4327 ID.AddPointer(Op);
4328 void *IP = nullptr;
4329 const SCEV *ExistingSCEV = UniqueSCEVs.FindNodeOrInsertPos(ID, IP);
4330 if (ExistingSCEV)
4331 return ExistingSCEV;
4332
4333 const SCEV **O = SCEVAllocator.Allocate<const SCEV *>(Ops.size());
4334 std::uninitialized_copy(Ops.begin(), Ops.end(), O);
4335 SCEV *S = new (SCEVAllocator)
4336 SCEVSequentialMinMaxExpr(ID.Intern(SCEVAllocator), Kind, O, Ops.size());
4337
4338 UniqueSCEVs.InsertNode(S, IP);
4339 registerUser(S, Ops);
4340 return S;
4341}
4342
4343const SCEV *ScalarEvolution::getSMaxExpr(const SCEV *LHS, const SCEV *RHS) {
4345 return getSMaxExpr(Ops);
4346}
4347
4349 return getMinMaxExpr(scSMaxExpr, Ops);
4350}
4351
4352const SCEV *ScalarEvolution::getUMaxExpr(const SCEV *LHS, const SCEV *RHS) {
4354 return getUMaxExpr(Ops);
4355}
4356
4358 return getMinMaxExpr(scUMaxExpr, Ops);
4359}
4360
4362 const SCEV *RHS) {
4364 return getSMinExpr(Ops);
4365}
4366
4368 return getMinMaxExpr(scSMinExpr, Ops);
4369}
4370
4371const SCEV *ScalarEvolution::getUMinExpr(const SCEV *LHS, const SCEV *RHS,
4372 bool Sequential) {
4374 return getUMinExpr(Ops, Sequential);
4375}
4376
4378 bool Sequential) {
4379 return Sequential ? getSequentialMinMaxExpr(scSequentialUMinExpr, Ops)
4380 : getMinMaxExpr(scUMinExpr, Ops);
4381}
4382
4383const SCEV *
4385 const SCEV *Res = getConstant(IntTy, Size.getKnownMinValue());
4386 if (Size.isScalable())
4387 Res = getMulExpr(Res, getVScale(IntTy));
4388 return Res;
4389}
4390
4392 return getSizeOfExpr(IntTy, getDataLayout().getTypeAllocSize(AllocTy));
4393}
4394
4396 return getSizeOfExpr(IntTy, getDataLayout().getTypeStoreSize(StoreTy));
4397}
4398
4400 StructType *STy,
4401 unsigned FieldNo) {
4402 // We can bypass creating a target-independent constant expression and then
4403 // folding it back into a ConstantInt. This is just a compile-time
4404 // optimization.
4405 const StructLayout *SL = getDataLayout().getStructLayout(STy);
4406 assert(!SL->getSizeInBits().isScalable() &&
4407 "Cannot get offset for structure containing scalable vector types");
4408 return getConstant(IntTy, SL->getElementOffset(FieldNo));
4409}
4410
4412 // Don't attempt to do anything other than create a SCEVUnknown object
4413 // here. createSCEV only calls getUnknown after checking for all other
4414 // interesting possibilities, and any other code that calls getUnknown
4415 // is doing so in order to hide a value from SCEV canonicalization.
4416
4418 ID.AddInteger(scUnknown);
4419 ID.AddPointer(V);
4420 void *IP = nullptr;
4421 if (SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) {
4422 assert(cast<SCEVUnknown>(S)->getValue() == V &&
4423 "Stale SCEVUnknown in uniquing map!");
4424 return S;
4425 }
4426 SCEV *S = new (SCEVAllocator) SCEVUnknown(ID.Intern(SCEVAllocator), V, this,
4427 FirstUnknown);
4428 FirstUnknown = cast<SCEVUnknown>(S);
4429 UniqueSCEVs.InsertNode(S, IP);
4430 return S;
4431}
4432
4433//===----------------------------------------------------------------------===//
4434// Basic SCEV Analysis and PHI Idiom Recognition Code
4435//
4436
4437/// Test if values of the given type are analyzable within the SCEV
4438/// framework. This primarily includes integer types, and it can optionally
4439/// include pointer types if the ScalarEvolution class has access to
4440/// target-specific information.
4442 // Integers and pointers are always SCEVable.
4443 return Ty->isIntOrPtrTy();
4444}
4445
4446/// Return the size in bits of the specified type, for which isSCEVable must
4447/// return true.
4449 assert(isSCEVable(Ty) && "Type is not SCEVable!");
4450 if (Ty->isPointerTy())
4452 return getDataLayout().getTypeSizeInBits(Ty);
4453}
4454
4455/// Return a type with the same bitwidth as the given type and which represents
4456/// how SCEV will treat the given type, for which isSCEVable must return
4457/// true. For pointer types, this is the pointer index sized integer type.
4459 assert(isSCEVable(Ty) && "Type is not SCEVable!");
4460
4461 if (Ty->isIntegerTy())
4462 return Ty;
4463
4464 // The only other support type is pointer.
4465 assert(Ty->isPointerTy() && "Unexpected non-pointer non-integer type!");
4466 return getDataLayout().getIndexType(Ty);
4467}
4468
4470 return getTypeSizeInBits(T1) >= getTypeSizeInBits(T2) ? T1 : T2;
4471}
4472
4474 const SCEV *B) {
4475 /// For a valid use point to exist, the defining scope of one operand
4476 /// must dominate the other.
4477 bool PreciseA, PreciseB;
4478 auto *ScopeA = getDefiningScopeBound({A}, PreciseA);
4479 auto *ScopeB = getDefiningScopeBound({B}, PreciseB);
4480 if (!PreciseA || !PreciseB)
4481 // Can't tell.
4482 return false;
4483 return (ScopeA == ScopeB) || DT.dominates(ScopeA, ScopeB) ||
4484 DT.dominates(ScopeB, ScopeA);
4485}
4486
4488 return CouldNotCompute.get();
4489}
4490
4491bool ScalarEvolution::checkValidity(const SCEV *S) const {
4492 bool ContainsNulls = SCEVExprContains(S, [](const SCEV *S) {
4493 auto *SU = dyn_cast<SCEVUnknown>(S);
4494 return SU && SU->getValue() == nullptr;
4495 });
4496
4497 return !ContainsNulls;
4498}
4499
4501 HasRecMapType::iterator I = HasRecMap.find(S);
4502 if (I != HasRecMap.end())
4503 return I->second;
4504
4505 bool FoundAddRec =
4506 SCEVExprContains(S, [](const SCEV *S) { return isa<SCEVAddRecExpr>(S); });
4507 HasRecMap.insert({S, FoundAddRec});
4508 return FoundAddRec;
4509}
4510
4511/// Return the ValueOffsetPair set for \p S. \p S can be represented
4512/// by the value and offset from any ValueOffsetPair in the set.
4513ArrayRef<Value *> ScalarEvolution::getSCEVValues(const SCEV *S) {
4514 ExprValueMapType::iterator SI = ExprValueMap.find_as(S);
4515 if (SI == ExprValueMap.end())
4516 return {};
4517 return SI->second.getArrayRef();
4518}
4519
4520/// Erase Value from ValueExprMap and ExprValueMap. ValueExprMap.erase(V)
4521/// cannot be used separately. eraseValueFromMap should be used to remove
4522/// V from ValueExprMap and ExprValueMap at the same time.
4523void ScalarEvolution::eraseValueFromMap(Value *V) {
4524 ValueExprMapType::iterator I = ValueExprMap.find_as(V);
4525 if (I != ValueExprMap.end()) {
4526 auto EVIt = ExprValueMap.find(I->second);
4527 bool Removed = EVIt->second.remove(V);
4528 (void) Removed;
4529 assert(Removed && "Value not in ExprValueMap?");
4530 ValueExprMap.erase(I);
4531 }
4532}
4533
4534void ScalarEvolution::insertValueToMap(Value *V, const SCEV *S) {
4535 // A recursive query may have already computed the SCEV. It should be
4536 // equivalent, but may not necessarily be exactly the same, e.g. due to lazily
4537 // inferred nowrap flags.
4538 auto It = ValueExprMap.find_as(V);
4539 if (It == ValueExprMap.end()) {
4540 ValueExprMap.insert({SCEVCallbackVH(V, this), S});
4541 ExprValueMap[S].insert(V);
4542 }
4543}
4544
4545/// Return an existing SCEV if it exists, otherwise analyze the expression and
4546/// create a new one.
4548 assert(isSCEVable(V->getType()) && "Value is not SCEVable!");
4549
4550 if (const SCEV *S = getExistingSCEV(V))
4551 return S;
4552 return createSCEVIter(V);
4553}
4554
4556 assert(isSCEVable(V->getType()) && "Value is not SCEVable!");
4557
4558 ValueExprMapType::iterator I = ValueExprMap.find_as(V);
4559 if (I != ValueExprMap.end()) {
4560 const SCEV *S = I->second;
4561 assert(checkValidity(S) &&
4562 "existing SCEV has not been properly invalidated");
4563 return S;
4564 }
4565 return nullptr;
4566}
4567
4568/// Return a SCEV corresponding to -V = -1*V
4570 SCEV::NoWrapFlags Flags) {
4571 if (const SCEVConstant *VC = dyn_cast<SCEVConstant>(V))
4572 return getConstant(
4573 cast<ConstantInt>(ConstantExpr::getNeg(VC->getValue())));
4574
4575 Type *Ty = V->getType();
4576 Ty = getEffectiveSCEVType(Ty);
4577 return getMulExpr(V, getMinusOne(Ty), Flags);
4578}
4579
4580/// If Expr computes ~A, return A else return nullptr
4581static const SCEV *MatchNotExpr(const SCEV *Expr) {
4582 const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Expr);
4583 if (!Add || Add->getNumOperands() != 2 ||
4584 !Add->getOperand(0)->isAllOnesValue())
4585 return nullptr;
4586
4587 const SCEVMulExpr *AddRHS = dyn_cast<SCEVMulExpr>(Add->getOperand(1));
4588 if (!AddRHS || AddRHS->getNumOperands() != 2 ||
4589 !AddRHS->getOperand(0)->isAllOnesValue())
4590 return nullptr;
4591
4592 return AddRHS->getOperand(1);
4593}
4594
4595/// Return a SCEV corresponding to ~V = -1-V
4597 assert(!V->getType()->isPointerTy() && "Can't negate pointer");
4598
4599 if (const SCEVConstant *VC = dyn_cast<SCEVConstant>(V))
4600 return getConstant(
4601 cast<ConstantInt>(ConstantExpr::getNot(VC->getValue())));
4602
4603 // Fold ~(u|s)(min|max)(~x, ~y) to (u|s)(max|min)(x, y)
4604 if (const SCEVMinMaxExpr *MME = dyn_cast<SCEVMinMaxExpr>(V)) {
4605 auto MatchMinMaxNegation = [&](const SCEVMinMaxExpr *MME) {
4606 SmallVector<const SCEV *, 2> MatchedOperands;
4607 for (const SCEV *Operand : MME->operands()) {
4608 const SCEV *Matched = MatchNotExpr(Operand);
4609 if (!Matched)
4610 return (const SCEV *)nullptr;
4611 MatchedOperands.push_back(Matched);
4612 }
4613 return getMinMaxExpr(SCEVMinMaxExpr::negate(MME->getSCEVType()),
4614 MatchedOperands);
4615 };
4616 if (const SCEV *Replaced = MatchMinMaxNegation(MME))
4617 return Replaced;
4618 }
4619
4620 Type *Ty = V->getType();
4621 Ty = getEffectiveSCEVType(Ty);
4622 return getMinusSCEV(getMinusOne(Ty), V);
4623}
4624
4626 assert(P->getType()->isPointerTy());
4627
4628 if (auto *AddRec = dyn_cast<SCEVAddRecExpr>(P)) {
4629 // The base of an AddRec is the first operand.
4630 SmallVector<const SCEV *> Ops{AddRec->operands()};
4631 Ops[0] = removePointerBase(Ops[0]);
4632 // Don't try to transfer nowrap flags for now. We could in some cases
4633 // (for example, if pointer operand of the AddRec is a SCEVUnknown).
4634 return getAddRecExpr(Ops, AddRec->getLoop(), SCEV::FlagAnyWrap);
4635 }
4636 if (auto *Add = dyn_cast<SCEVAddExpr>(P)) {
4637 // The base of an Add is the pointer operand.
4638 SmallVector<const SCEV *> Ops{Add->operands()};
4639 const SCEV **PtrOp = nullptr;
4640 for (const SCEV *&AddOp : Ops) {
4641 if (AddOp->getType()->isPointerTy()) {
4642 assert(!PtrOp && "Cannot have multiple pointer ops");
4643 PtrOp = &AddOp;
4644 }
4645 }
4646 *PtrOp = removePointerBase(*PtrOp);
4647 // Don't try to transfer nowrap flags for now. We could in some cases
4648 // (for example, if the pointer operand of the Add is a SCEVUnknown).
4649 return getAddExpr(Ops);
4650 }
4651 // Any other expression must be a pointer base.
4652 return getZero(P->getType());
4653}
4654
4655const SCEV *ScalarEvolution::getMinusSCEV(const SCEV *LHS, const SCEV *RHS,
4656 SCEV::NoWrapFlags Flags,
4657 unsigned Depth) {
4658 // Fast path: X - X --> 0.
4659 if (LHS == RHS)
4660 return getZero(LHS->getType());
4661
4662 // If we subtract two pointers with different pointer bases, bail.
4663 // Eventually, we're going to add an assertion to getMulExpr that we
4664 // can't multiply by a pointer.
4665 if (RHS->getType()->isPointerTy()) {
4666 if (!LHS->getType()->isPointerTy() ||
4668 return getCouldNotCompute();
4671 }
4672
4673 // We represent LHS - RHS as LHS + (-1)*RHS. This transformation
4674 // makes it so that we cannot make much use of NUW.
4675 auto AddFlags = SCEV::FlagAnyWrap;
4676 const bool RHSIsNotMinSigned =
4678 if (hasFlags(Flags, SCEV::FlagNSW)) {
4679 // Let M be the minimum representable signed value. Then (-1)*RHS
4680 // signed-wraps if and only if RHS is M. That can happen even for
4681 // a NSW subtraction because e.g. (-1)*M signed-wraps even though
4682 // -1 - M does not. So to transfer NSW from LHS - RHS to LHS +
4683 // (-1)*RHS, we need to prove that RHS != M.
4684 //
4685 // If LHS is non-negative and we know that LHS - RHS does not
4686 // signed-wrap, then RHS cannot be M. So we can rule out signed-wrap
4687 // either by proving that RHS > M or that LHS >= 0.
4688 if (RHSIsNotMinSigned || isKnownNonNegative(LHS)) {
4689 AddFlags = SCEV::FlagNSW;
4690 }
4691 }
4692
4693 // FIXME: Find a correct way to transfer NSW to (-1)*M when LHS -
4694 // RHS is NSW and LHS >= 0.
4695 //
4696 // The difficulty here is that the NSW flag may have been proven
4697 // relative to a loop that is to be found in a recurrence in LHS and
4698 // not in RHS. Applying NSW to (-1)*M may then let the NSW have a
4699 // larger scope than intended.
4700 auto NegFlags = RHSIsNotMinSigned ? SCEV::FlagNSW : SCEV::FlagAnyWrap;
4701
4702 return getAddExpr(LHS, getNegativeSCEV(RHS, NegFlags), AddFlags, Depth);
4703}
4704
4706 unsigned Depth) {
4707 Type *SrcTy = V->getType();
4708 assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
4709 "Cannot truncate or zero extend with non-integer arguments!");
4710 if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
4711 return V; // No conversion
4712 if (getTypeSizeInBits(SrcTy) > getTypeSizeInBits(Ty))
4713 return getTruncateExpr(V, Ty, Depth);
4714 return getZeroExtendExpr(V, Ty, Depth);
4715}
4716
4718 unsigned Depth) {
4719 Type *SrcTy = V->getType();
4720 assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
4721 "Cannot truncate or zero extend with non-integer arguments!");
4722 if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
4723 return V; // No conversion
4724 if (getTypeSizeInBits(SrcTy) > getTypeSizeInBits(Ty))
4725 return getTruncateExpr(V, Ty, Depth);
4726 return getSignExtendExpr(V, Ty, Depth);
4727}
4728
4729const SCEV *
4731 Type *SrcTy = V->getType();
4732 assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
4733 "Cannot noop or zero extend with non-integer arguments!");
4735 "getNoopOrZeroExtend cannot truncate!");
4736 if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
4737 return V; // No conversion
4738 return getZeroExtendExpr(V, Ty);
4739}
4740
4741const SCEV *
4743 Type *SrcTy = V->getType();
4744 assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
4745 "Cannot noop or sign extend with non-integer arguments!");
4747 "getNoopOrSignExtend cannot truncate!");
4748 if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
4749 return V; // No conversion
4750 return getSignExtendExpr(V, Ty);
4751}
4752
4753const SCEV *
4755 Type *SrcTy = V->getType();
4756 assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
4757 "Cannot noop or any extend with non-integer arguments!");
4759 "getNoopOrAnyExtend cannot truncate!");
4760 if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
4761 return V; // No conversion
4762 return getAnyExtendExpr(V, Ty);
4763}
4764
4765const SCEV *
4767 Type *SrcTy = V->getType();
4768 assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
4769 "Cannot truncate or noop with non-integer arguments!");
4771 "getTruncateOrNoop cannot extend!");
4772 if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
4773 return V; // No conversion
4774 return getTruncateExpr(V, Ty);
4775}
4776
4778 const SCEV *RHS) {
4779 const SCEV *PromotedLHS = LHS;
4780 const SCEV *PromotedRHS = RHS;
4781
4783 PromotedRHS = getZeroExtendExpr(RHS, LHS->getType());
4784 else
4785 PromotedLHS = getNoopOrZeroExtend(LHS, RHS->getType());
4786
4787 return getUMaxExpr(PromotedLHS, PromotedRHS);
4788}
4789
4791 const SCEV *RHS,
4792 bool Sequential) {
4794 return getUMinFromMismatchedTypes(Ops, Sequential);
4795}
4796
4797const SCEV *
4799 bool Sequential) {
4800 assert(!Ops.empty() && "At least one operand must be!");
4801 // Trivial case.
4802 if (Ops.size() == 1)
4803 return Ops[0];
4804
4805 // Find the max type first.
4806 Type *MaxType = nullptr;
4807 for (const auto *S : Ops)
4808 if (MaxType)
4809 MaxType = getWiderType(MaxType, S->getType());
4810 else
4811 MaxType = S->getType();
4812 assert(MaxType && "Failed to find maximum type!");
4813
4814 // Extend all ops to max type.
4815 SmallVector<const SCEV *, 2> PromotedOps;
4816 for (const auto *S : Ops)
4817 PromotedOps.push_back(getNoopOrZeroExtend(S, MaxType));
4818
4819 // Generate umin.
4820 return getUMinExpr(PromotedOps, Sequential);
4821}
4822
4824 // A pointer operand may evaluate to a nonpointer expression, such as null.
4825 if (!V->getType()->isPointerTy())
4826 return V;
4827
4828 while (true) {
4829 if (auto *AddRec = dyn_cast<SCEVAddRecExpr>(V)) {
4830 V = AddRec->getStart();
4831 } else if (auto *Add = dyn_cast<SCEVAddExpr>(V)) {
4832 const SCEV *PtrOp = nullptr;
4833 for (const SCEV *AddOp : Add->operands()) {
4834 if (AddOp->getType()->isPointerTy()) {
4835 assert(!PtrOp && "Cannot have multiple pointer ops");
4836 PtrOp = AddOp;
4837 }
4838 }
4839 assert(PtrOp && "Must have pointer op");
4840 V = PtrOp;
4841 } else // Not something we can look further into.
4842 return V;
4843 }
4844}
4845
4846/// Push users of the given Instruction onto the given Worklist.
4850 // Push the def-use children onto the Worklist stack.
4851 for (User *U : I->users()) {
4852 auto *UserInsn = cast<Instruction>(U);
4853 if (Visited.insert(UserInsn).second)
4854 Worklist.push_back(UserInsn);
4855 }
4856}
4857
4858namespace {
4859
4860/// Takes SCEV S and Loop L. For each AddRec sub-expression, use its start
4861/// expression in case its Loop is L. If it is not L then
4862/// if IgnoreOtherLoops is true then use AddRec itself
4863/// otherwise rewrite cannot be done.
4864/// If SCEV contains non-invariant unknown SCEV rewrite cannot be done.
4865class SCEVInitRewriter : public SCEVRewriteVisitor<SCEVInitRewriter> {
4866public:
4867 static const SCEV *rewrite(const SCEV *S, const Loop *L, ScalarEvolution &SE,
4868 bool IgnoreOtherLoops = true) {
4869 SCEVInitRewriter Rewriter(L, SE);
4870 const SCEV *Result = Rewriter.visit(S);
4871 if (Rewriter.hasSeenLoopVariantSCEVUnknown())
4872 return SE.getCouldNotCompute();
4873 return Rewriter.hasSeenOtherLoops() && !IgnoreOtherLoops
4874 ? SE.getCouldNotCompute()
4875 : Result;
4876 }
4877
4878 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
4879 if (!SE.isLoopInvariant(Expr, L))
4880 SeenLoopVariantSCEVUnknown = true;
4881 return Expr;
4882 }
4883
4884 const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) {
4885 // Only re-write AddRecExprs for this loop.
4886 if (Expr->getLoop() == L)
4887 return Expr->getStart();
4888 SeenOtherLoops = true;
4889 return Expr;
4890 }
4891
4892 bool hasSeenLoopVariantSCEVUnknown() { return SeenLoopVariantSCEVUnknown; }
4893
4894 bool hasSeenOtherLoops() { return SeenOtherLoops; }
4895
4896private:
4897 explicit SCEVInitRewriter(const Loop *L, ScalarEvolution &SE)
4898 : SCEVRewriteVisitor(SE), L(L) {}
4899
4900 const Loop *L;
4901 bool SeenLoopVariantSCEVUnknown = false;
4902 bool SeenOtherLoops = false;
4903};
4904
4905/// Takes SCEV S and Loop L. For each AddRec sub-expression, use its post
4906/// increment expression in case its Loop is L. If it is not L then
4907/// use AddRec itself.
4908/// If SCEV contains non-invariant unknown SCEV rewrite cannot be done.
4909class SCEVPostIncRewriter : public SCEVRewriteVisitor<SCEVPostIncRewriter> {
4910public:
4911 static const SCEV *rewrite(const SCEV *S, const Loop *L, ScalarEvolution &SE) {
4912 SCEVPostIncRewriter Rewriter(L, SE);
4913 const SCEV *Result = Rewriter.visit(S);
4914 return Rewriter.hasSeenLoopVariantSCEVUnknown()
4915 ? SE.getCouldNotCompute()
4916 : Result;
4917 }
4918
4919 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
4920 if (!SE.isLoopInvariant(Expr, L))
4921 SeenLoopVariantSCEVUnknown = true;
4922 return Expr;
4923 }
4924
4925 const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) {
4926 // Only re-write AddRecExprs for this loop.
4927 if (Expr->getLoop() == L)
4928 return Expr->getPostIncExpr(SE);
4929 SeenOtherLoops = true;
4930 return Expr;
4931 }
4932
4933 bool hasSeenLoopVariantSCEVUnknown() { return SeenLoopVariantSCEVUnknown; }
4934
4935 bool hasSeenOtherLoops() { return SeenOtherLoops; }
4936
4937private:
4938 explicit SCEVPostIncRewriter(const Loop *L, ScalarEvolution &SE)
4939 : SCEVRewriteVisitor(SE), L(L) {}
4940
4941 const Loop *L;
4942 bool SeenLoopVariantSCEVUnknown = false;
4943 bool SeenOtherLoops = false;
4944};
4945
4946/// This class evaluates the compare condition by matching it against the
4947/// condition of loop latch. If there is a match we assume a true value
4948/// for the condition while building SCEV nodes.
4949class SCEVBackedgeConditionFolder
4950 : public SCEVRewriteVisitor<SCEVBackedgeConditionFolder> {
4951public:
4952 static const SCEV *rewrite(const SCEV *S, const Loop *L,
4953 ScalarEvolution &SE) {
4954 bool IsPosBECond = false;
4955 Value *BECond = nullptr;
4956 if (BasicBlock *Latch = L->getLoopLatch()) {
4957 BranchInst *BI = dyn_cast<BranchInst>(Latch->getTerminator());
4958 if (BI && BI->isConditional()) {
4959 assert(BI->getSuccessor(0) != BI->getSuccessor(1) &&
4960 "Both outgoing branches should not target same header!");
4961 BECond = BI->getCondition();
4962 IsPosBECond = BI->getSuccessor(0) == L->getHeader();
4963 } else {
4964 return S;
4965 }
4966 }
4967 SCEVBackedgeConditionFolder Rewriter(L, BECond, IsPosBECond, SE);
4968 return Rewriter.visit(S);
4969 }
4970
4971 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
4972 const SCEV *Result = Expr;
4973 bool InvariantF = SE.isLoopInvariant(Expr, L);
4974
4975 if (!InvariantF) {
4976 Instruction *I = cast<Instruction>(Expr->getValue());
4977 switch (I->getOpcode()) {
4978 case Instruction::Select: {
4979 SelectInst *SI = cast<SelectInst>(I);
4980 std::optional<const SCEV *> Res =
4981 compareWithBackedgeCondition(SI->getCondition());
4982 if (Res) {
4983 bool IsOne = cast<SCEVConstant>(*Res)->getValue()->isOne();
4984 Result = SE.getSCEV(IsOne ? SI->getTrueValue() : SI->getFalseValue());
4985 }
4986 break;
4987 }
4988 default: {
4989 std::optional<const SCEV *> Res = compareWithBackedgeCondition(I);
4990 if (Res)
4991 Result = *Res;
4992 break;
4993 }
4994 }
4995 }
4996 return Result;
4997 }
4998
4999private:
5000 explicit SCEVBackedgeConditionFolder(const Loop *L, Value *BECond,
5001 bool IsPosBECond, ScalarEvolution &SE)
5002 : SCEVRewriteVisitor(SE), L(L), BackedgeCond(BECond),
5003 IsPositiveBECond(IsPosBECond) {}
5004
5005 std::optional<const SCEV *> compareWithBackedgeCondition(Value *IC);
5006
5007 const Loop *L;
5008 /// Loop back condition.
5009 Value *BackedgeCond = nullptr;
5010 /// Set to true if loop back is on positive branch condition.
5011 bool IsPositiveBECond;
5012};
5013
5014std::optional<const SCEV *>
5015SCEVBackedgeConditionFolder::compareWithBackedgeCondition(Value *IC) {
5016
5017 // If value matches the backedge condition for loop latch,
5018 // then return a constant evolution node based on loopback
5019 // branch taken.
5020 if (BackedgeCond == IC)
5021 return IsPositiveBECond ? SE.getOne(Type::getInt1Ty(SE.getContext()))
5023 return std::nullopt;
5024}
5025
5026class SCEVShiftRewriter : public SCEVRewriteVisitor<SCEVShiftRewriter> {
5027public:
5028 static const SCEV *rewrite(const SCEV *S, const Loop *L,
5029 ScalarEvolution &SE) {
5030 SCEVShiftRewriter Rewriter(L, SE);
5031 const SCEV *Result = Rewriter.visit(S);
5032 return Rewriter.isValid() ? Result : SE.getCouldNotCompute();
5033 }
5034
5035 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
5036 // Only allow AddRecExprs for this loop.
5037 if (!SE.isLoopInvariant(Expr, L))
5038 Valid = false;
5039 return Expr;
5040 }
5041
5042 const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) {
5043 if (Expr->getLoop() == L && Expr->isAffine())
5044 return SE.getMinusSCEV(Expr, Expr->getStepRecurrence(SE));
5045 Valid = false;
5046 return Expr;
5047 }
5048
5049 bool isValid() { return Valid; }
5050
5051private:
5052 explicit SCEVShiftRewriter(const Loop *L, ScalarEvolution &SE)
5053 : SCEVRewriteVisitor(SE), L(L) {}
5054
5055 const Loop *L;
5056 bool Valid = true;
5057};
5058
5059} // end anonymous namespace
5060
5062ScalarEvolution::proveNoWrapViaConstantRanges(const SCEVAddRecExpr *AR) {
5063 if (!AR->isAffine())
5064 return SCEV::FlagAnyWrap;
5065
5066 using OBO = OverflowingBinaryOperator;
5067
5069
5070 if (!AR->hasNoSelfWrap()) {
5071 const SCEV *BECount = getConstantMaxBackedgeTakenCount(AR->getLoop());
5072 if (const SCEVConstant *BECountMax = dyn_cast<SCEVConstant>(BECount)) {
5073 ConstantRange StepCR = getSignedRange(AR->getStepRecurrence(*this));
5074 const APInt &BECountAP = BECountMax->getAPInt();
5075 unsigned NoOverflowBitWidth =
5076 BECountAP.getActiveBits() + StepCR.getMinSignedBits();
5077 if (NoOverflowBitWidth <= getTypeSizeInBits(AR->getType()))
5079 }
5080 }
5081
5082 if (!AR->hasNoSignedWrap()) {
5083 ConstantRange AddRecRange = getSignedRange(AR);
5084 ConstantRange IncRange = getSignedRange(AR->getStepRecurrence(*this));
5085
5087 Instruction::Add, IncRange, OBO::NoSignedWrap);
5088 if (NSWRegion.contains(AddRecRange))
5090 }
5091
5092 if (!AR->hasNoUnsignedWrap()) {
5093 ConstantRange AddRecRange = getUnsignedRange(AR);
5094 ConstantRange IncRange = getUnsignedRange(AR->getStepRecurrence(*this));
5095
5097 Instruction::Add, IncRange, OBO::NoUnsignedWrap);
5098 if (NUWRegion.contains(AddRecRange))
5100 }
5101
5102 return Result;
5103}
5104
5106ScalarEvolution::proveNoSignedWrapViaInduction(const SCEVAddRecExpr *AR) {
5108
5109 if (AR->hasNoSignedWrap())
5110 return Result;
5111
5112 if (!AR->isAffine())
5113 return Result;
5114
5115 // This function can be expensive, only try to prove NSW once per AddRec.
5116 if (!SignedWrapViaInductionTried.insert(AR).second)
5117 return Result;
5118
5119 const SCEV *Step = AR->getStepRecurrence(*this);
5120 const Loop *L = AR->getLoop();
5121
5122 // Check whether the backedge-taken count is SCEVCouldNotCompute.
5123 // Note that this serves two purposes: It filters out loops that are
5124 // simply not analyzable, and it covers the case where this code is
5125 // being called from within backedge-taken count analysis, such that
5126 // attempting to ask for the backedge-taken count would likely result
5127 // in infinite recursion. In the later case, the analysis code will
5128 // cope with a conservative value, and it will take care to purge
5129 // that value once it has finished.
5130 const SCEV *MaxBECount = getConstantMaxBackedgeTakenCount(L);
5131
5132 // Normally, in the cases we can prove no-overflow via a
5133 // backedge guarding condition, we can also compute a backedge
5134 // taken count for the loop. The exceptions are assumptions and
5135 // guards present in the loop -- SCEV is not great at exploiting
5136 // these to compute max backedge taken counts, but can still use
5137 // these to prove lack of overflow. Use this fact to avoid
5138 // doing extra work that may not pay off.
5139
5140 if (isa<SCEVCouldNotCompute>(MaxBECount) && !HasGuards &&
5141 AC.assumptions().empty())
5142 return Result;
5143
5144 // If the backedge is guarded by a comparison with the pre-inc value the
5145 // addrec is safe. Also, if the entry is guarded by a comparison with the
5146 // start value and the backedge is guarded by a comparison with the post-inc
5147 // value, the addrec is safe.
5149 const SCEV *OverflowLimit =
5150 getSignedOverflowLimitForStep(Step, &Pred, this);
5151 if (OverflowLimit &&
5152 (isLoopBackedgeGuardedByCond(L, Pred, AR, OverflowLimit) ||
5153 isKnownOnEveryIteration(Pred, AR, OverflowLimit))) {
5154 Result = setFlags(Result, SCEV::FlagNSW);
5155 }
5156 return Result;
5157}
5159ScalarEvolution::proveNoUnsignedWrapViaInduction(const SCEVAddRecExpr *AR) {
5161
5162 if (AR->hasNoUnsignedWrap())
5163 return Result;
5164
5165 if (!AR->isAffine())
5166 return Result;
5167
5168 // This function can be expensive, only try to prove NUW once per AddRec.
5169 if (!UnsignedWrapViaInductionTried.insert(AR).second)
5170 return Result;
5171
5172 const SCEV *Step = AR->getStepRecurrence(*this);
5173 unsigned BitWidth = getTypeSizeInBits(AR->getType());
5174 const Loop *L = AR->getLoop();
5175
5176 // Check whether the backedge-taken count is SCEVCouldNotCompute.
5177 // Note that this serves two purposes: It filters out loops that are
5178 // simply not analyzable, and it covers the case where this code is
5179 // being called from within backedge-taken count analysis, such that
5180 // attempting to ask for the backedge-taken count would likely result
5181 // in infinite recursion. In the later case, the analysis code will
5182 // cope with a conservative value, and it will take care to purge
5183 // that value once it has finished.
5184 const SCEV *MaxBECount = getConstantMaxBackedgeTakenCount(L);
5185
5186 // Normally, in the cases we can prove no-overflow via a
5187 // backedge guarding condition, we can also compute a backedge
5188 // taken count for the loop. The exceptions are assumptions and
5189 // guards present in the loop -- SCEV is not great at exploiting
5190 // these to compute max backedge taken counts, but can still use
5191 // these to prove lack of overflow. Use this fact to avoid
5192 // doing extra work that may not pay off.
5193
5194 if (isa<SCEVCouldNotCompute>(MaxBECount) && !HasGuards &&
5195 AC.assumptions().empty())
5196 return Result;
5197
5198 // If the backedge is guarded by a comparison with the pre-inc value the
5199 // addrec is safe. Also, if the entry is guarded by a comparison with the
5200 // start value and the backedge is guarded by a comparison with the post-inc
5201 // value, the addrec is safe.
5202 if (isKnownPositive(Step)) {
5204 getUnsignedRangeMax(Step));
5207 Result = setFlags(Result, SCEV::FlagNUW);
5208 }
5209 }
5210
5211 return Result;
5212}
5213
5214namespace {
5215
5216/// Represents an abstract binary operation. This may exist as a
5217/// normal instruction or constant expression, or may have been
5218/// derived from an expression tree.
5219struct BinaryOp {
5220 unsigned Opcode;
5221 Value *LHS;
5222 Value *RHS;
5223 bool IsNSW = false;
5224 bool IsNUW = false;
5225
5226 /// Op is set if this BinaryOp corresponds to a concrete LLVM instruction or
5227 /// constant expression.
5228 Operator *Op = nullptr;
5229
5230 explicit BinaryOp(Operator *Op)
5231 : Opcode(Op->getOpcode()), LHS(Op->getOperand(0)), RHS(Op->getOperand(1)),
5232 Op(Op) {
5233 if (auto *OBO = dyn_cast<OverflowingBinaryOperator>(Op)) {
5234 IsNSW = OBO->hasNoSignedWrap();
5235 IsNUW = OBO->hasNoUnsignedWrap();
5236 }
5237 }
5238
5239 explicit BinaryOp(unsigned Opcode, Value *LHS, Value *RHS, bool IsNSW = false,
5240 bool IsNUW = false)
5241 : Opcode(Opcode), LHS(LHS), RHS(RHS), IsNSW(IsNSW), IsNUW(IsNUW) {}
5242};
5243
5244} // end anonymous namespace
5245
5246/// Try to map \p V into a BinaryOp, and return \c std::nullopt on failure.
5247static std::optional<BinaryOp> MatchBinaryOp(Value *V, const DataLayout &DL,
5248 AssumptionCache &AC,
5249 const DominatorTree &DT,
5250 const Instruction *CxtI) {
5251 auto *Op = dyn_cast<Operator>(V);
5252 if (!Op)
5253 return std::nullopt;
5254
5255 // Implementation detail: all the cleverness here should happen without
5256 // creating new SCEV expressions -- our caller knowns tricks to avoid creating
5257 // SCEV expressions when possible, and we should not break that.
5258
5259 switch (Op->getOpcode()) {
5260 case Instruction::Add:
5261 case Instruction::Sub:
5262 case Instruction::Mul:
5263 case Instruction::UDiv:
5264 case Instruction::URem:
5265 case Instruction::And:
5266 case Instruction::AShr:
5267 case Instruction::Shl:
5268 return BinaryOp(Op);
5269
5270 case Instruction::Or: {
5271 // Convert or disjoint into add nuw nsw.
5272 if (cast<PossiblyDisjointInst>(Op)->isDisjoint())
5273 return BinaryOp(Instruction::Add, Op->getOperand(0), Op->getOperand(1),
5274 /*IsNSW=*/true, /*IsNUW=*/true);
5275 return BinaryOp(Op);
5276 }
5277
5278 case Instruction::Xor:
5279 if (auto *RHSC = dyn_cast<ConstantInt>(Op->getOperand(1)))
5280 // If the RHS of the xor is a signmask, then this is just an add.
5281 // Instcombine turns add of signmask into xor as a strength reduction step.
5282 if (RHSC->getValue().isSignMask())
5283 return BinaryOp(Instruction::Add, Op->getOperand(0), Op->getOperand(1));
5284 // Binary `xor` is a bit-wise `add`.
5285 if (V->getType()->isIntegerTy(1))
5286 return BinaryOp(Instruction::Add, Op->getOperand(0), Op->getOperand(1));
5287 return BinaryOp(Op);
5288
5289 case Instruction::LShr:
5290 // Turn logical shift right of a constant into a unsigned divide.
5291 if (ConstantInt *SA = dyn_cast<ConstantInt>(Op->getOperand(1))) {
5292 uint32_t BitWidth = cast<IntegerType>(Op->getType())->getBitWidth();
5293
5294 // If the shift count is not less than the bitwidth, the result of
5295 // the shift is undefined. Don't try to analyze it, because the
5296 // resolution chosen here may differ from the resolution chosen in
5297 // other parts of the compiler.
5298 if (SA->getValue().ult(BitWidth)) {
5299 Constant *X =
5300 ConstantInt::get(SA->getContext(),
5301 APInt::getOneBitSet(BitWidth, SA->getZExtValue()));
5302 return BinaryOp(Instruction::UDiv, Op->getOperand(0), X);
5303 }
5304 }
5305 return BinaryOp(Op);
5306
5307 case Instruction::ExtractValue: {
5308 auto *EVI = cast<ExtractValueInst>(Op);
5309 if (EVI->getNumIndices() != 1 || EVI->getIndices()[0] != 0)
5310 break;
5311
5312 auto *WO = dyn_cast<WithOverflowInst>(EVI->getAggregateOperand());
5313 if (!WO)
5314 break;
5315
5316 Instruction::BinaryOps BinOp = WO->getBinaryOp();
5317 bool Signed = WO->isSigned();
5318 // TODO: Should add nuw/nsw flags for mul as well.
5319 if (BinOp == Instruction::Mul || !isOverflowIntrinsicNoWrap(WO, DT))
5320 return BinaryOp(BinOp, WO->getLHS(), WO->getRHS());
5321
5322 // Now that we know that all uses of the arithmetic-result component of
5323 // CI are guarded by the overflow check, we can go ahead and pretend
5324 // that the arithmetic is non-overflowing.
5325 return BinaryOp(BinOp, WO->getLHS(), WO->getRHS(),
5326 /* IsNSW = */ Signed, /* IsNUW = */ !Signed);
5327 }
5328
5329 default:
5330 break;
5331 }
5332
5333 // Recognise intrinsic loop.decrement.reg, and as this has exactly the same
5334 // semantics as a Sub, return a binary sub expression.
5335 if (auto *II = dyn_cast<IntrinsicInst>(V))
5336 if (II->getIntrinsicID() == Intrinsic::loop_decrement_reg)
5337 return BinaryOp(Instruction::Sub, II->getOperand(0), II->getOperand(1));
5338
5339 return std::nullopt;
5340}
5341
5342/// Helper function to createAddRecFromPHIWithCasts. We have a phi
5343/// node whose symbolic (unknown) SCEV is \p SymbolicPHI, which is updated via
5344/// the loop backedge by a SCEVAddExpr, possibly also with a few casts on the
5345/// way. This function checks if \p Op, an operand of this SCEVAddExpr,
5346/// follows one of the following patterns:
5347/// Op == (SExt ix (Trunc iy (%SymbolicPHI) to ix) to iy)
5348/// Op == (ZExt ix (Trunc iy (%SymbolicPHI) to ix) to iy)
5349/// If the SCEV expression of \p Op conforms with one of the expected patterns
5350/// we return the type of the truncation operation, and indicate whether the
5351/// truncated type should be treated as signed/unsigned by setting
5352/// \p Signed to true/false, respectively.
5353static Type *isSimpleCastedPHI(const SCEV *Op, const SCEVUnknown *SymbolicPHI,
5354 bool &Signed, ScalarEvolution &SE) {
5355 // The case where Op == SymbolicPHI (that is, with no type conversions on
5356 // the way) is handled by the regular add recurrence creating logic and
5357 // would have already been triggered in createAddRecForPHI. Reaching it here
5358 // means that createAddRecFromPHI had failed for this PHI before (e.g.,
5359 // because one of the other operands of the SCEVAddExpr updating this PHI is
5360 // not invariant).
5361 //
5362 // Here we look for the case where Op = (ext(trunc(SymbolicPHI))), and in
5363 // this case predicates that allow us to prove that Op == SymbolicPHI will
5364 // be added.
5365 if (Op == SymbolicPHI)
5366 return nullptr;
5367
5368 unsigned SourceBits = SE.getTypeSizeInBits(SymbolicPHI->getType());
5369 unsigned NewBits = SE.getTypeSizeInBits(Op->getType());
5370 if (SourceBits != NewBits)
5371 return nullptr;
5372
5373 const SCEVSignExtendExpr *SExt = dyn_cast<SCEVSignExtendExpr>(Op);
5374 const SCEVZeroExtendExpr *ZExt = dyn_cast<SCEVZeroExtendExpr>(Op);
5375 if (!SExt && !ZExt)
5376 return nullptr;
5377 const SCEVTruncateExpr *Trunc =
5378 SExt ? dyn_cast<SCEVTruncateExpr>(SExt->getOperand())
5379 : dyn_cast<SCEVTruncateExpr>(ZExt->getOperand());
5380 if (!Trunc)
5381 return nullptr;
5382 const SCEV *X = Trunc->getOperand();
5383 if (X != SymbolicPHI)
5384 return nullptr;
5385 Signed = SExt != nullptr;
5386 return Trunc->getType();
5387}
5388
5389static const Loop *isIntegerLoopHeaderPHI(const PHINode *PN, LoopInfo &LI) {
5390 if (!PN->getType()->isIntegerTy())
5391 return nullptr;
5392 const Loop *L = LI.getLoopFor(PN->getParent());
5393 if (!L || L->getHeader() != PN->getParent())
5394 return nullptr;
5395 return L;
5396}
5397
5398// Analyze \p SymbolicPHI, a SCEV expression of a phi node, and check if the
5399// computation that updates the phi follows the following pattern:
5400// (SExt/ZExt ix (Trunc iy (%SymbolicPHI) to ix) to iy) + InvariantAccum
5401// which correspond to a phi->trunc->sext/zext->add->phi update chain.
5402// If so, try to see if it can be rewritten as an AddRecExpr under some
5403// Predicates. If successful, return them as a pair. Also cache the results
5404// of the analysis.
5405//
5406// Example usage scenario:
5407// Say the Rewriter is called for the following SCEV:
5408// 8 * ((sext i32 (trunc i64 %X to i32) to i64) + %Step)
5409// where:
5410// %X = phi i64 (%Start, %BEValue)
5411// It will visitMul->visitAdd->visitSExt->visitTrunc->visitUnknown(%X),
5412// and call this function with %SymbolicPHI = %X.
5413//
5414// The analysis will find that the value coming around the backedge has
5415// the following SCEV:
5416// BEValue = ((sext i32 (trunc i64 %X to i32) to i64) + %Step)
5417// Upon concluding that this matches the desired pattern, the function
5418// will return the pair {NewAddRec, SmallPredsVec} where:
5419// NewAddRec = {%Start,+,%Step}
5420// SmallPredsVec = {P1, P2, P3} as follows:
5421// P1(WrapPred): AR: {trunc(%Start),+,(trunc %Step)}<nsw> Flags: <nssw>
5422// P2(EqualPred): %Start == (sext i32 (trunc i64 %Start to i32) to i64)
5423// P3(EqualPred): %Step == (sext i32 (trunc i64 %Step to i32) to i64)
5424// The returned pair means that SymbolicPHI can be rewritten into NewAddRec
5425// under the predicates {P1,P2,P3}.
5426// This predicated rewrite will be cached in PredicatedSCEVRewrites:
5427// PredicatedSCEVRewrites[{%X,L}] = {NewAddRec, {P1,P2,P3)}
5428//
5429// TODO's:
5430//
5431// 1) Extend the Induction descriptor to also support inductions that involve
5432// casts: When needed (namely, when we are called in the context of the
5433// vectorizer induction analysis), a Set of cast instructions will be
5434// populated by this method, and provided back to isInductionPHI. This is
5435// needed to allow the vectorizer to properly record them to be ignored by
5436// the cost model and to avoid vectorizing them (otherwise these casts,
5437// which are redundant under the runtime overflow checks, will be
5438// vectorized, which can be costly).
5439//
5440// 2) Support additional induction/PHISCEV patterns: We also want to support
5441// inductions where the sext-trunc / zext-trunc operations (partly) occur
5442// after the induction update operation (the induction increment):
5443//
5444// (Trunc iy (SExt/ZExt ix (%SymbolicPHI + InvariantAccum) to iy) to ix)
5445// which correspond to a phi->add->trunc->sext/zext->phi update chain.
5446//
5447// (Trunc iy ((SExt/ZExt ix (%SymbolicPhi) to iy) + InvariantAccum) to ix)
5448// which correspond to a phi->trunc->add->sext/zext->phi update chain.
5449//
5450// 3) Outline common code with createAddRecFromPHI to avoid duplication.
5451std::optional<std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>>>
5452ScalarEvolution::createAddRecFromPHIWithCastsImpl(const SCEVUnknown *SymbolicPHI) {
5454
5455 // *** Part1: Analyze if we have a phi-with-cast pattern for which we can
5456 // return an AddRec expression under some predicate.
5457
5458 auto *PN = cast<PHINode>(SymbolicPHI->getValue());
5459 const Loop *L = isIntegerLoopHeaderPHI(PN, LI);
5460 assert(L && "Expecting an integer loop header phi");
5461
5462 // The loop may have multiple entrances or multiple exits; we can analyze
5463 // this phi as an addrec if it has a unique entry value and a unique
5464 // backedge value.
5465 Value *BEValueV = nullptr, *StartValueV = nullptr;
5466 for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) {
5467 Value *V = PN->getIncomingValue(i);
5468 if (L->contains(PN->getIncomingBlock(i))) {
5469 if (!BEValueV) {
5470 BEValueV = V;
5471 } else if (BEValueV != V) {
5472 BEValueV = nullptr;
5473 break;
5474 }
5475 } else if (!StartValueV) {
5476 StartValueV = V;
5477 } else if (StartValueV != V) {
5478 StartValueV = nullptr;
5479 break;
5480 }
5481 }
5482 if (!BEValueV || !StartValueV)
5483 return std::nullopt;
5484
5485 const SCEV *BEValue = getSCEV(BEValueV);
5486
5487 // If the value coming around the backedge is an add with the symbolic
5488 // value we just inserted, possibly with casts that we can ignore under
5489 // an appropriate runtime guard, then we found a simple induction variable!
5490 const auto *Add = dyn_cast<SCEVAddExpr>(BEValue);
5491 if (!Add)
5492 return std::nullopt;
5493
5494 // If there is a single occurrence of the symbolic value, possibly
5495 // casted, replace it with a recurrence.
5496 unsigned FoundIndex = Add->getNumOperands();
5497 Type *TruncTy = nullptr;
5498 bool Signed;
5499 for (unsigned i = 0, e = Add->getNumOperands(); i != e; ++i)
5500 if ((TruncTy =
5501 isSimpleCastedPHI(Add->getOperand(i), SymbolicPHI, Signed, *this)))
5502 if (FoundIndex == e) {
5503 FoundIndex = i;
5504 break;
5505 }
5506
5507 if (FoundIndex == Add->getNumOperands())
5508 return std::nullopt;
5509
5510 // Create an add with everything but the specified operand.
5512 for (unsigned i = 0, e = Add->getNumOperands(); i != e; ++i)
5513 if (i != FoundIndex)
5514 Ops.push_back(Add->getOperand(i));
5515 const SCEV *Accum = getAddExpr(Ops);
5516
5517 // The runtime checks will not be valid if the step amount is
5518 // varying inside the loop.
5519 if (!isLoopInvariant(Accum, L))
5520 return std::nullopt;
5521
5522 // *** Part2: Create the predicates
5523
5524 // Analysis was successful: we have a phi-with-cast pattern for which we
5525 // can return an AddRec expression under the following predicates:
5526 //
5527 // P1: A Wrap predicate that guarantees that Trunc(Start) + i*Trunc(Accum)
5528 // fits within the truncated type (does not overflow) for i = 0 to n-1.
5529 // P2: An Equal predicate that guarantees that
5530 // Start = (Ext ix (Trunc iy (Start) to ix) to iy)
5531 // P3: An Equal predicate that guarantees that
5532 // Accum = (Ext ix (Trunc iy (Accum) to ix) to iy)
5533 //
5534 // As we next prove, the above predicates guarantee that:
5535 // Start + i*Accum = (Ext ix (Trunc iy ( Start + i*Accum ) to ix) to iy)
5536 //
5537 //
5538 // More formally, we want to prove that:
5539 // Expr(i+1) = Start + (i+1) * Accum
5540 // = (Ext ix (Trunc iy (Expr(i)) to ix) to iy) + Accum
5541 //
5542 // Given that:
5543 // 1) Expr(0) = Start
5544 // 2) Expr(1) = Start + Accum
5545 // = (Ext ix (Trunc iy (Start) to ix) to iy) + Accum :: from P2
5546 // 3) Induction hypothesis (step i):
5547 // Expr(i) = (Ext ix (Trunc iy (Expr(i-1)) to ix) to iy) + Accum
5548 //
5549 // Proof:
5550 // Expr(i+1) =
5551 // = Start + (i+1)*Accum
5552 // = (Start + i*Accum) + Accum
5553 // = Expr(i) + Accum
5554 // = (Ext ix (Trunc iy (Expr(i-1)) to ix) to iy) + Accum + Accum
5555 // :: from step i
5556 //
5557 // = (Ext ix (Trunc iy (Start + (i-1)*Accum) to ix) to iy) + Accum + Accum
5558 //
5559 // = (Ext ix (Trunc iy (Start + (i-1)*Accum) to ix) to iy)
5560 // + (Ext ix (Trunc iy (Accum) to ix) to iy)
5561 // + Accum :: from P3
5562 //
5563 // = (Ext ix (Trunc iy ((Start + (i-1)*Accum) + Accum) to ix) to iy)
5564 // + Accum :: from P1: Ext(x)+Ext(y)=>Ext(x+y)
5565 //
5566 // = (Ext ix (Trunc iy (Start + i*Accum) to ix) to iy) + Accum
5567 // = (Ext ix (Trunc iy (Expr(i)) to ix) to iy) + Accum
5568 //
5569 // By induction, the same applies to all iterations 1<=i<n:
5570 //
5571
5572 // Create a truncated addrec for which we will add a no overflow check (P1).
5573 const SCEV *StartVal = getSCEV(StartValueV);
5574 const SCEV *PHISCEV =
5575 getAddRecExpr(getTruncateExpr(StartVal, TruncTy),
5576 getTruncateExpr(Accum, TruncTy), L, SCEV::FlagAnyWrap);
5577
5578 // PHISCEV can be either a SCEVConstant or a SCEVAddRecExpr.
5579 // ex: If truncated Accum is 0 and StartVal is a constant, then PHISCEV
5580 // will be constant.
5581 //
5582 // If PHISCEV is a constant, then P1 degenerates into P2 or P3, so we don't
5583 // add P1.
5584 if (const auto *AR = dyn_cast<SCEVAddRecExpr>(PHISCEV)) {
5588 const SCEVPredicate *AddRecPred = getWrapPredicate(AR, AddedFlags);
5589 Predicates.push_back(AddRecPred);
5590 }
5591
5592 // Create the Equal Predicates P2,P3:
5593
5594 // It is possible that the predicates P2 and/or P3 are computable at
5595 // compile time due to StartVal and/or Accum being constants.
5596 // If either one is, then we can check that now and escape if either P2
5597 // or P3 is false.
5598
5599 // Construct the extended SCEV: (Ext ix (Trunc iy (Expr) to ix) to iy)
5600 // for each of StartVal and Accum
5601 auto getExtendedExpr = [&](const SCEV *Expr,
5602 bool CreateSignExtend) -> const SCEV * {
5603 assert(isLoopInvariant(Expr, L) && "Expr is expected to be invariant");
5604 const SCEV *TruncatedExpr = getTruncateExpr(Expr, TruncTy);
5605 const SCEV *ExtendedExpr =
5606 CreateSignExtend ? getSignExtendExpr(TruncatedExpr, Expr->getType())
5607 : getZeroExtendExpr(TruncatedExpr, Expr->getType());
5608 return ExtendedExpr;
5609 };
5610
5611 // Given:
5612 // ExtendedExpr = (Ext ix (Trunc iy (Expr) to ix) to iy
5613 // = getExtendedExpr(Expr)
5614 // Determine whether the predicate P: Expr == ExtendedExpr
5615 // is known to be false at compile time
5616 auto PredIsKnownFalse = [&](const SCEV *Expr,
5617 const SCEV *ExtendedExpr) -> bool {
5618 return Expr != ExtendedExpr &&
5619 isKnownPredicate(ICmpInst::ICMP_NE, Expr, ExtendedExpr);
5620 };
5621
5622 const SCEV *StartExtended = getExtendedExpr(StartVal, Signed);
5623 if (PredIsKnownFalse(StartVal, StartExtended)) {
5624 LLVM_DEBUG(dbgs() << "P2 is compile-time false\n";);
5625 return std::nullopt;
5626 }
5627
5628 // The Step is always Signed (because the overflow checks are either
5629 // NSSW or NUSW)
5630 const SCEV *AccumExtended = getExtendedExpr(Accum, /*CreateSignExtend=*/true);
5631 if (PredIsKnownFalse(Accum, AccumExtended)) {
5632 LLVM_DEBUG(dbgs() << "P3 is compile-time false\n";);
5633 return std::nullopt;
5634 }
5635
5636 auto AppendPredicate = [&](const SCEV *Expr,
5637 const SCEV *ExtendedExpr) -> void {
5638 if (Expr != ExtendedExpr &&
5639 !isKnownPredicate(ICmpInst::ICMP_EQ, Expr, ExtendedExpr)) {
5640 const SCEVPredicate *Pred = getEqualPredicate(Expr, ExtendedExpr);
5641 LLVM_DEBUG(dbgs() << "Added Predicate: " << *Pred);
5642 Predicates.push_back(Pred);
5643 }
5644 };
5645
5646 AppendPredicate(StartVal, StartExtended);
5647 AppendPredicate(Accum, AccumExtended);
5648
5649 // *** Part3: Predicates are ready. Now go ahead and create the new addrec in
5650 // which the casts had been folded away. The caller can rewrite SymbolicPHI
5651 // into NewAR if it will also add the runtime overflow checks specified in
5652 // Predicates.
5653 auto *NewAR = getAddRecExpr(StartVal, Accum, L, SCEV::FlagAnyWrap);
5654
5655 std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>> PredRewrite =
5656 std::make_pair(NewAR, Predicates);
5657 // Remember the result of the analysis for this SCEV at this locayyytion.
5658 PredicatedSCEVRewrites[{SymbolicPHI, L}] = PredRewrite;
5659 return PredRewrite;
5660}
5661
5662std::optional<std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>>>
5664 auto *PN = cast<PHINode>(SymbolicPHI->getValue());
5665 const Loop *L = isIntegerLoopHeaderPHI(PN, LI);
5666 if (!L)
5667 return std::nullopt;
5668
5669 // Check to see if we already analyzed this PHI.
5670 auto I = PredicatedSCEVRewrites.find({SymbolicPHI, L});
5671 if (I != PredicatedSCEVRewrites.end()) {
5672 std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>> Rewrite =
5673 I->second;
5674 // Analysis was done before and failed to create an AddRec:
5675 if (Rewrite.first == SymbolicPHI)
5676 return std::nullopt;
5677 // Analysis was done before and succeeded to create an AddRec under
5678 // a predicate:
5679 assert(isa<SCEVAddRecExpr>(Rewrite.first) && "Expected an AddRec");
5680 assert(!(Rewrite.second).empty() && "Expected to find Predicates");
5681 return Rewrite;
5682 }
5683
5684 std::optional<std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>>>
5685 Rewrite = createAddRecFromPHIWithCastsImpl(SymbolicPHI);
5686
5687 // Record in the cache that the analysis failed
5688 if (!Rewrite) {
5690 PredicatedSCEVRewrites[{SymbolicPHI, L}] = {SymbolicPHI, Predicates};
5691 return std::nullopt;
5692 }
5693
5694 return Rewrite;
5695}
5696
5697// FIXME: This utility is currently required because the Rewriter currently
5698// does not rewrite this expression:
5699// {0, +, (sext ix (trunc iy to ix) to iy)}
5700// into {0, +, %step},
5701// even when the following Equal predicate exists:
5702// "%step == (sext ix (trunc iy to ix) to iy)".
5704 const SCEVAddRecExpr *AR1, const SCEVAddRecExpr *AR2) const {
5705 if (AR1 == AR2)
5706 return true;
5707
5708 auto areExprsEqual = [&](const SCEV *Expr1, const SCEV *Expr2) -> bool {
5709 if (Expr1 != Expr2 &&
5710 !Preds->implies(SE.getEqualPredicate(Expr1, Expr2), SE) &&
5711 !Preds->implies(SE.getEqualPredicate(Expr2, Expr1), SE))
5712 return false;
5713 return true;
5714 };
5715
5716 if (!areExprsEqual(AR1->getStart(), AR2->getStart()) ||
5717 !areExprsEqual(AR1->getStepRecurrence(SE), AR2->getStepRecurrence(SE)))
5718 return false;
5719 return true;
5720}
5721
5722/// A helper function for createAddRecFromPHI to handle simple cases.
5723///
5724/// This function tries to find an AddRec expression for the simplest (yet most
5725/// common) cases: PN = PHI(Start, OP(Self, LoopInvariant)).
5726/// If it fails, createAddRecFromPHI will use a more general, but slow,
5727/// technique for finding the AddRec expression.
5728const SCEV *ScalarEvolution::createSimpleAffineAddRec(PHINode *PN,
5729 Value *BEValueV,
5730 Value *StartValueV) {
5731 const Loop *L = LI.getLoopFor(PN->getParent());
5732 assert(L && L->getHeader() == PN->getParent());
5733 assert(BEValueV && StartValueV);
5734
5735 auto BO = MatchBinaryOp(BEValueV, getDataLayout(), AC, DT, PN);
5736 if (!BO)
5737 return nullptr;
5738
5739 if (BO->Opcode != Instruction::Add)
5740 return nullptr;
5741
5742 const SCEV *Accum = nullptr;
5743 if (BO->LHS == PN && L->isLoopInvariant(BO->RHS))
5744 Accum = getSCEV(BO->RHS);
5745 else if (BO->RHS == PN && L->isLoopInvariant(BO->LHS))
5746 Accum = getSCEV(BO->LHS);
5747
5748 if (!Accum)
5749 return nullptr;
5750
5752 if (BO->IsNUW)
5753 Flags = setFlags(Flags, SCEV::FlagNUW);
5754 if (BO->IsNSW)
5755 Flags = setFlags(Flags, SCEV::FlagNSW);
5756
5757 const SCEV *StartVal = getSCEV(StartValueV);
5758 const SCEV *PHISCEV = getAddRecExpr(StartVal, Accum, L, Flags);
5759 insertValueToMap(PN, PHISCEV);
5760
5761 if (auto *AR = dyn_cast<SCEVAddRecExpr>(PHISCEV)) {
5762 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR),
5764 proveNoWrapViaConstantRanges(AR)));
5765 }
5766
5767 // We can add Flags to the post-inc expression only if we
5768 // know that it is *undefined behavior* for BEValueV to
5769 // overflow.
5770 if (auto *BEInst = dyn_cast<Instruction>(BEValueV)) {
5771 assert(isLoopInvariant(Accum, L) &&
5772 "Accum is defined outside L, but is not invariant?");
5773 if (isAddRecNeverPoison(BEInst, L))
5774 (void)getAddRecExpr(getAddExpr(StartVal, Accum), Accum, L, Flags);
5775 }
5776
5777 return PHISCEV;
5778}
5779
5780const SCEV *ScalarEvolution::createAddRecFromPHI(PHINode *PN) {
5781 const Loop *L = LI.getLoopFor(PN->getParent());
5782 if (!L || L->getHeader() != PN->getParent())
5783 return nullptr;
5784
5785 // The loop may have multiple entrances or multiple exits; we can analyze
5786 // this phi as an addrec if it has a unique entry value and a unique
5787 // backedge value.
5788 Value *BEValueV = nullptr, *StartValueV = nullptr;
5789 for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) {
5790 Value *V = PN->getIncomingValue(i);
5791 if (L->contains(PN->getIncomingBlock(i))) {
5792 if (!BEValueV) {
5793 BEValueV = V;
5794 } else if (BEValueV != V) {
5795 BEValueV = nullptr;
5796 break;
5797 }
5798 } else if (!StartValueV) {
5799 StartValueV = V;
5800 } else if (StartValueV != V) {
5801 StartValueV = nullptr;
5802 break;
5803 }
5804 }
5805 if (!BEValueV || !StartValueV)
5806 return nullptr;
5807
5808 assert(ValueExprMap.find_as(PN) == ValueExprMap.end() &&
5809 "PHI node already processed?");
5810
5811 // First, try to find AddRec expression without creating a fictituos symbolic
5812 // value for PN.
5813 if (auto *S = createSimpleAffineAddRec(PN, BEValueV, StartValueV))
5814 return S;
5815
5816 // Handle PHI node value symbolically.
5817 const SCEV *SymbolicName = getUnknown(PN);
5818 insertValueToMap(PN, SymbolicName);
5819
5820 // Using this symbolic name for the PHI, analyze the value coming around
5821 // the back-edge.
5822 const SCEV *BEValue = getSCEV(BEValueV);
5823
5824 // NOTE: If BEValue is loop invariant, we know that the PHI node just
5825 // has a special value for the first iteration of the loop.
5826
5827 // If the value coming around the backedge is an add with the symbolic
5828 // value we just inserted, then we found a simple induction variable!
5829 if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(BEValue)) {
5830 // If there is a single occurrence of the symbolic value, replace it
5831 // with a recurrence.
5832 unsigned FoundIndex = Add->getNumOperands();
5833 for (unsigned i = 0, e = Add->getNumOperands(); i != e; ++i)
5834 if (Add->getOperand(i) == SymbolicName)
5835 if (FoundIndex == e) {
5836 FoundIndex = i;
5837 break;
5838 }
5839
5840 if (FoundIndex != Add->getNumOperands()) {
5841 // Create an add with everything but the specified operand.
5843 for (unsigned i = 0, e = Add->getNumOperands(); i != e; ++i)
5844 if (i != FoundIndex)
5845 Ops.push_back(SCEVBackedgeConditionFolder::rewrite(Add->getOperand(i),
5846 L, *this));
5847 const SCEV *Accum = getAddExpr(Ops);
5848
5849 // This is not a valid addrec if the step amount is varying each
5850 // loop iteration, but is not itself an addrec in this loop.
5851 if (isLoopInvariant(Accum, L) ||
5852 (isa<SCEVAddRecExpr>(Accum) &&
5853 cast<SCEVAddRecExpr>(Accum)->getLoop() == L)) {
5855
5856 if (auto BO = MatchBinaryOp(BEValueV, getDataLayout(), AC, DT, PN)) {
5857 if (BO->Opcode == Instruction::Add && BO->LHS == PN) {
5858 if (BO->IsNUW)
5859 Flags = setFlags(Flags, SCEV::FlagNUW);
5860 if (BO->IsNSW)
5861 Flags = setFlags(Flags, SCEV::FlagNSW);
5862 }
5863 } else if (GEPOperator *GEP = dyn_cast<GEPOperator>(BEValueV)) {
5864 if (GEP->getOperand(0) == PN) {
5865 GEPNoWrapFlags NW = GEP->getNoWrapFlags();
5866 // If the increment has any nowrap flags, then we know the address
5867 // space cannot be wrapped around.
5868 if (NW != GEPNoWrapFlags::none())
5869 Flags = setFlags(Flags, SCEV::FlagNW);
5870 // If the GEP is nuw or nusw with non-negative offset, we know that
5871 // no unsigned wrap occurs. We cannot set the nsw flag as only the
5872 // offset is treated as signed, while the base is unsigned.
5873 if (NW.hasNoUnsignedWrap() ||
5875 Flags = setFlags(Flags, SCEV::FlagNUW);
5876 }
5877
5878 // We cannot transfer nuw and nsw flags from subtraction
5879 // operations -- sub nuw X, Y is not the same as add nuw X, -Y
5880 // for instance.
5881 }
5882
5883 const SCEV *StartVal = getSCEV(StartValueV);
5884 const SCEV *PHISCEV = getAddRecExpr(StartVal, Accum, L, Flags);
5885
5886 // Okay, for the entire analysis of this edge we assumed the PHI
5887 // to be symbolic. We now need to go back and purge all of the
5888 // entries for the scalars that use the symbolic expression.
5889 forgetMemoizedResults(SymbolicName);
5890 insertValueToMap(PN, PHISCEV);
5891
5892 if (auto *AR = dyn_cast<SCEVAddRecExpr>(PHISCEV)) {
5893 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR),
5895 proveNoWrapViaConstantRanges(AR)));
5896 }
5897
5898 // We can add Flags to the post-inc expression only if we
5899 // know that it is *undefined behavior* for BEValueV to
5900 // overflow.
5901 if (auto *BEInst = dyn_cast<Instruction>(BEValueV))
5902 if (isLoopInvariant(Accum, L) && isAddRecNeverPoison(BEInst, L))
5903 (void)getAddRecExpr(getAddExpr(StartVal, Accum), Accum, L, Flags);
5904
5905 return PHISCEV;
5906 }
5907 }
5908 } else {
5909 // Otherwise, this could be a loop like this:
5910 // i = 0; for (j = 1; ..; ++j) { .... i = j; }
5911 // In this case, j = {1,+,1} and BEValue is j.
5912 // Because the other in-value of i (0) fits the evolution of BEValue
5913 // i really is an addrec evolution.
5914 //
5915 // We can generalize this saying that i is the shifted value of BEValue
5916 // by one iteration:
5917 // PHI(f(0), f({1,+,1})) --> f({0,+,1})
5918
5919 // Do not allow refinement in rewriting of BEValue.
5920 if (isGuaranteedNotToCauseUB(BEValue)) {
5921 const SCEV *Shifted = SCEVShiftRewriter::rewrite(BEValue, L, *this);
5922 const SCEV *Start = SCEVInitRewriter::rewrite(Shifted, L, *this, false);
5923 if (Shifted != getCouldNotCompute() && Start != getCouldNotCompute() &&
5924 ::impliesPoison(BEValue, Start)) {
5925 const SCEV *StartVal = getSCEV(StartValueV);
5926 if (Start == StartVal) {
5927 // Okay, for the entire analysis of this edge we assumed the PHI
5928 // to be symbolic. We now need to go back and purge all of the
5929 // entries for the scalars that use the symbolic expression.
5930 forgetMemoizedResults(SymbolicName);
5931 insertValueToMap(PN, Shifted);
5932 return Shifted;
5933 }
5934 }
5935 }
5936 }
5937
5938 // Remove the temporary PHI node SCEV that has been inserted while intending
5939 // to create an AddRecExpr for this PHI node. We can not keep this temporary
5940 // as it will prevent later (possibly simpler) SCEV expressions to be added
5941 // to the ValueExprMap.
5942 eraseValueFromMap(PN);
5943
5944 return nullptr;
5945}
5946
5947// Try to match a control flow sequence that branches out at BI and merges back
5948// at Merge into a "C ? LHS : RHS" select pattern. Return true on a successful
5949// match.
5951 Value *&C, Value *&LHS, Value *&RHS) {
5952 C = BI->getCondition();
5953
5954 BasicBlockEdge LeftEdge(BI->getParent(), BI->getSuccessor(0));
5955 BasicBlockEdge RightEdge(BI->getParent(), BI->getSuccessor(1));
5956
5957 if (!LeftEdge.isSingleEdge())
5958 return false;
5959
5960 assert(RightEdge.isSingleEdge() && "Follows from LeftEdge.isSingleEdge()");
5961
5962 Use &LeftUse = Merge->getOperandUse(0);
5963 Use &RightUse = Merge->getOperandUse(1);
5964
5965 if (DT.dominates(LeftEdge, LeftUse) && DT.dominates(RightEdge, RightUse)) {
5966 LHS = LeftUse;
5967 RHS = RightUse;
5968 return true;
5969 }
5970
5971 if (DT.dominates(LeftEdge, RightUse) && DT.dominates(RightEdge, LeftUse)) {
5972 LHS = RightUse;
5973 RHS = LeftUse;
5974 return true;
5975 }
5976
5977 return false;
5978}
5979
5980const SCEV *ScalarEvolution::createNodeFromSelectLikePHI(PHINode *PN) {
5981 auto IsReachable =
5982 [&](BasicBlock *BB) { return DT.isReachableFromEntry(BB); };
5983 if (PN->getNumIncomingValues() == 2 && all_of(PN->blocks(), IsReachable)) {
5984 // Try to match
5985 //
5986 // br %cond, label %left, label %right
5987 // left:
5988 // br label %merge
5989 // right:
5990 // br label %merge
5991 // merge:
5992 // V = phi [ %x, %left ], [ %y, %right ]
5993 //
5994 // as "select %cond, %x, %y"
5995
5996 BasicBlock *IDom = DT[PN->getParent()]->getIDom()->getBlock();
5997 assert(IDom && "At least the entry block should dominate PN");
5998
5999 auto *BI = dyn_cast<BranchInst>(IDom->getTerminator());
6000 Value *Cond = nullptr, *LHS = nullptr, *RHS = nullptr;
6001
6002 if (BI && BI->isConditional() &&
6003 BrPHIToSelect(DT, BI, PN, Cond, LHS, RHS) &&
6004 properlyDominates(getSCEV(LHS), PN->getParent()) &&
6005 properlyDominates(getSCEV(RHS), PN->getParent()))
6006 return createNodeForSelectOrPHI(PN, Cond, LHS, RHS);
6007 }
6008
6009 return nullptr;
6010}
6011
6012/// Returns SCEV for the first operand of a phi if all phi operands have
6013/// identical opcodes and operands
6014/// eg.
6015/// a: %add = %a + %b
6016/// br %c
6017/// b: %add1 = %a + %b
6018/// br %c
6019/// c: %phi = phi [%add, a], [%add1, b]
6020/// scev(%phi) => scev(%add)
6021const SCEV *
6022ScalarEvolution::createNodeForPHIWithIdenticalOperands(PHINode *PN) {
6023 BinaryOperator *CommonInst = nullptr;
6024 // Check if instructions are identical.
6025 for (Value *Incoming : PN->incoming_values()) {
6026 auto *IncomingInst = dyn_cast<BinaryOperator>(Incoming);
6027 if (!IncomingInst)
6028 return nullptr;
6029 if (CommonInst) {
6030 if (!CommonInst->isIdenticalToWhenDefined(IncomingInst))
6031 return nullptr; // Not identical, give up
6032 } else {
6033 // Remember binary operator
6034 CommonInst = IncomingInst;
6035 }
6036 }
6037 if (!CommonInst)
6038 return nullptr;
6039
6040 // Check if SCEV exprs for instructions are identical.
6041 const SCEV *CommonSCEV = getSCEV(CommonInst);
6042 bool SCEVExprsIdentical =
6044 [this, CommonSCEV](Value *V) { return CommonSCEV == getSCEV(V); });
6045 return SCEVExprsIdentical ? CommonSCEV : nullptr;
6046}
6047
6048const SCEV *ScalarEvolution::createNodeForPHI(PHINode *PN) {
6049 if (const SCEV *S = createAddRecFromPHI(PN))
6050 return S;
6051
6052 // We do not allow simplifying phi (undef, X) to X here, to avoid reusing the
6053 // phi node for X.
6054 if (Value *V = simplifyInstruction(
6055 PN, {getDataLayout(), &TLI, &DT, &AC, /*CtxI=*/nullptr,
6056 /*UseInstrInfo=*/true, /*CanUseUndef=*/false}))
6057 return getSCEV(V);
6058
6059 if (const SCEV *S = createNodeForPHIWithIdenticalOperands(PN))
6060 return S;
6061
6062 if (const SCEV *S = createNodeFromSelectLikePHI(PN))
6063 return S;
6064
6065 // If it's not a loop phi, we can't handle it yet.
6066 return getUnknown(PN);
6067}
6068
6069bool SCEVMinMaxExprContains(const SCEV *Root, const SCEV *OperandToFind,
6070 SCEVTypes RootKind) {
6071 struct FindClosure {
6072 const SCEV *OperandToFind;
6073 const SCEVTypes RootKind; // Must be a sequential min/max expression.
6074 const SCEVTypes NonSequentialRootKind; // Non-seq variant of RootKind.
6075
6076 bool Found = false;
6077
6078 bool canRecurseInto(SCEVTypes Kind) const {
6079 // We can only recurse into the SCEV expression of the same effective type
6080 // as the type of our root SCEV expression, and into zero-extensions.
6081 return RootKind == Kind || NonSequentialRootKind == Kind ||
6082 scZeroExtend == Kind;
6083 };
6084
6085 FindClosure(const SCEV *OperandToFind, SCEVTypes RootKind)
6086 : OperandToFind(OperandToFind), RootKind(RootKind),
6087 NonSequentialRootKind(
6089 RootKind)) {}
6090
6091 bool follow(const SCEV *S) {
6092 Found = S == OperandToFind;
6093
6094 return !isDone() && canRecurseInto(S->getSCEVType());
6095 }
6096
6097 bool isDone() const { return Found; }
6098 };
6099
6100 FindClosure FC(OperandToFind, RootKind);
6101 visitAll(Root, FC);
6102 return FC.Found;
6103}
6104
6105std::optional<const SCEV *>
6106ScalarEvolution::createNodeForSelectOrPHIInstWithICmpInstCond(Type *Ty,
6107 ICmpInst *Cond,
6108 Value *TrueVal,
6109 Value *FalseVal) {
6110 // Try to match some simple smax or umax patterns.
6111 auto *ICI = Cond;
6112
6113 Value *LHS = ICI->getOperand(0);
6114 Value *RHS = ICI->getOperand(1);
6115
6116 switch (ICI->getPredicate()) {
6117 case ICmpInst::ICMP_SLT:
6118 case ICmpInst::ICMP_SLE:
6119 case ICmpInst::ICMP_ULT:
6120 case ICmpInst::ICMP_ULE:
6121 std::swap(LHS, RHS);
6122 [[fallthrough]];
6123 case ICmpInst::ICMP_SGT:
6124 case ICmpInst::ICMP_SGE:
6125 case ICmpInst::ICMP_UGT:
6126 case ICmpInst::ICMP_UGE:
6127 // a > b ? a+x : b+x -> max(a, b)+x
6128 // a > b ? b+x : a+x -> min(a, b)+x
6130 bool Signed = ICI->isSigned();
6131 const SCEV *LA = getSCEV(TrueVal);
6132 const SCEV *RA = getSCEV(FalseVal);
6133 const SCEV *LS = getSCEV(LHS);
6134 const SCEV *RS = getSCEV(RHS);
6135 if (LA->getType()->isPointerTy()) {
6136 // FIXME: Handle cases where LS/RS are pointers not equal to LA/RA.
6137 // Need to make sure we can't produce weird expressions involving
6138 // negated pointers.
6139 if (LA == LS && RA == RS)
6140 return Signed ? getSMaxExpr(LS, RS) : getUMaxExpr(LS, RS);
6141 if (LA == RS && RA == LS)
6142 return Signed ? getSMinExpr(LS, RS) : getUMinExpr(LS, RS);
6143 }
6144 auto CoerceOperand = [&](const SCEV *Op) -> const SCEV * {
6145 if (Op->getType()->isPointerTy()) {
6147 if (isa<SCEVCouldNotCompute>(Op))
6148 return Op;
6149 }
6150 if (Signed)
6151 Op = getNoopOrSignExtend(Op, Ty);
6152 else
6153 Op = getNoopOrZeroExtend(Op, Ty);
6154 return Op;
6155 };
6156 LS = CoerceOperand(LS);
6157 RS = CoerceOperand(RS);
6158 if (isa<SCEVCouldNotCompute>(LS) || isa<SCEVCouldNotCompute>(RS))
6159 break;
6160 const SCEV *LDiff = getMinusSCEV(LA, LS);
6161 const SCEV *RDiff = getMinusSCEV(RA, RS);
6162 if (LDiff == RDiff)
6163 return getAddExpr(Signed ? getSMaxExpr(LS, RS) : getUMaxExpr(LS, RS),
6164 LDiff);
6165 LDiff = getMinusSCEV(LA, RS);
6166 RDiff = getMinusSCEV(RA, LS);
6167 if (LDiff == RDiff)
6168 return getAddExpr(Signed ? getSMinExpr(LS, RS) : getUMinExpr(LS, RS),
6169 LDiff);
6170 }
6171 break;
6172 case ICmpInst::ICMP_NE:
6173 // x != 0 ? x+y : C+y -> x == 0 ? C+y : x+y
6174 std::swap(TrueVal, FalseVal);
6175 [[fallthrough]];
6176 case ICmpInst::ICMP_EQ:
6177 // x == 0 ? C+y : x+y -> umax(x, C)+y iff C u<= 1
6179 isa<ConstantInt>(RHS) && cast<ConstantInt>(RHS)->isZero()) {
6180 const SCEV *X = getNoopOrZeroExtend(getSCEV(LHS), Ty);
6181 const SCEV *TrueValExpr = getSCEV(TrueVal); // C+y
6182 const SCEV *FalseValExpr = getSCEV(FalseVal); // x+y
6183 const SCEV *Y = getMinusSCEV(FalseValExpr, X); // y = (x+y)-x
6184 const SCEV *C = getMinusSCEV(TrueValExpr, Y); // C = (C+y)-y
6185 if (isa<SCEVConstant>(C) && cast<SCEVConstant>(C)->getAPInt().ule(1))
6186 return getAddExpr(getUMaxExpr(X, C), Y);
6187 }
6188 // x == 0 ? 0 : umin (..., x, ...) -> umin_seq(x, umin (...))
6189 // x == 0 ? 0 : umin_seq(..., x, ...) -> umin_seq(x, umin_seq(...))
6190 // x == 0 ? 0 : umin (..., umin_seq(..., x, ...), ...)
6191 // -> umin_seq(x, umin (..., umin_seq(...), ...))
6192 if (isa<ConstantInt>(RHS) && cast<ConstantInt>(RHS)->isZero() &&
6193 isa<ConstantInt>(TrueVal) && cast<ConstantInt>(TrueVal)->isZero()) {
6194 const SCEV *X = getSCEV(LHS);
6195 while (auto *ZExt = dyn_cast<SCEVZeroExtendExpr>(X))
6196 X = ZExt->getOperand();
6197 if (getTypeSizeInBits(X->getType()) <= getTypeSizeInBits(Ty)) {
6198 const SCEV *FalseValExpr = getSCEV(FalseVal);
6199 if (SCEVMinMaxExprContains(FalseValExpr, X, scSequentialUMinExpr))
6200 return getUMinExpr(getNoopOrZeroExtend(X, Ty), FalseValExpr,
6201 /*Sequential=*/true);
6202 }
6203 }
6204 break;
6205 default:
6206 break;
6207 }
6208
6209 return std::nullopt;
6210}
6211
6212static std::optional<const SCEV *>
6214 const SCEV *TrueExpr, const SCEV *FalseExpr) {
6215 assert(CondExpr->getType()->isIntegerTy(1) &&
6216 TrueExpr->getType() == FalseExpr->getType() &&
6217 TrueExpr->getType()->isIntegerTy(1) &&
6218 "Unexpected operands of a select.");
6219
6220 // i1 cond ? i1 x : i1 C --> C + (i1 cond ? (i1 x - i1 C) : i1 0)
6221 // --> C + (umin_seq cond, x - C)
6222 //
6223 // i1 cond ? i1 C : i1 x --> C + (i1 cond ? i1 0 : (i1 x - i1 C))
6224 // --> C + (i1 ~cond ? (i1 x - i1 C) : i1 0)
6225 // --> C + (umin_seq ~cond, x - C)
6226
6227 // FIXME: while we can't legally model the case where both of the hands
6228 // are fully variable, we only require that the *difference* is constant.
6229 if (!isa<SCEVConstant>(TrueExpr) && !isa<SCEVConstant>(FalseExpr))
6230 return std::nullopt;
6231
6232 const SCEV *X, *C;
6233 if (isa<SCEVConstant>(TrueExpr)) {
6234 CondExpr = SE->getNotSCEV(CondExpr);
6235 X = FalseExpr;
6236 C = TrueExpr;
6237 } else {
6238 X = TrueExpr;
6239 C = FalseExpr;
6240 }
6241 return SE->getAddExpr(C, SE->getUMinExpr(CondExpr, SE->getMinusSCEV(X, C),
6242 /*Sequential=*/true));
6243}
6244
6245static std::optional<const SCEV *>
6247 Value *FalseVal) {
6248 if (!isa<ConstantInt>(TrueVal) && !isa<ConstantInt>(FalseVal))
6249 return std::nullopt;
6250
6251 const auto *SECond = SE->getSCEV(Cond);
6252 const auto *SETrue = SE->getSCEV(TrueVal);
6253 const auto *SEFalse = SE->getSCEV(FalseVal);
6254 return createNodeForSelectViaUMinSeq(SE, SECond, SETrue, SEFalse);
6255}
6256
6257const SCEV *ScalarEvolution::createNodeForSelectOrPHIViaUMinSeq(
6258 Value *V, Value *Cond, Value *TrueVal, Value *FalseVal) {
6259 assert(Cond->getType()->isIntegerTy(1) && "Select condition is not an i1?");
6260 assert(TrueVal->getType() == FalseVal->getType() &&
6261 V->getType() == TrueVal->getType() &&
6262 "Types of select hands and of the result must match.");
6263
6264 // For now, only deal with i1-typed `select`s.
6265 if (!V->getType()->isIntegerTy(1))
6266 return getUnknown(V);
6267
6268 if (std::optional<const SCEV *> S =
6269 createNodeForSelectViaUMinSeq(this, Cond, TrueVal, FalseVal))
6270 return *S;
6271
6272 return getUnknown(V);
6273}
6274
6275const SCEV *ScalarEvolution::createNodeForSelectOrPHI(Value *V, Value *Cond,
6276 Value *TrueVal,
6277 Value *FalseVal) {
6278 // Handle "constant" branch or select. This can occur for instance when a
6279 // loop pass transforms an inner loop and moves on to process the outer loop.
6280 if (auto *CI = dyn_cast<ConstantInt>(Cond))
6281 return getSCEV(CI->isOne() ? TrueVal : FalseVal);
6282
6283 if (auto *I = dyn_cast<Instruction>(V)) {
6284 if (auto *ICI = dyn_cast<ICmpInst>(Cond)) {
6285 if (std::optional<const SCEV *> S =
6286 createNodeForSelectOrPHIInstWithICmpInstCond(I->getType(), ICI,
6287 TrueVal, FalseVal))
6288 return *S;
6289 }
6290 }
6291
6292 return createNodeForSelectOrPHIViaUMinSeq(V, Cond, TrueVal, FalseVal);
6293}
6294
6295/// Expand GEP instructions into add and multiply operations. This allows them
6296/// to be analyzed by regular SCEV code.
6297const SCEV *ScalarEvolution::createNodeForGEP(GEPOperator *GEP) {
6298 assert(GEP->getSourceElementType()->isSized() &&
6299 "GEP source element type must be sized");
6300
6302 for (Value *Index : GEP->indices())
6303 IndexExprs.push_back(getSCEV(Index));
6304 return getGEPExpr(GEP, IndexExprs);
6305}
6306
6307APInt ScalarEvolution::getConstantMultipleImpl(const SCEV *S) {
6309 auto GetShiftedByZeros = [BitWidth](uint32_t TrailingZeros) {
6310 return TrailingZeros >= BitWidth
6312 : APInt::getOneBitSet(BitWidth, TrailingZeros);
6313 };
6314 auto GetGCDMultiple = [this](const SCEVNAryExpr *N) {
6315 // The result is GCD of all operands results.
6316 APInt Res = getConstantMultiple(N->getOperand(0));
6317 for (unsigned I = 1, E = N->getNumOperands(); I < E && Res != 1; ++I)
6319 Res, getConstantMultiple(N->getOperand(I)));
6320 return Res;
6321 };
6322
6323 switch (S->getSCEVType()) {
6324 case scConstant:
6325 return cast<SCEVConstant>(S)->getAPInt();
6326 case scPtrToInt:
6327 return getConstantMultiple(cast<SCEVPtrToIntExpr>(S)->getOperand());
6328 case scUDivExpr:
6329 case scVScale:
6330 return APInt(BitWidth, 1);
6331 case scTruncate: {
6332 // Only multiples that are a power of 2 will hold after truncation.
6333 const SCEVTruncateExpr *T = cast<SCEVTruncateExpr>(S);
6334 uint32_t TZ = getMinTrailingZeros(T->getOperand());
6335 return GetShiftedByZeros(TZ);
6336 }
6337 case scZeroExtend: {
6338 const SCEVZeroExtendExpr *Z = cast<SCEVZeroExtendExpr>(S);
6339 return getConstantMultiple(Z->getOperand()).zext(BitWidth);
6340 }
6341 case scSignExtend: {
6342 // Only multiples that are a power of 2 will hold after sext.
6343 const SCEVSignExtendExpr *E = cast<SCEVSignExtendExpr>(S);
6345 return GetShiftedByZeros(TZ);
6346 }
6347 case scMulExpr: {
6348 const SCEVMulExpr *M = cast<SCEVMulExpr>(S);
6349 if (M->hasNoUnsignedWrap()) {
6350 // The result is the product of all operand results.
6351 APInt Res = getConstantMultiple(M->getOperand(0));
6352 for (const SCEV *Operand : M->operands().drop_front())
6353 Res = Res * getConstantMultiple(Operand);
6354 return Res;
6355 }
6356
6357 // If there are no wrap guarentees, find the trailing zeros, which is the
6358 // sum of trailing zeros for all its operands.
6359 uint32_t TZ = 0;
6360 for (const SCEV *Operand : M->operands())
6361 TZ += getMinTrailingZeros(Operand);
6362 return GetShiftedByZeros(TZ);
6363 }
6364 case scAddExpr:
6365 case scAddRecExpr: {
6366 const SCEVNAryExpr *N = cast<SCEVNAryExpr>(S);
6367 if (N->hasNoUnsignedWrap())
6368 return GetGCDMultiple(N);
6369 // Find the trailing bits, which is the minimum of its operands.
6370 uint32_t TZ = getMinTrailingZeros(N->getOperand(0));
6371 for (const SCEV *Operand : N->operands().drop_front())
6372 TZ = std::min(TZ, getMinTrailingZeros(Operand));
6373 return GetShiftedByZeros(TZ);
6374 }
6375 case scUMaxExpr:
6376 case scSMaxExpr:
6377 case scUMinExpr:
6378 case scSMinExpr:
6380 return GetGCDMultiple(cast<SCEVNAryExpr>(S));
6381 case scUnknown: {
6382 // ask ValueTracking for known bits
6383 const SCEVUnknown *U = cast<SCEVUnknown>(S);
6384 unsigned Known =
6385 computeKnownBits(U->getValue(), getDataLayout(), 0, &AC, nullptr, &DT)
6386 .countMinTrailingZeros();
6387 return GetShiftedByZeros(Known);
6388 }
6389 case scCouldNotCompute:
6390 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
6391 }
6392 llvm_unreachable("Unknown SCEV kind!");
6393}
6394
6396 auto I = ConstantMultipleCache.find(S);
6397 if (I != ConstantMultipleCache.end())
6398 return I->second;
6399
6400 APInt Result = getConstantMultipleImpl(S);
6401 auto InsertPair = ConstantMultipleCache.insert({S, Result});
6402 assert(InsertPair.second && "Should insert a new key");
6403 return InsertPair.first->second;
6404}
6405
6407 APInt Multiple = getConstantMultiple(S);
6408 return Multiple == 0 ? APInt(Multiple.getBitWidth(), 1) : Multiple;
6409}
6410
6412 return std::min(getConstantMultiple(S).countTrailingZeros(),
6413 (unsigned)getTypeSizeInBits(S->getType()));
6414}
6415
6416/// Helper method to assign a range to V from metadata present in the IR.
6417static std::optional<ConstantRange> GetRangeFromMetadata(Value *V) {
6418 if (Instruction *I = dyn_cast<Instruction>(V)) {
6419 if (MDNode *MD = I->getMetadata(LLVMContext::MD_range))
6420 return getConstantRangeFromMetadata(*MD);
6421 if (const auto *CB = dyn_cast<CallBase>(V))
6422 if (std::optional<ConstantRange> Range = CB->getRange())
6423 return Range;
6424 }
6425 if (auto *A = dyn_cast<Argument>(V))
6426 if (std::optional<ConstantRange> Range = A->getRange())
6427 return Range;
6428
6429 return std::nullopt;
6430}
6431
6433 SCEV::NoWrapFlags Flags) {
6434 if (AddRec->getNoWrapFlags(Flags) != Flags) {
6435 AddRec->setNoWrapFlags(Flags);
6436 UnsignedRanges.erase(AddRec);
6437 SignedRanges.erase(AddRec);
6438 ConstantMultipleCache.erase(AddRec);
6439 }
6440}
6441
6442ConstantRange ScalarEvolution::
6443getRangeForUnknownRecurrence(const SCEVUnknown *U) {
6444 const DataLayout &DL = getDataLayout();
6445
6446 unsigned BitWidth = getTypeSizeInBits(U->getType());
6447 const ConstantRange FullSet(BitWidth, /*isFullSet=*/true);
6448
6449 // Match a simple recurrence of the form: <start, ShiftOp, Step>, and then
6450 // use information about the trip count to improve our available range. Note
6451 // that the trip count independent cases are already handled by known bits.
6452 // WARNING: The definition of recurrence used here is subtly different than
6453 // the one used by AddRec (and thus most of this file). Step is allowed to
6454 // be arbitrarily loop varying here, where AddRec allows only loop invariant
6455 // and other addrecs in the same loop (for non-affine addrecs). The code
6456 // below intentionally handles the case where step is not loop invariant.
6457 auto *P = dyn_cast<PHINode>(U->getValue());
6458 if (!P)
6459 return FullSet;
6460
6461 // Make sure that no Phi input comes from an unreachable block. Otherwise,
6462 // even the values that are not available in these blocks may come from them,
6463 // and this leads to false-positive recurrence test.
6464 for (auto *Pred : predecessors(P->getParent()))
6465 if (!DT.isReachableFromEntry(Pred))
6466 return FullSet;
6467
6468 BinaryOperator *BO;
6469 Value *Start, *Step;
6470 if (!matchSimpleRecurrence(P, BO, Start, Step))
6471 return FullSet;
6472
6473 // If we found a recurrence in reachable code, we must be in a loop. Note
6474 // that BO might be in some subloop of L, and that's completely okay.
6475 auto *L = LI.getLoopFor(P->getParent());
6476 assert(L && L->getHeader() == P->getParent());
6477 if (!L->contains(BO->getParent()))
6478 // NOTE: This bailout should be an assert instead. However, asserting
6479 // the condition here exposes a case where LoopFusion is querying SCEV
6480 // with malformed loop information during the midst of the transform.
6481 // There doesn't appear to be an obvious fix, so for the moment bailout
6482 // until the caller issue can be fixed. PR49566 tracks the bug.
6483 return FullSet;
6484
6485 // TODO: Extend to other opcodes such as mul, and div
6486 switch (BO->getOpcode()) {
6487 default:
6488 return FullSet;
6489 case Instruction::AShr:
6490 case Instruction::LShr:
6491 case Instruction::Shl:
6492 break;
6493 };
6494
6495 if (BO->getOperand(0) != P)
6496 // TODO: Handle the power function forms some day.
6497 return FullSet;
6498
6499 unsigned TC = getSmallConstantMaxTripCount(L);
6500 if (!TC || TC >= BitWidth)
6501 return FullSet;
6502
6503 auto KnownStart = computeKnownBits(Start, DL, 0, &AC, nullptr, &DT);
6504 auto KnownStep = computeKnownBits(Step, DL, 0, &AC, nullptr, &DT);
6505 assert(KnownStart.getBitWidth() == BitWidth &&
6506 KnownStep.getBitWidth() == BitWidth);
6507
6508 // Compute total shift amount, being careful of overflow and bitwidths.
6509 auto MaxShiftAmt = KnownStep.getMaxValue();
6510 APInt TCAP(BitWidth, TC-1);
6511 bool Overflow = false;
6512 auto TotalShift = MaxShiftAmt.umul_ov(TCAP, Overflow);
6513 if (Overflow)
6514 return FullSet;
6515
6516 switch (BO->getOpcode()) {
6517 default:
6518 llvm_unreachable("filtered out above");
6519 case Instruction::AShr: {
6520 // For each ashr, three cases:
6521 // shift = 0 => unchanged value
6522 // saturation => 0 or -1
6523 // other => a value closer to zero (of the same sign)
6524 // Thus, the end value is closer to zero than the start.
6525 auto KnownEnd = KnownBits::ashr(KnownStart,
6526 KnownBits::makeConstant(TotalShift));
6527 if (KnownStart.isNonNegative())
6528 // Analogous to lshr (simply not yet canonicalized)
6529 return ConstantRange::getNonEmpty(KnownEnd.getMinValue(),
6530 KnownStart.getMaxValue() + 1);
6531 if (KnownStart.isNegative())
6532 // End >=u Start && End <=s Start
6533 return ConstantRange::getNonEmpty(KnownStart.getMinValue(),
6534 KnownEnd.getMaxValue() + 1);
6535 break;
6536 }
6537 case Instruction::LShr: {
6538 // For each lshr, three cases:
6539 // shift = 0 => unchanged value
6540 // saturation => 0
6541 // other => a smaller positive number
6542 // Thus, the low end of the unsigned range is the last value produced.
6543 auto KnownEnd = KnownBits::lshr(KnownStart,
6544 KnownBits::makeConstant(TotalShift));
6545 return ConstantRange::getNonEmpty(KnownEnd.getMinValue(),
6546 KnownStart.getMaxValue() + 1);
6547 }
6548 case Instruction::Shl: {
6549 // Iff no bits are shifted out, value increases on every shift.
6550 auto KnownEnd = KnownBits::shl(KnownStart,
6551 KnownBits::makeConstant(TotalShift));
6552 if (TotalShift.ult(KnownStart.countMinLeadingZeros()))
6553 return ConstantRange(KnownStart.getMinValue(),
6554 KnownEnd.getMaxValue() + 1);
6555 break;
6556 }
6557 };
6558 return FullSet;
6559}
6560
6561const ConstantRange &
6562ScalarEvolution::getRangeRefIter(const SCEV *S,
6563 ScalarEvolution::RangeSignHint SignHint) {
6565 SignHint == ScalarEvolution::HINT_RANGE_UNSIGNED ? UnsignedRanges
6566 : SignedRanges;
6569
6570 // Add Expr to the worklist, if Expr is either an N-ary expression or a
6571 // SCEVUnknown PHI node.
6572 auto AddToWorklist = [&WorkList, &Seen, &Cache](const SCEV *Expr) {
6573 if (!Seen.insert(Expr).second)
6574 return;
6575 if (Cache.contains(Expr))
6576 return;
6577 switch (Expr->getSCEVType()) {
6578 case scUnknown:
6579 if (!isa<PHINode>(cast<SCEVUnknown>(Expr)->getValue()))
6580 break;
6581 [[fallthrough]];
6582 case scConstant:
6583 case scVScale:
6584 case scTruncate:
6585 case scZeroExtend:
6586 case scSignExtend:
6587 case scPtrToInt:
6588 case scAddExpr:
6589 case scMulExpr:
6590 case scUDivExpr:
6591 case scAddRecExpr:
6592 case scUMaxExpr:
6593 case scSMaxExpr:
6594 case scUMinExpr:
6595 case scSMinExpr:
6597 WorkList.push_back(Expr);
6598 break;
6599 case scCouldNotCompute:
6600 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
6601 }
6602 };
6603 AddToWorklist(S);
6604
6605 // Build worklist by queuing operands of N-ary expressions and phi nodes.
6606 for (unsigned I = 0; I != WorkList.size(); ++I) {
6607 const SCEV *P = WorkList[I];
6608 auto *UnknownS = dyn_cast<SCEVUnknown>(P);
6609 // If it is not a `SCEVUnknown`, just recurse into operands.
6610 if (!UnknownS) {
6611 for (const SCEV *Op : P->operands())
6612 AddToWorklist(Op);
6613 continue;
6614 }
6615 // `SCEVUnknown`'s require special treatment.
6616 if (const PHINode *P = dyn_cast<PHINode>(UnknownS->getValue())) {
6617 if (!PendingPhiRangesIter.insert(P).second)
6618 continue;
6619 for (auto &Op : reverse(P->operands()))
6620 AddToWorklist(getSCEV(Op));
6621 }
6622 }
6623
6624 if (!WorkList.empty()) {
6625 // Use getRangeRef to compute ranges for items in the worklist in reverse
6626 // order. This will force ranges for earlier operands to be computed before
6627 // their users in most cases.
6628 for (const SCEV *P : reverse(drop_begin(WorkList))) {
6629 getRangeRef(P, SignHint);
6630
6631 if (auto *UnknownS = dyn_cast<SCEVUnknown>(P))
6632 if (const PHINode *P = dyn_cast<PHINode>(UnknownS->getValue()))
6633 PendingPhiRangesIter.erase(P);
6634 }
6635 }
6636
6637 return getRangeRef(S, SignHint, 0);
6638}
6639
6640/// Determine the range for a particular SCEV. If SignHint is
6641/// HINT_RANGE_UNSIGNED (resp. HINT_RANGE_SIGNED) then getRange prefers ranges
6642/// with a "cleaner" unsigned (resp. signed) representation.
6643const ConstantRange &ScalarEvolution::getRangeRef(
6644 const SCEV *S, ScalarEvolution::RangeSignHint SignHint, unsigned Depth) {
6646 SignHint == ScalarEvolution::HINT_RANGE_UNSIGNED ? UnsignedRanges
6647 : SignedRanges;
6649 SignHint == ScalarEvolution::HINT_RANGE_UNSIGNED ? ConstantRange::Unsigned
6651
6652 // See if we've computed this range already.
6654 if (I != Cache.end())
6655 return I->second;
6656
6657 if (const SCEVConstant *C = dyn_cast<SCEVConstant>(S))
6658 return setRange(C, SignHint, ConstantRange(C->getAPInt()));
6659
6660 // Switch to iteratively computing the range for S, if it is part of a deeply
6661 // nested expression.
6663 return getRangeRefIter(S, SignHint);
6664
6665 unsigned BitWidth = getTypeSizeInBits(S->getType());
6666 ConstantRange ConservativeResult(BitWidth, /*isFullSet=*/true);
6667 using OBO = OverflowingBinaryOperator;
6668
6669 // If the value has known zeros, the maximum value will have those known zeros
6670 // as well.
6671 if (SignHint == ScalarEvolution::HINT_RANGE_UNSIGNED) {
6672 APInt Multiple = getNonZeroConstantMultiple(S);
6673 APInt Remainder = APInt::getMaxValue(BitWidth).urem(Multiple);
6674 if (!Remainder.isZero())
6675 ConservativeResult =
6677 APInt::getMaxValue(BitWidth) - Remainder + 1);
6678 }
6679 else {
6681 if (TZ != 0) {
6682 ConservativeResult = ConstantRange(
6684 APInt::getSignedMaxValue(BitWidth).ashr(TZ).shl(TZ) + 1);
6685 }
6686 }
6687
6688 switch (S->getSCEVType()) {
6689 case scConstant:
6690 llvm_unreachable("Already handled above.");
6691 case scVScale:
6692 return setRange(S, SignHint, getVScaleRange(&F, BitWidth));
6693 case scTruncate: {
6694 const SCEVTruncateExpr *Trunc = cast<SCEVTruncateExpr>(S);
6695 ConstantRange X = getRangeRef(Trunc->getOperand(), SignHint, Depth + 1);
6696 return setRange(
6697 Trunc, SignHint,
6698 ConservativeResult.intersectWith(X.truncate(BitWidth), RangeType));
6699 }
6700 case scZeroExtend: {
6701 const SCEVZeroExtendExpr *ZExt = cast<SCEVZeroExtendExpr>(S);
6702 ConstantRange X = getRangeRef(ZExt->getOperand(), SignHint, Depth + 1);
6703 return setRange(
6704 ZExt, SignHint,
6705 ConservativeResult.intersectWith(X.zeroExtend(BitWidth), RangeType));
6706 }
6707 case scSignExtend: {
6708 const SCEVSignExtendExpr *SExt = cast<SCEVSignExtendExpr>(S);
6709 ConstantRange X = getRangeRef(SExt->getOperand(), SignHint, Depth + 1);
6710 return setRange(
6711 SExt, SignHint,
6712 ConservativeResult.intersectWith(X.signExtend(BitWidth), RangeType));
6713 }
6714 case scPtrToInt: {
6715 const SCEVPtrToIntExpr *PtrToInt = cast<SCEVPtrToIntExpr>(S);
6716 ConstantRange X = getRangeRef(PtrToInt->getOperand(), SignHint, Depth + 1);
6717 return setRange(PtrToInt, SignHint, X);
6718 }
6719 case scAddExpr: {
6720 const SCEVAddExpr *Add = cast<SCEVAddExpr>(S);
6721 ConstantRange X = getRangeRef(Add->getOperand(0), SignHint, Depth + 1);
6722 unsigned WrapType = OBO::AnyWrap;
6723 if (Add->hasNoSignedWrap())
6724 WrapType |= OBO::NoSignedWrap;
6725 if (Add->hasNoUnsignedWrap())
6726 WrapType |= OBO::NoUnsignedWrap;
6727 for (const SCEV *Op : drop_begin(Add->operands()))
6728 X = X.addWithNoWrap(getRangeRef(Op, SignHint, Depth + 1), WrapType,
6729 RangeType);
6730 return setRange(Add, SignHint,
6731 ConservativeResult.intersectWith(X, RangeType));
6732 }
6733 case scMulExpr: {
6734 const SCEVMulExpr *Mul = cast<SCEVMulExpr>(S);
6735 ConstantRange X = getRangeRef(Mul->getOperand(0), SignHint, Depth + 1);
6736 for (const SCEV *Op : drop_begin(Mul->operands()))
6737 X = X.multiply(getRangeRef(Op, SignHint, Depth + 1));
6738 return setRange(Mul, SignHint,
6739 ConservativeResult.intersectWith(X, RangeType));
6740 }
6741 case scUDivExpr: {
6742 const SCEVUDivExpr *UDiv = cast<SCEVUDivExpr>(S);
6743 ConstantRange X = getRangeRef(UDiv->getLHS(), SignHint, Depth + 1);
6744 ConstantRange Y = getRangeRef(UDiv->getRHS(), SignHint, Depth + 1);
6745 return setRange(UDiv, SignHint,
6746 ConservativeResult.intersectWith(X.udiv(Y), RangeType));
6747 }
6748 case scAddRecExpr: {
6749 const SCEVAddRecExpr *AddRec = cast<SCEVAddRecExpr>(S);
6750 // If there's no unsigned wrap, the value will never be less than its
6751 // initial value.
6752 if (AddRec->hasNoUnsignedWrap()) {
6753 APInt UnsignedMinValue = getUnsignedRangeMin(AddRec->getStart());
6754 if (!UnsignedMinValue.isZero())
6755 ConservativeResult = ConservativeResult.intersectWith(
6756 ConstantRange(UnsignedMinValue, APInt(BitWidth, 0)), RangeType);
6757 }
6758
6759 // If there's no signed wrap, and all the operands except initial value have
6760 // the same sign or zero, the value won't ever be:
6761 // 1: smaller than initial value if operands are non negative,
6762 // 2: bigger than initial value if operands are non positive.
6763 // For both cases, value can not cross signed min/max boundary.
6764 if (AddRec->hasNoSignedWrap()) {
6765 bool AllNonNeg = true;
6766 bool AllNonPos = true;
6767 for (unsigned i = 1, e = AddRec->getNumOperands(); i != e; ++i) {
6768 if (!isKnownNonNegative(AddRec->getOperand(i)))
6769 AllNonNeg = false;
6770 if (!isKnownNonPositive(AddRec->getOperand(i)))
6771 AllNonPos = false;
6772 }
6773 if (AllNonNeg)
6774 ConservativeResult = ConservativeResult.intersectWith(
6777 RangeType);
6778 else if (AllNonPos)
6779 ConservativeResult = ConservativeResult.intersectWith(
6781 getSignedRangeMax(AddRec->getStart()) +
6782 1),
6783 RangeType);
6784 }
6785
6786 // TODO: non-affine addrec
6787 if (AddRec->isAffine()) {
6788 const SCEV *MaxBEScev =
6790 if (!isa<SCEVCouldNotCompute>(MaxBEScev)) {
6791 APInt MaxBECount = cast<SCEVConstant>(MaxBEScev)->getAPInt();
6792
6793 // Adjust MaxBECount to the same bitwidth as AddRec. We can truncate if
6794 // MaxBECount's active bits are all <= AddRec's bit width.
6795 if (MaxBECount.getBitWidth() > BitWidth &&
6796 MaxBECount.getActiveBits() <= BitWidth)
6797 MaxBECount = MaxBECount.trunc(BitWidth);
6798 else if (MaxBECount.getBitWidth() < BitWidth)
6799 MaxBECount = MaxBECount.zext(BitWidth);
6800
6801 if (MaxBECount.getBitWidth() == BitWidth) {
6802 auto RangeFromAffine = getRangeForAffineAR(
6803 AddRec->getStart(), AddRec->getStepRecurrence(*this), MaxBECount);
6804 ConservativeResult =
6805 ConservativeResult.intersectWith(RangeFromAffine, RangeType);
6806
6807 auto RangeFromFactoring = getRangeViaFactoring(
6808 AddRec->getStart(), AddRec->getStepRecurrence(*this), MaxBECount);
6809 ConservativeResult =
6810 ConservativeResult.intersectWith(RangeFromFactoring, RangeType);
6811 }
6812 }
6813
6814 // Now try symbolic BE count and more powerful methods.
6816 const SCEV *SymbolicMaxBECount =
6818 if (!isa<SCEVCouldNotCompute>(SymbolicMaxBECount) &&
6819 getTypeSizeInBits(MaxBEScev->getType()) <= BitWidth &&
6820 AddRec->hasNoSelfWrap()) {
6821 auto RangeFromAffineNew = getRangeForAffineNoSelfWrappingAR(
6822 AddRec, SymbolicMaxBECount, BitWidth, SignHint);
6823 ConservativeResult =
6824 ConservativeResult.intersectWith(RangeFromAffineNew, RangeType);
6825 }
6826 }
6827 }
6828
6829 return setRange(AddRec, SignHint, std::move(ConservativeResult));
6830 }
6831 case scUMaxExpr:
6832 case scSMaxExpr:
6833 case scUMinExpr:
6834 case scSMinExpr:
6835 case scSequentialUMinExpr: {
6837 switch (S->getSCEVType()) {
6838 case scUMaxExpr:
6839 ID = Intrinsic::umax;
6840 break;
6841 case scSMaxExpr:
6842 ID = Intrinsic::smax;
6843 break;
6844 case scUMinExpr:
6846 ID = Intrinsic::umin;
6847 break;
6848 case scSMinExpr:
6849 ID = Intrinsic::smin;
6850 break;
6851 default:
6852 llvm_unreachable("Unknown SCEVMinMaxExpr/SCEVSequentialMinMaxExpr.");
6853 }
6854
6855 const auto *NAry = cast<SCEVNAryExpr>(S);
6856 ConstantRange X = getRangeRef(NAry->getOperand(0), SignHint, Depth + 1);
6857 for (unsigned i = 1, e = NAry->getNumOperands(); i != e; ++i)
6858 X = X.intrinsic(
6859 ID, {X, getRangeRef(NAry->getOperand(i), SignHint, Depth + 1)});
6860 return setRange(S, SignHint,
6861 ConservativeResult.intersectWith(X, RangeType));
6862 }
6863 case scUnknown: {
6864 const SCEVUnknown *U = cast<SCEVUnknown>(S);
6865 Value *V = U->getValue();
6866
6867 // Check if the IR explicitly contains !range metadata.
6868 std::optional<ConstantRange> MDRange = GetRangeFromMetadata(V);
6869 if (MDRange)
6870 ConservativeResult =
6871 ConservativeResult.intersectWith(*MDRange, RangeType);
6872
6873 // Use facts about recurrences in the underlying IR. Note that add
6874 // recurrences are AddRecExprs and thus don't hit this path. This
6875 // primarily handles shift recurrences.
6876 auto CR = getRangeForUnknownRecurrence(U);
6877 ConservativeResult = ConservativeResult.intersectWith(CR);
6878
6879 // See if ValueTracking can give us a useful range.
6880 const DataLayout &DL = getDataLayout();
6881 KnownBits Known = computeKnownBits(V, DL, 0, &AC, nullptr, &DT);
6882 if (Known.getBitWidth() != BitWidth)
6883 Known = Known.zextOrTrunc(BitWidth);
6884
6885 // ValueTracking may be able to compute a tighter result for the number of
6886 // sign bits than for the value of those sign bits.
6887 unsigned NS = ComputeNumSignBits(V, DL, 0, &AC, nullptr, &DT);
6888 if (U->getType()->isPointerTy()) {
6889 // If the pointer size is larger than the index size type, this can cause
6890 // NS to be larger than BitWidth. So compensate for this.
6891 unsigned ptrSize = DL.getPointerTypeSizeInBits(U->getType());
6892 int ptrIdxDiff = ptrSize - BitWidth;
6893 if (ptrIdxDiff > 0 && ptrSize > BitWidth && NS > (unsigned)ptrIdxDiff)
6894 NS -= ptrIdxDiff;
6895 }
6896
6897 if (NS > 1) {
6898 // If we know any of the sign bits, we know all of the sign bits.
6899 if (!Known.Zero.getHiBits(NS).isZero())
6900 Known.Zero.setHighBits(NS);
6901 if (!Known.One.getHiBits(NS).isZero())
6902 Known.One.setHighBits(NS);
6903 }
6904
6905 if (Known.getMinValue() != Known.getMaxValue() + 1)
6906 ConservativeResult = ConservativeResult.intersectWith(
6907 ConstantRange(Known.getMinValue(), Known.getMaxValue() + 1),
6908 RangeType);
6909 if (NS > 1)
6910 ConservativeResult = ConservativeResult.intersectWith(
6912 APInt::getSignedMaxValue(BitWidth).ashr(NS - 1) + 1),
6913 RangeType);
6914
6915 if (U->getType()->isPointerTy() && SignHint == HINT_RANGE_UNSIGNED) {
6916 // Strengthen the range if the underlying IR value is a
6917 // global/alloca/heap allocation using the size of the object.
6918 bool CanBeNull, CanBeFreed;
6919 uint64_t DerefBytes =
6920 V->getPointerDereferenceableBytes(DL, CanBeNull, CanBeFreed);
6921 if (DerefBytes > 1 && isUIntN(BitWidth, DerefBytes)) {
6922 // The highest address the object can start is DerefBytes bytes before
6923 // the end (unsigned max value). If this value is not a multiple of the
6924 // alignment, the last possible start value is the next lowest multiple
6925 // of the alignment. Note: The computations below cannot overflow,
6926 // because if they would there's no possible start address for the
6927 // object.
6928 APInt MaxVal =
6929 APInt::getMaxValue(BitWidth) - APInt(BitWidth, DerefBytes);
6930 uint64_t Align = U->getValue()->getPointerAlignment(DL).value();
6931 uint64_t Rem = MaxVal.urem(Align);
6932 MaxVal -= APInt(BitWidth, Rem);
6933 APInt MinVal = APInt::getZero(BitWidth);
6934 if (llvm::isKnownNonZero(V, DL))
6935 MinVal = Align;
6936 ConservativeResult = ConservativeResult.intersectWith(
6937 ConstantRange::getNonEmpty(MinVal, MaxVal + 1), RangeType);
6938 }
6939 }
6940
6941 // A range of Phi is a subset of union of all ranges of its input.
6942 if (PHINode *Phi = dyn_cast<PHINode>(V)) {
6943 // Make sure that we do not run over cycled Phis.
6944 if (PendingPhiRanges.insert(Phi).second) {
6945 ConstantRange RangeFromOps(BitWidth, /*isFullSet=*/false);
6946
6947 for (const auto &Op : Phi->operands()) {
6948 auto OpRange = getRangeRef(getSCEV(Op), SignHint, Depth + 1);
6949 RangeFromOps = RangeFromOps.unionWith(OpRange);
6950 // No point to continue if we already have a full set.
6951 if (RangeFromOps.isFullSet())
6952 break;
6953 }
6954 ConservativeResult =
6955 ConservativeResult.intersectWith(RangeFromOps, RangeType);
6956 bool Erased = PendingPhiRanges.erase(Phi);
6957 assert(Erased && "Failed to erase Phi properly?");
6958 (void)Erased;
6959 }
6960 }
6961
6962 // vscale can't be equal to zero
6963 if (const auto *II = dyn_cast<IntrinsicInst>(V))
6964 if (II->getIntrinsicID() == Intrinsic::vscale) {
6966 ConservativeResult = ConservativeResult.difference(Disallowed);
6967 }
6968
6969 return setRange(U, SignHint, std::move(ConservativeResult));
6970 }
6971 case scCouldNotCompute:
6972 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
6973 }
6974
6975 return setRange(S, SignHint, std::move(ConservativeResult));
6976}
6977
6978// Given a StartRange, Step and MaxBECount for an expression compute a range of
6979// values that the expression can take. Initially, the expression has a value
6980// from StartRange and then is changed by Step up to MaxBECount times. Signed
6981// argument defines if we treat Step as signed or unsigned.
6983 const ConstantRange &StartRange,
6984 const APInt &MaxBECount,
6985 bool Signed) {
6986 unsigned BitWidth = Step.getBitWidth();
6987 assert(BitWidth == StartRange.getBitWidth() &&
6988 BitWidth == MaxBECount.getBitWidth() && "mismatched bit widths");
6989 // If either Step or MaxBECount is 0, then the expression won't change, and we
6990 // just need to return the initial range.
6991 if (Step == 0 || MaxBECount == 0)
6992 return StartRange;
6993
6994 // If we don't know anything about the initial value (i.e. StartRange is
6995 // FullRange), then we don't know anything about the final range either.
6996 // Return FullRange.
6997 if (StartRange.isFullSet())
6998 return ConstantRange::getFull(BitWidth);
6999
7000 // If Step is signed and negative, then we use its absolute value, but we also
7001 // note that we're moving in the opposite direction.
7002 bool Descending = Signed && Step.isNegative();
7003
7004 if (Signed)
7005 // This is correct even for INT_SMIN. Let's look at i8 to illustrate this:
7006 // abs(INT_SMIN) = abs(-128) = abs(0x80) = -0x80 = 0x80 = 128.
7007 // This equations hold true due to the well-defined wrap-around behavior of
7008 // APInt.
7009 Step = Step.abs();
7010
7011 // Check if Offset is more than full span of BitWidth. If it is, the
7012 // expression is guaranteed to overflow.
7013 if (APInt::getMaxValue(StartRange.getBitWidth()).udiv(Step).ult(MaxBECount))
7014 return ConstantRange::getFull(BitWidth);
7015
7016 // Offset is by how much the expression can change. Checks above guarantee no
7017 // overflow here.
7018 APInt Offset = Step * MaxBECount;
7019
7020 // Minimum value of the final range will match the minimal value of StartRange
7021 // if the expression is increasing and will be decreased by Offset otherwise.
7022 // Maximum value of the final range will match the maximal value of StartRange
7023 // if the expression is decreasing and will be increased by Offset otherwise.
7024 APInt StartLower = StartRange.getLower();
7025 APInt StartUpper = StartRange.getUpper() - 1;
7026 APInt MovedBoundary = Descending ? (StartLower - std::move(Offset))
7027 : (StartUpper + std::move(Offset));
7028
7029 // It's possible that the new minimum/maximum value will fall into the initial
7030 // range (due to wrap around). This means that the expression can take any
7031 // value in this bitwidth, and we have to return full range.
7032 if (StartRange.contains(MovedBoundary))
7033 return ConstantRange::getFull(BitWidth);
7034
7035 APInt NewLower =
7036 Descending ? std::move(MovedBoundary) : std::move(StartLower);
7037 APInt NewUpper =
7038 Descending ? std::move(StartUpper) : std::move(MovedBoundary);
7039 NewUpper += 1;
7040
7041 // No overflow detected, return [StartLower, StartUpper + Offset + 1) range.
7042 return ConstantRange::getNonEmpty(std::move(NewLower), std::move(NewUpper));
7043}
7044
7045ConstantRange ScalarEvolution::getRangeForAffineAR(const SCEV *Start,
7046 const SCEV *Step,
7047 const APInt &MaxBECount) {
7048 assert(getTypeSizeInBits(Start->getType()) ==
7049 getTypeSizeInBits(Step->getType()) &&
7050 getTypeSizeInBits(Start->getType()) == MaxBECount.getBitWidth() &&
7051 "mismatched bit widths");
7052
7053 // First, consider step signed.
7054 ConstantRange StartSRange = getSignedRange(Start);
7055 ConstantRange StepSRange = getSignedRange(Step);
7056
7057 // If Step can be both positive and negative, we need to find ranges for the
7058 // maximum absolute step values in both directions and union them.
7060 StepSRange.getSignedMin(), StartSRange, MaxBECount, /* Signed = */ true);
7062 StartSRange, MaxBECount,
7063 /* Signed = */ true));
7064
7065 // Next, consider step unsigned.
7067 getUnsignedRangeMax(Step), getUnsignedRange(Start), MaxBECount,
7068 /* Signed = */ false);
7069
7070 // Finally, intersect signed and unsigned ranges.
7072}
7073
7074ConstantRange ScalarEvolution::getRangeForAffineNoSelfWrappingAR(
7075 const SCEVAddRecExpr *AddRec, const SCEV *MaxBECount, unsigned BitWidth,
7076 ScalarEvolution::RangeSignHint SignHint) {
7077 assert(AddRec->isAffine() && "Non-affine AddRecs are not suppored!\n");
7078 assert(AddRec->hasNoSelfWrap() &&
7079 "This only works for non-self-wrapping AddRecs!");
7080 const bool IsSigned = SignHint == HINT_RANGE_SIGNED;
7081 const SCEV *Step = AddRec->getStepRecurrence(*this);
7082 // Only deal with constant step to save compile time.
7083 if (!isa<SCEVConstant>(Step))
7084 return ConstantRange::getFull(BitWidth);
7085 // Let's make sure that we can prove that we do not self-wrap during
7086 // MaxBECount iterations. We need this because MaxBECount is a maximum
7087 // iteration count estimate, and we might infer nw from some exit for which we
7088 // do not know max exit count (or any other side reasoning).
7089 // TODO: Turn into assert at some point.
7090 if (getTypeSizeInBits(MaxBECount->getType()) >
7091 getTypeSizeInBits(AddRec->getType()))
7092 return ConstantRange::getFull(BitWidth);
7093 MaxBECount = getNoopOrZeroExtend(MaxBECount, AddRec->getType());
7094 const SCEV *RangeWidth = getMinusOne(AddRec->getType());
7095 const SCEV *StepAbs = getUMinExpr(Step, getNegativeSCEV(Step));
7096 const SCEV *MaxItersWithoutWrap = getUDivExpr(RangeWidth, StepAbs);
7097 if (!isKnownPredicateViaConstantRanges(ICmpInst::ICMP_ULE, MaxBECount,
7098 MaxItersWithoutWrap))
7099 return ConstantRange::getFull(BitWidth);
7100
7101 ICmpInst::Predicate LEPred =
7103 ICmpInst::Predicate GEPred =
7105 const SCEV *End = AddRec->evaluateAtIteration(MaxBECount, *this);
7106
7107 // We know that there is no self-wrap. Let's take Start and End values and
7108 // look at all intermediate values V1, V2, ..., Vn that IndVar takes during
7109 // the iteration. They either lie inside the range [Min(Start, End),
7110 // Max(Start, End)] or outside it:
7111 //
7112 // Case 1: RangeMin ... Start V1 ... VN End ... RangeMax;
7113 // Case 2: RangeMin Vk ... V1 Start ... End Vn ... Vk + 1 RangeMax;
7114 //
7115 // No self wrap flag guarantees that the intermediate values cannot be BOTH
7116 // outside and inside the range [Min(Start, End), Max(Start, End)]. Using that
7117 // knowledge, let's try to prove that we are dealing with Case 1. It is so if
7118 // Start <= End and step is positive, or Start >= End and step is negative.
7119 const SCEV *Start = applyLoopGuards(AddRec->getStart(), AddRec->getLoop());
7120 ConstantRange StartRange = getRangeRef(Start, SignHint);
7121 ConstantRange EndRange = getRangeRef(End, SignHint);
7122 ConstantRange RangeBetween = StartRange.unionWith(EndRange);
7123 // If they already cover full iteration space, we will know nothing useful
7124 // even if we prove what we want to prove.
7125 if (RangeBetween.isFullSet())
7126 return RangeBetween;
7127 // Only deal with ranges that do not wrap (i.e. RangeMin < RangeMax).
7128 bool IsWrappedSet = IsSigned ? RangeBetween.isSignWrappedSet()
7129 : RangeBetween.isWrappedSet();
7130 if (IsWrappedSet)
7131 return ConstantRange::getFull(BitWidth);
7132
7133 if (isKnownPositive(Step) &&
7134 isKnownPredicateViaConstantRanges(LEPred, Start, End))
7135 return RangeBetween;
7136 if (isKnownNegative(Step) &&
7137 isKnownPredicateViaConstantRanges(GEPred, Start, End))
7138 return RangeBetween;
7139 return ConstantRange::getFull(BitWidth);
7140}
7141
7142ConstantRange ScalarEvolution::getRangeViaFactoring(const SCEV *Start,
7143 const SCEV *Step,
7144 const APInt &MaxBECount) {
7145 // RangeOf({C?A:B,+,C?P:Q}) == RangeOf(C?{A,+,P}:{B,+,Q})
7146 // == RangeOf({A,+,P}) union RangeOf({B,+,Q})
7147
7148 unsigned BitWidth = MaxBECount.getBitWidth();
7149 assert(getTypeSizeInBits(Start->getType()) == BitWidth &&
7150 getTypeSizeInBits(Step->getType()) == BitWidth &&
7151 "mismatched bit widths");
7152
7153 struct SelectPattern {
7154 Value *Condition = nullptr;
7155 APInt TrueValue;
7156 APInt FalseValue;
7157
7158 explicit SelectPattern(ScalarEvolution &SE, unsigned BitWidth,
7159 const SCEV *S) {
7160 std::optional<unsigned> CastOp;
7161 APInt Offset(BitWidth, 0);
7162
7164 "Should be!");
7165
7166 // Peel off a constant offset:
7167 if (auto *SA = dyn_cast<SCEVAddExpr>(S)) {
7168 // In the future we could consider being smarter here and handle
7169 // {Start+Step,+,Step} too.
7170 if (SA->getNumOperands() != 2 || !isa<SCEVConstant>(SA->getOperand(0)))
7171 return;
7172
7173 Offset = cast<SCEVConstant>(SA->getOperand(0))->getAPInt();
7174 S = SA->getOperand(1);
7175 }
7176
7177 // Peel off a cast operation
7178 if (auto *SCast = dyn_cast<SCEVIntegralCastExpr>(S)) {
7179 CastOp = SCast->getSCEVType();
7180 S = SCast->getOperand();
7181 }
7182
7183 using namespace llvm::PatternMatch;
7184
7185 auto *SU = dyn_cast<SCEVUnknown>(S);
7186 const APInt *TrueVal, *FalseVal;
7187 if (!SU ||
7188 !match(SU->getValue(), m_Select(m_Value(Condition), m_APInt(TrueVal),
7189 m_APInt(FalseVal)))) {
7190 Condition = nullptr;
7191 return;
7192 }
7193
7194 TrueValue = *TrueVal;
7195 FalseValue = *FalseVal;
7196
7197 // Re-apply the cast we peeled off earlier
7198 if (CastOp)
7199 switch (*CastOp) {
7200 default:
7201 llvm_unreachable("Unknown SCEV cast type!");
7202
7203 case scTruncate:
7204 TrueValue = TrueValue.trunc(BitWidth);
7205 FalseValue = FalseValue.trunc(BitWidth);
7206 break;
7207 case scZeroExtend:
7208 TrueValue = TrueValue.zext(BitWidth);
7209 FalseValue = FalseValue.zext(BitWidth);
7210 break;
7211 case scSignExtend:
7212 TrueValue = TrueValue.sext(BitWidth);
7213 FalseValue = FalseValue.sext(BitWidth);
7214 break;
7215 }
7216
7217 // Re-apply the constant offset we peeled off earlier
7218 TrueValue += Offset;
7219 FalseValue += Offset;
7220 }
7221
7222 bool isRecognized() { return Condition != nullptr; }
7223 };
7224
7225 SelectPattern StartPattern(*this, BitWidth, Start);
7226 if (!StartPattern.isRecognized())
7227 return ConstantRange::getFull(BitWidth);
7228
7229 SelectPattern StepPattern(*this, BitWidth, Step);
7230 if (!StepPattern.isRecognized())
7231 return ConstantRange::getFull(BitWidth);
7232
7233 if (StartPattern.Condition != StepPattern.Condition) {
7234 // We don't handle this case today; but we could, by considering four
7235 // possibilities below instead of two. I'm not sure if there are cases where
7236 // that will help over what getRange already does, though.
7237 return ConstantRange::getFull(BitWidth);
7238 }
7239
7240 // NB! Calling ScalarEvolution::getConstant is fine, but we should not try to
7241 // construct arbitrary general SCEV expressions here. This function is called
7242 // from deep in the call stack, and calling getSCEV (on a sext instruction,
7243 // say) can end up caching a suboptimal value.
7244
7245 // FIXME: without the explicit `this` receiver below, MSVC errors out with
7246 // C2352 and C2512 (otherwise it isn't needed).
7247
7248 const SCEV *TrueStart = this->getConstant(StartPattern.TrueValue);
7249 const SCEV *TrueStep = this->getConstant(StepPattern.TrueValue);
7250 const SCEV *FalseStart = this->getConstant(StartPattern.FalseValue);
7251 const SCEV *FalseStep = this->getConstant(StepPattern.FalseValue);
7252
7253 ConstantRange TrueRange =
7254 this->getRangeForAffineAR(TrueStart, TrueStep, MaxBECount);
7255 ConstantRange FalseRange =
7256 this->getRangeForAffineAR(FalseStart, FalseStep, MaxBECount);
7257
7258 return TrueRange.unionWith(FalseRange);
7259}
7260
7261SCEV::NoWrapFlags ScalarEvolution::getNoWrapFlagsFromUB(const Value *V) {
7262 if (isa<ConstantExpr>(V)) return SCEV::FlagAnyWrap;
7263 const BinaryOperator *BinOp = cast<BinaryOperator>(V);
7264
7265 // Return early if there are no flags to propagate to the SCEV.
7267 if (BinOp->hasNoUnsignedWrap())
7269 if (BinOp->hasNoSignedWrap())
7271 if (Flags == SCEV::FlagAnyWrap)
7272 return SCEV::FlagAnyWrap;
7273
7274 return isSCEVExprNeverPoison(BinOp) ? Flags : SCEV::FlagAnyWrap;
7275}
7276
7277const Instruction *
7278ScalarEvolution::getNonTrivialDefiningScopeBound(const SCEV *S) {
7279 if (auto *AddRec = dyn_cast<SCEVAddRecExpr>(S))
7280 return &*AddRec->getLoop()->getHeader()->begin();
7281 if (auto *U = dyn_cast<SCEVUnknown>(S))
7282 if (auto *I = dyn_cast<Instruction>(U->getValue()))
7283 return I;
7284 return nullptr;
7285}
7286
7287const Instruction *
7288ScalarEvolution::getDefiningScopeBound(ArrayRef<const SCEV *> Ops,
7289 bool &Precise) {
7290 Precise = true;
7291 // Do a bounded search of the def relation of the requested SCEVs.
7294 auto pushOp = [&](const SCEV *S) {
7295 if (!Visited.insert(S).second)
7296 return;
7297 // Threshold of 30 here is arbitrary.
7298 if (Visited.size() > 30) {
7299 Precise = false;
7300 return;
7301 }
7302 Worklist.push_back(S);
7303 };
7304
7305 for (const auto *S : Ops)
7306 pushOp(S);
7307
7308 const Instruction *Bound = nullptr;
7309 while (!Worklist.empty()) {
7310 auto *S = Worklist.pop_back_val();
7311 if (auto *DefI = getNonTrivialDefiningScopeBound(S)) {
7312 if (!Bound || DT.dominates(Bound, DefI))
7313 Bound = DefI;
7314 } else {
7315 for (const auto *Op : S->operands())
7316 pushOp(Op);
7317 }
7318 }
7319 return Bound ? Bound : &*F.getEntryBlock().begin();
7320}
7321
7322const Instruction *
7323ScalarEvolution::getDefiningScopeBound(ArrayRef<const SCEV *> Ops) {
7324 bool Discard;
7325 return getDefiningScopeBound(Ops, Discard);
7326}
7327
7328bool ScalarEvolution::isGuaranteedToTransferExecutionTo(const Instruction *A,
7329 const Instruction *B) {
7330 if (A->getParent() == B->getParent() &&
7332 B->getIterator()))
7333 return true;
7334
7335 auto *BLoop = LI.getLoopFor(B->getParent());
7336 if (BLoop && BLoop->getHeader() == B->getParent() &&
7337 BLoop->getLoopPreheader() == A->getParent() &&
7339 A->getParent()->end()) &&
7340 isGuaranteedToTransferExecutionToSuccessor(B->getParent()->begin(),
7341 B->getIterator()))
7342 return true;
7343 return false;
7344}
7345
7346bool ScalarEvolution::isGuaranteedNotToBePoison(const SCEV *Op) {
7347 SCEVPoisonCollector PC(/* LookThroughMaybePoisonBlocking */ true);
7348 visitAll(Op, PC);
7349 return PC.MaybePoison.empty();
7350}
7351
7352bool ScalarEvolution::isGuaranteedNotToCauseUB(const SCEV *Op) {
7353 return !SCEVExprContains(Op, [this](const SCEV *S) {
7354 auto *UDiv = dyn_cast<SCEVUDivExpr>(S);
7355 // The UDiv may be UB if the divisor is poison or zero. Unless the divisor
7356 // is a non-zero constant, we have to assume the UDiv may be UB.
7357 return UDiv && (!isKnownNonZero(UDiv->getOperand(1)) ||
7358 !isGuaranteedNotToBePoison(UDiv->getOperand(1)));
7359 });
7360}
7361
7362bool ScalarEvolution::isSCEVExprNeverPoison(const Instruction *I) {
7363 // Only proceed if we can prove that I does not yield poison.
7365 return false;
7366
7367 // At this point we know that if I is executed, then it does not wrap
7368 // according to at least one of NSW or NUW. If I is not executed, then we do
7369 // not know if the calculation that I represents would wrap. Multiple
7370 // instructions can map to the same SCEV. If we apply NSW or NUW from I to
7371 // the SCEV, we must guarantee no wrapping for that SCEV also when it is
7372 // derived from other instructions that map to the same SCEV. We cannot make
7373 // that guarantee for cases where I is not executed. So we need to find a
7374 // upper bound on the defining scope for the SCEV, and prove that I is
7375 // executed every time we enter that scope. When the bounding scope is a
7376 // loop (the common case), this is equivalent to proving I executes on every
7377 // iteration of that loop.
7379 for (const Use &Op : I->operands()) {
7380 // I could be an extractvalue from a call to an overflow intrinsic.
7381 // TODO: We can do better here in some cases.
7382 if (isSCEVable(Op->getType()))
7383 SCEVOps.push_back(getSCEV(Op));
7384 }
7385 auto *DefI = getDefiningScopeBound(SCEVOps);
7386 return isGuaranteedToTransferExecutionTo(DefI, I);
7387}
7388
7389bool ScalarEvolution::isAddRecNeverPoison(const Instruction *I, const Loop *L) {
7390 // If we know that \c I can never be poison period, then that's enough.
7391 if (isSCEVExprNeverPoison(I))
7392 return true;
7393
7394 // If the loop only has one exit, then we know that, if the loop is entered,
7395 // any instruction dominating that exit will be executed. If any such
7396 // instruction would result in UB, the addrec cannot be poison.
7397 //
7398 // This is basically the same reasoning as in isSCEVExprNeverPoison(), but
7399 // also handles uses outside the loop header (they just need to dominate the
7400 // single exit).
7401
7402 auto *ExitingBB = L->getExitingBlock();
7403 if (!ExitingBB || !loopHasNoAbnormalExits(L))
7404 return false;
7405
7408
7409 // We start by assuming \c I, the post-inc add recurrence, is poison. Only
7410 // things that are known to be poison under that assumption go on the
7411 // Worklist.
7412 KnownPoison.insert(I);
7413 Worklist.push_back(I);
7414
7415 while (!Worklist.empty()) {
7416 const Instruction *Poison = Worklist.pop_back_val();
7417
7418 for (const Use &U : Poison->uses()) {
7419 const Instruction *PoisonUser = cast<Instruction>(U.getUser());
7420 if (mustTriggerUB(PoisonUser, KnownPoison) &&
7421 DT.dominates(PoisonUser->getParent(), ExitingBB))
7422 return true;
7423
7424 if (propagatesPoison(U) && L->contains(PoisonUser))
7425 if (KnownPoison.insert(PoisonUser).second)
7426 Worklist.push_back(PoisonUser);
7427 }
7428 }
7429
7430 return false;
7431}
7432
7433ScalarEvolution::LoopProperties
7434ScalarEvolution::getLoopProperties(const Loop *L) {
7435 using LoopProperties = ScalarEvolution::LoopProperties;
7436
7437 auto Itr = LoopPropertiesCache.find(L);
7438 if (Itr == LoopPropertiesCache.end()) {
7439 auto HasSideEffects = [](Instruction *I) {
7440 if (auto *SI = dyn_cast<StoreInst>(I))
7441 return !SI->isSimple();
7442
7443 return I->mayThrow() || I->mayWriteToMemory();
7444 };
7445
7446 LoopProperties LP = {/* HasNoAbnormalExits */ true,
7447 /*HasNoSideEffects*/ true};
7448
7449 for (auto *BB : L->getBlocks())
7450 for (auto &I : *BB) {
7452 LP.HasNoAbnormalExits = false;
7453 if (HasSideEffects(&I))
7454 LP.HasNoSideEffects = false;
7455 if (!LP.HasNoAbnormalExits && !LP.HasNoSideEffects)
7456 break; // We're already as pessimistic as we can get.
7457 }
7458
7459 auto InsertPair = LoopPropertiesCache.insert({L, LP});
7460 assert(InsertPair.second && "We just checked!");
7461 Itr = InsertPair.first;
7462 }
7463
7464 return Itr->second;
7465}
7466
7468 // A mustprogress loop without side effects must be finite.
7469 // TODO: The check used here is very conservative. It's only *specific*
7470 // side effects which are well defined in infinite loops.
7471 return isFinite(L) || (isMustProgress(L) && loopHasNoSideEffects(L));
7472}
7473
7474const SCEV *ScalarEvolution::createSCEVIter(Value *V) {
7475 // Worklist item with a Value and a bool indicating whether all operands have
7476 // been visited already.
7479
7480 Stack.emplace_back(V, true);
7481 Stack.emplace_back(V, false);
7482 while (!Stack.empty()) {
7483 auto E = Stack.pop_back_val();
7484 Value *CurV = E.getPointer();
7485
7486 if (getExistingSCEV(CurV))
7487 continue;
7488
7490 const SCEV *CreatedSCEV = nullptr;
7491 // If all operands have been visited already, create the SCEV.
7492 if (E.getInt()) {
7493 CreatedSCEV = createSCEV(CurV);
7494 } else {
7495 // Otherwise get the operands we need to create SCEV's for before creating
7496 // the SCEV for CurV. If the SCEV for CurV can be constructed trivially,
7497 // just use it.
7498 CreatedSCEV = getOperandsToCreate(CurV, Ops);
7499 }
7500
7501 if (CreatedSCEV) {
7502 insertValueToMap(CurV, CreatedSCEV);
7503 } else {
7504 // Queue CurV for SCEV creation, followed by its's operands which need to
7505 // be constructed first.
7506 Stack.emplace_back(CurV, true);
7507 for (Value *Op : Ops)
7508 Stack.emplace_back(Op, false);
7509 }
7510 }
7511
7512 return getExistingSCEV(V);
7513}
7514
7515const SCEV *
7516ScalarEvolution::getOperandsToCreate(Value *V, SmallVectorImpl<Value *> &Ops) {
7517 if (!isSCEVable(V->getType()))
7518 return getUnknown(V);
7519
7520 if (Instruction *I = dyn_cast<Instruction>(V)) {
7521 // Don't attempt to analyze instructions in blocks that aren't
7522 // reachable. Such instructions don't matter, and they aren't required
7523 // to obey basic rules for definitions dominating uses which this
7524 // analysis depends on.
7525 if (!DT.isReachableFromEntry(I->getParent()))
7526 return getUnknown(PoisonValue::get(V->getType()));
7527 } else if (ConstantInt *CI = dyn_cast<ConstantInt>(V))
7528 return getConstant(CI);
7529 else if (isa<GlobalAlias>(V))
7530 return getUnknown(V);
7531 else if (!isa<ConstantExpr>(V))
7532 return getUnknown(V);
7533
7534 Operator *U = cast<Operator>(V);
7535 if (auto BO =
7536 MatchBinaryOp(U, getDataLayout(), AC, DT, dyn_cast<Instruction>(V))) {
7537 bool IsConstArg = isa<ConstantInt>(BO->RHS);
7538 switch (BO->Opcode) {
7539 case Instruction::Add:
7540 case Instruction::Mul: {
7541 // For additions and multiplications, traverse add/mul chains for which we
7542 // can potentially create a single SCEV, to reduce the number of
7543 // get{Add,Mul}Expr calls.
7544 do {
7545 if (BO->Op) {
7546 if (BO->Op != V && getExistingSCEV(BO->Op)) {
7547 Ops.push_back(BO->Op);
7548 break;
7549 }
7550 }
7551 Ops.push_back(BO->RHS);
7552 auto NewBO = MatchBinaryOp(BO->LHS, getDataLayout(), AC, DT,
7553 dyn_cast<Instruction>(V));
7554 if (!NewBO ||
7555 (BO->Opcode == Instruction::Add &&
7556 (NewBO->Opcode != Instruction::Add &&
7557 NewBO->Opcode != Instruction::Sub)) ||
7558 (BO->Opcode == Instruction::Mul &&
7559 NewBO->Opcode != Instruction::Mul)) {
7560 Ops.push_back(BO->LHS);
7561 break;
7562 }
7563 // CreateSCEV calls getNoWrapFlagsFromUB, which under certain conditions
7564 // requires a SCEV for the LHS.
7565 if (BO->Op && (BO->IsNSW || BO->IsNUW)) {
7566 auto *I = dyn_cast<Instruction>(BO->Op);
7567 if (I && programUndefinedIfPoison(I)) {
7568 Ops.push_back(BO->LHS);
7569 break;
7570 }
7571 }
7572 BO = NewBO;
7573 } while (true);
7574 return nullptr;
7575 }
7576 case Instruction::Sub:
7577 case Instruction::UDiv:
7578 case Instruction::URem:
7579 break;
7580 case Instruction::AShr:
7581 case Instruction::Shl:
7582 case Instruction::Xor:
7583 if (!IsConstArg)
7584 return nullptr;
7585 break;
7586 case Instruction::And:
7587 case Instruction::Or:
7588 if (!IsConstArg && !BO->LHS->getType()->isIntegerTy(1))
7589 return nullptr;
7590 break;
7591 case Instruction::LShr:
7592 return getUnknown(V);
7593 default:
7594 llvm_unreachable("Unhandled binop");
7595 break;
7596 }
7597
7598 Ops.push_back(BO->LHS);
7599 Ops.push_back(BO->RHS);
7600 return nullptr;
7601 }
7602
7603 switch (U->getOpcode()) {
7604 case Instruction::Trunc:
7605 case Instruction::ZExt:
7606 case Instruction::SExt:
7607 case Instruction::PtrToInt:
7608 Ops.push_back(U->getOperand(0));
7609 return nullptr;
7610
7611 case Instruction::BitCast:
7612 if (isSCEVable(U->getType()) && isSCEVable(U->getOperand(0)->getType())) {
7613 Ops.push_back(U->getOperand(0));
7614 return nullptr;
7615 }
7616 return getUnknown(V);
7617
7618 case Instruction::SDiv:
7619 case Instruction::SRem:
7620 Ops.push_back(U->getOperand(0));
7621 Ops.push_back(U->getOperand(1));
7622 return nullptr;
7623
7624 case Instruction::GetElementPtr:
7625 assert(cast<GEPOperator>(U)->getSourceElementType()->isSized() &&
7626 "GEP source element type must be sized");
7627 for (Value *Index : U->operands())
7628 Ops.push_back(Index);
7629 return nullptr;
7630
7631 case Instruction::IntToPtr:
7632 return getUnknown(V);
7633
7634 case Instruction::PHI:
7635 // Keep constructing SCEVs' for phis recursively for now.
7636 return nullptr;
7637
7638 case Instruction::Select: {
7639 // Check if U is a select that can be simplified to a SCEVUnknown.
7640 auto CanSimplifyToUnknown = [this, U]() {
7641 if (U->getType()->isIntegerTy(1) || isa<ConstantInt>(U->getOperand(0)))
7642 return false;
7643
7644 auto *ICI = dyn_cast<ICmpInst>(U->getOperand(0));
7645 if (!ICI)
7646 return false;
7647 Value *LHS = ICI->getOperand(0);
7648 Value *RHS = ICI->getOperand(1);
7649 if (ICI->getPredicate() == CmpInst::ICMP_EQ ||
7650 ICI->getPredicate() == CmpInst::ICMP_NE) {
7651 if (!(isa<ConstantInt>(RHS) && cast<ConstantInt>(RHS)->isZero()))
7652 return true;
7653 } else if (getTypeSizeInBits(LHS->getType()) >
7654 getTypeSizeInBits(U->getType()))
7655 return true;
7656 return false;
7657 };
7658 if (CanSimplifyToUnknown())
7659 return getUnknown(U);
7660
7661 for (Value *Inc : U->operands())
7662 Ops.push_back(Inc);
7663 return nullptr;
7664 break;
7665 }
7666 case Instruction::Call:
7667 case Instruction::Invoke:
7668 if (Value *RV = cast<CallBase>(U)->getReturnedArgOperand()) {
7669 Ops.push_back(RV);
7670 return nullptr;
7671 }
7672
7673 if (auto *II = dyn_cast<IntrinsicInst>(U)) {
7674 switch (II->getIntrinsicID()) {
7675 case Intrinsic::abs:
7676 Ops.push_back(II->getArgOperand(0));
7677 return nullptr;
7678 case Intrinsic::umax:
7679 case Intrinsic::umin:
7680 case Intrinsic::smax:
7681 case Intrinsic::smin:
7682 case Intrinsic::usub_sat:
7683 case Intrinsic::uadd_sat:
7684 Ops.push_back(II->getArgOperand(0));
7685 Ops.push_back(II->getArgOperand(1));
7686 return nullptr;
7687 case Intrinsic::start_loop_iterations:
7688 case Intrinsic::annotation:
7689 case Intrinsic::ptr_annotation:
7690 Ops.push_back(II->getArgOperand(0));
7691 return nullptr;
7692 default:
7693 break;
7694 }
7695 }
7696 break;
7697 }
7698
7699 return nullptr;
7700}
7701
7702const SCEV *ScalarEvolution::createSCEV(Value *V) {
7703 if (!isSCEVable(V->getType()))
7704 return getUnknown(V);
7705
7706 if (Instruction *I = dyn_cast<Instruction>(V)) {
7707 // Don't attempt to analyze instructions in blocks that aren't
7708 // reachable. Such instructions don't matter, and they aren't required
7709 // to obey basic rules for definitions dominating uses which this
7710 // analysis depends on.
7711 if (!DT.isReachableFromEntry(I->getParent()))
7712 return getUnknown(PoisonValue::get(V->getType()));
7713 } else if (ConstantInt *CI = dyn_cast<ConstantInt>(V))
7714 return getConstant(CI);
7715 else if (isa<GlobalAlias>(V))
7716 return getUnknown(V);
7717 else if (!isa<ConstantExpr>(V))
7718 return getUnknown(V);
7719
7720 const SCEV *LHS;
7721 const SCEV *RHS;
7722
7723 Operator *U = cast<Operator>(V);
7724 if (auto BO =
7725 MatchBinaryOp(U, getDataLayout(), AC, DT, dyn_cast<Instruction>(V))) {
7726 switch (BO->Opcode) {
7727 case Instruction::Add: {
7728 // The simple thing to do would be to just call getSCEV on both operands
7729 // and call getAddExpr with the result. However if we're looking at a
7730 // bunch of things all added together, this can be quite inefficient,
7731 // because it leads to N-1 getAddExpr calls for N ultimate operands.
7732 // Instead, gather up all the operands and make a single getAddExpr call.
7733 // LLVM IR canonical form means we need only traverse the left operands.
7735 do {
7736 if (BO->Op) {
7737 if (auto *OpSCEV = getExistingSCEV(BO->Op)) {
7738 AddOps.push_back(OpSCEV);
7739 break;
7740 }
7741
7742 // If a NUW or NSW flag can be applied to the SCEV for this
7743 // addition, then compute the SCEV for this addition by itself
7744 // with a separate call to getAddExpr. We need to do that
7745 // instead of pushing the operands of the addition onto AddOps,
7746 // since the flags are only known to apply to this particular
7747 // addition - they may not apply to other additions that can be
7748 // formed with operands from AddOps.
7749 const SCEV *RHS = getSCEV(BO->RHS);
7750 SCEV::NoWrapFlags Flags = getNoWrapFlagsFromUB(BO->Op);
7751 if (Flags != SCEV::FlagAnyWrap) {
7752 const SCEV *LHS = getSCEV(BO->LHS);
7753 if (BO->Opcode == Instruction::Sub)
7754 AddOps.push_back(getMinusSCEV(LHS, RHS, Flags));
7755 else
7756 AddOps.push_back(getAddExpr(LHS, RHS, Flags));
7757 break;
7758 }
7759 }
7760
7761 if (BO->Opcode == Instruction::Sub)
7762 AddOps.push_back(getNegativeSCEV(getSCEV(BO->RHS)));
7763 else
7764 AddOps.push_back(getSCEV(BO->RHS));
7765
7766 auto NewBO = MatchBinaryOp(BO->LHS, getDataLayout(), AC, DT,
7767 dyn_cast<Instruction>(V));
7768 if (!NewBO || (NewBO->Opcode != Instruction::Add &&
7769 NewBO->Opcode != Instruction::Sub)) {
7770 AddOps.push_back(getSCEV(BO->LHS));
7771 break;
7772 }
7773 BO = NewBO;
7774 } while (true);
7775
7776 return getAddExpr(AddOps);
7777 }
7778
7779 case Instruction::Mul: {
7781 do {
7782 if (BO->Op) {
7783 if (auto *OpSCEV = getExistingSCEV(BO->Op)) {
7784 MulOps.push_back(OpSCEV);
7785 break;
7786 }
7787
7788 SCEV::NoWrapFlags Flags = getNoWrapFlagsFromUB(BO->Op);
7789 if (Flags != SCEV::FlagAnyWrap) {
7790 LHS = getSCEV(BO->LHS);
7791 RHS = getSCEV(BO->RHS);
7792 MulOps.push_back(getMulExpr(LHS, RHS, Flags));
7793 break;
7794 }
7795 }
7796
7797 MulOps.push_back(getSCEV(BO->RHS));
7798 auto NewBO = MatchBinaryOp(BO->LHS, getDataLayout(), AC, DT,
7799 dyn_cast<Instruction>(V));
7800 if (!NewBO || NewBO->Opcode != Instruction::Mul) {
7801 MulOps.push_back(getSCEV(BO->LHS));
7802 break;
7803 }
7804 BO = NewBO;
7805 } while (true);
7806
7807 return getMulExpr(MulOps);
7808 }
7809 case Instruction::UDiv:
7810 LHS = getSCEV(BO->LHS);
7811 RHS = getSCEV(BO->RHS);
7812 return getUDivExpr(LHS, RHS);
7813 case Instruction::URem:
7814 LHS = getSCEV(BO->LHS);
7815 RHS = getSCEV(BO->RHS);
7816 return getURemExpr(LHS, RHS);
7817 case Instruction::Sub: {
7819 if (BO->Op)
7820 Flags = getNoWrapFlagsFromUB(BO->Op);
7821 LHS = getSCEV(BO->LHS);
7822 RHS = getSCEV(BO->RHS);
7823 return getMinusSCEV(LHS, RHS, Flags);
7824 }
7825 case Instruction::And:
7826 // For an expression like x&255 that merely masks off the high bits,
7827 // use zext(trunc(x)) as the SCEV expression.
7828 if (ConstantInt *CI = dyn_cast<ConstantInt>(BO->RHS)) {
7829 if (CI->isZero())
7830 return getSCEV(BO->RHS);
7831 if (CI->isMinusOne())
7832 return getSCEV(BO->LHS);
7833 const APInt &A = CI->getValue();
7834
7835 // Instcombine's ShrinkDemandedConstant may strip bits out of
7836 // constants, obscuring what would otherwise be a low-bits mask.
7837 // Use computeKnownBits to compute what ShrinkDemandedConstant
7838 // knew about to reconstruct a low-bits mask value.
7839 unsigned LZ = A.countl_zero();
7840 unsigned TZ = A.countr_zero();
7841 unsigned BitWidth = A.getBitWidth();
7842 KnownBits Known(BitWidth);
7843 computeKnownBits(BO->LHS, Known, getDataLayout(),
7844 0, &AC, nullptr, &DT);
7845
7846 APInt EffectiveMask =
7847 APInt::getLowBitsSet(BitWidth, BitWidth - LZ - TZ).shl(TZ);
7848 if ((LZ != 0 || TZ != 0) && !((~A & ~Known.Zero) & EffectiveMask)) {
7849 const SCEV *MulCount = getConstant(APInt::getOneBitSet(BitWidth, TZ));
7850 const SCEV *LHS = getSCEV(BO->LHS);
7851 const SCEV *ShiftedLHS = nullptr;
7852 if (auto *LHSMul = dyn_cast<SCEVMulExpr>(LHS)) {
7853 if (auto *OpC = dyn_cast<SCEVConstant>(LHSMul->getOperand(0))) {
7854 // For an expression like (x * 8) & 8, simplify the multiply.
7855 unsigned MulZeros = OpC->getAPInt().countr_zero();
7856 unsigned GCD = std::min(MulZeros, TZ);
7857 APInt DivAmt = APInt::getOneBitSet(BitWidth, TZ - GCD);
7859 MulOps.push_back(getConstant(OpC->getAPInt().lshr(GCD)));
7860 append_range(MulOps, LHSMul->operands().drop_front());
7861 auto *NewMul = getMulExpr(MulOps, LHSMul->getNoWrapFlags());
7862 ShiftedLHS = getUDivExpr(NewMul, getConstant(DivAmt));
7863 }
7864 }
7865 if (!ShiftedLHS)
7866 ShiftedLHS = getUDivExpr(LHS, MulCount);
7867 return getMulExpr(
7869 getTruncateExpr(ShiftedLHS,
7870 IntegerType::get(getContext(), BitWidth - LZ - TZ)),
7871 BO->LHS->getType()),
7872 MulCount);
7873 }
7874 }
7875 // Binary `and` is a bit-wise `umin`.
7876 if (BO->LHS->getType()->isIntegerTy(1)) {
7877 LHS = getSCEV(BO->LHS);
7878 RHS = getSCEV(BO->RHS);
7879 return getUMinExpr(LHS, RHS);
7880 }
7881 break;
7882
7883 case Instruction::Or:
7884 // Binary `or` is a bit-wise `umax`.
7885 if (BO->LHS->getType()->isIntegerTy(1)) {
7886 LHS = getSCEV(BO->LHS);
7887 RHS = getSCEV(BO->RHS);
7888 return getUMaxExpr(LHS, RHS);
7889 }
7890 break;
7891
7892 case Instruction::Xor:
7893 if (ConstantInt *CI = dyn_cast<ConstantInt>(BO->RHS)) {
7894 // If the RHS of xor is -1, then this is a not operation.
7895 if (CI->isMinusOne())
7896 return getNotSCEV(getSCEV(BO->LHS));
7897
7898 // Model xor(and(x, C), C) as and(~x, C), if C is a low-bits mask.
7899 // This is a variant of the check for xor with -1, and it handles
7900 // the case where instcombine has trimmed non-demanded bits out
7901 // of an xor with -1.
7902 if (auto *LBO = dyn_cast<BinaryOperator>(BO->LHS))
7903 if (ConstantInt *LCI = dyn_cast<ConstantInt>(LBO->getOperand(1)))
7904 if (LBO->getOpcode() == Instruction::And &&
7905 LCI->getValue() == CI->getValue())
7906 if (const SCEVZeroExtendExpr *Z =
7907 dyn_cast<SCEVZeroExtendExpr>(getSCEV(BO->LHS))) {
7908 Type *UTy = BO->LHS->getType();
7909 const SCEV *Z0 = Z->getOperand();
7910 Type *Z0Ty = Z0->getType();
7911 unsigned Z0TySize = getTypeSizeInBits(Z0Ty);
7912
7913 // If C is a low-bits mask, the zero extend is serving to
7914 // mask off the high bits. Complement the operand and
7915 // re-apply the zext.
7916 if (CI->getValue().isMask(Z0TySize))
7917 return getZeroExtendExpr(getNotSCEV(Z0), UTy);
7918
7919 // If C is a single bit, it may be in the sign-bit position
7920 // before the zero-extend. In this case, represent the xor
7921 // using an add, which is equivalent, and re-apply the zext.
7922 APInt Trunc = CI->getValue().trunc(Z0TySize);
7923 if (Trunc.zext(getTypeSizeInBits(UTy)) == CI->getValue() &&
7924 Trunc.isSignMask())
7925 return getZeroExtendExpr(getAddExpr(Z0, getConstant(Trunc)),
7926 UTy);
7927 }
7928 }
7929 break;
7930
7931 case Instruction::Shl:
7932 // Turn shift left of a constant amount into a multiply.
7933 if (ConstantInt *SA = dyn_cast<ConstantInt>(BO->RHS)) {
7934 uint32_t BitWidth = cast<IntegerType>(SA->getType())->getBitWidth();
7935
7936 // If the shift count is not less than the bitwidth, the result of
7937 // the shift is undefined. Don't try to analyze it, because the
7938 // resolution chosen here may differ from the resolution chosen in
7939 // other parts of the compiler.
7940 if (SA->getValue().uge(BitWidth))
7941 break;
7942
7943 // We can safely preserve the nuw flag in all cases. It's also safe to
7944 // turn a nuw nsw shl into a nuw nsw mul. However, nsw in isolation
7945 // requires special handling. It can be preserved as long as we're not
7946 // left shifting by bitwidth - 1.
7947 auto Flags = SCEV::FlagAnyWrap;
7948 if (BO->Op) {
7949 auto MulFlags = getNoWrapFlagsFromUB(BO->Op);
7950 if ((MulFlags & SCEV::FlagNSW) &&
7951 ((MulFlags & SCEV::FlagNUW) || SA->getValue().ult(BitWidth - 1)))
7953 if (MulFlags & SCEV::FlagNUW)
7955 }
7956
7957 ConstantInt *X = ConstantInt::get(
7958 getContext(), APInt::getOneBitSet(BitWidth, SA->getZExtValue()));
7959 return getMulExpr(getSCEV(BO->LHS), getConstant(X), Flags);
7960 }
7961 break;
7962
7963 case Instruction::AShr:
7964 // AShr X, C, where C is a constant.
7965 ConstantInt *CI = dyn_cast<ConstantInt>(BO->RHS);
7966 if (!CI)
7967 break;
7968
7969 Type *OuterTy = BO->LHS->getType();
7971 // If the shift count is not less than the bitwidth, the result of
7972 // the shift is undefined. Don't try to analyze it, because the
7973 // resolution chosen here may differ from the resolution chosen in
7974 // other parts of the compiler.
7975 if (CI->getValue().uge(BitWidth))
7976 break;
7977
7978 if (CI->isZero())
7979 return getSCEV(BO->LHS); // shift by zero --> noop
7980
7981 uint64_t AShrAmt = CI->getZExtValue();
7982 Type *TruncTy = IntegerType::get(getContext(), BitWidth - AShrAmt);
7983
7984 Operator *L = dyn_cast<Operator>(BO->LHS);
7985 const SCEV *AddTruncateExpr = nullptr;
7986 ConstantInt *ShlAmtCI = nullptr;
7987 const SCEV *AddConstant = nullptr;
7988
7989 if (L && L->getOpcode() == Instruction::Add) {
7990 // X = Shl A, n
7991 // Y = Add X, c
7992 // Z = AShr Y, m
7993 // n, c and m are constants.
7994
7995 Operator *LShift = dyn_cast<Operator>(L->getOperand(0));
7996 ConstantInt *AddOperandCI = dyn_cast<ConstantInt>(L->getOperand(1));
7997 if (LShift && LShift->getOpcode() == Instruction::Shl) {
7998 if (AddOperandCI) {
7999 const SCEV *ShlOp0SCEV = getSCEV(LShift->getOperand(0));
8000 ShlAmtCI = dyn_cast<ConstantInt>(LShift->getOperand(1));
8001 // since we truncate to TruncTy, the AddConstant should be of the
8002 // same type, so create a new Constant with type same as TruncTy.
8003 // Also, the Add constant should be shifted right by AShr amount.
8004 APInt AddOperand = AddOperandCI->getValue().ashr(AShrAmt);
8005 AddConstant = getConstant(AddOperand.trunc(BitWidth - AShrAmt));
8006 // we model the expression as sext(add(trunc(A), c << n)), since the
8007 // sext(trunc) part is already handled below, we create a
8008 // AddExpr(TruncExp) which will be used later.
8009 AddTruncateExpr = getTruncateExpr(ShlOp0SCEV, TruncTy);
8010 }
8011 }
8012 } else if (L && L->getOpcode() == Instruction::Shl) {
8013 // X = Shl A, n
8014 // Y = AShr X, m
8015 // Both n and m are constant.
8016
8017 const SCEV *ShlOp0SCEV = getSCEV(L->getOperand(0));
8018 ShlAmtCI = dyn_cast<ConstantInt>(L->getOperand(1));
8019 AddTruncateExpr = getTruncateExpr(ShlOp0SCEV, TruncTy);
8020 }
8021
8022 if (AddTruncateExpr && ShlAmtCI) {
8023 // We can merge the two given cases into a single SCEV statement,
8024 // incase n = m, the mul expression will be 2^0, so it gets resolved to
8025 // a simpler case. The following code handles the two cases:
8026 //
8027 // 1) For a two-shift sext-inreg, i.e. n = m,
8028 // use sext(trunc(x)) as the SCEV expression.
8029 //
8030 // 2) When n > m, use sext(mul(trunc(x), 2^(n-m)))) as the SCEV
8031 // expression. We already checked that ShlAmt < BitWidth, so
8032 // the multiplier, 1 << (ShlAmt - AShrAmt), fits into TruncTy as
8033 // ShlAmt - AShrAmt < Amt.
8034 const APInt &ShlAmt = ShlAmtCI->getValue();
8035 if (ShlAmt.ult(BitWidth) && ShlAmt.uge(AShrAmt)) {
8037 ShlAmtCI->getZExtValue() - AShrAmt);
8038 const SCEV *CompositeExpr =
8039 getMulExpr(AddTruncateExpr, getConstant(Mul));
8040 if (L->getOpcode() != Instruction::Shl)
8041 CompositeExpr = getAddExpr(CompositeExpr, AddConstant);
8042
8043 return getSignExtendExpr(CompositeExpr, OuterTy);
8044 }
8045 }
8046 break;
8047 }
8048 }
8049
8050 switch (U->getOpcode()) {
8051 case Instruction::Trunc:
8052 return getTruncateExpr(getSCEV(U->getOperand(0)), U->getType());
8053
8054 case Instruction::ZExt:
8055 return getZeroExtendExpr(getSCEV(U->getOperand(0)), U->getType());
8056
8057 case Instruction::SExt:
8058 if (auto BO = MatchBinaryOp(U->getOperand(0), getDataLayout(), AC, DT,
8059 dyn_cast<Instruction>(V))) {
8060 // The NSW flag of a subtract does not always survive the conversion to
8061 // A + (-1)*B. By pushing sign extension onto its operands we are much
8062 // more likely to preserve NSW and allow later AddRec optimisations.
8063 //
8064 // NOTE: This is effectively duplicating this logic from getSignExtend:
8065 // sext((A + B + ...)<nsw>) --> (sext(A) + sext(B) + ...)<nsw>
8066 // but by that point the NSW information has potentially been lost.
8067 if (BO->Opcode == Instruction::Sub && BO->IsNSW) {
8068 Type *Ty = U->getType();
8069 auto *V1 = getSignExtendExpr(getSCEV(BO->LHS), Ty);
8070 auto *V2 = getSignExtendExpr(getSCEV(BO->RHS), Ty);
8071 return getMinusSCEV(V1, V2, SCEV::FlagNSW);
8072 }
8073 }
8074 return getSignExtendExpr(getSCEV(U->getOperand(0)), U->getType());
8075
8076 case Instruction::BitCast:
8077 // BitCasts are no-op casts so we just eliminate the cast.
8078 if (isSCEVable(U->getType()) && isSCEVable(U->getOperand(0)->getType()))
8079 return getSCEV(U->getOperand(0));
8080 break;
8081
8082 case Instruction::PtrToInt: {
8083 // Pointer to integer cast is straight-forward, so do model it.
8084 const SCEV *Op = getSCEV(U->getOperand(0));
8085 Type *DstIntTy = U->getType();
8086 // But only if effective SCEV (integer) type is wide enough to represent
8087 // all possible pointer values.
8088 const SCEV *IntOp = getPtrToIntExpr(Op, DstIntTy);
8089 if (isa<SCEVCouldNotCompute>(IntOp))
8090 return getUnknown(V);
8091 return IntOp;
8092 }
8093 case Instruction::IntToPtr:
8094 // Just don't deal with inttoptr casts.
8095 return getUnknown(V);
8096
8097 case Instruction::SDiv:
8098 // If both operands are non-negative, this is just an udiv.
8099 if (isKnownNonNegative(getSCEV(U->getOperand(0))) &&
8100 isKnownNonNegative(getSCEV(U->getOperand(1))))
8101 return getUDivExpr(getSCEV(U->getOperand(0)), getSCEV(U->getOperand(1)));
8102 break;
8103
8104 case Instruction::SRem:
8105 // If both operands are non-negative, this is just an urem.
8106 if (isKnownNonNegative(getSCEV(U->getOperand(0))) &&
8107 isKnownNonNegative(getSCEV(U->getOperand(1))))
8108 return getURemExpr(getSCEV(U->getOperand(0)), getSCEV(U->getOperand(1)));
8109 break;
8110
8111 case Instruction::GetElementPtr:
8112 return createNodeForGEP(cast<GEPOperator>(U));
8113
8114 case Instruction::PHI:
8115 return createNodeForPHI(cast<PHINode>(U));
8116
8117 case Instruction::Select:
8118 return createNodeForSelectOrPHI(U, U->getOperand(0), U->getOperand(1),
8119 U->getOperand(2));
8120
8121 case Instruction::Call:
8122 case Instruction::Invoke:
8123 if (Value *RV = cast<CallBase>(U)->getReturnedArgOperand())
8124 return getSCEV(RV);
8125
8126 if (auto *II = dyn_cast<IntrinsicInst>(U)) {
8127 switch (II->getIntrinsicID()) {
8128 case Intrinsic::abs:
8129 return getAbsExpr(
8130 getSCEV(II->getArgOperand(0)),
8131 /*IsNSW=*/cast<ConstantInt>(II->getArgOperand(1))->isOne());
8132 case Intrinsic::umax:
8133 LHS = getSCEV(II->getArgOperand(0));
8134 RHS = getSCEV(II->getArgOperand(1));
8135 return getUMaxExpr(LHS, RHS);
8136 case Intrinsic::umin:
8137 LHS = getSCEV(II->getArgOperand(0));
8138 RHS = getSCEV(II->getArgOperand(1));
8139 return getUMinExpr(LHS, RHS);
8140 case Intrinsic::smax:
8141 LHS = getSCEV(II->getArgOperand(0));
8142 RHS = getSCEV(II->getArgOperand(1));
8143 return getSMaxExpr(LHS, RHS);
8144 case Intrinsic::smin:
8145 LHS = getSCEV(II->getArgOperand(0));
8146 RHS = getSCEV(II->getArgOperand(1));
8147 return getSMinExpr(LHS, RHS);
8148 case Intrinsic::usub_sat: {
8149 const SCEV *X = getSCEV(II->getArgOperand(0));
8150 const SCEV *Y = getSCEV(II->getArgOperand(1));
8151 const SCEV *ClampedY = getUMinExpr(X, Y);
8152 return getMinusSCEV(X, ClampedY, SCEV::FlagNUW);
8153 }
8154 case Intrinsic::uadd_sat: {
8155 const SCEV *X = getSCEV(II->getArgOperand(0));
8156 const SCEV *Y = getSCEV(II->getArgOperand(1));
8157 const SCEV *ClampedX = getUMinExpr(X, getNotSCEV(Y));
8158 return getAddExpr(ClampedX, Y, SCEV::FlagNUW);
8159 }
8160 case Intrinsic::start_loop_iterations:
8161 case Intrinsic::annotation:
8162 case Intrinsic::ptr_annotation:
8163 // A start_loop_iterations or llvm.annotation or llvm.prt.annotation is
8164 // just eqivalent to the first operand for SCEV purposes.
8165 return getSCEV(II->getArgOperand(0));
8166 case Intrinsic::vscale:
8167 return getVScale(II->getType());
8168 default:
8169 break;
8170 }
8171 }
8172 break;
8173 }
8174
8175 return getUnknown(V);
8176}
8177
8178//===----------------------------------------------------------------------===//
8179// Iteration Count Computation Code
8180//
8181
8183 if (isa<SCEVCouldNotCompute>(ExitCount))
8184 return getCouldNotCompute();
8185
8186 auto *ExitCountType = ExitCount->getType();
8187 assert(ExitCountType->isIntegerTy());
8188 auto *EvalTy = Type::getIntNTy(ExitCountType->getContext(),
8189 1 + ExitCountType->getScalarSizeInBits());
8190 return getTripCountFromExitCount(ExitCount, EvalTy, nullptr);
8191}
8192
8194 Type *EvalTy,
8195 const Loop *L) {
8196 if (isa<SCEVCouldNotCompute>(ExitCount))
8197 return getCouldNotCompute();
8198
8199 unsigned ExitCountSize = getTypeSizeInBits(ExitCount->getType());
8200 unsigned EvalSize = EvalTy->getPrimitiveSizeInBits();
8201
8202 auto CanAddOneWithoutOverflow = [&]() {
8203 ConstantRange ExitCountRange =
8204 getRangeRef(ExitCount, RangeSignHint::HINT_RANGE_UNSIGNED);
8205 if (!ExitCountRange.contains(APInt::getMaxValue(ExitCountSize)))
8206 return true;
8207
8208 return L && isLoopEntryGuardedByCond(L, ICmpInst::ICMP_NE, ExitCount,
8209 getMinusOne(ExitCount->getType()));
8210 };
8211
8212 // If we need to zero extend the backedge count, check if we can add one to
8213 // it prior to zero extending without overflow. Provided this is safe, it
8214 // allows better simplification of the +1.
8215 if (EvalSize > ExitCountSize && CanAddOneWithoutOverflow())
8216 return getZeroExtendExpr(
8217 getAddExpr(ExitCount, getOne(ExitCount->getType())), EvalTy);
8218
8219 // Get the total trip count from the count by adding 1. This may wrap.
8220 return getAddExpr(getTruncateOrZeroExtend(ExitCount, EvalTy), getOne(EvalTy));
8221}
8222
8223static unsigned getConstantTripCount(const SCEVConstant *ExitCount) {
8224 if (!ExitCount)
8225 return 0;
8226
8227 ConstantInt *ExitConst = ExitCount->getValue();
8228
8229 // Guard against huge trip counts.
8230 if (ExitConst->getValue().getActiveBits() > 32)
8231 return 0;
8232
8233 // In case of integer overflow, this returns 0, which is correct.
8234 return ((unsigned)ExitConst->getZExtValue()) + 1;
8235}
8236
8238 auto *ExitCount = dyn_cast<SCEVConstant>(getBackedgeTakenCount(L, Exact));
8239 return getConstantTripCount(ExitCount);
8240}
8241
8242unsigned
8244 const BasicBlock *ExitingBlock) {
8245 assert(ExitingBlock && "Must pass a non-null exiting block!");
8246 assert(L->isLoopExiting(ExitingBlock) &&
8247 "Exiting block must actually branch out of the loop!");
8248 const SCEVConstant *ExitCount =
8249 dyn_cast<SCEVConstant>(getExitCount(L, ExitingBlock));
8250 return getConstantTripCount(ExitCount);
8251}
8252
8254 const Loop *L, SmallVectorImpl<const SCEVPredicate *> *Predicates) {
8255
8256 const auto *MaxExitCount =
8257 Predicates ? getPredicatedConstantMaxBackedgeTakenCount(L, *Predicates)
8259 return getConstantTripCount(dyn_cast<SCEVConstant>(MaxExitCount));
8260}
8261
8263 SmallVector<BasicBlock *, 8> ExitingBlocks;
8264 L->getExitingBlocks(ExitingBlocks);
8265
8266 std::optional<unsigned> Res;
8267 for (auto *ExitingBB : ExitingBlocks) {
8268 unsigned Multiple = getSmallConstantTripMultiple(L, ExitingBB);
8269 if (!Res)
8270 Res = Multiple;
8271 Res = (unsigned)std::gcd(*Res, Multiple);
8272 }
8273 return Res.value_or(1);
8274}
8275
8277 const SCEV *ExitCount) {
8278 if (ExitCount == getCouldNotCompute())
8279 return 1;
8280
8281 // Get the trip count
8282 const SCEV *TCExpr = getTripCountFromExitCount(applyLoopGuards(ExitCount, L));
8283
8284 APInt Multiple = getNonZeroConstantMultiple(TCExpr);
8285 // If a trip multiple is huge (>=2^32), the trip count is still divisible by
8286 // the greatest power of 2 divisor less than 2^32.
8287 return Multiple.getActiveBits() > 32
8288 ? 1U << std::min((unsigned)31, Multiple.countTrailingZeros())
8289 : (unsigned)Multiple.zextOrTrunc(32).getZExtValue();
8290}
8291
8292/// Returns the largest constant divisor of the trip count of this loop as a
8293/// normal unsigned value, if possible. This means that the actual trip count is
8294/// always a multiple of the returned value (don't forget the trip count could
8295/// very well be zero as well!).
8296///
8297/// Returns 1 if the trip count is unknown or not guaranteed to be the
8298/// multiple of a constant (which is also the case if the trip count is simply
8299/// constant, use getSmallConstantTripCount for that case), Will also return 1
8300/// if the trip count is very large (>= 2^32).
8301///
8302/// As explained in the comments for getSmallConstantTripCount, this assumes
8303/// that control exits the loop via ExitingBlock.
8304unsigned
8306 const BasicBlock *ExitingBlock) {
8307 assert(ExitingBlock && "Must pass a non-null exiting block!");
8308 assert(L->isLoopExiting(ExitingBlock) &&
8309 "Exiting block must actually branch out of the loop!");
8310 const SCEV *ExitCount = getExitCount(L, ExitingBlock);
8311 return getSmallConstantTripMultiple(L, ExitCount);
8312}
8313
8315 const BasicBlock *ExitingBlock,
8316 ExitCountKind Kind) {
8317 switch (Kind) {
8318 case Exact:
8319 return getBackedgeTakenInfo(L).getExact(ExitingBlock, this);
8320 case SymbolicMaximum:
8321 return getBackedgeTakenInfo(L).getSymbolicMax(ExitingBlock, this);
8322 case ConstantMaximum:
8323 return getBackedgeTakenInfo(L).getConstantMax(ExitingBlock, this);
8324 };
8325 llvm_unreachable("Invalid ExitCountKind!");
8326}
8327
8329 const Loop *L, const BasicBlock *ExitingBlock,
8331 switch (Kind) {
8332 case Exact:
8333 return getPredicatedBackedgeTakenInfo(L).getExact(ExitingBlock, this,
8334 Predicates);
8335 case SymbolicMaximum:
8336 return getPredicatedBackedgeTakenInfo(L).getSymbolicMax(ExitingBlock, this,
8337 Predicates);
8338 case ConstantMaximum:
8339 return getPredicatedBackedgeTakenInfo(L).getConstantMax(ExitingBlock, this,
8340 Predicates);
8341 };
8342 llvm_unreachable("Invalid ExitCountKind!");
8343}
8344
8347 return getPredicatedBackedgeTakenInfo(L).getExact(L, this, &Preds);
8348}
8349
8351 ExitCountKind Kind) {
8352 switch (Kind) {
8353 case Exact:
8354 return getBackedgeTakenInfo(L).getExact(L, this);
8355 case ConstantMaximum:
8356 return getBackedgeTakenInfo(L).getConstantMax(this);
8357 case SymbolicMaximum:
8358 return getBackedgeTakenInfo(L).getSymbolicMax(L, this);
8359 };
8360 llvm_unreachable("Invalid ExitCountKind!");
8361}
8362
8365 return getPredicatedBackedgeTakenInfo(L).getSymbolicMax(L, this, &Preds);
8366}
8367
8370 return getPredicatedBackedgeTakenInfo(L).getConstantMax(this, &Preds);
8371}
8372
8374 return getBackedgeTakenInfo(L).isConstantMaxOrZero(this);
8375}
8376
8377/// Push PHI nodes in the header of the given loop onto the given Worklist.
8378static void PushLoopPHIs(const Loop *L,
8381 BasicBlock *Header = L->getHeader();
8382
8383 // Push all Loop-header PHIs onto the Worklist stack.
8384 for (PHINode &PN : Header->phis())
8385 if (Visited.insert(&PN).second)
8386 Worklist.push_back(&PN);
8387}
8388
8389ScalarEvolution::BackedgeTakenInfo &
8390ScalarEvolution::getPredicatedBackedgeTakenInfo(const Loop *L) {
8391 auto &BTI = getBackedgeTakenInfo(L);
8392 if (BTI.hasFullInfo())
8393 return BTI;
8394
8395 auto Pair = PredicatedBackedgeTakenCounts.insert({L, BackedgeTakenInfo()});
8396
8397 if (!Pair.second)
8398 return Pair.first->second;
8399
8400 BackedgeTakenInfo Result =
8401 computeBackedgeTakenCount(L, /*AllowPredicates=*/true);
8402
8403 return PredicatedBackedgeTakenCounts.find(L)->second = std::move(Result);
8404}
8405
8406ScalarEvolution::BackedgeTakenInfo &
8407ScalarEvolution::getBackedgeTakenInfo(const Loop *L) {
8408 // Initially insert an invalid entry for this loop. If the insertion
8409 // succeeds, proceed to actually compute a backedge-taken count and
8410 // update the value. The temporary CouldNotCompute value tells SCEV
8411 // code elsewhere that it shouldn't attempt to request a new
8412 // backedge-taken count, which could result in infinite recursion.
8413 std::pair<DenseMap<const Loop *, BackedgeTakenInfo>::iterator, bool> Pair =
8414 BackedgeTakenCounts.insert({L, BackedgeTakenInfo()});
8415 if (!Pair.second)
8416 return Pair.first->second;
8417
8418 // computeBackedgeTakenCount may allocate memory for its result. Inserting it
8419 // into the BackedgeTakenCounts map transfers ownership. Otherwise, the result
8420 // must be cleared in this scope.
8421 BackedgeTakenInfo Result = computeBackedgeTakenCount(L);
8422
8423 // Now that we know more about the trip count for this loop, forget any
8424 // existing SCEV values for PHI nodes in this loop since they are only
8425 // conservative estimates made without the benefit of trip count
8426 // information. This invalidation is not necessary for correctness, and is
8427 // only done to produce more precise results.
8428 if (Result.hasAnyInfo()) {
8429 // Invalidate any expression using an addrec in this loop.
8431 auto LoopUsersIt = LoopUsers.find(L);
8432 if (LoopUsersIt != LoopUsers.end())
8433 append_range(ToForget, LoopUsersIt->second);
8434 forgetMemoizedResults(ToForget);
8435
8436 // Invalidate constant-evolved loop header phis.
8437 for (PHINode &PN : L->getHeader()->phis())
8438 ConstantEvolutionLoopExitValue.erase(&PN);
8439 }
8440
8441 // Re-lookup the insert position, since the call to
8442 // computeBackedgeTakenCount above could result in a
8443 // recusive call to getBackedgeTakenInfo (on a different
8444 // loop), which would invalidate the iterator computed
8445 // earlier.
8446 return BackedgeTakenCounts.find(L)->second = std::move(Result);
8447}
8448
8450 // This method is intended to forget all info about loops. It should
8451 // invalidate caches as if the following happened:
8452 // - The trip counts of all loops have changed arbitrarily
8453 // - Every llvm::Value has been updated in place to produce a different
8454 // result.
8455 BackedgeTakenCounts.clear();
8456 PredicatedBackedgeTakenCounts.clear();
8457 BECountUsers.clear();
8458 LoopPropertiesCache.clear();
8459 ConstantEvolutionLoopExitValue.clear();
8460 ValueExprMap.clear();
8461 ValuesAtScopes.clear();
8462 ValuesAtScopesUsers.clear();
8463 LoopDispositions.clear();
8464 BlockDispositions.clear();
8465 UnsignedRanges.clear();
8466 SignedRanges.clear();
8467 ExprValueMap.clear();
8468 HasRecMap.clear();
8469 ConstantMultipleCache.clear();
8470 PredicatedSCEVRewrites.clear();
8471 FoldCache.clear();
8472 FoldCacheUser.clear();
8473}
8474void ScalarEvolution::visitAndClearUsers(
8478 while (!Worklist.empty()) {
8479 Instruction *I = Worklist.pop_back_val();
8480 if (!isSCEVable(I->getType()) && !isa<WithOverflowInst>(I))
8481 continue;
8482
8484 ValueExprMap.find_as(static_cast<Value *>(I));
8485 if (It != ValueExprMap.end()) {
8486 eraseValueFromMap(It->first);
8487 ToForget.push_back(It->second);
8488 if (PHINode *PN = dyn_cast<PHINode>(I))
8489 ConstantEvolutionLoopExitValue.erase(PN);
8490 }
8491
8492 PushDefUseChildren(I, Worklist, Visited);
8493 }
8494}
8495
8497 SmallVector<const Loop *, 16> LoopWorklist(1, L);
8501
8502 // Iterate over all the loops and sub-loops to drop SCEV information.
8503 while (!LoopWorklist.empty()) {
8504 auto *CurrL = LoopWorklist.pop_back_val();
8505
8506 // Drop any stored trip count value.
8507 forgetBackedgeTakenCounts(CurrL, /* Predicated */ false);
8508 forgetBackedgeTakenCounts(CurrL, /* Predicated */ true);
8509
8510 // Drop information about predicated SCEV rewrites for this loop.
8511 for (auto I = PredicatedSCEVRewrites.begin();
8512 I != PredicatedSCEVRewrites.end();) {
8513 std::pair<const SCEV *, const Loop *> Entry = I->first;
8514 if (Entry.second == CurrL)
8515 PredicatedSCEVRewrites.erase(I++);
8516 else
8517 ++I;
8518 }
8519
8520 auto LoopUsersItr = LoopUsers.find(CurrL);
8521 if (LoopUsersItr != LoopUsers.end()) {
8522 ToForget.insert(ToForget.end(), LoopUsersItr->second.begin(),
8523 LoopUsersItr->second.end());
8524 }
8525
8526 // Drop information about expressions based on loop-header PHIs.
8527 PushLoopPHIs(CurrL, Worklist, Visited);
8528 visitAndClearUsers(Worklist, Visited, ToForget);
8529
8530 LoopPropertiesCache.erase(CurrL);
8531 // Forget all contained loops too, to avoid dangling entries in the
8532 // ValuesAtScopes map.
8533 LoopWorklist.append(CurrL->begin(), CurrL->end());
8534 }
8535 forgetMemoizedResults(ToForget);
8536}
8537
8539 forgetLoop(L->getOutermostLoop());
8540}
8541
8543 Instruction *I = dyn_cast<Instruction>(V);
8544 if (!I) return;
8545
8546 // Drop information about expressions based on loop-header PHIs.
8550 Worklist.push_back(I);
8551 Visited.insert(I);
8552 visitAndClearUsers(Worklist, Visited, ToForget);
8553
8554 forgetMemoizedResults(ToForget);
8555}
8556
8558 if (!isSCEVable(V->getType()))
8559 return;
8560
8561 // If SCEV looked through a trivial LCSSA phi node, we might have SCEV's
8562 // directly using a SCEVUnknown/SCEVAddRec defined in the loop. After an
8563 // extra predecessor is added, this is no longer valid. Find all Unknowns and
8564 // AddRecs defined in the loop and invalidate any SCEV's making use of them.
8565 if (const SCEV *S = getExistingSCEV(V)) {
8566 struct InvalidationRootCollector {
8567 Loop *L;
8569
8570 InvalidationRootCollector(Loop *L) : L(L) {}
8571
8572 bool follow(const SCEV *S) {
8573 if (auto *SU = dyn_cast<SCEVUnknown>(S)) {
8574 if (auto *I = dyn_cast<Instruction>(SU->getValue()))
8575 if (L->contains(I))
8576 Roots.push_back(S);
8577 } else if (auto *AddRec = dyn_cast<SCEVAddRecExpr>(S)) {
8578 if (L->contains(AddRec->getLoop()))
8579 Roots.push_back(S);
8580 }
8581 return true;
8582 }
8583 bool isDone() const { return false; }
8584 };
8585
8586 InvalidationRootCollector C(L);
8587 visitAll(S, C);
8588 forgetMemoizedResults(C.Roots);
8589 }
8590
8591 // Also perform the normal invalidation.
8592 forgetValue(V);
8593}
8594
8595void ScalarEvolution::forgetLoopDispositions() { LoopDispositions.clear(); }
8596
8598 // Unless a specific value is passed to invalidation, completely clear both
8599 // caches.
8600 if (!V) {
8601 BlockDispositions.clear();
8602 LoopDispositions.clear();
8603 return;
8604 }
8605
8606 if (!isSCEVable(V->getType()))
8607 return;
8608
8609 const SCEV *S = getExistingSCEV(V);
8610 if (!S)
8611 return;
8612
8613 // Invalidate the block and loop dispositions cached for S. Dispositions of
8614 // S's users may change if S's disposition changes (i.e. a user may change to
8615 // loop-invariant, if S changes to loop invariant), so also invalidate
8616 // dispositions of S's users recursively.
8617 SmallVector<const SCEV *, 8> Worklist = {S};
8619 while (!Worklist.empty()) {
8620 const SCEV *Curr = Worklist.pop_back_val();
8621 bool LoopDispoRemoved = LoopDispositions.erase(Curr);
8622 bool BlockDispoRemoved = BlockDispositions.erase(Curr);
8623 if (!LoopDispoRemoved && !BlockDispoRemoved)
8624 continue;
8625 auto Users = SCEVUsers.find(Curr);
8626 if (Users != SCEVUsers.end())
8627 for (const auto *User : Users->second)
8628 if (Seen.insert(User).second)
8629 Worklist.push_back(User);
8630 }
8631}
8632
8633/// Get the exact loop backedge taken count considering all loop exits. A
8634/// computable result can only be returned for loops with all exiting blocks
8635/// dominating the latch. howFarToZero assumes that the limit of each loop test
8636/// is never skipped. This is a valid assumption as long as the loop exits via
8637/// that test. For precise results, it is the caller's responsibility to specify
8638/// the relevant loop exiting block using getExact(ExitingBlock, SE).
8639const SCEV *ScalarEvolution::BackedgeTakenInfo::getExact(
8640 const Loop *L, ScalarEvolution *SE,
8642 // If any exits were not computable, the loop is not computable.
8643 if (!isComplete() || ExitNotTaken.empty())
8644 return SE->getCouldNotCompute();
8645
8646 const BasicBlock *Latch = L->getLoopLatch();
8647 // All exiting blocks we have collected must dominate the only backedge.
8648 if (!Latch)
8649 return SE->getCouldNotCompute();
8650
8651 // All exiting blocks we have gathered dominate loop's latch, so exact trip
8652 // count is simply a minimum out of all these calculated exit counts.
8654 for (const auto &ENT : ExitNotTaken) {
8655 const SCEV *BECount = ENT.ExactNotTaken;
8656 assert(BECount != SE->getCouldNotCompute() && "Bad exit SCEV!");
8657 assert(SE->DT.dominates(ENT.ExitingBlock, Latch) &&
8658 "We should only have known counts for exiting blocks that dominate "
8659 "latch!");
8660
8661 Ops.push_back(BECount);
8662
8663 if (Preds)
8664 append_range(*Preds, ENT.Predicates);
8665
8666 assert((Preds || ENT.hasAlwaysTruePredicate()) &&
8667 "Predicate should be always true!");
8668 }
8669
8670 // If an earlier exit exits on the first iteration (exit count zero), then
8671 // a later poison exit count should not propagate into the result. This are
8672 // exactly the semantics provided by umin_seq.
8673 return SE->getUMinFromMismatchedTypes(Ops, /* Sequential */ true);
8674}
8675
8676const ScalarEvolution::ExitNotTakenInfo *
8677ScalarEvolution::BackedgeTakenInfo::getExitNotTaken(
8678 const BasicBlock *ExitingBlock,
8679 SmallVectorImpl<const SCEVPredicate *> *Predicates) const {
8680 for (const auto &ENT : ExitNotTaken)
8681 if (ENT.ExitingBlock == ExitingBlock) {
8682 if (ENT.hasAlwaysTruePredicate())
8683 return &ENT;
8684 else if (Predicates) {
8685 append_range(*Predicates, ENT.Predicates);
8686 return &ENT;
8687 }
8688 }
8689
8690 return nullptr;
8691}
8692
8693/// getConstantMax - Get the constant max backedge taken count for the loop.
8694const SCEV *ScalarEvolution::BackedgeTakenInfo::getConstantMax(
8695 ScalarEvolution *SE,
8696 SmallVectorImpl<const SCEVPredicate *> *Predicates) const {
8697 if (!getConstantMax())
8698 return SE->getCouldNotCompute();
8699
8700 for (const auto &ENT : ExitNotTaken)
8701 if (!ENT.hasAlwaysTruePredicate()) {
8702 if (!Predicates)
8703 return SE->getCouldNotCompute();
8704 append_range(*Predicates, ENT.Predicates);
8705 }
8706
8707 assert((isa<SCEVCouldNotCompute>(getConstantMax()) ||
8708 isa<SCEVConstant>(getConstantMax())) &&
8709 "No point in having a non-constant max backedge taken count!");
8710 return getConstantMax();
8711}
8712
8713const SCEV *ScalarEvolution::BackedgeTakenInfo::getSymbolicMax(
8714 const Loop *L, ScalarEvolution *SE,
8716 if (!SymbolicMax) {
8717 // Form an expression for the maximum exit count possible for this loop. We
8718 // merge the max and exact information to approximate a version of
8719 // getConstantMaxBackedgeTakenCount which isn't restricted to just
8720 // constants.
8722
8723 for (const auto &ENT : ExitNotTaken) {
8724 const SCEV *ExitCount = ENT.SymbolicMaxNotTaken;
8725 if (!isa<SCEVCouldNotCompute>(ExitCount)) {
8726 assert(SE->DT.dominates(ENT.ExitingBlock, L->getLoopLatch()) &&
8727 "We should only have known counts for exiting blocks that "
8728 "dominate latch!");
8729 ExitCounts.push_back(ExitCount);
8730 if (Predicates)
8731 append_range(*Predicates, ENT.Predicates);
8732
8733 assert((Predicates || ENT.hasAlwaysTruePredicate()) &&
8734 "Predicate should be always true!");
8735 }
8736 }
8737 if (ExitCounts.empty())
8738 SymbolicMax = SE->getCouldNotCompute();
8739 else
8740 SymbolicMax =
8741 SE->getUMinFromMismatchedTypes(ExitCounts, /*Sequential*/ true);
8742 }
8743 return SymbolicMax;
8744}
8745
8746bool ScalarEvolution::BackedgeTakenInfo::isConstantMaxOrZero(
8747 ScalarEvolution *SE) const {
8748 auto PredicateNotAlwaysTrue = [](const ExitNotTakenInfo &ENT) {
8749 return !ENT.hasAlwaysTruePredicate();
8750 };
8751 return MaxOrZero && !any_of(ExitNotTaken, PredicateNotAlwaysTrue);
8752}
8753
8755 : ExitLimit(E, E, E, false) {}
8756
8758 const SCEV *E, const SCEV *ConstantMaxNotTaken,
8759 const SCEV *SymbolicMaxNotTaken, bool MaxOrZero,
8761 : ExactNotTaken(E), ConstantMaxNotTaken(ConstantMaxNotTaken),
8762 SymbolicMaxNotTaken(SymbolicMaxNotTaken), MaxOrZero(MaxOrZero) {
8763 // If we prove the max count is zero, so is the symbolic bound. This happens
8764 // in practice due to differences in a) how context sensitive we've chosen
8765 // to be and b) how we reason about bounds implied by UB.
8766 if (ConstantMaxNotTaken->isZero()) {
8768 this->SymbolicMaxNotTaken = SymbolicMaxNotTaken = ConstantMaxNotTaken;
8769 }
8770
8771 assert((isa<SCEVCouldNotCompute>(ExactNotTaken) ||
8772 !isa<SCEVCouldNotCompute>(ConstantMaxNotTaken)) &&
8773 "Exact is not allowed to be less precise than Constant Max");
8774 assert((isa<SCEVCouldNotCompute>(ExactNotTaken) ||
8775 !isa<SCEVCouldNotCompute>(SymbolicMaxNotTaken)) &&
8776 "Exact is not allowed to be less precise than Symbolic Max");
8777 assert((isa<SCEVCouldNotCompute>(SymbolicMaxNotTaken) ||
8778 !isa<SCEVCouldNotCompute>(ConstantMaxNotTaken)) &&
8779 "Symbolic Max is not allowed to be less precise than Constant Max");
8780 assert((isa<SCEVCouldNotCompute>(ConstantMaxNotTaken) ||
8781 isa<SCEVConstant>(ConstantMaxNotTaken)) &&
8782 "No point in having a non-constant max backedge taken count!");
8784 for (const auto PredList : PredLists)
8785 for (const auto *P : PredList) {
8786 if (SeenPreds.contains(P))
8787 continue;
8788 assert(!isa<SCEVUnionPredicate>(P) && "Only add leaf predicates here!");
8789 SeenPreds.insert(P);
8790 Predicates.push_back(P);
8791 }
8792 assert((isa<SCEVCouldNotCompute>(E) || !E->getType()->isPointerTy()) &&
8793 "Backedge count should be int");
8794 assert((isa<SCEVCouldNotCompute>(ConstantMaxNotTaken) ||
8796 "Max backedge count should be int");
8797}
8798
8800 const SCEV *ConstantMaxNotTaken,
8801 const SCEV *SymbolicMaxNotTaken,
8802 bool MaxOrZero,
8804 : ExitLimit(E, ConstantMaxNotTaken, SymbolicMaxNotTaken, MaxOrZero,
8805 ArrayRef({PredList})) {}
8806
8807/// Allocate memory for BackedgeTakenInfo and copy the not-taken count of each
8808/// computable exit into a persistent ExitNotTakenInfo array.
8809ScalarEvolution::BackedgeTakenInfo::BackedgeTakenInfo(
8811 bool IsComplete, const SCEV *ConstantMax, bool MaxOrZero)
8812 : ConstantMax(ConstantMax), IsComplete(IsComplete), MaxOrZero(MaxOrZero) {
8813 using EdgeExitInfo = ScalarEvolution::BackedgeTakenInfo::EdgeExitInfo;
8814
8815 ExitNotTaken.reserve(ExitCounts.size());
8816 std::transform(ExitCounts.begin(), ExitCounts.end(),
8817 std::back_inserter(ExitNotTaken),
8818 [&](const EdgeExitInfo &EEI) {
8819 BasicBlock *ExitBB = EEI.first;
8820 const ExitLimit &EL = EEI.second;
8821 return ExitNotTakenInfo(ExitBB, EL.ExactNotTaken,
8822 EL.ConstantMaxNotTaken, EL.SymbolicMaxNotTaken,
8823 EL.Predicates);
8824 });
8825 assert((isa<SCEVCouldNotCompute>(ConstantMax) ||
8826 isa<SCEVConstant>(ConstantMax)) &&
8827 "No point in having a non-constant max backedge taken count!");
8828}
8829
8830/// Compute the number of times the backedge of the specified loop will execute.
8831ScalarEvolution::BackedgeTakenInfo
8832ScalarEvolution::computeBackedgeTakenCount(const Loop *L,
8833 bool AllowPredicates) {
8834 SmallVector<BasicBlock *, 8> ExitingBlocks;
8835 L->getExitingBlocks(ExitingBlocks);
8836
8837 using EdgeExitInfo = ScalarEvolution::BackedgeTakenInfo::EdgeExitInfo;
8838
8840 bool CouldComputeBECount = true;
8841 BasicBlock *Latch = L->getLoopLatch(); // may be NULL.
8842 const SCEV *MustExitMaxBECount = nullptr;
8843 const SCEV *MayExitMaxBECount = nullptr;
8844 bool MustExitMaxOrZero = false;
8845 bool IsOnlyExit = ExitingBlocks.size() == 1;
8846
8847 // Compute the ExitLimit for each loop exit. Use this to populate ExitCounts
8848 // and compute maxBECount.
8849 // Do a union of all the predicates here.
8850 for (BasicBlock *ExitBB : ExitingBlocks) {
8851 // We canonicalize untaken exits to br (constant), ignore them so that
8852 // proving an exit untaken doesn't negatively impact our ability to reason
8853 // about the loop as whole.
8854 if (auto *BI = dyn_cast<BranchInst>(ExitBB->getTerminator()))
8855 if (auto *CI = dyn_cast<ConstantInt>(BI->getCondition())) {
8856 bool ExitIfTrue = !L->contains(BI->getSuccessor(0));
8857 if (ExitIfTrue == CI->isZero())
8858 continue;
8859 }
8860
8861 ExitLimit EL = computeExitLimit(L, ExitBB, IsOnlyExit, AllowPredicates);
8862
8863 assert((AllowPredicates || EL.Predicates.empty()) &&
8864 "Predicated exit limit when predicates are not allowed!");
8865
8866 // 1. For each exit that can be computed, add an entry to ExitCounts.
8867 // CouldComputeBECount is true only if all exits can be computed.
8868 if (EL.ExactNotTaken != getCouldNotCompute())
8869 ++NumExitCountsComputed;
8870 else
8871 // We couldn't compute an exact value for this exit, so
8872 // we won't be able to compute an exact value for the loop.
8873 CouldComputeBECount = false;
8874 // Remember exit count if either exact or symbolic is known. Because
8875 // Exact always implies symbolic, only check symbolic.
8876 if (EL.SymbolicMaxNotTaken != getCouldNotCompute())
8877 ExitCounts.emplace_back(ExitBB, EL);
8878 else {
8879 assert(EL.ExactNotTaken == getCouldNotCompute() &&
8880 "Exact is known but symbolic isn't?");
8881 ++NumExitCountsNotComputed;
8882 }
8883
8884 // 2. Derive the loop's MaxBECount from each exit's max number of
8885 // non-exiting iterations. Partition the loop exits into two kinds:
8886 // LoopMustExits and LoopMayExits.
8887 //
8888 // If the exit dominates the loop latch, it is a LoopMustExit otherwise it
8889 // is a LoopMayExit. If any computable LoopMustExit is found, then
8890 // MaxBECount is the minimum EL.ConstantMaxNotTaken of computable
8891 // LoopMustExits. Otherwise, MaxBECount is conservatively the maximum
8892 // EL.ConstantMaxNotTaken, where CouldNotCompute is considered greater than
8893 // any
8894 // computable EL.ConstantMaxNotTaken.
8895 if (EL.ConstantMaxNotTaken != getCouldNotCompute() && Latch &&
8896 DT.dominates(ExitBB, Latch)) {
8897 if (!MustExitMaxBECount) {
8898 MustExitMaxBECount = EL.ConstantMaxNotTaken;
8899 MustExitMaxOrZero = EL.MaxOrZero;
8900 } else {
8901 MustExitMaxBECount = getUMinFromMismatchedTypes(MustExitMaxBECount,
8902 EL.ConstantMaxNotTaken);
8903 }
8904 } else if (MayExitMaxBECount != getCouldNotCompute()) {
8905 if (!MayExitMaxBECount || EL.ConstantMaxNotTaken == getCouldNotCompute())
8906 MayExitMaxBECount = EL.ConstantMaxNotTaken;
8907 else {
8908 MayExitMaxBECount = getUMaxFromMismatchedTypes(MayExitMaxBECount,
8909 EL.ConstantMaxNotTaken);
8910 }
8911 }
8912 }
8913 const SCEV *MaxBECount = MustExitMaxBECount ? MustExitMaxBECount :
8914 (MayExitMaxBECount ? MayExitMaxBECount : getCouldNotCompute());
8915 // The loop backedge will be taken the maximum or zero times if there's
8916 // a single exit that must be taken the maximum or zero times.
8917 bool MaxOrZero = (MustExitMaxOrZero && ExitingBlocks.size() == 1);
8918
8919 // Remember which SCEVs are used in exit limits for invalidation purposes.
8920 // We only care about non-constant SCEVs here, so we can ignore
8921 // EL.ConstantMaxNotTaken
8922 // and MaxBECount, which must be SCEVConstant.
8923 for (const auto &Pair : ExitCounts) {
8924 if (!isa<SCEVConstant>(Pair.second.ExactNotTaken))
8925 BECountUsers[Pair.second.ExactNotTaken].insert({L, AllowPredicates});
8926 if (!isa<SCEVConstant>(Pair.second.SymbolicMaxNotTaken))
8927 BECountUsers[Pair.second.SymbolicMaxNotTaken].insert(
8928 {L, AllowPredicates});
8929 }
8930 return BackedgeTakenInfo(std::move(ExitCounts), CouldComputeBECount,
8931 MaxBECount, MaxOrZero);
8932}
8933
8935ScalarEvolution::computeExitLimit(const Loop *L, BasicBlock *ExitingBlock,
8936 bool IsOnlyExit, bool AllowPredicates) {
8937 assert(L->contains(ExitingBlock) && "Exit count for non-loop block?");
8938 // If our exiting block does not dominate the latch, then its connection with
8939 // loop's exit limit may be far from trivial.
8940 const BasicBlock *Latch = L->getLoopLatch();
8941 if (!Latch || !DT.dominates(ExitingBlock, Latch))
8942 return getCouldNotCompute();
8943
8944 Instruction *Term = ExitingBlock->getTerminator();
8945 if (BranchInst *BI = dyn_cast<BranchInst>(Term)) {
8946 assert(BI->isConditional() && "If unconditional, it can't be in loop!");
8947 bool ExitIfTrue = !L->contains(BI->getSuccessor(0));
8948 assert(ExitIfTrue == L->contains(BI->getSuccessor(1)) &&
8949 "It should have one successor in loop and one exit block!");
8950 // Proceed to the next level to examine the exit condition expression.
8951 return computeExitLimitFromCond(L, BI->getCondition(), ExitIfTrue,
8952 /*ControlsOnlyExit=*/IsOnlyExit,
8953 AllowPredicates);
8954 }
8955
8956 if (SwitchInst *SI = dyn_cast<SwitchInst>(Term)) {
8957 // For switch, make sure that there is a single exit from the loop.
8958 BasicBlock *Exit = nullptr;
8959 for (auto *SBB : successors(ExitingBlock))
8960 if (!L->contains(SBB)) {
8961 if (Exit) // Multiple exit successors.
8962 return getCouldNotCompute();
8963 Exit = SBB;
8964 }
8965 assert(Exit && "Exiting block must have at least one exit");
8966 return computeExitLimitFromSingleExitSwitch(
8967 L, SI, Exit, /*ControlsOnlyExit=*/IsOnlyExit);
8968 }
8969
8970 return getCouldNotCompute();
8971}
8972
8974 const Loop *L, Value *ExitCond, bool ExitIfTrue, bool ControlsOnlyExit,
8975 bool AllowPredicates) {
8976 ScalarEvolution::ExitLimitCacheTy Cache(L, ExitIfTrue, AllowPredicates);
8977 return computeExitLimitFromCondCached(Cache, L, ExitCond, ExitIfTrue,
8978 ControlsOnlyExit, AllowPredicates);
8979}
8980
8981std::optional<ScalarEvolution::ExitLimit>
8982ScalarEvolution::ExitLimitCache::find(const Loop *L, Value *ExitCond,
8983 bool ExitIfTrue, bool ControlsOnlyExit,
8984 bool AllowPredicates) {
8985 (void)this->L;
8986 (void)this->ExitIfTrue;
8987 (void)this->AllowPredicates;
8988
8989 assert(this->L == L && this->ExitIfTrue == ExitIfTrue &&
8990 this->AllowPredicates == AllowPredicates &&
8991 "Variance in assumed invariant key components!");
8992 auto Itr = TripCountMap.find({ExitCond, ControlsOnlyExit});
8993 if (Itr == TripCountMap.end())
8994 return std::nullopt;
8995 return Itr->second;
8996}
8997
8998void ScalarEvolution::ExitLimitCache::insert(const Loop *L, Value *ExitCond,
8999 bool ExitIfTrue,
9000 bool ControlsOnlyExit,
9001 bool AllowPredicates,
9002 const ExitLimit &EL) {
9003 assert(this->L == L && this->ExitIfTrue == ExitIfTrue &&
9004 this->AllowPredicates == AllowPredicates &&
9005 "Variance in assumed invariant key components!");
9006
9007 auto InsertResult = TripCountMap.insert({{ExitCond, ControlsOnlyExit}, EL});
9008 assert(InsertResult.second && "Expected successful insertion!");
9009 (void)InsertResult;
9010 (void)ExitIfTrue;
9011}
9012
9013ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromCondCached(
9014 ExitLimitCacheTy &Cache, const Loop *L, Value *ExitCond, bool ExitIfTrue,
9015 bool ControlsOnlyExit, bool AllowPredicates) {
9016
9017 if (auto MaybeEL = Cache.find(L, ExitCond, ExitIfTrue, ControlsOnlyExit,
9018 AllowPredicates))
9019 return *MaybeEL;
9020
9021 ExitLimit EL = computeExitLimitFromCondImpl(
9022 Cache, L, ExitCond, ExitIfTrue, ControlsOnlyExit, AllowPredicates);
9023 Cache.insert(L, ExitCond, ExitIfTrue, ControlsOnlyExit, AllowPredicates, EL);
9024 return EL;
9025}
9026
9027ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromCondImpl(
9028 ExitLimitCacheTy &Cache, const Loop *L, Value *ExitCond, bool ExitIfTrue,
9029 bool ControlsOnlyExit, bool AllowPredicates) {
9030 // Handle BinOp conditions (And, Or).
9031 if (auto LimitFromBinOp = computeExitLimitFromCondFromBinOp(
9032 Cache, L, ExitCond, ExitIfTrue, ControlsOnlyExit, AllowPredicates))
9033 return *LimitFromBinOp;
9034
9035 // With an icmp, it may be feasible to compute an exact backedge-taken count.
9036 // Proceed to the next level to examine the icmp.
9037 if (ICmpInst *ExitCondICmp = dyn_cast<ICmpInst>(ExitCond)) {
9038 ExitLimit EL =
9039 computeExitLimitFromICmp(L, ExitCondICmp, ExitIfTrue, ControlsOnlyExit);
9040 if (EL.hasFullInfo() || !AllowPredicates)
9041 return EL;
9042
9043 // Try again, but use SCEV predicates this time.
9044 return computeExitLimitFromICmp(L, ExitCondICmp, ExitIfTrue,
9045 ControlsOnlyExit,
9046 /*AllowPredicates=*/true);
9047 }
9048
9049 // Check for a constant condition. These are normally stripped out by
9050 // SimplifyCFG, but ScalarEvolution may be used by a pass which wishes to
9051 // preserve the CFG and is temporarily leaving constant conditions
9052 // in place.
9053 if (ConstantInt *CI = dyn_cast<ConstantInt>(ExitCond)) {
9054 if (ExitIfTrue == !CI->getZExtValue())
9055 // The backedge is always taken.
9056 return getCouldNotCompute();
9057 // The backedge is never taken.
9058 return getZero(CI->getType());
9059 }
9060
9061 // If we're exiting based on the overflow flag of an x.with.overflow intrinsic
9062 // with a constant step, we can form an equivalent icmp predicate and figure
9063 // out how many iterations will be taken before we exit.
9064 const WithOverflowInst *WO;
9065 const APInt *C;
9066 if (match(ExitCond, m_ExtractValue<1>(m_WithOverflowInst(WO))) &&
9067 match(WO->getRHS(), m_APInt(C))) {
9068 ConstantRange NWR =
9070 WO->getNoWrapKind());
9071 CmpInst::Predicate Pred;
9072 APInt NewRHSC, Offset;
9073 NWR.getEquivalentICmp(Pred, NewRHSC, Offset);
9074 if (!ExitIfTrue)
9075 Pred = ICmpInst::getInversePredicate(Pred);
9076 auto *LHS = getSCEV(WO->getLHS());
9077 if (Offset != 0)
9079 auto EL = computeExitLimitFromICmp(L, Pred, LHS, getConstant(NewRHSC),
9080 ControlsOnlyExit, AllowPredicates);
9081 if (EL.hasAnyInfo())
9082 return EL;
9083 }
9084
9085 // If it's not an integer or pointer comparison then compute it the hard way.
9086 return computeExitCountExhaustively(L, ExitCond, ExitIfTrue);
9087}
9088
9089std::optional<ScalarEvolution::ExitLimit>
9090ScalarEvolution::computeExitLimitFromCondFromBinOp(
9091 ExitLimitCacheTy &Cache, const Loop *L, Value *ExitCond, bool ExitIfTrue,
9092 bool ControlsOnlyExit, bool AllowPredicates) {
9093 // Check if the controlling expression for this loop is an And or Or.
9094 Value *Op0, *Op1;
9095 bool IsAnd = false;
9096 if (match(ExitCond, m_LogicalAnd(m_Value(Op0), m_Value(Op1))))
9097 IsAnd = true;
9098 else if (match(ExitCond, m_LogicalOr(m_Value(Op0), m_Value(Op1))))
9099 IsAnd = false;
9100 else
9101 return std::nullopt;
9102
9103 // EitherMayExit is true in these two cases:
9104 // br (and Op0 Op1), loop, exit
9105 // br (or Op0 Op1), exit, loop
9106 bool EitherMayExit = IsAnd ^ ExitIfTrue;
9107 ExitLimit EL0 = computeExitLimitFromCondCached(
9108 Cache, L, Op0, ExitIfTrue, ControlsOnlyExit && !EitherMayExit,
9109 AllowPredicates);
9110 ExitLimit EL1 = computeExitLimitFromCondCached(
9111 Cache, L, Op1, ExitIfTrue, ControlsOnlyExit && !EitherMayExit,
9112 AllowPredicates);
9113
9114 // Be robust against unsimplified IR for the form "op i1 X, NeutralElement"
9115 const Constant *NeutralElement = ConstantInt::get(ExitCond->getType(), IsAnd);
9116 if (isa<ConstantInt>(Op1))
9117 return Op1 == NeutralElement ? EL0 : EL1;
9118 if (isa<ConstantInt>(Op0))
9119 return Op0 == NeutralElement ? EL1 : EL0;
9120
9121 const SCEV *BECount = getCouldNotCompute();
9122 const SCEV *ConstantMaxBECount = getCouldNotCompute();
9123 const SCEV *SymbolicMaxBECount = getCouldNotCompute();
9124 if (EitherMayExit) {
9125 bool UseSequentialUMin = !isa<BinaryOperator>(ExitCond);
9126 // Both conditions must be same for the loop to continue executing.
9127 // Choose the less conservative count.
9128 if (EL0.ExactNotTaken != getCouldNotCompute() &&
9129 EL1.ExactNotTaken != getCouldNotCompute()) {
9130 BECount = getUMinFromMismatchedTypes(EL0.ExactNotTaken, EL1.ExactNotTaken,
9131 UseSequentialUMin);
9132 }
9133 if (EL0.ConstantMaxNotTaken == getCouldNotCompute())
9134 ConstantMaxBECount = EL1.ConstantMaxNotTaken;
9135 else if (EL1.ConstantMaxNotTaken == getCouldNotCompute())
9136 ConstantMaxBECount = EL0.ConstantMaxNotTaken;
9137 else
9138 ConstantMaxBECount = getUMinFromMismatchedTypes(EL0.ConstantMaxNotTaken,
9139 EL1.ConstantMaxNotTaken);
9140 if (EL0.SymbolicMaxNotTaken == getCouldNotCompute())
9141 SymbolicMaxBECount = EL1.SymbolicMaxNotTaken;
9142 else if (EL1.SymbolicMaxNotTaken == getCouldNotCompute())
9143 SymbolicMaxBECount = EL0.SymbolicMaxNotTaken;
9144 else
9145 SymbolicMaxBECount = getUMinFromMismatchedTypes(
9146 EL0.SymbolicMaxNotTaken, EL1.SymbolicMaxNotTaken, UseSequentialUMin);
9147 } else {
9148 // Both conditions must be same at the same time for the loop to exit.
9149 // For now, be conservative.
9150 if (EL0.ExactNotTaken == EL1.ExactNotTaken)
9151 BECount = EL0.ExactNotTaken;
9152 }
9153
9154 // There are cases (e.g. PR26207) where computeExitLimitFromCond is able
9155 // to be more aggressive when computing BECount than when computing
9156 // ConstantMaxBECount. In these cases it is possible for EL0.ExactNotTaken
9157 // and
9158 // EL1.ExactNotTaken to match, but for EL0.ConstantMaxNotTaken and
9159 // EL1.ConstantMaxNotTaken to not.
9160 if (isa<SCEVCouldNotCompute>(ConstantMaxBECount) &&
9161 !isa<SCEVCouldNotCompute>(BECount))
9162 ConstantMaxBECount = getConstant(getUnsignedRangeMax(BECount));
9163 if (isa<SCEVCouldNotCompute>(SymbolicMaxBECount))
9164 SymbolicMaxBECount =
9165 isa<SCEVCouldNotCompute>(BECount) ? ConstantMaxBECount : BECount;
9166 return ExitLimit(BECount, ConstantMaxBECount, SymbolicMaxBECount, false,
9167 {ArrayRef(EL0.Predicates), ArrayRef(EL1.Predicates)});
9168}
9169
9170ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromICmp(
9171 const Loop *L, ICmpInst *ExitCond, bool ExitIfTrue, bool ControlsOnlyExit,
9172 bool AllowPredicates) {
9173 // If the condition was exit on true, convert the condition to exit on false
9175 if (!ExitIfTrue)
9176 Pred = ExitCond->getPredicate();
9177 else
9178 Pred = ExitCond->getInversePredicate();
9179 const ICmpInst::Predicate OriginalPred = Pred;
9180
9181 const SCEV *LHS = getSCEV(ExitCond->getOperand(0));
9182 const SCEV *RHS = getSCEV(ExitCond->getOperand(1));
9183
9184 ExitLimit EL = computeExitLimitFromICmp(L, Pred, LHS, RHS, ControlsOnlyExit,
9185 AllowPredicates);
9186 if (EL.hasAnyInfo())
9187 return EL;
9188
9189 auto *ExhaustiveCount =
9190 computeExitCountExhaustively(L, ExitCond, ExitIfTrue);
9191
9192 if (!isa<SCEVCouldNotCompute>(ExhaustiveCount))
9193 return ExhaustiveCount;
9194
9195 return computeShiftCompareExitLimit(ExitCond->getOperand(0),
9196 ExitCond->getOperand(1), L, OriginalPred);
9197}
9198ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromICmp(
9199 const Loop *L, ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS,
9200 bool ControlsOnlyExit, bool AllowPredicates) {
9201
9202 // Try to evaluate any dependencies out of the loop.
9203 LHS = getSCEVAtScope(LHS, L);
9204 RHS = getSCEVAtScope(RHS, L);
9205
9206 // At this point, we would like to compute how many iterations of the
9207 // loop the predicate will return true for these inputs.
9208 if (isLoopInvariant(LHS, L) && !isLoopInvariant(RHS, L)) {
9209 // If there is a loop-invariant, force it into the RHS.
9210 std::swap(LHS, RHS);
9211 Pred = ICmpInst::getSwappedPredicate(Pred);
9212 }
9213
9214 bool ControllingFiniteLoop = ControlsOnlyExit && loopHasNoAbnormalExits(L) &&
9216 // Simplify the operands before analyzing them.
9217 (void)SimplifyICmpOperands(Pred, LHS, RHS, /*Depth=*/0);
9218
9219 // If we have a comparison of a chrec against a constant, try to use value
9220 // ranges to answer this query.
9221 if (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS))
9222 if (const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(LHS))
9223 if (AddRec->getLoop() == L) {
9224 // Form the constant range.
9225 ConstantRange CompRange =
9226 ConstantRange::makeExactICmpRegion(Pred, RHSC->getAPInt());
9227
9228 const SCEV *Ret = AddRec->getNumIterationsInRange(CompRange, *this);
9229 if (!isa<SCEVCouldNotCompute>(Ret)) return Ret;
9230 }
9231
9232 // If this loop must exit based on this condition (or execute undefined
9233 // behaviour), see if we can improve wrap flags. This is essentially
9234 // a must execute style proof.
9235 if (ControllingFiniteLoop && isLoopInvariant(RHS, L)) {
9236 // If we can prove the test sequence produced must repeat the same values
9237 // on self-wrap of the IV, then we can infer that IV doesn't self wrap
9238 // because if it did, we'd have an infinite (undefined) loop.
9239 // TODO: We can peel off any functions which are invertible *in L*. Loop
9240 // invariant terms are effectively constants for our purposes here.
9241 auto *InnerLHS = LHS;
9242 if (auto *ZExt = dyn_cast<SCEVZeroExtendExpr>(LHS))
9243 InnerLHS = ZExt->getOperand();
9244 if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(InnerLHS);
9245 AR && !AR->hasNoSelfWrap() && AR->getLoop() == L && AR->isAffine() &&
9246 isKnownToBeAPowerOfTwo(AR->getStepRecurrence(*this), /*OrZero=*/true,
9247 /*OrNegative=*/true)) {
9248 auto Flags = AR->getNoWrapFlags();
9249 Flags = setFlags(Flags, SCEV::FlagNW);
9252 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), Flags);
9253 }
9254
9255 // For a slt/ult condition with a positive step, can we prove nsw/nuw?
9256 // From no-self-wrap, this follows trivially from the fact that every
9257 // (un)signed-wrapped, but not self-wrapped value must be LT than the
9258 // last value before (un)signed wrap. Since we know that last value
9259 // didn't exit, nor will any smaller one.
9260 if (Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_ULT) {
9261 auto WrapType = Pred == ICmpInst::ICMP_SLT ? SCEV::FlagNSW : SCEV::FlagNUW;
9262 if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(LHS);
9263 AR && AR->getLoop() == L && AR->isAffine() &&
9264 !AR->getNoWrapFlags(WrapType) && AR->hasNoSelfWrap() &&
9265 isKnownPositive(AR->getStepRecurrence(*this))) {
9266 auto Flags = AR->getNoWrapFlags();
9267 Flags = setFlags(Flags, WrapType);
9270 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), Flags);
9271 }
9272 }
9273 }
9274
9275 switch (Pred) {
9276 case ICmpInst::ICMP_NE: { // while (X != Y)
9277 // Convert to: while (X-Y != 0)
9278 if (LHS->getType()->isPointerTy()) {
9280 if (isa<SCEVCouldNotCompute>(LHS))
9281 return LHS;
9282 }
9283 if (RHS->getType()->isPointerTy()) {
9285 if (isa<SCEVCouldNotCompute>(RHS))
9286 return RHS;
9287 }
9288 ExitLimit EL = howFarToZero(getMinusSCEV(LHS, RHS), L, ControlsOnlyExit,
9289 AllowPredicates);
9290 if (EL.hasAnyInfo())
9291 return EL;
9292 break;
9293 }
9294 case ICmpInst::ICMP_EQ: { // while (X == Y)
9295 // Convert to: while (X-Y == 0)
9296 if (LHS->getType()->isPointerTy()) {
9298 if (isa<SCEVCouldNotCompute>(LHS))
9299 return LHS;
9300 }
9301 if (RHS->getType()->isPointerTy()) {
9303 if (isa<SCEVCouldNotCompute>(RHS))
9304 return RHS;
9305 }
9306 ExitLimit EL = howFarToNonZero(getMinusSCEV(LHS, RHS), L);
9307 if (EL.hasAnyInfo()) return EL;
9308 break;
9309 }
9310 case ICmpInst::ICMP_SLE:
9311 case ICmpInst::ICMP_ULE:
9312 // Since the loop is finite, an invariant RHS cannot include the boundary
9313 // value, otherwise it would loop forever.
9314 if (!EnableFiniteLoopControl || !ControllingFiniteLoop ||
9315 !isLoopInvariant(RHS, L)) {
9316 // Otherwise, perform the addition in a wider type, to avoid overflow.
9317 // If the LHS is an addrec with the appropriate nowrap flag, the
9318 // extension will be sunk into it and the exit count can be analyzed.
9319 auto *OldType = dyn_cast<IntegerType>(LHS->getType());
9320 if (!OldType)
9321 break;
9322 // Prefer doubling the bitwidth over adding a single bit to make it more
9323 // likely that we use a legal type.
9324 auto *NewType =
9325 Type::getIntNTy(OldType->getContext(), OldType->getBitWidth() * 2);
9326 if (ICmpInst::isSigned(Pred)) {
9327 LHS = getSignExtendExpr(LHS, NewType);
9328 RHS = getSignExtendExpr(RHS, NewType);
9329 } else {
9330 LHS = getZeroExtendExpr(LHS, NewType);
9331 RHS = getZeroExtendExpr(RHS, NewType);
9332 }
9333 }
9334 RHS = getAddExpr(getOne(RHS->getType()), RHS);
9335 [[fallthrough]];
9336 case ICmpInst::ICMP_SLT:
9337 case ICmpInst::ICMP_ULT: { // while (X < Y)
9338 bool IsSigned = ICmpInst::isSigned(Pred);
9339 ExitLimit EL = howManyLessThans(LHS, RHS, L, IsSigned, ControlsOnlyExit,
9340 AllowPredicates);
9341 if (EL.hasAnyInfo())
9342 return EL;
9343 break;
9344 }
9345 case ICmpInst::ICMP_SGE:
9346 case ICmpInst::ICMP_UGE:
9347 // Since the loop is finite, an invariant RHS cannot include the boundary
9348 // value, otherwise it would loop forever.
9349 if (!EnableFiniteLoopControl || !ControllingFiniteLoop ||
9350 !isLoopInvariant(RHS, L))
9351 break;
9352 RHS = getAddExpr(getMinusOne(RHS->getType()), RHS);
9353 [[fallthrough]];
9354 case ICmpInst::ICMP_SGT:
9355 case ICmpInst::ICMP_UGT: { // while (X > Y)
9356 bool IsSigned = ICmpInst::isSigned(Pred);
9357 ExitLimit EL = howManyGreaterThans(LHS, RHS, L, IsSigned, ControlsOnlyExit,
9358 AllowPredicates);
9359 if (EL.hasAnyInfo())
9360 return EL;
9361 break;
9362 }
9363 default:
9364 break;
9365 }
9366
9367 return getCouldNotCompute();
9368}
9369
9371ScalarEvolution::computeExitLimitFromSingleExitSwitch(const Loop *L,
9372 SwitchInst *Switch,
9373 BasicBlock *ExitingBlock,
9374 bool ControlsOnlyExit) {
9375 assert(!L->contains(ExitingBlock) && "Not an exiting block!");
9376
9377 // Give up if the exit is the default dest of a switch.
9378 if (Switch->getDefaultDest() == ExitingBlock)
9379 return getCouldNotCompute();
9380
9381 assert(L->contains(Switch->getDefaultDest()) &&
9382 "Default case must not exit the loop!");
9383 const SCEV *LHS = getSCEVAtScope(Switch->getCondition(), L);
9384 const SCEV *RHS = getConstant(Switch->findCaseDest(ExitingBlock));
9385
9386 // while (X != Y) --> while (X-Y != 0)
9387 ExitLimit EL = howFarToZero(getMinusSCEV(LHS, RHS), L, ControlsOnlyExit);
9388 if (EL.hasAnyInfo())
9389 return EL;
9390
9391 return getCouldNotCompute();
9392}
9393
9394static ConstantInt *
9396 ScalarEvolution &SE) {
9397 const SCEV *InVal = SE.getConstant(C);
9398 const SCEV *Val = AddRec->evaluateAtIteration(InVal, SE);
9399 assert(isa<SCEVConstant>(Val) &&
9400 "Evaluation of SCEV at constant didn't fold correctly?");
9401 return cast<SCEVConstant>(Val)->getValue();
9402}
9403
9404ScalarEvolution::ExitLimit ScalarEvolution::computeShiftCompareExitLimit(
9405 Value *LHS, Value *RHSV, const Loop *L, ICmpInst::Predicate Pred) {
9406 ConstantInt *RHS = dyn_cast<ConstantInt>(RHSV);
9407 if (!RHS)
9408 return getCouldNotCompute();
9409
9410 const BasicBlock *Latch = L->getLoopLatch();
9411 if (!Latch)
9412 return getCouldNotCompute();
9413
9414 const BasicBlock *Predecessor = L->getLoopPredecessor();
9415 if (!Predecessor)
9416 return getCouldNotCompute();
9417
9418 // Return true if V is of the form "LHS `shift_op` <positive constant>".
9419 // Return LHS in OutLHS and shift_opt in OutOpCode.
9420 auto MatchPositiveShift =
9421 [](Value *V, Value *&OutLHS, Instruction::BinaryOps &OutOpCode) {
9422
9423 using namespace PatternMatch;
9424
9425 ConstantInt *ShiftAmt;
9426 if (match(V, m_LShr(m_Value(OutLHS), m_ConstantInt(ShiftAmt))))
9427 OutOpCode = Instruction::LShr;
9428 else if (match(V, m_AShr(m_Value(OutLHS), m_ConstantInt(ShiftAmt))))
9429 OutOpCode = Instruction::AShr;
9430 else if (match(V, m_Shl(m_Value(OutLHS), m_ConstantInt(ShiftAmt))))
9431 OutOpCode = Instruction::Shl;
9432 else
9433 return false;
9434
9435 return ShiftAmt->getValue().isStrictlyPositive();
9436 };
9437
9438 // Recognize a "shift recurrence" either of the form %iv or of %iv.shifted in
9439 //
9440 // loop:
9441 // %iv = phi i32 [ %iv.shifted, %loop ], [ %val, %preheader ]
9442 // %iv.shifted = lshr i32 %iv, <positive constant>
9443 //
9444 // Return true on a successful match. Return the corresponding PHI node (%iv
9445 // above) in PNOut and the opcode of the shift operation in OpCodeOut.
9446 auto MatchShiftRecurrence =
9447 [&](Value *V, PHINode *&PNOut, Instruction::BinaryOps &OpCodeOut) {
9448 std::optional<Instruction::BinaryOps> PostShiftOpCode;
9449
9450 {
9452 Value *V;
9453
9454 // If we encounter a shift instruction, "peel off" the shift operation,
9455 // and remember that we did so. Later when we inspect %iv's backedge
9456 // value, we will make sure that the backedge value uses the same
9457 // operation.
9458 //
9459 // Note: the peeled shift operation does not have to be the same
9460 // instruction as the one feeding into the PHI's backedge value. We only
9461 // really care about it being the same *kind* of shift instruction --
9462 // that's all that is required for our later inferences to hold.
9463 if (MatchPositiveShift(LHS, V, OpC)) {
9464 PostShiftOpCode = OpC;
9465 LHS = V;
9466 }
9467 }
9468
9469 PNOut = dyn_cast<PHINode>(LHS);
9470 if (!PNOut || PNOut->getParent() != L->getHeader())
9471 return false;
9472
9473 Value *BEValue = PNOut->getIncomingValueForBlock(Latch);
9474 Value *OpLHS;
9475
9476 return
9477 // The backedge value for the PHI node must be a shift by a positive
9478 // amount
9479 MatchPositiveShift(BEValue, OpLHS, OpCodeOut) &&
9480
9481 // of the PHI node itself
9482 OpLHS == PNOut &&
9483
9484 // and the kind of shift should be match the kind of shift we peeled
9485 // off, if any.
9486 (!PostShiftOpCode || *PostShiftOpCode == OpCodeOut);
9487 };
9488
9489 PHINode *PN;
9491 if (!MatchShiftRecurrence(LHS, PN, OpCode))
9492 return getCouldNotCompute();
9493
9494 const DataLayout &DL = getDataLayout();
9495
9496 // The key rationale for this optimization is that for some kinds of shift
9497 // recurrences, the value of the recurrence "stabilizes" to either 0 or -1
9498 // within a finite number of iterations. If the condition guarding the
9499 // backedge (in the sense that the backedge is taken if the condition is true)
9500 // is false for the value the shift recurrence stabilizes to, then we know
9501 // that the backedge is taken only a finite number of times.
9502
9503 ConstantInt *StableValue = nullptr;
9504 switch (OpCode) {
9505 default:
9506 llvm_unreachable("Impossible case!");
9507
9508 case Instruction::AShr: {
9509 // {K,ashr,<positive-constant>} stabilizes to signum(K) in at most
9510 // bitwidth(K) iterations.
9511 Value *FirstValue = PN->getIncomingValueForBlock(Predecessor);
9512 KnownBits Known = computeKnownBits(FirstValue, DL, 0, &AC,
9513 Predecessor->getTerminator(), &DT);
9514 auto *Ty = cast<IntegerType>(RHS->getType());
9515 if (Known.isNonNegative())
9516 StableValue = ConstantInt::get(Ty, 0);
9517 else if (Known.isNegative())
9518 StableValue = ConstantInt::get(Ty, -1, true);
9519 else
9520 return getCouldNotCompute();
9521
9522 break;
9523 }
9524 case Instruction::LShr:
9525 case Instruction::Shl:
9526 // Both {K,lshr,<positive-constant>} and {K,shl,<positive-constant>}
9527 // stabilize to 0 in at most bitwidth(K) iterations.
9528 StableValue = ConstantInt::get(cast<IntegerType>(RHS->getType()), 0);
9529 break;
9530 }
9531
9532 auto *Result =
9533 ConstantFoldCompareInstOperands(Pred, StableValue, RHS, DL, &TLI);
9534 assert(Result->getType()->isIntegerTy(1) &&
9535 "Otherwise cannot be an operand to a branch instruction");
9536
9537 if (Result->isZeroValue()) {
9538 unsigned BitWidth = getTypeSizeInBits(RHS->getType());
9539 const SCEV *UpperBound =
9541 return ExitLimit(getCouldNotCompute(), UpperBound, UpperBound, false);
9542 }
9543
9544 return getCouldNotCompute();
9545}
9546
9547/// Return true if we can constant fold an instruction of the specified type,
9548/// assuming that all operands were constants.
9549static bool CanConstantFold(const Instruction *I) {
9550 if (isa<BinaryOperator>(I) || isa<CmpInst>(I) ||
9551 isa<SelectInst>(I) || isa<CastInst>(I) || isa<GetElementPtrInst>(I) ||
9552 isa<LoadInst>(I) || isa<ExtractValueInst>(I))
9553 return true;
9554
9555 if (const CallInst *CI = dyn_cast<CallInst>(I))
9556 if (const Function *F = CI->getCalledFunction())
9557 return canConstantFoldCallTo(CI, F);
9558 return false;
9559}
9560
9561/// Determine whether this instruction can constant evolve within this loop
9562/// assuming its operands can all constant evolve.
9563static bool canConstantEvolve(Instruction *I, const Loop *L) {
9564 // An instruction outside of the loop can't be derived from a loop PHI.
9565 if (!L->contains(I)) return false;
9566
9567 if (isa<PHINode>(I)) {
9568 // We don't currently keep track of the control flow needed to evaluate
9569 // PHIs, so we cannot handle PHIs inside of loops.
9570 return L->getHeader() == I->getParent();
9571 }
9572
9573 // If we won't be able to constant fold this expression even if the operands
9574 // are constants, bail early.
9575 return CanConstantFold(I);
9576}
9577
9578/// getConstantEvolvingPHIOperands - Implement getConstantEvolvingPHI by
9579/// recursing through each instruction operand until reaching a loop header phi.
9580static PHINode *
9583 unsigned Depth) {
9585 return nullptr;
9586
9587 // Otherwise, we can evaluate this instruction if all of its operands are
9588 // constant or derived from a PHI node themselves.
9589 PHINode *PHI = nullptr;
9590 for (Value *Op : UseInst->operands()) {
9591 if (isa<Constant>(Op)) continue;
9592
9593 Instruction *OpInst = dyn_cast<Instruction>(Op);
9594 if (!OpInst || !canConstantEvolve(OpInst, L)) return nullptr;
9595
9596 PHINode *P = dyn_cast<PHINode>(OpInst);
9597 if (!P)
9598 // If this operand is already visited, reuse the prior result.
9599 // We may have P != PHI if this is the deepest point at which the
9600 // inconsistent paths meet.
9601 P = PHIMap.lookup(OpInst);
9602 if (!P) {
9603 // Recurse and memoize the results, whether a phi is found or not.
9604 // This recursive call invalidates pointers into PHIMap.
9605 P = getConstantEvolvingPHIOperands(OpInst, L, PHIMap, Depth + 1);
9606 PHIMap[OpInst] = P;
9607 }
9608 if (!P)
9609 return nullptr; // Not evolving from PHI
9610 if (PHI && PHI != P)
9611 return nullptr; // Evolving from multiple different PHIs.
9612 PHI = P;
9613 }
9614 // This is a expression evolving from a constant PHI!
9615 return PHI;
9616}
9617
9618/// getConstantEvolvingPHI - Given an LLVM value and a loop, return a PHI node
9619/// in the loop that V is derived from. We allow arbitrary operations along the
9620/// way, but the operands of an operation must either be constants or a value
9621/// derived from a constant PHI. If this expression does not fit with these
9622/// constraints, return null.
9624 Instruction *I = dyn_cast<Instruction>(V);
9625 if (!I || !canConstantEvolve(I, L)) return nullptr;
9626
9627 if (PHINode *PN = dyn_cast<PHINode>(I))
9628 return PN;
9629
9630 // Record non-constant instructions contained by the loop.
9632 return getConstantEvolvingPHIOperands(I, L, PHIMap, 0);
9633}
9634
9635/// EvaluateExpression - Given an expression that passes the
9636/// getConstantEvolvingPHI predicate, evaluate its value assuming the PHI node
9637/// in the loop has the value PHIVal. If we can't fold this expression for some
9638/// reason, return null.
9641 const DataLayout &DL,
9642 const TargetLibraryInfo *TLI) {
9643 // Convenient constant check, but redundant for recursive calls.
9644 if (Constant *C = dyn_cast<Constant>(V)) return C;
9645 Instruction *I = dyn_cast<Instruction>(V);
9646 if (!I) return nullptr;
9647
9648 if (Constant *C = Vals.lookup(I)) return C;
9649
9650 // An instruction inside the loop depends on a value outside the loop that we
9651 // weren't given a mapping for, or a value such as a call inside the loop.
9652 if (!canConstantEvolve(I, L)) return nullptr;
9653
9654 // An unmapped PHI can be due to a branch or another loop inside this loop,
9655 // or due to this not being the initial iteration through a loop where we
9656 // couldn't compute the evolution of this particular PHI last time.
9657 if (isa<PHINode>(I)) return nullptr;
9658
9659 std::vector<Constant*> Operands(I->getNumOperands());
9660
9661 for (unsigned i = 0, e = I->getNumOperands(); i != e; ++i) {
9662 Instruction *Operand = dyn_cast<Instruction>(I->getOperand(i));
9663 if (!Operand) {
9664 Operands[i] = dyn_cast<Constant>(I->getOperand(i));
9665 if (!Operands[i]) return nullptr;
9666 continue;
9667 }
9668 Constant *C = EvaluateExpression(Operand, L, Vals, DL, TLI);
9669 Vals[Operand] = C;
9670 if (!C) return nullptr;
9671 Operands[i] = C;
9672 }
9673
9674 return ConstantFoldInstOperands(I, Operands, DL, TLI,
9675 /*AllowNonDeterministic=*/false);
9676}
9677
9678
9679// If every incoming value to PN except the one for BB is a specific Constant,
9680// return that, else return nullptr.
9682 Constant *IncomingVal = nullptr;
9683
9684 for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) {
9685 if (PN->getIncomingBlock(i) == BB)
9686 continue;
9687
9688 auto *CurrentVal = dyn_cast<Constant>(PN->getIncomingValue(i));
9689 if (!CurrentVal)
9690 return nullptr;
9691
9692 if (IncomingVal != CurrentVal) {
9693 if (IncomingVal)
9694 return nullptr;
9695 IncomingVal = CurrentVal;
9696 }
9697 }
9698
9699 return IncomingVal;
9700}
9701
9702/// getConstantEvolutionLoopExitValue - If we know that the specified Phi is
9703/// in the header of its containing loop, we know the loop executes a
9704/// constant number of times, and the PHI node is just a recurrence
9705/// involving constants, fold it.
9706Constant *
9707ScalarEvolution::getConstantEvolutionLoopExitValue(PHINode *PN,
9708 const APInt &BEs,
9709 const Loop *L) {
9710 auto [I, Inserted] = ConstantEvolutionLoopExitValue.try_emplace(PN);
9711 if (!Inserted)
9712 return I->second;
9713
9715 return nullptr; // Not going to evaluate it.
9716
9717 Constant *&RetVal = I->second;
9718
9720 BasicBlock *Header = L->getHeader();
9721 assert(PN->getParent() == Header && "Can't evaluate PHI not in loop header!");
9722
9723 BasicBlock *Latch = L->getLoopLatch();
9724 if (!Latch)
9725 return nullptr;
9726
9727 for (PHINode &PHI : Header->phis()) {
9728 if (auto *StartCST = getOtherIncomingValue(&PHI, Latch))
9729 CurrentIterVals[&PHI] = StartCST;
9730 }
9731 if (!CurrentIterVals.count(PN))
9732 return RetVal = nullptr;
9733
9734 Value *BEValue = PN->getIncomingValueForBlock(Latch);
9735
9736 // Execute the loop symbolically to determine the exit value.
9737 assert(BEs.getActiveBits() < CHAR_BIT * sizeof(unsigned) &&
9738 "BEs is <= MaxBruteForceIterations which is an 'unsigned'!");
9739
9740 unsigned NumIterations = BEs.getZExtValue(); // must be in range
9741 unsigned IterationNum = 0;
9742 const DataLayout &DL = getDataLayout();
9743 for (; ; ++IterationNum) {
9744 if (IterationNum == NumIterations)
9745 return RetVal = CurrentIterVals[PN]; // Got exit value!
9746
9747 // Compute the value of the PHIs for the next iteration.
9748 // EvaluateExpression adds non-phi values to the CurrentIterVals map.
9750 Constant *NextPHI =
9751 EvaluateExpression(BEValue, L, CurrentIterVals, DL, &TLI);
9752 if (!NextPHI)
9753 return nullptr; // Couldn't evaluate!
9754 NextIterVals[PN] = NextPHI;
9755
9756 bool StoppedEvolving = NextPHI == CurrentIterVals[PN];
9757
9758 // Also evaluate the other PHI nodes. However, we don't get to stop if we
9759 // cease to be able to evaluate one of them or if they stop evolving,
9760 // because that doesn't necessarily prevent us from computing PN.
9762 for (const auto &I : CurrentIterVals) {
9763 PHINode *PHI = dyn_cast<PHINode>(I.first);
9764 if (!PHI || PHI == PN || PHI->getParent() != Header) continue;
9765 PHIsToCompute.emplace_back(PHI, I.second);
9766 }
9767 // We use two distinct loops because EvaluateExpression may invalidate any
9768 // iterators into CurrentIterVals.
9769 for (const auto &I : PHIsToCompute) {
9770 PHINode *PHI = I.first;
9771 Constant *&NextPHI = NextIterVals[PHI];
9772 if (!NextPHI) { // Not already computed.
9773 Value *BEValue = PHI->getIncomingValueForBlock(Latch);
9774 NextPHI = EvaluateExpression(BEValue, L, CurrentIterVals, DL, &TLI);
9775 }
9776 if (NextPHI != I.second)
9777 StoppedEvolving = false;
9778 }
9779
9780 // If all entries in CurrentIterVals == NextIterVals then we can stop
9781 // iterating, the loop can't continue to change.
9782 if (StoppedEvolving)
9783 return RetVal = CurrentIterVals[PN];
9784
9785 CurrentIterVals.swap(NextIterVals);
9786 }
9787}
9788
9789const SCEV *ScalarEvolution::computeExitCountExhaustively(const Loop *L,
9790 Value *Cond,
9791 bool ExitWhen) {
9793 if (!PN) return getCouldNotCompute();
9794
9795 // If the loop is canonicalized, the PHI will have exactly two entries.
9796 // That's the only form we support here.
9797 if (PN->getNumIncomingValues() != 2) return getCouldNotCompute();
9798
9800 BasicBlock *Header = L->getHeader();
9801 assert(PN->getParent() == Header && "Can't evaluate PHI not in loop header!");
9802
9803 BasicBlock *Latch = L->getLoopLatch();
9804 assert(Latch && "Should follow from NumIncomingValues == 2!");
9805
9806 for (PHINode &PHI : Header->phis()) {
9807 if (auto *StartCST = getOtherIncomingValue(&PHI, Latch))
9808 CurrentIterVals[&PHI] = StartCST;
9809 }
9810 if (!CurrentIterVals.count(PN))
9811 return getCouldNotCompute();
9812
9813 // Okay, we find a PHI node that defines the trip count of this loop. Execute
9814 // the loop symbolically to determine when the condition gets a value of
9815 // "ExitWhen".
9816 unsigned MaxIterations = MaxBruteForceIterations; // Limit analysis.
9817 const DataLayout &DL = getDataLayout();
9818 for (unsigned IterationNum = 0; IterationNum != MaxIterations;++IterationNum){
9819 auto *CondVal = dyn_cast_or_null<ConstantInt>(
9820 EvaluateExpression(Cond, L, CurrentIterVals, DL, &TLI));
9821
9822 // Couldn't symbolically evaluate.
9823 if (!CondVal) return getCouldNotCompute();
9824
9825 if (CondVal->getValue() == uint64_t(ExitWhen)) {
9826 ++NumBruteForceTripCountsComputed;
9827 return getConstant(Type::getInt32Ty(getContext()), IterationNum);
9828 }
9829
9830 // Update all the PHI nodes for the next iteration.
9832
9833 // Create a list of which PHIs we need to compute. We want to do this before
9834 // calling EvaluateExpression on them because that may invalidate iterators
9835 // into CurrentIterVals.
9836 SmallVector<PHINode *, 8> PHIsToCompute;
9837 for (const auto &I : CurrentIterVals) {
9838 PHINode *PHI = dyn_cast<PHINode>(I.first);
9839 if (!PHI || PHI->getParent() != Header) continue;
9840 PHIsToCompute.push_back(PHI);
9841 }
9842 for (PHINode *PHI : PHIsToCompute) {
9843 Constant *&NextPHI = NextIterVals[PHI];
9844 if (NextPHI) continue; // Already computed!
9845
9846 Value *BEValue = PHI->getIncomingValueForBlock(Latch);
9847 NextPHI = EvaluateExpression(BEValue, L, CurrentIterVals, DL, &TLI);
9848 }
9849 CurrentIterVals.swap(NextIterVals);
9850 }
9851
9852 // Too many iterations were needed to evaluate.
9853 return getCouldNotCompute();
9854}
9855
9856const SCEV *ScalarEvolution::getSCEVAtScope(const SCEV *V, const Loop *L) {
9858 ValuesAtScopes[V];
9859 // Check to see if we've folded this expression at this loop before.
9860 for (auto &LS : Values)
9861 if (LS.first == L)
9862 return LS.second ? LS.second : V;
9863
9864 Values.emplace_back(L, nullptr);
9865
9866 // Otherwise compute it.
9867 const SCEV *C = computeSCEVAtScope(V, L);
9868 for (auto &LS : reverse(ValuesAtScopes[V]))
9869 if (LS.first == L) {
9870 LS.second = C;
9871 if (!isa<SCEVConstant>(C))
9872 ValuesAtScopesUsers[C].push_back({L, V});
9873 break;
9874 }
9875 return C;
9876}
9877
9878/// This builds up a Constant using the ConstantExpr interface. That way, we
9879/// will return Constants for objects which aren't represented by a
9880/// SCEVConstant, because SCEVConstant is restricted to ConstantInt.
9881/// Returns NULL if the SCEV isn't representable as a Constant.
9883 switch (V->getSCEVType()) {
9884 case scCouldNotCompute:
9885 case scAddRecExpr:
9886 case scVScale:
9887 return nullptr;
9888 case scConstant:
9889 return cast<SCEVConstant>(V)->getValue();
9890 case scUnknown:
9891 return dyn_cast<Constant>(cast<SCEVUnknown>(V)->getValue());
9892 case scPtrToInt: {
9893 const SCEVPtrToIntExpr *P2I = cast<SCEVPtrToIntExpr>(V);
9894 if (Constant *CastOp = BuildConstantFromSCEV(P2I->getOperand()))
9895 return ConstantExpr::getPtrToInt(CastOp, P2I->getType());
9896
9897 return nullptr;
9898 }
9899 case scTruncate: {
9900 const SCEVTruncateExpr *ST = cast<SCEVTruncateExpr>(V);
9901 if (Constant *CastOp = BuildConstantFromSCEV(ST->getOperand()))
9902 return ConstantExpr::getTrunc(CastOp, ST->getType());
9903 return nullptr;
9904 }
9905 case scAddExpr: {
9906 const SCEVAddExpr *SA = cast<SCEVAddExpr>(V);
9907 Constant *C = nullptr;
9908 for (const SCEV *Op : SA->operands()) {
9910 if (!OpC)
9911 return nullptr;
9912 if (!C) {
9913 C = OpC;
9914 continue;
9915 }
9916 assert(!C->getType()->isPointerTy() &&
9917 "Can only have one pointer, and it must be last");
9918 if (OpC->getType()->isPointerTy()) {
9919 // The offsets have been converted to bytes. We can add bytes using
9920 // an i8 GEP.
9922 OpC, C);
9923 } else {
9924 C = ConstantExpr::getAdd(C, OpC);
9925 }
9926 }
9927 return C;
9928 }
9929 case scMulExpr:
9930 case scSignExtend:
9931 case scZeroExtend:
9932 case scUDivExpr:
9933 case scSMaxExpr:
9934 case scUMaxExpr:
9935 case scSMinExpr:
9936 case scUMinExpr:
9938 return nullptr;
9939 }
9940 llvm_unreachable("Unknown SCEV kind!");
9941}
9942
9943const SCEV *
9944ScalarEvolution::getWithOperands(const SCEV *S,
9946 switch (S->getSCEVType()) {
9947 case scTruncate:
9948 case scZeroExtend:
9949 case scSignExtend:
9950 case scPtrToInt:
9951 return getCastExpr(S->getSCEVType(), NewOps[0], S->getType());
9952 case scAddRecExpr: {
9953 auto *AddRec = cast<SCEVAddRecExpr>(S);
9954 return getAddRecExpr(NewOps, AddRec->getLoop(), AddRec->getNoWrapFlags());
9955 }
9956 case scAddExpr:
9957 return getAddExpr(NewOps, cast<SCEVAddExpr>(S)->getNoWrapFlags());
9958 case scMulExpr:
9959 return getMulExpr(NewOps, cast<SCEVMulExpr>(S)->getNoWrapFlags());
9960 case scUDivExpr:
9961 return getUDivExpr(NewOps[0], NewOps[1]);
9962 case scUMaxExpr:
9963 case scSMaxExpr:
9964 case scUMinExpr:
9965 case scSMinExpr:
9966 return getMinMaxExpr(S->getSCEVType(), NewOps);
9968 return getSequentialMinMaxExpr(S->getSCEVType(), NewOps);
9969 case scConstant:
9970 case scVScale:
9971 case scUnknown:
9972 return S;
9973 case scCouldNotCompute:
9974 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
9975 }
9976 llvm_unreachable("Unknown SCEV kind!");
9977}
9978
9979const SCEV *ScalarEvolution::computeSCEVAtScope(const SCEV *V, const Loop *L) {
9980 switch (V->getSCEVType()) {
9981 case scConstant:
9982 case scVScale:
9983 return V;
9984 case scAddRecExpr: {
9985 // If this is a loop recurrence for a loop that does not contain L, then we
9986 // are dealing with the final value computed by the loop.
9987 const SCEVAddRecExpr *AddRec = cast<SCEVAddRecExpr>(V);
9988 // First, attempt to evaluate each operand.
9989 // Avoid performing the look-up in the common case where the specified
9990 // expression has no loop-variant portions.
9991 for (unsigned i = 0, e = AddRec->getNumOperands(); i != e; ++i) {
9992 const SCEV *OpAtScope = getSCEVAtScope(AddRec->getOperand(i), L);
9993 if (OpAtScope == AddRec->getOperand(i))
9994 continue;
9995
9996 // Okay, at least one of these operands is loop variant but might be
9997 // foldable. Build a new instance of the folded commutative expression.
9999 NewOps.reserve(AddRec->getNumOperands());
10000 append_range(NewOps, AddRec->operands().take_front(i));
10001 NewOps.push_back(OpAtScope);
10002 for (++i; i != e; ++i)
10003 NewOps.push_back(getSCEVAtScope(AddRec->getOperand(i), L));
10004
10005 const SCEV *FoldedRec = getAddRecExpr(
10006 NewOps, AddRec->getLoop(), AddRec->getNoWrapFlags(SCEV::FlagNW));
10007 AddRec = dyn_cast<SCEVAddRecExpr>(FoldedRec);
10008 // The addrec may be folded to a nonrecurrence, for example, if the
10009 // induction variable is multiplied by zero after constant folding. Go
10010 // ahead and return the folded value.
10011 if (!AddRec)
10012 return FoldedRec;
10013 break;
10014 }
10015
10016 // If the scope is outside the addrec's loop, evaluate it by using the
10017 // loop exit value of the addrec.
10018 if (!AddRec->getLoop()->contains(L)) {
10019 // To evaluate this recurrence, we need to know how many times the AddRec
10020 // loop iterates. Compute this now.
10021 const SCEV *BackedgeTakenCount = getBackedgeTakenCount(AddRec->getLoop());
10022 if (BackedgeTakenCount == getCouldNotCompute())
10023 return AddRec;
10024
10025 // Then, evaluate the AddRec.
10026 return AddRec->evaluateAtIteration(BackedgeTakenCount, *this);
10027 }
10028
10029 return AddRec;
10030 }
10031 case scTruncate:
10032 case scZeroExtend:
10033 case scSignExtend:
10034 case scPtrToInt:
10035 case scAddExpr:
10036 case scMulExpr:
10037 case scUDivExpr:
10038 case scUMaxExpr:
10039 case scSMaxExpr:
10040 case scUMinExpr:
10041 case scSMinExpr:
10042 case scSequentialUMinExpr: {
10043 ArrayRef<const SCEV *> Ops = V->operands();
10044 // Avoid performing the look-up in the common case where the specified
10045 // expression has no loop-variant portions.
10046 for (unsigned i = 0, e = Ops.size(); i != e; ++i) {
10047 const SCEV *OpAtScope = getSCEVAtScope(Ops[i], L);
10048 if (OpAtScope != Ops[i]) {
10049 // Okay, at least one of these operands is loop variant but might be
10050 // foldable. Build a new instance of the folded commutative expression.
10052 NewOps.reserve(Ops.size());
10053 append_range(NewOps, Ops.take_front(i));
10054 NewOps.push_back(OpAtScope);
10055
10056 for (++i; i != e; ++i) {
10057 OpAtScope = getSCEVAtScope(Ops[i], L);
10058 NewOps.push_back(OpAtScope);
10059 }
10060
10061 return getWithOperands(V, NewOps);
10062 }
10063 }
10064 // If we got here, all operands are loop invariant.
10065 return V;
10066 }
10067 case scUnknown: {
10068 // If this instruction is evolved from a constant-evolving PHI, compute the
10069 // exit value from the loop without using SCEVs.
10070 const SCEVUnknown *SU = cast<SCEVUnknown>(V);
10071 Instruction *I = dyn_cast<Instruction>(SU->getValue());
10072 if (!I)
10073 return V; // This is some other type of SCEVUnknown, just return it.
10074
10075 if (PHINode *PN = dyn_cast<PHINode>(I)) {
10076 const Loop *CurrLoop = this->LI[I->getParent()];
10077 // Looking for loop exit value.
10078 if (CurrLoop && CurrLoop->getParentLoop() == L &&
10079 PN->getParent() == CurrLoop->getHeader()) {
10080 // Okay, there is no closed form solution for the PHI node. Check
10081 // to see if the loop that contains it has a known backedge-taken
10082 // count. If so, we may be able to force computation of the exit
10083 // value.
10084 const SCEV *BackedgeTakenCount = getBackedgeTakenCount(CurrLoop);
10085 // This trivial case can show up in some degenerate cases where
10086 // the incoming IR has not yet been fully simplified.
10087 if (BackedgeTakenCount->isZero()) {
10088 Value *InitValue = nullptr;
10089 bool MultipleInitValues = false;
10090 for (unsigned i = 0; i < PN->getNumIncomingValues(); i++) {
10091 if (!CurrLoop->contains(PN->getIncomingBlock(i))) {
10092 if (!InitValue)
10093 InitValue = PN->getIncomingValue(i);
10094 else if (InitValue != PN->getIncomingValue(i)) {
10095 MultipleInitValues = true;
10096 break;
10097 }
10098 }
10099 }
10100 if (!MultipleInitValues && InitValue)
10101 return getSCEV(InitValue);
10102 }
10103 // Do we have a loop invariant value flowing around the backedge
10104 // for a loop which must execute the backedge?
10105 if (!isa<SCEVCouldNotCompute>(BackedgeTakenCount) &&
10106 isKnownNonZero(BackedgeTakenCount) &&
10107 PN->getNumIncomingValues() == 2) {
10108
10109 unsigned InLoopPred =
10110 CurrLoop->contains(PN->getIncomingBlock(0)) ? 0 : 1;
10111 Value *BackedgeVal = PN->getIncomingValue(InLoopPred);
10112 if (CurrLoop->isLoopInvariant(BackedgeVal))
10113 return getSCEV(BackedgeVal);
10114 }
10115 if (auto *BTCC = dyn_cast<SCEVConstant>(BackedgeTakenCount)) {
10116 // Okay, we know how many times the containing loop executes. If
10117 // this is a constant evolving PHI node, get the final value at
10118 // the specified iteration number.
10119 Constant *RV =
10120 getConstantEvolutionLoopExitValue(PN, BTCC->getAPInt(), CurrLoop);
10121 if (RV)
10122 return getSCEV(RV);
10123 }
10124 }
10125 }
10126
10127 // Okay, this is an expression that we cannot symbolically evaluate
10128 // into a SCEV. Check to see if it's possible to symbolically evaluate
10129 // the arguments into constants, and if so, try to constant propagate the
10130 // result. This is particularly useful for computing loop exit values.
10131 if (!CanConstantFold(I))
10132 return V; // This is some other type of SCEVUnknown, just return it.
10133
10135 Operands.reserve(I->getNumOperands());
10136 bool MadeImprovement = false;
10137 for (Value *Op : I->operands()) {
10138 if (Constant *C = dyn_cast<Constant>(Op)) {
10139 Operands.push_back(C);
10140 continue;
10141 }
10142
10143 // If any of the operands is non-constant and if they are
10144 // non-integer and non-pointer, don't even try to analyze them
10145 // with scev techniques.
10146 if (!isSCEVable(Op->getType()))
10147 return V;
10148
10149 const SCEV *OrigV = getSCEV(Op);
10150 const SCEV *OpV = getSCEVAtScope(OrigV, L);
10151 MadeImprovement |= OrigV != OpV;
10152
10154 if (!C)
10155 return V;
10156 assert(C->getType() == Op->getType() && "Type mismatch");
10157 Operands.push_back(C);
10158 }
10159
10160 // Check to see if getSCEVAtScope actually made an improvement.
10161 if (!MadeImprovement)
10162 return V; // This is some other type of SCEVUnknown, just return it.
10163
10164 Constant *C = nullptr;
10165 const DataLayout &DL = getDataLayout();
10166 C = ConstantFoldInstOperands(I, Operands, DL, &TLI,
10167 /*AllowNonDeterministic=*/false);
10168 if (!C)
10169 return V;
10170 return getSCEV(C);
10171 }
10172 case scCouldNotCompute:
10173 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
10174 }
10175 llvm_unreachable("Unknown SCEV type!");
10176}
10177
10179 return getSCEVAtScope(getSCEV(V), L);
10180}
10181
10182const SCEV *ScalarEvolution::stripInjectiveFunctions(const SCEV *S) const {
10183 if (const SCEVZeroExtendExpr *ZExt = dyn_cast<SCEVZeroExtendExpr>(S))
10184 return stripInjectiveFunctions(ZExt->getOperand());
10185 if (const SCEVSignExtendExpr *SExt = dyn_cast<SCEVSignExtendExpr>(S))
10186 return stripInjectiveFunctions(SExt->getOperand());
10187 return S;
10188}
10189
10190/// Finds the minimum unsigned root of the following equation:
10191///
10192/// A * X = B (mod N)
10193///
10194/// where N = 2^BW and BW is the common bit width of A and B. The signedness of
10195/// A and B isn't important.
10196///
10197/// If the equation does not have a solution, SCEVCouldNotCompute is returned.
10198static const SCEV *
10201
10202 ScalarEvolution &SE) {
10203 uint32_t BW = A.getBitWidth();
10204 assert(BW == SE.getTypeSizeInBits(B->getType()));
10205 assert(A != 0 && "A must be non-zero.");
10206
10207 // 1. D = gcd(A, N)
10208 //
10209 // The gcd of A and N may have only one prime factor: 2. The number of
10210 // trailing zeros in A is its multiplicity
10211 uint32_t Mult2 = A.countr_zero();
10212 // D = 2^Mult2
10213
10214 // 2. Check if B is divisible by D.
10215 //
10216 // B is divisible by D if and only if the multiplicity of prime factor 2 for B
10217 // is not less than multiplicity of this prime factor for D.
10218 if (SE.getMinTrailingZeros(B) < Mult2) {
10219 // Check if we can prove there's no remainder using URem.
10220 const SCEV *URem =
10221 SE.getURemExpr(B, SE.getConstant(APInt::getOneBitSet(BW, Mult2)));
10222 const SCEV *Zero = SE.getZero(B->getType());
10223 if (!SE.isKnownPredicate(CmpInst::ICMP_EQ, URem, Zero)) {
10224 // Try to add a predicate ensuring B is a multiple of 1 << Mult2.
10225 if (!Predicates)
10226 return SE.getCouldNotCompute();
10227
10228 // Avoid adding a predicate that is known to be false.
10229 if (SE.isKnownPredicate(CmpInst::ICMP_NE, URem, Zero))
10230 return SE.getCouldNotCompute();
10231 Predicates->push_back(SE.getEqualPredicate(URem, Zero));
10232 }
10233 }
10234
10235 // 3. Compute I: the multiplicative inverse of (A / D) in arithmetic
10236 // modulo (N / D).
10237 //
10238 // If D == 1, (N / D) == N == 2^BW, so we need one extra bit to represent
10239 // (N / D) in general. The inverse itself always fits into BW bits, though,
10240 // so we immediately truncate it.
10241 APInt AD = A.lshr(Mult2).trunc(BW - Mult2); // AD = A / D
10242 APInt I = AD.multiplicativeInverse().zext(BW);
10243
10244 // 4. Compute the minimum unsigned root of the equation:
10245 // I * (B / D) mod (N / D)
10246 // To simplify the computation, we factor out the divide by D:
10247 // (I * B mod N) / D
10248 const SCEV *D = SE.getConstant(APInt::getOneBitSet(BW, Mult2));
10249 return SE.getUDivExactExpr(SE.getMulExpr(B, SE.getConstant(I)), D);
10250}
10251
10252/// For a given quadratic addrec, generate coefficients of the corresponding
10253/// quadratic equation, multiplied by a common value to ensure that they are
10254/// integers.
10255/// The returned value is a tuple { A, B, C, M, BitWidth }, where
10256/// Ax^2 + Bx + C is the quadratic function, M is the value that A, B and C
10257/// were multiplied by, and BitWidth is the bit width of the original addrec
10258/// coefficients.
10259/// This function returns std::nullopt if the addrec coefficients are not
10260/// compile- time constants.
10261static std::optional<std::tuple<APInt, APInt, APInt, APInt, unsigned>>
10263 assert(AddRec->getNumOperands() == 3 && "This is not a quadratic chrec!");
10264 const SCEVConstant *LC = dyn_cast<SCEVConstant>(AddRec->getOperand(0));
10265 const SCEVConstant *MC = dyn_cast<SCEVConstant>(AddRec->getOperand(1));
10266 const SCEVConstant *NC = dyn_cast<SCEVConstant>(AddRec->getOperand(2));
10267 LLVM_DEBUG(dbgs() << __func__ << ": analyzing quadratic addrec: "
10268 << *AddRec << '\n');
10269
10270 // We currently can only solve this if the coefficients are constants.
10271 if (!LC || !MC || !NC) {
10272 LLVM_DEBUG(dbgs() << __func__ << ": coefficients are not constant\n");
10273 return std::nullopt;
10274 }
10275
10276 APInt L = LC->getAPInt();
10277 APInt M = MC->getAPInt();
10278 APInt N = NC->getAPInt();
10279 assert(!N.isZero() && "This is not a quadratic addrec");
10280
10281 unsigned BitWidth = LC->getAPInt().getBitWidth();
10282 unsigned NewWidth = BitWidth + 1;
10283 LLVM_DEBUG(dbgs() << __func__ << ": addrec coeff bw: "
10284 << BitWidth << '\n');
10285 // The sign-extension (as opposed to a zero-extension) here matches the
10286 // extension used in SolveQuadraticEquationWrap (with the same motivation).
10287 N = N.sext(NewWidth);
10288 M = M.sext(NewWidth);
10289 L = L.sext(NewWidth);
10290
10291 // The increments are M, M+N, M+2N, ..., so the accumulated values are
10292 // L+M, (L+M)+(M+N), (L+M)+(M+N)+(M+2N), ..., that is,
10293 // L+M, L+2M+N, L+3M+3N, ...
10294 // After n iterations the accumulated value Acc is L + nM + n(n-1)/2 N.
10295 //
10296 // The equation Acc = 0 is then
10297 // L + nM + n(n-1)/2 N = 0, or 2L + 2M n + n(n-1) N = 0.
10298 // In a quadratic form it becomes:
10299 // N n^2 + (2M-N) n + 2L = 0.
10300
10301 APInt A = N;
10302 APInt B = 2 * M - A;
10303 APInt C = 2 * L;
10304 APInt T = APInt(NewWidth, 2);
10305 LLVM_DEBUG(dbgs() << __func__ << ": equation " << A << "x^2 + " << B
10306 << "x + " << C << ", coeff bw: " << NewWidth
10307 << ", multiplied by " << T << '\n');
10308 return std::make_tuple(A, B, C, T, BitWidth);
10309}
10310
10311/// Helper function to compare optional APInts:
10312/// (a) if X and Y both exist, return min(X, Y),
10313/// (b) if neither X nor Y exist, return std::nullopt,
10314/// (c) if exactly one of X and Y exists, return that value.
10315static std::optional<APInt> MinOptional(std::optional<APInt> X,
10316 std::optional<APInt> Y) {
10317 if (X && Y) {
10318 unsigned W = std::max(X->getBitWidth(), Y->getBitWidth());
10319 APInt XW = X->sext(W);
10320 APInt YW = Y->sext(W);
10321 return XW.slt(YW) ? *X : *Y;
10322 }
10323 if (!X && !Y)
10324 return std::nullopt;
10325 return X ? *X : *Y;
10326}
10327
10328/// Helper function to truncate an optional APInt to a given BitWidth.
10329/// When solving addrec-related equations, it is preferable to return a value
10330/// that has the same bit width as the original addrec's coefficients. If the
10331/// solution fits in the original bit width, truncate it (except for i1).
10332/// Returning a value of a different bit width may inhibit some optimizations.
10333///
10334/// In general, a solution to a quadratic equation generated from an addrec
10335/// may require BW+1 bits, where BW is the bit width of the addrec's
10336/// coefficients. The reason is that the coefficients of the quadratic
10337/// equation are BW+1 bits wide (to avoid truncation when converting from
10338/// the addrec to the equation).
10339static std::optional<APInt> TruncIfPossible(std::optional<APInt> X,
10340 unsigned BitWidth) {
10341 if (!X)
10342 return std::nullopt;
10343 unsigned W = X->getBitWidth();
10344 if (BitWidth > 1 && BitWidth < W && X->isIntN(BitWidth))
10345 return X->trunc(BitWidth);
10346 return X;
10347}
10348
10349/// Let c(n) be the value of the quadratic chrec {L,+,M,+,N} after n
10350/// iterations. The values L, M, N are assumed to be signed, and they
10351/// should all have the same bit widths.
10352/// Find the least n >= 0 such that c(n) = 0 in the arithmetic modulo 2^BW,
10353/// where BW is the bit width of the addrec's coefficients.
10354/// If the calculated value is a BW-bit integer (for BW > 1), it will be
10355/// returned as such, otherwise the bit width of the returned value may
10356/// be greater than BW.
10357///
10358/// This function returns std::nullopt if
10359/// (a) the addrec coefficients are not constant, or
10360/// (b) SolveQuadraticEquationWrap was unable to find a solution. For cases
10361/// like x^2 = 5, no integer solutions exist, in other cases an integer
10362/// solution may exist, but SolveQuadraticEquationWrap may fail to find it.
10363static std::optional<APInt>
10365 APInt A, B, C, M;
10366 unsigned BitWidth;
10367 auto T = GetQuadraticEquation(AddRec);
10368 if (!T)
10369 return std::nullopt;
10370
10371 std::tie(A, B, C, M, BitWidth) = *T;
10372 LLVM_DEBUG(dbgs() << __func__ << ": solving for unsigned overflow\n");
10373 std::optional<APInt> X =
10375 if (!X)
10376 return std::nullopt;
10377
10378 ConstantInt *CX = ConstantInt::get(SE.getContext(), *X);
10379 ConstantInt *V = EvaluateConstantChrecAtConstant(AddRec, CX, SE);
10380 if (!V->isZero())
10381 return std::nullopt;
10382
10383 return TruncIfPossible(X, BitWidth);
10384}
10385
10386/// Let c(n) be the value of the quadratic chrec {0,+,M,+,N} after n
10387/// iterations. The values M, N are assumed to be signed, and they
10388/// should all have the same bit widths.
10389/// Find the least n such that c(n) does not belong to the given range,
10390/// while c(n-1) does.
10391///
10392/// This function returns std::nullopt if
10393/// (a) the addrec coefficients are not constant, or
10394/// (b) SolveQuadraticEquationWrap was unable to find a solution for the
10395/// bounds of the range.
10396static std::optional<APInt>
10398 const ConstantRange &Range, ScalarEvolution &SE) {
10399 assert(AddRec->getOperand(0)->isZero() &&
10400 "Starting value of addrec should be 0");
10401 LLVM_DEBUG(dbgs() << __func__ << ": solving boundary crossing for range "
10402 << Range << ", addrec " << *AddRec << '\n');
10403 // This case is handled in getNumIterationsInRange. Here we can assume that
10404 // we start in the range.
10405 assert(Range.contains(APInt(SE.getTypeSizeInBits(AddRec->getType()), 0)) &&
10406 "Addrec's initial value should be in range");
10407
10408 APInt A, B, C, M;
10409 unsigned BitWidth;
10410 auto T = GetQuadraticEquation(AddRec);
10411 if (!T)
10412 return std::nullopt;
10413
10414 // Be careful about the return value: there can be two reasons for not
10415 // returning an actual number. First, if no solutions to the equations
10416 // were found, and second, if the solutions don't leave the given range.
10417 // The first case means that the actual solution is "unknown", the second
10418 // means that it's known, but not valid. If the solution is unknown, we
10419 // cannot make any conclusions.
10420 // Return a pair: the optional solution and a flag indicating if the
10421 // solution was found.
10422 auto SolveForBoundary =
10423 [&](APInt Bound) -> std::pair<std::optional<APInt>, bool> {
10424 // Solve for signed overflow and unsigned overflow, pick the lower
10425 // solution.
10426 LLVM_DEBUG(dbgs() << "SolveQuadraticAddRecRange: checking boundary "
10427 << Bound << " (before multiplying by " << M << ")\n");
10428 Bound *= M; // The quadratic equation multiplier.
10429
10430 std::optional<APInt> SO;
10431 if (BitWidth > 1) {
10432 LLVM_DEBUG(dbgs() << "SolveQuadraticAddRecRange: solving for "
10433 "signed overflow\n");
10435 }
10436 LLVM_DEBUG(dbgs() << "SolveQuadraticAddRecRange: solving for "
10437 "unsigned overflow\n");
10438 std::optional<APInt> UO =
10440
10441 auto LeavesRange = [&] (const APInt &X) {
10442 ConstantInt *C0 = ConstantInt::get(SE.getContext(), X);
10443 ConstantInt *V0 = EvaluateConstantChrecAtConstant(AddRec, C0, SE);
10444 if (Range.contains(V0->getValue()))
10445 return false;
10446 // X should be at least 1, so X-1 is non-negative.
10447 ConstantInt *C1 = ConstantInt::get(SE.getContext(), X-1);
10448 ConstantInt *V1 = EvaluateConstantChrecAtConstant(AddRec, C1, SE);
10449 if (Range.contains(V1->getValue()))
10450 return true;
10451 return false;
10452 };
10453
10454 // If SolveQuadraticEquationWrap returns std::nullopt, it means that there
10455 // can be a solution, but the function failed to find it. We cannot treat it
10456 // as "no solution".
10457 if (!SO || !UO)
10458 return {std::nullopt, false};
10459
10460 // Check the smaller value first to see if it leaves the range.
10461 // At this point, both SO and UO must have values.
10462 std::optional<APInt> Min = MinOptional(SO, UO);
10463 if (LeavesRange(*Min))
10464 return { Min, true };
10465 std::optional<APInt> Max = Min == SO ? UO : SO;
10466 if (LeavesRange(*Max))
10467 return { Max, true };
10468
10469 // Solutions were found, but were eliminated, hence the "true".
10470 return {std::nullopt, true};
10471 };
10472
10473 std::tie(A, B, C, M, BitWidth) = *T;
10474 // Lower bound is inclusive, subtract 1 to represent the exiting value.
10475 APInt Lower = Range.getLower().sext(A.getBitWidth()) - 1;
10476 APInt Upper = Range.getUpper().sext(A.getBitWidth());
10477 auto SL = SolveForBoundary(Lower);
10478 auto SU = SolveForBoundary(Upper);
10479 // If any of the solutions was unknown, no meaninigful conclusions can
10480 // be made.
10481 if (!SL.second || !SU.second)
10482 return std::nullopt;
10483
10484 // Claim: The correct solution is not some value between Min and Max.
10485 //
10486 // Justification: Assuming that Min and Max are different values, one of
10487 // them is when the first signed overflow happens, the other is when the
10488 // first unsigned overflow happens. Crossing the range boundary is only
10489 // possible via an overflow (treating 0 as a special case of it, modeling
10490 // an overflow as crossing k*2^W for some k).
10491 //
10492 // The interesting case here is when Min was eliminated as an invalid
10493 // solution, but Max was not. The argument is that if there was another
10494 // overflow between Min and Max, it would also have been eliminated if
10495 // it was considered.
10496 //
10497 // For a given boundary, it is possible to have two overflows of the same
10498 // type (signed/unsigned) without having the other type in between: this
10499 // can happen when the vertex of the parabola is between the iterations
10500 // corresponding to the overflows. This is only possible when the two
10501 // overflows cross k*2^W for the same k. In such case, if the second one
10502 // left the range (and was the first one to do so), the first overflow
10503 // would have to enter the range, which would mean that either we had left
10504 // the range before or that we started outside of it. Both of these cases
10505 // are contradictions.
10506 //
10507 // Claim: In the case where SolveForBoundary returns std::nullopt, the correct
10508 // solution is not some value between the Max for this boundary and the
10509 // Min of the other boundary.
10510 //
10511 // Justification: Assume that we had such Max_A and Min_B corresponding
10512 // to range boundaries A and B and such that Max_A < Min_B. If there was
10513 // a solution between Max_A and Min_B, it would have to be caused by an
10514 // overflow corresponding to either A or B. It cannot correspond to B,
10515 // since Min_B is the first occurrence of such an overflow. If it
10516 // corresponded to A, it would have to be either a signed or an unsigned
10517 // overflow that is larger than both eliminated overflows for A. But
10518 // between the eliminated overflows and this overflow, the values would
10519 // cover the entire value space, thus crossing the other boundary, which
10520 // is a contradiction.
10521
10522 return TruncIfPossible(MinOptional(SL.first, SU.first), BitWidth);
10523}
10524
10525ScalarEvolution::ExitLimit ScalarEvolution::howFarToZero(const SCEV *V,
10526 const Loop *L,
10527 bool ControlsOnlyExit,
10528 bool AllowPredicates) {
10529
10530 // This is only used for loops with a "x != y" exit test. The exit condition
10531 // is now expressed as a single expression, V = x-y. So the exit test is
10532 // effectively V != 0. We know and take advantage of the fact that this
10533 // expression only being used in a comparison by zero context.
10534
10536 // If the value is a constant
10537 if (const SCEVConstant *C = dyn_cast<SCEVConstant>(V)) {
10538 // If the value is already zero, the branch will execute zero times.
10539 if (C->getValue()->isZero()) return C;
10540 return getCouldNotCompute(); // Otherwise it will loop infinitely.
10541 }
10542
10543 const SCEVAddRecExpr *AddRec =
10544 dyn_cast<SCEVAddRecExpr>(stripInjectiveFunctions(V));
10545
10546 if (!AddRec && AllowPredicates)
10547 // Try to make this an AddRec using runtime tests, in the first X
10548 // iterations of this loop, where X is the SCEV expression found by the
10549 // algorithm below.
10550 AddRec = convertSCEVToAddRecWithPredicates(V, L, Predicates);
10551
10552 if (!AddRec || AddRec->getLoop() != L)
10553 return getCouldNotCompute();
10554
10555 // If this is a quadratic (3-term) AddRec {L,+,M,+,N}, find the roots of
10556 // the quadratic equation to solve it.
10557 if (AddRec->isQuadratic() && AddRec->getType()->isIntegerTy()) {
10558 // We can only use this value if the chrec ends up with an exact zero
10559 // value at this index. When solving for "X*X != 5", for example, we
10560 // should not accept a root of 2.
10561 if (auto S = SolveQuadraticAddRecExact(AddRec, *this)) {
10562 const auto *R = cast<SCEVConstant>(getConstant(*S));
10563 return ExitLimit(R, R, R, false, Predicates);
10564 }
10565 return getCouldNotCompute();
10566 }
10567
10568 // Otherwise we can only handle this if it is affine.
10569 if (!AddRec->isAffine())
10570 return getCouldNotCompute();
10571
10572 // If this is an affine expression, the execution count of this branch is
10573 // the minimum unsigned root of the following equation:
10574 //
10575 // Start + Step*N = 0 (mod 2^BW)
10576 //
10577 // equivalent to:
10578 //
10579 // Step*N = -Start (mod 2^BW)
10580 //
10581 // where BW is the common bit width of Start and Step.
10582
10583 // Get the initial value for the loop.
10584 const SCEV *Start = getSCEVAtScope(AddRec->getStart(), L->getParentLoop());
10585 const SCEV *Step = getSCEVAtScope(AddRec->getOperand(1), L->getParentLoop());
10586
10587 if (!isLoopInvariant(Step, L))
10588 return getCouldNotCompute();
10589
10590 LoopGuards Guards = LoopGuards::collect(L, *this);
10591 // Specialize step for this loop so we get context sensitive facts below.
10592 const SCEV *StepWLG = applyLoopGuards(Step, Guards);
10593
10594 // For positive steps (counting up until unsigned overflow):
10595 // N = -Start/Step (as unsigned)
10596 // For negative steps (counting down to zero):
10597 // N = Start/-Step
10598 // First compute the unsigned distance from zero in the direction of Step.
10599 bool CountDown = isKnownNegative(StepWLG);
10600 if (!CountDown && !isKnownNonNegative(StepWLG))
10601 return getCouldNotCompute();
10602
10603 const SCEV *Distance = CountDown ? Start : getNegativeSCEV(Start);
10604 // Handle unitary steps, which cannot wraparound.
10605 // 1*N = -Start; -1*N = Start (mod 2^BW), so:
10606 // N = Distance (as unsigned)
10607
10608 if (match(Step, m_CombineOr(m_scev_One(), m_scev_AllOnes()))) {
10609 APInt MaxBECount = getUnsignedRangeMax(applyLoopGuards(Distance, Guards));
10610 MaxBECount = APIntOps::umin(MaxBECount, getUnsignedRangeMax(Distance));
10611
10612 // When a loop like "for (int i = 0; i != n; ++i) { /* body */ }" is rotated,
10613 // we end up with a loop whose backedge-taken count is n - 1. Detect this
10614 // case, and see if we can improve the bound.
10615 //
10616 // Explicitly handling this here is necessary because getUnsignedRange
10617 // isn't context-sensitive; it doesn't know that we only care about the
10618 // range inside the loop.
10619 const SCEV *Zero = getZero(Distance->getType());
10620 const SCEV *One = getOne(Distance->getType());
10621 const SCEV *DistancePlusOne = getAddExpr(Distance, One);
10622 if (isLoopEntryGuardedByCond(L, ICmpInst::ICMP_NE, DistancePlusOne, Zero)) {
10623 // If Distance + 1 doesn't overflow, we can compute the maximum distance
10624 // as "unsigned_max(Distance + 1) - 1".
10625 ConstantRange CR = getUnsignedRange(DistancePlusOne);
10626 MaxBECount = APIntOps::umin(MaxBECount, CR.getUnsignedMax() - 1);
10627 }
10628 return ExitLimit(Distance, getConstant(MaxBECount), Distance, false,
10629 Predicates);
10630 }
10631
10632 // If the condition controls loop exit (the loop exits only if the expression
10633 // is true) and the addition is no-wrap we can use unsigned divide to
10634 // compute the backedge count. In this case, the step may not divide the
10635 // distance, but we don't care because if the condition is "missed" the loop
10636 // will have undefined behavior due to wrapping.
10637 if (ControlsOnlyExit && AddRec->hasNoSelfWrap() &&
10638 loopHasNoAbnormalExits(AddRec->getLoop())) {
10639
10640 // If the stride is zero, the loop must be infinite. In C++, most loops
10641 // are finite by assumption, in which case the step being zero implies
10642 // UB must execute if the loop is entered.
10643 if (!loopIsFiniteByAssumption(L) && !isKnownNonZero(StepWLG))
10644 return getCouldNotCompute();
10645
10646 const SCEV *Exact =
10647 getUDivExpr(Distance, CountDown ? getNegativeSCEV(Step) : Step);
10648 const SCEV *ConstantMax = getCouldNotCompute();
10649 if (Exact != getCouldNotCompute()) {
10651 ConstantMax =
10653 }
10654 const SCEV *SymbolicMax =
10655 isa<SCEVCouldNotCompute>(Exact) ? ConstantMax : Exact;
10656 return ExitLimit(Exact, ConstantMax, SymbolicMax, false, Predicates);
10657 }
10658
10659 // Solve the general equation.
10660 const SCEVConstant *StepC = dyn_cast<SCEVConstant>(Step);
10661 if (!StepC || StepC->getValue()->isZero())
10662 return getCouldNotCompute();
10664 StepC->getAPInt(), getNegativeSCEV(Start),
10665 AllowPredicates ? &Predicates : nullptr, *this);
10666
10667 const SCEV *M = E;
10668 if (E != getCouldNotCompute()) {
10669 APInt MaxWithGuards = getUnsignedRangeMax(applyLoopGuards(E, Guards));
10670 M = getConstant(APIntOps::umin(MaxWithGuards, getUnsignedRangeMax(E)));
10671 }
10672 auto *S = isa<SCEVCouldNotCompute>(E) ? M : E;
10673 return ExitLimit(E, M, S, false, Predicates);
10674}
10675
10677ScalarEvolution::howFarToNonZero(const SCEV *V, const Loop *L) {
10678 // Loops that look like: while (X == 0) are very strange indeed. We don't
10679 // handle them yet except for the trivial case. This could be expanded in the
10680 // future as needed.
10681
10682 // If the value is a constant, check to see if it is known to be non-zero
10683 // already. If so, the backedge will execute zero times.
10684 if (const SCEVConstant *C = dyn_cast<SCEVConstant>(V)) {
10685 if (!C->getValue()->isZero())
10686 return getZero(C->getType());
10687 return getCouldNotCompute(); // Otherwise it will loop infinitely.
10688 }
10689
10690 // We could implement others, but I really doubt anyone writes loops like
10691 // this, and if they did, they would already be constant folded.
10692 return getCouldNotCompute();
10693}
10694
10695std::pair<const BasicBlock *, const BasicBlock *>
10696ScalarEvolution::getPredecessorWithUniqueSuccessorForBB(const BasicBlock *BB)
10697 const {
10698 // If the block has a unique predecessor, then there is no path from the
10699 // predecessor to the block that does not go through the direct edge
10700 // from the predecessor to the block.
10701 if (const BasicBlock *Pred = BB->getSinglePredecessor())
10702 return {Pred, BB};
10703
10704 // A loop's header is defined to be a block that dominates the loop.
10705 // If the header has a unique predecessor outside the loop, it must be
10706 // a block that has exactly one successor that can reach the loop.
10707 if (const Loop *L = LI.getLoopFor(BB))
10708 return {L->getLoopPredecessor(), L->getHeader()};
10709
10710 return {nullptr, BB};
10711}
10712
10713/// SCEV structural equivalence is usually sufficient for testing whether two
10714/// expressions are equal, however for the purposes of looking for a condition
10715/// guarding a loop, it can be useful to be a little more general, since a
10716/// front-end may have replicated the controlling expression.
10717static bool HasSameValue(const SCEV *A, const SCEV *B) {
10718 // Quick check to see if they are the same SCEV.
10719 if (A == B) return true;
10720
10721 auto ComputesEqualValues = [](const Instruction *A, const Instruction *B) {
10722 // Not all instructions that are "identical" compute the same value. For
10723 // instance, two distinct alloca instructions allocating the same type are
10724 // identical and do not read memory; but compute distinct values.
10725 return A->isIdenticalTo(B) && (isa<BinaryOperator>(A) || isa<GetElementPtrInst>(A));
10726 };
10727
10728 // Otherwise, if they're both SCEVUnknown, it's possible that they hold
10729 // two different instructions with the same value. Check for this case.
10730 if (const SCEVUnknown *AU = dyn_cast<SCEVUnknown>(A))
10731 if (const SCEVUnknown *BU = dyn_cast<SCEVUnknown>(B))
10732 if (const Instruction *AI = dyn_cast<Instruction>(AU->getValue()))
10733 if (const Instruction *BI = dyn_cast<Instruction>(BU->getValue()))
10734 if (ComputesEqualValues(AI, BI))
10735 return true;
10736
10737 // Otherwise assume they may have a different value.
10738 return false;
10739}
10740
10741static bool MatchBinarySub(const SCEV *S, const SCEV *&LHS, const SCEV *&RHS) {
10742 const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(S);
10743 if (!Add || Add->getNumOperands() != 2)
10744 return false;
10745 if (auto *ME = dyn_cast<SCEVMulExpr>(Add->getOperand(0));
10746 ME && ME->getNumOperands() == 2 && ME->getOperand(0)->isAllOnesValue()) {
10747 LHS = Add->getOperand(1);
10748 RHS = ME->getOperand(1);
10749 return true;
10750 }
10751 if (auto *ME = dyn_cast<SCEVMulExpr>(Add->getOperand(1));
10752 ME && ME->getNumOperands() == 2 && ME->getOperand(0)->isAllOnesValue()) {
10753 LHS = Add->getOperand(0);
10754 RHS = ME->getOperand(1);
10755 return true;
10756 }
10757 return false;
10758}
10759
10761 const SCEV *&LHS, const SCEV *&RHS,
10762 unsigned Depth) {
10763 bool Changed = false;
10764 // Simplifies ICMP to trivial true or false by turning it into '0 == 0' or
10765 // '0 != 0'.
10766 auto TrivialCase = [&](bool TriviallyTrue) {
10768 Pred = TriviallyTrue ? ICmpInst::ICMP_EQ : ICmpInst::ICMP_NE;
10769 return true;
10770 };
10771 // If we hit the max recursion limit bail out.
10772 if (Depth >= 3)
10773 return false;
10774
10775 // Canonicalize a constant to the right side.
10776 if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(LHS)) {
10777 // Check for both operands constant.
10778 if (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS)) {
10779 if (!ICmpInst::compare(LHSC->getAPInt(), RHSC->getAPInt(), Pred))
10780 return TrivialCase(false);
10781 return TrivialCase(true);
10782 }
10783 // Otherwise swap the operands to put the constant on the right.
10784 std::swap(LHS, RHS);
10785 Pred = ICmpInst::getSwappedPredicate(Pred);
10786 Changed = true;
10787 }
10788
10789 // If we're comparing an addrec with a value which is loop-invariant in the
10790 // addrec's loop, put the addrec on the left. Also make a dominance check,
10791 // as both operands could be addrecs loop-invariant in each other's loop.
10792 if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(RHS)) {
10793 const Loop *L = AR->getLoop();
10794 if (isLoopInvariant(LHS, L) && properlyDominates(LHS, L->getHeader())) {
10795 std::swap(LHS, RHS);
10796 Pred = ICmpInst::getSwappedPredicate(Pred);
10797 Changed = true;
10798 }
10799 }
10800
10801 // If there's a constant operand, canonicalize comparisons with boundary
10802 // cases, and canonicalize *-or-equal comparisons to regular comparisons.
10803 if (const SCEVConstant *RC = dyn_cast<SCEVConstant>(RHS)) {
10804 const APInt &RA = RC->getAPInt();
10805
10806 bool SimplifiedByConstantRange = false;
10807
10808 if (!ICmpInst::isEquality(Pred)) {
10810 if (ExactCR.isFullSet())
10811 return TrivialCase(true);
10812 if (ExactCR.isEmptySet())
10813 return TrivialCase(false);
10814
10815 APInt NewRHS;
10816 CmpInst::Predicate NewPred;
10817 if (ExactCR.getEquivalentICmp(NewPred, NewRHS) &&
10818 ICmpInst::isEquality(NewPred)) {
10819 // We were able to convert an inequality to an equality.
10820 Pred = NewPred;
10821 RHS = getConstant(NewRHS);
10822 Changed = SimplifiedByConstantRange = true;
10823 }
10824 }
10825
10826 if (!SimplifiedByConstantRange) {
10827 switch (Pred) {
10828 default:
10829 break;
10830 case ICmpInst::ICMP_EQ:
10831 case ICmpInst::ICMP_NE:
10832 // Fold ((-1) * %a) + %b == 0 (equivalent to %b-%a == 0) into %a == %b.
10833 if (RA.isZero() && MatchBinarySub(LHS, LHS, RHS))
10834 Changed = true;
10835 break;
10836
10837 // The "Should have been caught earlier!" messages refer to the fact
10838 // that the ExactCR.isFullSet() or ExactCR.isEmptySet() check above
10839 // should have fired on the corresponding cases, and canonicalized the
10840 // check to trivial case.
10841
10842 case ICmpInst::ICMP_UGE:
10843 assert(!RA.isMinValue() && "Should have been caught earlier!");
10844 Pred = ICmpInst::ICMP_UGT;
10845 RHS = getConstant(RA - 1);
10846 Changed = true;
10847 break;
10848 case ICmpInst::ICMP_ULE:
10849 assert(!RA.isMaxValue() && "Should have been caught earlier!");
10850 Pred = ICmpInst::ICMP_ULT;
10851 RHS = getConstant(RA + 1);
10852 Changed = true;
10853 break;
10854 case ICmpInst::ICMP_SGE:
10855 assert(!RA.isMinSignedValue() && "Should have been caught earlier!");
10856 Pred = ICmpInst::ICMP_SGT;
10857 RHS = getConstant(RA - 1);
10858 Changed = true;
10859 break;
10860 case ICmpInst::ICMP_SLE:
10861 assert(!RA.isMaxSignedValue() && "Should have been caught earlier!");
10862 Pred = ICmpInst::ICMP_SLT;
10863 RHS = getConstant(RA + 1);
10864 Changed = true;
10865 break;
10866 }
10867 }
10868 }
10869
10870 // Check for obvious equality.
10871 if (HasSameValue(LHS, RHS)) {
10872 if (ICmpInst::isTrueWhenEqual(Pred))
10873 return TrivialCase(true);
10875 return TrivialCase(false);
10876 }
10877
10878 // If possible, canonicalize GE/LE comparisons to GT/LT comparisons, by
10879 // adding or subtracting 1 from one of the operands.
10880 switch (Pred) {
10881 case ICmpInst::ICMP_SLE:
10882 if (!getSignedRangeMax(RHS).isMaxSignedValue()) {
10883 RHS = getAddExpr(getConstant(RHS->getType(), 1, true), RHS,
10885 Pred = ICmpInst::ICMP_SLT;
10886 Changed = true;
10887 } else if (!getSignedRangeMin(LHS).isMinSignedValue()) {
10888 LHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), LHS,
10890 Pred = ICmpInst::ICMP_SLT;
10891 Changed = true;
10892 }
10893 break;
10894 case ICmpInst::ICMP_SGE:
10895 if (!getSignedRangeMin(RHS).isMinSignedValue()) {
10896 RHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), RHS,
10898 Pred = ICmpInst::ICMP_SGT;
10899 Changed = true;
10900 } else if (!getSignedRangeMax(LHS).isMaxSignedValue()) {
10901 LHS = getAddExpr(getConstant(RHS->getType(), 1, true), LHS,
10903 Pred = ICmpInst::ICMP_SGT;
10904 Changed = true;
10905 }
10906 break;
10907 case ICmpInst::ICMP_ULE:
10908 if (!getUnsignedRangeMax(RHS).isMaxValue()) {
10909 RHS = getAddExpr(getConstant(RHS->getType(), 1, true), RHS,
10911 Pred = ICmpInst::ICMP_ULT;
10912 Changed = true;
10913 } else if (!getUnsignedRangeMin(LHS).isMinValue()) {
10914 LHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), LHS);
10915 Pred = ICmpInst::ICMP_ULT;
10916 Changed = true;
10917 }
10918 break;
10919 case ICmpInst::ICMP_UGE:
10920 if (!getUnsignedRangeMin(RHS).isMinValue()) {
10921 RHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), RHS);
10922 Pred = ICmpInst::ICMP_UGT;
10923 Changed = true;
10924 } else if (!getUnsignedRangeMax(LHS).isMaxValue()) {
10925 LHS = getAddExpr(getConstant(RHS->getType(), 1, true), LHS,
10927 Pred = ICmpInst::ICMP_UGT;
10928 Changed = true;
10929 }
10930 break;
10931 default:
10932 break;
10933 }
10934
10935 // TODO: More simplifications are possible here.
10936
10937 // Recursively simplify until we either hit a recursion limit or nothing
10938 // changes.
10939 if (Changed)
10940 return SimplifyICmpOperands(Pred, LHS, RHS, Depth + 1);
10941
10942 return Changed;
10943}
10944
10946 return getSignedRangeMax(S).isNegative();
10947}
10948
10951}
10952
10954 return !getSignedRangeMin(S).isNegative();
10955}
10956
10959}
10960
10962 // Query push down for cases where the unsigned range is
10963 // less than sufficient.
10964 if (const auto *SExt = dyn_cast<SCEVSignExtendExpr>(S))
10965 return isKnownNonZero(SExt->getOperand(0));
10966 return getUnsignedRangeMin(S) != 0;
10967}
10968
10970 bool OrNegative) {
10971 auto NonRecursive = [this, OrNegative](const SCEV *S) {
10972 if (auto *C = dyn_cast<SCEVConstant>(S))
10973 return C->getAPInt().isPowerOf2() ||
10974 (OrNegative && C->getAPInt().isNegatedPowerOf2());
10975
10976 // The vscale_range indicates vscale is a power-of-two.
10977 return isa<SCEVVScale>(S) && F.hasFnAttribute(Attribute::VScaleRange);
10978 };
10979
10980 if (NonRecursive(S))
10981 return true;
10982
10983 auto *Mul = dyn_cast<SCEVMulExpr>(S);
10984 if (!Mul)
10985 return false;
10986 return all_of(Mul->operands(), NonRecursive) && (OrZero || isKnownNonZero(S));
10987}
10988
10989std::pair<const SCEV *, const SCEV *>
10991 // Compute SCEV on entry of loop L.
10992 const SCEV *Start = SCEVInitRewriter::rewrite(S, L, *this);
10993 if (Start == getCouldNotCompute())
10994 return { Start, Start };
10995 // Compute post increment SCEV for loop L.
10996 const SCEV *PostInc = SCEVPostIncRewriter::rewrite(S, L, *this);
10997 assert(PostInc != getCouldNotCompute() && "Unexpected could not compute");
10998 return { Start, PostInc };
10999}
11000
11002 const SCEV *LHS, const SCEV *RHS) {
11003 // First collect all loops.
11005 getUsedLoops(LHS, LoopsUsed);
11006 getUsedLoops(RHS, LoopsUsed);
11007
11008 if (LoopsUsed.empty())
11009 return false;
11010
11011 // Domination relationship must be a linear order on collected loops.
11012#ifndef NDEBUG
11013 for (const auto *L1 : LoopsUsed)
11014 for (const auto *L2 : LoopsUsed)
11015 assert((DT.dominates(L1->getHeader(), L2->getHeader()) ||
11016 DT.dominates(L2->getHeader(), L1->getHeader())) &&
11017 "Domination relationship is not a linear order");
11018#endif
11019
11020 const Loop *MDL =
11021 *llvm::max_element(LoopsUsed, [&](const Loop *L1, const Loop *L2) {
11022 return DT.properlyDominates(L1->getHeader(), L2->getHeader());
11023 });
11024
11025 // Get init and post increment value for LHS.
11026 auto SplitLHS = SplitIntoInitAndPostInc(MDL, LHS);
11027 // if LHS contains unknown non-invariant SCEV then bail out.
11028 if (SplitLHS.first == getCouldNotCompute())
11029 return false;
11030 assert (SplitLHS.second != getCouldNotCompute() && "Unexpected CNC");
11031 // Get init and post increment value for RHS.
11032 auto SplitRHS = SplitIntoInitAndPostInc(MDL, RHS);
11033 // if RHS contains unknown non-invariant SCEV then bail out.
11034 if (SplitRHS.first == getCouldNotCompute())
11035 return false;
11036 assert (SplitRHS.second != getCouldNotCompute() && "Unexpected CNC");
11037 // It is possible that init SCEV contains an invariant load but it does
11038 // not dominate MDL and is not available at MDL loop entry, so we should
11039 // check it here.
11040 if (!isAvailableAtLoopEntry(SplitLHS.first, MDL) ||
11041 !isAvailableAtLoopEntry(SplitRHS.first, MDL))
11042 return false;
11043
11044 // It seems backedge guard check is faster than entry one so in some cases
11045 // it can speed up whole estimation by short circuit
11046 return isLoopBackedgeGuardedByCond(MDL, Pred, SplitLHS.second,
11047 SplitRHS.second) &&
11048 isLoopEntryGuardedByCond(MDL, Pred, SplitLHS.first, SplitRHS.first);
11049}
11050
11052 const SCEV *LHS, const SCEV *RHS) {
11053 // Canonicalize the inputs first.
11054 (void)SimplifyICmpOperands(Pred, LHS, RHS);
11055
11056 if (isKnownViaInduction(Pred, LHS, RHS))
11057 return true;
11058
11059 if (isKnownPredicateViaSplitting(Pred, LHS, RHS))
11060 return true;
11061
11062 // Otherwise see what can be done with some simple reasoning.
11063 return isKnownViaNonRecursiveReasoning(Pred, LHS, RHS);
11064}
11065
11067 const SCEV *LHS,
11068 const SCEV *RHS) {
11069 if (isKnownPredicate(Pred, LHS, RHS))
11070 return true;
11072 return false;
11073 return std::nullopt;
11074}
11075
11077 const SCEV *LHS, const SCEV *RHS,
11078 const Instruction *CtxI) {
11079 // TODO: Analyze guards and assumes from Context's block.
11080 return isKnownPredicate(Pred, LHS, RHS) ||
11082}
11083
11084std::optional<bool>
11086 const SCEV *RHS, const Instruction *CtxI) {
11087 std::optional<bool> KnownWithoutContext = evaluatePredicate(Pred, LHS, RHS);
11088 if (KnownWithoutContext)
11089 return KnownWithoutContext;
11090
11091 if (isBasicBlockEntryGuardedByCond(CtxI->getParent(), Pred, LHS, RHS))
11092 return true;
11095 LHS, RHS))
11096 return false;
11097 return std::nullopt;
11098}
11099
11101 const SCEVAddRecExpr *LHS,
11102 const SCEV *RHS) {
11103 const Loop *L = LHS->getLoop();
11104 return isLoopEntryGuardedByCond(L, Pred, LHS->getStart(), RHS) &&
11105 isLoopBackedgeGuardedByCond(L, Pred, LHS->getPostIncExpr(*this), RHS);
11106}
11107
11108std::optional<ScalarEvolution::MonotonicPredicateType>
11110 ICmpInst::Predicate Pred) {
11111 auto Result = getMonotonicPredicateTypeImpl(LHS, Pred);
11112
11113#ifndef NDEBUG
11114 // Verify an invariant: inverting the predicate should turn a monotonically
11115 // increasing change to a monotonically decreasing one, and vice versa.
11116 if (Result) {
11117 auto ResultSwapped =
11118 getMonotonicPredicateTypeImpl(LHS, ICmpInst::getSwappedPredicate(Pred));
11119
11120 assert(*ResultSwapped != *Result &&
11121 "monotonicity should flip as we flip the predicate");
11122 }
11123#endif
11124
11125 return Result;
11126}
11127
11128std::optional<ScalarEvolution::MonotonicPredicateType>
11129ScalarEvolution::getMonotonicPredicateTypeImpl(const SCEVAddRecExpr *LHS,
11130 ICmpInst::Predicate Pred) {
11131 // A zero step value for LHS means the induction variable is essentially a
11132 // loop invariant value. We don't really depend on the predicate actually
11133 // flipping from false to true (for increasing predicates, and the other way
11134 // around for decreasing predicates), all we care about is that *if* the
11135 // predicate changes then it only changes from false to true.
11136 //
11137 // A zero step value in itself is not very useful, but there may be places
11138 // where SCEV can prove X >= 0 but not prove X > 0, so it is helpful to be
11139 // as general as possible.
11140
11141 // Only handle LE/LT/GE/GT predicates.
11142 if (!ICmpInst::isRelational(Pred))
11143 return std::nullopt;
11144
11145 bool IsGreater = ICmpInst::isGE(Pred) || ICmpInst::isGT(Pred);
11146 assert((IsGreater || ICmpInst::isLE(Pred) || ICmpInst::isLT(Pred)) &&
11147 "Should be greater or less!");
11148
11149 // Check that AR does not wrap.
11150 if (ICmpInst::isUnsigned(Pred)) {
11151 if (!LHS->hasNoUnsignedWrap())
11152 return std::nullopt;
11154 }
11155 assert(ICmpInst::isSigned(Pred) &&
11156 "Relational predicate is either signed or unsigned!");
11157 if (!LHS->hasNoSignedWrap())
11158 return std::nullopt;
11159
11160 const SCEV *Step = LHS->getStepRecurrence(*this);
11161
11162 if (isKnownNonNegative(Step))
11164
11165 if (isKnownNonPositive(Step))
11167
11168 return std::nullopt;
11169}
11170
11171std::optional<ScalarEvolution::LoopInvariantPredicate>
11173 const SCEV *LHS, const SCEV *RHS,
11174 const Loop *L,
11175 const Instruction *CtxI) {
11176 // If there is a loop-invariant, force it into the RHS, otherwise bail out.
11177 if (!isLoopInvariant(RHS, L)) {
11178 if (!isLoopInvariant(LHS, L))
11179 return std::nullopt;
11180
11181 std::swap(LHS, RHS);
11182 Pred = ICmpInst::getSwappedPredicate(Pred);
11183 }
11184
11185 const SCEVAddRecExpr *ArLHS = dyn_cast<SCEVAddRecExpr>(LHS);
11186 if (!ArLHS || ArLHS->getLoop() != L)
11187 return std::nullopt;
11188
11189 auto MonotonicType = getMonotonicPredicateType(ArLHS, Pred);
11190 if (!MonotonicType)
11191 return std::nullopt;
11192 // If the predicate "ArLHS `Pred` RHS" monotonically increases from false to
11193 // true as the loop iterates, and the backedge is control dependent on
11194 // "ArLHS `Pred` RHS" == true then we can reason as follows:
11195 //
11196 // * if the predicate was false in the first iteration then the predicate
11197 // is never evaluated again, since the loop exits without taking the
11198 // backedge.
11199 // * if the predicate was true in the first iteration then it will
11200 // continue to be true for all future iterations since it is
11201 // monotonically increasing.
11202 //
11203 // For both the above possibilities, we can replace the loop varying
11204 // predicate with its value on the first iteration of the loop (which is
11205 // loop invariant).
11206 //
11207 // A similar reasoning applies for a monotonically decreasing predicate, by
11208 // replacing true with false and false with true in the above two bullets.
11210 auto P = Increasing ? Pred : ICmpInst::getInversePredicate(Pred);
11211
11214 RHS);
11215
11216 if (!CtxI)
11217 return std::nullopt;
11218 // Try to prove via context.
11219 // TODO: Support other cases.
11220 switch (Pred) {
11221 default:
11222 break;
11223 case ICmpInst::ICMP_ULE:
11224 case ICmpInst::ICMP_ULT: {
11225 assert(ArLHS->hasNoUnsignedWrap() && "Is a requirement of monotonicity!");
11226 // Given preconditions
11227 // (1) ArLHS does not cross the border of positive and negative parts of
11228 // range because of:
11229 // - Positive step; (TODO: lift this limitation)
11230 // - nuw - does not cross zero boundary;
11231 // - nsw - does not cross SINT_MAX boundary;
11232 // (2) ArLHS <s RHS
11233 // (3) RHS >=s 0
11234 // we can replace the loop variant ArLHS <u RHS condition with loop
11235 // invariant Start(ArLHS) <u RHS.
11236 //
11237 // Because of (1) there are two options:
11238 // - ArLHS is always negative. It means that ArLHS <u RHS is always false;
11239 // - ArLHS is always non-negative. Because of (3) RHS is also non-negative.
11240 // It means that ArLHS <s RHS <=> ArLHS <u RHS.
11241 // Because of (2) ArLHS <u RHS is trivially true.
11242 // All together it means that ArLHS <u RHS <=> Start(ArLHS) >=s 0.
11243 // We can strengthen this to Start(ArLHS) <u RHS.
11244 auto SignFlippedPred = ICmpInst::getFlippedSignednessPredicate(Pred);
11245 if (ArLHS->hasNoSignedWrap() && ArLHS->isAffine() &&
11246 isKnownPositive(ArLHS->getStepRecurrence(*this)) &&
11248 isKnownPredicateAt(SignFlippedPred, ArLHS, RHS, CtxI))
11250 RHS);
11251 }
11252 }
11253
11254 return std::nullopt;
11255}
11256
11257std::optional<ScalarEvolution::LoopInvariantPredicate>
11259 ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS, const Loop *L,
11260 const Instruction *CtxI, const SCEV *MaxIter) {
11262 Pred, LHS, RHS, L, CtxI, MaxIter))
11263 return LIP;
11264 if (auto *UMin = dyn_cast<SCEVUMinExpr>(MaxIter))
11265 // Number of iterations expressed as UMIN isn't always great for expressing
11266 // the value on the last iteration. If the straightforward approach didn't
11267 // work, try the following trick: if the a predicate is invariant for X, it
11268 // is also invariant for umin(X, ...). So try to find something that works
11269 // among subexpressions of MaxIter expressed as umin.
11270 for (auto *Op : UMin->operands())
11272 Pred, LHS, RHS, L, CtxI, Op))
11273 return LIP;
11274 return std::nullopt;
11275}
11276
11277std::optional<ScalarEvolution::LoopInvariantPredicate>
11279 ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS, const Loop *L,
11280 const Instruction *CtxI, const SCEV *MaxIter) {
11281 // Try to prove the following set of facts:
11282 // - The predicate is monotonic in the iteration space.
11283 // - If the check does not fail on the 1st iteration:
11284 // - No overflow will happen during first MaxIter iterations;
11285 // - It will not fail on the MaxIter'th iteration.
11286 // If the check does fail on the 1st iteration, we leave the loop and no
11287 // other checks matter.
11288
11289 // If there is a loop-invariant, force it into the RHS, otherwise bail out.
11290 if (!isLoopInvariant(RHS, L)) {
11291 if (!isLoopInvariant(LHS, L))
11292 return std::nullopt;
11293
11294 std::swap(LHS, RHS);
11295 Pred = ICmpInst::getSwappedPredicate(Pred);
11296 }
11297
11298 auto *AR = dyn_cast<SCEVAddRecExpr>(LHS);
11299 if (!AR || AR->getLoop() != L)
11300 return std::nullopt;
11301
11302 // The predicate must be relational (i.e. <, <=, >=, >).
11303 if (!ICmpInst::isRelational(Pred))
11304 return std::nullopt;
11305
11306 // TODO: Support steps other than +/- 1.
11307 const SCEV *Step = AR->getStepRecurrence(*this);
11308 auto *One = getOne(Step->getType());
11309 auto *MinusOne = getNegativeSCEV(One);
11310 if (Step != One && Step != MinusOne)
11311 return std::nullopt;
11312
11313 // Type mismatch here means that MaxIter is potentially larger than max
11314 // unsigned value in start type, which mean we cannot prove no wrap for the
11315 // indvar.
11316 if (AR->getType() != MaxIter->getType())
11317 return std::nullopt;
11318
11319 // Value of IV on suggested last iteration.
11320 const SCEV *Last = AR->evaluateAtIteration(MaxIter, *this);
11321 // Does it still meet the requirement?
11322 if (!isLoopBackedgeGuardedByCond(L, Pred, Last, RHS))
11323 return std::nullopt;
11324 // Because step is +/- 1 and MaxIter has same type as Start (i.e. it does
11325 // not exceed max unsigned value of this type), this effectively proves
11326 // that there is no wrap during the iteration. To prove that there is no
11327 // signed/unsigned wrap, we need to check that
11328 // Start <= Last for step = 1 or Start >= Last for step = -1.
11329 ICmpInst::Predicate NoOverflowPred =
11331 if (Step == MinusOne)
11332 NoOverflowPred = CmpInst::getSwappedPredicate(NoOverflowPred);
11333 const SCEV *Start = AR->getStart();
11334 if (!isKnownPredicateAt(NoOverflowPred, Start, Last, CtxI))
11335 return std::nullopt;
11336
11337 // Everything is fine.
11338 return ScalarEvolution::LoopInvariantPredicate(Pred, Start, RHS);
11339}
11340
11341bool ScalarEvolution::isKnownPredicateViaConstantRanges(
11342 ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS) {
11343 if (HasSameValue(LHS, RHS))
11344 return ICmpInst::isTrueWhenEqual(Pred);
11345
11346 // This code is split out from isKnownPredicate because it is called from
11347 // within isLoopEntryGuardedByCond.
11348
11349 auto CheckRanges = [&](const ConstantRange &RangeLHS,
11350 const ConstantRange &RangeRHS) {
11351 return RangeLHS.icmp(Pred, RangeRHS);
11352 };
11353
11354 // The check at the top of the function catches the case where the values are
11355 // known to be equal.
11356 if (Pred == CmpInst::ICMP_EQ)
11357 return false;
11358
11359 if (Pred == CmpInst::ICMP_NE) {
11360 auto SL = getSignedRange(LHS);
11361 auto SR = getSignedRange(RHS);
11362 if (CheckRanges(SL, SR))
11363 return true;
11364 auto UL = getUnsignedRange(LHS);
11365 auto UR = getUnsignedRange(RHS);
11366 if (CheckRanges(UL, UR))
11367 return true;
11368 auto *Diff = getMinusSCEV(LHS, RHS);
11369 return !isa<SCEVCouldNotCompute>(Diff) && isKnownNonZero(Diff);
11370 }
11371
11372 if (CmpInst::isSigned(Pred)) {
11373 auto SL = getSignedRange(LHS);
11374 auto SR = getSignedRange(RHS);
11375 return CheckRanges(SL, SR);
11376 }
11377
11378 auto UL = getUnsignedRange(LHS);
11379 auto UR = getUnsignedRange(RHS);
11380 return CheckRanges(UL, UR);
11381}
11382
11383bool ScalarEvolution::isKnownPredicateViaNoOverflow(ICmpInst::Predicate Pred,
11384 const SCEV *LHS,
11385 const SCEV *RHS) {
11386 // Match X to (A + C1)<ExpectedFlags> and Y to (A + C2)<ExpectedFlags>, where
11387 // C1 and C2 are constant integers. If either X or Y are not add expressions,
11388 // consider them as X + 0 and Y + 0 respectively. C1 and C2 are returned via
11389 // OutC1 and OutC2.
11390 auto MatchBinaryAddToConst = [this](const SCEV *X, const SCEV *Y,
11391 APInt &OutC1, APInt &OutC2,
11392 SCEV::NoWrapFlags ExpectedFlags) {
11393 const SCEV *XNonConstOp, *XConstOp;
11394 const SCEV *YNonConstOp, *YConstOp;
11395 SCEV::NoWrapFlags XFlagsPresent;
11396 SCEV::NoWrapFlags YFlagsPresent;
11397
11398 if (!splitBinaryAdd(X, XConstOp, XNonConstOp, XFlagsPresent)) {
11399 XConstOp = getZero(X->getType());
11400 XNonConstOp = X;
11401 XFlagsPresent = ExpectedFlags;
11402 }
11403 if (!isa<SCEVConstant>(XConstOp) ||
11404 (XFlagsPresent & ExpectedFlags) != ExpectedFlags)
11405 return false;
11406
11407 if (!splitBinaryAdd(Y, YConstOp, YNonConstOp, YFlagsPresent)) {
11408 YConstOp = getZero(Y->getType());
11409 YNonConstOp = Y;
11410 YFlagsPresent = ExpectedFlags;
11411 }
11412
11413 if (!isa<SCEVConstant>(YConstOp) ||
11414 (YFlagsPresent & ExpectedFlags) != ExpectedFlags)
11415 return false;
11416
11417 if (YNonConstOp != XNonConstOp)
11418 return false;
11419
11420 OutC1 = cast<SCEVConstant>(XConstOp)->getAPInt();
11421 OutC2 = cast<SCEVConstant>(YConstOp)->getAPInt();
11422
11423 return true;
11424 };
11425
11426 APInt C1;
11427 APInt C2;
11428
11429 switch (Pred) {
11430 default:
11431 break;
11432
11433 case ICmpInst::ICMP_SGE:
11434 std::swap(LHS, RHS);
11435 [[fallthrough]];
11436 case ICmpInst::ICMP_SLE:
11437 // (X + C1)<nsw> s<= (X + C2)<nsw> if C1 s<= C2.
11438 if (MatchBinaryAddToConst(LHS, RHS, C1, C2, SCEV::FlagNSW) && C1.sle(C2))
11439 return true;
11440
11441 break;
11442
11443 case ICmpInst::ICMP_SGT:
11444 std::swap(LHS, RHS);
11445 [[fallthrough]];
11446 case ICmpInst::ICMP_SLT:
11447 // (X + C1)<nsw> s< (X + C2)<nsw> if C1 s< C2.
11448 if (MatchBinaryAddToConst(LHS, RHS, C1, C2, SCEV::FlagNSW) && C1.slt(C2))
11449 return true;
11450
11451 break;
11452
11453 case ICmpInst::ICMP_UGE:
11454 std::swap(LHS, RHS);
11455 [[fallthrough]];
11456 case ICmpInst::ICMP_ULE:
11457 // (X + C1)<nuw> u<= (X + C2)<nuw> for C1 u<= C2.
11458 if (MatchBinaryAddToConst(LHS, RHS, C1, C2, SCEV::FlagNUW) && C1.ule(C2))
11459 return true;
11460
11461 break;
11462
11463 case ICmpInst::ICMP_UGT:
11464 std::swap(LHS, RHS);
11465 [[fallthrough]];
11466 case ICmpInst::ICMP_ULT:
11467 // (X + C1)<nuw> u< (X + C2)<nuw> if C1 u< C2.
11468 if (MatchBinaryAddToConst(LHS, RHS, C1, C2, SCEV::FlagNUW) && C1.ult(C2))
11469 return true;
11470 break;
11471 }
11472
11473 return false;
11474}
11475
11476bool ScalarEvolution::isKnownPredicateViaSplitting(ICmpInst::Predicate Pred,
11477 const SCEV *LHS,
11478 const SCEV *RHS) {
11479 if (Pred != ICmpInst::ICMP_ULT || ProvingSplitPredicate)
11480 return false;
11481
11482 // Allowing arbitrary number of activations of isKnownPredicateViaSplitting on
11483 // the stack can result in exponential time complexity.
11484 SaveAndRestore Restore(ProvingSplitPredicate, true);
11485
11486 // If L >= 0 then I `ult` L <=> I >= 0 && I `slt` L
11487 //
11488 // To prove L >= 0 we use isKnownNonNegative whereas to prove I >= 0 we use
11489 // isKnownPredicate. isKnownPredicate is more powerful, but also more
11490 // expensive; and using isKnownNonNegative(RHS) is sufficient for most of the
11491 // interesting cases seen in practice. We can consider "upgrading" L >= 0 to
11492 // use isKnownPredicate later if needed.
11493 return isKnownNonNegative(RHS) &&
11496}
11497
11498bool ScalarEvolution::isImpliedViaGuard(const BasicBlock *BB,
11500 const SCEV *LHS, const SCEV *RHS) {
11501 // No need to even try if we know the module has no guards.
11502 if (!HasGuards)
11503 return false;
11504
11505 return any_of(*BB, [&](const Instruction &I) {
11506 using namespace llvm::PatternMatch;
11507
11508 Value *Condition;
11509 return match(&I, m_Intrinsic<Intrinsic::experimental_guard>(
11510 m_Value(Condition))) &&
11511 isImpliedCond(Pred, LHS, RHS, Condition, false);
11512 });
11513}
11514
11515/// isLoopBackedgeGuardedByCond - Test whether the backedge of the loop is
11516/// protected by a conditional between LHS and RHS. This is used to
11517/// to eliminate casts.
11518bool
11521 const SCEV *LHS, const SCEV *RHS) {
11522 // Interpret a null as meaning no loop, where there is obviously no guard
11523 // (interprocedural conditions notwithstanding). Do not bother about
11524 // unreachable loops.
11525 if (!L || !DT.isReachableFromEntry(L->getHeader()))
11526 return true;
11527
11528 if (VerifyIR)
11529 assert(!verifyFunction(*L->getHeader()->getParent(), &dbgs()) &&
11530 "This cannot be done on broken IR!");
11531
11532
11533 if (isKnownViaNonRecursiveReasoning(Pred, LHS, RHS))
11534 return true;
11535
11536 BasicBlock *Latch = L->getLoopLatch();
11537 if (!Latch)
11538 return false;
11539
11540 BranchInst *LoopContinuePredicate =
11541 dyn_cast<BranchInst>(Latch->getTerminator());
11542 if (LoopContinuePredicate && LoopContinuePredicate->isConditional() &&
11543 isImpliedCond(Pred, LHS, RHS,
11544 LoopContinuePredicate->getCondition(),
11545 LoopContinuePredicate->getSuccessor(0) != L->getHeader()))
11546 return true;
11547
11548 // We don't want more than one activation of the following loops on the stack
11549 // -- that can lead to O(n!) time complexity.
11550 if (WalkingBEDominatingConds)
11551 return false;
11552
11553 SaveAndRestore ClearOnExit(WalkingBEDominatingConds, true);
11554
11555 // See if we can exploit a trip count to prove the predicate.
11556 const auto &BETakenInfo = getBackedgeTakenInfo(L);
11557 const SCEV *LatchBECount = BETakenInfo.getExact(Latch, this);
11558 if (LatchBECount != getCouldNotCompute()) {
11559 // We know that Latch branches back to the loop header exactly
11560 // LatchBECount times. This means the backdege condition at Latch is
11561 // equivalent to "{0,+,1} u< LatchBECount".
11562 Type *Ty = LatchBECount->getType();
11563 auto NoWrapFlags = SCEV::NoWrapFlags(SCEV::FlagNUW | SCEV::FlagNW);
11564 const SCEV *LoopCounter =
11565 getAddRecExpr(getZero(Ty), getOne(Ty), L, NoWrapFlags);
11566 if (isImpliedCond(Pred, LHS, RHS, ICmpInst::ICMP_ULT, LoopCounter,
11567 LatchBECount))
11568 return true;
11569 }
11570
11571 // Check conditions due to any @llvm.assume intrinsics.
11572 for (auto &AssumeVH : AC.assumptions()) {
11573 if (!AssumeVH)
11574 continue;
11575 auto *CI = cast<CallInst>(AssumeVH);
11576 if (!DT.dominates(CI, Latch->getTerminator()))
11577 continue;
11578
11579 if (isImpliedCond(Pred, LHS, RHS, CI->getArgOperand(0), false))
11580 return true;
11581 }
11582
11583 if (isImpliedViaGuard(Latch, Pred, LHS, RHS))
11584 return true;
11585
11586 for (DomTreeNode *DTN = DT[Latch], *HeaderDTN = DT[L->getHeader()];
11587 DTN != HeaderDTN; DTN = DTN->getIDom()) {
11588 assert(DTN && "should reach the loop header before reaching the root!");
11589
11590 BasicBlock *BB = DTN->getBlock();
11591 if (isImpliedViaGuard(BB, Pred, LHS, RHS))
11592 return true;
11593
11594 BasicBlock *PBB = BB->getSinglePredecessor();
11595 if (!PBB)
11596 continue;
11597
11598 BranchInst *ContinuePredicate = dyn_cast<BranchInst>(PBB->getTerminator());
11599 if (!ContinuePredicate || !ContinuePredicate->isConditional())
11600 continue;
11601
11602 Value *Condition = ContinuePredicate->getCondition();
11603
11604 // If we have an edge `E` within the loop body that dominates the only
11605 // latch, the condition guarding `E` also guards the backedge. This
11606 // reasoning works only for loops with a single latch.
11607
11608 BasicBlockEdge DominatingEdge(PBB, BB);
11609 if (DominatingEdge.isSingleEdge()) {
11610 // We're constructively (and conservatively) enumerating edges within the
11611 // loop body that dominate the latch. The dominator tree better agree
11612 // with us on this:
11613 assert(DT.dominates(DominatingEdge, Latch) && "should be!");
11614
11615 if (isImpliedCond(Pred, LHS, RHS, Condition,
11616 BB != ContinuePredicate->getSuccessor(0)))
11617 return true;
11618 }
11619 }
11620
11621 return false;
11622}
11623
11626 const SCEV *LHS,
11627 const SCEV *RHS) {
11628 // Do not bother proving facts for unreachable code.
11629 if (!DT.isReachableFromEntry(BB))
11630 return true;
11631 if (VerifyIR)
11632 assert(!verifyFunction(*BB->getParent(), &dbgs()) &&
11633 "This cannot be done on broken IR!");
11634
11635 // If we cannot prove strict comparison (e.g. a > b), maybe we can prove
11636 // the facts (a >= b && a != b) separately. A typical situation is when the
11637 // non-strict comparison is known from ranges and non-equality is known from
11638 // dominating predicates. If we are proving strict comparison, we always try
11639 // to prove non-equality and non-strict comparison separately.
11640 auto NonStrictPredicate = ICmpInst::getNonStrictPredicate(Pred);
11641 const bool ProvingStrictComparison = (Pred != NonStrictPredicate);
11642 bool ProvedNonStrictComparison = false;
11643 bool ProvedNonEquality = false;
11644
11645 auto SplitAndProve =
11646 [&](std::function<bool(ICmpInst::Predicate)> Fn) -> bool {
11647 if (!ProvedNonStrictComparison)
11648 ProvedNonStrictComparison = Fn(NonStrictPredicate);
11649 if (!ProvedNonEquality)
11650 ProvedNonEquality = Fn(ICmpInst::ICMP_NE);
11651 if (ProvedNonStrictComparison && ProvedNonEquality)
11652 return true;
11653 return false;
11654 };
11655
11656 if (ProvingStrictComparison) {
11657 auto ProofFn = [&](ICmpInst::Predicate P) {
11658 return isKnownViaNonRecursiveReasoning(P, LHS, RHS);
11659 };
11660 if (SplitAndProve(ProofFn))
11661 return true;
11662 }
11663
11664 // Try to prove (Pred, LHS, RHS) using isImpliedCond.
11665 auto ProveViaCond = [&](const Value *Condition, bool Inverse) {
11666 const Instruction *CtxI = &BB->front();
11667 if (isImpliedCond(Pred, LHS, RHS, Condition, Inverse, CtxI))
11668 return true;
11669 if (ProvingStrictComparison) {
11670 auto ProofFn = [&](ICmpInst::Predicate P) {
11671 return isImpliedCond(P, LHS, RHS, Condition, Inverse, CtxI);
11672 };
11673 if (SplitAndProve(ProofFn))
11674 return true;
11675 }
11676 return false;
11677 };
11678
11679 // Starting at the block's predecessor, climb up the predecessor chain, as long
11680 // as there are predecessors that can be found that have unique successors
11681 // leading to the original block.
11682 const Loop *ContainingLoop = LI.getLoopFor(BB);
11683 const BasicBlock *PredBB;
11684 if (ContainingLoop && ContainingLoop->getHeader() == BB)
11685 PredBB = ContainingLoop->getLoopPredecessor();
11686 else
11687 PredBB = BB->getSinglePredecessor();
11688 for (std::pair<const BasicBlock *, const BasicBlock *> Pair(PredBB, BB);
11689 Pair.first; Pair = getPredecessorWithUniqueSuccessorForBB(Pair.first)) {
11690 const BranchInst *BlockEntryPredicate =
11691 dyn_cast<BranchInst>(Pair.first->getTerminator());
11692 if (!BlockEntryPredicate || BlockEntryPredicate->isUnconditional())
11693 continue;
11694
11695 if (ProveViaCond(BlockEntryPredicate->getCondition(),
11696 BlockEntryPredicate->getSuccessor(0) != Pair.second))
11697 return true;
11698 }
11699
11700 // Check conditions due to any @llvm.assume intrinsics.
11701 for (auto &AssumeVH : AC.assumptions()) {
11702 if (!AssumeVH)
11703 continue;
11704 auto *CI = cast<CallInst>(AssumeVH);
11705 if (!DT.dominates(CI, BB))
11706 continue;
11707
11708 if (ProveViaCond(CI->getArgOperand(0), false))
11709 return true;
11710 }
11711
11712 // Check conditions due to any @llvm.experimental.guard intrinsics.
11713 auto *GuardDecl = Intrinsic::getDeclarationIfExists(
11714 F.getParent(), Intrinsic::experimental_guard);
11715 if (GuardDecl)
11716 for (const auto *GU : GuardDecl->users())
11717 if (const auto *Guard = dyn_cast<IntrinsicInst>(GU))
11718 if (Guard->getFunction() == BB->getParent() && DT.dominates(Guard, BB))
11719 if (ProveViaCond(Guard->getArgOperand(0), false))
11720 return true;
11721 return false;
11722}
11723
11726 const SCEV *LHS,
11727 const SCEV *RHS) {
11728 // Interpret a null as meaning no loop, where there is obviously no guard
11729 // (interprocedural conditions notwithstanding).
11730 if (!L)
11731 return false;
11732
11733 // Both LHS and RHS must be available at loop entry.
11735 "LHS is not available at Loop Entry");
11737 "RHS is not available at Loop Entry");
11738
11739 if (isKnownViaNonRecursiveReasoning(Pred, LHS, RHS))
11740 return true;
11741
11742 return isBasicBlockEntryGuardedByCond(L->getHeader(), Pred, LHS, RHS);
11743}
11744
11745bool ScalarEvolution::isImpliedCond(ICmpInst::Predicate Pred, const SCEV *LHS,
11746 const SCEV *RHS,
11747 const Value *FoundCondValue, bool Inverse,
11748 const Instruction *CtxI) {
11749 // False conditions implies anything. Do not bother analyzing it further.
11750 if (FoundCondValue ==
11751 ConstantInt::getBool(FoundCondValue->getContext(), Inverse))
11752 return true;
11753
11754 if (!PendingLoopPredicates.insert(FoundCondValue).second)
11755 return false;
11756
11757 auto ClearOnExit =
11758 make_scope_exit([&]() { PendingLoopPredicates.erase(FoundCondValue); });
11759
11760 // Recursively handle And and Or conditions.
11761 const Value *Op0, *Op1;
11762 if (match(FoundCondValue, m_LogicalAnd(m_Value(Op0), m_Value(Op1)))) {
11763 if (!Inverse)
11764 return isImpliedCond(Pred, LHS, RHS, Op0, Inverse, CtxI) ||
11765 isImpliedCond(Pred, LHS, RHS, Op1, Inverse, CtxI);
11766 } else if (match(FoundCondValue, m_LogicalOr(m_Value(Op0), m_Value(Op1)))) {
11767 if (Inverse)
11768 return isImpliedCond(Pred, LHS, RHS, Op0, Inverse, CtxI) ||
11769 isImpliedCond(Pred, LHS, RHS, Op1, Inverse, CtxI);
11770 }
11771
11772 const ICmpInst *ICI = dyn_cast<ICmpInst>(FoundCondValue);
11773 if (!ICI) return false;
11774
11775 // Now that we found a conditional branch that dominates the loop or controls
11776 // the loop latch. Check to see if it is the comparison we are looking for.
11777 ICmpInst::Predicate FoundPred;
11778 if (Inverse)
11779 FoundPred = ICI->getInversePredicate();
11780 else
11781 FoundPred = ICI->getPredicate();
11782
11783 const SCEV *FoundLHS = getSCEV(ICI->getOperand(0));
11784 const SCEV *FoundRHS = getSCEV(ICI->getOperand(1));
11785
11786 return isImpliedCond(Pred, LHS, RHS, FoundPred, FoundLHS, FoundRHS, CtxI);
11787}
11788
11789bool ScalarEvolution::isImpliedCond(ICmpInst::Predicate Pred, const SCEV *LHS,
11790 const SCEV *RHS,
11791 ICmpInst::Predicate FoundPred,
11792 const SCEV *FoundLHS, const SCEV *FoundRHS,
11793 const Instruction *CtxI) {
11794 // Balance the types.
11795 if (getTypeSizeInBits(LHS->getType()) <
11796 getTypeSizeInBits(FoundLHS->getType())) {
11797 // For unsigned and equality predicates, try to prove that both found
11798 // operands fit into narrow unsigned range. If so, try to prove facts in
11799 // narrow types.
11800 if (!CmpInst::isSigned(FoundPred) && !FoundLHS->getType()->isPointerTy() &&
11801 !FoundRHS->getType()->isPointerTy()) {
11802 auto *NarrowType = LHS->getType();
11803 auto *WideType = FoundLHS->getType();
11804 auto BitWidth = getTypeSizeInBits(NarrowType);
11805 const SCEV *MaxValue = getZeroExtendExpr(
11807 if (isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_ULE, FoundLHS,
11808 MaxValue) &&
11809 isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_ULE, FoundRHS,
11810 MaxValue)) {
11811 const SCEV *TruncFoundLHS = getTruncateExpr(FoundLHS, NarrowType);
11812 const SCEV *TruncFoundRHS = getTruncateExpr(FoundRHS, NarrowType);
11813 if (isImpliedCondBalancedTypes(Pred, LHS, RHS, FoundPred, TruncFoundLHS,
11814 TruncFoundRHS, CtxI))
11815 return true;
11816 }
11817 }
11818
11819 if (LHS->getType()->isPointerTy() || RHS->getType()->isPointerTy())
11820 return false;
11821 if (CmpInst::isSigned(Pred)) {
11822 LHS = getSignExtendExpr(LHS, FoundLHS->getType());
11823 RHS = getSignExtendExpr(RHS, FoundLHS->getType());
11824 } else {
11825 LHS = getZeroExtendExpr(LHS, FoundLHS->getType());
11826 RHS = getZeroExtendExpr(RHS, FoundLHS->getType());
11827 }
11828 } else if (getTypeSizeInBits(LHS->getType()) >
11829 getTypeSizeInBits(FoundLHS->getType())) {
11830 if (FoundLHS->getType()->isPointerTy() || FoundRHS->getType()->isPointerTy())
11831 return false;
11832 if (CmpInst::isSigned(FoundPred)) {
11833 FoundLHS = getSignExtendExpr(FoundLHS, LHS->getType());
11834 FoundRHS = getSignExtendExpr(FoundRHS, LHS->getType());
11835 } else {
11836 FoundLHS = getZeroExtendExpr(FoundLHS, LHS->getType());
11837 FoundRHS = getZeroExtendExpr(FoundRHS, LHS->getType());
11838 }
11839 }
11840 return isImpliedCondBalancedTypes(Pred, LHS, RHS, FoundPred, FoundLHS,
11841 FoundRHS, CtxI);
11842}
11843
11844bool ScalarEvolution::isImpliedCondBalancedTypes(
11845 ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS,
11846 ICmpInst::Predicate FoundPred, const SCEV *FoundLHS, const SCEV *FoundRHS,
11847 const Instruction *CtxI) {
11849 getTypeSizeInBits(FoundLHS->getType()) &&
11850 "Types should be balanced!");
11851 // Canonicalize the query to match the way instcombine will have
11852 // canonicalized the comparison.
11853 if (SimplifyICmpOperands(Pred, LHS, RHS))
11854 if (LHS == RHS)
11855 return CmpInst::isTrueWhenEqual(Pred);
11856 if (SimplifyICmpOperands(FoundPred, FoundLHS, FoundRHS))
11857 if (FoundLHS == FoundRHS)
11858 return CmpInst::isFalseWhenEqual(FoundPred);
11859
11860 // Check to see if we can make the LHS or RHS match.
11861 if (LHS == FoundRHS || RHS == FoundLHS) {
11862 if (isa<SCEVConstant>(RHS)) {
11863 std::swap(FoundLHS, FoundRHS);
11864 FoundPred = ICmpInst::getSwappedPredicate(FoundPred);
11865 } else {
11866 std::swap(LHS, RHS);
11867 Pred = ICmpInst::getSwappedPredicate(Pred);
11868 }
11869 }
11870
11871 // Check whether the found predicate is the same as the desired predicate.
11872 if (FoundPred == Pred)
11873 return isImpliedCondOperands(Pred, LHS, RHS, FoundLHS, FoundRHS, CtxI);
11874
11875 // Check whether swapping the found predicate makes it the same as the
11876 // desired predicate.
11877 if (ICmpInst::getSwappedPredicate(FoundPred) == Pred) {
11878 // We can write the implication
11879 // 0. LHS Pred RHS <- FoundLHS SwapPred FoundRHS
11880 // using one of the following ways:
11881 // 1. LHS Pred RHS <- FoundRHS Pred FoundLHS
11882 // 2. RHS SwapPred LHS <- FoundLHS SwapPred FoundRHS
11883 // 3. LHS Pred RHS <- ~FoundLHS Pred ~FoundRHS
11884 // 4. ~LHS SwapPred ~RHS <- FoundLHS SwapPred FoundRHS
11885 // Forms 1. and 2. require swapping the operands of one condition. Don't
11886 // do this if it would break canonical constant/addrec ordering.
11887 if (!isa<SCEVConstant>(RHS) && !isa<SCEVAddRecExpr>(LHS))
11888 return isImpliedCondOperands(FoundPred, RHS, LHS, FoundLHS, FoundRHS,
11889 CtxI);
11890 if (!isa<SCEVConstant>(FoundRHS) && !isa<SCEVAddRecExpr>(FoundLHS))
11891 return isImpliedCondOperands(Pred, LHS, RHS, FoundRHS, FoundLHS, CtxI);
11892
11893 // There's no clear preference between forms 3. and 4., try both. Avoid
11894 // forming getNotSCEV of pointer values as the resulting subtract is
11895 // not legal.
11896 if (!LHS->getType()->isPointerTy() && !RHS->getType()->isPointerTy() &&
11897 isImpliedCondOperands(FoundPred, getNotSCEV(LHS), getNotSCEV(RHS),
11898 FoundLHS, FoundRHS, CtxI))
11899 return true;
11900
11901 if (!FoundLHS->getType()->isPointerTy() &&
11902 !FoundRHS->getType()->isPointerTy() &&
11903 isImpliedCondOperands(Pred, LHS, RHS, getNotSCEV(FoundLHS),
11904 getNotSCEV(FoundRHS), CtxI))
11905 return true;
11906
11907 return false;
11908 }
11909
11910 auto IsSignFlippedPredicate = [](CmpInst::Predicate P1,
11911 CmpInst::Predicate P2) {
11912 assert(P1 != P2 && "Handled earlier!");
11913 return CmpInst::isRelational(P2) &&
11915 };
11916 if (IsSignFlippedPredicate(Pred, FoundPred)) {
11917 // Unsigned comparison is the same as signed comparison when both the
11918 // operands are non-negative or negative.
11919 if ((isKnownNonNegative(FoundLHS) && isKnownNonNegative(FoundRHS)) ||
11920 (isKnownNegative(FoundLHS) && isKnownNegative(FoundRHS)))
11921 return isImpliedCondOperands(Pred, LHS, RHS, FoundLHS, FoundRHS, CtxI);
11922 // Create local copies that we can freely swap and canonicalize our
11923 // conditions to "le/lt".
11924 ICmpInst::Predicate CanonicalPred = Pred, CanonicalFoundPred = FoundPred;
11925 const SCEV *CanonicalLHS = LHS, *CanonicalRHS = RHS,
11926 *CanonicalFoundLHS = FoundLHS, *CanonicalFoundRHS = FoundRHS;
11927 if (ICmpInst::isGT(CanonicalPred) || ICmpInst::isGE(CanonicalPred)) {
11928 CanonicalPred = ICmpInst::getSwappedPredicate(CanonicalPred);
11929 CanonicalFoundPred = ICmpInst::getSwappedPredicate(CanonicalFoundPred);
11930 std::swap(CanonicalLHS, CanonicalRHS);
11931 std::swap(CanonicalFoundLHS, CanonicalFoundRHS);
11932 }
11933 assert((ICmpInst::isLT(CanonicalPred) || ICmpInst::isLE(CanonicalPred)) &&
11934 "Must be!");
11935 assert((ICmpInst::isLT(CanonicalFoundPred) ||
11936 ICmpInst::isLE(CanonicalFoundPred)) &&
11937 "Must be!");
11938 if (ICmpInst::isSigned(CanonicalPred) && isKnownNonNegative(CanonicalRHS))
11939 // Use implication:
11940 // x <u y && y >=s 0 --> x <s y.
11941 // If we can prove the left part, the right part is also proven.
11942 return isImpliedCondOperands(CanonicalFoundPred, CanonicalLHS,
11943 CanonicalRHS, CanonicalFoundLHS,
11944 CanonicalFoundRHS);
11945 if (ICmpInst::isUnsigned(CanonicalPred) && isKnownNegative(CanonicalRHS))
11946 // Use implication:
11947 // x <s y && y <s 0 --> x <u y.
11948 // If we can prove the left part, the right part is also proven.
11949 return isImpliedCondOperands(CanonicalFoundPred, CanonicalLHS,
11950 CanonicalRHS, CanonicalFoundLHS,
11951 CanonicalFoundRHS);
11952 }
11953
11954 // Check if we can make progress by sharpening ranges.
11955 if (FoundPred == ICmpInst::ICMP_NE &&
11956 (isa<SCEVConstant>(FoundLHS) || isa<SCEVConstant>(FoundRHS))) {
11957
11958 const SCEVConstant *C = nullptr;
11959 const SCEV *V = nullptr;
11960
11961 if (isa<SCEVConstant>(FoundLHS)) {
11962 C = cast<SCEVConstant>(FoundLHS);
11963 V = FoundRHS;
11964 } else {
11965 C = cast<SCEVConstant>(FoundRHS);
11966 V = FoundLHS;
11967 }
11968
11969 // The guarding predicate tells us that C != V. If the known range
11970 // of V is [C, t), we can sharpen the range to [C + 1, t). The
11971 // range we consider has to correspond to same signedness as the
11972 // predicate we're interested in folding.
11973
11974 APInt Min = ICmpInst::isSigned(Pred) ?
11976
11977 if (Min == C->getAPInt()) {
11978 // Given (V >= Min && V != Min) we conclude V >= (Min + 1).
11979 // This is true even if (Min + 1) wraps around -- in case of
11980 // wraparound, (Min + 1) < Min, so (V >= Min => V >= (Min + 1)).
11981
11982 APInt SharperMin = Min + 1;
11983
11984 switch (Pred) {
11985 case ICmpInst::ICMP_SGE:
11986 case ICmpInst::ICMP_UGE:
11987 // We know V `Pred` SharperMin. If this implies LHS `Pred`
11988 // RHS, we're done.
11989 if (isImpliedCondOperands(Pred, LHS, RHS, V, getConstant(SharperMin),
11990 CtxI))
11991 return true;
11992 [[fallthrough]];
11993
11994 case ICmpInst::ICMP_SGT:
11995 case ICmpInst::ICMP_UGT:
11996 // We know from the range information that (V `Pred` Min ||
11997 // V == Min). We know from the guarding condition that !(V
11998 // == Min). This gives us
11999 //
12000 // V `Pred` Min || V == Min && !(V == Min)
12001 // => V `Pred` Min
12002 //
12003 // If V `Pred` Min implies LHS `Pred` RHS, we're done.
12004
12005 if (isImpliedCondOperands(Pred, LHS, RHS, V, getConstant(Min), CtxI))
12006 return true;
12007 break;
12008
12009 // `LHS < RHS` and `LHS <= RHS` are handled in the same way as `RHS > LHS` and `RHS >= LHS` respectively.
12010 case ICmpInst::ICMP_SLE:
12011 case ICmpInst::ICMP_ULE:
12012 if (isImpliedCondOperands(CmpInst::getSwappedPredicate(Pred), RHS,
12013 LHS, V, getConstant(SharperMin), CtxI))
12014 return true;
12015 [[fallthrough]];
12016
12017 case ICmpInst::ICMP_SLT:
12018 case ICmpInst::ICMP_ULT:
12019 if (isImpliedCondOperands(CmpInst::getSwappedPredicate(Pred), RHS,
12020 LHS, V, getConstant(Min), CtxI))
12021 return true;
12022 break;
12023
12024 default:
12025 // No change
12026 break;
12027 }
12028 }
12029 }
12030
12031 // Check whether the actual condition is beyond sufficient.
12032 if (FoundPred == ICmpInst::ICMP_EQ)
12033 if (ICmpInst::isTrueWhenEqual(Pred))
12034 if (isImpliedCondOperands(Pred, LHS, RHS, FoundLHS, FoundRHS, CtxI))
12035 return true;
12036 if (Pred == ICmpInst::ICMP_NE)
12037 if (!ICmpInst::isTrueWhenEqual(FoundPred))
12038 if (isImpliedCondOperands(FoundPred, LHS, RHS, FoundLHS, FoundRHS, CtxI))
12039 return true;
12040
12041 if (isImpliedCondOperandsViaRanges(Pred, LHS, RHS, FoundPred, FoundLHS, FoundRHS))
12042 return true;
12043
12044 // Otherwise assume the worst.
12045 return false;
12046}
12047
12048bool ScalarEvolution::splitBinaryAdd(const SCEV *Expr,
12049 const SCEV *&L, const SCEV *&R,
12050 SCEV::NoWrapFlags &Flags) {
12051 const auto *AE = dyn_cast<SCEVAddExpr>(Expr);
12052 if (!AE || AE->getNumOperands() != 2)
12053 return false;
12054
12055 L = AE->getOperand(0);
12056 R = AE->getOperand(1);
12057 Flags = AE->getNoWrapFlags();
12058 return true;
12059}
12060
12061std::optional<APInt>
12063 // We avoid subtracting expressions here because this function is usually
12064 // fairly deep in the call stack (i.e. is called many times).
12065
12066 unsigned BW = getTypeSizeInBits(More->getType());
12067 APInt Diff(BW, 0);
12068 APInt DiffMul(BW, 1);
12069 // Try various simplifications to reduce the difference to a constant. Limit
12070 // the number of allowed simplifications to keep compile-time low.
12071 for (unsigned I = 0; I < 8; ++I) {
12072 if (More == Less)
12073 return Diff;
12074
12075 // Reduce addrecs with identical steps to their start value.
12076 if (isa<SCEVAddRecExpr>(Less) && isa<SCEVAddRecExpr>(More)) {
12077 const auto *LAR = cast<SCEVAddRecExpr>(Less);
12078 const auto *MAR = cast<SCEVAddRecExpr>(More);
12079
12080 if (LAR->getLoop() != MAR->getLoop())
12081 return std::nullopt;
12082
12083 // We look at affine expressions only; not for correctness but to keep
12084 // getStepRecurrence cheap.
12085 if (!LAR->isAffine() || !MAR->isAffine())
12086 return std::nullopt;
12087
12088 if (LAR->getStepRecurrence(*this) != MAR->getStepRecurrence(*this))
12089 return std::nullopt;
12090
12091 Less = LAR->getStart();
12092 More = MAR->getStart();
12093 continue;
12094 }
12095
12096 // Try to match a common constant multiply.
12097 auto MatchConstMul =
12098 [](const SCEV *S) -> std::optional<std::pair<const SCEV *, APInt>> {
12099 auto *M = dyn_cast<SCEVMulExpr>(S);
12100 if (!M || M->getNumOperands() != 2 ||
12101 !isa<SCEVConstant>(M->getOperand(0)))
12102 return std::nullopt;
12103 return {
12104 {M->getOperand(1), cast<SCEVConstant>(M->getOperand(0))->getAPInt()}};
12105 };
12106 if (auto MatchedMore = MatchConstMul(More)) {
12107 if (auto MatchedLess = MatchConstMul(Less)) {
12108 if (MatchedMore->second == MatchedLess->second) {
12109 More = MatchedMore->first;
12110 Less = MatchedLess->first;
12111 DiffMul *= MatchedMore->second;
12112 continue;
12113 }
12114 }
12115 }
12116
12117 // Try to cancel out common factors in two add expressions.
12119 auto Add = [&](const SCEV *S, int Mul) {
12120 if (auto *C = dyn_cast<SCEVConstant>(S)) {
12121 if (Mul == 1) {
12122 Diff += C->getAPInt() * DiffMul;
12123 } else {
12124 assert(Mul == -1);
12125 Diff -= C->getAPInt() * DiffMul;
12126 }
12127 } else
12128 Multiplicity[S] += Mul;
12129 };
12130 auto Decompose = [&](const SCEV *S, int Mul) {
12131 if (isa<SCEVAddExpr>(S)) {
12132 for (const SCEV *Op : S->operands())
12133 Add(Op, Mul);
12134 } else
12135 Add(S, Mul);
12136 };
12137 Decompose(More, 1);
12138 Decompose(Less, -1);
12139
12140 // Check whether all the non-constants cancel out, or reduce to new
12141 // More/Less values.
12142 const SCEV *NewMore = nullptr, *NewLess = nullptr;
12143 for (const auto &[S, Mul] : Multiplicity) {
12144 if (Mul == 0)
12145 continue;
12146 if (Mul == 1) {
12147 if (NewMore)
12148 return std::nullopt;
12149 NewMore = S;
12150 } else if (Mul == -1) {
12151 if (NewLess)
12152 return std::nullopt;
12153 NewLess = S;
12154 } else
12155 return std::nullopt;
12156 }
12157
12158 // Values stayed the same, no point in trying further.
12159 if (NewMore == More || NewLess == Less)
12160 return std::nullopt;
12161
12162 More = NewMore;
12163 Less = NewLess;
12164
12165 // Reduced to constant.
12166 if (!More && !Less)
12167 return Diff;
12168
12169 // Left with variable on only one side, bail out.
12170 if (!More || !Less)
12171 return std::nullopt;
12172 }
12173
12174 // Did not reduce to constant.
12175 return std::nullopt;
12176}
12177
12178bool ScalarEvolution::isImpliedCondOperandsViaAddRecStart(
12179 ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS,
12180 const SCEV *FoundLHS, const SCEV *FoundRHS, const Instruction *CtxI) {
12181 // Try to recognize the following pattern:
12182 //
12183 // FoundRHS = ...
12184 // ...
12185 // loop:
12186 // FoundLHS = {Start,+,W}
12187 // context_bb: // Basic block from the same loop
12188 // known(Pred, FoundLHS, FoundRHS)
12189 //
12190 // If some predicate is known in the context of a loop, it is also known on
12191 // each iteration of this loop, including the first iteration. Therefore, in
12192 // this case, `FoundLHS Pred FoundRHS` implies `Start Pred FoundRHS`. Try to
12193 // prove the original pred using this fact.
12194 if (!CtxI)
12195 return false;
12196 const BasicBlock *ContextBB = CtxI->getParent();
12197 // Make sure AR varies in the context block.
12198 if (auto *AR = dyn_cast<SCEVAddRecExpr>(FoundLHS)) {
12199 const Loop *L = AR->getLoop();
12200 // Make sure that context belongs to the loop and executes on 1st iteration
12201 // (if it ever executes at all).
12202 if (!L->contains(ContextBB) || !DT.dominates(ContextBB, L->getLoopLatch()))
12203 return false;
12204 if (!isAvailableAtLoopEntry(FoundRHS, AR->getLoop()))
12205 return false;
12206 return isImpliedCondOperands(Pred, LHS, RHS, AR->getStart(), FoundRHS);
12207 }
12208
12209 if (auto *AR = dyn_cast<SCEVAddRecExpr>(FoundRHS)) {
12210 const Loop *L = AR->getLoop();
12211 // Make sure that context belongs to the loop and executes on 1st iteration
12212 // (if it ever executes at all).
12213 if (!L->contains(ContextBB) || !DT.dominates(ContextBB, L->getLoopLatch()))
12214 return false;
12215 if (!isAvailableAtLoopEntry(FoundLHS, AR->getLoop()))
12216 return false;
12217 return isImpliedCondOperands(Pred, LHS, RHS, FoundLHS, AR->getStart());
12218 }
12219
12220 return false;
12221}
12222
12223bool ScalarEvolution::isImpliedCondOperandsViaNoOverflow(
12224 ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS,
12225 const SCEV *FoundLHS, const SCEV *FoundRHS) {
12226 if (Pred != CmpInst::ICMP_SLT && Pred != CmpInst::ICMP_ULT)
12227 return false;
12228
12229 const auto *AddRecLHS = dyn_cast<SCEVAddRecExpr>(LHS);
12230 if (!AddRecLHS)
12231 return false;
12232
12233 const auto *AddRecFoundLHS = dyn_cast<SCEVAddRecExpr>(FoundLHS);
12234 if (!AddRecFoundLHS)
12235 return false;
12236
12237 // We'd like to let SCEV reason about control dependencies, so we constrain
12238 // both the inequalities to be about add recurrences on the same loop. This
12239 // way we can use isLoopEntryGuardedByCond later.
12240
12241 const Loop *L = AddRecFoundLHS->getLoop();
12242 if (L != AddRecLHS->getLoop())
12243 return false;
12244
12245 // FoundLHS u< FoundRHS u< -C => (FoundLHS + C) u< (FoundRHS + C) ... (1)
12246 //
12247 // FoundLHS s< FoundRHS s< INT_MIN - C => (FoundLHS + C) s< (FoundRHS + C)
12248 // ... (2)
12249 //
12250 // Informal proof for (2), assuming (1) [*]:
12251 //
12252 // We'll also assume (A s< B) <=> ((A + INT_MIN) u< (B + INT_MIN)) ... (3)[**]
12253 //
12254 // Then
12255 //
12256 // FoundLHS s< FoundRHS s< INT_MIN - C
12257 // <=> (FoundLHS + INT_MIN) u< (FoundRHS + INT_MIN) u< -C [ using (3) ]
12258 // <=> (FoundLHS + INT_MIN + C) u< (FoundRHS + INT_MIN + C) [ using (1) ]
12259 // <=> (FoundLHS + INT_MIN + C + INT_MIN) s<
12260 // (FoundRHS + INT_MIN + C + INT_MIN) [ using (3) ]
12261 // <=> FoundLHS + C s< FoundRHS + C
12262 //
12263 // [*]: (1) can be proved by ruling out overflow.
12264 //
12265 // [**]: This can be proved by analyzing all the four possibilities:
12266 // (A s< 0, B s< 0), (A s< 0, B s>= 0), (A s>= 0, B s< 0) and
12267 // (A s>= 0, B s>= 0).
12268 //
12269 // Note:
12270 // Despite (2), "FoundRHS s< INT_MIN - C" does not mean that "FoundRHS + C"
12271 // will not sign underflow. For instance, say FoundLHS = (i8 -128), FoundRHS
12272 // = (i8 -127) and C = (i8 -100). Then INT_MIN - C = (i8 -28), and FoundRHS
12273 // s< (INT_MIN - C). Lack of sign overflow / underflow in "FoundRHS + C" is
12274 // neither necessary nor sufficient to prove "(FoundLHS + C) s< (FoundRHS +
12275 // C)".
12276
12277 std::optional<APInt> LDiff = computeConstantDifference(LHS, FoundLHS);
12278 if (!LDiff)
12279 return false;
12280 std::optional<APInt> RDiff = computeConstantDifference(RHS, FoundRHS);
12281 if (!RDiff || *LDiff != *RDiff)
12282 return false;
12283
12284 if (LDiff->isMinValue())
12285 return true;
12286
12287 APInt FoundRHSLimit;
12288
12289 if (Pred == CmpInst::ICMP_ULT) {
12290 FoundRHSLimit = -(*RDiff);
12291 } else {
12292 assert(Pred == CmpInst::ICMP_SLT && "Checked above!");
12293 FoundRHSLimit = APInt::getSignedMinValue(getTypeSizeInBits(RHS->getType())) - *RDiff;
12294 }
12295
12296 // Try to prove (1) or (2), as needed.
12297 return isAvailableAtLoopEntry(FoundRHS, L) &&
12298 isLoopEntryGuardedByCond(L, Pred, FoundRHS,
12299 getConstant(FoundRHSLimit));
12300}
12301
12302bool ScalarEvolution::isImpliedViaMerge(ICmpInst::Predicate Pred,
12303 const SCEV *LHS, const SCEV *RHS,
12304 const SCEV *FoundLHS,
12305 const SCEV *FoundRHS, unsigned Depth) {
12306 const PHINode *LPhi = nullptr, *RPhi = nullptr;
12307
12308 auto ClearOnExit = make_scope_exit([&]() {
12309 if (LPhi) {
12310 bool Erased = PendingMerges.erase(LPhi);
12311 assert(Erased && "Failed to erase LPhi!");
12312 (void)Erased;
12313 }
12314 if (RPhi) {
12315 bool Erased = PendingMerges.erase(RPhi);
12316 assert(Erased && "Failed to erase RPhi!");
12317 (void)Erased;
12318 }
12319 });
12320
12321 // Find respective Phis and check that they are not being pending.
12322 if (const SCEVUnknown *LU = dyn_cast<SCEVUnknown>(LHS))
12323 if (auto *Phi = dyn_cast<PHINode>(LU->getValue())) {
12324 if (!PendingMerges.insert(Phi).second)
12325 return false;
12326 LPhi = Phi;
12327 }
12328 if (const SCEVUnknown *RU = dyn_cast<SCEVUnknown>(RHS))
12329 if (auto *Phi = dyn_cast<PHINode>(RU->getValue())) {
12330 // If we detect a loop of Phi nodes being processed by this method, for
12331 // example:
12332 //
12333 // %a = phi i32 [ %some1, %preheader ], [ %b, %latch ]
12334 // %b = phi i32 [ %some2, %preheader ], [ %a, %latch ]
12335 //
12336 // we don't want to deal with a case that complex, so return conservative
12337 // answer false.
12338 if (!PendingMerges.insert(Phi).second)
12339 return false;
12340 RPhi = Phi;
12341 }
12342
12343 // If none of LHS, RHS is a Phi, nothing to do here.
12344 if (!LPhi && !RPhi)
12345 return false;
12346
12347 // If there is a SCEVUnknown Phi we are interested in, make it left.
12348 if (!LPhi) {
12349 std::swap(LHS, RHS);
12350 std::swap(FoundLHS, FoundRHS);
12351 std::swap(LPhi, RPhi);
12352 Pred = ICmpInst::getSwappedPredicate(Pred);
12353 }
12354
12355 assert(LPhi && "LPhi should definitely be a SCEVUnknown Phi!");
12356 const BasicBlock *LBB = LPhi->getParent();
12357 const SCEVAddRecExpr *RAR = dyn_cast<SCEVAddRecExpr>(RHS);
12358
12359 auto ProvedEasily = [&](const SCEV *S1, const SCEV *S2) {
12360 return isKnownViaNonRecursiveReasoning(Pred, S1, S2) ||
12361 isImpliedCondOperandsViaRanges(Pred, S1, S2, Pred, FoundLHS, FoundRHS) ||
12362 isImpliedViaOperations(Pred, S1, S2, FoundLHS, FoundRHS, Depth);
12363 };
12364
12365 if (RPhi && RPhi->getParent() == LBB) {
12366 // Case one: RHS is also a SCEVUnknown Phi from the same basic block.
12367 // If we compare two Phis from the same block, and for each entry block
12368 // the predicate is true for incoming values from this block, then the
12369 // predicate is also true for the Phis.
12370 for (const BasicBlock *IncBB : predecessors(LBB)) {
12371 const SCEV *L = getSCEV(LPhi->getIncomingValueForBlock(IncBB));
12372 const SCEV *R = getSCEV(RPhi->getIncomingValueForBlock(IncBB));
12373 if (!ProvedEasily(L, R))
12374 return false;
12375 }
12376 } else if (RAR && RAR->getLoop()->getHeader() == LBB) {
12377 // Case two: RHS is also a Phi from the same basic block, and it is an
12378 // AddRec. It means that there is a loop which has both AddRec and Unknown
12379 // PHIs, for it we can compare incoming values of AddRec from above the loop
12380 // and latch with their respective incoming values of LPhi.
12381 // TODO: Generalize to handle loops with many inputs in a header.
12382 if (LPhi->getNumIncomingValues() != 2) return false;
12383
12384 auto *RLoop = RAR->getLoop();
12385 auto *Predecessor = RLoop->getLoopPredecessor();
12386 assert(Predecessor && "Loop with AddRec with no predecessor?");
12387 const SCEV *L1 = getSCEV(LPhi->getIncomingValueForBlock(Predecessor));
12388 if (!ProvedEasily(L1, RAR->getStart()))
12389 return false;
12390 auto *Latch = RLoop->getLoopLatch();
12391 assert(Latch && "Loop with AddRec with no latch?");
12392 const SCEV *L2 = getSCEV(LPhi->getIncomingValueForBlock(Latch));
12393 if (!ProvedEasily(L2, RAR->getPostIncExpr(*this)))
12394 return false;
12395 } else {
12396 // In all other cases go over inputs of LHS and compare each of them to RHS,
12397 // the predicate is true for (LHS, RHS) if it is true for all such pairs.
12398 // At this point RHS is either a non-Phi, or it is a Phi from some block
12399 // different from LBB.
12400 for (const BasicBlock *IncBB : predecessors(LBB)) {
12401 // Check that RHS is available in this block.
12402 if (!dominates(RHS, IncBB))
12403 return false;
12404 const SCEV *L = getSCEV(LPhi->getIncomingValueForBlock(IncBB));
12405 // Make sure L does not refer to a value from a potentially previous
12406 // iteration of a loop.
12407 if (!properlyDominates(L, LBB))
12408 return false;
12409 if (!ProvedEasily(L, RHS))
12410 return false;
12411 }
12412 }
12413 return true;
12414}
12415
12416bool ScalarEvolution::isImpliedCondOperandsViaShift(ICmpInst::Predicate Pred,
12417 const SCEV *LHS,
12418 const SCEV *RHS,
12419 const SCEV *FoundLHS,
12420 const SCEV *FoundRHS) {
12421 // We want to imply LHS < RHS from LHS < (RHS >> shiftvalue). First, make
12422 // sure that we are dealing with same LHS.
12423 if (RHS == FoundRHS) {
12424 std::swap(LHS, RHS);
12425 std::swap(FoundLHS, FoundRHS);
12426 Pred = ICmpInst::getSwappedPredicate(Pred);
12427 }
12428 if (LHS != FoundLHS)
12429 return false;
12430
12431 auto *SUFoundRHS = dyn_cast<SCEVUnknown>(FoundRHS);
12432 if (!SUFoundRHS)
12433 return false;
12434
12435 Value *Shiftee, *ShiftValue;
12436
12437 using namespace PatternMatch;
12438 if (match(SUFoundRHS->getValue(),
12439 m_LShr(m_Value(Shiftee), m_Value(ShiftValue)))) {
12440 auto *ShifteeS = getSCEV(Shiftee);
12441 // Prove one of the following:
12442 // LHS <u (shiftee >> shiftvalue) && shiftee <=u RHS ---> LHS <u RHS
12443 // LHS <=u (shiftee >> shiftvalue) && shiftee <=u RHS ---> LHS <=u RHS
12444 // LHS <s (shiftee >> shiftvalue) && shiftee <=s RHS && shiftee >=s 0
12445 // ---> LHS <s RHS
12446 // LHS <=s (shiftee >> shiftvalue) && shiftee <=s RHS && shiftee >=s 0
12447 // ---> LHS <=s RHS
12448 if (Pred == ICmpInst::ICMP_ULT || Pred == ICmpInst::ICMP_ULE)
12449 return isKnownPredicate(ICmpInst::ICMP_ULE, ShifteeS, RHS);
12450 if (Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_SLE)
12451 if (isKnownNonNegative(ShifteeS))
12452 return isKnownPredicate(ICmpInst::ICMP_SLE, ShifteeS, RHS);
12453 }
12454
12455 return false;
12456}
12457
12458bool ScalarEvolution::isImpliedCondOperands(ICmpInst::Predicate Pred,
12459 const SCEV *LHS, const SCEV *RHS,
12460 const SCEV *FoundLHS,
12461 const SCEV *FoundRHS,
12462 const Instruction *CtxI) {
12463 if (isImpliedCondOperandsViaRanges(Pred, LHS, RHS, Pred, FoundLHS, FoundRHS))
12464 return true;
12465
12466 if (isImpliedCondOperandsViaNoOverflow(Pred, LHS, RHS, FoundLHS, FoundRHS))
12467 return true;
12468
12469 if (isImpliedCondOperandsViaShift(Pred, LHS, RHS, FoundLHS, FoundRHS))
12470 return true;
12471
12472 if (isImpliedCondOperandsViaAddRecStart(Pred, LHS, RHS, FoundLHS, FoundRHS,
12473 CtxI))
12474 return true;
12475
12476 return isImpliedCondOperandsHelper(Pred, LHS, RHS,
12477 FoundLHS, FoundRHS);
12478}
12479
12480/// Is MaybeMinMaxExpr an (U|S)(Min|Max) of Candidate and some other values?
12481template <typename MinMaxExprType>
12482static bool IsMinMaxConsistingOf(const SCEV *MaybeMinMaxExpr,
12483 const SCEV *Candidate) {
12484 const MinMaxExprType *MinMaxExpr = dyn_cast<MinMaxExprType>(MaybeMinMaxExpr);
12485 if (!MinMaxExpr)
12486 return false;
12487
12488 return is_contained(MinMaxExpr->operands(), Candidate);
12489}
12490
12493 const SCEV *LHS, const SCEV *RHS) {
12494 // If both sides are affine addrecs for the same loop, with equal
12495 // steps, and we know the recurrences don't wrap, then we only
12496 // need to check the predicate on the starting values.
12497
12498 if (!ICmpInst::isRelational(Pred))
12499 return false;
12500
12501 const SCEVAddRecExpr *LAR = dyn_cast<SCEVAddRecExpr>(LHS);
12502 if (!LAR)
12503 return false;
12504 const SCEVAddRecExpr *RAR = dyn_cast<SCEVAddRecExpr>(RHS);
12505 if (!RAR)
12506 return false;
12507 if (LAR->getLoop() != RAR->getLoop())
12508 return false;
12509 if (!LAR->isAffine() || !RAR->isAffine())
12510 return false;
12511
12512 if (LAR->getStepRecurrence(SE) != RAR->getStepRecurrence(SE))
12513 return false;
12514
12517 if (!LAR->getNoWrapFlags(NW) || !RAR->getNoWrapFlags(NW))
12518 return false;
12519
12520 return SE.isKnownPredicate(Pred, LAR->getStart(), RAR->getStart());
12521}
12522
12523/// Is LHS `Pred` RHS true on the virtue of LHS or RHS being a Min or Max
12524/// expression?
12527 const SCEV *LHS, const SCEV *RHS) {
12528 switch (Pred) {
12529 default:
12530 return false;
12531
12532 case ICmpInst::ICMP_SGE:
12533 std::swap(LHS, RHS);
12534 [[fallthrough]];
12535 case ICmpInst::ICMP_SLE:
12536 return
12537 // min(A, ...) <= A
12538 IsMinMaxConsistingOf<SCEVSMinExpr>(LHS, RHS) ||
12539 // A <= max(A, ...)
12540 IsMinMaxConsistingOf<SCEVSMaxExpr>(RHS, LHS);
12541
12542 case ICmpInst::ICMP_UGE:
12543 std::swap(LHS, RHS);
12544 [[fallthrough]];
12545 case ICmpInst::ICMP_ULE:
12546 return
12547 // min(A, ...) <= A
12548 // FIXME: what about umin_seq?
12549 IsMinMaxConsistingOf<SCEVUMinExpr>(LHS, RHS) ||
12550 // A <= max(A, ...)
12551 IsMinMaxConsistingOf<SCEVUMaxExpr>(RHS, LHS);
12552 }
12553
12554 llvm_unreachable("covered switch fell through?!");
12555}
12556
12557bool ScalarEvolution::isImpliedViaOperations(ICmpInst::Predicate Pred,
12558 const SCEV *LHS, const SCEV *RHS,
12559 const SCEV *FoundLHS,
12560 const SCEV *FoundRHS,
12561 unsigned Depth) {
12564 "LHS and RHS have different sizes?");
12565 assert(getTypeSizeInBits(FoundLHS->getType()) ==
12566 getTypeSizeInBits(FoundRHS->getType()) &&
12567 "FoundLHS and FoundRHS have different sizes?");
12568 // We want to avoid hurting the compile time with analysis of too big trees.
12570 return false;
12571
12572 // We only want to work with GT comparison so far.
12573 if (Pred == ICmpInst::ICMP_ULT || Pred == ICmpInst::ICMP_SLT) {
12574 Pred = CmpInst::getSwappedPredicate(Pred);
12575 std::swap(LHS, RHS);
12576 std::swap(FoundLHS, FoundRHS);
12577 }
12578
12579 // For unsigned, try to reduce it to corresponding signed comparison.
12580 if (Pred == ICmpInst::ICMP_UGT)
12581 // We can replace unsigned predicate with its signed counterpart if all
12582 // involved values are non-negative.
12583 // TODO: We could have better support for unsigned.
12584 if (isKnownNonNegative(FoundLHS) && isKnownNonNegative(FoundRHS)) {
12585 // Knowing that both FoundLHS and FoundRHS are non-negative, and knowing
12586 // FoundLHS >u FoundRHS, we also know that FoundLHS >s FoundRHS. Let us
12587 // use this fact to prove that LHS and RHS are non-negative.
12588 const SCEV *MinusOne = getMinusOne(LHS->getType());
12589 if (isImpliedCondOperands(ICmpInst::ICMP_SGT, LHS, MinusOne, FoundLHS,
12590 FoundRHS) &&
12591 isImpliedCondOperands(ICmpInst::ICMP_SGT, RHS, MinusOne, FoundLHS,
12592 FoundRHS))
12593 Pred = ICmpInst::ICMP_SGT;
12594 }
12595
12596 if (Pred != ICmpInst::ICMP_SGT)
12597 return false;
12598
12599 auto GetOpFromSExt = [&](const SCEV *S) {
12600 if (auto *Ext = dyn_cast<SCEVSignExtendExpr>(S))
12601 return Ext->getOperand();
12602 // TODO: If S is a SCEVConstant then you can cheaply "strip" the sext off
12603 // the constant in some cases.
12604 return S;
12605 };
12606
12607 // Acquire values from extensions.
12608 auto *OrigLHS = LHS;
12609 auto *OrigFoundLHS = FoundLHS;
12610 LHS = GetOpFromSExt(LHS);
12611 FoundLHS = GetOpFromSExt(FoundLHS);
12612
12613 // Is the SGT predicate can be proved trivially or using the found context.
12614 auto IsSGTViaContext = [&](const SCEV *S1, const SCEV *S2) {
12615 return isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_SGT, S1, S2) ||
12616 isImpliedViaOperations(ICmpInst::ICMP_SGT, S1, S2, OrigFoundLHS,
12617 FoundRHS, Depth + 1);
12618 };
12619
12620 if (auto *LHSAddExpr = dyn_cast<SCEVAddExpr>(LHS)) {
12621 // We want to avoid creation of any new non-constant SCEV. Since we are
12622 // going to compare the operands to RHS, we should be certain that we don't
12623 // need any size extensions for this. So let's decline all cases when the
12624 // sizes of types of LHS and RHS do not match.
12625 // TODO: Maybe try to get RHS from sext to catch more cases?
12627 return false;
12628
12629 // Should not overflow.
12630 if (!LHSAddExpr->hasNoSignedWrap())
12631 return false;
12632
12633 auto *LL = LHSAddExpr->getOperand(0);
12634 auto *LR = LHSAddExpr->getOperand(1);
12635 auto *MinusOne = getMinusOne(RHS->getType());
12636
12637 // Checks that S1 >= 0 && S2 > RHS, trivially or using the found context.
12638 auto IsSumGreaterThanRHS = [&](const SCEV *S1, const SCEV *S2) {
12639 return IsSGTViaContext(S1, MinusOne) && IsSGTViaContext(S2, RHS);
12640 };
12641 // Try to prove the following rule:
12642 // (LHS = LL + LR) && (LL >= 0) && (LR > RHS) => (LHS > RHS).
12643 // (LHS = LL + LR) && (LR >= 0) && (LL > RHS) => (LHS > RHS).
12644 if (IsSumGreaterThanRHS(LL, LR) || IsSumGreaterThanRHS(LR, LL))
12645 return true;
12646 } else if (auto *LHSUnknownExpr = dyn_cast<SCEVUnknown>(LHS)) {
12647 Value *LL, *LR;
12648 // FIXME: Once we have SDiv implemented, we can get rid of this matching.
12649
12650 using namespace llvm::PatternMatch;
12651
12652 if (match(LHSUnknownExpr->getValue(), m_SDiv(m_Value(LL), m_Value(LR)))) {
12653 // Rules for division.
12654 // We are going to perform some comparisons with Denominator and its
12655 // derivative expressions. In general case, creating a SCEV for it may
12656 // lead to a complex analysis of the entire graph, and in particular it
12657 // can request trip count recalculation for the same loop. This would
12658 // cache as SCEVCouldNotCompute to avoid the infinite recursion. To avoid
12659 // this, we only want to create SCEVs that are constants in this section.
12660 // So we bail if Denominator is not a constant.
12661 if (!isa<ConstantInt>(LR))
12662 return false;
12663
12664 auto *Denominator = cast<SCEVConstant>(getSCEV(LR));
12665
12666 // We want to make sure that LHS = FoundLHS / Denominator. If it is so,
12667 // then a SCEV for the numerator already exists and matches with FoundLHS.
12668 auto *Numerator = getExistingSCEV(LL);
12669 if (!Numerator || Numerator->getType() != FoundLHS->getType())
12670 return false;
12671
12672 // Make sure that the numerator matches with FoundLHS and the denominator
12673 // is positive.
12674 if (!HasSameValue(Numerator, FoundLHS) || !isKnownPositive(Denominator))
12675 return false;
12676
12677 auto *DTy = Denominator->getType();
12678 auto *FRHSTy = FoundRHS->getType();
12679 if (DTy->isPointerTy() != FRHSTy->isPointerTy())
12680 // One of types is a pointer and another one is not. We cannot extend
12681 // them properly to a wider type, so let us just reject this case.
12682 // TODO: Usage of getEffectiveSCEVType for DTy, FRHSTy etc should help
12683 // to avoid this check.
12684 return false;
12685
12686 // Given that:
12687 // FoundLHS > FoundRHS, LHS = FoundLHS / Denominator, Denominator > 0.
12688 auto *WTy = getWiderType(DTy, FRHSTy);
12689 auto *DenominatorExt = getNoopOrSignExtend(Denominator, WTy);
12690 auto *FoundRHSExt = getNoopOrSignExtend(FoundRHS, WTy);
12691
12692 // Try to prove the following rule:
12693 // (FoundRHS > Denominator - 2) && (RHS <= 0) => (LHS > RHS).
12694 // For example, given that FoundLHS > 2. It means that FoundLHS is at
12695 // least 3. If we divide it by Denominator < 4, we will have at least 1.
12696 auto *DenomMinusTwo = getMinusSCEV(DenominatorExt, getConstant(WTy, 2));
12697 if (isKnownNonPositive(RHS) &&
12698 IsSGTViaContext(FoundRHSExt, DenomMinusTwo))
12699 return true;
12700
12701 // Try to prove the following rule:
12702 // (FoundRHS > -1 - Denominator) && (RHS < 0) => (LHS > RHS).
12703 // For example, given that FoundLHS > -3. Then FoundLHS is at least -2.
12704 // If we divide it by Denominator > 2, then:
12705 // 1. If FoundLHS is negative, then the result is 0.
12706 // 2. If FoundLHS is non-negative, then the result is non-negative.
12707 // Anyways, the result is non-negative.
12708 auto *MinusOne = getMinusOne(WTy);
12709 auto *NegDenomMinusOne = getMinusSCEV(MinusOne, DenominatorExt);
12710 if (isKnownNegative(RHS) &&
12711 IsSGTViaContext(FoundRHSExt, NegDenomMinusOne))
12712 return true;
12713 }
12714 }
12715
12716 // If our expression contained SCEVUnknown Phis, and we split it down and now
12717 // need to prove something for them, try to prove the predicate for every
12718 // possible incoming values of those Phis.
12719 if (isImpliedViaMerge(Pred, OrigLHS, RHS, OrigFoundLHS, FoundRHS, Depth + 1))
12720 return true;
12721
12722 return false;
12723}
12724
12726 const SCEV *LHS, const SCEV *RHS) {
12727 // zext x u<= sext x, sext x s<= zext x
12728 const SCEV *Op;
12729 switch (Pred) {
12730 case ICmpInst::ICMP_SGE:
12731 std::swap(LHS, RHS);
12732 [[fallthrough]];
12733 case ICmpInst::ICMP_SLE: {
12734 // If operand >=s 0 then ZExt == SExt. If operand <s 0 then SExt <s ZExt.
12735 return match(LHS, m_scev_SExt(m_SCEV(Op))) &&
12737 }
12738 case ICmpInst::ICMP_UGE:
12739 std::swap(LHS, RHS);
12740 [[fallthrough]];
12741 case ICmpInst::ICMP_ULE: {
12742 // If operand >=u 0 then ZExt == SExt. If operand <u 0 then ZExt <u SExt.
12743 return match(LHS, m_scev_ZExt(m_SCEV(Op))) &&
12745 }
12746 default:
12747 return false;
12748 };
12749 llvm_unreachable("unhandled case");
12750}
12751
12752bool
12753ScalarEvolution::isKnownViaNonRecursiveReasoning(ICmpInst::Predicate Pred,
12754 const SCEV *LHS, const SCEV *RHS) {
12755 return isKnownPredicateExtendIdiom(Pred, LHS, RHS) ||
12756 isKnownPredicateViaConstantRanges(Pred, LHS, RHS) ||
12757 IsKnownPredicateViaMinOrMax(*this, Pred, LHS, RHS) ||
12758 IsKnownPredicateViaAddRecStart(*this, Pred, LHS, RHS) ||
12759 isKnownPredicateViaNoOverflow(Pred, LHS, RHS);
12760}
12761
12762bool
12763ScalarEvolution::isImpliedCondOperandsHelper(ICmpInst::Predicate Pred,
12764 const SCEV *LHS, const SCEV *RHS,
12765 const SCEV *FoundLHS,
12766 const SCEV *FoundRHS) {
12767 switch (Pred) {
12768 default: llvm_unreachable("Unexpected ICmpInst::Predicate value!");
12769 case ICmpInst::ICMP_EQ:
12770 case ICmpInst::ICMP_NE:
12771 if (HasSameValue(LHS, FoundLHS) && HasSameValue(RHS, FoundRHS))
12772 return true;
12773 break;
12774 case ICmpInst::ICMP_SLT:
12775 case ICmpInst::ICMP_SLE:
12776 if (isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_SLE, LHS, FoundLHS) &&
12777 isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_SGE, RHS, FoundRHS))
12778 return true;
12779 break;
12780 case ICmpInst::ICMP_SGT:
12781 case ICmpInst::ICMP_SGE:
12782 if (isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_SGE, LHS, FoundLHS) &&
12783 isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_SLE, RHS, FoundRHS))
12784 return true;
12785 break;
12786 case ICmpInst::ICMP_ULT:
12787 case ICmpInst::ICMP_ULE:
12788 if (isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_ULE, LHS, FoundLHS) &&
12789 isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_UGE, RHS, FoundRHS))
12790 return true;
12791 break;
12792 case ICmpInst::ICMP_UGT:
12793 case ICmpInst::ICMP_UGE:
12794 if (isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_UGE, LHS, FoundLHS) &&
12795 isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_ULE, RHS, FoundRHS))
12796 return true;
12797 break;
12798 }
12799
12800 // Maybe it can be proved via operations?
12801 if (isImpliedViaOperations(Pred, LHS, RHS, FoundLHS, FoundRHS))
12802 return true;
12803
12804 return false;
12805}
12806
12807bool ScalarEvolution::isImpliedCondOperandsViaRanges(ICmpInst::Predicate Pred,
12808 const SCEV *LHS,
12809 const SCEV *RHS,
12810 ICmpInst::Predicate FoundPred,
12811 const SCEV *FoundLHS,
12812 const SCEV *FoundRHS) {
12813 if (!isa<SCEVConstant>(RHS) || !isa<SCEVConstant>(FoundRHS))
12814 // The restriction on `FoundRHS` be lifted easily -- it exists only to
12815 // reduce the compile time impact of this optimization.
12816 return false;
12817
12818 std::optional<APInt> Addend = computeConstantDifference(LHS, FoundLHS);
12819 if (!Addend)
12820 return false;
12821
12822 const APInt &ConstFoundRHS = cast<SCEVConstant>(FoundRHS)->getAPInt();
12823
12824 // `FoundLHSRange` is the range we know `FoundLHS` to be in by virtue of the
12825 // antecedent "`FoundLHS` `FoundPred` `FoundRHS`".
12826 ConstantRange FoundLHSRange =
12827 ConstantRange::makeExactICmpRegion(FoundPred, ConstFoundRHS);
12828
12829 // Since `LHS` is `FoundLHS` + `Addend`, we can compute a range for `LHS`:
12830 ConstantRange LHSRange = FoundLHSRange.add(ConstantRange(*Addend));
12831
12832 // We can also compute the range of values for `LHS` that satisfy the
12833 // consequent, "`LHS` `Pred` `RHS`":
12834 const APInt &ConstRHS = cast<SCEVConstant>(RHS)->getAPInt();
12835 // The antecedent implies the consequent if every value of `LHS` that
12836 // satisfies the antecedent also satisfies the consequent.
12837 return LHSRange.icmp(Pred, ConstRHS);
12838}
12839
12840bool ScalarEvolution::canIVOverflowOnLT(const SCEV *RHS, const SCEV *Stride,
12841 bool IsSigned) {
12842 assert(isKnownPositive(Stride) && "Positive stride expected!");
12843
12844 unsigned BitWidth = getTypeSizeInBits(RHS->getType());
12845 const SCEV *One = getOne(Stride->getType());
12846
12847 if (IsSigned) {
12848 APInt MaxRHS = getSignedRangeMax(RHS);
12850 APInt MaxStrideMinusOne = getSignedRangeMax(getMinusSCEV(Stride, One));
12851
12852 // SMaxRHS + SMaxStrideMinusOne > SMaxValue => overflow!
12853 return (std::move(MaxValue) - MaxStrideMinusOne).slt(MaxRHS);
12854 }
12855
12856 APInt MaxRHS = getUnsignedRangeMax(RHS);
12857 APInt MaxValue = APInt::getMaxValue(BitWidth);
12858 APInt MaxStrideMinusOne = getUnsignedRangeMax(getMinusSCEV(Stride, One));
12859
12860 // UMaxRHS + UMaxStrideMinusOne > UMaxValue => overflow!
12861 return (std::move(MaxValue) - MaxStrideMinusOne).ult(MaxRHS);
12862}
12863
12864bool ScalarEvolution::canIVOverflowOnGT(const SCEV *RHS, const SCEV *Stride,
12865 bool IsSigned) {
12866
12867 unsigned BitWidth = getTypeSizeInBits(RHS->getType());
12868 const SCEV *One = getOne(Stride->getType());
12869
12870 if (IsSigned) {
12871 APInt MinRHS = getSignedRangeMin(RHS);
12873 APInt MaxStrideMinusOne = getSignedRangeMax(getMinusSCEV(Stride, One));
12874
12875 // SMinRHS - SMaxStrideMinusOne < SMinValue => overflow!
12876 return (std::move(MinValue) + MaxStrideMinusOne).sgt(MinRHS);
12877 }
12878
12879 APInt MinRHS = getUnsignedRangeMin(RHS);
12880 APInt MinValue = APInt::getMinValue(BitWidth);
12881 APInt MaxStrideMinusOne = getUnsignedRangeMax(getMinusSCEV(Stride, One));
12882
12883 // UMinRHS - UMaxStrideMinusOne < UMinValue => overflow!
12884 return (std::move(MinValue) + MaxStrideMinusOne).ugt(MinRHS);
12885}
12886
12888 // umin(N, 1) + floor((N - umin(N, 1)) / D)
12889 // This is equivalent to "1 + floor((N - 1) / D)" for N != 0. The umin
12890 // expression fixes the case of N=0.
12891 const SCEV *MinNOne = getUMinExpr(N, getOne(N->getType()));
12892 const SCEV *NMinusOne = getMinusSCEV(N, MinNOne);
12893 return getAddExpr(MinNOne, getUDivExpr(NMinusOne, D));
12894}
12895
12896const SCEV *ScalarEvolution::computeMaxBECountForLT(const SCEV *Start,
12897 const SCEV *Stride,
12898 const SCEV *End,
12899 unsigned BitWidth,
12900 bool IsSigned) {
12901 // The logic in this function assumes we can represent a positive stride.
12902 // If we can't, the backedge-taken count must be zero.
12903 if (IsSigned && BitWidth == 1)
12904 return getZero(Stride->getType());
12905
12906 // This code below only been closely audited for negative strides in the
12907 // unsigned comparison case, it may be correct for signed comparison, but
12908 // that needs to be established.
12909 if (IsSigned && isKnownNegative(Stride))
12910 return getCouldNotCompute();
12911
12912 // Calculate the maximum backedge count based on the range of values
12913 // permitted by Start, End, and Stride.
12914 APInt MinStart =
12915 IsSigned ? getSignedRangeMin(Start) : getUnsignedRangeMin(Start);
12916
12917 APInt MinStride =
12918 IsSigned ? getSignedRangeMin(Stride) : getUnsignedRangeMin(Stride);
12919
12920 // We assume either the stride is positive, or the backedge-taken count
12921 // is zero. So force StrideForMaxBECount to be at least one.
12922 APInt One(BitWidth, 1);
12923 APInt StrideForMaxBECount = IsSigned ? APIntOps::smax(One, MinStride)
12924 : APIntOps::umax(One, MinStride);
12925
12926 APInt MaxValue = IsSigned ? APInt::getSignedMaxValue(BitWidth)
12927 : APInt::getMaxValue(BitWidth);
12928 APInt Limit = MaxValue - (StrideForMaxBECount - 1);
12929
12930 // Although End can be a MAX expression we estimate MaxEnd considering only
12931 // the case End = RHS of the loop termination condition. This is safe because
12932 // in the other case (End - Start) is zero, leading to a zero maximum backedge
12933 // taken count.
12934 APInt MaxEnd = IsSigned ? APIntOps::smin(getSignedRangeMax(End), Limit)
12935 : APIntOps::umin(getUnsignedRangeMax(End), Limit);
12936
12937 // MaxBECount = ceil((max(MaxEnd, MinStart) - MinStart) / Stride)
12938 MaxEnd = IsSigned ? APIntOps::smax(MaxEnd, MinStart)
12939 : APIntOps::umax(MaxEnd, MinStart);
12940
12941 return getUDivCeilSCEV(getConstant(MaxEnd - MinStart) /* Delta */,
12942 getConstant(StrideForMaxBECount) /* Step */);
12943}
12944
12946ScalarEvolution::howManyLessThans(const SCEV *LHS, const SCEV *RHS,
12947 const Loop *L, bool IsSigned,
12948 bool ControlsOnlyExit, bool AllowPredicates) {
12950
12951 const SCEVAddRecExpr *IV = dyn_cast<SCEVAddRecExpr>(LHS);
12952 bool PredicatedIV = false;
12953 if (!IV) {
12954 if (auto *ZExt = dyn_cast<SCEVZeroExtendExpr>(LHS)) {
12955 const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(ZExt->getOperand());
12956 if (AR && AR->getLoop() == L && AR->isAffine()) {
12957 auto canProveNUW = [&]() {
12958 // We can use the comparison to infer no-wrap flags only if it fully
12959 // controls the loop exit.
12960 if (!ControlsOnlyExit)
12961 return false;
12962
12963 if (!isLoopInvariant(RHS, L))
12964 return false;
12965
12966 if (!isKnownNonZero(AR->getStepRecurrence(*this)))
12967 // We need the sequence defined by AR to strictly increase in the
12968 // unsigned integer domain for the logic below to hold.
12969 return false;
12970
12971 const unsigned InnerBitWidth = getTypeSizeInBits(AR->getType());
12972 const unsigned OuterBitWidth = getTypeSizeInBits(RHS->getType());
12973 // If RHS <=u Limit, then there must exist a value V in the sequence
12974 // defined by AR (e.g. {Start,+,Step}) such that V >u RHS, and
12975 // V <=u UINT_MAX. Thus, we must exit the loop before unsigned
12976 // overflow occurs. This limit also implies that a signed comparison
12977 // (in the wide bitwidth) is equivalent to an unsigned comparison as
12978 // the high bits on both sides must be zero.
12979 APInt StrideMax = getUnsignedRangeMax(AR->getStepRecurrence(*this));
12980 APInt Limit = APInt::getMaxValue(InnerBitWidth) - (StrideMax - 1);
12981 Limit = Limit.zext(OuterBitWidth);
12982 return getUnsignedRangeMax(applyLoopGuards(RHS, L)).ule(Limit);
12983 };
12984 auto Flags = AR->getNoWrapFlags();
12985 if (!hasFlags(Flags, SCEV::FlagNUW) && canProveNUW())
12986 Flags = setFlags(Flags, SCEV::FlagNUW);
12987
12988 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), Flags);
12989 if (AR->hasNoUnsignedWrap()) {
12990 // Emulate what getZeroExtendExpr would have done during construction
12991 // if we'd been able to infer the fact just above at that time.
12992 const SCEV *Step = AR->getStepRecurrence(*this);
12993 Type *Ty = ZExt->getType();
12994 auto *S = getAddRecExpr(
12995 getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this, 0),
12996 getZeroExtendExpr(Step, Ty, 0), L, AR->getNoWrapFlags());
12997 IV = dyn_cast<SCEVAddRecExpr>(S);
12998 }
12999 }
13000 }
13001 }
13002
13003
13004 if (!IV && AllowPredicates) {
13005 // Try to make this an AddRec using runtime tests, in the first X
13006 // iterations of this loop, where X is the SCEV expression found by the
13007 // algorithm below.
13008 IV = convertSCEVToAddRecWithPredicates(LHS, L, Predicates);
13009 PredicatedIV = true;
13010 }
13011
13012 // Avoid weird loops
13013 if (!IV || IV->getLoop() != L || !IV->isAffine())
13014 return getCouldNotCompute();
13015
13016 // A precondition of this method is that the condition being analyzed
13017 // reaches an exiting branch which dominates the latch. Given that, we can
13018 // assume that an increment which violates the nowrap specification and
13019 // produces poison must cause undefined behavior when the resulting poison
13020 // value is branched upon and thus we can conclude that the backedge is
13021 // taken no more often than would be required to produce that poison value.
13022 // Note that a well defined loop can exit on the iteration which violates
13023 // the nowrap specification if there is another exit (either explicit or
13024 // implicit/exceptional) which causes the loop to execute before the
13025 // exiting instruction we're analyzing would trigger UB.
13026 auto WrapType = IsSigned ? SCEV::FlagNSW : SCEV::FlagNUW;
13027 bool NoWrap = ControlsOnlyExit && IV->getNoWrapFlags(WrapType);
13029
13030 const SCEV *Stride = IV->getStepRecurrence(*this);
13031
13032 bool PositiveStride = isKnownPositive(Stride);
13033
13034 // Avoid negative or zero stride values.
13035 if (!PositiveStride) {
13036 // We can compute the correct backedge taken count for loops with unknown
13037 // strides if we can prove that the loop is not an infinite loop with side
13038 // effects. Here's the loop structure we are trying to handle -
13039 //
13040 // i = start
13041 // do {
13042 // A[i] = i;
13043 // i += s;
13044 // } while (i < end);
13045 //
13046 // The backedge taken count for such loops is evaluated as -
13047 // (max(end, start + stride) - start - 1) /u stride
13048 //
13049 // The additional preconditions that we need to check to prove correctness
13050 // of the above formula is as follows -
13051 //
13052 // a) IV is either nuw or nsw depending upon signedness (indicated by the
13053 // NoWrap flag).
13054 // b) the loop is guaranteed to be finite (e.g. is mustprogress and has
13055 // no side effects within the loop)
13056 // c) loop has a single static exit (with no abnormal exits)
13057 //
13058 // Precondition a) implies that if the stride is negative, this is a single
13059 // trip loop. The backedge taken count formula reduces to zero in this case.
13060 //
13061 // Precondition b) and c) combine to imply that if rhs is invariant in L,
13062 // then a zero stride means the backedge can't be taken without executing
13063 // undefined behavior.
13064 //
13065 // The positive stride case is the same as isKnownPositive(Stride) returning
13066 // true (original behavior of the function).
13067 //
13068 if (PredicatedIV || !NoWrap || !loopIsFiniteByAssumption(L) ||
13070 return getCouldNotCompute();
13071
13072 if (!isKnownNonZero(Stride)) {
13073 // If we have a step of zero, and RHS isn't invariant in L, we don't know
13074 // if it might eventually be greater than start and if so, on which
13075 // iteration. We can't even produce a useful upper bound.
13076 if (!isLoopInvariant(RHS, L))
13077 return getCouldNotCompute();
13078
13079 // We allow a potentially zero stride, but we need to divide by stride
13080 // below. Since the loop can't be infinite and this check must control
13081 // the sole exit, we can infer the exit must be taken on the first
13082 // iteration (e.g. backedge count = 0) if the stride is zero. Given that,
13083 // we know the numerator in the divides below must be zero, so we can
13084 // pick an arbitrary non-zero value for the denominator (e.g. stride)
13085 // and produce the right result.
13086 // FIXME: Handle the case where Stride is poison?
13087 auto wouldZeroStrideBeUB = [&]() {
13088 // Proof by contradiction. Suppose the stride were zero. If we can
13089 // prove that the backedge *is* taken on the first iteration, then since
13090 // we know this condition controls the sole exit, we must have an
13091 // infinite loop. We can't have a (well defined) infinite loop per
13092 // check just above.
13093 // Note: The (Start - Stride) term is used to get the start' term from
13094 // (start' + stride,+,stride). Remember that we only care about the
13095 // result of this expression when stride == 0 at runtime.
13096 auto *StartIfZero = getMinusSCEV(IV->getStart(), Stride);
13097 return isLoopEntryGuardedByCond(L, Cond, StartIfZero, RHS);
13098 };
13099 if (!wouldZeroStrideBeUB()) {
13100 Stride = getUMaxExpr(Stride, getOne(Stride->getType()));
13101 }
13102 }
13103 } else if (!NoWrap) {
13104 // Avoid proven overflow cases: this will ensure that the backedge taken
13105 // count will not generate any unsigned overflow.
13106 if (canIVOverflowOnLT(RHS, Stride, IsSigned))
13107 return getCouldNotCompute();
13108 }
13109
13110 // On all paths just preceeding, we established the following invariant:
13111 // IV can be assumed not to overflow up to and including the exiting
13112 // iteration. We proved this in one of two ways:
13113 // 1) We can show overflow doesn't occur before the exiting iteration
13114 // 1a) canIVOverflowOnLT, and b) step of one
13115 // 2) We can show that if overflow occurs, the loop must execute UB
13116 // before any possible exit.
13117 // Note that we have not yet proved RHS invariant (in general).
13118
13119 const SCEV *Start = IV->getStart();
13120
13121 // Preserve pointer-typed Start/RHS to pass to isLoopEntryGuardedByCond.
13122 // If we convert to integers, isLoopEntryGuardedByCond will miss some cases.
13123 // Use integer-typed versions for actual computation; we can't subtract
13124 // pointers in general.
13125 const SCEV *OrigStart = Start;
13126 const SCEV *OrigRHS = RHS;
13127 if (Start->getType()->isPointerTy()) {
13128 Start = getLosslessPtrToIntExpr(Start);
13129 if (isa<SCEVCouldNotCompute>(Start))
13130 return Start;
13131 }
13132 if (RHS->getType()->isPointerTy()) {
13134 if (isa<SCEVCouldNotCompute>(RHS))
13135 return RHS;
13136 }
13137
13138 const SCEV *End = nullptr, *BECount = nullptr,
13139 *BECountIfBackedgeTaken = nullptr;
13140 if (!isLoopInvariant(RHS, L)) {
13141 const auto *RHSAddRec = dyn_cast<SCEVAddRecExpr>(RHS);
13142 if (PositiveStride && RHSAddRec != nullptr && RHSAddRec->getLoop() == L &&
13143 RHSAddRec->getNoWrapFlags()) {
13144 // The structure of loop we are trying to calculate backedge count of:
13145 //
13146 // left = left_start
13147 // right = right_start
13148 //
13149 // while(left < right){
13150 // ... do something here ...
13151 // left += s1; // stride of left is s1 (s1 > 0)
13152 // right += s2; // stride of right is s2 (s2 < 0)
13153 // }
13154 //
13155
13156 const SCEV *RHSStart = RHSAddRec->getStart();
13157 const SCEV *RHSStride = RHSAddRec->getStepRecurrence(*this);
13158
13159 // If Stride - RHSStride is positive and does not overflow, we can write
13160 // backedge count as ->
13161 // ceil((End - Start) /u (Stride - RHSStride))
13162 // Where, End = max(RHSStart, Start)
13163
13164 // Check if RHSStride < 0 and Stride - RHSStride will not overflow.
13165 if (isKnownNegative(RHSStride) &&
13166 willNotOverflow(Instruction::Sub, /*Signed=*/true, Stride,
13167 RHSStride)) {
13168
13169 const SCEV *Denominator = getMinusSCEV(Stride, RHSStride);
13170 if (isKnownPositive(Denominator)) {
13171 End = IsSigned ? getSMaxExpr(RHSStart, Start)
13172 : getUMaxExpr(RHSStart, Start);
13173
13174 // We can do this because End >= Start, as End = max(RHSStart, Start)
13175 const SCEV *Delta = getMinusSCEV(End, Start);
13176
13177 BECount = getUDivCeilSCEV(Delta, Denominator);
13178 BECountIfBackedgeTaken =
13179 getUDivCeilSCEV(getMinusSCEV(RHSStart, Start), Denominator);
13180 }
13181 }
13182 }
13183 if (BECount == nullptr) {
13184 // If we cannot calculate ExactBECount, we can calculate the MaxBECount,
13185 // given the start, stride and max value for the end bound of the
13186 // loop (RHS), and the fact that IV does not overflow (which is
13187 // checked above).
13188 const SCEV *MaxBECount = computeMaxBECountForLT(
13189 Start, Stride, RHS, getTypeSizeInBits(LHS->getType()), IsSigned);
13190 return ExitLimit(getCouldNotCompute() /* ExactNotTaken */, MaxBECount,
13191 MaxBECount, false /*MaxOrZero*/, Predicates);
13192 }
13193 } else {
13194 // We use the expression (max(End,Start)-Start)/Stride to describe the
13195 // backedge count, as if the backedge is taken at least once
13196 // max(End,Start) is End and so the result is as above, and if not
13197 // max(End,Start) is Start so we get a backedge count of zero.
13198 auto *OrigStartMinusStride = getMinusSCEV(OrigStart, Stride);
13199 assert(isAvailableAtLoopEntry(OrigStartMinusStride, L) && "Must be!");
13200 assert(isAvailableAtLoopEntry(OrigStart, L) && "Must be!");
13201 assert(isAvailableAtLoopEntry(OrigRHS, L) && "Must be!");
13202 // Can we prove (max(RHS,Start) > Start - Stride?
13203 if (isLoopEntryGuardedByCond(L, Cond, OrigStartMinusStride, OrigStart) &&
13204 isLoopEntryGuardedByCond(L, Cond, OrigStartMinusStride, OrigRHS)) {
13205 // In this case, we can use a refined formula for computing backedge
13206 // taken count. The general formula remains:
13207 // "End-Start /uceiling Stride" where "End = max(RHS,Start)"
13208 // We want to use the alternate formula:
13209 // "((End - 1) - (Start - Stride)) /u Stride"
13210 // Let's do a quick case analysis to show these are equivalent under
13211 // our precondition that max(RHS,Start) > Start - Stride.
13212 // * For RHS <= Start, the backedge-taken count must be zero.
13213 // "((End - 1) - (Start - Stride)) /u Stride" reduces to
13214 // "((Start - 1) - (Start - Stride)) /u Stride" which simplies to
13215 // "Stride - 1 /u Stride" which is indeed zero for all non-zero values
13216 // of Stride. For 0 stride, we've use umin(1,Stride) above,
13217 // reducing this to the stride of 1 case.
13218 // * For RHS >= Start, the backedge count must be "RHS-Start /uceil
13219 // Stride".
13220 // "((End - 1) - (Start - Stride)) /u Stride" reduces to
13221 // "((RHS - 1) - (Start - Stride)) /u Stride" reassociates to
13222 // "((RHS - (Start - Stride) - 1) /u Stride".
13223 // Our preconditions trivially imply no overflow in that form.
13224 const SCEV *MinusOne = getMinusOne(Stride->getType());
13225 const SCEV *Numerator =
13226 getMinusSCEV(getAddExpr(RHS, MinusOne), getMinusSCEV(Start, Stride));
13227 BECount = getUDivExpr(Numerator, Stride);
13228 }
13229
13230 if (!BECount) {
13231 auto canProveRHSGreaterThanEqualStart = [&]() {
13232 auto CondGE = IsSigned ? ICmpInst::ICMP_SGE : ICmpInst::ICMP_UGE;
13233 const SCEV *GuardedRHS = applyLoopGuards(OrigRHS, L);
13234 const SCEV *GuardedStart = applyLoopGuards(OrigStart, L);
13235
13236 if (isLoopEntryGuardedByCond(L, CondGE, OrigRHS, OrigStart) ||
13237 isKnownPredicate(CondGE, GuardedRHS, GuardedStart))
13238 return true;
13239
13240 // (RHS > Start - 1) implies RHS >= Start.
13241 // * "RHS >= Start" is trivially equivalent to "RHS > Start - 1" if
13242 // "Start - 1" doesn't overflow.
13243 // * For signed comparison, if Start - 1 does overflow, it's equal
13244 // to INT_MAX, and "RHS >s INT_MAX" is trivially false.
13245 // * For unsigned comparison, if Start - 1 does overflow, it's equal
13246 // to UINT_MAX, and "RHS >u UINT_MAX" is trivially false.
13247 //
13248 // FIXME: Should isLoopEntryGuardedByCond do this for us?
13249 auto CondGT = IsSigned ? ICmpInst::ICMP_SGT : ICmpInst::ICMP_UGT;
13250 auto *StartMinusOne =
13251 getAddExpr(OrigStart, getMinusOne(OrigStart->getType()));
13252 return isLoopEntryGuardedByCond(L, CondGT, OrigRHS, StartMinusOne);
13253 };
13254
13255 // If we know that RHS >= Start in the context of loop, then we know
13256 // that max(RHS, Start) = RHS at this point.
13257 if (canProveRHSGreaterThanEqualStart()) {
13258 End = RHS;
13259 } else {
13260 // If RHS < Start, the backedge will be taken zero times. So in
13261 // general, we can write the backedge-taken count as:
13262 //
13263 // RHS >= Start ? ceil(RHS - Start) / Stride : 0
13264 //
13265 // We convert it to the following to make it more convenient for SCEV:
13266 //
13267 // ceil(max(RHS, Start) - Start) / Stride
13268 End = IsSigned ? getSMaxExpr(RHS, Start) : getUMaxExpr(RHS, Start);
13269
13270 // See what would happen if we assume the backedge is taken. This is
13271 // used to compute MaxBECount.
13272 BECountIfBackedgeTaken =
13273 getUDivCeilSCEV(getMinusSCEV(RHS, Start), Stride);
13274 }
13275
13276 // At this point, we know:
13277 //
13278 // 1. If IsSigned, Start <=s End; otherwise, Start <=u End
13279 // 2. The index variable doesn't overflow.
13280 //
13281 // Therefore, we know N exists such that
13282 // (Start + Stride * N) >= End, and computing "(Start + Stride * N)"
13283 // doesn't overflow.
13284 //
13285 // Using this information, try to prove whether the addition in
13286 // "(Start - End) + (Stride - 1)" has unsigned overflow.
13287 const SCEV *One = getOne(Stride->getType());
13288 bool MayAddOverflow = [&] {
13289 if (isKnownToBeAPowerOfTwo(Stride)) {
13290 // Suppose Stride is a power of two, and Start/End are unsigned
13291 // integers. Let UMAX be the largest representable unsigned
13292 // integer.
13293 //
13294 // By the preconditions of this function, we know
13295 // "(Start + Stride * N) >= End", and this doesn't overflow.
13296 // As a formula:
13297 //
13298 // End <= (Start + Stride * N) <= UMAX
13299 //
13300 // Subtracting Start from all the terms:
13301 //
13302 // End - Start <= Stride * N <= UMAX - Start
13303 //
13304 // Since Start is unsigned, UMAX - Start <= UMAX. Therefore:
13305 //
13306 // End - Start <= Stride * N <= UMAX
13307 //
13308 // Stride * N is a multiple of Stride. Therefore,
13309 //
13310 // End - Start <= Stride * N <= UMAX - (UMAX mod Stride)
13311 //
13312 // Since Stride is a power of two, UMAX + 1 is divisible by
13313 // Stride. Therefore, UMAX mod Stride == Stride - 1. So we can
13314 // write:
13315 //
13316 // End - Start <= Stride * N <= UMAX - Stride - 1
13317 //
13318 // Dropping the middle term:
13319 //
13320 // End - Start <= UMAX - Stride - 1
13321 //
13322 // Adding Stride - 1 to both sides:
13323 //
13324 // (End - Start) + (Stride - 1) <= UMAX
13325 //
13326 // In other words, the addition doesn't have unsigned overflow.
13327 //
13328 // A similar proof works if we treat Start/End as signed values.
13329 // Just rewrite steps before "End - Start <= Stride * N <= UMAX"
13330 // to use signed max instead of unsigned max. Note that we're
13331 // trying to prove a lack of unsigned overflow in either case.
13332 return false;
13333 }
13334 if (Start == Stride || Start == getMinusSCEV(Stride, One)) {
13335 // If Start is equal to Stride, (End - Start) + (Stride - 1) == End
13336 // - 1. If !IsSigned, 0 <u Stride == Start <=u End; so 0 <u End - 1
13337 // <u End. If IsSigned, 0 <s Stride == Start <=s End; so 0 <s End -
13338 // 1 <s End.
13339 //
13340 // If Start is equal to Stride - 1, (End - Start) + Stride - 1 ==
13341 // End.
13342 return false;
13343 }
13344 return true;
13345 }();
13346
13347 const SCEV *Delta = getMinusSCEV(End, Start);
13348 if (!MayAddOverflow) {
13349 // floor((D + (S - 1)) / S)
13350 // We prefer this formulation if it's legal because it's fewer
13351 // operations.
13352 BECount =
13353 getUDivExpr(getAddExpr(Delta, getMinusSCEV(Stride, One)), Stride);
13354 } else {
13355 BECount = getUDivCeilSCEV(Delta, Stride);
13356 }
13357 }
13358 }
13359
13360 const SCEV *ConstantMaxBECount;
13361 bool MaxOrZero = false;
13362 if (isa<SCEVConstant>(BECount)) {
13363 ConstantMaxBECount = BECount;
13364 } else if (BECountIfBackedgeTaken &&
13365 isa<SCEVConstant>(BECountIfBackedgeTaken)) {
13366 // If we know exactly how many times the backedge will be taken if it's
13367 // taken at least once, then the backedge count will either be that or
13368 // zero.
13369 ConstantMaxBECount = BECountIfBackedgeTaken;
13370 MaxOrZero = true;
13371 } else {
13372 ConstantMaxBECount = computeMaxBECountForLT(
13373 Start, Stride, RHS, getTypeSizeInBits(LHS->getType()), IsSigned);
13374 }
13375
13376 if (isa<SCEVCouldNotCompute>(ConstantMaxBECount) &&
13377 !isa<SCEVCouldNotCompute>(BECount))
13378 ConstantMaxBECount = getConstant(getUnsignedRangeMax(BECount));
13379
13380 const SCEV *SymbolicMaxBECount =
13381 isa<SCEVCouldNotCompute>(BECount) ? ConstantMaxBECount : BECount;
13382 return ExitLimit(BECount, ConstantMaxBECount, SymbolicMaxBECount, MaxOrZero,
13383 Predicates);
13384}
13385
13386ScalarEvolution::ExitLimit ScalarEvolution::howManyGreaterThans(
13387 const SCEV *LHS, const SCEV *RHS, const Loop *L, bool IsSigned,
13388 bool ControlsOnlyExit, bool AllowPredicates) {
13390 // We handle only IV > Invariant
13391 if (!isLoopInvariant(RHS, L))
13392 return getCouldNotCompute();
13393
13394 const SCEVAddRecExpr *IV = dyn_cast<SCEVAddRecExpr>(LHS);
13395 if (!IV && AllowPredicates)
13396 // Try to make this an AddRec using runtime tests, in the first X
13397 // iterations of this loop, where X is the SCEV expression found by the
13398 // algorithm below.
13399 IV = convertSCEVToAddRecWithPredicates(LHS, L, Predicates);
13400
13401 // Avoid weird loops
13402 if (!IV || IV->getLoop() != L || !IV->isAffine())
13403 return getCouldNotCompute();
13404
13405 auto WrapType = IsSigned ? SCEV::FlagNSW : SCEV::FlagNUW;
13406 bool NoWrap = ControlsOnlyExit && IV->getNoWrapFlags(WrapType);
13408
13409 const SCEV *Stride = getNegativeSCEV(IV->getStepRecurrence(*this));
13410
13411 // Avoid negative or zero stride values
13412 if (!isKnownPositive(Stride))
13413 return getCouldNotCompute();
13414
13415 // Avoid proven overflow cases: this will ensure that the backedge taken count
13416 // will not generate any unsigned overflow. Relaxed no-overflow conditions
13417 // exploit NoWrapFlags, allowing to optimize in presence of undefined
13418 // behaviors like the case of C language.
13419 if (!Stride->isOne() && !NoWrap)
13420 if (canIVOverflowOnGT(RHS, Stride, IsSigned))
13421 return getCouldNotCompute();
13422
13423 const SCEV *Start = IV->getStart();
13424 const SCEV *End = RHS;
13425 if (!isLoopEntryGuardedByCond(L, Cond, getAddExpr(Start, Stride), RHS)) {
13426 // If we know that Start >= RHS in the context of loop, then we know that
13427 // min(RHS, Start) = RHS at this point.
13429 L, IsSigned ? ICmpInst::ICMP_SGE : ICmpInst::ICMP_UGE, Start, RHS))
13430 End = RHS;
13431 else
13432 End = IsSigned ? getSMinExpr(RHS, Start) : getUMinExpr(RHS, Start);
13433 }
13434
13435 if (Start->getType()->isPointerTy()) {
13436 Start = getLosslessPtrToIntExpr(Start);
13437 if (isa<SCEVCouldNotCompute>(Start))
13438 return Start;
13439 }
13440 if (End->getType()->isPointerTy()) {
13442 if (isa<SCEVCouldNotCompute>(End))
13443 return End;
13444 }
13445
13446 // Compute ((Start - End) + (Stride - 1)) / Stride.
13447 // FIXME: This can overflow. Holding off on fixing this for now;
13448 // howManyGreaterThans will hopefully be gone soon.
13449 const SCEV *One = getOne(Stride->getType());
13450 const SCEV *BECount = getUDivExpr(
13451 getAddExpr(getMinusSCEV(Start, End), getMinusSCEV(Stride, One)), Stride);
13452
13453 APInt MaxStart = IsSigned ? getSignedRangeMax(Start)
13454 : getUnsignedRangeMax(Start);
13455
13456 APInt MinStride = IsSigned ? getSignedRangeMin(Stride)
13457 : getUnsignedRangeMin(Stride);
13458
13459 unsigned BitWidth = getTypeSizeInBits(LHS->getType());
13460 APInt Limit = IsSigned ? APInt::getSignedMinValue(BitWidth) + (MinStride - 1)
13461 : APInt::getMinValue(BitWidth) + (MinStride - 1);
13462
13463 // Although End can be a MIN expression we estimate MinEnd considering only
13464 // the case End = RHS. This is safe because in the other case (Start - End)
13465 // is zero, leading to a zero maximum backedge taken count.
13466 APInt MinEnd =
13467 IsSigned ? APIntOps::smax(getSignedRangeMin(RHS), Limit)
13468 : APIntOps::umax(getUnsignedRangeMin(RHS), Limit);
13469
13470 const SCEV *ConstantMaxBECount =
13471 isa<SCEVConstant>(BECount)
13472 ? BECount
13473 : getUDivCeilSCEV(getConstant(MaxStart - MinEnd),
13474 getConstant(MinStride));
13475
13476 if (isa<SCEVCouldNotCompute>(ConstantMaxBECount))
13477 ConstantMaxBECount = BECount;
13478 const SCEV *SymbolicMaxBECount =
13479 isa<SCEVCouldNotCompute>(BECount) ? ConstantMaxBECount : BECount;
13480
13481 return ExitLimit(BECount, ConstantMaxBECount, SymbolicMaxBECount, false,
13482 Predicates);
13483}
13484
13486 ScalarEvolution &SE) const {
13487 if (Range.isFullSet()) // Infinite loop.
13488 return SE.getCouldNotCompute();
13489
13490 // If the start is a non-zero constant, shift the range to simplify things.
13491 if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(getStart()))
13492 if (!SC->getValue()->isZero()) {
13494 Operands[0] = SE.getZero(SC->getType());
13495 const SCEV *Shifted = SE.getAddRecExpr(Operands, getLoop(),
13496 getNoWrapFlags(FlagNW));
13497 if (const auto *ShiftedAddRec = dyn_cast<SCEVAddRecExpr>(Shifted))
13498 return ShiftedAddRec->getNumIterationsInRange(
13499 Range.subtract(SC->getAPInt()), SE);
13500 // This is strange and shouldn't happen.
13501 return SE.getCouldNotCompute();
13502 }
13503
13504 // The only time we can solve this is when we have all constant indices.
13505 // Otherwise, we cannot determine the overflow conditions.
13506 if (any_of(operands(), [](const SCEV *Op) { return !isa<SCEVConstant>(Op); }))
13507 return SE.getCouldNotCompute();
13508
13509 // Okay at this point we know that all elements of the chrec are constants and
13510 // that the start element is zero.
13511
13512 // First check to see if the range contains zero. If not, the first
13513 // iteration exits.
13514 unsigned BitWidth = SE.getTypeSizeInBits(getType());
13515 if (!Range.contains(APInt(BitWidth, 0)))
13516 return SE.getZero(getType());
13517
13518 if (isAffine()) {
13519 // If this is an affine expression then we have this situation:
13520 // Solve {0,+,A} in Range === Ax in Range
13521
13522 // We know that zero is in the range. If A is positive then we know that
13523 // the upper value of the range must be the first possible exit value.
13524 // If A is negative then the lower of the range is the last possible loop
13525 // value. Also note that we already checked for a full range.
13526 APInt A = cast<SCEVConstant>(getOperand(1))->getAPInt();
13527 APInt End = A.sge(1) ? (Range.getUpper() - 1) : Range.getLower();
13528
13529 // The exit value should be (End+A)/A.
13530 APInt ExitVal = (End + A).udiv(A);
13531 ConstantInt *ExitValue = ConstantInt::get(SE.getContext(), ExitVal);
13532
13533 // Evaluate at the exit value. If we really did fall out of the valid
13534 // range, then we computed our trip count, otherwise wrap around or other
13535 // things must have happened.
13536 ConstantInt *Val = EvaluateConstantChrecAtConstant(this, ExitValue, SE);
13537 if (Range.contains(Val->getValue()))
13538 return SE.getCouldNotCompute(); // Something strange happened
13539
13540 // Ensure that the previous value is in the range.
13543 ConstantInt::get(SE.getContext(), ExitVal - 1), SE)->getValue()) &&
13544 "Linear scev computation is off in a bad way!");
13545 return SE.getConstant(ExitValue);
13546 }
13547
13548 if (isQuadratic()) {
13549 if (auto S = SolveQuadraticAddRecRange(this, Range, SE))
13550 return SE.getConstant(*S);
13551 }
13552
13553 return SE.getCouldNotCompute();
13554}
13555
13556const SCEVAddRecExpr *
13558 assert(getNumOperands() > 1 && "AddRec with zero step?");
13559 // There is a temptation to just call getAddExpr(this, getStepRecurrence(SE)),
13560 // but in this case we cannot guarantee that the value returned will be an
13561 // AddRec because SCEV does not have a fixed point where it stops
13562 // simplification: it is legal to return ({rec1} + {rec2}). For example, it
13563 // may happen if we reach arithmetic depth limit while simplifying. So we
13564 // construct the returned value explicitly.
13566 // If this is {A,+,B,+,C,...,+,N}, then its step is {B,+,C,+,...,+,N}, and
13567 // (this + Step) is {A+B,+,B+C,+...,+,N}.
13568 for (unsigned i = 0, e = getNumOperands() - 1; i < e; ++i)
13569 Ops.push_back(SE.getAddExpr(getOperand(i), getOperand(i + 1)));
13570 // We know that the last operand is not a constant zero (otherwise it would
13571 // have been popped out earlier). This guarantees us that if the result has
13572 // the same last operand, then it will also not be popped out, meaning that
13573 // the returned value will be an AddRec.
13574 const SCEV *Last = getOperand(getNumOperands() - 1);
13575 assert(!Last->isZero() && "Recurrency with zero step?");
13576 Ops.push_back(Last);
13577 return cast<SCEVAddRecExpr>(SE.getAddRecExpr(Ops, getLoop(),
13579}
13580
13581// Return true when S contains at least an undef value.
13583 return SCEVExprContains(S, [](const SCEV *S) {
13584 if (const auto *SU = dyn_cast<SCEVUnknown>(S))
13585 return isa<UndefValue>(SU->getValue());
13586 return false;
13587 });
13588}
13589
13590// Return true when S contains a value that is a nullptr.
13592 return SCEVExprContains(S, [](const SCEV *S) {
13593 if (const auto *SU = dyn_cast<SCEVUnknown>(S))
13594 return SU->getValue() == nullptr;
13595 return false;
13596 });
13597}
13598
13599/// Return the size of an element read or written by Inst.
13601 Type *Ty;
13602 if (StoreInst *Store = dyn_cast<StoreInst>(Inst))
13603 Ty = Store->getValueOperand()->getType();
13604 else if (LoadInst *Load = dyn_cast<LoadInst>(Inst))
13605 Ty = Load->getType();
13606 else
13607 return nullptr;
13608
13610 return getSizeOfExpr(ETy, Ty);
13611}
13612
13613//===----------------------------------------------------------------------===//
13614// SCEVCallbackVH Class Implementation
13615//===----------------------------------------------------------------------===//
13616
13617void ScalarEvolution::SCEVCallbackVH::deleted() {
13618 assert(SE && "SCEVCallbackVH called with a null ScalarEvolution!");
13619 if (PHINode *PN = dyn_cast<PHINode>(getValPtr()))
13620 SE->ConstantEvolutionLoopExitValue.erase(PN);
13621 SE->eraseValueFromMap(getValPtr());
13622 // this now dangles!
13623}
13624
13625void ScalarEvolution::SCEVCallbackVH::allUsesReplacedWith(Value *V) {
13626 assert(SE && "SCEVCallbackVH called with a null ScalarEvolution!");
13627
13628 // Forget all the expressions associated with users of the old value,
13629 // so that future queries will recompute the expressions using the new
13630 // value.
13631 SE->forgetValue(getValPtr());
13632 // this now dangles!
13633}
13634
13635ScalarEvolution::SCEVCallbackVH::SCEVCallbackVH(Value *V, ScalarEvolution *se)
13636 : CallbackVH(V), SE(se) {}
13637
13638//===----------------------------------------------------------------------===//
13639// ScalarEvolution Class Implementation
13640//===----------------------------------------------------------------------===//
13641
13644 LoopInfo &LI)
13645 : F(F), DL(F.getDataLayout()), TLI(TLI), AC(AC), DT(DT), LI(LI),
13646 CouldNotCompute(new SCEVCouldNotCompute()), ValuesAtScopes(64),
13647 LoopDispositions(64), BlockDispositions(64) {
13648 // To use guards for proving predicates, we need to scan every instruction in
13649 // relevant basic blocks, and not just terminators. Doing this is a waste of
13650 // time if the IR does not actually contain any calls to
13651 // @llvm.experimental.guard, so do a quick check and remember this beforehand.
13652 //
13653 // This pessimizes the case where a pass that preserves ScalarEvolution wants
13654 // to _add_ guards to the module when there weren't any before, and wants
13655 // ScalarEvolution to optimize based on those guards. For now we prefer to be
13656 // efficient in lieu of being smart in that rather obscure case.
13657
13658 auto *GuardDecl = Intrinsic::getDeclarationIfExists(
13659 F.getParent(), Intrinsic::experimental_guard);
13660 HasGuards = GuardDecl && !GuardDecl->use_empty();
13661}
13662
13664 : F(Arg.F), DL(Arg.DL), HasGuards(Arg.HasGuards), TLI(Arg.TLI), AC(Arg.AC),
13665 DT(Arg.DT), LI(Arg.LI), CouldNotCompute(std::move(Arg.CouldNotCompute)),
13666 ValueExprMap(std::move(Arg.ValueExprMap)),
13667 PendingLoopPredicates(std::move(Arg.PendingLoopPredicates)),
13668 PendingPhiRanges(std::move(Arg.PendingPhiRanges)),
13669 PendingMerges(std::move(Arg.PendingMerges)),
13670 ConstantMultipleCache(std::move(Arg.ConstantMultipleCache)),
13671 BackedgeTakenCounts(std::move(Arg.BackedgeTakenCounts)),
13672 PredicatedBackedgeTakenCounts(
13673 std::move(Arg.PredicatedBackedgeTakenCounts)),
13674 BECountUsers(std::move(Arg.BECountUsers)),
13675 ConstantEvolutionLoopExitValue(
13676 std::move(Arg.ConstantEvolutionLoopExitValue)),
13677 ValuesAtScopes(std::move(Arg.ValuesAtScopes)),
13678 ValuesAtScopesUsers(std::move(Arg.ValuesAtScopesUsers)),
13679 LoopDispositions(std::move(Arg.LoopDispositions)),
13680 LoopPropertiesCache(std::move(Arg.LoopPropertiesCache)),
13681 BlockDispositions(std::move(Arg.BlockDispositions)),
13682 SCEVUsers(std::move(Arg.SCEVUsers)),
13683 UnsignedRanges(std::move(Arg.UnsignedRanges)),
13684 SignedRanges(std::move(Arg.SignedRanges)),
13685 UniqueSCEVs(std::move(Arg.UniqueSCEVs)),
13686 UniquePreds(std::move(Arg.UniquePreds)),
13687 SCEVAllocator(std::move(Arg.SCEVAllocator)),
13688 LoopUsers(std::move(Arg.LoopUsers)),
13689 PredicatedSCEVRewrites(std::move(Arg.PredicatedSCEVRewrites)),
13690 FirstUnknown(Arg.FirstUnknown) {
13691 Arg.FirstUnknown = nullptr;
13692}
13693
13695 // Iterate through all the SCEVUnknown instances and call their
13696 // destructors, so that they release their references to their values.
13697 for (SCEVUnknown *U = FirstUnknown; U;) {
13698 SCEVUnknown *Tmp = U;
13699 U = U->Next;
13700 Tmp->~SCEVUnknown();
13701 }
13702 FirstUnknown = nullptr;
13703
13704 ExprValueMap.clear();
13705 ValueExprMap.clear();
13706 HasRecMap.clear();
13707 BackedgeTakenCounts.clear();
13708 PredicatedBackedgeTakenCounts.clear();
13709
13710 assert(PendingLoopPredicates.empty() && "isImpliedCond garbage");
13711 assert(PendingPhiRanges.empty() && "getRangeRef garbage");
13712 assert(PendingMerges.empty() && "isImpliedViaMerge garbage");
13713 assert(!WalkingBEDominatingConds && "isLoopBackedgeGuardedByCond garbage!");
13714 assert(!ProvingSplitPredicate && "ProvingSplitPredicate garbage!");
13715}
13716
13718 return !isa<SCEVCouldNotCompute>(getBackedgeTakenCount(L));
13719}
13720
13721/// When printing a top-level SCEV for trip counts, it's helpful to include
13722/// a type for constants which are otherwise hard to disambiguate.
13723static void PrintSCEVWithTypeHint(raw_ostream &OS, const SCEV* S) {
13724 if (isa<SCEVConstant>(S))
13725 OS << *S->getType() << " ";
13726 OS << *S;
13727}
13728
13730 const Loop *L) {
13731 // Print all inner loops first
13732 for (Loop *I : *L)
13733 PrintLoopInfo(OS, SE, I);
13734
13735 OS << "Loop ";
13736 L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
13737 OS << ": ";
13738
13739 SmallVector<BasicBlock *, 8> ExitingBlocks;
13740 L->getExitingBlocks(ExitingBlocks);
13741 if (ExitingBlocks.size() != 1)
13742 OS << "<multiple exits> ";
13743
13744 auto *BTC = SE->getBackedgeTakenCount(L);
13745 if (!isa<SCEVCouldNotCompute>(BTC)) {
13746 OS << "backedge-taken count is ";
13748 } else
13749 OS << "Unpredictable backedge-taken count.";
13750 OS << "\n";
13751
13752 if (ExitingBlocks.size() > 1)
13753 for (BasicBlock *ExitingBlock : ExitingBlocks) {
13754 OS << " exit count for " << ExitingBlock->getName() << ": ";
13755 const SCEV *EC = SE->getExitCount(L, ExitingBlock);
13757 if (isa<SCEVCouldNotCompute>(EC)) {
13758 // Retry with predicates.
13760 EC = SE->getPredicatedExitCount(L, ExitingBlock, &Predicates);
13761 if (!isa<SCEVCouldNotCompute>(EC)) {
13762 OS << "\n predicated exit count for " << ExitingBlock->getName()
13763 << ": ";
13765 OS << "\n Predicates:\n";
13766 for (const auto *P : Predicates)
13767 P->print(OS, 4);
13768 }
13769 }
13770 OS << "\n";
13771 }
13772
13773 OS << "Loop ";
13774 L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
13775 OS << ": ";
13776
13777 auto *ConstantBTC = SE->getConstantMaxBackedgeTakenCount(L);
13778 if (!isa<SCEVCouldNotCompute>(ConstantBTC)) {
13779 OS << "constant max backedge-taken count is ";
13780 PrintSCEVWithTypeHint(OS, ConstantBTC);
13782 OS << ", actual taken count either this or zero.";
13783 } else {
13784 OS << "Unpredictable constant max backedge-taken count. ";
13785 }
13786
13787 OS << "\n"
13788 "Loop ";
13789 L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
13790 OS << ": ";
13791
13792 auto *SymbolicBTC = SE->getSymbolicMaxBackedgeTakenCount(L);
13793 if (!isa<SCEVCouldNotCompute>(SymbolicBTC)) {
13794 OS << "symbolic max backedge-taken count is ";
13795 PrintSCEVWithTypeHint(OS, SymbolicBTC);
13797 OS << ", actual taken count either this or zero.";
13798 } else {
13799 OS << "Unpredictable symbolic max backedge-taken count. ";
13800 }
13801 OS << "\n";
13802
13803 if (ExitingBlocks.size() > 1)
13804 for (BasicBlock *ExitingBlock : ExitingBlocks) {
13805 OS << " symbolic max exit count for " << ExitingBlock->getName() << ": ";
13806 auto *ExitBTC = SE->getExitCount(L, ExitingBlock,
13808 PrintSCEVWithTypeHint(OS, ExitBTC);
13809 if (isa<SCEVCouldNotCompute>(ExitBTC)) {
13810 // Retry with predicates.
13812 ExitBTC = SE->getPredicatedExitCount(L, ExitingBlock, &Predicates,
13814 if (!isa<SCEVCouldNotCompute>(ExitBTC)) {
13815 OS << "\n predicated symbolic max exit count for "
13816 << ExitingBlock->getName() << ": ";
13817 PrintSCEVWithTypeHint(OS, ExitBTC);
13818 OS << "\n Predicates:\n";
13819 for (const auto *P : Predicates)
13820 P->print(OS, 4);
13821 }
13822 }
13823 OS << "\n";
13824 }
13825
13827 auto *PBT = SE->getPredicatedBackedgeTakenCount(L, Preds);
13828 if (PBT != BTC) {
13829 assert(!Preds.empty() && "Different predicated BTC, but no predicates");
13830 OS << "Loop ";
13831 L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
13832 OS << ": ";
13833 if (!isa<SCEVCouldNotCompute>(PBT)) {
13834 OS << "Predicated backedge-taken count is ";
13836 } else
13837 OS << "Unpredictable predicated backedge-taken count.";
13838 OS << "\n";
13839 OS << " Predicates:\n";
13840 for (const auto *P : Preds)
13841 P->print(OS, 4);
13842 }
13843 Preds.clear();
13844
13845 auto *PredConstantMax =
13847 if (PredConstantMax != ConstantBTC) {
13848 assert(!Preds.empty() &&
13849 "different predicated constant max BTC but no predicates");
13850 OS << "Loop ";
13851 L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
13852 OS << ": ";
13853 if (!isa<SCEVCouldNotCompute>(PredConstantMax)) {
13854 OS << "Predicated constant max backedge-taken count is ";
13855 PrintSCEVWithTypeHint(OS, PredConstantMax);
13856 } else
13857 OS << "Unpredictable predicated constant max backedge-taken count.";
13858 OS << "\n";
13859 OS << " Predicates:\n";
13860 for (const auto *P : Preds)
13861 P->print(OS, 4);
13862 }
13863 Preds.clear();
13864
13865 auto *PredSymbolicMax =
13867 if (SymbolicBTC != PredSymbolicMax) {
13868 assert(!Preds.empty() &&
13869 "Different predicated symbolic max BTC, but no predicates");
13870 OS << "Loop ";
13871 L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
13872 OS << ": ";
13873 if (!isa<SCEVCouldNotCompute>(PredSymbolicMax)) {
13874 OS << "Predicated symbolic max backedge-taken count is ";
13875 PrintSCEVWithTypeHint(OS, PredSymbolicMax);
13876 } else
13877 OS << "Unpredictable predicated symbolic max backedge-taken count.";
13878 OS << "\n";
13879 OS << " Predicates:\n";
13880 for (const auto *P : Preds)
13881 P->print(OS, 4);
13882 }
13883
13885 OS << "Loop ";
13886 L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
13887 OS << ": ";
13888 OS << "Trip multiple is " << SE->getSmallConstantTripMultiple(L) << "\n";
13889 }
13890}
13891
13892namespace llvm {
13894 switch (LD) {
13896 OS << "Variant";
13897 break;
13899 OS << "Invariant";
13900 break;
13902 OS << "Computable";
13903 break;
13904 }
13905 return OS;
13906}
13907
13909 switch (BD) {
13911 OS << "DoesNotDominate";
13912 break;
13914 OS << "Dominates";
13915 break;
13917 OS << "ProperlyDominates";
13918 break;
13919 }
13920 return OS;
13921}
13922} // namespace llvm
13923
13925 // ScalarEvolution's implementation of the print method is to print
13926 // out SCEV values of all instructions that are interesting. Doing
13927 // this potentially causes it to create new SCEV objects though,
13928 // which technically conflicts with the const qualifier. This isn't
13929 // observable from outside the class though, so casting away the
13930 // const isn't dangerous.
13931 ScalarEvolution &SE = *const_cast<ScalarEvolution *>(this);
13932
13933 if (ClassifyExpressions) {
13934 OS << "Classifying expressions for: ";
13935 F.printAsOperand(OS, /*PrintType=*/false);
13936 OS << "\n";
13937 for (Instruction &I : instructions(F))
13938 if (isSCEVable(I.getType()) && !isa<CmpInst>(I)) {
13939 OS << I << '\n';
13940 OS << " --> ";
13941 const SCEV *SV = SE.getSCEV(&I);
13942 SV->print(OS);
13943 if (!isa<SCEVCouldNotCompute>(SV)) {
13944 OS << " U: ";
13945 SE.getUnsignedRange(SV).print(OS);
13946 OS << " S: ";
13947 SE.getSignedRange(SV).print(OS);
13948 }
13949
13950 const Loop *L = LI.getLoopFor(I.getParent());
13951
13952 const SCEV *AtUse = SE.getSCEVAtScope(SV, L);
13953 if (AtUse != SV) {
13954 OS << " --> ";
13955 AtUse->print(OS);
13956 if (!isa<SCEVCouldNotCompute>(AtUse)) {
13957 OS << " U: ";
13958 SE.getUnsignedRange(AtUse).print(OS);
13959 OS << " S: ";
13960 SE.getSignedRange(AtUse).print(OS);
13961 }
13962 }
13963
13964 if (L) {
13965 OS << "\t\t" "Exits: ";
13966 const SCEV *ExitValue = SE.getSCEVAtScope(SV, L->getParentLoop());
13967 if (!SE.isLoopInvariant(ExitValue, L)) {
13968 OS << "<<Unknown>>";
13969 } else {
13970 OS << *ExitValue;
13971 }
13972
13973 bool First = true;
13974 for (const auto *Iter = L; Iter; Iter = Iter->getParentLoop()) {
13975 if (First) {
13976 OS << "\t\t" "LoopDispositions: { ";
13977 First = false;
13978 } else {
13979 OS << ", ";
13980 }
13981
13982 Iter->getHeader()->printAsOperand(OS, /*PrintType=*/false);
13983 OS << ": " << SE.getLoopDisposition(SV, Iter);
13984 }
13985
13986 for (const auto *InnerL : depth_first(L)) {
13987 if (InnerL == L)
13988 continue;
13989 if (First) {
13990 OS << "\t\t" "LoopDispositions: { ";
13991 First = false;
13992 } else {
13993 OS << ", ";
13994 }
13995
13996 InnerL->getHeader()->printAsOperand(OS, /*PrintType=*/false);
13997 OS << ": " << SE.getLoopDisposition(SV, InnerL);
13998 }
13999
14000 OS << " }";
14001 }
14002
14003 OS << "\n";
14004 }
14005 }
14006
14007 OS << "Determining loop execution counts for: ";
14008 F.printAsOperand(OS, /*PrintType=*/false);
14009 OS << "\n";
14010 for (Loop *I : LI)
14011 PrintLoopInfo(OS, &SE, I);
14012}
14013
14016 auto &Values = LoopDispositions[S];
14017 for (auto &V : Values) {
14018 if (V.getPointer() == L)
14019 return V.getInt();
14020 }
14021 Values.emplace_back(L, LoopVariant);
14022 LoopDisposition D = computeLoopDisposition(S, L);
14023 auto &Values2 = LoopDispositions[S];
14024 for (auto &V : llvm::reverse(Values2)) {
14025 if (V.getPointer() == L) {
14026 V.setInt(D);
14027 break;
14028 }
14029 }
14030 return D;
14031}
14032
14034ScalarEvolution::computeLoopDisposition(const SCEV *S, const Loop *L) {
14035 switch (S->getSCEVType()) {
14036 case scConstant:
14037 case scVScale:
14038 return LoopInvariant;
14039 case scAddRecExpr: {
14040 const SCEVAddRecExpr *AR = cast<SCEVAddRecExpr>(S);
14041
14042 // If L is the addrec's loop, it's computable.
14043 if (AR->getLoop() == L)
14044 return LoopComputable;
14045
14046 // Add recurrences are never invariant in the function-body (null loop).
14047 if (!L)
14048 return LoopVariant;
14049
14050 // Everything that is not defined at loop entry is variant.
14051 if (DT.dominates(L->getHeader(), AR->getLoop()->getHeader()))
14052 return LoopVariant;
14053 assert(!L->contains(AR->getLoop()) && "Containing loop's header does not"
14054 " dominate the contained loop's header?");
14055
14056 // This recurrence is invariant w.r.t. L if AR's loop contains L.
14057 if (AR->getLoop()->contains(L))
14058 return LoopInvariant;
14059
14060 // This recurrence is variant w.r.t. L if any of its operands
14061 // are variant.
14062 for (const auto *Op : AR->operands())
14063 if (!isLoopInvariant(Op, L))
14064 return LoopVariant;
14065
14066 // Otherwise it's loop-invariant.
14067 return LoopInvariant;
14068 }
14069 case scTruncate:
14070 case scZeroExtend:
14071 case scSignExtend:
14072 case scPtrToInt:
14073 case scAddExpr:
14074 case scMulExpr:
14075 case scUDivExpr:
14076 case scUMaxExpr:
14077 case scSMaxExpr:
14078 case scUMinExpr:
14079 case scSMinExpr:
14080 case scSequentialUMinExpr: {
14081 bool HasVarying = false;
14082 for (const auto *Op : S->operands()) {
14084 if (D == LoopVariant)
14085 return LoopVariant;
14086 if (D == LoopComputable)
14087 HasVarying = true;
14088 }
14089 return HasVarying ? LoopComputable : LoopInvariant;
14090 }
14091 case scUnknown:
14092 // All non-instruction values are loop invariant. All instructions are loop
14093 // invariant if they are not contained in the specified loop.
14094 // Instructions are never considered invariant in the function body
14095 // (null loop) because they are defined within the "loop".
14096 if (auto *I = dyn_cast<Instruction>(cast<SCEVUnknown>(S)->getValue()))
14097 return (L && !L->contains(I)) ? LoopInvariant : LoopVariant;
14098 return LoopInvariant;
14099 case scCouldNotCompute:
14100 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
14101 }
14102 llvm_unreachable("Unknown SCEV kind!");
14103}
14104
14106 return getLoopDisposition(S, L) == LoopInvariant;
14107}
14108
14110 return getLoopDisposition(S, L) == LoopComputable;
14111}
14112
14115 auto &Values = BlockDispositions[S];
14116 for (auto &V : Values) {
14117 if (V.getPointer() == BB)
14118 return V.getInt();
14119 }
14120 Values.emplace_back(BB, DoesNotDominateBlock);
14121 BlockDisposition D = computeBlockDisposition(S, BB);
14122 auto &Values2 = BlockDispositions[S];
14123 for (auto &V : llvm::reverse(Values2)) {
14124 if (V.getPointer() == BB) {
14125 V.setInt(D);
14126 break;
14127 }
14128 }
14129 return D;
14130}
14131
14133ScalarEvolution::computeBlockDisposition(const SCEV *S, const BasicBlock *BB) {
14134 switch (S->getSCEVType()) {
14135 case scConstant:
14136 case scVScale:
14138 case scAddRecExpr: {
14139 // This uses a "dominates" query instead of "properly dominates" query
14140 // to test for proper dominance too, because the instruction which
14141 // produces the addrec's value is a PHI, and a PHI effectively properly
14142 // dominates its entire containing block.
14143 const SCEVAddRecExpr *AR = cast<SCEVAddRecExpr>(S);
14144 if (!DT.dominates(AR->getLoop()->getHeader(), BB))
14145 return DoesNotDominateBlock;
14146
14147 // Fall through into SCEVNAryExpr handling.
14148 [[fallthrough]];
14149 }
14150 case scTruncate:
14151 case scZeroExtend:
14152 case scSignExtend:
14153 case scPtrToInt:
14154 case scAddExpr:
14155 case scMulExpr:
14156 case scUDivExpr:
14157 case scUMaxExpr:
14158 case scSMaxExpr:
14159 case scUMinExpr:
14160 case scSMinExpr:
14161 case scSequentialUMinExpr: {
14162 bool Proper = true;
14163 for (const SCEV *NAryOp : S->operands()) {
14165 if (D == DoesNotDominateBlock)
14166 return DoesNotDominateBlock;
14167 if (D == DominatesBlock)
14168 Proper = false;
14169 }
14170 return Proper ? ProperlyDominatesBlock : DominatesBlock;
14171 }
14172 case scUnknown:
14173 if (Instruction *I =
14174 dyn_cast<Instruction>(cast<SCEVUnknown>(S)->getValue())) {
14175 if (I->getParent() == BB)
14176 return DominatesBlock;
14177 if (DT.properlyDominates(I->getParent(), BB))
14179 return DoesNotDominateBlock;
14180 }
14182 case scCouldNotCompute:
14183 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
14184 }
14185 llvm_unreachable("Unknown SCEV kind!");
14186}
14187
14188bool ScalarEvolution::dominates(const SCEV *S, const BasicBlock *BB) {
14189 return getBlockDisposition(S, BB) >= DominatesBlock;
14190}
14191
14194}
14195
14196bool ScalarEvolution::hasOperand(const SCEV *S, const SCEV *Op) const {
14197 return SCEVExprContains(S, [&](const SCEV *Expr) { return Expr == Op; });
14198}
14199
14200void ScalarEvolution::forgetBackedgeTakenCounts(const Loop *L,
14201 bool Predicated) {
14202 auto &BECounts =
14203 Predicated ? PredicatedBackedgeTakenCounts : BackedgeTakenCounts;
14204 auto It = BECounts.find(L);
14205 if (It != BECounts.end()) {
14206 for (const ExitNotTakenInfo &ENT : It->second.ExitNotTaken) {
14207 for (const SCEV *S : {ENT.ExactNotTaken, ENT.SymbolicMaxNotTaken}) {
14208 if (!isa<SCEVConstant>(S)) {
14209 auto UserIt = BECountUsers.find(S);
14210 assert(UserIt != BECountUsers.end());
14211 UserIt->second.erase({L, Predicated});
14212 }
14213 }
14214 }
14215 BECounts.erase(It);
14216 }
14217}
14218
14219void ScalarEvolution::forgetMemoizedResults(ArrayRef<const SCEV *> SCEVs) {
14220 SmallPtrSet<const SCEV *, 8> ToForget(SCEVs.begin(), SCEVs.end());
14221 SmallVector<const SCEV *, 8> Worklist(ToForget.begin(), ToForget.end());
14222
14223 while (!Worklist.empty()) {
14224 const SCEV *Curr = Worklist.pop_back_val();
14225 auto Users = SCEVUsers.find(Curr);
14226 if (Users != SCEVUsers.end())
14227 for (const auto *User : Users->second)
14228 if (ToForget.insert(User).second)
14229 Worklist.push_back(User);
14230 }
14231
14232 for (const auto *S : ToForget)
14233 forgetMemoizedResultsImpl(S);
14234
14235 for (auto I = PredicatedSCEVRewrites.begin();
14236 I != PredicatedSCEVRewrites.end();) {
14237 std::pair<const SCEV *, const Loop *> Entry = I->first;
14238 if (ToForget.count(Entry.first))
14239 PredicatedSCEVRewrites.erase(I++);
14240 else
14241 ++I;
14242 }
14243}
14244
14245void ScalarEvolution::forgetMemoizedResultsImpl(const SCEV *S) {
14246 LoopDispositions.erase(S);
14247 BlockDispositions.erase(S);
14248 UnsignedRanges.erase(S);
14249 SignedRanges.erase(S);
14250 HasRecMap.erase(S);
14251 ConstantMultipleCache.erase(S);
14252
14253 if (auto *AR = dyn_cast<SCEVAddRecExpr>(S)) {
14254 UnsignedWrapViaInductionTried.erase(AR);
14255 SignedWrapViaInductionTried.erase(AR);
14256 }
14257
14258 auto ExprIt = ExprValueMap.find(S);
14259 if (ExprIt != ExprValueMap.end()) {
14260 for (Value *V : ExprIt->second) {
14261 auto ValueIt = ValueExprMap.find_as(V);
14262 if (ValueIt != ValueExprMap.end())
14263 ValueExprMap.erase(ValueIt);
14264 }
14265 ExprValueMap.erase(ExprIt);
14266 }
14267
14268 auto ScopeIt = ValuesAtScopes.find(S);
14269 if (ScopeIt != ValuesAtScopes.end()) {
14270 for (const auto &Pair : ScopeIt->second)
14271 if (!isa_and_nonnull<SCEVConstant>(Pair.second))
14272 llvm::erase(ValuesAtScopesUsers[Pair.second],
14273 std::make_pair(Pair.first, S));
14274 ValuesAtScopes.erase(ScopeIt);
14275 }
14276
14277 auto ScopeUserIt = ValuesAtScopesUsers.find(S);
14278 if (ScopeUserIt != ValuesAtScopesUsers.end()) {
14279 for (const auto &Pair : ScopeUserIt->second)
14280 llvm::erase(ValuesAtScopes[Pair.second], std::make_pair(Pair.first, S));
14281 ValuesAtScopesUsers.erase(ScopeUserIt);
14282 }
14283
14284 auto BEUsersIt = BECountUsers.find(S);
14285 if (BEUsersIt != BECountUsers.end()) {
14286 // Work on a copy, as forgetBackedgeTakenCounts() will modify the original.
14287 auto Copy = BEUsersIt->second;
14288 for (const auto &Pair : Copy)
14289 forgetBackedgeTakenCounts(Pair.getPointer(), Pair.getInt());
14290 BECountUsers.erase(BEUsersIt);
14291 }
14292
14293 auto FoldUser = FoldCacheUser.find(S);
14294 if (FoldUser != FoldCacheUser.end())
14295 for (auto &KV : FoldUser->second)
14296 FoldCache.erase(KV);
14297 FoldCacheUser.erase(S);
14298}
14299
14300void
14301ScalarEvolution::getUsedLoops(const SCEV *S,
14302 SmallPtrSetImpl<const Loop *> &LoopsUsed) {
14303 struct FindUsedLoops {
14304 FindUsedLoops(SmallPtrSetImpl<const Loop *> &LoopsUsed)
14305 : LoopsUsed(LoopsUsed) {}
14307 bool follow(const SCEV *S) {
14308 if (auto *AR = dyn_cast<SCEVAddRecExpr>(S))
14309 LoopsUsed.insert(AR->getLoop());
14310 return true;
14311 }
14312
14313 bool isDone() const { return false; }
14314 };
14315
14316 FindUsedLoops F(LoopsUsed);
14318}
14319
14320void ScalarEvolution::getReachableBlocks(
14323 Worklist.push_back(&F.getEntryBlock());
14324 while (!Worklist.empty()) {
14325 BasicBlock *BB = Worklist.pop_back_val();
14326 if (!Reachable.insert(BB).second)
14327 continue;
14328
14329 Value *Cond;
14330 BasicBlock *TrueBB, *FalseBB;
14331 if (match(BB->getTerminator(), m_Br(m_Value(Cond), m_BasicBlock(TrueBB),
14332 m_BasicBlock(FalseBB)))) {
14333 if (auto *C = dyn_cast<ConstantInt>(Cond)) {
14334 Worklist.push_back(C->isOne() ? TrueBB : FalseBB);
14335 continue;
14336 }
14337
14338 if (auto *Cmp = dyn_cast<ICmpInst>(Cond)) {
14339 const SCEV *L = getSCEV(Cmp->getOperand(0));
14340 const SCEV *R = getSCEV(Cmp->getOperand(1));
14341 if (isKnownPredicateViaConstantRanges(Cmp->getPredicate(), L, R)) {
14342 Worklist.push_back(TrueBB);
14343 continue;
14344 }
14345 if (isKnownPredicateViaConstantRanges(Cmp->getInversePredicate(), L,
14346 R)) {
14347 Worklist.push_back(FalseBB);
14348 continue;
14349 }
14350 }
14351 }
14352
14353 append_range(Worklist, successors(BB));
14354 }
14355}
14356
14358 ScalarEvolution &SE = *const_cast<ScalarEvolution *>(this);
14359 ScalarEvolution SE2(F, TLI, AC, DT, LI);
14360
14361 SmallVector<Loop *, 8> LoopStack(LI.begin(), LI.end());
14362
14363 // Map's SCEV expressions from one ScalarEvolution "universe" to another.
14364 struct SCEVMapper : public SCEVRewriteVisitor<SCEVMapper> {
14365 SCEVMapper(ScalarEvolution &SE) : SCEVRewriteVisitor<SCEVMapper>(SE) {}
14366
14367 const SCEV *visitConstant(const SCEVConstant *Constant) {
14368 return SE.getConstant(Constant->getAPInt());
14369 }
14370
14371 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
14372 return SE.getUnknown(Expr->getValue());
14373 }
14374
14375 const SCEV *visitCouldNotCompute(const SCEVCouldNotCompute *Expr) {
14376 return SE.getCouldNotCompute();
14377 }
14378 };
14379
14380 SCEVMapper SCM(SE2);
14381 SmallPtrSet<BasicBlock *, 16> ReachableBlocks;
14382 SE2.getReachableBlocks(ReachableBlocks, F);
14383
14384 auto GetDelta = [&](const SCEV *Old, const SCEV *New) -> const SCEV * {
14385 if (containsUndefs(Old) || containsUndefs(New)) {
14386 // SCEV treats "undef" as an unknown but consistent value (i.e. it does
14387 // not propagate undef aggressively). This means we can (and do) fail
14388 // verification in cases where a transform makes a value go from "undef"
14389 // to "undef+1" (say). The transform is fine, since in both cases the
14390 // result is "undef", but SCEV thinks the value increased by 1.
14391 return nullptr;
14392 }
14393
14394 // Unless VerifySCEVStrict is set, we only compare constant deltas.
14395 const SCEV *Delta = SE2.getMinusSCEV(Old, New);
14396 if (!VerifySCEVStrict && !isa<SCEVConstant>(Delta))
14397 return nullptr;
14398
14399 return Delta;
14400 };
14401
14402 while (!LoopStack.empty()) {
14403 auto *L = LoopStack.pop_back_val();
14404 llvm::append_range(LoopStack, *L);
14405
14406 // Only verify BECounts in reachable loops. For an unreachable loop,
14407 // any BECount is legal.
14408 if (!ReachableBlocks.contains(L->getHeader()))
14409 continue;
14410
14411 // Only verify cached BECounts. Computing new BECounts may change the
14412 // results of subsequent SCEV uses.
14413 auto It = BackedgeTakenCounts.find(L);
14414 if (It == BackedgeTakenCounts.end())
14415 continue;
14416
14417 auto *CurBECount =
14418 SCM.visit(It->second.getExact(L, const_cast<ScalarEvolution *>(this)));
14419 auto *NewBECount = SE2.getBackedgeTakenCount(L);
14420
14421 if (CurBECount == SE2.getCouldNotCompute() ||
14422 NewBECount == SE2.getCouldNotCompute()) {
14423 // NB! This situation is legal, but is very suspicious -- whatever pass
14424 // change the loop to make a trip count go from could not compute to
14425 // computable or vice-versa *should have* invalidated SCEV. However, we
14426 // choose not to assert here (for now) since we don't want false
14427 // positives.
14428 continue;
14429 }
14430
14431 if (SE.getTypeSizeInBits(CurBECount->getType()) >
14432 SE.getTypeSizeInBits(NewBECount->getType()))
14433 NewBECount = SE2.getZeroExtendExpr(NewBECount, CurBECount->getType());
14434 else if (SE.getTypeSizeInBits(CurBECount->getType()) <
14435 SE.getTypeSizeInBits(NewBECount->getType()))
14436 CurBECount = SE2.getZeroExtendExpr(CurBECount, NewBECount->getType());
14437
14438 const SCEV *Delta = GetDelta(CurBECount, NewBECount);
14439 if (Delta && !Delta->isZero()) {
14440 dbgs() << "Trip Count for " << *L << " Changed!\n";
14441 dbgs() << "Old: " << *CurBECount << "\n";
14442 dbgs() << "New: " << *NewBECount << "\n";
14443 dbgs() << "Delta: " << *Delta << "\n";
14444 std::abort();
14445 }
14446 }
14447
14448 // Collect all valid loops currently in LoopInfo.
14449 SmallPtrSet<Loop *, 32> ValidLoops;
14450 SmallVector<Loop *, 32> Worklist(LI.begin(), LI.end());
14451 while (!Worklist.empty()) {
14452 Loop *L = Worklist.pop_back_val();
14453 if (ValidLoops.insert(L).second)
14454 Worklist.append(L->begin(), L->end());
14455 }
14456 for (const auto &KV : ValueExprMap) {
14457#ifndef NDEBUG
14458 // Check for SCEV expressions referencing invalid/deleted loops.
14459 if (auto *AR = dyn_cast<SCEVAddRecExpr>(KV.second)) {
14460 assert(ValidLoops.contains(AR->getLoop()) &&
14461 "AddRec references invalid loop");
14462 }
14463#endif
14464
14465 // Check that the value is also part of the reverse map.
14466 auto It = ExprValueMap.find(KV.second);
14467 if (It == ExprValueMap.end() || !It->second.contains(KV.first)) {
14468 dbgs() << "Value " << *KV.first
14469 << " is in ValueExprMap but not in ExprValueMap\n";
14470 std::abort();
14471 }
14472
14473 if (auto *I = dyn_cast<Instruction>(&*KV.first)) {
14474 if (!ReachableBlocks.contains(I->getParent()))
14475 continue;
14476 const SCEV *OldSCEV = SCM.visit(KV.second);
14477 const SCEV *NewSCEV = SE2.getSCEV(I);
14478 const SCEV *Delta = GetDelta(OldSCEV, NewSCEV);
14479 if (Delta && !Delta->isZero()) {
14480 dbgs() << "SCEV for value " << *I << " changed!\n"
14481 << "Old: " << *OldSCEV << "\n"
14482 << "New: " << *NewSCEV << "\n"
14483 << "Delta: " << *Delta << "\n";
14484 std::abort();
14485 }
14486 }
14487 }
14488
14489 for (const auto &KV : ExprValueMap) {
14490 for (Value *V : KV.second) {
14491 auto It = ValueExprMap.find_as(V);
14492 if (It == ValueExprMap.end()) {
14493 dbgs() << "Value " << *V
14494 << " is in ExprValueMap but not in ValueExprMap\n";
14495 std::abort();
14496 }
14497 if (It->second != KV.first) {
14498 dbgs() << "Value " << *V << " mapped to " << *It->second
14499 << " rather than " << *KV.first << "\n";
14500 std::abort();
14501 }
14502 }
14503 }
14504
14505 // Verify integrity of SCEV users.
14506 for (const auto &S : UniqueSCEVs) {
14507 for (const auto *Op : S.operands()) {
14508 // We do not store dependencies of constants.
14509 if (isa<SCEVConstant>(Op))
14510 continue;
14511 auto It = SCEVUsers.find(Op);
14512 if (It != SCEVUsers.end() && It->second.count(&S))
14513 continue;
14514 dbgs() << "Use of operand " << *Op << " by user " << S
14515 << " is not being tracked!\n";
14516 std::abort();
14517 }
14518 }
14519
14520 // Verify integrity of ValuesAtScopes users.
14521 for (const auto &ValueAndVec : ValuesAtScopes) {
14522 const SCEV *Value = ValueAndVec.first;
14523 for (const auto &LoopAndValueAtScope : ValueAndVec.second) {
14524 const Loop *L = LoopAndValueAtScope.first;
14525 const SCEV *ValueAtScope = LoopAndValueAtScope.second;
14526 if (!isa<SCEVConstant>(ValueAtScope)) {
14527 auto It = ValuesAtScopesUsers.find(ValueAtScope);
14528 if (It != ValuesAtScopesUsers.end() &&
14529 is_contained(It->second, std::make_pair(L, Value)))
14530 continue;
14531 dbgs() << "Value: " << *Value << ", Loop: " << *L << ", ValueAtScope: "
14532 << *ValueAtScope << " missing in ValuesAtScopesUsers\n";
14533 std::abort();
14534 }
14535 }
14536 }
14537
14538 for (const auto &ValueAtScopeAndVec : ValuesAtScopesUsers) {
14539 const SCEV *ValueAtScope = ValueAtScopeAndVec.first;
14540 for (const auto &LoopAndValue : ValueAtScopeAndVec.second) {
14541 const Loop *L = LoopAndValue.first;
14542 const SCEV *Value = LoopAndValue.second;
14543 assert(!isa<SCEVConstant>(Value));
14544 auto It = ValuesAtScopes.find(Value);
14545 if (It != ValuesAtScopes.end() &&
14546 is_contained(It->second, std::make_pair(L, ValueAtScope)))
14547 continue;
14548 dbgs() << "Value: " << *Value << ", Loop: " << *L << ", ValueAtScope: "
14549 << *ValueAtScope << " missing in ValuesAtScopes\n";
14550 std::abort();
14551 }
14552 }
14553
14554 // Verify integrity of BECountUsers.
14555 auto VerifyBECountUsers = [&](bool Predicated) {
14556 auto &BECounts =
14557 Predicated ? PredicatedBackedgeTakenCounts : BackedgeTakenCounts;
14558 for (const auto &LoopAndBEInfo : BECounts) {
14559 for (const ExitNotTakenInfo &ENT : LoopAndBEInfo.second.ExitNotTaken) {
14560 for (const SCEV *S : {ENT.ExactNotTaken, ENT.SymbolicMaxNotTaken}) {
14561 if (!isa<SCEVConstant>(S)) {
14562 auto UserIt = BECountUsers.find(S);
14563 if (UserIt != BECountUsers.end() &&
14564 UserIt->second.contains({ LoopAndBEInfo.first, Predicated }))
14565 continue;
14566 dbgs() << "Value " << *S << " for loop " << *LoopAndBEInfo.first
14567 << " missing from BECountUsers\n";
14568 std::abort();
14569 }
14570 }
14571 }
14572 }
14573 };
14574 VerifyBECountUsers(/* Predicated */ false);
14575 VerifyBECountUsers(/* Predicated */ true);
14576
14577 // Verify intergity of loop disposition cache.
14578 for (auto &[S, Values] : LoopDispositions) {
14579 for (auto [Loop, CachedDisposition] : Values) {
14580 const auto RecomputedDisposition = SE2.getLoopDisposition(S, Loop);
14581 if (CachedDisposition != RecomputedDisposition) {
14582 dbgs() << "Cached disposition of " << *S << " for loop " << *Loop
14583 << " is incorrect: cached " << CachedDisposition << ", actual "
14584 << RecomputedDisposition << "\n";
14585 std::abort();
14586 }
14587 }
14588 }
14589
14590 // Verify integrity of the block disposition cache.
14591 for (auto &[S, Values] : BlockDispositions) {
14592 for (auto [BB, CachedDisposition] : Values) {
14593 const auto RecomputedDisposition = SE2.getBlockDisposition(S, BB);
14594 if (CachedDisposition != RecomputedDisposition) {
14595 dbgs() << "Cached disposition of " << *S << " for block %"
14596 << BB->getName() << " is incorrect: cached " << CachedDisposition
14597 << ", actual " << RecomputedDisposition << "\n";
14598 std::abort();
14599 }
14600 }
14601 }
14602
14603 // Verify FoldCache/FoldCacheUser caches.
14604 for (auto [FoldID, Expr] : FoldCache) {
14605 auto I = FoldCacheUser.find(Expr);
14606 if (I == FoldCacheUser.end()) {
14607 dbgs() << "Missing entry in FoldCacheUser for cached expression " << *Expr
14608 << "!\n";
14609 std::abort();
14610 }
14611 if (!is_contained(I->second, FoldID)) {
14612 dbgs() << "Missing FoldID in cached users of " << *Expr << "!\n";
14613 std::abort();
14614 }
14615 }
14616 for (auto [Expr, IDs] : FoldCacheUser) {
14617 for (auto &FoldID : IDs) {
14618 auto I = FoldCache.find(FoldID);
14619 if (I == FoldCache.end()) {
14620 dbgs() << "Missing entry in FoldCache for expression " << *Expr
14621 << "!\n";
14622 std::abort();
14623 }
14624 if (I->second != Expr) {
14625 dbgs() << "Entry in FoldCache doesn't match FoldCacheUser: "
14626 << *I->second << " != " << *Expr << "!\n";
14627 std::abort();
14628 }
14629 }
14630 }
14631
14632 // Verify that ConstantMultipleCache computations are correct. We check that
14633 // cached multiples and recomputed multiples are multiples of each other to
14634 // verify correctness. It is possible that a recomputed multiple is different
14635 // from the cached multiple due to strengthened no wrap flags or changes in
14636 // KnownBits computations.
14637 for (auto [S, Multiple] : ConstantMultipleCache) {
14638 APInt RecomputedMultiple = SE2.getConstantMultiple(S);
14639 if ((Multiple != 0 && RecomputedMultiple != 0 &&
14640 Multiple.urem(RecomputedMultiple) != 0 &&
14641 RecomputedMultiple.urem(Multiple) != 0)) {
14642 dbgs() << "Incorrect cached computation in ConstantMultipleCache for "
14643 << *S << " : Computed " << RecomputedMultiple
14644 << " but cache contains " << Multiple << "!\n";
14645 std::abort();
14646 }
14647 }
14648}
14649
14651 Function &F, const PreservedAnalyses &PA,
14653 // Invalidate the ScalarEvolution object whenever it isn't preserved or one
14654 // of its dependencies is invalidated.
14655 auto PAC = PA.getChecker<ScalarEvolutionAnalysis>();
14656 return !(PAC.preserved() || PAC.preservedSet<AllAnalysesOn<Function>>()) ||
14657 Inv.invalidate<AssumptionAnalysis>(F, PA) ||
14659 Inv.invalidate<LoopAnalysis>(F, PA);
14660}
14661
14662AnalysisKey ScalarEvolutionAnalysis::Key;
14663
14666 auto &TLI = AM.getResult<TargetLibraryAnalysis>(F);
14667 auto &AC = AM.getResult<AssumptionAnalysis>(F);
14668 auto &DT = AM.getResult<DominatorTreeAnalysis>(F);
14669 auto &LI = AM.getResult<LoopAnalysis>(F);
14670 return ScalarEvolution(F, TLI, AC, DT, LI);
14671}
14672
14676 return PreservedAnalyses::all();
14677}
14678
14681 // For compatibility with opt's -analyze feature under legacy pass manager
14682 // which was not ported to NPM. This keeps tests using
14683 // update_analyze_test_checks.py working.
14684 OS << "Printing analysis 'Scalar Evolution Analysis' for function '"
14685 << F.getName() << "':\n";
14687 return PreservedAnalyses::all();
14688}
14689
14691 "Scalar Evolution Analysis", false, true)
14697 "Scalar Evolution Analysis", false, true)
14698
14700
14703}
14704
14706 SE.reset(new ScalarEvolution(
14707 F, getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F),
14708 getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F),
14709 getAnalysis<DominatorTreeWrapperPass>().getDomTree(),
14710 getAnalysis<LoopInfoWrapperPass>().getLoopInfo()));
14711 return false;
14712}
14713
14715
14717 SE->print(OS);
14718}
14719
14721 if (!VerifySCEV)
14722 return;
14723
14724 SE->verify();
14725}
14726
14728 AU.setPreservesAll();
14733}
14734
14736 const SCEV *RHS) {
14738}
14739
14740const SCEVPredicate *
14742 const SCEV *LHS, const SCEV *RHS) {
14744 assert(LHS->getType() == RHS->getType() &&
14745 "Type mismatch between LHS and RHS");
14746 // Unique this node based on the arguments
14747 ID.AddInteger(SCEVPredicate::P_Compare);
14748 ID.AddInteger(Pred);
14749 ID.AddPointer(LHS);
14750 ID.AddPointer(RHS);
14751 void *IP = nullptr;
14752 if (const auto *S = UniquePreds.FindNodeOrInsertPos(ID, IP))
14753 return S;
14754 SCEVComparePredicate *Eq = new (SCEVAllocator)
14755 SCEVComparePredicate(ID.Intern(SCEVAllocator), Pred, LHS, RHS);
14756 UniquePreds.InsertNode(Eq, IP);
14757 return Eq;
14758}
14759
14761 const SCEVAddRecExpr *AR,
14764 // Unique this node based on the arguments
14765 ID.AddInteger(SCEVPredicate::P_Wrap);
14766 ID.AddPointer(AR);
14767 ID.AddInteger(AddedFlags);
14768 void *IP = nullptr;
14769 if (const auto *S = UniquePreds.FindNodeOrInsertPos(ID, IP))
14770 return S;
14771 auto *OF = new (SCEVAllocator)
14772 SCEVWrapPredicate(ID.Intern(SCEVAllocator), AR, AddedFlags);
14773 UniquePreds.InsertNode(OF, IP);
14774 return OF;
14775}
14776
14777namespace {
14778
14779class SCEVPredicateRewriter : public SCEVRewriteVisitor<SCEVPredicateRewriter> {
14780public:
14781
14782 /// Rewrites \p S in the context of a loop L and the SCEV predication
14783 /// infrastructure.
14784 ///
14785 /// If \p Pred is non-null, the SCEV expression is rewritten to respect the
14786 /// equivalences present in \p Pred.
14787 ///
14788 /// If \p NewPreds is non-null, rewrite is free to add further predicates to
14789 /// \p NewPreds such that the result will be an AddRecExpr.
14790 static const SCEV *rewrite(const SCEV *S, const Loop *L, ScalarEvolution &SE,
14792 const SCEVPredicate *Pred) {
14793 SCEVPredicateRewriter Rewriter(L, SE, NewPreds, Pred);
14794 return Rewriter.visit(S);
14795 }
14796
14797 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
14798 if (Pred) {
14799 if (auto *U = dyn_cast<SCEVUnionPredicate>(Pred)) {
14800 for (const auto *Pred : U->getPredicates())
14801 if (const auto *IPred = dyn_cast<SCEVComparePredicate>(Pred))
14802 if (IPred->getLHS() == Expr &&
14803 IPred->getPredicate() == ICmpInst::ICMP_EQ)
14804 return IPred->getRHS();
14805 } else if (const auto *IPred = dyn_cast<SCEVComparePredicate>(Pred)) {
14806 if (IPred->getLHS() == Expr &&
14807 IPred->getPredicate() == ICmpInst::ICMP_EQ)
14808 return IPred->getRHS();
14809 }
14810 }
14811 return convertToAddRecWithPreds(Expr);
14812 }
14813
14814 const SCEV *visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr) {
14815 const SCEV *Operand = visit(Expr->getOperand());
14816 const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Operand);
14817 if (AR && AR->getLoop() == L && AR->isAffine()) {
14818 // This couldn't be folded because the operand didn't have the nuw
14819 // flag. Add the nusw flag as an assumption that we could make.
14820 const SCEV *Step = AR->getStepRecurrence(SE);
14821 Type *Ty = Expr->getType();
14822 if (addOverflowAssumption(AR, SCEVWrapPredicate::IncrementNUSW))
14823 return SE.getAddRecExpr(SE.getZeroExtendExpr(AR->getStart(), Ty),
14824 SE.getSignExtendExpr(Step, Ty), L,
14825 AR->getNoWrapFlags());
14826 }
14827 return SE.getZeroExtendExpr(Operand, Expr->getType());
14828 }
14829
14830 const SCEV *visitSignExtendExpr(const SCEVSignExtendExpr *Expr) {
14831 const SCEV *Operand = visit(Expr->getOperand());
14832 const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Operand);
14833 if (AR && AR->getLoop() == L && AR->isAffine()) {
14834 // This couldn't be folded because the operand didn't have the nsw
14835 // flag. Add the nssw flag as an assumption that we could make.
14836 const SCEV *Step = AR->getStepRecurrence(SE);
14837 Type *Ty = Expr->getType();
14838 if (addOverflowAssumption(AR, SCEVWrapPredicate::IncrementNSSW))
14839 return SE.getAddRecExpr(SE.getSignExtendExpr(AR->getStart(), Ty),
14840 SE.getSignExtendExpr(Step, Ty), L,
14841 AR->getNoWrapFlags());
14842 }
14843 return SE.getSignExtendExpr(Operand, Expr->getType());
14844 }
14845
14846private:
14847 explicit SCEVPredicateRewriter(
14848 const Loop *L, ScalarEvolution &SE,
14850 const SCEVPredicate *Pred)
14851 : SCEVRewriteVisitor(SE), NewPreds(NewPreds), Pred(Pred), L(L) {}
14852
14853 bool addOverflowAssumption(const SCEVPredicate *P) {
14854 if (!NewPreds) {
14855 // Check if we've already made this assumption.
14856 return Pred && Pred->implies(P, SE);
14857 }
14858 NewPreds->push_back(P);
14859 return true;
14860 }
14861
14862 bool addOverflowAssumption(const SCEVAddRecExpr *AR,
14864 auto *A = SE.getWrapPredicate(AR, AddedFlags);
14865 return addOverflowAssumption(A);
14866 }
14867
14868 // If \p Expr represents a PHINode, we try to see if it can be represented
14869 // as an AddRec, possibly under a predicate (PHISCEVPred). If it is possible
14870 // to add this predicate as a runtime overflow check, we return the AddRec.
14871 // If \p Expr does not meet these conditions (is not a PHI node, or we
14872 // couldn't create an AddRec for it, or couldn't add the predicate), we just
14873 // return \p Expr.
14874 const SCEV *convertToAddRecWithPreds(const SCEVUnknown *Expr) {
14875 if (!isa<PHINode>(Expr->getValue()))
14876 return Expr;
14877 std::optional<
14878 std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>>>
14879 PredicatedRewrite = SE.createAddRecFromPHIWithCasts(Expr);
14880 if (!PredicatedRewrite)
14881 return Expr;
14882 for (const auto *P : PredicatedRewrite->second){
14883 // Wrap predicates from outer loops are not supported.
14884 if (auto *WP = dyn_cast<const SCEVWrapPredicate>(P)) {
14885 if (L != WP->getExpr()->getLoop())
14886 return Expr;
14887 }
14888 if (!addOverflowAssumption(P))
14889 return Expr;
14890 }
14891 return PredicatedRewrite->first;
14892 }
14893
14895 const SCEVPredicate *Pred;
14896 const Loop *L;
14897};
14898
14899} // end anonymous namespace
14900
14901const SCEV *
14903 const SCEVPredicate &Preds) {
14904 return SCEVPredicateRewriter::rewrite(S, L, *this, nullptr, &Preds);
14905}
14906
14908 const SCEV *S, const Loop *L,
14911 S = SCEVPredicateRewriter::rewrite(S, L, *this, &TransformPreds, nullptr);
14912 auto *AddRec = dyn_cast<SCEVAddRecExpr>(S);
14913
14914 if (!AddRec)
14915 return nullptr;
14916
14917 // Since the transformation was successful, we can now transfer the SCEV
14918 // predicates.
14919 Preds.append(TransformPreds.begin(), TransformPreds.end());
14920
14921 return AddRec;
14922}
14923
14924/// SCEV predicates
14926 SCEVPredicateKind Kind)
14927 : FastID(ID), Kind(Kind) {}
14928
14930 const ICmpInst::Predicate Pred,
14931 const SCEV *LHS, const SCEV *RHS)
14932 : SCEVPredicate(ID, P_Compare), Pred(Pred), LHS(LHS), RHS(RHS) {
14933 assert(LHS->getType() == RHS->getType() && "LHS and RHS types don't match");
14934 assert(LHS != RHS && "LHS and RHS are the same SCEV");
14935}
14936
14938 ScalarEvolution &SE) const {
14939 const auto *Op = dyn_cast<SCEVComparePredicate>(N);
14940
14941 if (!Op)
14942 return false;
14943
14944 if (Pred != ICmpInst::ICMP_EQ)
14945 return false;
14946
14947 return Op->LHS == LHS && Op->RHS == RHS;
14948}
14949
14950bool SCEVComparePredicate::isAlwaysTrue() const { return false; }
14951
14953 if (Pred == ICmpInst::ICMP_EQ)
14954 OS.indent(Depth) << "Equal predicate: " << *LHS << " == " << *RHS << "\n";
14955 else
14956 OS.indent(Depth) << "Compare predicate: " << *LHS << " " << Pred << ") "
14957 << *RHS << "\n";
14958
14959}
14960
14962 const SCEVAddRecExpr *AR,
14963 IncrementWrapFlags Flags)
14964 : SCEVPredicate(ID, P_Wrap), AR(AR), Flags(Flags) {}
14965
14966const SCEVAddRecExpr *SCEVWrapPredicate::getExpr() const { return AR; }
14967
14969 ScalarEvolution &SE) const {
14970 const auto *Op = dyn_cast<SCEVWrapPredicate>(N);
14971 if (!Op || setFlags(Flags, Op->Flags) != Flags)
14972 return false;
14973
14974 if (Op->AR == AR)
14975 return true;
14976
14977 if (Flags != SCEVWrapPredicate::IncrementNSSW &&
14979 return false;
14980
14981 const SCEV *Start = AR->getStart();
14982 const SCEV *OpStart = Op->AR->getStart();
14983 if (Start->getType()->isPointerTy() != OpStart->getType()->isPointerTy())
14984 return false;
14985
14986 const SCEV *Step = AR->getStepRecurrence(SE);
14987 const SCEV *OpStep = Op->AR->getStepRecurrence(SE);
14988 if (!SE.isKnownPositive(Step) || !SE.isKnownPositive(OpStep))
14989 return false;
14990
14991 // If both steps are positive, this implies N, if N's start and step are
14992 // ULE/SLE (for NSUW/NSSW) than this'.
14993 Type *WiderTy = SE.getWiderType(Step->getType(), OpStep->getType());
14994 Step = SE.getNoopOrZeroExtend(Step, WiderTy);
14995 OpStep = SE.getNoopOrZeroExtend(OpStep, WiderTy);
14996
14997 bool IsNUW = Flags == SCEVWrapPredicate::IncrementNUSW;
14998 OpStart = IsNUW ? SE.getNoopOrZeroExtend(OpStart, WiderTy)
14999 : SE.getNoopOrSignExtend(OpStart, WiderTy);
15000 Start = IsNUW ? SE.getNoopOrZeroExtend(Start, WiderTy)
15001 : SE.getNoopOrSignExtend(Start, WiderTy);
15003 return SE.isKnownPredicate(Pred, OpStep, Step) &&
15004 SE.isKnownPredicate(Pred, OpStart, Start);
15005}
15006
15008 SCEV::NoWrapFlags ScevFlags = AR->getNoWrapFlags();
15009 IncrementWrapFlags IFlags = Flags;
15010
15011 if (ScalarEvolution::setFlags(ScevFlags, SCEV::FlagNSW) == ScevFlags)
15012 IFlags = clearFlags(IFlags, IncrementNSSW);
15013
15014 return IFlags == IncrementAnyWrap;
15015}
15016
15018 OS.indent(Depth) << *getExpr() << " Added Flags: ";
15020 OS << "<nusw>";
15022 OS << "<nssw>";
15023 OS << "\n";
15024}
15025
15028 ScalarEvolution &SE) {
15029 IncrementWrapFlags ImpliedFlags = IncrementAnyWrap;
15030 SCEV::NoWrapFlags StaticFlags = AR->getNoWrapFlags();
15031
15032 // We can safely transfer the NSW flag as NSSW.
15033 if (ScalarEvolution::setFlags(StaticFlags, SCEV::FlagNSW) == StaticFlags)
15034 ImpliedFlags = IncrementNSSW;
15035
15036 if (ScalarEvolution::setFlags(StaticFlags, SCEV::FlagNUW) == StaticFlags) {
15037 // If the increment is positive, the SCEV NUW flag will also imply the
15038 // WrapPredicate NUSW flag.
15039 if (const auto *Step = dyn_cast<SCEVConstant>(AR->getStepRecurrence(SE)))
15040 if (Step->getValue()->getValue().isNonNegative())
15041 ImpliedFlags = setFlags(ImpliedFlags, IncrementNUSW);
15042 }
15043
15044 return ImpliedFlags;
15045}
15046
15047/// Union predicates don't get cached so create a dummy set ID for it.
15049 ScalarEvolution &SE)
15050 : SCEVPredicate(FoldingSetNodeIDRef(nullptr, 0), P_Union) {
15051 for (const auto *P : Preds)
15052 add(P, SE);
15053}
15054
15056 return all_of(Preds,
15057 [](const SCEVPredicate *I) { return I->isAlwaysTrue(); });
15058}
15059
15061 ScalarEvolution &SE) const {
15062 if (const auto *Set = dyn_cast<SCEVUnionPredicate>(N))
15063 return all_of(Set->Preds, [this, &SE](const SCEVPredicate *I) {
15064 return this->implies(I, SE);
15065 });
15066
15067 return any_of(Preds,
15068 [N, &SE](const SCEVPredicate *I) { return I->implies(N, SE); });
15069}
15070
15072 for (const auto *Pred : Preds)
15073 Pred->print(OS, Depth);
15074}
15075
15076void SCEVUnionPredicate::add(const SCEVPredicate *N, ScalarEvolution &SE) {
15077 if (const auto *Set = dyn_cast<SCEVUnionPredicate>(N)) {
15078 for (const auto *Pred : Set->Preds)
15079 add(Pred, SE);
15080 return;
15081 }
15082
15083 // Only add predicate if it is not already implied by this union predicate.
15084 if (!implies(N, SE))
15085 Preds.push_back(N);
15086}
15087
15089 Loop &L)
15090 : SE(SE), L(L) {
15092 Preds = std::make_unique<SCEVUnionPredicate>(Empty, SE);
15093}
15094
15097 for (const auto *Op : Ops)
15098 // We do not expect that forgetting cached data for SCEVConstants will ever
15099 // open any prospects for sharpening or introduce any correctness issues,
15100 // so we don't bother storing their dependencies.
15101 if (!isa<SCEVConstant>(Op))
15102 SCEVUsers[Op].insert(User);
15103}
15104
15106 const SCEV *Expr = SE.getSCEV(V);
15107 RewriteEntry &Entry = RewriteMap[Expr];
15108
15109 // If we already have an entry and the version matches, return it.
15110 if (Entry.second && Generation == Entry.first)
15111 return Entry.second;
15112
15113 // We found an entry but it's stale. Rewrite the stale entry
15114 // according to the current predicate.
15115 if (Entry.second)
15116 Expr = Entry.second;
15117
15118 const SCEV *NewSCEV = SE.rewriteUsingPredicate(Expr, &L, *Preds);
15119 Entry = {Generation, NewSCEV};
15120
15121 return NewSCEV;
15122}
15123
15125 if (!BackedgeCount) {
15127 BackedgeCount = SE.getPredicatedBackedgeTakenCount(&L, Preds);
15128 for (const auto *P : Preds)
15129 addPredicate(*P);
15130 }
15131 return BackedgeCount;
15132}
15133
15135 if (!SymbolicMaxBackedgeCount) {
15137 SymbolicMaxBackedgeCount =
15139 for (const auto *P : Preds)
15140 addPredicate(*P);
15141 }
15142 return SymbolicMaxBackedgeCount;
15143}
15144
15146 if (!SmallConstantMaxTripCount) {
15148 SmallConstantMaxTripCount = SE.getSmallConstantMaxTripCount(&L, &Preds);
15149 for (const auto *P : Preds)
15150 addPredicate(*P);
15151 }
15152 return *SmallConstantMaxTripCount;
15153}
15154
15156 if (Preds->implies(&Pred, SE))
15157 return;
15158
15159 SmallVector<const SCEVPredicate *, 4> NewPreds(Preds->getPredicates());
15160 NewPreds.push_back(&Pred);
15161 Preds = std::make_unique<SCEVUnionPredicate>(NewPreds, SE);
15162 updateGeneration();
15163}
15164
15166 return *Preds;
15167}
15168
15169void PredicatedScalarEvolution::updateGeneration() {
15170 // If the generation number wrapped recompute everything.
15171 if (++Generation == 0) {
15172 for (auto &II : RewriteMap) {
15173 const SCEV *Rewritten = II.second.second;
15174 II.second = {Generation, SE.rewriteUsingPredicate(Rewritten, &L, *Preds)};
15175 }
15176 }
15177}
15178
15181 const SCEV *Expr = getSCEV(V);
15182 const auto *AR = cast<SCEVAddRecExpr>(Expr);
15183
15184 auto ImpliedFlags = SCEVWrapPredicate::getImpliedFlags(AR, SE);
15185
15186 // Clear the statically implied flags.
15187 Flags = SCEVWrapPredicate::clearFlags(Flags, ImpliedFlags);
15188 addPredicate(*SE.getWrapPredicate(AR, Flags));
15189
15190 auto II = FlagsMap.insert({V, Flags});
15191 if (!II.second)
15192 II.first->second = SCEVWrapPredicate::setFlags(Flags, II.first->second);
15193}
15194
15197 const SCEV *Expr = getSCEV(V);
15198 const auto *AR = cast<SCEVAddRecExpr>(Expr);
15199
15201 Flags, SCEVWrapPredicate::getImpliedFlags(AR, SE));
15202
15203 auto II = FlagsMap.find(V);
15204
15205 if (II != FlagsMap.end())
15206 Flags = SCEVWrapPredicate::clearFlags(Flags, II->second);
15207
15209}
15210
15212 const SCEV *Expr = this->getSCEV(V);
15214 auto *New = SE.convertSCEVToAddRecWithPredicates(Expr, &L, NewPreds);
15215
15216 if (!New)
15217 return nullptr;
15218
15219 for (const auto *P : NewPreds)
15220 addPredicate(*P);
15221
15222 RewriteMap[SE.getSCEV(V)] = {Generation, New};
15223 return New;
15224}
15225
15228 : RewriteMap(Init.RewriteMap), SE(Init.SE), L(Init.L),
15229 Preds(std::make_unique<SCEVUnionPredicate>(Init.Preds->getPredicates(),
15230 SE)),
15231 Generation(Init.Generation), BackedgeCount(Init.BackedgeCount) {
15232 for (auto I : Init.FlagsMap)
15233 FlagsMap.insert(I);
15234}
15235
15237 // For each block.
15238 for (auto *BB : L.getBlocks())
15239 for (auto &I : *BB) {
15240 if (!SE.isSCEVable(I.getType()))
15241 continue;
15242
15243 auto *Expr = SE.getSCEV(&I);
15244 auto II = RewriteMap.find(Expr);
15245
15246 if (II == RewriteMap.end())
15247 continue;
15248
15249 // Don't print things that are not interesting.
15250 if (II->second.second == Expr)
15251 continue;
15252
15253 OS.indent(Depth) << "[PSE]" << I << ":\n";
15254 OS.indent(Depth + 2) << *Expr << "\n";
15255 OS.indent(Depth + 2) << "--> " << *II->second.second << "\n";
15256 }
15257}
15258
15259// Match the mathematical pattern A - (A / B) * B, where A and B can be
15260// arbitrary expressions. Also match zext (trunc A to iB) to iY, which is used
15261// for URem with constant power-of-2 second operands.
15262// It's not always easy, as A and B can be folded (imagine A is X / 2, and B is
15263// 4, A / B becomes X / 8).
15264bool ScalarEvolution::matchURem(const SCEV *Expr, const SCEV *&LHS,
15265 const SCEV *&RHS) {
15266 if (Expr->getType()->isPointerTy())
15267 return false;
15268
15269 // Try to match 'zext (trunc A to iB) to iY', which is used
15270 // for URem with constant power-of-2 second operands. Make sure the size of
15271 // the operand A matches the size of the whole expressions.
15272 if (const auto *ZExt = dyn_cast<SCEVZeroExtendExpr>(Expr))
15273 if (const auto *Trunc = dyn_cast<SCEVTruncateExpr>(ZExt->getOperand(0))) {
15274 LHS = Trunc->getOperand();
15275 // Bail out if the type of the LHS is larger than the type of the
15276 // expression for now.
15277 if (getTypeSizeInBits(LHS->getType()) >
15278 getTypeSizeInBits(Expr->getType()))
15279 return false;
15280 if (LHS->getType() != Expr->getType())
15281 LHS = getZeroExtendExpr(LHS, Expr->getType());
15283 << getTypeSizeInBits(Trunc->getType()));
15284 return true;
15285 }
15286 const auto *Add = dyn_cast<SCEVAddExpr>(Expr);
15287 if (Add == nullptr || Add->getNumOperands() != 2)
15288 return false;
15289
15290 const SCEV *A = Add->getOperand(1);
15291 const auto *Mul = dyn_cast<SCEVMulExpr>(Add->getOperand(0));
15292
15293 if (Mul == nullptr)
15294 return false;
15295
15296 const auto MatchURemWithDivisor = [&](const SCEV *B) {
15297 // (SomeExpr + (-(SomeExpr / B) * B)).
15298 if (Expr == getURemExpr(A, B)) {
15299 LHS = A;
15300 RHS = B;
15301 return true;
15302 }
15303 return false;
15304 };
15305
15306 // (SomeExpr + (-1 * (SomeExpr / B) * B)).
15307 if (Mul->getNumOperands() == 3 && isa<SCEVConstant>(Mul->getOperand(0)))
15308 return MatchURemWithDivisor(Mul->getOperand(1)) ||
15309 MatchURemWithDivisor(Mul->getOperand(2));
15310
15311 // (SomeExpr + ((-SomeExpr / B) * B)) or (SomeExpr + ((SomeExpr / B) * -B)).
15312 if (Mul->getNumOperands() == 2)
15313 return MatchURemWithDivisor(Mul->getOperand(1)) ||
15314 MatchURemWithDivisor(Mul->getOperand(0)) ||
15315 MatchURemWithDivisor(getNegativeSCEV(Mul->getOperand(1))) ||
15316 MatchURemWithDivisor(getNegativeSCEV(Mul->getOperand(0)));
15317 return false;
15318}
15319
15322 BasicBlock *Header = L->getHeader();
15323 BasicBlock *Pred = L->getLoopPredecessor();
15324 LoopGuards Guards(SE);
15326 collectFromBlock(SE, Guards, Header, Pred, VisitedBlocks);
15327 return Guards;
15328}
15329
15330void ScalarEvolution::LoopGuards::collectFromPHI(
15332 const PHINode &Phi, SmallPtrSetImpl<const BasicBlock *> &VisitedBlocks,
15334 unsigned Depth) {
15335 if (!SE.isSCEVable(Phi.getType()))
15336 return;
15337
15338 using MinMaxPattern = std::pair<const SCEVConstant *, SCEVTypes>;
15339 auto GetMinMaxConst = [&](unsigned IncomingIdx) -> MinMaxPattern {
15340 const BasicBlock *InBlock = Phi.getIncomingBlock(IncomingIdx);
15341 if (!VisitedBlocks.insert(InBlock).second)
15342 return {nullptr, scCouldNotCompute};
15343 auto [G, Inserted] = IncomingGuards.try_emplace(InBlock, LoopGuards(SE));
15344 if (Inserted)
15345 collectFromBlock(SE, G->second, Phi.getParent(), InBlock, VisitedBlocks,
15346 Depth + 1);
15347 auto &RewriteMap = G->second.RewriteMap;
15348 if (RewriteMap.empty())
15349 return {nullptr, scCouldNotCompute};
15350 auto S = RewriteMap.find(SE.getSCEV(Phi.getIncomingValue(IncomingIdx)));
15351 if (S == RewriteMap.end())
15352 return {nullptr, scCouldNotCompute};
15353 auto *SM = dyn_cast_if_present<SCEVMinMaxExpr>(S->second);
15354 if (!SM)
15355 return {nullptr, scCouldNotCompute};
15356 if (const SCEVConstant *C0 = dyn_cast<SCEVConstant>(SM->getOperand(0)))
15357 return {C0, SM->getSCEVType()};
15358 return {nullptr, scCouldNotCompute};
15359 };
15360 auto MergeMinMaxConst = [](MinMaxPattern P1,
15361 MinMaxPattern P2) -> MinMaxPattern {
15362 auto [C1, T1] = P1;
15363 auto [C2, T2] = P2;
15364 if (!C1 || !C2 || T1 != T2)
15365 return {nullptr, scCouldNotCompute};
15366 switch (T1) {
15367 case scUMaxExpr:
15368 return {C1->getAPInt().ult(C2->getAPInt()) ? C1 : C2, T1};
15369 case scSMaxExpr:
15370 return {C1->getAPInt().slt(C2->getAPInt()) ? C1 : C2, T1};
15371 case scUMinExpr:
15372 return {C1->getAPInt().ugt(C2->getAPInt()) ? C1 : C2, T1};
15373 case scSMinExpr:
15374 return {C1->getAPInt().sgt(C2->getAPInt()) ? C1 : C2, T1};
15375 default:
15376 llvm_unreachable("Trying to merge non-MinMaxExpr SCEVs.");
15377 }
15378 };
15379 auto P = GetMinMaxConst(0);
15380 for (unsigned int In = 1; In < Phi.getNumIncomingValues(); In++) {
15381 if (!P.first)
15382 break;
15383 P = MergeMinMaxConst(P, GetMinMaxConst(In));
15384 }
15385 if (P.first) {
15386 const SCEV *LHS = SE.getSCEV(const_cast<PHINode *>(&Phi));
15387 SmallVector<const SCEV *, 2> Ops({P.first, LHS});
15388 const SCEV *RHS = SE.getMinMaxExpr(P.second, Ops);
15389 Guards.RewriteMap.insert({LHS, RHS});
15390 }
15391}
15392
15393void ScalarEvolution::LoopGuards::collectFromBlock(
15395 const BasicBlock *Block, const BasicBlock *Pred,
15396 SmallPtrSetImpl<const BasicBlock *> &VisitedBlocks, unsigned Depth) {
15397 SmallVector<const SCEV *> ExprsToRewrite;
15398 auto CollectCondition = [&](ICmpInst::Predicate Predicate, const SCEV *LHS,
15399 const SCEV *RHS,
15401 &RewriteMap) {
15402 // WARNING: It is generally unsound to apply any wrap flags to the proposed
15403 // replacement SCEV which isn't directly implied by the structure of that
15404 // SCEV. In particular, using contextual facts to imply flags is *NOT*
15405 // legal. See the scoping rules for flags in the header to understand why.
15406
15407 // If LHS is a constant, apply information to the other expression.
15408 if (isa<SCEVConstant>(LHS)) {
15409 std::swap(LHS, RHS);
15411 }
15412
15413 // Check for a condition of the form (-C1 + X < C2). InstCombine will
15414 // create this form when combining two checks of the form (X u< C2 + C1) and
15415 // (X >=u C1).
15416 auto MatchRangeCheckIdiom = [&SE, Predicate, LHS, RHS, &RewriteMap,
15417 &ExprsToRewrite]() {
15418 const SCEVConstant *C1;
15419 const SCEVUnknown *LHSUnknown;
15420 auto *C2 = dyn_cast<SCEVConstant>(RHS);
15421 if (!match(LHS,
15422 m_scev_Add(m_SCEVConstant(C1), m_SCEVUnknown(LHSUnknown))) ||
15423 !C2)
15424 return false;
15425
15426 auto ExactRegion =
15427 ConstantRange::makeExactICmpRegion(Predicate, C2->getAPInt())
15428 .sub(C1->getAPInt());
15429
15430 // Bail out, unless we have a non-wrapping, monotonic range.
15431 if (ExactRegion.isWrappedSet() || ExactRegion.isFullSet())
15432 return false;
15433 auto I = RewriteMap.find(LHSUnknown);
15434 const SCEV *RewrittenLHS = I != RewriteMap.end() ? I->second : LHSUnknown;
15435 RewriteMap[LHSUnknown] = SE.getUMaxExpr(
15436 SE.getConstant(ExactRegion.getUnsignedMin()),
15437 SE.getUMinExpr(RewrittenLHS,
15438 SE.getConstant(ExactRegion.getUnsignedMax())));
15439 ExprsToRewrite.push_back(LHSUnknown);
15440 return true;
15441 };
15442 if (MatchRangeCheckIdiom())
15443 return;
15444
15445 // Return true if \p Expr is a MinMax SCEV expression with a non-negative
15446 // constant operand. If so, return in \p SCTy the SCEV type and in \p RHS
15447 // the non-constant operand and in \p LHS the constant operand.
15448 auto IsMinMaxSCEVWithNonNegativeConstant =
15449 [&](const SCEV *Expr, SCEVTypes &SCTy, const SCEV *&LHS,
15450 const SCEV *&RHS) {
15451 if (auto *MinMax = dyn_cast<SCEVMinMaxExpr>(Expr)) {
15452 if (MinMax->getNumOperands() != 2)
15453 return false;
15454 if (auto *C = dyn_cast<SCEVConstant>(MinMax->getOperand(0))) {
15455 if (C->getAPInt().isNegative())
15456 return false;
15457 SCTy = MinMax->getSCEVType();
15458 LHS = MinMax->getOperand(0);
15459 RHS = MinMax->getOperand(1);
15460 return true;
15461 }
15462 }
15463 return false;
15464 };
15465
15466 // Checks whether Expr is a non-negative constant, and Divisor is a positive
15467 // constant, and returns their APInt in ExprVal and in DivisorVal.
15468 auto GetNonNegExprAndPosDivisor = [&](const SCEV *Expr, const SCEV *Divisor,
15469 APInt &ExprVal, APInt &DivisorVal) {
15470 auto *ConstExpr = dyn_cast<SCEVConstant>(Expr);
15471 auto *ConstDivisor = dyn_cast<SCEVConstant>(Divisor);
15472 if (!ConstExpr || !ConstDivisor)
15473 return false;
15474 ExprVal = ConstExpr->getAPInt();
15475 DivisorVal = ConstDivisor->getAPInt();
15476 return ExprVal.isNonNegative() && !DivisorVal.isNonPositive();
15477 };
15478
15479 // Return a new SCEV that modifies \p Expr to the closest number divides by
15480 // \p Divisor and greater or equal than Expr.
15481 // For now, only handle constant Expr and Divisor.
15482 auto GetNextSCEVDividesByDivisor = [&](const SCEV *Expr,
15483 const SCEV *Divisor) {
15484 APInt ExprVal;
15485 APInt DivisorVal;
15486 if (!GetNonNegExprAndPosDivisor(Expr, Divisor, ExprVal, DivisorVal))
15487 return Expr;
15488 APInt Rem = ExprVal.urem(DivisorVal);
15489 if (!Rem.isZero())
15490 // return the SCEV: Expr + Divisor - Expr % Divisor
15491 return SE.getConstant(ExprVal + DivisorVal - Rem);
15492 return Expr;
15493 };
15494
15495 // Return a new SCEV that modifies \p Expr to the closest number divides by
15496 // \p Divisor and less or equal than Expr.
15497 // For now, only handle constant Expr and Divisor.
15498 auto GetPreviousSCEVDividesByDivisor = [&](const SCEV *Expr,
15499 const SCEV *Divisor) {
15500 APInt ExprVal;
15501 APInt DivisorVal;
15502 if (!GetNonNegExprAndPosDivisor(Expr, Divisor, ExprVal, DivisorVal))
15503 return Expr;
15504 APInt Rem = ExprVal.urem(DivisorVal);
15505 // return the SCEV: Expr - Expr % Divisor
15506 return SE.getConstant(ExprVal - Rem);
15507 };
15508
15509 // Apply divisibilty by \p Divisor on MinMaxExpr with constant values,
15510 // recursively. This is done by aligning up/down the constant value to the
15511 // Divisor.
15512 std::function<const SCEV *(const SCEV *, const SCEV *)>
15513 ApplyDivisibiltyOnMinMaxExpr = [&](const SCEV *MinMaxExpr,
15514 const SCEV *Divisor) {
15515 const SCEV *MinMaxLHS = nullptr, *MinMaxRHS = nullptr;
15516 SCEVTypes SCTy;
15517 if (!IsMinMaxSCEVWithNonNegativeConstant(MinMaxExpr, SCTy, MinMaxLHS,
15518 MinMaxRHS))
15519 return MinMaxExpr;
15520 auto IsMin =
15521 isa<SCEVSMinExpr>(MinMaxExpr) || isa<SCEVUMinExpr>(MinMaxExpr);
15522 assert(SE.isKnownNonNegative(MinMaxLHS) &&
15523 "Expected non-negative operand!");
15524 auto *DivisibleExpr =
15525 IsMin ? GetPreviousSCEVDividesByDivisor(MinMaxLHS, Divisor)
15526 : GetNextSCEVDividesByDivisor(MinMaxLHS, Divisor);
15528 ApplyDivisibiltyOnMinMaxExpr(MinMaxRHS, Divisor), DivisibleExpr};
15529 return SE.getMinMaxExpr(SCTy, Ops);
15530 };
15531
15532 // If we have LHS == 0, check if LHS is computing a property of some unknown
15533 // SCEV %v which we can rewrite %v to express explicitly.
15534 if (Predicate == CmpInst::ICMP_EQ && match(RHS, m_scev_Zero())) {
15535 // If LHS is A % B, i.e. A % B == 0, rewrite A to (A /u B) * B to
15536 // explicitly express that.
15537 const SCEV *URemLHS = nullptr;
15538 const SCEV *URemRHS = nullptr;
15539 if (SE.matchURem(LHS, URemLHS, URemRHS)) {
15540 if (const SCEVUnknown *LHSUnknown = dyn_cast<SCEVUnknown>(URemLHS)) {
15541 auto I = RewriteMap.find(LHSUnknown);
15542 const SCEV *RewrittenLHS =
15543 I != RewriteMap.end() ? I->second : LHSUnknown;
15544 RewrittenLHS = ApplyDivisibiltyOnMinMaxExpr(RewrittenLHS, URemRHS);
15545 const auto *Multiple =
15546 SE.getMulExpr(SE.getUDivExpr(RewrittenLHS, URemRHS), URemRHS);
15547 RewriteMap[LHSUnknown] = Multiple;
15548 ExprsToRewrite.push_back(LHSUnknown);
15549 return;
15550 }
15551 }
15552 }
15553
15554 // Do not apply information for constants or if RHS contains an AddRec.
15555 if (isa<SCEVConstant>(LHS) || SE.containsAddRecurrence(RHS))
15556 return;
15557
15558 // If RHS is SCEVUnknown, make sure the information is applied to it.
15559 if (!isa<SCEVUnknown>(LHS) && isa<SCEVUnknown>(RHS)) {
15560 std::swap(LHS, RHS);
15562 }
15563
15564 // Puts rewrite rule \p From -> \p To into the rewrite map. Also if \p From
15565 // and \p FromRewritten are the same (i.e. there has been no rewrite
15566 // registered for \p From), then puts this value in the list of rewritten
15567 // expressions.
15568 auto AddRewrite = [&](const SCEV *From, const SCEV *FromRewritten,
15569 const SCEV *To) {
15570 if (From == FromRewritten)
15571 ExprsToRewrite.push_back(From);
15572 RewriteMap[From] = To;
15573 };
15574
15575 // Checks whether \p S has already been rewritten. In that case returns the
15576 // existing rewrite because we want to chain further rewrites onto the
15577 // already rewritten value. Otherwise returns \p S.
15578 auto GetMaybeRewritten = [&](const SCEV *S) {
15579 auto I = RewriteMap.find(S);
15580 return I != RewriteMap.end() ? I->second : S;
15581 };
15582
15583 // Check for the SCEV expression (A /u B) * B while B is a constant, inside
15584 // \p Expr. The check is done recuresively on \p Expr, which is assumed to
15585 // be a composition of Min/Max SCEVs. Return whether the SCEV expression (A
15586 // /u B) * B was found, and return the divisor B in \p DividesBy. For
15587 // example, if Expr = umin (umax ((A /u 8) * 8, 16), 64), return true since
15588 // (A /u 8) * 8 matched the pattern, and return the constant SCEV 8 in \p
15589 // DividesBy.
15590 std::function<bool(const SCEV *, const SCEV *&)> HasDivisibiltyInfo =
15591 [&](const SCEV *Expr, const SCEV *&DividesBy) {
15592 if (auto *Mul = dyn_cast<SCEVMulExpr>(Expr)) {
15593 if (Mul->getNumOperands() != 2)
15594 return false;
15595 auto *MulLHS = Mul->getOperand(0);
15596 auto *MulRHS = Mul->getOperand(1);
15597 if (isa<SCEVConstant>(MulLHS))
15598 std::swap(MulLHS, MulRHS);
15599 if (auto *Div = dyn_cast<SCEVUDivExpr>(MulLHS))
15600 if (Div->getOperand(1) == MulRHS) {
15601 DividesBy = MulRHS;
15602 return true;
15603 }
15604 }
15605 if (auto *MinMax = dyn_cast<SCEVMinMaxExpr>(Expr))
15606 return HasDivisibiltyInfo(MinMax->getOperand(0), DividesBy) ||
15607 HasDivisibiltyInfo(MinMax->getOperand(1), DividesBy);
15608 return false;
15609 };
15610
15611 // Return true if Expr known to divide by \p DividesBy.
15612 std::function<bool(const SCEV *, const SCEV *&)> IsKnownToDivideBy =
15613 [&](const SCEV *Expr, const SCEV *DividesBy) {
15614 if (SE.getURemExpr(Expr, DividesBy)->isZero())
15615 return true;
15616 if (auto *MinMax = dyn_cast<SCEVMinMaxExpr>(Expr))
15617 return IsKnownToDivideBy(MinMax->getOperand(0), DividesBy) &&
15618 IsKnownToDivideBy(MinMax->getOperand(1), DividesBy);
15619 return false;
15620 };
15621
15622 const SCEV *RewrittenLHS = GetMaybeRewritten(LHS);
15623 const SCEV *DividesBy = nullptr;
15624 if (HasDivisibiltyInfo(RewrittenLHS, DividesBy))
15625 // Check that the whole expression is divided by DividesBy
15626 DividesBy =
15627 IsKnownToDivideBy(RewrittenLHS, DividesBy) ? DividesBy : nullptr;
15628
15629 // Collect rewrites for LHS and its transitive operands based on the
15630 // condition.
15631 // For min/max expressions, also apply the guard to its operands:
15632 // 'min(a, b) >= c' -> '(a >= c) and (b >= c)',
15633 // 'min(a, b) > c' -> '(a > c) and (b > c)',
15634 // 'max(a, b) <= c' -> '(a <= c) and (b <= c)',
15635 // 'max(a, b) < c' -> '(a < c) and (b < c)'.
15636
15637 // We cannot express strict predicates in SCEV, so instead we replace them
15638 // with non-strict ones against plus or minus one of RHS depending on the
15639 // predicate.
15640 const SCEV *One = SE.getOne(RHS->getType());
15641 switch (Predicate) {
15642 case CmpInst::ICMP_ULT:
15643 if (RHS->getType()->isPointerTy())
15644 return;
15645 RHS = SE.getUMaxExpr(RHS, One);
15646 [[fallthrough]];
15647 case CmpInst::ICMP_SLT: {
15648 RHS = SE.getMinusSCEV(RHS, One);
15649 RHS = DividesBy ? GetPreviousSCEVDividesByDivisor(RHS, DividesBy) : RHS;
15650 break;
15651 }
15652 case CmpInst::ICMP_UGT:
15653 case CmpInst::ICMP_SGT:
15654 RHS = SE.getAddExpr(RHS, One);
15655 RHS = DividesBy ? GetNextSCEVDividesByDivisor(RHS, DividesBy) : RHS;
15656 break;
15657 case CmpInst::ICMP_ULE:
15658 case CmpInst::ICMP_SLE:
15659 RHS = DividesBy ? GetPreviousSCEVDividesByDivisor(RHS, DividesBy) : RHS;
15660 break;
15661 case CmpInst::ICMP_UGE:
15662 case CmpInst::ICMP_SGE:
15663 RHS = DividesBy ? GetNextSCEVDividesByDivisor(RHS, DividesBy) : RHS;
15664 break;
15665 default:
15666 break;
15667 }
15668
15669 SmallVector<const SCEV *, 16> Worklist(1, LHS);
15671
15672 auto EnqueueOperands = [&Worklist](const SCEVNAryExpr *S) {
15673 append_range(Worklist, S->operands());
15674 };
15675
15676 while (!Worklist.empty()) {
15677 const SCEV *From = Worklist.pop_back_val();
15678 if (isa<SCEVConstant>(From))
15679 continue;
15680 if (!Visited.insert(From).second)
15681 continue;
15682 const SCEV *FromRewritten = GetMaybeRewritten(From);
15683 const SCEV *To = nullptr;
15684
15685 switch (Predicate) {
15686 case CmpInst::ICMP_ULT:
15687 case CmpInst::ICMP_ULE:
15688 To = SE.getUMinExpr(FromRewritten, RHS);
15689 if (auto *UMax = dyn_cast<SCEVUMaxExpr>(FromRewritten))
15690 EnqueueOperands(UMax);
15691 break;
15692 case CmpInst::ICMP_SLT:
15693 case CmpInst::ICMP_SLE:
15694 To = SE.getSMinExpr(FromRewritten, RHS);
15695 if (auto *SMax = dyn_cast<SCEVSMaxExpr>(FromRewritten))
15696 EnqueueOperands(SMax);
15697 break;
15698 case CmpInst::ICMP_UGT:
15699 case CmpInst::ICMP_UGE:
15700 To = SE.getUMaxExpr(FromRewritten, RHS);
15701 if (auto *UMin = dyn_cast<SCEVUMinExpr>(FromRewritten))
15702 EnqueueOperands(UMin);
15703 break;
15704 case CmpInst::ICMP_SGT:
15705 case CmpInst::ICMP_SGE:
15706 To = SE.getSMaxExpr(FromRewritten, RHS);
15707 if (auto *SMin = dyn_cast<SCEVSMinExpr>(FromRewritten))
15708 EnqueueOperands(SMin);
15709 break;
15710 case CmpInst::ICMP_EQ:
15711 if (isa<SCEVConstant>(RHS))
15712 To = RHS;
15713 break;
15714 case CmpInst::ICMP_NE:
15715 if (match(RHS, m_scev_Zero())) {
15716 const SCEV *OneAlignedUp =
15717 DividesBy ? GetNextSCEVDividesByDivisor(One, DividesBy) : One;
15718 To = SE.getUMaxExpr(FromRewritten, OneAlignedUp);
15719 }
15720 break;
15721 default:
15722 break;
15723 }
15724
15725 if (To)
15726 AddRewrite(From, FromRewritten, To);
15727 }
15728 };
15729
15731 // First, collect information from assumptions dominating the loop.
15732 for (auto &AssumeVH : SE.AC.assumptions()) {
15733 if (!AssumeVH)
15734 continue;
15735 auto *AssumeI = cast<CallInst>(AssumeVH);
15736 if (!SE.DT.dominates(AssumeI, Block))
15737 continue;
15738 Terms.emplace_back(AssumeI->getOperand(0), true);
15739 }
15740
15741 // Second, collect information from llvm.experimental.guards dominating the loop.
15742 auto *GuardDecl = Intrinsic::getDeclarationIfExists(
15743 SE.F.getParent(), Intrinsic::experimental_guard);
15744 if (GuardDecl)
15745 for (const auto *GU : GuardDecl->users())
15746 if (const auto *Guard = dyn_cast<IntrinsicInst>(GU))
15747 if (Guard->getFunction() == Block->getParent() &&
15748 SE.DT.dominates(Guard, Block))
15749 Terms.emplace_back(Guard->getArgOperand(0), true);
15750
15751 // Third, collect conditions from dominating branches. Starting at the loop
15752 // predecessor, climb up the predecessor chain, as long as there are
15753 // predecessors that can be found that have unique successors leading to the
15754 // original header.
15755 // TODO: share this logic with isLoopEntryGuardedByCond.
15756 unsigned NumCollectedConditions = 0;
15757 std::pair<const BasicBlock *, const BasicBlock *> Pair(Pred, Block);
15758 for (; Pair.first;
15759 Pair = SE.getPredecessorWithUniqueSuccessorForBB(Pair.first)) {
15760 VisitedBlocks.insert(Pair.second);
15761 const BranchInst *LoopEntryPredicate =
15762 dyn_cast<BranchInst>(Pair.first->getTerminator());
15763 if (!LoopEntryPredicate || LoopEntryPredicate->isUnconditional())
15764 continue;
15765
15766 Terms.emplace_back(LoopEntryPredicate->getCondition(),
15767 LoopEntryPredicate->getSuccessor(0) == Pair.second);
15768 NumCollectedConditions++;
15769
15770 // If we are recursively collecting guards stop after 2
15771 // conditions to limit compile-time impact for now.
15772 if (Depth > 0 && NumCollectedConditions == 2)
15773 break;
15774 }
15775 // Finally, if we stopped climbing the predecessor chain because
15776 // there wasn't a unique one to continue, try to collect conditions
15777 // for PHINodes by recursively following all of their incoming
15778 // blocks and try to merge the found conditions to build a new one
15779 // for the Phi.
15780 if (Pair.second->hasNPredecessorsOrMore(2) &&
15783 for (auto &Phi : Pair.second->phis())
15784 collectFromPHI(SE, Guards, Phi, VisitedBlocks, IncomingGuards, Depth);
15785 }
15786
15787 // Now apply the information from the collected conditions to
15788 // Guards.RewriteMap. Conditions are processed in reverse order, so the
15789 // earliest conditions is processed first. This ensures the SCEVs with the
15790 // shortest dependency chains are constructed first.
15791 for (auto [Term, EnterIfTrue] : reverse(Terms)) {
15792 SmallVector<Value *, 8> Worklist;
15794 Worklist.push_back(Term);
15795 while (!Worklist.empty()) {
15796 Value *Cond = Worklist.pop_back_val();
15797 if (!Visited.insert(Cond).second)
15798 continue;
15799
15800 if (auto *Cmp = dyn_cast<ICmpInst>(Cond)) {
15801 auto Predicate =
15802 EnterIfTrue ? Cmp->getPredicate() : Cmp->getInversePredicate();
15803 const auto *LHS = SE.getSCEV(Cmp->getOperand(0));
15804 const auto *RHS = SE.getSCEV(Cmp->getOperand(1));
15805 CollectCondition(Predicate, LHS, RHS, Guards.RewriteMap);
15806 continue;
15807 }
15808
15809 Value *L, *R;
15810 if (EnterIfTrue ? match(Cond, m_LogicalAnd(m_Value(L), m_Value(R)))
15811 : match(Cond, m_LogicalOr(m_Value(L), m_Value(R)))) {
15812 Worklist.push_back(L);
15813 Worklist.push_back(R);
15814 }
15815 }
15816 }
15817
15818 // Let the rewriter preserve NUW/NSW flags if the unsigned/signed ranges of
15819 // the replacement expressions are contained in the ranges of the replaced
15820 // expressions.
15821 Guards.PreserveNUW = true;
15822 Guards.PreserveNSW = true;
15823 for (const SCEV *Expr : ExprsToRewrite) {
15824 const SCEV *RewriteTo = Guards.RewriteMap[Expr];
15825 Guards.PreserveNUW &=
15826 SE.getUnsignedRange(Expr).contains(SE.getUnsignedRange(RewriteTo));
15827 Guards.PreserveNSW &=
15828 SE.getSignedRange(Expr).contains(SE.getSignedRange(RewriteTo));
15829 }
15830
15831 // Now that all rewrite information is collect, rewrite the collected
15832 // expressions with the information in the map. This applies information to
15833 // sub-expressions.
15834 if (ExprsToRewrite.size() > 1) {
15835 for (const SCEV *Expr : ExprsToRewrite) {
15836 const SCEV *RewriteTo = Guards.RewriteMap[Expr];
15837 Guards.RewriteMap.erase(Expr);
15838 Guards.RewriteMap.insert({Expr, Guards.rewrite(RewriteTo)});
15839 }
15840 }
15841}
15842
15844 /// A rewriter to replace SCEV expressions in Map with the corresponding entry
15845 /// in the map. It skips AddRecExpr because we cannot guarantee that the
15846 /// replacement is loop invariant in the loop of the AddRec.
15847 class SCEVLoopGuardRewriter
15848 : public SCEVRewriteVisitor<SCEVLoopGuardRewriter> {
15850
15852
15853 public:
15854 SCEVLoopGuardRewriter(ScalarEvolution &SE,
15855 const ScalarEvolution::LoopGuards &Guards)
15856 : SCEVRewriteVisitor(SE), Map(Guards.RewriteMap) {
15857 if (Guards.PreserveNUW)
15858 FlagMask = ScalarEvolution::setFlags(FlagMask, SCEV::FlagNUW);
15859 if (Guards.PreserveNSW)
15860 FlagMask = ScalarEvolution::setFlags(FlagMask, SCEV::FlagNSW);
15861 }
15862
15863 const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) { return Expr; }
15864
15865 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
15866 auto I = Map.find(Expr);
15867 if (I == Map.end())
15868 return Expr;
15869 return I->second;
15870 }
15871
15872 const SCEV *visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr) {
15873 auto I = Map.find(Expr);
15874 if (I == Map.end()) {
15875 // If we didn't find the extact ZExt expr in the map, check if there's
15876 // an entry for a smaller ZExt we can use instead.
15877 Type *Ty = Expr->getType();
15878 const SCEV *Op = Expr->getOperand(0);
15879 unsigned Bitwidth = Ty->getScalarSizeInBits() / 2;
15880 while (Bitwidth % 8 == 0 && Bitwidth >= 8 &&
15881 Bitwidth > Op->getType()->getScalarSizeInBits()) {
15882 Type *NarrowTy = IntegerType::get(SE.getContext(), Bitwidth);
15883 auto *NarrowExt = SE.getZeroExtendExpr(Op, NarrowTy);
15884 auto I = Map.find(NarrowExt);
15885 if (I != Map.end())
15886 return SE.getZeroExtendExpr(I->second, Ty);
15887 Bitwidth = Bitwidth / 2;
15888 }
15889
15891 Expr);
15892 }
15893 return I->second;
15894 }
15895
15896 const SCEV *visitSignExtendExpr(const SCEVSignExtendExpr *Expr) {
15897 auto I = Map.find(Expr);
15898 if (I == Map.end())
15900 Expr);
15901 return I->second;
15902 }
15903
15904 const SCEV *visitUMinExpr(const SCEVUMinExpr *Expr) {
15905 auto I = Map.find(Expr);
15906 if (I == Map.end())
15908 return I->second;
15909 }
15910
15911 const SCEV *visitSMinExpr(const SCEVSMinExpr *Expr) {
15912 auto I = Map.find(Expr);
15913 if (I == Map.end())
15915 return I->second;
15916 }
15917
15918 const SCEV *visitAddExpr(const SCEVAddExpr *Expr) {
15920 bool Changed = false;
15921 for (const auto *Op : Expr->operands()) {
15922 Operands.push_back(
15924 Changed |= Op != Operands.back();
15925 }
15926 // We are only replacing operands with equivalent values, so transfer the
15927 // flags from the original expression.
15928 return !Changed ? Expr
15929 : SE.getAddExpr(Operands,
15931 Expr->getNoWrapFlags(), FlagMask));
15932 }
15933
15934 const SCEV *visitMulExpr(const SCEVMulExpr *Expr) {
15936 bool Changed = false;
15937 for (const auto *Op : Expr->operands()) {
15938 Operands.push_back(
15940 Changed |= Op != Operands.back();
15941 }
15942 // We are only replacing operands with equivalent values, so transfer the
15943 // flags from the original expression.
15944 return !Changed ? Expr
15945 : SE.getMulExpr(Operands,
15947 Expr->getNoWrapFlags(), FlagMask));
15948 }
15949 };
15950
15951 if (RewriteMap.empty())
15952 return Expr;
15953
15954 SCEVLoopGuardRewriter Rewriter(SE, *this);
15955 return Rewriter.visit(Expr);
15956}
15957
15958const SCEV *ScalarEvolution::applyLoopGuards(const SCEV *Expr, const Loop *L) {
15959 return applyLoopGuards(Expr, LoopGuards::collect(L, *this));
15960}
15961
15963 const LoopGuards &Guards) {
15964 return Guards.rewrite(Expr);
15965}
@ Poison
static const LLT S1
Rewrite undef for PHI
This file implements a class to represent arbitrary precision integral constant values and operations...
@ PostInc
MachineBasicBlock MachineBasicBlock::iterator DebugLoc DL
Expand Atomic instructions
basic Basic Alias true
block Block Frequency Analysis
BlockVerifier::State From
static GCRegistry::Add< OcamlGC > B("ocaml", "ocaml 3.10-compatible GC")
static GCRegistry::Add< ErlangGC > A("erlang", "erlang-compatible garbage collector")
static GCRegistry::Add< StatepointGC > D("statepoint-example", "an example strategy for statepoint")
#define LLVM_DUMP_METHOD
Mark debug helper function definitions like dump() that should not be stripped from debug builds.
Definition: Compiler.h:622
This file contains the declarations for the subclasses of Constant, which represent the different fla...
Returns the sub type a function will return at a given Idx Should correspond to the result type of an ExtractValue instruction executed with just that one unsigned Idx
#define LLVM_DEBUG(...)
Definition: Debug.h:106
This file defines the DenseMap class.
This file builds on the ADT/GraphTraits.h file to build generic depth first graph iterator.
uint64_t Size
bool End
Definition: ELF_riscv.cpp:480
Generic implementation of equivalence classes through the use Tarjan's efficient union-find algorithm...
static GCMetadataPrinterRegistry::Add< ErlangGCPrinter > X("erlang", "erlang-compatible garbage collector")
static bool isSigned(unsigned int Opcode)
This file defines a hash set that can be used to remove duplication of nodes in a graph.
#define op(i)
Hexagon Common GEP
This file provides various utilities for inspecting and working with the control flow graph in LLVM I...
This defines the Use class.
iv Induction Variable Users
Definition: IVUsers.cpp:48
static bool isZero(Value *V, const DataLayout &DL, DominatorTree *DT, AssumptionCache *AC)
Definition: Lint.cpp:533
#define F(x, y, z)
Definition: MD5.cpp:55
#define I(x, y, z)
Definition: MD5.cpp:58
#define G(x, y, z)
Definition: MD5.cpp:56
mir Rename Register Operands
#define T1
ConstantRange Range(APInt(BitWidth, Low), APInt(BitWidth, High))
uint64_t IntrinsicInst * II
static GCMetadataPrinterRegistry::Add< OcamlGCMetadataPrinter > Y("ocaml", "ocaml 3.10-compatible collector")
#define P(N)
ppc ctr loops verify
PowerPC Reduce CR logical Operation
if(PassOpts->AAPipeline)
#define INITIALIZE_PASS_DEPENDENCY(depName)
Definition: PassSupport.h:55
#define INITIALIZE_PASS_END(passName, arg, name, cfg, analysis)
Definition: PassSupport.h:57
#define INITIALIZE_PASS_BEGIN(passName, arg, name, cfg, analysis)
Definition: PassSupport.h:52
R600 Clause Merge
const SmallVectorImpl< MachineOperand > & Cond
static bool isValid(const char C)
Returns true if C is a valid mangled character: <0-9a-zA-Z_>.
assert(ImpDefSCC.getReg()==AMDGPU::SCC &&ImpDefSCC.isDef())
SI optimize exec mask operations pre RA
void visit(MachineFunction &MF, MachineBasicBlock &Start, std::function< void(MachineBasicBlock *)> op)
This file contains some templates that are useful if you are working with the STL at all.
raw_pwrite_stream & OS
This file provides utility classes that use RAII to save and restore values.
bool SCEVMinMaxExprContains(const SCEV *Root, const SCEV *OperandToFind, SCEVTypes RootKind)
static cl::opt< unsigned > MaxAddRecSize("scalar-evolution-max-add-rec-size", cl::Hidden, cl::desc("Max coefficients in AddRec during evolving"), cl::init(8))
static cl::opt< unsigned > RangeIterThreshold("scev-range-iter-threshold", cl::Hidden, cl::desc("Threshold for switching to iteratively computing SCEV ranges"), cl::init(32))
static const Loop * isIntegerLoopHeaderPHI(const PHINode *PN, LoopInfo &LI)
static unsigned getConstantTripCount(const SCEVConstant *ExitCount)
static int CompareValueComplexity(const LoopInfo *const LI, Value *LV, Value *RV, unsigned Depth)
Compare the two values LV and RV in terms of their "complexity" where "complexity" is a partial (and ...
static void PushLoopPHIs(const Loop *L, SmallVectorImpl< Instruction * > &Worklist, SmallPtrSetImpl< Instruction * > &Visited)
Push PHI nodes in the header of the given loop onto the given Worklist.
static void insertFoldCacheEntry(const ScalarEvolution::FoldID &ID, const SCEV *S, DenseMap< ScalarEvolution::FoldID, const SCEV * > &FoldCache, DenseMap< const SCEV *, SmallVector< ScalarEvolution::FoldID, 2 > > &FoldCacheUser)
static bool IsKnownPredicateViaMinOrMax(ScalarEvolution &SE, ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS)
Is LHS Pred RHS true on the virtue of LHS or RHS being a Min or Max expression?
static cl::opt< bool > ClassifyExpressions("scalar-evolution-classify-expressions", cl::Hidden, cl::init(true), cl::desc("When printing analysis, include information on every instruction"))
static bool CanConstantFold(const Instruction *I)
Return true if we can constant fold an instruction of the specified type, assuming that all operands ...
static cl::opt< unsigned > AddOpsInlineThreshold("scev-addops-inline-threshold", cl::Hidden, cl::desc("Threshold for inlining addition operands into a SCEV"), cl::init(500))
static cl::opt< bool > VerifyIR("scev-verify-ir", cl::Hidden, cl::desc("Verify IR correctness when making sensitive SCEV queries (slow)"), cl::init(false))
static bool BrPHIToSelect(DominatorTree &DT, BranchInst *BI, PHINode *Merge, Value *&C, Value *&LHS, Value *&RHS)
static std::optional< int > CompareSCEVComplexity(EquivalenceClasses< const SCEV * > &EqCacheSCEV, const LoopInfo *const LI, const SCEV *LHS, const SCEV *RHS, DominatorTree &DT, unsigned Depth=0)
static const SCEV * getPreStartForExtend(const SCEVAddRecExpr *AR, Type *Ty, ScalarEvolution *SE, unsigned Depth)
static std::optional< APInt > MinOptional(std::optional< APInt > X, std::optional< APInt > Y)
Helper function to compare optional APInts: (a) if X and Y both exist, return min(X,...
static cl::opt< unsigned > MulOpsInlineThreshold("scev-mulops-inline-threshold", cl::Hidden, cl::desc("Threshold for inlining multiplication operands into a SCEV"), cl::init(32))
static void GroupByComplexity(SmallVectorImpl< const SCEV * > &Ops, LoopInfo *LI, DominatorTree &DT)
Given a list of SCEV objects, order them by their complexity, and group objects of the same complexit...
static const SCEV * constantFoldAndGroupOps(ScalarEvolution &SE, LoopInfo &LI, DominatorTree &DT, SmallVectorImpl< const SCEV * > &Ops, FoldT Fold, IsIdentityT IsIdentity, IsAbsorberT IsAbsorber)
Performs a number of common optimizations on the passed Ops.
static std::optional< const SCEV * > createNodeForSelectViaUMinSeq(ScalarEvolution *SE, const SCEV *CondExpr, const SCEV *TrueExpr, const SCEV *FalseExpr)
static Constant * BuildConstantFromSCEV(const SCEV *V)
This builds up a Constant using the ConstantExpr interface.
static ConstantInt * EvaluateConstantChrecAtConstant(const SCEVAddRecExpr *AddRec, ConstantInt *C, ScalarEvolution &SE)
static const SCEV * BinomialCoefficient(const SCEV *It, unsigned K, ScalarEvolution &SE, Type *ResultTy)
Compute BC(It, K). The result has width W. Assume, K > 0.
static cl::opt< unsigned > MaxCastDepth("scalar-evolution-max-cast-depth", cl::Hidden, cl::desc("Maximum depth of recursive SExt/ZExt/Trunc"), cl::init(8))
static bool IsMinMaxConsistingOf(const SCEV *MaybeMinMaxExpr, const SCEV *Candidate)
Is MaybeMinMaxExpr an (U|S)(Min|Max) of Candidate and some other values?
static PHINode * getConstantEvolvingPHI(Value *V, const Loop *L)
getConstantEvolvingPHI - Given an LLVM value and a loop, return a PHI node in the loop that V is deri...
static cl::opt< unsigned > MaxBruteForceIterations("scalar-evolution-max-iterations", cl::ReallyHidden, cl::desc("Maximum number of iterations SCEV will " "symbolically execute a constant " "derived loop"), cl::init(100))
static bool MatchBinarySub(const SCEV *S, const SCEV *&LHS, const SCEV *&RHS)
static uint64_t umul_ov(uint64_t i, uint64_t j, bool &Overflow)
static void PrintSCEVWithTypeHint(raw_ostream &OS, const SCEV *S)
When printing a top-level SCEV for trip counts, it's helpful to include a type for constants which ar...
static void PrintLoopInfo(raw_ostream &OS, ScalarEvolution *SE, const Loop *L)
static bool containsConstantInAddMulChain(const SCEV *StartExpr)
Determine if any of the operands in this SCEV are a constant or if any of the add or multiply express...
static const SCEV * getExtendAddRecStart(const SCEVAddRecExpr *AR, Type *Ty, ScalarEvolution *SE, unsigned Depth)
static bool hasHugeExpression(ArrayRef< const SCEV * > Ops)
Returns true if Ops contains a huge SCEV (the subtree of S contains at least HugeExprThreshold nodes)...
static cl::opt< unsigned > MaxPhiSCCAnalysisSize("scalar-evolution-max-scc-analysis-depth", cl::Hidden, cl::desc("Maximum amount of nodes to process while searching SCEVUnknown " "Phi strongly connected components"), cl::init(8))
static cl::opt< unsigned > MaxLoopGuardCollectionDepth("scalar-evolution-max-loop-guard-collection-depth", cl::Hidden, cl::desc("Maximum depth for recrusive loop guard collection"), cl::init(1))
static cl::opt< unsigned > MaxSCEVOperationsImplicationDepth("scalar-evolution-max-scev-operations-implication-depth", cl::Hidden, cl::desc("Maximum depth of recursive SCEV operations implication analysis"), cl::init(2))
static void PushDefUseChildren(Instruction *I, SmallVectorImpl< Instruction * > &Worklist, SmallPtrSetImpl< Instruction * > &Visited)
Push users of the given Instruction onto the given Worklist.
static std::optional< APInt > SolveQuadraticAddRecRange(const SCEVAddRecExpr *AddRec, const ConstantRange &Range, ScalarEvolution &SE)
Let c(n) be the value of the quadratic chrec {0,+,M,+,N} after n iterations.
static cl::opt< bool > UseContextForNoWrapFlagInference("scalar-evolution-use-context-for-no-wrap-flag-strenghening", cl::Hidden, cl::desc("Infer nuw/nsw flags using context where suitable"), cl::init(true))
static cl::opt< bool > EnableFiniteLoopControl("scalar-evolution-finite-loop", cl::Hidden, cl::desc("Handle <= and >= in finite loops"), cl::init(true))
static std::optional< std::tuple< APInt, APInt, APInt, APInt, unsigned > > GetQuadraticEquation(const SCEVAddRecExpr *AddRec)
For a given quadratic addrec, generate coefficients of the corresponding quadratic equation,...
static std::optional< BinaryOp > MatchBinaryOp(Value *V, const DataLayout &DL, AssumptionCache &AC, const DominatorTree &DT, const Instruction *CxtI)
Try to map V into a BinaryOp, and return std::nullopt on failure.
static std::optional< APInt > SolveQuadraticAddRecExact(const SCEVAddRecExpr *AddRec, ScalarEvolution &SE)
Let c(n) be the value of the quadratic chrec {L,+,M,+,N} after n iterations.
static std::optional< APInt > TruncIfPossible(std::optional< APInt > X, unsigned BitWidth)
Helper function to truncate an optional APInt to a given BitWidth.
static bool IsKnownPredicateViaAddRecStart(ScalarEvolution &SE, ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS)
static cl::opt< unsigned > MaxSCEVCompareDepth("scalar-evolution-max-scev-compare-depth", cl::Hidden, cl::desc("Maximum depth of recursive SCEV complexity comparisons"), cl::init(32))
static APInt extractConstantWithoutWrapping(ScalarEvolution &SE, const SCEVConstant *ConstantTerm, const SCEVAddExpr *WholeAddExpr)
static cl::opt< unsigned > MaxConstantEvolvingDepth("scalar-evolution-max-constant-evolving-depth", cl::Hidden, cl::desc("Maximum depth of recursive constant evolving"), cl::init(32))
static ConstantRange getRangeForAffineARHelper(APInt Step, const ConstantRange &StartRange, const APInt &MaxBECount, bool Signed)
static std::optional< ConstantRange > GetRangeFromMetadata(Value *V)
Helper method to assign a range to V from metadata present in the IR.
static const SCEV * SolveLinEquationWithOverflow(const APInt &A, const SCEV *B, SmallVectorImpl< const SCEVPredicate * > *Predicates, ScalarEvolution &SE)
Finds the minimum unsigned root of the following equation:
static cl::opt< unsigned > HugeExprThreshold("scalar-evolution-huge-expr-threshold", cl::Hidden, cl::desc("Size of the expression which is considered huge"), cl::init(4096))
static bool isKnownPredicateExtendIdiom(ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS)
static Type * isSimpleCastedPHI(const SCEV *Op, const SCEVUnknown *SymbolicPHI, bool &Signed, ScalarEvolution &SE)
Helper function to createAddRecFromPHIWithCasts.
static Constant * EvaluateExpression(Value *V, const Loop *L, DenseMap< Instruction *, Constant * > &Vals, const DataLayout &DL, const TargetLibraryInfo *TLI)
EvaluateExpression - Given an expression that passes the getConstantEvolvingPHI predicate,...
static const SCEV * MatchNotExpr(const SCEV *Expr)
If Expr computes ~A, return A else return nullptr.
static cl::opt< unsigned > MaxValueCompareDepth("scalar-evolution-max-value-compare-depth", cl::Hidden, cl::desc("Maximum depth of recursive value complexity comparisons"), cl::init(2))
static cl::opt< bool, true > VerifySCEVOpt("verify-scev", cl::Hidden, cl::location(VerifySCEV), cl::desc("Verify ScalarEvolution's backedge taken counts (slow)"))
static const SCEV * getSignedOverflowLimitForStep(const SCEV *Step, ICmpInst::Predicate *Pred, ScalarEvolution *SE)
static SCEV::NoWrapFlags StrengthenNoWrapFlags(ScalarEvolution *SE, SCEVTypes Type, const ArrayRef< const SCEV * > Ops, SCEV::NoWrapFlags Flags)
static cl::opt< unsigned > MaxArithDepth("scalar-evolution-max-arith-depth", cl::Hidden, cl::desc("Maximum depth of recursive arithmetics"), cl::init(32))
static bool HasSameValue(const SCEV *A, const SCEV *B)
SCEV structural equivalence is usually sufficient for testing whether two expressions are equal,...
static uint64_t Choose(uint64_t n, uint64_t k, bool &Overflow)
Compute the result of "n choose k", the binomial coefficient.
static bool CollectAddOperandsWithScales(SmallDenseMap< const SCEV *, APInt, 16 > &M, SmallVectorImpl< const SCEV * > &NewOps, APInt &AccumulatedConstant, ArrayRef< const SCEV * > Ops, const APInt &Scale, ScalarEvolution &SE)
Process the given Ops list, which is a list of operands to be added under the given scale,...
static bool canConstantEvolve(Instruction *I, const Loop *L)
Determine whether this instruction can constant evolve within this loop assuming its operands can all...
static PHINode * getConstantEvolvingPHIOperands(Instruction *UseInst, const Loop *L, DenseMap< Instruction *, PHINode * > &PHIMap, unsigned Depth)
getConstantEvolvingPHIOperands - Implement getConstantEvolvingPHI by recursing through each instructi...
static bool scevUnconditionallyPropagatesPoisonFromOperands(SCEVTypes Kind)
static cl::opt< bool > VerifySCEVStrict("verify-scev-strict", cl::Hidden, cl::desc("Enable stricter verification with -verify-scev is passed"))
static Constant * getOtherIncomingValue(PHINode *PN, BasicBlock *BB)
scalar evolution
static cl::opt< bool > UseExpensiveRangeSharpening("scalar-evolution-use-expensive-range-sharpening", cl::Hidden, cl::init(false), cl::desc("Use more powerful methods of sharpening expression ranges. May " "be costly in terms of compile time"))
static const SCEV * getUnsignedOverflowLimitForStep(const SCEV *Step, ICmpInst::Predicate *Pred, ScalarEvolution *SE)
This file defines the make_scope_exit function, which executes user-defined cleanup logic at scope ex...
static bool InBlock(const Value *V, const BasicBlock *BB)
Provides some synthesis utilities to produce sequences of values.
This file defines the SmallPtrSet class.
This file defines the SmallSet class.
This file defines the SmallVector class.
This file defines the 'Statistic' class, which is designed to be an easy way to expose various metric...
#define STATISTIC(VARNAME, DESC)
Definition: Statistic.h:166
This file contains some functions that are useful when dealing with strings.
static SymbolRef::Type getType(const Symbol *Sym)
Definition: TapiFile.cpp:39
static std::optional< unsigned > getOpcode(ArrayRef< VPValue * > Values)
Returns the opcode of Values or ~0 if they do not all agree.
Definition: VPlanSLP.cpp:191
Virtual Register Rewriter
Definition: VirtRegMap.cpp:261
Value * RHS
Value * LHS
static const uint32_t IV[8]
Definition: blake3_impl.h:78
Class for arbitrary precision integers.
Definition: APInt.h:78
APInt umul_ov(const APInt &RHS, bool &Overflow) const
Definition: APInt.cpp:1945
APInt udiv(const APInt &RHS) const
Unsigned division operation.
Definition: APInt.cpp:1547
APInt zext(unsigned width) const
Zero extend to a new width.
Definition: APInt.cpp:986
bool isMinSignedValue() const
Determine if this is the smallest signed value.
Definition: APInt.h:423
uint64_t getZExtValue() const
Get zero extended value.
Definition: APInt.h:1520
void setHighBits(unsigned hiBits)
Set the top hiBits bits.
Definition: APInt.h:1392
APInt getHiBits(unsigned numBits) const
Compute an APInt containing numBits highbits from this APInt.
Definition: APInt.cpp:612
APInt zextOrTrunc(unsigned width) const
Zero extend or truncate to width.
Definition: APInt.cpp:1007
unsigned getActiveBits() const
Compute the number of active bits in the value.
Definition: APInt.h:1492
APInt trunc(unsigned width) const
Truncate to new width.
Definition: APInt.cpp:910
static APInt getMaxValue(unsigned numBits)
Gets maximum unsigned value of APInt for specific bit width.
Definition: APInt.h:206
APInt abs() const
Get the absolute value.
Definition: APInt.h:1773
bool sgt(const APInt &RHS) const
Signed greater than comparison.
Definition: APInt.h:1201
bool ugt(const APInt &RHS) const
Unsigned greater than comparison.
Definition: APInt.h:1182
bool isZero() const
Determine if this value is zero, i.e. all bits are clear.
Definition: APInt.h:380
bool isSignMask() const
Check if the APInt's value is returned by getSignMask.
Definition: APInt.h:466
APInt urem(const APInt &RHS) const
Unsigned remainder operation.
Definition: APInt.cpp:1640
unsigned getBitWidth() const
Return the number of bits in the APInt.
Definition: APInt.h:1468
bool ult(const APInt &RHS) const
Unsigned less than comparison.
Definition: APInt.h:1111
static APInt getSignedMaxValue(unsigned numBits)
Gets maximum signed value of APInt for a specific bit width.
Definition: APInt.h:209
static APInt getMinValue(unsigned numBits)
Gets minimum unsigned value of APInt for a specific bit width.
Definition: APInt.h:216
bool isNegative() const
Determine sign of this APInt.
Definition: APInt.h:329
bool sle(const APInt &RHS) const
Signed less or equal comparison.
Definition: APInt.h:1166
static APInt getSignedMinValue(unsigned numBits)
Gets minimum signed value of APInt for a specific bit width.
Definition: APInt.h:219
unsigned countTrailingZeros() const
Definition: APInt.h:1626
bool isStrictlyPositive() const
Determine if this APInt Value is positive.
Definition: APInt.h:356
APInt ashr(unsigned ShiftAmt) const
Arithmetic right-shift function.
Definition: APInt.h:827
APInt multiplicativeInverse() const
Definition: APInt.cpp:1248
bool ule(const APInt &RHS) const
Unsigned less or equal comparison.
Definition: APInt.h:1150
APInt sext(unsigned width) const
Sign extend to a new width.
Definition: APInt.cpp:959
APInt shl(unsigned shiftAmt) const
Left-shift function.
Definition: APInt.h:873
static APInt getLowBitsSet(unsigned numBits, unsigned loBitsSet)
Constructs an APInt value that has the bottom loBitsSet bits set.
Definition: APInt.h:306
bool isSignBitSet() const
Determine if sign bit of this APInt is set.
Definition: APInt.h:341
bool slt(const APInt &RHS) const
Signed less than comparison.
Definition: APInt.h:1130
static APInt getZero(unsigned numBits)
Get the '0' value for the specified bit-width.
Definition: APInt.h:200
bool isIntN(unsigned N) const
Check if this APInt has an N-bits unsigned integer value.
Definition: APInt.h:432
static APInt getOneBitSet(unsigned numBits, unsigned BitNo)
Return an APInt with exactly one bit set in the result.
Definition: APInt.h:239
bool uge(const APInt &RHS) const
Unsigned greater or equal comparison.
Definition: APInt.h:1221
This templated class represents "all analyses that operate over <a particular IR unit>" (e....
Definition: Analysis.h:49
API to communicate dependencies between analyses during invalidation.
Definition: PassManager.h:292
bool invalidate(IRUnitT &IR, const PreservedAnalyses &PA)
Trigger the invalidation of some other analysis pass if not already handled and return whether it was...
Definition: PassManager.h:310
A container for analyses that lazily runs them and caches their results.
Definition: PassManager.h:253
PassT::Result & getResult(IRUnitT &IR, ExtraArgTs... ExtraArgs)
Get the result of an analysis pass for a given IR unit.
Definition: PassManager.h:410
Represent the analysis usage information of a pass.
void setPreservesAll()
Set by analyses that do not transform their input at all.
AnalysisUsage & addRequiredTransitive()
ArrayRef - Represent a constant reference to an array (0 or more elements consecutively in memory),...
Definition: ArrayRef.h:41
ArrayRef< T > take_front(size_t N=1) const
Return a copy of *this with only the first N elements.
Definition: ArrayRef.h:231
iterator end() const
Definition: ArrayRef.h:157
size_t size() const
size - Get the array size.
Definition: ArrayRef.h:168
iterator begin() const
Definition: ArrayRef.h:156
A function analysis which provides an AssumptionCache.
An immutable pass that tracks lazily created AssumptionCache objects.
A cache of @llvm.assume calls within a function.
MutableArrayRef< ResultElem > assumptions()
Access the list of assumption handles currently tracked for this function.
bool isSingleEdge() const
Check if this is the only edge between Start and End.
Definition: Dominators.cpp:51
LLVM Basic Block Representation.
Definition: BasicBlock.h:61
iterator begin()
Instruction iterator methods.
Definition: BasicBlock.h:448
const Instruction & front() const
Definition: BasicBlock.h:471
const BasicBlock * getSinglePredecessor() const
Return the predecessor of this block if it has a single predecessor block.
Definition: BasicBlock.cpp:459
const Function * getParent() const
Return the enclosing method, or null if none.
Definition: BasicBlock.h:219
const Instruction * getTerminator() const LLVM_READONLY
Returns the terminator instruction if the block is well formed or null if the block is not well forme...
Definition: BasicBlock.h:239
Value * getRHS() const
unsigned getNoWrapKind() const
Returns one of OBO::NoSignedWrap or OBO::NoUnsignedWrap.
Instruction::BinaryOps getBinaryOp() const
Returns the binary operation underlying the intrinsic.
Value * getLHS() const
BinaryOps getOpcode() const
Definition: InstrTypes.h:370
Conditional or Unconditional Branch instruction.
bool isConditional() const
BasicBlock * getSuccessor(unsigned i) const
bool isUnconditional() const
Value * getCondition() const
LLVM_ATTRIBUTE_RETURNS_NONNULL void * Allocate(size_t Size, Align Alignment)
Allocate space at the specified alignment.
Definition: Allocator.h:148
This class represents a function call, abstracting a target machine's calling convention.
Value handle with callbacks on RAUW and destruction.
Definition: ValueHandle.h:383
void setValPtr(Value *P)
Definition: ValueHandle.h:390
bool isFalseWhenEqual() const
This is just a convenience.
Definition: InstrTypes.h:946
Predicate
This enumeration lists the possible predicates for CmpInst subclasses.
Definition: InstrTypes.h:673
@ ICMP_SLT
signed less than
Definition: InstrTypes.h:702
@ ICMP_SLE
signed less or equal
Definition: InstrTypes.h:703
@ ICMP_UGE
unsigned greater or equal
Definition: InstrTypes.h:697
@ ICMP_UGT
unsigned greater than
Definition: InstrTypes.h:696
@ ICMP_SGT
signed greater than
Definition: InstrTypes.h:700
@ ICMP_ULT
unsigned less than
Definition: InstrTypes.h:698
@ ICMP_EQ
equal
Definition: InstrTypes.h:694
@ ICMP_NE
not equal
Definition: InstrTypes.h:695
@ ICMP_SGE
signed greater or equal
Definition: InstrTypes.h:701
@ ICMP_ULE
unsigned less or equal
Definition: InstrTypes.h:699
bool isSigned() const
Definition: InstrTypes.h:928
Predicate getSwappedPredicate() const
For example, EQ->EQ, SLE->SGE, ULT->UGT, OEQ->OEQ, ULE->UGE, OLT->OGT, etc.
Definition: InstrTypes.h:825
bool isTrueWhenEqual() const
This is just a convenience.
Definition: InstrTypes.h:940
Predicate getNonStrictPredicate() const
For example, SGT -> SGE, SLT -> SLE, ULT -> ULE, UGT -> UGE.
Definition: InstrTypes.h:869
Predicate getInversePredicate() const
For example, EQ -> NE, UGT -> ULE, SLT -> SGE, OEQ -> UNE, UGT -> OLE, OLT -> UGE,...
Definition: InstrTypes.h:787
Predicate getPredicate() const
Return the predicate for this instruction.
Definition: InstrTypes.h:763
bool isUnsigned() const
Definition: InstrTypes.h:934
bool isRelational() const
Return true if the predicate is relational (not EQ or NE).
Definition: InstrTypes.h:924
static Constant * getNot(Constant *C)
Definition: Constants.cpp:2631
static Constant * getPtrToInt(Constant *C, Type *Ty, bool OnlyIfReduced=false)
Definition: Constants.cpp:2293
static Constant * getGetElementPtr(Type *Ty, Constant *C, ArrayRef< Constant * > IdxList, GEPNoWrapFlags NW=GEPNoWrapFlags::none(), std::optional< ConstantRange > InRange=std::nullopt, Type *OnlyIfReducedTy=nullptr)
Getelementptr form.
Definition: Constants.h:1267
static Constant * getAdd(Constant *C1, Constant *C2, bool HasNUW=false, bool HasNSW=false)
Definition: Constants.cpp:2637
static Constant * getNeg(Constant *C, bool HasNSW=false)
Definition: Constants.cpp:2625
static Constant * getTrunc(Constant *C, Type *Ty, bool OnlyIfReduced=false)
Definition: Constants.cpp:2279
This is the shared class of boolean and integer constants.
Definition: Constants.h:83
bool isZero() const
This is just a convenience method to make client code smaller for a common code.
Definition: Constants.h:208
static ConstantInt * getFalse(LLVMContext &Context)
Definition: Constants.cpp:873
uint64_t getZExtValue() const
Return the constant as a 64-bit unsigned integer value after it has been zero extended as appropriate...
Definition: Constants.h:157
const APInt & getValue() const
Return the constant as an APInt value reference.
Definition: Constants.h:148
static ConstantInt * getBool(LLVMContext &Context, bool V)
Definition: Constants.cpp:880
This class represents a range of values.
Definition: ConstantRange.h:47
ConstantRange add(const ConstantRange &Other) const
Return a new range representing the possible values resulting from an addition of a value in this ran...
ConstantRange zextOrTrunc(uint32_t BitWidth) const
Make this range have the bit width given by BitWidth.
PreferredRangeType
If represented precisely, the result of some range operations may consist of multiple disjoint ranges...
bool getEquivalentICmp(CmpInst::Predicate &Pred, APInt &RHS) const
Set up Pred and RHS such that ConstantRange::makeExactICmpRegion(Pred, RHS) == *this.
ConstantRange subtract(const APInt &CI) const
Subtract the specified constant from the endpoints of this constant range.
const APInt & getLower() const
Return the lower value for this range.
ConstantRange truncate(uint32_t BitWidth) const
Return a new range in the specified integer type, which must be strictly smaller than the current typ...
bool isFullSet() const
Return true if this set contains all of the elements possible for this data-type.
bool icmp(CmpInst::Predicate Pred, const ConstantRange &Other) const
Does the predicate Pred hold between ranges this and Other? NOTE: false does not mean that inverse pr...
bool isEmptySet() const
Return true if this set contains no members.
ConstantRange zeroExtend(uint32_t BitWidth) const
Return a new range in the specified integer type, which must be strictly larger than the current type...
bool isSignWrappedSet() const
Return true if this set wraps around the signed domain.
APInt getSignedMin() const
Return the smallest signed value contained in the ConstantRange.
bool isWrappedSet() const
Return true if this set wraps around the unsigned domain.
void print(raw_ostream &OS) const
Print out the bounds to a stream.
ConstantRange signExtend(uint32_t BitWidth) const
Return a new range in the specified integer type, which must be strictly larger than the current type...
const APInt & getUpper() const
Return the upper value for this range.
ConstantRange unionWith(const ConstantRange &CR, PreferredRangeType Type=Smallest) const
Return the range that results from the union of this range with another range.
static ConstantRange makeExactICmpRegion(CmpInst::Predicate Pred, const APInt &Other)
Produce the exact range such that all values in the returned range satisfy the given predicate with a...
bool contains(const APInt &Val) const
Return true if the specified value is in the set.
APInt getUnsignedMax() const
Return the largest unsigned value contained in the ConstantRange.
ConstantRange intersectWith(const ConstantRange &CR, PreferredRangeType Type=Smallest) const
Return the range that results from the intersection of this range with another range.
APInt getSignedMax() const
Return the largest signed value contained in the ConstantRange.
static ConstantRange getNonEmpty(APInt Lower, APInt Upper)
Create non-empty constant range with the given bounds.
Definition: ConstantRange.h:84
static ConstantRange makeGuaranteedNoWrapRegion(Instruction::BinaryOps BinOp, const ConstantRange &Other, unsigned NoWrapKind)
Produce the largest range containing all X such that "X BinOp Y" is guaranteed not to wrap (overflow)...
unsigned getMinSignedBits() const
Compute the maximal number of bits needed to represent every value in this signed range.
uint32_t getBitWidth() const
Get the bit width of this ConstantRange.
ConstantRange sub(const ConstantRange &Other) const
Return a new range representing the possible values resulting from a subtraction of a value in this r...
ConstantRange sextOrTrunc(uint32_t BitWidth) const
Make this range have the bit width given by BitWidth.
static ConstantRange makeExactNoWrapRegion(Instruction::BinaryOps BinOp, const APInt &Other, unsigned NoWrapKind)
Produce the range that contains X if and only if "X BinOp Other" does not wrap.
This is an important base class in LLVM.
Definition: Constant.h:42
This class represents an Operation in the Expression.
A parsed version of the target data layout string in and methods for querying it.
Definition: DataLayout.h:63
const StructLayout * getStructLayout(StructType *Ty) const
Returns a StructLayout object, indicating the alignment of the struct, its size, and the offsets of i...
Definition: DataLayout.cpp:709
IntegerType * getIntPtrType(LLVMContext &C, unsigned AddressSpace=0) const
Returns an integer type with size at least as big as that of a pointer in the given address space.
Definition: DataLayout.cpp:851
unsigned getIndexTypeSizeInBits(Type *Ty) const
Layout size of the index used in GEP calculation.
Definition: DataLayout.cpp:754
IntegerType * getIndexType(LLVMContext &C, unsigned AddressSpace) const
Returns the type of a GEP index in AddressSpace.
Definition: DataLayout.cpp:878
TypeSize getTypeSizeInBits(Type *Ty) const
Size examples:
Definition: DataLayout.h:617
ValueT lookup(const_arg_type_t< KeyT > Val) const
lookup - Return the entry for the specified key, or a default constructed value if no such entry exis...
Definition: DenseMap.h:194
iterator find(const_arg_type_t< KeyT > Val)
Definition: DenseMap.h:156
std::pair< iterator, bool > try_emplace(KeyT &&Key, Ts &&...Args)
Definition: DenseMap.h:226
bool erase(const KeyT &Val)
Definition: DenseMap.h:321
DenseMapIterator< KeyT, ValueT, KeyInfoT, BucketT > iterator
Definition: DenseMap.h:71
iterator find_as(const LookupKeyT &Val)
Alternate version of find() which allows a different, and possibly less expensive,...
Definition: DenseMap.h:176
size_type count(const_arg_type_t< KeyT > Val) const
Return 1 if the specified key is in the map, 0 otherwise.
Definition: DenseMap.h:152
iterator end()
Definition: DenseMap.h:84
bool contains(const_arg_type_t< KeyT > Val) const
Return true if the specified key is in the map, false otherwise.
Definition: DenseMap.h:147
std::pair< iterator, bool > insert(const std::pair< KeyT, ValueT > &KV)
Definition: DenseMap.h:211
Analysis pass which computes a DominatorTree.
Definition: Dominators.h:279
bool properlyDominates(const DomTreeNodeBase< NodeT > *A, const DomTreeNodeBase< NodeT > *B) const
properlyDominates - Returns true iff A dominates B and A != B.
Legacy analysis pass which computes a DominatorTree.
Definition: Dominators.h:317
Concrete subclass of DominatorTreeBase that is used to compute a normal dominator tree.
Definition: Dominators.h:162
bool isReachableFromEntry(const Use &U) const
Provide an overload for a Use.
Definition: Dominators.cpp:321
bool dominates(const BasicBlock *BB, const Use &U) const
Return true if the (end of the) basic block BB dominates the use U.
Definition: Dominators.cpp:122
EquivalenceClasses - This represents a collection of equivalence classes and supports three efficient...
member_iterator unionSets(const ElemTy &V1, const ElemTy &V2)
union - Merge the two equivalence sets for the specified values, inserting them if they do not alread...
bool isEquivalent(const ElemTy &V1, const ElemTy &V2) const
FoldingSetNodeIDRef - This class describes a reference to an interned FoldingSetNodeID,...
Definition: FoldingSet.h:290
FoldingSetNodeID - This class is used to gather all the unique data bits of a node.
Definition: FoldingSet.h:327
FunctionPass class - This class is used to implement most global optimizations.
Definition: Pass.h:310
const BasicBlock & getEntryBlock() const
Definition: Function.h:809
bool hasFnAttribute(Attribute::AttrKind Kind) const
Return true if the function has the attribute.
Definition: Function.cpp:731
Represents flags for the getelementptr instruction/expression.
bool hasNoUnsignedSignedWrap() const
bool hasNoUnsignedWrap() const
static GEPNoWrapFlags none()
static Type * getTypeAtIndex(Type *Ty, Value *Idx)
Return the type of the element at the given index of an indexable type.
Module * getParent()
Get the module that this global value is contained inside of...
Definition: GlobalValue.h:656
static bool isPrivateLinkage(LinkageTypes Linkage)
Definition: GlobalValue.h:406
static bool isInternalLinkage(LinkageTypes Linkage)
Definition: GlobalValue.h:403
This instruction compares its operands according to the predicate given to the constructor.
static bool isGE(Predicate P)
Return true if the predicate is SGE or UGE.
static bool compare(const APInt &LHS, const APInt &RHS, ICmpInst::Predicate Pred)
Return result of LHS Pred RHS comparison.
static bool isLT(Predicate P)
Return true if the predicate is SLT or ULT.
static bool isGT(Predicate P)
Return true if the predicate is SGT or UGT.
Predicate getFlippedSignednessPredicate() const
For example, SLT->ULT, ULT->SLT, SLE->ULE, ULE->SLE, EQ->EQ.
bool isEquality() const
Return true if this predicate is either EQ or NE.
bool isRelational() const
Return true if the predicate is relational (not EQ or NE).
static bool isLE(Predicate P)
Return true if the predicate is SLE or ULE.
bool hasNoUnsignedWrap() const LLVM_READONLY
Determine whether the no unsigned wrap flag is set.
bool hasNoSignedWrap() const LLVM_READONLY
Determine whether the no signed wrap flag is set.
bool isIdenticalToWhenDefined(const Instruction *I, bool IntersectAttrs=false) const LLVM_READONLY
This is like isIdenticalTo, except that it ignores the SubclassOptionalData flags,...
Class to represent integer types.
Definition: DerivedTypes.h:42
static IntegerType * get(LLVMContext &C, unsigned NumBits)
This static method is the primary way of constructing an IntegerType.
Definition: Type.cpp:311
An instruction for reading from memory.
Definition: Instructions.h:176
Analysis pass that exposes the LoopInfo for a function.
Definition: LoopInfo.h:566
bool contains(const LoopT *L) const
Return true if the specified loop is contained within in this loop.
BlockT * getHeader() const
unsigned getLoopDepth() const
Return the nesting level of this loop.
BlockT * getLoopPredecessor() const
If the given loop's header has exactly one unique predecessor outside the loop, return it.
LoopT * getParentLoop() const
Return the parent loop if it exists or nullptr for top level loops.
iterator end() const
unsigned getLoopDepth(const BlockT *BB) const
Return the loop nesting level of the specified block.
iterator begin() const
LoopT * getLoopFor(const BlockT *BB) const
Return the inner most loop that BB lives in.
The legacy pass manager's analysis pass to compute loop information.
Definition: LoopInfo.h:593
Represents a single loop in the control flow graph.
Definition: LoopInfo.h:39
bool isLoopInvariant(const Value *V) const
Return true if the specified value is loop invariant.
Definition: LoopInfo.cpp:61
Metadata node.
Definition: Metadata.h:1069
A Module instance is used to store all the information related to an LLVM module.
Definition: Module.h:65
This is a utility class that provides an abstraction for the common functionality between Instruction...
Definition: Operator.h:32
unsigned getOpcode() const
Return the opcode for this Instruction or ConstantExpr.
Definition: Operator.h:42
Utility class for integer operators which may exhibit overflow - Add, Sub, Mul, and Shl.
Definition: Operator.h:77
bool hasNoSignedWrap() const
Test whether this operation is known to never undergo signed overflow, aka the nsw property.
Definition: Operator.h:110
bool hasNoUnsignedWrap() const
Test whether this operation is known to never undergo unsigned overflow, aka the nuw property.
Definition: Operator.h:104
iterator_range< const_block_iterator > blocks() const
op_range incoming_values()
Value * getIncomingValueForBlock(const BasicBlock *BB) const
BasicBlock * getIncomingBlock(unsigned i) const
Return incoming basic block number i.
Value * getIncomingValue(unsigned i) const
Return incoming value number x.
unsigned getNumIncomingValues() const
Return the number of incoming edges.
static PassRegistry * getPassRegistry()
getPassRegistry - Access the global registry object, which is automatically initialized at applicatio...
PointerIntPair - This class implements a pair of a pointer and small integer.
static PointerType * getUnqual(Type *ElementType)
This constructs a pointer to an object of the specified type in the default address space (address sp...
Definition: DerivedTypes.h:686
static PoisonValue * get(Type *T)
Static factory methods - Return an 'poison' object of the specified type.
Definition: Constants.cpp:1878
An interface layer with SCEV used to manage how we see SCEV expressions for values in the context of ...
void addPredicate(const SCEVPredicate &Pred)
Adds a new predicate.
const SCEVPredicate & getPredicate() const
bool hasNoOverflow(Value *V, SCEVWrapPredicate::IncrementWrapFlags Flags)
Returns true if we've proved that V doesn't wrap by means of a SCEV predicate.
void setNoOverflow(Value *V, SCEVWrapPredicate::IncrementWrapFlags Flags)
Proves that V doesn't overflow by adding SCEV predicate.
void print(raw_ostream &OS, unsigned Depth) const
Print the SCEV mappings done by the Predicated Scalar Evolution.
bool areAddRecsEqualWithPreds(const SCEVAddRecExpr *AR1, const SCEVAddRecExpr *AR2) const
Check if AR1 and AR2 are equal, while taking into account Equal predicates in Preds.
PredicatedScalarEvolution(ScalarEvolution &SE, Loop &L)
const SCEVAddRecExpr * getAsAddRec(Value *V)
Attempts to produce an AddRecExpr for V by adding additional SCEV predicates.
unsigned getSmallConstantMaxTripCount()
Returns the upper bound of the loop trip count as a normal unsigned value, or 0 if the trip count is ...
const SCEV * getBackedgeTakenCount()
Get the (predicated) backedge count for the analyzed loop.
const SCEV * getSymbolicMaxBackedgeTakenCount()
Get the (predicated) symbolic max backedge count for the analyzed loop.
const SCEV * getSCEV(Value *V)
Returns the SCEV expression of V, in the context of the current SCEV predicate.
A set of analyses that are preserved following a run of a transformation pass.
Definition: Analysis.h:111
static PreservedAnalyses all()
Construct a special preserved set that preserves all passes.
Definition: Analysis.h:117
PreservedAnalysisChecker getChecker() const
Build a checker for this PreservedAnalyses and the specified analysis type.
Definition: Analysis.h:264
constexpr bool isValid() const
Definition: Register.h:116
This node represents an addition of some number of SCEVs.
This node represents a polynomial recurrence on the trip count of the specified loop.
const SCEV * evaluateAtIteration(const SCEV *It, ScalarEvolution &SE) const
Return the value of this chain of recurrences at the specified iteration number.
const SCEV * getStepRecurrence(ScalarEvolution &SE) const
Constructs and returns the recurrence indicating how much this expression steps by.
void setNoWrapFlags(NoWrapFlags Flags)
Set flags for a recurrence without clearing any previously set flags.
bool isAffine() const
Return true if this represents an expression A + B*x where A and B are loop invariant values.
bool isQuadratic() const
Return true if this represents an expression A + B*x + C*x^2 where A, B and C are loop invariant valu...
const SCEV * getNumIterationsInRange(const ConstantRange &Range, ScalarEvolution &SE) const
Return the number of iterations of this loop that produce values in the specified constant range.
const SCEVAddRecExpr * getPostIncExpr(ScalarEvolution &SE) const
Return an expression representing the value of this expression one iteration of the loop ahead.
This is the base class for unary cast operator classes.
const SCEV * getOperand() const
SCEVCastExpr(const FoldingSetNodeIDRef ID, SCEVTypes SCEVTy, const SCEV *op, Type *ty)
void setNoWrapFlags(NoWrapFlags Flags)
Set flags for a non-recurrence without clearing previously set flags.
This class represents an assumption that the expression LHS Pred RHS evaluates to true,...
SCEVComparePredicate(const FoldingSetNodeIDRef ID, const ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS)
bool isAlwaysTrue() const override
Returns true if the predicate is always true.
void print(raw_ostream &OS, unsigned Depth=0) const override
Prints a textual representation of this predicate with an indentation of Depth.
bool implies(const SCEVPredicate *N, ScalarEvolution &SE) const override
Implementation of the SCEVPredicate interface.
This class represents a constant integer value.
ConstantInt * getValue() const
const APInt & getAPInt() const
This is the base class for unary integral cast operator classes.
SCEVIntegralCastExpr(const FoldingSetNodeIDRef ID, SCEVTypes SCEVTy, const SCEV *op, Type *ty)
This node is the base class min/max selections.
static enum SCEVTypes negate(enum SCEVTypes T)
This node represents multiplication of some number of SCEVs.
This node is a base class providing common functionality for n'ary operators.
NoWrapFlags getNoWrapFlags(NoWrapFlags Mask=NoWrapMask) const
const SCEV * getOperand(unsigned i) const
const SCEV *const * Operands
ArrayRef< const SCEV * > operands() const
This class represents an assumption made using SCEV expressions which can be checked at run-time.
SCEVPredicate(const SCEVPredicate &)=default
virtual bool implies(const SCEVPredicate *N, ScalarEvolution &SE) const =0
Returns true if this predicate implies N.
virtual void print(raw_ostream &OS, unsigned Depth=0) const =0
Prints a textual representation of this predicate with an indentation of Depth.
This class represents a cast from a pointer to a pointer-sized integer value.
This visitor recursively visits a SCEV expression and re-writes it.
const SCEV * visitSignExtendExpr(const SCEVSignExtendExpr *Expr)
const SCEV * visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr)
const SCEV * visitSMinExpr(const SCEVSMinExpr *Expr)
const SCEV * visitUMinExpr(const SCEVUMinExpr *Expr)
This class represents a signed maximum selection.
This class represents a signed minimum selection.
This node is the base class for sequential/in-order min/max selections.
This class represents a sequential/in-order unsigned minimum selection.
This class represents a sign extension of a small integer value to a larger integer value.
Visit all nodes in the expression tree using worklist traversal.
void visitAll(const SCEV *Root)
This class represents a truncation of an integer value to a smaller integer value.
This class represents a binary unsigned division operation.
const SCEV * getLHS() const
const SCEV * getRHS() const
This class represents an unsigned maximum selection.
This class represents an unsigned minimum selection.
This class represents a composition of other SCEV predicates, and is the class that most clients will...
void print(raw_ostream &OS, unsigned Depth) const override
Prints a textual representation of this predicate with an indentation of Depth.
bool implies(const SCEVPredicate *N, ScalarEvolution &SE) const override
Returns true if this predicate implies N.
SCEVUnionPredicate(ArrayRef< const SCEVPredicate * > Preds, ScalarEvolution &SE)
Union predicates don't get cached so create a dummy set ID for it.
bool isAlwaysTrue() const override
Implementation of the SCEVPredicate interface.
This means that we are dealing with an entirely unknown SCEV value, and only represent it as its LLVM...
This class represents the value of vscale, as used when defining the length of a scalable vector or r...
This class represents an assumption made on an AddRec expression.
IncrementWrapFlags
Similar to SCEV::NoWrapFlags, but with slightly different semantics for FlagNUSW.
SCEVWrapPredicate(const FoldingSetNodeIDRef ID, const SCEVAddRecExpr *AR, IncrementWrapFlags Flags)
bool implies(const SCEVPredicate *N, ScalarEvolution &SE) const override
Returns true if this predicate implies N.
static SCEVWrapPredicate::IncrementWrapFlags setFlags(SCEVWrapPredicate::IncrementWrapFlags Flags, SCEVWrapPredicate::IncrementWrapFlags OnFlags)
void print(raw_ostream &OS, unsigned Depth=0) const override
Prints a textual representation of this predicate with an indentation of Depth.
bool isAlwaysTrue() const override
Returns true if the predicate is always true.
const SCEVAddRecExpr * getExpr() const
Implementation of the SCEVPredicate interface.
static SCEVWrapPredicate::IncrementWrapFlags clearFlags(SCEVWrapPredicate::IncrementWrapFlags Flags, SCEVWrapPredicate::IncrementWrapFlags OffFlags)
Convenient IncrementWrapFlags manipulation methods.
static SCEVWrapPredicate::IncrementWrapFlags getImpliedFlags(const SCEVAddRecExpr *AR, ScalarEvolution &SE)
Returns the set of SCEVWrapPredicate no wrap flags implied by a SCEVAddRecExpr.
IncrementWrapFlags getFlags() const
Returns the set assumed no overflow flags.
This class represents a zero extension of a small integer value to a larger integer value.
This class represents an analyzed expression in the program.
ArrayRef< const SCEV * > operands() const
Return operands of this SCEV expression.
unsigned short getExpressionSize() const
bool isOne() const
Return true if the expression is a constant one.
bool isZero() const
Return true if the expression is a constant zero.
void dump() const
This method is used for debugging.
bool isAllOnesValue() const
Return true if the expression is a constant all-ones value.
bool isNonConstantNegative() const
Return true if the specified scev is negated, but not a constant.
void print(raw_ostream &OS) const
Print out the internal representation of this scalar to the specified stream.
SCEVTypes getSCEVType() const
Type * getType() const
Return the LLVM type of this SCEV expression.
NoWrapFlags
NoWrapFlags are bitfield indices into SubclassData.
Analysis pass that exposes the ScalarEvolution for a function.
ScalarEvolution run(Function &F, FunctionAnalysisManager &AM)
PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM)
PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM)
void getAnalysisUsage(AnalysisUsage &AU) const override
getAnalysisUsage - This function should be overriden by passes that need analysis information to do t...
void print(raw_ostream &OS, const Module *=nullptr) const override
print - Print out the internal state of the pass.
bool runOnFunction(Function &F) override
runOnFunction - Virtual method overriden by subclasses to do the per-function processing of the pass.
void releaseMemory() override
releaseMemory() - This member can be implemented by a pass if it wants to be able to release its memo...
void verifyAnalysis() const override
verifyAnalysis() - This member can be implemented by a analysis pass to check state of analysis infor...
static LoopGuards collect(const Loop *L, ScalarEvolution &SE)
Collect rewrite map for loop guards for loop L, together with flags indicating if NUW and NSW can be ...
const SCEV * rewrite(const SCEV *Expr) const
Try to apply the collected loop guards to Expr.
The main scalar evolution driver.
const SCEV * getConstantMaxBackedgeTakenCount(const Loop *L)
When successful, this returns a SCEVConstant that is greater than or equal to (i.e.
static bool hasFlags(SCEV::NoWrapFlags Flags, SCEV::NoWrapFlags TestFlags)
const DataLayout & getDataLayout() const
Return the DataLayout associated with the module this SCEV instance is operating on.
bool isKnownNonNegative(const SCEV *S)
Test if the given expression is known to be non-negative.
const SCEV * getNegativeSCEV(const SCEV *V, SCEV::NoWrapFlags Flags=SCEV::FlagAnyWrap)
Return the SCEV object corresponding to -V.
bool isLoopBackedgeGuardedByCond(const Loop *L, ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS)
Test whether the backedge of the loop is protected by a conditional between LHS and RHS.
const SCEV * getSMaxExpr(const SCEV *LHS, const SCEV *RHS)
const SCEV * getUDivCeilSCEV(const SCEV *N, const SCEV *D)
Compute ceil(N / D).
const SCEV * getGEPExpr(GEPOperator *GEP, const SmallVectorImpl< const SCEV * > &IndexExprs)
Returns an expression for a GEP.
Type * getWiderType(Type *Ty1, Type *Ty2) const
const SCEV * getAbsExpr(const SCEV *Op, bool IsNSW)
bool isKnownNonPositive(const SCEV *S)
Test if the given expression is known to be non-positive.
const SCEV * getURemExpr(const SCEV *LHS, const SCEV *RHS)
Represents an unsigned remainder expression based on unsigned division.
bool SimplifyICmpOperands(ICmpInst::Predicate &Pred, const SCEV *&LHS, const SCEV *&RHS, unsigned Depth=0)
Simplify LHS and RHS in a comparison with predicate Pred.
APInt getConstantMultiple(const SCEV *S)
Returns the max constant multiple of S.
bool isKnownNegative(const SCEV *S)
Test if the given expression is known to be negative.
const SCEV * getPredicatedConstantMaxBackedgeTakenCount(const Loop *L, SmallVectorImpl< const SCEVPredicate * > &Predicates)
Similar to getConstantMaxBackedgeTakenCount, except it will add a set of SCEV predicates to Predicate...
const SCEV * removePointerBase(const SCEV *S)
Compute an expression equivalent to S - getPointerBase(S).
bool isKnownNonZero(const SCEV *S)
Test if the given expression is known to be non-zero.
const SCEV * getSCEVAtScope(const SCEV *S, const Loop *L)
Return a SCEV expression for the specified value at the specified scope in the program.
const SCEV * getSMinExpr(const SCEV *LHS, const SCEV *RHS)
const SCEV * getBackedgeTakenCount(const Loop *L, ExitCountKind Kind=Exact)
If the specified loop has a predictable backedge-taken count, return it, otherwise return a SCEVCould...
const SCEV * getUMaxExpr(const SCEV *LHS, const SCEV *RHS)
void setNoWrapFlags(SCEVAddRecExpr *AddRec, SCEV::NoWrapFlags Flags)
Update no-wrap flags of an AddRec.
const SCEV * getUMaxFromMismatchedTypes(const SCEV *LHS, const SCEV *RHS)
Promote the operands to the wider of the types using zero-extension, and then perform a umax operatio...
const SCEV * getZero(Type *Ty)
Return a SCEV for the constant 0 of a specific type.
bool willNotOverflow(Instruction::BinaryOps BinOp, bool Signed, const SCEV *LHS, const SCEV *RHS, const Instruction *CtxI=nullptr)
Is operation BinOp between LHS and RHS provably does not have a signed/unsigned overflow (Signed)?...
ExitLimit computeExitLimitFromCond(const Loop *L, Value *ExitCond, bool ExitIfTrue, bool ControlsOnlyExit, bool AllowPredicates=false)
Compute the number of times the backedge of the specified loop will execute if its exit condition wer...
const SCEV * getZeroExtendExprImpl(const SCEV *Op, Type *Ty, unsigned Depth=0)
const SCEVPredicate * getEqualPredicate(const SCEV *LHS, const SCEV *RHS)
unsigned getSmallConstantTripMultiple(const Loop *L, const SCEV *ExitCount)
Returns the largest constant divisor of the trip count as a normal unsigned value,...
uint64_t getTypeSizeInBits(Type *Ty) const
Return the size in bits of the specified type, for which isSCEVable must return true.
const SCEV * getConstant(ConstantInt *V)
const SCEV * getPredicatedBackedgeTakenCount(const Loop *L, SmallVectorImpl< const SCEVPredicate * > &Predicates)
Similar to getBackedgeTakenCount, except it will add a set of SCEV predicates to Predicates that are ...
const SCEV * getSCEV(Value *V)
Return a SCEV expression for the full generality of the specified expression.
ConstantRange getSignedRange(const SCEV *S)
Determine the signed range for a particular SCEV.
const SCEV * getNoopOrSignExtend(const SCEV *V, Type *Ty)
Return a SCEV corresponding to a conversion of the input value to the specified type.
bool loopHasNoAbnormalExits(const Loop *L)
Return true if the loop has no abnormal exits.
const SCEV * getTripCountFromExitCount(const SCEV *ExitCount)
A version of getTripCountFromExitCount below which always picks an evaluation type which can not resu...
ScalarEvolution(Function &F, TargetLibraryInfo &TLI, AssumptionCache &AC, DominatorTree &DT, LoopInfo &LI)
const SCEV * getOne(Type *Ty)
Return a SCEV for the constant 1 of a specific type.
const SCEV * getTruncateOrNoop(const SCEV *V, Type *Ty)
Return a SCEV corresponding to a conversion of the input value to the specified type.
const SCEV * getCastExpr(SCEVTypes Kind, const SCEV *Op, Type *Ty)
const SCEV * getSequentialMinMaxExpr(SCEVTypes Kind, SmallVectorImpl< const SCEV * > &Operands)
const SCEV * getLosslessPtrToIntExpr(const SCEV *Op, unsigned Depth=0)
unsigned getSmallConstantMaxTripCount(const Loop *L, SmallVectorImpl< const SCEVPredicate * > *Predicates=nullptr)
Returns the upper bound of the loop trip count as a normal unsigned value.
bool isKnownViaInduction(ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS)
We'd like to check the predicate on every iteration of the most dominated loop between loops used in ...
std::optional< bool > evaluatePredicate(ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS)
Check whether the condition described by Pred, LHS, and RHS is true or false.
bool isKnownPredicateAt(ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS, const Instruction *CtxI)
Test if the given expression is known to satisfy the condition described by Pred, LHS,...
const SCEV * getPtrToIntExpr(const SCEV *Op, Type *Ty)
bool isBackedgeTakenCountMaxOrZero(const Loop *L)
Return true if the backedge taken count is either the value returned by getConstantMaxBackedgeTakenCo...
void forgetLoop(const Loop *L)
This method should be called by the client when it has changed a loop in a way that may effect Scalar...
bool isLoopInvariant(const SCEV *S, const Loop *L)
Return true if the value of the given SCEV is unchanging in the specified loop.
bool isKnownPositive(const SCEV *S)
Test if the given expression is known to be positive.
APInt getUnsignedRangeMin(const SCEV *S)
Determine the min of the unsigned range for a particular SCEV.
bool isKnownPredicate(ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS)
Test if the given expression is known to satisfy the condition described by Pred, LHS,...
const SCEV * getOffsetOfExpr(Type *IntTy, StructType *STy, unsigned FieldNo)
Return an expression for offsetof on the given field with type IntTy.
LoopDisposition getLoopDisposition(const SCEV *S, const Loop *L)
Return the "disposition" of the given SCEV with respect to the given loop.
bool containsAddRecurrence(const SCEV *S)
Return true if the SCEV is a scAddRecExpr or it contains scAddRecExpr.
const SCEV * getSignExtendExprImpl(const SCEV *Op, Type *Ty, unsigned Depth=0)
const SCEV * getAddRecExpr(const SCEV *Start, const SCEV *Step, const Loop *L, SCEV::NoWrapFlags Flags)
Get an add recurrence expression for the specified loop.
bool isBasicBlockEntryGuardedByCond(const BasicBlock *BB, ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS)
Test whether entry to the basic block is protected by a conditional between LHS and RHS.
bool isKnownOnEveryIteration(ICmpInst::Predicate Pred, const SCEVAddRecExpr *LHS, const SCEV *RHS)
Test if the condition described by Pred, LHS, RHS is known to be true on every iteration of the loop ...
bool hasOperand(const SCEV *S, const SCEV *Op) const
Test whether the given SCEV has Op as a direct or indirect operand.
const SCEV * getUDivExpr(const SCEV *LHS, const SCEV *RHS)
Get a canonical unsigned division expression, or something simpler if possible.
const SCEV * getZeroExtendExpr(const SCEV *Op, Type *Ty, unsigned Depth=0)
bool isSCEVable(Type *Ty) const
Test if values of the given type are analyzable within the SCEV framework.
Type * getEffectiveSCEVType(Type *Ty) const
Return a type with the same bitwidth as the given type and which represents how SCEV will treat the g...
const SCEVPredicate * getComparePredicate(ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS)
const SCEV * getNotSCEV(const SCEV *V)
Return the SCEV object corresponding to ~V.
std::optional< LoopInvariantPredicate > getLoopInvariantPredicate(ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS, const Loop *L, const Instruction *CtxI=nullptr)
If the result of the predicate LHS Pred RHS is loop invariant with respect to L, return a LoopInvaria...
bool instructionCouldExistWithOperands(const SCEV *A, const SCEV *B)
Return true if there exists a point in the program at which both A and B could be operands to the sam...
std::optional< bool > evaluatePredicateAt(ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS, const Instruction *CtxI)
Check whether the condition described by Pred, LHS, and RHS is true or false in the given Context.
ConstantRange getUnsignedRange(const SCEV *S)
Determine the unsigned range for a particular SCEV.
uint32_t getMinTrailingZeros(const SCEV *S)
Determine the minimum number of zero bits that S is guaranteed to end in (at every loop iteration).
void print(raw_ostream &OS) const
const SCEV * getUMinExpr(const SCEV *LHS, const SCEV *RHS, bool Sequential=false)
const SCEV * getPredicatedExitCount(const Loop *L, const BasicBlock *ExitingBlock, SmallVectorImpl< const SCEVPredicate * > *Predicates, ExitCountKind Kind=Exact)
Same as above except this uses the predicated backedge taken info and may require predicates.
static SCEV::NoWrapFlags clearFlags(SCEV::NoWrapFlags Flags, SCEV::NoWrapFlags OffFlags)
void forgetTopmostLoop(const Loop *L)
void forgetValue(Value *V)
This method should be called by the client when it has changed a value in a way that may effect its v...
APInt getSignedRangeMin(const SCEV *S)
Determine the min of the signed range for a particular SCEV.
const SCEV * getNoopOrAnyExtend(const SCEV *V, Type *Ty)
Return a SCEV corresponding to a conversion of the input value to the specified type.
void forgetBlockAndLoopDispositions(Value *V=nullptr)
Called when the client has changed the disposition of values in a loop or block.
const SCEV * getTruncateExpr(const SCEV *Op, Type *Ty, unsigned Depth=0)
const SCEV * getStoreSizeOfExpr(Type *IntTy, Type *StoreTy)
Return an expression for the store size of StoreTy that is type IntTy.
const SCEVPredicate * getWrapPredicate(const SCEVAddRecExpr *AR, SCEVWrapPredicate::IncrementWrapFlags AddedFlags)
const SCEV * getMinusSCEV(const SCEV *LHS, const SCEV *RHS, SCEV::NoWrapFlags Flags=SCEV::FlagAnyWrap, unsigned Depth=0)
Return LHS-RHS.
APInt getNonZeroConstantMultiple(const SCEV *S)
const SCEV * getMinusOne(Type *Ty)
Return a SCEV for the constant -1 of a specific type.
static SCEV::NoWrapFlags setFlags(SCEV::NoWrapFlags Flags, SCEV::NoWrapFlags OnFlags)
bool hasLoopInvariantBackedgeTakenCount(const Loop *L)
Return true if the specified loop has an analyzable loop-invariant backedge-taken count.
BlockDisposition getBlockDisposition(const SCEV *S, const BasicBlock *BB)
Return the "disposition" of the given SCEV with respect to the given block.
const SCEV * getNoopOrZeroExtend(const SCEV *V, Type *Ty)
Return a SCEV corresponding to a conversion of the input value to the specified type.
bool invalidate(Function &F, const PreservedAnalyses &PA, FunctionAnalysisManager::Invalidator &Inv)
const SCEV * getUMinFromMismatchedTypes(const SCEV *LHS, const SCEV *RHS, bool Sequential=false)
Promote the operands to the wider of the types using zero-extension, and then perform a umin operatio...
bool loopIsFiniteByAssumption(const Loop *L)
Return true if this loop is finite by assumption.
const SCEV * getExistingSCEV(Value *V)
Return an existing SCEV for V if there is one, otherwise return nullptr.
LoopDisposition
An enum describing the relationship between a SCEV and a loop.
@ LoopComputable
The SCEV varies predictably with the loop.
@ LoopVariant
The SCEV is loop-variant (unknown).
@ LoopInvariant
The SCEV is loop-invariant.
const SCEV * getAnyExtendExpr(const SCEV *Op, Type *Ty)
getAnyExtendExpr - Return a SCEV for the given operand extended with unspecified bits out to the give...
bool isKnownToBeAPowerOfTwo(const SCEV *S, bool OrZero=false, bool OrNegative=false)
Test if the given expression is known to be a power of 2.
std::optional< SCEV::NoWrapFlags > getStrengthenedNoWrapFlagsFromBinOp(const OverflowingBinaryOperator *OBO)
Parse NSW/NUW flags from add/sub/mul IR binary operation Op into SCEV no-wrap flags,...
void forgetLcssaPhiWithNewPredecessor(Loop *L, PHINode *V)
Forget LCSSA phi node V of loop L to which a new predecessor was added, such that it may no longer be...
bool containsUndefs(const SCEV *S) const
Return true if the SCEV expression contains an undef value.
std::optional< MonotonicPredicateType > getMonotonicPredicateType(const SCEVAddRecExpr *LHS, ICmpInst::Predicate Pred)
If, for all loop invariant X, the predicate "LHS `Pred` X" is monotonically increasing or decreasing,...
const SCEV * getCouldNotCompute()
bool isAvailableAtLoopEntry(const SCEV *S, const Loop *L)
Determine if the SCEV can be evaluated at loop's entry.
BlockDisposition
An enum describing the relationship between a SCEV and a basic block.
@ DominatesBlock
The SCEV dominates the block.
@ ProperlyDominatesBlock
The SCEV properly dominates the block.
@ DoesNotDominateBlock
The SCEV does not dominate the block.
std::optional< LoopInvariantPredicate > getLoopInvariantExitCondDuringFirstIterationsImpl(ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS, const Loop *L, const Instruction *CtxI, const SCEV *MaxIter)
const SCEV * getExitCount(const Loop *L, const BasicBlock *ExitingBlock, ExitCountKind Kind=Exact)
Return the number of times the backedge executes before the given exit would be taken; if not exactly...
const SCEV * getSignExtendExpr(const SCEV *Op, Type *Ty, unsigned Depth=0)
void getPoisonGeneratingValues(SmallPtrSetImpl< const Value * > &Result, const SCEV *S)
Return the set of Values that, if poison, will definitively result in S being poison as well.
void forgetLoopDispositions()
Called when the client has changed the disposition of values in this loop.
const SCEV * getVScale(Type *Ty)
unsigned getSmallConstantTripCount(const Loop *L)
Returns the exact trip count of the loop if we can compute it, and the result is a small constant.
bool hasComputableLoopEvolution(const SCEV *S, const Loop *L)
Return true if the given SCEV changes value in a known way in the specified loop.
const SCEV * getPointerBase(const SCEV *V)
Transitively follow the chain of pointer-type operands until reaching a SCEV that does not have a sin...
const SCEV * getMinMaxExpr(SCEVTypes Kind, SmallVectorImpl< const SCEV * > &Operands)
bool dominates(const SCEV *S, const BasicBlock *BB)
Return true if elements that makes up the given SCEV dominate the specified basic block.
APInt getUnsignedRangeMax(const SCEV *S)
Determine the max of the unsigned range for a particular SCEV.
ExitCountKind
The terms "backedge taken count" and "exit count" are used interchangeably to refer to the number of ...
@ SymbolicMaximum
An expression which provides an upper bound on the exact trip count.
@ ConstantMaximum
A constant which provides an upper bound on the exact trip count.
@ Exact
An expression exactly describing the number of times the backedge has executed when a loop is exited.
std::optional< LoopInvariantPredicate > getLoopInvariantExitCondDuringFirstIterations(ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS, const Loop *L, const Instruction *CtxI, const SCEV *MaxIter)
If the result of the predicate LHS Pred RHS is loop invariant with respect to L at given Context duri...
const SCEV * applyLoopGuards(const SCEV *Expr, const Loop *L)
Try to apply information from loop guards for L to Expr.
const SCEV * getMulExpr(SmallVectorImpl< const SCEV * > &Ops, SCEV::NoWrapFlags Flags=SCEV::FlagAnyWrap, unsigned Depth=0)
Get a canonical multiply expression, or something simpler if possible.
const SCEVAddRecExpr * convertSCEVToAddRecWithPredicates(const SCEV *S, const Loop *L, SmallVectorImpl< const SCEVPredicate * > &Preds)
Tries to convert the S expression to an AddRec expression, adding additional predicates to Preds as r...
const SCEV * getElementSize(Instruction *Inst)
Return the size of an element read or written by Inst.
const SCEV * getSizeOfExpr(Type *IntTy, TypeSize Size)
Return an expression for a TypeSize.
const SCEV * getUnknown(Value *V)
std::optional< std::pair< const SCEV *, SmallVector< const SCEVPredicate *, 3 > > > createAddRecFromPHIWithCasts(const SCEVUnknown *SymbolicPHI)
Checks if SymbolicPHI can be rewritten as an AddRecExpr under some Predicates.
const SCEV * getTruncateOrZeroExtend(const SCEV *V, Type *Ty, unsigned Depth=0)
Return a SCEV corresponding to a conversion of the input value to the specified type.
bool isLoopEntryGuardedByCond(const Loop *L, ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS)
Test whether entry to the loop is protected by a conditional between LHS and RHS.
const SCEV * getElementCount(Type *Ty, ElementCount EC)
static SCEV::NoWrapFlags maskFlags(SCEV::NoWrapFlags Flags, int Mask)
Convenient NoWrapFlags manipulation that hides enum casts and is visible in the ScalarEvolution name ...
std::optional< APInt > computeConstantDifference(const SCEV *LHS, const SCEV *RHS)
Compute LHS - RHS and returns the result as an APInt if it is a constant, and std::nullopt if it isn'...
bool properlyDominates(const SCEV *S, const BasicBlock *BB)
Return true if elements that makes up the given SCEV properly dominate the specified basic block.
const SCEV * rewriteUsingPredicate(const SCEV *S, const Loop *L, const SCEVPredicate &A)
Re-writes the SCEV according to the Predicates in A.
std::pair< const SCEV *, const SCEV * > SplitIntoInitAndPostInc(const Loop *L, const SCEV *S)
Splits SCEV expression S into two SCEVs.
bool canReuseInstruction(const SCEV *S, Instruction *I, SmallVectorImpl< Instruction * > &DropPoisonGeneratingInsts)
Check whether it is poison-safe to represent the expression S using the instruction I.
const SCEV * getPredicatedSymbolicMaxBackedgeTakenCount(const Loop *L, SmallVectorImpl< const SCEVPredicate * > &Predicates)
Similar to getSymbolicMaxBackedgeTakenCount, except it will add a set of SCEV predicates to Predicate...
const SCEV * getUDivExactExpr(const SCEV *LHS, const SCEV *RHS)
Get a canonical unsigned division expression, or something simpler if possible.
void registerUser(const SCEV *User, ArrayRef< const SCEV * > Ops)
Notify this ScalarEvolution that User directly uses SCEVs in Ops.
const SCEV * getAddExpr(SmallVectorImpl< const SCEV * > &Ops, SCEV::NoWrapFlags Flags=SCEV::FlagAnyWrap, unsigned Depth=0)
Get a canonical add expression, or something simpler if possible.
const SCEV * getTruncateOrSignExtend(const SCEV *V, Type *Ty, unsigned Depth=0)
Return a SCEV corresponding to a conversion of the input value to the specified type.
bool containsErasedValue(const SCEV *S) const
Return true if the SCEV expression contains a Value that has been optimised out and is now a nullptr.
const SCEV * getSymbolicMaxBackedgeTakenCount(const Loop *L)
When successful, this returns a SCEV that is greater than or equal to (i.e.
APInt getSignedRangeMax(const SCEV *S)
Determine the max of the signed range for a particular SCEV.
LLVMContext & getContext() const
This class represents the LLVM 'select' instruction.
size_type size() const
Definition: SmallPtrSet.h:94
A templated base class for SmallPtrSet which provides the typesafe interface that is common across al...
Definition: SmallPtrSet.h:363
std::pair< iterator, bool > insert(PtrType Ptr)
Inserts Ptr if and only if there is no element in the container equal to Ptr.
Definition: SmallPtrSet.h:384
bool contains(ConstPtrType Ptr) const
Definition: SmallPtrSet.h:458
SmallPtrSet - This class implements a set which is optimized for holding SmallSize or less elements.
Definition: SmallPtrSet.h:519
SmallSet - This maintains a set of unique values, optimizing for the case when the set is small (less...
Definition: SmallSet.h:132
std::pair< const_iterator, bool > insert(const T &V)
insert - Insert an element into the set if it isn't already there.
Definition: SmallSet.h:181
size_type size() const
Definition: SmallSet.h:170
bool empty() const
Definition: SmallVector.h:81
size_t size() const
Definition: SmallVector.h:78
This class consists of common code factored out of the SmallVector class to reduce code duplication b...
Definition: SmallVector.h:573
reference emplace_back(ArgTypes &&... Args)
Definition: SmallVector.h:937
void reserve(size_type N)
Definition: SmallVector.h:663
iterator erase(const_iterator CI)
Definition: SmallVector.h:737
void append(ItTy in_start, ItTy in_end)
Add the specified range to the end of the SmallVector.
Definition: SmallVector.h:683
iterator insert(iterator I, T &&Elt)
Definition: SmallVector.h:805
void push_back(const T &Elt)
Definition: SmallVector.h:413
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
Definition: SmallVector.h:1196
An instruction for storing to memory.
Definition: Instructions.h:292
Used to lazily calculate structure layout information for a target machine, based on the DataLayout s...
Definition: DataLayout.h:567
TypeSize getElementOffset(unsigned Idx) const
Definition: DataLayout.h:596
TypeSize getSizeInBits() const
Definition: DataLayout.h:576
Class to represent struct types.
Definition: DerivedTypes.h:218
Multiway switch.
Analysis pass providing the TargetLibraryInfo.
Provides information about what library functions are available for the current target.
The instances of the Type class are immutable: once they are created, they are never changed.
Definition: Type.h:45
bool isPointerTy() const
True if this is an instance of PointerType.
Definition: Type.h:264
static IntegerType * getInt1Ty(LLVMContext &C)
static IntegerType * getIntNTy(LLVMContext &C, unsigned N)
unsigned getScalarSizeInBits() const LLVM_READONLY
If this is a vector type, return the getPrimitiveSizeInBits value for the element type.
static IntegerType * getInt8Ty(LLVMContext &C)
bool isIntOrPtrTy() const
Return true if this is an integer type or a pointer type.
Definition: Type.h:252
static IntegerType * getInt32Ty(LLVMContext &C)
bool isIntegerTy() const
True if this is an instance of IntegerType.
Definition: Type.h:237
TypeSize getPrimitiveSizeInBits() const LLVM_READONLY
Return the basic size of this type if it is a primitive type.
A Use represents the edge between a Value definition and its users.
Definition: Use.h:43
op_range operands()
Definition: User.h:288
Use & Op()
Definition: User.h:192
Value * getOperand(unsigned i) const
Definition: User.h:228
LLVM Value Representation.
Definition: Value.h:74
Type * getType() const
All values are typed, get the type of this value.
Definition: Value.h:255
unsigned getValueID() const
Return an ID for the concrete type of this object.
Definition: Value.h:532
void printAsOperand(raw_ostream &O, bool PrintType=true, const Module *M=nullptr) const
Print the name of this Value out to the specified raw_ostream.
Definition: AsmWriter.cpp:5144
LLVMContext & getContext() const
All values hold a context through their type.
Definition: Value.cpp:1075
StringRef getName() const
Return a constant reference to the value's name.
Definition: Value.cpp:309
Represents an op.with.overflow intrinsic.
constexpr bool isScalable() const
Returns whether the quantity is scaled by a runtime quantity (vscale).
Definition: TypeSize.h:171
const ParentTy * getParent() const
Definition: ilist_node.h:32
This class implements an extremely fast bulk output stream that can only output to a stream.
Definition: raw_ostream.h:52
raw_ostream & indent(unsigned NumSpaces)
indent - Insert 'NumSpaces' spaces.
#define llvm_unreachable(msg)
Marks that the current location is not supposed to be reachable.
const APInt & smin(const APInt &A, const APInt &B)
Determine the smaller of two APInts considered to be signed.
Definition: APInt.h:2217
const APInt & smax(const APInt &A, const APInt &B)
Determine the larger of two APInts considered to be signed.
Definition: APInt.h:2222
const APInt & umin(const APInt &A, const APInt &B)
Determine the smaller of two APInts considered to be unsigned.
Definition: APInt.h:2227
std::optional< APInt > SolveQuadraticEquationWrap(APInt A, APInt B, APInt C, unsigned RangeWidth)
Let q(n) = An^2 + Bn + C, and BW = bit width of the value range (e.g.
Definition: APInt.cpp:2785
const APInt & umax(const APInt &A, const APInt &B)
Determine the larger of two APInts considered to be unsigned.
Definition: APInt.h:2232
APInt GreatestCommonDivisor(APInt A, APInt B)
Compute GCD of two unsigned APInt values.
Definition: APInt.cpp:771
@ Entry
Definition: COFF.h:844
@ Exit
Definition: COFF.h:845
@ C
The default llvm calling convention, compatible with C.
Definition: CallingConv.h:34
unsigned ID
LLVM IR allows to use arbitrary numbers as calling convention identifiers.
Definition: CallingConv.h:24
Function * getDeclarationIfExists(Module *M, ID id, ArrayRef< Type * > Tys, FunctionType *FT=nullptr)
This version supports overloaded intrinsics.
Definition: Intrinsics.cpp:746
Predicate
Predicate - These are "(BI << 5) | BO" for various predicates.
Definition: PPCPredicates.h:26
BinaryOp_match< LHS, RHS, Instruction::AShr > m_AShr(const LHS &L, const RHS &R)
bool match(Val *V, const Pattern &P)
Definition: PatternMatch.h:49
specificval_ty m_Specific(const Value *V)
Match if we have a specific specified value.
Definition: PatternMatch.h:885
class_match< ConstantInt > m_ConstantInt()
Match an arbitrary ConstantInt and ignore it.
Definition: PatternMatch.h:168
ThreeOps_match< Cond, LHS, RHS, Instruction::Select > m_Select(const Cond &C, const LHS &L, const RHS &R)
Matches SelectInst.
bind_ty< WithOverflowInst > m_WithOverflowInst(WithOverflowInst *&I)
Match a with overflow intrinsic, capturing it if we match.
Definition: PatternMatch.h:832
auto m_LogicalOr()
Matches L || R where L and R are arbitrary values.
brc_match< Cond_t, bind_ty< BasicBlock >, bind_ty< BasicBlock > > m_Br(const Cond_t &C, BasicBlock *&T, BasicBlock *&F)
BinaryOp_match< LHS, RHS, Instruction::SDiv > m_SDiv(const LHS &L, const RHS &R)
apint_match m_APInt(const APInt *&Res)
Match a ConstantInt or splatted ConstantVector, binding the specified pointer to the contained APInt.
Definition: PatternMatch.h:299
class_match< Value > m_Value()
Match an arbitrary value and ignore it.
Definition: PatternMatch.h:92
BinaryOp_match< LHS, RHS, Instruction::LShr > m_LShr(const LHS &L, const RHS &R)
BinaryOp_match< LHS, RHS, Instruction::Shl > m_Shl(const LHS &L, const RHS &R)
auto m_LogicalAnd()
Matches L && R where L and R are arbitrary values.
class_match< BasicBlock > m_BasicBlock()
Match an arbitrary basic block value and ignore it.
Definition: PatternMatch.h:189
match_combine_or< LTy, RTy > m_CombineOr(const LTy &L, const RTy &R)
Combine two pattern matchers matching L || R.
Definition: PatternMatch.h:239
cst_pred_ty< is_all_ones > m_scev_AllOnes()
Match an integer with all bits set.
SCEVUnaryExpr_match< SCEVZeroExtendExpr, Op0_t > m_scev_ZExt(const Op0_t &Op0)
cst_pred_ty< is_one > m_scev_One()
Match an integer 1.
SCEVUnaryExpr_match< SCEVSignExtendExpr, Op0_t > m_scev_SExt(const Op0_t &Op0)
cst_pred_ty< is_zero > m_scev_Zero()
Match an integer 0.
bind_ty< const SCEVConstant > m_SCEVConstant(const SCEVConstant *&V)
bind_ty< const SCEV > m_SCEV(const SCEV *&V)
Match a SCEV, capturing it if we match.
SCEVBinaryExpr_match< SCEVAddExpr, Op0_t, Op1_t > m_scev_Add(const Op0_t &Op0, const Op1_t &Op1)
bool match(const SCEV *S, const Pattern &P)
bind_ty< const SCEVUnknown > m_SCEVUnknown(const SCEVUnknown *&V)
@ ReallyHidden
Definition: CommandLine.h:138
initializer< Ty > init(const Ty &Val)
Definition: CommandLine.h:443
LocationClass< Ty > location(Ty &L)
Definition: CommandLine.h:463
@ Switch
The "resume-switch" lowering, where there are separate resume and destroy functions that are shared b...
constexpr double e
Definition: MathExtras.h:47
NodeAddr< PhiNode * > Phi
Definition: RDFGraph.h:390
@ FalseVal
Definition: TGLexer.h:59
This is an optimization pass for GlobalISel generic memory operations.
Definition: AddressRanges.h:18
void visitAll(const SCEV *Root, SV &Visitor)
Use SCEVTraversal to visit all nodes in the given expression tree.
auto drop_begin(T &&RangeOrContainer, size_t N=1)
Return a range covering RangeOrContainer with the first N elements excluded.
Definition: STLExtras.h:329
@ Offset
Definition: DWP.cpp:480
LLVM_ATTRIBUTE_ALWAYS_INLINE DynamicAPInt gcd(const DynamicAPInt &A, const DynamicAPInt &B)
Definition: DynamicAPInt.h:390
void stable_sort(R &&Range)
Definition: STLExtras.h:2037
bool all_of(R &&range, UnaryPredicate P)
Provide wrappers to std::all_of which take ranges instead of having to pass begin/end explicitly.
Definition: STLExtras.h:1739
bool canCreatePoison(const Operator *Op, bool ConsiderFlagsAndMetadata=true)
bool mustTriggerUB(const Instruction *I, const SmallPtrSetImpl< const Value * > &KnownPoison)
Return true if the given instruction must trigger undefined behavior when I is executed with any oper...
bool isUIntN(unsigned N, uint64_t x)
Checks if an unsigned integer fits into the given (dynamic) bit width.
Definition: MathExtras.h:255
detail::scope_exit< std::decay_t< Callable > > make_scope_exit(Callable &&F)
Definition: ScopeExit.h:59
bool canConstantFoldCallTo(const CallBase *Call, const Function *F)
canConstantFoldCallTo - Return true if its even possible to fold a call to the specified function.
bool verifyFunction(const Function &F, raw_ostream *OS=nullptr)
Check a function for errors, useful for use when debugging a pass.
Definition: Verifier.cpp:7293
auto successors(const MachineBasicBlock *BB)
void * PointerTy
Definition: GenericValue.h:21
bool set_is_subset(const S1Ty &S1, const S2Ty &S2)
set_is_subset(A, B) - Return true iff A in B
void append_range(Container &C, Range &&R)
Wrapper function to append range R to container C.
Definition: STLExtras.h:2115
Constant * ConstantFoldCompareInstOperands(unsigned Predicate, Constant *LHS, Constant *RHS, const DataLayout &DL, const TargetLibraryInfo *TLI=nullptr, const Instruction *I=nullptr)
Attempt to constant fold a compare instruction (icmp/fcmp) with the specified operands.
unsigned short computeExpressionSize(ArrayRef< const SCEV * > Args)
bool VerifySCEV
Printable print(const GCNRegPressure &RP, const GCNSubtarget *ST=nullptr)
ConstantRange getConstantRangeFromMetadata(const MDNode &RangeMD)
Parse out a conservative ConstantRange from !range metadata.
int countr_zero(T Val)
Count number of 0's from the least significant bit to the most stopping at the first 1.
Definition: bit.h:215
Value * simplifyInstruction(Instruction *I, const SimplifyQuery &Q)
See if we can compute a simplified version of this instruction.
bool isOverflowIntrinsicNoWrap(const WithOverflowInst *WO, const DominatorTree &DT)
Returns true if the arithmetic part of the WO 's result is used only along the paths control dependen...
bool matchSimpleRecurrence(const PHINode *P, BinaryOperator *&BO, Value *&Start, Value *&Step)
Attempt to match a simple first order recurrence cycle of the form: iv = phi Ty [Start,...
void erase(Container &C, ValueType V)
Wrapper function to remove a value from a container:
Definition: STLExtras.h:2107
bool any_of(R &&range, UnaryPredicate P)
Provide wrappers to std::any_of which take ranges instead of having to pass begin/end explicitly.
Definition: STLExtras.h:1746
void initializeScalarEvolutionWrapperPassPass(PassRegistry &)
auto reverse(ContainerTy &&C)
Definition: STLExtras.h:420
bool isMustProgress(const Loop *L)
Return true if this loop can be assumed to make progress.
Definition: LoopInfo.cpp:1150
bool impliesPoison(const Value *ValAssumedPoison, const Value *V)
Return true if V is poison given that ValAssumedPoison is already poison.
bool isFinite(const Loop *L)
Return true if this loop can be assumed to run for a finite number of iterations.
Definition: LoopInfo.cpp:1140
bool programUndefinedIfPoison(const Instruction *Inst)
raw_ostream & dbgs()
dbgs() - This returns a reference to a raw_ostream for debugging messages.
Definition: Debug.cpp:163
bool isPointerTy(const Type *T)
Definition: SPIRVUtils.h:250
ConstantRange getVScaleRange(const Function *F, unsigned BitWidth)
Determine the possible constant range of vscale with the given bit width, based on the vscale_range f...
Constant * ConstantFoldInstOperands(Instruction *I, ArrayRef< Constant * > Ops, const DataLayout &DL, const TargetLibraryInfo *TLI=nullptr, bool AllowNonDeterministic=true)
ConstantFoldInstOperands - Attempt to constant fold an instruction with the specified operands.
bool isKnownNonZero(const Value *V, const SimplifyQuery &Q, unsigned Depth=0)
Return true if the given value is known to be non-zero when defined.
@ First
Helpers to iterate all locations in the MemoryEffectsBase class.
bool propagatesPoison(const Use &PoisonOp)
Return true if PoisonOp's user yields poison or raises UB if its operand PoisonOp is poison.
@ UMin
Unsigned integer min implemented in terms of select(cmp()).
@ Mul
Product of integers.
@ SMax
Signed integer max implemented in terms of select(cmp()).
@ SMin
Signed integer min implemented in terms of select(cmp()).
@ Add
Sum of integers.
@ UMax
Unsigned integer max implemented in terms of select(cmp()).
bool isIntN(unsigned N, int64_t x)
Checks if an signed integer fits into the given (dynamic) bit width.
Definition: MathExtras.h:260
auto count(R &&Range, const E &Element)
Wrapper function around std::count to count the number of times an element Element occurs in the give...
Definition: STLExtras.h:1938
void computeKnownBits(const Value *V, KnownBits &Known, const DataLayout &DL, unsigned Depth=0, AssumptionCache *AC=nullptr, const Instruction *CxtI=nullptr, const DominatorTree *DT=nullptr, bool UseInstrInfo=true)
Determine which bits of V are known to be either zero or one and return them in the KnownZero/KnownOn...
DWARFExpression::Operation Op
auto max_element(R &&Range)
Provide wrappers to std::max_element which take ranges instead of having to pass begin/end explicitly...
Definition: STLExtras.h:2014
raw_ostream & operator<<(raw_ostream &OS, const APFixedPoint &FX)
Definition: APFixedPoint.h:303
constexpr unsigned BitWidth
Definition: BitmaskEnum.h:217
OutputIt move(R &&Range, OutputIt Out)
Provide wrappers to std::move which take ranges instead of having to pass begin/end explicitly.
Definition: STLExtras.h:1873
bool isGuaranteedToTransferExecutionToSuccessor(const Instruction *I)
Return true if this function can prove that the instruction I will always transfer execution to one o...
auto count_if(R &&Range, UnaryPredicate P)
Wrapper function around std::count_if to count the number of times an element satisfying a given pred...
Definition: STLExtras.h:1945
auto predecessors(const MachineBasicBlock *BB)
bool is_contained(R &&Range, const E &Element)
Returns true if Element is found in Range.
Definition: STLExtras.h:1903
unsigned ComputeNumSignBits(const Value *Op, const DataLayout &DL, unsigned Depth=0, AssumptionCache *AC=nullptr, const Instruction *CxtI=nullptr, const DominatorTree *DT=nullptr, bool UseInstrInfo=true)
Return the number of times the sign bit of the register is replicated into the other bits.
iterator_range< df_iterator< T > > depth_first(const T &G)
auto seq(T Begin, T End)
Iterate over an integral type from Begin up to - but not including - End.
Definition: Sequence.h:305
bool isGuaranteedNotToBePoison(const Value *V, AssumptionCache *AC=nullptr, const Instruction *CtxI=nullptr, const DominatorTree *DT=nullptr, unsigned Depth=0)
Returns true if V cannot be poison, but may be undef.
bool SCEVExprContains(const SCEV *Root, PredTy Pred)
Return true if any node in Root satisfies the predicate Pred.
Implement std::hash so that hash_code can be used in STL containers.
Definition: BitVector.h:858
void swap(llvm::BitVector &LHS, llvm::BitVector &RHS)
Implement std::swap in terms of BitVector swap.
Definition: BitVector.h:860
#define N
#define NC
Definition: regutils.h:42
This struct is a compact representation of a valid (non-zero power of two) alignment.
Definition: Alignment.h:39
A special type used by analysis passes to provide an address that identifies that particular analysis...
Definition: Analysis.h:28
Incoming for lane maks phi as machine instruction, incoming register Reg and incoming block Block are...
static KnownBits makeConstant(const APInt &C)
Create known bits from a known constant.
Definition: KnownBits.h:293
bool isNonNegative() const
Returns true if this value is known to be non-negative.
Definition: KnownBits.h:100
static KnownBits ashr(const KnownBits &LHS, const KnownBits &RHS, bool ShAmtNonZero=false, bool Exact=false)
Compute known bits for ashr(LHS, RHS).
Definition: KnownBits.cpp:428
unsigned getBitWidth() const
Get the bit width of this value.
Definition: KnownBits.h:43
static KnownBits lshr(const KnownBits &LHS, const KnownBits &RHS, bool ShAmtNonZero=false, bool Exact=false)
Compute known bits for lshr(LHS, RHS).
Definition: KnownBits.cpp:370
KnownBits zextOrTrunc(unsigned BitWidth) const
Return known bits for a zero extension or truncation of the value we're tracking.
Definition: KnownBits.h:188
APInt getMaxValue() const
Return the maximal unsigned value possible given these KnownBits.
Definition: KnownBits.h:137
APInt getMinValue() const
Return the minimal unsigned value possible given these KnownBits.
Definition: KnownBits.h:121
bool isNegative() const
Returns true if this value is known to be negative.
Definition: KnownBits.h:97
static KnownBits shl(const KnownBits &LHS, const KnownBits &RHS, bool NUW=false, bool NSW=false, bool ShAmtNonZero=false)
Compute known bits for shl(LHS, RHS).
Definition: KnownBits.cpp:285
An object of this class is returned by queries that could not be answered.
static bool classof(const SCEV *S)
Methods for support type inquiry through isa, cast, and dyn_cast:
This class defines a simple visitor class that may be used for various SCEV analysis purposes.
A utility class that uses RAII to save and restore the value of a variable.
Information about the number of loop iterations for which a loop exit's branch condition evaluates to...
ExitLimit(const SCEV *E)
Construct either an exact exit limit from a constant, or an unknown one from a SCEVCouldNotCompute.
SmallVector< const SCEVPredicate *, 4 > Predicates
A vector of predicate guards for this ExitLimit.