LLVM 23.0.0git
ScalarEvolution.cpp
Go to the documentation of this file.
1//===- ScalarEvolution.cpp - Scalar Evolution Analysis --------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file contains the implementation of the scalar evolution analysis
10// engine, which is used primarily to analyze expressions involving induction
11// variables in loops.
12//
13// There are several aspects to this library. First is the representation of
14// scalar expressions, which are represented as subclasses of the SCEV class.
15// These classes are used to represent certain types of subexpressions that we
16// can handle. We only create one SCEV of a particular shape, so
17// pointer-comparisons for equality are legal.
18//
19// One important aspect of the SCEV objects is that they are never cyclic, even
20// if there is a cycle in the dataflow for an expression (ie, a PHI node). If
21// the PHI node is one of the idioms that we can represent (e.g., a polynomial
22// recurrence) then we represent it directly as a recurrence node, otherwise we
23// represent it as a SCEVUnknown node.
24//
25// In addition to being able to represent expressions of various types, we also
26// have folders that are used to build the *canonical* representation for a
27// particular expression. These folders are capable of using a variety of
28// rewrite rules to simplify the expressions.
29//
30// Once the folders are defined, we can implement the more interesting
31// higher-level code, such as the code that recognizes PHI nodes of various
32// types, computes the execution count of a loop, etc.
33//
34// TODO: We should use these routines and value representations to implement
35// dependence analysis!
36//
37//===----------------------------------------------------------------------===//
38//
39// There are several good references for the techniques used in this analysis.
40//
41// Chains of recurrences -- a method to expedite the evaluation
42// of closed-form functions
43// Olaf Bachmann, Paul S. Wang, Eugene V. Zima
44//
45// On computational properties of chains of recurrences
46// Eugene V. Zima
47//
48// Symbolic Evaluation of Chains of Recurrences for Loop Optimization
49// Robert A. van Engelen
50//
51// Efficient Symbolic Analysis for Optimizing Compilers
52// Robert A. van Engelen
53//
54// Using the chains of recurrences algebra for data dependence testing and
55// induction variable substitution
56// MS Thesis, Johnie Birch
57//
58//===----------------------------------------------------------------------===//
59
61#include "llvm/ADT/APInt.h"
62#include "llvm/ADT/ArrayRef.h"
63#include "llvm/ADT/DenseMap.h"
65#include "llvm/ADT/FoldingSet.h"
66#include "llvm/ADT/STLExtras.h"
67#include "llvm/ADT/ScopeExit.h"
68#include "llvm/ADT/Sequence.h"
71#include "llvm/ADT/Statistic.h"
73#include "llvm/ADT/StringRef.h"
83#include "llvm/Config/llvm-config.h"
84#include "llvm/IR/Argument.h"
85#include "llvm/IR/BasicBlock.h"
86#include "llvm/IR/CFG.h"
87#include "llvm/IR/Constant.h"
89#include "llvm/IR/Constants.h"
90#include "llvm/IR/DataLayout.h"
92#include "llvm/IR/Dominators.h"
93#include "llvm/IR/Function.h"
94#include "llvm/IR/GlobalAlias.h"
95#include "llvm/IR/GlobalValue.h"
97#include "llvm/IR/InstrTypes.h"
98#include "llvm/IR/Instruction.h"
101#include "llvm/IR/Intrinsics.h"
102#include "llvm/IR/LLVMContext.h"
103#include "llvm/IR/Operator.h"
104#include "llvm/IR/PatternMatch.h"
105#include "llvm/IR/Type.h"
106#include "llvm/IR/Use.h"
107#include "llvm/IR/User.h"
108#include "llvm/IR/Value.h"
109#include "llvm/IR/Verifier.h"
111#include "llvm/Pass.h"
112#include "llvm/Support/Casting.h"
115#include "llvm/Support/Debug.h"
121#include <algorithm>
122#include <cassert>
123#include <climits>
124#include <cstdint>
125#include <cstdlib>
126#include <map>
127#include <memory>
128#include <numeric>
129#include <optional>
130#include <tuple>
131#include <utility>
132#include <vector>
133
134using namespace llvm;
135using namespace PatternMatch;
136using namespace SCEVPatternMatch;
137
138#define DEBUG_TYPE "scalar-evolution"
139
140STATISTIC(NumExitCountsComputed,
141 "Number of loop exits with predictable exit counts");
142STATISTIC(NumExitCountsNotComputed,
143 "Number of loop exits without predictable exit counts");
144STATISTIC(NumBruteForceTripCountsComputed,
145 "Number of loops with trip counts computed by force");
146
147#ifdef EXPENSIVE_CHECKS
148bool llvm::VerifySCEV = true;
149#else
150bool llvm::VerifySCEV = false;
151#endif
152
154 MaxBruteForceIterations("scalar-evolution-max-iterations", cl::ReallyHidden,
155 cl::desc("Maximum number of iterations SCEV will "
156 "symbolically execute a constant "
157 "derived loop"),
158 cl::init(100));
159
161 "verify-scev", cl::Hidden, cl::location(VerifySCEV),
162 cl::desc("Verify ScalarEvolution's backedge taken counts (slow)"));
164 "verify-scev-strict", cl::Hidden,
165 cl::desc("Enable stricter verification with -verify-scev is passed"));
166
168 "scev-verify-ir", cl::Hidden,
169 cl::desc("Verify IR correctness when making sensitive SCEV queries (slow)"),
170 cl::init(false));
171
173 "scev-mulops-inline-threshold", cl::Hidden,
174 cl::desc("Threshold for inlining multiplication operands into a SCEV"),
175 cl::init(32));
176
178 "scev-addops-inline-threshold", cl::Hidden,
179 cl::desc("Threshold for inlining addition operands into a SCEV"),
180 cl::init(500));
181
183 "scalar-evolution-max-scev-compare-depth", cl::Hidden,
184 cl::desc("Maximum depth of recursive SCEV complexity comparisons"),
185 cl::init(32));
186
188 "scalar-evolution-max-scev-operations-implication-depth", cl::Hidden,
189 cl::desc("Maximum depth of recursive SCEV operations implication analysis"),
190 cl::init(2));
191
193 "scalar-evolution-max-value-compare-depth", cl::Hidden,
194 cl::desc("Maximum depth of recursive value complexity comparisons"),
195 cl::init(2));
196
198 MaxArithDepth("scalar-evolution-max-arith-depth", cl::Hidden,
199 cl::desc("Maximum depth of recursive arithmetics"),
200 cl::init(32));
201
203 "scalar-evolution-max-constant-evolving-depth", cl::Hidden,
204 cl::desc("Maximum depth of recursive constant evolving"), cl::init(32));
205
207 MaxCastDepth("scalar-evolution-max-cast-depth", cl::Hidden,
208 cl::desc("Maximum depth of recursive SExt/ZExt/Trunc"),
209 cl::init(8));
210
212 MaxAddRecSize("scalar-evolution-max-add-rec-size", cl::Hidden,
213 cl::desc("Max coefficients in AddRec during evolving"),
214 cl::init(8));
215
217 HugeExprThreshold("scalar-evolution-huge-expr-threshold", cl::Hidden,
218 cl::desc("Size of the expression which is considered huge"),
219 cl::init(4096));
220
222 "scev-range-iter-threshold", cl::Hidden,
223 cl::desc("Threshold for switching to iteratively computing SCEV ranges"),
224 cl::init(32));
225
227 "scalar-evolution-max-loop-guard-collection-depth", cl::Hidden,
228 cl::desc("Maximum depth for recursive loop guard collection"), cl::init(1));
229
230static cl::opt<bool>
231ClassifyExpressions("scalar-evolution-classify-expressions",
232 cl::Hidden, cl::init(true),
233 cl::desc("When printing analysis, include information on every instruction"));
234
236 "scalar-evolution-use-expensive-range-sharpening", cl::Hidden,
237 cl::init(false),
238 cl::desc("Use more powerful methods of sharpening expression ranges. May "
239 "be costly in terms of compile time"));
240
242 "scalar-evolution-max-scc-analysis-depth", cl::Hidden,
243 cl::desc("Maximum amount of nodes to process while searching SCEVUnknown "
244 "Phi strongly connected components"),
245 cl::init(8));
246
247static cl::opt<bool>
248 EnableFiniteLoopControl("scalar-evolution-finite-loop", cl::Hidden,
249 cl::desc("Handle <= and >= in finite loops"),
250 cl::init(true));
251
253 "scalar-evolution-use-context-for-no-wrap-flag-strenghening", cl::Hidden,
254 cl::desc("Infer nuw/nsw flags using context where suitable"),
255 cl::init(true));
256
257//===----------------------------------------------------------------------===//
258// SCEV class definitions
259//===----------------------------------------------------------------------===//
260
262 // Leaf nodes are always their own canonical.
263 switch (getSCEVType()) {
264 case scConstant:
265 case scVScale:
266 case scUnknown:
267 CanonicalSCEV = this;
268 return;
269 default:
270 break;
271 }
272
273 // For all other expressions, check whether any immediate operand has a
274 // different canonical. Since operands are always created before their parent,
275 // their canonical pointers are already set — no recursion needed.
276 bool Changed = false;
278 for (SCEVUse Op : operands()) {
279 CanonOps.push_back(Op->getCanonical());
280 Changed |= CanonOps.back() != Op.getPointer();
281 }
282
283 if (!Changed) {
284 CanonicalSCEV = this;
285 return;
286 }
287
288 auto *NAry = dyn_cast<SCEVNAryExpr>(this);
289 SCEV::NoWrapFlags Flags = NAry ? NAry->getNoWrapFlags() : SCEV::FlagAnyWrap;
290 switch (getSCEVType()) {
291 case scPtrToAddr:
292 CanonicalSCEV = SE.getPtrToAddrExpr(CanonOps[0]);
293 return;
294 case scPtrToInt:
295 CanonicalSCEV = SE.getPtrToIntExpr(CanonOps[0], getType());
296 return;
297 case scTruncate:
298 CanonicalSCEV = SE.getTruncateExpr(CanonOps[0], getType());
299 return;
300 case scZeroExtend:
301 CanonicalSCEV = SE.getZeroExtendExpr(CanonOps[0], getType());
302 return;
303 case scSignExtend:
304 CanonicalSCEV = SE.getSignExtendExpr(CanonOps[0], getType());
305 return;
306 case scUDivExpr:
307 CanonicalSCEV = SE.getUDivExpr(CanonOps[0], CanonOps[1]);
308 return;
309 case scAddExpr:
310 CanonicalSCEV = SE.getAddExpr(CanonOps, Flags);
311 return;
312 case scMulExpr:
313 CanonicalSCEV = SE.getMulExpr(CanonOps, Flags);
314 return;
315 case scAddRecExpr:
317 CanonOps, cast<SCEVAddRecExpr>(this)->getLoop(), Flags);
318 return;
319 case scSMaxExpr:
320 CanonicalSCEV = SE.getSMaxExpr(CanonOps);
321 return;
322 case scUMaxExpr:
323 CanonicalSCEV = SE.getUMaxExpr(CanonOps);
324 return;
325 case scSMinExpr:
326 CanonicalSCEV = SE.getSMinExpr(CanonOps);
327 return;
328 case scUMinExpr:
329 CanonicalSCEV = SE.getUMinExpr(CanonOps);
330 return;
332 CanonicalSCEV = SE.getUMinExpr(CanonOps, /*Sequential=*/true);
333 return;
334 default:
335 llvm_unreachable("Unknown SCEV type");
336 }
337}
338
339#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
341 print(dbgs());
342 dbgs() << '\n';
343}
344#endif
345
347 getPointer()->print(OS);
348 SCEV::NoWrapFlags Flags = static_cast<SCEV::NoWrapFlags>(getInt());
349 if (Flags & SCEV::FlagNUW)
350 OS << "(u nuw)";
351 if (Flags & SCEV::FlagNSW)
352 OS << "(u nsw)";
353}
354
355//===----------------------------------------------------------------------===//
356// Implementation of the SCEV class.
357//
358
359#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
361 print(dbgs());
362 dbgs() << '\n';
363}
364#endif
365
366void SCEV::print(raw_ostream &OS) const {
367 switch (getSCEVType()) {
368 case scConstant:
369 cast<SCEVConstant>(this)->getValue()->printAsOperand(OS, false);
370 return;
371 case scVScale:
372 OS << "vscale";
373 return;
374 case scPtrToAddr:
375 case scPtrToInt: {
376 const SCEVCastExpr *PtrCast = cast<SCEVCastExpr>(this);
377 const SCEV *Op = PtrCast->getOperand();
378 StringRef OpS = getSCEVType() == scPtrToAddr ? "addr" : "int";
379 OS << "(ptrto" << OpS << " " << *Op->getType() << " " << *Op << " to "
380 << *PtrCast->getType() << ")";
381 return;
382 }
383 case scTruncate: {
384 const SCEVTruncateExpr *Trunc = cast<SCEVTruncateExpr>(this);
385 const SCEV *Op = Trunc->getOperand();
386 OS << "(trunc " << *Op->getType() << " " << *Op << " to "
387 << *Trunc->getType() << ")";
388 return;
389 }
390 case scZeroExtend: {
392 const SCEV *Op = ZExt->getOperand();
393 OS << "(zext " << *Op->getType() << " " << *Op << " to "
394 << *ZExt->getType() << ")";
395 return;
396 }
397 case scSignExtend: {
399 const SCEV *Op = SExt->getOperand();
400 OS << "(sext " << *Op->getType() << " " << *Op << " to "
401 << *SExt->getType() << ")";
402 return;
403 }
404 case scAddRecExpr: {
405 const SCEVAddRecExpr *AR = cast<SCEVAddRecExpr>(this);
406 OS << "{" << *AR->getOperand(0);
407 for (unsigned i = 1, e = AR->getNumOperands(); i != e; ++i)
408 OS << ",+," << *AR->getOperand(i);
409 OS << "}<";
410 if (AR->hasNoUnsignedWrap())
411 OS << "nuw><";
412 if (AR->hasNoSignedWrap())
413 OS << "nsw><";
414 if (AR->hasNoSelfWrap() &&
416 OS << "nw><";
417 AR->getLoop()->getHeader()->printAsOperand(OS, /*PrintType=*/false);
418 OS << ">";
419 return;
420 }
421 case scAddExpr:
422 case scMulExpr:
423 case scUMaxExpr:
424 case scSMaxExpr:
425 case scUMinExpr:
426 case scSMinExpr:
428 const SCEVNAryExpr *NAry = cast<SCEVNAryExpr>(this);
429 const char *OpStr = nullptr;
430 switch (NAry->getSCEVType()) {
431 case scAddExpr: OpStr = " + "; break;
432 case scMulExpr: OpStr = " * "; break;
433 case scUMaxExpr: OpStr = " umax "; break;
434 case scSMaxExpr: OpStr = " smax "; break;
435 case scUMinExpr:
436 OpStr = " umin ";
437 break;
438 case scSMinExpr:
439 OpStr = " smin ";
440 break;
442 OpStr = " umin_seq ";
443 break;
444 default:
445 llvm_unreachable("There are no other nary expression types.");
446 }
447 OS << "("
449 << ")";
450 switch (NAry->getSCEVType()) {
451 case scAddExpr:
452 case scMulExpr:
453 if (NAry->hasNoUnsignedWrap())
454 OS << "<nuw>";
455 if (NAry->hasNoSignedWrap())
456 OS << "<nsw>";
457 break;
458 default:
459 // Nothing to print for other nary expressions.
460 break;
461 }
462 return;
463 }
464 case scUDivExpr: {
465 const SCEVUDivExpr *UDiv = cast<SCEVUDivExpr>(this);
466 OS << "(" << *UDiv->getLHS() << " /u " << *UDiv->getRHS() << ")";
467 return;
468 }
469 case scUnknown:
470 cast<SCEVUnknown>(this)->getValue()->printAsOperand(OS, false);
471 return;
473 OS << "***COULDNOTCOMPUTE***";
474 return;
475 }
476 llvm_unreachable("Unknown SCEV kind!");
477}
478
480 switch (getSCEVType()) {
481 case scConstant:
482 return cast<SCEVConstant>(this)->getType();
483 case scVScale:
484 return cast<SCEVVScale>(this)->getType();
485 case scPtrToAddr:
486 case scPtrToInt:
487 case scTruncate:
488 case scZeroExtend:
489 case scSignExtend:
490 return cast<SCEVCastExpr>(this)->getType();
491 case scAddRecExpr:
492 return cast<SCEVAddRecExpr>(this)->getType();
493 case scMulExpr:
494 return cast<SCEVMulExpr>(this)->getType();
495 case scUMaxExpr:
496 case scSMaxExpr:
497 case scUMinExpr:
498 case scSMinExpr:
499 return cast<SCEVMinMaxExpr>(this)->getType();
501 return cast<SCEVSequentialMinMaxExpr>(this)->getType();
502 case scAddExpr:
503 return cast<SCEVAddExpr>(this)->getType();
504 case scUDivExpr:
505 return cast<SCEVUDivExpr>(this)->getType();
506 case scUnknown:
507 return cast<SCEVUnknown>(this)->getType();
509 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
510 }
511 llvm_unreachable("Unknown SCEV kind!");
512}
513
515 switch (getSCEVType()) {
516 case scConstant:
517 case scVScale:
518 case scUnknown:
519 return {};
520 case scPtrToAddr:
521 case scPtrToInt:
522 case scTruncate:
523 case scZeroExtend:
524 case scSignExtend:
525 return cast<SCEVCastExpr>(this)->operands();
526 case scAddRecExpr:
527 case scAddExpr:
528 case scMulExpr:
529 case scUMaxExpr:
530 case scSMaxExpr:
531 case scUMinExpr:
532 case scSMinExpr:
534 return cast<SCEVNAryExpr>(this)->operands();
535 case scUDivExpr:
536 return cast<SCEVUDivExpr>(this)->operands();
538 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
539 }
540 llvm_unreachable("Unknown SCEV kind!");
541}
542
543bool SCEV::isZero() const { return match(this, m_scev_Zero()); }
544
545bool SCEV::isOne() const { return match(this, m_scev_One()); }
546
547bool SCEV::isAllOnesValue() const { return match(this, m_scev_AllOnes()); }
548
551 if (!Mul) return false;
552
553 // If there is a constant factor, it will be first.
554 const SCEVConstant *SC = dyn_cast<SCEVConstant>(Mul->getOperand(0));
555 if (!SC) return false;
556
557 // Return true if the value is negative, this matches things like (-42 * V).
558 return SC->getAPInt().isNegative();
559}
560
563
565 return S->getSCEVType() == scCouldNotCompute;
566}
567
570 ID.AddInteger(scConstant);
571 ID.AddPointer(V);
572 void *IP = nullptr;
573 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
574 SCEV *S = new (SCEVAllocator) SCEVConstant(ID.Intern(SCEVAllocator), V);
575 UniqueSCEVs.InsertNode(S, IP);
576 S->computeAndSetCanonical(*this);
577 return S;
578}
579
581 return getConstant(ConstantInt::get(getContext(), Val));
582}
583
584const SCEV *
587 // TODO: Avoid implicit trunc?
588 // See https://github.com/llvm/llvm-project/issues/112510.
589 return getConstant(
590 ConstantInt::get(ITy, V, isSigned, /*ImplicitTrunc=*/true));
591}
592
595 ID.AddInteger(scVScale);
596 ID.AddPointer(Ty);
597 void *IP = nullptr;
598 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP))
599 return S;
600 SCEV *S = new (SCEVAllocator) SCEVVScale(ID.Intern(SCEVAllocator), Ty);
601 UniqueSCEVs.InsertNode(S, IP);
602 S->computeAndSetCanonical(*this);
603 return S;
604}
605
607 SCEV::NoWrapFlags Flags) {
608 const SCEV *Res = getConstant(Ty, EC.getKnownMinValue());
609 if (EC.isScalable())
610 Res = getMulExpr(Res, getVScale(Ty), Flags);
611 return Res;
612}
613
617
618SCEVPtrToAddrExpr::SCEVPtrToAddrExpr(const FoldingSetNodeIDRef ID,
619 const SCEV *Op, Type *ITy)
620 : SCEVCastExpr(ID, scPtrToAddr, Op, ITy) {
621 assert(getOperand()->getType()->isPointerTy() && Ty->isIntegerTy() &&
622 "Must be a non-bit-width-changing pointer-to-integer cast!");
623}
624
625SCEVPtrToIntExpr::SCEVPtrToIntExpr(const FoldingSetNodeIDRef ID, SCEVUse Op,
626 Type *ITy)
627 : SCEVCastExpr(ID, scPtrToInt, Op, ITy) {
628 assert(getOperand()->getType()->isPointerTy() && Ty->isIntegerTy() &&
629 "Must be a non-bit-width-changing pointer-to-integer cast!");
630}
631
636
637SCEVTruncateExpr::SCEVTruncateExpr(const FoldingSetNodeIDRef ID, SCEVUse op,
638 Type *ty)
640 assert(getOperand()->getType()->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
641 "Cannot truncate non-integer value!");
642}
643
644SCEVZeroExtendExpr::SCEVZeroExtendExpr(const FoldingSetNodeIDRef ID, SCEVUse op,
645 Type *ty)
647 assert(getOperand()->getType()->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
648 "Cannot zero extend non-integer value!");
649}
650
651SCEVSignExtendExpr::SCEVSignExtendExpr(const FoldingSetNodeIDRef ID, SCEVUse op,
652 Type *ty)
654 assert(getOperand()->getType()->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
655 "Cannot sign extend non-integer value!");
656}
657
659 // Clear this SCEVUnknown from various maps.
660 SE->forgetMemoizedResults({this});
661
662 // Remove this SCEVUnknown from the uniquing map.
663 SE->UniqueSCEVs.RemoveNode(this);
664
665 // Release the value.
666 setValPtr(nullptr);
667}
668
669void SCEVUnknown::allUsesReplacedWith(Value *New) {
670 // Clear this SCEVUnknown from various maps.
671 SE->forgetMemoizedResults({this});
672
673 // Remove this SCEVUnknown from the uniquing map.
674 SE->UniqueSCEVs.RemoveNode(this);
675
676 // Replace the value pointer in case someone is still using this SCEVUnknown.
677 setValPtr(New);
678}
679
680//===----------------------------------------------------------------------===//
681// SCEV Utilities
682//===----------------------------------------------------------------------===//
683
684/// Compare the two values \p LV and \p RV in terms of their "complexity" where
685/// "complexity" is a partial (and somewhat ad-hoc) relation used to order
686/// operands in SCEV expressions.
687static int CompareValueComplexity(const LoopInfo *const LI, Value *LV,
688 Value *RV, unsigned Depth) {
690 return 0;
691
692 // Order pointer values after integer values. This helps SCEVExpander form
693 // GEPs.
694 bool LIsPointer = LV->getType()->isPointerTy(),
695 RIsPointer = RV->getType()->isPointerTy();
696 if (LIsPointer != RIsPointer)
697 return (int)LIsPointer - (int)RIsPointer;
698
699 // Compare getValueID values.
700 unsigned LID = LV->getValueID(), RID = RV->getValueID();
701 if (LID != RID)
702 return (int)LID - (int)RID;
703
704 // Sort arguments by their position.
705 if (const auto *LA = dyn_cast<Argument>(LV)) {
706 const auto *RA = cast<Argument>(RV);
707 unsigned LArgNo = LA->getArgNo(), RArgNo = RA->getArgNo();
708 return (int)LArgNo - (int)RArgNo;
709 }
710
711 if (const auto *LGV = dyn_cast<GlobalValue>(LV)) {
712 const auto *RGV = cast<GlobalValue>(RV);
713
714 if (auto L = LGV->getLinkage() - RGV->getLinkage())
715 return L;
716
717 const auto IsGVNameSemantic = [&](const GlobalValue *GV) {
718 auto LT = GV->getLinkage();
719 return !(GlobalValue::isPrivateLinkage(LT) ||
721 };
722
723 // Use the names to distinguish the two values, but only if the
724 // names are semantically important.
725 if (IsGVNameSemantic(LGV) && IsGVNameSemantic(RGV))
726 return LGV->getName().compare(RGV->getName());
727 }
728
729 // For instructions, compare their loop depth, and their operand count. This
730 // is pretty loose.
731 if (const auto *LInst = dyn_cast<Instruction>(LV)) {
732 const auto *RInst = cast<Instruction>(RV);
733
734 // Compare loop depths.
735 const BasicBlock *LParent = LInst->getParent(),
736 *RParent = RInst->getParent();
737 if (LParent != RParent) {
738 unsigned LDepth = LI->getLoopDepth(LParent),
739 RDepth = LI->getLoopDepth(RParent);
740 if (LDepth != RDepth)
741 return (int)LDepth - (int)RDepth;
742 }
743
744 // Compare the number of operands.
745 unsigned LNumOps = LInst->getNumOperands(),
746 RNumOps = RInst->getNumOperands();
747 if (LNumOps != RNumOps)
748 return (int)LNumOps - (int)RNumOps;
749
750 for (unsigned Idx : seq(LNumOps)) {
751 int Result = CompareValueComplexity(LI, LInst->getOperand(Idx),
752 RInst->getOperand(Idx), Depth + 1);
753 if (Result != 0)
754 return Result;
755 }
756 }
757
758 return 0;
759}
760
761// Return negative, zero, or positive, if LHS is less than, equal to, or greater
762// than RHS, respectively. A three-way result allows recursive comparisons to be
763// more efficient.
764// If the max analysis depth was reached, return std::nullopt, assuming we do
765// not know if they are equivalent for sure.
766static std::optional<int>
767CompareSCEVComplexity(const LoopInfo *const LI, const SCEV *LHS,
768 const SCEV *RHS, DominatorTree &DT, unsigned Depth = 0) {
769 // Fast-path: SCEVs are uniqued so we can do a quick equality check.
770 if (LHS == RHS)
771 return 0;
772
773 // Primarily, sort the SCEVs by their getSCEVType().
774 SCEVTypes LType = LHS->getSCEVType(), RType = RHS->getSCEVType();
775 if (LType != RType)
776 return (int)LType - (int)RType;
777
779 return std::nullopt;
780
781 // Aside from the getSCEVType() ordering, the particular ordering
782 // isn't very important except that it's beneficial to be consistent,
783 // so that (a + b) and (b + a) don't end up as different expressions.
784 switch (LType) {
785 case scUnknown: {
786 const SCEVUnknown *LU = cast<SCEVUnknown>(LHS);
787 const SCEVUnknown *RU = cast<SCEVUnknown>(RHS);
788
789 int X =
790 CompareValueComplexity(LI, LU->getValue(), RU->getValue(), Depth + 1);
791 return X;
792 }
793
794 case scConstant: {
797
798 // Compare constant values.
799 const APInt &LA = LC->getAPInt();
800 const APInt &RA = RC->getAPInt();
801 unsigned LBitWidth = LA.getBitWidth(), RBitWidth = RA.getBitWidth();
802 if (LBitWidth != RBitWidth)
803 return (int)LBitWidth - (int)RBitWidth;
804 return LA.ult(RA) ? -1 : 1;
805 }
806
807 case scVScale: {
808 const auto *LTy = cast<IntegerType>(cast<SCEVVScale>(LHS)->getType());
809 const auto *RTy = cast<IntegerType>(cast<SCEVVScale>(RHS)->getType());
810 return LTy->getBitWidth() - RTy->getBitWidth();
811 }
812
813 case scAddRecExpr: {
816
817 // There is always a dominance between two recs that are used by one SCEV,
818 // so we can safely sort recs by loop header dominance. We require such
819 // order in getAddExpr.
820 const Loop *LLoop = LA->getLoop(), *RLoop = RA->getLoop();
821 if (LLoop != RLoop) {
822 const BasicBlock *LHead = LLoop->getHeader(), *RHead = RLoop->getHeader();
823 assert(LHead != RHead && "Two loops share the same header?");
824 if (DT.dominates(LHead, RHead))
825 return 1;
826 assert(DT.dominates(RHead, LHead) &&
827 "No dominance between recurrences used by one SCEV?");
828 return -1;
829 }
830
831 [[fallthrough]];
832 }
833
834 case scTruncate:
835 case scZeroExtend:
836 case scSignExtend:
837 case scPtrToAddr:
838 case scPtrToInt:
839 case scAddExpr:
840 case scMulExpr:
841 case scUDivExpr:
842 case scSMaxExpr:
843 case scUMaxExpr:
844 case scSMinExpr:
845 case scUMinExpr:
847 ArrayRef<SCEVUse> LOps = LHS->operands();
848 ArrayRef<SCEVUse> ROps = RHS->operands();
849
850 // Lexicographically compare n-ary-like expressions.
851 unsigned LNumOps = LOps.size(), RNumOps = ROps.size();
852 if (LNumOps != RNumOps)
853 return (int)LNumOps - (int)RNumOps;
854
855 for (unsigned i = 0; i != LNumOps; ++i) {
856 auto X = CompareSCEVComplexity(LI, LOps[i].getPointer(),
857 ROps[i].getPointer(), DT, Depth + 1);
858 if (X != 0)
859 return X;
860 }
861 return 0;
862 }
863
865 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
866 }
867 llvm_unreachable("Unknown SCEV kind!");
868}
869
870/// Given a list of SCEV objects, order them by their complexity, and group
871/// objects of the same complexity together by value. When this routine is
872/// finished, we know that any duplicates in the vector are consecutive and that
873/// complexity is monotonically increasing.
874///
875/// Note that we go take special precautions to ensure that we get deterministic
876/// results from this routine. In other words, we don't want the results of
877/// this to depend on where the addresses of various SCEV objects happened to
878/// land in memory.
880 DominatorTree &DT) {
881 if (Ops.size() < 2) return; // Noop
882
883 // Whether LHS has provably less complexity than RHS.
884 auto IsLessComplex = [&](SCEVUse LHS, SCEVUse RHS) {
885 auto Complexity = CompareSCEVComplexity(LI, LHS, RHS, DT);
886 return Complexity && *Complexity < 0;
887 };
888 if (Ops.size() == 2) {
889 // This is the common case, which also happens to be trivially simple.
890 // Special case it.
891 SCEVUse &LHS = Ops[0], &RHS = Ops[1];
892 if (IsLessComplex(RHS, LHS))
893 std::swap(LHS, RHS);
894 return;
895 }
896
897 // Do the rough sort by complexity.
899 Ops, [&](SCEVUse LHS, SCEVUse RHS) { return IsLessComplex(LHS, RHS); });
900
901 // Now that we are sorted by complexity, group elements of the same
902 // complexity. Note that this is, at worst, N^2, but the vector is likely to
903 // be extremely short in practice. Note that we take this approach because we
904 // do not want to depend on the addresses of the objects we are grouping.
905 for (unsigned i = 0, e = Ops.size(); i != e-2; ++i) {
906 const SCEV *S = Ops[i];
907 unsigned Complexity = S->getSCEVType();
908
909 // If there are any objects of the same complexity and same value as this
910 // one, group them.
911 for (unsigned j = i+1; j != e && Ops[j]->getSCEVType() == Complexity; ++j) {
912 if (Ops[j] == S) { // Found a duplicate.
913 // Move it to immediately after i'th element.
914 std::swap(Ops[i+1], Ops[j]);
915 ++i; // no need to rescan it.
916 if (i == e-2) return; // Done!
917 }
918 }
919 }
920}
921
922/// Returns true if \p Ops contains a huge SCEV (the subtree of S contains at
923/// least HugeExprThreshold nodes).
925 return any_of(Ops, [](const SCEV *S) {
927 });
928}
929
930/// Performs a number of common optimizations on the passed \p Ops. If the
931/// whole expression reduces down to a single operand, it will be returned.
932///
933/// The following optimizations are performed:
934/// * Fold constants using the \p Fold function.
935/// * Remove identity constants satisfying \p IsIdentity.
936/// * If a constant satisfies \p IsAbsorber, return it.
937/// * Sort operands by complexity.
938template <typename FoldT, typename IsIdentityT, typename IsAbsorberT>
939static const SCEV *
941 SmallVectorImpl<SCEVUse> &Ops, FoldT Fold,
942 IsIdentityT IsIdentity, IsAbsorberT IsAbsorber) {
943 const SCEVConstant *Folded = nullptr;
944 for (unsigned Idx = 0; Idx < Ops.size();) {
945 const SCEV *Op = Ops[Idx];
946 if (const auto *C = dyn_cast<SCEVConstant>(Op)) {
947 if (!Folded)
948 Folded = C;
949 else
950 Folded = cast<SCEVConstant>(
951 SE.getConstant(Fold(Folded->getAPInt(), C->getAPInt())));
952 Ops.erase(Ops.begin() + Idx);
953 continue;
954 }
955 ++Idx;
956 }
957
958 if (Ops.empty()) {
959 assert(Folded && "Must have folded value");
960 return Folded;
961 }
962
963 if (Folded && IsAbsorber(Folded->getAPInt()))
964 return Folded;
965
966 GroupByComplexity(Ops, &LI, DT);
967 if (Folded && !IsIdentity(Folded->getAPInt()))
968 Ops.insert(Ops.begin(), Folded);
969
970 return Ops.size() == 1 ? Ops[0] : nullptr;
971}
972
973//===----------------------------------------------------------------------===//
974// Simple SCEV method implementations
975//===----------------------------------------------------------------------===//
976
977/// Compute BC(It, K). The result has width W. Assume, K > 0.
978static const SCEV *BinomialCoefficient(const SCEV *It, unsigned K,
979 ScalarEvolution &SE,
980 Type *ResultTy) {
981 // Handle the simplest case efficiently.
982 if (K == 1)
983 return SE.getTruncateOrZeroExtend(It, ResultTy);
984
985 // We are using the following formula for BC(It, K):
986 //
987 // BC(It, K) = (It * (It - 1) * ... * (It - K + 1)) / K!
988 //
989 // Suppose, W is the bitwidth of the return value. We must be prepared for
990 // overflow. Hence, we must assure that the result of our computation is
991 // equal to the accurate one modulo 2^W. Unfortunately, division isn't
992 // safe in modular arithmetic.
993 //
994 // However, this code doesn't use exactly that formula; the formula it uses
995 // is something like the following, where T is the number of factors of 2 in
996 // K! (i.e. trailing zeros in the binary representation of K!), and ^ is
997 // exponentiation:
998 //
999 // BC(It, K) = (It * (It - 1) * ... * (It - K + 1)) / 2^T / (K! / 2^T)
1000 //
1001 // This formula is trivially equivalent to the previous formula. However,
1002 // this formula can be implemented much more efficiently. The trick is that
1003 // K! / 2^T is odd, and exact division by an odd number *is* safe in modular
1004 // arithmetic. To do exact division in modular arithmetic, all we have
1005 // to do is multiply by the inverse. Therefore, this step can be done at
1006 // width W.
1007 //
1008 // The next issue is how to safely do the division by 2^T. The way this
1009 // is done is by doing the multiplication step at a width of at least W + T
1010 // bits. This way, the bottom W+T bits of the product are accurate. Then,
1011 // when we perform the division by 2^T (which is equivalent to a right shift
1012 // by T), the bottom W bits are accurate. Extra bits are okay; they'll get
1013 // truncated out after the division by 2^T.
1014 //
1015 // In comparison to just directly using the first formula, this technique
1016 // is much more efficient; using the first formula requires W * K bits,
1017 // but this formula less than W + K bits. Also, the first formula requires
1018 // a division step, whereas this formula only requires multiplies and shifts.
1019 //
1020 // It doesn't matter whether the subtraction step is done in the calculation
1021 // width or the input iteration count's width; if the subtraction overflows,
1022 // the result must be zero anyway. We prefer here to do it in the width of
1023 // the induction variable because it helps a lot for certain cases; CodeGen
1024 // isn't smart enough to ignore the overflow, which leads to much less
1025 // efficient code if the width of the subtraction is wider than the native
1026 // register width.
1027 //
1028 // (It's possible to not widen at all by pulling out factors of 2 before
1029 // the multiplication; for example, K=2 can be calculated as
1030 // It/2*(It+(It*INT_MIN/INT_MIN)+-1). However, it requires
1031 // extra arithmetic, so it's not an obvious win, and it gets
1032 // much more complicated for K > 3.)
1033
1034 // Protection from insane SCEVs; this bound is conservative,
1035 // but it probably doesn't matter.
1036 if (K > 1000)
1037 return SE.getCouldNotCompute();
1038
1039 unsigned W = SE.getTypeSizeInBits(ResultTy);
1040
1041 // Calculate K! / 2^T and T; we divide out the factors of two before
1042 // multiplying for calculating K! / 2^T to avoid overflow.
1043 // Other overflow doesn't matter because we only care about the bottom
1044 // W bits of the result.
1045 APInt OddFactorial(W, 1);
1046 unsigned T = 1;
1047 for (unsigned i = 3; i <= K; ++i) {
1048 unsigned TwoFactors = countr_zero(i);
1049 T += TwoFactors;
1050 OddFactorial *= (i >> TwoFactors);
1051 }
1052
1053 // We need at least W + T bits for the multiplication step
1054 unsigned CalculationBits = W + T;
1055
1056 // Calculate 2^T, at width T+W.
1057 APInt DivFactor = APInt::getOneBitSet(CalculationBits, T);
1058
1059 // Calculate the multiplicative inverse of K! / 2^T;
1060 // this multiplication factor will perform the exact division by
1061 // K! / 2^T.
1062 APInt MultiplyFactor = OddFactorial.multiplicativeInverse();
1063
1064 // Calculate the product, at width T+W
1065 IntegerType *CalculationTy = IntegerType::get(SE.getContext(),
1066 CalculationBits);
1067 const SCEV *Dividend = SE.getTruncateOrZeroExtend(It, CalculationTy);
1068 for (unsigned i = 1; i != K; ++i) {
1069 const SCEV *S = SE.getMinusSCEV(It, SE.getConstant(It->getType(), i));
1070 Dividend = SE.getMulExpr(Dividend,
1071 SE.getTruncateOrZeroExtend(S, CalculationTy));
1072 }
1073
1074 // Divide by 2^T
1075 const SCEV *DivResult = SE.getUDivExpr(Dividend, SE.getConstant(DivFactor));
1076
1077 // Truncate the result, and divide by K! / 2^T.
1078
1079 return SE.getMulExpr(SE.getConstant(MultiplyFactor),
1080 SE.getTruncateOrZeroExtend(DivResult, ResultTy));
1081}
1082
1083/// Return the value of this chain of recurrences at the specified iteration
1084/// number. We can evaluate this recurrence by multiplying each element in the
1085/// chain by the binomial coefficient corresponding to it. In other words, we
1086/// can evaluate {A,+,B,+,C,+,D} as:
1087///
1088/// A*BC(It, 0) + B*BC(It, 1) + C*BC(It, 2) + D*BC(It, 3)
1089///
1090/// where BC(It, k) stands for binomial coefficient.
1092 ScalarEvolution &SE) const {
1093 return evaluateAtIteration(operands(), It, SE);
1094}
1095
1097 const SCEV *It,
1098 ScalarEvolution &SE) {
1099 assert(Operands.size() > 0);
1100 const SCEV *Result = Operands[0].getPointer();
1101 for (unsigned i = 1, e = Operands.size(); i != e; ++i) {
1102 // The computation is correct in the face of overflow provided that the
1103 // multiplication is performed _after_ the evaluation of the binomial
1104 // coefficient.
1105 const SCEV *Coeff = BinomialCoefficient(It, i, SE, Result->getType());
1106 if (isa<SCEVCouldNotCompute>(Coeff))
1107 return Coeff;
1108
1109 Result =
1110 SE.getAddExpr(Result, SE.getMulExpr(Operands[i].getPointer(), Coeff));
1111 }
1112 return Result;
1113}
1114
1115//===----------------------------------------------------------------------===//
1116// SCEV Expression folder implementations
1117//===----------------------------------------------------------------------===//
1118
1119/// The SCEVCastSinkingRewriter takes a scalar evolution expression,
1120/// which computes a pointer-typed value, and rewrites the whole expression
1121/// tree so that *all* the computations are done on integers, and the only
1122/// pointer-typed operands in the expression are SCEVUnknown.
1123/// The CreatePtrCast callback is invoked to create the actual conversion
1124/// (ptrtoint or ptrtoaddr) at the SCEVUnknown leaves.
1126 : public SCEVRewriteVisitor<SCEVCastSinkingRewriter> {
1128 using ConversionFn = function_ref<const SCEV *(const SCEVUnknown *)>;
1129 Type *TargetTy;
1130 ConversionFn CreatePtrCast;
1131
1132public:
1134 ConversionFn CreatePtrCast)
1135 : Base(SE), TargetTy(TargetTy), CreatePtrCast(std::move(CreatePtrCast)) {}
1136
1137 static const SCEV *rewrite(const SCEV *Scev, ScalarEvolution &SE,
1138 Type *TargetTy, ConversionFn CreatePtrCast) {
1139 SCEVCastSinkingRewriter Rewriter(SE, TargetTy, std::move(CreatePtrCast));
1140 return Rewriter.visit(Scev);
1141 }
1142
1143 const SCEV *visit(const SCEV *S) {
1144 Type *STy = S->getType();
1145 // If the expression is not pointer-typed, just keep it as-is.
1146 if (!STy->isPointerTy())
1147 return S;
1148 // Else, recursively sink the cast down into it.
1149 return Base::visit(S);
1150 }
1151
1152 const SCEV *visitAddExpr(const SCEVAddExpr *Expr) {
1153 // Preserve wrap flags on rewritten SCEVAddExpr, which the default
1154 // implementation drops.
1155 SmallVector<SCEVUse, 2> Operands;
1156 bool Changed = false;
1157 for (SCEVUse Op : Expr->operands()) {
1158 Operands.push_back(visit(Op.getPointer()));
1159 Changed |= Op.getPointer() != Operands.back();
1160 }
1161 return !Changed ? Expr : SE.getAddExpr(Operands, Expr->getNoWrapFlags());
1162 }
1163
1164 const SCEV *visitMulExpr(const SCEVMulExpr *Expr) {
1165 SmallVector<SCEVUse, 2> Operands;
1166 bool Changed = false;
1167 for (SCEVUse Op : Expr->operands()) {
1168 Operands.push_back(visit(Op.getPointer()));
1169 Changed |= Op.getPointer() != Operands.back();
1170 }
1171 return !Changed ? Expr : SE.getMulExpr(Operands, Expr->getNoWrapFlags());
1172 }
1173
1174 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
1175 assert(Expr->getType()->isPointerTy() &&
1176 "Should only reach pointer-typed SCEVUnknown's.");
1177 // Perform some basic constant folding. If the operand of the cast is a
1178 // null pointer, don't create a cast SCEV expression (that will be left
1179 // as-is), but produce a zero constant.
1181 return SE.getZero(TargetTy);
1182 return CreatePtrCast(Expr);
1183 }
1184};
1185
1187 assert(Op->getType()->isPointerTy() && "Op must be a pointer");
1188
1189 // It isn't legal for optimizations to construct new ptrtoint expressions
1190 // for non-integral pointers.
1191 if (getDataLayout().isNonIntegralPointerType(Op->getType()))
1192 return getCouldNotCompute();
1193
1194 Type *IntPtrTy = getDataLayout().getIntPtrType(Op->getType());
1195
1196 // We can only trivially model ptrtoint if SCEV's effective (integer) type
1197 // is sufficiently wide to represent all possible pointer values.
1198 // We could theoretically teach SCEV to truncate wider pointers, but
1199 // that isn't implemented for now.
1201 getDataLayout().getTypeSizeInBits(IntPtrTy))
1202 return getCouldNotCompute();
1203
1204 // Use the rewriter to sink the cast down to SCEVUnknown leaves.
1206 Op, *this, IntPtrTy, [this, IntPtrTy](const SCEVUnknown *U) {
1208 ID.AddInteger(scPtrToInt);
1209 ID.AddPointer(U);
1210 void *IP = nullptr;
1211 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP))
1212 return S;
1213 SCEV *S = new (SCEVAllocator)
1214 SCEVPtrToIntExpr(ID.Intern(SCEVAllocator), U, IntPtrTy);
1215 UniqueSCEVs.InsertNode(S, IP);
1216 S->computeAndSetCanonical(*this);
1217 registerUser(S, U);
1218 return static_cast<const SCEV *>(S);
1219 });
1220 assert(IntOp->getType()->isIntegerTy() &&
1221 "We must have succeeded in sinking the cast, "
1222 "and ending up with an integer-typed expression!");
1223 return IntOp;
1224}
1225
1227 assert(Op->getType()->isPointerTy() && "Op must be a pointer");
1228
1229 // Treat pointers with unstable representation conservatively, since the
1230 // address bits may change.
1231 if (DL.hasUnstableRepresentation(Op->getType()))
1232 return getCouldNotCompute();
1233
1234 Type *Ty = DL.getAddressType(Op->getType());
1235
1236 // Use the rewriter to sink the cast down to SCEVUnknown leaves.
1237 // The rewriter handles null pointer constant folding.
1239 Op, *this, Ty, [this, Ty](const SCEVUnknown *U) {
1241 ID.AddInteger(scPtrToAddr);
1242 ID.AddPointer(U);
1243 ID.AddPointer(Ty);
1244 void *IP = nullptr;
1245 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP))
1246 return S;
1247 SCEV *S = new (SCEVAllocator)
1248 SCEVPtrToAddrExpr(ID.Intern(SCEVAllocator), U, Ty);
1249 UniqueSCEVs.InsertNode(S, IP);
1250 S->computeAndSetCanonical(*this);
1251 registerUser(S, U);
1252 return static_cast<const SCEV *>(S);
1253 });
1254 assert(IntOp->getType()->isIntegerTy() &&
1255 "We must have succeeded in sinking the cast, "
1256 "and ending up with an integer-typed expression!");
1257 return IntOp;
1258}
1259
1261 assert(Ty->isIntegerTy() && "Target type must be an integer type!");
1262
1263 const SCEV *IntOp = getLosslessPtrToIntExpr(Op);
1264 if (isa<SCEVCouldNotCompute>(IntOp))
1265 return IntOp;
1266
1267 return getTruncateOrZeroExtend(IntOp, Ty);
1268}
1269
1271 unsigned Depth) {
1272 assert(getTypeSizeInBits(Op->getType()) > getTypeSizeInBits(Ty) &&
1273 "This is not a truncating conversion!");
1274 assert(isSCEVable(Ty) &&
1275 "This is not a conversion to a SCEVable type!");
1276 assert(!Op->getType()->isPointerTy() && "Can't truncate pointer!");
1277 Ty = getEffectiveSCEVType(Ty);
1278
1280 ID.AddInteger(scTruncate);
1281 ID.AddPointer(Op);
1282 ID.AddPointer(Ty);
1283 void *IP = nullptr;
1284 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
1285
1286 // Fold if the operand is constant.
1287 if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op))
1288 return getConstant(
1289 cast<ConstantInt>(ConstantExpr::getTrunc(SC->getValue(), Ty)));
1290
1291 // trunc(trunc(x)) --> trunc(x)
1293 return getTruncateExpr(ST->getOperand(), Ty, Depth + 1);
1294
1295 // trunc(sext(x)) --> sext(x) if widening or trunc(x) if narrowing
1297 return getTruncateOrSignExtend(SS->getOperand(), Ty, Depth + 1);
1298
1299 // trunc(zext(x)) --> zext(x) if widening or trunc(x) if narrowing
1301 return getTruncateOrZeroExtend(SZ->getOperand(), Ty, Depth + 1);
1302
1303 if (Depth > MaxCastDepth) {
1304 SCEV *S =
1305 new (SCEVAllocator) SCEVTruncateExpr(ID.Intern(SCEVAllocator), Op, Ty);
1306 UniqueSCEVs.InsertNode(S, IP);
1307 S->computeAndSetCanonical(*this);
1308 registerUser(S, Op);
1309 return S;
1310 }
1311
1312 // trunc(x1 + ... + xN) --> trunc(x1) + ... + trunc(xN) and
1313 // trunc(x1 * ... * xN) --> trunc(x1) * ... * trunc(xN),
1314 // if after transforming we have at most one truncate, not counting truncates
1315 // that replace other casts.
1317 auto *CommOp = cast<SCEVCommutativeExpr>(Op);
1318 SmallVector<SCEVUse, 4> Operands;
1319 unsigned numTruncs = 0;
1320 for (unsigned i = 0, e = CommOp->getNumOperands(); i != e && numTruncs < 2;
1321 ++i) {
1322 const SCEV *S = getTruncateExpr(CommOp->getOperand(i), Ty, Depth + 1);
1323 if (!isa<SCEVIntegralCastExpr>(CommOp->getOperand(i)) &&
1325 numTruncs++;
1326 Operands.push_back(S);
1327 }
1328 if (numTruncs < 2) {
1329 if (isa<SCEVAddExpr>(Op))
1330 return getAddExpr(Operands);
1331 if (isa<SCEVMulExpr>(Op))
1332 return getMulExpr(Operands);
1333 llvm_unreachable("Unexpected SCEV type for Op.");
1334 }
1335 // Although we checked in the beginning that ID is not in the cache, it is
1336 // possible that during recursion and different modification ID was inserted
1337 // into the cache. So if we find it, just return it.
1338 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP))
1339 return S;
1340 }
1341
1342 // If the input value is a chrec scev, truncate the chrec's operands.
1343 if (const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(Op)) {
1344 SmallVector<SCEVUse, 4> Operands;
1345 for (const SCEV *Op : AddRec->operands())
1346 Operands.push_back(getTruncateExpr(Op, Ty, Depth + 1));
1347 return getAddRecExpr(Operands, AddRec->getLoop(), SCEV::FlagAnyWrap);
1348 }
1349
1350 // Return zero if truncating to known zeros.
1351 uint32_t MinTrailingZeros = getMinTrailingZeros(Op);
1352 if (MinTrailingZeros >= getTypeSizeInBits(Ty))
1353 return getZero(Ty);
1354
1355 // The cast wasn't folded; create an explicit cast node. We can reuse
1356 // the existing insert position since if we get here, we won't have
1357 // made any changes which would invalidate it.
1358 SCEV *S = new (SCEVAllocator) SCEVTruncateExpr(ID.Intern(SCEVAllocator),
1359 Op, Ty);
1360 UniqueSCEVs.InsertNode(S, IP);
1361 S->computeAndSetCanonical(*this);
1362 registerUser(S, Op);
1363 return S;
1364}
1365
1366// Get the limit of a recurrence such that incrementing by Step cannot cause
1367// signed overflow as long as the value of the recurrence within the
1368// loop does not exceed this limit before incrementing.
1369static const SCEV *getSignedOverflowLimitForStep(const SCEV *Step,
1370 ICmpInst::Predicate *Pred,
1371 ScalarEvolution *SE) {
1372 unsigned BitWidth = SE->getTypeSizeInBits(Step->getType());
1373 if (SE->isKnownPositive(Step)) {
1374 *Pred = ICmpInst::ICMP_SLT;
1376 SE->getSignedRangeMax(Step));
1377 }
1378 if (SE->isKnownNegative(Step)) {
1379 *Pred = ICmpInst::ICMP_SGT;
1381 SE->getSignedRangeMin(Step));
1382 }
1383 return nullptr;
1384}
1385
1386// Get the limit of a recurrence such that incrementing by Step cannot cause
1387// unsigned overflow as long as the value of the recurrence within the loop does
1388// not exceed this limit before incrementing.
1390 ICmpInst::Predicate *Pred,
1391 ScalarEvolution *SE) {
1392 unsigned BitWidth = SE->getTypeSizeInBits(Step->getType());
1393 *Pred = ICmpInst::ICMP_ULT;
1394
1396 SE->getUnsignedRangeMax(Step));
1397}
1398
1399namespace {
1400
1401struct ExtendOpTraitsBase {
1402 typedef const SCEV *(ScalarEvolution::*GetExtendExprTy)(const SCEV *, Type *,
1403 unsigned);
1404};
1405
1406// Used to make code generic over signed and unsigned overflow.
1407template <typename ExtendOp> struct ExtendOpTraits {
1408 // Members present:
1409 //
1410 // static const SCEV::NoWrapFlags WrapType;
1411 //
1412 // static const ExtendOpTraitsBase::GetExtendExprTy GetExtendExpr;
1413 //
1414 // static const SCEV *getOverflowLimitForStep(const SCEV *Step,
1415 // ICmpInst::Predicate *Pred,
1416 // ScalarEvolution *SE);
1417};
1418
1419template <>
1420struct ExtendOpTraits<SCEVSignExtendExpr> : public ExtendOpTraitsBase {
1421 static const SCEV::NoWrapFlags WrapType = SCEV::FlagNSW;
1422
1423 static const GetExtendExprTy GetExtendExpr;
1424
1425 static const SCEV *getOverflowLimitForStep(const SCEV *Step,
1426 ICmpInst::Predicate *Pred,
1427 ScalarEvolution *SE) {
1428 return getSignedOverflowLimitForStep(Step, Pred, SE);
1429 }
1430};
1431
1432const ExtendOpTraitsBase::GetExtendExprTy ExtendOpTraits<
1434
1435template <>
1436struct ExtendOpTraits<SCEVZeroExtendExpr> : public ExtendOpTraitsBase {
1437 static const SCEV::NoWrapFlags WrapType = SCEV::FlagNUW;
1438
1439 static const GetExtendExprTy GetExtendExpr;
1440
1441 static const SCEV *getOverflowLimitForStep(const SCEV *Step,
1442 ICmpInst::Predicate *Pred,
1443 ScalarEvolution *SE) {
1444 return getUnsignedOverflowLimitForStep(Step, Pred, SE);
1445 }
1446};
1447
1448const ExtendOpTraitsBase::GetExtendExprTy ExtendOpTraits<
1450
1451} // end anonymous namespace
1452
1453// The recurrence AR has been shown to have no signed/unsigned wrap or something
1454// close to it. Typically, if we can prove NSW/NUW for AR, then we can just as
1455// easily prove NSW/NUW for its preincrement or postincrement sibling. This
1456// allows normalizing a sign/zero extended AddRec as such: {sext/zext(Step +
1457// Start),+,Step} => {(Step + sext/zext(Start),+,Step} As a result, the
1458// expression "Step + sext/zext(PreIncAR)" is congruent with
1459// "sext/zext(PostIncAR)"
1460template <typename ExtendOpTy>
1461static const SCEV *getPreStartForExtend(const SCEVAddRecExpr *AR, Type *Ty,
1462 ScalarEvolution *SE, unsigned Depth) {
1463 auto WrapType = ExtendOpTraits<ExtendOpTy>::WrapType;
1464 auto GetExtendExpr = ExtendOpTraits<ExtendOpTy>::GetExtendExpr;
1465
1466 const Loop *L = AR->getLoop();
1467 const SCEV *Start = AR->getStart();
1468 const SCEV *Step = AR->getStepRecurrence(*SE);
1469
1470 // Check for a simple looking step prior to loop entry.
1471 const SCEVAddExpr *SA = dyn_cast<SCEVAddExpr>(Start);
1472 if (!SA)
1473 return nullptr;
1474
1475 // Create an AddExpr for "PreStart" after subtracting Step. Full SCEV
1476 // subtraction is expensive. For this purpose, perform a quick and dirty
1477 // difference, by checking for Step in the operand list. Note, that
1478 // SA might have repeated ops, like %a + %a + ..., so only remove one.
1479 SmallVector<SCEVUse, 4> DiffOps(SA->operands());
1480 for (auto It = DiffOps.begin(); It != DiffOps.end(); ++It)
1481 if (*It == Step) {
1482 DiffOps.erase(It);
1483 break;
1484 }
1485
1486 if (DiffOps.size() == SA->getNumOperands())
1487 return nullptr;
1488
1489 // Try to prove `WrapType` (SCEV::FlagNSW or SCEV::FlagNUW) on `PreStart` +
1490 // `Step`:
1491
1492 // 1. NSW/NUW flags on the step increment.
1493 auto PreStartFlags =
1495 const SCEV *PreStart = SE->getAddExpr(DiffOps, PreStartFlags);
1497 SE->getAddRecExpr(PreStart, Step, L, SCEV::FlagAnyWrap));
1498
1499 // "{S,+,X} is <nsw>/<nuw>" and "the backedge is taken at least once" implies
1500 // "S+X does not sign/unsign-overflow".
1501 //
1502
1503 const SCEV *BECount = SE->getBackedgeTakenCount(L);
1504 if (PreAR && PreAR->getNoWrapFlags(WrapType) &&
1505 !isa<SCEVCouldNotCompute>(BECount) && SE->isKnownPositive(BECount))
1506 return PreStart;
1507
1508 // 2. Direct overflow check on the step operation's expression.
1509 unsigned BitWidth = SE->getTypeSizeInBits(AR->getType());
1510 Type *WideTy = IntegerType::get(SE->getContext(), BitWidth * 2);
1511 const SCEV *OperandExtendedStart =
1512 SE->getAddExpr((SE->*GetExtendExpr)(PreStart, WideTy, Depth),
1513 (SE->*GetExtendExpr)(Step, WideTy, Depth));
1514 if ((SE->*GetExtendExpr)(Start, WideTy, Depth) == OperandExtendedStart) {
1515 if (PreAR && AR->getNoWrapFlags(WrapType)) {
1516 // If we know `AR` == {`PreStart`+`Step`,+,`Step`} is `WrapType` (FlagNSW
1517 // or FlagNUW) and that `PreStart` + `Step` is `WrapType` too, then
1518 // `PreAR` == {`PreStart`,+,`Step`} is also `WrapType`. Cache this fact.
1519 SE->setNoWrapFlags(const_cast<SCEVAddRecExpr *>(PreAR), WrapType);
1520 }
1521 return PreStart;
1522 }
1523
1524 // 3. Loop precondition.
1526 const SCEV *OverflowLimit =
1527 ExtendOpTraits<ExtendOpTy>::getOverflowLimitForStep(Step, &Pred, SE);
1528
1529 if (OverflowLimit &&
1530 SE->isLoopEntryGuardedByCond(L, Pred, PreStart, OverflowLimit))
1531 return PreStart;
1532
1533 return nullptr;
1534}
1535
1536// Get the normalized zero or sign extended expression for this AddRec's Start.
1537template <typename ExtendOpTy>
1538static const SCEV *getExtendAddRecStart(const SCEVAddRecExpr *AR, Type *Ty,
1539 ScalarEvolution *SE,
1540 unsigned Depth) {
1541 auto GetExtendExpr = ExtendOpTraits<ExtendOpTy>::GetExtendExpr;
1542
1543 const SCEV *PreStart = getPreStartForExtend<ExtendOpTy>(AR, Ty, SE, Depth);
1544 if (!PreStart)
1545 return (SE->*GetExtendExpr)(AR->getStart(), Ty, Depth);
1546
1547 return SE->getAddExpr((SE->*GetExtendExpr)(AR->getStepRecurrence(*SE), Ty,
1548 Depth),
1549 (SE->*GetExtendExpr)(PreStart, Ty, Depth));
1550}
1551
1552// Try to prove away overflow by looking at "nearby" add recurrences. A
1553// motivating example for this rule: if we know `{0,+,4}` is `ult` `-1` and it
1554// does not itself wrap then we can conclude that `{1,+,4}` is `nuw`.
1555//
1556// Formally:
1557//
1558// {S,+,X} == {S-T,+,X} + T
1559// => Ext({S,+,X}) == Ext({S-T,+,X} + T)
1560//
1561// If ({S-T,+,X} + T) does not overflow ... (1)
1562//
1563// RHS == Ext({S-T,+,X} + T) == Ext({S-T,+,X}) + Ext(T)
1564//
1565// If {S-T,+,X} does not overflow ... (2)
1566//
1567// RHS == Ext({S-T,+,X}) + Ext(T) == {Ext(S-T),+,Ext(X)} + Ext(T)
1568// == {Ext(S-T)+Ext(T),+,Ext(X)}
1569//
1570// If (S-T)+T does not overflow ... (3)
1571//
1572// RHS == {Ext(S-T)+Ext(T),+,Ext(X)} == {Ext(S-T+T),+,Ext(X)}
1573// == {Ext(S),+,Ext(X)} == LHS
1574//
1575// Thus, if (1), (2) and (3) are true for some T, then
1576// Ext({S,+,X}) == {Ext(S),+,Ext(X)}
1577//
1578// (3) is implied by (1) -- "(S-T)+T does not overflow" is simply "({S-T,+,X}+T)
1579// does not overflow" restricted to the 0th iteration. Therefore we only need
1580// to check for (1) and (2).
1581//
1582// In the current context, S is `Start`, X is `Step`, Ext is `ExtendOpTy` and T
1583// is `Delta` (defined below).
1584template <typename ExtendOpTy>
1585bool ScalarEvolution::proveNoWrapByVaryingStart(const SCEV *Start,
1586 const SCEV *Step,
1587 const Loop *L) {
1588 auto WrapType = ExtendOpTraits<ExtendOpTy>::WrapType;
1589
1590 // We restrict `Start` to a constant to prevent SCEV from spending too much
1591 // time here. It is correct (but more expensive) to continue with a
1592 // non-constant `Start` and do a general SCEV subtraction to compute
1593 // `PreStart` below.
1594 const SCEVConstant *StartC = dyn_cast<SCEVConstant>(Start);
1595 if (!StartC)
1596 return false;
1597
1598 APInt StartAI = StartC->getAPInt();
1599
1600 for (unsigned Delta : {-2, -1, 1, 2}) {
1601 const SCEV *PreStart = getConstant(StartAI - Delta);
1602
1603 FoldingSetNodeID ID;
1604 ID.AddInteger(scAddRecExpr);
1605 ID.AddPointer(PreStart);
1606 ID.AddPointer(Step);
1607 ID.AddPointer(L);
1608 void *IP = nullptr;
1609 const auto *PreAR =
1610 static_cast<SCEVAddRecExpr *>(UniqueSCEVs.FindNodeOrInsertPos(ID, IP));
1611
1612 // Give up if we don't already have the add recurrence we need because
1613 // actually constructing an add recurrence is relatively expensive.
1614 if (PreAR && PreAR->getNoWrapFlags(WrapType)) { // proves (2)
1615 const SCEV *DeltaS = getConstant(StartC->getType(), Delta);
1617 const SCEV *Limit = ExtendOpTraits<ExtendOpTy>::getOverflowLimitForStep(
1618 DeltaS, &Pred, this);
1619 if (Limit && isKnownPredicate(Pred, PreAR, Limit)) // proves (1)
1620 return true;
1621 }
1622 }
1623
1624 return false;
1625}
1626
1627// Finds an integer D for an expression (C + x + y + ...) such that the top
1628// level addition in (D + (C - D + x + y + ...)) would not wrap (signed or
1629// unsigned) and the number of trailing zeros of (C - D + x + y + ...) is
1630// maximized, where C is the \p ConstantTerm, x, y, ... are arbitrary SCEVs, and
1631// the (C + x + y + ...) expression is \p WholeAddExpr.
1633 const SCEVConstant *ConstantTerm,
1634 const SCEVAddExpr *WholeAddExpr) {
1635 const APInt &C = ConstantTerm->getAPInt();
1636 const unsigned BitWidth = C.getBitWidth();
1637 // Find number of trailing zeros of (x + y + ...) w/o the C first:
1638 uint32_t TZ = BitWidth;
1639 for (unsigned I = 1, E = WholeAddExpr->getNumOperands(); I < E && TZ; ++I)
1640 TZ = std::min(TZ, SE.getMinTrailingZeros(WholeAddExpr->getOperand(I)));
1641 if (TZ) {
1642 // Set D to be as many least significant bits of C as possible while still
1643 // guaranteeing that adding D to (C - D + x + y + ...) won't cause a wrap:
1644 return TZ < BitWidth ? C.trunc(TZ).zext(BitWidth) : C;
1645 }
1646 return APInt(BitWidth, 0);
1647}
1648
1649// Finds an integer D for an affine AddRec expression {C,+,x} such that the top
1650// level addition in (D + {C-D,+,x}) would not wrap (signed or unsigned) and the
1651// number of trailing zeros of (C - D + x * n) is maximized, where C is the \p
1652// ConstantStart, x is an arbitrary \p Step, and n is the loop trip count.
1654 const APInt &ConstantStart,
1655 const SCEV *Step) {
1656 const unsigned BitWidth = ConstantStart.getBitWidth();
1657 const uint32_t TZ = SE.getMinTrailingZeros(Step);
1658 if (TZ)
1659 return TZ < BitWidth ? ConstantStart.trunc(TZ).zext(BitWidth)
1660 : ConstantStart;
1661 return APInt(BitWidth, 0);
1662}
1663
1665 const ScalarEvolution::FoldID &ID, const SCEV *S,
1668 &FoldCacheUser) {
1669 auto I = FoldCache.insert({ID, S});
1670 if (!I.second) {
1671 // Remove FoldCacheUser entry for ID when replacing an existing FoldCache
1672 // entry.
1673 auto &UserIDs = FoldCacheUser[I.first->second];
1674 assert(count(UserIDs, ID) == 1 && "unexpected duplicates in UserIDs");
1675 for (unsigned I = 0; I != UserIDs.size(); ++I)
1676 if (UserIDs[I] == ID) {
1677 std::swap(UserIDs[I], UserIDs.back());
1678 break;
1679 }
1680 UserIDs.pop_back();
1681 I.first->second = S;
1682 }
1683 FoldCacheUser[S].push_back(ID);
1684}
1685
1686const SCEV *
1688 assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) &&
1689 "This is not an extending conversion!");
1690 assert(isSCEVable(Ty) &&
1691 "This is not a conversion to a SCEVable type!");
1692 assert(!Op->getType()->isPointerTy() && "Can't extend pointer!");
1693 Ty = getEffectiveSCEVType(Ty);
1694
1695 FoldID ID(scZeroExtend, Op, Ty);
1696 if (const SCEV *S = FoldCache.lookup(ID))
1697 return S;
1698
1699 const SCEV *S = getZeroExtendExprImpl(Op, Ty, Depth);
1701 insertFoldCacheEntry(ID, S, FoldCache, FoldCacheUser);
1702 return S;
1703}
1704
1706 unsigned Depth) {
1707 assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) &&
1708 "This is not an extending conversion!");
1709 assert(isSCEVable(Ty) && "This is not a conversion to a SCEVable type!");
1710 assert(!Op->getType()->isPointerTy() && "Can't extend pointer!");
1711
1712 // Fold if the operand is constant.
1713 if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op))
1714 return getConstant(SC->getAPInt().zext(getTypeSizeInBits(Ty)));
1715
1716 // zext(zext(x)) --> zext(x)
1718 return getZeroExtendExpr(SZ->getOperand(), Ty, Depth + 1);
1719
1720 // Before doing any expensive analysis, check to see if we've already
1721 // computed a SCEV for this Op and Ty.
1723 ID.AddInteger(scZeroExtend);
1724 ID.AddPointer(Op);
1725 ID.AddPointer(Ty);
1726 void *IP = nullptr;
1727 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
1728 if (Depth > MaxCastDepth) {
1729 SCEV *S = new (SCEVAllocator) SCEVZeroExtendExpr(ID.Intern(SCEVAllocator),
1730 Op, Ty);
1731 UniqueSCEVs.InsertNode(S, IP);
1732 S->computeAndSetCanonical(*this);
1733 registerUser(S, Op);
1734 return S;
1735 }
1736
1737 // zext(trunc(x)) --> zext(x) or x or trunc(x)
1739 // It's possible the bits taken off by the truncate were all zero bits. If
1740 // so, we should be able to simplify this further.
1741 const SCEV *X = ST->getOperand();
1743 unsigned TruncBits = getTypeSizeInBits(ST->getType());
1744 unsigned NewBits = getTypeSizeInBits(Ty);
1745 if (CR.truncate(TruncBits).zeroExtend(NewBits).contains(
1746 CR.zextOrTrunc(NewBits)))
1747 return getTruncateOrZeroExtend(X, Ty, Depth);
1748 }
1749
1750 // If the input value is a chrec scev, and we can prove that the value
1751 // did not overflow the old, smaller, value, we can zero extend all of the
1752 // operands (often constants). This allows analysis of something like
1753 // this: for (unsigned char X = 0; X < 100; ++X) { int Y = X; }
1755 if (AR->isAffine()) {
1756 const SCEV *Start = AR->getStart();
1757 const SCEV *Step = AR->getStepRecurrence(*this);
1758 unsigned BitWidth = getTypeSizeInBits(AR->getType());
1759 const Loop *L = AR->getLoop();
1760
1761 // If we have special knowledge that this addrec won't overflow,
1762 // we don't need to do any further analysis.
1763 if (AR->hasNoUnsignedWrap()) {
1764 Start =
1766 Step = getZeroExtendExpr(Step, Ty, Depth + 1);
1767 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
1768 }
1769
1770 // Check whether the backedge-taken count is SCEVCouldNotCompute.
1771 // Note that this serves two purposes: It filters out loops that are
1772 // simply not analyzable, and it covers the case where this code is
1773 // being called from within backedge-taken count analysis, such that
1774 // attempting to ask for the backedge-taken count would likely result
1775 // in infinite recursion. In the later case, the analysis code will
1776 // cope with a conservative value, and it will take care to purge
1777 // that value once it has finished.
1778 const SCEV *MaxBECount = getConstantMaxBackedgeTakenCount(L);
1779 if (!isa<SCEVCouldNotCompute>(MaxBECount)) {
1780 // Manually compute the final value for AR, checking for overflow.
1781
1782 // Check whether the backedge-taken count can be losslessly casted to
1783 // the addrec's type. The count is always unsigned.
1784 const SCEV *CastedMaxBECount =
1785 getTruncateOrZeroExtend(MaxBECount, Start->getType(), Depth);
1786 const SCEV *RecastedMaxBECount = getTruncateOrZeroExtend(
1787 CastedMaxBECount, MaxBECount->getType(), Depth);
1788 if (MaxBECount == RecastedMaxBECount) {
1789 Type *WideTy = IntegerType::get(getContext(), BitWidth * 2);
1790 // Check whether Start+Step*MaxBECount has no unsigned overflow.
1791 const SCEV *ZMul = getMulExpr(CastedMaxBECount, Step,
1793 const SCEV *ZAdd = getZeroExtendExpr(getAddExpr(Start, ZMul,
1795 Depth + 1),
1796 WideTy, Depth + 1);
1797 const SCEV *WideStart = getZeroExtendExpr(Start, WideTy, Depth + 1);
1798 const SCEV *WideMaxBECount =
1799 getZeroExtendExpr(CastedMaxBECount, WideTy, Depth + 1);
1800 const SCEV *OperandExtendedAdd =
1801 getAddExpr(WideStart,
1802 getMulExpr(WideMaxBECount,
1803 getZeroExtendExpr(Step, WideTy, Depth + 1),
1806 if (ZAdd == OperandExtendedAdd) {
1807 // Cache knowledge of AR NUW, which is propagated to this AddRec.
1808 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNUW);
1809 // Return the expression with the addrec on the outside.
1810 Start = getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this,
1811 Depth + 1);
1812 Step = getZeroExtendExpr(Step, Ty, Depth + 1);
1813 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
1814 }
1815 // Similar to above, only this time treat the step value as signed.
1816 // This covers loops that count down.
1817 OperandExtendedAdd =
1818 getAddExpr(WideStart,
1819 getMulExpr(WideMaxBECount,
1820 getSignExtendExpr(Step, WideTy, Depth + 1),
1823 if (ZAdd == OperandExtendedAdd) {
1824 // Cache knowledge of AR NW, which is propagated to this AddRec.
1825 // Negative step causes unsigned wrap, but it still can't self-wrap.
1826 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNW);
1827 // Return the expression with the addrec on the outside.
1828 Start = getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this,
1829 Depth + 1);
1830 Step = getSignExtendExpr(Step, Ty, Depth + 1);
1831 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
1832 }
1833 }
1834 }
1835
1836 // Normally, in the cases we can prove no-overflow via a
1837 // backedge guarding condition, we can also compute a backedge
1838 // taken count for the loop. The exceptions are assumptions and
1839 // guards present in the loop -- SCEV is not great at exploiting
1840 // these to compute max backedge taken counts, but can still use
1841 // these to prove lack of overflow. Use this fact to avoid
1842 // doing extra work that may not pay off.
1843 if (!isa<SCEVCouldNotCompute>(MaxBECount) || HasGuards ||
1844 !AC.assumptions().empty()) {
1845
1846 auto NewFlags = proveNoUnsignedWrapViaInduction(AR);
1847 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), NewFlags);
1848 if (AR->hasNoUnsignedWrap()) {
1849 // Same as nuw case above - duplicated here to avoid a compile time
1850 // issue. It's not clear that the order of checks does matter, but
1851 // it's one of two issue possible causes for a change which was
1852 // reverted. Be conservative for the moment.
1853 Start =
1855 Step = getZeroExtendExpr(Step, Ty, Depth + 1);
1856 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
1857 }
1858
1859 // For a negative step, we can extend the operands iff doing so only
1860 // traverses values in the range zext([0,UINT_MAX]).
1861 if (isKnownNegative(Step)) {
1863 getSignedRangeMin(Step));
1866 // Cache knowledge of AR NW, which is propagated to this
1867 // AddRec. Negative step causes unsigned wrap, but it
1868 // still can't self-wrap.
1869 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNW);
1870 // Return the expression with the addrec on the outside.
1871 Start = getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this,
1872 Depth + 1);
1873 Step = getSignExtendExpr(Step, Ty, Depth + 1);
1874 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
1875 }
1876 }
1877 }
1878
1879 // zext({C,+,Step}) --> (zext(D) + zext({C-D,+,Step}))<nuw><nsw>
1880 // if D + (C - D + Step * n) could be proven to not unsigned wrap
1881 // where D maximizes the number of trailing zeros of (C - D + Step * n)
1882 if (const auto *SC = dyn_cast<SCEVConstant>(Start)) {
1883 const APInt &C = SC->getAPInt();
1884 const APInt &D = extractConstantWithoutWrapping(*this, C, Step);
1885 if (D != 0) {
1886 const SCEV *SZExtD = getZeroExtendExpr(getConstant(D), Ty, Depth);
1887 const SCEV *SResidual =
1888 getAddRecExpr(getConstant(C - D), Step, L, AR->getNoWrapFlags());
1889 const SCEV *SZExtR = getZeroExtendExpr(SResidual, Ty, Depth + 1);
1890 return getAddExpr(SZExtD, SZExtR,
1892 Depth + 1);
1893 }
1894 }
1895
1896 if (proveNoWrapByVaryingStart<SCEVZeroExtendExpr>(Start, Step, L)) {
1897 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNUW);
1898 Start =
1900 Step = getZeroExtendExpr(Step, Ty, Depth + 1);
1901 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
1902 }
1903 }
1904
1905 // zext(A % B) --> zext(A) % zext(B)
1906 {
1907 const SCEV *LHS;
1908 const SCEV *RHS;
1909 if (match(Op, m_scev_URem(m_SCEV(LHS), m_SCEV(RHS), *this)))
1910 return getURemExpr(getZeroExtendExpr(LHS, Ty, Depth + 1),
1911 getZeroExtendExpr(RHS, Ty, Depth + 1));
1912 }
1913
1914 // zext(A / B) --> zext(A) / zext(B).
1915 if (auto *Div = dyn_cast<SCEVUDivExpr>(Op))
1916 return getUDivExpr(getZeroExtendExpr(Div->getLHS(), Ty, Depth + 1),
1917 getZeroExtendExpr(Div->getRHS(), Ty, Depth + 1));
1918
1919 if (auto *SA = dyn_cast<SCEVAddExpr>(Op)) {
1920 // zext((A + B + ...)<nuw>) --> (zext(A) + zext(B) + ...)<nuw>
1921 if (SA->hasNoUnsignedWrap()) {
1922 // If the addition does not unsign overflow then we can, by definition,
1923 // commute the zero extension with the addition operation.
1925 for (SCEVUse Op : SA->operands())
1926 Ops.push_back(getZeroExtendExpr(Op, Ty, Depth + 1));
1927 return getAddExpr(Ops, SCEV::FlagNUW, Depth + 1);
1928 }
1929
1930 // zext(C + x + y + ...) --> (zext(D) + zext((C - D) + x + y + ...))
1931 // if D + (C - D + x + y + ...) could be proven to not unsigned wrap
1932 // where D maximizes the number of trailing zeros of (C - D + x + y + ...)
1933 //
1934 // Often address arithmetics contain expressions like
1935 // (zext (add (shl X, C1), C2)), for instance, (zext (5 + (4 * X))).
1936 // This transformation is useful while proving that such expressions are
1937 // equal or differ by a small constant amount, see LoadStoreVectorizer pass.
1938 if (const auto *SC = dyn_cast<SCEVConstant>(SA->getOperand(0))) {
1939 const APInt &D = extractConstantWithoutWrapping(*this, SC, SA);
1940 if (D != 0) {
1941 const SCEV *SZExtD = getZeroExtendExpr(getConstant(D), Ty, Depth);
1942 const SCEV *SResidual =
1944 const SCEV *SZExtR = getZeroExtendExpr(SResidual, Ty, Depth + 1);
1945 return getAddExpr(SZExtD, SZExtR,
1947 Depth + 1);
1948 }
1949 }
1950 }
1951
1952 if (auto *SM = dyn_cast<SCEVMulExpr>(Op)) {
1953 // zext((A * B * ...)<nuw>) --> (zext(A) * zext(B) * ...)<nuw>
1954 if (SM->hasNoUnsignedWrap()) {
1955 // If the multiply does not unsign overflow then we can, by definition,
1956 // commute the zero extension with the multiply operation.
1958 for (SCEVUse Op : SM->operands())
1959 Ops.push_back(getZeroExtendExpr(Op, Ty, Depth + 1));
1960 return getMulExpr(Ops, SCEV::FlagNUW, Depth + 1);
1961 }
1962
1963 // zext(2^K * (trunc X to iN)) to iM ->
1964 // 2^K * (zext(trunc X to i{N-K}) to iM)<nuw>
1965 //
1966 // Proof:
1967 //
1968 // zext(2^K * (trunc X to iN)) to iM
1969 // = zext((trunc X to iN) << K) to iM
1970 // = zext((trunc X to i{N-K}) << K)<nuw> to iM
1971 // (because shl removes the top K bits)
1972 // = zext((2^K * (trunc X to i{N-K}))<nuw>) to iM
1973 // = (2^K * (zext(trunc X to i{N-K}) to iM))<nuw>.
1974 //
1975 const APInt *C;
1976 const SCEV *TruncRHS;
1977 if (match(SM,
1978 m_scev_Mul(m_scev_APInt(C), m_scev_Trunc(m_SCEV(TruncRHS)))) &&
1979 C->isPowerOf2()) {
1980 int NewTruncBits =
1981 getTypeSizeInBits(SM->getOperand(1)->getType()) - C->logBase2();
1982 Type *NewTruncTy = IntegerType::get(getContext(), NewTruncBits);
1983 return getMulExpr(
1984 getZeroExtendExpr(SM->getOperand(0), Ty),
1985 getZeroExtendExpr(getTruncateExpr(TruncRHS, NewTruncTy), Ty),
1986 SCEV::FlagNUW, Depth + 1);
1987 }
1988 }
1989
1990 // zext(umin(x, y)) -> umin(zext(x), zext(y))
1991 // zext(umax(x, y)) -> umax(zext(x), zext(y))
1994 SmallVector<SCEVUse, 4> Operands;
1995 for (SCEVUse Operand : MinMax->operands())
1996 Operands.push_back(getZeroExtendExpr(Operand, Ty));
1998 return getUMinExpr(Operands);
1999 return getUMaxExpr(Operands);
2000 }
2001
2002 // zext(umin_seq(x, y)) -> umin_seq(zext(x), zext(y))
2004 assert(isa<SCEVSequentialUMinExpr>(MinMax) && "Not supported!");
2005 SmallVector<SCEVUse, 4> Operands;
2006 for (SCEVUse Operand : MinMax->operands())
2007 Operands.push_back(getZeroExtendExpr(Operand, Ty));
2008 return getUMinExpr(Operands, /*Sequential*/ true);
2009 }
2010
2011 // The cast wasn't folded; create an explicit cast node.
2012 // Recompute the insert position, as it may have been invalidated.
2013 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
2014 SCEV *S = new (SCEVAllocator) SCEVZeroExtendExpr(ID.Intern(SCEVAllocator),
2015 Op, Ty);
2016 UniqueSCEVs.InsertNode(S, IP);
2017 S->computeAndSetCanonical(*this);
2018 registerUser(S, Op);
2019 return S;
2020}
2021
2022const SCEV *
2024 assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) &&
2025 "This is not an extending conversion!");
2026 assert(isSCEVable(Ty) &&
2027 "This is not a conversion to a SCEVable type!");
2028 assert(!Op->getType()->isPointerTy() && "Can't extend pointer!");
2029 Ty = getEffectiveSCEVType(Ty);
2030
2031 FoldID ID(scSignExtend, Op, Ty);
2032 if (const SCEV *S = FoldCache.lookup(ID))
2033 return S;
2034
2035 const SCEV *S = getSignExtendExprImpl(Op, Ty, Depth);
2037 insertFoldCacheEntry(ID, S, FoldCache, FoldCacheUser);
2038 return S;
2039}
2040
2042 unsigned Depth) {
2043 assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) &&
2044 "This is not an extending conversion!");
2045 assert(isSCEVable(Ty) && "This is not a conversion to a SCEVable type!");
2046 assert(!Op->getType()->isPointerTy() && "Can't extend pointer!");
2047 Ty = getEffectiveSCEVType(Ty);
2048
2049 // Fold if the operand is constant.
2050 if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op))
2051 return getConstant(SC->getAPInt().sext(getTypeSizeInBits(Ty)));
2052
2053 // sext(sext(x)) --> sext(x)
2055 return getSignExtendExpr(SS->getOperand(), Ty, Depth + 1);
2056
2057 // sext(zext(x)) --> zext(x)
2059 return getZeroExtendExpr(SZ->getOperand(), Ty, Depth + 1);
2060
2061 // Before doing any expensive analysis, check to see if we've already
2062 // computed a SCEV for this Op and Ty.
2064 ID.AddInteger(scSignExtend);
2065 ID.AddPointer(Op);
2066 ID.AddPointer(Ty);
2067 void *IP = nullptr;
2068 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
2069 // Limit recursion depth.
2070 if (Depth > MaxCastDepth) {
2071 SCEV *S = new (SCEVAllocator) SCEVSignExtendExpr(ID.Intern(SCEVAllocator),
2072 Op, Ty);
2073 UniqueSCEVs.InsertNode(S, IP);
2074 S->computeAndSetCanonical(*this);
2075 registerUser(S, Op);
2076 return S;
2077 }
2078
2079 // sext(trunc(x)) --> sext(x) or x or trunc(x)
2081 // It's possible the bits taken off by the truncate were all sign bits. If
2082 // so, we should be able to simplify this further.
2083 const SCEV *X = ST->getOperand();
2085 unsigned TruncBits = getTypeSizeInBits(ST->getType());
2086 unsigned NewBits = getTypeSizeInBits(Ty);
2087 if (CR.truncate(TruncBits).signExtend(NewBits).contains(
2088 CR.sextOrTrunc(NewBits)))
2089 return getTruncateOrSignExtend(X, Ty, Depth);
2090 }
2091
2092 if (auto *SA = dyn_cast<SCEVAddExpr>(Op)) {
2093 // sext((A + B + ...)<nsw>) --> (sext(A) + sext(B) + ...)<nsw>
2094 if (SA->hasNoSignedWrap()) {
2095 // If the addition does not sign overflow then we can, by definition,
2096 // commute the sign extension with the addition operation.
2098 for (SCEVUse Op : SA->operands())
2099 Ops.push_back(getSignExtendExpr(Op, Ty, Depth + 1));
2100 return getAddExpr(Ops, SCEV::FlagNSW, Depth + 1);
2101 }
2102
2103 // sext(C + x + y + ...) --> (sext(D) + sext((C - D) + x + y + ...))
2104 // if D + (C - D + x + y + ...) could be proven to not signed wrap
2105 // where D maximizes the number of trailing zeros of (C - D + x + y + ...)
2106 //
2107 // For instance, this will bring two seemingly different expressions:
2108 // 1 + sext(5 + 20 * %x + 24 * %y) and
2109 // sext(6 + 20 * %x + 24 * %y)
2110 // to the same form:
2111 // 2 + sext(4 + 20 * %x + 24 * %y)
2112 if (const auto *SC = dyn_cast<SCEVConstant>(SA->getOperand(0))) {
2113 const APInt &D = extractConstantWithoutWrapping(*this, SC, SA);
2114 if (D != 0) {
2115 const SCEV *SSExtD = getSignExtendExpr(getConstant(D), Ty, Depth);
2116 const SCEV *SResidual =
2118 const SCEV *SSExtR = getSignExtendExpr(SResidual, Ty, Depth + 1);
2119 return getAddExpr(SSExtD, SSExtR,
2121 Depth + 1);
2122 }
2123 }
2124 }
2125 // If the input value is a chrec scev, and we can prove that the value
2126 // did not overflow the old, smaller, value, we can sign extend all of the
2127 // operands (often constants). This allows analysis of something like
2128 // this: for (signed char X = 0; X < 100; ++X) { int Y = X; }
2130 if (AR->isAffine()) {
2131 const SCEV *Start = AR->getStart();
2132 const SCEV *Step = AR->getStepRecurrence(*this);
2133 unsigned BitWidth = getTypeSizeInBits(AR->getType());
2134 const Loop *L = AR->getLoop();
2135
2136 // If we have special knowledge that this addrec won't overflow,
2137 // we don't need to do any further analysis.
2138 if (AR->hasNoSignedWrap()) {
2139 Start =
2141 Step = getSignExtendExpr(Step, Ty, Depth + 1);
2142 return getAddRecExpr(Start, Step, L, SCEV::FlagNSW);
2143 }
2144
2145 // Check whether the backedge-taken count is SCEVCouldNotCompute.
2146 // Note that this serves two purposes: It filters out loops that are
2147 // simply not analyzable, and it covers the case where this code is
2148 // being called from within backedge-taken count analysis, such that
2149 // attempting to ask for the backedge-taken count would likely result
2150 // in infinite recursion. In the later case, the analysis code will
2151 // cope with a conservative value, and it will take care to purge
2152 // that value once it has finished.
2153 const SCEV *MaxBECount = getConstantMaxBackedgeTakenCount(L);
2154 if (!isa<SCEVCouldNotCompute>(MaxBECount)) {
2155 // Manually compute the final value for AR, checking for
2156 // overflow.
2157
2158 // Check whether the backedge-taken count can be losslessly casted to
2159 // the addrec's type. The count is always unsigned.
2160 const SCEV *CastedMaxBECount =
2161 getTruncateOrZeroExtend(MaxBECount, Start->getType(), Depth);
2162 const SCEV *RecastedMaxBECount = getTruncateOrZeroExtend(
2163 CastedMaxBECount, MaxBECount->getType(), Depth);
2164 if (MaxBECount == RecastedMaxBECount) {
2165 Type *WideTy = IntegerType::get(getContext(), BitWidth * 2);
2166 // Check whether Start+Step*MaxBECount has no signed overflow.
2167 const SCEV *SMul = getMulExpr(CastedMaxBECount, Step,
2169 const SCEV *SAdd = getSignExtendExpr(getAddExpr(Start, SMul,
2171 Depth + 1),
2172 WideTy, Depth + 1);
2173 const SCEV *WideStart = getSignExtendExpr(Start, WideTy, Depth + 1);
2174 const SCEV *WideMaxBECount =
2175 getZeroExtendExpr(CastedMaxBECount, WideTy, Depth + 1);
2176 const SCEV *OperandExtendedAdd =
2177 getAddExpr(WideStart,
2178 getMulExpr(WideMaxBECount,
2179 getSignExtendExpr(Step, WideTy, Depth + 1),
2182 if (SAdd == OperandExtendedAdd) {
2183 // Cache knowledge of AR NSW, which is propagated to this AddRec.
2184 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNSW);
2185 // Return the expression with the addrec on the outside.
2186 Start = getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this,
2187 Depth + 1);
2188 Step = getSignExtendExpr(Step, Ty, Depth + 1);
2189 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
2190 }
2191 // Similar to above, only this time treat the step value as unsigned.
2192 // This covers loops that count up with an unsigned step.
2193 OperandExtendedAdd =
2194 getAddExpr(WideStart,
2195 getMulExpr(WideMaxBECount,
2196 getZeroExtendExpr(Step, WideTy, Depth + 1),
2199 if (SAdd == OperandExtendedAdd) {
2200 // If AR wraps around then
2201 //
2202 // abs(Step) * MaxBECount > unsigned-max(AR->getType())
2203 // => SAdd != OperandExtendedAdd
2204 //
2205 // Thus (AR is not NW => SAdd != OperandExtendedAdd) <=>
2206 // (SAdd == OperandExtendedAdd => AR is NW)
2207
2208 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNW);
2209
2210 // Return the expression with the addrec on the outside.
2211 Start = getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this,
2212 Depth + 1);
2213 Step = getZeroExtendExpr(Step, Ty, Depth + 1);
2214 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
2215 }
2216 }
2217 }
2218
2219 auto NewFlags = proveNoSignedWrapViaInduction(AR);
2220 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), NewFlags);
2221 if (AR->hasNoSignedWrap()) {
2222 // Same as nsw case above - duplicated here to avoid a compile time
2223 // issue. It's not clear that the order of checks does matter, but
2224 // it's one of two issue possible causes for a change which was
2225 // reverted. Be conservative for the moment.
2226 Start =
2228 Step = getSignExtendExpr(Step, Ty, Depth + 1);
2229 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
2230 }
2231
2232 // sext({C,+,Step}) --> (sext(D) + sext({C-D,+,Step}))<nuw><nsw>
2233 // if D + (C - D + Step * n) could be proven to not signed wrap
2234 // where D maximizes the number of trailing zeros of (C - D + Step * n)
2235 if (const auto *SC = dyn_cast<SCEVConstant>(Start)) {
2236 const APInt &C = SC->getAPInt();
2237 const APInt &D = extractConstantWithoutWrapping(*this, C, Step);
2238 if (D != 0) {
2239 const SCEV *SSExtD = getSignExtendExpr(getConstant(D), Ty, Depth);
2240 const SCEV *SResidual =
2241 getAddRecExpr(getConstant(C - D), Step, L, AR->getNoWrapFlags());
2242 const SCEV *SSExtR = getSignExtendExpr(SResidual, Ty, Depth + 1);
2243 return getAddExpr(SSExtD, SSExtR,
2245 Depth + 1);
2246 }
2247 }
2248
2249 if (proveNoWrapByVaryingStart<SCEVSignExtendExpr>(Start, Step, L)) {
2250 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNSW);
2251 Start =
2253 Step = getSignExtendExpr(Step, Ty, Depth + 1);
2254 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
2255 }
2256 }
2257
2258 // If the input value is provably positive and we could not simplify
2259 // away the sext build a zext instead.
2261 return getZeroExtendExpr(Op, Ty, Depth + 1);
2262
2263 // sext(smin(x, y)) -> smin(sext(x), sext(y))
2264 // sext(smax(x, y)) -> smax(sext(x), sext(y))
2267 SmallVector<SCEVUse, 4> Operands;
2268 for (SCEVUse Operand : MinMax->operands())
2269 Operands.push_back(getSignExtendExpr(Operand, Ty));
2271 return getSMinExpr(Operands);
2272 return getSMaxExpr(Operands);
2273 }
2274
2275 // The cast wasn't folded; create an explicit cast node.
2276 // Recompute the insert position, as it may have been invalidated.
2277 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
2278 SCEV *S = new (SCEVAllocator) SCEVSignExtendExpr(ID.Intern(SCEVAllocator),
2279 Op, Ty);
2280 UniqueSCEVs.InsertNode(S, IP);
2281 S->computeAndSetCanonical(*this);
2282 registerUser(S, Op);
2283 return S;
2284}
2285
2287 Type *Ty) {
2288 switch (Kind) {
2289 case scTruncate:
2290 return getTruncateExpr(Op, Ty);
2291 case scZeroExtend:
2292 return getZeroExtendExpr(Op, Ty);
2293 case scSignExtend:
2294 return getSignExtendExpr(Op, Ty);
2295 case scPtrToInt:
2296 return getPtrToIntExpr(Op, Ty);
2297 default:
2298 llvm_unreachable("Not a SCEV cast expression!");
2299 }
2300}
2301
2302/// getAnyExtendExpr - Return a SCEV for the given operand extended with
2303/// unspecified bits out to the given type.
2305 Type *Ty) {
2306 assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) &&
2307 "This is not an extending conversion!");
2308 assert(isSCEVable(Ty) &&
2309 "This is not a conversion to a SCEVable type!");
2310 Ty = getEffectiveSCEVType(Ty);
2311
2312 // Sign-extend negative constants.
2313 if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op))
2314 if (SC->getAPInt().isNegative())
2315 return getSignExtendExpr(Op, Ty);
2316
2317 // Peel off a truncate cast.
2319 const SCEV *NewOp = T->getOperand();
2320 if (getTypeSizeInBits(NewOp->getType()) < getTypeSizeInBits(Ty))
2321 return getAnyExtendExpr(NewOp, Ty);
2322 return getTruncateOrNoop(NewOp, Ty);
2323 }
2324
2325 // Next try a zext cast. If the cast is folded, use it.
2326 const SCEV *ZExt = getZeroExtendExpr(Op, Ty);
2327 if (!isa<SCEVZeroExtendExpr>(ZExt))
2328 return ZExt;
2329
2330 // Next try a sext cast. If the cast is folded, use it.
2331 const SCEV *SExt = getSignExtendExpr(Op, Ty);
2332 if (!isa<SCEVSignExtendExpr>(SExt))
2333 return SExt;
2334
2335 // Force the cast to be folded into the operands of an addrec.
2336 if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Op)) {
2338 for (const SCEV *Op : AR->operands())
2339 Ops.push_back(getAnyExtendExpr(Op, Ty));
2340 return getAddRecExpr(Ops, AR->getLoop(), SCEV::FlagNW);
2341 }
2342
2343 // If the expression is obviously signed, use the sext cast value.
2344 if (isa<SCEVSMaxExpr>(Op))
2345 return SExt;
2346
2347 // Absent any other information, use the zext cast value.
2348 return ZExt;
2349}
2350
2351/// Process the given Ops list, which is a list of operands to be added under
2352/// the given scale, update the given map. This is a helper function for
2353/// getAddRecExpr. As an example of what it does, given a sequence of operands
2354/// that would form an add expression like this:
2355///
2356/// m + n + 13 + (A * (o + p + (B * (q + m + 29)))) + r + (-1 * r)
2357///
2358/// where A and B are constants, update the map with these values:
2359///
2360/// (m, 1+A*B), (n, 1), (o, A), (p, A), (q, A*B), (r, 0)
2361///
2362/// and add 13 + A*B*29 to AccumulatedConstant.
2363/// This will allow getAddRecExpr to produce this:
2364///
2365/// 13+A*B*29 + n + (m * (1+A*B)) + ((o + p) * A) + (q * A*B)
2366///
2367/// This form often exposes folding opportunities that are hidden in
2368/// the original operand list.
2369///
2370/// Return true iff it appears that any interesting folding opportunities
2371/// may be exposed. This helps getAddRecExpr short-circuit extra work in
2372/// the common case where no interesting opportunities are present, and
2373/// is also used as a check to avoid infinite recursion.
2376 APInt &AccumulatedConstant,
2378 const APInt &Scale,
2379 ScalarEvolution &SE) {
2380 bool Interesting = false;
2381
2382 // Iterate over the add operands. They are sorted, with constants first.
2383 unsigned i = 0;
2384 while (const SCEVConstant *C = dyn_cast<SCEVConstant>(Ops[i])) {
2385 ++i;
2386 // Pull a buried constant out to the outside.
2387 if (Scale != 1 || AccumulatedConstant != 0 || C->getValue()->isZero())
2388 Interesting = true;
2389 AccumulatedConstant += Scale * C->getAPInt();
2390 }
2391
2392 // Next comes everything else. We're especially interested in multiplies
2393 // here, but they're in the middle, so just visit the rest with one loop.
2394 for (; i != Ops.size(); ++i) {
2396 if (Mul && isa<SCEVConstant>(Mul->getOperand(0))) {
2397 APInt NewScale =
2398 Scale * cast<SCEVConstant>(Mul->getOperand(0))->getAPInt();
2399 if (Mul->getNumOperands() == 2 && isa<SCEVAddExpr>(Mul->getOperand(1))) {
2400 // A multiplication of a constant with another add; recurse.
2401 const SCEVAddExpr *Add = cast<SCEVAddExpr>(Mul->getOperand(1));
2402 Interesting |= CollectAddOperandsWithScales(
2403 M, NewOps, AccumulatedConstant, Add->operands(), NewScale, SE);
2404 } else {
2405 // A multiplication of a constant with some other value. Update
2406 // the map.
2407 SmallVector<SCEVUse, 4> MulOps(drop_begin(Mul->operands()));
2408 const SCEV *Key = SE.getMulExpr(MulOps);
2409 auto Pair = M.insert({Key, NewScale});
2410 if (Pair.second) {
2411 NewOps.push_back(Pair.first->first);
2412 } else {
2413 Pair.first->second += NewScale;
2414 // The map already had an entry for this value, which may indicate
2415 // a folding opportunity.
2416 Interesting = true;
2417 }
2418 }
2419 } else {
2420 // An ordinary operand. Update the map.
2421 auto Pair = M.insert({Ops[i], Scale});
2422 if (Pair.second) {
2423 NewOps.push_back(Pair.first->first);
2424 } else {
2425 Pair.first->second += Scale;
2426 // The map already had an entry for this value, which may indicate
2427 // a folding opportunity.
2428 Interesting = true;
2429 }
2430 }
2431 }
2432
2433 return Interesting;
2434}
2435
2437 const SCEV *LHS, const SCEV *RHS,
2438 const Instruction *CtxI) {
2440 unsigned);
2441 switch (BinOp) {
2442 default:
2443 llvm_unreachable("Unsupported binary op");
2444 case Instruction::Add:
2446 break;
2447 case Instruction::Sub:
2449 break;
2450 case Instruction::Mul:
2452 break;
2453 }
2454
2455 const SCEV *(ScalarEvolution::*Extension)(const SCEV *, Type *, unsigned) =
2458
2459 // Check ext(LHS op RHS) == ext(LHS) op ext(RHS)
2460 auto *NarrowTy = cast<IntegerType>(LHS->getType());
2461 auto *WideTy =
2462 IntegerType::get(NarrowTy->getContext(), NarrowTy->getBitWidth() * 2);
2463
2464 const SCEV *A = (this->*Extension)(
2465 (this->*Operation)(LHS, RHS, SCEV::FlagAnyWrap, 0), WideTy, 0);
2466 const SCEV *LHSB = (this->*Extension)(LHS, WideTy, 0);
2467 const SCEV *RHSB = (this->*Extension)(RHS, WideTy, 0);
2468 const SCEV *B = (this->*Operation)(LHSB, RHSB, SCEV::FlagAnyWrap, 0);
2469 if (A == B)
2470 return true;
2471 // Can we use context to prove the fact we need?
2472 if (!CtxI)
2473 return false;
2474 // TODO: Support mul.
2475 if (BinOp == Instruction::Mul)
2476 return false;
2477 auto *RHSC = dyn_cast<SCEVConstant>(RHS);
2478 // TODO: Lift this limitation.
2479 if (!RHSC)
2480 return false;
2481 APInt C = RHSC->getAPInt();
2482 unsigned NumBits = C.getBitWidth();
2483 bool IsSub = (BinOp == Instruction::Sub);
2484 bool IsNegativeConst = (Signed && C.isNegative());
2485 // Compute the direction and magnitude by which we need to check overflow.
2486 bool OverflowDown = IsSub ^ IsNegativeConst;
2487 APInt Magnitude = C;
2488 if (IsNegativeConst) {
2489 if (C == APInt::getSignedMinValue(NumBits))
2490 // TODO: SINT_MIN on inversion gives the same negative value, we don't
2491 // want to deal with that.
2492 return false;
2493 Magnitude = -C;
2494 }
2495
2497 if (OverflowDown) {
2498 // To avoid overflow down, we need to make sure that MIN + Magnitude <= LHS.
2499 APInt Min = Signed ? APInt::getSignedMinValue(NumBits)
2500 : APInt::getMinValue(NumBits);
2501 APInt Limit = Min + Magnitude;
2502 return isKnownPredicateAt(Pred, getConstant(Limit), LHS, CtxI);
2503 } else {
2504 // To avoid overflow up, we need to make sure that LHS <= MAX - Magnitude.
2505 APInt Max = Signed ? APInt::getSignedMaxValue(NumBits)
2506 : APInt::getMaxValue(NumBits);
2507 APInt Limit = Max - Magnitude;
2508 return isKnownPredicateAt(Pred, LHS, getConstant(Limit), CtxI);
2509 }
2510}
2511
2512std::optional<SCEV::NoWrapFlags>
2514 const OverflowingBinaryOperator *OBO) {
2515 // It cannot be done any better.
2516 if (OBO->hasNoUnsignedWrap() && OBO->hasNoSignedWrap())
2517 return std::nullopt;
2518
2520
2521 if (OBO->hasNoUnsignedWrap())
2523 if (OBO->hasNoSignedWrap())
2525
2526 bool Deduced = false;
2527
2528 if (OBO->getOpcode() != Instruction::Add &&
2529 OBO->getOpcode() != Instruction::Sub &&
2530 OBO->getOpcode() != Instruction::Mul)
2531 return std::nullopt;
2532
2533 const SCEV *LHS = getSCEV(OBO->getOperand(0));
2534 const SCEV *RHS = getSCEV(OBO->getOperand(1));
2535
2536 const Instruction *CtxI =
2538 if (!OBO->hasNoUnsignedWrap() &&
2540 /* Signed */ false, LHS, RHS, CtxI)) {
2542 Deduced = true;
2543 }
2544
2545 if (!OBO->hasNoSignedWrap() &&
2547 /* Signed */ true, LHS, RHS, CtxI)) {
2549 Deduced = true;
2550 }
2551
2552 if (Deduced)
2553 return Flags;
2554 return std::nullopt;
2555}
2556
2557// We're trying to construct a SCEV of type `Type' with `Ops' as operands and
2558// `OldFlags' as can't-wrap behavior. Infer a more aggressive set of
2559// can't-overflow flags for the operation if possible.
2563 SCEV::NoWrapFlags Flags) {
2564 using namespace std::placeholders;
2565
2566 using OBO = OverflowingBinaryOperator;
2567
2568 bool CanAnalyze =
2570 (void)CanAnalyze;
2571 assert(CanAnalyze && "don't call from other places!");
2572
2573 int SignOrUnsignMask = SCEV::FlagNUW | SCEV::FlagNSW;
2574 SCEV::NoWrapFlags SignOrUnsignWrap =
2575 ScalarEvolution::maskFlags(Flags, SignOrUnsignMask);
2576
2577 // If FlagNSW is true and all the operands are non-negative, infer FlagNUW.
2578 auto IsKnownNonNegative = [&](SCEVUse U) {
2579 return SE->isKnownNonNegative(U);
2580 };
2581
2582 if (SignOrUnsignWrap == SCEV::FlagNSW && all_of(Ops, IsKnownNonNegative))
2583 Flags =
2584 ScalarEvolution::setFlags(Flags, (SCEV::NoWrapFlags)SignOrUnsignMask);
2585
2586 SignOrUnsignWrap = ScalarEvolution::maskFlags(Flags, SignOrUnsignMask);
2587
2588 if (SignOrUnsignWrap != SignOrUnsignMask &&
2589 (Type == scAddExpr || Type == scMulExpr) && Ops.size() == 2 &&
2590 isa<SCEVConstant>(Ops[0])) {
2591
2592 auto Opcode = [&] {
2593 switch (Type) {
2594 case scAddExpr:
2595 return Instruction::Add;
2596 case scMulExpr:
2597 return Instruction::Mul;
2598 default:
2599 llvm_unreachable("Unexpected SCEV op.");
2600 }
2601 }();
2602
2603 const APInt &C = cast<SCEVConstant>(Ops[0])->getAPInt();
2604
2605 // (A <opcode> C) --> (A <opcode> C)<nsw> if the op doesn't sign overflow.
2606 if (!(SignOrUnsignWrap & SCEV::FlagNSW)) {
2608 Opcode, C, OBO::NoSignedWrap);
2609 if (NSWRegion.contains(SE->getSignedRange(Ops[1])))
2611 }
2612
2613 // (A <opcode> C) --> (A <opcode> C)<nuw> if the op doesn't unsign overflow.
2614 if (!(SignOrUnsignWrap & SCEV::FlagNUW)) {
2616 Opcode, C, OBO::NoUnsignedWrap);
2617 if (NUWRegion.contains(SE->getUnsignedRange(Ops[1])))
2619 }
2620 }
2621
2622 // <0,+,nonnegative><nw> is also nuw
2623 // TODO: Add corresponding nsw case
2625 !ScalarEvolution::hasFlags(Flags, SCEV::FlagNUW) && Ops.size() == 2 &&
2626 Ops[0]->isZero() && IsKnownNonNegative(Ops[1]))
2628
2629 // both (udiv X, Y) * Y and Y * (udiv X, Y) are always NUW
2631 Ops.size() == 2) {
2632 if (auto *UDiv = dyn_cast<SCEVUDivExpr>(Ops[0]))
2633 if (UDiv->getOperand(1) == Ops[1])
2635 if (auto *UDiv = dyn_cast<SCEVUDivExpr>(Ops[1]))
2636 if (UDiv->getOperand(1) == Ops[0])
2638 }
2639
2640 return Flags;
2641}
2642
2644 return isLoopInvariant(S, L) && properlyDominates(S, L->getHeader());
2645}
2646
2647/// Get a canonical add expression, or something simpler if possible.
2649 SCEV::NoWrapFlags OrigFlags,
2650 unsigned Depth) {
2651 assert(!(OrigFlags & ~(SCEV::FlagNUW | SCEV::FlagNSW)) &&
2652 "only nuw or nsw allowed");
2653 assert(!Ops.empty() && "Cannot get empty add!");
2654 if (Ops.size() == 1) return Ops[0];
2655#ifndef NDEBUG
2656 Type *ETy = getEffectiveSCEVType(Ops[0]->getType());
2657 for (unsigned i = 1, e = Ops.size(); i != e; ++i)
2658 assert(getEffectiveSCEVType(Ops[i]->getType()) == ETy &&
2659 "SCEVAddExpr operand types don't match!");
2660 unsigned NumPtrs = count_if(
2661 Ops, [](const SCEV *Op) { return Op->getType()->isPointerTy(); });
2662 assert(NumPtrs <= 1 && "add has at most one pointer operand");
2663#endif
2664
2665 const SCEV *Folded = constantFoldAndGroupOps(
2666 *this, LI, DT, Ops,
2667 [](const APInt &C1, const APInt &C2) { return C1 + C2; },
2668 [](const APInt &C) { return C.isZero(); }, // identity
2669 [](const APInt &C) { return false; }); // absorber
2670 if (Folded)
2671 return Folded;
2672
2673 unsigned Idx = isa<SCEVConstant>(Ops[0]) ? 1 : 0;
2674
2675 // Delay expensive flag strengthening until necessary.
2676 auto ComputeFlags = [this, OrigFlags](ArrayRef<SCEVUse> Ops) {
2677 return StrengthenNoWrapFlags(this, scAddExpr, Ops, OrigFlags);
2678 };
2679
2680 // Limit recursion calls depth.
2682 return getOrCreateAddExpr(Ops, ComputeFlags(Ops));
2683
2684 if (SCEV *S = findExistingSCEVInCache(scAddExpr, Ops)) {
2685 // Don't strengthen flags if we have no new information.
2686 SCEVAddExpr *Add = static_cast<SCEVAddExpr *>(S);
2687 if (Add->getNoWrapFlags(OrigFlags) != OrigFlags)
2688 Add->setNoWrapFlags(ComputeFlags(Ops));
2689 return S;
2690 }
2691
2692 // Okay, check to see if the same value occurs in the operand list more than
2693 // once. If so, merge them together into an multiply expression. Since we
2694 // sorted the list, these values are required to be adjacent.
2695 Type *Ty = Ops[0]->getType();
2696 bool FoundMatch = false;
2697 for (unsigned i = 0, e = Ops.size(); i != e-1; ++i)
2698 if (Ops[i] == Ops[i+1]) { // X + Y + Y --> X + Y*2
2699 // Scan ahead to count how many equal operands there are.
2700 unsigned Count = 2;
2701 while (i+Count != e && Ops[i+Count] == Ops[i])
2702 ++Count;
2703 // Merge the values into a multiply.
2704 SCEVUse Scale = getConstant(Ty, Count);
2705 const SCEV *Mul = getMulExpr(Scale, Ops[i], SCEV::FlagAnyWrap, Depth + 1);
2706 if (Ops.size() == Count)
2707 return Mul;
2708 Ops[i] = Mul;
2709 Ops.erase(Ops.begin()+i+1, Ops.begin()+i+Count);
2710 --i; e -= Count - 1;
2711 FoundMatch = true;
2712 }
2713 if (FoundMatch)
2714 return getAddExpr(Ops, OrigFlags, Depth + 1);
2715
2716 // Check for truncates. If all the operands are truncated from the same
2717 // type, see if factoring out the truncate would permit the result to be
2718 // folded. eg., n*trunc(x) + m*trunc(y) --> trunc(trunc(m)*x + trunc(n)*y)
2719 // if the contents of the resulting outer trunc fold to something simple.
2720 auto FindTruncSrcType = [&]() -> Type * {
2721 // We're ultimately looking to fold an addrec of truncs and muls of only
2722 // constants and truncs, so if we find any other types of SCEV
2723 // as operands of the addrec then we bail and return nullptr here.
2724 // Otherwise, we return the type of the operand of a trunc that we find.
2725 if (auto *T = dyn_cast<SCEVTruncateExpr>(Ops[Idx]))
2726 return T->getOperand()->getType();
2727 if (const auto *Mul = dyn_cast<SCEVMulExpr>(Ops[Idx])) {
2728 SCEVUse LastOp = Mul->getOperand(Mul->getNumOperands() - 1);
2729 if (const auto *T = dyn_cast<SCEVTruncateExpr>(LastOp))
2730 return T->getOperand()->getType();
2731 }
2732 return nullptr;
2733 };
2734 if (auto *SrcType = FindTruncSrcType()) {
2735 SmallVector<SCEVUse, 8> LargeOps;
2736 bool Ok = true;
2737 // Check all the operands to see if they can be represented in the
2738 // source type of the truncate.
2739 for (const SCEV *Op : Ops) {
2741 if (T->getOperand()->getType() != SrcType) {
2742 Ok = false;
2743 break;
2744 }
2745 LargeOps.push_back(T->getOperand());
2746 } else if (const SCEVConstant *C = dyn_cast<SCEVConstant>(Op)) {
2747 LargeOps.push_back(getAnyExtendExpr(C, SrcType));
2748 } else if (const SCEVMulExpr *M = dyn_cast<SCEVMulExpr>(Op)) {
2749 SmallVector<SCEVUse, 8> LargeMulOps;
2750 for (unsigned j = 0, f = M->getNumOperands(); j != f && Ok; ++j) {
2751 if (const SCEVTruncateExpr *T =
2752 dyn_cast<SCEVTruncateExpr>(M->getOperand(j))) {
2753 if (T->getOperand()->getType() != SrcType) {
2754 Ok = false;
2755 break;
2756 }
2757 LargeMulOps.push_back(T->getOperand());
2758 } else if (const auto *C = dyn_cast<SCEVConstant>(M->getOperand(j))) {
2759 LargeMulOps.push_back(getAnyExtendExpr(C, SrcType));
2760 } else {
2761 Ok = false;
2762 break;
2763 }
2764 }
2765 if (Ok)
2766 LargeOps.push_back(getMulExpr(LargeMulOps, SCEV::FlagAnyWrap, Depth + 1));
2767 } else {
2768 Ok = false;
2769 break;
2770 }
2771 }
2772 if (Ok) {
2773 // Evaluate the expression in the larger type.
2774 const SCEV *Fold = getAddExpr(LargeOps, SCEV::FlagAnyWrap, Depth + 1);
2775 // If it folds to something simple, use it. Otherwise, don't.
2776 if (isa<SCEVConstant>(Fold) || isa<SCEVUnknown>(Fold))
2777 return getTruncateExpr(Fold, Ty);
2778 }
2779 }
2780
2781 if (Ops.size() == 2) {
2782 // Check if we have an expression of the form ((X + C1) - C2), where C1 and
2783 // C2 can be folded in a way that allows retaining wrapping flags of (X +
2784 // C1).
2785 const SCEV *A = Ops[0];
2786 const SCEV *B = Ops[1];
2787 auto *AddExpr = dyn_cast<SCEVAddExpr>(B);
2788 auto *C = dyn_cast<SCEVConstant>(A);
2789 if (AddExpr && C && isa<SCEVConstant>(AddExpr->getOperand(0))) {
2790 auto C1 = cast<SCEVConstant>(AddExpr->getOperand(0))->getAPInt();
2791 auto C2 = C->getAPInt();
2792 SCEV::NoWrapFlags PreservedFlags = SCEV::FlagAnyWrap;
2793
2794 APInt ConstAdd = C1 + C2;
2795 auto AddFlags = AddExpr->getNoWrapFlags();
2796 // Adding a smaller constant is NUW if the original AddExpr was NUW.
2798 ConstAdd.ule(C1)) {
2799 PreservedFlags =
2801 }
2802
2803 // Adding a constant with the same sign and small magnitude is NSW, if the
2804 // original AddExpr was NSW.
2806 C1.isSignBitSet() == ConstAdd.isSignBitSet() &&
2807 ConstAdd.abs().ule(C1.abs())) {
2808 PreservedFlags =
2810 }
2811
2812 if (PreservedFlags != SCEV::FlagAnyWrap) {
2813 SmallVector<SCEVUse, 4> NewOps(AddExpr->operands());
2814 NewOps[0] = getConstant(ConstAdd);
2815 return getAddExpr(NewOps, PreservedFlags);
2816 }
2817 }
2818
2819 // Try to push the constant operand into a ZExt: A + zext (-A + B) -> zext
2820 // (B), if trunc (A) + -A + B does not unsigned-wrap.
2821 const SCEVAddExpr *InnerAdd;
2822 if (match(B, m_scev_ZExt(m_scev_Add(InnerAdd)))) {
2823 const SCEV *NarrowA = getTruncateExpr(A, InnerAdd->getType());
2824 if (NarrowA == getNegativeSCEV(InnerAdd->getOperand(0)) &&
2825 getZeroExtendExpr(NarrowA, B->getType()) == A &&
2826 hasFlags(StrengthenNoWrapFlags(this, scAddExpr, {NarrowA, InnerAdd},
2828 SCEV::FlagNUW)) {
2829 return getZeroExtendExpr(getAddExpr(NarrowA, InnerAdd), B->getType());
2830 }
2831 }
2832 }
2833
2834 // Canonicalize (-1 * urem X, Y) + X --> (Y * X/Y)
2835 const SCEV *Y;
2836 if (Ops.size() == 2 &&
2837 match(Ops[0],
2839 m_scev_URem(m_scev_Specific(Ops[1]), m_SCEV(Y), *this))))
2840 return getMulExpr(Y, getUDivExpr(Ops[1], Y));
2841
2842 // Skip past any other cast SCEVs.
2843 while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scAddExpr)
2844 ++Idx;
2845
2846 // If there are add operands they would be next.
2847 if (Idx < Ops.size()) {
2848 bool DeletedAdd = false;
2849 // If the original flags and all inlined SCEVAddExprs are NUW, use the
2850 // common NUW flag for expression after inlining. Other flags cannot be
2851 // preserved, because they may depend on the original order of operations.
2852 SCEV::NoWrapFlags CommonFlags = maskFlags(OrigFlags, SCEV::FlagNUW);
2853 while (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Ops[Idx])) {
2854 if (Ops.size() > AddOpsInlineThreshold ||
2855 Add->getNumOperands() > AddOpsInlineThreshold)
2856 break;
2857 // If we have an add, expand the add operands onto the end of the operands
2858 // list.
2859 Ops.erase(Ops.begin()+Idx);
2860 append_range(Ops, Add->operands());
2861 DeletedAdd = true;
2862 CommonFlags = maskFlags(CommonFlags, Add->getNoWrapFlags());
2863 }
2864
2865 // If we deleted at least one add, we added operands to the end of the list,
2866 // and they are not necessarily sorted. Recurse to resort and resimplify
2867 // any operands we just acquired.
2868 if (DeletedAdd)
2869 return getAddExpr(Ops, CommonFlags, Depth + 1);
2870 }
2871
2872 // Skip over the add expression until we get to a multiply.
2873 while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scMulExpr)
2874 ++Idx;
2875
2876 // Check to see if there are any folding opportunities present with
2877 // operands multiplied by constant values.
2878 if (Idx < Ops.size() && isa<SCEVMulExpr>(Ops[Idx])) {
2882 APInt AccumulatedConstant(BitWidth, 0);
2883 if (CollectAddOperandsWithScales(M, NewOps, AccumulatedConstant,
2884 Ops, APInt(BitWidth, 1), *this)) {
2885 struct APIntCompare {
2886 bool operator()(const APInt &LHS, const APInt &RHS) const {
2887 return LHS.ult(RHS);
2888 }
2889 };
2890
2891 // Some interesting folding opportunity is present, so its worthwhile to
2892 // re-generate the operands list. Group the operands by constant scale,
2893 // to avoid multiplying by the same constant scale multiple times.
2894 std::map<APInt, SmallVector<SCEVUse, 4>, APIntCompare> MulOpLists;
2895 for (const SCEV *NewOp : NewOps)
2896 MulOpLists[M.find(NewOp)->second].push_back(NewOp);
2897 // Re-generate the operands list.
2898 Ops.clear();
2899 if (AccumulatedConstant != 0)
2900 Ops.push_back(getConstant(AccumulatedConstant));
2901 for (auto &MulOp : MulOpLists) {
2902 if (MulOp.first == 1) {
2903 Ops.push_back(getAddExpr(MulOp.second, SCEV::FlagAnyWrap, Depth + 1));
2904 } else if (MulOp.first != 0) {
2905 Ops.push_back(getMulExpr(
2906 getConstant(MulOp.first),
2907 getAddExpr(MulOp.second, SCEV::FlagAnyWrap, Depth + 1),
2908 SCEV::FlagAnyWrap, Depth + 1));
2909 }
2910 }
2911 if (Ops.empty())
2912 return getZero(Ty);
2913 if (Ops.size() == 1)
2914 return Ops[0];
2915 return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
2916 }
2917 }
2918
2919 // If we are adding something to a multiply expression, make sure the
2920 // something is not already an operand of the multiply. If so, merge it into
2921 // the multiply.
2922 for (; Idx < Ops.size() && isa<SCEVMulExpr>(Ops[Idx]); ++Idx) {
2923 const SCEVMulExpr *Mul = cast<SCEVMulExpr>(Ops[Idx]);
2924 for (unsigned MulOp = 0, e = Mul->getNumOperands(); MulOp != e; ++MulOp) {
2925 const SCEV *MulOpSCEV = Mul->getOperand(MulOp);
2926 if (isa<SCEVConstant>(MulOpSCEV))
2927 continue;
2928 for (unsigned AddOp = 0, e = Ops.size(); AddOp != e; ++AddOp)
2929 if (MulOpSCEV == Ops[AddOp]) {
2930 // Fold W + X + (X * Y * Z) --> W + (X * ((Y*Z)+1))
2931 const SCEV *InnerMul = Mul->getOperand(MulOp == 0);
2932 if (Mul->getNumOperands() != 2) {
2933 // If the multiply has more than two operands, we must get the
2934 // Y*Z term.
2935 SmallVector<SCEVUse, 4> MulOps(Mul->operands().take_front(MulOp));
2936 append_range(MulOps, Mul->operands().drop_front(MulOp + 1));
2937 InnerMul = getMulExpr(MulOps, SCEV::FlagAnyWrap, Depth + 1);
2938 }
2939 const SCEV *AddOne =
2940 getAddExpr(getOne(Ty), InnerMul, SCEV::FlagAnyWrap, Depth + 1);
2941 const SCEV *OuterMul = getMulExpr(AddOne, MulOpSCEV,
2943 if (Ops.size() == 2) return OuterMul;
2944 if (AddOp < Idx) {
2945 Ops.erase(Ops.begin()+AddOp);
2946 Ops.erase(Ops.begin()+Idx-1);
2947 } else {
2948 Ops.erase(Ops.begin()+Idx);
2949 Ops.erase(Ops.begin()+AddOp-1);
2950 }
2951 Ops.push_back(OuterMul);
2952 return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
2953 }
2954
2955 // Check this multiply against other multiplies being added together.
2956 for (unsigned OtherMulIdx = Idx+1;
2957 OtherMulIdx < Ops.size() && isa<SCEVMulExpr>(Ops[OtherMulIdx]);
2958 ++OtherMulIdx) {
2959 const SCEVMulExpr *OtherMul = cast<SCEVMulExpr>(Ops[OtherMulIdx]);
2960 // If MulOp occurs in OtherMul, we can fold the two multiplies
2961 // together.
2962 for (unsigned OMulOp = 0, e = OtherMul->getNumOperands();
2963 OMulOp != e; ++OMulOp)
2964 if (OtherMul->getOperand(OMulOp) == MulOpSCEV) {
2965 // Fold X + (A*B*C) + (A*D*E) --> X + (A*(B*C+D*E))
2966 const SCEV *InnerMul1 = Mul->getOperand(MulOp == 0);
2967 if (Mul->getNumOperands() != 2) {
2968 SmallVector<SCEVUse, 4> MulOps(Mul->operands().take_front(MulOp));
2969 append_range(MulOps, Mul->operands().drop_front(MulOp+1));
2970 InnerMul1 = getMulExpr(MulOps, SCEV::FlagAnyWrap, Depth + 1);
2971 }
2972 const SCEV *InnerMul2 = OtherMul->getOperand(OMulOp == 0);
2973 if (OtherMul->getNumOperands() != 2) {
2975 OtherMul->operands().take_front(OMulOp));
2976 append_range(MulOps, OtherMul->operands().drop_front(OMulOp+1));
2977 InnerMul2 = getMulExpr(MulOps, SCEV::FlagAnyWrap, Depth + 1);
2978 }
2979 const SCEV *InnerMulSum =
2980 getAddExpr(InnerMul1, InnerMul2, SCEV::FlagAnyWrap, Depth + 1);
2981 const SCEV *OuterMul = getMulExpr(MulOpSCEV, InnerMulSum,
2983 if (Ops.size() == 2) return OuterMul;
2984 Ops.erase(Ops.begin()+Idx);
2985 Ops.erase(Ops.begin()+OtherMulIdx-1);
2986 Ops.push_back(OuterMul);
2987 return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
2988 }
2989 }
2990 }
2991 }
2992
2993 // If there are any add recurrences in the operands list, see if any other
2994 // added values are loop invariant. If so, we can fold them into the
2995 // recurrence.
2996 while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scAddRecExpr)
2997 ++Idx;
2998
2999 // Scan over all recurrences, trying to fold loop invariants into them.
3000 for (; Idx < Ops.size() && isa<SCEVAddRecExpr>(Ops[Idx]); ++Idx) {
3001 // Scan all of the other operands to this add and add them to the vector if
3002 // they are loop invariant w.r.t. the recurrence.
3004 const SCEVAddRecExpr *AddRec = cast<SCEVAddRecExpr>(Ops[Idx]);
3005 const Loop *AddRecLoop = AddRec->getLoop();
3006 for (unsigned i = 0, e = Ops.size(); i != e; ++i)
3007 if (isAvailableAtLoopEntry(Ops[i], AddRecLoop)) {
3008 LIOps.push_back(Ops[i]);
3009 Ops.erase(Ops.begin()+i);
3010 --i; --e;
3011 }
3012
3013 // If we found some loop invariants, fold them into the recurrence.
3014 if (!LIOps.empty()) {
3015 // Compute nowrap flags for the addition of the loop-invariant ops and
3016 // the addrec. Temporarily push it as an operand for that purpose. These
3017 // flags are valid in the scope of the addrec only.
3018 LIOps.push_back(AddRec);
3019 SCEV::NoWrapFlags Flags = ComputeFlags(LIOps);
3020 LIOps.pop_back();
3021
3022 // NLI + LI + {Start,+,Step} --> NLI + {LI+Start,+,Step}
3023 LIOps.push_back(AddRec->getStart());
3024
3025 SmallVector<SCEVUse, 4> AddRecOps(AddRec->operands());
3026
3027 // It is not in general safe to propagate flags valid on an add within
3028 // the addrec scope to one outside it. We must prove that the inner
3029 // scope is guaranteed to execute if the outer one does to be able to
3030 // safely propagate. We know the program is undefined if poison is
3031 // produced on the inner scoped addrec. We also know that *for this use*
3032 // the outer scoped add can't overflow (because of the flags we just
3033 // computed for the inner scoped add) without the program being undefined.
3034 // Proving that entry to the outer scope neccesitates entry to the inner
3035 // scope, thus proves the program undefined if the flags would be violated
3036 // in the outer scope.
3037 SCEV::NoWrapFlags AddFlags = Flags;
3038 if (AddFlags != SCEV::FlagAnyWrap) {
3039 auto *DefI = getDefiningScopeBound(LIOps);
3040 auto *ReachI = &*AddRecLoop->getHeader()->begin();
3041 if (!isGuaranteedToTransferExecutionTo(DefI, ReachI))
3042 AddFlags = SCEV::FlagAnyWrap;
3043 }
3044 AddRecOps[0] = getAddExpr(LIOps, AddFlags, Depth + 1);
3045
3046 // Build the new addrec. Propagate the NUW and NSW flags if both the
3047 // outer add and the inner addrec are guaranteed to have no overflow.
3048 // Always propagate NW.
3049 Flags = AddRec->getNoWrapFlags(setFlags(Flags, SCEV::FlagNW));
3050 const SCEV *NewRec = getAddRecExpr(AddRecOps, AddRecLoop, Flags);
3051
3052 // If all of the other operands were loop invariant, we are done.
3053 if (Ops.size() == 1) return NewRec;
3054
3055 // Otherwise, add the folded AddRec by the non-invariant parts.
3056 for (unsigned i = 0;; ++i)
3057 if (Ops[i] == AddRec) {
3058 Ops[i] = NewRec;
3059 break;
3060 }
3061 return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
3062 }
3063
3064 // Okay, if there weren't any loop invariants to be folded, check to see if
3065 // there are multiple AddRec's with the same loop induction variable being
3066 // added together. If so, we can fold them.
3067 for (unsigned OtherIdx = Idx+1;
3068 OtherIdx < Ops.size() && isa<SCEVAddRecExpr>(Ops[OtherIdx]);
3069 ++OtherIdx) {
3070 // We expect the AddRecExpr's to be sorted in reverse dominance order,
3071 // so that the 1st found AddRecExpr is dominated by all others.
3072 assert(DT.dominates(
3073 cast<SCEVAddRecExpr>(Ops[OtherIdx])->getLoop()->getHeader(),
3074 AddRec->getLoop()->getHeader()) &&
3075 "AddRecExprs are not sorted in reverse dominance order?");
3076 if (AddRecLoop == cast<SCEVAddRecExpr>(Ops[OtherIdx])->getLoop()) {
3077 // Other + {A,+,B}<L> + {C,+,D}<L> --> Other + {A+C,+,B+D}<L>
3078 SmallVector<SCEVUse, 4> AddRecOps(AddRec->operands());
3079 for (; OtherIdx != Ops.size() && isa<SCEVAddRecExpr>(Ops[OtherIdx]);
3080 ++OtherIdx) {
3081 const auto *OtherAddRec = cast<SCEVAddRecExpr>(Ops[OtherIdx]);
3082 if (OtherAddRec->getLoop() == AddRecLoop) {
3083 for (unsigned i = 0, e = OtherAddRec->getNumOperands();
3084 i != e; ++i) {
3085 if (i >= AddRecOps.size()) {
3086 append_range(AddRecOps, OtherAddRec->operands().drop_front(i));
3087 break;
3088 }
3089 AddRecOps[i] =
3090 getAddExpr(AddRecOps[i], OtherAddRec->getOperand(i),
3092 }
3093 Ops.erase(Ops.begin() + OtherIdx); --OtherIdx;
3094 }
3095 }
3096 // Step size has changed, so we cannot guarantee no self-wraparound.
3097 Ops[Idx] = getAddRecExpr(AddRecOps, AddRecLoop, SCEV::FlagAnyWrap);
3098 return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
3099 }
3100 }
3101
3102 // Otherwise couldn't fold anything into this recurrence. Move onto the
3103 // next one.
3104 }
3105
3106 // Okay, it looks like we really DO need an add expr. Check to see if we
3107 // already have one, otherwise create a new one.
3108 return getOrCreateAddExpr(Ops, ComputeFlags(Ops));
3109}
3110
3111const SCEV *ScalarEvolution::getOrCreateAddExpr(ArrayRef<SCEVUse> Ops,
3112 SCEV::NoWrapFlags Flags) {
3114 ID.AddInteger(scAddExpr);
3115 for (const SCEV *Op : Ops)
3116 ID.AddPointer(Op);
3117 void *IP = nullptr;
3118 SCEVAddExpr *S =
3119 static_cast<SCEVAddExpr *>(UniqueSCEVs.FindNodeOrInsertPos(ID, IP));
3120 if (!S) {
3121 SCEVUse *O = SCEVAllocator.Allocate<SCEVUse>(Ops.size());
3123 S = new (SCEVAllocator)
3124 SCEVAddExpr(ID.Intern(SCEVAllocator), O, Ops.size());
3125 UniqueSCEVs.InsertNode(S, IP);
3126 S->computeAndSetCanonical(*this);
3127 registerUser(S, Ops);
3128 }
3129 S->setNoWrapFlags(Flags);
3130 return S;
3131}
3132
3133const SCEV *ScalarEvolution::getOrCreateAddRecExpr(ArrayRef<SCEVUse> Ops,
3134 const Loop *L,
3135 SCEV::NoWrapFlags Flags) {
3136 FoldingSetNodeID ID;
3137 ID.AddInteger(scAddRecExpr);
3138 for (const SCEV *Op : Ops)
3139 ID.AddPointer(Op);
3140 ID.AddPointer(L);
3141 void *IP = nullptr;
3142 SCEVAddRecExpr *S =
3143 static_cast<SCEVAddRecExpr *>(UniqueSCEVs.FindNodeOrInsertPos(ID, IP));
3144 if (!S) {
3145 SCEVUse *O = SCEVAllocator.Allocate<SCEVUse>(Ops.size());
3147 S = new (SCEVAllocator)
3148 SCEVAddRecExpr(ID.Intern(SCEVAllocator), O, Ops.size(), L);
3149 UniqueSCEVs.InsertNode(S, IP);
3150 S->computeAndSetCanonical(*this);
3151 LoopUsers[L].push_back(S);
3152 registerUser(S, Ops);
3153 }
3154 setNoWrapFlags(S, Flags);
3155 return S;
3156}
3157
3158const SCEV *ScalarEvolution::getOrCreateMulExpr(ArrayRef<SCEVUse> Ops,
3159 SCEV::NoWrapFlags Flags) {
3160 FoldingSetNodeID ID;
3161 ID.AddInteger(scMulExpr);
3162 for (const SCEV *Op : Ops)
3163 ID.AddPointer(Op);
3164 void *IP = nullptr;
3165 SCEVMulExpr *S =
3166 static_cast<SCEVMulExpr *>(UniqueSCEVs.FindNodeOrInsertPos(ID, IP));
3167 if (!S) {
3168 SCEVUse *O = SCEVAllocator.Allocate<SCEVUse>(Ops.size());
3170 S = new (SCEVAllocator) SCEVMulExpr(ID.Intern(SCEVAllocator),
3171 O, Ops.size());
3172 UniqueSCEVs.InsertNode(S, IP);
3173 S->computeAndSetCanonical(*this);
3174 registerUser(S, Ops);
3175 }
3176 S->setNoWrapFlags(Flags);
3177 return S;
3178}
3179
3180static uint64_t umul_ov(uint64_t i, uint64_t j, bool &Overflow) {
3181 uint64_t k = i*j;
3182 if (j > 1 && k / j != i) Overflow = true;
3183 return k;
3184}
3185
3186/// Compute the result of "n choose k", the binomial coefficient. If an
3187/// intermediate computation overflows, Overflow will be set and the return will
3188/// be garbage. Overflow is not cleared on absence of overflow.
3189static uint64_t Choose(uint64_t n, uint64_t k, bool &Overflow) {
3190 // We use the multiplicative formula:
3191 // n(n-1)(n-2)...(n-(k-1)) / k(k-1)(k-2)...1 .
3192 // At each iteration, we take the n-th term of the numeral and divide by the
3193 // (k-n)th term of the denominator. This division will always produce an
3194 // integral result, and helps reduce the chance of overflow in the
3195 // intermediate computations. However, we can still overflow even when the
3196 // final result would fit.
3197
3198 if (n == 0 || n == k) return 1;
3199 if (k > n) return 0;
3200
3201 if (k > n/2)
3202 k = n-k;
3203
3204 uint64_t r = 1;
3205 for (uint64_t i = 1; i <= k; ++i) {
3206 r = umul_ov(r, n-(i-1), Overflow);
3207 r /= i;
3208 }
3209 return r;
3210}
3211
3212/// Determine if any of the operands in this SCEV are a constant or if
3213/// any of the add or multiply expressions in this SCEV contain a constant.
3214static bool containsConstantInAddMulChain(const SCEV *StartExpr) {
3215 struct FindConstantInAddMulChain {
3216 bool FoundConstant = false;
3217
3218 bool follow(const SCEV *S) {
3219 FoundConstant |= isa<SCEVConstant>(S);
3220 return isa<SCEVAddExpr>(S) || isa<SCEVMulExpr>(S);
3221 }
3222
3223 bool isDone() const {
3224 return FoundConstant;
3225 }
3226 };
3227
3228 FindConstantInAddMulChain F;
3230 ST.visitAll(StartExpr);
3231 return F.FoundConstant;
3232}
3233
3234/// Get a canonical multiply expression, or something simpler if possible.
3236 SCEV::NoWrapFlags OrigFlags,
3237 unsigned Depth) {
3238 assert(OrigFlags == maskFlags(OrigFlags, SCEV::FlagNUW | SCEV::FlagNSW) &&
3239 "only nuw or nsw allowed");
3240 assert(!Ops.empty() && "Cannot get empty mul!");
3241 if (Ops.size() == 1) return Ops[0];
3242#ifndef NDEBUG
3243 Type *ETy = Ops[0]->getType();
3244 assert(!ETy->isPointerTy());
3245 for (unsigned i = 1, e = Ops.size(); i != e; ++i)
3246 assert(Ops[i]->getType() == ETy &&
3247 "SCEVMulExpr operand types don't match!");
3248#endif
3249
3250 const SCEV *Folded = constantFoldAndGroupOps(
3251 *this, LI, DT, Ops,
3252 [](const APInt &C1, const APInt &C2) { return C1 * C2; },
3253 [](const APInt &C) { return C.isOne(); }, // identity
3254 [](const APInt &C) { return C.isZero(); }); // absorber
3255 if (Folded)
3256 return Folded;
3257
3258 // Delay expensive flag strengthening until necessary.
3259 auto ComputeFlags = [this, OrigFlags](const ArrayRef<SCEVUse> Ops) {
3260 return StrengthenNoWrapFlags(this, scMulExpr, Ops, OrigFlags);
3261 };
3262
3263 // Limit recursion calls depth.
3265 return getOrCreateMulExpr(Ops, ComputeFlags(Ops));
3266
3267 if (SCEV *S = findExistingSCEVInCache(scMulExpr, Ops)) {
3268 // Don't strengthen flags if we have no new information.
3269 SCEVMulExpr *Mul = static_cast<SCEVMulExpr *>(S);
3270 if (Mul->getNoWrapFlags(OrigFlags) != OrigFlags)
3271 Mul->setNoWrapFlags(ComputeFlags(Ops));
3272 return S;
3273 }
3274
3275 if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(Ops[0])) {
3276 if (Ops.size() == 2) {
3277 // C1*(C2+V) -> C1*C2 + C1*V
3278 // If any of Add's ops are Adds or Muls with a constant, apply this
3279 // transformation as well.
3280 //
3281 // TODO: There are some cases where this transformation is not
3282 // profitable; for example, Add = (C0 + X) * Y + Z. Maybe the scope of
3283 // this transformation should be narrowed down.
3284 const SCEV *Op0, *Op1;
3285 if (match(Ops[1], m_scev_Add(m_SCEV(Op0), m_SCEV(Op1))) &&
3287 const SCEV *LHS = getMulExpr(LHSC, Op0, SCEV::FlagAnyWrap, Depth + 1);
3288 const SCEV *RHS = getMulExpr(LHSC, Op1, SCEV::FlagAnyWrap, Depth + 1);
3289 return getAddExpr(LHS, RHS, SCEV::FlagAnyWrap, Depth + 1);
3290 }
3291
3292 if (Ops[0]->isAllOnesValue()) {
3293 // If we have a mul by -1 of an add, try distributing the -1 among the
3294 // add operands.
3295 if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Ops[1])) {
3297 bool AnyFolded = false;
3298 for (const SCEV *AddOp : Add->operands()) {
3299 const SCEV *Mul = getMulExpr(Ops[0], SCEVUse(AddOp),
3301 if (!isa<SCEVMulExpr>(Mul)) AnyFolded = true;
3302 NewOps.push_back(Mul);
3303 }
3304 if (AnyFolded)
3305 return getAddExpr(NewOps, SCEV::FlagAnyWrap, Depth + 1);
3306 } else if (const auto *AddRec = dyn_cast<SCEVAddRecExpr>(Ops[1])) {
3307 // Negation preserves a recurrence's no self-wrap property.
3308 SmallVector<SCEVUse, 4> Operands;
3309 for (const SCEV *AddRecOp : AddRec->operands())
3310 Operands.push_back(getMulExpr(Ops[0], SCEVUse(AddRecOp),
3311 SCEV::FlagAnyWrap, Depth + 1));
3312 // Let M be the minimum representable signed value. AddRec with nsw
3313 // multiplied by -1 can have signed overflow if and only if it takes a
3314 // value of M: M * (-1) would stay M and (M + 1) * (-1) would be the
3315 // maximum signed value. In all other cases signed overflow is
3316 // impossible.
3317 auto FlagsMask = SCEV::FlagNW;
3318 if (hasFlags(AddRec->getNoWrapFlags(), SCEV::FlagNSW)) {
3319 auto MinInt =
3320 APInt::getSignedMinValue(getTypeSizeInBits(AddRec->getType()));
3321 if (getSignedRangeMin(AddRec) != MinInt)
3322 FlagsMask = setFlags(FlagsMask, SCEV::FlagNSW);
3323 }
3324 return getAddRecExpr(Operands, AddRec->getLoop(),
3325 AddRec->getNoWrapFlags(FlagsMask));
3326 }
3327 }
3328
3329 // Try to push the constant operand into a ZExt: C * zext (A + B) ->
3330 // zext (C*A + C*B) if trunc (C) * (A + B) does not unsigned-wrap.
3331 const SCEVAddExpr *InnerAdd;
3332 if (match(Ops[1], m_scev_ZExt(m_scev_Add(InnerAdd)))) {
3333 const SCEV *NarrowC = getTruncateExpr(LHSC, InnerAdd->getType());
3334 if (isa<SCEVConstant>(InnerAdd->getOperand(0)) &&
3335 getZeroExtendExpr(NarrowC, Ops[1]->getType()) == LHSC &&
3336 hasFlags(StrengthenNoWrapFlags(this, scMulExpr, {NarrowC, InnerAdd},
3338 SCEV::FlagNUW)) {
3339 auto *Res = getMulExpr(NarrowC, InnerAdd, SCEV::FlagNUW, Depth + 1);
3340 return getZeroExtendExpr(Res, Ops[1]->getType(), Depth + 1);
3341 };
3342 }
3343
3344 // Try to fold (C1 * D /u C2) -> C1/C2 * D, if C1 and C2 are powers-of-2,
3345 // D is a multiple of C2, and C1 is a multiple of C2. If C2 is a multiple
3346 // of C1, fold to (D /u (C2 /u C1)).
3347 const SCEV *D;
3348 APInt C1V = LHSC->getAPInt();
3349 // (C1 * D /u C2) == -1 * -C1 * D /u C2 when C1 != INT_MIN. Don't treat -1
3350 // as -1 * 1, as it won't enable additional folds.
3351 if (C1V.isNegative() && !C1V.isMinSignedValue() && !C1V.isAllOnes())
3352 C1V = C1V.abs();
3353 const SCEVConstant *C2;
3354 if (C1V.isPowerOf2() &&
3356 C2->getAPInt().isPowerOf2() &&
3357 C1V.logBase2() <= getMinTrailingZeros(D)) {
3358 const SCEV *NewMul = nullptr;
3359 if (C1V.uge(C2->getAPInt())) {
3360 NewMul = getMulExpr(getUDivExpr(getConstant(C1V), C2), D);
3361 } else if (C2->getAPInt().logBase2() <= getMinTrailingZeros(D)) {
3362 assert(C1V.ugt(1) && "C1 <= 1 should have been folded earlier");
3363 NewMul = getUDivExpr(D, getUDivExpr(C2, getConstant(C1V)));
3364 }
3365 if (NewMul)
3366 return C1V == LHSC->getAPInt() ? NewMul : getNegativeSCEV(NewMul);
3367 }
3368 }
3369 }
3370
3371 // Skip over the add expression until we get to a multiply.
3372 unsigned Idx = 0;
3373 while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scMulExpr)
3374 ++Idx;
3375
3376 // If there are mul operands inline them all into this expression.
3377 if (Idx < Ops.size()) {
3378 bool DeletedMul = false;
3379 while (const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(Ops[Idx])) {
3380 if (Ops.size() > MulOpsInlineThreshold)
3381 break;
3382 // If we have an mul, expand the mul operands onto the end of the
3383 // operands list.
3384 Ops.erase(Ops.begin()+Idx);
3385 append_range(Ops, Mul->operands());
3386 DeletedMul = true;
3387 }
3388
3389 // If we deleted at least one mul, we added operands to the end of the
3390 // list, and they are not necessarily sorted. Recurse to resort and
3391 // resimplify any operands we just acquired.
3392 if (DeletedMul)
3393 return getMulExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
3394 }
3395
3396 // If there are any add recurrences in the operands list, see if any other
3397 // added values are loop invariant. If so, we can fold them into the
3398 // recurrence.
3399 while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scAddRecExpr)
3400 ++Idx;
3401
3402 // Scan over all recurrences, trying to fold loop invariants into them.
3403 for (; Idx < Ops.size() && isa<SCEVAddRecExpr>(Ops[Idx]); ++Idx) {
3404 // Scan all of the other operands to this mul and add them to the vector
3405 // if they are loop invariant w.r.t. the recurrence.
3407 const SCEVAddRecExpr *AddRec = cast<SCEVAddRecExpr>(Ops[Idx]);
3408 for (unsigned i = 0, e = Ops.size(); i != e; ++i)
3409 if (isAvailableAtLoopEntry(Ops[i], AddRec->getLoop())) {
3410 LIOps.push_back(Ops[i]);
3411 Ops.erase(Ops.begin()+i);
3412 --i; --e;
3413 }
3414
3415 // If we found some loop invariants, fold them into the recurrence.
3416 if (!LIOps.empty()) {
3417 // NLI * LI * {Start,+,Step} --> NLI * {LI*Start,+,LI*Step}
3419 NewOps.reserve(AddRec->getNumOperands());
3420 const SCEV *Scale = getMulExpr(LIOps, SCEV::FlagAnyWrap, Depth + 1);
3421
3422 // If both the mul and addrec are nuw, we can preserve nuw.
3423 // If both the mul and addrec are nsw, we can only preserve nsw if either
3424 // a) they are also nuw, or
3425 // b) all multiplications of addrec operands with scale are nsw.
3426 SCEV::NoWrapFlags Flags =
3427 AddRec->getNoWrapFlags(ComputeFlags({Scale, AddRec}));
3428
3429 for (unsigned i = 0, e = AddRec->getNumOperands(); i != e; ++i) {
3430 NewOps.push_back(getMulExpr(Scale, AddRec->getOperand(i),
3431 SCEV::FlagAnyWrap, Depth + 1));
3432
3433 if (hasFlags(Flags, SCEV::FlagNSW) && !hasFlags(Flags, SCEV::FlagNUW)) {
3435 Instruction::Mul, getSignedRange(Scale),
3437 if (!NSWRegion.contains(getSignedRange(AddRec->getOperand(i))))
3438 Flags = clearFlags(Flags, SCEV::FlagNSW);
3439 }
3440 }
3441
3442 const SCEV *NewRec = getAddRecExpr(NewOps, AddRec->getLoop(), Flags);
3443
3444 // If all of the other operands were loop invariant, we are done.
3445 if (Ops.size() == 1) return NewRec;
3446
3447 // Otherwise, multiply the folded AddRec by the non-invariant parts.
3448 for (unsigned i = 0;; ++i)
3449 if (Ops[i] == AddRec) {
3450 Ops[i] = NewRec;
3451 break;
3452 }
3453 return getMulExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
3454 }
3455
3456 // Okay, if there weren't any loop invariants to be folded, check to see
3457 // if there are multiple AddRec's with the same loop induction variable
3458 // being multiplied together. If so, we can fold them.
3459
3460 // {A1,+,A2,+,...,+,An}<L> * {B1,+,B2,+,...,+,Bn}<L>
3461 // = {x=1 in [ sum y=x..2x [ sum z=max(y-x, y-n)..min(x,n) [
3462 // choose(x, 2x)*choose(2x-y, x-z)*A_{y-z}*B_z
3463 // ]]],+,...up to x=2n}.
3464 // Note that the arguments to choose() are always integers with values
3465 // known at compile time, never SCEV objects.
3466 //
3467 // The implementation avoids pointless extra computations when the two
3468 // addrec's are of different length (mathematically, it's equivalent to
3469 // an infinite stream of zeros on the right).
3470 bool OpsModified = false;
3471 for (unsigned OtherIdx = Idx+1;
3472 OtherIdx != Ops.size() && isa<SCEVAddRecExpr>(Ops[OtherIdx]);
3473 ++OtherIdx) {
3474 const SCEVAddRecExpr *OtherAddRec =
3475 dyn_cast<SCEVAddRecExpr>(Ops[OtherIdx]);
3476 if (!OtherAddRec || OtherAddRec->getLoop() != AddRec->getLoop())
3477 continue;
3478
3479 // Limit max number of arguments to avoid creation of unreasonably big
3480 // SCEVAddRecs with very complex operands.
3481 if (AddRec->getNumOperands() + OtherAddRec->getNumOperands() - 1 >
3482 MaxAddRecSize || hasHugeExpression({AddRec, OtherAddRec}))
3483 continue;
3484
3485 bool Overflow = false;
3486 Type *Ty = AddRec->getType();
3487 bool LargerThan64Bits = getTypeSizeInBits(Ty) > 64;
3488 SmallVector<SCEVUse, 7> AddRecOps;
3489 for (int x = 0, xe = AddRec->getNumOperands() +
3490 OtherAddRec->getNumOperands() - 1; x != xe && !Overflow; ++x) {
3492 for (int y = x, ye = 2*x+1; y != ye && !Overflow; ++y) {
3493 uint64_t Coeff1 = Choose(x, 2*x - y, Overflow);
3494 for (int z = std::max(y-x, y-(int)AddRec->getNumOperands()+1),
3495 ze = std::min(x+1, (int)OtherAddRec->getNumOperands());
3496 z < ze && !Overflow; ++z) {
3497 uint64_t Coeff2 = Choose(2*x - y, x-z, Overflow);
3498 uint64_t Coeff;
3499 if (LargerThan64Bits)
3500 Coeff = umul_ov(Coeff1, Coeff2, Overflow);
3501 else
3502 Coeff = Coeff1*Coeff2;
3503 const SCEV *CoeffTerm = getConstant(Ty, Coeff);
3504 const SCEV *Term1 = AddRec->getOperand(y-z);
3505 const SCEV *Term2 = OtherAddRec->getOperand(z);
3506 SumOps.push_back(getMulExpr(CoeffTerm, Term1, Term2,
3507 SCEV::FlagAnyWrap, Depth + 1));
3508 }
3509 }
3510 if (SumOps.empty())
3511 SumOps.push_back(getZero(Ty));
3512 AddRecOps.push_back(getAddExpr(SumOps, SCEV::FlagAnyWrap, Depth + 1));
3513 }
3514 if (!Overflow) {
3515 const SCEV *NewAddRec = getAddRecExpr(AddRecOps, AddRec->getLoop(),
3517 if (Ops.size() == 2) return NewAddRec;
3518 Ops[Idx] = NewAddRec;
3519 Ops.erase(Ops.begin() + OtherIdx); --OtherIdx;
3520 OpsModified = true;
3521 AddRec = dyn_cast<SCEVAddRecExpr>(NewAddRec);
3522 if (!AddRec)
3523 break;
3524 }
3525 }
3526 if (OpsModified)
3527 return getMulExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
3528
3529 // Otherwise couldn't fold anything into this recurrence. Move onto the
3530 // next one.
3531 }
3532
3533 // Okay, it looks like we really DO need an mul expr. Check to see if we
3534 // already have one, otherwise create a new one.
3535 return getOrCreateMulExpr(Ops, ComputeFlags(Ops));
3536}
3537
3538/// Represents an unsigned remainder expression based on unsigned division.
3540 assert(getEffectiveSCEVType(LHS->getType()) ==
3541 getEffectiveSCEVType(RHS->getType()) &&
3542 "SCEVURemExpr operand types don't match!");
3543
3544 // Short-circuit easy cases
3545 if (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS)) {
3546 // If constant is one, the result is trivial
3547 if (RHSC->getValue()->isOne())
3548 return getZero(LHS->getType()); // X urem 1 --> 0
3549
3550 // If constant is a power of two, fold into a zext(trunc(LHS)).
3551 if (RHSC->getAPInt().isPowerOf2()) {
3552 Type *FullTy = LHS->getType();
3553 Type *TruncTy =
3554 IntegerType::get(getContext(), RHSC->getAPInt().logBase2());
3555 return getZeroExtendExpr(getTruncateExpr(LHS, TruncTy), FullTy);
3556 }
3557 }
3558
3559 // Fallback to %a == %x urem %y == %x -<nuw> ((%x udiv %y) *<nuw> %y)
3560 const SCEV *UDiv = getUDivExpr(LHS, RHS);
3561 const SCEV *Mult = getMulExpr(UDiv, RHS, SCEV::FlagNUW);
3562 return getMinusSCEV(LHS, Mult, SCEV::FlagNUW);
3563}
3564
3565/// Get a canonical unsigned division expression, or something simpler if
3566/// possible.
3568 assert(!LHS->getType()->isPointerTy() &&
3569 "SCEVUDivExpr operand can't be pointer!");
3570 assert(LHS->getType() == RHS->getType() &&
3571 "SCEVUDivExpr operand types don't match!");
3572
3574 ID.AddInteger(scUDivExpr);
3575 ID.AddPointer(LHS);
3576 ID.AddPointer(RHS);
3577 void *IP = nullptr;
3578 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP))
3579 return S;
3580
3581 // 0 udiv Y == 0
3582 if (match(LHS, m_scev_Zero()))
3583 return LHS;
3584
3585 if (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS)) {
3586 if (RHSC->getValue()->isOne())
3587 return LHS; // X udiv 1 --> x
3588 // If the denominator is zero, the result of the udiv is undefined. Don't
3589 // try to analyze it, because the resolution chosen here may differ from
3590 // the resolution chosen in other parts of the compiler.
3591 if (!RHSC->getValue()->isZero()) {
3592 // Determine if the division can be folded into the operands of
3593 // its operands.
3594 // TODO: Generalize this to non-constants by using known-bits information.
3595 Type *Ty = LHS->getType();
3596 unsigned LZ = RHSC->getAPInt().countl_zero();
3597 unsigned MaxShiftAmt = getTypeSizeInBits(Ty) - LZ - 1;
3598 // For non-power-of-two values, effectively round the value up to the
3599 // nearest power of two.
3600 if (!RHSC->getAPInt().isPowerOf2())
3601 ++MaxShiftAmt;
3602 IntegerType *ExtTy =
3603 IntegerType::get(getContext(), getTypeSizeInBits(Ty) + MaxShiftAmt);
3604 if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(LHS))
3605 if (const SCEVConstant *Step =
3606 dyn_cast<SCEVConstant>(AR->getStepRecurrence(*this))) {
3607 // {X,+,N}/C --> {X/C,+,N/C} if safe and N/C can be folded.
3608 const APInt &StepInt = Step->getAPInt();
3609 const APInt &DivInt = RHSC->getAPInt();
3610 if (!StepInt.urem(DivInt) &&
3611 getZeroExtendExpr(AR, ExtTy) ==
3612 getAddRecExpr(getZeroExtendExpr(AR->getStart(), ExtTy),
3613 getZeroExtendExpr(Step, ExtTy),
3614 AR->getLoop(), SCEV::FlagAnyWrap)) {
3615 SmallVector<SCEVUse, 4> Operands;
3616 for (const SCEV *Op : AR->operands())
3617 Operands.push_back(getUDivExpr(Op, RHS));
3618 return getAddRecExpr(Operands, AR->getLoop(), SCEV::FlagNW);
3619 }
3620 /// Get a canonical UDivExpr for a recurrence.
3621 /// {X,+,N}/C => {Y,+,N}/C where Y=X-(X%N). Safe when C%N=0.
3622 const APInt *StartRem;
3623 if (!DivInt.urem(StepInt) && match(getURemExpr(AR->getStart(), Step),
3624 m_scev_APInt(StartRem))) {
3625 bool NoWrap =
3626 getZeroExtendExpr(AR, ExtTy) ==
3627 getAddRecExpr(getZeroExtendExpr(AR->getStart(), ExtTy),
3628 getZeroExtendExpr(Step, ExtTy), AR->getLoop(),
3630
3631 // With N <= C and both N, C as powers-of-2, the transformation
3632 // {X,+,N}/C => {(X - X%N),+,N}/C preserves division results even
3633 // if wrapping occurs, as the division results remain equivalent for
3634 // all offsets in [[(X - X%N), X).
3635 bool CanFoldWithWrap = StepInt.ule(DivInt) && // N <= C
3636 StepInt.isPowerOf2() && DivInt.isPowerOf2();
3637 // Only fold if the subtraction can be folded in the start
3638 // expression.
3639 const SCEV *NewStart =
3640 getMinusSCEV(AR->getStart(), getConstant(*StartRem));
3641 if (*StartRem != 0 && (NoWrap || CanFoldWithWrap) &&
3642 !isa<SCEVAddExpr>(NewStart)) {
3643 const SCEV *NewLHS =
3644 getAddRecExpr(NewStart, Step, AR->getLoop(),
3645 NoWrap ? SCEV::FlagNW : SCEV::FlagAnyWrap);
3646 if (LHS != NewLHS) {
3647 LHS = NewLHS;
3648
3649 // Reset the ID to include the new LHS, and check if it is
3650 // already cached.
3651 ID.clear();
3652 ID.AddInteger(scUDivExpr);
3653 ID.AddPointer(LHS);
3654 ID.AddPointer(RHS);
3655 IP = nullptr;
3656 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP))
3657 return S;
3658 }
3659 }
3660 }
3661 }
3662 // (A*B)/C --> A*(B/C) if safe and B/C can be folded.
3663 if (const SCEVMulExpr *M = dyn_cast<SCEVMulExpr>(LHS)) {
3664 SmallVector<SCEVUse, 4> Operands;
3665 for (const SCEV *Op : M->operands())
3666 Operands.push_back(getZeroExtendExpr(Op, ExtTy));
3667 if (getZeroExtendExpr(M, ExtTy) == getMulExpr(Operands))
3668 // Find an operand that's safely divisible.
3669 for (unsigned i = 0, e = M->getNumOperands(); i != e; ++i) {
3670 const SCEV *Op = M->getOperand(i);
3671 const SCEV *Div = getUDivExpr(Op, RHSC);
3672 if (!isa<SCEVUDivExpr>(Div) && getMulExpr(Div, RHSC) == Op) {
3673 Operands = SmallVector<SCEVUse, 4>(M->operands());
3674 Operands[i] = Div;
3675 return getMulExpr(Operands);
3676 }
3677 }
3678 }
3679
3680 // (A/B)/C --> A/(B*C) if safe and B*C can be folded.
3681 if (const SCEVUDivExpr *OtherDiv = dyn_cast<SCEVUDivExpr>(LHS)) {
3682 if (auto *DivisorConstant =
3683 dyn_cast<SCEVConstant>(OtherDiv->getRHS())) {
3684 bool Overflow = false;
3685 APInt NewRHS =
3686 DivisorConstant->getAPInt().umul_ov(RHSC->getAPInt(), Overflow);
3687 if (Overflow) {
3688 return getConstant(RHSC->getType(), 0, false);
3689 }
3690 return getUDivExpr(OtherDiv->getLHS(), getConstant(NewRHS));
3691 }
3692 }
3693
3694 // (A+B)/C --> (A/C + B/C) if safe and A/C and B/C can be folded.
3695 if (const SCEVAddExpr *A = dyn_cast<SCEVAddExpr>(LHS)) {
3696 SmallVector<SCEVUse, 4> Operands;
3697 for (const SCEV *Op : A->operands())
3698 Operands.push_back(getZeroExtendExpr(Op, ExtTy));
3699 if (getZeroExtendExpr(A, ExtTy) == getAddExpr(Operands)) {
3700 Operands.clear();
3701 for (unsigned i = 0, e = A->getNumOperands(); i != e; ++i) {
3702 const SCEV *Op = getUDivExpr(A->getOperand(i), RHS);
3703 if (isa<SCEVUDivExpr>(Op) ||
3704 getMulExpr(Op, RHS) != A->getOperand(i))
3705 break;
3706 Operands.push_back(Op);
3707 }
3708 if (Operands.size() == A->getNumOperands())
3709 return getAddExpr(Operands);
3710 }
3711 }
3712
3713 // Fold if both operands are constant.
3714 if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(LHS))
3715 return getConstant(LHSC->getAPInt().udiv(RHSC->getAPInt()));
3716 }
3717 }
3718
3719 // ((-C + (C smax %x)) /u %x) evaluates to zero, for any positive constant C.
3720 const APInt *NegC, *C;
3721 if (match(LHS,
3724 NegC->isNegative() && !NegC->isMinSignedValue() && *C == -*NegC)
3725 return getZero(LHS->getType());
3726
3727 // TODO: Generalize to handle any common factors.
3728 // udiv (mul nuw a, vscale), (mul nuw b, vscale) --> udiv a, b
3729 const SCEV *NewLHS, *NewRHS;
3730 if (match(LHS, m_scev_c_NUWMul(m_SCEV(NewLHS), m_SCEVVScale())) &&
3731 match(RHS, m_scev_c_NUWMul(m_SCEV(NewRHS), m_SCEVVScale())))
3732 return getUDivExpr(NewLHS, NewRHS);
3733
3734 // The Insertion Point (IP) might be invalid by now (due to UniqueSCEVs
3735 // changes). Make sure we get a new one.
3736 IP = nullptr;
3737 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
3738 SCEV *S = new (SCEVAllocator) SCEVUDivExpr(ID.Intern(SCEVAllocator),
3739 LHS, RHS);
3740 UniqueSCEVs.InsertNode(S, IP);
3741 S->computeAndSetCanonical(*this);
3742 registerUser(S, ArrayRef<SCEVUse>({LHS, RHS}));
3743 return S;
3744}
3745
3746APInt gcd(const SCEVConstant *C1, const SCEVConstant *C2) {
3747 APInt A = C1->getAPInt().abs();
3748 APInt B = C2->getAPInt().abs();
3749 uint32_t ABW = A.getBitWidth();
3750 uint32_t BBW = B.getBitWidth();
3751
3752 if (ABW > BBW)
3753 B = B.zext(ABW);
3754 else if (ABW < BBW)
3755 A = A.zext(BBW);
3756
3757 return APIntOps::GreatestCommonDivisor(std::move(A), std::move(B));
3758}
3759
3760/// Get a canonical unsigned division expression, or something simpler if
3761/// possible. There is no representation for an exact udiv in SCEV IR, but we
3762/// can attempt to remove factors from the LHS and RHS. We can't do this when
3763/// it's not exact because the udiv may be clearing bits.
3765 // TODO: we could try to find factors in all sorts of things, but for now we
3766 // just deal with u/exact (multiply, constant). See SCEVDivision towards the
3767 // end of this file for inspiration.
3768
3770 if (!Mul || !Mul->hasNoUnsignedWrap())
3771 return getUDivExpr(LHS, RHS);
3772
3773 if (const SCEVConstant *RHSCst = dyn_cast<SCEVConstant>(RHS)) {
3774 // If the mulexpr multiplies by a constant, then that constant must be the
3775 // first element of the mulexpr.
3776 if (const auto *LHSCst = dyn_cast<SCEVConstant>(Mul->getOperand(0))) {
3777 if (LHSCst == RHSCst) {
3778 SmallVector<SCEVUse, 2> Operands(drop_begin(Mul->operands()));
3779 return getMulExpr(Operands);
3780 }
3781
3782 // We can't just assume that LHSCst divides RHSCst cleanly, it could be
3783 // that there's a factor provided by one of the other terms. We need to
3784 // check.
3785 APInt Factor = gcd(LHSCst, RHSCst);
3786 if (!Factor.isIntN(1)) {
3787 LHSCst =
3788 cast<SCEVConstant>(getConstant(LHSCst->getAPInt().udiv(Factor)));
3789 RHSCst =
3790 cast<SCEVConstant>(getConstant(RHSCst->getAPInt().udiv(Factor)));
3791 SmallVector<SCEVUse, 2> Operands;
3792 Operands.push_back(LHSCst);
3793 append_range(Operands, Mul->operands().drop_front());
3794 LHS = getMulExpr(Operands);
3795 RHS = RHSCst;
3797 if (!Mul)
3798 return getUDivExactExpr(LHS, RHS);
3799 }
3800 }
3801 }
3802
3803 for (int i = 0, e = Mul->getNumOperands(); i != e; ++i) {
3804 if (Mul->getOperand(i) == RHS) {
3805 SmallVector<SCEVUse, 2> Operands;
3806 append_range(Operands, Mul->operands().take_front(i));
3807 append_range(Operands, Mul->operands().drop_front(i + 1));
3808 return getMulExpr(Operands);
3809 }
3810 }
3811
3812 return getUDivExpr(LHS, RHS);
3813}
3814
3815/// Get an add recurrence expression for the specified loop. Simplify the
3816/// expression as much as possible.
3818 const Loop *L,
3819 SCEV::NoWrapFlags Flags) {
3820 SmallVector<SCEVUse, 4> Operands;
3821 Operands.push_back(Start);
3822 if (const SCEVAddRecExpr *StepChrec = dyn_cast<SCEVAddRecExpr>(Step))
3823 if (StepChrec->getLoop() == L) {
3824 append_range(Operands, StepChrec->operands());
3825 return getAddRecExpr(Operands, L, maskFlags(Flags, SCEV::FlagNW));
3826 }
3827
3828 Operands.push_back(Step);
3829 return getAddRecExpr(Operands, L, Flags);
3830}
3831
3832/// Get an add recurrence expression for the specified loop. Simplify the
3833/// expression as much as possible.
3835 const Loop *L,
3836 SCEV::NoWrapFlags Flags) {
3837 if (Operands.size() == 1) return Operands[0];
3838#ifndef NDEBUG
3839 Type *ETy = getEffectiveSCEVType(Operands[0]->getType());
3840 for (const SCEV *Op : llvm::drop_begin(Operands)) {
3841 assert(getEffectiveSCEVType(Op->getType()) == ETy &&
3842 "SCEVAddRecExpr operand types don't match!");
3843 assert(!Op->getType()->isPointerTy() && "Step must be integer");
3844 }
3845 for (const SCEV *Op : Operands)
3847 "SCEVAddRecExpr operand is not available at loop entry!");
3848#endif
3849
3850 if (Operands.back()->isZero()) {
3851 Operands.pop_back();
3852 return getAddRecExpr(Operands, L, SCEV::FlagAnyWrap); // {X,+,0} --> X
3853 }
3854
3855 // It's tempting to want to call getConstantMaxBackedgeTakenCount count here and
3856 // use that information to infer NUW and NSW flags. However, computing a
3857 // BE count requires calling getAddRecExpr, so we may not yet have a
3858 // meaningful BE count at this point (and if we don't, we'd be stuck
3859 // with a SCEVCouldNotCompute as the cached BE count).
3860
3861 Flags = StrengthenNoWrapFlags(this, scAddRecExpr, Operands, Flags);
3862
3863 // Canonicalize nested AddRecs in by nesting them in order of loop depth.
3864 if (const SCEVAddRecExpr *NestedAR = dyn_cast<SCEVAddRecExpr>(Operands[0])) {
3865 const Loop *NestedLoop = NestedAR->getLoop();
3866 if (L->contains(NestedLoop)
3867 ? (L->getLoopDepth() < NestedLoop->getLoopDepth())
3868 : (!NestedLoop->contains(L) &&
3869 DT.dominates(L->getHeader(), NestedLoop->getHeader()))) {
3870 SmallVector<SCEVUse, 4> NestedOperands(NestedAR->operands());
3871 Operands[0] = NestedAR->getStart();
3872 // AddRecs require their operands be loop-invariant with respect to their
3873 // loops. Don't perform this transformation if it would break this
3874 // requirement.
3875 bool AllInvariant = all_of(
3876 Operands, [&](const SCEV *Op) { return isLoopInvariant(Op, L); });
3877
3878 if (AllInvariant) {
3879 // Create a recurrence for the outer loop with the same step size.
3880 //
3881 // The outer recurrence keeps its NW flag but only keeps NUW/NSW if the
3882 // inner recurrence has the same property.
3883 SCEV::NoWrapFlags OuterFlags =
3884 maskFlags(Flags, SCEV::FlagNW | NestedAR->getNoWrapFlags());
3885
3886 NestedOperands[0] = getAddRecExpr(Operands, L, OuterFlags);
3887 AllInvariant = all_of(NestedOperands, [&](const SCEV *Op) {
3888 return isLoopInvariant(Op, NestedLoop);
3889 });
3890
3891 if (AllInvariant) {
3892 // Ok, both add recurrences are valid after the transformation.
3893 //
3894 // The inner recurrence keeps its NW flag but only keeps NUW/NSW if
3895 // the outer recurrence has the same property.
3896 SCEV::NoWrapFlags InnerFlags =
3897 maskFlags(NestedAR->getNoWrapFlags(), SCEV::FlagNW | Flags);
3898 return getAddRecExpr(NestedOperands, NestedLoop, InnerFlags);
3899 }
3900 }
3901 // Reset Operands to its original state.
3902 Operands[0] = NestedAR;
3903 }
3904 }
3905
3906 // Okay, it looks like we really DO need an addrec expr. Check to see if we
3907 // already have one, otherwise create a new one.
3908 return getOrCreateAddRecExpr(Operands, L, Flags);
3909}
3910
3912 ArrayRef<SCEVUse> IndexExprs) {
3913 const SCEV *BaseExpr = getSCEV(GEP->getPointerOperand());
3914 // getSCEV(Base)->getType() has the same address space as Base->getType()
3915 // because SCEV::getType() preserves the address space.
3916 GEPNoWrapFlags NW = GEP->getNoWrapFlags();
3917 if (NW != GEPNoWrapFlags::none()) {
3918 // We'd like to propagate flags from the IR to the corresponding SCEV nodes,
3919 // but to do that, we have to ensure that said flag is valid in the entire
3920 // defined scope of the SCEV.
3921 // TODO: non-instructions have global scope. We might be able to prove
3922 // some global scope cases
3923 auto *GEPI = dyn_cast<Instruction>(GEP);
3924 if (!GEPI || !isSCEVExprNeverPoison(GEPI))
3925 NW = GEPNoWrapFlags::none();
3926 }
3927
3928 return getGEPExpr(BaseExpr, IndexExprs, GEP->getSourceElementType(), NW);
3929}
3930
3932 ArrayRef<SCEVUse> IndexExprs,
3933 Type *SrcElementTy, GEPNoWrapFlags NW) {
3935 if (NW.hasNoUnsignedSignedWrap())
3936 OffsetWrap = setFlags(OffsetWrap, SCEV::FlagNSW);
3937 if (NW.hasNoUnsignedWrap())
3938 OffsetWrap = setFlags(OffsetWrap, SCEV::FlagNUW);
3939
3940 Type *CurTy = BaseExpr->getType();
3941 Type *IntIdxTy = getEffectiveSCEVType(BaseExpr->getType());
3942 bool FirstIter = true;
3944 for (SCEVUse IndexExpr : IndexExprs) {
3945 // Compute the (potentially symbolic) offset in bytes for this index.
3946 if (StructType *STy = dyn_cast<StructType>(CurTy)) {
3947 // For a struct, add the member offset.
3948 ConstantInt *Index = cast<SCEVConstant>(IndexExpr)->getValue();
3949 unsigned FieldNo = Index->getZExtValue();
3950 const SCEV *FieldOffset = getOffsetOfExpr(IntIdxTy, STy, FieldNo);
3951 Offsets.push_back(FieldOffset);
3952
3953 // Update CurTy to the type of the field at Index.
3954 CurTy = STy->getTypeAtIndex(Index);
3955 } else {
3956 // Update CurTy to its element type.
3957 if (FirstIter) {
3958 assert(isa<PointerType>(CurTy) &&
3959 "The first index of a GEP indexes a pointer");
3960 CurTy = SrcElementTy;
3961 FirstIter = false;
3962 } else {
3964 }
3965 // For an array, add the element offset, explicitly scaled.
3966 const SCEV *ElementSize = getSizeOfExpr(IntIdxTy, CurTy);
3967 // Getelementptr indices are signed.
3968 IndexExpr = getTruncateOrSignExtend(IndexExpr, IntIdxTy);
3969
3970 // Multiply the index by the element size to compute the element offset.
3971 const SCEV *LocalOffset = getMulExpr(IndexExpr, ElementSize, OffsetWrap);
3972 Offsets.push_back(LocalOffset);
3973 }
3974 }
3975
3976 // Handle degenerate case of GEP without offsets.
3977 if (Offsets.empty())
3978 return BaseExpr;
3979
3980 // Add the offsets together, assuming nsw if inbounds.
3981 const SCEV *Offset = getAddExpr(Offsets, OffsetWrap);
3982 // Add the base address and the offset. We cannot use the nsw flag, as the
3983 // base address is unsigned. However, if we know that the offset is
3984 // non-negative, we can use nuw.
3985 bool NUW = NW.hasNoUnsignedWrap() ||
3988 auto *GEPExpr = getAddExpr(BaseExpr, Offset, BaseWrap);
3989 assert(BaseExpr->getType() == GEPExpr->getType() &&
3990 "GEP should not change type mid-flight.");
3991 return GEPExpr;
3992}
3993
3994SCEV *ScalarEvolution::findExistingSCEVInCache(SCEVTypes SCEVType,
3997 ID.AddInteger(SCEVType);
3998 for (const SCEV *Op : Ops)
3999 ID.AddPointer(Op);
4000 void *IP = nullptr;
4001 return UniqueSCEVs.FindNodeOrInsertPos(ID, IP);
4002}
4003
4004SCEV *ScalarEvolution::findExistingSCEVInCache(SCEVTypes SCEVType,
4007 ID.AddInteger(SCEVType);
4008 for (const SCEV *Op : Ops)
4009 ID.AddPointer(Op);
4010 void *IP = nullptr;
4011 return UniqueSCEVs.FindNodeOrInsertPos(ID, IP);
4012}
4013
4014const SCEV *ScalarEvolution::getAbsExpr(const SCEV *Op, bool IsNSW) {
4016 return getSMaxExpr(Op, getNegativeSCEV(Op, Flags));
4017}
4018
4021 assert(SCEVMinMaxExpr::isMinMaxType(Kind) && "Not a SCEVMinMaxExpr!");
4022 assert(!Ops.empty() && "Cannot get empty (u|s)(min|max)!");
4023 if (Ops.size() == 1) return Ops[0];
4024#ifndef NDEBUG
4025 Type *ETy = getEffectiveSCEVType(Ops[0]->getType());
4026 for (unsigned i = 1, e = Ops.size(); i != e; ++i) {
4027 assert(getEffectiveSCEVType(Ops[i]->getType()) == ETy &&
4028 "Operand types don't match!");
4029 assert(Ops[0]->getType()->isPointerTy() ==
4030 Ops[i]->getType()->isPointerTy() &&
4031 "min/max should be consistently pointerish");
4032 }
4033#endif
4034
4035 bool IsSigned = Kind == scSMaxExpr || Kind == scSMinExpr;
4036 bool IsMax = Kind == scSMaxExpr || Kind == scUMaxExpr;
4037
4038 const SCEV *Folded = constantFoldAndGroupOps(
4039 *this, LI, DT, Ops,
4040 [&](const APInt &C1, const APInt &C2) {
4041 switch (Kind) {
4042 case scSMaxExpr:
4043 return APIntOps::smax(C1, C2);
4044 case scSMinExpr:
4045 return APIntOps::smin(C1, C2);
4046 case scUMaxExpr:
4047 return APIntOps::umax(C1, C2);
4048 case scUMinExpr:
4049 return APIntOps::umin(C1, C2);
4050 default:
4051 llvm_unreachable("Unknown SCEV min/max opcode");
4052 }
4053 },
4054 [&](const APInt &C) {
4055 // identity
4056 if (IsMax)
4057 return IsSigned ? C.isMinSignedValue() : C.isMinValue();
4058 else
4059 return IsSigned ? C.isMaxSignedValue() : C.isMaxValue();
4060 },
4061 [&](const APInt &C) {
4062 // absorber
4063 if (IsMax)
4064 return IsSigned ? C.isMaxSignedValue() : C.isMaxValue();
4065 else
4066 return IsSigned ? C.isMinSignedValue() : C.isMinValue();
4067 });
4068 if (Folded)
4069 return Folded;
4070
4071 // Check if we have created the same expression before.
4072 if (const SCEV *S = findExistingSCEVInCache(Kind, Ops)) {
4073 return S;
4074 }
4075
4076 // Find the first operation of the same kind
4077 unsigned Idx = 0;
4078 while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < Kind)
4079 ++Idx;
4080
4081 // Check to see if one of the operands is of the same kind. If so, expand its
4082 // operands onto our operand list, and recurse to simplify.
4083 if (Idx < Ops.size()) {
4084 bool DeletedAny = false;
4085 while (Ops[Idx]->getSCEVType() == Kind) {
4086 const SCEVMinMaxExpr *SMME = cast<SCEVMinMaxExpr>(Ops[Idx]);
4087 Ops.erase(Ops.begin()+Idx);
4088 append_range(Ops, SMME->operands());
4089 DeletedAny = true;
4090 }
4091
4092 if (DeletedAny)
4093 return getMinMaxExpr(Kind, Ops);
4094 }
4095
4096 // Okay, check to see if the same value occurs in the operand list twice. If
4097 // so, delete one. Since we sorted the list, these values are required to
4098 // be adjacent.
4103 llvm::CmpInst::Predicate FirstPred = IsMax ? GEPred : LEPred;
4104 llvm::CmpInst::Predicate SecondPred = IsMax ? LEPred : GEPred;
4105 for (unsigned i = 0, e = Ops.size() - 1; i != e; ++i) {
4106 if (Ops[i] == Ops[i + 1] ||
4107 isKnownViaNonRecursiveReasoning(FirstPred, Ops[i], Ops[i + 1])) {
4108 // X op Y op Y --> X op Y
4109 // X op Y --> X, if we know X, Y are ordered appropriately
4110 Ops.erase(Ops.begin() + i + 1, Ops.begin() + i + 2);
4111 --i;
4112 --e;
4113 } else if (isKnownViaNonRecursiveReasoning(SecondPred, Ops[i],
4114 Ops[i + 1])) {
4115 // X op Y --> Y, if we know X, Y are ordered appropriately
4116 Ops.erase(Ops.begin() + i, Ops.begin() + i + 1);
4117 --i;
4118 --e;
4119 }
4120 }
4121
4122 if (Ops.size() == 1) return Ops[0];
4123
4124 assert(!Ops.empty() && "Reduced smax down to nothing!");
4125
4126 // Okay, it looks like we really DO need an expr. Check to see if we
4127 // already have one, otherwise create a new one.
4129 ID.AddInteger(Kind);
4130 for (const SCEV *Op : Ops)
4131 ID.AddPointer(Op);
4132 void *IP = nullptr;
4133 const SCEV *ExistingSCEV = UniqueSCEVs.FindNodeOrInsertPos(ID, IP);
4134 if (ExistingSCEV)
4135 return ExistingSCEV;
4136 SCEVUse *O = SCEVAllocator.Allocate<SCEVUse>(Ops.size());
4138 SCEV *S = new (SCEVAllocator)
4139 SCEVMinMaxExpr(ID.Intern(SCEVAllocator), Kind, O, Ops.size());
4140
4141 UniqueSCEVs.InsertNode(S, IP);
4142 S->computeAndSetCanonical(*this);
4143 registerUser(S, Ops);
4144 return S;
4145}
4146
4147namespace {
4148
4149class SCEVSequentialMinMaxDeduplicatingVisitor final
4150 : public SCEVVisitor<SCEVSequentialMinMaxDeduplicatingVisitor,
4151 std::optional<const SCEV *>> {
4152 using RetVal = std::optional<const SCEV *>;
4154
4155 ScalarEvolution &SE;
4156 const SCEVTypes RootKind; // Must be a sequential min/max expression.
4157 const SCEVTypes NonSequentialRootKind; // Non-sequential variant of RootKind.
4159
4160 bool canRecurseInto(SCEVTypes Kind) const {
4161 // We can only recurse into the SCEV expression of the same effective type
4162 // as the type of our root SCEV expression.
4163 return RootKind == Kind || NonSequentialRootKind == Kind;
4164 };
4165
4166 RetVal visitAnyMinMaxExpr(const SCEV *S) {
4168 "Only for min/max expressions.");
4169 SCEVTypes Kind = S->getSCEVType();
4170
4171 if (!canRecurseInto(Kind))
4172 return S;
4173
4174 auto *NAry = cast<SCEVNAryExpr>(S);
4175 SmallVector<SCEVUse> NewOps;
4176 bool Changed = visit(Kind, NAry->operands(), NewOps);
4177
4178 if (!Changed)
4179 return S;
4180 if (NewOps.empty())
4181 return std::nullopt;
4182
4184 ? SE.getSequentialMinMaxExpr(Kind, NewOps)
4185 : SE.getMinMaxExpr(Kind, NewOps);
4186 }
4187
4188 RetVal visit(const SCEV *S) {
4189 // Has the whole operand been seen already?
4190 if (!SeenOps.insert(S).second)
4191 return std::nullopt;
4192 return Base::visit(S);
4193 }
4194
4195public:
4196 SCEVSequentialMinMaxDeduplicatingVisitor(ScalarEvolution &SE,
4197 SCEVTypes RootKind)
4198 : SE(SE), RootKind(RootKind),
4199 NonSequentialRootKind(
4200 SCEVSequentialMinMaxExpr::getEquivalentNonSequentialSCEVType(
4201 RootKind)) {}
4202
4203 bool /*Changed*/ visit(SCEVTypes Kind, ArrayRef<SCEVUse> OrigOps,
4204 SmallVectorImpl<SCEVUse> &NewOps) {
4205 bool Changed = false;
4207 Ops.reserve(OrigOps.size());
4208
4209 for (const SCEV *Op : OrigOps) {
4210 RetVal NewOp = visit(Op);
4211 if (NewOp != Op)
4212 Changed = true;
4213 if (NewOp)
4214 Ops.emplace_back(*NewOp);
4215 }
4216
4217 if (Changed)
4218 NewOps = std::move(Ops);
4219 return Changed;
4220 }
4221
4222 RetVal visitConstant(const SCEVConstant *Constant) { return Constant; }
4223
4224 RetVal visitVScale(const SCEVVScale *VScale) { return VScale; }
4225
4226 RetVal visitPtrToAddrExpr(const SCEVPtrToAddrExpr *Expr) { return Expr; }
4227
4228 RetVal visitPtrToIntExpr(const SCEVPtrToIntExpr *Expr) { return Expr; }
4229
4230 RetVal visitTruncateExpr(const SCEVTruncateExpr *Expr) { return Expr; }
4231
4232 RetVal visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr) { return Expr; }
4233
4234 RetVal visitSignExtendExpr(const SCEVSignExtendExpr *Expr) { return Expr; }
4235
4236 RetVal visitAddExpr(const SCEVAddExpr *Expr) { return Expr; }
4237
4238 RetVal visitMulExpr(const SCEVMulExpr *Expr) { return Expr; }
4239
4240 RetVal visitUDivExpr(const SCEVUDivExpr *Expr) { return Expr; }
4241
4242 RetVal visitAddRecExpr(const SCEVAddRecExpr *Expr) { return Expr; }
4243
4244 RetVal visitSMaxExpr(const SCEVSMaxExpr *Expr) {
4245 return visitAnyMinMaxExpr(Expr);
4246 }
4247
4248 RetVal visitUMaxExpr(const SCEVUMaxExpr *Expr) {
4249 return visitAnyMinMaxExpr(Expr);
4250 }
4251
4252 RetVal visitSMinExpr(const SCEVSMinExpr *Expr) {
4253 return visitAnyMinMaxExpr(Expr);
4254 }
4255
4256 RetVal visitUMinExpr(const SCEVUMinExpr *Expr) {
4257 return visitAnyMinMaxExpr(Expr);
4258 }
4259
4260 RetVal visitSequentialUMinExpr(const SCEVSequentialUMinExpr *Expr) {
4261 return visitAnyMinMaxExpr(Expr);
4262 }
4263
4264 RetVal visitUnknown(const SCEVUnknown *Expr) { return Expr; }
4265
4266 RetVal visitCouldNotCompute(const SCEVCouldNotCompute *Expr) { return Expr; }
4267};
4268
4269} // namespace
4270
4272 switch (Kind) {
4273 case scConstant:
4274 case scVScale:
4275 case scTruncate:
4276 case scZeroExtend:
4277 case scSignExtend:
4278 case scPtrToAddr:
4279 case scPtrToInt:
4280 case scAddExpr:
4281 case scMulExpr:
4282 case scUDivExpr:
4283 case scAddRecExpr:
4284 case scUMaxExpr:
4285 case scSMaxExpr:
4286 case scUMinExpr:
4287 case scSMinExpr:
4288 case scUnknown:
4289 // If any operand is poison, the whole expression is poison.
4290 return true;
4292 // FIXME: if the *first* operand is poison, the whole expression is poison.
4293 return false; // Pessimistically, say that it does not propagate poison.
4294 case scCouldNotCompute:
4295 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
4296 }
4297 llvm_unreachable("Unknown SCEV kind!");
4298}
4299
4300namespace {
4301// The only way poison may be introduced in a SCEV expression is from a
4302// poison SCEVUnknown (ConstantExprs are also represented as SCEVUnknown,
4303// not SCEVConstant). Notably, nowrap flags in SCEV nodes can *not*
4304// introduce poison -- they encode guaranteed, non-speculated knowledge.
4305//
4306// Additionally, all SCEV nodes propagate poison from inputs to outputs,
4307// with the notable exception of umin_seq, where only poison from the first
4308// operand is (unconditionally) propagated.
4309struct SCEVPoisonCollector {
4310 bool LookThroughMaybePoisonBlocking;
4311 SmallPtrSet<const SCEVUnknown *, 4> MaybePoison;
4312 SCEVPoisonCollector(bool LookThroughMaybePoisonBlocking)
4313 : LookThroughMaybePoisonBlocking(LookThroughMaybePoisonBlocking) {}
4314
4315 bool follow(const SCEV *S) {
4316 if (!LookThroughMaybePoisonBlocking &&
4318 return false;
4319
4320 if (auto *SU = dyn_cast<SCEVUnknown>(S)) {
4321 if (!isGuaranteedNotToBePoison(SU->getValue()))
4322 MaybePoison.insert(SU);
4323 }
4324 return true;
4325 }
4326 bool isDone() const { return false; }
4327};
4328} // namespace
4329
4330/// Return true if V is poison given that AssumedPoison is already poison.
4331static bool impliesPoison(const SCEV *AssumedPoison, const SCEV *S) {
4332 // First collect all SCEVs that might result in AssumedPoison to be poison.
4333 // We need to look through potentially poison-blocking operations here,
4334 // because we want to find all SCEVs that *might* result in poison, not only
4335 // those that are *required* to.
4336 SCEVPoisonCollector PC1(/* LookThroughMaybePoisonBlocking */ true);
4337 visitAll(AssumedPoison, PC1);
4338
4339 // AssumedPoison is never poison. As the assumption is false, the implication
4340 // is true. Don't bother walking the other SCEV in this case.
4341 if (PC1.MaybePoison.empty())
4342 return true;
4343
4344 // Collect all SCEVs in S that, if poison, *will* result in S being poison
4345 // as well. We cannot look through potentially poison-blocking operations
4346 // here, as their arguments only *may* make the result poison.
4347 SCEVPoisonCollector PC2(/* LookThroughMaybePoisonBlocking */ false);
4348 visitAll(S, PC2);
4349
4350 // Make sure that no matter which SCEV in PC1.MaybePoison is actually poison,
4351 // it will also make S poison by being part of PC2.MaybePoison.
4352 return llvm::set_is_subset(PC1.MaybePoison, PC2.MaybePoison);
4353}
4354
4356 SmallPtrSetImpl<const Value *> &Result, const SCEV *S) {
4357 SCEVPoisonCollector PC(/* LookThroughMaybePoisonBlocking */ false);
4358 visitAll(S, PC);
4359 for (const SCEVUnknown *SU : PC.MaybePoison)
4360 Result.insert(SU->getValue());
4361}
4362
4364 const SCEV *S, Instruction *I,
4365 SmallVectorImpl<Instruction *> &DropPoisonGeneratingInsts) {
4366 // If the instruction cannot be poison, it's always safe to reuse.
4368 return true;
4369
4370 // Otherwise, it is possible that I is more poisonous that S. Collect the
4371 // poison-contributors of S, and then check whether I has any additional
4372 // poison-contributors. Poison that is contributed through poison-generating
4373 // flags is handled by dropping those flags instead.
4375 getPoisonGeneratingValues(PoisonVals, S);
4376
4377 SmallVector<Value *> Worklist;
4379 Worklist.push_back(I);
4380 while (!Worklist.empty()) {
4381 Value *V = Worklist.pop_back_val();
4382 if (!Visited.insert(V).second)
4383 continue;
4384
4385 // Avoid walking large instruction graphs.
4386 if (Visited.size() > 16)
4387 return false;
4388
4389 // Either the value can't be poison, or the S would also be poison if it
4390 // is.
4391 if (PoisonVals.contains(V) || ::isGuaranteedNotToBePoison(V))
4392 continue;
4393
4394 auto *I = dyn_cast<Instruction>(V);
4395 if (!I)
4396 return false;
4397
4398 // Disjoint or instructions are interpreted as adds by SCEV. However, we
4399 // can't replace an arbitrary add with disjoint or, even if we drop the
4400 // flag. We would need to convert the or into an add.
4401 if (auto *PDI = dyn_cast<PossiblyDisjointInst>(I))
4402 if (PDI->isDisjoint())
4403 return false;
4404
4405 // FIXME: Ignore vscale, even though it technically could be poison. Do this
4406 // because SCEV currently assumes it can't be poison. Remove this special
4407 // case once we proper model when vscale can be poison.
4408 if (auto *II = dyn_cast<IntrinsicInst>(I);
4409 II && II->getIntrinsicID() == Intrinsic::vscale)
4410 continue;
4411
4412 if (canCreatePoison(cast<Operator>(I), /*ConsiderFlagsAndMetadata*/ false))
4413 return false;
4414
4415 // If the instruction can't create poison, we can recurse to its operands.
4416 if (I->hasPoisonGeneratingAnnotations())
4417 DropPoisonGeneratingInsts.push_back(I);
4418
4419 llvm::append_range(Worklist, I->operands());
4420 }
4421 return true;
4422}
4423
4424const SCEV *
4427 assert(SCEVSequentialMinMaxExpr::isSequentialMinMaxType(Kind) &&
4428 "Not a SCEVSequentialMinMaxExpr!");
4429 assert(!Ops.empty() && "Cannot get empty (u|s)(min|max)!");
4430 if (Ops.size() == 1)
4431 return Ops[0];
4432#ifndef NDEBUG
4433 Type *ETy = getEffectiveSCEVType(Ops[0]->getType());
4434 for (unsigned i = 1, e = Ops.size(); i != e; ++i) {
4435 assert(getEffectiveSCEVType(Ops[i]->getType()) == ETy &&
4436 "Operand types don't match!");
4437 assert(Ops[0]->getType()->isPointerTy() ==
4438 Ops[i]->getType()->isPointerTy() &&
4439 "min/max should be consistently pointerish");
4440 }
4441#endif
4442
4443 // Note that SCEVSequentialMinMaxExpr is *NOT* commutative,
4444 // so we can *NOT* do any kind of sorting of the expressions!
4445
4446 // Check if we have created the same expression before.
4447 if (const SCEV *S = findExistingSCEVInCache(Kind, Ops))
4448 return S;
4449
4450 // FIXME: there are *some* simplifications that we can do here.
4451
4452 // Keep only the first instance of an operand.
4453 {
4454 SCEVSequentialMinMaxDeduplicatingVisitor Deduplicator(*this, Kind);
4455 bool Changed = Deduplicator.visit(Kind, Ops, Ops);
4456 if (Changed)
4457 return getSequentialMinMaxExpr(Kind, Ops);
4458 }
4459
4460 // Check to see if one of the operands is of the same kind. If so, expand its
4461 // operands onto our operand list, and recurse to simplify.
4462 {
4463 unsigned Idx = 0;
4464 bool DeletedAny = false;
4465 while (Idx < Ops.size()) {
4466 if (Ops[Idx]->getSCEVType() != Kind) {
4467 ++Idx;
4468 continue;
4469 }
4470 const auto *SMME = cast<SCEVSequentialMinMaxExpr>(Ops[Idx]);
4471 Ops.erase(Ops.begin() + Idx);
4472 Ops.insert(Ops.begin() + Idx, SMME->operands().begin(),
4473 SMME->operands().end());
4474 DeletedAny = true;
4475 }
4476
4477 if (DeletedAny)
4478 return getSequentialMinMaxExpr(Kind, Ops);
4479 }
4480
4481 const SCEV *SaturationPoint;
4483 switch (Kind) {
4485 SaturationPoint = getZero(Ops[0]->getType());
4486 Pred = ICmpInst::ICMP_ULE;
4487 break;
4488 default:
4489 llvm_unreachable("Not a sequential min/max type.");
4490 }
4491
4492 for (unsigned i = 1, e = Ops.size(); i != e; ++i) {
4493 if (!isGuaranteedNotToCauseUB(Ops[i]))
4494 continue;
4495 // We can replace %x umin_seq %y with %x umin %y if either:
4496 // * %y being poison implies %x is also poison.
4497 // * %x cannot be the saturating value (e.g. zero for umin).
4498 if (::impliesPoison(Ops[i], Ops[i - 1]) ||
4499 isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_NE, Ops[i - 1],
4500 SaturationPoint)) {
4501 SmallVector<SCEVUse, 2> SeqOps = {Ops[i - 1], Ops[i]};
4502 Ops[i - 1] = getMinMaxExpr(
4504 SeqOps);
4505 Ops.erase(Ops.begin() + i);
4506 return getSequentialMinMaxExpr(Kind, Ops);
4507 }
4508 // Fold %x umin_seq %y to %x if %x ule %y.
4509 // TODO: We might be able to prove the predicate for a later operand.
4510 if (isKnownViaNonRecursiveReasoning(Pred, Ops[i - 1], Ops[i])) {
4511 Ops.erase(Ops.begin() + i);
4512 return getSequentialMinMaxExpr(Kind, Ops);
4513 }
4514 }
4515
4516 // Okay, it looks like we really DO need an expr. Check to see if we
4517 // already have one, otherwise create a new one.
4519 ID.AddInteger(Kind);
4520 for (const SCEV *Op : Ops)
4521 ID.AddPointer(Op);
4522 void *IP = nullptr;
4523 const SCEV *ExistingSCEV = UniqueSCEVs.FindNodeOrInsertPos(ID, IP);
4524 if (ExistingSCEV)
4525 return ExistingSCEV;
4526
4527 SCEVUse *O = SCEVAllocator.Allocate<SCEVUse>(Ops.size());
4529 SCEV *S = new (SCEVAllocator)
4530 SCEVSequentialMinMaxExpr(ID.Intern(SCEVAllocator), Kind, O, Ops.size());
4531
4532 UniqueSCEVs.InsertNode(S, IP);
4533 S->computeAndSetCanonical(*this);
4534 registerUser(S, Ops);
4535 return S;
4536}
4537
4542
4546
4551
4555
4560
4564
4566 bool Sequential) {
4567 SmallVector<SCEVUse, 2> Ops = {LHS, RHS};
4568 return getUMinExpr(Ops, Sequential);
4569}
4570
4576
4577const SCEV *
4579 const SCEV *Res = getConstant(IntTy, Size.getKnownMinValue());
4580 if (Size.isScalable())
4581 Res = getMulExpr(Res, getVScale(IntTy));
4582 return Res;
4583}
4584
4586 return getSizeOfExpr(IntTy, getDataLayout().getTypeAllocSize(AllocTy));
4587}
4588
4590 return getSizeOfExpr(IntTy, getDataLayout().getTypeStoreSize(StoreTy));
4591}
4592
4594 StructType *STy,
4595 unsigned FieldNo) {
4596 // We can bypass creating a target-independent constant expression and then
4597 // folding it back into a ConstantInt. This is just a compile-time
4598 // optimization.
4599 const StructLayout *SL = getDataLayout().getStructLayout(STy);
4600 assert(!SL->getSizeInBits().isScalable() &&
4601 "Cannot get offset for structure containing scalable vector types");
4602 return getConstant(IntTy, SL->getElementOffset(FieldNo));
4603}
4604
4606 // Don't attempt to do anything other than create a SCEVUnknown object
4607 // here. createSCEV only calls getUnknown after checking for all other
4608 // interesting possibilities, and any other code that calls getUnknown
4609 // is doing so in order to hide a value from SCEV canonicalization.
4610
4612 ID.AddInteger(scUnknown);
4613 ID.AddPointer(V);
4614 void *IP = nullptr;
4615 if (SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) {
4616 assert(cast<SCEVUnknown>(S)->getValue() == V &&
4617 "Stale SCEVUnknown in uniquing map!");
4618 return S;
4619 }
4620 SCEV *S = new (SCEVAllocator) SCEVUnknown(ID.Intern(SCEVAllocator), V, this,
4621 FirstUnknown);
4622 FirstUnknown = cast<SCEVUnknown>(S);
4623 UniqueSCEVs.InsertNode(S, IP);
4624 S->computeAndSetCanonical(*this);
4625 return S;
4626}
4627
4628//===----------------------------------------------------------------------===//
4629// Basic SCEV Analysis and PHI Idiom Recognition Code
4630//
4631
4632/// Test if values of the given type are analyzable within the SCEV
4633/// framework. This primarily includes integer types, and it can optionally
4634/// include pointer types if the ScalarEvolution class has access to
4635/// target-specific information.
4637 // Integers and pointers are always SCEVable.
4638 return Ty->isIntOrPtrTy();
4639}
4640
4641/// Return the size in bits of the specified type, for which isSCEVable must
4642/// return true.
4644 assert(isSCEVable(Ty) && "Type is not SCEVable!");
4645 if (Ty->isPointerTy())
4647 return getDataLayout().getTypeSizeInBits(Ty);
4648}
4649
4650/// Return a type with the same bitwidth as the given type and which represents
4651/// how SCEV will treat the given type, for which isSCEVable must return
4652/// true. For pointer types, this is the pointer index sized integer type.
4654 assert(isSCEVable(Ty) && "Type is not SCEVable!");
4655
4656 if (Ty->isIntegerTy())
4657 return Ty;
4658
4659 // The only other support type is pointer.
4660 assert(Ty->isPointerTy() && "Unexpected non-pointer non-integer type!");
4661 return getDataLayout().getIndexType(Ty);
4662}
4663
4665 return getTypeSizeInBits(T1) >= getTypeSizeInBits(T2) ? T1 : T2;
4666}
4667
4669 const SCEV *B) {
4670 /// For a valid use point to exist, the defining scope of one operand
4671 /// must dominate the other.
4672 bool PreciseA, PreciseB;
4673 auto *ScopeA = getDefiningScopeBound({A}, PreciseA);
4674 auto *ScopeB = getDefiningScopeBound({B}, PreciseB);
4675 if (!PreciseA || !PreciseB)
4676 // Can't tell.
4677 return false;
4678 return (ScopeA == ScopeB) || DT.dominates(ScopeA, ScopeB) ||
4679 DT.dominates(ScopeB, ScopeA);
4680}
4681
4683 return CouldNotCompute.get();
4684}
4685
4686bool ScalarEvolution::checkValidity(const SCEV *S) const {
4687 bool ContainsNulls = SCEVExprContains(S, [](const SCEV *S) {
4688 auto *SU = dyn_cast<SCEVUnknown>(S);
4689 return SU && SU->getValue() == nullptr;
4690 });
4691
4692 return !ContainsNulls;
4693}
4694
4696 HasRecMapType::iterator I = HasRecMap.find(S);
4697 if (I != HasRecMap.end())
4698 return I->second;
4699
4700 bool FoundAddRec =
4701 SCEVExprContains(S, [](const SCEV *S) { return isa<SCEVAddRecExpr>(S); });
4702 HasRecMap.insert({S, FoundAddRec});
4703 return FoundAddRec;
4704}
4705
4706/// Return the ValueOffsetPair set for \p S. \p S can be represented
4707/// by the value and offset from any ValueOffsetPair in the set.
4708ArrayRef<Value *> ScalarEvolution::getSCEVValues(const SCEV *S) {
4709 ExprValueMapType::iterator SI = ExprValueMap.find_as(S);
4710 if (SI == ExprValueMap.end())
4711 return {};
4712 return SI->second.getArrayRef();
4713}
4714
4715/// Erase Value from ValueExprMap and ExprValueMap. ValueExprMap.erase(V)
4716/// cannot be used separately. eraseValueFromMap should be used to remove
4717/// V from ValueExprMap and ExprValueMap at the same time.
4718void ScalarEvolution::eraseValueFromMap(Value *V) {
4719 ValueExprMapType::iterator I = ValueExprMap.find_as(V);
4720 if (I != ValueExprMap.end()) {
4721 auto EVIt = ExprValueMap.find(I->second);
4722 bool Removed = EVIt->second.remove(V);
4723 (void) Removed;
4724 assert(Removed && "Value not in ExprValueMap?");
4725 ValueExprMap.erase(I);
4726 }
4727}
4728
4729void ScalarEvolution::insertValueToMap(Value *V, const SCEV *S) {
4730 // A recursive query may have already computed the SCEV. It should be
4731 // equivalent, but may not necessarily be exactly the same, e.g. due to lazily
4732 // inferred nowrap flags.
4733 auto It = ValueExprMap.find_as(V);
4734 if (It == ValueExprMap.end()) {
4735 ValueExprMap.insert({SCEVCallbackVH(V, this), S});
4736 ExprValueMap[S].insert(V);
4737 }
4738}
4739
4740/// Return an existing SCEV if it exists, otherwise analyze the expression and
4741/// create a new one.
4743 assert(isSCEVable(V->getType()) && "Value is not SCEVable!");
4744
4745 if (const SCEV *S = getExistingSCEV(V))
4746 return S;
4747 return createSCEVIter(V);
4748}
4749
4751 assert(isSCEVable(V->getType()) && "Value is not SCEVable!");
4752
4753 ValueExprMapType::iterator I = ValueExprMap.find_as(V);
4754 if (I != ValueExprMap.end()) {
4755 const SCEV *S = I->second;
4756 assert(checkValidity(S) &&
4757 "existing SCEV has not been properly invalidated");
4758 return S;
4759 }
4760 return nullptr;
4761}
4762
4763/// Return a SCEV corresponding to -V = -1*V
4765 SCEV::NoWrapFlags Flags) {
4766 if (const SCEVConstant *VC = dyn_cast<SCEVConstant>(V))
4767 return getConstant(
4768 cast<ConstantInt>(ConstantExpr::getNeg(VC->getValue())));
4769
4770 Type *Ty = V->getType();
4771 Ty = getEffectiveSCEVType(Ty);
4772 return getMulExpr(V, getMinusOne(Ty), Flags);
4773}
4774
4775/// If Expr computes ~A, return A else return nullptr
4776static const SCEV *MatchNotExpr(const SCEV *Expr) {
4777 const SCEV *MulOp;
4778 if (match(Expr, m_scev_Add(m_scev_AllOnes(),
4779 m_scev_Mul(m_scev_AllOnes(), m_SCEV(MulOp)))))
4780 return MulOp;
4781 return nullptr;
4782}
4783
4784/// Return a SCEV corresponding to ~V = -1-V
4786 assert(!V->getType()->isPointerTy() && "Can't negate pointer");
4787
4788 if (const SCEVConstant *VC = dyn_cast<SCEVConstant>(V))
4789 return getConstant(
4790 cast<ConstantInt>(ConstantExpr::getNot(VC->getValue())));
4791
4792 // Fold ~(u|s)(min|max)(~x, ~y) to (u|s)(max|min)(x, y)
4793 if (const SCEVMinMaxExpr *MME = dyn_cast<SCEVMinMaxExpr>(V)) {
4794 auto MatchMinMaxNegation = [&](const SCEVMinMaxExpr *MME) {
4795 SmallVector<SCEVUse, 2> MatchedOperands;
4796 for (const SCEV *Operand : MME->operands()) {
4797 const SCEV *Matched = MatchNotExpr(Operand);
4798 if (!Matched)
4799 return (const SCEV *)nullptr;
4800 MatchedOperands.push_back(Matched);
4801 }
4802 return getMinMaxExpr(SCEVMinMaxExpr::negate(MME->getSCEVType()),
4803 MatchedOperands);
4804 };
4805 if (const SCEV *Replaced = MatchMinMaxNegation(MME))
4806 return Replaced;
4807 }
4808
4809 Type *Ty = V->getType();
4810 Ty = getEffectiveSCEVType(Ty);
4811 return getMinusSCEV(getMinusOne(Ty), V);
4812}
4813
4815 assert(P->getType()->isPointerTy());
4816
4817 if (auto *AddRec = dyn_cast<SCEVAddRecExpr>(P)) {
4818 // The base of an AddRec is the first operand.
4819 SmallVector<SCEVUse> Ops{AddRec->operands()};
4820 Ops[0] = removePointerBase(Ops[0]);
4821 // Don't try to transfer nowrap flags for now. We could in some cases
4822 // (for example, if pointer operand of the AddRec is a SCEVUnknown).
4823 return getAddRecExpr(Ops, AddRec->getLoop(), SCEV::FlagAnyWrap);
4824 }
4825 if (auto *Add = dyn_cast<SCEVAddExpr>(P)) {
4826 // The base of an Add is the pointer operand.
4827 SmallVector<SCEVUse> Ops{Add->operands()};
4828 SCEVUse *PtrOp = nullptr;
4829 for (SCEVUse &AddOp : Ops) {
4830 if (AddOp->getType()->isPointerTy()) {
4831 assert(!PtrOp && "Cannot have multiple pointer ops");
4832 PtrOp = &AddOp;
4833 }
4834 }
4835 *PtrOp = removePointerBase(*PtrOp);
4836 // Don't try to transfer nowrap flags for now. We could in some cases
4837 // (for example, if the pointer operand of the Add is a SCEVUnknown).
4838 return getAddExpr(Ops);
4839 }
4840 // Any other expression must be a pointer base.
4841 return getZero(P->getType());
4842}
4843
4845 SCEV::NoWrapFlags Flags,
4846 unsigned Depth) {
4847 // Fast path: X - X --> 0.
4848 if (LHS == RHS)
4849 return getZero(LHS->getType());
4850
4851 // If we subtract two pointers with different pointer bases, bail.
4852 // Eventually, we're going to add an assertion to getMulExpr that we
4853 // can't multiply by a pointer.
4854 if (RHS->getType()->isPointerTy()) {
4855 if (!LHS->getType()->isPointerTy() ||
4856 getPointerBase(LHS) != getPointerBase(RHS))
4857 return getCouldNotCompute();
4858 LHS = removePointerBase(LHS);
4859 RHS = removePointerBase(RHS);
4860 }
4861
4862 // We represent LHS - RHS as LHS + (-1)*RHS. This transformation
4863 // makes it so that we cannot make much use of NUW.
4864 auto AddFlags = SCEV::FlagAnyWrap;
4865 const bool RHSIsNotMinSigned =
4867 if (hasFlags(Flags, SCEV::FlagNSW)) {
4868 // Let M be the minimum representable signed value. Then (-1)*RHS
4869 // signed-wraps if and only if RHS is M. That can happen even for
4870 // a NSW subtraction because e.g. (-1)*M signed-wraps even though
4871 // -1 - M does not. So to transfer NSW from LHS - RHS to LHS +
4872 // (-1)*RHS, we need to prove that RHS != M.
4873 //
4874 // If LHS is non-negative and we know that LHS - RHS does not
4875 // signed-wrap, then RHS cannot be M. So we can rule out signed-wrap
4876 // either by proving that RHS > M or that LHS >= 0.
4877 if (RHSIsNotMinSigned || isKnownNonNegative(LHS)) {
4878 AddFlags = SCEV::FlagNSW;
4879 }
4880 }
4881
4882 // FIXME: Find a correct way to transfer NSW to (-1)*M when LHS -
4883 // RHS is NSW and LHS >= 0.
4884 //
4885 // The difficulty here is that the NSW flag may have been proven
4886 // relative to a loop that is to be found in a recurrence in LHS and
4887 // not in RHS. Applying NSW to (-1)*M may then let the NSW have a
4888 // larger scope than intended.
4889 auto NegFlags = RHSIsNotMinSigned ? SCEV::FlagNSW : SCEV::FlagAnyWrap;
4890
4891 return getAddExpr(LHS, getNegativeSCEV(RHS, NegFlags), AddFlags, Depth);
4892}
4893
4895 unsigned Depth) {
4896 Type *SrcTy = V->getType();
4897 assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
4898 "Cannot truncate or zero extend with non-integer arguments!");
4899 if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
4900 return V; // No conversion
4901 if (getTypeSizeInBits(SrcTy) > getTypeSizeInBits(Ty))
4902 return getTruncateExpr(V, Ty, Depth);
4903 return getZeroExtendExpr(V, Ty, Depth);
4904}
4905
4907 unsigned Depth) {
4908 Type *SrcTy = V->getType();
4909 assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
4910 "Cannot truncate or zero extend with non-integer arguments!");
4911 if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
4912 return V; // No conversion
4913 if (getTypeSizeInBits(SrcTy) > getTypeSizeInBits(Ty))
4914 return getTruncateExpr(V, Ty, Depth);
4915 return getSignExtendExpr(V, Ty, Depth);
4916}
4917
4918const SCEV *
4920 Type *SrcTy = V->getType();
4921 assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
4922 "Cannot noop or zero extend with non-integer arguments!");
4924 "getNoopOrZeroExtend cannot truncate!");
4925 if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
4926 return V; // No conversion
4927 return getZeroExtendExpr(V, Ty);
4928}
4929
4930const SCEV *
4932 Type *SrcTy = V->getType();
4933 assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
4934 "Cannot noop or sign extend with non-integer arguments!");
4936 "getNoopOrSignExtend cannot truncate!");
4937 if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
4938 return V; // No conversion
4939 return getSignExtendExpr(V, Ty);
4940}
4941
4942const SCEV *
4944 Type *SrcTy = V->getType();
4945 assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
4946 "Cannot noop or any extend with non-integer arguments!");
4948 "getNoopOrAnyExtend cannot truncate!");
4949 if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
4950 return V; // No conversion
4951 return getAnyExtendExpr(V, Ty);
4952}
4953
4954const SCEV *
4956 Type *SrcTy = V->getType();
4957 assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
4958 "Cannot truncate or noop with non-integer arguments!");
4960 "getTruncateOrNoop cannot extend!");
4961 if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
4962 return V; // No conversion
4963 return getTruncateExpr(V, Ty);
4964}
4965
4967 const SCEV *RHS) {
4968 const SCEV *PromotedLHS = LHS;
4969 const SCEV *PromotedRHS = RHS;
4970
4971 if (getTypeSizeInBits(LHS->getType()) > getTypeSizeInBits(RHS->getType()))
4972 PromotedRHS = getZeroExtendExpr(RHS, LHS->getType());
4973 else
4974 PromotedLHS = getNoopOrZeroExtend(LHS, RHS->getType());
4975
4976 return getUMaxExpr(PromotedLHS, PromotedRHS);
4977}
4978
4980 const SCEV *RHS,
4981 bool Sequential) {
4982 SmallVector<SCEVUse, 2> Ops = {LHS, RHS};
4983 return getUMinFromMismatchedTypes(Ops, Sequential);
4984}
4985
4986const SCEV *
4988 bool Sequential) {
4989 assert(!Ops.empty() && "At least one operand must be!");
4990 // Trivial case.
4991 if (Ops.size() == 1)
4992 return Ops[0];
4993
4994 // Find the max type first.
4995 Type *MaxType = nullptr;
4996 for (SCEVUse S : Ops)
4997 if (MaxType)
4998 MaxType = getWiderType(MaxType, S->getType());
4999 else
5000 MaxType = S->getType();
5001 assert(MaxType && "Failed to find maximum type!");
5002
5003 // Extend all ops to max type.
5004 SmallVector<SCEVUse, 2> PromotedOps;
5005 for (SCEVUse S : Ops)
5006 PromotedOps.push_back(getNoopOrZeroExtend(S, MaxType));
5007
5008 // Generate umin.
5009 return getUMinExpr(PromotedOps, Sequential);
5010}
5011
5013 // A pointer operand may evaluate to a nonpointer expression, such as null.
5014 if (!V->getType()->isPointerTy())
5015 return V;
5016
5017 while (true) {
5018 if (auto *AddRec = dyn_cast<SCEVAddRecExpr>(V)) {
5019 V = AddRec->getStart();
5020 } else if (auto *Add = dyn_cast<SCEVAddExpr>(V)) {
5021 const SCEV *PtrOp = nullptr;
5022 for (const SCEV *AddOp : Add->operands()) {
5023 if (AddOp->getType()->isPointerTy()) {
5024 assert(!PtrOp && "Cannot have multiple pointer ops");
5025 PtrOp = AddOp;
5026 }
5027 }
5028 assert(PtrOp && "Must have pointer op");
5029 V = PtrOp;
5030 } else // Not something we can look further into.
5031 return V;
5032 }
5033}
5034
5035/// Push users of the given Instruction onto the given Worklist.
5039 // Push the def-use children onto the Worklist stack.
5040 for (User *U : I->users()) {
5041 auto *UserInsn = cast<Instruction>(U);
5042 if (Visited.insert(UserInsn).second)
5043 Worklist.push_back(UserInsn);
5044 }
5045}
5046
5047namespace {
5048
5049/// Takes SCEV S and Loop L. For each AddRec sub-expression, use its start
5050/// expression in case its Loop is L. If it is not L then
5051/// if IgnoreOtherLoops is true then use AddRec itself
5052/// otherwise rewrite cannot be done.
5053/// If SCEV contains non-invariant unknown SCEV rewrite cannot be done.
5054class SCEVInitRewriter : public SCEVRewriteVisitor<SCEVInitRewriter> {
5055public:
5056 static const SCEV *rewrite(const SCEV *S, const Loop *L, ScalarEvolution &SE,
5057 bool IgnoreOtherLoops = true) {
5058 SCEVInitRewriter Rewriter(L, SE);
5059 const SCEV *Result = Rewriter.visit(S);
5060 if (Rewriter.hasSeenLoopVariantSCEVUnknown())
5061 return SE.getCouldNotCompute();
5062 return Rewriter.hasSeenOtherLoops() && !IgnoreOtherLoops
5063 ? SE.getCouldNotCompute()
5064 : Result;
5065 }
5066
5067 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
5068 if (!SE.isLoopInvariant(Expr, L))
5069 SeenLoopVariantSCEVUnknown = true;
5070 return Expr;
5071 }
5072
5073 const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) {
5074 // Only re-write AddRecExprs for this loop.
5075 if (Expr->getLoop() == L)
5076 return Expr->getStart();
5077 SeenOtherLoops = true;
5078 return Expr;
5079 }
5080
5081 bool hasSeenLoopVariantSCEVUnknown() { return SeenLoopVariantSCEVUnknown; }
5082
5083 bool hasSeenOtherLoops() { return SeenOtherLoops; }
5084
5085private:
5086 explicit SCEVInitRewriter(const Loop *L, ScalarEvolution &SE)
5087 : SCEVRewriteVisitor(SE), L(L) {}
5088
5089 const Loop *L;
5090 bool SeenLoopVariantSCEVUnknown = false;
5091 bool SeenOtherLoops = false;
5092};
5093
5094/// Takes SCEV S and Loop L. For each AddRec sub-expression, use its post
5095/// increment expression in case its Loop is L. If it is not L then
5096/// use AddRec itself.
5097/// If SCEV contains non-invariant unknown SCEV rewrite cannot be done.
5098class SCEVPostIncRewriter : public SCEVRewriteVisitor<SCEVPostIncRewriter> {
5099public:
5100 static const SCEV *rewrite(const SCEV *S, const Loop *L, ScalarEvolution &SE) {
5101 SCEVPostIncRewriter Rewriter(L, SE);
5102 const SCEV *Result = Rewriter.visit(S);
5103 return Rewriter.hasSeenLoopVariantSCEVUnknown()
5104 ? SE.getCouldNotCompute()
5105 : Result;
5106 }
5107
5108 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
5109 if (!SE.isLoopInvariant(Expr, L))
5110 SeenLoopVariantSCEVUnknown = true;
5111 return Expr;
5112 }
5113
5114 const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) {
5115 // Only re-write AddRecExprs for this loop.
5116 if (Expr->getLoop() == L)
5117 return Expr->getPostIncExpr(SE);
5118 SeenOtherLoops = true;
5119 return Expr;
5120 }
5121
5122 bool hasSeenLoopVariantSCEVUnknown() { return SeenLoopVariantSCEVUnknown; }
5123
5124 bool hasSeenOtherLoops() { return SeenOtherLoops; }
5125
5126private:
5127 explicit SCEVPostIncRewriter(const Loop *L, ScalarEvolution &SE)
5128 : SCEVRewriteVisitor(SE), L(L) {}
5129
5130 const Loop *L;
5131 bool SeenLoopVariantSCEVUnknown = false;
5132 bool SeenOtherLoops = false;
5133};
5134
5135/// This class evaluates the compare condition by matching it against the
5136/// condition of loop latch. If there is a match we assume a true value
5137/// for the condition while building SCEV nodes.
5138class SCEVBackedgeConditionFolder
5139 : public SCEVRewriteVisitor<SCEVBackedgeConditionFolder> {
5140public:
5141 static const SCEV *rewrite(const SCEV *S, const Loop *L,
5142 ScalarEvolution &SE) {
5143 bool IsPosBECond = false;
5144 Value *BECond = nullptr;
5145 if (BasicBlock *Latch = L->getLoopLatch()) {
5146 if (CondBrInst *BI = dyn_cast<CondBrInst>(Latch->getTerminator())) {
5147 assert(BI->getSuccessor(0) != BI->getSuccessor(1) &&
5148 "Both outgoing branches should not target same header!");
5149 BECond = BI->getCondition();
5150 IsPosBECond = BI->getSuccessor(0) == L->getHeader();
5151 } else {
5152 return S;
5153 }
5154 }
5155 SCEVBackedgeConditionFolder Rewriter(L, BECond, IsPosBECond, SE);
5156 return Rewriter.visit(S);
5157 }
5158
5159 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
5160 const SCEV *Result = Expr;
5161 bool InvariantF = SE.isLoopInvariant(Expr, L);
5162
5163 if (!InvariantF) {
5165 switch (I->getOpcode()) {
5166 case Instruction::Select: {
5167 SelectInst *SI = cast<SelectInst>(I);
5168 std::optional<const SCEV *> Res =
5169 compareWithBackedgeCondition(SI->getCondition());
5170 if (Res) {
5171 bool IsOne = cast<SCEVConstant>(*Res)->getValue()->isOne();
5172 Result = SE.getSCEV(IsOne ? SI->getTrueValue() : SI->getFalseValue());
5173 }
5174 break;
5175 }
5176 default: {
5177 std::optional<const SCEV *> Res = compareWithBackedgeCondition(I);
5178 if (Res)
5179 Result = *Res;
5180 break;
5181 }
5182 }
5183 }
5184 return Result;
5185 }
5186
5187private:
5188 explicit SCEVBackedgeConditionFolder(const Loop *L, Value *BECond,
5189 bool IsPosBECond, ScalarEvolution &SE)
5190 : SCEVRewriteVisitor(SE), L(L), BackedgeCond(BECond),
5191 IsPositiveBECond(IsPosBECond) {}
5192
5193 std::optional<const SCEV *> compareWithBackedgeCondition(Value *IC);
5194
5195 const Loop *L;
5196 /// Loop back condition.
5197 Value *BackedgeCond = nullptr;
5198 /// Set to true if loop back is on positive branch condition.
5199 bool IsPositiveBECond;
5200};
5201
5202std::optional<const SCEV *>
5203SCEVBackedgeConditionFolder::compareWithBackedgeCondition(Value *IC) {
5204
5205 // If value matches the backedge condition for loop latch,
5206 // then return a constant evolution node based on loopback
5207 // branch taken.
5208 if (BackedgeCond == IC)
5209 return IsPositiveBECond ? SE.getOne(Type::getInt1Ty(SE.getContext()))
5211 return std::nullopt;
5212}
5213
5214class SCEVShiftRewriter : public SCEVRewriteVisitor<SCEVShiftRewriter> {
5215public:
5216 static const SCEV *rewrite(const SCEV *S, const Loop *L,
5217 ScalarEvolution &SE) {
5218 SCEVShiftRewriter Rewriter(L, SE);
5219 const SCEV *Result = Rewriter.visit(S);
5220 return Rewriter.isValid() ? Result : SE.getCouldNotCompute();
5221 }
5222
5223 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
5224 // Only allow AddRecExprs for this loop.
5225 if (!SE.isLoopInvariant(Expr, L))
5226 Valid = false;
5227 return Expr;
5228 }
5229
5230 const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) {
5231 if (Expr->getLoop() == L && Expr->isAffine())
5232 return SE.getMinusSCEV(Expr, Expr->getStepRecurrence(SE));
5233 Valid = false;
5234 return Expr;
5235 }
5236
5237 bool isValid() { return Valid; }
5238
5239private:
5240 explicit SCEVShiftRewriter(const Loop *L, ScalarEvolution &SE)
5241 : SCEVRewriteVisitor(SE), L(L) {}
5242
5243 const Loop *L;
5244 bool Valid = true;
5245};
5246
5247} // end anonymous namespace
5248
5250ScalarEvolution::proveNoWrapViaConstantRanges(const SCEVAddRecExpr *AR) {
5251 if (!AR->isAffine())
5252 return SCEV::FlagAnyWrap;
5253
5254 using OBO = OverflowingBinaryOperator;
5255
5257
5258 if (!AR->hasNoSelfWrap()) {
5259 const SCEV *BECount = getConstantMaxBackedgeTakenCount(AR->getLoop());
5260 if (const SCEVConstant *BECountMax = dyn_cast<SCEVConstant>(BECount)) {
5261 ConstantRange StepCR = getSignedRange(AR->getStepRecurrence(*this));
5262 const APInt &BECountAP = BECountMax->getAPInt();
5263 unsigned NoOverflowBitWidth =
5264 BECountAP.getActiveBits() + StepCR.getMinSignedBits();
5265 if (NoOverflowBitWidth <= getTypeSizeInBits(AR->getType()))
5267 }
5268 }
5269
5270 if (!AR->hasNoSignedWrap()) {
5271 ConstantRange AddRecRange = getSignedRange(AR);
5272 ConstantRange IncRange = getSignedRange(AR->getStepRecurrence(*this));
5273
5275 Instruction::Add, IncRange, OBO::NoSignedWrap);
5276 if (NSWRegion.contains(AddRecRange))
5278 }
5279
5280 if (!AR->hasNoUnsignedWrap()) {
5281 ConstantRange AddRecRange = getUnsignedRange(AR);
5282 ConstantRange IncRange = getUnsignedRange(AR->getStepRecurrence(*this));
5283
5285 Instruction::Add, IncRange, OBO::NoUnsignedWrap);
5286 if (NUWRegion.contains(AddRecRange))
5288 }
5289
5290 return Result;
5291}
5292
5294ScalarEvolution::proveNoSignedWrapViaInduction(const SCEVAddRecExpr *AR) {
5296
5297 if (AR->hasNoSignedWrap())
5298 return Result;
5299
5300 if (!AR->isAffine())
5301 return Result;
5302
5303 // This function can be expensive, only try to prove NSW once per AddRec.
5304 if (!SignedWrapViaInductionTried.insert(AR).second)
5305 return Result;
5306
5307 const SCEV *Step = AR->getStepRecurrence(*this);
5308 const Loop *L = AR->getLoop();
5309
5310 // Check whether the backedge-taken count is SCEVCouldNotCompute.
5311 // Note that this serves two purposes: It filters out loops that are
5312 // simply not analyzable, and it covers the case where this code is
5313 // being called from within backedge-taken count analysis, such that
5314 // attempting to ask for the backedge-taken count would likely result
5315 // in infinite recursion. In the later case, the analysis code will
5316 // cope with a conservative value, and it will take care to purge
5317 // that value once it has finished.
5318 const SCEV *MaxBECount = getConstantMaxBackedgeTakenCount(L);
5319
5320 // Normally, in the cases we can prove no-overflow via a
5321 // backedge guarding condition, we can also compute a backedge
5322 // taken count for the loop. The exceptions are assumptions and
5323 // guards present in the loop -- SCEV is not great at exploiting
5324 // these to compute max backedge taken counts, but can still use
5325 // these to prove lack of overflow. Use this fact to avoid
5326 // doing extra work that may not pay off.
5327
5328 if (isa<SCEVCouldNotCompute>(MaxBECount) && !HasGuards &&
5329 AC.assumptions().empty())
5330 return Result;
5331
5332 // If the backedge is guarded by a comparison with the pre-inc value the
5333 // addrec is safe. Also, if the entry is guarded by a comparison with the
5334 // start value and the backedge is guarded by a comparison with the post-inc
5335 // value, the addrec is safe.
5337 const SCEV *OverflowLimit =
5338 getSignedOverflowLimitForStep(Step, &Pred, this);
5339 if (OverflowLimit &&
5340 (isLoopBackedgeGuardedByCond(L, Pred, AR, OverflowLimit) ||
5341 isKnownOnEveryIteration(Pred, AR, OverflowLimit))) {
5342 Result = setFlags(Result, SCEV::FlagNSW);
5343 }
5344 return Result;
5345}
5347ScalarEvolution::proveNoUnsignedWrapViaInduction(const SCEVAddRecExpr *AR) {
5349
5350 if (AR->hasNoUnsignedWrap())
5351 return Result;
5352
5353 if (!AR->isAffine())
5354 return Result;
5355
5356 // This function can be expensive, only try to prove NUW once per AddRec.
5357 if (!UnsignedWrapViaInductionTried.insert(AR).second)
5358 return Result;
5359
5360 const SCEV *Step = AR->getStepRecurrence(*this);
5361 unsigned BitWidth = getTypeSizeInBits(AR->getType());
5362 const Loop *L = AR->getLoop();
5363
5364 // Check whether the backedge-taken count is SCEVCouldNotCompute.
5365 // Note that this serves two purposes: It filters out loops that are
5366 // simply not analyzable, and it covers the case where this code is
5367 // being called from within backedge-taken count analysis, such that
5368 // attempting to ask for the backedge-taken count would likely result
5369 // in infinite recursion. In the later case, the analysis code will
5370 // cope with a conservative value, and it will take care to purge
5371 // that value once it has finished.
5372 const SCEV *MaxBECount = getConstantMaxBackedgeTakenCount(L);
5373
5374 // Normally, in the cases we can prove no-overflow via a
5375 // backedge guarding condition, we can also compute a backedge
5376 // taken count for the loop. The exceptions are assumptions and
5377 // guards present in the loop -- SCEV is not great at exploiting
5378 // these to compute max backedge taken counts, but can still use
5379 // these to prove lack of overflow. Use this fact to avoid
5380 // doing extra work that may not pay off.
5381
5382 if (isa<SCEVCouldNotCompute>(MaxBECount) && !HasGuards &&
5383 AC.assumptions().empty())
5384 return Result;
5385
5386 // If the backedge is guarded by a comparison with the pre-inc value the
5387 // addrec is safe. Also, if the entry is guarded by a comparison with the
5388 // start value and the backedge is guarded by a comparison with the post-inc
5389 // value, the addrec is safe.
5390 if (isKnownPositive(Step)) {
5391 const SCEV *N = getConstant(APInt::getMinValue(BitWidth) -
5392 getUnsignedRangeMax(Step));
5395 Result = setFlags(Result, SCEV::FlagNUW);
5396 }
5397 }
5398
5399 return Result;
5400}
5401
5402namespace {
5403
5404/// Represents an abstract binary operation. This may exist as a
5405/// normal instruction or constant expression, or may have been
5406/// derived from an expression tree.
5407struct BinaryOp {
5408 unsigned Opcode;
5409 Value *LHS;
5410 Value *RHS;
5411 bool IsNSW = false;
5412 bool IsNUW = false;
5413
5414 /// Op is set if this BinaryOp corresponds to a concrete LLVM instruction or
5415 /// constant expression.
5416 Operator *Op = nullptr;
5417
5418 explicit BinaryOp(Operator *Op)
5419 : Opcode(Op->getOpcode()), LHS(Op->getOperand(0)), RHS(Op->getOperand(1)),
5420 Op(Op) {
5421 if (auto *OBO = dyn_cast<OverflowingBinaryOperator>(Op)) {
5422 IsNSW = OBO->hasNoSignedWrap();
5423 IsNUW = OBO->hasNoUnsignedWrap();
5424 }
5425 }
5426
5427 explicit BinaryOp(unsigned Opcode, Value *LHS, Value *RHS, bool IsNSW = false,
5428 bool IsNUW = false)
5429 : Opcode(Opcode), LHS(LHS), RHS(RHS), IsNSW(IsNSW), IsNUW(IsNUW) {}
5430};
5431
5432} // end anonymous namespace
5433
5434/// Try to map \p V into a BinaryOp, and return \c std::nullopt on failure.
5435static std::optional<BinaryOp> MatchBinaryOp(Value *V, const DataLayout &DL,
5436 AssumptionCache &AC,
5437 const DominatorTree &DT,
5438 const Instruction *CxtI) {
5439 auto *Op = dyn_cast<Operator>(V);
5440 if (!Op)
5441 return std::nullopt;
5442
5443 // Implementation detail: all the cleverness here should happen without
5444 // creating new SCEV expressions -- our caller knowns tricks to avoid creating
5445 // SCEV expressions when possible, and we should not break that.
5446
5447 switch (Op->getOpcode()) {
5448 case Instruction::Add:
5449 case Instruction::Sub:
5450 case Instruction::Mul:
5451 case Instruction::UDiv:
5452 case Instruction::URem:
5453 case Instruction::And:
5454 case Instruction::AShr:
5455 case Instruction::Shl:
5456 return BinaryOp(Op);
5457
5458 case Instruction::Or: {
5459 // Convert or disjoint into add nuw nsw.
5460 if (cast<PossiblyDisjointInst>(Op)->isDisjoint())
5461 return BinaryOp(Instruction::Add, Op->getOperand(0), Op->getOperand(1),
5462 /*IsNSW=*/true, /*IsNUW=*/true);
5463 return BinaryOp(Op);
5464 }
5465
5466 case Instruction::Xor:
5467 if (auto *RHSC = dyn_cast<ConstantInt>(Op->getOperand(1)))
5468 // If the RHS of the xor is a signmask, then this is just an add.
5469 // Instcombine turns add of signmask into xor as a strength reduction step.
5470 if (RHSC->getValue().isSignMask())
5471 return BinaryOp(Instruction::Add, Op->getOperand(0), Op->getOperand(1));
5472 // Binary `xor` is a bit-wise `add`.
5473 if (V->getType()->isIntegerTy(1))
5474 return BinaryOp(Instruction::Add, Op->getOperand(0), Op->getOperand(1));
5475 return BinaryOp(Op);
5476
5477 case Instruction::LShr:
5478 // Turn logical shift right of a constant into a unsigned divide.
5479 if (ConstantInt *SA = dyn_cast<ConstantInt>(Op->getOperand(1))) {
5480 uint32_t BitWidth = cast<IntegerType>(Op->getType())->getBitWidth();
5481
5482 // If the shift count is not less than the bitwidth, the result of
5483 // the shift is undefined. Don't try to analyze it, because the
5484 // resolution chosen here may differ from the resolution chosen in
5485 // other parts of the compiler.
5486 if (SA->getValue().ult(BitWidth)) {
5487 Constant *X =
5488 ConstantInt::get(SA->getContext(),
5489 APInt::getOneBitSet(BitWidth, SA->getZExtValue()));
5490 return BinaryOp(Instruction::UDiv, Op->getOperand(0), X);
5491 }
5492 }
5493 return BinaryOp(Op);
5494
5495 case Instruction::ExtractValue: {
5496 auto *EVI = cast<ExtractValueInst>(Op);
5497 if (EVI->getNumIndices() != 1 || EVI->getIndices()[0] != 0)
5498 break;
5499
5500 auto *WO = dyn_cast<WithOverflowInst>(EVI->getAggregateOperand());
5501 if (!WO)
5502 break;
5503
5504 Instruction::BinaryOps BinOp = WO->getBinaryOp();
5505 bool Signed = WO->isSigned();
5506 // TODO: Should add nuw/nsw flags for mul as well.
5507 if (BinOp == Instruction::Mul || !isOverflowIntrinsicNoWrap(WO, DT))
5508 return BinaryOp(BinOp, WO->getLHS(), WO->getRHS());
5509
5510 // Now that we know that all uses of the arithmetic-result component of
5511 // CI are guarded by the overflow check, we can go ahead and pretend
5512 // that the arithmetic is non-overflowing.
5513 return BinaryOp(BinOp, WO->getLHS(), WO->getRHS(),
5514 /* IsNSW = */ Signed, /* IsNUW = */ !Signed);
5515 }
5516
5517 default:
5518 break;
5519 }
5520
5521 // Recognise intrinsic loop.decrement.reg, and as this has exactly the same
5522 // semantics as a Sub, return a binary sub expression.
5523 if (auto *II = dyn_cast<IntrinsicInst>(V))
5524 if (II->getIntrinsicID() == Intrinsic::loop_decrement_reg)
5525 return BinaryOp(Instruction::Sub, II->getOperand(0), II->getOperand(1));
5526
5527 return std::nullopt;
5528}
5529
5530/// Helper function to createAddRecFromPHIWithCasts. We have a phi
5531/// node whose symbolic (unknown) SCEV is \p SymbolicPHI, which is updated via
5532/// the loop backedge by a SCEVAddExpr, possibly also with a few casts on the
5533/// way. This function checks if \p Op, an operand of this SCEVAddExpr,
5534/// follows one of the following patterns:
5535/// Op == (SExt ix (Trunc iy (%SymbolicPHI) to ix) to iy)
5536/// Op == (ZExt ix (Trunc iy (%SymbolicPHI) to ix) to iy)
5537/// If the SCEV expression of \p Op conforms with one of the expected patterns
5538/// we return the type of the truncation operation, and indicate whether the
5539/// truncated type should be treated as signed/unsigned by setting
5540/// \p Signed to true/false, respectively.
5541static Type *isSimpleCastedPHI(const SCEV *Op, const SCEVUnknown *SymbolicPHI,
5542 bool &Signed, ScalarEvolution &SE) {
5543 // The case where Op == SymbolicPHI (that is, with no type conversions on
5544 // the way) is handled by the regular add recurrence creating logic and
5545 // would have already been triggered in createAddRecForPHI. Reaching it here
5546 // means that createAddRecFromPHI had failed for this PHI before (e.g.,
5547 // because one of the other operands of the SCEVAddExpr updating this PHI is
5548 // not invariant).
5549 //
5550 // Here we look for the case where Op = (ext(trunc(SymbolicPHI))), and in
5551 // this case predicates that allow us to prove that Op == SymbolicPHI will
5552 // be added.
5553 if (Op == SymbolicPHI)
5554 return nullptr;
5555
5556 unsigned SourceBits = SE.getTypeSizeInBits(SymbolicPHI->getType());
5557 unsigned NewBits = SE.getTypeSizeInBits(Op->getType());
5558 if (SourceBits != NewBits)
5559 return nullptr;
5560
5561 if (match(Op, m_scev_SExt(m_scev_Trunc(m_scev_Specific(SymbolicPHI))))) {
5562 Signed = true;
5563 return cast<SCEVCastExpr>(Op)->getOperand()->getType();
5564 }
5565 if (match(Op, m_scev_ZExt(m_scev_Trunc(m_scev_Specific(SymbolicPHI))))) {
5566 Signed = false;
5567 return cast<SCEVCastExpr>(Op)->getOperand()->getType();
5568 }
5569 return nullptr;
5570}
5571
5572static const Loop *isIntegerLoopHeaderPHI(const PHINode *PN, LoopInfo &LI) {
5573 if (!PN->getType()->isIntegerTy())
5574 return nullptr;
5575 const Loop *L = LI.getLoopFor(PN->getParent());
5576 if (!L || L->getHeader() != PN->getParent())
5577 return nullptr;
5578 return L;
5579}
5580
5581// Analyze \p SymbolicPHI, a SCEV expression of a phi node, and check if the
5582// computation that updates the phi follows the following pattern:
5583// (SExt/ZExt ix (Trunc iy (%SymbolicPHI) to ix) to iy) + InvariantAccum
5584// which correspond to a phi->trunc->sext/zext->add->phi update chain.
5585// If so, try to see if it can be rewritten as an AddRecExpr under some
5586// Predicates. If successful, return them as a pair. Also cache the results
5587// of the analysis.
5588//
5589// Example usage scenario:
5590// Say the Rewriter is called for the following SCEV:
5591// 8 * ((sext i32 (trunc i64 %X to i32) to i64) + %Step)
5592// where:
5593// %X = phi i64 (%Start, %BEValue)
5594// It will visitMul->visitAdd->visitSExt->visitTrunc->visitUnknown(%X),
5595// and call this function with %SymbolicPHI = %X.
5596//
5597// The analysis will find that the value coming around the backedge has
5598// the following SCEV:
5599// BEValue = ((sext i32 (trunc i64 %X to i32) to i64) + %Step)
5600// Upon concluding that this matches the desired pattern, the function
5601// will return the pair {NewAddRec, SmallPredsVec} where:
5602// NewAddRec = {%Start,+,%Step}
5603// SmallPredsVec = {P1, P2, P3} as follows:
5604// P1(WrapPred): AR: {trunc(%Start),+,(trunc %Step)}<nsw> Flags: <nssw>
5605// P2(EqualPred): %Start == (sext i32 (trunc i64 %Start to i32) to i64)
5606// P3(EqualPred): %Step == (sext i32 (trunc i64 %Step to i32) to i64)
5607// The returned pair means that SymbolicPHI can be rewritten into NewAddRec
5608// under the predicates {P1,P2,P3}.
5609// This predicated rewrite will be cached in PredicatedSCEVRewrites:
5610// PredicatedSCEVRewrites[{%X,L}] = {NewAddRec, {P1,P2,P3)}
5611//
5612// TODO's:
5613//
5614// 1) Extend the Induction descriptor to also support inductions that involve
5615// casts: When needed (namely, when we are called in the context of the
5616// vectorizer induction analysis), a Set of cast instructions will be
5617// populated by this method, and provided back to isInductionPHI. This is
5618// needed to allow the vectorizer to properly record them to be ignored by
5619// the cost model and to avoid vectorizing them (otherwise these casts,
5620// which are redundant under the runtime overflow checks, will be
5621// vectorized, which can be costly).
5622//
5623// 2) Support additional induction/PHISCEV patterns: We also want to support
5624// inductions where the sext-trunc / zext-trunc operations (partly) occur
5625// after the induction update operation (the induction increment):
5626//
5627// (Trunc iy (SExt/ZExt ix (%SymbolicPHI + InvariantAccum) to iy) to ix)
5628// which correspond to a phi->add->trunc->sext/zext->phi update chain.
5629//
5630// (Trunc iy ((SExt/ZExt ix (%SymbolicPhi) to iy) + InvariantAccum) to ix)
5631// which correspond to a phi->trunc->add->sext/zext->phi update chain.
5632//
5633// 3) Outline common code with createAddRecFromPHI to avoid duplication.
5634std::optional<std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>>>
5635ScalarEvolution::createAddRecFromPHIWithCastsImpl(const SCEVUnknown *SymbolicPHI) {
5637
5638 // *** Part1: Analyze if we have a phi-with-cast pattern for which we can
5639 // return an AddRec expression under some predicate.
5640
5641 auto *PN = cast<PHINode>(SymbolicPHI->getValue());
5642 const Loop *L = isIntegerLoopHeaderPHI(PN, LI);
5643 assert(L && "Expecting an integer loop header phi");
5644
5645 // The loop may have multiple entrances or multiple exits; we can analyze
5646 // this phi as an addrec if it has a unique entry value and a unique
5647 // backedge value.
5648 Value *BEValueV = nullptr, *StartValueV = nullptr;
5649 for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) {
5650 Value *V = PN->getIncomingValue(i);
5651 if (L->contains(PN->getIncomingBlock(i))) {
5652 if (!BEValueV) {
5653 BEValueV = V;
5654 } else if (BEValueV != V) {
5655 BEValueV = nullptr;
5656 break;
5657 }
5658 } else if (!StartValueV) {
5659 StartValueV = V;
5660 } else if (StartValueV != V) {
5661 StartValueV = nullptr;
5662 break;
5663 }
5664 }
5665 if (!BEValueV || !StartValueV)
5666 return std::nullopt;
5667
5668 const SCEV *BEValue = getSCEV(BEValueV);
5669
5670 // If the value coming around the backedge is an add with the symbolic
5671 // value we just inserted, possibly with casts that we can ignore under
5672 // an appropriate runtime guard, then we found a simple induction variable!
5673 const auto *Add = dyn_cast<SCEVAddExpr>(BEValue);
5674 if (!Add)
5675 return std::nullopt;
5676
5677 // If there is a single occurrence of the symbolic value, possibly
5678 // casted, replace it with a recurrence.
5679 unsigned FoundIndex = Add->getNumOperands();
5680 Type *TruncTy = nullptr;
5681 bool Signed;
5682 for (unsigned i = 0, e = Add->getNumOperands(); i != e; ++i)
5683 if ((TruncTy =
5684 isSimpleCastedPHI(Add->getOperand(i), SymbolicPHI, Signed, *this)))
5685 if (FoundIndex == e) {
5686 FoundIndex = i;
5687 break;
5688 }
5689
5690 if (FoundIndex == Add->getNumOperands())
5691 return std::nullopt;
5692
5693 // Create an add with everything but the specified operand.
5695 for (unsigned i = 0, e = Add->getNumOperands(); i != e; ++i)
5696 if (i != FoundIndex)
5697 Ops.push_back(Add->getOperand(i));
5698 const SCEV *Accum = getAddExpr(Ops);
5699
5700 // The runtime checks will not be valid if the step amount is
5701 // varying inside the loop.
5702 if (!isLoopInvariant(Accum, L))
5703 return std::nullopt;
5704
5705 // *** Part2: Create the predicates
5706
5707 // Analysis was successful: we have a phi-with-cast pattern for which we
5708 // can return an AddRec expression under the following predicates:
5709 //
5710 // P1: A Wrap predicate that guarantees that Trunc(Start) + i*Trunc(Accum)
5711 // fits within the truncated type (does not overflow) for i = 0 to n-1.
5712 // P2: An Equal predicate that guarantees that
5713 // Start = (Ext ix (Trunc iy (Start) to ix) to iy)
5714 // P3: An Equal predicate that guarantees that
5715 // Accum = (Ext ix (Trunc iy (Accum) to ix) to iy)
5716 //
5717 // As we next prove, the above predicates guarantee that:
5718 // Start + i*Accum = (Ext ix (Trunc iy ( Start + i*Accum ) to ix) to iy)
5719 //
5720 //
5721 // More formally, we want to prove that:
5722 // Expr(i+1) = Start + (i+1) * Accum
5723 // = (Ext ix (Trunc iy (Expr(i)) to ix) to iy) + Accum
5724 //
5725 // Given that:
5726 // 1) Expr(0) = Start
5727 // 2) Expr(1) = Start + Accum
5728 // = (Ext ix (Trunc iy (Start) to ix) to iy) + Accum :: from P2
5729 // 3) Induction hypothesis (step i):
5730 // Expr(i) = (Ext ix (Trunc iy (Expr(i-1)) to ix) to iy) + Accum
5731 //
5732 // Proof:
5733 // Expr(i+1) =
5734 // = Start + (i+1)*Accum
5735 // = (Start + i*Accum) + Accum
5736 // = Expr(i) + Accum
5737 // = (Ext ix (Trunc iy (Expr(i-1)) to ix) to iy) + Accum + Accum
5738 // :: from step i
5739 //
5740 // = (Ext ix (Trunc iy (Start + (i-1)*Accum) to ix) to iy) + Accum + Accum
5741 //
5742 // = (Ext ix (Trunc iy (Start + (i-1)*Accum) to ix) to iy)
5743 // + (Ext ix (Trunc iy (Accum) to ix) to iy)
5744 // + Accum :: from P3
5745 //
5746 // = (Ext ix (Trunc iy ((Start + (i-1)*Accum) + Accum) to ix) to iy)
5747 // + Accum :: from P1: Ext(x)+Ext(y)=>Ext(x+y)
5748 //
5749 // = (Ext ix (Trunc iy (Start + i*Accum) to ix) to iy) + Accum
5750 // = (Ext ix (Trunc iy (Expr(i)) to ix) to iy) + Accum
5751 //
5752 // By induction, the same applies to all iterations 1<=i<n:
5753 //
5754
5755 // Create a truncated addrec for which we will add a no overflow check (P1).
5756 const SCEV *StartVal = getSCEV(StartValueV);
5757 const SCEV *PHISCEV =
5758 getAddRecExpr(getTruncateExpr(StartVal, TruncTy),
5759 getTruncateExpr(Accum, TruncTy), L, SCEV::FlagAnyWrap);
5760
5761 // PHISCEV can be either a SCEVConstant or a SCEVAddRecExpr.
5762 // ex: If truncated Accum is 0 and StartVal is a constant, then PHISCEV
5763 // will be constant.
5764 //
5765 // If PHISCEV is a constant, then P1 degenerates into P2 or P3, so we don't
5766 // add P1.
5767 if (const auto *AR = dyn_cast<SCEVAddRecExpr>(PHISCEV)) {
5771 const SCEVPredicate *AddRecPred = getWrapPredicate(AR, AddedFlags);
5772 Predicates.push_back(AddRecPred);
5773 }
5774
5775 // Create the Equal Predicates P2,P3:
5776
5777 // It is possible that the predicates P2 and/or P3 are computable at
5778 // compile time due to StartVal and/or Accum being constants.
5779 // If either one is, then we can check that now and escape if either P2
5780 // or P3 is false.
5781
5782 // Construct the extended SCEV: (Ext ix (Trunc iy (Expr) to ix) to iy)
5783 // for each of StartVal and Accum
5784 auto getExtendedExpr = [&](const SCEV *Expr,
5785 bool CreateSignExtend) -> const SCEV * {
5786 assert(isLoopInvariant(Expr, L) && "Expr is expected to be invariant");
5787 const SCEV *TruncatedExpr = getTruncateExpr(Expr, TruncTy);
5788 const SCEV *ExtendedExpr =
5789 CreateSignExtend ? getSignExtendExpr(TruncatedExpr, Expr->getType())
5790 : getZeroExtendExpr(TruncatedExpr, Expr->getType());
5791 return ExtendedExpr;
5792 };
5793
5794 // Given:
5795 // ExtendedExpr = (Ext ix (Trunc iy (Expr) to ix) to iy
5796 // = getExtendedExpr(Expr)
5797 // Determine whether the predicate P: Expr == ExtendedExpr
5798 // is known to be false at compile time
5799 auto PredIsKnownFalse = [&](const SCEV *Expr,
5800 const SCEV *ExtendedExpr) -> bool {
5801 return Expr != ExtendedExpr &&
5802 isKnownPredicate(ICmpInst::ICMP_NE, Expr, ExtendedExpr);
5803 };
5804
5805 const SCEV *StartExtended = getExtendedExpr(StartVal, Signed);
5806 if (PredIsKnownFalse(StartVal, StartExtended)) {
5807 LLVM_DEBUG(dbgs() << "P2 is compile-time false\n";);
5808 return std::nullopt;
5809 }
5810
5811 // The Step is always Signed (because the overflow checks are either
5812 // NSSW or NUSW)
5813 const SCEV *AccumExtended = getExtendedExpr(Accum, /*CreateSignExtend=*/true);
5814 if (PredIsKnownFalse(Accum, AccumExtended)) {
5815 LLVM_DEBUG(dbgs() << "P3 is compile-time false\n";);
5816 return std::nullopt;
5817 }
5818
5819 auto AppendPredicate = [&](const SCEV *Expr,
5820 const SCEV *ExtendedExpr) -> void {
5821 if (Expr != ExtendedExpr &&
5822 !isKnownPredicate(ICmpInst::ICMP_EQ, Expr, ExtendedExpr)) {
5823 const SCEVPredicate *Pred = getEqualPredicate(Expr, ExtendedExpr);
5824 LLVM_DEBUG(dbgs() << "Added Predicate: " << *Pred);
5825 Predicates.push_back(Pred);
5826 }
5827 };
5828
5829 AppendPredicate(StartVal, StartExtended);
5830 AppendPredicate(Accum, AccumExtended);
5831
5832 // *** Part3: Predicates are ready. Now go ahead and create the new addrec in
5833 // which the casts had been folded away. The caller can rewrite SymbolicPHI
5834 // into NewAR if it will also add the runtime overflow checks specified in
5835 // Predicates.
5836 auto *NewAR = getAddRecExpr(StartVal, Accum, L, SCEV::FlagAnyWrap);
5837
5838 std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>> PredRewrite =
5839 std::make_pair(NewAR, Predicates);
5840 // Remember the result of the analysis for this SCEV at this locayyytion.
5841 PredicatedSCEVRewrites[{SymbolicPHI, L}] = PredRewrite;
5842 return PredRewrite;
5843}
5844
5845std::optional<std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>>>
5847 auto *PN = cast<PHINode>(SymbolicPHI->getValue());
5848 const Loop *L = isIntegerLoopHeaderPHI(PN, LI);
5849 if (!L)
5850 return std::nullopt;
5851
5852 // Check to see if we already analyzed this PHI.
5853 auto I = PredicatedSCEVRewrites.find({SymbolicPHI, L});
5854 if (I != PredicatedSCEVRewrites.end()) {
5855 std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>> Rewrite =
5856 I->second;
5857 // Analysis was done before and failed to create an AddRec:
5858 if (Rewrite.first == SymbolicPHI)
5859 return std::nullopt;
5860 // Analysis was done before and succeeded to create an AddRec under
5861 // a predicate:
5862 assert(isa<SCEVAddRecExpr>(Rewrite.first) && "Expected an AddRec");
5863 assert(!(Rewrite.second).empty() && "Expected to find Predicates");
5864 return Rewrite;
5865 }
5866
5867 std::optional<std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>>>
5868 Rewrite = createAddRecFromPHIWithCastsImpl(SymbolicPHI);
5869
5870 // Record in the cache that the analysis failed
5871 if (!Rewrite) {
5873 PredicatedSCEVRewrites[{SymbolicPHI, L}] = {SymbolicPHI, Predicates};
5874 return std::nullopt;
5875 }
5876
5877 return Rewrite;
5878}
5879
5880// FIXME: This utility is currently required because the Rewriter currently
5881// does not rewrite this expression:
5882// {0, +, (sext ix (trunc iy to ix) to iy)}
5883// into {0, +, %step},
5884// even when the following Equal predicate exists:
5885// "%step == (sext ix (trunc iy to ix) to iy)".
5887 const SCEVAddRecExpr *AR1, const SCEVAddRecExpr *AR2) const {
5888 if (AR1 == AR2)
5889 return true;
5890
5891 auto areExprsEqual = [&](const SCEV *Expr1, const SCEV *Expr2) -> bool {
5892 if (Expr1 != Expr2 &&
5893 !Preds->implies(SE.getEqualPredicate(Expr1, Expr2), SE) &&
5894 !Preds->implies(SE.getEqualPredicate(Expr2, Expr1), SE))
5895 return false;
5896 return true;
5897 };
5898
5899 if (!areExprsEqual(AR1->getStart(), AR2->getStart()) ||
5900 !areExprsEqual(AR1->getStepRecurrence(SE), AR2->getStepRecurrence(SE)))
5901 return false;
5902 return true;
5903}
5904
5905/// A helper function for createAddRecFromPHI to handle simple cases.
5906///
5907/// This function tries to find an AddRec expression for the simplest (yet most
5908/// common) cases: PN = PHI(Start, OP(Self, LoopInvariant)).
5909/// If it fails, createAddRecFromPHI will use a more general, but slow,
5910/// technique for finding the AddRec expression.
5911const SCEV *ScalarEvolution::createSimpleAffineAddRec(PHINode *PN,
5912 Value *BEValueV,
5913 Value *StartValueV) {
5914 const Loop *L = LI.getLoopFor(PN->getParent());
5915 assert(L && L->getHeader() == PN->getParent());
5916 assert(BEValueV && StartValueV);
5917
5918 auto BO = MatchBinaryOp(BEValueV, getDataLayout(), AC, DT, PN);
5919 if (!BO)
5920 return nullptr;
5921
5922 if (BO->Opcode != Instruction::Add)
5923 return nullptr;
5924
5925 const SCEV *Accum = nullptr;
5926 if (BO->LHS == PN && L->isLoopInvariant(BO->RHS))
5927 Accum = getSCEV(BO->RHS);
5928 else if (BO->RHS == PN && L->isLoopInvariant(BO->LHS))
5929 Accum = getSCEV(BO->LHS);
5930
5931 if (!Accum)
5932 return nullptr;
5933
5935 if (BO->IsNUW)
5936 Flags = setFlags(Flags, SCEV::FlagNUW);
5937 if (BO->IsNSW)
5938 Flags = setFlags(Flags, SCEV::FlagNSW);
5939
5940 const SCEV *StartVal = getSCEV(StartValueV);
5941 const SCEV *PHISCEV = getAddRecExpr(StartVal, Accum, L, Flags);
5942 insertValueToMap(PN, PHISCEV);
5943
5944 if (auto *AR = dyn_cast<SCEVAddRecExpr>(PHISCEV)) {
5945 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR),
5947 proveNoWrapViaConstantRanges(AR)));
5948 }
5949
5950 // We can add Flags to the post-inc expression only if we
5951 // know that it is *undefined behavior* for BEValueV to
5952 // overflow.
5953 if (auto *BEInst = dyn_cast<Instruction>(BEValueV)) {
5954 assert(isLoopInvariant(Accum, L) &&
5955 "Accum is defined outside L, but is not invariant?");
5956 if (isAddRecNeverPoison(BEInst, L))
5957 (void)getAddRecExpr(getAddExpr(StartVal, Accum), Accum, L, Flags);
5958 }
5959
5960 return PHISCEV;
5961}
5962
5963const SCEV *ScalarEvolution::createAddRecFromPHI(PHINode *PN) {
5964 const Loop *L = LI.getLoopFor(PN->getParent());
5965 if (!L || L->getHeader() != PN->getParent())
5966 return nullptr;
5967
5968 // The loop may have multiple entrances or multiple exits; we can analyze
5969 // this phi as an addrec if it has a unique entry value and a unique
5970 // backedge value.
5971 Value *BEValueV = nullptr, *StartValueV = nullptr;
5972 for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) {
5973 Value *V = PN->getIncomingValue(i);
5974 if (L->contains(PN->getIncomingBlock(i))) {
5975 if (!BEValueV) {
5976 BEValueV = V;
5977 } else if (BEValueV != V) {
5978 BEValueV = nullptr;
5979 break;
5980 }
5981 } else if (!StartValueV) {
5982 StartValueV = V;
5983 } else if (StartValueV != V) {
5984 StartValueV = nullptr;
5985 break;
5986 }
5987 }
5988 if (!BEValueV || !StartValueV)
5989 return nullptr;
5990
5991 assert(ValueExprMap.find_as(PN) == ValueExprMap.end() &&
5992 "PHI node already processed?");
5993
5994 // First, try to find AddRec expression without creating a fictituos symbolic
5995 // value for PN.
5996 if (auto *S = createSimpleAffineAddRec(PN, BEValueV, StartValueV))
5997 return S;
5998
5999 // Handle PHI node value symbolically.
6000 const SCEV *SymbolicName = getUnknown(PN);
6001 insertValueToMap(PN, SymbolicName);
6002
6003 // Using this symbolic name for the PHI, analyze the value coming around
6004 // the back-edge.
6005 const SCEV *BEValue = getSCEV(BEValueV);
6006
6007 // NOTE: If BEValue is loop invariant, we know that the PHI node just
6008 // has a special value for the first iteration of the loop.
6009
6010 // If the value coming around the backedge is an add with the symbolic
6011 // value we just inserted, then we found a simple induction variable!
6012 if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(BEValue)) {
6013 // If there is a single occurrence of the symbolic value, replace it
6014 // with a recurrence.
6015 unsigned FoundIndex = Add->getNumOperands();
6016 for (unsigned i = 0, e = Add->getNumOperands(); i != e; ++i)
6017 if (Add->getOperand(i) == SymbolicName)
6018 if (FoundIndex == e) {
6019 FoundIndex = i;
6020 break;
6021 }
6022
6023 if (FoundIndex != Add->getNumOperands()) {
6024 // Create an add with everything but the specified operand.
6026 for (unsigned i = 0, e = Add->getNumOperands(); i != e; ++i)
6027 if (i != FoundIndex)
6028 Ops.push_back(SCEVBackedgeConditionFolder::rewrite(Add->getOperand(i),
6029 L, *this));
6030 const SCEV *Accum = getAddExpr(Ops);
6031
6032 // This is not a valid addrec if the step amount is varying each
6033 // loop iteration, but is not itself an addrec in this loop.
6034 if (isLoopInvariant(Accum, L) ||
6035 (isa<SCEVAddRecExpr>(Accum) &&
6036 cast<SCEVAddRecExpr>(Accum)->getLoop() == L)) {
6038
6039 if (auto BO = MatchBinaryOp(BEValueV, getDataLayout(), AC, DT, PN)) {
6040 if (BO->Opcode == Instruction::Add && BO->LHS == PN) {
6041 if (BO->IsNUW)
6042 Flags = setFlags(Flags, SCEV::FlagNUW);
6043 if (BO->IsNSW)
6044 Flags = setFlags(Flags, SCEV::FlagNSW);
6045 }
6046 } else if (GEPOperator *GEP = dyn_cast<GEPOperator>(BEValueV)) {
6047 if (GEP->getOperand(0) == PN) {
6048 GEPNoWrapFlags NW = GEP->getNoWrapFlags();
6049 // If the increment has any nowrap flags, then we know the address
6050 // space cannot be wrapped around.
6051 if (NW != GEPNoWrapFlags::none())
6052 Flags = setFlags(Flags, SCEV::FlagNW);
6053 // If the GEP is nuw or nusw with non-negative offset, we know that
6054 // no unsigned wrap occurs. We cannot set the nsw flag as only the
6055 // offset is treated as signed, while the base is unsigned.
6056 if (NW.hasNoUnsignedWrap() ||
6058 Flags = setFlags(Flags, SCEV::FlagNUW);
6059 }
6060
6061 // We cannot transfer nuw and nsw flags from subtraction
6062 // operations -- sub nuw X, Y is not the same as add nuw X, -Y
6063 // for instance.
6064 }
6065
6066 const SCEV *StartVal = getSCEV(StartValueV);
6067 const SCEV *PHISCEV = getAddRecExpr(StartVal, Accum, L, Flags);
6068
6069 // Okay, for the entire analysis of this edge we assumed the PHI
6070 // to be symbolic. We now need to go back and purge all of the
6071 // entries for the scalars that use the symbolic expression.
6072 forgetMemoizedResults({SymbolicName});
6073 insertValueToMap(PN, PHISCEV);
6074
6075 if (auto *AR = dyn_cast<SCEVAddRecExpr>(PHISCEV)) {
6076 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR),
6078 proveNoWrapViaConstantRanges(AR)));
6079 }
6080
6081 // We can add Flags to the post-inc expression only if we
6082 // know that it is *undefined behavior* for BEValueV to
6083 // overflow.
6084 if (auto *BEInst = dyn_cast<Instruction>(BEValueV))
6085 if (isLoopInvariant(Accum, L) && isAddRecNeverPoison(BEInst, L))
6086 (void)getAddRecExpr(getAddExpr(StartVal, Accum), Accum, L, Flags);
6087
6088 return PHISCEV;
6089 }
6090 }
6091 } else {
6092 // Otherwise, this could be a loop like this:
6093 // i = 0; for (j = 1; ..; ++j) { .... i = j; }
6094 // In this case, j = {1,+,1} and BEValue is j.
6095 // Because the other in-value of i (0) fits the evolution of BEValue
6096 // i really is an addrec evolution.
6097 //
6098 // We can generalize this saying that i is the shifted value of BEValue
6099 // by one iteration:
6100 // PHI(f(0), f({1,+,1})) --> f({0,+,1})
6101
6102 // Do not allow refinement in rewriting of BEValue.
6103 const SCEV *Shifted = SCEVShiftRewriter::rewrite(BEValue, L, *this);
6104 const SCEV *Start = SCEVInitRewriter::rewrite(Shifted, L, *this, false);
6105 if (Shifted != getCouldNotCompute() && Start != getCouldNotCompute() &&
6106 isGuaranteedNotToCauseUB(Shifted) && ::impliesPoison(Shifted, Start)) {
6107 const SCEV *StartVal = getSCEV(StartValueV);
6108 if (Start == StartVal) {
6109 // Okay, for the entire analysis of this edge we assumed the PHI
6110 // to be symbolic. We now need to go back and purge all of the
6111 // entries for the scalars that use the symbolic expression.
6112 forgetMemoizedResults({SymbolicName});
6113 insertValueToMap(PN, Shifted);
6114 return Shifted;
6115 }
6116 }
6117 }
6118
6119 // Remove the temporary PHI node SCEV that has been inserted while intending
6120 // to create an AddRecExpr for this PHI node. We can not keep this temporary
6121 // as it will prevent later (possibly simpler) SCEV expressions to be added
6122 // to the ValueExprMap.
6123 eraseValueFromMap(PN);
6124
6125 return nullptr;
6126}
6127
6128// Try to match a control flow sequence that branches out at BI and merges back
6129// at Merge into a "C ? LHS : RHS" select pattern. Return true on a successful
6130// match.
6132 Value *&C, Value *&LHS, Value *&RHS) {
6133 C = BI->getCondition();
6134
6135 BasicBlockEdge LeftEdge(BI->getParent(), BI->getSuccessor(0));
6136 BasicBlockEdge RightEdge(BI->getParent(), BI->getSuccessor(1));
6137
6138 Use &LeftUse = Merge->getOperandUse(0);
6139 Use &RightUse = Merge->getOperandUse(1);
6140
6141 if (DT.dominates(LeftEdge, LeftUse) && DT.dominates(RightEdge, RightUse)) {
6142 LHS = LeftUse;
6143 RHS = RightUse;
6144 return true;
6145 }
6146
6147 if (DT.dominates(LeftEdge, RightUse) && DT.dominates(RightEdge, LeftUse)) {
6148 LHS = RightUse;
6149 RHS = LeftUse;
6150 return true;
6151 }
6152
6153 return false;
6154}
6155
6157 Value *&Cond, Value *&LHS,
6158 Value *&RHS) {
6159 auto IsReachable =
6160 [&](BasicBlock *BB) { return DT.isReachableFromEntry(BB); };
6161 if (PN->getNumIncomingValues() == 2 && all_of(PN->blocks(), IsReachable)) {
6162 // Try to match
6163 //
6164 // br %cond, label %left, label %right
6165 // left:
6166 // br label %merge
6167 // right:
6168 // br label %merge
6169 // merge:
6170 // V = phi [ %x, %left ], [ %y, %right ]
6171 //
6172 // as "select %cond, %x, %y"
6173
6174 BasicBlock *IDom = DT[PN->getParent()]->getIDom()->getBlock();
6175 assert(IDom && "At least the entry block should dominate PN");
6176
6177 auto *BI = dyn_cast<CondBrInst>(IDom->getTerminator());
6178 return BI && BrPHIToSelect(DT, BI, PN, Cond, LHS, RHS);
6179 }
6180 return false;
6181}
6182
6183const SCEV *ScalarEvolution::createNodeFromSelectLikePHI(PHINode *PN) {
6184 Value *Cond = nullptr, *LHS = nullptr, *RHS = nullptr;
6185 if (getOperandsForSelectLikePHI(DT, PN, Cond, LHS, RHS) &&
6188 return createNodeForSelectOrPHI(PN, Cond, LHS, RHS);
6189
6190 return nullptr;
6191}
6192
6194 BinaryOperator *CommonInst = nullptr;
6195 // Check if instructions are identical.
6196 for (Value *Incoming : PN->incoming_values()) {
6197 auto *IncomingInst = dyn_cast<BinaryOperator>(Incoming);
6198 if (!IncomingInst)
6199 return nullptr;
6200 if (CommonInst) {
6201 if (!CommonInst->isIdenticalToWhenDefined(IncomingInst))
6202 return nullptr; // Not identical, give up
6203 } else {
6204 // Remember binary operator
6205 CommonInst = IncomingInst;
6206 }
6207 }
6208 return CommonInst;
6209}
6210
6211/// Returns SCEV for the first operand of a phi if all phi operands have
6212/// identical opcodes and operands
6213/// eg.
6214/// a: %add = %a + %b
6215/// br %c
6216/// b: %add1 = %a + %b
6217/// br %c
6218/// c: %phi = phi [%add, a], [%add1, b]
6219/// scev(%phi) => scev(%add)
6220const SCEV *
6221ScalarEvolution::createNodeForPHIWithIdenticalOperands(PHINode *PN) {
6222 BinaryOperator *CommonInst = getCommonInstForPHI(PN);
6223 if (!CommonInst)
6224 return nullptr;
6225
6226 // Check if SCEV exprs for instructions are identical.
6227 const SCEV *CommonSCEV = getSCEV(CommonInst);
6228 bool SCEVExprsIdentical =
6230 [this, CommonSCEV](Value *V) { return CommonSCEV == getSCEV(V); });
6231 return SCEVExprsIdentical ? CommonSCEV : nullptr;
6232}
6233
6234const SCEV *ScalarEvolution::createNodeForPHI(PHINode *PN) {
6235 if (const SCEV *S = createAddRecFromPHI(PN))
6236 return S;
6237
6238 // We do not allow simplifying phi (undef, X) to X here, to avoid reusing the
6239 // phi node for X.
6240 if (Value *V = simplifyInstruction(
6241 PN, {getDataLayout(), &TLI, &DT, &AC, /*CtxI=*/nullptr,
6242 /*UseInstrInfo=*/true, /*CanUseUndef=*/false}))
6243 return getSCEV(V);
6244
6245 if (const SCEV *S = createNodeForPHIWithIdenticalOperands(PN))
6246 return S;
6247
6248 if (const SCEV *S = createNodeFromSelectLikePHI(PN))
6249 return S;
6250
6251 // If it's not a loop phi, we can't handle it yet.
6252 return getUnknown(PN);
6253}
6254
6255bool SCEVMinMaxExprContains(const SCEV *Root, const SCEV *OperandToFind,
6256 SCEVTypes RootKind) {
6257 struct FindClosure {
6258 const SCEV *OperandToFind;
6259 const SCEVTypes RootKind; // Must be a sequential min/max expression.
6260 const SCEVTypes NonSequentialRootKind; // Non-seq variant of RootKind.
6261
6262 bool Found = false;
6263
6264 bool canRecurseInto(SCEVTypes Kind) const {
6265 // We can only recurse into the SCEV expression of the same effective type
6266 // as the type of our root SCEV expression, and into zero-extensions.
6267 return RootKind == Kind || NonSequentialRootKind == Kind ||
6268 scZeroExtend == Kind;
6269 };
6270
6271 FindClosure(const SCEV *OperandToFind, SCEVTypes RootKind)
6272 : OperandToFind(OperandToFind), RootKind(RootKind),
6273 NonSequentialRootKind(
6275 RootKind)) {}
6276
6277 bool follow(const SCEV *S) {
6278 Found = S == OperandToFind;
6279
6280 return !isDone() && canRecurseInto(S->getSCEVType());
6281 }
6282
6283 bool isDone() const { return Found; }
6284 };
6285
6286 FindClosure FC(OperandToFind, RootKind);
6287 visitAll(Root, FC);
6288 return FC.Found;
6289}
6290
6291std::optional<const SCEV *>
6292ScalarEvolution::createNodeForSelectOrPHIInstWithICmpInstCond(Type *Ty,
6293 ICmpInst *Cond,
6294 Value *TrueVal,
6295 Value *FalseVal) {
6296 // Try to match some simple smax or umax patterns.
6297 auto *ICI = Cond;
6298
6299 Value *LHS = ICI->getOperand(0);
6300 Value *RHS = ICI->getOperand(1);
6301
6302 switch (ICI->getPredicate()) {
6303 case ICmpInst::ICMP_SLT:
6304 case ICmpInst::ICMP_SLE:
6305 case ICmpInst::ICMP_ULT:
6306 case ICmpInst::ICMP_ULE:
6307 std::swap(LHS, RHS);
6308 [[fallthrough]];
6309 case ICmpInst::ICMP_SGT:
6310 case ICmpInst::ICMP_SGE:
6311 case ICmpInst::ICMP_UGT:
6312 case ICmpInst::ICMP_UGE:
6313 // a > b ? a+x : b+x -> max(a, b)+x
6314 // a > b ? b+x : a+x -> min(a, b)+x
6316 bool Signed = ICI->isSigned();
6317 const SCEV *LA = getSCEV(TrueVal);
6318 const SCEV *RA = getSCEV(FalseVal);
6319 const SCEV *LS = getSCEV(LHS);
6320 const SCEV *RS = getSCEV(RHS);
6321 if (LA->getType()->isPointerTy()) {
6322 // FIXME: Handle cases where LS/RS are pointers not equal to LA/RA.
6323 // Need to make sure we can't produce weird expressions involving
6324 // negated pointers.
6325 if (LA == LS && RA == RS)
6326 return Signed ? getSMaxExpr(LS, RS) : getUMaxExpr(LS, RS);
6327 if (LA == RS && RA == LS)
6328 return Signed ? getSMinExpr(LS, RS) : getUMinExpr(LS, RS);
6329 }
6330 auto CoerceOperand = [&](const SCEV *Op) -> const SCEV * {
6331 if (Op->getType()->isPointerTy()) {
6334 return Op;
6335 }
6336 if (Signed)
6337 Op = getNoopOrSignExtend(Op, Ty);
6338 else
6339 Op = getNoopOrZeroExtend(Op, Ty);
6340 return Op;
6341 };
6342 LS = CoerceOperand(LS);
6343 RS = CoerceOperand(RS);
6345 break;
6346 const SCEV *LDiff = getMinusSCEV(LA, LS);
6347 const SCEV *RDiff = getMinusSCEV(RA, RS);
6348 if (LDiff == RDiff)
6349 return getAddExpr(Signed ? getSMaxExpr(LS, RS) : getUMaxExpr(LS, RS),
6350 LDiff);
6351 LDiff = getMinusSCEV(LA, RS);
6352 RDiff = getMinusSCEV(RA, LS);
6353 if (LDiff == RDiff)
6354 return getAddExpr(Signed ? getSMinExpr(LS, RS) : getUMinExpr(LS, RS),
6355 LDiff);
6356 }
6357 break;
6358 case ICmpInst::ICMP_NE:
6359 // x != 0 ? x+y : C+y -> x == 0 ? C+y : x+y
6360 std::swap(TrueVal, FalseVal);
6361 [[fallthrough]];
6362 case ICmpInst::ICMP_EQ:
6363 // x == 0 ? C+y : x+y -> umax(x, C)+y iff C u<= 1
6366 const SCEV *X = getNoopOrZeroExtend(getSCEV(LHS), Ty);
6367 const SCEV *TrueValExpr = getSCEV(TrueVal); // C+y
6368 const SCEV *FalseValExpr = getSCEV(FalseVal); // x+y
6369 const SCEV *Y = getMinusSCEV(FalseValExpr, X); // y = (x+y)-x
6370 const SCEV *C = getMinusSCEV(TrueValExpr, Y); // C = (C+y)-y
6371 if (isa<SCEVConstant>(C) && cast<SCEVConstant>(C)->getAPInt().ule(1))
6372 return getAddExpr(getUMaxExpr(X, C), Y);
6373 }
6374 // x == 0 ? 0 : umin (..., x, ...) -> umin_seq(x, umin (...))
6375 // x == 0 ? 0 : umin_seq(..., x, ...) -> umin_seq(x, umin_seq(...))
6376 // x == 0 ? 0 : umin (..., umin_seq(..., x, ...), ...)
6377 // -> umin_seq(x, umin (..., umin_seq(...), ...))
6379 isa<ConstantInt>(TrueVal) && cast<ConstantInt>(TrueVal)->isZero()) {
6380 const SCEV *X = getSCEV(LHS);
6381 while (auto *ZExt = dyn_cast<SCEVZeroExtendExpr>(X))
6382 X = ZExt->getOperand();
6383 if (getTypeSizeInBits(X->getType()) <= getTypeSizeInBits(Ty)) {
6384 const SCEV *FalseValExpr = getSCEV(FalseVal);
6385 if (SCEVMinMaxExprContains(FalseValExpr, X, scSequentialUMinExpr))
6386 return getUMinExpr(getNoopOrZeroExtend(X, Ty), FalseValExpr,
6387 /*Sequential=*/true);
6388 }
6389 }
6390 break;
6391 default:
6392 break;
6393 }
6394
6395 return std::nullopt;
6396}
6397
6398static std::optional<const SCEV *>
6400 const SCEV *TrueExpr, const SCEV *FalseExpr) {
6401 assert(CondExpr->getType()->isIntegerTy(1) &&
6402 TrueExpr->getType() == FalseExpr->getType() &&
6403 TrueExpr->getType()->isIntegerTy(1) &&
6404 "Unexpected operands of a select.");
6405
6406 // i1 cond ? i1 x : i1 C --> C + (i1 cond ? (i1 x - i1 C) : i1 0)
6407 // --> C + (umin_seq cond, x - C)
6408 //
6409 // i1 cond ? i1 C : i1 x --> C + (i1 cond ? i1 0 : (i1 x - i1 C))
6410 // --> C + (i1 ~cond ? (i1 x - i1 C) : i1 0)
6411 // --> C + (umin_seq ~cond, x - C)
6412
6413 // FIXME: while we can't legally model the case where both of the hands
6414 // are fully variable, we only require that the *difference* is constant.
6415 if (!isa<SCEVConstant>(TrueExpr) && !isa<SCEVConstant>(FalseExpr))
6416 return std::nullopt;
6417
6418 const SCEV *X, *C;
6419 if (isa<SCEVConstant>(TrueExpr)) {
6420 CondExpr = SE->getNotSCEV(CondExpr);
6421 X = FalseExpr;
6422 C = TrueExpr;
6423 } else {
6424 X = TrueExpr;
6425 C = FalseExpr;
6426 }
6427 return SE->getAddExpr(C, SE->getUMinExpr(CondExpr, SE->getMinusSCEV(X, C),
6428 /*Sequential=*/true));
6429}
6430
6431static std::optional<const SCEV *>
6433 Value *FalseVal) {
6434 if (!isa<ConstantInt>(TrueVal) && !isa<ConstantInt>(FalseVal))
6435 return std::nullopt;
6436
6437 const auto *SECond = SE->getSCEV(Cond);
6438 const auto *SETrue = SE->getSCEV(TrueVal);
6439 const auto *SEFalse = SE->getSCEV(FalseVal);
6440 return createNodeForSelectViaUMinSeq(SE, SECond, SETrue, SEFalse);
6441}
6442
6443const SCEV *ScalarEvolution::createNodeForSelectOrPHIViaUMinSeq(
6444 Value *V, Value *Cond, Value *TrueVal, Value *FalseVal) {
6445 assert(Cond->getType()->isIntegerTy(1) && "Select condition is not an i1?");
6446 assert(TrueVal->getType() == FalseVal->getType() &&
6447 V->getType() == TrueVal->getType() &&
6448 "Types of select hands and of the result must match.");
6449
6450 // For now, only deal with i1-typed `select`s.
6451 if (!V->getType()->isIntegerTy(1))
6452 return getUnknown(V);
6453
6454 if (std::optional<const SCEV *> S =
6455 createNodeForSelectViaUMinSeq(this, Cond, TrueVal, FalseVal))
6456 return *S;
6457
6458 return getUnknown(V);
6459}
6460
6461const SCEV *ScalarEvolution::createNodeForSelectOrPHI(Value *V, Value *Cond,
6462 Value *TrueVal,
6463 Value *FalseVal) {
6464 // Handle "constant" branch or select. This can occur for instance when a
6465 // loop pass transforms an inner loop and moves on to process the outer loop.
6466 if (auto *CI = dyn_cast<ConstantInt>(Cond))
6467 return getSCEV(CI->isOne() ? TrueVal : FalseVal);
6468
6469 if (auto *I = dyn_cast<Instruction>(V)) {
6470 if (auto *ICI = dyn_cast<ICmpInst>(Cond)) {
6471 if (std::optional<const SCEV *> S =
6472 createNodeForSelectOrPHIInstWithICmpInstCond(I->getType(), ICI,
6473 TrueVal, FalseVal))
6474 return *S;
6475 }
6476 }
6477
6478 return createNodeForSelectOrPHIViaUMinSeq(V, Cond, TrueVal, FalseVal);
6479}
6480
6481/// Expand GEP instructions into add and multiply operations. This allows them
6482/// to be analyzed by regular SCEV code.
6483const SCEV *ScalarEvolution::createNodeForGEP(GEPOperator *GEP) {
6484 assert(GEP->getSourceElementType()->isSized() &&
6485 "GEP source element type must be sized");
6486
6487 SmallVector<SCEVUse, 4> IndexExprs;
6488 for (Value *Index : GEP->indices())
6489 IndexExprs.push_back(getSCEV(Index));
6490 return getGEPExpr(GEP, IndexExprs);
6491}
6492
6493APInt ScalarEvolution::getConstantMultipleImpl(const SCEV *S,
6494 const Instruction *CtxI) {
6495 uint64_t BitWidth = getTypeSizeInBits(S->getType());
6496 auto GetShiftedByZeros = [BitWidth](uint32_t TrailingZeros) {
6497 return TrailingZeros >= BitWidth
6499 : APInt::getOneBitSet(BitWidth, TrailingZeros);
6500 };
6501 auto GetGCDMultiple = [this, CtxI](const SCEVNAryExpr *N) {
6502 // The result is GCD of all operands results.
6503 APInt Res = getConstantMultiple(N->getOperand(0), CtxI);
6504 for (unsigned I = 1, E = N->getNumOperands(); I < E && Res != 1; ++I)
6506 Res, getConstantMultiple(N->getOperand(I), CtxI));
6507 return Res;
6508 };
6509
6510 switch (S->getSCEVType()) {
6511 case scConstant:
6512 return cast<SCEVConstant>(S)->getAPInt();
6513 case scPtrToAddr:
6514 case scPtrToInt:
6515 return getConstantMultiple(cast<SCEVCastExpr>(S)->getOperand());
6516 case scUDivExpr:
6517 case scVScale:
6518 return APInt(BitWidth, 1);
6519 case scTruncate: {
6520 // Only multiples that are a power of 2 will hold after truncation.
6521 const SCEVTruncateExpr *T = cast<SCEVTruncateExpr>(S);
6522 uint32_t TZ = getMinTrailingZeros(T->getOperand(), CtxI);
6523 return GetShiftedByZeros(TZ);
6524 }
6525 case scZeroExtend: {
6526 const SCEVZeroExtendExpr *Z = cast<SCEVZeroExtendExpr>(S);
6527 return getConstantMultiple(Z->getOperand(), CtxI).zext(BitWidth);
6528 }
6529 case scSignExtend: {
6530 // Only multiples that are a power of 2 will hold after sext.
6531 const SCEVSignExtendExpr *E = cast<SCEVSignExtendExpr>(S);
6532 uint32_t TZ = getMinTrailingZeros(E->getOperand(), CtxI);
6533 return GetShiftedByZeros(TZ);
6534 }
6535 case scMulExpr: {
6536 const SCEVMulExpr *M = cast<SCEVMulExpr>(S);
6537 if (M->hasNoUnsignedWrap()) {
6538 // The result is the product of all operand results.
6539 APInt Res = getConstantMultiple(M->getOperand(0), CtxI);
6540 for (const SCEV *Operand : M->operands().drop_front())
6541 Res = Res * getConstantMultiple(Operand, CtxI);
6542 return Res;
6543 }
6544
6545 // If there are no wrap guarentees, find the trailing zeros, which is the
6546 // sum of trailing zeros for all its operands.
6547 uint32_t TZ = 0;
6548 for (const SCEV *Operand : M->operands())
6549 TZ += getMinTrailingZeros(Operand, CtxI);
6550 return GetShiftedByZeros(TZ);
6551 }
6552 case scAddExpr:
6553 case scAddRecExpr: {
6554 const SCEVNAryExpr *N = cast<SCEVNAryExpr>(S);
6555 if (N->hasNoUnsignedWrap())
6556 return GetGCDMultiple(N);
6557 // Find the trailing bits, which is the minimum of its operands.
6558 uint32_t TZ = getMinTrailingZeros(N->getOperand(0), CtxI);
6559 for (const SCEV *Operand : N->operands().drop_front())
6560 TZ = std::min(TZ, getMinTrailingZeros(Operand, CtxI));
6561 return GetShiftedByZeros(TZ);
6562 }
6563 case scUMaxExpr:
6564 case scSMaxExpr:
6565 case scUMinExpr:
6566 case scSMinExpr:
6568 return GetGCDMultiple(cast<SCEVNAryExpr>(S));
6569 case scUnknown: {
6570 // Ask ValueTracking for known bits. SCEVUnknown only become available at
6571 // the point their underlying IR instruction has been defined. If CtxI was
6572 // not provided, use:
6573 // * the first instruction in the entry block if it is an argument
6574 // * the instruction itself otherwise.
6575 const SCEVUnknown *U = cast<SCEVUnknown>(S);
6576 if (!CtxI) {
6577 if (isa<Argument>(U->getValue()))
6578 CtxI = &*F.getEntryBlock().begin();
6579 else if (auto *I = dyn_cast<Instruction>(U->getValue()))
6580 CtxI = I;
6581 }
6582 unsigned Known =
6583 computeKnownBits(U->getValue(), getDataLayout(), &AC, CtxI, &DT)
6584 .countMinTrailingZeros();
6585 return GetShiftedByZeros(Known);
6586 }
6587 case scCouldNotCompute:
6588 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
6589 }
6590 llvm_unreachable("Unknown SCEV kind!");
6591}
6592
6594 const Instruction *CtxI) {
6595 // Skip looking up and updating the cache if there is a context instruction,
6596 // as the result will only be valid in the specified context.
6597 if (CtxI)
6598 return getConstantMultipleImpl(S, CtxI);
6599
6600 auto I = ConstantMultipleCache.find(S);
6601 if (I != ConstantMultipleCache.end())
6602 return I->second;
6603
6604 APInt Result = getConstantMultipleImpl(S, CtxI);
6605 auto InsertPair = ConstantMultipleCache.insert({S, Result});
6606 assert(InsertPair.second && "Should insert a new key");
6607 return InsertPair.first->second;
6608}
6609
6611 APInt Multiple = getConstantMultiple(S);
6612 return Multiple == 0 ? APInt(Multiple.getBitWidth(), 1) : Multiple;
6613}
6614
6616 const Instruction *CtxI) {
6617 return std::min(getConstantMultiple(S, CtxI).countTrailingZeros(),
6618 (unsigned)getTypeSizeInBits(S->getType()));
6619}
6620
6621/// Helper method to assign a range to V from metadata present in the IR.
6622static std::optional<ConstantRange> GetRangeFromMetadata(Value *V) {
6624 if (MDNode *MD = I->getMetadata(LLVMContext::MD_range))
6625 return getConstantRangeFromMetadata(*MD);
6626 if (const auto *CB = dyn_cast<CallBase>(V))
6627 if (std::optional<ConstantRange> Range = CB->getRange())
6628 return Range;
6629 }
6630 if (auto *A = dyn_cast<Argument>(V))
6631 if (std::optional<ConstantRange> Range = A->getRange())
6632 return Range;
6633
6634 return std::nullopt;
6635}
6636
6638 SCEV::NoWrapFlags Flags) {
6639 if (AddRec->getNoWrapFlags(Flags) != Flags) {
6640 AddRec->setNoWrapFlags(Flags);
6641 UnsignedRanges.erase(AddRec);
6642 SignedRanges.erase(AddRec);
6643 ConstantMultipleCache.erase(AddRec);
6644 }
6645}
6646
6647ConstantRange ScalarEvolution::
6648getRangeForUnknownRecurrence(const SCEVUnknown *U) {
6649 const DataLayout &DL = getDataLayout();
6650
6651 unsigned BitWidth = getTypeSizeInBits(U->getType());
6652 const ConstantRange FullSet(BitWidth, /*isFullSet=*/true);
6653
6654 // Match a simple recurrence of the form: <start, ShiftOp, Step>, and then
6655 // use information about the trip count to improve our available range. Note
6656 // that the trip count independent cases are already handled by known bits.
6657 // WARNING: The definition of recurrence used here is subtly different than
6658 // the one used by AddRec (and thus most of this file). Step is allowed to
6659 // be arbitrarily loop varying here, where AddRec allows only loop invariant
6660 // and other addrecs in the same loop (for non-affine addrecs). The code
6661 // below intentionally handles the case where step is not loop invariant.
6662 auto *P = dyn_cast<PHINode>(U->getValue());
6663 if (!P)
6664 return FullSet;
6665
6666 // Make sure that no Phi input comes from an unreachable block. Otherwise,
6667 // even the values that are not available in these blocks may come from them,
6668 // and this leads to false-positive recurrence test.
6669 for (auto *Pred : predecessors(P->getParent()))
6670 if (!DT.isReachableFromEntry(Pred))
6671 return FullSet;
6672
6673 BinaryOperator *BO;
6674 Value *Start, *Step;
6675 if (!matchSimpleRecurrence(P, BO, Start, Step))
6676 return FullSet;
6677
6678 // If we found a recurrence in reachable code, we must be in a loop. Note
6679 // that BO might be in some subloop of L, and that's completely okay.
6680 auto *L = LI.getLoopFor(P->getParent());
6681 assert(L && L->getHeader() == P->getParent());
6682 if (!L->contains(BO->getParent()))
6683 // NOTE: This bailout should be an assert instead. However, asserting
6684 // the condition here exposes a case where LoopFusion is querying SCEV
6685 // with malformed loop information during the midst of the transform.
6686 // There doesn't appear to be an obvious fix, so for the moment bailout
6687 // until the caller issue can be fixed. PR49566 tracks the bug.
6688 return FullSet;
6689
6690 // TODO: Extend to other opcodes such as mul, and div
6691 switch (BO->getOpcode()) {
6692 default:
6693 return FullSet;
6694 case Instruction::AShr:
6695 case Instruction::LShr:
6696 case Instruction::Shl:
6697 break;
6698 };
6699
6700 if (BO->getOperand(0) != P)
6701 // TODO: Handle the power function forms some day.
6702 return FullSet;
6703
6704 unsigned TC = getSmallConstantMaxTripCount(L);
6705 if (!TC || TC >= BitWidth)
6706 return FullSet;
6707
6708 auto KnownStart = computeKnownBits(Start, DL, &AC, nullptr, &DT);
6709 auto KnownStep = computeKnownBits(Step, DL, &AC, nullptr, &DT);
6710 assert(KnownStart.getBitWidth() == BitWidth &&
6711 KnownStep.getBitWidth() == BitWidth);
6712
6713 // Compute total shift amount, being careful of overflow and bitwidths.
6714 auto MaxShiftAmt = KnownStep.getMaxValue();
6715 APInt TCAP(BitWidth, TC-1);
6716 bool Overflow = false;
6717 auto TotalShift = MaxShiftAmt.umul_ov(TCAP, Overflow);
6718 if (Overflow)
6719 return FullSet;
6720
6721 switch (BO->getOpcode()) {
6722 default:
6723 llvm_unreachable("filtered out above");
6724 case Instruction::AShr: {
6725 // For each ashr, three cases:
6726 // shift = 0 => unchanged value
6727 // saturation => 0 or -1
6728 // other => a value closer to zero (of the same sign)
6729 // Thus, the end value is closer to zero than the start.
6730 auto KnownEnd = KnownBits::ashr(KnownStart,
6731 KnownBits::makeConstant(TotalShift));
6732 if (KnownStart.isNonNegative())
6733 // Analogous to lshr (simply not yet canonicalized)
6734 return ConstantRange::getNonEmpty(KnownEnd.getMinValue(),
6735 KnownStart.getMaxValue() + 1);
6736 if (KnownStart.isNegative())
6737 // End >=u Start && End <=s Start
6738 return ConstantRange::getNonEmpty(KnownStart.getMinValue(),
6739 KnownEnd.getMaxValue() + 1);
6740 break;
6741 }
6742 case Instruction::LShr: {
6743 // For each lshr, three cases:
6744 // shift = 0 => unchanged value
6745 // saturation => 0
6746 // other => a smaller positive number
6747 // Thus, the low end of the unsigned range is the last value produced.
6748 auto KnownEnd = KnownBits::lshr(KnownStart,
6749 KnownBits::makeConstant(TotalShift));
6750 return ConstantRange::getNonEmpty(KnownEnd.getMinValue(),
6751 KnownStart.getMaxValue() + 1);
6752 }
6753 case Instruction::Shl: {
6754 // Iff no bits are shifted out, value increases on every shift.
6755 auto KnownEnd = KnownBits::shl(KnownStart,
6756 KnownBits::makeConstant(TotalShift));
6757 if (TotalShift.ult(KnownStart.countMinLeadingZeros()))
6758 return ConstantRange(KnownStart.getMinValue(),
6759 KnownEnd.getMaxValue() + 1);
6760 break;
6761 }
6762 };
6763 return FullSet;
6764}
6765
6766// The goal of this function is to check if recursively visiting the operands
6767// of this PHI might lead to an infinite loop. If we do see such a loop,
6768// there's no good way to break it, so we avoid analyzing such cases.
6769//
6770// getRangeRef previously used a visited set to avoid infinite loops, but this
6771// caused other issues: the result was dependent on the order of getRangeRef
6772// calls, and the interaction with createSCEVIter could cause a stack overflow
6773// in some cases (see issue #148253).
6774//
6775// FIXME: The way this is implemented is overly conservative; this checks
6776// for a few obviously safe patterns, but anything that doesn't lead to
6777// recursion is fine.
6779 Value *Cond = nullptr, *LHS = nullptr, *RHS = nullptr;
6781 return true;
6782
6783 if (all_of(PHI->operands(),
6784 [&](Value *Operand) { return DT.dominates(Operand, PHI); }))
6785 return true;
6786
6787 return false;
6788}
6789
6790const ConstantRange &
6791ScalarEvolution::getRangeRefIter(const SCEV *S,
6792 ScalarEvolution::RangeSignHint SignHint) {
6793 DenseMap<const SCEV *, ConstantRange> &Cache =
6794 SignHint == ScalarEvolution::HINT_RANGE_UNSIGNED ? UnsignedRanges
6795 : SignedRanges;
6796 SmallVector<SCEVUse> WorkList;
6797 SmallPtrSet<const SCEV *, 8> Seen;
6798
6799 // Add Expr to the worklist, if Expr is either an N-ary expression or a
6800 // SCEVUnknown PHI node.
6801 auto AddToWorklist = [&WorkList, &Seen, &Cache](const SCEV *Expr) {
6802 if (!Seen.insert(Expr).second)
6803 return;
6804 if (Cache.contains(Expr))
6805 return;
6806 switch (Expr->getSCEVType()) {
6807 case scUnknown:
6808 if (!isa<PHINode>(cast<SCEVUnknown>(Expr)->getValue()))
6809 break;
6810 [[fallthrough]];
6811 case scConstant:
6812 case scVScale:
6813 case scTruncate:
6814 case scZeroExtend:
6815 case scSignExtend:
6816 case scPtrToAddr:
6817 case scPtrToInt:
6818 case scAddExpr:
6819 case scMulExpr:
6820 case scUDivExpr:
6821 case scAddRecExpr:
6822 case scUMaxExpr:
6823 case scSMaxExpr:
6824 case scUMinExpr:
6825 case scSMinExpr:
6827 WorkList.push_back(Expr);
6828 break;
6829 case scCouldNotCompute:
6830 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
6831 }
6832 };
6833 AddToWorklist(S);
6834
6835 // Build worklist by queuing operands of N-ary expressions and phi nodes.
6836 for (unsigned I = 0; I != WorkList.size(); ++I) {
6837 const SCEV *P = WorkList[I];
6838 auto *UnknownS = dyn_cast<SCEVUnknown>(P);
6839 // If it is not a `SCEVUnknown`, just recurse into operands.
6840 if (!UnknownS) {
6841 for (const SCEV *Op : P->operands())
6842 AddToWorklist(Op);
6843 continue;
6844 }
6845 // `SCEVUnknown`'s require special treatment.
6846 if (PHINode *P = dyn_cast<PHINode>(UnknownS->getValue())) {
6847 if (!RangeRefPHIAllowedOperands(DT, P))
6848 continue;
6849 for (auto &Op : reverse(P->operands()))
6850 AddToWorklist(getSCEV(Op));
6851 }
6852 }
6853
6854 if (!WorkList.empty()) {
6855 // Use getRangeRef to compute ranges for items in the worklist in reverse
6856 // order. This will force ranges for earlier operands to be computed before
6857 // their users in most cases.
6858 for (const SCEV *P : reverse(drop_begin(WorkList))) {
6859 getRangeRef(P, SignHint);
6860 }
6861 }
6862
6863 return getRangeRef(S, SignHint, 0);
6864}
6865
6866/// Determine the range for a particular SCEV. If SignHint is
6867/// HINT_RANGE_UNSIGNED (resp. HINT_RANGE_SIGNED) then getRange prefers ranges
6868/// with a "cleaner" unsigned (resp. signed) representation.
6869const ConstantRange &ScalarEvolution::getRangeRef(
6870 const SCEV *S, ScalarEvolution::RangeSignHint SignHint, unsigned Depth) {
6871 DenseMap<const SCEV *, ConstantRange> &Cache =
6872 SignHint == ScalarEvolution::HINT_RANGE_UNSIGNED ? UnsignedRanges
6873 : SignedRanges;
6875 SignHint == ScalarEvolution::HINT_RANGE_UNSIGNED ? ConstantRange::Unsigned
6877
6878 // See if we've computed this range already.
6880 if (I != Cache.end())
6881 return I->second;
6882
6883 if (const SCEVConstant *C = dyn_cast<SCEVConstant>(S))
6884 return setRange(C, SignHint, ConstantRange(C->getAPInt()));
6885
6886 // Switch to iteratively computing the range for S, if it is part of a deeply
6887 // nested expression.
6889 return getRangeRefIter(S, SignHint);
6890
6891 unsigned BitWidth = getTypeSizeInBits(S->getType());
6892 ConstantRange ConservativeResult(BitWidth, /*isFullSet=*/true);
6893 using OBO = OverflowingBinaryOperator;
6894
6895 // If the value has known zeros, the maximum value will have those known zeros
6896 // as well.
6897 if (SignHint == ScalarEvolution::HINT_RANGE_UNSIGNED) {
6898 APInt Multiple = getNonZeroConstantMultiple(S);
6899 APInt Remainder = APInt::getMaxValue(BitWidth).urem(Multiple);
6900 if (!Remainder.isZero())
6901 ConservativeResult =
6902 ConstantRange(APInt::getMinValue(BitWidth),
6903 APInt::getMaxValue(BitWidth) - Remainder + 1);
6904 }
6905 else {
6906 uint32_t TZ = getMinTrailingZeros(S);
6907 if (TZ != 0) {
6908 ConservativeResult = ConstantRange(
6910 APInt::getSignedMaxValue(BitWidth).ashr(TZ).shl(TZ) + 1);
6911 }
6912 }
6913
6914 switch (S->getSCEVType()) {
6915 case scConstant:
6916 llvm_unreachable("Already handled above.");
6917 case scVScale:
6918 return setRange(S, SignHint, getVScaleRange(&F, BitWidth));
6919 case scTruncate: {
6920 const SCEVTruncateExpr *Trunc = cast<SCEVTruncateExpr>(S);
6921 ConstantRange X = getRangeRef(Trunc->getOperand(), SignHint, Depth + 1);
6922 return setRange(
6923 Trunc, SignHint,
6924 ConservativeResult.intersectWith(X.truncate(BitWidth), RangeType));
6925 }
6926 case scZeroExtend: {
6927 const SCEVZeroExtendExpr *ZExt = cast<SCEVZeroExtendExpr>(S);
6928 ConstantRange X = getRangeRef(ZExt->getOperand(), SignHint, Depth + 1);
6929 return setRange(
6930 ZExt, SignHint,
6931 ConservativeResult.intersectWith(X.zeroExtend(BitWidth), RangeType));
6932 }
6933 case scSignExtend: {
6934 const SCEVSignExtendExpr *SExt = cast<SCEVSignExtendExpr>(S);
6935 ConstantRange X = getRangeRef(SExt->getOperand(), SignHint, Depth + 1);
6936 return setRange(
6937 SExt, SignHint,
6938 ConservativeResult.intersectWith(X.signExtend(BitWidth), RangeType));
6939 }
6940 case scPtrToAddr:
6941 case scPtrToInt: {
6942 const SCEVCastExpr *Cast = cast<SCEVCastExpr>(S);
6943 ConstantRange X = getRangeRef(Cast->getOperand(), SignHint, Depth + 1);
6944 return setRange(Cast, SignHint, X);
6945 }
6946 case scAddExpr: {
6947 const SCEVAddExpr *Add = cast<SCEVAddExpr>(S);
6948 // Check if this is a URem pattern: A - (A / B) * B, which is always < B.
6949 const SCEV *URemLHS = nullptr, *URemRHS = nullptr;
6950 if (SignHint == ScalarEvolution::HINT_RANGE_UNSIGNED &&
6951 match(S, m_scev_URem(m_SCEV(URemLHS), m_SCEV(URemRHS), *this))) {
6952 ConstantRange LHSRange = getRangeRef(URemLHS, SignHint, Depth + 1);
6953 ConstantRange RHSRange = getRangeRef(URemRHS, SignHint, Depth + 1);
6954 ConservativeResult =
6955 ConservativeResult.intersectWith(LHSRange.urem(RHSRange), RangeType);
6956 }
6957 ConstantRange X = getRangeRef(Add->getOperand(0), SignHint, Depth + 1);
6958 unsigned WrapType = OBO::AnyWrap;
6959 if (Add->hasNoSignedWrap())
6960 WrapType |= OBO::NoSignedWrap;
6961 if (Add->hasNoUnsignedWrap())
6962 WrapType |= OBO::NoUnsignedWrap;
6963 for (const SCEV *Op : drop_begin(Add->operands()))
6964 X = X.addWithNoWrap(getRangeRef(Op, SignHint, Depth + 1), WrapType,
6965 RangeType);
6966 return setRange(Add, SignHint,
6967 ConservativeResult.intersectWith(X, RangeType));
6968 }
6969 case scMulExpr: {
6970 const SCEVMulExpr *Mul = cast<SCEVMulExpr>(S);
6971 ConstantRange X = getRangeRef(Mul->getOperand(0), SignHint, Depth + 1);
6972 for (const SCEV *Op : drop_begin(Mul->operands()))
6973 X = X.multiply(getRangeRef(Op, SignHint, Depth + 1));
6974 return setRange(Mul, SignHint,
6975 ConservativeResult.intersectWith(X, RangeType));
6976 }
6977 case scUDivExpr: {
6978 const SCEVUDivExpr *UDiv = cast<SCEVUDivExpr>(S);
6979 ConstantRange X = getRangeRef(UDiv->getLHS(), SignHint, Depth + 1);
6980 ConstantRange Y = getRangeRef(UDiv->getRHS(), SignHint, Depth + 1);
6981 return setRange(UDiv, SignHint,
6982 ConservativeResult.intersectWith(X.udiv(Y), RangeType));
6983 }
6984 case scAddRecExpr: {
6985 const SCEVAddRecExpr *AddRec = cast<SCEVAddRecExpr>(S);
6986 // If there's no unsigned wrap, the value will never be less than its
6987 // initial value.
6988 if (AddRec->hasNoUnsignedWrap()) {
6989 APInt UnsignedMinValue = getUnsignedRangeMin(AddRec->getStart());
6990 if (!UnsignedMinValue.isZero())
6991 ConservativeResult = ConservativeResult.intersectWith(
6992 ConstantRange(UnsignedMinValue, APInt(BitWidth, 0)), RangeType);
6993 }
6994
6995 // If there's no signed wrap, and all the operands except initial value have
6996 // the same sign or zero, the value won't ever be:
6997 // 1: smaller than initial value if operands are non negative,
6998 // 2: bigger than initial value if operands are non positive.
6999 // For both cases, value can not cross signed min/max boundary.
7000 if (AddRec->hasNoSignedWrap()) {
7001 bool AllNonNeg = true;
7002 bool AllNonPos = true;
7003 for (unsigned i = 1, e = AddRec->getNumOperands(); i != e; ++i) {
7004 if (!isKnownNonNegative(AddRec->getOperand(i)))
7005 AllNonNeg = false;
7006 if (!isKnownNonPositive(AddRec->getOperand(i)))
7007 AllNonPos = false;
7008 }
7009 if (AllNonNeg)
7010 ConservativeResult = ConservativeResult.intersectWith(
7013 RangeType);
7014 else if (AllNonPos)
7015 ConservativeResult = ConservativeResult.intersectWith(
7017 getSignedRangeMax(AddRec->getStart()) +
7018 1),
7019 RangeType);
7020 }
7021
7022 // TODO: non-affine addrec
7023 if (AddRec->isAffine()) {
7024 const SCEV *MaxBEScev =
7026 if (!isa<SCEVCouldNotCompute>(MaxBEScev)) {
7027 APInt MaxBECount = cast<SCEVConstant>(MaxBEScev)->getAPInt();
7028
7029 // Adjust MaxBECount to the same bitwidth as AddRec. We can truncate if
7030 // MaxBECount's active bits are all <= AddRec's bit width.
7031 if (MaxBECount.getBitWidth() > BitWidth &&
7032 MaxBECount.getActiveBits() <= BitWidth)
7033 MaxBECount = MaxBECount.trunc(BitWidth);
7034 else if (MaxBECount.getBitWidth() < BitWidth)
7035 MaxBECount = MaxBECount.zext(BitWidth);
7036
7037 if (MaxBECount.getBitWidth() == BitWidth) {
7038 auto RangeFromAffine = getRangeForAffineAR(
7039 AddRec->getStart(), AddRec->getStepRecurrence(*this), MaxBECount);
7040 ConservativeResult =
7041 ConservativeResult.intersectWith(RangeFromAffine, RangeType);
7042
7043 auto RangeFromFactoring = getRangeViaFactoring(
7044 AddRec->getStart(), AddRec->getStepRecurrence(*this), MaxBECount);
7045 ConservativeResult =
7046 ConservativeResult.intersectWith(RangeFromFactoring, RangeType);
7047 }
7048 }
7049
7050 // Now try symbolic BE count and more powerful methods.
7052 const SCEV *SymbolicMaxBECount =
7054 if (!isa<SCEVCouldNotCompute>(SymbolicMaxBECount) &&
7055 getTypeSizeInBits(MaxBEScev->getType()) <= BitWidth &&
7056 AddRec->hasNoSelfWrap()) {
7057 auto RangeFromAffineNew = getRangeForAffineNoSelfWrappingAR(
7058 AddRec, SymbolicMaxBECount, BitWidth, SignHint);
7059 ConservativeResult =
7060 ConservativeResult.intersectWith(RangeFromAffineNew, RangeType);
7061 }
7062 }
7063 }
7064
7065 return setRange(AddRec, SignHint, std::move(ConservativeResult));
7066 }
7067 case scUMaxExpr:
7068 case scSMaxExpr:
7069 case scUMinExpr:
7070 case scSMinExpr:
7071 case scSequentialUMinExpr: {
7073 switch (S->getSCEVType()) {
7074 case scUMaxExpr:
7075 ID = Intrinsic::umax;
7076 break;
7077 case scSMaxExpr:
7078 ID = Intrinsic::smax;
7079 break;
7080 case scUMinExpr:
7082 ID = Intrinsic::umin;
7083 break;
7084 case scSMinExpr:
7085 ID = Intrinsic::smin;
7086 break;
7087 default:
7088 llvm_unreachable("Unknown SCEVMinMaxExpr/SCEVSequentialMinMaxExpr.");
7089 }
7090
7091 const auto *NAry = cast<SCEVNAryExpr>(S);
7092 ConstantRange X = getRangeRef(NAry->getOperand(0), SignHint, Depth + 1);
7093 for (unsigned i = 1, e = NAry->getNumOperands(); i != e; ++i)
7094 X = X.intrinsic(
7095 ID, {X, getRangeRef(NAry->getOperand(i), SignHint, Depth + 1)});
7096 return setRange(S, SignHint,
7097 ConservativeResult.intersectWith(X, RangeType));
7098 }
7099 case scUnknown: {
7100 const SCEVUnknown *U = cast<SCEVUnknown>(S);
7101 Value *V = U->getValue();
7102
7103 // Check if the IR explicitly contains !range metadata.
7104 std::optional<ConstantRange> MDRange = GetRangeFromMetadata(V);
7105 if (MDRange)
7106 ConservativeResult =
7107 ConservativeResult.intersectWith(*MDRange, RangeType);
7108
7109 // Use facts about recurrences in the underlying IR. Note that add
7110 // recurrences are AddRecExprs and thus don't hit this path. This
7111 // primarily handles shift recurrences.
7112 auto CR = getRangeForUnknownRecurrence(U);
7113 ConservativeResult = ConservativeResult.intersectWith(CR);
7114
7115 // See if ValueTracking can give us a useful range.
7116 const DataLayout &DL = getDataLayout();
7117 KnownBits Known = computeKnownBits(V, DL, &AC, nullptr, &DT);
7118 if (Known.getBitWidth() != BitWidth)
7119 Known = Known.zextOrTrunc(BitWidth);
7120
7121 // ValueTracking may be able to compute a tighter result for the number of
7122 // sign bits than for the value of those sign bits.
7123 unsigned NS = ComputeNumSignBits(V, DL, &AC, nullptr, &DT);
7124 if (U->getType()->isPointerTy()) {
7125 // If the pointer size is larger than the index size type, this can cause
7126 // NS to be larger than BitWidth. So compensate for this.
7127 unsigned ptrSize = DL.getPointerTypeSizeInBits(U->getType());
7128 int ptrIdxDiff = ptrSize - BitWidth;
7129 if (ptrIdxDiff > 0 && ptrSize > BitWidth && NS > (unsigned)ptrIdxDiff)
7130 NS -= ptrIdxDiff;
7131 }
7132
7133 if (NS > 1) {
7134 // If we know any of the sign bits, we know all of the sign bits.
7135 if (!Known.Zero.getHiBits(NS).isZero())
7136 Known.Zero.setHighBits(NS);
7137 if (!Known.One.getHiBits(NS).isZero())
7138 Known.One.setHighBits(NS);
7139 }
7140
7141 if (Known.getMinValue() != Known.getMaxValue() + 1)
7142 ConservativeResult = ConservativeResult.intersectWith(
7143 ConstantRange(Known.getMinValue(), Known.getMaxValue() + 1),
7144 RangeType);
7145 if (NS > 1)
7146 ConservativeResult = ConservativeResult.intersectWith(
7147 ConstantRange(APInt::getSignedMinValue(BitWidth).ashr(NS - 1),
7148 APInt::getSignedMaxValue(BitWidth).ashr(NS - 1) + 1),
7149 RangeType);
7150
7151 if (U->getType()->isPointerTy() && SignHint == HINT_RANGE_UNSIGNED) {
7152 // Strengthen the range if the underlying IR value is a
7153 // global/alloca/heap allocation using the size of the object.
7154 bool CanBeNull, CanBeFreed;
7155 uint64_t DerefBytes =
7156 V->getPointerDereferenceableBytes(DL, CanBeNull, CanBeFreed);
7157 if (DerefBytes > 1 && isUIntN(BitWidth, DerefBytes)) {
7158 // The highest address the object can start is DerefBytes bytes before
7159 // the end (unsigned max value). If this value is not a multiple of the
7160 // alignment, the last possible start value is the next lowest multiple
7161 // of the alignment. Note: The computations below cannot overflow,
7162 // because if they would there's no possible start address for the
7163 // object.
7164 APInt MaxVal =
7165 APInt::getMaxValue(BitWidth) - APInt(BitWidth, DerefBytes);
7166 uint64_t Align = U->getValue()->getPointerAlignment(DL).value();
7167 uint64_t Rem = MaxVal.urem(Align);
7168 MaxVal -= APInt(BitWidth, Rem);
7169 APInt MinVal = APInt::getZero(BitWidth);
7170 if (llvm::isKnownNonZero(V, DL))
7171 MinVal = Align;
7172 ConservativeResult = ConservativeResult.intersectWith(
7173 ConstantRange::getNonEmpty(MinVal, MaxVal + 1), RangeType);
7174 }
7175 }
7176
7177 // A range of Phi is a subset of union of all ranges of its input.
7178 if (PHINode *Phi = dyn_cast<PHINode>(V)) {
7179 // SCEVExpander sometimes creates SCEVUnknowns that are secretly
7180 // AddRecs; return the range for the corresponding AddRec.
7181 if (auto *AR = dyn_cast<SCEVAddRecExpr>(getSCEV(V)))
7182 return getRangeRef(AR, SignHint, Depth + 1);
7183
7184 // Make sure that we do not run over cycled Phis.
7185 if (RangeRefPHIAllowedOperands(DT, Phi)) {
7186 ConstantRange RangeFromOps(BitWidth, /*isFullSet=*/false);
7187
7188 for (const auto &Op : Phi->operands()) {
7189 auto OpRange = getRangeRef(getSCEV(Op), SignHint, Depth + 1);
7190 RangeFromOps = RangeFromOps.unionWith(OpRange);
7191 // No point to continue if we already have a full set.
7192 if (RangeFromOps.isFullSet())
7193 break;
7194 }
7195 ConservativeResult =
7196 ConservativeResult.intersectWith(RangeFromOps, RangeType);
7197 }
7198 }
7199
7200 // vscale can't be equal to zero
7201 if (const auto *II = dyn_cast<IntrinsicInst>(V))
7202 if (II->getIntrinsicID() == Intrinsic::vscale) {
7203 ConstantRange Disallowed = APInt::getZero(BitWidth);
7204 ConservativeResult = ConservativeResult.difference(Disallowed);
7205 }
7206
7207 return setRange(U, SignHint, std::move(ConservativeResult));
7208 }
7209 case scCouldNotCompute:
7210 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
7211 }
7212
7213 return setRange(S, SignHint, std::move(ConservativeResult));
7214}
7215
7216// Given a StartRange, Step and MaxBECount for an expression compute a range of
7217// values that the expression can take. Initially, the expression has a value
7218// from StartRange and then is changed by Step up to MaxBECount times. Signed
7219// argument defines if we treat Step as signed or unsigned.
7221 const ConstantRange &StartRange,
7222 const APInt &MaxBECount,
7223 bool Signed) {
7224 unsigned BitWidth = Step.getBitWidth();
7225 assert(BitWidth == StartRange.getBitWidth() &&
7226 BitWidth == MaxBECount.getBitWidth() && "mismatched bit widths");
7227 // If either Step or MaxBECount is 0, then the expression won't change, and we
7228 // just need to return the initial range.
7229 if (Step == 0 || MaxBECount == 0)
7230 return StartRange;
7231
7232 // If we don't know anything about the initial value (i.e. StartRange is
7233 // FullRange), then we don't know anything about the final range either.
7234 // Return FullRange.
7235 if (StartRange.isFullSet())
7236 return ConstantRange::getFull(BitWidth);
7237
7238 // If Step is signed and negative, then we use its absolute value, but we also
7239 // note that we're moving in the opposite direction.
7240 bool Descending = Signed && Step.isNegative();
7241
7242 if (Signed)
7243 // This is correct even for INT_SMIN. Let's look at i8 to illustrate this:
7244 // abs(INT_SMIN) = abs(-128) = abs(0x80) = -0x80 = 0x80 = 128.
7245 // This equations hold true due to the well-defined wrap-around behavior of
7246 // APInt.
7247 Step = Step.abs();
7248
7249 // Check if Offset is more than full span of BitWidth. If it is, the
7250 // expression is guaranteed to overflow.
7251 if (APInt::getMaxValue(StartRange.getBitWidth()).udiv(Step).ult(MaxBECount))
7252 return ConstantRange::getFull(BitWidth);
7253
7254 // Offset is by how much the expression can change. Checks above guarantee no
7255 // overflow here.
7256 APInt Offset = Step * MaxBECount;
7257
7258 // Minimum value of the final range will match the minimal value of StartRange
7259 // if the expression is increasing and will be decreased by Offset otherwise.
7260 // Maximum value of the final range will match the maximal value of StartRange
7261 // if the expression is decreasing and will be increased by Offset otherwise.
7262 APInt StartLower = StartRange.getLower();
7263 APInt StartUpper = StartRange.getUpper() - 1;
7264 APInt MovedBoundary = Descending ? (StartLower - std::move(Offset))
7265 : (StartUpper + std::move(Offset));
7266
7267 // It's possible that the new minimum/maximum value will fall into the initial
7268 // range (due to wrap around). This means that the expression can take any
7269 // value in this bitwidth, and we have to return full range.
7270 if (StartRange.contains(MovedBoundary))
7271 return ConstantRange::getFull(BitWidth);
7272
7273 APInt NewLower =
7274 Descending ? std::move(MovedBoundary) : std::move(StartLower);
7275 APInt NewUpper =
7276 Descending ? std::move(StartUpper) : std::move(MovedBoundary);
7277 NewUpper += 1;
7278
7279 // No overflow detected, return [StartLower, StartUpper + Offset + 1) range.
7280 return ConstantRange::getNonEmpty(std::move(NewLower), std::move(NewUpper));
7281}
7282
7283ConstantRange ScalarEvolution::getRangeForAffineAR(const SCEV *Start,
7284 const SCEV *Step,
7285 const APInt &MaxBECount) {
7286 assert(getTypeSizeInBits(Start->getType()) ==
7287 getTypeSizeInBits(Step->getType()) &&
7288 getTypeSizeInBits(Start->getType()) == MaxBECount.getBitWidth() &&
7289 "mismatched bit widths");
7290
7291 // First, consider step signed.
7292 ConstantRange StartSRange = getSignedRange(Start);
7293 ConstantRange StepSRange = getSignedRange(Step);
7294
7295 // If Step can be both positive and negative, we need to find ranges for the
7296 // maximum absolute step values in both directions and union them.
7297 ConstantRange SR = getRangeForAffineARHelper(
7298 StepSRange.getSignedMin(), StartSRange, MaxBECount, /* Signed = */ true);
7300 StartSRange, MaxBECount,
7301 /* Signed = */ true));
7302
7303 // Next, consider step unsigned.
7304 ConstantRange UR = getRangeForAffineARHelper(
7305 getUnsignedRangeMax(Step), getUnsignedRange(Start), MaxBECount,
7306 /* Signed = */ false);
7307
7308 // Finally, intersect signed and unsigned ranges.
7310}
7311
7312ConstantRange ScalarEvolution::getRangeForAffineNoSelfWrappingAR(
7313 const SCEVAddRecExpr *AddRec, const SCEV *MaxBECount, unsigned BitWidth,
7314 ScalarEvolution::RangeSignHint SignHint) {
7315 assert(AddRec->isAffine() && "Non-affine AddRecs are not suppored!\n");
7316 assert(AddRec->hasNoSelfWrap() &&
7317 "This only works for non-self-wrapping AddRecs!");
7318 const bool IsSigned = SignHint == HINT_RANGE_SIGNED;
7319 const SCEV *Step = AddRec->getStepRecurrence(*this);
7320 // Only deal with constant step to save compile time.
7321 if (!isa<SCEVConstant>(Step))
7322 return ConstantRange::getFull(BitWidth);
7323 // Let's make sure that we can prove that we do not self-wrap during
7324 // MaxBECount iterations. We need this because MaxBECount is a maximum
7325 // iteration count estimate, and we might infer nw from some exit for which we
7326 // do not know max exit count (or any other side reasoning).
7327 // TODO: Turn into assert at some point.
7328 if (getTypeSizeInBits(MaxBECount->getType()) >
7329 getTypeSizeInBits(AddRec->getType()))
7330 return ConstantRange::getFull(BitWidth);
7331 MaxBECount = getNoopOrZeroExtend(MaxBECount, AddRec->getType());
7332 const SCEV *RangeWidth = getMinusOne(AddRec->getType());
7333 const SCEV *StepAbs = getUMinExpr(Step, getNegativeSCEV(Step));
7334 const SCEV *MaxItersWithoutWrap = getUDivExpr(RangeWidth, StepAbs);
7335 if (!isKnownPredicateViaConstantRanges(ICmpInst::ICMP_ULE, MaxBECount,
7336 MaxItersWithoutWrap))
7337 return ConstantRange::getFull(BitWidth);
7338
7339 ICmpInst::Predicate LEPred =
7341 ICmpInst::Predicate GEPred =
7343 const SCEV *End = AddRec->evaluateAtIteration(MaxBECount, *this);
7344
7345 // We know that there is no self-wrap. Let's take Start and End values and
7346 // look at all intermediate values V1, V2, ..., Vn that IndVar takes during
7347 // the iteration. They either lie inside the range [Min(Start, End),
7348 // Max(Start, End)] or outside it:
7349 //
7350 // Case 1: RangeMin ... Start V1 ... VN End ... RangeMax;
7351 // Case 2: RangeMin Vk ... V1 Start ... End Vn ... Vk + 1 RangeMax;
7352 //
7353 // No self wrap flag guarantees that the intermediate values cannot be BOTH
7354 // outside and inside the range [Min(Start, End), Max(Start, End)]. Using that
7355 // knowledge, let's try to prove that we are dealing with Case 1. It is so if
7356 // Start <= End and step is positive, or Start >= End and step is negative.
7357 const SCEV *Start = applyLoopGuards(AddRec->getStart(), AddRec->getLoop());
7358 ConstantRange StartRange = getRangeRef(Start, SignHint);
7359 ConstantRange EndRange = getRangeRef(End, SignHint);
7360 ConstantRange RangeBetween = StartRange.unionWith(EndRange);
7361 // If they already cover full iteration space, we will know nothing useful
7362 // even if we prove what we want to prove.
7363 if (RangeBetween.isFullSet())
7364 return RangeBetween;
7365 // Only deal with ranges that do not wrap (i.e. RangeMin < RangeMax).
7366 bool IsWrappedSet = IsSigned ? RangeBetween.isSignWrappedSet()
7367 : RangeBetween.isWrappedSet();
7368 if (IsWrappedSet)
7369 return ConstantRange::getFull(BitWidth);
7370
7371 if (isKnownPositive(Step) &&
7372 isKnownPredicateViaConstantRanges(LEPred, Start, End))
7373 return RangeBetween;
7374 if (isKnownNegative(Step) &&
7375 isKnownPredicateViaConstantRanges(GEPred, Start, End))
7376 return RangeBetween;
7377 return ConstantRange::getFull(BitWidth);
7378}
7379
7380ConstantRange ScalarEvolution::getRangeViaFactoring(const SCEV *Start,
7381 const SCEV *Step,
7382 const APInt &MaxBECount) {
7383 // RangeOf({C?A:B,+,C?P:Q}) == RangeOf(C?{A,+,P}:{B,+,Q})
7384 // == RangeOf({A,+,P}) union RangeOf({B,+,Q})
7385
7386 unsigned BitWidth = MaxBECount.getBitWidth();
7387 assert(getTypeSizeInBits(Start->getType()) == BitWidth &&
7388 getTypeSizeInBits(Step->getType()) == BitWidth &&
7389 "mismatched bit widths");
7390
7391 struct SelectPattern {
7392 Value *Condition = nullptr;
7393 APInt TrueValue;
7394 APInt FalseValue;
7395
7396 explicit SelectPattern(ScalarEvolution &SE, unsigned BitWidth,
7397 const SCEV *S) {
7398 std::optional<unsigned> CastOp;
7399 APInt Offset(BitWidth, 0);
7400
7402 "Should be!");
7403
7404 // Peel off a constant offset. In the future we could consider being
7405 // smarter here and handle {Start+Step,+,Step} too.
7406 const APInt *Off;
7407 if (match(S, m_scev_Add(m_scev_APInt(Off), m_SCEV(S))))
7408 Offset = *Off;
7409
7410 // Peel off a cast operation
7411 if (auto *SCast = dyn_cast<SCEVIntegralCastExpr>(S)) {
7412 CastOp = SCast->getSCEVType();
7413 S = SCast->getOperand();
7414 }
7415
7416 using namespace llvm::PatternMatch;
7417
7418 auto *SU = dyn_cast<SCEVUnknown>(S);
7419 const APInt *TrueVal, *FalseVal;
7420 if (!SU ||
7421 !match(SU->getValue(), m_Select(m_Value(Condition), m_APInt(TrueVal),
7422 m_APInt(FalseVal)))) {
7423 Condition = nullptr;
7424 return;
7425 }
7426
7427 TrueValue = *TrueVal;
7428 FalseValue = *FalseVal;
7429
7430 // Re-apply the cast we peeled off earlier
7431 if (CastOp)
7432 switch (*CastOp) {
7433 default:
7434 llvm_unreachable("Unknown SCEV cast type!");
7435
7436 case scTruncate:
7437 TrueValue = TrueValue.trunc(BitWidth);
7438 FalseValue = FalseValue.trunc(BitWidth);
7439 break;
7440 case scZeroExtend:
7441 TrueValue = TrueValue.zext(BitWidth);
7442 FalseValue = FalseValue.zext(BitWidth);
7443 break;
7444 case scSignExtend:
7445 TrueValue = TrueValue.sext(BitWidth);
7446 FalseValue = FalseValue.sext(BitWidth);
7447 break;
7448 }
7449
7450 // Re-apply the constant offset we peeled off earlier
7451 TrueValue += Offset;
7452 FalseValue += Offset;
7453 }
7454
7455 bool isRecognized() { return Condition != nullptr; }
7456 };
7457
7458 SelectPattern StartPattern(*this, BitWidth, Start);
7459 if (!StartPattern.isRecognized())
7460 return ConstantRange::getFull(BitWidth);
7461
7462 SelectPattern StepPattern(*this, BitWidth, Step);
7463 if (!StepPattern.isRecognized())
7464 return ConstantRange::getFull(BitWidth);
7465
7466 if (StartPattern.Condition != StepPattern.Condition) {
7467 // We don't handle this case today; but we could, by considering four
7468 // possibilities below instead of two. I'm not sure if there are cases where
7469 // that will help over what getRange already does, though.
7470 return ConstantRange::getFull(BitWidth);
7471 }
7472
7473 // NB! Calling ScalarEvolution::getConstant is fine, but we should not try to
7474 // construct arbitrary general SCEV expressions here. This function is called
7475 // from deep in the call stack, and calling getSCEV (on a sext instruction,
7476 // say) can end up caching a suboptimal value.
7477
7478 // FIXME: without the explicit `this` receiver below, MSVC errors out with
7479 // C2352 and C2512 (otherwise it isn't needed).
7480
7481 const SCEV *TrueStart = this->getConstant(StartPattern.TrueValue);
7482 const SCEV *TrueStep = this->getConstant(StepPattern.TrueValue);
7483 const SCEV *FalseStart = this->getConstant(StartPattern.FalseValue);
7484 const SCEV *FalseStep = this->getConstant(StepPattern.FalseValue);
7485
7486 ConstantRange TrueRange =
7487 this->getRangeForAffineAR(TrueStart, TrueStep, MaxBECount);
7488 ConstantRange FalseRange =
7489 this->getRangeForAffineAR(FalseStart, FalseStep, MaxBECount);
7490
7491 return TrueRange.unionWith(FalseRange);
7492}
7493
7494SCEV::NoWrapFlags ScalarEvolution::getNoWrapFlagsFromUB(const Value *V) {
7495 if (isa<ConstantExpr>(V)) return SCEV::FlagAnyWrap;
7496 const BinaryOperator *BinOp = cast<BinaryOperator>(V);
7497
7498 // Return early if there are no flags to propagate to the SCEV.
7500 if (BinOp->hasNoUnsignedWrap())
7502 if (BinOp->hasNoSignedWrap())
7504 if (Flags == SCEV::FlagAnyWrap)
7505 return SCEV::FlagAnyWrap;
7506
7507 return isSCEVExprNeverPoison(BinOp) ? Flags : SCEV::FlagAnyWrap;
7508}
7509
7510const Instruction *
7511ScalarEvolution::getNonTrivialDefiningScopeBound(const SCEV *S) {
7512 if (auto *AddRec = dyn_cast<SCEVAddRecExpr>(S))
7513 return &*AddRec->getLoop()->getHeader()->begin();
7514 if (auto *U = dyn_cast<SCEVUnknown>(S))
7515 if (auto *I = dyn_cast<Instruction>(U->getValue()))
7516 return I;
7517 return nullptr;
7518}
7519
7520const Instruction *ScalarEvolution::getDefiningScopeBound(ArrayRef<SCEVUse> Ops,
7521 bool &Precise) {
7522 Precise = true;
7523 // Do a bounded search of the def relation of the requested SCEVs.
7524 SmallPtrSet<const SCEV *, 16> Visited;
7525 SmallVector<SCEVUse> Worklist;
7526 auto pushOp = [&](const SCEV *S) {
7527 if (!Visited.insert(S).second)
7528 return;
7529 // Threshold of 30 here is arbitrary.
7530 if (Visited.size() > 30) {
7531 Precise = false;
7532 return;
7533 }
7534 Worklist.push_back(S);
7535 };
7536
7537 for (SCEVUse S : Ops)
7538 pushOp(S);
7539
7540 const Instruction *Bound = nullptr;
7541 while (!Worklist.empty()) {
7542 SCEVUse S = Worklist.pop_back_val();
7543 if (auto *DefI = getNonTrivialDefiningScopeBound(S)) {
7544 if (!Bound || DT.dominates(Bound, DefI))
7545 Bound = DefI;
7546 } else {
7547 for (SCEVUse Op : S->operands())
7548 pushOp(Op);
7549 }
7550 }
7551 return Bound ? Bound : &*F.getEntryBlock().begin();
7552}
7553
7554const Instruction *
7555ScalarEvolution::getDefiningScopeBound(ArrayRef<SCEVUse> Ops) {
7556 bool Discard;
7557 return getDefiningScopeBound(Ops, Discard);
7558}
7559
7560bool ScalarEvolution::isGuaranteedToTransferExecutionTo(const Instruction *A,
7561 const Instruction *B) {
7562 if (A->getParent() == B->getParent() &&
7564 B->getIterator()))
7565 return true;
7566
7567 auto *BLoop = LI.getLoopFor(B->getParent());
7568 if (BLoop && BLoop->getHeader() == B->getParent() &&
7569 BLoop->getLoopPreheader() == A->getParent() &&
7571 A->getParent()->end()) &&
7572 isGuaranteedToTransferExecutionToSuccessor(B->getParent()->begin(),
7573 B->getIterator()))
7574 return true;
7575 return false;
7576}
7577
7578bool ScalarEvolution::isGuaranteedNotToBePoison(const SCEV *Op) {
7579 SCEVPoisonCollector PC(/* LookThroughMaybePoisonBlocking */ true);
7580 visitAll(Op, PC);
7581 return PC.MaybePoison.empty();
7582}
7583
7584bool ScalarEvolution::isGuaranteedNotToCauseUB(const SCEV *Op) {
7585 return !SCEVExprContains(Op, [this](const SCEV *S) {
7586 const SCEV *Op1;
7587 bool M = match(S, m_scev_UDiv(m_SCEV(), m_SCEV(Op1)));
7588 // The UDiv may be UB if the divisor is poison or zero. Unless the divisor
7589 // is a non-zero constant, we have to assume the UDiv may be UB.
7590 return M && (!isKnownNonZero(Op1) || !isGuaranteedNotToBePoison(Op1));
7591 });
7592}
7593
7594bool ScalarEvolution::isSCEVExprNeverPoison(const Instruction *I) {
7595 // Only proceed if we can prove that I does not yield poison.
7597 return false;
7598
7599 // At this point we know that if I is executed, then it does not wrap
7600 // according to at least one of NSW or NUW. If I is not executed, then we do
7601 // not know if the calculation that I represents would wrap. Multiple
7602 // instructions can map to the same SCEV. If we apply NSW or NUW from I to
7603 // the SCEV, we must guarantee no wrapping for that SCEV also when it is
7604 // derived from other instructions that map to the same SCEV. We cannot make
7605 // that guarantee for cases where I is not executed. So we need to find a
7606 // upper bound on the defining scope for the SCEV, and prove that I is
7607 // executed every time we enter that scope. When the bounding scope is a
7608 // loop (the common case), this is equivalent to proving I executes on every
7609 // iteration of that loop.
7610 SmallVector<SCEVUse> SCEVOps;
7611 for (const Use &Op : I->operands()) {
7612 // I could be an extractvalue from a call to an overflow intrinsic.
7613 // TODO: We can do better here in some cases.
7614 if (isSCEVable(Op->getType()))
7615 SCEVOps.push_back(getSCEV(Op));
7616 }
7617 auto *DefI = getDefiningScopeBound(SCEVOps);
7618 return isGuaranteedToTransferExecutionTo(DefI, I);
7619}
7620
7621bool ScalarEvolution::isAddRecNeverPoison(const Instruction *I, const Loop *L) {
7622 // If we know that \c I can never be poison period, then that's enough.
7623 if (isSCEVExprNeverPoison(I))
7624 return true;
7625
7626 // If the loop only has one exit, then we know that, if the loop is entered,
7627 // any instruction dominating that exit will be executed. If any such
7628 // instruction would result in UB, the addrec cannot be poison.
7629 //
7630 // This is basically the same reasoning as in isSCEVExprNeverPoison(), but
7631 // also handles uses outside the loop header (they just need to dominate the
7632 // single exit).
7633
7634 auto *ExitingBB = L->getExitingBlock();
7635 if (!ExitingBB || !loopHasNoAbnormalExits(L))
7636 return false;
7637
7638 SmallPtrSet<const Value *, 16> KnownPoison;
7640
7641 // We start by assuming \c I, the post-inc add recurrence, is poison. Only
7642 // things that are known to be poison under that assumption go on the
7643 // Worklist.
7644 KnownPoison.insert(I);
7645 Worklist.push_back(I);
7646
7647 while (!Worklist.empty()) {
7648 const Instruction *Poison = Worklist.pop_back_val();
7649
7650 for (const Use &U : Poison->uses()) {
7651 const Instruction *PoisonUser = cast<Instruction>(U.getUser());
7652 if (mustTriggerUB(PoisonUser, KnownPoison) &&
7653 DT.dominates(PoisonUser->getParent(), ExitingBB))
7654 return true;
7655
7656 if (propagatesPoison(U) && L->contains(PoisonUser))
7657 if (KnownPoison.insert(PoisonUser).second)
7658 Worklist.push_back(PoisonUser);
7659 }
7660 }
7661
7662 return false;
7663}
7664
7665ScalarEvolution::LoopProperties
7666ScalarEvolution::getLoopProperties(const Loop *L) {
7667 using LoopProperties = ScalarEvolution::LoopProperties;
7668
7669 auto Itr = LoopPropertiesCache.find(L);
7670 if (Itr == LoopPropertiesCache.end()) {
7671 auto HasSideEffects = [](Instruction *I) {
7672 if (auto *SI = dyn_cast<StoreInst>(I))
7673 return !SI->isSimple();
7674
7675 if (I->mayThrow())
7676 return true;
7677
7678 // Non-volatile memset / memcpy do not count as side-effect for forward
7679 // progress.
7680 if (isa<MemIntrinsic>(I) && !I->isVolatile())
7681 return false;
7682
7683 return I->mayWriteToMemory();
7684 };
7685
7686 LoopProperties LP = {/* HasNoAbnormalExits */ true,
7687 /*HasNoSideEffects*/ true};
7688
7689 for (auto *BB : L->getBlocks())
7690 for (auto &I : *BB) {
7692 LP.HasNoAbnormalExits = false;
7693 if (HasSideEffects(&I))
7694 LP.HasNoSideEffects = false;
7695 if (!LP.HasNoAbnormalExits && !LP.HasNoSideEffects)
7696 break; // We're already as pessimistic as we can get.
7697 }
7698
7699 auto InsertPair = LoopPropertiesCache.insert({L, LP});
7700 assert(InsertPair.second && "We just checked!");
7701 Itr = InsertPair.first;
7702 }
7703
7704 return Itr->second;
7705}
7706
7708 // A mustprogress loop without side effects must be finite.
7709 // TODO: The check used here is very conservative. It's only *specific*
7710 // side effects which are well defined in infinite loops.
7711 return isFinite(L) || (isMustProgress(L) && loopHasNoSideEffects(L));
7712}
7713
7714const SCEV *ScalarEvolution::createSCEVIter(Value *V) {
7715 // Worklist item with a Value and a bool indicating whether all operands have
7716 // been visited already.
7719
7720 Stack.emplace_back(V, true);
7721 Stack.emplace_back(V, false);
7722 while (!Stack.empty()) {
7723 auto E = Stack.pop_back_val();
7724 Value *CurV = E.getPointer();
7725
7726 if (getExistingSCEV(CurV))
7727 continue;
7728
7730 const SCEV *CreatedSCEV = nullptr;
7731 // If all operands have been visited already, create the SCEV.
7732 if (E.getInt()) {
7733 CreatedSCEV = createSCEV(CurV);
7734 } else {
7735 // Otherwise get the operands we need to create SCEV's for before creating
7736 // the SCEV for CurV. If the SCEV for CurV can be constructed trivially,
7737 // just use it.
7738 CreatedSCEV = getOperandsToCreate(CurV, Ops);
7739 }
7740
7741 if (CreatedSCEV) {
7742 insertValueToMap(CurV, CreatedSCEV);
7743 } else {
7744 // Queue CurV for SCEV creation, followed by its's operands which need to
7745 // be constructed first.
7746 Stack.emplace_back(CurV, true);
7747 for (Value *Op : Ops)
7748 Stack.emplace_back(Op, false);
7749 }
7750 }
7751
7752 return getExistingSCEV(V);
7753}
7754
7755const SCEV *
7756ScalarEvolution::getOperandsToCreate(Value *V, SmallVectorImpl<Value *> &Ops) {
7757 if (!isSCEVable(V->getType()))
7758 return getUnknown(V);
7759
7760 if (Instruction *I = dyn_cast<Instruction>(V)) {
7761 // Don't attempt to analyze instructions in blocks that aren't
7762 // reachable. Such instructions don't matter, and they aren't required
7763 // to obey basic rules for definitions dominating uses which this
7764 // analysis depends on.
7765 if (!DT.isReachableFromEntry(I->getParent()))
7766 return getUnknown(PoisonValue::get(V->getType()));
7767 } else if (ConstantInt *CI = dyn_cast<ConstantInt>(V))
7768 return getConstant(CI);
7769 else if (isa<GlobalAlias>(V))
7770 return getUnknown(V);
7771 else if (!isa<ConstantExpr>(V))
7772 return getUnknown(V);
7773
7775 if (auto BO =
7777 bool IsConstArg = isa<ConstantInt>(BO->RHS);
7778 switch (BO->Opcode) {
7779 case Instruction::Add:
7780 case Instruction::Mul: {
7781 // For additions and multiplications, traverse add/mul chains for which we
7782 // can potentially create a single SCEV, to reduce the number of
7783 // get{Add,Mul}Expr calls.
7784 do {
7785 if (BO->Op) {
7786 if (BO->Op != V && getExistingSCEV(BO->Op)) {
7787 Ops.push_back(BO->Op);
7788 break;
7789 }
7790 }
7791 Ops.push_back(BO->RHS);
7792 auto NewBO = MatchBinaryOp(BO->LHS, getDataLayout(), AC, DT,
7794 if (!NewBO ||
7795 (BO->Opcode == Instruction::Add &&
7796 (NewBO->Opcode != Instruction::Add &&
7797 NewBO->Opcode != Instruction::Sub)) ||
7798 (BO->Opcode == Instruction::Mul &&
7799 NewBO->Opcode != Instruction::Mul)) {
7800 Ops.push_back(BO->LHS);
7801 break;
7802 }
7803 // CreateSCEV calls getNoWrapFlagsFromUB, which under certain conditions
7804 // requires a SCEV for the LHS.
7805 if (BO->Op && (BO->IsNSW || BO->IsNUW)) {
7806 auto *I = dyn_cast<Instruction>(BO->Op);
7807 if (I && programUndefinedIfPoison(I)) {
7808 Ops.push_back(BO->LHS);
7809 break;
7810 }
7811 }
7812 BO = NewBO;
7813 } while (true);
7814 return nullptr;
7815 }
7816 case Instruction::Sub:
7817 case Instruction::UDiv:
7818 case Instruction::URem:
7819 break;
7820 case Instruction::AShr:
7821 case Instruction::Shl:
7822 case Instruction::Xor:
7823 if (!IsConstArg)
7824 return nullptr;
7825 break;
7826 case Instruction::And:
7827 case Instruction::Or:
7828 if (!IsConstArg && !BO->LHS->getType()->isIntegerTy(1))
7829 return nullptr;
7830 break;
7831 case Instruction::LShr:
7832 return getUnknown(V);
7833 default:
7834 llvm_unreachable("Unhandled binop");
7835 break;
7836 }
7837
7838 Ops.push_back(BO->LHS);
7839 Ops.push_back(BO->RHS);
7840 return nullptr;
7841 }
7842
7843 switch (U->getOpcode()) {
7844 case Instruction::Trunc:
7845 case Instruction::ZExt:
7846 case Instruction::SExt:
7847 case Instruction::PtrToAddr:
7848 case Instruction::PtrToInt:
7849 Ops.push_back(U->getOperand(0));
7850 return nullptr;
7851
7852 case Instruction::BitCast:
7853 if (isSCEVable(U->getType()) && isSCEVable(U->getOperand(0)->getType())) {
7854 Ops.push_back(U->getOperand(0));
7855 return nullptr;
7856 }
7857 return getUnknown(V);
7858
7859 case Instruction::SDiv:
7860 case Instruction::SRem:
7861 Ops.push_back(U->getOperand(0));
7862 Ops.push_back(U->getOperand(1));
7863 return nullptr;
7864
7865 case Instruction::GetElementPtr:
7866 assert(cast<GEPOperator>(U)->getSourceElementType()->isSized() &&
7867 "GEP source element type must be sized");
7868 llvm::append_range(Ops, U->operands());
7869 return nullptr;
7870
7871 case Instruction::IntToPtr:
7872 return getUnknown(V);
7873
7874 case Instruction::PHI:
7875 // getNodeForPHI has four ways to turn a PHI into a SCEV; retrieve the
7876 // relevant nodes for each of them.
7877 //
7878 // The first is just to call simplifyInstruction, and get something back
7879 // that isn't a PHI.
7880 if (Value *V = simplifyInstruction(
7881 cast<PHINode>(U),
7882 {getDataLayout(), &TLI, &DT, &AC, /*CtxI=*/nullptr,
7883 /*UseInstrInfo=*/true, /*CanUseUndef=*/false})) {
7884 assert(V);
7885 Ops.push_back(V);
7886 return nullptr;
7887 }
7888 // The second is createNodeForPHIWithIdenticalOperands: this looks for
7889 // operands which all perform the same operation, but haven't been
7890 // CSE'ed for whatever reason.
7891 if (BinaryOperator *BO = getCommonInstForPHI(cast<PHINode>(U))) {
7892 assert(BO);
7893 Ops.push_back(BO);
7894 return nullptr;
7895 }
7896 // The third is createNodeFromSelectLikePHI; this takes a PHI which
7897 // is equivalent to a select, and analyzes it like a select.
7898 {
7899 Value *Cond = nullptr, *LHS = nullptr, *RHS = nullptr;
7901 assert(Cond);
7902 assert(LHS);
7903 assert(RHS);
7904 if (auto *CondICmp = dyn_cast<ICmpInst>(Cond)) {
7905 Ops.push_back(CondICmp->getOperand(0));
7906 Ops.push_back(CondICmp->getOperand(1));
7907 }
7908 Ops.push_back(Cond);
7909 Ops.push_back(LHS);
7910 Ops.push_back(RHS);
7911 return nullptr;
7912 }
7913 }
7914 // The fourth way is createAddRecFromPHI. It's complicated to handle here,
7915 // so just construct it recursively.
7916 //
7917 // In addition to getNodeForPHI, also construct nodes which might be needed
7918 // by getRangeRef.
7920 for (Value *V : cast<PHINode>(U)->operands())
7921 Ops.push_back(V);
7922 return nullptr;
7923 }
7924 return nullptr;
7925
7926 case Instruction::Select: {
7927 // Check if U is a select that can be simplified to a SCEVUnknown.
7928 auto CanSimplifyToUnknown = [this, U]() {
7929 if (U->getType()->isIntegerTy(1) || isa<ConstantInt>(U->getOperand(0)))
7930 return false;
7931
7932 auto *ICI = dyn_cast<ICmpInst>(U->getOperand(0));
7933 if (!ICI)
7934 return false;
7935 Value *LHS = ICI->getOperand(0);
7936 Value *RHS = ICI->getOperand(1);
7937 if (ICI->getPredicate() == CmpInst::ICMP_EQ ||
7938 ICI->getPredicate() == CmpInst::ICMP_NE) {
7940 return true;
7941 } else if (getTypeSizeInBits(LHS->getType()) >
7942 getTypeSizeInBits(U->getType()))
7943 return true;
7944 return false;
7945 };
7946 if (CanSimplifyToUnknown())
7947 return getUnknown(U);
7948
7949 llvm::append_range(Ops, U->operands());
7950 return nullptr;
7951 break;
7952 }
7953 case Instruction::Call:
7954 case Instruction::Invoke:
7955 if (Value *RV = cast<CallBase>(U)->getReturnedArgOperand()) {
7956 Ops.push_back(RV);
7957 return nullptr;
7958 }
7959
7960 if (auto *II = dyn_cast<IntrinsicInst>(U)) {
7961 switch (II->getIntrinsicID()) {
7962 case Intrinsic::abs:
7963 Ops.push_back(II->getArgOperand(0));
7964 return nullptr;
7965 case Intrinsic::umax:
7966 case Intrinsic::umin:
7967 case Intrinsic::smax:
7968 case Intrinsic::smin:
7969 case Intrinsic::usub_sat:
7970 case Intrinsic::uadd_sat:
7971 Ops.push_back(II->getArgOperand(0));
7972 Ops.push_back(II->getArgOperand(1));
7973 return nullptr;
7974 case Intrinsic::start_loop_iterations:
7975 case Intrinsic::annotation:
7976 case Intrinsic::ptr_annotation:
7977 Ops.push_back(II->getArgOperand(0));
7978 return nullptr;
7979 default:
7980 break;
7981 }
7982 }
7983 break;
7984 }
7985
7986 return nullptr;
7987}
7988
7989const SCEV *ScalarEvolution::createSCEV(Value *V) {
7990 if (!isSCEVable(V->getType()))
7991 return getUnknown(V);
7992
7993 if (Instruction *I = dyn_cast<Instruction>(V)) {
7994 // Don't attempt to analyze instructions in blocks that aren't
7995 // reachable. Such instructions don't matter, and they aren't required
7996 // to obey basic rules for definitions dominating uses which this
7997 // analysis depends on.
7998 if (!DT.isReachableFromEntry(I->getParent()))
7999 return getUnknown(PoisonValue::get(V->getType()));
8000 } else if (ConstantInt *CI = dyn_cast<ConstantInt>(V))
8001 return getConstant(CI);
8002 else if (isa<GlobalAlias>(V))
8003 return getUnknown(V);
8004 else if (!isa<ConstantExpr>(V))
8005 return getUnknown(V);
8006
8007 const SCEV *LHS;
8008 const SCEV *RHS;
8009
8011 if (auto BO =
8013 switch (BO->Opcode) {
8014 case Instruction::Add: {
8015 // The simple thing to do would be to just call getSCEV on both operands
8016 // and call getAddExpr with the result. However if we're looking at a
8017 // bunch of things all added together, this can be quite inefficient,
8018 // because it leads to N-1 getAddExpr calls for N ultimate operands.
8019 // Instead, gather up all the operands and make a single getAddExpr call.
8020 // LLVM IR canonical form means we need only traverse the left operands.
8022 do {
8023 if (BO->Op) {
8024 if (auto *OpSCEV = getExistingSCEV(BO->Op)) {
8025 AddOps.push_back(OpSCEV);
8026 break;
8027 }
8028
8029 // If a NUW or NSW flag can be applied to the SCEV for this
8030 // addition, then compute the SCEV for this addition by itself
8031 // with a separate call to getAddExpr. We need to do that
8032 // instead of pushing the operands of the addition onto AddOps,
8033 // since the flags are only known to apply to this particular
8034 // addition - they may not apply to other additions that can be
8035 // formed with operands from AddOps.
8036 const SCEV *RHS = getSCEV(BO->RHS);
8037 SCEV::NoWrapFlags Flags = getNoWrapFlagsFromUB(BO->Op);
8038 if (Flags != SCEV::FlagAnyWrap) {
8039 const SCEV *LHS = getSCEV(BO->LHS);
8040 if (BO->Opcode == Instruction::Sub)
8041 AddOps.push_back(getMinusSCEV(LHS, RHS, Flags));
8042 else
8043 AddOps.push_back(getAddExpr(LHS, RHS, Flags));
8044 break;
8045 }
8046 }
8047
8048 if (BO->Opcode == Instruction::Sub)
8049 AddOps.push_back(getNegativeSCEV(getSCEV(BO->RHS)));
8050 else
8051 AddOps.push_back(getSCEV(BO->RHS));
8052
8053 auto NewBO = MatchBinaryOp(BO->LHS, getDataLayout(), AC, DT,
8055 if (!NewBO || (NewBO->Opcode != Instruction::Add &&
8056 NewBO->Opcode != Instruction::Sub)) {
8057 AddOps.push_back(getSCEV(BO->LHS));
8058 break;
8059 }
8060 BO = NewBO;
8061 } while (true);
8062
8063 return getAddExpr(AddOps);
8064 }
8065
8066 case Instruction::Mul: {
8068 do {
8069 if (BO->Op) {
8070 if (auto *OpSCEV = getExistingSCEV(BO->Op)) {
8071 MulOps.push_back(OpSCEV);
8072 break;
8073 }
8074
8075 SCEV::NoWrapFlags Flags = getNoWrapFlagsFromUB(BO->Op);
8076 if (Flags != SCEV::FlagAnyWrap) {
8077 LHS = getSCEV(BO->LHS);
8078 RHS = getSCEV(BO->RHS);
8079 MulOps.push_back(getMulExpr(LHS, RHS, Flags));
8080 break;
8081 }
8082 }
8083
8084 MulOps.push_back(getSCEV(BO->RHS));
8085 auto NewBO = MatchBinaryOp(BO->LHS, getDataLayout(), AC, DT,
8087 if (!NewBO || NewBO->Opcode != Instruction::Mul) {
8088 MulOps.push_back(getSCEV(BO->LHS));
8089 break;
8090 }
8091 BO = NewBO;
8092 } while (true);
8093
8094 return getMulExpr(MulOps);
8095 }
8096 case Instruction::UDiv:
8097 LHS = getSCEV(BO->LHS);
8098 RHS = getSCEV(BO->RHS);
8099 return getUDivExpr(LHS, RHS);
8100 case Instruction::URem:
8101 LHS = getSCEV(BO->LHS);
8102 RHS = getSCEV(BO->RHS);
8103 return getURemExpr(LHS, RHS);
8104 case Instruction::Sub: {
8106 if (BO->Op)
8107 Flags = getNoWrapFlagsFromUB(BO->Op);
8108 LHS = getSCEV(BO->LHS);
8109 RHS = getSCEV(BO->RHS);
8110 return getMinusSCEV(LHS, RHS, Flags);
8111 }
8112 case Instruction::And:
8113 // For an expression like x&255 that merely masks off the high bits,
8114 // use zext(trunc(x)) as the SCEV expression.
8115 if (ConstantInt *CI = dyn_cast<ConstantInt>(BO->RHS)) {
8116 if (CI->isZero())
8117 return getSCEV(BO->RHS);
8118 if (CI->isMinusOne())
8119 return getSCEV(BO->LHS);
8120 const APInt &A = CI->getValue();
8121
8122 // Instcombine's ShrinkDemandedConstant may strip bits out of
8123 // constants, obscuring what would otherwise be a low-bits mask.
8124 // Use computeKnownBits to compute what ShrinkDemandedConstant
8125 // knew about to reconstruct a low-bits mask value.
8126 unsigned LZ = A.countl_zero();
8127 unsigned TZ = A.countr_zero();
8128 unsigned BitWidth = A.getBitWidth();
8129 KnownBits Known(BitWidth);
8130 computeKnownBits(BO->LHS, Known, getDataLayout(), &AC, nullptr, &DT);
8131
8132 APInt EffectiveMask =
8133 APInt::getLowBitsSet(BitWidth, BitWidth - LZ - TZ).shl(TZ);
8134 if ((LZ != 0 || TZ != 0) && !((~A & ~Known.Zero) & EffectiveMask)) {
8135 const SCEV *MulCount = getConstant(APInt::getOneBitSet(BitWidth, TZ));
8136 const SCEV *LHS = getSCEV(BO->LHS);
8137 const SCEV *ShiftedLHS = nullptr;
8138 if (auto *LHSMul = dyn_cast<SCEVMulExpr>(LHS)) {
8139 if (auto *OpC = dyn_cast<SCEVConstant>(LHSMul->getOperand(0))) {
8140 // For an expression like (x * 8) & 8, simplify the multiply.
8141 unsigned MulZeros = OpC->getAPInt().countr_zero();
8142 unsigned GCD = std::min(MulZeros, TZ);
8143 APInt DivAmt = APInt::getOneBitSet(BitWidth, TZ - GCD);
8145 MulOps.push_back(getConstant(OpC->getAPInt().ashr(GCD)));
8146 append_range(MulOps, LHSMul->operands().drop_front());
8147 auto *NewMul = getMulExpr(MulOps, LHSMul->getNoWrapFlags());
8148 ShiftedLHS = getUDivExpr(NewMul, getConstant(DivAmt));
8149 }
8150 }
8151 if (!ShiftedLHS)
8152 ShiftedLHS = getUDivExpr(LHS, MulCount);
8153 return getMulExpr(
8155 getTruncateExpr(ShiftedLHS,
8156 IntegerType::get(getContext(), BitWidth - LZ - TZ)),
8157 BO->LHS->getType()),
8158 MulCount);
8159 }
8160 }
8161 // Binary `and` is a bit-wise `umin`.
8162 if (BO->LHS->getType()->isIntegerTy(1)) {
8163 LHS = getSCEV(BO->LHS);
8164 RHS = getSCEV(BO->RHS);
8165 return getUMinExpr(LHS, RHS);
8166 }
8167 break;
8168
8169 case Instruction::Or:
8170 // Binary `or` is a bit-wise `umax`.
8171 if (BO->LHS->getType()->isIntegerTy(1)) {
8172 LHS = getSCEV(BO->LHS);
8173 RHS = getSCEV(BO->RHS);
8174 return getUMaxExpr(LHS, RHS);
8175 }
8176 break;
8177
8178 case Instruction::Xor:
8179 if (ConstantInt *CI = dyn_cast<ConstantInt>(BO->RHS)) {
8180 // If the RHS of xor is -1, then this is a not operation.
8181 if (CI->isMinusOne())
8182 return getNotSCEV(getSCEV(BO->LHS));
8183
8184 // Model xor(and(x, C), C) as and(~x, C), if C is a low-bits mask.
8185 // This is a variant of the check for xor with -1, and it handles
8186 // the case where instcombine has trimmed non-demanded bits out
8187 // of an xor with -1.
8188 if (auto *LBO = dyn_cast<BinaryOperator>(BO->LHS))
8189 if (ConstantInt *LCI = dyn_cast<ConstantInt>(LBO->getOperand(1)))
8190 if (LBO->getOpcode() == Instruction::And &&
8191 LCI->getValue() == CI->getValue())
8192 if (const SCEVZeroExtendExpr *Z =
8194 Type *UTy = BO->LHS->getType();
8195 const SCEV *Z0 = Z->getOperand();
8196 Type *Z0Ty = Z0->getType();
8197 unsigned Z0TySize = getTypeSizeInBits(Z0Ty);
8198
8199 // If C is a low-bits mask, the zero extend is serving to
8200 // mask off the high bits. Complement the operand and
8201 // re-apply the zext.
8202 if (CI->getValue().isMask(Z0TySize))
8203 return getZeroExtendExpr(getNotSCEV(Z0), UTy);
8204
8205 // If C is a single bit, it may be in the sign-bit position
8206 // before the zero-extend. In this case, represent the xor
8207 // using an add, which is equivalent, and re-apply the zext.
8208 APInt Trunc = CI->getValue().trunc(Z0TySize);
8209 if (Trunc.zext(getTypeSizeInBits(UTy)) == CI->getValue() &&
8210 Trunc.isSignMask())
8211 return getZeroExtendExpr(getAddExpr(Z0, getConstant(Trunc)),
8212 UTy);
8213 }
8214 }
8215 break;
8216
8217 case Instruction::Shl:
8218 // Turn shift left of a constant amount into a multiply.
8219 if (ConstantInt *SA = dyn_cast<ConstantInt>(BO->RHS)) {
8220 uint32_t BitWidth = cast<IntegerType>(SA->getType())->getBitWidth();
8221
8222 // If the shift count is not less than the bitwidth, the result of
8223 // the shift is undefined. Don't try to analyze it, because the
8224 // resolution chosen here may differ from the resolution chosen in
8225 // other parts of the compiler.
8226 if (SA->getValue().uge(BitWidth))
8227 break;
8228
8229 // We can safely preserve the nuw flag in all cases. It's also safe to
8230 // turn a nuw nsw shl into a nuw nsw mul. However, nsw in isolation
8231 // requires special handling. It can be preserved as long as we're not
8232 // left shifting by bitwidth - 1.
8233 auto Flags = SCEV::FlagAnyWrap;
8234 if (BO->Op) {
8235 auto MulFlags = getNoWrapFlagsFromUB(BO->Op);
8236 if ((MulFlags & SCEV::FlagNSW) &&
8237 ((MulFlags & SCEV::FlagNUW) || SA->getValue().ult(BitWidth - 1)))
8239 if (MulFlags & SCEV::FlagNUW)
8241 }
8242
8243 ConstantInt *X = ConstantInt::get(
8244 getContext(), APInt::getOneBitSet(BitWidth, SA->getZExtValue()));
8245 return getMulExpr(getSCEV(BO->LHS), getConstant(X), Flags);
8246 }
8247 break;
8248
8249 case Instruction::AShr:
8250 // AShr X, C, where C is a constant.
8251 ConstantInt *CI = dyn_cast<ConstantInt>(BO->RHS);
8252 if (!CI)
8253 break;
8254
8255 Type *OuterTy = BO->LHS->getType();
8256 uint64_t BitWidth = getTypeSizeInBits(OuterTy);
8257 // If the shift count is not less than the bitwidth, the result of
8258 // the shift is undefined. Don't try to analyze it, because the
8259 // resolution chosen here may differ from the resolution chosen in
8260 // other parts of the compiler.
8261 if (CI->getValue().uge(BitWidth))
8262 break;
8263
8264 if (CI->isZero())
8265 return getSCEV(BO->LHS); // shift by zero --> noop
8266
8267 uint64_t AShrAmt = CI->getZExtValue();
8268 Type *TruncTy = IntegerType::get(getContext(), BitWidth - AShrAmt);
8269
8270 Operator *L = dyn_cast<Operator>(BO->LHS);
8271 const SCEV *AddTruncateExpr = nullptr;
8272 ConstantInt *ShlAmtCI = nullptr;
8273 const SCEV *AddConstant = nullptr;
8274
8275 if (L && L->getOpcode() == Instruction::Add) {
8276 // X = Shl A, n
8277 // Y = Add X, c
8278 // Z = AShr Y, m
8279 // n, c and m are constants.
8280
8281 Operator *LShift = dyn_cast<Operator>(L->getOperand(0));
8282 ConstantInt *AddOperandCI = dyn_cast<ConstantInt>(L->getOperand(1));
8283 if (LShift && LShift->getOpcode() == Instruction::Shl) {
8284 if (AddOperandCI) {
8285 const SCEV *ShlOp0SCEV = getSCEV(LShift->getOperand(0));
8286 ShlAmtCI = dyn_cast<ConstantInt>(LShift->getOperand(1));
8287 // since we truncate to TruncTy, the AddConstant should be of the
8288 // same type, so create a new Constant with type same as TruncTy.
8289 // Also, the Add constant should be shifted right by AShr amount.
8290 APInt AddOperand = AddOperandCI->getValue().ashr(AShrAmt);
8291 AddConstant = getConstant(AddOperand.trunc(BitWidth - AShrAmt));
8292 // we model the expression as sext(add(trunc(A), c << n)), since the
8293 // sext(trunc) part is already handled below, we create a
8294 // AddExpr(TruncExp) which will be used later.
8295 AddTruncateExpr = getTruncateExpr(ShlOp0SCEV, TruncTy);
8296 }
8297 }
8298 } else if (L && L->getOpcode() == Instruction::Shl) {
8299 // X = Shl A, n
8300 // Y = AShr X, m
8301 // Both n and m are constant.
8302
8303 const SCEV *ShlOp0SCEV = getSCEV(L->getOperand(0));
8304 ShlAmtCI = dyn_cast<ConstantInt>(L->getOperand(1));
8305 AddTruncateExpr = getTruncateExpr(ShlOp0SCEV, TruncTy);
8306 }
8307
8308 if (AddTruncateExpr && ShlAmtCI) {
8309 // We can merge the two given cases into a single SCEV statement,
8310 // incase n = m, the mul expression will be 2^0, so it gets resolved to
8311 // a simpler case. The following code handles the two cases:
8312 //
8313 // 1) For a two-shift sext-inreg, i.e. n = m,
8314 // use sext(trunc(x)) as the SCEV expression.
8315 //
8316 // 2) When n > m, use sext(mul(trunc(x), 2^(n-m)))) as the SCEV
8317 // expression. We already checked that ShlAmt < BitWidth, so
8318 // the multiplier, 1 << (ShlAmt - AShrAmt), fits into TruncTy as
8319 // ShlAmt - AShrAmt < Amt.
8320 const APInt &ShlAmt = ShlAmtCI->getValue();
8321 if (ShlAmt.ult(BitWidth) && ShlAmt.uge(AShrAmt)) {
8322 APInt Mul = APInt::getOneBitSet(BitWidth - AShrAmt,
8323 ShlAmtCI->getZExtValue() - AShrAmt);
8324 const SCEV *CompositeExpr =
8325 getMulExpr(AddTruncateExpr, getConstant(Mul));
8326 if (L->getOpcode() != Instruction::Shl)
8327 CompositeExpr = getAddExpr(CompositeExpr, AddConstant);
8328
8329 return getSignExtendExpr(CompositeExpr, OuterTy);
8330 }
8331 }
8332 break;
8333 }
8334 }
8335
8336 switch (U->getOpcode()) {
8337 case Instruction::Trunc:
8338 return getTruncateExpr(getSCEV(U->getOperand(0)), U->getType());
8339
8340 case Instruction::ZExt:
8341 return getZeroExtendExpr(getSCEV(U->getOperand(0)), U->getType());
8342
8343 case Instruction::SExt:
8344 if (auto BO = MatchBinaryOp(U->getOperand(0), getDataLayout(), AC, DT,
8346 // The NSW flag of a subtract does not always survive the conversion to
8347 // A + (-1)*B. By pushing sign extension onto its operands we are much
8348 // more likely to preserve NSW and allow later AddRec optimisations.
8349 //
8350 // NOTE: This is effectively duplicating this logic from getSignExtend:
8351 // sext((A + B + ...)<nsw>) --> (sext(A) + sext(B) + ...)<nsw>
8352 // but by that point the NSW information has potentially been lost.
8353 if (BO->Opcode == Instruction::Sub && BO->IsNSW) {
8354 Type *Ty = U->getType();
8355 auto *V1 = getSignExtendExpr(getSCEV(BO->LHS), Ty);
8356 auto *V2 = getSignExtendExpr(getSCEV(BO->RHS), Ty);
8357 return getMinusSCEV(V1, V2, SCEV::FlagNSW);
8358 }
8359 }
8360 return getSignExtendExpr(getSCEV(U->getOperand(0)), U->getType());
8361
8362 case Instruction::BitCast:
8363 // BitCasts are no-op casts so we just eliminate the cast.
8364 if (isSCEVable(U->getType()) && isSCEVable(U->getOperand(0)->getType()))
8365 return getSCEV(U->getOperand(0));
8366 break;
8367
8368 case Instruction::PtrToAddr: {
8369 const SCEV *IntOp = getPtrToAddrExpr(getSCEV(U->getOperand(0)));
8370 if (isa<SCEVCouldNotCompute>(IntOp))
8371 return getUnknown(V);
8372 return IntOp;
8373 }
8374
8375 case Instruction::PtrToInt: {
8376 // Pointer to integer cast is straight-forward, so do model it.
8377 const SCEV *Op = getSCEV(U->getOperand(0));
8378 Type *DstIntTy = U->getType();
8379 // But only if effective SCEV (integer) type is wide enough to represent
8380 // all possible pointer values.
8381 const SCEV *IntOp = getPtrToIntExpr(Op, DstIntTy);
8382 if (isa<SCEVCouldNotCompute>(IntOp))
8383 return getUnknown(V);
8384 return IntOp;
8385 }
8386 case Instruction::IntToPtr:
8387 // Just don't deal with inttoptr casts.
8388 return getUnknown(V);
8389
8390 case Instruction::SDiv:
8391 // If both operands are non-negative, this is just an udiv.
8392 if (isKnownNonNegative(getSCEV(U->getOperand(0))) &&
8393 isKnownNonNegative(getSCEV(U->getOperand(1))))
8394 return getUDivExpr(getSCEV(U->getOperand(0)), getSCEV(U->getOperand(1)));
8395 break;
8396
8397 case Instruction::SRem:
8398 // If both operands are non-negative, this is just an urem.
8399 if (isKnownNonNegative(getSCEV(U->getOperand(0))) &&
8400 isKnownNonNegative(getSCEV(U->getOperand(1))))
8401 return getURemExpr(getSCEV(U->getOperand(0)), getSCEV(U->getOperand(1)));
8402 break;
8403
8404 case Instruction::GetElementPtr:
8405 return createNodeForGEP(cast<GEPOperator>(U));
8406
8407 case Instruction::PHI:
8408 return createNodeForPHI(cast<PHINode>(U));
8409
8410 case Instruction::Select:
8411 return createNodeForSelectOrPHI(U, U->getOperand(0), U->getOperand(1),
8412 U->getOperand(2));
8413
8414 case Instruction::Call:
8415 case Instruction::Invoke:
8416 if (Value *RV = cast<CallBase>(U)->getReturnedArgOperand())
8417 return getSCEV(RV);
8418
8419 if (auto *II = dyn_cast<IntrinsicInst>(U)) {
8420 switch (II->getIntrinsicID()) {
8421 case Intrinsic::abs:
8422 return getAbsExpr(
8423 getSCEV(II->getArgOperand(0)),
8424 /*IsNSW=*/cast<ConstantInt>(II->getArgOperand(1))->isOne());
8425 case Intrinsic::umax:
8426 LHS = getSCEV(II->getArgOperand(0));
8427 RHS = getSCEV(II->getArgOperand(1));
8428 return getUMaxExpr(LHS, RHS);
8429 case Intrinsic::umin:
8430 LHS = getSCEV(II->getArgOperand(0));
8431 RHS = getSCEV(II->getArgOperand(1));
8432 return getUMinExpr(LHS, RHS);
8433 case Intrinsic::smax:
8434 LHS = getSCEV(II->getArgOperand(0));
8435 RHS = getSCEV(II->getArgOperand(1));
8436 return getSMaxExpr(LHS, RHS);
8437 case Intrinsic::smin:
8438 LHS = getSCEV(II->getArgOperand(0));
8439 RHS = getSCEV(II->getArgOperand(1));
8440 return getSMinExpr(LHS, RHS);
8441 case Intrinsic::usub_sat: {
8442 const SCEV *X = getSCEV(II->getArgOperand(0));
8443 const SCEV *Y = getSCEV(II->getArgOperand(1));
8444 const SCEV *ClampedY = getUMinExpr(X, Y);
8445 return getMinusSCEV(X, ClampedY, SCEV::FlagNUW);
8446 }
8447 case Intrinsic::uadd_sat: {
8448 const SCEV *X = getSCEV(II->getArgOperand(0));
8449 const SCEV *Y = getSCEV(II->getArgOperand(1));
8450 const SCEV *ClampedX = getUMinExpr(X, getNotSCEV(Y));
8451 return getAddExpr(ClampedX, Y, SCEV::FlagNUW);
8452 }
8453 case Intrinsic::start_loop_iterations:
8454 case Intrinsic::annotation:
8455 case Intrinsic::ptr_annotation:
8456 // A start_loop_iterations or llvm.annotation or llvm.prt.annotation is
8457 // just eqivalent to the first operand for SCEV purposes.
8458 return getSCEV(II->getArgOperand(0));
8459 case Intrinsic::vscale:
8460 return getVScale(II->getType());
8461 default:
8462 break;
8463 }
8464 }
8465 break;
8466 }
8467
8468 return getUnknown(V);
8469}
8470
8471//===----------------------------------------------------------------------===//
8472// Iteration Count Computation Code
8473//
8474
8476 if (isa<SCEVCouldNotCompute>(ExitCount))
8477 return getCouldNotCompute();
8478
8479 auto *ExitCountType = ExitCount->getType();
8480 assert(ExitCountType->isIntegerTy());
8481 auto *EvalTy = Type::getIntNTy(ExitCountType->getContext(),
8482 1 + ExitCountType->getScalarSizeInBits());
8483 return getTripCountFromExitCount(ExitCount, EvalTy, nullptr);
8484}
8485
8487 Type *EvalTy,
8488 const Loop *L) {
8489 if (isa<SCEVCouldNotCompute>(ExitCount))
8490 return getCouldNotCompute();
8491
8492 unsigned ExitCountSize = getTypeSizeInBits(ExitCount->getType());
8493 unsigned EvalSize = EvalTy->getPrimitiveSizeInBits();
8494
8495 auto CanAddOneWithoutOverflow = [&]() {
8496 ConstantRange ExitCountRange =
8497 getRangeRef(ExitCount, RangeSignHint::HINT_RANGE_UNSIGNED);
8498 if (!ExitCountRange.contains(APInt::getMaxValue(ExitCountSize)))
8499 return true;
8500
8501 return L && isLoopEntryGuardedByCond(L, ICmpInst::ICMP_NE, ExitCount,
8502 getMinusOne(ExitCount->getType()));
8503 };
8504
8505 // If we need to zero extend the backedge count, check if we can add one to
8506 // it prior to zero extending without overflow. Provided this is safe, it
8507 // allows better simplification of the +1.
8508 if (EvalSize > ExitCountSize && CanAddOneWithoutOverflow())
8509 return getZeroExtendExpr(
8510 getAddExpr(ExitCount, getOne(ExitCount->getType())), EvalTy);
8511
8512 // Get the total trip count from the count by adding 1. This may wrap.
8513 return getAddExpr(getTruncateOrZeroExtend(ExitCount, EvalTy), getOne(EvalTy));
8514}
8515
8516static unsigned getConstantTripCount(const SCEVConstant *ExitCount) {
8517 if (!ExitCount)
8518 return 0;
8519
8520 ConstantInt *ExitConst = ExitCount->getValue();
8521
8522 // Guard against huge trip counts.
8523 if (ExitConst->getValue().getActiveBits() > 32)
8524 return 0;
8525
8526 // In case of integer overflow, this returns 0, which is correct.
8527 return ((unsigned)ExitConst->getZExtValue()) + 1;
8528}
8529
8531 auto *ExitCount = dyn_cast<SCEVConstant>(getBackedgeTakenCount(L, Exact));
8532 return getConstantTripCount(ExitCount);
8533}
8534
8535unsigned
8537 const BasicBlock *ExitingBlock) {
8538 assert(ExitingBlock && "Must pass a non-null exiting block!");
8539 assert(L->isLoopExiting(ExitingBlock) &&
8540 "Exiting block must actually branch out of the loop!");
8541 const SCEVConstant *ExitCount =
8542 dyn_cast<SCEVConstant>(getExitCount(L, ExitingBlock));
8543 return getConstantTripCount(ExitCount);
8544}
8545
8547 const Loop *L, SmallVectorImpl<const SCEVPredicate *> *Predicates) {
8548
8549 const auto *MaxExitCount =
8550 Predicates ? getPredicatedConstantMaxBackedgeTakenCount(L, *Predicates)
8552 return getConstantTripCount(dyn_cast<SCEVConstant>(MaxExitCount));
8553}
8554
8556 SmallVector<BasicBlock *, 8> ExitingBlocks;
8557 L->getExitingBlocks(ExitingBlocks);
8558
8559 std::optional<unsigned> Res;
8560 for (auto *ExitingBB : ExitingBlocks) {
8561 unsigned Multiple = getSmallConstantTripMultiple(L, ExitingBB);
8562 if (!Res)
8563 Res = Multiple;
8564 Res = std::gcd(*Res, Multiple);
8565 }
8566 return Res.value_or(1);
8567}
8568
8570 const SCEV *ExitCount) {
8571 if (isa<SCEVCouldNotCompute>(ExitCount))
8572 return 1;
8573
8574 // Get the trip count
8575 const SCEV *TCExpr = getTripCountFromExitCount(applyLoopGuards(ExitCount, L));
8576
8577 APInt Multiple = getNonZeroConstantMultiple(TCExpr);
8578 // If a trip multiple is huge (>=2^32), the trip count is still divisible by
8579 // the greatest power of 2 divisor less than 2^32.
8580 return Multiple.getActiveBits() > 32
8581 ? 1U << std::min(31U, Multiple.countTrailingZeros())
8582 : (unsigned)Multiple.getZExtValue();
8583}
8584
8585/// Returns the largest constant divisor of the trip count of this loop as a
8586/// normal unsigned value, if possible. This means that the actual trip count is
8587/// always a multiple of the returned value (don't forget the trip count could
8588/// very well be zero as well!).
8589///
8590/// Returns 1 if the trip count is unknown or not guaranteed to be the
8591/// multiple of a constant (which is also the case if the trip count is simply
8592/// constant, use getSmallConstantTripCount for that case), Will also return 1
8593/// if the trip count is very large (>= 2^32).
8594///
8595/// As explained in the comments for getSmallConstantTripCount, this assumes
8596/// that control exits the loop via ExitingBlock.
8597unsigned
8599 const BasicBlock *ExitingBlock) {
8600 assert(ExitingBlock && "Must pass a non-null exiting block!");
8601 assert(L->isLoopExiting(ExitingBlock) &&
8602 "Exiting block must actually branch out of the loop!");
8603 const SCEV *ExitCount = getExitCount(L, ExitingBlock);
8604 return getSmallConstantTripMultiple(L, ExitCount);
8605}
8606
8608 const BasicBlock *ExitingBlock,
8609 ExitCountKind Kind) {
8610 switch (Kind) {
8611 case Exact:
8612 return getBackedgeTakenInfo(L).getExact(ExitingBlock, this);
8613 case SymbolicMaximum:
8614 return getBackedgeTakenInfo(L).getSymbolicMax(ExitingBlock, this);
8615 case ConstantMaximum:
8616 return getBackedgeTakenInfo(L).getConstantMax(ExitingBlock, this);
8617 };
8618 llvm_unreachable("Invalid ExitCountKind!");
8619}
8620
8622 const Loop *L, const BasicBlock *ExitingBlock,
8624 switch (Kind) {
8625 case Exact:
8626 return getPredicatedBackedgeTakenInfo(L).getExact(ExitingBlock, this,
8627 Predicates);
8628 case SymbolicMaximum:
8629 return getPredicatedBackedgeTakenInfo(L).getSymbolicMax(ExitingBlock, this,
8630 Predicates);
8631 case ConstantMaximum:
8632 return getPredicatedBackedgeTakenInfo(L).getConstantMax(ExitingBlock, this,
8633 Predicates);
8634 };
8635 llvm_unreachable("Invalid ExitCountKind!");
8636}
8637
8640 return getPredicatedBackedgeTakenInfo(L).getExact(L, this, &Preds);
8641}
8642
8644 ExitCountKind Kind) {
8645 switch (Kind) {
8646 case Exact:
8647 return getBackedgeTakenInfo(L).getExact(L, this);
8648 case ConstantMaximum:
8649 return getBackedgeTakenInfo(L).getConstantMax(this);
8650 case SymbolicMaximum:
8651 return getBackedgeTakenInfo(L).getSymbolicMax(L, this);
8652 };
8653 llvm_unreachable("Invalid ExitCountKind!");
8654}
8655
8658 return getPredicatedBackedgeTakenInfo(L).getSymbolicMax(L, this, &Preds);
8659}
8660
8663 return getPredicatedBackedgeTakenInfo(L).getConstantMax(this, &Preds);
8664}
8665
8667 return getBackedgeTakenInfo(L).isConstantMaxOrZero(this);
8668}
8669
8670/// Push PHI nodes in the header of the given loop onto the given Worklist.
8671static void PushLoopPHIs(const Loop *L,
8674 BasicBlock *Header = L->getHeader();
8675
8676 // Push all Loop-header PHIs onto the Worklist stack.
8677 for (PHINode &PN : Header->phis())
8678 if (Visited.insert(&PN).second)
8679 Worklist.push_back(&PN);
8680}
8681
8682ScalarEvolution::BackedgeTakenInfo &
8683ScalarEvolution::getPredicatedBackedgeTakenInfo(const Loop *L) {
8684 auto &BTI = getBackedgeTakenInfo(L);
8685 if (BTI.hasFullInfo())
8686 return BTI;
8687
8688 auto Pair = PredicatedBackedgeTakenCounts.try_emplace(L);
8689
8690 if (!Pair.second)
8691 return Pair.first->second;
8692
8693 BackedgeTakenInfo Result =
8694 computeBackedgeTakenCount(L, /*AllowPredicates=*/true);
8695
8696 return PredicatedBackedgeTakenCounts.find(L)->second = std::move(Result);
8697}
8698
8699ScalarEvolution::BackedgeTakenInfo &
8700ScalarEvolution::getBackedgeTakenInfo(const Loop *L) {
8701 // Initially insert an invalid entry for this loop. If the insertion
8702 // succeeds, proceed to actually compute a backedge-taken count and
8703 // update the value. The temporary CouldNotCompute value tells SCEV
8704 // code elsewhere that it shouldn't attempt to request a new
8705 // backedge-taken count, which could result in infinite recursion.
8706 std::pair<DenseMap<const Loop *, BackedgeTakenInfo>::iterator, bool> Pair =
8707 BackedgeTakenCounts.try_emplace(L);
8708 if (!Pair.second)
8709 return Pair.first->second;
8710
8711 // computeBackedgeTakenCount may allocate memory for its result. Inserting it
8712 // into the BackedgeTakenCounts map transfers ownership. Otherwise, the result
8713 // must be cleared in this scope.
8714 BackedgeTakenInfo Result = computeBackedgeTakenCount(L);
8715
8716 // Now that we know more about the trip count for this loop, forget any
8717 // existing SCEV values for PHI nodes in this loop since they are only
8718 // conservative estimates made without the benefit of trip count
8719 // information. This invalidation is not necessary for correctness, and is
8720 // only done to produce more precise results.
8721 if (Result.hasAnyInfo()) {
8722 // Invalidate any expression using an addrec in this loop.
8723 SmallVector<SCEVUse, 8> ToForget;
8724 auto LoopUsersIt = LoopUsers.find(L);
8725 if (LoopUsersIt != LoopUsers.end())
8726 append_range(ToForget, LoopUsersIt->second);
8727 forgetMemoizedResults(ToForget);
8728
8729 // Invalidate constant-evolved loop header phis.
8730 for (PHINode &PN : L->getHeader()->phis())
8731 ConstantEvolutionLoopExitValue.erase(&PN);
8732 }
8733
8734 // Re-lookup the insert position, since the call to
8735 // computeBackedgeTakenCount above could result in a
8736 // recusive call to getBackedgeTakenInfo (on a different
8737 // loop), which would invalidate the iterator computed
8738 // earlier.
8739 return BackedgeTakenCounts.find(L)->second = std::move(Result);
8740}
8741
8743 // This method is intended to forget all info about loops. It should
8744 // invalidate caches as if the following happened:
8745 // - The trip counts of all loops have changed arbitrarily
8746 // - Every llvm::Value has been updated in place to produce a different
8747 // result.
8748 BackedgeTakenCounts.clear();
8749 PredicatedBackedgeTakenCounts.clear();
8750 BECountUsers.clear();
8751 LoopPropertiesCache.clear();
8752 ConstantEvolutionLoopExitValue.clear();
8753 ValueExprMap.clear();
8754 ValuesAtScopes.clear();
8755 ValuesAtScopesUsers.clear();
8756 LoopDispositions.clear();
8757 BlockDispositions.clear();
8758 UnsignedRanges.clear();
8759 SignedRanges.clear();
8760 ExprValueMap.clear();
8761 HasRecMap.clear();
8762 ConstantMultipleCache.clear();
8763 PredicatedSCEVRewrites.clear();
8764 FoldCache.clear();
8765 FoldCacheUser.clear();
8766}
8767void ScalarEvolution::visitAndClearUsers(
8770 SmallVectorImpl<SCEVUse> &ToForget) {
8771 while (!Worklist.empty()) {
8772 Instruction *I = Worklist.pop_back_val();
8773 if (!isSCEVable(I->getType()) && !isa<WithOverflowInst>(I))
8774 continue;
8775
8777 ValueExprMap.find_as(static_cast<Value *>(I));
8778 if (It != ValueExprMap.end()) {
8779 eraseValueFromMap(It->first);
8780 ToForget.push_back(It->second);
8781 if (PHINode *PN = dyn_cast<PHINode>(I))
8782 ConstantEvolutionLoopExitValue.erase(PN);
8783 }
8784
8785 PushDefUseChildren(I, Worklist, Visited);
8786 }
8787}
8788
8790 SmallVector<const Loop *, 16> LoopWorklist(1, L);
8793 SmallVector<SCEVUse, 16> ToForget;
8794
8795 // Iterate over all the loops and sub-loops to drop SCEV information.
8796 while (!LoopWorklist.empty()) {
8797 auto *CurrL = LoopWorklist.pop_back_val();
8798
8799 // Drop any stored trip count value.
8800 forgetBackedgeTakenCounts(CurrL, /* Predicated */ false);
8801 forgetBackedgeTakenCounts(CurrL, /* Predicated */ true);
8802
8803 // Drop information about predicated SCEV rewrites for this loop.
8804 for (auto I = PredicatedSCEVRewrites.begin();
8805 I != PredicatedSCEVRewrites.end();) {
8806 std::pair<const SCEV *, const Loop *> Entry = I->first;
8807 if (Entry.second == CurrL)
8808 PredicatedSCEVRewrites.erase(I++);
8809 else
8810 ++I;
8811 }
8812
8813 auto LoopUsersItr = LoopUsers.find(CurrL);
8814 if (LoopUsersItr != LoopUsers.end())
8815 llvm::append_range(ToForget, LoopUsersItr->second);
8816
8817 // Drop information about expressions based on loop-header PHIs.
8818 PushLoopPHIs(CurrL, Worklist, Visited);
8819 visitAndClearUsers(Worklist, Visited, ToForget);
8820
8821 LoopPropertiesCache.erase(CurrL);
8822 // Forget all contained loops too, to avoid dangling entries in the
8823 // ValuesAtScopes map.
8824 LoopWorklist.append(CurrL->begin(), CurrL->end());
8825 }
8826 forgetMemoizedResults(ToForget);
8827}
8828
8830 forgetLoop(L->getOutermostLoop());
8831}
8832
8835 if (!I) return;
8836
8837 // Drop information about expressions based on loop-header PHIs.
8840 SmallVector<SCEVUse, 8> ToForget;
8841 Worklist.push_back(I);
8842 Visited.insert(I);
8843 visitAndClearUsers(Worklist, Visited, ToForget);
8844
8845 forgetMemoizedResults(ToForget);
8846}
8847
8849 if (!isSCEVable(V->getType()))
8850 return;
8851
8852 // If SCEV looked through a trivial LCSSA phi node, we might have SCEV's
8853 // directly using a SCEVUnknown/SCEVAddRec defined in the loop. After an
8854 // extra predecessor is added, this is no longer valid. Find all Unknowns and
8855 // AddRecs defined in the loop and invalidate any SCEV's making use of them.
8856 if (const SCEV *S = getExistingSCEV(V)) {
8857 struct InvalidationRootCollector {
8858 Loop *L;
8860
8861 InvalidationRootCollector(Loop *L) : L(L) {}
8862
8863 bool follow(const SCEV *S) {
8864 if (auto *SU = dyn_cast<SCEVUnknown>(S)) {
8865 if (auto *I = dyn_cast<Instruction>(SU->getValue()))
8866 if (L->contains(I))
8867 Roots.push_back(S);
8868 } else if (auto *AddRec = dyn_cast<SCEVAddRecExpr>(S)) {
8869 if (L->contains(AddRec->getLoop()))
8870 Roots.push_back(S);
8871 }
8872 return true;
8873 }
8874 bool isDone() const { return false; }
8875 };
8876
8877 InvalidationRootCollector C(L);
8878 visitAll(S, C);
8879 forgetMemoizedResults(C.Roots);
8880 }
8881
8882 // Also perform the normal invalidation.
8883 forgetValue(V);
8884}
8885
8886void ScalarEvolution::forgetLoopDispositions() { LoopDispositions.clear(); }
8887
8889 // Unless a specific value is passed to invalidation, completely clear both
8890 // caches.
8891 if (!V) {
8892 BlockDispositions.clear();
8893 LoopDispositions.clear();
8894 return;
8895 }
8896
8897 if (!isSCEVable(V->getType()))
8898 return;
8899
8900 const SCEV *S = getExistingSCEV(V);
8901 if (!S)
8902 return;
8903
8904 // Invalidate the block and loop dispositions cached for S. Dispositions of
8905 // S's users may change if S's disposition changes (i.e. a user may change to
8906 // loop-invariant, if S changes to loop invariant), so also invalidate
8907 // dispositions of S's users recursively.
8908 SmallVector<SCEVUse, 8> Worklist = {S};
8910 while (!Worklist.empty()) {
8911 const SCEV *Curr = Worklist.pop_back_val();
8912 bool LoopDispoRemoved = LoopDispositions.erase(Curr);
8913 bool BlockDispoRemoved = BlockDispositions.erase(Curr);
8914 if (!LoopDispoRemoved && !BlockDispoRemoved)
8915 continue;
8916 auto Users = SCEVUsers.find(Curr);
8917 if (Users != SCEVUsers.end())
8918 for (const auto *User : Users->second)
8919 if (Seen.insert(User).second)
8920 Worklist.push_back(User);
8921 }
8922}
8923
8924/// Get the exact loop backedge taken count considering all loop exits. A
8925/// computable result can only be returned for loops with all exiting blocks
8926/// dominating the latch. howFarToZero assumes that the limit of each loop test
8927/// is never skipped. This is a valid assumption as long as the loop exits via
8928/// that test. For precise results, it is the caller's responsibility to specify
8929/// the relevant loop exiting block using getExact(ExitingBlock, SE).
8930const SCEV *ScalarEvolution::BackedgeTakenInfo::getExact(
8931 const Loop *L, ScalarEvolution *SE,
8933 // If any exits were not computable, the loop is not computable.
8934 if (!isComplete() || ExitNotTaken.empty())
8935 return SE->getCouldNotCompute();
8936
8937 const BasicBlock *Latch = L->getLoopLatch();
8938 // All exiting blocks we have collected must dominate the only backedge.
8939 if (!Latch)
8940 return SE->getCouldNotCompute();
8941
8942 // All exiting blocks we have gathered dominate loop's latch, so exact trip
8943 // count is simply a minimum out of all these calculated exit counts.
8945 for (const auto &ENT : ExitNotTaken) {
8946 const SCEV *BECount = ENT.ExactNotTaken;
8947 assert(BECount != SE->getCouldNotCompute() && "Bad exit SCEV!");
8948 assert(SE->DT.dominates(ENT.ExitingBlock, Latch) &&
8949 "We should only have known counts for exiting blocks that dominate "
8950 "latch!");
8951
8952 Ops.push_back(BECount);
8953
8954 if (Preds)
8955 append_range(*Preds, ENT.Predicates);
8956
8957 assert((Preds || ENT.hasAlwaysTruePredicate()) &&
8958 "Predicate should be always true!");
8959 }
8960
8961 // If an earlier exit exits on the first iteration (exit count zero), then
8962 // a later poison exit count should not propagate into the result. This are
8963 // exactly the semantics provided by umin_seq.
8964 return SE->getUMinFromMismatchedTypes(Ops, /* Sequential */ true);
8965}
8966
8967const ScalarEvolution::ExitNotTakenInfo *
8968ScalarEvolution::BackedgeTakenInfo::getExitNotTaken(
8969 const BasicBlock *ExitingBlock,
8970 SmallVectorImpl<const SCEVPredicate *> *Predicates) const {
8971 for (const auto &ENT : ExitNotTaken)
8972 if (ENT.ExitingBlock == ExitingBlock) {
8973 if (ENT.hasAlwaysTruePredicate())
8974 return &ENT;
8975 else if (Predicates) {
8976 append_range(*Predicates, ENT.Predicates);
8977 return &ENT;
8978 }
8979 }
8980
8981 return nullptr;
8982}
8983
8984/// getConstantMax - Get the constant max backedge taken count for the loop.
8985const SCEV *ScalarEvolution::BackedgeTakenInfo::getConstantMax(
8986 ScalarEvolution *SE,
8987 SmallVectorImpl<const SCEVPredicate *> *Predicates) const {
8988 if (!getConstantMax())
8989 return SE->getCouldNotCompute();
8990
8991 for (const auto &ENT : ExitNotTaken)
8992 if (!ENT.hasAlwaysTruePredicate()) {
8993 if (!Predicates)
8994 return SE->getCouldNotCompute();
8995 append_range(*Predicates, ENT.Predicates);
8996 }
8997
8998 assert((isa<SCEVCouldNotCompute>(getConstantMax()) ||
8999 isa<SCEVConstant>(getConstantMax())) &&
9000 "No point in having a non-constant max backedge taken count!");
9001 return getConstantMax();
9002}
9003
9004const SCEV *ScalarEvolution::BackedgeTakenInfo::getSymbolicMax(
9005 const Loop *L, ScalarEvolution *SE,
9006 SmallVectorImpl<const SCEVPredicate *> *Predicates) {
9007 if (!SymbolicMax) {
9008 // Form an expression for the maximum exit count possible for this loop. We
9009 // merge the max and exact information to approximate a version of
9010 // getConstantMaxBackedgeTakenCount which isn't restricted to just
9011 // constants.
9012 SmallVector<SCEVUse, 4> ExitCounts;
9013
9014 for (const auto &ENT : ExitNotTaken) {
9015 const SCEV *ExitCount = ENT.SymbolicMaxNotTaken;
9016 if (!isa<SCEVCouldNotCompute>(ExitCount)) {
9017 assert(SE->DT.dominates(ENT.ExitingBlock, L->getLoopLatch()) &&
9018 "We should only have known counts for exiting blocks that "
9019 "dominate latch!");
9020 ExitCounts.push_back(ExitCount);
9021 if (Predicates)
9022 append_range(*Predicates, ENT.Predicates);
9023
9024 assert((Predicates || ENT.hasAlwaysTruePredicate()) &&
9025 "Predicate should be always true!");
9026 }
9027 }
9028 if (ExitCounts.empty())
9029 SymbolicMax = SE->getCouldNotCompute();
9030 else
9031 SymbolicMax =
9032 SE->getUMinFromMismatchedTypes(ExitCounts, /*Sequential*/ true);
9033 }
9034 return SymbolicMax;
9035}
9036
9037bool ScalarEvolution::BackedgeTakenInfo::isConstantMaxOrZero(
9038 ScalarEvolution *SE) const {
9039 auto PredicateNotAlwaysTrue = [](const ExitNotTakenInfo &ENT) {
9040 return !ENT.hasAlwaysTruePredicate();
9041 };
9042 return MaxOrZero && !any_of(ExitNotTaken, PredicateNotAlwaysTrue);
9043}
9044
9047
9049 const SCEV *E, const SCEV *ConstantMaxNotTaken,
9050 const SCEV *SymbolicMaxNotTaken, bool MaxOrZero,
9054 // If we prove the max count is zero, so is the symbolic bound. This happens
9055 // in practice due to differences in a) how context sensitive we've chosen
9056 // to be and b) how we reason about bounds implied by UB.
9057 if (ConstantMaxNotTaken->isZero()) {
9058 this->ExactNotTaken = E = ConstantMaxNotTaken;
9059 this->SymbolicMaxNotTaken = SymbolicMaxNotTaken = ConstantMaxNotTaken;
9060 }
9061
9064 "Exact is not allowed to be less precise than Constant Max");
9067 "Exact is not allowed to be less precise than Symbolic Max");
9070 "Symbolic Max is not allowed to be less precise than Constant Max");
9073 "No point in having a non-constant max backedge taken count!");
9075 for (const auto PredList : PredLists)
9076 for (const auto *P : PredList) {
9077 if (SeenPreds.contains(P))
9078 continue;
9079 assert(!isa<SCEVUnionPredicate>(P) && "Only add leaf predicates here!");
9080 SeenPreds.insert(P);
9081 Predicates.push_back(P);
9082 }
9083 assert((isa<SCEVCouldNotCompute>(E) || !E->getType()->isPointerTy()) &&
9084 "Backedge count should be int");
9086 !ConstantMaxNotTaken->getType()->isPointerTy()) &&
9087 "Max backedge count should be int");
9088}
9089
9097
9098/// Allocate memory for BackedgeTakenInfo and copy the not-taken count of each
9099/// computable exit into a persistent ExitNotTakenInfo array.
9100ScalarEvolution::BackedgeTakenInfo::BackedgeTakenInfo(
9102 bool IsComplete, const SCEV *ConstantMax, bool MaxOrZero)
9103 : ConstantMax(ConstantMax), IsComplete(IsComplete), MaxOrZero(MaxOrZero) {
9104 using EdgeExitInfo = ScalarEvolution::BackedgeTakenInfo::EdgeExitInfo;
9105
9106 ExitNotTaken.reserve(ExitCounts.size());
9107 std::transform(ExitCounts.begin(), ExitCounts.end(),
9108 std::back_inserter(ExitNotTaken),
9109 [&](const EdgeExitInfo &EEI) {
9110 BasicBlock *ExitBB = EEI.first;
9111 const ExitLimit &EL = EEI.second;
9112 return ExitNotTakenInfo(ExitBB, EL.ExactNotTaken,
9113 EL.ConstantMaxNotTaken, EL.SymbolicMaxNotTaken,
9114 EL.Predicates);
9115 });
9116 assert((isa<SCEVCouldNotCompute>(ConstantMax) ||
9117 isa<SCEVConstant>(ConstantMax)) &&
9118 "No point in having a non-constant max backedge taken count!");
9119}
9120
9121/// Compute the number of times the backedge of the specified loop will execute.
9122ScalarEvolution::BackedgeTakenInfo
9123ScalarEvolution::computeBackedgeTakenCount(const Loop *L,
9124 bool AllowPredicates) {
9125 SmallVector<BasicBlock *, 8> ExitingBlocks;
9126 L->getExitingBlocks(ExitingBlocks);
9127
9128 using EdgeExitInfo = ScalarEvolution::BackedgeTakenInfo::EdgeExitInfo;
9129
9131 bool CouldComputeBECount = true;
9132 BasicBlock *Latch = L->getLoopLatch(); // may be NULL.
9133 const SCEV *MustExitMaxBECount = nullptr;
9134 const SCEV *MayExitMaxBECount = nullptr;
9135 bool MustExitMaxOrZero = false;
9136 bool IsOnlyExit = ExitingBlocks.size() == 1;
9137
9138 // Compute the ExitLimit for each loop exit. Use this to populate ExitCounts
9139 // and compute maxBECount.
9140 // Do a union of all the predicates here.
9141 for (BasicBlock *ExitBB : ExitingBlocks) {
9142 // We canonicalize untaken exits to br (constant), ignore them so that
9143 // proving an exit untaken doesn't negatively impact our ability to reason
9144 // about the loop as whole.
9145 if (auto *BI = dyn_cast<CondBrInst>(ExitBB->getTerminator()))
9146 if (auto *CI = dyn_cast<ConstantInt>(BI->getCondition())) {
9147 bool ExitIfTrue = !L->contains(BI->getSuccessor(0));
9148 if (ExitIfTrue == CI->isZero())
9149 continue;
9150 }
9151
9152 ExitLimit EL = computeExitLimit(L, ExitBB, IsOnlyExit, AllowPredicates);
9153
9154 assert((AllowPredicates || EL.Predicates.empty()) &&
9155 "Predicated exit limit when predicates are not allowed!");
9156
9157 // 1. For each exit that can be computed, add an entry to ExitCounts.
9158 // CouldComputeBECount is true only if all exits can be computed.
9159 if (EL.ExactNotTaken != getCouldNotCompute())
9160 ++NumExitCountsComputed;
9161 else
9162 // We couldn't compute an exact value for this exit, so
9163 // we won't be able to compute an exact value for the loop.
9164 CouldComputeBECount = false;
9165 // Remember exit count if either exact or symbolic is known. Because
9166 // Exact always implies symbolic, only check symbolic.
9167 if (EL.SymbolicMaxNotTaken != getCouldNotCompute())
9168 ExitCounts.emplace_back(ExitBB, EL);
9169 else {
9170 assert(EL.ExactNotTaken == getCouldNotCompute() &&
9171 "Exact is known but symbolic isn't?");
9172 ++NumExitCountsNotComputed;
9173 }
9174
9175 // 2. Derive the loop's MaxBECount from each exit's max number of
9176 // non-exiting iterations. Partition the loop exits into two kinds:
9177 // LoopMustExits and LoopMayExits.
9178 //
9179 // If the exit dominates the loop latch, it is a LoopMustExit otherwise it
9180 // is a LoopMayExit. If any computable LoopMustExit is found, then
9181 // MaxBECount is the minimum EL.ConstantMaxNotTaken of computable
9182 // LoopMustExits. Otherwise, MaxBECount is conservatively the maximum
9183 // EL.ConstantMaxNotTaken, where CouldNotCompute is considered greater than
9184 // any
9185 // computable EL.ConstantMaxNotTaken.
9186 if (EL.ConstantMaxNotTaken != getCouldNotCompute() && Latch &&
9187 DT.dominates(ExitBB, Latch)) {
9188 if (!MustExitMaxBECount) {
9189 MustExitMaxBECount = EL.ConstantMaxNotTaken;
9190 MustExitMaxOrZero = EL.MaxOrZero;
9191 } else {
9192 MustExitMaxBECount = getUMinFromMismatchedTypes(MustExitMaxBECount,
9193 EL.ConstantMaxNotTaken);
9194 }
9195 } else if (MayExitMaxBECount != getCouldNotCompute()) {
9196 if (!MayExitMaxBECount || EL.ConstantMaxNotTaken == getCouldNotCompute())
9197 MayExitMaxBECount = EL.ConstantMaxNotTaken;
9198 else {
9199 MayExitMaxBECount = getUMaxFromMismatchedTypes(MayExitMaxBECount,
9200 EL.ConstantMaxNotTaken);
9201 }
9202 }
9203 }
9204 const SCEV *MaxBECount = MustExitMaxBECount ? MustExitMaxBECount :
9205 (MayExitMaxBECount ? MayExitMaxBECount : getCouldNotCompute());
9206 // The loop backedge will be taken the maximum or zero times if there's
9207 // a single exit that must be taken the maximum or zero times.
9208 bool MaxOrZero = (MustExitMaxOrZero && ExitingBlocks.size() == 1);
9209
9210 // Remember which SCEVs are used in exit limits for invalidation purposes.
9211 // We only care about non-constant SCEVs here, so we can ignore
9212 // EL.ConstantMaxNotTaken
9213 // and MaxBECount, which must be SCEVConstant.
9214 for (const auto &Pair : ExitCounts) {
9215 if (!isa<SCEVConstant>(Pair.second.ExactNotTaken))
9216 BECountUsers[Pair.second.ExactNotTaken].insert({L, AllowPredicates});
9217 if (!isa<SCEVConstant>(Pair.second.SymbolicMaxNotTaken))
9218 BECountUsers[Pair.second.SymbolicMaxNotTaken].insert(
9219 {L, AllowPredicates});
9220 }
9221 return BackedgeTakenInfo(std::move(ExitCounts), CouldComputeBECount,
9222 MaxBECount, MaxOrZero);
9223}
9224
9225ScalarEvolution::ExitLimit
9226ScalarEvolution::computeExitLimit(const Loop *L, BasicBlock *ExitingBlock,
9227 bool IsOnlyExit, bool AllowPredicates) {
9228 assert(L->contains(ExitingBlock) && "Exit count for non-loop block?");
9229 // If our exiting block does not dominate the latch, then its connection with
9230 // loop's exit limit may be far from trivial.
9231 const BasicBlock *Latch = L->getLoopLatch();
9232 if (!Latch || !DT.dominates(ExitingBlock, Latch))
9233 return getCouldNotCompute();
9234
9235 Instruction *Term = ExitingBlock->getTerminator();
9236 if (CondBrInst *BI = dyn_cast<CondBrInst>(Term)) {
9237 bool ExitIfTrue = !L->contains(BI->getSuccessor(0));
9238 assert(ExitIfTrue == L->contains(BI->getSuccessor(1)) &&
9239 "It should have one successor in loop and one exit block!");
9240 // Proceed to the next level to examine the exit condition expression.
9241 return computeExitLimitFromCond(L, BI->getCondition(), ExitIfTrue,
9242 /*ControlsOnlyExit=*/IsOnlyExit,
9243 AllowPredicates);
9244 }
9245
9246 if (SwitchInst *SI = dyn_cast<SwitchInst>(Term)) {
9247 // For switch, make sure that there is a single exit from the loop.
9248 BasicBlock *Exit = nullptr;
9249 for (auto *SBB : successors(ExitingBlock))
9250 if (!L->contains(SBB)) {
9251 if (Exit) // Multiple exit successors.
9252 return getCouldNotCompute();
9253 Exit = SBB;
9254 }
9255 assert(Exit && "Exiting block must have at least one exit");
9256 return computeExitLimitFromSingleExitSwitch(
9257 L, SI, Exit, /*ControlsOnlyExit=*/IsOnlyExit);
9258 }
9259
9260 return getCouldNotCompute();
9261}
9262
9264 const Loop *L, Value *ExitCond, bool ExitIfTrue, bool ControlsOnlyExit,
9265 bool AllowPredicates) {
9266 ScalarEvolution::ExitLimitCacheTy Cache(L, ExitIfTrue, AllowPredicates);
9267 return computeExitLimitFromCondCached(Cache, L, ExitCond, ExitIfTrue,
9268 ControlsOnlyExit, AllowPredicates);
9269}
9270
9271std::optional<ScalarEvolution::ExitLimit>
9272ScalarEvolution::ExitLimitCache::find(const Loop *L, Value *ExitCond,
9273 bool ExitIfTrue, bool ControlsOnlyExit,
9274 bool AllowPredicates) {
9275 (void)this->L;
9276 (void)this->ExitIfTrue;
9277 (void)this->AllowPredicates;
9278
9279 assert(this->L == L && this->ExitIfTrue == ExitIfTrue &&
9280 this->AllowPredicates == AllowPredicates &&
9281 "Variance in assumed invariant key components!");
9282 auto Itr = TripCountMap.find({ExitCond, ControlsOnlyExit});
9283 if (Itr == TripCountMap.end())
9284 return std::nullopt;
9285 return Itr->second;
9286}
9287
9288void ScalarEvolution::ExitLimitCache::insert(const Loop *L, Value *ExitCond,
9289 bool ExitIfTrue,
9290 bool ControlsOnlyExit,
9291 bool AllowPredicates,
9292 const ExitLimit &EL) {
9293 assert(this->L == L && this->ExitIfTrue == ExitIfTrue &&
9294 this->AllowPredicates == AllowPredicates &&
9295 "Variance in assumed invariant key components!");
9296
9297 auto InsertResult = TripCountMap.insert({{ExitCond, ControlsOnlyExit}, EL});
9298 assert(InsertResult.second && "Expected successful insertion!");
9299 (void)InsertResult;
9300 (void)ExitIfTrue;
9301}
9302
9303ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromCondCached(
9304 ExitLimitCacheTy &Cache, const Loop *L, Value *ExitCond, bool ExitIfTrue,
9305 bool ControlsOnlyExit, bool AllowPredicates) {
9306
9307 if (auto MaybeEL = Cache.find(L, ExitCond, ExitIfTrue, ControlsOnlyExit,
9308 AllowPredicates))
9309 return *MaybeEL;
9310
9311 ExitLimit EL = computeExitLimitFromCondImpl(
9312 Cache, L, ExitCond, ExitIfTrue, ControlsOnlyExit, AllowPredicates);
9313 Cache.insert(L, ExitCond, ExitIfTrue, ControlsOnlyExit, AllowPredicates, EL);
9314 return EL;
9315}
9316
9317ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromCondImpl(
9318 ExitLimitCacheTy &Cache, const Loop *L, Value *ExitCond, bool ExitIfTrue,
9319 bool ControlsOnlyExit, bool AllowPredicates) {
9320 // Handle BinOp conditions (And, Or).
9321 if (auto LimitFromBinOp = computeExitLimitFromCondFromBinOp(
9322 Cache, L, ExitCond, ExitIfTrue, ControlsOnlyExit, AllowPredicates))
9323 return *LimitFromBinOp;
9324
9325 // With an icmp, it may be feasible to compute an exact backedge-taken count.
9326 // Proceed to the next level to examine the icmp.
9327 if (ICmpInst *ExitCondICmp = dyn_cast<ICmpInst>(ExitCond)) {
9328 ExitLimit EL =
9329 computeExitLimitFromICmp(L, ExitCondICmp, ExitIfTrue, ControlsOnlyExit);
9330 if (EL.hasFullInfo() || !AllowPredicates)
9331 return EL;
9332
9333 // Try again, but use SCEV predicates this time.
9334 return computeExitLimitFromICmp(L, ExitCondICmp, ExitIfTrue,
9335 ControlsOnlyExit,
9336 /*AllowPredicates=*/true);
9337 }
9338
9339 // Check for a constant condition. These are normally stripped out by
9340 // SimplifyCFG, but ScalarEvolution may be used by a pass which wishes to
9341 // preserve the CFG and is temporarily leaving constant conditions
9342 // in place.
9343 if (ConstantInt *CI = dyn_cast<ConstantInt>(ExitCond)) {
9344 if (ExitIfTrue == !CI->getZExtValue())
9345 // The backedge is always taken.
9346 return getCouldNotCompute();
9347 // The backedge is never taken.
9348 return getZero(CI->getType());
9349 }
9350
9351 // If we're exiting based on the overflow flag of an x.with.overflow intrinsic
9352 // with a constant step, we can form an equivalent icmp predicate and figure
9353 // out how many iterations will be taken before we exit.
9354 const WithOverflowInst *WO;
9355 const APInt *C;
9356 if (match(ExitCond, m_ExtractValue<1>(m_WithOverflowInst(WO))) &&
9357 match(WO->getRHS(), m_APInt(C))) {
9358 ConstantRange NWR =
9360 WO->getNoWrapKind());
9361 CmpInst::Predicate Pred;
9362 APInt NewRHSC, Offset;
9363 NWR.getEquivalentICmp(Pred, NewRHSC, Offset);
9364 if (!ExitIfTrue)
9365 Pred = ICmpInst::getInversePredicate(Pred);
9366 auto *LHS = getSCEV(WO->getLHS());
9367 if (Offset != 0)
9369 auto EL = computeExitLimitFromICmp(L, Pred, LHS, getConstant(NewRHSC),
9370 ControlsOnlyExit, AllowPredicates);
9371 if (EL.hasAnyInfo())
9372 return EL;
9373 }
9374
9375 // If it's not an integer or pointer comparison then compute it the hard way.
9376 return computeExitCountExhaustively(L, ExitCond, ExitIfTrue);
9377}
9378
9379std::optional<ScalarEvolution::ExitLimit>
9380ScalarEvolution::computeExitLimitFromCondFromBinOp(
9381 ExitLimitCacheTy &Cache, const Loop *L, Value *ExitCond, bool ExitIfTrue,
9382 bool ControlsOnlyExit, bool AllowPredicates) {
9383 // Check if the controlling expression for this loop is an And or Or.
9384 Value *Op0, *Op1;
9385 bool IsAnd = false;
9386 if (match(ExitCond, m_LogicalAnd(m_Value(Op0), m_Value(Op1))))
9387 IsAnd = true;
9388 else if (match(ExitCond, m_LogicalOr(m_Value(Op0), m_Value(Op1))))
9389 IsAnd = false;
9390 else
9391 return std::nullopt;
9392
9393 // EitherMayExit is true in these two cases:
9394 // br (and Op0 Op1), loop, exit
9395 // br (or Op0 Op1), exit, loop
9396 bool EitherMayExit = IsAnd ^ ExitIfTrue;
9397 ExitLimit EL0 = computeExitLimitFromCondCached(
9398 Cache, L, Op0, ExitIfTrue, ControlsOnlyExit && !EitherMayExit,
9399 AllowPredicates);
9400 ExitLimit EL1 = computeExitLimitFromCondCached(
9401 Cache, L, Op1, ExitIfTrue, ControlsOnlyExit && !EitherMayExit,
9402 AllowPredicates);
9403
9404 // Be robust against unsimplified IR for the form "op i1 X, NeutralElement"
9405 const Constant *NeutralElement = ConstantInt::get(ExitCond->getType(), IsAnd);
9406 if (isa<ConstantInt>(Op1))
9407 return Op1 == NeutralElement ? EL0 : EL1;
9408 if (isa<ConstantInt>(Op0))
9409 return Op0 == NeutralElement ? EL1 : EL0;
9410
9411 const SCEV *BECount = getCouldNotCompute();
9412 const SCEV *ConstantMaxBECount = getCouldNotCompute();
9413 const SCEV *SymbolicMaxBECount = getCouldNotCompute();
9414 if (EitherMayExit) {
9415 bool UseSequentialUMin = !isa<BinaryOperator>(ExitCond);
9416 // Both conditions must be same for the loop to continue executing.
9417 // Choose the less conservative count.
9418 if (EL0.ExactNotTaken != getCouldNotCompute() &&
9419 EL1.ExactNotTaken != getCouldNotCompute()) {
9420 BECount = getUMinFromMismatchedTypes(EL0.ExactNotTaken, EL1.ExactNotTaken,
9421 UseSequentialUMin);
9422 }
9423 if (EL0.ConstantMaxNotTaken == getCouldNotCompute())
9424 ConstantMaxBECount = EL1.ConstantMaxNotTaken;
9425 else if (EL1.ConstantMaxNotTaken == getCouldNotCompute())
9426 ConstantMaxBECount = EL0.ConstantMaxNotTaken;
9427 else
9428 ConstantMaxBECount = getUMinFromMismatchedTypes(EL0.ConstantMaxNotTaken,
9429 EL1.ConstantMaxNotTaken);
9430 if (EL0.SymbolicMaxNotTaken == getCouldNotCompute())
9431 SymbolicMaxBECount = EL1.SymbolicMaxNotTaken;
9432 else if (EL1.SymbolicMaxNotTaken == getCouldNotCompute())
9433 SymbolicMaxBECount = EL0.SymbolicMaxNotTaken;
9434 else
9435 SymbolicMaxBECount = getUMinFromMismatchedTypes(
9436 EL0.SymbolicMaxNotTaken, EL1.SymbolicMaxNotTaken, UseSequentialUMin);
9437 } else {
9438 // Both conditions must be same at the same time for the loop to exit.
9439 // For now, be conservative.
9440 if (EL0.ExactNotTaken == EL1.ExactNotTaken)
9441 BECount = EL0.ExactNotTaken;
9442 }
9443
9444 // There are cases (e.g. PR26207) where computeExitLimitFromCond is able
9445 // to be more aggressive when computing BECount than when computing
9446 // ConstantMaxBECount. In these cases it is possible for EL0.ExactNotTaken
9447 // and
9448 // EL1.ExactNotTaken to match, but for EL0.ConstantMaxNotTaken and
9449 // EL1.ConstantMaxNotTaken to not.
9450 if (isa<SCEVCouldNotCompute>(ConstantMaxBECount) &&
9451 !isa<SCEVCouldNotCompute>(BECount))
9452 ConstantMaxBECount = getConstant(getUnsignedRangeMax(BECount));
9453 if (isa<SCEVCouldNotCompute>(SymbolicMaxBECount))
9454 SymbolicMaxBECount =
9455 isa<SCEVCouldNotCompute>(BECount) ? ConstantMaxBECount : BECount;
9456 return ExitLimit(BECount, ConstantMaxBECount, SymbolicMaxBECount, false,
9457 {ArrayRef(EL0.Predicates), ArrayRef(EL1.Predicates)});
9458}
9459
9460ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromICmp(
9461 const Loop *L, ICmpInst *ExitCond, bool ExitIfTrue, bool ControlsOnlyExit,
9462 bool AllowPredicates) {
9463 // If the condition was exit on true, convert the condition to exit on false
9464 CmpPredicate Pred;
9465 if (!ExitIfTrue)
9466 Pred = ExitCond->getCmpPredicate();
9467 else
9468 Pred = ExitCond->getInverseCmpPredicate();
9469 const ICmpInst::Predicate OriginalPred = Pred;
9470
9471 const SCEV *LHS = getSCEV(ExitCond->getOperand(0));
9472 const SCEV *RHS = getSCEV(ExitCond->getOperand(1));
9473
9474 ExitLimit EL = computeExitLimitFromICmp(L, Pred, LHS, RHS, ControlsOnlyExit,
9475 AllowPredicates);
9476 if (EL.hasAnyInfo())
9477 return EL;
9478
9479 auto *ExhaustiveCount =
9480 computeExitCountExhaustively(L, ExitCond, ExitIfTrue);
9481
9482 if (!isa<SCEVCouldNotCompute>(ExhaustiveCount))
9483 return ExhaustiveCount;
9484
9485 return computeShiftCompareExitLimit(ExitCond->getOperand(0),
9486 ExitCond->getOperand(1), L, OriginalPred);
9487}
9488ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromICmp(
9489 const Loop *L, CmpPredicate Pred, SCEVUse LHS, SCEVUse RHS,
9490 bool ControlsOnlyExit, bool AllowPredicates) {
9491
9492 // Try to evaluate any dependencies out of the loop.
9493 LHS = getSCEVAtScope(LHS, L);
9494 RHS = getSCEVAtScope(RHS, L);
9495
9496 // At this point, we would like to compute how many iterations of the
9497 // loop the predicate will return true for these inputs.
9498 if (isLoopInvariant(LHS, L) && !isLoopInvariant(RHS, L)) {
9499 // If there is a loop-invariant, force it into the RHS.
9500 std::swap(LHS, RHS);
9502 }
9503
9504 bool ControllingFiniteLoop = ControlsOnlyExit && loopHasNoAbnormalExits(L) &&
9506 // Simplify the operands before analyzing them.
9507 (void)SimplifyICmpOperands(Pred, LHS, RHS, /*Depth=*/0);
9508
9509 // If we have a comparison of a chrec against a constant, try to use value
9510 // ranges to answer this query.
9511 if (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS))
9512 if (const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(LHS))
9513 if (AddRec->getLoop() == L) {
9514 // Form the constant range.
9515 ConstantRange CompRange =
9516 ConstantRange::makeExactICmpRegion(Pred, RHSC->getAPInt());
9517
9518 const SCEV *Ret = AddRec->getNumIterationsInRange(CompRange, *this);
9519 if (!isa<SCEVCouldNotCompute>(Ret)) return Ret;
9520 }
9521
9522 // If this loop must exit based on this condition (or execute undefined
9523 // behaviour), see if we can improve wrap flags. This is essentially
9524 // a must execute style proof.
9525 if (ControllingFiniteLoop && isLoopInvariant(RHS, L)) {
9526 // If we can prove the test sequence produced must repeat the same values
9527 // on self-wrap of the IV, then we can infer that IV doesn't self wrap
9528 // because if it did, we'd have an infinite (undefined) loop.
9529 // TODO: We can peel off any functions which are invertible *in L*. Loop
9530 // invariant terms are effectively constants for our purposes here.
9531 SCEVUse InnerLHS = LHS;
9532 if (auto *ZExt = dyn_cast<SCEVZeroExtendExpr>(LHS))
9533 InnerLHS = ZExt->getOperand();
9534 if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(InnerLHS);
9535 AR && !AR->hasNoSelfWrap() && AR->getLoop() == L && AR->isAffine() &&
9536 isKnownToBeAPowerOfTwo(AR->getStepRecurrence(*this), /*OrZero=*/true,
9537 /*OrNegative=*/true)) {
9538 auto Flags = AR->getNoWrapFlags();
9539 Flags = setFlags(Flags, SCEV::FlagNW);
9540 SmallVector<SCEVUse> Operands{AR->operands()};
9541 Flags = StrengthenNoWrapFlags(this, scAddRecExpr, Operands, Flags);
9542 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), Flags);
9543 }
9544
9545 // For a slt/ult condition with a positive step, can we prove nsw/nuw?
9546 // From no-self-wrap, this follows trivially from the fact that every
9547 // (un)signed-wrapped, but not self-wrapped value must be LT than the
9548 // last value before (un)signed wrap. Since we know that last value
9549 // didn't exit, nor will any smaller one.
9550 if (Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_ULT) {
9551 auto WrapType = Pred == ICmpInst::ICMP_SLT ? SCEV::FlagNSW : SCEV::FlagNUW;
9552 if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(LHS);
9553 AR && AR->getLoop() == L && AR->isAffine() &&
9554 !AR->getNoWrapFlags(WrapType) && AR->hasNoSelfWrap() &&
9555 isKnownPositive(AR->getStepRecurrence(*this))) {
9556 auto Flags = AR->getNoWrapFlags();
9557 Flags = setFlags(Flags, WrapType);
9558 SmallVector<SCEVUse> Operands{AR->operands()};
9559 Flags = StrengthenNoWrapFlags(this, scAddRecExpr, Operands, Flags);
9560 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), Flags);
9561 }
9562 }
9563 }
9564
9565 switch (Pred) {
9566 case ICmpInst::ICMP_NE: { // while (X != Y)
9567 // Convert to: while (X-Y != 0)
9568 if (LHS->getType()->isPointerTy()) {
9571 return LHS;
9572 }
9573 if (RHS->getType()->isPointerTy()) {
9576 return RHS;
9577 }
9578 ExitLimit EL = howFarToZero(getMinusSCEV(LHS, RHS), L, ControlsOnlyExit,
9579 AllowPredicates);
9580 if (EL.hasAnyInfo())
9581 return EL;
9582 break;
9583 }
9584 case ICmpInst::ICMP_EQ: { // while (X == Y)
9585 // Convert to: while (X-Y == 0)
9586 if (LHS->getType()->isPointerTy()) {
9589 return LHS;
9590 }
9591 if (RHS->getType()->isPointerTy()) {
9594 return RHS;
9595 }
9596 ExitLimit EL = howFarToNonZero(getMinusSCEV(LHS, RHS), L);
9597 if (EL.hasAnyInfo()) return EL;
9598 break;
9599 }
9600 case ICmpInst::ICMP_SLE:
9601 case ICmpInst::ICMP_ULE:
9602 // Since the loop is finite, an invariant RHS cannot include the boundary
9603 // value, otherwise it would loop forever.
9604 if (!EnableFiniteLoopControl || !ControllingFiniteLoop ||
9605 !isLoopInvariant(RHS, L)) {
9606 // Otherwise, perform the addition in a wider type, to avoid overflow.
9607 // If the LHS is an addrec with the appropriate nowrap flag, the
9608 // extension will be sunk into it and the exit count can be analyzed.
9609 auto *OldType = dyn_cast<IntegerType>(LHS->getType());
9610 if (!OldType)
9611 break;
9612 // Prefer doubling the bitwidth over adding a single bit to make it more
9613 // likely that we use a legal type.
9614 auto *NewType =
9615 Type::getIntNTy(OldType->getContext(), OldType->getBitWidth() * 2);
9616 if (ICmpInst::isSigned(Pred)) {
9617 LHS = getSignExtendExpr(LHS, NewType);
9618 RHS = getSignExtendExpr(RHS, NewType);
9619 } else {
9620 LHS = getZeroExtendExpr(LHS, NewType);
9621 RHS = getZeroExtendExpr(RHS, NewType);
9622 }
9623 }
9625 [[fallthrough]];
9626 case ICmpInst::ICMP_SLT:
9627 case ICmpInst::ICMP_ULT: { // while (X < Y)
9628 bool IsSigned = ICmpInst::isSigned(Pred);
9629 ExitLimit EL = howManyLessThans(LHS, RHS, L, IsSigned, ControlsOnlyExit,
9630 AllowPredicates);
9631 if (EL.hasAnyInfo())
9632 return EL;
9633 break;
9634 }
9635 case ICmpInst::ICMP_SGE:
9636 case ICmpInst::ICMP_UGE:
9637 // Since the loop is finite, an invariant RHS cannot include the boundary
9638 // value, otherwise it would loop forever.
9639 if (!EnableFiniteLoopControl || !ControllingFiniteLoop ||
9640 !isLoopInvariant(RHS, L))
9641 break;
9643 [[fallthrough]];
9644 case ICmpInst::ICMP_SGT:
9645 case ICmpInst::ICMP_UGT: { // while (X > Y)
9646 bool IsSigned = ICmpInst::isSigned(Pred);
9647 ExitLimit EL = howManyGreaterThans(LHS, RHS, L, IsSigned, ControlsOnlyExit,
9648 AllowPredicates);
9649 if (EL.hasAnyInfo())
9650 return EL;
9651 break;
9652 }
9653 default:
9654 break;
9655 }
9656
9657 return getCouldNotCompute();
9658}
9659
9660ScalarEvolution::ExitLimit
9661ScalarEvolution::computeExitLimitFromSingleExitSwitch(const Loop *L,
9662 SwitchInst *Switch,
9663 BasicBlock *ExitingBlock,
9664 bool ControlsOnlyExit) {
9665 assert(!L->contains(ExitingBlock) && "Not an exiting block!");
9666
9667 // Give up if the exit is the default dest of a switch.
9668 if (Switch->getDefaultDest() == ExitingBlock)
9669 return getCouldNotCompute();
9670
9671 assert(L->contains(Switch->getDefaultDest()) &&
9672 "Default case must not exit the loop!");
9673 const SCEV *LHS = getSCEVAtScope(Switch->getCondition(), L);
9674 const SCEV *RHS = getConstant(Switch->findCaseDest(ExitingBlock));
9675
9676 // while (X != Y) --> while (X-Y != 0)
9677 ExitLimit EL = howFarToZero(getMinusSCEV(LHS, RHS), L, ControlsOnlyExit);
9678 if (EL.hasAnyInfo())
9679 return EL;
9680
9681 return getCouldNotCompute();
9682}
9683
9684static ConstantInt *
9686 ScalarEvolution &SE) {
9687 const SCEV *InVal = SE.getConstant(C);
9688 const SCEV *Val = AddRec->evaluateAtIteration(InVal, SE);
9690 "Evaluation of SCEV at constant didn't fold correctly?");
9691 return cast<SCEVConstant>(Val)->getValue();
9692}
9693
9694ScalarEvolution::ExitLimit ScalarEvolution::computeShiftCompareExitLimit(
9695 Value *LHS, Value *RHSV, const Loop *L, ICmpInst::Predicate Pred) {
9696 ConstantInt *RHS = dyn_cast<ConstantInt>(RHSV);
9697 if (!RHS)
9698 return getCouldNotCompute();
9699
9700 const BasicBlock *Latch = L->getLoopLatch();
9701 if (!Latch)
9702 return getCouldNotCompute();
9703
9704 const BasicBlock *Predecessor = L->getLoopPredecessor();
9705 if (!Predecessor)
9706 return getCouldNotCompute();
9707
9708 // Return true if V is of the form "LHS `shift_op` <positive constant>".
9709 // Return LHS in OutLHS and shift_opt in OutOpCode.
9710 auto MatchPositiveShift =
9711 [](Value *V, Value *&OutLHS, Instruction::BinaryOps &OutOpCode) {
9712
9713 using namespace PatternMatch;
9714
9715 ConstantInt *ShiftAmt;
9716 if (match(V, m_LShr(m_Value(OutLHS), m_ConstantInt(ShiftAmt))))
9717 OutOpCode = Instruction::LShr;
9718 else if (match(V, m_AShr(m_Value(OutLHS), m_ConstantInt(ShiftAmt))))
9719 OutOpCode = Instruction::AShr;
9720 else if (match(V, m_Shl(m_Value(OutLHS), m_ConstantInt(ShiftAmt))))
9721 OutOpCode = Instruction::Shl;
9722 else
9723 return false;
9724
9725 return ShiftAmt->getValue().isStrictlyPositive();
9726 };
9727
9728 // Recognize a "shift recurrence" either of the form %iv or of %iv.shifted in
9729 //
9730 // loop:
9731 // %iv = phi i32 [ %iv.shifted, %loop ], [ %val, %preheader ]
9732 // %iv.shifted = lshr i32 %iv, <positive constant>
9733 //
9734 // Return true on a successful match. Return the corresponding PHI node (%iv
9735 // above) in PNOut and the opcode of the shift operation in OpCodeOut.
9736 auto MatchShiftRecurrence =
9737 [&](Value *V, PHINode *&PNOut, Instruction::BinaryOps &OpCodeOut) {
9738 std::optional<Instruction::BinaryOps> PostShiftOpCode;
9739
9740 {
9742 Value *V;
9743
9744 // If we encounter a shift instruction, "peel off" the shift operation,
9745 // and remember that we did so. Later when we inspect %iv's backedge
9746 // value, we will make sure that the backedge value uses the same
9747 // operation.
9748 //
9749 // Note: the peeled shift operation does not have to be the same
9750 // instruction as the one feeding into the PHI's backedge value. We only
9751 // really care about it being the same *kind* of shift instruction --
9752 // that's all that is required for our later inferences to hold.
9753 if (MatchPositiveShift(LHS, V, OpC)) {
9754 PostShiftOpCode = OpC;
9755 LHS = V;
9756 }
9757 }
9758
9759 PNOut = dyn_cast<PHINode>(LHS);
9760 if (!PNOut || PNOut->getParent() != L->getHeader())
9761 return false;
9762
9763 Value *BEValue = PNOut->getIncomingValueForBlock(Latch);
9764 Value *OpLHS;
9765
9766 return
9767 // The backedge value for the PHI node must be a shift by a positive
9768 // amount
9769 MatchPositiveShift(BEValue, OpLHS, OpCodeOut) &&
9770
9771 // of the PHI node itself
9772 OpLHS == PNOut &&
9773
9774 // and the kind of shift should be match the kind of shift we peeled
9775 // off, if any.
9776 (!PostShiftOpCode || *PostShiftOpCode == OpCodeOut);
9777 };
9778
9779 PHINode *PN;
9781 if (!MatchShiftRecurrence(LHS, PN, OpCode))
9782 return getCouldNotCompute();
9783
9784 const DataLayout &DL = getDataLayout();
9785
9786 // The key rationale for this optimization is that for some kinds of shift
9787 // recurrences, the value of the recurrence "stabilizes" to either 0 or -1
9788 // within a finite number of iterations. If the condition guarding the
9789 // backedge (in the sense that the backedge is taken if the condition is true)
9790 // is false for the value the shift recurrence stabilizes to, then we know
9791 // that the backedge is taken only a finite number of times.
9792
9793 ConstantInt *StableValue = nullptr;
9794 switch (OpCode) {
9795 default:
9796 llvm_unreachable("Impossible case!");
9797
9798 case Instruction::AShr: {
9799 // {K,ashr,<positive-constant>} stabilizes to signum(K) in at most
9800 // bitwidth(K) iterations.
9801 Value *FirstValue = PN->getIncomingValueForBlock(Predecessor);
9802 KnownBits Known = computeKnownBits(FirstValue, DL, &AC,
9803 Predecessor->getTerminator(), &DT);
9804 auto *Ty = cast<IntegerType>(RHS->getType());
9805 if (Known.isNonNegative())
9806 StableValue = ConstantInt::get(Ty, 0);
9807 else if (Known.isNegative())
9808 StableValue = ConstantInt::get(Ty, -1, true);
9809 else
9810 return getCouldNotCompute();
9811
9812 break;
9813 }
9814 case Instruction::LShr:
9815 case Instruction::Shl:
9816 // Both {K,lshr,<positive-constant>} and {K,shl,<positive-constant>}
9817 // stabilize to 0 in at most bitwidth(K) iterations.
9818 StableValue = ConstantInt::get(cast<IntegerType>(RHS->getType()), 0);
9819 break;
9820 }
9821
9822 auto *Result =
9823 ConstantFoldCompareInstOperands(Pred, StableValue, RHS, DL, &TLI);
9824 assert(Result->getType()->isIntegerTy(1) &&
9825 "Otherwise cannot be an operand to a branch instruction");
9826
9827 if (Result->isNullValue()) {
9828 unsigned BitWidth = getTypeSizeInBits(RHS->getType());
9829 const SCEV *UpperBound =
9831 return ExitLimit(getCouldNotCompute(), UpperBound, UpperBound, false);
9832 }
9833
9834 return getCouldNotCompute();
9835}
9836
9837/// Return true if we can constant fold an instruction of the specified type,
9838/// assuming that all operands were constants.
9839static bool CanConstantFold(const Instruction *I) {
9843 return true;
9844
9845 if (const CallInst *CI = dyn_cast<CallInst>(I))
9846 if (const Function *F = CI->getCalledFunction())
9847 return canConstantFoldCallTo(CI, F);
9848 return false;
9849}
9850
9851/// Determine whether this instruction can constant evolve within this loop
9852/// assuming its operands can all constant evolve.
9853static bool canConstantEvolve(Instruction *I, const Loop *L) {
9854 // An instruction outside of the loop can't be derived from a loop PHI.
9855 if (!L->contains(I)) return false;
9856
9857 if (isa<PHINode>(I)) {
9858 // We don't currently keep track of the control flow needed to evaluate
9859 // PHIs, so we cannot handle PHIs inside of loops.
9860 return L->getHeader() == I->getParent();
9861 }
9862
9863 // If we won't be able to constant fold this expression even if the operands
9864 // are constants, bail early.
9865 return CanConstantFold(I);
9866}
9867
9868/// getConstantEvolvingPHIOperands - Implement getConstantEvolvingPHI by
9869/// recursing through each instruction operand until reaching a loop header phi.
9870static PHINode *
9873 unsigned Depth) {
9875 return nullptr;
9876
9877 // Otherwise, we can evaluate this instruction if all of its operands are
9878 // constant or derived from a PHI node themselves.
9879 PHINode *PHI = nullptr;
9880 for (Value *Op : UseInst->operands()) {
9881 if (isa<Constant>(Op)) continue;
9882
9884 if (!OpInst || !canConstantEvolve(OpInst, L)) return nullptr;
9885
9886 PHINode *P = dyn_cast<PHINode>(OpInst);
9887 if (!P)
9888 // If this operand is already visited, reuse the prior result.
9889 // We may have P != PHI if this is the deepest point at which the
9890 // inconsistent paths meet.
9891 P = PHIMap.lookup(OpInst);
9892 if (!P) {
9893 // Recurse and memoize the results, whether a phi is found or not.
9894 // This recursive call invalidates pointers into PHIMap.
9895 P = getConstantEvolvingPHIOperands(OpInst, L, PHIMap, Depth + 1);
9896 PHIMap[OpInst] = P;
9897 }
9898 if (!P)
9899 return nullptr; // Not evolving from PHI
9900 if (PHI && PHI != P)
9901 return nullptr; // Evolving from multiple different PHIs.
9902 PHI = P;
9903 }
9904 // This is a expression evolving from a constant PHI!
9905 return PHI;
9906}
9907
9908/// getConstantEvolvingPHI - Given an LLVM value and a loop, return a PHI node
9909/// in the loop that V is derived from. We allow arbitrary operations along the
9910/// way, but the operands of an operation must either be constants or a value
9911/// derived from a constant PHI. If this expression does not fit with these
9912/// constraints, return null.
9915 if (!I || !canConstantEvolve(I, L)) return nullptr;
9916
9917 if (PHINode *PN = dyn_cast<PHINode>(I))
9918 return PN;
9919
9920 // Record non-constant instructions contained by the loop.
9922 return getConstantEvolvingPHIOperands(I, L, PHIMap, 0);
9923}
9924
9925/// EvaluateExpression - Given an expression that passes the
9926/// getConstantEvolvingPHI predicate, evaluate its value assuming the PHI node
9927/// in the loop has the value PHIVal. If we can't fold this expression for some
9928/// reason, return null.
9931 const DataLayout &DL,
9932 const TargetLibraryInfo *TLI) {
9933 // Convenient constant check, but redundant for recursive calls.
9934 if (Constant *C = dyn_cast<Constant>(V)) return C;
9936 if (!I) return nullptr;
9937
9938 if (Constant *C = Vals.lookup(I)) return C;
9939
9940 // An instruction inside the loop depends on a value outside the loop that we
9941 // weren't given a mapping for, or a value such as a call inside the loop.
9942 if (!canConstantEvolve(I, L)) return nullptr;
9943
9944 // An unmapped PHI can be due to a branch or another loop inside this loop,
9945 // or due to this not being the initial iteration through a loop where we
9946 // couldn't compute the evolution of this particular PHI last time.
9947 if (isa<PHINode>(I)) return nullptr;
9948
9949 std::vector<Constant*> Operands(I->getNumOperands());
9950
9951 for (unsigned i = 0, e = I->getNumOperands(); i != e; ++i) {
9952 Instruction *Operand = dyn_cast<Instruction>(I->getOperand(i));
9953 if (!Operand) {
9954 Operands[i] = dyn_cast<Constant>(I->getOperand(i));
9955 if (!Operands[i]) return nullptr;
9956 continue;
9957 }
9958 Constant *C = EvaluateExpression(Operand, L, Vals, DL, TLI);
9959 Vals[Operand] = C;
9960 if (!C) return nullptr;
9961 Operands[i] = C;
9962 }
9963
9964 return ConstantFoldInstOperands(I, Operands, DL, TLI,
9965 /*AllowNonDeterministic=*/false);
9966}
9967
9968
9969// If every incoming value to PN except the one for BB is a specific Constant,
9970// return that, else return nullptr.
9972 Constant *IncomingVal = nullptr;
9973
9974 for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) {
9975 if (PN->getIncomingBlock(i) == BB)
9976 continue;
9977
9978 auto *CurrentVal = dyn_cast<Constant>(PN->getIncomingValue(i));
9979 if (!CurrentVal)
9980 return nullptr;
9981
9982 if (IncomingVal != CurrentVal) {
9983 if (IncomingVal)
9984 return nullptr;
9985 IncomingVal = CurrentVal;
9986 }
9987 }
9988
9989 return IncomingVal;
9990}
9991
9992/// getConstantEvolutionLoopExitValue - If we know that the specified Phi is
9993/// in the header of its containing loop, we know the loop executes a
9994/// constant number of times, and the PHI node is just a recurrence
9995/// involving constants, fold it.
9996Constant *
9997ScalarEvolution::getConstantEvolutionLoopExitValue(PHINode *PN,
9998 const APInt &BEs,
9999 const Loop *L) {
10000 auto [I, Inserted] = ConstantEvolutionLoopExitValue.try_emplace(PN);
10001 if (!Inserted)
10002 return I->second;
10003
10005 return nullptr; // Not going to evaluate it.
10006
10007 Constant *&RetVal = I->second;
10008
10009 DenseMap<Instruction *, Constant *> CurrentIterVals;
10010 BasicBlock *Header = L->getHeader();
10011 assert(PN->getParent() == Header && "Can't evaluate PHI not in loop header!");
10012
10013 BasicBlock *Latch = L->getLoopLatch();
10014 if (!Latch)
10015 return nullptr;
10016
10017 for (PHINode &PHI : Header->phis()) {
10018 if (auto *StartCST = getOtherIncomingValue(&PHI, Latch))
10019 CurrentIterVals[&PHI] = StartCST;
10020 }
10021 if (!CurrentIterVals.count(PN))
10022 return RetVal = nullptr;
10023
10024 Value *BEValue = PN->getIncomingValueForBlock(Latch);
10025
10026 // Execute the loop symbolically to determine the exit value.
10027 assert(BEs.getActiveBits() < CHAR_BIT * sizeof(unsigned) &&
10028 "BEs is <= MaxBruteForceIterations which is an 'unsigned'!");
10029
10030 unsigned NumIterations = BEs.getZExtValue(); // must be in range
10031 unsigned IterationNum = 0;
10032 const DataLayout &DL = getDataLayout();
10033 for (; ; ++IterationNum) {
10034 if (IterationNum == NumIterations)
10035 return RetVal = CurrentIterVals[PN]; // Got exit value!
10036
10037 // Compute the value of the PHIs for the next iteration.
10038 // EvaluateExpression adds non-phi values to the CurrentIterVals map.
10039 DenseMap<Instruction *, Constant *> NextIterVals;
10040 Constant *NextPHI =
10041 EvaluateExpression(BEValue, L, CurrentIterVals, DL, &TLI);
10042 if (!NextPHI)
10043 return nullptr; // Couldn't evaluate!
10044 NextIterVals[PN] = NextPHI;
10045
10046 bool StoppedEvolving = NextPHI == CurrentIterVals[PN];
10047
10048 // Also evaluate the other PHI nodes. However, we don't get to stop if we
10049 // cease to be able to evaluate one of them or if they stop evolving,
10050 // because that doesn't necessarily prevent us from computing PN.
10052 for (const auto &I : CurrentIterVals) {
10053 PHINode *PHI = dyn_cast<PHINode>(I.first);
10054 if (!PHI || PHI == PN || PHI->getParent() != Header) continue;
10055 PHIsToCompute.emplace_back(PHI, I.second);
10056 }
10057 // We use two distinct loops because EvaluateExpression may invalidate any
10058 // iterators into CurrentIterVals.
10059 for (const auto &I : PHIsToCompute) {
10060 PHINode *PHI = I.first;
10061 Constant *&NextPHI = NextIterVals[PHI];
10062 if (!NextPHI) { // Not already computed.
10063 Value *BEValue = PHI->getIncomingValueForBlock(Latch);
10064 NextPHI = EvaluateExpression(BEValue, L, CurrentIterVals, DL, &TLI);
10065 }
10066 if (NextPHI != I.second)
10067 StoppedEvolving = false;
10068 }
10069
10070 // If all entries in CurrentIterVals == NextIterVals then we can stop
10071 // iterating, the loop can't continue to change.
10072 if (StoppedEvolving)
10073 return RetVal = CurrentIterVals[PN];
10074
10075 CurrentIterVals.swap(NextIterVals);
10076 }
10077}
10078
10079const SCEV *ScalarEvolution::computeExitCountExhaustively(const Loop *L,
10080 Value *Cond,
10081 bool ExitWhen) {
10082 PHINode *PN = getConstantEvolvingPHI(Cond, L);
10083 if (!PN) return getCouldNotCompute();
10084
10085 // If the loop is canonicalized, the PHI will have exactly two entries.
10086 // That's the only form we support here.
10087 if (PN->getNumIncomingValues() != 2) return getCouldNotCompute();
10088
10089 DenseMap<Instruction *, Constant *> CurrentIterVals;
10090 BasicBlock *Header = L->getHeader();
10091 assert(PN->getParent() == Header && "Can't evaluate PHI not in loop header!");
10092
10093 BasicBlock *Latch = L->getLoopLatch();
10094 assert(Latch && "Should follow from NumIncomingValues == 2!");
10095
10096 for (PHINode &PHI : Header->phis()) {
10097 if (auto *StartCST = getOtherIncomingValue(&PHI, Latch))
10098 CurrentIterVals[&PHI] = StartCST;
10099 }
10100 if (!CurrentIterVals.count(PN))
10101 return getCouldNotCompute();
10102
10103 // Okay, we find a PHI node that defines the trip count of this loop. Execute
10104 // the loop symbolically to determine when the condition gets a value of
10105 // "ExitWhen".
10106 unsigned MaxIterations = MaxBruteForceIterations; // Limit analysis.
10107 const DataLayout &DL = getDataLayout();
10108 for (unsigned IterationNum = 0; IterationNum != MaxIterations;++IterationNum){
10109 auto *CondVal = dyn_cast_or_null<ConstantInt>(
10110 EvaluateExpression(Cond, L, CurrentIterVals, DL, &TLI));
10111
10112 // Couldn't symbolically evaluate.
10113 if (!CondVal) return getCouldNotCompute();
10114
10115 if (CondVal->getValue() == uint64_t(ExitWhen)) {
10116 ++NumBruteForceTripCountsComputed;
10117 return getConstant(Type::getInt32Ty(getContext()), IterationNum);
10118 }
10119
10120 // Update all the PHI nodes for the next iteration.
10121 DenseMap<Instruction *, Constant *> NextIterVals;
10122
10123 // Create a list of which PHIs we need to compute. We want to do this before
10124 // calling EvaluateExpression on them because that may invalidate iterators
10125 // into CurrentIterVals.
10126 SmallVector<PHINode *, 8> PHIsToCompute;
10127 for (const auto &I : CurrentIterVals) {
10128 PHINode *PHI = dyn_cast<PHINode>(I.first);
10129 if (!PHI || PHI->getParent() != Header) continue;
10130 PHIsToCompute.push_back(PHI);
10131 }
10132 for (PHINode *PHI : PHIsToCompute) {
10133 Constant *&NextPHI = NextIterVals[PHI];
10134 if (NextPHI) continue; // Already computed!
10135
10136 Value *BEValue = PHI->getIncomingValueForBlock(Latch);
10137 NextPHI = EvaluateExpression(BEValue, L, CurrentIterVals, DL, &TLI);
10138 }
10139 CurrentIterVals.swap(NextIterVals);
10140 }
10141
10142 // Too many iterations were needed to evaluate.
10143 return getCouldNotCompute();
10144}
10145
10146const SCEV *ScalarEvolution::getSCEVAtScope(const SCEV *V, const Loop *L) {
10148 ValuesAtScopes[V];
10149 // Check to see if we've folded this expression at this loop before.
10150 for (auto &LS : Values)
10151 if (LS.first == L)
10152 return LS.second ? LS.second : V;
10153
10154 Values.emplace_back(L, nullptr);
10155
10156 // Otherwise compute it.
10157 const SCEV *C = computeSCEVAtScope(V, L);
10158 for (auto &LS : reverse(ValuesAtScopes[V]))
10159 if (LS.first == L) {
10160 LS.second = C;
10161 if (!isa<SCEVConstant>(C))
10162 ValuesAtScopesUsers[C].push_back({L, V});
10163 break;
10164 }
10165 return C;
10166}
10167
10168/// This builds up a Constant using the ConstantExpr interface. That way, we
10169/// will return Constants for objects which aren't represented by a
10170/// SCEVConstant, because SCEVConstant is restricted to ConstantInt.
10171/// Returns NULL if the SCEV isn't representable as a Constant.
10173 switch (V->getSCEVType()) {
10174 case scCouldNotCompute:
10175 case scAddRecExpr:
10176 case scVScale:
10177 return nullptr;
10178 case scConstant:
10179 return cast<SCEVConstant>(V)->getValue();
10180 case scUnknown:
10181 return dyn_cast<Constant>(cast<SCEVUnknown>(V)->getValue());
10182 case scPtrToAddr: {
10184 if (Constant *CastOp = BuildConstantFromSCEV(P2I->getOperand()))
10185 return ConstantExpr::getPtrToAddr(CastOp, P2I->getType());
10186
10187 return nullptr;
10188 }
10189 case scPtrToInt: {
10191 if (Constant *CastOp = BuildConstantFromSCEV(P2I->getOperand()))
10192 return ConstantExpr::getPtrToInt(CastOp, P2I->getType());
10193
10194 return nullptr;
10195 }
10196 case scTruncate: {
10198 if (Constant *CastOp = BuildConstantFromSCEV(ST->getOperand()))
10199 return ConstantExpr::getTrunc(CastOp, ST->getType());
10200 return nullptr;
10201 }
10202 case scAddExpr: {
10203 const SCEVAddExpr *SA = cast<SCEVAddExpr>(V);
10204 Constant *C = nullptr;
10205 for (const SCEV *Op : SA->operands()) {
10207 if (!OpC)
10208 return nullptr;
10209 if (!C) {
10210 C = OpC;
10211 continue;
10212 }
10213 assert(!C->getType()->isPointerTy() &&
10214 "Can only have one pointer, and it must be last");
10215 if (OpC->getType()->isPointerTy()) {
10216 // The offsets have been converted to bytes. We can add bytes using
10217 // an i8 GEP.
10218 C = ConstantExpr::getPtrAdd(OpC, C);
10219 } else {
10220 C = ConstantExpr::getAdd(C, OpC);
10221 }
10222 }
10223 return C;
10224 }
10225 case scMulExpr:
10226 case scSignExtend:
10227 case scZeroExtend:
10228 case scUDivExpr:
10229 case scSMaxExpr:
10230 case scUMaxExpr:
10231 case scSMinExpr:
10232 case scUMinExpr:
10234 return nullptr;
10235 }
10236 llvm_unreachable("Unknown SCEV kind!");
10237}
10238
10239const SCEV *ScalarEvolution::getWithOperands(const SCEV *S,
10240 SmallVectorImpl<SCEVUse> &NewOps) {
10241 switch (S->getSCEVType()) {
10242 case scTruncate:
10243 case scZeroExtend:
10244 case scSignExtend:
10245 case scPtrToAddr:
10246 case scPtrToInt:
10247 return getCastExpr(S->getSCEVType(), NewOps[0], S->getType());
10248 case scAddRecExpr: {
10249 auto *AddRec = cast<SCEVAddRecExpr>(S);
10250 return getAddRecExpr(NewOps, AddRec->getLoop(), AddRec->getNoWrapFlags());
10251 }
10252 case scAddExpr:
10253 return getAddExpr(NewOps, cast<SCEVAddExpr>(S)->getNoWrapFlags());
10254 case scMulExpr:
10255 return getMulExpr(NewOps, cast<SCEVMulExpr>(S)->getNoWrapFlags());
10256 case scUDivExpr:
10257 return getUDivExpr(NewOps[0], NewOps[1]);
10258 case scUMaxExpr:
10259 case scSMaxExpr:
10260 case scUMinExpr:
10261 case scSMinExpr:
10262 return getMinMaxExpr(S->getSCEVType(), NewOps);
10264 return getSequentialMinMaxExpr(S->getSCEVType(), NewOps);
10265 case scConstant:
10266 case scVScale:
10267 case scUnknown:
10268 return S;
10269 case scCouldNotCompute:
10270 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
10271 }
10272 llvm_unreachable("Unknown SCEV kind!");
10273}
10274
10275const SCEV *ScalarEvolution::computeSCEVAtScope(const SCEV *V, const Loop *L) {
10276 switch (V->getSCEVType()) {
10277 case scConstant:
10278 case scVScale:
10279 return V;
10280 case scAddRecExpr: {
10281 // If this is a loop recurrence for a loop that does not contain L, then we
10282 // are dealing with the final value computed by the loop.
10283 const SCEVAddRecExpr *AddRec = cast<SCEVAddRecExpr>(V);
10284 // First, attempt to evaluate each operand.
10285 // Avoid performing the look-up in the common case where the specified
10286 // expression has no loop-variant portions.
10287 for (unsigned i = 0, e = AddRec->getNumOperands(); i != e; ++i) {
10288 const SCEV *OpAtScope = getSCEVAtScope(AddRec->getOperand(i), L);
10289 if (OpAtScope == AddRec->getOperand(i))
10290 continue;
10291
10292 // Okay, at least one of these operands is loop variant but might be
10293 // foldable. Build a new instance of the folded commutative expression.
10295 NewOps.reserve(AddRec->getNumOperands());
10296 append_range(NewOps, AddRec->operands().take_front(i));
10297 NewOps.push_back(OpAtScope);
10298 for (++i; i != e; ++i)
10299 NewOps.push_back(getSCEVAtScope(AddRec->getOperand(i), L));
10300
10301 const SCEV *FoldedRec = getAddRecExpr(
10302 NewOps, AddRec->getLoop(), AddRec->getNoWrapFlags(SCEV::FlagNW));
10303 AddRec = dyn_cast<SCEVAddRecExpr>(FoldedRec);
10304 // The addrec may be folded to a nonrecurrence, for example, if the
10305 // induction variable is multiplied by zero after constant folding. Go
10306 // ahead and return the folded value.
10307 if (!AddRec)
10308 return FoldedRec;
10309 break;
10310 }
10311
10312 // If the scope is outside the addrec's loop, evaluate it by using the
10313 // loop exit value of the addrec.
10314 if (!AddRec->getLoop()->contains(L)) {
10315 // To evaluate this recurrence, we need to know how many times the AddRec
10316 // loop iterates. Compute this now.
10317 const SCEV *BackedgeTakenCount = getBackedgeTakenCount(AddRec->getLoop());
10318 if (BackedgeTakenCount == getCouldNotCompute())
10319 return AddRec;
10320
10321 // Then, evaluate the AddRec.
10322 return AddRec->evaluateAtIteration(BackedgeTakenCount, *this);
10323 }
10324
10325 return AddRec;
10326 }
10327 case scTruncate:
10328 case scZeroExtend:
10329 case scSignExtend:
10330 case scPtrToAddr:
10331 case scPtrToInt:
10332 case scAddExpr:
10333 case scMulExpr:
10334 case scUDivExpr:
10335 case scUMaxExpr:
10336 case scSMaxExpr:
10337 case scUMinExpr:
10338 case scSMinExpr:
10339 case scSequentialUMinExpr: {
10340 ArrayRef<SCEVUse> Ops = V->operands();
10341 // Avoid performing the look-up in the common case where the specified
10342 // expression has no loop-variant portions.
10343 for (unsigned i = 0, e = Ops.size(); i != e; ++i) {
10344 const SCEV *OpAtScope = getSCEVAtScope(Ops[i].getPointer(), L);
10345 if (OpAtScope != Ops[i].getPointer()) {
10346 // Okay, at least one of these operands is loop variant but might be
10347 // foldable. Build a new instance of the folded commutative expression.
10349 NewOps.reserve(Ops.size());
10350 append_range(NewOps, Ops.take_front(i));
10351 NewOps.push_back(OpAtScope);
10352
10353 for (++i; i != e; ++i) {
10354 OpAtScope = getSCEVAtScope(Ops[i].getPointer(), L);
10355 NewOps.push_back(OpAtScope);
10356 }
10357
10358 return getWithOperands(V, NewOps);
10359 }
10360 }
10361 // If we got here, all operands are loop invariant.
10362 return V;
10363 }
10364 case scUnknown: {
10365 // If this instruction is evolved from a constant-evolving PHI, compute the
10366 // exit value from the loop without using SCEVs.
10367 const SCEVUnknown *SU = cast<SCEVUnknown>(V);
10369 if (!I)
10370 return V; // This is some other type of SCEVUnknown, just return it.
10371
10372 if (PHINode *PN = dyn_cast<PHINode>(I)) {
10373 const Loop *CurrLoop = this->LI[I->getParent()];
10374 // Looking for loop exit value.
10375 if (CurrLoop && CurrLoop->getParentLoop() == L &&
10376 PN->getParent() == CurrLoop->getHeader()) {
10377 // Okay, there is no closed form solution for the PHI node. Check
10378 // to see if the loop that contains it has a known backedge-taken
10379 // count. If so, we may be able to force computation of the exit
10380 // value.
10381 const SCEV *BackedgeTakenCount = getBackedgeTakenCount(CurrLoop);
10382 // This trivial case can show up in some degenerate cases where
10383 // the incoming IR has not yet been fully simplified.
10384 if (BackedgeTakenCount->isZero()) {
10385 Value *InitValue = nullptr;
10386 bool MultipleInitValues = false;
10387 for (unsigned i = 0; i < PN->getNumIncomingValues(); i++) {
10388 if (!CurrLoop->contains(PN->getIncomingBlock(i))) {
10389 if (!InitValue)
10390 InitValue = PN->getIncomingValue(i);
10391 else if (InitValue != PN->getIncomingValue(i)) {
10392 MultipleInitValues = true;
10393 break;
10394 }
10395 }
10396 }
10397 if (!MultipleInitValues && InitValue)
10398 return getSCEV(InitValue);
10399 }
10400 // Do we have a loop invariant value flowing around the backedge
10401 // for a loop which must execute the backedge?
10402 if (!isa<SCEVCouldNotCompute>(BackedgeTakenCount) &&
10403 isKnownNonZero(BackedgeTakenCount) &&
10404 PN->getNumIncomingValues() == 2) {
10405
10406 unsigned InLoopPred =
10407 CurrLoop->contains(PN->getIncomingBlock(0)) ? 0 : 1;
10408 Value *BackedgeVal = PN->getIncomingValue(InLoopPred);
10409 if (CurrLoop->isLoopInvariant(BackedgeVal))
10410 return getSCEV(BackedgeVal);
10411 }
10412 if (auto *BTCC = dyn_cast<SCEVConstant>(BackedgeTakenCount)) {
10413 // Okay, we know how many times the containing loop executes. If
10414 // this is a constant evolving PHI node, get the final value at
10415 // the specified iteration number.
10416 Constant *RV =
10417 getConstantEvolutionLoopExitValue(PN, BTCC->getAPInt(), CurrLoop);
10418 if (RV)
10419 return getSCEV(RV);
10420 }
10421 }
10422 }
10423
10424 // Okay, this is an expression that we cannot symbolically evaluate
10425 // into a SCEV. Check to see if it's possible to symbolically evaluate
10426 // the arguments into constants, and if so, try to constant propagate the
10427 // result. This is particularly useful for computing loop exit values.
10428 if (!CanConstantFold(I))
10429 return V; // This is some other type of SCEVUnknown, just return it.
10430
10431 SmallVector<Constant *, 4> Operands;
10432 Operands.reserve(I->getNumOperands());
10433 bool MadeImprovement = false;
10434 for (Value *Op : I->operands()) {
10435 if (Constant *C = dyn_cast<Constant>(Op)) {
10436 Operands.push_back(C);
10437 continue;
10438 }
10439
10440 // If any of the operands is non-constant and if they are
10441 // non-integer and non-pointer, don't even try to analyze them
10442 // with scev techniques.
10443 if (!isSCEVable(Op->getType()))
10444 return V;
10445
10446 const SCEV *OrigV = getSCEV(Op);
10447 const SCEV *OpV = getSCEVAtScope(OrigV, L);
10448 MadeImprovement |= OrigV != OpV;
10449
10451 if (!C)
10452 return V;
10453 assert(C->getType() == Op->getType() && "Type mismatch");
10454 Operands.push_back(C);
10455 }
10456
10457 // Check to see if getSCEVAtScope actually made an improvement.
10458 if (!MadeImprovement)
10459 return V; // This is some other type of SCEVUnknown, just return it.
10460
10461 Constant *C = nullptr;
10462 const DataLayout &DL = getDataLayout();
10463 C = ConstantFoldInstOperands(I, Operands, DL, &TLI,
10464 /*AllowNonDeterministic=*/false);
10465 if (!C)
10466 return V;
10467 return getSCEV(C);
10468 }
10469 case scCouldNotCompute:
10470 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
10471 }
10472 llvm_unreachable("Unknown SCEV type!");
10473}
10474
10476 return getSCEVAtScope(getSCEV(V), L);
10477}
10478
10479const SCEV *ScalarEvolution::stripInjectiveFunctions(const SCEV *S) const {
10481 return stripInjectiveFunctions(ZExt->getOperand());
10483 return stripInjectiveFunctions(SExt->getOperand());
10484 return S;
10485}
10486
10487/// Finds the minimum unsigned root of the following equation:
10488///
10489/// A * X = B (mod N)
10490///
10491/// where N = 2^BW and BW is the common bit width of A and B. The signedness of
10492/// A and B isn't important.
10493///
10494/// If the equation does not have a solution, SCEVCouldNotCompute is returned.
10495static const SCEV *
10498 ScalarEvolution &SE, const Loop *L) {
10499 uint32_t BW = A.getBitWidth();
10500 assert(BW == SE.getTypeSizeInBits(B->getType()));
10501 assert(A != 0 && "A must be non-zero.");
10502
10503 // 1. D = gcd(A, N)
10504 //
10505 // The gcd of A and N may have only one prime factor: 2. The number of
10506 // trailing zeros in A is its multiplicity
10507 uint32_t Mult2 = A.countr_zero();
10508 // D = 2^Mult2
10509
10510 // 2. Check if B is divisible by D.
10511 //
10512 // B is divisible by D if and only if the multiplicity of prime factor 2 for B
10513 // is not less than multiplicity of this prime factor for D.
10514 unsigned MinTZ = SE.getMinTrailingZeros(B);
10515 // Try again with the terminator of the loop predecessor for context-specific
10516 // result, if MinTZ s too small.
10517 if (MinTZ < Mult2 && L->getLoopPredecessor())
10518 MinTZ = SE.getMinTrailingZeros(B, L->getLoopPredecessor()->getTerminator());
10519 if (MinTZ < Mult2) {
10520 // Check if we can prove there's no remainder using URem.
10521 const SCEV *URem =
10522 SE.getURemExpr(B, SE.getConstant(APInt::getOneBitSet(BW, Mult2)));
10523 const SCEV *Zero = SE.getZero(B->getType());
10524 if (!SE.isKnownPredicate(CmpInst::ICMP_EQ, URem, Zero)) {
10525 // Try to add a predicate ensuring B is a multiple of 1 << Mult2.
10526 if (!Predicates)
10527 return SE.getCouldNotCompute();
10528
10529 // Avoid adding a predicate that is known to be false.
10530 if (SE.isKnownPredicate(CmpInst::ICMP_NE, URem, Zero))
10531 return SE.getCouldNotCompute();
10532 Predicates->push_back(SE.getEqualPredicate(URem, Zero));
10533 }
10534 }
10535
10536 // 3. Compute I: the multiplicative inverse of (A / D) in arithmetic
10537 // modulo (N / D).
10538 //
10539 // If D == 1, (N / D) == N == 2^BW, so we need one extra bit to represent
10540 // (N / D) in general. The inverse itself always fits into BW bits, though,
10541 // so we immediately truncate it.
10542 APInt AD = A.lshr(Mult2).trunc(BW - Mult2); // AD = A / D
10543 APInt I = AD.multiplicativeInverse().zext(BW);
10544
10545 // 4. Compute the minimum unsigned root of the equation:
10546 // I * (B / D) mod (N / D)
10547 // To simplify the computation, we factor out the divide by D:
10548 // (I * B mod N) / D
10549 const SCEV *D = SE.getConstant(APInt::getOneBitSet(BW, Mult2));
10550 return SE.getUDivExactExpr(SE.getMulExpr(B, SE.getConstant(I)), D);
10551}
10552
10553/// For a given quadratic addrec, generate coefficients of the corresponding
10554/// quadratic equation, multiplied by a common value to ensure that they are
10555/// integers.
10556/// The returned value is a tuple { A, B, C, M, BitWidth }, where
10557/// Ax^2 + Bx + C is the quadratic function, M is the value that A, B and C
10558/// were multiplied by, and BitWidth is the bit width of the original addrec
10559/// coefficients.
10560/// This function returns std::nullopt if the addrec coefficients are not
10561/// compile- time constants.
10562static std::optional<std::tuple<APInt, APInt, APInt, APInt, unsigned>>
10564 assert(AddRec->getNumOperands() == 3 && "This is not a quadratic chrec!");
10565 const SCEVConstant *LC = dyn_cast<SCEVConstant>(AddRec->getOperand(0));
10566 const SCEVConstant *MC = dyn_cast<SCEVConstant>(AddRec->getOperand(1));
10567 const SCEVConstant *NC = dyn_cast<SCEVConstant>(AddRec->getOperand(2));
10568 LLVM_DEBUG(dbgs() << __func__ << ": analyzing quadratic addrec: "
10569 << *AddRec << '\n');
10570
10571 // We currently can only solve this if the coefficients are constants.
10572 if (!LC || !MC || !NC) {
10573 LLVM_DEBUG(dbgs() << __func__ << ": coefficients are not constant\n");
10574 return std::nullopt;
10575 }
10576
10577 APInt L = LC->getAPInt();
10578 APInt M = MC->getAPInt();
10579 APInt N = NC->getAPInt();
10580 assert(!N.isZero() && "This is not a quadratic addrec");
10581
10582 unsigned BitWidth = LC->getAPInt().getBitWidth();
10583 unsigned NewWidth = BitWidth + 1;
10584 LLVM_DEBUG(dbgs() << __func__ << ": addrec coeff bw: "
10585 << BitWidth << '\n');
10586 // The sign-extension (as opposed to a zero-extension) here matches the
10587 // extension used in SolveQuadraticEquationWrap (with the same motivation).
10588 N = N.sext(NewWidth);
10589 M = M.sext(NewWidth);
10590 L = L.sext(NewWidth);
10591
10592 // The increments are M, M+N, M+2N, ..., so the accumulated values are
10593 // L+M, (L+M)+(M+N), (L+M)+(M+N)+(M+2N), ..., that is,
10594 // L+M, L+2M+N, L+3M+3N, ...
10595 // After n iterations the accumulated value Acc is L + nM + n(n-1)/2 N.
10596 //
10597 // The equation Acc = 0 is then
10598 // L + nM + n(n-1)/2 N = 0, or 2L + 2M n + n(n-1) N = 0.
10599 // In a quadratic form it becomes:
10600 // N n^2 + (2M-N) n + 2L = 0.
10601
10602 APInt A = N;
10603 APInt B = 2 * M - A;
10604 APInt C = 2 * L;
10605 APInt T = APInt(NewWidth, 2);
10606 LLVM_DEBUG(dbgs() << __func__ << ": equation " << A << "x^2 + " << B
10607 << "x + " << C << ", coeff bw: " << NewWidth
10608 << ", multiplied by " << T << '\n');
10609 return std::make_tuple(A, B, C, T, BitWidth);
10610}
10611
10612/// Helper function to compare optional APInts:
10613/// (a) if X and Y both exist, return min(X, Y),
10614/// (b) if neither X nor Y exist, return std::nullopt,
10615/// (c) if exactly one of X and Y exists, return that value.
10616static std::optional<APInt> MinOptional(std::optional<APInt> X,
10617 std::optional<APInt> Y) {
10618 if (X && Y) {
10619 unsigned W = std::max(X->getBitWidth(), Y->getBitWidth());
10620 APInt XW = X->sext(W);
10621 APInt YW = Y->sext(W);
10622 return XW.slt(YW) ? *X : *Y;
10623 }
10624 if (!X && !Y)
10625 return std::nullopt;
10626 return X ? *X : *Y;
10627}
10628
10629/// Helper function to truncate an optional APInt to a given BitWidth.
10630/// When solving addrec-related equations, it is preferable to return a value
10631/// that has the same bit width as the original addrec's coefficients. If the
10632/// solution fits in the original bit width, truncate it (except for i1).
10633/// Returning a value of a different bit width may inhibit some optimizations.
10634///
10635/// In general, a solution to a quadratic equation generated from an addrec
10636/// may require BW+1 bits, where BW is the bit width of the addrec's
10637/// coefficients. The reason is that the coefficients of the quadratic
10638/// equation are BW+1 bits wide (to avoid truncation when converting from
10639/// the addrec to the equation).
10640static std::optional<APInt> TruncIfPossible(std::optional<APInt> X,
10641 unsigned BitWidth) {
10642 if (!X)
10643 return std::nullopt;
10644 unsigned W = X->getBitWidth();
10646 return X->trunc(BitWidth);
10647 return X;
10648}
10649
10650/// Let c(n) be the value of the quadratic chrec {L,+,M,+,N} after n
10651/// iterations. The values L, M, N are assumed to be signed, and they
10652/// should all have the same bit widths.
10653/// Find the least n >= 0 such that c(n) = 0 in the arithmetic modulo 2^BW,
10654/// where BW is the bit width of the addrec's coefficients.
10655/// If the calculated value is a BW-bit integer (for BW > 1), it will be
10656/// returned as such, otherwise the bit width of the returned value may
10657/// be greater than BW.
10658///
10659/// This function returns std::nullopt if
10660/// (a) the addrec coefficients are not constant, or
10661/// (b) SolveQuadraticEquationWrap was unable to find a solution. For cases
10662/// like x^2 = 5, no integer solutions exist, in other cases an integer
10663/// solution may exist, but SolveQuadraticEquationWrap may fail to find it.
10664static std::optional<APInt>
10666 APInt A, B, C, M;
10667 unsigned BitWidth;
10668 auto T = GetQuadraticEquation(AddRec);
10669 if (!T)
10670 return std::nullopt;
10671
10672 std::tie(A, B, C, M, BitWidth) = *T;
10673 LLVM_DEBUG(dbgs() << __func__ << ": solving for unsigned overflow\n");
10674 std::optional<APInt> X =
10676 if (!X)
10677 return std::nullopt;
10678
10679 ConstantInt *CX = ConstantInt::get(SE.getContext(), *X);
10680 ConstantInt *V = EvaluateConstantChrecAtConstant(AddRec, CX, SE);
10681 if (!V->isZero())
10682 return std::nullopt;
10683
10684 return TruncIfPossible(X, BitWidth);
10685}
10686
10687/// Let c(n) be the value of the quadratic chrec {0,+,M,+,N} after n
10688/// iterations. The values M, N are assumed to be signed, and they
10689/// should all have the same bit widths.
10690/// Find the least n such that c(n) does not belong to the given range,
10691/// while c(n-1) does.
10692///
10693/// This function returns std::nullopt if
10694/// (a) the addrec coefficients are not constant, or
10695/// (b) SolveQuadraticEquationWrap was unable to find a solution for the
10696/// bounds of the range.
10697static std::optional<APInt>
10699 const ConstantRange &Range, ScalarEvolution &SE) {
10700 assert(AddRec->getOperand(0)->isZero() &&
10701 "Starting value of addrec should be 0");
10702 LLVM_DEBUG(dbgs() << __func__ << ": solving boundary crossing for range "
10703 << Range << ", addrec " << *AddRec << '\n');
10704 // This case is handled in getNumIterationsInRange. Here we can assume that
10705 // we start in the range.
10706 assert(Range.contains(APInt(SE.getTypeSizeInBits(AddRec->getType()), 0)) &&
10707 "Addrec's initial value should be in range");
10708
10709 APInt A, B, C, M;
10710 unsigned BitWidth;
10711 auto T = GetQuadraticEquation(AddRec);
10712 if (!T)
10713 return std::nullopt;
10714
10715 // Be careful about the return value: there can be two reasons for not
10716 // returning an actual number. First, if no solutions to the equations
10717 // were found, and second, if the solutions don't leave the given range.
10718 // The first case means that the actual solution is "unknown", the second
10719 // means that it's known, but not valid. If the solution is unknown, we
10720 // cannot make any conclusions.
10721 // Return a pair: the optional solution and a flag indicating if the
10722 // solution was found.
10723 auto SolveForBoundary =
10724 [&](APInt Bound) -> std::pair<std::optional<APInt>, bool> {
10725 // Solve for signed overflow and unsigned overflow, pick the lower
10726 // solution.
10727 LLVM_DEBUG(dbgs() << "SolveQuadraticAddRecRange: checking boundary "
10728 << Bound << " (before multiplying by " << M << ")\n");
10729 Bound *= M; // The quadratic equation multiplier.
10730
10731 std::optional<APInt> SO;
10732 if (BitWidth > 1) {
10733 LLVM_DEBUG(dbgs() << "SolveQuadraticAddRecRange: solving for "
10734 "signed overflow\n");
10736 }
10737 LLVM_DEBUG(dbgs() << "SolveQuadraticAddRecRange: solving for "
10738 "unsigned overflow\n");
10739 std::optional<APInt> UO =
10741
10742 auto LeavesRange = [&] (const APInt &X) {
10743 ConstantInt *C0 = ConstantInt::get(SE.getContext(), X);
10744 ConstantInt *V0 = EvaluateConstantChrecAtConstant(AddRec, C0, SE);
10745 if (Range.contains(V0->getValue()))
10746 return false;
10747 // X should be at least 1, so X-1 is non-negative.
10748 ConstantInt *C1 = ConstantInt::get(SE.getContext(), X-1);
10749 ConstantInt *V1 = EvaluateConstantChrecAtConstant(AddRec, C1, SE);
10750 if (Range.contains(V1->getValue()))
10751 return true;
10752 return false;
10753 };
10754
10755 // If SolveQuadraticEquationWrap returns std::nullopt, it means that there
10756 // can be a solution, but the function failed to find it. We cannot treat it
10757 // as "no solution".
10758 if (!SO || !UO)
10759 return {std::nullopt, false};
10760
10761 // Check the smaller value first to see if it leaves the range.
10762 // At this point, both SO and UO must have values.
10763 std::optional<APInt> Min = MinOptional(SO, UO);
10764 if (LeavesRange(*Min))
10765 return { Min, true };
10766 std::optional<APInt> Max = Min == SO ? UO : SO;
10767 if (LeavesRange(*Max))
10768 return { Max, true };
10769
10770 // Solutions were found, but were eliminated, hence the "true".
10771 return {std::nullopt, true};
10772 };
10773
10774 std::tie(A, B, C, M, BitWidth) = *T;
10775 // Lower bound is inclusive, subtract 1 to represent the exiting value.
10776 APInt Lower = Range.getLower().sext(A.getBitWidth()) - 1;
10777 APInt Upper = Range.getUpper().sext(A.getBitWidth());
10778 auto SL = SolveForBoundary(Lower);
10779 auto SU = SolveForBoundary(Upper);
10780 // If any of the solutions was unknown, no meaninigful conclusions can
10781 // be made.
10782 if (!SL.second || !SU.second)
10783 return std::nullopt;
10784
10785 // Claim: The correct solution is not some value between Min and Max.
10786 //
10787 // Justification: Assuming that Min and Max are different values, one of
10788 // them is when the first signed overflow happens, the other is when the
10789 // first unsigned overflow happens. Crossing the range boundary is only
10790 // possible via an overflow (treating 0 as a special case of it, modeling
10791 // an overflow as crossing k*2^W for some k).
10792 //
10793 // The interesting case here is when Min was eliminated as an invalid
10794 // solution, but Max was not. The argument is that if there was another
10795 // overflow between Min and Max, it would also have been eliminated if
10796 // it was considered.
10797 //
10798 // For a given boundary, it is possible to have two overflows of the same
10799 // type (signed/unsigned) without having the other type in between: this
10800 // can happen when the vertex of the parabola is between the iterations
10801 // corresponding to the overflows. This is only possible when the two
10802 // overflows cross k*2^W for the same k. In such case, if the second one
10803 // left the range (and was the first one to do so), the first overflow
10804 // would have to enter the range, which would mean that either we had left
10805 // the range before or that we started outside of it. Both of these cases
10806 // are contradictions.
10807 //
10808 // Claim: In the case where SolveForBoundary returns std::nullopt, the correct
10809 // solution is not some value between the Max for this boundary and the
10810 // Min of the other boundary.
10811 //
10812 // Justification: Assume that we had such Max_A and Min_B corresponding
10813 // to range boundaries A and B and such that Max_A < Min_B. If there was
10814 // a solution between Max_A and Min_B, it would have to be caused by an
10815 // overflow corresponding to either A or B. It cannot correspond to B,
10816 // since Min_B is the first occurrence of such an overflow. If it
10817 // corresponded to A, it would have to be either a signed or an unsigned
10818 // overflow that is larger than both eliminated overflows for A. But
10819 // between the eliminated overflows and this overflow, the values would
10820 // cover the entire value space, thus crossing the other boundary, which
10821 // is a contradiction.
10822
10823 return TruncIfPossible(MinOptional(SL.first, SU.first), BitWidth);
10824}
10825
10826ScalarEvolution::ExitLimit ScalarEvolution::howFarToZero(const SCEV *V,
10827 const Loop *L,
10828 bool ControlsOnlyExit,
10829 bool AllowPredicates) {
10830
10831 // This is only used for loops with a "x != y" exit test. The exit condition
10832 // is now expressed as a single expression, V = x-y. So the exit test is
10833 // effectively V != 0. We know and take advantage of the fact that this
10834 // expression only being used in a comparison by zero context.
10835
10837 // If the value is a constant
10838 if (const SCEVConstant *C = dyn_cast<SCEVConstant>(V)) {
10839 // If the value is already zero, the branch will execute zero times.
10840 if (C->getValue()->isZero()) return C;
10841 return getCouldNotCompute(); // Otherwise it will loop infinitely.
10842 }
10843
10844 const SCEVAddRecExpr *AddRec =
10845 dyn_cast<SCEVAddRecExpr>(stripInjectiveFunctions(V));
10846
10847 if (!AddRec && AllowPredicates)
10848 // Try to make this an AddRec using runtime tests, in the first X
10849 // iterations of this loop, where X is the SCEV expression found by the
10850 // algorithm below.
10851 AddRec = convertSCEVToAddRecWithPredicates(V, L, Predicates);
10852
10853 if (!AddRec || AddRec->getLoop() != L)
10854 return getCouldNotCompute();
10855
10856 // If this is a quadratic (3-term) AddRec {L,+,M,+,N}, find the roots of
10857 // the quadratic equation to solve it.
10858 if (AddRec->isQuadratic() && AddRec->getType()->isIntegerTy()) {
10859 // We can only use this value if the chrec ends up with an exact zero
10860 // value at this index. When solving for "X*X != 5", for example, we
10861 // should not accept a root of 2.
10862 if (auto S = SolveQuadraticAddRecExact(AddRec, *this)) {
10863 const auto *R = cast<SCEVConstant>(getConstant(*S));
10864 return ExitLimit(R, R, R, false, Predicates);
10865 }
10866 return getCouldNotCompute();
10867 }
10868
10869 // Otherwise we can only handle this if it is affine.
10870 if (!AddRec->isAffine())
10871 return getCouldNotCompute();
10872
10873 // If this is an affine expression, the execution count of this branch is
10874 // the minimum unsigned root of the following equation:
10875 //
10876 // Start + Step*N = 0 (mod 2^BW)
10877 //
10878 // equivalent to:
10879 //
10880 // Step*N = -Start (mod 2^BW)
10881 //
10882 // where BW is the common bit width of Start and Step.
10883
10884 // Get the initial value for the loop.
10885 const SCEV *Start = getSCEVAtScope(AddRec->getStart(), L->getParentLoop());
10886 const SCEV *Step = getSCEVAtScope(AddRec->getOperand(1), L->getParentLoop());
10887
10888 if (!isLoopInvariant(Step, L))
10889 return getCouldNotCompute();
10890
10891 LoopGuards Guards = LoopGuards::collect(L, *this);
10892 // Specialize step for this loop so we get context sensitive facts below.
10893 const SCEV *StepWLG = applyLoopGuards(Step, Guards);
10894
10895 // For positive steps (counting up until unsigned overflow):
10896 // N = -Start/Step (as unsigned)
10897 // For negative steps (counting down to zero):
10898 // N = Start/-Step
10899 // First compute the unsigned distance from zero in the direction of Step.
10900 bool CountDown = isKnownNegative(StepWLG);
10901 if (!CountDown && !isKnownNonNegative(StepWLG))
10902 return getCouldNotCompute();
10903
10904 const SCEV *Distance = CountDown ? Start : getNegativeSCEV(Start);
10905 // Handle unitary steps, which cannot wraparound.
10906 // 1*N = -Start; -1*N = Start (mod 2^BW), so:
10907 // N = Distance (as unsigned)
10908
10909 if (match(Step, m_CombineOr(m_scev_One(), m_scev_AllOnes()))) {
10910 APInt MaxBECount = getUnsignedRangeMax(applyLoopGuards(Distance, Guards));
10911 MaxBECount = APIntOps::umin(MaxBECount, getUnsignedRangeMax(Distance));
10912
10913 // When a loop like "for (int i = 0; i != n; ++i) { /* body */ }" is rotated,
10914 // we end up with a loop whose backedge-taken count is n - 1. Detect this
10915 // case, and see if we can improve the bound.
10916 //
10917 // Explicitly handling this here is necessary because getUnsignedRange
10918 // isn't context-sensitive; it doesn't know that we only care about the
10919 // range inside the loop.
10920 const SCEV *Zero = getZero(Distance->getType());
10921 const SCEV *One = getOne(Distance->getType());
10922 const SCEV *DistancePlusOne = getAddExpr(Distance, One);
10923 if (isLoopEntryGuardedByCond(L, ICmpInst::ICMP_NE, DistancePlusOne, Zero)) {
10924 // If Distance + 1 doesn't overflow, we can compute the maximum distance
10925 // as "unsigned_max(Distance + 1) - 1".
10926 ConstantRange CR = getUnsignedRange(DistancePlusOne);
10927 MaxBECount = APIntOps::umin(MaxBECount, CR.getUnsignedMax() - 1);
10928 }
10929 return ExitLimit(Distance, getConstant(MaxBECount), Distance, false,
10930 Predicates);
10931 }
10932
10933 // If the condition controls loop exit (the loop exits only if the expression
10934 // is true) and the addition is no-wrap we can use unsigned divide to
10935 // compute the backedge count. In this case, the step may not divide the
10936 // distance, but we don't care because if the condition is "missed" the loop
10937 // will have undefined behavior due to wrapping.
10938 if (ControlsOnlyExit && AddRec->hasNoSelfWrap() &&
10939 loopHasNoAbnormalExits(AddRec->getLoop())) {
10940
10941 // If the stride is zero and the start is non-zero, the loop must be
10942 // infinite. In C++, most loops are finite by assumption, in which case the
10943 // step being zero implies UB must execute if the loop is entered.
10944 if (!(loopIsFiniteByAssumption(L) && isKnownNonZero(Start)) &&
10945 !isKnownNonZero(StepWLG))
10946 return getCouldNotCompute();
10947
10948 const SCEV *Exact =
10949 getUDivExpr(Distance, CountDown ? getNegativeSCEV(Step) : Step);
10950 const SCEV *ConstantMax = getCouldNotCompute();
10951 if (Exact != getCouldNotCompute()) {
10952 APInt MaxInt = getUnsignedRangeMax(applyLoopGuards(Exact, Guards));
10953 ConstantMax =
10955 }
10956 const SCEV *SymbolicMax =
10957 isa<SCEVCouldNotCompute>(Exact) ? ConstantMax : Exact;
10958 return ExitLimit(Exact, ConstantMax, SymbolicMax, false, Predicates);
10959 }
10960
10961 // Solve the general equation.
10962 const SCEVConstant *StepC = dyn_cast<SCEVConstant>(Step);
10963 if (!StepC || StepC->getValue()->isZero())
10964 return getCouldNotCompute();
10965 const SCEV *E = SolveLinEquationWithOverflow(
10966 StepC->getAPInt(), getNegativeSCEV(Start),
10967 AllowPredicates ? &Predicates : nullptr, *this, L);
10968
10969 const SCEV *M = E;
10970 if (E != getCouldNotCompute()) {
10971 APInt MaxWithGuards = getUnsignedRangeMax(applyLoopGuards(E, Guards));
10972 M = getConstant(APIntOps::umin(MaxWithGuards, getUnsignedRangeMax(E)));
10973 }
10974 auto *S = isa<SCEVCouldNotCompute>(E) ? M : E;
10975 return ExitLimit(E, M, S, false, Predicates);
10976}
10977
10978ScalarEvolution::ExitLimit
10979ScalarEvolution::howFarToNonZero(const SCEV *V, const Loop *L) {
10980 // Loops that look like: while (X == 0) are very strange indeed. We don't
10981 // handle them yet except for the trivial case. This could be expanded in the
10982 // future as needed.
10983
10984 // If the value is a constant, check to see if it is known to be non-zero
10985 // already. If so, the backedge will execute zero times.
10986 if (const SCEVConstant *C = dyn_cast<SCEVConstant>(V)) {
10987 if (!C->getValue()->isZero())
10988 return getZero(C->getType());
10989 return getCouldNotCompute(); // Otherwise it will loop infinitely.
10990 }
10991
10992 // We could implement others, but I really doubt anyone writes loops like
10993 // this, and if they did, they would already be constant folded.
10994 return getCouldNotCompute();
10995}
10996
10997std::pair<const BasicBlock *, const BasicBlock *>
10998ScalarEvolution::getPredecessorWithUniqueSuccessorForBB(const BasicBlock *BB)
10999 const {
11000 // If the block has a unique predecessor, then there is no path from the
11001 // predecessor to the block that does not go through the direct edge
11002 // from the predecessor to the block.
11003 if (const BasicBlock *Pred = BB->getSinglePredecessor())
11004 return {Pred, BB};
11005
11006 // A loop's header is defined to be a block that dominates the loop.
11007 // If the header has a unique predecessor outside the loop, it must be
11008 // a block that has exactly one successor that can reach the loop.
11009 if (const Loop *L = LI.getLoopFor(BB))
11010 return {L->getLoopPredecessor(), L->getHeader()};
11011
11012 return {nullptr, BB};
11013}
11014
11015/// SCEV structural equivalence is usually sufficient for testing whether two
11016/// expressions are equal, however for the purposes of looking for a condition
11017/// guarding a loop, it can be useful to be a little more general, since a
11018/// front-end may have replicated the controlling expression.
11019static bool HasSameValue(const SCEV *A, const SCEV *B) {
11020 // Quick check to see if they are the same SCEV.
11021 if (A == B) return true;
11022
11023 auto ComputesEqualValues = [](const Instruction *A, const Instruction *B) {
11024 // Not all instructions that are "identical" compute the same value. For
11025 // instance, two distinct alloca instructions allocating the same type are
11026 // identical and do not read memory; but compute distinct values.
11027 return A->isIdenticalTo(B) && (isa<BinaryOperator>(A) || isa<GetElementPtrInst>(A));
11028 };
11029
11030 // Otherwise, if they're both SCEVUnknown, it's possible that they hold
11031 // two different instructions with the same value. Check for this case.
11032 if (const SCEVUnknown *AU = dyn_cast<SCEVUnknown>(A))
11033 if (const SCEVUnknown *BU = dyn_cast<SCEVUnknown>(B))
11034 if (const Instruction *AI = dyn_cast<Instruction>(AU->getValue()))
11035 if (const Instruction *BI = dyn_cast<Instruction>(BU->getValue()))
11036 if (ComputesEqualValues(AI, BI))
11037 return true;
11038
11039 // Otherwise assume they may have a different value.
11040 return false;
11041}
11042
11043static bool MatchBinarySub(const SCEV *S, SCEVUse &LHS, SCEVUse &RHS) {
11044 const SCEV *Op0, *Op1;
11045 if (!match(S, m_scev_Add(m_SCEV(Op0), m_SCEV(Op1))))
11046 return false;
11047 if (match(Op0, m_scev_Mul(m_scev_AllOnes(), m_SCEV(RHS)))) {
11048 LHS = Op1;
11049 return true;
11050 }
11051 if (match(Op1, m_scev_Mul(m_scev_AllOnes(), m_SCEV(RHS)))) {
11052 LHS = Op0;
11053 return true;
11054 }
11055 return false;
11056}
11057
11059 SCEVUse &RHS, unsigned Depth) {
11060 bool Changed = false;
11061 // Simplifies ICMP to trivial true or false by turning it into '0 == 0' or
11062 // '0 != 0'.
11063 auto TrivialCase = [&](bool TriviallyTrue) {
11065 Pred = TriviallyTrue ? ICmpInst::ICMP_EQ : ICmpInst::ICMP_NE;
11066 return true;
11067 };
11068 // If we hit the max recursion limit bail out.
11069 if (Depth >= 3)
11070 return false;
11071
11072 const SCEV *NewLHS, *NewRHS;
11073 if (match(LHS, m_scev_c_Mul(m_SCEV(NewLHS), m_SCEVVScale())) &&
11074 match(RHS, m_scev_c_Mul(m_SCEV(NewRHS), m_SCEVVScale()))) {
11075 const SCEVMulExpr *LMul = cast<SCEVMulExpr>(LHS);
11076 const SCEVMulExpr *RMul = cast<SCEVMulExpr>(RHS);
11077
11078 // (X * vscale) pred (Y * vscale) ==> X pred Y
11079 // when both multiples are NSW.
11080 // (X * vscale) uicmp/eq/ne (Y * vscale) ==> X uicmp/eq/ne Y
11081 // when both multiples are NUW.
11082 if ((LMul->hasNoSignedWrap() && RMul->hasNoSignedWrap()) ||
11083 (LMul->hasNoUnsignedWrap() && RMul->hasNoUnsignedWrap() &&
11084 !ICmpInst::isSigned(Pred))) {
11085 LHS = NewLHS;
11086 RHS = NewRHS;
11087 Changed = true;
11088 }
11089 }
11090
11091 // Canonicalize a constant to the right side.
11092 if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(LHS)) {
11093 // Check for both operands constant.
11094 if (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS)) {
11095 if (!ICmpInst::compare(LHSC->getAPInt(), RHSC->getAPInt(), Pred))
11096 return TrivialCase(false);
11097 return TrivialCase(true);
11098 }
11099 // Otherwise swap the operands to put the constant on the right.
11100 std::swap(LHS, RHS);
11102 Changed = true;
11103 }
11104
11105 // If we're comparing an addrec with a value which is loop-invariant in the
11106 // addrec's loop, put the addrec on the left. Also make a dominance check,
11107 // as both operands could be addrecs loop-invariant in each other's loop.
11108 if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(RHS)) {
11109 const Loop *L = AR->getLoop();
11110 if (isLoopInvariant(LHS, L) && properlyDominates(LHS, L->getHeader())) {
11111 std::swap(LHS, RHS);
11113 Changed = true;
11114 }
11115 }
11116
11117 // If there's a constant operand, canonicalize comparisons with boundary
11118 // cases, and canonicalize *-or-equal comparisons to regular comparisons.
11119 if (const SCEVConstant *RC = dyn_cast<SCEVConstant>(RHS)) {
11120 const APInt &RA = RC->getAPInt();
11121
11122 bool SimplifiedByConstantRange = false;
11123
11124 if (!ICmpInst::isEquality(Pred)) {
11126 if (ExactCR.isFullSet())
11127 return TrivialCase(true);
11128 if (ExactCR.isEmptySet())
11129 return TrivialCase(false);
11130
11131 APInt NewRHS;
11132 CmpInst::Predicate NewPred;
11133 if (ExactCR.getEquivalentICmp(NewPred, NewRHS) &&
11134 ICmpInst::isEquality(NewPred)) {
11135 // We were able to convert an inequality to an equality.
11136 Pred = NewPred;
11137 RHS = getConstant(NewRHS);
11138 Changed = SimplifiedByConstantRange = true;
11139 }
11140 }
11141
11142 if (!SimplifiedByConstantRange) {
11143 switch (Pred) {
11144 default:
11145 break;
11146 case ICmpInst::ICMP_EQ:
11147 case ICmpInst::ICMP_NE:
11148 // Fold ((-1) * %a) + %b == 0 (equivalent to %b-%a == 0) into %a == %b.
11149 if (RA.isZero() && MatchBinarySub(LHS, LHS, RHS))
11150 Changed = true;
11151 break;
11152
11153 // The "Should have been caught earlier!" messages refer to the fact
11154 // that the ExactCR.isFullSet() or ExactCR.isEmptySet() check above
11155 // should have fired on the corresponding cases, and canonicalized the
11156 // check to trivial case.
11157
11158 case ICmpInst::ICMP_UGE:
11159 assert(!RA.isMinValue() && "Should have been caught earlier!");
11160 Pred = ICmpInst::ICMP_UGT;
11161 RHS = getConstant(RA - 1);
11162 Changed = true;
11163 break;
11164 case ICmpInst::ICMP_ULE:
11165 assert(!RA.isMaxValue() && "Should have been caught earlier!");
11166 Pred = ICmpInst::ICMP_ULT;
11167 RHS = getConstant(RA + 1);
11168 Changed = true;
11169 break;
11170 case ICmpInst::ICMP_SGE:
11171 assert(!RA.isMinSignedValue() && "Should have been caught earlier!");
11172 Pred = ICmpInst::ICMP_SGT;
11173 RHS = getConstant(RA - 1);
11174 Changed = true;
11175 break;
11176 case ICmpInst::ICMP_SLE:
11177 assert(!RA.isMaxSignedValue() && "Should have been caught earlier!");
11178 Pred = ICmpInst::ICMP_SLT;
11179 RHS = getConstant(RA + 1);
11180 Changed = true;
11181 break;
11182 }
11183 }
11184 }
11185
11186 // Check for obvious equality.
11187 if (HasSameValue(LHS, RHS)) {
11188 if (ICmpInst::isTrueWhenEqual(Pred))
11189 return TrivialCase(true);
11191 return TrivialCase(false);
11192 }
11193
11194 // If possible, canonicalize GE/LE comparisons to GT/LT comparisons, by
11195 // adding or subtracting 1 from one of the operands.
11196 switch (Pred) {
11197 case ICmpInst::ICMP_SLE:
11198 if (!getSignedRangeMax(RHS).isMaxSignedValue()) {
11199 RHS = getAddExpr(getConstant(RHS->getType(), 1, true), RHS,
11201 Pred = ICmpInst::ICMP_SLT;
11202 Changed = true;
11203 } else if (!getSignedRangeMin(LHS).isMinSignedValue()) {
11204 LHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), LHS,
11206 Pred = ICmpInst::ICMP_SLT;
11207 Changed = true;
11208 }
11209 break;
11210 case ICmpInst::ICMP_SGE:
11211 if (!getSignedRangeMin(RHS).isMinSignedValue()) {
11212 RHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), RHS,
11214 Pred = ICmpInst::ICMP_SGT;
11215 Changed = true;
11216 } else if (!getSignedRangeMax(LHS).isMaxSignedValue()) {
11217 LHS = getAddExpr(getConstant(RHS->getType(), 1, true), LHS,
11219 Pred = ICmpInst::ICMP_SGT;
11220 Changed = true;
11221 }
11222 break;
11223 case ICmpInst::ICMP_ULE:
11224 if (!getUnsignedRangeMax(RHS).isMaxValue()) {
11225 RHS = getAddExpr(getConstant(RHS->getType(), 1, true), RHS,
11227 Pred = ICmpInst::ICMP_ULT;
11228 Changed = true;
11229 } else if (!getUnsignedRangeMin(LHS).isMinValue()) {
11230 LHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), LHS);
11231 Pred = ICmpInst::ICMP_ULT;
11232 Changed = true;
11233 }
11234 break;
11235 case ICmpInst::ICMP_UGE:
11236 // If RHS is an op we can fold the -1, try that first.
11237 // Otherwise prefer LHS to preserve the nuw flag.
11238 if ((isa<SCEVConstant>(RHS) ||
11240 isa<SCEVConstant>(cast<SCEVNAryExpr>(RHS)->getOperand(0)))) &&
11241 !getUnsignedRangeMin(RHS).isMinValue()) {
11242 RHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), RHS);
11243 Pred = ICmpInst::ICMP_UGT;
11244 Changed = true;
11245 } else if (!getUnsignedRangeMax(LHS).isMaxValue()) {
11246 LHS = getAddExpr(getConstant(RHS->getType(), 1, true), LHS,
11248 Pred = ICmpInst::ICMP_UGT;
11249 Changed = true;
11250 } else if (!getUnsignedRangeMin(RHS).isMinValue()) {
11251 RHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), RHS);
11252 Pred = ICmpInst::ICMP_UGT;
11253 Changed = true;
11254 }
11255 break;
11256 default:
11257 break;
11258 }
11259
11260 // TODO: More simplifications are possible here.
11261
11262 // Recursively simplify until we either hit a recursion limit or nothing
11263 // changes.
11264 if (Changed)
11265 (void)SimplifyICmpOperands(Pred, LHS, RHS, Depth + 1);
11266
11267 return Changed;
11268}
11269
11271 return getSignedRangeMax(S).isNegative();
11272}
11273
11277
11279 return !getSignedRangeMin(S).isNegative();
11280}
11281
11285
11287 // Query push down for cases where the unsigned range is
11288 // less than sufficient.
11289 if (const auto *SExt = dyn_cast<SCEVSignExtendExpr>(S))
11290 return isKnownNonZero(SExt->getOperand(0));
11291 return getUnsignedRangeMin(S) != 0;
11292}
11293
11295 bool OrNegative) {
11296 auto NonRecursive = [OrNegative](const SCEV *S) {
11297 if (auto *C = dyn_cast<SCEVConstant>(S))
11298 return C->getAPInt().isPowerOf2() ||
11299 (OrNegative && C->getAPInt().isNegatedPowerOf2());
11300
11301 // vscale is a power-of-two.
11302 return isa<SCEVVScale>(S);
11303 };
11304
11305 if (NonRecursive(S))
11306 return true;
11307
11308 auto *Mul = dyn_cast<SCEVMulExpr>(S);
11309 if (!Mul)
11310 return false;
11311 return all_of(Mul->operands(), NonRecursive) && (OrZero || isKnownNonZero(S));
11312}
11313
11315 const SCEV *S, uint64_t M,
11317 if (M == 0)
11318 return false;
11319 if (M == 1)
11320 return true;
11321
11322 // Recursively check AddRec operands. An AddRecExpr S is a multiple of M if S
11323 // starts with a multiple of M and at every iteration step S only adds
11324 // multiples of M.
11325 if (auto *AddRec = dyn_cast<SCEVAddRecExpr>(S))
11326 return isKnownMultipleOf(AddRec->getStart(), M, Assumptions) &&
11327 isKnownMultipleOf(AddRec->getStepRecurrence(*this), M, Assumptions);
11328
11329 // For a constant, check that "S % M == 0".
11330 if (auto *Cst = dyn_cast<SCEVConstant>(S)) {
11331 APInt C = Cst->getAPInt();
11332 return C.urem(M) == 0;
11333 }
11334
11335 // TODO: Also check other SCEV expressions, i.e., SCEVAddRecExpr, etc.
11336
11337 // Basic tests have failed.
11338 // Check "S % M == 0" at compile time and record runtime Assumptions.
11339 auto *STy = dyn_cast<IntegerType>(S->getType());
11340 const SCEV *SmodM =
11341 getURemExpr(S, getConstant(ConstantInt::get(STy, M, false)));
11342 const SCEV *Zero = getZero(STy);
11343
11344 // Check whether "S % M == 0" is known at compile time.
11345 if (isKnownPredicate(ICmpInst::ICMP_EQ, SmodM, Zero))
11346 return true;
11347
11348 // Check whether "S % M != 0" is known at compile time.
11349 if (isKnownPredicate(ICmpInst::ICMP_NE, SmodM, Zero))
11350 return false;
11351
11353
11354 // Detect redundant predicates.
11355 for (auto *A : Assumptions)
11356 if (A->implies(P, *this))
11357 return true;
11358
11359 // Only record non-redundant predicates.
11360 Assumptions.push_back(P);
11361 return true;
11362}
11363
11365 return ((isKnownNonNegative(S1) && isKnownNonNegative(S2)) ||
11367}
11368
11369std::pair<const SCEV *, const SCEV *>
11371 // Compute SCEV on entry of loop L.
11372 const SCEV *Start = SCEVInitRewriter::rewrite(S, L, *this);
11373 if (Start == getCouldNotCompute())
11374 return { Start, Start };
11375 // Compute post increment SCEV for loop L.
11376 const SCEV *PostInc = SCEVPostIncRewriter::rewrite(S, L, *this);
11377 assert(PostInc != getCouldNotCompute() && "Unexpected could not compute");
11378 return { Start, PostInc };
11379}
11380
11382 SCEVUse RHS) {
11383 // First collect all loops.
11385 getUsedLoops(LHS, LoopsUsed);
11386 getUsedLoops(RHS, LoopsUsed);
11387
11388 if (LoopsUsed.empty())
11389 return false;
11390
11391 // Domination relationship must be a linear order on collected loops.
11392#ifndef NDEBUG
11393 for (const auto *L1 : LoopsUsed)
11394 for (const auto *L2 : LoopsUsed)
11395 assert((DT.dominates(L1->getHeader(), L2->getHeader()) ||
11396 DT.dominates(L2->getHeader(), L1->getHeader())) &&
11397 "Domination relationship is not a linear order");
11398#endif
11399
11400 const Loop *MDL =
11401 *llvm::max_element(LoopsUsed, [&](const Loop *L1, const Loop *L2) {
11402 return DT.properlyDominates(L1->getHeader(), L2->getHeader());
11403 });
11404
11405 // Get init and post increment value for LHS.
11406 auto SplitLHS = SplitIntoInitAndPostInc(MDL, LHS);
11407 // if LHS contains unknown non-invariant SCEV then bail out.
11408 if (SplitLHS.first == getCouldNotCompute())
11409 return false;
11410 assert (SplitLHS.second != getCouldNotCompute() && "Unexpected CNC");
11411 // Get init and post increment value for RHS.
11412 auto SplitRHS = SplitIntoInitAndPostInc(MDL, RHS);
11413 // if RHS contains unknown non-invariant SCEV then bail out.
11414 if (SplitRHS.first == getCouldNotCompute())
11415 return false;
11416 assert (SplitRHS.second != getCouldNotCompute() && "Unexpected CNC");
11417 // It is possible that init SCEV contains an invariant load but it does
11418 // not dominate MDL and is not available at MDL loop entry, so we should
11419 // check it here.
11420 if (!isAvailableAtLoopEntry(SplitLHS.first, MDL) ||
11421 !isAvailableAtLoopEntry(SplitRHS.first, MDL))
11422 return false;
11423
11424 // It seems backedge guard check is faster than entry one so in some cases
11425 // it can speed up whole estimation by short circuit
11426 return isLoopBackedgeGuardedByCond(MDL, Pred, SplitLHS.second,
11427 SplitRHS.second) &&
11428 isLoopEntryGuardedByCond(MDL, Pred, SplitLHS.first, SplitRHS.first);
11429}
11430
11432 SCEVUse RHS) {
11433 // Canonicalize the inputs first.
11434 (void)SimplifyICmpOperands(Pred, LHS, RHS);
11435
11436 if (isKnownViaInduction(Pred, LHS, RHS))
11437 return true;
11438
11439 if (isKnownPredicateViaSplitting(Pred, LHS, RHS))
11440 return true;
11441
11442 // Otherwise see what can be done with some simple reasoning.
11443 return isKnownViaNonRecursiveReasoning(Pred, LHS, RHS);
11444}
11445
11447 const SCEV *LHS,
11448 const SCEV *RHS) {
11449 if (isKnownPredicate(Pred, LHS, RHS))
11450 return true;
11452 return false;
11453 return std::nullopt;
11454}
11455
11457 const SCEV *RHS,
11458 const Instruction *CtxI) {
11459 // TODO: Analyze guards and assumes from Context's block.
11460 return isKnownPredicate(Pred, LHS, RHS) ||
11461 isBasicBlockEntryGuardedByCond(CtxI->getParent(), Pred, LHS, RHS);
11462}
11463
11464std::optional<bool>
11466 const SCEV *RHS, const Instruction *CtxI) {
11467 std::optional<bool> KnownWithoutContext = evaluatePredicate(Pred, LHS, RHS);
11468 if (KnownWithoutContext)
11469 return KnownWithoutContext;
11470
11471 if (isBasicBlockEntryGuardedByCond(CtxI->getParent(), Pred, LHS, RHS))
11472 return true;
11474 CtxI->getParent(), ICmpInst::getInverseCmpPredicate(Pred), LHS, RHS))
11475 return false;
11476 return std::nullopt;
11477}
11478
11480 const SCEVAddRecExpr *LHS,
11481 const SCEV *RHS) {
11482 const Loop *L = LHS->getLoop();
11483 return isLoopEntryGuardedByCond(L, Pred, LHS->getStart(), RHS) &&
11484 isLoopBackedgeGuardedByCond(L, Pred, LHS->getPostIncExpr(*this), RHS);
11485}
11486
11487std::optional<ScalarEvolution::MonotonicPredicateType>
11489 ICmpInst::Predicate Pred) {
11490 auto Result = getMonotonicPredicateTypeImpl(LHS, Pred);
11491
11492#ifndef NDEBUG
11493 // Verify an invariant: inverting the predicate should turn a monotonically
11494 // increasing change to a monotonically decreasing one, and vice versa.
11495 if (Result) {
11496 auto ResultSwapped =
11497 getMonotonicPredicateTypeImpl(LHS, ICmpInst::getSwappedPredicate(Pred));
11498
11499 assert(*ResultSwapped != *Result &&
11500 "monotonicity should flip as we flip the predicate");
11501 }
11502#endif
11503
11504 return Result;
11505}
11506
11507std::optional<ScalarEvolution::MonotonicPredicateType>
11508ScalarEvolution::getMonotonicPredicateTypeImpl(const SCEVAddRecExpr *LHS,
11509 ICmpInst::Predicate Pred) {
11510 // A zero step value for LHS means the induction variable is essentially a
11511 // loop invariant value. We don't really depend on the predicate actually
11512 // flipping from false to true (for increasing predicates, and the other way
11513 // around for decreasing predicates), all we care about is that *if* the
11514 // predicate changes then it only changes from false to true.
11515 //
11516 // A zero step value in itself is not very useful, but there may be places
11517 // where SCEV can prove X >= 0 but not prove X > 0, so it is helpful to be
11518 // as general as possible.
11519
11520 // Only handle LE/LT/GE/GT predicates.
11521 if (!ICmpInst::isRelational(Pred))
11522 return std::nullopt;
11523
11524 bool IsGreater = ICmpInst::isGE(Pred) || ICmpInst::isGT(Pred);
11525 assert((IsGreater || ICmpInst::isLE(Pred) || ICmpInst::isLT(Pred)) &&
11526 "Should be greater or less!");
11527
11528 // Check that AR does not wrap.
11529 if (ICmpInst::isUnsigned(Pred)) {
11530 if (!LHS->hasNoUnsignedWrap())
11531 return std::nullopt;
11533 }
11534 assert(ICmpInst::isSigned(Pred) &&
11535 "Relational predicate is either signed or unsigned!");
11536 if (!LHS->hasNoSignedWrap())
11537 return std::nullopt;
11538
11539 const SCEV *Step = LHS->getStepRecurrence(*this);
11540
11541 if (isKnownNonNegative(Step))
11543
11544 if (isKnownNonPositive(Step))
11546
11547 return std::nullopt;
11548}
11549
11550std::optional<ScalarEvolution::LoopInvariantPredicate>
11552 const SCEV *RHS, const Loop *L,
11553 const Instruction *CtxI) {
11554 // If there is a loop-invariant, force it into the RHS, otherwise bail out.
11555 if (!isLoopInvariant(RHS, L)) {
11556 if (!isLoopInvariant(LHS, L))
11557 return std::nullopt;
11558
11559 std::swap(LHS, RHS);
11561 }
11562
11563 const SCEVAddRecExpr *ArLHS = dyn_cast<SCEVAddRecExpr>(LHS);
11564 if (!ArLHS || ArLHS->getLoop() != L)
11565 return std::nullopt;
11566
11567 auto MonotonicType = getMonotonicPredicateType(ArLHS, Pred);
11568 if (!MonotonicType)
11569 return std::nullopt;
11570 // If the predicate "ArLHS `Pred` RHS" monotonically increases from false to
11571 // true as the loop iterates, and the backedge is control dependent on
11572 // "ArLHS `Pred` RHS" == true then we can reason as follows:
11573 //
11574 // * if the predicate was false in the first iteration then the predicate
11575 // is never evaluated again, since the loop exits without taking the
11576 // backedge.
11577 // * if the predicate was true in the first iteration then it will
11578 // continue to be true for all future iterations since it is
11579 // monotonically increasing.
11580 //
11581 // For both the above possibilities, we can replace the loop varying
11582 // predicate with its value on the first iteration of the loop (which is
11583 // loop invariant).
11584 //
11585 // A similar reasoning applies for a monotonically decreasing predicate, by
11586 // replacing true with false and false with true in the above two bullets.
11588 auto P = Increasing ? Pred : ICmpInst::getInverseCmpPredicate(Pred);
11589
11590 if (isLoopBackedgeGuardedByCond(L, P, LHS, RHS))
11592 RHS);
11593
11594 if (!CtxI)
11595 return std::nullopt;
11596 // Try to prove via context.
11597 // TODO: Support other cases.
11598 switch (Pred) {
11599 default:
11600 break;
11601 case ICmpInst::ICMP_ULE:
11602 case ICmpInst::ICMP_ULT: {
11603 assert(ArLHS->hasNoUnsignedWrap() && "Is a requirement of monotonicity!");
11604 // Given preconditions
11605 // (1) ArLHS does not cross the border of positive and negative parts of
11606 // range because of:
11607 // - Positive step; (TODO: lift this limitation)
11608 // - nuw - does not cross zero boundary;
11609 // - nsw - does not cross SINT_MAX boundary;
11610 // (2) ArLHS <s RHS
11611 // (3) RHS >=s 0
11612 // we can replace the loop variant ArLHS <u RHS condition with loop
11613 // invariant Start(ArLHS) <u RHS.
11614 //
11615 // Because of (1) there are two options:
11616 // - ArLHS is always negative. It means that ArLHS <u RHS is always false;
11617 // - ArLHS is always non-negative. Because of (3) RHS is also non-negative.
11618 // It means that ArLHS <s RHS <=> ArLHS <u RHS.
11619 // Because of (2) ArLHS <u RHS is trivially true.
11620 // All together it means that ArLHS <u RHS <=> Start(ArLHS) >=s 0.
11621 // We can strengthen this to Start(ArLHS) <u RHS.
11622 auto SignFlippedPred = ICmpInst::getFlippedSignednessPredicate(Pred);
11623 if (ArLHS->hasNoSignedWrap() && ArLHS->isAffine() &&
11624 isKnownPositive(ArLHS->getStepRecurrence(*this)) &&
11625 isKnownNonNegative(RHS) &&
11626 isKnownPredicateAt(SignFlippedPred, ArLHS, RHS, CtxI))
11628 RHS);
11629 }
11630 }
11631
11632 return std::nullopt;
11633}
11634
11635std::optional<ScalarEvolution::LoopInvariantPredicate>
11637 CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS, const Loop *L,
11638 const Instruction *CtxI, const SCEV *MaxIter) {
11640 Pred, LHS, RHS, L, CtxI, MaxIter))
11641 return LIP;
11642 if (auto *UMin = dyn_cast<SCEVUMinExpr>(MaxIter))
11643 // Number of iterations expressed as UMIN isn't always great for expressing
11644 // the value on the last iteration. If the straightforward approach didn't
11645 // work, try the following trick: if the a predicate is invariant for X, it
11646 // is also invariant for umin(X, ...). So try to find something that works
11647 // among subexpressions of MaxIter expressed as umin.
11648 for (SCEVUse Op : UMin->operands())
11650 Pred, LHS, RHS, L, CtxI, Op))
11651 return LIP;
11652 return std::nullopt;
11653}
11654
11655std::optional<ScalarEvolution::LoopInvariantPredicate>
11657 CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS, const Loop *L,
11658 const Instruction *CtxI, const SCEV *MaxIter) {
11659 // Try to prove the following set of facts:
11660 // - The predicate is monotonic in the iteration space.
11661 // - If the check does not fail on the 1st iteration:
11662 // - No overflow will happen during first MaxIter iterations;
11663 // - It will not fail on the MaxIter'th iteration.
11664 // If the check does fail on the 1st iteration, we leave the loop and no
11665 // other checks matter.
11666
11667 // If there is a loop-invariant, force it into the RHS, otherwise bail out.
11668 if (!isLoopInvariant(RHS, L)) {
11669 if (!isLoopInvariant(LHS, L))
11670 return std::nullopt;
11671
11672 std::swap(LHS, RHS);
11674 }
11675
11676 auto *AR = dyn_cast<SCEVAddRecExpr>(LHS);
11677 if (!AR || AR->getLoop() != L)
11678 return std::nullopt;
11679
11680 // Even if both are valid, we need to consistently chose the unsigned or the
11681 // signed predicate below, not mixtures of both. For now, prefer the unsigned
11682 // predicate.
11683 Pred = Pred.dropSameSign();
11684
11685 // The predicate must be relational (i.e. <, <=, >=, >).
11686 if (!ICmpInst::isRelational(Pred))
11687 return std::nullopt;
11688
11689 // TODO: Support steps other than +/- 1.
11690 const SCEV *Step = AR->getStepRecurrence(*this);
11691 auto *One = getOne(Step->getType());
11692 auto *MinusOne = getNegativeSCEV(One);
11693 if (Step != One && Step != MinusOne)
11694 return std::nullopt;
11695
11696 // Type mismatch here means that MaxIter is potentially larger than max
11697 // unsigned value in start type, which mean we cannot prove no wrap for the
11698 // indvar.
11699 if (AR->getType() != MaxIter->getType())
11700 return std::nullopt;
11701
11702 // Value of IV on suggested last iteration.
11703 const SCEV *Last = AR->evaluateAtIteration(MaxIter, *this);
11704 // Does it still meet the requirement?
11705 if (!isLoopBackedgeGuardedByCond(L, Pred, Last, RHS))
11706 return std::nullopt;
11707 // Because step is +/- 1 and MaxIter has same type as Start (i.e. it does
11708 // not exceed max unsigned value of this type), this effectively proves
11709 // that there is no wrap during the iteration. To prove that there is no
11710 // signed/unsigned wrap, we need to check that
11711 // Start <= Last for step = 1 or Start >= Last for step = -1.
11712 ICmpInst::Predicate NoOverflowPred =
11714 if (Step == MinusOne)
11715 NoOverflowPred = ICmpInst::getSwappedPredicate(NoOverflowPred);
11716 const SCEV *Start = AR->getStart();
11717 if (!isKnownPredicateAt(NoOverflowPred, Start, Last, CtxI))
11718 return std::nullopt;
11719
11720 // Everything is fine.
11721 return ScalarEvolution::LoopInvariantPredicate(Pred, Start, RHS);
11722}
11723
11724bool ScalarEvolution::isKnownPredicateViaConstantRanges(CmpPredicate Pred,
11725 SCEVUse LHS,
11726 SCEVUse RHS) {
11727 if (HasSameValue(LHS, RHS))
11728 return ICmpInst::isTrueWhenEqual(Pred);
11729
11730 auto CheckRange = [&](bool IsSigned) {
11731 auto RangeLHS = IsSigned ? getSignedRange(LHS) : getUnsignedRange(LHS);
11732 auto RangeRHS = IsSigned ? getSignedRange(RHS) : getUnsignedRange(RHS);
11733 return RangeLHS.icmp(Pred, RangeRHS);
11734 };
11735
11736 // The check at the top of the function catches the case where the values are
11737 // known to be equal.
11738 if (Pred == CmpInst::ICMP_EQ)
11739 return false;
11740
11741 if (Pred == CmpInst::ICMP_NE) {
11742 if (CheckRange(true) || CheckRange(false))
11743 return true;
11744 auto *Diff = getMinusSCEV(LHS, RHS);
11745 return !isa<SCEVCouldNotCompute>(Diff) && isKnownNonZero(Diff);
11746 }
11747
11748 return CheckRange(CmpInst::isSigned(Pred));
11749}
11750
11751bool ScalarEvolution::isKnownPredicateViaNoOverflow(CmpPredicate Pred,
11752 SCEVUse LHS, SCEVUse RHS) {
11753 // Match X to (A + C1)<ExpectedFlags> and Y to (A + C2)<ExpectedFlags>, where
11754 // C1 and C2 are constant integers. If either X or Y are not add expressions,
11755 // consider them as X + 0 and Y + 0 respectively. C1 and C2 are returned via
11756 // OutC1 and OutC2.
11757 auto MatchBinaryAddToConst = [this](SCEVUse X, SCEVUse Y, APInt &OutC1,
11758 APInt &OutC2,
11759 SCEV::NoWrapFlags ExpectedFlags) {
11760 SCEVUse XNonConstOp, XConstOp;
11761 SCEVUse YNonConstOp, YConstOp;
11762 SCEV::NoWrapFlags XFlagsPresent;
11763 SCEV::NoWrapFlags YFlagsPresent;
11764
11765 if (!splitBinaryAdd(X, XConstOp, XNonConstOp, XFlagsPresent)) {
11766 XConstOp = getZero(X->getType());
11767 XNonConstOp = X;
11768 XFlagsPresent = ExpectedFlags;
11769 }
11770 if (!isa<SCEVConstant>(XConstOp))
11771 return false;
11772
11773 if (!splitBinaryAdd(Y, YConstOp, YNonConstOp, YFlagsPresent)) {
11774 YConstOp = getZero(Y->getType());
11775 YNonConstOp = Y;
11776 YFlagsPresent = ExpectedFlags;
11777 }
11778
11779 if (YNonConstOp != XNonConstOp)
11780 return false;
11781
11782 if (!isa<SCEVConstant>(YConstOp))
11783 return false;
11784
11785 // When matching ADDs with NUW flags (and unsigned predicates), only the
11786 // second ADD (with the larger constant) requires NUW.
11787 if ((YFlagsPresent & ExpectedFlags) != ExpectedFlags)
11788 return false;
11789 if (ExpectedFlags != SCEV::FlagNUW &&
11790 (XFlagsPresent & ExpectedFlags) != ExpectedFlags) {
11791 return false;
11792 }
11793
11794 OutC1 = cast<SCEVConstant>(XConstOp)->getAPInt();
11795 OutC2 = cast<SCEVConstant>(YConstOp)->getAPInt();
11796
11797 return true;
11798 };
11799
11800 APInt C1;
11801 APInt C2;
11802
11803 switch (Pred) {
11804 default:
11805 break;
11806
11807 case ICmpInst::ICMP_SGE:
11808 std::swap(LHS, RHS);
11809 [[fallthrough]];
11810 case ICmpInst::ICMP_SLE:
11811 // (X + C1)<nsw> s<= (X + C2)<nsw> if C1 s<= C2.
11812 if (MatchBinaryAddToConst(LHS, RHS, C1, C2, SCEV::FlagNSW) && C1.sle(C2))
11813 return true;
11814
11815 break;
11816
11817 case ICmpInst::ICMP_SGT:
11818 std::swap(LHS, RHS);
11819 [[fallthrough]];
11820 case ICmpInst::ICMP_SLT:
11821 // (X + C1)<nsw> s< (X + C2)<nsw> if C1 s< C2.
11822 if (MatchBinaryAddToConst(LHS, RHS, C1, C2, SCEV::FlagNSW) && C1.slt(C2))
11823 return true;
11824
11825 break;
11826
11827 case ICmpInst::ICMP_UGE:
11828 std::swap(LHS, RHS);
11829 [[fallthrough]];
11830 case ICmpInst::ICMP_ULE:
11831 // (X + C1) u<= (X + C2)<nuw> for C1 u<= C2.
11832 if (MatchBinaryAddToConst(LHS, RHS, C1, C2, SCEV::FlagNUW) && C1.ule(C2))
11833 return true;
11834
11835 break;
11836
11837 case ICmpInst::ICMP_UGT:
11838 std::swap(LHS, RHS);
11839 [[fallthrough]];
11840 case ICmpInst::ICMP_ULT:
11841 // (X + C1) u< (X + C2)<nuw> if C1 u< C2.
11842 if (MatchBinaryAddToConst(LHS, RHS, C1, C2, SCEV::FlagNUW) && C1.ult(C2))
11843 return true;
11844 break;
11845 }
11846
11847 return false;
11848}
11849
11850bool ScalarEvolution::isKnownPredicateViaSplitting(CmpPredicate Pred,
11851 SCEVUse LHS, SCEVUse RHS) {
11852 if (Pred != ICmpInst::ICMP_ULT || ProvingSplitPredicate)
11853 return false;
11854
11855 // Allowing arbitrary number of activations of isKnownPredicateViaSplitting on
11856 // the stack can result in exponential time complexity.
11857 SaveAndRestore Restore(ProvingSplitPredicate, true);
11858
11859 // If L >= 0 then I `ult` L <=> I >= 0 && I `slt` L
11860 //
11861 // To prove L >= 0 we use isKnownNonNegative whereas to prove I >= 0 we use
11862 // isKnownPredicate. isKnownPredicate is more powerful, but also more
11863 // expensive; and using isKnownNonNegative(RHS) is sufficient for most of the
11864 // interesting cases seen in practice. We can consider "upgrading" L >= 0 to
11865 // use isKnownPredicate later if needed.
11866 return isKnownNonNegative(RHS) &&
11869}
11870
11871bool ScalarEvolution::isImpliedViaGuard(const BasicBlock *BB, CmpPredicate Pred,
11872 const SCEV *LHS, const SCEV *RHS) {
11873 // No need to even try if we know the module has no guards.
11874 if (!HasGuards)
11875 return false;
11876
11877 return any_of(*BB, [&](const Instruction &I) {
11878 using namespace llvm::PatternMatch;
11879
11880 Value *Condition;
11882 m_Value(Condition))) &&
11883 isImpliedCond(Pred, LHS, RHS, Condition, false);
11884 });
11885}
11886
11887/// isLoopBackedgeGuardedByCond - Test whether the backedge of the loop is
11888/// protected by a conditional between LHS and RHS. This is used to
11889/// to eliminate casts.
11891 CmpPredicate Pred,
11892 const SCEV *LHS,
11893 const SCEV *RHS) {
11894 // Interpret a null as meaning no loop, where there is obviously no guard
11895 // (interprocedural conditions notwithstanding). Do not bother about
11896 // unreachable loops.
11897 if (!L || !DT.isReachableFromEntry(L->getHeader()))
11898 return true;
11899
11900 if (VerifyIR)
11901 assert(!verifyFunction(*L->getHeader()->getParent(), &dbgs()) &&
11902 "This cannot be done on broken IR!");
11903
11904
11905 if (isKnownViaNonRecursiveReasoning(Pred, LHS, RHS))
11906 return true;
11907
11908 BasicBlock *Latch = L->getLoopLatch();
11909 if (!Latch)
11910 return false;
11911
11912 CondBrInst *LoopContinuePredicate =
11914 if (LoopContinuePredicate &&
11915 isImpliedCond(Pred, LHS, RHS, LoopContinuePredicate->getCondition(),
11916 LoopContinuePredicate->getSuccessor(0) != L->getHeader()))
11917 return true;
11918
11919 // We don't want more than one activation of the following loops on the stack
11920 // -- that can lead to O(n!) time complexity.
11921 if (WalkingBEDominatingConds)
11922 return false;
11923
11924 SaveAndRestore ClearOnExit(WalkingBEDominatingConds, true);
11925
11926 // See if we can exploit a trip count to prove the predicate.
11927 const auto &BETakenInfo = getBackedgeTakenInfo(L);
11928 const SCEV *LatchBECount = BETakenInfo.getExact(Latch, this);
11929 if (LatchBECount != getCouldNotCompute()) {
11930 // We know that Latch branches back to the loop header exactly
11931 // LatchBECount times. This means the backdege condition at Latch is
11932 // equivalent to "{0,+,1} u< LatchBECount".
11933 Type *Ty = LatchBECount->getType();
11934 auto NoWrapFlags = SCEV::NoWrapFlags(SCEV::FlagNUW | SCEV::FlagNW);
11935 const SCEV *LoopCounter =
11936 getAddRecExpr(getZero(Ty), getOne(Ty), L, NoWrapFlags);
11937 if (isImpliedCond(Pred, LHS, RHS, ICmpInst::ICMP_ULT, LoopCounter,
11938 LatchBECount))
11939 return true;
11940 }
11941
11942 // Check conditions due to any @llvm.assume intrinsics.
11943 for (auto &AssumeVH : AC.assumptions()) {
11944 if (!AssumeVH)
11945 continue;
11946 auto *CI = cast<CallInst>(AssumeVH);
11947 if (!DT.dominates(CI, Latch->getTerminator()))
11948 continue;
11949
11950 if (isImpliedCond(Pred, LHS, RHS, CI->getArgOperand(0), false))
11951 return true;
11952 }
11953
11954 if (isImpliedViaGuard(Latch, Pred, LHS, RHS))
11955 return true;
11956
11957 for (DomTreeNode *DTN = DT[Latch], *HeaderDTN = DT[L->getHeader()];
11958 DTN != HeaderDTN; DTN = DTN->getIDom()) {
11959 assert(DTN && "should reach the loop header before reaching the root!");
11960
11961 BasicBlock *BB = DTN->getBlock();
11962 if (isImpliedViaGuard(BB, Pred, LHS, RHS))
11963 return true;
11964
11965 BasicBlock *PBB = BB->getSinglePredecessor();
11966 if (!PBB)
11967 continue;
11968
11970 if (!ContBr || ContBr->getSuccessor(0) == ContBr->getSuccessor(1))
11971 continue;
11972
11973 // If we have an edge `E` within the loop body that dominates the only
11974 // latch, the condition guarding `E` also guards the backedge. This
11975 // reasoning works only for loops with a single latch.
11976 // We're constructively (and conservatively) enumerating edges within the
11977 // loop body that dominate the latch. The dominator tree better agree
11978 // with us on this:
11979 assert(DT.dominates(BasicBlockEdge(PBB, BB), Latch) && "should be!");
11980 if (isImpliedCond(Pred, LHS, RHS, ContBr->getCondition(),
11981 BB != ContBr->getSuccessor(0)))
11982 return true;
11983 }
11984
11985 return false;
11986}
11987
11989 CmpPredicate Pred,
11990 const SCEV *LHS,
11991 const SCEV *RHS) {
11992 // Do not bother proving facts for unreachable code.
11993 if (!DT.isReachableFromEntry(BB))
11994 return true;
11995 if (VerifyIR)
11996 assert(!verifyFunction(*BB->getParent(), &dbgs()) &&
11997 "This cannot be done on broken IR!");
11998
11999 // If we cannot prove strict comparison (e.g. a > b), maybe we can prove
12000 // the facts (a >= b && a != b) separately. A typical situation is when the
12001 // non-strict comparison is known from ranges and non-equality is known from
12002 // dominating predicates. If we are proving strict comparison, we always try
12003 // to prove non-equality and non-strict comparison separately.
12004 CmpPredicate NonStrictPredicate = ICmpInst::getNonStrictCmpPredicate(Pred);
12005 const bool ProvingStrictComparison =
12006 Pred != NonStrictPredicate.dropSameSign();
12007 bool ProvedNonStrictComparison = false;
12008 bool ProvedNonEquality = false;
12009
12010 auto SplitAndProve = [&](std::function<bool(CmpPredicate)> Fn) -> bool {
12011 if (!ProvedNonStrictComparison)
12012 ProvedNonStrictComparison = Fn(NonStrictPredicate);
12013 if (!ProvedNonEquality)
12014 ProvedNonEquality = Fn(ICmpInst::ICMP_NE);
12015 if (ProvedNonStrictComparison && ProvedNonEquality)
12016 return true;
12017 return false;
12018 };
12019
12020 if (ProvingStrictComparison) {
12021 auto ProofFn = [&](CmpPredicate P) {
12022 return isKnownViaNonRecursiveReasoning(P, LHS, RHS);
12023 };
12024 if (SplitAndProve(ProofFn))
12025 return true;
12026 }
12027
12028 // Try to prove (Pred, LHS, RHS) using isImpliedCond.
12029 auto ProveViaCond = [&](const Value *Condition, bool Inverse) {
12030 const Instruction *CtxI = &BB->front();
12031 if (isImpliedCond(Pred, LHS, RHS, Condition, Inverse, CtxI))
12032 return true;
12033 if (ProvingStrictComparison) {
12034 auto ProofFn = [&](CmpPredicate P) {
12035 return isImpliedCond(P, LHS, RHS, Condition, Inverse, CtxI);
12036 };
12037 if (SplitAndProve(ProofFn))
12038 return true;
12039 }
12040 return false;
12041 };
12042
12043 // Starting at the block's predecessor, climb up the predecessor chain, as long
12044 // as there are predecessors that can be found that have unique successors
12045 // leading to the original block.
12046 const Loop *ContainingLoop = LI.getLoopFor(BB);
12047 const BasicBlock *PredBB;
12048 if (ContainingLoop && ContainingLoop->getHeader() == BB)
12049 PredBB = ContainingLoop->getLoopPredecessor();
12050 else
12051 PredBB = BB->getSinglePredecessor();
12052 for (std::pair<const BasicBlock *, const BasicBlock *> Pair(PredBB, BB);
12053 Pair.first; Pair = getPredecessorWithUniqueSuccessorForBB(Pair.first)) {
12054 const CondBrInst *BlockEntryPredicate =
12055 dyn_cast<CondBrInst>(Pair.first->getTerminator());
12056 if (!BlockEntryPredicate)
12057 continue;
12058
12059 if (ProveViaCond(BlockEntryPredicate->getCondition(),
12060 BlockEntryPredicate->getSuccessor(0) != Pair.second))
12061 return true;
12062 }
12063
12064 // Check conditions due to any @llvm.assume intrinsics.
12065 for (auto &AssumeVH : AC.assumptions()) {
12066 if (!AssumeVH)
12067 continue;
12068 auto *CI = cast<CallInst>(AssumeVH);
12069 if (!DT.dominates(CI, BB))
12070 continue;
12071
12072 if (ProveViaCond(CI->getArgOperand(0), false))
12073 return true;
12074 }
12075
12076 // Check conditions due to any @llvm.experimental.guard intrinsics.
12077 auto *GuardDecl = Intrinsic::getDeclarationIfExists(
12078 F.getParent(), Intrinsic::experimental_guard);
12079 if (GuardDecl)
12080 for (const auto *GU : GuardDecl->users())
12081 if (const auto *Guard = dyn_cast<IntrinsicInst>(GU))
12082 if (Guard->getFunction() == BB->getParent() && DT.dominates(Guard, BB))
12083 if (ProveViaCond(Guard->getArgOperand(0), false))
12084 return true;
12085 return false;
12086}
12087
12089 const SCEV *LHS,
12090 const SCEV *RHS) {
12091 // Interpret a null as meaning no loop, where there is obviously no guard
12092 // (interprocedural conditions notwithstanding).
12093 if (!L)
12094 return false;
12095
12096 // Both LHS and RHS must be available at loop entry.
12098 "LHS is not available at Loop Entry");
12100 "RHS is not available at Loop Entry");
12101
12102 if (isKnownViaNonRecursiveReasoning(Pred, LHS, RHS))
12103 return true;
12104
12105 return isBasicBlockEntryGuardedByCond(L->getHeader(), Pred, LHS, RHS);
12106}
12107
12108bool ScalarEvolution::isImpliedCond(CmpPredicate Pred, const SCEV *LHS,
12109 const SCEV *RHS,
12110 const Value *FoundCondValue, bool Inverse,
12111 const Instruction *CtxI) {
12112 // False conditions implies anything. Do not bother analyzing it further.
12113 if (FoundCondValue ==
12114 ConstantInt::getBool(FoundCondValue->getContext(), Inverse))
12115 return true;
12116
12117 if (!PendingLoopPredicates.insert(FoundCondValue).second)
12118 return false;
12119
12120 llvm::scope_exit ClearOnExit(
12121 [&]() { PendingLoopPredicates.erase(FoundCondValue); });
12122
12123 // Recursively handle And and Or conditions.
12124 const Value *Op0, *Op1;
12125 if (match(FoundCondValue, m_LogicalAnd(m_Value(Op0), m_Value(Op1)))) {
12126 if (!Inverse)
12127 return isImpliedCond(Pred, LHS, RHS, Op0, Inverse, CtxI) ||
12128 isImpliedCond(Pred, LHS, RHS, Op1, Inverse, CtxI);
12129 } else if (match(FoundCondValue, m_LogicalOr(m_Value(Op0), m_Value(Op1)))) {
12130 if (Inverse)
12131 return isImpliedCond(Pred, LHS, RHS, Op0, Inverse, CtxI) ||
12132 isImpliedCond(Pred, LHS, RHS, Op1, Inverse, CtxI);
12133 }
12134
12135 const ICmpInst *ICI = dyn_cast<ICmpInst>(FoundCondValue);
12136 if (!ICI) return false;
12137
12138 // Now that we found a conditional branch that dominates the loop or controls
12139 // the loop latch. Check to see if it is the comparison we are looking for.
12140 CmpPredicate FoundPred;
12141 if (Inverse)
12142 FoundPred = ICI->getInverseCmpPredicate();
12143 else
12144 FoundPred = ICI->getCmpPredicate();
12145
12146 const SCEV *FoundLHS = getSCEV(ICI->getOperand(0));
12147 const SCEV *FoundRHS = getSCEV(ICI->getOperand(1));
12148
12149 return isImpliedCond(Pred, LHS, RHS, FoundPred, FoundLHS, FoundRHS, CtxI);
12150}
12151
12152bool ScalarEvolution::isImpliedCond(CmpPredicate Pred, const SCEV *LHS,
12153 const SCEV *RHS, CmpPredicate FoundPred,
12154 const SCEV *FoundLHS, const SCEV *FoundRHS,
12155 const Instruction *CtxI) {
12156 // Balance the types.
12157 if (getTypeSizeInBits(LHS->getType()) <
12158 getTypeSizeInBits(FoundLHS->getType())) {
12159 // For unsigned and equality predicates, try to prove that both found
12160 // operands fit into narrow unsigned range. If so, try to prove facts in
12161 // narrow types.
12162 if (!CmpInst::isSigned(FoundPred) && !FoundLHS->getType()->isPointerTy() &&
12163 !FoundRHS->getType()->isPointerTy()) {
12164 auto *NarrowType = LHS->getType();
12165 auto *WideType = FoundLHS->getType();
12166 auto BitWidth = getTypeSizeInBits(NarrowType);
12167 const SCEV *MaxValue = getZeroExtendExpr(
12169 if (isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_ULE, FoundLHS,
12170 MaxValue) &&
12171 isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_ULE, FoundRHS,
12172 MaxValue)) {
12173 const SCEV *TruncFoundLHS = getTruncateExpr(FoundLHS, NarrowType);
12174 const SCEV *TruncFoundRHS = getTruncateExpr(FoundRHS, NarrowType);
12175 // We cannot preserve samesign after truncation.
12176 if (isImpliedCondBalancedTypes(Pred, LHS, RHS, FoundPred.dropSameSign(),
12177 TruncFoundLHS, TruncFoundRHS, CtxI))
12178 return true;
12179 }
12180 }
12181
12182 if (LHS->getType()->isPointerTy() || RHS->getType()->isPointerTy())
12183 return false;
12184 if (CmpInst::isSigned(Pred)) {
12185 LHS = getSignExtendExpr(LHS, FoundLHS->getType());
12186 RHS = getSignExtendExpr(RHS, FoundLHS->getType());
12187 } else {
12188 LHS = getZeroExtendExpr(LHS, FoundLHS->getType());
12189 RHS = getZeroExtendExpr(RHS, FoundLHS->getType());
12190 }
12191 } else if (getTypeSizeInBits(LHS->getType()) >
12192 getTypeSizeInBits(FoundLHS->getType())) {
12193 if (FoundLHS->getType()->isPointerTy() || FoundRHS->getType()->isPointerTy())
12194 return false;
12195 if (CmpInst::isSigned(FoundPred)) {
12196 FoundLHS = getSignExtendExpr(FoundLHS, LHS->getType());
12197 FoundRHS = getSignExtendExpr(FoundRHS, LHS->getType());
12198 } else {
12199 FoundLHS = getZeroExtendExpr(FoundLHS, LHS->getType());
12200 FoundRHS = getZeroExtendExpr(FoundRHS, LHS->getType());
12201 }
12202 }
12203 return isImpliedCondBalancedTypes(Pred, LHS, RHS, FoundPred, FoundLHS,
12204 FoundRHS, CtxI);
12205}
12206
12207bool ScalarEvolution::isImpliedCondBalancedTypes(
12208 CmpPredicate Pred, SCEVUse LHS, SCEVUse RHS, CmpPredicate FoundPred,
12209 SCEVUse FoundLHS, SCEVUse FoundRHS, const Instruction *CtxI) {
12211 getTypeSizeInBits(FoundLHS->getType()) &&
12212 "Types should be balanced!");
12213 // Canonicalize the query to match the way instcombine will have
12214 // canonicalized the comparison.
12215 if (SimplifyICmpOperands(Pred, LHS, RHS))
12216 if (LHS == RHS)
12217 return CmpInst::isTrueWhenEqual(Pred);
12218 if (SimplifyICmpOperands(FoundPred, FoundLHS, FoundRHS))
12219 if (FoundLHS == FoundRHS)
12220 return CmpInst::isFalseWhenEqual(FoundPred);
12221
12222 // Check to see if we can make the LHS or RHS match.
12223 if (LHS == FoundRHS || RHS == FoundLHS) {
12224 if (isa<SCEVConstant>(RHS)) {
12225 std::swap(FoundLHS, FoundRHS);
12226 FoundPred = ICmpInst::getSwappedCmpPredicate(FoundPred);
12227 } else {
12228 std::swap(LHS, RHS);
12230 }
12231 }
12232
12233 // Check whether the found predicate is the same as the desired predicate.
12234 if (auto P = CmpPredicate::getMatching(FoundPred, Pred))
12235 return isImpliedCondOperands(*P, LHS, RHS, FoundLHS, FoundRHS, CtxI);
12236
12237 // Check whether swapping the found predicate makes it the same as the
12238 // desired predicate.
12239 if (auto P = CmpPredicate::getMatching(
12240 ICmpInst::getSwappedCmpPredicate(FoundPred), Pred)) {
12241 // We can write the implication
12242 // 0. LHS Pred RHS <- FoundLHS SwapPred FoundRHS
12243 // using one of the following ways:
12244 // 1. LHS Pred RHS <- FoundRHS Pred FoundLHS
12245 // 2. RHS SwapPred LHS <- FoundLHS SwapPred FoundRHS
12246 // 3. LHS Pred RHS <- ~FoundLHS Pred ~FoundRHS
12247 // 4. ~LHS SwapPred ~RHS <- FoundLHS SwapPred FoundRHS
12248 // Forms 1. and 2. require swapping the operands of one condition. Don't
12249 // do this if it would break canonical constant/addrec ordering.
12251 return isImpliedCondOperands(ICmpInst::getSwappedCmpPredicate(*P), RHS,
12252 LHS, FoundLHS, FoundRHS, CtxI);
12253 if (!isa<SCEVConstant>(FoundRHS) && !isa<SCEVAddRecExpr>(FoundLHS))
12254 return isImpliedCondOperands(*P, LHS, RHS, FoundRHS, FoundLHS, CtxI);
12255
12256 // There's no clear preference between forms 3. and 4., try both. Avoid
12257 // forming getNotSCEV of pointer values as the resulting subtract is
12258 // not legal.
12259 if (!LHS->getType()->isPointerTy() && !RHS->getType()->isPointerTy() &&
12260 isImpliedCondOperands(ICmpInst::getSwappedCmpPredicate(*P),
12261 getNotSCEV(LHS), getNotSCEV(RHS), FoundLHS,
12262 FoundRHS, CtxI))
12263 return true;
12264
12265 if (!FoundLHS->getType()->isPointerTy() &&
12266 !FoundRHS->getType()->isPointerTy() &&
12267 isImpliedCondOperands(*P, LHS, RHS, getNotSCEV(FoundLHS),
12268 getNotSCEV(FoundRHS), CtxI))
12269 return true;
12270
12271 return false;
12272 }
12273
12274 auto IsSignFlippedPredicate = [](CmpInst::Predicate P1,
12276 assert(P1 != P2 && "Handled earlier!");
12277 return CmpInst::isRelational(P2) &&
12279 };
12280 if (IsSignFlippedPredicate(Pred, FoundPred)) {
12281 // Unsigned comparison is the same as signed comparison when both the
12282 // operands are non-negative or negative.
12283 if (haveSameSign(FoundLHS, FoundRHS))
12284 return isImpliedCondOperands(Pred, LHS, RHS, FoundLHS, FoundRHS, CtxI);
12285 // Create local copies that we can freely swap and canonicalize our
12286 // conditions to "le/lt".
12287 CmpPredicate CanonicalPred = Pred, CanonicalFoundPred = FoundPred;
12288 const SCEV *CanonicalLHS = LHS, *CanonicalRHS = RHS,
12289 *CanonicalFoundLHS = FoundLHS, *CanonicalFoundRHS = FoundRHS;
12290 if (ICmpInst::isGT(CanonicalPred) || ICmpInst::isGE(CanonicalPred)) {
12291 CanonicalPred = ICmpInst::getSwappedCmpPredicate(CanonicalPred);
12292 CanonicalFoundPred = ICmpInst::getSwappedCmpPredicate(CanonicalFoundPred);
12293 std::swap(CanonicalLHS, CanonicalRHS);
12294 std::swap(CanonicalFoundLHS, CanonicalFoundRHS);
12295 }
12296 assert((ICmpInst::isLT(CanonicalPred) || ICmpInst::isLE(CanonicalPred)) &&
12297 "Must be!");
12298 assert((ICmpInst::isLT(CanonicalFoundPred) ||
12299 ICmpInst::isLE(CanonicalFoundPred)) &&
12300 "Must be!");
12301 if (ICmpInst::isSigned(CanonicalPred) && isKnownNonNegative(CanonicalRHS))
12302 // Use implication:
12303 // x <u y && y >=s 0 --> x <s y.
12304 // If we can prove the left part, the right part is also proven.
12305 return isImpliedCondOperands(CanonicalFoundPred, CanonicalLHS,
12306 CanonicalRHS, CanonicalFoundLHS,
12307 CanonicalFoundRHS);
12308 if (ICmpInst::isUnsigned(CanonicalPred) && isKnownNegative(CanonicalRHS))
12309 // Use implication:
12310 // x <s y && y <s 0 --> x <u y.
12311 // If we can prove the left part, the right part is also proven.
12312 return isImpliedCondOperands(CanonicalFoundPred, CanonicalLHS,
12313 CanonicalRHS, CanonicalFoundLHS,
12314 CanonicalFoundRHS);
12315 }
12316
12317 // Check if we can make progress by sharpening ranges.
12318 if (FoundPred == ICmpInst::ICMP_NE &&
12319 (isa<SCEVConstant>(FoundLHS) || isa<SCEVConstant>(FoundRHS))) {
12320
12321 const SCEVConstant *C = nullptr;
12322 const SCEV *V = nullptr;
12323
12324 if (isa<SCEVConstant>(FoundLHS)) {
12325 C = cast<SCEVConstant>(FoundLHS);
12326 V = FoundRHS;
12327 } else {
12328 C = cast<SCEVConstant>(FoundRHS);
12329 V = FoundLHS;
12330 }
12331
12332 // The guarding predicate tells us that C != V. If the known range
12333 // of V is [C, t), we can sharpen the range to [C + 1, t). The
12334 // range we consider has to correspond to same signedness as the
12335 // predicate we're interested in folding.
12336
12337 APInt Min = ICmpInst::isSigned(Pred) ?
12339
12340 if (Min == C->getAPInt()) {
12341 // Given (V >= Min && V != Min) we conclude V >= (Min + 1).
12342 // This is true even if (Min + 1) wraps around -- in case of
12343 // wraparound, (Min + 1) < Min, so (V >= Min => V >= (Min + 1)).
12344
12345 APInt SharperMin = Min + 1;
12346
12347 switch (Pred) {
12348 case ICmpInst::ICMP_SGE:
12349 case ICmpInst::ICMP_UGE:
12350 // We know V `Pred` SharperMin. If this implies LHS `Pred`
12351 // RHS, we're done.
12352 if (isImpliedCondOperands(Pred, LHS, RHS, V, getConstant(SharperMin),
12353 CtxI))
12354 return true;
12355 [[fallthrough]];
12356
12357 case ICmpInst::ICMP_SGT:
12358 case ICmpInst::ICMP_UGT:
12359 // We know from the range information that (V `Pred` Min ||
12360 // V == Min). We know from the guarding condition that !(V
12361 // == Min). This gives us
12362 //
12363 // V `Pred` Min || V == Min && !(V == Min)
12364 // => V `Pred` Min
12365 //
12366 // If V `Pred` Min implies LHS `Pred` RHS, we're done.
12367
12368 if (isImpliedCondOperands(Pred, LHS, RHS, V, getConstant(Min), CtxI))
12369 return true;
12370 break;
12371
12372 // `LHS < RHS` and `LHS <= RHS` are handled in the same way as `RHS > LHS` and `RHS >= LHS` respectively.
12373 case ICmpInst::ICMP_SLE:
12374 case ICmpInst::ICMP_ULE:
12375 if (isImpliedCondOperands(ICmpInst::getSwappedCmpPredicate(Pred), RHS,
12376 LHS, V, getConstant(SharperMin), CtxI))
12377 return true;
12378 [[fallthrough]];
12379
12380 case ICmpInst::ICMP_SLT:
12381 case ICmpInst::ICMP_ULT:
12382 if (isImpliedCondOperands(ICmpInst::getSwappedCmpPredicate(Pred), RHS,
12383 LHS, V, getConstant(Min), CtxI))
12384 return true;
12385 break;
12386
12387 default:
12388 // No change
12389 break;
12390 }
12391 }
12392 }
12393
12394 // Check whether the actual condition is beyond sufficient.
12395 if (FoundPred == ICmpInst::ICMP_EQ)
12396 if (ICmpInst::isTrueWhenEqual(Pred))
12397 if (isImpliedCondOperands(Pred, LHS, RHS, FoundLHS, FoundRHS, CtxI))
12398 return true;
12399 if (Pred == ICmpInst::ICMP_NE)
12400 if (!ICmpInst::isTrueWhenEqual(FoundPred))
12401 if (isImpliedCondOperands(FoundPred, LHS, RHS, FoundLHS, FoundRHS, CtxI))
12402 return true;
12403
12404 if (isImpliedCondOperandsViaRanges(Pred, LHS, RHS, FoundPred, FoundLHS, FoundRHS))
12405 return true;
12406
12407 // Otherwise assume the worst.
12408 return false;
12409}
12410
12411bool ScalarEvolution::splitBinaryAdd(SCEVUse Expr, SCEVUse &L, SCEVUse &R,
12412 SCEV::NoWrapFlags &Flags) {
12413 if (!match(Expr, m_scev_Add(m_SCEV(L), m_SCEV(R))))
12414 return false;
12415
12416 Flags = cast<SCEVAddExpr>(Expr)->getNoWrapFlags();
12417 return true;
12418}
12419
12420std::optional<APInt>
12422 // We avoid subtracting expressions here because this function is usually
12423 // fairly deep in the call stack (i.e. is called many times).
12424
12425 unsigned BW = getTypeSizeInBits(More->getType());
12426 APInt Diff(BW, 0);
12427 APInt DiffMul(BW, 1);
12428 // Try various simplifications to reduce the difference to a constant. Limit
12429 // the number of allowed simplifications to keep compile-time low.
12430 for (unsigned I = 0; I < 8; ++I) {
12431 if (More == Less)
12432 return Diff;
12433
12434 // Reduce addrecs with identical steps to their start value.
12436 const auto *LAR = cast<SCEVAddRecExpr>(Less);
12437 const auto *MAR = cast<SCEVAddRecExpr>(More);
12438
12439 if (LAR->getLoop() != MAR->getLoop())
12440 return std::nullopt;
12441
12442 // We look at affine expressions only; not for correctness but to keep
12443 // getStepRecurrence cheap.
12444 if (!LAR->isAffine() || !MAR->isAffine())
12445 return std::nullopt;
12446
12447 if (LAR->getStepRecurrence(*this) != MAR->getStepRecurrence(*this))
12448 return std::nullopt;
12449
12450 Less = LAR->getStart();
12451 More = MAR->getStart();
12452 continue;
12453 }
12454
12455 // Try to match a common constant multiply.
12456 auto MatchConstMul =
12457 [](const SCEV *S) -> std::optional<std::pair<const SCEV *, APInt>> {
12458 const APInt *C;
12459 const SCEV *Op;
12460 if (match(S, m_scev_Mul(m_scev_APInt(C), m_SCEV(Op))))
12461 return {{Op, *C}};
12462 return std::nullopt;
12463 };
12464 if (auto MatchedMore = MatchConstMul(More)) {
12465 if (auto MatchedLess = MatchConstMul(Less)) {
12466 if (MatchedMore->second == MatchedLess->second) {
12467 More = MatchedMore->first;
12468 Less = MatchedLess->first;
12469 DiffMul *= MatchedMore->second;
12470 continue;
12471 }
12472 }
12473 }
12474
12475 // Try to cancel out common factors in two add expressions.
12477 auto Add = [&](const SCEV *S, int Mul) {
12478 if (auto *C = dyn_cast<SCEVConstant>(S)) {
12479 if (Mul == 1) {
12480 Diff += C->getAPInt() * DiffMul;
12481 } else {
12482 assert(Mul == -1);
12483 Diff -= C->getAPInt() * DiffMul;
12484 }
12485 } else
12486 Multiplicity[S] += Mul;
12487 };
12488 auto Decompose = [&](const SCEV *S, int Mul) {
12489 if (isa<SCEVAddExpr>(S)) {
12490 for (const SCEV *Op : S->operands())
12491 Add(Op, Mul);
12492 } else
12493 Add(S, Mul);
12494 };
12495 Decompose(More, 1);
12496 Decompose(Less, -1);
12497
12498 // Check whether all the non-constants cancel out, or reduce to new
12499 // More/Less values.
12500 const SCEV *NewMore = nullptr, *NewLess = nullptr;
12501 for (const auto &[S, Mul] : Multiplicity) {
12502 if (Mul == 0)
12503 continue;
12504 if (Mul == 1) {
12505 if (NewMore)
12506 return std::nullopt;
12507 NewMore = S;
12508 } else if (Mul == -1) {
12509 if (NewLess)
12510 return std::nullopt;
12511 NewLess = S;
12512 } else
12513 return std::nullopt;
12514 }
12515
12516 // Values stayed the same, no point in trying further.
12517 if (NewMore == More || NewLess == Less)
12518 return std::nullopt;
12519
12520 More = NewMore;
12521 Less = NewLess;
12522
12523 // Reduced to constant.
12524 if (!More && !Less)
12525 return Diff;
12526
12527 // Left with variable on only one side, bail out.
12528 if (!More || !Less)
12529 return std::nullopt;
12530 }
12531
12532 // Did not reduce to constant.
12533 return std::nullopt;
12534}
12535
12536bool ScalarEvolution::isImpliedCondOperandsViaAddRecStart(
12537 CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS, const SCEV *FoundLHS,
12538 const SCEV *FoundRHS, const Instruction *CtxI) {
12539 // Try to recognize the following pattern:
12540 //
12541 // FoundRHS = ...
12542 // ...
12543 // loop:
12544 // FoundLHS = {Start,+,W}
12545 // context_bb: // Basic block from the same loop
12546 // known(Pred, FoundLHS, FoundRHS)
12547 //
12548 // If some predicate is known in the context of a loop, it is also known on
12549 // each iteration of this loop, including the first iteration. Therefore, in
12550 // this case, `FoundLHS Pred FoundRHS` implies `Start Pred FoundRHS`. Try to
12551 // prove the original pred using this fact.
12552 if (!CtxI)
12553 return false;
12554 const BasicBlock *ContextBB = CtxI->getParent();
12555 // Make sure AR varies in the context block.
12556 if (auto *AR = dyn_cast<SCEVAddRecExpr>(FoundLHS)) {
12557 const Loop *L = AR->getLoop();
12558 const auto *Latch = L->getLoopLatch();
12559 // Make sure that context belongs to the loop and executes on 1st iteration
12560 // (if it ever executes at all).
12561 if (!L->contains(ContextBB) || !Latch || !DT.dominates(ContextBB, Latch))
12562 return false;
12563 if (!isAvailableAtLoopEntry(FoundRHS, AR->getLoop()))
12564 return false;
12565 return isImpliedCondOperands(Pred, LHS, RHS, AR->getStart(), FoundRHS);
12566 }
12567
12568 if (auto *AR = dyn_cast<SCEVAddRecExpr>(FoundRHS)) {
12569 const Loop *L = AR->getLoop();
12570 const auto *Latch = L->getLoopLatch();
12571 // Make sure that context belongs to the loop and executes on 1st iteration
12572 // (if it ever executes at all).
12573 if (!L->contains(ContextBB) || !Latch || !DT.dominates(ContextBB, Latch))
12574 return false;
12575 if (!isAvailableAtLoopEntry(FoundLHS, AR->getLoop()))
12576 return false;
12577 return isImpliedCondOperands(Pred, LHS, RHS, FoundLHS, AR->getStart());
12578 }
12579
12580 return false;
12581}
12582
12583bool ScalarEvolution::isImpliedCondOperandsViaNoOverflow(CmpPredicate Pred,
12584 const SCEV *LHS,
12585 const SCEV *RHS,
12586 const SCEV *FoundLHS,
12587 const SCEV *FoundRHS) {
12588 if (Pred != CmpInst::ICMP_SLT && Pred != CmpInst::ICMP_ULT)
12589 return false;
12590
12591 const auto *AddRecLHS = dyn_cast<SCEVAddRecExpr>(LHS);
12592 if (!AddRecLHS)
12593 return false;
12594
12595 const auto *AddRecFoundLHS = dyn_cast<SCEVAddRecExpr>(FoundLHS);
12596 if (!AddRecFoundLHS)
12597 return false;
12598
12599 // We'd like to let SCEV reason about control dependencies, so we constrain
12600 // both the inequalities to be about add recurrences on the same loop. This
12601 // way we can use isLoopEntryGuardedByCond later.
12602
12603 const Loop *L = AddRecFoundLHS->getLoop();
12604 if (L != AddRecLHS->getLoop())
12605 return false;
12606
12607 // FoundLHS u< FoundRHS u< -C => (FoundLHS + C) u< (FoundRHS + C) ... (1)
12608 //
12609 // FoundLHS s< FoundRHS s< INT_MIN - C => (FoundLHS + C) s< (FoundRHS + C)
12610 // ... (2)
12611 //
12612 // Informal proof for (2), assuming (1) [*]:
12613 //
12614 // We'll also assume (A s< B) <=> ((A + INT_MIN) u< (B + INT_MIN)) ... (3)[**]
12615 //
12616 // Then
12617 //
12618 // FoundLHS s< FoundRHS s< INT_MIN - C
12619 // <=> (FoundLHS + INT_MIN) u< (FoundRHS + INT_MIN) u< -C [ using (3) ]
12620 // <=> (FoundLHS + INT_MIN + C) u< (FoundRHS + INT_MIN + C) [ using (1) ]
12621 // <=> (FoundLHS + INT_MIN + C + INT_MIN) s<
12622 // (FoundRHS + INT_MIN + C + INT_MIN) [ using (3) ]
12623 // <=> FoundLHS + C s< FoundRHS + C
12624 //
12625 // [*]: (1) can be proved by ruling out overflow.
12626 //
12627 // [**]: This can be proved by analyzing all the four possibilities:
12628 // (A s< 0, B s< 0), (A s< 0, B s>= 0), (A s>= 0, B s< 0) and
12629 // (A s>= 0, B s>= 0).
12630 //
12631 // Note:
12632 // Despite (2), "FoundRHS s< INT_MIN - C" does not mean that "FoundRHS + C"
12633 // will not sign underflow. For instance, say FoundLHS = (i8 -128), FoundRHS
12634 // = (i8 -127) and C = (i8 -100). Then INT_MIN - C = (i8 -28), and FoundRHS
12635 // s< (INT_MIN - C). Lack of sign overflow / underflow in "FoundRHS + C" is
12636 // neither necessary nor sufficient to prove "(FoundLHS + C) s< (FoundRHS +
12637 // C)".
12638
12639 std::optional<APInt> LDiff = computeConstantDifference(LHS, FoundLHS);
12640 if (!LDiff)
12641 return false;
12642 std::optional<APInt> RDiff = computeConstantDifference(RHS, FoundRHS);
12643 if (!RDiff || *LDiff != *RDiff)
12644 return false;
12645
12646 if (LDiff->isMinValue())
12647 return true;
12648
12649 APInt FoundRHSLimit;
12650
12651 if (Pred == CmpInst::ICMP_ULT) {
12652 FoundRHSLimit = -(*RDiff);
12653 } else {
12654 assert(Pred == CmpInst::ICMP_SLT && "Checked above!");
12655 FoundRHSLimit = APInt::getSignedMinValue(getTypeSizeInBits(RHS->getType())) - *RDiff;
12656 }
12657
12658 // Try to prove (1) or (2), as needed.
12659 return isAvailableAtLoopEntry(FoundRHS, L) &&
12660 isLoopEntryGuardedByCond(L, Pred, FoundRHS,
12661 getConstant(FoundRHSLimit));
12662}
12663
12664bool ScalarEvolution::isImpliedViaMerge(CmpPredicate Pred, const SCEV *LHS,
12665 const SCEV *RHS, const SCEV *FoundLHS,
12666 const SCEV *FoundRHS, unsigned Depth) {
12667 const PHINode *LPhi = nullptr, *RPhi = nullptr;
12668
12669 llvm::scope_exit ClearOnExit([&]() {
12670 if (LPhi) {
12671 bool Erased = PendingMerges.erase(LPhi);
12672 assert(Erased && "Failed to erase LPhi!");
12673 (void)Erased;
12674 }
12675 if (RPhi) {
12676 bool Erased = PendingMerges.erase(RPhi);
12677 assert(Erased && "Failed to erase RPhi!");
12678 (void)Erased;
12679 }
12680 });
12681
12682 // Find respective Phis and check that they are not being pending.
12683 if (const SCEVUnknown *LU = dyn_cast<SCEVUnknown>(LHS))
12684 if (auto *Phi = dyn_cast<PHINode>(LU->getValue())) {
12685 if (!PendingMerges.insert(Phi).second)
12686 return false;
12687 LPhi = Phi;
12688 }
12689 if (const SCEVUnknown *RU = dyn_cast<SCEVUnknown>(RHS))
12690 if (auto *Phi = dyn_cast<PHINode>(RU->getValue())) {
12691 // If we detect a loop of Phi nodes being processed by this method, for
12692 // example:
12693 //
12694 // %a = phi i32 [ %some1, %preheader ], [ %b, %latch ]
12695 // %b = phi i32 [ %some2, %preheader ], [ %a, %latch ]
12696 //
12697 // we don't want to deal with a case that complex, so return conservative
12698 // answer false.
12699 if (!PendingMerges.insert(Phi).second)
12700 return false;
12701 RPhi = Phi;
12702 }
12703
12704 // If none of LHS, RHS is a Phi, nothing to do here.
12705 if (!LPhi && !RPhi)
12706 return false;
12707
12708 // If there is a SCEVUnknown Phi we are interested in, make it left.
12709 if (!LPhi) {
12710 std::swap(LHS, RHS);
12711 std::swap(FoundLHS, FoundRHS);
12712 std::swap(LPhi, RPhi);
12714 }
12715
12716 assert(LPhi && "LPhi should definitely be a SCEVUnknown Phi!");
12717 const BasicBlock *LBB = LPhi->getParent();
12718 const SCEVAddRecExpr *RAR = dyn_cast<SCEVAddRecExpr>(RHS);
12719
12720 auto ProvedEasily = [&](const SCEV *S1, const SCEV *S2) {
12721 return isKnownViaNonRecursiveReasoning(Pred, S1, S2) ||
12722 isImpliedCondOperandsViaRanges(Pred, S1, S2, Pred, FoundLHS, FoundRHS) ||
12723 isImpliedViaOperations(Pred, S1, S2, FoundLHS, FoundRHS, Depth);
12724 };
12725
12726 if (RPhi && RPhi->getParent() == LBB) {
12727 // Case one: RHS is also a SCEVUnknown Phi from the same basic block.
12728 // If we compare two Phis from the same block, and for each entry block
12729 // the predicate is true for incoming values from this block, then the
12730 // predicate is also true for the Phis.
12731 for (const BasicBlock *IncBB : predecessors(LBB)) {
12732 const SCEV *L = getSCEV(LPhi->getIncomingValueForBlock(IncBB));
12733 const SCEV *R = getSCEV(RPhi->getIncomingValueForBlock(IncBB));
12734 if (!ProvedEasily(L, R))
12735 return false;
12736 }
12737 } else if (RAR && RAR->getLoop()->getHeader() == LBB) {
12738 // Case two: RHS is also a Phi from the same basic block, and it is an
12739 // AddRec. It means that there is a loop which has both AddRec and Unknown
12740 // PHIs, for it we can compare incoming values of AddRec from above the loop
12741 // and latch with their respective incoming values of LPhi.
12742 // TODO: Generalize to handle loops with many inputs in a header.
12743 if (LPhi->getNumIncomingValues() != 2) return false;
12744
12745 auto *RLoop = RAR->getLoop();
12746 auto *Predecessor = RLoop->getLoopPredecessor();
12747 assert(Predecessor && "Loop with AddRec with no predecessor?");
12748 const SCEV *L1 = getSCEV(LPhi->getIncomingValueForBlock(Predecessor));
12749 if (!ProvedEasily(L1, RAR->getStart()))
12750 return false;
12751 auto *Latch = RLoop->getLoopLatch();
12752 assert(Latch && "Loop with AddRec with no latch?");
12753 const SCEV *L2 = getSCEV(LPhi->getIncomingValueForBlock(Latch));
12754 if (!ProvedEasily(L2, RAR->getPostIncExpr(*this)))
12755 return false;
12756 } else {
12757 // In all other cases go over inputs of LHS and compare each of them to RHS,
12758 // the predicate is true for (LHS, RHS) if it is true for all such pairs.
12759 // At this point RHS is either a non-Phi, or it is a Phi from some block
12760 // different from LBB.
12761 for (const BasicBlock *IncBB : predecessors(LBB)) {
12762 // Check that RHS is available in this block.
12763 if (!dominates(RHS, IncBB))
12764 return false;
12765 const SCEV *L = getSCEV(LPhi->getIncomingValueForBlock(IncBB));
12766 // Make sure L does not refer to a value from a potentially previous
12767 // iteration of a loop.
12768 if (!properlyDominates(L, LBB))
12769 return false;
12770 // Addrecs are considered to properly dominate their loop, so are missed
12771 // by the previous check. Discard any values that have computable
12772 // evolution in this loop.
12773 if (auto *Loop = LI.getLoopFor(LBB))
12774 if (hasComputableLoopEvolution(L, Loop))
12775 return false;
12776 if (!ProvedEasily(L, RHS))
12777 return false;
12778 }
12779 }
12780 return true;
12781}
12782
12783bool ScalarEvolution::isImpliedCondOperandsViaShift(CmpPredicate Pred,
12784 const SCEV *LHS,
12785 const SCEV *RHS,
12786 const SCEV *FoundLHS,
12787 const SCEV *FoundRHS) {
12788 // We want to imply LHS < RHS from LHS < (RHS >> shiftvalue). First, make
12789 // sure that we are dealing with same LHS.
12790 if (RHS == FoundRHS) {
12791 std::swap(LHS, RHS);
12792 std::swap(FoundLHS, FoundRHS);
12794 }
12795 if (LHS != FoundLHS)
12796 return false;
12797
12798 auto *SUFoundRHS = dyn_cast<SCEVUnknown>(FoundRHS);
12799 if (!SUFoundRHS)
12800 return false;
12801
12802 Value *Shiftee, *ShiftValue;
12803
12804 using namespace PatternMatch;
12805 if (match(SUFoundRHS->getValue(),
12806 m_LShr(m_Value(Shiftee), m_Value(ShiftValue)))) {
12807 auto *ShifteeS = getSCEV(Shiftee);
12808 // Prove one of the following:
12809 // LHS <u (shiftee >> shiftvalue) && shiftee <=u RHS ---> LHS <u RHS
12810 // LHS <=u (shiftee >> shiftvalue) && shiftee <=u RHS ---> LHS <=u RHS
12811 // LHS <s (shiftee >> shiftvalue) && shiftee <=s RHS && shiftee >=s 0
12812 // ---> LHS <s RHS
12813 // LHS <=s (shiftee >> shiftvalue) && shiftee <=s RHS && shiftee >=s 0
12814 // ---> LHS <=s RHS
12815 if (Pred == ICmpInst::ICMP_ULT || Pred == ICmpInst::ICMP_ULE)
12816 return isKnownPredicate(ICmpInst::ICMP_ULE, ShifteeS, RHS);
12817 if (Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_SLE)
12818 if (isKnownNonNegative(ShifteeS))
12819 return isKnownPredicate(ICmpInst::ICMP_SLE, ShifteeS, RHS);
12820 }
12821
12822 return false;
12823}
12824
12825bool ScalarEvolution::isImpliedCondOperands(CmpPredicate Pred, const SCEV *LHS,
12826 const SCEV *RHS,
12827 const SCEV *FoundLHS,
12828 const SCEV *FoundRHS,
12829 const Instruction *CtxI) {
12830 return isImpliedCondOperandsViaRanges(Pred, LHS, RHS, Pred, FoundLHS,
12831 FoundRHS) ||
12832 isImpliedCondOperandsViaNoOverflow(Pred, LHS, RHS, FoundLHS,
12833 FoundRHS) ||
12834 isImpliedCondOperandsViaShift(Pred, LHS, RHS, FoundLHS, FoundRHS) ||
12835 isImpliedCondOperandsViaAddRecStart(Pred, LHS, RHS, FoundLHS, FoundRHS,
12836 CtxI) ||
12837 isImpliedCondOperandsHelper(Pred, LHS, RHS, FoundLHS, FoundRHS);
12838}
12839
12840/// Is MaybeMinMaxExpr an (U|S)(Min|Max) of Candidate and some other values?
12841template <typename MinMaxExprType>
12842static bool IsMinMaxConsistingOf(const SCEV *MaybeMinMaxExpr,
12843 const SCEV *Candidate) {
12844 const MinMaxExprType *MinMaxExpr = dyn_cast<MinMaxExprType>(MaybeMinMaxExpr);
12845 if (!MinMaxExpr)
12846 return false;
12847
12848 return is_contained(MinMaxExpr->operands(), Candidate);
12849}
12850
12852 CmpPredicate Pred, const SCEV *LHS,
12853 const SCEV *RHS) {
12854 // If both sides are affine addrecs for the same loop, with equal
12855 // steps, and we know the recurrences don't wrap, then we only
12856 // need to check the predicate on the starting values.
12857
12858 if (!ICmpInst::isRelational(Pred))
12859 return false;
12860
12861 const SCEV *LStart, *RStart, *Step;
12862 const Loop *L;
12863 if (!match(LHS,
12864 m_scev_AffineAddRec(m_SCEV(LStart), m_SCEV(Step), m_Loop(L))) ||
12866 m_SpecificLoop(L))))
12867 return false;
12872 if (!LAR->getNoWrapFlags(NW) || !RAR->getNoWrapFlags(NW))
12873 return false;
12874
12875 return SE.isKnownPredicate(Pred, LStart, RStart);
12876}
12877
12878/// Is LHS `Pred` RHS true on the virtue of LHS or RHS being a Min or Max
12879/// expression?
12881 const SCEV *LHS, const SCEV *RHS) {
12882 switch (Pred) {
12883 default:
12884 return false;
12885
12886 case ICmpInst::ICMP_SGE:
12887 std::swap(LHS, RHS);
12888 [[fallthrough]];
12889 case ICmpInst::ICMP_SLE:
12890 return
12891 // min(A, ...) <= A
12893 // A <= max(A, ...)
12895
12896 case ICmpInst::ICMP_UGE:
12897 std::swap(LHS, RHS);
12898 [[fallthrough]];
12899 case ICmpInst::ICMP_ULE:
12900 return
12901 // min(A, ...) <= A
12902 // FIXME: what about umin_seq?
12904 // A <= max(A, ...)
12906 }
12907
12908 llvm_unreachable("covered switch fell through?!");
12909}
12910
12911bool ScalarEvolution::isImpliedViaOperations(CmpPredicate Pred, const SCEV *LHS,
12912 const SCEV *RHS,
12913 const SCEV *FoundLHS,
12914 const SCEV *FoundRHS,
12915 unsigned Depth) {
12918 "LHS and RHS have different sizes?");
12919 assert(getTypeSizeInBits(FoundLHS->getType()) ==
12920 getTypeSizeInBits(FoundRHS->getType()) &&
12921 "FoundLHS and FoundRHS have different sizes?");
12922 // We want to avoid hurting the compile time with analysis of too big trees.
12924 return false;
12925
12926 // We only want to work with GT comparison so far.
12927 if (ICmpInst::isLT(Pred)) {
12929 std::swap(LHS, RHS);
12930 std::swap(FoundLHS, FoundRHS);
12931 }
12932
12934
12935 // For unsigned, try to reduce it to corresponding signed comparison.
12936 if (P == ICmpInst::ICMP_UGT)
12937 // We can replace unsigned predicate with its signed counterpart if all
12938 // involved values are non-negative.
12939 // TODO: We could have better support for unsigned.
12940 if (isKnownNonNegative(FoundLHS) && isKnownNonNegative(FoundRHS)) {
12941 // Knowing that both FoundLHS and FoundRHS are non-negative, and knowing
12942 // FoundLHS >u FoundRHS, we also know that FoundLHS >s FoundRHS. Let us
12943 // use this fact to prove that LHS and RHS are non-negative.
12944 const SCEV *MinusOne = getMinusOne(LHS->getType());
12945 if (isImpliedCondOperands(ICmpInst::ICMP_SGT, LHS, MinusOne, FoundLHS,
12946 FoundRHS) &&
12947 isImpliedCondOperands(ICmpInst::ICMP_SGT, RHS, MinusOne, FoundLHS,
12948 FoundRHS))
12950 }
12951
12952 if (P != ICmpInst::ICMP_SGT)
12953 return false;
12954
12955 auto GetOpFromSExt = [&](const SCEV *S) -> const SCEV * {
12956 if (auto *Ext = dyn_cast<SCEVSignExtendExpr>(S))
12957 return Ext->getOperand();
12958 // TODO: If S is a SCEVConstant then you can cheaply "strip" the sext off
12959 // the constant in some cases.
12960 return S;
12961 };
12962
12963 // Acquire values from extensions.
12964 auto *OrigLHS = LHS;
12965 auto *OrigFoundLHS = FoundLHS;
12966 LHS = GetOpFromSExt(LHS);
12967 FoundLHS = GetOpFromSExt(FoundLHS);
12968
12969 // Is the SGT predicate can be proved trivially or using the found context.
12970 auto IsSGTViaContext = [&](const SCEV *S1, const SCEV *S2) {
12971 return isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_SGT, S1, S2) ||
12972 isImpliedViaOperations(ICmpInst::ICMP_SGT, S1, S2, OrigFoundLHS,
12973 FoundRHS, Depth + 1);
12974 };
12975
12976 if (auto *LHSAddExpr = dyn_cast<SCEVAddExpr>(LHS)) {
12977 // We want to avoid creation of any new non-constant SCEV. Since we are
12978 // going to compare the operands to RHS, we should be certain that we don't
12979 // need any size extensions for this. So let's decline all cases when the
12980 // sizes of types of LHS and RHS do not match.
12981 // TODO: Maybe try to get RHS from sext to catch more cases?
12983 return false;
12984
12985 // Should not overflow.
12986 if (!LHSAddExpr->hasNoSignedWrap())
12987 return false;
12988
12989 SCEVUse LL = LHSAddExpr->getOperand(0);
12990 SCEVUse LR = LHSAddExpr->getOperand(1);
12991 auto *MinusOne = getMinusOne(RHS->getType());
12992
12993 // Checks that S1 >= 0 && S2 > RHS, trivially or using the found context.
12994 auto IsSumGreaterThanRHS = [&](const SCEV *S1, const SCEV *S2) {
12995 return IsSGTViaContext(S1, MinusOne) && IsSGTViaContext(S2, RHS);
12996 };
12997 // Try to prove the following rule:
12998 // (LHS = LL + LR) && (LL >= 0) && (LR > RHS) => (LHS > RHS).
12999 // (LHS = LL + LR) && (LR >= 0) && (LL > RHS) => (LHS > RHS).
13000 if (IsSumGreaterThanRHS(LL, LR) || IsSumGreaterThanRHS(LR, LL))
13001 return true;
13002 } else if (auto *LHSUnknownExpr = dyn_cast<SCEVUnknown>(LHS)) {
13003 Value *LL, *LR;
13004 // FIXME: Once we have SDiv implemented, we can get rid of this matching.
13005
13006 using namespace llvm::PatternMatch;
13007
13008 if (match(LHSUnknownExpr->getValue(), m_SDiv(m_Value(LL), m_Value(LR)))) {
13009 // Rules for division.
13010 // We are going to perform some comparisons with Denominator and its
13011 // derivative expressions. In general case, creating a SCEV for it may
13012 // lead to a complex analysis of the entire graph, and in particular it
13013 // can request trip count recalculation for the same loop. This would
13014 // cache as SCEVCouldNotCompute to avoid the infinite recursion. To avoid
13015 // this, we only want to create SCEVs that are constants in this section.
13016 // So we bail if Denominator is not a constant.
13017 if (!isa<ConstantInt>(LR))
13018 return false;
13019
13020 auto *Denominator = cast<SCEVConstant>(getSCEV(LR));
13021
13022 // We want to make sure that LHS = FoundLHS / Denominator. If it is so,
13023 // then a SCEV for the numerator already exists and matches with FoundLHS.
13024 auto *Numerator = getExistingSCEV(LL);
13025 if (!Numerator || Numerator->getType() != FoundLHS->getType())
13026 return false;
13027
13028 // Make sure that the numerator matches with FoundLHS and the denominator
13029 // is positive.
13030 if (!HasSameValue(Numerator, FoundLHS) || !isKnownPositive(Denominator))
13031 return false;
13032
13033 auto *DTy = Denominator->getType();
13034 auto *FRHSTy = FoundRHS->getType();
13035 if (DTy->isPointerTy() != FRHSTy->isPointerTy())
13036 // One of types is a pointer and another one is not. We cannot extend
13037 // them properly to a wider type, so let us just reject this case.
13038 // TODO: Usage of getEffectiveSCEVType for DTy, FRHSTy etc should help
13039 // to avoid this check.
13040 return false;
13041
13042 // Given that:
13043 // FoundLHS > FoundRHS, LHS = FoundLHS / Denominator, Denominator > 0.
13044 auto *WTy = getWiderType(DTy, FRHSTy);
13045 auto *DenominatorExt = getNoopOrSignExtend(Denominator, WTy);
13046 auto *FoundRHSExt = getNoopOrSignExtend(FoundRHS, WTy);
13047
13048 // Try to prove the following rule:
13049 // (FoundRHS > Denominator - 2) && (RHS <= 0) => (LHS > RHS).
13050 // For example, given that FoundLHS > 2. It means that FoundLHS is at
13051 // least 3. If we divide it by Denominator < 4, we will have at least 1.
13052 auto *DenomMinusTwo = getMinusSCEV(DenominatorExt, getConstant(WTy, 2));
13053 if (isKnownNonPositive(RHS) &&
13054 IsSGTViaContext(FoundRHSExt, DenomMinusTwo))
13055 return true;
13056
13057 // Try to prove the following rule:
13058 // (FoundRHS > -1 - Denominator) && (RHS < 0) => (LHS > RHS).
13059 // For example, given that FoundLHS > -3. Then FoundLHS is at least -2.
13060 // If we divide it by Denominator > 2, then:
13061 // 1. If FoundLHS is negative, then the result is 0.
13062 // 2. If FoundLHS is non-negative, then the result is non-negative.
13063 // Anyways, the result is non-negative.
13064 auto *MinusOne = getMinusOne(WTy);
13065 auto *NegDenomMinusOne = getMinusSCEV(MinusOne, DenominatorExt);
13066 if (isKnownNegative(RHS) &&
13067 IsSGTViaContext(FoundRHSExt, NegDenomMinusOne))
13068 return true;
13069 }
13070 }
13071
13072 // If our expression contained SCEVUnknown Phis, and we split it down and now
13073 // need to prove something for them, try to prove the predicate for every
13074 // possible incoming values of those Phis.
13075 if (isImpliedViaMerge(Pred, OrigLHS, RHS, OrigFoundLHS, FoundRHS, Depth + 1))
13076 return true;
13077
13078 return false;
13079}
13080
13082 const SCEV *RHS) {
13083 // zext x u<= sext x, sext x s<= zext x
13084 const SCEV *Op;
13085 switch (Pred) {
13086 case ICmpInst::ICMP_SGE:
13087 std::swap(LHS, RHS);
13088 [[fallthrough]];
13089 case ICmpInst::ICMP_SLE: {
13090 // If operand >=s 0 then ZExt == SExt. If operand <s 0 then SExt <s ZExt.
13091 return match(LHS, m_scev_SExt(m_SCEV(Op))) &&
13093 }
13094 case ICmpInst::ICMP_UGE:
13095 std::swap(LHS, RHS);
13096 [[fallthrough]];
13097 case ICmpInst::ICMP_ULE: {
13098 // If operand >=u 0 then ZExt == SExt. If operand <u 0 then ZExt <u SExt.
13099 return match(LHS, m_scev_ZExt(m_SCEV(Op))) &&
13101 }
13102 default:
13103 return false;
13104 };
13105 llvm_unreachable("unhandled case");
13106}
13107
13108bool ScalarEvolution::isKnownViaNonRecursiveReasoning(CmpPredicate Pred,
13109 SCEVUse LHS,
13110 SCEVUse RHS) {
13111 return isKnownPredicateExtendIdiom(Pred, LHS, RHS) ||
13112 isKnownPredicateViaConstantRanges(Pred, LHS, RHS) ||
13113 IsKnownPredicateViaMinOrMax(*this, Pred, LHS, RHS) ||
13114 IsKnownPredicateViaAddRecStart(*this, Pred, LHS, RHS) ||
13115 isKnownPredicateViaNoOverflow(Pred, LHS, RHS);
13116}
13117
13118bool ScalarEvolution::isImpliedCondOperandsHelper(CmpPredicate Pred,
13119 const SCEV *LHS,
13120 const SCEV *RHS,
13121 const SCEV *FoundLHS,
13122 const SCEV *FoundRHS) {
13123 switch (Pred) {
13124 default:
13125 llvm_unreachable("Unexpected CmpPredicate value!");
13126 case ICmpInst::ICMP_EQ:
13127 case ICmpInst::ICMP_NE:
13128 if (HasSameValue(LHS, FoundLHS) && HasSameValue(RHS, FoundRHS))
13129 return true;
13130 break;
13131 case ICmpInst::ICMP_SLT:
13132 case ICmpInst::ICMP_SLE:
13133 if (isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_SLE, LHS, FoundLHS) &&
13134 isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_SGE, RHS, FoundRHS))
13135 return true;
13136 break;
13137 case ICmpInst::ICMP_SGT:
13138 case ICmpInst::ICMP_SGE:
13139 if (isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_SGE, LHS, FoundLHS) &&
13140 isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_SLE, RHS, FoundRHS))
13141 return true;
13142 break;
13143 case ICmpInst::ICMP_ULT:
13144 case ICmpInst::ICMP_ULE:
13145 if (isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_ULE, LHS, FoundLHS) &&
13146 isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_UGE, RHS, FoundRHS))
13147 return true;
13148 break;
13149 case ICmpInst::ICMP_UGT:
13150 case ICmpInst::ICMP_UGE:
13151 if (isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_UGE, LHS, FoundLHS) &&
13152 isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_ULE, RHS, FoundRHS))
13153 return true;
13154 break;
13155 }
13156
13157 // Maybe it can be proved via operations?
13158 if (isImpliedViaOperations(Pred, LHS, RHS, FoundLHS, FoundRHS))
13159 return true;
13160
13161 return false;
13162}
13163
13164bool ScalarEvolution::isImpliedCondOperandsViaRanges(
13165 CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS, CmpPredicate FoundPred,
13166 const SCEV *FoundLHS, const SCEV *FoundRHS) {
13167 if (!isa<SCEVConstant>(RHS) || !isa<SCEVConstant>(FoundRHS))
13168 // The restriction on `FoundRHS` be lifted easily -- it exists only to
13169 // reduce the compile time impact of this optimization.
13170 return false;
13171
13172 std::optional<APInt> Addend = computeConstantDifference(LHS, FoundLHS);
13173 if (!Addend)
13174 return false;
13175
13176 const APInt &ConstFoundRHS = cast<SCEVConstant>(FoundRHS)->getAPInt();
13177
13178 // `FoundLHSRange` is the range we know `FoundLHS` to be in by virtue of the
13179 // antecedent "`FoundLHS` `FoundPred` `FoundRHS`".
13180 ConstantRange FoundLHSRange =
13181 ConstantRange::makeExactICmpRegion(FoundPred, ConstFoundRHS);
13182
13183 // Since `LHS` is `FoundLHS` + `Addend`, we can compute a range for `LHS`:
13184 ConstantRange LHSRange = FoundLHSRange.add(ConstantRange(*Addend));
13185
13186 // We can also compute the range of values for `LHS` that satisfy the
13187 // consequent, "`LHS` `Pred` `RHS`":
13188 const APInt &ConstRHS = cast<SCEVConstant>(RHS)->getAPInt();
13189 // The antecedent implies the consequent if every value of `LHS` that
13190 // satisfies the antecedent also satisfies the consequent.
13191 return LHSRange.icmp(Pred, ConstRHS);
13192}
13193
13194bool ScalarEvolution::canIVOverflowOnLT(const SCEV *RHS, const SCEV *Stride,
13195 bool IsSigned) {
13196 assert(isKnownPositive(Stride) && "Positive stride expected!");
13197
13198 unsigned BitWidth = getTypeSizeInBits(RHS->getType());
13199 const SCEV *One = getOne(Stride->getType());
13200
13201 if (IsSigned) {
13202 APInt MaxRHS = getSignedRangeMax(RHS);
13203 APInt MaxValue = APInt::getSignedMaxValue(BitWidth);
13204 APInt MaxStrideMinusOne = getSignedRangeMax(getMinusSCEV(Stride, One));
13205
13206 // SMaxRHS + SMaxStrideMinusOne > SMaxValue => overflow!
13207 return (std::move(MaxValue) - MaxStrideMinusOne).slt(MaxRHS);
13208 }
13209
13210 APInt MaxRHS = getUnsignedRangeMax(RHS);
13211 APInt MaxValue = APInt::getMaxValue(BitWidth);
13212 APInt MaxStrideMinusOne = getUnsignedRangeMax(getMinusSCEV(Stride, One));
13213
13214 // UMaxRHS + UMaxStrideMinusOne > UMaxValue => overflow!
13215 return (std::move(MaxValue) - MaxStrideMinusOne).ult(MaxRHS);
13216}
13217
13218bool ScalarEvolution::canIVOverflowOnGT(const SCEV *RHS, const SCEV *Stride,
13219 bool IsSigned) {
13220
13221 unsigned BitWidth = getTypeSizeInBits(RHS->getType());
13222 const SCEV *One = getOne(Stride->getType());
13223
13224 if (IsSigned) {
13225 APInt MinRHS = getSignedRangeMin(RHS);
13226 APInt MinValue = APInt::getSignedMinValue(BitWidth);
13227 APInt MaxStrideMinusOne = getSignedRangeMax(getMinusSCEV(Stride, One));
13228
13229 // SMinRHS - SMaxStrideMinusOne < SMinValue => overflow!
13230 return (std::move(MinValue) + MaxStrideMinusOne).sgt(MinRHS);
13231 }
13232
13233 APInt MinRHS = getUnsignedRangeMin(RHS);
13234 APInt MinValue = APInt::getMinValue(BitWidth);
13235 APInt MaxStrideMinusOne = getUnsignedRangeMax(getMinusSCEV(Stride, One));
13236
13237 // UMinRHS - UMaxStrideMinusOne < UMinValue => overflow!
13238 return (std::move(MinValue) + MaxStrideMinusOne).ugt(MinRHS);
13239}
13240
13242 // umin(N, 1) + floor((N - umin(N, 1)) / D)
13243 // This is equivalent to "1 + floor((N - 1) / D)" for N != 0. The umin
13244 // expression fixes the case of N=0.
13245 const SCEV *MinNOne = getUMinExpr(N, getOne(N->getType()));
13246 const SCEV *NMinusOne = getMinusSCEV(N, MinNOne);
13247 return getAddExpr(MinNOne, getUDivExpr(NMinusOne, D));
13248}
13249
13250const SCEV *ScalarEvolution::computeMaxBECountForLT(const SCEV *Start,
13251 const SCEV *Stride,
13252 const SCEV *End,
13253 unsigned BitWidth,
13254 bool IsSigned) {
13255 // The logic in this function assumes we can represent a positive stride.
13256 // If we can't, the backedge-taken count must be zero.
13257 if (IsSigned && BitWidth == 1)
13258 return getZero(Stride->getType());
13259
13260 // This code below only been closely audited for negative strides in the
13261 // unsigned comparison case, it may be correct for signed comparison, but
13262 // that needs to be established.
13263 if (IsSigned && isKnownNegative(Stride))
13264 return getCouldNotCompute();
13265
13266 // Calculate the maximum backedge count based on the range of values
13267 // permitted by Start, End, and Stride.
13268 APInt MinStart =
13269 IsSigned ? getSignedRangeMin(Start) : getUnsignedRangeMin(Start);
13270
13271 APInt MinStride =
13272 IsSigned ? getSignedRangeMin(Stride) : getUnsignedRangeMin(Stride);
13273
13274 // We assume either the stride is positive, or the backedge-taken count
13275 // is zero. So force StrideForMaxBECount to be at least one.
13276 APInt One(BitWidth, 1);
13277 APInt StrideForMaxBECount = IsSigned ? APIntOps::smax(One, MinStride)
13278 : APIntOps::umax(One, MinStride);
13279
13280 APInt MaxValue = IsSigned ? APInt::getSignedMaxValue(BitWidth)
13281 : APInt::getMaxValue(BitWidth);
13282 APInt Limit = MaxValue - (StrideForMaxBECount - 1);
13283
13284 // Although End can be a MAX expression we estimate MaxEnd considering only
13285 // the case End = RHS of the loop termination condition. This is safe because
13286 // in the other case (End - Start) is zero, leading to a zero maximum backedge
13287 // taken count.
13288 APInt MaxEnd = IsSigned ? APIntOps::smin(getSignedRangeMax(End), Limit)
13289 : APIntOps::umin(getUnsignedRangeMax(End), Limit);
13290
13291 // MaxBECount = ceil((max(MaxEnd, MinStart) - MinStart) / Stride)
13292 MaxEnd = IsSigned ? APIntOps::smax(MaxEnd, MinStart)
13293 : APIntOps::umax(MaxEnd, MinStart);
13294
13295 return getUDivCeilSCEV(getConstant(MaxEnd - MinStart) /* Delta */,
13296 getConstant(StrideForMaxBECount) /* Step */);
13297}
13298
13300ScalarEvolution::howManyLessThans(const SCEV *LHS, const SCEV *RHS,
13301 const Loop *L, bool IsSigned,
13302 bool ControlsOnlyExit, bool AllowPredicates) {
13304
13306 bool PredicatedIV = false;
13307 if (!IV) {
13308 if (auto *ZExt = dyn_cast<SCEVZeroExtendExpr>(LHS)) {
13309 const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(ZExt->getOperand());
13310 if (AR && AR->getLoop() == L && AR->isAffine()) {
13311 auto canProveNUW = [&]() {
13312 // We can use the comparison to infer no-wrap flags only if it fully
13313 // controls the loop exit.
13314 if (!ControlsOnlyExit)
13315 return false;
13316
13317 if (!isLoopInvariant(RHS, L))
13318 return false;
13319
13320 if (!isKnownNonZero(AR->getStepRecurrence(*this)))
13321 // We need the sequence defined by AR to strictly increase in the
13322 // unsigned integer domain for the logic below to hold.
13323 return false;
13324
13325 const unsigned InnerBitWidth = getTypeSizeInBits(AR->getType());
13326 const unsigned OuterBitWidth = getTypeSizeInBits(RHS->getType());
13327 // If RHS <=u Limit, then there must exist a value V in the sequence
13328 // defined by AR (e.g. {Start,+,Step}) such that V >u RHS, and
13329 // V <=u UINT_MAX. Thus, we must exit the loop before unsigned
13330 // overflow occurs. This limit also implies that a signed comparison
13331 // (in the wide bitwidth) is equivalent to an unsigned comparison as
13332 // the high bits on both sides must be zero.
13333 APInt StrideMax = getUnsignedRangeMax(AR->getStepRecurrence(*this));
13334 APInt Limit = APInt::getMaxValue(InnerBitWidth) - (StrideMax - 1);
13335 Limit = Limit.zext(OuterBitWidth);
13336 return getUnsignedRangeMax(applyLoopGuards(RHS, L)).ule(Limit);
13337 };
13338 auto Flags = AR->getNoWrapFlags();
13339 if (!hasFlags(Flags, SCEV::FlagNUW) && canProveNUW())
13340 Flags = setFlags(Flags, SCEV::FlagNUW);
13341
13342 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), Flags);
13343 if (AR->hasNoUnsignedWrap()) {
13344 // Emulate what getZeroExtendExpr would have done during construction
13345 // if we'd been able to infer the fact just above at that time.
13346 const SCEV *Step = AR->getStepRecurrence(*this);
13347 Type *Ty = ZExt->getType();
13348 auto *S = getAddRecExpr(
13350 getZeroExtendExpr(Step, Ty, 0), L, AR->getNoWrapFlags());
13352 }
13353 }
13354 }
13355 }
13356
13357
13358 if (!IV && AllowPredicates) {
13359 // Try to make this an AddRec using runtime tests, in the first X
13360 // iterations of this loop, where X is the SCEV expression found by the
13361 // algorithm below.
13362 IV = convertSCEVToAddRecWithPredicates(LHS, L, Predicates);
13363 PredicatedIV = true;
13364 }
13365
13366 // Avoid weird loops
13367 if (!IV || IV->getLoop() != L || !IV->isAffine())
13368 return getCouldNotCompute();
13369
13370 // A precondition of this method is that the condition being analyzed
13371 // reaches an exiting branch which dominates the latch. Given that, we can
13372 // assume that an increment which violates the nowrap specification and
13373 // produces poison must cause undefined behavior when the resulting poison
13374 // value is branched upon and thus we can conclude that the backedge is
13375 // taken no more often than would be required to produce that poison value.
13376 // Note that a well defined loop can exit on the iteration which violates
13377 // the nowrap specification if there is another exit (either explicit or
13378 // implicit/exceptional) which causes the loop to execute before the
13379 // exiting instruction we're analyzing would trigger UB.
13380 auto WrapType = IsSigned ? SCEV::FlagNSW : SCEV::FlagNUW;
13381 bool NoWrap = ControlsOnlyExit && IV->getNoWrapFlags(WrapType);
13383
13384 const SCEV *Stride = IV->getStepRecurrence(*this);
13385
13386 bool PositiveStride = isKnownPositive(Stride);
13387
13388 // Avoid negative or zero stride values.
13389 if (!PositiveStride) {
13390 // We can compute the correct backedge taken count for loops with unknown
13391 // strides if we can prove that the loop is not an infinite loop with side
13392 // effects. Here's the loop structure we are trying to handle -
13393 //
13394 // i = start
13395 // do {
13396 // A[i] = i;
13397 // i += s;
13398 // } while (i < end);
13399 //
13400 // The backedge taken count for such loops is evaluated as -
13401 // (max(end, start + stride) - start - 1) /u stride
13402 //
13403 // The additional preconditions that we need to check to prove correctness
13404 // of the above formula is as follows -
13405 //
13406 // a) IV is either nuw or nsw depending upon signedness (indicated by the
13407 // NoWrap flag).
13408 // b) the loop is guaranteed to be finite (e.g. is mustprogress and has
13409 // no side effects within the loop)
13410 // c) loop has a single static exit (with no abnormal exits)
13411 //
13412 // Precondition a) implies that if the stride is negative, this is a single
13413 // trip loop. The backedge taken count formula reduces to zero in this case.
13414 //
13415 // Precondition b) and c) combine to imply that if rhs is invariant in L,
13416 // then a zero stride means the backedge can't be taken without executing
13417 // undefined behavior.
13418 //
13419 // The positive stride case is the same as isKnownPositive(Stride) returning
13420 // true (original behavior of the function).
13421 //
13422 if (PredicatedIV || !NoWrap || !loopIsFiniteByAssumption(L) ||
13424 return getCouldNotCompute();
13425
13426 if (!isKnownNonZero(Stride)) {
13427 // If we have a step of zero, and RHS isn't invariant in L, we don't know
13428 // if it might eventually be greater than start and if so, on which
13429 // iteration. We can't even produce a useful upper bound.
13430 if (!isLoopInvariant(RHS, L))
13431 return getCouldNotCompute();
13432
13433 // We allow a potentially zero stride, but we need to divide by stride
13434 // below. Since the loop can't be infinite and this check must control
13435 // the sole exit, we can infer the exit must be taken on the first
13436 // iteration (e.g. backedge count = 0) if the stride is zero. Given that,
13437 // we know the numerator in the divides below must be zero, so we can
13438 // pick an arbitrary non-zero value for the denominator (e.g. stride)
13439 // and produce the right result.
13440 // FIXME: Handle the case where Stride is poison?
13441 auto wouldZeroStrideBeUB = [&]() {
13442 // Proof by contradiction. Suppose the stride were zero. If we can
13443 // prove that the backedge *is* taken on the first iteration, then since
13444 // we know this condition controls the sole exit, we must have an
13445 // infinite loop. We can't have a (well defined) infinite loop per
13446 // check just above.
13447 // Note: The (Start - Stride) term is used to get the start' term from
13448 // (start' + stride,+,stride). Remember that we only care about the
13449 // result of this expression when stride == 0 at runtime.
13450 auto *StartIfZero = getMinusSCEV(IV->getStart(), Stride);
13451 return isLoopEntryGuardedByCond(L, Cond, StartIfZero, RHS);
13452 };
13453 if (!wouldZeroStrideBeUB()) {
13454 Stride = getUMaxExpr(Stride, getOne(Stride->getType()));
13455 }
13456 }
13457 } else if (!NoWrap) {
13458 // Avoid proven overflow cases: this will ensure that the backedge taken
13459 // count will not generate any unsigned overflow.
13460 if (canIVOverflowOnLT(RHS, Stride, IsSigned))
13461 return getCouldNotCompute();
13462 }
13463
13464 // On all paths just preceeding, we established the following invariant:
13465 // IV can be assumed not to overflow up to and including the exiting
13466 // iteration. We proved this in one of two ways:
13467 // 1) We can show overflow doesn't occur before the exiting iteration
13468 // 1a) canIVOverflowOnLT, and b) step of one
13469 // 2) We can show that if overflow occurs, the loop must execute UB
13470 // before any possible exit.
13471 // Note that we have not yet proved RHS invariant (in general).
13472
13473 const SCEV *Start = IV->getStart();
13474
13475 // Preserve pointer-typed Start/RHS to pass to isLoopEntryGuardedByCond.
13476 // If we convert to integers, isLoopEntryGuardedByCond will miss some cases.
13477 // Use integer-typed versions for actual computation; we can't subtract
13478 // pointers in general.
13479 const SCEV *OrigStart = Start;
13480 const SCEV *OrigRHS = RHS;
13481 if (Start->getType()->isPointerTy()) {
13483 if (isa<SCEVCouldNotCompute>(Start))
13484 return Start;
13485 }
13486 if (RHS->getType()->isPointerTy()) {
13489 return RHS;
13490 }
13491
13492 const SCEV *End = nullptr, *BECount = nullptr,
13493 *BECountIfBackedgeTaken = nullptr;
13494 if (!isLoopInvariant(RHS, L)) {
13495 const auto *RHSAddRec = dyn_cast<SCEVAddRecExpr>(RHS);
13496 if (PositiveStride && RHSAddRec != nullptr && RHSAddRec->getLoop() == L &&
13497 RHSAddRec->getNoWrapFlags()) {
13498 // The structure of loop we are trying to calculate backedge count of:
13499 //
13500 // left = left_start
13501 // right = right_start
13502 //
13503 // while(left < right){
13504 // ... do something here ...
13505 // left += s1; // stride of left is s1 (s1 > 0)
13506 // right += s2; // stride of right is s2 (s2 < 0)
13507 // }
13508 //
13509
13510 const SCEV *RHSStart = RHSAddRec->getStart();
13511 const SCEV *RHSStride = RHSAddRec->getStepRecurrence(*this);
13512
13513 // If Stride - RHSStride is positive and does not overflow, we can write
13514 // backedge count as ->
13515 // ceil((End - Start) /u (Stride - RHSStride))
13516 // Where, End = max(RHSStart, Start)
13517
13518 // Check if RHSStride < 0 and Stride - RHSStride will not overflow.
13519 if (isKnownNegative(RHSStride) &&
13520 willNotOverflow(Instruction::Sub, /*Signed=*/true, Stride,
13521 RHSStride)) {
13522
13523 const SCEV *Denominator = getMinusSCEV(Stride, RHSStride);
13524 if (isKnownPositive(Denominator)) {
13525 End = IsSigned ? getSMaxExpr(RHSStart, Start)
13526 : getUMaxExpr(RHSStart, Start);
13527
13528 // We can do this because End >= Start, as End = max(RHSStart, Start)
13529 const SCEV *Delta = getMinusSCEV(End, Start);
13530
13531 BECount = getUDivCeilSCEV(Delta, Denominator);
13532 BECountIfBackedgeTaken =
13533 getUDivCeilSCEV(getMinusSCEV(RHSStart, Start), Denominator);
13534 }
13535 }
13536 }
13537 if (BECount == nullptr) {
13538 // If we cannot calculate ExactBECount, we can calculate the MaxBECount,
13539 // given the start, stride and max value for the end bound of the
13540 // loop (RHS), and the fact that IV does not overflow (which is
13541 // checked above).
13542 const SCEV *MaxBECount = computeMaxBECountForLT(
13543 Start, Stride, RHS, getTypeSizeInBits(LHS->getType()), IsSigned);
13544 return ExitLimit(getCouldNotCompute() /* ExactNotTaken */, MaxBECount,
13545 MaxBECount, false /*MaxOrZero*/, Predicates);
13546 }
13547 } else {
13548 // We use the expression (max(End,Start)-Start)/Stride to describe the
13549 // backedge count, as if the backedge is taken at least once
13550 // max(End,Start) is End and so the result is as above, and if not
13551 // max(End,Start) is Start so we get a backedge count of zero.
13552 auto *OrigStartMinusStride = getMinusSCEV(OrigStart, Stride);
13553 assert(isAvailableAtLoopEntry(OrigStartMinusStride, L) && "Must be!");
13554 assert(isAvailableAtLoopEntry(OrigStart, L) && "Must be!");
13555 assert(isAvailableAtLoopEntry(OrigRHS, L) && "Must be!");
13556 // Can we prove (max(RHS,Start) > Start - Stride?
13557 if (isLoopEntryGuardedByCond(L, Cond, OrigStartMinusStride, OrigStart) &&
13558 isLoopEntryGuardedByCond(L, Cond, OrigStartMinusStride, OrigRHS)) {
13559 // In this case, we can use a refined formula for computing backedge
13560 // taken count. The general formula remains:
13561 // "End-Start /uceiling Stride" where "End = max(RHS,Start)"
13562 // We want to use the alternate formula:
13563 // "((End - 1) - (Start - Stride)) /u Stride"
13564 // Let's do a quick case analysis to show these are equivalent under
13565 // our precondition that max(RHS,Start) > Start - Stride.
13566 // * For RHS <= Start, the backedge-taken count must be zero.
13567 // "((End - 1) - (Start - Stride)) /u Stride" reduces to
13568 // "((Start - 1) - (Start - Stride)) /u Stride" which simplies to
13569 // "Stride - 1 /u Stride" which is indeed zero for all non-zero values
13570 // of Stride. For 0 stride, we've use umin(1,Stride) above,
13571 // reducing this to the stride of 1 case.
13572 // * For RHS >= Start, the backedge count must be "RHS-Start /uceil
13573 // Stride".
13574 // "((End - 1) - (Start - Stride)) /u Stride" reduces to
13575 // "((RHS - 1) - (Start - Stride)) /u Stride" reassociates to
13576 // "((RHS - (Start - Stride) - 1) /u Stride".
13577 // Our preconditions trivially imply no overflow in that form.
13578 const SCEV *MinusOne = getMinusOne(Stride->getType());
13579 const SCEV *Numerator =
13580 getMinusSCEV(getAddExpr(RHS, MinusOne), getMinusSCEV(Start, Stride));
13581 BECount = getUDivExpr(Numerator, Stride);
13582 }
13583
13584 if (!BECount) {
13585 auto canProveRHSGreaterThanEqualStart = [&]() {
13586 auto CondGE = IsSigned ? ICmpInst::ICMP_SGE : ICmpInst::ICMP_UGE;
13587 const SCEV *GuardedRHS = applyLoopGuards(OrigRHS, L);
13588 const SCEV *GuardedStart = applyLoopGuards(OrigStart, L);
13589
13590 if (isLoopEntryGuardedByCond(L, CondGE, OrigRHS, OrigStart) ||
13591 isKnownPredicate(CondGE, GuardedRHS, GuardedStart))
13592 return true;
13593
13594 // (RHS > Start - 1) implies RHS >= Start.
13595 // * "RHS >= Start" is trivially equivalent to "RHS > Start - 1" if
13596 // "Start - 1" doesn't overflow.
13597 // * For signed comparison, if Start - 1 does overflow, it's equal
13598 // to INT_MAX, and "RHS >s INT_MAX" is trivially false.
13599 // * For unsigned comparison, if Start - 1 does overflow, it's equal
13600 // to UINT_MAX, and "RHS >u UINT_MAX" is trivially false.
13601 //
13602 // FIXME: Should isLoopEntryGuardedByCond do this for us?
13603 auto CondGT = IsSigned ? ICmpInst::ICMP_SGT : ICmpInst::ICMP_UGT;
13604 auto *StartMinusOne =
13605 getAddExpr(OrigStart, getMinusOne(OrigStart->getType()));
13606 return isLoopEntryGuardedByCond(L, CondGT, OrigRHS, StartMinusOne);
13607 };
13608
13609 // If we know that RHS >= Start in the context of loop, then we know
13610 // that max(RHS, Start) = RHS at this point.
13611 if (canProveRHSGreaterThanEqualStart()) {
13612 End = RHS;
13613 } else {
13614 // If RHS < Start, the backedge will be taken zero times. So in
13615 // general, we can write the backedge-taken count as:
13616 //
13617 // RHS >= Start ? ceil(RHS - Start) / Stride : 0
13618 //
13619 // We convert it to the following to make it more convenient for SCEV:
13620 //
13621 // ceil(max(RHS, Start) - Start) / Stride
13622 End = IsSigned ? getSMaxExpr(RHS, Start) : getUMaxExpr(RHS, Start);
13623
13624 // See what would happen if we assume the backedge is taken. This is
13625 // used to compute MaxBECount.
13626 BECountIfBackedgeTaken =
13627 getUDivCeilSCEV(getMinusSCEV(RHS, Start), Stride);
13628 }
13629
13630 // At this point, we know:
13631 //
13632 // 1. If IsSigned, Start <=s End; otherwise, Start <=u End
13633 // 2. The index variable doesn't overflow.
13634 //
13635 // Therefore, we know N exists such that
13636 // (Start + Stride * N) >= End, and computing "(Start + Stride * N)"
13637 // doesn't overflow.
13638 //
13639 // Using this information, try to prove whether the addition in
13640 // "(Start - End) + (Stride - 1)" has unsigned overflow.
13641 const SCEV *One = getOne(Stride->getType());
13642 bool MayAddOverflow = [&] {
13643 if (isKnownToBeAPowerOfTwo(Stride)) {
13644 // Suppose Stride is a power of two, and Start/End are unsigned
13645 // integers. Let UMAX be the largest representable unsigned
13646 // integer.
13647 //
13648 // By the preconditions of this function, we know
13649 // "(Start + Stride * N) >= End", and this doesn't overflow.
13650 // As a formula:
13651 //
13652 // End <= (Start + Stride * N) <= UMAX
13653 //
13654 // Subtracting Start from all the terms:
13655 //
13656 // End - Start <= Stride * N <= UMAX - Start
13657 //
13658 // Since Start is unsigned, UMAX - Start <= UMAX. Therefore:
13659 //
13660 // End - Start <= Stride * N <= UMAX
13661 //
13662 // Stride * N is a multiple of Stride. Therefore,
13663 //
13664 // End - Start <= Stride * N <= UMAX - (UMAX mod Stride)
13665 //
13666 // Since Stride is a power of two, UMAX + 1 is divisible by
13667 // Stride. Therefore, UMAX mod Stride == Stride - 1. So we can
13668 // write:
13669 //
13670 // End - Start <= Stride * N <= UMAX - Stride - 1
13671 //
13672 // Dropping the middle term:
13673 //
13674 // End - Start <= UMAX - Stride - 1
13675 //
13676 // Adding Stride - 1 to both sides:
13677 //
13678 // (End - Start) + (Stride - 1) <= UMAX
13679 //
13680 // In other words, the addition doesn't have unsigned overflow.
13681 //
13682 // A similar proof works if we treat Start/End as signed values.
13683 // Just rewrite steps before "End - Start <= Stride * N <= UMAX"
13684 // to use signed max instead of unsigned max. Note that we're
13685 // trying to prove a lack of unsigned overflow in either case.
13686 return false;
13687 }
13688 if (Start == Stride || Start == getMinusSCEV(Stride, One)) {
13689 // If Start is equal to Stride, (End - Start) + (Stride - 1) == End
13690 // - 1. If !IsSigned, 0 <u Stride == Start <=u End; so 0 <u End - 1
13691 // <u End. If IsSigned, 0 <s Stride == Start <=s End; so 0 <s End -
13692 // 1 <s End.
13693 //
13694 // If Start is equal to Stride - 1, (End - Start) + Stride - 1 ==
13695 // End.
13696 return false;
13697 }
13698 return true;
13699 }();
13700
13701 const SCEV *Delta = getMinusSCEV(End, Start);
13702 if (!MayAddOverflow) {
13703 // floor((D + (S - 1)) / S)
13704 // We prefer this formulation if it's legal because it's fewer
13705 // operations.
13706 BECount =
13707 getUDivExpr(getAddExpr(Delta, getMinusSCEV(Stride, One)), Stride);
13708 } else {
13709 BECount = getUDivCeilSCEV(Delta, Stride);
13710 }
13711 }
13712 }
13713
13714 const SCEV *ConstantMaxBECount;
13715 bool MaxOrZero = false;
13716 if (isa<SCEVConstant>(BECount)) {
13717 ConstantMaxBECount = BECount;
13718 } else if (BECountIfBackedgeTaken &&
13719 isa<SCEVConstant>(BECountIfBackedgeTaken)) {
13720 // If we know exactly how many times the backedge will be taken if it's
13721 // taken at least once, then the backedge count will either be that or
13722 // zero.
13723 ConstantMaxBECount = BECountIfBackedgeTaken;
13724 MaxOrZero = true;
13725 } else {
13726 ConstantMaxBECount = computeMaxBECountForLT(
13727 Start, Stride, RHS, getTypeSizeInBits(LHS->getType()), IsSigned);
13728 }
13729
13730 if (isa<SCEVCouldNotCompute>(ConstantMaxBECount) &&
13731 !isa<SCEVCouldNotCompute>(BECount))
13732 ConstantMaxBECount = getConstant(getUnsignedRangeMax(BECount));
13733
13734 const SCEV *SymbolicMaxBECount =
13735 isa<SCEVCouldNotCompute>(BECount) ? ConstantMaxBECount : BECount;
13736 return ExitLimit(BECount, ConstantMaxBECount, SymbolicMaxBECount, MaxOrZero,
13737 Predicates);
13738}
13739
13740ScalarEvolution::ExitLimit ScalarEvolution::howManyGreaterThans(
13741 const SCEV *LHS, const SCEV *RHS, const Loop *L, bool IsSigned,
13742 bool ControlsOnlyExit, bool AllowPredicates) {
13744 // We handle only IV > Invariant
13745 if (!isLoopInvariant(RHS, L))
13746 return getCouldNotCompute();
13747
13748 const SCEVAddRecExpr *IV = dyn_cast<SCEVAddRecExpr>(LHS);
13749 if (!IV && AllowPredicates)
13750 // Try to make this an AddRec using runtime tests, in the first X
13751 // iterations of this loop, where X is the SCEV expression found by the
13752 // algorithm below.
13753 IV = convertSCEVToAddRecWithPredicates(LHS, L, Predicates);
13754
13755 // Avoid weird loops
13756 if (!IV || IV->getLoop() != L || !IV->isAffine())
13757 return getCouldNotCompute();
13758
13759 auto WrapType = IsSigned ? SCEV::FlagNSW : SCEV::FlagNUW;
13760 bool NoWrap = ControlsOnlyExit && IV->getNoWrapFlags(WrapType);
13762
13763 const SCEV *Stride = getNegativeSCEV(IV->getStepRecurrence(*this));
13764
13765 // Avoid negative or zero stride values
13766 if (!isKnownPositive(Stride))
13767 return getCouldNotCompute();
13768
13769 // Avoid proven overflow cases: this will ensure that the backedge taken count
13770 // will not generate any unsigned overflow. Relaxed no-overflow conditions
13771 // exploit NoWrapFlags, allowing to optimize in presence of undefined
13772 // behaviors like the case of C language.
13773 if (!Stride->isOne() && !NoWrap)
13774 if (canIVOverflowOnGT(RHS, Stride, IsSigned))
13775 return getCouldNotCompute();
13776
13777 const SCEV *Start = IV->getStart();
13778 const SCEV *End = RHS;
13779 if (!isLoopEntryGuardedByCond(L, Cond, getAddExpr(Start, Stride), RHS)) {
13780 // If we know that Start >= RHS in the context of loop, then we know that
13781 // min(RHS, Start) = RHS at this point.
13783 L, IsSigned ? ICmpInst::ICMP_SGE : ICmpInst::ICMP_UGE, Start, RHS))
13784 End = RHS;
13785 else
13786 End = IsSigned ? getSMinExpr(RHS, Start) : getUMinExpr(RHS, Start);
13787 }
13788
13789 if (Start->getType()->isPointerTy()) {
13791 if (isa<SCEVCouldNotCompute>(Start))
13792 return Start;
13793 }
13794 if (End->getType()->isPointerTy()) {
13795 End = getLosslessPtrToIntExpr(End);
13796 if (isa<SCEVCouldNotCompute>(End))
13797 return End;
13798 }
13799
13800 // Compute ((Start - End) + (Stride - 1)) / Stride.
13801 // FIXME: This can overflow. Holding off on fixing this for now;
13802 // howManyGreaterThans will hopefully be gone soon.
13803 const SCEV *One = getOne(Stride->getType());
13804 const SCEV *BECount = getUDivExpr(
13805 getAddExpr(getMinusSCEV(Start, End), getMinusSCEV(Stride, One)), Stride);
13806
13807 APInt MaxStart = IsSigned ? getSignedRangeMax(Start)
13809
13810 APInt MinStride = IsSigned ? getSignedRangeMin(Stride)
13811 : getUnsignedRangeMin(Stride);
13812
13813 unsigned BitWidth = getTypeSizeInBits(LHS->getType());
13814 APInt Limit = IsSigned ? APInt::getSignedMinValue(BitWidth) + (MinStride - 1)
13815 : APInt::getMinValue(BitWidth) + (MinStride - 1);
13816
13817 // Although End can be a MIN expression we estimate MinEnd considering only
13818 // the case End = RHS. This is safe because in the other case (Start - End)
13819 // is zero, leading to a zero maximum backedge taken count.
13820 APInt MinEnd =
13821 IsSigned ? APIntOps::smax(getSignedRangeMin(RHS), Limit)
13822 : APIntOps::umax(getUnsignedRangeMin(RHS), Limit);
13823
13824 const SCEV *ConstantMaxBECount =
13825 isa<SCEVConstant>(BECount)
13826 ? BECount
13827 : getUDivCeilSCEV(getConstant(MaxStart - MinEnd),
13828 getConstant(MinStride));
13829
13830 if (isa<SCEVCouldNotCompute>(ConstantMaxBECount))
13831 ConstantMaxBECount = BECount;
13832 const SCEV *SymbolicMaxBECount =
13833 isa<SCEVCouldNotCompute>(BECount) ? ConstantMaxBECount : BECount;
13834
13835 return ExitLimit(BECount, ConstantMaxBECount, SymbolicMaxBECount, false,
13836 Predicates);
13837}
13838
13840 ScalarEvolution &SE) const {
13841 if (Range.isFullSet()) // Infinite loop.
13842 return SE.getCouldNotCompute();
13843
13844 // If the start is a non-zero constant, shift the range to simplify things.
13845 if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(getStart()))
13846 if (!SC->getValue()->isZero()) {
13848 Operands[0] = SE.getZero(SC->getType());
13849 const SCEV *Shifted = SE.getAddRecExpr(Operands, getLoop(),
13851 if (const auto *ShiftedAddRec = dyn_cast<SCEVAddRecExpr>(Shifted))
13852 return ShiftedAddRec->getNumIterationsInRange(
13853 Range.subtract(SC->getAPInt()), SE);
13854 // This is strange and shouldn't happen.
13855 return SE.getCouldNotCompute();
13856 }
13857
13858 // The only time we can solve this is when we have all constant indices.
13859 // Otherwise, we cannot determine the overflow conditions.
13860 if (any_of(operands(), [](const SCEV *Op) { return !isa<SCEVConstant>(Op); }))
13861 return SE.getCouldNotCompute();
13862
13863 // Okay at this point we know that all elements of the chrec are constants and
13864 // that the start element is zero.
13865
13866 // First check to see if the range contains zero. If not, the first
13867 // iteration exits.
13868 unsigned BitWidth = SE.getTypeSizeInBits(getType());
13869 if (!Range.contains(APInt(BitWidth, 0)))
13870 return SE.getZero(getType());
13871
13872 if (isAffine()) {
13873 // If this is an affine expression then we have this situation:
13874 // Solve {0,+,A} in Range === Ax in Range
13875
13876 // We know that zero is in the range. If A is positive then we know that
13877 // the upper value of the range must be the first possible exit value.
13878 // If A is negative then the lower of the range is the last possible loop
13879 // value. Also note that we already checked for a full range.
13880 APInt A = cast<SCEVConstant>(getOperand(1))->getAPInt();
13881 APInt End = A.sge(1) ? (Range.getUpper() - 1) : Range.getLower();
13882
13883 // The exit value should be (End+A)/A.
13884 APInt ExitVal = (End + A).udiv(A);
13885 ConstantInt *ExitValue = ConstantInt::get(SE.getContext(), ExitVal);
13886
13887 // Evaluate at the exit value. If we really did fall out of the valid
13888 // range, then we computed our trip count, otherwise wrap around or other
13889 // things must have happened.
13890 ConstantInt *Val = EvaluateConstantChrecAtConstant(this, ExitValue, SE);
13891 if (Range.contains(Val->getValue()))
13892 return SE.getCouldNotCompute(); // Something strange happened
13893
13894 // Ensure that the previous value is in the range.
13895 assert(Range.contains(
13897 ConstantInt::get(SE.getContext(), ExitVal - 1), SE)->getValue()) &&
13898 "Linear scev computation is off in a bad way!");
13899 return SE.getConstant(ExitValue);
13900 }
13901
13902 if (isQuadratic()) {
13903 if (auto S = SolveQuadraticAddRecRange(this, Range, SE))
13904 return SE.getConstant(*S);
13905 }
13906
13907 return SE.getCouldNotCompute();
13908}
13909
13910const SCEVAddRecExpr *
13912 assert(getNumOperands() > 1 && "AddRec with zero step?");
13913 // There is a temptation to just call getAddExpr(this, getStepRecurrence(SE)),
13914 // but in this case we cannot guarantee that the value returned will be an
13915 // AddRec because SCEV does not have a fixed point where it stops
13916 // simplification: it is legal to return ({rec1} + {rec2}). For example, it
13917 // may happen if we reach arithmetic depth limit while simplifying. So we
13918 // construct the returned value explicitly.
13920 // If this is {A,+,B,+,C,...,+,N}, then its step is {B,+,C,+,...,+,N}, and
13921 // (this + Step) is {A+B,+,B+C,+...,+,N}.
13922 for (unsigned i = 0, e = getNumOperands() - 1; i < e; ++i)
13923 Ops.push_back(SE.getAddExpr(getOperand(i), getOperand(i + 1)));
13924 // We know that the last operand is not a constant zero (otherwise it would
13925 // have been popped out earlier). This guarantees us that if the result has
13926 // the same last operand, then it will also not be popped out, meaning that
13927 // the returned value will be an AddRec.
13928 const SCEV *Last = getOperand(getNumOperands() - 1);
13929 assert(!Last->isZero() && "Recurrency with zero step?");
13930 Ops.push_back(Last);
13933}
13934
13935// Return true when S contains at least an undef value.
13937 return SCEVExprContains(
13938 S, [](const SCEV *S) { return match(S, m_scev_UndefOrPoison()); });
13939}
13940
13941// Return true when S contains a value that is a nullptr.
13943 return SCEVExprContains(S, [](const SCEV *S) {
13944 if (const auto *SU = dyn_cast<SCEVUnknown>(S))
13945 return SU->getValue() == nullptr;
13946 return false;
13947 });
13948}
13949
13950/// Return the size of an element read or written by Inst.
13952 Type *Ty;
13953 if (StoreInst *Store = dyn_cast<StoreInst>(Inst))
13954 Ty = Store->getValueOperand()->getType();
13955 else if (LoadInst *Load = dyn_cast<LoadInst>(Inst))
13956 Ty = Load->getType();
13957 else
13958 return nullptr;
13959
13961 return getSizeOfExpr(ETy, Ty);
13962}
13963
13964//===----------------------------------------------------------------------===//
13965// SCEVCallbackVH Class Implementation
13966//===----------------------------------------------------------------------===//
13967
13969 assert(SE && "SCEVCallbackVH called with a null ScalarEvolution!");
13970 if (PHINode *PN = dyn_cast<PHINode>(getValPtr()))
13971 SE->ConstantEvolutionLoopExitValue.erase(PN);
13972 SE->eraseValueFromMap(getValPtr());
13973 // this now dangles!
13974}
13975
13976void ScalarEvolution::SCEVCallbackVH::allUsesReplacedWith(Value *V) {
13977 assert(SE && "SCEVCallbackVH called with a null ScalarEvolution!");
13978
13979 // Forget all the expressions associated with users of the old value,
13980 // so that future queries will recompute the expressions using the new
13981 // value.
13982 SE->forgetValue(getValPtr());
13983 // this now dangles!
13984}
13985
13986ScalarEvolution::SCEVCallbackVH::SCEVCallbackVH(Value *V, ScalarEvolution *se)
13987 : CallbackVH(V), SE(se) {}
13988
13989//===----------------------------------------------------------------------===//
13990// ScalarEvolution Class Implementation
13991//===----------------------------------------------------------------------===//
13992
13995 LoopInfo &LI)
13996 : F(F), DL(F.getDataLayout()), TLI(TLI), AC(AC), DT(DT), LI(LI),
13997 CouldNotCompute(new SCEVCouldNotCompute()), ValuesAtScopes(64),
13998 LoopDispositions(64), BlockDispositions(64) {
13999 // To use guards for proving predicates, we need to scan every instruction in
14000 // relevant basic blocks, and not just terminators. Doing this is a waste of
14001 // time if the IR does not actually contain any calls to
14002 // @llvm.experimental.guard, so do a quick check and remember this beforehand.
14003 //
14004 // This pessimizes the case where a pass that preserves ScalarEvolution wants
14005 // to _add_ guards to the module when there weren't any before, and wants
14006 // ScalarEvolution to optimize based on those guards. For now we prefer to be
14007 // efficient in lieu of being smart in that rather obscure case.
14008
14009 auto *GuardDecl = Intrinsic::getDeclarationIfExists(
14010 F.getParent(), Intrinsic::experimental_guard);
14011 HasGuards = GuardDecl && !GuardDecl->use_empty();
14012}
14013
14015 : F(Arg.F), DL(Arg.DL), HasGuards(Arg.HasGuards), TLI(Arg.TLI), AC(Arg.AC),
14016 DT(Arg.DT), LI(Arg.LI), CouldNotCompute(std::move(Arg.CouldNotCompute)),
14017 ValueExprMap(std::move(Arg.ValueExprMap)),
14018 PendingLoopPredicates(std::move(Arg.PendingLoopPredicates)),
14019 PendingMerges(std::move(Arg.PendingMerges)),
14020 ConstantMultipleCache(std::move(Arg.ConstantMultipleCache)),
14021 BackedgeTakenCounts(std::move(Arg.BackedgeTakenCounts)),
14022 PredicatedBackedgeTakenCounts(
14023 std::move(Arg.PredicatedBackedgeTakenCounts)),
14024 BECountUsers(std::move(Arg.BECountUsers)),
14025 ConstantEvolutionLoopExitValue(
14026 std::move(Arg.ConstantEvolutionLoopExitValue)),
14027 ValuesAtScopes(std::move(Arg.ValuesAtScopes)),
14028 ValuesAtScopesUsers(std::move(Arg.ValuesAtScopesUsers)),
14029 LoopDispositions(std::move(Arg.LoopDispositions)),
14030 LoopPropertiesCache(std::move(Arg.LoopPropertiesCache)),
14031 BlockDispositions(std::move(Arg.BlockDispositions)),
14032 SCEVUsers(std::move(Arg.SCEVUsers)),
14033 UnsignedRanges(std::move(Arg.UnsignedRanges)),
14034 SignedRanges(std::move(Arg.SignedRanges)),
14035 UniqueSCEVs(std::move(Arg.UniqueSCEVs)),
14036 UniquePreds(std::move(Arg.UniquePreds)),
14037 SCEVAllocator(std::move(Arg.SCEVAllocator)),
14038 LoopUsers(std::move(Arg.LoopUsers)),
14039 PredicatedSCEVRewrites(std::move(Arg.PredicatedSCEVRewrites)),
14040 FirstUnknown(Arg.FirstUnknown) {
14041 Arg.FirstUnknown = nullptr;
14042}
14043
14045 // Iterate through all the SCEVUnknown instances and call their
14046 // destructors, so that they release their references to their values.
14047 for (SCEVUnknown *U = FirstUnknown; U;) {
14048 SCEVUnknown *Tmp = U;
14049 U = U->Next;
14050 Tmp->~SCEVUnknown();
14051 }
14052 FirstUnknown = nullptr;
14053
14054 ExprValueMap.clear();
14055 ValueExprMap.clear();
14056 HasRecMap.clear();
14057 BackedgeTakenCounts.clear();
14058 PredicatedBackedgeTakenCounts.clear();
14059
14060 assert(PendingLoopPredicates.empty() && "isImpliedCond garbage");
14061 assert(PendingMerges.empty() && "isImpliedViaMerge garbage");
14062 assert(!WalkingBEDominatingConds && "isLoopBackedgeGuardedByCond garbage!");
14063 assert(!ProvingSplitPredicate && "ProvingSplitPredicate garbage!");
14064}
14065
14069
14070/// When printing a top-level SCEV for trip counts, it's helpful to include
14071/// a type for constants which are otherwise hard to disambiguate.
14072static void PrintSCEVWithTypeHint(raw_ostream &OS, const SCEV* S) {
14073 if (isa<SCEVConstant>(S))
14074 OS << *S->getType() << " ";
14075 OS << *S;
14076}
14077
14079 const Loop *L) {
14080 // Print all inner loops first
14081 for (Loop *I : *L)
14082 PrintLoopInfo(OS, SE, I);
14083
14084 OS << "Loop ";
14085 L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
14086 OS << ": ";
14087
14088 SmallVector<BasicBlock *, 8> ExitingBlocks;
14089 L->getExitingBlocks(ExitingBlocks);
14090 if (ExitingBlocks.size() != 1)
14091 OS << "<multiple exits> ";
14092
14093 auto *BTC = SE->getBackedgeTakenCount(L);
14094 if (!isa<SCEVCouldNotCompute>(BTC)) {
14095 OS << "backedge-taken count is ";
14096 PrintSCEVWithTypeHint(OS, BTC);
14097 } else
14098 OS << "Unpredictable backedge-taken count.";
14099 OS << "\n";
14100
14101 if (ExitingBlocks.size() > 1)
14102 for (BasicBlock *ExitingBlock : ExitingBlocks) {
14103 OS << " exit count for " << ExitingBlock->getName() << ": ";
14104 const SCEV *EC = SE->getExitCount(L, ExitingBlock);
14105 PrintSCEVWithTypeHint(OS, EC);
14106 if (isa<SCEVCouldNotCompute>(EC)) {
14107 // Retry with predicates.
14109 EC = SE->getPredicatedExitCount(L, ExitingBlock, &Predicates);
14110 if (!isa<SCEVCouldNotCompute>(EC)) {
14111 OS << "\n predicated exit count for " << ExitingBlock->getName()
14112 << ": ";
14113 PrintSCEVWithTypeHint(OS, EC);
14114 OS << "\n Predicates:\n";
14115 for (const auto *P : Predicates)
14116 P->print(OS, 4);
14117 }
14118 }
14119 OS << "\n";
14120 }
14121
14122 OS << "Loop ";
14123 L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
14124 OS << ": ";
14125
14126 auto *ConstantBTC = SE->getConstantMaxBackedgeTakenCount(L);
14127 if (!isa<SCEVCouldNotCompute>(ConstantBTC)) {
14128 OS << "constant max backedge-taken count is ";
14129 PrintSCEVWithTypeHint(OS, ConstantBTC);
14131 OS << ", actual taken count either this or zero.";
14132 } else {
14133 OS << "Unpredictable constant max backedge-taken count. ";
14134 }
14135
14136 OS << "\n"
14137 "Loop ";
14138 L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
14139 OS << ": ";
14140
14141 auto *SymbolicBTC = SE->getSymbolicMaxBackedgeTakenCount(L);
14142 if (!isa<SCEVCouldNotCompute>(SymbolicBTC)) {
14143 OS << "symbolic max backedge-taken count is ";
14144 PrintSCEVWithTypeHint(OS, SymbolicBTC);
14146 OS << ", actual taken count either this or zero.";
14147 } else {
14148 OS << "Unpredictable symbolic max backedge-taken count. ";
14149 }
14150 OS << "\n";
14151
14152 if (ExitingBlocks.size() > 1)
14153 for (BasicBlock *ExitingBlock : ExitingBlocks) {
14154 OS << " symbolic max exit count for " << ExitingBlock->getName() << ": ";
14155 auto *ExitBTC = SE->getExitCount(L, ExitingBlock,
14157 PrintSCEVWithTypeHint(OS, ExitBTC);
14158 if (isa<SCEVCouldNotCompute>(ExitBTC)) {
14159 // Retry with predicates.
14161 ExitBTC = SE->getPredicatedExitCount(L, ExitingBlock, &Predicates,
14163 if (!isa<SCEVCouldNotCompute>(ExitBTC)) {
14164 OS << "\n predicated symbolic max exit count for "
14165 << ExitingBlock->getName() << ": ";
14166 PrintSCEVWithTypeHint(OS, ExitBTC);
14167 OS << "\n Predicates:\n";
14168 for (const auto *P : Predicates)
14169 P->print(OS, 4);
14170 }
14171 }
14172 OS << "\n";
14173 }
14174
14176 auto *PBT = SE->getPredicatedBackedgeTakenCount(L, Preds);
14177 if (PBT != BTC) {
14178 assert(!Preds.empty() && "Different predicated BTC, but no predicates");
14179 OS << "Loop ";
14180 L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
14181 OS << ": ";
14182 if (!isa<SCEVCouldNotCompute>(PBT)) {
14183 OS << "Predicated backedge-taken count is ";
14184 PrintSCEVWithTypeHint(OS, PBT);
14185 } else
14186 OS << "Unpredictable predicated backedge-taken count.";
14187 OS << "\n";
14188 OS << " Predicates:\n";
14189 for (const auto *P : Preds)
14190 P->print(OS, 4);
14191 }
14192 Preds.clear();
14193
14194 auto *PredConstantMax =
14196 if (PredConstantMax != ConstantBTC) {
14197 assert(!Preds.empty() &&
14198 "different predicated constant max BTC but no predicates");
14199 OS << "Loop ";
14200 L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
14201 OS << ": ";
14202 if (!isa<SCEVCouldNotCompute>(PredConstantMax)) {
14203 OS << "Predicated constant max backedge-taken count is ";
14204 PrintSCEVWithTypeHint(OS, PredConstantMax);
14205 } else
14206 OS << "Unpredictable predicated constant max backedge-taken count.";
14207 OS << "\n";
14208 OS << " Predicates:\n";
14209 for (const auto *P : Preds)
14210 P->print(OS, 4);
14211 }
14212 Preds.clear();
14213
14214 auto *PredSymbolicMax =
14216 if (SymbolicBTC != PredSymbolicMax) {
14217 assert(!Preds.empty() &&
14218 "Different predicated symbolic max BTC, but no predicates");
14219 OS << "Loop ";
14220 L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
14221 OS << ": ";
14222 if (!isa<SCEVCouldNotCompute>(PredSymbolicMax)) {
14223 OS << "Predicated symbolic max backedge-taken count is ";
14224 PrintSCEVWithTypeHint(OS, PredSymbolicMax);
14225 } else
14226 OS << "Unpredictable predicated symbolic max backedge-taken count.";
14227 OS << "\n";
14228 OS << " Predicates:\n";
14229 for (const auto *P : Preds)
14230 P->print(OS, 4);
14231 }
14232
14234 OS << "Loop ";
14235 L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
14236 OS << ": ";
14237 OS << "Trip multiple is " << SE->getSmallConstantTripMultiple(L) << "\n";
14238 }
14239}
14240
14241namespace llvm {
14242// Note: these overloaded operators need to be in the llvm namespace for them
14243// to be resolved correctly. If we put them outside the llvm namespace, the
14244//
14245// OS << ": " << SE.getLoopDisposition(SV, InnerL);
14246//
14247// code below "breaks" and start printing raw enum values as opposed to the
14248// string values.
14251 switch (LD) {
14253 OS << "Variant";
14254 break;
14256 OS << "Invariant";
14257 break;
14259 OS << "Computable";
14260 break;
14261 }
14262 return OS;
14263}
14264
14267 switch (BD) {
14269 OS << "DoesNotDominate";
14270 break;
14272 OS << "Dominates";
14273 break;
14275 OS << "ProperlyDominates";
14276 break;
14277 }
14278 return OS;
14279}
14280} // namespace llvm
14281
14283 // ScalarEvolution's implementation of the print method is to print
14284 // out SCEV values of all instructions that are interesting. Doing
14285 // this potentially causes it to create new SCEV objects though,
14286 // which technically conflicts with the const qualifier. This isn't
14287 // observable from outside the class though, so casting away the
14288 // const isn't dangerous.
14289 ScalarEvolution &SE = *const_cast<ScalarEvolution *>(this);
14290
14291 if (ClassifyExpressions) {
14292 OS << "Classifying expressions for: ";
14293 F.printAsOperand(OS, /*PrintType=*/false);
14294 OS << "\n";
14295 for (Instruction &I : instructions(F))
14296 if (isSCEVable(I.getType()) && !isa<CmpInst>(I)) {
14297 OS << I << '\n';
14298 OS << " --> ";
14299 const SCEV *SV = SE.getSCEV(&I);
14300 SV->print(OS);
14301 if (!isa<SCEVCouldNotCompute>(SV)) {
14302 OS << " U: ";
14303 SE.getUnsignedRange(SV).print(OS);
14304 OS << " S: ";
14305 SE.getSignedRange(SV).print(OS);
14306 }
14307
14308 const Loop *L = LI.getLoopFor(I.getParent());
14309
14310 const SCEV *AtUse = SE.getSCEVAtScope(SV, L);
14311 if (AtUse != SV) {
14312 OS << " --> ";
14313 AtUse->print(OS);
14314 if (!isa<SCEVCouldNotCompute>(AtUse)) {
14315 OS << " U: ";
14316 SE.getUnsignedRange(AtUse).print(OS);
14317 OS << " S: ";
14318 SE.getSignedRange(AtUse).print(OS);
14319 }
14320 }
14321
14322 if (L) {
14323 OS << "\t\t" "Exits: ";
14324 const SCEV *ExitValue = SE.getSCEVAtScope(SV, L->getParentLoop());
14325 if (!SE.isLoopInvariant(ExitValue, L)) {
14326 OS << "<<Unknown>>";
14327 } else {
14328 OS << *ExitValue;
14329 }
14330
14331 ListSeparator LS(", ", "\t\tLoopDispositions: { ");
14332 for (const auto *Iter = L; Iter; Iter = Iter->getParentLoop()) {
14333 OS << LS;
14334 Iter->getHeader()->printAsOperand(OS, /*PrintType=*/false);
14335 OS << ": " << SE.getLoopDisposition(SV, Iter);
14336 }
14337
14338 for (const auto *InnerL : depth_first(L)) {
14339 if (InnerL == L)
14340 continue;
14341 OS << LS;
14342 InnerL->getHeader()->printAsOperand(OS, /*PrintType=*/false);
14343 OS << ": " << SE.getLoopDisposition(SV, InnerL);
14344 }
14345
14346 OS << " }";
14347 }
14348
14349 OS << "\n";
14350 }
14351 }
14352
14353 OS << "Determining loop execution counts for: ";
14354 F.printAsOperand(OS, /*PrintType=*/false);
14355 OS << "\n";
14356 for (Loop *I : LI)
14357 PrintLoopInfo(OS, &SE, I);
14358}
14359
14362 auto &Values = LoopDispositions[S];
14363 for (auto &V : Values) {
14364 if (V.getPointer() == L)
14365 return V.getInt();
14366 }
14367 Values.emplace_back(L, LoopVariant);
14368 LoopDisposition D = computeLoopDisposition(S, L);
14369 auto &Values2 = LoopDispositions[S];
14370 for (auto &V : llvm::reverse(Values2)) {
14371 if (V.getPointer() == L) {
14372 V.setInt(D);
14373 break;
14374 }
14375 }
14376 return D;
14377}
14378
14380ScalarEvolution::computeLoopDisposition(const SCEV *S, const Loop *L) {
14381 switch (S->getSCEVType()) {
14382 case scConstant:
14383 case scVScale:
14384 return LoopInvariant;
14385 case scAddRecExpr: {
14386 const SCEVAddRecExpr *AR = cast<SCEVAddRecExpr>(S);
14387
14388 // If L is the addrec's loop, it's computable.
14389 if (AR->getLoop() == L)
14390 return LoopComputable;
14391
14392 // Add recurrences are never invariant in the function-body (null loop).
14393 if (!L)
14394 return LoopVariant;
14395
14396 // Everything that is not defined at loop entry is variant.
14397 if (DT.dominates(L->getHeader(), AR->getLoop()->getHeader()))
14398 return LoopVariant;
14399 assert(!L->contains(AR->getLoop()) && "Containing loop's header does not"
14400 " dominate the contained loop's header?");
14401
14402 // This recurrence is invariant w.r.t. L if AR's loop contains L.
14403 if (AR->getLoop()->contains(L))
14404 return LoopInvariant;
14405
14406 // This recurrence is variant w.r.t. L if any of its operands
14407 // are variant.
14408 for (SCEVUse Op : AR->operands())
14409 if (!isLoopInvariant(Op, L))
14410 return LoopVariant;
14411
14412 // Otherwise it's loop-invariant.
14413 return LoopInvariant;
14414 }
14415 case scTruncate:
14416 case scZeroExtend:
14417 case scSignExtend:
14418 case scPtrToAddr:
14419 case scPtrToInt:
14420 case scAddExpr:
14421 case scMulExpr:
14422 case scUDivExpr:
14423 case scUMaxExpr:
14424 case scSMaxExpr:
14425 case scUMinExpr:
14426 case scSMinExpr:
14427 case scSequentialUMinExpr: {
14428 bool HasVarying = false;
14429 for (SCEVUse Op : S->operands()) {
14431 if (D == LoopVariant)
14432 return LoopVariant;
14433 if (D == LoopComputable)
14434 HasVarying = true;
14435 }
14436 return HasVarying ? LoopComputable : LoopInvariant;
14437 }
14438 case scUnknown:
14439 // All non-instruction values are loop invariant. All instructions are loop
14440 // invariant if they are not contained in the specified loop.
14441 // Instructions are never considered invariant in the function body
14442 // (null loop) because they are defined within the "loop".
14443 if (auto *I = dyn_cast<Instruction>(cast<SCEVUnknown>(S)->getValue()))
14444 return (L && !L->contains(I)) ? LoopInvariant : LoopVariant;
14445 return LoopInvariant;
14446 case scCouldNotCompute:
14447 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
14448 }
14449 llvm_unreachable("Unknown SCEV kind!");
14450}
14451
14453 return getLoopDisposition(S, L) == LoopInvariant;
14454}
14455
14457 return getLoopDisposition(S, L) == LoopComputable;
14458}
14459
14462 auto &Values = BlockDispositions[S];
14463 for (auto &V : Values) {
14464 if (V.getPointer() == BB)
14465 return V.getInt();
14466 }
14467 Values.emplace_back(BB, DoesNotDominateBlock);
14468 BlockDisposition D = computeBlockDisposition(S, BB);
14469 auto &Values2 = BlockDispositions[S];
14470 for (auto &V : llvm::reverse(Values2)) {
14471 if (V.getPointer() == BB) {
14472 V.setInt(D);
14473 break;
14474 }
14475 }
14476 return D;
14477}
14478
14480ScalarEvolution::computeBlockDisposition(const SCEV *S, const BasicBlock *BB) {
14481 switch (S->getSCEVType()) {
14482 case scConstant:
14483 case scVScale:
14485 case scAddRecExpr: {
14486 // This uses a "dominates" query instead of "properly dominates" query
14487 // to test for proper dominance too, because the instruction which
14488 // produces the addrec's value is a PHI, and a PHI effectively properly
14489 // dominates its entire containing block.
14490 const SCEVAddRecExpr *AR = cast<SCEVAddRecExpr>(S);
14491 if (!DT.dominates(AR->getLoop()->getHeader(), BB))
14492 return DoesNotDominateBlock;
14493
14494 // Fall through into SCEVNAryExpr handling.
14495 [[fallthrough]];
14496 }
14497 case scTruncate:
14498 case scZeroExtend:
14499 case scSignExtend:
14500 case scPtrToAddr:
14501 case scPtrToInt:
14502 case scAddExpr:
14503 case scMulExpr:
14504 case scUDivExpr:
14505 case scUMaxExpr:
14506 case scSMaxExpr:
14507 case scUMinExpr:
14508 case scSMinExpr:
14509 case scSequentialUMinExpr: {
14510 bool Proper = true;
14511 for (const SCEV *NAryOp : S->operands()) {
14513 if (D == DoesNotDominateBlock)
14514 return DoesNotDominateBlock;
14515 if (D == DominatesBlock)
14516 Proper = false;
14517 }
14518 return Proper ? ProperlyDominatesBlock : DominatesBlock;
14519 }
14520 case scUnknown:
14521 if (Instruction *I =
14522 dyn_cast<Instruction>(cast<SCEVUnknown>(S)->getValue())) {
14523 if (I->getParent() == BB)
14524 return DominatesBlock;
14525 if (DT.properlyDominates(I->getParent(), BB))
14527 return DoesNotDominateBlock;
14528 }
14530 case scCouldNotCompute:
14531 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
14532 }
14533 llvm_unreachable("Unknown SCEV kind!");
14534}
14535
14536bool ScalarEvolution::dominates(const SCEV *S, const BasicBlock *BB) {
14537 return getBlockDisposition(S, BB) >= DominatesBlock;
14538}
14539
14542}
14543
14544bool ScalarEvolution::hasOperand(const SCEV *S, const SCEV *Op) const {
14545 return SCEVExprContains(S, [&](const SCEV *Expr) { return Expr == Op; });
14546}
14547
14548void ScalarEvolution::forgetBackedgeTakenCounts(const Loop *L,
14549 bool Predicated) {
14550 auto &BECounts =
14551 Predicated ? PredicatedBackedgeTakenCounts : BackedgeTakenCounts;
14552 auto It = BECounts.find(L);
14553 if (It != BECounts.end()) {
14554 for (const ExitNotTakenInfo &ENT : It->second.ExitNotTaken) {
14555 for (const SCEV *S : {ENT.ExactNotTaken, ENT.SymbolicMaxNotTaken}) {
14556 if (!isa<SCEVConstant>(S)) {
14557 auto UserIt = BECountUsers.find(S);
14558 assert(UserIt != BECountUsers.end());
14559 UserIt->second.erase({L, Predicated});
14560 }
14561 }
14562 }
14563 BECounts.erase(It);
14564 }
14565}
14566
14567void ScalarEvolution::forgetMemoizedResults(ArrayRef<SCEVUse> SCEVs) {
14568 SmallPtrSet<const SCEV *, 8> ToForget(llvm::from_range, SCEVs);
14569 SmallVector<SCEVUse, 8> Worklist(ToForget.begin(), ToForget.end());
14570
14571 while (!Worklist.empty()) {
14572 const SCEV *Curr = Worklist.pop_back_val();
14573 auto Users = SCEVUsers.find(Curr);
14574 if (Users != SCEVUsers.end())
14575 for (const auto *User : Users->second)
14576 if (ToForget.insert(User).second)
14577 Worklist.push_back(User);
14578 }
14579
14580 for (const auto *S : ToForget)
14581 forgetMemoizedResultsImpl(S);
14582
14583 for (auto I = PredicatedSCEVRewrites.begin();
14584 I != PredicatedSCEVRewrites.end();) {
14585 std::pair<const SCEV *, const Loop *> Entry = I->first;
14586 if (ToForget.count(Entry.first))
14587 PredicatedSCEVRewrites.erase(I++);
14588 else
14589 ++I;
14590 }
14591}
14592
14593void ScalarEvolution::forgetMemoizedResultsImpl(const SCEV *S) {
14594 LoopDispositions.erase(S);
14595 BlockDispositions.erase(S);
14596 UnsignedRanges.erase(S);
14597 SignedRanges.erase(S);
14598 HasRecMap.erase(S);
14599 ConstantMultipleCache.erase(S);
14600
14601 if (auto *AR = dyn_cast<SCEVAddRecExpr>(S)) {
14602 UnsignedWrapViaInductionTried.erase(AR);
14603 SignedWrapViaInductionTried.erase(AR);
14604 }
14605
14606 auto ExprIt = ExprValueMap.find(S);
14607 if (ExprIt != ExprValueMap.end()) {
14608 for (Value *V : ExprIt->second) {
14609 auto ValueIt = ValueExprMap.find_as(V);
14610 if (ValueIt != ValueExprMap.end())
14611 ValueExprMap.erase(ValueIt);
14612 }
14613 ExprValueMap.erase(ExprIt);
14614 }
14615
14616 auto ScopeIt = ValuesAtScopes.find(S);
14617 if (ScopeIt != ValuesAtScopes.end()) {
14618 for (const auto &Pair : ScopeIt->second)
14619 if (!isa_and_nonnull<SCEVConstant>(Pair.second))
14620 llvm::erase(ValuesAtScopesUsers[Pair.second],
14621 std::make_pair(Pair.first, S));
14622 ValuesAtScopes.erase(ScopeIt);
14623 }
14624
14625 auto ScopeUserIt = ValuesAtScopesUsers.find(S);
14626 if (ScopeUserIt != ValuesAtScopesUsers.end()) {
14627 for (const auto &Pair : ScopeUserIt->second)
14628 llvm::erase(ValuesAtScopes[Pair.second], std::make_pair(Pair.first, S));
14629 ValuesAtScopesUsers.erase(ScopeUserIt);
14630 }
14631
14632 auto BEUsersIt = BECountUsers.find(S);
14633 if (BEUsersIt != BECountUsers.end()) {
14634 // Work on a copy, as forgetBackedgeTakenCounts() will modify the original.
14635 auto Copy = BEUsersIt->second;
14636 for (const auto &Pair : Copy)
14637 forgetBackedgeTakenCounts(Pair.getPointer(), Pair.getInt());
14638 BECountUsers.erase(BEUsersIt);
14639 }
14640
14641 auto FoldUser = FoldCacheUser.find(S);
14642 if (FoldUser != FoldCacheUser.end())
14643 for (auto &KV : FoldUser->second)
14644 FoldCache.erase(KV);
14645 FoldCacheUser.erase(S);
14646}
14647
14648void
14649ScalarEvolution::getUsedLoops(const SCEV *S,
14650 SmallPtrSetImpl<const Loop *> &LoopsUsed) {
14651 struct FindUsedLoops {
14652 FindUsedLoops(SmallPtrSetImpl<const Loop *> &LoopsUsed)
14653 : LoopsUsed(LoopsUsed) {}
14654 SmallPtrSetImpl<const Loop *> &LoopsUsed;
14655 bool follow(const SCEV *S) {
14656 if (auto *AR = dyn_cast<SCEVAddRecExpr>(S))
14657 LoopsUsed.insert(AR->getLoop());
14658 return true;
14659 }
14660
14661 bool isDone() const { return false; }
14662 };
14663
14664 FindUsedLoops F(LoopsUsed);
14665 SCEVTraversal<FindUsedLoops>(F).visitAll(S);
14666}
14667
14668void ScalarEvolution::getReachableBlocks(
14671 Worklist.push_back(&F.getEntryBlock());
14672 while (!Worklist.empty()) {
14673 BasicBlock *BB = Worklist.pop_back_val();
14674 if (!Reachable.insert(BB).second)
14675 continue;
14676
14677 Value *Cond;
14678 BasicBlock *TrueBB, *FalseBB;
14679 if (match(BB->getTerminator(), m_Br(m_Value(Cond), m_BasicBlock(TrueBB),
14680 m_BasicBlock(FalseBB)))) {
14681 if (auto *C = dyn_cast<ConstantInt>(Cond)) {
14682 Worklist.push_back(C->isOne() ? TrueBB : FalseBB);
14683 continue;
14684 }
14685
14686 if (auto *Cmp = dyn_cast<ICmpInst>(Cond)) {
14687 const SCEV *L = getSCEV(Cmp->getOperand(0));
14688 const SCEV *R = getSCEV(Cmp->getOperand(1));
14689 if (isKnownPredicateViaConstantRanges(Cmp->getCmpPredicate(), L, R)) {
14690 Worklist.push_back(TrueBB);
14691 continue;
14692 }
14693 if (isKnownPredicateViaConstantRanges(Cmp->getInverseCmpPredicate(), L,
14694 R)) {
14695 Worklist.push_back(FalseBB);
14696 continue;
14697 }
14698 }
14699 }
14700
14701 append_range(Worklist, successors(BB));
14702 }
14703}
14704
14706 ScalarEvolution &SE = *const_cast<ScalarEvolution *>(this);
14707 ScalarEvolution SE2(F, TLI, AC, DT, LI);
14708
14709 SmallVector<Loop *, 8> LoopStack(LI.begin(), LI.end());
14710
14711 // Map's SCEV expressions from one ScalarEvolution "universe" to another.
14712 struct SCEVMapper : public SCEVRewriteVisitor<SCEVMapper> {
14713 SCEVMapper(ScalarEvolution &SE) : SCEVRewriteVisitor<SCEVMapper>(SE) {}
14714
14715 const SCEV *visitConstant(const SCEVConstant *Constant) {
14716 return SE.getConstant(Constant->getAPInt());
14717 }
14718
14719 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
14720 return SE.getUnknown(Expr->getValue());
14721 }
14722
14723 const SCEV *visitCouldNotCompute(const SCEVCouldNotCompute *Expr) {
14724 return SE.getCouldNotCompute();
14725 }
14726 };
14727
14728 SCEVMapper SCM(SE2);
14729 SmallPtrSet<BasicBlock *, 16> ReachableBlocks;
14730 SE2.getReachableBlocks(ReachableBlocks, F);
14731
14732 auto GetDelta = [&](const SCEV *Old, const SCEV *New) -> const SCEV * {
14733 if (containsUndefs(Old) || containsUndefs(New)) {
14734 // SCEV treats "undef" as an unknown but consistent value (i.e. it does
14735 // not propagate undef aggressively). This means we can (and do) fail
14736 // verification in cases where a transform makes a value go from "undef"
14737 // to "undef+1" (say). The transform is fine, since in both cases the
14738 // result is "undef", but SCEV thinks the value increased by 1.
14739 return nullptr;
14740 }
14741
14742 // Unless VerifySCEVStrict is set, we only compare constant deltas.
14743 const SCEV *Delta = SE2.getMinusSCEV(Old, New);
14744 if (!VerifySCEVStrict && !isa<SCEVConstant>(Delta))
14745 return nullptr;
14746
14747 return Delta;
14748 };
14749
14750 while (!LoopStack.empty()) {
14751 auto *L = LoopStack.pop_back_val();
14752 llvm::append_range(LoopStack, *L);
14753
14754 // Only verify BECounts in reachable loops. For an unreachable loop,
14755 // any BECount is legal.
14756 if (!ReachableBlocks.contains(L->getHeader()))
14757 continue;
14758
14759 // Only verify cached BECounts. Computing new BECounts may change the
14760 // results of subsequent SCEV uses.
14761 auto It = BackedgeTakenCounts.find(L);
14762 if (It == BackedgeTakenCounts.end())
14763 continue;
14764
14765 auto *CurBECount =
14766 SCM.visit(It->second.getExact(L, const_cast<ScalarEvolution *>(this)));
14767 auto *NewBECount = SE2.getBackedgeTakenCount(L);
14768
14769 if (CurBECount == SE2.getCouldNotCompute() ||
14770 NewBECount == SE2.getCouldNotCompute()) {
14771 // NB! This situation is legal, but is very suspicious -- whatever pass
14772 // change the loop to make a trip count go from could not compute to
14773 // computable or vice-versa *should have* invalidated SCEV. However, we
14774 // choose not to assert here (for now) since we don't want false
14775 // positives.
14776 continue;
14777 }
14778
14779 if (SE.getTypeSizeInBits(CurBECount->getType()) >
14780 SE.getTypeSizeInBits(NewBECount->getType()))
14781 NewBECount = SE2.getZeroExtendExpr(NewBECount, CurBECount->getType());
14782 else if (SE.getTypeSizeInBits(CurBECount->getType()) <
14783 SE.getTypeSizeInBits(NewBECount->getType()))
14784 CurBECount = SE2.getZeroExtendExpr(CurBECount, NewBECount->getType());
14785
14786 const SCEV *Delta = GetDelta(CurBECount, NewBECount);
14787 if (Delta && !Delta->isZero()) {
14788 dbgs() << "Trip Count for " << *L << " Changed!\n";
14789 dbgs() << "Old: " << *CurBECount << "\n";
14790 dbgs() << "New: " << *NewBECount << "\n";
14791 dbgs() << "Delta: " << *Delta << "\n";
14792 std::abort();
14793 }
14794 }
14795
14796 // Collect all valid loops currently in LoopInfo.
14797 SmallPtrSet<Loop *, 32> ValidLoops;
14798 SmallVector<Loop *, 32> Worklist(LI.begin(), LI.end());
14799 while (!Worklist.empty()) {
14800 Loop *L = Worklist.pop_back_val();
14801 if (ValidLoops.insert(L).second)
14802 Worklist.append(L->begin(), L->end());
14803 }
14804 for (const auto &KV : ValueExprMap) {
14805#ifndef NDEBUG
14806 // Check for SCEV expressions referencing invalid/deleted loops.
14807 if (auto *AR = dyn_cast<SCEVAddRecExpr>(KV.second)) {
14808 assert(ValidLoops.contains(AR->getLoop()) &&
14809 "AddRec references invalid loop");
14810 }
14811#endif
14812
14813 // Check that the value is also part of the reverse map.
14814 auto It = ExprValueMap.find(KV.second);
14815 if (It == ExprValueMap.end() || !It->second.contains(KV.first)) {
14816 dbgs() << "Value " << *KV.first
14817 << " is in ValueExprMap but not in ExprValueMap\n";
14818 std::abort();
14819 }
14820
14821 if (auto *I = dyn_cast<Instruction>(&*KV.first)) {
14822 if (!ReachableBlocks.contains(I->getParent()))
14823 continue;
14824 const SCEV *OldSCEV = SCM.visit(KV.second);
14825 const SCEV *NewSCEV = SE2.getSCEV(I);
14826 const SCEV *Delta = GetDelta(OldSCEV, NewSCEV);
14827 if (Delta && !Delta->isZero()) {
14828 dbgs() << "SCEV for value " << *I << " changed!\n"
14829 << "Old: " << *OldSCEV << "\n"
14830 << "New: " << *NewSCEV << "\n"
14831 << "Delta: " << *Delta << "\n";
14832 std::abort();
14833 }
14834 }
14835 }
14836
14837 for (const auto &KV : ExprValueMap) {
14838 for (Value *V : KV.second) {
14839 const SCEV *S = ValueExprMap.lookup(V);
14840 if (!S) {
14841 dbgs() << "Value " << *V
14842 << " is in ExprValueMap but not in ValueExprMap\n";
14843 std::abort();
14844 }
14845 if (S != KV.first) {
14846 dbgs() << "Value " << *V << " mapped to " << *S << " rather than "
14847 << *KV.first << "\n";
14848 std::abort();
14849 }
14850 }
14851 }
14852
14853 // Verify integrity of SCEV users.
14854 for (const auto &S : UniqueSCEVs) {
14855 for (SCEVUse Op : S.operands()) {
14856 // We do not store dependencies of constants.
14857 if (isa<SCEVConstant>(Op))
14858 continue;
14859 auto It = SCEVUsers.find(Op);
14860 if (It != SCEVUsers.end() && It->second.count(&S))
14861 continue;
14862 dbgs() << "Use of operand " << *Op << " by user " << S
14863 << " is not being tracked!\n";
14864 std::abort();
14865 }
14866 }
14867
14868 // Verify integrity of ValuesAtScopes users.
14869 for (const auto &ValueAndVec : ValuesAtScopes) {
14870 const SCEV *Value = ValueAndVec.first;
14871 for (const auto &LoopAndValueAtScope : ValueAndVec.second) {
14872 const Loop *L = LoopAndValueAtScope.first;
14873 const SCEV *ValueAtScope = LoopAndValueAtScope.second;
14874 if (!isa<SCEVConstant>(ValueAtScope)) {
14875 auto It = ValuesAtScopesUsers.find(ValueAtScope);
14876 if (It != ValuesAtScopesUsers.end() &&
14877 is_contained(It->second, std::make_pair(L, Value)))
14878 continue;
14879 dbgs() << "Value: " << *Value << ", Loop: " << *L << ", ValueAtScope: "
14880 << *ValueAtScope << " missing in ValuesAtScopesUsers\n";
14881 std::abort();
14882 }
14883 }
14884 }
14885
14886 for (const auto &ValueAtScopeAndVec : ValuesAtScopesUsers) {
14887 const SCEV *ValueAtScope = ValueAtScopeAndVec.first;
14888 for (const auto &LoopAndValue : ValueAtScopeAndVec.second) {
14889 const Loop *L = LoopAndValue.first;
14890 const SCEV *Value = LoopAndValue.second;
14892 auto It = ValuesAtScopes.find(Value);
14893 if (It != ValuesAtScopes.end() &&
14894 is_contained(It->second, std::make_pair(L, ValueAtScope)))
14895 continue;
14896 dbgs() << "Value: " << *Value << ", Loop: " << *L << ", ValueAtScope: "
14897 << *ValueAtScope << " missing in ValuesAtScopes\n";
14898 std::abort();
14899 }
14900 }
14901
14902 // Verify integrity of BECountUsers.
14903 auto VerifyBECountUsers = [&](bool Predicated) {
14904 auto &BECounts =
14905 Predicated ? PredicatedBackedgeTakenCounts : BackedgeTakenCounts;
14906 for (const auto &LoopAndBEInfo : BECounts) {
14907 for (const ExitNotTakenInfo &ENT : LoopAndBEInfo.second.ExitNotTaken) {
14908 for (const SCEV *S : {ENT.ExactNotTaken, ENT.SymbolicMaxNotTaken}) {
14909 if (!isa<SCEVConstant>(S)) {
14910 auto UserIt = BECountUsers.find(S);
14911 if (UserIt != BECountUsers.end() &&
14912 UserIt->second.contains({ LoopAndBEInfo.first, Predicated }))
14913 continue;
14914 dbgs() << "Value " << *S << " for loop " << *LoopAndBEInfo.first
14915 << " missing from BECountUsers\n";
14916 std::abort();
14917 }
14918 }
14919 }
14920 }
14921 };
14922 VerifyBECountUsers(/* Predicated */ false);
14923 VerifyBECountUsers(/* Predicated */ true);
14924
14925 // Verify intergity of loop disposition cache.
14926 for (auto &[S, Values] : LoopDispositions) {
14927 for (auto [Loop, CachedDisposition] : Values) {
14928 const auto RecomputedDisposition = SE2.getLoopDisposition(S, Loop);
14929 if (CachedDisposition != RecomputedDisposition) {
14930 dbgs() << "Cached disposition of " << *S << " for loop " << *Loop
14931 << " is incorrect: cached " << CachedDisposition << ", actual "
14932 << RecomputedDisposition << "\n";
14933 std::abort();
14934 }
14935 }
14936 }
14937
14938 // Verify integrity of the block disposition cache.
14939 for (auto &[S, Values] : BlockDispositions) {
14940 for (auto [BB, CachedDisposition] : Values) {
14941 const auto RecomputedDisposition = SE2.getBlockDisposition(S, BB);
14942 if (CachedDisposition != RecomputedDisposition) {
14943 dbgs() << "Cached disposition of " << *S << " for block %"
14944 << BB->getName() << " is incorrect: cached " << CachedDisposition
14945 << ", actual " << RecomputedDisposition << "\n";
14946 std::abort();
14947 }
14948 }
14949 }
14950
14951 // Verify FoldCache/FoldCacheUser caches.
14952 for (auto [FoldID, Expr] : FoldCache) {
14953 auto I = FoldCacheUser.find(Expr);
14954 if (I == FoldCacheUser.end()) {
14955 dbgs() << "Missing entry in FoldCacheUser for cached expression " << *Expr
14956 << "!\n";
14957 std::abort();
14958 }
14959 if (!is_contained(I->second, FoldID)) {
14960 dbgs() << "Missing FoldID in cached users of " << *Expr << "!\n";
14961 std::abort();
14962 }
14963 }
14964 for (auto [Expr, IDs] : FoldCacheUser) {
14965 for (auto &FoldID : IDs) {
14966 const SCEV *S = FoldCache.lookup(FoldID);
14967 if (!S) {
14968 dbgs() << "Missing entry in FoldCache for expression " << *Expr
14969 << "!\n";
14970 std::abort();
14971 }
14972 if (S != Expr) {
14973 dbgs() << "Entry in FoldCache doesn't match FoldCacheUser: " << *S
14974 << " != " << *Expr << "!\n";
14975 std::abort();
14976 }
14977 }
14978 }
14979
14980 // Verify that ConstantMultipleCache computations are correct. We check that
14981 // cached multiples and recomputed multiples are multiples of each other to
14982 // verify correctness. It is possible that a recomputed multiple is different
14983 // from the cached multiple due to strengthened no wrap flags or changes in
14984 // KnownBits computations.
14985 for (auto [S, Multiple] : ConstantMultipleCache) {
14986 APInt RecomputedMultiple = SE2.getConstantMultiple(S);
14987 if ((Multiple != 0 && RecomputedMultiple != 0 &&
14988 Multiple.urem(RecomputedMultiple) != 0 &&
14989 RecomputedMultiple.urem(Multiple) != 0)) {
14990 dbgs() << "Incorrect cached computation in ConstantMultipleCache for "
14991 << *S << " : Computed " << RecomputedMultiple
14992 << " but cache contains " << Multiple << "!\n";
14993 std::abort();
14994 }
14995 }
14996}
14997
14999 Function &F, const PreservedAnalyses &PA,
15000 FunctionAnalysisManager::Invalidator &Inv) {
15001 // Invalidate the ScalarEvolution object whenever it isn't preserved or one
15002 // of its dependencies is invalidated.
15003 auto PAC = PA.getChecker<ScalarEvolutionAnalysis>();
15004 return !(PAC.preserved() || PAC.preservedSet<AllAnalysesOn<Function>>()) ||
15005 Inv.invalidate<AssumptionAnalysis>(F, PA) ||
15006 Inv.invalidate<DominatorTreeAnalysis>(F, PA) ||
15007 Inv.invalidate<LoopAnalysis>(F, PA);
15008}
15009
15010AnalysisKey ScalarEvolutionAnalysis::Key;
15011
15014 auto &TLI = AM.getResult<TargetLibraryAnalysis>(F);
15015 auto &AC = AM.getResult<AssumptionAnalysis>(F);
15016 auto &DT = AM.getResult<DominatorTreeAnalysis>(F);
15017 auto &LI = AM.getResult<LoopAnalysis>(F);
15018 return ScalarEvolution(F, TLI, AC, DT, LI);
15019}
15020
15026
15029 // For compatibility with opt's -analyze feature under legacy pass manager
15030 // which was not ported to NPM. This keeps tests using
15031 // update_analyze_test_checks.py working.
15032 OS << "Printing analysis 'Scalar Evolution Analysis' for function '"
15033 << F.getName() << "':\n";
15035 return PreservedAnalyses::all();
15036}
15037
15039 "Scalar Evolution Analysis", false, true)
15045 "Scalar Evolution Analysis", false, true)
15046
15048
15050
15052 SE.reset(new ScalarEvolution(
15054 getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F),
15056 getAnalysis<LoopInfoWrapperPass>().getLoopInfo()));
15057 return false;
15058}
15059
15061
15063 SE->print(OS);
15064}
15065
15067 if (!VerifySCEV)
15068 return;
15069
15070 SE->verify();
15071}
15072
15080
15082 const SCEV *RHS) {
15083 return getComparePredicate(ICmpInst::ICMP_EQ, LHS, RHS);
15084}
15085
15086const SCEVPredicate *
15088 const SCEV *LHS, const SCEV *RHS) {
15090 assert(LHS->getType() == RHS->getType() &&
15091 "Type mismatch between LHS and RHS");
15092 // Unique this node based on the arguments
15093 ID.AddInteger(SCEVPredicate::P_Compare);
15094 ID.AddInteger(Pred);
15095 ID.AddPointer(LHS);
15096 ID.AddPointer(RHS);
15097 void *IP = nullptr;
15098 if (const auto *S = UniquePreds.FindNodeOrInsertPos(ID, IP))
15099 return S;
15100 SCEVComparePredicate *Eq = new (SCEVAllocator)
15101 SCEVComparePredicate(ID.Intern(SCEVAllocator), Pred, LHS, RHS);
15102 UniquePreds.InsertNode(Eq, IP);
15103 return Eq;
15104}
15105
15107 const SCEVAddRecExpr *AR,
15110 // Unique this node based on the arguments
15111 ID.AddInteger(SCEVPredicate::P_Wrap);
15112 ID.AddPointer(AR);
15113 ID.AddInteger(AddedFlags);
15114 void *IP = nullptr;
15115 if (const auto *S = UniquePreds.FindNodeOrInsertPos(ID, IP))
15116 return S;
15117 auto *OF = new (SCEVAllocator)
15118 SCEVWrapPredicate(ID.Intern(SCEVAllocator), AR, AddedFlags);
15119 UniquePreds.InsertNode(OF, IP);
15120 return OF;
15121}
15122
15123namespace {
15124
15125class SCEVPredicateRewriter : public SCEVRewriteVisitor<SCEVPredicateRewriter> {
15126public:
15127
15128 /// Rewrites \p S in the context of a loop L and the SCEV predication
15129 /// infrastructure.
15130 ///
15131 /// If \p Pred is non-null, the SCEV expression is rewritten to respect the
15132 /// equivalences present in \p Pred.
15133 ///
15134 /// If \p NewPreds is non-null, rewrite is free to add further predicates to
15135 /// \p NewPreds such that the result will be an AddRecExpr.
15136 static const SCEV *rewrite(const SCEV *S, const Loop *L, ScalarEvolution &SE,
15138 const SCEVPredicate *Pred) {
15139 SCEVPredicateRewriter Rewriter(L, SE, NewPreds, Pred);
15140 return Rewriter.visit(S);
15141 }
15142
15143 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
15144 if (Pred) {
15145 if (auto *U = dyn_cast<SCEVUnionPredicate>(Pred)) {
15146 for (const auto *Pred : U->getPredicates())
15147 if (const auto *IPred = dyn_cast<SCEVComparePredicate>(Pred))
15148 if (IPred->getLHS() == Expr &&
15149 IPred->getPredicate() == ICmpInst::ICMP_EQ)
15150 return IPred->getRHS();
15151 } else if (const auto *IPred = dyn_cast<SCEVComparePredicate>(Pred)) {
15152 if (IPred->getLHS() == Expr &&
15153 IPred->getPredicate() == ICmpInst::ICMP_EQ)
15154 return IPred->getRHS();
15155 }
15156 }
15157 return convertToAddRecWithPreds(Expr);
15158 }
15159
15160 const SCEV *visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr) {
15161 const SCEV *Operand = visit(Expr->getOperand());
15162 const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Operand);
15163 if (AR && AR->getLoop() == L && AR->isAffine()) {
15164 // This couldn't be folded because the operand didn't have the nuw
15165 // flag. Add the nusw flag as an assumption that we could make.
15166 const SCEV *Step = AR->getStepRecurrence(SE);
15167 Type *Ty = Expr->getType();
15168 if (addOverflowAssumption(AR, SCEVWrapPredicate::IncrementNUSW))
15169 return SE.getAddRecExpr(SE.getZeroExtendExpr(AR->getStart(), Ty),
15170 SE.getSignExtendExpr(Step, Ty), L,
15171 AR->getNoWrapFlags());
15172 }
15173 return SE.getZeroExtendExpr(Operand, Expr->getType());
15174 }
15175
15176 const SCEV *visitSignExtendExpr(const SCEVSignExtendExpr *Expr) {
15177 const SCEV *Operand = visit(Expr->getOperand());
15178 const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Operand);
15179 if (AR && AR->getLoop() == L && AR->isAffine()) {
15180 // This couldn't be folded because the operand didn't have the nsw
15181 // flag. Add the nssw flag as an assumption that we could make.
15182 const SCEV *Step = AR->getStepRecurrence(SE);
15183 Type *Ty = Expr->getType();
15184 if (addOverflowAssumption(AR, SCEVWrapPredicate::IncrementNSSW))
15185 return SE.getAddRecExpr(SE.getSignExtendExpr(AR->getStart(), Ty),
15186 SE.getSignExtendExpr(Step, Ty), L,
15187 AR->getNoWrapFlags());
15188 }
15189 return SE.getSignExtendExpr(Operand, Expr->getType());
15190 }
15191
15192private:
15193 explicit SCEVPredicateRewriter(
15194 const Loop *L, ScalarEvolution &SE,
15195 SmallVectorImpl<const SCEVPredicate *> *NewPreds,
15196 const SCEVPredicate *Pred)
15197 : SCEVRewriteVisitor(SE), NewPreds(NewPreds), Pred(Pred), L(L) {}
15198
15199 bool addOverflowAssumption(const SCEVPredicate *P) {
15200 if (!NewPreds) {
15201 // Check if we've already made this assumption.
15202 return Pred && Pred->implies(P, SE);
15203 }
15204 NewPreds->push_back(P);
15205 return true;
15206 }
15207
15208 bool addOverflowAssumption(const SCEVAddRecExpr *AR,
15210 auto *A = SE.getWrapPredicate(AR, AddedFlags);
15211 return addOverflowAssumption(A);
15212 }
15213
15214 // If \p Expr represents a PHINode, we try to see if it can be represented
15215 // as an AddRec, possibly under a predicate (PHISCEVPred). If it is possible
15216 // to add this predicate as a runtime overflow check, we return the AddRec.
15217 // If \p Expr does not meet these conditions (is not a PHI node, or we
15218 // couldn't create an AddRec for it, or couldn't add the predicate), we just
15219 // return \p Expr.
15220 const SCEV *convertToAddRecWithPreds(const SCEVUnknown *Expr) {
15221 if (!isa<PHINode>(Expr->getValue()))
15222 return Expr;
15223 std::optional<
15224 std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>>>
15225 PredicatedRewrite = SE.createAddRecFromPHIWithCasts(Expr);
15226 if (!PredicatedRewrite)
15227 return Expr;
15228 for (const auto *P : PredicatedRewrite->second){
15229 // Wrap predicates from outer loops are not supported.
15230 if (auto *WP = dyn_cast<const SCEVWrapPredicate>(P)) {
15231 if (L != WP->getExpr()->getLoop())
15232 return Expr;
15233 }
15234 if (!addOverflowAssumption(P))
15235 return Expr;
15236 }
15237 return PredicatedRewrite->first;
15238 }
15239
15240 SmallVectorImpl<const SCEVPredicate *> *NewPreds;
15241 const SCEVPredicate *Pred;
15242 const Loop *L;
15243};
15244
15245} // end anonymous namespace
15246
15247const SCEV *
15249 const SCEVPredicate &Preds) {
15250 return SCEVPredicateRewriter::rewrite(S, L, *this, nullptr, &Preds);
15251}
15252
15254 const SCEV *S, const Loop *L,
15257 S = SCEVPredicateRewriter::rewrite(S, L, *this, &TransformPreds, nullptr);
15258 auto *AddRec = dyn_cast<SCEVAddRecExpr>(S);
15259
15260 if (!AddRec)
15261 return nullptr;
15262
15263 // Check if any of the transformed predicates is known to be false. In that
15264 // case, it doesn't make sense to convert to a predicated AddRec, as the
15265 // versioned loop will never execute.
15266 for (const SCEVPredicate *Pred : TransformPreds) {
15267 auto *WrapPred = dyn_cast<SCEVWrapPredicate>(Pred);
15268 if (!WrapPred || WrapPred->getFlags() != SCEVWrapPredicate::IncrementNSSW)
15269 continue;
15270
15271 const SCEVAddRecExpr *AddRecToCheck = WrapPred->getExpr();
15272 const SCEV *ExitCount = getBackedgeTakenCount(AddRecToCheck->getLoop());
15273 if (isa<SCEVCouldNotCompute>(ExitCount))
15274 continue;
15275
15276 const SCEV *Step = AddRecToCheck->getStepRecurrence(*this);
15277 if (!Step->isOne())
15278 continue;
15279
15280 ExitCount = getTruncateOrSignExtend(ExitCount, Step->getType());
15281 const SCEV *Add = getAddExpr(AddRecToCheck->getStart(), ExitCount);
15282 if (isKnownPredicate(CmpInst::ICMP_SLT, Add, AddRecToCheck->getStart()))
15283 return nullptr;
15284 }
15285
15286 // Since the transformation was successful, we can now transfer the SCEV
15287 // predicates.
15288 Preds.append(TransformPreds.begin(), TransformPreds.end());
15289
15290 return AddRec;
15291}
15292
15293/// SCEV predicates
15297
15299 const ICmpInst::Predicate Pred,
15300 const SCEV *LHS, const SCEV *RHS)
15301 : SCEVPredicate(ID, P_Compare), Pred(Pred), LHS(LHS), RHS(RHS) {
15302 assert(LHS->getType() == RHS->getType() && "LHS and RHS types don't match");
15303 assert(LHS != RHS && "LHS and RHS are the same SCEV");
15304}
15305
15307 ScalarEvolution &SE) const {
15308 const auto *Op = dyn_cast<SCEVComparePredicate>(N);
15309
15310 if (!Op)
15311 return false;
15312
15313 if (Pred != ICmpInst::ICMP_EQ)
15314 return false;
15315
15316 return Op->LHS == LHS && Op->RHS == RHS;
15317}
15318
15319bool SCEVComparePredicate::isAlwaysTrue() const { return false; }
15320
15322 if (Pred == ICmpInst::ICMP_EQ)
15323 OS.indent(Depth) << "Equal predicate: " << *LHS << " == " << *RHS << "\n";
15324 else
15325 OS.indent(Depth) << "Compare predicate: " << *LHS << " " << Pred << ") "
15326 << *RHS << "\n";
15327
15328}
15329
15331 const SCEVAddRecExpr *AR,
15332 IncrementWrapFlags Flags)
15333 : SCEVPredicate(ID, P_Wrap), AR(AR), Flags(Flags) {}
15334
15335const SCEVAddRecExpr *SCEVWrapPredicate::getExpr() const { return AR; }
15336
15338 ScalarEvolution &SE) const {
15339 const auto *Op = dyn_cast<SCEVWrapPredicate>(N);
15340 if (!Op || setFlags(Flags, Op->Flags) != Flags)
15341 return false;
15342
15343 if (Op->AR == AR)
15344 return true;
15345
15346 if (Flags != SCEVWrapPredicate::IncrementNSSW &&
15348 return false;
15349
15350 const SCEV *Start = AR->getStart();
15351 const SCEV *OpStart = Op->AR->getStart();
15352 if (Start->getType()->isPointerTy() != OpStart->getType()->isPointerTy())
15353 return false;
15354
15355 // Reject pointers to different address spaces.
15356 if (Start->getType()->isPointerTy() && Start->getType() != OpStart->getType())
15357 return false;
15358
15359 const SCEV *Step = AR->getStepRecurrence(SE);
15360 const SCEV *OpStep = Op->AR->getStepRecurrence(SE);
15361 if (!SE.isKnownPositive(Step) || !SE.isKnownPositive(OpStep))
15362 return false;
15363
15364 // If both steps are positive, this implies N, if N's start and step are
15365 // ULE/SLE (for NSUW/NSSW) than this'.
15366 Type *WiderTy = SE.getWiderType(Step->getType(), OpStep->getType());
15367 Step = SE.getNoopOrZeroExtend(Step, WiderTy);
15368 OpStep = SE.getNoopOrZeroExtend(OpStep, WiderTy);
15369
15370 bool IsNUW = Flags == SCEVWrapPredicate::IncrementNUSW;
15371 OpStart = IsNUW ? SE.getNoopOrZeroExtend(OpStart, WiderTy)
15372 : SE.getNoopOrSignExtend(OpStart, WiderTy);
15373 Start = IsNUW ? SE.getNoopOrZeroExtend(Start, WiderTy)
15374 : SE.getNoopOrSignExtend(Start, WiderTy);
15376 return SE.isKnownPredicate(Pred, OpStep, Step) &&
15377 SE.isKnownPredicate(Pred, OpStart, Start);
15378}
15379
15381 SCEV::NoWrapFlags ScevFlags = AR->getNoWrapFlags();
15382 IncrementWrapFlags IFlags = Flags;
15383
15384 if (ScalarEvolution::setFlags(ScevFlags, SCEV::FlagNSW) == ScevFlags)
15385 IFlags = clearFlags(IFlags, IncrementNSSW);
15386
15387 return IFlags == IncrementAnyWrap;
15388}
15389
15390void SCEVWrapPredicate::print(raw_ostream &OS, unsigned Depth) const {
15391 OS.indent(Depth) << *getExpr() << " Added Flags: ";
15393 OS << "<nusw>";
15395 OS << "<nssw>";
15396 OS << "\n";
15397}
15398
15401 ScalarEvolution &SE) {
15402 IncrementWrapFlags ImpliedFlags = IncrementAnyWrap;
15403 SCEV::NoWrapFlags StaticFlags = AR->getNoWrapFlags();
15404
15405 // We can safely transfer the NSW flag as NSSW.
15406 if (ScalarEvolution::setFlags(StaticFlags, SCEV::FlagNSW) == StaticFlags)
15407 ImpliedFlags = IncrementNSSW;
15408
15409 if (ScalarEvolution::setFlags(StaticFlags, SCEV::FlagNUW) == StaticFlags) {
15410 // If the increment is positive, the SCEV NUW flag will also imply the
15411 // WrapPredicate NUSW flag.
15412 if (const auto *Step = dyn_cast<SCEVConstant>(AR->getStepRecurrence(SE)))
15413 if (Step->getValue()->getValue().isNonNegative())
15414 ImpliedFlags = setFlags(ImpliedFlags, IncrementNUSW);
15415 }
15416
15417 return ImpliedFlags;
15418}
15419
15420/// Union predicates don't get cached so create a dummy set ID for it.
15422 ScalarEvolution &SE)
15423 : SCEVPredicate(FoldingSetNodeIDRef(nullptr, 0), P_Union) {
15424 for (const auto *P : Preds)
15425 add(P, SE);
15426}
15427
15429 return all_of(Preds,
15430 [](const SCEVPredicate *I) { return I->isAlwaysTrue(); });
15431}
15432
15434 ScalarEvolution &SE) const {
15435 if (const auto *Set = dyn_cast<SCEVUnionPredicate>(N))
15436 return all_of(Set->Preds, [this, &SE](const SCEVPredicate *I) {
15437 return this->implies(I, SE);
15438 });
15439
15440 return any_of(Preds,
15441 [N, &SE](const SCEVPredicate *I) { return I->implies(N, SE); });
15442}
15443
15445 for (const auto *Pred : Preds)
15446 Pred->print(OS, Depth);
15447}
15448
15449void SCEVUnionPredicate::add(const SCEVPredicate *N, ScalarEvolution &SE) {
15450 if (const auto *Set = dyn_cast<SCEVUnionPredicate>(N)) {
15451 for (const auto *Pred : Set->Preds)
15452 add(Pred, SE);
15453 return;
15454 }
15455
15456 // Implication checks are quadratic in the number of predicates. Stop doing
15457 // them if there are many predicates, as they should be too expensive to use
15458 // anyway at that point.
15459 bool CheckImplies = Preds.size() < 16;
15460
15461 // Only add predicate if it is not already implied by this union predicate.
15462 if (CheckImplies && implies(N, SE))
15463 return;
15464
15465 // Build a new vector containing the current predicates, except the ones that
15466 // are implied by the new predicate N.
15468 for (auto *P : Preds) {
15469 if (CheckImplies && N->implies(P, SE))
15470 continue;
15471 PrunedPreds.push_back(P);
15472 }
15473 Preds = std::move(PrunedPreds);
15474 Preds.push_back(N);
15475}
15476
15478 Loop &L)
15479 : SE(SE), L(L) {
15481 Preds = std::make_unique<SCEVUnionPredicate>(Empty, SE);
15482}
15483
15486 for (const auto *Op : Ops)
15487 // We do not expect that forgetting cached data for SCEVConstants will ever
15488 // open any prospects for sharpening or introduce any correctness issues,
15489 // so we don't bother storing their dependencies.
15490 if (!isa<SCEVConstant>(Op))
15491 SCEVUsers[Op].insert(User);
15492}
15493
15495 for (const SCEV *Op : Ops)
15496 // We do not expect that forgetting cached data for SCEVConstants will ever
15497 // open any prospects for sharpening or introduce any correctness issues,
15498 // so we don't bother storing their dependencies.
15499 if (!isa<SCEVConstant>(Op))
15500 SCEVUsers[Op].insert(User);
15501}
15502
15504 const SCEV *Expr = SE.getSCEV(V);
15505 return getPredicatedSCEV(Expr);
15506}
15507
15509 RewriteEntry &Entry = RewriteMap[Expr];
15510
15511 // If we already have an entry and the version matches, return it.
15512 if (Entry.second && Generation == Entry.first)
15513 return Entry.second;
15514
15515 // We found an entry but it's stale. Rewrite the stale entry
15516 // according to the current predicate.
15517 if (Entry.second)
15518 Expr = Entry.second;
15519
15520 const SCEV *NewSCEV = SE.rewriteUsingPredicate(Expr, &L, *Preds);
15521 Entry = {Generation, NewSCEV};
15522
15523 return NewSCEV;
15524}
15525
15527 if (!BackedgeCount) {
15529 BackedgeCount = SE.getPredicatedBackedgeTakenCount(&L, Preds);
15530 for (const auto *P : Preds)
15531 addPredicate(*P);
15532 }
15533 return BackedgeCount;
15534}
15535
15537 if (!SymbolicMaxBackedgeCount) {
15539 SymbolicMaxBackedgeCount =
15540 SE.getPredicatedSymbolicMaxBackedgeTakenCount(&L, Preds);
15541 for (const auto *P : Preds)
15542 addPredicate(*P);
15543 }
15544 return SymbolicMaxBackedgeCount;
15545}
15546
15548 if (!SmallConstantMaxTripCount) {
15550 SmallConstantMaxTripCount = SE.getSmallConstantMaxTripCount(&L, &Preds);
15551 for (const auto *P : Preds)
15552 addPredicate(*P);
15553 }
15554 return *SmallConstantMaxTripCount;
15555}
15556
15558 if (Preds->implies(&Pred, SE))
15559 return;
15560
15561 SmallVector<const SCEVPredicate *, 4> NewPreds(Preds->getPredicates());
15562 NewPreds.push_back(&Pred);
15563 Preds = std::make_unique<SCEVUnionPredicate>(NewPreds, SE);
15564 updateGeneration();
15565}
15566
15568 return *Preds;
15569}
15570
15571void PredicatedScalarEvolution::updateGeneration() {
15572 // If the generation number wrapped recompute everything.
15573 if (++Generation == 0) {
15574 for (auto &II : RewriteMap) {
15575 const SCEV *Rewritten = II.second.second;
15576 II.second = {Generation, SE.rewriteUsingPredicate(Rewritten, &L, *Preds)};
15577 }
15578 }
15579}
15580
15583 const SCEV *Expr = getSCEV(V);
15584 const auto *AR = cast<SCEVAddRecExpr>(Expr);
15585
15586 auto ImpliedFlags = SCEVWrapPredicate::getImpliedFlags(AR, SE);
15587
15588 // Clear the statically implied flags.
15589 Flags = SCEVWrapPredicate::clearFlags(Flags, ImpliedFlags);
15590 addPredicate(*SE.getWrapPredicate(AR, Flags));
15591
15592 auto II = FlagsMap.insert({V, Flags});
15593 if (!II.second)
15594 II.first->second = SCEVWrapPredicate::setFlags(Flags, II.first->second);
15595}
15596
15599 const SCEV *Expr = getSCEV(V);
15600 const auto *AR = cast<SCEVAddRecExpr>(Expr);
15601
15603 Flags, SCEVWrapPredicate::getImpliedFlags(AR, SE));
15604
15605 auto II = FlagsMap.find(V);
15606
15607 if (II != FlagsMap.end())
15608 Flags = SCEVWrapPredicate::clearFlags(Flags, II->second);
15609
15611}
15612
15614 const SCEV *Expr = this->getSCEV(V);
15616 auto *New = SE.convertSCEVToAddRecWithPredicates(Expr, &L, NewPreds);
15617
15618 if (!New)
15619 return nullptr;
15620
15621 for (const auto *P : NewPreds)
15622 addPredicate(*P);
15623
15624 RewriteMap[SE.getSCEV(V)] = {Generation, New};
15625 return New;
15626}
15627
15630 : RewriteMap(Init.RewriteMap), SE(Init.SE), L(Init.L),
15631 Preds(std::make_unique<SCEVUnionPredicate>(Init.Preds->getPredicates(),
15632 SE)),
15633 Generation(Init.Generation), BackedgeCount(Init.BackedgeCount) {
15634 for (auto I : Init.FlagsMap)
15635 FlagsMap.insert(I);
15636}
15637
15639 // For each block.
15640 for (auto *BB : L.getBlocks())
15641 for (auto &I : *BB) {
15642 if (!SE.isSCEVable(I.getType()))
15643 continue;
15644
15645 auto *Expr = SE.getSCEV(&I);
15646 auto II = RewriteMap.find(Expr);
15647
15648 if (II == RewriteMap.end())
15649 continue;
15650
15651 // Don't print things that are not interesting.
15652 if (II->second.second == Expr)
15653 continue;
15654
15655 OS.indent(Depth) << "[PSE]" << I << ":\n";
15656 OS.indent(Depth + 2) << *Expr << "\n";
15657 OS.indent(Depth + 2) << "--> " << *II->second.second << "\n";
15658 }
15659}
15660
15663 BasicBlock *Header = L->getHeader();
15664 BasicBlock *Pred = L->getLoopPredecessor();
15665 LoopGuards Guards(SE);
15666 if (!Pred)
15667 return Guards;
15669 collectFromBlock(SE, Guards, Header, Pred, VisitedBlocks);
15670 return Guards;
15671}
15672
15673void ScalarEvolution::LoopGuards::collectFromPHI(
15677 unsigned Depth) {
15678 if (!SE.isSCEVable(Phi.getType()))
15679 return;
15680
15681 using MinMaxPattern = std::pair<const SCEVConstant *, SCEVTypes>;
15682 auto GetMinMaxConst = [&](unsigned IncomingIdx) -> MinMaxPattern {
15683 const BasicBlock *InBlock = Phi.getIncomingBlock(IncomingIdx);
15684 if (!VisitedBlocks.insert(InBlock).second)
15685 return {nullptr, scCouldNotCompute};
15686
15687 // Avoid analyzing unreachable blocks so that we don't get trapped
15688 // traversing cycles with ill-formed dominance or infinite cycles
15689 if (!SE.DT.isReachableFromEntry(InBlock))
15690 return {nullptr, scCouldNotCompute};
15691
15692 auto [G, Inserted] = IncomingGuards.try_emplace(InBlock, LoopGuards(SE));
15693 if (Inserted)
15694 collectFromBlock(SE, G->second, Phi.getParent(), InBlock, VisitedBlocks,
15695 Depth + 1);
15696 auto &RewriteMap = G->second.RewriteMap;
15697 if (RewriteMap.empty())
15698 return {nullptr, scCouldNotCompute};
15699 auto S = RewriteMap.find(SE.getSCEV(Phi.getIncomingValue(IncomingIdx)));
15700 if (S == RewriteMap.end())
15701 return {nullptr, scCouldNotCompute};
15702 auto *SM = dyn_cast_if_present<SCEVMinMaxExpr>(S->second);
15703 if (!SM)
15704 return {nullptr, scCouldNotCompute};
15705 if (const SCEVConstant *C0 = dyn_cast<SCEVConstant>(SM->getOperand(0)))
15706 return {C0, SM->getSCEVType()};
15707 return {nullptr, scCouldNotCompute};
15708 };
15709 auto MergeMinMaxConst = [](MinMaxPattern P1,
15710 MinMaxPattern P2) -> MinMaxPattern {
15711 auto [C1, T1] = P1;
15712 auto [C2, T2] = P2;
15713 if (!C1 || !C2 || T1 != T2)
15714 return {nullptr, scCouldNotCompute};
15715 switch (T1) {
15716 case scUMaxExpr:
15717 return {C1->getAPInt().ult(C2->getAPInt()) ? C1 : C2, T1};
15718 case scSMaxExpr:
15719 return {C1->getAPInt().slt(C2->getAPInt()) ? C1 : C2, T1};
15720 case scUMinExpr:
15721 return {C1->getAPInt().ugt(C2->getAPInt()) ? C1 : C2, T1};
15722 case scSMinExpr:
15723 return {C1->getAPInt().sgt(C2->getAPInt()) ? C1 : C2, T1};
15724 default:
15725 llvm_unreachable("Trying to merge non-MinMaxExpr SCEVs.");
15726 }
15727 };
15728 auto P = GetMinMaxConst(0);
15729 for (unsigned int In = 1; In < Phi.getNumIncomingValues(); In++) {
15730 if (!P.first)
15731 break;
15732 P = MergeMinMaxConst(P, GetMinMaxConst(In));
15733 }
15734 if (P.first) {
15735 const SCEV *LHS = SE.getSCEV(const_cast<PHINode *>(&Phi));
15736 SmallVector<SCEVUse, 2> Ops({P.first, LHS});
15737 const SCEV *RHS = SE.getMinMaxExpr(P.second, Ops);
15738 Guards.RewriteMap.insert({LHS, RHS});
15739 }
15740}
15741
15742// Return a new SCEV that modifies \p Expr to the closest number divides by
15743// \p Divisor and less or equal than Expr. For now, only handle constant
15744// Expr.
15746 const APInt &DivisorVal,
15747 ScalarEvolution &SE) {
15748 const APInt *ExprVal;
15749 if (!match(Expr, m_scev_APInt(ExprVal)) || ExprVal->isNegative() ||
15750 DivisorVal.isNonPositive())
15751 return Expr;
15752 APInt Rem = ExprVal->urem(DivisorVal);
15753 // return the SCEV: Expr - Expr % Divisor
15754 return SE.getConstant(*ExprVal - Rem);
15755}
15756
15757// Return a new SCEV that modifies \p Expr to the closest number divides by
15758// \p Divisor and greater or equal than Expr. For now, only handle constant
15759// Expr.
15760static const SCEV *getNextSCEVDivisibleByDivisor(const SCEV *Expr,
15761 const APInt &DivisorVal,
15762 ScalarEvolution &SE) {
15763 const APInt *ExprVal;
15764 if (!match(Expr, m_scev_APInt(ExprVal)) || ExprVal->isNegative() ||
15765 DivisorVal.isNonPositive())
15766 return Expr;
15767 APInt Rem = ExprVal->urem(DivisorVal);
15768 if (Rem.isZero())
15769 return Expr;
15770 // return the SCEV: Expr + Divisor - Expr % Divisor
15771 return SE.getConstant(*ExprVal + DivisorVal - Rem);
15772}
15773
15775 ICmpInst::Predicate Predicate, const SCEV *LHS, const SCEV *RHS,
15778 // If we have LHS == 0, check if LHS is computing a property of some unknown
15779 // SCEV %v which we can rewrite %v to express explicitly.
15781 return false;
15782 // If LHS is A % B, i.e. A % B == 0, rewrite A to (A /u B) * B to
15783 // explicitly express that.
15784 const SCEVUnknown *URemLHS = nullptr;
15785 const SCEV *URemRHS = nullptr;
15786 if (!match(LHS, m_scev_URem(m_SCEVUnknown(URemLHS), m_SCEV(URemRHS), SE)))
15787 return false;
15788
15789 const SCEV *Multiple =
15790 SE.getMulExpr(SE.getUDivExpr(URemLHS, URemRHS), URemRHS);
15791 DivInfo[URemLHS] = Multiple;
15792 if (auto *C = dyn_cast<SCEVConstant>(URemRHS))
15793 Multiples[URemLHS] = C->getAPInt();
15794 return true;
15795}
15796
15797// Check if the condition is a divisibility guard (A % B == 0).
15798static bool isDivisibilityGuard(const SCEV *LHS, const SCEV *RHS,
15799 ScalarEvolution &SE) {
15800 const SCEV *X, *Y;
15801 return match(LHS, m_scev_URem(m_SCEV(X), m_SCEV(Y), SE)) && RHS->isZero();
15802}
15803
15804// Apply divisibility by \p Divisor on MinMaxExpr with constant values,
15805// recursively. This is done by aligning up/down the constant value to the
15806// Divisor.
15807static const SCEV *applyDivisibilityOnMinMaxExpr(const SCEV *MinMaxExpr,
15808 APInt Divisor,
15809 ScalarEvolution &SE) {
15810 // Return true if \p Expr is a MinMax SCEV expression with a non-negative
15811 // constant operand. If so, return in \p SCTy the SCEV type and in \p RHS
15812 // the non-constant operand and in \p LHS the constant operand.
15813 auto IsMinMaxSCEVWithNonNegativeConstant =
15814 [&](const SCEV *Expr, SCEVTypes &SCTy, const SCEV *&LHS,
15815 const SCEV *&RHS) {
15816 if (auto *MinMax = dyn_cast<SCEVMinMaxExpr>(Expr)) {
15817 if (MinMax->getNumOperands() != 2)
15818 return false;
15819 if (auto *C = dyn_cast<SCEVConstant>(MinMax->getOperand(0))) {
15820 if (C->getAPInt().isNegative())
15821 return false;
15822 SCTy = MinMax->getSCEVType();
15823 LHS = MinMax->getOperand(0);
15824 RHS = MinMax->getOperand(1);
15825 return true;
15826 }
15827 }
15828 return false;
15829 };
15830
15831 const SCEV *MinMaxLHS = nullptr, *MinMaxRHS = nullptr;
15832 SCEVTypes SCTy;
15833 if (!IsMinMaxSCEVWithNonNegativeConstant(MinMaxExpr, SCTy, MinMaxLHS,
15834 MinMaxRHS))
15835 return MinMaxExpr;
15836 auto IsMin = isa<SCEVSMinExpr>(MinMaxExpr) || isa<SCEVUMinExpr>(MinMaxExpr);
15837 assert(SE.isKnownNonNegative(MinMaxLHS) && "Expected non-negative operand!");
15838 auto *DivisibleExpr =
15839 IsMin ? getPreviousSCEVDivisibleByDivisor(MinMaxLHS, Divisor, SE)
15840 : getNextSCEVDivisibleByDivisor(MinMaxLHS, Divisor, SE);
15842 applyDivisibilityOnMinMaxExpr(MinMaxRHS, Divisor, SE), DivisibleExpr};
15843 return SE.getMinMaxExpr(SCTy, Ops);
15844}
15845
15846void ScalarEvolution::LoopGuards::collectFromBlock(
15847 ScalarEvolution &SE, ScalarEvolution::LoopGuards &Guards,
15848 const BasicBlock *Block, const BasicBlock *Pred,
15849 SmallPtrSetImpl<const BasicBlock *> &VisitedBlocks, unsigned Depth) {
15850
15852
15853 SmallVector<SCEVUse> ExprsToRewrite;
15854 auto CollectCondition = [&](ICmpInst::Predicate Predicate, const SCEV *LHS,
15855 const SCEV *RHS,
15856 DenseMap<const SCEV *, const SCEV *> &RewriteMap,
15857 const LoopGuards &DivGuards) {
15858 // WARNING: It is generally unsound to apply any wrap flags to the proposed
15859 // replacement SCEV which isn't directly implied by the structure of that
15860 // SCEV. In particular, using contextual facts to imply flags is *NOT*
15861 // legal. See the scoping rules for flags in the header to understand why.
15862
15863 // Check for a condition of the form (-C1 + X < C2). InstCombine will
15864 // create this form when combining two checks of the form (X u< C2 + C1) and
15865 // (X >=u C1).
15866 auto MatchRangeCheckIdiom = [&SE, Predicate, LHS, RHS, &RewriteMap,
15867 &ExprsToRewrite]() {
15868 const SCEVConstant *C1;
15869 const SCEVUnknown *LHSUnknown;
15870 auto *C2 = dyn_cast<SCEVConstant>(RHS);
15871 if (!match(LHS,
15872 m_scev_Add(m_SCEVConstant(C1), m_SCEVUnknown(LHSUnknown))) ||
15873 !C2)
15874 return false;
15875
15876 auto ExactRegion =
15877 ConstantRange::makeExactICmpRegion(Predicate, C2->getAPInt())
15878 .sub(C1->getAPInt());
15879
15880 // Bail out, unless we have a non-wrapping, monotonic range.
15881 if (ExactRegion.isWrappedSet() || ExactRegion.isFullSet())
15882 return false;
15883 auto [I, Inserted] = RewriteMap.try_emplace(LHSUnknown);
15884 const SCEV *RewrittenLHS = Inserted ? LHSUnknown : I->second;
15885 I->second = SE.getUMaxExpr(
15886 SE.getConstant(ExactRegion.getUnsignedMin()),
15887 SE.getUMinExpr(RewrittenLHS,
15888 SE.getConstant(ExactRegion.getUnsignedMax())));
15889 ExprsToRewrite.push_back(LHSUnknown);
15890 return true;
15891 };
15892 if (MatchRangeCheckIdiom())
15893 return;
15894
15895 // Do not apply information for constants or if RHS contains an AddRec.
15897 return;
15898
15899 // If RHS is SCEVUnknown, make sure the information is applied to it.
15901 std::swap(LHS, RHS);
15903 }
15904
15905 // Puts rewrite rule \p From -> \p To into the rewrite map. Also if \p From
15906 // and \p FromRewritten are the same (i.e. there has been no rewrite
15907 // registered for \p From), then puts this value in the list of rewritten
15908 // expressions.
15909 auto AddRewrite = [&](const SCEV *From, const SCEV *FromRewritten,
15910 const SCEV *To) {
15911 if (From == FromRewritten)
15912 ExprsToRewrite.push_back(From);
15913 RewriteMap[From] = To;
15914 };
15915
15916 // Checks whether \p S has already been rewritten. In that case returns the
15917 // existing rewrite because we want to chain further rewrites onto the
15918 // already rewritten value. Otherwise returns \p S.
15919 auto GetMaybeRewritten = [&](const SCEV *S) {
15920 return RewriteMap.lookup_or(S, S);
15921 };
15922
15923 const SCEV *RewrittenLHS = GetMaybeRewritten(LHS);
15924 // Apply divisibility information when computing the constant multiple.
15925 const APInt &DividesBy =
15926 SE.getConstantMultiple(DivGuards.rewrite(RewrittenLHS));
15927
15928 // Collect rewrites for LHS and its transitive operands based on the
15929 // condition.
15930 // For min/max expressions, also apply the guard to its operands:
15931 // 'min(a, b) >= c' -> '(a >= c) and (b >= c)',
15932 // 'min(a, b) > c' -> '(a > c) and (b > c)',
15933 // 'max(a, b) <= c' -> '(a <= c) and (b <= c)',
15934 // 'max(a, b) < c' -> '(a < c) and (b < c)'.
15935
15936 // We cannot express strict predicates in SCEV, so instead we replace them
15937 // with non-strict ones against plus or minus one of RHS depending on the
15938 // predicate.
15939 const SCEV *One = SE.getOne(RHS->getType());
15940 switch (Predicate) {
15941 case CmpInst::ICMP_ULT:
15942 if (RHS->getType()->isPointerTy())
15943 return;
15944 RHS = SE.getUMaxExpr(RHS, One);
15945 [[fallthrough]];
15946 case CmpInst::ICMP_SLT: {
15947 RHS = SE.getMinusSCEV(RHS, One);
15948 RHS = getPreviousSCEVDivisibleByDivisor(RHS, DividesBy, SE);
15949 break;
15950 }
15951 case CmpInst::ICMP_UGT:
15952 case CmpInst::ICMP_SGT:
15953 RHS = SE.getAddExpr(RHS, One);
15954 RHS = getNextSCEVDivisibleByDivisor(RHS, DividesBy, SE);
15955 break;
15956 case CmpInst::ICMP_ULE:
15957 case CmpInst::ICMP_SLE:
15958 RHS = getPreviousSCEVDivisibleByDivisor(RHS, DividesBy, SE);
15959 break;
15960 case CmpInst::ICMP_UGE:
15961 case CmpInst::ICMP_SGE:
15962 RHS = getNextSCEVDivisibleByDivisor(RHS, DividesBy, SE);
15963 break;
15964 default:
15965 break;
15966 }
15967
15968 SmallVector<SCEVUse, 16> Worklist(1, LHS);
15969 SmallPtrSet<const SCEV *, 16> Visited;
15970
15971 auto EnqueueOperands = [&Worklist](const SCEVNAryExpr *S) {
15972 append_range(Worklist, S->operands());
15973 };
15974
15975 while (!Worklist.empty()) {
15976 const SCEV *From = Worklist.pop_back_val();
15977 if (isa<SCEVConstant>(From))
15978 continue;
15979 if (!Visited.insert(From).second)
15980 continue;
15981 const SCEV *FromRewritten = GetMaybeRewritten(From);
15982 const SCEV *To = nullptr;
15983
15984 switch (Predicate) {
15985 case CmpInst::ICMP_ULT:
15986 case CmpInst::ICMP_ULE:
15987 To = SE.getUMinExpr(FromRewritten, RHS);
15988 if (auto *UMax = dyn_cast<SCEVUMaxExpr>(FromRewritten))
15989 EnqueueOperands(UMax);
15990 break;
15991 case CmpInst::ICMP_SLT:
15992 case CmpInst::ICMP_SLE:
15993 To = SE.getSMinExpr(FromRewritten, RHS);
15994 if (auto *SMax = dyn_cast<SCEVSMaxExpr>(FromRewritten))
15995 EnqueueOperands(SMax);
15996 break;
15997 case CmpInst::ICMP_UGT:
15998 case CmpInst::ICMP_UGE:
15999 To = SE.getUMaxExpr(FromRewritten, RHS);
16000 if (auto *UMin = dyn_cast<SCEVUMinExpr>(FromRewritten))
16001 EnqueueOperands(UMin);
16002 break;
16003 case CmpInst::ICMP_SGT:
16004 case CmpInst::ICMP_SGE:
16005 To = SE.getSMaxExpr(FromRewritten, RHS);
16006 if (auto *SMin = dyn_cast<SCEVSMinExpr>(FromRewritten))
16007 EnqueueOperands(SMin);
16008 break;
16009 case CmpInst::ICMP_EQ:
16011 To = RHS;
16012 break;
16013 case CmpInst::ICMP_NE:
16014 if (match(RHS, m_scev_Zero())) {
16015 const SCEV *OneAlignedUp =
16016 getNextSCEVDivisibleByDivisor(One, DividesBy, SE);
16017 To = SE.getUMaxExpr(FromRewritten, OneAlignedUp);
16018 } else {
16019 // LHS != RHS can be rewritten as (LHS - RHS) = UMax(1, LHS - RHS),
16020 // but creating the subtraction eagerly is expensive. Track the
16021 // inequalities in a separate map, and materialize the rewrite lazily
16022 // when encountering a suitable subtraction while re-writing.
16023 if (LHS->getType()->isPointerTy()) {
16027 break;
16028 }
16029 const SCEVConstant *C;
16030 const SCEV *A, *B;
16033 RHS = A;
16034 LHS = B;
16035 }
16036 if (LHS > RHS)
16037 std::swap(LHS, RHS);
16038 Guards.NotEqual.insert({LHS, RHS});
16039 continue;
16040 }
16041 break;
16042 default:
16043 break;
16044 }
16045
16046 if (To)
16047 AddRewrite(From, FromRewritten, To);
16048 }
16049 };
16050
16052 // First, collect information from assumptions dominating the loop.
16053 for (auto &AssumeVH : SE.AC.assumptions()) {
16054 if (!AssumeVH)
16055 continue;
16056 auto *AssumeI = cast<CallInst>(AssumeVH);
16057 if (!SE.DT.dominates(AssumeI, Block))
16058 continue;
16059 Terms.emplace_back(AssumeI->getOperand(0), true);
16060 }
16061
16062 // Second, collect information from llvm.experimental.guards dominating the loop.
16063 auto *GuardDecl = Intrinsic::getDeclarationIfExists(
16064 SE.F.getParent(), Intrinsic::experimental_guard);
16065 if (GuardDecl)
16066 for (const auto *GU : GuardDecl->users())
16067 if (const auto *Guard = dyn_cast<IntrinsicInst>(GU))
16068 if (Guard->getFunction() == Block->getParent() &&
16069 SE.DT.dominates(Guard, Block))
16070 Terms.emplace_back(Guard->getArgOperand(0), true);
16071
16072 // Third, collect conditions from dominating branches. Starting at the loop
16073 // predecessor, climb up the predecessor chain, as long as there are
16074 // predecessors that can be found that have unique successors leading to the
16075 // original header.
16076 // TODO: share this logic with isLoopEntryGuardedByCond.
16077 unsigned NumCollectedConditions = 0;
16079 std::pair<const BasicBlock *, const BasicBlock *> Pair(Pred, Block);
16080 for (; Pair.first;
16081 Pair = SE.getPredecessorWithUniqueSuccessorForBB(Pair.first)) {
16082 VisitedBlocks.insert(Pair.second);
16083 const CondBrInst *LoopEntryPredicate =
16084 dyn_cast<CondBrInst>(Pair.first->getTerminator());
16085 if (!LoopEntryPredicate)
16086 continue;
16087
16088 Terms.emplace_back(LoopEntryPredicate->getCondition(),
16089 LoopEntryPredicate->getSuccessor(0) == Pair.second);
16090 NumCollectedConditions++;
16091
16092 // If we are recursively collecting guards stop after 2
16093 // conditions to limit compile-time impact for now.
16094 if (Depth > 0 && NumCollectedConditions == 2)
16095 break;
16096 }
16097 // Finally, if we stopped climbing the predecessor chain because
16098 // there wasn't a unique one to continue, try to collect conditions
16099 // for PHINodes by recursively following all of their incoming
16100 // blocks and try to merge the found conditions to build a new one
16101 // for the Phi.
16102 if (Pair.second->hasNPredecessorsOrMore(2) &&
16104 SmallDenseMap<const BasicBlock *, LoopGuards> IncomingGuards;
16105 for (auto &Phi : Pair.second->phis())
16106 collectFromPHI(SE, Guards, Phi, VisitedBlocks, IncomingGuards, Depth);
16107 }
16108
16109 // Now apply the information from the collected conditions to
16110 // Guards.RewriteMap. Conditions are processed in reverse order, so the
16111 // earliest conditions is processed first, except guards with divisibility
16112 // information, which are moved to the back. This ensures the SCEVs with the
16113 // shortest dependency chains are constructed first.
16115 GuardsToProcess;
16116 for (auto [Term, EnterIfTrue] : reverse(Terms)) {
16117 SmallVector<Value *, 8> Worklist;
16118 SmallPtrSet<Value *, 8> Visited;
16119 Worklist.push_back(Term);
16120 while (!Worklist.empty()) {
16121 Value *Cond = Worklist.pop_back_val();
16122 if (!Visited.insert(Cond).second)
16123 continue;
16124
16125 if (auto *Cmp = dyn_cast<ICmpInst>(Cond)) {
16126 auto Predicate =
16127 EnterIfTrue ? Cmp->getPredicate() : Cmp->getInversePredicate();
16128 const auto *LHS = SE.getSCEV(Cmp->getOperand(0));
16129 const auto *RHS = SE.getSCEV(Cmp->getOperand(1));
16130 // If LHS is a constant, apply information to the other expression.
16131 // TODO: If LHS is not a constant, check if using CompareSCEVComplexity
16132 // can improve results.
16133 if (isa<SCEVConstant>(LHS)) {
16134 std::swap(LHS, RHS);
16136 }
16137 GuardsToProcess.emplace_back(Predicate, LHS, RHS);
16138 continue;
16139 }
16140
16141 Value *L, *R;
16142 if (EnterIfTrue ? match(Cond, m_LogicalAnd(m_Value(L), m_Value(R)))
16143 : match(Cond, m_LogicalOr(m_Value(L), m_Value(R)))) {
16144 Worklist.push_back(L);
16145 Worklist.push_back(R);
16146 }
16147 }
16148 }
16149
16150 // Process divisibility guards in reverse order to populate DivGuards early.
16151 DenseMap<const SCEV *, APInt> Multiples;
16152 LoopGuards DivGuards(SE);
16153 for (const auto &[Predicate, LHS, RHS] : GuardsToProcess) {
16154 if (!isDivisibilityGuard(LHS, RHS, SE))
16155 continue;
16156 collectDivisibilityInformation(Predicate, LHS, RHS, DivGuards.RewriteMap,
16157 Multiples, SE);
16158 }
16159
16160 for (const auto &[Predicate, LHS, RHS] : GuardsToProcess)
16161 CollectCondition(Predicate, LHS, RHS, Guards.RewriteMap, DivGuards);
16162
16163 // Apply divisibility information last. This ensures it is applied to the
16164 // outermost expression after other rewrites for the given value.
16165 for (const auto &[K, Divisor] : Multiples) {
16166 const SCEV *DivisorSCEV = SE.getConstant(Divisor);
16167 Guards.RewriteMap[K] =
16169 Guards.rewrite(K), Divisor, SE),
16170 DivisorSCEV),
16171 DivisorSCEV);
16172 ExprsToRewrite.push_back(K);
16173 }
16174
16175 // Let the rewriter preserve NUW/NSW flags if the unsigned/signed ranges of
16176 // the replacement expressions are contained in the ranges of the replaced
16177 // expressions.
16178 Guards.PreserveNUW = true;
16179 Guards.PreserveNSW = true;
16180 for (const SCEV *Expr : ExprsToRewrite) {
16181 const SCEV *RewriteTo = Guards.RewriteMap[Expr];
16182 Guards.PreserveNUW &=
16183 SE.getUnsignedRange(Expr).contains(SE.getUnsignedRange(RewriteTo));
16184 Guards.PreserveNSW &=
16185 SE.getSignedRange(Expr).contains(SE.getSignedRange(RewriteTo));
16186 }
16187
16188 // Now that all rewrite information is collect, rewrite the collected
16189 // expressions with the information in the map. This applies information to
16190 // sub-expressions.
16191 if (ExprsToRewrite.size() > 1) {
16192 for (const SCEV *Expr : ExprsToRewrite) {
16193 const SCEV *RewriteTo = Guards.RewriteMap[Expr];
16194 Guards.RewriteMap.erase(Expr);
16195 Guards.RewriteMap.insert({Expr, Guards.rewrite(RewriteTo)});
16196 }
16197 }
16198}
16199
16201 /// A rewriter to replace SCEV expressions in Map with the corresponding entry
16202 /// in the map. It skips AddRecExpr because we cannot guarantee that the
16203 /// replacement is loop invariant in the loop of the AddRec.
16204 class SCEVLoopGuardRewriter
16205 : public SCEVRewriteVisitor<SCEVLoopGuardRewriter> {
16208
16210
16211 public:
16212 SCEVLoopGuardRewriter(ScalarEvolution &SE,
16213 const ScalarEvolution::LoopGuards &Guards)
16214 : SCEVRewriteVisitor(SE), Map(Guards.RewriteMap),
16215 NotEqual(Guards.NotEqual) {
16216 if (Guards.PreserveNUW)
16217 FlagMask = ScalarEvolution::setFlags(FlagMask, SCEV::FlagNUW);
16218 if (Guards.PreserveNSW)
16219 FlagMask = ScalarEvolution::setFlags(FlagMask, SCEV::FlagNSW);
16220 }
16221
16222 const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) { return Expr; }
16223
16224 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
16225 return Map.lookup_or(Expr, Expr);
16226 }
16227
16228 const SCEV *visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr) {
16229 if (const SCEV *S = Map.lookup(Expr))
16230 return S;
16231
16232 // If we didn't find the extact ZExt expr in the map, check if there's
16233 // an entry for a smaller ZExt we can use instead.
16234 Type *Ty = Expr->getType();
16235 const SCEV *Op = Expr->getOperand(0);
16236 unsigned Bitwidth = Ty->getScalarSizeInBits() / 2;
16237 while (Bitwidth % 8 == 0 && Bitwidth >= 8 &&
16238 Bitwidth > Op->getType()->getScalarSizeInBits()) {
16239 Type *NarrowTy = IntegerType::get(SE.getContext(), Bitwidth);
16240 auto *NarrowExt = SE.getZeroExtendExpr(Op, NarrowTy);
16241 if (const SCEV *S = Map.lookup(NarrowExt))
16242 return SE.getZeroExtendExpr(S, Ty);
16243 Bitwidth = Bitwidth / 2;
16244 }
16245
16247 Expr);
16248 }
16249
16250 const SCEV *visitSignExtendExpr(const SCEVSignExtendExpr *Expr) {
16251 if (const SCEV *S = Map.lookup(Expr))
16252 return S;
16254 Expr);
16255 }
16256
16257 const SCEV *visitUMinExpr(const SCEVUMinExpr *Expr) {
16258 if (const SCEV *S = Map.lookup(Expr))
16259 return S;
16261 }
16262
16263 const SCEV *visitSMinExpr(const SCEVSMinExpr *Expr) {
16264 if (const SCEV *S = Map.lookup(Expr))
16265 return S;
16267 }
16268
16269 const SCEV *visitAddExpr(const SCEVAddExpr *Expr) {
16270 // Helper to check if S is a subtraction (A - B) where A != B, and if so,
16271 // return UMax(S, 1).
16272 auto RewriteSubtraction = [&](const SCEV *S) -> const SCEV * {
16273 SCEVUse LHS, RHS;
16274 if (MatchBinarySub(S, LHS, RHS)) {
16275 if (LHS > RHS)
16276 std::swap(LHS, RHS);
16277 if (NotEqual.contains({LHS, RHS})) {
16278 const SCEV *OneAlignedUp = getNextSCEVDivisibleByDivisor(
16279 SE.getOne(S->getType()), SE.getConstantMultiple(S), SE);
16280 return SE.getUMaxExpr(OneAlignedUp, S);
16281 }
16282 }
16283 return nullptr;
16284 };
16285
16286 // Check if Expr itself is a subtraction pattern with guard info.
16287 if (const SCEV *Rewritten = RewriteSubtraction(Expr))
16288 return Rewritten;
16289
16290 // Trip count expressions sometimes consist of adding 3 operands, i.e.
16291 // (Const + A + B). There may be guard info for A + B, and if so, apply
16292 // it.
16293 // TODO: Could more generally apply guards to Add sub-expressions.
16294 if (isa<SCEVConstant>(Expr->getOperand(0)) &&
16295 Expr->getNumOperands() == 3) {
16296 const SCEV *Add =
16297 SE.getAddExpr(Expr->getOperand(1), Expr->getOperand(2));
16298 if (const SCEV *Rewritten = RewriteSubtraction(Add))
16299 return SE.getAddExpr(
16300 Expr->getOperand(0), Rewritten,
16301 ScalarEvolution::maskFlags(Expr->getNoWrapFlags(), FlagMask));
16302 if (const SCEV *S = Map.lookup(Add))
16303 return SE.getAddExpr(Expr->getOperand(0), S);
16304 }
16305 SmallVector<SCEVUse, 2> Operands;
16306 bool Changed = false;
16307 for (SCEVUse Op : Expr->operands()) {
16308 Operands.push_back(
16310 Changed |= Op != Operands.back();
16311 }
16312 // We are only replacing operands with equivalent values, so transfer the
16313 // flags from the original expression.
16314 return !Changed ? Expr
16315 : SE.getAddExpr(Operands,
16317 Expr->getNoWrapFlags(), FlagMask));
16318 }
16319
16320 const SCEV *visitMulExpr(const SCEVMulExpr *Expr) {
16321 SmallVector<SCEVUse, 2> Operands;
16322 bool Changed = false;
16323 for (SCEVUse Op : Expr->operands()) {
16324 Operands.push_back(
16326 Changed |= Op != Operands.back();
16327 }
16328 // We are only replacing operands with equivalent values, so transfer the
16329 // flags from the original expression.
16330 return !Changed ? Expr
16331 : SE.getMulExpr(Operands,
16333 Expr->getNoWrapFlags(), FlagMask));
16334 }
16335 };
16336
16337 if (RewriteMap.empty() && NotEqual.empty())
16338 return Expr;
16339
16340 SCEVLoopGuardRewriter Rewriter(SE, *this);
16341 return Rewriter.visit(Expr);
16342}
16343
16344const SCEV *ScalarEvolution::applyLoopGuards(const SCEV *Expr, const Loop *L) {
16345 return applyLoopGuards(Expr, LoopGuards::collect(L, *this));
16346}
16347
16349 const LoopGuards &Guards) {
16350 return Guards.rewrite(Expr);
16351}
assert(UImm &&(UImm !=~static_cast< T >(0)) &&"Invalid immediate!")
constexpr LLT S1
Rewrite undef for PHI
This file implements a class to represent arbitrary precision integral constant values and operations...
@ PostInc
MachineBasicBlock MachineBasicBlock::iterator DebugLoc DL
Expand Atomic instructions
#define X(NUM, ENUM, NAME)
Definition ELF.h:851
static GCRegistry::Add< ErlangGC > A("erlang", "erlang-compatible garbage collector")
static GCRegistry::Add< StatepointGC > D("statepoint-example", "an example strategy for statepoint")
static GCRegistry::Add< CoreCLRGC > E("coreclr", "CoreCLR-compatible GC")
static GCRegistry::Add< OcamlGC > B("ocaml", "ocaml 3.10-compatible GC")
#define LLVM_DUMP_METHOD
Mark debug helper function definitions like dump() that should not be stripped from debug builds.
Definition Compiler.h:661
This file contains the declarations for the subclasses of Constant, which represent the different fla...
SmallPtrSet< const BasicBlock *, 8 > VisitedBlocks
This file defines the DenseMap class.
This file builds on the ADT/GraphTraits.h file to build generic depth first graph iterator.
This file defines a hash set that can be used to remove duplication of nodes in a graph.
#define op(i)
Hexagon Common GEP
Value * getPointer(Value *Ptr)
This file provides various utilities for inspecting and working with the control flow graph in LLVM I...
This defines the Use class.
iv Induction Variable Users
Definition IVUsers.cpp:48
const AbstractManglingParser< Derived, Alloc >::OperatorInfo AbstractManglingParser< Derived, Alloc >::Ops[]
static bool isZero(Value *V, const DataLayout &DL, DominatorTree *DT, AssumptionCache *AC)
Definition Lint.cpp:539
#define F(x, y, z)
Definition MD5.cpp:54
#define I(x, y, z)
Definition MD5.cpp:57
#define G(x, y, z)
Definition MD5.cpp:55
#define T
#define T1
static constexpr unsigned SM(unsigned Version)
ConstantRange Range(APInt(BitWidth, Low), APInt(BitWidth, High))
uint64_t IntrinsicInst * II
#define P(N)
ppc ctr loops verify
PowerPC Reduce CR logical Operation
#define INITIALIZE_PASS_DEPENDENCY(depName)
Definition PassSupport.h:42
#define INITIALIZE_PASS_END(passName, arg, name, cfg, analysis)
Definition PassSupport.h:44
#define INITIALIZE_PASS_BEGIN(passName, arg, name, cfg, analysis)
Definition PassSupport.h:39
R600 Clause Merge
const SmallVectorImpl< MachineOperand > & Cond
static DominatorTree getDomTree(Function &F)
static bool isValid(const char C)
Returns true if C is a valid mangled character: <0-9a-zA-Z_>.
SI optimize exec mask operations pre RA
static void visit(BasicBlock &Start, std::function< bool(BasicBlock *)> op)
This file contains some templates that are useful if you are working with the STL at all.
This file provides utility classes that use RAII to save and restore values.
bool SCEVMinMaxExprContains(const SCEV *Root, const SCEV *OperandToFind, SCEVTypes RootKind)
static cl::opt< unsigned > MaxAddRecSize("scalar-evolution-max-add-rec-size", cl::Hidden, cl::desc("Max coefficients in AddRec during evolving"), cl::init(8))
static cl::opt< unsigned > RangeIterThreshold("scev-range-iter-threshold", cl::Hidden, cl::desc("Threshold for switching to iteratively computing SCEV ranges"), cl::init(32))
static const Loop * isIntegerLoopHeaderPHI(const PHINode *PN, LoopInfo &LI)
static unsigned getConstantTripCount(const SCEVConstant *ExitCount)
static int CompareValueComplexity(const LoopInfo *const LI, Value *LV, Value *RV, unsigned Depth)
Compare the two values LV and RV in terms of their "complexity" where "complexity" is a partial (and ...
static const SCEV * getNextSCEVDivisibleByDivisor(const SCEV *Expr, const APInt &DivisorVal, ScalarEvolution &SE)
static void PushLoopPHIs(const Loop *L, SmallVectorImpl< Instruction * > &Worklist, SmallPtrSetImpl< Instruction * > &Visited)
Push PHI nodes in the header of the given loop onto the given Worklist.
static void insertFoldCacheEntry(const ScalarEvolution::FoldID &ID, const SCEV *S, DenseMap< ScalarEvolution::FoldID, const SCEV * > &FoldCache, DenseMap< const SCEV *, SmallVector< ScalarEvolution::FoldID, 2 > > &FoldCacheUser)
static cl::opt< bool > ClassifyExpressions("scalar-evolution-classify-expressions", cl::Hidden, cl::init(true), cl::desc("When printing analysis, include information on every instruction"))
static bool hasHugeExpression(ArrayRef< SCEVUse > Ops)
Returns true if Ops contains a huge SCEV (the subtree of S contains at least HugeExprThreshold nodes)...
static bool CanConstantFold(const Instruction *I)
Return true if we can constant fold an instruction of the specified type, assuming that all operands ...
static cl::opt< unsigned > AddOpsInlineThreshold("scev-addops-inline-threshold", cl::Hidden, cl::desc("Threshold for inlining addition operands into a SCEV"), cl::init(500))
static cl::opt< unsigned > MaxLoopGuardCollectionDepth("scalar-evolution-max-loop-guard-collection-depth", cl::Hidden, cl::desc("Maximum depth for recursive loop guard collection"), cl::init(1))
static cl::opt< bool > VerifyIR("scev-verify-ir", cl::Hidden, cl::desc("Verify IR correctness when making sensitive SCEV queries (slow)"), cl::init(false))
static bool RangeRefPHIAllowedOperands(DominatorTree &DT, PHINode *PHI)
static const SCEV * getPreStartForExtend(const SCEVAddRecExpr *AR, Type *Ty, ScalarEvolution *SE, unsigned Depth)
static std::optional< APInt > MinOptional(std::optional< APInt > X, std::optional< APInt > Y)
Helper function to compare optional APInts: (a) if X and Y both exist, return min(X,...
static cl::opt< unsigned > MulOpsInlineThreshold("scev-mulops-inline-threshold", cl::Hidden, cl::desc("Threshold for inlining multiplication operands into a SCEV"), cl::init(32))
static BinaryOperator * getCommonInstForPHI(PHINode *PN)
static bool isDivisibilityGuard(const SCEV *LHS, const SCEV *RHS, ScalarEvolution &SE)
static std::optional< const SCEV * > createNodeForSelectViaUMinSeq(ScalarEvolution *SE, const SCEV *CondExpr, const SCEV *TrueExpr, const SCEV *FalseExpr)
static Constant * BuildConstantFromSCEV(const SCEV *V)
This builds up a Constant using the ConstantExpr interface.
static ConstantInt * EvaluateConstantChrecAtConstant(const SCEVAddRecExpr *AddRec, ConstantInt *C, ScalarEvolution &SE)
static const SCEV * BinomialCoefficient(const SCEV *It, unsigned K, ScalarEvolution &SE, Type *ResultTy)
Compute BC(It, K). The result has width W. Assume, K > 0.
static cl::opt< unsigned > MaxCastDepth("scalar-evolution-max-cast-depth", cl::Hidden, cl::desc("Maximum depth of recursive SExt/ZExt/Trunc"), cl::init(8))
static bool IsMinMaxConsistingOf(const SCEV *MaybeMinMaxExpr, const SCEV *Candidate)
Is MaybeMinMaxExpr an (U|S)(Min|Max) of Candidate and some other values?
static PHINode * getConstantEvolvingPHI(Value *V, const Loop *L)
getConstantEvolvingPHI - Given an LLVM value and a loop, return a PHI node in the loop that V is deri...
static const SCEV * SolveLinEquationWithOverflow(const APInt &A, const SCEV *B, SmallVectorImpl< const SCEVPredicate * > *Predicates, ScalarEvolution &SE, const Loop *L)
Finds the minimum unsigned root of the following equation:
static cl::opt< unsigned > MaxBruteForceIterations("scalar-evolution-max-iterations", cl::ReallyHidden, cl::desc("Maximum number of iterations SCEV will " "symbolically execute a constant " "derived loop"), cl::init(100))
static uint64_t umul_ov(uint64_t i, uint64_t j, bool &Overflow)
static void PrintSCEVWithTypeHint(raw_ostream &OS, const SCEV *S)
When printing a top-level SCEV for trip counts, it's helpful to include a type for constants which ar...
static void PrintLoopInfo(raw_ostream &OS, ScalarEvolution *SE, const Loop *L)
static SCEV::NoWrapFlags StrengthenNoWrapFlags(ScalarEvolution *SE, SCEVTypes Type, ArrayRef< SCEVUse > Ops, SCEV::NoWrapFlags Flags)
static bool containsConstantInAddMulChain(const SCEV *StartExpr)
Determine if any of the operands in this SCEV are a constant or if any of the add or multiply express...
static const SCEV * getExtendAddRecStart(const SCEVAddRecExpr *AR, Type *Ty, ScalarEvolution *SE, unsigned Depth)
static bool CollectAddOperandsWithScales(SmallDenseMap< SCEVUse, APInt, 16 > &M, SmallVectorImpl< SCEVUse > &NewOps, APInt &AccumulatedConstant, ArrayRef< SCEVUse > Ops, const APInt &Scale, ScalarEvolution &SE)
Process the given Ops list, which is a list of operands to be added under the given scale,...
static const SCEV * constantFoldAndGroupOps(ScalarEvolution &SE, LoopInfo &LI, DominatorTree &DT, SmallVectorImpl< SCEVUse > &Ops, FoldT Fold, IsIdentityT IsIdentity, IsAbsorberT IsAbsorber)
Performs a number of common optimizations on the passed Ops.
static cl::opt< unsigned > MaxPhiSCCAnalysisSize("scalar-evolution-max-scc-analysis-depth", cl::Hidden, cl::desc("Maximum amount of nodes to process while searching SCEVUnknown " "Phi strongly connected components"), cl::init(8))
static bool IsKnownPredicateViaAddRecStart(ScalarEvolution &SE, CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS)
static void GroupByComplexity(SmallVectorImpl< SCEVUse > &Ops, LoopInfo *LI, DominatorTree &DT)
Given a list of SCEV objects, order them by their complexity, and group objects of the same complexit...
static bool collectDivisibilityInformation(ICmpInst::Predicate Predicate, const SCEV *LHS, const SCEV *RHS, DenseMap< const SCEV *, const SCEV * > &DivInfo, DenseMap< const SCEV *, APInt > &Multiples, ScalarEvolution &SE)
static cl::opt< unsigned > MaxSCEVOperationsImplicationDepth("scalar-evolution-max-scev-operations-implication-depth", cl::Hidden, cl::desc("Maximum depth of recursive SCEV operations implication analysis"), cl::init(2))
static void PushDefUseChildren(Instruction *I, SmallVectorImpl< Instruction * > &Worklist, SmallPtrSetImpl< Instruction * > &Visited)
Push users of the given Instruction onto the given Worklist.
static std::optional< APInt > SolveQuadraticAddRecRange(const SCEVAddRecExpr *AddRec, const ConstantRange &Range, ScalarEvolution &SE)
Let c(n) be the value of the quadratic chrec {0,+,M,+,N} after n iterations.
static cl::opt< bool > UseContextForNoWrapFlagInference("scalar-evolution-use-context-for-no-wrap-flag-strenghening", cl::Hidden, cl::desc("Infer nuw/nsw flags using context where suitable"), cl::init(true))
static cl::opt< bool > EnableFiniteLoopControl("scalar-evolution-finite-loop", cl::Hidden, cl::desc("Handle <= and >= in finite loops"), cl::init(true))
static bool getOperandsForSelectLikePHI(DominatorTree &DT, PHINode *PN, Value *&Cond, Value *&LHS, Value *&RHS)
static std::optional< std::tuple< APInt, APInt, APInt, APInt, unsigned > > GetQuadraticEquation(const SCEVAddRecExpr *AddRec)
For a given quadratic addrec, generate coefficients of the corresponding quadratic equation,...
static bool isKnownPredicateExtendIdiom(CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS)
static std::optional< BinaryOp > MatchBinaryOp(Value *V, const DataLayout &DL, AssumptionCache &AC, const DominatorTree &DT, const Instruction *CxtI)
Try to map V into a BinaryOp, and return std::nullopt on failure.
static std::optional< APInt > SolveQuadraticAddRecExact(const SCEVAddRecExpr *AddRec, ScalarEvolution &SE)
Let c(n) be the value of the quadratic chrec {L,+,M,+,N} after n iterations.
static std::optional< APInt > TruncIfPossible(std::optional< APInt > X, unsigned BitWidth)
Helper function to truncate an optional APInt to a given BitWidth.
static cl::opt< unsigned > MaxSCEVCompareDepth("scalar-evolution-max-scev-compare-depth", cl::Hidden, cl::desc("Maximum depth of recursive SCEV complexity comparisons"), cl::init(32))
static APInt extractConstantWithoutWrapping(ScalarEvolution &SE, const SCEVConstant *ConstantTerm, const SCEVAddExpr *WholeAddExpr)
static cl::opt< unsigned > MaxConstantEvolvingDepth("scalar-evolution-max-constant-evolving-depth", cl::Hidden, cl::desc("Maximum depth of recursive constant evolving"), cl::init(32))
static ConstantRange getRangeForAffineARHelper(APInt Step, const ConstantRange &StartRange, const APInt &MaxBECount, bool Signed)
static bool MatchBinarySub(const SCEV *S, SCEVUse &LHS, SCEVUse &RHS)
static std::optional< ConstantRange > GetRangeFromMetadata(Value *V)
Helper method to assign a range to V from metadata present in the IR.
static cl::opt< unsigned > HugeExprThreshold("scalar-evolution-huge-expr-threshold", cl::Hidden, cl::desc("Size of the expression which is considered huge"), cl::init(4096))
static Type * isSimpleCastedPHI(const SCEV *Op, const SCEVUnknown *SymbolicPHI, bool &Signed, ScalarEvolution &SE)
Helper function to createAddRecFromPHIWithCasts.
static Constant * EvaluateExpression(Value *V, const Loop *L, DenseMap< Instruction *, Constant * > &Vals, const DataLayout &DL, const TargetLibraryInfo *TLI)
EvaluateExpression - Given an expression that passes the getConstantEvolvingPHI predicate,...
static const SCEV * getPreviousSCEVDivisibleByDivisor(const SCEV *Expr, const APInt &DivisorVal, ScalarEvolution &SE)
static const SCEV * MatchNotExpr(const SCEV *Expr)
If Expr computes ~A, return A else return nullptr.
static cl::opt< unsigned > MaxValueCompareDepth("scalar-evolution-max-value-compare-depth", cl::Hidden, cl::desc("Maximum depth of recursive value complexity comparisons"), cl::init(2))
static const SCEV * applyDivisibilityOnMinMaxExpr(const SCEV *MinMaxExpr, APInt Divisor, ScalarEvolution &SE)
static cl::opt< bool, true > VerifySCEVOpt("verify-scev", cl::Hidden, cl::location(VerifySCEV), cl::desc("Verify ScalarEvolution's backedge taken counts (slow)"))
static const SCEV * getSignedOverflowLimitForStep(const SCEV *Step, ICmpInst::Predicate *Pred, ScalarEvolution *SE)
static cl::opt< unsigned > MaxArithDepth("scalar-evolution-max-arith-depth", cl::Hidden, cl::desc("Maximum depth of recursive arithmetics"), cl::init(32))
static bool HasSameValue(const SCEV *A, const SCEV *B)
SCEV structural equivalence is usually sufficient for testing whether two expressions are equal,...
static uint64_t Choose(uint64_t n, uint64_t k, bool &Overflow)
Compute the result of "n choose k", the binomial coefficient.
static std::optional< int > CompareSCEVComplexity(const LoopInfo *const LI, const SCEV *LHS, const SCEV *RHS, DominatorTree &DT, unsigned Depth=0)
static bool canConstantEvolve(Instruction *I, const Loop *L)
Determine whether this instruction can constant evolve within this loop assuming its operands can all...
static PHINode * getConstantEvolvingPHIOperands(Instruction *UseInst, const Loop *L, DenseMap< Instruction *, PHINode * > &PHIMap, unsigned Depth)
getConstantEvolvingPHIOperands - Implement getConstantEvolvingPHI by recursing through each instructi...
static bool scevUnconditionallyPropagatesPoisonFromOperands(SCEVTypes Kind)
static cl::opt< bool > VerifySCEVStrict("verify-scev-strict", cl::Hidden, cl::desc("Enable stricter verification with -verify-scev is passed"))
static Constant * getOtherIncomingValue(PHINode *PN, BasicBlock *BB)
static cl::opt< bool > UseExpensiveRangeSharpening("scalar-evolution-use-expensive-range-sharpening", cl::Hidden, cl::init(false), cl::desc("Use more powerful methods of sharpening expression ranges. May " "be costly in terms of compile time"))
static const SCEV * getUnsignedOverflowLimitForStep(const SCEV *Step, ICmpInst::Predicate *Pred, ScalarEvolution *SE)
static bool IsKnownPredicateViaMinOrMax(ScalarEvolution &SE, CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS)
Is LHS Pred RHS true on the virtue of LHS or RHS being a Min or Max expression?
static bool BrPHIToSelect(DominatorTree &DT, CondBrInst *BI, PHINode *Merge, Value *&C, Value *&LHS, Value *&RHS)
This file defines the make_scope_exit function, which executes user-defined cleanup logic at scope ex...
static bool InBlock(const Value *V, const BasicBlock *BB)
Provides some synthesis utilities to produce sequences of values.
This file defines the SmallPtrSet class.
This file defines the SmallVector class.
This file defines the 'Statistic' class, which is designed to be an easy way to expose various metric...
#define STATISTIC(VARNAME, DESC)
Definition Statistic.h:171
This file contains some functions that are useful when dealing with strings.
#define LLVM_DEBUG(...)
Definition Debug.h:114
static TableGen::Emitter::Opt Y("gen-skeleton-entry", EmitSkeleton, "Generate example skeleton entry")
static SymbolRef::Type getType(const Symbol *Sym)
Definition TapiFile.cpp:39
LocallyHashedType DenseMapInfo< LocallyHashedType >::Empty
static std::optional< unsigned > getOpcode(ArrayRef< VPValue * > Values)
Returns the opcode of Values or ~0 if they do not all agree.
Definition VPlanSLP.cpp:247
static std::optional< bool > isImpliedCondOperands(CmpInst::Predicate Pred, const Value *ALHS, const Value *ARHS, const Value *BLHS, const Value *BRHS)
Return true if "icmp Pred BLHS BRHS" is true whenever "icmp PredALHS ARHS" is true.
Virtual Register Rewriter
Value * RHS
Value * LHS
BinaryOperator * Mul
static const uint32_t IV[8]
Definition blake3_impl.h:83
SCEVCastSinkingRewriter(ScalarEvolution &SE, Type *TargetTy, ConversionFn CreatePtrCast)
static const SCEV * rewrite(const SCEV *Scev, ScalarEvolution &SE, Type *TargetTy, ConversionFn CreatePtrCast)
const SCEV * visitUnknown(const SCEVUnknown *Expr)
const SCEV * visitMulExpr(const SCEVMulExpr *Expr)
const SCEV * visitAddExpr(const SCEVAddExpr *Expr)
const SCEV * visit(const SCEV *S)
Class for arbitrary precision integers.
Definition APInt.h:78
LLVM_ABI APInt umul_ov(const APInt &RHS, bool &Overflow) const
Definition APInt.cpp:2011
LLVM_ABI APInt zext(unsigned width) const
Zero extend to a new width.
Definition APInt.cpp:1043
bool isMinSignedValue() const
Determine if this is the smallest signed value.
Definition APInt.h:424
uint64_t getZExtValue() const
Get zero extended value.
Definition APInt.h:1555
void setHighBits(unsigned hiBits)
Set the top hiBits bits.
Definition APInt.h:1406
LLVM_ABI APInt getHiBits(unsigned numBits) const
Compute an APInt containing numBits highbits from this APInt.
Definition APInt.cpp:639
unsigned getActiveBits() const
Compute the number of active bits in the value.
Definition APInt.h:1527
LLVM_ABI APInt trunc(unsigned width) const
Truncate to new width.
Definition APInt.cpp:956
static APInt getMaxValue(unsigned numBits)
Gets maximum unsigned value of APInt for specific bit width.
Definition APInt.h:207
APInt abs() const
Get the absolute value.
Definition APInt.h:1810
bool sgt(const APInt &RHS) const
Signed greater than comparison.
Definition APInt.h:1208
bool isAllOnes() const
Determine if all bits are set. This is true for zero-width values.
Definition APInt.h:372
bool ugt(const APInt &RHS) const
Unsigned greater than comparison.
Definition APInt.h:1189
bool isZero() const
Determine if this value is zero, i.e. all bits are clear.
Definition APInt.h:381
bool isSignMask() const
Check if the APInt's value is returned by getSignMask.
Definition APInt.h:467
LLVM_ABI APInt urem(const APInt &RHS) const
Unsigned remainder operation.
Definition APInt.cpp:1697
unsigned getBitWidth() const
Return the number of bits in the APInt.
Definition APInt.h:1503
bool ult(const APInt &RHS) const
Unsigned less than comparison.
Definition APInt.h:1118
static APInt getSignedMaxValue(unsigned numBits)
Gets maximum signed value of APInt for a specific bit width.
Definition APInt.h:210
static APInt getMinValue(unsigned numBits)
Gets minimum unsigned value of APInt for a specific bit width.
Definition APInt.h:217
bool isNegative() const
Determine sign of this APInt.
Definition APInt.h:330
bool sle(const APInt &RHS) const
Signed less or equal comparison.
Definition APInt.h:1173
static APInt getSignedMinValue(unsigned numBits)
Gets minimum signed value of APInt for a specific bit width.
Definition APInt.h:220
bool isNonPositive() const
Determine if this APInt Value is non-positive (<= 0).
Definition APInt.h:362
unsigned countTrailingZeros() const
Definition APInt.h:1662
bool isStrictlyPositive() const
Determine if this APInt Value is positive.
Definition APInt.h:357
unsigned logBase2() const
Definition APInt.h:1776
APInt ashr(unsigned ShiftAmt) const
Arithmetic right-shift function.
Definition APInt.h:834
LLVM_ABI APInt multiplicativeInverse() const
Definition APInt.cpp:1305
bool ule(const APInt &RHS) const
Unsigned less or equal comparison.
Definition APInt.h:1157
LLVM_ABI APInt sext(unsigned width) const
Sign extend to a new width.
Definition APInt.cpp:1016
APInt shl(unsigned shiftAmt) const
Left-shift function.
Definition APInt.h:880
bool isPowerOf2() const
Check if this APInt's value is a power of two greater than zero.
Definition APInt.h:441
static APInt getLowBitsSet(unsigned numBits, unsigned loBitsSet)
Constructs an APInt value that has the bottom loBitsSet bits set.
Definition APInt.h:307
bool isSignBitSet() const
Determine if sign bit of this APInt is set.
Definition APInt.h:342
bool slt(const APInt &RHS) const
Signed less than comparison.
Definition APInt.h:1137
static APInt getZero(unsigned numBits)
Get the '0' value for the specified bit-width.
Definition APInt.h:201
bool isIntN(unsigned N) const
Check if this APInt has an N-bits unsigned integer value.
Definition APInt.h:433
static APInt getOneBitSet(unsigned numBits, unsigned BitNo)
Return an APInt with exactly one bit set in the result.
Definition APInt.h:240
bool uge(const APInt &RHS) const
Unsigned greater or equal comparison.
Definition APInt.h:1228
This templated class represents "all analyses that operate over <aparticular IR unit>" (e....
Definition Analysis.h:50
PassT::Result & getResult(IRUnitT &IR, ExtraArgTs... ExtraArgs)
Get the result of an analysis pass for a given IR unit.
Represent the analysis usage information of a pass.
void setPreservesAll()
Set by analyses that do not transform their input at all.
AnalysisUsage & addRequiredTransitive()
ArrayRef - Represent a constant reference to an array (0 or more elements consecutively in memory),...
Definition ArrayRef.h:40
iterator end() const
Definition ArrayRef.h:131
size_t size() const
size - Get the array size.
Definition ArrayRef.h:142
iterator begin() const
Definition ArrayRef.h:130
A function analysis which provides an AssumptionCache.
An immutable pass that tracks lazily created AssumptionCache objects.
A cache of @llvm.assume calls within a function.
MutableArrayRef< WeakVH > assumptions()
Access the list of assumption handles currently tracked for this function.
LLVM Basic Block Representation.
Definition BasicBlock.h:62
iterator begin()
Instruction iterator methods.
Definition BasicBlock.h:461
const Function * getParent() const
Return the enclosing method, or null if none.
Definition BasicBlock.h:213
LLVM_ABI const BasicBlock * getSinglePredecessor() const
Return the predecessor of this block if it has a single predecessor block.
const Instruction & front() const
Definition BasicBlock.h:484
const Instruction * getTerminator() const LLVM_READONLY
Returns the terminator instruction; assumes that the block is well-formed.
Definition BasicBlock.h:237
LLVM_ABI unsigned getNoWrapKind() const
Returns one of OBO::NoSignedWrap or OBO::NoUnsignedWrap.
LLVM_ABI Instruction::BinaryOps getBinaryOp() const
Returns the binary operation underlying the intrinsic.
BinaryOps getOpcode() const
Definition InstrTypes.h:374
This class represents a function call, abstracting a target machine's calling convention.
virtual void deleted()
Callback for Value destruction.
void setValPtr(Value *P)
bool isFalseWhenEqual() const
This is just a convenience.
Definition InstrTypes.h:948
Predicate
This enumeration lists the possible predicates for CmpInst subclasses.
Definition InstrTypes.h:676
@ ICMP_SLT
signed less than
Definition InstrTypes.h:705
@ ICMP_SLE
signed less or equal
Definition InstrTypes.h:706
@ ICMP_UGE
unsigned greater or equal
Definition InstrTypes.h:700
@ ICMP_UGT
unsigned greater than
Definition InstrTypes.h:699
@ ICMP_SGT
signed greater than
Definition InstrTypes.h:703
@ ICMP_ULT
unsigned less than
Definition InstrTypes.h:701
@ ICMP_NE
not equal
Definition InstrTypes.h:698
@ ICMP_SGE
signed greater or equal
Definition InstrTypes.h:704
@ ICMP_ULE
unsigned less or equal
Definition InstrTypes.h:702
bool isSigned() const
Definition InstrTypes.h:930
Predicate getSwappedPredicate() const
For example, EQ->EQ, SLE->SGE, ULT->UGT, OEQ->OEQ, ULE->UGE, OLT->OGT, etc.
Definition InstrTypes.h:827
bool isTrueWhenEqual() const
This is just a convenience.
Definition InstrTypes.h:942
Predicate getInversePredicate() const
For example, EQ -> NE, UGT -> ULE, SLT -> SGE, OEQ -> UNE, UGT -> OLE, OLT -> UGE,...
Definition InstrTypes.h:789
bool isUnsigned() const
Definition InstrTypes.h:936
bool isRelational() const
Return true if the predicate is relational (not EQ or NE).
Definition InstrTypes.h:926
An abstraction over a floating-point predicate, and a pack of an integer predicate with samesign info...
static LLVM_ABI std::optional< CmpPredicate > getMatching(CmpPredicate A, CmpPredicate B)
Compares two CmpPredicates taking samesign into account and returns the canonicalized CmpPredicate if...
LLVM_ABI CmpInst::Predicate getPreferredSignedPredicate() const
Attempts to return a signed CmpInst::Predicate from the CmpPredicate.
CmpInst::Predicate dropSameSign() const
Drops samesign information.
Conditional Branch instruction.
Value * getCondition() const
BasicBlock * getSuccessor(unsigned i) const
static LLVM_ABI Constant * getNot(Constant *C)
static Constant * getPtrAdd(Constant *Ptr, Constant *Offset, GEPNoWrapFlags NW=GEPNoWrapFlags::none(), std::optional< ConstantRange > InRange=std::nullopt, Type *OnlyIfReduced=nullptr)
Create a getelementptr i8, ptr, offset constant expression.
Definition Constants.h:1472
static LLVM_ABI Constant * getPtrToInt(Constant *C, Type *Ty, bool OnlyIfReduced=false)
static LLVM_ABI Constant * getPtrToAddr(Constant *C, Type *Ty, bool OnlyIfReduced=false)
static LLVM_ABI Constant * getAdd(Constant *C1, Constant *C2, bool HasNUW=false, bool HasNSW=false)
static LLVM_ABI Constant * getNeg(Constant *C, bool HasNSW=false)
static LLVM_ABI Constant * getTrunc(Constant *C, Type *Ty, bool OnlyIfReduced=false)
This is the shared class of boolean and integer constants.
Definition Constants.h:87
bool isZero() const
This is just a convenience method to make client code smaller for a common code.
Definition Constants.h:219
static LLVM_ABI ConstantInt * getFalse(LLVMContext &Context)
uint64_t getZExtValue() const
Return the constant as a 64-bit unsigned integer value after it has been zero extended as appropriate...
Definition Constants.h:168
const APInt & getValue() const
Return the constant as an APInt value reference.
Definition Constants.h:159
static LLVM_ABI ConstantInt * getBool(LLVMContext &Context, bool V)
This class represents a range of values.
LLVM_ABI ConstantRange add(const ConstantRange &Other) const
Return a new range representing the possible values resulting from an addition of a value in this ran...
LLVM_ABI ConstantRange zextOrTrunc(uint32_t BitWidth) const
Make this range have the bit width given by BitWidth.
PreferredRangeType
If represented precisely, the result of some range operations may consist of multiple disjoint ranges...
LLVM_ABI bool getEquivalentICmp(CmpInst::Predicate &Pred, APInt &RHS) const
Set up Pred and RHS such that ConstantRange::makeExactICmpRegion(Pred, RHS) == *this.
const APInt & getLower() const
Return the lower value for this range.
LLVM_ABI ConstantRange urem(const ConstantRange &Other) const
Return a new range representing the possible values resulting from an unsigned remainder operation of...
LLVM_ABI bool isFullSet() const
Return true if this set contains all of the elements possible for this data-type.
LLVM_ABI bool icmp(CmpInst::Predicate Pred, const ConstantRange &Other) const
Does the predicate Pred hold between ranges this and Other?
LLVM_ABI bool isEmptySet() const
Return true if this set contains no members.
LLVM_ABI ConstantRange zeroExtend(uint32_t BitWidth) const
Return a new range in the specified integer type, which must be strictly larger than the current type...
LLVM_ABI bool isSignWrappedSet() const
Return true if this set wraps around the signed domain.
LLVM_ABI APInt getSignedMin() const
Return the smallest signed value contained in the ConstantRange.
LLVM_ABI bool isWrappedSet() const
Return true if this set wraps around the unsigned domain.
LLVM_ABI void print(raw_ostream &OS) const
Print out the bounds to a stream.
LLVM_ABI ConstantRange truncate(uint32_t BitWidth, unsigned NoWrapKind=0) const
Return a new range in the specified integer type, which must be strictly smaller than the current typ...
LLVM_ABI ConstantRange signExtend(uint32_t BitWidth) const
Return a new range in the specified integer type, which must be strictly larger than the current type...
const APInt & getUpper() const
Return the upper value for this range.
LLVM_ABI ConstantRange unionWith(const ConstantRange &CR, PreferredRangeType Type=Smallest) const
Return the range that results from the union of this range with another range.
static LLVM_ABI ConstantRange makeExactICmpRegion(CmpInst::Predicate Pred, const APInt &Other)
Produce the exact range such that all values in the returned range satisfy the given predicate with a...
LLVM_ABI bool contains(const APInt &Val) const
Return true if the specified value is in the set.
LLVM_ABI APInt getUnsignedMax() const
Return the largest unsigned value contained in the ConstantRange.
LLVM_ABI ConstantRange intersectWith(const ConstantRange &CR, PreferredRangeType Type=Smallest) const
Return the range that results from the intersection of this range with another range.
LLVM_ABI APInt getSignedMax() const
Return the largest signed value contained in the ConstantRange.
static ConstantRange getNonEmpty(APInt Lower, APInt Upper)
Create non-empty constant range with the given bounds.
static LLVM_ABI ConstantRange makeGuaranteedNoWrapRegion(Instruction::BinaryOps BinOp, const ConstantRange &Other, unsigned NoWrapKind)
Produce the largest range containing all X such that "X BinOp Y" is guaranteed not to wrap (overflow)...
LLVM_ABI unsigned getMinSignedBits() const
Compute the maximal number of bits needed to represent every value in this signed range.
uint32_t getBitWidth() const
Get the bit width of this ConstantRange.
LLVM_ABI ConstantRange sub(const ConstantRange &Other) const
Return a new range representing the possible values resulting from a subtraction of a value in this r...
LLVM_ABI ConstantRange sextOrTrunc(uint32_t BitWidth) const
Make this range have the bit width given by BitWidth.
static LLVM_ABI ConstantRange makeExactNoWrapRegion(Instruction::BinaryOps BinOp, const APInt &Other, unsigned NoWrapKind)
Produce the range that contains X if and only if "X BinOp Other" does not wrap.
This is an important base class in LLVM.
Definition Constant.h:43
A parsed version of the target data layout string in and methods for querying it.
Definition DataLayout.h:64
LLVM_ABI const StructLayout * getStructLayout(StructType *Ty) const
Returns a StructLayout object, indicating the alignment of the struct, its size, and the offsets of i...
LLVM_ABI IntegerType * getIntPtrType(LLVMContext &C, unsigned AddressSpace=0) const
Returns an integer type with size at least as big as that of a pointer in the given address space.
LLVM_ABI unsigned getIndexTypeSizeInBits(Type *Ty) const
The size in bits of the index used in GEP calculation for this type.
LLVM_ABI IntegerType * getIndexType(LLVMContext &C, unsigned AddressSpace) const
Returns the type of a GEP index in AddressSpace.
TypeSize getTypeSizeInBits(Type *Ty) const
Size examples:
Definition DataLayout.h:784
ValueT lookup(const_arg_type_t< KeyT > Val) const
lookup - Return the entry for the specified key, or a default constructed value if no such entry exis...
Definition DenseMap.h:205
iterator find(const_arg_type_t< KeyT > Val)
Definition DenseMap.h:178
std::pair< iterator, bool > try_emplace(KeyT &&Key, Ts &&...Args)
Definition DenseMap.h:256
DenseMapIterator< KeyT, ValueT, KeyInfoT, BucketT > iterator
Definition DenseMap.h:74
iterator find_as(const LookupKeyT &Val)
Alternate version of find() which allows a different, and possibly less expensive,...
Definition DenseMap.h:191
size_type count(const_arg_type_t< KeyT > Val) const
Return 1 if the specified key is in the map, 0 otherwise.
Definition DenseMap.h:174
iterator end()
Definition DenseMap.h:81
bool contains(const_arg_type_t< KeyT > Val) const
Return true if the specified key is in the map, false otherwise.
Definition DenseMap.h:169
void swap(DerivedT &RHS)
Definition DenseMap.h:371
std::pair< iterator, bool > insert(const std::pair< KeyT, ValueT > &KV)
Definition DenseMap.h:241
Analysis pass which computes a DominatorTree.
Definition Dominators.h:278
Legacy analysis pass which computes a DominatorTree.
Definition Dominators.h:316
Concrete subclass of DominatorTreeBase that is used to compute a normal dominator tree.
Definition Dominators.h:159
LLVM_ABI bool isReachableFromEntry(const Use &U) const
Provide an overload for a Use.
LLVM_ABI bool dominates(const BasicBlock *BB, const Use &U) const
Return true if the (end of the) basic block BB dominates the use U.
FoldingSetNodeIDRef - This class describes a reference to an interned FoldingSetNodeID,...
Definition FoldingSet.h:172
FoldingSetNodeID - This class is used to gather all the unique data bits of a node.
Definition FoldingSet.h:209
FunctionPass(char &pid)
Definition Pass.h:316
Represents flags for the getelementptr instruction/expression.
bool hasNoUnsignedSignedWrap() const
bool hasNoUnsignedWrap() const
static GEPNoWrapFlags none()
static LLVM_ABI Type * getTypeAtIndex(Type *Ty, Value *Idx)
Return the type of the element at the given index of an indexable type.
Module * getParent()
Get the module that this global value is contained inside of...
static bool isPrivateLinkage(LinkageTypes Linkage)
static bool isInternalLinkage(LinkageTypes Linkage)
This instruction compares its operands according to the predicate given to the constructor.
CmpPredicate getCmpPredicate() const
static bool isGE(Predicate P)
Return true if the predicate is SGE or UGE.
CmpPredicate getSwappedCmpPredicate() const
static LLVM_ABI bool compare(const APInt &LHS, const APInt &RHS, ICmpInst::Predicate Pred)
Return result of LHS Pred RHS comparison.
static bool isLT(Predicate P)
Return true if the predicate is SLT or ULT.
CmpPredicate getInverseCmpPredicate() const
Predicate getNonStrictCmpPredicate() const
For example, SGT -> SGE, SLT -> SLE, ULT -> ULE, UGT -> UGE.
static bool isGT(Predicate P)
Return true if the predicate is SGT or UGT.
Predicate getFlippedSignednessPredicate() const
For example, SLT->ULT, ULT->SLT, SLE->ULE, ULE->SLE, EQ->EQ.
static CmpPredicate getInverseCmpPredicate(CmpPredicate Pred)
static bool isEquality(Predicate P)
Return true if this predicate is either EQ or NE.
bool isRelational() const
Return true if the predicate is relational (not EQ or NE).
static bool isLE(Predicate P)
Return true if the predicate is SLE or ULE.
LLVM_ABI bool hasNoUnsignedWrap() const LLVM_READONLY
Determine whether the no unsigned wrap flag is set.
LLVM_ABI bool hasNoSignedWrap() const LLVM_READONLY
Determine whether the no signed wrap flag is set.
LLVM_ABI bool isIdenticalToWhenDefined(const Instruction *I, bool IntersectAttrs=false) const LLVM_READONLY
This is like isIdenticalTo, except that it ignores the SubclassOptionalData flags,...
Class to represent integer types.
static LLVM_ABI IntegerType * get(LLVMContext &C, unsigned NumBits)
This static method is the primary way of constructing an IntegerType.
Definition Type.cpp:354
A helper class to return the specified delimiter string after the first invocation of operator String...
An instruction for reading from memory.
Analysis pass that exposes the LoopInfo for a function.
Definition LoopInfo.h:569
bool contains(const LoopT *L) const
Return true if the specified loop is contained within in this loop.
BlockT * getHeader() const
unsigned getLoopDepth() const
Return the nesting level of this loop.
BlockT * getLoopPredecessor() const
If the given loop's header has exactly one unique predecessor outside the loop, return it.
LoopT * getParentLoop() const
Return the parent loop if it exists or nullptr for top level loops.
unsigned getLoopDepth(const BlockT *BB) const
Return the loop nesting level of the specified block.
LoopT * getLoopFor(const BlockT *BB) const
Return the inner most loop that BB lives in.
The legacy pass manager's analysis pass to compute loop information.
Definition LoopInfo.h:596
Represents a single loop in the control flow graph.
Definition LoopInfo.h:40
bool isLoopInvariant(const Value *V) const
Return true if the specified value is loop invariant.
Definition LoopInfo.cpp:67
Metadata node.
Definition Metadata.h:1080
A Module instance is used to store all the information related to an LLVM module.
Definition Module.h:67
unsigned getOpcode() const
Return the opcode for this Instruction or ConstantExpr.
Definition Operator.h:43
Utility class for integer operators which may exhibit overflow - Add, Sub, Mul, and Shl.
Definition Operator.h:78
bool hasNoSignedWrap() const
Test whether this operation is known to never undergo signed overflow, aka the nsw property.
Definition Operator.h:111
bool hasNoUnsignedWrap() const
Test whether this operation is known to never undergo unsigned overflow, aka the nuw property.
Definition Operator.h:105
iterator_range< const_block_iterator > blocks() const
op_range incoming_values()
Value * getIncomingValueForBlock(const BasicBlock *BB) const
BasicBlock * getIncomingBlock(unsigned i) const
Return incoming basic block number i.
Value * getIncomingValue(unsigned i) const
Return incoming value number x.
unsigned getNumIncomingValues() const
Return the number of incoming edges.
AnalysisType & getAnalysis() const
getAnalysis<AnalysisType>() - This function is used by subclasses to get to the analysis information ...
PointerIntPair - This class implements a pair of a pointer and small integer.
static PointerType * getUnqual(Type *ElementType)
This constructs a pointer to an object of the specified type in the default address space (address sp...
static LLVM_ABI PoisonValue * get(Type *T)
Static factory methods - Return an 'poison' object of the specified type.
LLVM_ABI void addPredicate(const SCEVPredicate &Pred)
Adds a new predicate.
LLVM_ABI const SCEVPredicate & getPredicate() const
LLVM_ABI const SCEV * getPredicatedSCEV(const SCEV *Expr)
Returns the rewritten SCEV for Expr in the context of the current SCEV predicate.
LLVM_ABI bool hasNoOverflow(Value *V, SCEVWrapPredicate::IncrementWrapFlags Flags)
Returns true if we've proved that V doesn't wrap by means of a SCEV predicate.
LLVM_ABI void setNoOverflow(Value *V, SCEVWrapPredicate::IncrementWrapFlags Flags)
Proves that V doesn't overflow by adding SCEV predicate.
LLVM_ABI void print(raw_ostream &OS, unsigned Depth) const
Print the SCEV mappings done by the Predicated Scalar Evolution.
LLVM_ABI bool areAddRecsEqualWithPreds(const SCEVAddRecExpr *AR1, const SCEVAddRecExpr *AR2) const
Check if AR1 and AR2 are equal, while taking into account Equal predicates in Preds.
LLVM_ABI PredicatedScalarEvolution(ScalarEvolution &SE, Loop &L)
LLVM_ABI const SCEVAddRecExpr * getAsAddRec(Value *V)
Attempts to produce an AddRecExpr for V by adding additional SCEV predicates.
LLVM_ABI unsigned getSmallConstantMaxTripCount()
Returns the upper bound of the loop trip count as a normal unsigned value, or 0 if the trip count is ...
LLVM_ABI const SCEV * getBackedgeTakenCount()
Get the (predicated) backedge count for the analyzed loop.
LLVM_ABI const SCEV * getSymbolicMaxBackedgeTakenCount()
Get the (predicated) symbolic max backedge count for the analyzed loop.
LLVM_ABI const SCEV * getSCEV(Value *V)
Returns the SCEV expression of V, in the context of the current SCEV predicate.
A set of analyses that are preserved following a run of a transformation pass.
Definition Analysis.h:112
static PreservedAnalyses all()
Construct a special preserved set that preserves all passes.
Definition Analysis.h:118
PreservedAnalysisChecker getChecker() const
Build a checker for this PreservedAnalyses and the specified analysis type.
Definition Analysis.h:275
constexpr bool isValid() const
Definition Register.h:112
This node represents an addition of some number of SCEVs.
This node represents a polynomial recurrence on the trip count of the specified loop.
LLVM_ABI const SCEV * evaluateAtIteration(const SCEV *It, ScalarEvolution &SE) const
Return the value of this chain of recurrences at the specified iteration number.
void setNoWrapFlags(NoWrapFlags Flags)
Set flags for a recurrence without clearing any previously set flags.
bool isAffine() const
Return true if this represents an expression A + B*x where A and B are loop invariant values.
bool isQuadratic() const
Return true if this represents an expression A + B*x + C*x^2 where A, B and C are loop invariant valu...
LLVM_ABI const SCEV * getNumIterationsInRange(const ConstantRange &Range, ScalarEvolution &SE) const
Return the number of iterations of this loop that produce values in the specified constant range.
LLVM_ABI const SCEVAddRecExpr * getPostIncExpr(ScalarEvolution &SE) const
Return an expression representing the value of this expression one iteration of the loop ahead.
SCEVUse getStepRecurrence(ScalarEvolution &SE) const
Constructs and returns the recurrence indicating how much this expression steps by.
This is the base class for unary cast operator classes.
LLVM_ABI SCEVCastExpr(const FoldingSetNodeIDRef ID, SCEVTypes SCEVTy, SCEVUse op, Type *ty)
void setNoWrapFlags(NoWrapFlags Flags)
Set flags for a non-recurrence without clearing previously set flags.
This class represents an assumption that the expression LHS Pred RHS evaluates to true,...
SCEVComparePredicate(const FoldingSetNodeIDRef ID, const ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS)
bool isAlwaysTrue() const override
Returns true if the predicate is always true.
void print(raw_ostream &OS, unsigned Depth=0) const override
Prints a textual representation of this predicate with an indentation of Depth.
bool implies(const SCEVPredicate *N, ScalarEvolution &SE) const override
Implementation of the SCEVPredicate interface.
This class represents a constant integer value.
ConstantInt * getValue() const
const APInt & getAPInt() const
This is the base class for unary integral cast operator classes.
LLVM_ABI SCEVIntegralCastExpr(const FoldingSetNodeIDRef ID, SCEVTypes SCEVTy, SCEVUse op, Type *ty)
This node is the base class min/max selections.
static enum SCEVTypes negate(enum SCEVTypes T)
This node represents multiplication of some number of SCEVs.
This node is a base class providing common functionality for n'ary operators.
ArrayRef< SCEVUse > operands() const
NoWrapFlags getNoWrapFlags(NoWrapFlags Mask=NoWrapMask) const
SCEVUse getOperand(unsigned i) const
This class represents an assumption made using SCEV expressions which can be checked at run-time.
SCEVPredicate(const SCEVPredicate &)=default
virtual bool implies(const SCEVPredicate *N, ScalarEvolution &SE) const =0
Returns true if this predicate implies N.
SCEVPredicateKind Kind
This class represents a cast from a pointer to a pointer-sized integer value, without capturing the p...
This class represents a cast from a pointer to a pointer-sized integer value.
This visitor recursively visits a SCEV expression and re-writes it.
const SCEV * visitSignExtendExpr(const SCEVSignExtendExpr *Expr)
const SCEV * visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr)
const SCEV * visitSMinExpr(const SCEVSMinExpr *Expr)
const SCEV * visitUMinExpr(const SCEVUMinExpr *Expr)
This class represents a signed minimum selection.
This node is the base class for sequential/in-order min/max selections.
static SCEVTypes getEquivalentNonSequentialSCEVType(SCEVTypes Ty)
This class represents a sign extension of a small integer value to a larger integer value.
Visit all nodes in the expression tree using worklist traversal.
This class represents a truncation of an integer value to a smaller integer value.
This class represents a binary unsigned division operation.
This class represents an unsigned minimum selection.
This class represents a composition of other SCEV predicates, and is the class that most clients will...
void print(raw_ostream &OS, unsigned Depth) const override
Prints a textual representation of this predicate with an indentation of Depth.
bool implies(const SCEVPredicate *N, ScalarEvolution &SE) const override
Returns true if this predicate implies N.
SCEVUnionPredicate(ArrayRef< const SCEVPredicate * > Preds, ScalarEvolution &SE)
Union predicates don't get cached so create a dummy set ID for it.
bool isAlwaysTrue() const override
Implementation of the SCEVPredicate interface.
This means that we are dealing with an entirely unknown SCEV value, and only represent it as its LLVM...
This class represents the value of vscale, as used when defining the length of a scalable vector or r...
This class represents an assumption made on an AddRec expression.
IncrementWrapFlags
Similar to SCEV::NoWrapFlags, but with slightly different semantics for FlagNUSW.
SCEVWrapPredicate(const FoldingSetNodeIDRef ID, const SCEVAddRecExpr *AR, IncrementWrapFlags Flags)
bool implies(const SCEVPredicate *N, ScalarEvolution &SE) const override
Returns true if this predicate implies N.
static SCEVWrapPredicate::IncrementWrapFlags setFlags(SCEVWrapPredicate::IncrementWrapFlags Flags, SCEVWrapPredicate::IncrementWrapFlags OnFlags)
void print(raw_ostream &OS, unsigned Depth=0) const override
Prints a textual representation of this predicate with an indentation of Depth.
bool isAlwaysTrue() const override
Returns true if the predicate is always true.
const SCEVAddRecExpr * getExpr() const
Implementation of the SCEVPredicate interface.
static SCEVWrapPredicate::IncrementWrapFlags clearFlags(SCEVWrapPredicate::IncrementWrapFlags Flags, SCEVWrapPredicate::IncrementWrapFlags OffFlags)
Convenient IncrementWrapFlags manipulation methods.
static SCEVWrapPredicate::IncrementWrapFlags getImpliedFlags(const SCEVAddRecExpr *AR, ScalarEvolution &SE)
Returns the set of SCEVWrapPredicate no wrap flags implied by a SCEVAddRecExpr.
IncrementWrapFlags getFlags() const
Returns the set assumed no overflow flags.
This class represents a zero extension of a small integer value to a larger integer value.
This class represents an analyzed expression in the program.
unsigned short getExpressionSize() const
LLVM_ABI bool isOne() const
Return true if the expression is a constant one.
LLVM_ABI void computeAndSetCanonical(ScalarEvolution &SE)
Compute and set the canonical SCEV, by constructing a SCEV with the same operands,...
LLVM_ABI bool isZero() const
Return true if the expression is a constant zero.
const SCEV * CanonicalSCEV
Pointer to the canonical version of the SCEV, i.e.
LLVM_ABI void dump() const
This method is used for debugging.
LLVM_ABI bool isAllOnesValue() const
Return true if the expression is a constant all-ones value.
LLVM_ABI bool isNonConstantNegative() const
Return true if the specified scev is negated, but not a constant.
LLVM_ABI ArrayRef< SCEVUse > operands() const
Return operands of this SCEV expression.
LLVM_ABI void print(raw_ostream &OS) const
Print out the internal representation of this scalar to the specified stream.
SCEV(const FoldingSetNodeIDRef ID, SCEVTypes SCEVTy, unsigned short ExpressionSize)
SCEVTypes getSCEVType() const
LLVM_ABI Type * getType() const
Return the LLVM type of this SCEV expression.
NoWrapFlags
NoWrapFlags are bitfield indices into SubclassData.
Analysis pass that exposes the ScalarEvolution for a function.
LLVM_ABI ScalarEvolution run(Function &F, FunctionAnalysisManager &AM)
LLVM_ABI PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM)
LLVM_ABI PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM)
void getAnalysisUsage(AnalysisUsage &AU) const override
getAnalysisUsage - This function should be overriden by passes that need analysis information to do t...
void print(raw_ostream &OS, const Module *=nullptr) const override
print - Print out the internal state of the pass.
bool runOnFunction(Function &F) override
runOnFunction - Virtual method overriden by subclasses to do the per-function processing of the pass.
void releaseMemory() override
releaseMemory() - This member can be implemented by a pass if it wants to be able to release its memo...
void verifyAnalysis() const override
verifyAnalysis() - This member can be implemented by a analysis pass to check state of analysis infor...
static LLVM_ABI LoopGuards collect(const Loop *L, ScalarEvolution &SE)
Collect rewrite map for loop guards for loop L, together with flags indicating if NUW and NSW can be ...
LLVM_ABI const SCEV * rewrite(const SCEV *Expr) const
Try to apply the collected loop guards to Expr.
The main scalar evolution driver.
LLVM_ABI const SCEV * getUDivExpr(SCEVUse LHS, SCEVUse RHS)
Get a canonical unsigned division expression, or something simpler if possible.
const SCEV * getConstantMaxBackedgeTakenCount(const Loop *L)
When successful, this returns a SCEVConstant that is greater than or equal to (i.e.
static bool hasFlags(SCEV::NoWrapFlags Flags, SCEV::NoWrapFlags TestFlags)
const DataLayout & getDataLayout() const
Return the DataLayout associated with the module this SCEV instance is operating on.
LLVM_ABI bool isKnownNonNegative(const SCEV *S)
Test if the given expression is known to be non-negative.
LLVM_ABI bool isKnownOnEveryIteration(CmpPredicate Pred, const SCEVAddRecExpr *LHS, const SCEV *RHS)
Test if the condition described by Pred, LHS, RHS is known to be true on every iteration of the loop ...
LLVM_ABI const SCEV * getNegativeSCEV(const SCEV *V, SCEV::NoWrapFlags Flags=SCEV::FlagAnyWrap)
Return the SCEV object corresponding to -V.
LLVM_ABI std::optional< LoopInvariantPredicate > getLoopInvariantExitCondDuringFirstIterationsImpl(CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS, const Loop *L, const Instruction *CtxI, const SCEV *MaxIter)
LLVM_ABI const SCEV * getUDivCeilSCEV(const SCEV *N, const SCEV *D)
Compute ceil(N / D).
LLVM_ABI std::optional< LoopInvariantPredicate > getLoopInvariantExitCondDuringFirstIterations(CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS, const Loop *L, const Instruction *CtxI, const SCEV *MaxIter)
If the result of the predicate LHS Pred RHS is loop invariant with respect to L at given Context duri...
LLVM_ABI Type * getWiderType(Type *Ty1, Type *Ty2) const
LLVM_ABI const SCEV * getAbsExpr(const SCEV *Op, bool IsNSW)
LLVM_ABI bool isKnownNonPositive(const SCEV *S)
Test if the given expression is known to be non-positive.
LLVM_ABI bool isKnownNegative(const SCEV *S)
Test if the given expression is known to be negative.
LLVM_ABI const SCEV * getPredicatedConstantMaxBackedgeTakenCount(const Loop *L, SmallVectorImpl< const SCEVPredicate * > &Predicates)
Similar to getConstantMaxBackedgeTakenCount, except it will add a set of SCEV predicates to Predicate...
LLVM_ABI const SCEV * removePointerBase(const SCEV *S)
Compute an expression equivalent to S - getPointerBase(S).
LLVM_ABI bool isLoopEntryGuardedByCond(const Loop *L, CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS)
Test whether entry to the loop is protected by a conditional between LHS and RHS.
LLVM_ABI bool isKnownNonZero(const SCEV *S)
Test if the given expression is known to be non-zero.
LLVM_ABI const SCEV * getURemExpr(SCEVUse LHS, SCEVUse RHS)
Represents an unsigned remainder expression based on unsigned division.
LLVM_ABI const SCEV * getSCEVAtScope(const SCEV *S, const Loop *L)
Return a SCEV expression for the specified value at the specified scope in the program.
LLVM_ABI const SCEV * getBackedgeTakenCount(const Loop *L, ExitCountKind Kind=Exact)
If the specified loop has a predictable backedge-taken count, return it, otherwise return a SCEVCould...
LLVM_ABI const SCEV * getSMinExpr(SCEVUse LHS, SCEVUse RHS)
LLVM_ABI void setNoWrapFlags(SCEVAddRecExpr *AddRec, SCEV::NoWrapFlags Flags)
Update no-wrap flags of an AddRec.
LLVM_ABI const SCEV * getUMaxFromMismatchedTypes(const SCEV *LHS, const SCEV *RHS)
Promote the operands to the wider of the types using zero-extension, and then perform a umax operatio...
const SCEV * getZero(Type *Ty)
Return a SCEV for the constant 0 of a specific type.
LLVM_ABI bool willNotOverflow(Instruction::BinaryOps BinOp, bool Signed, const SCEV *LHS, const SCEV *RHS, const Instruction *CtxI=nullptr)
Is operation BinOp between LHS and RHS provably does not have a signed/unsigned overflow (Signed)?
LLVM_ABI ExitLimit computeExitLimitFromCond(const Loop *L, Value *ExitCond, bool ExitIfTrue, bool ControlsOnlyExit, bool AllowPredicates=false)
Compute the number of times the backedge of the specified loop will execute if its exit condition wer...
LLVM_ABI const SCEV * getZeroExtendExprImpl(const SCEV *Op, Type *Ty, unsigned Depth=0)
LLVM_ABI const SCEV * getMinMaxExpr(SCEVTypes Kind, SmallVectorImpl< SCEVUse > &Operands)
LLVM_ABI const SCEVPredicate * getEqualPredicate(const SCEV *LHS, const SCEV *RHS)
LLVM_ABI unsigned getSmallConstantTripMultiple(const Loop *L, const SCEV *ExitCount)
Returns the largest constant divisor of the trip count as a normal unsigned value,...
LLVM_ABI uint64_t getTypeSizeInBits(Type *Ty) const
Return the size in bits of the specified type, for which isSCEVable must return true.
LLVM_ABI const SCEV * getConstant(ConstantInt *V)
LLVM_ABI const SCEV * getPredicatedBackedgeTakenCount(const Loop *L, SmallVectorImpl< const SCEVPredicate * > &Predicates)
Similar to getBackedgeTakenCount, except it will add a set of SCEV predicates to Predicates that are ...
LLVM_ABI const SCEV * getSCEV(Value *V)
Return a SCEV expression for the full generality of the specified expression.
LLVM_ABI const SCEV * getMinusSCEV(SCEVUse LHS, SCEVUse RHS, SCEV::NoWrapFlags Flags=SCEV::FlagAnyWrap, unsigned Depth=0)
Return LHS-RHS.
ConstantRange getSignedRange(const SCEV *S)
Determine the signed range for a particular SCEV.
LLVM_ABI const SCEV * getAddRecExpr(SCEVUse Start, SCEVUse Step, const Loop *L, SCEV::NoWrapFlags Flags)
Get an add recurrence expression for the specified loop.
LLVM_ABI const SCEV * getNoopOrSignExtend(const SCEV *V, Type *Ty)
Return a SCEV corresponding to a conversion of the input value to the specified type.
bool loopHasNoAbnormalExits(const Loop *L)
Return true if the loop has no abnormal exits.
LLVM_ABI const SCEV * getTripCountFromExitCount(const SCEV *ExitCount)
A version of getTripCountFromExitCount below which always picks an evaluation type which can not resu...
LLVM_ABI ScalarEvolution(Function &F, TargetLibraryInfo &TLI, AssumptionCache &AC, DominatorTree &DT, LoopInfo &LI)
const SCEV * getOne(Type *Ty)
Return a SCEV for the constant 1 of a specific type.
LLVM_ABI const SCEV * getTruncateOrNoop(const SCEV *V, Type *Ty)
Return a SCEV corresponding to a conversion of the input value to the specified type.
LLVM_ABI const SCEV * getLosslessPtrToIntExpr(const SCEV *Op)
LLVM_ABI const SCEV * getCastExpr(SCEVTypes Kind, const SCEV *Op, Type *Ty)
LLVM_ABI const SCEV * getSequentialMinMaxExpr(SCEVTypes Kind, SmallVectorImpl< SCEVUse > &Operands)
LLVM_ABI std::optional< bool > evaluatePredicateAt(CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS, const Instruction *CtxI)
Check whether the condition described by Pred, LHS, and RHS is true or false in the given Context.
LLVM_ABI unsigned getSmallConstantMaxTripCount(const Loop *L, SmallVectorImpl< const SCEVPredicate * > *Predicates=nullptr)
Returns the upper bound of the loop trip count as a normal unsigned value.
LLVM_ABI const SCEV * getPtrToIntExpr(const SCEV *Op, Type *Ty)
LLVM_ABI bool isBackedgeTakenCountMaxOrZero(const Loop *L)
Return true if the backedge taken count is either the value returned by getConstantMaxBackedgeTakenCo...
LLVM_ABI void forgetLoop(const Loop *L)
This method should be called by the client when it has changed a loop in a way that may effect Scalar...
LLVM_ABI bool isLoopInvariant(const SCEV *S, const Loop *L)
Return true if the value of the given SCEV is unchanging in the specified loop.
LLVM_ABI bool isKnownPositive(const SCEV *S)
Test if the given expression is known to be positive.
LLVM_ABI bool SimplifyICmpOperands(CmpPredicate &Pred, SCEVUse &LHS, SCEVUse &RHS, unsigned Depth=0)
Simplify LHS and RHS in a comparison with predicate Pred.
APInt getUnsignedRangeMin(const SCEV *S)
Determine the min of the unsigned range for a particular SCEV.
LLVM_ABI const SCEV * getOffsetOfExpr(Type *IntTy, StructType *STy, unsigned FieldNo)
Return an expression for offsetof on the given field with type IntTy.
LLVM_ABI LoopDisposition getLoopDisposition(const SCEV *S, const Loop *L)
Return the "disposition" of the given SCEV with respect to the given loop.
LLVM_ABI bool containsAddRecurrence(const SCEV *S)
Return true if the SCEV is a scAddRecExpr or it contains scAddRecExpr.
LLVM_ABI const SCEV * getSignExtendExprImpl(const SCEV *Op, Type *Ty, unsigned Depth=0)
LLVM_ABI bool hasOperand(const SCEV *S, const SCEV *Op) const
Test whether the given SCEV has Op as a direct or indirect operand.
LLVM_ABI const SCEV * getZeroExtendExpr(const SCEV *Op, Type *Ty, unsigned Depth=0)
LLVM_ABI bool isSCEVable(Type *Ty) const
Test if values of the given type are analyzable within the SCEV framework.
LLVM_ABI Type * getEffectiveSCEVType(Type *Ty) const
Return a type with the same bitwidth as the given type and which represents how SCEV will treat the g...
LLVM_ABI const SCEVPredicate * getComparePredicate(ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS)
LLVM_ABI bool haveSameSign(const SCEV *S1, const SCEV *S2)
Return true if we know that S1 and S2 must have the same sign.
LLVM_ABI const SCEV * getNotSCEV(const SCEV *V)
Return the SCEV object corresponding to ~V.
LLVM_ABI const SCEV * getElementCount(Type *Ty, ElementCount EC, SCEV::NoWrapFlags Flags=SCEV::FlagAnyWrap)
LLVM_ABI bool instructionCouldExistWithOperands(const SCEV *A, const SCEV *B)
Return true if there exists a point in the program at which both A and B could be operands to the sam...
ConstantRange getUnsignedRange(const SCEV *S)
Determine the unsigned range for a particular SCEV.
LLVM_ABI void print(raw_ostream &OS) const
LLVM_ABI const SCEV * getPredicatedExitCount(const Loop *L, const BasicBlock *ExitingBlock, SmallVectorImpl< const SCEVPredicate * > *Predicates, ExitCountKind Kind=Exact)
Same as above except this uses the predicated backedge taken info and may require predicates.
static SCEV::NoWrapFlags clearFlags(SCEV::NoWrapFlags Flags, SCEV::NoWrapFlags OffFlags)
LLVM_ABI void forgetTopmostLoop(const Loop *L)
LLVM_ABI void forgetValue(Value *V)
This method should be called by the client when it has changed a value in a way that may effect its v...
APInt getSignedRangeMin(const SCEV *S)
Determine the min of the signed range for a particular SCEV.
LLVM_ABI const SCEV * getNoopOrAnyExtend(const SCEV *V, Type *Ty)
Return a SCEV corresponding to a conversion of the input value to the specified type.
LLVM_ABI void forgetBlockAndLoopDispositions(Value *V=nullptr)
Called when the client has changed the disposition of values in a loop or block.
LLVM_ABI const SCEV * getTruncateExpr(const SCEV *Op, Type *Ty, unsigned Depth=0)
LLVM_ABI const SCEV * getUMaxExpr(SCEVUse LHS, SCEVUse RHS)
LLVM_ABI std::optional< LoopInvariantPredicate > getLoopInvariantPredicate(CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS, const Loop *L, const Instruction *CtxI=nullptr)
If the result of the predicate LHS Pred RHS is loop invariant with respect to L, return a LoopInvaria...
LLVM_ABI const SCEV * getStoreSizeOfExpr(Type *IntTy, Type *StoreTy)
Return an expression for the store size of StoreTy that is type IntTy.
LLVM_ABI const SCEVPredicate * getWrapPredicate(const SCEVAddRecExpr *AR, SCEVWrapPredicate::IncrementWrapFlags AddedFlags)
LLVM_ABI bool isLoopBackedgeGuardedByCond(const Loop *L, CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS)
Test whether the backedge of the loop is protected by a conditional between LHS and RHS.
LLVM_ABI APInt getNonZeroConstantMultiple(const SCEV *S)
const SCEV * getMinusOne(Type *Ty)
Return a SCEV for the constant -1 of a specific type.
static SCEV::NoWrapFlags setFlags(SCEV::NoWrapFlags Flags, SCEV::NoWrapFlags OnFlags)
LLVM_ABI bool hasLoopInvariantBackedgeTakenCount(const Loop *L)
Return true if the specified loop has an analyzable loop-invariant backedge-taken count.
LLVM_ABI BlockDisposition getBlockDisposition(const SCEV *S, const BasicBlock *BB)
Return the "disposition" of the given SCEV with respect to the given block.
LLVM_ABI const SCEV * getNoopOrZeroExtend(const SCEV *V, Type *Ty)
Return a SCEV corresponding to a conversion of the input value to the specified type.
LLVM_ABI bool invalidate(Function &F, const PreservedAnalyses &PA, FunctionAnalysisManager::Invalidator &Inv)
LLVM_ABI const SCEV * getUMinFromMismatchedTypes(const SCEV *LHS, const SCEV *RHS, bool Sequential=false)
Promote the operands to the wider of the types using zero-extension, and then perform a umin operatio...
LLVM_ABI bool loopIsFiniteByAssumption(const Loop *L)
Return true if this loop is finite by assumption.
LLVM_ABI const SCEV * getExistingSCEV(Value *V)
Return an existing SCEV for V if there is one, otherwise return nullptr.
LLVM_ABI APInt getConstantMultiple(const SCEV *S, const Instruction *CtxI=nullptr)
Returns the max constant multiple of S.
LoopDisposition
An enum describing the relationship between a SCEV and a loop.
@ LoopComputable
The SCEV varies predictably with the loop.
@ LoopVariant
The SCEV is loop-variant (unknown).
@ LoopInvariant
The SCEV is loop-invariant.
LLVM_ABI bool isKnownMultipleOf(const SCEV *S, uint64_t M, SmallVectorImpl< const SCEVPredicate * > &Assumptions)
Check that S is a multiple of M.
LLVM_ABI const SCEV * getAnyExtendExpr(const SCEV *Op, Type *Ty)
getAnyExtendExpr - Return a SCEV for the given operand extended with unspecified bits out to the give...
LLVM_ABI bool isKnownToBeAPowerOfTwo(const SCEV *S, bool OrZero=false, bool OrNegative=false)
Test if the given expression is known to be a power of 2.
LLVM_ABI std::optional< SCEV::NoWrapFlags > getStrengthenedNoWrapFlagsFromBinOp(const OverflowingBinaryOperator *OBO)
Parse NSW/NUW flags from add/sub/mul IR binary operation Op into SCEV no-wrap flags,...
LLVM_ABI void forgetLcssaPhiWithNewPredecessor(Loop *L, PHINode *V)
Forget LCSSA phi node V of loop L to which a new predecessor was added, such that it may no longer be...
LLVM_ABI bool containsUndefs(const SCEV *S) const
Return true if the SCEV expression contains an undef value.
LLVM_ABI std::optional< MonotonicPredicateType > getMonotonicPredicateType(const SCEVAddRecExpr *LHS, ICmpInst::Predicate Pred)
If, for all loop invariant X, the predicate "LHS `Pred` X" is monotonically increasing or decreasing,...
LLVM_ABI const SCEV * getCouldNotCompute()
LLVM_ABI const SCEV * getMulExpr(SmallVectorImpl< SCEVUse > &Ops, SCEV::NoWrapFlags Flags=SCEV::FlagAnyWrap, unsigned Depth=0)
Get a canonical multiply expression, or something simpler if possible.
LLVM_ABI bool isAvailableAtLoopEntry(const SCEV *S, const Loop *L)
Determine if the SCEV can be evaluated at loop's entry.
LLVM_ABI uint32_t getMinTrailingZeros(const SCEV *S, const Instruction *CtxI=nullptr)
Determine the minimum number of zero bits that S is guaranteed to end in (at every loop iteration).
BlockDisposition
An enum describing the relationship between a SCEV and a basic block.
@ DominatesBlock
The SCEV dominates the block.
@ ProperlyDominatesBlock
The SCEV properly dominates the block.
@ DoesNotDominateBlock
The SCEV does not dominate the block.
LLVM_ABI const SCEV * getExitCount(const Loop *L, const BasicBlock *ExitingBlock, ExitCountKind Kind=Exact)
Return the number of times the backedge executes before the given exit would be taken; if not exactly...
LLVM_ABI const SCEV * getSignExtendExpr(const SCEV *Op, Type *Ty, unsigned Depth=0)
LLVM_ABI void getPoisonGeneratingValues(SmallPtrSetImpl< const Value * > &Result, const SCEV *S)
Return the set of Values that, if poison, will definitively result in S being poison as well.
LLVM_ABI void forgetLoopDispositions()
Called when the client has changed the disposition of values in this loop.
LLVM_ABI const SCEV * getVScale(Type *Ty)
LLVM_ABI unsigned getSmallConstantTripCount(const Loop *L)
Returns the exact trip count of the loop if we can compute it, and the result is a small constant.
LLVM_ABI bool hasComputableLoopEvolution(const SCEV *S, const Loop *L)
Return true if the given SCEV changes value in a known way in the specified loop.
LLVM_ABI const SCEV * getPointerBase(const SCEV *V)
Transitively follow the chain of pointer-type operands until reaching a SCEV that does not have a sin...
LLVM_ABI void forgetAllLoops()
LLVM_ABI bool dominates(const SCEV *S, const BasicBlock *BB)
Return true if elements that makes up the given SCEV dominate the specified basic block.
APInt getUnsignedRangeMax(const SCEV *S)
Determine the max of the unsigned range for a particular SCEV.
LLVM_ABI const SCEV * getAddExpr(SmallVectorImpl< SCEVUse > &Ops, SCEV::NoWrapFlags Flags=SCEV::FlagAnyWrap, unsigned Depth=0)
Get a canonical add expression, or something simpler if possible.
ExitCountKind
The terms "backedge taken count" and "exit count" are used interchangeably to refer to the number of ...
@ SymbolicMaximum
An expression which provides an upper bound on the exact trip count.
@ ConstantMaximum
A constant which provides an upper bound on the exact trip count.
@ Exact
An expression exactly describing the number of times the backedge has executed when a loop is exited.
LLVM_ABI bool isKnownPredicate(CmpPredicate Pred, SCEVUse LHS, SCEVUse RHS)
Test if the given expression is known to satisfy the condition described by Pred, LHS,...
LLVM_ABI const SCEV * applyLoopGuards(const SCEV *Expr, const Loop *L)
Try to apply information from loop guards for L to Expr.
LLVM_ABI const SCEV * getPtrToAddrExpr(const SCEV *Op)
LLVM_ABI const SCEVAddRecExpr * convertSCEVToAddRecWithPredicates(const SCEV *S, const Loop *L, SmallVectorImpl< const SCEVPredicate * > &Preds)
Tries to convert the S expression to an AddRec expression, adding additional predicates to Preds as r...
LLVM_ABI const SCEV * getSMaxExpr(SCEVUse LHS, SCEVUse RHS)
LLVM_ABI const SCEV * getElementSize(Instruction *Inst)
Return the size of an element read or written by Inst.
LLVM_ABI const SCEV * getSizeOfExpr(Type *IntTy, TypeSize Size)
Return an expression for a TypeSize.
LLVM_ABI std::optional< bool > evaluatePredicate(CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS)
Check whether the condition described by Pred, LHS, and RHS is true or false.
LLVM_ABI const SCEV * getUnknown(Value *V)
LLVM_ABI std::optional< std::pair< const SCEV *, SmallVector< const SCEVPredicate *, 3 > > > createAddRecFromPHIWithCasts(const SCEVUnknown *SymbolicPHI)
Checks if SymbolicPHI can be rewritten as an AddRecExpr under some Predicates.
LLVM_ABI const SCEV * getTruncateOrZeroExtend(const SCEV *V, Type *Ty, unsigned Depth=0)
Return a SCEV corresponding to a conversion of the input value to the specified type.
LLVM_ABI bool isKnownViaInduction(CmpPredicate Pred, SCEVUse LHS, SCEVUse RHS)
We'd like to check the predicate on every iteration of the most dominated loop between loops used in ...
static SCEV::NoWrapFlags maskFlags(SCEV::NoWrapFlags Flags, int Mask)
Convenient NoWrapFlags manipulation that hides enum casts and is visible in the ScalarEvolution name ...
LLVM_ABI std::optional< APInt > computeConstantDifference(const SCEV *LHS, const SCEV *RHS)
Compute LHS - RHS and returns the result as an APInt if it is a constant, and std::nullopt if it isn'...
LLVM_ABI bool properlyDominates(const SCEV *S, const BasicBlock *BB)
Return true if elements that makes up the given SCEV properly dominate the specified basic block.
LLVM_ABI const SCEV * getUDivExactExpr(SCEVUse LHS, SCEVUse RHS)
Get a canonical unsigned division expression, or something simpler if possible.
LLVM_ABI const SCEV * rewriteUsingPredicate(const SCEV *S, const Loop *L, const SCEVPredicate &A)
Re-writes the SCEV according to the Predicates in A.
LLVM_ABI std::pair< const SCEV *, const SCEV * > SplitIntoInitAndPostInc(const Loop *L, const SCEV *S)
Splits SCEV expression S into two SCEVs.
LLVM_ABI bool canReuseInstruction(const SCEV *S, Instruction *I, SmallVectorImpl< Instruction * > &DropPoisonGeneratingInsts)
Check whether it is poison-safe to represent the expression S using the instruction I.
LLVM_ABI bool isKnownPredicateAt(CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS, const Instruction *CtxI)
Test if the given expression is known to satisfy the condition described by Pred, LHS,...
LLVM_ABI const SCEV * getPredicatedSymbolicMaxBackedgeTakenCount(const Loop *L, SmallVectorImpl< const SCEVPredicate * > &Predicates)
Similar to getSymbolicMaxBackedgeTakenCount, except it will add a set of SCEV predicates to Predicate...
LLVM_ABI const SCEV * getGEPExpr(GEPOperator *GEP, ArrayRef< SCEVUse > IndexExprs)
Returns an expression for a GEP.
LLVM_ABI const SCEV * getUMinExpr(SCEVUse LHS, SCEVUse RHS, bool Sequential=false)
LLVM_ABI void registerUser(const SCEV *User, ArrayRef< const SCEV * > Ops)
Notify this ScalarEvolution that User directly uses SCEVs in Ops.
LLVM_ABI bool isBasicBlockEntryGuardedByCond(const BasicBlock *BB, CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS)
Test whether entry to the basic block is protected by a conditional between LHS and RHS.
LLVM_ABI const SCEV * getTruncateOrSignExtend(const SCEV *V, Type *Ty, unsigned Depth=0)
Return a SCEV corresponding to a conversion of the input value to the specified type.
LLVM_ABI bool containsErasedValue(const SCEV *S) const
Return true if the SCEV expression contains a Value that has been optimised out and is now a nullptr.
const SCEV * getSymbolicMaxBackedgeTakenCount(const Loop *L)
When successful, this returns a SCEV that is greater than or equal to (i.e.
APInt getSignedRangeMax(const SCEV *S)
Determine the max of the signed range for a particular SCEV.
LLVM_ABI void verify() const
LLVMContext & getContext() const
Implements a dense probed hash-table based set with some number of buckets stored inline.
Definition DenseSet.h:291
size_type size() const
Definition SmallPtrSet.h:99
A templated base class for SmallPtrSet which provides the typesafe interface that is common across al...
std::pair< iterator, bool > insert(PtrType Ptr)
Inserts Ptr if and only if there is no element in the container equal to Ptr.
bool contains(ConstPtrType Ptr) const
SmallPtrSet - This class implements a set which is optimized for holding SmallSize or less elements.
This class consists of common code factored out of the SmallVector class to reduce code duplication b...
reference emplace_back(ArgTypes &&... Args)
void reserve(size_type N)
iterator erase(const_iterator CI)
void append(ItTy in_start, ItTy in_end)
Add the specified range to the end of the SmallVector.
iterator insert(iterator I, T &&Elt)
void push_back(const T &Elt)
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
An instruction for storing to memory.
StringRef - Represent a constant reference to a string, i.e.
Definition StringRef.h:55
Used to lazily calculate structure layout information for a target machine, based on the DataLayout s...
Definition DataLayout.h:736
TypeSize getElementOffset(unsigned Idx) const
Definition DataLayout.h:767
TypeSize getSizeInBits() const
Definition DataLayout.h:747
Class to represent struct types.
Analysis pass providing the TargetLibraryInfo.
Provides information about what library functions are available for the current target.
The instances of the Type class are immutable: once they are created, they are never changed.
Definition Type.h:46
static LLVM_ABI IntegerType * getInt32Ty(LLVMContext &C)
Definition Type.cpp:313
bool isPointerTy() const
True if this is an instance of PointerType.
Definition Type.h:284
LLVM_ABI TypeSize getPrimitiveSizeInBits() const LLVM_READONLY
Return the basic size of this type if it is a primitive type.
Definition Type.cpp:201
static LLVM_ABI IntegerType * getInt1Ty(LLVMContext &C)
Definition Type.cpp:310
bool isIntOrPtrTy() const
Return true if this is an integer type or a pointer type.
Definition Type.h:272
bool isIntegerTy() const
True if this is an instance of IntegerType.
Definition Type.h:257
static LLVM_ABI IntegerType * getIntNTy(LLVMContext &C, unsigned N)
Definition Type.cpp:317
A Use represents the edge between a Value definition and its users.
Definition Use.h:35
op_range operands()
Definition User.h:267
Use & Op()
Definition User.h:171
Value * getOperand(unsigned i) const
Definition User.h:207
LLVM Value Representation.
Definition Value.h:75
Type * getType() const
All values are typed, get the type of this value.
Definition Value.h:255
LLVMContext & getContext() const
All values hold a context through their type.
Definition Value.h:258
unsigned getValueID() const
Return an ID for the concrete type of this object.
Definition Value.h:543
LLVM_ABI void printAsOperand(raw_ostream &O, bool PrintType=true, const Module *M=nullptr) const
Print the name of this Value out to the specified raw_ostream.
LLVM_ABI StringRef getName() const
Return a constant reference to the value's name.
Definition Value.cpp:318
constexpr bool isScalable() const
Returns whether the quantity is scaled by a runtime quantity (vscale).
Definition TypeSize.h:168
An efficient, type-erasing, non-owning reference to a callable.
const ParentTy * getParent() const
Definition ilist_node.h:34
This class implements an extremely fast bulk output stream that can only output to a stream.
Definition raw_ostream.h:53
raw_ostream & indent(unsigned NumSpaces)
indent - Insert 'NumSpaces' spaces.
Changed
#define llvm_unreachable(msg)
Marks that the current location is not supposed to be reachable.
constexpr char Align[]
Key for Kernel::Arg::Metadata::mAlign.
const APInt & smin(const APInt &A, const APInt &B)
Determine the smaller of two APInts considered to be signed.
Definition APInt.h:2266
const APInt & smax(const APInt &A, const APInt &B)
Determine the larger of two APInts considered to be signed.
Definition APInt.h:2271
const APInt & umin(const APInt &A, const APInt &B)
Determine the smaller of two APInts considered to be unsigned.
Definition APInt.h:2276
LLVM_ABI std::optional< APInt > SolveQuadraticEquationWrap(APInt A, APInt B, APInt C, unsigned RangeWidth)
Let q(n) = An^2 + Bn + C, and BW = bit width of the value range (e.g.
Definition APInt.cpp:2852
const APInt & umax(const APInt &A, const APInt &B)
Determine the larger of two APInts considered to be unsigned.
Definition APInt.h:2281
LLVM_ABI APInt GreatestCommonDivisor(APInt A, APInt B)
Compute GCD of two unsigned APInt values.
Definition APInt.cpp:818
@ Entry
Definition COFF.h:862
unsigned ID
LLVM IR allows to use arbitrary numbers as calling convention identifiers.
Definition CallingConv.h:24
@ C
The default llvm calling convention, compatible with C.
Definition CallingConv.h:34
int getMinValue(MCInstrInfo const &MCII, MCInst const &MCI)
Return the minimum value of an extendable operand.
@ BasicBlock
Various leaf nodes.
Definition ISDOpcodes.h:81
LLVM_ABI Function * getDeclarationIfExists(const Module *M, ID id)
Look up the Function declaration of the intrinsic id in the Module M and return it if it exists.
Predicate
Predicate - These are "(BI << 5) | BO" for various predicates.
BinaryOp_match< LHS, RHS, Instruction::AShr > m_AShr(const LHS &L, const RHS &R)
ap_match< APInt > m_APInt(const APInt *&Res)
Match a ConstantInt or splatted ConstantVector, binding the specified pointer to the contained APInt.
bool match(Val *V, const Pattern &P)
class_match< ConstantInt > m_ConstantInt()
Match an arbitrary ConstantInt and ignore it.
IntrinsicID_match m_Intrinsic()
Match intrinsic calls like this: m_Intrinsic<Intrinsic::fabs>(m_Value(X))
ThreeOps_match< Cond, LHS, RHS, Instruction::Select > m_Select(const Cond &C, const LHS &L, const RHS &R)
Matches SelectInst.
ExtractValue_match< Ind, Val_t > m_ExtractValue(const Val_t &V)
Match a single index ExtractValue instruction.
bind_ty< WithOverflowInst > m_WithOverflowInst(WithOverflowInst *&I)
Match a with overflow intrinsic, capturing it if we match.
auto m_LogicalOr()
Matches L || R where L and R are arbitrary values.
brc_match< Cond_t, bind_ty< BasicBlock >, bind_ty< BasicBlock > > m_Br(const Cond_t &C, BasicBlock *&T, BasicBlock *&F)
BinaryOp_match< LHS, RHS, Instruction::SDiv > m_SDiv(const LHS &L, const RHS &R)
class_match< Value > m_Value()
Match an arbitrary value and ignore it.
BinaryOp_match< LHS, RHS, Instruction::LShr > m_LShr(const LHS &L, const RHS &R)
BinaryOp_match< LHS, RHS, Instruction::Shl > m_Shl(const LHS &L, const RHS &R)
auto m_LogicalAnd()
Matches L && R where L and R are arbitrary values.
class_match< BasicBlock > m_BasicBlock()
Match an arbitrary basic block value and ignore it.
match_combine_or< LTy, RTy > m_CombineOr(const LTy &L, const RTy &R)
Combine two pattern matchers matching L || R.
class_match< const SCEVVScale > m_SCEVVScale()
bind_cst_ty m_scev_APInt(const APInt *&C)
Match an SCEV constant and bind it to an APInt.
cst_pred_ty< is_all_ones > m_scev_AllOnes()
Match an integer with all bits set.
SCEVUnaryExpr_match< SCEVZeroExtendExpr, Op0_t > m_scev_ZExt(const Op0_t &Op0)
is_undef_or_poison m_scev_UndefOrPoison()
Match an SCEVUnknown wrapping undef or poison.
class_match< const SCEVConstant > m_SCEVConstant()
cst_pred_ty< is_one > m_scev_One()
Match an integer 1.
specificloop_ty m_SpecificLoop(const Loop *L)
SCEVAffineAddRec_match< Op0_t, Op1_t, class_match< const Loop > > m_scev_AffineAddRec(const Op0_t &Op0, const Op1_t &Op1)
bind_ty< const SCEVMulExpr > m_scev_Mul(const SCEVMulExpr *&V)
SCEVUnaryExpr_match< SCEVSignExtendExpr, Op0_t > m_scev_SExt(const Op0_t &Op0)
cst_pred_ty< is_zero > m_scev_Zero()
Match an integer 0.
SCEVUnaryExpr_match< SCEVTruncateExpr, Op0_t > m_scev_Trunc(const Op0_t &Op0)
bool match(const SCEV *S, const Pattern &P)
SCEVBinaryExpr_match< SCEVUDivExpr, Op0_t, Op1_t > m_scev_UDiv(const Op0_t &Op0, const Op1_t &Op1)
specificscev_ty m_scev_Specific(const SCEV *S)
Match if we have a specific specified SCEV.
SCEVBinaryExpr_match< SCEVMulExpr, Op0_t, Op1_t, SCEV::FlagNUW, true > m_scev_c_NUWMul(const Op0_t &Op0, const Op1_t &Op1)
class_match< const Loop > m_Loop()
bind_ty< const SCEVAddExpr > m_scev_Add(const SCEVAddExpr *&V)
bind_ty< const SCEVUnknown > m_SCEVUnknown(const SCEVUnknown *&V)
SCEVBinaryExpr_match< SCEVMulExpr, Op0_t, Op1_t, SCEV::FlagAnyWrap, true > m_scev_c_Mul(const Op0_t &Op0, const Op1_t &Op1)
SCEVBinaryExpr_match< SCEVSMaxExpr, Op0_t, Op1_t > m_scev_SMax(const Op0_t &Op0, const Op1_t &Op1)
SCEVURem_match< Op0_t, Op1_t > m_scev_URem(Op0_t LHS, Op1_t RHS, ScalarEvolution &SE)
Match the mathematical pattern A - (A / B) * B, where A and B can be arbitrary expressions.
class_match< const SCEV > m_SCEV()
@ Valid
The data is already valid.
initializer< Ty > init(const Ty &Val)
LocationClass< Ty > location(Ty &L)
@ Switch
The "resume-switch" lowering, where there are separate resume and destroy functions that are shared b...
Definition CoroShape.h:31
constexpr double e
NodeAddr< PhiNode * > Phi
Definition RDFGraph.h:390
friend class Instruction
Iterator for Instructions in a `BasicBlock.
Definition BasicBlock.h:73
This is an optimization pass for GlobalISel generic memory operations.
void visitAll(const SCEV *Root, SV &Visitor)
Use SCEVTraversal to visit all nodes in the given expression tree.
auto drop_begin(T &&RangeOrContainer, size_t N=1)
Return a range covering RangeOrContainer with the first N elements excluded.
Definition STLExtras.h:316
@ Offset
Definition DWP.cpp:532
FunctionAddr VTableAddr Value
Definition InstrProf.h:137
LLVM_ATTRIBUTE_ALWAYS_INLINE DynamicAPInt gcd(const DynamicAPInt &A, const DynamicAPInt &B)
void stable_sort(R &&Range)
Definition STLExtras.h:2116
bool all_of(R &&range, UnaryPredicate P)
Provide wrappers to std::all_of which take ranges instead of having to pass begin/end explicitly.
Definition STLExtras.h:1739
SaveAndRestore(T &) -> SaveAndRestore< T >
Printable print(const GCNRegPressure &RP, const GCNSubtarget *ST=nullptr, unsigned DynamicVGPRBlockSize=0)
LLVM_ABI bool canCreatePoison(const Operator *Op, bool ConsiderFlagsAndMetadata=true)
LLVM_ABI bool mustTriggerUB(const Instruction *I, const SmallPtrSetImpl< const Value * > &KnownPoison)
Return true if the given instruction must trigger undefined behavior when I is executed with any oper...
LLVM_ABI bool canConstantFoldCallTo(const CallBase *Call, const Function *F)
canConstantFoldCallTo - Return true if its even possible to fold a call to the specified function.
InterleavedRange< Range > interleaved(const Range &R, StringRef Separator=", ", StringRef Prefix="", StringRef Suffix="")
Output range R as a sequence of interleaved elements.
decltype(auto) dyn_cast(const From &Val)
dyn_cast<X> - Return the argument parameter cast to the specified type.
Definition Casting.h:643
LLVM_ABI bool verifyFunction(const Function &F, raw_ostream *OS=nullptr)
Check a function for errors, useful for use when debugging a pass.
auto successors(const MachineBasicBlock *BB)
scope_exit(Callable) -> scope_exit< Callable >
constexpr from_range_t from_range
auto dyn_cast_if_present(const Y &Val)
dyn_cast_if_present<X> - Functionally identical to dyn_cast, except that a null (or none in the case ...
Definition Casting.h:732
bool set_is_subset(const S1Ty &S1, const S2Ty &S2)
set_is_subset(A, B) - Return true iff A in B
void append_range(Container &C, Range &&R)
Wrapper function to append range R to container C.
Definition STLExtras.h:2208
constexpr bool isUIntN(unsigned N, uint64_t x)
Checks if an unsigned integer fits into the given (dynamic) bit width.
Definition MathExtras.h:243
LLVM_ABI Constant * ConstantFoldCompareInstOperands(unsigned Predicate, Constant *LHS, Constant *RHS, const DataLayout &DL, const TargetLibraryInfo *TLI=nullptr, const Instruction *I=nullptr)
Attempt to constant fold a compare instruction (icmp/fcmp) with the specified operands.
void * PointerTy
LLVM_ABI bool VerifySCEV
auto uninitialized_copy(R &&Src, IterTy Dst)
Definition STLExtras.h:2111
bool isa_and_nonnull(const Y &Val)
Definition Casting.h:676
LLVM_ABI ConstantRange getConstantRangeFromMetadata(const MDNode &RangeMD)
Parse out a conservative ConstantRange from !range metadata.
int countr_zero(T Val)
Count number of 0's from the least significant bit to the most stopping at the first 1.
Definition bit.h:202
LLVM_ABI Value * simplifyInstruction(Instruction *I, const SimplifyQuery &Q)
See if we can compute a simplified version of this instruction.
LLVM_ABI bool isOverflowIntrinsicNoWrap(const WithOverflowInst *WO, const DominatorTree &DT)
Returns true if the arithmetic part of the WO 's result is used only along the paths control dependen...
DomTreeNodeBase< BasicBlock > DomTreeNode
Definition Dominators.h:94
LLVM_ABI bool matchSimpleRecurrence(const PHINode *P, BinaryOperator *&BO, Value *&Start, Value *&Step)
Attempt to match a simple first order recurrence cycle of the form: iv = phi Ty [Start,...
auto dyn_cast_or_null(const Y &Val)
Definition Casting.h:753
void erase(Container &C, ValueType V)
Wrapper function to remove a value from a container:
Definition STLExtras.h:2200
bool any_of(R &&range, UnaryPredicate P)
Provide wrappers to std::any_of which take ranges instead of having to pass begin/end explicitly.
Definition STLExtras.h:1746
iterator_range< pointee_iterator< WrappedIteratorT > > make_pointee_range(RangeT &&Range)
Definition iterator.h:341
auto reverse(ContainerTy &&C)
Definition STLExtras.h:408
LLVM_ABI bool isMustProgress(const Loop *L)
Return true if this loop can be assumed to make progress.
LLVM_ABI bool impliesPoison(const Value *ValAssumedPoison, const Value *V)
Return true if V is poison given that ValAssumedPoison is already poison.
LLVM_ABI bool isFinite(const Loop *L)
Return true if this loop can be assumed to run for a finite number of iterations.
LLVM_ABI void computeKnownBits(const Value *V, KnownBits &Known, const DataLayout &DL, AssumptionCache *AC=nullptr, const Instruction *CxtI=nullptr, const DominatorTree *DT=nullptr, bool UseInstrInfo=true, unsigned Depth=0)
Determine which bits of V are known to be either zero or one and return them in the KnownZero/KnownOn...
unsigned short computeExpressionSize(ArrayRef< SCEVUse > Args)
LLVM_ABI bool programUndefinedIfPoison(const Instruction *Inst)
LLVM_ABI raw_ostream & dbgs()
dbgs() - This returns a reference to a raw_ostream for debugging messages.
Definition Debug.cpp:207
bool isPointerTy(const Type *T)
Definition SPIRVUtils.h:368
FunctionAddr VTableAddr Count
Definition InstrProf.h:139
LLVM_ABI ConstantRange getVScaleRange(const Function *F, unsigned BitWidth)
Determine the possible constant range of vscale with the given bit width, based on the vscale_range f...
class LLVM_GSL_OWNER SmallVector
Forward declaration of SmallVector so that calculateSmallVectorDefaultInlinedElements can reference s...
bool isa(const From &Val)
isa<X> - Return true if the parameter to the template is an instance of one of the template type argu...
Definition Casting.h:547
LLVM_ATTRIBUTE_VISIBILITY_DEFAULT AnalysisKey InnerAnalysisManagerProxy< AnalysisManagerT, IRUnitT, ExtraArgTs... >::Key
LLVM_ABI bool isKnownNonZero(const Value *V, const SimplifyQuery &Q, unsigned Depth=0)
Return true if the given value is known to be non-zero when defined.
LLVM_ABI bool propagatesPoison(const Use &PoisonOp)
Return true if PoisonOp's user yields poison or raises UB if its operand PoisonOp is poison.
@ UMin
Unsigned integer min implemented in terms of select(cmp()).
@ Mul
Product of integers.
@ SMax
Signed integer max implemented in terms of select(cmp()).
@ SMin
Signed integer min implemented in terms of select(cmp()).
@ Add
Sum of integers.
@ UMax
Unsigned integer max implemented in terms of select(cmp()).
auto count(R &&Range, const E &Element)
Wrapper function around std::count to count the number of times an element Element occurs in the give...
Definition STLExtras.h:2012
DWARFExpression::Operation Op
auto max_element(R &&Range)
Provide wrappers to std::max_element which take ranges instead of having to pass begin/end explicitly...
Definition STLExtras.h:2088
raw_ostream & operator<<(raw_ostream &OS, const APFixedPoint &FX)
ArrayRef(const T &OneElt) -> ArrayRef< T >
LLVM_ABI unsigned ComputeNumSignBits(const Value *Op, const DataLayout &DL, AssumptionCache *AC=nullptr, const Instruction *CxtI=nullptr, const DominatorTree *DT=nullptr, bool UseInstrInfo=true, unsigned Depth=0)
Return the number of times the sign bit of the register is replicated into the other bits.
constexpr unsigned BitWidth
OutputIt move(R &&Range, OutputIt Out)
Provide wrappers to std::move which take ranges instead of having to pass begin/end explicitly.
Definition STLExtras.h:1917
LLVM_ABI bool isGuaranteedToTransferExecutionToSuccessor(const Instruction *I)
Return true if this function can prove that the instruction I will always transfer execution to one o...
auto count_if(R &&Range, UnaryPredicate P)
Wrapper function around std::count_if to count the number of times an element satisfying a given pred...
Definition STLExtras.h:2019
decltype(auto) cast(const From &Val)
cast<X> - Return the argument parameter cast to the specified type.
Definition Casting.h:559
constexpr bool isIntN(unsigned N, int64_t x)
Checks if an signed integer fits into the given (dynamic) bit width.
Definition MathExtras.h:248
auto predecessors(const MachineBasicBlock *BB)
bool is_contained(R &&Range, const E &Element)
Returns true if Element is found in Range.
Definition STLExtras.h:1947
iterator_range< df_iterator< T > > depth_first(const T &G)
auto seq(T Begin, T End)
Iterate over an integral type from Begin up to - but not including - End.
Definition Sequence.h:305
AnalysisManager< Function > FunctionAnalysisManager
Convenience typedef for the Function analysis manager.
LLVM_ABI bool isGuaranteedNotToBePoison(const Value *V, AssumptionCache *AC=nullptr, const Instruction *CtxI=nullptr, const DominatorTree *DT=nullptr, unsigned Depth=0)
Returns true if V cannot be poison, but may be undef.
LLVM_ABI Constant * ConstantFoldInstOperands(const Instruction *I, ArrayRef< Constant * > Ops, const DataLayout &DL, const TargetLibraryInfo *TLI=nullptr, bool AllowNonDeterministic=true)
ConstantFoldInstOperands - Attempt to constant fold an instruction with the specified operands.
bool SCEVExprContains(const SCEV *Root, PredTy Pred)
Return true if any node in Root satisfies the predicate Pred.
Implement std::hash so that hash_code can be used in STL containers.
Definition BitVector.h:870
void swap(llvm::BitVector &LHS, llvm::BitVector &RHS)
Implement std::swap in terms of BitVector swap.
Definition BitVector.h:872
#define N
#define NC
Definition regutils.h:42
A special type used by analysis passes to provide an address that identifies that particular analysis...
Definition Analysis.h:29
Incoming for lane mask phi as machine instruction, incoming register Reg and incoming block Block are...
static KnownBits makeConstant(const APInt &C)
Create known bits from a known constant.
Definition KnownBits.h:317
bool isNonNegative() const
Returns true if this value is known to be non-negative.
Definition KnownBits.h:108
static LLVM_ABI KnownBits ashr(const KnownBits &LHS, const KnownBits &RHS, bool ShAmtNonZero=false, bool Exact=false)
Compute known bits for ashr(LHS, RHS).
unsigned getBitWidth() const
Get the bit width of this value.
Definition KnownBits.h:44
static LLVM_ABI KnownBits lshr(const KnownBits &LHS, const KnownBits &RHS, bool ShAmtNonZero=false, bool Exact=false)
Compute known bits for lshr(LHS, RHS).
KnownBits zextOrTrunc(unsigned BitWidth) const
Return known bits for a zero extension or truncation of the value we're tracking.
Definition KnownBits.h:202
APInt getMaxValue() const
Return the maximal unsigned value possible given these KnownBits.
Definition KnownBits.h:148
APInt getMinValue() const
Return the minimal unsigned value possible given these KnownBits.
Definition KnownBits.h:132
bool isNegative() const
Returns true if this value is known to be negative.
Definition KnownBits.h:105
static LLVM_ABI KnownBits shl(const KnownBits &LHS, const KnownBits &RHS, bool NUW=false, bool NSW=false, bool ShAmtNonZero=false)
Compute known bits for shl(LHS, RHS).
An object of this class is returned by queries that could not be answered.
static LLVM_ABI bool classof(const SCEV *S)
Methods for support type inquiry through isa, cast, and dyn_cast:
void dump() const
This method is used for debugging.
void print(raw_ostream &OS) const
Print out the internal representation of this scalar to the specified stream.
This class defines a simple visitor class that may be used for various SCEV analysis purposes.
A utility class that uses RAII to save and restore the value of a variable.
Information about the number of loop iterations for which a loop exit's branch condition evaluates to...
LLVM_ABI ExitLimit(const SCEV *E)
Construct either an exact exit limit from a constant, or an unknown one from a SCEVCouldNotCompute.
SmallVector< const SCEVPredicate *, 4 > Predicates
A vector of predicate guards for this ExitLimit.