LLVM 20.0.0git
ScalarEvolution.cpp
Go to the documentation of this file.
1//===- ScalarEvolution.cpp - Scalar Evolution Analysis --------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file contains the implementation of the scalar evolution analysis
10// engine, which is used primarily to analyze expressions involving induction
11// variables in loops.
12//
13// There are several aspects to this library. First is the representation of
14// scalar expressions, which are represented as subclasses of the SCEV class.
15// These classes are used to represent certain types of subexpressions that we
16// can handle. We only create one SCEV of a particular shape, so
17// pointer-comparisons for equality are legal.
18//
19// One important aspect of the SCEV objects is that they are never cyclic, even
20// if there is a cycle in the dataflow for an expression (ie, a PHI node). If
21// the PHI node is one of the idioms that we can represent (e.g., a polynomial
22// recurrence) then we represent it directly as a recurrence node, otherwise we
23// represent it as a SCEVUnknown node.
24//
25// In addition to being able to represent expressions of various types, we also
26// have folders that are used to build the *canonical* representation for a
27// particular expression. These folders are capable of using a variety of
28// rewrite rules to simplify the expressions.
29//
30// Once the folders are defined, we can implement the more interesting
31// higher-level code, such as the code that recognizes PHI nodes of various
32// types, computes the execution count of a loop, etc.
33//
34// TODO: We should use these routines and value representations to implement
35// dependence analysis!
36//
37//===----------------------------------------------------------------------===//
38//
39// There are several good references for the techniques used in this analysis.
40//
41// Chains of recurrences -- a method to expedite the evaluation
42// of closed-form functions
43// Olaf Bachmann, Paul S. Wang, Eugene V. Zima
44//
45// On computational properties of chains of recurrences
46// Eugene V. Zima
47//
48// Symbolic Evaluation of Chains of Recurrences for Loop Optimization
49// Robert A. van Engelen
50//
51// Efficient Symbolic Analysis for Optimizing Compilers
52// Robert A. van Engelen
53//
54// Using the chains of recurrences algebra for data dependence testing and
55// induction variable substitution
56// MS Thesis, Johnie Birch
57//
58//===----------------------------------------------------------------------===//
59
61#include "llvm/ADT/APInt.h"
62#include "llvm/ADT/ArrayRef.h"
63#include "llvm/ADT/DenseMap.h"
66#include "llvm/ADT/FoldingSet.h"
67#include "llvm/ADT/STLExtras.h"
68#include "llvm/ADT/ScopeExit.h"
69#include "llvm/ADT/Sequence.h"
71#include "llvm/ADT/SmallSet.h"
73#include "llvm/ADT/Statistic.h"
75#include "llvm/ADT/StringRef.h"
85#include "llvm/Config/llvm-config.h"
86#include "llvm/IR/Argument.h"
87#include "llvm/IR/BasicBlock.h"
88#include "llvm/IR/CFG.h"
89#include "llvm/IR/Constant.h"
91#include "llvm/IR/Constants.h"
92#include "llvm/IR/DataLayout.h"
94#include "llvm/IR/Dominators.h"
95#include "llvm/IR/Function.h"
96#include "llvm/IR/GlobalAlias.h"
97#include "llvm/IR/GlobalValue.h"
99#include "llvm/IR/InstrTypes.h"
100#include "llvm/IR/Instruction.h"
101#include "llvm/IR/Instructions.h"
103#include "llvm/IR/Intrinsics.h"
104#include "llvm/IR/LLVMContext.h"
105#include "llvm/IR/Operator.h"
106#include "llvm/IR/PatternMatch.h"
107#include "llvm/IR/Type.h"
108#include "llvm/IR/Use.h"
109#include "llvm/IR/User.h"
110#include "llvm/IR/Value.h"
111#include "llvm/IR/Verifier.h"
113#include "llvm/Pass.h"
114#include "llvm/Support/Casting.h"
117#include "llvm/Support/Debug.h"
122#include <algorithm>
123#include <cassert>
124#include <climits>
125#include <cstdint>
126#include <cstdlib>
127#include <map>
128#include <memory>
129#include <numeric>
130#include <optional>
131#include <tuple>
132#include <utility>
133#include <vector>
134
135using namespace llvm;
136using namespace PatternMatch;
137using namespace SCEVPatternMatch;
138
139#define DEBUG_TYPE "scalar-evolution"
140
141STATISTIC(NumExitCountsComputed,
142 "Number of loop exits with predictable exit counts");
143STATISTIC(NumExitCountsNotComputed,
144 "Number of loop exits without predictable exit counts");
145STATISTIC(NumBruteForceTripCountsComputed,
146 "Number of loops with trip counts computed by force");
147
148#ifdef EXPENSIVE_CHECKS
149bool llvm::VerifySCEV = true;
150#else
151bool llvm::VerifySCEV = false;
152#endif
153
155 MaxBruteForceIterations("scalar-evolution-max-iterations", cl::ReallyHidden,
156 cl::desc("Maximum number of iterations SCEV will "
157 "symbolically execute a constant "
158 "derived loop"),
159 cl::init(100));
160
162 "verify-scev", cl::Hidden, cl::location(VerifySCEV),
163 cl::desc("Verify ScalarEvolution's backedge taken counts (slow)"));
165 "verify-scev-strict", cl::Hidden,
166 cl::desc("Enable stricter verification with -verify-scev is passed"));
167
169 "scev-verify-ir", cl::Hidden,
170 cl::desc("Verify IR correctness when making sensitive SCEV queries (slow)"),
171 cl::init(false));
172
174 "scev-mulops-inline-threshold", cl::Hidden,
175 cl::desc("Threshold for inlining multiplication operands into a SCEV"),
176 cl::init(32));
177
179 "scev-addops-inline-threshold", cl::Hidden,
180 cl::desc("Threshold for inlining addition operands into a SCEV"),
181 cl::init(500));
182
184 "scalar-evolution-max-scev-compare-depth", cl::Hidden,
185 cl::desc("Maximum depth of recursive SCEV complexity comparisons"),
186 cl::init(32));
187
189 "scalar-evolution-max-scev-operations-implication-depth", cl::Hidden,
190 cl::desc("Maximum depth of recursive SCEV operations implication analysis"),
191 cl::init(2));
192
194 "scalar-evolution-max-value-compare-depth", cl::Hidden,
195 cl::desc("Maximum depth of recursive value complexity comparisons"),
196 cl::init(2));
197
199 MaxArithDepth("scalar-evolution-max-arith-depth", cl::Hidden,
200 cl::desc("Maximum depth of recursive arithmetics"),
201 cl::init(32));
202
204 "scalar-evolution-max-constant-evolving-depth", cl::Hidden,
205 cl::desc("Maximum depth of recursive constant evolving"), cl::init(32));
206
208 MaxCastDepth("scalar-evolution-max-cast-depth", cl::Hidden,
209 cl::desc("Maximum depth of recursive SExt/ZExt/Trunc"),
210 cl::init(8));
211
213 MaxAddRecSize("scalar-evolution-max-add-rec-size", cl::Hidden,
214 cl::desc("Max coefficients in AddRec during evolving"),
215 cl::init(8));
216
218 HugeExprThreshold("scalar-evolution-huge-expr-threshold", cl::Hidden,
219 cl::desc("Size of the expression which is considered huge"),
220 cl::init(4096));
221
223 "scev-range-iter-threshold", cl::Hidden,
224 cl::desc("Threshold for switching to iteratively computing SCEV ranges"),
225 cl::init(32));
226
228 "scalar-evolution-max-loop-guard-collection-depth", cl::Hidden,
229 cl::desc("Maximum depth for recursive loop guard collection"), cl::init(1));
230
231static cl::opt<bool>
232ClassifyExpressions("scalar-evolution-classify-expressions",
233 cl::Hidden, cl::init(true),
234 cl::desc("When printing analysis, include information on every instruction"));
235
237 "scalar-evolution-use-expensive-range-sharpening", cl::Hidden,
238 cl::init(false),
239 cl::desc("Use more powerful methods of sharpening expression ranges. May "
240 "be costly in terms of compile time"));
241
243 "scalar-evolution-max-scc-analysis-depth", cl::Hidden,
244 cl::desc("Maximum amount of nodes to process while searching SCEVUnknown "
245 "Phi strongly connected components"),
246 cl::init(8));
247
248static cl::opt<bool>
249 EnableFiniteLoopControl("scalar-evolution-finite-loop", cl::Hidden,
250 cl::desc("Handle <= and >= in finite loops"),
251 cl::init(true));
252
254 "scalar-evolution-use-context-for-no-wrap-flag-strenghening", cl::Hidden,
255 cl::desc("Infer nuw/nsw flags using context where suitable"),
256 cl::init(true));
257
258//===----------------------------------------------------------------------===//
259// SCEV class definitions
260//===----------------------------------------------------------------------===//
261
262//===----------------------------------------------------------------------===//
263// Implementation of the SCEV class.
264//
265
266#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
268 print(dbgs());
269 dbgs() << '\n';
270}
271#endif
272
274 switch (getSCEVType()) {
275 case scConstant:
276 cast<SCEVConstant>(this)->getValue()->printAsOperand(OS, false);
277 return;
278 case scVScale:
279 OS << "vscale";
280 return;
281 case scPtrToInt: {
282 const SCEVPtrToIntExpr *PtrToInt = cast<SCEVPtrToIntExpr>(this);
283 const SCEV *Op = PtrToInt->getOperand();
284 OS << "(ptrtoint " << *Op->getType() << " " << *Op << " to "
285 << *PtrToInt->getType() << ")";
286 return;
287 }
288 case scTruncate: {
289 const SCEVTruncateExpr *Trunc = cast<SCEVTruncateExpr>(this);
290 const SCEV *Op = Trunc->getOperand();
291 OS << "(trunc " << *Op->getType() << " " << *Op << " to "
292 << *Trunc->getType() << ")";
293 return;
294 }
295 case scZeroExtend: {
296 const SCEVZeroExtendExpr *ZExt = cast<SCEVZeroExtendExpr>(this);
297 const SCEV *Op = ZExt->getOperand();
298 OS << "(zext " << *Op->getType() << " " << *Op << " to "
299 << *ZExt->getType() << ")";
300 return;
301 }
302 case scSignExtend: {
303 const SCEVSignExtendExpr *SExt = cast<SCEVSignExtendExpr>(this);
304 const SCEV *Op = SExt->getOperand();
305 OS << "(sext " << *Op->getType() << " " << *Op << " to "
306 << *SExt->getType() << ")";
307 return;
308 }
309 case scAddRecExpr: {
310 const SCEVAddRecExpr *AR = cast<SCEVAddRecExpr>(this);
311 OS << "{" << *AR->getOperand(0);
312 for (unsigned i = 1, e = AR->getNumOperands(); i != e; ++i)
313 OS << ",+," << *AR->getOperand(i);
314 OS << "}<";
315 if (AR->hasNoUnsignedWrap())
316 OS << "nuw><";
317 if (AR->hasNoSignedWrap())
318 OS << "nsw><";
319 if (AR->hasNoSelfWrap() &&
321 OS << "nw><";
322 AR->getLoop()->getHeader()->printAsOperand(OS, /*PrintType=*/false);
323 OS << ">";
324 return;
325 }
326 case scAddExpr:
327 case scMulExpr:
328 case scUMaxExpr:
329 case scSMaxExpr:
330 case scUMinExpr:
331 case scSMinExpr:
333 const SCEVNAryExpr *NAry = cast<SCEVNAryExpr>(this);
334 const char *OpStr = nullptr;
335 switch (NAry->getSCEVType()) {
336 case scAddExpr: OpStr = " + "; break;
337 case scMulExpr: OpStr = " * "; break;
338 case scUMaxExpr: OpStr = " umax "; break;
339 case scSMaxExpr: OpStr = " smax "; break;
340 case scUMinExpr:
341 OpStr = " umin ";
342 break;
343 case scSMinExpr:
344 OpStr = " smin ";
345 break;
347 OpStr = " umin_seq ";
348 break;
349 default:
350 llvm_unreachable("There are no other nary expression types.");
351 }
352 OS << "(";
353 ListSeparator LS(OpStr);
354 for (const SCEV *Op : NAry->operands())
355 OS << LS << *Op;
356 OS << ")";
357 switch (NAry->getSCEVType()) {
358 case scAddExpr:
359 case scMulExpr:
360 if (NAry->hasNoUnsignedWrap())
361 OS << "<nuw>";
362 if (NAry->hasNoSignedWrap())
363 OS << "<nsw>";
364 break;
365 default:
366 // Nothing to print for other nary expressions.
367 break;
368 }
369 return;
370 }
371 case scUDivExpr: {
372 const SCEVUDivExpr *UDiv = cast<SCEVUDivExpr>(this);
373 OS << "(" << *UDiv->getLHS() << " /u " << *UDiv->getRHS() << ")";
374 return;
375 }
376 case scUnknown:
377 cast<SCEVUnknown>(this)->getValue()->printAsOperand(OS, false);
378 return;
380 OS << "***COULDNOTCOMPUTE***";
381 return;
382 }
383 llvm_unreachable("Unknown SCEV kind!");
384}
385
387 switch (getSCEVType()) {
388 case scConstant:
389 return cast<SCEVConstant>(this)->getType();
390 case scVScale:
391 return cast<SCEVVScale>(this)->getType();
392 case scPtrToInt:
393 case scTruncate:
394 case scZeroExtend:
395 case scSignExtend:
396 return cast<SCEVCastExpr>(this)->getType();
397 case scAddRecExpr:
398 return cast<SCEVAddRecExpr>(this)->getType();
399 case scMulExpr:
400 return cast<SCEVMulExpr>(this)->getType();
401 case scUMaxExpr:
402 case scSMaxExpr:
403 case scUMinExpr:
404 case scSMinExpr:
405 return cast<SCEVMinMaxExpr>(this)->getType();
407 return cast<SCEVSequentialMinMaxExpr>(this)->getType();
408 case scAddExpr:
409 return cast<SCEVAddExpr>(this)->getType();
410 case scUDivExpr:
411 return cast<SCEVUDivExpr>(this)->getType();
412 case scUnknown:
413 return cast<SCEVUnknown>(this)->getType();
415 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
416 }
417 llvm_unreachable("Unknown SCEV kind!");
418}
419
421 switch (getSCEVType()) {
422 case scConstant:
423 case scVScale:
424 case scUnknown:
425 return {};
426 case scPtrToInt:
427 case scTruncate:
428 case scZeroExtend:
429 case scSignExtend:
430 return cast<SCEVCastExpr>(this)->operands();
431 case scAddRecExpr:
432 case scAddExpr:
433 case scMulExpr:
434 case scUMaxExpr:
435 case scSMaxExpr:
436 case scUMinExpr:
437 case scSMinExpr:
439 return cast<SCEVNAryExpr>(this)->operands();
440 case scUDivExpr:
441 return cast<SCEVUDivExpr>(this)->operands();
443 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
444 }
445 llvm_unreachable("Unknown SCEV kind!");
446}
447
448bool SCEV::isZero() const { return match(this, m_scev_Zero()); }
449
450bool SCEV::isOne() const { return match(this, m_scev_One()); }
451
452bool SCEV::isAllOnesValue() const { return match(this, m_scev_AllOnes()); }
453
455 const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(this);
456 if (!Mul) return false;
457
458 // If there is a constant factor, it will be first.
459 const SCEVConstant *SC = dyn_cast<SCEVConstant>(Mul->getOperand(0));
460 if (!SC) return false;
461
462 // Return true if the value is negative, this matches things like (-42 * V).
463 return SC->getAPInt().isNegative();
464}
465
468
470 return S->getSCEVType() == scCouldNotCompute;
471}
472
475 ID.AddInteger(scConstant);
476 ID.AddPointer(V);
477 void *IP = nullptr;
478 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
479 SCEV *S = new (SCEVAllocator) SCEVConstant(ID.Intern(SCEVAllocator), V);
480 UniqueSCEVs.InsertNode(S, IP);
481 return S;
482}
483
485 return getConstant(ConstantInt::get(getContext(), Val));
486}
487
488const SCEV *
490 IntegerType *ITy = cast<IntegerType>(getEffectiveSCEVType(Ty));
491 return getConstant(ConstantInt::get(ITy, V, isSigned));
492}
493
496 ID.AddInteger(scVScale);
497 ID.AddPointer(Ty);
498 void *IP = nullptr;
499 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP))
500 return S;
501 SCEV *S = new (SCEVAllocator) SCEVVScale(ID.Intern(SCEVAllocator), Ty);
502 UniqueSCEVs.InsertNode(S, IP);
503 return S;
504}
505
507 const SCEV *Res = getConstant(Ty, EC.getKnownMinValue());
508 if (EC.isScalable())
509 Res = getMulExpr(Res, getVScale(Ty));
510 return Res;
511}
512
514 const SCEV *op, Type *ty)
515 : SCEV(ID, SCEVTy, computeExpressionSize(op)), Op(op), Ty(ty) {}
516
517SCEVPtrToIntExpr::SCEVPtrToIntExpr(const FoldingSetNodeIDRef ID, const SCEV *Op,
518 Type *ITy)
519 : SCEVCastExpr(ID, scPtrToInt, Op, ITy) {
521 "Must be a non-bit-width-changing pointer-to-integer cast!");
522}
523
525 SCEVTypes SCEVTy, const SCEV *op,
526 Type *ty)
527 : SCEVCastExpr(ID, SCEVTy, op, ty) {}
528
529SCEVTruncateExpr::SCEVTruncateExpr(const FoldingSetNodeIDRef ID, const SCEV *op,
530 Type *ty)
532 assert(getOperand()->getType()->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
533 "Cannot truncate non-integer value!");
534}
535
536SCEVZeroExtendExpr::SCEVZeroExtendExpr(const FoldingSetNodeIDRef ID,
537 const SCEV *op, Type *ty)
539 assert(getOperand()->getType()->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
540 "Cannot zero extend non-integer value!");
541}
542
543SCEVSignExtendExpr::SCEVSignExtendExpr(const FoldingSetNodeIDRef ID,
544 const SCEV *op, Type *ty)
546 assert(getOperand()->getType()->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
547 "Cannot sign extend non-integer value!");
548}
549
550void SCEVUnknown::deleted() {
551 // Clear this SCEVUnknown from various maps.
552 SE->forgetMemoizedResults(this);
553
554 // Remove this SCEVUnknown from the uniquing map.
555 SE->UniqueSCEVs.RemoveNode(this);
556
557 // Release the value.
558 setValPtr(nullptr);
559}
560
561void SCEVUnknown::allUsesReplacedWith(Value *New) {
562 // Clear this SCEVUnknown from various maps.
563 SE->forgetMemoizedResults(this);
564
565 // Remove this SCEVUnknown from the uniquing map.
566 SE->UniqueSCEVs.RemoveNode(this);
567
568 // Replace the value pointer in case someone is still using this SCEVUnknown.
569 setValPtr(New);
570}
571
572//===----------------------------------------------------------------------===//
573// SCEV Utilities
574//===----------------------------------------------------------------------===//
575
576/// Compare the two values \p LV and \p RV in terms of their "complexity" where
577/// "complexity" is a partial (and somewhat ad-hoc) relation used to order
578/// operands in SCEV expressions.
579static int CompareValueComplexity(const LoopInfo *const LI, Value *LV,
580 Value *RV, unsigned Depth) {
582 return 0;
583
584 // Order pointer values after integer values. This helps SCEVExpander form
585 // GEPs.
586 bool LIsPointer = LV->getType()->isPointerTy(),
587 RIsPointer = RV->getType()->isPointerTy();
588 if (LIsPointer != RIsPointer)
589 return (int)LIsPointer - (int)RIsPointer;
590
591 // Compare getValueID values.
592 unsigned LID = LV->getValueID(), RID = RV->getValueID();
593 if (LID != RID)
594 return (int)LID - (int)RID;
595
596 // Sort arguments by their position.
597 if (const auto *LA = dyn_cast<Argument>(LV)) {
598 const auto *RA = cast<Argument>(RV);
599 unsigned LArgNo = LA->getArgNo(), RArgNo = RA->getArgNo();
600 return (int)LArgNo - (int)RArgNo;
601 }
602
603 if (const auto *LGV = dyn_cast<GlobalValue>(LV)) {
604 const auto *RGV = cast<GlobalValue>(RV);
605
606 const auto IsGVNameSemantic = [&](const GlobalValue *GV) {
607 auto LT = GV->getLinkage();
608 return !(GlobalValue::isPrivateLinkage(LT) ||
610 };
611
612 // Use the names to distinguish the two values, but only if the
613 // names are semantically important.
614 if (IsGVNameSemantic(LGV) && IsGVNameSemantic(RGV))
615 return LGV->getName().compare(RGV->getName());
616 }
617
618 // For instructions, compare their loop depth, and their operand count. This
619 // is pretty loose.
620 if (const auto *LInst = dyn_cast<Instruction>(LV)) {
621 const auto *RInst = cast<Instruction>(RV);
622
623 // Compare loop depths.
624 const BasicBlock *LParent = LInst->getParent(),
625 *RParent = RInst->getParent();
626 if (LParent != RParent) {
627 unsigned LDepth = LI->getLoopDepth(LParent),
628 RDepth = LI->getLoopDepth(RParent);
629 if (LDepth != RDepth)
630 return (int)LDepth - (int)RDepth;
631 }
632
633 // Compare the number of operands.
634 unsigned LNumOps = LInst->getNumOperands(),
635 RNumOps = RInst->getNumOperands();
636 if (LNumOps != RNumOps)
637 return (int)LNumOps - (int)RNumOps;
638
639 for (unsigned Idx : seq(LNumOps)) {
640 int Result = CompareValueComplexity(LI, LInst->getOperand(Idx),
641 RInst->getOperand(Idx), Depth + 1);
642 if (Result != 0)
643 return Result;
644 }
645 }
646
647 return 0;
648}
649
650// Return negative, zero, or positive, if LHS is less than, equal to, or greater
651// than RHS, respectively. A three-way result allows recursive comparisons to be
652// more efficient.
653// If the max analysis depth was reached, return std::nullopt, assuming we do
654// not know if they are equivalent for sure.
655static std::optional<int>
657 const LoopInfo *const LI, const SCEV *LHS,
658 const SCEV *RHS, DominatorTree &DT, unsigned Depth = 0) {
659 // Fast-path: SCEVs are uniqued so we can do a quick equality check.
660 if (LHS == RHS)
661 return 0;
662
663 // Primarily, sort the SCEVs by their getSCEVType().
664 SCEVTypes LType = LHS->getSCEVType(), RType = RHS->getSCEVType();
665 if (LType != RType)
666 return (int)LType - (int)RType;
667
668 if (EqCacheSCEV.isEquivalent(LHS, RHS))
669 return 0;
670
672 return std::nullopt;
673
674 // Aside from the getSCEVType() ordering, the particular ordering
675 // isn't very important except that it's beneficial to be consistent,
676 // so that (a + b) and (b + a) don't end up as different expressions.
677 switch (LType) {
678 case scUnknown: {
679 const SCEVUnknown *LU = cast<SCEVUnknown>(LHS);
680 const SCEVUnknown *RU = cast<SCEVUnknown>(RHS);
681
682 int X =
683 CompareValueComplexity(LI, LU->getValue(), RU->getValue(), Depth + 1);
684 if (X == 0)
685 EqCacheSCEV.unionSets(LHS, RHS);
686 return X;
687 }
688
689 case scConstant: {
690 const SCEVConstant *LC = cast<SCEVConstant>(LHS);
691 const SCEVConstant *RC = cast<SCEVConstant>(RHS);
692
693 // Compare constant values.
694 const APInt &LA = LC->getAPInt();
695 const APInt &RA = RC->getAPInt();
696 unsigned LBitWidth = LA.getBitWidth(), RBitWidth = RA.getBitWidth();
697 if (LBitWidth != RBitWidth)
698 return (int)LBitWidth - (int)RBitWidth;
699 return LA.ult(RA) ? -1 : 1;
700 }
701
702 case scVScale: {
703 const auto *LTy = cast<IntegerType>(cast<SCEVVScale>(LHS)->getType());
704 const auto *RTy = cast<IntegerType>(cast<SCEVVScale>(RHS)->getType());
705 return LTy->getBitWidth() - RTy->getBitWidth();
706 }
707
708 case scAddRecExpr: {
709 const SCEVAddRecExpr *LA = cast<SCEVAddRecExpr>(LHS);
710 const SCEVAddRecExpr *RA = cast<SCEVAddRecExpr>(RHS);
711
712 // There is always a dominance between two recs that are used by one SCEV,
713 // so we can safely sort recs by loop header dominance. We require such
714 // order in getAddExpr.
715 const Loop *LLoop = LA->getLoop(), *RLoop = RA->getLoop();
716 if (LLoop != RLoop) {
717 const BasicBlock *LHead = LLoop->getHeader(), *RHead = RLoop->getHeader();
718 assert(LHead != RHead && "Two loops share the same header?");
719 if (DT.dominates(LHead, RHead))
720 return 1;
721 assert(DT.dominates(RHead, LHead) &&
722 "No dominance between recurrences used by one SCEV?");
723 return -1;
724 }
725
726 [[fallthrough]];
727 }
728
729 case scTruncate:
730 case scZeroExtend:
731 case scSignExtend:
732 case scPtrToInt:
733 case scAddExpr:
734 case scMulExpr:
735 case scUDivExpr:
736 case scSMaxExpr:
737 case scUMaxExpr:
738 case scSMinExpr:
739 case scUMinExpr:
741 ArrayRef<const SCEV *> LOps = LHS->operands();
742 ArrayRef<const SCEV *> ROps = RHS->operands();
743
744 // Lexicographically compare n-ary-like expressions.
745 unsigned LNumOps = LOps.size(), RNumOps = ROps.size();
746 if (LNumOps != RNumOps)
747 return (int)LNumOps - (int)RNumOps;
748
749 for (unsigned i = 0; i != LNumOps; ++i) {
750 auto X = CompareSCEVComplexity(EqCacheSCEV, LI, LOps[i], ROps[i], DT,
751 Depth + 1);
752 if (X != 0)
753 return X;
754 }
755 EqCacheSCEV.unionSets(LHS, RHS);
756 return 0;
757 }
758
760 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
761 }
762 llvm_unreachable("Unknown SCEV kind!");
763}
764
765/// Given a list of SCEV objects, order them by their complexity, and group
766/// objects of the same complexity together by value. When this routine is
767/// finished, we know that any duplicates in the vector are consecutive and that
768/// complexity is monotonically increasing.
769///
770/// Note that we go take special precautions to ensure that we get deterministic
771/// results from this routine. In other words, we don't want the results of
772/// this to depend on where the addresses of various SCEV objects happened to
773/// land in memory.
775 LoopInfo *LI, DominatorTree &DT) {
776 if (Ops.size() < 2) return; // Noop
777
779
780 // Whether LHS has provably less complexity than RHS.
781 auto IsLessComplex = [&](const SCEV *LHS, const SCEV *RHS) {
782 auto Complexity = CompareSCEVComplexity(EqCacheSCEV, LI, LHS, RHS, DT);
783 return Complexity && *Complexity < 0;
784 };
785 if (Ops.size() == 2) {
786 // This is the common case, which also happens to be trivially simple.
787 // Special case it.
788 const SCEV *&LHS = Ops[0], *&RHS = Ops[1];
789 if (IsLessComplex(RHS, LHS))
790 std::swap(LHS, RHS);
791 return;
792 }
793
794 // Do the rough sort by complexity.
795 llvm::stable_sort(Ops, [&](const SCEV *LHS, const SCEV *RHS) {
796 return IsLessComplex(LHS, RHS);
797 });
798
799 // Now that we are sorted by complexity, group elements of the same
800 // complexity. Note that this is, at worst, N^2, but the vector is likely to
801 // be extremely short in practice. Note that we take this approach because we
802 // do not want to depend on the addresses of the objects we are grouping.
803 for (unsigned i = 0, e = Ops.size(); i != e-2; ++i) {
804 const SCEV *S = Ops[i];
805 unsigned Complexity = S->getSCEVType();
806
807 // If there are any objects of the same complexity and same value as this
808 // one, group them.
809 for (unsigned j = i+1; j != e && Ops[j]->getSCEVType() == Complexity; ++j) {
810 if (Ops[j] == S) { // Found a duplicate.
811 // Move it to immediately after i'th element.
812 std::swap(Ops[i+1], Ops[j]);
813 ++i; // no need to rescan it.
814 if (i == e-2) return; // Done!
815 }
816 }
817 }
818}
819
820/// Returns true if \p Ops contains a huge SCEV (the subtree of S contains at
821/// least HugeExprThreshold nodes).
823 return any_of(Ops, [](const SCEV *S) {
825 });
826}
827
828/// Performs a number of common optimizations on the passed \p Ops. If the
829/// whole expression reduces down to a single operand, it will be returned.
830///
831/// The following optimizations are performed:
832/// * Fold constants using the \p Fold function.
833/// * Remove identity constants satisfying \p IsIdentity.
834/// * If a constant satisfies \p IsAbsorber, return it.
835/// * Sort operands by complexity.
836template <typename FoldT, typename IsIdentityT, typename IsAbsorberT>
837static const SCEV *
839 SmallVectorImpl<const SCEV *> &Ops, FoldT Fold,
840 IsIdentityT IsIdentity, IsAbsorberT IsAbsorber) {
841 const SCEVConstant *Folded = nullptr;
842 for (unsigned Idx = 0; Idx < Ops.size();) {
843 const SCEV *Op = Ops[Idx];
844 if (const auto *C = dyn_cast<SCEVConstant>(Op)) {
845 if (!Folded)
846 Folded = C;
847 else
848 Folded = cast<SCEVConstant>(
849 SE.getConstant(Fold(Folded->getAPInt(), C->getAPInt())));
850 Ops.erase(Ops.begin() + Idx);
851 continue;
852 }
853 ++Idx;
854 }
855
856 if (Ops.empty()) {
857 assert(Folded && "Must have folded value");
858 return Folded;
859 }
860
861 if (Folded && IsAbsorber(Folded->getAPInt()))
862 return Folded;
863
864 GroupByComplexity(Ops, &LI, DT);
865 if (Folded && !IsIdentity(Folded->getAPInt()))
866 Ops.insert(Ops.begin(), Folded);
867
868 return Ops.size() == 1 ? Ops[0] : nullptr;
869}
870
871//===----------------------------------------------------------------------===//
872// Simple SCEV method implementations
873//===----------------------------------------------------------------------===//
874
875/// Compute BC(It, K). The result has width W. Assume, K > 0.
876static const SCEV *BinomialCoefficient(const SCEV *It, unsigned K,
877 ScalarEvolution &SE,
878 Type *ResultTy) {
879 // Handle the simplest case efficiently.
880 if (K == 1)
881 return SE.getTruncateOrZeroExtend(It, ResultTy);
882
883 // We are using the following formula for BC(It, K):
884 //
885 // BC(It, K) = (It * (It - 1) * ... * (It - K + 1)) / K!
886 //
887 // Suppose, W is the bitwidth of the return value. We must be prepared for
888 // overflow. Hence, we must assure that the result of our computation is
889 // equal to the accurate one modulo 2^W. Unfortunately, division isn't
890 // safe in modular arithmetic.
891 //
892 // However, this code doesn't use exactly that formula; the formula it uses
893 // is something like the following, where T is the number of factors of 2 in
894 // K! (i.e. trailing zeros in the binary representation of K!), and ^ is
895 // exponentiation:
896 //
897 // BC(It, K) = (It * (It - 1) * ... * (It - K + 1)) / 2^T / (K! / 2^T)
898 //
899 // This formula is trivially equivalent to the previous formula. However,
900 // this formula can be implemented much more efficiently. The trick is that
901 // K! / 2^T is odd, and exact division by an odd number *is* safe in modular
902 // arithmetic. To do exact division in modular arithmetic, all we have
903 // to do is multiply by the inverse. Therefore, this step can be done at
904 // width W.
905 //
906 // The next issue is how to safely do the division by 2^T. The way this
907 // is done is by doing the multiplication step at a width of at least W + T
908 // bits. This way, the bottom W+T bits of the product are accurate. Then,
909 // when we perform the division by 2^T (which is equivalent to a right shift
910 // by T), the bottom W bits are accurate. Extra bits are okay; they'll get
911 // truncated out after the division by 2^T.
912 //
913 // In comparison to just directly using the first formula, this technique
914 // is much more efficient; using the first formula requires W * K bits,
915 // but this formula less than W + K bits. Also, the first formula requires
916 // a division step, whereas this formula only requires multiplies and shifts.
917 //
918 // It doesn't matter whether the subtraction step is done in the calculation
919 // width or the input iteration count's width; if the subtraction overflows,
920 // the result must be zero anyway. We prefer here to do it in the width of
921 // the induction variable because it helps a lot for certain cases; CodeGen
922 // isn't smart enough to ignore the overflow, which leads to much less
923 // efficient code if the width of the subtraction is wider than the native
924 // register width.
925 //
926 // (It's possible to not widen at all by pulling out factors of 2 before
927 // the multiplication; for example, K=2 can be calculated as
928 // It/2*(It+(It*INT_MIN/INT_MIN)+-1). However, it requires
929 // extra arithmetic, so it's not an obvious win, and it gets
930 // much more complicated for K > 3.)
931
932 // Protection from insane SCEVs; this bound is conservative,
933 // but it probably doesn't matter.
934 if (K > 1000)
935 return SE.getCouldNotCompute();
936
937 unsigned W = SE.getTypeSizeInBits(ResultTy);
938
939 // Calculate K! / 2^T and T; we divide out the factors of two before
940 // multiplying for calculating K! / 2^T to avoid overflow.
941 // Other overflow doesn't matter because we only care about the bottom
942 // W bits of the result.
943 APInt OddFactorial(W, 1);
944 unsigned T = 1;
945 for (unsigned i = 3; i <= K; ++i) {
946 unsigned TwoFactors = countr_zero(i);
947 T += TwoFactors;
948 OddFactorial *= (i >> TwoFactors);
949 }
950
951 // We need at least W + T bits for the multiplication step
952 unsigned CalculationBits = W + T;
953
954 // Calculate 2^T, at width T+W.
955 APInt DivFactor = APInt::getOneBitSet(CalculationBits, T);
956
957 // Calculate the multiplicative inverse of K! / 2^T;
958 // this multiplication factor will perform the exact division by
959 // K! / 2^T.
960 APInt MultiplyFactor = OddFactorial.multiplicativeInverse();
961
962 // Calculate the product, at width T+W
963 IntegerType *CalculationTy = IntegerType::get(SE.getContext(),
964 CalculationBits);
965 const SCEV *Dividend = SE.getTruncateOrZeroExtend(It, CalculationTy);
966 for (unsigned i = 1; i != K; ++i) {
967 const SCEV *S = SE.getMinusSCEV(It, SE.getConstant(It->getType(), i));
968 Dividend = SE.getMulExpr(Dividend,
969 SE.getTruncateOrZeroExtend(S, CalculationTy));
970 }
971
972 // Divide by 2^T
973 const SCEV *DivResult = SE.getUDivExpr(Dividend, SE.getConstant(DivFactor));
974
975 // Truncate the result, and divide by K! / 2^T.
976
977 return SE.getMulExpr(SE.getConstant(MultiplyFactor),
978 SE.getTruncateOrZeroExtend(DivResult, ResultTy));
979}
980
981/// Return the value of this chain of recurrences at the specified iteration
982/// number. We can evaluate this recurrence by multiplying each element in the
983/// chain by the binomial coefficient corresponding to it. In other words, we
984/// can evaluate {A,+,B,+,C,+,D} as:
985///
986/// A*BC(It, 0) + B*BC(It, 1) + C*BC(It, 2) + D*BC(It, 3)
987///
988/// where BC(It, k) stands for binomial coefficient.
990 ScalarEvolution &SE) const {
991 return evaluateAtIteration(operands(), It, SE);
992}
993
994const SCEV *
996 const SCEV *It, ScalarEvolution &SE) {
997 assert(Operands.size() > 0);
998 const SCEV *Result = Operands[0];
999 for (unsigned i = 1, e = Operands.size(); i != e; ++i) {
1000 // The computation is correct in the face of overflow provided that the
1001 // multiplication is performed _after_ the evaluation of the binomial
1002 // coefficient.
1003 const SCEV *Coeff = BinomialCoefficient(It, i, SE, Result->getType());
1004 if (isa<SCEVCouldNotCompute>(Coeff))
1005 return Coeff;
1006
1007 Result = SE.getAddExpr(Result, SE.getMulExpr(Operands[i], Coeff));
1008 }
1009 return Result;
1010}
1011
1012//===----------------------------------------------------------------------===//
1013// SCEV Expression folder implementations
1014//===----------------------------------------------------------------------===//
1015
1017 unsigned Depth) {
1018 assert(Depth <= 1 &&
1019 "getLosslessPtrToIntExpr() should self-recurse at most once.");
1020
1021 // We could be called with an integer-typed operands during SCEV rewrites.
1022 // Since the operand is an integer already, just perform zext/trunc/self cast.
1023 if (!Op->getType()->isPointerTy())
1024 return Op;
1025
1026 // What would be an ID for such a SCEV cast expression?
1028 ID.AddInteger(scPtrToInt);
1029 ID.AddPointer(Op);
1030
1031 void *IP = nullptr;
1032
1033 // Is there already an expression for such a cast?
1034 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP))
1035 return S;
1036
1037 // It isn't legal for optimizations to construct new ptrtoint expressions
1038 // for non-integral pointers.
1039 if (getDataLayout().isNonIntegralPointerType(Op->getType()))
1040 return getCouldNotCompute();
1041
1042 Type *IntPtrTy = getDataLayout().getIntPtrType(Op->getType());
1043
1044 // We can only trivially model ptrtoint if SCEV's effective (integer) type
1045 // is sufficiently wide to represent all possible pointer values.
1046 // We could theoretically teach SCEV to truncate wider pointers, but
1047 // that isn't implemented for now.
1049 getDataLayout().getTypeSizeInBits(IntPtrTy))
1050 return getCouldNotCompute();
1051
1052 // If not, is this expression something we can't reduce any further?
1053 if (auto *U = dyn_cast<SCEVUnknown>(Op)) {
1054 // Perform some basic constant folding. If the operand of the ptr2int cast
1055 // is a null pointer, don't create a ptr2int SCEV expression (that will be
1056 // left as-is), but produce a zero constant.
1057 // NOTE: We could handle a more general case, but lack motivational cases.
1058 if (isa<ConstantPointerNull>(U->getValue()))
1059 return getZero(IntPtrTy);
1060
1061 // Create an explicit cast node.
1062 // We can reuse the existing insert position since if we get here,
1063 // we won't have made any changes which would invalidate it.
1064 SCEV *S = new (SCEVAllocator)
1065 SCEVPtrToIntExpr(ID.Intern(SCEVAllocator), Op, IntPtrTy);
1066 UniqueSCEVs.InsertNode(S, IP);
1067 registerUser(S, Op);
1068 return S;
1069 }
1070
1071 assert(Depth == 0 && "getLosslessPtrToIntExpr() should not self-recurse for "
1072 "non-SCEVUnknown's.");
1073
1074 // Otherwise, we've got some expression that is more complex than just a
1075 // single SCEVUnknown. But we don't want to have a SCEVPtrToIntExpr of an
1076 // arbitrary expression, we want to have SCEVPtrToIntExpr of an SCEVUnknown
1077 // only, and the expressions must otherwise be integer-typed.
1078 // So sink the cast down to the SCEVUnknown's.
1079
1080 /// The SCEVPtrToIntSinkingRewriter takes a scalar evolution expression,
1081 /// which computes a pointer-typed value, and rewrites the whole expression
1082 /// tree so that *all* the computations are done on integers, and the only
1083 /// pointer-typed operands in the expression are SCEVUnknown.
1084 class SCEVPtrToIntSinkingRewriter
1085 : public SCEVRewriteVisitor<SCEVPtrToIntSinkingRewriter> {
1087
1088 public:
1089 SCEVPtrToIntSinkingRewriter(ScalarEvolution &SE) : SCEVRewriteVisitor(SE) {}
1090
1091 static const SCEV *rewrite(const SCEV *Scev, ScalarEvolution &SE) {
1092 SCEVPtrToIntSinkingRewriter Rewriter(SE);
1093 return Rewriter.visit(Scev);
1094 }
1095
1096 const SCEV *visit(const SCEV *S) {
1097 Type *STy = S->getType();
1098 // If the expression is not pointer-typed, just keep it as-is.
1099 if (!STy->isPointerTy())
1100 return S;
1101 // Else, recursively sink the cast down into it.
1102 return Base::visit(S);
1103 }
1104
1105 const SCEV *visitAddExpr(const SCEVAddExpr *Expr) {
1107 bool Changed = false;
1108 for (const auto *Op : Expr->operands()) {
1109 Operands.push_back(visit(Op));
1110 Changed |= Op != Operands.back();
1111 }
1112 return !Changed ? Expr : SE.getAddExpr(Operands, Expr->getNoWrapFlags());
1113 }
1114
1115 const SCEV *visitMulExpr(const SCEVMulExpr *Expr) {
1117 bool Changed = false;
1118 for (const auto *Op : Expr->operands()) {
1119 Operands.push_back(visit(Op));
1120 Changed |= Op != Operands.back();
1121 }
1122 return !Changed ? Expr : SE.getMulExpr(Operands, Expr->getNoWrapFlags());
1123 }
1124
1125 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
1126 assert(Expr->getType()->isPointerTy() &&
1127 "Should only reach pointer-typed SCEVUnknown's.");
1128 return SE.getLosslessPtrToIntExpr(Expr, /*Depth=*/1);
1129 }
1130 };
1131
1132 // And actually perform the cast sinking.
1133 const SCEV *IntOp = SCEVPtrToIntSinkingRewriter::rewrite(Op, *this);
1134 assert(IntOp->getType()->isIntegerTy() &&
1135 "We must have succeeded in sinking the cast, "
1136 "and ending up with an integer-typed expression!");
1137 return IntOp;
1138}
1139
1141 assert(Ty->isIntegerTy() && "Target type must be an integer type!");
1142
1143 const SCEV *IntOp = getLosslessPtrToIntExpr(Op);
1144 if (isa<SCEVCouldNotCompute>(IntOp))
1145 return IntOp;
1146
1147 return getTruncateOrZeroExtend(IntOp, Ty);
1148}
1149
1151 unsigned Depth) {
1152 assert(getTypeSizeInBits(Op->getType()) > getTypeSizeInBits(Ty) &&
1153 "This is not a truncating conversion!");
1154 assert(isSCEVable(Ty) &&
1155 "This is not a conversion to a SCEVable type!");
1156 assert(!Op->getType()->isPointerTy() && "Can't truncate pointer!");
1157 Ty = getEffectiveSCEVType(Ty);
1158
1160 ID.AddInteger(scTruncate);
1161 ID.AddPointer(Op);
1162 ID.AddPointer(Ty);
1163 void *IP = nullptr;
1164 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
1165
1166 // Fold if the operand is constant.
1167 if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op))
1168 return getConstant(
1169 cast<ConstantInt>(ConstantExpr::getTrunc(SC->getValue(), Ty)));
1170
1171 // trunc(trunc(x)) --> trunc(x)
1172 if (const SCEVTruncateExpr *ST = dyn_cast<SCEVTruncateExpr>(Op))
1173 return getTruncateExpr(ST->getOperand(), Ty, Depth + 1);
1174
1175 // trunc(sext(x)) --> sext(x) if widening or trunc(x) if narrowing
1176 if (const SCEVSignExtendExpr *SS = dyn_cast<SCEVSignExtendExpr>(Op))
1177 return getTruncateOrSignExtend(SS->getOperand(), Ty, Depth + 1);
1178
1179 // trunc(zext(x)) --> zext(x) if widening or trunc(x) if narrowing
1180 if (const SCEVZeroExtendExpr *SZ = dyn_cast<SCEVZeroExtendExpr>(Op))
1181 return getTruncateOrZeroExtend(SZ->getOperand(), Ty, Depth + 1);
1182
1183 if (Depth > MaxCastDepth) {
1184 SCEV *S =
1185 new (SCEVAllocator) SCEVTruncateExpr(ID.Intern(SCEVAllocator), Op, Ty);
1186 UniqueSCEVs.InsertNode(S, IP);
1187 registerUser(S, Op);
1188 return S;
1189 }
1190
1191 // trunc(x1 + ... + xN) --> trunc(x1) + ... + trunc(xN) and
1192 // trunc(x1 * ... * xN) --> trunc(x1) * ... * trunc(xN),
1193 // if after transforming we have at most one truncate, not counting truncates
1194 // that replace other casts.
1195 if (isa<SCEVAddExpr>(Op) || isa<SCEVMulExpr>(Op)) {
1196 auto *CommOp = cast<SCEVCommutativeExpr>(Op);
1198 unsigned numTruncs = 0;
1199 for (unsigned i = 0, e = CommOp->getNumOperands(); i != e && numTruncs < 2;
1200 ++i) {
1201 const SCEV *S = getTruncateExpr(CommOp->getOperand(i), Ty, Depth + 1);
1202 if (!isa<SCEVIntegralCastExpr>(CommOp->getOperand(i)) &&
1203 isa<SCEVTruncateExpr>(S))
1204 numTruncs++;
1205 Operands.push_back(S);
1206 }
1207 if (numTruncs < 2) {
1208 if (isa<SCEVAddExpr>(Op))
1209 return getAddExpr(Operands);
1210 if (isa<SCEVMulExpr>(Op))
1211 return getMulExpr(Operands);
1212 llvm_unreachable("Unexpected SCEV type for Op.");
1213 }
1214 // Although we checked in the beginning that ID is not in the cache, it is
1215 // possible that during recursion and different modification ID was inserted
1216 // into the cache. So if we find it, just return it.
1217 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP))
1218 return S;
1219 }
1220
1221 // If the input value is a chrec scev, truncate the chrec's operands.
1222 if (const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(Op)) {
1224 for (const SCEV *Op : AddRec->operands())
1225 Operands.push_back(getTruncateExpr(Op, Ty, Depth + 1));
1226 return getAddRecExpr(Operands, AddRec->getLoop(), SCEV::FlagAnyWrap);
1227 }
1228
1229 // Return zero if truncating to known zeros.
1230 uint32_t MinTrailingZeros = getMinTrailingZeros(Op);
1231 if (MinTrailingZeros >= getTypeSizeInBits(Ty))
1232 return getZero(Ty);
1233
1234 // The cast wasn't folded; create an explicit cast node. We can reuse
1235 // the existing insert position since if we get here, we won't have
1236 // made any changes which would invalidate it.
1237 SCEV *S = new (SCEVAllocator) SCEVTruncateExpr(ID.Intern(SCEVAllocator),
1238 Op, Ty);
1239 UniqueSCEVs.InsertNode(S, IP);
1240 registerUser(S, Op);
1241 return S;
1242}
1243
1244// Get the limit of a recurrence such that incrementing by Step cannot cause
1245// signed overflow as long as the value of the recurrence within the
1246// loop does not exceed this limit before incrementing.
1247static const SCEV *getSignedOverflowLimitForStep(const SCEV *Step,
1248 ICmpInst::Predicate *Pred,
1249 ScalarEvolution *SE) {
1250 unsigned BitWidth = SE->getTypeSizeInBits(Step->getType());
1251 if (SE->isKnownPositive(Step)) {
1252 *Pred = ICmpInst::ICMP_SLT;
1254 SE->getSignedRangeMax(Step));
1255 }
1256 if (SE->isKnownNegative(Step)) {
1257 *Pred = ICmpInst::ICMP_SGT;
1259 SE->getSignedRangeMin(Step));
1260 }
1261 return nullptr;
1262}
1263
1264// Get the limit of a recurrence such that incrementing by Step cannot cause
1265// unsigned overflow as long as the value of the recurrence within the loop does
1266// not exceed this limit before incrementing.
1268 ICmpInst::Predicate *Pred,
1269 ScalarEvolution *SE) {
1270 unsigned BitWidth = SE->getTypeSizeInBits(Step->getType());
1271 *Pred = ICmpInst::ICMP_ULT;
1272
1274 SE->getUnsignedRangeMax(Step));
1275}
1276
1277namespace {
1278
1279struct ExtendOpTraitsBase {
1280 typedef const SCEV *(ScalarEvolution::*GetExtendExprTy)(const SCEV *, Type *,
1281 unsigned);
1282};
1283
1284// Used to make code generic over signed and unsigned overflow.
1285template <typename ExtendOp> struct ExtendOpTraits {
1286 // Members present:
1287 //
1288 // static const SCEV::NoWrapFlags WrapType;
1289 //
1290 // static const ExtendOpTraitsBase::GetExtendExprTy GetExtendExpr;
1291 //
1292 // static const SCEV *getOverflowLimitForStep(const SCEV *Step,
1293 // ICmpInst::Predicate *Pred,
1294 // ScalarEvolution *SE);
1295};
1296
1297template <>
1298struct ExtendOpTraits<SCEVSignExtendExpr> : public ExtendOpTraitsBase {
1299 static const SCEV::NoWrapFlags WrapType = SCEV::FlagNSW;
1300
1301 static const GetExtendExprTy GetExtendExpr;
1302
1303 static const SCEV *getOverflowLimitForStep(const SCEV *Step,
1304 ICmpInst::Predicate *Pred,
1305 ScalarEvolution *SE) {
1306 return getSignedOverflowLimitForStep(Step, Pred, SE);
1307 }
1308};
1309
1310const ExtendOpTraitsBase::GetExtendExprTy ExtendOpTraits<
1312
1313template <>
1314struct ExtendOpTraits<SCEVZeroExtendExpr> : public ExtendOpTraitsBase {
1315 static const SCEV::NoWrapFlags WrapType = SCEV::FlagNUW;
1316
1317 static const GetExtendExprTy GetExtendExpr;
1318
1319 static const SCEV *getOverflowLimitForStep(const SCEV *Step,
1320 ICmpInst::Predicate *Pred,
1321 ScalarEvolution *SE) {
1322 return getUnsignedOverflowLimitForStep(Step, Pred, SE);
1323 }
1324};
1325
1326const ExtendOpTraitsBase::GetExtendExprTy ExtendOpTraits<
1328
1329} // end anonymous namespace
1330
1331// The recurrence AR has been shown to have no signed/unsigned wrap or something
1332// close to it. Typically, if we can prove NSW/NUW for AR, then we can just as
1333// easily prove NSW/NUW for its preincrement or postincrement sibling. This
1334// allows normalizing a sign/zero extended AddRec as such: {sext/zext(Step +
1335// Start),+,Step} => {(Step + sext/zext(Start),+,Step} As a result, the
1336// expression "Step + sext/zext(PreIncAR)" is congruent with
1337// "sext/zext(PostIncAR)"
1338template <typename ExtendOpTy>
1339static const SCEV *getPreStartForExtend(const SCEVAddRecExpr *AR, Type *Ty,
1340 ScalarEvolution *SE, unsigned Depth) {
1341 auto WrapType = ExtendOpTraits<ExtendOpTy>::WrapType;
1342 auto GetExtendExpr = ExtendOpTraits<ExtendOpTy>::GetExtendExpr;
1343
1344 const Loop *L = AR->getLoop();
1345 const SCEV *Start = AR->getStart();
1346 const SCEV *Step = AR->getStepRecurrence(*SE);
1347
1348 // Check for a simple looking step prior to loop entry.
1349 const SCEVAddExpr *SA = dyn_cast<SCEVAddExpr>(Start);
1350 if (!SA)
1351 return nullptr;
1352
1353 // Create an AddExpr for "PreStart" after subtracting Step. Full SCEV
1354 // subtraction is expensive. For this purpose, perform a quick and dirty
1355 // difference, by checking for Step in the operand list. Note, that
1356 // SA might have repeated ops, like %a + %a + ..., so only remove one.
1358 for (auto It = DiffOps.begin(); It != DiffOps.end(); ++It)
1359 if (*It == Step) {
1360 DiffOps.erase(It);
1361 break;
1362 }
1363
1364 if (DiffOps.size() == SA->getNumOperands())
1365 return nullptr;
1366
1367 // Try to prove `WrapType` (SCEV::FlagNSW or SCEV::FlagNUW) on `PreStart` +
1368 // `Step`:
1369
1370 // 1. NSW/NUW flags on the step increment.
1371 auto PreStartFlags =
1373 const SCEV *PreStart = SE->getAddExpr(DiffOps, PreStartFlags);
1374 const SCEVAddRecExpr *PreAR = dyn_cast<SCEVAddRecExpr>(
1375 SE->getAddRecExpr(PreStart, Step, L, SCEV::FlagAnyWrap));
1376
1377 // "{S,+,X} is <nsw>/<nuw>" and "the backedge is taken at least once" implies
1378 // "S+X does not sign/unsign-overflow".
1379 //
1380
1381 const SCEV *BECount = SE->getBackedgeTakenCount(L);
1382 if (PreAR && PreAR->getNoWrapFlags(WrapType) &&
1383 !isa<SCEVCouldNotCompute>(BECount) && SE->isKnownPositive(BECount))
1384 return PreStart;
1385
1386 // 2. Direct overflow check on the step operation's expression.
1387 unsigned BitWidth = SE->getTypeSizeInBits(AR->getType());
1388 Type *WideTy = IntegerType::get(SE->getContext(), BitWidth * 2);
1389 const SCEV *OperandExtendedStart =
1390 SE->getAddExpr((SE->*GetExtendExpr)(PreStart, WideTy, Depth),
1391 (SE->*GetExtendExpr)(Step, WideTy, Depth));
1392 if ((SE->*GetExtendExpr)(Start, WideTy, Depth) == OperandExtendedStart) {
1393 if (PreAR && AR->getNoWrapFlags(WrapType)) {
1394 // If we know `AR` == {`PreStart`+`Step`,+,`Step`} is `WrapType` (FlagNSW
1395 // or FlagNUW) and that `PreStart` + `Step` is `WrapType` too, then
1396 // `PreAR` == {`PreStart`,+,`Step`} is also `WrapType`. Cache this fact.
1397 SE->setNoWrapFlags(const_cast<SCEVAddRecExpr *>(PreAR), WrapType);
1398 }
1399 return PreStart;
1400 }
1401
1402 // 3. Loop precondition.
1404 const SCEV *OverflowLimit =
1405 ExtendOpTraits<ExtendOpTy>::getOverflowLimitForStep(Step, &Pred, SE);
1406
1407 if (OverflowLimit &&
1408 SE->isLoopEntryGuardedByCond(L, Pred, PreStart, OverflowLimit))
1409 return PreStart;
1410
1411 return nullptr;
1412}
1413
1414// Get the normalized zero or sign extended expression for this AddRec's Start.
1415template <typename ExtendOpTy>
1416static const SCEV *getExtendAddRecStart(const SCEVAddRecExpr *AR, Type *Ty,
1417 ScalarEvolution *SE,
1418 unsigned Depth) {
1419 auto GetExtendExpr = ExtendOpTraits<ExtendOpTy>::GetExtendExpr;
1420
1421 const SCEV *PreStart = getPreStartForExtend<ExtendOpTy>(AR, Ty, SE, Depth);
1422 if (!PreStart)
1423 return (SE->*GetExtendExpr)(AR->getStart(), Ty, Depth);
1424
1425 return SE->getAddExpr((SE->*GetExtendExpr)(AR->getStepRecurrence(*SE), Ty,
1426 Depth),
1427 (SE->*GetExtendExpr)(PreStart, Ty, Depth));
1428}
1429
1430// Try to prove away overflow by looking at "nearby" add recurrences. A
1431// motivating example for this rule: if we know `{0,+,4}` is `ult` `-1` and it
1432// does not itself wrap then we can conclude that `{1,+,4}` is `nuw`.
1433//
1434// Formally:
1435//
1436// {S,+,X} == {S-T,+,X} + T
1437// => Ext({S,+,X}) == Ext({S-T,+,X} + T)
1438//
1439// If ({S-T,+,X} + T) does not overflow ... (1)
1440//
1441// RHS == Ext({S-T,+,X} + T) == Ext({S-T,+,X}) + Ext(T)
1442//
1443// If {S-T,+,X} does not overflow ... (2)
1444//
1445// RHS == Ext({S-T,+,X}) + Ext(T) == {Ext(S-T),+,Ext(X)} + Ext(T)
1446// == {Ext(S-T)+Ext(T),+,Ext(X)}
1447//
1448// If (S-T)+T does not overflow ... (3)
1449//
1450// RHS == {Ext(S-T)+Ext(T),+,Ext(X)} == {Ext(S-T+T),+,Ext(X)}
1451// == {Ext(S),+,Ext(X)} == LHS
1452//
1453// Thus, if (1), (2) and (3) are true for some T, then
1454// Ext({S,+,X}) == {Ext(S),+,Ext(X)}
1455//
1456// (3) is implied by (1) -- "(S-T)+T does not overflow" is simply "({S-T,+,X}+T)
1457// does not overflow" restricted to the 0th iteration. Therefore we only need
1458// to check for (1) and (2).
1459//
1460// In the current context, S is `Start`, X is `Step`, Ext is `ExtendOpTy` and T
1461// is `Delta` (defined below).
1462template <typename ExtendOpTy>
1463bool ScalarEvolution::proveNoWrapByVaryingStart(const SCEV *Start,
1464 const SCEV *Step,
1465 const Loop *L) {
1466 auto WrapType = ExtendOpTraits<ExtendOpTy>::WrapType;
1467
1468 // We restrict `Start` to a constant to prevent SCEV from spending too much
1469 // time here. It is correct (but more expensive) to continue with a
1470 // non-constant `Start` and do a general SCEV subtraction to compute
1471 // `PreStart` below.
1472 const SCEVConstant *StartC = dyn_cast<SCEVConstant>(Start);
1473 if (!StartC)
1474 return false;
1475
1476 APInt StartAI = StartC->getAPInt();
1477
1478 for (unsigned Delta : {-2, -1, 1, 2}) {
1479 const SCEV *PreStart = getConstant(StartAI - Delta);
1480
1482 ID.AddInteger(scAddRecExpr);
1483 ID.AddPointer(PreStart);
1484 ID.AddPointer(Step);
1485 ID.AddPointer(L);
1486 void *IP = nullptr;
1487 const auto *PreAR =
1488 static_cast<SCEVAddRecExpr *>(UniqueSCEVs.FindNodeOrInsertPos(ID, IP));
1489
1490 // Give up if we don't already have the add recurrence we need because
1491 // actually constructing an add recurrence is relatively expensive.
1492 if (PreAR && PreAR->getNoWrapFlags(WrapType)) { // proves (2)
1493 const SCEV *DeltaS = getConstant(StartC->getType(), Delta);
1495 const SCEV *Limit = ExtendOpTraits<ExtendOpTy>::getOverflowLimitForStep(
1496 DeltaS, &Pred, this);
1497 if (Limit && isKnownPredicate(Pred, PreAR, Limit)) // proves (1)
1498 return true;
1499 }
1500 }
1501
1502 return false;
1503}
1504
1505// Finds an integer D for an expression (C + x + y + ...) such that the top
1506// level addition in (D + (C - D + x + y + ...)) would not wrap (signed or
1507// unsigned) and the number of trailing zeros of (C - D + x + y + ...) is
1508// maximized, where C is the \p ConstantTerm, x, y, ... are arbitrary SCEVs, and
1509// the (C + x + y + ...) expression is \p WholeAddExpr.
1511 const SCEVConstant *ConstantTerm,
1512 const SCEVAddExpr *WholeAddExpr) {
1513 const APInt &C = ConstantTerm->getAPInt();
1514 const unsigned BitWidth = C.getBitWidth();
1515 // Find number of trailing zeros of (x + y + ...) w/o the C first:
1516 uint32_t TZ = BitWidth;
1517 for (unsigned I = 1, E = WholeAddExpr->getNumOperands(); I < E && TZ; ++I)
1518 TZ = std::min(TZ, SE.getMinTrailingZeros(WholeAddExpr->getOperand(I)));
1519 if (TZ) {
1520 // Set D to be as many least significant bits of C as possible while still
1521 // guaranteeing that adding D to (C - D + x + y + ...) won't cause a wrap:
1522 return TZ < BitWidth ? C.trunc(TZ).zext(BitWidth) : C;
1523 }
1524 return APInt(BitWidth, 0);
1525}
1526
1527// Finds an integer D for an affine AddRec expression {C,+,x} such that the top
1528// level addition in (D + {C-D,+,x}) would not wrap (signed or unsigned) and the
1529// number of trailing zeros of (C - D + x * n) is maximized, where C is the \p
1530// ConstantStart, x is an arbitrary \p Step, and n is the loop trip count.
1532 const APInt &ConstantStart,
1533 const SCEV *Step) {
1534 const unsigned BitWidth = ConstantStart.getBitWidth();
1535 const uint32_t TZ = SE.getMinTrailingZeros(Step);
1536 if (TZ)
1537 return TZ < BitWidth ? ConstantStart.trunc(TZ).zext(BitWidth)
1538 : ConstantStart;
1539 return APInt(BitWidth, 0);
1540}
1541
1543 const ScalarEvolution::FoldID &ID, const SCEV *S,
1546 &FoldCacheUser) {
1547 auto I = FoldCache.insert({ID, S});
1548 if (!I.second) {
1549 // Remove FoldCacheUser entry for ID when replacing an existing FoldCache
1550 // entry.
1551 auto &UserIDs = FoldCacheUser[I.first->second];
1552 assert(count(UserIDs, ID) == 1 && "unexpected duplicates in UserIDs");
1553 for (unsigned I = 0; I != UserIDs.size(); ++I)
1554 if (UserIDs[I] == ID) {
1555 std::swap(UserIDs[I], UserIDs.back());
1556 break;
1557 }
1558 UserIDs.pop_back();
1559 I.first->second = S;
1560 }
1561 FoldCacheUser[S].push_back(ID);
1562}
1563
1564const SCEV *
1566 assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) &&
1567 "This is not an extending conversion!");
1568 assert(isSCEVable(Ty) &&
1569 "This is not a conversion to a SCEVable type!");
1570 assert(!Op->getType()->isPointerTy() && "Can't extend pointer!");
1571 Ty = getEffectiveSCEVType(Ty);
1572
1573 FoldID ID(scZeroExtend, Op, Ty);
1574 auto Iter = FoldCache.find(ID);
1575 if (Iter != FoldCache.end())
1576 return Iter->second;
1577
1578 const SCEV *S = getZeroExtendExprImpl(Op, Ty, Depth);
1579 if (!isa<SCEVZeroExtendExpr>(S))
1580 insertFoldCacheEntry(ID, S, FoldCache, FoldCacheUser);
1581 return S;
1582}
1583
1585 unsigned Depth) {
1586 assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) &&
1587 "This is not an extending conversion!");
1588 assert(isSCEVable(Ty) && "This is not a conversion to a SCEVable type!");
1589 assert(!Op->getType()->isPointerTy() && "Can't extend pointer!");
1590
1591 // Fold if the operand is constant.
1592 if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op))
1593 return getConstant(SC->getAPInt().zext(getTypeSizeInBits(Ty)));
1594
1595 // zext(zext(x)) --> zext(x)
1596 if (const SCEVZeroExtendExpr *SZ = dyn_cast<SCEVZeroExtendExpr>(Op))
1597 return getZeroExtendExpr(SZ->getOperand(), Ty, Depth + 1);
1598
1599 // Before doing any expensive analysis, check to see if we've already
1600 // computed a SCEV for this Op and Ty.
1602 ID.AddInteger(scZeroExtend);
1603 ID.AddPointer(Op);
1604 ID.AddPointer(Ty);
1605 void *IP = nullptr;
1606 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
1607 if (Depth > MaxCastDepth) {
1608 SCEV *S = new (SCEVAllocator) SCEVZeroExtendExpr(ID.Intern(SCEVAllocator),
1609 Op, Ty);
1610 UniqueSCEVs.InsertNode(S, IP);
1611 registerUser(S, Op);
1612 return S;
1613 }
1614
1615 // zext(trunc(x)) --> zext(x) or x or trunc(x)
1616 if (const SCEVTruncateExpr *ST = dyn_cast<SCEVTruncateExpr>(Op)) {
1617 // It's possible the bits taken off by the truncate were all zero bits. If
1618 // so, we should be able to simplify this further.
1619 const SCEV *X = ST->getOperand();
1621 unsigned TruncBits = getTypeSizeInBits(ST->getType());
1622 unsigned NewBits = getTypeSizeInBits(Ty);
1623 if (CR.truncate(TruncBits).zeroExtend(NewBits).contains(
1624 CR.zextOrTrunc(NewBits)))
1625 return getTruncateOrZeroExtend(X, Ty, Depth);
1626 }
1627
1628 // If the input value is a chrec scev, and we can prove that the value
1629 // did not overflow the old, smaller, value, we can zero extend all of the
1630 // operands (often constants). This allows analysis of something like
1631 // this: for (unsigned char X = 0; X < 100; ++X) { int Y = X; }
1632 if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Op))
1633 if (AR->isAffine()) {
1634 const SCEV *Start = AR->getStart();
1635 const SCEV *Step = AR->getStepRecurrence(*this);
1636 unsigned BitWidth = getTypeSizeInBits(AR->getType());
1637 const Loop *L = AR->getLoop();
1638
1639 // If we have special knowledge that this addrec won't overflow,
1640 // we don't need to do any further analysis.
1641 if (AR->hasNoUnsignedWrap()) {
1642 Start =
1643 getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this, Depth + 1);
1644 Step = getZeroExtendExpr(Step, Ty, Depth + 1);
1645 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
1646 }
1647
1648 // Check whether the backedge-taken count is SCEVCouldNotCompute.
1649 // Note that this serves two purposes: It filters out loops that are
1650 // simply not analyzable, and it covers the case where this code is
1651 // being called from within backedge-taken count analysis, such that
1652 // attempting to ask for the backedge-taken count would likely result
1653 // in infinite recursion. In the later case, the analysis code will
1654 // cope with a conservative value, and it will take care to purge
1655 // that value once it has finished.
1656 const SCEV *MaxBECount = getConstantMaxBackedgeTakenCount(L);
1657 if (!isa<SCEVCouldNotCompute>(MaxBECount)) {
1658 // Manually compute the final value for AR, checking for overflow.
1659
1660 // Check whether the backedge-taken count can be losslessly casted to
1661 // the addrec's type. The count is always unsigned.
1662 const SCEV *CastedMaxBECount =
1663 getTruncateOrZeroExtend(MaxBECount, Start->getType(), Depth);
1664 const SCEV *RecastedMaxBECount = getTruncateOrZeroExtend(
1665 CastedMaxBECount, MaxBECount->getType(), Depth);
1666 if (MaxBECount == RecastedMaxBECount) {
1667 Type *WideTy = IntegerType::get(getContext(), BitWidth * 2);
1668 // Check whether Start+Step*MaxBECount has no unsigned overflow.
1669 const SCEV *ZMul = getMulExpr(CastedMaxBECount, Step,
1671 const SCEV *ZAdd = getZeroExtendExpr(getAddExpr(Start, ZMul,
1673 Depth + 1),
1674 WideTy, Depth + 1);
1675 const SCEV *WideStart = getZeroExtendExpr(Start, WideTy, Depth + 1);
1676 const SCEV *WideMaxBECount =
1677 getZeroExtendExpr(CastedMaxBECount, WideTy, Depth + 1);
1678 const SCEV *OperandExtendedAdd =
1679 getAddExpr(WideStart,
1680 getMulExpr(WideMaxBECount,
1681 getZeroExtendExpr(Step, WideTy, Depth + 1),
1684 if (ZAdd == OperandExtendedAdd) {
1685 // Cache knowledge of AR NUW, which is propagated to this AddRec.
1686 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNUW);
1687 // Return the expression with the addrec on the outside.
1688 Start = getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this,
1689 Depth + 1);
1690 Step = getZeroExtendExpr(Step, Ty, Depth + 1);
1691 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
1692 }
1693 // Similar to above, only this time treat the step value as signed.
1694 // This covers loops that count down.
1695 OperandExtendedAdd =
1696 getAddExpr(WideStart,
1697 getMulExpr(WideMaxBECount,
1698 getSignExtendExpr(Step, WideTy, Depth + 1),
1701 if (ZAdd == OperandExtendedAdd) {
1702 // Cache knowledge of AR NW, which is propagated to this AddRec.
1703 // Negative step causes unsigned wrap, but it still can't self-wrap.
1704 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNW);
1705 // Return the expression with the addrec on the outside.
1706 Start = getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this,
1707 Depth + 1);
1708 Step = getSignExtendExpr(Step, Ty, Depth + 1);
1709 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
1710 }
1711 }
1712 }
1713
1714 // Normally, in the cases we can prove no-overflow via a
1715 // backedge guarding condition, we can also compute a backedge
1716 // taken count for the loop. The exceptions are assumptions and
1717 // guards present in the loop -- SCEV is not great at exploiting
1718 // these to compute max backedge taken counts, but can still use
1719 // these to prove lack of overflow. Use this fact to avoid
1720 // doing extra work that may not pay off.
1721 if (!isa<SCEVCouldNotCompute>(MaxBECount) || HasGuards ||
1722 !AC.assumptions().empty()) {
1723
1724 auto NewFlags = proveNoUnsignedWrapViaInduction(AR);
1725 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), NewFlags);
1726 if (AR->hasNoUnsignedWrap()) {
1727 // Same as nuw case above - duplicated here to avoid a compile time
1728 // issue. It's not clear that the order of checks does matter, but
1729 // it's one of two issue possible causes for a change which was
1730 // reverted. Be conservative for the moment.
1731 Start =
1732 getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this, Depth + 1);
1733 Step = getZeroExtendExpr(Step, Ty, Depth + 1);
1734 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
1735 }
1736
1737 // For a negative step, we can extend the operands iff doing so only
1738 // traverses values in the range zext([0,UINT_MAX]).
1739 if (isKnownNegative(Step)) {
1741 getSignedRangeMin(Step));
1744 // Cache knowledge of AR NW, which is propagated to this
1745 // AddRec. Negative step causes unsigned wrap, but it
1746 // still can't self-wrap.
1747 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNW);
1748 // Return the expression with the addrec on the outside.
1749 Start = getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this,
1750 Depth + 1);
1751 Step = getSignExtendExpr(Step, Ty, Depth + 1);
1752 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
1753 }
1754 }
1755 }
1756
1757 // zext({C,+,Step}) --> (zext(D) + zext({C-D,+,Step}))<nuw><nsw>
1758 // if D + (C - D + Step * n) could be proven to not unsigned wrap
1759 // where D maximizes the number of trailing zeros of (C - D + Step * n)
1760 if (const auto *SC = dyn_cast<SCEVConstant>(Start)) {
1761 const APInt &C = SC->getAPInt();
1762 const APInt &D = extractConstantWithoutWrapping(*this, C, Step);
1763 if (D != 0) {
1764 const SCEV *SZExtD = getZeroExtendExpr(getConstant(D), Ty, Depth);
1765 const SCEV *SResidual =
1766 getAddRecExpr(getConstant(C - D), Step, L, AR->getNoWrapFlags());
1767 const SCEV *SZExtR = getZeroExtendExpr(SResidual, Ty, Depth + 1);
1768 return getAddExpr(SZExtD, SZExtR,
1770 Depth + 1);
1771 }
1772 }
1773
1774 if (proveNoWrapByVaryingStart<SCEVZeroExtendExpr>(Start, Step, L)) {
1775 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNUW);
1776 Start =
1777 getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this, Depth + 1);
1778 Step = getZeroExtendExpr(Step, Ty, Depth + 1);
1779 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
1780 }
1781 }
1782
1783 // zext(A % B) --> zext(A) % zext(B)
1784 {
1785 const SCEV *LHS;
1786 const SCEV *RHS;
1787 if (matchURem(Op, LHS, RHS))
1788 return getURemExpr(getZeroExtendExpr(LHS, Ty, Depth + 1),
1789 getZeroExtendExpr(RHS, Ty, Depth + 1));
1790 }
1791
1792 // zext(A / B) --> zext(A) / zext(B).
1793 if (auto *Div = dyn_cast<SCEVUDivExpr>(Op))
1794 return getUDivExpr(getZeroExtendExpr(Div->getLHS(), Ty, Depth + 1),
1795 getZeroExtendExpr(Div->getRHS(), Ty, Depth + 1));
1796
1797 if (auto *SA = dyn_cast<SCEVAddExpr>(Op)) {
1798 // zext((A + B + ...)<nuw>) --> (zext(A) + zext(B) + ...)<nuw>
1799 if (SA->hasNoUnsignedWrap()) {
1800 // If the addition does not unsign overflow then we can, by definition,
1801 // commute the zero extension with the addition operation.
1803 for (const auto *Op : SA->operands())
1804 Ops.push_back(getZeroExtendExpr(Op, Ty, Depth + 1));
1805 return getAddExpr(Ops, SCEV::FlagNUW, Depth + 1);
1806 }
1807
1808 // zext(C + x + y + ...) --> (zext(D) + zext((C - D) + x + y + ...))
1809 // if D + (C - D + x + y + ...) could be proven to not unsigned wrap
1810 // where D maximizes the number of trailing zeros of (C - D + x + y + ...)
1811 //
1812 // Often address arithmetics contain expressions like
1813 // (zext (add (shl X, C1), C2)), for instance, (zext (5 + (4 * X))).
1814 // This transformation is useful while proving that such expressions are
1815 // equal or differ by a small constant amount, see LoadStoreVectorizer pass.
1816 if (const auto *SC = dyn_cast<SCEVConstant>(SA->getOperand(0))) {
1817 const APInt &D = extractConstantWithoutWrapping(*this, SC, SA);
1818 if (D != 0) {
1819 const SCEV *SZExtD = getZeroExtendExpr(getConstant(D), Ty, Depth);
1820 const SCEV *SResidual =
1822 const SCEV *SZExtR = getZeroExtendExpr(SResidual, Ty, Depth + 1);
1823 return getAddExpr(SZExtD, SZExtR,
1825 Depth + 1);
1826 }
1827 }
1828 }
1829
1830 if (auto *SM = dyn_cast<SCEVMulExpr>(Op)) {
1831 // zext((A * B * ...)<nuw>) --> (zext(A) * zext(B) * ...)<nuw>
1832 if (SM->hasNoUnsignedWrap()) {
1833 // If the multiply does not unsign overflow then we can, by definition,
1834 // commute the zero extension with the multiply operation.
1836 for (const auto *Op : SM->operands())
1837 Ops.push_back(getZeroExtendExpr(Op, Ty, Depth + 1));
1838 return getMulExpr(Ops, SCEV::FlagNUW, Depth + 1);
1839 }
1840
1841 // zext(2^K * (trunc X to iN)) to iM ->
1842 // 2^K * (zext(trunc X to i{N-K}) to iM)<nuw>
1843 //
1844 // Proof:
1845 //
1846 // zext(2^K * (trunc X to iN)) to iM
1847 // = zext((trunc X to iN) << K) to iM
1848 // = zext((trunc X to i{N-K}) << K)<nuw> to iM
1849 // (because shl removes the top K bits)
1850 // = zext((2^K * (trunc X to i{N-K}))<nuw>) to iM
1851 // = (2^K * (zext(trunc X to i{N-K}) to iM))<nuw>.
1852 //
1853 if (SM->getNumOperands() == 2)
1854 if (auto *MulLHS = dyn_cast<SCEVConstant>(SM->getOperand(0)))
1855 if (MulLHS->getAPInt().isPowerOf2())
1856 if (auto *TruncRHS = dyn_cast<SCEVTruncateExpr>(SM->getOperand(1))) {
1857 int NewTruncBits = getTypeSizeInBits(TruncRHS->getType()) -
1858 MulLHS->getAPInt().logBase2();
1859 Type *NewTruncTy = IntegerType::get(getContext(), NewTruncBits);
1860 return getMulExpr(
1861 getZeroExtendExpr(MulLHS, Ty),
1863 getTruncateExpr(TruncRHS->getOperand(), NewTruncTy), Ty),
1864 SCEV::FlagNUW, Depth + 1);
1865 }
1866 }
1867
1868 // zext(umin(x, y)) -> umin(zext(x), zext(y))
1869 // zext(umax(x, y)) -> umax(zext(x), zext(y))
1870 if (isa<SCEVUMinExpr>(Op) || isa<SCEVUMaxExpr>(Op)) {
1871 auto *MinMax = cast<SCEVMinMaxExpr>(Op);
1873 for (auto *Operand : MinMax->operands())
1874 Operands.push_back(getZeroExtendExpr(Operand, Ty));
1875 if (isa<SCEVUMinExpr>(MinMax))
1876 return getUMinExpr(Operands);
1877 return getUMaxExpr(Operands);
1878 }
1879
1880 // zext(umin_seq(x, y)) -> umin_seq(zext(x), zext(y))
1881 if (auto *MinMax = dyn_cast<SCEVSequentialMinMaxExpr>(Op)) {
1882 assert(isa<SCEVSequentialUMinExpr>(MinMax) && "Not supported!");
1884 for (auto *Operand : MinMax->operands())
1885 Operands.push_back(getZeroExtendExpr(Operand, Ty));
1886 return getUMinExpr(Operands, /*Sequential*/ true);
1887 }
1888
1889 // The cast wasn't folded; create an explicit cast node.
1890 // Recompute the insert position, as it may have been invalidated.
1891 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
1892 SCEV *S = new (SCEVAllocator) SCEVZeroExtendExpr(ID.Intern(SCEVAllocator),
1893 Op, Ty);
1894 UniqueSCEVs.InsertNode(S, IP);
1895 registerUser(S, Op);
1896 return S;
1897}
1898
1899const SCEV *
1901 assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) &&
1902 "This is not an extending conversion!");
1903 assert(isSCEVable(Ty) &&
1904 "This is not a conversion to a SCEVable type!");
1905 assert(!Op->getType()->isPointerTy() && "Can't extend pointer!");
1906 Ty = getEffectiveSCEVType(Ty);
1907
1908 FoldID ID(scSignExtend, Op, Ty);
1909 auto Iter = FoldCache.find(ID);
1910 if (Iter != FoldCache.end())
1911 return Iter->second;
1912
1913 const SCEV *S = getSignExtendExprImpl(Op, Ty, Depth);
1914 if (!isa<SCEVSignExtendExpr>(S))
1915 insertFoldCacheEntry(ID, S, FoldCache, FoldCacheUser);
1916 return S;
1917}
1918
1920 unsigned Depth) {
1921 assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) &&
1922 "This is not an extending conversion!");
1923 assert(isSCEVable(Ty) && "This is not a conversion to a SCEVable type!");
1924 assert(!Op->getType()->isPointerTy() && "Can't extend pointer!");
1925 Ty = getEffectiveSCEVType(Ty);
1926
1927 // Fold if the operand is constant.
1928 if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op))
1929 return getConstant(SC->getAPInt().sext(getTypeSizeInBits(Ty)));
1930
1931 // sext(sext(x)) --> sext(x)
1932 if (const SCEVSignExtendExpr *SS = dyn_cast<SCEVSignExtendExpr>(Op))
1933 return getSignExtendExpr(SS->getOperand(), Ty, Depth + 1);
1934
1935 // sext(zext(x)) --> zext(x)
1936 if (const SCEVZeroExtendExpr *SZ = dyn_cast<SCEVZeroExtendExpr>(Op))
1937 return getZeroExtendExpr(SZ->getOperand(), Ty, Depth + 1);
1938
1939 // Before doing any expensive analysis, check to see if we've already
1940 // computed a SCEV for this Op and Ty.
1942 ID.AddInteger(scSignExtend);
1943 ID.AddPointer(Op);
1944 ID.AddPointer(Ty);
1945 void *IP = nullptr;
1946 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
1947 // Limit recursion depth.
1948 if (Depth > MaxCastDepth) {
1949 SCEV *S = new (SCEVAllocator) SCEVSignExtendExpr(ID.Intern(SCEVAllocator),
1950 Op, Ty);
1951 UniqueSCEVs.InsertNode(S, IP);
1952 registerUser(S, Op);
1953 return S;
1954 }
1955
1956 // sext(trunc(x)) --> sext(x) or x or trunc(x)
1957 if (const SCEVTruncateExpr *ST = dyn_cast<SCEVTruncateExpr>(Op)) {
1958 // It's possible the bits taken off by the truncate were all sign bits. If
1959 // so, we should be able to simplify this further.
1960 const SCEV *X = ST->getOperand();
1962 unsigned TruncBits = getTypeSizeInBits(ST->getType());
1963 unsigned NewBits = getTypeSizeInBits(Ty);
1964 if (CR.truncate(TruncBits).signExtend(NewBits).contains(
1965 CR.sextOrTrunc(NewBits)))
1966 return getTruncateOrSignExtend(X, Ty, Depth);
1967 }
1968
1969 if (auto *SA = dyn_cast<SCEVAddExpr>(Op)) {
1970 // sext((A + B + ...)<nsw>) --> (sext(A) + sext(B) + ...)<nsw>
1971 if (SA->hasNoSignedWrap()) {
1972 // If the addition does not sign overflow then we can, by definition,
1973 // commute the sign extension with the addition operation.
1975 for (const auto *Op : SA->operands())
1976 Ops.push_back(getSignExtendExpr(Op, Ty, Depth + 1));
1977 return getAddExpr(Ops, SCEV::FlagNSW, Depth + 1);
1978 }
1979
1980 // sext(C + x + y + ...) --> (sext(D) + sext((C - D) + x + y + ...))
1981 // if D + (C - D + x + y + ...) could be proven to not signed wrap
1982 // where D maximizes the number of trailing zeros of (C - D + x + y + ...)
1983 //
1984 // For instance, this will bring two seemingly different expressions:
1985 // 1 + sext(5 + 20 * %x + 24 * %y) and
1986 // sext(6 + 20 * %x + 24 * %y)
1987 // to the same form:
1988 // 2 + sext(4 + 20 * %x + 24 * %y)
1989 if (const auto *SC = dyn_cast<SCEVConstant>(SA->getOperand(0))) {
1990 const APInt &D = extractConstantWithoutWrapping(*this, SC, SA);
1991 if (D != 0) {
1992 const SCEV *SSExtD = getSignExtendExpr(getConstant(D), Ty, Depth);
1993 const SCEV *SResidual =
1995 const SCEV *SSExtR = getSignExtendExpr(SResidual, Ty, Depth + 1);
1996 return getAddExpr(SSExtD, SSExtR,
1998 Depth + 1);
1999 }
2000 }
2001 }
2002 // If the input value is a chrec scev, and we can prove that the value
2003 // did not overflow the old, smaller, value, we can sign extend all of the
2004 // operands (often constants). This allows analysis of something like
2005 // this: for (signed char X = 0; X < 100; ++X) { int Y = X; }
2006 if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Op))
2007 if (AR->isAffine()) {
2008 const SCEV *Start = AR->getStart();
2009 const SCEV *Step = AR->getStepRecurrence(*this);
2010 unsigned BitWidth = getTypeSizeInBits(AR->getType());
2011 const Loop *L = AR->getLoop();
2012
2013 // If we have special knowledge that this addrec won't overflow,
2014 // we don't need to do any further analysis.
2015 if (AR->hasNoSignedWrap()) {
2016 Start =
2017 getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this, Depth + 1);
2018 Step = getSignExtendExpr(Step, Ty, Depth + 1);
2019 return getAddRecExpr(Start, Step, L, SCEV::FlagNSW);
2020 }
2021
2022 // Check whether the backedge-taken count is SCEVCouldNotCompute.
2023 // Note that this serves two purposes: It filters out loops that are
2024 // simply not analyzable, and it covers the case where this code is
2025 // being called from within backedge-taken count analysis, such that
2026 // attempting to ask for the backedge-taken count would likely result
2027 // in infinite recursion. In the later case, the analysis code will
2028 // cope with a conservative value, and it will take care to purge
2029 // that value once it has finished.
2030 const SCEV *MaxBECount = getConstantMaxBackedgeTakenCount(L);
2031 if (!isa<SCEVCouldNotCompute>(MaxBECount)) {
2032 // Manually compute the final value for AR, checking for
2033 // overflow.
2034
2035 // Check whether the backedge-taken count can be losslessly casted to
2036 // the addrec's type. The count is always unsigned.
2037 const SCEV *CastedMaxBECount =
2038 getTruncateOrZeroExtend(MaxBECount, Start->getType(), Depth);
2039 const SCEV *RecastedMaxBECount = getTruncateOrZeroExtend(
2040 CastedMaxBECount, MaxBECount->getType(), Depth);
2041 if (MaxBECount == RecastedMaxBECount) {
2042 Type *WideTy = IntegerType::get(getContext(), BitWidth * 2);
2043 // Check whether Start+Step*MaxBECount has no signed overflow.
2044 const SCEV *SMul = getMulExpr(CastedMaxBECount, Step,
2046 const SCEV *SAdd = getSignExtendExpr(getAddExpr(Start, SMul,
2048 Depth + 1),
2049 WideTy, Depth + 1);
2050 const SCEV *WideStart = getSignExtendExpr(Start, WideTy, Depth + 1);
2051 const SCEV *WideMaxBECount =
2052 getZeroExtendExpr(CastedMaxBECount, WideTy, Depth + 1);
2053 const SCEV *OperandExtendedAdd =
2054 getAddExpr(WideStart,
2055 getMulExpr(WideMaxBECount,
2056 getSignExtendExpr(Step, WideTy, Depth + 1),
2059 if (SAdd == OperandExtendedAdd) {
2060 // Cache knowledge of AR NSW, which is propagated to this AddRec.
2061 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNSW);
2062 // Return the expression with the addrec on the outside.
2063 Start = getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this,
2064 Depth + 1);
2065 Step = getSignExtendExpr(Step, Ty, Depth + 1);
2066 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
2067 }
2068 // Similar to above, only this time treat the step value as unsigned.
2069 // This covers loops that count up with an unsigned step.
2070 OperandExtendedAdd =
2071 getAddExpr(WideStart,
2072 getMulExpr(WideMaxBECount,
2073 getZeroExtendExpr(Step, WideTy, Depth + 1),
2076 if (SAdd == OperandExtendedAdd) {
2077 // If AR wraps around then
2078 //
2079 // abs(Step) * MaxBECount > unsigned-max(AR->getType())
2080 // => SAdd != OperandExtendedAdd
2081 //
2082 // Thus (AR is not NW => SAdd != OperandExtendedAdd) <=>
2083 // (SAdd == OperandExtendedAdd => AR is NW)
2084
2085 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNW);
2086
2087 // Return the expression with the addrec on the outside.
2088 Start = getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this,
2089 Depth + 1);
2090 Step = getZeroExtendExpr(Step, Ty, Depth + 1);
2091 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
2092 }
2093 }
2094 }
2095
2096 auto NewFlags = proveNoSignedWrapViaInduction(AR);
2097 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), NewFlags);
2098 if (AR->hasNoSignedWrap()) {
2099 // Same as nsw case above - duplicated here to avoid a compile time
2100 // issue. It's not clear that the order of checks does matter, but
2101 // it's one of two issue possible causes for a change which was
2102 // reverted. Be conservative for the moment.
2103 Start =
2104 getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this, Depth + 1);
2105 Step = getSignExtendExpr(Step, Ty, Depth + 1);
2106 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
2107 }
2108
2109 // sext({C,+,Step}) --> (sext(D) + sext({C-D,+,Step}))<nuw><nsw>
2110 // if D + (C - D + Step * n) could be proven to not signed wrap
2111 // where D maximizes the number of trailing zeros of (C - D + Step * n)
2112 if (const auto *SC = dyn_cast<SCEVConstant>(Start)) {
2113 const APInt &C = SC->getAPInt();
2114 const APInt &D = extractConstantWithoutWrapping(*this, C, Step);
2115 if (D != 0) {
2116 const SCEV *SSExtD = getSignExtendExpr(getConstant(D), Ty, Depth);
2117 const SCEV *SResidual =
2118 getAddRecExpr(getConstant(C - D), Step, L, AR->getNoWrapFlags());
2119 const SCEV *SSExtR = getSignExtendExpr(SResidual, Ty, Depth + 1);
2120 return getAddExpr(SSExtD, SSExtR,
2122 Depth + 1);
2123 }
2124 }
2125
2126 if (proveNoWrapByVaryingStart<SCEVSignExtendExpr>(Start, Step, L)) {
2127 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNSW);
2128 Start =
2129 getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this, Depth + 1);
2130 Step = getSignExtendExpr(Step, Ty, Depth + 1);
2131 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
2132 }
2133 }
2134
2135 // If the input value is provably positive and we could not simplify
2136 // away the sext build a zext instead.
2138 return getZeroExtendExpr(Op, Ty, Depth + 1);
2139
2140 // sext(smin(x, y)) -> smin(sext(x), sext(y))
2141 // sext(smax(x, y)) -> smax(sext(x), sext(y))
2142 if (isa<SCEVSMinExpr>(Op) || isa<SCEVSMaxExpr>(Op)) {
2143 auto *MinMax = cast<SCEVMinMaxExpr>(Op);
2145 for (auto *Operand : MinMax->operands())
2146 Operands.push_back(getSignExtendExpr(Operand, Ty));
2147 if (isa<SCEVSMinExpr>(MinMax))
2148 return getSMinExpr(Operands);
2149 return getSMaxExpr(Operands);
2150 }
2151
2152 // The cast wasn't folded; create an explicit cast node.
2153 // Recompute the insert position, as it may have been invalidated.
2154 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
2155 SCEV *S = new (SCEVAllocator) SCEVSignExtendExpr(ID.Intern(SCEVAllocator),
2156 Op, Ty);
2157 UniqueSCEVs.InsertNode(S, IP);
2158 registerUser(S, { Op });
2159 return S;
2160}
2161
2163 Type *Ty) {
2164 switch (Kind) {
2165 case scTruncate:
2166 return getTruncateExpr(Op, Ty);
2167 case scZeroExtend:
2168 return getZeroExtendExpr(Op, Ty);
2169 case scSignExtend:
2170 return getSignExtendExpr(Op, Ty);
2171 case scPtrToInt:
2172 return getPtrToIntExpr(Op, Ty);
2173 default:
2174 llvm_unreachable("Not a SCEV cast expression!");
2175 }
2176}
2177
2178/// getAnyExtendExpr - Return a SCEV for the given operand extended with
2179/// unspecified bits out to the given type.
2181 Type *Ty) {
2182 assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) &&
2183 "This is not an extending conversion!");
2184 assert(isSCEVable(Ty) &&
2185 "This is not a conversion to a SCEVable type!");
2186 Ty = getEffectiveSCEVType(Ty);
2187
2188 // Sign-extend negative constants.
2189 if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op))
2190 if (SC->getAPInt().isNegative())
2191 return getSignExtendExpr(Op, Ty);
2192
2193 // Peel off a truncate cast.
2194 if (const SCEVTruncateExpr *T = dyn_cast<SCEVTruncateExpr>(Op)) {
2195 const SCEV *NewOp = T->getOperand();
2196 if (getTypeSizeInBits(NewOp->getType()) < getTypeSizeInBits(Ty))
2197 return getAnyExtendExpr(NewOp, Ty);
2198 return getTruncateOrNoop(NewOp, Ty);
2199 }
2200
2201 // Next try a zext cast. If the cast is folded, use it.
2202 const SCEV *ZExt = getZeroExtendExpr(Op, Ty);
2203 if (!isa<SCEVZeroExtendExpr>(ZExt))
2204 return ZExt;
2205
2206 // Next try a sext cast. If the cast is folded, use it.
2207 const SCEV *SExt = getSignExtendExpr(Op, Ty);
2208 if (!isa<SCEVSignExtendExpr>(SExt))
2209 return SExt;
2210
2211 // Force the cast to be folded into the operands of an addrec.
2212 if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Op)) {
2214 for (const SCEV *Op : AR->operands())
2215 Ops.push_back(getAnyExtendExpr(Op, Ty));
2216 return getAddRecExpr(Ops, AR->getLoop(), SCEV::FlagNW);
2217 }
2218
2219 // If the expression is obviously signed, use the sext cast value.
2220 if (isa<SCEVSMaxExpr>(Op))
2221 return SExt;
2222
2223 // Absent any other information, use the zext cast value.
2224 return ZExt;
2225}
2226
2227/// Process the given Ops list, which is a list of operands to be added under
2228/// the given scale, update the given map. This is a helper function for
2229/// getAddRecExpr. As an example of what it does, given a sequence of operands
2230/// that would form an add expression like this:
2231///
2232/// m + n + 13 + (A * (o + p + (B * (q + m + 29)))) + r + (-1 * r)
2233///
2234/// where A and B are constants, update the map with these values:
2235///
2236/// (m, 1+A*B), (n, 1), (o, A), (p, A), (q, A*B), (r, 0)
2237///
2238/// and add 13 + A*B*29 to AccumulatedConstant.
2239/// This will allow getAddRecExpr to produce this:
2240///
2241/// 13+A*B*29 + n + (m * (1+A*B)) + ((o + p) * A) + (q * A*B)
2242///
2243/// This form often exposes folding opportunities that are hidden in
2244/// the original operand list.
2245///
2246/// Return true iff it appears that any interesting folding opportunities
2247/// may be exposed. This helps getAddRecExpr short-circuit extra work in
2248/// the common case where no interesting opportunities are present, and
2249/// is also used as a check to avoid infinite recursion.
2250static bool
2253 APInt &AccumulatedConstant,
2254 ArrayRef<const SCEV *> Ops, const APInt &Scale,
2255 ScalarEvolution &SE) {
2256 bool Interesting = false;
2257
2258 // Iterate over the add operands. They are sorted, with constants first.
2259 unsigned i = 0;
2260 while (const SCEVConstant *C = dyn_cast<SCEVConstant>(Ops[i])) {
2261 ++i;
2262 // Pull a buried constant out to the outside.
2263 if (Scale != 1 || AccumulatedConstant != 0 || C->getValue()->isZero())
2264 Interesting = true;
2265 AccumulatedConstant += Scale * C->getAPInt();
2266 }
2267
2268 // Next comes everything else. We're especially interested in multiplies
2269 // here, but they're in the middle, so just visit the rest with one loop.
2270 for (; i != Ops.size(); ++i) {
2271 const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(Ops[i]);
2272 if (Mul && isa<SCEVConstant>(Mul->getOperand(0))) {
2273 APInt NewScale =
2274 Scale * cast<SCEVConstant>(Mul->getOperand(0))->getAPInt();
2275 if (Mul->getNumOperands() == 2 && isa<SCEVAddExpr>(Mul->getOperand(1))) {
2276 // A multiplication of a constant with another add; recurse.
2277 const SCEVAddExpr *Add = cast<SCEVAddExpr>(Mul->getOperand(1));
2278 Interesting |=
2279 CollectAddOperandsWithScales(M, NewOps, AccumulatedConstant,
2280 Add->operands(), NewScale, SE);
2281 } else {
2282 // A multiplication of a constant with some other value. Update
2283 // the map.
2284 SmallVector<const SCEV *, 4> MulOps(drop_begin(Mul->operands()));
2285 const SCEV *Key = SE.getMulExpr(MulOps);
2286 auto Pair = M.insert({Key, NewScale});
2287 if (Pair.second) {
2288 NewOps.push_back(Pair.first->first);
2289 } else {
2290 Pair.first->second += NewScale;
2291 // The map already had an entry for this value, which may indicate
2292 // a folding opportunity.
2293 Interesting = true;
2294 }
2295 }
2296 } else {
2297 // An ordinary operand. Update the map.
2298 std::pair<DenseMap<const SCEV *, APInt>::iterator, bool> Pair =
2299 M.insert({Ops[i], Scale});
2300 if (Pair.second) {
2301 NewOps.push_back(Pair.first->first);
2302 } else {
2303 Pair.first->second += Scale;
2304 // The map already had an entry for this value, which may indicate
2305 // a folding opportunity.
2306 Interesting = true;
2307 }
2308 }
2309 }
2310
2311 return Interesting;
2312}
2313
2315 const SCEV *LHS, const SCEV *RHS,
2316 const Instruction *CtxI) {
2317 const SCEV *(ScalarEvolution::*Operation)(const SCEV *, const SCEV *,
2318 SCEV::NoWrapFlags, unsigned);
2319 switch (BinOp) {
2320 default:
2321 llvm_unreachable("Unsupported binary op");
2322 case Instruction::Add:
2324 break;
2325 case Instruction::Sub:
2327 break;
2328 case Instruction::Mul:
2330 break;
2331 }
2332
2333 const SCEV *(ScalarEvolution::*Extension)(const SCEV *, Type *, unsigned) =
2336
2337 // Check ext(LHS op RHS) == ext(LHS) op ext(RHS)
2338 auto *NarrowTy = cast<IntegerType>(LHS->getType());
2339 auto *WideTy =
2340 IntegerType::get(NarrowTy->getContext(), NarrowTy->getBitWidth() * 2);
2341
2342 const SCEV *A = (this->*Extension)(
2343 (this->*Operation)(LHS, RHS, SCEV::FlagAnyWrap, 0), WideTy, 0);
2344 const SCEV *LHSB = (this->*Extension)(LHS, WideTy, 0);
2345 const SCEV *RHSB = (this->*Extension)(RHS, WideTy, 0);
2346 const SCEV *B = (this->*Operation)(LHSB, RHSB, SCEV::FlagAnyWrap, 0);
2347 if (A == B)
2348 return true;
2349 // Can we use context to prove the fact we need?
2350 if (!CtxI)
2351 return false;
2352 // TODO: Support mul.
2353 if (BinOp == Instruction::Mul)
2354 return false;
2355 auto *RHSC = dyn_cast<SCEVConstant>(RHS);
2356 // TODO: Lift this limitation.
2357 if (!RHSC)
2358 return false;
2359 APInt C = RHSC->getAPInt();
2360 unsigned NumBits = C.getBitWidth();
2361 bool IsSub = (BinOp == Instruction::Sub);
2362 bool IsNegativeConst = (Signed && C.isNegative());
2363 // Compute the direction and magnitude by which we need to check overflow.
2364 bool OverflowDown = IsSub ^ IsNegativeConst;
2365 APInt Magnitude = C;
2366 if (IsNegativeConst) {
2367 if (C == APInt::getSignedMinValue(NumBits))
2368 // TODO: SINT_MIN on inversion gives the same negative value, we don't
2369 // want to deal with that.
2370 return false;
2371 Magnitude = -C;
2372 }
2373
2375 if (OverflowDown) {
2376 // To avoid overflow down, we need to make sure that MIN + Magnitude <= LHS.
2377 APInt Min = Signed ? APInt::getSignedMinValue(NumBits)
2378 : APInt::getMinValue(NumBits);
2379 APInt Limit = Min + Magnitude;
2380 return isKnownPredicateAt(Pred, getConstant(Limit), LHS, CtxI);
2381 } else {
2382 // To avoid overflow up, we need to make sure that LHS <= MAX - Magnitude.
2383 APInt Max = Signed ? APInt::getSignedMaxValue(NumBits)
2384 : APInt::getMaxValue(NumBits);
2385 APInt Limit = Max - Magnitude;
2386 return isKnownPredicateAt(Pred, LHS, getConstant(Limit), CtxI);
2387 }
2388}
2389
2390std::optional<SCEV::NoWrapFlags>
2392 const OverflowingBinaryOperator *OBO) {
2393 // It cannot be done any better.
2394 if (OBO->hasNoUnsignedWrap() && OBO->hasNoSignedWrap())
2395 return std::nullopt;
2396
2398
2399 if (OBO->hasNoUnsignedWrap())
2401 if (OBO->hasNoSignedWrap())
2403
2404 bool Deduced = false;
2405
2406 if (OBO->getOpcode() != Instruction::Add &&
2407 OBO->getOpcode() != Instruction::Sub &&
2408 OBO->getOpcode() != Instruction::Mul)
2409 return std::nullopt;
2410
2411 const SCEV *LHS = getSCEV(OBO->getOperand(0));
2412 const SCEV *RHS = getSCEV(OBO->getOperand(1));
2413
2414 const Instruction *CtxI =
2415 UseContextForNoWrapFlagInference ? dyn_cast<Instruction>(OBO) : nullptr;
2416 if (!OBO->hasNoUnsignedWrap() &&
2418 /* Signed */ false, LHS, RHS, CtxI)) {
2420 Deduced = true;
2421 }
2422
2423 if (!OBO->hasNoSignedWrap() &&
2425 /* Signed */ true, LHS, RHS, CtxI)) {
2427 Deduced = true;
2428 }
2429
2430 if (Deduced)
2431 return Flags;
2432 return std::nullopt;
2433}
2434
2435// We're trying to construct a SCEV of type `Type' with `Ops' as operands and
2436// `OldFlags' as can't-wrap behavior. Infer a more aggressive set of
2437// can't-overflow flags for the operation if possible.
2438static SCEV::NoWrapFlags
2440 const ArrayRef<const SCEV *> Ops,
2441 SCEV::NoWrapFlags Flags) {
2442 using namespace std::placeholders;
2443
2444 using OBO = OverflowingBinaryOperator;
2445
2446 bool CanAnalyze =
2448 (void)CanAnalyze;
2449 assert(CanAnalyze && "don't call from other places!");
2450
2451 int SignOrUnsignMask = SCEV::FlagNUW | SCEV::FlagNSW;
2452 SCEV::NoWrapFlags SignOrUnsignWrap =
2453 ScalarEvolution::maskFlags(Flags, SignOrUnsignMask);
2454
2455 // If FlagNSW is true and all the operands are non-negative, infer FlagNUW.
2456 auto IsKnownNonNegative = [&](const SCEV *S) {
2457 return SE->isKnownNonNegative(S);
2458 };
2459
2460 if (SignOrUnsignWrap == SCEV::FlagNSW && all_of(Ops, IsKnownNonNegative))
2461 Flags =
2462 ScalarEvolution::setFlags(Flags, (SCEV::NoWrapFlags)SignOrUnsignMask);
2463
2464 SignOrUnsignWrap = ScalarEvolution::maskFlags(Flags, SignOrUnsignMask);
2465
2466 if (SignOrUnsignWrap != SignOrUnsignMask &&
2467 (Type == scAddExpr || Type == scMulExpr) && Ops.size() == 2 &&
2468 isa<SCEVConstant>(Ops[0])) {
2469
2470 auto Opcode = [&] {
2471 switch (Type) {
2472 case scAddExpr:
2473 return Instruction::Add;
2474 case scMulExpr:
2475 return Instruction::Mul;
2476 default:
2477 llvm_unreachable("Unexpected SCEV op.");
2478 }
2479 }();
2480
2481 const APInt &C = cast<SCEVConstant>(Ops[0])->getAPInt();
2482
2483 // (A <opcode> C) --> (A <opcode> C)<nsw> if the op doesn't sign overflow.
2484 if (!(SignOrUnsignWrap & SCEV::FlagNSW)) {
2486 Opcode, C, OBO::NoSignedWrap);
2487 if (NSWRegion.contains(SE->getSignedRange(Ops[1])))
2489 }
2490
2491 // (A <opcode> C) --> (A <opcode> C)<nuw> if the op doesn't unsign overflow.
2492 if (!(SignOrUnsignWrap & SCEV::FlagNUW)) {
2494 Opcode, C, OBO::NoUnsignedWrap);
2495 if (NUWRegion.contains(SE->getUnsignedRange(Ops[1])))
2497 }
2498 }
2499
2500 // <0,+,nonnegative><nw> is also nuw
2501 // TODO: Add corresponding nsw case
2503 !ScalarEvolution::hasFlags(Flags, SCEV::FlagNUW) && Ops.size() == 2 &&
2504 Ops[0]->isZero() && IsKnownNonNegative(Ops[1]))
2506
2507 // both (udiv X, Y) * Y and Y * (udiv X, Y) are always NUW
2509 Ops.size() == 2) {
2510 if (auto *UDiv = dyn_cast<SCEVUDivExpr>(Ops[0]))
2511 if (UDiv->getOperand(1) == Ops[1])
2513 if (auto *UDiv = dyn_cast<SCEVUDivExpr>(Ops[1]))
2514 if (UDiv->getOperand(1) == Ops[0])
2516 }
2517
2518 return Flags;
2519}
2520
2522 return isLoopInvariant(S, L) && properlyDominates(S, L->getHeader());
2523}
2524
2525/// Get a canonical add expression, or something simpler if possible.
2527 SCEV::NoWrapFlags OrigFlags,
2528 unsigned Depth) {
2529 assert(!(OrigFlags & ~(SCEV::FlagNUW | SCEV::FlagNSW)) &&
2530 "only nuw or nsw allowed");
2531 assert(!Ops.empty() && "Cannot get empty add!");
2532 if (Ops.size() == 1) return Ops[0];
2533#ifndef NDEBUG
2534 Type *ETy = getEffectiveSCEVType(Ops[0]->getType());
2535 for (unsigned i = 1, e = Ops.size(); i != e; ++i)
2536 assert(getEffectiveSCEVType(Ops[i]->getType()) == ETy &&
2537 "SCEVAddExpr operand types don't match!");
2538 unsigned NumPtrs = count_if(
2539 Ops, [](const SCEV *Op) { return Op->getType()->isPointerTy(); });
2540 assert(NumPtrs <= 1 && "add has at most one pointer operand");
2541#endif
2542
2543 const SCEV *Folded = constantFoldAndGroupOps(
2544 *this, LI, DT, Ops,
2545 [](const APInt &C1, const APInt &C2) { return C1 + C2; },
2546 [](const APInt &C) { return C.isZero(); }, // identity
2547 [](const APInt &C) { return false; }); // absorber
2548 if (Folded)
2549 return Folded;
2550
2551 unsigned Idx = isa<SCEVConstant>(Ops[0]) ? 1 : 0;
2552
2553 // Delay expensive flag strengthening until necessary.
2554 auto ComputeFlags = [this, OrigFlags](const ArrayRef<const SCEV *> Ops) {
2555 return StrengthenNoWrapFlags(this, scAddExpr, Ops, OrigFlags);
2556 };
2557
2558 // Limit recursion calls depth.
2560 return getOrCreateAddExpr(Ops, ComputeFlags(Ops));
2561
2562 if (SCEV *S = findExistingSCEVInCache(scAddExpr, Ops)) {
2563 // Don't strengthen flags if we have no new information.
2564 SCEVAddExpr *Add = static_cast<SCEVAddExpr *>(S);
2565 if (Add->getNoWrapFlags(OrigFlags) != OrigFlags)
2566 Add->setNoWrapFlags(ComputeFlags(Ops));
2567 return S;
2568 }
2569
2570 // Okay, check to see if the same value occurs in the operand list more than
2571 // once. If so, merge them together into an multiply expression. Since we
2572 // sorted the list, these values are required to be adjacent.
2573 Type *Ty = Ops[0]->getType();
2574 bool FoundMatch = false;
2575 for (unsigned i = 0, e = Ops.size(); i != e-1; ++i)
2576 if (Ops[i] == Ops[i+1]) { // X + Y + Y --> X + Y*2
2577 // Scan ahead to count how many equal operands there are.
2578 unsigned Count = 2;
2579 while (i+Count != e && Ops[i+Count] == Ops[i])
2580 ++Count;
2581 // Merge the values into a multiply.
2582 const SCEV *Scale = getConstant(Ty, Count);
2583 const SCEV *Mul = getMulExpr(Scale, Ops[i], SCEV::FlagAnyWrap, Depth + 1);
2584 if (Ops.size() == Count)
2585 return Mul;
2586 Ops[i] = Mul;
2587 Ops.erase(Ops.begin()+i+1, Ops.begin()+i+Count);
2588 --i; e -= Count - 1;
2589 FoundMatch = true;
2590 }
2591 if (FoundMatch)
2592 return getAddExpr(Ops, OrigFlags, Depth + 1);
2593
2594 // Check for truncates. If all the operands are truncated from the same
2595 // type, see if factoring out the truncate would permit the result to be
2596 // folded. eg., n*trunc(x) + m*trunc(y) --> trunc(trunc(m)*x + trunc(n)*y)
2597 // if the contents of the resulting outer trunc fold to something simple.
2598 auto FindTruncSrcType = [&]() -> Type * {
2599 // We're ultimately looking to fold an addrec of truncs and muls of only
2600 // constants and truncs, so if we find any other types of SCEV
2601 // as operands of the addrec then we bail and return nullptr here.
2602 // Otherwise, we return the type of the operand of a trunc that we find.
2603 if (auto *T = dyn_cast<SCEVTruncateExpr>(Ops[Idx]))
2604 return T->getOperand()->getType();
2605 if (const auto *Mul = dyn_cast<SCEVMulExpr>(Ops[Idx])) {
2606 const auto *LastOp = Mul->getOperand(Mul->getNumOperands() - 1);
2607 if (const auto *T = dyn_cast<SCEVTruncateExpr>(LastOp))
2608 return T->getOperand()->getType();
2609 }
2610 return nullptr;
2611 };
2612 if (auto *SrcType = FindTruncSrcType()) {
2614 bool Ok = true;
2615 // Check all the operands to see if they can be represented in the
2616 // source type of the truncate.
2617 for (const SCEV *Op : Ops) {
2618 if (const SCEVTruncateExpr *T = dyn_cast<SCEVTruncateExpr>(Op)) {
2619 if (T->getOperand()->getType() != SrcType) {
2620 Ok = false;
2621 break;
2622 }
2623 LargeOps.push_back(T->getOperand());
2624 } else if (const SCEVConstant *C = dyn_cast<SCEVConstant>(Op)) {
2625 LargeOps.push_back(getAnyExtendExpr(C, SrcType));
2626 } else if (const SCEVMulExpr *M = dyn_cast<SCEVMulExpr>(Op)) {
2627 SmallVector<const SCEV *, 8> LargeMulOps;
2628 for (unsigned j = 0, f = M->getNumOperands(); j != f && Ok; ++j) {
2629 if (const SCEVTruncateExpr *T =
2630 dyn_cast<SCEVTruncateExpr>(M->getOperand(j))) {
2631 if (T->getOperand()->getType() != SrcType) {
2632 Ok = false;
2633 break;
2634 }
2635 LargeMulOps.push_back(T->getOperand());
2636 } else if (const auto *C = dyn_cast<SCEVConstant>(M->getOperand(j))) {
2637 LargeMulOps.push_back(getAnyExtendExpr(C, SrcType));
2638 } else {
2639 Ok = false;
2640 break;
2641 }
2642 }
2643 if (Ok)
2644 LargeOps.push_back(getMulExpr(LargeMulOps, SCEV::FlagAnyWrap, Depth + 1));
2645 } else {
2646 Ok = false;
2647 break;
2648 }
2649 }
2650 if (Ok) {
2651 // Evaluate the expression in the larger type.
2652 const SCEV *Fold = getAddExpr(LargeOps, SCEV::FlagAnyWrap, Depth + 1);
2653 // If it folds to something simple, use it. Otherwise, don't.
2654 if (isa<SCEVConstant>(Fold) || isa<SCEVUnknown>(Fold))
2655 return getTruncateExpr(Fold, Ty);
2656 }
2657 }
2658
2659 if (Ops.size() == 2) {
2660 // Check if we have an expression of the form ((X + C1) - C2), where C1 and
2661 // C2 can be folded in a way that allows retaining wrapping flags of (X +
2662 // C1).
2663 const SCEV *A = Ops[0];
2664 const SCEV *B = Ops[1];
2665 auto *AddExpr = dyn_cast<SCEVAddExpr>(B);
2666 auto *C = dyn_cast<SCEVConstant>(A);
2667 if (AddExpr && C && isa<SCEVConstant>(AddExpr->getOperand(0))) {
2668 auto C1 = cast<SCEVConstant>(AddExpr->getOperand(0))->getAPInt();
2669 auto C2 = C->getAPInt();
2670 SCEV::NoWrapFlags PreservedFlags = SCEV::FlagAnyWrap;
2671
2672 APInt ConstAdd = C1 + C2;
2673 auto AddFlags = AddExpr->getNoWrapFlags();
2674 // Adding a smaller constant is NUW if the original AddExpr was NUW.
2676 ConstAdd.ule(C1)) {
2677 PreservedFlags =
2679 }
2680
2681 // Adding a constant with the same sign and small magnitude is NSW, if the
2682 // original AddExpr was NSW.
2684 C1.isSignBitSet() == ConstAdd.isSignBitSet() &&
2685 ConstAdd.abs().ule(C1.abs())) {
2686 PreservedFlags =
2688 }
2689
2690 if (PreservedFlags != SCEV::FlagAnyWrap) {
2691 SmallVector<const SCEV *, 4> NewOps(AddExpr->operands());
2692 NewOps[0] = getConstant(ConstAdd);
2693 return getAddExpr(NewOps, PreservedFlags);
2694 }
2695 }
2696 }
2697
2698 // Canonicalize (-1 * urem X, Y) + X --> (Y * X/Y)
2699 if (Ops.size() == 2) {
2700 const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(Ops[0]);
2701 if (Mul && Mul->getNumOperands() == 2 &&
2702 Mul->getOperand(0)->isAllOnesValue()) {
2703 const SCEV *X;
2704 const SCEV *Y;
2705 if (matchURem(Mul->getOperand(1), X, Y) && X == Ops[1]) {
2706 return getMulExpr(Y, getUDivExpr(X, Y));
2707 }
2708 }
2709 }
2710
2711 // Skip past any other cast SCEVs.
2712 while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scAddExpr)
2713 ++Idx;
2714
2715 // If there are add operands they would be next.
2716 if (Idx < Ops.size()) {
2717 bool DeletedAdd = false;
2718 // If the original flags and all inlined SCEVAddExprs are NUW, use the
2719 // common NUW flag for expression after inlining. Other flags cannot be
2720 // preserved, because they may depend on the original order of operations.
2721 SCEV::NoWrapFlags CommonFlags = maskFlags(OrigFlags, SCEV::FlagNUW);
2722 while (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Ops[Idx])) {
2723 if (Ops.size() > AddOpsInlineThreshold ||
2724 Add->getNumOperands() > AddOpsInlineThreshold)
2725 break;
2726 // If we have an add, expand the add operands onto the end of the operands
2727 // list.
2728 Ops.erase(Ops.begin()+Idx);
2729 append_range(Ops, Add->operands());
2730 DeletedAdd = true;
2731 CommonFlags = maskFlags(CommonFlags, Add->getNoWrapFlags());
2732 }
2733
2734 // If we deleted at least one add, we added operands to the end of the list,
2735 // and they are not necessarily sorted. Recurse to resort and resimplify
2736 // any operands we just acquired.
2737 if (DeletedAdd)
2738 return getAddExpr(Ops, CommonFlags, Depth + 1);
2739 }
2740
2741 // Skip over the add expression until we get to a multiply.
2742 while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scMulExpr)
2743 ++Idx;
2744
2745 // Check to see if there are any folding opportunities present with
2746 // operands multiplied by constant values.
2747 if (Idx < Ops.size() && isa<SCEVMulExpr>(Ops[Idx])) {
2751 APInt AccumulatedConstant(BitWidth, 0);
2752 if (CollectAddOperandsWithScales(M, NewOps, AccumulatedConstant,
2753 Ops, APInt(BitWidth, 1), *this)) {
2754 struct APIntCompare {
2755 bool operator()(const APInt &LHS, const APInt &RHS) const {
2756 return LHS.ult(RHS);
2757 }
2758 };
2759
2760 // Some interesting folding opportunity is present, so its worthwhile to
2761 // re-generate the operands list. Group the operands by constant scale,
2762 // to avoid multiplying by the same constant scale multiple times.
2763 std::map<APInt, SmallVector<const SCEV *, 4>, APIntCompare> MulOpLists;
2764 for (const SCEV *NewOp : NewOps)
2765 MulOpLists[M.find(NewOp)->second].push_back(NewOp);
2766 // Re-generate the operands list.
2767 Ops.clear();
2768 if (AccumulatedConstant != 0)
2769 Ops.push_back(getConstant(AccumulatedConstant));
2770 for (auto &MulOp : MulOpLists) {
2771 if (MulOp.first == 1) {
2772 Ops.push_back(getAddExpr(MulOp.second, SCEV::FlagAnyWrap, Depth + 1));
2773 } else if (MulOp.first != 0) {
2775 getConstant(MulOp.first),
2776 getAddExpr(MulOp.second, SCEV::FlagAnyWrap, Depth + 1),
2777 SCEV::FlagAnyWrap, Depth + 1));
2778 }
2779 }
2780 if (Ops.empty())
2781 return getZero(Ty);
2782 if (Ops.size() == 1)
2783 return Ops[0];
2784 return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
2785 }
2786 }
2787
2788 // If we are adding something to a multiply expression, make sure the
2789 // something is not already an operand of the multiply. If so, merge it into
2790 // the multiply.
2791 for (; Idx < Ops.size() && isa<SCEVMulExpr>(Ops[Idx]); ++Idx) {
2792 const SCEVMulExpr *Mul = cast<SCEVMulExpr>(Ops[Idx]);
2793 for (unsigned MulOp = 0, e = Mul->getNumOperands(); MulOp != e; ++MulOp) {
2794 const SCEV *MulOpSCEV = Mul->getOperand(MulOp);
2795 if (isa<SCEVConstant>(MulOpSCEV))
2796 continue;
2797 for (unsigned AddOp = 0, e = Ops.size(); AddOp != e; ++AddOp)
2798 if (MulOpSCEV == Ops[AddOp]) {
2799 // Fold W + X + (X * Y * Z) --> W + (X * ((Y*Z)+1))
2800 const SCEV *InnerMul = Mul->getOperand(MulOp == 0);
2801 if (Mul->getNumOperands() != 2) {
2802 // If the multiply has more than two operands, we must get the
2803 // Y*Z term.
2805 Mul->operands().take_front(MulOp));
2806 append_range(MulOps, Mul->operands().drop_front(MulOp + 1));
2807 InnerMul = getMulExpr(MulOps, SCEV::FlagAnyWrap, Depth + 1);
2808 }
2809 SmallVector<const SCEV *, 2> TwoOps = {getOne(Ty), InnerMul};
2810 const SCEV *AddOne = getAddExpr(TwoOps, SCEV::FlagAnyWrap, Depth + 1);
2811 const SCEV *OuterMul = getMulExpr(AddOne, MulOpSCEV,
2813 if (Ops.size() == 2) return OuterMul;
2814 if (AddOp < Idx) {
2815 Ops.erase(Ops.begin()+AddOp);
2816 Ops.erase(Ops.begin()+Idx-1);
2817 } else {
2818 Ops.erase(Ops.begin()+Idx);
2819 Ops.erase(Ops.begin()+AddOp-1);
2820 }
2821 Ops.push_back(OuterMul);
2822 return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
2823 }
2824
2825 // Check this multiply against other multiplies being added together.
2826 for (unsigned OtherMulIdx = Idx+1;
2827 OtherMulIdx < Ops.size() && isa<SCEVMulExpr>(Ops[OtherMulIdx]);
2828 ++OtherMulIdx) {
2829 const SCEVMulExpr *OtherMul = cast<SCEVMulExpr>(Ops[OtherMulIdx]);
2830 // If MulOp occurs in OtherMul, we can fold the two multiplies
2831 // together.
2832 for (unsigned OMulOp = 0, e = OtherMul->getNumOperands();
2833 OMulOp != e; ++OMulOp)
2834 if (OtherMul->getOperand(OMulOp) == MulOpSCEV) {
2835 // Fold X + (A*B*C) + (A*D*E) --> X + (A*(B*C+D*E))
2836 const SCEV *InnerMul1 = Mul->getOperand(MulOp == 0);
2837 if (Mul->getNumOperands() != 2) {
2839 Mul->operands().take_front(MulOp));
2840 append_range(MulOps, Mul->operands().drop_front(MulOp+1));
2841 InnerMul1 = getMulExpr(MulOps, SCEV::FlagAnyWrap, Depth + 1);
2842 }
2843 const SCEV *InnerMul2 = OtherMul->getOperand(OMulOp == 0);
2844 if (OtherMul->getNumOperands() != 2) {
2846 OtherMul->operands().take_front(OMulOp));
2847 append_range(MulOps, OtherMul->operands().drop_front(OMulOp+1));
2848 InnerMul2 = getMulExpr(MulOps, SCEV::FlagAnyWrap, Depth + 1);
2849 }
2850 SmallVector<const SCEV *, 2> TwoOps = {InnerMul1, InnerMul2};
2851 const SCEV *InnerMulSum =
2852 getAddExpr(TwoOps, SCEV::FlagAnyWrap, Depth + 1);
2853 const SCEV *OuterMul = getMulExpr(MulOpSCEV, InnerMulSum,
2855 if (Ops.size() == 2) return OuterMul;
2856 Ops.erase(Ops.begin()+Idx);
2857 Ops.erase(Ops.begin()+OtherMulIdx-1);
2858 Ops.push_back(OuterMul);
2859 return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
2860 }
2861 }
2862 }
2863 }
2864
2865 // If there are any add recurrences in the operands list, see if any other
2866 // added values are loop invariant. If so, we can fold them into the
2867 // recurrence.
2868 while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scAddRecExpr)
2869 ++Idx;
2870
2871 // Scan over all recurrences, trying to fold loop invariants into them.
2872 for (; Idx < Ops.size() && isa<SCEVAddRecExpr>(Ops[Idx]); ++Idx) {
2873 // Scan all of the other operands to this add and add them to the vector if
2874 // they are loop invariant w.r.t. the recurrence.
2876 const SCEVAddRecExpr *AddRec = cast<SCEVAddRecExpr>(Ops[Idx]);
2877 const Loop *AddRecLoop = AddRec->getLoop();
2878 for (unsigned i = 0, e = Ops.size(); i != e; ++i)
2879 if (isAvailableAtLoopEntry(Ops[i], AddRecLoop)) {
2880 LIOps.push_back(Ops[i]);
2881 Ops.erase(Ops.begin()+i);
2882 --i; --e;
2883 }
2884
2885 // If we found some loop invariants, fold them into the recurrence.
2886 if (!LIOps.empty()) {
2887 // Compute nowrap flags for the addition of the loop-invariant ops and
2888 // the addrec. Temporarily push it as an operand for that purpose. These
2889 // flags are valid in the scope of the addrec only.
2890 LIOps.push_back(AddRec);
2891 SCEV::NoWrapFlags Flags = ComputeFlags(LIOps);
2892 LIOps.pop_back();
2893
2894 // NLI + LI + {Start,+,Step} --> NLI + {LI+Start,+,Step}
2895 LIOps.push_back(AddRec->getStart());
2896
2897 SmallVector<const SCEV *, 4> AddRecOps(AddRec->operands());
2898
2899 // It is not in general safe to propagate flags valid on an add within
2900 // the addrec scope to one outside it. We must prove that the inner
2901 // scope is guaranteed to execute if the outer one does to be able to
2902 // safely propagate. We know the program is undefined if poison is
2903 // produced on the inner scoped addrec. We also know that *for this use*
2904 // the outer scoped add can't overflow (because of the flags we just
2905 // computed for the inner scoped add) without the program being undefined.
2906 // Proving that entry to the outer scope neccesitates entry to the inner
2907 // scope, thus proves the program undefined if the flags would be violated
2908 // in the outer scope.
2909 SCEV::NoWrapFlags AddFlags = Flags;
2910 if (AddFlags != SCEV::FlagAnyWrap) {
2911 auto *DefI = getDefiningScopeBound(LIOps);
2912 auto *ReachI = &*AddRecLoop->getHeader()->begin();
2913 if (!isGuaranteedToTransferExecutionTo(DefI, ReachI))
2914 AddFlags = SCEV::FlagAnyWrap;
2915 }
2916 AddRecOps[0] = getAddExpr(LIOps, AddFlags, Depth + 1);
2917
2918 // Build the new addrec. Propagate the NUW and NSW flags if both the
2919 // outer add and the inner addrec are guaranteed to have no overflow.
2920 // Always propagate NW.
2921 Flags = AddRec->getNoWrapFlags(setFlags(Flags, SCEV::FlagNW));
2922 const SCEV *NewRec = getAddRecExpr(AddRecOps, AddRecLoop, Flags);
2923
2924 // If all of the other operands were loop invariant, we are done.
2925 if (Ops.size() == 1) return NewRec;
2926
2927 // Otherwise, add the folded AddRec by the non-invariant parts.
2928 for (unsigned i = 0;; ++i)
2929 if (Ops[i] == AddRec) {
2930 Ops[i] = NewRec;
2931 break;
2932 }
2933 return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
2934 }
2935
2936 // Okay, if there weren't any loop invariants to be folded, check to see if
2937 // there are multiple AddRec's with the same loop induction variable being
2938 // added together. If so, we can fold them.
2939 for (unsigned OtherIdx = Idx+1;
2940 OtherIdx < Ops.size() && isa<SCEVAddRecExpr>(Ops[OtherIdx]);
2941 ++OtherIdx) {
2942 // We expect the AddRecExpr's to be sorted in reverse dominance order,
2943 // so that the 1st found AddRecExpr is dominated by all others.
2944 assert(DT.dominates(
2945 cast<SCEVAddRecExpr>(Ops[OtherIdx])->getLoop()->getHeader(),
2946 AddRec->getLoop()->getHeader()) &&
2947 "AddRecExprs are not sorted in reverse dominance order?");
2948 if (AddRecLoop == cast<SCEVAddRecExpr>(Ops[OtherIdx])->getLoop()) {
2949 // Other + {A,+,B}<L> + {C,+,D}<L> --> Other + {A+C,+,B+D}<L>
2950 SmallVector<const SCEV *, 4> AddRecOps(AddRec->operands());
2951 for (; OtherIdx != Ops.size() && isa<SCEVAddRecExpr>(Ops[OtherIdx]);
2952 ++OtherIdx) {
2953 const auto *OtherAddRec = cast<SCEVAddRecExpr>(Ops[OtherIdx]);
2954 if (OtherAddRec->getLoop() == AddRecLoop) {
2955 for (unsigned i = 0, e = OtherAddRec->getNumOperands();
2956 i != e; ++i) {
2957 if (i >= AddRecOps.size()) {
2958 append_range(AddRecOps, OtherAddRec->operands().drop_front(i));
2959 break;
2960 }
2962 AddRecOps[i], OtherAddRec->getOperand(i)};
2963 AddRecOps[i] = getAddExpr(TwoOps, SCEV::FlagAnyWrap, Depth + 1);
2964 }
2965 Ops.erase(Ops.begin() + OtherIdx); --OtherIdx;
2966 }
2967 }
2968 // Step size has changed, so we cannot guarantee no self-wraparound.
2969 Ops[Idx] = getAddRecExpr(AddRecOps, AddRecLoop, SCEV::FlagAnyWrap);
2970 return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
2971 }
2972 }
2973
2974 // Otherwise couldn't fold anything into this recurrence. Move onto the
2975 // next one.
2976 }
2977
2978 // Okay, it looks like we really DO need an add expr. Check to see if we
2979 // already have one, otherwise create a new one.
2980 return getOrCreateAddExpr(Ops, ComputeFlags(Ops));
2981}
2982
2983const SCEV *
2984ScalarEvolution::getOrCreateAddExpr(ArrayRef<const SCEV *> Ops,
2985 SCEV::NoWrapFlags Flags) {
2987 ID.AddInteger(scAddExpr);
2988 for (const SCEV *Op : Ops)
2989 ID.AddPointer(Op);
2990 void *IP = nullptr;
2991 SCEVAddExpr *S =
2992 static_cast<SCEVAddExpr *>(UniqueSCEVs.FindNodeOrInsertPos(ID, IP));
2993 if (!S) {
2994 const SCEV **O = SCEVAllocator.Allocate<const SCEV *>(Ops.size());
2995 std::uninitialized_copy(Ops.begin(), Ops.end(), O);
2996 S = new (SCEVAllocator)
2997 SCEVAddExpr(ID.Intern(SCEVAllocator), O, Ops.size());
2998 UniqueSCEVs.InsertNode(S, IP);
2999 registerUser(S, Ops);
3000 }
3001 S->setNoWrapFlags(Flags);
3002 return S;
3003}
3004
3005const SCEV *
3006ScalarEvolution::getOrCreateAddRecExpr(ArrayRef<const SCEV *> Ops,
3007 const Loop *L, SCEV::NoWrapFlags Flags) {
3009 ID.AddInteger(scAddRecExpr);
3010 for (const SCEV *Op : Ops)
3011 ID.AddPointer(Op);
3012 ID.AddPointer(L);
3013 void *IP = nullptr;
3014 SCEVAddRecExpr *S =
3015 static_cast<SCEVAddRecExpr *>(UniqueSCEVs.FindNodeOrInsertPos(ID, IP));
3016 if (!S) {
3017 const SCEV **O = SCEVAllocator.Allocate<const SCEV *>(Ops.size());
3018 std::uninitialized_copy(Ops.begin(), Ops.end(), O);
3019 S = new (SCEVAllocator)
3020 SCEVAddRecExpr(ID.Intern(SCEVAllocator), O, Ops.size(), L);
3021 UniqueSCEVs.InsertNode(S, IP);
3022 LoopUsers[L].push_back(S);
3023 registerUser(S, Ops);
3024 }
3025 setNoWrapFlags(S, Flags);
3026 return S;
3027}
3028
3029const SCEV *
3030ScalarEvolution::getOrCreateMulExpr(ArrayRef<const SCEV *> Ops,
3031 SCEV::NoWrapFlags Flags) {
3033 ID.AddInteger(scMulExpr);
3034 for (const SCEV *Op : Ops)
3035 ID.AddPointer(Op);
3036 void *IP = nullptr;
3037 SCEVMulExpr *S =
3038 static_cast<SCEVMulExpr *>(UniqueSCEVs.FindNodeOrInsertPos(ID, IP));
3039 if (!S) {
3040 const SCEV **O = SCEVAllocator.Allocate<const SCEV *>(Ops.size());
3041 std::uninitialized_copy(Ops.begin(), Ops.end(), O);
3042 S = new (SCEVAllocator) SCEVMulExpr(ID.Intern(SCEVAllocator),
3043 O, Ops.size());
3044 UniqueSCEVs.InsertNode(S, IP);
3045 registerUser(S, Ops);
3046 }
3047 S->setNoWrapFlags(Flags);
3048 return S;
3049}
3050
3051static uint64_t umul_ov(uint64_t i, uint64_t j, bool &Overflow) {
3052 uint64_t k = i*j;
3053 if (j > 1 && k / j != i) Overflow = true;
3054 return k;
3055}
3056
3057/// Compute the result of "n choose k", the binomial coefficient. If an
3058/// intermediate computation overflows, Overflow will be set and the return will
3059/// be garbage. Overflow is not cleared on absence of overflow.
3060static uint64_t Choose(uint64_t n, uint64_t k, bool &Overflow) {
3061 // We use the multiplicative formula:
3062 // n(n-1)(n-2)...(n-(k-1)) / k(k-1)(k-2)...1 .
3063 // At each iteration, we take the n-th term of the numeral and divide by the
3064 // (k-n)th term of the denominator. This division will always produce an
3065 // integral result, and helps reduce the chance of overflow in the
3066 // intermediate computations. However, we can still overflow even when the
3067 // final result would fit.
3068
3069 if (n == 0 || n == k) return 1;
3070 if (k > n) return 0;
3071
3072 if (k > n/2)
3073 k = n-k;
3074
3075 uint64_t r = 1;
3076 for (uint64_t i = 1; i <= k; ++i) {
3077 r = umul_ov(r, n-(i-1), Overflow);
3078 r /= i;
3079 }
3080 return r;
3081}
3082
3083/// Determine if any of the operands in this SCEV are a constant or if
3084/// any of the add or multiply expressions in this SCEV contain a constant.
3085static bool containsConstantInAddMulChain(const SCEV *StartExpr) {
3086 struct FindConstantInAddMulChain {
3087 bool FoundConstant = false;
3088
3089 bool follow(const SCEV *S) {
3090 FoundConstant |= isa<SCEVConstant>(S);
3091 return isa<SCEVAddExpr>(S) || isa<SCEVMulExpr>(S);
3092 }
3093
3094 bool isDone() const {
3095 return FoundConstant;
3096 }
3097 };
3098
3099 FindConstantInAddMulChain F;
3101 ST.visitAll(StartExpr);
3102 return F.FoundConstant;
3103}
3104
3105/// Get a canonical multiply expression, or something simpler if possible.
3107 SCEV::NoWrapFlags OrigFlags,
3108 unsigned Depth) {
3109 assert(OrigFlags == maskFlags(OrigFlags, SCEV::FlagNUW | SCEV::FlagNSW) &&
3110 "only nuw or nsw allowed");
3111 assert(!Ops.empty() && "Cannot get empty mul!");
3112 if (Ops.size() == 1) return Ops[0];
3113#ifndef NDEBUG
3114 Type *ETy = Ops[0]->getType();
3115 assert(!ETy->isPointerTy());
3116 for (unsigned i = 1, e = Ops.size(); i != e; ++i)
3117 assert(Ops[i]->getType() == ETy &&
3118 "SCEVMulExpr operand types don't match!");
3119#endif
3120
3121 const SCEV *Folded = constantFoldAndGroupOps(
3122 *this, LI, DT, Ops,
3123 [](const APInt &C1, const APInt &C2) { return C1 * C2; },
3124 [](const APInt &C) { return C.isOne(); }, // identity
3125 [](const APInt &C) { return C.isZero(); }); // absorber
3126 if (Folded)
3127 return Folded;
3128
3129 // Delay expensive flag strengthening until necessary.
3130 auto ComputeFlags = [this, OrigFlags](const ArrayRef<const SCEV *> Ops) {
3131 return StrengthenNoWrapFlags(this, scMulExpr, Ops, OrigFlags);
3132 };
3133
3134 // Limit recursion calls depth.
3136 return getOrCreateMulExpr(Ops, ComputeFlags(Ops));
3137
3138 if (SCEV *S = findExistingSCEVInCache(scMulExpr, Ops)) {
3139 // Don't strengthen flags if we have no new information.
3140 SCEVMulExpr *Mul = static_cast<SCEVMulExpr *>(S);
3141 if (Mul->getNoWrapFlags(OrigFlags) != OrigFlags)
3142 Mul->setNoWrapFlags(ComputeFlags(Ops));
3143 return S;
3144 }
3145
3146 if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(Ops[0])) {
3147 if (Ops.size() == 2) {
3148 // C1*(C2+V) -> C1*C2 + C1*V
3149 if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Ops[1]))
3150 // If any of Add's ops are Adds or Muls with a constant, apply this
3151 // transformation as well.
3152 //
3153 // TODO: There are some cases where this transformation is not
3154 // profitable; for example, Add = (C0 + X) * Y + Z. Maybe the scope of
3155 // this transformation should be narrowed down.
3156 if (Add->getNumOperands() == 2 && containsConstantInAddMulChain(Add)) {
3157 const SCEV *LHS = getMulExpr(LHSC, Add->getOperand(0),
3159 const SCEV *RHS = getMulExpr(LHSC, Add->getOperand(1),
3161 return getAddExpr(LHS, RHS, SCEV::FlagAnyWrap, Depth + 1);
3162 }
3163
3164 if (Ops[0]->isAllOnesValue()) {
3165 // If we have a mul by -1 of an add, try distributing the -1 among the
3166 // add operands.
3167 if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Ops[1])) {
3169 bool AnyFolded = false;
3170 for (const SCEV *AddOp : Add->operands()) {
3171 const SCEV *Mul = getMulExpr(Ops[0], AddOp, SCEV::FlagAnyWrap,
3172 Depth + 1);
3173 if (!isa<SCEVMulExpr>(Mul)) AnyFolded = true;
3174 NewOps.push_back(Mul);
3175 }
3176 if (AnyFolded)
3177 return getAddExpr(NewOps, SCEV::FlagAnyWrap, Depth + 1);
3178 } else if (const auto *AddRec = dyn_cast<SCEVAddRecExpr>(Ops[1])) {
3179 // Negation preserves a recurrence's no self-wrap property.
3181 for (const SCEV *AddRecOp : AddRec->operands())
3182 Operands.push_back(getMulExpr(Ops[0], AddRecOp, SCEV::FlagAnyWrap,
3183 Depth + 1));
3184 // Let M be the minimum representable signed value. AddRec with nsw
3185 // multiplied by -1 can have signed overflow if and only if it takes a
3186 // value of M: M * (-1) would stay M and (M + 1) * (-1) would be the
3187 // maximum signed value. In all other cases signed overflow is
3188 // impossible.
3189 auto FlagsMask = SCEV::FlagNW;
3190 if (hasFlags(AddRec->getNoWrapFlags(), SCEV::FlagNSW)) {
3191 auto MinInt =
3192 APInt::getSignedMinValue(getTypeSizeInBits(AddRec->getType()));
3193 if (getSignedRangeMin(AddRec) != MinInt)
3194 FlagsMask = setFlags(FlagsMask, SCEV::FlagNSW);
3195 }
3196 return getAddRecExpr(Operands, AddRec->getLoop(),
3197 AddRec->getNoWrapFlags(FlagsMask));
3198 }
3199 }
3200 }
3201 }
3202
3203 // Skip over the add expression until we get to a multiply.
3204 unsigned Idx = 0;
3205 while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scMulExpr)
3206 ++Idx;
3207
3208 // If there are mul operands inline them all into this expression.
3209 if (Idx < Ops.size()) {
3210 bool DeletedMul = false;
3211 while (const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(Ops[Idx])) {
3212 if (Ops.size() > MulOpsInlineThreshold)
3213 break;
3214 // If we have an mul, expand the mul operands onto the end of the
3215 // operands list.
3216 Ops.erase(Ops.begin()+Idx);
3217 append_range(Ops, Mul->operands());
3218 DeletedMul = true;
3219 }
3220
3221 // If we deleted at least one mul, we added operands to the end of the
3222 // list, and they are not necessarily sorted. Recurse to resort and
3223 // resimplify any operands we just acquired.
3224 if (DeletedMul)
3225 return getMulExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
3226 }
3227
3228 // If there are any add recurrences in the operands list, see if any other
3229 // added values are loop invariant. If so, we can fold them into the
3230 // recurrence.
3231 while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scAddRecExpr)
3232 ++Idx;
3233
3234 // Scan over all recurrences, trying to fold loop invariants into them.
3235 for (; Idx < Ops.size() && isa<SCEVAddRecExpr>(Ops[Idx]); ++Idx) {
3236 // Scan all of the other operands to this mul and add them to the vector
3237 // if they are loop invariant w.r.t. the recurrence.
3239 const SCEVAddRecExpr *AddRec = cast<SCEVAddRecExpr>(Ops[Idx]);
3240 for (unsigned i = 0, e = Ops.size(); i != e; ++i)
3241 if (isAvailableAtLoopEntry(Ops[i], AddRec->getLoop())) {
3242 LIOps.push_back(Ops[i]);
3243 Ops.erase(Ops.begin()+i);
3244 --i; --e;
3245 }
3246
3247 // If we found some loop invariants, fold them into the recurrence.
3248 if (!LIOps.empty()) {
3249 // NLI * LI * {Start,+,Step} --> NLI * {LI*Start,+,LI*Step}
3251 NewOps.reserve(AddRec->getNumOperands());
3252 const SCEV *Scale = getMulExpr(LIOps, SCEV::FlagAnyWrap, Depth + 1);
3253
3254 // If both the mul and addrec are nuw, we can preserve nuw.
3255 // If both the mul and addrec are nsw, we can only preserve nsw if either
3256 // a) they are also nuw, or
3257 // b) all multiplications of addrec operands with scale are nsw.
3258 SCEV::NoWrapFlags Flags =
3259 AddRec->getNoWrapFlags(ComputeFlags({Scale, AddRec}));
3260
3261 for (unsigned i = 0, e = AddRec->getNumOperands(); i != e; ++i) {
3262 NewOps.push_back(getMulExpr(Scale, AddRec->getOperand(i),
3263 SCEV::FlagAnyWrap, Depth + 1));
3264
3265 if (hasFlags(Flags, SCEV::FlagNSW) && !hasFlags(Flags, SCEV::FlagNUW)) {
3267 Instruction::Mul, getSignedRange(Scale),
3269 if (!NSWRegion.contains(getSignedRange(AddRec->getOperand(i))))
3270 Flags = clearFlags(Flags, SCEV::FlagNSW);
3271 }
3272 }
3273
3274 const SCEV *NewRec = getAddRecExpr(NewOps, AddRec->getLoop(), Flags);
3275
3276 // If all of the other operands were loop invariant, we are done.
3277 if (Ops.size() == 1) return NewRec;
3278
3279 // Otherwise, multiply the folded AddRec by the non-invariant parts.
3280 for (unsigned i = 0;; ++i)
3281 if (Ops[i] == AddRec) {
3282 Ops[i] = NewRec;
3283 break;
3284 }
3285 return getMulExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
3286 }
3287
3288 // Okay, if there weren't any loop invariants to be folded, check to see
3289 // if there are multiple AddRec's with the same loop induction variable
3290 // being multiplied together. If so, we can fold them.
3291
3292 // {A1,+,A2,+,...,+,An}<L> * {B1,+,B2,+,...,+,Bn}<L>
3293 // = {x=1 in [ sum y=x..2x [ sum z=max(y-x, y-n)..min(x,n) [
3294 // choose(x, 2x)*choose(2x-y, x-z)*A_{y-z}*B_z
3295 // ]]],+,...up to x=2n}.
3296 // Note that the arguments to choose() are always integers with values
3297 // known at compile time, never SCEV objects.
3298 //
3299 // The implementation avoids pointless extra computations when the two
3300 // addrec's are of different length (mathematically, it's equivalent to
3301 // an infinite stream of zeros on the right).
3302 bool OpsModified = false;
3303 for (unsigned OtherIdx = Idx+1;
3304 OtherIdx != Ops.size() && isa<SCEVAddRecExpr>(Ops[OtherIdx]);
3305 ++OtherIdx) {
3306 const SCEVAddRecExpr *OtherAddRec =
3307 dyn_cast<SCEVAddRecExpr>(Ops[OtherIdx]);
3308 if (!OtherAddRec || OtherAddRec->getLoop() != AddRec->getLoop())
3309 continue;
3310
3311 // Limit max number of arguments to avoid creation of unreasonably big
3312 // SCEVAddRecs with very complex operands.
3313 if (AddRec->getNumOperands() + OtherAddRec->getNumOperands() - 1 >
3314 MaxAddRecSize || hasHugeExpression({AddRec, OtherAddRec}))
3315 continue;
3316
3317 bool Overflow = false;
3318 Type *Ty = AddRec->getType();
3319 bool LargerThan64Bits = getTypeSizeInBits(Ty) > 64;
3321 for (int x = 0, xe = AddRec->getNumOperands() +
3322 OtherAddRec->getNumOperands() - 1; x != xe && !Overflow; ++x) {
3323 SmallVector <const SCEV *, 7> SumOps;
3324 for (int y = x, ye = 2*x+1; y != ye && !Overflow; ++y) {
3325 uint64_t Coeff1 = Choose(x, 2*x - y, Overflow);
3326 for (int z = std::max(y-x, y-(int)AddRec->getNumOperands()+1),
3327 ze = std::min(x+1, (int)OtherAddRec->getNumOperands());
3328 z < ze && !Overflow; ++z) {
3329 uint64_t Coeff2 = Choose(2*x - y, x-z, Overflow);
3330 uint64_t Coeff;
3331 if (LargerThan64Bits)
3332 Coeff = umul_ov(Coeff1, Coeff2, Overflow);
3333 else
3334 Coeff = Coeff1*Coeff2;
3335 const SCEV *CoeffTerm = getConstant(Ty, Coeff);
3336 const SCEV *Term1 = AddRec->getOperand(y-z);
3337 const SCEV *Term2 = OtherAddRec->getOperand(z);
3338 SumOps.push_back(getMulExpr(CoeffTerm, Term1, Term2,
3339 SCEV::FlagAnyWrap, Depth + 1));
3340 }
3341 }
3342 if (SumOps.empty())
3343 SumOps.push_back(getZero(Ty));
3344 AddRecOps.push_back(getAddExpr(SumOps, SCEV::FlagAnyWrap, Depth + 1));
3345 }
3346 if (!Overflow) {
3347 const SCEV *NewAddRec = getAddRecExpr(AddRecOps, AddRec->getLoop(),
3349 if (Ops.size() == 2) return NewAddRec;
3350 Ops[Idx] = NewAddRec;
3351 Ops.erase(Ops.begin() + OtherIdx); --OtherIdx;
3352 OpsModified = true;
3353 AddRec = dyn_cast<SCEVAddRecExpr>(NewAddRec);
3354 if (!AddRec)
3355 break;
3356 }
3357 }
3358 if (OpsModified)
3359 return getMulExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
3360
3361 // Otherwise couldn't fold anything into this recurrence. Move onto the
3362 // next one.
3363 }
3364
3365 // Okay, it looks like we really DO need an mul expr. Check to see if we
3366 // already have one, otherwise create a new one.
3367 return getOrCreateMulExpr(Ops, ComputeFlags(Ops));
3368}
3369
3370/// Represents an unsigned remainder expression based on unsigned division.
3372 const SCEV *RHS) {
3375 "SCEVURemExpr operand types don't match!");
3376
3377 // Short-circuit easy cases
3378 if (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS)) {
3379 // If constant is one, the result is trivial
3380 if (RHSC->getValue()->isOne())
3381 return getZero(LHS->getType()); // X urem 1 --> 0
3382
3383 // If constant is a power of two, fold into a zext(trunc(LHS)).
3384 if (RHSC->getAPInt().isPowerOf2()) {
3385 Type *FullTy = LHS->getType();
3386 Type *TruncTy =
3387 IntegerType::get(getContext(), RHSC->getAPInt().logBase2());
3388 return getZeroExtendExpr(getTruncateExpr(LHS, TruncTy), FullTy);
3389 }
3390 }
3391
3392 // Fallback to %a == %x urem %y == %x -<nuw> ((%x udiv %y) *<nuw> %y)
3393 const SCEV *UDiv = getUDivExpr(LHS, RHS);
3394 const SCEV *Mult = getMulExpr(UDiv, RHS, SCEV::FlagNUW);
3395 return getMinusSCEV(LHS, Mult, SCEV::FlagNUW);
3396}
3397
3398/// Get a canonical unsigned division expression, or something simpler if
3399/// possible.
3401 const SCEV *RHS) {
3402 assert(!LHS->getType()->isPointerTy() &&
3403 "SCEVUDivExpr operand can't be pointer!");
3404 assert(LHS->getType() == RHS->getType() &&
3405 "SCEVUDivExpr operand types don't match!");
3406
3408 ID.AddInteger(scUDivExpr);
3409 ID.AddPointer(LHS);
3410 ID.AddPointer(RHS);
3411 void *IP = nullptr;
3412 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP))
3413 return S;
3414
3415 // 0 udiv Y == 0
3416 if (match(LHS, m_scev_Zero()))
3417 return LHS;
3418
3419 if (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS)) {
3420 if (RHSC->getValue()->isOne())
3421 return LHS; // X udiv 1 --> x
3422 // If the denominator is zero, the result of the udiv is undefined. Don't
3423 // try to analyze it, because the resolution chosen here may differ from
3424 // the resolution chosen in other parts of the compiler.
3425 if (!RHSC->getValue()->isZero()) {
3426 // Determine if the division can be folded into the operands of
3427 // its operands.
3428 // TODO: Generalize this to non-constants by using known-bits information.
3429 Type *Ty = LHS->getType();
3430 unsigned LZ = RHSC->getAPInt().countl_zero();
3431 unsigned MaxShiftAmt = getTypeSizeInBits(Ty) - LZ - 1;
3432 // For non-power-of-two values, effectively round the value up to the
3433 // nearest power of two.
3434 if (!RHSC->getAPInt().isPowerOf2())
3435 ++MaxShiftAmt;
3436 IntegerType *ExtTy =
3437 IntegerType::get(getContext(), getTypeSizeInBits(Ty) + MaxShiftAmt);
3438 if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(LHS))
3439 if (const SCEVConstant *Step =
3440 dyn_cast<SCEVConstant>(AR->getStepRecurrence(*this))) {
3441 // {X,+,N}/C --> {X/C,+,N/C} if safe and N/C can be folded.
3442 const APInt &StepInt = Step->getAPInt();
3443 const APInt &DivInt = RHSC->getAPInt();
3444 if (!StepInt.urem(DivInt) &&
3445 getZeroExtendExpr(AR, ExtTy) ==
3446 getAddRecExpr(getZeroExtendExpr(AR->getStart(), ExtTy),
3447 getZeroExtendExpr(Step, ExtTy),
3448 AR->getLoop(), SCEV::FlagAnyWrap)) {
3450 for (const SCEV *Op : AR->operands())
3451 Operands.push_back(getUDivExpr(Op, RHS));
3452 return getAddRecExpr(Operands, AR->getLoop(), SCEV::FlagNW);
3453 }
3454 /// Get a canonical UDivExpr for a recurrence.
3455 /// {X,+,N}/C => {Y,+,N}/C where Y=X-(X%N). Safe when C%N=0.
3456 // We can currently only fold X%N if X is constant.
3457 const SCEVConstant *StartC = dyn_cast<SCEVConstant>(AR->getStart());
3458 if (StartC && !DivInt.urem(StepInt) &&
3459 getZeroExtendExpr(AR, ExtTy) ==
3460 getAddRecExpr(getZeroExtendExpr(AR->getStart(), ExtTy),
3461 getZeroExtendExpr(Step, ExtTy),
3462 AR->getLoop(), SCEV::FlagAnyWrap)) {
3463 const APInt &StartInt = StartC->getAPInt();
3464 const APInt &StartRem = StartInt.urem(StepInt);
3465 if (StartRem != 0) {
3466 const SCEV *NewLHS =
3467 getAddRecExpr(getConstant(StartInt - StartRem), Step,
3468 AR->getLoop(), SCEV::FlagNW);
3469 if (LHS != NewLHS) {
3470 LHS = NewLHS;
3471
3472 // Reset the ID to include the new LHS, and check if it is
3473 // already cached.
3474 ID.clear();
3475 ID.AddInteger(scUDivExpr);
3476 ID.AddPointer(LHS);
3477 ID.AddPointer(RHS);
3478 IP = nullptr;
3479 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP))
3480 return S;
3481 }
3482 }
3483 }
3484 }
3485 // (A*B)/C --> A*(B/C) if safe and B/C can be folded.
3486 if (const SCEVMulExpr *M = dyn_cast<SCEVMulExpr>(LHS)) {
3488 for (const SCEV *Op : M->operands())
3489 Operands.push_back(getZeroExtendExpr(Op, ExtTy));
3490 if (getZeroExtendExpr(M, ExtTy) == getMulExpr(Operands))
3491 // Find an operand that's safely divisible.
3492 for (unsigned i = 0, e = M->getNumOperands(); i != e; ++i) {
3493 const SCEV *Op = M->getOperand(i);
3494 const SCEV *Div = getUDivExpr(Op, RHSC);
3495 if (!isa<SCEVUDivExpr>(Div) && getMulExpr(Div, RHSC) == Op) {
3496 Operands = SmallVector<const SCEV *, 4>(M->operands());
3497 Operands[i] = Div;
3498 return getMulExpr(Operands);
3499 }
3500 }
3501 }
3502
3503 // (A/B)/C --> A/(B*C) if safe and B*C can be folded.
3504 if (const SCEVUDivExpr *OtherDiv = dyn_cast<SCEVUDivExpr>(LHS)) {
3505 if (auto *DivisorConstant =
3506 dyn_cast<SCEVConstant>(OtherDiv->getRHS())) {
3507 bool Overflow = false;
3508 APInt NewRHS =
3509 DivisorConstant->getAPInt().umul_ov(RHSC->getAPInt(), Overflow);
3510 if (Overflow) {
3511 return getConstant(RHSC->getType(), 0, false);
3512 }
3513 return getUDivExpr(OtherDiv->getLHS(), getConstant(NewRHS));
3514 }
3515 }
3516
3517 // (A+B)/C --> (A/C + B/C) if safe and A/C and B/C can be folded.
3518 if (const SCEVAddExpr *A = dyn_cast<SCEVAddExpr>(LHS)) {
3520 for (const SCEV *Op : A->operands())
3521 Operands.push_back(getZeroExtendExpr(Op, ExtTy));
3522 if (getZeroExtendExpr(A, ExtTy) == getAddExpr(Operands)) {
3523 Operands.clear();
3524 for (unsigned i = 0, e = A->getNumOperands(); i != e; ++i) {
3525 const SCEV *Op = getUDivExpr(A->getOperand(i), RHS);
3526 if (isa<SCEVUDivExpr>(Op) ||
3527 getMulExpr(Op, RHS) != A->getOperand(i))
3528 break;
3529 Operands.push_back(Op);
3530 }
3531 if (Operands.size() == A->getNumOperands())
3532 return getAddExpr(Operands);
3533 }
3534 }
3535
3536 // Fold if both operands are constant.
3537 if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(LHS))
3538 return getConstant(LHSC->getAPInt().udiv(RHSC->getAPInt()));
3539 }
3540 }
3541
3542 // ((-C + (C smax %x)) /u %x) evaluates to zero, for any positive constant C.
3543 if (const auto *AE = dyn_cast<SCEVAddExpr>(LHS);
3544 AE && AE->getNumOperands() == 2) {
3545 if (const auto *VC = dyn_cast<SCEVConstant>(AE->getOperand(0))) {
3546 const APInt &NegC = VC->getAPInt();
3547 if (NegC.isNegative() && !NegC.isMinSignedValue()) {
3548 const auto *MME = dyn_cast<SCEVSMaxExpr>(AE->getOperand(1));
3549 if (MME && MME->getNumOperands() == 2 &&
3550 isa<SCEVConstant>(MME->getOperand(0)) &&
3551 cast<SCEVConstant>(MME->getOperand(0))->getAPInt() == -NegC &&
3552 MME->getOperand(1) == RHS)
3553 return getZero(LHS->getType());
3554 }
3555 }
3556 }
3557
3558 // The Insertion Point (IP) might be invalid by now (due to UniqueSCEVs
3559 // changes). Make sure we get a new one.
3560 IP = nullptr;
3561 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
3562 SCEV *S = new (SCEVAllocator) SCEVUDivExpr(ID.Intern(SCEVAllocator),
3563 LHS, RHS);
3564 UniqueSCEVs.InsertNode(S, IP);
3565 registerUser(S, {LHS, RHS});
3566 return S;
3567}
3568
3569APInt gcd(const SCEVConstant *C1, const SCEVConstant *C2) {
3570 APInt A = C1->getAPInt().abs();
3571 APInt B = C2->getAPInt().abs();
3572 uint32_t ABW = A.getBitWidth();
3573 uint32_t BBW = B.getBitWidth();
3574
3575 if (ABW > BBW)
3576 B = B.zext(ABW);
3577 else if (ABW < BBW)
3578 A = A.zext(BBW);
3579
3580 return APIntOps::GreatestCommonDivisor(std::move(A), std::move(B));
3581}
3582
3583/// Get a canonical unsigned division expression, or something simpler if
3584/// possible. There is no representation for an exact udiv in SCEV IR, but we
3585/// can attempt to remove factors from the LHS and RHS. We can't do this when
3586/// it's not exact because the udiv may be clearing bits.
3588 const SCEV *RHS) {
3589 // TODO: we could try to find factors in all sorts of things, but for now we
3590 // just deal with u/exact (multiply, constant). See SCEVDivision towards the
3591 // end of this file for inspiration.
3592
3593 const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(LHS);
3594 if (!Mul || !Mul->hasNoUnsignedWrap())
3595 return getUDivExpr(LHS, RHS);
3596
3597 if (const SCEVConstant *RHSCst = dyn_cast<SCEVConstant>(RHS)) {
3598 // If the mulexpr multiplies by a constant, then that constant must be the
3599 // first element of the mulexpr.
3600 if (const auto *LHSCst = dyn_cast<SCEVConstant>(Mul->getOperand(0))) {
3601 if (LHSCst == RHSCst) {
3603 return getMulExpr(Operands);
3604 }
3605
3606 // We can't just assume that LHSCst divides RHSCst cleanly, it could be
3607 // that there's a factor provided by one of the other terms. We need to
3608 // check.
3609 APInt Factor = gcd(LHSCst, RHSCst);
3610 if (!Factor.isIntN(1)) {
3611 LHSCst =
3612 cast<SCEVConstant>(getConstant(LHSCst->getAPInt().udiv(Factor)));
3613 RHSCst =
3614 cast<SCEVConstant>(getConstant(RHSCst->getAPInt().udiv(Factor)));
3616 Operands.push_back(LHSCst);
3617 append_range(Operands, Mul->operands().drop_front());
3619 RHS = RHSCst;
3620 Mul = dyn_cast<SCEVMulExpr>(LHS);
3621 if (!Mul)
3622 return getUDivExactExpr(LHS, RHS);
3623 }
3624 }
3625 }
3626
3627 for (int i = 0, e = Mul->getNumOperands(); i != e; ++i) {
3628 if (Mul->getOperand(i) == RHS) {
3630 append_range(Operands, Mul->operands().take_front(i));
3631 append_range(Operands, Mul->operands().drop_front(i + 1));
3632 return getMulExpr(Operands);
3633 }
3634 }
3635
3636 return getUDivExpr(LHS, RHS);
3637}
3638
3639/// Get an add recurrence expression for the specified loop. Simplify the
3640/// expression as much as possible.
3641const SCEV *ScalarEvolution::getAddRecExpr(const SCEV *Start, const SCEV *Step,
3642 const Loop *L,
3643 SCEV::NoWrapFlags Flags) {
3645 Operands.push_back(Start);
3646 if (const SCEVAddRecExpr *StepChrec = dyn_cast<SCEVAddRecExpr>(Step))
3647 if (StepChrec->getLoop() == L) {
3648 append_range(Operands, StepChrec->operands());
3649 return getAddRecExpr(Operands, L, maskFlags(Flags, SCEV::FlagNW));
3650 }
3651
3652 Operands.push_back(Step);
3653 return getAddRecExpr(Operands, L, Flags);
3654}
3655
3656/// Get an add recurrence expression for the specified loop. Simplify the
3657/// expression as much as possible.
3658const SCEV *
3660 const Loop *L, SCEV::NoWrapFlags Flags) {
3661 if (Operands.size() == 1) return Operands[0];
3662#ifndef NDEBUG
3664 for (const SCEV *Op : llvm::drop_begin(Operands)) {
3665 assert(getEffectiveSCEVType(Op->getType()) == ETy &&
3666 "SCEVAddRecExpr operand types don't match!");
3667 assert(!Op->getType()->isPointerTy() && "Step must be integer");
3668 }
3669 for (const SCEV *Op : Operands)
3671 "SCEVAddRecExpr operand is not available at loop entry!");
3672#endif
3673
3674 if (Operands.back()->isZero()) {
3675 Operands.pop_back();
3676 return getAddRecExpr(Operands, L, SCEV::FlagAnyWrap); // {X,+,0} --> X
3677 }
3678
3679 // It's tempting to want to call getConstantMaxBackedgeTakenCount count here and
3680 // use that information to infer NUW and NSW flags. However, computing a
3681 // BE count requires calling getAddRecExpr, so we may not yet have a
3682 // meaningful BE count at this point (and if we don't, we'd be stuck
3683 // with a SCEVCouldNotCompute as the cached BE count).
3684
3685 Flags = StrengthenNoWrapFlags(this, scAddRecExpr, Operands, Flags);
3686
3687 // Canonicalize nested AddRecs in by nesting them in order of loop depth.
3688 if (const SCEVAddRecExpr *NestedAR = dyn_cast<SCEVAddRecExpr>(Operands[0])) {
3689 const Loop *NestedLoop = NestedAR->getLoop();
3690 if (L->contains(NestedLoop)
3691 ? (L->getLoopDepth() < NestedLoop->getLoopDepth())
3692 : (!NestedLoop->contains(L) &&
3693 DT.dominates(L->getHeader(), NestedLoop->getHeader()))) {
3694 SmallVector<const SCEV *, 4> NestedOperands(NestedAR->operands());
3695 Operands[0] = NestedAR->getStart();
3696 // AddRecs require their operands be loop-invariant with respect to their
3697 // loops. Don't perform this transformation if it would break this
3698 // requirement.
3699 bool AllInvariant = all_of(
3700 Operands, [&](const SCEV *Op) { return isLoopInvariant(Op, L); });
3701
3702 if (AllInvariant) {
3703 // Create a recurrence for the outer loop with the same step size.
3704 //
3705 // The outer recurrence keeps its NW flag but only keeps NUW/NSW if the
3706 // inner recurrence has the same property.
3707 SCEV::NoWrapFlags OuterFlags =
3708 maskFlags(Flags, SCEV::FlagNW | NestedAR->getNoWrapFlags());
3709
3710 NestedOperands[0] = getAddRecExpr(Operands, L, OuterFlags);
3711 AllInvariant = all_of(NestedOperands, [&](const SCEV *Op) {
3712 return isLoopInvariant(Op, NestedLoop);
3713 });
3714
3715 if (AllInvariant) {
3716 // Ok, both add recurrences are valid after the transformation.
3717 //
3718 // The inner recurrence keeps its NW flag but only keeps NUW/NSW if
3719 // the outer recurrence has the same property.
3720 SCEV::NoWrapFlags InnerFlags =
3721 maskFlags(NestedAR->getNoWrapFlags(), SCEV::FlagNW | Flags);
3722 return getAddRecExpr(NestedOperands, NestedLoop, InnerFlags);
3723 }
3724 }
3725 // Reset Operands to its original state.
3726 Operands[0] = NestedAR;
3727 }
3728 }
3729
3730 // Okay, it looks like we really DO need an addrec expr. Check to see if we
3731 // already have one, otherwise create a new one.
3732 return getOrCreateAddRecExpr(Operands, L, Flags);
3733}
3734
3735const SCEV *
3737 const SmallVectorImpl<const SCEV *> &IndexExprs) {
3738 const SCEV *BaseExpr = getSCEV(GEP->getPointerOperand());
3739 // getSCEV(Base)->getType() has the same address space as Base->getType()
3740 // because SCEV::getType() preserves the address space.
3741 Type *IntIdxTy = getEffectiveSCEVType(BaseExpr->getType());
3742 GEPNoWrapFlags NW = GEP->getNoWrapFlags();
3743 if (NW != GEPNoWrapFlags::none()) {
3744 // We'd like to propagate flags from the IR to the corresponding SCEV nodes,
3745 // but to do that, we have to ensure that said flag is valid in the entire
3746 // defined scope of the SCEV.
3747 // TODO: non-instructions have global scope. We might be able to prove
3748 // some global scope cases
3749 auto *GEPI = dyn_cast<Instruction>(GEP);
3750 if (!GEPI || !isSCEVExprNeverPoison(GEPI))
3751 NW = GEPNoWrapFlags::none();
3752 }
3753
3755 if (NW.hasNoUnsignedSignedWrap())
3756 OffsetWrap = setFlags(OffsetWrap, SCEV::FlagNSW);
3757 if (NW.hasNoUnsignedWrap())
3758 OffsetWrap = setFlags(OffsetWrap, SCEV::FlagNUW);
3759
3760 Type *CurTy = GEP->getType();
3761 bool FirstIter = true;
3763 for (const SCEV *IndexExpr : IndexExprs) {
3764 // Compute the (potentially symbolic) offset in bytes for this index.
3765 if (StructType *STy = dyn_cast<StructType>(CurTy)) {
3766 // For a struct, add the member offset.
3767 ConstantInt *Index = cast<SCEVConstant>(IndexExpr)->getValue();
3768 unsigned FieldNo = Index->getZExtValue();
3769 const SCEV *FieldOffset = getOffsetOfExpr(IntIdxTy, STy, FieldNo);
3770 Offsets.push_back(FieldOffset);
3771
3772 // Update CurTy to the type of the field at Index.
3773 CurTy = STy->getTypeAtIndex(Index);
3774 } else {
3775 // Update CurTy to its element type.
3776 if (FirstIter) {
3777 assert(isa<PointerType>(CurTy) &&
3778 "The first index of a GEP indexes a pointer");
3779 CurTy = GEP->getSourceElementType();
3780 FirstIter = false;
3781 } else {
3783 }
3784 // For an array, add the element offset, explicitly scaled.
3785 const SCEV *ElementSize = getSizeOfExpr(IntIdxTy, CurTy);
3786 // Getelementptr indices are signed.
3787 IndexExpr = getTruncateOrSignExtend(IndexExpr, IntIdxTy);
3788
3789 // Multiply the index by the element size to compute the element offset.
3790 const SCEV *LocalOffset = getMulExpr(IndexExpr, ElementSize, OffsetWrap);
3791 Offsets.push_back(LocalOffset);
3792 }
3793 }
3794
3795 // Handle degenerate case of GEP without offsets.
3796 if (Offsets.empty())
3797 return BaseExpr;
3798
3799 // Add the offsets together, assuming nsw if inbounds.
3800 const SCEV *Offset = getAddExpr(Offsets, OffsetWrap);
3801 // Add the base address and the offset. We cannot use the nsw flag, as the
3802 // base address is unsigned. However, if we know that the offset is
3803 // non-negative, we can use nuw.
3804 bool NUW = NW.hasNoUnsignedWrap() ||
3807 auto *GEPExpr = getAddExpr(BaseExpr, Offset, BaseWrap);
3808 assert(BaseExpr->getType() == GEPExpr->getType() &&
3809 "GEP should not change type mid-flight.");
3810 return GEPExpr;
3811}
3812
3813SCEV *ScalarEvolution::findExistingSCEVInCache(SCEVTypes SCEVType,
3816 ID.AddInteger(SCEVType);
3817 for (const SCEV *Op : Ops)
3818 ID.AddPointer(Op);
3819 void *IP = nullptr;
3820 return UniqueSCEVs.FindNodeOrInsertPos(ID, IP);
3821}
3822
3823const SCEV *ScalarEvolution::getAbsExpr(const SCEV *Op, bool IsNSW) {
3825 return getSMaxExpr(Op, getNegativeSCEV(Op, Flags));
3826}
3827
3830 assert(SCEVMinMaxExpr::isMinMaxType(Kind) && "Not a SCEVMinMaxExpr!");
3831 assert(!Ops.empty() && "Cannot get empty (u|s)(min|max)!");
3832 if (Ops.size() == 1) return Ops[0];
3833#ifndef NDEBUG
3834 Type *ETy = getEffectiveSCEVType(Ops[0]->getType());
3835 for (unsigned i = 1, e = Ops.size(); i != e; ++i) {
3836 assert(getEffectiveSCEVType(Ops[i]->getType()) == ETy &&
3837 "Operand types don't match!");
3838 assert(Ops[0]->getType()->isPointerTy() ==
3839 Ops[i]->getType()->isPointerTy() &&
3840 "min/max should be consistently pointerish");
3841 }
3842#endif
3843
3844 bool IsSigned = Kind == scSMaxExpr || Kind == scSMinExpr;
3845 bool IsMax = Kind == scSMaxExpr || Kind == scUMaxExpr;
3846
3847 const SCEV *Folded = constantFoldAndGroupOps(
3848 *this, LI, DT, Ops,
3849 [&](const APInt &C1, const APInt &C2) {
3850 switch (Kind) {
3851 case scSMaxExpr:
3852 return APIntOps::smax(C1, C2);
3853 case scSMinExpr:
3854 return APIntOps::smin(C1, C2);
3855 case scUMaxExpr:
3856 return APIntOps::umax(C1, C2);
3857 case scUMinExpr:
3858 return APIntOps::umin(C1, C2);
3859 default:
3860 llvm_unreachable("Unknown SCEV min/max opcode");
3861 }
3862 },
3863 [&](const APInt &C) {
3864 // identity
3865 if (IsMax)
3866 return IsSigned ? C.isMinSignedValue() : C.isMinValue();
3867 else
3868 return IsSigned ? C.isMaxSignedValue() : C.isMaxValue();
3869 },
3870 [&](const APInt &C) {
3871 // absorber
3872 if (IsMax)
3873 return IsSigned ? C.isMaxSignedValue() : C.isMaxValue();
3874 else
3875 return IsSigned ? C.isMinSignedValue() : C.isMinValue();
3876 });
3877 if (Folded)
3878 return Folded;
3879
3880 // Check if we have created the same expression before.
3881 if (const SCEV *S = findExistingSCEVInCache(Kind, Ops)) {
3882 return S;
3883 }
3884
3885 // Find the first operation of the same kind
3886 unsigned Idx = 0;
3887 while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < Kind)
3888 ++Idx;
3889
3890 // Check to see if one of the operands is of the same kind. If so, expand its
3891 // operands onto our operand list, and recurse to simplify.
3892 if (Idx < Ops.size()) {
3893 bool DeletedAny = false;
3894 while (Ops[Idx]->getSCEVType() == Kind) {
3895 const SCEVMinMaxExpr *SMME = cast<SCEVMinMaxExpr>(Ops[Idx]);
3896 Ops.erase(Ops.begin()+Idx);
3897 append_range(Ops, SMME->operands());
3898 DeletedAny = true;
3899 }
3900
3901 if (DeletedAny)
3902 return getMinMaxExpr(Kind, Ops);
3903 }
3904
3905 // Okay, check to see if the same value occurs in the operand list twice. If
3906 // so, delete one. Since we sorted the list, these values are required to
3907 // be adjacent.
3912 llvm::CmpInst::Predicate FirstPred = IsMax ? GEPred : LEPred;
3913 llvm::CmpInst::Predicate SecondPred = IsMax ? LEPred : GEPred;
3914 for (unsigned i = 0, e = Ops.size() - 1; i != e; ++i) {
3915 if (Ops[i] == Ops[i + 1] ||
3916 isKnownViaNonRecursiveReasoning(FirstPred, Ops[i], Ops[i + 1])) {
3917 // X op Y op Y --> X op Y
3918 // X op Y --> X, if we know X, Y are ordered appropriately
3919 Ops.erase(Ops.begin() + i + 1, Ops.begin() + i + 2);
3920 --i;
3921 --e;
3922 } else if (isKnownViaNonRecursiveReasoning(SecondPred, Ops[i],
3923 Ops[i + 1])) {
3924 // X op Y --> Y, if we know X, Y are ordered appropriately
3925 Ops.erase(Ops.begin() + i, Ops.begin() + i + 1);
3926 --i;
3927 --e;
3928 }
3929 }
3930
3931 if (Ops.size() == 1) return Ops[0];
3932
3933 assert(!Ops.empty() && "Reduced smax down to nothing!");
3934
3935 // Okay, it looks like we really DO need an expr. Check to see if we
3936 // already have one, otherwise create a new one.
3938 ID.AddInteger(Kind);
3939 for (const SCEV *Op : Ops)
3940 ID.AddPointer(Op);
3941 void *IP = nullptr;
3942 const SCEV *ExistingSCEV = UniqueSCEVs.FindNodeOrInsertPos(ID, IP);
3943 if (ExistingSCEV)
3944 return ExistingSCEV;
3945 const SCEV **O = SCEVAllocator.Allocate<const SCEV *>(Ops.size());
3946 std::uninitialized_copy(Ops.begin(), Ops.end(), O);
3947 SCEV *S = new (SCEVAllocator)
3948 SCEVMinMaxExpr(ID.Intern(SCEVAllocator), Kind, O, Ops.size());
3949
3950 UniqueSCEVs.InsertNode(S, IP);
3951 registerUser(S, Ops);
3952 return S;
3953}
3954
3955namespace {
3956
3957class SCEVSequentialMinMaxDeduplicatingVisitor final
3958 : public SCEVVisitor<SCEVSequentialMinMaxDeduplicatingVisitor,
3959 std::optional<const SCEV *>> {
3960 using RetVal = std::optional<const SCEV *>;
3962
3963 ScalarEvolution &SE;
3964 const SCEVTypes RootKind; // Must be a sequential min/max expression.
3965 const SCEVTypes NonSequentialRootKind; // Non-sequential variant of RootKind.
3967
3968 bool canRecurseInto(SCEVTypes Kind) const {
3969 // We can only recurse into the SCEV expression of the same effective type
3970 // as the type of our root SCEV expression.
3971 return RootKind == Kind || NonSequentialRootKind == Kind;
3972 };
3973
3974 RetVal visitAnyMinMaxExpr(const SCEV *S) {
3975 assert((isa<SCEVMinMaxExpr>(S) || isa<SCEVSequentialMinMaxExpr>(S)) &&
3976 "Only for min/max expressions.");
3977 SCEVTypes Kind = S->getSCEVType();
3978
3979 if (!canRecurseInto(Kind))
3980 return S;
3981
3982 auto *NAry = cast<SCEVNAryExpr>(S);
3984 bool Changed = visit(Kind, NAry->operands(), NewOps);
3985
3986 if (!Changed)
3987 return S;
3988 if (NewOps.empty())
3989 return std::nullopt;
3990
3991 return isa<SCEVSequentialMinMaxExpr>(S)
3992 ? SE.getSequentialMinMaxExpr(Kind, NewOps)
3993 : SE.getMinMaxExpr(Kind, NewOps);
3994 }
3995
3996 RetVal visit(const SCEV *S) {
3997 // Has the whole operand been seen already?
3998 if (!SeenOps.insert(S).second)
3999 return std::nullopt;
4000 return Base::visit(S);
4001 }
4002
4003public:
4004 SCEVSequentialMinMaxDeduplicatingVisitor(ScalarEvolution &SE,
4005 SCEVTypes RootKind)
4006 : SE(SE), RootKind(RootKind),
4007 NonSequentialRootKind(
4008 SCEVSequentialMinMaxExpr::getEquivalentNonSequentialSCEVType(
4009 RootKind)) {}
4010
4011 bool /*Changed*/ visit(SCEVTypes Kind, ArrayRef<const SCEV *> OrigOps,
4013 bool Changed = false;
4015 Ops.reserve(OrigOps.size());
4016
4017 for (const SCEV *Op : OrigOps) {
4018 RetVal NewOp = visit(Op);
4019 if (NewOp != Op)
4020 Changed = true;
4021 if (NewOp)
4022 Ops.emplace_back(*NewOp);
4023 }
4024
4025 if (Changed)
4026 NewOps = std::move(Ops);
4027 return Changed;
4028 }
4029
4030 RetVal visitConstant(const SCEVConstant *Constant) { return Constant; }
4031
4032 RetVal visitVScale(const SCEVVScale *VScale) { return VScale; }
4033
4034 RetVal visitPtrToIntExpr(const SCEVPtrToIntExpr *Expr) { return Expr; }
4035
4036 RetVal visitTruncateExpr(const SCEVTruncateExpr *Expr) { return Expr; }
4037
4038 RetVal visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr) { return Expr; }
4039
4040 RetVal visitSignExtendExpr(const SCEVSignExtendExpr *Expr) { return Expr; }
4041
4042 RetVal visitAddExpr(const SCEVAddExpr *Expr) { return Expr; }
4043
4044 RetVal visitMulExpr(const SCEVMulExpr *Expr) { return Expr; }
4045
4046 RetVal visitUDivExpr(const SCEVUDivExpr *Expr) { return Expr; }
4047
4048 RetVal visitAddRecExpr(const SCEVAddRecExpr *Expr) { return Expr; }
4049
4050 RetVal visitSMaxExpr(const SCEVSMaxExpr *Expr) {
4051 return visitAnyMinMaxExpr(Expr);
4052 }
4053
4054 RetVal visitUMaxExpr(const SCEVUMaxExpr *Expr) {
4055 return visitAnyMinMaxExpr(Expr);
4056 }
4057
4058 RetVal visitSMinExpr(const SCEVSMinExpr *Expr) {
4059 return visitAnyMinMaxExpr(Expr);
4060 }
4061
4062 RetVal visitUMinExpr(const SCEVUMinExpr *Expr) {
4063 return visitAnyMinMaxExpr(Expr);
4064 }
4065
4066 RetVal visitSequentialUMinExpr(const SCEVSequentialUMinExpr *Expr) {
4067 return visitAnyMinMaxExpr(Expr);
4068 }
4069
4070 RetVal visitUnknown(const SCEVUnknown *Expr) { return Expr; }
4071
4072 RetVal visitCouldNotCompute(const SCEVCouldNotCompute *Expr) { return Expr; }
4073};
4074
4075} // namespace
4076
4078 switch (Kind) {
4079 case scConstant:
4080 case scVScale:
4081 case scTruncate:
4082 case scZeroExtend:
4083 case scSignExtend:
4084 case scPtrToInt:
4085 case scAddExpr:
4086 case scMulExpr:
4087 case scUDivExpr:
4088 case scAddRecExpr:
4089 case scUMaxExpr:
4090 case scSMaxExpr:
4091 case scUMinExpr:
4092 case scSMinExpr:
4093 case scUnknown:
4094 // If any operand is poison, the whole expression is poison.
4095 return true;
4097 // FIXME: if the *first* operand is poison, the whole expression is poison.
4098 return false; // Pessimistically, say that it does not propagate poison.
4099 case scCouldNotCompute:
4100 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
4101 }
4102 llvm_unreachable("Unknown SCEV kind!");
4103}
4104
4105namespace {
4106// The only way poison may be introduced in a SCEV expression is from a
4107// poison SCEVUnknown (ConstantExprs are also represented as SCEVUnknown,
4108// not SCEVConstant). Notably, nowrap flags in SCEV nodes can *not*
4109// introduce poison -- they encode guaranteed, non-speculated knowledge.
4110//
4111// Additionally, all SCEV nodes propagate poison from inputs to outputs,
4112// with the notable exception of umin_seq, where only poison from the first
4113// operand is (unconditionally) propagated.
4114struct SCEVPoisonCollector {
4115 bool LookThroughMaybePoisonBlocking;
4117 SCEVPoisonCollector(bool LookThroughMaybePoisonBlocking)
4118 : LookThroughMaybePoisonBlocking(LookThroughMaybePoisonBlocking) {}
4119
4120 bool follow(const SCEV *S) {
4121 if (!LookThroughMaybePoisonBlocking &&
4123 return false;
4124
4125 if (auto *SU = dyn_cast<SCEVUnknown>(S)) {
4126 if (!isGuaranteedNotToBePoison(SU->getValue()))
4127 MaybePoison.insert(SU);
4128 }
4129 return true;
4130 }
4131 bool isDone() const { return false; }
4132};
4133} // namespace
4134
4135/// Return true if V is poison given that AssumedPoison is already poison.
4136static bool impliesPoison(const SCEV *AssumedPoison, const SCEV *S) {
4137 // First collect all SCEVs that might result in AssumedPoison to be poison.
4138 // We need to look through potentially poison-blocking operations here,
4139 // because we want to find all SCEVs that *might* result in poison, not only
4140 // those that are *required* to.
4141 SCEVPoisonCollector PC1(/* LookThroughMaybePoisonBlocking */ true);
4142 visitAll(AssumedPoison, PC1);
4143
4144 // AssumedPoison is never poison. As the assumption is false, the implication
4145 // is true. Don't bother walking the other SCEV in this case.
4146 if (PC1.MaybePoison.empty())
4147 return true;
4148
4149 // Collect all SCEVs in S that, if poison, *will* result in S being poison
4150 // as well. We cannot look through potentially poison-blocking operations
4151 // here, as their arguments only *may* make the result poison.
4152 SCEVPoisonCollector PC2(/* LookThroughMaybePoisonBlocking */ false);
4153 visitAll(S, PC2);
4154
4155 // Make sure that no matter which SCEV in PC1.MaybePoison is actually poison,
4156 // it will also make S poison by being part of PC2.MaybePoison.
4157 return llvm::set_is_subset(PC1.MaybePoison, PC2.MaybePoison);
4158}
4159
4161 SmallPtrSetImpl<const Value *> &Result, const SCEV *S) {
4162 SCEVPoisonCollector PC(/* LookThroughMaybePoisonBlocking */ false);
4163 visitAll(S, PC);
4164 for (const SCEVUnknown *SU : PC.MaybePoison)
4165 Result.insert(SU->getValue());
4166}
4167
4169 const SCEV *S, Instruction *I,
4170 SmallVectorImpl<Instruction *> &DropPoisonGeneratingInsts) {
4171 // If the instruction cannot be poison, it's always safe to reuse.
4173 return true;
4174
4175 // Otherwise, it is possible that I is more poisonous that S. Collect the
4176 // poison-contributors of S, and then check whether I has any additional
4177 // poison-contributors. Poison that is contributed through poison-generating
4178 // flags is handled by dropping those flags instead.
4180 getPoisonGeneratingValues(PoisonVals, S);
4181
4182 SmallVector<Value *> Worklist;
4184 Worklist.push_back(I);
4185 while (!Worklist.empty()) {
4186 Value *V = Worklist.pop_back_val();
4187 if (!Visited.insert(V).second)
4188 continue;
4189
4190 // Avoid walking large instruction graphs.
4191 if (Visited.size() > 16)
4192 return false;
4193
4194 // Either the value can't be poison, or the S would also be poison if it
4195 // is.
4196 if (PoisonVals.contains(V) || ::isGuaranteedNotToBePoison(V))
4197 continue;
4198
4199 auto *I = dyn_cast<Instruction>(V);
4200 if (!I)
4201 return false;
4202
4203 // Disjoint or instructions are interpreted as adds by SCEV. However, we
4204 // can't replace an arbitrary add with disjoint or, even if we drop the
4205 // flag. We would need to convert the or into an add.
4206 if (auto *PDI = dyn_cast<PossiblyDisjointInst>(I))
4207 if (PDI->isDisjoint())
4208 return false;
4209
4210 // FIXME: Ignore vscale, even though it technically could be poison. Do this
4211 // because SCEV currently assumes it can't be poison. Remove this special
4212 // case once we proper model when vscale can be poison.
4213 if (auto *II = dyn_cast<IntrinsicInst>(I);
4214 II && II->getIntrinsicID() == Intrinsic::vscale)
4215 continue;
4216
4217 if (canCreatePoison(cast<Operator>(I), /*ConsiderFlagsAndMetadata*/ false))
4218 return false;
4219
4220 // If the instruction can't create poison, we can recurse to its operands.
4221 if (I->hasPoisonGeneratingAnnotations())
4222 DropPoisonGeneratingInsts.push_back(I);
4223
4224 for (Value *Op : I->operands())
4225 Worklist.push_back(Op);
4226 }
4227 return true;
4228}
4229
4230const SCEV *
4233 assert(SCEVSequentialMinMaxExpr::isSequentialMinMaxType(Kind) &&
4234 "Not a SCEVSequentialMinMaxExpr!");
4235 assert(!Ops.empty() && "Cannot get empty (u|s)(min|max)!");
4236 if (Ops.size() == 1)
4237 return Ops[0];
4238#ifndef NDEBUG
4239 Type *ETy = getEffectiveSCEVType(Ops[0]->getType());
4240 for (unsigned i = 1, e = Ops.size(); i != e; ++i) {
4241 assert(getEffectiveSCEVType(Ops[i]->getType()) == ETy &&
4242 "Operand types don't match!");
4243 assert(Ops[0]->getType()->isPointerTy() ==
4244 Ops[i]->getType()->isPointerTy() &&
4245 "min/max should be consistently pointerish");
4246 }
4247#endif
4248
4249 // Note that SCEVSequentialMinMaxExpr is *NOT* commutative,
4250 // so we can *NOT* do any kind of sorting of the expressions!
4251
4252 // Check if we have created the same expression before.
4253 if (const SCEV *S = findExistingSCEVInCache(Kind, Ops))
4254 return S;
4255
4256 // FIXME: there are *some* simplifications that we can do here.
4257
4258 // Keep only the first instance of an operand.
4259 {
4260 SCEVSequentialMinMaxDeduplicatingVisitor Deduplicator(*this, Kind);
4261 bool Changed = Deduplicator.visit(Kind, Ops, Ops);
4262 if (Changed)
4263 return getSequentialMinMaxExpr(Kind, Ops);
4264 }
4265
4266 // Check to see if one of the operands is of the same kind. If so, expand its
4267 // operands onto our operand list, and recurse to simplify.
4268 {
4269 unsigned Idx = 0;
4270 bool DeletedAny = false;
4271 while (Idx < Ops.size()) {
4272 if (Ops[Idx]->getSCEVType() != Kind) {
4273 ++Idx;
4274 continue;
4275 }
4276 const auto *SMME = cast<SCEVSequentialMinMaxExpr>(Ops[Idx]);
4277 Ops.erase(Ops.begin() + Idx);
4278 Ops.insert(Ops.begin() + Idx, SMME->operands().begin(),
4279 SMME->operands().end());
4280 DeletedAny = true;
4281 }
4282
4283 if (DeletedAny)
4284 return getSequentialMinMaxExpr(Kind, Ops);
4285 }
4286
4287 const SCEV *SaturationPoint;
4289 switch (Kind) {
4291 SaturationPoint = getZero(Ops[0]->getType());
4292 Pred = ICmpInst::ICMP_ULE;
4293 break;
4294 default:
4295 llvm_unreachable("Not a sequential min/max type.");
4296 }
4297
4298 for (unsigned i = 1, e = Ops.size(); i != e; ++i) {
4299 if (!isGuaranteedNotToCauseUB(Ops[i]))
4300 continue;
4301 // We can replace %x umin_seq %y with %x umin %y if either:
4302 // * %y being poison implies %x is also poison.
4303 // * %x cannot be the saturating value (e.g. zero for umin).
4304 if (::impliesPoison(Ops[i], Ops[i - 1]) ||
4305 isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_NE, Ops[i - 1],
4306 SaturationPoint)) {
4307 SmallVector<const SCEV *> SeqOps = {Ops[i - 1], Ops[i]};
4308 Ops[i - 1] = getMinMaxExpr(
4310 SeqOps);
4311 Ops.erase(Ops.begin() + i);
4312 return getSequentialMinMaxExpr(Kind, Ops);
4313 }
4314 // Fold %x umin_seq %y to %x if %x ule %y.
4315 // TODO: We might be able to prove the predicate for a later operand.
4316 if (isKnownViaNonRecursiveReasoning(Pred, Ops[i - 1], Ops[i])) {
4317 Ops.erase(Ops.begin() + i);
4318 return getSequentialMinMaxExpr(Kind, Ops);
4319 }
4320 }
4321
4322 // Okay, it looks like we really DO need an expr. Check to see if we
4323 // already have one, otherwise create a new one.
4325 ID.AddInteger(Kind);
4326 for (const SCEV *Op : Ops)
4327 ID.AddPointer(Op);
4328 void *IP = nullptr;
4329 const SCEV *ExistingSCEV = UniqueSCEVs.FindNodeOrInsertPos(ID, IP);
4330 if (ExistingSCEV)
4331 return ExistingSCEV;
4332
4333 const SCEV **O = SCEVAllocator.Allocate<const SCEV *>(Ops.size());
4334 std::uninitialized_copy(Ops.begin(), Ops.end(), O);
4335 SCEV *S = new (SCEVAllocator)
4336 SCEVSequentialMinMaxExpr(ID.Intern(SCEVAllocator), Kind, O, Ops.size());
4337
4338 UniqueSCEVs.InsertNode(S, IP);
4339 registerUser(S, Ops);
4340 return S;
4341}
4342
4343const SCEV *ScalarEvolution::getSMaxExpr(const SCEV *LHS, const SCEV *RHS) {
4345 return getSMaxExpr(Ops);
4346}
4347
4349 return getMinMaxExpr(scSMaxExpr, Ops);
4350}
4351
4352const SCEV *ScalarEvolution::getUMaxExpr(const SCEV *LHS, const SCEV *RHS) {
4354 return getUMaxExpr(Ops);
4355}
4356
4358 return getMinMaxExpr(scUMaxExpr, Ops);
4359}
4360
4362 const SCEV *RHS) {
4364 return getSMinExpr(Ops);
4365}
4366
4368 return getMinMaxExpr(scSMinExpr, Ops);
4369}
4370
4371const SCEV *ScalarEvolution::getUMinExpr(const SCEV *LHS, const SCEV *RHS,
4372 bool Sequential) {
4374 return getUMinExpr(Ops, Sequential);
4375}
4376
4378 bool Sequential) {
4379 return Sequential ? getSequentialMinMaxExpr(scSequentialUMinExpr, Ops)
4380 : getMinMaxExpr(scUMinExpr, Ops);
4381}
4382
4383const SCEV *
4385 const SCEV *Res = getConstant(IntTy, Size.getKnownMinValue());
4386 if (Size.isScalable())
4387 Res = getMulExpr(Res, getVScale(IntTy));
4388 return Res;
4389}
4390
4392 return getSizeOfExpr(IntTy, getDataLayout().getTypeAllocSize(AllocTy));
4393}
4394
4396 return getSizeOfExpr(IntTy, getDataLayout().getTypeStoreSize(StoreTy));
4397}
4398
4400 StructType *STy,
4401 unsigned FieldNo) {
4402 // We can bypass creating a target-independent constant expression and then
4403 // folding it back into a ConstantInt. This is just a compile-time
4404 // optimization.
4405 const StructLayout *SL = getDataLayout().getStructLayout(STy);
4406 assert(!SL->getSizeInBits().isScalable() &&
4407 "Cannot get offset for structure containing scalable vector types");
4408 return getConstant(IntTy, SL->getElementOffset(FieldNo));
4409}
4410
4412 // Don't attempt to do anything other than create a SCEVUnknown object
4413 // here. createSCEV only calls getUnknown after checking for all other
4414 // interesting possibilities, and any other code that calls getUnknown
4415 // is doing so in order to hide a value from SCEV canonicalization.
4416
4418 ID.AddInteger(scUnknown);
4419 ID.AddPointer(V);
4420 void *IP = nullptr;
4421 if (SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) {
4422 assert(cast<SCEVUnknown>(S)->getValue() == V &&
4423 "Stale SCEVUnknown in uniquing map!");
4424 return S;
4425 }
4426 SCEV *S = new (SCEVAllocator) SCEVUnknown(ID.Intern(SCEVAllocator), V, this,
4427 FirstUnknown);
4428 FirstUnknown = cast<SCEVUnknown>(S);
4429 UniqueSCEVs.InsertNode(S, IP);
4430 return S;
4431}
4432
4433//===----------------------------------------------------------------------===//
4434// Basic SCEV Analysis and PHI Idiom Recognition Code
4435//
4436
4437/// Test if values of the given type are analyzable within the SCEV
4438/// framework. This primarily includes integer types, and it can optionally
4439/// include pointer types if the ScalarEvolution class has access to
4440/// target-specific information.
4442 // Integers and pointers are always SCEVable.
4443 return Ty->isIntOrPtrTy();
4444}
4445
4446/// Return the size in bits of the specified type, for which isSCEVable must
4447/// return true.
4449 assert(isSCEVable(Ty) && "Type is not SCEVable!");
4450 if (Ty->isPointerTy())
4452 return getDataLayout().getTypeSizeInBits(Ty);
4453}
4454
4455/// Return a type with the same bitwidth as the given type and which represents
4456/// how SCEV will treat the given type, for which isSCEVable must return
4457/// true. For pointer types, this is the pointer index sized integer type.
4459 assert(isSCEVable(Ty) && "Type is not SCEVable!");
4460
4461 if (Ty->isIntegerTy())
4462 return Ty;
4463
4464 // The only other support type is pointer.
4465 assert(Ty->isPointerTy() && "Unexpected non-pointer non-integer type!");
4466 return getDataLayout().getIndexType(Ty);
4467}
4468
4470 return getTypeSizeInBits(T1) >= getTypeSizeInBits(T2) ? T1 : T2;
4471}
4472
4474 const SCEV *B) {
4475 /// For a valid use point to exist, the defining scope of one operand
4476 /// must dominate the other.
4477 bool PreciseA, PreciseB;
4478 auto *ScopeA = getDefiningScopeBound({A}, PreciseA);
4479 auto *ScopeB = getDefiningScopeBound({B}, PreciseB);
4480 if (!PreciseA || !PreciseB)
4481 // Can't tell.
4482 return false;
4483 return (ScopeA == ScopeB) || DT.dominates(ScopeA, ScopeB) ||
4484 DT.dominates(ScopeB, ScopeA);
4485}
4486
4488 return CouldNotCompute.get();
4489}
4490
4491bool ScalarEvolution::checkValidity(const SCEV *S) const {
4492 bool ContainsNulls = SCEVExprContains(S, [](const SCEV *S) {
4493 auto *SU = dyn_cast<SCEVUnknown>(S);
4494 return SU && SU->getValue() == nullptr;
4495 });
4496
4497 return !ContainsNulls;
4498}
4499
4501 HasRecMapType::iterator I = HasRecMap.find(S);
4502 if (I != HasRecMap.end())
4503 return I->second;
4504
4505 bool FoundAddRec =
4506 SCEVExprContains(S, [](const SCEV *S) { return isa<SCEVAddRecExpr>(S); });
4507 HasRecMap.insert({S, FoundAddRec});
4508 return FoundAddRec;
4509}
4510
4511/// Return the ValueOffsetPair set for \p S. \p S can be represented
4512/// by the value and offset from any ValueOffsetPair in the set.
4513ArrayRef<Value *> ScalarEvolution::getSCEVValues(const SCEV *S) {
4514 ExprValueMapType::iterator SI = ExprValueMap.find_as(S);
4515 if (SI == ExprValueMap.end())
4516 return {};
4517 return SI->second.getArrayRef();
4518}
4519
4520/// Erase Value from ValueExprMap and ExprValueMap. ValueExprMap.erase(V)
4521/// cannot be used separately. eraseValueFromMap should be used to remove
4522/// V from ValueExprMap and ExprValueMap at the same time.
4523void ScalarEvolution::eraseValueFromMap(Value *V) {
4524 ValueExprMapType::iterator I = ValueExprMap.find_as(V);
4525 if (I != ValueExprMap.end()) {
4526 auto EVIt = ExprValueMap.find(I->second);
4527 bool Removed = EVIt->second.remove(V);
4528 (void) Removed;
4529 assert(Removed && "Value not in ExprValueMap?");
4530 ValueExprMap.erase(I);
4531 }
4532}
4533
4534void ScalarEvolution::insertValueToMap(Value *V, const SCEV *S) {
4535 // A recursive query may have already computed the SCEV. It should be
4536 // equivalent, but may not necessarily be exactly the same, e.g. due to lazily
4537 // inferred nowrap flags.
4538 auto It = ValueExprMap.find_as(V);
4539 if (It == ValueExprMap.end()) {
4540 ValueExprMap.insert({SCEVCallbackVH(V, this), S});
4541 ExprValueMap[S].insert(V);
4542 }
4543}
4544
4545/// Return an existing SCEV if it exists, otherwise analyze the expression and
4546/// create a new one.
4548 assert(isSCEVable(V->getType()) && "Value is not SCEVable!");
4549
4550 if (const SCEV *S = getExistingSCEV(V))
4551 return S;
4552 return createSCEVIter(V);
4553}
4554
4556 assert(isSCEVable(V->getType()) && "Value is not SCEVable!");
4557
4558 ValueExprMapType::iterator I = ValueExprMap.find_as(V);
4559 if (I != ValueExprMap.end()) {
4560 const SCEV *S = I->second;
4561 assert(checkValidity(S) &&
4562 "existing SCEV has not been properly invalidated");
4563 return S;
4564 }
4565 return nullptr;
4566}
4567
4568/// Return a SCEV corresponding to -V = -1*V
4570 SCEV::NoWrapFlags Flags) {
4571 if (const SCEVConstant *VC = dyn_cast<SCEVConstant>(V))
4572 return getConstant(
4573 cast<ConstantInt>(ConstantExpr::getNeg(VC->getValue())));
4574
4575 Type *Ty = V->getType();
4576 Ty = getEffectiveSCEVType(Ty);
4577 return getMulExpr(V, getMinusOne(Ty), Flags);
4578}
4579
4580/// If Expr computes ~A, return A else return nullptr
4581static const SCEV *MatchNotExpr(const SCEV *Expr) {
4582 const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Expr);
4583 if (!Add || Add->getNumOperands() != 2 ||
4584 !Add->getOperand(0)->isAllOnesValue())
4585 return nullptr;
4586
4587 const SCEVMulExpr *AddRHS = dyn_cast<SCEVMulExpr>(Add->getOperand(1));
4588 if (!AddRHS || AddRHS->getNumOperands() != 2 ||
4589 !AddRHS->getOperand(0)->isAllOnesValue())
4590 return nullptr;
4591
4592 return AddRHS->getOperand(1);
4593}
4594
4595/// Return a SCEV corresponding to ~V = -1-V
4597 assert(!V->getType()->isPointerTy() && "Can't negate pointer");
4598
4599 if (const SCEVConstant *VC = dyn_cast<SCEVConstant>(V))
4600 return getConstant(
4601 cast<ConstantInt>(ConstantExpr::getNot(VC->getValue())));
4602
4603 // Fold ~(u|s)(min|max)(~x, ~y) to (u|s)(max|min)(x, y)
4604 if (const SCEVMinMaxExpr *MME = dyn_cast<SCEVMinMaxExpr>(V)) {
4605 auto MatchMinMaxNegation = [&](const SCEVMinMaxExpr *MME) {
4606 SmallVector<const SCEV *, 2> MatchedOperands;
4607 for (const SCEV *Operand : MME->operands()) {
4608 const SCEV *Matched = MatchNotExpr(Operand);
4609 if (!Matched)
4610 return (const SCEV *)nullptr;
4611 MatchedOperands.push_back(Matched);
4612 }
4613 return getMinMaxExpr(SCEVMinMaxExpr::negate(MME->getSCEVType()),
4614 MatchedOperands);
4615 };
4616 if (const SCEV *Replaced = MatchMinMaxNegation(MME))
4617 return Replaced;
4618 }
4619
4620 Type *Ty = V->getType();
4621 Ty = getEffectiveSCEVType(Ty);
4622 return getMinusSCEV(getMinusOne(Ty), V);
4623}
4624
4626 assert(P->getType()->isPointerTy());
4627
4628 if (auto *AddRec = dyn_cast<SCEVAddRecExpr>(P)) {
4629 // The base of an AddRec is the first operand.
4630 SmallVector<const SCEV *> Ops{AddRec->operands()};
4631 Ops[0] = removePointerBase(Ops[0]);
4632 // Don't try to transfer nowrap flags for now. We could in some cases
4633 // (for example, if pointer operand of the AddRec is a SCEVUnknown).
4634 return getAddRecExpr(Ops, AddRec->getLoop(), SCEV::FlagAnyWrap);
4635 }
4636 if (auto *Add = dyn_cast<SCEVAddExpr>(P)) {
4637 // The base of an Add is the pointer operand.
4638 SmallVector<const SCEV *> Ops{Add->operands()};
4639 const SCEV **PtrOp = nullptr;
4640 for (const SCEV *&AddOp : Ops) {
4641 if (AddOp->getType()->isPointerTy()) {
4642 assert(!PtrOp && "Cannot have multiple pointer ops");
4643 PtrOp = &AddOp;
4644 }
4645 }
4646 *PtrOp = removePointerBase(*PtrOp);
4647 // Don't try to transfer nowrap flags for now. We could in some cases
4648 // (for example, if the pointer operand of the Add is a SCEVUnknown).
4649 return getAddExpr(Ops);
4650 }
4651 // Any other expression must be a pointer base.
4652 return getZero(P->getType());
4653}
4654
4655const SCEV *ScalarEvolution::getMinusSCEV(const SCEV *LHS, const SCEV *RHS,
4656 SCEV::NoWrapFlags Flags,
4657 unsigned Depth) {
4658 // Fast path: X - X --> 0.
4659 if (LHS == RHS)
4660 return getZero(LHS->getType());
4661
4662 // If we subtract two pointers with different pointer bases, bail.
4663 // Eventually, we're going to add an assertion to getMulExpr that we
4664 // can't multiply by a pointer.
4665 if (RHS->getType()->isPointerTy()) {
4666 if (!LHS->getType()->isPointerTy() ||
4668 return getCouldNotCompute();
4671 }
4672
4673 // We represent LHS - RHS as LHS + (-1)*RHS. This transformation
4674 // makes it so that we cannot make much use of NUW.
4675 auto AddFlags = SCEV::FlagAnyWrap;
4676 const bool RHSIsNotMinSigned =
4678 if (hasFlags(Flags, SCEV::FlagNSW)) {
4679 // Let M be the minimum representable signed value. Then (-1)*RHS
4680 // signed-wraps if and only if RHS is M. That can happen even for
4681 // a NSW subtraction because e.g. (-1)*M signed-wraps even though
4682 // -1 - M does not. So to transfer NSW from LHS - RHS to LHS +
4683 // (-1)*RHS, we need to prove that RHS != M.
4684 //
4685 // If LHS is non-negative and we know that LHS - RHS does not
4686 // signed-wrap, then RHS cannot be M. So we can rule out signed-wrap
4687 // either by proving that RHS > M or that LHS >= 0.
4688 if (RHSIsNotMinSigned || isKnownNonNegative(LHS)) {
4689 AddFlags = SCEV::FlagNSW;
4690 }
4691 }
4692
4693 // FIXME: Find a correct way to transfer NSW to (-1)*M when LHS -
4694 // RHS is NSW and LHS >= 0.
4695 //
4696 // The difficulty here is that the NSW flag may have been proven
4697 // relative to a loop that is to be found in a recurrence in LHS and
4698 // not in RHS. Applying NSW to (-1)*M may then let the NSW have a
4699 // larger scope than intended.
4700 auto NegFlags = RHSIsNotMinSigned ? SCEV::FlagNSW : SCEV::FlagAnyWrap;
4701
4702 return getAddExpr(LHS, getNegativeSCEV(RHS, NegFlags), AddFlags, Depth);
4703}
4704
4706 unsigned Depth) {
4707 Type *SrcTy = V->getType();
4708 assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
4709 "Cannot truncate or zero extend with non-integer arguments!");
4710 if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
4711 return V; // No conversion
4712 if (getTypeSizeInBits(SrcTy) > getTypeSizeInBits(Ty))
4713 return getTruncateExpr(V, Ty, Depth);
4714 return getZeroExtendExpr(V, Ty, Depth);
4715}
4716
4718 unsigned Depth) {
4719 Type *SrcTy = V->getType();
4720 assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
4721 "Cannot truncate or zero extend with non-integer arguments!");
4722 if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
4723 return V; // No conversion
4724 if (getTypeSizeInBits(SrcTy) > getTypeSizeInBits(Ty))
4725 return getTruncateExpr(V, Ty, Depth);
4726 return getSignExtendExpr(V, Ty, Depth);
4727}
4728
4729const SCEV *
4731 Type *SrcTy = V->getType();
4732 assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
4733 "Cannot noop or zero extend with non-integer arguments!");
4735 "getNoopOrZeroExtend cannot truncate!");
4736 if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
4737 return V; // No conversion
4738 return getZeroExtendExpr(V, Ty);
4739}
4740
4741const SCEV *
4743 Type *SrcTy = V->getType();
4744 assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
4745 "Cannot noop or sign extend with non-integer arguments!");
4747 "getNoopOrSignExtend cannot truncate!");
4748 if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
4749 return V; // No conversion
4750 return getSignExtendExpr(V, Ty);
4751}
4752
4753const SCEV *
4755 Type *SrcTy = V->getType();
4756 assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
4757 "Cannot noop or any extend with non-integer arguments!");
4759 "getNoopOrAnyExtend cannot truncate!");
4760 if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
4761 return V; // No conversion
4762 return getAnyExtendExpr(V, Ty);
4763}
4764
4765const SCEV *
4767 Type *SrcTy = V->getType();
4768 assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
4769 "Cannot truncate or noop with non-integer arguments!");
4771 "getTruncateOrNoop cannot extend!");
4772 if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
4773 return V; // No conversion
4774 return getTruncateExpr(V, Ty);
4775}
4776
4778 const SCEV *RHS) {
4779 const SCEV *PromotedLHS = LHS;
4780 const SCEV *PromotedRHS = RHS;
4781
4783 PromotedRHS = getZeroExtendExpr(RHS, LHS->getType());
4784 else
4785 PromotedLHS = getNoopOrZeroExtend(LHS, RHS->getType());
4786
4787 return getUMaxExpr(PromotedLHS, PromotedRHS);
4788}
4789
4791 const SCEV *RHS,
4792 bool Sequential) {
4794 return getUMinFromMismatchedTypes(Ops, Sequential);
4795}
4796
4797const SCEV *
4799 bool Sequential) {
4800 assert(!Ops.empty() && "At least one operand must be!");
4801 // Trivial case.
4802 if (Ops.size() == 1)
4803 return Ops[0];
4804
4805 // Find the max type first.
4806 Type *MaxType = nullptr;
4807 for (const auto *S : Ops)
4808 if (MaxType)
4809 MaxType = getWiderType(MaxType, S->getType());
4810 else
4811 MaxType = S->getType();
4812 assert(MaxType && "Failed to find maximum type!");
4813
4814 // Extend all ops to max type.
4815 SmallVector<const SCEV *, 2> PromotedOps;
4816 for (const auto *S : Ops)
4817 PromotedOps.push_back(getNoopOrZeroExtend(S, MaxType));
4818
4819 // Generate umin.
4820 return getUMinExpr(PromotedOps, Sequential);
4821}
4822
4824 // A pointer operand may evaluate to a nonpointer expression, such as null.
4825 if (!V->getType()->isPointerTy())
4826 return V;
4827
4828 while (true) {
4829 if (auto *AddRec = dyn_cast<SCEVAddRecExpr>(V)) {
4830 V = AddRec->getStart();
4831 } else if (auto *Add = dyn_cast<SCEVAddExpr>(V)) {
4832 const SCEV *PtrOp = nullptr;
4833 for (const SCEV *AddOp : Add->operands()) {
4834 if (AddOp->getType()->isPointerTy()) {
4835 assert(!PtrOp && "Cannot have multiple pointer ops");
4836 PtrOp = AddOp;
4837 }
4838 }
4839 assert(PtrOp && "Must have pointer op");
4840 V = PtrOp;
4841 } else // Not something we can look further into.
4842 return V;
4843 }
4844}
4845
4846/// Push users of the given Instruction onto the given Worklist.
4850 // Push the def-use children onto the Worklist stack.
4851 for (User *U : I->users()) {
4852 auto *UserInsn = cast<Instruction>(U);
4853 if (Visited.insert(UserInsn).second)
4854 Worklist.push_back(UserInsn);
4855 }
4856}
4857
4858namespace {
4859
4860/// Takes SCEV S and Loop L. For each AddRec sub-expression, use its start
4861/// expression in case its Loop is L. If it is not L then
4862/// if IgnoreOtherLoops is true then use AddRec itself
4863/// otherwise rewrite cannot be done.
4864/// If SCEV contains non-invariant unknown SCEV rewrite cannot be done.
4865class SCEVInitRewriter : public SCEVRewriteVisitor<SCEVInitRewriter> {
4866public:
4867 static const SCEV *rewrite(const SCEV *S, const Loop *L, ScalarEvolution &SE,
4868 bool IgnoreOtherLoops = true) {
4869 SCEVInitRewriter Rewriter(L, SE);
4870 const SCEV *Result = Rewriter.visit(S);
4871 if (Rewriter.hasSeenLoopVariantSCEVUnknown())
4872 return SE.getCouldNotCompute();
4873 return Rewriter.hasSeenOtherLoops() && !IgnoreOtherLoops
4874 ? SE.getCouldNotCompute()
4875 : Result;
4876 }
4877
4878 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
4879 if (!SE.isLoopInvariant(Expr, L))
4880 SeenLoopVariantSCEVUnknown = true;
4881 return Expr;
4882 }
4883
4884 const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) {
4885 // Only re-write AddRecExprs for this loop.
4886 if (Expr->getLoop() == L)
4887 return Expr->getStart();
4888 SeenOtherLoops = true;
4889 return Expr;
4890 }
4891
4892 bool hasSeenLoopVariantSCEVUnknown() { return SeenLoopVariantSCEVUnknown; }
4893
4894 bool hasSeenOtherLoops() { return SeenOtherLoops; }
4895
4896private:
4897 explicit SCEVInitRewriter(const Loop *L, ScalarEvolution &SE)
4898 : SCEVRewriteVisitor(SE), L(L) {}
4899
4900 const Loop *L;
4901 bool SeenLoopVariantSCEVUnknown = false;
4902 bool SeenOtherLoops = false;
4903};
4904
4905/// Takes SCEV S and Loop L. For each AddRec sub-expression, use its post
4906/// increment expression in case its Loop is L. If it is not L then
4907/// use AddRec itself.
4908/// If SCEV contains non-invariant unknown SCEV rewrite cannot be done.
4909class SCEVPostIncRewriter : public SCEVRewriteVisitor<SCEVPostIncRewriter> {
4910public:
4911 static const SCEV *rewrite(const SCEV *S, const Loop *L, ScalarEvolution &SE) {
4912 SCEVPostIncRewriter Rewriter(L, SE);
4913 const SCEV *Result = Rewriter.visit(S);
4914 return Rewriter.hasSeenLoopVariantSCEVUnknown()
4915 ? SE.getCouldNotCompute()
4916 : Result;
4917 }
4918
4919 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
4920 if (!SE.isLoopInvariant(Expr, L))
4921 SeenLoopVariantSCEVUnknown = true;
4922 return Expr;
4923 }
4924
4925 const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) {
4926 // Only re-write AddRecExprs for this loop.
4927 if (Expr->getLoop() == L)
4928 return Expr->getPostIncExpr(SE);
4929 SeenOtherLoops = true;
4930 return Expr;
4931 }
4932
4933 bool hasSeenLoopVariantSCEVUnknown() { return SeenLoopVariantSCEVUnknown; }
4934
4935 bool hasSeenOtherLoops() { return SeenOtherLoops; }
4936
4937private:
4938 explicit SCEVPostIncRewriter(const Loop *L, ScalarEvolution &SE)
4939 : SCEVRewriteVisitor(SE), L(L) {}
4940
4941 const Loop *L;
4942 bool SeenLoopVariantSCEVUnknown = false;
4943 bool SeenOtherLoops = false;
4944};
4945
4946/// This class evaluates the compare condition by matching it against the
4947/// condition of loop latch. If there is a match we assume a true value
4948/// for the condition while building SCEV nodes.
4949class SCEVBackedgeConditionFolder
4950 : public SCEVRewriteVisitor<SCEVBackedgeConditionFolder> {
4951public:
4952 static const SCEV *rewrite(const SCEV *S, const Loop *L,
4953 ScalarEvolution &SE) {
4954 bool IsPosBECond = false;
4955 Value *BECond = nullptr;
4956 if (BasicBlock *Latch = L->getLoopLatch()) {
4957 BranchInst *BI = dyn_cast<BranchInst>(Latch->getTerminator());
4958 if (BI && BI->isConditional()) {
4959 assert(BI->getSuccessor(0) != BI->getSuccessor(1) &&
4960 "Both outgoing branches should not target same header!");
4961 BECond = BI->getCondition();
4962 IsPosBECond = BI->getSuccessor(0) == L->getHeader();
4963 } else {
4964 return S;
4965 }
4966 }
4967 SCEVBackedgeConditionFolder Rewriter(L, BECond, IsPosBECond, SE);
4968 return Rewriter.visit(S);
4969 }
4970
4971 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
4972 const SCEV *Result = Expr;
4973 bool InvariantF = SE.isLoopInvariant(Expr, L);
4974
4975 if (!InvariantF) {
4976 Instruction *I = cast<Instruction>(Expr->getValue());
4977 switch (I->getOpcode()) {
4978 case Instruction::Select: {
4979 SelectInst *SI = cast<SelectInst>(I);
4980 std::optional<const SCEV *> Res =
4981 compareWithBackedgeCondition(SI->getCondition());
4982 if (Res) {
4983 bool IsOne = cast<SCEVConstant>(*Res)->getValue()->isOne();
4984 Result = SE.getSCEV(IsOne ? SI->getTrueValue() : SI->getFalseValue());
4985 }
4986 break;
4987 }
4988 default: {
4989 std::optional<const SCEV *> Res = compareWithBackedgeCondition(I);
4990 if (Res)
4991 Result = *Res;
4992 break;
4993 }
4994 }
4995 }
4996 return Result;
4997 }
4998
4999private:
5000 explicit SCEVBackedgeConditionFolder(const Loop *L, Value *BECond,
5001 bool IsPosBECond, ScalarEvolution &SE)
5002 : SCEVRewriteVisitor(SE), L(L), BackedgeCond(BECond),
5003 IsPositiveBECond(IsPosBECond) {}
5004
5005 std::optional<const SCEV *> compareWithBackedgeCondition(Value *IC);
5006
5007 const Loop *L;
5008 /// Loop back condition.
5009 Value *BackedgeCond = nullptr;
5010 /// Set to true if loop back is on positive branch condition.
5011 bool IsPositiveBECond;
5012};
5013
5014std::optional<const SCEV *>
5015SCEVBackedgeConditionFolder::compareWithBackedgeCondition(Value *IC) {
5016
5017 // If value matches the backedge condition for loop latch,
5018 // then return a constant evolution node based on loopback
5019 // branch taken.
5020 if (BackedgeCond == IC)
5021 return IsPositiveBECond ? SE.getOne(Type::getInt1Ty(SE.getContext()))
5023 return std::nullopt;
5024}
5025
5026class SCEVShiftRewriter : public SCEVRewriteVisitor<SCEVShiftRewriter> {
5027public:
5028 static const SCEV *rewrite(const SCEV *S, const Loop *L,
5029 ScalarEvolution &SE) {
5030 SCEVShiftRewriter Rewriter(L, SE);
5031 const SCEV *Result = Rewriter.visit(S);
5032 return Rewriter.isValid() ? Result : SE.getCouldNotCompute();
5033 }
5034
5035 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
5036 // Only allow AddRecExprs for this loop.
5037 if (!SE.isLoopInvariant(Expr, L))
5038 Valid = false;
5039 return Expr;
5040 }
5041
5042 const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) {
5043 if (Expr->getLoop() == L && Expr->isAffine())
5044 return SE.getMinusSCEV(Expr, Expr->getStepRecurrence(SE));
5045 Valid = false;
5046 return Expr;
5047 }
5048
5049 bool isValid() { return Valid; }
5050
5051private:
5052 explicit SCEVShiftRewriter(const Loop *L, ScalarEvolution &SE)
5053 : SCEVRewriteVisitor(SE), L(L) {}
5054
5055 const Loop *L;
5056 bool Valid = true;
5057};
5058
5059} // end anonymous namespace
5060
5062ScalarEvolution::proveNoWrapViaConstantRanges(const SCEVAddRecExpr *AR) {
5063 if (!AR->isAffine())
5064 return SCEV::FlagAnyWrap;
5065
5066 using OBO = OverflowingBinaryOperator;
5067
5069
5070 if (!AR->hasNoSelfWrap()) {
5071 const SCEV *BECount = getConstantMaxBackedgeTakenCount(AR->getLoop());
5072 if (const SCEVConstant *BECountMax = dyn_cast<SCEVConstant>(BECount)) {
5073 ConstantRange StepCR = getSignedRange(AR->getStepRecurrence(*this));
5074 const APInt &BECountAP = BECountMax->getAPInt();
5075 unsigned NoOverflowBitWidth =
5076 BECountAP.getActiveBits() + StepCR.getMinSignedBits();
5077 if (NoOverflowBitWidth <= getTypeSizeInBits(AR->getType()))
5079 }
5080 }
5081
5082 if (!AR->hasNoSignedWrap()) {
5083 ConstantRange AddRecRange = getSignedRange(AR);
5084 ConstantRange IncRange = getSignedRange(AR->getStepRecurrence(*this));
5085
5087 Instruction::Add, IncRange, OBO::NoSignedWrap);
5088 if (NSWRegion.contains(AddRecRange))
5090 }
5091
5092 if (!AR->hasNoUnsignedWrap()) {
5093 ConstantRange AddRecRange = getUnsignedRange(AR);
5094 ConstantRange IncRange = getUnsignedRange(AR->getStepRecurrence(*this));
5095
5097 Instruction::Add, IncRange, OBO::NoUnsignedWrap);
5098 if (NUWRegion.contains(AddRecRange))
5100 }
5101
5102 return Result;
5103}
5104
5106ScalarEvolution::proveNoSignedWrapViaInduction(const SCEVAddRecExpr *AR) {
5108
5109 if (AR->hasNoSignedWrap())
5110 return Result;
5111
5112 if (!AR->isAffine())
5113 return Result;
5114
5115 // This function can be expensive, only try to prove NSW once per AddRec.
5116 if (!SignedWrapViaInductionTried.insert(AR).second)
5117 return Result;
5118
5119 const SCEV *Step = AR->getStepRecurrence(*this);
5120 const Loop *L = AR->getLoop();
5121
5122 // Check whether the backedge-taken count is SCEVCouldNotCompute.
5123 // Note that this serves two purposes: It filters out loops that are
5124 // simply not analyzable, and it covers the case where this code is
5125 // being called from within backedge-taken count analysis, such that
5126 // attempting to ask for the backedge-taken count would likely result
5127 // in infinite recursion. In the later case, the analysis code will
5128 // cope with a conservative value, and it will take care to purge
5129 // that value once it has finished.
5130 const SCEV *MaxBECount = getConstantMaxBackedgeTakenCount(L);
5131
5132 // Normally, in the cases we can prove no-overflow via a
5133 // backedge guarding condition, we can also compute a backedge
5134 // taken count for the loop. The exceptions are assumptions and
5135 // guards present in the loop -- SCEV is not great at exploiting
5136 // these to compute max backedge taken counts, but can still use
5137 // these to prove lack of overflow. Use this fact to avoid
5138 // doing extra work that may not pay off.
5139
5140 if (isa<SCEVCouldNotCompute>(MaxBECount) && !HasGuards &&
5141 AC.assumptions().empty())
5142 return Result;
5143
5144 // If the backedge is guarded by a comparison with the pre-inc value the
5145 // addrec is safe. Also, if the entry is guarded by a comparison with the
5146 // start value and the backedge is guarded by a comparison with the post-inc
5147 // value, the addrec is safe.
5149 const SCEV *OverflowLimit =
5150 getSignedOverflowLimitForStep(Step, &Pred, this);
5151 if (OverflowLimit &&
5152 (isLoopBackedgeGuardedByCond(L, Pred, AR, OverflowLimit) ||
5153 isKnownOnEveryIteration(Pred, AR, OverflowLimit))) {
5154 Result = setFlags(Result, SCEV::FlagNSW);
5155 }
5156 return Result;
5157}
5159ScalarEvolution::proveNoUnsignedWrapViaInduction(const SCEVAddRecExpr *AR) {
5161
5162 if (AR->hasNoUnsignedWrap())
5163 return Result;
5164
5165 if (!AR->isAffine())
5166 return Result;
5167
5168 // This function can be expensive, only try to prove NUW once per AddRec.
5169 if (!UnsignedWrapViaInductionTried.insert(AR).second)
5170 return Result;
5171
5172 const SCEV *Step = AR->getStepRecurrence(*this);
5173 unsigned BitWidth = getTypeSizeInBits(AR->getType());
5174 const Loop *L = AR->getLoop();
5175
5176 // Check whether the backedge-taken count is SCEVCouldNotCompute.
5177 // Note that this serves two purposes: It filters out loops that are
5178 // simply not analyzable, and it covers the case where this code is
5179 // being called from within backedge-taken count analysis, such that
5180 // attempting to ask for the backedge-taken count would likely result
5181 // in infinite recursion. In the later case, the analysis code will
5182 // cope with a conservative value, and it will take care to purge
5183 // that value once it has finished.
5184 const SCEV *MaxBECount = getConstantMaxBackedgeTakenCount(L);
5185
5186 // Normally, in the cases we can prove no-overflow via a
5187 // backedge guarding condition, we can also compute a backedge
5188 // taken count for the loop. The exceptions are assumptions and
5189 // guards present in the loop -- SCEV is not great at exploiting
5190 // these to compute max backedge taken counts, but can still use
5191 // these to prove lack of overflow. Use this fact to avoid
5192 // doing extra work that may not pay off.
5193
5194 if (isa<SCEVCouldNotCompute>(MaxBECount) && !HasGuards &&
5195 AC.assumptions().empty())
5196 return Result;
5197
5198 // If the backedge is guarded by a comparison with the pre-inc value the
5199 // addrec is safe. Also, if the entry is guarded by a comparison with the
5200 // start value and the backedge is guarded by a comparison with the post-inc
5201 // value, the addrec is safe.
5202 if (isKnownPositive(Step)) {
5204 getUnsignedRangeMax(Step));
5207 Result = setFlags(Result, SCEV::FlagNUW);
5208 }
5209 }
5210
5211 return Result;
5212}
5213
5214namespace {
5215
5216/// Represents an abstract binary operation. This may exist as a
5217/// normal instruction or constant expression, or may have been
5218/// derived from an expression tree.
5219struct BinaryOp {
5220 unsigned Opcode;
5221 Value *LHS;
5222 Value *RHS;
5223 bool IsNSW = false;
5224 bool IsNUW = false;
5225
5226 /// Op is set if this BinaryOp corresponds to a concrete LLVM instruction or
5227 /// constant expression.
5228 Operator *Op = nullptr;
5229
5230 explicit BinaryOp(Operator *Op)
5231 : Opcode(Op->getOpcode()), LHS(Op->getOperand(0)), RHS(Op->getOperand(1)),
5232 Op(Op) {
5233 if (auto *OBO = dyn_cast<OverflowingBinaryOperator>(Op)) {
5234 IsNSW = OBO->hasNoSignedWrap();
5235 IsNUW = OBO->hasNoUnsignedWrap();
5236 }
5237 }
5238
5239 explicit BinaryOp(unsigned Opcode, Value *LHS, Value *RHS, bool IsNSW = false,
5240 bool IsNUW = false)
5241 : Opcode(Opcode), LHS(LHS), RHS(RHS), IsNSW(IsNSW), IsNUW(IsNUW) {}
5242};
5243
5244} // end anonymous namespace
5245
5246/// Try to map \p V into a BinaryOp, and return \c std::nullopt on failure.
5247static std::optional<BinaryOp> MatchBinaryOp(Value *V, const DataLayout &DL,
5248 AssumptionCache &AC,
5249 const DominatorTree &DT,
5250 const Instruction *CxtI) {
5251 auto *Op = dyn_cast<Operator>(V);
5252 if (!Op)
5253 return std::nullopt;
5254
5255 // Implementation detail: all the cleverness here should happen without
5256 // creating new SCEV expressions -- our caller knowns tricks to avoid creating
5257 // SCEV expressions when possible, and we should not break that.
5258
5259 switch (Op->getOpcode()) {
5260 case Instruction::Add:
5261 case Instruction::Sub:
5262 case Instruction::Mul:
5263 case Instruction::UDiv:
5264 case Instruction::URem:
5265 case Instruction::And:
5266 case Instruction::AShr:
5267 case Instruction::Shl:
5268 return BinaryOp(Op);
5269
5270 case Instruction::Or: {
5271 // Convert or disjoint into add nuw nsw.
5272 if (cast<PossiblyDisjointInst>(Op)->isDisjoint())
5273 return BinaryOp(Instruction::Add, Op->getOperand(0), Op->getOperand(1),
5274 /*IsNSW=*/true, /*IsNUW=*/true);
5275 return BinaryOp(Op);
5276 }
5277
5278 case Instruction::Xor:
5279 if (auto *RHSC = dyn_cast<ConstantInt>(Op->getOperand(1)))
5280 // If the RHS of the xor is a signmask, then this is just an add.
5281 // Instcombine turns add of signmask into xor as a strength reduction step.
5282 if (RHSC->getValue().isSignMask())
5283 return BinaryOp(Instruction::Add, Op->getOperand(0), Op->getOperand(1));
5284 // Binary `xor` is a bit-wise `add`.
5285 if (V->getType()->isIntegerTy(1))
5286 return BinaryOp(Instruction::Add, Op->getOperand(0), Op->getOperand(1));
5287 return BinaryOp(Op);
5288
5289 case Instruction::LShr:
5290 // Turn logical shift right of a constant into a unsigned divide.
5291 if (ConstantInt *SA = dyn_cast<ConstantInt>(Op->getOperand(1))) {
5292 uint32_t BitWidth = cast<IntegerType>(Op->getType())->getBitWidth();
5293
5294 // If the shift count is not less than the bitwidth, the result of
5295 // the shift is undefined. Don't try to analyze it, because the
5296 // resolution chosen here may differ from the resolution chosen in
5297 // other parts of the compiler.
5298 if (SA->getValue().ult(BitWidth)) {
5299 Constant *X =
5300 ConstantInt::get(SA->getContext(),
5301 APInt::getOneBitSet(BitWidth, SA->getZExtValue()));
5302 return BinaryOp(Instruction::UDiv, Op->getOperand(0), X);
5303 }
5304 }
5305 return BinaryOp(Op);
5306
5307 case Instruction::ExtractValue: {
5308 auto *EVI = cast<ExtractValueInst>(Op);
5309 if (EVI->getNumIndices() != 1 || EVI->getIndices()[0] != 0)
5310 break;
5311
5312 auto *WO = dyn_cast<WithOverflowInst>(EVI->getAggregateOperand());
5313 if (!WO)
5314 break;
5315
5316 Instruction::BinaryOps BinOp = WO->getBinaryOp();
5317 bool Signed = WO->isSigned();
5318 // TODO: Should add nuw/nsw flags for mul as well.
5319 if (BinOp == Instruction::Mul || !isOverflowIntrinsicNoWrap(WO, DT))
5320 return BinaryOp(BinOp, WO->getLHS(), WO->getRHS());
5321
5322 // Now that we know that all uses of the arithmetic-result component of
5323 // CI are guarded by the overflow check, we can go ahead and pretend
5324 // that the arithmetic is non-overflowing.
5325 return BinaryOp(BinOp, WO->getLHS(), WO->getRHS(),
5326 /* IsNSW = */ Signed, /* IsNUW = */ !Signed);
5327 }
5328
5329 default:
5330 break;
5331 }
5332
5333 // Recognise intrinsic loop.decrement.reg, and as this has exactly the same
5334 // semantics as a Sub, return a binary sub expression.
5335 if (auto *II = dyn_cast<IntrinsicInst>(V))
5336 if (II->getIntrinsicID() == Intrinsic::loop_decrement_reg)
5337 return BinaryOp(Instruction::Sub, II->getOperand(0), II->getOperand(1));
5338
5339 return std::nullopt;
5340}
5341
5342/// Helper function to createAddRecFromPHIWithCasts. We have a phi
5343/// node whose symbolic (unknown) SCEV is \p SymbolicPHI, which is updated via
5344/// the loop backedge by a SCEVAddExpr, possibly also with a few casts on the
5345/// way. This function checks if \p Op, an operand of this SCEVAddExpr,
5346/// follows one of the following patterns:
5347/// Op == (SExt ix (Trunc iy (%SymbolicPHI) to ix) to iy)
5348/// Op == (ZExt ix (Trunc iy (%SymbolicPHI) to ix) to iy)
5349/// If the SCEV expression of \p Op conforms with one of the expected patterns
5350/// we return the type of the truncation operation, and indicate whether the
5351/// truncated type should be treated as signed/unsigned by setting
5352/// \p Signed to true/false, respectively.
5353static Type *isSimpleCastedPHI(const SCEV *Op, const SCEVUnknown *SymbolicPHI,
5354 bool &Signed, ScalarEvolution &SE) {
5355 // The case where Op == SymbolicPHI (that is, with no type conversions on
5356 // the way) is handled by the regular add recurrence creating logic and
5357 // would have already been triggered in createAddRecForPHI. Reaching it here
5358 // means that createAddRecFromPHI had failed for this PHI before (e.g.,
5359 // because one of the other operands of the SCEVAddExpr updating this PHI is
5360 // not invariant).
5361 //
5362 // Here we look for the case where Op = (ext(trunc(SymbolicPHI))), and in
5363 // this case predicates that allow us to prove that Op == SymbolicPHI will
5364 // be added.
5365 if (Op == SymbolicPHI)
5366 return nullptr;
5367
5368 unsigned SourceBits = SE.getTypeSizeInBits(SymbolicPHI->getType());
5369 unsigned NewBits = SE.getTypeSizeInBits(Op->getType());
5370 if (SourceBits != NewBits)
5371 return nullptr;
5372
5373 const SCEVSignExtendExpr *SExt = dyn_cast<SCEVSignExtendExpr>(Op);
5374 const SCEVZeroExtendExpr *ZExt = dyn_cast<SCEVZeroExtendExpr>(Op);
5375 if (!SExt && !ZExt)
5376 return nullptr;
5377 const SCEVTruncateExpr *Trunc =
5378 SExt ? dyn_cast<SCEVTruncateExpr>(SExt->getOperand())
5379 : dyn_cast<SCEVTruncateExpr>(ZExt->getOperand());
5380 if (!Trunc)
5381 return nullptr;
5382 const SCEV *X = Trunc->getOperand();
5383 if (X != SymbolicPHI)
5384 return nullptr;
5385 Signed = SExt != nullptr;
5386 return Trunc->getType();
5387}
5388
5389static const Loop *isIntegerLoopHeaderPHI(const PHINode *PN, LoopInfo &LI) {
5390 if (!PN->getType()->isIntegerTy())
5391 return nullptr;
5392 const Loop *L = LI.getLoopFor(PN->getParent());
5393 if (!L || L->getHeader() != PN->getParent())
5394 return nullptr;
5395 return L;
5396}
5397
5398// Analyze \p SymbolicPHI, a SCEV expression of a phi node, and check if the
5399// computation that updates the phi follows the following pattern:
5400// (SExt/ZExt ix (Trunc iy (%SymbolicPHI) to ix) to iy) + InvariantAccum
5401// which correspond to a phi->trunc->sext/zext->add->phi update chain.
5402// If so, try to see if it can be rewritten as an AddRecExpr under some
5403// Predicates. If successful, return them as a pair. Also cache the results
5404// of the analysis.
5405//
5406// Example usage scenario:
5407// Say the Rewriter is called for the following SCEV:
5408// 8 * ((sext i32 (trunc i64 %X to i32) to i64) + %Step)
5409// where:
5410// %X = phi i64 (%Start, %BEValue)
5411// It will visitMul->visitAdd->visitSExt->visitTrunc->visitUnknown(%X),
5412// and call this function with %SymbolicPHI = %X.
5413//
5414// The analysis will find that the value coming around the backedge has
5415// the following SCEV:
5416// BEValue = ((sext i32 (trunc i64 %X to i32) to i64) + %Step)
5417// Upon concluding that this matches the desired pattern, the function
5418// will return the pair {NewAddRec, SmallPredsVec} where:
5419// NewAddRec = {%Start,+,%Step}
5420// SmallPredsVec = {P1, P2, P3} as follows:
5421// P1(WrapPred): AR: {trunc(%Start),+,(trunc %Step)}<nsw> Flags: <nssw>
5422// P2(EqualPred): %Start == (sext i32 (trunc i64 %Start to i32) to i64)
5423// P3(EqualPred): %Step == (sext i32 (trunc i64 %Step to i32) to i64)
5424// The returned pair means that SymbolicPHI can be rewritten into NewAddRec
5425// under the predicates {P1,P2,P3}.
5426// This predicated rewrite will be cached in PredicatedSCEVRewrites:
5427// PredicatedSCEVRewrites[{%X,L}] = {NewAddRec, {P1,P2,P3)}
5428//
5429// TODO's:
5430//
5431// 1) Extend the Induction descriptor to also support inductions that involve
5432// casts: When needed (namely, when we are called in the context of the
5433// vectorizer induction analysis), a Set of cast instructions will be
5434// populated by this method, and provided back to isInductionPHI. This is
5435// needed to allow the vectorizer to properly record them to be ignored by
5436// the cost model and to avoid vectorizing them (otherwise these casts,
5437// which are redundant under the runtime overflow checks, will be
5438// vectorized, which can be costly).
5439//
5440// 2) Support additional induction/PHISCEV patterns: We also want to support
5441// inductions where the sext-trunc / zext-trunc operations (partly) occur
5442// after the induction update operation (the induction increment):
5443//
5444// (Trunc iy (SExt/ZExt ix (%SymbolicPHI + InvariantAccum) to iy) to ix)
5445// which correspond to a phi->add->trunc->sext/zext->phi update chain.
5446//
5447// (Trunc iy ((SExt/ZExt ix (%SymbolicPhi) to iy) + InvariantAccum) to ix)
5448// which correspond to a phi->trunc->add->sext/zext->phi update chain.
5449//
5450// 3) Outline common code with createAddRecFromPHI to avoid duplication.
5451std::optional<std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>>>
5452ScalarEvolution::createAddRecFromPHIWithCastsImpl(const SCEVUnknown *SymbolicPHI) {
5454
5455 // *** Part1: Analyze if we have a phi-with-cast pattern for which we can
5456 // return an AddRec expression under some predicate.
5457
5458 auto *PN = cast<PHINode>(SymbolicPHI->getValue());
5459 const Loop *L = isIntegerLoopHeaderPHI(PN, LI);
5460 assert(L && "Expecting an integer loop header phi");
5461
5462 // The loop may have multiple entrances or multiple exits; we can analyze
5463 // this phi as an addrec if it has a unique entry value and a unique
5464 // backedge value.
5465 Value *BEValueV = nullptr, *StartValueV = nullptr;
5466 for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) {
5467 Value *V = PN->getIncomingValue(i);
5468 if (L->contains(PN->getIncomingBlock(i))) {
5469 if (!BEValueV) {
5470 BEValueV = V;
5471 } else if (BEValueV != V) {
5472 BEValueV = nullptr;
5473 break;
5474 }
5475 } else if (!StartValueV) {
5476 StartValueV = V;
5477 } else if (StartValueV != V) {
5478 StartValueV = nullptr;
5479 break;
5480 }
5481 }
5482 if (!BEValueV || !StartValueV)
5483 return std::nullopt;
5484
5485 const SCEV *BEValue = getSCEV(BEValueV);
5486
5487 // If the value coming around the backedge is an add with the symbolic
5488 // value we just inserted, possibly with casts that we can ignore under
5489 // an appropriate runtime guard, then we found a simple induction variable!
5490 const auto *Add = dyn_cast<SCEVAddExpr>(BEValue);
5491 if (!Add)
5492 return std::nullopt;
5493
5494 // If there is a single occurrence of the symbolic value, possibly
5495 // casted, replace it with a recurrence.
5496 unsigned FoundIndex = Add->getNumOperands();
5497 Type *TruncTy = nullptr;
5498 bool Signed;
5499 for (unsigned i = 0, e = Add->getNumOperands(); i != e; ++i)
5500 if ((TruncTy =
5501 isSimpleCastedPHI(Add->getOperand(i), SymbolicPHI, Signed, *this)))
5502 if (FoundIndex == e) {
5503 FoundIndex = i;
5504 break;
5505 }
5506
5507 if (FoundIndex == Add->getNumOperands())
5508 return std::nullopt;
5509
5510 // Create an add with everything but the specified operand.
5512 for (unsigned i = 0, e = Add->getNumOperands(); i != e; ++i)
5513 if (i != FoundIndex)
5514 Ops.push_back(Add->getOperand(i));
5515 const SCEV *Accum = getAddExpr(Ops);
5516
5517 // The runtime checks will not be valid if the step amount is
5518 // varying inside the loop.
5519 if (!isLoopInvariant(Accum, L))
5520 return std::nullopt;
5521
5522 // *** Part2: Create the predicates
5523
5524 // Analysis was successful: we have a phi-with-cast pattern for which we
5525 // can return an AddRec expression under the following predicates:
5526 //
5527 // P1: A Wrap predicate that guarantees that Trunc(Start) + i*Trunc(Accum)
5528 // fits within the truncated type (does not overflow) for i = 0 to n-1.
5529 // P2: An Equal predicate that guarantees that
5530 // Start = (Ext ix (Trunc iy (Start) to ix) to iy)
5531 // P3: An Equal predicate that guarantees that
5532 // Accum = (Ext ix (Trunc iy (Accum) to ix) to iy)
5533 //
5534 // As we next prove, the above predicates guarantee that:
5535 // Start + i*Accum = (Ext ix (Trunc iy ( Start + i*Accum ) to ix) to iy)
5536 //
5537 //
5538 // More formally, we want to prove that:
5539 // Expr(i+1) = Start + (i+1) * Accum
5540 // = (Ext ix (Trunc iy (Expr(i)) to ix) to iy) + Accum
5541 //
5542 // Given that:
5543 // 1) Expr(0) = Start
5544 // 2) Expr(1) = Start + Accum
5545 // = (Ext ix (Trunc iy (Start) to ix) to iy) + Accum :: from P2
5546 // 3) Induction hypothesis (step i):
5547 // Expr(i) = (Ext ix (Trunc iy (Expr(i-1)) to ix) to iy) + Accum
5548 //
5549 // Proof:
5550 // Expr(i+1) =
5551 // = Start + (i+1)*Accum
5552 // = (Start + i*Accum) + Accum
5553 // = Expr(i) + Accum
5554 // = (Ext ix (Trunc iy (Expr(i-1)) to ix) to iy) + Accum + Accum
5555 // :: from step i
5556 //
5557 // = (Ext ix (Trunc iy (Start + (i-1)*Accum) to ix) to iy) + Accum + Accum
5558 //
5559 // = (Ext ix (Trunc iy (Start + (i-1)*Accum) to ix) to iy)
5560 // + (Ext ix (Trunc iy (Accum) to ix) to iy)
5561 // + Accum :: from P3
5562 //
5563 // = (Ext ix (Trunc iy ((Start + (i-1)*Accum) + Accum) to ix) to iy)
5564 // + Accum :: from P1: Ext(x)+Ext(y)=>Ext(x+y)
5565 //
5566 // = (Ext ix (Trunc iy (Start + i*Accum) to ix) to iy) + Accum
5567 // = (Ext ix (Trunc iy (Expr(i)) to ix) to iy) + Accum
5568 //
5569 // By induction, the same applies to all iterations 1<=i<n:
5570 //
5571
5572 // Create a truncated addrec for which we will add a no overflow check (P1).
5573 const SCEV *StartVal = getSCEV(StartValueV);
5574 const SCEV *PHISCEV =
5575 getAddRecExpr(getTruncateExpr(StartVal, TruncTy),
5576 getTruncateExpr(Accum, TruncTy), L, SCEV::FlagAnyWrap);
5577
5578 // PHISCEV can be either a SCEVConstant or a SCEVAddRecExpr.
5579 // ex: If truncated Accum is 0 and StartVal is a constant, then PHISCEV
5580 // will be constant.
5581 //
5582 // If PHISCEV is a constant, then P1 degenerates into P2 or P3, so we don't
5583 // add P1.
5584 if (const auto *AR = dyn_cast<SCEVAddRecExpr>(PHISCEV)) {
5588 const SCEVPredicate *AddRecPred = getWrapPredicate(AR, AddedFlags);
5589 Predicates.push_back(AddRecPred);
5590 }
5591
5592 // Create the Equal Predicates P2,P3:
5593
5594 // It is possible that the predicates P2 and/or P3 are computable at
5595 // compile time due to StartVal and/or Accum being constants.
5596 // If either one is, then we can check that now and escape if either P2
5597 // or P3 is false.
5598
5599 // Construct the extended SCEV: (Ext ix (Trunc iy (Expr) to ix) to iy)
5600 // for each of StartVal and Accum
5601 auto getExtendedExpr = [&](const SCEV *Expr,
5602 bool CreateSignExtend) -> const SCEV * {
5603 assert(isLoopInvariant(Expr, L) && "Expr is expected to be invariant");
5604 const SCEV *TruncatedExpr = getTruncateExpr(Expr, TruncTy);
5605 const SCEV *ExtendedExpr =
5606 CreateSignExtend ? getSignExtendExpr(TruncatedExpr, Expr->getType())
5607 : getZeroExtendExpr(TruncatedExpr, Expr->getType());
5608 return ExtendedExpr;
5609 };
5610
5611 // Given:
5612 // ExtendedExpr = (Ext ix (Trunc iy (Expr) to ix) to iy
5613 // = getExtendedExpr(Expr)
5614 // Determine whether the predicate P: Expr == ExtendedExpr
5615 // is known to be false at compile time
5616 auto PredIsKnownFalse = [&](const SCEV *Expr,
5617 const SCEV *ExtendedExpr) -> bool {
5618 return Expr != ExtendedExpr &&
5619 isKnownPredicate(ICmpInst::ICMP_NE, Expr, ExtendedExpr);
5620 };
5621
5622 const SCEV *StartExtended = getExtendedExpr(StartVal, Signed);
5623 if (PredIsKnownFalse(StartVal, StartExtended)) {
5624 LLVM_DEBUG(dbgs() << "P2 is compile-time false\n";);
5625 return std::nullopt;
5626 }
5627
5628 // The Step is always Signed (because the overflow checks are either
5629 // NSSW or NUSW)
5630 const SCEV *AccumExtended = getExtendedExpr(Accum, /*CreateSignExtend=*/true);
5631 if (PredIsKnownFalse(Accum, AccumExtended)) {
5632 LLVM_DEBUG(dbgs() << "P3 is compile-time false\n";);
5633 return std::nullopt;
5634 }
5635
5636 auto AppendPredicate = [&](const SCEV *Expr,
5637 const SCEV *ExtendedExpr) -> void {
5638 if (Expr != ExtendedExpr &&
5639 !isKnownPredicate(ICmpInst::ICMP_EQ, Expr, ExtendedExpr)) {
5640 const SCEVPredicate *Pred = getEqualPredicate(Expr, ExtendedExpr);
5641 LLVM_DEBUG(dbgs() << "Added Predicate: " << *Pred);
5642 Predicates.push_back(Pred);
5643 }
5644 };
5645
5646 AppendPredicate(StartVal, StartExtended);
5647 AppendPredicate(Accum, AccumExtended);
5648
5649 // *** Part3: Predicates are ready. Now go ahead and create the new addrec in
5650 // which the casts had been folded away. The caller can rewrite SymbolicPHI
5651 // into NewAR if it will also add the runtime overflow checks specified in
5652 // Predicates.
5653 auto *NewAR = getAddRecExpr(StartVal, Accum, L, SCEV::FlagAnyWrap);
5654
5655 std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>> PredRewrite =
5656 std::make_pair(NewAR, Predicates);
5657 // Remember the result of the analysis for this SCEV at this locayyytion.
5658 PredicatedSCEVRewrites[{SymbolicPHI, L}] = PredRewrite;
5659 return PredRewrite;
5660}
5661
5662std::optional<std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>>>
5664 auto *PN = cast<PHINode>(SymbolicPHI->getValue());
5665 const Loop *L = isIntegerLoopHeaderPHI(PN, LI);
5666 if (!L)
5667 return std::nullopt;
5668
5669 // Check to see if we already analyzed this PHI.
5670 auto I = PredicatedSCEVRewrites.find({SymbolicPHI, L});
5671 if (I != PredicatedSCEVRewrites.end()) {
5672 std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>> Rewrite =
5673 I->second;
5674 // Analysis was done before and failed to create an AddRec:
5675 if (Rewrite.first == SymbolicPHI)
5676 return std::nullopt;
5677 // Analysis was done before and succeeded to create an AddRec under
5678 // a predicate:
5679 assert(isa<SCEVAddRecExpr>(Rewrite.first) && "Expected an AddRec");
5680 assert(!(Rewrite.second).empty() && "Expected to find Predicates");
5681 return Rewrite;
5682 }
5683
5684 std::optional<std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>>>
5685 Rewrite = createAddRecFromPHIWithCastsImpl(SymbolicPHI);
5686
5687 // Record in the cache that the analysis failed
5688 if (!Rewrite) {
5690 PredicatedSCEVRewrites[{SymbolicPHI, L}] = {SymbolicPHI, Predicates};
5691 return std::nullopt;
5692 }
5693
5694 return Rewrite;
5695}
5696
5697// FIXME: This utility is currently required because the Rewriter currently
5698// does not rewrite this expression:
5699// {0, +, (sext ix (trunc iy to ix) to iy)}
5700// into {0, +, %step},
5701// even when the following Equal predicate exists:
5702// "%step == (sext ix (trunc iy to ix) to iy)".
5704 const SCEVAddRecExpr *AR1, const SCEVAddRecExpr *AR2) const {
5705 if (AR1 == AR2)
5706 return true;
5707
5708 auto areExprsEqual = [&](const SCEV *Expr1, const SCEV *Expr2) -> bool {
5709 if (Expr1 != Expr2 &&
5710 !Preds->implies(SE.getEqualPredicate(Expr1, Expr2), SE) &&
5711 !Preds->implies(SE.getEqualPredicate(Expr2, Expr1), SE))
5712 return false;
5713 return true;
5714 };
5715
5716 if (!areExprsEqual(AR1->getStart(), AR2->getStart()) ||
5717 !areExprsEqual(AR1->getStepRecurrence(SE), AR2->getStepRecurrence(SE)))
5718 return false;
5719 return true;
5720}
5721
5722/// A helper function for createAddRecFromPHI to handle simple cases.
5723///
5724/// This function tries to find an AddRec expression for the simplest (yet most
5725/// common) cases: PN = PHI(Start, OP(Self, LoopInvariant)).
5726/// If it fails, createAddRecFromPHI will use a more general, but slow,
5727/// technique for finding the AddRec expression.
5728const SCEV *ScalarEvolution::createSimpleAffineAddRec(PHINode *PN,
5729 Value *BEValueV,
5730 Value *StartValueV) {
5731 const Loop *L = LI.getLoopFor(PN->getParent());
5732 assert(L && L->getHeader() == PN->getParent());
5733 assert(BEValueV && StartValueV);
5734
5735 auto BO = MatchBinaryOp(BEValueV, getDataLayout(), AC, DT, PN);
5736 if (!BO)
5737 return nullptr;
5738
5739 if (BO->Opcode != Instruction::Add)
5740 return nullptr;
5741
5742 const SCEV *Accum = nullptr;
5743 if (BO->LHS == PN && L->isLoopInvariant(BO->RHS))
5744 Accum = getSCEV(BO->RHS);
5745 else if (BO->RHS == PN && L->isLoopInvariant(BO->LHS))
5746 Accum = getSCEV(BO->LHS);
5747
5748 if (!Accum)
5749 return nullptr;
5750
5752 if (BO->IsNUW)
5753 Flags = setFlags(Flags, SCEV::FlagNUW);
5754 if (BO->IsNSW)
5755 Flags = setFlags(Flags, SCEV::FlagNSW);
5756
5757 const SCEV *StartVal = getSCEV(StartValueV);
5758 const SCEV *PHISCEV = getAddRecExpr(StartVal, Accum, L, Flags);
5759 insertValueToMap(PN, PHISCEV);
5760
5761 if (auto *AR = dyn_cast<SCEVAddRecExpr>(PHISCEV)) {
5762 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR),
5764 proveNoWrapViaConstantRanges(AR)));
5765 }
5766
5767 // We can add Flags to the post-inc expression only if we
5768 // know that it is *undefined behavior* for BEValueV to
5769 // overflow.
5770 if (auto *BEInst = dyn_cast<Instruction>(BEValueV)) {
5771 assert(isLoopInvariant(Accum, L) &&
5772 "Accum is defined outside L, but is not invariant?");
5773 if (isAddRecNeverPoison(BEInst, L))
5774 (void)getAddRecExpr(getAddExpr(StartVal, Accum), Accum, L, Flags);
5775 }
5776
5777 return PHISCEV;
5778}
5779
5780const SCEV *ScalarEvolution::createAddRecFromPHI(PHINode *PN) {
5781 const Loop *L = LI.getLoopFor(PN->getParent());
5782 if (!L || L->getHeader() != PN->getParent())
5783 return nullptr;
5784
5785 // The loop may have multiple entrances or multiple exits; we can analyze
5786 // this phi as an addrec if it has a unique entry value and a unique
5787 // backedge value.
5788 Value *BEValueV = nullptr, *StartValueV = nullptr;
5789 for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) {
5790 Value *V = PN->getIncomingValue(i);
5791 if (L->contains(PN->getIncomingBlock(i))) {
5792 if (!BEValueV) {
5793 BEValueV = V;
5794 } else if (BEValueV != V) {
5795 BEValueV = nullptr;
5796 break;
5797 }
5798 } else if (!StartValueV) {
5799 StartValueV = V;
5800 } else if (StartValueV != V) {
5801 StartValueV = nullptr;
5802 break;
5803 }
5804 }
5805 if (!BEValueV || !StartValueV)
5806 return nullptr;
5807
5808 assert(ValueExprMap.find_as(PN) == ValueExprMap.end() &&
5809 "PHI node already processed?");
5810
5811 // First, try to find AddRec expression without creating a fictituos symbolic
5812 // value for PN.
5813 if (auto *S = createSimpleAffineAddRec(PN, BEValueV, StartValueV))
5814 return S;
5815
5816 // Handle PHI node value symbolically.
5817 const SCEV *SymbolicName = getUnknown(PN);
5818 insertValueToMap(PN, SymbolicName);
5819
5820 // Using this symbolic name for the PHI, analyze the value coming around
5821 // the back-edge.
5822 const SCEV *BEValue = getSCEV(BEValueV);
5823
5824 // NOTE: If BEValue is loop invariant, we know that the PHI node just
5825 // has a special value for the first iteration of the loop.
5826
5827 // If the value coming around the backedge is an add with the symbolic
5828 // value we just inserted, then we found a simple induction variable!
5829 if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(BEValue)) {
5830 // If there is a single occurrence of the symbolic value, replace it
5831 // with a recurrence.
5832 unsigned FoundIndex = Add->getNumOperands();
5833 for (unsigned i = 0, e = Add->getNumOperands(); i != e; ++i)
5834 if (Add->getOperand(i) == SymbolicName)
5835 if (FoundIndex == e) {
5836 FoundIndex = i;
5837 break;
5838 }
5839
5840 if (FoundIndex != Add->getNumOperands()) {
5841 // Create an add with everything but the specified operand.
5843 for (unsigned i = 0, e = Add->getNumOperands(); i != e; ++i)
5844 if (i != FoundIndex)
5845 Ops.push_back(SCEVBackedgeConditionFolder::rewrite(Add->getOperand(i),
5846 L, *this));
5847 const SCEV *Accum = getAddExpr(Ops);
5848
5849 // This is not a valid addrec if the step amount is varying each
5850 // loop iteration, but is not itself an addrec in this loop.
5851 if (isLoopInvariant(Accum, L) ||
5852 (isa<SCEVAddRecExpr>(Accum) &&
5853 cast<SCEVAddRecExpr>(Accum)->getLoop() == L)) {
5855
5856 if (auto BO = MatchBinaryOp(BEValueV, getDataLayout(), AC, DT, PN)) {
5857 if (BO->Opcode == Instruction::Add && BO->LHS == PN) {
5858 if (BO->IsNUW)
5859 Flags = setFlags(Flags, SCEV::FlagNUW);
5860 if (BO->IsNSW)
5861 Flags = setFlags(Flags, SCEV::FlagNSW);
5862 }
5863 } else if (GEPOperator *GEP = dyn_cast<GEPOperator>(BEValueV)) {
5864 if (GEP->getOperand(0) == PN) {
5865 GEPNoWrapFlags NW = GEP->getNoWrapFlags();
5866 // If the increment has any nowrap flags, then we know the address
5867 // space cannot be wrapped around.
5868 if (NW != GEPNoWrapFlags::none())
5869 Flags = setFlags(Flags, SCEV::FlagNW);
5870 // If the GEP is nuw or nusw with non-negative offset, we know that
5871 // no unsigned wrap occurs. We cannot set the nsw flag as only the
5872 // offset is treated as signed, while the base is unsigned.
5873 if (NW.hasNoUnsignedWrap() ||
5875 Flags = setFlags(Flags, SCEV::FlagNUW);
5876 }
5877
5878 // We cannot transfer nuw and nsw flags from subtraction
5879 // operations -- sub nuw X, Y is not the same as add nuw X, -Y
5880 // for instance.
5881 }
5882
5883 const SCEV *StartVal = getSCEV(StartValueV);
5884 const SCEV *PHISCEV = getAddRecExpr(StartVal, Accum, L, Flags);
5885
5886 // Okay, for the entire analysis of this edge we assumed the PHI
5887 // to be symbolic. We now need to go back and purge all of the
5888 // entries for the scalars that use the symbolic expression.
5889 forgetMemoizedResults(SymbolicName);
5890 insertValueToMap(PN, PHISCEV);
5891
5892 if (auto *AR = dyn_cast<SCEVAddRecExpr>(PHISCEV)) {
5893 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR),
5895 proveNoWrapViaConstantRanges(AR)));
5896 }
5897
5898 // We can add Flags to the post-inc expression only if we
5899 // know that it is *undefined behavior* for BEValueV to
5900 // overflow.
5901 if (auto *BEInst = dyn_cast<Instruction>(BEValueV))
5902 if (isLoopInvariant(Accum, L) && isAddRecNeverPoison(BEInst, L))
5903 (void)getAddRecExpr(getAddExpr(StartVal, Accum), Accum, L, Flags);
5904
5905 return PHISCEV;
5906 }
5907 }
5908 } else {
5909 // Otherwise, this could be a loop like this:
5910 // i = 0; for (j = 1; ..; ++j) { .... i = j; }
5911 // In this case, j = {1,+,1} and BEValue is j.
5912 // Because the other in-value of i (0) fits the evolution of BEValue
5913 // i really is an addrec evolution.
5914 //
5915 // We can generalize this saying that i is the shifted value of BEValue
5916 // by one iteration:
5917 // PHI(f(0), f({1,+,1})) --> f({0,+,1})
5918
5919 // Do not allow refinement in rewriting of BEValue.
5920 if (isGuaranteedNotToCauseUB(BEValue)) {
5921 const SCEV *Shifted = SCEVShiftRewriter::rewrite(BEValue, L, *this);
5922 const SCEV *Start = SCEVInitRewriter::rewrite(Shifted, L, *this, false);
5923 if (Shifted != getCouldNotCompute() && Start != getCouldNotCompute() &&
5924 ::impliesPoison(BEValue, Start)) {
5925 const SCEV *StartVal = getSCEV(StartValueV);
5926 if (Start == StartVal) {
5927 // Okay, for the entire analysis of this edge we assumed the PHI
5928 // to be symbolic. We now need to go back and purge all of the
5929 // entries for the scalars that use the symbolic expression.
5930 forgetMemoizedResults(SymbolicName);
5931 insertValueToMap(PN, Shifted);
5932 return Shifted;
5933 }
5934 }
5935 }
5936 }
5937
5938 // Remove the temporary PHI node SCEV that has been inserted while intending
5939 // to create an AddRecExpr for this PHI node. We can not keep this temporary
5940 // as it will prevent later (possibly simpler) SCEV expressions to be added
5941 // to the ValueExprMap.
5942 eraseValueFromMap(PN);
5943
5944 return nullptr;
5945}
5946
5947// Try to match a control flow sequence that branches out at BI and merges back
5948// at Merge into a "C ? LHS : RHS" select pattern. Return true on a successful
5949// match.
5951 Value *&C, Value *&LHS, Value *&RHS) {
5952 C = BI->getCondition();
5953
5954 BasicBlockEdge LeftEdge(BI->getParent(), BI->getSuccessor(0));
5955 BasicBlockEdge RightEdge(BI->getParent(), BI->getSuccessor(1));
5956
5957 if (!LeftEdge.isSingleEdge())
5958 return false;
5959
5960 assert(RightEdge.isSingleEdge() && "Follows from LeftEdge.isSingleEdge()");
5961
5962 Use &LeftUse = Merge->getOperandUse(0);
5963 Use &RightUse = Merge->getOperandUse(1);
5964
5965 if (DT.dominates(LeftEdge, LeftUse) && DT.dominates(RightEdge, RightUse)) {
5966 LHS = LeftUse;
5967 RHS = RightUse;
5968 return true;
5969 }
5970
5971 if (DT.dominates(LeftEdge, RightUse) && DT.dominates(RightEdge, LeftUse)) {
5972 LHS = RightUse;
5973 RHS = LeftUse;
5974 return true;
5975 }
5976
5977 return false;
5978}
5979
5980const SCEV *ScalarEvolution::createNodeFromSelectLikePHI(PHINode *PN) {
5981 auto IsReachable =
5982 [&](BasicBlock *BB) { return DT.isReachableFromEntry(BB); };
5983 if (PN->getNumIncomingValues() == 2 && all_of(PN->blocks(), IsReachable)) {
5984 // Try to match
5985 //
5986 // br %cond, label %left, label %right
5987 // left:
5988 // br label %merge
5989 // right:
5990 // br label %merge
5991 // merge:
5992 // V = phi [ %x, %left ], [ %y, %right ]
5993 //
5994 // as "select %cond, %x, %y"
5995
5996 BasicBlock *IDom = DT[PN->getParent()]->getIDom()->getBlock();
5997 assert(IDom && "At least the entry block should dominate PN");
5998
5999 auto *BI = dyn_cast<BranchInst>(IDom->getTerminator());
6000 Value *Cond = nullptr, *LHS = nullptr, *RHS = nullptr;
6001
6002 if (BI && BI->isConditional() &&
6003 BrPHIToSelect(DT, BI, PN, Cond, LHS, RHS) &&
6004 properlyDominates(getSCEV(LHS), PN->getParent()) &&
6005 properlyDominates(getSCEV(RHS), PN->getParent()))
6006 return createNodeForSelectOrPHI(PN, Cond, LHS, RHS);
6007 }
6008
6009 return nullptr;
6010}
6011
6012/// Returns SCEV for the first operand of a phi if all phi operands have
6013/// identical opcodes and operands
6014/// eg.
6015/// a: %add = %a + %b
6016/// br %c
6017/// b: %add1 = %a + %b
6018/// br %c
6019/// c: %phi = phi [%add, a], [%add1, b]
6020/// scev(%phi) => scev(%add)
6021const SCEV *
6022ScalarEvolution::createNodeForPHIWithIdenticalOperands(PHINode *PN) {
6023 BinaryOperator *CommonInst = nullptr;
6024 // Check if instructions are identical.
6025 for (Value *Incoming : PN->incoming_values()) {
6026 auto *IncomingInst = dyn_cast<BinaryOperator>(Incoming);
6027 if (!IncomingInst)
6028 return nullptr;
6029 if (CommonInst) {
6030 if (!CommonInst->isIdenticalToWhenDefined(IncomingInst))
6031 return nullptr; // Not identical, give up
6032 } else {
6033 // Remember binary operator
6034 CommonInst = IncomingInst;
6035 }
6036 }
6037 if (!CommonInst)
6038 return nullptr;
6039
6040 // Check if SCEV exprs for instructions are identical.
6041 const SCEV *CommonSCEV = getSCEV(CommonInst);
6042 bool SCEVExprsIdentical =
6044 [this, CommonSCEV](Value *V) { return CommonSCEV == getSCEV(V); });
6045 return SCEVExprsIdentical ? CommonSCEV : nullptr;
6046}
6047
6048const SCEV *ScalarEvolution::createNodeForPHI(PHINode *PN) {
6049 if (const SCEV *S = createAddRecFromPHI(PN))
6050 return S;
6051
6052 // We do not allow simplifying phi (undef, X) to X here, to avoid reusing the
6053 // phi node for X.
6054 if (Value *V = simplifyInstruction(
6055 PN, {getDataLayout(), &TLI, &DT, &AC, /*CtxI=*/nullptr,
6056 /*UseInstrInfo=*/true, /*CanUseUndef=*/false}))
6057 return getSCEV(V);
6058
6059 if (const SCEV *S = createNodeForPHIWithIdenticalOperands(PN))
6060 return S;
6061
6062 if (const SCEV *S = createNodeFromSelectLikePHI(PN))
6063 return S;
6064
6065 // If it's not a loop phi, we can't handle it yet.
6066 return getUnknown(PN);
6067}
6068
6069bool SCEVMinMaxExprContains(const SCEV *Root, const SCEV *OperandToFind,
6070 SCEVTypes RootKind) {
6071 struct FindClosure {
6072 const SCEV *OperandToFind;
6073 const SCEVTypes RootKind; // Must be a sequential min/max expression.
6074 const SCEVTypes NonSequentialRootKind; // Non-seq variant of RootKind.
6075
6076 bool Found = false;
6077
6078 bool canRecurseInto(SCEVTypes Kind) const {
6079 // We can only recurse into the SCEV expression of the same effective type
6080 // as the type of our root SCEV expression, and into zero-extensions.
6081 return RootKind == Kind || NonSequentialRootKind == Kind ||
6082 scZeroExtend == Kind;
6083 };
6084
6085 FindClosure(const SCEV *OperandToFind, SCEVTypes RootKind)
6086 : OperandToFind(OperandToFind), RootKind(RootKind),
6087 NonSequentialRootKind(
6089 RootKind)) {}
6090
6091 bool follow(const SCEV *S) {
6092 Found = S == OperandToFind;
6093
6094 return !isDone() && canRecurseInto(S->getSCEVType());
6095 }
6096
6097 bool isDone() const { return Found; }
6098 };
6099
6100 FindClosure FC(OperandToFind, RootKind);
6101 visitAll(Root, FC);
6102 return FC.Found;
6103}
6104
6105std::optional<const SCEV *>
6106ScalarEvolution::createNodeForSelectOrPHIInstWithICmpInstCond(Type *Ty,
6107 ICmpInst *Cond,
6108 Value *TrueVal,
6109 Value *FalseVal) {
6110 // Try to match some simple smax or umax patterns.
6111 auto *ICI = Cond;
6112
6113 Value *LHS = ICI->getOperand(0);
6114 Value *RHS = ICI->getOperand(1);
6115
6116 switch (ICI->getPredicate()) {
6117 case ICmpInst::ICMP_SLT:
6118 case ICmpInst::ICMP_SLE:
6119 case ICmpInst::ICMP_ULT:
6120 case ICmpInst::ICMP_ULE:
6121 std::swap(LHS, RHS);
6122 [[fallthrough]];
6123 case ICmpInst::ICMP_SGT:
6124 case ICmpInst::ICMP_SGE:
6125 case ICmpInst::ICMP_UGT:
6126 case ICmpInst::ICMP_UGE:
6127 // a > b ? a+x : b+x -> max(a, b)+x
6128 // a > b ? b+x : a+x -> min(a, b)+x
6130 bool Signed = ICI->isSigned();
6131 const SCEV *LA = getSCEV(TrueVal);
6132 const SCEV *RA = getSCEV(FalseVal);
6133 const SCEV *LS = getSCEV(LHS);
6134 const SCEV *RS = getSCEV(RHS);
6135 if (LA->getType()->isPointerTy()) {
6136 // FIXME: Handle cases where LS/RS are pointers not equal to LA/RA.
6137 // Need to make sure we can't produce weird expressions involving
6138 // negated pointers.
6139 if (LA == LS && RA == RS)
6140 return Signed ? getSMaxExpr(LS, RS) : getUMaxExpr(LS, RS);
6141 if (LA == RS && RA == LS)
6142 return Signed ? getSMinExpr(LS, RS) : getUMinExpr(LS, RS);
6143 }
6144 auto CoerceOperand = [&](const SCEV *Op) -> const SCEV * {
6145 if (Op->getType()->isPointerTy()) {
6147 if (isa<SCEVCouldNotCompute>(Op))
6148 return Op;
6149 }
6150 if (Signed)
6151 Op = getNoopOrSignExtend(Op, Ty);
6152 else
6153 Op = getNoopOrZeroExtend(Op, Ty);
6154 return Op;
6155 };
6156 LS = CoerceOperand(LS);
6157 RS = CoerceOperand(RS);
6158 if (isa<SCEVCouldNotCompute>(LS) || isa<SCEVCouldNotCompute>(RS))
6159 break;
6160 const SCEV *LDiff = getMinusSCEV(LA, LS);
6161 const SCEV *RDiff = getMinusSCEV(RA, RS);
6162 if (LDiff == RDiff)
6163 return getAddExpr(Signed ? getSMaxExpr(LS, RS) : getUMaxExpr(LS, RS),
6164 LDiff);
6165 LDiff = getMinusSCEV(LA, RS);
6166 RDiff = getMinusSCEV(RA, LS);
6167 if (LDiff == RDiff)
6168 return getAddExpr(Signed ? getSMinExpr(LS, RS) : getUMinExpr(LS, RS),
6169 LDiff);
6170 }
6171 break;
6172 case ICmpInst::ICMP_NE:
6173 // x != 0 ? x+y : C+y -> x == 0 ? C+y : x+y
6174 std::swap(TrueVal, FalseVal);
6175 [[fallthrough]];
6176 case ICmpInst::ICMP_EQ:
6177 // x == 0 ? C+y : x+y -> umax(x, C)+y iff C u<= 1
6179 isa<ConstantInt>(RHS) && cast<ConstantInt>(RHS)->isZero()) {
6180 const SCEV *X = getNoopOrZeroExtend(getSCEV(LHS), Ty);
6181 const SCEV *TrueValExpr = getSCEV(TrueVal); // C+y
6182 const SCEV *FalseValExpr = getSCEV(FalseVal); // x+y
6183 const SCEV *Y = getMinusSCEV(FalseValExpr, X); // y = (x+y)-x
6184 const SCEV *C = getMinusSCEV(TrueValExpr, Y); // C = (C+y)-y
6185 if (isa<SCEVConstant>(C) && cast<SCEVConstant>(C)->getAPInt().ule(1))
6186 return getAddExpr(getUMaxExpr(X, C), Y);
6187 }
6188 // x == 0 ? 0 : umin (..., x, ...) -> umin_seq(x, umin (...))
6189 // x == 0 ? 0 : umin_seq(..., x, ...) -> umin_seq(x, umin_seq(...))
6190 // x == 0 ? 0 : umin (..., umin_seq(..., x, ...), ...)
6191 // -> umin_seq(x, umin (..., umin_seq(...), ...))
6192 if (isa<ConstantInt>(RHS) && cast<ConstantInt>(RHS)->isZero() &&
6193 isa<ConstantInt>(TrueVal) && cast<ConstantInt>(TrueVal)->isZero()) {
6194 const SCEV *X = getSCEV(LHS);
6195 while (auto *ZExt = dyn_cast<SCEVZeroExtendExpr>(X))
6196 X = ZExt->getOperand();
6197 if (getTypeSizeInBits(X->getType()) <= getTypeSizeInBits(Ty)) {
6198 const SCEV *FalseValExpr = getSCEV(FalseVal);
6199 if (SCEVMinMaxExprContains(FalseValExpr, X, scSequentialUMinExpr))
6200 return getUMinExpr(getNoopOrZeroExtend(X, Ty), FalseValExpr,
6201 /*Sequential=*/true);
6202 }
6203 }
6204 break;
6205 default:
6206 break;
6207 }
6208
6209 return std::nullopt;
6210}
6211
6212static std::optional<const SCEV *>
6214 const SCEV *TrueExpr, const SCEV *FalseExpr) {
6215 assert(CondExpr->getType()->isIntegerTy(1) &&
6216 TrueExpr->getType() == FalseExpr->getType() &&
6217 TrueExpr->getType()->isIntegerTy(1) &&
6218 "Unexpected operands of a select.");
6219
6220 // i1 cond ? i1 x : i1 C --> C + (i1 cond ? (i1 x - i1 C) : i1 0)
6221 // --> C + (umin_seq cond, x - C)
6222 //
6223 // i1 cond ? i1 C : i1 x --> C + (i1 cond ? i1 0 : (i1 x - i1 C))
6224 // --> C + (i1 ~cond ? (i1 x - i1 C) : i1 0)
6225 // --> C + (umin_seq ~cond, x - C)
6226
6227 // FIXME: while we can't legally model the case where both of the hands
6228 // are fully variable, we only require that the *difference* is constant.
6229 if (!isa<SCEVConstant>(TrueExpr) && !isa<SCEVConstant>(FalseExpr))
6230 return std::nullopt;
6231
6232 const SCEV *X, *C;
6233 if (isa<SCEVConstant>(TrueExpr)) {
6234 CondExpr = SE->getNotSCEV(CondExpr);
6235 X = FalseExpr;
6236 C = TrueExpr;
6237 } else {
6238 X = TrueExpr;
6239 C = FalseExpr;
6240 }
6241 return SE->getAddExpr(C, SE->getUMinExpr(CondExpr, SE->getMinusSCEV(X, C),
6242 /*Sequential=*/true));
6243}
6244
6245static std::optional<const SCEV *>
6247 Value *FalseVal) {
6248 if (!isa<ConstantInt>(TrueVal) && !isa<ConstantInt>(FalseVal))
6249 return std::nullopt;
6250
6251 const auto *SECond = SE->getSCEV(Cond);
6252 const auto *SETrue = SE->getSCEV(TrueVal);
6253 const auto *SEFalse = SE->getSCEV(FalseVal);
6254 return createNodeForSelectViaUMinSeq(SE, SECond, SETrue, SEFalse);
6255}
6256
6257const SCEV *ScalarEvolution::createNodeForSelectOrPHIViaUMinSeq(
6258 Value *V, Value *Cond, Value *TrueVal, Value *FalseVal) {
6259 assert(Cond->getType()->isIntegerTy(1) && "Select condition is not an i1?");
6260 assert(TrueVal->getType() == FalseVal->getType() &&
6261 V->getType() == TrueVal->getType() &&
6262 "Types of select hands and of the result must match.");
6263
6264 // For now, only deal with i1-typed `select`s.
6265 if (!V->getType()->isIntegerTy(1))
6266 return getUnknown(V);
6267
6268 if (std::optional<const SCEV *> S =
6269 createNodeForSelectViaUMinSeq(this, Cond, TrueVal, FalseVal))
6270 return *S;
6271
6272 return getUnknown(V);
6273}
6274
6275const SCEV *ScalarEvolution::createNodeForSelectOrPHI(Value *V, Value *Cond,
6276 Value *TrueVal,
6277 Value *FalseVal) {
6278 // Handle "constant" branch or select. This can occur for instance when a
6279 // loop pass transforms an inner loop and moves on to process the outer loop.
6280 if (auto *CI = dyn_cast<ConstantInt>(Cond))
6281 return getSCEV(CI->isOne() ? TrueVal : FalseVal);
6282
6283 if (auto *I = dyn_cast<Instruction>(V)) {
6284 if (auto *ICI = dyn_cast<ICmpInst>(Cond)) {
6285 if (std::optional<const SCEV *> S =
6286 createNodeForSelectOrPHIInstWithICmpInstCond(I->getType(), ICI,
6287 TrueVal, FalseVal))
6288 return *S;
6289 }
6290 }
6291
6292 return createNodeForSelectOrPHIViaUMinSeq(V, Cond, TrueVal, FalseVal);
6293}
6294
6295/// Expand GEP instructions into add and multiply operations. This allows them
6296/// to be analyzed by regular SCEV code.
6297const SCEV *ScalarEvolution::createNodeForGEP(GEPOperator *GEP) {
6298 assert(GEP->getSourceElementType()->isSized() &&
6299 "GEP source element type must be sized");
6300
6302 for (Value *Index : GEP->indices())
6303 IndexExprs.push_back(getSCEV(Index));
6304 return getGEPExpr(GEP, IndexExprs);
6305}
6306
6307APInt ScalarEvolution::getConstantMultipleImpl(const SCEV *S) {
6309 auto GetShiftedByZeros = [BitWidth](uint32_t TrailingZeros) {
6310 return TrailingZeros >= BitWidth
6312 : APInt::getOneBitSet(BitWidth, TrailingZeros);
6313 };
6314 auto GetGCDMultiple = [this](const SCEVNAryExpr *N) {
6315 // The result is GCD of all operands results.
6316 APInt Res = getConstantMultiple(N->getOperand(0));
6317 for (unsigned I = 1, E = N->getNumOperands(); I < E && Res != 1; ++I)
6319 Res, getConstantMultiple(N->getOperand(I)));
6320 return Res;
6321 };
6322
6323 switch (S->getSCEVType()) {
6324 case scConstant:
6325 return cast<SCEVConstant>(S)->getAPInt();
6326 case scPtrToInt:
6327 return getConstantMultiple(cast<SCEVPtrToIntExpr>(S)->getOperand());
6328 case scUDivExpr:
6329 case scVScale:
6330 return APInt(BitWidth, 1);
6331 case scTruncate: {
6332 // Only multiples that are a power of 2 will hold after truncation.
6333 const SCEVTruncateExpr *T = cast<SCEVTruncateExpr>(S);
6334 uint32_t TZ = getMinTrailingZeros(T->getOperand());
6335 return GetShiftedByZeros(TZ);
6336 }
6337 case scZeroExtend: {
6338 const SCEVZeroExtendExpr *Z = cast<SCEVZeroExtendExpr>(S);
6339 return getConstantMultiple(Z->getOperand()).zext(BitWidth);
6340 }
6341 case scSignExtend: {
6342 // Only multiples that are a power of 2 will hold after sext.
6343 const SCEVSignExtendExpr *E = cast<SCEVSignExtendExpr>(S);
6345 return GetShiftedByZeros(TZ);
6346 }
6347 case scMulExpr: {
6348 const SCEVMulExpr *M = cast<SCEVMulExpr>(S);
6349 if (M->hasNoUnsignedWrap()) {
6350 // The result is the product of all operand results.
6351 APInt Res = getConstantMultiple(M->getOperand(0));
6352 for (const SCEV *Operand : M->operands().drop_front())
6353 Res = Res * getConstantMultiple(Operand);
6354 return Res;
6355 }
6356
6357 // If there are no wrap guarentees, find the trailing zeros, which is the
6358 // sum of trailing zeros for all its operands.
6359 uint32_t TZ = 0;
6360 for (const SCEV *Operand : M->operands())
6361 TZ += getMinTrailingZeros(Operand);
6362 return GetShiftedByZeros(TZ);
6363 }
6364 case scAddExpr:
6365 case scAddRecExpr: {
6366 const SCEVNAryExpr *N = cast<SCEVNAryExpr>(S);
6367 if (N->hasNoUnsignedWrap())
6368 return GetGCDMultiple(N);
6369 // Find the trailing bits, which is the minimum of its operands.
6370 uint32_t TZ = getMinTrailingZeros(N->getOperand(0));
6371 for (const SCEV *Operand : N->operands().drop_front())
6372 TZ = std::min(TZ, getMinTrailingZeros(Operand));
6373 return GetShiftedByZeros(TZ);
6374 }
6375 case scUMaxExpr:
6376 case scSMaxExpr:
6377 case scUMinExpr:
6378 case scSMinExpr:
6380 return GetGCDMultiple(cast<SCEVNAryExpr>(S));
6381 case scUnknown: {
6382 // ask ValueTracking for known bits
6383 const SCEVUnknown *U = cast<SCEVUnknown>(S);
6384 unsigned Known =
6385 computeKnownBits(U->getValue(), getDataLayout(), 0, &AC, nullptr, &DT)
6386 .countMinTrailingZeros();
6387 return GetShiftedByZeros(Known);
6388 }
6389 case scCouldNotCompute:
6390 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
6391 }
6392 llvm_unreachable("Unknown SCEV kind!");
6393}
6394
6396 auto I = ConstantMultipleCache.find(S);
6397 if (I != ConstantMultipleCache.end())
6398 return I->second;
6399
6400 APInt Result = getConstantMultipleImpl(S);
6401 auto InsertPair = ConstantMultipleCache.insert({S, Result});
6402 assert(InsertPair.second && "Should insert a new key");
6403 return InsertPair.first->second;
6404}
6405
6407 APInt Multiple = getConstantMultiple(S);
6408 return Multiple == 0 ? APInt(Multiple.getBitWidth(), 1) : Multiple;
6409}
6410
6412 return std::min(getConstantMultiple(S).countTrailingZeros(),
6413 (unsigned)getTypeSizeInBits(S->getType()));
6414}
6415
6416/// Helper method to assign a range to V from metadata present in the IR.
6417static std::optional<ConstantRange> GetRangeFromMetadata(Value *V) {
6418 if (Instruction *I = dyn_cast<Instruction>(V)) {
6419 if (MDNode *MD = I->getMetadata(LLVMContext::MD_range))
6420 return getConstantRangeFromMetadata(*MD);
6421 if (const auto *CB = dyn_cast<CallBase>(V))
6422 if (std::optional<ConstantRange> Range = CB->getRange())
6423 return Range;
6424 }
6425 if (auto *A = dyn_cast<Argument>(V))
6426 if (std::optional<ConstantRange> Range = A->getRange())
6427 return Range;
6428
6429 return std::nullopt;
6430}
6431
6433 SCEV::NoWrapFlags Flags) {
6434 if (AddRec->getNoWrapFlags(Flags) != Flags) {
6435 AddRec->setNoWrapFlags(Flags);
6436 UnsignedRanges.erase(AddRec);
6437 SignedRanges.erase(AddRec);
6438 ConstantMultipleCache.erase(AddRec);
6439 }
6440}
6441
6442ConstantRange ScalarEvolution::
6443getRangeForUnknownRecurrence(const SCEVUnknown *U) {
6444 const DataLayout &DL = getDataLayout();
6445
6446 unsigned BitWidth = getTypeSizeInBits(U->getType());
6447 const ConstantRange FullSet(BitWidth, /*isFullSet=*/true);
6448
6449 // Match a simple recurrence of the form: <start, ShiftOp, Step>, and then
6450 // use information about the trip count to improve our available range. Note
6451 // that the trip count independent cases are already handled by known bits.
6452 // WARNING: The definition of recurrence used here is subtly different than
6453 // the one used by AddRec (and thus most of this file). Step is allowed to
6454 // be arbitrarily loop varying here, where AddRec allows only loop invariant
6455 // and other addrecs in the same loop (for non-affine addrecs). The code
6456 // below intentionally handles the case where step is not loop invariant.
6457 auto *P = dyn_cast<PHINode>(U->getValue());
6458 if (!P)
6459 return FullSet;
6460
6461 // Make sure that no Phi input comes from an unreachable block. Otherwise,
6462 // even the values that are not available in these blocks may come from them,
6463 // and this leads to false-positive recurrence test.
6464 for (auto *Pred : predecessors(P->getParent()))
6465 if (!DT.isReachableFromEntry(Pred))
6466 return FullSet;
6467
6468 BinaryOperator *BO;
6469 Value *Start, *Step;
6470 if (!matchSimpleRecurrence(P, BO, Start, Step))
6471 return FullSet;
6472
6473 // If we found a recurrence in reachable code, we must be in a loop. Note
6474 // that BO might be in some subloop of L, and that's completely okay.
6475 auto *L = LI.getLoopFor(P->getParent());
6476 assert(L && L->getHeader() == P->getParent());
6477 if (!L->contains(BO->getParent()))
6478 // NOTE: This bailout should be an assert instead. However, asserting
6479 // the condition here exposes a case where LoopFusion is querying SCEV
6480 // with malformed loop information during the midst of the transform.
6481 // There doesn't appear to be an obvious fix, so for the moment bailout
6482 // until the caller issue can be fixed. PR49566 tracks the bug.
6483 return FullSet;
6484
6485 // TODO: Extend to other opcodes such as mul, and div
6486 switch (BO->getOpcode()) {
6487 default:
6488 return FullSet;
6489 case Instruction::AShr:
6490 case Instruction::LShr:
6491 case Instruction::Shl:
6492 break;
6493 };
6494
6495 if (BO->getOperand(0) != P)
6496 // TODO: Handle the power function forms some day.
6497 return FullSet;
6498
6499 unsigned TC = getSmallConstantMaxTripCount(L);
6500 if (!TC || TC >= BitWidth)
6501 return FullSet;
6502
6503 auto KnownStart = computeKnownBits(Start, DL, 0, &AC, nullptr, &DT);
6504 auto KnownStep = computeKnownBits(Step, DL, 0, &AC, nullptr, &DT);
6505 assert(KnownStart.getBitWidth() == BitWidth &&
6506 KnownStep.getBitWidth() == BitWidth);
6507
6508 // Compute total shift amount, being careful of overflow and bitwidths.
6509 auto MaxShiftAmt = KnownStep.getMaxValue();
6510 APInt TCAP(BitWidth, TC-1);
6511 bool Overflow = false;
6512 auto TotalShift = MaxShiftAmt.umul_ov(TCAP, Overflow);
6513 if (Overflow)
6514 return FullSet;
6515
6516 switch (BO->getOpcode()) {
6517 default:
6518 llvm_unreachable("filtered out above");
6519 case Instruction::AShr: {
6520 // For each ashr, three cases:
6521 // shift = 0 => unchanged value
6522 // saturation => 0 or -1
6523 // other => a value closer to zero (of the same sign)
6524 // Thus, the end value is closer to zero than the start.
6525 auto KnownEnd = KnownBits::ashr(KnownStart,
6526 KnownBits::makeConstant(TotalShift));
6527 if (KnownStart.isNonNegative())
6528 // Analogous to lshr (simply not yet canonicalized)
6529 return ConstantRange::getNonEmpty(KnownEnd.getMinValue(),
6530 KnownStart.getMaxValue() + 1);
6531 if (KnownStart.isNegative())
6532 // End >=u Start && End <=s Start
6533 return ConstantRange::getNonEmpty(KnownStart.getMinValue(),
6534 KnownEnd.getMaxValue() + 1);
6535 break;
6536 }
6537 case Instruction::LShr: {
6538 // For each lshr, three cases:
6539 // shift = 0 => unchanged value
6540 // saturation => 0
6541 // other => a smaller positive number
6542 // Thus, the low end of the unsigned range is the last value produced.
6543 auto KnownEnd = KnownBits::lshr(KnownStart,
6544 KnownBits::makeConstant(TotalShift));
6545 return ConstantRange::getNonEmpty(KnownEnd.getMinValue(),
6546 KnownStart.getMaxValue() + 1);
6547 }
6548 case Instruction::Shl: {
6549 // Iff no bits are shifted out, value increases on every shift.
6550 auto KnownEnd = KnownBits::shl(KnownStart,
6551 KnownBits::makeConstant(TotalShift));
6552 if (TotalShift.ult(KnownStart.countMinLeadingZeros()))
6553 return ConstantRange(KnownStart.getMinValue(),
6554 KnownEnd.getMaxValue() + 1);
6555 break;
6556 }
6557 };
6558 return FullSet;
6559}
6560
6561const ConstantRange &
6562ScalarEvolution::getRangeRefIter(const SCEV *S,
6563 ScalarEvolution::RangeSignHint SignHint) {
6565 SignHint == ScalarEvolution::HINT_RANGE_UNSIGNED ? UnsignedRanges
6566 : SignedRanges;
6569
6570 // Add Expr to the worklist, if Expr is either an N-ary expression or a
6571 // SCEVUnknown PHI node.
6572 auto AddToWorklist = [&WorkList, &Seen, &Cache](const SCEV *Expr) {
6573 if (!Seen.insert(Expr).second)
6574 return;
6575 if (Cache.contains(Expr))
6576 return;
6577 switch (Expr->getSCEVType()) {
6578 case scUnknown:
6579 if (!isa<PHINode>(cast<SCEVUnknown>(Expr)->getValue()))
6580 break;
6581 [[fallthrough]];
6582 case scConstant:
6583 case scVScale:
6584 case scTruncate:
6585 case scZeroExtend:
6586 case scSignExtend:
6587 case scPtrToInt:
6588 case scAddExpr:
6589 case scMulExpr:
6590 case scUDivExpr:
6591 case scAddRecExpr:
6592 case scUMaxExpr:
6593 case scSMaxExpr:
6594 case scUMinExpr:
6595 case scSMinExpr:
6597 WorkList.push_back(Expr);
6598 break;
6599 case scCouldNotCompute:
6600 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
6601 }
6602 };
6603 AddToWorklist(S);
6604
6605 // Build worklist by queuing operands of N-ary expressions and phi nodes.
6606 for (unsigned I = 0; I != WorkList.size(); ++I) {
6607 const SCEV *P = WorkList[I];
6608 auto *UnknownS = dyn_cast<SCEVUnknown>(P);
6609 // If it is not a `SCEVUnknown`, just recurse into operands.
6610 if (!UnknownS) {
6611 for (const SCEV *Op : P->operands())
6612 AddToWorklist(Op);
6613 continue;
6614 }
6615 // `SCEVUnknown`'s require special treatment.
6616 if (const PHINode *P = dyn_cast<PHINode>(UnknownS->getValue())) {
6617 if (!PendingPhiRangesIter.insert(P).second)
6618 continue;
6619 for (auto &Op : reverse(P->operands()))
6620 AddToWorklist(getSCEV(Op));
6621 }
6622 }
6623
6624 if (!WorkList.empty()) {
6625 // Use getRangeRef to compute ranges for items in the worklist in reverse
6626 // order. This will force ranges for earlier operands to be computed before
6627 // their users in most cases.
6628 for (const SCEV *P : reverse(drop_begin(WorkList))) {
6629 getRangeRef(P, SignHint);
6630
6631 if (auto *UnknownS = dyn_cast<SCEVUnknown>(P))
6632 if (const PHINode *P = dyn_cast<PHINode>(UnknownS->getValue()))
6633 PendingPhiRangesIter.erase(P);
6634 }
6635 }
6636
6637 return getRangeRef(S, SignHint, 0);
6638}
6639
6640/// Determine the range for a particular SCEV. If SignHint is
6641/// HINT_RANGE_UNSIGNED (resp. HINT_RANGE_SIGNED) then getRange prefers ranges
6642/// with a "cleaner" unsigned (resp. signed) representation.
6643const ConstantRange &ScalarEvolution::getRangeRef(
6644 const SCEV *S, ScalarEvolution::RangeSignHint SignHint, unsigned Depth) {
6646 SignHint == ScalarEvolution::HINT_RANGE_UNSIGNED ? UnsignedRanges
6647 : SignedRanges;
6649 SignHint == ScalarEvolution::HINT_RANGE_UNSIGNED ? ConstantRange::Unsigned
6651
6652 // See if we've computed this range already.
6654 if (I != Cache.end())
6655 return I->second;
6656
6657 if (const SCEVConstant *C = dyn_cast<SCEVConstant>(S))
6658 return setRange(C, SignHint, ConstantRange(C->getAPInt()));
6659
6660 // Switch to iteratively computing the range for S, if it is part of a deeply
6661 // nested expression.
6663 return getRangeRefIter(S, SignHint);
6664
6665 unsigned BitWidth = getTypeSizeInBits(S->getType());
6666 ConstantRange ConservativeResult(BitWidth, /*isFullSet=*/true);
6667 using OBO = OverflowingBinaryOperator;
6668
6669 // If the value has known zeros, the maximum value will have those known zeros
6670 // as well.
6671 if (SignHint == ScalarEvolution::HINT_RANGE_UNSIGNED) {
6672 APInt Multiple = getNonZeroConstantMultiple(S);
6673 APInt Remainder = APInt::getMaxValue(BitWidth).urem(Multiple);
6674 if (!Remainder.isZero())
6675 ConservativeResult =
6677 APInt::getMaxValue(BitWidth) - Remainder + 1);
6678 }
6679 else {
6681 if (TZ != 0) {
6682 ConservativeResult = ConstantRange(
6684 APInt::getSignedMaxValue(BitWidth).ashr(TZ).shl(TZ) + 1);
6685 }
6686 }
6687
6688 switch (S->getSCEVType()) {
6689 case scConstant:
6690 llvm_unreachable("Already handled above.");
6691 case scVScale:
6692 return setRange(S, SignHint, getVScaleRange(&F, BitWidth));
6693 case scTruncate: {
6694 const SCEVTruncateExpr *Trunc = cast<SCEVTruncateExpr>(S);
6695 ConstantRange X = getRangeRef(Trunc->getOperand(), SignHint, Depth + 1);
6696 return setRange(
6697 Trunc, SignHint,
6698 ConservativeResult.intersectWith(X.truncate(BitWidth), RangeType));
6699 }
6700 case scZeroExtend: {
6701 const SCEVZeroExtendExpr *ZExt = cast<SCEVZeroExtendExpr>(S);
6702 ConstantRange X = getRangeRef(ZExt->getOperand(), SignHint, Depth + 1);
6703 return setRange(
6704 ZExt, SignHint,
6705 ConservativeResult.intersectWith(X.zeroExtend(BitWidth), RangeType));
6706 }
6707 case scSignExtend: {
6708 const SCEVSignExtendExpr *SExt = cast<SCEVSignExtendExpr>(S);
6709 ConstantRange X = getRangeRef(SExt->getOperand(), SignHint, Depth + 1);
6710 return setRange(
6711 SExt, SignHint,
6712 ConservativeResult.intersectWith(X.signExtend(BitWidth), RangeType));
6713 }
6714 case scPtrToInt: {
6715 const SCEVPtrToIntExpr *PtrToInt = cast<SCEVPtrToIntExpr>(S);
6716 ConstantRange X = getRangeRef(PtrToInt->getOperand(), SignHint, Depth + 1);
6717 return setRange(PtrToInt, SignHint, X);
6718 }
6719 case scAddExpr: {
6720 const SCEVAddExpr *Add = cast<SCEVAddExpr>(S);
6721 ConstantRange X = getRangeRef(Add->getOperand(0), SignHint, Depth + 1);
6722 unsigned WrapType = OBO::AnyWrap;
6723 if (Add->hasNoSignedWrap())
6724 WrapType |= OBO::NoSignedWrap;
6725 if (Add->hasNoUnsignedWrap())
6726 WrapType |= OBO::NoUnsignedWrap;
6727 for (const SCEV *Op : drop_begin(Add->operands()))
6728 X = X.addWithNoWrap(getRangeRef(Op, SignHint, Depth + 1), WrapType,
6729 RangeType);
6730 return setRange(Add, SignHint,
6731 ConservativeResult.intersectWith(X, RangeType));
6732 }
6733 case scMulExpr: {
6734 const SCEVMulExpr *Mul = cast<SCEVMulExpr>(S);
6735 ConstantRange X = getRangeRef(Mul->getOperand(0), SignHint, Depth + 1);
6736 for (const SCEV *Op : drop_begin(Mul->operands()))
6737 X = X.multiply(getRangeRef(Op, SignHint, Depth + 1));
6738 return setRange(Mul, SignHint,
6739 ConservativeResult.intersectWith(X, RangeType));
6740 }
6741 case scUDivExpr: {
6742 const SCEVUDivExpr *UDiv = cast<SCEVUDivExpr>(S);
6743 ConstantRange X = getRangeRef(UDiv->getLHS(), SignHint, Depth + 1);
6744 ConstantRange Y = getRangeRef(UDiv->getRHS(), SignHint, Depth + 1);
6745 return setRange(UDiv, SignHint,
6746 ConservativeResult.intersectWith(X.udiv(Y), RangeType));
6747 }
6748 case scAddRecExpr: {
6749 const SCEVAddRecExpr *AddRec = cast<SCEVAddRecExpr>(S);
6750 // If there's no unsigned wrap, the value will never be less than its
6751 // initial value.
6752 if (AddRec->hasNoUnsignedWrap()) {
6753 APInt UnsignedMinValue = getUnsignedRangeMin(AddRec->getStart());
6754 if (!UnsignedMinValue.isZero())
6755 ConservativeResult = ConservativeResult.intersectWith(
6756 ConstantRange(UnsignedMinValue, APInt(BitWidth, 0)), RangeType);
6757 }
6758
6759 // If there's no signed wrap, and all the operands except initial value have
6760 // the same sign or zero, the value won't ever be:
6761 // 1: smaller than initial value if operands are non negative,
6762 // 2: bigger than initial value if operands are non positive.
6763 // For both cases, value can not cross signed min/max boundary.
6764 if (AddRec->hasNoSignedWrap()) {
6765 bool AllNonNeg = true;
6766 bool AllNonPos = true;
6767 for (unsigned i = 1, e = AddRec->getNumOperands(); i != e; ++i) {
6768 if (!isKnownNonNegative(AddRec->getOperand(i)))
6769 AllNonNeg = false;
6770 if (!isKnownNonPositive(AddRec->getOperand(i)))
6771 AllNonPos = false;
6772 }
6773 if (AllNonNeg)
6774 ConservativeResult = ConservativeResult.intersectWith(
6777 RangeType);
6778 else if (AllNonPos)
6779 ConservativeResult = ConservativeResult.intersectWith(
6781 getSignedRangeMax(AddRec->getStart()) +
6782 1),
6783 RangeType);
6784 }
6785
6786 // TODO: non-affine addrec
6787 if (AddRec->isAffine()) {
6788 const SCEV *MaxBEScev =
6790 if (!isa<SCEVCouldNotCompute>(MaxBEScev)) {
6791 APInt MaxBECount = cast<SCEVConstant>(MaxBEScev)->getAPInt();
6792
6793 // Adjust MaxBECount to the same bitwidth as AddRec. We can truncate if
6794 // MaxBECount's active bits are all <= AddRec's bit width.
6795 if (MaxBECount.getBitWidth() > BitWidth &&
6796 MaxBECount.getActiveBits() <= BitWidth)
6797 MaxBECount = MaxBECount.trunc(BitWidth);
6798 else if (MaxBECount.getBitWidth() < BitWidth)
6799 MaxBECount = MaxBECount.zext(BitWidth);
6800
6801 if (MaxBECount.getBitWidth() == BitWidth) {
6802 auto RangeFromAffine = getRangeForAffineAR(
6803 AddRec->getStart(), AddRec->getStepRecurrence(*this), MaxBECount);
6804 ConservativeResult =
6805 ConservativeResult.intersectWith(RangeFromAffine, RangeType);
6806
6807 auto RangeFromFactoring = getRangeViaFactoring(
6808 AddRec->getStart(), AddRec->getStepRecurrence(*this), MaxBECount);
6809 ConservativeResult =
6810 ConservativeResult.intersectWith(RangeFromFactoring, RangeType);
6811 }
6812 }
6813
6814 // Now try symbolic BE count and more powerful methods.
6816 const SCEV *SymbolicMaxBECount =
6818 if (!isa<SCEVCouldNotCompute>(SymbolicMaxBECount) &&
6819 getTypeSizeInBits(MaxBEScev->getType()) <= BitWidth &&
6820 AddRec->hasNoSelfWrap()) {
6821 auto RangeFromAffineNew = getRangeForAffineNoSelfWrappingAR(
6822 AddRec, SymbolicMaxBECount, BitWidth, SignHint);
6823 ConservativeResult =
6824 ConservativeResult.intersectWith(RangeFromAffineNew, RangeType);
6825 }
6826 }
6827 }
6828
6829 return setRange(AddRec, SignHint, std::move(ConservativeResult));
6830 }
6831 case scUMaxExpr:
6832 case scSMaxExpr:
6833 case scUMinExpr:
6834 case scSMinExpr:
6835 case scSequentialUMinExpr: {
6837 switch (S->getSCEVType()) {
6838 case scUMaxExpr:
6839 ID = Intrinsic::umax;
6840 break;
6841 case scSMaxExpr:
6842 ID = Intrinsic::smax;
6843 break;
6844 case scUMinExpr:
6846 ID = Intrinsic::umin;
6847 break;
6848 case scSMinExpr:
6849 ID = Intrinsic::smin;
6850 break;
6851 default:
6852 llvm_unreachable("Unknown SCEVMinMaxExpr/SCEVSequentialMinMaxExpr.");
6853 }
6854
6855 const auto *NAry = cast<SCEVNAryExpr>(S);
6856 ConstantRange X = getRangeRef(NAry->getOperand(0), SignHint, Depth + 1);
6857 for (unsigned i = 1, e = NAry->getNumOperands(); i != e; ++i)
6858 X = X.intrinsic(
6859 ID, {X, getRangeRef(NAry->getOperand(i), SignHint, Depth + 1)});
6860 return setRange(S, SignHint,
6861 ConservativeResult.intersectWith(X, RangeType));
6862 }
6863 case scUnknown: {
6864 const SCEVUnknown *U = cast<SCEVUnknown>(S);
6865 Value *V = U->getValue();
6866
6867 // Check if the IR explicitly contains !range metadata.
6868 std::optional<ConstantRange> MDRange = GetRangeFromMetadata(V);
6869 if (MDRange)
6870 ConservativeResult =
6871 ConservativeResult.intersectWith(*MDRange, RangeType);
6872
6873 // Use facts about recurrences in the underlying IR. Note that add
6874 // recurrences are AddRecExprs and thus don't hit this path. This
6875 // primarily handles shift recurrences.
6876 auto CR = getRangeForUnknownRecurrence(U);
6877 ConservativeResult = ConservativeResult.intersectWith(CR);
6878
6879 // See if ValueTracking can give us a useful range.
6880 const DataLayout &DL = getDataLayout();
6881 KnownBits Known = computeKnownBits(V, DL, 0, &AC, nullptr, &DT);
6882 if (Known.getBitWidth() != BitWidth)
6883 Known = Known.zextOrTrunc(BitWidth);
6884
6885 // ValueTracking may be able to compute a tighter result for the number of
6886 // sign bits than for the value of those sign bits.
6887 unsigned NS = ComputeNumSignBits(V, DL, 0, &AC, nullptr, &DT);
6888 if (U->getType()->isPointerTy()) {
6889 // If the pointer size is larger than the index size type, this can cause
6890 // NS to be larger than BitWidth. So compensate for this.
6891 unsigned ptrSize = DL.getPointerTypeSizeInBits(U->getType());
6892 int ptrIdxDiff = ptrSize - BitWidth;
6893 if (ptrIdxDiff > 0 && ptrSize > BitWidth && NS > (unsigned)ptrIdxDiff)
6894 NS -= ptrIdxDiff;
6895 }
6896
6897 if (NS > 1) {
6898 // If we know any of the sign bits, we know all of the sign bits.
6899 if (!Known.Zero.getHiBits(NS).isZero())
6900 Known.Zero.setHighBits(NS);
6901 if (!Known.One.getHiBits(NS).isZero())
6902 Known.One.setHighBits(NS);
6903 }
6904
6905 if (Known.getMinValue() != Known.getMaxValue() + 1)
6906 ConservativeResult = ConservativeResult.intersectWith(
6907 ConstantRange(Known.getMinValue(), Known.getMaxValue() + 1),
6908 RangeType);
6909 if (NS > 1)
6910 ConservativeResult = ConservativeResult.intersectWith(
6912 APInt::getSignedMaxValue(BitWidth).ashr(NS - 1) + 1),
6913 RangeType);
6914
6915 if (U->getType()->isPointerTy() && SignHint == HINT_RANGE_UNSIGNED) {
6916 // Strengthen the range if the underlying IR value is a
6917 // global/alloca/heap allocation using the size of the object.
6918 bool CanBeNull, CanBeFreed;
6919 uint64_t DerefBytes =
6920 V->getPointerDereferenceableBytes(DL, CanBeNull, CanBeFreed);
6921 if (DerefBytes > 1 && isUIntN(BitWidth, DerefBytes)) {
6922 // The highest address the object can start is DerefBytes bytes before
6923 // the end (unsigned max value). If this value is not a multiple of the
6924 // alignment, the last possible start value is the next lowest multiple
6925 // of the alignment. Note: The computations below cannot overflow,
6926 // because if they would there's no possible start address for the
6927 // object.
6928 APInt MaxVal =
6929 APInt::getMaxValue(BitWidth) - APInt(BitWidth, DerefBytes);
6930 uint64_t Align = U->getValue()->getPointerAlignment(DL).value();
6931 uint64_t Rem = MaxVal.urem(Align);
6932 MaxVal -= APInt(BitWidth, Rem);
6933 APInt MinVal = APInt::getZero(BitWidth);
6934 if (llvm::isKnownNonZero(V, DL))
6935 MinVal = Align;
6936 ConservativeResult = ConservativeResult.intersectWith(
6937 ConstantRange::getNonEmpty(MinVal, MaxVal + 1), RangeType);
6938 }
6939 }
6940
6941 // A range of Phi is a subset of union of all ranges of its input.
6942 if (PHINode *Phi = dyn_cast<PHINode>(V)) {
6943 // Make sure that we do not run over cycled Phis.
6944 if (PendingPhiRanges.insert(Phi).second) {
6945 ConstantRange RangeFromOps(BitWidth, /*isFullSet=*/false);
6946
6947 for (const auto &Op : Phi->operands()) {
6948 auto OpRange = getRangeRef(getSCEV(Op), SignHint, Depth + 1);
6949 RangeFromOps = RangeFromOps.unionWith(OpRange);
6950 // No point to continue if we already have a full set.
6951 if (RangeFromOps.isFullSet())
6952 break;
6953 }
6954 ConservativeResult =
6955 ConservativeResult.intersectWith(RangeFromOps, RangeType);
6956 bool Erased = PendingPhiRanges.erase(Phi);
6957 assert(Erased && "Failed to erase Phi properly?");
6958 (void)Erased;
6959 }
6960 }
6961
6962 // vscale can't be equal to zero
6963 if (const auto *II = dyn_cast<IntrinsicInst>(V))
6964 if (II->getIntrinsicID() == Intrinsic::vscale) {
6966 ConservativeResult = ConservativeResult.difference(Disallowed);
6967 }
6968
6969 return setRange(U, SignHint, std::move(ConservativeResult));
6970 }
6971 case scCouldNotCompute:
6972 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
6973 }
6974
6975 return setRange(S, SignHint, std::move(ConservativeResult));
6976}
6977
6978// Given a StartRange, Step and MaxBECount for an expression compute a range of
6979// values that the expression can take. Initially, the expression has a value
6980// from StartRange and then is changed by Step up to MaxBECount times. Signed
6981// argument defines if we treat Step as signed or unsigned.
6983 const ConstantRange &StartRange,
6984 const APInt &MaxBECount,
6985 bool Signed) {
6986 unsigned BitWidth = Step.getBitWidth();
6987 assert(BitWidth == StartRange.getBitWidth() &&
6988 BitWidth == MaxBECount.getBitWidth() && "mismatched bit widths");
6989 // If either Step or MaxBECount is 0, then the expression won't change, and we
6990 // just need to return the initial range.
6991 if (Step == 0 || MaxBECount == 0)
6992 return StartRange;
6993
6994 // If we don't know anything about the initial value (i.e. StartRange is
6995 // FullRange), then we don't know anything about the final range either.
6996 // Return FullRange.
6997 if (StartRange.isFullSet())
6998 return ConstantRange::getFull(BitWidth);
6999
7000 // If Step is signed and negative, then we use its absolute value, but we also
7001 // note that we're moving in the opposite direction.
7002 bool Descending = Signed && Step.isNegative();
7003
7004 if (Signed)
7005 // This is correct even for INT_SMIN. Let's look at i8 to illustrate this:
7006 // abs(INT_SMIN) = abs(-128) = abs(0x80) = -0x80 = 0x80 = 128.
7007 // This equations hold true due to the well-defined wrap-around behavior of
7008 // APInt.
7009 Step = Step.abs();
7010
7011 // Check if Offset is more than full span of BitWidth. If it is, the
7012 // expression is guaranteed to overflow.
7013 if (APInt::getMaxValue(StartRange.getBitWidth()).udiv(Step).ult(MaxBECount))
7014 return ConstantRange::getFull(BitWidth);
7015
7016 // Offset is by how much the expression can change. Checks above guarantee no
7017 // overflow here.
7018 APInt Offset = Step * MaxBECount;
7019
7020 // Minimum value of the final range will match the minimal value of StartRange
7021 // if the expression is increasing and will be decreased by Offset otherwise.
7022 // Maximum value of the final range will match the maximal value of StartRange
7023 // if the expression is decreasing and will be increased by Offset otherwise.
7024 APInt StartLower = StartRange.getLower();
7025 APInt StartUpper = StartRange.getUpper() - 1;
7026 APInt MovedBoundary = Descending ? (StartLower - std::move(Offset))
7027 : (StartUpper + std::move(Offset));
7028
7029 // It's possible that the new minimum/maximum value will fall into the initial
7030 // range (due to wrap around). This means that the expression can take any
7031 // value in this bitwidth, and we have to return full range.
7032 if (StartRange.contains(MovedBoundary))
7033 return ConstantRange::getFull(BitWidth);
7034
7035 APInt NewLower =
7036 Descending ? std::move(MovedBoundary) : std::move(StartLower);
7037 APInt NewUpper =
7038 Descending ? std::move(StartUpper) : std::move(MovedBoundary);
7039 NewUpper += 1;
7040
7041 // No overflow detected, return [StartLower, StartUpper + Offset + 1) range.
7042 return ConstantRange::getNonEmpty(std::move(NewLower), std::move(NewUpper));
7043}
7044
7045ConstantRange ScalarEvolution::getRangeForAffineAR(const SCEV *Start,
7046 const SCEV *Step,
7047 const APInt &MaxBECount) {
7048 assert(getTypeSizeInBits(Start->getType()) ==
7049 getTypeSizeInBits(Step->getType()) &&
7050 getTypeSizeInBits(Start->getType()) == MaxBECount.getBitWidth() &&
7051 "mismatched bit widths");
7052
7053 // First, consider step signed.
7054 ConstantRange StartSRange = getSignedRange(Start);
7055 ConstantRange StepSRange = getSignedRange(Step);
7056
7057 // If Step can be both positive and negative, we need to find ranges for the
7058 // maximum absolute step values in both directions and union them.
7060 StepSRange.getSignedMin(), StartSRange, MaxBECount, /* Signed = */ true);
7062 StartSRange, MaxBECount,
7063 /* Signed = */ true));
7064
7065 // Next, consider step unsigned.
7067 getUnsignedRangeMax(Step), getUnsignedRange(Start), MaxBECount,
7068 /* Signed = */ false);
7069
7070 // Finally, intersect signed and unsigned ranges.
7072}
7073
7074ConstantRange ScalarEvolution::getRangeForAffineNoSelfWrappingAR(
7075 const SCEVAddRecExpr *AddRec, const SCEV *MaxBECount, unsigned BitWidth,
7076 ScalarEvolution::RangeSignHint SignHint) {
7077 assert(AddRec->isAffine() && "Non-affine AddRecs are not suppored!\n");
7078 assert(AddRec->hasNoSelfWrap() &&
7079 "This only works for non-self-wrapping AddRecs!");
7080 const bool IsSigned = SignHint == HINT_RANGE_SIGNED;
7081 const SCEV *Step = AddRec->getStepRecurrence(*this);
7082 // Only deal with constant step to save compile time.
7083 if (!isa<SCEVConstant>(Step))
7084 return ConstantRange::getFull(BitWidth);
7085 // Let's make sure that we can prove that we do not self-wrap during
7086 // MaxBECount iterations. We need this because MaxBECount is a maximum
7087 // iteration count estimate, and we might infer nw from some exit for which we
7088 // do not know max exit count (or any other side reasoning).
7089 // TODO: Turn into assert at some point.
7090 if (getTypeSizeInBits(MaxBECount->getType()) >
7091 getTypeSizeInBits(AddRec->getType()))
7092 return ConstantRange::getFull(BitWidth);
7093 MaxBECount = getNoopOrZeroExtend(MaxBECount, AddRec->getType());
7094 const SCEV *RangeWidth = getMinusOne(AddRec->getType());
7095 const SCEV *StepAbs = getUMinExpr(Step, getNegativeSCEV(Step));
7096 const SCEV *MaxItersWithoutWrap = getUDivExpr(RangeWidth, StepAbs);
7097 if (!isKnownPredicateViaConstantRanges(ICmpInst::ICMP_ULE, MaxBECount,
7098 MaxItersWithoutWrap))
7099 return ConstantRange::getFull(BitWidth);
7100
7101 ICmpInst::Predicate LEPred =
7103 ICmpInst::Predicate GEPred =
7105 const SCEV *End = AddRec->evaluateAtIteration(MaxBECount, *this);
7106
7107 // We know that there is no self-wrap. Let's take Start and End values and
7108 // look at all intermediate values V1, V2, ..., Vn that IndVar takes during
7109 // the iteration. They either lie inside the range [Min(Start, End),
7110 // Max(Start, End)] or outside it:
7111 //
7112 // Case 1: RangeMin ... Start V1 ... VN End ... RangeMax;
7113 // Case 2: RangeMin Vk ... V1 Start ... End Vn ... Vk + 1 RangeMax;
7114 //
7115 // No self wrap flag guarantees that the intermediate values cannot be BOTH
7116 // outside and inside the range [Min(Start, End), Max(Start, End)]. Using that
7117 // knowledge, let's try to prove that we are dealing with Case 1. It is so if
7118 // Start <= End and step is positive, or Start >= End and step is negative.
7119 const SCEV *Start = applyLoopGuards(AddRec->getStart(), AddRec->getLoop());
7120 ConstantRange StartRange = getRangeRef(Start, SignHint);
7121 ConstantRange EndRange = getRangeRef(End, SignHint);
7122 ConstantRange RangeBetween = StartRange.unionWith(EndRange);
7123 // If they already cover full iteration space, we will know nothing useful
7124 // even if we prove what we want to prove.
7125 if (RangeBetween.isFullSet())
7126 return RangeBetween;
7127 // Only deal with ranges that do not wrap (i.e. RangeMin < RangeMax).
7128 bool IsWrappedSet = IsSigned ? RangeBetween.isSignWrappedSet()
7129 : RangeBetween.isWrappedSet();
7130 if (IsWrappedSet)
7131 return ConstantRange::getFull(BitWidth);
7132
7133 if (isKnownPositive(Step) &&
7134 isKnownPredicateViaConstantRanges(LEPred, Start, End))
7135 return RangeBetween;
7136 if (isKnownNegative(Step) &&
7137 isKnownPredicateViaConstantRanges(GEPred, Start, End))
7138 return RangeBetween;
7139 return ConstantRange::getFull(BitWidth);
7140}
7141
7142ConstantRange ScalarEvolution::getRangeViaFactoring(const SCEV *Start,
7143 const SCEV *Step,
7144 const APInt &MaxBECount) {
7145 // RangeOf({C?A:B,+,C?P:Q}) == RangeOf(C?{A,+,P}:{B,+,Q})
7146 // == RangeOf({A,+,P}) union RangeOf({B,+,Q})
7147
7148 unsigned BitWidth = MaxBECount.getBitWidth();
7149 assert(getTypeSizeInBits(Start->getType()) == BitWidth &&
7150 getTypeSizeInBits(Step->getType()) == BitWidth &&
7151 "mismatched bit widths");
7152
7153 struct SelectPattern {
7154 Value *Condition = nullptr;
7155 APInt TrueValue;
7156 APInt FalseValue;
7157
7158 explicit SelectPattern(ScalarEvolution &SE, unsigned BitWidth,
7159 const SCEV *S) {
7160 std::optional<unsigned> CastOp;
7161 APInt Offset(BitWidth, 0);
7162
7164 "Should be!");
7165
7166 // Peel off a constant offset:
7167 if (auto *SA = dyn_cast<SCEVAddExpr>(S)) {
7168 // In the future we could consider being smarter here and handle
7169 // {Start+Step,+,Step} too.
7170 if (SA->getNumOperands() != 2 || !isa<SCEVConstant>(SA->getOperand(0)))
7171 return;
7172
7173 Offset = cast<SCEVConstant>(SA->getOperand(0))->getAPInt();
7174 S = SA->getOperand(1);
7175 }
7176
7177 // Peel off a cast operation
7178 if (auto *SCast = dyn_cast<SCEVIntegralCastExpr>(S)) {
7179 CastOp = SCast->getSCEVType();
7180 S = SCast->getOperand();
7181 }
7182
7183 using namespace llvm::PatternMatch;
7184
7185 auto *SU = dyn_cast<SCEVUnknown>(S);
7186 const APInt *TrueVal, *FalseVal;
7187 if (!SU ||
7188 !match(SU->getValue(), m_Select(m_Value(Condition), m_APInt(TrueVal),
7189 m_APInt(FalseVal)))) {
7190 Condition = nullptr;
7191 return;
7192 }
7193
7194 TrueValue = *TrueVal;
7195 FalseValue = *FalseVal;
7196
7197 // Re-apply the cast we peeled off earlier
7198 if (CastOp)
7199 switch (*CastOp) {
7200 default:
7201 llvm_unreachable("Unknown SCEV cast type!");
7202
7203 case scTruncate:
7204 TrueValue = TrueValue.trunc(BitWidth);
7205 FalseValue = FalseValue.trunc(BitWidth);
7206 break;
7207 case scZeroExtend:
7208 TrueValue = TrueValue.zext(BitWidth);
7209 FalseValue = FalseValue.zext(BitWidth);
7210 break;
7211 case scSignExtend:
7212 TrueValue = TrueValue.sext(BitWidth);
7213 FalseValue = FalseValue.sext(BitWidth);
7214 break;
7215 }
7216
7217 // Re-apply the constant offset we peeled off earlier
7218 TrueValue += Offset;
7219 FalseValue += Offset;
7220 }
7221
7222 bool isRecognized() { return Condition != nullptr; }
7223 };
7224
7225 SelectPattern StartPattern(*this, BitWidth, Start);
7226 if (!StartPattern.isRecognized())
7227 return ConstantRange::getFull(BitWidth);
7228
7229 SelectPattern StepPattern(*this, BitWidth, Step);
7230 if (!StepPattern.isRecognized())
7231 return ConstantRange::getFull(BitWidth);
7232
7233 if (StartPattern.Condition != StepPattern.Condition) {
7234 // We don't handle this case today; but we could, by considering four
7235 // possibilities below instead of two. I'm not sure if there are cases where
7236 // that will help over what getRange already does, though.
7237 return ConstantRange::getFull(BitWidth);
7238 }
7239
7240 // NB! Calling ScalarEvolution::getConstant is fine, but we should not try to
7241 // construct arbitrary general SCEV expressions here. This function is called
7242 // from deep in the call stack, and calling getSCEV (on a sext instruction,
7243 // say) can end up caching a suboptimal value.
7244
7245 // FIXME: without the explicit `this` receiver below, MSVC errors out with
7246 // C2352 and C2512 (otherwise it isn't needed).
7247
7248 const SCEV *TrueStart = this->getConstant(StartPattern.TrueValue);
7249 const SCEV *TrueStep = this->getConstant(StepPattern.TrueValue);
7250 const SCEV *FalseStart = this->getConstant(StartPattern.FalseValue);
7251 const SCEV *FalseStep = this->getConstant(StepPattern.FalseValue);
7252
7253 ConstantRange TrueRange =
7254 this->getRangeForAffineAR(TrueStart, TrueStep, MaxBECount);
7255 ConstantRange FalseRange =
7256 this->getRangeForAffineAR(FalseStart, FalseStep, MaxBECount);
7257
7258 return TrueRange.unionWith(FalseRange);
7259}
7260
7261SCEV::NoWrapFlags ScalarEvolution::getNoWrapFlagsFromUB(const Value *V) {
7262 if (isa<ConstantExpr>(V)) return SCEV::FlagAnyWrap;
7263 const BinaryOperator *BinOp = cast<BinaryOperator>(V);
7264
7265 // Return early if there are no flags to propagate to the SCEV.
7267 if (BinOp->hasNoUnsignedWrap())
7269 if (BinOp->hasNoSignedWrap())
7271 if (Flags == SCEV::FlagAnyWrap)
7272 return SCEV::FlagAnyWrap;
7273
7274 return isSCEVExprNeverPoison(BinOp) ? Flags : SCEV::FlagAnyWrap;
7275}
7276
7277const Instruction *
7278ScalarEvolution::getNonTrivialDefiningScopeBound(const SCEV *S) {
7279 if (auto *AddRec = dyn_cast<SCEVAddRecExpr>(S))
7280 return &*AddRec->getLoop()->getHeader()->begin();
7281 if (auto *U = dyn_cast<SCEVUnknown>(S))
7282 if (auto *I = dyn_cast<Instruction>(U->getValue()))
7283 return I;
7284 return nullptr;
7285}
7286
7287const Instruction *
7288ScalarEvolution::getDefiningScopeBound(ArrayRef<const SCEV *> Ops,
7289 bool &Precise) {
7290 Precise = true;
7291 // Do a bounded search of the def relation of the requested SCEVs.
7294 auto pushOp = [&](const SCEV *S) {
7295 if (!Visited.insert(S).second)
7296 return;
7297 // Threshold of 30 here is arbitrary.
7298 if (Visited.size() > 30) {
7299 Precise = false;
7300 return;
7301 }
7302 Worklist.push_back(S);
7303 };
7304
7305 for (const auto *S : Ops)
7306 pushOp(S);
7307
7308 const Instruction *Bound = nullptr;
7309 while (!Worklist.empty()) {
7310 auto *S = Worklist.pop_back_val();
7311 if (auto *DefI = getNonTrivialDefiningScopeBound(S)) {
7312 if (!Bound || DT.dominates(Bound, DefI))
7313 Bound = DefI;
7314 } else {
7315 for (const auto *Op : S->operands())
7316 pushOp(Op);
7317 }
7318 }
7319 return Bound ? Bound : &*F.getEntryBlock().begin();
7320}
7321
7322const Instruction *
7323ScalarEvolution::getDefiningScopeBound(ArrayRef<const SCEV *> Ops) {
7324 bool Discard;
7325 return getDefiningScopeBound(Ops, Discard);
7326}
7327
7328bool ScalarEvolution::isGuaranteedToTransferExecutionTo(const Instruction *A,
7329 const Instruction *B) {
7330 if (A->getParent() == B->getParent() &&
7332 B->getIterator()))
7333 return true;
7334
7335 auto *BLoop = LI.getLoopFor(B->getParent());
7336 if (BLoop && BLoop->getHeader() == B->getParent() &&
7337 BLoop->getLoopPreheader() == A->getParent() &&
7339 A->getParent()->end()) &&
7340 isGuaranteedToTransferExecutionToSuccessor(B->getParent()->begin(),
7341 B->getIterator()))
7342 return true;
7343 return false;
7344}
7345
7346bool ScalarEvolution::isGuaranteedNotToBePoison(const SCEV *Op) {
7347 SCEVPoisonCollector PC(/* LookThroughMaybePoisonBlocking */ true);
7348 visitAll(Op, PC);
7349 return PC.MaybePoison.empty();
7350}
7351
7352bool ScalarEvolution::isGuaranteedNotToCauseUB(const SCEV *Op) {
7353 return !SCEVExprContains(Op, [this](const SCEV *S) {
7354 auto *UDiv = dyn_cast<SCEVUDivExpr>(S);
7355 // The UDiv may be UB if the divisor is poison or zero. Unless the divisor
7356 // is a non-zero constant, we have to assume the UDiv may be UB.
7357 return UDiv && (!isKnownNonZero(UDiv->getOperand(1)) ||
7358 !isGuaranteedNotToBePoison(UDiv->getOperand(1)));
7359 });
7360}
7361
7362bool ScalarEvolution::isSCEVExprNeverPoison(const Instruction *I) {
7363 // Only proceed if we can prove that I does not yield poison.
7365 return false;
7366
7367 // At this point we know that if I is executed, then it does not wrap
7368 // according to at least one of NSW or NUW. If I is not executed, then we do
7369 // not know if the calculation that I represents would wrap. Multiple
7370 // instructions can map to the same SCEV. If we apply NSW or NUW from I to
7371 // the SCEV, we must guarantee no wrapping for that SCEV also when it is
7372 // derived from other instructions that map to the same SCEV. We cannot make
7373 // that guarantee for cases where I is not executed. So we need to find a
7374 // upper bound on the defining scope for the SCEV, and prove that I is
7375 // executed every time we enter that scope. When the bounding scope is a
7376 // loop (the common case), this is equivalent to proving I executes on every
7377 // iteration of that loop.
7379 for (const Use &Op : I->operands()) {
7380 // I could be an extractvalue from a call to an overflow intrinsic.
7381 // TODO: We can do better here in some cases.
7382 if (isSCEVable(Op->getType()))
7383 SCEVOps.push_back(getSCEV(Op));
7384 }
7385 auto *DefI = getDefiningScopeBound(SCEVOps);
7386 return isGuaranteedToTransferExecutionTo(DefI, I);
7387}
7388
7389bool ScalarEvolution::isAddRecNeverPoison(const Instruction *I, const Loop *L) {
7390 // If we know that \c I can never be poison period, then that's enough.
7391 if (isSCEVExprNeverPoison(I))
7392 return true;
7393
7394 // If the loop only has one exit, then we know that, if the loop is entered,
7395 // any instruction dominating that exit will be executed. If any such
7396 // instruction would result in UB, the addrec cannot be poison.
7397 //
7398 // This is basically the same reasoning as in isSCEVExprNeverPoison(), but
7399 // also handles uses outside the loop header (they just need to dominate the
7400 // single exit).
7401
7402 auto *ExitingBB = L->getExitingBlock();
7403 if (!ExitingBB || !loopHasNoAbnormalExits(L))
7404 return false;
7405
7408
7409 // We start by assuming \c I, the post-inc add recurrence, is poison. Only
7410 // things that are known to be poison under that assumption go on the
7411 // Worklist.
7412 KnownPoison.insert(I);
7413 Worklist.push_back(I);
7414
7415 while (!Worklist.empty()) {
7416 const Instruction *Poison = Worklist.pop_back_val();
7417
7418 for (const Use &U : Poison->uses()) {
7419 const Instruction *PoisonUser = cast<Instruction>(U.getUser());
7420 if (mustTriggerUB(PoisonUser, KnownPoison) &&
7421 DT.dominates(PoisonUser->getParent(), ExitingBB))
7422 return true;
7423
7424 if (propagatesPoison(U) && L->contains(PoisonUser))
7425 if (KnownPoison.insert(PoisonUser).second)
7426 Worklist.push_back(PoisonUser);
7427 }
7428 }
7429
7430 return false;
7431}
7432
7433ScalarEvolution::LoopProperties
7434ScalarEvolution::getLoopProperties(const Loop *L) {
7435 using LoopProperties = ScalarEvolution::LoopProperties;
7436
7437 auto Itr = LoopPropertiesCache.find(L);
7438 if (Itr == LoopPropertiesCache.end()) {
7439 auto HasSideEffects = [](Instruction *I) {
7440 if (auto *SI = dyn_cast<StoreInst>(I))
7441 return !SI->isSimple();
7442
7443 return I->mayThrow() || I->mayWriteToMemory();
7444 };
7445
7446 LoopProperties LP = {/* HasNoAbnormalExits */ true,
7447 /*HasNoSideEffects*/ true};
7448
7449 for (auto *BB : L->getBlocks())
7450 for (auto &I : *BB) {
7452 LP.HasNoAbnormalExits = false;
7453 if (HasSideEffects(&I))
7454 LP.HasNoSideEffects = false;
7455 if (!LP.HasNoAbnormalExits && !LP.HasNoSideEffects)
7456 break; // We're already as pessimistic as we can get.
7457 }
7458
7459 auto InsertPair = LoopPropertiesCache.insert({L, LP});
7460 assert(InsertPair.second && "We just checked!");
7461 Itr = InsertPair.first;
7462 }
7463
7464 return Itr->second;
7465}
7466
7468 // A mustprogress loop without side effects must be finite.
7469 // TODO: The check used here is very conservative. It's only *specific*
7470 // side effects which are well defined in infinite loops.
7471 return isFinite(L) || (isMustProgress(L) && loopHasNoSideEffects(L));
7472}
7473
7474const SCEV *ScalarEvolution::createSCEVIter(Value *V) {
7475 // Worklist item with a Value and a bool indicating whether all operands have
7476 // been visited already.
7479
7480 Stack.emplace_back(V, true);
7481 Stack.emplace_back(V, false);
7482 while (!Stack.empty()) {
7483 auto E = Stack.pop_back_val();
7484 Value *CurV = E.getPointer();
7485
7486 if (getExistingSCEV(CurV))
7487 continue;
7488
7490 const SCEV *CreatedSCEV = nullptr;
7491 // If all operands have been visited already, create the SCEV.
7492 if (E.getInt()) {
7493 CreatedSCEV = createSCEV(CurV);
7494 } else {
7495 // Otherwise get the operands we need to create SCEV's for before creating
7496 // the SCEV for CurV. If the SCEV for CurV can be constructed trivially,
7497 // just use it.
7498 CreatedSCEV = getOperandsToCreate(CurV, Ops);
7499 }
7500
7501 if (CreatedSCEV) {
7502 insertValueToMap(CurV, CreatedSCEV);
7503 } else {
7504 // Queue CurV for SCEV creation, followed by its's operands which need to
7505 // be constructed first.
7506 Stack.emplace_back(CurV, true);
7507 for (Value *Op : Ops)
7508 Stack.emplace_back(Op, false);
7509 }
7510 }
7511
7512 return getExistingSCEV(V);
7513}
7514
7515const SCEV *
7516ScalarEvolution::getOperandsToCreate(Value *V, SmallVectorImpl<Value *> &Ops) {
7517 if (!isSCEVable(V->getType()))
7518 return getUnknown(V);
7519
7520 if (Instruction *I = dyn_cast<Instruction>(V)) {
7521 // Don't attempt to analyze instructions in blocks that aren't
7522 // reachable. Such instructions don't matter, and they aren't required
7523 // to obey basic rules for definitions dominating uses which this
7524 // analysis depends on.
7525 if (!DT.isReachableFromEntry(I->getParent()))
7526 return getUnknown(PoisonValue::get(V->getType()));
7527 } else if (ConstantInt *CI = dyn_cast<ConstantInt>(V))
7528 return getConstant(CI);
7529 else if (isa<GlobalAlias>(V))
7530 return getUnknown(V);
7531 else if (!isa<ConstantExpr>(V))
7532 return getUnknown(V);
7533
7534 Operator *U = cast<Operator>(V);
7535 if (auto BO =
7536 MatchBinaryOp(U, getDataLayout(), AC, DT, dyn_cast<Instruction>(V))) {
7537 bool IsConstArg = isa<ConstantInt>(BO->RHS);
7538 switch (BO->Opcode) {
7539 case Instruction::Add:
7540 case Instruction::Mul: {
7541 // For additions and multiplications, traverse add/mul chains for which we
7542 // can potentially create a single SCEV, to reduce the number of
7543 // get{Add,Mul}Expr calls.
7544 do {
7545 if (BO->Op) {
7546 if (BO->Op != V && getExistingSCEV(BO->Op)) {
7547 Ops.push_back(BO->Op);
7548 break;
7549 }
7550 }
7551 Ops.push_back(BO->RHS);
7552 auto NewBO = MatchBinaryOp(BO->LHS, getDataLayout(), AC, DT,
7553 dyn_cast<Instruction>(V));
7554 if (!NewBO ||
7555 (BO->Opcode == Instruction::Add &&
7556 (NewBO->Opcode != Instruction::Add &&
7557 NewBO->Opcode != Instruction::Sub)) ||
7558 (BO->Opcode == Instruction::Mul &&
7559 NewBO->Opcode != Instruction::Mul)) {
7560 Ops.push_back(BO->LHS);
7561 break;
7562 }
7563 // CreateSCEV calls getNoWrapFlagsFromUB, which under certain conditions
7564 // requires a SCEV for the LHS.
7565 if (BO->Op && (BO->IsNSW || BO->IsNUW)) {
7566 auto *I = dyn_cast<Instruction>(BO->Op);
7567 if (I && programUndefinedIfPoison(I)) {
7568 Ops.push_back(BO->LHS);
7569 break;
7570 }
7571 }
7572 BO = NewBO;
7573 } while (true);
7574 return nullptr;
7575 }
7576 case Instruction::Sub:
7577 case Instruction::UDiv:
7578 case Instruction::URem:
7579 break;
7580 case Instruction::AShr:
7581 case Instruction::Shl:
7582 case Instruction::Xor:
7583 if (!IsConstArg)
7584 return nullptr;
7585 break;
7586 case Instruction::And:
7587 case Instruction::Or:
7588 if (!IsConstArg && !BO->LHS->getType()->isIntegerTy(1))
7589 return nullptr;
7590 break;
7591 case Instruction::LShr:
7592 return getUnknown(V);
7593 default:
7594 llvm_unreachable("Unhandled binop");
7595 break;
7596 }
7597
7598 Ops.push_back(BO->LHS);
7599 Ops.push_back(BO->RHS);
7600 return nullptr;
7601 }
7602
7603 switch (U->getOpcode()) {
7604 case Instruction::Trunc:
7605 case Instruction::ZExt:
7606 case Instruction::SExt:
7607 case Instruction::PtrToInt:
7608 Ops.push_back(U->getOperand(0));
7609 return nullptr;
7610
7611 case Instruction::BitCast:
7612 if (isSCEVable(U->getType()) && isSCEVable(U->getOperand(0)->getType())) {
7613 Ops.push_back(U->getOperand(0));
7614 return nullptr;
7615 }
7616 return getUnknown(V);
7617
7618 case Instruction::SDiv:
7619 case Instruction::SRem:
7620 Ops.push_back(U->getOperand(0));
7621 Ops.push_back(U->getOperand(1));
7622 return nullptr;
7623
7624 case Instruction::GetElementPtr:
7625 assert(cast<GEPOperator>(U)->getSourceElementType()->isSized() &&
7626 "GEP source element type must be sized");
7627 for (Value *Index : U->operands())
7628 Ops.push_back(Index);
7629 return nullptr;
7630
7631 case Instruction::IntToPtr:
7632 return getUnknown(V);
7633
7634 case Instruction::PHI:
7635 // Keep constructing SCEVs' for phis recursively for now.
7636 return nullptr;
7637
7638 case Instruction::Select: {
7639 // Check if U is a select that can be simplified to a SCEVUnknown.
7640 auto CanSimplifyToUnknown = [this, U]() {
7641 if (U->getType()->isIntegerTy(1) || isa<ConstantInt>(U->getOperand(0)))
7642 return false;
7643
7644 auto *ICI = dyn_cast<ICmpInst>(U->getOperand(0));
7645 if (!ICI)
7646 return false;
7647 Value *LHS = ICI->getOperand(0);
7648 Value *RHS = ICI->getOperand(1);
7649 if (ICI->getPredicate() == CmpInst::ICMP_EQ ||
7650 ICI->getPredicate() == CmpInst::ICMP_NE) {
7651 if (!(isa<ConstantInt>(RHS) && cast<ConstantInt>(RHS)->isZero()))
7652 return true;
7653 } else if (getTypeSizeInBits(LHS->getType()) >
7654 getTypeSizeInBits(U->getType()))
7655 return true;
7656 return false;
7657 };
7658 if (CanSimplifyToUnknown())
7659 return getUnknown(U);
7660
7661 for (Value *Inc : U->operands())
7662 Ops.push_back(Inc);
7663 return nullptr;
7664 break;
7665 }
7666 case Instruction::Call:
7667 case Instruction::Invoke:
7668 if (Value *RV = cast<CallBase>(U)->getReturnedArgOperand()) {
7669 Ops.push_back(RV);
7670 return nullptr;
7671 }
7672
7673 if (auto *II = dyn_cast<IntrinsicInst>(U)) {
7674 switch (II->getIntrinsicID()) {
7675 case Intrinsic::abs:
7676 Ops.push_back(II->getArgOperand(0));
7677 return nullptr;
7678 case Intrinsic::umax:
7679 case Intrinsic::umin:
7680 case Intrinsic::smax:
7681 case Intrinsic::smin:
7682 case Intrinsic::usub_sat:
7683 case Intrinsic::uadd_sat:
7684 Ops.push_back(II->getArgOperand(0));
7685 Ops.push_back(II->getArgOperand(1));
7686 return nullptr;
7687 case Intrinsic::start_loop_iterations:
7688 case Intrinsic::annotation:
7689 case Intrinsic::ptr_annotation:
7690 Ops.push_back(II->getArgOperand(0));
7691 return nullptr;
7692 default:
7693 break;
7694 }
7695 }
7696 break;
7697 }
7698
7699 return nullptr;
7700}
7701
7702const SCEV *ScalarEvolution::createSCEV(Value *V) {
7703 if (!isSCEVable(V->getType()))
7704 return getUnknown(V);
7705
7706 if (Instruction *I = dyn_cast<Instruction>(V)) {
7707 // Don't attempt to analyze instructions in blocks that aren't
7708 // reachable. Such instructions don't matter, and they aren't required
7709 // to obey basic rules for definitions dominating uses which this
7710 // analysis depends on.
7711 if (!DT.isReachableFromEntry(I->getParent()))
7712 return getUnknown(PoisonValue::get(V->getType()));
7713 } else if (ConstantInt *CI = dyn_cast<ConstantInt>(V))
7714 return getConstant(CI);
7715 else if (isa<GlobalAlias>(V))
7716 return getUnknown(V);
7717 else if (!isa<ConstantExpr>(V))
7718 return getUnknown(V);
7719
7720 const SCEV *LHS;
7721 const SCEV *RHS;
7722
7723 Operator *U = cast<Operator>(V);
7724 if (auto BO =
7725 MatchBinaryOp(U, getDataLayout(), AC, DT, dyn_cast<Instruction>(V))) {
7726 switch (BO->Opcode) {
7727 case Instruction::Add: {
7728 // The simple thing to do would be to just call getSCEV on both operands
7729 // and call getAddExpr with the result. However if we're looking at a
7730 // bunch of things all added together, this can be quite inefficient,
7731 // because it leads to N-1 getAddExpr calls for N ultimate operands.
7732 // Instead, gather up all the operands and make a single getAddExpr call.
7733 // LLVM IR canonical form means we need only traverse the left operands.
7735 do {
7736 if (BO->Op) {
7737 if (auto *OpSCEV = getExistingSCEV(BO->Op)) {
7738 AddOps.push_back(OpSCEV);
7739 break;
7740 }
7741
7742 // If a NUW or NSW flag can be applied to the SCEV for this
7743 // addition, then compute the SCEV for this addition by itself
7744 // with a separate call to getAddExpr. We need to do that
7745 // instead of pushing the operands of the addition onto AddOps,
7746 // since the flags are only known to apply to this particular
7747 // addition - they may not apply to other additions that can be
7748 // formed with operands from AddOps.
7749 const SCEV *RHS = getSCEV(BO->RHS);
7750 SCEV::NoWrapFlags Flags = getNoWrapFlagsFromUB(BO->Op);
7751 if (Flags != SCEV::FlagAnyWrap) {
7752 const SCEV *LHS = getSCEV(BO->LHS);
7753 if (BO->Opcode == Instruction::Sub)
7754 AddOps.push_back(getMinusSCEV(LHS, RHS, Flags));
7755 else
7756 AddOps.push_back(getAddExpr(LHS, RHS, Flags));
7757 break;
7758 }
7759 }
7760
7761 if (BO->Opcode == Instruction::Sub)
7762 AddOps.push_back(getNegativeSCEV(getSCEV(BO->RHS)));
7763 else
7764 AddOps.push_back(getSCEV(BO->RHS));
7765
7766 auto NewBO = MatchBinaryOp(BO->LHS, getDataLayout(), AC, DT,
7767 dyn_cast<Instruction>(V));
7768 if (!NewBO || (NewBO->Opcode != Instruction::Add &&
7769 NewBO->Opcode != Instruction::Sub)) {
7770 AddOps.push_back(getSCEV(BO->LHS));
7771 break;
7772 }
7773 BO = NewBO;
7774 } while (true);
7775
7776 return getAddExpr(AddOps);
7777 }
7778
7779 case Instruction::Mul: {
7781 do {
7782 if (BO->Op) {
7783 if (auto *OpSCEV = getExistingSCEV(BO->Op)) {
7784 MulOps.push_back(OpSCEV);
7785 break;
7786 }
7787
7788 SCEV::NoWrapFlags Flags = getNoWrapFlagsFromUB(BO->Op);
7789 if (Flags != SCEV::FlagAnyWrap) {
7790 LHS = getSCEV(BO->LHS);
7791 RHS = getSCEV(BO->RHS);
7792 MulOps.push_back(getMulExpr(LHS, RHS, Flags));
7793 break;
7794 }
7795 }
7796
7797 MulOps.push_back(getSCEV(BO->RHS));
7798 auto NewBO = MatchBinaryOp(BO->LHS, getDataLayout(), AC, DT,
7799 dyn_cast<Instruction>(V));
7800 if (!NewBO || NewBO->Opcode != Instruction::Mul) {
7801 MulOps.push_back(getSCEV(BO->LHS));
7802 break;
7803 }
7804 BO = NewBO;
7805 } while (true);
7806
7807 return getMulExpr(MulOps);
7808 }
7809 case Instruction::UDiv:
7810 LHS = getSCEV(BO->LHS);
7811 RHS = getSCEV(BO->RHS);
7812 return getUDivExpr(LHS, RHS);
7813 case Instruction::URem:
7814 LHS = getSCEV(BO->LHS);
7815 RHS = getSCEV(BO->RHS);
7816 return getURemExpr(LHS, RHS);
7817 case Instruction::Sub: {
7819 if (BO->Op)
7820 Flags = getNoWrapFlagsFromUB(BO->Op);
7821 LHS = getSCEV(BO->LHS);
7822 RHS = getSCEV(BO->RHS);
7823 return getMinusSCEV(LHS, RHS, Flags);
7824 }
7825 case Instruction::And:
7826 // For an expression like x&255 that merely masks off the high bits,
7827 // use zext(trunc(x)) as the SCEV expression.
7828 if (ConstantInt *CI = dyn_cast<ConstantInt>(BO->RHS)) {
7829 if (CI->isZero())
7830 return getSCEV(BO->RHS);
7831 if (CI->isMinusOne())
7832 return getSCEV(BO->LHS);
7833 const APInt &A = CI->getValue();
7834
7835 // Instcombine's ShrinkDemandedConstant may strip bits out of
7836 // constants, obscuring what would otherwise be a low-bits mask.
7837 // Use computeKnownBits to compute what ShrinkDemandedConstant
7838 // knew about to reconstruct a low-bits mask value.
7839 unsigned LZ = A.countl_zero();
7840 unsigned TZ = A.countr_zero();
7841 unsigned BitWidth = A.getBitWidth();
7842 KnownBits Known(BitWidth);
7843 computeKnownBits(BO->LHS, Known, getDataLayout(),
7844 0, &AC, nullptr, &DT);
7845
7846 APInt EffectiveMask =
7847 APInt::getLowBitsSet(BitWidth, BitWidth - LZ - TZ).shl(TZ);
7848 if ((LZ != 0 || TZ != 0) && !((~A & ~Known.Zero) & EffectiveMask)) {
7849 const SCEV *MulCount = getConstant(APInt::getOneBitSet(BitWidth, TZ));
7850 const SCEV *LHS = getSCEV(BO->LHS);
7851 const SCEV *ShiftedLHS = nullptr;
7852 if (auto *LHSMul = dyn_cast<SCEVMulExpr>(LHS)) {
7853 if (auto *OpC = dyn_cast<SCEVConstant>(LHSMul->getOperand(0))) {
7854 // For an expression like (x * 8) & 8, simplify the multiply.
7855 unsigned MulZeros = OpC->getAPInt().countr_zero();
7856 unsigned GCD = std::min(MulZeros, TZ);
7857 APInt DivAmt = APInt::getOneBitSet(BitWidth, TZ - GCD);
7859 MulOps.push_back(getConstant(OpC->getAPInt().lshr(GCD)));
7860 append_range(MulOps, LHSMul->operands().drop_front());
7861 auto *NewMul = getMulExpr(MulOps, LHSMul->getNoWrapFlags());
7862 ShiftedLHS = getUDivExpr(NewMul, getConstant(DivAmt));
7863 }
7864 }
7865 if (!ShiftedLHS)
7866 ShiftedLHS = getUDivExpr(LHS, MulCount);
7867 return getMulExpr(
7869 getTruncateExpr(ShiftedLHS,
7870 IntegerType::get(getContext(), BitWidth - LZ - TZ)),
7871 BO->LHS->getType()),
7872 MulCount);
7873 }
7874 }
7875 // Binary `and` is a bit-wise `umin`.
7876 if (BO->LHS->getType()->isIntegerTy(1)) {
7877 LHS = getSCEV(BO->LHS);
7878 RHS = getSCEV(BO->RHS);
7879 return getUMinExpr(LHS, RHS);
7880 }
7881 break;
7882
7883 case Instruction::Or:
7884 // Binary `or` is a bit-wise `umax`.
7885 if (BO->LHS->getType()->isIntegerTy(1)) {
7886 LHS = getSCEV(BO->LHS);
7887 RHS = getSCEV(BO->RHS);
7888 return getUMaxExpr(LHS, RHS);
7889 }
7890 break;
7891
7892 case Instruction::Xor:
7893 if (ConstantInt *CI = dyn_cast<ConstantInt>(BO->RHS)) {
7894 // If the RHS of xor is -1, then this is a not operation.
7895 if (CI->isMinusOne())
7896 return getNotSCEV(getSCEV(BO->LHS));
7897
7898 // Model xor(and(x, C), C) as and(~x, C), if C is a low-bits mask.
7899 // This is a variant of the check for xor with -1, and it handles
7900 // the case where instcombine has trimmed non-demanded bits out
7901 // of an xor with -1.
7902 if (auto *LBO = dyn_cast<BinaryOperator>(BO->LHS))
7903 if (ConstantInt *LCI = dyn_cast<ConstantInt>(LBO->getOperand(1)))
7904 if (LBO->getOpcode() == Instruction::And &&
7905 LCI->getValue() == CI->getValue())
7906 if (const SCEVZeroExtendExpr *Z =
7907 dyn_cast<SCEVZeroExtendExpr>(getSCEV(BO->LHS))) {
7908 Type *UTy = BO->LHS->getType();
7909 const SCEV *Z0 = Z->getOperand();
7910 Type *Z0Ty = Z0->getType();
7911 unsigned Z0TySize = getTypeSizeInBits(Z0Ty);
7912
7913 // If C is a low-bits mask, the zero extend is serving to
7914 // mask off the high bits. Complement the operand and
7915 // re-apply the zext.
7916 if (CI->getValue().isMask(Z0TySize))
7917 return getZeroExtendExpr(getNotSCEV(Z0), UTy);
7918
7919 // If C is a single bit, it may be in the sign-bit position
7920 // before the zero-extend. In this case, represent the xor
7921 // using an add, which is equivalent, and re-apply the zext.
7922 APInt Trunc = CI->getValue().trunc(Z0TySize);
7923 if (Trunc.zext(getTypeSizeInBits(UTy)) == CI->getValue() &&
7924 Trunc.isSignMask())
7925 return getZeroExtendExpr(getAddExpr(Z0, getConstant(Trunc)),
7926 UTy);
7927 }
7928 }
7929 break;
7930
7931 case Instruction::Shl:
7932 // Turn shift left of a constant amount into a multiply.
7933 if (ConstantInt *SA = dyn_cast<ConstantInt>(BO->RHS)) {
7934 uint32_t BitWidth = cast<IntegerType>(SA->getType())->getBitWidth();
7935
7936 // If the shift count is not less than the bitwidth, the result of
7937 // the shift is undefined. Don't try to analyze it, because the
7938 // resolution chosen here may differ from the resolution chosen in
7939 // other parts of the compiler.
7940 if (SA->getValue().uge(BitWidth))
7941 break;
7942
7943 // We can safely preserve the nuw flag in all cases. It's also safe to
7944 // turn a nuw nsw shl into a nuw nsw mul. However, nsw in isolation
7945 // requires special handling. It can be preserved as long as we're not
7946 // left shifting by bitwidth - 1.
7947 auto Flags = SCEV::FlagAnyWrap;
7948 if (BO->Op) {
7949 auto MulFlags = getNoWrapFlagsFromUB(BO->Op);
7950 if ((MulFlags & SCEV::FlagNSW) &&
7951 ((MulFlags & SCEV::FlagNUW) || SA->getValue().ult(BitWidth - 1)))
7953 if (MulFlags & SCEV::FlagNUW)
7955 }
7956
7957 ConstantInt *X = ConstantInt::get(
7958 getContext(), APInt::getOneBitSet(BitWidth, SA->getZExtValue()));
7959 return getMulExpr(getSCEV(BO->LHS), getConstant(X), Flags);
7960 }
7961 break;
7962
7963 case Instruction::AShr:
7964 // AShr X, C, where C is a constant.
7965 ConstantInt *CI = dyn_cast<ConstantInt>(BO->RHS);
7966 if (!CI)
7967 break;
7968
7969 Type *OuterTy = BO->LHS->getType();
7971 // If the shift count is not less than the bitwidth, the result of
7972 // the shift is undefined. Don't try to analyze it, because the
7973 // resolution chosen here may differ from the resolution chosen in
7974 // other parts of the compiler.
7975 if (CI->getValue().uge(BitWidth))
7976 break;
7977
7978 if (CI->isZero())
7979 return getSCEV(BO->LHS); // shift by zero --> noop
7980
7981 uint64_t AShrAmt = CI->getZExtValue();
7982 Type *TruncTy = IntegerType::get(getContext(), BitWidth - AShrAmt);
7983
7984 Operator *L = dyn_cast<Operator>(BO->LHS);
7985 const SCEV *AddTruncateExpr = nullptr;
7986 ConstantInt *ShlAmtCI = nullptr;
7987 const SCEV *AddConstant = nullptr;
7988
7989 if (L && L->getOpcode() == Instruction::Add) {
7990 // X = Shl A, n
7991 // Y = Add X, c
7992 // Z = AShr Y, m
7993 // n, c and m are constants.
7994
7995 Operator *LShift = dyn_cast<Operator>(L->getOperand(0));
7996 ConstantInt *AddOperandCI = dyn_cast<ConstantInt>(L->getOperand(1));
7997 if (LShift && LShift->getOpcode() == Instruction::Shl) {
7998 if (AddOperandCI) {
7999 const SCEV *ShlOp0SCEV = getSCEV(LShift->getOperand(0));
8000 ShlAmtCI = dyn_cast<ConstantInt>(LShift->getOperand(1));
8001 // since we truncate to TruncTy, the AddConstant should be of the
8002 // same type, so create a new Constant with type same as TruncTy.
8003 // Also, the Add constant should be shifted right by AShr amount.
8004 APInt AddOperand = AddOperandCI->getValue().ashr(AShrAmt);
8005 AddConstant = getConstant(AddOperand.trunc(BitWidth - AShrAmt));
8006 // we model the expression as sext(add(trunc(A), c << n)), since the
8007 // sext(trunc) part is already handled below, we create a
8008 // AddExpr(TruncExp) which will be used later.
8009 AddTruncateExpr = getTruncateExpr(ShlOp0SCEV, TruncTy);
8010 }
8011 }
8012 } else if (L && L->getOpcode() == Instruction::Shl) {
8013 // X = Shl A, n
8014 // Y = AShr X, m
8015 // Both n and m are constant.
8016
8017 const SCEV *ShlOp0SCEV = getSCEV(L->getOperand(0));
8018 ShlAmtCI = dyn_cast<ConstantInt>(L->getOperand(1));
8019 AddTruncateExpr = getTruncateExpr(ShlOp0SCEV, TruncTy);
8020 }
8021
8022 if (AddTruncateExpr && ShlAmtCI) {
8023 // We can merge the two given cases into a single SCEV statement,
8024 // incase n = m, the mul expression will be 2^0, so it gets resolved to
8025 // a simpler case. The following code handles the two cases:
8026 //
8027 // 1) For a two-shift sext-inreg, i.e. n = m,
8028 // use sext(trunc(x)) as the SCEV expression.
8029 //
8030 // 2) When n > m, use sext(mul(trunc(x), 2^(n-m)))) as the SCEV
8031 // expression. We already checked that ShlAmt < BitWidth, so
8032 // the multiplier, 1 << (ShlAmt - AShrAmt), fits into TruncTy as
8033 // ShlAmt - AShrAmt < Amt.
8034 const APInt &ShlAmt = ShlAmtCI->getValue();
8035 if (ShlAmt.ult(BitWidth) && ShlAmt.uge(AShrAmt)) {
8037 ShlAmtCI->getZExtValue() - AShrAmt);
8038 const SCEV *CompositeExpr =
8039 getMulExpr(AddTruncateExpr, getConstant(Mul));
8040 if (L->getOpcode() != Instruction::Shl)
8041 CompositeExpr = getAddExpr(CompositeExpr, AddConstant);
8042
8043 return getSignExtendExpr(CompositeExpr, OuterTy);
8044 }
8045 }
8046 break;
8047 }
8048 }
8049
8050 switch (U->getOpcode()) {
8051 case Instruction::Trunc:
8052 return getTruncateExpr(getSCEV(U->getOperand(0)), U->getType());
8053
8054 case Instruction::ZExt:
8055 return getZeroExtendExpr(getSCEV(U->getOperand(0)), U->getType());
8056
8057 case Instruction::SExt:
8058 if (auto BO = MatchBinaryOp(U->getOperand(0), getDataLayout(), AC, DT,
8059 dyn_cast<Instruction>(V))) {
8060 // The NSW flag of a subtract does not always survive the conversion to
8061 // A + (-1)*B. By pushing sign extension onto its operands we are much
8062 // more likely to preserve NSW and allow later AddRec optimisations.
8063 //
8064 // NOTE: This is effectively duplicating this logic from getSignExtend:
8065 // sext((A + B + ...)<nsw>) --> (sext(A) + sext(B) + ...)<nsw>
8066 // but by that point the NSW information has potentially been lost.
8067 if (BO->Opcode == Instruction::Sub && BO->IsNSW) {
8068 Type *Ty = U->getType();
8069 auto *V1 = getSignExtendExpr(getSCEV(BO->LHS), Ty);
8070 auto *V2 = getSignExtendExpr(getSCEV(BO->RHS), Ty);
8071 return getMinusSCEV(V1, V2, SCEV::FlagNSW);
8072 }
8073 }
8074 return getSignExtendExpr(getSCEV(U->getOperand(0)), U->getType());
8075
8076 case Instruction::BitCast:
8077 // BitCasts are no-op casts so we just eliminate the cast.
8078 if (isSCEVable(U->getType()) && isSCEVable(U->getOperand(0)->getType()))
8079 return getSCEV(U->getOperand(0));
8080 break;
8081
8082 case Instruction::PtrToInt: {
8083 // Pointer to integer cast is straight-forward, so do model it.
8084 const SCEV *Op = getSCEV(U->getOperand(0));
8085 Type *DstIntTy = U->getType();
8086 // But only if effective SCEV (integer) type is wide enough to represent
8087 // all possible pointer values.
8088 const SCEV *IntOp = getPtrToIntExpr(Op, DstIntTy);
8089 if (isa<SCEVCouldNotCompute>(IntOp))
8090 return getUnknown(V);
8091 return IntOp;
8092 }
8093 case Instruction::IntToPtr:
8094 // Just don't deal with inttoptr casts.
8095 return getUnknown(V);
8096
8097 case Instruction::SDiv:
8098 // If both operands are non-negative, this is just an udiv.
8099 if (isKnownNonNegative(getSCEV(U->getOperand(0))) &&
8100 isKnownNonNegative(getSCEV(U->getOperand(1))))
8101 return getUDivExpr(getSCEV(U->getOperand(0)), getSCEV(U->getOperand(1)));
8102 break;
8103
8104 case Instruction::SRem:
8105 // If both operands are non-negative, this is just an urem.
8106 if (isKnownNonNegative(getSCEV(U->getOperand(0))) &&
8107 isKnownNonNegative(getSCEV(U->getOperand(1))))
8108 return getURemExpr(getSCEV(U->getOperand(0)), getSCEV(U->getOperand(1)));
8109 break;
8110
8111 case Instruction::GetElementPtr:
8112 return createNodeForGEP(cast<GEPOperator>(U));
8113
8114 case Instruction::PHI:
8115 return createNodeForPHI(cast<PHINode>(U));
8116
8117 case Instruction::Select:
8118 return createNodeForSelectOrPHI(U, U->getOperand(0), U->getOperand(1),
8119 U->getOperand(2));
8120
8121 case Instruction::Call:
8122 case Instruction::Invoke:
8123 if (Value *RV = cast<CallBase>(U)->getReturnedArgOperand())
8124 return getSCEV(RV);
8125
8126 if (auto *II = dyn_cast<IntrinsicInst>(U)) {
8127 switch (II->getIntrinsicID()) {
8128 case Intrinsic::abs:
8129 return getAbsExpr(
8130 getSCEV(II->getArgOperand(0)),
8131 /*IsNSW=*/cast<ConstantInt>(II->getArgOperand(1))->isOne());
8132 case Intrinsic::umax:
8133 LHS = getSCEV(II->getArgOperand(0));
8134 RHS = getSCEV(II->getArgOperand(1));
8135 return getUMaxExpr(LHS, RHS);
8136 case Intrinsic::umin:
8137 LHS = getSCEV(II->getArgOperand(0));
8138 RHS = getSCEV(II->getArgOperand(1));
8139 return getUMinExpr(LHS, RHS);
8140 case Intrinsic::smax:
8141 LHS = getSCEV(II->getArgOperand(0));
8142 RHS = getSCEV(II->getArgOperand(1));
8143 return getSMaxExpr(LHS, RHS);
8144 case Intrinsic::smin:
8145 LHS = getSCEV(II->getArgOperand(0));
8146 RHS = getSCEV(II->getArgOperand(1));
8147 return getSMinExpr(LHS, RHS);
8148 case Intrinsic::usub_sat: {
8149 const SCEV *X = getSCEV(II->getArgOperand(0));
8150 const SCEV *Y = getSCEV(II->getArgOperand(1));
8151 const SCEV *ClampedY = getUMinExpr(X, Y);
8152 return getMinusSCEV(X, ClampedY, SCEV::FlagNUW);
8153 }
8154 case Intrinsic::uadd_sat: {
8155 const SCEV *X = getSCEV(II->getArgOperand(0));
8156 const SCEV *Y = getSCEV(II->getArgOperand(1));
8157 const SCEV *ClampedX = getUMinExpr(X, getNotSCEV(Y));
8158 return getAddExpr(ClampedX, Y, SCEV::FlagNUW);
8159 }
8160 case Intrinsic::start_loop_iterations:
8161 case Intrinsic::annotation:
8162 case Intrinsic::ptr_annotation:
8163 // A start_loop_iterations or llvm.annotation or llvm.prt.annotation is
8164 // just eqivalent to the first operand for SCEV purposes.
8165 return getSCEV(II->getArgOperand(0));
8166 case Intrinsic::vscale:
8167 return getVScale(II->getType());
8168 default:
8169 break;
8170 }
8171 }
8172 break;
8173 }
8174
8175 return getUnknown(V);
8176}
8177
8178//===----------------------------------------------------------------------===//
8179// Iteration Count Computation Code
8180//
8181
8183 if (isa<SCEVCouldNotCompute>(ExitCount))
8184 return getCouldNotCompute();
8185
8186 auto *ExitCountType = ExitCount->getType();
8187 assert(ExitCountType->isIntegerTy());
8188 auto *EvalTy = Type::getIntNTy(ExitCountType->getContext(),
8189 1 + ExitCountType->getScalarSizeInBits());
8190 return getTripCountFromExitCount(ExitCount, EvalTy, nullptr);
8191}
8192
8194 Type *EvalTy,
8195 const Loop *L) {
8196 if (isa<SCEVCouldNotCompute>(ExitCount))
8197 return getCouldNotCompute();
8198
8199 unsigned ExitCountSize = getTypeSizeInBits(ExitCount->getType());
8200 unsigned EvalSize = EvalTy->getPrimitiveSizeInBits();
8201
8202 auto CanAddOneWithoutOverflow = [&]() {
8203 ConstantRange ExitCountRange =
8204 getRangeRef(ExitCount, RangeSignHint::HINT_RANGE_UNSIGNED);
8205 if (!ExitCountRange.contains(APInt::getMaxValue(ExitCountSize)))
8206 return true;
8207
8208 return L && isLoopEntryGuardedByCond(L, ICmpInst::ICMP_NE, ExitCount,
8209 getMinusOne(ExitCount->getType()));
8210 };
8211
8212 // If we need to zero extend the backedge count, check if we can add one to
8213 // it prior to zero extending without overflow. Provided this is safe, it
8214 // allows better simplification of the +1.
8215 if (EvalSize > ExitCountSize && CanAddOneWithoutOverflow())
8216 return getZeroExtendExpr(
8217 getAddExpr(ExitCount, getOne(ExitCount->getType())), EvalTy);
8218
8219 // Get the total trip count from the count by adding 1. This may wrap.
8220 return getAddExpr(getTruncateOrZeroExtend(ExitCount, EvalTy), getOne(EvalTy));
8221}
8222
8223static unsigned getConstantTripCount(const SCEVConstant *ExitCount) {
8224 if (!ExitCount)
8225 return 0;
8226
8227 ConstantInt *ExitConst = ExitCount->getValue();
8228
8229 // Guard against huge trip counts.
8230 if (ExitConst->getValue().getActiveBits() > 32)
8231 return 0;
8232
8233 // In case of integer overflow, this returns 0, which is correct.
8234 return ((unsigned)ExitConst->getZExtValue()) + 1;
8235}
8236
8238 auto *ExitCount = dyn_cast<SCEVConstant>(getBackedgeTakenCount(L, Exact));
8239 return getConstantTripCount(ExitCount);
8240}
8241
8242unsigned
8244 const BasicBlock *ExitingBlock) {
8245 assert(ExitingBlock && "Must pass a non-null exiting block!");
8246 assert(L->isLoopExiting(ExitingBlock) &&
8247 "Exiting block must actually branch out of the loop!");
8248 const SCEVConstant *ExitCount =
8249 dyn_cast<SCEVConstant>(getExitCount(L, ExitingBlock));
8250 return getConstantTripCount(ExitCount);
8251}
8252
8254 const Loop *L, SmallVectorImpl<const SCEVPredicate *> *Predicates) {
8255
8256 const auto *MaxExitCount =
8257 Predicates ? getPredicatedConstantMaxBackedgeTakenCount(L, *Predicates)
8259 return getConstantTripCount(dyn_cast<SCEVConstant>(MaxExitCount));
8260}
8261
8263 SmallVector<BasicBlock *, 8> ExitingBlocks;
8264 L->getExitingBlocks(ExitingBlocks);
8265
8266 std::optional<unsigned> Res;
8267 for (auto *ExitingBB : ExitingBlocks) {
8268 unsigned Multiple = getSmallConstantTripMultiple(L, ExitingBB);
8269 if (!Res)
8270 Res = Multiple;
8271 Res = (unsigned)std::gcd(*Res, Multiple);
8272 }
8273 return Res.value_or(1);
8274}
8275
8277 const SCEV *ExitCount) {
8278 if (ExitCount == getCouldNotCompute())
8279 return 1;
8280
8281 // Get the trip count
8282 const SCEV *TCExpr = getTripCountFromExitCount(applyLoopGuards(ExitCount, L));
8283
8284 APInt Multiple = getNonZeroConstantMultiple(TCExpr);
8285 // If a trip multiple is huge (>=2^32), the trip count is still divisible by
8286 // the greatest power of 2 divisor less than 2^32.
8287 return Multiple.getActiveBits() > 32
8288 ? 1U << std::min((unsigned)31, Multiple.countTrailingZeros())
8289 : (unsigned)Multiple.zextOrTrunc(32).getZExtValue();
8290}
8291
8292/// Returns the largest constant divisor of the trip count of this loop as a
8293/// normal unsigned value, if possible. This means that the actual trip count is
8294/// always a multiple of the returned value (don't forget the trip count could
8295/// very well be zero as well!).
8296///
8297/// Returns 1 if the trip count is unknown or not guaranteed to be the
8298/// multiple of a constant (which is also the case if the trip count is simply
8299/// constant, use getSmallConstantTripCount for that case), Will also return 1
8300/// if the trip count is very large (>= 2^32).
8301///
8302/// As explained in the comments for getSmallConstantTripCount, this assumes
8303/// that control exits the loop via ExitingBlock.
8304unsigned
8306 const BasicBlock *ExitingBlock) {
8307 assert(ExitingBlock && "Must pass a non-null exiting block!");
8308 assert(L->isLoopExiting(ExitingBlock) &&
8309 "Exiting block must actually branch out of the loop!");
8310 const SCEV *ExitCount = getExitCount(L, ExitingBlock);
8311 return getSmallConstantTripMultiple(L, ExitCount);
8312}
8313
8315 const BasicBlock *ExitingBlock,
8316 ExitCountKind Kind) {
8317 switch (Kind) {
8318 case Exact:
8319 return getBackedgeTakenInfo(L).getExact(ExitingBlock, this);
8320 case SymbolicMaximum:
8321 return getBackedgeTakenInfo(L).getSymbolicMax(ExitingBlock, this);
8322 case ConstantMaximum:
8323 return getBackedgeTakenInfo(L).getConstantMax(ExitingBlock, this);
8324 };
8325 llvm_unreachable("Invalid ExitCountKind!");
8326}
8327
8329 const Loop *L, const BasicBlock *ExitingBlock,
8331 switch (Kind) {
8332 case Exact:
8333 return getPredicatedBackedgeTakenInfo(L).getExact(ExitingBlock, this,
8334 Predicates);
8335 case SymbolicMaximum:
8336 return getPredicatedBackedgeTakenInfo(L).getSymbolicMax(ExitingBlock, this,
8337 Predicates);
8338 case ConstantMaximum:
8339 return getPredicatedBackedgeTakenInfo(L).getConstantMax(ExitingBlock, this,
8340 Predicates);
8341 };
8342 llvm_unreachable("Invalid ExitCountKind!");
8343}
8344
8347 return getPredicatedBackedgeTakenInfo(L).getExact(L, this, &Preds);
8348}
8349
8351 ExitCountKind Kind) {
8352 switch (Kind) {
8353 case Exact:
8354 return getBackedgeTakenInfo(L).getExact(L, this);
8355 case ConstantMaximum:
8356 return getBackedgeTakenInfo(L).getConstantMax(this);
8357 case SymbolicMaximum:
8358 return getBackedgeTakenInfo(L).getSymbolicMax(L, this);
8359 };
8360 llvm_unreachable("Invalid ExitCountKind!");
8361}
8362
8365 return getPredicatedBackedgeTakenInfo(L).getSymbolicMax(L, this, &Preds);
8366}
8367
8370 return getPredicatedBackedgeTakenInfo(L).getConstantMax(this, &Preds);
8371}
8372
8374 return getBackedgeTakenInfo(L).isConstantMaxOrZero(this);
8375}
8376
8377/// Push PHI nodes in the header of the given loop onto the given Worklist.
8378static void PushLoopPHIs(const Loop *L,
8381 BasicBlock *Header = L->getHeader();
8382
8383 // Push all Loop-header PHIs onto the Worklist stack.
8384 for (PHINode &PN : Header->phis())
8385 if (Visited.insert(&PN).second)
8386 Worklist.push_back(&PN);
8387}
8388
8389ScalarEvolution::BackedgeTakenInfo &
8390ScalarEvolution::getPredicatedBackedgeTakenInfo(const Loop *L) {
8391 auto &BTI = getBackedgeTakenInfo(L);
8392 if (BTI.hasFullInfo())
8393 return BTI;
8394
8395 auto Pair = PredicatedBackedgeTakenCounts.insert({L, BackedgeTakenInfo()});
8396
8397 if (!Pair.second)
8398 return Pair.first->second;
8399
8400 BackedgeTakenInfo Result =
8401 computeBackedgeTakenCount(L, /*AllowPredicates=*/true);
8402
8403 return PredicatedBackedgeTakenCounts.find(L)->second = std::move(Result);
8404}
8405
8406ScalarEvolution::BackedgeTakenInfo &
8407ScalarEvolution::getBackedgeTakenInfo(const Loop *L) {
8408 // Initially insert an invalid entry for this loop. If the insertion
8409 // succeeds, proceed to actually compute a backedge-taken count and
8410 // update the value. The temporary CouldNotCompute value tells SCEV
8411 // code elsewhere that it shouldn't attempt to request a new
8412 // backedge-taken count, which could result in infinite recursion.
8413 std::pair<DenseMap<const Loop *, BackedgeTakenInfo>::iterator, bool> Pair =
8414 BackedgeTakenCounts.insert({L, BackedgeTakenInfo()});
8415 if (!Pair.second)
8416 return Pair.first->second;
8417
8418 // computeBackedgeTakenCount may allocate memory for its result. Inserting it
8419 // into the BackedgeTakenCounts map transfers ownership. Otherwise, the result
8420 // must be cleared in this scope.
8421 BackedgeTakenInfo Result = computeBackedgeTakenCount(L);
8422
8423 // Now that we know more about the trip count for this loop, forget any
8424 // existing SCEV values for PHI nodes in this loop since they are only
8425 // conservative estimates made without the benefit of trip count
8426 // information. This invalidation is not necessary for correctness, and is
8427 // only done to produce more precise results.
8428 if (Result.hasAnyInfo()) {
8429 // Invalidate any expression using an addrec in this loop.
8431 auto LoopUsersIt = LoopUsers.find(L);
8432 if (LoopUsersIt != LoopUsers.end())
8433 append_range(ToForget, LoopUsersIt->second);
8434 forgetMemoizedResults(ToForget);
8435
8436 // Invalidate constant-evolved loop header phis.
8437 for (PHINode &PN : L->getHeader()->phis())
8438 ConstantEvolutionLoopExitValue.erase(&PN);
8439 }
8440
8441 // Re-lookup the insert position, since the call to
8442 // computeBackedgeTakenCount above could result in a
8443 // recusive call to getBackedgeTakenInfo (on a different
8444 // loop), which would invalidate the iterator computed
8445 // earlier.
8446 return BackedgeTakenCounts.find(L)->second = std::move(Result);
8447}
8448
8450 // This method is intended to forget all info about loops. It should
8451 // invalidate caches as if the following happened:
8452 // - The trip counts of all loops have changed arbitrarily
8453 // - Every llvm::Value has been updated in place to produce a different
8454 // result.
8455 BackedgeTakenCounts.clear();
8456 PredicatedBackedgeTakenCounts.clear();
8457 BECountUsers.clear();
8458 LoopPropertiesCache.clear();
8459 ConstantEvolutionLoopExitValue.clear();
8460 ValueExprMap.clear();
8461 ValuesAtScopes.clear();
8462 ValuesAtScopesUsers.clear();
8463 LoopDispositions.clear();
8464 BlockDispositions.clear();
8465 UnsignedRanges.clear();
8466 SignedRanges.clear();
8467 ExprValueMap.clear();
8468 HasRecMap.clear();
8469 ConstantMultipleCache.clear();
8470 PredicatedSCEVRewrites.clear();
8471 FoldCache.clear();
8472 FoldCacheUser.clear();
8473}
8474void ScalarEvolution::visitAndClearUsers(
8478 while (!Worklist.empty()) {
8479 Instruction *I = Worklist.pop_back_val();
8480 if (!isSCEVable(I->getType()) && !isa<WithOverflowInst>(I))
8481 continue;
8482
8484 ValueExprMap.find_as(static_cast<Value *>(I));
8485 if (It != ValueExprMap.end()) {
8486 eraseValueFromMap(It->first);
8487 ToForget.push_back(It->second);
8488 if (PHINode *PN = dyn_cast<PHINode>(I))
8489 ConstantEvolutionLoopExitValue.erase(PN);
8490 }
8491
8492 PushDefUseChildren(I, Worklist, Visited);
8493 }
8494}
8495
8497 SmallVector<const Loop *, 16> LoopWorklist(1, L);
8501
8502 // Iterate over all the loops and sub-loops to drop SCEV information.
8503 while (!LoopWorklist.empty()) {
8504 auto *CurrL = LoopWorklist.pop_back_val();
8505
8506 // Drop any stored trip count value.
8507 forgetBackedgeTakenCounts(CurrL, /* Predicated */ false);
8508 forgetBackedgeTakenCounts(CurrL, /* Predicated */ true);
8509
8510 // Drop information about predicated SCEV rewrites for this loop.
8511 for (auto I = PredicatedSCEVRewrites.begin();
8512 I != PredicatedSCEVRewrites.end();) {
8513 std::pair<const SCEV *, const Loop *> Entry = I->first;
8514 if (Entry.second == CurrL)
8515 PredicatedSCEVRewrites.erase(I++);
8516 else
8517 ++I;
8518 }
8519
8520 auto LoopUsersItr = LoopUsers.find(CurrL);
8521 if (LoopUsersItr != LoopUsers.end()) {
8522 ToForget.insert(ToForget.end(), LoopUsersItr->second.begin(),
8523 LoopUsersItr->second.end());
8524 }
8525
8526 // Drop information about expressions based on loop-header PHIs.
8527 PushLoopPHIs(CurrL, Worklist, Visited);
8528 visitAndClearUsers(Worklist, Visited, ToForget);
8529
8530 LoopPropertiesCache.erase(CurrL);
8531 // Forget all contained loops too, to avoid dangling entries in the
8532 // ValuesAtScopes map.
8533 LoopWorklist.append(CurrL->begin(), CurrL->end());
8534 }
8535 forgetMemoizedResults(ToForget);
8536}
8537
8539 forgetLoop(L->getOutermostLoop());
8540}
8541
8543 Instruction *I = dyn_cast<Instruction>(V);
8544 if (!I) return;
8545
8546 // Drop information about expressions based on loop-header PHIs.
8550 Worklist.push_back(I);
8551 Visited.insert(I);
8552 visitAndClearUsers(Worklist, Visited, ToForget);
8553
8554 forgetMemoizedResults(ToForget);
8555}
8556
8558 if (!isSCEVable(V->getType()))
8559 return;
8560
8561 // If SCEV looked through a trivial LCSSA phi node, we might have SCEV's
8562 // directly using a SCEVUnknown/SCEVAddRec defined in the loop. After an
8563 // extra predecessor is added, this is no longer valid. Find all Unknowns and
8564 // AddRecs defined in the loop and invalidate any SCEV's making use of them.
8565 if (const SCEV *S = getExistingSCEV(V)) {
8566 struct InvalidationRootCollector {
8567 Loop *L;
8569
8570 InvalidationRootCollector(Loop *L) : L(L) {}
8571
8572 bool follow(const SCEV *S) {
8573 if (auto *SU = dyn_cast<SCEVUnknown>(S)) {
8574 if (auto *I = dyn_cast<Instruction>(SU->getValue()))
8575 if (L->contains(I))
8576 Roots.push_back(S);
8577 } else if (auto *AddRec = dyn_cast<SCEVAddRecExpr>(S)) {
8578 if (L->contains(AddRec->getLoop()))
8579 Roots.push_back(S);
8580 }
8581 return true;
8582 }
8583 bool isDone() const { return false; }
8584 };
8585
8586 InvalidationRootCollector C(L);
8587 visitAll(S, C);
8588 forgetMemoizedResults(C.Roots);
8589 }
8590
8591 // Also perform the normal invalidation.
8592 forgetValue(V);
8593}
8594
8595void ScalarEvolution::forgetLoopDispositions() { LoopDispositions.clear(); }
8596
8598 // Unless a specific value is passed to invalidation, completely clear both
8599 // caches.
8600 if (!V) {
8601 BlockDispositions.clear();
8602 LoopDispositions.clear();
8603 return;
8604 }
8605
8606 if (!isSCEVable(V->getType()))
8607 return;
8608
8609 const SCEV *S = getExistingSCEV(V);
8610 if (!S)
8611 return;
8612
8613 // Invalidate the block and loop dispositions cached for S. Dispositions of
8614 // S's users may change if S's disposition changes (i.e. a user may change to
8615 // loop-invariant, if S changes to loop invariant), so also invalidate
8616 // dispositions of S's users recursively.
8617 SmallVector<const SCEV *, 8> Worklist = {S};
8619 while (!Worklist.empty()) {
8620 const SCEV *Curr = Worklist.pop_back_val();
8621 bool LoopDispoRemoved = LoopDispositions.erase(Curr);
8622 bool BlockDispoRemoved = BlockDispositions.erase(Curr);
8623 if (!LoopDispoRemoved && !BlockDispoRemoved)
8624 continue;
8625 auto Users = SCEVUsers.find(Curr);
8626 if (Users != SCEVUsers.end())
8627 for (const auto *User : Users->second)
8628 if (Seen.insert(User).second)
8629 Worklist.push_back(User);
8630 }
8631}
8632
8633/// Get the exact loop backedge taken count considering all loop exits. A
8634/// computable result can only be returned for loops with all exiting blocks
8635/// dominating the latch. howFarToZero assumes that the limit of each loop test
8636/// is never skipped. This is a valid assumption as long as the loop exits via
8637/// that test. For precise results, it is the caller's responsibility to specify
8638/// the relevant loop exiting block using getExact(ExitingBlock, SE).
8639const SCEV *ScalarEvolution::BackedgeTakenInfo::getExact(
8640 const Loop *L, ScalarEvolution *SE,
8642 // If any exits were not computable, the loop is not computable.
8643 if (!isComplete() || ExitNotTaken.empty())
8644 return SE->getCouldNotCompute();
8645
8646 const BasicBlock *Latch = L->getLoopLatch();
8647 // All exiting blocks we have collected must dominate the only backedge.
8648 if (!Latch)
8649 return SE->getCouldNotCompute();
8650
8651 // All exiting blocks we have gathered dominate loop's latch, so exact trip
8652 // count is simply a minimum out of all these calculated exit counts.
8654 for (const auto &ENT : ExitNotTaken) {
8655 const SCEV *BECount = ENT.ExactNotTaken;
8656 assert(BECount != SE->getCouldNotCompute() && "Bad exit SCEV!");
8657 assert(SE->DT.dominates(ENT.ExitingBlock, Latch) &&
8658 "We should only have known counts for exiting blocks that dominate "
8659 "latch!");
8660
8661 Ops.push_back(BECount);
8662
8663 if (Preds)
8664 append_range(*Preds, ENT.Predicates);
8665
8666 assert((Preds || ENT.hasAlwaysTruePredicate()) &&
8667 "Predicate should be always true!");
8668 }
8669
8670 // If an earlier exit exits on the first iteration (exit count zero), then
8671 // a later poison exit count should not propagate into the result. This are
8672 // exactly the semantics provided by umin_seq.
8673 return SE->getUMinFromMismatchedTypes(Ops, /* Sequential */ true);
8674}
8675
8676const ScalarEvolution::ExitNotTakenInfo *
8677ScalarEvolution::BackedgeTakenInfo::getExitNotTaken(
8678 const BasicBlock *ExitingBlock,
8679 SmallVectorImpl<const SCEVPredicate *> *Predicates) const {
8680 for (const auto &ENT : ExitNotTaken)
8681 if (ENT.ExitingBlock == ExitingBlock) {
8682 if (ENT.hasAlwaysTruePredicate())
8683 return &ENT;
8684 else if (Predicates) {
8685 append_range(*Predicates, ENT.Predicates);
8686 return &ENT;
8687 }
8688 }
8689
8690 return nullptr;
8691}
8692
8693/// getConstantMax - Get the constant max backedge taken count for the loop.
8694const SCEV *ScalarEvolution::BackedgeTakenInfo::getConstantMax(
8695 ScalarEvolution *SE,
8696 SmallVectorImpl<const SCEVPredicate *> *Predicates) const {
8697 if (!getConstantMax())
8698 return SE->getCouldNotCompute();
8699
8700 for (const auto &ENT : ExitNotTaken)
8701 if (!ENT.hasAlwaysTruePredicate()) {
8702 if (!Predicates)
8703 return SE->getCouldNotCompute();
8704 append_range(*Predicates, ENT.Predicates);
8705 }
8706
8707 assert((isa<SCEVCouldNotCompute>(getConstantMax()) ||
8708 isa<SCEVConstant>(getConstantMax())) &&
8709 "No point in having a non-constant max backedge taken count!");
8710 return getConstantMax();
8711}
8712
8713const SCEV *ScalarEvolution::BackedgeTakenInfo::getSymbolicMax(
8714 const Loop *L, ScalarEvolution *SE,
8716 if (!SymbolicMax) {
8717 // Form an expression for the maximum exit count possible for this loop. We
8718 // merge the max and exact information to approximate a version of
8719 // getConstantMaxBackedgeTakenCount which isn't restricted to just
8720 // constants.
8722
8723 for (const auto &ENT : ExitNotTaken) {
8724 const SCEV *ExitCount = ENT.SymbolicMaxNotTaken;
8725 if (!isa<SCEVCouldNotCompute>(ExitCount)) {
8726 assert(SE->DT.dominates(ENT.ExitingBlock, L->getLoopLatch()) &&
8727 "We should only have known counts for exiting blocks that "
8728 "dominate latch!");
8729 ExitCounts.push_back(ExitCount);
8730 if (Predicates)
8731 append_range(*Predicates, ENT.Predicates);
8732
8733 assert((Predicates || ENT.hasAlwaysTruePredicate()) &&
8734 "Predicate should be always true!");
8735 }
8736 }
8737 if (ExitCounts.empty())
8738 SymbolicMax = SE->getCouldNotCompute();
8739 else
8740 SymbolicMax =
8741 SE->getUMinFromMismatchedTypes(ExitCounts, /*Sequential*/ true);
8742 }
8743 return SymbolicMax;
8744}
8745
8746bool ScalarEvolution::BackedgeTakenInfo::isConstantMaxOrZero(
8747 ScalarEvolution *SE) const {
8748 auto PredicateNotAlwaysTrue = [](const ExitNotTakenInfo &ENT) {
8749 return !ENT.hasAlwaysTruePredicate();
8750 };
8751 return MaxOrZero && !any_of(ExitNotTaken, PredicateNotAlwaysTrue);
8752}
8753
8755 : ExitLimit(E, E, E, false) {}
8756
8758 const SCEV *E, const SCEV *ConstantMaxNotTaken,
8759 const SCEV *SymbolicMaxNotTaken, bool MaxOrZero,
8761 : ExactNotTaken(E), ConstantMaxNotTaken(ConstantMaxNotTaken),
8762 SymbolicMaxNotTaken(SymbolicMaxNotTaken), MaxOrZero(MaxOrZero) {
8763 // If we prove the max count is zero, so is the symbolic bound. This happens
8764 // in practice due to differences in a) how context sensitive we've chosen
8765 // to be and b) how we reason about bounds implied by UB.
8766 if (ConstantMaxNotTaken->isZero()) {
8768 this->SymbolicMaxNotTaken = SymbolicMaxNotTaken = ConstantMaxNotTaken;
8769 }
8770
8771 assert((isa<SCEVCouldNotCompute>(ExactNotTaken) ||
8772 !isa<SCEVCouldNotCompute>(ConstantMaxNotTaken)) &&
8773 "Exact is not allowed to be less precise than Constant Max");
8774 assert((isa<SCEVCouldNotCompute>(ExactNotTaken) ||
8775 !isa<SCEVCouldNotCompute>(SymbolicMaxNotTaken)) &&
8776 "Exact is not allowed to be less precise than Symbolic Max");
8777 assert((isa<SCEVCouldNotCompute>(SymbolicMaxNotTaken) ||
8778 !isa<SCEVCouldNotCompute>(ConstantMaxNotTaken)) &&
8779 "Symbolic Max is not allowed to be less precise than Constant Max");
8780 assert((isa<SCEVCouldNotCompute>(ConstantMaxNotTaken) ||
8781 isa<SCEVConstant>(ConstantMaxNotTaken)) &&
8782 "No point in having a non-constant max backedge taken count!");
8784 for (const auto PredList : PredLists)
8785 for (const auto *P : PredList) {
8786 if (SeenPreds.contains(P))
8787 continue;
8788 assert(!isa<SCEVUnionPredicate>(P) && "Only add leaf predicates here!");
8789 SeenPreds.insert(P);
8790 Predicates.push_back(P);
8791 }
8792 assert((isa<SCEVCouldNotCompute>(E) || !E->getType()->isPointerTy()) &&
8793 "Backedge count should be int");
8794 assert((isa<SCEVCouldNotCompute>(ConstantMaxNotTaken) ||
8796 "Max backedge count should be int");
8797}
8798
8800 const SCEV *ConstantMaxNotTaken,
8801 const SCEV *SymbolicMaxNotTaken,
8802 bool MaxOrZero,
8804 : ExitLimit(E, ConstantMaxNotTaken, SymbolicMaxNotTaken, MaxOrZero,
8805 ArrayRef({PredList})) {}
8806
8807/// Allocate memory for BackedgeTakenInfo and copy the not-taken count of each
8808/// computable exit into a persistent ExitNotTakenInfo array.
8809ScalarEvolution::BackedgeTakenInfo::BackedgeTakenInfo(
8811 bool IsComplete, const SCEV *ConstantMax, bool MaxOrZero)
8812 : ConstantMax(ConstantMax), IsComplete(IsComplete), MaxOrZero(MaxOrZero) {
8813 using EdgeExitInfo = ScalarEvolution::BackedgeTakenInfo::EdgeExitInfo;
8814
8815 ExitNotTaken.reserve(ExitCounts.size());
8816 std::transform(ExitCounts.begin(), ExitCounts.end(),
8817 std::back_inserter(ExitNotTaken),
8818 [&](const EdgeExitInfo &EEI) {
8819 BasicBlock *ExitBB = EEI.first;
8820 const ExitLimit &EL = EEI.second;
8821 return ExitNotTakenInfo(ExitBB, EL.ExactNotTaken,
8822 EL.ConstantMaxNotTaken, EL.SymbolicMaxNotTaken,
8823 EL.Predicates);
8824 });
8825 assert((isa<SCEVCouldNotCompute>(ConstantMax) ||
8826 isa<SCEVConstant>(ConstantMax)) &&
8827 "No point in having a non-constant max backedge taken count!");
8828}
8829
8830/// Compute the number of times the backedge of the specified loop will execute.
8831ScalarEvolution::BackedgeTakenInfo
8832ScalarEvolution::computeBackedgeTakenCount(const Loop *L,
8833 bool AllowPredicates) {
8834 SmallVector<BasicBlock *, 8> ExitingBlocks;
8835 L->getExitingBlocks(ExitingBlocks);
8836
8837 using EdgeExitInfo = ScalarEvolution::BackedgeTakenInfo::EdgeExitInfo;
8838
8840 bool CouldComputeBECount = true;
8841 BasicBlock *Latch = L->getLoopLatch(); // may be NULL.
8842 const SCEV *MustExitMaxBECount = nullptr;
8843 const SCEV *MayExitMaxBECount = nullptr;
8844 bool MustExitMaxOrZero = false;
8845 bool IsOnlyExit = ExitingBlocks.size() == 1;
8846
8847 // Compute the ExitLimit for each loop exit. Use this to populate ExitCounts
8848 // and compute maxBECount.
8849 // Do a union of all the predicates here.
8850 for (BasicBlock *ExitBB : ExitingBlocks) {
8851 // We canonicalize untaken exits to br (constant), ignore them so that
8852 // proving an exit untaken doesn't negatively impact our ability to reason
8853 // about the loop as whole.
8854 if (auto *BI = dyn_cast<BranchInst>(ExitBB->getTerminator()))
8855 if (auto *CI = dyn_cast<ConstantInt>(BI->getCondition())) {
8856 bool ExitIfTrue = !L->contains(BI->getSuccessor(0));
8857 if (ExitIfTrue == CI->isZero())
8858 continue;
8859 }
8860
8861 ExitLimit EL = computeExitLimit(L, ExitBB, IsOnlyExit, AllowPredicates);
8862
8863 assert((AllowPredicates || EL.Predicates.empty()) &&
8864 "Predicated exit limit when predicates are not allowed!");
8865
8866 // 1. For each exit that can be computed, add an entry to ExitCounts.
8867 // CouldComputeBECount is true only if all exits can be computed.
8868 if (EL.ExactNotTaken != getCouldNotCompute())
8869 ++NumExitCountsComputed;
8870 else
8871 // We couldn't compute an exact value for this exit, so
8872 // we won't be able to compute an exact value for the loop.
8873 CouldComputeBECount = false;
8874 // Remember exit count if either exact or symbolic is known. Because
8875 // Exact always implies symbolic, only check symbolic.
8876 if (EL.SymbolicMaxNotTaken != getCouldNotCompute())
8877 ExitCounts.emplace_back(ExitBB, EL);
8878 else {
8879 assert(EL.ExactNotTaken == getCouldNotCompute() &&
8880 "Exact is known but symbolic isn't?");
8881 ++NumExitCountsNotComputed;
8882 }
8883
8884 // 2. Derive the loop's MaxBECount from each exit's max number of
8885 // non-exiting iterations. Partition the loop exits into two kinds:
8886 // LoopMustExits and LoopMayExits.
8887 //
8888 // If the exit dominates the loop latch, it is a LoopMustExit otherwise it
8889 // is a LoopMayExit. If any computable LoopMustExit is found, then
8890 // MaxBECount is the minimum EL.ConstantMaxNotTaken of computable
8891 // LoopMustExits. Otherwise, MaxBECount is conservatively the maximum
8892 // EL.ConstantMaxNotTaken, where CouldNotCompute is considered greater than
8893 // any
8894 // computable EL.ConstantMaxNotTaken.
8895 if (EL.ConstantMaxNotTaken != getCouldNotCompute() && Latch &&
8896 DT.dominates(ExitBB, Latch)) {
8897 if (!MustExitMaxBECount) {
8898 MustExitMaxBECount = EL.ConstantMaxNotTaken;
8899 MustExitMaxOrZero = EL.MaxOrZero;
8900 } else {
8901 MustExitMaxBECount = getUMinFromMismatchedTypes(MustExitMaxBECount,
8902 EL.ConstantMaxNotTaken);
8903 }
8904 } else if (MayExitMaxBECount != getCouldNotCompute()) {
8905 if (!MayExitMaxBECount || EL.ConstantMaxNotTaken == getCouldNotCompute())
8906 MayExitMaxBECount = EL.ConstantMaxNotTaken;
8907 else {
8908 MayExitMaxBECount = getUMaxFromMismatchedTypes(MayExitMaxBECount,
8909 EL.ConstantMaxNotTaken);
8910 }
8911 }
8912 }
8913 const SCEV *MaxBECount = MustExitMaxBECount ? MustExitMaxBECount :
8914 (MayExitMaxBECount ? MayExitMaxBECount : getCouldNotCompute());
8915 // The loop backedge will be taken the maximum or zero times if there's
8916 // a single exit that must be taken the maximum or zero times.
8917 bool MaxOrZero = (MustExitMaxOrZero && ExitingBlocks.size() == 1);
8918
8919 // Remember which SCEVs are used in exit limits for invalidation purposes.
8920 // We only care about non-constant SCEVs here, so we can ignore
8921 // EL.ConstantMaxNotTaken
8922 // and MaxBECount, which must be SCEVConstant.
8923 for (const auto &Pair : ExitCounts) {
8924 if (!isa<SCEVConstant>(Pair.second.ExactNotTaken))
8925 BECountUsers[Pair.second.ExactNotTaken].insert({L, AllowPredicates});
8926 if (!isa<SCEVConstant>(Pair.second.SymbolicMaxNotTaken))
8927 BECountUsers[Pair.second.SymbolicMaxNotTaken].insert(
8928 {L, AllowPredicates});
8929 }
8930 return BackedgeTakenInfo(std::move(ExitCounts), CouldComputeBECount,
8931 MaxBECount, MaxOrZero);
8932}
8933
8935ScalarEvolution::computeExitLimit(const Loop *L, BasicBlock *ExitingBlock,
8936 bool IsOnlyExit, bool AllowPredicates) {
8937 assert(L->contains(ExitingBlock) && "Exit count for non-loop block?");
8938 // If our exiting block does not dominate the latch, then its connection with
8939 // loop's exit limit may be far from trivial.
8940 const BasicBlock *Latch = L->getLoopLatch();
8941 if (!Latch || !DT.dominates(ExitingBlock, Latch))
8942 return getCouldNotCompute();
8943
8944 Instruction *Term = ExitingBlock->getTerminator();
8945 if (BranchInst *BI = dyn_cast<BranchInst>(Term)) {
8946 assert(BI->isConditional() && "If unconditional, it can't be in loop!");
8947 bool ExitIfTrue = !L->contains(BI->getSuccessor(0));
8948 assert(ExitIfTrue == L->contains(BI->getSuccessor(1)) &&
8949 "It should have one successor in loop and one exit block!");
8950 // Proceed to the next level to examine the exit condition expression.
8951 return computeExitLimitFromCond(L, BI->getCondition(), ExitIfTrue,
8952 /*ControlsOnlyExit=*/IsOnlyExit,
8953 AllowPredicates);
8954 }
8955
8956 if (SwitchInst *SI = dyn_cast<SwitchInst>(Term)) {
8957 // For switch, make sure that there is a single exit from the loop.
8958 BasicBlock *Exit = nullptr;
8959 for (auto *SBB : successors(ExitingBlock))
8960 if (!L->contains(SBB)) {
8961 if (Exit) // Multiple exit successors.
8962 return getCouldNotCompute();
8963 Exit = SBB;
8964 }
8965 assert(Exit && "Exiting block must have at least one exit");
8966 return computeExitLimitFromSingleExitSwitch(
8967 L, SI, Exit, /*ControlsOnlyExit=*/IsOnlyExit);
8968 }
8969
8970 return getCouldNotCompute();
8971}
8972
8974 const Loop *L, Value *ExitCond, bool ExitIfTrue, bool ControlsOnlyExit,
8975 bool AllowPredicates) {
8976 ScalarEvolution::ExitLimitCacheTy Cache(L, ExitIfTrue, AllowPredicates);
8977 return computeExitLimitFromCondCached(Cache, L, ExitCond, ExitIfTrue,
8978 ControlsOnlyExit, AllowPredicates);
8979}
8980
8981std::optional<ScalarEvolution::ExitLimit>
8982ScalarEvolution::ExitLimitCache::find(const Loop *L, Value *ExitCond,
8983 bool ExitIfTrue, bool ControlsOnlyExit,
8984 bool AllowPredicates) {
8985 (void)this->L;
8986 (void)this->ExitIfTrue;
8987 (void)this->AllowPredicates;
8988
8989 assert(this->L == L && this->ExitIfTrue == ExitIfTrue &&
8990 this->AllowPredicates == AllowPredicates &&
8991 "Variance in assumed invariant key components!");
8992 auto Itr = TripCountMap.find({ExitCond, ControlsOnlyExit});
8993 if (Itr == TripCountMap.end())
8994 return std::nullopt;
8995 return Itr->second;
8996}
8997
8998void ScalarEvolution::ExitLimitCache::insert(const Loop *L, Value *ExitCond,
8999 bool ExitIfTrue,
9000 bool ControlsOnlyExit,
9001 bool AllowPredicates,
9002 const ExitLimit &EL) {
9003 assert(this->L == L && this->ExitIfTrue == ExitIfTrue &&
9004 this->AllowPredicates == AllowPredicates &&
9005 "Variance in assumed invariant key components!");
9006
9007 auto InsertResult = TripCountMap.insert({{ExitCond, ControlsOnlyExit}, EL});
9008 assert(InsertResult.second && "Expected successful insertion!");
9009 (void)InsertResult;
9010 (void)ExitIfTrue;
9011}
9012
9013ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromCondCached(
9014 ExitLimitCacheTy &Cache, const Loop *L, Value *ExitCond, bool ExitIfTrue,
9015 bool ControlsOnlyExit, bool AllowPredicates) {
9016
9017 if (auto MaybeEL = Cache.find(L, ExitCond, ExitIfTrue, ControlsOnlyExit,
9018 AllowPredicates))
9019 return *MaybeEL;
9020
9021 ExitLimit EL = computeExitLimitFromCondImpl(
9022 Cache, L, ExitCond, ExitIfTrue, ControlsOnlyExit, AllowPredicates);
9023 Cache.insert(L, ExitCond, ExitIfTrue, ControlsOnlyExit, AllowPredicates, EL);
9024 return EL;
9025}
9026
9027ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromCondImpl(
9028 ExitLimitCacheTy &Cache, const Loop *L, Value *ExitCond, bool ExitIfTrue,
9029 bool ControlsOnlyExit, bool AllowPredicates) {
9030 // Handle BinOp conditions (And, Or).
9031 if (auto LimitFromBinOp = computeExitLimitFromCondFromBinOp(
9032 Cache, L, ExitCond, ExitIfTrue, ControlsOnlyExit, AllowPredicates))
9033 return *LimitFromBinOp;
9034
9035 // With an icmp, it may be feasible to compute an exact backedge-taken count.
9036 // Proceed to the next level to examine the icmp.
9037 if (ICmpInst *ExitCondICmp = dyn_cast<ICmpInst>(ExitCond)) {
9038 ExitLimit EL =
9039 computeExitLimitFromICmp(L, ExitCondICmp, ExitIfTrue, ControlsOnlyExit);
9040 if (EL.hasFullInfo() || !AllowPredicates)
9041 return EL;
9042
9043 // Try again, but use SCEV predicates this time.
9044 return computeExitLimitFromICmp(L, ExitCondICmp, ExitIfTrue,
9045 ControlsOnlyExit,
9046 /*AllowPredicates=*/true);
9047 }
9048
9049 // Check for a constant condition. These are normally stripped out by
9050 // SimplifyCFG, but ScalarEvolution may be used by a pass which wishes to
9051 // preserve the CFG and is temporarily leaving constant conditions
9052 // in place.
9053 if (ConstantInt *CI = dyn_cast<ConstantInt>(ExitCond)) {
9054 if (ExitIfTrue == !CI->getZExtValue())
9055 // The backedge is always taken.
9056 return getCouldNotCompute();
9057 // The backedge is never taken.
9058 return getZero(CI->getType());
9059 }
9060
9061 // If we're exiting based on the overflow flag of an x.with.overflow intrinsic
9062 // with a constant step, we can form an equivalent icmp predicate and figure
9063 // out how many iterations will be taken before we exit.
9064 const WithOverflowInst *WO;
9065 const APInt *C;
9066 if (match(ExitCond, m_ExtractValue<1>(m_WithOverflowInst(WO))) &&
9067 match(WO->getRHS(), m_APInt(C))) {
9068 ConstantRange NWR =
9070 WO->getNoWrapKind());
9071 CmpInst::Predicate Pred;
9072 APInt NewRHSC, Offset;
9073 NWR.getEquivalentICmp(Pred, NewRHSC, Offset);
9074 if (!ExitIfTrue)
9075 Pred = ICmpInst::getInversePredicate(Pred);
9076 auto *LHS = getSCEV(WO->getLHS());
9077 if (Offset != 0)
9079 auto EL = computeExitLimitFromICmp(L, Pred, LHS, getConstant(NewRHSC),
9080 ControlsOnlyExit, AllowPredicates);
9081 if (EL.hasAnyInfo())
9082 return EL;
9083 }
9084
9085 // If it's not an integer or pointer comparison then compute it the hard way.
9086 return computeExitCountExhaustively(L, ExitCond, ExitIfTrue);
9087}
9088
9089std::optional<ScalarEvolution::ExitLimit>
9090ScalarEvolution::computeExitLimitFromCondFromBinOp(
9091 ExitLimitCacheTy &Cache, const Loop *L, Value *ExitCond, bool ExitIfTrue,
9092 bool ControlsOnlyExit, bool AllowPredicates) {
9093 // Check if the controlling expression for this loop is an And or Or.
9094 Value *Op0, *Op1;
9095 bool IsAnd = false;
9096 if (match(ExitCond, m_LogicalAnd(m_Value(Op0), m_Value(Op1))))
9097 IsAnd = true;
9098 else if (match(ExitCond, m_LogicalOr(m_Value(Op0), m_Value(Op1))))
9099 IsAnd = false;
9100 else
9101 return std::nullopt;
9102
9103 // EitherMayExit is true in these two cases:
9104 // br (and Op0 Op1), loop, exit
9105 // br (or Op0 Op1), exit, loop
9106 bool EitherMayExit = IsAnd ^ ExitIfTrue;
9107 ExitLimit EL0 = computeExitLimitFromCondCached(
9108 Cache, L, Op0, ExitIfTrue, ControlsOnlyExit && !EitherMayExit,
9109 AllowPredicates);
9110 ExitLimit EL1 = computeExitLimitFromCondCached(
9111 Cache, L, Op1, ExitIfTrue, ControlsOnlyExit && !EitherMayExit,
9112 AllowPredicates);
9113
9114 // Be robust against unsimplified IR for the form "op i1 X, NeutralElement"
9115 const Constant *NeutralElement = ConstantInt::get(ExitCond->getType(), IsAnd);
9116 if (isa<ConstantInt>(Op1))
9117 return Op1 == NeutralElement ? EL0 : EL1;
9118 if (isa<ConstantInt>(Op0))
9119 return Op0 == NeutralElement ? EL1 : EL0;
9120
9121 const SCEV *BECount = getCouldNotCompute();
9122 const SCEV *ConstantMaxBECount = getCouldNotCompute();
9123 const SCEV *SymbolicMaxBECount = getCouldNotCompute();
9124 if (EitherMayExit) {
9125 bool UseSequentialUMin = !isa<BinaryOperator>(ExitCond);
9126 // Both conditions must be same for the loop to continue executing.
9127 // Choose the less conservative count.
9128 if (EL0.ExactNotTaken != getCouldNotCompute() &&
9129 EL1.ExactNotTaken != getCouldNotCompute()) {
9130 BECount = getUMinFromMismatchedTypes(EL0.ExactNotTaken, EL1.ExactNotTaken,
9131 UseSequentialUMin);
9132 }
9133 if (EL0.ConstantMaxNotTaken == getCouldNotCompute())
9134 ConstantMaxBECount = EL1.ConstantMaxNotTaken;
9135 else if (EL1.ConstantMaxNotTaken == getCouldNotCompute())
9136 ConstantMaxBECount = EL0.ConstantMaxNotTaken;
9137 else
9138 ConstantMaxBECount = getUMinFromMismatchedTypes(EL0.ConstantMaxNotTaken,
9139 EL1.ConstantMaxNotTaken);
9140 if (EL0.SymbolicMaxNotTaken == getCouldNotCompute())
9141 SymbolicMaxBECount = EL1.SymbolicMaxNotTaken;
9142 else if (EL1.SymbolicMaxNotTaken == getCouldNotCompute())
9143 SymbolicMaxBECount = EL0.SymbolicMaxNotTaken;
9144 else
9145 SymbolicMaxBECount = getUMinFromMismatchedTypes(
9146 EL0.SymbolicMaxNotTaken, EL1.SymbolicMaxNotTaken, UseSequentialUMin);
9147 } else {
9148 // Both conditions must be same at the same time for the loop to exit.
9149 // For now, be conservative.
9150 if (EL0.ExactNotTaken == EL1.ExactNotTaken)
9151 BECount = EL0.ExactNotTaken;
9152 }
9153
9154 // There are cases (e.g. PR26207) where computeExitLimitFromCond is able
9155 // to be more aggressive when computing BECount than when computing
9156 // ConstantMaxBECount. In these cases it is possible for EL0.ExactNotTaken
9157 // and
9158 // EL1.ExactNotTaken to match, but for EL0.ConstantMaxNotTaken and
9159 // EL1.ConstantMaxNotTaken to not.
9160 if (isa<SCEVCouldNotCompute>(ConstantMaxBECount) &&
9161 !isa<SCEVCouldNotCompute>(BECount))
9162 ConstantMaxBECount = getConstant(getUnsignedRangeMax(BECount));
9163 if (isa<SCEVCouldNotCompute>(SymbolicMaxBECount))
9164 SymbolicMaxBECount =
9165 isa<SCEVCouldNotCompute>(BECount) ? ConstantMaxBECount : BECount;
9166 return ExitLimit(BECount, ConstantMaxBECount, SymbolicMaxBECount, false,
9167 {ArrayRef(EL0.Predicates), ArrayRef(EL1.Predicates)});
9168}
9169
9170ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromICmp(
9171 const Loop *L, ICmpInst *ExitCond, bool ExitIfTrue, bool ControlsOnlyExit,
9172 bool AllowPredicates) {
9173 // If the condition was exit on true, convert the condition to exit on false
9174 CmpPredicate Pred;
9175 if (!ExitIfTrue)
9176 Pred = ExitCond->getCmpPredicate();
9177 else
9178 Pred = ExitCond->getInverseCmpPredicate();
9179 const ICmpInst::Predicate OriginalPred = Pred;
9180
9181 const SCEV *LHS = getSCEV(ExitCond->getOperand(0));
9182 const SCEV *RHS = getSCEV(ExitCond->getOperand(1));
9183
9184 ExitLimit EL = computeExitLimitFromICmp(L, Pred, LHS, RHS, ControlsOnlyExit,
9185 AllowPredicates);
9186 if (EL.hasAnyInfo())
9187 return EL;
9188
9189 auto *ExhaustiveCount =
9190 computeExitCountExhaustively(L, ExitCond, ExitIfTrue);
9191
9192 if (!isa<SCEVCouldNotCompute>(ExhaustiveCount))
9193 return ExhaustiveCount;
9194
9195 return computeShiftCompareExitLimit(ExitCond->getOperand(0),
9196 ExitCond->getOperand(1), L, OriginalPred);
9197}
9198ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromICmp(
9199 const Loop *L, CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS,
9200 bool ControlsOnlyExit, bool AllowPredicates) {
9201
9202 // Try to evaluate any dependencies out of the loop.
9203 LHS = getSCEVAtScope(LHS, L);
9204 RHS = getSCEVAtScope(RHS, L);
9205
9206 // At this point, we would like to compute how many iterations of the
9207 // loop the predicate will return true for these inputs.
9208 if (isLoopInvariant(LHS, L) && !isLoopInvariant(RHS, L)) {
9209 // If there is a loop-invariant, force it into the RHS.
9210 std::swap(LHS, RHS);
9212 }
9213
9214 bool ControllingFiniteLoop = ControlsOnlyExit && loopHasNoAbnormalExits(L) &&
9216 // Simplify the operands before analyzing them.
9217 (void)SimplifyICmpOperands(Pred, LHS, RHS, /*Depth=*/0);
9218
9219 // If we have a comparison of a chrec against a constant, try to use value
9220 // ranges to answer this query.
9221 if (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS))
9222 if (const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(LHS))
9223 if (AddRec->getLoop() == L) {
9224 // Form the constant range.
9225 ConstantRange CompRange =
9226 ConstantRange::makeExactICmpRegion(Pred, RHSC->getAPInt());
9227
9228 const SCEV *Ret = AddRec->getNumIterationsInRange(CompRange, *this);
9229 if (!isa<SCEVCouldNotCompute>(Ret)) return Ret;
9230 }
9231
9232 // If this loop must exit based on this condition (or execute undefined
9233 // behaviour), see if we can improve wrap flags. This is essentially
9234 // a must execute style proof.
9235 if (ControllingFiniteLoop && isLoopInvariant(RHS, L)) {
9236 // If we can prove the test sequence produced must repeat the same values
9237 // on self-wrap of the IV, then we can infer that IV doesn't self wrap
9238 // because if it did, we'd have an infinite (undefined) loop.
9239 // TODO: We can peel off any functions which are invertible *in L*. Loop
9240 // invariant terms are effectively constants for our purposes here.
9241 auto *InnerLHS = LHS;
9242 if (auto *ZExt = dyn_cast<SCEVZeroExtendExpr>(LHS))
9243 InnerLHS = ZExt->getOperand();
9244 if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(InnerLHS);
9245 AR && !AR->hasNoSelfWrap() && AR->getLoop() == L && AR->isAffine() &&
9246 isKnownToBeAPowerOfTwo(AR->getStepRecurrence(*this), /*OrZero=*/true,
9247 /*OrNegative=*/true)) {
9248 auto Flags = AR->getNoWrapFlags();
9249 Flags = setFlags(Flags, SCEV::FlagNW);
9252 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), Flags);
9253 }
9254
9255 // For a slt/ult condition with a positive step, can we prove nsw/nuw?
9256 // From no-self-wrap, this follows trivially from the fact that every
9257 // (un)signed-wrapped, but not self-wrapped value must be LT than the
9258 // last value before (un)signed wrap. Since we know that last value
9259 // didn't exit, nor will any smaller one.
9260 if (Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_ULT) {
9261 auto WrapType = Pred == ICmpInst::ICMP_SLT ? SCEV::FlagNSW : SCEV::FlagNUW;
9262 if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(LHS);
9263 AR && AR->getLoop() == L && AR->isAffine() &&
9264 !AR->getNoWrapFlags(WrapType) && AR->hasNoSelfWrap() &&
9265 isKnownPositive(AR->getStepRecurrence(*this))) {
9266 auto Flags = AR->getNoWrapFlags();
9267 Flags = setFlags(Flags, WrapType);
9270 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), Flags);
9271 }
9272 }
9273 }
9274
9275 switch (Pred) {
9276 case ICmpInst::ICMP_NE: { // while (X != Y)
9277 // Convert to: while (X-Y != 0)
9278 if (LHS->getType()->isPointerTy()) {
9280 if (isa<SCEVCouldNotCompute>(LHS))
9281 return LHS;
9282 }
9283 if (RHS->getType()->isPointerTy()) {
9285 if (isa<SCEVCouldNotCompute>(RHS))
9286 return RHS;
9287 }
9288 ExitLimit EL = howFarToZero(getMinusSCEV(LHS, RHS), L, ControlsOnlyExit,
9289 AllowPredicates);
9290 if (EL.hasAnyInfo())
9291 return EL;
9292 break;
9293 }
9294 case ICmpInst::ICMP_EQ: { // while (X == Y)
9295 // Convert to: while (X-Y == 0)
9296 if (LHS->getType()->isPointerTy()) {
9298 if (isa<SCEVCouldNotCompute>(LHS))
9299 return LHS;
9300 }
9301 if (RHS->getType()->isPointerTy()) {
9303 if (isa<SCEVCouldNotCompute>(RHS))
9304 return RHS;
9305 }
9306 ExitLimit EL = howFarToNonZero(getMinusSCEV(LHS, RHS), L);
9307 if (EL.hasAnyInfo()) return EL;
9308 break;
9309 }
9310 case ICmpInst::ICMP_SLE:
9311 case ICmpInst::ICMP_ULE:
9312 // Since the loop is finite, an invariant RHS cannot include the boundary
9313 // value, otherwise it would loop forever.
9314 if (!EnableFiniteLoopControl || !ControllingFiniteLoop ||
9315 !isLoopInvariant(RHS, L)) {
9316 // Otherwise, perform the addition in a wider type, to avoid overflow.
9317 // If the LHS is an addrec with the appropriate nowrap flag, the
9318 // extension will be sunk into it and the exit count can be analyzed.
9319 auto *OldType = dyn_cast<IntegerType>(LHS->getType());
9320 if (!OldType)
9321 break;
9322 // Prefer doubling the bitwidth over adding a single bit to make it more
9323 // likely that we use a legal type.
9324 auto *NewType =
9325 Type::getIntNTy(OldType->getContext(), OldType->getBitWidth() * 2);
9326 if (ICmpInst::isSigned(Pred)) {
9327 LHS = getSignExtendExpr(LHS, NewType);
9328 RHS = getSignExtendExpr(RHS, NewType);
9329 } else {
9330 LHS = getZeroExtendExpr(LHS, NewType);
9331 RHS = getZeroExtendExpr(RHS, NewType);
9332 }
9333 }
9334 RHS = getAddExpr(getOne(RHS->getType()), RHS);
9335 [[fallthrough]];
9336 case ICmpInst::ICMP_SLT:
9337 case ICmpInst::ICMP_ULT: { // while (X < Y)
9338 bool IsSigned = ICmpInst::isSigned(Pred);
9339 ExitLimit EL = howManyLessThans(LHS, RHS, L, IsSigned, ControlsOnlyExit,
9340 AllowPredicates);
9341 if (EL.hasAnyInfo())
9342 return EL;
9343 break;
9344 }
9345 case ICmpInst::ICMP_SGE:
9346 case ICmpInst::ICMP_UGE:
9347 // Since the loop is finite, an invariant RHS cannot include the boundary
9348 // value, otherwise it would loop forever.
9349 if (!EnableFiniteLoopControl || !ControllingFiniteLoop ||
9350 !isLoopInvariant(RHS, L))
9351 break;
9352 RHS = getAddExpr(getMinusOne(RHS->getType()), RHS);
9353 [[fallthrough]];
9354 case ICmpInst::ICMP_SGT:
9355 case ICmpInst::ICMP_UGT: { // while (X > Y)
9356 bool IsSigned = ICmpInst::isSigned(Pred);
9357 ExitLimit EL = howManyGreaterThans(LHS, RHS, L, IsSigned, ControlsOnlyExit,
9358 AllowPredicates);
9359 if (EL.hasAnyInfo())
9360 return EL;
9361 break;
9362 }
9363 default:
9364 break;
9365 }
9366
9367 return getCouldNotCompute();
9368}
9369
9371ScalarEvolution::computeExitLimitFromSingleExitSwitch(const Loop *L,
9372 SwitchInst *Switch,
9373 BasicBlock *ExitingBlock,
9374 bool ControlsOnlyExit) {
9375 assert(!L->contains(ExitingBlock) && "Not an exiting block!");
9376
9377 // Give up if the exit is the default dest of a switch.
9378 if (Switch->getDefaultDest() == ExitingBlock)
9379 return getCouldNotCompute();
9380
9381 assert(L->contains(Switch->getDefaultDest()) &&
9382 "Default case must not exit the loop!");
9383 const SCEV *LHS = getSCEVAtScope(Switch->getCondition(), L);
9384 const SCEV *RHS = getConstant(Switch->findCaseDest(ExitingBlock));
9385
9386 // while (X != Y) --> while (X-Y != 0)
9387 ExitLimit EL = howFarToZero(getMinusSCEV(LHS, RHS), L, ControlsOnlyExit);
9388 if (EL.hasAnyInfo())
9389 return EL;
9390
9391 return getCouldNotCompute();
9392}
9393
9394static ConstantInt *
9396 ScalarEvolution &SE) {
9397 const SCEV *InVal = SE.getConstant(C);
9398 const SCEV *Val = AddRec->evaluateAtIteration(InVal, SE);
9399 assert(isa<SCEVConstant>(Val) &&
9400 "Evaluation of SCEV at constant didn't fold correctly?");
9401 return cast<SCEVConstant>(Val)->getValue();
9402}
9403
9404ScalarEvolution::ExitLimit ScalarEvolution::computeShiftCompareExitLimit(
9405 Value *LHS, Value *RHSV, const Loop *L, ICmpInst::Predicate Pred) {
9406 ConstantInt *RHS = dyn_cast<ConstantInt>(RHSV);
9407 if (!RHS)
9408 return getCouldNotCompute();
9409
9410 const BasicBlock *Latch = L->getLoopLatch();
9411 if (!Latch)
9412 return getCouldNotCompute();
9413
9414 const BasicBlock *Predecessor = L->getLoopPredecessor();
9415 if (!Predecessor)
9416 return getCouldNotCompute();
9417
9418 // Return true if V is of the form "LHS `shift_op` <positive constant>".
9419 // Return LHS in OutLHS and shift_opt in OutOpCode.
9420 auto MatchPositiveShift =
9421 [](Value *V, Value *&OutLHS, Instruction::BinaryOps &OutOpCode) {
9422
9423 using namespace PatternMatch;
9424
9425 ConstantInt *ShiftAmt;
9426 if (match(V, m_LShr(m_Value(OutLHS), m_ConstantInt(ShiftAmt))))
9427 OutOpCode = Instruction::LShr;
9428 else if (match(V, m_AShr(m_Value(OutLHS), m_ConstantInt(ShiftAmt))))
9429 OutOpCode = Instruction::AShr;
9430 else if (match(V, m_Shl(m_Value(OutLHS), m_ConstantInt(ShiftAmt))))
9431 OutOpCode = Instruction::Shl;
9432 else
9433 return false;
9434
9435 return ShiftAmt->getValue().isStrictlyPositive();
9436 };
9437
9438 // Recognize a "shift recurrence" either of the form %iv or of %iv.shifted in
9439 //
9440 // loop:
9441 // %iv = phi i32 [ %iv.shifted, %loop ], [ %val, %preheader ]
9442 // %iv.shifted = lshr i32 %iv, <positive constant>
9443 //
9444 // Return true on a successful match. Return the corresponding PHI node (%iv
9445 // above) in PNOut and the opcode of the shift operation in OpCodeOut.
9446 auto MatchShiftRecurrence =
9447 [&](Value *V, PHINode *&PNOut, Instruction::BinaryOps &OpCodeOut) {
9448 std::optional<Instruction::BinaryOps> PostShiftOpCode;
9449
9450 {
9452 Value *V;
9453
9454 // If we encounter a shift instruction, "peel off" the shift operation,
9455 // and remember that we did so. Later when we inspect %iv's backedge
9456 // value, we will make sure that the backedge value uses the same
9457 // operation.
9458 //
9459 // Note: the peeled shift operation does not have to be the same
9460 // instruction as the one feeding into the PHI's backedge value. We only
9461 // really care about it being the same *kind* of shift instruction --
9462 // that's all that is required for our later inferences to hold.
9463 if (MatchPositiveShift(LHS, V, OpC)) {
9464 PostShiftOpCode = OpC;
9465 LHS = V;
9466 }
9467 }
9468
9469 PNOut = dyn_cast<PHINode>(LHS);
9470 if (!PNOut || PNOut->getParent() != L->getHeader())
9471 return false;
9472
9473 Value *BEValue = PNOut->getIncomingValueForBlock(Latch);
9474 Value *OpLHS;
9475
9476 return
9477 // The backedge value for the PHI node must be a shift by a positive
9478 // amount
9479 MatchPositiveShift(BEValue, OpLHS, OpCodeOut) &&
9480
9481 // of the PHI node itself
9482 OpLHS == PNOut &&
9483
9484 // and the kind of shift should be match the kind of shift we peeled
9485 // off, if any.
9486 (!PostShiftOpCode || *PostShiftOpCode == OpCodeOut);
9487 };
9488
9489 PHINode *PN;
9491 if (!MatchShiftRecurrence(LHS, PN, OpCode))
9492 return getCouldNotCompute();
9493
9494 const DataLayout &DL = getDataLayout();
9495
9496 // The key rationale for this optimization is that for some kinds of shift
9497 // recurrences, the value of the recurrence "stabilizes" to either 0 or -1
9498 // within a finite number of iterations. If the condition guarding the
9499 // backedge (in the sense that the backedge is taken if the condition is true)
9500 // is false for the value the shift recurrence stabilizes to, then we know
9501 // that the backedge is taken only a finite number of times.
9502
9503 ConstantInt *StableValue = nullptr;
9504 switch (OpCode) {
9505 default:
9506 llvm_unreachable("Impossible case!");
9507
9508 case Instruction::AShr: {
9509 // {K,ashr,<positive-constant>} stabilizes to signum(K) in at most
9510 // bitwidth(K) iterations.
9511 Value *FirstValue = PN->getIncomingValueForBlock(Predecessor);
9512 KnownBits Known = computeKnownBits(FirstValue, DL, 0, &AC,
9513 Predecessor->getTerminator(), &DT);
9514 auto *Ty = cast<IntegerType>(RHS->getType());
9515 if (Known.isNonNegative())
9516 StableValue = ConstantInt::get(Ty, 0);
9517 else if (Known.isNegative())
9518 StableValue = ConstantInt::get(Ty, -1, true);
9519 else
9520 return getCouldNotCompute();
9521
9522 break;
9523 }
9524 case Instruction::LShr:
9525 case Instruction::Shl:
9526 // Both {K,lshr,<positive-constant>} and {K,shl,<positive-constant>}
9527 // stabilize to 0 in at most bitwidth(K) iterations.
9528 StableValue = ConstantInt::get(cast<IntegerType>(RHS->getType()), 0);
9529 break;
9530 }
9531
9532 auto *Result =
9533 ConstantFoldCompareInstOperands(Pred, StableValue, RHS, DL, &TLI);
9534 assert(Result->getType()->isIntegerTy(1) &&
9535 "Otherwise cannot be an operand to a branch instruction");
9536
9537 if (Result->isZeroValue()) {
9538 unsigned BitWidth = getTypeSizeInBits(RHS->getType());
9539 const SCEV *UpperBound =
9541 return ExitLimit(getCouldNotCompute(), UpperBound, UpperBound, false);
9542 }
9543
9544 return getCouldNotCompute();
9545}
9546
9547/// Return true if we can constant fold an instruction of the specified type,
9548/// assuming that all operands were constants.
9549static bool CanConstantFold(const Instruction *I) {
9550 if (isa<BinaryOperator>(I) || isa<CmpInst>(I) ||
9551 isa<SelectInst>(I) || isa<CastInst>(I) || isa<GetElementPtrInst>(I) ||
9552 isa<LoadInst>(I) || isa<ExtractValueInst>(I))
9553 return true;
9554
9555 if (const CallInst *CI = dyn_cast<CallInst>(I))
9556 if (const Function *F = CI->getCalledFunction())
9557 return canConstantFoldCallTo(CI, F);
9558 return false;
9559}
9560
9561/// Determine whether this instruction can constant evolve within this loop
9562/// assuming its operands can all constant evolve.
9563static bool canConstantEvolve(Instruction *I, const Loop *L) {
9564 // An instruction outside of the loop can't be derived from a loop PHI.
9565 if (!L->contains(I)) return false;
9566
9567 if (isa<PHINode>(I)) {
9568 // We don't currently keep track of the control flow needed to evaluate
9569 // PHIs, so we cannot handle PHIs inside of loops.
9570 return L->getHeader() == I->getParent();
9571 }
9572
9573 // If we won't be able to constant fold this expression even if the operands
9574 // are constants, bail early.
9575 return CanConstantFold(I);
9576}
9577
9578/// getConstantEvolvingPHIOperands - Implement getConstantEvolvingPHI by
9579/// recursing through each instruction operand until reaching a loop header phi.
9580static PHINode *
9583 unsigned Depth) {
9585 return nullptr;
9586
9587 // Otherwise, we can evaluate this instruction if all of its operands are
9588 // constant or derived from a PHI node themselves.
9589 PHINode *PHI = nullptr;
9590 for (Value *Op : UseInst->operands()) {
9591 if (isa<Constant>(Op)) continue;
9592
9593 Instruction *OpInst = dyn_cast<Instruction>(Op);
9594 if (!OpInst || !canConstantEvolve(OpInst, L)) return nullptr;
9595
9596 PHINode *P = dyn_cast<PHINode>(OpInst);
9597 if (!P)
9598 // If this operand is already visited, reuse the prior result.
9599 // We may have P != PHI if this is the deepest point at which the
9600 // inconsistent paths meet.
9601 P = PHIMap.lookup(OpInst);
9602 if (!P) {
9603 // Recurse and memoize the results, whether a phi is found or not.
9604 // This recursive call invalidates pointers into PHIMap.
9605 P = getConstantEvolvingPHIOperands(OpInst, L, PHIMap, Depth + 1);
9606 PHIMap[OpInst] = P;
9607 }
9608 if (!P)
9609 return nullptr; // Not evolving from PHI
9610 if (PHI && PHI != P)
9611 return nullptr; // Evolving from multiple different PHIs.
9612 PHI = P;
9613 }
9614 // This is a expression evolving from a constant PHI!
9615 return PHI;
9616}
9617
9618/// getConstantEvolvingPHI - Given an LLVM value and a loop, return a PHI node
9619/// in the loop that V is derived from. We allow arbitrary operations along the
9620/// way, but the operands of an operation must either be constants or a value
9621/// derived from a constant PHI. If this expression does not fit with these
9622/// constraints, return null.
9624 Instruction *I = dyn_cast<Instruction>(V);
9625 if (!I || !canConstantEvolve(I, L)) return nullptr;
9626
9627 if (PHINode *PN = dyn_cast<PHINode>(I))
9628 return PN;
9629
9630 // Record non-constant instructions contained by the loop.
9632 return getConstantEvolvingPHIOperands(I, L, PHIMap, 0);
9633}
9634
9635/// EvaluateExpression - Given an expression that passes the
9636/// getConstantEvolvingPHI predicate, evaluate its value assuming the PHI node
9637/// in the loop has the value PHIVal. If we can't fold this expression for some
9638/// reason, return null.
9641 const DataLayout &DL,
9642 const TargetLibraryInfo *TLI) {
9643 // Convenient constant check, but redundant for recursive calls.
9644 if (Constant *C = dyn_cast<Constant>(V)) return C;
9645 Instruction *I = dyn_cast<Instruction>(V);
9646 if (!I) return nullptr;
9647
9648 if (Constant *C = Vals.lookup(I)) return C;
9649
9650 // An instruction inside the loop depends on a value outside the loop that we
9651 // weren't given a mapping for, or a value such as a call inside the loop.
9652 if (!canConstantEvolve(I, L)) return nullptr;
9653
9654 // An unmapped PHI can be due to a branch or another loop inside this loop,
9655 // or due to this not being the initial iteration through a loop where we
9656 // couldn't compute the evolution of this particular PHI last time.
9657 if (isa<PHINode>(I)) return nullptr;
9658
9659 std::vector<Constant*> Operands(I->getNumOperands());
9660
9661 for (unsigned i = 0, e = I->getNumOperands(); i != e; ++i) {
9662 Instruction *Operand = dyn_cast<Instruction>(I->getOperand(i));
9663 if (!Operand) {
9664 Operands[i] = dyn_cast<Constant>(I->getOperand(i));
9665 if (!Operands[i]) return nullptr;
9666 continue;
9667 }
9668 Constant *C = EvaluateExpression(Operand, L, Vals, DL, TLI);
9669 Vals[Operand] = C;
9670 if (!C) return nullptr;
9671 Operands[i] = C;
9672 }
9673
9674 return ConstantFoldInstOperands(I, Operands, DL, TLI,
9675 /*AllowNonDeterministic=*/false);
9676}
9677
9678
9679// If every incoming value to PN except the one for BB is a specific Constant,
9680// return that, else return nullptr.
9682 Constant *IncomingVal = nullptr;
9683
9684 for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) {
9685 if (PN->getIncomingBlock(i) == BB)
9686 continue;
9687
9688 auto *CurrentVal = dyn_cast<Constant>(PN->getIncomingValue(i));
9689 if (!CurrentVal)
9690 return nullptr;
9691
9692 if (IncomingVal != CurrentVal) {
9693 if (IncomingVal)
9694 return nullptr;
9695 IncomingVal = CurrentVal;
9696 }
9697 }
9698
9699 return IncomingVal;
9700}
9701
9702/// getConstantEvolutionLoopExitValue - If we know that the specified Phi is
9703/// in the header of its containing loop, we know the loop executes a
9704/// constant number of times, and the PHI node is just a recurrence
9705/// involving constants, fold it.
9706Constant *
9707ScalarEvolution::getConstantEvolutionLoopExitValue(PHINode *PN,
9708 const APInt &BEs,
9709 const Loop *L) {
9710 auto [I, Inserted] = ConstantEvolutionLoopExitValue.try_emplace(PN);
9711 if (!Inserted)
9712 return I->second;
9713
9715 return nullptr; // Not going to evaluate it.
9716
9717 Constant *&RetVal = I->second;
9718
9720 BasicBlock *Header = L->getHeader();
9721 assert(PN->getParent() == Header && "Can't evaluate PHI not in loop header!");
9722
9723 BasicBlock *Latch = L->getLoopLatch();
9724 if (!Latch)
9725 return nullptr;
9726
9727 for (PHINode &PHI : Header->phis()) {
9728 if (auto *StartCST = getOtherIncomingValue(&PHI, Latch))
9729 CurrentIterVals[&PHI] = StartCST;
9730 }
9731 if (!CurrentIterVals.count(PN))
9732 return RetVal = nullptr;
9733
9734 Value *BEValue = PN->getIncomingValueForBlock(Latch);
9735
9736 // Execute the loop symbolically to determine the exit value.
9737 assert(BEs.getActiveBits() < CHAR_BIT * sizeof(unsigned) &&
9738 "BEs is <= MaxBruteForceIterations which is an 'unsigned'!");
9739
9740 unsigned NumIterations = BEs.getZExtValue(); // must be in range
9741 unsigned IterationNum = 0;
9742 const DataLayout &DL = getDataLayout();
9743 for (; ; ++IterationNum) {
9744 if (IterationNum == NumIterations)
9745 return RetVal = CurrentIterVals[PN]; // Got exit value!
9746
9747 // Compute the value of the PHIs for the next iteration.
9748 // EvaluateExpression adds non-phi values to the CurrentIterVals map.
9750 Constant *NextPHI =
9751 EvaluateExpression(BEValue, L, CurrentIterVals, DL, &TLI);
9752 if (!NextPHI)
9753 return nullptr; // Couldn't evaluate!
9754 NextIterVals[PN] = NextPHI;
9755
9756 bool StoppedEvolving = NextPHI == CurrentIterVals[PN];
9757
9758 // Also evaluate the other PHI nodes. However, we don't get to stop if we
9759 // cease to be able to evaluate one of them or if they stop evolving,
9760 // because that doesn't necessarily prevent us from computing PN.
9762 for (const auto &I : CurrentIterVals) {
9763 PHINode *PHI = dyn_cast<PHINode>(I.first);
9764 if (!PHI || PHI == PN || PHI->getParent() != Header) continue;
9765 PHIsToCompute.emplace_back(PHI, I.second);
9766 }
9767 // We use two distinct loops because EvaluateExpression may invalidate any
9768 // iterators into CurrentIterVals.
9769 for (const auto &I : PHIsToCompute) {
9770 PHINode *PHI = I.first;
9771 Constant *&NextPHI = NextIterVals[PHI];
9772 if (!NextPHI) { // Not already computed.
9773 Value *BEValue = PHI->getIncomingValueForBlock(Latch);
9774 NextPHI = EvaluateExpression(BEValue, L, CurrentIterVals, DL, &TLI);
9775 }
9776 if (NextPHI != I.second)
9777 StoppedEvolving = false;
9778 }
9779
9780 // If all entries in CurrentIterVals == NextIterVals then we can stop
9781 // iterating, the loop can't continue to change.
9782 if (StoppedEvolving)
9783 return RetVal = CurrentIterVals[PN];
9784
9785 CurrentIterVals.swap(NextIterVals);
9786 }
9787}
9788
9789const SCEV *ScalarEvolution::computeExitCountExhaustively(const Loop *L,
9790 Value *Cond,
9791 bool ExitWhen) {
9793 if (!PN) return getCouldNotCompute();
9794
9795 // If the loop is canonicalized, the PHI will have exactly two entries.
9796 // That's the only form we support here.
9797 if (PN->getNumIncomingValues() != 2) return getCouldNotCompute();
9798
9800 BasicBlock *Header = L->getHeader();
9801 assert(PN->getParent() == Header && "Can't evaluate PHI not in loop header!");
9802
9803 BasicBlock *Latch = L->getLoopLatch();
9804 assert(Latch && "Should follow from NumIncomingValues == 2!");
9805
9806 for (PHINode &PHI : Header->phis()) {
9807 if (auto *StartCST = getOtherIncomingValue(&PHI, Latch))
9808 CurrentIterVals[&PHI] = StartCST;
9809 }
9810 if (!CurrentIterVals.count(PN))
9811 return getCouldNotCompute();
9812
9813 // Okay, we find a PHI node that defines the trip count of this loop. Execute
9814 // the loop symbolically to determine when the condition gets a value of
9815 // "ExitWhen".
9816 unsigned MaxIterations = MaxBruteForceIterations; // Limit analysis.
9817 const DataLayout &DL = getDataLayout();
9818 for (unsigned IterationNum = 0; IterationNum != MaxIterations;++IterationNum){
9819 auto *CondVal = dyn_cast_or_null<ConstantInt>(
9820 EvaluateExpression(Cond, L, CurrentIterVals, DL, &TLI));
9821
9822 // Couldn't symbolically evaluate.
9823 if (!CondVal) return getCouldNotCompute();
9824
9825 if (CondVal->getValue() == uint64_t(ExitWhen)) {
9826 ++NumBruteForceTripCountsComputed;
9827 return getConstant(Type::getInt32Ty(getContext()), IterationNum);
9828 }
9829
9830 // Update all the PHI nodes for the next iteration.
9832
9833 // Create a list of which PHIs we need to compute. We want to do this before
9834 // calling EvaluateExpression on them because that may invalidate iterators
9835 // into CurrentIterVals.
9836 SmallVector<PHINode *, 8> PHIsToCompute;
9837 for (const auto &I : CurrentIterVals) {
9838 PHINode *PHI = dyn_cast<PHINode>(I.first);
9839 if (!PHI || PHI->getParent() != Header) continue;
9840 PHIsToCompute.push_back(PHI);
9841 }
9842 for (PHINode *PHI : PHIsToCompute) {
9843 Constant *&NextPHI = NextIterVals[PHI];
9844 if (NextPHI) continue; // Already computed!
9845
9846 Value *BEValue = PHI->getIncomingValueForBlock(Latch);
9847 NextPHI = EvaluateExpression(BEValue, L, CurrentIterVals, DL, &TLI);
9848 }
9849 CurrentIterVals.swap(NextIterVals);
9850 }
9851
9852 // Too many iterations were needed to evaluate.
9853 return getCouldNotCompute();
9854}
9855
9856const SCEV *ScalarEvolution::getSCEVAtScope(const SCEV *V, const Loop *L) {
9858 ValuesAtScopes[V];
9859 // Check to see if we've folded this expression at this loop before.
9860 for (auto &LS : Values)
9861 if (LS.first == L)
9862 return LS.second ? LS.second : V;
9863
9864 Values.emplace_back(L, nullptr);
9865
9866 // Otherwise compute it.
9867 const SCEV *C = computeSCEVAtScope(V, L);
9868 for (auto &LS : reverse(ValuesAtScopes[V]))
9869 if (LS.first == L) {
9870 LS.second = C;
9871 if (!isa<SCEVConstant>(C))
9872 ValuesAtScopesUsers[C].push_back({L, V});
9873 break;
9874 }
9875 return C;
9876}
9877
9878/// This builds up a Constant using the ConstantExpr interface. That way, we
9879/// will return Constants for objects which aren't represented by a
9880/// SCEVConstant, because SCEVConstant is restricted to ConstantInt.
9881/// Returns NULL if the SCEV isn't representable as a Constant.
9883 switch (V->getSCEVType()) {
9884 case scCouldNotCompute:
9885 case scAddRecExpr:
9886 case scVScale:
9887 return nullptr;
9888 case scConstant:
9889 return cast<SCEVConstant>(V)->getValue();
9890 case scUnknown:
9891 return dyn_cast<Constant>(cast<SCEVUnknown>(V)->getValue());
9892 case scPtrToInt: {
9893 const SCEVPtrToIntExpr *P2I = cast<SCEVPtrToIntExpr>(V);
9894 if (Constant *CastOp = BuildConstantFromSCEV(P2I->getOperand()))
9895 return ConstantExpr::getPtrToInt(CastOp, P2I->getType());
9896
9897 return nullptr;
9898 }
9899 case scTruncate: {
9900 const SCEVTruncateExpr *ST = cast<SCEVTruncateExpr>(V);
9901 if (Constant *CastOp = BuildConstantFromSCEV(ST->getOperand()))
9902 return ConstantExpr::getTrunc(CastOp, ST->getType());
9903 return nullptr;
9904 }
9905 case scAddExpr: {
9906 const SCEVAddExpr *SA = cast<SCEVAddExpr>(V);
9907 Constant *C = nullptr;
9908 for (const SCEV *Op : SA->operands()) {
9910 if (!OpC)
9911 return nullptr;
9912 if (!C) {
9913 C = OpC;
9914 continue;
9915 }
9916 assert(!C->getType()->isPointerTy() &&
9917 "Can only have one pointer, and it must be last");
9918 if (OpC->getType()->isPointerTy()) {
9919 // The offsets have been converted to bytes. We can add bytes using
9920 // an i8 GEP.
9922 OpC, C);
9923 } else {
9924 C = ConstantExpr::getAdd(C, OpC);
9925 }
9926 }
9927 return C;
9928 }
9929 case scMulExpr:
9930 case scSignExtend:
9931 case scZeroExtend:
9932 case scUDivExpr:
9933 case scSMaxExpr:
9934 case scUMaxExpr:
9935 case scSMinExpr:
9936 case scUMinExpr:
9938 return nullptr;
9939 }
9940 llvm_unreachable("Unknown SCEV kind!");
9941}
9942
9943const SCEV *
9944ScalarEvolution::getWithOperands(const SCEV *S,
9946 switch (S->getSCEVType()) {
9947 case scTruncate:
9948 case scZeroExtend:
9949 case scSignExtend:
9950 case scPtrToInt:
9951 return getCastExpr(S->getSCEVType(), NewOps[0], S->getType());
9952 case scAddRecExpr: {
9953 auto *AddRec = cast<SCEVAddRecExpr>(S);
9954 return getAddRecExpr(NewOps, AddRec->getLoop(), AddRec->getNoWrapFlags());
9955 }
9956 case scAddExpr:
9957 return getAddExpr(NewOps, cast<SCEVAddExpr>(S)->getNoWrapFlags());
9958 case scMulExpr:
9959 return getMulExpr(NewOps, cast<SCEVMulExpr>(S)->getNoWrapFlags());
9960 case scUDivExpr:
9961 return getUDivExpr(NewOps[0], NewOps[1]);
9962 case scUMaxExpr:
9963 case scSMaxExpr:
9964 case scUMinExpr:
9965 case scSMinExpr:
9966 return getMinMaxExpr(S->getSCEVType(), NewOps);
9968 return getSequentialMinMaxExpr(S->getSCEVType(), NewOps);
9969 case scConstant:
9970 case scVScale:
9971 case scUnknown:
9972 return S;
9973 case scCouldNotCompute:
9974 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
9975 }
9976 llvm_unreachable("Unknown SCEV kind!");
9977}
9978
9979const SCEV *ScalarEvolution::computeSCEVAtScope(const SCEV *V, const Loop *L) {
9980 switch (V->getSCEVType()) {
9981 case scConstant:
9982 case scVScale:
9983 return V;
9984 case scAddRecExpr: {
9985 // If this is a loop recurrence for a loop that does not contain L, then we
9986 // are dealing with the final value computed by the loop.
9987 const SCEVAddRecExpr *AddRec = cast<SCEVAddRecExpr>(V);
9988 // First, attempt to evaluate each operand.
9989 // Avoid performing the look-up in the common case where the specified
9990 // expression has no loop-variant portions.
9991 for (unsigned i = 0, e = AddRec->getNumOperands(); i != e; ++i) {
9992 const SCEV *OpAtScope = getSCEVAtScope(AddRec->getOperand(i), L);
9993 if (OpAtScope == AddRec->getOperand(i))
9994 continue;
9995
9996 // Okay, at least one of these operands is loop variant but might be
9997 // foldable. Build a new instance of the folded commutative expression.
9999 NewOps.reserve(AddRec->getNumOperands());
10000 append_range(NewOps, AddRec->operands().take_front(i));
10001 NewOps.push_back(OpAtScope);
10002 for (++i; i != e; ++i)
10003 NewOps.push_back(getSCEVAtScope(AddRec->getOperand(i), L));
10004
10005 const SCEV *FoldedRec = getAddRecExpr(
10006 NewOps, AddRec->getLoop(), AddRec->getNoWrapFlags(SCEV::FlagNW));
10007 AddRec = dyn_cast<SCEVAddRecExpr>(FoldedRec);
10008 // The addrec may be folded to a nonrecurrence, for example, if the
10009 // induction variable is multiplied by zero after constant folding. Go
10010 // ahead and return the folded value.
10011 if (!AddRec)
10012 return FoldedRec;
10013 break;
10014 }
10015
10016 // If the scope is outside the addrec's loop, evaluate it by using the
10017 // loop exit value of the addrec.
10018 if (!AddRec->getLoop()->contains(L)) {
10019 // To evaluate this recurrence, we need to know how many times the AddRec
10020 // loop iterates. Compute this now.
10021 const SCEV *BackedgeTakenCount = getBackedgeTakenCount(AddRec->getLoop());
10022 if (BackedgeTakenCount == getCouldNotCompute())
10023 return AddRec;
10024
10025 // Then, evaluate the AddRec.
10026 return AddRec->evaluateAtIteration(BackedgeTakenCount, *this);
10027 }
10028
10029 return AddRec;
10030 }
10031 case scTruncate:
10032 case scZeroExtend:
10033 case scSignExtend:
10034 case scPtrToInt:
10035 case scAddExpr:
10036 case scMulExpr:
10037 case scUDivExpr:
10038 case scUMaxExpr:
10039 case scSMaxExpr:
10040 case scUMinExpr:
10041 case scSMinExpr:
10042 case scSequentialUMinExpr: {
10043 ArrayRef<const SCEV *> Ops = V->operands();
10044 // Avoid performing the look-up in the common case where the specified
10045 // expression has no loop-variant portions.
10046 for (unsigned i = 0, e = Ops.size(); i != e; ++i) {
10047 const SCEV *OpAtScope = getSCEVAtScope(Ops[i], L);
10048 if (OpAtScope != Ops[i]) {
10049 // Okay, at least one of these operands is loop variant but might be
10050 // foldable. Build a new instance of the folded commutative expression.
10052 NewOps.reserve(Ops.size());
10053 append_range(NewOps, Ops.take_front(i));
10054 NewOps.push_back(OpAtScope);
10055
10056 for (++i; i != e; ++i) {
10057 OpAtScope = getSCEVAtScope(Ops[i], L);
10058 NewOps.push_back(OpAtScope);
10059 }
10060
10061 return getWithOperands(V, NewOps);
10062 }
10063 }
10064 // If we got here, all operands are loop invariant.
10065 return V;
10066 }
10067 case scUnknown: {
10068 // If this instruction is evolved from a constant-evolving PHI, compute the
10069 // exit value from the loop without using SCEVs.
10070 const SCEVUnknown *SU = cast<SCEVUnknown>(V);
10071 Instruction *I = dyn_cast<Instruction>(SU->getValue());
10072 if (!I)
10073 return V; // This is some other type of SCEVUnknown, just return it.
10074
10075 if (PHINode *PN = dyn_cast<PHINode>(I)) {
10076 const Loop *CurrLoop = this->LI[I->getParent()];
10077 // Looking for loop exit value.
10078 if (CurrLoop && CurrLoop->getParentLoop() == L &&
10079 PN->getParent() == CurrLoop->getHeader()) {
10080 // Okay, there is no closed form solution for the PHI node. Check
10081 // to see if the loop that contains it has a known backedge-taken
10082 // count. If so, we may be able to force computation of the exit
10083 // value.
10084 const SCEV *BackedgeTakenCount = getBackedgeTakenCount(CurrLoop);
10085 // This trivial case can show up in some degenerate cases where
10086 // the incoming IR has not yet been fully simplified.
10087 if (BackedgeTakenCount->isZero()) {
10088 Value *InitValue = nullptr;
10089 bool MultipleInitValues = false;
10090 for (unsigned i = 0; i < PN->getNumIncomingValues(); i++) {
10091 if (!CurrLoop->contains(PN->getIncomingBlock(i))) {
10092 if (!InitValue)
10093 InitValue = PN->getIncomingValue(i);
10094 else if (InitValue != PN->getIncomingValue(i)) {
10095 MultipleInitValues = true;
10096 break;
10097 }
10098 }
10099 }
10100 if (!MultipleInitValues && InitValue)
10101 return getSCEV(InitValue);
10102 }
10103 // Do we have a loop invariant value flowing around the backedge
10104 // for a loop which must execute the backedge?
10105 if (!isa<SCEVCouldNotCompute>(BackedgeTakenCount) &&
10106 isKnownNonZero(BackedgeTakenCount) &&
10107 PN->getNumIncomingValues() == 2) {
10108
10109 unsigned InLoopPred =
10110 CurrLoop->contains(PN->getIncomingBlock(0)) ? 0 : 1;
10111 Value *BackedgeVal = PN->getIncomingValue(InLoopPred);
10112 if (CurrLoop->isLoopInvariant(BackedgeVal))
10113 return getSCEV(BackedgeVal);
10114 }
10115 if (auto *BTCC = dyn_cast<SCEVConstant>(BackedgeTakenCount)) {
10116 // Okay, we know how many times the containing loop executes. If
10117 // this is a constant evolving PHI node, get the final value at
10118 // the specified iteration number.
10119 Constant *RV =
10120 getConstantEvolutionLoopExitValue(PN, BTCC->getAPInt(), CurrLoop);
10121 if (RV)
10122 return getSCEV(RV);
10123 }
10124 }
10125 }
10126
10127 // Okay, this is an expression that we cannot symbolically evaluate
10128 // into a SCEV. Check to see if it's possible to symbolically evaluate
10129 // the arguments into constants, and if so, try to constant propagate the
10130 // result. This is particularly useful for computing loop exit values.
10131 if (!CanConstantFold(I))
10132 return V; // This is some other type of SCEVUnknown, just return it.
10133
10135 Operands.reserve(I->getNumOperands());
10136 bool MadeImprovement = false;
10137 for (Value *Op : I->operands()) {
10138 if (Constant *C = dyn_cast<Constant>(Op)) {
10139 Operands.push_back(C);
10140 continue;
10141 }
10142
10143 // If any of the operands is non-constant and if they are
10144 // non-integer and non-pointer, don't even try to analyze them
10145 // with scev techniques.
10146 if (!isSCEVable(Op->getType()))
10147 return V;
10148
10149 const SCEV *OrigV = getSCEV(Op);
10150 const SCEV *OpV = getSCEVAtScope(OrigV, L);
10151 MadeImprovement |= OrigV != OpV;
10152
10154 if (!C)
10155 return V;
10156 assert(C->getType() == Op->getType() && "Type mismatch");
10157 Operands.push_back(C);
10158 }
10159
10160 // Check to see if getSCEVAtScope actually made an improvement.
10161 if (!MadeImprovement)
10162 return V; // This is some other type of SCEVUnknown, just return it.
10163
10164 Constant *C = nullptr;
10165 const DataLayout &DL = getDataLayout();
10166 C = ConstantFoldInstOperands(I, Operands, DL, &TLI,
10167 /*AllowNonDeterministic=*/false);
10168 if (!C)
10169 return V;
10170 return getSCEV(C);
10171 }
10172 case scCouldNotCompute:
10173 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
10174 }
10175 llvm_unreachable("Unknown SCEV type!");
10176}
10177
10179 return getSCEVAtScope(getSCEV(V), L);
10180}
10181
10182const SCEV *ScalarEvolution::stripInjectiveFunctions(const SCEV *S) const {
10183 if (const SCEVZeroExtendExpr *ZExt = dyn_cast<SCEVZeroExtendExpr>(S))
10184 return stripInjectiveFunctions(ZExt->getOperand());
10185 if (const SCEVSignExtendExpr *SExt = dyn_cast<SCEVSignExtendExpr>(S))
10186 return stripInjectiveFunctions(SExt->getOperand());
10187 return S;
10188}
10189
10190/// Finds the minimum unsigned root of the following equation:
10191///
10192/// A * X = B (mod N)
10193///
10194/// where N = 2^BW and BW is the common bit width of A and B. The signedness of
10195/// A and B isn't important.
10196///
10197/// If the equation does not have a solution, SCEVCouldNotCompute is returned.
10198static const SCEV *
10201
10202 ScalarEvolution &SE) {
10203 uint32_t BW = A.getBitWidth();
10204 assert(BW == SE.getTypeSizeInBits(B->getType()));
10205 assert(A != 0 && "A must be non-zero.");
10206
10207 // 1. D = gcd(A, N)
10208 //
10209 // The gcd of A and N may have only one prime factor: 2. The number of
10210 // trailing zeros in A is its multiplicity
10211 uint32_t Mult2 = A.countr_zero();
10212 // D = 2^Mult2
10213
10214 // 2. Check if B is divisible by D.
10215 //
10216 // B is divisible by D if and only if the multiplicity of prime factor 2 for B
10217 // is not less than multiplicity of this prime factor for D.
10218 if (SE.getMinTrailingZeros(B) < Mult2) {
10219 // Check if we can prove there's no remainder using URem.
10220 const SCEV *URem =
10221 SE.getURemExpr(B, SE.getConstant(APInt::getOneBitSet(BW, Mult2)));
10222 const SCEV *Zero = SE.getZero(B->getType());
10223 if (!SE.isKnownPredicate(CmpInst::ICMP_EQ, URem, Zero)) {
10224 // Try to add a predicate ensuring B is a multiple of 1 << Mult2.
10225 if (!Predicates)
10226 return SE.getCouldNotCompute();
10227
10228 // Avoid adding a predicate that is known to be false.
10229 if (SE.isKnownPredicate(CmpInst::ICMP_NE, URem, Zero))
10230 return SE.getCouldNotCompute();
10231 Predicates->push_back(SE.getEqualPredicate(URem, Zero));
10232 }
10233 }
10234
10235 // 3. Compute I: the multiplicative inverse of (A / D) in arithmetic
10236 // modulo (N / D).
10237 //
10238 // If D == 1, (N / D) == N == 2^BW, so we need one extra bit to represent
10239 // (N / D) in general. The inverse itself always fits into BW bits, though,
10240 // so we immediately truncate it.
10241 APInt AD = A.lshr(Mult2).trunc(BW - Mult2); // AD = A / D
10242 APInt I = AD.multiplicativeInverse().zext(BW);
10243
10244 // 4. Compute the minimum unsigned root of the equation:
10245 // I * (B / D) mod (N / D)
10246 // To simplify the computation, we factor out the divide by D:
10247 // (I * B mod N) / D
10248 const SCEV *D = SE.getConstant(APInt::getOneBitSet(BW, Mult2));
10249 return SE.getUDivExactExpr(SE.getMulExpr(B, SE.getConstant(I)), D);
10250}
10251
10252/// For a given quadratic addrec, generate coefficients of the corresponding
10253/// quadratic equation, multiplied by a common value to ensure that they are
10254/// integers.
10255/// The returned value is a tuple { A, B, C, M, BitWidth }, where
10256/// Ax^2 + Bx + C is the quadratic function, M is the value that A, B and C
10257/// were multiplied by, and BitWidth is the bit width of the original addrec
10258/// coefficients.
10259/// This function returns std::nullopt if the addrec coefficients are not
10260/// compile- time constants.
10261static std::optional<std::tuple<APInt, APInt, APInt, APInt, unsigned>>
10263 assert(AddRec->getNumOperands() == 3 && "This is not a quadratic chrec!");
10264 const SCEVConstant *LC = dyn_cast<SCEVConstant>(AddRec->getOperand(0));
10265 const SCEVConstant *MC = dyn_cast<SCEVConstant>(AddRec->getOperand(1));
10266 const SCEVConstant *NC = dyn_cast<SCEVConstant>(AddRec->getOperand(2));
10267 LLVM_DEBUG(dbgs() << __func__ << ": analyzing quadratic addrec: "
10268 << *AddRec << '\n');
10269
10270 // We currently can only solve this if the coefficients are constants.
10271 if (!LC || !MC || !NC) {
10272 LLVM_DEBUG(dbgs() << __func__ << ": coefficients are not constant\n");
10273 return std::nullopt;
10274 }
10275
10276 APInt L = LC->getAPInt();
10277 APInt M = MC->getAPInt();
10278 APInt N = NC->getAPInt();
10279 assert(!N.isZero() && "This is not a quadratic addrec");
10280
10281 unsigned BitWidth = LC->getAPInt().getBitWidth();
10282 unsigned NewWidth = BitWidth + 1;
10283 LLVM_DEBUG(dbgs() << __func__ << ": addrec coeff bw: "
10284 << BitWidth << '\n');
10285 // The sign-extension (as opposed to a zero-extension) here matches the
10286 // extension used in SolveQuadraticEquationWrap (with the same motivation).
10287 N = N.sext(NewWidth);
10288 M = M.sext(NewWidth);
10289 L = L.sext(NewWidth);
10290
10291 // The increments are M, M+N, M+2N, ..., so the accumulated values are
10292 // L+M, (L+M)+(M+N), (L+M)+(M+N)+(M+2N), ..., that is,
10293 // L+M, L+2M+N, L+3M+3N, ...
10294 // After n iterations the accumulated value Acc is L + nM + n(n-1)/2 N.
10295 //
10296 // The equation Acc = 0 is then
10297 // L + nM + n(n-1)/2 N = 0, or 2L + 2M n + n(n-1) N = 0.
10298 // In a quadratic form it becomes:
10299 // N n^2 + (2M-N) n + 2L = 0.
10300
10301 APInt A = N;
10302 APInt B = 2 * M - A;
10303 APInt C = 2 * L;
10304 APInt T = APInt(NewWidth, 2);
10305 LLVM_DEBUG(dbgs() << __func__ << ": equation " << A << "x^2 + " << B
10306 << "x + " << C << ", coeff bw: " << NewWidth
10307 << ", multiplied by " << T << '\n');
10308 return std::make_tuple(A, B, C, T, BitWidth);
10309}
10310
10311/// Helper function to compare optional APInts:
10312/// (a) if X and Y both exist, return min(X, Y),
10313/// (b) if neither X nor Y exist, return std::nullopt,
10314/// (c) if exactly one of X and Y exists, return that value.
10315static std::optional<APInt> MinOptional(std::optional<APInt> X,
10316 std::optional<APInt> Y) {
10317 if (X && Y) {
10318 unsigned W = std::max(X->getBitWidth(), Y->getBitWidth());
10319 APInt XW = X->sext(W);
10320 APInt YW = Y->sext(W);
10321 return XW.slt(YW) ? *X : *Y;
10322 }
10323 if (!X && !Y)
10324 return std::nullopt;
10325 return X ? *X : *Y;
10326}
10327
10328/// Helper function to truncate an optional APInt to a given BitWidth.
10329/// When solving addrec-related equations, it is preferable to return a value
10330/// that has the same bit width as the original addrec's coefficients. If the
10331/// solution fits in the original bit width, truncate it (except for i1).
10332/// Returning a value of a different bit width may inhibit some optimizations.
10333///
10334/// In general, a solution to a quadratic equation generated from an addrec
10335/// may require BW+1 bits, where BW is the bit width of the addrec's
10336/// coefficients. The reason is that the coefficients of the quadratic
10337/// equation are BW+1 bits wide (to avoid truncation when converting from
10338/// the addrec to the equation).
10339static std::optional<APInt> TruncIfPossible(std::optional<APInt> X,
10340 unsigned BitWidth) {
10341 if (!X)
10342 return std::nullopt;
10343 unsigned W = X->getBitWidth();
10344 if (BitWidth > 1 && BitWidth < W && X->isIntN(BitWidth))
10345 return X->trunc(BitWidth);
10346 return X;
10347}
10348
10349/// Let c(n) be the value of the quadratic chrec {L,+,M,+,N} after n
10350/// iterations. The values L, M, N are assumed to be signed, and they
10351/// should all have the same bit widths.
10352/// Find the least n >= 0 such that c(n) = 0 in the arithmetic modulo 2^BW,
10353/// where BW is the bit width of the addrec's coefficients.
10354/// If the calculated value is a BW-bit integer (for BW > 1), it will be
10355/// returned as such, otherwise the bit width of the returned value may
10356/// be greater than BW.
10357///
10358/// This function returns std::nullopt if
10359/// (a) the addrec coefficients are not constant, or
10360/// (b) SolveQuadraticEquationWrap was unable to find a solution. For cases
10361/// like x^2 = 5, no integer solutions exist, in other cases an integer
10362/// solution may exist, but SolveQuadraticEquationWrap may fail to find it.
10363static std::optional<APInt>
10365 APInt A, B, C, M;
10366 unsigned BitWidth;
10367 auto T = GetQuadraticEquation(AddRec);
10368 if (!T)
10369 return std::nullopt;
10370
10371 std::tie(A, B, C, M, BitWidth) = *T;
10372 LLVM_DEBUG(dbgs() << __func__ << ": solving for unsigned overflow\n");
10373 std::optional<APInt> X =
10375 if (!X)
10376 return std::nullopt;
10377
10378 ConstantInt *CX = ConstantInt::get(SE.getContext(), *X);
10379 ConstantInt *V = EvaluateConstantChrecAtConstant(AddRec, CX, SE);
10380 if (!V->isZero())
10381 return std::nullopt;
10382
10383 return TruncIfPossible(X, BitWidth);
10384}
10385
10386/// Let c(n) be the value of the quadratic chrec {0,+,M,+,N} after n
10387/// iterations. The values M, N are assumed to be signed, and they
10388/// should all have the same bit widths.
10389/// Find the least n such that c(n) does not belong to the given range,
10390/// while c(n-1) does.
10391///
10392/// This function returns std::nullopt if
10393/// (a) the addrec coefficients are not constant, or
10394/// (b) SolveQuadraticEquationWrap was unable to find a solution for the
10395/// bounds of the range.
10396static std::optional<APInt>
10398 const ConstantRange &Range, ScalarEvolution &SE) {
10399 assert(AddRec->getOperand(0)->isZero() &&
10400 "Starting value of addrec should be 0");
10401 LLVM_DEBUG(dbgs() << __func__ << ": solving boundary crossing for range "
10402 << Range << ", addrec " << *AddRec << '\n');
10403 // This case is handled in getNumIterationsInRange. Here we can assume that
10404 // we start in the range.
10405 assert(Range.contains(APInt(SE.getTypeSizeInBits(AddRec->getType()), 0)) &&
10406 "Addrec's initial value should be in range");
10407
10408 APInt A, B, C, M;
10409 unsigned BitWidth;
10410 auto T = GetQuadraticEquation(AddRec);
10411 if (!T)
10412 return std::nullopt;
10413
10414 // Be careful about the return value: there can be two reasons for not
10415 // returning an actual number. First, if no solutions to the equations
10416 // were found, and second, if the solutions don't leave the given range.
10417 // The first case means that the actual solution is "unknown", the second
10418 // means that it's known, but not valid. If the solution is unknown, we
10419 // cannot make any conclusions.
10420 // Return a pair: the optional solution and a flag indicating if the
10421 // solution was found.
10422 auto SolveForBoundary =
10423 [&](APInt Bound) -> std::pair<std::optional<APInt>, bool> {
10424 // Solve for signed overflow and unsigned overflow, pick the lower
10425 // solution.
10426 LLVM_DEBUG(dbgs() << "SolveQuadraticAddRecRange: checking boundary "
10427 << Bound << " (before multiplying by " << M << ")\n");
10428 Bound *= M; // The quadratic equation multiplier.
10429
10430 std::optional<APInt> SO;
10431 if (BitWidth > 1) {
10432 LLVM_DEBUG(dbgs() << "SolveQuadraticAddRecRange: solving for "
10433 "signed overflow\n");
10435 }
10436 LLVM_DEBUG(dbgs() << "SolveQuadraticAddRecRange: solving for "
10437 "unsigned overflow\n");
10438 std::optional<APInt> UO =
10440
10441 auto LeavesRange = [&] (const APInt &X) {
10442 ConstantInt *C0 = ConstantInt::get(SE.getContext(), X);
10443 ConstantInt *V0 = EvaluateConstantChrecAtConstant(AddRec, C0, SE);
10444 if (Range.contains(V0->getValue()))
10445 return false;
10446 // X should be at least 1, so X-1 is non-negative.
10447 ConstantInt *C1 = ConstantInt::get(SE.getContext(), X-1);
10448 ConstantInt *V1 = EvaluateConstantChrecAtConstant(AddRec, C1, SE);
10449 if (Range.contains(V1->getValue()))
10450 return true;
10451 return false;
10452 };
10453
10454 // If SolveQuadraticEquationWrap returns std::nullopt, it means that there
10455 // can be a solution, but the function failed to find it. We cannot treat it
10456 // as "no solution".
10457 if (!SO || !UO)
10458 return {std::nullopt, false};
10459
10460 // Check the smaller value first to see if it leaves the range.
10461 // At this point, both SO and UO must have values.
10462 std::optional<APInt> Min = MinOptional(SO, UO);
10463 if (LeavesRange(*Min))
10464 return { Min, true };
10465 std::optional<APInt> Max = Min == SO ? UO : SO;
10466 if (LeavesRange(*Max))
10467 return { Max, true };
10468
10469 // Solutions were found, but were eliminated, hence the "true".
10470 return {std::nullopt, true};
10471 };
10472
10473 std::tie(A, B, C, M, BitWidth) = *T;
10474 // Lower bound is inclusive, subtract 1 to represent the exiting value.
10475 APInt Lower = Range.getLower().sext(A.getBitWidth()) - 1;
10476 APInt Upper = Range.getUpper().sext(A.getBitWidth());
10477 auto SL = SolveForBoundary(Lower);
10478 auto SU = SolveForBoundary(Upper);
10479 // If any of the solutions was unknown, no meaninigful conclusions can
10480 // be made.
10481 if (!SL.second || !SU.second)
10482 return std::nullopt;
10483
10484 // Claim: The correct solution is not some value between Min and Max.
10485 //
10486 // Justification: Assuming that Min and Max are different values, one of
10487 // them is when the first signed overflow happens, the other is when the
10488 // first unsigned overflow happens. Crossing the range boundary is only
10489 // possible via an overflow (treating 0 as a special case of it, modeling
10490 // an overflow as crossing k*2^W for some k).
10491 //
10492 // The interesting case here is when Min was eliminated as an invalid
10493 // solution, but Max was not. The argument is that if there was another
10494 // overflow between Min and Max, it would also have been eliminated if
10495 // it was considered.
10496 //
10497 // For a given boundary, it is possible to have two overflows of the same
10498 // type (signed/unsigned) without having the other type in between: this
10499 // can happen when the vertex of the parabola is between the iterations
10500 // corresponding to the overflows. This is only possible when the two
10501 // overflows cross k*2^W for the same k. In such case, if the second one
10502 // left the range (and was the first one to do so), the first overflow
10503 // would have to enter the range, which would mean that either we had left
10504 // the range before or that we started outside of it. Both of these cases
10505 // are contradictions.
10506 //
10507 // Claim: In the case where SolveForBoundary returns std::nullopt, the correct
10508 // solution is not some value between the Max for this boundary and the
10509 // Min of the other boundary.
10510 //
10511 // Justification: Assume that we had such Max_A and Min_B corresponding
10512 // to range boundaries A and B and such that Max_A < Min_B. If there was
10513 // a solution between Max_A and Min_B, it would have to be caused by an
10514 // overflow corresponding to either A or B. It cannot correspond to B,
10515 // since Min_B is the first occurrence of such an overflow. If it
10516 // corresponded to A, it would have to be either a signed or an unsigned
10517 // overflow that is larger than both eliminated overflows for A. But
10518 // between the eliminated overflows and this overflow, the values would
10519 // cover the entire value space, thus crossing the other boundary, which
10520 // is a contradiction.
10521
10522 return TruncIfPossible(MinOptional(SL.first, SU.first), BitWidth);
10523}
10524
10525ScalarEvolution::ExitLimit ScalarEvolution::howFarToZero(const SCEV *V,
10526 const Loop *L,
10527 bool ControlsOnlyExit,
10528 bool AllowPredicates) {
10529
10530 // This is only used for loops with a "x != y" exit test. The exit condition
10531 // is now expressed as a single expression, V = x-y. So the exit test is
10532 // effectively V != 0. We know and take advantage of the fact that this
10533 // expression only being used in a comparison by zero context.
10534
10536 // If the value is a constant
10537 if (const SCEVConstant *C = dyn_cast<SCEVConstant>(V)) {
10538 // If the value is already zero, the branch will execute zero times.
10539 if (C->getValue()->isZero()) return C;
10540 return getCouldNotCompute(); // Otherwise it will loop infinitely.
10541 }
10542
10543 const SCEVAddRecExpr *AddRec =
10544 dyn_cast<SCEVAddRecExpr>(stripInjectiveFunctions(V));
10545
10546 if (!AddRec && AllowPredicates)
10547 // Try to make this an AddRec using runtime tests, in the first X
10548 // iterations of this loop, where X is the SCEV expression found by the
10549 // algorithm below.
10550 AddRec = convertSCEVToAddRecWithPredicates(V, L, Predicates);
10551
10552 if (!AddRec || AddRec->getLoop() != L)
10553 return getCouldNotCompute();
10554
10555 // If this is a quadratic (3-term) AddRec {L,+,M,+,N}, find the roots of
10556 // the quadratic equation to solve it.
10557 if (AddRec->isQuadratic() && AddRec->getType()->isIntegerTy()) {
10558 // We can only use this value if the chrec ends up with an exact zero
10559 // value at this index. When solving for "X*X != 5", for example, we
10560 // should not accept a root of 2.
10561 if (auto S = SolveQuadraticAddRecExact(AddRec, *this)) {
10562 const auto *R = cast<SCEVConstant>(getConstant(*S));
10563 return ExitLimit(R, R, R, false, Predicates);
10564 }
10565 return getCouldNotCompute();
10566 }
10567
10568 // Otherwise we can only handle this if it is affine.
10569 if (!AddRec->isAffine())
10570 return getCouldNotCompute();
10571
10572 // If this is an affine expression, the execution count of this branch is
10573 // the minimum unsigned root of the following equation:
10574 //
10575 // Start + Step*N = 0 (mod 2^BW)
10576 //
10577 // equivalent to:
10578 //
10579 // Step*N = -Start (mod 2^BW)
10580 //
10581 // where BW is the common bit width of Start and Step.
10582
10583 // Get the initial value for the loop.
10584 const SCEV *Start = getSCEVAtScope(AddRec->getStart(), L->getParentLoop());
10585 const SCEV *Step = getSCEVAtScope(AddRec->getOperand(1), L->getParentLoop());
10586
10587 if (!isLoopInvariant(Step, L))
10588 return getCouldNotCompute();
10589
10590 LoopGuards Guards = LoopGuards::collect(L, *this);
10591 // Specialize step for this loop so we get context sensitive facts below.
10592 const SCEV *StepWLG = applyLoopGuards(Step, Guards);
10593
10594 // For positive steps (counting up until unsigned overflow):
10595 // N = -Start/Step (as unsigned)
10596 // For negative steps (counting down to zero):
10597 // N = Start/-Step
10598 // First compute the unsigned distance from zero in the direction of Step.
10599 bool CountDown = isKnownNegative(StepWLG);
10600 if (!CountDown && !isKnownNonNegative(StepWLG))
10601 return getCouldNotCompute();
10602
10603 const SCEV *Distance = CountDown ? Start : getNegativeSCEV(Start);
10604 // Handle unitary steps, which cannot wraparound.
10605 // 1*N = -Start; -1*N = Start (mod 2^BW), so:
10606 // N = Distance (as unsigned)
10607
10608 if (match(Step, m_CombineOr(m_scev_One(), m_scev_AllOnes()))) {
10609 APInt MaxBECount = getUnsignedRangeMax(applyLoopGuards(Distance, Guards));
10610 MaxBECount = APIntOps::umin(MaxBECount, getUnsignedRangeMax(Distance));
10611
10612 // When a loop like "for (int i = 0; i != n; ++i) { /* body */ }" is rotated,
10613 // we end up with a loop whose backedge-taken count is n - 1. Detect this
10614 // case, and see if we can improve the bound.
10615 //
10616 // Explicitly handling this here is necessary because getUnsignedRange
10617 // isn't context-sensitive; it doesn't know that we only care about the
10618 // range inside the loop.
10619 const SCEV *Zero = getZero(Distance->getType());
10620 const SCEV *One = getOne(Distance->getType());
10621 const SCEV *DistancePlusOne = getAddExpr(Distance, One);
10622 if (isLoopEntryGuardedByCond(L, ICmpInst::ICMP_NE, DistancePlusOne, Zero)) {
10623 // If Distance + 1 doesn't overflow, we can compute the maximum distance
10624 // as "unsigned_max(Distance + 1) - 1".
10625 ConstantRange CR = getUnsignedRange(DistancePlusOne);
10626 MaxBECount = APIntOps::umin(MaxBECount, CR.getUnsignedMax() - 1);
10627 }
10628 return ExitLimit(Distance, getConstant(MaxBECount), Distance, false,
10629 Predicates);
10630 }
10631
10632 // If the condition controls loop exit (the loop exits only if the expression
10633 // is true) and the addition is no-wrap we can use unsigned divide to
10634 // compute the backedge count. In this case, the step may not divide the
10635 // distance, but we don't care because if the condition is "missed" the loop
10636 // will have undefined behavior due to wrapping.
10637 if (ControlsOnlyExit && AddRec->hasNoSelfWrap() &&
10638 loopHasNoAbnormalExits(AddRec->getLoop())) {
10639
10640 // If the stride is zero, the loop must be infinite. In C++, most loops
10641 // are finite by assumption, in which case the step being zero implies
10642 // UB must execute if the loop is entered.
10643 if (!loopIsFiniteByAssumption(L) && !isKnownNonZero(StepWLG))
10644 return getCouldNotCompute();
10645
10646 const SCEV *Exact =
10647 getUDivExpr(Distance, CountDown ? getNegativeSCEV(Step) : Step);
10648 const SCEV *ConstantMax = getCouldNotCompute();
10649 if (Exact != getCouldNotCompute()) {
10651 ConstantMax =
10653 }
10654 const SCEV *SymbolicMax =
10655 isa<SCEVCouldNotCompute>(Exact) ? ConstantMax : Exact;
10656 return ExitLimit(Exact, ConstantMax, SymbolicMax, false, Predicates);
10657 }
10658
10659 // Solve the general equation.
10660 const SCEVConstant *StepC = dyn_cast<SCEVConstant>(Step);
10661 if (!StepC || StepC->getValue()->isZero())
10662 return getCouldNotCompute();
10664 StepC->getAPInt(), getNegativeSCEV(Start),
10665 AllowPredicates ? &Predicates : nullptr, *this);
10666
10667 const SCEV *M = E;
10668 if (E != getCouldNotCompute()) {
10669 APInt MaxWithGuards = getUnsignedRangeMax(applyLoopGuards(E, Guards));
10670 M = getConstant(APIntOps::umin(MaxWithGuards, getUnsignedRangeMax(E)));
10671 }
10672 auto *S = isa<SCEVCouldNotCompute>(E) ? M : E;
10673 return ExitLimit(E, M, S, false, Predicates);
10674}
10675
10677ScalarEvolution::howFarToNonZero(const SCEV *V, const Loop *L) {
10678 // Loops that look like: while (X == 0) are very strange indeed. We don't
10679 // handle them yet except for the trivial case. This could be expanded in the
10680 // future as needed.
10681
10682 // If the value is a constant, check to see if it is known to be non-zero
10683 // already. If so, the backedge will execute zero times.
10684 if (const SCEVConstant *C = dyn_cast<SCEVConstant>(V)) {
10685 if (!C->getValue()->isZero())
10686 return getZero(C->getType());
10687 return getCouldNotCompute(); // Otherwise it will loop infinitely.
10688 }
10689
10690 // We could implement others, but I really doubt anyone writes loops like
10691 // this, and if they did, they would already be constant folded.
10692 return getCouldNotCompute();
10693}
10694
10695std::pair<const BasicBlock *, const BasicBlock *>
10696ScalarEvolution::getPredecessorWithUniqueSuccessorForBB(const BasicBlock *BB)
10697 const {
10698 // If the block has a unique predecessor, then there is no path from the
10699 // predecessor to the block that does not go through the direct edge
10700 // from the predecessor to the block.
10701 if (const BasicBlock *Pred = BB->getSinglePredecessor())
10702 return {Pred, BB};
10703
10704 // A loop's header is defined to be a block that dominates the loop.
10705 // If the header has a unique predecessor outside the loop, it must be
10706 // a block that has exactly one successor that can reach the loop.
10707 if (const Loop *L = LI.getLoopFor(BB))
10708 return {L->getLoopPredecessor(), L->getHeader()};
10709
10710 return {nullptr, BB};
10711}
10712
10713/// SCEV structural equivalence is usually sufficient for testing whether two
10714/// expressions are equal, however for the purposes of looking for a condition
10715/// guarding a loop, it can be useful to be a little more general, since a
10716/// front-end may have replicated the controlling expression.
10717static bool HasSameValue(const SCEV *A, const SCEV *B) {
10718 // Quick check to see if they are the same SCEV.
10719 if (A == B) return true;
10720
10721 auto ComputesEqualValues = [](const Instruction *A, const Instruction *B) {
10722 // Not all instructions that are "identical" compute the same value. For
10723 // instance, two distinct alloca instructions allocating the same type are
10724 // identical and do not read memory; but compute distinct values.
10725 return A->isIdenticalTo(B) && (isa<BinaryOperator>(A) || isa<GetElementPtrInst>(A));
10726 };
10727
10728 // Otherwise, if they're both SCEVUnknown, it's possible that they hold
10729 // two different instructions with the same value. Check for this case.
10730 if (const SCEVUnknown *AU = dyn_cast<SCEVUnknown>(A))
10731 if (const SCEVUnknown *BU = dyn_cast<SCEVUnknown>(B))
10732 if (const Instruction *AI = dyn_cast<Instruction>(AU->getValue()))
10733 if (const Instruction *BI = dyn_cast<Instruction>(BU->getValue()))
10734 if (ComputesEqualValues(AI, BI))
10735 return true;
10736
10737 // Otherwise assume they may have a different value.
10738 return false;
10739}
10740
10741static bool MatchBinarySub(const SCEV *S, const SCEV *&LHS, const SCEV *&RHS) {
10742 const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(S);
10743 if (!Add || Add->getNumOperands() != 2)
10744 return false;
10745 if (auto *ME = dyn_cast<SCEVMulExpr>(Add->getOperand(0));
10746 ME && ME->getNumOperands() == 2 && ME->getOperand(0)->isAllOnesValue()) {
10747 LHS = Add->getOperand(1);
10748 RHS = ME->getOperand(1);
10749 return true;
10750 }
10751 if (auto *ME = dyn_cast<SCEVMulExpr>(Add->getOperand(1));
10752 ME && ME->getNumOperands() == 2 && ME->getOperand(0)->isAllOnesValue()) {
10753 LHS = Add->getOperand(0);
10754 RHS = ME->getOperand(1);
10755 return true;
10756 }
10757 return false;
10758}
10759
10761 const SCEV *&RHS, unsigned Depth) {
10762 bool Changed = false;
10763 // Simplifies ICMP to trivial true or false by turning it into '0 == 0' or
10764 // '0 != 0'.
10765 auto TrivialCase = [&](bool TriviallyTrue) {
10767 Pred = TriviallyTrue ? ICmpInst::ICMP_EQ : ICmpInst::ICMP_NE;
10768 return true;
10769 };
10770 // If we hit the max recursion limit bail out.
10771 if (Depth >= 3)
10772 return false;
10773
10774 // Canonicalize a constant to the right side.
10775 if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(LHS)) {
10776 // Check for both operands constant.
10777 if (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS)) {
10778 if (!ICmpInst::compare(LHSC->getAPInt(), RHSC->getAPInt(), Pred))
10779 return TrivialCase(false);
10780 return TrivialCase(true);
10781 }
10782 // Otherwise swap the operands to put the constant on the right.
10783 std::swap(LHS, RHS);
10785 Changed = true;
10786 }
10787
10788 // If we're comparing an addrec with a value which is loop-invariant in the
10789 // addrec's loop, put the addrec on the left. Also make a dominance check,
10790 // as both operands could be addrecs loop-invariant in each other's loop.
10791 if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(RHS)) {
10792 const Loop *L = AR->getLoop();
10793 if (isLoopInvariant(LHS, L) && properlyDominates(LHS, L->getHeader())) {
10794 std::swap(LHS, RHS);
10796 Changed = true;
10797 }
10798 }
10799
10800 // If there's a constant operand, canonicalize comparisons with boundary
10801 // cases, and canonicalize *-or-equal comparisons to regular comparisons.
10802 if (const SCEVConstant *RC = dyn_cast<SCEVConstant>(RHS)) {
10803 const APInt &RA = RC->getAPInt();
10804
10805 bool SimplifiedByConstantRange = false;
10806
10807 if (!ICmpInst::isEquality(Pred)) {
10809 if (ExactCR.isFullSet())
10810 return TrivialCase(true);
10811 if (ExactCR.isEmptySet())
10812 return TrivialCase(false);
10813
10814 APInt NewRHS;
10815 CmpInst::Predicate NewPred;
10816 if (ExactCR.getEquivalentICmp(NewPred, NewRHS) &&
10817 ICmpInst::isEquality(NewPred)) {
10818 // We were able to convert an inequality to an equality.
10819 Pred = NewPred;
10820 RHS = getConstant(NewRHS);
10821 Changed = SimplifiedByConstantRange = true;
10822 }
10823 }
10824
10825 if (!SimplifiedByConstantRange) {
10826 switch (Pred) {
10827 default:
10828 break;
10829 case ICmpInst::ICMP_EQ:
10830 case ICmpInst::ICMP_NE:
10831 // Fold ((-1) * %a) + %b == 0 (equivalent to %b-%a == 0) into %a == %b.
10832 if (RA.isZero() && MatchBinarySub(LHS, LHS, RHS))
10833 Changed = true;
10834 break;
10835
10836 // The "Should have been caught earlier!" messages refer to the fact
10837 // that the ExactCR.isFullSet() or ExactCR.isEmptySet() check above
10838 // should have fired on the corresponding cases, and canonicalized the
10839 // check to trivial case.
10840
10841 case ICmpInst::ICMP_UGE:
10842 assert(!RA.isMinValue() && "Should have been caught earlier!");
10843 Pred = ICmpInst::ICMP_UGT;
10844 RHS = getConstant(RA - 1);
10845 Changed = true;
10846 break;
10847 case ICmpInst::ICMP_ULE:
10848 assert(!RA.isMaxValue() && "Should have been caught earlier!");
10849 Pred = ICmpInst::ICMP_ULT;
10850 RHS = getConstant(RA + 1);
10851 Changed = true;
10852 break;
10853 case ICmpInst::ICMP_SGE:
10854 assert(!RA.isMinSignedValue() && "Should have been caught earlier!");
10855 Pred = ICmpInst::ICMP_SGT;
10856 RHS = getConstant(RA - 1);
10857 Changed = true;
10858 break;
10859 case ICmpInst::ICMP_SLE:
10860 assert(!RA.isMaxSignedValue() && "Should have been caught earlier!");
10861 Pred = ICmpInst::ICMP_SLT;
10862 RHS = getConstant(RA + 1);
10863 Changed = true;
10864 break;
10865 }
10866 }
10867 }
10868
10869 // Check for obvious equality.
10870 if (HasSameValue(LHS, RHS)) {
10871 if (ICmpInst::isTrueWhenEqual(Pred))
10872 return TrivialCase(true);
10874 return TrivialCase(false);
10875 }
10876
10877 // If possible, canonicalize GE/LE comparisons to GT/LT comparisons, by
10878 // adding or subtracting 1 from one of the operands.
10879 switch (Pred) {
10880 case ICmpInst::ICMP_SLE:
10881 if (!getSignedRangeMax(RHS).isMaxSignedValue()) {
10882 RHS = getAddExpr(getConstant(RHS->getType(), 1, true), RHS,
10884 Pred = ICmpInst::ICMP_SLT;
10885 Changed = true;
10886 } else if (!getSignedRangeMin(LHS).isMinSignedValue()) {
10887 LHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), LHS,
10889 Pred = ICmpInst::ICMP_SLT;
10890 Changed = true;
10891 }
10892 break;
10893 case ICmpInst::ICMP_SGE:
10894 if (!getSignedRangeMin(RHS).isMinSignedValue()) {
10895 RHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), RHS,
10897 Pred = ICmpInst::ICMP_SGT;
10898 Changed = true;
10899 } else if (!getSignedRangeMax(LHS).isMaxSignedValue()) {
10900 LHS = getAddExpr(getConstant(RHS->getType(), 1, true), LHS,
10902 Pred = ICmpInst::ICMP_SGT;
10903 Changed = true;
10904 }
10905 break;
10906 case ICmpInst::ICMP_ULE:
10907 if (!getUnsignedRangeMax(RHS).isMaxValue()) {
10908 RHS = getAddExpr(getConstant(RHS->getType(), 1, true), RHS,
10910 Pred = ICmpInst::ICMP_ULT;
10911 Changed = true;
10912 } else if (!getUnsignedRangeMin(LHS).isMinValue()) {
10913 LHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), LHS);
10914 Pred = ICmpInst::ICMP_ULT;
10915 Changed = true;
10916 }
10917 break;
10918 case ICmpInst::ICMP_UGE:
10919 if (!getUnsignedRangeMin(RHS).isMinValue()) {
10920 RHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), RHS);
10921 Pred = ICmpInst::ICMP_UGT;
10922 Changed = true;
10923 } else if (!getUnsignedRangeMax(LHS).isMaxValue()) {
10924 LHS = getAddExpr(getConstant(RHS->getType(), 1, true), LHS,
10926 Pred = ICmpInst::ICMP_UGT;
10927 Changed = true;
10928 }
10929 break;
10930 default:
10931 break;
10932 }
10933
10934 // TODO: More simplifications are possible here.
10935
10936 // Recursively simplify until we either hit a recursion limit or nothing
10937 // changes.
10938 if (Changed)
10939 return SimplifyICmpOperands(Pred, LHS, RHS, Depth + 1);
10940
10941 return Changed;
10942}
10943
10945 return getSignedRangeMax(S).isNegative();
10946}
10947
10950}
10951
10953 return !getSignedRangeMin(S).isNegative();
10954}
10955
10958}
10959
10961 // Query push down for cases where the unsigned range is
10962 // less than sufficient.
10963 if (const auto *SExt = dyn_cast<SCEVSignExtendExpr>(S))
10964 return isKnownNonZero(SExt->getOperand(0));
10965 return getUnsignedRangeMin(S) != 0;
10966}
10967
10969 bool OrNegative) {
10970 auto NonRecursive = [this, OrNegative](const SCEV *S) {
10971 if (auto *C = dyn_cast<SCEVConstant>(S))
10972 return C->getAPInt().isPowerOf2() ||
10973 (OrNegative && C->getAPInt().isNegatedPowerOf2());
10974
10975 // The vscale_range indicates vscale is a power-of-two.
10976 return isa<SCEVVScale>(S) && F.hasFnAttribute(Attribute::VScaleRange);
10977 };
10978
10979 if (NonRecursive(S))
10980 return true;
10981
10982 auto *Mul = dyn_cast<SCEVMulExpr>(S);
10983 if (!Mul)
10984 return false;
10985 return all_of(Mul->operands(), NonRecursive) && (OrZero || isKnownNonZero(S));
10986}
10987
10988std::pair<const SCEV *, const SCEV *>
10990 // Compute SCEV on entry of loop L.
10991 const SCEV *Start = SCEVInitRewriter::rewrite(S, L, *this);
10992 if (Start == getCouldNotCompute())
10993 return { Start, Start };
10994 // Compute post increment SCEV for loop L.
10995 const SCEV *PostInc = SCEVPostIncRewriter::rewrite(S, L, *this);
10996 assert(PostInc != getCouldNotCompute() && "Unexpected could not compute");
10997 return { Start, PostInc };
10998}
10999
11001 const SCEV *RHS) {
11002 // First collect all loops.
11004 getUsedLoops(LHS, LoopsUsed);
11005 getUsedLoops(RHS, LoopsUsed);
11006
11007 if (LoopsUsed.empty())
11008 return false;
11009
11010 // Domination relationship must be a linear order on collected loops.
11011#ifndef NDEBUG
11012 for (const auto *L1 : LoopsUsed)
11013 for (const auto *L2 : LoopsUsed)
11014 assert((DT.dominates(L1->getHeader(), L2->getHeader()) ||
11015 DT.dominates(L2->getHeader(), L1->getHeader())) &&
11016 "Domination relationship is not a linear order");
11017#endif
11018
11019 const Loop *MDL =
11020 *llvm::max_element(LoopsUsed, [&](const Loop *L1, const Loop *L2) {
11021 return DT.properlyDominates(L1->getHeader(), L2->getHeader());
11022 });
11023
11024 // Get init and post increment value for LHS.
11025 auto SplitLHS = SplitIntoInitAndPostInc(MDL, LHS);
11026 // if LHS contains unknown non-invariant SCEV then bail out.
11027 if (SplitLHS.first == getCouldNotCompute())
11028 return false;
11029 assert (SplitLHS.second != getCouldNotCompute() && "Unexpected CNC");
11030 // Get init and post increment value for RHS.
11031 auto SplitRHS = SplitIntoInitAndPostInc(MDL, RHS);
11032 // if RHS contains unknown non-invariant SCEV then bail out.
11033 if (SplitRHS.first == getCouldNotCompute())
11034 return false;
11035 assert (SplitRHS.second != getCouldNotCompute() && "Unexpected CNC");
11036 // It is possible that init SCEV contains an invariant load but it does
11037 // not dominate MDL and is not available at MDL loop entry, so we should
11038 // check it here.
11039 if (!isAvailableAtLoopEntry(SplitLHS.first, MDL) ||
11040 !isAvailableAtLoopEntry(SplitRHS.first, MDL))
11041 return false;
11042
11043 // It seems backedge guard check is faster than entry one so in some cases
11044 // it can speed up whole estimation by short circuit
11045 return isLoopBackedgeGuardedByCond(MDL, Pred, SplitLHS.second,
11046 SplitRHS.second) &&
11047 isLoopEntryGuardedByCond(MDL, Pred, SplitLHS.first, SplitRHS.first);
11048}
11049
11051 const SCEV *RHS) {
11052 // Canonicalize the inputs first.
11053 (void)SimplifyICmpOperands(Pred, LHS, RHS);
11054
11055 if (isKnownViaInduction(Pred, LHS, RHS))
11056 return true;
11057
11058 if (isKnownPredicateViaSplitting(Pred, LHS, RHS))
11059 return true;
11060
11061 // Otherwise see what can be done with some simple reasoning.
11062 return isKnownViaNonRecursiveReasoning(Pred, LHS, RHS);
11063}
11064
11066 const SCEV *LHS,
11067 const SCEV *RHS) {
11068 if (isKnownPredicate(Pred, LHS, RHS))
11069 return true;
11071 return false;
11072 return std::nullopt;
11073}
11074
11076 const SCEV *RHS,
11077 const Instruction *CtxI) {
11078 // TODO: Analyze guards and assumes from Context's block.
11079 return isKnownPredicate(Pred, LHS, RHS) ||
11081}
11082
11083std::optional<bool>
11085 const SCEV *RHS, const Instruction *CtxI) {
11086 std::optional<bool> KnownWithoutContext = evaluatePredicate(Pred, LHS, RHS);
11087 if (KnownWithoutContext)
11088 return KnownWithoutContext;
11089
11090 if (isBasicBlockEntryGuardedByCond(CtxI->getParent(), Pred, LHS, RHS))
11091 return true;
11094 return false;
11095 return std::nullopt;
11096}
11097
11099 const SCEVAddRecExpr *LHS,
11100 const SCEV *RHS) {
11101 const Loop *L = LHS->getLoop();
11102 return isLoopEntryGuardedByCond(L, Pred, LHS->getStart(), RHS) &&
11103 isLoopBackedgeGuardedByCond(L, Pred, LHS->getPostIncExpr(*this), RHS);
11104}
11105
11106std::optional<ScalarEvolution::MonotonicPredicateType>
11108 ICmpInst::Predicate Pred) {
11109 auto Result = getMonotonicPredicateTypeImpl(LHS, Pred);
11110
11111#ifndef NDEBUG
11112 // Verify an invariant: inverting the predicate should turn a monotonically
11113 // increasing change to a monotonically decreasing one, and vice versa.
11114 if (Result) {
11115 auto ResultSwapped =
11116 getMonotonicPredicateTypeImpl(LHS, ICmpInst::getSwappedPredicate(Pred));
11117
11118 assert(*ResultSwapped != *Result &&
11119 "monotonicity should flip as we flip the predicate");
11120 }
11121#endif
11122
11123 return Result;
11124}
11125
11126std::optional<ScalarEvolution::MonotonicPredicateType>
11127ScalarEvolution::getMonotonicPredicateTypeImpl(const SCEVAddRecExpr *LHS,
11128 ICmpInst::Predicate Pred) {
11129 // A zero step value for LHS means the induction variable is essentially a
11130 // loop invariant value. We don't really depend on the predicate actually
11131 // flipping from false to true (for increasing predicates, and the other way
11132 // around for decreasing predicates), all we care about is that *if* the
11133 // predicate changes then it only changes from false to true.
11134 //
11135 // A zero step value in itself is not very useful, but there may be places
11136 // where SCEV can prove X >= 0 but not prove X > 0, so it is helpful to be
11137 // as general as possible.
11138
11139 // Only handle LE/LT/GE/GT predicates.
11140 if (!ICmpInst::isRelational(Pred))
11141 return std::nullopt;
11142
11143 bool IsGreater = ICmpInst::isGE(Pred) || ICmpInst::isGT(Pred);
11144 assert((IsGreater || ICmpInst::isLE(Pred) || ICmpInst::isLT(Pred)) &&
11145 "Should be greater or less!");
11146
11147 // Check that AR does not wrap.
11148 if (ICmpInst::isUnsigned(Pred)) {
11149 if (!LHS->hasNoUnsignedWrap())
11150 return std::nullopt;
11152 }
11153 assert(ICmpInst::isSigned(Pred) &&
11154 "Relational predicate is either signed or unsigned!");
11155 if (!LHS->hasNoSignedWrap())
11156 return std::nullopt;
11157
11158 const SCEV *Step = LHS->getStepRecurrence(*this);
11159
11160 if (isKnownNonNegative(Step))
11162
11163 if (isKnownNonPositive(Step))
11165
11166 return std::nullopt;
11167}
11168
11169std::optional<ScalarEvolution::LoopInvariantPredicate>
11171 const SCEV *LHS, const SCEV *RHS,
11172 const Loop *L,
11173 const Instruction *CtxI) {
11174 // If there is a loop-invariant, force it into the RHS, otherwise bail out.
11175 if (!isLoopInvariant(RHS, L)) {
11176 if (!isLoopInvariant(LHS, L))
11177 return std::nullopt;
11178
11179 std::swap(LHS, RHS);
11180 Pred = ICmpInst::getSwappedPredicate(Pred);
11181 }
11182
11183 const SCEVAddRecExpr *ArLHS = dyn_cast<SCEVAddRecExpr>(LHS);
11184 if (!ArLHS || ArLHS->getLoop() != L)
11185 return std::nullopt;
11186
11187 auto MonotonicType = getMonotonicPredicateType(ArLHS, Pred);
11188 if (!MonotonicType)
11189 return std::nullopt;
11190 // If the predicate "ArLHS `Pred` RHS" monotonically increases from false to
11191 // true as the loop iterates, and the backedge is control dependent on
11192 // "ArLHS `Pred` RHS" == true then we can reason as follows:
11193 //
11194 // * if the predicate was false in the first iteration then the predicate
11195 // is never evaluated again, since the loop exits without taking the
11196 // backedge.
11197 // * if the predicate was true in the first iteration then it will
11198 // continue to be true for all future iterations since it is
11199 // monotonically increasing.
11200 //
11201 // For both the above possibilities, we can replace the loop varying
11202 // predicate with its value on the first iteration of the loop (which is
11203 // loop invariant).
11204 //
11205 // A similar reasoning applies for a monotonically decreasing predicate, by
11206 // replacing true with false and false with true in the above two bullets.
11208 auto P = Increasing ? Pred : ICmpInst::getInversePredicate(Pred);
11209
11212 RHS);
11213
11214 if (!CtxI)
11215 return std::nullopt;
11216 // Try to prove via context.
11217 // TODO: Support other cases.
11218 switch (Pred) {
11219 default:
11220 break;
11221 case ICmpInst::ICMP_ULE:
11222 case ICmpInst::ICMP_ULT: {
11223 assert(ArLHS->hasNoUnsignedWrap() && "Is a requirement of monotonicity!");
11224 // Given preconditions
11225 // (1) ArLHS does not cross the border of positive and negative parts of
11226 // range because of:
11227 // - Positive step; (TODO: lift this limitation)
11228 // - nuw - does not cross zero boundary;
11229 // - nsw - does not cross SINT_MAX boundary;
11230 // (2) ArLHS <s RHS
11231 // (3) RHS >=s 0
11232 // we can replace the loop variant ArLHS <u RHS condition with loop
11233 // invariant Start(ArLHS) <u RHS.
11234 //
11235 // Because of (1) there are two options:
11236 // - ArLHS is always negative. It means that ArLHS <u RHS is always false;
11237 // - ArLHS is always non-negative. Because of (3) RHS is also non-negative.
11238 // It means that ArLHS <s RHS <=> ArLHS <u RHS.
11239 // Because of (2) ArLHS <u RHS is trivially true.
11240 // All together it means that ArLHS <u RHS <=> Start(ArLHS) >=s 0.
11241 // We can strengthen this to Start(ArLHS) <u RHS.
11242 auto SignFlippedPred = ICmpInst::getFlippedSignednessPredicate(Pred);
11243 if (ArLHS->hasNoSignedWrap() && ArLHS->isAffine() &&
11244 isKnownPositive(ArLHS->getStepRecurrence(*this)) &&
11246 isKnownPredicateAt(SignFlippedPred, ArLHS, RHS, CtxI))
11248 RHS);
11249 }
11250 }
11251
11252 return std::nullopt;
11253}
11254
11255std::optional<ScalarEvolution::LoopInvariantPredicate>
11257 CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS, const Loop *L,
11258 const Instruction *CtxI, const SCEV *MaxIter) {
11260 Pred, LHS, RHS, L, CtxI, MaxIter))
11261 return LIP;
11262 if (auto *UMin = dyn_cast<SCEVUMinExpr>(MaxIter))
11263 // Number of iterations expressed as UMIN isn't always great for expressing
11264 // the value on the last iteration. If the straightforward approach didn't
11265 // work, try the following trick: if the a predicate is invariant for X, it
11266 // is also invariant for umin(X, ...). So try to find something that works
11267 // among subexpressions of MaxIter expressed as umin.
11268 for (auto *Op : UMin->operands())
11270 Pred, LHS, RHS, L, CtxI, Op))
11271 return LIP;
11272 return std::nullopt;
11273}
11274
11275std::optional<ScalarEvolution::LoopInvariantPredicate>
11277 CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS, const Loop *L,
11278 const Instruction *CtxI, const SCEV *MaxIter) {
11279 // Try to prove the following set of facts:
11280 // - The predicate is monotonic in the iteration space.
11281 // - If the check does not fail on the 1st iteration:
11282 // - No overflow will happen during first MaxIter iterations;
11283 // - It will not fail on the MaxIter'th iteration.
11284 // If the check does fail on the 1st iteration, we leave the loop and no
11285 // other checks matter.
11286
11287 // If there is a loop-invariant, force it into the RHS, otherwise bail out.
11288 if (!isLoopInvariant(RHS, L)) {
11289 if (!isLoopInvariant(LHS, L))
11290 return std::nullopt;
11291
11292 std::swap(LHS, RHS);
11294 }
11295
11296 auto *AR = dyn_cast<SCEVAddRecExpr>(LHS);
11297 if (!AR || AR->getLoop() != L)
11298 return std::nullopt;
11299
11300 // The predicate must be relational (i.e. <, <=, >=, >).
11301 if (!ICmpInst::isRelational(Pred))
11302 return std::nullopt;
11303
11304 // TODO: Support steps other than +/- 1.
11305 const SCEV *Step = AR->getStepRecurrence(*this);
11306 auto *One = getOne(Step->getType());
11307 auto *MinusOne = getNegativeSCEV(One);
11308 if (Step != One && Step != MinusOne)
11309 return std::nullopt;
11310
11311 // Type mismatch here means that MaxIter is potentially larger than max
11312 // unsigned value in start type, which mean we cannot prove no wrap for the
11313 // indvar.
11314 if (AR->getType() != MaxIter->getType())
11315 return std::nullopt;
11316
11317 // Value of IV on suggested last iteration.
11318 const SCEV *Last = AR->evaluateAtIteration(MaxIter, *this);
11319 // Does it still meet the requirement?
11320 if (!isLoopBackedgeGuardedByCond(L, Pred, Last, RHS))
11321 return std::nullopt;
11322 // Because step is +/- 1 and MaxIter has same type as Start (i.e. it does
11323 // not exceed max unsigned value of this type), this effectively proves
11324 // that there is no wrap during the iteration. To prove that there is no
11325 // signed/unsigned wrap, we need to check that
11326 // Start <= Last for step = 1 or Start >= Last for step = -1.
11327 ICmpInst::Predicate NoOverflowPred =
11329 if (Step == MinusOne)
11330 NoOverflowPred = ICmpInst::getSwappedCmpPredicate(NoOverflowPred);
11331 const SCEV *Start = AR->getStart();
11332 if (!isKnownPredicateAt(NoOverflowPred, Start, Last, CtxI))
11333 return std::nullopt;
11334
11335 // Everything is fine.
11336 return ScalarEvolution::LoopInvariantPredicate(Pred, Start, RHS);
11337}
11338
11339bool ScalarEvolution::isKnownPredicateViaConstantRanges(CmpPredicate Pred,
11340 const SCEV *LHS,
11341 const SCEV *RHS) {
11342 if (HasSameValue(LHS, RHS))
11343 return ICmpInst::isTrueWhenEqual(Pred);
11344
11345 // This code is split out from isKnownPredicate because it is called from
11346 // within isLoopEntryGuardedByCond.
11347
11348 auto CheckRanges = [&](const ConstantRange &RangeLHS,
11349 const ConstantRange &RangeRHS) {
11350 return RangeLHS.icmp(Pred, RangeRHS);
11351 };
11352
11353 // The check at the top of the function catches the case where the values are
11354 // known to be equal.
11355 if (Pred == CmpInst::ICMP_EQ)
11356 return false;
11357
11358 if (Pred == CmpInst::ICMP_NE) {
11359 auto SL = getSignedRange(LHS);
11360 auto SR = getSignedRange(RHS);
11361 if (CheckRanges(SL, SR))
11362 return true;
11363 auto UL = getUnsignedRange(LHS);
11364 auto UR = getUnsignedRange(RHS);
11365 if (CheckRanges(UL, UR))
11366 return true;
11367 auto *Diff = getMinusSCEV(LHS, RHS);
11368 return !isa<SCEVCouldNotCompute>(Diff) && isKnownNonZero(Diff);
11369 }
11370
11371 if (CmpInst::isSigned(Pred)) {
11372 auto SL = getSignedRange(LHS);
11373 auto SR = getSignedRange(RHS);
11374 return CheckRanges(SL, SR);
11375 }
11376
11377 auto UL = getUnsignedRange(LHS);
11378 auto UR = getUnsignedRange(RHS);
11379 return CheckRanges(UL, UR);
11380}
11381
11382bool ScalarEvolution::isKnownPredicateViaNoOverflow(CmpPredicate Pred,
11383 const SCEV *LHS,
11384 const SCEV *RHS) {
11385 // Match X to (A + C1)<ExpectedFlags> and Y to (A + C2)<ExpectedFlags>, where
11386 // C1 and C2 are constant integers. If either X or Y are not add expressions,
11387 // consider them as X + 0 and Y + 0 respectively. C1 and C2 are returned via
11388 // OutC1 and OutC2.
11389 auto MatchBinaryAddToConst = [this](const SCEV *X, const SCEV *Y,
11390 APInt &OutC1, APInt &OutC2,
11391 SCEV::NoWrapFlags ExpectedFlags) {
11392 const SCEV *XNonConstOp, *XConstOp;
11393 const SCEV *YNonConstOp, *YConstOp;
11394 SCEV::NoWrapFlags XFlagsPresent;
11395 SCEV::NoWrapFlags YFlagsPresent;
11396
11397 if (!splitBinaryAdd(X, XConstOp, XNonConstOp, XFlagsPresent)) {
11398 XConstOp = getZero(X->getType());
11399 XNonConstOp = X;
11400 XFlagsPresent = ExpectedFlags;
11401 }
11402 if (!isa<SCEVConstant>(XConstOp) ||
11403 (XFlagsPresent & ExpectedFlags) != ExpectedFlags)
11404 return false;
11405
11406 if (!splitBinaryAdd(Y, YConstOp, YNonConstOp, YFlagsPresent)) {
11407 YConstOp = getZero(Y->getType());
11408 YNonConstOp = Y;
11409 YFlagsPresent = ExpectedFlags;
11410 }
11411
11412 if (!isa<SCEVConstant>(YConstOp) ||
11413 (YFlagsPresent & ExpectedFlags) != ExpectedFlags)
11414 return false;
11415
11416 if (YNonConstOp != XNonConstOp)
11417 return false;
11418
11419 OutC1 = cast<SCEVConstant>(XConstOp)->getAPInt();
11420 OutC2 = cast<SCEVConstant>(YConstOp)->getAPInt();
11421
11422 return true;
11423 };
11424
11425 APInt C1;
11426 APInt C2;
11427
11428 switch (Pred) {
11429 default:
11430 break;
11431
11432 case ICmpInst::ICMP_SGE:
11433 std::swap(LHS, RHS);
11434 [[fallthrough]];
11435 case ICmpInst::ICMP_SLE:
11436 // (X + C1)<nsw> s<= (X + C2)<nsw> if C1 s<= C2.
11437 if (MatchBinaryAddToConst(LHS, RHS, C1, C2, SCEV::FlagNSW) && C1.sle(C2))
11438 return true;
11439
11440 break;
11441
11442 case ICmpInst::ICMP_SGT:
11443 std::swap(LHS, RHS);
11444 [[fallthrough]];
11445 case ICmpInst::ICMP_SLT:
11446 // (X + C1)<nsw> s< (X + C2)<nsw> if C1 s< C2.
11447 if (MatchBinaryAddToConst(LHS, RHS, C1, C2, SCEV::FlagNSW) && C1.slt(C2))
11448 return true;
11449
11450 break;
11451
11452 case ICmpInst::ICMP_UGE:
11453 std::swap(LHS, RHS);
11454 [[fallthrough]];
11455 case ICmpInst::ICMP_ULE:
11456 // (X + C1)<nuw> u<= (X + C2)<nuw> for C1 u<= C2.
11457 if (MatchBinaryAddToConst(LHS, RHS, C1, C2, SCEV::FlagNUW) && C1.ule(C2))
11458 return true;
11459
11460 break;
11461
11462 case ICmpInst::ICMP_UGT:
11463 std::swap(LHS, RHS);
11464 [[fallthrough]];
11465 case ICmpInst::ICMP_ULT:
11466 // (X + C1)<nuw> u< (X + C2)<nuw> if C1 u< C2.
11467 if (MatchBinaryAddToConst(LHS, RHS, C1, C2, SCEV::FlagNUW) && C1.ult(C2))
11468 return true;
11469 break;
11470 }
11471
11472 return false;
11473}
11474
11475bool ScalarEvolution::isKnownPredicateViaSplitting(CmpPredicate Pred,
11476 const SCEV *LHS,
11477 const SCEV *RHS) {
11478 if (Pred != ICmpInst::ICMP_ULT || ProvingSplitPredicate)
11479 return false;
11480
11481 // Allowing arbitrary number of activations of isKnownPredicateViaSplitting on
11482 // the stack can result in exponential time complexity.
11483 SaveAndRestore Restore(ProvingSplitPredicate, true);
11484
11485 // If L >= 0 then I `ult` L <=> I >= 0 && I `slt` L
11486 //
11487 // To prove L >= 0 we use isKnownNonNegative whereas to prove I >= 0 we use
11488 // isKnownPredicate. isKnownPredicate is more powerful, but also more
11489 // expensive; and using isKnownNonNegative(RHS) is sufficient for most of the
11490 // interesting cases seen in practice. We can consider "upgrading" L >= 0 to
11491 // use isKnownPredicate later if needed.
11492 return isKnownNonNegative(RHS) &&
11495}
11496
11497bool ScalarEvolution::isImpliedViaGuard(const BasicBlock *BB, CmpPredicate Pred,
11498 const SCEV *LHS, const SCEV *RHS) {
11499 // No need to even try if we know the module has no guards.
11500 if (!HasGuards)
11501 return false;
11502
11503 return any_of(*BB, [&](const Instruction &I) {
11504 using namespace llvm::PatternMatch;
11505
11506 Value *Condition;
11507 return match(&I, m_Intrinsic<Intrinsic::experimental_guard>(
11508 m_Value(Condition))) &&
11509 isImpliedCond(Pred, LHS, RHS, Condition, false);
11510 });
11511}
11512
11513/// isLoopBackedgeGuardedByCond - Test whether the backedge of the loop is
11514/// protected by a conditional between LHS and RHS. This is used to
11515/// to eliminate casts.
11517 CmpPredicate Pred,
11518 const SCEV *LHS,
11519 const SCEV *RHS) {
11520 // Interpret a null as meaning no loop, where there is obviously no guard
11521 // (interprocedural conditions notwithstanding). Do not bother about
11522 // unreachable loops.
11523 if (!L || !DT.isReachableFromEntry(L->getHeader()))
11524 return true;
11525
11526 if (VerifyIR)
11527 assert(!verifyFunction(*L->getHeader()->getParent(), &dbgs()) &&
11528 "This cannot be done on broken IR!");
11529
11530
11531 if (isKnownViaNonRecursiveReasoning(Pred, LHS, RHS))
11532 return true;
11533
11534 BasicBlock *Latch = L->getLoopLatch();
11535 if (!Latch)
11536 return false;
11537
11538 BranchInst *LoopContinuePredicate =
11539 dyn_cast<BranchInst>(Latch->getTerminator());
11540 if (LoopContinuePredicate && LoopContinuePredicate->isConditional() &&
11541 isImpliedCond(Pred, LHS, RHS,
11542 LoopContinuePredicate->getCondition(),
11543 LoopContinuePredicate->getSuccessor(0) != L->getHeader()))
11544 return true;
11545
11546 // We don't want more than one activation of the following loops on the stack
11547 // -- that can lead to O(n!) time complexity.
11548 if (WalkingBEDominatingConds)
11549 return false;
11550
11551 SaveAndRestore ClearOnExit(WalkingBEDominatingConds, true);
11552
11553 // See if we can exploit a trip count to prove the predicate.
11554 const auto &BETakenInfo = getBackedgeTakenInfo(L);
11555 const SCEV *LatchBECount = BETakenInfo.getExact(Latch, this);
11556 if (LatchBECount != getCouldNotCompute()) {
11557 // We know that Latch branches back to the loop header exactly
11558 // LatchBECount times. This means the backdege condition at Latch is
11559 // equivalent to "{0,+,1} u< LatchBECount".
11560 Type *Ty = LatchBECount->getType();
11561 auto NoWrapFlags = SCEV::NoWrapFlags(SCEV::FlagNUW | SCEV::FlagNW);
11562 const SCEV *LoopCounter =
11563 getAddRecExpr(getZero(Ty), getOne(Ty), L, NoWrapFlags);
11564 if (isImpliedCond(Pred, LHS, RHS, ICmpInst::ICMP_ULT, LoopCounter,
11565 LatchBECount))
11566 return true;
11567 }
11568
11569 // Check conditions due to any @llvm.assume intrinsics.
11570 for (auto &AssumeVH : AC.assumptions()) {
11571 if (!AssumeVH)
11572 continue;
11573 auto *CI = cast<CallInst>(AssumeVH);
11574 if (!DT.dominates(CI, Latch->getTerminator()))
11575 continue;
11576
11577 if (isImpliedCond(Pred, LHS, RHS, CI->getArgOperand(0), false))
11578 return true;
11579 }
11580
11581 if (isImpliedViaGuard(Latch, Pred, LHS, RHS))
11582 return true;
11583
11584 for (DomTreeNode *DTN = DT[Latch], *HeaderDTN = DT[L->getHeader()];
11585 DTN != HeaderDTN; DTN = DTN->getIDom()) {
11586 assert(DTN && "should reach the loop header before reaching the root!");
11587
11588 BasicBlock *BB = DTN->getBlock();
11589 if (isImpliedViaGuard(BB, Pred, LHS, RHS))
11590 return true;
11591
11592 BasicBlock *PBB = BB->getSinglePredecessor();
11593 if (!PBB)
11594 continue;
11595
11596 BranchInst *ContinuePredicate = dyn_cast<BranchInst>(PBB->getTerminator());
11597 if (!ContinuePredicate || !ContinuePredicate->isConditional())
11598 continue;
11599
11600 Value *Condition = ContinuePredicate->getCondition();
11601
11602 // If we have an edge `E` within the loop body that dominates the only
11603 // latch, the condition guarding `E` also guards the backedge. This
11604 // reasoning works only for loops with a single latch.
11605
11606 BasicBlockEdge DominatingEdge(PBB, BB);
11607 if (DominatingEdge.isSingleEdge()) {
11608 // We're constructively (and conservatively) enumerating edges within the
11609 // loop body that dominate the latch. The dominator tree better agree
11610 // with us on this:
11611 assert(DT.dominates(DominatingEdge, Latch) && "should be!");
11612
11613 if (isImpliedCond(Pred, LHS, RHS, Condition,
11614 BB != ContinuePredicate->getSuccessor(0)))
11615 return true;
11616 }
11617 }
11618
11619 return false;
11620}
11621
11623 CmpPredicate Pred,
11624 const SCEV *LHS,
11625 const SCEV *RHS) {
11626 // Do not bother proving facts for unreachable code.
11627 if (!DT.isReachableFromEntry(BB))
11628 return true;
11629 if (VerifyIR)
11630 assert(!verifyFunction(*BB->getParent(), &dbgs()) &&
11631 "This cannot be done on broken IR!");
11632
11633 // If we cannot prove strict comparison (e.g. a > b), maybe we can prove
11634 // the facts (a >= b && a != b) separately. A typical situation is when the
11635 // non-strict comparison is known from ranges and non-equality is known from
11636 // dominating predicates. If we are proving strict comparison, we always try
11637 // to prove non-equality and non-strict comparison separately.
11638 auto NonStrictPredicate = ICmpInst::getNonStrictPredicate(Pred);
11639 const bool ProvingStrictComparison = (Pred != NonStrictPredicate);
11640 bool ProvedNonStrictComparison = false;
11641 bool ProvedNonEquality = false;
11642
11643 auto SplitAndProve = [&](std::function<bool(CmpPredicate)> Fn) -> bool {
11644 if (!ProvedNonStrictComparison)
11645 ProvedNonStrictComparison = Fn(NonStrictPredicate);
11646 if (!ProvedNonEquality)
11647 ProvedNonEquality = Fn(ICmpInst::ICMP_NE);
11648 if (ProvedNonStrictComparison && ProvedNonEquality)
11649 return true;
11650 return false;
11651 };
11652
11653 if (ProvingStrictComparison) {
11654 auto ProofFn = [&](CmpPredicate P) {
11655 return isKnownViaNonRecursiveReasoning(P, LHS, RHS);
11656 };
11657 if (SplitAndProve(ProofFn))
11658 return true;
11659 }
11660
11661 // Try to prove (Pred, LHS, RHS) using isImpliedCond.
11662 auto ProveViaCond = [&](const Value *Condition, bool Inverse) {
11663 const Instruction *CtxI = &BB->front();
11664 if (isImpliedCond(Pred, LHS, RHS, Condition, Inverse, CtxI))
11665 return true;
11666 if (ProvingStrictComparison) {
11667 auto ProofFn = [&](CmpPredicate P) {
11668 return isImpliedCond(P, LHS, RHS, Condition, Inverse, CtxI);
11669 };
11670 if (SplitAndProve(ProofFn))
11671 return true;
11672 }
11673 return false;
11674 };
11675
11676 // Starting at the block's predecessor, climb up the predecessor chain, as long
11677 // as there are predecessors that can be found that have unique successors
11678 // leading to the original block.
11679 const Loop *ContainingLoop = LI.getLoopFor(BB);
11680 const BasicBlock *PredBB;
11681 if (ContainingLoop && ContainingLoop->getHeader() == BB)
11682 PredBB = ContainingLoop->getLoopPredecessor();
11683 else
11684 PredBB = BB->getSinglePredecessor();
11685 for (std::pair<const BasicBlock *, const BasicBlock *> Pair(PredBB, BB);
11686 Pair.first; Pair = getPredecessorWithUniqueSuccessorForBB(Pair.first)) {
11687 const BranchInst *BlockEntryPredicate =
11688 dyn_cast<BranchInst>(Pair.first->getTerminator());
11689 if (!BlockEntryPredicate || BlockEntryPredicate->isUnconditional())
11690 continue;
11691
11692 if (ProveViaCond(BlockEntryPredicate->getCondition(),
11693 BlockEntryPredicate->getSuccessor(0) != Pair.second))
11694 return true;
11695 }
11696
11697 // Check conditions due to any @llvm.assume intrinsics.
11698 for (auto &AssumeVH : AC.assumptions()) {
11699 if (!AssumeVH)
11700 continue;
11701 auto *CI = cast<CallInst>(AssumeVH);
11702 if (!DT.dominates(CI, BB))
11703 continue;
11704
11705 if (ProveViaCond(CI->getArgOperand(0), false))
11706 return true;
11707 }
11708
11709 // Check conditions due to any @llvm.experimental.guard intrinsics.
11710 auto *GuardDecl = Intrinsic::getDeclarationIfExists(
11711 F.getParent(), Intrinsic::experimental_guard);
11712 if (GuardDecl)
11713 for (const auto *GU : GuardDecl->users())
11714 if (const auto *Guard = dyn_cast<IntrinsicInst>(GU))
11715 if (Guard->getFunction() == BB->getParent() && DT.dominates(Guard, BB))
11716 if (ProveViaCond(Guard->getArgOperand(0), false))
11717 return true;
11718 return false;
11719}
11720
11722 const SCEV *LHS,
11723 const SCEV *RHS) {
11724 // Interpret a null as meaning no loop, where there is obviously no guard
11725 // (interprocedural conditions notwithstanding).
11726 if (!L)
11727 return false;
11728
11729 // Both LHS and RHS must be available at loop entry.
11731 "LHS is not available at Loop Entry");
11733 "RHS is not available at Loop Entry");
11734
11735 if (isKnownViaNonRecursiveReasoning(Pred, LHS, RHS))
11736 return true;
11737
11738 return isBasicBlockEntryGuardedByCond(L->getHeader(), Pred, LHS, RHS);
11739}
11740
11741bool ScalarEvolution::isImpliedCond(CmpPredicate Pred, const SCEV *LHS,
11742 const SCEV *RHS,
11743 const Value *FoundCondValue, bool Inverse,
11744 const Instruction *CtxI) {
11745 // False conditions implies anything. Do not bother analyzing it further.
11746 if (FoundCondValue ==
11747 ConstantInt::getBool(FoundCondValue->getContext(), Inverse))
11748 return true;
11749
11750 if (!PendingLoopPredicates.insert(FoundCondValue).second)
11751 return false;
11752
11753 auto ClearOnExit =
11754 make_scope_exit([&]() { PendingLoopPredicates.erase(FoundCondValue); });
11755
11756 // Recursively handle And and Or conditions.
11757 const Value *Op0, *Op1;
11758 if (match(FoundCondValue, m_LogicalAnd(m_Value(Op0), m_Value(Op1)))) {
11759 if (!Inverse)
11760 return isImpliedCond(Pred, LHS, RHS, Op0, Inverse, CtxI) ||
11761 isImpliedCond(Pred, LHS, RHS, Op1, Inverse, CtxI);
11762 } else if (match(FoundCondValue, m_LogicalOr(m_Value(Op0), m_Value(Op1)))) {
11763 if (Inverse)
11764 return isImpliedCond(Pred, LHS, RHS, Op0, Inverse, CtxI) ||
11765 isImpliedCond(Pred, LHS, RHS, Op1, Inverse, CtxI);
11766 }
11767
11768 const ICmpInst *ICI = dyn_cast<ICmpInst>(FoundCondValue);
11769 if (!ICI) return false;
11770
11771 // Now that we found a conditional branch that dominates the loop or controls
11772 // the loop latch. Check to see if it is the comparison we are looking for.
11773 CmpPredicate FoundPred;
11774 if (Inverse)
11775 FoundPred = ICI->getInverseCmpPredicate();
11776 else
11777 FoundPred = ICI->getCmpPredicate();
11778
11779 const SCEV *FoundLHS = getSCEV(ICI->getOperand(0));
11780 const SCEV *FoundRHS = getSCEV(ICI->getOperand(1));
11781
11782 return isImpliedCond(Pred, LHS, RHS, FoundPred, FoundLHS, FoundRHS, CtxI);
11783}
11784
11785bool ScalarEvolution::isImpliedCond(CmpPredicate Pred, const SCEV *LHS,
11786 const SCEV *RHS, CmpPredicate FoundPred,
11787 const SCEV *FoundLHS, const SCEV *FoundRHS,
11788 const Instruction *CtxI) {
11789 // Balance the types.
11790 if (getTypeSizeInBits(LHS->getType()) <
11791 getTypeSizeInBits(FoundLHS->getType())) {
11792 // For unsigned and equality predicates, try to prove that both found
11793 // operands fit into narrow unsigned range. If so, try to prove facts in
11794 // narrow types.
11795 if (!CmpInst::isSigned(FoundPred) && !FoundLHS->getType()->isPointerTy() &&
11796 !FoundRHS->getType()->isPointerTy()) {
11797 auto *NarrowType = LHS->getType();
11798 auto *WideType = FoundLHS->getType();
11799 auto BitWidth = getTypeSizeInBits(NarrowType);
11800 const SCEV *MaxValue = getZeroExtendExpr(
11802 if (isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_ULE, FoundLHS,
11803 MaxValue) &&
11804 isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_ULE, FoundRHS,
11805 MaxValue)) {
11806 const SCEV *TruncFoundLHS = getTruncateExpr(FoundLHS, NarrowType);
11807 const SCEV *TruncFoundRHS = getTruncateExpr(FoundRHS, NarrowType);
11808 if (isImpliedCondBalancedTypes(Pred, LHS, RHS, FoundPred, TruncFoundLHS,
11809 TruncFoundRHS, CtxI))
11810 return true;
11811 }
11812 }
11813
11814 if (LHS->getType()->isPointerTy() || RHS->getType()->isPointerTy())
11815 return false;
11816 if (CmpInst::isSigned(Pred)) {
11817 LHS = getSignExtendExpr(LHS, FoundLHS->getType());
11818 RHS = getSignExtendExpr(RHS, FoundLHS->getType());
11819 } else {
11820 LHS = getZeroExtendExpr(LHS, FoundLHS->getType());
11821 RHS = getZeroExtendExpr(RHS, FoundLHS->getType());
11822 }
11823 } else if (getTypeSizeInBits(LHS->getType()) >
11824 getTypeSizeInBits(FoundLHS->getType())) {
11825 if (FoundLHS->getType()->isPointerTy() || FoundRHS->getType()->isPointerTy())
11826 return false;
11827 if (CmpInst::isSigned(FoundPred)) {
11828 FoundLHS = getSignExtendExpr(FoundLHS, LHS->getType());
11829 FoundRHS = getSignExtendExpr(FoundRHS, LHS->getType());
11830 } else {
11831 FoundLHS = getZeroExtendExpr(FoundLHS, LHS->getType());
11832 FoundRHS = getZeroExtendExpr(FoundRHS, LHS->getType());
11833 }
11834 }
11835 return isImpliedCondBalancedTypes(Pred, LHS, RHS, FoundPred, FoundLHS,
11836 FoundRHS, CtxI);
11837}
11838
11839bool ScalarEvolution::isImpliedCondBalancedTypes(
11840 CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS, CmpPredicate FoundPred,
11841 const SCEV *FoundLHS, const SCEV *FoundRHS, const Instruction *CtxI) {
11843 getTypeSizeInBits(FoundLHS->getType()) &&
11844 "Types should be balanced!");
11845 // Canonicalize the query to match the way instcombine will have
11846 // canonicalized the comparison.
11847 if (SimplifyICmpOperands(Pred, LHS, RHS))
11848 if (LHS == RHS)
11849 return CmpInst::isTrueWhenEqual(Pred);
11850 if (SimplifyICmpOperands(FoundPred, FoundLHS, FoundRHS))
11851 if (FoundLHS == FoundRHS)
11852 return CmpInst::isFalseWhenEqual(FoundPred);
11853
11854 // Check to see if we can make the LHS or RHS match.
11855 if (LHS == FoundRHS || RHS == FoundLHS) {
11856 if (isa<SCEVConstant>(RHS)) {
11857 std::swap(FoundLHS, FoundRHS);
11858 FoundPred = ICmpInst::getSwappedCmpPredicate(FoundPred);
11859 } else {
11860 std::swap(LHS, RHS);
11862 }
11863 }
11864
11865 // Check whether the found predicate is the same as the desired predicate.
11866 // FIXME: use CmpPredicate::getMatching here.
11867 if (FoundPred == static_cast<CmpInst::Predicate>(Pred))
11868 return isImpliedCondOperands(Pred, LHS, RHS, FoundLHS, FoundRHS, CtxI);
11869
11870 // Check whether swapping the found predicate makes it the same as the
11871 // desired predicate.
11872 // FIXME: use CmpPredicate::getMatching here.
11873 if (ICmpInst::getSwappedCmpPredicate(FoundPred) ==
11874 static_cast<CmpInst::Predicate>(Pred)) {
11875 // We can write the implication
11876 // 0. LHS Pred RHS <- FoundLHS SwapPred FoundRHS
11877 // using one of the following ways:
11878 // 1. LHS Pred RHS <- FoundRHS Pred FoundLHS
11879 // 2. RHS SwapPred LHS <- FoundLHS SwapPred FoundRHS
11880 // 3. LHS Pred RHS <- ~FoundLHS Pred ~FoundRHS
11881 // 4. ~LHS SwapPred ~RHS <- FoundLHS SwapPred FoundRHS
11882 // Forms 1. and 2. require swapping the operands of one condition. Don't
11883 // do this if it would break canonical constant/addrec ordering.
11884 if (!isa<SCEVConstant>(RHS) && !isa<SCEVAddRecExpr>(LHS))
11885 return isImpliedCondOperands(FoundPred, RHS, LHS, FoundLHS, FoundRHS,
11886 CtxI);
11887 if (!isa<SCEVConstant>(FoundRHS) && !isa<SCEVAddRecExpr>(FoundLHS))
11888 return isImpliedCondOperands(Pred, LHS, RHS, FoundRHS, FoundLHS, CtxI);
11889
11890 // There's no clear preference between forms 3. and 4., try both. Avoid
11891 // forming getNotSCEV of pointer values as the resulting subtract is
11892 // not legal.
11893 if (!LHS->getType()->isPointerTy() && !RHS->getType()->isPointerTy() &&
11894 isImpliedCondOperands(FoundPred, getNotSCEV(LHS), getNotSCEV(RHS),
11895 FoundLHS, FoundRHS, CtxI))
11896 return true;
11897
11898 if (!FoundLHS->getType()->isPointerTy() &&
11899 !FoundRHS->getType()->isPointerTy() &&
11900 isImpliedCondOperands(Pred, LHS, RHS, getNotSCEV(FoundLHS),
11901 getNotSCEV(FoundRHS), CtxI))
11902 return true;
11903
11904 return false;
11905 }
11906
11907 auto IsSignFlippedPredicate = [](CmpInst::Predicate P1,
11908 CmpInst::Predicate P2) {
11909 assert(P1 != P2 && "Handled earlier!");
11910 return CmpInst::isRelational(P2) &&
11912 };
11913 if (IsSignFlippedPredicate(Pred, FoundPred)) {
11914 // Unsigned comparison is the same as signed comparison when both the
11915 // operands are non-negative or negative.
11916 if ((isKnownNonNegative(FoundLHS) && isKnownNonNegative(FoundRHS)) ||
11917 (isKnownNegative(FoundLHS) && isKnownNegative(FoundRHS)))
11918 return isImpliedCondOperands(Pred, LHS, RHS, FoundLHS, FoundRHS, CtxI);
11919 // Create local copies that we can freely swap and canonicalize our
11920 // conditions to "le/lt".
11921 CmpPredicate CanonicalPred = Pred, CanonicalFoundPred = FoundPred;
11922 const SCEV *CanonicalLHS = LHS, *CanonicalRHS = RHS,
11923 *CanonicalFoundLHS = FoundLHS, *CanonicalFoundRHS = FoundRHS;
11924 if (ICmpInst::isGT(CanonicalPred) || ICmpInst::isGE(CanonicalPred)) {
11925 CanonicalPred = ICmpInst::getSwappedCmpPredicate(CanonicalPred);
11926 CanonicalFoundPred = ICmpInst::getSwappedCmpPredicate(CanonicalFoundPred);
11927 std::swap(CanonicalLHS, CanonicalRHS);
11928 std::swap(CanonicalFoundLHS, CanonicalFoundRHS);
11929 }
11930 assert((ICmpInst::isLT(CanonicalPred) || ICmpInst::isLE(CanonicalPred)) &&
11931 "Must be!");
11932 assert((ICmpInst::isLT(CanonicalFoundPred) ||
11933 ICmpInst::isLE(CanonicalFoundPred)) &&
11934 "Must be!");
11935 if (ICmpInst::isSigned(CanonicalPred) && isKnownNonNegative(CanonicalRHS))
11936 // Use implication:
11937 // x <u y && y >=s 0 --> x <s y.
11938 // If we can prove the left part, the right part is also proven.
11939 return isImpliedCondOperands(CanonicalFoundPred, CanonicalLHS,
11940 CanonicalRHS, CanonicalFoundLHS,
11941 CanonicalFoundRHS);
11942 if (ICmpInst::isUnsigned(CanonicalPred) && isKnownNegative(CanonicalRHS))
11943 // Use implication:
11944 // x <s y && y <s 0 --> x <u y.
11945 // If we can prove the left part, the right part is also proven.
11946 return isImpliedCondOperands(CanonicalFoundPred, CanonicalLHS,
11947 CanonicalRHS, CanonicalFoundLHS,
11948 CanonicalFoundRHS);
11949 }
11950
11951 // Check if we can make progress by sharpening ranges.
11952 if (FoundPred == ICmpInst::ICMP_NE &&
11953 (isa<SCEVConstant>(FoundLHS) || isa<SCEVConstant>(FoundRHS))) {
11954
11955 const SCEVConstant *C = nullptr;
11956 const SCEV *V = nullptr;
11957
11958 if (isa<SCEVConstant>(FoundLHS)) {
11959 C = cast<SCEVConstant>(FoundLHS);
11960 V = FoundRHS;
11961 } else {
11962 C = cast<SCEVConstant>(FoundRHS);
11963 V = FoundLHS;
11964 }
11965
11966 // The guarding predicate tells us that C != V. If the known range
11967 // of V is [C, t), we can sharpen the range to [C + 1, t). The
11968 // range we consider has to correspond to same signedness as the
11969 // predicate we're interested in folding.
11970
11971 APInt Min = ICmpInst::isSigned(Pred) ?
11973
11974 if (Min == C->getAPInt()) {
11975 // Given (V >= Min && V != Min) we conclude V >= (Min + 1).
11976 // This is true even if (Min + 1) wraps around -- in case of
11977 // wraparound, (Min + 1) < Min, so (V >= Min => V >= (Min + 1)).
11978
11979 APInt SharperMin = Min + 1;
11980
11981 switch (Pred) {
11982 case ICmpInst::ICMP_SGE:
11983 case ICmpInst::ICMP_UGE:
11984 // We know V `Pred` SharperMin. If this implies LHS `Pred`
11985 // RHS, we're done.
11986 if (isImpliedCondOperands(Pred, LHS, RHS, V, getConstant(SharperMin),
11987 CtxI))
11988 return true;
11989 [[fallthrough]];
11990
11991 case ICmpInst::ICMP_SGT:
11992 case ICmpInst::ICMP_UGT:
11993 // We know from the range information that (V `Pred` Min ||
11994 // V == Min). We know from the guarding condition that !(V
11995 // == Min). This gives us
11996 //
11997 // V `Pred` Min || V == Min && !(V == Min)
11998 // => V `Pred` Min
11999 //
12000 // If V `Pred` Min implies LHS `Pred` RHS, we're done.
12001
12002 if (isImpliedCondOperands(Pred, LHS, RHS, V, getConstant(Min), CtxI))
12003 return true;
12004 break;
12005
12006 // `LHS < RHS` and `LHS <= RHS` are handled in the same way as `RHS > LHS` and `RHS >= LHS` respectively.
12007 case ICmpInst::ICMP_SLE:
12008 case ICmpInst::ICMP_ULE:
12009 if (isImpliedCondOperands(ICmpInst::getSwappedCmpPredicate(Pred), RHS,
12010 LHS, V, getConstant(SharperMin), CtxI))
12011 return true;
12012 [[fallthrough]];
12013
12014 case ICmpInst::ICMP_SLT:
12015 case ICmpInst::ICMP_ULT:
12016 if (isImpliedCondOperands(ICmpInst::getSwappedCmpPredicate(Pred), RHS,
12017 LHS, V, getConstant(Min), CtxI))
12018 return true;
12019 break;
12020
12021 default:
12022 // No change
12023 break;
12024 }
12025 }
12026 }
12027
12028 // Check whether the actual condition is beyond sufficient.
12029 if (FoundPred == ICmpInst::ICMP_EQ)
12030 if (ICmpInst::isTrueWhenEqual(Pred))
12031 if (isImpliedCondOperands(Pred, LHS, RHS, FoundLHS, FoundRHS, CtxI))
12032 return true;
12033 if (Pred == ICmpInst::ICMP_NE)
12034 if (!ICmpInst::isTrueWhenEqual(FoundPred))
12035 if (isImpliedCondOperands(FoundPred, LHS, RHS, FoundLHS, FoundRHS, CtxI))
12036 return true;
12037
12038 if (isImpliedCondOperandsViaRanges(Pred, LHS, RHS, FoundPred, FoundLHS, FoundRHS))
12039 return true;
12040
12041 // Otherwise assume the worst.
12042 return false;
12043}
12044
12045bool ScalarEvolution::splitBinaryAdd(const SCEV *Expr,
12046 const SCEV *&L, const SCEV *&R,
12047 SCEV::NoWrapFlags &Flags) {
12048 const auto *AE = dyn_cast<SCEVAddExpr>(Expr);
12049 if (!AE || AE->getNumOperands() != 2)
12050 return false;
12051
12052 L = AE->getOperand(0);
12053 R = AE->getOperand(1);
12054 Flags = AE->getNoWrapFlags();
12055 return true;
12056}
12057
12058std::optional<APInt>
12060 // We avoid subtracting expressions here because this function is usually
12061 // fairly deep in the call stack (i.e. is called many times).
12062
12063 unsigned BW = getTypeSizeInBits(More->getType());
12064 APInt Diff(BW, 0);
12065 APInt DiffMul(BW, 1);
12066 // Try various simplifications to reduce the difference to a constant. Limit
12067 // the number of allowed simplifications to keep compile-time low.
12068 for (unsigned I = 0; I < 8; ++I) {
12069 if (More == Less)
12070 return Diff;
12071
12072 // Reduce addrecs with identical steps to their start value.
12073 if (isa<SCEVAddRecExpr>(Less) && isa<SCEVAddRecExpr>(More)) {
12074 const auto *LAR = cast<SCEVAddRecExpr>(Less);
12075 const auto *MAR = cast<SCEVAddRecExpr>(More);
12076
12077 if (LAR->getLoop() != MAR->getLoop())
12078 return std::nullopt;
12079
12080 // We look at affine expressions only; not for correctness but to keep
12081 // getStepRecurrence cheap.
12082 if (!LAR->isAffine() || !MAR->isAffine())
12083 return std::nullopt;
12084
12085 if (LAR->getStepRecurrence(*this) != MAR->getStepRecurrence(*this))
12086 return std::nullopt;
12087
12088 Less = LAR->getStart();
12089 More = MAR->getStart();
12090 continue;
12091 }
12092
12093 // Try to match a common constant multiply.
12094 auto MatchConstMul =
12095 [](const SCEV *S) -> std::optional<std::pair<const SCEV *, APInt>> {
12096 auto *M = dyn_cast<SCEVMulExpr>(S);
12097 if (!M || M->getNumOperands() != 2 ||
12098 !isa<SCEVConstant>(M->getOperand(0)))
12099 return std::nullopt;
12100 return {
12101 {M->getOperand(1), cast<SCEVConstant>(M->getOperand(0))->getAPInt()}};
12102 };
12103 if (auto MatchedMore = MatchConstMul(More)) {
12104 if (auto MatchedLess = MatchConstMul(Less)) {
12105 if (MatchedMore->second == MatchedLess->second) {
12106 More = MatchedMore->first;
12107 Less = MatchedLess->first;
12108 DiffMul *= MatchedMore->second;
12109 continue;
12110 }
12111 }
12112 }
12113
12114 // Try to cancel out common factors in two add expressions.
12116 auto Add = [&](const SCEV *S, int Mul) {
12117 if (auto *C = dyn_cast<SCEVConstant>(S)) {
12118 if (Mul == 1) {
12119 Diff += C->getAPInt() * DiffMul;
12120 } else {
12121 assert(Mul == -1);
12122 Diff -= C->getAPInt() * DiffMul;
12123 }
12124 } else
12125 Multiplicity[S] += Mul;
12126 };
12127 auto Decompose = [&](const SCEV *S, int Mul) {
12128 if (isa<SCEVAddExpr>(S)) {
12129 for (const SCEV *Op : S->operands())
12130 Add(Op, Mul);
12131 } else
12132 Add(S, Mul);
12133 };
12134 Decompose(More, 1);
12135 Decompose(Less, -1);
12136
12137 // Check whether all the non-constants cancel out, or reduce to new
12138 // More/Less values.
12139 const SCEV *NewMore = nullptr, *NewLess = nullptr;
12140 for (const auto &[S, Mul] : Multiplicity) {
12141 if (Mul == 0)
12142 continue;
12143 if (Mul == 1) {
12144 if (NewMore)
12145 return std::nullopt;
12146 NewMore = S;
12147 } else if (Mul == -1) {
12148 if (NewLess)
12149 return std::nullopt;
12150 NewLess = S;
12151 } else
12152 return std::nullopt;
12153 }
12154
12155 // Values stayed the same, no point in trying further.
12156 if (NewMore == More || NewLess == Less)
12157 return std::nullopt;
12158
12159 More = NewMore;
12160 Less = NewLess;
12161
12162 // Reduced to constant.
12163 if (!More && !Less)
12164 return Diff;
12165
12166 // Left with variable on only one side, bail out.
12167 if (!More || !Less)
12168 return std::nullopt;
12169 }
12170
12171 // Did not reduce to constant.
12172 return std::nullopt;
12173}
12174
12175bool ScalarEvolution::isImpliedCondOperandsViaAddRecStart(
12176 CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS, const SCEV *FoundLHS,
12177 const SCEV *FoundRHS, const Instruction *CtxI) {
12178 // Try to recognize the following pattern:
12179 //
12180 // FoundRHS = ...
12181 // ...
12182 // loop:
12183 // FoundLHS = {Start,+,W}
12184 // context_bb: // Basic block from the same loop
12185 // known(Pred, FoundLHS, FoundRHS)
12186 //
12187 // If some predicate is known in the context of a loop, it is also known on
12188 // each iteration of this loop, including the first iteration. Therefore, in
12189 // this case, `FoundLHS Pred FoundRHS` implies `Start Pred FoundRHS`. Try to
12190 // prove the original pred using this fact.
12191 if (!CtxI)
12192 return false;
12193 const BasicBlock *ContextBB = CtxI->getParent();
12194 // Make sure AR varies in the context block.
12195 if (auto *AR = dyn_cast<SCEVAddRecExpr>(FoundLHS)) {
12196 const Loop *L = AR->getLoop();
12197 // Make sure that context belongs to the loop and executes on 1st iteration
12198 // (if it ever executes at all).
12199 if (!L->contains(ContextBB) || !DT.dominates(ContextBB, L->getLoopLatch()))
12200 return false;
12201 if (!isAvailableAtLoopEntry(FoundRHS, AR->getLoop()))
12202 return false;
12203 return isImpliedCondOperands(Pred, LHS, RHS, AR->getStart(), FoundRHS);
12204 }
12205
12206 if (auto *AR = dyn_cast<SCEVAddRecExpr>(FoundRHS)) {
12207 const Loop *L = AR->getLoop();
12208 // Make sure that context belongs to the loop and executes on 1st iteration
12209 // (if it ever executes at all).
12210 if (!L->contains(ContextBB) || !DT.dominates(ContextBB, L->getLoopLatch()))
12211 return false;
12212 if (!isAvailableAtLoopEntry(FoundLHS, AR->getLoop()))
12213 return false;
12214 return isImpliedCondOperands(Pred, LHS, RHS, FoundLHS, AR->getStart());
12215 }
12216
12217 return false;
12218}
12219
12220bool ScalarEvolution::isImpliedCondOperandsViaNoOverflow(CmpPredicate Pred,
12221 const SCEV *LHS,
12222 const SCEV *RHS,
12223 const SCEV *FoundLHS,
12224 const SCEV *FoundRHS) {
12225 if (Pred != CmpInst::ICMP_SLT && Pred != CmpInst::ICMP_ULT)
12226 return false;
12227
12228 const auto *AddRecLHS = dyn_cast<SCEVAddRecExpr>(LHS);
12229 if (!AddRecLHS)
12230 return false;
12231
12232 const auto *AddRecFoundLHS = dyn_cast<SCEVAddRecExpr>(FoundLHS);
12233 if (!AddRecFoundLHS)
12234 return false;
12235
12236 // We'd like to let SCEV reason about control dependencies, so we constrain
12237 // both the inequalities to be about add recurrences on the same loop. This
12238 // way we can use isLoopEntryGuardedByCond later.
12239
12240 const Loop *L = AddRecFoundLHS->getLoop();
12241 if (L != AddRecLHS->getLoop())
12242 return false;
12243
12244 // FoundLHS u< FoundRHS u< -C => (FoundLHS + C) u< (FoundRHS + C) ... (1)
12245 //
12246 // FoundLHS s< FoundRHS s< INT_MIN - C => (FoundLHS + C) s< (FoundRHS + C)
12247 // ... (2)
12248 //
12249 // Informal proof for (2), assuming (1) [*]:
12250 //
12251 // We'll also assume (A s< B) <=> ((A + INT_MIN) u< (B + INT_MIN)) ... (3)[**]
12252 //
12253 // Then
12254 //
12255 // FoundLHS s< FoundRHS s< INT_MIN - C
12256 // <=> (FoundLHS + INT_MIN) u< (FoundRHS + INT_MIN) u< -C [ using (3) ]
12257 // <=> (FoundLHS + INT_MIN + C) u< (FoundRHS + INT_MIN + C) [ using (1) ]
12258 // <=> (FoundLHS + INT_MIN + C + INT_MIN) s<
12259 // (FoundRHS + INT_MIN + C + INT_MIN) [ using (3) ]
12260 // <=> FoundLHS + C s< FoundRHS + C
12261 //
12262 // [*]: (1) can be proved by ruling out overflow.
12263 //
12264 // [**]: This can be proved by analyzing all the four possibilities:
12265 // (A s< 0, B s< 0), (A s< 0, B s>= 0), (A s>= 0, B s< 0) and
12266 // (A s>= 0, B s>= 0).
12267 //
12268 // Note:
12269 // Despite (2), "FoundRHS s< INT_MIN - C" does not mean that "FoundRHS + C"
12270 // will not sign underflow. For instance, say FoundLHS = (i8 -128), FoundRHS
12271 // = (i8 -127) and C = (i8 -100). Then INT_MIN - C = (i8 -28), and FoundRHS
12272 // s< (INT_MIN - C). Lack of sign overflow / underflow in "FoundRHS + C" is
12273 // neither necessary nor sufficient to prove "(FoundLHS + C) s< (FoundRHS +
12274 // C)".
12275
12276 std::optional<APInt> LDiff = computeConstantDifference(LHS, FoundLHS);
12277 if (!LDiff)
12278 return false;
12279 std::optional<APInt> RDiff = computeConstantDifference(RHS, FoundRHS);
12280 if (!RDiff || *LDiff != *RDiff)
12281 return false;
12282
12283 if (LDiff->isMinValue())
12284 return true;
12285
12286 APInt FoundRHSLimit;
12287
12288 if (Pred == CmpInst::ICMP_ULT) {
12289 FoundRHSLimit = -(*RDiff);
12290 } else {
12291 assert(Pred == CmpInst::ICMP_SLT && "Checked above!");
12292 FoundRHSLimit = APInt::getSignedMinValue(getTypeSizeInBits(RHS->getType())) - *RDiff;
12293 }
12294
12295 // Try to prove (1) or (2), as needed.
12296 return isAvailableAtLoopEntry(FoundRHS, L) &&
12297 isLoopEntryGuardedByCond(L, Pred, FoundRHS,
12298 getConstant(FoundRHSLimit));
12299}
12300
12301bool ScalarEvolution::isImpliedViaMerge(CmpPredicate Pred, const SCEV *LHS,
12302 const SCEV *RHS, const SCEV *FoundLHS,
12303 const SCEV *FoundRHS, unsigned Depth) {
12304 const PHINode *LPhi = nullptr, *RPhi = nullptr;
12305
12306 auto ClearOnExit = make_scope_exit([&]() {
12307 if (LPhi) {
12308 bool Erased = PendingMerges.erase(LPhi);
12309 assert(Erased && "Failed to erase LPhi!");
12310 (void)Erased;
12311 }
12312 if (RPhi) {
12313 bool Erased = PendingMerges.erase(RPhi);
12314 assert(Erased && "Failed to erase RPhi!");
12315 (void)Erased;
12316 }
12317 });
12318
12319 // Find respective Phis and check that they are not being pending.
12320 if (const SCEVUnknown *LU = dyn_cast<SCEVUnknown>(LHS))
12321 if (auto *Phi = dyn_cast<PHINode>(LU->getValue())) {
12322 if (!PendingMerges.insert(Phi).second)
12323 return false;
12324 LPhi = Phi;
12325 }
12326 if (const SCEVUnknown *RU = dyn_cast<SCEVUnknown>(RHS))
12327 if (auto *Phi = dyn_cast<PHINode>(RU->getValue())) {
12328 // If we detect a loop of Phi nodes being processed by this method, for
12329 // example:
12330 //
12331 // %a = phi i32 [ %some1, %preheader ], [ %b, %latch ]
12332 // %b = phi i32 [ %some2, %preheader ], [ %a, %latch ]
12333 //
12334 // we don't want to deal with a case that complex, so return conservative
12335 // answer false.
12336 if (!PendingMerges.insert(Phi).second)
12337 return false;
12338 RPhi = Phi;
12339 }
12340
12341 // If none of LHS, RHS is a Phi, nothing to do here.
12342 if (!LPhi && !RPhi)
12343 return false;
12344
12345 // If there is a SCEVUnknown Phi we are interested in, make it left.
12346 if (!LPhi) {
12347 std::swap(LHS, RHS);
12348 std::swap(FoundLHS, FoundRHS);
12349 std::swap(LPhi, RPhi);
12351 }
12352
12353 assert(LPhi && "LPhi should definitely be a SCEVUnknown Phi!");
12354 const BasicBlock *LBB = LPhi->getParent();
12355 const SCEVAddRecExpr *RAR = dyn_cast<SCEVAddRecExpr>(RHS);
12356
12357 auto ProvedEasily = [&](const SCEV *S1, const SCEV *S2) {
12358 return isKnownViaNonRecursiveReasoning(Pred, S1, S2) ||
12359 isImpliedCondOperandsViaRanges(Pred, S1, S2, Pred, FoundLHS, FoundRHS) ||
12360 isImpliedViaOperations(Pred, S1, S2, FoundLHS, FoundRHS, Depth);
12361 };
12362
12363 if (RPhi && RPhi->getParent() == LBB) {
12364 // Case one: RHS is also a SCEVUnknown Phi from the same basic block.
12365 // If we compare two Phis from the same block, and for each entry block
12366 // the predicate is true for incoming values from this block, then the
12367 // predicate is also true for the Phis.
12368 for (const BasicBlock *IncBB : predecessors(LBB)) {
12369 const SCEV *L = getSCEV(LPhi->getIncomingValueForBlock(IncBB));
12370 const SCEV *R = getSCEV(RPhi->getIncomingValueForBlock(IncBB));
12371 if (!ProvedEasily(L, R))
12372 return false;
12373 }
12374 } else if (RAR && RAR->getLoop()->getHeader() == LBB) {
12375 // Case two: RHS is also a Phi from the same basic block, and it is an
12376 // AddRec. It means that there is a loop which has both AddRec and Unknown
12377 // PHIs, for it we can compare incoming values of AddRec from above the loop
12378 // and latch with their respective incoming values of LPhi.
12379 // TODO: Generalize to handle loops with many inputs in a header.
12380 if (LPhi->getNumIncomingValues() != 2) return false;
12381
12382 auto *RLoop = RAR->getLoop();
12383 auto *Predecessor = RLoop->getLoopPredecessor();
12384 assert(Predecessor && "Loop with AddRec with no predecessor?");
12385 const SCEV *L1 = getSCEV(LPhi->getIncomingValueForBlock(Predecessor));
12386 if (!ProvedEasily(L1, RAR->getStart()))
12387 return false;
12388 auto *Latch = RLoop->getLoopLatch();
12389 assert(Latch && "Loop with AddRec with no latch?");
12390 const SCEV *L2 = getSCEV(LPhi->getIncomingValueForBlock(Latch));
12391 if (!ProvedEasily(L2, RAR->getPostIncExpr(*this)))
12392 return false;
12393 } else {
12394 // In all other cases go over inputs of LHS and compare each of them to RHS,
12395 // the predicate is true for (LHS, RHS) if it is true for all such pairs.
12396 // At this point RHS is either a non-Phi, or it is a Phi from some block
12397 // different from LBB.
12398 for (const BasicBlock *IncBB : predecessors(LBB)) {
12399 // Check that RHS is available in this block.
12400 if (!dominates(RHS, IncBB))
12401 return false;
12402 const SCEV *L = getSCEV(LPhi->getIncomingValueForBlock(IncBB));
12403 // Make sure L does not refer to a value from a potentially previous
12404 // iteration of a loop.
12405 if (!properlyDominates(L, LBB))
12406 return false;
12407 if (!ProvedEasily(L, RHS))
12408 return false;
12409 }
12410 }
12411 return true;
12412}
12413
12414bool ScalarEvolution::isImpliedCondOperandsViaShift(CmpPredicate Pred,
12415 const SCEV *LHS,
12416 const SCEV *RHS,
12417 const SCEV *FoundLHS,
12418 const SCEV *FoundRHS) {
12419 // We want to imply LHS < RHS from LHS < (RHS >> shiftvalue). First, make
12420 // sure that we are dealing with same LHS.
12421 if (RHS == FoundRHS) {
12422 std::swap(LHS, RHS);
12423 std::swap(FoundLHS, FoundRHS);
12425 }
12426 if (LHS != FoundLHS)
12427 return false;
12428
12429 auto *SUFoundRHS = dyn_cast<SCEVUnknown>(FoundRHS);
12430 if (!SUFoundRHS)
12431 return false;
12432
12433 Value *Shiftee, *ShiftValue;
12434
12435 using namespace PatternMatch;
12436 if (match(SUFoundRHS->getValue(),
12437 m_LShr(m_Value(Shiftee), m_Value(ShiftValue)))) {
12438 auto *ShifteeS = getSCEV(Shiftee);
12439 // Prove one of the following:
12440 // LHS <u (shiftee >> shiftvalue) && shiftee <=u RHS ---> LHS <u RHS
12441 // LHS <=u (shiftee >> shiftvalue) && shiftee <=u RHS ---> LHS <=u RHS
12442 // LHS <s (shiftee >> shiftvalue) && shiftee <=s RHS && shiftee >=s 0
12443 // ---> LHS <s RHS
12444 // LHS <=s (shiftee >> shiftvalue) && shiftee <=s RHS && shiftee >=s 0
12445 // ---> LHS <=s RHS
12446 if (Pred == ICmpInst::ICMP_ULT || Pred == ICmpInst::ICMP_ULE)
12447 return isKnownPredicate(ICmpInst::ICMP_ULE, ShifteeS, RHS);
12448 if (Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_SLE)
12449 if (isKnownNonNegative(ShifteeS))
12450 return isKnownPredicate(ICmpInst::ICMP_SLE, ShifteeS, RHS);
12451 }
12452
12453 return false;
12454}
12455
12456bool ScalarEvolution::isImpliedCondOperands(CmpPredicate Pred, const SCEV *LHS,
12457 const SCEV *RHS,
12458 const SCEV *FoundLHS,
12459 const SCEV *FoundRHS,
12460 const Instruction *CtxI) {
12461 if (isImpliedCondOperandsViaRanges(Pred, LHS, RHS, Pred, FoundLHS, FoundRHS))
12462 return true;
12463
12464 if (isImpliedCondOperandsViaNoOverflow(Pred, LHS, RHS, FoundLHS, FoundRHS))
12465 return true;
12466
12467 if (isImpliedCondOperandsViaShift(Pred, LHS, RHS, FoundLHS, FoundRHS))
12468 return true;
12469
12470 if (isImpliedCondOperandsViaAddRecStart(Pred, LHS, RHS, FoundLHS, FoundRHS,
12471 CtxI))
12472 return true;
12473
12474 return isImpliedCondOperandsHelper(Pred, LHS, RHS,
12475 FoundLHS, FoundRHS);
12476}
12477
12478/// Is MaybeMinMaxExpr an (U|S)(Min|Max) of Candidate and some other values?
12479template <typename MinMaxExprType>
12480static bool IsMinMaxConsistingOf(const SCEV *MaybeMinMaxExpr,
12481 const SCEV *Candidate) {
12482 const MinMaxExprType *MinMaxExpr = dyn_cast<MinMaxExprType>(MaybeMinMaxExpr);
12483 if (!MinMaxExpr)
12484 return false;
12485
12486 return is_contained(MinMaxExpr->operands(), Candidate);
12487}
12488
12490 CmpPredicate Pred, const SCEV *LHS,
12491 const SCEV *RHS) {
12492 // If both sides are affine addrecs for the same loop, with equal
12493 // steps, and we know the recurrences don't wrap, then we only
12494 // need to check the predicate on the starting values.
12495
12496 if (!ICmpInst::isRelational(Pred))
12497 return false;
12498
12499 const SCEVAddRecExpr *LAR = dyn_cast<SCEVAddRecExpr>(LHS);
12500 if (!LAR)
12501 return false;
12502 const SCEVAddRecExpr *RAR = dyn_cast<SCEVAddRecExpr>(RHS);
12503 if (!RAR)
12504 return false;
12505 if (LAR->getLoop() != RAR->getLoop())
12506 return false;
12507 if (!LAR->isAffine() || !RAR->isAffine())
12508 return false;
12509
12510 if (LAR->getStepRecurrence(SE) != RAR->getStepRecurrence(SE))
12511 return false;
12512
12515 if (!LAR->getNoWrapFlags(NW) || !RAR->getNoWrapFlags(NW))
12516 return false;
12517
12518 return SE.isKnownPredicate(Pred, LAR->getStart(), RAR->getStart());
12519}
12520
12521/// Is LHS `Pred` RHS true on the virtue of LHS or RHS being a Min or Max
12522/// expression?
12524 const SCEV *LHS, const SCEV *RHS) {
12525 switch (Pred) {
12526 default:
12527 return false;
12528
12529 case ICmpInst::ICMP_SGE:
12530 std::swap(LHS, RHS);
12531 [[fallthrough]];
12532 case ICmpInst::ICMP_SLE:
12533 return
12534 // min(A, ...) <= A
12535 IsMinMaxConsistingOf<SCEVSMinExpr>(LHS, RHS) ||
12536 // A <= max(A, ...)
12537 IsMinMaxConsistingOf<SCEVSMaxExpr>(RHS, LHS);
12538
12539 case ICmpInst::ICMP_UGE:
12540 std::swap(LHS, RHS);
12541 [[fallthrough]];
12542 case ICmpInst::ICMP_ULE:
12543 return
12544 // min(A, ...) <= A
12545 // FIXME: what about umin_seq?
12546 IsMinMaxConsistingOf<SCEVUMinExpr>(LHS, RHS) ||
12547 // A <= max(A, ...)
12548 IsMinMaxConsistingOf<SCEVUMaxExpr>(RHS, LHS);
12549 }
12550
12551 llvm_unreachable("covered switch fell through?!");
12552}
12553
12554bool ScalarEvolution::isImpliedViaOperations(CmpPredicate Pred, const SCEV *LHS,
12555 const SCEV *RHS,
12556 const SCEV *FoundLHS,
12557 const SCEV *FoundRHS,
12558 unsigned Depth) {
12561 "LHS and RHS have different sizes?");
12562 assert(getTypeSizeInBits(FoundLHS->getType()) ==
12563 getTypeSizeInBits(FoundRHS->getType()) &&
12564 "FoundLHS and FoundRHS have different sizes?");
12565 // We want to avoid hurting the compile time with analysis of too big trees.
12567 return false;
12568
12569 // We only want to work with GT comparison so far.
12570 if (Pred == ICmpInst::ICMP_ULT || Pred == ICmpInst::ICMP_SLT) {
12572 std::swap(LHS, RHS);
12573 std::swap(FoundLHS, FoundRHS);
12574 }
12575
12576 // For unsigned, try to reduce it to corresponding signed comparison.
12577 if (Pred == ICmpInst::ICMP_UGT)
12578 // We can replace unsigned predicate with its signed counterpart if all
12579 // involved values are non-negative.
12580 // TODO: We could have better support for unsigned.
12581 if (isKnownNonNegative(FoundLHS) && isKnownNonNegative(FoundRHS)) {
12582 // Knowing that both FoundLHS and FoundRHS are non-negative, and knowing
12583 // FoundLHS >u FoundRHS, we also know that FoundLHS >s FoundRHS. Let us
12584 // use this fact to prove that LHS and RHS are non-negative.
12585 const SCEV *MinusOne = getMinusOne(LHS->getType());
12586 if (isImpliedCondOperands(ICmpInst::ICMP_SGT, LHS, MinusOne, FoundLHS,
12587 FoundRHS) &&
12588 isImpliedCondOperands(ICmpInst::ICMP_SGT, RHS, MinusOne, FoundLHS,
12589 FoundRHS))
12590 Pred = ICmpInst::ICMP_SGT;
12591 }
12592
12593 if (Pred != ICmpInst::ICMP_SGT)
12594 return false;
12595
12596 auto GetOpFromSExt = [&](const SCEV *S) {
12597 if (auto *Ext = dyn_cast<SCEVSignExtendExpr>(S))
12598 return Ext->getOperand();
12599 // TODO: If S is a SCEVConstant then you can cheaply "strip" the sext off
12600 // the constant in some cases.
12601 return S;
12602 };
12603
12604 // Acquire values from extensions.
12605 auto *OrigLHS = LHS;
12606 auto *OrigFoundLHS = FoundLHS;
12607 LHS = GetOpFromSExt(LHS);
12608 FoundLHS = GetOpFromSExt(FoundLHS);
12609
12610 // Is the SGT predicate can be proved trivially or using the found context.
12611 auto IsSGTViaContext = [&](const SCEV *S1, const SCEV *S2) {
12612 return isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_SGT, S1, S2) ||
12613 isImpliedViaOperations(ICmpInst::ICMP_SGT, S1, S2, OrigFoundLHS,
12614 FoundRHS, Depth + 1);
12615 };
12616
12617 if (auto *LHSAddExpr = dyn_cast<SCEVAddExpr>(LHS)) {
12618 // We want to avoid creation of any new non-constant SCEV. Since we are
12619 // going to compare the operands to RHS, we should be certain that we don't
12620 // need any size extensions for this. So let's decline all cases when the
12621 // sizes of types of LHS and RHS do not match.
12622 // TODO: Maybe try to get RHS from sext to catch more cases?
12624 return false;
12625
12626 // Should not overflow.
12627 if (!LHSAddExpr->hasNoSignedWrap())
12628 return false;
12629
12630 auto *LL = LHSAddExpr->getOperand(0);
12631 auto *LR = LHSAddExpr->getOperand(1);
12632 auto *MinusOne = getMinusOne(RHS->getType());
12633
12634 // Checks that S1 >= 0 && S2 > RHS, trivially or using the found context.
12635 auto IsSumGreaterThanRHS = [&](const SCEV *S1, const SCEV *S2) {
12636 return IsSGTViaContext(S1, MinusOne) && IsSGTViaContext(S2, RHS);
12637 };
12638 // Try to prove the following rule:
12639 // (LHS = LL + LR) && (LL >= 0) && (LR > RHS) => (LHS > RHS).
12640 // (LHS = LL + LR) && (LR >= 0) && (LL > RHS) => (LHS > RHS).
12641 if (IsSumGreaterThanRHS(LL, LR) || IsSumGreaterThanRHS(LR, LL))
12642 return true;
12643 } else if (auto *LHSUnknownExpr = dyn_cast<SCEVUnknown>(LHS)) {
12644 Value *LL, *LR;
12645 // FIXME: Once we have SDiv implemented, we can get rid of this matching.
12646
12647 using namespace llvm::PatternMatch;
12648
12649 if (match(LHSUnknownExpr->getValue(), m_SDiv(m_Value(LL), m_Value(LR)))) {
12650 // Rules for division.
12651 // We are going to perform some comparisons with Denominator and its
12652 // derivative expressions. In general case, creating a SCEV for it may
12653 // lead to a complex analysis of the entire graph, and in particular it
12654 // can request trip count recalculation for the same loop. This would
12655 // cache as SCEVCouldNotCompute to avoid the infinite recursion. To avoid
12656 // this, we only want to create SCEVs that are constants in this section.
12657 // So we bail if Denominator is not a constant.
12658 if (!isa<ConstantInt>(LR))
12659 return false;
12660
12661 auto *Denominator = cast<SCEVConstant>(getSCEV(LR));
12662
12663 // We want to make sure that LHS = FoundLHS / Denominator. If it is so,
12664 // then a SCEV for the numerator already exists and matches with FoundLHS.
12665 auto *Numerator = getExistingSCEV(LL);
12666 if (!Numerator || Numerator->getType() != FoundLHS->getType())
12667 return false;
12668
12669 // Make sure that the numerator matches with FoundLHS and the denominator
12670 // is positive.
12671 if (!HasSameValue(Numerator, FoundLHS) || !isKnownPositive(Denominator))
12672 return false;
12673
12674 auto *DTy = Denominator->getType();
12675 auto *FRHSTy = FoundRHS->getType();
12676 if (DTy->isPointerTy() != FRHSTy->isPointerTy())
12677 // One of types is a pointer and another one is not. We cannot extend
12678 // them properly to a wider type, so let us just reject this case.
12679 // TODO: Usage of getEffectiveSCEVType for DTy, FRHSTy etc should help
12680 // to avoid this check.
12681 return false;
12682
12683 // Given that:
12684 // FoundLHS > FoundRHS, LHS = FoundLHS / Denominator, Denominator > 0.
12685 auto *WTy = getWiderType(DTy, FRHSTy);
12686 auto *DenominatorExt = getNoopOrSignExtend(Denominator, WTy);
12687 auto *FoundRHSExt = getNoopOrSignExtend(FoundRHS, WTy);
12688
12689 // Try to prove the following rule:
12690 // (FoundRHS > Denominator - 2) && (RHS <= 0) => (LHS > RHS).
12691 // For example, given that FoundLHS > 2. It means that FoundLHS is at
12692 // least 3. If we divide it by Denominator < 4, we will have at least 1.
12693 auto *DenomMinusTwo = getMinusSCEV(DenominatorExt, getConstant(WTy, 2));
12694 if (isKnownNonPositive(RHS) &&
12695 IsSGTViaContext(FoundRHSExt, DenomMinusTwo))
12696 return true;
12697
12698 // Try to prove the following rule:
12699 // (FoundRHS > -1 - Denominator) && (RHS < 0) => (LHS > RHS).
12700 // For example, given that FoundLHS > -3. Then FoundLHS is at least -2.
12701 // If we divide it by Denominator > 2, then:
12702 // 1. If FoundLHS is negative, then the result is 0.
12703 // 2. If FoundLHS is non-negative, then the result is non-negative.
12704 // Anyways, the result is non-negative.
12705 auto *MinusOne = getMinusOne(WTy);
12706 auto *NegDenomMinusOne = getMinusSCEV(MinusOne, DenominatorExt);
12707 if (isKnownNegative(RHS) &&
12708 IsSGTViaContext(FoundRHSExt, NegDenomMinusOne))
12709 return true;
12710 }
12711 }
12712
12713 // If our expression contained SCEVUnknown Phis, and we split it down and now
12714 // need to prove something for them, try to prove the predicate for every
12715 // possible incoming values of those Phis.
12716 if (isImpliedViaMerge(Pred, OrigLHS, RHS, OrigFoundLHS, FoundRHS, Depth + 1))
12717 return true;
12718
12719 return false;
12720}
12721
12722static bool isKnownPredicateExtendIdiom(CmpPredicate Pred, const SCEV *LHS,
12723 const SCEV *RHS) {
12724 // zext x u<= sext x, sext x s<= zext x
12725 const SCEV *Op;
12726 switch (Pred) {
12727 case ICmpInst::ICMP_SGE:
12728 std::swap(LHS, RHS);
12729 [[fallthrough]];
12730 case ICmpInst::ICMP_SLE: {
12731 // If operand >=s 0 then ZExt == SExt. If operand <s 0 then SExt <s ZExt.
12732 return match(LHS, m_scev_SExt(m_SCEV(Op))) &&
12734 }
12735 case ICmpInst::ICMP_UGE:
12736 std::swap(LHS, RHS);
12737 [[fallthrough]];
12738 case ICmpInst::ICMP_ULE: {
12739 // If operand >=u 0 then ZExt == SExt. If operand <u 0 then ZExt <u SExt.
12740 return match(LHS, m_scev_ZExt(m_SCEV(Op))) &&
12742 }
12743 default:
12744 return false;
12745 };
12746 llvm_unreachable("unhandled case");
12747}
12748
12749bool ScalarEvolution::isKnownViaNonRecursiveReasoning(CmpPredicate Pred,
12750 const SCEV *LHS,
12751 const SCEV *RHS) {
12752 return isKnownPredicateExtendIdiom(Pred, LHS, RHS) ||
12753 isKnownPredicateViaConstantRanges(Pred, LHS, RHS) ||
12754 IsKnownPredicateViaMinOrMax(*this, Pred, LHS, RHS) ||
12755 IsKnownPredicateViaAddRecStart(*this, Pred, LHS, RHS) ||
12756 isKnownPredicateViaNoOverflow(Pred, LHS, RHS);
12757}
12758
12759bool ScalarEvolution::isImpliedCondOperandsHelper(CmpPredicate Pred,
12760 const SCEV *LHS,
12761 const SCEV *RHS,
12762 const SCEV *FoundLHS,
12763 const SCEV *FoundRHS) {
12764 switch (Pred) {
12765 default:
12766 llvm_unreachable("Unexpected CmpPredicate value!");
12767 case ICmpInst::ICMP_EQ:
12768 case ICmpInst::ICMP_NE:
12769 if (HasSameValue(LHS, FoundLHS) && HasSameValue(RHS, FoundRHS))
12770 return true;
12771 break;
12772 case ICmpInst::ICMP_SLT:
12773 case ICmpInst::ICMP_SLE:
12774 if (isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_SLE, LHS, FoundLHS) &&
12775 isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_SGE, RHS, FoundRHS))
12776 return true;
12777 break;
12778 case ICmpInst::ICMP_SGT:
12779 case ICmpInst::ICMP_SGE:
12780 if (isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_SGE, LHS, FoundLHS) &&
12781 isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_SLE, RHS, FoundRHS))
12782 return true;
12783 break;
12784 case ICmpInst::ICMP_ULT:
12785 case ICmpInst::ICMP_ULE:
12786 if (isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_ULE, LHS, FoundLHS) &&
12787 isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_UGE, RHS, FoundRHS))
12788 return true;
12789 break;
12790 case ICmpInst::ICMP_UGT:
12791 case ICmpInst::ICMP_UGE:
12792 if (isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_UGE, LHS, FoundLHS) &&
12793 isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_ULE, RHS, FoundRHS))
12794 return true;
12795 break;
12796 }
12797
12798 // Maybe it can be proved via operations?
12799 if (isImpliedViaOperations(Pred, LHS, RHS, FoundLHS, FoundRHS))
12800 return true;
12801
12802 return false;
12803}
12804
12805bool ScalarEvolution::isImpliedCondOperandsViaRanges(
12806 CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS, CmpPredicate FoundPred,
12807 const SCEV *FoundLHS, const SCEV *FoundRHS) {
12808 if (!isa<SCEVConstant>(RHS) || !isa<SCEVConstant>(FoundRHS))
12809 // The restriction on `FoundRHS` be lifted easily -- it exists only to
12810 // reduce the compile time impact of this optimization.
12811 return false;
12812
12813 std::optional<APInt> Addend = computeConstantDifference(LHS, FoundLHS);
12814 if (!Addend)
12815 return false;
12816
12817 const APInt &ConstFoundRHS = cast<SCEVConstant>(FoundRHS)->getAPInt();
12818
12819 // `FoundLHSRange` is the range we know `FoundLHS` to be in by virtue of the
12820 // antecedent "`FoundLHS` `FoundPred` `FoundRHS`".
12821 ConstantRange FoundLHSRange =
12822 ConstantRange::makeExactICmpRegion(FoundPred, ConstFoundRHS);
12823
12824 // Since `LHS` is `FoundLHS` + `Addend`, we can compute a range for `LHS`:
12825 ConstantRange LHSRange = FoundLHSRange.add(ConstantRange(*Addend));
12826
12827 // We can also compute the range of values for `LHS` that satisfy the
12828 // consequent, "`LHS` `Pred` `RHS`":
12829 const APInt &ConstRHS = cast<SCEVConstant>(RHS)->getAPInt();
12830 // The antecedent implies the consequent if every value of `LHS` that
12831 // satisfies the antecedent also satisfies the consequent.
12832 return LHSRange.icmp(Pred, ConstRHS);
12833}
12834
12835bool ScalarEvolution::canIVOverflowOnLT(const SCEV *RHS, const SCEV *Stride,
12836 bool IsSigned) {
12837 assert(isKnownPositive(Stride) && "Positive stride expected!");
12838
12839 unsigned BitWidth = getTypeSizeInBits(RHS->getType());
12840 const SCEV *One = getOne(Stride->getType());
12841
12842 if (IsSigned) {
12843 APInt MaxRHS = getSignedRangeMax(RHS);
12845 APInt MaxStrideMinusOne = getSignedRangeMax(getMinusSCEV(Stride, One));
12846
12847 // SMaxRHS + SMaxStrideMinusOne > SMaxValue => overflow!
12848 return (std::move(MaxValue) - MaxStrideMinusOne).slt(MaxRHS);
12849 }
12850
12851 APInt MaxRHS = getUnsignedRangeMax(RHS);
12852 APInt MaxValue = APInt::getMaxValue(BitWidth);
12853 APInt MaxStrideMinusOne = getUnsignedRangeMax(getMinusSCEV(Stride, One));
12854
12855 // UMaxRHS + UMaxStrideMinusOne > UMaxValue => overflow!
12856 return (std::move(MaxValue) - MaxStrideMinusOne).ult(MaxRHS);
12857}
12858
12859bool ScalarEvolution::canIVOverflowOnGT(const SCEV *RHS, const SCEV *Stride,
12860 bool IsSigned) {
12861
12862 unsigned BitWidth = getTypeSizeInBits(RHS->getType());
12863 const SCEV *One = getOne(Stride->getType());
12864
12865 if (IsSigned) {
12866 APInt MinRHS = getSignedRangeMin(RHS);
12868 APInt MaxStrideMinusOne = getSignedRangeMax(getMinusSCEV(Stride, One));
12869
12870 // SMinRHS - SMaxStrideMinusOne < SMinValue => overflow!
12871 return (std::move(MinValue) + MaxStrideMinusOne).sgt(MinRHS);
12872 }
12873
12874 APInt MinRHS = getUnsignedRangeMin(RHS);
12875 APInt MinValue = APInt::getMinValue(BitWidth);
12876 APInt MaxStrideMinusOne = getUnsignedRangeMax(getMinusSCEV(Stride, One));
12877
12878 // UMinRHS - UMaxStrideMinusOne < UMinValue => overflow!
12879 return (std::move(MinValue) + MaxStrideMinusOne).ugt(MinRHS);
12880}
12881
12883 // umin(N, 1) + floor((N - umin(N, 1)) / D)
12884 // This is equivalent to "1 + floor((N - 1) / D)" for N != 0. The umin
12885 // expression fixes the case of N=0.
12886 const SCEV *MinNOne = getUMinExpr(N, getOne(N->getType()));
12887 const SCEV *NMinusOne = getMinusSCEV(N, MinNOne);
12888 return getAddExpr(MinNOne, getUDivExpr(NMinusOne, D));
12889}
12890
12891const SCEV *ScalarEvolution::computeMaxBECountForLT(const SCEV *Start,
12892 const SCEV *Stride,
12893 const SCEV *End,
12894 unsigned BitWidth,
12895 bool IsSigned) {
12896 // The logic in this function assumes we can represent a positive stride.
12897 // If we can't, the backedge-taken count must be zero.
12898 if (IsSigned && BitWidth == 1)
12899 return getZero(Stride->getType());
12900
12901 // This code below only been closely audited for negative strides in the
12902 // unsigned comparison case, it may be correct for signed comparison, but
12903 // that needs to be established.
12904 if (IsSigned && isKnownNegative(Stride))
12905 return getCouldNotCompute();
12906
12907 // Calculate the maximum backedge count based on the range of values
12908 // permitted by Start, End, and Stride.
12909 APInt MinStart =
12910 IsSigned ? getSignedRangeMin(Start) : getUnsignedRangeMin(Start);
12911
12912 APInt MinStride =
12913 IsSigned ? getSignedRangeMin(Stride) : getUnsignedRangeMin(Stride);
12914
12915 // We assume either the stride is positive, or the backedge-taken count
12916 // is zero. So force StrideForMaxBECount to be at least one.
12917 APInt One(BitWidth, 1);
12918 APInt StrideForMaxBECount = IsSigned ? APIntOps::smax(One, MinStride)
12919 : APIntOps::umax(One, MinStride);
12920
12921 APInt MaxValue = IsSigned ? APInt::getSignedMaxValue(BitWidth)
12922 : APInt::getMaxValue(BitWidth);
12923 APInt Limit = MaxValue - (StrideForMaxBECount - 1);
12924
12925 // Although End can be a MAX expression we estimate MaxEnd considering only
12926 // the case End = RHS of the loop termination condition. This is safe because
12927 // in the other case (End - Start) is zero, leading to a zero maximum backedge
12928 // taken count.
12929 APInt MaxEnd = IsSigned ? APIntOps::smin(getSignedRangeMax(End), Limit)
12930 : APIntOps::umin(getUnsignedRangeMax(End), Limit);
12931
12932 // MaxBECount = ceil((max(MaxEnd, MinStart) - MinStart) / Stride)
12933 MaxEnd = IsSigned ? APIntOps::smax(MaxEnd, MinStart)
12934 : APIntOps::umax(MaxEnd, MinStart);
12935
12936 return getUDivCeilSCEV(getConstant(MaxEnd - MinStart) /* Delta */,
12937 getConstant(StrideForMaxBECount) /* Step */);
12938}
12939
12941ScalarEvolution::howManyLessThans(const SCEV *LHS, const SCEV *RHS,
12942 const Loop *L, bool IsSigned,
12943 bool ControlsOnlyExit, bool AllowPredicates) {
12945
12946 const SCEVAddRecExpr *IV = dyn_cast<SCEVAddRecExpr>(LHS);
12947 bool PredicatedIV = false;
12948 if (!IV) {
12949 if (auto *ZExt = dyn_cast<SCEVZeroExtendExpr>(LHS)) {
12950 const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(ZExt->getOperand());
12951 if (AR && AR->getLoop() == L && AR->isAffine()) {
12952 auto canProveNUW = [&]() {
12953 // We can use the comparison to infer no-wrap flags only if it fully
12954 // controls the loop exit.
12955 if (!ControlsOnlyExit)
12956 return false;
12957
12958 if (!isLoopInvariant(RHS, L))
12959 return false;
12960
12961 if (!isKnownNonZero(AR->getStepRecurrence(*this)))
12962 // We need the sequence defined by AR to strictly increase in the
12963 // unsigned integer domain for the logic below to hold.
12964 return false;
12965
12966 const unsigned InnerBitWidth = getTypeSizeInBits(AR->getType());
12967 const unsigned OuterBitWidth = getTypeSizeInBits(RHS->getType());
12968 // If RHS <=u Limit, then there must exist a value V in the sequence
12969 // defined by AR (e.g. {Start,+,Step}) such that V >u RHS, and
12970 // V <=u UINT_MAX. Thus, we must exit the loop before unsigned
12971 // overflow occurs. This limit also implies that a signed comparison
12972 // (in the wide bitwidth) is equivalent to an unsigned comparison as
12973 // the high bits on both sides must be zero.
12974 APInt StrideMax = getUnsignedRangeMax(AR->getStepRecurrence(*this));
12975 APInt Limit = APInt::getMaxValue(InnerBitWidth) - (StrideMax - 1);
12976 Limit = Limit.zext(OuterBitWidth);
12977 return getUnsignedRangeMax(applyLoopGuards(RHS, L)).ule(Limit);
12978 };
12979 auto Flags = AR->getNoWrapFlags();
12980 if (!hasFlags(Flags, SCEV::FlagNUW) && canProveNUW())
12981 Flags = setFlags(Flags, SCEV::FlagNUW);
12982
12983 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), Flags);
12984 if (AR->hasNoUnsignedWrap()) {
12985 // Emulate what getZeroExtendExpr would have done during construction
12986 // if we'd been able to infer the fact just above at that time.
12987 const SCEV *Step = AR->getStepRecurrence(*this);
12988 Type *Ty = ZExt->getType();
12989 auto *S = getAddRecExpr(
12990 getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this, 0),
12991 getZeroExtendExpr(Step, Ty, 0), L, AR->getNoWrapFlags());
12992 IV = dyn_cast<SCEVAddRecExpr>(S);
12993 }
12994 }
12995 }
12996 }
12997
12998
12999 if (!IV && AllowPredicates) {
13000 // Try to make this an AddRec using runtime tests, in the first X
13001 // iterations of this loop, where X is the SCEV expression found by the
13002 // algorithm below.
13003 IV = convertSCEVToAddRecWithPredicates(LHS, L, Predicates);
13004 PredicatedIV = true;
13005 }
13006
13007 // Avoid weird loops
13008 if (!IV || IV->getLoop() != L || !IV->isAffine())
13009 return getCouldNotCompute();
13010
13011 // A precondition of this method is that the condition being analyzed
13012 // reaches an exiting branch which dominates the latch. Given that, we can
13013 // assume that an increment which violates the nowrap specification and
13014 // produces poison must cause undefined behavior when the resulting poison
13015 // value is branched upon and thus we can conclude that the backedge is
13016 // taken no more often than would be required to produce that poison value.
13017 // Note that a well defined loop can exit on the iteration which violates
13018 // the nowrap specification if there is another exit (either explicit or
13019 // implicit/exceptional) which causes the loop to execute before the
13020 // exiting instruction we're analyzing would trigger UB.
13021 auto WrapType = IsSigned ? SCEV::FlagNSW : SCEV::FlagNUW;
13022 bool NoWrap = ControlsOnlyExit && IV->getNoWrapFlags(WrapType);
13024
13025 const SCEV *Stride = IV->getStepRecurrence(*this);
13026
13027 bool PositiveStride = isKnownPositive(Stride);
13028
13029 // Avoid negative or zero stride values.
13030 if (!PositiveStride) {
13031 // We can compute the correct backedge taken count for loops with unknown
13032 // strides if we can prove that the loop is not an infinite loop with side
13033 // effects. Here's the loop structure we are trying to handle -
13034 //
13035 // i = start
13036 // do {
13037 // A[i] = i;
13038 // i += s;
13039 // } while (i < end);
13040 //
13041 // The backedge taken count for such loops is evaluated as -
13042 // (max(end, start + stride) - start - 1) /u stride
13043 //
13044 // The additional preconditions that we need to check to prove correctness
13045 // of the above formula is as follows -
13046 //
13047 // a) IV is either nuw or nsw depending upon signedness (indicated by the
13048 // NoWrap flag).
13049 // b) the loop is guaranteed to be finite (e.g. is mustprogress and has
13050 // no side effects within the loop)
13051 // c) loop has a single static exit (with no abnormal exits)
13052 //
13053 // Precondition a) implies that if the stride is negative, this is a single
13054 // trip loop. The backedge taken count formula reduces to zero in this case.
13055 //
13056 // Precondition b) and c) combine to imply that if rhs is invariant in L,
13057 // then a zero stride means the backedge can't be taken without executing
13058 // undefined behavior.
13059 //
13060 // The positive stride case is the same as isKnownPositive(Stride) returning
13061 // true (original behavior of the function).
13062 //
13063 if (PredicatedIV || !NoWrap || !loopIsFiniteByAssumption(L) ||
13065 return getCouldNotCompute();
13066
13067 if (!isKnownNonZero(Stride)) {
13068 // If we have a step of zero, and RHS isn't invariant in L, we don't know
13069 // if it might eventually be greater than start and if so, on which
13070 // iteration. We can't even produce a useful upper bound.
13071 if (!isLoopInvariant(RHS, L))
13072 return getCouldNotCompute();
13073
13074 // We allow a potentially zero stride, but we need to divide by stride
13075 // below. Since the loop can't be infinite and this check must control
13076 // the sole exit, we can infer the exit must be taken on the first
13077 // iteration (e.g. backedge count = 0) if the stride is zero. Given that,
13078 // we know the numerator in the divides below must be zero, so we can
13079 // pick an arbitrary non-zero value for the denominator (e.g. stride)
13080 // and produce the right result.
13081 // FIXME: Handle the case where Stride is poison?
13082 auto wouldZeroStrideBeUB = [&]() {
13083 // Proof by contradiction. Suppose the stride were zero. If we can
13084 // prove that the backedge *is* taken on the first iteration, then since
13085 // we know this condition controls the sole exit, we must have an
13086 // infinite loop. We can't have a (well defined) infinite loop per
13087 // check just above.
13088 // Note: The (Start - Stride) term is used to get the start' term from
13089 // (start' + stride,+,stride). Remember that we only care about the
13090 // result of this expression when stride == 0 at runtime.
13091 auto *StartIfZero = getMinusSCEV(IV->getStart(), Stride);
13092 return isLoopEntryGuardedByCond(L, Cond, StartIfZero, RHS);
13093 };
13094 if (!wouldZeroStrideBeUB()) {
13095 Stride = getUMaxExpr(Stride, getOne(Stride->getType()));
13096 }
13097 }
13098 } else if (!NoWrap) {
13099 // Avoid proven overflow cases: this will ensure that the backedge taken
13100 // count will not generate any unsigned overflow.
13101 if (canIVOverflowOnLT(RHS, Stride, IsSigned))
13102 return getCouldNotCompute();
13103 }
13104
13105 // On all paths just preceeding, we established the following invariant:
13106 // IV can be assumed not to overflow up to and including the exiting
13107 // iteration. We proved this in one of two ways:
13108 // 1) We can show overflow doesn't occur before the exiting iteration
13109 // 1a) canIVOverflowOnLT, and b) step of one
13110 // 2) We can show that if overflow occurs, the loop must execute UB
13111 // before any possible exit.
13112 // Note that we have not yet proved RHS invariant (in general).
13113
13114 const SCEV *Start = IV->getStart();
13115
13116 // Preserve pointer-typed Start/RHS to pass to isLoopEntryGuardedByCond.
13117 // If we convert to integers, isLoopEntryGuardedByCond will miss some cases.
13118 // Use integer-typed versions for actual computation; we can't subtract
13119 // pointers in general.
13120 const SCEV *OrigStart = Start;
13121 const SCEV *OrigRHS = RHS;
13122 if (Start->getType()->isPointerTy()) {
13123 Start = getLosslessPtrToIntExpr(Start);
13124 if (isa<SCEVCouldNotCompute>(Start))
13125 return Start;
13126 }
13127 if (RHS->getType()->isPointerTy()) {
13129 if (isa<SCEVCouldNotCompute>(RHS))
13130 return RHS;
13131 }
13132
13133 const SCEV *End = nullptr, *BECount = nullptr,
13134 *BECountIfBackedgeTaken = nullptr;
13135 if (!isLoopInvariant(RHS, L)) {
13136 const auto *RHSAddRec = dyn_cast<SCEVAddRecExpr>(RHS);
13137 if (PositiveStride && RHSAddRec != nullptr && RHSAddRec->getLoop() == L &&
13138 RHSAddRec->getNoWrapFlags()) {
13139 // The structure of loop we are trying to calculate backedge count of:
13140 //
13141 // left = left_start
13142 // right = right_start
13143 //
13144 // while(left < right){
13145 // ... do something here ...
13146 // left += s1; // stride of left is s1 (s1 > 0)
13147 // right += s2; // stride of right is s2 (s2 < 0)
13148 // }
13149 //
13150
13151 const SCEV *RHSStart = RHSAddRec->getStart();
13152 const SCEV *RHSStride = RHSAddRec->getStepRecurrence(*this);
13153
13154 // If Stride - RHSStride is positive and does not overflow, we can write
13155 // backedge count as ->
13156 // ceil((End - Start) /u (Stride - RHSStride))
13157 // Where, End = max(RHSStart, Start)
13158
13159 // Check if RHSStride < 0 and Stride - RHSStride will not overflow.
13160 if (isKnownNegative(RHSStride) &&
13161 willNotOverflow(Instruction::Sub, /*Signed=*/true, Stride,
13162 RHSStride)) {
13163
13164 const SCEV *Denominator = getMinusSCEV(Stride, RHSStride);
13165 if (isKnownPositive(Denominator)) {
13166 End = IsSigned ? getSMaxExpr(RHSStart, Start)
13167 : getUMaxExpr(RHSStart, Start);
13168
13169 // We can do this because End >= Start, as End = max(RHSStart, Start)
13170 const SCEV *Delta = getMinusSCEV(End, Start);
13171
13172 BECount = getUDivCeilSCEV(Delta, Denominator);
13173 BECountIfBackedgeTaken =
13174 getUDivCeilSCEV(getMinusSCEV(RHSStart, Start), Denominator);
13175 }
13176 }
13177 }
13178 if (BECount == nullptr) {
13179 // If we cannot calculate ExactBECount, we can calculate the MaxBECount,
13180 // given the start, stride and max value for the end bound of the
13181 // loop (RHS), and the fact that IV does not overflow (which is
13182 // checked above).
13183 const SCEV *MaxBECount = computeMaxBECountForLT(
13184 Start, Stride, RHS, getTypeSizeInBits(LHS->getType()), IsSigned);
13185 return ExitLimit(getCouldNotCompute() /* ExactNotTaken */, MaxBECount,
13186 MaxBECount, false /*MaxOrZero*/, Predicates);
13187 }
13188 } else {
13189 // We use the expression (max(End,Start)-Start)/Stride to describe the
13190 // backedge count, as if the backedge is taken at least once
13191 // max(End,Start) is End and so the result is as above, and if not
13192 // max(End,Start) is Start so we get a backedge count of zero.
13193 auto *OrigStartMinusStride = getMinusSCEV(OrigStart, Stride);
13194 assert(isAvailableAtLoopEntry(OrigStartMinusStride, L) && "Must be!");
13195 assert(isAvailableAtLoopEntry(OrigStart, L) && "Must be!");
13196 assert(isAvailableAtLoopEntry(OrigRHS, L) && "Must be!");
13197 // Can we prove (max(RHS,Start) > Start - Stride?
13198 if (isLoopEntryGuardedByCond(L, Cond, OrigStartMinusStride, OrigStart) &&
13199 isLoopEntryGuardedByCond(L, Cond, OrigStartMinusStride, OrigRHS)) {
13200 // In this case, we can use a refined formula for computing backedge
13201 // taken count. The general formula remains:
13202 // "End-Start /uceiling Stride" where "End = max(RHS,Start)"
13203 // We want to use the alternate formula:
13204 // "((End - 1) - (Start - Stride)) /u Stride"
13205 // Let's do a quick case analysis to show these are equivalent under
13206 // our precondition that max(RHS,Start) > Start - Stride.
13207 // * For RHS <= Start, the backedge-taken count must be zero.
13208 // "((End - 1) - (Start - Stride)) /u Stride" reduces to
13209 // "((Start - 1) - (Start - Stride)) /u Stride" which simplies to
13210 // "Stride - 1 /u Stride" which is indeed zero for all non-zero values
13211 // of Stride. For 0 stride, we've use umin(1,Stride) above,
13212 // reducing this to the stride of 1 case.
13213 // * For RHS >= Start, the backedge count must be "RHS-Start /uceil
13214 // Stride".
13215 // "((End - 1) - (Start - Stride)) /u Stride" reduces to
13216 // "((RHS - 1) - (Start - Stride)) /u Stride" reassociates to
13217 // "((RHS - (Start - Stride) - 1) /u Stride".
13218 // Our preconditions trivially imply no overflow in that form.
13219 const SCEV *MinusOne = getMinusOne(Stride->getType());
13220 const SCEV *Numerator =
13221 getMinusSCEV(getAddExpr(RHS, MinusOne), getMinusSCEV(Start, Stride));
13222 BECount = getUDivExpr(Numerator, Stride);
13223 }
13224
13225 if (!BECount) {
13226 auto canProveRHSGreaterThanEqualStart = [&]() {
13227 auto CondGE = IsSigned ? ICmpInst::ICMP_SGE : ICmpInst::ICMP_UGE;
13228 const SCEV *GuardedRHS = applyLoopGuards(OrigRHS, L);
13229 const SCEV *GuardedStart = applyLoopGuards(OrigStart, L);
13230
13231 if (isLoopEntryGuardedByCond(L, CondGE, OrigRHS, OrigStart) ||
13232 isKnownPredicate(CondGE, GuardedRHS, GuardedStart))
13233 return true;
13234
13235 // (RHS > Start - 1) implies RHS >= Start.
13236 // * "RHS >= Start" is trivially equivalent to "RHS > Start - 1" if
13237 // "Start - 1" doesn't overflow.
13238 // * For signed comparison, if Start - 1 does overflow, it's equal
13239 // to INT_MAX, and "RHS >s INT_MAX" is trivially false.
13240 // * For unsigned comparison, if Start - 1 does overflow, it's equal
13241 // to UINT_MAX, and "RHS >u UINT_MAX" is trivially false.
13242 //
13243 // FIXME: Should isLoopEntryGuardedByCond do this for us?
13244 auto CondGT = IsSigned ? ICmpInst::ICMP_SGT : ICmpInst::ICMP_UGT;
13245 auto *StartMinusOne =
13246 getAddExpr(OrigStart, getMinusOne(OrigStart->getType()));
13247 return isLoopEntryGuardedByCond(L, CondGT, OrigRHS, StartMinusOne);
13248 };
13249
13250 // If we know that RHS >= Start in the context of loop, then we know
13251 // that max(RHS, Start) = RHS at this point.
13252 if (canProveRHSGreaterThanEqualStart()) {
13253 End = RHS;
13254 } else {
13255 // If RHS < Start, the backedge will be taken zero times. So in
13256 // general, we can write the backedge-taken count as:
13257 //
13258 // RHS >= Start ? ceil(RHS - Start) / Stride : 0
13259 //
13260 // We convert it to the following to make it more convenient for SCEV:
13261 //
13262 // ceil(max(RHS, Start) - Start) / Stride
13263 End = IsSigned ? getSMaxExpr(RHS, Start) : getUMaxExpr(RHS, Start);
13264
13265 // See what would happen if we assume the backedge is taken. This is
13266 // used to compute MaxBECount.
13267 BECountIfBackedgeTaken =
13268 getUDivCeilSCEV(getMinusSCEV(RHS, Start), Stride);
13269 }
13270
13271 // At this point, we know:
13272 //
13273 // 1. If IsSigned, Start <=s End; otherwise, Start <=u End
13274 // 2. The index variable doesn't overflow.
13275 //
13276 // Therefore, we know N exists such that
13277 // (Start + Stride * N) >= End, and computing "(Start + Stride * N)"
13278 // doesn't overflow.
13279 //
13280 // Using this information, try to prove whether the addition in
13281 // "(Start - End) + (Stride - 1)" has unsigned overflow.
13282 const SCEV *One = getOne(Stride->getType());
13283 bool MayAddOverflow = [&] {
13284 if (isKnownToBeAPowerOfTwo(Stride)) {
13285 // Suppose Stride is a power of two, and Start/End are unsigned
13286 // integers. Let UMAX be the largest representable unsigned
13287 // integer.
13288 //
13289 // By the preconditions of this function, we know
13290 // "(Start + Stride * N) >= End", and this doesn't overflow.
13291 // As a formula:
13292 //
13293 // End <= (Start + Stride * N) <= UMAX
13294 //
13295 // Subtracting Start from all the terms:
13296 //
13297 // End - Start <= Stride * N <= UMAX - Start
13298 //
13299 // Since Start is unsigned, UMAX - Start <= UMAX. Therefore:
13300 //
13301 // End - Start <= Stride * N <= UMAX
13302 //
13303 // Stride * N is a multiple of Stride. Therefore,
13304 //
13305 // End - Start <= Stride * N <= UMAX - (UMAX mod Stride)
13306 //
13307 // Since Stride is a power of two, UMAX + 1 is divisible by
13308 // Stride. Therefore, UMAX mod Stride == Stride - 1. So we can
13309 // write:
13310 //
13311 // End - Start <= Stride * N <= UMAX - Stride - 1
13312 //
13313 // Dropping the middle term:
13314 //
13315 // End - Start <= UMAX - Stride - 1
13316 //
13317 // Adding Stride - 1 to both sides:
13318 //
13319 // (End - Start) + (Stride - 1) <= UMAX
13320 //
13321 // In other words, the addition doesn't have unsigned overflow.
13322 //
13323 // A similar proof works if we treat Start/End as signed values.
13324 // Just rewrite steps before "End - Start <= Stride * N <= UMAX"
13325 // to use signed max instead of unsigned max. Note that we're
13326 // trying to prove a lack of unsigned overflow in either case.
13327 return false;
13328 }
13329 if (Start == Stride || Start == getMinusSCEV(Stride, One)) {
13330 // If Start is equal to Stride, (End - Start) + (Stride - 1) == End
13331 // - 1. If !IsSigned, 0 <u Stride == Start <=u End; so 0 <u End - 1
13332 // <u End. If IsSigned, 0 <s Stride == Start <=s End; so 0 <s End -
13333 // 1 <s End.
13334 //
13335 // If Start is equal to Stride - 1, (End - Start) + Stride - 1 ==
13336 // End.
13337 return false;
13338 }
13339 return true;
13340 }();
13341
13342 const SCEV *Delta = getMinusSCEV(End, Start);
13343 if (!MayAddOverflow) {
13344 // floor((D + (S - 1)) / S)
13345 // We prefer this formulation if it's legal because it's fewer
13346 // operations.
13347 BECount =
13348 getUDivExpr(getAddExpr(Delta, getMinusSCEV(Stride, One)), Stride);
13349 } else {
13350 BECount = getUDivCeilSCEV(Delta, Stride);
13351 }
13352 }
13353 }
13354
13355 const SCEV *ConstantMaxBECount;
13356 bool MaxOrZero = false;
13357 if (isa<SCEVConstant>(BECount)) {
13358 ConstantMaxBECount = BECount;
13359 } else if (BECountIfBackedgeTaken &&
13360 isa<SCEVConstant>(BECountIfBackedgeTaken)) {
13361 // If we know exactly how many times the backedge will be taken if it's
13362 // taken at least once, then the backedge count will either be that or
13363 // zero.
13364 ConstantMaxBECount = BECountIfBackedgeTaken;
13365 MaxOrZero = true;
13366 } else {
13367 ConstantMaxBECount = computeMaxBECountForLT(
13368 Start, Stride, RHS, getTypeSizeInBits(LHS->getType()), IsSigned);
13369 }
13370
13371 if (isa<SCEVCouldNotCompute>(ConstantMaxBECount) &&
13372 !isa<SCEVCouldNotCompute>(BECount))
13373 ConstantMaxBECount = getConstant(getUnsignedRangeMax(BECount));
13374
13375 const SCEV *SymbolicMaxBECount =
13376 isa<SCEVCouldNotCompute>(BECount) ? ConstantMaxBECount : BECount;
13377 return ExitLimit(BECount, ConstantMaxBECount, SymbolicMaxBECount, MaxOrZero,
13378 Predicates);
13379}
13380
13381ScalarEvolution::ExitLimit ScalarEvolution::howManyGreaterThans(
13382 const SCEV *LHS, const SCEV *RHS, const Loop *L, bool IsSigned,
13383 bool ControlsOnlyExit, bool AllowPredicates) {
13385 // We handle only IV > Invariant
13386 if (!isLoopInvariant(RHS, L))
13387 return getCouldNotCompute();
13388
13389 const SCEVAddRecExpr *IV = dyn_cast<SCEVAddRecExpr>(LHS);
13390 if (!IV && AllowPredicates)
13391 // Try to make this an AddRec using runtime tests, in the first X
13392 // iterations of this loop, where X is the SCEV expression found by the
13393 // algorithm below.
13394 IV = convertSCEVToAddRecWithPredicates(LHS, L, Predicates);
13395
13396 // Avoid weird loops
13397 if (!IV || IV->getLoop() != L || !IV->isAffine())
13398 return getCouldNotCompute();
13399
13400 auto WrapType = IsSigned ? SCEV::FlagNSW : SCEV::FlagNUW;
13401 bool NoWrap = ControlsOnlyExit && IV->getNoWrapFlags(WrapType);
13403
13404 const SCEV *Stride = getNegativeSCEV(IV->getStepRecurrence(*this));
13405
13406 // Avoid negative or zero stride values
13407 if (!isKnownPositive(Stride))
13408 return getCouldNotCompute();
13409
13410 // Avoid proven overflow cases: this will ensure that the backedge taken count
13411 // will not generate any unsigned overflow. Relaxed no-overflow conditions
13412 // exploit NoWrapFlags, allowing to optimize in presence of undefined
13413 // behaviors like the case of C language.
13414 if (!Stride->isOne() && !NoWrap)
13415 if (canIVOverflowOnGT(RHS, Stride, IsSigned))
13416 return getCouldNotCompute();
13417
13418 const SCEV *Start = IV->getStart();
13419 const SCEV *End = RHS;
13420 if (!isLoopEntryGuardedByCond(L, Cond, getAddExpr(Start, Stride), RHS)) {
13421 // If we know that Start >= RHS in the context of loop, then we know that
13422 // min(RHS, Start) = RHS at this point.
13424 L, IsSigned ? ICmpInst::ICMP_SGE : ICmpInst::ICMP_UGE, Start, RHS))
13425 End = RHS;
13426 else
13427 End = IsSigned ? getSMinExpr(RHS, Start) : getUMinExpr(RHS, Start);
13428 }
13429
13430 if (Start->getType()->isPointerTy()) {
13431 Start = getLosslessPtrToIntExpr(Start);
13432 if (isa<SCEVCouldNotCompute>(Start))
13433 return Start;
13434 }
13435 if (End->getType()->isPointerTy()) {
13437 if (isa<SCEVCouldNotCompute>(End))
13438 return End;
13439 }
13440
13441 // Compute ((Start - End) + (Stride - 1)) / Stride.
13442 // FIXME: This can overflow. Holding off on fixing this for now;
13443 // howManyGreaterThans will hopefully be gone soon.
13444 const SCEV *One = getOne(Stride->getType());
13445 const SCEV *BECount = getUDivExpr(
13446 getAddExpr(getMinusSCEV(Start, End), getMinusSCEV(Stride, One)), Stride);
13447
13448 APInt MaxStart = IsSigned ? getSignedRangeMax(Start)
13449 : getUnsignedRangeMax(Start);
13450
13451 APInt MinStride = IsSigned ? getSignedRangeMin(Stride)
13452 : getUnsignedRangeMin(Stride);
13453
13454 unsigned BitWidth = getTypeSizeInBits(LHS->getType());
13455 APInt Limit = IsSigned ? APInt::getSignedMinValue(BitWidth) + (MinStride - 1)
13456 : APInt::getMinValue(BitWidth) + (MinStride - 1);
13457
13458 // Although End can be a MIN expression we estimate MinEnd considering only
13459 // the case End = RHS. This is safe because in the other case (Start - End)
13460 // is zero, leading to a zero maximum backedge taken count.
13461 APInt MinEnd =
13462 IsSigned ? APIntOps::smax(getSignedRangeMin(RHS), Limit)
13463 : APIntOps::umax(getUnsignedRangeMin(RHS), Limit);
13464
13465 const SCEV *ConstantMaxBECount =
13466 isa<SCEVConstant>(BECount)
13467 ? BECount
13468 : getUDivCeilSCEV(getConstant(MaxStart - MinEnd),
13469 getConstant(MinStride));
13470
13471 if (isa<SCEVCouldNotCompute>(ConstantMaxBECount))
13472 ConstantMaxBECount = BECount;
13473 const SCEV *SymbolicMaxBECount =
13474 isa<SCEVCouldNotCompute>(BECount) ? ConstantMaxBECount : BECount;
13475
13476 return ExitLimit(BECount, ConstantMaxBECount, SymbolicMaxBECount, false,
13477 Predicates);
13478}
13479
13481 ScalarEvolution &SE) const {
13482 if (Range.isFullSet()) // Infinite loop.
13483 return SE.getCouldNotCompute();
13484
13485 // If the start is a non-zero constant, shift the range to simplify things.
13486 if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(getStart()))
13487 if (!SC->getValue()->isZero()) {
13489 Operands[0] = SE.getZero(SC->getType());
13490 const SCEV *Shifted = SE.getAddRecExpr(Operands, getLoop(),
13491 getNoWrapFlags(FlagNW));
13492 if (const auto *ShiftedAddRec = dyn_cast<SCEVAddRecExpr>(Shifted))
13493 return ShiftedAddRec->getNumIterationsInRange(
13494 Range.subtract(SC->getAPInt()), SE);
13495 // This is strange and shouldn't happen.
13496 return SE.getCouldNotCompute();
13497 }
13498
13499 // The only time we can solve this is when we have all constant indices.
13500 // Otherwise, we cannot determine the overflow conditions.
13501 if (any_of(operands(), [](const SCEV *Op) { return !isa<SCEVConstant>(Op); }))
13502 return SE.getCouldNotCompute();
13503
13504 // Okay at this point we know that all elements of the chrec are constants and
13505 // that the start element is zero.
13506
13507 // First check to see if the range contains zero. If not, the first
13508 // iteration exits.
13509 unsigned BitWidth = SE.getTypeSizeInBits(getType());
13510 if (!Range.contains(APInt(BitWidth, 0)))
13511 return SE.getZero(getType());
13512
13513 if (isAffine()) {
13514 // If this is an affine expression then we have this situation:
13515 // Solve {0,+,A} in Range === Ax in Range
13516
13517 // We know that zero is in the range. If A is positive then we know that
13518 // the upper value of the range must be the first possible exit value.
13519 // If A is negative then the lower of the range is the last possible loop
13520 // value. Also note that we already checked for a full range.
13521 APInt A = cast<SCEVConstant>(getOperand(1))->getAPInt();
13522 APInt End = A.sge(1) ? (Range.getUpper() - 1) : Range.getLower();
13523
13524 // The exit value should be (End+A)/A.
13525 APInt ExitVal = (End + A).udiv(A);
13526 ConstantInt *ExitValue = ConstantInt::get(SE.getContext(), ExitVal);
13527
13528 // Evaluate at the exit value. If we really did fall out of the valid
13529 // range, then we computed our trip count, otherwise wrap around or other
13530 // things must have happened.
13531 ConstantInt *Val = EvaluateConstantChrecAtConstant(this, ExitValue, SE);
13532 if (Range.contains(Val->getValue()))
13533 return SE.getCouldNotCompute(); // Something strange happened
13534
13535 // Ensure that the previous value is in the range.
13538 ConstantInt::get(SE.getContext(), ExitVal - 1), SE)->getValue()) &&
13539 "Linear scev computation is off in a bad way!");
13540 return SE.getConstant(ExitValue);
13541 }
13542
13543 if (isQuadratic()) {
13544 if (auto S = SolveQuadraticAddRecRange(this, Range, SE))
13545 return SE.getConstant(*S);
13546 }
13547
13548 return SE.getCouldNotCompute();
13549}
13550
13551const SCEVAddRecExpr *
13553 assert(getNumOperands() > 1 && "AddRec with zero step?");
13554 // There is a temptation to just call getAddExpr(this, getStepRecurrence(SE)),
13555 // but in this case we cannot guarantee that the value returned will be an
13556 // AddRec because SCEV does not have a fixed point where it stops
13557 // simplification: it is legal to return ({rec1} + {rec2}). For example, it
13558 // may happen if we reach arithmetic depth limit while simplifying. So we
13559 // construct the returned value explicitly.
13561 // If this is {A,+,B,+,C,...,+,N}, then its step is {B,+,C,+,...,+,N}, and
13562 // (this + Step) is {A+B,+,B+C,+...,+,N}.
13563 for (unsigned i = 0, e = getNumOperands() - 1; i < e; ++i)
13564 Ops.push_back(SE.getAddExpr(getOperand(i), getOperand(i + 1)));
13565 // We know that the last operand is not a constant zero (otherwise it would
13566 // have been popped out earlier). This guarantees us that if the result has
13567 // the same last operand, then it will also not be popped out, meaning that
13568 // the returned value will be an AddRec.
13569 const SCEV *Last = getOperand(getNumOperands() - 1);
13570 assert(!Last->isZero() && "Recurrency with zero step?");
13571 Ops.push_back(Last);
13572 return cast<SCEVAddRecExpr>(SE.getAddRecExpr(Ops, getLoop(),
13574}
13575
13576// Return true when S contains at least an undef value.
13578 return SCEVExprContains(S, [](const SCEV *S) {
13579 if (const auto *SU = dyn_cast<SCEVUnknown>(S))
13580 return isa<UndefValue>(SU->getValue());
13581 return false;
13582 });
13583}
13584
13585// Return true when S contains a value that is a nullptr.
13587 return SCEVExprContains(S, [](const SCEV *S) {
13588 if (const auto *SU = dyn_cast<SCEVUnknown>(S))
13589 return SU->getValue() == nullptr;
13590 return false;
13591 });
13592}
13593
13594/// Return the size of an element read or written by Inst.
13596 Type *Ty;
13597 if (StoreInst *Store = dyn_cast<StoreInst>(Inst))
13598 Ty = Store->getValueOperand()->getType();
13599 else if (LoadInst *Load = dyn_cast<LoadInst>(Inst))
13600 Ty = Load->getType();
13601 else
13602 return nullptr;
13603
13605 return getSizeOfExpr(ETy, Ty);
13606}
13607
13608//===----------------------------------------------------------------------===//
13609// SCEVCallbackVH Class Implementation
13610//===----------------------------------------------------------------------===//
13611
13612void ScalarEvolution::SCEVCallbackVH::deleted() {
13613 assert(SE && "SCEVCallbackVH called with a null ScalarEvolution!");
13614 if (PHINode *PN = dyn_cast<PHINode>(getValPtr()))
13615 SE->ConstantEvolutionLoopExitValue.erase(PN);
13616 SE->eraseValueFromMap(getValPtr());
13617 // this now dangles!
13618}
13619
13620void ScalarEvolution::SCEVCallbackVH::allUsesReplacedWith(Value *V) {
13621 assert(SE && "SCEVCallbackVH called with a null ScalarEvolution!");
13622
13623 // Forget all the expressions associated with users of the old value,
13624 // so that future queries will recompute the expressions using the new
13625 // value.
13626 SE->forgetValue(getValPtr());
13627 // this now dangles!
13628}
13629
13630ScalarEvolution::SCEVCallbackVH::SCEVCallbackVH(Value *V, ScalarEvolution *se)
13631 : CallbackVH(V), SE(se) {}
13632
13633//===----------------------------------------------------------------------===//
13634// ScalarEvolution Class Implementation
13635//===----------------------------------------------------------------------===//
13636
13639 LoopInfo &LI)
13640 : F(F), DL(F.getDataLayout()), TLI(TLI), AC(AC), DT(DT), LI(LI),
13641 CouldNotCompute(new SCEVCouldNotCompute()), ValuesAtScopes(64),
13642 LoopDispositions(64), BlockDispositions(64) {
13643 // To use guards for proving predicates, we need to scan every instruction in
13644 // relevant basic blocks, and not just terminators. Doing this is a waste of
13645 // time if the IR does not actually contain any calls to
13646 // @llvm.experimental.guard, so do a quick check and remember this beforehand.
13647 //
13648 // This pessimizes the case where a pass that preserves ScalarEvolution wants
13649 // to _add_ guards to the module when there weren't any before, and wants
13650 // ScalarEvolution to optimize based on those guards. For now we prefer to be
13651 // efficient in lieu of being smart in that rather obscure case.
13652
13653 auto *GuardDecl = Intrinsic::getDeclarationIfExists(
13654 F.getParent(), Intrinsic::experimental_guard);
13655 HasGuards = GuardDecl && !GuardDecl->use_empty();
13656}
13657
13659 : F(Arg.F), DL(Arg.DL), HasGuards(Arg.HasGuards), TLI(Arg.TLI), AC(Arg.AC),
13660 DT(Arg.DT), LI(Arg.LI), CouldNotCompute(std::move(Arg.CouldNotCompute)),
13661 ValueExprMap(std::move(Arg.ValueExprMap)),
13662 PendingLoopPredicates(std::move(Arg.PendingLoopPredicates)),
13663 PendingPhiRanges(std::move(Arg.PendingPhiRanges)),
13664 PendingMerges(std::move(Arg.PendingMerges)),
13665 ConstantMultipleCache(std::move(Arg.ConstantMultipleCache)),
13666 BackedgeTakenCounts(std::move(Arg.BackedgeTakenCounts)),
13667 PredicatedBackedgeTakenCounts(
13668 std::move(Arg.PredicatedBackedgeTakenCounts)),
13669 BECountUsers(std::move(Arg.BECountUsers)),
13670 ConstantEvolutionLoopExitValue(
13671 std::move(Arg.ConstantEvolutionLoopExitValue)),
13672 ValuesAtScopes(std::move(Arg.ValuesAtScopes)),
13673 ValuesAtScopesUsers(std::move(Arg.ValuesAtScopesUsers)),
13674 LoopDispositions(std::move(Arg.LoopDispositions)),
13675 LoopPropertiesCache(std::move(Arg.LoopPropertiesCache)),
13676 BlockDispositions(std::move(Arg.BlockDispositions)),
13677 SCEVUsers(std::move(Arg.SCEVUsers)),
13678 UnsignedRanges(std::move(Arg.UnsignedRanges)),
13679 SignedRanges(std::move(Arg.SignedRanges)),
13680 UniqueSCEVs(std::move(Arg.UniqueSCEVs)),
13681 UniquePreds(std::move(Arg.UniquePreds)),
13682 SCEVAllocator(std::move(Arg.SCEVAllocator)),
13683 LoopUsers(std::move(Arg.LoopUsers)),
13684 PredicatedSCEVRewrites(std::move(Arg.PredicatedSCEVRewrites)),
13685 FirstUnknown(Arg.FirstUnknown) {
13686 Arg.FirstUnknown = nullptr;
13687}
13688
13690 // Iterate through all the SCEVUnknown instances and call their
13691 // destructors, so that they release their references to their values.
13692 for (SCEVUnknown *U = FirstUnknown; U;) {
13693 SCEVUnknown *Tmp = U;
13694 U = U->Next;
13695 Tmp->~SCEVUnknown();
13696 }
13697 FirstUnknown = nullptr;
13698
13699 ExprValueMap.clear();
13700 ValueExprMap.clear();
13701 HasRecMap.clear();
13702 BackedgeTakenCounts.clear();
13703 PredicatedBackedgeTakenCounts.clear();
13704
13705 assert(PendingLoopPredicates.empty() && "isImpliedCond garbage");
13706 assert(PendingPhiRanges.empty() && "getRangeRef garbage");
13707 assert(PendingMerges.empty() && "isImpliedViaMerge garbage");
13708 assert(!WalkingBEDominatingConds && "isLoopBackedgeGuardedByCond garbage!");
13709 assert(!ProvingSplitPredicate && "ProvingSplitPredicate garbage!");
13710}
13711
13713 return !isa<SCEVCouldNotCompute>(getBackedgeTakenCount(L));
13714}
13715
13716/// When printing a top-level SCEV for trip counts, it's helpful to include
13717/// a type for constants which are otherwise hard to disambiguate.
13718static void PrintSCEVWithTypeHint(raw_ostream &OS, const SCEV* S) {
13719 if (isa<SCEVConstant>(S))
13720 OS << *S->getType() << " ";
13721 OS << *S;
13722}
13723
13725 const Loop *L) {
13726 // Print all inner loops first
13727 for (Loop *I : *L)
13728 PrintLoopInfo(OS, SE, I);
13729
13730 OS << "Loop ";
13731 L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
13732 OS << ": ";
13733
13734 SmallVector<BasicBlock *, 8> ExitingBlocks;
13735 L->getExitingBlocks(ExitingBlocks);
13736 if (ExitingBlocks.size() != 1)
13737 OS << "<multiple exits> ";
13738
13739 auto *BTC = SE->getBackedgeTakenCount(L);
13740 if (!isa<SCEVCouldNotCompute>(BTC)) {
13741 OS << "backedge-taken count is ";
13743 } else
13744 OS << "Unpredictable backedge-taken count.";
13745 OS << "\n";
13746
13747 if (ExitingBlocks.size() > 1)
13748 for (BasicBlock *ExitingBlock : ExitingBlocks) {
13749 OS << " exit count for " << ExitingBlock->getName() << ": ";
13750 const SCEV *EC = SE->getExitCount(L, ExitingBlock);
13752 if (isa<SCEVCouldNotCompute>(EC)) {
13753 // Retry with predicates.
13755 EC = SE->getPredicatedExitCount(L, ExitingBlock, &Predicates);
13756 if (!isa<SCEVCouldNotCompute>(EC)) {
13757 OS << "\n predicated exit count for " << ExitingBlock->getName()
13758 << ": ";
13760 OS << "\n Predicates:\n";
13761 for (const auto *P : Predicates)
13762 P->print(OS, 4);
13763 }
13764 }
13765 OS << "\n";
13766 }
13767
13768 OS << "Loop ";
13769 L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
13770 OS << ": ";
13771
13772 auto *ConstantBTC = SE->getConstantMaxBackedgeTakenCount(L);
13773 if (!isa<SCEVCouldNotCompute>(ConstantBTC)) {
13774 OS << "constant max backedge-taken count is ";
13775 PrintSCEVWithTypeHint(OS, ConstantBTC);
13777 OS << ", actual taken count either this or zero.";
13778 } else {
13779 OS << "Unpredictable constant max backedge-taken count. ";
13780 }
13781
13782 OS << "\n"
13783 "Loop ";
13784 L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
13785 OS << ": ";
13786
13787 auto *SymbolicBTC = SE->getSymbolicMaxBackedgeTakenCount(L);
13788 if (!isa<SCEVCouldNotCompute>(SymbolicBTC)) {
13789 OS << "symbolic max backedge-taken count is ";
13790 PrintSCEVWithTypeHint(OS, SymbolicBTC);
13792 OS << ", actual taken count either this or zero.";
13793 } else {
13794 OS << "Unpredictable symbolic max backedge-taken count. ";
13795 }
13796 OS << "\n";
13797
13798 if (ExitingBlocks.size() > 1)
13799 for (BasicBlock *ExitingBlock : ExitingBlocks) {
13800 OS << " symbolic max exit count for " << ExitingBlock->getName() << ": ";
13801 auto *ExitBTC = SE->getExitCount(L, ExitingBlock,
13803 PrintSCEVWithTypeHint(OS, ExitBTC);
13804 if (isa<SCEVCouldNotCompute>(ExitBTC)) {
13805 // Retry with predicates.
13807 ExitBTC = SE->getPredicatedExitCount(L, ExitingBlock, &Predicates,
13809 if (!isa<SCEVCouldNotCompute>(ExitBTC)) {
13810 OS << "\n predicated symbolic max exit count for "
13811 << ExitingBlock->getName() << ": ";
13812 PrintSCEVWithTypeHint(OS, ExitBTC);
13813 OS << "\n Predicates:\n";
13814 for (const auto *P : Predicates)
13815 P->print(OS, 4);
13816 }
13817 }
13818 OS << "\n";
13819 }
13820
13822 auto *PBT = SE->getPredicatedBackedgeTakenCount(L, Preds);
13823 if (PBT != BTC) {
13824 assert(!Preds.empty() && "Different predicated BTC, but no predicates");
13825 OS << "Loop ";
13826 L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
13827 OS << ": ";
13828 if (!isa<SCEVCouldNotCompute>(PBT)) {
13829 OS << "Predicated backedge-taken count is ";
13831 } else
13832 OS << "Unpredictable predicated backedge-taken count.";
13833 OS << "\n";
13834 OS << " Predicates:\n";
13835 for (const auto *P : Preds)
13836 P->print(OS, 4);
13837 }
13838 Preds.clear();
13839
13840 auto *PredConstantMax =
13842 if (PredConstantMax != ConstantBTC) {
13843 assert(!Preds.empty() &&
13844 "different predicated constant max BTC but no predicates");
13845 OS << "Loop ";
13846 L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
13847 OS << ": ";
13848 if (!isa<SCEVCouldNotCompute>(PredConstantMax)) {
13849 OS << "Predicated constant max backedge-taken count is ";
13850 PrintSCEVWithTypeHint(OS, PredConstantMax);
13851 } else
13852 OS << "Unpredictable predicated constant max backedge-taken count.";
13853 OS << "\n";
13854 OS << " Predicates:\n";
13855 for (const auto *P : Preds)
13856 P->print(OS, 4);
13857 }
13858 Preds.clear();
13859
13860 auto *PredSymbolicMax =
13862 if (SymbolicBTC != PredSymbolicMax) {
13863 assert(!Preds.empty() &&
13864 "Different predicated symbolic max BTC, but no predicates");
13865 OS << "Loop ";
13866 L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
13867 OS << ": ";
13868 if (!isa<SCEVCouldNotCompute>(PredSymbolicMax)) {
13869 OS << "Predicated symbolic max backedge-taken count is ";
13870 PrintSCEVWithTypeHint(OS, PredSymbolicMax);
13871 } else
13872 OS << "Unpredictable predicated symbolic max backedge-taken count.";
13873 OS << "\n";
13874 OS << " Predicates:\n";
13875 for (const auto *P : Preds)
13876 P->print(OS, 4);
13877 }
13878
13880 OS << "Loop ";
13881 L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
13882 OS << ": ";
13883 OS << "Trip multiple is " << SE->getSmallConstantTripMultiple(L) << "\n";
13884 }
13885}
13886
13887namespace llvm {
13889 switch (LD) {
13891 OS << "Variant";
13892 break;
13894 OS << "Invariant";
13895 break;
13897 OS << "Computable";
13898 break;
13899 }
13900 return OS;
13901}
13902
13904 switch (BD) {
13906 OS << "DoesNotDominate";
13907 break;
13909 OS << "Dominates";
13910 break;
13912 OS << "ProperlyDominates";
13913 break;
13914 }
13915 return OS;
13916}
13917} // namespace llvm
13918
13920 // ScalarEvolution's implementation of the print method is to print
13921 // out SCEV values of all instructions that are interesting. Doing
13922 // this potentially causes it to create new SCEV objects though,
13923 // which technically conflicts with the const qualifier. This isn't
13924 // observable from outside the class though, so casting away the
13925 // const isn't dangerous.
13926 ScalarEvolution &SE = *const_cast<ScalarEvolution *>(this);
13927
13928 if (ClassifyExpressions) {
13929 OS << "Classifying expressions for: ";
13930 F.printAsOperand(OS, /*PrintType=*/false);
13931 OS << "\n";
13932 for (Instruction &I : instructions(F))
13933 if (isSCEVable(I.getType()) && !isa<CmpInst>(I)) {
13934 OS << I << '\n';
13935 OS << " --> ";
13936 const SCEV *SV = SE.getSCEV(&I);
13937 SV->print(OS);
13938 if (!isa<SCEVCouldNotCompute>(SV)) {
13939 OS << " U: ";
13940 SE.getUnsignedRange(SV).print(OS);
13941 OS << " S: ";
13942 SE.getSignedRange(SV).print(OS);
13943 }
13944
13945 const Loop *L = LI.getLoopFor(I.getParent());
13946
13947 const SCEV *AtUse = SE.getSCEVAtScope(SV, L);
13948 if (AtUse != SV) {
13949 OS << " --> ";
13950 AtUse->print(OS);
13951 if (!isa<SCEVCouldNotCompute>(AtUse)) {
13952 OS << " U: ";
13953 SE.getUnsignedRange(AtUse).print(OS);
13954 OS << " S: ";
13955 SE.getSignedRange(AtUse).print(OS);
13956 }
13957 }
13958
13959 if (L) {
13960 OS << "\t\t" "Exits: ";
13961 const SCEV *ExitValue = SE.getSCEVAtScope(SV, L->getParentLoop());
13962 if (!SE.isLoopInvariant(ExitValue, L)) {
13963 OS << "<<Unknown>>";
13964 } else {
13965 OS << *ExitValue;
13966 }
13967
13968 bool First = true;
13969 for (const auto *Iter = L; Iter; Iter = Iter->getParentLoop()) {
13970 if (First) {
13971 OS << "\t\t" "LoopDispositions: { ";
13972 First = false;
13973 } else {
13974 OS << ", ";
13975 }
13976
13977 Iter->getHeader()->printAsOperand(OS, /*PrintType=*/false);
13978 OS << ": " << SE.getLoopDisposition(SV, Iter);
13979 }
13980
13981 for (const auto *InnerL : depth_first(L)) {
13982 if (InnerL == L)
13983 continue;
13984 if (First) {
13985 OS << "\t\t" "LoopDispositions: { ";
13986 First = false;
13987 } else {
13988 OS << ", ";
13989 }
13990
13991 InnerL->getHeader()->printAsOperand(OS, /*PrintType=*/false);
13992 OS << ": " << SE.getLoopDisposition(SV, InnerL);
13993 }
13994
13995 OS << " }";
13996 }
13997
13998 OS << "\n";
13999 }
14000 }
14001
14002 OS << "Determining loop execution counts for: ";
14003 F.printAsOperand(OS, /*PrintType=*/false);
14004 OS << "\n";
14005 for (Loop *I : LI)
14006 PrintLoopInfo(OS, &SE, I);
14007}
14008
14011 auto &Values = LoopDispositions[S];
14012 for (auto &V : Values) {
14013 if (V.getPointer() == L)
14014 return V.getInt();
14015 }
14016 Values.emplace_back(L, LoopVariant);
14017 LoopDisposition D = computeLoopDisposition(S, L);
14018 auto &Values2 = LoopDispositions[S];
14019 for (auto &V : llvm::reverse(Values2)) {
14020 if (V.getPointer() == L) {
14021 V.setInt(D);
14022 break;
14023 }
14024 }
14025 return D;
14026}
14027
14029ScalarEvolution::computeLoopDisposition(const SCEV *S, const Loop *L) {
14030 switch (S->getSCEVType()) {
14031 case scConstant:
14032 case scVScale:
14033 return LoopInvariant;
14034 case scAddRecExpr: {
14035 const SCEVAddRecExpr *AR = cast<SCEVAddRecExpr>(S);
14036
14037 // If L is the addrec's loop, it's computable.
14038 if (AR->getLoop() == L)
14039 return LoopComputable;
14040
14041 // Add recurrences are never invariant in the function-body (null loop).
14042 if (!L)
14043 return LoopVariant;
14044
14045 // Everything that is not defined at loop entry is variant.
14046 if (DT.dominates(L->getHeader(), AR->getLoop()->getHeader()))
14047 return LoopVariant;
14048 assert(!L->contains(AR->getLoop()) && "Containing loop's header does not"
14049 " dominate the contained loop's header?");
14050
14051 // This recurrence is invariant w.r.t. L if AR's loop contains L.
14052 if (AR->getLoop()->contains(L))
14053 return LoopInvariant;
14054
14055 // This recurrence is variant w.r.t. L if any of its operands
14056 // are variant.
14057 for (const auto *Op : AR->operands())
14058 if (!isLoopInvariant(Op, L))
14059 return LoopVariant;
14060
14061 // Otherwise it's loop-invariant.
14062 return LoopInvariant;
14063 }
14064 case scTruncate:
14065 case scZeroExtend:
14066 case scSignExtend:
14067 case scPtrToInt:
14068 case scAddExpr:
14069 case scMulExpr:
14070 case scUDivExpr:
14071 case scUMaxExpr:
14072 case scSMaxExpr:
14073 case scUMinExpr:
14074 case scSMinExpr:
14075 case scSequentialUMinExpr: {
14076 bool HasVarying = false;
14077 for (const auto *Op : S->operands()) {
14079 if (D == LoopVariant)
14080 return LoopVariant;
14081 if (D == LoopComputable)
14082 HasVarying = true;
14083 }
14084 return HasVarying ? LoopComputable : LoopInvariant;
14085 }
14086 case scUnknown:
14087 // All non-instruction values are loop invariant. All instructions are loop
14088 // invariant if they are not contained in the specified loop.
14089 // Instructions are never considered invariant in the function body
14090 // (null loop) because they are defined within the "loop".
14091 if (auto *I = dyn_cast<Instruction>(cast<SCEVUnknown>(S)->getValue()))
14092 return (L && !L->contains(I)) ? LoopInvariant : LoopVariant;
14093 return LoopInvariant;
14094 case scCouldNotCompute:
14095 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
14096 }
14097 llvm_unreachable("Unknown SCEV kind!");
14098}
14099
14101 return getLoopDisposition(S, L) == LoopInvariant;
14102}
14103
14105 return getLoopDisposition(S, L) == LoopComputable;
14106}
14107
14110 auto &Values = BlockDispositions[S];
14111 for (auto &V : Values) {
14112 if (V.getPointer() == BB)
14113 return V.getInt();
14114 }
14115 Values.emplace_back(BB, DoesNotDominateBlock);
14116 BlockDisposition D = computeBlockDisposition(S, BB);
14117 auto &Values2 = BlockDispositions[S];
14118 for (auto &V : llvm::reverse(Values2)) {
14119 if (V.getPointer() == BB) {
14120 V.setInt(D);
14121 break;
14122 }
14123 }
14124 return D;
14125}
14126
14128ScalarEvolution::computeBlockDisposition(const SCEV *S, const BasicBlock *BB) {
14129 switch (S->getSCEVType()) {
14130 case scConstant:
14131 case scVScale:
14133 case scAddRecExpr: {
14134 // This uses a "dominates" query instead of "properly dominates" query
14135 // to test for proper dominance too, because the instruction which
14136 // produces the addrec's value is a PHI, and a PHI effectively properly
14137 // dominates its entire containing block.
14138 const SCEVAddRecExpr *AR = cast<SCEVAddRecExpr>(S);
14139 if (!DT.dominates(AR->getLoop()->getHeader(), BB))
14140 return DoesNotDominateBlock;
14141
14142 // Fall through into SCEVNAryExpr handling.
14143 [[fallthrough]];
14144 }
14145 case scTruncate:
14146 case scZeroExtend:
14147 case scSignExtend:
14148 case scPtrToInt:
14149 case scAddExpr:
14150 case scMulExpr:
14151 case scUDivExpr:
14152 case scUMaxExpr:
14153 case scSMaxExpr:
14154 case scUMinExpr:
14155 case scSMinExpr:
14156 case scSequentialUMinExpr: {
14157 bool Proper = true;
14158 for (const SCEV *NAryOp : S->operands()) {
14160 if (D == DoesNotDominateBlock)
14161 return DoesNotDominateBlock;
14162 if (D == DominatesBlock)
14163 Proper = false;
14164 }
14165 return Proper ? ProperlyDominatesBlock : DominatesBlock;
14166 }
14167 case scUnknown:
14168 if (Instruction *I =
14169 dyn_cast<Instruction>(cast<SCEVUnknown>(S)->getValue())) {
14170 if (I->getParent() == BB)
14171 return DominatesBlock;
14172 if (DT.properlyDominates(I->getParent(), BB))
14174 return DoesNotDominateBlock;
14175 }
14177 case scCouldNotCompute:
14178 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
14179 }
14180 llvm_unreachable("Unknown SCEV kind!");
14181}
14182
14183bool ScalarEvolution::dominates(const SCEV *S, const BasicBlock *BB) {
14184 return getBlockDisposition(S, BB) >= DominatesBlock;
14185}
14186
14189}
14190
14191bool ScalarEvolution::hasOperand(const SCEV *S, const SCEV *Op) const {
14192 return SCEVExprContains(S, [&](const SCEV *Expr) { return Expr == Op; });
14193}
14194
14195void ScalarEvolution::forgetBackedgeTakenCounts(const Loop *L,
14196 bool Predicated) {
14197 auto &BECounts =
14198 Predicated ? PredicatedBackedgeTakenCounts : BackedgeTakenCounts;
14199 auto It = BECounts.find(L);
14200 if (It != BECounts.end()) {
14201 for (const ExitNotTakenInfo &ENT : It->second.ExitNotTaken) {
14202 for (const SCEV *S : {ENT.ExactNotTaken, ENT.SymbolicMaxNotTaken}) {
14203 if (!isa<SCEVConstant>(S)) {
14204 auto UserIt = BECountUsers.find(S);
14205 assert(UserIt != BECountUsers.end());
14206 UserIt->second.erase({L, Predicated});
14207 }
14208 }
14209 }
14210 BECounts.erase(It);
14211 }
14212}
14213
14214void ScalarEvolution::forgetMemoizedResults(ArrayRef<const SCEV *> SCEVs) {
14215 SmallPtrSet<const SCEV *, 8> ToForget(SCEVs.begin(), SCEVs.end());
14216 SmallVector<const SCEV *, 8> Worklist(ToForget.begin(), ToForget.end());
14217
14218 while (!Worklist.empty()) {
14219 const SCEV *Curr = Worklist.pop_back_val();
14220 auto Users = SCEVUsers.find(Curr);
14221 if (Users != SCEVUsers.end())
14222 for (const auto *User : Users->second)
14223 if (ToForget.insert(User).second)
14224 Worklist.push_back(User);
14225 }
14226
14227 for (const auto *S : ToForget)
14228 forgetMemoizedResultsImpl(S);
14229
14230 for (auto I = PredicatedSCEVRewrites.begin();
14231 I != PredicatedSCEVRewrites.end();) {
14232 std::pair<const SCEV *, const Loop *> Entry = I->first;
14233 if (ToForget.count(Entry.first))
14234 PredicatedSCEVRewrites.erase(I++);
14235 else
14236 ++I;
14237 }
14238}
14239
14240void ScalarEvolution::forgetMemoizedResultsImpl(const SCEV *S) {
14241 LoopDispositions.erase(S);
14242 BlockDispositions.erase(S);
14243 UnsignedRanges.erase(S);
14244 SignedRanges.erase(S);
14245 HasRecMap.erase(S);
14246 ConstantMultipleCache.erase(S);
14247
14248 if (auto *AR = dyn_cast<SCEVAddRecExpr>(S)) {
14249 UnsignedWrapViaInductionTried.erase(AR);
14250 SignedWrapViaInductionTried.erase(AR);
14251 }
14252
14253 auto ExprIt = ExprValueMap.find(S);
14254 if (ExprIt != ExprValueMap.end()) {
14255 for (Value *V : ExprIt->second) {
14256 auto ValueIt = ValueExprMap.find_as(V);
14257 if (ValueIt != ValueExprMap.end())
14258 ValueExprMap.erase(ValueIt);
14259 }
14260 ExprValueMap.erase(ExprIt);
14261 }
14262
14263 auto ScopeIt = ValuesAtScopes.find(S);
14264 if (ScopeIt != ValuesAtScopes.end()) {
14265 for (const auto &Pair : ScopeIt->second)
14266 if (!isa_and_nonnull<SCEVConstant>(Pair.second))
14267 llvm::erase(ValuesAtScopesUsers[Pair.second],
14268 std::make_pair(Pair.first, S));
14269 ValuesAtScopes.erase(ScopeIt);
14270 }
14271
14272 auto ScopeUserIt = ValuesAtScopesUsers.find(S);
14273 if (ScopeUserIt != ValuesAtScopesUsers.end()) {
14274 for (const auto &Pair : ScopeUserIt->second)
14275 llvm::erase(ValuesAtScopes[Pair.second], std::make_pair(Pair.first, S));
14276 ValuesAtScopesUsers.erase(ScopeUserIt);
14277 }
14278
14279 auto BEUsersIt = BECountUsers.find(S);
14280 if (BEUsersIt != BECountUsers.end()) {
14281 // Work on a copy, as forgetBackedgeTakenCounts() will modify the original.
14282 auto Copy = BEUsersIt->second;
14283 for (const auto &Pair : Copy)
14284 forgetBackedgeTakenCounts(Pair.getPointer(), Pair.getInt());
14285 BECountUsers.erase(BEUsersIt);
14286 }
14287
14288 auto FoldUser = FoldCacheUser.find(S);
14289 if (FoldUser != FoldCacheUser.end())
14290 for (auto &KV : FoldUser->second)
14291 FoldCache.erase(KV);
14292 FoldCacheUser.erase(S);
14293}
14294
14295void
14296ScalarEvolution::getUsedLoops(const SCEV *S,
14297 SmallPtrSetImpl<const Loop *> &LoopsUsed) {
14298 struct FindUsedLoops {
14299 FindUsedLoops(SmallPtrSetImpl<const Loop *> &LoopsUsed)
14300 : LoopsUsed(LoopsUsed) {}
14302 bool follow(const SCEV *S) {
14303 if (auto *AR = dyn_cast<SCEVAddRecExpr>(S))
14304 LoopsUsed.insert(AR->getLoop());
14305 return true;
14306 }
14307
14308 bool isDone() const { return false; }
14309 };
14310
14311 FindUsedLoops F(LoopsUsed);
14313}
14314
14315void ScalarEvolution::getReachableBlocks(
14318 Worklist.push_back(&F.getEntryBlock());
14319 while (!Worklist.empty()) {
14320 BasicBlock *BB = Worklist.pop_back_val();
14321 if (!Reachable.insert(BB).second)
14322 continue;
14323
14324 Value *Cond;
14325 BasicBlock *TrueBB, *FalseBB;
14326 if (match(BB->getTerminator(), m_Br(m_Value(Cond), m_BasicBlock(TrueBB),
14327 m_BasicBlock(FalseBB)))) {
14328 if (auto *C = dyn_cast<ConstantInt>(Cond)) {
14329 Worklist.push_back(C->isOne() ? TrueBB : FalseBB);
14330 continue;
14331 }
14332
14333 if (auto *Cmp = dyn_cast<ICmpInst>(Cond)) {
14334 const SCEV *L = getSCEV(Cmp->getOperand(0));
14335 const SCEV *R = getSCEV(Cmp->getOperand(1));
14336 if (isKnownPredicateViaConstantRanges(Cmp->getCmpPredicate(), L, R)) {
14337 Worklist.push_back(TrueBB);
14338 continue;
14339 }
14340 if (isKnownPredicateViaConstantRanges(Cmp->getInverseCmpPredicate(), L,
14341 R)) {
14342 Worklist.push_back(FalseBB);
14343 continue;
14344 }
14345 }
14346 }
14347
14348 append_range(Worklist, successors(BB));
14349 }
14350}
14351
14353 ScalarEvolution &SE = *const_cast<ScalarEvolution *>(this);
14354 ScalarEvolution SE2(F, TLI, AC, DT, LI);
14355
14356 SmallVector<Loop *, 8> LoopStack(LI.begin(), LI.end());
14357
14358 // Map's SCEV expressions from one ScalarEvolution "universe" to another.
14359 struct SCEVMapper : public SCEVRewriteVisitor<SCEVMapper> {
14360 SCEVMapper(ScalarEvolution &SE) : SCEVRewriteVisitor<SCEVMapper>(SE) {}
14361
14362 const SCEV *visitConstant(const SCEVConstant *Constant) {
14363 return SE.getConstant(Constant->getAPInt());
14364 }
14365
14366 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
14367 return SE.getUnknown(Expr->getValue());
14368 }
14369
14370 const SCEV *visitCouldNotCompute(const SCEVCouldNotCompute *Expr) {
14371 return SE.getCouldNotCompute();
14372 }
14373 };
14374
14375 SCEVMapper SCM(SE2);
14376 SmallPtrSet<BasicBlock *, 16> ReachableBlocks;
14377 SE2.getReachableBlocks(ReachableBlocks, F);
14378
14379 auto GetDelta = [&](const SCEV *Old, const SCEV *New) -> const SCEV * {
14380 if (containsUndefs(Old) || containsUndefs(New)) {
14381 // SCEV treats "undef" as an unknown but consistent value (i.e. it does
14382 // not propagate undef aggressively). This means we can (and do) fail
14383 // verification in cases where a transform makes a value go from "undef"
14384 // to "undef+1" (say). The transform is fine, since in both cases the
14385 // result is "undef", but SCEV thinks the value increased by 1.
14386 return nullptr;
14387 }
14388
14389 // Unless VerifySCEVStrict is set, we only compare constant deltas.
14390 const SCEV *Delta = SE2.getMinusSCEV(Old, New);
14391 if (!VerifySCEVStrict && !isa<SCEVConstant>(Delta))
14392 return nullptr;
14393
14394 return Delta;
14395 };
14396
14397 while (!LoopStack.empty()) {
14398 auto *L = LoopStack.pop_back_val();
14399 llvm::append_range(LoopStack, *L);
14400
14401 // Only verify BECounts in reachable loops. For an unreachable loop,
14402 // any BECount is legal.
14403 if (!ReachableBlocks.contains(L->getHeader()))
14404 continue;
14405
14406 // Only verify cached BECounts. Computing new BECounts may change the
14407 // results of subsequent SCEV uses.
14408 auto It = BackedgeTakenCounts.find(L);
14409 if (It == BackedgeTakenCounts.end())
14410 continue;
14411
14412 auto *CurBECount =
14413 SCM.visit(It->second.getExact(L, const_cast<ScalarEvolution *>(this)));
14414 auto *NewBECount = SE2.getBackedgeTakenCount(L);
14415
14416 if (CurBECount == SE2.getCouldNotCompute() ||
14417 NewBECount == SE2.getCouldNotCompute()) {
14418 // NB! This situation is legal, but is very suspicious -- whatever pass
14419 // change the loop to make a trip count go from could not compute to
14420 // computable or vice-versa *should have* invalidated SCEV. However, we
14421 // choose not to assert here (for now) since we don't want false
14422 // positives.
14423 continue;
14424 }
14425
14426 if (SE.getTypeSizeInBits(CurBECount->getType()) >
14427 SE.getTypeSizeInBits(NewBECount->getType()))
14428 NewBECount = SE2.getZeroExtendExpr(NewBECount, CurBECount->getType());
14429 else if (SE.getTypeSizeInBits(CurBECount->getType()) <
14430 SE.getTypeSizeInBits(NewBECount->getType()))
14431 CurBECount = SE2.getZeroExtendExpr(CurBECount, NewBECount->getType());
14432
14433 const SCEV *Delta = GetDelta(CurBECount, NewBECount);
14434 if (Delta && !Delta->isZero()) {
14435 dbgs() << "Trip Count for " << *L << " Changed!\n";
14436 dbgs() << "Old: " << *CurBECount << "\n";
14437 dbgs() << "New: " << *NewBECount << "\n";
14438 dbgs() << "Delta: " << *Delta << "\n";
14439 std::abort();
14440 }
14441 }
14442
14443 // Collect all valid loops currently in LoopInfo.
14444 SmallPtrSet<Loop *, 32> ValidLoops;
14445 SmallVector<Loop *, 32> Worklist(LI.begin(), LI.end());
14446 while (!Worklist.empty()) {
14447 Loop *L = Worklist.pop_back_val();
14448 if (ValidLoops.insert(L).second)
14449 Worklist.append(L->begin(), L->end());
14450 }
14451 for (const auto &KV : ValueExprMap) {
14452#ifndef NDEBUG
14453 // Check for SCEV expressions referencing invalid/deleted loops.
14454 if (auto *AR = dyn_cast<SCEVAddRecExpr>(KV.second)) {
14455 assert(ValidLoops.contains(AR->getLoop()) &&
14456 "AddRec references invalid loop");
14457 }
14458#endif
14459
14460 // Check that the value is also part of the reverse map.
14461 auto It = ExprValueMap.find(KV.second);
14462 if (It == ExprValueMap.end() || !It->second.contains(KV.first)) {
14463 dbgs() << "Value " << *KV.first
14464 << " is in ValueExprMap but not in ExprValueMap\n";
14465 std::abort();
14466 }
14467
14468 if (auto *I = dyn_cast<Instruction>(&*KV.first)) {
14469 if (!ReachableBlocks.contains(I->getParent()))
14470 continue;
14471 const SCEV *OldSCEV = SCM.visit(KV.second);
14472 const SCEV *NewSCEV = SE2.getSCEV(I);
14473 const SCEV *Delta = GetDelta(OldSCEV, NewSCEV);
14474 if (Delta && !Delta->isZero()) {
14475 dbgs() << "SCEV for value " << *I << " changed!\n"
14476 << "Old: " << *OldSCEV << "\n"
14477 << "New: " << *NewSCEV << "\n"
14478 << "Delta: " << *Delta << "\n";
14479 std::abort();
14480 }
14481 }
14482 }
14483
14484 for (const auto &KV : ExprValueMap) {
14485 for (Value *V : KV.second) {
14486 auto It = ValueExprMap.find_as(V);
14487 if (It == ValueExprMap.end()) {
14488 dbgs() << "Value " << *V
14489 << " is in ExprValueMap but not in ValueExprMap\n";
14490 std::abort();
14491 }
14492 if (It->second != KV.first) {
14493 dbgs() << "Value " << *V << " mapped to " << *It->second
14494 << " rather than " << *KV.first << "\n";
14495 std::abort();
14496 }
14497 }
14498 }
14499
14500 // Verify integrity of SCEV users.
14501 for (const auto &S : UniqueSCEVs) {
14502 for (const auto *Op : S.operands()) {
14503 // We do not store dependencies of constants.
14504 if (isa<SCEVConstant>(Op))
14505 continue;
14506 auto It = SCEVUsers.find(Op);
14507 if (It != SCEVUsers.end() && It->second.count(&S))
14508 continue;
14509 dbgs() << "Use of operand " << *Op << " by user " << S
14510 << " is not being tracked!\n";
14511 std::abort();
14512 }
14513 }
14514
14515 // Verify integrity of ValuesAtScopes users.
14516 for (const auto &ValueAndVec : ValuesAtScopes) {
14517 const SCEV *Value = ValueAndVec.first;
14518 for (const auto &LoopAndValueAtScope : ValueAndVec.second) {
14519 const Loop *L = LoopAndValueAtScope.first;
14520 const SCEV *ValueAtScope = LoopAndValueAtScope.second;
14521 if (!isa<SCEVConstant>(ValueAtScope)) {
14522 auto It = ValuesAtScopesUsers.find(ValueAtScope);
14523 if (It != ValuesAtScopesUsers.end() &&
14524 is_contained(It->second, std::make_pair(L, Value)))
14525 continue;
14526 dbgs() << "Value: " << *Value << ", Loop: " << *L << ", ValueAtScope: "
14527 << *ValueAtScope << " missing in ValuesAtScopesUsers\n";
14528 std::abort();
14529 }
14530 }
14531 }
14532
14533 for (const auto &ValueAtScopeAndVec : ValuesAtScopesUsers) {
14534 const SCEV *ValueAtScope = ValueAtScopeAndVec.first;
14535 for (const auto &LoopAndValue : ValueAtScopeAndVec.second) {
14536 const Loop *L = LoopAndValue.first;
14537 const SCEV *Value = LoopAndValue.second;
14538 assert(!isa<SCEVConstant>(Value));
14539 auto It = ValuesAtScopes.find(Value);
14540 if (It != ValuesAtScopes.end() &&
14541 is_contained(It->second, std::make_pair(L, ValueAtScope)))
14542 continue;
14543 dbgs() << "Value: " << *Value << ", Loop: " << *L << ", ValueAtScope: "
14544 << *ValueAtScope << " missing in ValuesAtScopes\n";
14545 std::abort();
14546 }
14547 }
14548
14549 // Verify integrity of BECountUsers.
14550 auto VerifyBECountUsers = [&](bool Predicated) {
14551 auto &BECounts =
14552 Predicated ? PredicatedBackedgeTakenCounts : BackedgeTakenCounts;
14553 for (const auto &LoopAndBEInfo : BECounts) {
14554 for (const ExitNotTakenInfo &ENT : LoopAndBEInfo.second.ExitNotTaken) {
14555 for (const SCEV *S : {ENT.ExactNotTaken, ENT.SymbolicMaxNotTaken}) {
14556 if (!isa<SCEVConstant>(S)) {
14557 auto UserIt = BECountUsers.find(S);
14558 if (UserIt != BECountUsers.end() &&
14559 UserIt->second.contains({ LoopAndBEInfo.first, Predicated }))
14560 continue;
14561 dbgs() << "Value " << *S << " for loop " << *LoopAndBEInfo.first
14562 << " missing from BECountUsers\n";
14563 std::abort();
14564 }
14565 }
14566 }
14567 }
14568 };
14569 VerifyBECountUsers(/* Predicated */ false);
14570 VerifyBECountUsers(/* Predicated */ true);
14571
14572 // Verify intergity of loop disposition cache.
14573 for (auto &[S, Values] : LoopDispositions) {
14574 for (auto [Loop, CachedDisposition] : Values) {
14575 const auto RecomputedDisposition = SE2.getLoopDisposition(S, Loop);
14576 if (CachedDisposition != RecomputedDisposition) {
14577 dbgs() << "Cached disposition of " << *S << " for loop " << *Loop
14578 << " is incorrect: cached " << CachedDisposition << ", actual "
14579 << RecomputedDisposition << "\n";
14580 std::abort();
14581 }
14582 }
14583 }
14584
14585 // Verify integrity of the block disposition cache.
14586 for (auto &[S, Values] : BlockDispositions) {
14587 for (auto [BB, CachedDisposition] : Values) {
14588 const auto RecomputedDisposition = SE2.getBlockDisposition(S, BB);
14589 if (CachedDisposition != RecomputedDisposition) {
14590 dbgs() << "Cached disposition of " << *S << " for block %"
14591 << BB->getName() << " is incorrect: cached " << CachedDisposition
14592 << ", actual " << RecomputedDisposition << "\n";
14593 std::abort();
14594 }
14595 }
14596 }
14597
14598 // Verify FoldCache/FoldCacheUser caches.
14599 for (auto [FoldID, Expr] : FoldCache) {
14600 auto I = FoldCacheUser.find(Expr);
14601 if (I == FoldCacheUser.end()) {
14602 dbgs() << "Missing entry in FoldCacheUser for cached expression " << *Expr
14603 << "!\n";
14604 std::abort();
14605 }
14606 if (!is_contained(I->second, FoldID)) {
14607 dbgs() << "Missing FoldID in cached users of " << *Expr << "!\n";
14608 std::abort();
14609 }
14610 }
14611 for (auto [Expr, IDs] : FoldCacheUser) {
14612 for (auto &FoldID : IDs) {
14613 auto I = FoldCache.find(FoldID);
14614 if (I == FoldCache.end()) {
14615 dbgs() << "Missing entry in FoldCache for expression " << *Expr
14616 << "!\n";
14617 std::abort();
14618 }
14619 if (I->second != Expr) {
14620 dbgs() << "Entry in FoldCache doesn't match FoldCacheUser: "
14621 << *I->second << " != " << *Expr << "!\n";
14622 std::abort();
14623 }
14624 }
14625 }
14626
14627 // Verify that ConstantMultipleCache computations are correct. We check that
14628 // cached multiples and recomputed multiples are multiples of each other to
14629 // verify correctness. It is possible that a recomputed multiple is different
14630 // from the cached multiple due to strengthened no wrap flags or changes in
14631 // KnownBits computations.
14632 for (auto [S, Multiple] : ConstantMultipleCache) {
14633 APInt RecomputedMultiple = SE2.getConstantMultiple(S);
14634 if ((Multiple != 0 && RecomputedMultiple != 0 &&
14635 Multiple.urem(RecomputedMultiple) != 0 &&
14636 RecomputedMultiple.urem(Multiple) != 0)) {
14637 dbgs() << "Incorrect cached computation in ConstantMultipleCache for "
14638 << *S << " : Computed " << RecomputedMultiple
14639 << " but cache contains " << Multiple << "!\n";
14640 std::abort();
14641 }
14642 }
14643}
14644
14646 Function &F, const PreservedAnalyses &PA,
14648 // Invalidate the ScalarEvolution object whenever it isn't preserved or one
14649 // of its dependencies is invalidated.
14650 auto PAC = PA.getChecker<ScalarEvolutionAnalysis>();
14651 return !(PAC.preserved() || PAC.preservedSet<AllAnalysesOn<Function>>()) ||
14652 Inv.invalidate<AssumptionAnalysis>(F, PA) ||
14654 Inv.invalidate<LoopAnalysis>(F, PA);
14655}
14656
14657AnalysisKey ScalarEvolutionAnalysis::Key;
14658
14661 auto &TLI = AM.getResult<TargetLibraryAnalysis>(F);
14662 auto &AC = AM.getResult<AssumptionAnalysis>(F);
14663 auto &DT = AM.getResult<DominatorTreeAnalysis>(F);
14664 auto &LI = AM.getResult<LoopAnalysis>(F);
14665 return ScalarEvolution(F, TLI, AC, DT, LI);
14666}
14667
14671 return PreservedAnalyses::all();
14672}
14673
14676 // For compatibility with opt's -analyze feature under legacy pass manager
14677 // which was not ported to NPM. This keeps tests using
14678 // update_analyze_test_checks.py working.
14679 OS << "Printing analysis 'Scalar Evolution Analysis' for function '"
14680 << F.getName() << "':\n";
14682 return PreservedAnalyses::all();
14683}
14684
14686 "Scalar Evolution Analysis", false, true)
14692 "Scalar Evolution Analysis", false, true)
14693
14695
14698}
14699
14701 SE.reset(new ScalarEvolution(
14702 F, getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F),
14703 getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F),
14704 getAnalysis<DominatorTreeWrapperPass>().getDomTree(),
14705 getAnalysis<LoopInfoWrapperPass>().getLoopInfo()));
14706 return false;
14707}
14708
14710
14712 SE->print(OS);
14713}
14714
14716 if (!VerifySCEV)
14717 return;
14718
14719 SE->verify();
14720}
14721
14723 AU.setPreservesAll();
14728}
14729
14731 const SCEV *RHS) {
14733}
14734
14735const SCEVPredicate *
14737 const SCEV *LHS, const SCEV *RHS) {
14739 assert(LHS->getType() == RHS->getType() &&
14740 "Type mismatch between LHS and RHS");
14741 // Unique this node based on the arguments
14742 ID.AddInteger(SCEVPredicate::P_Compare);
14743 ID.AddInteger(Pred);
14744 ID.AddPointer(LHS);
14745 ID.AddPointer(RHS);
14746 void *IP = nullptr;
14747 if (const auto *S = UniquePreds.FindNodeOrInsertPos(ID, IP))
14748 return S;
14749 SCEVComparePredicate *Eq = new (SCEVAllocator)
14750 SCEVComparePredicate(ID.Intern(SCEVAllocator), Pred, LHS, RHS);
14751 UniquePreds.InsertNode(Eq, IP);
14752 return Eq;
14753}
14754
14756 const SCEVAddRecExpr *AR,
14759 // Unique this node based on the arguments
14760 ID.AddInteger(SCEVPredicate::P_Wrap);
14761 ID.AddPointer(AR);
14762 ID.AddInteger(AddedFlags);
14763 void *IP = nullptr;
14764 if (const auto *S = UniquePreds.FindNodeOrInsertPos(ID, IP))
14765 return S;
14766 auto *OF = new (SCEVAllocator)
14767 SCEVWrapPredicate(ID.Intern(SCEVAllocator), AR, AddedFlags);
14768 UniquePreds.InsertNode(OF, IP);
14769 return OF;
14770}
14771
14772namespace {
14773
14774class SCEVPredicateRewriter : public SCEVRewriteVisitor<SCEVPredicateRewriter> {
14775public:
14776
14777 /// Rewrites \p S in the context of a loop L and the SCEV predication
14778 /// infrastructure.
14779 ///
14780 /// If \p Pred is non-null, the SCEV expression is rewritten to respect the
14781 /// equivalences present in \p Pred.
14782 ///
14783 /// If \p NewPreds is non-null, rewrite is free to add further predicates to
14784 /// \p NewPreds such that the result will be an AddRecExpr.
14785 static const SCEV *rewrite(const SCEV *S, const Loop *L, ScalarEvolution &SE,
14787 const SCEVPredicate *Pred) {
14788 SCEVPredicateRewriter Rewriter(L, SE, NewPreds, Pred);
14789 return Rewriter.visit(S);
14790 }
14791
14792 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
14793 if (Pred) {
14794 if (auto *U = dyn_cast<SCEVUnionPredicate>(Pred)) {
14795 for (const auto *Pred : U->getPredicates())
14796 if (const auto *IPred = dyn_cast<SCEVComparePredicate>(Pred))
14797 if (IPred->getLHS() == Expr &&
14798 IPred->getPredicate() == ICmpInst::ICMP_EQ)
14799 return IPred->getRHS();
14800 } else if (const auto *IPred = dyn_cast<SCEVComparePredicate>(Pred)) {
14801 if (IPred->getLHS() == Expr &&
14802 IPred->getPredicate() == ICmpInst::ICMP_EQ)
14803 return IPred->getRHS();
14804 }
14805 }
14806 return convertToAddRecWithPreds(Expr);
14807 }
14808
14809 const SCEV *visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr) {
14810 const SCEV *Operand = visit(Expr->getOperand());
14811 const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Operand);
14812 if (AR && AR->getLoop() == L && AR->isAffine()) {
14813 // This couldn't be folded because the operand didn't have the nuw
14814 // flag. Add the nusw flag as an assumption that we could make.
14815 const SCEV *Step = AR->getStepRecurrence(SE);
14816 Type *Ty = Expr->getType();
14817 if (addOverflowAssumption(AR, SCEVWrapPredicate::IncrementNUSW))
14818 return SE.getAddRecExpr(SE.getZeroExtendExpr(AR->getStart(), Ty),
14819 SE.getSignExtendExpr(Step, Ty), L,
14820 AR->getNoWrapFlags());
14821 }
14822 return SE.getZeroExtendExpr(Operand, Expr->getType());
14823 }
14824
14825 const SCEV *visitSignExtendExpr(const SCEVSignExtendExpr *Expr) {
14826 const SCEV *Operand = visit(Expr->getOperand());
14827 const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Operand);
14828 if (AR && AR->getLoop() == L && AR->isAffine()) {
14829 // This couldn't be folded because the operand didn't have the nsw
14830 // flag. Add the nssw flag as an assumption that we could make.
14831 const SCEV *Step = AR->getStepRecurrence(SE);
14832 Type *Ty = Expr->getType();
14833 if (addOverflowAssumption(AR, SCEVWrapPredicate::IncrementNSSW))
14834 return SE.getAddRecExpr(SE.getSignExtendExpr(AR->getStart(), Ty),
14835 SE.getSignExtendExpr(Step, Ty), L,
14836 AR->getNoWrapFlags());
14837 }
14838 return SE.getSignExtendExpr(Operand, Expr->getType());
14839 }
14840
14841private:
14842 explicit SCEVPredicateRewriter(
14843 const Loop *L, ScalarEvolution &SE,
14845 const SCEVPredicate *Pred)
14846 : SCEVRewriteVisitor(SE), NewPreds(NewPreds), Pred(Pred), L(L) {}
14847
14848 bool addOverflowAssumption(const SCEVPredicate *P) {
14849 if (!NewPreds) {
14850 // Check if we've already made this assumption.
14851 return Pred && Pred->implies(P, SE);
14852 }
14853 NewPreds->push_back(P);
14854 return true;
14855 }
14856
14857 bool addOverflowAssumption(const SCEVAddRecExpr *AR,
14859 auto *A = SE.getWrapPredicate(AR, AddedFlags);
14860 return addOverflowAssumption(A);
14861 }
14862
14863 // If \p Expr represents a PHINode, we try to see if it can be represented
14864 // as an AddRec, possibly under a predicate (PHISCEVPred). If it is possible
14865 // to add this predicate as a runtime overflow check, we return the AddRec.
14866 // If \p Expr does not meet these conditions (is not a PHI node, or we
14867 // couldn't create an AddRec for it, or couldn't add the predicate), we just
14868 // return \p Expr.
14869 const SCEV *convertToAddRecWithPreds(const SCEVUnknown *Expr) {
14870 if (!isa<PHINode>(Expr->getValue()))
14871 return Expr;
14872 std::optional<
14873 std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>>>
14874 PredicatedRewrite = SE.createAddRecFromPHIWithCasts(Expr);
14875 if (!PredicatedRewrite)
14876 return Expr;
14877 for (const auto *P : PredicatedRewrite->second){
14878 // Wrap predicates from outer loops are not supported.
14879 if (auto *WP = dyn_cast<const SCEVWrapPredicate>(P)) {
14880 if (L != WP->getExpr()->getLoop())
14881 return Expr;
14882 }
14883 if (!addOverflowAssumption(P))
14884 return Expr;
14885 }
14886 return PredicatedRewrite->first;
14887 }
14888
14890 const SCEVPredicate *Pred;
14891 const Loop *L;
14892};
14893
14894} // end anonymous namespace
14895
14896const SCEV *
14898 const SCEVPredicate &Preds) {
14899 return SCEVPredicateRewriter::rewrite(S, L, *this, nullptr, &Preds);
14900}
14901
14903 const SCEV *S, const Loop *L,
14906 S = SCEVPredicateRewriter::rewrite(S, L, *this, &TransformPreds, nullptr);
14907 auto *AddRec = dyn_cast<SCEVAddRecExpr>(S);
14908
14909 if (!AddRec)
14910 return nullptr;
14911
14912 // Since the transformation was successful, we can now transfer the SCEV
14913 // predicates.
14914 Preds.append(TransformPreds.begin(), TransformPreds.end());
14915
14916 return AddRec;
14917}
14918
14919/// SCEV predicates
14921 SCEVPredicateKind Kind)
14922 : FastID(ID), Kind(Kind) {}
14923
14925 const ICmpInst::Predicate Pred,
14926 const SCEV *LHS, const SCEV *RHS)
14927 : SCEVPredicate(ID, P_Compare), Pred(Pred), LHS(LHS), RHS(RHS) {
14928 assert(LHS->getType() == RHS->getType() && "LHS and RHS types don't match");
14929 assert(LHS != RHS && "LHS and RHS are the same SCEV");
14930}
14931
14933 ScalarEvolution &SE) const {
14934 const auto *Op = dyn_cast<SCEVComparePredicate>(N);
14935
14936 if (!Op)
14937 return false;
14938
14939 if (Pred != ICmpInst::ICMP_EQ)
14940 return false;
14941
14942 return Op->LHS == LHS && Op->RHS == RHS;
14943}
14944
14945bool SCEVComparePredicate::isAlwaysTrue() const { return false; }
14946
14948 if (Pred == ICmpInst::ICMP_EQ)
14949 OS.indent(Depth) << "Equal predicate: " << *LHS << " == " << *RHS << "\n";
14950 else
14951 OS.indent(Depth) << "Compare predicate: " << *LHS << " " << Pred << ") "
14952 << *RHS << "\n";
14953
14954}
14955
14957 const SCEVAddRecExpr *AR,
14958 IncrementWrapFlags Flags)
14959 : SCEVPredicate(ID, P_Wrap), AR(AR), Flags(Flags) {}
14960
14961const SCEVAddRecExpr *SCEVWrapPredicate::getExpr() const { return AR; }
14962
14964 ScalarEvolution &SE) const {
14965 const auto *Op = dyn_cast<SCEVWrapPredicate>(N);
14966 if (!Op || setFlags(Flags, Op->Flags) != Flags)
14967 return false;
14968
14969 if (Op->AR == AR)
14970 return true;
14971
14972 if (Flags != SCEVWrapPredicate::IncrementNSSW &&
14974 return false;
14975
14976 const SCEV *Start = AR->getStart();
14977 const SCEV *OpStart = Op->AR->getStart();
14978 if (Start->getType()->isPointerTy() != OpStart->getType()->isPointerTy())
14979 return false;
14980
14981 const SCEV *Step = AR->getStepRecurrence(SE);
14982 const SCEV *OpStep = Op->AR->getStepRecurrence(SE);
14983 if (!SE.isKnownPositive(Step) || !SE.isKnownPositive(OpStep))
14984 return false;
14985
14986 // If both steps are positive, this implies N, if N's start and step are
14987 // ULE/SLE (for NSUW/NSSW) than this'.
14988 Type *WiderTy = SE.getWiderType(Step->getType(), OpStep->getType());
14989 Step = SE.getNoopOrZeroExtend(Step, WiderTy);
14990 OpStep = SE.getNoopOrZeroExtend(OpStep, WiderTy);
14991
14992 bool IsNUW = Flags == SCEVWrapPredicate::IncrementNUSW;
14993 OpStart = IsNUW ? SE.getNoopOrZeroExtend(OpStart, WiderTy)
14994 : SE.getNoopOrSignExtend(OpStart, WiderTy);
14995 Start = IsNUW ? SE.getNoopOrZeroExtend(Start, WiderTy)
14996 : SE.getNoopOrSignExtend(Start, WiderTy);
14998 return SE.isKnownPredicate(Pred, OpStep, Step) &&
14999 SE.isKnownPredicate(Pred, OpStart, Start);
15000}
15001
15003 SCEV::NoWrapFlags ScevFlags = AR->getNoWrapFlags();
15004 IncrementWrapFlags IFlags = Flags;
15005
15006 if (ScalarEvolution::setFlags(ScevFlags, SCEV::FlagNSW) == ScevFlags)
15007 IFlags = clearFlags(IFlags, IncrementNSSW);
15008
15009 return IFlags == IncrementAnyWrap;
15010}
15011
15013 OS.indent(Depth) << *getExpr() << " Added Flags: ";
15015 OS << "<nusw>";
15017 OS << "<nssw>";
15018 OS << "\n";
15019}
15020
15023 ScalarEvolution &SE) {
15024 IncrementWrapFlags ImpliedFlags = IncrementAnyWrap;
15025 SCEV::NoWrapFlags StaticFlags = AR->getNoWrapFlags();
15026
15027 // We can safely transfer the NSW flag as NSSW.
15028 if (ScalarEvolution::setFlags(StaticFlags, SCEV::FlagNSW) == StaticFlags)
15029 ImpliedFlags = IncrementNSSW;
15030
15031 if (ScalarEvolution::setFlags(StaticFlags, SCEV::FlagNUW) == StaticFlags) {
15032 // If the increment is positive, the SCEV NUW flag will also imply the
15033 // WrapPredicate NUSW flag.
15034 if (const auto *Step = dyn_cast<SCEVConstant>(AR->getStepRecurrence(SE)))
15035 if (Step->getValue()->getValue().isNonNegative())
15036 ImpliedFlags = setFlags(ImpliedFlags, IncrementNUSW);
15037 }
15038
15039 return ImpliedFlags;
15040}
15041
15042/// Union predicates don't get cached so create a dummy set ID for it.
15044 ScalarEvolution &SE)
15045 : SCEVPredicate(FoldingSetNodeIDRef(nullptr, 0), P_Union) {
15046 for (const auto *P : Preds)
15047 add(P, SE);
15048}
15049
15051 return all_of(Preds,
15052 [](const SCEVPredicate *I) { return I->isAlwaysTrue(); });
15053}
15054
15056 ScalarEvolution &SE) const {
15057 if (const auto *Set = dyn_cast<SCEVUnionPredicate>(N))
15058 return all_of(Set->Preds, [this, &SE](const SCEVPredicate *I) {
15059 return this->implies(I, SE);
15060 });
15061
15062 return any_of(Preds,
15063 [N, &SE](const SCEVPredicate *I) { return I->implies(N, SE); });
15064}
15065
15067 for (const auto *Pred : Preds)
15068 Pred->print(OS, Depth);
15069}
15070
15071void SCEVUnionPredicate::add(const SCEVPredicate *N, ScalarEvolution &SE) {
15072 if (const auto *Set = dyn_cast<SCEVUnionPredicate>(N)) {
15073 for (const auto *Pred : Set->Preds)
15074 add(Pred, SE);
15075 return;
15076 }
15077
15078 // Only add predicate if it is not already implied by this union predicate.
15079 if (implies(N, SE))
15080 return;
15081
15082 // Build a new vector containing the current predicates, except the ones that
15083 // are implied by the new predicate N.
15085 for (auto *P : Preds) {
15086 if (N->implies(P, SE))
15087 continue;
15088 PrunedPreds.push_back(P);
15089 }
15090 Preds = std::move(PrunedPreds);
15091 Preds.push_back(N);
15092}
15093
15095 Loop &L)
15096 : SE(SE), L(L) {
15098 Preds = std::make_unique<SCEVUnionPredicate>(Empty, SE);
15099}
15100
15103 for (const auto *Op : Ops)
15104 // We do not expect that forgetting cached data for SCEVConstants will ever
15105 // open any prospects for sharpening or introduce any correctness issues,
15106 // so we don't bother storing their dependencies.
15107 if (!isa<SCEVConstant>(Op))
15108 SCEVUsers[Op].insert(User);
15109}
15110
15112 const SCEV *Expr = SE.getSCEV(V);
15113 RewriteEntry &Entry = RewriteMap[Expr];
15114
15115 // If we already have an entry and the version matches, return it.
15116 if (Entry.second && Generation == Entry.first)
15117 return Entry.second;
15118
15119 // We found an entry but it's stale. Rewrite the stale entry
15120 // according to the current predicate.
15121 if (Entry.second)
15122 Expr = Entry.second;
15123
15124 const SCEV *NewSCEV = SE.rewriteUsingPredicate(Expr, &L, *Preds);
15125 Entry = {Generation, NewSCEV};
15126
15127 return NewSCEV;
15128}
15129
15131 if (!BackedgeCount) {
15133 BackedgeCount = SE.getPredicatedBackedgeTakenCount(&L, Preds);
15134 for (const auto *P : Preds)
15135 addPredicate(*P);
15136 }
15137 return BackedgeCount;
15138}
15139
15141 if (!SymbolicMaxBackedgeCount) {
15143 SymbolicMaxBackedgeCount =
15145 for (const auto *P : Preds)
15146 addPredicate(*P);
15147 }
15148 return SymbolicMaxBackedgeCount;
15149}
15150
15152 if (!SmallConstantMaxTripCount) {
15154 SmallConstantMaxTripCount = SE.getSmallConstantMaxTripCount(&L, &Preds);
15155 for (const auto *P : Preds)
15156 addPredicate(*P);
15157 }
15158 return *SmallConstantMaxTripCount;
15159}
15160
15162 if (Preds->implies(&Pred, SE))
15163 return;
15164
15165 SmallVector<const SCEVPredicate *, 4> NewPreds(Preds->getPredicates());
15166 NewPreds.push_back(&Pred);
15167 Preds = std::make_unique<SCEVUnionPredicate>(NewPreds, SE);
15168 updateGeneration();
15169}
15170
15172 return *Preds;
15173}
15174
15175void PredicatedScalarEvolution::updateGeneration() {
15176 // If the generation number wrapped recompute everything.
15177 if (++Generation == 0) {
15178 for (auto &II : RewriteMap) {
15179 const SCEV *Rewritten = II.second.second;
15180 II.second = {Generation, SE.rewriteUsingPredicate(Rewritten, &L, *Preds)};
15181 }
15182 }
15183}
15184
15187 const SCEV *Expr = getSCEV(V);
15188 const auto *AR = cast<SCEVAddRecExpr>(Expr);
15189
15190 auto ImpliedFlags = SCEVWrapPredicate::getImpliedFlags(AR, SE);
15191
15192 // Clear the statically implied flags.
15193 Flags = SCEVWrapPredicate::clearFlags(Flags, ImpliedFlags);
15194 addPredicate(*SE.getWrapPredicate(AR, Flags));
15195
15196 auto II = FlagsMap.insert({V, Flags});
15197 if (!II.second)
15198 II.first->second = SCEVWrapPredicate::setFlags(Flags, II.first->second);
15199}
15200
15203 const SCEV *Expr = getSCEV(V);
15204 const auto *AR = cast<SCEVAddRecExpr>(Expr);
15205
15207 Flags, SCEVWrapPredicate::getImpliedFlags(AR, SE));
15208
15209 auto II = FlagsMap.find(V);
15210
15211 if (II != FlagsMap.end())
15212 Flags = SCEVWrapPredicate::clearFlags(Flags, II->second);
15213
15215}
15216
15218 const SCEV *Expr = this->getSCEV(V);
15220 auto *New = SE.convertSCEVToAddRecWithPredicates(Expr, &L, NewPreds);
15221
15222 if (!New)
15223 return nullptr;
15224
15225 for (const auto *P : NewPreds)
15226 addPredicate(*P);
15227
15228 RewriteMap[SE.getSCEV(V)] = {Generation, New};
15229 return New;
15230}
15231
15234 : RewriteMap(Init.RewriteMap), SE(Init.SE), L(Init.L),
15235 Preds(std::make_unique<SCEVUnionPredicate>(Init.Preds->getPredicates(),
15236 SE)),
15237 Generation(Init.Generation), BackedgeCount(Init.BackedgeCount) {
15238 for (auto I : Init.FlagsMap)
15239 FlagsMap.insert(I);
15240}
15241
15243 // For each block.
15244 for (auto *BB : L.getBlocks())
15245 for (auto &I : *BB) {
15246 if (!SE.isSCEVable(I.getType()))
15247 continue;
15248
15249 auto *Expr = SE.getSCEV(&I);
15250 auto II = RewriteMap.find(Expr);
15251
15252 if (II == RewriteMap.end())
15253 continue;
15254
15255 // Don't print things that are not interesting.
15256 if (II->second.second == Expr)
15257 continue;
15258
15259 OS.indent(Depth) << "[PSE]" << I << ":\n";
15260 OS.indent(Depth + 2) << *Expr << "\n";
15261 OS.indent(Depth + 2) << "--> " << *II->second.second << "\n";
15262 }
15263}
15264
15265// Match the mathematical pattern A - (A / B) * B, where A and B can be
15266// arbitrary expressions. Also match zext (trunc A to iB) to iY, which is used
15267// for URem with constant power-of-2 second operands.
15268// It's not always easy, as A and B can be folded (imagine A is X / 2, and B is
15269// 4, A / B becomes X / 8).
15270bool ScalarEvolution::matchURem(const SCEV *Expr, const SCEV *&LHS,
15271 const SCEV *&RHS) {
15272 if (Expr->getType()->isPointerTy())
15273 return false;
15274
15275 // Try to match 'zext (trunc A to iB) to iY', which is used
15276 // for URem with constant power-of-2 second operands. Make sure the size of
15277 // the operand A matches the size of the whole expressions.
15278 if (const auto *ZExt = dyn_cast<SCEVZeroExtendExpr>(Expr))
15279 if (const auto *Trunc = dyn_cast<SCEVTruncateExpr>(ZExt->getOperand(0))) {
15280 LHS = Trunc->getOperand();
15281 // Bail out if the type of the LHS is larger than the type of the
15282 // expression for now.
15283 if (getTypeSizeInBits(LHS->getType()) >
15284 getTypeSizeInBits(Expr->getType()))
15285 return false;
15286 if (LHS->getType() != Expr->getType())
15287 LHS = getZeroExtendExpr(LHS, Expr->getType());
15289 << getTypeSizeInBits(Trunc->getType()));
15290 return true;
15291 }
15292 const auto *Add = dyn_cast<SCEVAddExpr>(Expr);
15293 if (Add == nullptr || Add->getNumOperands() != 2)
15294 return false;
15295
15296 const SCEV *A = Add->getOperand(1);
15297 const auto *Mul = dyn_cast<SCEVMulExpr>(Add->getOperand(0));
15298
15299 if (Mul == nullptr)
15300 return false;
15301
15302 const auto MatchURemWithDivisor = [&](const SCEV *B) {
15303 // (SomeExpr + (-(SomeExpr / B) * B)).
15304 if (Expr == getURemExpr(A, B)) {
15305 LHS = A;
15306 RHS = B;
15307 return true;
15308 }
15309 return false;
15310 };
15311
15312 // (SomeExpr + (-1 * (SomeExpr / B) * B)).
15313 if (Mul->getNumOperands() == 3 && isa<SCEVConstant>(Mul->getOperand(0)))
15314 return MatchURemWithDivisor(Mul->getOperand(1)) ||
15315 MatchURemWithDivisor(Mul->getOperand(2));
15316
15317 // (SomeExpr + ((-SomeExpr / B) * B)) or (SomeExpr + ((SomeExpr / B) * -B)).
15318 if (Mul->getNumOperands() == 2)
15319 return MatchURemWithDivisor(Mul->getOperand(1)) ||
15320 MatchURemWithDivisor(Mul->getOperand(0)) ||
15321 MatchURemWithDivisor(getNegativeSCEV(Mul->getOperand(1))) ||
15322 MatchURemWithDivisor(getNegativeSCEV(Mul->getOperand(0)));
15323 return false;
15324}
15325
15328 BasicBlock *Header = L->getHeader();
15329 BasicBlock *Pred = L->getLoopPredecessor();
15330 LoopGuards Guards(SE);
15332 collectFromBlock(SE, Guards, Header, Pred, VisitedBlocks);
15333 return Guards;
15334}
15335
15336void ScalarEvolution::LoopGuards::collectFromPHI(
15338 const PHINode &Phi, SmallPtrSetImpl<const BasicBlock *> &VisitedBlocks,
15340 unsigned Depth) {
15341 if (!SE.isSCEVable(Phi.getType()))
15342 return;
15343
15344 using MinMaxPattern = std::pair<const SCEVConstant *, SCEVTypes>;
15345 auto GetMinMaxConst = [&](unsigned IncomingIdx) -> MinMaxPattern {
15346 const BasicBlock *InBlock = Phi.getIncomingBlock(IncomingIdx);
15347 if (!VisitedBlocks.insert(InBlock).second)
15348 return {nullptr, scCouldNotCompute};
15349 auto [G, Inserted] = IncomingGuards.try_emplace(InBlock, LoopGuards(SE));
15350 if (Inserted)
15351 collectFromBlock(SE, G->second, Phi.getParent(), InBlock, VisitedBlocks,
15352 Depth + 1);
15353 auto &RewriteMap = G->second.RewriteMap;
15354 if (RewriteMap.empty())
15355 return {nullptr, scCouldNotCompute};
15356 auto S = RewriteMap.find(SE.getSCEV(Phi.getIncomingValue(IncomingIdx)));
15357 if (S == RewriteMap.end())
15358 return {nullptr, scCouldNotCompute};
15359 auto *SM = dyn_cast_if_present<SCEVMinMaxExpr>(S->second);
15360 if (!SM)
15361 return {nullptr, scCouldNotCompute};
15362 if (const SCEVConstant *C0 = dyn_cast<SCEVConstant>(SM->getOperand(0)))
15363 return {C0, SM->getSCEVType()};
15364 return {nullptr, scCouldNotCompute};
15365 };
15366 auto MergeMinMaxConst = [](MinMaxPattern P1,
15367 MinMaxPattern P2) -> MinMaxPattern {
15368 auto [C1, T1] = P1;
15369 auto [C2, T2] = P2;
15370 if (!C1 || !C2 || T1 != T2)
15371 return {nullptr, scCouldNotCompute};
15372 switch (T1) {
15373 case scUMaxExpr:
15374 return {C1->getAPInt().ult(C2->getAPInt()) ? C1 : C2, T1};
15375 case scSMaxExpr:
15376 return {C1->getAPInt().slt(C2->getAPInt()) ? C1 : C2, T1};
15377 case scUMinExpr:
15378 return {C1->getAPInt().ugt(C2->getAPInt()) ? C1 : C2, T1};
15379 case scSMinExpr:
15380 return {C1->getAPInt().sgt(C2->getAPInt()) ? C1 : C2, T1};
15381 default:
15382 llvm_unreachable("Trying to merge non-MinMaxExpr SCEVs.");
15383 }
15384 };
15385 auto P = GetMinMaxConst(0);
15386 for (unsigned int In = 1; In < Phi.getNumIncomingValues(); In++) {
15387 if (!P.first)
15388 break;
15389 P = MergeMinMaxConst(P, GetMinMaxConst(In));
15390 }
15391 if (P.first) {
15392 const SCEV *LHS = SE.getSCEV(const_cast<PHINode *>(&Phi));
15393 SmallVector<const SCEV *, 2> Ops({P.first, LHS});
15394 const SCEV *RHS = SE.getMinMaxExpr(P.second, Ops);
15395 Guards.RewriteMap.insert({LHS, RHS});
15396 }
15397}
15398
15399void ScalarEvolution::LoopGuards::collectFromBlock(
15401 const BasicBlock *Block, const BasicBlock *Pred,
15402 SmallPtrSetImpl<const BasicBlock *> &VisitedBlocks, unsigned Depth) {
15403 SmallVector<const SCEV *> ExprsToRewrite;
15404 auto CollectCondition = [&](ICmpInst::Predicate Predicate, const SCEV *LHS,
15405 const SCEV *RHS,
15407 &RewriteMap) {
15408 // WARNING: It is generally unsound to apply any wrap flags to the proposed
15409 // replacement SCEV which isn't directly implied by the structure of that
15410 // SCEV. In particular, using contextual facts to imply flags is *NOT*
15411 // legal. See the scoping rules for flags in the header to understand why.
15412
15413 // If LHS is a constant, apply information to the other expression.
15414 if (isa<SCEVConstant>(LHS)) {
15415 std::swap(LHS, RHS);
15417 }
15418
15419 // Check for a condition of the form (-C1 + X < C2). InstCombine will
15420 // create this form when combining two checks of the form (X u< C2 + C1) and
15421 // (X >=u C1).
15422 auto MatchRangeCheckIdiom = [&SE, Predicate, LHS, RHS, &RewriteMap,
15423 &ExprsToRewrite]() {
15424 const SCEVConstant *C1;
15425 const SCEVUnknown *LHSUnknown;
15426 auto *C2 = dyn_cast<SCEVConstant>(RHS);
15427 if (!match(LHS,
15428 m_scev_Add(m_SCEVConstant(C1), m_SCEVUnknown(LHSUnknown))) ||
15429 !C2)
15430 return false;
15431
15432 auto ExactRegion =
15433 ConstantRange::makeExactICmpRegion(Predicate, C2->getAPInt())
15434 .sub(C1->getAPInt());
15435
15436 // Bail out, unless we have a non-wrapping, monotonic range.
15437 if (ExactRegion.isWrappedSet() || ExactRegion.isFullSet())
15438 return false;
15439 auto I = RewriteMap.find(LHSUnknown);
15440 const SCEV *RewrittenLHS = I != RewriteMap.end() ? I->second : LHSUnknown;
15441 RewriteMap[LHSUnknown] = SE.getUMaxExpr(
15442 SE.getConstant(ExactRegion.getUnsignedMin()),
15443 SE.getUMinExpr(RewrittenLHS,
15444 SE.getConstant(ExactRegion.getUnsignedMax())));
15445 ExprsToRewrite.push_back(LHSUnknown);
15446 return true;
15447 };
15448 if (MatchRangeCheckIdiom())
15449 return;
15450
15451 // Return true if \p Expr is a MinMax SCEV expression with a non-negative
15452 // constant operand. If so, return in \p SCTy the SCEV type and in \p RHS
15453 // the non-constant operand and in \p LHS the constant operand.
15454 auto IsMinMaxSCEVWithNonNegativeConstant =
15455 [&](const SCEV *Expr, SCEVTypes &SCTy, const SCEV *&LHS,
15456 const SCEV *&RHS) {
15457 if (auto *MinMax = dyn_cast<SCEVMinMaxExpr>(Expr)) {
15458 if (MinMax->getNumOperands() != 2)
15459 return false;
15460 if (auto *C = dyn_cast<SCEVConstant>(MinMax->getOperand(0))) {
15461 if (C->getAPInt().isNegative())
15462 return false;
15463 SCTy = MinMax->getSCEVType();
15464 LHS = MinMax->getOperand(0);
15465 RHS = MinMax->getOperand(1);
15466 return true;
15467 }
15468 }
15469 return false;
15470 };
15471
15472 // Checks whether Expr is a non-negative constant, and Divisor is a positive
15473 // constant, and returns their APInt in ExprVal and in DivisorVal.
15474 auto GetNonNegExprAndPosDivisor = [&](const SCEV *Expr, const SCEV *Divisor,
15475 APInt &ExprVal, APInt &DivisorVal) {
15476 auto *ConstExpr = dyn_cast<SCEVConstant>(Expr);
15477 auto *ConstDivisor = dyn_cast<SCEVConstant>(Divisor);
15478 if (!ConstExpr || !ConstDivisor)
15479 return false;
15480 ExprVal = ConstExpr->getAPInt();
15481 DivisorVal = ConstDivisor->getAPInt();
15482 return ExprVal.isNonNegative() && !DivisorVal.isNonPositive();
15483 };
15484
15485 // Return a new SCEV that modifies \p Expr to the closest number divides by
15486 // \p Divisor and greater or equal than Expr.
15487 // For now, only handle constant Expr and Divisor.
15488 auto GetNextSCEVDividesByDivisor = [&](const SCEV *Expr,
15489 const SCEV *Divisor) {
15490 APInt ExprVal;
15491 APInt DivisorVal;
15492 if (!GetNonNegExprAndPosDivisor(Expr, Divisor, ExprVal, DivisorVal))
15493 return Expr;
15494 APInt Rem = ExprVal.urem(DivisorVal);
15495 if (!Rem.isZero())
15496 // return the SCEV: Expr + Divisor - Expr % Divisor
15497 return SE.getConstant(ExprVal + DivisorVal - Rem);
15498 return Expr;
15499 };
15500
15501 // Return a new SCEV that modifies \p Expr to the closest number divides by
15502 // \p Divisor and less or equal than Expr.
15503 // For now, only handle constant Expr and Divisor.
15504 auto GetPreviousSCEVDividesByDivisor = [&](const SCEV *Expr,
15505 const SCEV *Divisor) {
15506 APInt ExprVal;
15507 APInt DivisorVal;
15508 if (!GetNonNegExprAndPosDivisor(Expr, Divisor, ExprVal, DivisorVal))
15509 return Expr;
15510 APInt Rem = ExprVal.urem(DivisorVal);
15511 // return the SCEV: Expr - Expr % Divisor
15512 return SE.getConstant(ExprVal - Rem);
15513 };
15514
15515 // Apply divisibilty by \p Divisor on MinMaxExpr with constant values,
15516 // recursively. This is done by aligning up/down the constant value to the
15517 // Divisor.
15518 std::function<const SCEV *(const SCEV *, const SCEV *)>
15519 ApplyDivisibiltyOnMinMaxExpr = [&](const SCEV *MinMaxExpr,
15520 const SCEV *Divisor) {
15521 const SCEV *MinMaxLHS = nullptr, *MinMaxRHS = nullptr;
15522 SCEVTypes SCTy;
15523 if (!IsMinMaxSCEVWithNonNegativeConstant(MinMaxExpr, SCTy, MinMaxLHS,
15524 MinMaxRHS))
15525 return MinMaxExpr;
15526 auto IsMin =
15527 isa<SCEVSMinExpr>(MinMaxExpr) || isa<SCEVUMinExpr>(MinMaxExpr);
15528 assert(SE.isKnownNonNegative(MinMaxLHS) &&
15529 "Expected non-negative operand!");
15530 auto *DivisibleExpr =
15531 IsMin ? GetPreviousSCEVDividesByDivisor(MinMaxLHS, Divisor)
15532 : GetNextSCEVDividesByDivisor(MinMaxLHS, Divisor);
15534 ApplyDivisibiltyOnMinMaxExpr(MinMaxRHS, Divisor), DivisibleExpr};
15535 return SE.getMinMaxExpr(SCTy, Ops);
15536 };
15537
15538 // If we have LHS == 0, check if LHS is computing a property of some unknown
15539 // SCEV %v which we can rewrite %v to express explicitly.
15540 if (Predicate == CmpInst::ICMP_EQ && match(RHS, m_scev_Zero())) {
15541 // If LHS is A % B, i.e. A % B == 0, rewrite A to (A /u B) * B to
15542 // explicitly express that.
15543 const SCEV *URemLHS = nullptr;
15544 const SCEV *URemRHS = nullptr;
15545 if (SE.matchURem(LHS, URemLHS, URemRHS)) {
15546 if (const SCEVUnknown *LHSUnknown = dyn_cast<SCEVUnknown>(URemLHS)) {
15547 auto I = RewriteMap.find(LHSUnknown);
15548 const SCEV *RewrittenLHS =
15549 I != RewriteMap.end() ? I->second : LHSUnknown;
15550 RewrittenLHS = ApplyDivisibiltyOnMinMaxExpr(RewrittenLHS, URemRHS);
15551 const auto *Multiple =
15552 SE.getMulExpr(SE.getUDivExpr(RewrittenLHS, URemRHS), URemRHS);
15553 RewriteMap[LHSUnknown] = Multiple;
15554 ExprsToRewrite.push_back(LHSUnknown);
15555 return;
15556 }
15557 }
15558 }
15559
15560 // Do not apply information for constants or if RHS contains an AddRec.
15561 if (isa<SCEVConstant>(LHS) || SE.containsAddRecurrence(RHS))
15562 return;
15563
15564 // If RHS is SCEVUnknown, make sure the information is applied to it.
15565 if (!isa<SCEVUnknown>(LHS) && isa<SCEVUnknown>(RHS)) {
15566 std::swap(LHS, RHS);
15568 }
15569
15570 // Puts rewrite rule \p From -> \p To into the rewrite map. Also if \p From
15571 // and \p FromRewritten are the same (i.e. there has been no rewrite
15572 // registered for \p From), then puts this value in the list of rewritten
15573 // expressions.
15574 auto AddRewrite = [&](const SCEV *From, const SCEV *FromRewritten,
15575 const SCEV *To) {
15576 if (From == FromRewritten)
15577 ExprsToRewrite.push_back(From);
15578 RewriteMap[From] = To;
15579 };
15580
15581 // Checks whether \p S has already been rewritten. In that case returns the
15582 // existing rewrite because we want to chain further rewrites onto the
15583 // already rewritten value. Otherwise returns \p S.
15584 auto GetMaybeRewritten = [&](const SCEV *S) {
15585 auto I = RewriteMap.find(S);
15586 return I != RewriteMap.end() ? I->second : S;
15587 };
15588
15589 // Check for the SCEV expression (A /u B) * B while B is a constant, inside
15590 // \p Expr. The check is done recuresively on \p Expr, which is assumed to
15591 // be a composition of Min/Max SCEVs. Return whether the SCEV expression (A
15592 // /u B) * B was found, and return the divisor B in \p DividesBy. For
15593 // example, if Expr = umin (umax ((A /u 8) * 8, 16), 64), return true since
15594 // (A /u 8) * 8 matched the pattern, and return the constant SCEV 8 in \p
15595 // DividesBy.
15596 std::function<bool(const SCEV *, const SCEV *&)> HasDivisibiltyInfo =
15597 [&](const SCEV *Expr, const SCEV *&DividesBy) {
15598 if (auto *Mul = dyn_cast<SCEVMulExpr>(Expr)) {
15599 if (Mul->getNumOperands() != 2)
15600 return false;
15601 auto *MulLHS = Mul->getOperand(0);
15602 auto *MulRHS = Mul->getOperand(1);
15603 if (isa<SCEVConstant>(MulLHS))
15604 std::swap(MulLHS, MulRHS);
15605 if (auto *Div = dyn_cast<SCEVUDivExpr>(MulLHS))
15606 if (Div->getOperand(1) == MulRHS) {
15607 DividesBy = MulRHS;
15608 return true;
15609 }
15610 }
15611 if (auto *MinMax = dyn_cast<SCEVMinMaxExpr>(Expr))
15612 return HasDivisibiltyInfo(MinMax->getOperand(0), DividesBy) ||
15613 HasDivisibiltyInfo(MinMax->getOperand(1), DividesBy);
15614 return false;
15615 };
15616
15617 // Return true if Expr known to divide by \p DividesBy.
15618 std::function<bool(const SCEV *, const SCEV *&)> IsKnownToDivideBy =
15619 [&](const SCEV *Expr, const SCEV *DividesBy) {
15620 if (SE.getURemExpr(Expr, DividesBy)->isZero())
15621 return true;
15622 if (auto *MinMax = dyn_cast<SCEVMinMaxExpr>(Expr))
15623 return IsKnownToDivideBy(MinMax->getOperand(0), DividesBy) &&
15624 IsKnownToDivideBy(MinMax->getOperand(1), DividesBy);
15625 return false;
15626 };
15627
15628 const SCEV *RewrittenLHS = GetMaybeRewritten(LHS);
15629 const SCEV *DividesBy = nullptr;
15630 if (HasDivisibiltyInfo(RewrittenLHS, DividesBy))
15631 // Check that the whole expression is divided by DividesBy
15632 DividesBy =
15633 IsKnownToDivideBy(RewrittenLHS, DividesBy) ? DividesBy : nullptr;
15634
15635 // Collect rewrites for LHS and its transitive operands based on the
15636 // condition.
15637 // For min/max expressions, also apply the guard to its operands:
15638 // 'min(a, b) >= c' -> '(a >= c) and (b >= c)',
15639 // 'min(a, b) > c' -> '(a > c) and (b > c)',
15640 // 'max(a, b) <= c' -> '(a <= c) and (b <= c)',
15641 // 'max(a, b) < c' -> '(a < c) and (b < c)'.
15642
15643 // We cannot express strict predicates in SCEV, so instead we replace them
15644 // with non-strict ones against plus or minus one of RHS depending on the
15645 // predicate.
15646 const SCEV *One = SE.getOne(RHS->getType());
15647 switch (Predicate) {
15648 case CmpInst::ICMP_ULT:
15649 if (RHS->getType()->isPointerTy())
15650 return;
15651 RHS = SE.getUMaxExpr(RHS, One);
15652 [[fallthrough]];
15653 case CmpInst::ICMP_SLT: {
15654 RHS = SE.getMinusSCEV(RHS, One);
15655 RHS = DividesBy ? GetPreviousSCEVDividesByDivisor(RHS, DividesBy) : RHS;
15656 break;
15657 }
15658 case CmpInst::ICMP_UGT:
15659 case CmpInst::ICMP_SGT:
15660 RHS = SE.getAddExpr(RHS, One);
15661 RHS = DividesBy ? GetNextSCEVDividesByDivisor(RHS, DividesBy) : RHS;
15662 break;
15663 case CmpInst::ICMP_ULE:
15664 case CmpInst::ICMP_SLE:
15665 RHS = DividesBy ? GetPreviousSCEVDividesByDivisor(RHS, DividesBy) : RHS;
15666 break;
15667 case CmpInst::ICMP_UGE:
15668 case CmpInst::ICMP_SGE:
15669 RHS = DividesBy ? GetNextSCEVDividesByDivisor(RHS, DividesBy) : RHS;
15670 break;
15671 default:
15672 break;
15673 }
15674
15675 SmallVector<const SCEV *, 16> Worklist(1, LHS);
15677
15678 auto EnqueueOperands = [&Worklist](const SCEVNAryExpr *S) {
15679 append_range(Worklist, S->operands());
15680 };
15681
15682 while (!Worklist.empty()) {
15683 const SCEV *From = Worklist.pop_back_val();
15684 if (isa<SCEVConstant>(From))
15685 continue;
15686 if (!Visited.insert(From).second)
15687 continue;
15688 const SCEV *FromRewritten = GetMaybeRewritten(From);
15689 const SCEV *To = nullptr;
15690
15691 switch (Predicate) {
15692 case CmpInst::ICMP_ULT:
15693 case CmpInst::ICMP_ULE:
15694 To = SE.getUMinExpr(FromRewritten, RHS);
15695 if (auto *UMax = dyn_cast<SCEVUMaxExpr>(FromRewritten))
15696 EnqueueOperands(UMax);
15697 break;
15698 case CmpInst::ICMP_SLT:
15699 case CmpInst::ICMP_SLE:
15700 To = SE.getSMinExpr(FromRewritten, RHS);
15701 if (auto *SMax = dyn_cast<SCEVSMaxExpr>(FromRewritten))
15702 EnqueueOperands(SMax);
15703 break;
15704 case CmpInst::ICMP_UGT:
15705 case CmpInst::ICMP_UGE:
15706 To = SE.getUMaxExpr(FromRewritten, RHS);
15707 if (auto *UMin = dyn_cast<SCEVUMinExpr>(FromRewritten))
15708 EnqueueOperands(UMin);
15709 break;
15710 case CmpInst::ICMP_SGT:
15711 case CmpInst::ICMP_SGE:
15712 To = SE.getSMaxExpr(FromRewritten, RHS);
15713 if (auto *SMin = dyn_cast<SCEVSMinExpr>(FromRewritten))
15714 EnqueueOperands(SMin);
15715 break;
15716 case CmpInst::ICMP_EQ:
15717 if (isa<SCEVConstant>(RHS))
15718 To = RHS;
15719 break;
15720 case CmpInst::ICMP_NE:
15721 if (match(RHS, m_scev_Zero())) {
15722 const SCEV *OneAlignedUp =
15723 DividesBy ? GetNextSCEVDividesByDivisor(One, DividesBy) : One;
15724 To = SE.getUMaxExpr(FromRewritten, OneAlignedUp);
15725 }
15726 break;
15727 default:
15728 break;
15729 }
15730
15731 if (To)
15732 AddRewrite(From, FromRewritten, To);
15733 }
15734 };
15735
15737 // First, collect information from assumptions dominating the loop.
15738 for (auto &AssumeVH : SE.AC.assumptions()) {
15739 if (!AssumeVH)
15740 continue;
15741 auto *AssumeI = cast<CallInst>(AssumeVH);
15742 if (!SE.DT.dominates(AssumeI, Block))
15743 continue;
15744 Terms.emplace_back(AssumeI->getOperand(0), true);
15745 }
15746
15747 // Second, collect information from llvm.experimental.guards dominating the loop.
15748 auto *GuardDecl = Intrinsic::getDeclarationIfExists(
15749 SE.F.getParent(), Intrinsic::experimental_guard);
15750 if (GuardDecl)
15751 for (const auto *GU : GuardDecl->users())
15752 if (const auto *Guard = dyn_cast<IntrinsicInst>(GU))
15753 if (Guard->getFunction() == Block->getParent() &&
15754 SE.DT.dominates(Guard, Block))
15755 Terms.emplace_back(Guard->getArgOperand(0), true);
15756
15757 // Third, collect conditions from dominating branches. Starting at the loop
15758 // predecessor, climb up the predecessor chain, as long as there are
15759 // predecessors that can be found that have unique successors leading to the
15760 // original header.
15761 // TODO: share this logic with isLoopEntryGuardedByCond.
15762 unsigned NumCollectedConditions = 0;
15763 VisitedBlocks.insert(Block);
15764 std::pair<const BasicBlock *, const BasicBlock *> Pair(Pred, Block);
15765 for (; Pair.first;
15766 Pair = SE.getPredecessorWithUniqueSuccessorForBB(Pair.first)) {
15767 VisitedBlocks.insert(Pair.second);
15768 const BranchInst *LoopEntryPredicate =
15769 dyn_cast<BranchInst>(Pair.first->getTerminator());
15770 if (!LoopEntryPredicate || LoopEntryPredicate->isUnconditional())
15771 continue;
15772
15773 Terms.emplace_back(LoopEntryPredicate->getCondition(),
15774 LoopEntryPredicate->getSuccessor(0) == Pair.second);
15775 NumCollectedConditions++;
15776
15777 // If we are recursively collecting guards stop after 2
15778 // conditions to limit compile-time impact for now.
15779 if (Depth > 0 && NumCollectedConditions == 2)
15780 break;
15781 }
15782 // Finally, if we stopped climbing the predecessor chain because
15783 // there wasn't a unique one to continue, try to collect conditions
15784 // for PHINodes by recursively following all of their incoming
15785 // blocks and try to merge the found conditions to build a new one
15786 // for the Phi.
15787 if (Pair.second->hasNPredecessorsOrMore(2) &&
15790 for (auto &Phi : Pair.second->phis())
15791 collectFromPHI(SE, Guards, Phi, VisitedBlocks, IncomingGuards, Depth);
15792 }
15793
15794 // Now apply the information from the collected conditions to
15795 // Guards.RewriteMap. Conditions are processed in reverse order, so the
15796 // earliest conditions is processed first. This ensures the SCEVs with the
15797 // shortest dependency chains are constructed first.
15798 for (auto [Term, EnterIfTrue] : reverse(Terms)) {
15799 SmallVector<Value *, 8> Worklist;
15801 Worklist.push_back(Term);
15802 while (!Worklist.empty()) {
15803 Value *Cond = Worklist.pop_back_val();
15804 if (!Visited.insert(Cond).second)
15805 continue;
15806
15807 if (auto *Cmp = dyn_cast<ICmpInst>(Cond)) {
15808 auto Predicate =
15809 EnterIfTrue ? Cmp->getPredicate() : Cmp->getInversePredicate();
15810 const auto *LHS = SE.getSCEV(Cmp->getOperand(0));
15811 const auto *RHS = SE.getSCEV(Cmp->getOperand(1));
15812 CollectCondition(Predicate, LHS, RHS, Guards.RewriteMap);
15813 continue;
15814 }
15815
15816 Value *L, *R;
15817 if (EnterIfTrue ? match(Cond, m_LogicalAnd(m_Value(L), m_Value(R)))
15818 : match(Cond, m_LogicalOr(m_Value(L), m_Value(R)))) {
15819 Worklist.push_back(L);
15820 Worklist.push_back(R);
15821 }
15822 }
15823 }
15824
15825 // Let the rewriter preserve NUW/NSW flags if the unsigned/signed ranges of
15826 // the replacement expressions are contained in the ranges of the replaced
15827 // expressions.
15828 Guards.PreserveNUW = true;
15829 Guards.PreserveNSW = true;
15830 for (const SCEV *Expr : ExprsToRewrite) {
15831 const SCEV *RewriteTo = Guards.RewriteMap[Expr];
15832 Guards.PreserveNUW &=
15833 SE.getUnsignedRange(Expr).contains(SE.getUnsignedRange(RewriteTo));
15834 Guards.PreserveNSW &=
15835 SE.getSignedRange(Expr).contains(SE.getSignedRange(RewriteTo));
15836 }
15837
15838 // Now that all rewrite information is collect, rewrite the collected
15839 // expressions with the information in the map. This applies information to
15840 // sub-expressions.
15841 if (ExprsToRewrite.size() > 1) {
15842 for (const SCEV *Expr : ExprsToRewrite) {
15843 const SCEV *RewriteTo = Guards.RewriteMap[Expr];
15844 Guards.RewriteMap.erase(Expr);
15845 Guards.RewriteMap.insert({Expr, Guards.rewrite(RewriteTo)});
15846 }
15847 }
15848}
15849
15851 /// A rewriter to replace SCEV expressions in Map with the corresponding entry
15852 /// in the map. It skips AddRecExpr because we cannot guarantee that the
15853 /// replacement is loop invariant in the loop of the AddRec.
15854 class SCEVLoopGuardRewriter
15855 : public SCEVRewriteVisitor<SCEVLoopGuardRewriter> {
15857
15859
15860 public:
15861 SCEVLoopGuardRewriter(ScalarEvolution &SE,
15862 const ScalarEvolution::LoopGuards &Guards)
15863 : SCEVRewriteVisitor(SE), Map(Guards.RewriteMap) {
15864 if (Guards.PreserveNUW)
15865 FlagMask = ScalarEvolution::setFlags(FlagMask, SCEV::FlagNUW);
15866 if (Guards.PreserveNSW)
15867 FlagMask = ScalarEvolution::setFlags(FlagMask, SCEV::FlagNSW);
15868 }
15869
15870 const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) { return Expr; }
15871
15872 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
15873 auto I = Map.find(Expr);
15874 if (I == Map.end())
15875 return Expr;
15876 return I->second;
15877 }
15878
15879 const SCEV *visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr) {
15880 auto I = Map.find(Expr);
15881 if (I == Map.end()) {
15882 // If we didn't find the extact ZExt expr in the map, check if there's
15883 // an entry for a smaller ZExt we can use instead.
15884 Type *Ty = Expr->getType();
15885 const SCEV *Op = Expr->getOperand(0);
15886 unsigned Bitwidth = Ty->getScalarSizeInBits() / 2;
15887 while (Bitwidth % 8 == 0 && Bitwidth >= 8 &&
15888 Bitwidth > Op->getType()->getScalarSizeInBits()) {
15889 Type *NarrowTy = IntegerType::get(SE.getContext(), Bitwidth);
15890 auto *NarrowExt = SE.getZeroExtendExpr(Op, NarrowTy);
15891 auto I = Map.find(NarrowExt);
15892 if (I != Map.end())
15893 return SE.getZeroExtendExpr(I->second, Ty);
15894 Bitwidth = Bitwidth / 2;
15895 }
15896
15898 Expr);
15899 }
15900 return I->second;
15901 }
15902
15903 const SCEV *visitSignExtendExpr(const SCEVSignExtendExpr *Expr) {
15904 auto I = Map.find(Expr);
15905 if (I == Map.end())
15907 Expr);
15908 return I->second;
15909 }
15910
15911 const SCEV *visitUMinExpr(const SCEVUMinExpr *Expr) {
15912 auto I = Map.find(Expr);
15913 if (I == Map.end())
15915 return I->second;
15916 }
15917
15918 const SCEV *visitSMinExpr(const SCEVSMinExpr *Expr) {
15919 auto I = Map.find(Expr);
15920 if (I == Map.end())
15922 return I->second;
15923 }
15924
15925 const SCEV *visitAddExpr(const SCEVAddExpr *Expr) {
15927 bool Changed = false;
15928 for (const auto *Op : Expr->operands()) {
15929 Operands.push_back(
15931 Changed |= Op != Operands.back();
15932 }
15933 // We are only replacing operands with equivalent values, so transfer the
15934 // flags from the original expression.
15935 return !Changed ? Expr
15936 : SE.getAddExpr(Operands,
15938 Expr->getNoWrapFlags(), FlagMask));
15939 }
15940
15941 const SCEV *visitMulExpr(const SCEVMulExpr *Expr) {
15943 bool Changed = false;
15944 for (const auto *Op : Expr->operands()) {
15945 Operands.push_back(
15947 Changed |= Op != Operands.back();
15948 }
15949 // We are only replacing operands with equivalent values, so transfer the
15950 // flags from the original expression.
15951 return !Changed ? Expr
15952 : SE.getMulExpr(Operands,
15954 Expr->getNoWrapFlags(), FlagMask));
15955 }
15956 };
15957
15958 if (RewriteMap.empty())
15959 return Expr;
15960
15961 SCEVLoopGuardRewriter Rewriter(SE, *this);
15962 return Rewriter.visit(Expr);
15963}
15964
15965const SCEV *ScalarEvolution::applyLoopGuards(const SCEV *Expr, const Loop *L) {
15966 return applyLoopGuards(Expr, LoopGuards::collect(L, *this));
15967}
15968
15970 const LoopGuards &Guards) {
15971 return Guards.rewrite(Expr);
15972}
@ Poison
static const LLT S1
Rewrite undef for PHI
This file implements a class to represent arbitrary precision integral constant values and operations...
@ PostInc
MachineBasicBlock MachineBasicBlock::iterator DebugLoc DL
Expand Atomic instructions
basic Basic Alias true
block Block Frequency Analysis
BlockVerifier::State From
static GCRegistry::Add< OcamlGC > B("ocaml", "ocaml 3.10-compatible GC")
static GCRegistry::Add< ErlangGC > A("erlang", "erlang-compatible garbage collector")
static GCRegistry::Add< StatepointGC > D("statepoint-example", "an example strategy for statepoint")
#define LLVM_DUMP_METHOD
Mark debug helper function definitions like dump() that should not be stripped from debug builds.
Definition: Compiler.h:622
This file contains the declarations for the subclasses of Constant, which represent the different fla...
Returns the sub type a function will return at a given Idx Should correspond to the result type of an ExtractValue instruction executed with just that one unsigned Idx
#define LLVM_DEBUG(...)
Definition: Debug.h:106
This file defines the DenseMap class.
This file builds on the ADT/GraphTraits.h file to build generic depth first graph iterator.
uint64_t Size
bool End
Definition: ELF_riscv.cpp:480
Generic implementation of equivalence classes through the use Tarjan's efficient union-find algorithm...
static GCMetadataPrinterRegistry::Add< ErlangGCPrinter > X("erlang", "erlang-compatible garbage collector")
static bool isSigned(unsigned int Opcode)
This file defines a hash set that can be used to remove duplication of nodes in a graph.
#define op(i)
Hexagon Common GEP
This file provides various utilities for inspecting and working with the control flow graph in LLVM I...
This defines the Use class.
iv Induction Variable Users
Definition: IVUsers.cpp:48
static bool isZero(Value *V, const DataLayout &DL, DominatorTree *DT, AssumptionCache *AC)
Definition: Lint.cpp:557
#define F(x, y, z)
Definition: MD5.cpp:55
#define I(x, y, z)
Definition: MD5.cpp:58
#define G(x, y, z)
Definition: MD5.cpp:56
mir Rename Register Operands
#define T1
ConstantRange Range(APInt(BitWidth, Low), APInt(BitWidth, High))
uint64_t IntrinsicInst * II
static GCMetadataPrinterRegistry::Add< OcamlGCMetadataPrinter > Y("ocaml", "ocaml 3.10-compatible collector")
#define P(N)
ppc ctr loops verify
PowerPC Reduce CR logical Operation
if(PassOpts->AAPipeline)
#define INITIALIZE_PASS_DEPENDENCY(depName)
Definition: PassSupport.h:55
#define INITIALIZE_PASS_END(passName, arg, name, cfg, analysis)
Definition: PassSupport.h:57
#define INITIALIZE_PASS_BEGIN(passName, arg, name, cfg, analysis)
Definition: PassSupport.h:52
R600 Clause Merge
const SmallVectorImpl< MachineOperand > & Cond
static bool isValid(const char C)
Returns true if C is a valid mangled character: <0-9a-zA-Z_>.
assert(ImpDefSCC.getReg()==AMDGPU::SCC &&ImpDefSCC.isDef())
SI optimize exec mask operations pre RA
void visit(MachineFunction &MF, MachineBasicBlock &Start, std::function< void(MachineBasicBlock *)> op)
This file contains some templates that are useful if you are working with the STL at all.
raw_pwrite_stream & OS
This file provides utility classes that use RAII to save and restore values.
bool SCEVMinMaxExprContains(const SCEV *Root, const SCEV *OperandToFind, SCEVTypes RootKind)
static cl::opt< unsigned > MaxAddRecSize("scalar-evolution-max-add-rec-size", cl::Hidden, cl::desc("Max coefficients in AddRec during evolving"), cl::init(8))
static cl::opt< unsigned > RangeIterThreshold("scev-range-iter-threshold", cl::Hidden, cl::desc("Threshold for switching to iteratively computing SCEV ranges"), cl::init(32))
static const Loop * isIntegerLoopHeaderPHI(const PHINode *PN, LoopInfo &LI)
static unsigned getConstantTripCount(const SCEVConstant *ExitCount)
static int CompareValueComplexity(const LoopInfo *const LI, Value *LV, Value *RV, unsigned Depth)
Compare the two values LV and RV in terms of their "complexity" where "complexity" is a partial (and ...
static void PushLoopPHIs(const Loop *L, SmallVectorImpl< Instruction * > &Worklist, SmallPtrSetImpl< Instruction * > &Visited)
Push PHI nodes in the header of the given loop onto the given Worklist.
static void insertFoldCacheEntry(const ScalarEvolution::FoldID &ID, const SCEV *S, DenseMap< ScalarEvolution::FoldID, const SCEV * > &FoldCache, DenseMap< const SCEV *, SmallVector< ScalarEvolution::FoldID, 2 > > &FoldCacheUser)
static cl::opt< bool > ClassifyExpressions("scalar-evolution-classify-expressions", cl::Hidden, cl::init(true), cl::desc("When printing analysis, include information on every instruction"))
static bool CanConstantFold(const Instruction *I)
Return true if we can constant fold an instruction of the specified type, assuming that all operands ...
static cl::opt< unsigned > AddOpsInlineThreshold("scev-addops-inline-threshold", cl::Hidden, cl::desc("Threshold for inlining addition operands into a SCEV"), cl::init(500))
static cl::opt< unsigned > MaxLoopGuardCollectionDepth("scalar-evolution-max-loop-guard-collection-depth", cl::Hidden, cl::desc("Maximum depth for recursive loop guard collection"), cl::init(1))
static cl::opt< bool > VerifyIR("scev-verify-ir", cl::Hidden, cl::desc("Verify IR correctness when making sensitive SCEV queries (slow)"), cl::init(false))
static bool BrPHIToSelect(DominatorTree &DT, BranchInst *BI, PHINode *Merge, Value *&C, Value *&LHS, Value *&RHS)
static std::optional< int > CompareSCEVComplexity(EquivalenceClasses< const SCEV * > &EqCacheSCEV, const LoopInfo *const LI, const SCEV *LHS, const SCEV *RHS, DominatorTree &DT, unsigned Depth=0)
static const SCEV * getPreStartForExtend(const SCEVAddRecExpr *AR, Type *Ty, ScalarEvolution *SE, unsigned Depth)
static std::optional< APInt > MinOptional(std::optional< APInt > X, std::optional< APInt > Y)
Helper function to compare optional APInts: (a) if X and Y both exist, return min(X,...
static cl::opt< unsigned > MulOpsInlineThreshold("scev-mulops-inline-threshold", cl::Hidden, cl::desc("Threshold for inlining multiplication operands into a SCEV"), cl::init(32))
static void GroupByComplexity(SmallVectorImpl< const SCEV * > &Ops, LoopInfo *LI, DominatorTree &DT)
Given a list of SCEV objects, order them by their complexity, and group objects of the same complexit...
static const SCEV * constantFoldAndGroupOps(ScalarEvolution &SE, LoopInfo &LI, DominatorTree &DT, SmallVectorImpl< const SCEV * > &Ops, FoldT Fold, IsIdentityT IsIdentity, IsAbsorberT IsAbsorber)
Performs a number of common optimizations on the passed Ops.
static std::optional< const SCEV * > createNodeForSelectViaUMinSeq(ScalarEvolution *SE, const SCEV *CondExpr, const SCEV *TrueExpr, const SCEV *FalseExpr)
static Constant * BuildConstantFromSCEV(const SCEV *V)
This builds up a Constant using the ConstantExpr interface.
static ConstantInt * EvaluateConstantChrecAtConstant(const SCEVAddRecExpr *AddRec, ConstantInt *C, ScalarEvolution &SE)
static const SCEV * BinomialCoefficient(const SCEV *It, unsigned K, ScalarEvolution &SE, Type *ResultTy)
Compute BC(It, K). The result has width W. Assume, K > 0.
static cl::opt< unsigned > MaxCastDepth("scalar-evolution-max-cast-depth", cl::Hidden, cl::desc("Maximum depth of recursive SExt/ZExt/Trunc"), cl::init(8))
static bool IsMinMaxConsistingOf(const SCEV *MaybeMinMaxExpr, const SCEV *Candidate)
Is MaybeMinMaxExpr an (U|S)(Min|Max) of Candidate and some other values?
static PHINode * getConstantEvolvingPHI(Value *V, const Loop *L)
getConstantEvolvingPHI - Given an LLVM value and a loop, return a PHI node in the loop that V is deri...
static cl::opt< unsigned > MaxBruteForceIterations("scalar-evolution-max-iterations", cl::ReallyHidden, cl::desc("Maximum number of iterations SCEV will " "symbolically execute a constant " "derived loop"), cl::init(100))
static bool MatchBinarySub(const SCEV *S, const SCEV *&LHS, const SCEV *&RHS)
static uint64_t umul_ov(uint64_t i, uint64_t j, bool &Overflow)
static void PrintSCEVWithTypeHint(raw_ostream &OS, const SCEV *S)
When printing a top-level SCEV for trip counts, it's helpful to include a type for constants which ar...
static void PrintLoopInfo(raw_ostream &OS, ScalarEvolution *SE, const Loop *L)
static bool containsConstantInAddMulChain(const SCEV *StartExpr)
Determine if any of the operands in this SCEV are a constant or if any of the add or multiply express...
static const SCEV * getExtendAddRecStart(const SCEVAddRecExpr *AR, Type *Ty, ScalarEvolution *SE, unsigned Depth)
static bool hasHugeExpression(ArrayRef< const SCEV * > Ops)
Returns true if Ops contains a huge SCEV (the subtree of S contains at least HugeExprThreshold nodes)...
static cl::opt< unsigned > MaxPhiSCCAnalysisSize("scalar-evolution-max-scc-analysis-depth", cl::Hidden, cl::desc("Maximum amount of nodes to process while searching SCEVUnknown " "Phi strongly connected components"), cl::init(8))
static bool IsKnownPredicateViaAddRecStart(ScalarEvolution &SE, CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS)
static cl::opt< unsigned > MaxSCEVOperationsImplicationDepth("scalar-evolution-max-scev-operations-implication-depth", cl::Hidden, cl::desc("Maximum depth of recursive SCEV operations implication analysis"), cl::init(2))
static void PushDefUseChildren(Instruction *I, SmallVectorImpl< Instruction * > &Worklist, SmallPtrSetImpl< Instruction * > &Visited)
Push users of the given Instruction onto the given Worklist.
static std::optional< APInt > SolveQuadraticAddRecRange(const SCEVAddRecExpr *AddRec, const ConstantRange &Range, ScalarEvolution &SE)
Let c(n) be the value of the quadratic chrec {0,+,M,+,N} after n iterations.
static cl::opt< bool > UseContextForNoWrapFlagInference("scalar-evolution-use-context-for-no-wrap-flag-strenghening", cl::Hidden, cl::desc("Infer nuw/nsw flags using context where suitable"), cl::init(true))
static cl::opt< bool > EnableFiniteLoopControl("scalar-evolution-finite-loop", cl::Hidden, cl::desc("Handle <= and >= in finite loops"), cl::init(true))
static std::optional< std::tuple< APInt, APInt, APInt, APInt, unsigned > > GetQuadraticEquation(const SCEVAddRecExpr *AddRec)
For a given quadratic addrec, generate coefficients of the corresponding quadratic equation,...
static bool isKnownPredicateExtendIdiom(CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS)
static std::optional< BinaryOp > MatchBinaryOp(Value *V, const DataLayout &DL, AssumptionCache &AC, const DominatorTree &DT, const Instruction *CxtI)
Try to map V into a BinaryOp, and return std::nullopt on failure.
static std::optional< APInt > SolveQuadraticAddRecExact(const SCEVAddRecExpr *AddRec, ScalarEvolution &SE)
Let c(n) be the value of the quadratic chrec {L,+,M,+,N} after n iterations.
static std::optional< APInt > TruncIfPossible(std::optional< APInt > X, unsigned BitWidth)
Helper function to truncate an optional APInt to a given BitWidth.
static cl::opt< unsigned > MaxSCEVCompareDepth("scalar-evolution-max-scev-compare-depth", cl::Hidden, cl::desc("Maximum depth of recursive SCEV complexity comparisons"), cl::init(32))
static APInt extractConstantWithoutWrapping(ScalarEvolution &SE, const SCEVConstant *ConstantTerm, const SCEVAddExpr *WholeAddExpr)
static cl::opt< unsigned > MaxConstantEvolvingDepth("scalar-evolution-max-constant-evolving-depth", cl::Hidden, cl::desc("Maximum depth of recursive constant evolving"), cl::init(32))
static ConstantRange getRangeForAffineARHelper(APInt Step, const ConstantRange &StartRange, const APInt &MaxBECount, bool Signed)
static std::optional< ConstantRange > GetRangeFromMetadata(Value *V)
Helper method to assign a range to V from metadata present in the IR.
static const SCEV * SolveLinEquationWithOverflow(const APInt &A, const SCEV *B, SmallVectorImpl< const SCEVPredicate * > *Predicates, ScalarEvolution &SE)
Finds the minimum unsigned root of the following equation:
static cl::opt< unsigned > HugeExprThreshold("scalar-evolution-huge-expr-threshold", cl::Hidden, cl::desc("Size of the expression which is considered huge"), cl::init(4096))
static Type * isSimpleCastedPHI(const SCEV *Op, const SCEVUnknown *SymbolicPHI, bool &Signed, ScalarEvolution &SE)
Helper function to createAddRecFromPHIWithCasts.
static Constant * EvaluateExpression(Value *V, const Loop *L, DenseMap< Instruction *, Constant * > &Vals, const DataLayout &DL, const TargetLibraryInfo *TLI)
EvaluateExpression - Given an expression that passes the getConstantEvolvingPHI predicate,...
static const SCEV * MatchNotExpr(const SCEV *Expr)
If Expr computes ~A, return A else return nullptr.
static cl::opt< unsigned > MaxValueCompareDepth("scalar-evolution-max-value-compare-depth", cl::Hidden, cl::desc("Maximum depth of recursive value complexity comparisons"), cl::init(2))
static cl::opt< bool, true > VerifySCEVOpt("verify-scev", cl::Hidden, cl::location(VerifySCEV), cl::desc("Verify ScalarEvolution's backedge taken counts (slow)"))
static const SCEV * getSignedOverflowLimitForStep(const SCEV *Step, ICmpInst::Predicate *Pred, ScalarEvolution *SE)
static SCEV::NoWrapFlags StrengthenNoWrapFlags(ScalarEvolution *SE, SCEVTypes Type, const ArrayRef< const SCEV * > Ops, SCEV::NoWrapFlags Flags)
static cl::opt< unsigned > MaxArithDepth("scalar-evolution-max-arith-depth", cl::Hidden, cl::desc("Maximum depth of recursive arithmetics"), cl::init(32))
static bool HasSameValue(const SCEV *A, const SCEV *B)
SCEV structural equivalence is usually sufficient for testing whether two expressions are equal,...
static uint64_t Choose(uint64_t n, uint64_t k, bool &Overflow)
Compute the result of "n choose k", the binomial coefficient.
static bool CollectAddOperandsWithScales(SmallDenseMap< const SCEV *, APInt, 16 > &M, SmallVectorImpl< const SCEV * > &NewOps, APInt &AccumulatedConstant, ArrayRef< const SCEV * > Ops, const APInt &Scale, ScalarEvolution &SE)
Process the given Ops list, which is a list of operands to be added under the given scale,...
static bool canConstantEvolve(Instruction *I, const Loop *L)
Determine whether this instruction can constant evolve within this loop assuming its operands can all...
static PHINode * getConstantEvolvingPHIOperands(Instruction *UseInst, const Loop *L, DenseMap< Instruction *, PHINode * > &PHIMap, unsigned Depth)
getConstantEvolvingPHIOperands - Implement getConstantEvolvingPHI by recursing through each instructi...
static bool scevUnconditionallyPropagatesPoisonFromOperands(SCEVTypes Kind)
static cl::opt< bool > VerifySCEVStrict("verify-scev-strict", cl::Hidden, cl::desc("Enable stricter verification with -verify-scev is passed"))
static Constant * getOtherIncomingValue(PHINode *PN, BasicBlock *BB)
scalar evolution
static cl::opt< bool > UseExpensiveRangeSharpening("scalar-evolution-use-expensive-range-sharpening", cl::Hidden, cl::init(false), cl::desc("Use more powerful methods of sharpening expression ranges. May " "be costly in terms of compile time"))
static const SCEV * getUnsignedOverflowLimitForStep(const SCEV *Step, ICmpInst::Predicate *Pred, ScalarEvolution *SE)
static bool IsKnownPredicateViaMinOrMax(ScalarEvolution &SE, CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS)
Is LHS Pred RHS true on the virtue of LHS or RHS being a Min or Max expression?
This file defines the make_scope_exit function, which executes user-defined cleanup logic at scope ex...
static bool InBlock(const Value *V, const BasicBlock *BB)
Provides some synthesis utilities to produce sequences of values.
This file defines the SmallPtrSet class.
This file defines the SmallSet class.
This file defines the SmallVector class.
This file defines the 'Statistic' class, which is designed to be an easy way to expose various metric...
#define STATISTIC(VARNAME, DESC)
Definition: Statistic.h:166
This file contains some functions that are useful when dealing with strings.
static SymbolRef::Type getType(const Symbol *Sym)
Definition: TapiFile.cpp:39
static std::optional< unsigned > getOpcode(ArrayRef< VPValue * > Values)
Returns the opcode of Values or ~0 if they do not all agree.
Definition: VPlanSLP.cpp:191
Virtual Register Rewriter
Definition: VirtRegMap.cpp:261
Value * RHS
Value * LHS
static const uint32_t IV[8]
Definition: blake3_impl.h:78
Class for arbitrary precision integers.
Definition: APInt.h:78
APInt umul_ov(const APInt &RHS, bool &Overflow) const
Definition: APInt.cpp:1945
APInt udiv(const APInt &RHS) const
Unsigned division operation.
Definition: APInt.cpp:1547
APInt zext(unsigned width) const
Zero extend to a new width.
Definition: APInt.cpp:986
bool isMinSignedValue() const
Determine if this is the smallest signed value.
Definition: APInt.h:423
uint64_t getZExtValue() const
Get zero extended value.
Definition: APInt.h:1520
void setHighBits(unsigned hiBits)
Set the top hiBits bits.
Definition: APInt.h:1392
APInt getHiBits(unsigned numBits) const
Compute an APInt containing numBits highbits from this APInt.
Definition: APInt.cpp:612
APInt zextOrTrunc(unsigned width) const
Zero extend or truncate to width.
Definition: APInt.cpp:1007
unsigned getActiveBits() const
Compute the number of active bits in the value.
Definition: APInt.h:1492
APInt trunc(unsigned width) const
Truncate to new width.
Definition: APInt.cpp:910
static APInt getMaxValue(unsigned numBits)
Gets maximum unsigned value of APInt for specific bit width.
Definition: APInt.h:206
APInt abs() const
Get the absolute value.
Definition: APInt.h:1773
bool sgt(const APInt &RHS) const
Signed greater than comparison.
Definition: APInt.h:1201
bool ugt(const APInt &RHS) const
Unsigned greater than comparison.
Definition: APInt.h:1182
bool isZero() const
Determine if this value is zero, i.e. all bits are clear.
Definition: APInt.h:380
bool isSignMask() const
Check if the APInt's value is returned by getSignMask.
Definition: APInt.h:466
APInt urem(const APInt &RHS) const
Unsigned remainder operation.
Definition: APInt.cpp:1640
unsigned getBitWidth() const
Return the number of bits in the APInt.
Definition: APInt.h:1468
bool ult(const APInt &RHS) const
Unsigned less than comparison.
Definition: APInt.h:1111
static APInt getSignedMaxValue(unsigned numBits)
Gets maximum signed value of APInt for a specific bit width.
Definition: APInt.h:209
static APInt getMinValue(unsigned numBits)
Gets minimum unsigned value of APInt for a specific bit width.
Definition: APInt.h:216
bool isNegative() const
Determine sign of this APInt.
Definition: APInt.h:329
bool sle(const APInt &RHS) const
Signed less or equal comparison.
Definition: APInt.h:1166
static APInt getSignedMinValue(unsigned numBits)
Gets minimum signed value of APInt for a specific bit width.
Definition: APInt.h:219
unsigned countTrailingZeros() const
Definition: APInt.h:1626
bool isStrictlyPositive() const
Determine if this APInt Value is positive.
Definition: APInt.h:356
APInt ashr(unsigned ShiftAmt) const
Arithmetic right-shift function.
Definition: APInt.h:827
APInt multiplicativeInverse() const
Definition: APInt.cpp:1248
bool ule(const APInt &RHS) const
Unsigned less or equal comparison.
Definition: APInt.h:1150
APInt sext(unsigned width) const
Sign extend to a new width.
Definition: APInt.cpp:959
APInt shl(unsigned shiftAmt) const
Left-shift function.
Definition: APInt.h:873
static APInt getLowBitsSet(unsigned numBits, unsigned loBitsSet)
Constructs an APInt value that has the bottom loBitsSet bits set.
Definition: APInt.h:306
bool isSignBitSet() const
Determine if sign bit of this APInt is set.
Definition: APInt.h:341
bool slt(const APInt &RHS) const
Signed less than comparison.
Definition: APInt.h:1130
static APInt getZero(unsigned numBits)
Get the '0' value for the specified bit-width.
Definition: APInt.h:200
bool isIntN(unsigned N) const
Check if this APInt has an N-bits unsigned integer value.
Definition: APInt.h:432
static APInt getOneBitSet(unsigned numBits, unsigned BitNo)
Return an APInt with exactly one bit set in the result.
Definition: APInt.h:239
bool uge(const APInt &RHS) const
Unsigned greater or equal comparison.
Definition: APInt.h:1221
This templated class represents "all analyses that operate over <a particular IR unit>" (e....
Definition: Analysis.h:49
API to communicate dependencies between analyses during invalidation.
Definition: PassManager.h:292
bool invalidate(IRUnitT &IR, const PreservedAnalyses &PA)
Trigger the invalidation of some other analysis pass if not already handled and return whether it was...
Definition: PassManager.h:310
A container for analyses that lazily runs them and caches their results.
Definition: PassManager.h:253
PassT::Result & getResult(IRUnitT &IR, ExtraArgTs... ExtraArgs)
Get the result of an analysis pass for a given IR unit.
Definition: PassManager.h:410
Represent the analysis usage information of a pass.
void setPreservesAll()
Set by analyses that do not transform their input at all.
AnalysisUsage & addRequiredTransitive()
ArrayRef - Represent a constant reference to an array (0 or more elements consecutively in memory),...
Definition: ArrayRef.h:41
ArrayRef< T > take_front(size_t N=1) const
Return a copy of *this with only the first N elements.
Definition: ArrayRef.h:231
iterator end() const
Definition: ArrayRef.h:157
size_t size() const
size - Get the array size.
Definition: ArrayRef.h:168
iterator begin() const
Definition: ArrayRef.h:156
A function analysis which provides an AssumptionCache.
An immutable pass that tracks lazily created AssumptionCache objects.
A cache of @llvm.assume calls within a function.
MutableArrayRef< ResultElem > assumptions()
Access the list of assumption handles currently tracked for this function.
bool isSingleEdge() const
Check if this is the only edge between Start and End.
Definition: Dominators.cpp:51
LLVM Basic Block Representation.
Definition: BasicBlock.h:61
iterator begin()
Instruction iterator methods.
Definition: BasicBlock.h:448
const Instruction & front() const
Definition: BasicBlock.h:471
const BasicBlock * getSinglePredecessor() const
Return the predecessor of this block if it has a single predecessor block.
Definition: BasicBlock.cpp:459
const Function * getParent() const
Return the enclosing method, or null if none.
Definition: BasicBlock.h:219
const Instruction * getTerminator() const LLVM_READONLY
Returns the terminator instruction if the block is well formed or null if the block is not well forme...
Definition: BasicBlock.h:239
Value * getRHS() const
unsigned getNoWrapKind() const
Returns one of OBO::NoSignedWrap or OBO::NoUnsignedWrap.
Instruction::BinaryOps getBinaryOp() const
Returns the binary operation underlying the intrinsic.
Value * getLHS() const
BinaryOps getOpcode() const
Definition: InstrTypes.h:370
Conditional or Unconditional Branch instruction.
bool isConditional() const
BasicBlock * getSuccessor(unsigned i) const
bool isUnconditional() const
Value * getCondition() const
LLVM_ATTRIBUTE_RETURNS_NONNULL void * Allocate(size_t Size, Align Alignment)
Allocate space at the specified alignment.
Definition: Allocator.h:148
This class represents a function call, abstracting a target machine's calling convention.
Value handle with callbacks on RAUW and destruction.
Definition: ValueHandle.h:383
void setValPtr(Value *P)
Definition: ValueHandle.h:390
bool isFalseWhenEqual() const
This is just a convenience.
Definition: InstrTypes.h:946
Predicate
This enumeration lists the possible predicates for CmpInst subclasses.
Definition: InstrTypes.h:673
@ ICMP_SLT
signed less than
Definition: InstrTypes.h:702
@ ICMP_SLE
signed less or equal
Definition: InstrTypes.h:703
@ ICMP_UGE
unsigned greater or equal
Definition: InstrTypes.h:697
@ ICMP_UGT
unsigned greater than
Definition: InstrTypes.h:696
@ ICMP_SGT
signed greater than
Definition: InstrTypes.h:700
@ ICMP_ULT
unsigned less than
Definition: InstrTypes.h:698
@ ICMP_EQ
equal
Definition: InstrTypes.h:694
@ ICMP_NE
not equal
Definition: InstrTypes.h:695
@ ICMP_SGE
signed greater or equal
Definition: InstrTypes.h:701
@ ICMP_ULE
unsigned less or equal
Definition: InstrTypes.h:699
bool isSigned() const
Definition: InstrTypes.h:928
Predicate getSwappedPredicate() const
For example, EQ->EQ, SLE->SGE, ULT->UGT, OEQ->OEQ, ULE->UGE, OLT->OGT, etc.
Definition: InstrTypes.h:825
bool isTrueWhenEqual() const
This is just a convenience.
Definition: InstrTypes.h:940
Predicate getNonStrictPredicate() const
For example, SGT -> SGE, SLT -> SLE, ULT -> ULE, UGT -> UGE.
Definition: InstrTypes.h:869
Predicate getInversePredicate() const
For example, EQ -> NE, UGT -> ULE, SLT -> SGE, OEQ -> UNE, UGT -> OLE, OLT -> UGE,...
Definition: InstrTypes.h:787
bool isUnsigned() const
Definition: InstrTypes.h:934
bool isRelational() const
Return true if the predicate is relational (not EQ or NE).
Definition: InstrTypes.h:924
An abstraction over a floating-point predicate, and a pack of an integer predicate with samesign info...
Definition: CmpPredicate.h:22
static Constant * getNot(Constant *C)
Definition: Constants.cpp:2631
static Constant * getPtrToInt(Constant *C, Type *Ty, bool OnlyIfReduced=false)
Definition: Constants.cpp:2293
static Constant * getGetElementPtr(Type *Ty, Constant *C, ArrayRef< Constant * > IdxList, GEPNoWrapFlags NW=GEPNoWrapFlags::none(), std::optional< ConstantRange > InRange=std::nullopt, Type *OnlyIfReducedTy=nullptr)
Getelementptr form.
Definition: Constants.h:1267
static Constant * getAdd(Constant *C1, Constant *C2, bool HasNUW=false, bool HasNSW=false)
Definition: Constants.cpp:2637
static Constant * getNeg(Constant *C, bool HasNSW=false)
Definition: Constants.cpp:2625
static Constant * getTrunc(Constant *C, Type *Ty, bool OnlyIfReduced=false)
Definition: Constants.cpp:2279
This is the shared class of boolean and integer constants.
Definition: Constants.h:83
bool isZero() const
This is just a convenience method to make client code smaller for a common code.
Definition: Constants.h:208
static ConstantInt * getFalse(LLVMContext &Context)
Definition: Constants.cpp:873
uint64_t getZExtValue() const
Return the constant as a 64-bit unsigned integer value after it has been zero extended as appropriate...
Definition: Constants.h:157
const APInt & getValue() const
Return the constant as an APInt value reference.
Definition: Constants.h:148
static ConstantInt * getBool(LLVMContext &Context, bool V)
Definition: Constants.cpp:880
This class represents a range of values.
Definition: ConstantRange.h:47
ConstantRange add(const ConstantRange &Other) const
Return a new range representing the possible values resulting from an addition of a value in this ran...
ConstantRange zextOrTrunc(uint32_t BitWidth) const
Make this range have the bit width given by BitWidth.
PreferredRangeType
If represented precisely, the result of some range operations may consist of multiple disjoint ranges...
bool getEquivalentICmp(CmpInst::Predicate &Pred, APInt &RHS) const
Set up Pred and RHS such that ConstantRange::makeExactICmpRegion(Pred, RHS) == *this.
ConstantRange subtract(const APInt &CI) const
Subtract the specified constant from the endpoints of this constant range.
const APInt & getLower() const
Return the lower value for this range.
ConstantRange truncate(uint32_t BitWidth) const
Return a new range in the specified integer type, which must be strictly smaller than the current typ...
bool isFullSet() const
Return true if this set contains all of the elements possible for this data-type.
bool icmp(CmpInst::Predicate Pred, const ConstantRange &Other) const
Does the predicate Pred hold between ranges this and Other? NOTE: false does not mean that inverse pr...
bool isEmptySet() const
Return true if this set contains no members.
ConstantRange zeroExtend(uint32_t BitWidth) const
Return a new range in the specified integer type, which must be strictly larger than the current type...
bool isSignWrappedSet() const
Return true if this set wraps around the signed domain.
APInt getSignedMin() const
Return the smallest signed value contained in the ConstantRange.
bool isWrappedSet() const
Return true if this set wraps around the unsigned domain.
void print(raw_ostream &OS) const
Print out the bounds to a stream.
ConstantRange signExtend(uint32_t BitWidth) const
Return a new range in the specified integer type, which must be strictly larger than the current type...
const APInt & getUpper() const
Return the upper value for this range.
ConstantRange unionWith(const ConstantRange &CR, PreferredRangeType Type=Smallest) const
Return the range that results from the union of this range with another range.
static ConstantRange makeExactICmpRegion(CmpInst::Predicate Pred, const APInt &Other)
Produce the exact range such that all values in the returned range satisfy the given predicate with a...
bool contains(const APInt &Val) const
Return true if the specified value is in the set.
APInt getUnsignedMax() const
Return the largest unsigned value contained in the ConstantRange.
ConstantRange intersectWith(const ConstantRange &CR, PreferredRangeType Type=Smallest) const
Return the range that results from the intersection of this range with another range.
APInt getSignedMax() const
Return the largest signed value contained in the ConstantRange.
static ConstantRange getNonEmpty(APInt Lower, APInt Upper)
Create non-empty constant range with the given bounds.
Definition: ConstantRange.h:84
static ConstantRange makeGuaranteedNoWrapRegion(Instruction::BinaryOps BinOp, const ConstantRange &Other, unsigned NoWrapKind)
Produce the largest range containing all X such that "X BinOp Y" is guaranteed not to wrap (overflow)...
unsigned getMinSignedBits() const
Compute the maximal number of bits needed to represent every value in this signed range.
uint32_t getBitWidth() const
Get the bit width of this ConstantRange.
ConstantRange sub(const ConstantRange &Other) const
Return a new range representing the possible values resulting from a subtraction of a value in this r...
ConstantRange sextOrTrunc(uint32_t BitWidth) const
Make this range have the bit width given by BitWidth.
static ConstantRange makeExactNoWrapRegion(Instruction::BinaryOps BinOp, const APInt &Other, unsigned NoWrapKind)
Produce the range that contains X if and only if "X BinOp Other" does not wrap.
This is an important base class in LLVM.
Definition: Constant.h:42
This class represents an Operation in the Expression.
A parsed version of the target data layout string in and methods for querying it.
Definition: DataLayout.h:63
const StructLayout * getStructLayout(StructType *Ty) const
Returns a StructLayout object, indicating the alignment of the struct, its size, and the offsets of i...
Definition: DataLayout.cpp:709
IntegerType * getIntPtrType(LLVMContext &C, unsigned AddressSpace=0) const
Returns an integer type with size at least as big as that of a pointer in the given address space.
Definition: DataLayout.cpp:851
unsigned getIndexTypeSizeInBits(Type *Ty) const
Layout size of the index used in GEP calculation.
Definition: DataLayout.cpp:754
IntegerType * getIndexType(LLVMContext &C, unsigned AddressSpace) const
Returns the type of a GEP index in AddressSpace.
Definition: DataLayout.cpp:878
TypeSize getTypeSizeInBits(Type *Ty) const
Size examples:
Definition: DataLayout.h:617
ValueT lookup(const_arg_type_t< KeyT > Val) const
lookup - Return the entry for the specified key, or a default constructed value if no such entry exis...
Definition: DenseMap.h:194
iterator find(const_arg_type_t< KeyT > Val)
Definition: DenseMap.h:156
std::pair< iterator, bool > try_emplace(KeyT &&Key, Ts &&...Args)
Definition: DenseMap.h:226
bool erase(const KeyT &Val)
Definition: DenseMap.h:321
DenseMapIterator< KeyT, ValueT, KeyInfoT, BucketT > iterator
Definition: DenseMap.h:71
iterator find_as(const LookupKeyT &Val)
Alternate version of find() which allows a different, and possibly less expensive,...
Definition: DenseMap.h:176
size_type count(const_arg_type_t< KeyT > Val) const
Return 1 if the specified key is in the map, 0 otherwise.
Definition: DenseMap.h:152
iterator end()
Definition: DenseMap.h:84
bool contains(const_arg_type_t< KeyT > Val) const
Return true if the specified key is in the map, false otherwise.
Definition: DenseMap.h:147
std::pair< iterator, bool > insert(const std::pair< KeyT, ValueT > &KV)
Definition: DenseMap.h:211
Analysis pass which computes a DominatorTree.
Definition: Dominators.h:279
bool properlyDominates(const DomTreeNodeBase< NodeT > *A, const DomTreeNodeBase< NodeT > *B) const
properlyDominates - Returns true iff A dominates B and A != B.
Legacy analysis pass which computes a DominatorTree.
Definition: Dominators.h:317
Concrete subclass of DominatorTreeBase that is used to compute a normal dominator tree.
Definition: Dominators.h:162
bool isReachableFromEntry(const Use &U) const
Provide an overload for a Use.
Definition: Dominators.cpp:321
bool dominates(const BasicBlock *BB, const Use &U) const
Return true if the (end of the) basic block BB dominates the use U.
Definition: Dominators.cpp:122
EquivalenceClasses - This represents a collection of equivalence classes and supports three efficient...
member_iterator unionSets(const ElemTy &V1, const ElemTy &V2)
union - Merge the two equivalence sets for the specified values, inserting them if they do not alread...
bool isEquivalent(const ElemTy &V1, const ElemTy &V2) const
FoldingSetNodeIDRef - This class describes a reference to an interned FoldingSetNodeID,...
Definition: FoldingSet.h:290
FoldingSetNodeID - This class is used to gather all the unique data bits of a node.
Definition: FoldingSet.h:327
FunctionPass class - This class is used to implement most global optimizations.
Definition: Pass.h:310
const BasicBlock & getEntryBlock() const
Definition: Function.h:809
bool hasFnAttribute(Attribute::AttrKind Kind) const
Return true if the function has the attribute.
Definition: Function.cpp:731
Represents flags for the getelementptr instruction/expression.
bool hasNoUnsignedSignedWrap() const
bool hasNoUnsignedWrap() const
static GEPNoWrapFlags none()
static Type * getTypeAtIndex(Type *Ty, Value *Idx)
Return the type of the element at the given index of an indexable type.
Module * getParent()
Get the module that this global value is contained inside of...
Definition: GlobalValue.h:656
static bool isPrivateLinkage(LinkageTypes Linkage)
Definition: GlobalValue.h:406
static bool isInternalLinkage(LinkageTypes Linkage)
Definition: GlobalValue.h:403
This instruction compares its operands according to the predicate given to the constructor.
CmpPredicate getCmpPredicate() const
static bool isGE(Predicate P)
Return true if the predicate is SGE or UGE.
CmpPredicate getSwappedCmpPredicate() const
static bool compare(const APInt &LHS, const APInt &RHS, ICmpInst::Predicate Pred)
Return result of LHS Pred RHS comparison.
static bool isLT(Predicate P)
Return true if the predicate is SLT or ULT.
CmpPredicate getInverseCmpPredicate() const
static bool isGT(Predicate P)
Return true if the predicate is SGT or UGT.
Predicate getFlippedSignednessPredicate() const
For example, SLT->ULT, ULT->SLT, SLE->ULE, ULE->SLE, EQ->EQ.
static CmpPredicate getInverseCmpPredicate(CmpPredicate Pred)
bool isEquality() const
Return true if this predicate is either EQ or NE.
bool isRelational() const
Return true if the predicate is relational (not EQ or NE).
static bool isLE(Predicate P)
Return true if the predicate is SLE or ULE.
bool hasNoUnsignedWrap() const LLVM_READONLY
Determine whether the no unsigned wrap flag is set.
bool hasNoSignedWrap() const LLVM_READONLY
Determine whether the no signed wrap flag is set.
bool isIdenticalToWhenDefined(const Instruction *I, bool IntersectAttrs=false) const LLVM_READONLY
This is like isIdenticalTo, except that it ignores the SubclassOptionalData flags,...
Class to represent integer types.
Definition: DerivedTypes.h:42
static IntegerType * get(LLVMContext &C, unsigned NumBits)
This static method is the primary way of constructing an IntegerType.
Definition: Type.cpp:311
An instruction for reading from memory.
Definition: Instructions.h:176
Analysis pass that exposes the LoopInfo for a function.
Definition: LoopInfo.h:566
bool contains(const LoopT *L) const
Return true if the specified loop is contained within in this loop.
BlockT * getHeader() const
unsigned getLoopDepth() const
Return the nesting level of this loop.
BlockT * getLoopPredecessor() const
If the given loop's header has exactly one unique predecessor outside the loop, return it.
LoopT * getParentLoop() const
Return the parent loop if it exists or nullptr for top level loops.
iterator end() const
unsigned getLoopDepth(const BlockT *BB) const
Return the loop nesting level of the specified block.
iterator begin() const
LoopT * getLoopFor(const BlockT *BB) const
Return the inner most loop that BB lives in.
The legacy pass manager's analysis pass to compute loop information.
Definition: LoopInfo.h:593
Represents a single loop in the control flow graph.
Definition: LoopInfo.h:39
bool isLoopInvariant(const Value *V) const
Return true if the specified value is loop invariant.
Definition: LoopInfo.cpp:61
Metadata node.
Definition: Metadata.h:1069
A Module instance is used to store all the information related to an LLVM module.
Definition: Module.h:65
This is a utility class that provides an abstraction for the common functionality between Instruction...
Definition: Operator.h:32
unsigned getOpcode() const
Return the opcode for this Instruction or ConstantExpr.
Definition: Operator.h:42
Utility class for integer operators which may exhibit overflow - Add, Sub, Mul, and Shl.
Definition: Operator.h:77
bool hasNoSignedWrap() const
Test whether this operation is known to never undergo signed overflow, aka the nsw property.
Definition: Operator.h:110
bool hasNoUnsignedWrap() const
Test whether this operation is known to never undergo unsigned overflow, aka the nuw property.
Definition: Operator.h:104
iterator_range< const_block_iterator > blocks() const
op_range incoming_values()
Value * getIncomingValueForBlock(const BasicBlock *BB) const
BasicBlock * getIncomingBlock(unsigned i) const
Return incoming basic block number i.
Value * getIncomingValue(unsigned i) const
Return incoming value number x.
unsigned getNumIncomingValues() const
Return the number of incoming edges.
static PassRegistry * getPassRegistry()
getPassRegistry - Access the global registry object, which is automatically initialized at applicatio...
PointerIntPair - This class implements a pair of a pointer and small integer.
static PointerType * getUnqual(Type *ElementType)
This constructs a pointer to an object of the specified type in the default address space (address sp...
Definition: DerivedTypes.h:686
static PoisonValue * get(Type *T)
Static factory methods - Return an 'poison' object of the specified type.
Definition: Constants.cpp:1878
An interface layer with SCEV used to manage how we see SCEV expressions for values in the context of ...
void addPredicate(const SCEVPredicate &Pred)
Adds a new predicate.
const SCEVPredicate & getPredicate() const
bool hasNoOverflow(Value *V, SCEVWrapPredicate::IncrementWrapFlags Flags)
Returns true if we've proved that V doesn't wrap by means of a SCEV predicate.
void setNoOverflow(Value *V, SCEVWrapPredicate::IncrementWrapFlags Flags)
Proves that V doesn't overflow by adding SCEV predicate.
void print(raw_ostream &OS, unsigned Depth) const
Print the SCEV mappings done by the Predicated Scalar Evolution.
bool areAddRecsEqualWithPreds(const SCEVAddRecExpr *AR1, const SCEVAddRecExpr *AR2) const
Check if AR1 and AR2 are equal, while taking into account Equal predicates in Preds.
PredicatedScalarEvolution(ScalarEvolution &SE, Loop &L)
const SCEVAddRecExpr * getAsAddRec(Value *V)
Attempts to produce an AddRecExpr for V by adding additional SCEV predicates.
unsigned getSmallConstantMaxTripCount()
Returns the upper bound of the loop trip count as a normal unsigned value, or 0 if the trip count is ...
const SCEV * getBackedgeTakenCount()
Get the (predicated) backedge count for the analyzed loop.
const SCEV * getSymbolicMaxBackedgeTakenCount()
Get the (predicated) symbolic max backedge count for the analyzed loop.
const SCEV * getSCEV(Value *V)
Returns the SCEV expression of V, in the context of the current SCEV predicate.
A set of analyses that are preserved following a run of a transformation pass.
Definition: Analysis.h:111
static PreservedAnalyses all()
Construct a special preserved set that preserves all passes.
Definition: Analysis.h:117
PreservedAnalysisChecker getChecker() const
Build a checker for this PreservedAnalyses and the specified analysis type.
Definition: Analysis.h:264
constexpr bool isValid() const
Definition: Register.h:116
This node represents an addition of some number of SCEVs.
This node represents a polynomial recurrence on the trip count of the specified loop.
const SCEV * evaluateAtIteration(const SCEV *It, ScalarEvolution &SE) const
Return the value of this chain of recurrences at the specified iteration number.
const SCEV * getStepRecurrence(ScalarEvolution &SE) const
Constructs and returns the recurrence indicating how much this expression steps by.
void setNoWrapFlags(NoWrapFlags Flags)
Set flags for a recurrence without clearing any previously set flags.
bool isAffine() const
Return true if this represents an expression A + B*x where A and B are loop invariant values.
bool isQuadratic() const
Return true if this represents an expression A + B*x + C*x^2 where A, B and C are loop invariant valu...
const SCEV * getNumIterationsInRange(const ConstantRange &Range, ScalarEvolution &SE) const
Return the number of iterations of this loop that produce values in the specified constant range.
const SCEVAddRecExpr * getPostIncExpr(ScalarEvolution &SE) const
Return an expression representing the value of this expression one iteration of the loop ahead.
This is the base class for unary cast operator classes.
const SCEV * getOperand() const
SCEVCastExpr(const FoldingSetNodeIDRef ID, SCEVTypes SCEVTy, const SCEV *op, Type *ty)
void setNoWrapFlags(NoWrapFlags Flags)
Set flags for a non-recurrence without clearing previously set flags.
This class represents an assumption that the expression LHS Pred RHS evaluates to true,...
SCEVComparePredicate(const FoldingSetNodeIDRef ID, const ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS)
bool isAlwaysTrue() const override
Returns true if the predicate is always true.
void print(raw_ostream &OS, unsigned Depth=0) const override
Prints a textual representation of this predicate with an indentation of Depth.
bool implies(const SCEVPredicate *N, ScalarEvolution &SE) const override
Implementation of the SCEVPredicate interface.
This class represents a constant integer value.
ConstantInt * getValue() const
const APInt & getAPInt() const
This is the base class for unary integral cast operator classes.
SCEVIntegralCastExpr(const FoldingSetNodeIDRef ID, SCEVTypes SCEVTy, const SCEV *op, Type *ty)
This node is the base class min/max selections.
static enum SCEVTypes negate(enum SCEVTypes T)
This node represents multiplication of some number of SCEVs.
This node is a base class providing common functionality for n'ary operators.
NoWrapFlags getNoWrapFlags(NoWrapFlags Mask=NoWrapMask) const
const SCEV * getOperand(unsigned i) const
const SCEV *const * Operands
ArrayRef< const SCEV * > operands() const
This class represents an assumption made using SCEV expressions which can be checked at run-time.
SCEVPredicate(const SCEVPredicate &)=default
virtual bool implies(const SCEVPredicate *N, ScalarEvolution &SE) const =0
Returns true if this predicate implies N.
virtual void print(raw_ostream &OS, unsigned Depth=0) const =0
Prints a textual representation of this predicate with an indentation of Depth.
This class represents a cast from a pointer to a pointer-sized integer value.
This visitor recursively visits a SCEV expression and re-writes it.
const SCEV * visitSignExtendExpr(const SCEVSignExtendExpr *Expr)
const SCEV * visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr)
const SCEV * visitSMinExpr(const SCEVSMinExpr *Expr)
const SCEV * visitUMinExpr(const SCEVUMinExpr *Expr)
This class represents a signed maximum selection.
This class represents a signed minimum selection.
This node is the base class for sequential/in-order min/max selections.
This class represents a sequential/in-order unsigned minimum selection.
This class represents a sign extension of a small integer value to a larger integer value.
Visit all nodes in the expression tree using worklist traversal.
void visitAll(const SCEV *Root)
This class represents a truncation of an integer value to a smaller integer value.
This class represents a binary unsigned division operation.
const SCEV * getLHS() const
const SCEV * getRHS() const
This class represents an unsigned maximum selection.
This class represents an unsigned minimum selection.
This class represents a composition of other SCEV predicates, and is the class that most clients will...
void print(raw_ostream &OS, unsigned Depth) const override
Prints a textual representation of this predicate with an indentation of Depth.
bool implies(const SCEVPredicate *N, ScalarEvolution &SE) const override
Returns true if this predicate implies N.
SCEVUnionPredicate(ArrayRef< const SCEVPredicate * > Preds, ScalarEvolution &SE)
Union predicates don't get cached so create a dummy set ID for it.
bool isAlwaysTrue() const override
Implementation of the SCEVPredicate interface.
This means that we are dealing with an entirely unknown SCEV value, and only represent it as its LLVM...
This class represents the value of vscale, as used when defining the length of a scalable vector or r...
This class represents an assumption made on an AddRec expression.
IncrementWrapFlags
Similar to SCEV::NoWrapFlags, but with slightly different semantics for FlagNUSW.
SCEVWrapPredicate(const FoldingSetNodeIDRef ID, const SCEVAddRecExpr *AR, IncrementWrapFlags Flags)
bool implies(const SCEVPredicate *N, ScalarEvolution &SE) const override
Returns true if this predicate implies N.
static SCEVWrapPredicate::IncrementWrapFlags setFlags(SCEVWrapPredicate::IncrementWrapFlags Flags, SCEVWrapPredicate::IncrementWrapFlags OnFlags)
void print(raw_ostream &OS, unsigned Depth=0) const override
Prints a textual representation of this predicate with an indentation of Depth.
bool isAlwaysTrue() const override
Returns true if the predicate is always true.
const SCEVAddRecExpr * getExpr() const
Implementation of the SCEVPredicate interface.
static SCEVWrapPredicate::IncrementWrapFlags clearFlags(SCEVWrapPredicate::IncrementWrapFlags Flags, SCEVWrapPredicate::IncrementWrapFlags OffFlags)
Convenient IncrementWrapFlags manipulation methods.
static SCEVWrapPredicate::IncrementWrapFlags getImpliedFlags(const SCEVAddRecExpr *AR, ScalarEvolution &SE)
Returns the set of SCEVWrapPredicate no wrap flags implied by a SCEVAddRecExpr.
IncrementWrapFlags getFlags() const
Returns the set assumed no overflow flags.
This class represents a zero extension of a small integer value to a larger integer value.
This class represents an analyzed expression in the program.
ArrayRef< const SCEV * > operands() const
Return operands of this SCEV expression.
unsigned short getExpressionSize() const
bool isOne() const
Return true if the expression is a constant one.
bool isZero() const
Return true if the expression is a constant zero.
void dump() const
This method is used for debugging.
bool isAllOnesValue() const
Return true if the expression is a constant all-ones value.
bool isNonConstantNegative() const
Return true if the specified scev is negated, but not a constant.
void print(raw_ostream &OS) const
Print out the internal representation of this scalar to the specified stream.
SCEVTypes getSCEVType() const
Type * getType() const
Return the LLVM type of this SCEV expression.
NoWrapFlags
NoWrapFlags are bitfield indices into SubclassData.
Analysis pass that exposes the ScalarEvolution for a function.
ScalarEvolution run(Function &F, FunctionAnalysisManager &AM)
PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM)
PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM)
void getAnalysisUsage(AnalysisUsage &AU) const override
getAnalysisUsage - This function should be overriden by passes that need analysis information to do t...
void print(raw_ostream &OS, const Module *=nullptr) const override
print - Print out the internal state of the pass.
bool runOnFunction(Function &F) override
runOnFunction - Virtual method overriden by subclasses to do the per-function processing of the pass.
void releaseMemory() override
releaseMemory() - This member can be implemented by a pass if it wants to be able to release its memo...
void verifyAnalysis() const override
verifyAnalysis() - This member can be implemented by a analysis pass to check state of analysis infor...
static LoopGuards collect(const Loop *L, ScalarEvolution &SE)
Collect rewrite map for loop guards for loop L, together with flags indicating if NUW and NSW can be ...
const SCEV * rewrite(const SCEV *Expr) const
Try to apply the collected loop guards to Expr.
The main scalar evolution driver.
const SCEV * getConstantMaxBackedgeTakenCount(const Loop *L)
When successful, this returns a SCEVConstant that is greater than or equal to (i.e.
static bool hasFlags(SCEV::NoWrapFlags Flags, SCEV::NoWrapFlags TestFlags)
const DataLayout & getDataLayout() const
Return the DataLayout associated with the module this SCEV instance is operating on.
bool isKnownNonNegative(const SCEV *S)
Test if the given expression is known to be non-negative.
bool isKnownOnEveryIteration(CmpPredicate Pred, const SCEVAddRecExpr *LHS, const SCEV *RHS)
Test if the condition described by Pred, LHS, RHS is known to be true on every iteration of the loop ...
const SCEV * getNegativeSCEV(const SCEV *V, SCEV::NoWrapFlags Flags=SCEV::FlagAnyWrap)
Return the SCEV object corresponding to -V.
std::optional< LoopInvariantPredicate > getLoopInvariantExitCondDuringFirstIterationsImpl(CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS, const Loop *L, const Instruction *CtxI, const SCEV *MaxIter)
const SCEV * getSMaxExpr(const SCEV *LHS, const SCEV *RHS)
const SCEV * getUDivCeilSCEV(const SCEV *N, const SCEV *D)
Compute ceil(N / D).
const SCEV * getGEPExpr(GEPOperator *GEP, const SmallVectorImpl< const SCEV * > &IndexExprs)
Returns an expression for a GEP.
std::optional< LoopInvariantPredicate > getLoopInvariantExitCondDuringFirstIterations(CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS, const Loop *L, const Instruction *CtxI, const SCEV *MaxIter)
If the result of the predicate LHS Pred RHS is loop invariant with respect to L at given Context duri...
Type * getWiderType(Type *Ty1, Type *Ty2) const
const SCEV * getAbsExpr(const SCEV *Op, bool IsNSW)
bool isKnownNonPositive(const SCEV *S)
Test if the given expression is known to be non-positive.
const SCEV * getURemExpr(const SCEV *LHS, const SCEV *RHS)
Represents an unsigned remainder expression based on unsigned division.
APInt getConstantMultiple(const SCEV *S)
Returns the max constant multiple of S.
bool isKnownNegative(const SCEV *S)
Test if the given expression is known to be negative.
const SCEV * getPredicatedConstantMaxBackedgeTakenCount(const Loop *L, SmallVectorImpl< const SCEVPredicate * > &Predicates)
Similar to getConstantMaxBackedgeTakenCount, except it will add a set of SCEV predicates to Predicate...
const SCEV * removePointerBase(const SCEV *S)
Compute an expression equivalent to S - getPointerBase(S).
bool isLoopEntryGuardedByCond(const Loop *L, CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS)
Test whether entry to the loop is protected by a conditional between LHS and RHS.
bool isKnownNonZero(const SCEV *S)
Test if the given expression is known to be non-zero.
const SCEV * getSCEVAtScope(const SCEV *S, const Loop *L)
Return a SCEV expression for the specified value at the specified scope in the program.
const SCEV * getSMinExpr(const SCEV *LHS, const SCEV *RHS)
const SCEV * getBackedgeTakenCount(const Loop *L, ExitCountKind Kind=Exact)
If the specified loop has a predictable backedge-taken count, return it, otherwise return a SCEVCould...
const SCEV * getUMaxExpr(const SCEV *LHS, const SCEV *RHS)
void setNoWrapFlags(SCEVAddRecExpr *AddRec, SCEV::NoWrapFlags Flags)
Update no-wrap flags of an AddRec.
const SCEV * getUMaxFromMismatchedTypes(const SCEV *LHS, const SCEV *RHS)
Promote the operands to the wider of the types using zero-extension, and then perform a umax operatio...
const SCEV * getZero(Type *Ty)
Return a SCEV for the constant 0 of a specific type.
bool willNotOverflow(Instruction::BinaryOps BinOp, bool Signed, const SCEV *LHS, const SCEV *RHS, const Instruction *CtxI=nullptr)
Is operation BinOp between LHS and RHS provably does not have a signed/unsigned overflow (Signed)?...
ExitLimit computeExitLimitFromCond(const Loop *L, Value *ExitCond, bool ExitIfTrue, bool ControlsOnlyExit, bool AllowPredicates=false)
Compute the number of times the backedge of the specified loop will execute if its exit condition wer...
const SCEV * getZeroExtendExprImpl(const SCEV *Op, Type *Ty, unsigned Depth=0)
const SCEVPredicate * getEqualPredicate(const SCEV *LHS, const SCEV *RHS)
unsigned getSmallConstantTripMultiple(const Loop *L, const SCEV *ExitCount)
Returns the largest constant divisor of the trip count as a normal unsigned value,...
uint64_t getTypeSizeInBits(Type *Ty) const
Return the size in bits of the specified type, for which isSCEVable must return true.
const SCEV * getConstant(ConstantInt *V)
const SCEV * getPredicatedBackedgeTakenCount(const Loop *L, SmallVectorImpl< const SCEVPredicate * > &Predicates)
Similar to getBackedgeTakenCount, except it will add a set of SCEV predicates to Predicates that are ...
const SCEV * getSCEV(Value *V)
Return a SCEV expression for the full generality of the specified expression.
ConstantRange getSignedRange(const SCEV *S)
Determine the signed range for a particular SCEV.
const SCEV * getNoopOrSignExtend(const SCEV *V, Type *Ty)
Return a SCEV corresponding to a conversion of the input value to the specified type.
bool loopHasNoAbnormalExits(const Loop *L)
Return true if the loop has no abnormal exits.
const SCEV * getTripCountFromExitCount(const SCEV *ExitCount)
A version of getTripCountFromExitCount below which always picks an evaluation type which can not resu...
ScalarEvolution(Function &F, TargetLibraryInfo &TLI, AssumptionCache &AC, DominatorTree &DT, LoopInfo &LI)
const SCEV * getOne(Type *Ty)
Return a SCEV for the constant 1 of a specific type.
const SCEV * getTruncateOrNoop(const SCEV *V, Type *Ty)
Return a SCEV corresponding to a conversion of the input value to the specified type.
const SCEV * getCastExpr(SCEVTypes Kind, const SCEV *Op, Type *Ty)
const SCEV * getSequentialMinMaxExpr(SCEVTypes Kind, SmallVectorImpl< const SCEV * > &Operands)
const SCEV * getLosslessPtrToIntExpr(const SCEV *Op, unsigned Depth=0)
std::optional< bool > evaluatePredicateAt(CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS, const Instruction *CtxI)
Check whether the condition described by Pred, LHS, and RHS is true or false in the given Context.
unsigned getSmallConstantMaxTripCount(const Loop *L, SmallVectorImpl< const SCEVPredicate * > *Predicates=nullptr)
Returns the upper bound of the loop trip count as a normal unsigned value.
const SCEV * getPtrToIntExpr(const SCEV *Op, Type *Ty)
bool isBackedgeTakenCountMaxOrZero(const Loop *L)
Return true if the backedge taken count is either the value returned by getConstantMaxBackedgeTakenCo...
void forgetLoop(const Loop *L)
This method should be called by the client when it has changed a loop in a way that may effect Scalar...
bool isLoopInvariant(const SCEV *S, const Loop *L)
Return true if the value of the given SCEV is unchanging in the specified loop.
bool isKnownPositive(const SCEV *S)
Test if the given expression is known to be positive.
APInt getUnsignedRangeMin(const SCEV *S)
Determine the min of the unsigned range for a particular SCEV.
bool SimplifyICmpOperands(CmpPredicate &Pred, const SCEV *&LHS, const SCEV *&RHS, unsigned Depth=0)
Simplify LHS and RHS in a comparison with predicate Pred.
const SCEV * getOffsetOfExpr(Type *IntTy, StructType *STy, unsigned FieldNo)
Return an expression for offsetof on the given field with type IntTy.
LoopDisposition getLoopDisposition(const SCEV *S, const Loop *L)
Return the "disposition" of the given SCEV with respect to the given loop.
bool containsAddRecurrence(const SCEV *S)
Return true if the SCEV is a scAddRecExpr or it contains scAddRecExpr.
const SCEV * getSignExtendExprImpl(const SCEV *Op, Type *Ty, unsigned Depth=0)
const SCEV * getAddRecExpr(const SCEV *Start, const SCEV *Step, const Loop *L, SCEV::NoWrapFlags Flags)
Get an add recurrence expression for the specified loop.
bool hasOperand(const SCEV *S, const SCEV *Op) const
Test whether the given SCEV has Op as a direct or indirect operand.
const SCEV * getUDivExpr(const SCEV *LHS, const SCEV *RHS)
Get a canonical unsigned division expression, or something simpler if possible.
const SCEV * getZeroExtendExpr(const SCEV *Op, Type *Ty, unsigned Depth=0)
bool isSCEVable(Type *Ty) const
Test if values of the given type are analyzable within the SCEV framework.
Type * getEffectiveSCEVType(Type *Ty) const
Return a type with the same bitwidth as the given type and which represents how SCEV will treat the g...
const SCEVPredicate * getComparePredicate(ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS)
const SCEV * getNotSCEV(const SCEV *V)
Return the SCEV object corresponding to ~V.
std::optional< LoopInvariantPredicate > getLoopInvariantPredicate(ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS, const Loop *L, const Instruction *CtxI=nullptr)
If the result of the predicate LHS Pred RHS is loop invariant with respect to L, return a LoopInvaria...
bool instructionCouldExistWithOperands(const SCEV *A, const SCEV *B)
Return true if there exists a point in the program at which both A and B could be operands to the sam...
ConstantRange getUnsignedRange(const SCEV *S)
Determine the unsigned range for a particular SCEV.
uint32_t getMinTrailingZeros(const SCEV *S)
Determine the minimum number of zero bits that S is guaranteed to end in (at every loop iteration).
void print(raw_ostream &OS) const
const SCEV * getUMinExpr(const SCEV *LHS, const SCEV *RHS, bool Sequential=false)
const SCEV * getPredicatedExitCount(const Loop *L, const BasicBlock *ExitingBlock, SmallVectorImpl< const SCEVPredicate * > *Predicates, ExitCountKind Kind=Exact)
Same as above except this uses the predicated backedge taken info and may require predicates.
static SCEV::NoWrapFlags clearFlags(SCEV::NoWrapFlags Flags, SCEV::NoWrapFlags OffFlags)
void forgetTopmostLoop(const Loop *L)
void forgetValue(Value *V)
This method should be called by the client when it has changed a value in a way that may effect its v...
APInt getSignedRangeMin(const SCEV *S)
Determine the min of the signed range for a particular SCEV.
const SCEV * getNoopOrAnyExtend(const SCEV *V, Type *Ty)
Return a SCEV corresponding to a conversion of the input value to the specified type.
void forgetBlockAndLoopDispositions(Value *V=nullptr)
Called when the client has changed the disposition of values in a loop or block.
const SCEV * getTruncateExpr(const SCEV *Op, Type *Ty, unsigned Depth=0)
const SCEV * getStoreSizeOfExpr(Type *IntTy, Type *StoreTy)
Return an expression for the store size of StoreTy that is type IntTy.
const SCEVPredicate * getWrapPredicate(const SCEVAddRecExpr *AR, SCEVWrapPredicate::IncrementWrapFlags AddedFlags)
bool isLoopBackedgeGuardedByCond(const Loop *L, CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS)
Test whether the backedge of the loop is protected by a conditional between LHS and RHS.
const SCEV * getMinusSCEV(const SCEV *LHS, const SCEV *RHS, SCEV::NoWrapFlags Flags=SCEV::FlagAnyWrap, unsigned Depth=0)
Return LHS-RHS.
APInt getNonZeroConstantMultiple(const SCEV *S)
const SCEV * getMinusOne(Type *Ty)
Return a SCEV for the constant -1 of a specific type.
static SCEV::NoWrapFlags setFlags(SCEV::NoWrapFlags Flags, SCEV::NoWrapFlags OnFlags)
bool hasLoopInvariantBackedgeTakenCount(const Loop *L)
Return true if the specified loop has an analyzable loop-invariant backedge-taken count.
BlockDisposition getBlockDisposition(const SCEV *S, const BasicBlock *BB)
Return the "disposition" of the given SCEV with respect to the given block.
const SCEV * getNoopOrZeroExtend(const SCEV *V, Type *Ty)
Return a SCEV corresponding to a conversion of the input value to the specified type.
bool invalidate(Function &F, const PreservedAnalyses &PA, FunctionAnalysisManager::Invalidator &Inv)
const SCEV * getUMinFromMismatchedTypes(const SCEV *LHS, const SCEV *RHS, bool Sequential=false)
Promote the operands to the wider of the types using zero-extension, and then perform a umin operatio...
bool loopIsFiniteByAssumption(const Loop *L)
Return true if this loop is finite by assumption.
const SCEV * getExistingSCEV(Value *V)
Return an existing SCEV for V if there is one, otherwise return nullptr.
LoopDisposition
An enum describing the relationship between a SCEV and a loop.
@ LoopComputable
The SCEV varies predictably with the loop.
@ LoopVariant
The SCEV is loop-variant (unknown).
@ LoopInvariant
The SCEV is loop-invariant.
const SCEV * getAnyExtendExpr(const SCEV *Op, Type *Ty)
getAnyExtendExpr - Return a SCEV for the given operand extended with unspecified bits out to the give...
bool isKnownToBeAPowerOfTwo(const SCEV *S, bool OrZero=false, bool OrNegative=false)
Test if the given expression is known to be a power of 2.
std::optional< SCEV::NoWrapFlags > getStrengthenedNoWrapFlagsFromBinOp(const OverflowingBinaryOperator *OBO)
Parse NSW/NUW flags from add/sub/mul IR binary operation Op into SCEV no-wrap flags,...
void forgetLcssaPhiWithNewPredecessor(Loop *L, PHINode *V)
Forget LCSSA phi node V of loop L to which a new predecessor was added, such that it may no longer be...
bool containsUndefs(const SCEV *S) const
Return true if the SCEV expression contains an undef value.
std::optional< MonotonicPredicateType > getMonotonicPredicateType(const SCEVAddRecExpr *LHS, ICmpInst::Predicate Pred)
If, for all loop invariant X, the predicate "LHS `Pred` X" is monotonically increasing or decreasing,...
const SCEV * getCouldNotCompute()
bool isAvailableAtLoopEntry(const SCEV *S, const Loop *L)
Determine if the SCEV can be evaluated at loop's entry.
BlockDisposition
An enum describing the relationship between a SCEV and a basic block.
@ DominatesBlock
The SCEV dominates the block.
@ ProperlyDominatesBlock
The SCEV properly dominates the block.
@ DoesNotDominateBlock
The SCEV does not dominate the block.
const SCEV * getExitCount(const Loop *L, const BasicBlock *ExitingBlock, ExitCountKind Kind=Exact)
Return the number of times the backedge executes before the given exit would be taken; if not exactly...
const SCEV * getSignExtendExpr(const SCEV *Op, Type *Ty, unsigned Depth=0)
void getPoisonGeneratingValues(SmallPtrSetImpl< const Value * > &Result, const SCEV *S)
Return the set of Values that, if poison, will definitively result in S being poison as well.
void forgetLoopDispositions()
Called when the client has changed the disposition of values in this loop.
const SCEV * getVScale(Type *Ty)
unsigned getSmallConstantTripCount(const Loop *L)
Returns the exact trip count of the loop if we can compute it, and the result is a small constant.
bool hasComputableLoopEvolution(const SCEV *S, const Loop *L)
Return true if the given SCEV changes value in a known way in the specified loop.
const SCEV * getPointerBase(const SCEV *V)
Transitively follow the chain of pointer-type operands until reaching a SCEV that does not have a sin...
const SCEV * getMinMaxExpr(SCEVTypes Kind, SmallVectorImpl< const SCEV * > &Operands)
bool dominates(const SCEV *S, const BasicBlock *BB)
Return true if elements that makes up the given SCEV dominate the specified basic block.
APInt getUnsignedRangeMax(const SCEV *S)
Determine the max of the unsigned range for a particular SCEV.
ExitCountKind
The terms "backedge taken count" and "exit count" are used interchangeably to refer to the number of ...
@ SymbolicMaximum
An expression which provides an upper bound on the exact trip count.
@ ConstantMaximum
A constant which provides an upper bound on the exact trip count.
@ Exact
An expression exactly describing the number of times the backedge has executed when a loop is exited.
const SCEV * applyLoopGuards(const SCEV *Expr, const Loop *L)
Try to apply information from loop guards for L to Expr.
const SCEV * getMulExpr(SmallVectorImpl< const SCEV * > &Ops, SCEV::NoWrapFlags Flags=SCEV::FlagAnyWrap, unsigned Depth=0)
Get a canonical multiply expression, or something simpler if possible.
const SCEVAddRecExpr * convertSCEVToAddRecWithPredicates(const SCEV *S, const Loop *L, SmallVectorImpl< const SCEVPredicate * > &Preds)
Tries to convert the S expression to an AddRec expression, adding additional predicates to Preds as r...
const SCEV * getElementSize(Instruction *Inst)
Return the size of an element read or written by Inst.
const SCEV * getSizeOfExpr(Type *IntTy, TypeSize Size)
Return an expression for a TypeSize.
std::optional< bool > evaluatePredicate(CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS)
Check whether the condition described by Pred, LHS, and RHS is true or false.
const SCEV * getUnknown(Value *V)
std::optional< std::pair< const SCEV *, SmallVector< const SCEVPredicate *, 3 > > > createAddRecFromPHIWithCasts(const SCEVUnknown *SymbolicPHI)
Checks if SymbolicPHI can be rewritten as an AddRecExpr under some Predicates.
const SCEV * getTruncateOrZeroExtend(const SCEV *V, Type *Ty, unsigned Depth=0)
Return a SCEV corresponding to a conversion of the input value to the specified type.
const SCEV * getElementCount(Type *Ty, ElementCount EC)
static SCEV::NoWrapFlags maskFlags(SCEV::NoWrapFlags Flags, int Mask)
Convenient NoWrapFlags manipulation that hides enum casts and is visible in the ScalarEvolution name ...
std::optional< APInt > computeConstantDifference(const SCEV *LHS, const SCEV *RHS)
Compute LHS - RHS and returns the result as an APInt if it is a constant, and std::nullopt if it isn'...
bool properlyDominates(const SCEV *S, const BasicBlock *BB)
Return true if elements that makes up the given SCEV properly dominate the specified basic block.
const SCEV * rewriteUsingPredicate(const SCEV *S, const Loop *L, const SCEVPredicate &A)
Re-writes the SCEV according to the Predicates in A.
std::pair< const SCEV *, const SCEV * > SplitIntoInitAndPostInc(const Loop *L, const SCEV *S)
Splits SCEV expression S into two SCEVs.
bool canReuseInstruction(const SCEV *S, Instruction *I, SmallVectorImpl< Instruction * > &DropPoisonGeneratingInsts)
Check whether it is poison-safe to represent the expression S using the instruction I.
bool isKnownPredicateAt(CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS, const Instruction *CtxI)
Test if the given expression is known to satisfy the condition described by Pred, LHS,...
const SCEV * getPredicatedSymbolicMaxBackedgeTakenCount(const Loop *L, SmallVectorImpl< const SCEVPredicate * > &Predicates)
Similar to getSymbolicMaxBackedgeTakenCount, except it will add a set of SCEV predicates to Predicate...
const SCEV * getUDivExactExpr(const SCEV *LHS, const SCEV *RHS)
Get a canonical unsigned division expression, or something simpler if possible.
void registerUser(const SCEV *User, ArrayRef< const SCEV * > Ops)
Notify this ScalarEvolution that User directly uses SCEVs in Ops.
const SCEV * getAddExpr(SmallVectorImpl< const SCEV * > &Ops, SCEV::NoWrapFlags Flags=SCEV::FlagAnyWrap, unsigned Depth=0)
Get a canonical add expression, or something simpler if possible.
bool isBasicBlockEntryGuardedByCond(const BasicBlock *BB, CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS)
Test whether entry to the basic block is protected by a conditional between LHS and RHS.
const SCEV * getTruncateOrSignExtend(const SCEV *V, Type *Ty, unsigned Depth=0)
Return a SCEV corresponding to a conversion of the input value to the specified type.
bool containsErasedValue(const SCEV *S) const
Return true if the SCEV expression contains a Value that has been optimised out and is now a nullptr.
bool isKnownPredicate(CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS)
Test if the given expression is known to satisfy the condition described by Pred, LHS,...
bool isKnownViaInduction(CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS)
We'd like to check the predicate on every iteration of the most dominated loop between loops used in ...
const SCEV * getSymbolicMaxBackedgeTakenCount(const Loop *L)
When successful, this returns a SCEV that is greater than or equal to (i.e.
APInt getSignedRangeMax(const SCEV *S)
Determine the max of the signed range for a particular SCEV.
LLVMContext & getContext() const
This class represents the LLVM 'select' instruction.
size_type size() const
Definition: SmallPtrSet.h:94
A templated base class for SmallPtrSet which provides the typesafe interface that is common across al...
Definition: SmallPtrSet.h:363
std::pair< iterator, bool > insert(PtrType Ptr)
Inserts Ptr if and only if there is no element in the container equal to Ptr.
Definition: SmallPtrSet.h:384
bool contains(ConstPtrType Ptr) const
Definition: SmallPtrSet.h:458
SmallPtrSet - This class implements a set which is optimized for holding SmallSize or less elements.
Definition: SmallPtrSet.h:519
SmallSet - This maintains a set of unique values, optimizing for the case when the set is small (less...
Definition: SmallSet.h:132
std::pair< const_iterator, bool > insert(const T &V)
insert - Insert an element into the set if it isn't already there.
Definition: SmallSet.h:181
size_type size() const
Definition: SmallSet.h:170
bool empty() const
Definition: SmallVector.h:81
size_t size() const
Definition: SmallVector.h:78
This class consists of common code factored out of the SmallVector class to reduce code duplication b...
Definition: SmallVector.h:573
reference emplace_back(ArgTypes &&... Args)
Definition: SmallVector.h:937
void reserve(size_type N)
Definition: SmallVector.h:663
iterator erase(const_iterator CI)
Definition: SmallVector.h:737
void append(ItTy in_start, ItTy in_end)
Add the specified range to the end of the SmallVector.
Definition: SmallVector.h:683
iterator insert(iterator I, T &&Elt)
Definition: SmallVector.h:805
void push_back(const T &Elt)
Definition: SmallVector.h:413
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
Definition: SmallVector.h:1196
An instruction for storing to memory.
Definition: Instructions.h:292
Used to lazily calculate structure layout information for a target machine, based on the DataLayout s...
Definition: DataLayout.h:567
TypeSize getElementOffset(unsigned Idx) const
Definition: DataLayout.h:596
TypeSize getSizeInBits() const
Definition: DataLayout.h:576
Class to represent struct types.
Definition: DerivedTypes.h:218
Multiway switch.
Analysis pass providing the TargetLibraryInfo.
Provides information about what library functions are available for the current target.
The instances of the Type class are immutable: once they are created, they are never changed.
Definition: Type.h:45
bool isPointerTy() const
True if this is an instance of PointerType.
Definition: Type.h:264
static IntegerType * getInt1Ty(LLVMContext &C)
static IntegerType * getIntNTy(LLVMContext &C, unsigned N)
unsigned getScalarSizeInBits() const LLVM_READONLY
If this is a vector type, return the getPrimitiveSizeInBits value for the element type.
static IntegerType * getInt8Ty(LLVMContext &C)
bool isIntOrPtrTy() const
Return true if this is an integer type or a pointer type.
Definition: Type.h:252
static IntegerType * getInt32Ty(LLVMContext &C)
bool isIntegerTy() const
True if this is an instance of IntegerType.
Definition: Type.h:237
TypeSize getPrimitiveSizeInBits() const LLVM_READONLY
Return the basic size of this type if it is a primitive type.
A Use represents the edge between a Value definition and its users.
Definition: Use.h:43
op_range operands()
Definition: User.h:288
Use & Op()
Definition: User.h:192
Value * getOperand(unsigned i) const
Definition: User.h:228
LLVM Value Representation.
Definition: Value.h:74
Type * getType() const
All values are typed, get the type of this value.
Definition: Value.h:255
unsigned getValueID() const
Return an ID for the concrete type of this object.
Definition: Value.h:532
void printAsOperand(raw_ostream &O, bool PrintType=true, const Module *M=nullptr) const
Print the name of this Value out to the specified raw_ostream.
Definition: AsmWriter.cpp:5144
LLVMContext & getContext() const
All values hold a context through their type.
Definition: Value.cpp:1075
StringRef getName() const
Return a constant reference to the value's name.
Definition: Value.cpp:309
Represents an op.with.overflow intrinsic.
constexpr bool isScalable() const
Returns whether the quantity is scaled by a runtime quantity (vscale).
Definition: TypeSize.h:171
const ParentTy * getParent() const
Definition: ilist_node.h:32
This class implements an extremely fast bulk output stream that can only output to a stream.
Definition: raw_ostream.h:52
raw_ostream & indent(unsigned NumSpaces)
indent - Insert 'NumSpaces' spaces.
#define llvm_unreachable(msg)
Marks that the current location is not supposed to be reachable.
const APInt & smin(const APInt &A, const APInt &B)
Determine the smaller of two APInts considered to be signed.
Definition: APInt.h:2217
const APInt & smax(const APInt &A, const APInt &B)
Determine the larger of two APInts considered to be signed.
Definition: APInt.h:2222
const APInt & umin(const APInt &A, const APInt &B)
Determine the smaller of two APInts considered to be unsigned.
Definition: APInt.h:2227
std::optional< APInt > SolveQuadraticEquationWrap(APInt A, APInt B, APInt C, unsigned RangeWidth)
Let q(n) = An^2 + Bn + C, and BW = bit width of the value range (e.g.
Definition: APInt.cpp:2785
const APInt & umax(const APInt &A, const APInt &B)
Determine the larger of two APInts considered to be unsigned.
Definition: APInt.h:2232
APInt GreatestCommonDivisor(APInt A, APInt B)
Compute GCD of two unsigned APInt values.
Definition: APInt.cpp:771
@ Entry
Definition: COFF.h:844
@ Exit
Definition: COFF.h:845
@ C
The default llvm calling convention, compatible with C.
Definition: CallingConv.h:34
unsigned ID
LLVM IR allows to use arbitrary numbers as calling convention identifiers.
Definition: CallingConv.h:24
Function * getDeclarationIfExists(Module *M, ID id, ArrayRef< Type * > Tys, FunctionType *FT=nullptr)
This version supports overloaded intrinsics.
Definition: Intrinsics.cpp:746
Predicate
Predicate - These are "(BI << 5) | BO" for various predicates.
Definition: PPCPredicates.h:26
BinaryOp_match< LHS, RHS, Instruction::AShr > m_AShr(const LHS &L, const RHS &R)
bool match(Val *V, const Pattern &P)
Definition: PatternMatch.h:49
specificval_ty m_Specific(const Value *V)
Match if we have a specific specified value.
Definition: PatternMatch.h:885
class_match< ConstantInt > m_ConstantInt()
Match an arbitrary ConstantInt and ignore it.
Definition: PatternMatch.h:168
ThreeOps_match< Cond, LHS, RHS, Instruction::Select > m_Select(const Cond &C, const LHS &L, const RHS &R)
Matches SelectInst.
bind_ty< WithOverflowInst > m_WithOverflowInst(WithOverflowInst *&I)
Match a with overflow intrinsic, capturing it if we match.
Definition: PatternMatch.h:832
auto m_LogicalOr()
Matches L || R where L and R are arbitrary values.
brc_match< Cond_t, bind_ty< BasicBlock >, bind_ty< BasicBlock > > m_Br(const Cond_t &C, BasicBlock *&T, BasicBlock *&F)
BinaryOp_match< LHS, RHS, Instruction::SDiv > m_SDiv(const LHS &L, const RHS &R)
apint_match m_APInt(const APInt *&Res)
Match a ConstantInt or splatted ConstantVector, binding the specified pointer to the contained APInt.
Definition: PatternMatch.h:299
class_match< Value > m_Value()
Match an arbitrary value and ignore it.
Definition: PatternMatch.h:92
BinaryOp_match< LHS, RHS, Instruction::LShr > m_LShr(const LHS &L, const RHS &R)
BinaryOp_match< LHS, RHS, Instruction::Shl > m_Shl(const LHS &L, const RHS &R)
auto m_LogicalAnd()
Matches L && R where L and R are arbitrary values.
class_match< BasicBlock > m_BasicBlock()
Match an arbitrary basic block value and ignore it.
Definition: PatternMatch.h:189
match_combine_or< LTy, RTy > m_CombineOr(const LTy &L, const RTy &R)
Combine two pattern matchers matching L || R.
Definition: PatternMatch.h:239
cst_pred_ty< is_all_ones > m_scev_AllOnes()
Match an integer with all bits set.
SCEVUnaryExpr_match< SCEVZeroExtendExpr, Op0_t > m_scev_ZExt(const Op0_t &Op0)
cst_pred_ty< is_one > m_scev_One()
Match an integer 1.
SCEVUnaryExpr_match< SCEVSignExtendExpr, Op0_t > m_scev_SExt(const Op0_t &Op0)
cst_pred_ty< is_zero > m_scev_Zero()
Match an integer 0.
bind_ty< const SCEVConstant > m_SCEVConstant(const SCEVConstant *&V)
bind_ty< const SCEV > m_SCEV(const SCEV *&V)
Match a SCEV, capturing it if we match.
SCEVBinaryExpr_match< SCEVAddExpr, Op0_t, Op1_t > m_scev_Add(const Op0_t &Op0, const Op1_t &Op1)
bool match(const SCEV *S, const Pattern &P)
bind_ty< const SCEVUnknown > m_SCEVUnknown(const SCEVUnknown *&V)
@ ReallyHidden
Definition: CommandLine.h:138
initializer< Ty > init(const Ty &Val)
Definition: CommandLine.h:443
LocationClass< Ty > location(Ty &L)
Definition: CommandLine.h:463
@ Switch
The "resume-switch" lowering, where there are separate resume and destroy functions that are shared b...
constexpr double e
Definition: MathExtras.h:47
NodeAddr< PhiNode * > Phi
Definition: RDFGraph.h:390
@ FalseVal
Definition: TGLexer.h:59
This is an optimization pass for GlobalISel generic memory operations.
Definition: AddressRanges.h:18
void visitAll(const SCEV *Root, SV &Visitor)
Use SCEVTraversal to visit all nodes in the given expression tree.
auto drop_begin(T &&RangeOrContainer, size_t N=1)
Return a range covering RangeOrContainer with the first N elements excluded.
Definition: STLExtras.h:329
@ Offset
Definition: DWP.cpp:480
LLVM_ATTRIBUTE_ALWAYS_INLINE DynamicAPInt gcd(const DynamicAPInt &A, const DynamicAPInt &B)
Definition: DynamicAPInt.h:390
void stable_sort(R &&Range)
Definition: STLExtras.h:2037
bool all_of(R &&range, UnaryPredicate P)
Provide wrappers to std::all_of which take ranges instead of having to pass begin/end explicitly.
Definition: STLExtras.h:1739
bool canCreatePoison(const Operator *Op, bool ConsiderFlagsAndMetadata=true)
bool mustTriggerUB(const Instruction *I, const SmallPtrSetImpl< const Value * > &KnownPoison)
Return true if the given instruction must trigger undefined behavior when I is executed with any oper...
bool isUIntN(unsigned N, uint64_t x)
Checks if an unsigned integer fits into the given (dynamic) bit width.
Definition: MathExtras.h:255
detail::scope_exit< std::decay_t< Callable > > make_scope_exit(Callable &&F)
Definition: ScopeExit.h:59
bool canConstantFoldCallTo(const CallBase *Call, const Function *F)
canConstantFoldCallTo - Return true if its even possible to fold a call to the specified function.
bool verifyFunction(const Function &F, raw_ostream *OS=nullptr)
Check a function for errors, useful for use when debugging a pass.
Definition: Verifier.cpp:7297
auto successors(const MachineBasicBlock *BB)
void * PointerTy
Definition: GenericValue.h:21
bool set_is_subset(const S1Ty &S1, const S2Ty &S2)
set_is_subset(A, B) - Return true iff A in B
void append_range(Container &C, Range &&R)
Wrapper function to append range R to container C.
Definition: STLExtras.h:2115
Constant * ConstantFoldCompareInstOperands(unsigned Predicate, Constant *LHS, Constant *RHS, const DataLayout &DL, const TargetLibraryInfo *TLI=nullptr, const Instruction *I=nullptr)
Attempt to constant fold a compare instruction (icmp/fcmp) with the specified operands.
unsigned short computeExpressionSize(ArrayRef< const SCEV * > Args)
bool VerifySCEV
Printable print(const GCNRegPressure &RP, const GCNSubtarget *ST=nullptr)
ConstantRange getConstantRangeFromMetadata(const MDNode &RangeMD)
Parse out a conservative ConstantRange from !range metadata.
int countr_zero(T Val)
Count number of 0's from the least significant bit to the most stopping at the first 1.
Definition: bit.h:215
Value * simplifyInstruction(Instruction *I, const SimplifyQuery &Q)
See if we can compute a simplified version of this instruction.
bool isOverflowIntrinsicNoWrap(const WithOverflowInst *WO, const DominatorTree &DT)
Returns true if the arithmetic part of the WO 's result is used only along the paths control dependen...
bool matchSimpleRecurrence(const PHINode *P, BinaryOperator *&BO, Value *&Start, Value *&Step)
Attempt to match a simple first order recurrence cycle of the form: iv = phi Ty [Start,...
void erase(Container &C, ValueType V)
Wrapper function to remove a value from a container:
Definition: STLExtras.h:2107
bool any_of(R &&range, UnaryPredicate P)
Provide wrappers to std::any_of which take ranges instead of having to pass begin/end explicitly.
Definition: STLExtras.h:1746
void initializeScalarEvolutionWrapperPassPass(PassRegistry &)
auto reverse(ContainerTy &&C)
Definition: STLExtras.h:420
bool isMustProgress(const Loop *L)
Return true if this loop can be assumed to make progress.
Definition: LoopInfo.cpp:1150
bool impliesPoison(const Value *ValAssumedPoison, const Value *V)
Return true if V is poison given that ValAssumedPoison is already poison.
bool isFinite(const Loop *L)
Return true if this loop can be assumed to run for a finite number of iterations.
Definition: LoopInfo.cpp:1140
bool programUndefinedIfPoison(const Instruction *Inst)
raw_ostream & dbgs()
dbgs() - This returns a reference to a raw_ostream for debugging messages.
Definition: Debug.cpp:163
bool isPointerTy(const Type *T)
Definition: SPIRVUtils.h:255
ConstantRange getVScaleRange(const Function *F, unsigned BitWidth)
Determine the possible constant range of vscale with the given bit width, based on the vscale_range f...
Constant * ConstantFoldInstOperands(Instruction *I, ArrayRef< Constant * > Ops, const DataLayout &DL, const TargetLibraryInfo *TLI=nullptr, bool AllowNonDeterministic=true)
ConstantFoldInstOperands - Attempt to constant fold an instruction with the specified operands.
bool isKnownNonZero(const Value *V, const SimplifyQuery &Q, unsigned Depth=0)
Return true if the given value is known to be non-zero when defined.
@ First
Helpers to iterate all locations in the MemoryEffectsBase class.
bool propagatesPoison(const Use &PoisonOp)
Return true if PoisonOp's user yields poison or raises UB if its operand PoisonOp is poison.
@ UMin
Unsigned integer min implemented in terms of select(cmp()).
@ Mul
Product of integers.
@ SMax
Signed integer max implemented in terms of select(cmp()).
@ SMin
Signed integer min implemented in terms of select(cmp()).
@ Add
Sum of integers.
@ UMax
Unsigned integer max implemented in terms of select(cmp()).
bool isIntN(unsigned N, int64_t x)
Checks if an signed integer fits into the given (dynamic) bit width.
Definition: MathExtras.h:260
auto count(R &&Range, const E &Element)
Wrapper function around std::count to count the number of times an element Element occurs in the give...
Definition: STLExtras.h:1938
void computeKnownBits(const Value *V, KnownBits &Known, const DataLayout &DL, unsigned Depth=0, AssumptionCache *AC=nullptr, const Instruction *CxtI=nullptr, const DominatorTree *DT=nullptr, bool UseInstrInfo=true)
Determine which bits of V are known to be either zero or one and return them in the KnownZero/KnownOn...
DWARFExpression::Operation Op
auto max_element(R &&Range)
Provide wrappers to std::max_element which take ranges instead of having to pass begin/end explicitly...
Definition: STLExtras.h:2014
raw_ostream & operator<<(raw_ostream &OS, const APFixedPoint &FX)
Definition: APFixedPoint.h:303
constexpr unsigned BitWidth
Definition: BitmaskEnum.h:217
OutputIt move(R &&Range, OutputIt Out)
Provide wrappers to std::move which take ranges instead of having to pass begin/end explicitly.
Definition: STLExtras.h:1873
bool isGuaranteedToTransferExecutionToSuccessor(const Instruction *I)
Return true if this function can prove that the instruction I will always transfer execution to one o...
auto count_if(R &&Range, UnaryPredicate P)
Wrapper function around std::count_if to count the number of times an element satisfying a given pred...
Definition: STLExtras.h:1945
auto predecessors(const MachineBasicBlock *BB)
bool is_contained(R &&Range, const E &Element)
Returns true if Element is found in Range.
Definition: STLExtras.h:1903
unsigned ComputeNumSignBits(const Value *Op, const DataLayout &DL, unsigned Depth=0, AssumptionCache *AC=nullptr, const Instruction *CxtI=nullptr, const DominatorTree *DT=nullptr, bool UseInstrInfo=true)
Return the number of times the sign bit of the register is replicated into the other bits.
iterator_range< df_iterator< T > > depth_first(const T &G)
auto seq(T Begin, T End)
Iterate over an integral type from Begin up to - but not including - End.
Definition: Sequence.h:305
bool isGuaranteedNotToBePoison(const Value *V, AssumptionCache *AC=nullptr, const Instruction *CtxI=nullptr, const DominatorTree *DT=nullptr, unsigned Depth=0)
Returns true if V cannot be poison, but may be undef.
bool SCEVExprContains(const SCEV *Root, PredTy Pred)
Return true if any node in Root satisfies the predicate Pred.
Implement std::hash so that hash_code can be used in STL containers.
Definition: BitVector.h:858
void swap(llvm::BitVector &LHS, llvm::BitVector &RHS)
Implement std::swap in terms of BitVector swap.
Definition: BitVector.h:860
#define N
#define NC
Definition: regutils.h:42
This struct is a compact representation of a valid (non-zero power of two) alignment.
Definition: Alignment.h:39
A special type used by analysis passes to provide an address that identifies that particular analysis...
Definition: Analysis.h:28
Incoming for lane maks phi as machine instruction, incoming register Reg and incoming block Block are...
static KnownBits makeConstant(const APInt &C)
Create known bits from a known constant.
Definition: KnownBits.h:293
bool isNonNegative() const
Returns true if this value is known to be non-negative.
Definition: KnownBits.h:100
static KnownBits ashr(const KnownBits &LHS, const KnownBits &RHS, bool ShAmtNonZero=false, bool Exact=false)
Compute known bits for ashr(LHS, RHS).
Definition: KnownBits.cpp:428
unsigned getBitWidth() const
Get the bit width of this value.
Definition: KnownBits.h:43
static KnownBits lshr(const KnownBits &LHS, const KnownBits &RHS, bool ShAmtNonZero=false, bool Exact=false)
Compute known bits for lshr(LHS, RHS).
Definition: KnownBits.cpp:370
KnownBits zextOrTrunc(unsigned BitWidth) const
Return known bits for a zero extension or truncation of the value we're tracking.
Definition: KnownBits.h:188
APInt getMaxValue() const
Return the maximal unsigned value possible given these KnownBits.
Definition: KnownBits.h:137
APInt getMinValue() const
Return the minimal unsigned value possible given these KnownBits.
Definition: KnownBits.h:121
bool isNegative() const
Returns true if this value is known to be negative.
Definition: KnownBits.h:97
static KnownBits shl(const KnownBits &LHS, const KnownBits &RHS, bool NUW=false, bool NSW=false, bool ShAmtNonZero=false)
Compute known bits for shl(LHS, RHS).
Definition: KnownBits.cpp:285
An object of this class is returned by queries that could not be answered.
static bool classof(const SCEV *S)
Methods for support type inquiry through isa, cast, and dyn_cast:
This class defines a simple visitor class that may be used for various SCEV analysis purposes.
A utility class that uses RAII to save and restore the value of a variable.
Information about the number of loop iterations for which a loop exit's branch condition evaluates to...
ExitLimit(const SCEV *E)
Construct either an exact exit limit from a constant, or an unknown one from a SCEVCouldNotCompute.
SmallVector< const SCEVPredicate *, 4 > Predicates
A vector of predicate guards for this ExitLimit.