LLVM 23.0.0git
ScalarEvolution.cpp
Go to the documentation of this file.
1//===- ScalarEvolution.cpp - Scalar Evolution Analysis --------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file contains the implementation of the scalar evolution analysis
10// engine, which is used primarily to analyze expressions involving induction
11// variables in loops.
12//
13// There are several aspects to this library. First is the representation of
14// scalar expressions, which are represented as subclasses of the SCEV class.
15// These classes are used to represent certain types of subexpressions that we
16// can handle. We only create one SCEV of a particular shape, so
17// pointer-comparisons for equality are legal.
18//
19// One important aspect of the SCEV objects is that they are never cyclic, even
20// if there is a cycle in the dataflow for an expression (ie, a PHI node). If
21// the PHI node is one of the idioms that we can represent (e.g., a polynomial
22// recurrence) then we represent it directly as a recurrence node, otherwise we
23// represent it as a SCEVUnknown node.
24//
25// In addition to being able to represent expressions of various types, we also
26// have folders that are used to build the *canonical* representation for a
27// particular expression. These folders are capable of using a variety of
28// rewrite rules to simplify the expressions.
29//
30// Once the folders are defined, we can implement the more interesting
31// higher-level code, such as the code that recognizes PHI nodes of various
32// types, computes the execution count of a loop, etc.
33//
34// TODO: We should use these routines and value representations to implement
35// dependence analysis!
36//
37//===----------------------------------------------------------------------===//
38//
39// There are several good references for the techniques used in this analysis.
40//
41// Chains of recurrences -- a method to expedite the evaluation
42// of closed-form functions
43// Olaf Bachmann, Paul S. Wang, Eugene V. Zima
44//
45// On computational properties of chains of recurrences
46// Eugene V. Zima
47//
48// Symbolic Evaluation of Chains of Recurrences for Loop Optimization
49// Robert A. van Engelen
50//
51// Efficient Symbolic Analysis for Optimizing Compilers
52// Robert A. van Engelen
53//
54// Using the chains of recurrences algebra for data dependence testing and
55// induction variable substitution
56// MS Thesis, Johnie Birch
57//
58//===----------------------------------------------------------------------===//
59
61#include "llvm/ADT/APInt.h"
62#include "llvm/ADT/ArrayRef.h"
63#include "llvm/ADT/DenseMap.h"
65#include "llvm/ADT/FoldingSet.h"
66#include "llvm/ADT/STLExtras.h"
67#include "llvm/ADT/ScopeExit.h"
68#include "llvm/ADT/Sequence.h"
71#include "llvm/ADT/Statistic.h"
73#include "llvm/ADT/StringRef.h"
83#include "llvm/Config/llvm-config.h"
84#include "llvm/IR/Argument.h"
85#include "llvm/IR/BasicBlock.h"
86#include "llvm/IR/CFG.h"
87#include "llvm/IR/Constant.h"
89#include "llvm/IR/Constants.h"
90#include "llvm/IR/DataLayout.h"
92#include "llvm/IR/Dominators.h"
93#include "llvm/IR/Function.h"
94#include "llvm/IR/GlobalAlias.h"
95#include "llvm/IR/GlobalValue.h"
97#include "llvm/IR/InstrTypes.h"
98#include "llvm/IR/Instruction.h"
101#include "llvm/IR/Intrinsics.h"
102#include "llvm/IR/LLVMContext.h"
103#include "llvm/IR/Operator.h"
104#include "llvm/IR/PatternMatch.h"
105#include "llvm/IR/Type.h"
106#include "llvm/IR/Use.h"
107#include "llvm/IR/User.h"
108#include "llvm/IR/Value.h"
109#include "llvm/IR/Verifier.h"
111#include "llvm/Pass.h"
112#include "llvm/Support/Casting.h"
115#include "llvm/Support/Debug.h"
121#include <algorithm>
122#include <cassert>
123#include <climits>
124#include <cstdint>
125#include <cstdlib>
126#include <map>
127#include <memory>
128#include <numeric>
129#include <optional>
130#include <tuple>
131#include <utility>
132#include <vector>
133
134using namespace llvm;
135using namespace PatternMatch;
136using namespace SCEVPatternMatch;
137
138#define DEBUG_TYPE "scalar-evolution"
139
140STATISTIC(NumExitCountsComputed,
141 "Number of loop exits with predictable exit counts");
142STATISTIC(NumExitCountsNotComputed,
143 "Number of loop exits without predictable exit counts");
144STATISTIC(NumBruteForceTripCountsComputed,
145 "Number of loops with trip counts computed by force");
146
147#ifdef EXPENSIVE_CHECKS
148bool llvm::VerifySCEV = true;
149#else
150bool llvm::VerifySCEV = false;
151#endif
152
154 MaxBruteForceIterations("scalar-evolution-max-iterations", cl::ReallyHidden,
155 cl::desc("Maximum number of iterations SCEV will "
156 "symbolically execute a constant "
157 "derived loop"),
158 cl::init(100));
159
161 "verify-scev", cl::Hidden, cl::location(VerifySCEV),
162 cl::desc("Verify ScalarEvolution's backedge taken counts (slow)"));
164 "verify-scev-strict", cl::Hidden,
165 cl::desc("Enable stricter verification with -verify-scev is passed"));
166
168 "scev-verify-ir", cl::Hidden,
169 cl::desc("Verify IR correctness when making sensitive SCEV queries (slow)"),
170 cl::init(false));
171
173 "scev-mulops-inline-threshold", cl::Hidden,
174 cl::desc("Threshold for inlining multiplication operands into a SCEV"),
175 cl::init(32));
176
178 "scev-addops-inline-threshold", cl::Hidden,
179 cl::desc("Threshold for inlining addition operands into a SCEV"),
180 cl::init(500));
181
183 "scalar-evolution-max-scev-compare-depth", cl::Hidden,
184 cl::desc("Maximum depth of recursive SCEV complexity comparisons"),
185 cl::init(32));
186
188 "scalar-evolution-max-scev-operations-implication-depth", cl::Hidden,
189 cl::desc("Maximum depth of recursive SCEV operations implication analysis"),
190 cl::init(2));
191
193 "scalar-evolution-max-value-compare-depth", cl::Hidden,
194 cl::desc("Maximum depth of recursive value complexity comparisons"),
195 cl::init(2));
196
198 MaxArithDepth("scalar-evolution-max-arith-depth", cl::Hidden,
199 cl::desc("Maximum depth of recursive arithmetics"),
200 cl::init(32));
201
203 "scalar-evolution-max-constant-evolving-depth", cl::Hidden,
204 cl::desc("Maximum depth of recursive constant evolving"), cl::init(32));
205
207 MaxCastDepth("scalar-evolution-max-cast-depth", cl::Hidden,
208 cl::desc("Maximum depth of recursive SExt/ZExt/Trunc"),
209 cl::init(8));
210
212 MaxAddRecSize("scalar-evolution-max-add-rec-size", cl::Hidden,
213 cl::desc("Max coefficients in AddRec during evolving"),
214 cl::init(8));
215
217 HugeExprThreshold("scalar-evolution-huge-expr-threshold", cl::Hidden,
218 cl::desc("Size of the expression which is considered huge"),
219 cl::init(4096));
220
222 "scev-range-iter-threshold", cl::Hidden,
223 cl::desc("Threshold for switching to iteratively computing SCEV ranges"),
224 cl::init(32));
225
227 "scalar-evolution-max-loop-guard-collection-depth", cl::Hidden,
228 cl::desc("Maximum depth for recursive loop guard collection"), cl::init(1));
229
230static cl::opt<bool>
231ClassifyExpressions("scalar-evolution-classify-expressions",
232 cl::Hidden, cl::init(true),
233 cl::desc("When printing analysis, include information on every instruction"));
234
236 "scalar-evolution-use-expensive-range-sharpening", cl::Hidden,
237 cl::init(false),
238 cl::desc("Use more powerful methods of sharpening expression ranges. May "
239 "be costly in terms of compile time"));
240
242 "scalar-evolution-max-scc-analysis-depth", cl::Hidden,
243 cl::desc("Maximum amount of nodes to process while searching SCEVUnknown "
244 "Phi strongly connected components"),
245 cl::init(8));
246
247static cl::opt<bool>
248 EnableFiniteLoopControl("scalar-evolution-finite-loop", cl::Hidden,
249 cl::desc("Handle <= and >= in finite loops"),
250 cl::init(true));
251
253 "scalar-evolution-use-context-for-no-wrap-flag-strenghening", cl::Hidden,
254 cl::desc("Infer nuw/nsw flags using context where suitable"),
255 cl::init(true));
256
257//===----------------------------------------------------------------------===//
258// SCEV class definitions
259//===----------------------------------------------------------------------===//
260
262 // Leaf nodes are always their own canonical.
263 switch (getSCEVType()) {
264 case scConstant:
265 case scVScale:
266 case scUnknown:
267 CanonicalSCEV = this;
268 return;
269 default:
270 break;
271 }
272
273 // For all other expressions, check whether any immediate operand has a
274 // different canonical. Since operands are always created before their parent,
275 // their canonical pointers are already set — no recursion needed.
276 bool Changed = false;
278 for (SCEVUse Op : operands()) {
279 CanonOps.push_back(Op->getCanonical());
280 Changed |= CanonOps.back() != Op.getPointer();
281 }
282
283 if (!Changed) {
284 CanonicalSCEV = this;
285 return;
286 }
287
288 auto *NAry = dyn_cast<SCEVNAryExpr>(this);
289 SCEV::NoWrapFlags Flags = NAry ? NAry->getNoWrapFlags() : SCEV::FlagAnyWrap;
290 switch (getSCEVType()) {
291 case scPtrToAddr:
292 CanonicalSCEV = SE.getPtrToAddrExpr(CanonOps[0]);
293 return;
294 case scPtrToInt:
295 CanonicalSCEV = SE.getPtrToIntExpr(CanonOps[0], getType());
296 return;
297 case scTruncate:
298 CanonicalSCEV = SE.getTruncateExpr(CanonOps[0], getType());
299 return;
300 case scZeroExtend:
301 CanonicalSCEV = SE.getZeroExtendExpr(CanonOps[0], getType());
302 return;
303 case scSignExtend:
304 CanonicalSCEV = SE.getSignExtendExpr(CanonOps[0], getType());
305 return;
306 case scUDivExpr:
307 CanonicalSCEV = SE.getUDivExpr(CanonOps[0], CanonOps[1]);
308 return;
309 case scAddExpr:
310 CanonicalSCEV = SE.getAddExpr(CanonOps, Flags);
311 return;
312 case scMulExpr:
313 CanonicalSCEV = SE.getMulExpr(CanonOps, Flags);
314 return;
315 case scAddRecExpr:
317 CanonOps, cast<SCEVAddRecExpr>(this)->getLoop(), Flags);
318 return;
319 case scSMaxExpr:
320 CanonicalSCEV = SE.getSMaxExpr(CanonOps);
321 return;
322 case scUMaxExpr:
323 CanonicalSCEV = SE.getUMaxExpr(CanonOps);
324 return;
325 case scSMinExpr:
326 CanonicalSCEV = SE.getSMinExpr(CanonOps);
327 return;
328 case scUMinExpr:
329 CanonicalSCEV = SE.getUMinExpr(CanonOps);
330 return;
332 CanonicalSCEV = SE.getUMinExpr(CanonOps, /*Sequential=*/true);
333 return;
334 default:
335 llvm_unreachable("Unknown SCEV type");
336 }
337}
338
339//===----------------------------------------------------------------------===//
340// Implementation of the SCEV class.
341//
342
343#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
345 print(dbgs());
346 dbgs() << '\n';
347}
348#endif
349
350void SCEV::print(raw_ostream &OS) const {
351 switch (getSCEVType()) {
352 case scConstant:
353 cast<SCEVConstant>(this)->getValue()->printAsOperand(OS, false);
354 return;
355 case scVScale:
356 OS << "vscale";
357 return;
358 case scPtrToAddr:
359 case scPtrToInt: {
360 const SCEVCastExpr *PtrCast = cast<SCEVCastExpr>(this);
361 const SCEV *Op = PtrCast->getOperand();
362 StringRef OpS = getSCEVType() == scPtrToAddr ? "addr" : "int";
363 OS << "(ptrto" << OpS << " " << *Op->getType() << " " << *Op << " to "
364 << *PtrCast->getType() << ")";
365 return;
366 }
367 case scTruncate: {
368 const SCEVTruncateExpr *Trunc = cast<SCEVTruncateExpr>(this);
369 const SCEV *Op = Trunc->getOperand();
370 OS << "(trunc " << *Op->getType() << " " << *Op << " to "
371 << *Trunc->getType() << ")";
372 return;
373 }
374 case scZeroExtend: {
376 const SCEV *Op = ZExt->getOperand();
377 OS << "(zext " << *Op->getType() << " " << *Op << " to "
378 << *ZExt->getType() << ")";
379 return;
380 }
381 case scSignExtend: {
383 const SCEV *Op = SExt->getOperand();
384 OS << "(sext " << *Op->getType() << " " << *Op << " to "
385 << *SExt->getType() << ")";
386 return;
387 }
388 case scAddRecExpr: {
389 const SCEVAddRecExpr *AR = cast<SCEVAddRecExpr>(this);
390 OS << "{" << *AR->getOperand(0);
391 for (unsigned i = 1, e = AR->getNumOperands(); i != e; ++i)
392 OS << ",+," << *AR->getOperand(i);
393 OS << "}<";
394 if (AR->hasNoUnsignedWrap())
395 OS << "nuw><";
396 if (AR->hasNoSignedWrap())
397 OS << "nsw><";
398 if (AR->hasNoSelfWrap() && !AR->hasNoUnsignedWrap() &&
399 !AR->hasNoSignedWrap())
400 OS << "nw><";
401 AR->getLoop()->getHeader()->printAsOperand(OS, /*PrintType=*/false);
402 OS << ">";
403 return;
404 }
405 case scAddExpr:
406 case scMulExpr:
407 case scUMaxExpr:
408 case scSMaxExpr:
409 case scUMinExpr:
410 case scSMinExpr:
412 const SCEVNAryExpr *NAry = cast<SCEVNAryExpr>(this);
413 const char *OpStr = nullptr;
414 switch (NAry->getSCEVType()) {
415 case scAddExpr: OpStr = " + "; break;
416 case scMulExpr: OpStr = " * "; break;
417 case scUMaxExpr: OpStr = " umax "; break;
418 case scSMaxExpr: OpStr = " smax "; break;
419 case scUMinExpr:
420 OpStr = " umin ";
421 break;
422 case scSMinExpr:
423 OpStr = " smin ";
424 break;
426 OpStr = " umin_seq ";
427 break;
428 default:
429 llvm_unreachable("There are no other nary expression types.");
430 }
431 OS << "("
433 << ")";
434 switch (NAry->getSCEVType()) {
435 case scAddExpr:
436 case scMulExpr:
437 if (NAry->hasNoUnsignedWrap())
438 OS << "<nuw>";
439 if (NAry->hasNoSignedWrap())
440 OS << "<nsw>";
441 break;
442 default:
443 // Nothing to print for other nary expressions.
444 break;
445 }
446 return;
447 }
448 case scUDivExpr: {
449 const SCEVUDivExpr *UDiv = cast<SCEVUDivExpr>(this);
450 OS << "(" << *UDiv->getLHS() << " /u " << *UDiv->getRHS() << ")";
451 return;
452 }
453 case scUnknown:
454 cast<SCEVUnknown>(this)->getValue()->printAsOperand(OS, false);
455 return;
457 OS << "***COULDNOTCOMPUTE***";
458 return;
459 }
460 llvm_unreachable("Unknown SCEV kind!");
461}
462
464 switch (getSCEVType()) {
465 case scConstant:
466 return cast<SCEVConstant>(this)->getType();
467 case scVScale:
468 return cast<SCEVVScale>(this)->getType();
469 case scPtrToAddr:
470 case scPtrToInt:
471 case scTruncate:
472 case scZeroExtend:
473 case scSignExtend:
474 return cast<SCEVCastExpr>(this)->getType();
475 case scAddRecExpr:
476 return cast<SCEVAddRecExpr>(this)->getType();
477 case scMulExpr:
478 return cast<SCEVMulExpr>(this)->getType();
479 case scUMaxExpr:
480 case scSMaxExpr:
481 case scUMinExpr:
482 case scSMinExpr:
483 return cast<SCEVMinMaxExpr>(this)->getType();
485 return cast<SCEVSequentialMinMaxExpr>(this)->getType();
486 case scAddExpr:
487 return cast<SCEVAddExpr>(this)->getType();
488 case scUDivExpr:
489 return cast<SCEVUDivExpr>(this)->getType();
490 case scUnknown:
491 return cast<SCEVUnknown>(this)->getType();
493 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
494 }
495 llvm_unreachable("Unknown SCEV kind!");
496}
497
499 switch (getSCEVType()) {
500 case scConstant:
501 case scVScale:
502 case scUnknown:
503 return {};
504 case scPtrToAddr:
505 case scPtrToInt:
506 case scTruncate:
507 case scZeroExtend:
508 case scSignExtend:
509 return cast<SCEVCastExpr>(this)->operands();
510 case scAddRecExpr:
511 case scAddExpr:
512 case scMulExpr:
513 case scUMaxExpr:
514 case scSMaxExpr:
515 case scUMinExpr:
516 case scSMinExpr:
518 return cast<SCEVNAryExpr>(this)->operands();
519 case scUDivExpr:
520 return cast<SCEVUDivExpr>(this)->operands();
522 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
523 }
524 llvm_unreachable("Unknown SCEV kind!");
525}
526
527bool SCEV::isZero() const { return match(this, m_scev_Zero()); }
528
529bool SCEV::isOne() const { return match(this, m_scev_One()); }
530
531bool SCEV::isAllOnesValue() const { return match(this, m_scev_AllOnes()); }
532
535 if (!Mul) return false;
536
537 // If there is a constant factor, it will be first.
538 const SCEVConstant *SC = dyn_cast<SCEVConstant>(Mul->getOperand(0));
539 if (!SC) return false;
540
541 // Return true if the value is negative, this matches things like (-42 * V).
542 return SC->getAPInt().isNegative();
543}
544
547
549 return S->getSCEVType() == scCouldNotCompute;
550}
551
554 ID.AddInteger(scConstant);
555 ID.AddPointer(V);
556 void *IP = nullptr;
557 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
558 SCEV *S = new (SCEVAllocator) SCEVConstant(ID.Intern(SCEVAllocator), V);
559 UniqueSCEVs.InsertNode(S, IP);
560 S->computeAndSetCanonical(*this);
561 return S;
562}
563
565 return getConstant(ConstantInt::get(getContext(), Val));
566}
567
568const SCEV *
571 // TODO: Avoid implicit trunc?
572 // See https://github.com/llvm/llvm-project/issues/112510.
573 return getConstant(
574 ConstantInt::get(ITy, V, isSigned, /*ImplicitTrunc=*/true));
575}
576
579 ID.AddInteger(scVScale);
580 ID.AddPointer(Ty);
581 void *IP = nullptr;
582 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP))
583 return S;
584 SCEV *S = new (SCEVAllocator) SCEVVScale(ID.Intern(SCEVAllocator), Ty);
585 UniqueSCEVs.InsertNode(S, IP);
586 S->computeAndSetCanonical(*this);
587 return S;
588}
589
591 SCEV::NoWrapFlags Flags) {
592 const SCEV *Res = getConstant(Ty, EC.getKnownMinValue());
593 if (EC.isScalable())
594 Res = getMulExpr(Res, getVScale(Ty), Flags);
595 return Res;
596}
597
601
602SCEVPtrToAddrExpr::SCEVPtrToAddrExpr(const FoldingSetNodeIDRef ID,
603 const SCEV *Op, Type *ITy)
604 : SCEVCastExpr(ID, scPtrToAddr, Op, ITy) {
605 assert(getOperand()->getType()->isPointerTy() && Ty->isIntegerTy() &&
606 "Must be a non-bit-width-changing pointer-to-integer cast!");
607}
608
609SCEVPtrToIntExpr::SCEVPtrToIntExpr(const FoldingSetNodeIDRef ID, SCEVUse Op,
610 Type *ITy)
611 : SCEVCastExpr(ID, scPtrToInt, Op, ITy) {
612 assert(getOperand()->getType()->isPointerTy() && Ty->isIntegerTy() &&
613 "Must be a non-bit-width-changing pointer-to-integer cast!");
614}
615
620
621SCEVTruncateExpr::SCEVTruncateExpr(const FoldingSetNodeIDRef ID, SCEVUse op,
622 Type *ty)
624 assert(getOperand()->getType()->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
625 "Cannot truncate non-integer value!");
626}
627
628SCEVZeroExtendExpr::SCEVZeroExtendExpr(const FoldingSetNodeIDRef ID, SCEVUse op,
629 Type *ty)
631 assert(getOperand()->getType()->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
632 "Cannot zero extend non-integer value!");
633}
634
635SCEVSignExtendExpr::SCEVSignExtendExpr(const FoldingSetNodeIDRef ID, SCEVUse op,
636 Type *ty)
638 assert(getOperand()->getType()->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
639 "Cannot sign extend non-integer value!");
640}
641
643 // Clear this SCEVUnknown from various maps.
644 SE->forgetMemoizedResults({this});
645
646 // Remove this SCEVUnknown from the uniquing map.
647 SE->UniqueSCEVs.RemoveNode(this);
648
649 // Release the value.
650 setValPtr(nullptr);
651}
652
653void SCEVUnknown::allUsesReplacedWith(Value *New) {
654 // Clear this SCEVUnknown from various maps.
655 SE->forgetMemoizedResults({this});
656
657 // Remove this SCEVUnknown from the uniquing map.
658 SE->UniqueSCEVs.RemoveNode(this);
659
660 // Replace the value pointer in case someone is still using this SCEVUnknown.
661 setValPtr(New);
662}
663
664//===----------------------------------------------------------------------===//
665// SCEV Utilities
666//===----------------------------------------------------------------------===//
667
668/// Compare the two values \p LV and \p RV in terms of their "complexity" where
669/// "complexity" is a partial (and somewhat ad-hoc) relation used to order
670/// operands in SCEV expressions.
671static int CompareValueComplexity(const LoopInfo *const LI, Value *LV,
672 Value *RV, unsigned Depth) {
674 return 0;
675
676 // Order pointer values after integer values. This helps SCEVExpander form
677 // GEPs.
678 bool LIsPointer = LV->getType()->isPointerTy(),
679 RIsPointer = RV->getType()->isPointerTy();
680 if (LIsPointer != RIsPointer)
681 return (int)LIsPointer - (int)RIsPointer;
682
683 // Compare getValueID values.
684 unsigned LID = LV->getValueID(), RID = RV->getValueID();
685 if (LID != RID)
686 return (int)LID - (int)RID;
687
688 // Sort arguments by their position.
689 if (const auto *LA = dyn_cast<Argument>(LV)) {
690 const auto *RA = cast<Argument>(RV);
691 unsigned LArgNo = LA->getArgNo(), RArgNo = RA->getArgNo();
692 return (int)LArgNo - (int)RArgNo;
693 }
694
695 if (const auto *LGV = dyn_cast<GlobalValue>(LV)) {
696 const auto *RGV = cast<GlobalValue>(RV);
697
698 if (auto L = LGV->getLinkage() - RGV->getLinkage())
699 return L;
700
701 const auto IsGVNameSemantic = [&](const GlobalValue *GV) {
702 auto LT = GV->getLinkage();
703 return !(GlobalValue::isPrivateLinkage(LT) ||
705 };
706
707 // Use the names to distinguish the two values, but only if the
708 // names are semantically important.
709 if (IsGVNameSemantic(LGV) && IsGVNameSemantic(RGV))
710 return LGV->getName().compare(RGV->getName());
711 }
712
713 // For instructions, compare their loop depth, and their operand count. This
714 // is pretty loose.
715 if (const auto *LInst = dyn_cast<Instruction>(LV)) {
716 const auto *RInst = cast<Instruction>(RV);
717
718 // Compare loop depths.
719 const BasicBlock *LParent = LInst->getParent(),
720 *RParent = RInst->getParent();
721 if (LParent != RParent) {
722 unsigned LDepth = LI->getLoopDepth(LParent),
723 RDepth = LI->getLoopDepth(RParent);
724 if (LDepth != RDepth)
725 return (int)LDepth - (int)RDepth;
726 }
727
728 // Compare the number of operands.
729 unsigned LNumOps = LInst->getNumOperands(),
730 RNumOps = RInst->getNumOperands();
731 if (LNumOps != RNumOps)
732 return (int)LNumOps - (int)RNumOps;
733
734 for (unsigned Idx : seq(LNumOps)) {
735 int Result = CompareValueComplexity(LI, LInst->getOperand(Idx),
736 RInst->getOperand(Idx), Depth + 1);
737 if (Result != 0)
738 return Result;
739 }
740 }
741
742 return 0;
743}
744
745// Return negative, zero, or positive, if LHS is less than, equal to, or greater
746// than RHS, respectively. A three-way result allows recursive comparisons to be
747// more efficient.
748// If the max analysis depth was reached, return std::nullopt, assuming we do
749// not know if they are equivalent for sure.
750static std::optional<int>
751CompareSCEVComplexity(const LoopInfo *const LI, const SCEV *LHS,
752 const SCEV *RHS, DominatorTree &DT, unsigned Depth = 0) {
753 // Fast-path: SCEVs are uniqued so we can do a quick equality check.
754 if (LHS == RHS)
755 return 0;
756
757 // Primarily, sort the SCEVs by their getSCEVType().
758 SCEVTypes LType = LHS->getSCEVType(), RType = RHS->getSCEVType();
759 if (LType != RType)
760 return (int)LType - (int)RType;
761
763 return std::nullopt;
764
765 // Aside from the getSCEVType() ordering, the particular ordering
766 // isn't very important except that it's beneficial to be consistent,
767 // so that (a + b) and (b + a) don't end up as different expressions.
768 switch (LType) {
769 case scUnknown: {
770 const SCEVUnknown *LU = cast<SCEVUnknown>(LHS);
771 const SCEVUnknown *RU = cast<SCEVUnknown>(RHS);
772
773 int X =
774 CompareValueComplexity(LI, LU->getValue(), RU->getValue(), Depth + 1);
775 return X;
776 }
777
778 case scConstant: {
781
782 // Compare constant values.
783 const APInt &LA = LC->getAPInt();
784 const APInt &RA = RC->getAPInt();
785 unsigned LBitWidth = LA.getBitWidth(), RBitWidth = RA.getBitWidth();
786 if (LBitWidth != RBitWidth)
787 return (int)LBitWidth - (int)RBitWidth;
788 return LA.ult(RA) ? -1 : 1;
789 }
790
791 case scVScale: {
792 const auto *LTy = cast<IntegerType>(cast<SCEVVScale>(LHS)->getType());
793 const auto *RTy = cast<IntegerType>(cast<SCEVVScale>(RHS)->getType());
794 return LTy->getBitWidth() - RTy->getBitWidth();
795 }
796
797 case scAddRecExpr: {
800
801 // There is always a dominance between two recs that are used by one SCEV,
802 // so we can safely sort recs by loop header dominance. We require such
803 // order in getAddExpr.
804 const Loop *LLoop = LA->getLoop(), *RLoop = RA->getLoop();
805 if (LLoop != RLoop) {
806 const BasicBlock *LHead = LLoop->getHeader(), *RHead = RLoop->getHeader();
807 assert(LHead != RHead && "Two loops share the same header?");
808 if (DT.dominates(LHead, RHead))
809 return 1;
810 assert(DT.dominates(RHead, LHead) &&
811 "No dominance between recurrences used by one SCEV?");
812 return -1;
813 }
814
815 [[fallthrough]];
816 }
817
818 case scTruncate:
819 case scZeroExtend:
820 case scSignExtend:
821 case scPtrToAddr:
822 case scPtrToInt:
823 case scAddExpr:
824 case scMulExpr:
825 case scUDivExpr:
826 case scSMaxExpr:
827 case scUMaxExpr:
828 case scSMinExpr:
829 case scUMinExpr:
831 ArrayRef<SCEVUse> LOps = LHS->operands();
832 ArrayRef<SCEVUse> ROps = RHS->operands();
833
834 // Lexicographically compare n-ary-like expressions.
835 unsigned LNumOps = LOps.size(), RNumOps = ROps.size();
836 if (LNumOps != RNumOps)
837 return (int)LNumOps - (int)RNumOps;
838
839 for (unsigned i = 0; i != LNumOps; ++i) {
840 auto X = CompareSCEVComplexity(LI, LOps[i].getPointer(),
841 ROps[i].getPointer(), DT, Depth + 1);
842 if (X != 0)
843 return X;
844 }
845 return 0;
846 }
847
849 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
850 }
851 llvm_unreachable("Unknown SCEV kind!");
852}
853
854/// Given a list of SCEV objects, order them by their complexity, and group
855/// objects of the same complexity together by value. When this routine is
856/// finished, we know that any duplicates in the vector are consecutive and that
857/// complexity is monotonically increasing.
858///
859/// Note that we go take special precautions to ensure that we get deterministic
860/// results from this routine. In other words, we don't want the results of
861/// this to depend on where the addresses of various SCEV objects happened to
862/// land in memory.
864 DominatorTree &DT) {
865 if (Ops.size() < 2) return; // Noop
866
867 // Whether LHS has provably less complexity than RHS.
868 auto IsLessComplex = [&](SCEVUse LHS, SCEVUse RHS) {
869 auto Complexity = CompareSCEVComplexity(LI, LHS, RHS, DT);
870 return Complexity && *Complexity < 0;
871 };
872 if (Ops.size() == 2) {
873 // This is the common case, which also happens to be trivially simple.
874 // Special case it.
875 SCEVUse &LHS = Ops[0], &RHS = Ops[1];
876 if (IsLessComplex(RHS, LHS))
877 std::swap(LHS, RHS);
878 return;
879 }
880
881 // Do the rough sort by complexity.
883 Ops, [&](SCEVUse LHS, SCEVUse RHS) { return IsLessComplex(LHS, RHS); });
884
885 // Now that we are sorted by complexity, group elements of the same
886 // complexity. Note that this is, at worst, N^2, but the vector is likely to
887 // be extremely short in practice. Note that we take this approach because we
888 // do not want to depend on the addresses of the objects we are grouping.
889 for (unsigned i = 0, e = Ops.size(); i != e-2; ++i) {
890 const SCEV *S = Ops[i];
891 unsigned Complexity = S->getSCEVType();
892
893 // If there are any objects of the same complexity and same value as this
894 // one, group them.
895 for (unsigned j = i+1; j != e && Ops[j]->getSCEVType() == Complexity; ++j) {
896 if (Ops[j] == S) { // Found a duplicate.
897 // Move it to immediately after i'th element.
898 std::swap(Ops[i+1], Ops[j]);
899 ++i; // no need to rescan it.
900 if (i == e-2) return; // Done!
901 }
902 }
903 }
904}
905
906/// Returns true if \p Ops contains a huge SCEV (the subtree of S contains at
907/// least HugeExprThreshold nodes).
909 return any_of(Ops, [](const SCEV *S) {
911 });
912}
913
914/// Performs a number of common optimizations on the passed \p Ops. If the
915/// whole expression reduces down to a single operand, it will be returned.
916///
917/// The following optimizations are performed:
918/// * Fold constants using the \p Fold function.
919/// * Remove identity constants satisfying \p IsIdentity.
920/// * If a constant satisfies \p IsAbsorber, return it.
921/// * Sort operands by complexity.
922template <typename FoldT, typename IsIdentityT, typename IsAbsorberT>
923static const SCEV *
925 SmallVectorImpl<SCEVUse> &Ops, FoldT Fold,
926 IsIdentityT IsIdentity, IsAbsorberT IsAbsorber) {
927 const SCEVConstant *Folded = nullptr;
928 for (unsigned Idx = 0; Idx < Ops.size();) {
929 const SCEV *Op = Ops[Idx];
930 if (const auto *C = dyn_cast<SCEVConstant>(Op)) {
931 if (!Folded)
932 Folded = C;
933 else
934 Folded = cast<SCEVConstant>(
935 SE.getConstant(Fold(Folded->getAPInt(), C->getAPInt())));
936 Ops.erase(Ops.begin() + Idx);
937 continue;
938 }
939 ++Idx;
940 }
941
942 if (Ops.empty()) {
943 assert(Folded && "Must have folded value");
944 return Folded;
945 }
946
947 if (Folded && IsAbsorber(Folded->getAPInt()))
948 return Folded;
949
950 GroupByComplexity(Ops, &LI, DT);
951 if (Folded && !IsIdentity(Folded->getAPInt()))
952 Ops.insert(Ops.begin(), Folded);
953
954 return Ops.size() == 1 ? Ops[0] : nullptr;
955}
956
957//===----------------------------------------------------------------------===//
958// Simple SCEV method implementations
959//===----------------------------------------------------------------------===//
960
961/// Compute BC(It, K). The result has width W. Assume, K > 0.
962static const SCEV *BinomialCoefficient(const SCEV *It, unsigned K,
963 ScalarEvolution &SE,
964 Type *ResultTy) {
965 // Handle the simplest case efficiently.
966 if (K == 1)
967 return SE.getTruncateOrZeroExtend(It, ResultTy);
968
969 // We are using the following formula for BC(It, K):
970 //
971 // BC(It, K) = (It * (It - 1) * ... * (It - K + 1)) / K!
972 //
973 // Suppose, W is the bitwidth of the return value. We must be prepared for
974 // overflow. Hence, we must assure that the result of our computation is
975 // equal to the accurate one modulo 2^W. Unfortunately, division isn't
976 // safe in modular arithmetic.
977 //
978 // However, this code doesn't use exactly that formula; the formula it uses
979 // is something like the following, where T is the number of factors of 2 in
980 // K! (i.e. trailing zeros in the binary representation of K!), and ^ is
981 // exponentiation:
982 //
983 // BC(It, K) = (It * (It - 1) * ... * (It - K + 1)) / 2^T / (K! / 2^T)
984 //
985 // This formula is trivially equivalent to the previous formula. However,
986 // this formula can be implemented much more efficiently. The trick is that
987 // K! / 2^T is odd, and exact division by an odd number *is* safe in modular
988 // arithmetic. To do exact division in modular arithmetic, all we have
989 // to do is multiply by the inverse. Therefore, this step can be done at
990 // width W.
991 //
992 // The next issue is how to safely do the division by 2^T. The way this
993 // is done is by doing the multiplication step at a width of at least W + T
994 // bits. This way, the bottom W+T bits of the product are accurate. Then,
995 // when we perform the division by 2^T (which is equivalent to a right shift
996 // by T), the bottom W bits are accurate. Extra bits are okay; they'll get
997 // truncated out after the division by 2^T.
998 //
999 // In comparison to just directly using the first formula, this technique
1000 // is much more efficient; using the first formula requires W * K bits,
1001 // but this formula less than W + K bits. Also, the first formula requires
1002 // a division step, whereas this formula only requires multiplies and shifts.
1003 //
1004 // It doesn't matter whether the subtraction step is done in the calculation
1005 // width or the input iteration count's width; if the subtraction overflows,
1006 // the result must be zero anyway. We prefer here to do it in the width of
1007 // the induction variable because it helps a lot for certain cases; CodeGen
1008 // isn't smart enough to ignore the overflow, which leads to much less
1009 // efficient code if the width of the subtraction is wider than the native
1010 // register width.
1011 //
1012 // (It's possible to not widen at all by pulling out factors of 2 before
1013 // the multiplication; for example, K=2 can be calculated as
1014 // It/2*(It+(It*INT_MIN/INT_MIN)+-1). However, it requires
1015 // extra arithmetic, so it's not an obvious win, and it gets
1016 // much more complicated for K > 3.)
1017
1018 // Protection from insane SCEVs; this bound is conservative,
1019 // but it probably doesn't matter.
1020 if (K > 1000)
1021 return SE.getCouldNotCompute();
1022
1023 unsigned W = SE.getTypeSizeInBits(ResultTy);
1024
1025 // Calculate K! / 2^T and T; we divide out the factors of two before
1026 // multiplying for calculating K! / 2^T to avoid overflow.
1027 // Other overflow doesn't matter because we only care about the bottom
1028 // W bits of the result.
1029 APInt OddFactorial(W, 1);
1030 unsigned T = 1;
1031 for (unsigned i = 3; i <= K; ++i) {
1032 unsigned TwoFactors = countr_zero(i);
1033 T += TwoFactors;
1034 OddFactorial *= (i >> TwoFactors);
1035 }
1036
1037 // We need at least W + T bits for the multiplication step
1038 unsigned CalculationBits = W + T;
1039
1040 // Calculate 2^T, at width T+W.
1041 APInt DivFactor = APInt::getOneBitSet(CalculationBits, T);
1042
1043 // Calculate the multiplicative inverse of K! / 2^T;
1044 // this multiplication factor will perform the exact division by
1045 // K! / 2^T.
1046 APInt MultiplyFactor = OddFactorial.multiplicativeInverse();
1047
1048 // Calculate the product, at width T+W
1049 IntegerType *CalculationTy = IntegerType::get(SE.getContext(),
1050 CalculationBits);
1051 const SCEV *Dividend = SE.getTruncateOrZeroExtend(It, CalculationTy);
1052 for (unsigned i = 1; i != K; ++i) {
1053 const SCEV *S = SE.getMinusSCEV(It, SE.getConstant(It->getType(), i));
1054 Dividend = SE.getMulExpr(Dividend,
1055 SE.getTruncateOrZeroExtend(S, CalculationTy));
1056 }
1057
1058 // Divide by 2^T
1059 const SCEV *DivResult = SE.getUDivExpr(Dividend, SE.getConstant(DivFactor));
1060
1061 // Truncate the result, and divide by K! / 2^T.
1062
1063 return SE.getMulExpr(SE.getConstant(MultiplyFactor),
1064 SE.getTruncateOrZeroExtend(DivResult, ResultTy));
1065}
1066
1067/// Return the value of this chain of recurrences at the specified iteration
1068/// number. We can evaluate this recurrence by multiplying each element in the
1069/// chain by the binomial coefficient corresponding to it. In other words, we
1070/// can evaluate {A,+,B,+,C,+,D} as:
1071///
1072/// A*BC(It, 0) + B*BC(It, 1) + C*BC(It, 2) + D*BC(It, 3)
1073///
1074/// where BC(It, k) stands for binomial coefficient.
1076 ScalarEvolution &SE) const {
1077 return evaluateAtIteration(operands(), It, SE);
1078}
1079
1081 const SCEV *It,
1082 ScalarEvolution &SE) {
1083 assert(Operands.size() > 0);
1084 const SCEV *Result = Operands[0].getPointer();
1085 for (unsigned i = 1, e = Operands.size(); i != e; ++i) {
1086 // The computation is correct in the face of overflow provided that the
1087 // multiplication is performed _after_ the evaluation of the binomial
1088 // coefficient.
1089 const SCEV *Coeff = BinomialCoefficient(It, i, SE, Result->getType());
1090 if (isa<SCEVCouldNotCompute>(Coeff))
1091 return Coeff;
1092
1093 Result =
1094 SE.getAddExpr(Result, SE.getMulExpr(Operands[i].getPointer(), Coeff));
1095 }
1096 return Result;
1097}
1098
1099//===----------------------------------------------------------------------===//
1100// SCEV Expression folder implementations
1101//===----------------------------------------------------------------------===//
1102
1103/// The SCEVCastSinkingRewriter takes a scalar evolution expression,
1104/// which computes a pointer-typed value, and rewrites the whole expression
1105/// tree so that *all* the computations are done on integers, and the only
1106/// pointer-typed operands in the expression are SCEVUnknown.
1107/// The CreatePtrCast callback is invoked to create the actual conversion
1108/// (ptrtoint or ptrtoaddr) at the SCEVUnknown leaves.
1110 : public SCEVRewriteVisitor<SCEVCastSinkingRewriter> {
1112 using ConversionFn = function_ref<const SCEV *(const SCEVUnknown *)>;
1113 Type *TargetTy;
1114 ConversionFn CreatePtrCast;
1115
1116public:
1118 ConversionFn CreatePtrCast)
1119 : Base(SE), TargetTy(TargetTy), CreatePtrCast(std::move(CreatePtrCast)) {}
1120
1121 static const SCEV *rewrite(const SCEV *Scev, ScalarEvolution &SE,
1122 Type *TargetTy, ConversionFn CreatePtrCast) {
1123 SCEVCastSinkingRewriter Rewriter(SE, TargetTy, std::move(CreatePtrCast));
1124 return Rewriter.visit(Scev);
1125 }
1126
1127 const SCEV *visit(const SCEV *S) {
1128 Type *STy = S->getType();
1129 // If the expression is not pointer-typed, just keep it as-is.
1130 if (!STy->isPointerTy())
1131 return S;
1132 // Else, recursively sink the cast down into it.
1133 return Base::visit(S);
1134 }
1135
1136 const SCEV *visitAddExpr(const SCEVAddExpr *Expr) {
1137 // Preserve wrap flags on rewritten SCEVAddExpr, which the default
1138 // implementation drops.
1139 SmallVector<SCEVUse, 2> Operands;
1140 bool Changed = false;
1141 for (SCEVUse Op : Expr->operands()) {
1142 Operands.push_back(visit(Op.getPointer()));
1143 Changed |= Op.getPointer() != Operands.back();
1144 }
1145 return !Changed ? Expr : SE.getAddExpr(Operands, Expr->getNoWrapFlags());
1146 }
1147
1148 const SCEV *visitMulExpr(const SCEVMulExpr *Expr) {
1149 SmallVector<SCEVUse, 2> Operands;
1150 bool Changed = false;
1151 for (SCEVUse Op : Expr->operands()) {
1152 Operands.push_back(visit(Op.getPointer()));
1153 Changed |= Op.getPointer() != Operands.back();
1154 }
1155 return !Changed ? Expr : SE.getMulExpr(Operands, Expr->getNoWrapFlags());
1156 }
1157
1158 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
1159 assert(Expr->getType()->isPointerTy() &&
1160 "Should only reach pointer-typed SCEVUnknown's.");
1161 // Perform some basic constant folding. If the operand of the cast is a
1162 // null pointer, don't create a cast SCEV expression (that will be left
1163 // as-is), but produce a zero constant.
1165 return SE.getZero(TargetTy);
1166 return CreatePtrCast(Expr);
1167 }
1168};
1169
1171 assert(Op->getType()->isPointerTy() && "Op must be a pointer");
1172
1173 // It isn't legal for optimizations to construct new ptrtoint expressions
1174 // for non-integral pointers.
1175 if (getDataLayout().isNonIntegralPointerType(Op->getType()))
1176 return getCouldNotCompute();
1177
1178 Type *IntPtrTy = getDataLayout().getIntPtrType(Op->getType());
1179
1180 // We can only trivially model ptrtoint if SCEV's effective (integer) type
1181 // is sufficiently wide to represent all possible pointer values.
1182 // We could theoretically teach SCEV to truncate wider pointers, but
1183 // that isn't implemented for now.
1185 getDataLayout().getTypeSizeInBits(IntPtrTy))
1186 return getCouldNotCompute();
1187
1188 // Use the rewriter to sink the cast down to SCEVUnknown leaves.
1190 Op, *this, IntPtrTy, [this, IntPtrTy](const SCEVUnknown *U) {
1192 ID.AddInteger(scPtrToInt);
1193 ID.AddPointer(U);
1194 void *IP = nullptr;
1195 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP))
1196 return S;
1197 SCEV *S = new (SCEVAllocator)
1198 SCEVPtrToIntExpr(ID.Intern(SCEVAllocator), U, IntPtrTy);
1199 UniqueSCEVs.InsertNode(S, IP);
1200 S->computeAndSetCanonical(*this);
1201 registerUser(S, U);
1202 return static_cast<const SCEV *>(S);
1203 });
1204 assert(IntOp->getType()->isIntegerTy() &&
1205 "We must have succeeded in sinking the cast, "
1206 "and ending up with an integer-typed expression!");
1207 return IntOp;
1208}
1209
1211 assert(Op->getType()->isPointerTy() && "Op must be a pointer");
1212
1213 // Treat pointers with unstable representation conservatively, since the
1214 // address bits may change.
1215 if (DL.hasUnstableRepresentation(Op->getType()))
1216 return getCouldNotCompute();
1217
1218 Type *Ty = DL.getAddressType(Op->getType());
1219
1220 // Use the rewriter to sink the cast down to SCEVUnknown leaves.
1221 // The rewriter handles null pointer constant folding.
1223 Op, *this, Ty, [this, Ty](const SCEVUnknown *U) {
1225 ID.AddInteger(scPtrToAddr);
1226 ID.AddPointer(U);
1227 ID.AddPointer(Ty);
1228 void *IP = nullptr;
1229 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP))
1230 return S;
1231 SCEV *S = new (SCEVAllocator)
1232 SCEVPtrToAddrExpr(ID.Intern(SCEVAllocator), U, Ty);
1233 UniqueSCEVs.InsertNode(S, IP);
1234 S->computeAndSetCanonical(*this);
1235 registerUser(S, U);
1236 return static_cast<const SCEV *>(S);
1237 });
1238 assert(IntOp->getType()->isIntegerTy() &&
1239 "We must have succeeded in sinking the cast, "
1240 "and ending up with an integer-typed expression!");
1241 return IntOp;
1242}
1243
1245 assert(Ty->isIntegerTy() && "Target type must be an integer type!");
1246
1247 const SCEV *IntOp = getLosslessPtrToIntExpr(Op);
1248 if (isa<SCEVCouldNotCompute>(IntOp))
1249 return IntOp;
1250
1251 return getTruncateOrZeroExtend(IntOp, Ty);
1252}
1253
1255 unsigned Depth) {
1256 assert(getTypeSizeInBits(Op->getType()) > getTypeSizeInBits(Ty) &&
1257 "This is not a truncating conversion!");
1258 assert(isSCEVable(Ty) &&
1259 "This is not a conversion to a SCEVable type!");
1260 assert(!Op->getType()->isPointerTy() && "Can't truncate pointer!");
1261 Ty = getEffectiveSCEVType(Ty);
1262
1264 ID.AddInteger(scTruncate);
1265 ID.AddPointer(Op);
1266 ID.AddPointer(Ty);
1267 void *IP = nullptr;
1268 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
1269
1270 // Fold if the operand is constant.
1271 if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op))
1272 return getConstant(
1273 cast<ConstantInt>(ConstantExpr::getTrunc(SC->getValue(), Ty)));
1274
1275 // trunc(trunc(x)) --> trunc(x)
1277 return getTruncateExpr(ST->getOperand(), Ty, Depth + 1);
1278
1279 // trunc(sext(x)) --> sext(x) if widening or trunc(x) if narrowing
1281 return getTruncateOrSignExtend(SS->getOperand(), Ty, Depth + 1);
1282
1283 // trunc(zext(x)) --> zext(x) if widening or trunc(x) if narrowing
1285 return getTruncateOrZeroExtend(SZ->getOperand(), Ty, Depth + 1);
1286
1287 if (Depth > MaxCastDepth) {
1288 SCEV *S =
1289 new (SCEVAllocator) SCEVTruncateExpr(ID.Intern(SCEVAllocator), Op, Ty);
1290 UniqueSCEVs.InsertNode(S, IP);
1291 S->computeAndSetCanonical(*this);
1292 registerUser(S, Op);
1293 return S;
1294 }
1295
1296 // trunc(x1 + ... + xN) --> trunc(x1) + ... + trunc(xN) and
1297 // trunc(x1 * ... * xN) --> trunc(x1) * ... * trunc(xN),
1298 // if after transforming we have at most one truncate, not counting truncates
1299 // that replace other casts.
1301 auto *CommOp = cast<SCEVCommutativeExpr>(Op);
1302 SmallVector<SCEVUse, 4> Operands;
1303 unsigned numTruncs = 0;
1304 for (unsigned i = 0, e = CommOp->getNumOperands(); i != e && numTruncs < 2;
1305 ++i) {
1306 const SCEV *S = getTruncateExpr(CommOp->getOperand(i), Ty, Depth + 1);
1307 if (!isa<SCEVIntegralCastExpr>(CommOp->getOperand(i)) &&
1309 numTruncs++;
1310 Operands.push_back(S);
1311 }
1312 if (numTruncs < 2) {
1313 if (isa<SCEVAddExpr>(Op))
1314 return getAddExpr(Operands);
1315 if (isa<SCEVMulExpr>(Op))
1316 return getMulExpr(Operands);
1317 llvm_unreachable("Unexpected SCEV type for Op.");
1318 }
1319 // Although we checked in the beginning that ID is not in the cache, it is
1320 // possible that during recursion and different modification ID was inserted
1321 // into the cache. So if we find it, just return it.
1322 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP))
1323 return S;
1324 }
1325
1326 // If the input value is a chrec scev, truncate the chrec's operands.
1327 if (const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(Op)) {
1328 SmallVector<SCEVUse, 4> Operands;
1329 for (const SCEV *Op : AddRec->operands())
1330 Operands.push_back(getTruncateExpr(Op, Ty, Depth + 1));
1331 return getAddRecExpr(Operands, AddRec->getLoop(), SCEV::FlagAnyWrap);
1332 }
1333
1334 // Return zero if truncating to known zeros.
1335 uint32_t MinTrailingZeros = getMinTrailingZeros(Op);
1336 if (MinTrailingZeros >= getTypeSizeInBits(Ty))
1337 return getZero(Ty);
1338
1339 // The cast wasn't folded; create an explicit cast node. We can reuse
1340 // the existing insert position since if we get here, we won't have
1341 // made any changes which would invalidate it.
1342 SCEV *S = new (SCEVAllocator) SCEVTruncateExpr(ID.Intern(SCEVAllocator),
1343 Op, Ty);
1344 UniqueSCEVs.InsertNode(S, IP);
1345 S->computeAndSetCanonical(*this);
1346 registerUser(S, Op);
1347 return S;
1348}
1349
1350// Get the limit of a recurrence such that incrementing by Step cannot cause
1351// signed overflow as long as the value of the recurrence within the
1352// loop does not exceed this limit before incrementing.
1353static const SCEV *getSignedOverflowLimitForStep(const SCEV *Step,
1354 ICmpInst::Predicate *Pred,
1355 ScalarEvolution *SE) {
1356 unsigned BitWidth = SE->getTypeSizeInBits(Step->getType());
1357 if (SE->isKnownPositive(Step)) {
1358 *Pred = ICmpInst::ICMP_SLT;
1360 SE->getSignedRangeMax(Step));
1361 }
1362 if (SE->isKnownNegative(Step)) {
1363 *Pred = ICmpInst::ICMP_SGT;
1365 SE->getSignedRangeMin(Step));
1366 }
1367 return nullptr;
1368}
1369
1370// Get the limit of a recurrence such that incrementing by Step cannot cause
1371// unsigned overflow as long as the value of the recurrence within the loop does
1372// not exceed this limit before incrementing.
1374 ICmpInst::Predicate *Pred,
1375 ScalarEvolution *SE) {
1376 unsigned BitWidth = SE->getTypeSizeInBits(Step->getType());
1377 *Pred = ICmpInst::ICMP_ULT;
1378
1380 SE->getUnsignedRangeMax(Step));
1381}
1382
1383namespace {
1384
1385struct ExtendOpTraitsBase {
1386 typedef const SCEV *(ScalarEvolution::*GetExtendExprTy)(const SCEV *, Type *,
1387 unsigned);
1388};
1389
1390// Used to make code generic over signed and unsigned overflow.
1391template <typename ExtendOp> struct ExtendOpTraits {
1392 // Members present:
1393 //
1394 // static const SCEV::NoWrapFlags WrapType;
1395 //
1396 // static const ExtendOpTraitsBase::GetExtendExprTy GetExtendExpr;
1397 //
1398 // static const SCEV *getOverflowLimitForStep(const SCEV *Step,
1399 // ICmpInst::Predicate *Pred,
1400 // ScalarEvolution *SE);
1401};
1402
1403template <>
1404struct ExtendOpTraits<SCEVSignExtendExpr> : public ExtendOpTraitsBase {
1405 static const SCEV::NoWrapFlags WrapType = SCEV::FlagNSW;
1406
1407 static const GetExtendExprTy GetExtendExpr;
1408
1409 static const SCEV *getOverflowLimitForStep(const SCEV *Step,
1410 ICmpInst::Predicate *Pred,
1411 ScalarEvolution *SE) {
1412 return getSignedOverflowLimitForStep(Step, Pred, SE);
1413 }
1414};
1415
1416const ExtendOpTraitsBase::GetExtendExprTy ExtendOpTraits<
1418
1419template <>
1420struct ExtendOpTraits<SCEVZeroExtendExpr> : public ExtendOpTraitsBase {
1421 static const SCEV::NoWrapFlags WrapType = SCEV::FlagNUW;
1422
1423 static const GetExtendExprTy GetExtendExpr;
1424
1425 static const SCEV *getOverflowLimitForStep(const SCEV *Step,
1426 ICmpInst::Predicate *Pred,
1427 ScalarEvolution *SE) {
1428 return getUnsignedOverflowLimitForStep(Step, Pred, SE);
1429 }
1430};
1431
1432const ExtendOpTraitsBase::GetExtendExprTy ExtendOpTraits<
1434
1435} // end anonymous namespace
1436
1437// The recurrence AR has been shown to have no signed/unsigned wrap or something
1438// close to it. Typically, if we can prove NSW/NUW for AR, then we can just as
1439// easily prove NSW/NUW for its preincrement or postincrement sibling. This
1440// allows normalizing a sign/zero extended AddRec as such: {sext/zext(Step +
1441// Start),+,Step} => {(Step + sext/zext(Start),+,Step} As a result, the
1442// expression "Step + sext/zext(PreIncAR)" is congruent with
1443// "sext/zext(PostIncAR)"
1444template <typename ExtendOpTy>
1445static const SCEV *getPreStartForExtend(const SCEVAddRecExpr *AR, Type *Ty,
1446 ScalarEvolution *SE, unsigned Depth) {
1447 auto WrapType = ExtendOpTraits<ExtendOpTy>::WrapType;
1448 auto GetExtendExpr = ExtendOpTraits<ExtendOpTy>::GetExtendExpr;
1449
1450 const Loop *L = AR->getLoop();
1451 const SCEV *Start = AR->getStart();
1452 const SCEV *Step = AR->getStepRecurrence(*SE);
1453
1454 // Check for a simple looking step prior to loop entry.
1455 const SCEVAddExpr *SA = dyn_cast<SCEVAddExpr>(Start);
1456 if (!SA)
1457 return nullptr;
1458
1459 // Create an AddExpr for "PreStart" after subtracting Step. Full SCEV
1460 // subtraction is expensive. For this purpose, perform a quick and dirty
1461 // difference, by checking for Step in the operand list. Note, that
1462 // SA might have repeated ops, like %a + %a + ..., so only remove one.
1463 SmallVector<SCEVUse, 4> DiffOps(SA->operands());
1464 for (auto It = DiffOps.begin(); It != DiffOps.end(); ++It)
1465 if (*It == Step) {
1466 DiffOps.erase(It);
1467 break;
1468 }
1469
1470 if (DiffOps.size() == SA->getNumOperands())
1471 return nullptr;
1472
1473 // Try to prove `WrapType` (SCEV::FlagNSW or SCEV::FlagNUW) on `PreStart` +
1474 // `Step`:
1475
1476 // 1. NSW/NUW flags on the step increment.
1477 auto PreStartFlags =
1479 const SCEV *PreStart = SE->getAddExpr(DiffOps, PreStartFlags);
1481 SE->getAddRecExpr(PreStart, Step, L, SCEV::FlagAnyWrap));
1482
1483 // "{S,+,X} is <nsw>/<nuw>" and "the backedge is taken at least once" implies
1484 // "S+X does not sign/unsign-overflow".
1485 //
1486
1487 const SCEV *BECount = SE->getBackedgeTakenCount(L);
1488 if (PreAR && any(PreAR->getNoWrapFlags(WrapType)) &&
1489 !isa<SCEVCouldNotCompute>(BECount) && SE->isKnownPositive(BECount))
1490 return PreStart;
1491
1492 // 2. Direct overflow check on the step operation's expression.
1493 unsigned BitWidth = SE->getTypeSizeInBits(AR->getType());
1494 Type *WideTy = IntegerType::get(SE->getContext(), BitWidth * 2);
1495 const SCEV *OperandExtendedStart =
1496 SE->getAddExpr((SE->*GetExtendExpr)(PreStart, WideTy, Depth),
1497 (SE->*GetExtendExpr)(Step, WideTy, Depth));
1498 if ((SE->*GetExtendExpr)(Start, WideTy, Depth) == OperandExtendedStart) {
1499 if (PreAR && any(AR->getNoWrapFlags(WrapType))) {
1500 // If we know `AR` == {`PreStart`+`Step`,+,`Step`} is `WrapType` (FlagNSW
1501 // or FlagNUW) and that `PreStart` + `Step` is `WrapType` too, then
1502 // `PreAR` == {`PreStart`,+,`Step`} is also `WrapType`. Cache this fact.
1503 SE->setNoWrapFlags(const_cast<SCEVAddRecExpr *>(PreAR), WrapType);
1504 }
1505 return PreStart;
1506 }
1507
1508 // 3. Loop precondition.
1510 const SCEV *OverflowLimit =
1511 ExtendOpTraits<ExtendOpTy>::getOverflowLimitForStep(Step, &Pred, SE);
1512
1513 if (OverflowLimit &&
1514 SE->isLoopEntryGuardedByCond(L, Pred, PreStart, OverflowLimit))
1515 return PreStart;
1516
1517 return nullptr;
1518}
1519
1520// Get the normalized zero or sign extended expression for this AddRec's Start.
1521template <typename ExtendOpTy>
1522static const SCEV *getExtendAddRecStart(const SCEVAddRecExpr *AR, Type *Ty,
1523 ScalarEvolution *SE,
1524 unsigned Depth) {
1525 auto GetExtendExpr = ExtendOpTraits<ExtendOpTy>::GetExtendExpr;
1526
1527 const SCEV *PreStart = getPreStartForExtend<ExtendOpTy>(AR, Ty, SE, Depth);
1528 if (!PreStart)
1529 return (SE->*GetExtendExpr)(AR->getStart(), Ty, Depth);
1530
1531 return SE->getAddExpr((SE->*GetExtendExpr)(AR->getStepRecurrence(*SE), Ty,
1532 Depth),
1533 (SE->*GetExtendExpr)(PreStart, Ty, Depth));
1534}
1535
1536// Try to prove away overflow by looking at "nearby" add recurrences. A
1537// motivating example for this rule: if we know `{0,+,4}` is `ult` `-1` and it
1538// does not itself wrap then we can conclude that `{1,+,4}` is `nuw`.
1539//
1540// Formally:
1541//
1542// {S,+,X} == {S-T,+,X} + T
1543// => Ext({S,+,X}) == Ext({S-T,+,X} + T)
1544//
1545// If ({S-T,+,X} + T) does not overflow ... (1)
1546//
1547// RHS == Ext({S-T,+,X} + T) == Ext({S-T,+,X}) + Ext(T)
1548//
1549// If {S-T,+,X} does not overflow ... (2)
1550//
1551// RHS == Ext({S-T,+,X}) + Ext(T) == {Ext(S-T),+,Ext(X)} + Ext(T)
1552// == {Ext(S-T)+Ext(T),+,Ext(X)}
1553//
1554// If (S-T)+T does not overflow ... (3)
1555//
1556// RHS == {Ext(S-T)+Ext(T),+,Ext(X)} == {Ext(S-T+T),+,Ext(X)}
1557// == {Ext(S),+,Ext(X)} == LHS
1558//
1559// Thus, if (1), (2) and (3) are true for some T, then
1560// Ext({S,+,X}) == {Ext(S),+,Ext(X)}
1561//
1562// (3) is implied by (1) -- "(S-T)+T does not overflow" is simply "({S-T,+,X}+T)
1563// does not overflow" restricted to the 0th iteration. Therefore we only need
1564// to check for (1) and (2).
1565//
1566// In the current context, S is `Start`, X is `Step`, Ext is `ExtendOpTy` and T
1567// is `Delta` (defined below).
1568template <typename ExtendOpTy>
1569bool ScalarEvolution::proveNoWrapByVaryingStart(const SCEV *Start,
1570 const SCEV *Step,
1571 const Loop *L) {
1572 auto WrapType = ExtendOpTraits<ExtendOpTy>::WrapType;
1573
1574 // We restrict `Start` to a constant to prevent SCEV from spending too much
1575 // time here. It is correct (but more expensive) to continue with a
1576 // non-constant `Start` and do a general SCEV subtraction to compute
1577 // `PreStart` below.
1578 const SCEVConstant *StartC = dyn_cast<SCEVConstant>(Start);
1579 if (!StartC)
1580 return false;
1581
1582 APInt StartAI = StartC->getAPInt();
1583
1584 for (unsigned Delta : {-2, -1, 1, 2}) {
1585 const SCEV *PreStart = getConstant(StartAI - Delta);
1586
1587 FoldingSetNodeID ID;
1588 ID.AddInteger(scAddRecExpr);
1589 ID.AddPointer(PreStart);
1590 ID.AddPointer(Step);
1591 ID.AddPointer(L);
1592 void *IP = nullptr;
1593 const auto *PreAR =
1594 static_cast<SCEVAddRecExpr *>(UniqueSCEVs.FindNodeOrInsertPos(ID, IP));
1595
1596 // Give up if we don't already have the add recurrence we need because
1597 // actually constructing an add recurrence is relatively expensive.
1598 if (PreAR && any(PreAR->getNoWrapFlags(WrapType))) { // proves (2)
1599 const SCEV *DeltaS = getConstant(StartC->getType(), Delta);
1601 const SCEV *Limit = ExtendOpTraits<ExtendOpTy>::getOverflowLimitForStep(
1602 DeltaS, &Pred, this);
1603 if (Limit && isKnownPredicate(Pred, PreAR, Limit)) // proves (1)
1604 return true;
1605 }
1606 }
1607
1608 return false;
1609}
1610
1611// Finds an integer D for an expression (C + x + y + ...) such that the top
1612// level addition in (D + (C - D + x + y + ...)) would not wrap (signed or
1613// unsigned) and the number of trailing zeros of (C - D + x + y + ...) is
1614// maximized, where C is the \p ConstantTerm, x, y, ... are arbitrary SCEVs, and
1615// the (C + x + y + ...) expression is \p WholeAddExpr.
1617 const SCEVConstant *ConstantTerm,
1618 const SCEVAddExpr *WholeAddExpr) {
1619 const APInt &C = ConstantTerm->getAPInt();
1620 const unsigned BitWidth = C.getBitWidth();
1621 // Find number of trailing zeros of (x + y + ...) w/o the C first:
1622 uint32_t TZ = BitWidth;
1623 for (unsigned I = 1, E = WholeAddExpr->getNumOperands(); I < E && TZ; ++I)
1624 TZ = std::min(TZ, SE.getMinTrailingZeros(WholeAddExpr->getOperand(I)));
1625 if (TZ) {
1626 // Set D to be as many least significant bits of C as possible while still
1627 // guaranteeing that adding D to (C - D + x + y + ...) won't cause a wrap:
1628 return TZ < BitWidth ? C.trunc(TZ).zext(BitWidth) : C;
1629 }
1630 return APInt(BitWidth, 0);
1631}
1632
1633// Finds an integer D for an affine AddRec expression {C,+,x} such that the top
1634// level addition in (D + {C-D,+,x}) would not wrap (signed or unsigned) and the
1635// number of trailing zeros of (C - D + x * n) is maximized, where C is the \p
1636// ConstantStart, x is an arbitrary \p Step, and n is the loop trip count.
1638 const APInt &ConstantStart,
1639 const SCEV *Step) {
1640 const unsigned BitWidth = ConstantStart.getBitWidth();
1641 const uint32_t TZ = SE.getMinTrailingZeros(Step);
1642 if (TZ)
1643 return TZ < BitWidth ? ConstantStart.trunc(TZ).zext(BitWidth)
1644 : ConstantStart;
1645 return APInt(BitWidth, 0);
1646}
1647
1649 const ScalarEvolution::FoldID &ID, const SCEV *S,
1652 &FoldCacheUser) {
1653 auto I = FoldCache.insert({ID, S});
1654 if (!I.second) {
1655 // Remove FoldCacheUser entry for ID when replacing an existing FoldCache
1656 // entry.
1657 auto &UserIDs = FoldCacheUser[I.first->second];
1658 assert(count(UserIDs, ID) == 1 && "unexpected duplicates in UserIDs");
1659 for (unsigned I = 0; I != UserIDs.size(); ++I)
1660 if (UserIDs[I] == ID) {
1661 std::swap(UserIDs[I], UserIDs.back());
1662 break;
1663 }
1664 UserIDs.pop_back();
1665 I.first->second = S;
1666 }
1667 FoldCacheUser[S].push_back(ID);
1668}
1669
1670const SCEV *
1672 assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) &&
1673 "This is not an extending conversion!");
1674 assert(isSCEVable(Ty) &&
1675 "This is not a conversion to a SCEVable type!");
1676 assert(!Op->getType()->isPointerTy() && "Can't extend pointer!");
1677 Ty = getEffectiveSCEVType(Ty);
1678
1679 FoldID ID(scZeroExtend, Op, Ty);
1680 if (const SCEV *S = FoldCache.lookup(ID))
1681 return S;
1682
1683 const SCEV *S = getZeroExtendExprImpl(Op, Ty, Depth);
1685 insertFoldCacheEntry(ID, S, FoldCache, FoldCacheUser);
1686 return S;
1687}
1688
1690 unsigned Depth) {
1691 assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) &&
1692 "This is not an extending conversion!");
1693 assert(isSCEVable(Ty) && "This is not a conversion to a SCEVable type!");
1694 assert(!Op->getType()->isPointerTy() && "Can't extend pointer!");
1695
1696 // Fold if the operand is constant.
1697 if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op))
1698 return getConstant(SC->getAPInt().zext(getTypeSizeInBits(Ty)));
1699
1700 // zext(zext(x)) --> zext(x)
1702 return getZeroExtendExpr(SZ->getOperand(), Ty, Depth + 1);
1703
1704 // Before doing any expensive analysis, check to see if we've already
1705 // computed a SCEV for this Op and Ty.
1707 ID.AddInteger(scZeroExtend);
1708 ID.AddPointer(Op);
1709 ID.AddPointer(Ty);
1710 void *IP = nullptr;
1711 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
1712 if (Depth > MaxCastDepth) {
1713 SCEV *S = new (SCEVAllocator) SCEVZeroExtendExpr(ID.Intern(SCEVAllocator),
1714 Op, Ty);
1715 UniqueSCEVs.InsertNode(S, IP);
1716 S->computeAndSetCanonical(*this);
1717 registerUser(S, Op);
1718 return S;
1719 }
1720
1721 // zext(trunc(x)) --> zext(x) or x or trunc(x)
1723 // It's possible the bits taken off by the truncate were all zero bits. If
1724 // so, we should be able to simplify this further.
1725 const SCEV *X = ST->getOperand();
1727 unsigned TruncBits = getTypeSizeInBits(ST->getType());
1728 unsigned NewBits = getTypeSizeInBits(Ty);
1729 if (CR.truncate(TruncBits).zeroExtend(NewBits).contains(
1730 CR.zextOrTrunc(NewBits)))
1731 return getTruncateOrZeroExtend(X, Ty, Depth);
1732 }
1733
1734 // If the input value is a chrec scev, and we can prove that the value
1735 // did not overflow the old, smaller, value, we can zero extend all of the
1736 // operands (often constants). This allows analysis of something like
1737 // this: for (unsigned char X = 0; X < 100; ++X) { int Y = X; }
1739 if (AR->isAffine()) {
1740 const SCEV *Start = AR->getStart();
1741 const SCEV *Step = AR->getStepRecurrence(*this);
1742 unsigned BitWidth = getTypeSizeInBits(AR->getType());
1743 const Loop *L = AR->getLoop();
1744
1745 // If we have special knowledge that this addrec won't overflow,
1746 // we don't need to do any further analysis.
1747 if (AR->hasNoUnsignedWrap()) {
1748 Start =
1750 Step = getZeroExtendExpr(Step, Ty, Depth + 1);
1751 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
1752 }
1753
1754 // Check whether the backedge-taken count is SCEVCouldNotCompute.
1755 // Note that this serves two purposes: It filters out loops that are
1756 // simply not analyzable, and it covers the case where this code is
1757 // being called from within backedge-taken count analysis, such that
1758 // attempting to ask for the backedge-taken count would likely result
1759 // in infinite recursion. In the later case, the analysis code will
1760 // cope with a conservative value, and it will take care to purge
1761 // that value once it has finished.
1762 const SCEV *MaxBECount = getConstantMaxBackedgeTakenCount(L);
1763 if (!isa<SCEVCouldNotCompute>(MaxBECount)) {
1764 // Manually compute the final value for AR, checking for overflow.
1765
1766 // Check whether the backedge-taken count can be losslessly casted to
1767 // the addrec's type. The count is always unsigned.
1768 const SCEV *CastedMaxBECount =
1769 getTruncateOrZeroExtend(MaxBECount, Start->getType(), Depth);
1770 const SCEV *RecastedMaxBECount = getTruncateOrZeroExtend(
1771 CastedMaxBECount, MaxBECount->getType(), Depth);
1772 if (MaxBECount == RecastedMaxBECount) {
1773 Type *WideTy = IntegerType::get(getContext(), BitWidth * 2);
1774 // Check whether Start+Step*MaxBECount has no unsigned overflow.
1775 const SCEV *ZMul = getMulExpr(CastedMaxBECount, Step,
1777 const SCEV *ZAdd = getZeroExtendExpr(getAddExpr(Start, ZMul,
1779 Depth + 1),
1780 WideTy, Depth + 1);
1781 const SCEV *WideStart = getZeroExtendExpr(Start, WideTy, Depth + 1);
1782 const SCEV *WideMaxBECount =
1783 getZeroExtendExpr(CastedMaxBECount, WideTy, Depth + 1);
1784 const SCEV *OperandExtendedAdd =
1785 getAddExpr(WideStart,
1786 getMulExpr(WideMaxBECount,
1787 getZeroExtendExpr(Step, WideTy, Depth + 1),
1790 if (ZAdd == OperandExtendedAdd) {
1791 // Cache knowledge of AR NUW, which is propagated to this AddRec.
1792 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNUW);
1793 // Return the expression with the addrec on the outside.
1794 Start = getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this,
1795 Depth + 1);
1796 Step = getZeroExtendExpr(Step, Ty, Depth + 1);
1797 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
1798 }
1799 // Similar to above, only this time treat the step value as signed.
1800 // This covers loops that count down.
1801 OperandExtendedAdd =
1802 getAddExpr(WideStart,
1803 getMulExpr(WideMaxBECount,
1804 getSignExtendExpr(Step, WideTy, Depth + 1),
1807 if (ZAdd == OperandExtendedAdd) {
1808 // Cache knowledge of AR NW, which is propagated to this AddRec.
1809 // Negative step causes unsigned wrap, but it still can't self-wrap.
1810 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNW);
1811 // Return the expression with the addrec on the outside.
1812 Start = getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this,
1813 Depth + 1);
1814 Step = getSignExtendExpr(Step, Ty, Depth + 1);
1815 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
1816 }
1817 }
1818 }
1819
1820 // Normally, in the cases we can prove no-overflow via a
1821 // backedge guarding condition, we can also compute a backedge
1822 // taken count for the loop. The exceptions are assumptions and
1823 // guards present in the loop -- SCEV is not great at exploiting
1824 // these to compute max backedge taken counts, but can still use
1825 // these to prove lack of overflow. Use this fact to avoid
1826 // doing extra work that may not pay off.
1827 if (!isa<SCEVCouldNotCompute>(MaxBECount) || HasGuards ||
1828 !AC.assumptions().empty()) {
1829
1830 auto NewFlags = proveNoUnsignedWrapViaInduction(AR);
1831 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), NewFlags);
1832 if (AR->hasNoUnsignedWrap()) {
1833 // Same as nuw case above - duplicated here to avoid a compile time
1834 // issue. It's not clear that the order of checks does matter, but
1835 // it's one of two issue possible causes for a change which was
1836 // reverted. Be conservative for the moment.
1837 Start =
1839 Step = getZeroExtendExpr(Step, Ty, Depth + 1);
1840 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
1841 }
1842
1843 // For a negative step, we can extend the operands iff doing so only
1844 // traverses values in the range zext([0,UINT_MAX]).
1845 if (isKnownNegative(Step)) {
1847 getSignedRangeMin(Step));
1850 // Cache knowledge of AR NW, which is propagated to this
1851 // AddRec. Negative step causes unsigned wrap, but it
1852 // still can't self-wrap.
1853 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNW);
1854 // Return the expression with the addrec on the outside.
1855 Start = getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this,
1856 Depth + 1);
1857 Step = getSignExtendExpr(Step, Ty, Depth + 1);
1858 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
1859 }
1860 }
1861 }
1862
1863 // zext({C,+,Step}) --> (zext(D) + zext({C-D,+,Step}))<nuw><nsw>
1864 // if D + (C - D + Step * n) could be proven to not unsigned wrap
1865 // where D maximizes the number of trailing zeros of (C - D + Step * n)
1866 if (const auto *SC = dyn_cast<SCEVConstant>(Start)) {
1867 const APInt &C = SC->getAPInt();
1868 const APInt &D = extractConstantWithoutWrapping(*this, C, Step);
1869 if (D != 0) {
1870 const SCEV *SZExtD = getZeroExtendExpr(getConstant(D), Ty, Depth);
1871 const SCEV *SResidual =
1872 getAddRecExpr(getConstant(C - D), Step, L, AR->getNoWrapFlags());
1873 const SCEV *SZExtR = getZeroExtendExpr(SResidual, Ty, Depth + 1);
1874 return getAddExpr(SZExtD, SZExtR, SCEV::FlagNSW | SCEV::FlagNUW,
1875 Depth + 1);
1876 }
1877 }
1878
1879 if (proveNoWrapByVaryingStart<SCEVZeroExtendExpr>(Start, Step, L)) {
1880 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNUW);
1881 Start =
1883 Step = getZeroExtendExpr(Step, Ty, Depth + 1);
1884 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
1885 }
1886 }
1887
1888 // zext(A % B) --> zext(A) % zext(B)
1889 {
1890 const SCEV *LHS;
1891 const SCEV *RHS;
1892 if (match(Op, m_scev_URem(m_SCEV(LHS), m_SCEV(RHS), *this)))
1893 return getURemExpr(getZeroExtendExpr(LHS, Ty, Depth + 1),
1894 getZeroExtendExpr(RHS, Ty, Depth + 1));
1895 }
1896
1897 // zext(A / B) --> zext(A) / zext(B).
1898 if (auto *Div = dyn_cast<SCEVUDivExpr>(Op))
1899 return getUDivExpr(getZeroExtendExpr(Div->getLHS(), Ty, Depth + 1),
1900 getZeroExtendExpr(Div->getRHS(), Ty, Depth + 1));
1901
1902 if (auto *SA = dyn_cast<SCEVAddExpr>(Op)) {
1903 // zext((A + B + ...)<nuw>) --> (zext(A) + zext(B) + ...)<nuw>
1904 if (SA->hasNoUnsignedWrap()) {
1905 // If the addition does not unsign overflow then we can, by definition,
1906 // commute the zero extension with the addition operation.
1908 for (SCEVUse Op : SA->operands())
1909 Ops.push_back(getZeroExtendExpr(Op, Ty, Depth + 1));
1910 return getAddExpr(Ops, SCEV::FlagNUW, Depth + 1);
1911 }
1912
1913 const APInt *C, *C2;
1914 // zext (C + A)<nsw> -> (sext(C) + sext(A))<nsw> if zext (C + A)<nsw> >=s 0.
1915 // Currently the non-negative check is done manually, as isKnownNonNegative
1916 // is too expensive.
1917 if (SA->hasNoSignedWrap() &&
1919 m_scev_SMax(m_scev_APInt(C2), m_SCEV()))) &&
1920 C->isNegative() && !C->isMinSignedValue() && C2->sge(C->abs())) {
1921 assert(isKnownNonNegative(SA) && "incorrectly determined non-negative");
1922 return getAddExpr(getSignExtendExpr(SA->getOperand(0), Ty, Depth + 1),
1923 getSignExtendExpr(SA->getOperand(1), Ty, Depth + 1),
1924 SCEV::FlagNSW, Depth + 1);
1925 }
1926
1927 // zext(C + x + y + ...) --> (zext(D) + zext((C - D) + x + y + ...))
1928 // if D + (C - D + x + y + ...) could be proven to not unsigned wrap
1929 // where D maximizes the number of trailing zeros of (C - D + x + y + ...)
1930 //
1931 // Often address arithmetics contain expressions like
1932 // (zext (add (shl X, C1), C2)), for instance, (zext (5 + (4 * X))).
1933 // This transformation is useful while proving that such expressions are
1934 // equal or differ by a small constant amount, see LoadStoreVectorizer pass.
1935 if (const auto *SC = dyn_cast<SCEVConstant>(SA->getOperand(0))) {
1936 const APInt &D = extractConstantWithoutWrapping(*this, SC, SA);
1937 if (D != 0) {
1938 const SCEV *SZExtD = getZeroExtendExpr(getConstant(D), Ty, Depth);
1939 const SCEV *SResidual =
1941 const SCEV *SZExtR = getZeroExtendExpr(SResidual, Ty, Depth + 1);
1942 return getAddExpr(SZExtD, SZExtR, (SCEV::FlagNSW | SCEV::FlagNUW),
1943 Depth + 1);
1944 }
1945 }
1946 }
1947
1948 if (auto *SM = dyn_cast<SCEVMulExpr>(Op)) {
1949 // zext((A * B * ...)<nuw>) --> (zext(A) * zext(B) * ...)<nuw>
1950 if (SM->hasNoUnsignedWrap()) {
1951 // If the multiply does not unsign overflow then we can, by definition,
1952 // commute the zero extension with the multiply operation.
1954 for (SCEVUse Op : SM->operands())
1955 Ops.push_back(getZeroExtendExpr(Op, Ty, Depth + 1));
1956 return getMulExpr(Ops, SCEV::FlagNUW, Depth + 1);
1957 }
1958
1959 // zext(2^K * (trunc X to iN)) to iM ->
1960 // 2^K * (zext(trunc X to i{N-K}) to iM)<nuw>
1961 //
1962 // Proof:
1963 //
1964 // zext(2^K * (trunc X to iN)) to iM
1965 // = zext((trunc X to iN) << K) to iM
1966 // = zext((trunc X to i{N-K}) << K)<nuw> to iM
1967 // (because shl removes the top K bits)
1968 // = zext((2^K * (trunc X to i{N-K}))<nuw>) to iM
1969 // = (2^K * (zext(trunc X to i{N-K}) to iM))<nuw>.
1970 //
1971 const APInt *C;
1972 const SCEV *TruncRHS;
1973 if (match(SM,
1974 m_scev_Mul(m_scev_APInt(C), m_scev_Trunc(m_SCEV(TruncRHS)))) &&
1975 C->isPowerOf2()) {
1976 int NewTruncBits =
1977 getTypeSizeInBits(SM->getOperand(1)->getType()) - C->logBase2();
1978 Type *NewTruncTy = IntegerType::get(getContext(), NewTruncBits);
1979 return getMulExpr(
1980 getZeroExtendExpr(SM->getOperand(0), Ty),
1981 getZeroExtendExpr(getTruncateExpr(TruncRHS, NewTruncTy), Ty),
1982 SCEV::FlagNUW, Depth + 1);
1983 }
1984 }
1985
1986 // zext(umin(x, y)) -> umin(zext(x), zext(y))
1987 // zext(umax(x, y)) -> umax(zext(x), zext(y))
1990 SmallVector<SCEVUse, 4> Operands;
1991 for (SCEVUse Operand : MinMax->operands())
1992 Operands.push_back(getZeroExtendExpr(Operand, Ty));
1994 return getUMinExpr(Operands);
1995 return getUMaxExpr(Operands);
1996 }
1997
1998 // zext(umin_seq(x, y)) -> umin_seq(zext(x), zext(y))
2000 assert(isa<SCEVSequentialUMinExpr>(MinMax) && "Not supported!");
2001 SmallVector<SCEVUse, 4> Operands;
2002 for (SCEVUse Operand : MinMax->operands())
2003 Operands.push_back(getZeroExtendExpr(Operand, Ty));
2004 return getUMinExpr(Operands, /*Sequential*/ true);
2005 }
2006
2007 // The cast wasn't folded; create an explicit cast node.
2008 // Recompute the insert position, as it may have been invalidated.
2009 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
2010 SCEV *S = new (SCEVAllocator) SCEVZeroExtendExpr(ID.Intern(SCEVAllocator),
2011 Op, Ty);
2012 UniqueSCEVs.InsertNode(S, IP);
2013 S->computeAndSetCanonical(*this);
2014 registerUser(S, Op);
2015 return S;
2016}
2017
2018const SCEV *
2020 assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) &&
2021 "This is not an extending conversion!");
2022 assert(isSCEVable(Ty) &&
2023 "This is not a conversion to a SCEVable type!");
2024 assert(!Op->getType()->isPointerTy() && "Can't extend pointer!");
2025 Ty = getEffectiveSCEVType(Ty);
2026
2027 FoldID ID(scSignExtend, Op, Ty);
2028 if (const SCEV *S = FoldCache.lookup(ID))
2029 return S;
2030
2031 const SCEV *S = getSignExtendExprImpl(Op, Ty, Depth);
2033 insertFoldCacheEntry(ID, S, FoldCache, FoldCacheUser);
2034 return S;
2035}
2036
2038 unsigned Depth) {
2039 assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) &&
2040 "This is not an extending conversion!");
2041 assert(isSCEVable(Ty) && "This is not a conversion to a SCEVable type!");
2042 assert(!Op->getType()->isPointerTy() && "Can't extend pointer!");
2043 Ty = getEffectiveSCEVType(Ty);
2044
2045 // Fold if the operand is constant.
2046 if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op))
2047 return getConstant(SC->getAPInt().sext(getTypeSizeInBits(Ty)));
2048
2049 // sext(sext(x)) --> sext(x)
2051 return getSignExtendExpr(SS->getOperand(), Ty, Depth + 1);
2052
2053 // sext(zext(x)) --> zext(x)
2055 return getZeroExtendExpr(SZ->getOperand(), Ty, Depth + 1);
2056
2057 // Before doing any expensive analysis, check to see if we've already
2058 // computed a SCEV for this Op and Ty.
2060 ID.AddInteger(scSignExtend);
2061 ID.AddPointer(Op);
2062 ID.AddPointer(Ty);
2063 void *IP = nullptr;
2064 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
2065 // Limit recursion depth.
2066 if (Depth > MaxCastDepth) {
2067 SCEV *S = new (SCEVAllocator) SCEVSignExtendExpr(ID.Intern(SCEVAllocator),
2068 Op, Ty);
2069 UniqueSCEVs.InsertNode(S, IP);
2070 S->computeAndSetCanonical(*this);
2071 registerUser(S, Op);
2072 return S;
2073 }
2074
2075 // sext(trunc(x)) --> sext(x) or x or trunc(x)
2077 // It's possible the bits taken off by the truncate were all sign bits. If
2078 // so, we should be able to simplify this further.
2079 const SCEV *X = ST->getOperand();
2081 unsigned TruncBits = getTypeSizeInBits(ST->getType());
2082 unsigned NewBits = getTypeSizeInBits(Ty);
2083 if (CR.truncate(TruncBits).signExtend(NewBits).contains(
2084 CR.sextOrTrunc(NewBits)))
2085 return getTruncateOrSignExtend(X, Ty, Depth);
2086 }
2087
2088 if (auto *SA = dyn_cast<SCEVAddExpr>(Op)) {
2089 // sext((A + B + ...)<nsw>) --> (sext(A) + sext(B) + ...)<nsw>
2090 if (SA->hasNoSignedWrap()) {
2091 // If the addition does not sign overflow then we can, by definition,
2092 // commute the sign extension with the addition operation.
2094 for (SCEVUse Op : SA->operands())
2095 Ops.push_back(getSignExtendExpr(Op, Ty, Depth + 1));
2096 return getAddExpr(Ops, SCEV::FlagNSW, Depth + 1);
2097 }
2098
2099 // sext(C + x + y + ...) --> (sext(D) + sext((C - D) + x + y + ...))
2100 // if D + (C - D + x + y + ...) could be proven to not signed wrap
2101 // where D maximizes the number of trailing zeros of (C - D + x + y + ...)
2102 //
2103 // For instance, this will bring two seemingly different expressions:
2104 // 1 + sext(5 + 20 * %x + 24 * %y) and
2105 // sext(6 + 20 * %x + 24 * %y)
2106 // to the same form:
2107 // 2 + sext(4 + 20 * %x + 24 * %y)
2108 if (const auto *SC = dyn_cast<SCEVConstant>(SA->getOperand(0))) {
2109 const APInt &D = extractConstantWithoutWrapping(*this, SC, SA);
2110 if (D != 0) {
2111 const SCEV *SSExtD = getSignExtendExpr(getConstant(D), Ty, Depth);
2112 const SCEV *SResidual =
2114 const SCEV *SSExtR = getSignExtendExpr(SResidual, Ty, Depth + 1);
2115 return getAddExpr(SSExtD, SSExtR, (SCEV::FlagNSW | SCEV::FlagNUW),
2116 Depth + 1);
2117 }
2118 }
2119 }
2120 // If the input value is a chrec scev, and we can prove that the value
2121 // did not overflow the old, smaller, value, we can sign extend all of the
2122 // operands (often constants). This allows analysis of something like
2123 // this: for (signed char X = 0; X < 100; ++X) { int Y = X; }
2125 if (AR->isAffine()) {
2126 const SCEV *Start = AR->getStart();
2127 const SCEV *Step = AR->getStepRecurrence(*this);
2128 unsigned BitWidth = getTypeSizeInBits(AR->getType());
2129 const Loop *L = AR->getLoop();
2130
2131 // If we have special knowledge that this addrec won't overflow,
2132 // we don't need to do any further analysis.
2133 if (AR->hasNoSignedWrap()) {
2134 Start =
2136 Step = getSignExtendExpr(Step, Ty, Depth + 1);
2137 return getAddRecExpr(Start, Step, L, SCEV::FlagNSW);
2138 }
2139
2140 // Check whether the backedge-taken count is SCEVCouldNotCompute.
2141 // Note that this serves two purposes: It filters out loops that are
2142 // simply not analyzable, and it covers the case where this code is
2143 // being called from within backedge-taken count analysis, such that
2144 // attempting to ask for the backedge-taken count would likely result
2145 // in infinite recursion. In the later case, the analysis code will
2146 // cope with a conservative value, and it will take care to purge
2147 // that value once it has finished.
2148 const SCEV *MaxBECount = getConstantMaxBackedgeTakenCount(L);
2149 if (!isa<SCEVCouldNotCompute>(MaxBECount)) {
2150 // Manually compute the final value for AR, checking for
2151 // overflow.
2152
2153 // Check whether the backedge-taken count can be losslessly casted to
2154 // the addrec's type. The count is always unsigned.
2155 const SCEV *CastedMaxBECount =
2156 getTruncateOrZeroExtend(MaxBECount, Start->getType(), Depth);
2157 const SCEV *RecastedMaxBECount = getTruncateOrZeroExtend(
2158 CastedMaxBECount, MaxBECount->getType(), Depth);
2159 if (MaxBECount == RecastedMaxBECount) {
2160 Type *WideTy = IntegerType::get(getContext(), BitWidth * 2);
2161 // Check whether Start+Step*MaxBECount has no signed overflow.
2162 const SCEV *SMul = getMulExpr(CastedMaxBECount, Step,
2164 const SCEV *SAdd = getSignExtendExpr(getAddExpr(Start, SMul,
2166 Depth + 1),
2167 WideTy, Depth + 1);
2168 const SCEV *WideStart = getSignExtendExpr(Start, WideTy, Depth + 1);
2169 const SCEV *WideMaxBECount =
2170 getZeroExtendExpr(CastedMaxBECount, WideTy, Depth + 1);
2171 const SCEV *OperandExtendedAdd =
2172 getAddExpr(WideStart,
2173 getMulExpr(WideMaxBECount,
2174 getSignExtendExpr(Step, WideTy, Depth + 1),
2177 if (SAdd == OperandExtendedAdd) {
2178 // Cache knowledge of AR NSW, which is propagated to this AddRec.
2179 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNSW);
2180 // Return the expression with the addrec on the outside.
2181 Start = getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this,
2182 Depth + 1);
2183 Step = getSignExtendExpr(Step, Ty, Depth + 1);
2184 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
2185 }
2186 // Similar to above, only this time treat the step value as unsigned.
2187 // This covers loops that count up with an unsigned step.
2188 OperandExtendedAdd =
2189 getAddExpr(WideStart,
2190 getMulExpr(WideMaxBECount,
2191 getZeroExtendExpr(Step, WideTy, Depth + 1),
2194 if (SAdd == OperandExtendedAdd) {
2195 // If AR wraps around then
2196 //
2197 // abs(Step) * MaxBECount > unsigned-max(AR->getType())
2198 // => SAdd != OperandExtendedAdd
2199 //
2200 // Thus (AR is not NW => SAdd != OperandExtendedAdd) <=>
2201 // (SAdd == OperandExtendedAdd => AR is NW)
2202
2203 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNW);
2204
2205 // Return the expression with the addrec on the outside.
2206 Start = getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this,
2207 Depth + 1);
2208 Step = getZeroExtendExpr(Step, Ty, Depth + 1);
2209 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
2210 }
2211 }
2212 }
2213
2214 auto NewFlags = proveNoSignedWrapViaInduction(AR);
2215 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), NewFlags);
2216 if (AR->hasNoSignedWrap()) {
2217 // Same as nsw case above - duplicated here to avoid a compile time
2218 // issue. It's not clear that the order of checks does matter, but
2219 // it's one of two issue possible causes for a change which was
2220 // reverted. Be conservative for the moment.
2221 Start =
2223 Step = getSignExtendExpr(Step, Ty, Depth + 1);
2224 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
2225 }
2226
2227 // sext({C,+,Step}) --> (sext(D) + sext({C-D,+,Step}))<nuw><nsw>
2228 // if D + (C - D + Step * n) could be proven to not signed wrap
2229 // where D maximizes the number of trailing zeros of (C - D + Step * n)
2230 if (const auto *SC = dyn_cast<SCEVConstant>(Start)) {
2231 const APInt &C = SC->getAPInt();
2232 const APInt &D = extractConstantWithoutWrapping(*this, C, Step);
2233 if (D != 0) {
2234 const SCEV *SSExtD = getSignExtendExpr(getConstant(D), Ty, Depth);
2235 const SCEV *SResidual =
2236 getAddRecExpr(getConstant(C - D), Step, L, AR->getNoWrapFlags());
2237 const SCEV *SSExtR = getSignExtendExpr(SResidual, Ty, Depth + 1);
2238 return getAddExpr(SSExtD, SSExtR, (SCEV::FlagNSW | SCEV::FlagNUW),
2239 Depth + 1);
2240 }
2241 }
2242
2243 if (proveNoWrapByVaryingStart<SCEVSignExtendExpr>(Start, Step, L)) {
2244 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNSW);
2245 Start =
2247 Step = getSignExtendExpr(Step, Ty, Depth + 1);
2248 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
2249 }
2250 }
2251
2252 // If the input value is provably positive and we could not simplify
2253 // away the sext build a zext instead.
2255 return getZeroExtendExpr(Op, Ty, Depth + 1);
2256
2257 // sext(smin(x, y)) -> smin(sext(x), sext(y))
2258 // sext(smax(x, y)) -> smax(sext(x), sext(y))
2261 SmallVector<SCEVUse, 4> Operands;
2262 for (SCEVUse Operand : MinMax->operands())
2263 Operands.push_back(getSignExtendExpr(Operand, Ty));
2265 return getSMinExpr(Operands);
2266 return getSMaxExpr(Operands);
2267 }
2268
2269 // The cast wasn't folded; create an explicit cast node.
2270 // Recompute the insert position, as it may have been invalidated.
2271 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
2272 SCEV *S = new (SCEVAllocator) SCEVSignExtendExpr(ID.Intern(SCEVAllocator),
2273 Op, Ty);
2274 UniqueSCEVs.InsertNode(S, IP);
2275 S->computeAndSetCanonical(*this);
2276 registerUser(S, Op);
2277 return S;
2278}
2279
2281 Type *Ty) {
2282 switch (Kind) {
2283 case scTruncate:
2284 return getTruncateExpr(Op, Ty);
2285 case scZeroExtend:
2286 return getZeroExtendExpr(Op, Ty);
2287 case scSignExtend:
2288 return getSignExtendExpr(Op, Ty);
2289 case scPtrToInt:
2290 return getPtrToIntExpr(Op, Ty);
2291 default:
2292 llvm_unreachable("Not a SCEV cast expression!");
2293 }
2294}
2295
2296/// getAnyExtendExpr - Return a SCEV for the given operand extended with
2297/// unspecified bits out to the given type.
2299 Type *Ty) {
2300 assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) &&
2301 "This is not an extending conversion!");
2302 assert(isSCEVable(Ty) &&
2303 "This is not a conversion to a SCEVable type!");
2304 Ty = getEffectiveSCEVType(Ty);
2305
2306 // Sign-extend negative constants.
2307 if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op))
2308 if (SC->getAPInt().isNegative())
2309 return getSignExtendExpr(Op, Ty);
2310
2311 // Peel off a truncate cast.
2313 const SCEV *NewOp = T->getOperand();
2314 if (getTypeSizeInBits(NewOp->getType()) < getTypeSizeInBits(Ty))
2315 return getAnyExtendExpr(NewOp, Ty);
2316 return getTruncateOrNoop(NewOp, Ty);
2317 }
2318
2319 // Next try a zext cast. If the cast is folded, use it.
2320 const SCEV *ZExt = getZeroExtendExpr(Op, Ty);
2321 if (!isa<SCEVZeroExtendExpr>(ZExt))
2322 return ZExt;
2323
2324 // Next try a sext cast. If the cast is folded, use it.
2325 const SCEV *SExt = getSignExtendExpr(Op, Ty);
2326 if (!isa<SCEVSignExtendExpr>(SExt))
2327 return SExt;
2328
2329 // Force the cast to be folded into the operands of an addrec.
2330 if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Op)) {
2332 for (const SCEV *Op : AR->operands())
2333 Ops.push_back(getAnyExtendExpr(Op, Ty));
2334 return getAddRecExpr(Ops, AR->getLoop(), SCEV::FlagNW);
2335 }
2336
2337 // If the expression is obviously signed, use the sext cast value.
2338 if (isa<SCEVSMaxExpr>(Op))
2339 return SExt;
2340
2341 // Absent any other information, use the zext cast value.
2342 return ZExt;
2343}
2344
2345/// Process the given Ops list, which is a list of operands to be added under
2346/// the given scale, update the given map. This is a helper function for
2347/// getAddRecExpr. As an example of what it does, given a sequence of operands
2348/// that would form an add expression like this:
2349///
2350/// m + n + 13 + (A * (o + p + (B * (q + m + 29)))) + r + (-1 * r)
2351///
2352/// where A and B are constants, update the map with these values:
2353///
2354/// (m, 1+A*B), (n, 1), (o, A), (p, A), (q, A*B), (r, 0)
2355///
2356/// and add 13 + A*B*29 to AccumulatedConstant.
2357/// This will allow getAddRecExpr to produce this:
2358///
2359/// 13+A*B*29 + n + (m * (1+A*B)) + ((o + p) * A) + (q * A*B)
2360///
2361/// This form often exposes folding opportunities that are hidden in
2362/// the original operand list.
2363///
2364/// Return true iff it appears that any interesting folding opportunities
2365/// may be exposed. This helps getAddRecExpr short-circuit extra work in
2366/// the common case where no interesting opportunities are present, and
2367/// is also used as a check to avoid infinite recursion.
2370 APInt &AccumulatedConstant,
2372 const APInt &Scale,
2373 ScalarEvolution &SE) {
2374 bool Interesting = false;
2375
2376 // Iterate over the add operands. They are sorted, with constants first.
2377 unsigned i = 0;
2378 while (const SCEVConstant *C = dyn_cast<SCEVConstant>(Ops[i])) {
2379 ++i;
2380 // Pull a buried constant out to the outside.
2381 if (Scale != 1 || AccumulatedConstant != 0 || C->getValue()->isZero())
2382 Interesting = true;
2383 AccumulatedConstant += Scale * C->getAPInt();
2384 }
2385
2386 // Next comes everything else. We're especially interested in multiplies
2387 // here, but they're in the middle, so just visit the rest with one loop.
2388 for (; i != Ops.size(); ++i) {
2390 if (Mul && isa<SCEVConstant>(Mul->getOperand(0))) {
2391 APInt NewScale =
2392 Scale * cast<SCEVConstant>(Mul->getOperand(0))->getAPInt();
2393 if (Mul->getNumOperands() == 2 && isa<SCEVAddExpr>(Mul->getOperand(1))) {
2394 // A multiplication of a constant with another add; recurse.
2395 const SCEVAddExpr *Add = cast<SCEVAddExpr>(Mul->getOperand(1));
2396 Interesting |= CollectAddOperandsWithScales(
2397 M, NewOps, AccumulatedConstant, Add->operands(), NewScale, SE);
2398 } else {
2399 // A multiplication of a constant with some other value. Update
2400 // the map.
2401 SmallVector<SCEVUse, 4> MulOps(drop_begin(Mul->operands()));
2402 const SCEV *Key = SE.getMulExpr(MulOps);
2403 auto Pair = M.insert({Key, NewScale});
2404 if (Pair.second) {
2405 NewOps.push_back(Pair.first->first);
2406 } else {
2407 Pair.first->second += NewScale;
2408 // The map already had an entry for this value, which may indicate
2409 // a folding opportunity.
2410 Interesting = true;
2411 }
2412 }
2413 } else {
2414 // An ordinary operand. Update the map.
2415 auto Pair = M.insert({Ops[i], Scale});
2416 if (Pair.second) {
2417 NewOps.push_back(Pair.first->first);
2418 } else {
2419 Pair.first->second += Scale;
2420 // The map already had an entry for this value, which may indicate
2421 // a folding opportunity.
2422 Interesting = true;
2423 }
2424 }
2425 }
2426
2427 return Interesting;
2428}
2429
2431 const SCEV *LHS, const SCEV *RHS,
2432 const Instruction *CtxI) {
2434 unsigned);
2435 switch (BinOp) {
2436 default:
2437 llvm_unreachable("Unsupported binary op");
2438 case Instruction::Add:
2440 break;
2441 case Instruction::Sub:
2443 break;
2444 case Instruction::Mul:
2446 break;
2447 }
2448
2449 const SCEV *(ScalarEvolution::*Extension)(const SCEV *, Type *, unsigned) =
2452
2453 // Check ext(LHS op RHS) == ext(LHS) op ext(RHS)
2454 auto *NarrowTy = cast<IntegerType>(LHS->getType());
2455 auto *WideTy =
2456 IntegerType::get(NarrowTy->getContext(), NarrowTy->getBitWidth() * 2);
2457
2458 const SCEV *A = (this->*Extension)(
2459 (this->*Operation)(LHS, RHS, SCEV::FlagAnyWrap, 0), WideTy, 0);
2460 const SCEV *LHSB = (this->*Extension)(LHS, WideTy, 0);
2461 const SCEV *RHSB = (this->*Extension)(RHS, WideTy, 0);
2462 const SCEV *B = (this->*Operation)(LHSB, RHSB, SCEV::FlagAnyWrap, 0);
2463 if (A == B)
2464 return true;
2465 // Can we use context to prove the fact we need?
2466 if (!CtxI)
2467 return false;
2468 // TODO: Support mul.
2469 if (BinOp == Instruction::Mul)
2470 return false;
2471 auto *RHSC = dyn_cast<SCEVConstant>(RHS);
2472 // TODO: Lift this limitation.
2473 if (!RHSC)
2474 return false;
2475 APInt C = RHSC->getAPInt();
2476 unsigned NumBits = C.getBitWidth();
2477 bool IsSub = (BinOp == Instruction::Sub);
2478 bool IsNegativeConst = (Signed && C.isNegative());
2479 // Compute the direction and magnitude by which we need to check overflow.
2480 bool OverflowDown = IsSub ^ IsNegativeConst;
2481 APInt Magnitude = C;
2482 if (IsNegativeConst) {
2483 if (C == APInt::getSignedMinValue(NumBits))
2484 // TODO: SINT_MIN on inversion gives the same negative value, we don't
2485 // want to deal with that.
2486 return false;
2487 Magnitude = -C;
2488 }
2489
2491 if (OverflowDown) {
2492 // To avoid overflow down, we need to make sure that MIN + Magnitude <= LHS.
2493 APInt Min = Signed ? APInt::getSignedMinValue(NumBits)
2494 : APInt::getMinValue(NumBits);
2495 APInt Limit = Min + Magnitude;
2496 return isKnownPredicateAt(Pred, getConstant(Limit), LHS, CtxI);
2497 } else {
2498 // To avoid overflow up, we need to make sure that LHS <= MAX - Magnitude.
2499 APInt Max = Signed ? APInt::getSignedMaxValue(NumBits)
2500 : APInt::getMaxValue(NumBits);
2501 APInt Limit = Max - Magnitude;
2502 return isKnownPredicateAt(Pred, LHS, getConstant(Limit), CtxI);
2503 }
2504}
2505
2506std::optional<SCEV::NoWrapFlags>
2508 const OverflowingBinaryOperator *OBO) {
2509 // It cannot be done any better.
2510 if (OBO->hasNoUnsignedWrap() && OBO->hasNoSignedWrap())
2511 return std::nullopt;
2512
2513 SCEV::NoWrapFlags Flags = SCEV::NoWrapFlags::FlagAnyWrap;
2514
2515 if (OBO->hasNoUnsignedWrap())
2517 if (OBO->hasNoSignedWrap())
2519
2520 bool Deduced = false;
2521
2522 if (OBO->getOpcode() != Instruction::Add &&
2523 OBO->getOpcode() != Instruction::Sub &&
2524 OBO->getOpcode() != Instruction::Mul)
2525 return std::nullopt;
2526
2527 const SCEV *LHS = getSCEV(OBO->getOperand(0));
2528 const SCEV *RHS = getSCEV(OBO->getOperand(1));
2529
2530 const Instruction *CtxI =
2532 if (!OBO->hasNoUnsignedWrap() &&
2534 /* Signed */ false, LHS, RHS, CtxI)) {
2536 Deduced = true;
2537 }
2538
2539 if (!OBO->hasNoSignedWrap() &&
2541 /* Signed */ true, LHS, RHS, CtxI)) {
2543 Deduced = true;
2544 }
2545
2546 if (Deduced)
2547 return Flags;
2548 return std::nullopt;
2549}
2550
2551// We're trying to construct a SCEV of type `Type' with `Ops' as operands and
2552// `OldFlags' as can't-wrap behavior. Infer a more aggressive set of
2553// can't-overflow flags for the operation if possible.
2557 SCEV::NoWrapFlags Flags) {
2558 using namespace std::placeholders;
2559
2560 using OBO = OverflowingBinaryOperator;
2561
2562 bool CanAnalyze =
2564 (void)CanAnalyze;
2565 assert(CanAnalyze && "don't call from other places!");
2566
2567 SCEV::NoWrapFlags SignOrUnsignMask = SCEV::FlagNUW | SCEV::FlagNSW;
2568 SCEV::NoWrapFlags SignOrUnsignWrap =
2569 ScalarEvolution::maskFlags(Flags, SignOrUnsignMask);
2570
2571 // If FlagNSW is true and all the operands are non-negative, infer FlagNUW.
2572 auto IsKnownNonNegative = [&](SCEVUse U) {
2573 return SE->isKnownNonNegative(U);
2574 };
2575
2576 if (SignOrUnsignWrap == SCEV::FlagNSW && all_of(Ops, IsKnownNonNegative))
2577 Flags = ScalarEvolution::setFlags(Flags, SignOrUnsignMask);
2578
2579 SignOrUnsignWrap = ScalarEvolution::maskFlags(Flags, SignOrUnsignMask);
2580
2581 if (SignOrUnsignWrap != SignOrUnsignMask &&
2582 (Type == scAddExpr || Type == scMulExpr) && Ops.size() == 2 &&
2583 isa<SCEVConstant>(Ops[0])) {
2584
2585 auto Opcode = [&] {
2586 switch (Type) {
2587 case scAddExpr:
2588 return Instruction::Add;
2589 case scMulExpr:
2590 return Instruction::Mul;
2591 default:
2592 llvm_unreachable("Unexpected SCEV op.");
2593 }
2594 }();
2595
2596 const APInt &C = cast<SCEVConstant>(Ops[0])->getAPInt();
2597
2598 // (A <opcode> C) --> (A <opcode> C)<nsw> if the op doesn't sign overflow.
2599 if (!(SignOrUnsignWrap & SCEV::FlagNSW)) {
2601 Opcode, C, OBO::NoSignedWrap);
2602 if (NSWRegion.contains(SE->getSignedRange(Ops[1])))
2604 }
2605
2606 // (A <opcode> C) --> (A <opcode> C)<nuw> if the op doesn't unsign overflow.
2607 if (!(SignOrUnsignWrap & SCEV::FlagNUW)) {
2609 Opcode, C, OBO::NoUnsignedWrap);
2610 if (NUWRegion.contains(SE->getUnsignedRange(Ops[1])))
2612 }
2613 }
2614
2615 // <0,+,nonnegative><nw> is also nuw
2616 // TODO: Add corresponding nsw case
2618 !ScalarEvolution::hasFlags(Flags, SCEV::FlagNUW) && Ops.size() == 2 &&
2619 Ops[0]->isZero() && IsKnownNonNegative(Ops[1]))
2621
2622 // both (udiv X, Y) * Y and Y * (udiv X, Y) are always NUW
2624 Ops.size() == 2) {
2625 if (auto *UDiv = dyn_cast<SCEVUDivExpr>(Ops[0]))
2626 if (UDiv->getOperand(1) == Ops[1])
2628 if (auto *UDiv = dyn_cast<SCEVUDivExpr>(Ops[1]))
2629 if (UDiv->getOperand(1) == Ops[0])
2631 }
2632
2633 return Flags;
2634}
2635
2637 return isLoopInvariant(S, L) && properlyDominates(S, L->getHeader());
2638}
2639
2640/// Get a canonical add expression, or something simpler if possible.
2642 SCEV::NoWrapFlags OrigFlags,
2643 unsigned Depth) {
2644 assert(!(OrigFlags & ~(SCEV::FlagNUW | SCEV::FlagNSW)) &&
2645 "only nuw or nsw allowed");
2646 assert(!Ops.empty() && "Cannot get empty add!");
2647 if (Ops.size() == 1) return Ops[0];
2648#ifndef NDEBUG
2649 Type *ETy = getEffectiveSCEVType(Ops[0]->getType());
2650 for (unsigned i = 1, e = Ops.size(); i != e; ++i)
2651 assert(getEffectiveSCEVType(Ops[i]->getType()) == ETy &&
2652 "SCEVAddExpr operand types don't match!");
2653 unsigned NumPtrs = count_if(
2654 Ops, [](const SCEV *Op) { return Op->getType()->isPointerTy(); });
2655 assert(NumPtrs <= 1 && "add has at most one pointer operand");
2656#endif
2657
2658 const SCEV *Folded = constantFoldAndGroupOps(
2659 *this, LI, DT, Ops,
2660 [](const APInt &C1, const APInt &C2) { return C1 + C2; },
2661 [](const APInt &C) { return C.isZero(); }, // identity
2662 [](const APInt &C) { return false; }); // absorber
2663 if (Folded)
2664 return Folded;
2665
2666 unsigned Idx = isa<SCEVConstant>(Ops[0]) ? 1 : 0;
2667
2668 // Delay expensive flag strengthening until necessary.
2669 auto ComputeFlags = [this, OrigFlags](ArrayRef<SCEVUse> Ops) {
2670 return StrengthenNoWrapFlags(this, scAddExpr, Ops, OrigFlags);
2671 };
2672
2673 // Limit recursion calls depth.
2675 return getOrCreateAddExpr(Ops, ComputeFlags(Ops));
2676
2677 if (SCEV *S = findExistingSCEVInCache(scAddExpr, Ops)) {
2678 // Don't strengthen flags if we have no new information.
2679 SCEVAddExpr *Add = static_cast<SCEVAddExpr *>(S);
2680 if (Add->getNoWrapFlags(OrigFlags) != OrigFlags)
2681 Add->setNoWrapFlags(ComputeFlags(Ops));
2682 return S;
2683 }
2684
2685 // Okay, check to see if the same value occurs in the operand list more than
2686 // once. If so, merge them together into an multiply expression. Since we
2687 // sorted the list, these values are required to be adjacent.
2688 Type *Ty = Ops[0]->getType();
2689 bool FoundMatch = false;
2690 for (unsigned i = 0, e = Ops.size(); i != e-1; ++i)
2691 if (Ops[i] == Ops[i+1]) { // X + Y + Y --> X + Y*2
2692 // Scan ahead to count how many equal operands there are.
2693 unsigned Count = 2;
2694 while (i+Count != e && Ops[i+Count] == Ops[i])
2695 ++Count;
2696 // Merge the values into a multiply.
2697 SCEVUse Scale = getConstant(Ty, Count);
2698 const SCEV *Mul = getMulExpr(Scale, Ops[i], SCEV::FlagAnyWrap, Depth + 1);
2699 if (Ops.size() == Count)
2700 return Mul;
2701 Ops[i] = Mul;
2702 Ops.erase(Ops.begin()+i+1, Ops.begin()+i+Count);
2703 --i; e -= Count - 1;
2704 FoundMatch = true;
2705 }
2706 if (FoundMatch)
2707 return getAddExpr(Ops, OrigFlags, Depth + 1);
2708
2709 // Check for truncates. If all the operands are truncated from the same
2710 // type, see if factoring out the truncate would permit the result to be
2711 // folded. eg., n*trunc(x) + m*trunc(y) --> trunc(trunc(m)*x + trunc(n)*y)
2712 // if the contents of the resulting outer trunc fold to something simple.
2713 auto FindTruncSrcType = [&]() -> Type * {
2714 // We're ultimately looking to fold an addrec of truncs and muls of only
2715 // constants and truncs, so if we find any other types of SCEV
2716 // as operands of the addrec then we bail and return nullptr here.
2717 // Otherwise, we return the type of the operand of a trunc that we find.
2718 if (auto *T = dyn_cast<SCEVTruncateExpr>(Ops[Idx]))
2719 return T->getOperand()->getType();
2720 if (const auto *Mul = dyn_cast<SCEVMulExpr>(Ops[Idx])) {
2721 SCEVUse LastOp = Mul->getOperand(Mul->getNumOperands() - 1);
2722 if (const auto *T = dyn_cast<SCEVTruncateExpr>(LastOp))
2723 return T->getOperand()->getType();
2724 }
2725 return nullptr;
2726 };
2727 if (auto *SrcType = FindTruncSrcType()) {
2728 SmallVector<SCEVUse, 8> LargeOps;
2729 bool Ok = true;
2730 // Check all the operands to see if they can be represented in the
2731 // source type of the truncate.
2732 for (const SCEV *Op : Ops) {
2734 if (T->getOperand()->getType() != SrcType) {
2735 Ok = false;
2736 break;
2737 }
2738 LargeOps.push_back(T->getOperand());
2739 } else if (const SCEVConstant *C = dyn_cast<SCEVConstant>(Op)) {
2740 LargeOps.push_back(getAnyExtendExpr(C, SrcType));
2741 } else if (const SCEVMulExpr *M = dyn_cast<SCEVMulExpr>(Op)) {
2742 SmallVector<SCEVUse, 8> LargeMulOps;
2743 for (unsigned j = 0, f = M->getNumOperands(); j != f && Ok; ++j) {
2744 if (const SCEVTruncateExpr *T =
2745 dyn_cast<SCEVTruncateExpr>(M->getOperand(j))) {
2746 if (T->getOperand()->getType() != SrcType) {
2747 Ok = false;
2748 break;
2749 }
2750 LargeMulOps.push_back(T->getOperand());
2751 } else if (const auto *C = dyn_cast<SCEVConstant>(M->getOperand(j))) {
2752 LargeMulOps.push_back(getAnyExtendExpr(C, SrcType));
2753 } else {
2754 Ok = false;
2755 break;
2756 }
2757 }
2758 if (Ok)
2759 LargeOps.push_back(getMulExpr(LargeMulOps, SCEV::FlagAnyWrap, Depth + 1));
2760 } else {
2761 Ok = false;
2762 break;
2763 }
2764 }
2765 if (Ok) {
2766 // Evaluate the expression in the larger type.
2767 const SCEV *Fold = getAddExpr(LargeOps, SCEV::FlagAnyWrap, Depth + 1);
2768 // If it folds to something simple, use it. Otherwise, don't.
2769 if (isa<SCEVConstant>(Fold) || isa<SCEVUnknown>(Fold))
2770 return getTruncateExpr(Fold, Ty);
2771 }
2772 }
2773
2774 if (Ops.size() == 2) {
2775 // Check if we have an expression of the form ((X + C1) - C2), where C1 and
2776 // C2 can be folded in a way that allows retaining wrapping flags of (X +
2777 // C1).
2778 const SCEV *A = Ops[0];
2779 const SCEV *B = Ops[1];
2780 auto *AddExpr = dyn_cast<SCEVAddExpr>(B);
2781 auto *C = dyn_cast<SCEVConstant>(A);
2782 if (AddExpr && C && isa<SCEVConstant>(AddExpr->getOperand(0))) {
2783 auto C1 = cast<SCEVConstant>(AddExpr->getOperand(0))->getAPInt();
2784 auto C2 = C->getAPInt();
2785 SCEV::NoWrapFlags PreservedFlags = SCEV::FlagAnyWrap;
2786
2787 APInt ConstAdd = C1 + C2;
2788 auto AddFlags = AddExpr->getNoWrapFlags();
2789 // Adding a smaller constant is NUW if the original AddExpr was NUW.
2791 ConstAdd.ule(C1)) {
2792 PreservedFlags =
2794 }
2795
2796 // Adding a constant with the same sign and small magnitude is NSW, if the
2797 // original AddExpr was NSW.
2799 C1.isSignBitSet() == ConstAdd.isSignBitSet() &&
2800 ConstAdd.abs().ule(C1.abs())) {
2801 PreservedFlags =
2803 }
2804
2805 if (PreservedFlags != SCEV::FlagAnyWrap) {
2806 SmallVector<SCEVUse, 4> NewOps(AddExpr->operands());
2807 NewOps[0] = getConstant(ConstAdd);
2808 return getAddExpr(NewOps, PreservedFlags);
2809 }
2810 }
2811
2812 // Try to push the constant operand into a ZExt: A + zext (-A + B) -> zext
2813 // (B), if trunc (A) + -A + B does not unsigned-wrap.
2814 const SCEVAddExpr *InnerAdd;
2815 if (match(B, m_scev_ZExt(m_scev_Add(InnerAdd)))) {
2816 const SCEV *NarrowA = getTruncateExpr(A, InnerAdd->getType());
2817 if (NarrowA == getNegativeSCEV(InnerAdd->getOperand(0)) &&
2818 getZeroExtendExpr(NarrowA, B->getType()) == A &&
2819 hasFlags(StrengthenNoWrapFlags(this, scAddExpr, {NarrowA, InnerAdd},
2821 SCEV::FlagNUW)) {
2822 return getZeroExtendExpr(getAddExpr(NarrowA, InnerAdd), B->getType());
2823 }
2824 }
2825 }
2826
2827 // Canonicalize (-1 * urem X, Y) + X --> (Y * X/Y)
2828 const SCEV *Y;
2829 if (Ops.size() == 2 &&
2830 match(Ops[0],
2832 m_scev_URem(m_scev_Specific(Ops[1]), m_SCEV(Y), *this))))
2833 return getMulExpr(Y, getUDivExpr(Ops[1], Y));
2834
2835 // Skip past any other cast SCEVs.
2836 while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scAddExpr)
2837 ++Idx;
2838
2839 // If there are add operands they would be next.
2840 if (Idx < Ops.size()) {
2841 bool DeletedAdd = false;
2842 // If the original flags and all inlined SCEVAddExprs are NUW, use the
2843 // common NUW flag for expression after inlining. Other flags cannot be
2844 // preserved, because they may depend on the original order of operations.
2845 SCEV::NoWrapFlags CommonFlags = maskFlags(OrigFlags, SCEV::FlagNUW);
2846 while (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Ops[Idx])) {
2847 if (Ops.size() > AddOpsInlineThreshold ||
2848 Add->getNumOperands() > AddOpsInlineThreshold)
2849 break;
2850 // If we have an add, expand the add operands onto the end of the operands
2851 // list.
2852 Ops.erase(Ops.begin()+Idx);
2853 append_range(Ops, Add->operands());
2854 DeletedAdd = true;
2855 CommonFlags = maskFlags(CommonFlags, Add->getNoWrapFlags());
2856 }
2857
2858 // If we deleted at least one add, we added operands to the end of the list,
2859 // and they are not necessarily sorted. Recurse to resort and resimplify
2860 // any operands we just acquired.
2861 if (DeletedAdd)
2862 return getAddExpr(Ops, CommonFlags, Depth + 1);
2863 }
2864
2865 // Skip over the add expression until we get to a multiply.
2866 while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scMulExpr)
2867 ++Idx;
2868
2869 // Check to see if there are any folding opportunities present with
2870 // operands multiplied by constant values.
2871 if (Idx < Ops.size() && isa<SCEVMulExpr>(Ops[Idx])) {
2875 APInt AccumulatedConstant(BitWidth, 0);
2876 if (CollectAddOperandsWithScales(M, NewOps, AccumulatedConstant,
2877 Ops, APInt(BitWidth, 1), *this)) {
2878 struct APIntCompare {
2879 bool operator()(const APInt &LHS, const APInt &RHS) const {
2880 return LHS.ult(RHS);
2881 }
2882 };
2883
2884 // Some interesting folding opportunity is present, so its worthwhile to
2885 // re-generate the operands list. Group the operands by constant scale,
2886 // to avoid multiplying by the same constant scale multiple times.
2887 std::map<APInt, SmallVector<SCEVUse, 4>, APIntCompare> MulOpLists;
2888 for (const SCEV *NewOp : NewOps)
2889 MulOpLists[M.find(NewOp)->second].push_back(NewOp);
2890 // Re-generate the operands list.
2891 Ops.clear();
2892 if (AccumulatedConstant != 0)
2893 Ops.push_back(getConstant(AccumulatedConstant));
2894 for (auto &MulOp : MulOpLists) {
2895 if (MulOp.first == 1) {
2896 Ops.push_back(getAddExpr(MulOp.second, SCEV::FlagAnyWrap, Depth + 1));
2897 } else if (MulOp.first != 0) {
2898 Ops.push_back(getMulExpr(
2899 getConstant(MulOp.first),
2900 getAddExpr(MulOp.second, SCEV::FlagAnyWrap, Depth + 1),
2901 SCEV::FlagAnyWrap, Depth + 1));
2902 }
2903 }
2904 if (Ops.empty())
2905 return getZero(Ty);
2906 if (Ops.size() == 1)
2907 return Ops[0];
2908 return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
2909 }
2910 }
2911
2912 // Given a SCEVMulExpr and an operand index, return the product of all
2913 // operands except the one at OpIdx.
2914 auto StripFactor = [&](const SCEVMulExpr *M, unsigned OpIdx) -> SCEVUse {
2915 if (M->getNumOperands() == 2)
2916 return M->getOperand(OpIdx == 0);
2917 SmallVector<SCEVUse, 4> Remaining(M->operands().take_front(OpIdx));
2918 append_range(Remaining, M->operands().drop_front(OpIdx + 1));
2919 return getMulExpr(Remaining, SCEV::FlagAnyWrap, Depth + 1);
2920 };
2921
2922 // If we are adding something to a multiply expression, make sure the
2923 // something is not already an operand of the multiply. If so, merge it into
2924 // the multiply.
2925 for (; Idx < Ops.size() && isa<SCEVMulExpr>(Ops[Idx]); ++Idx) {
2926 const SCEVMulExpr *Mul = cast<SCEVMulExpr>(Ops[Idx]);
2927 for (unsigned MulOp = 0, e = Mul->getNumOperands(); MulOp != e; ++MulOp) {
2928 // Scan all terms to find every occurrence of common factor MulOpSCEV
2929 // and fold them in one shot:
2930 // A1*X + A2*X + ... + An*X --> X * (A1 + A2 + ... + An)
2931 const SCEV *MulOpSCEV = Mul->getOperand(MulOp);
2932 if (isa<SCEVConstant>(MulOpSCEV))
2933 continue;
2934
2935 // Cofactors: 1 for bare addends matching MulOpSCEV, or the
2936 // remaining product for multiply terms containing MulOpSCEV.
2937 SmallVector<SCEVUse, 4> Cofactors;
2938 SmallVector<unsigned, 4> DeadIndices;
2939 for (unsigned AddOp = 0, e = Ops.size(); AddOp != e; ++AddOp) {
2940 if (MulOpSCEV == Ops[AddOp]) {
2941 // W + X + (X * Y * Z) --> W + (X * ((Y*Z)+1))
2942 Cofactors.push_back(getOne(Ty));
2943 DeadIndices.push_back(AddOp);
2944 continue;
2945 }
2946
2947 if (AddOp <= Idx || !isa<SCEVMulExpr>(Ops[AddOp]))
2948 continue;
2949
2950 const SCEVMulExpr *OtherMul = cast<SCEVMulExpr>(Ops[AddOp]);
2951 for (unsigned OMulOp = 0, OE = OtherMul->getNumOperands(); OMulOp != OE;
2952 ++OMulOp) {
2953 if (OtherMul->getOperand(OMulOp) == MulOpSCEV) {
2954 // (A*B*C) + (A*D*E) --> A * (B*C + D*E)
2955 Cofactors.push_back(StripFactor(OtherMul, OMulOp));
2956 DeadIndices.push_back(AddOp);
2957 break;
2958 }
2959 }
2960 }
2961
2962 // Fold all collected cofactors with the anchor multiply's cofactor:
2963 // MulOpSCEV * (Cofactor_1 + ... + Cofactor_n + AnchorCofactor)
2964 if (!Cofactors.empty()) {
2965 Cofactors.push_back(StripFactor(Mul, MulOp));
2966
2967 SCEVUse InnerSum = getAddExpr(Cofactors, SCEV::FlagAnyWrap, Depth + 1);
2968 SCEVUse OuterMul =
2969 getMulExpr(MulOpSCEV, InnerSum, SCEV::FlagAnyWrap, Depth + 1);
2970
2971 // DeadIndices does not include Idx (the anchor), hence +1.
2972 if (Ops.size() == DeadIndices.size() + 1)
2973 return OuterMul;
2974
2975 // Erase Ops[Idx] first, then erase DeadIndices in reverse order.
2976 // The -1 adjustment accounts for the shift from removing Idx;
2977 // reverse order means each erasure only shifts later positions,
2978 // which have already been processed.
2979 Ops.erase(Ops.begin() + Idx);
2980 for (unsigned Dead : reverse(DeadIndices))
2981 Ops.erase(Ops.begin() + (Dead > Idx ? Dead - 1 : Dead));
2982
2983 Ops.push_back(OuterMul);
2984 return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
2985 }
2986 }
2987 }
2988
2989 // If there are any add recurrences in the operands list, see if any other
2990 // added values are loop invariant. If so, we can fold them into the
2991 // recurrence.
2992 while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scAddRecExpr)
2993 ++Idx;
2994
2995 // Scan over all recurrences, trying to fold loop invariants into them.
2996 for (; Idx < Ops.size() && isa<SCEVAddRecExpr>(Ops[Idx]); ++Idx) {
2997 // Scan all of the other operands to this add and add them to the vector if
2998 // they are loop invariant w.r.t. the recurrence.
3000 const SCEVAddRecExpr *AddRec = cast<SCEVAddRecExpr>(Ops[Idx]);
3001 const Loop *AddRecLoop = AddRec->getLoop();
3002 for (unsigned i = 0, e = Ops.size(); i != e; ++i)
3003 if (isAvailableAtLoopEntry(Ops[i], AddRecLoop)) {
3004 LIOps.push_back(Ops[i]);
3005 Ops.erase(Ops.begin()+i);
3006 --i; --e;
3007 }
3008
3009 // If we found some loop invariants, fold them into the recurrence.
3010 if (!LIOps.empty()) {
3011 // Compute nowrap flags for the addition of the loop-invariant ops and
3012 // the addrec. Temporarily push it as an operand for that purpose. These
3013 // flags are valid in the scope of the addrec only.
3014 LIOps.push_back(AddRec);
3015 SCEV::NoWrapFlags Flags = ComputeFlags(LIOps);
3016 LIOps.pop_back();
3017
3018 // NLI + LI + {Start,+,Step} --> NLI + {LI+Start,+,Step}
3019 LIOps.push_back(AddRec->getStart());
3020
3021 SmallVector<SCEVUse, 4> AddRecOps(AddRec->operands());
3022
3023 // It is not in general safe to propagate flags valid on an add within
3024 // the addrec scope to one outside it. We must prove that the inner
3025 // scope is guaranteed to execute if the outer one does to be able to
3026 // safely propagate. We know the program is undefined if poison is
3027 // produced on the inner scoped addrec. We also know that *for this use*
3028 // the outer scoped add can't overflow (because of the flags we just
3029 // computed for the inner scoped add) without the program being undefined.
3030 // Proving that entry to the outer scope neccesitates entry to the inner
3031 // scope, thus proves the program undefined if the flags would be violated
3032 // in the outer scope.
3033 SCEV::NoWrapFlags AddFlags = Flags;
3034 if (AddFlags != SCEV::FlagAnyWrap) {
3035 auto *DefI = getDefiningScopeBound(LIOps);
3036 auto *ReachI = &*AddRecLoop->getHeader()->begin();
3037 if (!isGuaranteedToTransferExecutionTo(DefI, ReachI))
3038 AddFlags = SCEV::FlagAnyWrap;
3039 }
3040 AddRecOps[0] = getAddExpr(LIOps, AddFlags, Depth + 1);
3041
3042 // Build the new addrec. Propagate the NUW and NSW flags if both the
3043 // outer add and the inner addrec are guaranteed to have no overflow.
3044 // Always propagate NW.
3045 Flags = AddRec->getNoWrapFlags(setFlags(Flags, SCEV::FlagNW));
3046 const SCEV *NewRec = getAddRecExpr(AddRecOps, AddRecLoop, Flags);
3047
3048 // If all of the other operands were loop invariant, we are done.
3049 if (Ops.size() == 1) return NewRec;
3050
3051 // Otherwise, add the folded AddRec by the non-invariant parts.
3052 for (unsigned i = 0;; ++i)
3053 if (Ops[i] == AddRec) {
3054 Ops[i] = NewRec;
3055 break;
3056 }
3057 return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
3058 }
3059
3060 // Okay, if there weren't any loop invariants to be folded, check to see if
3061 // there are multiple AddRec's with the same loop induction variable being
3062 // added together. If so, we can fold them.
3063 for (unsigned OtherIdx = Idx+1;
3064 OtherIdx < Ops.size() && isa<SCEVAddRecExpr>(Ops[OtherIdx]);
3065 ++OtherIdx) {
3066 // We expect the AddRecExpr's to be sorted in reverse dominance order,
3067 // so that the 1st found AddRecExpr is dominated by all others.
3068 assert(DT.dominates(
3069 cast<SCEVAddRecExpr>(Ops[OtherIdx])->getLoop()->getHeader(),
3070 AddRec->getLoop()->getHeader()) &&
3071 "AddRecExprs are not sorted in reverse dominance order?");
3072 if (AddRecLoop == cast<SCEVAddRecExpr>(Ops[OtherIdx])->getLoop()) {
3073 // Other + {A,+,B}<L> + {C,+,D}<L> --> Other + {A+C,+,B+D}<L>
3074 SmallVector<SCEVUse, 4> AddRecOps(AddRec->operands());
3075 for (; OtherIdx != Ops.size() && isa<SCEVAddRecExpr>(Ops[OtherIdx]);
3076 ++OtherIdx) {
3077 const auto *OtherAddRec = cast<SCEVAddRecExpr>(Ops[OtherIdx]);
3078 if (OtherAddRec->getLoop() == AddRecLoop) {
3079 for (unsigned i = 0, e = OtherAddRec->getNumOperands();
3080 i != e; ++i) {
3081 if (i >= AddRecOps.size()) {
3082 append_range(AddRecOps, OtherAddRec->operands().drop_front(i));
3083 break;
3084 }
3085 AddRecOps[i] =
3086 getAddExpr(AddRecOps[i], OtherAddRec->getOperand(i),
3088 }
3089 Ops.erase(Ops.begin() + OtherIdx); --OtherIdx;
3090 }
3091 }
3092 // Step size has changed, so we cannot guarantee no self-wraparound.
3093 Ops[Idx] = getAddRecExpr(AddRecOps, AddRecLoop, SCEV::FlagAnyWrap);
3094 return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
3095 }
3096 }
3097
3098 // Otherwise couldn't fold anything into this recurrence. Move onto the
3099 // next one.
3100 }
3101
3102 // Okay, it looks like we really DO need an add expr. Check to see if we
3103 // already have one, otherwise create a new one.
3104 return getOrCreateAddExpr(Ops, ComputeFlags(Ops));
3105}
3106
3107const SCEV *ScalarEvolution::getOrCreateAddExpr(ArrayRef<SCEVUse> Ops,
3108 SCEV::NoWrapFlags Flags) {
3110 ID.AddInteger(scAddExpr);
3111 for (const SCEV *Op : Ops)
3112 ID.AddPointer(Op);
3113 void *IP = nullptr;
3114 SCEVAddExpr *S =
3115 static_cast<SCEVAddExpr *>(UniqueSCEVs.FindNodeOrInsertPos(ID, IP));
3116 if (!S) {
3117 SCEVUse *O = SCEVAllocator.Allocate<SCEVUse>(Ops.size());
3119 S = new (SCEVAllocator)
3120 SCEVAddExpr(ID.Intern(SCEVAllocator), O, Ops.size());
3121 UniqueSCEVs.InsertNode(S, IP);
3122 S->computeAndSetCanonical(*this);
3123 registerUser(S, Ops);
3124 }
3125 S->setNoWrapFlags(Flags);
3126 return S;
3127}
3128
3129const SCEV *ScalarEvolution::getOrCreateAddRecExpr(ArrayRef<SCEVUse> Ops,
3130 const Loop *L,
3131 SCEV::NoWrapFlags Flags) {
3132 FoldingSetNodeID ID;
3133 ID.AddInteger(scAddRecExpr);
3134 for (const SCEV *Op : Ops)
3135 ID.AddPointer(Op);
3136 ID.AddPointer(L);
3137 void *IP = nullptr;
3138 SCEVAddRecExpr *S =
3139 static_cast<SCEVAddRecExpr *>(UniqueSCEVs.FindNodeOrInsertPos(ID, IP));
3140 if (!S) {
3141 SCEVUse *O = SCEVAllocator.Allocate<SCEVUse>(Ops.size());
3143 S = new (SCEVAllocator)
3144 SCEVAddRecExpr(ID.Intern(SCEVAllocator), O, Ops.size(), L);
3145 UniqueSCEVs.InsertNode(S, IP);
3146 S->computeAndSetCanonical(*this);
3147 LoopUsers[L].push_back(S);
3148 registerUser(S, Ops);
3149 }
3150 setNoWrapFlags(S, Flags);
3151 return S;
3152}
3153
3154const SCEV *ScalarEvolution::getOrCreateMulExpr(ArrayRef<SCEVUse> Ops,
3155 SCEV::NoWrapFlags Flags) {
3156 FoldingSetNodeID ID;
3157 ID.AddInteger(scMulExpr);
3158 for (const SCEV *Op : Ops)
3159 ID.AddPointer(Op);
3160 void *IP = nullptr;
3161 SCEVMulExpr *S =
3162 static_cast<SCEVMulExpr *>(UniqueSCEVs.FindNodeOrInsertPos(ID, IP));
3163 if (!S) {
3164 SCEVUse *O = SCEVAllocator.Allocate<SCEVUse>(Ops.size());
3166 S = new (SCEVAllocator) SCEVMulExpr(ID.Intern(SCEVAllocator),
3167 O, Ops.size());
3168 UniqueSCEVs.InsertNode(S, IP);
3169 S->computeAndSetCanonical(*this);
3170 registerUser(S, Ops);
3171 }
3172 S->setNoWrapFlags(Flags);
3173 return S;
3174}
3175
3176static uint64_t umul_ov(uint64_t i, uint64_t j, bool &Overflow) {
3177 uint64_t k = i*j;
3178 if (j > 1 && k / j != i) Overflow = true;
3179 return k;
3180}
3181
3182/// Compute the result of "n choose k", the binomial coefficient. If an
3183/// intermediate computation overflows, Overflow will be set and the return will
3184/// be garbage. Overflow is not cleared on absence of overflow.
3185static uint64_t Choose(uint64_t n, uint64_t k, bool &Overflow) {
3186 // We use the multiplicative formula:
3187 // n(n-1)(n-2)...(n-(k-1)) / k(k-1)(k-2)...1 .
3188 // At each iteration, we take the n-th term of the numeral and divide by the
3189 // (k-n)th term of the denominator. This division will always produce an
3190 // integral result, and helps reduce the chance of overflow in the
3191 // intermediate computations. However, we can still overflow even when the
3192 // final result would fit.
3193
3194 if (n == 0 || n == k) return 1;
3195 if (k > n) return 0;
3196
3197 if (k > n/2)
3198 k = n-k;
3199
3200 uint64_t r = 1;
3201 for (uint64_t i = 1; i <= k; ++i) {
3202 r = umul_ov(r, n-(i-1), Overflow);
3203 r /= i;
3204 }
3205 return r;
3206}
3207
3208/// Determine if any of the operands in this SCEV are a constant or if
3209/// any of the add or multiply expressions in this SCEV contain a constant.
3210static bool containsConstantInAddMulChain(const SCEV *StartExpr) {
3211 struct FindConstantInAddMulChain {
3212 bool FoundConstant = false;
3213
3214 bool follow(const SCEV *S) {
3215 FoundConstant |= isa<SCEVConstant>(S);
3216 return isa<SCEVAddExpr>(S) || isa<SCEVMulExpr>(S);
3217 }
3218
3219 bool isDone() const {
3220 return FoundConstant;
3221 }
3222 };
3223
3224 FindConstantInAddMulChain F;
3226 ST.visitAll(StartExpr);
3227 return F.FoundConstant;
3228}
3229
3230/// Get a canonical multiply expression, or something simpler if possible.
3232 SCEV::NoWrapFlags OrigFlags,
3233 unsigned Depth) {
3234 assert(OrigFlags == maskFlags(OrigFlags, SCEV::FlagNUW | SCEV::FlagNSW) &&
3235 "only nuw or nsw allowed");
3236 assert(!Ops.empty() && "Cannot get empty mul!");
3237 if (Ops.size() == 1) return Ops[0];
3238#ifndef NDEBUG
3239 Type *ETy = Ops[0]->getType();
3240 assert(!ETy->isPointerTy());
3241 for (unsigned i = 1, e = Ops.size(); i != e; ++i)
3242 assert(Ops[i]->getType() == ETy &&
3243 "SCEVMulExpr operand types don't match!");
3244#endif
3245
3246 const SCEV *Folded = constantFoldAndGroupOps(
3247 *this, LI, DT, Ops,
3248 [](const APInt &C1, const APInt &C2) { return C1 * C2; },
3249 [](const APInt &C) { return C.isOne(); }, // identity
3250 [](const APInt &C) { return C.isZero(); }); // absorber
3251 if (Folded)
3252 return Folded;
3253
3254 // Delay expensive flag strengthening until necessary.
3255 auto ComputeFlags = [this, OrigFlags](const ArrayRef<SCEVUse> Ops) {
3256 return StrengthenNoWrapFlags(this, scMulExpr, Ops, OrigFlags);
3257 };
3258
3259 // Limit recursion calls depth.
3261 return getOrCreateMulExpr(Ops, ComputeFlags(Ops));
3262
3263 if (SCEV *S = findExistingSCEVInCache(scMulExpr, Ops)) {
3264 // Don't strengthen flags if we have no new information.
3265 SCEVMulExpr *Mul = static_cast<SCEVMulExpr *>(S);
3266 if (Mul->getNoWrapFlags(OrigFlags) != OrigFlags)
3267 Mul->setNoWrapFlags(ComputeFlags(Ops));
3268 return S;
3269 }
3270
3271 if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(Ops[0])) {
3272 if (Ops.size() == 2) {
3273 // C1*(C2+V) -> C1*C2 + C1*V
3274 // If any of Add's ops are Adds or Muls with a constant, apply this
3275 // transformation as well.
3276 //
3277 // TODO: There are some cases where this transformation is not
3278 // profitable; for example, Add = (C0 + X) * Y + Z. Maybe the scope of
3279 // this transformation should be narrowed down.
3280 const SCEV *Op0, *Op1;
3281 if (match(Ops[1], m_scev_Add(m_SCEV(Op0), m_SCEV(Op1))) &&
3283 const SCEV *LHS = getMulExpr(LHSC, Op0, SCEV::FlagAnyWrap, Depth + 1);
3284 const SCEV *RHS = getMulExpr(LHSC, Op1, SCEV::FlagAnyWrap, Depth + 1);
3285 return getAddExpr(LHS, RHS, SCEV::FlagAnyWrap, Depth + 1);
3286 }
3287
3288 if (Ops[0]->isAllOnesValue()) {
3289 // If we have a mul by -1 of an add, try distributing the -1 among the
3290 // add operands.
3291 if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Ops[1])) {
3293 bool AnyFolded = false;
3294 for (const SCEV *AddOp : Add->operands()) {
3295 const SCEV *Mul = getMulExpr(Ops[0], SCEVUse(AddOp),
3297 if (!isa<SCEVMulExpr>(Mul)) AnyFolded = true;
3298 NewOps.push_back(Mul);
3299 }
3300 if (AnyFolded)
3301 return getAddExpr(NewOps, SCEV::FlagAnyWrap, Depth + 1);
3302 } else if (const auto *AddRec = dyn_cast<SCEVAddRecExpr>(Ops[1])) {
3303 // Negation preserves a recurrence's no self-wrap property.
3304 SmallVector<SCEVUse, 4> Operands;
3305 for (const SCEV *AddRecOp : AddRec->operands())
3306 Operands.push_back(getMulExpr(Ops[0], SCEVUse(AddRecOp),
3307 SCEV::FlagAnyWrap, Depth + 1));
3308 // Let M be the minimum representable signed value. AddRec with nsw
3309 // multiplied by -1 can have signed overflow if and only if it takes a
3310 // value of M: M * (-1) would stay M and (M + 1) * (-1) would be the
3311 // maximum signed value. In all other cases signed overflow is
3312 // impossible.
3313 auto FlagsMask = SCEV::FlagNW;
3314 if (AddRec->hasNoSignedWrap()) {
3315 auto MinInt =
3316 APInt::getSignedMinValue(getTypeSizeInBits(AddRec->getType()));
3317 if (getSignedRangeMin(AddRec) != MinInt)
3318 FlagsMask = setFlags(FlagsMask, SCEV::FlagNSW);
3319 }
3320 return getAddRecExpr(Operands, AddRec->getLoop(),
3321 AddRec->getNoWrapFlags(FlagsMask));
3322 }
3323 }
3324
3325 // Try to push the constant operand into a ZExt: C * zext (A + B) ->
3326 // zext (C*A + C*B) if trunc (C) * (A + B) does not unsigned-wrap.
3327 const SCEVAddExpr *InnerAdd;
3328 if (match(Ops[1], m_scev_ZExt(m_scev_Add(InnerAdd)))) {
3329 const SCEV *NarrowC = getTruncateExpr(LHSC, InnerAdd->getType());
3330 if (isa<SCEVConstant>(InnerAdd->getOperand(0)) &&
3331 getZeroExtendExpr(NarrowC, Ops[1]->getType()) == LHSC &&
3332 hasFlags(StrengthenNoWrapFlags(this, scMulExpr, {NarrowC, InnerAdd},
3334 SCEV::FlagNUW)) {
3335 auto *Res = getMulExpr(NarrowC, InnerAdd, SCEV::FlagNUW, Depth + 1);
3336 return getZeroExtendExpr(Res, Ops[1]->getType(), Depth + 1);
3337 };
3338 }
3339
3340 // Try to fold (C1 * D /u C2) -> C1/C2 * D, if C1 and C2 are powers-of-2,
3341 // D is a multiple of C2, and C1 is a multiple of C2. If C2 is a multiple
3342 // of C1, fold to (D /u (C2 /u C1)).
3343 const SCEV *D;
3344 APInt C1V = LHSC->getAPInt();
3345 // (C1 * D /u C2) == -1 * -C1 * D /u C2 when C1 != INT_MIN. Don't treat -1
3346 // as -1 * 1, as it won't enable additional folds.
3347 if (C1V.isNegative() && !C1V.isMinSignedValue() && !C1V.isAllOnes())
3348 C1V = C1V.abs();
3349 const SCEVConstant *C2;
3350 if (C1V.isPowerOf2() &&
3352 C2->getAPInt().isPowerOf2() &&
3353 C1V.logBase2() <= getMinTrailingZeros(D)) {
3354 const SCEV *NewMul = nullptr;
3355 if (C1V.uge(C2->getAPInt())) {
3356 NewMul = getMulExpr(getUDivExpr(getConstant(C1V), C2), D);
3357 } else if (C2->getAPInt().logBase2() <= getMinTrailingZeros(D)) {
3358 assert(C1V.ugt(1) && "C1 <= 1 should have been folded earlier");
3359 NewMul = getUDivExpr(D, getUDivExpr(C2, getConstant(C1V)));
3360 }
3361 if (NewMul)
3362 return C1V == LHSC->getAPInt() ? NewMul : getNegativeSCEV(NewMul);
3363 }
3364 }
3365 }
3366
3367 // Skip over the add expression until we get to a multiply.
3368 unsigned Idx = 0;
3369 while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scMulExpr)
3370 ++Idx;
3371
3372 // If there are mul operands inline them all into this expression.
3373 if (Idx < Ops.size()) {
3374 bool DeletedMul = false;
3375 while (const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(Ops[Idx])) {
3376 if (Ops.size() > MulOpsInlineThreshold)
3377 break;
3378 // If we have an mul, expand the mul operands onto the end of the
3379 // operands list.
3380 Ops.erase(Ops.begin()+Idx);
3381 append_range(Ops, Mul->operands());
3382 DeletedMul = true;
3383 }
3384
3385 // If we deleted at least one mul, we added operands to the end of the
3386 // list, and they are not necessarily sorted. Recurse to resort and
3387 // resimplify any operands we just acquired.
3388 if (DeletedMul)
3389 return getMulExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
3390 }
3391
3392 // If there are any add recurrences in the operands list, see if any other
3393 // added values are loop invariant. If so, we can fold them into the
3394 // recurrence.
3395 while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scAddRecExpr)
3396 ++Idx;
3397
3398 // Scan over all recurrences, trying to fold loop invariants into them.
3399 for (; Idx < Ops.size() && isa<SCEVAddRecExpr>(Ops[Idx]); ++Idx) {
3400 // Scan all of the other operands to this mul and add them to the vector
3401 // if they are loop invariant w.r.t. the recurrence.
3403 const SCEVAddRecExpr *AddRec = cast<SCEVAddRecExpr>(Ops[Idx]);
3404 for (unsigned i = 0, e = Ops.size(); i != e; ++i)
3405 if (isAvailableAtLoopEntry(Ops[i], AddRec->getLoop())) {
3406 LIOps.push_back(Ops[i]);
3407 Ops.erase(Ops.begin()+i);
3408 --i; --e;
3409 }
3410
3411 // If we found some loop invariants, fold them into the recurrence.
3412 if (!LIOps.empty()) {
3413 // NLI * LI * {Start,+,Step} --> NLI * {LI*Start,+,LI*Step}
3415 NewOps.reserve(AddRec->getNumOperands());
3416 const SCEV *Scale = getMulExpr(LIOps, SCEV::FlagAnyWrap, Depth + 1);
3417
3418 // If both the mul and addrec are nuw, we can preserve nuw.
3419 // If both the mul and addrec are nsw, we can only preserve nsw if either
3420 // a) they are also nuw, or
3421 // b) all multiplications of addrec operands with scale are nsw.
3422 SCEV::NoWrapFlags Flags =
3423 AddRec->getNoWrapFlags(ComputeFlags({Scale, AddRec}));
3424
3425 for (unsigned i = 0, e = AddRec->getNumOperands(); i != e; ++i) {
3426 NewOps.push_back(getMulExpr(Scale, AddRec->getOperand(i),
3427 SCEV::FlagAnyWrap, Depth + 1));
3428
3429 if (hasFlags(Flags, SCEV::FlagNSW) && !hasFlags(Flags, SCEV::FlagNUW)) {
3431 Instruction::Mul, getSignedRange(Scale),
3433 if (!NSWRegion.contains(getSignedRange(AddRec->getOperand(i))))
3434 Flags = clearFlags(Flags, SCEV::FlagNSW);
3435 }
3436 }
3437
3438 const SCEV *NewRec = getAddRecExpr(NewOps, AddRec->getLoop(), Flags);
3439
3440 // If all of the other operands were loop invariant, we are done.
3441 if (Ops.size() == 1) return NewRec;
3442
3443 // Otherwise, multiply the folded AddRec by the non-invariant parts.
3444 for (unsigned i = 0;; ++i)
3445 if (Ops[i] == AddRec) {
3446 Ops[i] = NewRec;
3447 break;
3448 }
3449 return getMulExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
3450 }
3451
3452 // Okay, if there weren't any loop invariants to be folded, check to see
3453 // if there are multiple AddRec's with the same loop induction variable
3454 // being multiplied together. If so, we can fold them.
3455
3456 // {A1,+,A2,+,...,+,An}<L> * {B1,+,B2,+,...,+,Bn}<L>
3457 // = {x=1 in [ sum y=x..2x [ sum z=max(y-x, y-n)..min(x,n) [
3458 // choose(x, 2x)*choose(2x-y, x-z)*A_{y-z}*B_z
3459 // ]]],+,...up to x=2n}.
3460 // Note that the arguments to choose() are always integers with values
3461 // known at compile time, never SCEV objects.
3462 //
3463 // The implementation avoids pointless extra computations when the two
3464 // addrec's are of different length (mathematically, it's equivalent to
3465 // an infinite stream of zeros on the right).
3466 bool OpsModified = false;
3467 for (unsigned OtherIdx = Idx+1;
3468 OtherIdx != Ops.size() && isa<SCEVAddRecExpr>(Ops[OtherIdx]);
3469 ++OtherIdx) {
3470 const SCEVAddRecExpr *OtherAddRec =
3471 dyn_cast<SCEVAddRecExpr>(Ops[OtherIdx]);
3472 if (!OtherAddRec || OtherAddRec->getLoop() != AddRec->getLoop())
3473 continue;
3474
3475 // Limit max number of arguments to avoid creation of unreasonably big
3476 // SCEVAddRecs with very complex operands.
3477 if (AddRec->getNumOperands() + OtherAddRec->getNumOperands() - 1 >
3478 MaxAddRecSize || hasHugeExpression({AddRec, OtherAddRec}))
3479 continue;
3480
3481 bool Overflow = false;
3482 Type *Ty = AddRec->getType();
3483 bool LargerThan64Bits = getTypeSizeInBits(Ty) > 64;
3484 SmallVector<SCEVUse, 7> AddRecOps;
3485 for (int x = 0, xe = AddRec->getNumOperands() +
3486 OtherAddRec->getNumOperands() - 1; x != xe && !Overflow; ++x) {
3488 for (int y = x, ye = 2*x+1; y != ye && !Overflow; ++y) {
3489 uint64_t Coeff1 = Choose(x, 2*x - y, Overflow);
3490 for (int z = std::max(y-x, y-(int)AddRec->getNumOperands()+1),
3491 ze = std::min(x+1, (int)OtherAddRec->getNumOperands());
3492 z < ze && !Overflow; ++z) {
3493 uint64_t Coeff2 = Choose(2*x - y, x-z, Overflow);
3494 uint64_t Coeff;
3495 if (LargerThan64Bits)
3496 Coeff = umul_ov(Coeff1, Coeff2, Overflow);
3497 else
3498 Coeff = Coeff1*Coeff2;
3499 const SCEV *CoeffTerm = getConstant(Ty, Coeff);
3500 const SCEV *Term1 = AddRec->getOperand(y-z);
3501 const SCEV *Term2 = OtherAddRec->getOperand(z);
3502 SumOps.push_back(getMulExpr(CoeffTerm, Term1, Term2,
3503 SCEV::FlagAnyWrap, Depth + 1));
3504 }
3505 }
3506 if (SumOps.empty())
3507 SumOps.push_back(getZero(Ty));
3508 AddRecOps.push_back(getAddExpr(SumOps, SCEV::FlagAnyWrap, Depth + 1));
3509 }
3510 if (!Overflow) {
3511 const SCEV *NewAddRec = getAddRecExpr(AddRecOps, AddRec->getLoop(),
3513 if (Ops.size() == 2) return NewAddRec;
3514 Ops[Idx] = NewAddRec;
3515 Ops.erase(Ops.begin() + OtherIdx); --OtherIdx;
3516 OpsModified = true;
3517 AddRec = dyn_cast<SCEVAddRecExpr>(NewAddRec);
3518 if (!AddRec)
3519 break;
3520 }
3521 }
3522 if (OpsModified)
3523 return getMulExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
3524
3525 // Otherwise couldn't fold anything into this recurrence. Move onto the
3526 // next one.
3527 }
3528
3529 // Okay, it looks like we really DO need an mul expr. Check to see if we
3530 // already have one, otherwise create a new one.
3531 return getOrCreateMulExpr(Ops, ComputeFlags(Ops));
3532}
3533
3534/// Represents an unsigned remainder expression based on unsigned division.
3536 assert(getEffectiveSCEVType(LHS->getType()) ==
3537 getEffectiveSCEVType(RHS->getType()) &&
3538 "SCEVURemExpr operand types don't match!");
3539
3540 // Short-circuit easy cases
3541 if (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS)) {
3542 // If constant is one, the result is trivial
3543 if (RHSC->getValue()->isOne())
3544 return getZero(LHS->getType()); // X urem 1 --> 0
3545
3546 // If constant is a power of two, fold into a zext(trunc(LHS)).
3547 if (RHSC->getAPInt().isPowerOf2()) {
3548 Type *FullTy = LHS->getType();
3549 Type *TruncTy =
3550 IntegerType::get(getContext(), RHSC->getAPInt().logBase2());
3551 return getZeroExtendExpr(getTruncateExpr(LHS, TruncTy), FullTy);
3552 }
3553 }
3554
3555 // Fallback to %a == %x urem %y == %x -<nuw> ((%x udiv %y) *<nuw> %y)
3556 const SCEV *UDiv = getUDivExpr(LHS, RHS);
3557 const SCEV *Mult = getMulExpr(UDiv, RHS, SCEV::FlagNUW);
3558 return getMinusSCEV(LHS, Mult, SCEV::FlagNUW);
3559}
3560
3561/// Get a canonical unsigned division expression, or something simpler if
3562/// possible.
3564 assert(!LHS->getType()->isPointerTy() &&
3565 "SCEVUDivExpr operand can't be pointer!");
3566 assert(LHS->getType() == RHS->getType() &&
3567 "SCEVUDivExpr operand types don't match!");
3568
3570 ID.AddInteger(scUDivExpr);
3571 ID.AddPointer(LHS);
3572 ID.AddPointer(RHS);
3573 void *IP = nullptr;
3574 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP))
3575 return S;
3576
3577 // 0 udiv Y == 0
3578 if (match(LHS, m_scev_Zero()))
3579 return LHS;
3580
3581 if (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS)) {
3582 if (RHSC->getValue()->isOne())
3583 return LHS; // X udiv 1 --> x
3584 // If the denominator is zero, the result of the udiv is undefined. Don't
3585 // try to analyze it, because the resolution chosen here may differ from
3586 // the resolution chosen in other parts of the compiler.
3587 if (!RHSC->getValue()->isZero()) {
3588 // Determine if the division can be folded into the operands of
3589 // its operands.
3590 // TODO: Generalize this to non-constants by using known-bits information.
3591 Type *Ty = LHS->getType();
3592 unsigned LZ = RHSC->getAPInt().countl_zero();
3593 unsigned MaxShiftAmt = getTypeSizeInBits(Ty) - LZ - 1;
3594 // For non-power-of-two values, effectively round the value up to the
3595 // nearest power of two.
3596 if (!RHSC->getAPInt().isPowerOf2())
3597 ++MaxShiftAmt;
3598 IntegerType *ExtTy =
3599 IntegerType::get(getContext(), getTypeSizeInBits(Ty) + MaxShiftAmt);
3600 if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(LHS))
3601 if (const SCEVConstant *Step =
3602 dyn_cast<SCEVConstant>(AR->getStepRecurrence(*this))) {
3603 // {X,+,N}/C --> {X/C,+,N/C} if safe and N/C can be folded.
3604 const APInt &StepInt = Step->getAPInt();
3605 const APInt &DivInt = RHSC->getAPInt();
3606 if (!StepInt.urem(DivInt) &&
3607 getZeroExtendExpr(AR, ExtTy) ==
3608 getAddRecExpr(getZeroExtendExpr(AR->getStart(), ExtTy),
3609 getZeroExtendExpr(Step, ExtTy),
3610 AR->getLoop(), SCEV::FlagAnyWrap)) {
3611 SmallVector<SCEVUse, 4> Operands;
3612 for (const SCEV *Op : AR->operands())
3613 Operands.push_back(getUDivExpr(Op, RHS));
3614 return getAddRecExpr(Operands, AR->getLoop(), SCEV::FlagNW);
3615 }
3616 /// Get a canonical UDivExpr for a recurrence.
3617 /// {X,+,N}/C => {Y,+,N}/C where Y=X-(X%N). Safe when C%N=0.
3618 const APInt *StartRem;
3619 if (!DivInt.urem(StepInt) && match(getURemExpr(AR->getStart(), Step),
3620 m_scev_APInt(StartRem))) {
3621 bool NoWrap =
3622 getZeroExtendExpr(AR, ExtTy) ==
3623 getAddRecExpr(getZeroExtendExpr(AR->getStart(), ExtTy),
3624 getZeroExtendExpr(Step, ExtTy), AR->getLoop(),
3626
3627 // With N <= C and both N, C as powers-of-2, the transformation
3628 // {X,+,N}/C => {(X - X%N),+,N}/C preserves division results even
3629 // if wrapping occurs, as the division results remain equivalent for
3630 // all offsets in [[(X - X%N), X).
3631 bool CanFoldWithWrap = StepInt.ule(DivInt) && // N <= C
3632 StepInt.isPowerOf2() && DivInt.isPowerOf2();
3633 // Only fold if the subtraction can be folded in the start
3634 // expression.
3635 const SCEV *NewStart =
3636 getMinusSCEV(AR->getStart(), getConstant(*StartRem));
3637 if (*StartRem != 0 && (NoWrap || CanFoldWithWrap) &&
3638 !isa<SCEVAddExpr>(NewStart)) {
3639 const SCEV *NewLHS =
3640 getAddRecExpr(NewStart, Step, AR->getLoop(),
3641 NoWrap ? SCEV::FlagNW : SCEV::FlagAnyWrap);
3642 if (LHS != NewLHS) {
3643 LHS = NewLHS;
3644
3645 // Reset the ID to include the new LHS, and check if it is
3646 // already cached.
3647 ID.clear();
3648 ID.AddInteger(scUDivExpr);
3649 ID.AddPointer(LHS);
3650 ID.AddPointer(RHS);
3651 IP = nullptr;
3652 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP))
3653 return S;
3654 }
3655 }
3656 }
3657 }
3658 // (A*B)/C --> A*(B/C) if safe and B/C can be folded.
3659 if (const SCEVMulExpr *M = dyn_cast<SCEVMulExpr>(LHS)) {
3660 SmallVector<SCEVUse, 4> Operands;
3661 for (const SCEV *Op : M->operands())
3662 Operands.push_back(getZeroExtendExpr(Op, ExtTy));
3663 if (getZeroExtendExpr(M, ExtTy) == getMulExpr(Operands)) {
3664 // Find an operand that's safely divisible.
3665 for (unsigned i = 0, e = M->getNumOperands(); i != e; ++i) {
3666 const SCEV *Op = M->getOperand(i);
3667 const SCEV *Div = getUDivExpr(Op, RHSC);
3668 if (!isa<SCEVUDivExpr>(Div) && getMulExpr(Div, RHSC) == Op) {
3669 Operands = SmallVector<SCEVUse, 4>(M->operands());
3670 Operands[i] = Div;
3671 return getMulExpr(Operands);
3672 }
3673 }
3674
3675 // Even if it's not divisible, try to remove a common factor.
3676 if (const auto *LHSC = dyn_cast<SCEVConstant>(M->getOperand(0))) {
3677 APInt Factor = APIntOps::GreatestCommonDivisor(LHSC->getAPInt(),
3678 RHSC->getAPInt());
3679 if (!Factor.isIntN(1)) {
3680 SmallVector<SCEVUse, 2> NewOperands;
3681 NewOperands.push_back(getConstant(LHSC->getAPInt().udiv(Factor)));
3682 append_range(NewOperands, M->operands().drop_front());
3683 const SCEV *NewMul = getMulExpr(NewOperands);
3684 return getUDivExpr(NewMul,
3685 getConstant(RHSC->getAPInt().udiv(Factor)));
3686 }
3687 }
3688 }
3689 }
3690
3691 // (A/B)/C --> A/(B*C) if safe and B*C can be folded.
3692 if (const SCEVUDivExpr *OtherDiv = dyn_cast<SCEVUDivExpr>(LHS)) {
3693 if (auto *DivisorConstant =
3694 dyn_cast<SCEVConstant>(OtherDiv->getRHS())) {
3695 bool Overflow = false;
3696 APInt NewRHS =
3697 DivisorConstant->getAPInt().umul_ov(RHSC->getAPInt(), Overflow);
3698 if (Overflow) {
3699 return getConstant(RHSC->getType(), 0, false);
3700 }
3701 return getUDivExpr(OtherDiv->getLHS(), getConstant(NewRHS));
3702 }
3703 }
3704
3705 // (A+B)/C --> (A/C + B/C) if safe and A/C and B/C can be folded.
3706 if (const SCEVAddExpr *A = dyn_cast<SCEVAddExpr>(LHS)) {
3707 SmallVector<SCEVUse, 4> Operands;
3708 for (const SCEV *Op : A->operands())
3709 Operands.push_back(getZeroExtendExpr(Op, ExtTy));
3710 if (getZeroExtendExpr(A, ExtTy) == getAddExpr(Operands)) {
3711 Operands.clear();
3712 for (unsigned i = 0, e = A->getNumOperands(); i != e; ++i) {
3713 const SCEV *Op = getUDivExpr(A->getOperand(i), RHS);
3714 if (isa<SCEVUDivExpr>(Op) ||
3715 getMulExpr(Op, RHS) != A->getOperand(i))
3716 break;
3717 Operands.push_back(Op);
3718 }
3719 if (Operands.size() == A->getNumOperands())
3720 return getAddExpr(Operands);
3721 }
3722 }
3723
3724 // ((N - M) + (M * A)) / N --> ((N - 1) + (M * A)) / N
3725 // This is an idiom for rounding A up to the next multiple of N, where A
3726 // is aready known to be a multiple of M. In this case, instcombine can
3727 // see that some low bits of the added constant are unused, so can clear
3728 // them, but we want to canonicalise to set the low bits. This makes the
3729 // pattern easier to match, without needing to check for known bits in
3730 // A*M.
3731 const APInt &N = RHSC->getAPInt();
3732 const APInt *NMinusM, *M;
3733 const SCEV *A;
3734 if (match(LHS, m_scev_Add(m_scev_APInt(NMinusM),
3735 m_scev_Mul(m_scev_APInt(M), m_SCEV(A))))) {
3736 if (N.isPowerOf2() && M->isPowerOf2() && M->ult(N) &&
3737 *NMinusM == N - *M) {
3738 return getUDivExpr(
3740 RHS);
3741 }
3742 }
3743
3744 // Fold if both operands are constant.
3745 if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(LHS))
3746 return getConstant(LHSC->getAPInt().udiv(RHSC->getAPInt()));
3747 }
3748 }
3749
3750 // ((-C + (C smax %x)) /u %x) evaluates to zero, for any positive constant C.
3751 const APInt *NegC, *C;
3752 if (match(LHS,
3755 NegC->isNegative() && !NegC->isMinSignedValue() && *C == -*NegC)
3756 return getZero(LHS->getType());
3757
3758 // (%a * %b)<nuw> / %b -> %a
3759 const auto *Mul = dyn_cast<SCEVMulExpr>(LHS);
3760 if (Mul && Mul->hasNoUnsignedWrap()) {
3761 for (int i = 0, e = Mul->getNumOperands(); i != e; ++i) {
3762 if (Mul->getOperand(i) == RHS) {
3763 SmallVector<SCEVUse, 2> Operands;
3764 append_range(Operands, Mul->operands().take_front(i));
3765 append_range(Operands, Mul->operands().drop_front(i + 1));
3766 return getMulExpr(Operands);
3767 }
3768 }
3769 }
3770
3771 // TODO: Generalize to handle any common factors.
3772 // udiv (mul nuw a, vscale), (mul nuw b, vscale) --> udiv a, b
3773 const SCEV *NewLHS, *NewRHS;
3774 if (match(LHS, m_scev_c_NUWMul(m_SCEV(NewLHS), m_SCEVVScale())) &&
3775 match(RHS, m_scev_c_NUWMul(m_SCEV(NewRHS), m_SCEVVScale())))
3776 return getUDivExpr(NewLHS, NewRHS);
3777
3778 // The Insertion Point (IP) might be invalid by now (due to UniqueSCEVs
3779 // changes). Make sure we get a new one.
3780 IP = nullptr;
3781 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
3782 SCEV *S = new (SCEVAllocator) SCEVUDivExpr(ID.Intern(SCEVAllocator),
3783 LHS, RHS);
3784 UniqueSCEVs.InsertNode(S, IP);
3785 S->computeAndSetCanonical(*this);
3786 registerUser(S, ArrayRef<SCEVUse>({LHS, RHS}));
3787 return S;
3788}
3789
3790APInt gcd(const SCEVConstant *C1, const SCEVConstant *C2) {
3791 APInt A = C1->getAPInt().abs();
3792 APInt B = C2->getAPInt().abs();
3793 uint32_t ABW = A.getBitWidth();
3794 uint32_t BBW = B.getBitWidth();
3795
3796 if (ABW > BBW)
3797 B = B.zext(ABW);
3798 else if (ABW < BBW)
3799 A = A.zext(BBW);
3800
3801 return APIntOps::GreatestCommonDivisor(std::move(A), std::move(B));
3802}
3803
3804/// Get a canonical unsigned division expression, or something simpler if
3805/// possible. There is no representation for an exact udiv in SCEV IR, but we
3806/// can attempt to optimize it prior to construction.
3808 // Currently there is no exact specific logic.
3809
3810 return getUDivExpr(LHS, RHS);
3811}
3812
3813/// Get an add recurrence expression for the specified loop. Simplify the
3814/// expression as much as possible.
3816 const Loop *L,
3817 SCEV::NoWrapFlags Flags) {
3818 SmallVector<SCEVUse, 4> Operands;
3819 Operands.push_back(Start);
3820 if (const SCEVAddRecExpr *StepChrec = dyn_cast<SCEVAddRecExpr>(Step))
3821 if (StepChrec->getLoop() == L) {
3822 append_range(Operands, StepChrec->operands());
3823 return getAddRecExpr(Operands, L, maskFlags(Flags, SCEV::FlagNW));
3824 }
3825
3826 Operands.push_back(Step);
3827 return getAddRecExpr(Operands, L, Flags);
3828}
3829
3830/// Get an add recurrence expression for the specified loop. Simplify the
3831/// expression as much as possible.
3833 const Loop *L,
3834 SCEV::NoWrapFlags Flags) {
3835 if (Operands.size() == 1) return Operands[0];
3836#ifndef NDEBUG
3837 Type *ETy = getEffectiveSCEVType(Operands[0]->getType());
3838 for (const SCEV *Op : llvm::drop_begin(Operands)) {
3839 assert(getEffectiveSCEVType(Op->getType()) == ETy &&
3840 "SCEVAddRecExpr operand types don't match!");
3841 assert(!Op->getType()->isPointerTy() && "Step must be integer");
3842 }
3843 for (const SCEV *Op : Operands)
3845 "SCEVAddRecExpr operand is not available at loop entry!");
3846#endif
3847
3848 if (Operands.back()->isZero()) {
3849 Operands.pop_back();
3850 return getAddRecExpr(Operands, L, SCEV::FlagAnyWrap); // {X,+,0} --> X
3851 }
3852
3853 // It's tempting to want to call getConstantMaxBackedgeTakenCount count here and
3854 // use that information to infer NUW and NSW flags. However, computing a
3855 // BE count requires calling getAddRecExpr, so we may not yet have a
3856 // meaningful BE count at this point (and if we don't, we'd be stuck
3857 // with a SCEVCouldNotCompute as the cached BE count).
3858
3859 Flags = StrengthenNoWrapFlags(this, scAddRecExpr, Operands, Flags);
3860
3861 // Canonicalize nested AddRecs in by nesting them in order of loop depth.
3862 if (const SCEVAddRecExpr *NestedAR = dyn_cast<SCEVAddRecExpr>(Operands[0])) {
3863 const Loop *NestedLoop = NestedAR->getLoop();
3864 if (L->contains(NestedLoop)
3865 ? (L->getLoopDepth() < NestedLoop->getLoopDepth())
3866 : (!NestedLoop->contains(L) &&
3867 DT.dominates(L->getHeader(), NestedLoop->getHeader()))) {
3868 SmallVector<SCEVUse, 4> NestedOperands(NestedAR->operands());
3869 Operands[0] = NestedAR->getStart();
3870 // AddRecs require their operands be loop-invariant with respect to their
3871 // loops. Don't perform this transformation if it would break this
3872 // requirement.
3873 bool AllInvariant = all_of(
3874 Operands, [&](const SCEV *Op) { return isLoopInvariant(Op, L); });
3875
3876 if (AllInvariant) {
3877 // Create a recurrence for the outer loop with the same step size.
3878 //
3879 // The outer recurrence keeps its NW flag but only keeps NUW/NSW if the
3880 // inner recurrence has the same property.
3881 SCEV::NoWrapFlags OuterFlags =
3882 maskFlags(Flags, SCEV::FlagNW | NestedAR->getNoWrapFlags());
3883
3884 NestedOperands[0] = getAddRecExpr(Operands, L, OuterFlags);
3885 AllInvariant = all_of(NestedOperands, [&](const SCEV *Op) {
3886 return isLoopInvariant(Op, NestedLoop);
3887 });
3888
3889 if (AllInvariant) {
3890 // Ok, both add recurrences are valid after the transformation.
3891 //
3892 // The inner recurrence keeps its NW flag but only keeps NUW/NSW if
3893 // the outer recurrence has the same property.
3894 SCEV::NoWrapFlags InnerFlags =
3895 maskFlags(NestedAR->getNoWrapFlags(), SCEV::FlagNW | Flags);
3896 return getAddRecExpr(NestedOperands, NestedLoop, InnerFlags);
3897 }
3898 }
3899 // Reset Operands to its original state.
3900 Operands[0] = NestedAR;
3901 }
3902 }
3903
3904 // Okay, it looks like we really DO need an addrec expr. Check to see if we
3905 // already have one, otherwise create a new one.
3906 return getOrCreateAddRecExpr(Operands, L, Flags);
3907}
3908
3910 ArrayRef<SCEVUse> IndexExprs) {
3911 const SCEV *BaseExpr = getSCEV(GEP->getPointerOperand());
3912 // getSCEV(Base)->getType() has the same address space as Base->getType()
3913 // because SCEV::getType() preserves the address space.
3914 GEPNoWrapFlags NW = GEP->getNoWrapFlags();
3915 if (NW != GEPNoWrapFlags::none()) {
3916 // We'd like to propagate flags from the IR to the corresponding SCEV nodes,
3917 // but to do that, we have to ensure that said flag is valid in the entire
3918 // defined scope of the SCEV.
3919 // TODO: non-instructions have global scope. We might be able to prove
3920 // some global scope cases
3921 auto *GEPI = dyn_cast<Instruction>(GEP);
3922 if (!GEPI || !isSCEVExprNeverPoison(GEPI))
3923 NW = GEPNoWrapFlags::none();
3924 }
3925
3926 return getGEPExpr(BaseExpr, IndexExprs, GEP->getSourceElementType(), NW);
3927}
3928
3930 ArrayRef<SCEVUse> IndexExprs,
3931 Type *SrcElementTy, GEPNoWrapFlags NW) {
3933 if (NW.hasNoUnsignedSignedWrap())
3934 OffsetWrap = setFlags(OffsetWrap, SCEV::FlagNSW);
3935 if (NW.hasNoUnsignedWrap())
3936 OffsetWrap = setFlags(OffsetWrap, SCEV::FlagNUW);
3937
3938 Type *CurTy = BaseExpr->getType();
3939 Type *IntIdxTy = getEffectiveSCEVType(BaseExpr->getType());
3940 bool FirstIter = true;
3942 for (SCEVUse IndexExpr : IndexExprs) {
3943 // Compute the (potentially symbolic) offset in bytes for this index.
3944 if (StructType *STy = dyn_cast<StructType>(CurTy)) {
3945 // For a struct, add the member offset.
3946 ConstantInt *Index = cast<SCEVConstant>(IndexExpr)->getValue();
3947 unsigned FieldNo = Index->getZExtValue();
3948 const SCEV *FieldOffset = getOffsetOfExpr(IntIdxTy, STy, FieldNo);
3949 Offsets.push_back(FieldOffset);
3950
3951 // Update CurTy to the type of the field at Index.
3952 CurTy = STy->getTypeAtIndex(Index);
3953 } else {
3954 // Update CurTy to its element type.
3955 if (FirstIter) {
3956 assert(isa<PointerType>(CurTy) &&
3957 "The first index of a GEP indexes a pointer");
3958 CurTy = SrcElementTy;
3959 FirstIter = false;
3960 } else {
3962 }
3963 // For an array, add the element offset, explicitly scaled.
3964 const SCEV *ElementSize = getSizeOfExpr(IntIdxTy, CurTy);
3965 // Getelementptr indices are signed.
3966 IndexExpr = getTruncateOrSignExtend(IndexExpr, IntIdxTy);
3967
3968 // Multiply the index by the element size to compute the element offset.
3969 const SCEV *LocalOffset = getMulExpr(IndexExpr, ElementSize, OffsetWrap);
3970 Offsets.push_back(LocalOffset);
3971 }
3972 }
3973
3974 // Handle degenerate case of GEP without offsets.
3975 if (Offsets.empty())
3976 return BaseExpr;
3977
3978 // Add the offsets together, assuming nsw if inbounds.
3979 const SCEV *Offset = getAddExpr(Offsets, OffsetWrap);
3980 // Add the base address and the offset. We cannot use the nsw flag, as the
3981 // base address is unsigned. However, if we know that the offset is
3982 // non-negative, we can use nuw.
3983 bool NUW = NW.hasNoUnsignedWrap() ||
3986 auto *GEPExpr = getAddExpr(BaseExpr, Offset, BaseWrap);
3987 assert(BaseExpr->getType() == GEPExpr->getType() &&
3988 "GEP should not change type mid-flight.");
3989 return GEPExpr;
3990}
3991
3992SCEV *ScalarEvolution::findExistingSCEVInCache(SCEVTypes SCEVType,
3995 ID.AddInteger(SCEVType);
3996 for (const SCEV *Op : Ops)
3997 ID.AddPointer(Op);
3998 void *IP = nullptr;
3999 return UniqueSCEVs.FindNodeOrInsertPos(ID, IP);
4000}
4001
4002SCEV *ScalarEvolution::findExistingSCEVInCache(SCEVTypes SCEVType,
4005 ID.AddInteger(SCEVType);
4006 for (const SCEV *Op : Ops)
4007 ID.AddPointer(Op);
4008 void *IP = nullptr;
4009 return UniqueSCEVs.FindNodeOrInsertPos(ID, IP);
4010}
4011
4012const SCEV *ScalarEvolution::getAbsExpr(const SCEV *Op, bool IsNSW) {
4014 return getSMaxExpr(Op, getNegativeSCEV(Op, Flags));
4015}
4016
4019 assert(SCEVMinMaxExpr::isMinMaxType(Kind) && "Not a SCEVMinMaxExpr!");
4020 assert(!Ops.empty() && "Cannot get empty (u|s)(min|max)!");
4021 if (Ops.size() == 1) return Ops[0];
4022#ifndef NDEBUG
4023 Type *ETy = getEffectiveSCEVType(Ops[0]->getType());
4024 for (unsigned i = 1, e = Ops.size(); i != e; ++i) {
4025 assert(getEffectiveSCEVType(Ops[i]->getType()) == ETy &&
4026 "Operand types don't match!");
4027 assert(Ops[0]->getType()->isPointerTy() ==
4028 Ops[i]->getType()->isPointerTy() &&
4029 "min/max should be consistently pointerish");
4030 }
4031#endif
4032
4033 bool IsSigned = Kind == scSMaxExpr || Kind == scSMinExpr;
4034 bool IsMax = Kind == scSMaxExpr || Kind == scUMaxExpr;
4035
4036 const SCEV *Folded = constantFoldAndGroupOps(
4037 *this, LI, DT, Ops,
4038 [&](const APInt &C1, const APInt &C2) {
4039 switch (Kind) {
4040 case scSMaxExpr:
4041 return APIntOps::smax(C1, C2);
4042 case scSMinExpr:
4043 return APIntOps::smin(C1, C2);
4044 case scUMaxExpr:
4045 return APIntOps::umax(C1, C2);
4046 case scUMinExpr:
4047 return APIntOps::umin(C1, C2);
4048 default:
4049 llvm_unreachable("Unknown SCEV min/max opcode");
4050 }
4051 },
4052 [&](const APInt &C) {
4053 // identity
4054 if (IsMax)
4055 return IsSigned ? C.isMinSignedValue() : C.isMinValue();
4056 else
4057 return IsSigned ? C.isMaxSignedValue() : C.isMaxValue();
4058 },
4059 [&](const APInt &C) {
4060 // absorber
4061 if (IsMax)
4062 return IsSigned ? C.isMaxSignedValue() : C.isMaxValue();
4063 else
4064 return IsSigned ? C.isMinSignedValue() : C.isMinValue();
4065 });
4066 if (Folded)
4067 return Folded;
4068
4069 // Check if we have created the same expression before.
4070 if (const SCEV *S = findExistingSCEVInCache(Kind, Ops)) {
4071 return S;
4072 }
4073
4074 // Find the first operation of the same kind
4075 unsigned Idx = 0;
4076 while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < Kind)
4077 ++Idx;
4078
4079 // Check to see if one of the operands is of the same kind. If so, expand its
4080 // operands onto our operand list, and recurse to simplify.
4081 if (Idx < Ops.size()) {
4082 bool DeletedAny = false;
4083 while (Ops[Idx]->getSCEVType() == Kind) {
4084 const SCEVMinMaxExpr *SMME = cast<SCEVMinMaxExpr>(Ops[Idx]);
4085 Ops.erase(Ops.begin()+Idx);
4086 append_range(Ops, SMME->operands());
4087 DeletedAny = true;
4088 }
4089
4090 if (DeletedAny)
4091 return getMinMaxExpr(Kind, Ops);
4092 }
4093
4094 // Okay, check to see if the same value occurs in the operand list twice. If
4095 // so, delete one. Since we sorted the list, these values are required to
4096 // be adjacent.
4101 llvm::CmpInst::Predicate FirstPred = IsMax ? GEPred : LEPred;
4102 llvm::CmpInst::Predicate SecondPred = IsMax ? LEPred : GEPred;
4103 for (unsigned i = 0, e = Ops.size() - 1; i != e; ++i) {
4104 if (Ops[i] == Ops[i + 1] ||
4105 isKnownViaNonRecursiveReasoning(FirstPred, Ops[i], Ops[i + 1])) {
4106 // X op Y op Y --> X op Y
4107 // X op Y --> X, if we know X, Y are ordered appropriately
4108 Ops.erase(Ops.begin() + i + 1, Ops.begin() + i + 2);
4109 --i;
4110 --e;
4111 } else if (isKnownViaNonRecursiveReasoning(SecondPred, Ops[i],
4112 Ops[i + 1])) {
4113 // X op Y --> Y, if we know X, Y are ordered appropriately
4114 Ops.erase(Ops.begin() + i, Ops.begin() + i + 1);
4115 --i;
4116 --e;
4117 }
4118 }
4119
4120 if (Ops.size() == 1) return Ops[0];
4121
4122 assert(!Ops.empty() && "Reduced smax down to nothing!");
4123
4124 // Okay, it looks like we really DO need an expr. Check to see if we
4125 // already have one, otherwise create a new one.
4127 ID.AddInteger(Kind);
4128 for (const SCEV *Op : Ops)
4129 ID.AddPointer(Op);
4130 void *IP = nullptr;
4131 const SCEV *ExistingSCEV = UniqueSCEVs.FindNodeOrInsertPos(ID, IP);
4132 if (ExistingSCEV)
4133 return ExistingSCEV;
4134 SCEVUse *O = SCEVAllocator.Allocate<SCEVUse>(Ops.size());
4136 SCEV *S = new (SCEVAllocator)
4137 SCEVMinMaxExpr(ID.Intern(SCEVAllocator), Kind, O, Ops.size());
4138
4139 UniqueSCEVs.InsertNode(S, IP);
4140 S->computeAndSetCanonical(*this);
4141 registerUser(S, Ops);
4142 return S;
4143}
4144
4145namespace {
4146
4147class SCEVSequentialMinMaxDeduplicatingVisitor final
4148 : public SCEVVisitor<SCEVSequentialMinMaxDeduplicatingVisitor,
4149 std::optional<const SCEV *>> {
4150 using RetVal = std::optional<const SCEV *>;
4152
4153 ScalarEvolution &SE;
4154 const SCEVTypes RootKind; // Must be a sequential min/max expression.
4155 const SCEVTypes NonSequentialRootKind; // Non-sequential variant of RootKind.
4157
4158 bool canRecurseInto(SCEVTypes Kind) const {
4159 // We can only recurse into the SCEV expression of the same effective type
4160 // as the type of our root SCEV expression.
4161 return RootKind == Kind || NonSequentialRootKind == Kind;
4162 };
4163
4164 RetVal visitAnyMinMaxExpr(const SCEV *S) {
4166 "Only for min/max expressions.");
4167 SCEVTypes Kind = S->getSCEVType();
4168
4169 if (!canRecurseInto(Kind))
4170 return S;
4171
4172 auto *NAry = cast<SCEVNAryExpr>(S);
4173 SmallVector<SCEVUse> NewOps;
4174 bool Changed = visit(Kind, NAry->operands(), NewOps);
4175
4176 if (!Changed)
4177 return S;
4178 if (NewOps.empty())
4179 return std::nullopt;
4180
4182 ? SE.getSequentialMinMaxExpr(Kind, NewOps)
4183 : SE.getMinMaxExpr(Kind, NewOps);
4184 }
4185
4186 RetVal visit(const SCEV *S) {
4187 // Has the whole operand been seen already?
4188 if (!SeenOps.insert(S).second)
4189 return std::nullopt;
4190 return Base::visit(S);
4191 }
4192
4193public:
4194 SCEVSequentialMinMaxDeduplicatingVisitor(ScalarEvolution &SE,
4195 SCEVTypes RootKind)
4196 : SE(SE), RootKind(RootKind),
4197 NonSequentialRootKind(
4198 SCEVSequentialMinMaxExpr::getEquivalentNonSequentialSCEVType(
4199 RootKind)) {}
4200
4201 bool /*Changed*/ visit(SCEVTypes Kind, ArrayRef<SCEVUse> OrigOps,
4202 SmallVectorImpl<SCEVUse> &NewOps) {
4203 bool Changed = false;
4205 Ops.reserve(OrigOps.size());
4206
4207 for (const SCEV *Op : OrigOps) {
4208 RetVal NewOp = visit(Op);
4209 if (NewOp != Op)
4210 Changed = true;
4211 if (NewOp)
4212 Ops.emplace_back(*NewOp);
4213 }
4214
4215 if (Changed)
4216 NewOps = std::move(Ops);
4217 return Changed;
4218 }
4219
4220 RetVal visitConstant(const SCEVConstant *Constant) { return Constant; }
4221
4222 RetVal visitVScale(const SCEVVScale *VScale) { return VScale; }
4223
4224 RetVal visitPtrToAddrExpr(const SCEVPtrToAddrExpr *Expr) { return Expr; }
4225
4226 RetVal visitPtrToIntExpr(const SCEVPtrToIntExpr *Expr) { return Expr; }
4227
4228 RetVal visitTruncateExpr(const SCEVTruncateExpr *Expr) { return Expr; }
4229
4230 RetVal visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr) { return Expr; }
4231
4232 RetVal visitSignExtendExpr(const SCEVSignExtendExpr *Expr) { return Expr; }
4233
4234 RetVal visitAddExpr(const SCEVAddExpr *Expr) { return Expr; }
4235
4236 RetVal visitMulExpr(const SCEVMulExpr *Expr) { return Expr; }
4237
4238 RetVal visitUDivExpr(const SCEVUDivExpr *Expr) { return Expr; }
4239
4240 RetVal visitAddRecExpr(const SCEVAddRecExpr *Expr) { return Expr; }
4241
4242 RetVal visitSMaxExpr(const SCEVSMaxExpr *Expr) {
4243 return visitAnyMinMaxExpr(Expr);
4244 }
4245
4246 RetVal visitUMaxExpr(const SCEVUMaxExpr *Expr) {
4247 return visitAnyMinMaxExpr(Expr);
4248 }
4249
4250 RetVal visitSMinExpr(const SCEVSMinExpr *Expr) {
4251 return visitAnyMinMaxExpr(Expr);
4252 }
4253
4254 RetVal visitUMinExpr(const SCEVUMinExpr *Expr) {
4255 return visitAnyMinMaxExpr(Expr);
4256 }
4257
4258 RetVal visitSequentialUMinExpr(const SCEVSequentialUMinExpr *Expr) {
4259 return visitAnyMinMaxExpr(Expr);
4260 }
4261
4262 RetVal visitUnknown(const SCEVUnknown *Expr) { return Expr; }
4263
4264 RetVal visitCouldNotCompute(const SCEVCouldNotCompute *Expr) { return Expr; }
4265};
4266
4267} // namespace
4268
4270 switch (Kind) {
4271 case scConstant:
4272 case scVScale:
4273 case scTruncate:
4274 case scZeroExtend:
4275 case scSignExtend:
4276 case scPtrToAddr:
4277 case scPtrToInt:
4278 case scAddExpr:
4279 case scMulExpr:
4280 case scUDivExpr:
4281 case scAddRecExpr:
4282 case scUMaxExpr:
4283 case scSMaxExpr:
4284 case scUMinExpr:
4285 case scSMinExpr:
4286 case scUnknown:
4287 // If any operand is poison, the whole expression is poison.
4288 return true;
4290 // FIXME: if the *first* operand is poison, the whole expression is poison.
4291 return false; // Pessimistically, say that it does not propagate poison.
4292 case scCouldNotCompute:
4293 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
4294 }
4295 llvm_unreachable("Unknown SCEV kind!");
4296}
4297
4298namespace {
4299// The only way poison may be introduced in a SCEV expression is from a
4300// poison SCEVUnknown (ConstantExprs are also represented as SCEVUnknown,
4301// not SCEVConstant). Notably, nowrap flags in SCEV nodes can *not*
4302// introduce poison -- they encode guaranteed, non-speculated knowledge.
4303//
4304// Additionally, all SCEV nodes propagate poison from inputs to outputs,
4305// with the notable exception of umin_seq, where only poison from the first
4306// operand is (unconditionally) propagated.
4307struct SCEVPoisonCollector {
4308 bool LookThroughMaybePoisonBlocking;
4309 SmallPtrSet<const SCEVUnknown *, 4> MaybePoison;
4310 SCEVPoisonCollector(bool LookThroughMaybePoisonBlocking)
4311 : LookThroughMaybePoisonBlocking(LookThroughMaybePoisonBlocking) {}
4312
4313 bool follow(const SCEV *S) {
4314 if (!LookThroughMaybePoisonBlocking &&
4316 return false;
4317
4318 if (auto *SU = dyn_cast<SCEVUnknown>(S)) {
4319 if (!isGuaranteedNotToBePoison(SU->getValue()))
4320 MaybePoison.insert(SU);
4321 }
4322 return true;
4323 }
4324 bool isDone() const { return false; }
4325};
4326} // namespace
4327
4328/// Return true if V is poison given that AssumedPoison is already poison.
4329static bool impliesPoison(const SCEV *AssumedPoison, const SCEV *S) {
4330 // First collect all SCEVs that might result in AssumedPoison to be poison.
4331 // We need to look through potentially poison-blocking operations here,
4332 // because we want to find all SCEVs that *might* result in poison, not only
4333 // those that are *required* to.
4334 SCEVPoisonCollector PC1(/* LookThroughMaybePoisonBlocking */ true);
4335 visitAll(AssumedPoison, PC1);
4336
4337 // AssumedPoison is never poison. As the assumption is false, the implication
4338 // is true. Don't bother walking the other SCEV in this case.
4339 if (PC1.MaybePoison.empty())
4340 return true;
4341
4342 // Collect all SCEVs in S that, if poison, *will* result in S being poison
4343 // as well. We cannot look through potentially poison-blocking operations
4344 // here, as their arguments only *may* make the result poison.
4345 SCEVPoisonCollector PC2(/* LookThroughMaybePoisonBlocking */ false);
4346 visitAll(S, PC2);
4347
4348 // Make sure that no matter which SCEV in PC1.MaybePoison is actually poison,
4349 // it will also make S poison by being part of PC2.MaybePoison.
4350 return llvm::set_is_subset(PC1.MaybePoison, PC2.MaybePoison);
4351}
4352
4354 SmallPtrSetImpl<const Value *> &Result, const SCEV *S) {
4355 SCEVPoisonCollector PC(/* LookThroughMaybePoisonBlocking */ false);
4356 visitAll(S, PC);
4357 for (const SCEVUnknown *SU : PC.MaybePoison)
4358 Result.insert(SU->getValue());
4359}
4360
4362 const SCEV *S, Instruction *I,
4363 SmallVectorImpl<Instruction *> &DropPoisonGeneratingInsts) {
4364 // If the instruction cannot be poison, it's always safe to reuse.
4366 return true;
4367
4368 // Otherwise, it is possible that I is more poisonous that S. Collect the
4369 // poison-contributors of S, and then check whether I has any additional
4370 // poison-contributors. Poison that is contributed through poison-generating
4371 // flags is handled by dropping those flags instead.
4373 getPoisonGeneratingValues(PoisonVals, S);
4374
4375 SmallVector<Value *> Worklist;
4377 Worklist.push_back(I);
4378 while (!Worklist.empty()) {
4379 Value *V = Worklist.pop_back_val();
4380 if (!Visited.insert(V).second)
4381 continue;
4382
4383 // Avoid walking large instruction graphs.
4384 if (Visited.size() > 16)
4385 return false;
4386
4387 // Either the value can't be poison, or the S would also be poison if it
4388 // is.
4389 if (PoisonVals.contains(V) || ::isGuaranteedNotToBePoison(V))
4390 continue;
4391
4392 auto *I = dyn_cast<Instruction>(V);
4393 if (!I)
4394 return false;
4395
4396 // Disjoint or instructions are interpreted as adds by SCEV. However, we
4397 // can't replace an arbitrary add with disjoint or, even if we drop the
4398 // flag. We would need to convert the or into an add.
4399 if (auto *PDI = dyn_cast<PossiblyDisjointInst>(I))
4400 if (PDI->isDisjoint())
4401 return false;
4402
4403 // FIXME: Ignore vscale, even though it technically could be poison. Do this
4404 // because SCEV currently assumes it can't be poison. Remove this special
4405 // case once we proper model when vscale can be poison.
4406 if (auto *II = dyn_cast<IntrinsicInst>(I);
4407 II && II->getIntrinsicID() == Intrinsic::vscale)
4408 continue;
4409
4410 if (canCreatePoison(cast<Operator>(I), /*ConsiderFlagsAndMetadata*/ false))
4411 return false;
4412
4413 // If the instruction can't create poison, we can recurse to its operands.
4414 if (I->hasPoisonGeneratingAnnotations())
4415 DropPoisonGeneratingInsts.push_back(I);
4416
4417 llvm::append_range(Worklist, I->operands());
4418 }
4419 return true;
4420}
4421
4422const SCEV *
4425 assert(SCEVSequentialMinMaxExpr::isSequentialMinMaxType(Kind) &&
4426 "Not a SCEVSequentialMinMaxExpr!");
4427 assert(!Ops.empty() && "Cannot get empty (u|s)(min|max)!");
4428 if (Ops.size() == 1)
4429 return Ops[0];
4430#ifndef NDEBUG
4431 Type *ETy = getEffectiveSCEVType(Ops[0]->getType());
4432 for (unsigned i = 1, e = Ops.size(); i != e; ++i) {
4433 assert(getEffectiveSCEVType(Ops[i]->getType()) == ETy &&
4434 "Operand types don't match!");
4435 assert(Ops[0]->getType()->isPointerTy() ==
4436 Ops[i]->getType()->isPointerTy() &&
4437 "min/max should be consistently pointerish");
4438 }
4439#endif
4440
4441 // Note that SCEVSequentialMinMaxExpr is *NOT* commutative,
4442 // so we can *NOT* do any kind of sorting of the expressions!
4443
4444 // Check if we have created the same expression before.
4445 if (const SCEV *S = findExistingSCEVInCache(Kind, Ops))
4446 return S;
4447
4448 // FIXME: there are *some* simplifications that we can do here.
4449
4450 // Keep only the first instance of an operand.
4451 {
4452 SCEVSequentialMinMaxDeduplicatingVisitor Deduplicator(*this, Kind);
4453 bool Changed = Deduplicator.visit(Kind, Ops, Ops);
4454 if (Changed)
4455 return getSequentialMinMaxExpr(Kind, Ops);
4456 }
4457
4458 // Check to see if one of the operands is of the same kind. If so, expand its
4459 // operands onto our operand list, and recurse to simplify.
4460 {
4461 unsigned Idx = 0;
4462 bool DeletedAny = false;
4463 while (Idx < Ops.size()) {
4464 if (Ops[Idx]->getSCEVType() != Kind) {
4465 ++Idx;
4466 continue;
4467 }
4468 const auto *SMME = cast<SCEVSequentialMinMaxExpr>(Ops[Idx]);
4469 Ops.erase(Ops.begin() + Idx);
4470 Ops.insert(Ops.begin() + Idx, SMME->operands().begin(),
4471 SMME->operands().end());
4472 DeletedAny = true;
4473 }
4474
4475 if (DeletedAny)
4476 return getSequentialMinMaxExpr(Kind, Ops);
4477 }
4478
4479 const SCEV *SaturationPoint;
4481 switch (Kind) {
4483 SaturationPoint = getZero(Ops[0]->getType());
4484 Pred = ICmpInst::ICMP_ULE;
4485 break;
4486 default:
4487 llvm_unreachable("Not a sequential min/max type.");
4488 }
4489
4490 for (unsigned i = 1, e = Ops.size(); i != e; ++i) {
4491 if (!isGuaranteedNotToCauseUB(Ops[i]))
4492 continue;
4493 // We can replace %x umin_seq %y with %x umin %y if either:
4494 // * %y being poison implies %x is also poison.
4495 // * %x cannot be the saturating value (e.g. zero for umin).
4496 if (::impliesPoison(Ops[i], Ops[i - 1]) ||
4497 isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_NE, Ops[i - 1],
4498 SaturationPoint)) {
4499 SmallVector<SCEVUse, 2> SeqOps = {Ops[i - 1], Ops[i]};
4500 Ops[i - 1] = getMinMaxExpr(
4502 SeqOps);
4503 Ops.erase(Ops.begin() + i);
4504 return getSequentialMinMaxExpr(Kind, Ops);
4505 }
4506 // Fold %x umin_seq %y to %x if %x ule %y.
4507 // TODO: We might be able to prove the predicate for a later operand.
4508 if (isKnownViaNonRecursiveReasoning(Pred, Ops[i - 1], Ops[i])) {
4509 Ops.erase(Ops.begin() + i);
4510 return getSequentialMinMaxExpr(Kind, Ops);
4511 }
4512 }
4513
4514 // Okay, it looks like we really DO need an expr. Check to see if we
4515 // already have one, otherwise create a new one.
4517 ID.AddInteger(Kind);
4518 for (const SCEV *Op : Ops)
4519 ID.AddPointer(Op);
4520 void *IP = nullptr;
4521 const SCEV *ExistingSCEV = UniqueSCEVs.FindNodeOrInsertPos(ID, IP);
4522 if (ExistingSCEV)
4523 return ExistingSCEV;
4524
4525 SCEVUse *O = SCEVAllocator.Allocate<SCEVUse>(Ops.size());
4527 SCEV *S = new (SCEVAllocator)
4528 SCEVSequentialMinMaxExpr(ID.Intern(SCEVAllocator), Kind, O, Ops.size());
4529
4530 UniqueSCEVs.InsertNode(S, IP);
4531 S->computeAndSetCanonical(*this);
4532 registerUser(S, Ops);
4533 return S;
4534}
4535
4540
4544
4549
4553
4558
4562
4564 bool Sequential) {
4565 SmallVector<SCEVUse, 2> Ops = {LHS, RHS};
4566 return getUMinExpr(Ops, Sequential);
4567}
4568
4574
4575const SCEV *
4577 const SCEV *Res = getConstant(IntTy, Size.getKnownMinValue());
4578 if (Size.isScalable())
4579 Res = getMulExpr(Res, getVScale(IntTy));
4580 return Res;
4581}
4582
4584 return getSizeOfExpr(IntTy, getDataLayout().getTypeAllocSize(AllocTy));
4585}
4586
4588 return getSizeOfExpr(IntTy, getDataLayout().getTypeStoreSize(StoreTy));
4589}
4590
4592 StructType *STy,
4593 unsigned FieldNo) {
4594 // We can bypass creating a target-independent constant expression and then
4595 // folding it back into a ConstantInt. This is just a compile-time
4596 // optimization.
4597 const StructLayout *SL = getDataLayout().getStructLayout(STy);
4598 assert(!SL->getSizeInBits().isScalable() &&
4599 "Cannot get offset for structure containing scalable vector types");
4600 return getConstant(IntTy, SL->getElementOffset(FieldNo));
4601}
4602
4604 // Don't attempt to do anything other than create a SCEVUnknown object
4605 // here. createSCEV only calls getUnknown after checking for all other
4606 // interesting possibilities, and any other code that calls getUnknown
4607 // is doing so in order to hide a value from SCEV canonicalization.
4608
4610 ID.AddInteger(scUnknown);
4611 ID.AddPointer(V);
4612 void *IP = nullptr;
4613 if (SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) {
4614 assert(cast<SCEVUnknown>(S)->getValue() == V &&
4615 "Stale SCEVUnknown in uniquing map!");
4616 return S;
4617 }
4618 SCEV *S = new (SCEVAllocator) SCEVUnknown(ID.Intern(SCEVAllocator), V, this,
4619 FirstUnknown);
4620 FirstUnknown = cast<SCEVUnknown>(S);
4621 UniqueSCEVs.InsertNode(S, IP);
4622 S->computeAndSetCanonical(*this);
4623 return S;
4624}
4625
4626//===----------------------------------------------------------------------===//
4627// Basic SCEV Analysis and PHI Idiom Recognition Code
4628//
4629
4630/// Test if values of the given type are analyzable within the SCEV
4631/// framework. This primarily includes integer types, and it can optionally
4632/// include pointer types if the ScalarEvolution class has access to
4633/// target-specific information.
4635 // Integers and pointers are always SCEVable.
4636 return Ty->isIntOrPtrTy();
4637}
4638
4639/// Return the size in bits of the specified type, for which isSCEVable must
4640/// return true.
4642 assert(isSCEVable(Ty) && "Type is not SCEVable!");
4643 if (Ty->isPointerTy())
4645 return getDataLayout().getTypeSizeInBits(Ty);
4646}
4647
4648/// Return a type with the same bitwidth as the given type and which represents
4649/// how SCEV will treat the given type, for which isSCEVable must return
4650/// true. For pointer types, this is the pointer index sized integer type.
4652 assert(isSCEVable(Ty) && "Type is not SCEVable!");
4653
4654 if (Ty->isIntegerTy())
4655 return Ty;
4656
4657 // The only other support type is pointer.
4658 assert(Ty->isPointerTy() && "Unexpected non-pointer non-integer type!");
4659 return getDataLayout().getIndexType(Ty);
4660}
4661
4663 return getTypeSizeInBits(T1) >= getTypeSizeInBits(T2) ? T1 : T2;
4664}
4665
4667 const SCEV *B) {
4668 /// For a valid use point to exist, the defining scope of one operand
4669 /// must dominate the other.
4670 bool PreciseA, PreciseB;
4671 auto *ScopeA = getDefiningScopeBound({A}, PreciseA);
4672 auto *ScopeB = getDefiningScopeBound({B}, PreciseB);
4673 if (!PreciseA || !PreciseB)
4674 // Can't tell.
4675 return false;
4676 return (ScopeA == ScopeB) || DT.dominates(ScopeA, ScopeB) ||
4677 DT.dominates(ScopeB, ScopeA);
4678}
4679
4681 return CouldNotCompute.get();
4682}
4683
4684bool ScalarEvolution::checkValidity(const SCEV *S) const {
4685 bool ContainsNulls = SCEVExprContains(S, [](const SCEV *S) {
4686 auto *SU = dyn_cast<SCEVUnknown>(S);
4687 return SU && SU->getValue() == nullptr;
4688 });
4689
4690 return !ContainsNulls;
4691}
4692
4694 HasRecMapType::iterator I = HasRecMap.find(S);
4695 if (I != HasRecMap.end())
4696 return I->second;
4697
4698 bool FoundAddRec =
4699 SCEVExprContains(S, [](const SCEV *S) { return isa<SCEVAddRecExpr>(S); });
4700 HasRecMap.insert({S, FoundAddRec});
4701 return FoundAddRec;
4702}
4703
4704/// Return the ValueOffsetPair set for \p S. \p S can be represented
4705/// by the value and offset from any ValueOffsetPair in the set.
4706ArrayRef<Value *> ScalarEvolution::getSCEVValues(const SCEV *S) {
4707 ExprValueMapType::iterator SI = ExprValueMap.find_as(S);
4708 if (SI == ExprValueMap.end())
4709 return {};
4710 return SI->second.getArrayRef();
4711}
4712
4713/// Erase Value from ValueExprMap and ExprValueMap. ValueExprMap.erase(V)
4714/// cannot be used separately. eraseValueFromMap should be used to remove
4715/// V from ValueExprMap and ExprValueMap at the same time.
4716void ScalarEvolution::eraseValueFromMap(Value *V) {
4717 ValueExprMapType::iterator I = ValueExprMap.find_as(V);
4718 if (I != ValueExprMap.end()) {
4719 auto EVIt = ExprValueMap.find(I->second);
4720 bool Removed = EVIt->second.remove(V);
4721 (void) Removed;
4722 assert(Removed && "Value not in ExprValueMap?");
4723 ValueExprMap.erase(I);
4724 }
4725}
4726
4727void ScalarEvolution::insertValueToMap(Value *V, const SCEV *S) {
4728 // A recursive query may have already computed the SCEV. It should be
4729 // equivalent, but may not necessarily be exactly the same, e.g. due to lazily
4730 // inferred nowrap flags.
4731 auto It = ValueExprMap.find_as(V);
4732 if (It == ValueExprMap.end()) {
4733 ValueExprMap.insert({SCEVCallbackVH(V, this), S});
4734 ExprValueMap[S].insert(V);
4735 }
4736}
4737
4738/// Return an existing SCEV if it exists, otherwise analyze the expression and
4739/// create a new one.
4741 assert(isSCEVable(V->getType()) && "Value is not SCEVable!");
4742
4743 if (const SCEV *S = getExistingSCEV(V))
4744 return S;
4745 return createSCEVIter(V);
4746}
4747
4749 assert(isSCEVable(V->getType()) && "Value is not SCEVable!");
4750
4751 ValueExprMapType::iterator I = ValueExprMap.find_as(V);
4752 if (I != ValueExprMap.end()) {
4753 const SCEV *S = I->second;
4754 assert(checkValidity(S) &&
4755 "existing SCEV has not been properly invalidated");
4756 return S;
4757 }
4758 return nullptr;
4759}
4760
4761/// Return a SCEV corresponding to -V = -1*V
4763 SCEV::NoWrapFlags Flags) {
4764 if (const SCEVConstant *VC = dyn_cast<SCEVConstant>(V))
4765 return getConstant(
4766 cast<ConstantInt>(ConstantExpr::getNeg(VC->getValue())));
4767
4768 Type *Ty = V->getType();
4769 Ty = getEffectiveSCEVType(Ty);
4770 return getMulExpr(V, getMinusOne(Ty), Flags);
4771}
4772
4773/// If Expr computes ~A, return A else return nullptr
4774static const SCEV *MatchNotExpr(const SCEV *Expr) {
4775 const SCEV *MulOp;
4776 if (match(Expr, m_scev_Add(m_scev_AllOnes(),
4777 m_scev_Mul(m_scev_AllOnes(), m_SCEV(MulOp)))))
4778 return MulOp;
4779 return nullptr;
4780}
4781
4782/// Return a SCEV corresponding to ~V = -1-V
4784 assert(!V->getType()->isPointerTy() && "Can't negate pointer");
4785
4786 if (const SCEVConstant *VC = dyn_cast<SCEVConstant>(V))
4787 return getConstant(
4788 cast<ConstantInt>(ConstantExpr::getNot(VC->getValue())));
4789
4790 // Fold ~(u|s)(min|max)(~x, ~y) to (u|s)(max|min)(x, y)
4791 if (const SCEVMinMaxExpr *MME = dyn_cast<SCEVMinMaxExpr>(V)) {
4792 auto MatchMinMaxNegation = [&](const SCEVMinMaxExpr *MME) {
4793 SmallVector<SCEVUse, 2> MatchedOperands;
4794 for (const SCEV *Operand : MME->operands()) {
4795 const SCEV *Matched = MatchNotExpr(Operand);
4796 if (!Matched)
4797 return (const SCEV *)nullptr;
4798 MatchedOperands.push_back(Matched);
4799 }
4800 return getMinMaxExpr(SCEVMinMaxExpr::negate(MME->getSCEVType()),
4801 MatchedOperands);
4802 };
4803 if (const SCEV *Replaced = MatchMinMaxNegation(MME))
4804 return Replaced;
4805 }
4806
4807 Type *Ty = V->getType();
4808 Ty = getEffectiveSCEVType(Ty);
4809 return getMinusSCEV(getMinusOne(Ty), V);
4810}
4811
4813 assert(P->getType()->isPointerTy());
4814
4815 if (auto *AddRec = dyn_cast<SCEVAddRecExpr>(P)) {
4816 // The base of an AddRec is the first operand.
4817 SmallVector<SCEVUse> Ops{AddRec->operands()};
4818 Ops[0] = removePointerBase(Ops[0]);
4819 // Don't try to transfer nowrap flags for now. We could in some cases
4820 // (for example, if pointer operand of the AddRec is a SCEVUnknown).
4821 return getAddRecExpr(Ops, AddRec->getLoop(), SCEV::FlagAnyWrap);
4822 }
4823 if (auto *Add = dyn_cast<SCEVAddExpr>(P)) {
4824 // The base of an Add is the pointer operand.
4825 SmallVector<SCEVUse> Ops{Add->operands()};
4826 SCEVUse *PtrOp = nullptr;
4827 for (SCEVUse &AddOp : Ops) {
4828 if (AddOp->getType()->isPointerTy()) {
4829 assert(!PtrOp && "Cannot have multiple pointer ops");
4830 PtrOp = &AddOp;
4831 }
4832 }
4833 *PtrOp = removePointerBase(*PtrOp);
4834 // Don't try to transfer nowrap flags for now. We could in some cases
4835 // (for example, if the pointer operand of the Add is a SCEVUnknown).
4836 return getAddExpr(Ops);
4837 }
4838 // Any other expression must be a pointer base.
4839 return getZero(P->getType());
4840}
4841
4843 SCEV::NoWrapFlags Flags,
4844 unsigned Depth) {
4845 // Fast path: X - X --> 0.
4846 if (LHS == RHS)
4847 return getZero(LHS->getType());
4848
4849 // If we subtract two pointers with different pointer bases, bail.
4850 // Eventually, we're going to add an assertion to getMulExpr that we
4851 // can't multiply by a pointer.
4852 if (RHS->getType()->isPointerTy()) {
4853 if (!LHS->getType()->isPointerTy() ||
4854 getPointerBase(LHS) != getPointerBase(RHS))
4855 return getCouldNotCompute();
4856 LHS = removePointerBase(LHS);
4857 RHS = removePointerBase(RHS);
4858 }
4859
4860 // We represent LHS - RHS as LHS + (-1)*RHS. This transformation
4861 // makes it so that we cannot make much use of NUW.
4862 auto AddFlags = SCEV::FlagAnyWrap;
4863 const bool RHSIsNotMinSigned =
4865 if (hasFlags(Flags, SCEV::FlagNSW)) {
4866 // Let M be the minimum representable signed value. Then (-1)*RHS
4867 // signed-wraps if and only if RHS is M. That can happen even for
4868 // a NSW subtraction because e.g. (-1)*M signed-wraps even though
4869 // -1 - M does not. So to transfer NSW from LHS - RHS to LHS +
4870 // (-1)*RHS, we need to prove that RHS != M.
4871 //
4872 // If LHS is non-negative and we know that LHS - RHS does not
4873 // signed-wrap, then RHS cannot be M. So we can rule out signed-wrap
4874 // either by proving that RHS > M or that LHS >= 0.
4875 if (RHSIsNotMinSigned || isKnownNonNegative(LHS)) {
4876 AddFlags = SCEV::FlagNSW;
4877 }
4878 }
4879
4880 // FIXME: Find a correct way to transfer NSW to (-1)*M when LHS -
4881 // RHS is NSW and LHS >= 0.
4882 //
4883 // The difficulty here is that the NSW flag may have been proven
4884 // relative to a loop that is to be found in a recurrence in LHS and
4885 // not in RHS. Applying NSW to (-1)*M may then let the NSW have a
4886 // larger scope than intended.
4887 auto NegFlags = RHSIsNotMinSigned ? SCEV::FlagNSW : SCEV::FlagAnyWrap;
4888
4889 return getAddExpr(LHS, getNegativeSCEV(RHS, NegFlags), AddFlags, Depth);
4890}
4891
4893 unsigned Depth) {
4894 Type *SrcTy = V->getType();
4895 assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
4896 "Cannot truncate or zero extend with non-integer arguments!");
4897 if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
4898 return V; // No conversion
4899 if (getTypeSizeInBits(SrcTy) > getTypeSizeInBits(Ty))
4900 return getTruncateExpr(V, Ty, Depth);
4901 return getZeroExtendExpr(V, Ty, Depth);
4902}
4903
4905 unsigned Depth) {
4906 Type *SrcTy = V->getType();
4907 assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
4908 "Cannot truncate or zero extend with non-integer arguments!");
4909 if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
4910 return V; // No conversion
4911 if (getTypeSizeInBits(SrcTy) > getTypeSizeInBits(Ty))
4912 return getTruncateExpr(V, Ty, Depth);
4913 return getSignExtendExpr(V, Ty, Depth);
4914}
4915
4916const SCEV *
4918 Type *SrcTy = V->getType();
4919 assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
4920 "Cannot noop or zero extend with non-integer arguments!");
4922 "getNoopOrZeroExtend cannot truncate!");
4923 if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
4924 return V; // No conversion
4925 return getZeroExtendExpr(V, Ty);
4926}
4927
4928const SCEV *
4930 Type *SrcTy = V->getType();
4931 assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
4932 "Cannot noop or sign extend with non-integer arguments!");
4934 "getNoopOrSignExtend cannot truncate!");
4935 if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
4936 return V; // No conversion
4937 return getSignExtendExpr(V, Ty);
4938}
4939
4940const SCEV *
4942 Type *SrcTy = V->getType();
4943 assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
4944 "Cannot noop or any extend with non-integer arguments!");
4946 "getNoopOrAnyExtend cannot truncate!");
4947 if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
4948 return V; // No conversion
4949 return getAnyExtendExpr(V, Ty);
4950}
4951
4952const SCEV *
4954 Type *SrcTy = V->getType();
4955 assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
4956 "Cannot truncate or noop with non-integer arguments!");
4958 "getTruncateOrNoop cannot extend!");
4959 if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
4960 return V; // No conversion
4961 return getTruncateExpr(V, Ty);
4962}
4963
4965 const SCEV *RHS) {
4966 const SCEV *PromotedLHS = LHS;
4967 const SCEV *PromotedRHS = RHS;
4968
4969 if (getTypeSizeInBits(LHS->getType()) > getTypeSizeInBits(RHS->getType()))
4970 PromotedRHS = getZeroExtendExpr(RHS, LHS->getType());
4971 else
4972 PromotedLHS = getNoopOrZeroExtend(LHS, RHS->getType());
4973
4974 return getUMaxExpr(PromotedLHS, PromotedRHS);
4975}
4976
4978 const SCEV *RHS,
4979 bool Sequential) {
4980 SmallVector<SCEVUse, 2> Ops = {LHS, RHS};
4981 return getUMinFromMismatchedTypes(Ops, Sequential);
4982}
4983
4984const SCEV *
4986 bool Sequential) {
4987 assert(!Ops.empty() && "At least one operand must be!");
4988 // Trivial case.
4989 if (Ops.size() == 1)
4990 return Ops[0];
4991
4992 // Find the max type first.
4993 Type *MaxType = nullptr;
4994 for (SCEVUse S : Ops)
4995 if (MaxType)
4996 MaxType = getWiderType(MaxType, S->getType());
4997 else
4998 MaxType = S->getType();
4999 assert(MaxType && "Failed to find maximum type!");
5000
5001 // Extend all ops to max type.
5002 SmallVector<SCEVUse, 2> PromotedOps;
5003 for (SCEVUse S : Ops)
5004 PromotedOps.push_back(getNoopOrZeroExtend(S, MaxType));
5005
5006 // Generate umin.
5007 return getUMinExpr(PromotedOps, Sequential);
5008}
5009
5011 // A pointer operand may evaluate to a nonpointer expression, such as null.
5012 if (!V->getType()->isPointerTy())
5013 return V;
5014
5015 while (true) {
5016 if (auto *AddRec = dyn_cast<SCEVAddRecExpr>(V)) {
5017 V = AddRec->getStart();
5018 } else if (auto *Add = dyn_cast<SCEVAddExpr>(V)) {
5019 const SCEV *PtrOp = nullptr;
5020 for (const SCEV *AddOp : Add->operands()) {
5021 if (AddOp->getType()->isPointerTy()) {
5022 assert(!PtrOp && "Cannot have multiple pointer ops");
5023 PtrOp = AddOp;
5024 }
5025 }
5026 assert(PtrOp && "Must have pointer op");
5027 V = PtrOp;
5028 } else // Not something we can look further into.
5029 return V;
5030 }
5031}
5032
5033/// Push users of the given Instruction onto the given Worklist.
5037 // Push the def-use children onto the Worklist stack.
5038 for (User *U : I->users()) {
5039 auto *UserInsn = cast<Instruction>(U);
5040 if (Visited.insert(UserInsn).second)
5041 Worklist.push_back(UserInsn);
5042 }
5043}
5044
5045namespace {
5046
5047/// Takes SCEV S and Loop L. For each AddRec sub-expression, use its start
5048/// expression in case its Loop is L. If it is not L then
5049/// if IgnoreOtherLoops is true then use AddRec itself
5050/// otherwise rewrite cannot be done.
5051/// If SCEV contains non-invariant unknown SCEV rewrite cannot be done.
5052class SCEVInitRewriter : public SCEVRewriteVisitor<SCEVInitRewriter> {
5053public:
5054 static const SCEV *rewrite(const SCEV *S, const Loop *L, ScalarEvolution &SE,
5055 bool IgnoreOtherLoops = true) {
5056 SCEVInitRewriter Rewriter(L, SE);
5057 const SCEV *Result = Rewriter.visit(S);
5058 if (Rewriter.hasSeenLoopVariantSCEVUnknown())
5059 return SE.getCouldNotCompute();
5060 return Rewriter.hasSeenOtherLoops() && !IgnoreOtherLoops
5061 ? SE.getCouldNotCompute()
5062 : Result;
5063 }
5064
5065 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
5066 if (!SE.isLoopInvariant(Expr, L))
5067 SeenLoopVariantSCEVUnknown = true;
5068 return Expr;
5069 }
5070
5071 const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) {
5072 // Only re-write AddRecExprs for this loop.
5073 if (Expr->getLoop() == L)
5074 return Expr->getStart();
5075 SeenOtherLoops = true;
5076 return Expr;
5077 }
5078
5079 bool hasSeenLoopVariantSCEVUnknown() { return SeenLoopVariantSCEVUnknown; }
5080
5081 bool hasSeenOtherLoops() { return SeenOtherLoops; }
5082
5083private:
5084 explicit SCEVInitRewriter(const Loop *L, ScalarEvolution &SE)
5085 : SCEVRewriteVisitor(SE), L(L) {}
5086
5087 const Loop *L;
5088 bool SeenLoopVariantSCEVUnknown = false;
5089 bool SeenOtherLoops = false;
5090};
5091
5092/// Takes SCEV S and Loop L. For each AddRec sub-expression, use its post
5093/// increment expression in case its Loop is L. If it is not L then
5094/// use AddRec itself.
5095/// If SCEV contains non-invariant unknown SCEV rewrite cannot be done.
5096class SCEVPostIncRewriter : public SCEVRewriteVisitor<SCEVPostIncRewriter> {
5097public:
5098 static const SCEV *rewrite(const SCEV *S, const Loop *L, ScalarEvolution &SE) {
5099 SCEVPostIncRewriter Rewriter(L, SE);
5100 const SCEV *Result = Rewriter.visit(S);
5101 return Rewriter.hasSeenLoopVariantSCEVUnknown()
5102 ? SE.getCouldNotCompute()
5103 : Result;
5104 }
5105
5106 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
5107 if (!SE.isLoopInvariant(Expr, L))
5108 SeenLoopVariantSCEVUnknown = true;
5109 return Expr;
5110 }
5111
5112 const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) {
5113 // Only re-write AddRecExprs for this loop.
5114 if (Expr->getLoop() == L)
5115 return Expr->getPostIncExpr(SE);
5116 SeenOtherLoops = true;
5117 return Expr;
5118 }
5119
5120 bool hasSeenLoopVariantSCEVUnknown() { return SeenLoopVariantSCEVUnknown; }
5121
5122 bool hasSeenOtherLoops() { return SeenOtherLoops; }
5123
5124private:
5125 explicit SCEVPostIncRewriter(const Loop *L, ScalarEvolution &SE)
5126 : SCEVRewriteVisitor(SE), L(L) {}
5127
5128 const Loop *L;
5129 bool SeenLoopVariantSCEVUnknown = false;
5130 bool SeenOtherLoops = false;
5131};
5132
5133/// This class evaluates the compare condition by matching it against the
5134/// condition of loop latch. If there is a match we assume a true value
5135/// for the condition while building SCEV nodes.
5136class SCEVBackedgeConditionFolder
5137 : public SCEVRewriteVisitor<SCEVBackedgeConditionFolder> {
5138public:
5139 static const SCEV *rewrite(const SCEV *S, const Loop *L,
5140 ScalarEvolution &SE) {
5141 bool IsPosBECond = false;
5142 Value *BECond = nullptr;
5143 if (BasicBlock *Latch = L->getLoopLatch()) {
5144 if (CondBrInst *BI = dyn_cast<CondBrInst>(Latch->getTerminator())) {
5145 assert(BI->getSuccessor(0) != BI->getSuccessor(1) &&
5146 "Both outgoing branches should not target same header!");
5147 BECond = BI->getCondition();
5148 IsPosBECond = BI->getSuccessor(0) == L->getHeader();
5149 } else {
5150 return S;
5151 }
5152 }
5153 SCEVBackedgeConditionFolder Rewriter(L, BECond, IsPosBECond, SE);
5154 return Rewriter.visit(S);
5155 }
5156
5157 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
5158 const SCEV *Result = Expr;
5159 bool InvariantF = SE.isLoopInvariant(Expr, L);
5160
5161 if (!InvariantF) {
5163 switch (I->getOpcode()) {
5164 case Instruction::Select: {
5165 SelectInst *SI = cast<SelectInst>(I);
5166 std::optional<const SCEV *> Res =
5167 compareWithBackedgeCondition(SI->getCondition());
5168 if (Res) {
5169 bool IsOne = cast<SCEVConstant>(*Res)->getValue()->isOne();
5170 Result = SE.getSCEV(IsOne ? SI->getTrueValue() : SI->getFalseValue());
5171 }
5172 break;
5173 }
5174 default: {
5175 std::optional<const SCEV *> Res = compareWithBackedgeCondition(I);
5176 if (Res)
5177 Result = *Res;
5178 break;
5179 }
5180 }
5181 }
5182 return Result;
5183 }
5184
5185private:
5186 explicit SCEVBackedgeConditionFolder(const Loop *L, Value *BECond,
5187 bool IsPosBECond, ScalarEvolution &SE)
5188 : SCEVRewriteVisitor(SE), L(L), BackedgeCond(BECond),
5189 IsPositiveBECond(IsPosBECond) {}
5190
5191 std::optional<const SCEV *> compareWithBackedgeCondition(Value *IC);
5192
5193 const Loop *L;
5194 /// Loop back condition.
5195 Value *BackedgeCond = nullptr;
5196 /// Set to true if loop back is on positive branch condition.
5197 bool IsPositiveBECond;
5198};
5199
5200std::optional<const SCEV *>
5201SCEVBackedgeConditionFolder::compareWithBackedgeCondition(Value *IC) {
5202
5203 // If value matches the backedge condition for loop latch,
5204 // then return a constant evolution node based on loopback
5205 // branch taken.
5206 if (BackedgeCond == IC)
5207 return IsPositiveBECond ? SE.getOne(Type::getInt1Ty(SE.getContext()))
5209 return std::nullopt;
5210}
5211
5212class SCEVShiftRewriter : public SCEVRewriteVisitor<SCEVShiftRewriter> {
5213public:
5214 static const SCEV *rewrite(const SCEV *S, const Loop *L,
5215 ScalarEvolution &SE) {
5216 SCEVShiftRewriter Rewriter(L, SE);
5217 const SCEV *Result = Rewriter.visit(S);
5218 return Rewriter.isValid() ? Result : SE.getCouldNotCompute();
5219 }
5220
5221 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
5222 // Only allow AddRecExprs for this loop.
5223 if (!SE.isLoopInvariant(Expr, L))
5224 Valid = false;
5225 return Expr;
5226 }
5227
5228 const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) {
5229 if (Expr->getLoop() == L && Expr->isAffine())
5230 return SE.getMinusSCEV(Expr, Expr->getStepRecurrence(SE));
5231 Valid = false;
5232 return Expr;
5233 }
5234
5235 bool isValid() { return Valid; }
5236
5237private:
5238 explicit SCEVShiftRewriter(const Loop *L, ScalarEvolution &SE)
5239 : SCEVRewriteVisitor(SE), L(L) {}
5240
5241 const Loop *L;
5242 bool Valid = true;
5243};
5244
5245} // end anonymous namespace
5246
5248ScalarEvolution::proveNoWrapViaConstantRanges(const SCEVAddRecExpr *AR) {
5249 if (!AR->isAffine())
5250 return SCEV::FlagAnyWrap;
5251
5252 using OBO = OverflowingBinaryOperator;
5253
5255
5256 if (!AR->hasNoSelfWrap()) {
5257 const SCEV *BECount = getConstantMaxBackedgeTakenCount(AR->getLoop());
5258 if (const SCEVConstant *BECountMax = dyn_cast<SCEVConstant>(BECount)) {
5259 ConstantRange StepCR = getSignedRange(AR->getStepRecurrence(*this));
5260 const APInt &BECountAP = BECountMax->getAPInt();
5261 unsigned NoOverflowBitWidth =
5262 BECountAP.getActiveBits() + StepCR.getMinSignedBits();
5263 if (NoOverflowBitWidth <= getTypeSizeInBits(AR->getType()))
5265 }
5266 }
5267
5268 if (!AR->hasNoSignedWrap()) {
5269 ConstantRange AddRecRange = getSignedRange(AR);
5270 ConstantRange IncRange = getSignedRange(AR->getStepRecurrence(*this));
5271
5273 Instruction::Add, IncRange, OBO::NoSignedWrap);
5274 if (NSWRegion.contains(AddRecRange))
5276 }
5277
5278 if (!AR->hasNoUnsignedWrap()) {
5279 ConstantRange AddRecRange = getUnsignedRange(AR);
5280 ConstantRange IncRange = getUnsignedRange(AR->getStepRecurrence(*this));
5281
5283 Instruction::Add, IncRange, OBO::NoUnsignedWrap);
5284 if (NUWRegion.contains(AddRecRange))
5286 }
5287
5288 return Result;
5289}
5290
5292ScalarEvolution::proveNoSignedWrapViaInduction(const SCEVAddRecExpr *AR) {
5294
5295 if (AR->hasNoSignedWrap())
5296 return Result;
5297
5298 if (!AR->isAffine())
5299 return Result;
5300
5301 // This function can be expensive, only try to prove NSW once per AddRec.
5302 if (!SignedWrapViaInductionTried.insert(AR).second)
5303 return Result;
5304
5305 const SCEV *Step = AR->getStepRecurrence(*this);
5306 const Loop *L = AR->getLoop();
5307
5308 // Check whether the backedge-taken count is SCEVCouldNotCompute.
5309 // Note that this serves two purposes: It filters out loops that are
5310 // simply not analyzable, and it covers the case where this code is
5311 // being called from within backedge-taken count analysis, such that
5312 // attempting to ask for the backedge-taken count would likely result
5313 // in infinite recursion. In the later case, the analysis code will
5314 // cope with a conservative value, and it will take care to purge
5315 // that value once it has finished.
5316 const SCEV *MaxBECount = getConstantMaxBackedgeTakenCount(L);
5317
5318 // Normally, in the cases we can prove no-overflow via a
5319 // backedge guarding condition, we can also compute a backedge
5320 // taken count for the loop. The exceptions are assumptions and
5321 // guards present in the loop -- SCEV is not great at exploiting
5322 // these to compute max backedge taken counts, but can still use
5323 // these to prove lack of overflow. Use this fact to avoid
5324 // doing extra work that may not pay off.
5325
5326 if (isa<SCEVCouldNotCompute>(MaxBECount) && !HasGuards &&
5327 AC.assumptions().empty())
5328 return Result;
5329
5330 // If the backedge is guarded by a comparison with the pre-inc value the
5331 // addrec is safe. Also, if the entry is guarded by a comparison with the
5332 // start value and the backedge is guarded by a comparison with the post-inc
5333 // value, the addrec is safe.
5335 const SCEV *OverflowLimit =
5336 getSignedOverflowLimitForStep(Step, &Pred, this);
5337 if (OverflowLimit &&
5338 (isLoopBackedgeGuardedByCond(L, Pred, AR, OverflowLimit) ||
5339 isKnownOnEveryIteration(Pred, AR, OverflowLimit))) {
5340 Result = setFlags(Result, SCEV::FlagNSW);
5341 }
5342 return Result;
5343}
5345ScalarEvolution::proveNoUnsignedWrapViaInduction(const SCEVAddRecExpr *AR) {
5347
5348 if (AR->hasNoUnsignedWrap())
5349 return Result;
5350
5351 if (!AR->isAffine())
5352 return Result;
5353
5354 // This function can be expensive, only try to prove NUW once per AddRec.
5355 if (!UnsignedWrapViaInductionTried.insert(AR).second)
5356 return Result;
5357
5358 const SCEV *Step = AR->getStepRecurrence(*this);
5359 unsigned BitWidth = getTypeSizeInBits(AR->getType());
5360 const Loop *L = AR->getLoop();
5361
5362 // Check whether the backedge-taken count is SCEVCouldNotCompute.
5363 // Note that this serves two purposes: It filters out loops that are
5364 // simply not analyzable, and it covers the case where this code is
5365 // being called from within backedge-taken count analysis, such that
5366 // attempting to ask for the backedge-taken count would likely result
5367 // in infinite recursion. In the later case, the analysis code will
5368 // cope with a conservative value, and it will take care to purge
5369 // that value once it has finished.
5370 const SCEV *MaxBECount = getConstantMaxBackedgeTakenCount(L);
5371
5372 // Normally, in the cases we can prove no-overflow via a
5373 // backedge guarding condition, we can also compute a backedge
5374 // taken count for the loop. The exceptions are assumptions and
5375 // guards present in the loop -- SCEV is not great at exploiting
5376 // these to compute max backedge taken counts, but can still use
5377 // these to prove lack of overflow. Use this fact to avoid
5378 // doing extra work that may not pay off.
5379
5380 if (isa<SCEVCouldNotCompute>(MaxBECount) && !HasGuards &&
5381 AC.assumptions().empty())
5382 return Result;
5383
5384 // If the backedge is guarded by a comparison with the pre-inc value the
5385 // addrec is safe. Also, if the entry is guarded by a comparison with the
5386 // start value and the backedge is guarded by a comparison with the post-inc
5387 // value, the addrec is safe.
5388 if (isKnownPositive(Step)) {
5389 const SCEV *N = getConstant(APInt::getMinValue(BitWidth) -
5390 getUnsignedRangeMax(Step));
5393 Result = setFlags(Result, SCEV::FlagNUW);
5394 }
5395 }
5396
5397 return Result;
5398}
5399
5400namespace {
5401
5402/// Represents an abstract binary operation. This may exist as a
5403/// normal instruction or constant expression, or may have been
5404/// derived from an expression tree.
5405struct BinaryOp {
5406 unsigned Opcode;
5407 Value *LHS;
5408 Value *RHS;
5409 bool IsNSW = false;
5410 bool IsNUW = false;
5411
5412 /// Op is set if this BinaryOp corresponds to a concrete LLVM instruction or
5413 /// constant expression.
5414 Operator *Op = nullptr;
5415
5416 explicit BinaryOp(Operator *Op)
5417 : Opcode(Op->getOpcode()), LHS(Op->getOperand(0)), RHS(Op->getOperand(1)),
5418 Op(Op) {
5419 if (auto *OBO = dyn_cast<OverflowingBinaryOperator>(Op)) {
5420 IsNSW = OBO->hasNoSignedWrap();
5421 IsNUW = OBO->hasNoUnsignedWrap();
5422 }
5423 }
5424
5425 explicit BinaryOp(unsigned Opcode, Value *LHS, Value *RHS, bool IsNSW = false,
5426 bool IsNUW = false)
5427 : Opcode(Opcode), LHS(LHS), RHS(RHS), IsNSW(IsNSW), IsNUW(IsNUW) {}
5428};
5429
5430} // end anonymous namespace
5431
5432/// Try to map \p V into a BinaryOp, and return \c std::nullopt on failure.
5433static std::optional<BinaryOp> MatchBinaryOp(Value *V, const DataLayout &DL,
5434 AssumptionCache &AC,
5435 const DominatorTree &DT,
5436 const Instruction *CxtI) {
5437 auto *Op = dyn_cast<Operator>(V);
5438 if (!Op)
5439 return std::nullopt;
5440
5441 // Implementation detail: all the cleverness here should happen without
5442 // creating new SCEV expressions -- our caller knowns tricks to avoid creating
5443 // SCEV expressions when possible, and we should not break that.
5444
5445 switch (Op->getOpcode()) {
5446 case Instruction::Add:
5447 case Instruction::Sub:
5448 case Instruction::Mul:
5449 case Instruction::UDiv:
5450 case Instruction::URem:
5451 case Instruction::And:
5452 case Instruction::AShr:
5453 case Instruction::Shl:
5454 return BinaryOp(Op);
5455
5456 case Instruction::Or: {
5457 // Convert or disjoint into add nuw nsw.
5458 if (cast<PossiblyDisjointInst>(Op)->isDisjoint()) {
5459 BinaryOp BinOp(Instruction::Add, Op->getOperand(0), Op->getOperand(1),
5460 /*IsNSW=*/true, /*IsNUW=*/true);
5461 // Keep the reference to the original instruction so that we can later
5462 // check whether it can produce poison value or not.
5463 BinOp.Op = Op;
5464 return BinOp;
5465 }
5466 return BinaryOp(Op);
5467 }
5468
5469 case Instruction::Xor:
5470 if (auto *RHSC = dyn_cast<ConstantInt>(Op->getOperand(1)))
5471 // If the RHS of the xor is a signmask, then this is just an add.
5472 // Instcombine turns add of signmask into xor as a strength reduction step.
5473 if (RHSC->getValue().isSignMask())
5474 return BinaryOp(Instruction::Add, Op->getOperand(0), Op->getOperand(1));
5475 // Binary `xor` is a bit-wise `add`.
5476 if (V->getType()->isIntegerTy(1))
5477 return BinaryOp(Instruction::Add, Op->getOperand(0), Op->getOperand(1));
5478 return BinaryOp(Op);
5479
5480 case Instruction::LShr:
5481 // Turn logical shift right of a constant into a unsigned divide.
5482 if (ConstantInt *SA = dyn_cast<ConstantInt>(Op->getOperand(1))) {
5483 uint32_t BitWidth = cast<IntegerType>(Op->getType())->getBitWidth();
5484
5485 // If the shift count is not less than the bitwidth, the result of
5486 // the shift is undefined. Don't try to analyze it, because the
5487 // resolution chosen here may differ from the resolution chosen in
5488 // other parts of the compiler.
5489 if (SA->getValue().ult(BitWidth)) {
5490 Constant *X =
5491 ConstantInt::get(SA->getContext(),
5492 APInt::getOneBitSet(BitWidth, SA->getZExtValue()));
5493 return BinaryOp(Instruction::UDiv, Op->getOperand(0), X);
5494 }
5495 }
5496 return BinaryOp(Op);
5497
5498 case Instruction::ExtractValue: {
5499 auto *EVI = cast<ExtractValueInst>(Op);
5500 if (EVI->getNumIndices() != 1 || EVI->getIndices()[0] != 0)
5501 break;
5502
5503 auto *WO = dyn_cast<WithOverflowInst>(EVI->getAggregateOperand());
5504 if (!WO)
5505 break;
5506
5507 Instruction::BinaryOps BinOp = WO->getBinaryOp();
5508 bool Signed = WO->isSigned();
5509 // TODO: Should add nuw/nsw flags for mul as well.
5510 if (BinOp == Instruction::Mul || !isOverflowIntrinsicNoWrap(WO, DT))
5511 return BinaryOp(BinOp, WO->getLHS(), WO->getRHS());
5512
5513 // Now that we know that all uses of the arithmetic-result component of
5514 // CI are guarded by the overflow check, we can go ahead and pretend
5515 // that the arithmetic is non-overflowing.
5516 return BinaryOp(BinOp, WO->getLHS(), WO->getRHS(),
5517 /* IsNSW = */ Signed, /* IsNUW = */ !Signed);
5518 }
5519
5520 default:
5521 break;
5522 }
5523
5524 // Recognise intrinsic loop.decrement.reg, and as this has exactly the same
5525 // semantics as a Sub, return a binary sub expression.
5526 if (auto *II = dyn_cast<IntrinsicInst>(V))
5527 if (II->getIntrinsicID() == Intrinsic::loop_decrement_reg)
5528 return BinaryOp(Instruction::Sub, II->getOperand(0), II->getOperand(1));
5529
5530 return std::nullopt;
5531}
5532
5533/// Helper function to createAddRecFromPHIWithCasts. We have a phi
5534/// node whose symbolic (unknown) SCEV is \p SymbolicPHI, which is updated via
5535/// the loop backedge by a SCEVAddExpr, possibly also with a few casts on the
5536/// way. This function checks if \p Op, an operand of this SCEVAddExpr,
5537/// follows one of the following patterns:
5538/// Op == (SExt ix (Trunc iy (%SymbolicPHI) to ix) to iy)
5539/// Op == (ZExt ix (Trunc iy (%SymbolicPHI) to ix) to iy)
5540/// If the SCEV expression of \p Op conforms with one of the expected patterns
5541/// we return the type of the truncation operation, and indicate whether the
5542/// truncated type should be treated as signed/unsigned by setting
5543/// \p Signed to true/false, respectively.
5544static Type *isSimpleCastedPHI(const SCEV *Op, const SCEVUnknown *SymbolicPHI,
5545 bool &Signed, ScalarEvolution &SE) {
5546 // The case where Op == SymbolicPHI (that is, with no type conversions on
5547 // the way) is handled by the regular add recurrence creating logic and
5548 // would have already been triggered in createAddRecForPHI. Reaching it here
5549 // means that createAddRecFromPHI had failed for this PHI before (e.g.,
5550 // because one of the other operands of the SCEVAddExpr updating this PHI is
5551 // not invariant).
5552 //
5553 // Here we look for the case where Op = (ext(trunc(SymbolicPHI))), and in
5554 // this case predicates that allow us to prove that Op == SymbolicPHI will
5555 // be added.
5556 if (Op == SymbolicPHI)
5557 return nullptr;
5558
5559 unsigned SourceBits = SE.getTypeSizeInBits(SymbolicPHI->getType());
5560 unsigned NewBits = SE.getTypeSizeInBits(Op->getType());
5561 if (SourceBits != NewBits)
5562 return nullptr;
5563
5564 if (match(Op, m_scev_SExt(m_scev_Trunc(m_scev_Specific(SymbolicPHI))))) {
5565 Signed = true;
5566 return cast<SCEVCastExpr>(Op)->getOperand()->getType();
5567 }
5568 if (match(Op, m_scev_ZExt(m_scev_Trunc(m_scev_Specific(SymbolicPHI))))) {
5569 Signed = false;
5570 return cast<SCEVCastExpr>(Op)->getOperand()->getType();
5571 }
5572 return nullptr;
5573}
5574
5575static const Loop *isIntegerLoopHeaderPHI(const PHINode *PN, LoopInfo &LI) {
5576 if (!PN->getType()->isIntegerTy())
5577 return nullptr;
5578 const Loop *L = LI.getLoopFor(PN->getParent());
5579 if (!L || L->getHeader() != PN->getParent())
5580 return nullptr;
5581 return L;
5582}
5583
5584// Analyze \p SymbolicPHI, a SCEV expression of a phi node, and check if the
5585// computation that updates the phi follows the following pattern:
5586// (SExt/ZExt ix (Trunc iy (%SymbolicPHI) to ix) to iy) + InvariantAccum
5587// which correspond to a phi->trunc->sext/zext->add->phi update chain.
5588// If so, try to see if it can be rewritten as an AddRecExpr under some
5589// Predicates. If successful, return them as a pair. Also cache the results
5590// of the analysis.
5591//
5592// Example usage scenario:
5593// Say the Rewriter is called for the following SCEV:
5594// 8 * ((sext i32 (trunc i64 %X to i32) to i64) + %Step)
5595// where:
5596// %X = phi i64 (%Start, %BEValue)
5597// It will visitMul->visitAdd->visitSExt->visitTrunc->visitUnknown(%X),
5598// and call this function with %SymbolicPHI = %X.
5599//
5600// The analysis will find that the value coming around the backedge has
5601// the following SCEV:
5602// BEValue = ((sext i32 (trunc i64 %X to i32) to i64) + %Step)
5603// Upon concluding that this matches the desired pattern, the function
5604// will return the pair {NewAddRec, SmallPredsVec} where:
5605// NewAddRec = {%Start,+,%Step}
5606// SmallPredsVec = {P1, P2, P3} as follows:
5607// P1(WrapPred): AR: {trunc(%Start),+,(trunc %Step)}<nsw> Flags: <nssw>
5608// P2(EqualPred): %Start == (sext i32 (trunc i64 %Start to i32) to i64)
5609// P3(EqualPred): %Step == (sext i32 (trunc i64 %Step to i32) to i64)
5610// The returned pair means that SymbolicPHI can be rewritten into NewAddRec
5611// under the predicates {P1,P2,P3}.
5612// This predicated rewrite will be cached in PredicatedSCEVRewrites:
5613// PredicatedSCEVRewrites[{%X,L}] = {NewAddRec, {P1,P2,P3)}
5614//
5615// TODO's:
5616//
5617// 1) Extend the Induction descriptor to also support inductions that involve
5618// casts: When needed (namely, when we are called in the context of the
5619// vectorizer induction analysis), a Set of cast instructions will be
5620// populated by this method, and provided back to isInductionPHI. This is
5621// needed to allow the vectorizer to properly record them to be ignored by
5622// the cost model and to avoid vectorizing them (otherwise these casts,
5623// which are redundant under the runtime overflow checks, will be
5624// vectorized, which can be costly).
5625//
5626// 2) Support additional induction/PHISCEV patterns: We also want to support
5627// inductions where the sext-trunc / zext-trunc operations (partly) occur
5628// after the induction update operation (the induction increment):
5629//
5630// (Trunc iy (SExt/ZExt ix (%SymbolicPHI + InvariantAccum) to iy) to ix)
5631// which correspond to a phi->add->trunc->sext/zext->phi update chain.
5632//
5633// (Trunc iy ((SExt/ZExt ix (%SymbolicPhi) to iy) + InvariantAccum) to ix)
5634// which correspond to a phi->trunc->add->sext/zext->phi update chain.
5635//
5636// 3) Outline common code with createAddRecFromPHI to avoid duplication.
5637std::optional<std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>>>
5638ScalarEvolution::createAddRecFromPHIWithCastsImpl(const SCEVUnknown *SymbolicPHI) {
5640
5641 // *** Part1: Analyze if we have a phi-with-cast pattern for which we can
5642 // return an AddRec expression under some predicate.
5643
5644 auto *PN = cast<PHINode>(SymbolicPHI->getValue());
5645 const Loop *L = isIntegerLoopHeaderPHI(PN, LI);
5646 assert(L && "Expecting an integer loop header phi");
5647
5648 // The loop may have multiple entrances or multiple exits; we can analyze
5649 // this phi as an addrec if it has a unique entry value and a unique
5650 // backedge value.
5651 Value *BEValueV = nullptr, *StartValueV = nullptr;
5652 for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) {
5653 Value *V = PN->getIncomingValue(i);
5654 if (L->contains(PN->getIncomingBlock(i))) {
5655 if (!BEValueV) {
5656 BEValueV = V;
5657 } else if (BEValueV != V) {
5658 BEValueV = nullptr;
5659 break;
5660 }
5661 } else if (!StartValueV) {
5662 StartValueV = V;
5663 } else if (StartValueV != V) {
5664 StartValueV = nullptr;
5665 break;
5666 }
5667 }
5668 if (!BEValueV || !StartValueV)
5669 return std::nullopt;
5670
5671 const SCEV *BEValue = getSCEV(BEValueV);
5672
5673 // If the value coming around the backedge is an add with the symbolic
5674 // value we just inserted, possibly with casts that we can ignore under
5675 // an appropriate runtime guard, then we found a simple induction variable!
5676 const auto *Add = dyn_cast<SCEVAddExpr>(BEValue);
5677 if (!Add)
5678 return std::nullopt;
5679
5680 // If there is a single occurrence of the symbolic value, possibly
5681 // casted, replace it with a recurrence.
5682 unsigned FoundIndex = Add->getNumOperands();
5683 Type *TruncTy = nullptr;
5684 bool Signed;
5685 for (unsigned i = 0, e = Add->getNumOperands(); i != e; ++i)
5686 if ((TruncTy =
5687 isSimpleCastedPHI(Add->getOperand(i), SymbolicPHI, Signed, *this)))
5688 if (FoundIndex == e) {
5689 FoundIndex = i;
5690 break;
5691 }
5692
5693 if (FoundIndex == Add->getNumOperands())
5694 return std::nullopt;
5695
5696 // Create an add with everything but the specified operand.
5698 for (unsigned i = 0, e = Add->getNumOperands(); i != e; ++i)
5699 if (i != FoundIndex)
5700 Ops.push_back(Add->getOperand(i));
5701 const SCEV *Accum = getAddExpr(Ops);
5702
5703 // The runtime checks will not be valid if the step amount is
5704 // varying inside the loop.
5705 if (!isLoopInvariant(Accum, L))
5706 return std::nullopt;
5707
5708 // *** Part2: Create the predicates
5709
5710 // Analysis was successful: we have a phi-with-cast pattern for which we
5711 // can return an AddRec expression under the following predicates:
5712 //
5713 // P1: A Wrap predicate that guarantees that Trunc(Start) + i*Trunc(Accum)
5714 // fits within the truncated type (does not overflow) for i = 0 to n-1.
5715 // P2: An Equal predicate that guarantees that
5716 // Start = (Ext ix (Trunc iy (Start) to ix) to iy)
5717 // P3: An Equal predicate that guarantees that
5718 // Accum = (Ext ix (Trunc iy (Accum) to ix) to iy)
5719 //
5720 // As we next prove, the above predicates guarantee that:
5721 // Start + i*Accum = (Ext ix (Trunc iy ( Start + i*Accum ) to ix) to iy)
5722 //
5723 //
5724 // More formally, we want to prove that:
5725 // Expr(i+1) = Start + (i+1) * Accum
5726 // = (Ext ix (Trunc iy (Expr(i)) to ix) to iy) + Accum
5727 //
5728 // Given that:
5729 // 1) Expr(0) = Start
5730 // 2) Expr(1) = Start + Accum
5731 // = (Ext ix (Trunc iy (Start) to ix) to iy) + Accum :: from P2
5732 // 3) Induction hypothesis (step i):
5733 // Expr(i) = (Ext ix (Trunc iy (Expr(i-1)) to ix) to iy) + Accum
5734 //
5735 // Proof:
5736 // Expr(i+1) =
5737 // = Start + (i+1)*Accum
5738 // = (Start + i*Accum) + Accum
5739 // = Expr(i) + Accum
5740 // = (Ext ix (Trunc iy (Expr(i-1)) to ix) to iy) + Accum + Accum
5741 // :: from step i
5742 //
5743 // = (Ext ix (Trunc iy (Start + (i-1)*Accum) to ix) to iy) + Accum + Accum
5744 //
5745 // = (Ext ix (Trunc iy (Start + (i-1)*Accum) to ix) to iy)
5746 // + (Ext ix (Trunc iy (Accum) to ix) to iy)
5747 // + Accum :: from P3
5748 //
5749 // = (Ext ix (Trunc iy ((Start + (i-1)*Accum) + Accum) to ix) to iy)
5750 // + Accum :: from P1: Ext(x)+Ext(y)=>Ext(x+y)
5751 //
5752 // = (Ext ix (Trunc iy (Start + i*Accum) to ix) to iy) + Accum
5753 // = (Ext ix (Trunc iy (Expr(i)) to ix) to iy) + Accum
5754 //
5755 // By induction, the same applies to all iterations 1<=i<n:
5756 //
5757
5758 // Create a truncated addrec for which we will add a no overflow check (P1).
5759 const SCEV *StartVal = getSCEV(StartValueV);
5760 const SCEV *PHISCEV =
5761 getAddRecExpr(getTruncateExpr(StartVal, TruncTy),
5762 getTruncateExpr(Accum, TruncTy), L, SCEV::FlagAnyWrap);
5763
5764 // PHISCEV can be either a SCEVConstant or a SCEVAddRecExpr.
5765 // ex: If truncated Accum is 0 and StartVal is a constant, then PHISCEV
5766 // will be constant.
5767 //
5768 // If PHISCEV is a constant, then P1 degenerates into P2 or P3, so we don't
5769 // add P1.
5770 if (const auto *AR = dyn_cast<SCEVAddRecExpr>(PHISCEV)) {
5774 const SCEVPredicate *AddRecPred = getWrapPredicate(AR, AddedFlags);
5775 Predicates.push_back(AddRecPred);
5776 }
5777
5778 // Create the Equal Predicates P2,P3:
5779
5780 // It is possible that the predicates P2 and/or P3 are computable at
5781 // compile time due to StartVal and/or Accum being constants.
5782 // If either one is, then we can check that now and escape if either P2
5783 // or P3 is false.
5784
5785 // Construct the extended SCEV: (Ext ix (Trunc iy (Expr) to ix) to iy)
5786 // for each of StartVal and Accum
5787 auto getExtendedExpr = [&](const SCEV *Expr,
5788 bool CreateSignExtend) -> const SCEV * {
5789 assert(isLoopInvariant(Expr, L) && "Expr is expected to be invariant");
5790 const SCEV *TruncatedExpr = getTruncateExpr(Expr, TruncTy);
5791 const SCEV *ExtendedExpr =
5792 CreateSignExtend ? getSignExtendExpr(TruncatedExpr, Expr->getType())
5793 : getZeroExtendExpr(TruncatedExpr, Expr->getType());
5794 return ExtendedExpr;
5795 };
5796
5797 // Given:
5798 // ExtendedExpr = (Ext ix (Trunc iy (Expr) to ix) to iy
5799 // = getExtendedExpr(Expr)
5800 // Determine whether the predicate P: Expr == ExtendedExpr
5801 // is known to be false at compile time
5802 auto PredIsKnownFalse = [&](const SCEV *Expr,
5803 const SCEV *ExtendedExpr) -> bool {
5804 return Expr != ExtendedExpr &&
5805 isKnownPredicate(ICmpInst::ICMP_NE, Expr, ExtendedExpr);
5806 };
5807
5808 const SCEV *StartExtended = getExtendedExpr(StartVal, Signed);
5809 if (PredIsKnownFalse(StartVal, StartExtended)) {
5810 LLVM_DEBUG(dbgs() << "P2 is compile-time false\n";);
5811 return std::nullopt;
5812 }
5813
5814 // The Step is always Signed (because the overflow checks are either
5815 // NSSW or NUSW)
5816 const SCEV *AccumExtended = getExtendedExpr(Accum, /*CreateSignExtend=*/true);
5817 if (PredIsKnownFalse(Accum, AccumExtended)) {
5818 LLVM_DEBUG(dbgs() << "P3 is compile-time false\n";);
5819 return std::nullopt;
5820 }
5821
5822 auto AppendPredicate = [&](const SCEV *Expr,
5823 const SCEV *ExtendedExpr) -> void {
5824 if (Expr != ExtendedExpr &&
5825 !isKnownPredicate(ICmpInst::ICMP_EQ, Expr, ExtendedExpr)) {
5826 const SCEVPredicate *Pred = getEqualPredicate(Expr, ExtendedExpr);
5827 LLVM_DEBUG(dbgs() << "Added Predicate: " << *Pred);
5828 Predicates.push_back(Pred);
5829 }
5830 };
5831
5832 AppendPredicate(StartVal, StartExtended);
5833 AppendPredicate(Accum, AccumExtended);
5834
5835 // *** Part3: Predicates are ready. Now go ahead and create the new addrec in
5836 // which the casts had been folded away. The caller can rewrite SymbolicPHI
5837 // into NewAR if it will also add the runtime overflow checks specified in
5838 // Predicates.
5839 auto *NewAR = getAddRecExpr(StartVal, Accum, L, SCEV::FlagAnyWrap);
5840
5841 std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>> PredRewrite =
5842 std::make_pair(NewAR, Predicates);
5843 // Remember the result of the analysis for this SCEV at this locayyytion.
5844 PredicatedSCEVRewrites[{SymbolicPHI, L}] = PredRewrite;
5845 return PredRewrite;
5846}
5847
5848std::optional<std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>>>
5850 auto *PN = cast<PHINode>(SymbolicPHI->getValue());
5851 const Loop *L = isIntegerLoopHeaderPHI(PN, LI);
5852 if (!L)
5853 return std::nullopt;
5854
5855 // Check to see if we already analyzed this PHI.
5856 auto I = PredicatedSCEVRewrites.find({SymbolicPHI, L});
5857 if (I != PredicatedSCEVRewrites.end()) {
5858 std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>> Rewrite =
5859 I->second;
5860 // Analysis was done before and failed to create an AddRec:
5861 if (Rewrite.first == SymbolicPHI)
5862 return std::nullopt;
5863 // Analysis was done before and succeeded to create an AddRec under
5864 // a predicate:
5865 assert(isa<SCEVAddRecExpr>(Rewrite.first) && "Expected an AddRec");
5866 assert(!(Rewrite.second).empty() && "Expected to find Predicates");
5867 return Rewrite;
5868 }
5869
5870 std::optional<std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>>>
5871 Rewrite = createAddRecFromPHIWithCastsImpl(SymbolicPHI);
5872
5873 // Record in the cache that the analysis failed
5874 if (!Rewrite) {
5876 PredicatedSCEVRewrites[{SymbolicPHI, L}] = {SymbolicPHI, Predicates};
5877 return std::nullopt;
5878 }
5879
5880 return Rewrite;
5881}
5882
5883// FIXME: This utility is currently required because the Rewriter currently
5884// does not rewrite this expression:
5885// {0, +, (sext ix (trunc iy to ix) to iy)}
5886// into {0, +, %step},
5887// even when the following Equal predicate exists:
5888// "%step == (sext ix (trunc iy to ix) to iy)".
5890 const SCEVAddRecExpr *AR1, const SCEVAddRecExpr *AR2,
5891 ArrayRef<const SCEVPredicate *> NoWrapPreds) const {
5892 if (AR1 == AR2)
5893 return true;
5894
5895 SCEVUnionPredicate NoWrapUnionPred(NoWrapPreds, SE);
5896 SCEVUnionPredicate AllPreds = Preds->getUnionWith(&NoWrapUnionPred, SE);
5897 auto areExprsEqual = [&](const SCEV *Expr1, const SCEV *Expr2) -> bool {
5898 if (Expr1 != Expr2 &&
5899 !AllPreds.implies(SE.getEqualPredicate(Expr1, Expr2), SE) &&
5900 !AllPreds.implies(SE.getEqualPredicate(Expr2, Expr1), SE))
5901 return false;
5902 return true;
5903 };
5904
5905 if (!areExprsEqual(AR1->getStart(), AR2->getStart()) ||
5906 !areExprsEqual(AR1->getStepRecurrence(SE), AR2->getStepRecurrence(SE)))
5907 return false;
5908 return true;
5909}
5910
5911/// A helper function for createAddRecFromPHI to handle simple cases.
5912///
5913/// This function tries to find an AddRec expression for the simplest (yet most
5914/// common) cases: PN = PHI(Start, OP(Self, LoopInvariant)).
5915/// If it fails, createAddRecFromPHI will use a more general, but slow,
5916/// technique for finding the AddRec expression.
5917const SCEV *ScalarEvolution::createSimpleAffineAddRec(PHINode *PN,
5918 Value *BEValueV,
5919 Value *StartValueV) {
5920 const Loop *L = LI.getLoopFor(PN->getParent());
5921 assert(L && L->getHeader() == PN->getParent());
5922 assert(BEValueV && StartValueV);
5923
5924 auto BO = MatchBinaryOp(BEValueV, getDataLayout(), AC, DT, PN);
5925 if (!BO)
5926 return nullptr;
5927
5928 if (BO->Opcode != Instruction::Add)
5929 return nullptr;
5930
5931 const SCEV *Accum = nullptr;
5932 if (BO->LHS == PN && L->isLoopInvariant(BO->RHS))
5933 Accum = getSCEV(BO->RHS);
5934 else if (BO->RHS == PN && L->isLoopInvariant(BO->LHS))
5935 Accum = getSCEV(BO->LHS);
5936
5937 if (!Accum)
5938 return nullptr;
5939
5941 if (BO->IsNUW)
5942 Flags = setFlags(Flags, SCEV::FlagNUW);
5943 if (BO->IsNSW)
5944 Flags = setFlags(Flags, SCEV::FlagNSW);
5945
5946 const SCEV *StartVal = getSCEV(StartValueV);
5947 const SCEV *PHISCEV = getAddRecExpr(StartVal, Accum, L, Flags);
5948 insertValueToMap(PN, PHISCEV);
5949
5950 if (auto *AR = dyn_cast<SCEVAddRecExpr>(PHISCEV)) {
5951 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR),
5952 (AR->getNoWrapFlags() | proveNoWrapViaConstantRanges(AR)));
5953 }
5954
5955 // We can add Flags to the post-inc expression only if we
5956 // know that it is *undefined behavior* for BEValueV to
5957 // overflow.
5958 if (auto *BEInst = dyn_cast<Instruction>(BEValueV)) {
5959 assert(isLoopInvariant(Accum, L) &&
5960 "Accum is defined outside L, but is not invariant?");
5961 if (isAddRecNeverPoison(BEInst, L))
5962 (void)getAddRecExpr(getAddExpr(StartVal, Accum), Accum, L, Flags);
5963 }
5964
5965 return PHISCEV;
5966}
5967
5968const SCEV *ScalarEvolution::createAddRecFromPHI(PHINode *PN) {
5969 const Loop *L = LI.getLoopFor(PN->getParent());
5970 if (!L || L->getHeader() != PN->getParent())
5971 return nullptr;
5972
5973 // The loop may have multiple entrances or multiple exits; we can analyze
5974 // this phi as an addrec if it has a unique entry value and a unique
5975 // backedge value.
5976 Value *BEValueV = nullptr, *StartValueV = nullptr;
5977 for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) {
5978 Value *V = PN->getIncomingValue(i);
5979 if (L->contains(PN->getIncomingBlock(i))) {
5980 if (!BEValueV) {
5981 BEValueV = V;
5982 } else if (BEValueV != V) {
5983 BEValueV = nullptr;
5984 break;
5985 }
5986 } else if (!StartValueV) {
5987 StartValueV = V;
5988 } else if (StartValueV != V) {
5989 StartValueV = nullptr;
5990 break;
5991 }
5992 }
5993 if (!BEValueV || !StartValueV)
5994 return nullptr;
5995
5996 assert(ValueExprMap.find_as(PN) == ValueExprMap.end() &&
5997 "PHI node already processed?");
5998
5999 // First, try to find AddRec expression without creating a fictituos symbolic
6000 // value for PN.
6001 if (auto *S = createSimpleAffineAddRec(PN, BEValueV, StartValueV))
6002 return S;
6003
6004 // Handle PHI node value symbolically.
6005 const SCEV *SymbolicName = getUnknown(PN);
6006 insertValueToMap(PN, SymbolicName);
6007
6008 // Using this symbolic name for the PHI, analyze the value coming around
6009 // the back-edge.
6010 const SCEV *BEValue = getSCEV(BEValueV);
6011
6012 // NOTE: If BEValue is loop invariant, we know that the PHI node just
6013 // has a special value for the first iteration of the loop.
6014
6015 // If the value coming around the backedge is an add with the symbolic
6016 // value we just inserted, then we found a simple induction variable!
6017 if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(BEValue)) {
6018 // If there is a single occurrence of the symbolic value, replace it
6019 // with a recurrence.
6020 unsigned FoundIndex = Add->getNumOperands();
6021 for (unsigned i = 0, e = Add->getNumOperands(); i != e; ++i)
6022 if (Add->getOperand(i) == SymbolicName)
6023 if (FoundIndex == e) {
6024 FoundIndex = i;
6025 break;
6026 }
6027
6028 if (FoundIndex != Add->getNumOperands()) {
6029 // Create an add with everything but the specified operand.
6031 for (unsigned i = 0, e = Add->getNumOperands(); i != e; ++i)
6032 if (i != FoundIndex)
6033 Ops.push_back(SCEVBackedgeConditionFolder::rewrite(Add->getOperand(i),
6034 L, *this));
6035 const SCEV *Accum = getAddExpr(Ops);
6036
6037 // This is not a valid addrec if the step amount is varying each
6038 // loop iteration, but is not itself an addrec in this loop.
6039 if (isLoopInvariant(Accum, L) ||
6040 (isa<SCEVAddRecExpr>(Accum) &&
6041 cast<SCEVAddRecExpr>(Accum)->getLoop() == L)) {
6043
6044 if (auto BO = MatchBinaryOp(BEValueV, getDataLayout(), AC, DT, PN)) {
6045 if (BO->Opcode == Instruction::Add && BO->LHS == PN) {
6046 if (BO->IsNUW)
6047 Flags = setFlags(Flags, SCEV::FlagNUW);
6048 if (BO->IsNSW)
6049 Flags = setFlags(Flags, SCEV::FlagNSW);
6050 }
6051 } else if (GEPOperator *GEP = dyn_cast<GEPOperator>(BEValueV)) {
6052 if (GEP->getOperand(0) == PN) {
6053 GEPNoWrapFlags NW = GEP->getNoWrapFlags();
6054 // If the increment has any nowrap flags, then we know the address
6055 // space cannot be wrapped around.
6056 if (NW != GEPNoWrapFlags::none())
6057 Flags = setFlags(Flags, SCEV::FlagNW);
6058 // If the GEP is nuw or nusw with non-negative offset, we know that
6059 // no unsigned wrap occurs. We cannot set the nsw flag as only the
6060 // offset is treated as signed, while the base is unsigned.
6061 if (NW.hasNoUnsignedWrap() ||
6063 Flags = setFlags(Flags, SCEV::FlagNUW);
6064 }
6065
6066 // We cannot transfer nuw and nsw flags from subtraction
6067 // operations -- sub nuw X, Y is not the same as add nuw X, -Y
6068 // for instance.
6069 }
6070
6071 const SCEV *StartVal = getSCEV(StartValueV);
6072 const SCEV *PHISCEV = getAddRecExpr(StartVal, Accum, L, Flags);
6073
6074 // Okay, for the entire analysis of this edge we assumed the PHI
6075 // to be symbolic. We now need to go back and purge all of the
6076 // entries for the scalars that use the symbolic expression.
6077 forgetMemoizedResults({SymbolicName});
6078 insertValueToMap(PN, PHISCEV);
6079
6080 if (auto *AR = dyn_cast<SCEVAddRecExpr>(PHISCEV)) {
6082 const_cast<SCEVAddRecExpr *>(AR),
6083 (AR->getNoWrapFlags() | proveNoWrapViaConstantRanges(AR)));
6084 }
6085
6086 // We can add Flags to the post-inc expression only if we
6087 // know that it is *undefined behavior* for BEValueV to
6088 // overflow.
6089 if (auto *BEInst = dyn_cast<Instruction>(BEValueV))
6090 if (isLoopInvariant(Accum, L) && isAddRecNeverPoison(BEInst, L))
6091 (void)getAddRecExpr(getAddExpr(StartVal, Accum), Accum, L, Flags);
6092
6093 return PHISCEV;
6094 }
6095 }
6096 } else {
6097 // Otherwise, this could be a loop like this:
6098 // i = 0; for (j = 1; ..; ++j) { .... i = j; }
6099 // In this case, j = {1,+,1} and BEValue is j.
6100 // Because the other in-value of i (0) fits the evolution of BEValue
6101 // i really is an addrec evolution.
6102 //
6103 // We can generalize this saying that i is the shifted value of BEValue
6104 // by one iteration:
6105 // PHI(f(0), f({1,+,1})) --> f({0,+,1})
6106
6107 // Do not allow refinement in rewriting of BEValue.
6108 const SCEV *Shifted = SCEVShiftRewriter::rewrite(BEValue, L, *this);
6109 const SCEV *Start = SCEVInitRewriter::rewrite(Shifted, L, *this, false);
6110 if (Shifted != getCouldNotCompute() && Start != getCouldNotCompute() &&
6111 isGuaranteedNotToCauseUB(Shifted) && ::impliesPoison(Shifted, Start)) {
6112 const SCEV *StartVal = getSCEV(StartValueV);
6113 if (Start == StartVal) {
6114 // Okay, for the entire analysis of this edge we assumed the PHI
6115 // to be symbolic. We now need to go back and purge all of the
6116 // entries for the scalars that use the symbolic expression.
6117 forgetMemoizedResults({SymbolicName});
6118 insertValueToMap(PN, Shifted);
6119 return Shifted;
6120 }
6121 }
6122 }
6123
6124 // Remove the temporary PHI node SCEV that has been inserted while intending
6125 // to create an AddRecExpr for this PHI node. We can not keep this temporary
6126 // as it will prevent later (possibly simpler) SCEV expressions to be added
6127 // to the ValueExprMap.
6128 eraseValueFromMap(PN);
6129
6130 return nullptr;
6131}
6132
6133// Try to match a control flow sequence that branches out at BI and merges back
6134// at Merge into a "C ? LHS : RHS" select pattern. Return true on a successful
6135// match.
6137 Value *&C, Value *&LHS, Value *&RHS) {
6138 C = BI->getCondition();
6139
6140 BasicBlockEdge LeftEdge(BI->getParent(), BI->getSuccessor(0));
6141 BasicBlockEdge RightEdge(BI->getParent(), BI->getSuccessor(1));
6142
6143 Use &LeftUse = Merge->getOperandUse(0);
6144 Use &RightUse = Merge->getOperandUse(1);
6145
6146 if (DT.dominates(LeftEdge, LeftUse) && DT.dominates(RightEdge, RightUse)) {
6147 LHS = LeftUse;
6148 RHS = RightUse;
6149 return true;
6150 }
6151
6152 if (DT.dominates(LeftEdge, RightUse) && DT.dominates(RightEdge, LeftUse)) {
6153 LHS = RightUse;
6154 RHS = LeftUse;
6155 return true;
6156 }
6157
6158 return false;
6159}
6160
6162 Value *&Cond, Value *&LHS,
6163 Value *&RHS) {
6164 auto IsReachable =
6165 [&](BasicBlock *BB) { return DT.isReachableFromEntry(BB); };
6166 if (PN->getNumIncomingValues() == 2 && all_of(PN->blocks(), IsReachable)) {
6167 // Try to match
6168 //
6169 // br %cond, label %left, label %right
6170 // left:
6171 // br label %merge
6172 // right:
6173 // br label %merge
6174 // merge:
6175 // V = phi [ %x, %left ], [ %y, %right ]
6176 //
6177 // as "select %cond, %x, %y"
6178
6179 BasicBlock *IDom = DT[PN->getParent()]->getIDom()->getBlock();
6180 assert(IDom && "At least the entry block should dominate PN");
6181
6182 auto *BI = dyn_cast<CondBrInst>(IDom->getTerminator());
6183 return BI && BrPHIToSelect(DT, BI, PN, Cond, LHS, RHS);
6184 }
6185 return false;
6186}
6187
6188const SCEV *ScalarEvolution::createNodeFromSelectLikePHI(PHINode *PN) {
6189 Value *Cond = nullptr, *LHS = nullptr, *RHS = nullptr;
6190 if (getOperandsForSelectLikePHI(DT, PN, Cond, LHS, RHS) &&
6193 return createNodeForSelectOrPHI(PN, Cond, LHS, RHS);
6194
6195 return nullptr;
6196}
6197
6199 BinaryOperator *CommonInst = nullptr;
6200 // Check if instructions are identical.
6201 for (Value *Incoming : PN->incoming_values()) {
6202 auto *IncomingInst = dyn_cast<BinaryOperator>(Incoming);
6203 if (!IncomingInst)
6204 return nullptr;
6205 if (CommonInst) {
6206 if (!CommonInst->isIdenticalToWhenDefined(IncomingInst))
6207 return nullptr; // Not identical, give up
6208 } else {
6209 // Remember binary operator
6210 CommonInst = IncomingInst;
6211 }
6212 }
6213 return CommonInst;
6214}
6215
6216/// Returns SCEV for the first operand of a phi if all phi operands have
6217/// identical opcodes and operands
6218/// eg.
6219/// a: %add = %a + %b
6220/// br %c
6221/// b: %add1 = %a + %b
6222/// br %c
6223/// c: %phi = phi [%add, a], [%add1, b]
6224/// scev(%phi) => scev(%add)
6225const SCEV *
6226ScalarEvolution::createNodeForPHIWithIdenticalOperands(PHINode *PN) {
6227 BinaryOperator *CommonInst = getCommonInstForPHI(PN);
6228 if (!CommonInst)
6229 return nullptr;
6230
6231 // Check if SCEV exprs for instructions are identical.
6232 const SCEV *CommonSCEV = getSCEV(CommonInst);
6233 bool SCEVExprsIdentical =
6235 [this, CommonSCEV](Value *V) { return CommonSCEV == getSCEV(V); });
6236 return SCEVExprsIdentical ? CommonSCEV : nullptr;
6237}
6238
6239const SCEV *ScalarEvolution::createNodeForPHI(PHINode *PN) {
6240 if (const SCEV *S = createAddRecFromPHI(PN))
6241 return S;
6242
6243 // We do not allow simplifying phi (undef, X) to X here, to avoid reusing the
6244 // phi node for X.
6245 if (Value *V = simplifyInstruction(
6246 PN, {getDataLayout(), &TLI, &DT, &AC, /*CtxI=*/nullptr,
6247 /*UseInstrInfo=*/true, /*CanUseUndef=*/false}))
6248 return getSCEV(V);
6249
6250 if (const SCEV *S = createNodeForPHIWithIdenticalOperands(PN))
6251 return S;
6252
6253 if (const SCEV *S = createNodeFromSelectLikePHI(PN))
6254 return S;
6255
6256 // If it's not a loop phi, we can't handle it yet.
6257 return getUnknown(PN);
6258}
6259
6260bool SCEVMinMaxExprContains(const SCEV *Root, const SCEV *OperandToFind,
6261 SCEVTypes RootKind) {
6262 struct FindClosure {
6263 const SCEV *OperandToFind;
6264 const SCEVTypes RootKind; // Must be a sequential min/max expression.
6265 const SCEVTypes NonSequentialRootKind; // Non-seq variant of RootKind.
6266
6267 bool Found = false;
6268
6269 bool canRecurseInto(SCEVTypes Kind) const {
6270 // We can only recurse into the SCEV expression of the same effective type
6271 // as the type of our root SCEV expression, and into zero-extensions.
6272 return RootKind == Kind || NonSequentialRootKind == Kind ||
6273 scZeroExtend == Kind;
6274 };
6275
6276 FindClosure(const SCEV *OperandToFind, SCEVTypes RootKind)
6277 : OperandToFind(OperandToFind), RootKind(RootKind),
6278 NonSequentialRootKind(
6280 RootKind)) {}
6281
6282 bool follow(const SCEV *S) {
6283 Found = S == OperandToFind;
6284
6285 return !isDone() && canRecurseInto(S->getSCEVType());
6286 }
6287
6288 bool isDone() const { return Found; }
6289 };
6290
6291 FindClosure FC(OperandToFind, RootKind);
6292 visitAll(Root, FC);
6293 return FC.Found;
6294}
6295
6296std::optional<const SCEV *>
6297ScalarEvolution::createNodeForSelectOrPHIInstWithICmpInstCond(Type *Ty,
6298 ICmpInst *Cond,
6299 Value *TrueVal,
6300 Value *FalseVal) {
6301 // Try to match some simple smax or umax patterns.
6302 auto *ICI = Cond;
6303
6304 Value *LHS = ICI->getOperand(0);
6305 Value *RHS = ICI->getOperand(1);
6306
6307 switch (ICI->getPredicate()) {
6308 case ICmpInst::ICMP_SLT:
6309 case ICmpInst::ICMP_SLE:
6310 case ICmpInst::ICMP_ULT:
6311 case ICmpInst::ICMP_ULE:
6312 std::swap(LHS, RHS);
6313 [[fallthrough]];
6314 case ICmpInst::ICMP_SGT:
6315 case ICmpInst::ICMP_SGE:
6316 case ICmpInst::ICMP_UGT:
6317 case ICmpInst::ICMP_UGE:
6318 // a > b ? a+x : b+x -> max(a, b)+x
6319 // a > b ? b+x : a+x -> min(a, b)+x
6321 bool Signed = ICI->isSigned();
6322 const SCEV *LA = getSCEV(TrueVal);
6323 const SCEV *RA = getSCEV(FalseVal);
6324 const SCEV *LS = getSCEV(LHS);
6325 const SCEV *RS = getSCEV(RHS);
6326 if (LA->getType()->isPointerTy()) {
6327 // FIXME: Handle cases where LS/RS are pointers not equal to LA/RA.
6328 // Need to make sure we can't produce weird expressions involving
6329 // negated pointers.
6330 if (LA == LS && RA == RS)
6331 return Signed ? getSMaxExpr(LS, RS) : getUMaxExpr(LS, RS);
6332 if (LA == RS && RA == LS)
6333 return Signed ? getSMinExpr(LS, RS) : getUMinExpr(LS, RS);
6334 }
6335 auto CoerceOperand = [&](const SCEV *Op) -> const SCEV * {
6336 if (Op->getType()->isPointerTy()) {
6339 return Op;
6340 }
6341 if (Signed)
6342 Op = getNoopOrSignExtend(Op, Ty);
6343 else
6344 Op = getNoopOrZeroExtend(Op, Ty);
6345 return Op;
6346 };
6347 LS = CoerceOperand(LS);
6348 RS = CoerceOperand(RS);
6350 break;
6351 const SCEV *LDiff = getMinusSCEV(LA, LS);
6352 const SCEV *RDiff = getMinusSCEV(RA, RS);
6353 if (LDiff == RDiff)
6354 return getAddExpr(Signed ? getSMaxExpr(LS, RS) : getUMaxExpr(LS, RS),
6355 LDiff);
6356 LDiff = getMinusSCEV(LA, RS);
6357 RDiff = getMinusSCEV(RA, LS);
6358 if (LDiff == RDiff)
6359 return getAddExpr(Signed ? getSMinExpr(LS, RS) : getUMinExpr(LS, RS),
6360 LDiff);
6361 }
6362 break;
6363 case ICmpInst::ICMP_NE:
6364 // x != 0 ? x+y : C+y -> x == 0 ? C+y : x+y
6365 std::swap(TrueVal, FalseVal);
6366 [[fallthrough]];
6367 case ICmpInst::ICMP_EQ:
6368 // x == 0 ? C+y : x+y -> umax(x, C)+y iff C u<= 1
6371 const SCEV *X = getNoopOrZeroExtend(getSCEV(LHS), Ty);
6372 const SCEV *TrueValExpr = getSCEV(TrueVal); // C+y
6373 const SCEV *FalseValExpr = getSCEV(FalseVal); // x+y
6374 const SCEV *Y = getMinusSCEV(FalseValExpr, X); // y = (x+y)-x
6375 const SCEV *C = getMinusSCEV(TrueValExpr, Y); // C = (C+y)-y
6376 if (isa<SCEVConstant>(C) && cast<SCEVConstant>(C)->getAPInt().ule(1))
6377 return getAddExpr(getUMaxExpr(X, C), Y);
6378 }
6379 // x == 0 ? 0 : umin (..., x, ...) -> umin_seq(x, umin (...))
6380 // x == 0 ? 0 : umin_seq(..., x, ...) -> umin_seq(x, umin_seq(...))
6381 // x == 0 ? 0 : umin (..., umin_seq(..., x, ...), ...)
6382 // -> umin_seq(x, umin (..., umin_seq(...), ...))
6384 isa<ConstantInt>(TrueVal) && cast<ConstantInt>(TrueVal)->isZero()) {
6385 const SCEV *X = getSCEV(LHS);
6386 while (auto *ZExt = dyn_cast<SCEVZeroExtendExpr>(X))
6387 X = ZExt->getOperand();
6388 if (getTypeSizeInBits(X->getType()) <= getTypeSizeInBits(Ty)) {
6389 const SCEV *FalseValExpr = getSCEV(FalseVal);
6390 if (SCEVMinMaxExprContains(FalseValExpr, X, scSequentialUMinExpr))
6391 return getUMinExpr(getNoopOrZeroExtend(X, Ty), FalseValExpr,
6392 /*Sequential=*/true);
6393 }
6394 }
6395 break;
6396 default:
6397 break;
6398 }
6399
6400 return std::nullopt;
6401}
6402
6403static std::optional<const SCEV *>
6405 const SCEV *TrueExpr, const SCEV *FalseExpr) {
6406 assert(CondExpr->getType()->isIntegerTy(1) &&
6407 TrueExpr->getType() == FalseExpr->getType() &&
6408 TrueExpr->getType()->isIntegerTy(1) &&
6409 "Unexpected operands of a select.");
6410
6411 // i1 cond ? i1 x : i1 C --> C + (i1 cond ? (i1 x - i1 C) : i1 0)
6412 // --> C + (umin_seq cond, x - C)
6413 //
6414 // i1 cond ? i1 C : i1 x --> C + (i1 cond ? i1 0 : (i1 x - i1 C))
6415 // --> C + (i1 ~cond ? (i1 x - i1 C) : i1 0)
6416 // --> C + (umin_seq ~cond, x - C)
6417
6418 // FIXME: while we can't legally model the case where both of the hands
6419 // are fully variable, we only require that the *difference* is constant.
6420 if (!isa<SCEVConstant>(TrueExpr) && !isa<SCEVConstant>(FalseExpr))
6421 return std::nullopt;
6422
6423 const SCEV *X, *C;
6424 if (isa<SCEVConstant>(TrueExpr)) {
6425 CondExpr = SE->getNotSCEV(CondExpr);
6426 X = FalseExpr;
6427 C = TrueExpr;
6428 } else {
6429 X = TrueExpr;
6430 C = FalseExpr;
6431 }
6432 return SE->getAddExpr(C, SE->getUMinExpr(CondExpr, SE->getMinusSCEV(X, C),
6433 /*Sequential=*/true));
6434}
6435
6436static std::optional<const SCEV *>
6438 Value *FalseVal) {
6439 if (!isa<ConstantInt>(TrueVal) && !isa<ConstantInt>(FalseVal))
6440 return std::nullopt;
6441
6442 const auto *SECond = SE->getSCEV(Cond);
6443 const auto *SETrue = SE->getSCEV(TrueVal);
6444 const auto *SEFalse = SE->getSCEV(FalseVal);
6445 return createNodeForSelectViaUMinSeq(SE, SECond, SETrue, SEFalse);
6446}
6447
6448const SCEV *ScalarEvolution::createNodeForSelectOrPHIViaUMinSeq(
6449 Value *V, Value *Cond, Value *TrueVal, Value *FalseVal) {
6450 assert(Cond->getType()->isIntegerTy(1) && "Select condition is not an i1?");
6451 assert(TrueVal->getType() == FalseVal->getType() &&
6452 V->getType() == TrueVal->getType() &&
6453 "Types of select hands and of the result must match.");
6454
6455 // For now, only deal with i1-typed `select`s.
6456 if (!V->getType()->isIntegerTy(1))
6457 return getUnknown(V);
6458
6459 if (std::optional<const SCEV *> S =
6460 createNodeForSelectViaUMinSeq(this, Cond, TrueVal, FalseVal))
6461 return *S;
6462
6463 return getUnknown(V);
6464}
6465
6466const SCEV *ScalarEvolution::createNodeForSelectOrPHI(Value *V, Value *Cond,
6467 Value *TrueVal,
6468 Value *FalseVal) {
6469 // Handle "constant" branch or select. This can occur for instance when a
6470 // loop pass transforms an inner loop and moves on to process the outer loop.
6471 if (auto *CI = dyn_cast<ConstantInt>(Cond))
6472 return getSCEV(CI->isOne() ? TrueVal : FalseVal);
6473
6474 if (auto *I = dyn_cast<Instruction>(V)) {
6475 if (auto *ICI = dyn_cast<ICmpInst>(Cond)) {
6476 if (std::optional<const SCEV *> S =
6477 createNodeForSelectOrPHIInstWithICmpInstCond(I->getType(), ICI,
6478 TrueVal, FalseVal))
6479 return *S;
6480 }
6481 }
6482
6483 return createNodeForSelectOrPHIViaUMinSeq(V, Cond, TrueVal, FalseVal);
6484}
6485
6486/// Expand GEP instructions into add and multiply operations. This allows them
6487/// to be analyzed by regular SCEV code.
6488const SCEV *ScalarEvolution::createNodeForGEP(GEPOperator *GEP) {
6489 assert(GEP->getSourceElementType()->isSized() &&
6490 "GEP source element type must be sized");
6491
6492 SmallVector<SCEVUse, 4> IndexExprs;
6493 for (Value *Index : GEP->indices())
6494 IndexExprs.push_back(getSCEV(Index));
6495 return getGEPExpr(GEP, IndexExprs);
6496}
6497
6498APInt ScalarEvolution::getConstantMultipleImpl(const SCEV *S,
6499 const Instruction *CtxI) {
6500 uint64_t BitWidth = getTypeSizeInBits(S->getType());
6501 auto GetShiftedByZeros = [BitWidth](uint32_t TrailingZeros) {
6502 return TrailingZeros >= BitWidth
6504 : APInt::getOneBitSet(BitWidth, TrailingZeros);
6505 };
6506 auto GetGCDMultiple = [this, CtxI](const SCEVNAryExpr *N) {
6507 // The result is GCD of all operands results.
6508 APInt Res = getConstantMultiple(N->getOperand(0), CtxI);
6509 for (unsigned I = 1, E = N->getNumOperands(); I < E && Res != 1; ++I)
6511 Res, getConstantMultiple(N->getOperand(I), CtxI));
6512 return Res;
6513 };
6514
6515 switch (S->getSCEVType()) {
6516 case scConstant:
6517 return cast<SCEVConstant>(S)->getAPInt();
6518 case scPtrToAddr:
6519 case scPtrToInt:
6520 return getConstantMultiple(cast<SCEVCastExpr>(S)->getOperand());
6521 case scUDivExpr:
6522 case scVScale:
6523 return APInt(BitWidth, 1);
6524 case scTruncate: {
6525 // Only multiples that are a power of 2 will hold after truncation.
6526 const SCEVTruncateExpr *T = cast<SCEVTruncateExpr>(S);
6527 uint32_t TZ = getMinTrailingZeros(T->getOperand(), CtxI);
6528 return GetShiftedByZeros(TZ);
6529 }
6530 case scZeroExtend: {
6531 const SCEVZeroExtendExpr *Z = cast<SCEVZeroExtendExpr>(S);
6532 return getConstantMultiple(Z->getOperand(), CtxI).zext(BitWidth);
6533 }
6534 case scSignExtend: {
6535 // Only multiples that are a power of 2 will hold after sext.
6536 const SCEVSignExtendExpr *E = cast<SCEVSignExtendExpr>(S);
6537 uint32_t TZ = getMinTrailingZeros(E->getOperand(), CtxI);
6538 return GetShiftedByZeros(TZ);
6539 }
6540 case scMulExpr: {
6541 const SCEVMulExpr *M = cast<SCEVMulExpr>(S);
6542 if (M->hasNoUnsignedWrap()) {
6543 // The result is the product of all operand results.
6544 APInt Res = getConstantMultiple(M->getOperand(0), CtxI);
6545 for (const SCEV *Operand : M->operands().drop_front())
6546 Res = Res * getConstantMultiple(Operand, CtxI);
6547 return Res;
6548 }
6549
6550 // If there are no wrap guarentees, find the trailing zeros, which is the
6551 // sum of trailing zeros for all its operands.
6552 uint32_t TZ = 0;
6553 for (const SCEV *Operand : M->operands())
6554 TZ += getMinTrailingZeros(Operand, CtxI);
6555 return GetShiftedByZeros(TZ);
6556 }
6557 case scAddExpr:
6558 case scAddRecExpr: {
6559 const SCEVNAryExpr *N = cast<SCEVNAryExpr>(S);
6560 if (N->hasNoUnsignedWrap())
6561 return GetGCDMultiple(N);
6562 // Find the trailing bits, which is the minimum of its operands.
6563 uint32_t TZ = getMinTrailingZeros(N->getOperand(0), CtxI);
6564 for (const SCEV *Operand : N->operands().drop_front())
6565 TZ = std::min(TZ, getMinTrailingZeros(Operand, CtxI));
6566 return GetShiftedByZeros(TZ);
6567 }
6568 case scUMaxExpr:
6569 case scSMaxExpr:
6570 case scUMinExpr:
6571 case scSMinExpr:
6573 return GetGCDMultiple(cast<SCEVNAryExpr>(S));
6574 case scUnknown: {
6575 // Ask ValueTracking for known bits. SCEVUnknown only become available at
6576 // the point their underlying IR instruction has been defined. If CtxI was
6577 // not provided, use:
6578 // * the first instruction in the entry block if it is an argument
6579 // * the instruction itself otherwise.
6580 const SCEVUnknown *U = cast<SCEVUnknown>(S);
6581 if (!CtxI) {
6582 if (isa<Argument>(U->getValue()))
6583 CtxI = &*F.getEntryBlock().begin();
6584 else if (auto *I = dyn_cast<Instruction>(U->getValue()))
6585 CtxI = I;
6586 }
6587 unsigned Known =
6588 computeKnownBits(U->getValue(),
6589 SimplifyQuery(getDataLayout(), &DT, &AC, CtxI)
6590 .allowEphemerals(true))
6591 .countMinTrailingZeros();
6592 return GetShiftedByZeros(Known);
6593 }
6594 case scCouldNotCompute:
6595 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
6596 }
6597 llvm_unreachable("Unknown SCEV kind!");
6598}
6599
6601 const Instruction *CtxI) {
6602 // Skip looking up and updating the cache if there is a context instruction,
6603 // as the result will only be valid in the specified context.
6604 if (CtxI)
6605 return getConstantMultipleImpl(S, CtxI);
6606
6607 auto I = ConstantMultipleCache.find(S);
6608 if (I != ConstantMultipleCache.end())
6609 return I->second;
6610
6611 APInt Result = getConstantMultipleImpl(S, CtxI);
6612 auto InsertPair = ConstantMultipleCache.insert({S, Result});
6613 assert(InsertPair.second && "Should insert a new key");
6614 return InsertPair.first->second;
6615}
6616
6618 APInt Multiple = getConstantMultiple(S);
6619 return Multiple == 0 ? APInt(Multiple.getBitWidth(), 1) : Multiple;
6620}
6621
6623 const Instruction *CtxI) {
6624 return std::min(getConstantMultiple(S, CtxI).countTrailingZeros(),
6625 (unsigned)getTypeSizeInBits(S->getType()));
6626}
6627
6628/// Helper method to assign a range to V from metadata present in the IR.
6629static std::optional<ConstantRange> GetRangeFromMetadata(Value *V) {
6631 if (MDNode *MD = I->getMetadata(LLVMContext::MD_range))
6632 return getConstantRangeFromMetadata(*MD);
6633 if (const auto *CB = dyn_cast<CallBase>(V))
6634 if (std::optional<ConstantRange> Range = CB->getRange())
6635 return Range;
6636 }
6637 if (auto *A = dyn_cast<Argument>(V))
6638 if (std::optional<ConstantRange> Range = A->getRange())
6639 return Range;
6640
6641 return std::nullopt;
6642}
6643
6645 SCEV::NoWrapFlags Flags) {
6646 if (AddRec->getNoWrapFlags(Flags) != Flags) {
6647 AddRec->setNoWrapFlags(Flags);
6648 UnsignedRanges.erase(AddRec);
6649 SignedRanges.erase(AddRec);
6650 ConstantMultipleCache.erase(AddRec);
6651 }
6652}
6653
6654ConstantRange ScalarEvolution::
6655getRangeForUnknownRecurrence(const SCEVUnknown *U) {
6656 const DataLayout &DL = getDataLayout();
6657
6658 unsigned BitWidth = getTypeSizeInBits(U->getType());
6659 const ConstantRange FullSet(BitWidth, /*isFullSet=*/true);
6660
6661 // Match a simple recurrence of the form: <start, ShiftOp, Step>, and then
6662 // use information about the trip count to improve our available range. Note
6663 // that the trip count independent cases are already handled by known bits.
6664 // WARNING: The definition of recurrence used here is subtly different than
6665 // the one used by AddRec (and thus most of this file). Step is allowed to
6666 // be arbitrarily loop varying here, where AddRec allows only loop invariant
6667 // and other addrecs in the same loop (for non-affine addrecs). The code
6668 // below intentionally handles the case where step is not loop invariant.
6669 auto *P = dyn_cast<PHINode>(U->getValue());
6670 if (!P)
6671 return FullSet;
6672
6673 // Make sure that no Phi input comes from an unreachable block. Otherwise,
6674 // even the values that are not available in these blocks may come from them,
6675 // and this leads to false-positive recurrence test.
6676 for (auto *Pred : predecessors(P->getParent()))
6677 if (!DT.isReachableFromEntry(Pred))
6678 return FullSet;
6679
6680 BinaryOperator *BO;
6681 Value *Start, *Step;
6682 if (!matchSimpleRecurrence(P, BO, Start, Step))
6683 return FullSet;
6684
6685 // If we found a recurrence in reachable code, we must be in a loop. Note
6686 // that BO might be in some subloop of L, and that's completely okay.
6687 auto *L = LI.getLoopFor(P->getParent());
6688 assert(L && L->getHeader() == P->getParent());
6689 if (!L->contains(BO->getParent()))
6690 // NOTE: This bailout should be an assert instead. However, asserting
6691 // the condition here exposes a case where LoopFusion is querying SCEV
6692 // with malformed loop information during the midst of the transform.
6693 // There doesn't appear to be an obvious fix, so for the moment bailout
6694 // until the caller issue can be fixed. PR49566 tracks the bug.
6695 return FullSet;
6696
6697 // TODO: Extend to other opcodes such as mul, and div
6698 switch (BO->getOpcode()) {
6699 default:
6700 return FullSet;
6701 case Instruction::AShr:
6702 case Instruction::LShr:
6703 case Instruction::Shl:
6704 break;
6705 };
6706
6707 if (BO->getOperand(0) != P)
6708 // TODO: Handle the power function forms some day.
6709 return FullSet;
6710
6711 unsigned TC = getSmallConstantMaxTripCount(L);
6712 if (!TC || TC >= BitWidth)
6713 return FullSet;
6714
6715 auto KnownStart = computeKnownBits(Start, DL, &AC, nullptr, &DT);
6716 auto KnownStep = computeKnownBits(Step, DL, &AC, nullptr, &DT);
6717 assert(KnownStart.getBitWidth() == BitWidth &&
6718 KnownStep.getBitWidth() == BitWidth);
6719
6720 // Compute total shift amount, being careful of overflow and bitwidths.
6721 auto MaxShiftAmt = KnownStep.getMaxValue();
6722 APInt TCAP(BitWidth, TC-1);
6723 bool Overflow = false;
6724 auto TotalShift = MaxShiftAmt.umul_ov(TCAP, Overflow);
6725 if (Overflow)
6726 return FullSet;
6727
6728 switch (BO->getOpcode()) {
6729 default:
6730 llvm_unreachable("filtered out above");
6731 case Instruction::AShr: {
6732 // For each ashr, three cases:
6733 // shift = 0 => unchanged value
6734 // saturation => 0 or -1
6735 // other => a value closer to zero (of the same sign)
6736 // Thus, the end value is closer to zero than the start.
6737 auto KnownEnd = KnownBits::ashr(KnownStart,
6738 KnownBits::makeConstant(TotalShift));
6739 if (KnownStart.isNonNegative())
6740 // Analogous to lshr (simply not yet canonicalized)
6741 return ConstantRange::getNonEmpty(KnownEnd.getMinValue(),
6742 KnownStart.getMaxValue() + 1);
6743 if (KnownStart.isNegative())
6744 // End >=u Start && End <=s Start
6745 return ConstantRange::getNonEmpty(KnownStart.getMinValue(),
6746 KnownEnd.getMaxValue() + 1);
6747 break;
6748 }
6749 case Instruction::LShr: {
6750 // For each lshr, three cases:
6751 // shift = 0 => unchanged value
6752 // saturation => 0
6753 // other => a smaller positive number
6754 // Thus, the low end of the unsigned range is the last value produced.
6755 auto KnownEnd = KnownBits::lshr(KnownStart,
6756 KnownBits::makeConstant(TotalShift));
6757 return ConstantRange::getNonEmpty(KnownEnd.getMinValue(),
6758 KnownStart.getMaxValue() + 1);
6759 }
6760 case Instruction::Shl: {
6761 // Iff no bits are shifted out, value increases on every shift.
6762 auto KnownEnd = KnownBits::shl(KnownStart,
6763 KnownBits::makeConstant(TotalShift));
6764 if (TotalShift.ult(KnownStart.countMinLeadingZeros()))
6765 return ConstantRange(KnownStart.getMinValue(),
6766 KnownEnd.getMaxValue() + 1);
6767 break;
6768 }
6769 };
6770 return FullSet;
6771}
6772
6773// The goal of this function is to check if recursively visiting the operands
6774// of this PHI might lead to an infinite loop. If we do see such a loop,
6775// there's no good way to break it, so we avoid analyzing such cases.
6776//
6777// getRangeRef previously used a visited set to avoid infinite loops, but this
6778// caused other issues: the result was dependent on the order of getRangeRef
6779// calls, and the interaction with createSCEVIter could cause a stack overflow
6780// in some cases (see issue #148253).
6781//
6782// FIXME: The way this is implemented is overly conservative; this checks
6783// for a few obviously safe patterns, but anything that doesn't lead to
6784// recursion is fine.
6786 Value *Cond = nullptr, *LHS = nullptr, *RHS = nullptr;
6788 return true;
6789
6790 if (all_of(PHI->operands(),
6791 [&](Value *Operand) { return DT.dominates(Operand, PHI); }))
6792 return true;
6793
6794 return false;
6795}
6796
6797const ConstantRange &
6798ScalarEvolution::getRangeRefIter(const SCEV *S,
6799 ScalarEvolution::RangeSignHint SignHint) {
6800 DenseMap<const SCEV *, ConstantRange> &Cache =
6801 SignHint == ScalarEvolution::HINT_RANGE_UNSIGNED ? UnsignedRanges
6802 : SignedRanges;
6803 SmallVector<SCEVUse> WorkList;
6804 SmallPtrSet<const SCEV *, 8> Seen;
6805
6806 // Add Expr to the worklist, if Expr is either an N-ary expression or a
6807 // SCEVUnknown PHI node.
6808 auto AddToWorklist = [&WorkList, &Seen, &Cache](const SCEV *Expr) {
6809 if (!Seen.insert(Expr).second)
6810 return;
6811 if (Cache.contains(Expr))
6812 return;
6813 switch (Expr->getSCEVType()) {
6814 case scUnknown:
6816 break;
6817 [[fallthrough]];
6818 case scConstant:
6819 case scVScale:
6820 case scTruncate:
6821 case scZeroExtend:
6822 case scSignExtend:
6823 case scPtrToAddr:
6824 case scPtrToInt:
6825 case scAddExpr:
6826 case scMulExpr:
6827 case scUDivExpr:
6828 case scAddRecExpr:
6829 case scUMaxExpr:
6830 case scSMaxExpr:
6831 case scUMinExpr:
6832 case scSMinExpr:
6834 WorkList.push_back(Expr);
6835 break;
6836 case scCouldNotCompute:
6837 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
6838 }
6839 };
6840 AddToWorklist(S);
6841
6842 // Build worklist by queuing operands of N-ary expressions and phi nodes.
6843 for (unsigned I = 0; I != WorkList.size(); ++I) {
6844 const SCEV *P = WorkList[I];
6845 auto *UnknownS = dyn_cast<SCEVUnknown>(P);
6846 // If it is not a `SCEVUnknown`, just recurse into operands.
6847 if (!UnknownS) {
6848 for (const SCEV *Op : P->operands())
6849 AddToWorklist(Op);
6850 continue;
6851 }
6852 // `SCEVUnknown`'s require special treatment.
6853 if (PHINode *P = dyn_cast<PHINode>(UnknownS->getValue())) {
6854 if (!RangeRefPHIAllowedOperands(DT, P))
6855 continue;
6856 for (auto &Op : reverse(P->operands()))
6857 AddToWorklist(getSCEV(Op));
6858 }
6859 }
6860
6861 if (!WorkList.empty()) {
6862 // Use getRangeRef to compute ranges for items in the worklist in reverse
6863 // order. This will force ranges for earlier operands to be computed before
6864 // their users in most cases.
6865 for (const SCEV *P : reverse(drop_begin(WorkList))) {
6866 getRangeRef(P, SignHint);
6867 }
6868 }
6869
6870 return getRangeRef(S, SignHint, 0);
6871}
6872
6873/// Determine the range for a particular SCEV. If SignHint is
6874/// HINT_RANGE_UNSIGNED (resp. HINT_RANGE_SIGNED) then getRange prefers ranges
6875/// with a "cleaner" unsigned (resp. signed) representation.
6876const ConstantRange &ScalarEvolution::getRangeRef(
6877 const SCEV *S, ScalarEvolution::RangeSignHint SignHint, unsigned Depth) {
6878 DenseMap<const SCEV *, ConstantRange> &Cache =
6879 SignHint == ScalarEvolution::HINT_RANGE_UNSIGNED ? UnsignedRanges
6880 : SignedRanges;
6882 SignHint == ScalarEvolution::HINT_RANGE_UNSIGNED ? ConstantRange::Unsigned
6884
6885 // See if we've computed this range already.
6886 auto I = Cache.find(S);
6887 if (I != Cache.end())
6888 return I->second;
6889
6890 if (const SCEVConstant *C = dyn_cast<SCEVConstant>(S))
6891 return setRange(C, SignHint, ConstantRange(C->getAPInt()));
6892
6893 // Switch to iteratively computing the range for S, if it is part of a deeply
6894 // nested expression.
6896 return getRangeRefIter(S, SignHint);
6897
6898 unsigned BitWidth = getTypeSizeInBits(S->getType());
6899 ConstantRange ConservativeResult(BitWidth, /*isFullSet=*/true);
6900 using OBO = OverflowingBinaryOperator;
6901
6902 // If the value has known zeros, the maximum value will have those known zeros
6903 // as well.
6904 if (SignHint == ScalarEvolution::HINT_RANGE_UNSIGNED) {
6905 APInt Multiple = getNonZeroConstantMultiple(S);
6906 APInt Remainder = APInt::getMaxValue(BitWidth).urem(Multiple);
6907 if (!Remainder.isZero())
6908 ConservativeResult =
6909 ConstantRange(APInt::getMinValue(BitWidth),
6910 APInt::getMaxValue(BitWidth) - Remainder + 1);
6911 }
6912 else {
6913 uint32_t TZ = getMinTrailingZeros(S);
6914 if (TZ != 0) {
6915 ConservativeResult = ConstantRange(
6917 APInt::getSignedMaxValue(BitWidth).ashr(TZ).shl(TZ) + 1);
6918 }
6919 }
6920
6921 switch (S->getSCEVType()) {
6922 case scConstant:
6923 llvm_unreachable("Already handled above.");
6924 case scVScale:
6925 return setRange(S, SignHint, getVScaleRange(&F, BitWidth));
6926 case scTruncate: {
6927 const SCEVTruncateExpr *Trunc = cast<SCEVTruncateExpr>(S);
6928 ConstantRange X = getRangeRef(Trunc->getOperand(), SignHint, Depth + 1);
6929 return setRange(
6930 Trunc, SignHint,
6931 ConservativeResult.intersectWith(X.truncate(BitWidth), RangeType));
6932 }
6933 case scZeroExtend: {
6934 const SCEVZeroExtendExpr *ZExt = cast<SCEVZeroExtendExpr>(S);
6935 ConstantRange X = getRangeRef(ZExt->getOperand(), SignHint, Depth + 1);
6936 return setRange(
6937 ZExt, SignHint,
6938 ConservativeResult.intersectWith(X.zeroExtend(BitWidth), RangeType));
6939 }
6940 case scSignExtend: {
6941 const SCEVSignExtendExpr *SExt = cast<SCEVSignExtendExpr>(S);
6942 ConstantRange X = getRangeRef(SExt->getOperand(), SignHint, Depth + 1);
6943 return setRange(
6944 SExt, SignHint,
6945 ConservativeResult.intersectWith(X.signExtend(BitWidth), RangeType));
6946 }
6947 case scPtrToAddr:
6948 case scPtrToInt: {
6949 const SCEVCastExpr *Cast = cast<SCEVCastExpr>(S);
6950 ConstantRange X = getRangeRef(Cast->getOperand(), SignHint, Depth + 1);
6951 return setRange(Cast, SignHint, X);
6952 }
6953 case scAddExpr: {
6954 const SCEVAddExpr *Add = cast<SCEVAddExpr>(S);
6955 // Check if this is a URem pattern: A - (A / B) * B, which is always < B.
6956 const SCEV *URemLHS = nullptr, *URemRHS = nullptr;
6957 if (SignHint == ScalarEvolution::HINT_RANGE_UNSIGNED &&
6958 match(S, m_scev_URem(m_SCEV(URemLHS), m_SCEV(URemRHS), *this))) {
6959 ConstantRange LHSRange = getRangeRef(URemLHS, SignHint, Depth + 1);
6960 ConstantRange RHSRange = getRangeRef(URemRHS, SignHint, Depth + 1);
6961 ConservativeResult =
6962 ConservativeResult.intersectWith(LHSRange.urem(RHSRange), RangeType);
6963 }
6964 ConstantRange X = getRangeRef(Add->getOperand(0), SignHint, Depth + 1);
6965 unsigned WrapType = OBO::AnyWrap;
6966 if (Add->hasNoSignedWrap())
6967 WrapType |= OBO::NoSignedWrap;
6968 if (Add->hasNoUnsignedWrap())
6969 WrapType |= OBO::NoUnsignedWrap;
6970 for (const SCEV *Op : drop_begin(Add->operands()))
6971 X = X.addWithNoWrap(getRangeRef(Op, SignHint, Depth + 1), WrapType,
6972 RangeType);
6973 return setRange(Add, SignHint,
6974 ConservativeResult.intersectWith(X, RangeType));
6975 }
6976 case scMulExpr: {
6977 const SCEVMulExpr *Mul = cast<SCEVMulExpr>(S);
6978 ConstantRange X = getRangeRef(Mul->getOperand(0), SignHint, Depth + 1);
6979 for (const SCEV *Op : drop_begin(Mul->operands()))
6980 X = X.multiply(getRangeRef(Op, SignHint, Depth + 1));
6981 return setRange(Mul, SignHint,
6982 ConservativeResult.intersectWith(X, RangeType));
6983 }
6984 case scUDivExpr: {
6985 const SCEVUDivExpr *UDiv = cast<SCEVUDivExpr>(S);
6986 ConstantRange X = getRangeRef(UDiv->getLHS(), SignHint, Depth + 1);
6987 ConstantRange Y = getRangeRef(UDiv->getRHS(), SignHint, Depth + 1);
6988 return setRange(UDiv, SignHint,
6989 ConservativeResult.intersectWith(X.udiv(Y), RangeType));
6990 }
6991 case scAddRecExpr: {
6992 const SCEVAddRecExpr *AddRec = cast<SCEVAddRecExpr>(S);
6993 // If there's no unsigned wrap, the value will never be less than its
6994 // initial value.
6995 if (AddRec->hasNoUnsignedWrap()) {
6996 APInt UnsignedMinValue = getUnsignedRangeMin(AddRec->getStart());
6997 if (!UnsignedMinValue.isZero())
6998 ConservativeResult = ConservativeResult.intersectWith(
6999 ConstantRange(UnsignedMinValue, APInt(BitWidth, 0)), RangeType);
7000 }
7001
7002 // If there's no signed wrap, and all the operands except initial value have
7003 // the same sign or zero, the value won't ever be:
7004 // 1: smaller than initial value if operands are non negative,
7005 // 2: bigger than initial value if operands are non positive.
7006 // For both cases, value can not cross signed min/max boundary.
7007 if (AddRec->hasNoSignedWrap()) {
7008 bool AllNonNeg = true;
7009 bool AllNonPos = true;
7010 for (unsigned i = 1, e = AddRec->getNumOperands(); i != e; ++i) {
7011 if (!isKnownNonNegative(AddRec->getOperand(i)))
7012 AllNonNeg = false;
7013 if (!isKnownNonPositive(AddRec->getOperand(i)))
7014 AllNonPos = false;
7015 }
7016 if (AllNonNeg)
7017 ConservativeResult = ConservativeResult.intersectWith(
7020 RangeType);
7021 else if (AllNonPos)
7022 ConservativeResult = ConservativeResult.intersectWith(
7024 getSignedRangeMax(AddRec->getStart()) +
7025 1),
7026 RangeType);
7027 }
7028
7029 // TODO: non-affine addrec
7030 if (AddRec->isAffine()) {
7031 const SCEV *MaxBEScev =
7033 if (!isa<SCEVCouldNotCompute>(MaxBEScev)) {
7034 APInt MaxBECount = cast<SCEVConstant>(MaxBEScev)->getAPInt();
7035
7036 // Adjust MaxBECount to the same bitwidth as AddRec. We can truncate if
7037 // MaxBECount's active bits are all <= AddRec's bit width.
7038 if (MaxBECount.getBitWidth() > BitWidth &&
7039 MaxBECount.getActiveBits() <= BitWidth)
7040 MaxBECount = MaxBECount.trunc(BitWidth);
7041 else if (MaxBECount.getBitWidth() < BitWidth)
7042 MaxBECount = MaxBECount.zext(BitWidth);
7043
7044 if (MaxBECount.getBitWidth() == BitWidth) {
7045 auto RangeFromAffine = getRangeForAffineAR(
7046 AddRec->getStart(), AddRec->getStepRecurrence(*this), MaxBECount);
7047 ConservativeResult =
7048 ConservativeResult.intersectWith(RangeFromAffine, RangeType);
7049
7050 auto RangeFromFactoring = getRangeViaFactoring(
7051 AddRec->getStart(), AddRec->getStepRecurrence(*this), MaxBECount);
7052 ConservativeResult =
7053 ConservativeResult.intersectWith(RangeFromFactoring, RangeType);
7054 }
7055 }
7056
7057 // Now try symbolic BE count and more powerful methods.
7059 const SCEV *SymbolicMaxBECount =
7061 if (!isa<SCEVCouldNotCompute>(SymbolicMaxBECount) &&
7062 getTypeSizeInBits(MaxBEScev->getType()) <= BitWidth &&
7063 AddRec->hasNoSelfWrap()) {
7064 auto RangeFromAffineNew = getRangeForAffineNoSelfWrappingAR(
7065 AddRec, SymbolicMaxBECount, BitWidth, SignHint);
7066 ConservativeResult =
7067 ConservativeResult.intersectWith(RangeFromAffineNew, RangeType);
7068 }
7069 }
7070 }
7071
7072 return setRange(AddRec, SignHint, std::move(ConservativeResult));
7073 }
7074 case scUMaxExpr:
7075 case scSMaxExpr:
7076 case scUMinExpr:
7077 case scSMinExpr:
7078 case scSequentialUMinExpr: {
7080 switch (S->getSCEVType()) {
7081 case scUMaxExpr:
7082 ID = Intrinsic::umax;
7083 break;
7084 case scSMaxExpr:
7085 ID = Intrinsic::smax;
7086 break;
7087 case scUMinExpr:
7089 ID = Intrinsic::umin;
7090 break;
7091 case scSMinExpr:
7092 ID = Intrinsic::smin;
7093 break;
7094 default:
7095 llvm_unreachable("Unknown SCEVMinMaxExpr/SCEVSequentialMinMaxExpr.");
7096 }
7097
7098 const auto *NAry = cast<SCEVNAryExpr>(S);
7099 ConstantRange X = getRangeRef(NAry->getOperand(0), SignHint, Depth + 1);
7100 for (unsigned i = 1, e = NAry->getNumOperands(); i != e; ++i)
7101 X = X.intrinsic(
7102 ID, {X, getRangeRef(NAry->getOperand(i), SignHint, Depth + 1)});
7103 return setRange(S, SignHint,
7104 ConservativeResult.intersectWith(X, RangeType));
7105 }
7106 case scUnknown: {
7107 const SCEVUnknown *U = cast<SCEVUnknown>(S);
7108 Value *V = U->getValue();
7109
7110 // Check if the IR explicitly contains !range metadata.
7111 std::optional<ConstantRange> MDRange = GetRangeFromMetadata(V);
7112 if (MDRange)
7113 ConservativeResult =
7114 ConservativeResult.intersectWith(*MDRange, RangeType);
7115
7116 // Use facts about recurrences in the underlying IR. Note that add
7117 // recurrences are AddRecExprs and thus don't hit this path. This
7118 // primarily handles shift recurrences.
7119 auto CR = getRangeForUnknownRecurrence(U);
7120 ConservativeResult = ConservativeResult.intersectWith(CR);
7121
7122 // See if ValueTracking can give us a useful range.
7123 const DataLayout &DL = getDataLayout();
7124 KnownBits Known = computeKnownBits(V, DL, &AC, nullptr, &DT);
7125 if (Known.getBitWidth() != BitWidth)
7126 Known = Known.zextOrTrunc(BitWidth);
7127
7128 // ValueTracking may be able to compute a tighter result for the number of
7129 // sign bits than for the value of those sign bits.
7130 unsigned NS = ComputeNumSignBits(V, DL, &AC, nullptr, &DT);
7131 if (U->getType()->isPointerTy()) {
7132 // If the pointer size is larger than the index size type, this can cause
7133 // NS to be larger than BitWidth. So compensate for this.
7134 unsigned ptrSize = DL.getPointerTypeSizeInBits(U->getType());
7135 int ptrIdxDiff = ptrSize - BitWidth;
7136 if (ptrIdxDiff > 0 && ptrSize > BitWidth && NS > (unsigned)ptrIdxDiff)
7137 NS -= ptrIdxDiff;
7138 }
7139
7140 if (NS > 1) {
7141 // If we know any of the sign bits, we know all of the sign bits.
7142 if (!Known.Zero.getHiBits(NS).isZero())
7143 Known.Zero.setHighBits(NS);
7144 if (!Known.One.getHiBits(NS).isZero())
7145 Known.One.setHighBits(NS);
7146 }
7147
7148 if (Known.getMinValue() != Known.getMaxValue() + 1)
7149 ConservativeResult = ConservativeResult.intersectWith(
7150 ConstantRange(Known.getMinValue(), Known.getMaxValue() + 1),
7151 RangeType);
7152 if (NS > 1)
7153 ConservativeResult = ConservativeResult.intersectWith(
7154 ConstantRange(APInt::getSignedMinValue(BitWidth).ashr(NS - 1),
7155 APInt::getSignedMaxValue(BitWidth).ashr(NS - 1) + 1),
7156 RangeType);
7157
7158 if (U->getType()->isPointerTy() && SignHint == HINT_RANGE_UNSIGNED) {
7159 // Strengthen the range if the underlying IR value is a
7160 // global/alloca/heap allocation using the size of the object.
7161 bool CanBeNull;
7162 uint64_t DerefBytes = V->getPointerDereferenceableBytes(
7163 DL, CanBeNull, /*CanBeFreed=*/nullptr);
7164 if (DerefBytes > 1 && isUIntN(BitWidth, DerefBytes)) {
7165 // The highest address the object can start is DerefBytes bytes before
7166 // the end (unsigned max value). If this value is not a multiple of the
7167 // alignment, the last possible start value is the next lowest multiple
7168 // of the alignment. Note: The computations below cannot overflow,
7169 // because if they would there's no possible start address for the
7170 // object.
7171 APInt MaxVal =
7172 APInt::getMaxValue(BitWidth) - APInt(BitWidth, DerefBytes);
7173 uint64_t Align = U->getValue()->getPointerAlignment(DL).value();
7174 uint64_t Rem = MaxVal.urem(Align);
7175 MaxVal -= APInt(BitWidth, Rem);
7176 APInt MinVal = APInt::getZero(BitWidth);
7177 if (llvm::isKnownNonZero(V, DL))
7178 MinVal = Align;
7179 ConservativeResult = ConservativeResult.intersectWith(
7180 ConstantRange::getNonEmpty(MinVal, MaxVal + 1), RangeType);
7181 }
7182 }
7183
7184 // A range of Phi is a subset of union of all ranges of its input.
7185 if (PHINode *Phi = dyn_cast<PHINode>(V)) {
7186 // SCEVExpander sometimes creates SCEVUnknowns that are secretly
7187 // AddRecs; return the range for the corresponding AddRec.
7188 if (auto *AR = dyn_cast<SCEVAddRecExpr>(getSCEV(V)))
7189 return getRangeRef(AR, SignHint, Depth + 1);
7190
7191 // Make sure that we do not run over cycled Phis.
7192 if (RangeRefPHIAllowedOperands(DT, Phi)) {
7193 ConstantRange RangeFromOps(BitWidth, /*isFullSet=*/false);
7194
7195 for (const auto &Op : Phi->operands()) {
7196 auto OpRange = getRangeRef(getSCEV(Op), SignHint, Depth + 1);
7197 RangeFromOps = RangeFromOps.unionWith(OpRange);
7198 // No point to continue if we already have a full set.
7199 if (RangeFromOps.isFullSet())
7200 break;
7201 }
7202 ConservativeResult =
7203 ConservativeResult.intersectWith(RangeFromOps, RangeType);
7204 }
7205 }
7206
7207 // vscale can't be equal to zero
7208 if (const auto *II = dyn_cast<IntrinsicInst>(V))
7209 if (II->getIntrinsicID() == Intrinsic::vscale) {
7210 ConstantRange Disallowed = APInt::getZero(BitWidth);
7211 ConservativeResult = ConservativeResult.difference(Disallowed);
7212 }
7213
7214 return setRange(U, SignHint, std::move(ConservativeResult));
7215 }
7216 case scCouldNotCompute:
7217 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
7218 }
7219
7220 return setRange(S, SignHint, std::move(ConservativeResult));
7221}
7222
7223// Given a StartRange, Step and MaxBECount for an expression compute a range of
7224// values that the expression can take. Initially, the expression has a value
7225// from StartRange and then is changed by Step up to MaxBECount times. Signed
7226// argument defines if we treat Step as signed or unsigned.
7228 const ConstantRange &StartRange,
7229 const APInt &MaxBECount,
7230 bool Signed) {
7231 unsigned BitWidth = Step.getBitWidth();
7232 assert(BitWidth == StartRange.getBitWidth() &&
7233 BitWidth == MaxBECount.getBitWidth() && "mismatched bit widths");
7234 // If either Step or MaxBECount is 0, then the expression won't change, and we
7235 // just need to return the initial range.
7236 if (Step == 0 || MaxBECount == 0)
7237 return StartRange;
7238
7239 // If we don't know anything about the initial value (i.e. StartRange is
7240 // FullRange), then we don't know anything about the final range either.
7241 // Return FullRange.
7242 if (StartRange.isFullSet())
7243 return ConstantRange::getFull(BitWidth);
7244
7245 // If Step is signed and negative, then we use its absolute value, but we also
7246 // note that we're moving in the opposite direction.
7247 bool Descending = Signed && Step.isNegative();
7248
7249 if (Signed)
7250 // This is correct even for INT_SMIN. Let's look at i8 to illustrate this:
7251 // abs(INT_SMIN) = abs(-128) = abs(0x80) = -0x80 = 0x80 = 128.
7252 // This equations hold true due to the well-defined wrap-around behavior of
7253 // APInt.
7254 Step = Step.abs();
7255
7256 // Check if Offset is more than full span of BitWidth. If it is, the
7257 // expression is guaranteed to overflow.
7258 if (APInt::getMaxValue(StartRange.getBitWidth()).udiv(Step).ult(MaxBECount))
7259 return ConstantRange::getFull(BitWidth);
7260
7261 // Offset is by how much the expression can change. Checks above guarantee no
7262 // overflow here.
7263 APInt Offset = Step * MaxBECount;
7264
7265 // Minimum value of the final range will match the minimal value of StartRange
7266 // if the expression is increasing and will be decreased by Offset otherwise.
7267 // Maximum value of the final range will match the maximal value of StartRange
7268 // if the expression is decreasing and will be increased by Offset otherwise.
7269 APInt StartLower = StartRange.getLower();
7270 APInt StartUpper = StartRange.getUpper() - 1;
7271 APInt MovedBoundary = Descending ? (StartLower - std::move(Offset))
7272 : (StartUpper + std::move(Offset));
7273
7274 // It's possible that the new minimum/maximum value will fall into the initial
7275 // range (due to wrap around). This means that the expression can take any
7276 // value in this bitwidth, and we have to return full range.
7277 if (StartRange.contains(MovedBoundary))
7278 return ConstantRange::getFull(BitWidth);
7279
7280 APInt NewLower =
7281 Descending ? std::move(MovedBoundary) : std::move(StartLower);
7282 APInt NewUpper =
7283 Descending ? std::move(StartUpper) : std::move(MovedBoundary);
7284 NewUpper += 1;
7285
7286 // No overflow detected, return [StartLower, StartUpper + Offset + 1) range.
7287 return ConstantRange::getNonEmpty(std::move(NewLower), std::move(NewUpper));
7288}
7289
7290ConstantRange ScalarEvolution::getRangeForAffineAR(const SCEV *Start,
7291 const SCEV *Step,
7292 const APInt &MaxBECount) {
7293 assert(getTypeSizeInBits(Start->getType()) ==
7294 getTypeSizeInBits(Step->getType()) &&
7295 getTypeSizeInBits(Start->getType()) == MaxBECount.getBitWidth() &&
7296 "mismatched bit widths");
7297
7298 // First, consider step signed.
7299 ConstantRange StartSRange = getSignedRange(Start);
7300 ConstantRange StepSRange = getSignedRange(Step);
7301
7302 // If Step can be both positive and negative, we need to find ranges for the
7303 // maximum absolute step values in both directions and union them.
7304 ConstantRange SR = getRangeForAffineARHelper(
7305 StepSRange.getSignedMin(), StartSRange, MaxBECount, /* Signed = */ true);
7307 StartSRange, MaxBECount,
7308 /* Signed = */ true));
7309
7310 // Next, consider step unsigned.
7311 ConstantRange UR = getRangeForAffineARHelper(
7312 getUnsignedRangeMax(Step), getUnsignedRange(Start), MaxBECount,
7313 /* Signed = */ false);
7314
7315 // Finally, intersect signed and unsigned ranges.
7317}
7318
7319ConstantRange ScalarEvolution::getRangeForAffineNoSelfWrappingAR(
7320 const SCEVAddRecExpr *AddRec, const SCEV *MaxBECount, unsigned BitWidth,
7321 ScalarEvolution::RangeSignHint SignHint) {
7322 assert(AddRec->isAffine() && "Non-affine AddRecs are not suppored!\n");
7323 assert(AddRec->hasNoSelfWrap() &&
7324 "This only works for non-self-wrapping AddRecs!");
7325 const bool IsSigned = SignHint == HINT_RANGE_SIGNED;
7326 const SCEV *Step = AddRec->getStepRecurrence(*this);
7327 // Only deal with constant step to save compile time.
7328 if (!isa<SCEVConstant>(Step))
7329 return ConstantRange::getFull(BitWidth);
7330 // Let's make sure that we can prove that we do not self-wrap during
7331 // MaxBECount iterations. We need this because MaxBECount is a maximum
7332 // iteration count estimate, and we might infer nw from some exit for which we
7333 // do not know max exit count (or any other side reasoning).
7334 // TODO: Turn into assert at some point.
7335 if (getTypeSizeInBits(MaxBECount->getType()) >
7336 getTypeSizeInBits(AddRec->getType()))
7337 return ConstantRange::getFull(BitWidth);
7338 MaxBECount = getNoopOrZeroExtend(MaxBECount, AddRec->getType());
7339 const SCEV *RangeWidth = getMinusOne(AddRec->getType());
7340 const SCEV *StepAbs = getUMinExpr(Step, getNegativeSCEV(Step));
7341 const SCEV *MaxItersWithoutWrap = getUDivExpr(RangeWidth, StepAbs);
7342 if (!isKnownPredicateViaConstantRanges(ICmpInst::ICMP_ULE, MaxBECount,
7343 MaxItersWithoutWrap))
7344 return ConstantRange::getFull(BitWidth);
7345
7346 ICmpInst::Predicate LEPred =
7348 ICmpInst::Predicate GEPred =
7350 const SCEV *End = AddRec->evaluateAtIteration(MaxBECount, *this);
7351
7352 // We know that there is no self-wrap. Let's take Start and End values and
7353 // look at all intermediate values V1, V2, ..., Vn that IndVar takes during
7354 // the iteration. They either lie inside the range [Min(Start, End),
7355 // Max(Start, End)] or outside it:
7356 //
7357 // Case 1: RangeMin ... Start V1 ... VN End ... RangeMax;
7358 // Case 2: RangeMin Vk ... V1 Start ... End Vn ... Vk + 1 RangeMax;
7359 //
7360 // No self wrap flag guarantees that the intermediate values cannot be BOTH
7361 // outside and inside the range [Min(Start, End), Max(Start, End)]. Using that
7362 // knowledge, let's try to prove that we are dealing with Case 1. It is so if
7363 // Start <= End and step is positive, or Start >= End and step is negative.
7364 const SCEV *Start = applyLoopGuards(AddRec->getStart(), AddRec->getLoop());
7365 ConstantRange StartRange = getRangeRef(Start, SignHint);
7366 ConstantRange EndRange = getRangeRef(End, SignHint);
7367 ConstantRange RangeBetween = StartRange.unionWith(EndRange);
7368 // If they already cover full iteration space, we will know nothing useful
7369 // even if we prove what we want to prove.
7370 if (RangeBetween.isFullSet())
7371 return RangeBetween;
7372 // Only deal with ranges that do not wrap (i.e. RangeMin < RangeMax).
7373 bool IsWrappedSet = IsSigned ? RangeBetween.isSignWrappedSet()
7374 : RangeBetween.isWrappedSet();
7375 if (IsWrappedSet)
7376 return ConstantRange::getFull(BitWidth);
7377
7378 if (isKnownPositive(Step) &&
7379 isKnownPredicateViaConstantRanges(LEPred, Start, End))
7380 return RangeBetween;
7381 if (isKnownNegative(Step) &&
7382 isKnownPredicateViaConstantRanges(GEPred, Start, End))
7383 return RangeBetween;
7384 return ConstantRange::getFull(BitWidth);
7385}
7386
7387ConstantRange ScalarEvolution::getRangeViaFactoring(const SCEV *Start,
7388 const SCEV *Step,
7389 const APInt &MaxBECount) {
7390 // RangeOf({C?A:B,+,C?P:Q}) == RangeOf(C?{A,+,P}:{B,+,Q})
7391 // == RangeOf({A,+,P}) union RangeOf({B,+,Q})
7392
7393 unsigned BitWidth = MaxBECount.getBitWidth();
7394 assert(getTypeSizeInBits(Start->getType()) == BitWidth &&
7395 getTypeSizeInBits(Step->getType()) == BitWidth &&
7396 "mismatched bit widths");
7397
7398 struct SelectPattern {
7399 Value *Condition = nullptr;
7400 APInt TrueValue;
7401 APInt FalseValue;
7402
7403 explicit SelectPattern(ScalarEvolution &SE, unsigned BitWidth,
7404 const SCEV *S) {
7405 std::optional<unsigned> CastOp;
7406 APInt Offset(BitWidth, 0);
7407
7409 "Should be!");
7410
7411 // Peel off a constant offset. In the future we could consider being
7412 // smarter here and handle {Start+Step,+,Step} too.
7413 const APInt *Off;
7414 if (match(S, m_scev_Add(m_scev_APInt(Off), m_SCEV(S))))
7415 Offset = *Off;
7416
7417 // Peel off a cast operation
7418 if (auto *SCast = dyn_cast<SCEVIntegralCastExpr>(S)) {
7419 CastOp = SCast->getSCEVType();
7420 S = SCast->getOperand();
7421 }
7422
7423 using namespace llvm::PatternMatch;
7424
7425 auto *SU = dyn_cast<SCEVUnknown>(S);
7426 const APInt *TrueVal, *FalseVal;
7427 if (!SU ||
7428 !match(SU->getValue(), m_Select(m_Value(Condition), m_APInt(TrueVal),
7429 m_APInt(FalseVal)))) {
7430 Condition = nullptr;
7431 return;
7432 }
7433
7434 TrueValue = *TrueVal;
7435 FalseValue = *FalseVal;
7436
7437 // Re-apply the cast we peeled off earlier
7438 if (CastOp)
7439 switch (*CastOp) {
7440 default:
7441 llvm_unreachable("Unknown SCEV cast type!");
7442
7443 case scTruncate:
7444 TrueValue = TrueValue.trunc(BitWidth);
7445 FalseValue = FalseValue.trunc(BitWidth);
7446 break;
7447 case scZeroExtend:
7448 TrueValue = TrueValue.zext(BitWidth);
7449 FalseValue = FalseValue.zext(BitWidth);
7450 break;
7451 case scSignExtend:
7452 TrueValue = TrueValue.sext(BitWidth);
7453 FalseValue = FalseValue.sext(BitWidth);
7454 break;
7455 }
7456
7457 // Re-apply the constant offset we peeled off earlier
7458 TrueValue += Offset;
7459 FalseValue += Offset;
7460 }
7461
7462 bool isRecognized() { return Condition != nullptr; }
7463 };
7464
7465 SelectPattern StartPattern(*this, BitWidth, Start);
7466 if (!StartPattern.isRecognized())
7467 return ConstantRange::getFull(BitWidth);
7468
7469 SelectPattern StepPattern(*this, BitWidth, Step);
7470 if (!StepPattern.isRecognized())
7471 return ConstantRange::getFull(BitWidth);
7472
7473 if (StartPattern.Condition != StepPattern.Condition) {
7474 // We don't handle this case today; but we could, by considering four
7475 // possibilities below instead of two. I'm not sure if there are cases where
7476 // that will help over what getRange already does, though.
7477 return ConstantRange::getFull(BitWidth);
7478 }
7479
7480 // NB! Calling ScalarEvolution::getConstant is fine, but we should not try to
7481 // construct arbitrary general SCEV expressions here. This function is called
7482 // from deep in the call stack, and calling getSCEV (on a sext instruction,
7483 // say) can end up caching a suboptimal value.
7484
7485 // FIXME: without the explicit `this` receiver below, MSVC errors out with
7486 // C2352 and C2512 (otherwise it isn't needed).
7487
7488 const SCEV *TrueStart = this->getConstant(StartPattern.TrueValue);
7489 const SCEV *TrueStep = this->getConstant(StepPattern.TrueValue);
7490 const SCEV *FalseStart = this->getConstant(StartPattern.FalseValue);
7491 const SCEV *FalseStep = this->getConstant(StepPattern.FalseValue);
7492
7493 ConstantRange TrueRange =
7494 this->getRangeForAffineAR(TrueStart, TrueStep, MaxBECount);
7495 ConstantRange FalseRange =
7496 this->getRangeForAffineAR(FalseStart, FalseStep, MaxBECount);
7497
7498 return TrueRange.unionWith(FalseRange);
7499}
7500
7501SCEV::NoWrapFlags ScalarEvolution::getNoWrapFlagsFromUB(const Value *V) {
7502 if (isa<ConstantExpr>(V)) return SCEV::FlagAnyWrap;
7503 const BinaryOperator *BinOp = cast<BinaryOperator>(V);
7504
7505 // Return early if there are no flags to propagate to the SCEV.
7507 if (auto *PDI = dyn_cast<PossiblyDisjointInst>(BinOp);
7508 PDI && PDI->isDisjoint()) {
7510 } else {
7511 if (BinOp->hasNoUnsignedWrap())
7513 if (BinOp->hasNoSignedWrap())
7515 }
7516 if (Flags == SCEV::FlagAnyWrap)
7517 return SCEV::FlagAnyWrap;
7518
7519 return isSCEVExprNeverPoison(BinOp) ? Flags : SCEV::FlagAnyWrap;
7520}
7521
7522const Instruction *
7523ScalarEvolution::getNonTrivialDefiningScopeBound(const SCEV *S) {
7524 if (auto *AddRec = dyn_cast<SCEVAddRecExpr>(S))
7525 return &*AddRec->getLoop()->getHeader()->begin();
7526 if (auto *U = dyn_cast<SCEVUnknown>(S))
7527 if (auto *I = dyn_cast<Instruction>(U->getValue()))
7528 return I;
7529 return nullptr;
7530}
7531
7532const Instruction *ScalarEvolution::getDefiningScopeBound(ArrayRef<SCEVUse> Ops,
7533 bool &Precise) {
7534 Precise = true;
7535 // Do a bounded search of the def relation of the requested SCEVs.
7536 SmallPtrSet<const SCEV *, 16> Visited;
7537 SmallVector<SCEVUse> Worklist;
7538 auto pushOp = [&](const SCEV *S) {
7539 if (!Visited.insert(S).second)
7540 return;
7541 // Threshold of 30 here is arbitrary.
7542 if (Visited.size() > 30) {
7543 Precise = false;
7544 return;
7545 }
7546 Worklist.push_back(S);
7547 };
7548
7549 for (SCEVUse S : Ops)
7550 pushOp(S);
7551
7552 const Instruction *Bound = nullptr;
7553 while (!Worklist.empty()) {
7554 SCEVUse S = Worklist.pop_back_val();
7555 if (auto *DefI = getNonTrivialDefiningScopeBound(S)) {
7556 if (!Bound || DT.dominates(Bound, DefI))
7557 Bound = DefI;
7558 } else {
7559 for (SCEVUse Op : S->operands())
7560 pushOp(Op);
7561 }
7562 }
7563 return Bound ? Bound : &*F.getEntryBlock().begin();
7564}
7565
7566const Instruction *
7567ScalarEvolution::getDefiningScopeBound(ArrayRef<SCEVUse> Ops) {
7568 bool Discard;
7569 return getDefiningScopeBound(Ops, Discard);
7570}
7571
7572bool ScalarEvolution::isGuaranteedToTransferExecutionTo(const Instruction *A,
7573 const Instruction *B) {
7574 if (A->getParent() == B->getParent() &&
7576 B->getIterator()))
7577 return true;
7578
7579 auto *BLoop = LI.getLoopFor(B->getParent());
7580 if (BLoop && BLoop->getHeader() == B->getParent() &&
7581 BLoop->getLoopPreheader() == A->getParent() &&
7583 A->getParent()->end()) &&
7584 isGuaranteedToTransferExecutionToSuccessor(B->getParent()->begin(),
7585 B->getIterator()))
7586 return true;
7587 return false;
7588}
7589
7591 SCEVPoisonCollector PC(/* LookThroughMaybePoisonBlocking */ true);
7592 visitAll(Op, PC);
7593 return PC.MaybePoison.empty();
7594}
7595
7596bool ScalarEvolution::isGuaranteedNotToCauseUB(const SCEV *Op) {
7597 return !SCEVExprContains(Op, [this](const SCEV *S) {
7598 const SCEV *Op1;
7599 bool M = match(S, m_scev_UDiv(m_SCEV(), m_SCEV(Op1)));
7600 // The UDiv may be UB if the divisor is poison or zero. Unless the divisor
7601 // is a non-zero constant, we have to assume the UDiv may be UB.
7602 return M && (!isKnownNonZero(Op1) || !isGuaranteedNotToBePoison(Op1));
7603 });
7604}
7605
7606bool ScalarEvolution::isSCEVExprNeverPoison(const Instruction *I) {
7607 // Only proceed if we can prove that I does not yield poison.
7609 return false;
7610
7611 // At this point we know that if I is executed, then it does not wrap
7612 // according to at least one of NSW or NUW. If I is not executed, then we do
7613 // not know if the calculation that I represents would wrap. Multiple
7614 // instructions can map to the same SCEV. If we apply NSW or NUW from I to
7615 // the SCEV, we must guarantee no wrapping for that SCEV also when it is
7616 // derived from other instructions that map to the same SCEV. We cannot make
7617 // that guarantee for cases where I is not executed. So we need to find a
7618 // upper bound on the defining scope for the SCEV, and prove that I is
7619 // executed every time we enter that scope. When the bounding scope is a
7620 // loop (the common case), this is equivalent to proving I executes on every
7621 // iteration of that loop.
7622 SmallVector<SCEVUse> SCEVOps;
7623 for (const Use &Op : I->operands()) {
7624 // I could be an extractvalue from a call to an overflow intrinsic.
7625 // TODO: We can do better here in some cases.
7626 if (isSCEVable(Op->getType()))
7627 SCEVOps.push_back(getSCEV(Op));
7628 }
7629 auto *DefI = getDefiningScopeBound(SCEVOps);
7630 return isGuaranteedToTransferExecutionTo(DefI, I);
7631}
7632
7633bool ScalarEvolution::isAddRecNeverPoison(const Instruction *I, const Loop *L) {
7634 // If we know that \c I can never be poison period, then that's enough.
7635 if (isSCEVExprNeverPoison(I))
7636 return true;
7637
7638 // If the loop only has one exit, then we know that, if the loop is entered,
7639 // any instruction dominating that exit will be executed. If any such
7640 // instruction would result in UB, the addrec cannot be poison.
7641 //
7642 // This is basically the same reasoning as in isSCEVExprNeverPoison(), but
7643 // also handles uses outside the loop header (they just need to dominate the
7644 // single exit).
7645
7646 auto *ExitingBB = L->getExitingBlock();
7647 if (!ExitingBB || !loopHasNoAbnormalExits(L))
7648 return false;
7649
7650 SmallPtrSet<const Value *, 16> KnownPoison;
7652
7653 // We start by assuming \c I, the post-inc add recurrence, is poison. Only
7654 // things that are known to be poison under that assumption go on the
7655 // Worklist.
7656 KnownPoison.insert(I);
7657 Worklist.push_back(I);
7658
7659 while (!Worklist.empty()) {
7660 const Instruction *Poison = Worklist.pop_back_val();
7661
7662 for (const Use &U : Poison->uses()) {
7663 const Instruction *PoisonUser = cast<Instruction>(U.getUser());
7664 if (mustTriggerUB(PoisonUser, KnownPoison) &&
7665 DT.dominates(PoisonUser->getParent(), ExitingBB))
7666 return true;
7667
7668 if (propagatesPoison(U) && L->contains(PoisonUser))
7669 if (KnownPoison.insert(PoisonUser).second)
7670 Worklist.push_back(PoisonUser);
7671 }
7672 }
7673
7674 return false;
7675}
7676
7677ScalarEvolution::LoopProperties
7678ScalarEvolution::getLoopProperties(const Loop *L) {
7679 using LoopProperties = ScalarEvolution::LoopProperties;
7680
7681 auto Itr = LoopPropertiesCache.find(L);
7682 if (Itr == LoopPropertiesCache.end()) {
7683 auto HasSideEffects = [](Instruction *I) {
7684 if (auto *SI = dyn_cast<StoreInst>(I))
7685 return !SI->isSimple();
7686
7687 if (I->mayThrow())
7688 return true;
7689
7690 // Non-volatile memset / memcpy do not count as side-effect for forward
7691 // progress.
7692 if (isa<MemIntrinsic>(I) && !I->isVolatile())
7693 return false;
7694
7695 return I->mayWriteToMemory();
7696 };
7697
7698 LoopProperties LP = {/* HasNoAbnormalExits */ true,
7699 /*HasNoSideEffects*/ true};
7700
7701 for (auto *BB : L->getBlocks())
7702 for (auto &I : *BB) {
7704 LP.HasNoAbnormalExits = false;
7705 if (HasSideEffects(&I))
7706 LP.HasNoSideEffects = false;
7707 if (!LP.HasNoAbnormalExits && !LP.HasNoSideEffects)
7708 break; // We're already as pessimistic as we can get.
7709 }
7710
7711 auto InsertPair = LoopPropertiesCache.insert({L, LP});
7712 assert(InsertPair.second && "We just checked!");
7713 Itr = InsertPair.first;
7714 }
7715
7716 return Itr->second;
7717}
7718
7720 // A mustprogress loop without side effects must be finite.
7721 // TODO: The check used here is very conservative. It's only *specific*
7722 // side effects which are well defined in infinite loops.
7723 return isFinite(L) || (isMustProgress(L) && loopHasNoSideEffects(L));
7724}
7725
7726const SCEV *ScalarEvolution::createSCEVIter(Value *V) {
7727 // Worklist item with a Value and a bool indicating whether all operands have
7728 // been visited already.
7731
7732 Stack.emplace_back(V, false);
7733 while (!Stack.empty()) {
7734 auto E = Stack.back();
7735 Value *CurV = E.getPointer();
7736
7737 if (getExistingSCEV(CurV)) {
7738 Stack.pop_back();
7739 continue;
7740 }
7741
7743 const SCEV *CreatedSCEV = nullptr;
7744 // If all operands have been visited already, create the SCEV.
7745 if (E.getInt()) {
7746 CreatedSCEV = createSCEV(CurV);
7747 } else {
7748 // Otherwise get the operands we need to create SCEV's for before creating
7749 // the SCEV for CurV. If the SCEV for CurV can be constructed trivially,
7750 // just use it.
7751 CreatedSCEV = getOperandsToCreate(CurV, Ops);
7752 }
7753
7754 if (CreatedSCEV) {
7755 insertValueToMap(CurV, CreatedSCEV);
7756 Stack.pop_back();
7757 } else {
7758 Stack.back().setInt(true);
7759 // Queue its operands which need to be constructed.
7760 for (Value *Op : Ops)
7761 Stack.emplace_back(Op, false);
7762 }
7763 }
7764
7765 return getExistingSCEV(V);
7766}
7767
7768const SCEV *
7769ScalarEvolution::getOperandsToCreate(Value *V, SmallVectorImpl<Value *> &Ops) {
7770 if (!isSCEVable(V->getType()))
7771 return getUnknown(V);
7772
7773 if (Instruction *I = dyn_cast<Instruction>(V)) {
7774 // Don't attempt to analyze instructions in blocks that aren't
7775 // reachable. Such instructions don't matter, and they aren't required
7776 // to obey basic rules for definitions dominating uses which this
7777 // analysis depends on.
7778 if (!DT.isReachableFromEntry(I->getParent()))
7779 return getUnknown(PoisonValue::get(V->getType()));
7780 } else if (ConstantInt *CI = dyn_cast<ConstantInt>(V))
7781 return getConstant(CI);
7782 else if (isa<GlobalAlias>(V))
7783 return getUnknown(V);
7784 else if (!isa<ConstantExpr>(V))
7785 return getUnknown(V);
7786
7788 if (auto BO =
7790 bool IsConstArg = isa<ConstantInt>(BO->RHS);
7791 switch (BO->Opcode) {
7792 case Instruction::Add:
7793 case Instruction::Mul: {
7794 // For additions and multiplications, traverse add/mul chains for which we
7795 // can potentially create a single SCEV, to reduce the number of
7796 // get{Add,Mul}Expr calls.
7797 do {
7798 if (BO->Op) {
7799 if (BO->Op != V && getExistingSCEV(BO->Op)) {
7800 Ops.push_back(BO->Op);
7801 break;
7802 }
7803 }
7804 Ops.push_back(BO->RHS);
7805 auto NewBO = MatchBinaryOp(BO->LHS, getDataLayout(), AC, DT,
7807 if (!NewBO ||
7808 (BO->Opcode == Instruction::Add &&
7809 (NewBO->Opcode != Instruction::Add &&
7810 NewBO->Opcode != Instruction::Sub)) ||
7811 (BO->Opcode == Instruction::Mul &&
7812 NewBO->Opcode != Instruction::Mul)) {
7813 Ops.push_back(BO->LHS);
7814 break;
7815 }
7816 // CreateSCEV calls getNoWrapFlagsFromUB, which under certain conditions
7817 // requires a SCEV for the LHS.
7818 if (BO->Op && (BO->IsNSW || BO->IsNUW)) {
7819 auto *I = dyn_cast<Instruction>(BO->Op);
7820 if (I && programUndefinedIfPoison(I)) {
7821 Ops.push_back(BO->LHS);
7822 break;
7823 }
7824 }
7825 BO = NewBO;
7826 } while (true);
7827 return nullptr;
7828 }
7829 case Instruction::Sub:
7830 case Instruction::UDiv:
7831 case Instruction::URem:
7832 break;
7833 case Instruction::AShr:
7834 case Instruction::Shl:
7835 case Instruction::Xor:
7836 if (!IsConstArg)
7837 return nullptr;
7838 break;
7839 case Instruction::And:
7840 case Instruction::Or:
7841 if (!IsConstArg && !BO->LHS->getType()->isIntegerTy(1))
7842 return nullptr;
7843 break;
7844 case Instruction::LShr:
7845 return getUnknown(V);
7846 default:
7847 llvm_unreachable("Unhandled binop");
7848 break;
7849 }
7850
7851 Ops.push_back(BO->LHS);
7852 Ops.push_back(BO->RHS);
7853 return nullptr;
7854 }
7855
7856 switch (U->getOpcode()) {
7857 case Instruction::Trunc:
7858 case Instruction::ZExt:
7859 case Instruction::SExt:
7860 case Instruction::PtrToAddr:
7861 case Instruction::PtrToInt:
7862 Ops.push_back(U->getOperand(0));
7863 return nullptr;
7864
7865 case Instruction::BitCast:
7866 if (isSCEVable(U->getType()) && isSCEVable(U->getOperand(0)->getType())) {
7867 Ops.push_back(U->getOperand(0));
7868 return nullptr;
7869 }
7870 return getUnknown(V);
7871
7872 case Instruction::SDiv:
7873 case Instruction::SRem:
7874 Ops.push_back(U->getOperand(0));
7875 Ops.push_back(U->getOperand(1));
7876 return nullptr;
7877
7878 case Instruction::GetElementPtr:
7879 assert(cast<GEPOperator>(U)->getSourceElementType()->isSized() &&
7880 "GEP source element type must be sized");
7881 llvm::append_range(Ops, U->operands());
7882 return nullptr;
7883
7884 case Instruction::IntToPtr:
7885 return getUnknown(V);
7886
7887 case Instruction::PHI:
7888 // getNodeForPHI has four ways to turn a PHI into a SCEV; retrieve the
7889 // relevant nodes for each of them.
7890 //
7891 // The first is just to call simplifyInstruction, and get something back
7892 // that isn't a PHI.
7893 if (Value *V = simplifyInstruction(
7894 cast<PHINode>(U),
7895 {getDataLayout(), &TLI, &DT, &AC, /*CtxI=*/nullptr,
7896 /*UseInstrInfo=*/true, /*CanUseUndef=*/false})) {
7897 assert(V);
7898 Ops.push_back(V);
7899 return nullptr;
7900 }
7901 // The second is createNodeForPHIWithIdenticalOperands: this looks for
7902 // operands which all perform the same operation, but haven't been
7903 // CSE'ed for whatever reason.
7904 if (BinaryOperator *BO = getCommonInstForPHI(cast<PHINode>(U))) {
7905 assert(BO);
7906 Ops.push_back(BO);
7907 return nullptr;
7908 }
7909 // The third is createNodeFromSelectLikePHI; this takes a PHI which
7910 // is equivalent to a select, and analyzes it like a select.
7911 {
7912 Value *Cond = nullptr, *LHS = nullptr, *RHS = nullptr;
7914 assert(Cond);
7915 assert(LHS);
7916 assert(RHS);
7917 if (auto *CondICmp = dyn_cast<ICmpInst>(Cond)) {
7918 Ops.push_back(CondICmp->getOperand(0));
7919 Ops.push_back(CondICmp->getOperand(1));
7920 }
7921 Ops.push_back(Cond);
7922 Ops.push_back(LHS);
7923 Ops.push_back(RHS);
7924 return nullptr;
7925 }
7926 }
7927 // The fourth way is createAddRecFromPHI. It's complicated to handle here,
7928 // so just construct it recursively.
7929 //
7930 // In addition to getNodeForPHI, also construct nodes which might be needed
7931 // by getRangeRef.
7933 for (Value *V : cast<PHINode>(U)->operands())
7934 Ops.push_back(V);
7935 return nullptr;
7936 }
7937 return nullptr;
7938
7939 case Instruction::Select: {
7940 // Check if U is a select that can be simplified to a SCEVUnknown.
7941 auto CanSimplifyToUnknown = [this, U]() {
7942 if (U->getType()->isIntegerTy(1) || isa<ConstantInt>(U->getOperand(0)))
7943 return false;
7944
7945 auto *ICI = dyn_cast<ICmpInst>(U->getOperand(0));
7946 if (!ICI)
7947 return false;
7948 Value *LHS = ICI->getOperand(0);
7949 Value *RHS = ICI->getOperand(1);
7950 if (ICI->getPredicate() == CmpInst::ICMP_EQ ||
7951 ICI->getPredicate() == CmpInst::ICMP_NE) {
7953 return true;
7954 } else if (getTypeSizeInBits(LHS->getType()) >
7955 getTypeSizeInBits(U->getType()))
7956 return true;
7957 return false;
7958 };
7959 if (CanSimplifyToUnknown())
7960 return getUnknown(U);
7961
7962 llvm::append_range(Ops, U->operands());
7963 return nullptr;
7964 break;
7965 }
7966 case Instruction::Call:
7967 case Instruction::Invoke:
7968 if (Value *RV = cast<CallBase>(U)->getReturnedArgOperand()) {
7969 Ops.push_back(RV);
7970 return nullptr;
7971 }
7972
7973 if (auto *II = dyn_cast<IntrinsicInst>(U)) {
7974 switch (II->getIntrinsicID()) {
7975 case Intrinsic::abs:
7976 Ops.push_back(II->getArgOperand(0));
7977 return nullptr;
7978 case Intrinsic::umax:
7979 case Intrinsic::umin:
7980 case Intrinsic::smax:
7981 case Intrinsic::smin:
7982 case Intrinsic::usub_sat:
7983 case Intrinsic::uadd_sat:
7984 Ops.push_back(II->getArgOperand(0));
7985 Ops.push_back(II->getArgOperand(1));
7986 return nullptr;
7987 case Intrinsic::start_loop_iterations:
7988 case Intrinsic::annotation:
7989 case Intrinsic::ptr_annotation:
7990 Ops.push_back(II->getArgOperand(0));
7991 return nullptr;
7992 default:
7993 break;
7994 }
7995 }
7996 break;
7997 }
7998
7999 return nullptr;
8000}
8001
8002const SCEV *ScalarEvolution::createSCEV(Value *V) {
8003 if (!isSCEVable(V->getType()))
8004 return getUnknown(V);
8005
8006 if (Instruction *I = dyn_cast<Instruction>(V)) {
8007 // Don't attempt to analyze instructions in blocks that aren't
8008 // reachable. Such instructions don't matter, and they aren't required
8009 // to obey basic rules for definitions dominating uses which this
8010 // analysis depends on.
8011 if (!DT.isReachableFromEntry(I->getParent()))
8012 return getUnknown(PoisonValue::get(V->getType()));
8013 } else if (ConstantInt *CI = dyn_cast<ConstantInt>(V))
8014 return getConstant(CI);
8015 else if (isa<GlobalAlias>(V))
8016 return getUnknown(V);
8017 else if (!isa<ConstantExpr>(V))
8018 return getUnknown(V);
8019
8020 const SCEV *LHS;
8021 const SCEV *RHS;
8022
8024 if (auto BO =
8026 switch (BO->Opcode) {
8027 case Instruction::Add: {
8028 // The simple thing to do would be to just call getSCEV on both operands
8029 // and call getAddExpr with the result. However if we're looking at a
8030 // bunch of things all added together, this can be quite inefficient,
8031 // because it leads to N-1 getAddExpr calls for N ultimate operands.
8032 // Instead, gather up all the operands and make a single getAddExpr call.
8033 // LLVM IR canonical form means we need only traverse the left operands.
8035 do {
8036 if (BO->Op) {
8037 if (auto *OpSCEV = getExistingSCEV(BO->Op)) {
8038 AddOps.push_back(OpSCEV);
8039 break;
8040 }
8041
8042 // If a NUW or NSW flag can be applied to the SCEV for this
8043 // addition, then compute the SCEV for this addition by itself
8044 // with a separate call to getAddExpr. We need to do that
8045 // instead of pushing the operands of the addition onto AddOps,
8046 // since the flags are only known to apply to this particular
8047 // addition - they may not apply to other additions that can be
8048 // formed with operands from AddOps.
8049 const SCEV *RHS = getSCEV(BO->RHS);
8050 SCEV::NoWrapFlags Flags = getNoWrapFlagsFromUB(BO->Op);
8051 if (Flags != SCEV::FlagAnyWrap) {
8052 const SCEV *LHS = getSCEV(BO->LHS);
8053 if (BO->Opcode == Instruction::Sub)
8054 AddOps.push_back(getMinusSCEV(LHS, RHS, Flags));
8055 else
8056 AddOps.push_back(getAddExpr(LHS, RHS, Flags));
8057 break;
8058 }
8059 }
8060
8061 if (BO->Opcode == Instruction::Sub)
8062 AddOps.push_back(getNegativeSCEV(getSCEV(BO->RHS)));
8063 else
8064 AddOps.push_back(getSCEV(BO->RHS));
8065
8066 auto NewBO = MatchBinaryOp(BO->LHS, getDataLayout(), AC, DT,
8068 if (!NewBO || (NewBO->Opcode != Instruction::Add &&
8069 NewBO->Opcode != Instruction::Sub)) {
8070 AddOps.push_back(getSCEV(BO->LHS));
8071 break;
8072 }
8073 BO = NewBO;
8074 } while (true);
8075
8076 return getAddExpr(AddOps);
8077 }
8078
8079 case Instruction::Mul: {
8081 do {
8082 if (BO->Op) {
8083 if (auto *OpSCEV = getExistingSCEV(BO->Op)) {
8084 MulOps.push_back(OpSCEV);
8085 break;
8086 }
8087
8088 SCEV::NoWrapFlags Flags = getNoWrapFlagsFromUB(BO->Op);
8089 if (Flags != SCEV::FlagAnyWrap) {
8090 LHS = getSCEV(BO->LHS);
8091 RHS = getSCEV(BO->RHS);
8092 MulOps.push_back(getMulExpr(LHS, RHS, Flags));
8093 break;
8094 }
8095 }
8096
8097 MulOps.push_back(getSCEV(BO->RHS));
8098 auto NewBO = MatchBinaryOp(BO->LHS, getDataLayout(), AC, DT,
8100 if (!NewBO || NewBO->Opcode != Instruction::Mul) {
8101 MulOps.push_back(getSCEV(BO->LHS));
8102 break;
8103 }
8104 BO = NewBO;
8105 } while (true);
8106
8107 return getMulExpr(MulOps);
8108 }
8109 case Instruction::UDiv:
8110 LHS = getSCEV(BO->LHS);
8111 RHS = getSCEV(BO->RHS);
8112 return getUDivExpr(LHS, RHS);
8113 case Instruction::URem:
8114 LHS = getSCEV(BO->LHS);
8115 RHS = getSCEV(BO->RHS);
8116 return getURemExpr(LHS, RHS);
8117 case Instruction::Sub: {
8119 if (BO->Op)
8120 Flags = getNoWrapFlagsFromUB(BO->Op);
8121 LHS = getSCEV(BO->LHS);
8122 RHS = getSCEV(BO->RHS);
8123 return getMinusSCEV(LHS, RHS, Flags);
8124 }
8125 case Instruction::And:
8126 // For an expression like x&255 that merely masks off the high bits,
8127 // use zext(trunc(x)) as the SCEV expression.
8128 if (ConstantInt *CI = dyn_cast<ConstantInt>(BO->RHS)) {
8129 if (CI->isZero())
8130 return getSCEV(BO->RHS);
8131 if (CI->isMinusOne())
8132 return getSCEV(BO->LHS);
8133 const APInt &A = CI->getValue();
8134
8135 // Instcombine's ShrinkDemandedConstant may strip bits out of
8136 // constants, obscuring what would otherwise be a low-bits mask.
8137 // Use computeKnownBits to compute what ShrinkDemandedConstant
8138 // knew about to reconstruct a low-bits mask value.
8139 unsigned LZ = A.countl_zero();
8140 unsigned TZ = A.countr_zero();
8141 unsigned BitWidth = A.getBitWidth();
8142 KnownBits Known(BitWidth);
8143 computeKnownBits(BO->LHS, Known, getDataLayout(), &AC, nullptr, &DT);
8144
8145 APInt EffectiveMask =
8146 APInt::getLowBitsSet(BitWidth, BitWidth - LZ - TZ).shl(TZ);
8147 if ((LZ != 0 || TZ != 0) && !((~A & ~Known.Zero) & EffectiveMask)) {
8148 const SCEV *MulCount = getConstant(APInt::getOneBitSet(BitWidth, TZ));
8149 const SCEV *LHS = getSCEV(BO->LHS);
8150 const SCEV *ShiftedLHS = nullptr;
8151 if (auto *LHSMul = dyn_cast<SCEVMulExpr>(LHS)) {
8152 if (auto *OpC = dyn_cast<SCEVConstant>(LHSMul->getOperand(0))) {
8153 // For an expression like (x * 8) & 8, simplify the multiply.
8154 unsigned MulZeros = OpC->getAPInt().countr_zero();
8155 unsigned GCD = std::min(MulZeros, TZ);
8156 APInt DivAmt = APInt::getOneBitSet(BitWidth, TZ - GCD);
8158 MulOps.push_back(getConstant(OpC->getAPInt().ashr(GCD)));
8159 append_range(MulOps, LHSMul->operands().drop_front());
8160 auto *NewMul = getMulExpr(MulOps, LHSMul->getNoWrapFlags());
8161 ShiftedLHS = getUDivExpr(NewMul, getConstant(DivAmt));
8162 }
8163 }
8164 if (!ShiftedLHS)
8165 ShiftedLHS = getUDivExpr(LHS, MulCount);
8166 return getMulExpr(
8168 getTruncateExpr(ShiftedLHS,
8169 IntegerType::get(getContext(), BitWidth - LZ - TZ)),
8170 BO->LHS->getType()),
8171 MulCount);
8172 }
8173 }
8174 // Binary `and` is a bit-wise `umin`.
8175 if (BO->LHS->getType()->isIntegerTy(1)) {
8176 LHS = getSCEV(BO->LHS);
8177 RHS = getSCEV(BO->RHS);
8178 return getUMinExpr(LHS, RHS);
8179 }
8180 break;
8181
8182 case Instruction::Or:
8183 // Binary `or` is a bit-wise `umax`.
8184 if (BO->LHS->getType()->isIntegerTy(1)) {
8185 LHS = getSCEV(BO->LHS);
8186 RHS = getSCEV(BO->RHS);
8187 return getUMaxExpr(LHS, RHS);
8188 }
8189 break;
8190
8191 case Instruction::Xor:
8192 if (ConstantInt *CI = dyn_cast<ConstantInt>(BO->RHS)) {
8193 // If the RHS of xor is -1, then this is a not operation.
8194 if (CI->isMinusOne())
8195 return getNotSCEV(getSCEV(BO->LHS));
8196
8197 // Model xor(and(x, C), C) as and(~x, C), if C is a low-bits mask.
8198 // This is a variant of the check for xor with -1, and it handles
8199 // the case where instcombine has trimmed non-demanded bits out
8200 // of an xor with -1.
8201 if (auto *LBO = dyn_cast<BinaryOperator>(BO->LHS))
8202 if (ConstantInt *LCI = dyn_cast<ConstantInt>(LBO->getOperand(1)))
8203 if (LBO->getOpcode() == Instruction::And &&
8204 LCI->getValue() == CI->getValue())
8205 if (const SCEVZeroExtendExpr *Z =
8207 Type *UTy = BO->LHS->getType();
8208 const SCEV *Z0 = Z->getOperand();
8209 Type *Z0Ty = Z0->getType();
8210 unsigned Z0TySize = getTypeSizeInBits(Z0Ty);
8211
8212 // If C is a low-bits mask, the zero extend is serving to
8213 // mask off the high bits. Complement the operand and
8214 // re-apply the zext.
8215 if (CI->getValue().isMask(Z0TySize))
8216 return getZeroExtendExpr(getNotSCEV(Z0), UTy);
8217
8218 // If C is a single bit, it may be in the sign-bit position
8219 // before the zero-extend. In this case, represent the xor
8220 // using an add, which is equivalent, and re-apply the zext.
8221 APInt Trunc = CI->getValue().trunc(Z0TySize);
8222 if (Trunc.zext(getTypeSizeInBits(UTy)) == CI->getValue() &&
8223 Trunc.isSignMask())
8224 return getZeroExtendExpr(getAddExpr(Z0, getConstant(Trunc)),
8225 UTy);
8226 }
8227 }
8228 break;
8229
8230 case Instruction::Shl:
8231 // Turn shift left of a constant amount into a multiply.
8232 if (ConstantInt *SA = dyn_cast<ConstantInt>(BO->RHS)) {
8233 uint32_t BitWidth = cast<IntegerType>(SA->getType())->getBitWidth();
8234
8235 // If the shift count is not less than the bitwidth, the result of
8236 // the shift is undefined. Don't try to analyze it, because the
8237 // resolution chosen here may differ from the resolution chosen in
8238 // other parts of the compiler.
8239 if (SA->getValue().uge(BitWidth))
8240 break;
8241
8242 // We can safely preserve the nuw flag in all cases. It's also safe to
8243 // turn a nuw nsw shl into a nuw nsw mul. However, nsw in isolation
8244 // requires special handling. It can be preserved as long as we're not
8245 // left shifting by bitwidth - 1.
8246 auto Flags = SCEV::FlagAnyWrap;
8247 if (BO->Op) {
8248 auto MulFlags = getNoWrapFlagsFromUB(BO->Op);
8249 if (any(MulFlags & SCEV::FlagNSW) &&
8250 (any(MulFlags & SCEV::FlagNUW) ||
8251 SA->getValue().ult(BitWidth - 1)))
8253 if (any(MulFlags & SCEV::FlagNUW))
8255 }
8256
8257 ConstantInt *X = ConstantInt::get(
8258 getContext(), APInt::getOneBitSet(BitWidth, SA->getZExtValue()));
8259 return getMulExpr(getSCEV(BO->LHS), getConstant(X), Flags);
8260 }
8261 break;
8262
8263 case Instruction::AShr:
8264 // AShr X, C, where C is a constant.
8265 ConstantInt *CI = dyn_cast<ConstantInt>(BO->RHS);
8266 if (!CI)
8267 break;
8268
8269 Type *OuterTy = BO->LHS->getType();
8270 uint64_t BitWidth = getTypeSizeInBits(OuterTy);
8271 // If the shift count is not less than the bitwidth, the result of
8272 // the shift is undefined. Don't try to analyze it, because the
8273 // resolution chosen here may differ from the resolution chosen in
8274 // other parts of the compiler.
8275 if (CI->getValue().uge(BitWidth))
8276 break;
8277
8278 if (CI->isZero())
8279 return getSCEV(BO->LHS); // shift by zero --> noop
8280
8281 uint64_t AShrAmt = CI->getZExtValue();
8282 Type *TruncTy = IntegerType::get(getContext(), BitWidth - AShrAmt);
8283
8284 Operator *L = dyn_cast<Operator>(BO->LHS);
8285 const SCEV *AddTruncateExpr = nullptr;
8286 ConstantInt *ShlAmtCI = nullptr;
8287 const SCEV *AddConstant = nullptr;
8288
8289 if (L && L->getOpcode() == Instruction::Add) {
8290 // X = Shl A, n
8291 // Y = Add X, c
8292 // Z = AShr Y, m
8293 // n, c and m are constants.
8294
8295 Operator *LShift = dyn_cast<Operator>(L->getOperand(0));
8296 ConstantInt *AddOperandCI = dyn_cast<ConstantInt>(L->getOperand(1));
8297 if (LShift && LShift->getOpcode() == Instruction::Shl) {
8298 if (AddOperandCI) {
8299 const SCEV *ShlOp0SCEV = getSCEV(LShift->getOperand(0));
8300 ShlAmtCI = dyn_cast<ConstantInt>(LShift->getOperand(1));
8301 // since we truncate to TruncTy, the AddConstant should be of the
8302 // same type, so create a new Constant with type same as TruncTy.
8303 // Also, the Add constant should be shifted right by AShr amount.
8304 APInt AddOperand = AddOperandCI->getValue().ashr(AShrAmt);
8305 AddConstant = getConstant(AddOperand.trunc(BitWidth - AShrAmt));
8306 // we model the expression as sext(add(trunc(A), c << n)), since the
8307 // sext(trunc) part is already handled below, we create a
8308 // AddExpr(TruncExp) which will be used later.
8309 AddTruncateExpr = getTruncateExpr(ShlOp0SCEV, TruncTy);
8310 }
8311 }
8312 } else if (L && L->getOpcode() == Instruction::Shl) {
8313 // X = Shl A, n
8314 // Y = AShr X, m
8315 // Both n and m are constant.
8316
8317 const SCEV *ShlOp0SCEV = getSCEV(L->getOperand(0));
8318 ShlAmtCI = dyn_cast<ConstantInt>(L->getOperand(1));
8319 AddTruncateExpr = getTruncateExpr(ShlOp0SCEV, TruncTy);
8320 }
8321
8322 if (AddTruncateExpr && ShlAmtCI) {
8323 // We can merge the two given cases into a single SCEV statement,
8324 // incase n = m, the mul expression will be 2^0, so it gets resolved to
8325 // a simpler case. The following code handles the two cases:
8326 //
8327 // 1) For a two-shift sext-inreg, i.e. n = m,
8328 // use sext(trunc(x)) as the SCEV expression.
8329 //
8330 // 2) When n > m, use sext(mul(trunc(x), 2^(n-m)))) as the SCEV
8331 // expression. We already checked that ShlAmt < BitWidth, so
8332 // the multiplier, 1 << (ShlAmt - AShrAmt), fits into TruncTy as
8333 // ShlAmt - AShrAmt < Amt.
8334 const APInt &ShlAmt = ShlAmtCI->getValue();
8335 if (ShlAmt.ult(BitWidth) && ShlAmt.uge(AShrAmt)) {
8336 APInt Mul = APInt::getOneBitSet(BitWidth - AShrAmt,
8337 ShlAmtCI->getZExtValue() - AShrAmt);
8338 const SCEV *CompositeExpr =
8339 getMulExpr(AddTruncateExpr, getConstant(Mul));
8340 if (L->getOpcode() != Instruction::Shl)
8341 CompositeExpr = getAddExpr(CompositeExpr, AddConstant);
8342
8343 return getSignExtendExpr(CompositeExpr, OuterTy);
8344 }
8345 }
8346 break;
8347 }
8348 }
8349
8350 switch (U->getOpcode()) {
8351 case Instruction::Trunc:
8352 return getTruncateExpr(getSCEV(U->getOperand(0)), U->getType());
8353
8354 case Instruction::ZExt:
8355 return getZeroExtendExpr(getSCEV(U->getOperand(0)), U->getType());
8356
8357 case Instruction::SExt:
8358 if (auto BO = MatchBinaryOp(U->getOperand(0), getDataLayout(), AC, DT,
8360 // The NSW flag of a subtract does not always survive the conversion to
8361 // A + (-1)*B. By pushing sign extension onto its operands we are much
8362 // more likely to preserve NSW and allow later AddRec optimisations.
8363 //
8364 // NOTE: This is effectively duplicating this logic from getSignExtend:
8365 // sext((A + B + ...)<nsw>) --> (sext(A) + sext(B) + ...)<nsw>
8366 // but by that point the NSW information has potentially been lost.
8367 if (BO->Opcode == Instruction::Sub && BO->IsNSW) {
8368 Type *Ty = U->getType();
8369 auto *V1 = getSignExtendExpr(getSCEV(BO->LHS), Ty);
8370 auto *V2 = getSignExtendExpr(getSCEV(BO->RHS), Ty);
8371 return getMinusSCEV(V1, V2, SCEV::FlagNSW);
8372 }
8373 }
8374 return getSignExtendExpr(getSCEV(U->getOperand(0)), U->getType());
8375
8376 case Instruction::BitCast:
8377 // BitCasts are no-op casts so we just eliminate the cast.
8378 if (isSCEVable(U->getType()) && isSCEVable(U->getOperand(0)->getType()))
8379 return getSCEV(U->getOperand(0));
8380 break;
8381
8382 case Instruction::PtrToAddr: {
8383 const SCEV *IntOp = getPtrToAddrExpr(getSCEV(U->getOperand(0)));
8384 if (isa<SCEVCouldNotCompute>(IntOp))
8385 return getUnknown(V);
8386 return IntOp;
8387 }
8388
8389 case Instruction::PtrToInt: {
8390 // Pointer to integer cast is straight-forward, so do model it.
8391 const SCEV *Op = getSCEV(U->getOperand(0));
8392 Type *DstIntTy = U->getType();
8393 // But only if effective SCEV (integer) type is wide enough to represent
8394 // all possible pointer values.
8395 const SCEV *IntOp = getPtrToIntExpr(Op, DstIntTy);
8396 if (isa<SCEVCouldNotCompute>(IntOp))
8397 return getUnknown(V);
8398 return IntOp;
8399 }
8400 case Instruction::IntToPtr:
8401 // Just don't deal with inttoptr casts.
8402 return getUnknown(V);
8403
8404 case Instruction::SDiv:
8405 // If both operands are non-negative, this is just an udiv.
8406 if (isKnownNonNegative(getSCEV(U->getOperand(0))) &&
8407 isKnownNonNegative(getSCEV(U->getOperand(1))))
8408 return getUDivExpr(getSCEV(U->getOperand(0)), getSCEV(U->getOperand(1)));
8409 break;
8410
8411 case Instruction::SRem:
8412 // If both operands are non-negative, this is just an urem.
8413 if (isKnownNonNegative(getSCEV(U->getOperand(0))) &&
8414 isKnownNonNegative(getSCEV(U->getOperand(1))))
8415 return getURemExpr(getSCEV(U->getOperand(0)), getSCEV(U->getOperand(1)));
8416 break;
8417
8418 case Instruction::GetElementPtr:
8419 return createNodeForGEP(cast<GEPOperator>(U));
8420
8421 case Instruction::PHI:
8422 return createNodeForPHI(cast<PHINode>(U));
8423
8424 case Instruction::Select:
8425 return createNodeForSelectOrPHI(U, U->getOperand(0), U->getOperand(1),
8426 U->getOperand(2));
8427
8428 case Instruction::Call:
8429 case Instruction::Invoke:
8430 if (Value *RV = cast<CallBase>(U)->getReturnedArgOperand())
8431 return getSCEV(RV);
8432
8433 if (auto *II = dyn_cast<IntrinsicInst>(U)) {
8434 switch (II->getIntrinsicID()) {
8435 case Intrinsic::abs:
8436 return getAbsExpr(
8437 getSCEV(II->getArgOperand(0)),
8438 /*IsNSW=*/cast<ConstantInt>(II->getArgOperand(1))->isOne());
8439 case Intrinsic::umax:
8440 LHS = getSCEV(II->getArgOperand(0));
8441 RHS = getSCEV(II->getArgOperand(1));
8442 return getUMaxExpr(LHS, RHS);
8443 case Intrinsic::umin:
8444 LHS = getSCEV(II->getArgOperand(0));
8445 RHS = getSCEV(II->getArgOperand(1));
8446 return getUMinExpr(LHS, RHS);
8447 case Intrinsic::smax:
8448 LHS = getSCEV(II->getArgOperand(0));
8449 RHS = getSCEV(II->getArgOperand(1));
8450 return getSMaxExpr(LHS, RHS);
8451 case Intrinsic::smin:
8452 LHS = getSCEV(II->getArgOperand(0));
8453 RHS = getSCEV(II->getArgOperand(1));
8454 return getSMinExpr(LHS, RHS);
8455 case Intrinsic::usub_sat: {
8456 const SCEV *X = getSCEV(II->getArgOperand(0));
8457 const SCEV *Y = getSCEV(II->getArgOperand(1));
8458 const SCEV *ClampedY = getUMinExpr(X, Y);
8459 return getMinusSCEV(X, ClampedY, SCEV::FlagNUW);
8460 }
8461 case Intrinsic::uadd_sat: {
8462 const SCEV *X = getSCEV(II->getArgOperand(0));
8463 const SCEV *Y = getSCEV(II->getArgOperand(1));
8464 const SCEV *ClampedX = getUMinExpr(X, getNotSCEV(Y));
8465 return getAddExpr(ClampedX, Y, SCEV::FlagNUW);
8466 }
8467 case Intrinsic::start_loop_iterations:
8468 case Intrinsic::annotation:
8469 case Intrinsic::ptr_annotation:
8470 // A start_loop_iterations or llvm.annotation or llvm.prt.annotation is
8471 // just eqivalent to the first operand for SCEV purposes.
8472 return getSCEV(II->getArgOperand(0));
8473 case Intrinsic::vscale:
8474 return getVScale(II->getType());
8475 default:
8476 break;
8477 }
8478 }
8479 break;
8480 }
8481
8482 return getUnknown(V);
8483}
8484
8485//===----------------------------------------------------------------------===//
8486// Iteration Count Computation Code
8487//
8488
8490 if (isa<SCEVCouldNotCompute>(ExitCount))
8491 return getCouldNotCompute();
8492
8493 auto *ExitCountType = ExitCount->getType();
8494 assert(ExitCountType->isIntegerTy());
8495 auto *EvalTy = Type::getIntNTy(ExitCountType->getContext(),
8496 1 + ExitCountType->getScalarSizeInBits());
8497 return getTripCountFromExitCount(ExitCount, EvalTy, nullptr);
8498}
8499
8501 Type *EvalTy,
8502 const Loop *L) {
8503 if (isa<SCEVCouldNotCompute>(ExitCount))
8504 return getCouldNotCompute();
8505
8506 unsigned ExitCountSize = getTypeSizeInBits(ExitCount->getType());
8507 unsigned EvalSize = EvalTy->getPrimitiveSizeInBits();
8508
8509 auto CanAddOneWithoutOverflow = [&]() {
8510 ConstantRange ExitCountRange =
8511 getRangeRef(ExitCount, RangeSignHint::HINT_RANGE_UNSIGNED);
8512 if (!ExitCountRange.contains(APInt::getMaxValue(ExitCountSize)))
8513 return true;
8514
8515 return L && isLoopEntryGuardedByCond(L, ICmpInst::ICMP_NE, ExitCount,
8516 getMinusOne(ExitCount->getType()));
8517 };
8518
8519 // If we need to zero extend the backedge count, check if we can add one to
8520 // it prior to zero extending without overflow. Provided this is safe, it
8521 // allows better simplification of the +1.
8522 if (EvalSize > ExitCountSize && CanAddOneWithoutOverflow())
8523 return getZeroExtendExpr(
8524 getAddExpr(ExitCount, getOne(ExitCount->getType())), EvalTy);
8525
8526 // Get the total trip count from the count by adding 1. This may wrap.
8527 return getAddExpr(getTruncateOrZeroExtend(ExitCount, EvalTy), getOne(EvalTy));
8528}
8529
8530static unsigned getConstantTripCount(const SCEVConstant *ExitCount) {
8531 if (!ExitCount)
8532 return 0;
8533
8534 ConstantInt *ExitConst = ExitCount->getValue();
8535
8536 // Guard against huge trip counts.
8537 if (ExitConst->getValue().getActiveBits() > 32)
8538 return 0;
8539
8540 // In case of integer overflow, this returns 0, which is correct.
8541 return ((unsigned)ExitConst->getZExtValue()) + 1;
8542}
8543
8545 auto *ExitCount = dyn_cast<SCEVConstant>(getBackedgeTakenCount(L, Exact));
8546 return getConstantTripCount(ExitCount);
8547}
8548
8549unsigned
8551 const BasicBlock *ExitingBlock) {
8552 assert(ExitingBlock && "Must pass a non-null exiting block!");
8553 assert(L->isLoopExiting(ExitingBlock) &&
8554 "Exiting block must actually branch out of the loop!");
8555 const SCEVConstant *ExitCount =
8556 dyn_cast<SCEVConstant>(getExitCount(L, ExitingBlock));
8557 return getConstantTripCount(ExitCount);
8558}
8559
8561 const Loop *L, SmallVectorImpl<const SCEVPredicate *> *Predicates) {
8562
8563 const auto *MaxExitCount =
8564 Predicates ? getPredicatedConstantMaxBackedgeTakenCount(L, *Predicates)
8566 return getConstantTripCount(dyn_cast<SCEVConstant>(MaxExitCount));
8567}
8568
8570 SmallVector<BasicBlock *, 8> ExitingBlocks;
8571 L->getExitingBlocks(ExitingBlocks);
8572
8573 std::optional<unsigned> Res;
8574 for (auto *ExitingBB : ExitingBlocks) {
8575 unsigned Multiple = getSmallConstantTripMultiple(L, ExitingBB);
8576 if (!Res)
8577 Res = Multiple;
8578 Res = std::gcd(*Res, Multiple);
8579 }
8580 return Res.value_or(1);
8581}
8582
8584 const SCEV *ExitCount) {
8585 if (isa<SCEVCouldNotCompute>(ExitCount))
8586 return 1;
8587
8588 // Get the trip count
8589 const SCEV *TCExpr = getTripCountFromExitCount(applyLoopGuards(ExitCount, L));
8590
8591 APInt Multiple = getNonZeroConstantMultiple(TCExpr);
8592 // If a trip multiple is huge (>=2^32), the trip count is still divisible by
8593 // the greatest power of 2 divisor less than 2^32.
8594 return Multiple.getActiveBits() > 32
8595 ? 1U << std::min(31U, Multiple.countTrailingZeros())
8596 : (unsigned)Multiple.getZExtValue();
8597}
8598
8599/// Returns the largest constant divisor of the trip count of this loop as a
8600/// normal unsigned value, if possible. This means that the actual trip count is
8601/// always a multiple of the returned value (don't forget the trip count could
8602/// very well be zero as well!).
8603///
8604/// Returns 1 if the trip count is unknown or not guaranteed to be the
8605/// multiple of a constant (which is also the case if the trip count is simply
8606/// constant, use getSmallConstantTripCount for that case), Will also return 1
8607/// if the trip count is very large (>= 2^32).
8608///
8609/// As explained in the comments for getSmallConstantTripCount, this assumes
8610/// that control exits the loop via ExitingBlock.
8611unsigned
8613 const BasicBlock *ExitingBlock) {
8614 assert(ExitingBlock && "Must pass a non-null exiting block!");
8615 assert(L->isLoopExiting(ExitingBlock) &&
8616 "Exiting block must actually branch out of the loop!");
8617 const SCEV *ExitCount = getExitCount(L, ExitingBlock);
8618 return getSmallConstantTripMultiple(L, ExitCount);
8619}
8620
8622 const BasicBlock *ExitingBlock,
8623 ExitCountKind Kind) {
8624 switch (Kind) {
8625 case Exact:
8626 return getBackedgeTakenInfo(L).getExact(ExitingBlock, this);
8627 case SymbolicMaximum:
8628 return getBackedgeTakenInfo(L).getSymbolicMax(ExitingBlock, this);
8629 case ConstantMaximum:
8630 return getBackedgeTakenInfo(L).getConstantMax(ExitingBlock, this);
8631 };
8632 llvm_unreachable("Invalid ExitCountKind!");
8633}
8634
8636 const Loop *L, const BasicBlock *ExitingBlock,
8638 switch (Kind) {
8639 case Exact:
8640 return getPredicatedBackedgeTakenInfo(L).getExact(ExitingBlock, this,
8641 Predicates);
8642 case SymbolicMaximum:
8643 return getPredicatedBackedgeTakenInfo(L).getSymbolicMax(ExitingBlock, this,
8644 Predicates);
8645 case ConstantMaximum:
8646 return getPredicatedBackedgeTakenInfo(L).getConstantMax(ExitingBlock, this,
8647 Predicates);
8648 };
8649 llvm_unreachable("Invalid ExitCountKind!");
8650}
8651
8654 return getPredicatedBackedgeTakenInfo(L).getExact(L, this, &Preds);
8655}
8656
8658 ExitCountKind Kind) {
8659 switch (Kind) {
8660 case Exact:
8661 return getBackedgeTakenInfo(L).getExact(L, this);
8662 case ConstantMaximum:
8663 return getBackedgeTakenInfo(L).getConstantMax(this);
8664 case SymbolicMaximum:
8665 return getBackedgeTakenInfo(L).getSymbolicMax(L, this);
8666 };
8667 llvm_unreachable("Invalid ExitCountKind!");
8668}
8669
8672 return getPredicatedBackedgeTakenInfo(L).getSymbolicMax(L, this, &Preds);
8673}
8674
8677 return getPredicatedBackedgeTakenInfo(L).getConstantMax(this, &Preds);
8678}
8679
8681 return getBackedgeTakenInfo(L).isConstantMaxOrZero(this);
8682}
8683
8684ScalarEvolution::BackedgeTakenInfo &
8685ScalarEvolution::getPredicatedBackedgeTakenInfo(const Loop *L) {
8686 auto &BTI = getBackedgeTakenInfo(L);
8687 if (BTI.hasFullInfo())
8688 return BTI;
8689
8690 auto Pair = PredicatedBackedgeTakenCounts.try_emplace(L);
8691
8692 if (!Pair.second)
8693 return Pair.first->second;
8694
8695 BackedgeTakenInfo Result =
8696 computeBackedgeTakenCount(L, /*AllowPredicates=*/true);
8697
8698 return PredicatedBackedgeTakenCounts.find(L)->second = std::move(Result);
8699}
8700
8701ScalarEvolution::BackedgeTakenInfo &
8702ScalarEvolution::getBackedgeTakenInfo(const Loop *L) {
8703 // Initially insert an invalid entry for this loop. If the insertion
8704 // succeeds, proceed to actually compute a backedge-taken count and
8705 // update the value. The temporary CouldNotCompute value tells SCEV
8706 // code elsewhere that it shouldn't attempt to request a new
8707 // backedge-taken count, which could result in infinite recursion.
8708 std::pair<DenseMap<const Loop *, BackedgeTakenInfo>::iterator, bool> Pair =
8709 BackedgeTakenCounts.try_emplace(L);
8710 if (!Pair.second)
8711 return Pair.first->second;
8712
8713 // computeBackedgeTakenCount may allocate memory for its result. Inserting it
8714 // into the BackedgeTakenCounts map transfers ownership. Otherwise, the result
8715 // must be cleared in this scope.
8716 BackedgeTakenInfo Result = computeBackedgeTakenCount(L);
8717
8718 // Now that we know more about the trip count for this loop, forget any
8719 // existing SCEV values for PHI nodes in this loop since they are only
8720 // conservative estimates made without the benefit of trip count
8721 // information. This invalidation is not necessary for correctness, and is
8722 // only done to produce more precise results.
8723 if (Result.hasAnyInfo()) {
8724 // Invalidate any expression using an addrec in this loop.
8725 SmallVector<SCEVUse, 8> ToForget;
8726 auto LoopUsersIt = LoopUsers.find(L);
8727 if (LoopUsersIt != LoopUsers.end())
8728 append_range(ToForget, LoopUsersIt->second);
8729 forgetMemoizedResults(ToForget);
8730
8731 // Invalidate constant-evolved loop header phis.
8732 for (PHINode &PN : L->getHeader()->phis())
8733 ConstantEvolutionLoopExitValue.erase(&PN);
8734 }
8735
8736 // Re-lookup the insert position, since the call to
8737 // computeBackedgeTakenCount above could result in a
8738 // recusive call to getBackedgeTakenInfo (on a different
8739 // loop), which would invalidate the iterator computed
8740 // earlier.
8741 return BackedgeTakenCounts.find(L)->second = std::move(Result);
8742}
8743
8745 // This method is intended to forget all info about loops. It should
8746 // invalidate caches as if the following happened:
8747 // - The trip counts of all loops have changed arbitrarily
8748 // - Every llvm::Value has been updated in place to produce a different
8749 // result.
8750 BackedgeTakenCounts.clear();
8751 PredicatedBackedgeTakenCounts.clear();
8752 BECountUsers.clear();
8753 LoopPropertiesCache.clear();
8754 ConstantEvolutionLoopExitValue.clear();
8755 ValueExprMap.clear();
8756 ValuesAtScopes.clear();
8757 ValuesAtScopesUsers.clear();
8758 LoopDispositions.clear();
8759 BlockDispositions.clear();
8760 UnsignedRanges.clear();
8761 SignedRanges.clear();
8762 ExprValueMap.clear();
8763 HasRecMap.clear();
8764 ConstantMultipleCache.clear();
8765 PredicatedSCEVRewrites.clear();
8766 FoldCache.clear();
8767 FoldCacheUser.clear();
8768}
8769void ScalarEvolution::visitAndClearUsers(
8772 SmallVectorImpl<SCEVUse> &ToForget) {
8773 while (!Worklist.empty()) {
8774 Instruction *I = Worklist.pop_back_val();
8775 if (!isSCEVable(I->getType()) && !isa<WithOverflowInst>(I))
8776 continue;
8777
8779 ValueExprMap.find_as(static_cast<Value *>(I));
8780 if (It != ValueExprMap.end()) {
8781 ToForget.push_back(It->second);
8782 eraseValueFromMap(It->first);
8783 if (PHINode *PN = dyn_cast<PHINode>(I))
8784 ConstantEvolutionLoopExitValue.erase(PN);
8785 }
8786
8787 PushDefUseChildren(I, Worklist, Visited);
8788 }
8789}
8790
8792 SmallVector<const Loop *, 16> LoopWorklist(1, L);
8793 SmallVector<SCEVUse, 16> ToForget;
8794
8795 // Iterate over all the loops and sub-loops to drop SCEV information.
8796 while (!LoopWorklist.empty()) {
8797 auto *CurrL = LoopWorklist.pop_back_val();
8798
8799 // Drop any stored trip count value.
8800 forgetBackedgeTakenCounts(CurrL, /* Predicated */ false);
8801 forgetBackedgeTakenCounts(CurrL, /* Predicated */ true);
8802
8803 // Drop information about predicated SCEV rewrites for this loop.
8804 PredicatedSCEVRewrites.remove_if(
8805 [&](const auto &Entry) { return Entry.first.second == CurrL; });
8806
8807 auto LoopUsersItr = LoopUsers.find(CurrL);
8808 if (LoopUsersItr != LoopUsers.end())
8809 llvm::append_range(ToForget, LoopUsersItr->second);
8810
8811 // Drop information about expressions based on loop-header PHIs.
8812 for (PHINode &PN : CurrL->getHeader()->phis()) {
8813 ConstantEvolutionLoopExitValue.erase(&PN);
8814 auto VIt = ValueExprMap.find_as(static_cast<Value *>(&PN));
8815 if (VIt != ValueExprMap.end())
8816 ToForget.push_back(VIt->second);
8817 }
8818
8819 LoopPropertiesCache.erase(CurrL);
8820 // Forget all contained loops too, to avoid dangling entries in the
8821 // ValuesAtScopes map.
8822 LoopWorklist.append(CurrL->begin(), CurrL->end());
8823 }
8824 forgetMemoizedResults(ToForget);
8825}
8826
8828 forgetLoop(L->getOutermostLoop());
8829}
8830
8833 if (!I) return;
8834
8835 // Drop information about expressions based on loop-header PHIs.
8838 SmallVector<SCEVUse, 8> ToForget;
8839 Worklist.push_back(I);
8840 Visited.insert(I);
8841 visitAndClearUsers(Worklist, Visited, ToForget);
8842
8843 forgetMemoizedResults(ToForget);
8844}
8845
8847 if (!isSCEVable(V->getType()))
8848 return;
8849
8850 // If SCEV looked through a trivial LCSSA phi node, we might have SCEV's
8851 // directly using a SCEVUnknown/SCEVAddRec defined in the loop. After an
8852 // extra predecessor is added, this is no longer valid. Find all Unknowns and
8853 // AddRecs defined in the loop and invalidate any SCEV's making use of them.
8854 if (const SCEV *S = getExistingSCEV(V)) {
8855 struct InvalidationRootCollector {
8856 Loop *L;
8858
8859 InvalidationRootCollector(Loop *L) : L(L) {}
8860
8861 bool follow(const SCEV *S) {
8862 if (auto *SU = dyn_cast<SCEVUnknown>(S)) {
8863 if (auto *I = dyn_cast<Instruction>(SU->getValue()))
8864 if (L->contains(I))
8865 Roots.push_back(S);
8866 } else if (auto *AddRec = dyn_cast<SCEVAddRecExpr>(S)) {
8867 if (L->contains(AddRec->getLoop()))
8868 Roots.push_back(S);
8869 }
8870 return true;
8871 }
8872 bool isDone() const { return false; }
8873 };
8874
8875 InvalidationRootCollector C(L);
8876 visitAll(S, C);
8877 forgetMemoizedResults(C.Roots);
8878 }
8879
8880 // Also perform the normal invalidation.
8881 forgetValue(V);
8882}
8883
8884void ScalarEvolution::forgetLoopDispositions() { LoopDispositions.clear(); }
8885
8887 // Unless a specific value is passed to invalidation, completely clear both
8888 // caches.
8889 if (!V) {
8890 BlockDispositions.clear();
8891 LoopDispositions.clear();
8892 return;
8893 }
8894
8895 if (!isSCEVable(V->getType()))
8896 return;
8897
8898 const SCEV *S = getExistingSCEV(V);
8899 if (!S)
8900 return;
8901
8902 // Invalidate the block and loop dispositions cached for S. Dispositions of
8903 // S's users may change if S's disposition changes (i.e. a user may change to
8904 // loop-invariant, if S changes to loop invariant), so also invalidate
8905 // dispositions of S's users recursively.
8906 SmallVector<SCEVUse, 8> Worklist = {S};
8908 while (!Worklist.empty()) {
8909 const SCEV *Curr = Worklist.pop_back_val();
8910 bool LoopDispoRemoved = LoopDispositions.erase(Curr);
8911 bool BlockDispoRemoved = BlockDispositions.erase(Curr);
8912 if (!LoopDispoRemoved && !BlockDispoRemoved)
8913 continue;
8914 auto Users = SCEVUsers.find(Curr);
8915 if (Users != SCEVUsers.end())
8916 for (const auto *User : Users->second)
8917 if (Seen.insert(User).second)
8918 Worklist.push_back(User);
8919 }
8920}
8921
8922/// Get the exact loop backedge taken count considering all loop exits. A
8923/// computable result can only be returned for loops with all exiting blocks
8924/// dominating the latch. howFarToZero assumes that the limit of each loop test
8925/// is never skipped. This is a valid assumption as long as the loop exits via
8926/// that test. For precise results, it is the caller's responsibility to specify
8927/// the relevant loop exiting block using getExact(ExitingBlock, SE).
8928const SCEV *ScalarEvolution::BackedgeTakenInfo::getExact(
8929 const Loop *L, ScalarEvolution *SE,
8931 // If any exits were not computable, the loop is not computable.
8932 if (!isComplete() || ExitNotTaken.empty())
8933 return SE->getCouldNotCompute();
8934
8935 const BasicBlock *Latch = L->getLoopLatch();
8936 // All exiting blocks we have collected must dominate the only backedge.
8937 if (!Latch)
8938 return SE->getCouldNotCompute();
8939
8940 // All exiting blocks we have gathered dominate loop's latch, so exact trip
8941 // count is simply a minimum out of all these calculated exit counts.
8943 for (const auto &ENT : ExitNotTaken) {
8944 const SCEV *BECount = ENT.ExactNotTaken;
8945 assert(BECount != SE->getCouldNotCompute() && "Bad exit SCEV!");
8946 assert(SE->DT.dominates(ENT.ExitingBlock, Latch) &&
8947 "We should only have known counts for exiting blocks that dominate "
8948 "latch!");
8949
8950 Ops.push_back(BECount);
8951
8952 if (Preds)
8953 append_range(*Preds, ENT.Predicates);
8954
8955 assert((Preds || ENT.hasAlwaysTruePredicate()) &&
8956 "Predicate should be always true!");
8957 }
8958
8959 // If an earlier exit exits on the first iteration (exit count zero), then
8960 // a later poison exit count should not propagate into the result. This are
8961 // exactly the semantics provided by umin_seq.
8962 return SE->getUMinFromMismatchedTypes(Ops, /* Sequential */ true);
8963}
8964
8965const ScalarEvolution::ExitNotTakenInfo *
8966ScalarEvolution::BackedgeTakenInfo::getExitNotTaken(
8967 const BasicBlock *ExitingBlock,
8968 SmallVectorImpl<const SCEVPredicate *> *Predicates) const {
8969 for (const auto &ENT : ExitNotTaken)
8970 if (ENT.ExitingBlock == ExitingBlock) {
8971 if (ENT.hasAlwaysTruePredicate())
8972 return &ENT;
8973 else if (Predicates) {
8974 append_range(*Predicates, ENT.Predicates);
8975 return &ENT;
8976 }
8977 }
8978
8979 return nullptr;
8980}
8981
8982/// getConstantMax - Get the constant max backedge taken count for the loop.
8983const SCEV *ScalarEvolution::BackedgeTakenInfo::getConstantMax(
8984 ScalarEvolution *SE,
8985 SmallVectorImpl<const SCEVPredicate *> *Predicates) const {
8986 if (!getConstantMax())
8987 return SE->getCouldNotCompute();
8988
8989 for (const auto &ENT : ExitNotTaken)
8990 if (!ENT.hasAlwaysTruePredicate()) {
8991 if (!Predicates)
8992 return SE->getCouldNotCompute();
8993 append_range(*Predicates, ENT.Predicates);
8994 }
8995
8996 assert((isa<SCEVCouldNotCompute>(getConstantMax()) ||
8997 isa<SCEVConstant>(getConstantMax())) &&
8998 "No point in having a non-constant max backedge taken count!");
8999 return getConstantMax();
9000}
9001
9002const SCEV *ScalarEvolution::BackedgeTakenInfo::getSymbolicMax(
9003 const Loop *L, ScalarEvolution *SE,
9004 SmallVectorImpl<const SCEVPredicate *> *Predicates) {
9005 if (!SymbolicMax) {
9006 // Form an expression for the maximum exit count possible for this loop. We
9007 // merge the max and exact information to approximate a version of
9008 // getConstantMaxBackedgeTakenCount which isn't restricted to just
9009 // constants.
9010 SmallVector<SCEVUse, 4> ExitCounts;
9011
9012 for (const auto &ENT : ExitNotTaken) {
9013 const SCEV *ExitCount = ENT.SymbolicMaxNotTaken;
9014 if (!isa<SCEVCouldNotCompute>(ExitCount)) {
9015 assert(SE->DT.dominates(ENT.ExitingBlock, L->getLoopLatch()) &&
9016 "We should only have known counts for exiting blocks that "
9017 "dominate latch!");
9018 ExitCounts.push_back(ExitCount);
9019 if (Predicates)
9020 append_range(*Predicates, ENT.Predicates);
9021
9022 assert((Predicates || ENT.hasAlwaysTruePredicate()) &&
9023 "Predicate should be always true!");
9024 }
9025 }
9026 if (ExitCounts.empty())
9027 SymbolicMax = SE->getCouldNotCompute();
9028 else
9029 SymbolicMax =
9030 SE->getUMinFromMismatchedTypes(ExitCounts, /*Sequential*/ true);
9031 }
9032 return SymbolicMax;
9033}
9034
9035bool ScalarEvolution::BackedgeTakenInfo::isConstantMaxOrZero(
9036 ScalarEvolution *SE) const {
9037 auto PredicateNotAlwaysTrue = [](const ExitNotTakenInfo &ENT) {
9038 return !ENT.hasAlwaysTruePredicate();
9039 };
9040 return MaxOrZero && !any_of(ExitNotTaken, PredicateNotAlwaysTrue);
9041}
9042
9045
9047 const SCEV *E, const SCEV *ConstantMaxNotTaken,
9048 const SCEV *SymbolicMaxNotTaken, bool MaxOrZero,
9052 // If we prove the max count is zero, so is the symbolic bound. This happens
9053 // in practice due to differences in a) how context sensitive we've chosen
9054 // to be and b) how we reason about bounds implied by UB.
9055 if (ConstantMaxNotTaken->isZero()) {
9056 this->ExactNotTaken = E = ConstantMaxNotTaken;
9057 this->SymbolicMaxNotTaken = SymbolicMaxNotTaken = ConstantMaxNotTaken;
9058 }
9059
9062 "Exact is not allowed to be less precise than Constant Max");
9065 "Exact is not allowed to be less precise than Symbolic Max");
9068 "Symbolic Max is not allowed to be less precise than Constant Max");
9071 "No point in having a non-constant max backedge taken count!");
9073 for (const auto PredList : PredLists)
9074 for (const auto *P : PredList) {
9075 if (SeenPreds.contains(P))
9076 continue;
9077 assert(!isa<SCEVUnionPredicate>(P) && "Only add leaf predicates here!");
9078 SeenPreds.insert(P);
9079 Predicates.push_back(P);
9080 }
9081 assert((isa<SCEVCouldNotCompute>(E) || !E->getType()->isPointerTy()) &&
9082 "Backedge count should be int");
9084 !ConstantMaxNotTaken->getType()->isPointerTy()) &&
9085 "Max backedge count should be int");
9086}
9087
9095
9096/// Allocate memory for BackedgeTakenInfo and copy the not-taken count of each
9097/// computable exit into a persistent ExitNotTakenInfo array.
9098ScalarEvolution::BackedgeTakenInfo::BackedgeTakenInfo(
9100 bool IsComplete, const SCEV *ConstantMax, bool MaxOrZero)
9101 : ConstantMax(ConstantMax), IsComplete(IsComplete), MaxOrZero(MaxOrZero) {
9102 using EdgeExitInfo = ScalarEvolution::BackedgeTakenInfo::EdgeExitInfo;
9103
9104 ExitNotTaken.reserve(ExitCounts.size());
9105 std::transform(ExitCounts.begin(), ExitCounts.end(),
9106 std::back_inserter(ExitNotTaken),
9107 [&](const EdgeExitInfo &EEI) {
9108 BasicBlock *ExitBB = EEI.first;
9109 const ExitLimit &EL = EEI.second;
9110 return ExitNotTakenInfo(ExitBB, EL.ExactNotTaken,
9111 EL.ConstantMaxNotTaken, EL.SymbolicMaxNotTaken,
9112 EL.Predicates);
9113 });
9114 assert((isa<SCEVCouldNotCompute>(ConstantMax) ||
9115 isa<SCEVConstant>(ConstantMax)) &&
9116 "No point in having a non-constant max backedge taken count!");
9117}
9118
9119/// Compute the number of times the backedge of the specified loop will execute.
9120ScalarEvolution::BackedgeTakenInfo
9121ScalarEvolution::computeBackedgeTakenCount(const Loop *L,
9122 bool AllowPredicates) {
9123 SmallVector<BasicBlock *, 8> ExitingBlocks;
9124 L->getExitingBlocks(ExitingBlocks);
9125
9126 using EdgeExitInfo = ScalarEvolution::BackedgeTakenInfo::EdgeExitInfo;
9127
9129 bool CouldComputeBECount = true;
9130 BasicBlock *Latch = L->getLoopLatch(); // may be NULL.
9131 const SCEV *MustExitMaxBECount = nullptr;
9132 const SCEV *MayExitMaxBECount = nullptr;
9133 bool MustExitMaxOrZero = false;
9134 bool IsOnlyExit = ExitingBlocks.size() == 1;
9135
9136 // Compute the ExitLimit for each loop exit. Use this to populate ExitCounts
9137 // and compute maxBECount.
9138 // Do a union of all the predicates here.
9139 for (BasicBlock *ExitBB : ExitingBlocks) {
9140 // We canonicalize untaken exits to br (constant), ignore them so that
9141 // proving an exit untaken doesn't negatively impact our ability to reason
9142 // about the loop as whole.
9143 if (auto *BI = dyn_cast<CondBrInst>(ExitBB->getTerminator()))
9144 if (auto *CI = dyn_cast<ConstantInt>(BI->getCondition())) {
9145 bool ExitIfTrue = !L->contains(BI->getSuccessor(0));
9146 if (ExitIfTrue == CI->isZero())
9147 continue;
9148 }
9149
9150 ExitLimit EL = computeExitLimit(L, ExitBB, IsOnlyExit, AllowPredicates);
9151
9152 assert((AllowPredicates || EL.Predicates.empty()) &&
9153 "Predicated exit limit when predicates are not allowed!");
9154
9155 // 1. For each exit that can be computed, add an entry to ExitCounts.
9156 // CouldComputeBECount is true only if all exits can be computed.
9157 if (EL.ExactNotTaken != getCouldNotCompute())
9158 ++NumExitCountsComputed;
9159 else
9160 // We couldn't compute an exact value for this exit, so
9161 // we won't be able to compute an exact value for the loop.
9162 CouldComputeBECount = false;
9163 // Remember exit count if either exact or symbolic is known. Because
9164 // Exact always implies symbolic, only check symbolic.
9165 if (EL.SymbolicMaxNotTaken != getCouldNotCompute())
9166 ExitCounts.emplace_back(ExitBB, EL);
9167 else {
9168 assert(EL.ExactNotTaken == getCouldNotCompute() &&
9169 "Exact is known but symbolic isn't?");
9170 ++NumExitCountsNotComputed;
9171 }
9172
9173 // 2. Derive the loop's MaxBECount from each exit's max number of
9174 // non-exiting iterations. Partition the loop exits into two kinds:
9175 // LoopMustExits and LoopMayExits.
9176 //
9177 // If the exit dominates the loop latch, it is a LoopMustExit otherwise it
9178 // is a LoopMayExit. If any computable LoopMustExit is found, then
9179 // MaxBECount is the minimum EL.ConstantMaxNotTaken of computable
9180 // LoopMustExits. Otherwise, MaxBECount is conservatively the maximum
9181 // EL.ConstantMaxNotTaken, where CouldNotCompute is considered greater than
9182 // any
9183 // computable EL.ConstantMaxNotTaken.
9184 if (EL.ConstantMaxNotTaken != getCouldNotCompute() && Latch &&
9185 DT.dominates(ExitBB, Latch)) {
9186 if (!MustExitMaxBECount) {
9187 MustExitMaxBECount = EL.ConstantMaxNotTaken;
9188 MustExitMaxOrZero = EL.MaxOrZero;
9189 } else {
9190 MustExitMaxBECount = getUMinFromMismatchedTypes(MustExitMaxBECount,
9191 EL.ConstantMaxNotTaken);
9192 }
9193 } else if (MayExitMaxBECount != getCouldNotCompute()) {
9194 if (!MayExitMaxBECount || EL.ConstantMaxNotTaken == getCouldNotCompute())
9195 MayExitMaxBECount = EL.ConstantMaxNotTaken;
9196 else {
9197 MayExitMaxBECount = getUMaxFromMismatchedTypes(MayExitMaxBECount,
9198 EL.ConstantMaxNotTaken);
9199 }
9200 }
9201 }
9202 const SCEV *MaxBECount = MustExitMaxBECount ? MustExitMaxBECount :
9203 (MayExitMaxBECount ? MayExitMaxBECount : getCouldNotCompute());
9204 // The loop backedge will be taken the maximum or zero times if there's
9205 // a single exit that must be taken the maximum or zero times.
9206 bool MaxOrZero = (MustExitMaxOrZero && ExitingBlocks.size() == 1);
9207
9208 // Remember which SCEVs are used in exit limits for invalidation purposes.
9209 // We only care about non-constant SCEVs here, so we can ignore
9210 // EL.ConstantMaxNotTaken
9211 // and MaxBECount, which must be SCEVConstant.
9212 for (const auto &Pair : ExitCounts) {
9213 if (!isa<SCEVConstant>(Pair.second.ExactNotTaken))
9214 BECountUsers[Pair.second.ExactNotTaken].insert({L, AllowPredicates});
9215 if (!isa<SCEVConstant>(Pair.second.SymbolicMaxNotTaken))
9216 BECountUsers[Pair.second.SymbolicMaxNotTaken].insert(
9217 {L, AllowPredicates});
9218 }
9219 return BackedgeTakenInfo(std::move(ExitCounts), CouldComputeBECount,
9220 MaxBECount, MaxOrZero);
9221}
9222
9223ScalarEvolution::ExitLimit
9224ScalarEvolution::computeExitLimit(const Loop *L, BasicBlock *ExitingBlock,
9225 bool IsOnlyExit, bool AllowPredicates) {
9226 assert(L->contains(ExitingBlock) && "Exit count for non-loop block?");
9227 // If our exiting block does not dominate the latch, then its connection with
9228 // loop's exit limit may be far from trivial.
9229 const BasicBlock *Latch = L->getLoopLatch();
9230 if (!Latch || !DT.dominates(ExitingBlock, Latch))
9231 return getCouldNotCompute();
9232
9233 Instruction *Term = ExitingBlock->getTerminator();
9234 if (CondBrInst *BI = dyn_cast<CondBrInst>(Term)) {
9235 bool ExitIfTrue = !L->contains(BI->getSuccessor(0));
9236 assert(ExitIfTrue == L->contains(BI->getSuccessor(1)) &&
9237 "It should have one successor in loop and one exit block!");
9238 // Proceed to the next level to examine the exit condition expression.
9239 return computeExitLimitFromCond(L, BI->getCondition(), ExitIfTrue,
9240 /*ControlsOnlyExit=*/IsOnlyExit,
9241 AllowPredicates);
9242 }
9243
9244 if (SwitchInst *SI = dyn_cast<SwitchInst>(Term)) {
9245 // For switch, make sure that there is a single exit from the loop.
9246 BasicBlock *Exit = nullptr;
9247 for (auto *SBB : successors(ExitingBlock))
9248 if (!L->contains(SBB)) {
9249 if (Exit) // Multiple exit successors.
9250 return getCouldNotCompute();
9251 Exit = SBB;
9252 }
9253 assert(Exit && "Exiting block must have at least one exit");
9254 return computeExitLimitFromSingleExitSwitch(
9255 L, SI, Exit, /*ControlsOnlyExit=*/IsOnlyExit);
9256 }
9257
9258 return getCouldNotCompute();
9259}
9260
9262 const Loop *L, Value *ExitCond, bool ExitIfTrue, bool ControlsOnlyExit,
9263 bool AllowPredicates) {
9264 ScalarEvolution::ExitLimitCacheTy Cache(L, ExitIfTrue, AllowPredicates);
9265 return computeExitLimitFromCondCached(Cache, L, ExitCond, ExitIfTrue,
9266 ControlsOnlyExit, AllowPredicates);
9267}
9268
9269std::optional<ScalarEvolution::ExitLimit>
9270ScalarEvolution::ExitLimitCache::find(const Loop *L, Value *ExitCond,
9271 bool ExitIfTrue, bool ControlsOnlyExit,
9272 bool AllowPredicates) {
9273 (void)this->L;
9274 (void)this->ExitIfTrue;
9275 (void)this->AllowPredicates;
9276
9277 assert(this->L == L && this->ExitIfTrue == ExitIfTrue &&
9278 this->AllowPredicates == AllowPredicates &&
9279 "Variance in assumed invariant key components!");
9280 auto Itr = TripCountMap.find({ExitCond, ControlsOnlyExit});
9281 if (Itr == TripCountMap.end())
9282 return std::nullopt;
9283 return Itr->second;
9284}
9285
9286void ScalarEvolution::ExitLimitCache::insert(const Loop *L, Value *ExitCond,
9287 bool ExitIfTrue,
9288 bool ControlsOnlyExit,
9289 bool AllowPredicates,
9290 const ExitLimit &EL) {
9291 assert(this->L == L && this->ExitIfTrue == ExitIfTrue &&
9292 this->AllowPredicates == AllowPredicates &&
9293 "Variance in assumed invariant key components!");
9294
9295 auto InsertResult = TripCountMap.insert({{ExitCond, ControlsOnlyExit}, EL});
9296 assert(InsertResult.second && "Expected successful insertion!");
9297 (void)InsertResult;
9298 (void)ExitIfTrue;
9299}
9300
9301ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromCondCached(
9302 ExitLimitCacheTy &Cache, const Loop *L, Value *ExitCond, bool ExitIfTrue,
9303 bool ControlsOnlyExit, bool AllowPredicates) {
9304
9305 if (auto MaybeEL = Cache.find(L, ExitCond, ExitIfTrue, ControlsOnlyExit,
9306 AllowPredicates))
9307 return *MaybeEL;
9308
9309 ExitLimit EL = computeExitLimitFromCondImpl(
9310 Cache, L, ExitCond, ExitIfTrue, ControlsOnlyExit, AllowPredicates);
9311 Cache.insert(L, ExitCond, ExitIfTrue, ControlsOnlyExit, AllowPredicates, EL);
9312 return EL;
9313}
9314
9315ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromCondImpl(
9316 ExitLimitCacheTy &Cache, const Loop *L, Value *ExitCond, bool ExitIfTrue,
9317 bool ControlsOnlyExit, bool AllowPredicates) {
9318 // Handle BinOp conditions (And, Or).
9319 if (auto LimitFromBinOp = computeExitLimitFromCondFromBinOp(
9320 Cache, L, ExitCond, ExitIfTrue, AllowPredicates))
9321 return *LimitFromBinOp;
9322
9323 // With an icmp, it may be feasible to compute an exact backedge-taken count.
9324 // Proceed to the next level to examine the icmp.
9325 if (ICmpInst *ExitCondICmp = dyn_cast<ICmpInst>(ExitCond)) {
9326 ExitLimit EL =
9327 computeExitLimitFromICmp(L, ExitCondICmp, ExitIfTrue, ControlsOnlyExit);
9328 if (EL.hasFullInfo() || !AllowPredicates)
9329 return EL;
9330
9331 // Try again, but use SCEV predicates this time.
9332 return computeExitLimitFromICmp(L, ExitCondICmp, ExitIfTrue,
9333 ControlsOnlyExit,
9334 /*AllowPredicates=*/true);
9335 }
9336
9337 // Check for a constant condition. These are normally stripped out by
9338 // SimplifyCFG, but ScalarEvolution may be used by a pass which wishes to
9339 // preserve the CFG and is temporarily leaving constant conditions
9340 // in place.
9341 if (ConstantInt *CI = dyn_cast<ConstantInt>(ExitCond)) {
9342 if (ExitIfTrue == !CI->getZExtValue())
9343 // The backedge is always taken.
9344 return getCouldNotCompute();
9345 // The backedge is never taken.
9346 return getZero(CI->getType());
9347 }
9348
9349 // If we're exiting based on the overflow flag of an x.with.overflow intrinsic
9350 // with a constant step, we can form an equivalent icmp predicate and figure
9351 // out how many iterations will be taken before we exit.
9352 const WithOverflowInst *WO;
9353 const APInt *C;
9354 if (match(ExitCond, m_ExtractValue<1>(m_WithOverflowInst(WO))) &&
9355 match(WO->getRHS(), m_APInt(C))) {
9356 ConstantRange NWR =
9358 WO->getNoWrapKind());
9359 CmpInst::Predicate Pred;
9360 APInt NewRHSC, Offset;
9361 NWR.getEquivalentICmp(Pred, NewRHSC, Offset);
9362 if (!ExitIfTrue)
9363 Pred = ICmpInst::getInversePredicate(Pred);
9364 auto *LHS = getSCEV(WO->getLHS());
9365 if (Offset != 0)
9367 auto EL = computeExitLimitFromICmp(L, Pred, LHS, getConstant(NewRHSC),
9368 ControlsOnlyExit, AllowPredicates);
9369 if (EL.hasAnyInfo())
9370 return EL;
9371 }
9372
9373 // If it's not an integer or pointer comparison then compute it the hard way.
9374 return computeExitCountExhaustively(L, ExitCond, ExitIfTrue);
9375}
9376
9377std::optional<ScalarEvolution::ExitLimit>
9378ScalarEvolution::computeExitLimitFromCondFromBinOp(ExitLimitCacheTy &Cache,
9379 const Loop *L,
9380 Value *ExitCond,
9381 bool ExitIfTrue,
9382 bool AllowPredicates) {
9383 // Check if the controlling expression for this loop is an And or Or.
9384 Value *Op0, *Op1;
9385 bool IsAnd;
9386 if (match(ExitCond, m_LogicalAnd(m_Value(Op0), m_Value(Op1))))
9387 IsAnd = true;
9388 else if (match(ExitCond, m_LogicalOr(m_Value(Op0), m_Value(Op1))))
9389 IsAnd = false;
9390 else
9391 return std::nullopt;
9392
9393 // A sub-condition of a non-trivial binop never solely controls the exit,
9394 // whether we exit always depends on both conditions.
9395 ExitLimit EL0 = computeExitLimitFromCondCached(
9396 Cache, L, Op0, ExitIfTrue, /*ControlsOnlyExit=*/false, AllowPredicates);
9397 ExitLimit EL1 = computeExitLimitFromCondCached(
9398 Cache, L, Op1, ExitIfTrue, /*ControlsOnlyExit=*/false, AllowPredicates);
9399
9400 // EitherMayExit is true in these two cases:
9401 // br (and Op0 Op1), loop, exit
9402 // br (or Op0 Op1), exit, loop
9403 bool EitherMayExit = IsAnd ^ ExitIfTrue;
9404
9405 const SCEV *BECount = getCouldNotCompute();
9406 const SCEV *ConstantMaxBECount = getCouldNotCompute();
9407 const SCEV *SymbolicMaxBECount = getCouldNotCompute();
9408 if (EitherMayExit) {
9409 bool UseSequentialUMin = !isa<BinaryOperator>(ExitCond);
9410 // Both conditions must be same for the loop to continue executing.
9411 // Choose the less conservative count.
9412 if (EL0.ExactNotTaken != getCouldNotCompute() &&
9413 EL1.ExactNotTaken != getCouldNotCompute()) {
9414 BECount = getUMinFromMismatchedTypes(EL0.ExactNotTaken, EL1.ExactNotTaken,
9415 UseSequentialUMin);
9416 }
9417 if (EL0.ConstantMaxNotTaken == getCouldNotCompute())
9418 ConstantMaxBECount = EL1.ConstantMaxNotTaken;
9419 else if (EL1.ConstantMaxNotTaken == getCouldNotCompute())
9420 ConstantMaxBECount = EL0.ConstantMaxNotTaken;
9421 else
9422 ConstantMaxBECount = getUMinFromMismatchedTypes(EL0.ConstantMaxNotTaken,
9423 EL1.ConstantMaxNotTaken);
9424 if (EL0.SymbolicMaxNotTaken == getCouldNotCompute())
9425 SymbolicMaxBECount = EL1.SymbolicMaxNotTaken;
9426 else if (EL1.SymbolicMaxNotTaken == getCouldNotCompute())
9427 SymbolicMaxBECount = EL0.SymbolicMaxNotTaken;
9428 else
9429 SymbolicMaxBECount = getUMinFromMismatchedTypes(
9430 EL0.SymbolicMaxNotTaken, EL1.SymbolicMaxNotTaken, UseSequentialUMin);
9431 } else {
9432 // Both conditions must be same at the same time for the loop to exit.
9433 // For now, be conservative.
9434 if (EL0.ExactNotTaken == EL1.ExactNotTaken)
9435 BECount = EL0.ExactNotTaken;
9436 }
9437
9438 // There are cases (e.g. PR26207) where computeExitLimitFromCond is able
9439 // to be more aggressive when computing BECount than when computing
9440 // ConstantMaxBECount. In these cases it is possible for EL0.ExactNotTaken
9441 // and
9442 // EL1.ExactNotTaken to match, but for EL0.ConstantMaxNotTaken and
9443 // EL1.ConstantMaxNotTaken to not.
9444 if (isa<SCEVCouldNotCompute>(ConstantMaxBECount) &&
9445 !isa<SCEVCouldNotCompute>(BECount))
9446 ConstantMaxBECount = getConstant(getUnsignedRangeMax(BECount));
9447 if (isa<SCEVCouldNotCompute>(SymbolicMaxBECount))
9448 SymbolicMaxBECount =
9449 isa<SCEVCouldNotCompute>(BECount) ? ConstantMaxBECount : BECount;
9450 return ExitLimit(BECount, ConstantMaxBECount, SymbolicMaxBECount, false,
9451 {ArrayRef(EL0.Predicates), ArrayRef(EL1.Predicates)});
9452}
9453
9454ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromICmp(
9455 const Loop *L, ICmpInst *ExitCond, bool ExitIfTrue, bool ControlsOnlyExit,
9456 bool AllowPredicates) {
9457 // If the condition was exit on true, convert the condition to exit on false
9458 CmpPredicate Pred;
9459 if (!ExitIfTrue)
9460 Pred = ExitCond->getCmpPredicate();
9461 else
9462 Pred = ExitCond->getInverseCmpPredicate();
9463 const ICmpInst::Predicate OriginalPred = Pred;
9464
9465 const SCEV *LHS = getSCEV(ExitCond->getOperand(0));
9466 const SCEV *RHS = getSCEV(ExitCond->getOperand(1));
9467
9468 ExitLimit EL = computeExitLimitFromICmp(L, Pred, LHS, RHS, ControlsOnlyExit,
9469 AllowPredicates);
9470 if (EL.hasAnyInfo())
9471 return EL;
9472
9473 auto *ExhaustiveCount =
9474 computeExitCountExhaustively(L, ExitCond, ExitIfTrue);
9475
9476 if (!isa<SCEVCouldNotCompute>(ExhaustiveCount))
9477 return ExhaustiveCount;
9478
9479 return computeShiftCompareExitLimit(ExitCond->getOperand(0),
9480 ExitCond->getOperand(1), L, OriginalPred);
9481}
9482ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromICmp(
9483 const Loop *L, CmpPredicate Pred, SCEVUse LHS, SCEVUse RHS,
9484 bool ControlsOnlyExit, bool AllowPredicates) {
9485
9486 // Try to evaluate any dependencies out of the loop.
9487 LHS = getSCEVAtScope(LHS, L);
9488 RHS = getSCEVAtScope(RHS, L);
9489
9490 // At this point, we would like to compute how many iterations of the
9491 // loop the predicate will return true for these inputs.
9492 if (isLoopInvariant(LHS, L) && !isLoopInvariant(RHS, L)) {
9493 // If there is a loop-invariant, force it into the RHS.
9494 std::swap(LHS, RHS);
9496 }
9497
9498 bool ControllingFiniteLoop = ControlsOnlyExit && loopHasNoAbnormalExits(L) &&
9500 // Simplify the operands before analyzing them.
9501 (void)SimplifyICmpOperands(Pred, LHS, RHS, /*Depth=*/0);
9502
9503 // If we have a comparison of a chrec against a constant, try to use value
9504 // ranges to answer this query.
9505 if (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS))
9506 if (const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(LHS))
9507 if (AddRec->getLoop() == L) {
9508 // Form the constant range.
9509 ConstantRange CompRange =
9510 ConstantRange::makeExactICmpRegion(Pred, RHSC->getAPInt());
9511
9512 const SCEV *Ret = AddRec->getNumIterationsInRange(CompRange, *this);
9513 if (!isa<SCEVCouldNotCompute>(Ret)) return Ret;
9514 }
9515
9516 // If this loop must exit based on this condition (or execute undefined
9517 // behaviour), see if we can improve wrap flags. This is essentially
9518 // a must execute style proof.
9519 if (ControllingFiniteLoop && isLoopInvariant(RHS, L)) {
9520 // If we can prove the test sequence produced must repeat the same values
9521 // on self-wrap of the IV, then we can infer that IV doesn't self wrap
9522 // because if it did, we'd have an infinite (undefined) loop.
9523 // TODO: We can peel off any functions which are invertible *in L*. Loop
9524 // invariant terms are effectively constants for our purposes here.
9525 SCEVUse InnerLHS = LHS;
9526 if (auto *ZExt = dyn_cast<SCEVZeroExtendExpr>(LHS))
9527 InnerLHS = ZExt->getOperand();
9528 if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(InnerLHS);
9529 AR && !AR->hasNoSelfWrap() && AR->getLoop() == L && AR->isAffine() &&
9530 isKnownToBeAPowerOfTwo(AR->getStepRecurrence(*this), /*OrZero=*/true,
9531 /*OrNegative=*/true)) {
9532 auto Flags = AR->getNoWrapFlags();
9533 Flags = setFlags(Flags, SCEV::FlagNW);
9534 SmallVector<SCEVUse> Operands{AR->operands()};
9535 Flags = StrengthenNoWrapFlags(this, scAddRecExpr, Operands, Flags);
9536 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), Flags);
9537 }
9538
9539 // For a slt/ult condition with a positive step, can we prove nsw/nuw?
9540 // From no-self-wrap, this follows trivially from the fact that every
9541 // (un)signed-wrapped, but not self-wrapped value must be LT than the
9542 // last value before (un)signed wrap. Since we know that last value
9543 // didn't exit, nor will any smaller one.
9544 if (Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_ULT) {
9545 auto WrapType = Pred == ICmpInst::ICMP_SLT ? SCEV::FlagNSW : SCEV::FlagNUW;
9546 if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(LHS);
9547 AR && AR->getLoop() == L && AR->isAffine() &&
9548 !AR->getNoWrapFlags(WrapType) && AR->hasNoSelfWrap() &&
9549 isKnownPositive(AR->getStepRecurrence(*this))) {
9550 auto Flags = AR->getNoWrapFlags();
9551 Flags = setFlags(Flags, WrapType);
9552 SmallVector<SCEVUse> Operands{AR->operands()};
9553 Flags = StrengthenNoWrapFlags(this, scAddRecExpr, Operands, Flags);
9554 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), Flags);
9555 }
9556 }
9557 }
9558
9559 switch (Pred) {
9560 case ICmpInst::ICMP_NE: { // while (X != Y)
9561 // Convert to: while (X-Y != 0)
9562 if (LHS->getType()->isPointerTy()) {
9565 return LHS;
9566 }
9567 if (RHS->getType()->isPointerTy()) {
9570 return RHS;
9571 }
9572 ExitLimit EL = howFarToZero(getMinusSCEV(LHS, RHS), L, ControlsOnlyExit,
9573 AllowPredicates);
9574 if (EL.hasAnyInfo())
9575 return EL;
9576 break;
9577 }
9578 case ICmpInst::ICMP_EQ: { // while (X == Y)
9579 // Convert to: while (X-Y == 0)
9580 if (LHS->getType()->isPointerTy()) {
9583 return LHS;
9584 }
9585 if (RHS->getType()->isPointerTy()) {
9588 return RHS;
9589 }
9590 ExitLimit EL = howFarToNonZero(getMinusSCEV(LHS, RHS), L);
9591 if (EL.hasAnyInfo()) return EL;
9592 break;
9593 }
9594 case ICmpInst::ICMP_SLE:
9595 case ICmpInst::ICMP_ULE:
9596 // Since the loop is finite, an invariant RHS cannot include the boundary
9597 // value, otherwise it would loop forever.
9598 if (!EnableFiniteLoopControl || !ControllingFiniteLoop ||
9599 !isLoopInvariant(RHS, L)) {
9600 // Otherwise, perform the addition in a wider type, to avoid overflow.
9601 // If the LHS is an addrec with the appropriate nowrap flag, the
9602 // extension will be sunk into it and the exit count can be analyzed.
9603 auto *OldType = dyn_cast<IntegerType>(LHS->getType());
9604 if (!OldType)
9605 break;
9606 // Prefer doubling the bitwidth over adding a single bit to make it more
9607 // likely that we use a legal type.
9608 auto *NewType =
9609 Type::getIntNTy(OldType->getContext(), OldType->getBitWidth() * 2);
9610 if (ICmpInst::isSigned(Pred)) {
9611 LHS = getSignExtendExpr(LHS, NewType);
9612 RHS = getSignExtendExpr(RHS, NewType);
9613 } else {
9614 LHS = getZeroExtendExpr(LHS, NewType);
9615 RHS = getZeroExtendExpr(RHS, NewType);
9616 }
9617 }
9619 [[fallthrough]];
9620 case ICmpInst::ICMP_SLT:
9621 case ICmpInst::ICMP_ULT: { // while (X < Y)
9622 bool IsSigned = ICmpInst::isSigned(Pred);
9623 ExitLimit EL = howManyLessThans(LHS, RHS, L, IsSigned, ControlsOnlyExit,
9624 AllowPredicates);
9625 if (EL.hasAnyInfo())
9626 return EL;
9627 break;
9628 }
9629 case ICmpInst::ICMP_SGE:
9630 case ICmpInst::ICMP_UGE:
9631 // Since the loop is finite, an invariant RHS cannot include the boundary
9632 // value, otherwise it would loop forever.
9633 if (!EnableFiniteLoopControl || !ControllingFiniteLoop ||
9634 !isLoopInvariant(RHS, L))
9635 break;
9637 [[fallthrough]];
9638 case ICmpInst::ICMP_SGT:
9639 case ICmpInst::ICMP_UGT: { // while (X > Y)
9640 bool IsSigned = ICmpInst::isSigned(Pred);
9641 ExitLimit EL = howManyGreaterThans(LHS, RHS, L, IsSigned, ControlsOnlyExit,
9642 AllowPredicates);
9643 if (EL.hasAnyInfo())
9644 return EL;
9645 break;
9646 }
9647 default:
9648 break;
9649 }
9650
9651 return getCouldNotCompute();
9652}
9653
9654ScalarEvolution::ExitLimit
9655ScalarEvolution::computeExitLimitFromSingleExitSwitch(const Loop *L,
9656 SwitchInst *Switch,
9657 BasicBlock *ExitingBlock,
9658 bool ControlsOnlyExit) {
9659 assert(!L->contains(ExitingBlock) && "Not an exiting block!");
9660
9661 // Give up if the exit is the default dest of a switch.
9662 if (Switch->getDefaultDest() == ExitingBlock)
9663 return getCouldNotCompute();
9664
9665 assert(L->contains(Switch->getDefaultDest()) &&
9666 "Default case must not exit the loop!");
9667 const SCEV *LHS = getSCEVAtScope(Switch->getCondition(), L);
9668 const SCEV *RHS = getConstant(Switch->findCaseDest(ExitingBlock));
9669
9670 // while (X != Y) --> while (X-Y != 0)
9671 ExitLimit EL = howFarToZero(getMinusSCEV(LHS, RHS), L, ControlsOnlyExit);
9672 if (EL.hasAnyInfo())
9673 return EL;
9674
9675 return getCouldNotCompute();
9676}
9677
9678static ConstantInt *
9680 ScalarEvolution &SE) {
9681 const SCEV *InVal = SE.getConstant(C);
9682 const SCEV *Val = AddRec->evaluateAtIteration(InVal, SE);
9684 "Evaluation of SCEV at constant didn't fold correctly?");
9685 return cast<SCEVConstant>(Val)->getValue();
9686}
9687
9688ScalarEvolution::ExitLimit ScalarEvolution::computeShiftCompareExitLimit(
9689 Value *LHS, Value *RHSV, const Loop *L, ICmpInst::Predicate Pred) {
9690 ConstantInt *RHS = dyn_cast<ConstantInt>(RHSV);
9691 if (!RHS)
9692 return getCouldNotCompute();
9693
9694 const BasicBlock *Latch = L->getLoopLatch();
9695 if (!Latch)
9696 return getCouldNotCompute();
9697
9698 const BasicBlock *Predecessor = L->getLoopPredecessor();
9699 if (!Predecessor)
9700 return getCouldNotCompute();
9701
9702 // Return true if V is of the form "LHS `shift_op` <positive constant>".
9703 // Return LHS in OutLHS, shift_op in OutOpCode, and the shift amount in
9704 // OutShiftAmt.
9705 auto MatchPositiveShift = [](Value *V, Value *&OutLHS,
9706 Instruction::BinaryOps &OutOpCode,
9707 unsigned &OutShiftAmt) {
9708 using namespace PatternMatch;
9709
9710 ConstantInt *ShiftAmt;
9711 if (match(V, m_LShr(m_Value(OutLHS), m_ConstantInt(ShiftAmt))))
9712 OutOpCode = Instruction::LShr;
9713 else if (match(V, m_AShr(m_Value(OutLHS), m_ConstantInt(ShiftAmt))))
9714 OutOpCode = Instruction::AShr;
9715 else if (match(V, m_Shl(m_Value(OutLHS), m_ConstantInt(ShiftAmt))))
9716 OutOpCode = Instruction::Shl;
9717 else
9718 return false;
9719
9720 uint64_t Amt = ShiftAmt->getValue().getLimitedValue();
9721 if (Amt == 0 || Amt >= OutLHS->getType()->getScalarSizeInBits())
9722 return false;
9723 OutShiftAmt = Amt;
9724 return true;
9725 };
9726
9727 // Recognize a "shift recurrence" either of the form %iv or of %iv.shifted in
9728 //
9729 // loop:
9730 // %iv = phi i32 [ %iv.shifted, %loop ], [ %val, %preheader ]
9731 // %iv.shifted = lshr i32 %iv, <positive constant>
9732 //
9733 // Return true on a successful match. Return the corresponding PHI node (%iv
9734 // above) in PNOut, the opcode of the shift operation in OpCodeOut, and the
9735 // shift amount in ShiftAmtOut.
9736 auto MatchShiftRecurrence = [&](Value *V, PHINode *&PNOut,
9737 Instruction::BinaryOps &OpCodeOut,
9738 unsigned &ShiftAmtOut) {
9739 std::optional<Instruction::BinaryOps> PostShiftOpCode;
9740
9741 {
9743 Value *V;
9744 unsigned Amt;
9745
9746 // If we encounter a shift instruction, "peel off" the shift operation,
9747 // and remember that we did so. Later when we inspect %iv's backedge
9748 // value, we will make sure that the backedge value uses the same
9749 // operation.
9750 //
9751 // Note: the peeled shift operation does not have to be the same
9752 // instruction as the one feeding into the PHI's backedge value. We only
9753 // really care about it being the same *kind* of shift instruction --
9754 // that's all that is required for our later inferences to hold.
9755 if (MatchPositiveShift(LHS, V, OpC, Amt)) {
9756 PostShiftOpCode = OpC;
9757 LHS = V;
9758 }
9759 }
9760
9761 PNOut = dyn_cast<PHINode>(LHS);
9762 if (!PNOut || PNOut->getParent() != L->getHeader())
9763 return false;
9764
9765 Value *BEValue = PNOut->getIncomingValueForBlock(Latch);
9766 Value *OpLHS;
9767
9768 return
9769 // The backedge value for the PHI node must be a shift by a positive
9770 // amount
9771 MatchPositiveShift(BEValue, OpLHS, OpCodeOut, ShiftAmtOut) &&
9772
9773 // of the PHI node itself
9774 OpLHS == PNOut &&
9775
9776 // and the kind of shift should be match the kind of shift we peeled
9777 // off, if any.
9778 (!PostShiftOpCode || *PostShiftOpCode == OpCodeOut);
9779 };
9780
9781 PHINode *PN;
9783 unsigned ShiftAmt;
9784 if (!MatchShiftRecurrence(LHS, PN, OpCode, ShiftAmt))
9785 return getCouldNotCompute();
9786
9787 const DataLayout &DL = getDataLayout();
9788
9789 // The key rationale for this optimization is that for some kinds of shift
9790 // recurrences, the value of the recurrence "stabilizes" to either 0 or -1
9791 // within a finite number of iterations. If the condition guarding the
9792 // backedge (in the sense that the backedge is taken if the condition is true)
9793 // is false for the value the shift recurrence stabilizes to, then we know
9794 // that the backedge is taken only a finite number of times.
9795
9796 ConstantInt *StableValue = nullptr;
9797 switch (OpCode) {
9798 default:
9799 llvm_unreachable("Impossible case!");
9800
9801 case Instruction::AShr: {
9802 // {K,ashr,<positive-constant>} stabilizes to signum(K) in at most
9803 // bitwidth(K) iterations.
9804 Value *FirstValue = PN->getIncomingValueForBlock(Predecessor);
9805 KnownBits Known = computeKnownBits(FirstValue, DL, &AC,
9806 Predecessor->getTerminator(), &DT);
9807 auto *Ty = cast<IntegerType>(RHS->getType());
9808 if (Known.isNonNegative())
9809 StableValue = ConstantInt::get(Ty, 0);
9810 else if (Known.isNegative())
9811 StableValue = ConstantInt::get(Ty, -1, true);
9812 else
9813 return getCouldNotCompute();
9814
9815 break;
9816 }
9817 case Instruction::LShr:
9818 case Instruction::Shl:
9819 // Both {K,lshr,<positive-constant>} and {K,shl,<positive-constant>}
9820 // stabilize to 0 in at most bitwidth(K) iterations.
9821 StableValue = ConstantInt::get(cast<IntegerType>(RHS->getType()), 0);
9822 break;
9823 }
9824
9825 auto *Result =
9826 ConstantFoldCompareInstOperands(Pred, StableValue, RHS, DL, &TLI);
9827 assert(Result->getType()->isIntegerTy(1) &&
9828 "Otherwise cannot be an operand to a branch instruction");
9829
9830 if (Result->isNullValue()) {
9831 unsigned BitWidth = getTypeSizeInBits(RHS->getType());
9832 unsigned MaxBTC = BitWidth;
9833
9834 // For right-shift recurrences (lshr/ashr with non-negative start), we can
9835 // compute a tighter max backedge-taken count from the range of the start
9836 // value. After k shifts of ShiftAmt, value = start >> (k * ShiftAmt).
9837 // The value reaches 0 (the stable value) when k * ShiftAmt >=
9838 // activeBits(start), so max BTC = ceil(activeBits(maxStart) / ShiftAmt).
9839 if (OpCode == Instruction::LShr || OpCode == Instruction::AShr) {
9840 Value *StartValue = PN->getIncomingValueForBlock(Predecessor);
9841 const SCEV *StartSCEV = getSCEV(StartValue);
9842 APInt MaxStart = getUnsignedRangeMax(StartSCEV);
9843 if (MaxStart.isStrictlyPositive()) {
9844 unsigned ActiveBits = MaxStart.getActiveBits();
9845 unsigned RangeBTC = divideCeil(ActiveBits, ShiftAmt);
9846 MaxBTC = std::min(MaxBTC, RangeBTC);
9847 }
9848 }
9849
9850 const SCEV *UpperBound =
9852 return ExitLimit(getCouldNotCompute(), UpperBound, UpperBound, false);
9853 }
9854
9855 return getCouldNotCompute();
9856}
9857
9858/// Return true if we can constant fold an instruction of the specified type,
9859/// assuming that all operands were constants.
9860static bool CanConstantFold(const Instruction *I) {
9864 return true;
9865
9866 if (const CallInst *CI = dyn_cast<CallInst>(I))
9867 if (const Function *F = CI->getCalledFunction())
9868 return canConstantFoldCallTo(CI, F);
9869 return false;
9870}
9871
9872/// Determine whether this instruction can constant evolve within this loop
9873/// assuming its operands can all constant evolve.
9874static bool canConstantEvolve(Instruction *I, const Loop *L) {
9875 // An instruction outside of the loop can't be derived from a loop PHI.
9876 if (!L->contains(I)) return false;
9877
9878 if (isa<PHINode>(I)) {
9879 // We don't currently keep track of the control flow needed to evaluate
9880 // PHIs, so we cannot handle PHIs inside of loops.
9881 return L->getHeader() == I->getParent();
9882 }
9883
9884 // If we won't be able to constant fold this expression even if the operands
9885 // are constants, bail early.
9886 return CanConstantFold(I);
9887}
9888
9889/// getConstantEvolvingPHIOperands - Implement getConstantEvolvingPHI by
9890/// recursing through each instruction operand until reaching a loop header phi.
9891static PHINode *
9894 unsigned Depth) {
9896 return nullptr;
9897
9898 // Otherwise, we can evaluate this instruction if all of its operands are
9899 // constant or derived from a PHI node themselves.
9900 PHINode *PHI = nullptr;
9901 for (Value *Op : UseInst->operands()) {
9902 if (isa<Constant>(Op)) continue;
9903
9905 if (!OpInst || !canConstantEvolve(OpInst, L)) return nullptr;
9906
9907 PHINode *P = dyn_cast<PHINode>(OpInst);
9908 if (!P)
9909 // If this operand is already visited, reuse the prior result.
9910 // We may have P != PHI if this is the deepest point at which the
9911 // inconsistent paths meet.
9912 P = PHIMap.lookup(OpInst);
9913 if (!P) {
9914 // Recurse and memoize the results, whether a phi is found or not.
9915 // This recursive call invalidates pointers into PHIMap.
9916 P = getConstantEvolvingPHIOperands(OpInst, L, PHIMap, Depth + 1);
9917 PHIMap[OpInst] = P;
9918 }
9919 if (!P)
9920 return nullptr; // Not evolving from PHI
9921 if (PHI && PHI != P)
9922 return nullptr; // Evolving from multiple different PHIs.
9923 PHI = P;
9924 }
9925 // This is a expression evolving from a constant PHI!
9926 return PHI;
9927}
9928
9929/// getConstantEvolvingPHI - Given an LLVM value and a loop, return a PHI node
9930/// in the loop that V is derived from. We allow arbitrary operations along the
9931/// way, but the operands of an operation must either be constants or a value
9932/// derived from a constant PHI. If this expression does not fit with these
9933/// constraints, return null.
9936 if (!I || !canConstantEvolve(I, L)) return nullptr;
9937
9938 if (PHINode *PN = dyn_cast<PHINode>(I))
9939 return PN;
9940
9941 // Record non-constant instructions contained by the loop.
9943 return getConstantEvolvingPHIOperands(I, L, PHIMap, 0);
9944}
9945
9946/// EvaluateExpression - Given an expression that passes the
9947/// getConstantEvolvingPHI predicate, evaluate its value assuming the PHI node
9948/// in the loop has the value PHIVal. If we can't fold this expression for some
9949/// reason, return null.
9952 const DataLayout &DL,
9953 const TargetLibraryInfo *TLI) {
9954 // Convenient constant check, but redundant for recursive calls.
9955 if (Constant *C = dyn_cast<Constant>(V)) return C;
9957 if (!I) return nullptr;
9958
9959 if (Constant *C = Vals.lookup(I)) return C;
9960
9961 // An instruction inside the loop depends on a value outside the loop that we
9962 // weren't given a mapping for, or a value such as a call inside the loop.
9963 if (!canConstantEvolve(I, L)) return nullptr;
9964
9965 // An unmapped PHI can be due to a branch or another loop inside this loop,
9966 // or due to this not being the initial iteration through a loop where we
9967 // couldn't compute the evolution of this particular PHI last time.
9968 if (isa<PHINode>(I)) return nullptr;
9969
9970 std::vector<Constant*> Operands(I->getNumOperands());
9971
9972 for (unsigned i = 0, e = I->getNumOperands(); i != e; ++i) {
9973 Instruction *Operand = dyn_cast<Instruction>(I->getOperand(i));
9974 if (!Operand) {
9975 Operands[i] = dyn_cast<Constant>(I->getOperand(i));
9976 if (!Operands[i]) return nullptr;
9977 continue;
9978 }
9979 Constant *C = EvaluateExpression(Operand, L, Vals, DL, TLI);
9980 Vals[Operand] = C;
9981 if (!C) return nullptr;
9982 Operands[i] = C;
9983 }
9984
9985 return ConstantFoldInstOperands(I, Operands, DL, TLI,
9986 /*AllowNonDeterministic=*/false);
9987}
9988
9989
9990// If every incoming value to PN except the one for BB is a specific Constant,
9991// return that, else return nullptr.
9993 Constant *IncomingVal = nullptr;
9994
9995 for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) {
9996 if (PN->getIncomingBlock(i) == BB)
9997 continue;
9998
9999 auto *CurrentVal = dyn_cast<Constant>(PN->getIncomingValue(i));
10000 if (!CurrentVal)
10001 return nullptr;
10002
10003 if (IncomingVal != CurrentVal) {
10004 if (IncomingVal)
10005 return nullptr;
10006 IncomingVal = CurrentVal;
10007 }
10008 }
10009
10010 return IncomingVal;
10011}
10012
10013/// getConstantEvolutionLoopExitValue - If we know that the specified Phi is
10014/// in the header of its containing loop, we know the loop executes a
10015/// constant number of times, and the PHI node is just a recurrence
10016/// involving constants, fold it.
10017Constant *
10018ScalarEvolution::getConstantEvolutionLoopExitValue(PHINode *PN,
10019 const APInt &BEs,
10020 const Loop *L) {
10021 auto [I, Inserted] = ConstantEvolutionLoopExitValue.try_emplace(PN);
10022 if (!Inserted)
10023 return I->second;
10024
10026 return nullptr; // Not going to evaluate it.
10027
10028 Constant *&RetVal = I->second;
10029
10030 DenseMap<Instruction *, Constant *> CurrentIterVals;
10031 BasicBlock *Header = L->getHeader();
10032 assert(PN->getParent() == Header && "Can't evaluate PHI not in loop header!");
10033
10034 BasicBlock *Latch = L->getLoopLatch();
10035 if (!Latch)
10036 return nullptr;
10037
10038 for (PHINode &PHI : Header->phis()) {
10039 if (auto *StartCST = getOtherIncomingValue(&PHI, Latch))
10040 CurrentIterVals[&PHI] = StartCST;
10041 }
10042 if (!CurrentIterVals.count(PN))
10043 return RetVal = nullptr;
10044
10045 Value *BEValue = PN->getIncomingValueForBlock(Latch);
10046
10047 // Execute the loop symbolically to determine the exit value.
10048 assert(BEs.getActiveBits() < CHAR_BIT * sizeof(unsigned) &&
10049 "BEs is <= MaxBruteForceIterations which is an 'unsigned'!");
10050
10051 unsigned NumIterations = BEs.getZExtValue(); // must be in range
10052 unsigned IterationNum = 0;
10053 const DataLayout &DL = getDataLayout();
10054 for (; ; ++IterationNum) {
10055 if (IterationNum == NumIterations)
10056 return RetVal = CurrentIterVals[PN]; // Got exit value!
10057
10058 // Compute the value of the PHIs for the next iteration.
10059 // EvaluateExpression adds non-phi values to the CurrentIterVals map.
10060 DenseMap<Instruction *, Constant *> NextIterVals;
10061 Constant *NextPHI =
10062 EvaluateExpression(BEValue, L, CurrentIterVals, DL, &TLI);
10063 if (!NextPHI)
10064 return nullptr; // Couldn't evaluate!
10065 NextIterVals[PN] = NextPHI;
10066
10067 bool StoppedEvolving = NextPHI == CurrentIterVals[PN];
10068
10069 // Also evaluate the other PHI nodes. However, we don't get to stop if we
10070 // cease to be able to evaluate one of them or if they stop evolving,
10071 // because that doesn't necessarily prevent us from computing PN.
10073 for (const auto &I : CurrentIterVals) {
10074 PHINode *PHI = dyn_cast<PHINode>(I.first);
10075 if (!PHI || PHI == PN || PHI->getParent() != Header) continue;
10076 PHIsToCompute.emplace_back(PHI, I.second);
10077 }
10078 // We use two distinct loops because EvaluateExpression may invalidate any
10079 // iterators into CurrentIterVals.
10080 for (const auto &I : PHIsToCompute) {
10081 PHINode *PHI = I.first;
10082 Constant *&NextPHI = NextIterVals[PHI];
10083 if (!NextPHI) { // Not already computed.
10084 Value *BEValue = PHI->getIncomingValueForBlock(Latch);
10085 NextPHI = EvaluateExpression(BEValue, L, CurrentIterVals, DL, &TLI);
10086 }
10087 if (NextPHI != I.second)
10088 StoppedEvolving = false;
10089 }
10090
10091 // If all entries in CurrentIterVals == NextIterVals then we can stop
10092 // iterating, the loop can't continue to change.
10093 if (StoppedEvolving)
10094 return RetVal = CurrentIterVals[PN];
10095
10096 CurrentIterVals.swap(NextIterVals);
10097 }
10098}
10099
10100const SCEV *ScalarEvolution::computeExitCountExhaustively(const Loop *L,
10101 Value *Cond,
10102 bool ExitWhen) {
10103 PHINode *PN = getConstantEvolvingPHI(Cond, L);
10104 if (!PN) return getCouldNotCompute();
10105
10106 // If the loop is canonicalized, the PHI will have exactly two entries.
10107 // That's the only form we support here.
10108 if (PN->getNumIncomingValues() != 2) return getCouldNotCompute();
10109
10110 DenseMap<Instruction *, Constant *> CurrentIterVals;
10111 BasicBlock *Header = L->getHeader();
10112 assert(PN->getParent() == Header && "Can't evaluate PHI not in loop header!");
10113
10114 BasicBlock *Latch = L->getLoopLatch();
10115 assert(Latch && "Should follow from NumIncomingValues == 2!");
10116
10117 for (PHINode &PHI : Header->phis()) {
10118 if (auto *StartCST = getOtherIncomingValue(&PHI, Latch))
10119 CurrentIterVals[&PHI] = StartCST;
10120 }
10121 if (!CurrentIterVals.count(PN))
10122 return getCouldNotCompute();
10123
10124 // Okay, we find a PHI node that defines the trip count of this loop. Execute
10125 // the loop symbolically to determine when the condition gets a value of
10126 // "ExitWhen".
10127 unsigned MaxIterations = MaxBruteForceIterations; // Limit analysis.
10128 const DataLayout &DL = getDataLayout();
10129 for (unsigned IterationNum = 0; IterationNum != MaxIterations;++IterationNum){
10130 auto *CondVal = dyn_cast_or_null<ConstantInt>(
10131 EvaluateExpression(Cond, L, CurrentIterVals, DL, &TLI));
10132
10133 // Couldn't symbolically evaluate.
10134 if (!CondVal) return getCouldNotCompute();
10135
10136 if (CondVal->getValue() == uint64_t(ExitWhen)) {
10137 ++NumBruteForceTripCountsComputed;
10138 return getConstant(Type::getInt32Ty(getContext()), IterationNum);
10139 }
10140
10141 // Update all the PHI nodes for the next iteration.
10142 DenseMap<Instruction *, Constant *> NextIterVals;
10143
10144 // Create a list of which PHIs we need to compute. We want to do this before
10145 // calling EvaluateExpression on them because that may invalidate iterators
10146 // into CurrentIterVals.
10147 SmallVector<PHINode *, 8> PHIsToCompute;
10148 for (const auto &I : CurrentIterVals) {
10149 PHINode *PHI = dyn_cast<PHINode>(I.first);
10150 if (!PHI || PHI->getParent() != Header) continue;
10151 PHIsToCompute.push_back(PHI);
10152 }
10153 for (PHINode *PHI : PHIsToCompute) {
10154 Constant *&NextPHI = NextIterVals[PHI];
10155 if (NextPHI) continue; // Already computed!
10156
10157 Value *BEValue = PHI->getIncomingValueForBlock(Latch);
10158 NextPHI = EvaluateExpression(BEValue, L, CurrentIterVals, DL, &TLI);
10159 }
10160 CurrentIterVals.swap(NextIterVals);
10161 }
10162
10163 // Too many iterations were needed to evaluate.
10164 return getCouldNotCompute();
10165}
10166
10167const SCEV *ScalarEvolution::getSCEVAtScope(const SCEV *V, const Loop *L) {
10169 ValuesAtScopes[V];
10170 // Check to see if we've folded this expression at this loop before.
10171 for (auto &LS : Values)
10172 if (LS.first == L)
10173 return LS.second ? LS.second : V;
10174
10175 Values.emplace_back(L, nullptr);
10176
10177 // Otherwise compute it.
10178 const SCEV *C = computeSCEVAtScope(V, L);
10179 for (auto &LS : reverse(ValuesAtScopes[V]))
10180 if (LS.first == L) {
10181 LS.second = C;
10182 if (!isa<SCEVConstant>(C))
10183 ValuesAtScopesUsers[C].push_back({L, V});
10184 break;
10185 }
10186 return C;
10187}
10188
10189/// This builds up a Constant using the ConstantExpr interface. That way, we
10190/// will return Constants for objects which aren't represented by a
10191/// SCEVConstant, because SCEVConstant is restricted to ConstantInt.
10192/// Returns NULL if the SCEV isn't representable as a Constant.
10194 switch (V->getSCEVType()) {
10195 case scCouldNotCompute:
10196 case scAddRecExpr:
10197 case scVScale:
10198 return nullptr;
10199 case scConstant:
10200 return cast<SCEVConstant>(V)->getValue();
10201 case scUnknown:
10203 case scPtrToAddr: {
10205 if (Constant *CastOp = BuildConstantFromSCEV(P2I->getOperand()))
10206 return ConstantExpr::getPtrToAddr(CastOp, P2I->getType());
10207
10208 return nullptr;
10209 }
10210 case scPtrToInt: {
10212 if (Constant *CastOp = BuildConstantFromSCEV(P2I->getOperand()))
10213 return ConstantExpr::getPtrToInt(CastOp, P2I->getType());
10214
10215 return nullptr;
10216 }
10217 case scTruncate: {
10219 if (Constant *CastOp = BuildConstantFromSCEV(ST->getOperand()))
10220 return ConstantExpr::getTrunc(CastOp, ST->getType());
10221 return nullptr;
10222 }
10223 case scAddExpr: {
10224 const SCEVAddExpr *SA = cast<SCEVAddExpr>(V);
10225 Constant *C = nullptr;
10226 for (const SCEV *Op : SA->operands()) {
10228 if (!OpC)
10229 return nullptr;
10230 if (!C) {
10231 C = OpC;
10232 continue;
10233 }
10234 assert(!C->getType()->isPointerTy() &&
10235 "Can only have one pointer, and it must be last");
10236 if (OpC->getType()->isPointerTy()) {
10237 // The offsets have been converted to bytes. We can add bytes using
10238 // an i8 GEP.
10239 C = ConstantExpr::getPtrAdd(OpC, C);
10240 } else {
10241 C = ConstantExpr::getAdd(C, OpC);
10242 }
10243 }
10244 return C;
10245 }
10246 case scMulExpr:
10247 case scSignExtend:
10248 case scZeroExtend:
10249 case scUDivExpr:
10250 case scSMaxExpr:
10251 case scUMaxExpr:
10252 case scSMinExpr:
10253 case scUMinExpr:
10255 return nullptr;
10256 }
10257 llvm_unreachable("Unknown SCEV kind!");
10258}
10259
10260const SCEV *ScalarEvolution::getWithOperands(const SCEV *S,
10261 SmallVectorImpl<SCEVUse> &NewOps) {
10262 switch (S->getSCEVType()) {
10263 case scTruncate:
10264 case scZeroExtend:
10265 case scSignExtend:
10266 case scPtrToAddr:
10267 case scPtrToInt:
10268 return getCastExpr(S->getSCEVType(), NewOps[0], S->getType());
10269 case scAddRecExpr: {
10270 auto *AddRec = cast<SCEVAddRecExpr>(S);
10271 return getAddRecExpr(NewOps, AddRec->getLoop(), AddRec->getNoWrapFlags());
10272 }
10273 case scAddExpr:
10274 return getAddExpr(NewOps, cast<SCEVAddExpr>(S)->getNoWrapFlags());
10275 case scMulExpr:
10276 return getMulExpr(NewOps, cast<SCEVMulExpr>(S)->getNoWrapFlags());
10277 case scUDivExpr:
10278 return getUDivExpr(NewOps[0], NewOps[1]);
10279 case scUMaxExpr:
10280 case scSMaxExpr:
10281 case scUMinExpr:
10282 case scSMinExpr:
10283 return getMinMaxExpr(S->getSCEVType(), NewOps);
10285 return getSequentialMinMaxExpr(S->getSCEVType(), NewOps);
10286 case scConstant:
10287 case scVScale:
10288 case scUnknown:
10289 return S;
10290 case scCouldNotCompute:
10291 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
10292 }
10293 llvm_unreachable("Unknown SCEV kind!");
10294}
10295
10296const SCEV *ScalarEvolution::computeSCEVAtScope(const SCEV *V, const Loop *L) {
10297 switch (V->getSCEVType()) {
10298 case scConstant:
10299 case scVScale:
10300 return V;
10301 case scAddRecExpr: {
10302 // If this is a loop recurrence for a loop that does not contain L, then we
10303 // are dealing with the final value computed by the loop.
10304 const SCEVAddRecExpr *AddRec = cast<SCEVAddRecExpr>(V);
10305 // First, attempt to evaluate each operand.
10306 // Avoid performing the look-up in the common case where the specified
10307 // expression has no loop-variant portions.
10308 for (unsigned i = 0, e = AddRec->getNumOperands(); i != e; ++i) {
10309 const SCEV *OpAtScope = getSCEVAtScope(AddRec->getOperand(i), L);
10310 if (OpAtScope == AddRec->getOperand(i))
10311 continue;
10312
10313 // Okay, at least one of these operands is loop variant but might be
10314 // foldable. Build a new instance of the folded commutative expression.
10316 NewOps.reserve(AddRec->getNumOperands());
10317 append_range(NewOps, AddRec->operands().take_front(i));
10318 NewOps.push_back(OpAtScope);
10319 for (++i; i != e; ++i)
10320 NewOps.push_back(getSCEVAtScope(AddRec->getOperand(i), L));
10321
10322 const SCEV *FoldedRec = getAddRecExpr(
10323 NewOps, AddRec->getLoop(), AddRec->getNoWrapFlags(SCEV::FlagNW));
10324 AddRec = dyn_cast<SCEVAddRecExpr>(FoldedRec);
10325 // The addrec may be folded to a nonrecurrence, for example, if the
10326 // induction variable is multiplied by zero after constant folding. Go
10327 // ahead and return the folded value.
10328 if (!AddRec)
10329 return FoldedRec;
10330 break;
10331 }
10332
10333 // If the scope is outside the addrec's loop, evaluate it by using the
10334 // loop exit value of the addrec.
10335 if (!AddRec->getLoop()->contains(L)) {
10336 // To evaluate this recurrence, we need to know how many times the AddRec
10337 // loop iterates. Compute this now.
10338 const SCEV *BackedgeTakenCount = getBackedgeTakenCount(AddRec->getLoop());
10339 if (BackedgeTakenCount == getCouldNotCompute())
10340 return AddRec;
10341
10342 // Then, evaluate the AddRec.
10343 return AddRec->evaluateAtIteration(BackedgeTakenCount, *this);
10344 }
10345
10346 return AddRec;
10347 }
10348 case scTruncate:
10349 case scZeroExtend:
10350 case scSignExtend:
10351 case scPtrToAddr:
10352 case scPtrToInt:
10353 case scAddExpr:
10354 case scMulExpr:
10355 case scUDivExpr:
10356 case scUMaxExpr:
10357 case scSMaxExpr:
10358 case scUMinExpr:
10359 case scSMinExpr:
10360 case scSequentialUMinExpr: {
10361 ArrayRef<SCEVUse> Ops = V->operands();
10362 // Avoid performing the look-up in the common case where the specified
10363 // expression has no loop-variant portions.
10364 for (unsigned i = 0, e = Ops.size(); i != e; ++i) {
10365 const SCEV *OpAtScope = getSCEVAtScope(Ops[i].getPointer(), L);
10366 if (OpAtScope != Ops[i].getPointer()) {
10367 // Okay, at least one of these operands is loop variant but might be
10368 // foldable. Build a new instance of the folded commutative expression.
10370 NewOps.reserve(Ops.size());
10371 append_range(NewOps, Ops.take_front(i));
10372 NewOps.push_back(OpAtScope);
10373
10374 for (++i; i != e; ++i) {
10375 OpAtScope = getSCEVAtScope(Ops[i].getPointer(), L);
10376 NewOps.push_back(OpAtScope);
10377 }
10378
10379 return getWithOperands(V, NewOps);
10380 }
10381 }
10382 // If we got here, all operands are loop invariant.
10383 return V;
10384 }
10385 case scUnknown: {
10386 // If this instruction is evolved from a constant-evolving PHI, compute the
10387 // exit value from the loop without using SCEVs.
10388 const SCEVUnknown *SU = cast<SCEVUnknown>(V);
10390 if (!I)
10391 return V; // This is some other type of SCEVUnknown, just return it.
10392
10393 if (PHINode *PN = dyn_cast<PHINode>(I)) {
10394 const Loop *CurrLoop = this->LI[I->getParent()];
10395 // Looking for loop exit value.
10396 if (CurrLoop && CurrLoop->getParentLoop() == L &&
10397 PN->getParent() == CurrLoop->getHeader()) {
10398 // Okay, there is no closed form solution for the PHI node. Check
10399 // to see if the loop that contains it has a known backedge-taken
10400 // count. If so, we may be able to force computation of the exit
10401 // value.
10402 const SCEV *BackedgeTakenCount = getBackedgeTakenCount(CurrLoop);
10403 // This trivial case can show up in some degenerate cases where
10404 // the incoming IR has not yet been fully simplified.
10405 if (BackedgeTakenCount->isZero()) {
10406 Value *InitValue = nullptr;
10407 bool MultipleInitValues = false;
10408 for (unsigned i = 0; i < PN->getNumIncomingValues(); i++) {
10409 if (!CurrLoop->contains(PN->getIncomingBlock(i))) {
10410 if (!InitValue)
10411 InitValue = PN->getIncomingValue(i);
10412 else if (InitValue != PN->getIncomingValue(i)) {
10413 MultipleInitValues = true;
10414 break;
10415 }
10416 }
10417 }
10418 if (!MultipleInitValues && InitValue)
10419 return getSCEV(InitValue);
10420 }
10421 // Do we have a loop invariant value flowing around the backedge
10422 // for a loop which must execute the backedge?
10423 if (!isa<SCEVCouldNotCompute>(BackedgeTakenCount) &&
10424 isKnownNonZero(BackedgeTakenCount) &&
10425 PN->getNumIncomingValues() == 2) {
10426
10427 unsigned InLoopPred =
10428 CurrLoop->contains(PN->getIncomingBlock(0)) ? 0 : 1;
10429 Value *BackedgeVal = PN->getIncomingValue(InLoopPred);
10430 if (CurrLoop->isLoopInvariant(BackedgeVal))
10431 return getSCEV(BackedgeVal);
10432 }
10433 if (auto *BTCC = dyn_cast<SCEVConstant>(BackedgeTakenCount)) {
10434 // Okay, we know how many times the containing loop executes. If
10435 // this is a constant evolving PHI node, get the final value at
10436 // the specified iteration number.
10437 Constant *RV =
10438 getConstantEvolutionLoopExitValue(PN, BTCC->getAPInt(), CurrLoop);
10439 if (RV)
10440 return getSCEV(RV);
10441 }
10442 }
10443 }
10444
10445 // Okay, this is an expression that we cannot symbolically evaluate
10446 // into a SCEV. Check to see if it's possible to symbolically evaluate
10447 // the arguments into constants, and if so, try to constant propagate the
10448 // result. This is particularly useful for computing loop exit values.
10449 if (!CanConstantFold(I))
10450 return V; // This is some other type of SCEVUnknown, just return it.
10451
10452 SmallVector<Constant *, 4> Operands;
10453 Operands.reserve(I->getNumOperands());
10454 bool MadeImprovement = false;
10455 for (Value *Op : I->operands()) {
10456 if (Constant *C = dyn_cast<Constant>(Op)) {
10457 Operands.push_back(C);
10458 continue;
10459 }
10460
10461 // If any of the operands is non-constant and if they are
10462 // non-integer and non-pointer, don't even try to analyze them
10463 // with scev techniques.
10464 if (!isSCEVable(Op->getType()))
10465 return V;
10466
10467 const SCEV *OrigV = getSCEV(Op);
10468 const SCEV *OpV = getSCEVAtScope(OrigV, L);
10469 MadeImprovement |= OrigV != OpV;
10470
10472 if (!C)
10473 return V;
10474 assert(C->getType() == Op->getType() && "Type mismatch");
10475 Operands.push_back(C);
10476 }
10477
10478 // Check to see if getSCEVAtScope actually made an improvement.
10479 if (!MadeImprovement)
10480 return V; // This is some other type of SCEVUnknown, just return it.
10481
10482 Constant *C = nullptr;
10483 const DataLayout &DL = getDataLayout();
10484 C = ConstantFoldInstOperands(I, Operands, DL, &TLI,
10485 /*AllowNonDeterministic=*/false);
10486 if (!C)
10487 return V;
10488 return getSCEV(C);
10489 }
10490 case scCouldNotCompute:
10491 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
10492 }
10493 llvm_unreachable("Unknown SCEV type!");
10494}
10495
10497 return getSCEVAtScope(getSCEV(V), L);
10498}
10499
10500const SCEV *ScalarEvolution::stripInjectiveFunctions(const SCEV *S) const {
10502 return stripInjectiveFunctions(ZExt->getOperand());
10504 return stripInjectiveFunctions(SExt->getOperand());
10505 return S;
10506}
10507
10508/// Finds the minimum unsigned root of the following equation:
10509///
10510/// A * X = B (mod N)
10511///
10512/// where N = 2^BW and BW is the common bit width of A and B. The signedness of
10513/// A and B isn't important.
10514///
10515/// If the equation does not have a solution, SCEVCouldNotCompute is returned.
10516static const SCEV *
10519 ScalarEvolution &SE, const Loop *L) {
10520 uint32_t BW = A.getBitWidth();
10521 assert(BW == SE.getTypeSizeInBits(B->getType()));
10522 assert(A != 0 && "A must be non-zero.");
10523
10524 // 1. D = gcd(A, N)
10525 //
10526 // The gcd of A and N may have only one prime factor: 2. The number of
10527 // trailing zeros in A is its multiplicity
10528 uint32_t Mult2 = A.countr_zero();
10529 // D = 2^Mult2
10530
10531 // 2. Check if B is divisible by D.
10532 //
10533 // B is divisible by D if and only if the multiplicity of prime factor 2 for B
10534 // is not less than multiplicity of this prime factor for D.
10535 unsigned MinTZ = SE.getMinTrailingZeros(B);
10536 // Try again with the terminator of the loop predecessor for context-specific
10537 // result, if MinTZ s too small.
10538 if (MinTZ < Mult2 && L->getLoopPredecessor())
10539 MinTZ = SE.getMinTrailingZeros(B, L->getLoopPredecessor()->getTerminator());
10540 if (MinTZ < Mult2) {
10541 // Check if we can prove there's no remainder using URem.
10542 const SCEV *URem =
10543 SE.getURemExpr(B, SE.getConstant(APInt::getOneBitSet(BW, Mult2)));
10544 const SCEV *Zero = SE.getZero(B->getType());
10545 if (!SE.isKnownPredicate(CmpInst::ICMP_EQ, URem, Zero)) {
10546 // Try to add a predicate ensuring B is a multiple of 1 << Mult2.
10547 if (!Predicates)
10548 return SE.getCouldNotCompute();
10549
10550 // Avoid adding a predicate that is known to be false.
10551 if (SE.isKnownPredicate(CmpInst::ICMP_NE, URem, Zero))
10552 return SE.getCouldNotCompute();
10553 Predicates->push_back(SE.getEqualPredicate(URem, Zero));
10554 }
10555 }
10556
10557 // 3. Compute I: the multiplicative inverse of (A / D) in arithmetic
10558 // modulo (N / D).
10559 //
10560 // If D == 1, (N / D) == N == 2^BW, so we need one extra bit to represent
10561 // (N / D) in general. The inverse itself always fits into BW bits, though,
10562 // so we immediately truncate it.
10563 APInt AD = A.lshr(Mult2).trunc(BW - Mult2); // AD = A / D
10564 APInt I = AD.multiplicativeInverse().zext(BW);
10565
10566 // 4. Compute the minimum unsigned root of the equation:
10567 // I * (B / D) mod (N / D)
10568 // To simplify the computation, we factor out the divide by D:
10569 // (I * B mod N) / D
10570 const SCEV *D = SE.getConstant(APInt::getOneBitSet(BW, Mult2));
10571 return SE.getUDivExactExpr(SE.getMulExpr(B, SE.getConstant(I)), D);
10572}
10573
10574/// For a given quadratic addrec, generate coefficients of the corresponding
10575/// quadratic equation, multiplied by a common value to ensure that they are
10576/// integers.
10577/// The returned value is a tuple { A, B, C, M, BitWidth }, where
10578/// Ax^2 + Bx + C is the quadratic function, M is the value that A, B and C
10579/// were multiplied by, and BitWidth is the bit width of the original addrec
10580/// coefficients.
10581/// This function returns std::nullopt if the addrec coefficients are not
10582/// compile- time constants.
10583static std::optional<std::tuple<APInt, APInt, APInt, APInt, unsigned>>
10585 assert(AddRec->getNumOperands() == 3 && "This is not a quadratic chrec!");
10586 const SCEVConstant *LC = dyn_cast<SCEVConstant>(AddRec->getOperand(0));
10587 const SCEVConstant *MC = dyn_cast<SCEVConstant>(AddRec->getOperand(1));
10588 const SCEVConstant *NC = dyn_cast<SCEVConstant>(AddRec->getOperand(2));
10589 LLVM_DEBUG(dbgs() << __func__ << ": analyzing quadratic addrec: "
10590 << *AddRec << '\n');
10591
10592 // We currently can only solve this if the coefficients are constants.
10593 if (!LC || !MC || !NC) {
10594 LLVM_DEBUG(dbgs() << __func__ << ": coefficients are not constant\n");
10595 return std::nullopt;
10596 }
10597
10598 APInt L = LC->getAPInt();
10599 APInt M = MC->getAPInt();
10600 APInt N = NC->getAPInt();
10601 assert(!N.isZero() && "This is not a quadratic addrec");
10602
10603 unsigned BitWidth = LC->getAPInt().getBitWidth();
10604 unsigned NewWidth = BitWidth + 1;
10605 LLVM_DEBUG(dbgs() << __func__ << ": addrec coeff bw: "
10606 << BitWidth << '\n');
10607 // The sign-extension (as opposed to a zero-extension) here matches the
10608 // extension used in SolveQuadraticEquationWrap (with the same motivation).
10609 N = N.sext(NewWidth);
10610 M = M.sext(NewWidth);
10611 L = L.sext(NewWidth);
10612
10613 // The increments are M, M+N, M+2N, ..., so the accumulated values are
10614 // L+M, (L+M)+(M+N), (L+M)+(M+N)+(M+2N), ..., that is,
10615 // L+M, L+2M+N, L+3M+3N, ...
10616 // After n iterations the accumulated value Acc is L + nM + n(n-1)/2 N.
10617 //
10618 // The equation Acc = 0 is then
10619 // L + nM + n(n-1)/2 N = 0, or 2L + 2M n + n(n-1) N = 0.
10620 // In a quadratic form it becomes:
10621 // N n^2 + (2M-N) n + 2L = 0.
10622
10623 APInt A = N;
10624 APInt B = 2 * M - A;
10625 APInt C = 2 * L;
10626 APInt T = APInt(NewWidth, 2);
10627 LLVM_DEBUG(dbgs() << __func__ << ": equation " << A << "x^2 + " << B
10628 << "x + " << C << ", coeff bw: " << NewWidth
10629 << ", multiplied by " << T << '\n');
10630 return std::make_tuple(A, B, C, T, BitWidth);
10631}
10632
10633/// Helper function to compare optional APInts:
10634/// (a) if X and Y both exist, return min(X, Y),
10635/// (b) if neither X nor Y exist, return std::nullopt,
10636/// (c) if exactly one of X and Y exists, return that value.
10637static std::optional<APInt> MinOptional(std::optional<APInt> X,
10638 std::optional<APInt> Y) {
10639 if (X && Y) {
10640 unsigned W = std::max(X->getBitWidth(), Y->getBitWidth());
10641 APInt XW = X->sext(W);
10642 APInt YW = Y->sext(W);
10643 return XW.slt(YW) ? *X : *Y;
10644 }
10645 if (!X && !Y)
10646 return std::nullopt;
10647 return X ? *X : *Y;
10648}
10649
10650/// Helper function to truncate an optional APInt to a given BitWidth.
10651/// When solving addrec-related equations, it is preferable to return a value
10652/// that has the same bit width as the original addrec's coefficients. If the
10653/// solution fits in the original bit width, truncate it (except for i1).
10654/// Returning a value of a different bit width may inhibit some optimizations.
10655///
10656/// In general, a solution to a quadratic equation generated from an addrec
10657/// may require BW+1 bits, where BW is the bit width of the addrec's
10658/// coefficients. The reason is that the coefficients of the quadratic
10659/// equation are BW+1 bits wide (to avoid truncation when converting from
10660/// the addrec to the equation).
10661static std::optional<APInt> TruncIfPossible(std::optional<APInt> X,
10662 unsigned BitWidth) {
10663 if (!X)
10664 return std::nullopt;
10665 unsigned W = X->getBitWidth();
10667 return X->trunc(BitWidth);
10668 return X;
10669}
10670
10671/// Let c(n) be the value of the quadratic chrec {L,+,M,+,N} after n
10672/// iterations. The values L, M, N are assumed to be signed, and they
10673/// should all have the same bit widths.
10674/// Find the least n >= 0 such that c(n) = 0 in the arithmetic modulo 2^BW,
10675/// where BW is the bit width of the addrec's coefficients.
10676/// If the calculated value is a BW-bit integer (for BW > 1), it will be
10677/// returned as such, otherwise the bit width of the returned value may
10678/// be greater than BW.
10679///
10680/// This function returns std::nullopt if
10681/// (a) the addrec coefficients are not constant, or
10682/// (b) SolveQuadraticEquationWrap was unable to find a solution. For cases
10683/// like x^2 = 5, no integer solutions exist, in other cases an integer
10684/// solution may exist, but SolveQuadraticEquationWrap may fail to find it.
10685static std::optional<APInt>
10687 APInt A, B, C, M;
10688 unsigned BitWidth;
10689 auto T = GetQuadraticEquation(AddRec);
10690 if (!T)
10691 return std::nullopt;
10692
10693 std::tie(A, B, C, M, BitWidth) = *T;
10694 LLVM_DEBUG(dbgs() << __func__ << ": solving for unsigned overflow\n");
10695 std::optional<APInt> X =
10697 if (!X)
10698 return std::nullopt;
10699
10700 ConstantInt *CX = ConstantInt::get(SE.getContext(), *X);
10701 ConstantInt *V = EvaluateConstantChrecAtConstant(AddRec, CX, SE);
10702 if (!V->isZero())
10703 return std::nullopt;
10704
10705 return TruncIfPossible(X, BitWidth);
10706}
10707
10708/// Let c(n) be the value of the quadratic chrec {0,+,M,+,N} after n
10709/// iterations. The values M, N are assumed to be signed, and they
10710/// should all have the same bit widths.
10711/// Find the least n such that c(n) does not belong to the given range,
10712/// while c(n-1) does.
10713///
10714/// This function returns std::nullopt if
10715/// (a) the addrec coefficients are not constant, or
10716/// (b) SolveQuadraticEquationWrap was unable to find a solution for the
10717/// bounds of the range.
10718static std::optional<APInt>
10720 const ConstantRange &Range, ScalarEvolution &SE) {
10721 assert(AddRec->getOperand(0)->isZero() &&
10722 "Starting value of addrec should be 0");
10723 LLVM_DEBUG(dbgs() << __func__ << ": solving boundary crossing for range "
10724 << Range << ", addrec " << *AddRec << '\n');
10725 // This case is handled in getNumIterationsInRange. Here we can assume that
10726 // we start in the range.
10727 assert(Range.contains(APInt(SE.getTypeSizeInBits(AddRec->getType()), 0)) &&
10728 "Addrec's initial value should be in range");
10729
10730 APInt A, B, C, M;
10731 unsigned BitWidth;
10732 auto T = GetQuadraticEquation(AddRec);
10733 if (!T)
10734 return std::nullopt;
10735
10736 // Be careful about the return value: there can be two reasons for not
10737 // returning an actual number. First, if no solutions to the equations
10738 // were found, and second, if the solutions don't leave the given range.
10739 // The first case means that the actual solution is "unknown", the second
10740 // means that it's known, but not valid. If the solution is unknown, we
10741 // cannot make any conclusions.
10742 // Return a pair: the optional solution and a flag indicating if the
10743 // solution was found.
10744 auto SolveForBoundary =
10745 [&](APInt Bound) -> std::pair<std::optional<APInt>, bool> {
10746 // Solve for signed overflow and unsigned overflow, pick the lower
10747 // solution.
10748 LLVM_DEBUG(dbgs() << "SolveQuadraticAddRecRange: checking boundary "
10749 << Bound << " (before multiplying by " << M << ")\n");
10750 Bound *= M; // The quadratic equation multiplier.
10751
10752 std::optional<APInt> SO;
10753 if (BitWidth > 1) {
10754 LLVM_DEBUG(dbgs() << "SolveQuadraticAddRecRange: solving for "
10755 "signed overflow\n");
10757 }
10758 LLVM_DEBUG(dbgs() << "SolveQuadraticAddRecRange: solving for "
10759 "unsigned overflow\n");
10760 std::optional<APInt> UO =
10762
10763 auto LeavesRange = [&] (const APInt &X) {
10764 ConstantInt *C0 = ConstantInt::get(SE.getContext(), X);
10765 ConstantInt *V0 = EvaluateConstantChrecAtConstant(AddRec, C0, SE);
10766 if (Range.contains(V0->getValue()))
10767 return false;
10768 // X should be at least 1, so X-1 is non-negative.
10769 ConstantInt *C1 = ConstantInt::get(SE.getContext(), X-1);
10771 if (Range.contains(V1->getValue()))
10772 return true;
10773 return false;
10774 };
10775
10776 // If SolveQuadraticEquationWrap returns std::nullopt, it means that there
10777 // can be a solution, but the function failed to find it. We cannot treat it
10778 // as "no solution".
10779 if (!SO || !UO)
10780 return {std::nullopt, false};
10781
10782 // Check the smaller value first to see if it leaves the range.
10783 // At this point, both SO and UO must have values.
10784 std::optional<APInt> Min = MinOptional(SO, UO);
10785 if (LeavesRange(*Min))
10786 return { Min, true };
10787 std::optional<APInt> Max = Min == SO ? UO : SO;
10788 if (LeavesRange(*Max))
10789 return { Max, true };
10790
10791 // Solutions were found, but were eliminated, hence the "true".
10792 return {std::nullopt, true};
10793 };
10794
10795 std::tie(A, B, C, M, BitWidth) = *T;
10796 // Lower bound is inclusive, subtract 1 to represent the exiting value.
10797 APInt Lower = Range.getLower().sext(A.getBitWidth()) - 1;
10798 APInt Upper = Range.getUpper().sext(A.getBitWidth());
10799 auto SL = SolveForBoundary(Lower);
10800 auto SU = SolveForBoundary(Upper);
10801 // If any of the solutions was unknown, no meaninigful conclusions can
10802 // be made.
10803 if (!SL.second || !SU.second)
10804 return std::nullopt;
10805
10806 // Claim: The correct solution is not some value between Min and Max.
10807 //
10808 // Justification: Assuming that Min and Max are different values, one of
10809 // them is when the first signed overflow happens, the other is when the
10810 // first unsigned overflow happens. Crossing the range boundary is only
10811 // possible via an overflow (treating 0 as a special case of it, modeling
10812 // an overflow as crossing k*2^W for some k).
10813 //
10814 // The interesting case here is when Min was eliminated as an invalid
10815 // solution, but Max was not. The argument is that if there was another
10816 // overflow between Min and Max, it would also have been eliminated if
10817 // it was considered.
10818 //
10819 // For a given boundary, it is possible to have two overflows of the same
10820 // type (signed/unsigned) without having the other type in between: this
10821 // can happen when the vertex of the parabola is between the iterations
10822 // corresponding to the overflows. This is only possible when the two
10823 // overflows cross k*2^W for the same k. In such case, if the second one
10824 // left the range (and was the first one to do so), the first overflow
10825 // would have to enter the range, which would mean that either we had left
10826 // the range before or that we started outside of it. Both of these cases
10827 // are contradictions.
10828 //
10829 // Claim: In the case where SolveForBoundary returns std::nullopt, the correct
10830 // solution is not some value between the Max for this boundary and the
10831 // Min of the other boundary.
10832 //
10833 // Justification: Assume that we had such Max_A and Min_B corresponding
10834 // to range boundaries A and B and such that Max_A < Min_B. If there was
10835 // a solution between Max_A and Min_B, it would have to be caused by an
10836 // overflow corresponding to either A or B. It cannot correspond to B,
10837 // since Min_B is the first occurrence of such an overflow. If it
10838 // corresponded to A, it would have to be either a signed or an unsigned
10839 // overflow that is larger than both eliminated overflows for A. But
10840 // between the eliminated overflows and this overflow, the values would
10841 // cover the entire value space, thus crossing the other boundary, which
10842 // is a contradiction.
10843
10844 return TruncIfPossible(MinOptional(SL.first, SU.first), BitWidth);
10845}
10846
10847ScalarEvolution::ExitLimit ScalarEvolution::howFarToZero(const SCEV *V,
10848 const Loop *L,
10849 bool ControlsOnlyExit,
10850 bool AllowPredicates) {
10851
10852 // This is only used for loops with a "x != y" exit test. The exit condition
10853 // is now expressed as a single expression, V = x-y. So the exit test is
10854 // effectively V != 0. We know and take advantage of the fact that this
10855 // expression only being used in a comparison by zero context.
10856
10858 // If the value is a constant
10859 if (const SCEVConstant *C = dyn_cast<SCEVConstant>(V)) {
10860 // If the value is already zero, the branch will execute zero times.
10861 if (C->getValue()->isZero()) return C;
10862 return getCouldNotCompute(); // Otherwise it will loop infinitely.
10863 }
10864
10865 const SCEVAddRecExpr *AddRec =
10866 dyn_cast<SCEVAddRecExpr>(stripInjectiveFunctions(V));
10867
10868 if (!AddRec && AllowPredicates)
10869 // Try to make this an AddRec using runtime tests, in the first X
10870 // iterations of this loop, where X is the SCEV expression found by the
10871 // algorithm below.
10872 AddRec = convertSCEVToAddRecWithPredicates(V, L, Predicates);
10873
10874 if (!AddRec || AddRec->getLoop() != L)
10875 return getCouldNotCompute();
10876
10877 // If this is a quadratic (3-term) AddRec {L,+,M,+,N}, find the roots of
10878 // the quadratic equation to solve it.
10879 if (AddRec->isQuadratic() && AddRec->getType()->isIntegerTy()) {
10880 // We can only use this value if the chrec ends up with an exact zero
10881 // value at this index. When solving for "X*X != 5", for example, we
10882 // should not accept a root of 2.
10883 if (auto S = SolveQuadraticAddRecExact(AddRec, *this)) {
10884 const auto *R = cast<SCEVConstant>(getConstant(*S));
10885 return ExitLimit(R, R, R, false, Predicates);
10886 }
10887 return getCouldNotCompute();
10888 }
10889
10890 // Otherwise we can only handle this if it is affine.
10891 if (!AddRec->isAffine())
10892 return getCouldNotCompute();
10893
10894 // If this is an affine expression, the execution count of this branch is
10895 // the minimum unsigned root of the following equation:
10896 //
10897 // Start + Step*N = 0 (mod 2^BW)
10898 //
10899 // equivalent to:
10900 //
10901 // Step*N = -Start (mod 2^BW)
10902 //
10903 // where BW is the common bit width of Start and Step.
10904
10905 // Get the initial value for the loop.
10906 const SCEV *Start = getSCEVAtScope(AddRec->getStart(), L->getParentLoop());
10907 const SCEV *Step = getSCEVAtScope(AddRec->getOperand(1), L->getParentLoop());
10908
10909 if (!isLoopInvariant(Step, L))
10910 return getCouldNotCompute();
10911
10912 LoopGuards Guards = LoopGuards::collect(L, *this);
10913 // Specialize step for this loop so we get context sensitive facts below.
10914 const SCEV *StepWLG = applyLoopGuards(Step, Guards);
10915
10916 // For positive steps (counting up until unsigned overflow):
10917 // N = -Start/Step (as unsigned)
10918 // For negative steps (counting down to zero):
10919 // N = Start/-Step
10920 // First compute the unsigned distance from zero in the direction of Step.
10921 bool CountDown = isKnownNegative(StepWLG);
10922 if (!CountDown && !isKnownNonNegative(StepWLG))
10923 return getCouldNotCompute();
10924
10925 const SCEV *Distance = CountDown ? Start : getNegativeSCEV(Start);
10926 // Handle unitary steps, which cannot wraparound.
10927 // 1*N = -Start; -1*N = Start (mod 2^BW), so:
10928 // N = Distance (as unsigned)
10929
10930 if (match(Step, m_CombineOr(m_scev_One(), m_scev_AllOnes()))) {
10931 APInt MaxBECount = getUnsignedRangeMax(applyLoopGuards(Distance, Guards));
10932 MaxBECount = APIntOps::umin(MaxBECount, getUnsignedRangeMax(Distance));
10933
10934 // When a loop like "for (int i = 0; i != n; ++i) { /* body */ }" is rotated,
10935 // we end up with a loop whose backedge-taken count is n - 1. Detect this
10936 // case, and see if we can improve the bound.
10937 //
10938 // Explicitly handling this here is necessary because getUnsignedRange
10939 // isn't context-sensitive; it doesn't know that we only care about the
10940 // range inside the loop.
10941 const SCEV *Zero = getZero(Distance->getType());
10942 const SCEV *One = getOne(Distance->getType());
10943 const SCEV *DistancePlusOne = getAddExpr(Distance, One);
10944 if (isLoopEntryGuardedByCond(L, ICmpInst::ICMP_NE, DistancePlusOne, Zero)) {
10945 // If Distance + 1 doesn't overflow, we can compute the maximum distance
10946 // as "unsigned_max(Distance + 1) - 1".
10947 ConstantRange CR = getUnsignedRange(DistancePlusOne);
10948 MaxBECount = APIntOps::umin(MaxBECount, CR.getUnsignedMax() - 1);
10949 }
10950 return ExitLimit(Distance, getConstant(MaxBECount), Distance, false,
10951 Predicates);
10952 }
10953
10954 // If the condition controls loop exit (the loop exits only if the expression
10955 // is true) and the addition is no-wrap we can use unsigned divide to
10956 // compute the backedge count. In this case, the step may not divide the
10957 // distance, but we don't care because if the condition is "missed" the loop
10958 // will have undefined behavior due to wrapping.
10959 if (ControlsOnlyExit && AddRec->hasNoSelfWrap() &&
10960 loopHasNoAbnormalExits(AddRec->getLoop())) {
10961
10962 // If the stride is zero and the start is non-zero, the loop must be
10963 // infinite. In C++, most loops are finite by assumption, in which case the
10964 // step being zero implies UB must execute if the loop is entered.
10965 if (!(loopIsFiniteByAssumption(L) && isKnownNonZero(Start)) &&
10966 !isKnownNonZero(StepWLG))
10967 return getCouldNotCompute();
10968
10969 const SCEV *Exact =
10970 getUDivExpr(Distance, CountDown ? getNegativeSCEV(Step) : Step);
10971 const SCEV *ConstantMax = getCouldNotCompute();
10972 if (Exact != getCouldNotCompute()) {
10973 APInt MaxInt = getUnsignedRangeMax(applyLoopGuards(Exact, Guards));
10974 ConstantMax =
10976 }
10977 const SCEV *SymbolicMax =
10978 isa<SCEVCouldNotCompute>(Exact) ? ConstantMax : Exact;
10979 return ExitLimit(Exact, ConstantMax, SymbolicMax, false, Predicates);
10980 }
10981
10982 // Solve the general equation.
10983 const SCEVConstant *StepC = dyn_cast<SCEVConstant>(Step);
10984 if (!StepC || StepC->getValue()->isZero())
10985 return getCouldNotCompute();
10986 const SCEV *E = SolveLinEquationWithOverflow(
10987 StepC->getAPInt(), getNegativeSCEV(Start),
10988 AllowPredicates ? &Predicates : nullptr, *this, L);
10989
10990 const SCEV *M = E;
10991 if (E != getCouldNotCompute()) {
10992 APInt MaxWithGuards = getUnsignedRangeMax(applyLoopGuards(E, Guards));
10993 M = getConstant(APIntOps::umin(MaxWithGuards, getUnsignedRangeMax(E)));
10994 }
10995 auto *S = isa<SCEVCouldNotCompute>(E) ? M : E;
10996 return ExitLimit(E, M, S, false, Predicates);
10997}
10998
10999ScalarEvolution::ExitLimit
11000ScalarEvolution::howFarToNonZero(const SCEV *V, const Loop *L) {
11001 // Loops that look like: while (X == 0) are very strange indeed. We don't
11002 // handle them yet except for the trivial case. This could be expanded in the
11003 // future as needed.
11004
11005 // If the value is a constant, check to see if it is known to be non-zero
11006 // already. If so, the backedge will execute zero times.
11007 if (const SCEVConstant *C = dyn_cast<SCEVConstant>(V)) {
11008 if (!C->getValue()->isZero())
11009 return getZero(C->getType());
11010 return getCouldNotCompute(); // Otherwise it will loop infinitely.
11011 }
11012
11013 // We could implement others, but I really doubt anyone writes loops like
11014 // this, and if they did, they would already be constant folded.
11015 return getCouldNotCompute();
11016}
11017
11018std::pair<const BasicBlock *, const BasicBlock *>
11019ScalarEvolution::getPredecessorWithUniqueSuccessorForBB(const BasicBlock *BB)
11020 const {
11021 // If the block has a unique predecessor, then there is no path from the
11022 // predecessor to the block that does not go through the direct edge
11023 // from the predecessor to the block.
11024 if (const BasicBlock *Pred = BB->getSinglePredecessor())
11025 return {Pred, BB};
11026
11027 // A loop's header is defined to be a block that dominates the loop.
11028 // If the header has a unique predecessor outside the loop, it must be
11029 // a block that has exactly one successor that can reach the loop.
11030 if (const Loop *L = LI.getLoopFor(BB))
11031 return {L->getLoopPredecessor(), L->getHeader()};
11032
11033 return {nullptr, BB};
11034}
11035
11036/// SCEV structural equivalence is usually sufficient for testing whether two
11037/// expressions are equal, however for the purposes of looking for a condition
11038/// guarding a loop, it can be useful to be a little more general, since a
11039/// front-end may have replicated the controlling expression.
11040static bool HasSameValue(const SCEV *A, const SCEV *B) {
11041 // Quick check to see if they are the same SCEV.
11042 if (A == B) return true;
11043
11044 auto ComputesEqualValues = [](const Instruction *A, const Instruction *B) {
11045 // Not all instructions that are "identical" compute the same value. For
11046 // instance, two distinct alloca instructions allocating the same type are
11047 // identical and do not read memory; but compute distinct values.
11048 return A->isIdenticalTo(B) && (isa<BinaryOperator>(A) || isa<GetElementPtrInst>(A));
11049 };
11050
11051 // Otherwise, if they're both SCEVUnknown, it's possible that they hold
11052 // two different instructions with the same value. Check for this case.
11053 if (const SCEVUnknown *AU = dyn_cast<SCEVUnknown>(A))
11054 if (const SCEVUnknown *BU = dyn_cast<SCEVUnknown>(B))
11055 if (const Instruction *AI = dyn_cast<Instruction>(AU->getValue()))
11056 if (const Instruction *BI = dyn_cast<Instruction>(BU->getValue()))
11057 if (ComputesEqualValues(AI, BI))
11058 return true;
11059
11060 // Otherwise assume they may have a different value.
11061 return false;
11062}
11063
11064static bool MatchBinarySub(const SCEV *S, SCEVUse &LHS, SCEVUse &RHS) {
11065 const SCEV *Op0, *Op1;
11066 if (!match(S, m_scev_Add(m_SCEV(Op0), m_SCEV(Op1))))
11067 return false;
11068 if (match(Op0, m_scev_Mul(m_scev_AllOnes(), m_SCEV(RHS)))) {
11069 LHS = Op1;
11070 return true;
11071 }
11072 if (match(Op1, m_scev_Mul(m_scev_AllOnes(), m_SCEV(RHS)))) {
11073 LHS = Op0;
11074 return true;
11075 }
11076 return false;
11077}
11078
11080 SCEVUse &RHS, unsigned Depth) {
11081 bool Changed = false;
11082 // Simplifies ICMP to trivial true or false by turning it into '0 == 0' or
11083 // '0 != 0'.
11084 auto TrivialCase = [&](bool TriviallyTrue) {
11086 Pred = TriviallyTrue ? ICmpInst::ICMP_EQ : ICmpInst::ICMP_NE;
11087 return true;
11088 };
11089 // If we hit the max recursion limit bail out.
11090 if (Depth >= 3)
11091 return false;
11092
11093 const SCEV *NewLHS, *NewRHS;
11094 if (match(LHS, m_scev_c_Mul(m_SCEV(NewLHS), m_SCEVVScale())) &&
11095 match(RHS, m_scev_c_Mul(m_SCEV(NewRHS), m_SCEVVScale()))) {
11096 const SCEVMulExpr *LMul = cast<SCEVMulExpr>(LHS);
11097 const SCEVMulExpr *RMul = cast<SCEVMulExpr>(RHS);
11098
11099 // (X * vscale) pred (Y * vscale) ==> X pred Y
11100 // when both multiples are NSW.
11101 // (X * vscale) uicmp/eq/ne (Y * vscale) ==> X uicmp/eq/ne Y
11102 // when both multiples are NUW.
11103 if ((LMul->hasNoSignedWrap() && RMul->hasNoSignedWrap()) ||
11104 (LMul->hasNoUnsignedWrap() && RMul->hasNoUnsignedWrap() &&
11105 !ICmpInst::isSigned(Pred))) {
11106 LHS = NewLHS;
11107 RHS = NewRHS;
11108 Changed = true;
11109 }
11110 }
11111
11112 // Canonicalize a constant to the right side.
11113 if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(LHS)) {
11114 // Check for both operands constant.
11115 if (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS)) {
11116 if (!ICmpInst::compare(LHSC->getAPInt(), RHSC->getAPInt(), Pred))
11117 return TrivialCase(false);
11118 return TrivialCase(true);
11119 }
11120 // Otherwise swap the operands to put the constant on the right.
11121 std::swap(LHS, RHS);
11123 Changed = true;
11124 }
11125
11126 // If we're comparing an addrec with a value which is loop-invariant in the
11127 // addrec's loop, put the addrec on the left. Also make a dominance check,
11128 // as both operands could be addrecs loop-invariant in each other's loop.
11129 if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(RHS)) {
11130 const Loop *L = AR->getLoop();
11131 if (isLoopInvariant(LHS, L) && properlyDominates(LHS, L->getHeader())) {
11132 std::swap(LHS, RHS);
11134 Changed = true;
11135 }
11136 }
11137
11138 // If there's a constant operand, canonicalize comparisons with boundary
11139 // cases, and canonicalize *-or-equal comparisons to regular comparisons.
11140 if (const SCEVConstant *RC = dyn_cast<SCEVConstant>(RHS)) {
11141 const APInt &RA = RC->getAPInt();
11142
11143 bool SimplifiedByConstantRange = false;
11144
11145 if (!ICmpInst::isEquality(Pred)) {
11147 if (ExactCR.isFullSet())
11148 return TrivialCase(true);
11149 if (ExactCR.isEmptySet())
11150 return TrivialCase(false);
11151
11152 APInt NewRHS;
11153 CmpInst::Predicate NewPred;
11154 if (ExactCR.getEquivalentICmp(NewPred, NewRHS) &&
11155 ICmpInst::isEquality(NewPred)) {
11156 // We were able to convert an inequality to an equality.
11157 Pred = NewPred;
11158 RHS = getConstant(NewRHS);
11159 Changed = SimplifiedByConstantRange = true;
11160 }
11161 }
11162
11163 if (!SimplifiedByConstantRange) {
11164 switch (Pred) {
11165 default:
11166 break;
11167 case ICmpInst::ICMP_EQ:
11168 case ICmpInst::ICMP_NE:
11169 // Fold ((-1) * %a) + %b == 0 (equivalent to %b-%a == 0) into %a == %b.
11170 if (RA.isZero() && MatchBinarySub(LHS, LHS, RHS))
11171 Changed = true;
11172 break;
11173
11174 // The "Should have been caught earlier!" messages refer to the fact
11175 // that the ExactCR.isFullSet() or ExactCR.isEmptySet() check above
11176 // should have fired on the corresponding cases, and canonicalized the
11177 // check to trivial case.
11178
11179 case ICmpInst::ICMP_UGE:
11180 assert(!RA.isMinValue() && "Should have been caught earlier!");
11181 Pred = ICmpInst::ICMP_UGT;
11182 RHS = getConstant(RA - 1);
11183 Changed = true;
11184 break;
11185 case ICmpInst::ICMP_ULE:
11186 assert(!RA.isMaxValue() && "Should have been caught earlier!");
11187 Pred = ICmpInst::ICMP_ULT;
11188 RHS = getConstant(RA + 1);
11189 Changed = true;
11190 break;
11191 case ICmpInst::ICMP_SGE:
11192 assert(!RA.isMinSignedValue() && "Should have been caught earlier!");
11193 Pred = ICmpInst::ICMP_SGT;
11194 RHS = getConstant(RA - 1);
11195 Changed = true;
11196 break;
11197 case ICmpInst::ICMP_SLE:
11198 assert(!RA.isMaxSignedValue() && "Should have been caught earlier!");
11199 Pred = ICmpInst::ICMP_SLT;
11200 RHS = getConstant(RA + 1);
11201 Changed = true;
11202 break;
11203 }
11204 }
11205 }
11206
11207 // Check for obvious equality.
11208 if (HasSameValue(LHS, RHS)) {
11209 if (ICmpInst::isTrueWhenEqual(Pred))
11210 return TrivialCase(true);
11212 return TrivialCase(false);
11213 }
11214
11215 // If possible, canonicalize GE/LE comparisons to GT/LT comparisons, by
11216 // adding or subtracting 1 from one of the operands.
11217 switch (Pred) {
11218 case ICmpInst::ICMP_SLE:
11219 if (!getSignedRangeMax(RHS).isMaxSignedValue()) {
11220 RHS = getAddExpr(getConstant(RHS->getType(), 1, true), RHS,
11222 Pred = ICmpInst::ICMP_SLT;
11223 Changed = true;
11224 } else if (!getSignedRangeMin(LHS).isMinSignedValue()) {
11225 LHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), LHS,
11227 Pred = ICmpInst::ICMP_SLT;
11228 Changed = true;
11229 }
11230 break;
11231 case ICmpInst::ICMP_SGE:
11232 if (!getSignedRangeMin(RHS).isMinSignedValue()) {
11233 RHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), RHS,
11235 Pred = ICmpInst::ICMP_SGT;
11236 Changed = true;
11237 } else if (!getSignedRangeMax(LHS).isMaxSignedValue()) {
11238 LHS = getAddExpr(getConstant(RHS->getType(), 1, true), LHS,
11240 Pred = ICmpInst::ICMP_SGT;
11241 Changed = true;
11242 }
11243 break;
11244 case ICmpInst::ICMP_ULE:
11245 if (!getUnsignedRangeMax(RHS).isMaxValue()) {
11246 RHS = getAddExpr(getConstant(RHS->getType(), 1, true), RHS,
11248 Pred = ICmpInst::ICMP_ULT;
11249 Changed = true;
11250 } else if (!getUnsignedRangeMin(LHS).isMinValue()) {
11251 LHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), LHS);
11252 Pred = ICmpInst::ICMP_ULT;
11253 Changed = true;
11254 }
11255 break;
11256 case ICmpInst::ICMP_UGE:
11257 // If RHS is an op we can fold the -1, try that first.
11258 // Otherwise prefer LHS to preserve the nuw flag.
11259 if ((isa<SCEVConstant>(RHS) ||
11261 isa<SCEVConstant>(cast<SCEVNAryExpr>(RHS)->getOperand(0)))) &&
11262 !getUnsignedRangeMin(RHS).isMinValue()) {
11263 RHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), RHS);
11264 Pred = ICmpInst::ICMP_UGT;
11265 Changed = true;
11266 } else if (!getUnsignedRangeMax(LHS).isMaxValue()) {
11267 LHS = getAddExpr(getConstant(RHS->getType(), 1, true), LHS,
11269 Pred = ICmpInst::ICMP_UGT;
11270 Changed = true;
11271 } else if (!getUnsignedRangeMin(RHS).isMinValue()) {
11272 RHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), RHS);
11273 Pred = ICmpInst::ICMP_UGT;
11274 Changed = true;
11275 }
11276 break;
11277 default:
11278 break;
11279 }
11280
11281 // TODO: More simplifications are possible here.
11282
11283 // Recursively simplify until we either hit a recursion limit or nothing
11284 // changes.
11285 if (Changed)
11286 (void)SimplifyICmpOperands(Pred, LHS, RHS, Depth + 1);
11287
11288 return Changed;
11289}
11290
11292 return getSignedRangeMax(S).isNegative();
11293}
11294
11298
11300 return !getSignedRangeMin(S).isNegative();
11301}
11302
11306
11308 // Query push down for cases where the unsigned range is
11309 // less than sufficient.
11310 if (const auto *SExt = dyn_cast<SCEVSignExtendExpr>(S))
11311 return isKnownNonZero(SExt->getOperand(0));
11312 return getUnsignedRangeMin(S) != 0;
11313}
11314
11316 bool OrNegative) {
11317 auto NonRecursive = [OrNegative](const SCEV *S) {
11318 if (auto *C = dyn_cast<SCEVConstant>(S))
11319 return C->getAPInt().isPowerOf2() ||
11320 (OrNegative && C->getAPInt().isNegatedPowerOf2());
11321
11322 // vscale is a power-of-two.
11323 return isa<SCEVVScale>(S);
11324 };
11325
11326 if (NonRecursive(S))
11327 return true;
11328
11329 auto *Mul = dyn_cast<SCEVMulExpr>(S);
11330 if (!Mul)
11331 return false;
11332 return all_of(Mul->operands(), NonRecursive) && (OrZero || isKnownNonZero(S));
11333}
11334
11336 const SCEV *S, uint64_t M,
11338 if (M == 0)
11339 return false;
11340 if (M == 1)
11341 return true;
11342
11343 // Recursively check AddRec operands. An AddRecExpr S is a multiple of M if S
11344 // starts with a multiple of M and at every iteration step S only adds
11345 // multiples of M.
11346 if (auto *AddRec = dyn_cast<SCEVAddRecExpr>(S))
11347 return isKnownMultipleOf(AddRec->getStart(), M, Assumptions) &&
11348 isKnownMultipleOf(AddRec->getStepRecurrence(*this), M, Assumptions);
11349
11350 // For a constant, check that "S % M == 0".
11351 if (auto *Cst = dyn_cast<SCEVConstant>(S)) {
11352 APInt C = Cst->getAPInt();
11353 return C.urem(M) == 0;
11354 }
11355
11356 // TODO: Also check other SCEV expressions, i.e., SCEVAddRecExpr, etc.
11357
11358 // Basic tests have failed.
11359 // Check "S % M == 0" at compile time and record runtime Assumptions.
11360 auto *STy = dyn_cast<IntegerType>(S->getType());
11361 const SCEV *SmodM =
11362 getURemExpr(S, getConstant(ConstantInt::get(STy, M, false)));
11363 const SCEV *Zero = getZero(STy);
11364
11365 // Check whether "S % M == 0" is known at compile time.
11366 if (isKnownPredicate(ICmpInst::ICMP_EQ, SmodM, Zero))
11367 return true;
11368
11369 // Check whether "S % M != 0" is known at compile time.
11370 if (isKnownPredicate(ICmpInst::ICMP_NE, SmodM, Zero))
11371 return false;
11372
11374
11375 // Detect redundant predicates.
11376 for (auto *A : Assumptions)
11377 if (A->implies(P, *this))
11378 return true;
11379
11380 // Only record non-redundant predicates.
11381 Assumptions.push_back(P);
11382 return true;
11383}
11384
11386 return ((isKnownNonNegative(S1) && isKnownNonNegative(S2)) ||
11388}
11389
11390std::pair<const SCEV *, const SCEV *>
11392 // Compute SCEV on entry of loop L.
11393 const SCEV *Start = SCEVInitRewriter::rewrite(S, L, *this);
11394 if (Start == getCouldNotCompute())
11395 return { Start, Start };
11396 // Compute post increment SCEV for loop L.
11397 const SCEV *PostInc = SCEVPostIncRewriter::rewrite(S, L, *this);
11398 assert(PostInc != getCouldNotCompute() && "Unexpected could not compute");
11399 return { Start, PostInc };
11400}
11401
11403 SCEVUse RHS) {
11404 // First collect all loops.
11406 getUsedLoops(LHS, LoopsUsed);
11407 getUsedLoops(RHS, LoopsUsed);
11408
11409 if (LoopsUsed.empty())
11410 return false;
11411
11412 // Domination relationship must be a linear order on collected loops.
11413#ifndef NDEBUG
11414 for (const auto *L1 : LoopsUsed)
11415 for (const auto *L2 : LoopsUsed)
11416 assert((DT.dominates(L1->getHeader(), L2->getHeader()) ||
11417 DT.dominates(L2->getHeader(), L1->getHeader())) &&
11418 "Domination relationship is not a linear order");
11419#endif
11420
11421 const Loop *MDL =
11422 *llvm::max_element(LoopsUsed, [&](const Loop *L1, const Loop *L2) {
11423 return DT.properlyDominates(L1->getHeader(), L2->getHeader());
11424 });
11425
11426 // Get init and post increment value for LHS.
11427 auto SplitLHS = SplitIntoInitAndPostInc(MDL, LHS);
11428 // if LHS contains unknown non-invariant SCEV then bail out.
11429 if (SplitLHS.first == getCouldNotCompute())
11430 return false;
11431 assert (SplitLHS.second != getCouldNotCompute() && "Unexpected CNC");
11432 // Get init and post increment value for RHS.
11433 auto SplitRHS = SplitIntoInitAndPostInc(MDL, RHS);
11434 // if RHS contains unknown non-invariant SCEV then bail out.
11435 if (SplitRHS.first == getCouldNotCompute())
11436 return false;
11437 assert (SplitRHS.second != getCouldNotCompute() && "Unexpected CNC");
11438 // It is possible that init SCEV contains an invariant load but it does
11439 // not dominate MDL and is not available at MDL loop entry, so we should
11440 // check it here.
11441 if (!isAvailableAtLoopEntry(SplitLHS.first, MDL) ||
11442 !isAvailableAtLoopEntry(SplitRHS.first, MDL))
11443 return false;
11444
11445 // It seems backedge guard check is faster than entry one so in some cases
11446 // it can speed up whole estimation by short circuit
11447 return isLoopBackedgeGuardedByCond(MDL, Pred, SplitLHS.second,
11448 SplitRHS.second) &&
11449 isLoopEntryGuardedByCond(MDL, Pred, SplitLHS.first, SplitRHS.first);
11450}
11451
11453 SCEVUse RHS) {
11454 // Canonicalize the inputs first.
11455 (void)SimplifyICmpOperands(Pred, LHS, RHS);
11456
11457 if (isKnownViaInduction(Pred, LHS, RHS))
11458 return true;
11459
11460 if (isKnownPredicateViaSplitting(Pred, LHS, RHS))
11461 return true;
11462
11463 // Otherwise see what can be done with some simple reasoning.
11464 return isKnownViaNonRecursiveReasoning(Pred, LHS, RHS);
11465}
11466
11468 const SCEV *LHS,
11469 const SCEV *RHS) {
11470 if (isKnownPredicate(Pred, LHS, RHS))
11471 return true;
11473 return false;
11474 return std::nullopt;
11475}
11476
11478 const SCEV *RHS,
11479 const Instruction *CtxI) {
11480 // TODO: Analyze guards and assumes from Context's block.
11481 return isKnownPredicate(Pred, LHS, RHS) ||
11482 isBasicBlockEntryGuardedByCond(CtxI->getParent(), Pred, LHS, RHS);
11483}
11484
11485std::optional<bool>
11487 const SCEV *RHS, const Instruction *CtxI) {
11488 std::optional<bool> KnownWithoutContext = evaluatePredicate(Pred, LHS, RHS);
11489 if (KnownWithoutContext)
11490 return KnownWithoutContext;
11491
11492 if (isBasicBlockEntryGuardedByCond(CtxI->getParent(), Pred, LHS, RHS))
11493 return true;
11495 CtxI->getParent(), ICmpInst::getInverseCmpPredicate(Pred), LHS, RHS))
11496 return false;
11497 return std::nullopt;
11498}
11499
11501 const SCEVAddRecExpr *LHS,
11502 const SCEV *RHS) {
11503 const Loop *L = LHS->getLoop();
11504 return isLoopEntryGuardedByCond(L, Pred, LHS->getStart(), RHS) &&
11505 isLoopBackedgeGuardedByCond(L, Pred, LHS->getPostIncExpr(*this), RHS);
11506}
11507
11508std::optional<ScalarEvolution::MonotonicPredicateType>
11510 ICmpInst::Predicate Pred) {
11511 auto Result = getMonotonicPredicateTypeImpl(LHS, Pred);
11512
11513#ifndef NDEBUG
11514 // Verify an invariant: inverting the predicate should turn a monotonically
11515 // increasing change to a monotonically decreasing one, and vice versa.
11516 if (Result) {
11517 auto ResultSwapped =
11518 getMonotonicPredicateTypeImpl(LHS, ICmpInst::getSwappedPredicate(Pred));
11519
11520 assert(*ResultSwapped != *Result &&
11521 "monotonicity should flip as we flip the predicate");
11522 }
11523#endif
11524
11525 return Result;
11526}
11527
11528std::optional<ScalarEvolution::MonotonicPredicateType>
11529ScalarEvolution::getMonotonicPredicateTypeImpl(const SCEVAddRecExpr *LHS,
11530 ICmpInst::Predicate Pred) {
11531 // A zero step value for LHS means the induction variable is essentially a
11532 // loop invariant value. We don't really depend on the predicate actually
11533 // flipping from false to true (for increasing predicates, and the other way
11534 // around for decreasing predicates), all we care about is that *if* the
11535 // predicate changes then it only changes from false to true.
11536 //
11537 // A zero step value in itself is not very useful, but there may be places
11538 // where SCEV can prove X >= 0 but not prove X > 0, so it is helpful to be
11539 // as general as possible.
11540
11541 // Only handle LE/LT/GE/GT predicates.
11542 if (!ICmpInst::isRelational(Pred))
11543 return std::nullopt;
11544
11545 bool IsGreater = ICmpInst::isGE(Pred) || ICmpInst::isGT(Pred);
11546 assert((IsGreater || ICmpInst::isLE(Pred) || ICmpInst::isLT(Pred)) &&
11547 "Should be greater or less!");
11548
11549 // Check that AR does not wrap.
11550 if (ICmpInst::isUnsigned(Pred)) {
11551 if (!LHS->hasNoUnsignedWrap())
11552 return std::nullopt;
11554 }
11555 assert(ICmpInst::isSigned(Pred) &&
11556 "Relational predicate is either signed or unsigned!");
11557 if (!LHS->hasNoSignedWrap())
11558 return std::nullopt;
11559
11560 const SCEV *Step = LHS->getStepRecurrence(*this);
11561
11562 if (isKnownNonNegative(Step))
11564
11565 if (isKnownNonPositive(Step))
11567
11568 return std::nullopt;
11569}
11570
11571std::optional<ScalarEvolution::LoopInvariantPredicate>
11573 const SCEV *RHS, const Loop *L,
11574 const Instruction *CtxI) {
11575 // If there is a loop-invariant, force it into the RHS, otherwise bail out.
11576 if (!isLoopInvariant(RHS, L)) {
11577 if (!isLoopInvariant(LHS, L))
11578 return std::nullopt;
11579
11580 std::swap(LHS, RHS);
11582 }
11583
11584 const SCEVAddRecExpr *ArLHS = dyn_cast<SCEVAddRecExpr>(LHS);
11585 if (!ArLHS || ArLHS->getLoop() != L)
11586 return std::nullopt;
11587
11588 auto MonotonicType = getMonotonicPredicateType(ArLHS, Pred);
11589 if (!MonotonicType)
11590 return std::nullopt;
11591 // If the predicate "ArLHS `Pred` RHS" monotonically increases from false to
11592 // true as the loop iterates, and the backedge is control dependent on
11593 // "ArLHS `Pred` RHS" == true then we can reason as follows:
11594 //
11595 // * if the predicate was false in the first iteration then the predicate
11596 // is never evaluated again, since the loop exits without taking the
11597 // backedge.
11598 // * if the predicate was true in the first iteration then it will
11599 // continue to be true for all future iterations since it is
11600 // monotonically increasing.
11601 //
11602 // For both the above possibilities, we can replace the loop varying
11603 // predicate with its value on the first iteration of the loop (which is
11604 // loop invariant).
11605 //
11606 // A similar reasoning applies for a monotonically decreasing predicate, by
11607 // replacing true with false and false with true in the above two bullets.
11609 auto P = Increasing ? Pred : ICmpInst::getInverseCmpPredicate(Pred);
11610
11611 if (isLoopBackedgeGuardedByCond(L, P, LHS, RHS))
11613 RHS);
11614
11615 if (!CtxI)
11616 return std::nullopt;
11617 // Try to prove via context.
11618 // TODO: Support other cases.
11619 switch (Pred) {
11620 default:
11621 break;
11622 case ICmpInst::ICMP_ULE:
11623 case ICmpInst::ICMP_ULT: {
11624 assert(ArLHS->hasNoUnsignedWrap() && "Is a requirement of monotonicity!");
11625 // Given preconditions
11626 // (1) ArLHS does not cross the border of positive and negative parts of
11627 // range because of:
11628 // - Positive step; (TODO: lift this limitation)
11629 // - nuw - does not cross zero boundary;
11630 // - nsw - does not cross SINT_MAX boundary;
11631 // (2) ArLHS <s RHS
11632 // (3) RHS >=s 0
11633 // we can replace the loop variant ArLHS <u RHS condition with loop
11634 // invariant Start(ArLHS) <u RHS.
11635 //
11636 // Because of (1) there are two options:
11637 // - ArLHS is always negative. It means that ArLHS <u RHS is always false;
11638 // - ArLHS is always non-negative. Because of (3) RHS is also non-negative.
11639 // It means that ArLHS <s RHS <=> ArLHS <u RHS.
11640 // Because of (2) ArLHS <u RHS is trivially true.
11641 // All together it means that ArLHS <u RHS <=> Start(ArLHS) >=s 0.
11642 // We can strengthen this to Start(ArLHS) <u RHS.
11643 auto SignFlippedPred = ICmpInst::getFlippedSignednessPredicate(Pred);
11644 if (ArLHS->hasNoSignedWrap() && ArLHS->isAffine() &&
11645 isKnownPositive(ArLHS->getStepRecurrence(*this)) &&
11646 isKnownNonNegative(RHS) &&
11647 isKnownPredicateAt(SignFlippedPred, ArLHS, RHS, CtxI))
11649 RHS);
11650 }
11651 }
11652
11653 return std::nullopt;
11654}
11655
11656std::optional<ScalarEvolution::LoopInvariantPredicate>
11658 CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS, const Loop *L,
11659 const Instruction *CtxI, const SCEV *MaxIter) {
11661 Pred, LHS, RHS, L, CtxI, MaxIter))
11662 return LIP;
11663 if (auto *UMin = dyn_cast<SCEVUMinExpr>(MaxIter))
11664 // Number of iterations expressed as UMIN isn't always great for expressing
11665 // the value on the last iteration. If the straightforward approach didn't
11666 // work, try the following trick: if the a predicate is invariant for X, it
11667 // is also invariant for umin(X, ...). So try to find something that works
11668 // among subexpressions of MaxIter expressed as umin.
11669 for (SCEVUse Op : UMin->operands())
11671 Pred, LHS, RHS, L, CtxI, Op))
11672 return LIP;
11673 return std::nullopt;
11674}
11675
11676std::optional<ScalarEvolution::LoopInvariantPredicate>
11678 CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS, const Loop *L,
11679 const Instruction *CtxI, const SCEV *MaxIter) {
11680 // Try to prove the following set of facts:
11681 // - The predicate is monotonic in the iteration space.
11682 // - If the check does not fail on the 1st iteration:
11683 // - No overflow will happen during first MaxIter iterations;
11684 // - It will not fail on the MaxIter'th iteration.
11685 // If the check does fail on the 1st iteration, we leave the loop and no
11686 // other checks matter.
11687
11688 // If there is a loop-invariant, force it into the RHS, otherwise bail out.
11689 if (!isLoopInvariant(RHS, L)) {
11690 if (!isLoopInvariant(LHS, L))
11691 return std::nullopt;
11692
11693 std::swap(LHS, RHS);
11695 }
11696
11697 auto *AR = dyn_cast<SCEVAddRecExpr>(LHS);
11698 if (!AR || AR->getLoop() != L)
11699 return std::nullopt;
11700
11701 // Even if both are valid, we need to consistently chose the unsigned or the
11702 // signed predicate below, not mixtures of both. For now, prefer the unsigned
11703 // predicate.
11704 Pred = Pred.dropSameSign();
11705
11706 // The predicate must be relational (i.e. <, <=, >=, >).
11707 if (!ICmpInst::isRelational(Pred))
11708 return std::nullopt;
11709
11710 // TODO: Support steps other than +/- 1.
11711 const SCEV *Step = AR->getStepRecurrence(*this);
11712 auto *One = getOne(Step->getType());
11713 auto *MinusOne = getNegativeSCEV(One);
11714 if (Step != One && Step != MinusOne)
11715 return std::nullopt;
11716
11717 // Type mismatch here means that MaxIter is potentially larger than max
11718 // unsigned value in start type, which mean we cannot prove no wrap for the
11719 // indvar.
11720 if (AR->getType() != MaxIter->getType())
11721 return std::nullopt;
11722
11723 // Value of IV on suggested last iteration.
11724 const SCEV *Last = AR->evaluateAtIteration(MaxIter, *this);
11725 // Does it still meet the requirement?
11726 if (!isLoopBackedgeGuardedByCond(L, Pred, Last, RHS))
11727 return std::nullopt;
11728 // Because step is +/- 1 and MaxIter has same type as Start (i.e. it does
11729 // not exceed max unsigned value of this type), this effectively proves
11730 // that there is no wrap during the iteration. To prove that there is no
11731 // signed/unsigned wrap, we need to check that
11732 // Start <= Last for step = 1 or Start >= Last for step = -1.
11733 ICmpInst::Predicate NoOverflowPred =
11735 if (Step == MinusOne)
11736 NoOverflowPred = ICmpInst::getSwappedPredicate(NoOverflowPred);
11737 const SCEV *Start = AR->getStart();
11738 if (!isKnownPredicateAt(NoOverflowPred, Start, Last, CtxI))
11739 return std::nullopt;
11740
11741 // Everything is fine.
11742 return ScalarEvolution::LoopInvariantPredicate(Pred, Start, RHS);
11743}
11744
11745bool ScalarEvolution::isKnownPredicateViaConstantRanges(CmpPredicate Pred,
11746 SCEVUse LHS,
11747 SCEVUse RHS) {
11748 if (HasSameValue(LHS, RHS))
11749 return ICmpInst::isTrueWhenEqual(Pred);
11750
11751 auto CheckRange = [&](bool IsSigned) {
11752 auto RangeLHS = IsSigned ? getSignedRange(LHS) : getUnsignedRange(LHS);
11753 auto RangeRHS = IsSigned ? getSignedRange(RHS) : getUnsignedRange(RHS);
11754 return RangeLHS.icmp(Pred, RangeRHS);
11755 };
11756
11757 // The check at the top of the function catches the case where the values are
11758 // known to be equal.
11759 if (Pred == CmpInst::ICMP_EQ)
11760 return false;
11761
11762 if (Pred == CmpInst::ICMP_NE) {
11763 if (CheckRange(true) || CheckRange(false))
11764 return true;
11765 auto *Diff = getMinusSCEV(LHS, RHS);
11766 return !isa<SCEVCouldNotCompute>(Diff) && isKnownNonZero(Diff);
11767 }
11768
11769 return CheckRange(CmpInst::isSigned(Pred));
11770}
11771
11772bool ScalarEvolution::isKnownPredicateViaNoOverflow(CmpPredicate Pred,
11774 // Match X to (A + C1)<ExpectedFlags> and Y to (A + C2)<ExpectedFlags>, where
11775 // C1 and C2 are constant integers. If either X or Y are not add expressions,
11776 // consider them as X + 0 and Y + 0 respectively. C1 and C2 are returned via
11777 // OutC1 and OutC2.
11778 auto MatchBinaryAddToConst = [this](SCEVUse X, SCEVUse Y, APInt &OutC1,
11779 APInt &OutC2,
11780 SCEV::NoWrapFlags ExpectedFlags) {
11781 SCEVUse XNonConstOp, XConstOp;
11782 SCEVUse YNonConstOp, YConstOp;
11783 SCEV::NoWrapFlags XFlagsPresent;
11784 SCEV::NoWrapFlags YFlagsPresent;
11785
11786 if (!splitBinaryAdd(X, XConstOp, XNonConstOp, XFlagsPresent)) {
11787 XConstOp = getZero(X->getType());
11788 XNonConstOp = X;
11789 XFlagsPresent = ExpectedFlags;
11790 }
11791 if (!isa<SCEVConstant>(XConstOp))
11792 return false;
11793
11794 if (!splitBinaryAdd(Y, YConstOp, YNonConstOp, YFlagsPresent)) {
11795 YConstOp = getZero(Y->getType());
11796 YNonConstOp = Y;
11797 YFlagsPresent = ExpectedFlags;
11798 }
11799
11800 if (YNonConstOp != XNonConstOp)
11801 return false;
11802
11803 if (!isa<SCEVConstant>(YConstOp))
11804 return false;
11805
11806 // When matching ADDs with NUW flags (and unsigned predicates), only the
11807 // second ADD (with the larger constant) requires NUW.
11808 if ((YFlagsPresent & ExpectedFlags) != ExpectedFlags)
11809 return false;
11810 if (ExpectedFlags != SCEV::FlagNUW &&
11811 (XFlagsPresent & ExpectedFlags) != ExpectedFlags) {
11812 return false;
11813 }
11814
11815 OutC1 = cast<SCEVConstant>(XConstOp)->getAPInt();
11816 OutC2 = cast<SCEVConstant>(YConstOp)->getAPInt();
11817
11818 return true;
11819 };
11820
11821 APInt C1;
11822 APInt C2;
11823
11824 switch (Pred) {
11825 default:
11826 break;
11827
11828 case ICmpInst::ICMP_SGE:
11829 std::swap(LHS, RHS);
11830 [[fallthrough]];
11831 case ICmpInst::ICMP_SLE:
11832 // (X + C1)<nsw> s<= (X + C2)<nsw> if C1 s<= C2.
11833 if (MatchBinaryAddToConst(LHS, RHS, C1, C2, SCEV::FlagNSW) && C1.sle(C2))
11834 return true;
11835
11836 break;
11837
11838 case ICmpInst::ICMP_SGT:
11839 std::swap(LHS, RHS);
11840 [[fallthrough]];
11841 case ICmpInst::ICMP_SLT:
11842 // (X + C1)<nsw> s< (X + C2)<nsw> if C1 s< C2.
11843 if (MatchBinaryAddToConst(LHS, RHS, C1, C2, SCEV::FlagNSW) && C1.slt(C2))
11844 return true;
11845
11846 break;
11847
11848 case ICmpInst::ICMP_UGE:
11849 std::swap(LHS, RHS);
11850 [[fallthrough]];
11851 case ICmpInst::ICMP_ULE:
11852 // (X + C1) u<= (X + C2)<nuw> for C1 u<= C2.
11853 if (MatchBinaryAddToConst(LHS, RHS, C1, C2, SCEV::FlagNUW) && C1.ule(C2))
11854 return true;
11855
11856 break;
11857
11858 case ICmpInst::ICMP_UGT:
11859 std::swap(LHS, RHS);
11860 [[fallthrough]];
11861 case ICmpInst::ICMP_ULT:
11862 // (X + C1) u< (X + C2)<nuw> if C1 u< C2.
11863 if (MatchBinaryAddToConst(LHS, RHS, C1, C2, SCEV::FlagNUW) && C1.ult(C2))
11864 return true;
11865 break;
11866 }
11867
11868 return false;
11869}
11870
11871bool ScalarEvolution::isKnownPredicateViaSplitting(CmpPredicate Pred,
11873 if (Pred != ICmpInst::ICMP_ULT || ProvingSplitPredicate)
11874 return false;
11875
11876 // Allowing arbitrary number of activations of isKnownPredicateViaSplitting on
11877 // the stack can result in exponential time complexity.
11878 SaveAndRestore Restore(ProvingSplitPredicate, true);
11879
11880 // If L >= 0 then I `ult` L <=> I >= 0 && I `slt` L
11881 //
11882 // To prove L >= 0 we use isKnownNonNegative whereas to prove I >= 0 we use
11883 // isKnownPredicate. isKnownPredicate is more powerful, but also more
11884 // expensive; and using isKnownNonNegative(RHS) is sufficient for most of the
11885 // interesting cases seen in practice. We can consider "upgrading" L >= 0 to
11886 // use isKnownPredicate later if needed.
11887 return isKnownNonNegative(RHS) &&
11890}
11891
11892bool ScalarEvolution::isImpliedViaGuard(const BasicBlock *BB, CmpPredicate Pred,
11893 const SCEV *LHS, const SCEV *RHS) {
11894 // No need to even try if we know the module has no guards.
11895 if (!HasGuards)
11896 return false;
11897
11898 return any_of(*BB, [&](const Instruction &I) {
11899 using namespace llvm::PatternMatch;
11900
11901 Value *Condition;
11903 m_Value(Condition))) &&
11904 isImpliedCond(Pred, LHS, RHS, Condition, false);
11905 });
11906}
11907
11908/// isLoopBackedgeGuardedByCond - Test whether the backedge of the loop is
11909/// protected by a conditional between LHS and RHS. This is used to
11910/// to eliminate casts.
11912 CmpPredicate Pred,
11913 const SCEV *LHS,
11914 const SCEV *RHS) {
11915 // Interpret a null as meaning no loop, where there is obviously no guard
11916 // (interprocedural conditions notwithstanding). Do not bother about
11917 // unreachable loops.
11918 if (!L || !DT.isReachableFromEntry(L->getHeader()))
11919 return true;
11920
11921 if (VerifyIR)
11922 assert(!verifyFunction(*L->getHeader()->getParent(), &dbgs()) &&
11923 "This cannot be done on broken IR!");
11924
11925
11926 if (isKnownViaNonRecursiveReasoning(Pred, LHS, RHS))
11927 return true;
11928
11929 BasicBlock *Latch = L->getLoopLatch();
11930 if (!Latch)
11931 return false;
11932
11933 CondBrInst *LoopContinuePredicate =
11935 if (LoopContinuePredicate &&
11936 isImpliedCond(Pred, LHS, RHS, LoopContinuePredicate->getCondition(),
11937 LoopContinuePredicate->getSuccessor(0) != L->getHeader()))
11938 return true;
11939
11940 // We don't want more than one activation of the following loops on the stack
11941 // -- that can lead to O(n!) time complexity.
11942 if (WalkingBEDominatingConds)
11943 return false;
11944
11945 SaveAndRestore ClearOnExit(WalkingBEDominatingConds, true);
11946
11947 // See if we can exploit a trip count to prove the predicate.
11948 const auto &BETakenInfo = getBackedgeTakenInfo(L);
11949 const SCEV *LatchBECount = BETakenInfo.getExact(Latch, this);
11950 if (LatchBECount != getCouldNotCompute()) {
11951 // We know that Latch branches back to the loop header exactly
11952 // LatchBECount times. This means the backdege condition at Latch is
11953 // equivalent to "{0,+,1} u< LatchBECount".
11954 Type *Ty = LatchBECount->getType();
11955 auto NoWrapFlags = SCEV::NoWrapFlags(SCEV::FlagNUW | SCEV::FlagNW);
11956 const SCEV *LoopCounter =
11957 getAddRecExpr(getZero(Ty), getOne(Ty), L, NoWrapFlags);
11958 if (isImpliedCond(Pred, LHS, RHS, ICmpInst::ICMP_ULT, LoopCounter,
11959 LatchBECount))
11960 return true;
11961 }
11962
11963 // Check conditions due to any @llvm.assume intrinsics.
11964 for (auto &AssumeVH : AC.assumptions()) {
11965 if (!AssumeVH)
11966 continue;
11967 auto *CI = cast<CallInst>(AssumeVH);
11968 if (!DT.dominates(CI, Latch->getTerminator()))
11969 continue;
11970
11971 if (isImpliedCond(Pred, LHS, RHS, CI->getArgOperand(0), false))
11972 return true;
11973 }
11974
11975 if (isImpliedViaGuard(Latch, Pred, LHS, RHS))
11976 return true;
11977
11978 for (DomTreeNode *DTN = DT[Latch], *HeaderDTN = DT[L->getHeader()];
11979 DTN != HeaderDTN; DTN = DTN->getIDom()) {
11980 assert(DTN && "should reach the loop header before reaching the root!");
11981
11982 BasicBlock *BB = DTN->getBlock();
11983 if (isImpliedViaGuard(BB, Pred, LHS, RHS))
11984 return true;
11985
11986 BasicBlock *PBB = BB->getSinglePredecessor();
11987 if (!PBB)
11988 continue;
11989
11991 if (!ContBr || ContBr->getSuccessor(0) == ContBr->getSuccessor(1))
11992 continue;
11993
11994 // If we have an edge `E` within the loop body that dominates the only
11995 // latch, the condition guarding `E` also guards the backedge. This
11996 // reasoning works only for loops with a single latch.
11997 // We're constructively (and conservatively) enumerating edges within the
11998 // loop body that dominate the latch. The dominator tree better agree
11999 // with us on this:
12000 assert(DT.dominates(BasicBlockEdge(PBB, BB), Latch) && "should be!");
12001 if (isImpliedCond(Pred, LHS, RHS, ContBr->getCondition(),
12002 BB != ContBr->getSuccessor(0)))
12003 return true;
12004 }
12005
12006 return false;
12007}
12008
12010 CmpPredicate Pred,
12011 const SCEV *LHS,
12012 const SCEV *RHS) {
12013 // Do not bother proving facts for unreachable code.
12014 if (!DT.isReachableFromEntry(BB))
12015 return true;
12016 if (VerifyIR)
12017 assert(!verifyFunction(*BB->getParent(), &dbgs()) &&
12018 "This cannot be done on broken IR!");
12019
12020 // If we cannot prove strict comparison (e.g. a > b), maybe we can prove
12021 // the facts (a >= b && a != b) separately. A typical situation is when the
12022 // non-strict comparison is known from ranges and non-equality is known from
12023 // dominating predicates. If we are proving strict comparison, we always try
12024 // to prove non-equality and non-strict comparison separately.
12025 CmpPredicate NonStrictPredicate = ICmpInst::getNonStrictCmpPredicate(Pred);
12026 const bool ProvingStrictComparison =
12027 Pred != NonStrictPredicate.dropSameSign();
12028 bool ProvedNonStrictComparison = false;
12029 bool ProvedNonEquality = false;
12030
12031 auto SplitAndProve = [&](std::function<bool(CmpPredicate)> Fn) -> bool {
12032 if (!ProvedNonStrictComparison)
12033 ProvedNonStrictComparison = Fn(NonStrictPredicate);
12034 if (!ProvedNonEquality)
12035 ProvedNonEquality = Fn(ICmpInst::ICMP_NE);
12036 if (ProvedNonStrictComparison && ProvedNonEquality)
12037 return true;
12038 return false;
12039 };
12040
12041 if (ProvingStrictComparison) {
12042 auto ProofFn = [&](CmpPredicate P) {
12043 return isKnownViaNonRecursiveReasoning(P, LHS, RHS);
12044 };
12045 if (SplitAndProve(ProofFn))
12046 return true;
12047 }
12048
12049 // Try to prove (Pred, LHS, RHS) using isImpliedCond.
12050 auto ProveViaCond = [&](const Value *Condition, bool Inverse) {
12051 const Instruction *CtxI = &BB->front();
12052 if (isImpliedCond(Pred, LHS, RHS, Condition, Inverse, CtxI))
12053 return true;
12054 if (ProvingStrictComparison) {
12055 auto ProofFn = [&](CmpPredicate P) {
12056 return isImpliedCond(P, LHS, RHS, Condition, Inverse, CtxI);
12057 };
12058 if (SplitAndProve(ProofFn))
12059 return true;
12060 }
12061 return false;
12062 };
12063
12064 // Starting at the block's predecessor, climb up the predecessor chain, as long
12065 // as there are predecessors that can be found that have unique successors
12066 // leading to the original block.
12067 const Loop *ContainingLoop = LI.getLoopFor(BB);
12068 const BasicBlock *PredBB;
12069 if (ContainingLoop && ContainingLoop->getHeader() == BB)
12070 PredBB = ContainingLoop->getLoopPredecessor();
12071 else
12072 PredBB = BB->getSinglePredecessor();
12073 for (std::pair<const BasicBlock *, const BasicBlock *> Pair(PredBB, BB);
12074 Pair.first; Pair = getPredecessorWithUniqueSuccessorForBB(Pair.first)) {
12075 const CondBrInst *BlockEntryPredicate =
12076 dyn_cast<CondBrInst>(Pair.first->getTerminator());
12077 if (!BlockEntryPredicate)
12078 continue;
12079
12080 if (ProveViaCond(BlockEntryPredicate->getCondition(),
12081 BlockEntryPredicate->getSuccessor(0) != Pair.second))
12082 return true;
12083 }
12084
12085 // Check conditions due to any @llvm.assume intrinsics.
12086 for (auto &AssumeVH : AC.assumptions()) {
12087 if (!AssumeVH)
12088 continue;
12089 auto *CI = cast<CallInst>(AssumeVH);
12090 if (!DT.dominates(CI, BB))
12091 continue;
12092
12093 if (ProveViaCond(CI->getArgOperand(0), false))
12094 return true;
12095 }
12096
12097 // Check conditions due to any @llvm.experimental.guard intrinsics.
12098 auto *GuardDecl = Intrinsic::getDeclarationIfExists(
12099 F.getParent(), Intrinsic::experimental_guard);
12100 if (GuardDecl)
12101 for (const auto *GU : GuardDecl->users())
12102 if (const auto *Guard = dyn_cast<IntrinsicInst>(GU))
12103 if (Guard->getFunction() == BB->getParent() && DT.dominates(Guard, BB))
12104 if (ProveViaCond(Guard->getArgOperand(0), false))
12105 return true;
12106 return false;
12107}
12108
12110 const SCEV *LHS,
12111 const SCEV *RHS) {
12112 // Interpret a null as meaning no loop, where there is obviously no guard
12113 // (interprocedural conditions notwithstanding).
12114 if (!L)
12115 return false;
12116
12117 // Both LHS and RHS must be available at loop entry.
12119 "LHS is not available at Loop Entry");
12121 "RHS is not available at Loop Entry");
12122
12123 if (isKnownViaNonRecursiveReasoning(Pred, LHS, RHS))
12124 return true;
12125
12126 return isBasicBlockEntryGuardedByCond(L->getHeader(), Pred, LHS, RHS);
12127}
12128
12129bool ScalarEvolution::isImpliedCond(CmpPredicate Pred, const SCEV *LHS,
12130 const SCEV *RHS,
12131 const Value *FoundCondValue, bool Inverse,
12132 const Instruction *CtxI) {
12133 // False conditions implies anything. Do not bother analyzing it further.
12134 if (FoundCondValue ==
12135 ConstantInt::getBool(FoundCondValue->getContext(), Inverse))
12136 return true;
12137
12138 if (!PendingLoopPredicates.insert(FoundCondValue).second)
12139 return false;
12140
12141 llvm::scope_exit ClearOnExit(
12142 [&]() { PendingLoopPredicates.erase(FoundCondValue); });
12143
12144 // Recursively handle And and Or conditions.
12145 const Value *Op0, *Op1;
12146 if (match(FoundCondValue, m_LogicalAnd(m_Value(Op0), m_Value(Op1)))) {
12147 if (!Inverse)
12148 return isImpliedCond(Pred, LHS, RHS, Op0, Inverse, CtxI) ||
12149 isImpliedCond(Pred, LHS, RHS, Op1, Inverse, CtxI);
12150 } else if (match(FoundCondValue, m_LogicalOr(m_Value(Op0), m_Value(Op1)))) {
12151 if (Inverse)
12152 return isImpliedCond(Pred, LHS, RHS, Op0, Inverse, CtxI) ||
12153 isImpliedCond(Pred, LHS, RHS, Op1, Inverse, CtxI);
12154 }
12155
12156 const ICmpInst *ICI = dyn_cast<ICmpInst>(FoundCondValue);
12157 if (!ICI) return false;
12158
12159 // Now that we found a conditional branch that dominates the loop or controls
12160 // the loop latch. Check to see if it is the comparison we are looking for.
12161 CmpPredicate FoundPred;
12162 if (Inverse)
12163 FoundPred = ICI->getInverseCmpPredicate();
12164 else
12165 FoundPred = ICI->getCmpPredicate();
12166
12167 const SCEV *FoundLHS = getSCEV(ICI->getOperand(0));
12168 const SCEV *FoundRHS = getSCEV(ICI->getOperand(1));
12169
12170 return isImpliedCond(Pred, LHS, RHS, FoundPred, FoundLHS, FoundRHS, CtxI);
12171}
12172
12173bool ScalarEvolution::isImpliedCond(CmpPredicate Pred, const SCEV *LHS,
12174 const SCEV *RHS, CmpPredicate FoundPred,
12175 const SCEV *FoundLHS, const SCEV *FoundRHS,
12176 const Instruction *CtxI) {
12177 // Balance the types.
12178 if (getTypeSizeInBits(LHS->getType()) <
12179 getTypeSizeInBits(FoundLHS->getType())) {
12180 // For unsigned and equality predicates, try to prove that both found
12181 // operands fit into narrow unsigned range. If so, try to prove facts in
12182 // narrow types.
12183 if (!CmpInst::isSigned(FoundPred) && !FoundLHS->getType()->isPointerTy() &&
12184 !FoundRHS->getType()->isPointerTy()) {
12185 auto *NarrowType = LHS->getType();
12186 auto *WideType = FoundLHS->getType();
12187 auto BitWidth = getTypeSizeInBits(NarrowType);
12188 const SCEV *MaxValue = getZeroExtendExpr(
12190 if (isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_ULE, FoundLHS,
12191 MaxValue) &&
12192 isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_ULE, FoundRHS,
12193 MaxValue)) {
12194 const SCEV *TruncFoundLHS = getTruncateExpr(FoundLHS, NarrowType);
12195 const SCEV *TruncFoundRHS = getTruncateExpr(FoundRHS, NarrowType);
12196 // We cannot preserve samesign after truncation.
12197 if (isImpliedCondBalancedTypes(Pred, LHS, RHS, FoundPred.dropSameSign(),
12198 TruncFoundLHS, TruncFoundRHS, CtxI))
12199 return true;
12200 }
12201 }
12202
12203 if (LHS->getType()->isPointerTy() || RHS->getType()->isPointerTy())
12204 return false;
12205 if (CmpInst::isSigned(Pred)) {
12206 LHS = getSignExtendExpr(LHS, FoundLHS->getType());
12207 RHS = getSignExtendExpr(RHS, FoundLHS->getType());
12208 } else {
12209 LHS = getZeroExtendExpr(LHS, FoundLHS->getType());
12210 RHS = getZeroExtendExpr(RHS, FoundLHS->getType());
12211 }
12212 } else if (getTypeSizeInBits(LHS->getType()) >
12213 getTypeSizeInBits(FoundLHS->getType())) {
12214 if (FoundLHS->getType()->isPointerTy() || FoundRHS->getType()->isPointerTy())
12215 return false;
12216 if (CmpInst::isSigned(FoundPred)) {
12217 FoundLHS = getSignExtendExpr(FoundLHS, LHS->getType());
12218 FoundRHS = getSignExtendExpr(FoundRHS, LHS->getType());
12219 } else {
12220 FoundLHS = getZeroExtendExpr(FoundLHS, LHS->getType());
12221 FoundRHS = getZeroExtendExpr(FoundRHS, LHS->getType());
12222 }
12223 }
12224 return isImpliedCondBalancedTypes(Pred, LHS, RHS, FoundPred, FoundLHS,
12225 FoundRHS, CtxI);
12226}
12227
12228bool ScalarEvolution::isImpliedCondBalancedTypes(
12229 CmpPredicate Pred, SCEVUse LHS, SCEVUse RHS, CmpPredicate FoundPred,
12230 SCEVUse FoundLHS, SCEVUse FoundRHS, const Instruction *CtxI) {
12232 getTypeSizeInBits(FoundLHS->getType()) &&
12233 "Types should be balanced!");
12234 // Canonicalize the query to match the way instcombine will have
12235 // canonicalized the comparison.
12236 if (SimplifyICmpOperands(Pred, LHS, RHS))
12237 if (LHS == RHS)
12238 return CmpInst::isTrueWhenEqual(Pred);
12239 if (SimplifyICmpOperands(FoundPred, FoundLHS, FoundRHS))
12240 if (FoundLHS == FoundRHS)
12241 return CmpInst::isFalseWhenEqual(FoundPred);
12242
12243 // Check to see if we can make the LHS or RHS match.
12244 if (LHS == FoundRHS || RHS == FoundLHS) {
12245 if (isa<SCEVConstant>(RHS)) {
12246 std::swap(FoundLHS, FoundRHS);
12247 FoundPred = ICmpInst::getSwappedCmpPredicate(FoundPred);
12248 } else {
12249 std::swap(LHS, RHS);
12251 }
12252 }
12253
12254 // Check whether the found predicate is the same as the desired predicate.
12255 if (auto P = CmpPredicate::getMatching(FoundPred, Pred))
12256 return isImpliedCondOperands(*P, LHS, RHS, FoundLHS, FoundRHS, CtxI);
12257
12258 // Check whether swapping the found predicate makes it the same as the
12259 // desired predicate.
12260 if (auto P = CmpPredicate::getMatching(
12261 ICmpInst::getSwappedCmpPredicate(FoundPred), Pred)) {
12262 // We can write the implication
12263 // 0. LHS Pred RHS <- FoundLHS SwapPred FoundRHS
12264 // using one of the following ways:
12265 // 1. LHS Pred RHS <- FoundRHS Pred FoundLHS
12266 // 2. RHS SwapPred LHS <- FoundLHS SwapPred FoundRHS
12267 // 3. LHS Pred RHS <- ~FoundLHS Pred ~FoundRHS
12268 // 4. ~LHS SwapPred ~RHS <- FoundLHS SwapPred FoundRHS
12269 // Forms 1. and 2. require swapping the operands of one condition. Don't
12270 // do this if it would break canonical constant/addrec ordering.
12272 return isImpliedCondOperands(ICmpInst::getSwappedCmpPredicate(*P), RHS,
12273 LHS, FoundLHS, FoundRHS, CtxI);
12274 if (!isa<SCEVConstant>(FoundRHS) && !isa<SCEVAddRecExpr>(FoundLHS))
12275 return isImpliedCondOperands(*P, LHS, RHS, FoundRHS, FoundLHS, CtxI);
12276
12277 // There's no clear preference between forms 3. and 4., try both. Avoid
12278 // forming getNotSCEV of pointer values as the resulting subtract is
12279 // not legal.
12280 if (!LHS->getType()->isPointerTy() && !RHS->getType()->isPointerTy() &&
12281 isImpliedCondOperands(ICmpInst::getSwappedCmpPredicate(*P),
12282 getNotSCEV(LHS), getNotSCEV(RHS), FoundLHS,
12283 FoundRHS, CtxI))
12284 return true;
12285
12286 if (!FoundLHS->getType()->isPointerTy() &&
12287 !FoundRHS->getType()->isPointerTy() &&
12288 isImpliedCondOperands(*P, LHS, RHS, getNotSCEV(FoundLHS),
12289 getNotSCEV(FoundRHS), CtxI))
12290 return true;
12291
12292 return false;
12293 }
12294
12295 auto IsSignFlippedPredicate = [](CmpInst::Predicate P1,
12297 assert(P1 != P2 && "Handled earlier!");
12298 return CmpInst::isRelational(P2) &&
12300 };
12301 if (IsSignFlippedPredicate(Pred, FoundPred)) {
12302 // Unsigned comparison is the same as signed comparison when both the
12303 // operands are non-negative or negative.
12304 if (haveSameSign(FoundLHS, FoundRHS))
12305 return isImpliedCondOperands(Pred, LHS, RHS, FoundLHS, FoundRHS, CtxI);
12306 // Create local copies that we can freely swap and canonicalize our
12307 // conditions to "le/lt".
12308 CmpPredicate CanonicalPred = Pred, CanonicalFoundPred = FoundPred;
12309 const SCEV *CanonicalLHS = LHS, *CanonicalRHS = RHS,
12310 *CanonicalFoundLHS = FoundLHS, *CanonicalFoundRHS = FoundRHS;
12311 if (ICmpInst::isGT(CanonicalPred) || ICmpInst::isGE(CanonicalPred)) {
12312 CanonicalPred = ICmpInst::getSwappedCmpPredicate(CanonicalPred);
12313 CanonicalFoundPred = ICmpInst::getSwappedCmpPredicate(CanonicalFoundPred);
12314 std::swap(CanonicalLHS, CanonicalRHS);
12315 std::swap(CanonicalFoundLHS, CanonicalFoundRHS);
12316 }
12317 assert((ICmpInst::isLT(CanonicalPred) || ICmpInst::isLE(CanonicalPred)) &&
12318 "Must be!");
12319 assert((ICmpInst::isLT(CanonicalFoundPred) ||
12320 ICmpInst::isLE(CanonicalFoundPred)) &&
12321 "Must be!");
12322 if (ICmpInst::isSigned(CanonicalPred) && isKnownNonNegative(CanonicalRHS))
12323 // Use implication:
12324 // x <u y && y >=s 0 --> x <s y.
12325 // If we can prove the left part, the right part is also proven.
12326 return isImpliedCondOperands(CanonicalFoundPred, CanonicalLHS,
12327 CanonicalRHS, CanonicalFoundLHS,
12328 CanonicalFoundRHS);
12329 if (ICmpInst::isUnsigned(CanonicalPred) && isKnownNegative(CanonicalRHS))
12330 // Use implication:
12331 // x <s y && y <s 0 --> x <u y.
12332 // If we can prove the left part, the right part is also proven.
12333 return isImpliedCondOperands(CanonicalFoundPred, CanonicalLHS,
12334 CanonicalRHS, CanonicalFoundLHS,
12335 CanonicalFoundRHS);
12336 }
12337
12338 // Check if we can make progress by sharpening ranges.
12339 if (FoundPred == ICmpInst::ICMP_NE &&
12340 (isa<SCEVConstant>(FoundLHS) || isa<SCEVConstant>(FoundRHS))) {
12341
12342 const SCEVConstant *C = nullptr;
12343 const SCEV *V = nullptr;
12344
12345 if (isa<SCEVConstant>(FoundLHS)) {
12346 C = cast<SCEVConstant>(FoundLHS);
12347 V = FoundRHS;
12348 } else {
12349 C = cast<SCEVConstant>(FoundRHS);
12350 V = FoundLHS;
12351 }
12352
12353 // The guarding predicate tells us that C != V. If the known range
12354 // of V is [C, t), we can sharpen the range to [C + 1, t). The
12355 // range we consider has to correspond to same signedness as the
12356 // predicate we're interested in folding.
12357
12358 APInt Min = ICmpInst::isSigned(Pred) ?
12360
12361 if (Min == C->getAPInt()) {
12362 // Given (V >= Min && V != Min) we conclude V >= (Min + 1).
12363 // This is true even if (Min + 1) wraps around -- in case of
12364 // wraparound, (Min + 1) < Min, so (V >= Min => V >= (Min + 1)).
12365
12366 APInt SharperMin = Min + 1;
12367
12368 switch (Pred) {
12369 case ICmpInst::ICMP_SGE:
12370 case ICmpInst::ICMP_UGE:
12371 // We know V `Pred` SharperMin. If this implies LHS `Pred`
12372 // RHS, we're done.
12373 if (isImpliedCondOperands(Pred, LHS, RHS, V, getConstant(SharperMin),
12374 CtxI))
12375 return true;
12376 [[fallthrough]];
12377
12378 case ICmpInst::ICMP_SGT:
12379 case ICmpInst::ICMP_UGT:
12380 // We know from the range information that (V `Pred` Min ||
12381 // V == Min). We know from the guarding condition that !(V
12382 // == Min). This gives us
12383 //
12384 // V `Pred` Min || V == Min && !(V == Min)
12385 // => V `Pred` Min
12386 //
12387 // If V `Pred` Min implies LHS `Pred` RHS, we're done.
12388
12389 if (isImpliedCondOperands(Pred, LHS, RHS, V, getConstant(Min), CtxI))
12390 return true;
12391 break;
12392
12393 // `LHS < RHS` and `LHS <= RHS` are handled in the same way as `RHS > LHS` and `RHS >= LHS` respectively.
12394 case ICmpInst::ICMP_SLE:
12395 case ICmpInst::ICMP_ULE:
12396 if (isImpliedCondOperands(ICmpInst::getSwappedCmpPredicate(Pred), RHS,
12397 LHS, V, getConstant(SharperMin), CtxI))
12398 return true;
12399 [[fallthrough]];
12400
12401 case ICmpInst::ICMP_SLT:
12402 case ICmpInst::ICMP_ULT:
12403 if (isImpliedCondOperands(ICmpInst::getSwappedCmpPredicate(Pred), RHS,
12404 LHS, V, getConstant(Min), CtxI))
12405 return true;
12406 break;
12407
12408 default:
12409 // No change
12410 break;
12411 }
12412 }
12413 }
12414
12415 // Check whether the actual condition is beyond sufficient.
12416 if (FoundPred == ICmpInst::ICMP_EQ)
12417 if (ICmpInst::isTrueWhenEqual(Pred))
12418 if (isImpliedCondOperands(Pred, LHS, RHS, FoundLHS, FoundRHS, CtxI))
12419 return true;
12420 if (Pred == ICmpInst::ICMP_NE)
12421 if (!ICmpInst::isTrueWhenEqual(FoundPred))
12422 if (isImpliedCondOperands(FoundPred, LHS, RHS, FoundLHS, FoundRHS, CtxI))
12423 return true;
12424
12425 if (isImpliedCondOperandsViaRanges(Pred, LHS, RHS, FoundPred, FoundLHS, FoundRHS))
12426 return true;
12427
12428 // Otherwise assume the worst.
12429 return false;
12430}
12431
12432bool ScalarEvolution::splitBinaryAdd(SCEVUse Expr, SCEVUse &L, SCEVUse &R,
12433 SCEV::NoWrapFlags &Flags) {
12434 if (!match(Expr, m_scev_Add(m_SCEV(L), m_SCEV(R))))
12435 return false;
12436
12437 Flags = cast<SCEVAddExpr>(Expr)->getNoWrapFlags();
12438 return true;
12439}
12440
12441std::optional<APInt>
12443 // We avoid subtracting expressions here because this function is usually
12444 // fairly deep in the call stack (i.e. is called many times).
12445
12446 unsigned BW = getTypeSizeInBits(More->getType());
12447 APInt Diff(BW, 0);
12448 APInt DiffMul(BW, 1);
12449 // Try various simplifications to reduce the difference to a constant. Limit
12450 // the number of allowed simplifications to keep compile-time low.
12451 for (unsigned I = 0; I < 8; ++I) {
12452 if (More == Less)
12453 return Diff;
12454
12455 // Reduce addrecs with identical steps to their start value.
12457 const auto *LAR = cast<SCEVAddRecExpr>(Less);
12458 const auto *MAR = cast<SCEVAddRecExpr>(More);
12459
12460 if (LAR->getLoop() != MAR->getLoop())
12461 return std::nullopt;
12462
12463 // We look at affine expressions only; not for correctness but to keep
12464 // getStepRecurrence cheap.
12465 if (!LAR->isAffine() || !MAR->isAffine())
12466 return std::nullopt;
12467
12468 if (LAR->getStepRecurrence(*this) != MAR->getStepRecurrence(*this))
12469 return std::nullopt;
12470
12471 Less = LAR->getStart();
12472 More = MAR->getStart();
12473 continue;
12474 }
12475
12476 // Try to match a common constant multiply.
12477 auto MatchConstMul =
12478 [](const SCEV *S) -> std::optional<std::pair<const SCEV *, APInt>> {
12479 const APInt *C;
12480 const SCEV *Op;
12481 if (match(S, m_scev_Mul(m_scev_APInt(C), m_SCEV(Op))))
12482 return {{Op, *C}};
12483 return std::nullopt;
12484 };
12485 if (auto MatchedMore = MatchConstMul(More)) {
12486 if (auto MatchedLess = MatchConstMul(Less)) {
12487 if (MatchedMore->second == MatchedLess->second) {
12488 More = MatchedMore->first;
12489 Less = MatchedLess->first;
12490 DiffMul *= MatchedMore->second;
12491 continue;
12492 }
12493 }
12494 }
12495
12496 // Try to cancel out common factors in two add expressions.
12498 auto Add = [&](const SCEV *S, int Mul) {
12499 if (auto *C = dyn_cast<SCEVConstant>(S)) {
12500 if (Mul == 1) {
12501 Diff += C->getAPInt() * DiffMul;
12502 } else {
12503 assert(Mul == -1);
12504 Diff -= C->getAPInt() * DiffMul;
12505 }
12506 } else
12507 Multiplicity[S] += Mul;
12508 };
12509 auto Decompose = [&](const SCEV *S, int Mul) {
12510 if (isa<SCEVAddExpr>(S)) {
12511 for (const SCEV *Op : S->operands())
12512 Add(Op, Mul);
12513 } else
12514 Add(S, Mul);
12515 };
12516 Decompose(More, 1);
12517 Decompose(Less, -1);
12518
12519 // Check whether all the non-constants cancel out, or reduce to new
12520 // More/Less values.
12521 const SCEV *NewMore = nullptr, *NewLess = nullptr;
12522 for (const auto &[S, Mul] : Multiplicity) {
12523 if (Mul == 0)
12524 continue;
12525 if (Mul == 1) {
12526 if (NewMore)
12527 return std::nullopt;
12528 NewMore = S;
12529 } else if (Mul == -1) {
12530 if (NewLess)
12531 return std::nullopt;
12532 NewLess = S;
12533 } else
12534 return std::nullopt;
12535 }
12536
12537 // Values stayed the same, no point in trying further.
12538 if (NewMore == More || NewLess == Less)
12539 return std::nullopt;
12540
12541 More = NewMore;
12542 Less = NewLess;
12543
12544 // Reduced to constant.
12545 if (!More && !Less)
12546 return Diff;
12547
12548 // Left with variable on only one side, bail out.
12549 if (!More || !Less)
12550 return std::nullopt;
12551 }
12552
12553 // Did not reduce to constant.
12554 return std::nullopt;
12555}
12556
12557bool ScalarEvolution::isImpliedCondOperandsViaAddRecStart(
12558 CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS, const SCEV *FoundLHS,
12559 const SCEV *FoundRHS, const Instruction *CtxI) {
12560 // Try to recognize the following pattern:
12561 //
12562 // FoundRHS = ...
12563 // ...
12564 // loop:
12565 // FoundLHS = {Start,+,W}
12566 // context_bb: // Basic block from the same loop
12567 // known(Pred, FoundLHS, FoundRHS)
12568 //
12569 // If some predicate is known in the context of a loop, it is also known on
12570 // each iteration of this loop, including the first iteration. Therefore, in
12571 // this case, `FoundLHS Pred FoundRHS` implies `Start Pred FoundRHS`. Try to
12572 // prove the original pred using this fact.
12573 if (!CtxI)
12574 return false;
12575 const BasicBlock *ContextBB = CtxI->getParent();
12576 // Make sure AR varies in the context block.
12577 if (auto *AR = dyn_cast<SCEVAddRecExpr>(FoundLHS)) {
12578 const Loop *L = AR->getLoop();
12579 const auto *Latch = L->getLoopLatch();
12580 // Make sure that context belongs to the loop and executes on 1st iteration
12581 // (if it ever executes at all).
12582 if (!L->contains(ContextBB) || !Latch || !DT.dominates(ContextBB, Latch))
12583 return false;
12584 if (!isAvailableAtLoopEntry(FoundRHS, AR->getLoop()))
12585 return false;
12586 return isImpliedCondOperands(Pred, LHS, RHS, AR->getStart(), FoundRHS);
12587 }
12588
12589 if (auto *AR = dyn_cast<SCEVAddRecExpr>(FoundRHS)) {
12590 const Loop *L = AR->getLoop();
12591 const auto *Latch = L->getLoopLatch();
12592 // Make sure that context belongs to the loop and executes on 1st iteration
12593 // (if it ever executes at all).
12594 if (!L->contains(ContextBB) || !Latch || !DT.dominates(ContextBB, Latch))
12595 return false;
12596 if (!isAvailableAtLoopEntry(FoundLHS, AR->getLoop()))
12597 return false;
12598 return isImpliedCondOperands(Pred, LHS, RHS, FoundLHS, AR->getStart());
12599 }
12600
12601 return false;
12602}
12603
12604bool ScalarEvolution::isImpliedCondOperandsViaNoOverflow(CmpPredicate Pred,
12605 const SCEV *LHS,
12606 const SCEV *RHS,
12607 const SCEV *FoundLHS,
12608 const SCEV *FoundRHS) {
12609 if (Pred != CmpInst::ICMP_SLT && Pred != CmpInst::ICMP_ULT)
12610 return false;
12611
12612 const auto *AddRecLHS = dyn_cast<SCEVAddRecExpr>(LHS);
12613 if (!AddRecLHS)
12614 return false;
12615
12616 const auto *AddRecFoundLHS = dyn_cast<SCEVAddRecExpr>(FoundLHS);
12617 if (!AddRecFoundLHS)
12618 return false;
12619
12620 // We'd like to let SCEV reason about control dependencies, so we constrain
12621 // both the inequalities to be about add recurrences on the same loop. This
12622 // way we can use isLoopEntryGuardedByCond later.
12623
12624 const Loop *L = AddRecFoundLHS->getLoop();
12625 if (L != AddRecLHS->getLoop())
12626 return false;
12627
12628 // FoundLHS u< FoundRHS u< -C => (FoundLHS + C) u< (FoundRHS + C) ... (1)
12629 //
12630 // FoundLHS s< FoundRHS s< INT_MIN - C => (FoundLHS + C) s< (FoundRHS + C)
12631 // ... (2)
12632 //
12633 // Informal proof for (2), assuming (1) [*]:
12634 //
12635 // We'll also assume (A s< B) <=> ((A + INT_MIN) u< (B + INT_MIN)) ... (3)[**]
12636 //
12637 // Then
12638 //
12639 // FoundLHS s< FoundRHS s< INT_MIN - C
12640 // <=> (FoundLHS + INT_MIN) u< (FoundRHS + INT_MIN) u< -C [ using (3) ]
12641 // <=> (FoundLHS + INT_MIN + C) u< (FoundRHS + INT_MIN + C) [ using (1) ]
12642 // <=> (FoundLHS + INT_MIN + C + INT_MIN) s<
12643 // (FoundRHS + INT_MIN + C + INT_MIN) [ using (3) ]
12644 // <=> FoundLHS + C s< FoundRHS + C
12645 //
12646 // [*]: (1) can be proved by ruling out overflow.
12647 //
12648 // [**]: This can be proved by analyzing all the four possibilities:
12649 // (A s< 0, B s< 0), (A s< 0, B s>= 0), (A s>= 0, B s< 0) and
12650 // (A s>= 0, B s>= 0).
12651 //
12652 // Note:
12653 // Despite (2), "FoundRHS s< INT_MIN - C" does not mean that "FoundRHS + C"
12654 // will not sign underflow. For instance, say FoundLHS = (i8 -128), FoundRHS
12655 // = (i8 -127) and C = (i8 -100). Then INT_MIN - C = (i8 -28), and FoundRHS
12656 // s< (INT_MIN - C). Lack of sign overflow / underflow in "FoundRHS + C" is
12657 // neither necessary nor sufficient to prove "(FoundLHS + C) s< (FoundRHS +
12658 // C)".
12659
12660 std::optional<APInt> LDiff = computeConstantDifference(LHS, FoundLHS);
12661 if (!LDiff)
12662 return false;
12663 std::optional<APInt> RDiff = computeConstantDifference(RHS, FoundRHS);
12664 if (!RDiff || *LDiff != *RDiff)
12665 return false;
12666
12667 if (LDiff->isMinValue())
12668 return true;
12669
12670 APInt FoundRHSLimit;
12671
12672 if (Pred == CmpInst::ICMP_ULT) {
12673 FoundRHSLimit = -(*RDiff);
12674 } else {
12675 assert(Pred == CmpInst::ICMP_SLT && "Checked above!");
12676 FoundRHSLimit = APInt::getSignedMinValue(getTypeSizeInBits(RHS->getType())) - *RDiff;
12677 }
12678
12679 // Try to prove (1) or (2), as needed.
12680 return isAvailableAtLoopEntry(FoundRHS, L) &&
12681 isLoopEntryGuardedByCond(L, Pred, FoundRHS,
12682 getConstant(FoundRHSLimit));
12683}
12684
12685bool ScalarEvolution::isImpliedViaMerge(CmpPredicate Pred, const SCEV *LHS,
12686 const SCEV *RHS, const SCEV *FoundLHS,
12687 const SCEV *FoundRHS, unsigned Depth) {
12688 const PHINode *LPhi = nullptr, *RPhi = nullptr;
12689
12690 llvm::scope_exit ClearOnExit([&]() {
12691 if (LPhi) {
12692 bool Erased = PendingMerges.erase(LPhi);
12693 assert(Erased && "Failed to erase LPhi!");
12694 (void)Erased;
12695 }
12696 if (RPhi) {
12697 bool Erased = PendingMerges.erase(RPhi);
12698 assert(Erased && "Failed to erase RPhi!");
12699 (void)Erased;
12700 }
12701 });
12702
12703 // Find respective Phis and check that they are not being pending.
12704 if (const SCEVUnknown *LU = dyn_cast<SCEVUnknown>(LHS))
12705 if (auto *Phi = dyn_cast<PHINode>(LU->getValue())) {
12706 if (!PendingMerges.insert(Phi).second)
12707 return false;
12708 LPhi = Phi;
12709 }
12710 if (const SCEVUnknown *RU = dyn_cast<SCEVUnknown>(RHS))
12711 if (auto *Phi = dyn_cast<PHINode>(RU->getValue())) {
12712 // If we detect a loop of Phi nodes being processed by this method, for
12713 // example:
12714 //
12715 // %a = phi i32 [ %some1, %preheader ], [ %b, %latch ]
12716 // %b = phi i32 [ %some2, %preheader ], [ %a, %latch ]
12717 //
12718 // we don't want to deal with a case that complex, so return conservative
12719 // answer false.
12720 if (!PendingMerges.insert(Phi).second)
12721 return false;
12722 RPhi = Phi;
12723 }
12724
12725 // If none of LHS, RHS is a Phi, nothing to do here.
12726 if (!LPhi && !RPhi)
12727 return false;
12728
12729 // If there is a SCEVUnknown Phi we are interested in, make it left.
12730 if (!LPhi) {
12731 std::swap(LHS, RHS);
12732 std::swap(FoundLHS, FoundRHS);
12733 std::swap(LPhi, RPhi);
12735 }
12736
12737 assert(LPhi && "LPhi should definitely be a SCEVUnknown Phi!");
12738 const BasicBlock *LBB = LPhi->getParent();
12739 const SCEVAddRecExpr *RAR = dyn_cast<SCEVAddRecExpr>(RHS);
12740
12741 auto ProvedEasily = [&](const SCEV *S1, const SCEV *S2) {
12742 return isKnownViaNonRecursiveReasoning(Pred, S1, S2) ||
12743 isImpliedCondOperandsViaRanges(Pred, S1, S2, Pred, FoundLHS, FoundRHS) ||
12744 isImpliedViaOperations(Pred, S1, S2, FoundLHS, FoundRHS, Depth);
12745 };
12746
12747 if (RPhi && RPhi->getParent() == LBB) {
12748 // Case one: RHS is also a SCEVUnknown Phi from the same basic block.
12749 // If we compare two Phis from the same block, and for each entry block
12750 // the predicate is true for incoming values from this block, then the
12751 // predicate is also true for the Phis.
12752 for (const BasicBlock *IncBB : predecessors(LBB)) {
12753 const SCEV *L = getSCEV(LPhi->getIncomingValueForBlock(IncBB));
12754 const SCEV *R = getSCEV(RPhi->getIncomingValueForBlock(IncBB));
12755 if (!ProvedEasily(L, R))
12756 return false;
12757 }
12758 } else if (RAR && RAR->getLoop()->getHeader() == LBB) {
12759 // Case two: RHS is also a Phi from the same basic block, and it is an
12760 // AddRec. It means that there is a loop which has both AddRec and Unknown
12761 // PHIs, for it we can compare incoming values of AddRec from above the loop
12762 // and latch with their respective incoming values of LPhi.
12763 // TODO: Generalize to handle loops with many inputs in a header.
12764 if (LPhi->getNumIncomingValues() != 2) return false;
12765
12766 auto *RLoop = RAR->getLoop();
12767 auto *Predecessor = RLoop->getLoopPredecessor();
12768 assert(Predecessor && "Loop with AddRec with no predecessor?");
12769 const SCEV *L1 = getSCEV(LPhi->getIncomingValueForBlock(Predecessor));
12770 if (!ProvedEasily(L1, RAR->getStart()))
12771 return false;
12772 auto *Latch = RLoop->getLoopLatch();
12773 assert(Latch && "Loop with AddRec with no latch?");
12774 const SCEV *L2 = getSCEV(LPhi->getIncomingValueForBlock(Latch));
12775 if (!ProvedEasily(L2, RAR->getPostIncExpr(*this)))
12776 return false;
12777 } else {
12778 // In all other cases go over inputs of LHS and compare each of them to RHS,
12779 // the predicate is true for (LHS, RHS) if it is true for all such pairs.
12780 // At this point RHS is either a non-Phi, or it is a Phi from some block
12781 // different from LBB.
12782 for (const BasicBlock *IncBB : predecessors(LBB)) {
12783 // Check that RHS is available in this block.
12784 if (!dominates(RHS, IncBB))
12785 return false;
12786 const SCEV *L = getSCEV(LPhi->getIncomingValueForBlock(IncBB));
12787 // Make sure L does not refer to a value from a potentially previous
12788 // iteration of a loop.
12789 if (!properlyDominates(L, LBB))
12790 return false;
12791 // Addrecs are considered to properly dominate their loop, so are missed
12792 // by the previous check. Discard any values that have computable
12793 // evolution in this loop.
12794 if (auto *Loop = LI.getLoopFor(LBB))
12795 if (hasComputableLoopEvolution(L, Loop))
12796 return false;
12797 if (!ProvedEasily(L, RHS))
12798 return false;
12799 }
12800 }
12801 return true;
12802}
12803
12804bool ScalarEvolution::isImpliedCondOperandsViaShift(CmpPredicate Pred,
12805 const SCEV *LHS,
12806 const SCEV *RHS,
12807 const SCEV *FoundLHS,
12808 const SCEV *FoundRHS) {
12809 // We want to imply LHS < RHS from LHS < (RHS >> shiftvalue). First, make
12810 // sure that we are dealing with same LHS.
12811 if (RHS == FoundRHS) {
12812 std::swap(LHS, RHS);
12813 std::swap(FoundLHS, FoundRHS);
12815 }
12816 if (LHS != FoundLHS)
12817 return false;
12818
12819 auto *SUFoundRHS = dyn_cast<SCEVUnknown>(FoundRHS);
12820 if (!SUFoundRHS)
12821 return false;
12822
12823 Value *Shiftee, *ShiftValue;
12824
12825 using namespace PatternMatch;
12826 if (match(SUFoundRHS->getValue(),
12827 m_LShr(m_Value(Shiftee), m_Value(ShiftValue)))) {
12828 auto *ShifteeS = getSCEV(Shiftee);
12829 // Prove one of the following:
12830 // LHS <u (shiftee >> shiftvalue) && shiftee <=u RHS ---> LHS <u RHS
12831 // LHS <=u (shiftee >> shiftvalue) && shiftee <=u RHS ---> LHS <=u RHS
12832 // LHS <s (shiftee >> shiftvalue) && shiftee <=s RHS && shiftee >=s 0
12833 // ---> LHS <s RHS
12834 // LHS <=s (shiftee >> shiftvalue) && shiftee <=s RHS && shiftee >=s 0
12835 // ---> LHS <=s RHS
12836 if (Pred == ICmpInst::ICMP_ULT || Pred == ICmpInst::ICMP_ULE)
12837 return isKnownPredicate(ICmpInst::ICMP_ULE, ShifteeS, RHS);
12838 if (Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_SLE)
12839 if (isKnownNonNegative(ShifteeS))
12840 return isKnownPredicate(ICmpInst::ICMP_SLE, ShifteeS, RHS);
12841 }
12842
12843 return false;
12844}
12845
12846bool ScalarEvolution::isImpliedCondOperands(CmpPredicate Pred, const SCEV *LHS,
12847 const SCEV *RHS,
12848 const SCEV *FoundLHS,
12849 const SCEV *FoundRHS,
12850 const Instruction *CtxI) {
12851 return isImpliedCondOperandsViaRanges(Pred, LHS, RHS, Pred, FoundLHS,
12852 FoundRHS) ||
12853 isImpliedCondOperandsViaNoOverflow(Pred, LHS, RHS, FoundLHS,
12854 FoundRHS) ||
12855 isImpliedCondOperandsViaShift(Pred, LHS, RHS, FoundLHS, FoundRHS) ||
12856 isImpliedCondOperandsViaAddRecStart(Pred, LHS, RHS, FoundLHS, FoundRHS,
12857 CtxI) ||
12858 isImpliedCondOperandsHelper(Pred, LHS, RHS, FoundLHS, FoundRHS);
12859}
12860
12861/// Is MaybeMinMaxExpr an (U|S)(Min|Max) of Candidate and some other values?
12862template <typename MinMaxExprType>
12863static bool IsMinMaxConsistingOf(const SCEV *MaybeMinMaxExpr,
12864 const SCEV *Candidate) {
12865 const MinMaxExprType *MinMaxExpr = dyn_cast<MinMaxExprType>(MaybeMinMaxExpr);
12866 if (!MinMaxExpr)
12867 return false;
12868
12869 return is_contained(MinMaxExpr->operands(), Candidate);
12870}
12871
12873 CmpPredicate Pred, const SCEV *LHS,
12874 const SCEV *RHS) {
12875 // If both sides are affine addrecs for the same loop, with equal
12876 // steps, and we know the recurrences don't wrap, then we only
12877 // need to check the predicate on the starting values.
12878
12879 if (!ICmpInst::isRelational(Pred))
12880 return false;
12881
12882 const SCEV *LStart, *RStart, *Step;
12883 const Loop *L;
12884 if (!match(LHS,
12885 m_scev_AffineAddRec(m_SCEV(LStart), m_SCEV(Step), m_Loop(L))) ||
12887 m_SpecificLoop(L))))
12888 return false;
12893 if (!LAR->getNoWrapFlags(NW) || !RAR->getNoWrapFlags(NW))
12894 return false;
12895
12896 return SE.isKnownPredicate(Pred, LStart, RStart);
12897}
12898
12899/// Is LHS `Pred` RHS true on the virtue of LHS or RHS being a Min or Max
12900/// expression?
12902 const SCEV *LHS, const SCEV *RHS) {
12903 switch (Pred) {
12904 default:
12905 return false;
12906
12907 case ICmpInst::ICMP_SGE:
12908 std::swap(LHS, RHS);
12909 [[fallthrough]];
12910 case ICmpInst::ICMP_SLE:
12911 return
12912 // min(A, ...) <= A
12914 // A <= max(A, ...)
12916
12917 case ICmpInst::ICMP_UGE:
12918 std::swap(LHS, RHS);
12919 [[fallthrough]];
12920 case ICmpInst::ICMP_ULE:
12921 return
12922 // min(A, ...) <= A
12923 // FIXME: what about umin_seq?
12925 // A <= max(A, ...)
12927 }
12928
12929 llvm_unreachable("covered switch fell through?!");
12930}
12931
12932bool ScalarEvolution::isImpliedViaOperations(CmpPredicate Pred, const SCEV *LHS,
12933 const SCEV *RHS,
12934 const SCEV *FoundLHS,
12935 const SCEV *FoundRHS,
12936 unsigned Depth) {
12939 "LHS and RHS have different sizes?");
12940 assert(getTypeSizeInBits(FoundLHS->getType()) ==
12941 getTypeSizeInBits(FoundRHS->getType()) &&
12942 "FoundLHS and FoundRHS have different sizes?");
12943 // We want to avoid hurting the compile time with analysis of too big trees.
12945 return false;
12946
12947 // We only want to work with GT comparison so far.
12948 if (ICmpInst::isLT(Pred)) {
12950 std::swap(LHS, RHS);
12951 std::swap(FoundLHS, FoundRHS);
12952 }
12953
12955
12956 // For unsigned, try to reduce it to corresponding signed comparison.
12957 if (P == ICmpInst::ICMP_UGT)
12958 // We can replace unsigned predicate with its signed counterpart if all
12959 // involved values are non-negative.
12960 // TODO: We could have better support for unsigned.
12961 if (isKnownNonNegative(FoundLHS) && isKnownNonNegative(FoundRHS)) {
12962 // Knowing that both FoundLHS and FoundRHS are non-negative, and knowing
12963 // FoundLHS >u FoundRHS, we also know that FoundLHS >s FoundRHS. Let us
12964 // use this fact to prove that LHS and RHS are non-negative.
12965 const SCEV *MinusOne = getMinusOne(LHS->getType());
12966 if (isImpliedCondOperands(ICmpInst::ICMP_SGT, LHS, MinusOne, FoundLHS,
12967 FoundRHS) &&
12968 isImpliedCondOperands(ICmpInst::ICMP_SGT, RHS, MinusOne, FoundLHS,
12969 FoundRHS))
12971 }
12972
12973 if (P != ICmpInst::ICMP_SGT)
12974 return false;
12975
12976 auto GetOpFromSExt = [&](const SCEV *S) -> const SCEV * {
12977 if (auto *Ext = dyn_cast<SCEVSignExtendExpr>(S))
12978 return Ext->getOperand();
12979 // TODO: If S is a SCEVConstant then you can cheaply "strip" the sext off
12980 // the constant in some cases.
12981 return S;
12982 };
12983
12984 // Acquire values from extensions.
12985 auto *OrigLHS = LHS;
12986 auto *OrigFoundLHS = FoundLHS;
12987 LHS = GetOpFromSExt(LHS);
12988 FoundLHS = GetOpFromSExt(FoundLHS);
12989
12990 // Is the SGT predicate can be proved trivially or using the found context.
12991 auto IsSGTViaContext = [&](const SCEV *S1, const SCEV *S2) {
12992 return isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_SGT, S1, S2) ||
12993 isImpliedViaOperations(ICmpInst::ICMP_SGT, S1, S2, OrigFoundLHS,
12994 FoundRHS, Depth + 1);
12995 };
12996
12997 if (auto *LHSAddExpr = dyn_cast<SCEVAddExpr>(LHS)) {
12998 // We want to avoid creation of any new non-constant SCEV. Since we are
12999 // going to compare the operands to RHS, we should be certain that we don't
13000 // need any size extensions for this. So let's decline all cases when the
13001 // sizes of types of LHS and RHS do not match.
13002 // TODO: Maybe try to get RHS from sext to catch more cases?
13004 return false;
13005
13006 // Should not overflow.
13007 if (!LHSAddExpr->hasNoSignedWrap())
13008 return false;
13009
13010 SCEVUse LL = LHSAddExpr->getOperand(0);
13011 SCEVUse LR = LHSAddExpr->getOperand(1);
13012 auto *MinusOne = getMinusOne(RHS->getType());
13013
13014 // Checks that S1 >= 0 && S2 > RHS, trivially or using the found context.
13015 auto IsSumGreaterThanRHS = [&](const SCEV *S1, const SCEV *S2) {
13016 return IsSGTViaContext(S1, MinusOne) && IsSGTViaContext(S2, RHS);
13017 };
13018 // Try to prove the following rule:
13019 // (LHS = LL + LR) && (LL >= 0) && (LR > RHS) => (LHS > RHS).
13020 // (LHS = LL + LR) && (LR >= 0) && (LL > RHS) => (LHS > RHS).
13021 if (IsSumGreaterThanRHS(LL, LR) || IsSumGreaterThanRHS(LR, LL))
13022 return true;
13023 } else if (auto *LHSUnknownExpr = dyn_cast<SCEVUnknown>(LHS)) {
13024 Value *LL, *LR;
13025 // FIXME: Once we have SDiv implemented, we can get rid of this matching.
13026
13027 using namespace llvm::PatternMatch;
13028
13029 if (match(LHSUnknownExpr->getValue(), m_SDiv(m_Value(LL), m_Value(LR)))) {
13030 // Rules for division.
13031 // We are going to perform some comparisons with Denominator and its
13032 // derivative expressions. In general case, creating a SCEV for it may
13033 // lead to a complex analysis of the entire graph, and in particular it
13034 // can request trip count recalculation for the same loop. This would
13035 // cache as SCEVCouldNotCompute to avoid the infinite recursion. To avoid
13036 // this, we only want to create SCEVs that are constants in this section.
13037 // So we bail if Denominator is not a constant.
13038 if (!isa<ConstantInt>(LR))
13039 return false;
13040
13041 auto *Denominator = cast<SCEVConstant>(getSCEV(LR));
13042
13043 // We want to make sure that LHS = FoundLHS / Denominator. If it is so,
13044 // then a SCEV for the numerator already exists and matches with FoundLHS.
13045 auto *Numerator = getExistingSCEV(LL);
13046 if (!Numerator || Numerator->getType() != FoundLHS->getType())
13047 return false;
13048
13049 // Make sure that the numerator matches with FoundLHS and the denominator
13050 // is positive.
13051 if (!HasSameValue(Numerator, FoundLHS) || !isKnownPositive(Denominator))
13052 return false;
13053
13054 auto *DTy = Denominator->getType();
13055 auto *FRHSTy = FoundRHS->getType();
13056 if (DTy->isPointerTy() != FRHSTy->isPointerTy())
13057 // One of types is a pointer and another one is not. We cannot extend
13058 // them properly to a wider type, so let us just reject this case.
13059 // TODO: Usage of getEffectiveSCEVType for DTy, FRHSTy etc should help
13060 // to avoid this check.
13061 return false;
13062
13063 // Given that:
13064 // FoundLHS > FoundRHS, LHS = FoundLHS / Denominator, Denominator > 0.
13065 auto *WTy = getWiderType(DTy, FRHSTy);
13066 auto *DenominatorExt = getNoopOrSignExtend(Denominator, WTy);
13067 auto *FoundRHSExt = getNoopOrSignExtend(FoundRHS, WTy);
13068
13069 // Try to prove the following rule:
13070 // (FoundRHS > Denominator - 2) && (RHS <= 0) => (LHS > RHS).
13071 // For example, given that FoundLHS > 2. It means that FoundLHS is at
13072 // least 3. If we divide it by Denominator < 4, we will have at least 1.
13073 auto *DenomMinusTwo = getMinusSCEV(DenominatorExt, getConstant(WTy, 2));
13074 if (isKnownNonPositive(RHS) &&
13075 IsSGTViaContext(FoundRHSExt, DenomMinusTwo))
13076 return true;
13077
13078 // Try to prove the following rule:
13079 // (FoundRHS > -1 - Denominator) && (RHS < 0) => (LHS > RHS).
13080 // For example, given that FoundLHS > -3. Then FoundLHS is at least -2.
13081 // If we divide it by Denominator > 2, then:
13082 // 1. If FoundLHS is negative, then the result is 0.
13083 // 2. If FoundLHS is non-negative, then the result is non-negative.
13084 // Anyways, the result is non-negative.
13085 auto *MinusOne = getMinusOne(WTy);
13086 auto *NegDenomMinusOne = getMinusSCEV(MinusOne, DenominatorExt);
13087 if (isKnownNegative(RHS) &&
13088 IsSGTViaContext(FoundRHSExt, NegDenomMinusOne))
13089 return true;
13090 }
13091 }
13092
13093 // If our expression contained SCEVUnknown Phis, and we split it down and now
13094 // need to prove something for them, try to prove the predicate for every
13095 // possible incoming values of those Phis.
13096 if (isImpliedViaMerge(Pred, OrigLHS, RHS, OrigFoundLHS, FoundRHS, Depth + 1))
13097 return true;
13098
13099 return false;
13100}
13101
13103 const SCEV *RHS) {
13104 // zext x u<= sext x, sext x s<= zext x
13105 const SCEV *Op;
13106 switch (Pred) {
13107 case ICmpInst::ICMP_SGE:
13108 std::swap(LHS, RHS);
13109 [[fallthrough]];
13110 case ICmpInst::ICMP_SLE: {
13111 // If operand >=s 0 then ZExt == SExt. If operand <s 0 then SExt <s ZExt.
13112 return match(LHS, m_scev_SExt(m_SCEV(Op))) &&
13114 }
13115 case ICmpInst::ICMP_UGE:
13116 std::swap(LHS, RHS);
13117 [[fallthrough]];
13118 case ICmpInst::ICMP_ULE: {
13119 // If operand >=u 0 then ZExt == SExt. If operand <u 0 then ZExt <u SExt.
13120 return match(LHS, m_scev_ZExt(m_SCEV(Op))) &&
13122 }
13123 default:
13124 return false;
13125 };
13126 llvm_unreachable("unhandled case");
13127}
13128
13129bool ScalarEvolution::isKnownViaNonRecursiveReasoning(CmpPredicate Pred,
13130 SCEVUse LHS,
13131 SCEVUse RHS) {
13132 return isKnownPredicateExtendIdiom(Pred, LHS, RHS) ||
13133 isKnownPredicateViaConstantRanges(Pred, LHS, RHS) ||
13134 IsKnownPredicateViaMinOrMax(*this, Pred, LHS, RHS) ||
13135 IsKnownPredicateViaAddRecStart(*this, Pred, LHS, RHS) ||
13136 isKnownPredicateViaNoOverflow(Pred, LHS, RHS);
13137}
13138
13139bool ScalarEvolution::isImpliedCondOperandsHelper(CmpPredicate Pred,
13140 const SCEV *LHS,
13141 const SCEV *RHS,
13142 const SCEV *FoundLHS,
13143 const SCEV *FoundRHS) {
13144 switch (Pred) {
13145 default:
13146 llvm_unreachable("Unexpected CmpPredicate value!");
13147 case ICmpInst::ICMP_EQ:
13148 case ICmpInst::ICMP_NE:
13149 if (HasSameValue(LHS, FoundLHS) && HasSameValue(RHS, FoundRHS))
13150 return true;
13151 break;
13152 case ICmpInst::ICMP_SLT:
13153 case ICmpInst::ICMP_SLE:
13154 if (isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_SLE, LHS, FoundLHS) &&
13155 isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_SGE, RHS, FoundRHS))
13156 return true;
13157 break;
13158 case ICmpInst::ICMP_SGT:
13159 case ICmpInst::ICMP_SGE:
13160 if (isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_SGE, LHS, FoundLHS) &&
13161 isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_SLE, RHS, FoundRHS))
13162 return true;
13163 break;
13164 case ICmpInst::ICMP_ULT:
13165 case ICmpInst::ICMP_ULE:
13166 if (isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_ULE, LHS, FoundLHS) &&
13167 isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_UGE, RHS, FoundRHS))
13168 return true;
13169 break;
13170 case ICmpInst::ICMP_UGT:
13171 case ICmpInst::ICMP_UGE:
13172 if (isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_UGE, LHS, FoundLHS) &&
13173 isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_ULE, RHS, FoundRHS))
13174 return true;
13175 break;
13176 }
13177
13178 // Maybe it can be proved via operations?
13179 if (isImpliedViaOperations(Pred, LHS, RHS, FoundLHS, FoundRHS))
13180 return true;
13181
13182 return false;
13183}
13184
13185bool ScalarEvolution::isImpliedCondOperandsViaRanges(
13186 CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS, CmpPredicate FoundPred,
13187 const SCEV *FoundLHS, const SCEV *FoundRHS) {
13188 if (!isa<SCEVConstant>(RHS) || !isa<SCEVConstant>(FoundRHS))
13189 // The restriction on `FoundRHS` be lifted easily -- it exists only to
13190 // reduce the compile time impact of this optimization.
13191 return false;
13192
13193 std::optional<APInt> Addend = computeConstantDifference(LHS, FoundLHS);
13194 if (!Addend)
13195 return false;
13196
13197 const APInt &ConstFoundRHS = cast<SCEVConstant>(FoundRHS)->getAPInt();
13198
13199 // `FoundLHSRange` is the range we know `FoundLHS` to be in by virtue of the
13200 // antecedent "`FoundLHS` `FoundPred` `FoundRHS`".
13201 ConstantRange FoundLHSRange =
13202 ConstantRange::makeExactICmpRegion(FoundPred, ConstFoundRHS);
13203
13204 // Since `LHS` is `FoundLHS` + `Addend`, we can compute a range for `LHS`:
13205 ConstantRange LHSRange = FoundLHSRange.add(ConstantRange(*Addend));
13206
13207 // We can also compute the range of values for `LHS` that satisfy the
13208 // consequent, "`LHS` `Pred` `RHS`":
13209 const APInt &ConstRHS = cast<SCEVConstant>(RHS)->getAPInt();
13210 // The antecedent implies the consequent if every value of `LHS` that
13211 // satisfies the antecedent also satisfies the consequent.
13212 return LHSRange.icmp(Pred, ConstRHS);
13213}
13214
13215bool ScalarEvolution::canIVOverflowOnLT(const SCEV *RHS, const SCEV *Stride,
13216 bool IsSigned) {
13217 assert(isKnownPositive(Stride) && "Positive stride expected!");
13218
13219 unsigned BitWidth = getTypeSizeInBits(RHS->getType());
13220 const SCEV *One = getOne(Stride->getType());
13221
13222 if (IsSigned) {
13223 APInt MaxRHS = getSignedRangeMax(RHS);
13224 APInt MaxValue = APInt::getSignedMaxValue(BitWidth);
13225 APInt MaxStrideMinusOne = getSignedRangeMax(getMinusSCEV(Stride, One));
13226
13227 // SMaxRHS + SMaxStrideMinusOne > SMaxValue => overflow!
13228 return (std::move(MaxValue) - MaxStrideMinusOne).slt(MaxRHS);
13229 }
13230
13231 APInt MaxRHS = getUnsignedRangeMax(RHS);
13232 APInt MaxValue = APInt::getMaxValue(BitWidth);
13233 APInt MaxStrideMinusOne = getUnsignedRangeMax(getMinusSCEV(Stride, One));
13234
13235 // UMaxRHS + UMaxStrideMinusOne > UMaxValue => overflow!
13236 return (std::move(MaxValue) - MaxStrideMinusOne).ult(MaxRHS);
13237}
13238
13239bool ScalarEvolution::canIVOverflowOnGT(const SCEV *RHS, const SCEV *Stride,
13240 bool IsSigned) {
13241
13242 unsigned BitWidth = getTypeSizeInBits(RHS->getType());
13243 const SCEV *One = getOne(Stride->getType());
13244
13245 if (IsSigned) {
13246 APInt MinRHS = getSignedRangeMin(RHS);
13247 APInt MinValue = APInt::getSignedMinValue(BitWidth);
13248 APInt MaxStrideMinusOne = getSignedRangeMax(getMinusSCEV(Stride, One));
13249
13250 // SMinRHS - SMaxStrideMinusOne < SMinValue => overflow!
13251 return (std::move(MinValue) + MaxStrideMinusOne).sgt(MinRHS);
13252 }
13253
13254 APInt MinRHS = getUnsignedRangeMin(RHS);
13255 APInt MinValue = APInt::getMinValue(BitWidth);
13256 APInt MaxStrideMinusOne = getUnsignedRangeMax(getMinusSCEV(Stride, One));
13257
13258 // UMinRHS - UMaxStrideMinusOne < UMinValue => overflow!
13259 return (std::move(MinValue) + MaxStrideMinusOne).ugt(MinRHS);
13260}
13261
13263 // umin(N, 1) + floor((N - umin(N, 1)) / D)
13264 // This is equivalent to "1 + floor((N - 1) / D)" for N != 0. The umin
13265 // expression fixes the case of N=0.
13266 const SCEV *MinNOne = getUMinExpr(N, getOne(N->getType()));
13267 const SCEV *NMinusOne = getMinusSCEV(N, MinNOne);
13268 return getAddExpr(MinNOne, getUDivExpr(NMinusOne, D));
13269}
13270
13271const SCEV *ScalarEvolution::computeMaxBECountForLT(const SCEV *Start,
13272 const SCEV *Stride,
13273 const SCEV *End,
13274 unsigned BitWidth,
13275 bool IsSigned) {
13276 // The logic in this function assumes we can represent a positive stride.
13277 // If we can't, the backedge-taken count must be zero.
13278 if (IsSigned && BitWidth == 1)
13279 return getZero(Stride->getType());
13280
13281 // This code below only been closely audited for negative strides in the
13282 // unsigned comparison case, it may be correct for signed comparison, but
13283 // that needs to be established.
13284 if (IsSigned && isKnownNegative(Stride))
13285 return getCouldNotCompute();
13286
13287 // Calculate the maximum backedge count based on the range of values
13288 // permitted by Start, End, and Stride.
13289 APInt MinStart =
13290 IsSigned ? getSignedRangeMin(Start) : getUnsignedRangeMin(Start);
13291
13292 APInt MinStride =
13293 IsSigned ? getSignedRangeMin(Stride) : getUnsignedRangeMin(Stride);
13294
13295 // We assume either the stride is positive, or the backedge-taken count
13296 // is zero. So force StrideForMaxBECount to be at least one.
13297 APInt One(BitWidth, 1);
13298 APInt StrideForMaxBECount = IsSigned ? APIntOps::smax(One, MinStride)
13299 : APIntOps::umax(One, MinStride);
13300
13301 APInt MaxValue = IsSigned ? APInt::getSignedMaxValue(BitWidth)
13302 : APInt::getMaxValue(BitWidth);
13303 APInt Limit = MaxValue - (StrideForMaxBECount - 1);
13304
13305 // Although End can be a MAX expression we estimate MaxEnd considering only
13306 // the case End = RHS of the loop termination condition. This is safe because
13307 // in the other case (End - Start) is zero, leading to a zero maximum backedge
13308 // taken count.
13309 APInt MaxEnd = IsSigned ? APIntOps::smin(getSignedRangeMax(End), Limit)
13310 : APIntOps::umin(getUnsignedRangeMax(End), Limit);
13311
13312 // MaxBECount = ceil((max(MaxEnd, MinStart) - MinStart) / Stride)
13313 MaxEnd = IsSigned ? APIntOps::smax(MaxEnd, MinStart)
13314 : APIntOps::umax(MaxEnd, MinStart);
13315
13316 return getUDivCeilSCEV(getConstant(MaxEnd - MinStart) /* Delta */,
13317 getConstant(StrideForMaxBECount) /* Step */);
13318}
13319
13321ScalarEvolution::howManyLessThans(const SCEV *LHS, const SCEV *RHS,
13322 const Loop *L, bool IsSigned,
13323 bool ControlsOnlyExit, bool AllowPredicates) {
13325
13327 bool PredicatedIV = false;
13328 if (!IV) {
13329 if (auto *ZExt = dyn_cast<SCEVZeroExtendExpr>(LHS)) {
13330 const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(ZExt->getOperand());
13331 if (AR && AR->getLoop() == L && AR->isAffine()) {
13332 auto canProveNUW = [&]() {
13333 // We can use the comparison to infer no-wrap flags only if it fully
13334 // controls the loop exit.
13335 if (!ControlsOnlyExit)
13336 return false;
13337
13338 if (!isLoopInvariant(RHS, L))
13339 return false;
13340
13341 if (!isKnownNonZero(AR->getStepRecurrence(*this)))
13342 // We need the sequence defined by AR to strictly increase in the
13343 // unsigned integer domain for the logic below to hold.
13344 return false;
13345
13346 const unsigned InnerBitWidth = getTypeSizeInBits(AR->getType());
13347 const unsigned OuterBitWidth = getTypeSizeInBits(RHS->getType());
13348 // If RHS <=u Limit, then there must exist a value V in the sequence
13349 // defined by AR (e.g. {Start,+,Step}) such that V >u RHS, and
13350 // V <=u UINT_MAX. Thus, we must exit the loop before unsigned
13351 // overflow occurs. This limit also implies that a signed comparison
13352 // (in the wide bitwidth) is equivalent to an unsigned comparison as
13353 // the high bits on both sides must be zero.
13354 APInt StrideMax = getUnsignedRangeMax(AR->getStepRecurrence(*this));
13355 APInt Limit = APInt::getMaxValue(InnerBitWidth) - (StrideMax - 1);
13356 Limit = Limit.zext(OuterBitWidth);
13357 return getUnsignedRangeMax(applyLoopGuards(RHS, L)).ule(Limit);
13358 };
13359 auto Flags = AR->getNoWrapFlags();
13360 if (!hasFlags(Flags, SCEV::FlagNUW) && canProveNUW())
13361 Flags = setFlags(Flags, SCEV::FlagNUW);
13362
13363 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), Flags);
13364 if (AR->hasNoUnsignedWrap()) {
13365 // Emulate what getZeroExtendExpr would have done during construction
13366 // if we'd been able to infer the fact just above at that time.
13367 const SCEV *Step = AR->getStepRecurrence(*this);
13368 Type *Ty = ZExt->getType();
13369 auto *S = getAddRecExpr(
13371 getZeroExtendExpr(Step, Ty, 0), L, AR->getNoWrapFlags());
13373 }
13374 }
13375 }
13376 }
13377
13378
13379 if (!IV && AllowPredicates) {
13380 // Try to make this an AddRec using runtime tests, in the first X
13381 // iterations of this loop, where X is the SCEV expression found by the
13382 // algorithm below.
13383 IV = convertSCEVToAddRecWithPredicates(LHS, L, Predicates);
13384 PredicatedIV = true;
13385 }
13386
13387 // Avoid weird loops
13388 if (!IV || IV->getLoop() != L || !IV->isAffine())
13389 return getCouldNotCompute();
13390
13391 // A precondition of this method is that the condition being analyzed
13392 // reaches an exiting branch which dominates the latch. Given that, we can
13393 // assume that an increment which violates the nowrap specification and
13394 // produces poison must cause undefined behavior when the resulting poison
13395 // value is branched upon and thus we can conclude that the backedge is
13396 // taken no more often than would be required to produce that poison value.
13397 // Note that a well defined loop can exit on the iteration which violates
13398 // the nowrap specification if there is another exit (either explicit or
13399 // implicit/exceptional) which causes the loop to execute before the
13400 // exiting instruction we're analyzing would trigger UB.
13401 auto WrapType = IsSigned ? SCEV::FlagNSW : SCEV::FlagNUW;
13402 bool NoWrap = ControlsOnlyExit && any(IV->getNoWrapFlags(WrapType));
13404
13405 const SCEV *Stride = IV->getStepRecurrence(*this);
13406
13407 bool PositiveStride = isKnownPositive(Stride);
13408
13409 // Avoid negative or zero stride values.
13410 if (!PositiveStride) {
13411 // We can compute the correct backedge taken count for loops with unknown
13412 // strides if we can prove that the loop is not an infinite loop with side
13413 // effects. Here's the loop structure we are trying to handle -
13414 //
13415 // i = start
13416 // do {
13417 // A[i] = i;
13418 // i += s;
13419 // } while (i < end);
13420 //
13421 // The backedge taken count for such loops is evaluated as -
13422 // (max(end, start + stride) - start - 1) /u stride
13423 //
13424 // The additional preconditions that we need to check to prove correctness
13425 // of the above formula is as follows -
13426 //
13427 // a) IV is either nuw or nsw depending upon signedness (indicated by the
13428 // NoWrap flag).
13429 // b) the loop is guaranteed to be finite (e.g. is mustprogress and has
13430 // no side effects within the loop)
13431 // c) loop has a single static exit (with no abnormal exits)
13432 //
13433 // Precondition a) implies that if the stride is negative, this is a single
13434 // trip loop. The backedge taken count formula reduces to zero in this case.
13435 //
13436 // Precondition b) and c) combine to imply that if rhs is invariant in L,
13437 // then a zero stride means the backedge can't be taken without executing
13438 // undefined behavior.
13439 //
13440 // The positive stride case is the same as isKnownPositive(Stride) returning
13441 // true (original behavior of the function).
13442 //
13443 if (PredicatedIV || !NoWrap || !loopIsFiniteByAssumption(L) ||
13445 return getCouldNotCompute();
13446
13447 if (!isKnownNonZero(Stride)) {
13448 // If we have a step of zero, and RHS isn't invariant in L, we don't know
13449 // if it might eventually be greater than start and if so, on which
13450 // iteration. We can't even produce a useful upper bound.
13451 if (!isLoopInvariant(RHS, L))
13452 return getCouldNotCompute();
13453
13454 // We allow a potentially zero stride, but we need to divide by stride
13455 // below. Since the loop can't be infinite and this check must control
13456 // the sole exit, we can infer the exit must be taken on the first
13457 // iteration (e.g. backedge count = 0) if the stride is zero. Given that,
13458 // we know the numerator in the divides below must be zero, so we can
13459 // pick an arbitrary non-zero value for the denominator (e.g. stride)
13460 // and produce the right result.
13461 // FIXME: Handle the case where Stride is poison?
13462 auto wouldZeroStrideBeUB = [&]() {
13463 // Proof by contradiction. Suppose the stride were zero. If we can
13464 // prove that the backedge *is* taken on the first iteration, then since
13465 // we know this condition controls the sole exit, we must have an
13466 // infinite loop. We can't have a (well defined) infinite loop per
13467 // check just above.
13468 // Note: The (Start - Stride) term is used to get the start' term from
13469 // (start' + stride,+,stride). Remember that we only care about the
13470 // result of this expression when stride == 0 at runtime.
13471 auto *StartIfZero = getMinusSCEV(IV->getStart(), Stride);
13472 return isLoopEntryGuardedByCond(L, Cond, StartIfZero, RHS);
13473 };
13474 if (!wouldZeroStrideBeUB()) {
13475 Stride = getUMaxExpr(Stride, getOne(Stride->getType()));
13476 }
13477 }
13478 } else if (!NoWrap) {
13479 // Avoid proven overflow cases: this will ensure that the backedge taken
13480 // count will not generate any unsigned overflow.
13481 if (canIVOverflowOnLT(RHS, Stride, IsSigned))
13482 return getCouldNotCompute();
13483 }
13484
13485 // On all paths just preceeding, we established the following invariant:
13486 // IV can be assumed not to overflow up to and including the exiting
13487 // iteration. We proved this in one of two ways:
13488 // 1) We can show overflow doesn't occur before the exiting iteration
13489 // 1a) canIVOverflowOnLT, and b) step of one
13490 // 2) We can show that if overflow occurs, the loop must execute UB
13491 // before any possible exit.
13492 // Note that we have not yet proved RHS invariant (in general).
13493
13494 const SCEV *Start = IV->getStart();
13495
13496 // Preserve pointer-typed Start/RHS to pass to isLoopEntryGuardedByCond.
13497 // If we convert to integers, isLoopEntryGuardedByCond will miss some cases.
13498 // Use integer-typed versions for actual computation; we can't subtract
13499 // pointers in general.
13500 const SCEV *OrigStart = Start;
13501 const SCEV *OrigRHS = RHS;
13502 if (Start->getType()->isPointerTy()) {
13504 if (isa<SCEVCouldNotCompute>(Start))
13505 return Start;
13506 }
13507 if (RHS->getType()->isPointerTy()) {
13510 return RHS;
13511 }
13512
13513 const SCEV *End = nullptr, *BECount = nullptr,
13514 *BECountIfBackedgeTaken = nullptr;
13515 if (!isLoopInvariant(RHS, L)) {
13516 const auto *RHSAddRec = dyn_cast<SCEVAddRecExpr>(RHS);
13517 if (PositiveStride && RHSAddRec != nullptr && RHSAddRec->getLoop() == L &&
13518 any(RHSAddRec->getNoWrapFlags())) {
13519 // The structure of loop we are trying to calculate backedge count of:
13520 //
13521 // left = left_start
13522 // right = right_start
13523 //
13524 // while(left < right){
13525 // ... do something here ...
13526 // left += s1; // stride of left is s1 (s1 > 0)
13527 // right += s2; // stride of right is s2 (s2 < 0)
13528 // }
13529 //
13530
13531 const SCEV *RHSStart = RHSAddRec->getStart();
13532 const SCEV *RHSStride = RHSAddRec->getStepRecurrence(*this);
13533
13534 // If Stride - RHSStride is positive and does not overflow, we can write
13535 // backedge count as ->
13536 // ceil((End - Start) /u (Stride - RHSStride))
13537 // Where, End = max(RHSStart, Start)
13538
13539 // Check if RHSStride < 0 and Stride - RHSStride will not overflow.
13540 if (isKnownNegative(RHSStride) &&
13541 willNotOverflow(Instruction::Sub, /*Signed=*/true, Stride,
13542 RHSStride)) {
13543
13544 const SCEV *Denominator = getMinusSCEV(Stride, RHSStride);
13545 if (isKnownPositive(Denominator)) {
13546 End = IsSigned ? getSMaxExpr(RHSStart, Start)
13547 : getUMaxExpr(RHSStart, Start);
13548
13549 // We can do this because End >= Start, as End = max(RHSStart, Start)
13550 const SCEV *Delta = getMinusSCEV(End, Start);
13551
13552 BECount = getUDivCeilSCEV(Delta, Denominator);
13553 BECountIfBackedgeTaken =
13554 getUDivCeilSCEV(getMinusSCEV(RHSStart, Start), Denominator);
13555 }
13556 }
13557 }
13558 if (BECount == nullptr) {
13559 // If we cannot calculate ExactBECount, we can calculate the MaxBECount,
13560 // given the start, stride and max value for the end bound of the
13561 // loop (RHS), and the fact that IV does not overflow (which is
13562 // checked above).
13563 const SCEV *MaxBECount = computeMaxBECountForLT(
13564 Start, Stride, RHS, getTypeSizeInBits(LHS->getType()), IsSigned);
13565 return ExitLimit(getCouldNotCompute() /* ExactNotTaken */, MaxBECount,
13566 MaxBECount, false /*MaxOrZero*/, Predicates);
13567 }
13568 } else {
13569 // We use the expression (max(End,Start)-Start)/Stride to describe the
13570 // backedge count, as if the backedge is taken at least once
13571 // max(End,Start) is End and so the result is as above, and if not
13572 // max(End,Start) is Start so we get a backedge count of zero.
13573 auto *OrigStartMinusStride = getMinusSCEV(OrigStart, Stride);
13574 assert(isAvailableAtLoopEntry(OrigStartMinusStride, L) && "Must be!");
13575 assert(isAvailableAtLoopEntry(OrigStart, L) && "Must be!");
13576 assert(isAvailableAtLoopEntry(OrigRHS, L) && "Must be!");
13577 // Can we prove (max(RHS,Start) > Start - Stride?
13578 if (isLoopEntryGuardedByCond(L, Cond, OrigStartMinusStride, OrigStart) &&
13579 isLoopEntryGuardedByCond(L, Cond, OrigStartMinusStride, OrigRHS)) {
13580 // In this case, we can use a refined formula for computing backedge
13581 // taken count. The general formula remains:
13582 // "End-Start /uceiling Stride" where "End = max(RHS,Start)"
13583 // We want to use the alternate formula:
13584 // "((End - 1) - (Start - Stride)) /u Stride"
13585 // Let's do a quick case analysis to show these are equivalent under
13586 // our precondition that max(RHS,Start) > Start - Stride.
13587 // * For RHS <= Start, the backedge-taken count must be zero.
13588 // "((End - 1) - (Start - Stride)) /u Stride" reduces to
13589 // "((Start - 1) - (Start - Stride)) /u Stride" which simplies to
13590 // "Stride - 1 /u Stride" which is indeed zero for all non-zero values
13591 // of Stride. For 0 stride, we've use umin(1,Stride) above,
13592 // reducing this to the stride of 1 case.
13593 // * For RHS >= Start, the backedge count must be "RHS-Start /uceil
13594 // Stride".
13595 // "((End - 1) - (Start - Stride)) /u Stride" reduces to
13596 // "((RHS - 1) - (Start - Stride)) /u Stride" reassociates to
13597 // "((RHS - (Start - Stride) - 1) /u Stride".
13598 // Our preconditions trivially imply no overflow in that form.
13599 const SCEV *MinusOne = getMinusOne(Stride->getType());
13600 const SCEV *Numerator =
13601 getMinusSCEV(getAddExpr(RHS, MinusOne), getMinusSCEV(Start, Stride));
13602 BECount = getUDivExpr(Numerator, Stride);
13603 }
13604
13605 if (!BECount) {
13606 auto canProveRHSGreaterThanEqualStart = [&]() {
13607 auto CondGE = IsSigned ? ICmpInst::ICMP_SGE : ICmpInst::ICMP_UGE;
13608 const SCEV *GuardedRHS = applyLoopGuards(OrigRHS, L);
13609 const SCEV *GuardedStart = applyLoopGuards(OrigStart, L);
13610
13611 if (isLoopEntryGuardedByCond(L, CondGE, OrigRHS, OrigStart) ||
13612 isKnownPredicate(CondGE, GuardedRHS, GuardedStart))
13613 return true;
13614
13615 // (RHS > Start - 1) implies RHS >= Start.
13616 // * "RHS >= Start" is trivially equivalent to "RHS > Start - 1" if
13617 // "Start - 1" doesn't overflow.
13618 // * For signed comparison, if Start - 1 does overflow, it's equal
13619 // to INT_MAX, and "RHS >s INT_MAX" is trivially false.
13620 // * For unsigned comparison, if Start - 1 does overflow, it's equal
13621 // to UINT_MAX, and "RHS >u UINT_MAX" is trivially false.
13622 //
13623 // FIXME: Should isLoopEntryGuardedByCond do this for us?
13624 auto CondGT = IsSigned ? ICmpInst::ICMP_SGT : ICmpInst::ICMP_UGT;
13625 auto *StartMinusOne =
13626 getAddExpr(OrigStart, getMinusOne(OrigStart->getType()));
13627 return isLoopEntryGuardedByCond(L, CondGT, OrigRHS, StartMinusOne);
13628 };
13629
13630 // If we know that RHS >= Start in the context of loop, then we know
13631 // that max(RHS, Start) = RHS at this point.
13632 if (canProveRHSGreaterThanEqualStart()) {
13633 End = RHS;
13634 } else {
13635 // If RHS < Start, the backedge will be taken zero times. So in
13636 // general, we can write the backedge-taken count as:
13637 //
13638 // RHS >= Start ? ceil(RHS - Start) / Stride : 0
13639 //
13640 // We convert it to the following to make it more convenient for SCEV:
13641 //
13642 // ceil(max(RHS, Start) - Start) / Stride
13643 End = IsSigned ? getSMaxExpr(RHS, Start) : getUMaxExpr(RHS, Start);
13644
13645 // See what would happen if we assume the backedge is taken. This is
13646 // used to compute MaxBECount.
13647 BECountIfBackedgeTaken =
13648 getUDivCeilSCEV(getMinusSCEV(RHS, Start), Stride);
13649 }
13650
13651 // At this point, we know:
13652 //
13653 // 1. If IsSigned, Start <=s End; otherwise, Start <=u End
13654 // 2. The index variable doesn't overflow.
13655 //
13656 // Therefore, we know N exists such that
13657 // (Start + Stride * N) >= End, and computing "(Start + Stride * N)"
13658 // doesn't overflow.
13659 //
13660 // Using this information, try to prove whether the addition in
13661 // "(Start - End) + (Stride - 1)" has unsigned overflow.
13662 const SCEV *One = getOne(Stride->getType());
13663 bool MayAddOverflow = [&] {
13664 if (isKnownToBeAPowerOfTwo(Stride)) {
13665 // Suppose Stride is a power of two, and Start/End are unsigned
13666 // integers. Let UMAX be the largest representable unsigned
13667 // integer.
13668 //
13669 // By the preconditions of this function, we know
13670 // "(Start + Stride * N) >= End", and this doesn't overflow.
13671 // As a formula:
13672 //
13673 // End <= (Start + Stride * N) <= UMAX
13674 //
13675 // Subtracting Start from all the terms:
13676 //
13677 // End - Start <= Stride * N <= UMAX - Start
13678 //
13679 // Since Start is unsigned, UMAX - Start <= UMAX. Therefore:
13680 //
13681 // End - Start <= Stride * N <= UMAX
13682 //
13683 // Stride * N is a multiple of Stride. Therefore,
13684 //
13685 // End - Start <= Stride * N <= UMAX - (UMAX mod Stride)
13686 //
13687 // Since Stride is a power of two, UMAX + 1 is divisible by
13688 // Stride. Therefore, UMAX mod Stride == Stride - 1. So we can
13689 // write:
13690 //
13691 // End - Start <= Stride * N <= UMAX - Stride - 1
13692 //
13693 // Dropping the middle term:
13694 //
13695 // End - Start <= UMAX - Stride - 1
13696 //
13697 // Adding Stride - 1 to both sides:
13698 //
13699 // (End - Start) + (Stride - 1) <= UMAX
13700 //
13701 // In other words, the addition doesn't have unsigned overflow.
13702 //
13703 // A similar proof works if we treat Start/End as signed values.
13704 // Just rewrite steps before "End - Start <= Stride * N <= UMAX"
13705 // to use signed max instead of unsigned max. Note that we're
13706 // trying to prove a lack of unsigned overflow in either case.
13707 return false;
13708 }
13709 if (Start == Stride || Start == getMinusSCEV(Stride, One)) {
13710 // If Start is equal to Stride, (End - Start) + (Stride - 1) == End
13711 // - 1. If !IsSigned, 0 <u Stride == Start <=u End; so 0 <u End - 1
13712 // <u End. If IsSigned, 0 <s Stride == Start <=s End; so 0 <s End -
13713 // 1 <s End.
13714 //
13715 // If Start is equal to Stride - 1, (End - Start) + Stride - 1 ==
13716 // End.
13717 return false;
13718 }
13719 return true;
13720 }();
13721
13722 const SCEV *Delta = getMinusSCEV(End, Start);
13723 if (!MayAddOverflow) {
13724 // floor((D + (S - 1)) / S)
13725 // We prefer this formulation if it's legal because it's fewer
13726 // operations.
13727 BECount =
13728 getUDivExpr(getAddExpr(Delta, getMinusSCEV(Stride, One)), Stride);
13729 } else {
13730 BECount = getUDivCeilSCEV(Delta, Stride);
13731 }
13732 }
13733 }
13734
13735 const SCEV *ConstantMaxBECount;
13736 bool MaxOrZero = false;
13737 if (isa<SCEVConstant>(BECount)) {
13738 ConstantMaxBECount = BECount;
13739 } else if (BECountIfBackedgeTaken &&
13740 isa<SCEVConstant>(BECountIfBackedgeTaken)) {
13741 // If we know exactly how many times the backedge will be taken if it's
13742 // taken at least once, then the backedge count will either be that or
13743 // zero.
13744 ConstantMaxBECount = BECountIfBackedgeTaken;
13745 MaxOrZero = true;
13746 } else {
13747 ConstantMaxBECount = computeMaxBECountForLT(
13748 Start, Stride, RHS, getTypeSizeInBits(LHS->getType()), IsSigned);
13749 }
13750
13751 if (isa<SCEVCouldNotCompute>(ConstantMaxBECount) &&
13752 !isa<SCEVCouldNotCompute>(BECount))
13753 ConstantMaxBECount = getConstant(getUnsignedRangeMax(BECount));
13754
13755 const SCEV *SymbolicMaxBECount =
13756 isa<SCEVCouldNotCompute>(BECount) ? ConstantMaxBECount : BECount;
13757 return ExitLimit(BECount, ConstantMaxBECount, SymbolicMaxBECount, MaxOrZero,
13758 Predicates);
13759}
13760
13761ScalarEvolution::ExitLimit ScalarEvolution::howManyGreaterThans(
13762 const SCEV *LHS, const SCEV *RHS, const Loop *L, bool IsSigned,
13763 bool ControlsOnlyExit, bool AllowPredicates) {
13765 // We handle only IV > Invariant
13766 if (!isLoopInvariant(RHS, L))
13767 return getCouldNotCompute();
13768
13769 const SCEVAddRecExpr *IV = dyn_cast<SCEVAddRecExpr>(LHS);
13770 if (!IV && AllowPredicates)
13771 // Try to make this an AddRec using runtime tests, in the first X
13772 // iterations of this loop, where X is the SCEV expression found by the
13773 // algorithm below.
13774 IV = convertSCEVToAddRecWithPredicates(LHS, L, Predicates);
13775
13776 // Avoid weird loops
13777 if (!IV || IV->getLoop() != L || !IV->isAffine())
13778 return getCouldNotCompute();
13779
13780 auto WrapType = IsSigned ? SCEV::FlagNSW : SCEV::FlagNUW;
13781 bool NoWrap = ControlsOnlyExit && any(IV->getNoWrapFlags(WrapType));
13783
13784 const SCEV *Stride = getNegativeSCEV(IV->getStepRecurrence(*this));
13785
13786 // Avoid negative or zero stride values
13787 if (!isKnownPositive(Stride))
13788 return getCouldNotCompute();
13789
13790 // Avoid proven overflow cases: this will ensure that the backedge taken count
13791 // will not generate any unsigned overflow. Relaxed no-overflow conditions
13792 // exploit NoWrapFlags, allowing to optimize in presence of undefined
13793 // behaviors like the case of C language.
13794 if (!Stride->isOne() && !NoWrap)
13795 if (canIVOverflowOnGT(RHS, Stride, IsSigned))
13796 return getCouldNotCompute();
13797
13798 const SCEV *Start = IV->getStart();
13799 const SCEV *End = RHS;
13800 if (!isLoopEntryGuardedByCond(L, Cond, getAddExpr(Start, Stride), RHS)) {
13801 // If we know that Start >= RHS in the context of loop, then we know that
13802 // min(RHS, Start) = RHS at this point.
13804 L, IsSigned ? ICmpInst::ICMP_SGE : ICmpInst::ICMP_UGE, Start, RHS))
13805 End = RHS;
13806 else
13807 End = IsSigned ? getSMinExpr(RHS, Start) : getUMinExpr(RHS, Start);
13808 }
13809
13810 if (Start->getType()->isPointerTy()) {
13812 if (isa<SCEVCouldNotCompute>(Start))
13813 return Start;
13814 }
13815 if (End->getType()->isPointerTy()) {
13816 End = getLosslessPtrToIntExpr(End);
13817 if (isa<SCEVCouldNotCompute>(End))
13818 return End;
13819 }
13820
13821 // Compute ((Start - End) + (Stride - 1)) / Stride.
13822 // FIXME: This can overflow. Holding off on fixing this for now;
13823 // howManyGreaterThans will hopefully be gone soon.
13824 const SCEV *One = getOne(Stride->getType());
13825 const SCEV *BECount = getUDivExpr(
13826 getAddExpr(getMinusSCEV(Start, End), getMinusSCEV(Stride, One)), Stride);
13827
13828 APInt MaxStart = IsSigned ? getSignedRangeMax(Start)
13830
13831 APInt MinStride = IsSigned ? getSignedRangeMin(Stride)
13832 : getUnsignedRangeMin(Stride);
13833
13834 unsigned BitWidth = getTypeSizeInBits(LHS->getType());
13835 APInt Limit = IsSigned ? APInt::getSignedMinValue(BitWidth) + (MinStride - 1)
13836 : APInt::getMinValue(BitWidth) + (MinStride - 1);
13837
13838 // Although End can be a MIN expression we estimate MinEnd considering only
13839 // the case End = RHS. This is safe because in the other case (Start - End)
13840 // is zero, leading to a zero maximum backedge taken count.
13841 APInt MinEnd =
13842 IsSigned ? APIntOps::smax(getSignedRangeMin(RHS), Limit)
13843 : APIntOps::umax(getUnsignedRangeMin(RHS), Limit);
13844
13845 const SCEV *ConstantMaxBECount =
13846 isa<SCEVConstant>(BECount)
13847 ? BECount
13848 : getUDivCeilSCEV(getConstant(MaxStart - MinEnd),
13849 getConstant(MinStride));
13850
13851 if (isa<SCEVCouldNotCompute>(ConstantMaxBECount))
13852 ConstantMaxBECount = BECount;
13853 const SCEV *SymbolicMaxBECount =
13854 isa<SCEVCouldNotCompute>(BECount) ? ConstantMaxBECount : BECount;
13855
13856 return ExitLimit(BECount, ConstantMaxBECount, SymbolicMaxBECount, false,
13857 Predicates);
13858}
13859
13861 ScalarEvolution &SE) const {
13862 if (Range.isFullSet()) // Infinite loop.
13863 return SE.getCouldNotCompute();
13864
13865 // If the start is a non-zero constant, shift the range to simplify things.
13866 if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(getStart()))
13867 if (!SC->getValue()->isZero()) {
13869 Operands[0] = SE.getZero(SC->getType());
13870 const SCEV *Shifted = SE.getAddRecExpr(Operands, getLoop(),
13872 if (const auto *ShiftedAddRec = dyn_cast<SCEVAddRecExpr>(Shifted))
13873 return ShiftedAddRec->getNumIterationsInRange(
13874 Range.subtract(SC->getAPInt()), SE);
13875 // This is strange and shouldn't happen.
13876 return SE.getCouldNotCompute();
13877 }
13878
13879 // The only time we can solve this is when we have all constant indices.
13880 // Otherwise, we cannot determine the overflow conditions.
13881 if (any_of(operands(), [](const SCEV *Op) { return !isa<SCEVConstant>(Op); }))
13882 return SE.getCouldNotCompute();
13883
13884 // Okay at this point we know that all elements of the chrec are constants and
13885 // that the start element is zero.
13886
13887 // First check to see if the range contains zero. If not, the first
13888 // iteration exits.
13889 unsigned BitWidth = SE.getTypeSizeInBits(getType());
13890 if (!Range.contains(APInt(BitWidth, 0)))
13891 return SE.getZero(getType());
13892
13893 if (isAffine()) {
13894 // If this is an affine expression then we have this situation:
13895 // Solve {0,+,A} in Range === Ax in Range
13896
13897 // We know that zero is in the range. If A is positive then we know that
13898 // the upper value of the range must be the first possible exit value.
13899 // If A is negative then the lower of the range is the last possible loop
13900 // value. Also note that we already checked for a full range.
13901 APInt A = cast<SCEVConstant>(getOperand(1))->getAPInt();
13902 APInt End = A.sge(1) ? (Range.getUpper() - 1) : Range.getLower();
13903
13904 // The exit value should be (End+A)/A.
13905 APInt ExitVal = (End + A).udiv(A);
13906 ConstantInt *ExitValue = ConstantInt::get(SE.getContext(), ExitVal);
13907
13908 // Evaluate at the exit value. If we really did fall out of the valid
13909 // range, then we computed our trip count, otherwise wrap around or other
13910 // things must have happened.
13911 ConstantInt *Val = EvaluateConstantChrecAtConstant(this, ExitValue, SE);
13912 if (Range.contains(Val->getValue()))
13913 return SE.getCouldNotCompute(); // Something strange happened
13914
13915 // Ensure that the previous value is in the range.
13916 assert(Range.contains(
13918 ConstantInt::get(SE.getContext(), ExitVal - 1), SE)->getValue()) &&
13919 "Linear scev computation is off in a bad way!");
13920 return SE.getConstant(ExitValue);
13921 }
13922
13923 if (isQuadratic()) {
13924 if (auto S = SolveQuadraticAddRecRange(this, Range, SE))
13925 return SE.getConstant(*S);
13926 }
13927
13928 return SE.getCouldNotCompute();
13929}
13930
13931const SCEVAddRecExpr *
13933 assert(getNumOperands() > 1 && "AddRec with zero step?");
13934 // There is a temptation to just call getAddExpr(this, getStepRecurrence(SE)),
13935 // but in this case we cannot guarantee that the value returned will be an
13936 // AddRec because SCEV does not have a fixed point where it stops
13937 // simplification: it is legal to return ({rec1} + {rec2}). For example, it
13938 // may happen if we reach arithmetic depth limit while simplifying. So we
13939 // construct the returned value explicitly.
13941 // If this is {A,+,B,+,C,...,+,N}, then its step is {B,+,C,+,...,+,N}, and
13942 // (this + Step) is {A+B,+,B+C,+...,+,N}.
13943 for (unsigned i = 0, e = getNumOperands() - 1; i < e; ++i)
13944 Ops.push_back(SE.getAddExpr(getOperand(i), getOperand(i + 1)));
13945 // We know that the last operand is not a constant zero (otherwise it would
13946 // have been popped out earlier). This guarantees us that if the result has
13947 // the same last operand, then it will also not be popped out, meaning that
13948 // the returned value will be an AddRec.
13949 const SCEV *Last = getOperand(getNumOperands() - 1);
13950 assert(!Last->isZero() && "Recurrency with zero step?");
13951 Ops.push_back(Last);
13954}
13955
13956// Return true when S contains at least an undef value.
13958 return SCEVExprContains(
13959 S, [](const SCEV *S) { return match(S, m_scev_UndefOrPoison()); });
13960}
13961
13962// Return true when S contains a value that is a nullptr.
13964 return SCEVExprContains(S, [](const SCEV *S) {
13965 if (const auto *SU = dyn_cast<SCEVUnknown>(S))
13966 return SU->getValue() == nullptr;
13967 return false;
13968 });
13969}
13970
13971/// Return the size of an element read or written by Inst.
13973 Type *Ty;
13974 if (StoreInst *Store = dyn_cast<StoreInst>(Inst))
13975 Ty = Store->getValueOperand()->getType();
13976 else if (LoadInst *Load = dyn_cast<LoadInst>(Inst))
13977 Ty = Load->getType();
13978 else
13979 return nullptr;
13980
13982 return getSizeOfExpr(ETy, Ty);
13983}
13984
13985//===----------------------------------------------------------------------===//
13986// SCEVCallbackVH Class Implementation
13987//===----------------------------------------------------------------------===//
13988
13990 assert(SE && "SCEVCallbackVH called with a null ScalarEvolution!");
13991 if (PHINode *PN = dyn_cast<PHINode>(getValPtr()))
13992 SE->ConstantEvolutionLoopExitValue.erase(PN);
13993 SE->eraseValueFromMap(getValPtr());
13994 // this now dangles!
13995}
13996
13997void ScalarEvolution::SCEVCallbackVH::allUsesReplacedWith(Value *V) {
13998 assert(SE && "SCEVCallbackVH called with a null ScalarEvolution!");
13999
14000 // Forget all the expressions associated with users of the old value,
14001 // so that future queries will recompute the expressions using the new
14002 // value.
14003 SE->forgetValue(getValPtr());
14004 // this now dangles!
14005}
14006
14007ScalarEvolution::SCEVCallbackVH::SCEVCallbackVH(Value *V, ScalarEvolution *se)
14008 : CallbackVH(V), SE(se) {}
14009
14010//===----------------------------------------------------------------------===//
14011// ScalarEvolution Class Implementation
14012//===----------------------------------------------------------------------===//
14013
14016 LoopInfo &LI)
14017 : F(F), DL(F.getDataLayout()), TLI(TLI), AC(AC), DT(DT), LI(LI),
14018 CouldNotCompute(new SCEVCouldNotCompute()), ValuesAtScopes(64),
14019 LoopDispositions(64), BlockDispositions(64) {
14020 // To use guards for proving predicates, we need to scan every instruction in
14021 // relevant basic blocks, and not just terminators. Doing this is a waste of
14022 // time if the IR does not actually contain any calls to
14023 // @llvm.experimental.guard, so do a quick check and remember this beforehand.
14024 //
14025 // This pessimizes the case where a pass that preserves ScalarEvolution wants
14026 // to _add_ guards to the module when there weren't any before, and wants
14027 // ScalarEvolution to optimize based on those guards. For now we prefer to be
14028 // efficient in lieu of being smart in that rather obscure case.
14029
14030 auto *GuardDecl = Intrinsic::getDeclarationIfExists(
14031 F.getParent(), Intrinsic::experimental_guard);
14032 HasGuards = GuardDecl && !GuardDecl->use_empty();
14033}
14034
14036 : F(Arg.F), DL(Arg.DL), HasGuards(Arg.HasGuards), TLI(Arg.TLI), AC(Arg.AC),
14037 DT(Arg.DT), LI(Arg.LI), CouldNotCompute(std::move(Arg.CouldNotCompute)),
14038 ValueExprMap(std::move(Arg.ValueExprMap)),
14039 PendingLoopPredicates(std::move(Arg.PendingLoopPredicates)),
14040 PendingMerges(std::move(Arg.PendingMerges)),
14041 ConstantMultipleCache(std::move(Arg.ConstantMultipleCache)),
14042 BackedgeTakenCounts(std::move(Arg.BackedgeTakenCounts)),
14043 PredicatedBackedgeTakenCounts(
14044 std::move(Arg.PredicatedBackedgeTakenCounts)),
14045 BECountUsers(std::move(Arg.BECountUsers)),
14046 ConstantEvolutionLoopExitValue(
14047 std::move(Arg.ConstantEvolutionLoopExitValue)),
14048 ValuesAtScopes(std::move(Arg.ValuesAtScopes)),
14049 ValuesAtScopesUsers(std::move(Arg.ValuesAtScopesUsers)),
14050 LoopDispositions(std::move(Arg.LoopDispositions)),
14051 LoopPropertiesCache(std::move(Arg.LoopPropertiesCache)),
14052 BlockDispositions(std::move(Arg.BlockDispositions)),
14053 SCEVUsers(std::move(Arg.SCEVUsers)),
14054 UnsignedRanges(std::move(Arg.UnsignedRanges)),
14055 SignedRanges(std::move(Arg.SignedRanges)),
14056 UniqueSCEVs(std::move(Arg.UniqueSCEVs)),
14057 UniquePreds(std::move(Arg.UniquePreds)),
14058 SCEVAllocator(std::move(Arg.SCEVAllocator)),
14059 LoopUsers(std::move(Arg.LoopUsers)),
14060 PredicatedSCEVRewrites(std::move(Arg.PredicatedSCEVRewrites)),
14061 FirstUnknown(Arg.FirstUnknown) {
14062 Arg.FirstUnknown = nullptr;
14063}
14064
14066 // Iterate through all the SCEVUnknown instances and call their
14067 // destructors, so that they release their references to their values.
14068 for (SCEVUnknown *U = FirstUnknown; U;) {
14069 SCEVUnknown *Tmp = U;
14070 U = U->Next;
14071 Tmp->~SCEVUnknown();
14072 }
14073 FirstUnknown = nullptr;
14074
14075 ExprValueMap.clear();
14076 ValueExprMap.clear();
14077 HasRecMap.clear();
14078 BackedgeTakenCounts.clear();
14079 PredicatedBackedgeTakenCounts.clear();
14080
14081 assert(PendingLoopPredicates.empty() && "isImpliedCond garbage");
14082 assert(PendingMerges.empty() && "isImpliedViaMerge garbage");
14083 assert(!WalkingBEDominatingConds && "isLoopBackedgeGuardedByCond garbage!");
14084 assert(!ProvingSplitPredicate && "ProvingSplitPredicate garbage!");
14085}
14086
14090
14091/// When printing a top-level SCEV for trip counts, it's helpful to include
14092/// a type for constants which are otherwise hard to disambiguate.
14093static void PrintSCEVWithTypeHint(raw_ostream &OS, const SCEV* S) {
14094 if (isa<SCEVConstant>(S))
14095 OS << *S->getType() << " ";
14096 OS << *S;
14097}
14098
14100 const Loop *L) {
14101 // Print all inner loops first
14102 for (Loop *I : *L)
14103 PrintLoopInfo(OS, SE, I);
14104
14105 OS << "Loop ";
14106 L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
14107 OS << ": ";
14108
14109 SmallVector<BasicBlock *, 8> ExitingBlocks;
14110 L->getExitingBlocks(ExitingBlocks);
14111 if (ExitingBlocks.size() != 1)
14112 OS << "<multiple exits> ";
14113
14114 auto *BTC = SE->getBackedgeTakenCount(L);
14115 if (!isa<SCEVCouldNotCompute>(BTC)) {
14116 OS << "backedge-taken count is ";
14117 PrintSCEVWithTypeHint(OS, BTC);
14118 } else
14119 OS << "Unpredictable backedge-taken count.";
14120 OS << "\n";
14121
14122 if (ExitingBlocks.size() > 1)
14123 for (BasicBlock *ExitingBlock : ExitingBlocks) {
14124 OS << " exit count for " << ExitingBlock->getName() << ": ";
14125 const SCEV *EC = SE->getExitCount(L, ExitingBlock);
14126 PrintSCEVWithTypeHint(OS, EC);
14127 if (isa<SCEVCouldNotCompute>(EC)) {
14128 // Retry with predicates.
14130 EC = SE->getPredicatedExitCount(L, ExitingBlock, &Predicates);
14131 if (!isa<SCEVCouldNotCompute>(EC)) {
14132 OS << "\n predicated exit count for " << ExitingBlock->getName()
14133 << ": ";
14134 PrintSCEVWithTypeHint(OS, EC);
14135 OS << "\n Predicates:\n";
14136 for (const auto *P : Predicates)
14137 P->print(OS, 4);
14138 }
14139 }
14140 OS << "\n";
14141 }
14142
14143 OS << "Loop ";
14144 L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
14145 OS << ": ";
14146
14147 auto *ConstantBTC = SE->getConstantMaxBackedgeTakenCount(L);
14148 if (!isa<SCEVCouldNotCompute>(ConstantBTC)) {
14149 OS << "constant max backedge-taken count is ";
14150 PrintSCEVWithTypeHint(OS, ConstantBTC);
14152 OS << ", actual taken count either this or zero.";
14153 } else {
14154 OS << "Unpredictable constant max backedge-taken count. ";
14155 }
14156
14157 OS << "\n"
14158 "Loop ";
14159 L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
14160 OS << ": ";
14161
14162 auto *SymbolicBTC = SE->getSymbolicMaxBackedgeTakenCount(L);
14163 if (!isa<SCEVCouldNotCompute>(SymbolicBTC)) {
14164 OS << "symbolic max backedge-taken count is ";
14165 PrintSCEVWithTypeHint(OS, SymbolicBTC);
14167 OS << ", actual taken count either this or zero.";
14168 } else {
14169 OS << "Unpredictable symbolic max backedge-taken count. ";
14170 }
14171 OS << "\n";
14172
14173 if (ExitingBlocks.size() > 1)
14174 for (BasicBlock *ExitingBlock : ExitingBlocks) {
14175 OS << " symbolic max exit count for " << ExitingBlock->getName() << ": ";
14176 auto *ExitBTC = SE->getExitCount(L, ExitingBlock,
14178 PrintSCEVWithTypeHint(OS, ExitBTC);
14179 if (isa<SCEVCouldNotCompute>(ExitBTC)) {
14180 // Retry with predicates.
14182 ExitBTC = SE->getPredicatedExitCount(L, ExitingBlock, &Predicates,
14184 if (!isa<SCEVCouldNotCompute>(ExitBTC)) {
14185 OS << "\n predicated symbolic max exit count for "
14186 << ExitingBlock->getName() << ": ";
14187 PrintSCEVWithTypeHint(OS, ExitBTC);
14188 OS << "\n Predicates:\n";
14189 for (const auto *P : Predicates)
14190 P->print(OS, 4);
14191 }
14192 }
14193 OS << "\n";
14194 }
14195
14197 auto *PBT = SE->getPredicatedBackedgeTakenCount(L, Preds);
14198 if (PBT != BTC) {
14199 assert(!Preds.empty() && "Different predicated BTC, but no predicates");
14200 OS << "Loop ";
14201 L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
14202 OS << ": ";
14203 if (!isa<SCEVCouldNotCompute>(PBT)) {
14204 OS << "Predicated backedge-taken count is ";
14205 PrintSCEVWithTypeHint(OS, PBT);
14206 } else
14207 OS << "Unpredictable predicated backedge-taken count.";
14208 OS << "\n";
14209 OS << " Predicates:\n";
14210 for (const auto *P : Preds)
14211 P->print(OS, 4);
14212 }
14213 Preds.clear();
14214
14215 auto *PredConstantMax =
14217 if (PredConstantMax != ConstantBTC) {
14218 assert(!Preds.empty() &&
14219 "different predicated constant max BTC but no predicates");
14220 OS << "Loop ";
14221 L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
14222 OS << ": ";
14223 if (!isa<SCEVCouldNotCompute>(PredConstantMax)) {
14224 OS << "Predicated constant max backedge-taken count is ";
14225 PrintSCEVWithTypeHint(OS, PredConstantMax);
14226 } else
14227 OS << "Unpredictable predicated constant max backedge-taken count.";
14228 OS << "\n";
14229 OS << " Predicates:\n";
14230 for (const auto *P : Preds)
14231 P->print(OS, 4);
14232 }
14233 Preds.clear();
14234
14235 auto *PredSymbolicMax =
14237 if (SymbolicBTC != PredSymbolicMax) {
14238 assert(!Preds.empty() &&
14239 "Different predicated symbolic max BTC, but no predicates");
14240 OS << "Loop ";
14241 L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
14242 OS << ": ";
14243 if (!isa<SCEVCouldNotCompute>(PredSymbolicMax)) {
14244 OS << "Predicated symbolic max backedge-taken count is ";
14245 PrintSCEVWithTypeHint(OS, PredSymbolicMax);
14246 } else
14247 OS << "Unpredictable predicated symbolic max backedge-taken count.";
14248 OS << "\n";
14249 OS << " Predicates:\n";
14250 for (const auto *P : Preds)
14251 P->print(OS, 4);
14252 }
14253
14255 OS << "Loop ";
14256 L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
14257 OS << ": ";
14258 OS << "Trip multiple is " << SE->getSmallConstantTripMultiple(L) << "\n";
14259 }
14260}
14261
14262namespace llvm {
14263// Note: these overloaded operators need to be in the llvm namespace for them
14264// to be resolved correctly. If we put them outside the llvm namespace, the
14265//
14266// OS << ": " << SE.getLoopDisposition(SV, InnerL);
14267//
14268// code below "breaks" and start printing raw enum values as opposed to the
14269// string values.
14272 switch (LD) {
14274 OS << "Variant";
14275 break;
14277 OS << "Invariant";
14278 break;
14280 OS << "Uniform";
14281 break;
14283 OS << "Computable";
14284 break;
14285 }
14286 return OS;
14287}
14288
14291 switch (BD) {
14293 OS << "DoesNotDominate";
14294 break;
14296 OS << "Dominates";
14297 break;
14299 OS << "ProperlyDominates";
14300 break;
14301 }
14302 return OS;
14303}
14304} // namespace llvm
14305
14307 // ScalarEvolution's implementation of the print method is to print
14308 // out SCEV values of all instructions that are interesting. Doing
14309 // this potentially causes it to create new SCEV objects though,
14310 // which technically conflicts with the const qualifier. This isn't
14311 // observable from outside the class though, so casting away the
14312 // const isn't dangerous.
14313 ScalarEvolution &SE = *const_cast<ScalarEvolution *>(this);
14314
14315 if (ClassifyExpressions) {
14316 OS << "Classifying expressions for: ";
14317 F.printAsOperand(OS, /*PrintType=*/false);
14318 OS << "\n";
14319 for (Instruction &I : instructions(F))
14320 if (isSCEVable(I.getType()) && !isa<CmpInst>(I)) {
14321 OS << I << '\n';
14322 OS << " --> ";
14323 const SCEV *SV = SE.getSCEV(&I);
14324 SV->print(OS);
14325 if (!isa<SCEVCouldNotCompute>(SV)) {
14326 OS << " U: ";
14327 SE.getUnsignedRange(SV).print(OS);
14328 OS << " S: ";
14329 SE.getSignedRange(SV).print(OS);
14330 }
14331
14332 const Loop *L = LI.getLoopFor(I.getParent());
14333
14334 const SCEV *AtUse = SE.getSCEVAtScope(SV, L);
14335 if (AtUse != SV) {
14336 OS << " --> ";
14337 AtUse->print(OS);
14338 if (!isa<SCEVCouldNotCompute>(AtUse)) {
14339 OS << " U: ";
14340 SE.getUnsignedRange(AtUse).print(OS);
14341 OS << " S: ";
14342 SE.getSignedRange(AtUse).print(OS);
14343 }
14344 }
14345
14346 if (L) {
14347 OS << "\t\t" "Exits: ";
14348 const SCEV *ExitValue = SE.getSCEVAtScope(SV, L->getParentLoop());
14349 if (!SE.isLoopInvariant(ExitValue, L)) {
14350 OS << "<<Unknown>>";
14351 } else {
14352 OS << *ExitValue;
14353 }
14354
14355 ListSeparator LS(", ", "\t\tLoopDispositions: { ");
14356 for (const auto *Iter = L; Iter; Iter = Iter->getParentLoop()) {
14357 OS << LS;
14358 Iter->getHeader()->printAsOperand(OS, /*PrintType=*/false);
14359 OS << ": " << SE.getLoopDisposition(SV, Iter);
14360 }
14361
14362 for (const auto *InnerL : depth_first(L)) {
14363 if (InnerL == L)
14364 continue;
14365 OS << LS;
14366 InnerL->getHeader()->printAsOperand(OS, /*PrintType=*/false);
14367 OS << ": " << SE.getLoopDisposition(SV, InnerL);
14368 }
14369
14370 OS << " }";
14371 }
14372
14373 OS << "\n";
14374 }
14375 }
14376
14377 OS << "Determining loop execution counts for: ";
14378 F.printAsOperand(OS, /*PrintType=*/false);
14379 OS << "\n";
14380 for (Loop *I : LI)
14381 PrintLoopInfo(OS, &SE, I);
14382}
14383
14386 auto &Values = LoopDispositions[S];
14387 for (auto &V : Values) {
14388 if (V.getPointer() == L)
14389 return V.getInt();
14390 }
14391 Values.emplace_back(L, LoopVariant);
14392 LoopDisposition D = computeLoopDisposition(S, L);
14393 auto &Values2 = LoopDispositions[S];
14394 for (auto &V : llvm::reverse(Values2)) {
14395 if (V.getPointer() == L) {
14396 V.setInt(D);
14397 break;
14398 }
14399 }
14400 return D;
14401}
14402
14404ScalarEvolution::computeLoopDisposition(const SCEV *S, const Loop *L) {
14405 switch (S->getSCEVType()) {
14406 case scConstant:
14407 case scVScale:
14408 return LoopInvariant;
14409 case scAddRecExpr: {
14410 const SCEVAddRecExpr *AR = cast<SCEVAddRecExpr>(S);
14411
14412 // If L is the addrec's loop, it's computable.
14413 if (AR->getLoop() == L)
14414 return LoopComputable;
14415
14416 // Add recurrences are never invariant in the function-body (null loop).
14417 if (!L)
14418 return LoopVariant;
14419
14420 // Everything that is not defined at loop entry is variant.
14421 if (DT.dominates(L->getHeader(), AR->getLoop()->getHeader())) {
14422 if (L->contains(AR->getLoop()) &&
14423 llvm::all_of(AR->operands(),
14424 [&](const SCEV *Op) { return isLoopUniform(Op, L); }))
14425 return LoopUniform;
14426
14427 return LoopVariant;
14428 }
14429 assert(!L->contains(AR->getLoop()) && "Containing loop's header does not"
14430 " dominate the contained loop's header?");
14431
14432 // This recurrence is invariant w.r.t. L if AR's loop contains L.
14433 if (AR->getLoop()->contains(L))
14434 return LoopInvariant;
14435
14436 // This recurrence is variant w.r.t. L if any of its operands
14437 // are variant.
14438 for (SCEVUse Op : AR->operands())
14439 if (!isLoopInvariant(Op, L))
14440 return LoopVariant;
14441
14442 // Otherwise it's loop-invariant.
14443 return LoopInvariant;
14444 }
14445 case scTruncate:
14446 case scZeroExtend:
14447 case scSignExtend:
14448 case scPtrToAddr:
14449 case scPtrToInt:
14450 case scAddExpr:
14451 case scMulExpr:
14452 case scUDivExpr:
14453 case scUMaxExpr:
14454 case scSMaxExpr:
14455 case scUMinExpr:
14456 case scSMinExpr:
14457 case scSequentialUMinExpr: {
14458 bool HasVarying = false;
14459 bool HasUniform = false;
14460 for (SCEVUse Op : S->operands()) {
14462 if (D == LoopVariant)
14463 return LoopVariant;
14464 if (D == LoopComputable)
14465 HasVarying = true;
14466 if (D == LoopUniform)
14467 HasUniform = true;
14468 }
14469 return HasVarying ? (HasUniform ? LoopVariant : LoopComputable)
14470 : (HasUniform ? LoopUniform : LoopInvariant);
14471 }
14472 case scUnknown:
14473 // All non-instruction values are loop invariant. All instructions are loop
14474 // invariant if they are not contained in the specified loop.
14475 // Instructions are never considered invariant in the function body
14476 // (null loop) because they are defined within the "loop".
14478 return (L && !L->contains(I)) ? LoopInvariant : LoopVariant;
14479 return LoopInvariant;
14480 case scCouldNotCompute:
14481 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
14482 }
14483 llvm_unreachable("Unknown SCEV kind!");
14484}
14485
14486bool ScalarEvolution::isLoopUniform(const SCEV *S, const Loop *L) {
14488 return D == LoopUniform || D == LoopInvariant;
14489}
14490
14492 return getLoopDisposition(S, L) == LoopInvariant;
14493}
14494
14496 return getLoopDisposition(S, L) == LoopComputable;
14497}
14498
14501 auto &Values = BlockDispositions[S];
14502 for (auto &V : Values) {
14503 if (V.getPointer() == BB)
14504 return V.getInt();
14505 }
14506 Values.emplace_back(BB, DoesNotDominateBlock);
14507 BlockDisposition D = computeBlockDisposition(S, BB);
14508 auto &Values2 = BlockDispositions[S];
14509 for (auto &V : llvm::reverse(Values2)) {
14510 if (V.getPointer() == BB) {
14511 V.setInt(D);
14512 break;
14513 }
14514 }
14515 return D;
14516}
14517
14519ScalarEvolution::computeBlockDisposition(const SCEV *S, const BasicBlock *BB) {
14520 switch (S->getSCEVType()) {
14521 case scConstant:
14522 case scVScale:
14524 case scAddRecExpr: {
14525 // This uses a "dominates" query instead of "properly dominates" query
14526 // to test for proper dominance too, because the instruction which
14527 // produces the addrec's value is a PHI, and a PHI effectively properly
14528 // dominates its entire containing block.
14529 const SCEVAddRecExpr *AR = cast<SCEVAddRecExpr>(S);
14530 if (!DT.dominates(AR->getLoop()->getHeader(), BB))
14531 return DoesNotDominateBlock;
14532
14533 // Fall through into SCEVNAryExpr handling.
14534 [[fallthrough]];
14535 }
14536 case scTruncate:
14537 case scZeroExtend:
14538 case scSignExtend:
14539 case scPtrToAddr:
14540 case scPtrToInt:
14541 case scAddExpr:
14542 case scMulExpr:
14543 case scUDivExpr:
14544 case scUMaxExpr:
14545 case scSMaxExpr:
14546 case scUMinExpr:
14547 case scSMinExpr:
14548 case scSequentialUMinExpr: {
14549 bool Proper = true;
14550 for (const SCEV *NAryOp : S->operands()) {
14552 if (D == DoesNotDominateBlock)
14553 return DoesNotDominateBlock;
14554 if (D == DominatesBlock)
14555 Proper = false;
14556 }
14557 return Proper ? ProperlyDominatesBlock : DominatesBlock;
14558 }
14559 case scUnknown:
14560 if (Instruction *I =
14562 if (I->getParent() == BB)
14563 return DominatesBlock;
14564 if (DT.properlyDominates(I->getParent(), BB))
14566 return DoesNotDominateBlock;
14567 }
14569 case scCouldNotCompute:
14570 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
14571 }
14572 llvm_unreachable("Unknown SCEV kind!");
14573}
14574
14575bool ScalarEvolution::dominates(const SCEV *S, const BasicBlock *BB) {
14576 return getBlockDisposition(S, BB) >= DominatesBlock;
14577}
14578
14581}
14582
14583bool ScalarEvolution::hasOperand(const SCEV *S, const SCEV *Op) const {
14584 return SCEVExprContains(S, [&](const SCEV *Expr) { return Expr == Op; });
14585}
14586
14587void ScalarEvolution::forgetBackedgeTakenCounts(const Loop *L,
14588 bool Predicated) {
14589 auto &BECounts =
14590 Predicated ? PredicatedBackedgeTakenCounts : BackedgeTakenCounts;
14591 auto It = BECounts.find(L);
14592 if (It != BECounts.end()) {
14593 for (const ExitNotTakenInfo &ENT : It->second.ExitNotTaken) {
14594 for (const SCEV *S : {ENT.ExactNotTaken, ENT.SymbolicMaxNotTaken}) {
14595 if (!isa<SCEVConstant>(S)) {
14596 auto UserIt = BECountUsers.find(S);
14597 assert(UserIt != BECountUsers.end());
14598 UserIt->second.erase({L, Predicated});
14599 }
14600 }
14601 }
14602 BECounts.erase(It);
14603 }
14604}
14605
14606void ScalarEvolution::forgetMemoizedResults(ArrayRef<SCEVUse> SCEVs) {
14607 SmallPtrSet<const SCEV *, 8> ToForget(llvm::from_range, SCEVs);
14608 SmallVector<SCEVUse, 8> Worklist(ToForget.begin(), ToForget.end());
14609
14610 while (!Worklist.empty()) {
14611 const SCEV *Curr = Worklist.pop_back_val();
14612 auto Users = SCEVUsers.find(Curr);
14613 if (Users != SCEVUsers.end())
14614 for (const auto *User : Users->second)
14615 if (ToForget.insert(User).second)
14616 Worklist.push_back(User);
14617 }
14618
14619 for (const auto *S : ToForget)
14620 forgetMemoizedResultsImpl(S);
14621
14622 PredicatedSCEVRewrites.remove_if(
14623 [&](const auto &Entry) { return ToForget.count(Entry.first.first); });
14624}
14625
14626void ScalarEvolution::forgetMemoizedResultsImpl(const SCEV *S) {
14627 LoopDispositions.erase(S);
14628 BlockDispositions.erase(S);
14629 UnsignedRanges.erase(S);
14630 SignedRanges.erase(S);
14631 HasRecMap.erase(S);
14632 ConstantMultipleCache.erase(S);
14633
14634 if (auto *AR = dyn_cast<SCEVAddRecExpr>(S)) {
14635 UnsignedWrapViaInductionTried.erase(AR);
14636 SignedWrapViaInductionTried.erase(AR);
14637 }
14638
14639 auto ExprIt = ExprValueMap.find(S);
14640 if (ExprIt != ExprValueMap.end()) {
14641 for (Value *V : ExprIt->second) {
14642 auto ValueIt = ValueExprMap.find_as(V);
14643 if (ValueIt != ValueExprMap.end())
14644 ValueExprMap.erase(ValueIt);
14645 }
14646 ExprValueMap.erase(ExprIt);
14647 }
14648
14649 auto ScopeIt = ValuesAtScopes.find(S);
14650 if (ScopeIt != ValuesAtScopes.end()) {
14651 for (const auto &Pair : ScopeIt->second)
14652 if (!isa_and_nonnull<SCEVConstant>(Pair.second))
14653 llvm::erase(ValuesAtScopesUsers[Pair.second],
14654 std::make_pair(Pair.first, S));
14655 ValuesAtScopes.erase(ScopeIt);
14656 }
14657
14658 auto ScopeUserIt = ValuesAtScopesUsers.find(S);
14659 if (ScopeUserIt != ValuesAtScopesUsers.end()) {
14660 for (const auto &Pair : ScopeUserIt->second)
14661 llvm::erase(ValuesAtScopes[Pair.second], std::make_pair(Pair.first, S));
14662 ValuesAtScopesUsers.erase(ScopeUserIt);
14663 }
14664
14665 auto BEUsersIt = BECountUsers.find(S);
14666 if (BEUsersIt != BECountUsers.end()) {
14667 // Work on a copy, as forgetBackedgeTakenCounts() will modify the original.
14668 auto Copy = BEUsersIt->second;
14669 for (const auto &Pair : Copy)
14670 forgetBackedgeTakenCounts(Pair.getPointer(), Pair.getInt());
14671 BECountUsers.erase(BEUsersIt);
14672 }
14673
14674 auto FoldUser = FoldCacheUser.find(S);
14675 if (FoldUser != FoldCacheUser.end())
14676 for (auto &KV : FoldUser->second)
14677 FoldCache.erase(KV);
14678 FoldCacheUser.erase(S);
14679}
14680
14681void
14682ScalarEvolution::getUsedLoops(const SCEV *S,
14683 SmallPtrSetImpl<const Loop *> &LoopsUsed) {
14684 struct FindUsedLoops {
14685 FindUsedLoops(SmallPtrSetImpl<const Loop *> &LoopsUsed)
14686 : LoopsUsed(LoopsUsed) {}
14687 SmallPtrSetImpl<const Loop *> &LoopsUsed;
14688 bool follow(const SCEV *S) {
14689 if (auto *AR = dyn_cast<SCEVAddRecExpr>(S))
14690 LoopsUsed.insert(AR->getLoop());
14691 return true;
14692 }
14693
14694 bool isDone() const { return false; }
14695 };
14696
14697 FindUsedLoops F(LoopsUsed);
14698 SCEVTraversal<FindUsedLoops>(F).visitAll(S);
14699}
14700
14701void ScalarEvolution::getReachableBlocks(
14704 Worklist.push_back(&F.getEntryBlock());
14705 while (!Worklist.empty()) {
14706 BasicBlock *BB = Worklist.pop_back_val();
14707 if (!Reachable.insert(BB).second)
14708 continue;
14709
14710 Value *Cond;
14711 BasicBlock *TrueBB, *FalseBB;
14712 if (match(BB->getTerminator(), m_Br(m_Value(Cond), m_BasicBlock(TrueBB),
14713 m_BasicBlock(FalseBB)))) {
14714 if (auto *C = dyn_cast<ConstantInt>(Cond)) {
14715 Worklist.push_back(C->isOne() ? TrueBB : FalseBB);
14716 continue;
14717 }
14718
14719 if (auto *Cmp = dyn_cast<ICmpInst>(Cond)) {
14720 const SCEV *L = getSCEV(Cmp->getOperand(0));
14721 const SCEV *R = getSCEV(Cmp->getOperand(1));
14722 if (isKnownPredicateViaConstantRanges(Cmp->getCmpPredicate(), L, R)) {
14723 Worklist.push_back(TrueBB);
14724 continue;
14725 }
14726 if (isKnownPredicateViaConstantRanges(Cmp->getInverseCmpPredicate(), L,
14727 R)) {
14728 Worklist.push_back(FalseBB);
14729 continue;
14730 }
14731 }
14732 }
14733
14734 append_range(Worklist, successors(BB));
14735 }
14736}
14737
14739 ScalarEvolution &SE = *const_cast<ScalarEvolution *>(this);
14740 ScalarEvolution SE2(F, TLI, AC, DT, LI);
14741
14742 SmallVector<Loop *, 8> LoopStack(LI.begin(), LI.end());
14743
14744 // Map's SCEV expressions from one ScalarEvolution "universe" to another.
14745 struct SCEVMapper : public SCEVRewriteVisitor<SCEVMapper> {
14746 SCEVMapper(ScalarEvolution &SE) : SCEVRewriteVisitor<SCEVMapper>(SE) {}
14747
14748 const SCEV *visitConstant(const SCEVConstant *Constant) {
14749 return SE.getConstant(Constant->getAPInt());
14750 }
14751
14752 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
14753 return SE.getUnknown(Expr->getValue());
14754 }
14755
14756 const SCEV *visitCouldNotCompute(const SCEVCouldNotCompute *Expr) {
14757 return SE.getCouldNotCompute();
14758 }
14759 };
14760
14761 SCEVMapper SCM(SE2);
14762 SmallPtrSet<BasicBlock *, 16> ReachableBlocks;
14763 SE2.getReachableBlocks(ReachableBlocks, F);
14764
14765 auto GetDelta = [&](const SCEV *Old, const SCEV *New) -> const SCEV * {
14766 if (containsUndefs(Old) || containsUndefs(New)) {
14767 // SCEV treats "undef" as an unknown but consistent value (i.e. it does
14768 // not propagate undef aggressively). This means we can (and do) fail
14769 // verification in cases where a transform makes a value go from "undef"
14770 // to "undef+1" (say). The transform is fine, since in both cases the
14771 // result is "undef", but SCEV thinks the value increased by 1.
14772 return nullptr;
14773 }
14774
14775 // Unless VerifySCEVStrict is set, we only compare constant deltas.
14776 const SCEV *Delta = SE2.getMinusSCEV(Old, New);
14777 if (!VerifySCEVStrict && !isa<SCEVConstant>(Delta))
14778 return nullptr;
14779
14780 return Delta;
14781 };
14782
14783 while (!LoopStack.empty()) {
14784 auto *L = LoopStack.pop_back_val();
14785 llvm::append_range(LoopStack, *L);
14786
14787 // Only verify BECounts in reachable loops. For an unreachable loop,
14788 // any BECount is legal.
14789 if (!ReachableBlocks.contains(L->getHeader()))
14790 continue;
14791
14792 // Only verify cached BECounts. Computing new BECounts may change the
14793 // results of subsequent SCEV uses.
14794 auto It = BackedgeTakenCounts.find(L);
14795 if (It == BackedgeTakenCounts.end())
14796 continue;
14797
14798 auto *CurBECount =
14799 SCM.visit(It->second.getExact(L, const_cast<ScalarEvolution *>(this)));
14800 auto *NewBECount = SE2.getBackedgeTakenCount(L);
14801
14802 if (CurBECount == SE2.getCouldNotCompute() ||
14803 NewBECount == SE2.getCouldNotCompute()) {
14804 // NB! This situation is legal, but is very suspicious -- whatever pass
14805 // change the loop to make a trip count go from could not compute to
14806 // computable or vice-versa *should have* invalidated SCEV. However, we
14807 // choose not to assert here (for now) since we don't want false
14808 // positives.
14809 continue;
14810 }
14811
14812 if (SE.getTypeSizeInBits(CurBECount->getType()) >
14813 SE.getTypeSizeInBits(NewBECount->getType()))
14814 NewBECount = SE2.getZeroExtendExpr(NewBECount, CurBECount->getType());
14815 else if (SE.getTypeSizeInBits(CurBECount->getType()) <
14816 SE.getTypeSizeInBits(NewBECount->getType()))
14817 CurBECount = SE2.getZeroExtendExpr(CurBECount, NewBECount->getType());
14818
14819 const SCEV *Delta = GetDelta(CurBECount, NewBECount);
14820 if (Delta && !Delta->isZero()) {
14821 dbgs() << "Trip Count for " << *L << " Changed!\n";
14822 dbgs() << "Old: " << *CurBECount << "\n";
14823 dbgs() << "New: " << *NewBECount << "\n";
14824 dbgs() << "Delta: " << *Delta << "\n";
14825 std::abort();
14826 }
14827 }
14828
14829 // Collect all valid loops currently in LoopInfo.
14830 SmallPtrSet<Loop *, 32> ValidLoops;
14831 SmallVector<Loop *, 32> Worklist(LI.begin(), LI.end());
14832 while (!Worklist.empty()) {
14833 Loop *L = Worklist.pop_back_val();
14834 if (ValidLoops.insert(L).second)
14835 Worklist.append(L->begin(), L->end());
14836 }
14837 for (const auto &KV : ValueExprMap) {
14838#ifndef NDEBUG
14839 // Check for SCEV expressions referencing invalid/deleted loops.
14840 if (auto *AR = dyn_cast<SCEVAddRecExpr>(KV.second)) {
14841 assert(ValidLoops.contains(AR->getLoop()) &&
14842 "AddRec references invalid loop");
14843 }
14844#endif
14845
14846 // Check that the value is also part of the reverse map.
14847 auto It = ExprValueMap.find(KV.second);
14848 if (It == ExprValueMap.end() || !It->second.contains(KV.first)) {
14849 dbgs() << "Value " << *KV.first
14850 << " is in ValueExprMap but not in ExprValueMap\n";
14851 std::abort();
14852 }
14853
14854 if (auto *I = dyn_cast<Instruction>(&*KV.first)) {
14855 if (!ReachableBlocks.contains(I->getParent()))
14856 continue;
14857 const SCEV *OldSCEV = SCM.visit(KV.second);
14858 const SCEV *NewSCEV = SE2.getSCEV(I);
14859 const SCEV *Delta = GetDelta(OldSCEV, NewSCEV);
14860 if (Delta && !Delta->isZero()) {
14861 dbgs() << "SCEV for value " << *I << " changed!\n"
14862 << "Old: " << *OldSCEV << "\n"
14863 << "New: " << *NewSCEV << "\n"
14864 << "Delta: " << *Delta << "\n";
14865 std::abort();
14866 }
14867 }
14868 }
14869
14870 for (const auto &KV : ExprValueMap) {
14871 for (Value *V : KV.second) {
14872 const SCEV *S = ValueExprMap.lookup(V);
14873 if (!S) {
14874 dbgs() << "Value " << *V
14875 << " is in ExprValueMap but not in ValueExprMap\n";
14876 std::abort();
14877 }
14878 if (S != KV.first) {
14879 dbgs() << "Value " << *V << " mapped to " << *S << " rather than "
14880 << *KV.first << "\n";
14881 std::abort();
14882 }
14883 }
14884 }
14885
14886 // Verify integrity of SCEV users.
14887 for (const auto &S : UniqueSCEVs) {
14888 for (SCEVUse Op : S.operands()) {
14889 // We do not store dependencies of constants.
14890 if (isa<SCEVConstant>(Op))
14891 continue;
14892 auto It = SCEVUsers.find(Op);
14893 if (It != SCEVUsers.end() && It->second.count(&S))
14894 continue;
14895 dbgs() << "Use of operand " << *Op << " by user " << S
14896 << " is not being tracked!\n";
14897 std::abort();
14898 }
14899 }
14900
14901 // Verify integrity of ValuesAtScopes users.
14902 for (const auto &ValueAndVec : ValuesAtScopes) {
14903 const SCEV *Value = ValueAndVec.first;
14904 for (const auto &LoopAndValueAtScope : ValueAndVec.second) {
14905 const Loop *L = LoopAndValueAtScope.first;
14906 const SCEV *ValueAtScope = LoopAndValueAtScope.second;
14907 if (!isa<SCEVConstant>(ValueAtScope)) {
14908 auto It = ValuesAtScopesUsers.find(ValueAtScope);
14909 if (It != ValuesAtScopesUsers.end() &&
14910 is_contained(It->second, std::make_pair(L, Value)))
14911 continue;
14912 dbgs() << "Value: " << *Value << ", Loop: " << *L << ", ValueAtScope: "
14913 << *ValueAtScope << " missing in ValuesAtScopesUsers\n";
14914 std::abort();
14915 }
14916 }
14917 }
14918
14919 for (const auto &ValueAtScopeAndVec : ValuesAtScopesUsers) {
14920 const SCEV *ValueAtScope = ValueAtScopeAndVec.first;
14921 for (const auto &LoopAndValue : ValueAtScopeAndVec.second) {
14922 const Loop *L = LoopAndValue.first;
14923 const SCEV *Value = LoopAndValue.second;
14925 auto It = ValuesAtScopes.find(Value);
14926 if (It != ValuesAtScopes.end() &&
14927 is_contained(It->second, std::make_pair(L, ValueAtScope)))
14928 continue;
14929 dbgs() << "Value: " << *Value << ", Loop: " << *L << ", ValueAtScope: "
14930 << *ValueAtScope << " missing in ValuesAtScopes\n";
14931 std::abort();
14932 }
14933 }
14934
14935 // Verify integrity of BECountUsers.
14936 auto VerifyBECountUsers = [&](bool Predicated) {
14937 auto &BECounts =
14938 Predicated ? PredicatedBackedgeTakenCounts : BackedgeTakenCounts;
14939 for (const auto &LoopAndBEInfo : BECounts) {
14940 for (const ExitNotTakenInfo &ENT : LoopAndBEInfo.second.ExitNotTaken) {
14941 for (const SCEV *S : {ENT.ExactNotTaken, ENT.SymbolicMaxNotTaken}) {
14942 if (!isa<SCEVConstant>(S)) {
14943 auto UserIt = BECountUsers.find(S);
14944 if (UserIt != BECountUsers.end() &&
14945 UserIt->second.contains({ LoopAndBEInfo.first, Predicated }))
14946 continue;
14947 dbgs() << "Value " << *S << " for loop " << *LoopAndBEInfo.first
14948 << " missing from BECountUsers\n";
14949 std::abort();
14950 }
14951 }
14952 }
14953 }
14954 };
14955 VerifyBECountUsers(/* Predicated */ false);
14956 VerifyBECountUsers(/* Predicated */ true);
14957
14958 // Verify intergity of loop disposition cache.
14959 for (auto &[S, Values] : LoopDispositions) {
14960 for (auto [Loop, CachedDisposition] : Values) {
14961 const auto RecomputedDisposition = SE2.getLoopDisposition(S, Loop);
14962 if (CachedDisposition != RecomputedDisposition) {
14963 dbgs() << "Cached disposition of " << *S << " for loop " << *Loop
14964 << " is incorrect: cached " << CachedDisposition << ", actual "
14965 << RecomputedDisposition << "\n";
14966 std::abort();
14967 }
14968 }
14969 }
14970
14971 // Verify integrity of the block disposition cache.
14972 for (auto &[S, Values] : BlockDispositions) {
14973 for (auto [BB, CachedDisposition] : Values) {
14974 const auto RecomputedDisposition = SE2.getBlockDisposition(S, BB);
14975 if (CachedDisposition != RecomputedDisposition) {
14976 dbgs() << "Cached disposition of " << *S << " for block %"
14977 << BB->getName() << " is incorrect: cached " << CachedDisposition
14978 << ", actual " << RecomputedDisposition << "\n";
14979 std::abort();
14980 }
14981 }
14982 }
14983
14984 // Verify FoldCache/FoldCacheUser caches.
14985 for (auto [FoldID, Expr] : FoldCache) {
14986 auto I = FoldCacheUser.find(Expr);
14987 if (I == FoldCacheUser.end()) {
14988 dbgs() << "Missing entry in FoldCacheUser for cached expression " << *Expr
14989 << "!\n";
14990 std::abort();
14991 }
14992 if (!is_contained(I->second, FoldID)) {
14993 dbgs() << "Missing FoldID in cached users of " << *Expr << "!\n";
14994 std::abort();
14995 }
14996 }
14997 for (auto [Expr, IDs] : FoldCacheUser) {
14998 for (auto &FoldID : IDs) {
14999 const SCEV *S = FoldCache.lookup(FoldID);
15000 if (!S) {
15001 dbgs() << "Missing entry in FoldCache for expression " << *Expr
15002 << "!\n";
15003 std::abort();
15004 }
15005 if (S != Expr) {
15006 dbgs() << "Entry in FoldCache doesn't match FoldCacheUser: " << *S
15007 << " != " << *Expr << "!\n";
15008 std::abort();
15009 }
15010 }
15011 }
15012
15013 // Verify that ConstantMultipleCache computations are correct. We check that
15014 // cached multiples and recomputed multiples are multiples of each other to
15015 // verify correctness. It is possible that a recomputed multiple is different
15016 // from the cached multiple due to strengthened no wrap flags or changes in
15017 // KnownBits computations.
15018 for (auto [S, Multiple] : ConstantMultipleCache) {
15019 APInt RecomputedMultiple = SE2.getConstantMultiple(S);
15020 if ((Multiple != 0 && RecomputedMultiple != 0 &&
15021 Multiple.urem(RecomputedMultiple) != 0 &&
15022 RecomputedMultiple.urem(Multiple) != 0)) {
15023 dbgs() << "Incorrect cached computation in ConstantMultipleCache for "
15024 << *S << " : Computed " << RecomputedMultiple
15025 << " but cache contains " << Multiple << "!\n";
15026 std::abort();
15027 }
15028 }
15029}
15030
15032 Function &F, const PreservedAnalyses &PA,
15033 FunctionAnalysisManager::Invalidator &Inv) {
15034 // Invalidate the ScalarEvolution object whenever it isn't preserved or one
15035 // of its dependencies is invalidated.
15036 auto PAC = PA.getChecker<ScalarEvolutionAnalysis>();
15037 return !(PAC.preserved() || PAC.preservedSet<AllAnalysesOn<Function>>()) ||
15038 Inv.invalidate<AssumptionAnalysis>(F, PA) ||
15039 Inv.invalidate<DominatorTreeAnalysis>(F, PA) ||
15040 Inv.invalidate<LoopAnalysis>(F, PA);
15041}
15042
15043AnalysisKey ScalarEvolutionAnalysis::Key;
15044
15047 auto &TLI = AM.getResult<TargetLibraryAnalysis>(F);
15048 auto &AC = AM.getResult<AssumptionAnalysis>(F);
15049 auto &DT = AM.getResult<DominatorTreeAnalysis>(F);
15050 auto &LI = AM.getResult<LoopAnalysis>(F);
15051 return ScalarEvolution(F, TLI, AC, DT, LI);
15052}
15053
15059
15062 // For compatibility with opt's -analyze feature under legacy pass manager
15063 // which was not ported to NPM. This keeps tests using
15064 // update_analyze_test_checks.py working.
15065 OS << "Printing analysis 'Scalar Evolution Analysis' for function '"
15066 << F.getName() << "':\n";
15068 return PreservedAnalyses::all();
15069}
15070
15072 "Scalar Evolution Analysis", false, true)
15078 "Scalar Evolution Analysis", false, true)
15079
15081
15083
15085 SE.reset(new ScalarEvolution(
15087 getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F),
15089 getAnalysis<LoopInfoWrapperPass>().getLoopInfo()));
15090 return false;
15091}
15092
15094
15096 SE->print(OS);
15097}
15098
15100 if (!VerifySCEV)
15101 return;
15102
15103 SE->verify();
15104}
15105
15113
15115 const SCEV *RHS) {
15116 return getComparePredicate(ICmpInst::ICMP_EQ, LHS, RHS);
15117}
15118
15119const SCEVPredicate *
15121 const SCEV *LHS, const SCEV *RHS) {
15123 assert(LHS->getType() == RHS->getType() &&
15124 "Type mismatch between LHS and RHS");
15125 // Unique this node based on the arguments
15126 ID.AddInteger(SCEVPredicate::P_Compare);
15127 ID.AddInteger(Pred);
15128 ID.AddPointer(LHS);
15129 ID.AddPointer(RHS);
15130 void *IP = nullptr;
15131 if (const auto *S = UniquePreds.FindNodeOrInsertPos(ID, IP))
15132 return S;
15133 SCEVComparePredicate *Eq = new (SCEVAllocator)
15134 SCEVComparePredicate(ID.Intern(SCEVAllocator), Pred, LHS, RHS);
15135 UniquePreds.InsertNode(Eq, IP);
15136 return Eq;
15137}
15138
15140 const SCEVAddRecExpr *AR,
15143 // Unique this node based on the arguments
15144 ID.AddInteger(SCEVPredicate::P_Wrap);
15145 ID.AddPointer(AR);
15146 ID.AddInteger(AddedFlags);
15147 void *IP = nullptr;
15148 if (const auto *S = UniquePreds.FindNodeOrInsertPos(ID, IP))
15149 return S;
15150 auto *OF = new (SCEVAllocator)
15151 SCEVWrapPredicate(ID.Intern(SCEVAllocator), AR, AddedFlags);
15152 UniquePreds.InsertNode(OF, IP);
15153 return OF;
15154}
15155
15156namespace {
15157
15158class SCEVPredicateRewriter : public SCEVRewriteVisitor<SCEVPredicateRewriter> {
15159public:
15160
15161 /// Rewrites \p S in the context of a loop L and the SCEV predication
15162 /// infrastructure.
15163 ///
15164 /// If \p Pred is non-null, the SCEV expression is rewritten to respect the
15165 /// equivalences present in \p Pred.
15166 ///
15167 /// If \p NewPreds is non-null, rewrite is free to add further predicates to
15168 /// \p NewPreds such that the result will be an AddRecExpr.
15169 static const SCEV *rewrite(const SCEV *S, const Loop *L, ScalarEvolution &SE,
15171 const SCEVPredicate *Pred) {
15172 SCEVPredicateRewriter Rewriter(L, SE, NewPreds, Pred);
15173 return Rewriter.visit(S);
15174 }
15175
15176 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
15177 if (Pred) {
15178 if (auto *U = dyn_cast<SCEVUnionPredicate>(Pred)) {
15179 for (const auto *Pred : U->getPredicates())
15180 if (const auto *IPred = dyn_cast<SCEVComparePredicate>(Pred))
15181 if (IPred->getLHS() == Expr &&
15182 IPred->getPredicate() == ICmpInst::ICMP_EQ)
15183 return IPred->getRHS();
15184 } else if (const auto *IPred = dyn_cast<SCEVComparePredicate>(Pred)) {
15185 if (IPred->getLHS() == Expr &&
15186 IPred->getPredicate() == ICmpInst::ICMP_EQ)
15187 return IPred->getRHS();
15188 }
15189 }
15190 return convertToAddRecWithPreds(Expr);
15191 }
15192
15193 const SCEV *visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr) {
15194 const SCEV *Operand = visit(Expr->getOperand());
15195 const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Operand);
15196 if (AR && AR->getLoop() == L && AR->isAffine()) {
15197 // This couldn't be folded because the operand didn't have the nuw
15198 // flag. Add the nusw flag as an assumption that we could make.
15199 const SCEV *Step = AR->getStepRecurrence(SE);
15200 Type *Ty = Expr->getType();
15201 if (addOverflowAssumption(AR, SCEVWrapPredicate::IncrementNUSW))
15202 return SE.getAddRecExpr(SE.getZeroExtendExpr(AR->getStart(), Ty),
15203 SE.getSignExtendExpr(Step, Ty), L,
15204 AR->getNoWrapFlags());
15205 }
15206 return SE.getZeroExtendExpr(Operand, Expr->getType());
15207 }
15208
15209 const SCEV *visitSignExtendExpr(const SCEVSignExtendExpr *Expr) {
15210 const SCEV *Operand = visit(Expr->getOperand());
15211 const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Operand);
15212 if (AR && AR->getLoop() == L && AR->isAffine()) {
15213 // This couldn't be folded because the operand didn't have the nsw
15214 // flag. Add the nssw flag as an assumption that we could make.
15215 const SCEV *Step = AR->getStepRecurrence(SE);
15216 Type *Ty = Expr->getType();
15217 if (addOverflowAssumption(AR, SCEVWrapPredicate::IncrementNSSW))
15218 return SE.getAddRecExpr(SE.getSignExtendExpr(AR->getStart(), Ty),
15219 SE.getSignExtendExpr(Step, Ty), L,
15220 AR->getNoWrapFlags());
15221 }
15222 return SE.getSignExtendExpr(Operand, Expr->getType());
15223 }
15224
15225private:
15226 explicit SCEVPredicateRewriter(
15227 const Loop *L, ScalarEvolution &SE,
15228 SmallVectorImpl<const SCEVPredicate *> *NewPreds,
15229 const SCEVPredicate *Pred)
15230 : SCEVRewriteVisitor(SE), NewPreds(NewPreds), Pred(Pred), L(L) {}
15231
15232 bool addOverflowAssumption(const SCEVPredicate *P) {
15233 if (!NewPreds) {
15234 // Check if we've already made this assumption.
15235 return Pred && Pred->implies(P, SE);
15236 }
15237 NewPreds->push_back(P);
15238 return true;
15239 }
15240
15241 bool addOverflowAssumption(const SCEVAddRecExpr *AR,
15243 auto *A = SE.getWrapPredicate(AR, AddedFlags);
15244 return addOverflowAssumption(A);
15245 }
15246
15247 // If \p Expr represents a PHINode, we try to see if it can be represented
15248 // as an AddRec, possibly under a predicate (PHISCEVPred). If it is possible
15249 // to add this predicate as a runtime overflow check, we return the AddRec.
15250 // If \p Expr does not meet these conditions (is not a PHI node, or we
15251 // couldn't create an AddRec for it, or couldn't add the predicate), we just
15252 // return \p Expr.
15253 const SCEV *convertToAddRecWithPreds(const SCEVUnknown *Expr) {
15254 if (!isa<PHINode>(Expr->getValue()))
15255 return Expr;
15256 std::optional<
15257 std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>>>
15258 PredicatedRewrite = SE.createAddRecFromPHIWithCasts(Expr);
15259 if (!PredicatedRewrite)
15260 return Expr;
15261 for (const auto *P : PredicatedRewrite->second){
15262 // Wrap predicates from outer loops are not supported.
15263 if (auto *WP = dyn_cast<const SCEVWrapPredicate>(P)) {
15264 if (L != WP->getExpr()->getLoop())
15265 return Expr;
15266 }
15267 if (!addOverflowAssumption(P))
15268 return Expr;
15269 }
15270 return PredicatedRewrite->first;
15271 }
15272
15273 SmallVectorImpl<const SCEVPredicate *> *NewPreds;
15274 const SCEVPredicate *Pred;
15275 const Loop *L;
15276};
15277
15278} // end anonymous namespace
15279
15280const SCEV *
15282 const SCEVPredicate &Preds) {
15283 return SCEVPredicateRewriter::rewrite(S, L, *this, nullptr, &Preds);
15284}
15285
15287 const SCEV *S, const Loop *L,
15290 S = SCEVPredicateRewriter::rewrite(S, L, *this, &TransformPreds, nullptr);
15291 auto *AddRec = dyn_cast<SCEVAddRecExpr>(S);
15292
15293 if (!AddRec)
15294 return nullptr;
15295
15296 // Check if any of the transformed predicates is known to be false. In that
15297 // case, it doesn't make sense to convert to a predicated AddRec, as the
15298 // versioned loop will never execute.
15299 for (const SCEVPredicate *Pred : TransformPreds) {
15300 auto *WrapPred = dyn_cast<SCEVWrapPredicate>(Pred);
15301 if (!WrapPred || WrapPred->getFlags() != SCEVWrapPredicate::IncrementNSSW)
15302 continue;
15303
15304 const SCEVAddRecExpr *AddRecToCheck = WrapPred->getExpr();
15305 const SCEV *ExitCount = getBackedgeTakenCount(AddRecToCheck->getLoop());
15306 if (isa<SCEVCouldNotCompute>(ExitCount))
15307 continue;
15308
15309 const SCEV *Step = AddRecToCheck->getStepRecurrence(*this);
15310 if (!Step->isOne())
15311 continue;
15312
15313 ExitCount = getTruncateOrSignExtend(ExitCount, Step->getType());
15314 const SCEV *Add = getAddExpr(AddRecToCheck->getStart(), ExitCount);
15315 if (isKnownPredicate(CmpInst::ICMP_SLT, Add, AddRecToCheck->getStart()))
15316 return nullptr;
15317 }
15318
15319 // Since the transformation was successful, we can now transfer the SCEV
15320 // predicates.
15321 Preds.append(TransformPreds.begin(), TransformPreds.end());
15322
15323 return AddRec;
15324}
15325
15326/// SCEV predicates
15330
15332 const ICmpInst::Predicate Pred,
15333 const SCEV *LHS, const SCEV *RHS)
15334 : SCEVPredicate(ID, P_Compare), Pred(Pred), LHS(LHS), RHS(RHS) {
15335 assert(LHS->getType() == RHS->getType() && "LHS and RHS types don't match");
15336 assert(LHS != RHS && "LHS and RHS are the same SCEV");
15337}
15338
15340 ScalarEvolution &SE) const {
15341 const auto *Op = dyn_cast<SCEVComparePredicate>(N);
15342
15343 if (!Op)
15344 return false;
15345
15346 if (Pred != ICmpInst::ICMP_EQ)
15347 return false;
15348
15349 return Op->LHS == LHS && Op->RHS == RHS;
15350}
15351
15352bool SCEVComparePredicate::isAlwaysTrue() const { return false; }
15353
15355 if (Pred == ICmpInst::ICMP_EQ)
15356 OS.indent(Depth) << "Equal predicate: " << *LHS << " == " << *RHS << "\n";
15357 else
15358 OS.indent(Depth) << "Compare predicate: " << *LHS << " " << Pred << ") "
15359 << *RHS << "\n";
15360
15361}
15362
15364 const SCEVAddRecExpr *AR,
15365 IncrementWrapFlags Flags)
15366 : SCEVPredicate(ID, P_Wrap), AR(AR), Flags(Flags) {}
15367
15368const SCEVAddRecExpr *SCEVWrapPredicate::getExpr() const { return AR; }
15369
15371 ScalarEvolution &SE) const {
15372 const auto *Op = dyn_cast<SCEVWrapPredicate>(N);
15373 if (!Op || setFlags(Flags, Op->Flags) != Flags)
15374 return false;
15375
15376 if (Op->AR == AR)
15377 return true;
15378
15379 if (Flags != SCEVWrapPredicate::IncrementNSSW &&
15381 return false;
15382
15383 const SCEV *Start = AR->getStart();
15384 const SCEV *OpStart = Op->AR->getStart();
15385 if (Start->getType()->isPointerTy() != OpStart->getType()->isPointerTy())
15386 return false;
15387
15388 // Reject pointers to different address spaces.
15389 if (Start->getType()->isPointerTy() && Start->getType() != OpStart->getType())
15390 return false;
15391
15392 // NUSW/NSSW on a wider-type AddRec does not imply the same on a
15393 // narrower-type AddRec.
15394 if (SE.getTypeSizeInBits(AR->getType()) >
15395 SE.getTypeSizeInBits(Op->AR->getType()))
15396 return false;
15397
15398 const SCEV *Step = AR->getStepRecurrence(SE);
15399 const SCEV *OpStep = Op->AR->getStepRecurrence(SE);
15400 if (!SE.isKnownPositive(Step) || !SE.isKnownPositive(OpStep))
15401 return false;
15402
15403 // If both steps are positive, this implies N, if N's start and step are
15404 // ULE/SLE (for NSUW/NSSW) than this'.
15405 Type *WiderTy = SE.getWiderType(Step->getType(), OpStep->getType());
15406 Step = SE.getNoopOrZeroExtend(Step, WiderTy);
15407 OpStep = SE.getNoopOrZeroExtend(OpStep, WiderTy);
15408
15409 bool IsNUW = Flags == SCEVWrapPredicate::IncrementNUSW;
15410 OpStart = IsNUW ? SE.getNoopOrZeroExtend(OpStart, WiderTy)
15411 : SE.getNoopOrSignExtend(OpStart, WiderTy);
15412 Start = IsNUW ? SE.getNoopOrZeroExtend(Start, WiderTy)
15413 : SE.getNoopOrSignExtend(Start, WiderTy);
15415 return SE.isKnownPredicate(Pred, OpStep, Step) &&
15416 SE.isKnownPredicate(Pred, OpStart, Start);
15417}
15418
15420 SCEV::NoWrapFlags ScevFlags = AR->getNoWrapFlags();
15421 IncrementWrapFlags IFlags = Flags;
15422
15423 if (ScalarEvolution::setFlags(ScevFlags, SCEV::FlagNSW) == ScevFlags)
15424 IFlags = clearFlags(IFlags, IncrementNSSW);
15425
15426 return IFlags == IncrementAnyWrap;
15427}
15428
15429void SCEVWrapPredicate::print(raw_ostream &OS, unsigned Depth) const {
15430 OS.indent(Depth) << *getExpr() << " Added Flags: ";
15432 OS << "<nusw>";
15434 OS << "<nssw>";
15435 OS << "\n";
15436}
15437
15440 ScalarEvolution &SE) {
15441 IncrementWrapFlags ImpliedFlags = IncrementAnyWrap;
15442 SCEV::NoWrapFlags StaticFlags = AR->getNoWrapFlags();
15443
15444 // We can safely transfer the NSW flag as NSSW.
15445 if (ScalarEvolution::setFlags(StaticFlags, SCEV::FlagNSW) == StaticFlags)
15446 ImpliedFlags = IncrementNSSW;
15447
15448 if (ScalarEvolution::setFlags(StaticFlags, SCEV::FlagNUW) == StaticFlags) {
15449 // If the increment is positive, the SCEV NUW flag will also imply the
15450 // WrapPredicate NUSW flag.
15451 if (const auto *Step = dyn_cast<SCEVConstant>(AR->getStepRecurrence(SE)))
15452 if (Step->getValue()->getValue().isNonNegative())
15453 ImpliedFlags = setFlags(ImpliedFlags, IncrementNUSW);
15454 }
15455
15456 return ImpliedFlags;
15457}
15458
15459/// Union predicates don't get cached so create a dummy set ID for it.
15461 ScalarEvolution &SE)
15462 : SCEVPredicate(FoldingSetNodeIDRef(nullptr, 0), P_Union) {
15463 for (const auto *P : Preds)
15464 add(P, SE);
15465}
15466
15468 return all_of(Preds,
15469 [](const SCEVPredicate *I) { return I->isAlwaysTrue(); });
15470}
15471
15473 ScalarEvolution &SE) const {
15474 if (const auto *Set = dyn_cast<SCEVUnionPredicate>(N))
15475 return all_of(Set->Preds, [this, &SE](const SCEVPredicate *I) {
15476 return this->implies(I, SE);
15477 });
15478
15479 return any_of(Preds,
15480 [N, &SE](const SCEVPredicate *I) { return I->implies(N, SE); });
15481}
15482
15484 for (const auto *Pred : Preds)
15485 Pred->print(OS, Depth);
15486}
15487
15488void SCEVUnionPredicate::add(const SCEVPredicate *N, ScalarEvolution &SE) {
15489 if (const auto *Set = dyn_cast<SCEVUnionPredicate>(N)) {
15490 for (const auto *Pred : Set->Preds)
15491 add(Pred, SE);
15492 return;
15493 }
15494
15495 // Implication checks are quadratic in the number of predicates. Stop doing
15496 // them if there are many predicates, as they should be too expensive to use
15497 // anyway at that point.
15498 bool CheckImplies = Preds.size() < 16;
15499
15500 // Only add predicate if it is not already implied by this union predicate.
15501 if (CheckImplies && implies(N, SE))
15502 return;
15503
15504 // Build a new vector containing the current predicates, except the ones that
15505 // are implied by the new predicate N.
15507 for (auto *P : Preds) {
15508 if (CheckImplies && N->implies(P, SE))
15509 continue;
15510 PrunedPreds.push_back(P);
15511 }
15512 Preds = std::move(PrunedPreds);
15513 Preds.push_back(N);
15514}
15515
15517 Loop &L)
15518 : SE(SE), L(L) {
15520 Preds = std::make_unique<SCEVUnionPredicate>(Empty, SE);
15521}
15522
15525 for (const auto *Op : Ops)
15526 // We do not expect that forgetting cached data for SCEVConstants will ever
15527 // open any prospects for sharpening or introduce any correctness issues,
15528 // so we don't bother storing their dependencies.
15529 if (!isa<SCEVConstant>(Op))
15530 SCEVUsers[Op].insert(User);
15531}
15532
15534 for (const SCEV *Op : Ops)
15535 // We do not expect that forgetting cached data for SCEVConstants will ever
15536 // open any prospects for sharpening or introduce any correctness issues,
15537 // so we don't bother storing their dependencies.
15538 if (!isa<SCEVConstant>(Op))
15539 SCEVUsers[Op].insert(User);
15540}
15541
15543 const SCEV *Expr = SE.getSCEV(V);
15544 return getPredicatedSCEV(Expr);
15545}
15546
15548 RewriteEntry &Entry = RewriteMap[Expr];
15549
15550 // If we already have an entry and the version matches, return it.
15551 if (Entry.second && Generation == Entry.first)
15552 return Entry.second;
15553
15554 // We found an entry but it's stale. Rewrite the stale entry
15555 // according to the current predicate.
15556 if (Entry.second)
15557 Expr = Entry.second;
15558
15559 const SCEV *NewSCEV = SE.rewriteUsingPredicate(Expr, &L, *Preds);
15560 Entry = {Generation, NewSCEV};
15561
15562 return NewSCEV;
15563}
15564
15566 if (!BackedgeCount) {
15568 BackedgeCount = SE.getPredicatedBackedgeTakenCount(&L, Preds);
15569 for (const auto *P : Preds)
15570 addPredicate(*P);
15571 }
15572 return BackedgeCount;
15573}
15574
15576 if (!SymbolicMaxBackedgeCount) {
15578 SymbolicMaxBackedgeCount =
15579 SE.getPredicatedSymbolicMaxBackedgeTakenCount(&L, Preds);
15580 for (const auto *P : Preds)
15581 addPredicate(*P);
15582 }
15583 return SymbolicMaxBackedgeCount;
15584}
15585
15587 if (!SmallConstantMaxTripCount) {
15589 SmallConstantMaxTripCount = SE.getSmallConstantMaxTripCount(&L, &Preds);
15590 for (const auto *P : Preds)
15591 addPredicate(*P);
15592 }
15593 return *SmallConstantMaxTripCount;
15594}
15595
15597 if (Preds->implies(&Pred, SE))
15598 return;
15599
15600 SmallVector<const SCEVPredicate *, 4> NewPreds(Preds->getPredicates());
15601 NewPreds.push_back(&Pred);
15602 Preds = std::make_unique<SCEVUnionPredicate>(NewPreds, SE);
15603 updateGeneration();
15604}
15605
15607 return *Preds;
15608}
15609
15610void PredicatedScalarEvolution::updateGeneration() {
15611 // If the generation number wrapped recompute everything.
15612 if (++Generation == 0) {
15613 for (auto &II : RewriteMap) {
15614 const SCEV *Rewritten = II.second.second;
15615 II.second = {Generation, SE.rewriteUsingPredicate(Rewritten, &L, *Preds)};
15616 }
15617 }
15618}
15619
15622 const SCEV *Expr = getSCEV(V);
15623 const auto *AR = cast<SCEVAddRecExpr>(Expr);
15624
15625 auto ImpliedFlags = SCEVWrapPredicate::getImpliedFlags(AR, SE);
15626
15627 // Clear the statically implied flags.
15628 Flags = SCEVWrapPredicate::clearFlags(Flags, ImpliedFlags);
15629 addPredicate(*SE.getWrapPredicate(AR, Flags));
15630
15631 auto II = FlagsMap.insert({V, Flags});
15632 if (!II.second)
15633 II.first->second = SCEVWrapPredicate::setFlags(Flags, II.first->second);
15634}
15635
15638 const SCEV *Expr = getSCEV(V);
15639 const auto *AR = cast<SCEVAddRecExpr>(Expr);
15640
15642 Flags, SCEVWrapPredicate::getImpliedFlags(AR, SE));
15643
15644 auto II = FlagsMap.find(V);
15645
15646 if (II != FlagsMap.end())
15647 Flags = SCEVWrapPredicate::clearFlags(Flags, II->second);
15648
15650}
15651
15654 const SCEV *Expr = this->getSCEV(V);
15656 auto *New = SE.convertSCEVToAddRecWithPredicates(Expr, &L, NewPreds);
15657
15658 if (!New)
15659 return nullptr;
15660
15661 if (ExtraPreds) {
15662 ExtraPreds->append(NewPreds);
15663 return New;
15664 }
15665
15666 for (const auto *P : NewPreds)
15667 addPredicate(*P);
15668
15669 RewriteMap[SE.getSCEV(V)] = {Generation, New};
15670 return New;
15671}
15672
15675 : RewriteMap(Init.RewriteMap), SE(Init.SE), L(Init.L),
15676 Preds(std::make_unique<SCEVUnionPredicate>(Init.Preds->getPredicates(),
15677 SE)),
15678 Generation(Init.Generation), BackedgeCount(Init.BackedgeCount) {
15679 for (auto I : Init.FlagsMap)
15680 FlagsMap.insert(I);
15681}
15682
15684 // For each block.
15685 for (auto *BB : L.getBlocks())
15686 for (auto &I : *BB) {
15687 if (!SE.isSCEVable(I.getType()))
15688 continue;
15689
15690 auto *Expr = SE.getSCEV(&I);
15691 auto II = RewriteMap.find(Expr);
15692
15693 if (II == RewriteMap.end())
15694 continue;
15695
15696 // Don't print things that are not interesting.
15697 if (II->second.second == Expr)
15698 continue;
15699
15700 OS.indent(Depth) << "[PSE]" << I << ":\n";
15701 OS.indent(Depth + 2) << *Expr << "\n";
15702 OS.indent(Depth + 2) << "--> " << *II->second.second << "\n";
15703 }
15704}
15705
15708 BasicBlock *Header = L->getHeader();
15709 BasicBlock *Pred = L->getLoopPredecessor();
15710 LoopGuards Guards(SE);
15711 if (!Pred)
15712 return Guards;
15714 collectFromBlock(SE, Guards, Header, Pred, VisitedBlocks);
15715 return Guards;
15716}
15717
15718void ScalarEvolution::LoopGuards::collectFromPHI(
15722 unsigned Depth) {
15723 if (!SE.isSCEVable(Phi.getType()))
15724 return;
15725
15726 using MinMaxPattern = std::pair<const SCEVConstant *, SCEVTypes>;
15727 auto GetMinMaxConst = [&](unsigned IncomingIdx) -> MinMaxPattern {
15728 const BasicBlock *InBlock = Phi.getIncomingBlock(IncomingIdx);
15729 if (!VisitedBlocks.insert(InBlock).second)
15730 return {nullptr, scCouldNotCompute};
15731
15732 // Avoid analyzing unreachable blocks so that we don't get trapped
15733 // traversing cycles with ill-formed dominance or infinite cycles
15734 if (!SE.DT.isReachableFromEntry(InBlock))
15735 return {nullptr, scCouldNotCompute};
15736
15737 auto [G, Inserted] = IncomingGuards.try_emplace(InBlock, LoopGuards(SE));
15738 if (Inserted)
15739 collectFromBlock(SE, G->second, Phi.getParent(), InBlock, VisitedBlocks,
15740 Depth + 1);
15741 auto &RewriteMap = G->second.RewriteMap;
15742 if (RewriteMap.empty())
15743 return {nullptr, scCouldNotCompute};
15744 auto S = RewriteMap.find(SE.getSCEV(Phi.getIncomingValue(IncomingIdx)));
15745 if (S == RewriteMap.end())
15746 return {nullptr, scCouldNotCompute};
15747 auto *SM = dyn_cast_if_present<SCEVMinMaxExpr>(S->second);
15748 if (!SM)
15749 return {nullptr, scCouldNotCompute};
15750 if (const SCEVConstant *C0 = dyn_cast<SCEVConstant>(SM->getOperand(0)))
15751 return {C0, SM->getSCEVType()};
15752 return {nullptr, scCouldNotCompute};
15753 };
15754 auto MergeMinMaxConst = [](MinMaxPattern P1,
15755 MinMaxPattern P2) -> MinMaxPattern {
15756 auto [C1, T1] = P1;
15757 auto [C2, T2] = P2;
15758 if (!C1 || !C2 || T1 != T2)
15759 return {nullptr, scCouldNotCompute};
15760 switch (T1) {
15761 case scUMaxExpr:
15762 return {C1->getAPInt().ult(C2->getAPInt()) ? C1 : C2, T1};
15763 case scSMaxExpr:
15764 return {C1->getAPInt().slt(C2->getAPInt()) ? C1 : C2, T1};
15765 case scUMinExpr:
15766 return {C1->getAPInt().ugt(C2->getAPInt()) ? C1 : C2, T1};
15767 case scSMinExpr:
15768 return {C1->getAPInt().sgt(C2->getAPInt()) ? C1 : C2, T1};
15769 default:
15770 llvm_unreachable("Trying to merge non-MinMaxExpr SCEVs.");
15771 }
15772 };
15773 auto P = GetMinMaxConst(0);
15774 for (unsigned int In = 1; In < Phi.getNumIncomingValues(); In++) {
15775 if (!P.first)
15776 break;
15777 P = MergeMinMaxConst(P, GetMinMaxConst(In));
15778 }
15779 if (P.first) {
15780 const SCEV *LHS = SE.getSCEV(const_cast<PHINode *>(&Phi));
15781 SmallVector<SCEVUse, 2> Ops({P.first, LHS});
15782 const SCEV *RHS = SE.getMinMaxExpr(P.second, Ops);
15783 Guards.RewriteMap.insert({LHS, RHS});
15784 }
15785}
15786
15787// Return a new SCEV that modifies \p Expr to the closest number divides by
15788// \p Divisor and less or equal than Expr. For now, only handle constant
15789// Expr.
15791 const APInt &DivisorVal,
15792 ScalarEvolution &SE) {
15793 const APInt *ExprVal;
15794 if (!match(Expr, m_scev_APInt(ExprVal)) || ExprVal->isNegative() ||
15795 DivisorVal.isNonPositive())
15796 return Expr;
15797 APInt Rem = ExprVal->urem(DivisorVal);
15798 // return the SCEV: Expr - Expr % Divisor
15799 return SE.getConstant(*ExprVal - Rem);
15800}
15801
15802// Return a new SCEV that modifies \p Expr to the closest number divides by
15803// \p Divisor and greater or equal than Expr. For now, only handle constant
15804// Expr.
15805static const SCEV *getNextSCEVDivisibleByDivisor(const SCEV *Expr,
15806 const APInt &DivisorVal,
15807 ScalarEvolution &SE) {
15808 const APInt *ExprVal;
15809 if (!match(Expr, m_scev_APInt(ExprVal)) || ExprVal->isNegative() ||
15810 DivisorVal.isNonPositive())
15811 return Expr;
15812 APInt Rem = ExprVal->urem(DivisorVal);
15813 if (Rem.isZero())
15814 return Expr;
15815 // return the SCEV: Expr + Divisor - Expr % Divisor
15816 return SE.getConstant(*ExprVal + DivisorVal - Rem);
15817}
15818
15820 ICmpInst::Predicate Predicate, const SCEV *LHS, const SCEV *RHS,
15823 // If we have LHS == 0, check if LHS is computing a property of some unknown
15824 // SCEV %v which we can rewrite %v to express explicitly.
15826 return false;
15827 // If LHS is A % B, i.e. A % B == 0, rewrite A to (A /u B) * B to
15828 // explicitly express that.
15829 const SCEVUnknown *URemLHS = nullptr;
15830 const SCEV *URemRHS = nullptr;
15831 if (!match(LHS, m_scev_URem(m_SCEVUnknown(URemLHS), m_SCEV(URemRHS), SE)))
15832 return false;
15833
15834 const SCEV *Multiple =
15835 SE.getMulExpr(SE.getUDivExpr(URemLHS, URemRHS), URemRHS);
15836 DivInfo[URemLHS] = Multiple;
15837 if (auto *C = dyn_cast<SCEVConstant>(URemRHS))
15838 Multiples[URemLHS] = C->getAPInt();
15839 return true;
15840}
15841
15842// Check if the condition is a divisibility guard (A % B == 0).
15843static bool isDivisibilityGuard(const SCEV *LHS, const SCEV *RHS,
15844 ScalarEvolution &SE) {
15845 const SCEV *X, *Y;
15846 return match(LHS, m_scev_URem(m_SCEV(X), m_SCEV(Y), SE)) && RHS->isZero();
15847}
15848
15849// Apply divisibility by \p Divisor on MinMaxExpr with constant values,
15850// recursively. This is done by aligning up/down the constant value to the
15851// Divisor.
15852static const SCEV *applyDivisibilityOnMinMaxExpr(const SCEV *MinMaxExpr,
15853 APInt Divisor,
15854 ScalarEvolution &SE) {
15855 // Return true if \p Expr is a MinMax SCEV expression with a non-negative
15856 // constant operand. If so, return in \p SCTy the SCEV type and in \p RHS
15857 // the non-constant operand and in \p LHS the constant operand.
15858 auto IsMinMaxSCEVWithNonNegativeConstant =
15859 [&](const SCEV *Expr, SCEVTypes &SCTy, const SCEV *&LHS,
15860 const SCEV *&RHS) {
15861 if (auto *MinMax = dyn_cast<SCEVMinMaxExpr>(Expr)) {
15862 if (MinMax->getNumOperands() != 2)
15863 return false;
15864 if (auto *C = dyn_cast<SCEVConstant>(MinMax->getOperand(0))) {
15865 if (C->getAPInt().isNegative())
15866 return false;
15867 SCTy = MinMax->getSCEVType();
15868 LHS = MinMax->getOperand(0);
15869 RHS = MinMax->getOperand(1);
15870 return true;
15871 }
15872 }
15873 return false;
15874 };
15875
15876 const SCEV *MinMaxLHS = nullptr, *MinMaxRHS = nullptr;
15877 SCEVTypes SCTy;
15878 if (!IsMinMaxSCEVWithNonNegativeConstant(MinMaxExpr, SCTy, MinMaxLHS,
15879 MinMaxRHS))
15880 return MinMaxExpr;
15881 auto IsMin = isa<SCEVSMinExpr>(MinMaxExpr) || isa<SCEVUMinExpr>(MinMaxExpr);
15882 assert(SE.isKnownNonNegative(MinMaxLHS) && "Expected non-negative operand!");
15883 auto *DivisibleExpr =
15884 IsMin ? getPreviousSCEVDivisibleByDivisor(MinMaxLHS, Divisor, SE)
15885 : getNextSCEVDivisibleByDivisor(MinMaxLHS, Divisor, SE);
15887 applyDivisibilityOnMinMaxExpr(MinMaxRHS, Divisor, SE), DivisibleExpr};
15888 return SE.getMinMaxExpr(SCTy, Ops);
15889}
15890
15891void ScalarEvolution::LoopGuards::collectFromBlock(
15892 ScalarEvolution &SE, ScalarEvolution::LoopGuards &Guards,
15893 const BasicBlock *Block, const BasicBlock *Pred,
15894 SmallPtrSetImpl<const BasicBlock *> &VisitedBlocks, unsigned Depth) {
15895
15897
15898 SmallVector<SCEVUse> ExprsToRewrite;
15899 auto CollectCondition = [&](ICmpInst::Predicate Predicate, const SCEV *LHS,
15900 const SCEV *RHS,
15901 DenseMap<const SCEV *, const SCEV *> &RewriteMap,
15902 const LoopGuards &DivGuards) {
15903 // WARNING: It is generally unsound to apply any wrap flags to the proposed
15904 // replacement SCEV which isn't directly implied by the structure of that
15905 // SCEV. In particular, using contextual facts to imply flags is *NOT*
15906 // legal. See the scoping rules for flags in the header to understand why.
15907
15908 // Check for a condition of the form (-C1 + X < C2). InstCombine will
15909 // create this form when combining two checks of the form (X u< C2 + C1) and
15910 // (X >=u C1).
15911 auto MatchRangeCheckIdiom = [&SE, Predicate, LHS, RHS, &RewriteMap,
15912 &ExprsToRewrite]() {
15913 const SCEVConstant *C1;
15914 const SCEVUnknown *LHSUnknown;
15915 auto *C2 = dyn_cast<SCEVConstant>(RHS);
15916 if (!match(LHS,
15917 m_scev_Add(m_SCEVConstant(C1), m_SCEVUnknown(LHSUnknown))) ||
15918 !C2)
15919 return false;
15920
15921 auto ExactRegion =
15922 ConstantRange::makeExactICmpRegion(Predicate, C2->getAPInt())
15923 .sub(C1->getAPInt());
15924
15925 // Bail out, unless we have a non-wrapping, monotonic range.
15926 if (ExactRegion.isWrappedSet() || ExactRegion.isFullSet())
15927 return false;
15928 auto [I, Inserted] = RewriteMap.try_emplace(LHSUnknown);
15929 const SCEV *RewrittenLHS = Inserted ? LHSUnknown : I->second;
15930 I->second = SE.getUMaxExpr(
15931 SE.getConstant(ExactRegion.getUnsignedMin()),
15932 SE.getUMinExpr(RewrittenLHS,
15933 SE.getConstant(ExactRegion.getUnsignedMax())));
15934 ExprsToRewrite.push_back(LHSUnknown);
15935 return true;
15936 };
15937 if (MatchRangeCheckIdiom())
15938 return;
15939
15940 // Do not apply information for constants or if RHS contains an AddRec.
15942 return;
15943
15944 // If RHS is SCEVUnknown, make sure the information is applied to it.
15946 std::swap(LHS, RHS);
15948 }
15949
15950 // Puts rewrite rule \p From -> \p To into the rewrite map. Also if \p From
15951 // and \p FromRewritten are the same (i.e. there has been no rewrite
15952 // registered for \p From), then puts this value in the list of rewritten
15953 // expressions.
15954 auto AddRewrite = [&](const SCEV *From, const SCEV *FromRewritten,
15955 const SCEV *To) {
15956 if (From == FromRewritten)
15957 ExprsToRewrite.push_back(From);
15958 RewriteMap[From] = To;
15959 };
15960
15961 // Checks whether \p S has already been rewritten. In that case returns the
15962 // existing rewrite because we want to chain further rewrites onto the
15963 // already rewritten value. Otherwise returns \p S.
15964 auto GetMaybeRewritten = [&](const SCEV *S) {
15965 return RewriteMap.lookup_or(S, S);
15966 };
15967
15968 const SCEV *RewrittenLHS = GetMaybeRewritten(LHS);
15969 // Apply divisibility information when computing the constant multiple.
15970 const APInt &DividesBy =
15971 SE.getConstantMultiple(DivGuards.rewrite(RewrittenLHS));
15972
15973 // Collect rewrites for LHS and its transitive operands based on the
15974 // condition.
15975 // For min/max expressions, also apply the guard to its operands:
15976 // 'min(a, b) >= c' -> '(a >= c) and (b >= c)',
15977 // 'min(a, b) > c' -> '(a > c) and (b > c)',
15978 // 'max(a, b) <= c' -> '(a <= c) and (b <= c)',
15979 // 'max(a, b) < c' -> '(a < c) and (b < c)'.
15980
15981 // We cannot express strict predicates in SCEV, so instead we replace them
15982 // with non-strict ones against plus or minus one of RHS depending on the
15983 // predicate.
15984 const SCEV *One = SE.getOne(RHS->getType());
15985 switch (Predicate) {
15986 case CmpInst::ICMP_ULT:
15987 if (RHS->getType()->isPointerTy())
15988 return;
15989 RHS = SE.getUMaxExpr(RHS, One);
15990 [[fallthrough]];
15991 case CmpInst::ICMP_SLT: {
15992 RHS = SE.getMinusSCEV(RHS, One);
15993 RHS = getPreviousSCEVDivisibleByDivisor(RHS, DividesBy, SE);
15994 break;
15995 }
15996 case CmpInst::ICMP_UGT:
15997 case CmpInst::ICMP_SGT:
15998 RHS = SE.getAddExpr(RHS, One);
15999 RHS = getNextSCEVDivisibleByDivisor(RHS, DividesBy, SE);
16000 break;
16001 case CmpInst::ICMP_ULE:
16002 case CmpInst::ICMP_SLE:
16003 RHS = getPreviousSCEVDivisibleByDivisor(RHS, DividesBy, SE);
16004 break;
16005 case CmpInst::ICMP_UGE:
16006 case CmpInst::ICMP_SGE:
16007 RHS = getNextSCEVDivisibleByDivisor(RHS, DividesBy, SE);
16008 break;
16009 default:
16010 break;
16011 }
16012
16013 SmallVector<SCEVUse, 16> Worklist(1, LHS);
16014 SmallPtrSet<const SCEV *, 16> Visited;
16015
16016 auto EnqueueOperands = [&Worklist](const SCEVNAryExpr *S) {
16017 append_range(Worklist, S->operands());
16018 };
16019
16020 while (!Worklist.empty()) {
16021 const SCEV *From = Worklist.pop_back_val();
16022 if (isa<SCEVConstant>(From))
16023 continue;
16024 if (!Visited.insert(From).second)
16025 continue;
16026 const SCEV *FromRewritten = GetMaybeRewritten(From);
16027 const SCEV *To = nullptr;
16028
16029 switch (Predicate) {
16030 case CmpInst::ICMP_ULT:
16031 case CmpInst::ICMP_ULE:
16032 To = SE.getUMinExpr(FromRewritten, RHS);
16033 if (auto *UMax = dyn_cast<SCEVUMaxExpr>(FromRewritten))
16034 EnqueueOperands(UMax);
16035 break;
16036 case CmpInst::ICMP_SLT:
16037 case CmpInst::ICMP_SLE:
16038 To = SE.getSMinExpr(FromRewritten, RHS);
16039 if (auto *SMax = dyn_cast<SCEVSMaxExpr>(FromRewritten))
16040 EnqueueOperands(SMax);
16041 break;
16042 case CmpInst::ICMP_UGT:
16043 case CmpInst::ICMP_UGE:
16044 To = SE.getUMaxExpr(FromRewritten, RHS);
16045 if (auto *UMin = dyn_cast<SCEVUMinExpr>(FromRewritten))
16046 EnqueueOperands(UMin);
16047 break;
16048 case CmpInst::ICMP_SGT:
16049 case CmpInst::ICMP_SGE:
16050 To = SE.getSMaxExpr(FromRewritten, RHS);
16051 if (auto *SMin = dyn_cast<SCEVSMinExpr>(FromRewritten))
16052 EnqueueOperands(SMin);
16053 break;
16054 case CmpInst::ICMP_EQ:
16056 To = RHS;
16057 break;
16058 case CmpInst::ICMP_NE:
16059 if (match(RHS, m_scev_Zero())) {
16060 const SCEV *OneAlignedUp =
16061 getNextSCEVDivisibleByDivisor(One, DividesBy, SE);
16062 To = SE.getUMaxExpr(FromRewritten, OneAlignedUp);
16063 } else {
16064 // LHS != RHS can be rewritten as (LHS - RHS) = UMax(1, LHS - RHS),
16065 // but creating the subtraction eagerly is expensive. Track the
16066 // inequalities in a separate map, and materialize the rewrite lazily
16067 // when encountering a suitable subtraction while re-writing.
16068 if (LHS->getType()->isPointerTy()) {
16072 break;
16073 }
16074 const SCEVConstant *C;
16075 const SCEV *A, *B;
16078 RHS = A;
16079 LHS = B;
16080 }
16081 if (LHS > RHS)
16082 std::swap(LHS, RHS);
16083 Guards.NotEqual.insert({LHS, RHS});
16084 continue;
16085 }
16086 break;
16087 default:
16088 break;
16089 }
16090
16091 if (To)
16092 AddRewrite(From, FromRewritten, To);
16093 }
16094 };
16095
16097 // First, collect information from assumptions dominating the loop.
16098 for (auto &AssumeVH : SE.AC.assumptions()) {
16099 if (!AssumeVH)
16100 continue;
16101 auto *AssumeI = cast<CallInst>(AssumeVH);
16102 if (!SE.DT.dominates(AssumeI, Block))
16103 continue;
16104 Terms.emplace_back(AssumeI->getOperand(0), true);
16105 }
16106
16107 // Second, collect information from llvm.experimental.guards dominating the loop.
16108 auto *GuardDecl = Intrinsic::getDeclarationIfExists(
16109 SE.F.getParent(), Intrinsic::experimental_guard);
16110 if (GuardDecl)
16111 for (const auto *GU : GuardDecl->users())
16112 if (const auto *Guard = dyn_cast<IntrinsicInst>(GU))
16113 if (Guard->getFunction() == Block->getParent() &&
16114 SE.DT.dominates(Guard, Block))
16115 Terms.emplace_back(Guard->getArgOperand(0), true);
16116
16117 // Third, collect conditions from dominating branches. Starting at the loop
16118 // predecessor, climb up the predecessor chain, as long as there are
16119 // predecessors that can be found that have unique successors leading to the
16120 // original header.
16121 // TODO: share this logic with isLoopEntryGuardedByCond.
16122 unsigned NumCollectedConditions = 0;
16124 std::pair<const BasicBlock *, const BasicBlock *> Pair(Pred, Block);
16125 for (; Pair.first;
16126 Pair = SE.getPredecessorWithUniqueSuccessorForBB(Pair.first)) {
16127 VisitedBlocks.insert(Pair.second);
16128 const CondBrInst *LoopEntryPredicate =
16129 dyn_cast<CondBrInst>(Pair.first->getTerminator());
16130 if (!LoopEntryPredicate)
16131 continue;
16132
16133 Terms.emplace_back(LoopEntryPredicate->getCondition(),
16134 LoopEntryPredicate->getSuccessor(0) == Pair.second);
16135 NumCollectedConditions++;
16136
16137 // If we are recursively collecting guards stop after 2
16138 // conditions to limit compile-time impact for now.
16139 if (Depth > 0 && NumCollectedConditions == 2)
16140 break;
16141 }
16142 // Finally, if we stopped climbing the predecessor chain because
16143 // there wasn't a unique one to continue, try to collect conditions
16144 // for PHINodes by recursively following all of their incoming
16145 // blocks and try to merge the found conditions to build a new one
16146 // for the Phi.
16147 if (Pair.second->hasNPredecessorsOrMore(2) &&
16149 SmallDenseMap<const BasicBlock *, LoopGuards> IncomingGuards;
16150 for (auto &Phi : Pair.second->phis())
16151 collectFromPHI(SE, Guards, Phi, VisitedBlocks, IncomingGuards, Depth);
16152 }
16153
16154 // Now apply the information from the collected conditions to
16155 // Guards.RewriteMap. Conditions are processed in reverse order, so the
16156 // earliest conditions is processed first, except guards with divisibility
16157 // information, which are moved to the back. This ensures the SCEVs with the
16158 // shortest dependency chains are constructed first.
16160 GuardsToProcess;
16161 for (auto [Term, EnterIfTrue] : reverse(Terms)) {
16162 SmallVector<Value *, 8> Worklist;
16163 SmallPtrSet<Value *, 8> Visited;
16164 Worklist.push_back(Term);
16165 while (!Worklist.empty()) {
16166 Value *Cond = Worklist.pop_back_val();
16167 if (!Visited.insert(Cond).second)
16168 continue;
16169
16170 if (auto *Cmp = dyn_cast<ICmpInst>(Cond)) {
16171 auto Predicate =
16172 EnterIfTrue ? Cmp->getPredicate() : Cmp->getInversePredicate();
16173 const auto *LHS = SE.getSCEV(Cmp->getOperand(0));
16174 const auto *RHS = SE.getSCEV(Cmp->getOperand(1));
16175 // If LHS is a constant, apply information to the other expression.
16176 // TODO: If LHS is not a constant, check if using CompareSCEVComplexity
16177 // can improve results.
16178 if (isa<SCEVConstant>(LHS)) {
16179 std::swap(LHS, RHS);
16181 }
16182 GuardsToProcess.emplace_back(Predicate, LHS, RHS);
16183 continue;
16184 }
16185
16186 Value *L, *R;
16187 if (EnterIfTrue ? match(Cond, m_LogicalAnd(m_Value(L), m_Value(R)))
16188 : match(Cond, m_LogicalOr(m_Value(L), m_Value(R)))) {
16189 Worklist.push_back(L);
16190 Worklist.push_back(R);
16191 }
16192 }
16193 }
16194
16195 // Process divisibility guards in reverse order to populate DivGuards early.
16196 DenseMap<const SCEV *, APInt> Multiples;
16197 LoopGuards DivGuards(SE);
16198 for (const auto &[Predicate, LHS, RHS] : GuardsToProcess) {
16199 if (!isDivisibilityGuard(LHS, RHS, SE))
16200 continue;
16201 collectDivisibilityInformation(Predicate, LHS, RHS, DivGuards.RewriteMap,
16202 Multiples, SE);
16203 }
16204
16205 for (const auto &[Predicate, LHS, RHS] : GuardsToProcess)
16206 CollectCondition(Predicate, LHS, RHS, Guards.RewriteMap, DivGuards);
16207
16208 // Apply divisibility information last. This ensures it is applied to the
16209 // outermost expression after other rewrites for the given value.
16210 for (const auto &[K, Divisor] : Multiples) {
16211 const SCEV *DivisorSCEV = SE.getConstant(Divisor);
16212 Guards.RewriteMap[K] =
16214 Guards.rewrite(K), Divisor, SE),
16215 DivisorSCEV),
16216 DivisorSCEV);
16217 ExprsToRewrite.push_back(K);
16218 }
16219
16220 // Let the rewriter preserve NUW/NSW flags if the unsigned/signed ranges of
16221 // the replacement expressions are contained in the ranges of the replaced
16222 // expressions.
16223 Guards.PreserveNUW = true;
16224 Guards.PreserveNSW = true;
16225 for (const SCEV *Expr : ExprsToRewrite) {
16226 const SCEV *RewriteTo = Guards.RewriteMap[Expr];
16227 Guards.PreserveNUW &=
16228 SE.getUnsignedRange(Expr).contains(SE.getUnsignedRange(RewriteTo));
16229 Guards.PreserveNSW &=
16230 SE.getSignedRange(Expr).contains(SE.getSignedRange(RewriteTo));
16231 }
16232
16233 // Now that all rewrite information is collect, rewrite the collected
16234 // expressions with the information in the map. This applies information to
16235 // sub-expressions.
16236 if (ExprsToRewrite.size() > 1) {
16237 for (const SCEV *Expr : ExprsToRewrite) {
16238 const SCEV *RewriteTo = Guards.RewriteMap[Expr];
16239 Guards.RewriteMap.erase(Expr);
16240 Guards.RewriteMap.insert({Expr, Guards.rewrite(RewriteTo)});
16241 }
16242 }
16243}
16244
16246 /// A rewriter to replace SCEV expressions in Map with the corresponding entry
16247 /// in the map. It skips AddRecExpr because we cannot guarantee that the
16248 /// replacement is loop invariant in the loop of the AddRec.
16249 class SCEVLoopGuardRewriter
16250 : public SCEVRewriteVisitor<SCEVLoopGuardRewriter> {
16253
16255
16256 public:
16257 SCEVLoopGuardRewriter(ScalarEvolution &SE,
16258 const ScalarEvolution::LoopGuards &Guards)
16259 : SCEVRewriteVisitor(SE), Map(Guards.RewriteMap),
16260 NotEqual(Guards.NotEqual) {
16261 if (Guards.PreserveNUW)
16262 FlagMask = ScalarEvolution::setFlags(FlagMask, SCEV::FlagNUW);
16263 if (Guards.PreserveNSW)
16264 FlagMask = ScalarEvolution::setFlags(FlagMask, SCEV::FlagNSW);
16265 }
16266
16267 const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) { return Expr; }
16268
16269 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
16270 return Map.lookup_or(Expr, Expr);
16271 }
16272
16273 const SCEV *visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr) {
16274 if (const SCEV *S = Map.lookup(Expr))
16275 return S;
16276
16277 // If we didn't find the extact ZExt expr in the map, check if there's
16278 // an entry for a smaller ZExt we can use instead.
16279 Type *Ty = Expr->getType();
16280 const SCEV *Op = Expr->getOperand(0);
16281 unsigned Bitwidth = Ty->getScalarSizeInBits() / 2;
16282 while (Bitwidth % 8 == 0 && Bitwidth >= 8 &&
16283 Bitwidth > Op->getType()->getScalarSizeInBits()) {
16284 Type *NarrowTy = IntegerType::get(SE.getContext(), Bitwidth);
16285 auto *NarrowExt = SE.getZeroExtendExpr(Op, NarrowTy);
16286 if (const SCEV *S = Map.lookup(NarrowExt))
16287 return SE.getZeroExtendExpr(S, Ty);
16288 Bitwidth = Bitwidth / 2;
16289 }
16290
16292 Expr);
16293 }
16294
16295 const SCEV *visitSignExtendExpr(const SCEVSignExtendExpr *Expr) {
16296 if (const SCEV *S = Map.lookup(Expr))
16297 return S;
16299 Expr);
16300 }
16301
16302 const SCEV *visitUMinExpr(const SCEVUMinExpr *Expr) {
16303 if (const SCEV *S = Map.lookup(Expr))
16304 return S;
16306 }
16307
16308 const SCEV *visitSMinExpr(const SCEVSMinExpr *Expr) {
16309 if (const SCEV *S = Map.lookup(Expr))
16310 return S;
16312 }
16313
16314 const SCEV *visitAddExpr(const SCEVAddExpr *Expr) {
16315 // Helper to check if S is a subtraction (A - B) where A != B, and if so,
16316 // return UMax(S, 1).
16317 auto RewriteSubtraction = [&](const SCEV *S) -> const SCEV * {
16318 SCEVUse LHS, RHS;
16319 if (MatchBinarySub(S, LHS, RHS)) {
16320 if (LHS > RHS)
16321 std::swap(LHS, RHS);
16322 if (NotEqual.contains({LHS, RHS})) {
16323 const SCEV *OneAlignedUp = getNextSCEVDivisibleByDivisor(
16324 SE.getOne(S->getType()), SE.getConstantMultiple(S), SE);
16325 return SE.getUMaxExpr(OneAlignedUp, S);
16326 }
16327 }
16328 return nullptr;
16329 };
16330
16331 // Check if Expr itself is a subtraction pattern with guard info.
16332 if (const SCEV *Rewritten = RewriteSubtraction(Expr))
16333 return Rewritten;
16334
16335 // Trip count expressions sometimes consist of adding 3 operands, i.e.
16336 // (Const + A + B). There may be guard info for A + B, and if so, apply
16337 // it.
16338 // TODO: Could more generally apply guards to Add sub-expressions.
16339 if (isa<SCEVConstant>(Expr->getOperand(0)) &&
16340 Expr->getNumOperands() == 3) {
16341 const SCEV *Add =
16342 SE.getAddExpr(Expr->getOperand(1), Expr->getOperand(2));
16343 if (const SCEV *Rewritten = RewriteSubtraction(Add))
16344 return SE.getAddExpr(
16345 Expr->getOperand(0), Rewritten,
16346 ScalarEvolution::maskFlags(Expr->getNoWrapFlags(), FlagMask));
16347 if (const SCEV *S = Map.lookup(Add))
16348 return SE.getAddExpr(Expr->getOperand(0), S);
16349 }
16350 SmallVector<SCEVUse, 2> Operands;
16351 bool Changed = false;
16352 for (SCEVUse Op : Expr->operands()) {
16353 Operands.push_back(
16355 Changed |= Op != Operands.back();
16356 }
16357 // We are only replacing operands with equivalent values, so transfer the
16358 // flags from the original expression.
16359 return !Changed ? Expr
16360 : SE.getAddExpr(Operands,
16362 Expr->getNoWrapFlags(), FlagMask));
16363 }
16364
16365 const SCEV *visitMulExpr(const SCEVMulExpr *Expr) {
16366 SmallVector<SCEVUse, 2> Operands;
16367 bool Changed = false;
16368 for (SCEVUse Op : Expr->operands()) {
16369 Operands.push_back(
16371 Changed |= Op != Operands.back();
16372 }
16373 // We are only replacing operands with equivalent values, so transfer the
16374 // flags from the original expression.
16375 return !Changed ? Expr
16376 : SE.getMulExpr(Operands,
16378 Expr->getNoWrapFlags(), FlagMask));
16379 }
16380 };
16381
16382 if (RewriteMap.empty() && NotEqual.empty())
16383 return Expr;
16384
16385 SCEVLoopGuardRewriter Rewriter(SE, *this);
16386 return Rewriter.visit(Expr);
16387}
16388
16389const SCEV *ScalarEvolution::applyLoopGuards(const SCEV *Expr, const Loop *L) {
16390 return applyLoopGuards(Expr, LoopGuards::collect(L, *this));
16391}
16392
16394 const LoopGuards &Guards) {
16395 return Guards.rewrite(Expr);
16396}
assert(UImm &&(UImm !=~static_cast< T >(0)) &&"Invalid immediate!")
constexpr LLT S1
Rewrite undef for PHI
This file implements a class to represent arbitrary precision integral constant values and operations...
@ PostInc
MachineBasicBlock MachineBasicBlock::iterator DebugLoc DL
Expand Atomic instructions
#define X(NUM, ENUM, NAME)
Definition ELF.h:853
static GCRegistry::Add< ErlangGC > A("erlang", "erlang-compatible garbage collector")
static GCRegistry::Add< StatepointGC > D("statepoint-example", "an example strategy for statepoint")
static GCRegistry::Add< CoreCLRGC > E("coreclr", "CoreCLR-compatible GC")
static GCRegistry::Add< OcamlGC > B("ocaml", "ocaml 3.10-compatible GC")
#define LLVM_DUMP_METHOD
Mark debug helper function definitions like dump() that should not be stripped from debug builds.
Definition Compiler.h:661
This file contains the declarations for the subclasses of Constant, which represent the different fla...
SmallPtrSet< const BasicBlock *, 8 > VisitedBlocks
This file defines the DenseMap class.
This file builds on the ADT/GraphTraits.h file to build generic depth first graph iterator.
static bool isSigned(unsigned Opcode)
This file defines a hash set that can be used to remove duplication of nodes in a graph.
#define op(i)
Hexagon Common GEP
Value * getPointer(Value *Ptr)
This file provides various utilities for inspecting and working with the control flow graph in LLVM I...
This defines the Use class.
iv Induction Variable Users
Definition IVUsers.cpp:48
static constexpr Value * getValue(Ty &ValueOrUse)
static Value * getOpcode(Value &V, Type &Ty, InstrumentationConfig &IConf, InstrumentorIRBuilderTy &IIRB)
const AbstractManglingParser< Derived, Alloc >::OperatorInfo AbstractManglingParser< Derived, Alloc >::Ops[]
static bool isZero(Value *V, const DataLayout &DL, DominatorTree *DT, AssumptionCache *AC)
Definition Lint.cpp:539
#define F(x, y, z)
Definition MD5.cpp:54
#define I(x, y, z)
Definition MD5.cpp:57
#define G(x, y, z)
Definition MD5.cpp:55
#define T
#define T1
MachineInstr unsigned OpIdx
static constexpr unsigned SM(unsigned Version)
ConstantRange Range(APInt(BitWidth, Low), APInt(BitWidth, High))
uint64_t IntrinsicInst * II
#define P(N)
ppc ctr loops verify
PowerPC Reduce CR logical Operation
#define INITIALIZE_PASS_DEPENDENCY(depName)
Definition PassSupport.h:42
#define INITIALIZE_PASS_END(passName, arg, name, cfg, analysis)
Definition PassSupport.h:44
#define INITIALIZE_PASS_BEGIN(passName, arg, name, cfg, analysis)
Definition PassSupport.h:39
R600 Clause Merge
const SmallVectorImpl< MachineOperand > & Cond
static DominatorTree getDomTree(Function &F)
static bool isValid(const char C)
Returns true if C is a valid mangled character: <0-9a-zA-Z_>.
SI optimize exec mask operations pre RA
static void visit(BasicBlock &Start, std::function< bool(BasicBlock *)> op)
This file contains some templates that are useful if you are working with the STL at all.
This file provides utility classes that use RAII to save and restore values.
bool SCEVMinMaxExprContains(const SCEV *Root, const SCEV *OperandToFind, SCEVTypes RootKind)
static cl::opt< unsigned > MaxAddRecSize("scalar-evolution-max-add-rec-size", cl::Hidden, cl::desc("Max coefficients in AddRec during evolving"), cl::init(8))
static cl::opt< unsigned > RangeIterThreshold("scev-range-iter-threshold", cl::Hidden, cl::desc("Threshold for switching to iteratively computing SCEV ranges"), cl::init(32))
static const Loop * isIntegerLoopHeaderPHI(const PHINode *PN, LoopInfo &LI)
static unsigned getConstantTripCount(const SCEVConstant *ExitCount)
static int CompareValueComplexity(const LoopInfo *const LI, Value *LV, Value *RV, unsigned Depth)
Compare the two values LV and RV in terms of their "complexity" where "complexity" is a partial (and ...
static const SCEV * getNextSCEVDivisibleByDivisor(const SCEV *Expr, const APInt &DivisorVal, ScalarEvolution &SE)
static void insertFoldCacheEntry(const ScalarEvolution::FoldID &ID, const SCEV *S, DenseMap< ScalarEvolution::FoldID, const SCEV * > &FoldCache, DenseMap< const SCEV *, SmallVector< ScalarEvolution::FoldID, 2 > > &FoldCacheUser)
static cl::opt< bool > ClassifyExpressions("scalar-evolution-classify-expressions", cl::Hidden, cl::init(true), cl::desc("When printing analysis, include information on every instruction"))
static bool hasHugeExpression(ArrayRef< SCEVUse > Ops)
Returns true if Ops contains a huge SCEV (the subtree of S contains at least HugeExprThreshold nodes)...
static bool CanConstantFold(const Instruction *I)
Return true if we can constant fold an instruction of the specified type, assuming that all operands ...
static cl::opt< unsigned > AddOpsInlineThreshold("scev-addops-inline-threshold", cl::Hidden, cl::desc("Threshold for inlining addition operands into a SCEV"), cl::init(500))
static cl::opt< unsigned > MaxLoopGuardCollectionDepth("scalar-evolution-max-loop-guard-collection-depth", cl::Hidden, cl::desc("Maximum depth for recursive loop guard collection"), cl::init(1))
static cl::opt< bool > VerifyIR("scev-verify-ir", cl::Hidden, cl::desc("Verify IR correctness when making sensitive SCEV queries (slow)"), cl::init(false))
static bool RangeRefPHIAllowedOperands(DominatorTree &DT, PHINode *PHI)
static const SCEV * getPreStartForExtend(const SCEVAddRecExpr *AR, Type *Ty, ScalarEvolution *SE, unsigned Depth)
static std::optional< APInt > MinOptional(std::optional< APInt > X, std::optional< APInt > Y)
Helper function to compare optional APInts: (a) if X and Y both exist, return min(X,...
static cl::opt< unsigned > MulOpsInlineThreshold("scev-mulops-inline-threshold", cl::Hidden, cl::desc("Threshold for inlining multiplication operands into a SCEV"), cl::init(32))
static BinaryOperator * getCommonInstForPHI(PHINode *PN)
static bool isDivisibilityGuard(const SCEV *LHS, const SCEV *RHS, ScalarEvolution &SE)
static std::optional< const SCEV * > createNodeForSelectViaUMinSeq(ScalarEvolution *SE, const SCEV *CondExpr, const SCEV *TrueExpr, const SCEV *FalseExpr)
static Constant * BuildConstantFromSCEV(const SCEV *V)
This builds up a Constant using the ConstantExpr interface.
static ConstantInt * EvaluateConstantChrecAtConstant(const SCEVAddRecExpr *AddRec, ConstantInt *C, ScalarEvolution &SE)
static const SCEV * BinomialCoefficient(const SCEV *It, unsigned K, ScalarEvolution &SE, Type *ResultTy)
Compute BC(It, K). The result has width W. Assume, K > 0.
static cl::opt< unsigned > MaxCastDepth("scalar-evolution-max-cast-depth", cl::Hidden, cl::desc("Maximum depth of recursive SExt/ZExt/Trunc"), cl::init(8))
static bool IsMinMaxConsistingOf(const SCEV *MaybeMinMaxExpr, const SCEV *Candidate)
Is MaybeMinMaxExpr an (U|S)(Min|Max) of Candidate and some other values?
static PHINode * getConstantEvolvingPHI(Value *V, const Loop *L)
getConstantEvolvingPHI - Given an LLVM value and a loop, return a PHI node in the loop that V is deri...
static const SCEV * SolveLinEquationWithOverflow(const APInt &A, const SCEV *B, SmallVectorImpl< const SCEVPredicate * > *Predicates, ScalarEvolution &SE, const Loop *L)
Finds the minimum unsigned root of the following equation:
static cl::opt< unsigned > MaxBruteForceIterations("scalar-evolution-max-iterations", cl::ReallyHidden, cl::desc("Maximum number of iterations SCEV will " "symbolically execute a constant " "derived loop"), cl::init(100))
static uint64_t umul_ov(uint64_t i, uint64_t j, bool &Overflow)
static void PrintSCEVWithTypeHint(raw_ostream &OS, const SCEV *S)
When printing a top-level SCEV for trip counts, it's helpful to include a type for constants which ar...
static void PrintLoopInfo(raw_ostream &OS, ScalarEvolution *SE, const Loop *L)
static SCEV::NoWrapFlags StrengthenNoWrapFlags(ScalarEvolution *SE, SCEVTypes Type, ArrayRef< SCEVUse > Ops, SCEV::NoWrapFlags Flags)
static bool containsConstantInAddMulChain(const SCEV *StartExpr)
Determine if any of the operands in this SCEV are a constant or if any of the add or multiply express...
static const SCEV * getExtendAddRecStart(const SCEVAddRecExpr *AR, Type *Ty, ScalarEvolution *SE, unsigned Depth)
static bool CollectAddOperandsWithScales(SmallDenseMap< SCEVUse, APInt, 16 > &M, SmallVectorImpl< SCEVUse > &NewOps, APInt &AccumulatedConstant, ArrayRef< SCEVUse > Ops, const APInt &Scale, ScalarEvolution &SE)
Process the given Ops list, which is a list of operands to be added under the given scale,...
static const SCEV * constantFoldAndGroupOps(ScalarEvolution &SE, LoopInfo &LI, DominatorTree &DT, SmallVectorImpl< SCEVUse > &Ops, FoldT Fold, IsIdentityT IsIdentity, IsAbsorberT IsAbsorber)
Performs a number of common optimizations on the passed Ops.
static cl::opt< unsigned > MaxPhiSCCAnalysisSize("scalar-evolution-max-scc-analysis-depth", cl::Hidden, cl::desc("Maximum amount of nodes to process while searching SCEVUnknown " "Phi strongly connected components"), cl::init(8))
static bool IsKnownPredicateViaAddRecStart(ScalarEvolution &SE, CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS)
static void GroupByComplexity(SmallVectorImpl< SCEVUse > &Ops, LoopInfo *LI, DominatorTree &DT)
Given a list of SCEV objects, order them by their complexity, and group objects of the same complexit...
static bool collectDivisibilityInformation(ICmpInst::Predicate Predicate, const SCEV *LHS, const SCEV *RHS, DenseMap< const SCEV *, const SCEV * > &DivInfo, DenseMap< const SCEV *, APInt > &Multiples, ScalarEvolution &SE)
static cl::opt< unsigned > MaxSCEVOperationsImplicationDepth("scalar-evolution-max-scev-operations-implication-depth", cl::Hidden, cl::desc("Maximum depth of recursive SCEV operations implication analysis"), cl::init(2))
static void PushDefUseChildren(Instruction *I, SmallVectorImpl< Instruction * > &Worklist, SmallPtrSetImpl< Instruction * > &Visited)
Push users of the given Instruction onto the given Worklist.
static std::optional< APInt > SolveQuadraticAddRecRange(const SCEVAddRecExpr *AddRec, const ConstantRange &Range, ScalarEvolution &SE)
Let c(n) be the value of the quadratic chrec {0,+,M,+,N} after n iterations.
static cl::opt< bool > UseContextForNoWrapFlagInference("scalar-evolution-use-context-for-no-wrap-flag-strenghening", cl::Hidden, cl::desc("Infer nuw/nsw flags using context where suitable"), cl::init(true))
static cl::opt< bool > EnableFiniteLoopControl("scalar-evolution-finite-loop", cl::Hidden, cl::desc("Handle <= and >= in finite loops"), cl::init(true))
static bool getOperandsForSelectLikePHI(DominatorTree &DT, PHINode *PN, Value *&Cond, Value *&LHS, Value *&RHS)
static std::optional< std::tuple< APInt, APInt, APInt, APInt, unsigned > > GetQuadraticEquation(const SCEVAddRecExpr *AddRec)
For a given quadratic addrec, generate coefficients of the corresponding quadratic equation,...
static bool isKnownPredicateExtendIdiom(CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS)
static std::optional< BinaryOp > MatchBinaryOp(Value *V, const DataLayout &DL, AssumptionCache &AC, const DominatorTree &DT, const Instruction *CxtI)
Try to map V into a BinaryOp, and return std::nullopt on failure.
static std::optional< APInt > SolveQuadraticAddRecExact(const SCEVAddRecExpr *AddRec, ScalarEvolution &SE)
Let c(n) be the value of the quadratic chrec {L,+,M,+,N} after n iterations.
static std::optional< APInt > TruncIfPossible(std::optional< APInt > X, unsigned BitWidth)
Helper function to truncate an optional APInt to a given BitWidth.
static cl::opt< unsigned > MaxSCEVCompareDepth("scalar-evolution-max-scev-compare-depth", cl::Hidden, cl::desc("Maximum depth of recursive SCEV complexity comparisons"), cl::init(32))
static APInt extractConstantWithoutWrapping(ScalarEvolution &SE, const SCEVConstant *ConstantTerm, const SCEVAddExpr *WholeAddExpr)
static cl::opt< unsigned > MaxConstantEvolvingDepth("scalar-evolution-max-constant-evolving-depth", cl::Hidden, cl::desc("Maximum depth of recursive constant evolving"), cl::init(32))
static ConstantRange getRangeForAffineARHelper(APInt Step, const ConstantRange &StartRange, const APInt &MaxBECount, bool Signed)
static bool MatchBinarySub(const SCEV *S, SCEVUse &LHS, SCEVUse &RHS)
static std::optional< ConstantRange > GetRangeFromMetadata(Value *V)
Helper method to assign a range to V from metadata present in the IR.
static cl::opt< unsigned > HugeExprThreshold("scalar-evolution-huge-expr-threshold", cl::Hidden, cl::desc("Size of the expression which is considered huge"), cl::init(4096))
static Type * isSimpleCastedPHI(const SCEV *Op, const SCEVUnknown *SymbolicPHI, bool &Signed, ScalarEvolution &SE)
Helper function to createAddRecFromPHIWithCasts.
static Constant * EvaluateExpression(Value *V, const Loop *L, DenseMap< Instruction *, Constant * > &Vals, const DataLayout &DL, const TargetLibraryInfo *TLI)
EvaluateExpression - Given an expression that passes the getConstantEvolvingPHI predicate,...
static const SCEV * getPreviousSCEVDivisibleByDivisor(const SCEV *Expr, const APInt &DivisorVal, ScalarEvolution &SE)
static const SCEV * MatchNotExpr(const SCEV *Expr)
If Expr computes ~A, return A else return nullptr.
static cl::opt< unsigned > MaxValueCompareDepth("scalar-evolution-max-value-compare-depth", cl::Hidden, cl::desc("Maximum depth of recursive value complexity comparisons"), cl::init(2))
static const SCEV * applyDivisibilityOnMinMaxExpr(const SCEV *MinMaxExpr, APInt Divisor, ScalarEvolution &SE)
static cl::opt< bool, true > VerifySCEVOpt("verify-scev", cl::Hidden, cl::location(VerifySCEV), cl::desc("Verify ScalarEvolution's backedge taken counts (slow)"))
static const SCEV * getSignedOverflowLimitForStep(const SCEV *Step, ICmpInst::Predicate *Pred, ScalarEvolution *SE)
static cl::opt< unsigned > MaxArithDepth("scalar-evolution-max-arith-depth", cl::Hidden, cl::desc("Maximum depth of recursive arithmetics"), cl::init(32))
static bool HasSameValue(const SCEV *A, const SCEV *B)
SCEV structural equivalence is usually sufficient for testing whether two expressions are equal,...
static uint64_t Choose(uint64_t n, uint64_t k, bool &Overflow)
Compute the result of "n choose k", the binomial coefficient.
static std::optional< int > CompareSCEVComplexity(const LoopInfo *const LI, const SCEV *LHS, const SCEV *RHS, DominatorTree &DT, unsigned Depth=0)
static bool canConstantEvolve(Instruction *I, const Loop *L)
Determine whether this instruction can constant evolve within this loop assuming its operands can all...
static PHINode * getConstantEvolvingPHIOperands(Instruction *UseInst, const Loop *L, DenseMap< Instruction *, PHINode * > &PHIMap, unsigned Depth)
getConstantEvolvingPHIOperands - Implement getConstantEvolvingPHI by recursing through each instructi...
static bool scevUnconditionallyPropagatesPoisonFromOperands(SCEVTypes Kind)
static cl::opt< bool > VerifySCEVStrict("verify-scev-strict", cl::Hidden, cl::desc("Enable stricter verification with -verify-scev is passed"))
static Constant * getOtherIncomingValue(PHINode *PN, BasicBlock *BB)
static cl::opt< bool > UseExpensiveRangeSharpening("scalar-evolution-use-expensive-range-sharpening", cl::Hidden, cl::init(false), cl::desc("Use more powerful methods of sharpening expression ranges. May " "be costly in terms of compile time"))
static const SCEV * getUnsignedOverflowLimitForStep(const SCEV *Step, ICmpInst::Predicate *Pred, ScalarEvolution *SE)
static bool IsKnownPredicateViaMinOrMax(ScalarEvolution &SE, CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS)
Is LHS Pred RHS true on the virtue of LHS or RHS being a Min or Max expression?
static bool BrPHIToSelect(DominatorTree &DT, CondBrInst *BI, PHINode *Merge, Value *&C, Value *&LHS, Value *&RHS)
This file defines the make_scope_exit function, which executes user-defined cleanup logic at scope ex...
static bool InBlock(const Value *V, const BasicBlock *BB)
Provides some synthesis utilities to produce sequences of values.
This file defines the SmallPtrSet class.
This file defines the SmallVector class.
This file defines the 'Statistic' class, which is designed to be an easy way to expose various metric...
#define STATISTIC(VARNAME, DESC)
Definition Statistic.h:171
This file contains some functions that are useful when dealing with strings.
#define LLVM_DEBUG(...)
Definition Debug.h:119
static TableGen::Emitter::Opt Y("gen-skeleton-entry", EmitSkeleton, "Generate example skeleton entry")
static SymbolRef::Type getType(const Symbol *Sym)
Definition TapiFile.cpp:39
LocallyHashedType DenseMapInfo< LocallyHashedType >::Empty
static std::optional< bool > isImpliedCondOperands(CmpInst::Predicate Pred, const Value *ALHS, const Value *ARHS, const Value *BLHS, const Value *BRHS)
Return true if "icmp Pred BLHS BRHS" is true whenever "icmp PredALHS ARHS" is true.
Virtual Register Rewriter
Value * RHS
Value * LHS
BinaryOperator * Mul
static const uint32_t IV[8]
Definition blake3_impl.h:83
SCEVCastSinkingRewriter(ScalarEvolution &SE, Type *TargetTy, ConversionFn CreatePtrCast)
static const SCEV * rewrite(const SCEV *Scev, ScalarEvolution &SE, Type *TargetTy, ConversionFn CreatePtrCast)
const SCEV * visitUnknown(const SCEVUnknown *Expr)
const SCEV * visitMulExpr(const SCEVMulExpr *Expr)
const SCEV * visitAddExpr(const SCEVAddExpr *Expr)
const SCEV * visit(const SCEV *S)
Class for arbitrary precision integers.
Definition APInt.h:78
LLVM_ABI APInt umul_ov(const APInt &RHS, bool &Overflow) const
Definition APInt.cpp:2006
LLVM_ABI APInt zext(unsigned width) const
Zero extend to a new width.
Definition APInt.cpp:1055
bool isMinSignedValue() const
Determine if this is the smallest signed value.
Definition APInt.h:424
uint64_t getZExtValue() const
Get zero extended value.
Definition APInt.h:1563
void setHighBits(unsigned hiBits)
Set the top hiBits bits.
Definition APInt.h:1414
LLVM_ABI APInt getHiBits(unsigned numBits) const
Compute an APInt containing numBits highbits from this APInt.
Definition APInt.cpp:640
unsigned getActiveBits() const
Compute the number of active bits in the value.
Definition APInt.h:1535
LLVM_ABI APInt trunc(unsigned width) const
Truncate to new width.
Definition APInt.cpp:968
static APInt getMaxValue(unsigned numBits)
Gets maximum unsigned value of APInt for specific bit width.
Definition APInt.h:207
APInt abs() const
Get the absolute value.
Definition APInt.h:1818
bool sgt(const APInt &RHS) const
Signed greater than comparison.
Definition APInt.h:1208
bool isAllOnes() const
Determine if all bits are set. This is true for zero-width values.
Definition APInt.h:372
bool ugt(const APInt &RHS) const
Unsigned greater than comparison.
Definition APInt.h:1189
bool isZero() const
Determine if this value is zero, i.e. all bits are clear.
Definition APInt.h:381
bool isSignMask() const
Check if the APInt's value is returned by getSignMask.
Definition APInt.h:467
LLVM_ABI APInt urem(const APInt &RHS) const
Unsigned remainder operation.
Definition APInt.cpp:1692
unsigned getBitWidth() const
Return the number of bits in the APInt.
Definition APInt.h:1511
bool ult(const APInt &RHS) const
Unsigned less than comparison.
Definition APInt.h:1118
static APInt getSignedMaxValue(unsigned numBits)
Gets maximum signed value of APInt for a specific bit width.
Definition APInt.h:210
static APInt getMinValue(unsigned numBits)
Gets minimum unsigned value of APInt for a specific bit width.
Definition APInt.h:217
bool isNegative() const
Determine sign of this APInt.
Definition APInt.h:330
bool sle(const APInt &RHS) const
Signed less or equal comparison.
Definition APInt.h:1173
static APInt getSignedMinValue(unsigned numBits)
Gets minimum signed value of APInt for a specific bit width.
Definition APInt.h:220
bool isNonPositive() const
Determine if this APInt Value is non-positive (<= 0).
Definition APInt.h:362
unsigned countTrailingZeros() const
Definition APInt.h:1670
bool isStrictlyPositive() const
Determine if this APInt Value is positive.
Definition APInt.h:357
unsigned logBase2() const
Definition APInt.h:1784
uint64_t getLimitedValue(uint64_t Limit=UINT64_MAX) const
If this value is smaller than the specified limit, return it, otherwise return the limit value.
Definition APInt.h:476
APInt ashr(unsigned ShiftAmt) const
Arithmetic right-shift function.
Definition APInt.h:834
LLVM_ABI APInt multiplicativeInverse() const
Definition APInt.cpp:1300
bool ule(const APInt &RHS) const
Unsigned less or equal comparison.
Definition APInt.h:1157
LLVM_ABI APInt sext(unsigned width) const
Sign extend to a new width.
Definition APInt.cpp:1028
APInt shl(unsigned shiftAmt) const
Left-shift function.
Definition APInt.h:880
bool isPowerOf2() const
Check if this APInt's value is a power of two greater than zero.
Definition APInt.h:441
static APInt getLowBitsSet(unsigned numBits, unsigned loBitsSet)
Constructs an APInt value that has the bottom loBitsSet bits set.
Definition APInt.h:307
bool isSignBitSet() const
Determine if sign bit of this APInt is set.
Definition APInt.h:342
bool slt(const APInt &RHS) const
Signed less than comparison.
Definition APInt.h:1137
static APInt getZero(unsigned numBits)
Get the '0' value for the specified bit-width.
Definition APInt.h:201
bool isIntN(unsigned N) const
Check if this APInt has an N-bits unsigned integer value.
Definition APInt.h:433
bool sge(const APInt &RHS) const
Signed greater or equal comparison.
Definition APInt.h:1244
static APInt getOneBitSet(unsigned numBits, unsigned BitNo)
Return an APInt with exactly one bit set in the result.
Definition APInt.h:240
bool uge(const APInt &RHS) const
Unsigned greater or equal comparison.
Definition APInt.h:1228
This templated class represents "all analyses that operate over <aparticular IR unit>" (e....
Definition Analysis.h:50
PassT::Result & getResult(IRUnitT &IR, ExtraArgTs... ExtraArgs)
Get the result of an analysis pass for a given IR unit.
Represent the analysis usage information of a pass.
void setPreservesAll()
Set by analyses that do not transform their input at all.
AnalysisUsage & addRequiredTransitive()
Represent a constant reference to an array (0 or more elements consecutively in memory),...
Definition ArrayRef.h:40
iterator end() const
Definition ArrayRef.h:130
size_t size() const
Get the array size.
Definition ArrayRef.h:141
iterator begin() const
Definition ArrayRef.h:129
A function analysis which provides an AssumptionCache.
An immutable pass that tracks lazily created AssumptionCache objects.
A cache of @llvm.assume calls within a function.
MutableArrayRef< WeakVH > assumptions()
Access the list of assumption handles currently tracked for this function.
LLVM Basic Block Representation.
Definition BasicBlock.h:62
iterator begin()
Instruction iterator methods.
Definition BasicBlock.h:461
const Function * getParent() const
Return the enclosing method, or null if none.
Definition BasicBlock.h:213
LLVM_ABI const BasicBlock * getSinglePredecessor() const
Return the predecessor of this block if it has a single predecessor block.
const Instruction & front() const
Definition BasicBlock.h:484
const Instruction * getTerminator() const LLVM_READONLY
Returns the terminator instruction; assumes that the block is well-formed.
Definition BasicBlock.h:237
LLVM_ABI unsigned getNoWrapKind() const
Returns one of OBO::NoSignedWrap or OBO::NoUnsignedWrap.
LLVM_ABI Instruction::BinaryOps getBinaryOp() const
Returns the binary operation underlying the intrinsic.
BinaryOps getOpcode() const
Definition InstrTypes.h:409
This class represents a function call, abstracting a target machine's calling convention.
virtual void deleted()
Callback for Value destruction.
void setValPtr(Value *P)
bool isFalseWhenEqual() const
This is just a convenience.
Predicate
This enumeration lists the possible predicates for CmpInst subclasses.
Definition InstrTypes.h:740
@ ICMP_SLT
signed less than
Definition InstrTypes.h:769
@ ICMP_SLE
signed less or equal
Definition InstrTypes.h:770
@ ICMP_UGE
unsigned greater or equal
Definition InstrTypes.h:764
@ ICMP_UGT
unsigned greater than
Definition InstrTypes.h:763
@ ICMP_SGT
signed greater than
Definition InstrTypes.h:767
@ ICMP_ULT
unsigned less than
Definition InstrTypes.h:765
@ ICMP_NE
not equal
Definition InstrTypes.h:762
@ ICMP_SGE
signed greater or equal
Definition InstrTypes.h:768
@ ICMP_ULE
unsigned less or equal
Definition InstrTypes.h:766
bool isSigned() const
Definition InstrTypes.h:993
Predicate getSwappedPredicate() const
For example, EQ->EQ, SLE->SGE, ULT->UGT, OEQ->OEQ, ULE->UGE, OLT->OGT, etc.
Definition InstrTypes.h:890
bool isTrueWhenEqual() const
This is just a convenience.
Predicate getInversePredicate() const
For example, EQ -> NE, UGT -> ULE, SLT -> SGE, OEQ -> UNE, UGT -> OLE, OLT -> UGE,...
Definition InstrTypes.h:852
bool isUnsigned() const
Definition InstrTypes.h:999
bool isRelational() const
Return true if the predicate is relational (not EQ or NE).
Definition InstrTypes.h:989
An abstraction over a floating-point predicate, and a pack of an integer predicate with samesign info...
static LLVM_ABI std::optional< CmpPredicate > getMatching(CmpPredicate A, CmpPredicate B)
Compares two CmpPredicates taking samesign into account and returns the canonicalized CmpPredicate if...
LLVM_ABI CmpInst::Predicate getPreferredSignedPredicate() const
Attempts to return a signed CmpInst::Predicate from the CmpPredicate.
CmpInst::Predicate dropSameSign() const
Drops samesign information.
Conditional Branch instruction.
Value * getCondition() const
BasicBlock * getSuccessor(unsigned i) const
static LLVM_ABI Constant * getNot(Constant *C)
static Constant * getPtrAdd(Constant *Ptr, Constant *Offset, GEPNoWrapFlags NW=GEPNoWrapFlags::none(), std::optional< ConstantRange > InRange=std::nullopt, Type *OnlyIfReduced=nullptr)
Create a getelementptr i8, ptr, offset constant expression.
Definition Constants.h:1491
static LLVM_ABI Constant * getPtrToInt(Constant *C, Type *Ty, bool OnlyIfReduced=false)
static LLVM_ABI Constant * getPtrToAddr(Constant *C, Type *Ty, bool OnlyIfReduced=false)
static LLVM_ABI Constant * getAdd(Constant *C1, Constant *C2, bool HasNUW=false, bool HasNSW=false)
static LLVM_ABI Constant * getNeg(Constant *C, bool HasNSW=false)
static LLVM_ABI Constant * getTrunc(Constant *C, Type *Ty, bool OnlyIfReduced=false)
This is the shared class of boolean and integer constants.
Definition Constants.h:87
bool isZero() const
This is just a convenience method to make client code smaller for a common code.
Definition Constants.h:219
static LLVM_ABI ConstantInt * getFalse(LLVMContext &Context)
uint64_t getZExtValue() const
Return the constant as a 64-bit unsigned integer value after it has been zero extended as appropriate...
Definition Constants.h:168
const APInt & getValue() const
Return the constant as an APInt value reference.
Definition Constants.h:159
static LLVM_ABI ConstantInt * getBool(LLVMContext &Context, bool V)
This class represents a range of values.
LLVM_ABI ConstantRange add(const ConstantRange &Other) const
Return a new range representing the possible values resulting from an addition of a value in this ran...
LLVM_ABI ConstantRange zextOrTrunc(uint32_t BitWidth) const
Make this range have the bit width given by BitWidth.
PreferredRangeType
If represented precisely, the result of some range operations may consist of multiple disjoint ranges...
LLVM_ABI bool getEquivalentICmp(CmpInst::Predicate &Pred, APInt &RHS) const
Set up Pred and RHS such that ConstantRange::makeExactICmpRegion(Pred, RHS) == *this.
const APInt & getLower() const
Return the lower value for this range.
LLVM_ABI ConstantRange urem(const ConstantRange &Other) const
Return a new range representing the possible values resulting from an unsigned remainder operation of...
LLVM_ABI bool isFullSet() const
Return true if this set contains all of the elements possible for this data-type.
LLVM_ABI bool icmp(CmpInst::Predicate Pred, const ConstantRange &Other) const
Does the predicate Pred hold between ranges this and Other?
LLVM_ABI bool isEmptySet() const
Return true if this set contains no members.
LLVM_ABI ConstantRange zeroExtend(uint32_t BitWidth) const
Return a new range in the specified integer type, which must be strictly larger than the current type...
LLVM_ABI bool isSignWrappedSet() const
Return true if this set wraps around the signed domain.
LLVM_ABI APInt getSignedMin() const
Return the smallest signed value contained in the ConstantRange.
LLVM_ABI bool isWrappedSet() const
Return true if this set wraps around the unsigned domain.
LLVM_ABI void print(raw_ostream &OS) const
Print out the bounds to a stream.
LLVM_ABI ConstantRange truncate(uint32_t BitWidth, unsigned NoWrapKind=0) const
Return a new range in the specified integer type, which must be strictly smaller than the current typ...
LLVM_ABI ConstantRange signExtend(uint32_t BitWidth) const
Return a new range in the specified integer type, which must be strictly larger than the current type...
const APInt & getUpper() const
Return the upper value for this range.
LLVM_ABI ConstantRange unionWith(const ConstantRange &CR, PreferredRangeType Type=Smallest) const
Return the range that results from the union of this range with another range.
static LLVM_ABI ConstantRange makeExactICmpRegion(CmpInst::Predicate Pred, const APInt &Other)
Produce the exact range such that all values in the returned range satisfy the given predicate with a...
LLVM_ABI bool contains(const APInt &Val) const
Return true if the specified value is in the set.
LLVM_ABI APInt getUnsignedMax() const
Return the largest unsigned value contained in the ConstantRange.
LLVM_ABI ConstantRange intersectWith(const ConstantRange &CR, PreferredRangeType Type=Smallest) const
Return the range that results from the intersection of this range with another range.
LLVM_ABI APInt getSignedMax() const
Return the largest signed value contained in the ConstantRange.
static ConstantRange getNonEmpty(APInt Lower, APInt Upper)
Create non-empty constant range with the given bounds.
static LLVM_ABI ConstantRange makeGuaranteedNoWrapRegion(Instruction::BinaryOps BinOp, const ConstantRange &Other, unsigned NoWrapKind)
Produce the largest range containing all X such that "X BinOp Y" is guaranteed not to wrap (overflow)...
LLVM_ABI unsigned getMinSignedBits() const
Compute the maximal number of bits needed to represent every value in this signed range.
uint32_t getBitWidth() const
Get the bit width of this ConstantRange.
LLVM_ABI ConstantRange sub(const ConstantRange &Other) const
Return a new range representing the possible values resulting from a subtraction of a value in this r...
LLVM_ABI ConstantRange sextOrTrunc(uint32_t BitWidth) const
Make this range have the bit width given by BitWidth.
static LLVM_ABI ConstantRange makeExactNoWrapRegion(Instruction::BinaryOps BinOp, const APInt &Other, unsigned NoWrapKind)
Produce the range that contains X if and only if "X BinOp Other" does not wrap.
This is an important base class in LLVM.
Definition Constant.h:43
A parsed version of the target data layout string in and methods for querying it.
Definition DataLayout.h:64
LLVM_ABI const StructLayout * getStructLayout(StructType *Ty) const
Returns a StructLayout object, indicating the alignment of the struct, its size, and the offsets of i...
LLVM_ABI IntegerType * getIntPtrType(LLVMContext &C, unsigned AddressSpace=0) const
Returns an integer type with size at least as big as that of a pointer in the given address space.
LLVM_ABI unsigned getIndexTypeSizeInBits(Type *Ty) const
The size in bits of the index used in GEP calculation for this type.
LLVM_ABI IntegerType * getIndexType(LLVMContext &C, unsigned AddressSpace) const
Returns the type of a GEP index in AddressSpace.
TypeSize getTypeSizeInBits(Type *Ty) const
Size examples:
Definition DataLayout.h:791
ValueT lookup(const_arg_type_t< KeyT > Val) const
Return the entry for the specified key, or a default constructed value if no such entry exists.
Definition DenseMap.h:252
iterator find(const_arg_type_t< KeyT > Val)
Definition DenseMap.h:225
std::pair< iterator, bool > try_emplace(KeyT &&Key, Ts &&...Args)
Definition DenseMap.h:301
DenseMapIterator< KeyT, ValueT, KeyInfoT, BucketT > iterator
Definition DenseMap.h:135
iterator find_as(const LookupKeyT &Val)
Alternate version of find() which allows a different, and possibly less expensive,...
Definition DenseMap.h:238
size_type count(const_arg_type_t< KeyT > Val) const
Return 1 if the specified key is in the map, 0 otherwise.
Definition DenseMap.h:221
iterator end()
Definition DenseMap.h:143
bool contains(const_arg_type_t< KeyT > Val) const
Return true if the specified key is in the map, false otherwise.
Definition DenseMap.h:216
void swap(DerivedT &RHS)
Definition DenseMap.h:439
std::pair< iterator, bool > insert(const std::pair< KeyT, ValueT > &KV)
Definition DenseMap.h:286
Analysis pass which computes a DominatorTree.
Definition Dominators.h:270
Legacy analysis pass which computes a DominatorTree.
Definition Dominators.h:306
Concrete subclass of DominatorTreeBase that is used to compute a normal dominator tree.
Definition Dominators.h:151
LLVM_ABI bool isReachableFromEntry(const Use &U) const
Provide an overload for a Use.
LLVM_ABI bool dominates(const BasicBlock *BB, const Use &U) const
Return true if the (end of the) basic block BB dominates the use U.
This class describes a reference to an interned FoldingSetNodeID, which can be a useful to store node...
Definition FoldingSet.h:171
This class is used to gather all the unique data bits of a node.
Definition FoldingSet.h:208
FunctionPass(char &pid)
Definition Pass.h:316
Represents flags for the getelementptr instruction/expression.
bool hasNoUnsignedSignedWrap() const
bool hasNoUnsignedWrap() const
static GEPNoWrapFlags none()
static LLVM_ABI Type * getTypeAtIndex(Type *Ty, Value *Idx)
Return the type of the element at the given index of an indexable type.
Module * getParent()
Get the module that this global value is contained inside of...
static bool isPrivateLinkage(LinkageTypes Linkage)
static bool isInternalLinkage(LinkageTypes Linkage)
This instruction compares its operands according to the predicate given to the constructor.
CmpPredicate getCmpPredicate() const
static bool isGE(Predicate P)
Return true if the predicate is SGE or UGE.
CmpPredicate getSwappedCmpPredicate() const
static LLVM_ABI bool compare(const APInt &LHS, const APInt &RHS, ICmpInst::Predicate Pred)
Return result of LHS Pred RHS comparison.
static bool isLT(Predicate P)
Return true if the predicate is SLT or ULT.
CmpPredicate getInverseCmpPredicate() const
Predicate getNonStrictCmpPredicate() const
For example, SGT -> SGE, SLT -> SLE, ULT -> ULE, UGT -> UGE.
static bool isGT(Predicate P)
Return true if the predicate is SGT or UGT.
Predicate getFlippedSignednessPredicate() const
For example, SLT->ULT, ULT->SLT, SLE->ULE, ULE->SLE, EQ->EQ.
static CmpPredicate getInverseCmpPredicate(CmpPredicate Pred)
static bool isEquality(Predicate P)
Return true if this predicate is either EQ or NE.
bool isRelational() const
Return true if the predicate is relational (not EQ or NE).
static bool isLE(Predicate P)
Return true if the predicate is SLE or ULE.
LLVM_ABI bool hasNoUnsignedWrap() const LLVM_READONLY
Determine whether the no unsigned wrap flag is set.
LLVM_ABI bool hasNoSignedWrap() const LLVM_READONLY
Determine whether the no signed wrap flag is set.
LLVM_ABI bool isIdenticalToWhenDefined(const Instruction *I, bool IntersectAttrs=false) const LLVM_READONLY
This is like isIdenticalTo, except that it ignores the SubclassOptionalData flags,...
Class to represent integer types.
static LLVM_ABI IntegerType * get(LLVMContext &C, unsigned NumBits)
This static method is the primary way of constructing an IntegerType.
Definition Type.cpp:350
A helper class to return the specified delimiter string after the first invocation of operator String...
An instruction for reading from memory.
Analysis pass that exposes the LoopInfo for a function.
Definition LoopInfo.h:587
bool contains(const LoopT *L) const
Return true if the specified loop is contained within in this loop.
BlockT * getHeader() const
unsigned getLoopDepth() const
Return the nesting level of this loop.
BlockT * getLoopPredecessor() const
If the given loop's header has exactly one unique predecessor outside the loop, return it.
LoopT * getParentLoop() const
Return the parent loop if it exists or nullptr for top level loops.
unsigned getLoopDepth(const BlockT *BB) const
Return the loop nesting level of the specified block.
LoopT * getLoopFor(const BlockT *BB) const
Return the inner most loop that BB lives in.
The legacy pass manager's analysis pass to compute loop information.
Definition LoopInfo.h:612
Represents a single loop in the control flow graph.
Definition LoopInfo.h:40
bool isLoopInvariant(const Value *V) const
Return true if the specified value is loop invariant.
Definition LoopInfo.cpp:67
Metadata node.
Definition Metadata.h:1069
A Module instance is used to store all the information related to an LLVM module.
Definition Module.h:67
unsigned getOpcode() const
Return the opcode for this Instruction or ConstantExpr.
Definition Operator.h:43
Utility class for integer operators which may exhibit overflow - Add, Sub, Mul, and Shl.
Definition Operator.h:78
bool hasNoSignedWrap() const
Test whether this operation is known to never undergo signed overflow, aka the nsw property.
Definition Operator.h:113
bool hasNoUnsignedWrap() const
Test whether this operation is known to never undergo unsigned overflow, aka the nuw property.
Definition Operator.h:107
iterator_range< const_block_iterator > blocks() const
op_range incoming_values()
Value * getIncomingValueForBlock(const BasicBlock *BB) const
BasicBlock * getIncomingBlock(unsigned i) const
Return incoming basic block number i.
Value * getIncomingValue(unsigned i) const
Return incoming value number x.
unsigned getNumIncomingValues() const
Return the number of incoming edges.
AnalysisType & getAnalysis() const
getAnalysis<AnalysisType>() - This function is used by subclasses to get to the analysis information ...
PointerIntPair - This class implements a pair of a pointer and small integer.
static PointerType * getUnqual(Type *ElementType)
This constructs a pointer to an object of the specified type in the default address space (address sp...
static LLVM_ABI PoisonValue * get(Type *T)
Static factory methods - Return an 'poison' object of the specified type.
LLVM_ABI void addPredicate(const SCEVPredicate &Pred)
Adds a new predicate.
LLVM_ABI const SCEVPredicate & getPredicate() const
LLVM_ABI const SCEV * getPredicatedSCEV(const SCEV *Expr)
Returns the rewritten SCEV for Expr in the context of the current SCEV predicate.
LLVM_ABI bool areAddRecsEqualWithPreds(const SCEVAddRecExpr *AR1, const SCEVAddRecExpr *AR2, ArrayRef< const SCEVPredicate * > ExtraPreds={}) const
Check if AR1 and AR2 are equal, while taking into account Equal predicates in Preds and ExtraPreds.
LLVM_ABI bool hasNoOverflow(Value *V, SCEVWrapPredicate::IncrementWrapFlags Flags)
Returns true if we've proved that V doesn't wrap by means of a SCEV predicate.
LLVM_ABI void setNoOverflow(Value *V, SCEVWrapPredicate::IncrementWrapFlags Flags)
Proves that V doesn't overflow by adding SCEV predicate.
LLVM_ABI const SCEVAddRecExpr * getAsAddRec(Value *V, SmallVectorImpl< const SCEVPredicate * > *WrapPredsAdded=nullptr)
Attempts to produce an AddRecExpr for V by adding additional SCEV predicates.
LLVM_ABI void print(raw_ostream &OS, unsigned Depth) const
Print the SCEV mappings done by the Predicated Scalar Evolution.
LLVM_ABI PredicatedScalarEvolution(ScalarEvolution &SE, Loop &L)
LLVM_ABI unsigned getSmallConstantMaxTripCount()
Returns the upper bound of the loop trip count as a normal unsigned value, or 0 if the trip count is ...
LLVM_ABI const SCEV * getBackedgeTakenCount()
Get the (predicated) backedge count for the analyzed loop.
LLVM_ABI const SCEV * getSymbolicMaxBackedgeTakenCount()
Get the (predicated) symbolic max backedge count for the analyzed loop.
LLVM_ABI const SCEV * getSCEV(Value *V)
Returns the SCEV expression of V, in the context of the current SCEV predicate.
A set of analyses that are preserved following a run of a transformation pass.
Definition Analysis.h:112
static PreservedAnalyses all()
Construct a special preserved set that preserves all passes.
Definition Analysis.h:118
PreservedAnalysisChecker getChecker() const
Build a checker for this PreservedAnalyses and the specified analysis type.
Definition Analysis.h:275
constexpr bool isValid() const
Definition Register.h:112
This node represents an addition of some number of SCEVs.
This node represents a polynomial recurrence on the trip count of the specified loop.
LLVM_ABI const SCEV * evaluateAtIteration(const SCEV *It, ScalarEvolution &SE) const
Return the value of this chain of recurrences at the specified iteration number.
void setNoWrapFlags(NoWrapFlags Flags)
Set flags for a recurrence without clearing any previously set flags.
bool isAffine() const
Return true if this represents an expression A + B*x where A and B are loop invariant values.
bool isQuadratic() const
Return true if this represents an expression A + B*x + C*x^2 where A, B and C are loop invariant valu...
LLVM_ABI const SCEV * getNumIterationsInRange(const ConstantRange &Range, ScalarEvolution &SE) const
Return the number of iterations of this loop that produce values in the specified constant range.
LLVM_ABI const SCEVAddRecExpr * getPostIncExpr(ScalarEvolution &SE) const
Return an expression representing the value of this expression one iteration of the loop ahead.
SCEVUse getStepRecurrence(ScalarEvolution &SE) const
Constructs and returns the recurrence indicating how much this expression steps by.
This is the base class for unary cast operator classes.
LLVM_ABI SCEVCastExpr(const FoldingSetNodeIDRef ID, SCEVTypes SCEVTy, SCEVUse op, Type *ty)
void setNoWrapFlags(NoWrapFlags Flags)
Set flags for a non-recurrence without clearing previously set flags.
This class represents an assumption that the expression LHS Pred RHS evaluates to true,...
SCEVComparePredicate(const FoldingSetNodeIDRef ID, const ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS)
bool isAlwaysTrue() const override
Returns true if the predicate is always true.
void print(raw_ostream &OS, unsigned Depth=0) const override
Prints a textual representation of this predicate with an indentation of Depth.
bool implies(const SCEVPredicate *N, ScalarEvolution &SE) const override
Implementation of the SCEVPredicate interface.
This class represents a constant integer value.
ConstantInt * getValue() const
const APInt & getAPInt() const
This is the base class for unary integral cast operator classes.
LLVM_ABI SCEVIntegralCastExpr(const FoldingSetNodeIDRef ID, SCEVTypes SCEVTy, SCEVUse op, Type *ty)
This node is the base class min/max selections.
static enum SCEVTypes negate(enum SCEVTypes T)
This node represents multiplication of some number of SCEVs.
This node is a base class providing common functionality for n'ary operators.
ArrayRef< SCEVUse > operands() const
NoWrapFlags getNoWrapFlags(NoWrapFlags Mask=NoWrapMask) const
SCEVUse getOperand(unsigned i) const
This class represents an assumption made using SCEV expressions which can be checked at run-time.
SCEVPredicate(const SCEVPredicate &)=default
virtual bool implies(const SCEVPredicate *N, ScalarEvolution &SE) const =0
Returns true if this predicate implies N.
SCEVPredicateKind Kind
This class represents a cast from a pointer to a pointer-sized integer value, without capturing the p...
This class represents a cast from a pointer to a pointer-sized integer value.
This visitor recursively visits a SCEV expression and re-writes it.
const SCEV * visitSignExtendExpr(const SCEVSignExtendExpr *Expr)
const SCEV * visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr)
const SCEV * visitSMinExpr(const SCEVSMinExpr *Expr)
const SCEV * visitUMinExpr(const SCEVUMinExpr *Expr)
This class represents a signed minimum selection.
This node is the base class for sequential/in-order min/max selections.
static SCEVTypes getEquivalentNonSequentialSCEVType(SCEVTypes Ty)
This class represents a sign extension of a small integer value to a larger integer value.
Visit all nodes in the expression tree using worklist traversal.
This class represents a truncation of an integer value to a smaller integer value.
This class represents a binary unsigned division operation.
This class represents an unsigned minimum selection.
This class represents a composition of other SCEV predicates, and is the class that most clients will...
void print(raw_ostream &OS, unsigned Depth) const override
Prints a textual representation of this predicate with an indentation of Depth.
bool implies(const SCEVPredicate *N, ScalarEvolution &SE) const override
Returns true if this predicate implies N.
SCEVUnionPredicate(ArrayRef< const SCEVPredicate * > Preds, ScalarEvolution &SE)
Union predicates don't get cached so create a dummy set ID for it.
bool isAlwaysTrue() const override
Implementation of the SCEVPredicate interface.
SCEVUnionPredicate getUnionWith(const SCEVPredicate *N, ScalarEvolution &SE) const
Returns a new SCEVUnionPredicate that is the union of this predicate and the given predicate N.
This means that we are dealing with an entirely unknown SCEV value, and only represent it as its LLVM...
This class represents the value of vscale, as used when defining the length of a scalable vector or r...
This class represents an assumption made on an AddRec expression.
IncrementWrapFlags
Similar to SCEV::NoWrapFlags, but with slightly different semantics for FlagNUSW.
SCEVWrapPredicate(const FoldingSetNodeIDRef ID, const SCEVAddRecExpr *AR, IncrementWrapFlags Flags)
bool implies(const SCEVPredicate *N, ScalarEvolution &SE) const override
Returns true if this predicate implies N.
static SCEVWrapPredicate::IncrementWrapFlags setFlags(SCEVWrapPredicate::IncrementWrapFlags Flags, SCEVWrapPredicate::IncrementWrapFlags OnFlags)
void print(raw_ostream &OS, unsigned Depth=0) const override
Prints a textual representation of this predicate with an indentation of Depth.
bool isAlwaysTrue() const override
Returns true if the predicate is always true.
const SCEVAddRecExpr * getExpr() const
Implementation of the SCEVPredicate interface.
static SCEVWrapPredicate::IncrementWrapFlags clearFlags(SCEVWrapPredicate::IncrementWrapFlags Flags, SCEVWrapPredicate::IncrementWrapFlags OffFlags)
Convenient IncrementWrapFlags manipulation methods.
static SCEVWrapPredicate::IncrementWrapFlags getImpliedFlags(const SCEVAddRecExpr *AR, ScalarEvolution &SE)
Returns the set of SCEVWrapPredicate no wrap flags implied by a SCEVAddRecExpr.
IncrementWrapFlags getFlags() const
Returns the set assumed no overflow flags.
This class represents a zero extension of a small integer value to a larger integer value.
This class represents an analyzed expression in the program.
unsigned short getExpressionSize() const
SCEVNoWrapFlags NoWrapFlags
LLVM_ABI bool isOne() const
Return true if the expression is a constant one.
static constexpr auto FlagNUW
LLVM_ABI void computeAndSetCanonical(ScalarEvolution &SE)
Compute and set the canonical SCEV, by constructing a SCEV with the same operands,...
LLVM_ABI bool isZero() const
Return true if the expression is a constant zero.
const SCEV * CanonicalSCEV
Pointer to the canonical version of the SCEV, i.e.
static constexpr auto FlagAnyWrap
LLVM_ABI void dump() const
This method is used for debugging.
LLVM_ABI bool isAllOnesValue() const
Return true if the expression is a constant all-ones value.
LLVM_ABI bool isNonConstantNegative() const
Return true if the specified scev is negated, but not a constant.
static constexpr auto FlagNSW
LLVM_ABI ArrayRef< SCEVUse > operands() const
Return operands of this SCEV expression.
LLVM_ABI void print(raw_ostream &OS) const
Print out the internal representation of this scalar to the specified stream.
SCEV(const FoldingSetNodeIDRef ID, SCEVTypes SCEVTy, unsigned short ExpressionSize)
SCEVTypes getSCEVType() const
static constexpr auto FlagNW
LLVM_ABI Type * getType() const
Return the LLVM type of this SCEV expression.
Analysis pass that exposes the ScalarEvolution for a function.
LLVM_ABI ScalarEvolution run(Function &F, FunctionAnalysisManager &AM)
LLVM_ABI PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM)
LLVM_ABI PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM)
void getAnalysisUsage(AnalysisUsage &AU) const override
getAnalysisUsage - This function should be overriden by passes that need analysis information to do t...
void print(raw_ostream &OS, const Module *=nullptr) const override
print - Print out the internal state of the pass.
bool runOnFunction(Function &F) override
runOnFunction - Virtual method overriden by subclasses to do the per-function processing of the pass.
void releaseMemory() override
releaseMemory() - This member can be implemented by a pass if it wants to be able to release its memo...
void verifyAnalysis() const override
verifyAnalysis() - This member can be implemented by a analysis pass to check state of analysis infor...
static LLVM_ABI LoopGuards collect(const Loop *L, ScalarEvolution &SE)
Collect rewrite map for loop guards for loop L, together with flags indicating if NUW and NSW can be ...
LLVM_ABI const SCEV * rewrite(const SCEV *Expr) const
Try to apply the collected loop guards to Expr.
The main scalar evolution driver.
LLVM_ABI const SCEV * getUDivExpr(SCEVUse LHS, SCEVUse RHS)
Get a canonical unsigned division expression, or something simpler if possible.
const SCEV * getConstantMaxBackedgeTakenCount(const Loop *L)
When successful, this returns a SCEVConstant that is greater than or equal to (i.e.
static bool hasFlags(SCEV::NoWrapFlags Flags, SCEV::NoWrapFlags TestFlags)
const DataLayout & getDataLayout() const
Return the DataLayout associated with the module this SCEV instance is operating on.
LLVM_ABI bool isKnownNonNegative(const SCEV *S)
Test if the given expression is known to be non-negative.
LLVM_ABI bool isKnownOnEveryIteration(CmpPredicate Pred, const SCEVAddRecExpr *LHS, const SCEV *RHS)
Test if the condition described by Pred, LHS, RHS is known to be true on every iteration of the loop ...
LLVM_ABI const SCEV * getNegativeSCEV(const SCEV *V, SCEV::NoWrapFlags Flags=SCEV::FlagAnyWrap)
Return the SCEV object corresponding to -V.
LLVM_ABI std::optional< LoopInvariantPredicate > getLoopInvariantExitCondDuringFirstIterationsImpl(CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS, const Loop *L, const Instruction *CtxI, const SCEV *MaxIter)
LLVM_ABI const SCEV * getUDivCeilSCEV(const SCEV *N, const SCEV *D)
Compute ceil(N / D).
LLVM_ABI std::optional< LoopInvariantPredicate > getLoopInvariantExitCondDuringFirstIterations(CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS, const Loop *L, const Instruction *CtxI, const SCEV *MaxIter)
If the result of the predicate LHS Pred RHS is loop invariant with respect to L at given Context duri...
LLVM_ABI Type * getWiderType(Type *Ty1, Type *Ty2) const
LLVM_ABI const SCEV * getAbsExpr(const SCEV *Op, bool IsNSW)
LLVM_ABI bool isKnownNonPositive(const SCEV *S)
Test if the given expression is known to be non-positive.
LLVM_ABI bool isKnownNegative(const SCEV *S)
Test if the given expression is known to be negative.
LLVM_ABI const SCEV * getPredicatedConstantMaxBackedgeTakenCount(const Loop *L, SmallVectorImpl< const SCEVPredicate * > &Predicates)
Similar to getConstantMaxBackedgeTakenCount, except it will add a set of SCEV predicates to Predicate...
LLVM_ABI const SCEV * removePointerBase(const SCEV *S)
Compute an expression equivalent to S - getPointerBase(S).
LLVM_ABI bool isLoopEntryGuardedByCond(const Loop *L, CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS)
Test whether entry to the loop is protected by a conditional between LHS and RHS.
LLVM_ABI bool isKnownNonZero(const SCEV *S)
Test if the given expression is known to be non-zero.
LLVM_ABI const SCEV * getURemExpr(SCEVUse LHS, SCEVUse RHS)
Represents an unsigned remainder expression based on unsigned division.
LLVM_ABI const SCEV * getSCEVAtScope(const SCEV *S, const Loop *L)
Return a SCEV expression for the specified value at the specified scope in the program.
LLVM_ABI const SCEV * getBackedgeTakenCount(const Loop *L, ExitCountKind Kind=Exact)
If the specified loop has a predictable backedge-taken count, return it, otherwise return a SCEVCould...
LLVM_ABI const SCEV * getSMinExpr(SCEVUse LHS, SCEVUse RHS)
LLVM_ABI void setNoWrapFlags(SCEVAddRecExpr *AddRec, SCEV::NoWrapFlags Flags)
Update no-wrap flags of an AddRec.
LLVM_ABI const SCEV * getUMaxFromMismatchedTypes(const SCEV *LHS, const SCEV *RHS)
Promote the operands to the wider of the types using zero-extension, and then perform a umax operatio...
const SCEV * getZero(Type *Ty)
Return a SCEV for the constant 0 of a specific type.
LLVM_ABI bool willNotOverflow(Instruction::BinaryOps BinOp, bool Signed, const SCEV *LHS, const SCEV *RHS, const Instruction *CtxI=nullptr)
Is operation BinOp between LHS and RHS provably does not have a signed/unsigned overflow (Signed)?
LLVM_ABI ExitLimit computeExitLimitFromCond(const Loop *L, Value *ExitCond, bool ExitIfTrue, bool ControlsOnlyExit, bool AllowPredicates=false)
Compute the number of times the backedge of the specified loop will execute if its exit condition wer...
LLVM_ABI const SCEV * getZeroExtendExprImpl(const SCEV *Op, Type *Ty, unsigned Depth=0)
LLVM_ABI const SCEV * getMinMaxExpr(SCEVTypes Kind, SmallVectorImpl< SCEVUse > &Operands)
LLVM_ABI const SCEVPredicate * getEqualPredicate(const SCEV *LHS, const SCEV *RHS)
LLVM_ABI unsigned getSmallConstantTripMultiple(const Loop *L, const SCEV *ExitCount)
Returns the largest constant divisor of the trip count as a normal unsigned value,...
LLVM_ABI uint64_t getTypeSizeInBits(Type *Ty) const
Return the size in bits of the specified type, for which isSCEVable must return true.
LLVM_ABI const SCEV * getConstant(ConstantInt *V)
LLVM_ABI const SCEV * getPredicatedBackedgeTakenCount(const Loop *L, SmallVectorImpl< const SCEVPredicate * > &Predicates)
Similar to getBackedgeTakenCount, except it will add a set of SCEV predicates to Predicates that are ...
LLVM_ABI const SCEV * getSCEV(Value *V)
Return a SCEV expression for the full generality of the specified expression.
LLVM_ABI const SCEV * getMinusSCEV(SCEVUse LHS, SCEVUse RHS, SCEV::NoWrapFlags Flags=SCEV::FlagAnyWrap, unsigned Depth=0)
Return LHS-RHS.
ConstantRange getSignedRange(const SCEV *S)
Determine the signed range for a particular SCEV.
LLVM_ABI const SCEV * getAddRecExpr(SCEVUse Start, SCEVUse Step, const Loop *L, SCEV::NoWrapFlags Flags)
Get an add recurrence expression for the specified loop.
LLVM_ABI const SCEV * getNoopOrSignExtend(const SCEV *V, Type *Ty)
Return a SCEV corresponding to a conversion of the input value to the specified type.
static LLVM_ABI bool isGuaranteedNotToBePoison(const SCEV *Op)
Returns true if Op is guaranteed to not be poison.
bool loopHasNoAbnormalExits(const Loop *L)
Return true if the loop has no abnormal exits.
LLVM_ABI const SCEV * getTripCountFromExitCount(const SCEV *ExitCount)
A version of getTripCountFromExitCount below which always picks an evaluation type which can not resu...
LLVM_ABI ScalarEvolution(Function &F, TargetLibraryInfo &TLI, AssumptionCache &AC, DominatorTree &DT, LoopInfo &LI)
const SCEV * getOne(Type *Ty)
Return a SCEV for the constant 1 of a specific type.
LLVM_ABI const SCEV * getTruncateOrNoop(const SCEV *V, Type *Ty)
Return a SCEV corresponding to a conversion of the input value to the specified type.
LLVM_ABI const SCEV * getLosslessPtrToIntExpr(const SCEV *Op)
LLVM_ABI const SCEV * getCastExpr(SCEVTypes Kind, const SCEV *Op, Type *Ty)
LLVM_ABI const SCEV * getSequentialMinMaxExpr(SCEVTypes Kind, SmallVectorImpl< SCEVUse > &Operands)
LLVM_ABI std::optional< bool > evaluatePredicateAt(CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS, const Instruction *CtxI)
Check whether the condition described by Pred, LHS, and RHS is true or false in the given Context.
LLVM_ABI unsigned getSmallConstantMaxTripCount(const Loop *L, SmallVectorImpl< const SCEVPredicate * > *Predicates=nullptr)
Returns the upper bound of the loop trip count as a normal unsigned value.
LLVM_ABI const SCEV * getPtrToIntExpr(const SCEV *Op, Type *Ty)
LLVM_ABI bool isBackedgeTakenCountMaxOrZero(const Loop *L)
Return true if the backedge taken count is either the value returned by getConstantMaxBackedgeTakenCo...
LLVM_ABI void forgetLoop(const Loop *L)
This method should be called by the client when it has changed a loop in a way that may effect Scalar...
LLVM_ABI bool isLoopInvariant(const SCEV *S, const Loop *L)
Return true if the value of the given SCEV is unchanging in the specified loop.
LLVM_ABI bool isKnownPositive(const SCEV *S)
Test if the given expression is known to be positive.
LLVM_ABI bool SimplifyICmpOperands(CmpPredicate &Pred, SCEVUse &LHS, SCEVUse &RHS, unsigned Depth=0)
Simplify LHS and RHS in a comparison with predicate Pred.
APInt getUnsignedRangeMin(const SCEV *S)
Determine the min of the unsigned range for a particular SCEV.
LLVM_ABI const SCEV * getOffsetOfExpr(Type *IntTy, StructType *STy, unsigned FieldNo)
Return an expression for offsetof on the given field with type IntTy.
LLVM_ABI LoopDisposition getLoopDisposition(const SCEV *S, const Loop *L)
Return the "disposition" of the given SCEV with respect to the given loop.
LLVM_ABI bool containsAddRecurrence(const SCEV *S)
Return true if the SCEV is a scAddRecExpr or it contains scAddRecExpr.
LLVM_ABI const SCEV * getSignExtendExprImpl(const SCEV *Op, Type *Ty, unsigned Depth=0)
LLVM_ABI bool hasOperand(const SCEV *S, const SCEV *Op) const
Test whether the given SCEV has Op as a direct or indirect operand.
LLVM_ABI const SCEV * getZeroExtendExpr(const SCEV *Op, Type *Ty, unsigned Depth=0)
LLVM_ABI bool isSCEVable(Type *Ty) const
Test if values of the given type are analyzable within the SCEV framework.
LLVM_ABI Type * getEffectiveSCEVType(Type *Ty) const
Return a type with the same bitwidth as the given type and which represents how SCEV will treat the g...
LLVM_ABI const SCEVPredicate * getComparePredicate(ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS)
LLVM_ABI bool haveSameSign(const SCEV *S1, const SCEV *S2)
Return true if we know that S1 and S2 must have the same sign.
LLVM_ABI const SCEV * getNotSCEV(const SCEV *V)
Return the SCEV object corresponding to ~V.
LLVM_ABI const SCEV * getElementCount(Type *Ty, ElementCount EC, SCEV::NoWrapFlags Flags=SCEV::FlagAnyWrap)
LLVM_ABI bool instructionCouldExistWithOperands(const SCEV *A, const SCEV *B)
Return true if there exists a point in the program at which both A and B could be operands to the sam...
ConstantRange getUnsignedRange(const SCEV *S)
Determine the unsigned range for a particular SCEV.
LLVM_ABI void print(raw_ostream &OS) const
LLVM_ABI const SCEV * getPredicatedExitCount(const Loop *L, const BasicBlock *ExitingBlock, SmallVectorImpl< const SCEVPredicate * > *Predicates, ExitCountKind Kind=Exact)
Same as above except this uses the predicated backedge taken info and may require predicates.
static SCEV::NoWrapFlags clearFlags(SCEV::NoWrapFlags Flags, SCEV::NoWrapFlags OffFlags)
LLVM_ABI void forgetTopmostLoop(const Loop *L)
LLVM_ABI void forgetValue(Value *V)
This method should be called by the client when it has changed a value in a way that may effect its v...
APInt getSignedRangeMin(const SCEV *S)
Determine the min of the signed range for a particular SCEV.
LLVM_ABI bool isLoopUniform(const SCEV *S, const Loop *L)
Returns true if the given SCEV is loop-uniform with respect to the specified loop L.
LLVM_ABI const SCEV * getNoopOrAnyExtend(const SCEV *V, Type *Ty)
Return a SCEV corresponding to a conversion of the input value to the specified type.
LLVM_ABI void forgetBlockAndLoopDispositions(Value *V=nullptr)
Called when the client has changed the disposition of values in a loop or block.
LLVM_ABI const SCEV * getTruncateExpr(const SCEV *Op, Type *Ty, unsigned Depth=0)
LLVM_ABI const SCEV * getUMaxExpr(SCEVUse LHS, SCEVUse RHS)
static SCEV::NoWrapFlags maskFlags(SCEV::NoWrapFlags Flags, SCEV::NoWrapFlags Mask)
Convenient NoWrapFlags manipulation.
LLVM_ABI std::optional< LoopInvariantPredicate > getLoopInvariantPredicate(CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS, const Loop *L, const Instruction *CtxI=nullptr)
If the result of the predicate LHS Pred RHS is loop invariant with respect to L, return a LoopInvaria...
LLVM_ABI const SCEV * getStoreSizeOfExpr(Type *IntTy, Type *StoreTy)
Return an expression for the store size of StoreTy that is type IntTy.
LLVM_ABI const SCEVPredicate * getWrapPredicate(const SCEVAddRecExpr *AR, SCEVWrapPredicate::IncrementWrapFlags AddedFlags)
LLVM_ABI bool isLoopBackedgeGuardedByCond(const Loop *L, CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS)
Test whether the backedge of the loop is protected by a conditional between LHS and RHS.
LLVM_ABI APInt getNonZeroConstantMultiple(const SCEV *S)
const SCEV * getMinusOne(Type *Ty)
Return a SCEV for the constant -1 of a specific type.
static SCEV::NoWrapFlags setFlags(SCEV::NoWrapFlags Flags, SCEV::NoWrapFlags OnFlags)
LLVM_ABI bool hasLoopInvariantBackedgeTakenCount(const Loop *L)
Return true if the specified loop has an analyzable loop-invariant backedge-taken count.
LLVM_ABI BlockDisposition getBlockDisposition(const SCEV *S, const BasicBlock *BB)
Return the "disposition" of the given SCEV with respect to the given block.
LLVM_ABI const SCEV * getNoopOrZeroExtend(const SCEV *V, Type *Ty)
Return a SCEV corresponding to a conversion of the input value to the specified type.
LLVM_ABI bool invalidate(Function &F, const PreservedAnalyses &PA, FunctionAnalysisManager::Invalidator &Inv)
LLVM_ABI const SCEV * getUMinFromMismatchedTypes(const SCEV *LHS, const SCEV *RHS, bool Sequential=false)
Promote the operands to the wider of the types using zero-extension, and then perform a umin operatio...
LLVM_ABI bool loopIsFiniteByAssumption(const Loop *L)
Return true if this loop is finite by assumption.
LLVM_ABI const SCEV * getExistingSCEV(Value *V)
Return an existing SCEV for V if there is one, otherwise return nullptr.
LLVM_ABI APInt getConstantMultiple(const SCEV *S, const Instruction *CtxI=nullptr)
Returns the max constant multiple of S.
LoopDisposition
An enum describing the relationship between a SCEV and a loop.
@ LoopComputable
The SCEV varies predictably with the loop.
@ LoopVariant
The SCEV is loop-variant (unknown).
@ LoopInvariant
The SCEV is loop-invariant.
@ LoopUniform
The SCEV is loop-uniform.
LLVM_ABI bool isKnownMultipleOf(const SCEV *S, uint64_t M, SmallVectorImpl< const SCEVPredicate * > &Assumptions)
Check that S is a multiple of M.
LLVM_ABI const SCEV * getAnyExtendExpr(const SCEV *Op, Type *Ty)
getAnyExtendExpr - Return a SCEV for the given operand extended with unspecified bits out to the give...
LLVM_ABI bool isKnownToBeAPowerOfTwo(const SCEV *S, bool OrZero=false, bool OrNegative=false)
Test if the given expression is known to be a power of 2.
LLVM_ABI std::optional< SCEV::NoWrapFlags > getStrengthenedNoWrapFlagsFromBinOp(const OverflowingBinaryOperator *OBO)
Parse NSW/NUW flags from add/sub/mul IR binary operation Op into SCEV no-wrap flags,...
LLVM_ABI void forgetLcssaPhiWithNewPredecessor(Loop *L, PHINode *V)
Forget LCSSA phi node V of loop L to which a new predecessor was added, such that it may no longer be...
LLVM_ABI bool containsUndefs(const SCEV *S) const
Return true if the SCEV expression contains an undef value.
LLVM_ABI std::optional< MonotonicPredicateType > getMonotonicPredicateType(const SCEVAddRecExpr *LHS, ICmpInst::Predicate Pred)
If, for all loop invariant X, the predicate "LHS `Pred` X" is monotonically increasing or decreasing,...
LLVM_ABI const SCEV * getCouldNotCompute()
LLVM_ABI const SCEV * getMulExpr(SmallVectorImpl< SCEVUse > &Ops, SCEV::NoWrapFlags Flags=SCEV::FlagAnyWrap, unsigned Depth=0)
Get a canonical multiply expression, or something simpler if possible.
LLVM_ABI bool isAvailableAtLoopEntry(const SCEV *S, const Loop *L)
Determine if the SCEV can be evaluated at loop's entry.
LLVM_ABI uint32_t getMinTrailingZeros(const SCEV *S, const Instruction *CtxI=nullptr)
Determine the minimum number of zero bits that S is guaranteed to end in (at every loop iteration).
BlockDisposition
An enum describing the relationship between a SCEV and a basic block.
@ DominatesBlock
The SCEV dominates the block.
@ ProperlyDominatesBlock
The SCEV properly dominates the block.
@ DoesNotDominateBlock
The SCEV does not dominate the block.
LLVM_ABI const SCEV * getExitCount(const Loop *L, const BasicBlock *ExitingBlock, ExitCountKind Kind=Exact)
Return the number of times the backedge executes before the given exit would be taken; if not exactly...
LLVM_ABI const SCEV * getSignExtendExpr(const SCEV *Op, Type *Ty, unsigned Depth=0)
LLVM_ABI void getPoisonGeneratingValues(SmallPtrSetImpl< const Value * > &Result, const SCEV *S)
Return the set of Values that, if poison, will definitively result in S being poison as well.
LLVM_ABI void forgetLoopDispositions()
Called when the client has changed the disposition of values in this loop.
LLVM_ABI const SCEV * getVScale(Type *Ty)
LLVM_ABI unsigned getSmallConstantTripCount(const Loop *L)
Returns the exact trip count of the loop if we can compute it, and the result is a small constant.
LLVM_ABI bool hasComputableLoopEvolution(const SCEV *S, const Loop *L)
Return true if the given SCEV changes value in a known way in the specified loop.
LLVM_ABI const SCEV * getPointerBase(const SCEV *V)
Transitively follow the chain of pointer-type operands until reaching a SCEV that does not have a sin...
LLVM_ABI void forgetAllLoops()
LLVM_ABI bool dominates(const SCEV *S, const BasicBlock *BB)
Return true if elements that makes up the given SCEV dominate the specified basic block.
APInt getUnsignedRangeMax(const SCEV *S)
Determine the max of the unsigned range for a particular SCEV.
LLVM_ABI const SCEV * getAddExpr(SmallVectorImpl< SCEVUse > &Ops, SCEV::NoWrapFlags Flags=SCEV::FlagAnyWrap, unsigned Depth=0)
Get a canonical add expression, or something simpler if possible.
ExitCountKind
The terms "backedge taken count" and "exit count" are used interchangeably to refer to the number of ...
@ SymbolicMaximum
An expression which provides an upper bound on the exact trip count.
@ ConstantMaximum
A constant which provides an upper bound on the exact trip count.
@ Exact
An expression exactly describing the number of times the backedge has executed when a loop is exited.
LLVM_ABI bool isKnownPredicate(CmpPredicate Pred, SCEVUse LHS, SCEVUse RHS)
Test if the given expression is known to satisfy the condition described by Pred, LHS,...
LLVM_ABI const SCEV * applyLoopGuards(const SCEV *Expr, const Loop *L)
Try to apply information from loop guards for L to Expr.
LLVM_ABI const SCEV * getPtrToAddrExpr(const SCEV *Op)
LLVM_ABI const SCEVAddRecExpr * convertSCEVToAddRecWithPredicates(const SCEV *S, const Loop *L, SmallVectorImpl< const SCEVPredicate * > &Preds)
Tries to convert the S expression to an AddRec expression, adding additional predicates to Preds as r...
LLVM_ABI const SCEV * getSMaxExpr(SCEVUse LHS, SCEVUse RHS)
LLVM_ABI const SCEV * getElementSize(Instruction *Inst)
Return the size of an element read or written by Inst.
LLVM_ABI const SCEV * getSizeOfExpr(Type *IntTy, TypeSize Size)
Return an expression for a TypeSize.
LLVM_ABI std::optional< bool > evaluatePredicate(CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS)
Check whether the condition described by Pred, LHS, and RHS is true or false.
LLVM_ABI const SCEV * getUnknown(Value *V)
LLVM_ABI std::optional< std::pair< const SCEV *, SmallVector< const SCEVPredicate *, 3 > > > createAddRecFromPHIWithCasts(const SCEVUnknown *SymbolicPHI)
Checks if SymbolicPHI can be rewritten as an AddRecExpr under some Predicates.
LLVM_ABI const SCEV * getTruncateOrZeroExtend(const SCEV *V, Type *Ty, unsigned Depth=0)
Return a SCEV corresponding to a conversion of the input value to the specified type.
LLVM_ABI bool isKnownViaInduction(CmpPredicate Pred, SCEVUse LHS, SCEVUse RHS)
We'd like to check the predicate on every iteration of the most dominated loop between loops used in ...
LLVM_ABI std::optional< APInt > computeConstantDifference(const SCEV *LHS, const SCEV *RHS)
Compute LHS - RHS and returns the result as an APInt if it is a constant, and std::nullopt if it isn'...
LLVM_ABI bool properlyDominates(const SCEV *S, const BasicBlock *BB)
Return true if elements that makes up the given SCEV properly dominate the specified basic block.
LLVM_ABI const SCEV * getUDivExactExpr(SCEVUse LHS, SCEVUse RHS)
Get a canonical unsigned division expression, or something simpler if possible.
LLVM_ABI const SCEV * rewriteUsingPredicate(const SCEV *S, const Loop *L, const SCEVPredicate &A)
Re-writes the SCEV according to the Predicates in A.
LLVM_ABI std::pair< const SCEV *, const SCEV * > SplitIntoInitAndPostInc(const Loop *L, const SCEV *S)
Splits SCEV expression S into two SCEVs.
LLVM_ABI bool canReuseInstruction(const SCEV *S, Instruction *I, SmallVectorImpl< Instruction * > &DropPoisonGeneratingInsts)
Check whether it is poison-safe to represent the expression S using the instruction I.
LLVM_ABI bool isKnownPredicateAt(CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS, const Instruction *CtxI)
Test if the given expression is known to satisfy the condition described by Pred, LHS,...
LLVM_ABI const SCEV * getPredicatedSymbolicMaxBackedgeTakenCount(const Loop *L, SmallVectorImpl< const SCEVPredicate * > &Predicates)
Similar to getSymbolicMaxBackedgeTakenCount, except it will add a set of SCEV predicates to Predicate...
LLVM_ABI const SCEV * getGEPExpr(GEPOperator *GEP, ArrayRef< SCEVUse > IndexExprs)
Returns an expression for a GEP.
LLVM_ABI const SCEV * getUMinExpr(SCEVUse LHS, SCEVUse RHS, bool Sequential=false)
LLVM_ABI void registerUser(const SCEV *User, ArrayRef< const SCEV * > Ops)
Notify this ScalarEvolution that User directly uses SCEVs in Ops.
LLVM_ABI bool isBasicBlockEntryGuardedByCond(const BasicBlock *BB, CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS)
Test whether entry to the basic block is protected by a conditional between LHS and RHS.
LLVM_ABI const SCEV * getTruncateOrSignExtend(const SCEV *V, Type *Ty, unsigned Depth=0)
Return a SCEV corresponding to a conversion of the input value to the specified type.
LLVM_ABI bool containsErasedValue(const SCEV *S) const
Return true if the SCEV expression contains a Value that has been optimised out and is now a nullptr.
const SCEV * getSymbolicMaxBackedgeTakenCount(const Loop *L)
When successful, this returns a SCEV that is greater than or equal to (i.e.
APInt getSignedRangeMax(const SCEV *S)
Determine the max of the signed range for a particular SCEV.
LLVM_ABI void verify() const
LLVMContext & getContext() const
Implements a dense probed hash-table based set with some number of buckets stored inline.
Definition DenseSet.h:301
size_type size() const
Definition SmallPtrSet.h:99
A templated base class for SmallPtrSet which provides the typesafe interface that is common across al...
std::pair< iterator, bool > insert(PtrType Ptr)
Inserts Ptr if and only if there is no element in the container equal to Ptr.
bool contains(ConstPtrType Ptr) const
SmallPtrSet - This class implements a set which is optimized for holding SmallSize or less elements.
This class consists of common code factored out of the SmallVector class to reduce code duplication b...
reference emplace_back(ArgTypes &&... Args)
void reserve(size_type N)
iterator erase(const_iterator CI)
void append(ItTy in_start, ItTy in_end)
Add the specified range to the end of the SmallVector.
iterator insert(iterator I, T &&Elt)
void push_back(const T &Elt)
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
An instruction for storing to memory.
Represent a constant reference to a string, i.e.
Definition StringRef.h:56
Used to lazily calculate structure layout information for a target machine, based on the DataLayout s...
Definition DataLayout.h:743
TypeSize getElementOffset(unsigned Idx) const
Definition DataLayout.h:774
TypeSize getSizeInBits() const
Definition DataLayout.h:754
Class to represent struct types.
Analysis pass providing the TargetLibraryInfo.
Provides information about what library functions are available for the current target.
The instances of the Type class are immutable: once they are created, they are never changed.
Definition Type.h:46
static LLVM_ABI IntegerType * getInt32Ty(LLVMContext &C)
Definition Type.cpp:309
bool isPointerTy() const
True if this is an instance of PointerType.
Definition Type.h:282
LLVM_ABI TypeSize getPrimitiveSizeInBits() const LLVM_READONLY
Return the basic size of this type if it is a primitive type.
Definition Type.cpp:197
static LLVM_ABI IntegerType * getInt1Ty(LLVMContext &C)
Definition Type.cpp:306
bool isIntOrPtrTy() const
Return true if this is an integer type or a pointer type.
Definition Type.h:270
bool isIntegerTy() const
True if this is an instance of IntegerType.
Definition Type.h:257
static LLVM_ABI IntegerType * getIntNTy(LLVMContext &C, unsigned N)
Definition Type.cpp:313
A Use represents the edge between a Value definition and its users.
Definition Use.h:35
op_range operands()
Definition User.h:267
Use & Op()
Definition User.h:171
Value * getOperand(unsigned i) const
Definition User.h:207
LLVM Value Representation.
Definition Value.h:75
Type * getType() const
All values are typed, get the type of this value.
Definition Value.h:255
LLVMContext & getContext() const
All values hold a context through their type.
Definition Value.h:258
unsigned getValueID() const
Return an ID for the concrete type of this object.
Definition Value.h:543
LLVM_ABI void printAsOperand(raw_ostream &O, bool PrintType=true, const Module *M=nullptr) const
Print the name of this Value out to the specified raw_ostream.
LLVM_ABI StringRef getName() const
Return a constant reference to the value's name.
Definition Value.cpp:319
constexpr bool isScalable() const
Returns whether the quantity is scaled by a runtime quantity (vscale).
Definition TypeSize.h:168
An efficient, type-erasing, non-owning reference to a callable.
const ParentTy * getParent() const
Definition ilist_node.h:34
This class implements an extremely fast bulk output stream that can only output to a stream.
Definition raw_ostream.h:53
raw_ostream & indent(unsigned NumSpaces)
indent - Insert 'NumSpaces' spaces.
Changed
#define llvm_unreachable(msg)
Marks that the current location is not supposed to be reachable.
constexpr char Align[]
Key for Kernel::Arg::Metadata::mAlign.
const APInt & smin(const APInt &A, const APInt &B)
Determine the smaller of two APInts considered to be signed.
Definition APInt.h:2277
const APInt & smax(const APInt &A, const APInt &B)
Determine the larger of two APInts considered to be signed.
Definition APInt.h:2282
const APInt & umin(const APInt &A, const APInt &B)
Determine the smaller of two APInts considered to be unsigned.
Definition APInt.h:2287
LLVM_ABI std::optional< APInt > SolveQuadraticEquationWrap(APInt A, APInt B, APInt C, unsigned RangeWidth)
Let q(n) = An^2 + Bn + C, and BW = bit width of the value range (e.g.
Definition APInt.cpp:2847
const APInt & umax(const APInt &A, const APInt &B)
Determine the larger of two APInts considered to be unsigned.
Definition APInt.h:2292
LLVM_ABI APInt GreatestCommonDivisor(APInt A, APInt B)
Compute GCD of two unsigned APInt values.
Definition APInt.cpp:830
constexpr bool any(E Val)
@ Entry
Definition COFF.h:862
unsigned ID
LLVM IR allows to use arbitrary numbers as calling convention identifiers.
Definition CallingConv.h:24
@ C
The default llvm calling convention, compatible with C.
Definition CallingConv.h:34
int getMinValue(MCInstrInfo const &MCII, MCInst const &MCI)
Return the minimum value of an extendable operand.
@ BasicBlock
Various leaf nodes.
Definition ISDOpcodes.h:81
LLVM_ABI Function * getDeclarationIfExists(const Module *M, ID id)
Look up the Function declaration of the intrinsic id in the Module M and return it if it exists.
Predicate
Predicate - These are "(BI << 5) | BO" for various predicates.
match_combine_or< Ty... > m_CombineOr(const Ty &...Ps)
Combine pattern matchers matching any of Ps patterns.
BinaryOp_match< LHS, RHS, Instruction::AShr > m_AShr(const LHS &L, const RHS &R)
ap_match< APInt > m_APInt(const APInt *&Res)
Match a ConstantInt or splatted ConstantVector, binding the specified pointer to the contained APInt.
bool match(Val *V, const Pattern &P)
IntrinsicID_match m_Intrinsic()
Match intrinsic calls like this: m_Intrinsic<Intrinsic::fabs>(m_Value(X))
ThreeOps_match< Cond, LHS, RHS, Instruction::Select > m_Select(const Cond &C, const LHS &L, const RHS &R)
Matches SelectInst.
auto m_BasicBlock()
Match an arbitrary basic block value and ignore it.
ExtractValue_match< Ind, Val_t > m_ExtractValue(const Val_t &V)
Match a single index ExtractValue instruction.
auto m_Value()
Match an arbitrary value and ignore it.
auto m_LogicalOr()
Matches L || R where L and R are arbitrary values.
match_bind< WithOverflowInst > m_WithOverflowInst(WithOverflowInst *&I)
Match a with overflow intrinsic, capturing it if we match.
BinaryOp_match< LHS, RHS, Instruction::SDiv > m_SDiv(const LHS &L, const RHS &R)
BinaryOp_match< LHS, RHS, Instruction::LShr > m_LShr(const LHS &L, const RHS &R)
BinaryOp_match< LHS, RHS, Instruction::Shl > m_Shl(const LHS &L, const RHS &R)
auto m_LogicalAnd()
Matches L && R where L and R are arbitrary values.
brc_match< Cond_t, match_bind< BasicBlock >, match_bind< BasicBlock > > m_Br(const Cond_t &C, BasicBlock *&T, BasicBlock *&F)
auto m_ConstantInt()
Match an arbitrary ConstantInt and ignore it.
bind_cst_ty m_scev_APInt(const APInt *&C)
Match an SCEV constant and bind it to an APInt.
cst_pred_ty< is_all_ones > m_scev_AllOnes()
Match an integer with all bits set.
SCEVUnaryExpr_match< SCEVZeroExtendExpr, Op0_t > m_scev_ZExt(const Op0_t &Op0)
is_undef_or_poison m_scev_UndefOrPoison()
Match an SCEVUnknown wrapping undef or poison.
cst_pred_ty< is_one > m_scev_One()
Match an integer 1.
specificloop_ty m_SpecificLoop(const Loop *L)
SCEVUnaryExpr_match< SCEVSignExtendExpr, Op0_t > m_scev_SExt(const Op0_t &Op0)
match_bind< const SCEVMulExpr > m_scev_Mul(const SCEVMulExpr *&V)
cst_pred_ty< is_zero > m_scev_Zero()
Match an integer 0.
SCEVUnaryExpr_match< SCEVTruncateExpr, Op0_t > m_scev_Trunc(const Op0_t &Op0)
bool match(const SCEV *S, const Pattern &P)
SCEVBinaryExpr_match< SCEVUDivExpr, Op0_t, Op1_t > m_scev_UDiv(const Op0_t &Op0, const Op1_t &Op1)
specificscev_ty m_scev_Specific(const SCEV *S)
Match if we have a specific specified SCEV.
SCEVAffineAddRec_match< Op0_t, Op1_t, match_isa< const Loop > > m_scev_AffineAddRec(const Op0_t &Op0, const Op1_t &Op1)
match_bind< const SCEVUnknown > m_SCEVUnknown(const SCEVUnknown *&V)
SCEVBinaryExpr_match< SCEVMulExpr, Op0_t, Op1_t, SCEV::FlagNUW, true > m_scev_c_NUWMul(const Op0_t &Op0, const Op1_t &Op1)
match_bind< const SCEVAddExpr > m_scev_Add(const SCEVAddExpr *&V)
SCEVBinaryExpr_match< SCEVMulExpr, Op0_t, Op1_t, SCEV::FlagAnyWrap, true > m_scev_c_Mul(const Op0_t &Op0, const Op1_t &Op1)
SCEVBinaryExpr_match< SCEVSMaxExpr, Op0_t, Op1_t > m_scev_SMax(const Op0_t &Op0, const Op1_t &Op1)
SCEVURem_match< Op0_t, Op1_t > m_scev_URem(Op0_t LHS, Op1_t RHS, ScalarEvolution &SE)
Match the mathematical pattern A - (A / B) * B, where A and B can be arbitrary expressions.
@ Valid
The data is already valid.
initializer< Ty > init(const Ty &Val)
LocationClass< Ty > location(Ty &L)
@ Switch
The "resume-switch" lowering, where there are separate resume and destroy functions that are shared b...
Definition CoroShape.h:31
constexpr double e
NodeAddr< PhiNode * > Phi
Definition RDFGraph.h:390
friend class Instruction
Iterator for Instructions in a `BasicBlock.
Definition BasicBlock.h:73
This is an optimization pass for GlobalISel generic memory operations.
void visitAll(const SCEV *Root, SV &Visitor)
Use SCEVTraversal to visit all nodes in the given expression tree.
auto drop_begin(T &&RangeOrContainer, size_t N=1)
Return a range covering RangeOrContainer with the first N elements excluded.
Definition STLExtras.h:315
@ Offset
Definition DWP.cpp:558
FunctionAddr VTableAddr Value
Definition InstrProf.h:137
LLVM_ATTRIBUTE_ALWAYS_INLINE DynamicAPInt gcd(const DynamicAPInt &A, const DynamicAPInt &B)
void stable_sort(R &&Range)
Definition STLExtras.h:2115
bool all_of(R &&range, UnaryPredicate P)
Provide wrappers to std::all_of which take ranges instead of having to pass begin/end explicitly.
Definition STLExtras.h:1738
SaveAndRestore(T &) -> SaveAndRestore< T >
Printable print(const GCNRegPressure &RP, const GCNSubtarget *ST=nullptr, unsigned DynamicVGPRBlockSize=0)
LLVM_ABI bool canCreatePoison(const Operator *Op, bool ConsiderFlagsAndMetadata=true)
LLVM_ABI bool mustTriggerUB(const Instruction *I, const SmallPtrSetImpl< const Value * > &KnownPoison)
Return true if the given instruction must trigger undefined behavior when I is executed with any oper...
@ Dead
Unused definition.
LLVM_ABI bool canConstantFoldCallTo(const CallBase *Call, const Function *F)
canConstantFoldCallTo - Return true if its even possible to fold a call to the specified function.
InterleavedRange< Range > interleaved(const Range &R, StringRef Separator=", ", StringRef Prefix="", StringRef Suffix="")
Output range R as a sequence of interleaved elements.
decltype(auto) dyn_cast(const From &Val)
dyn_cast<X> - Return the argument parameter cast to the specified type.
Definition Casting.h:643
LLVM_ABI bool verifyFunction(const Function &F, raw_ostream *OS=nullptr)
Check a function for errors, useful for use when debugging a pass.
auto successors(const MachineBasicBlock *BB)
scope_exit(Callable) -> scope_exit< Callable >
constexpr from_range_t from_range
auto dyn_cast_if_present(const Y &Val)
dyn_cast_if_present<X> - Functionally identical to dyn_cast, except that a null (or none in the case ...
Definition Casting.h:732
bool set_is_subset(const S1Ty &S1, const S2Ty &S2)
set_is_subset(A, B) - Return true iff A in B
void append_range(Container &C, Range &&R)
Wrapper function to append range R to container C.
Definition STLExtras.h:2207
constexpr bool isUIntN(unsigned N, uint64_t x)
Checks if an unsigned integer fits into the given (dynamic) bit width.
Definition MathExtras.h:243
LLVM_ABI Constant * ConstantFoldCompareInstOperands(unsigned Predicate, Constant *LHS, Constant *RHS, const DataLayout &DL, const TargetLibraryInfo *TLI=nullptr, const Instruction *I=nullptr)
Attempt to constant fold a compare instruction (icmp/fcmp) with the specified operands.
void * PointerTy
LLVM_ABI bool VerifySCEV
auto uninitialized_copy(R &&Src, IterTy Dst)
Definition STLExtras.h:2110
bool isa_and_nonnull(const Y &Val)
Definition Casting.h:676
LLVM_ABI ConstantRange getConstantRangeFromMetadata(const MDNode &RangeMD)
Parse out a conservative ConstantRange from !range metadata.
int countr_zero(T Val)
Count number of 0's from the least significant bit to the most stopping at the first 1.
Definition bit.h:204
LLVM_ABI Value * simplifyInstruction(Instruction *I, const SimplifyQuery &Q)
See if we can compute a simplified version of this instruction.
LLVM_ABI bool isOverflowIntrinsicNoWrap(const WithOverflowInst *WO, const DominatorTree &DT)
Returns true if the arithmetic part of the WO 's result is used only along the paths control dependen...
DomTreeNodeBase< BasicBlock > DomTreeNode
Definition Dominators.h:94
LLVM_ABI bool matchSimpleRecurrence(const PHINode *P, BinaryOperator *&BO, Value *&Start, Value *&Step)
Attempt to match a simple first order recurrence cycle of the form: iv = phi Ty [Start,...
auto dyn_cast_or_null(const Y &Val)
Definition Casting.h:753
void erase(Container &C, ValueType V)
Wrapper function to remove a value from a container:
Definition STLExtras.h:2199
bool any_of(R &&range, UnaryPredicate P)
Provide wrappers to std::any_of which take ranges instead of having to pass begin/end explicitly.
Definition STLExtras.h:1745
iterator_range< pointee_iterator< WrappedIteratorT > > make_pointee_range(RangeT &&Range)
Definition iterator.h:341
auto reverse(ContainerTy &&C)
Definition STLExtras.h:407
LLVM_ABI bool isMustProgress(const Loop *L)
Return true if this loop can be assumed to make progress.
LLVM_ABI bool impliesPoison(const Value *ValAssumedPoison, const Value *V)
Return true if V is poison given that ValAssumedPoison is already poison.
LLVM_ABI bool isFinite(const Loop *L)
Return true if this loop can be assumed to run for a finite number of iterations.
LLVM_ABI void computeKnownBits(const Value *V, KnownBits &Known, const DataLayout &DL, AssumptionCache *AC=nullptr, const Instruction *CxtI=nullptr, const DominatorTree *DT=nullptr, bool UseInstrInfo=true, unsigned Depth=0)
Determine which bits of V are known to be either zero or one and return them in the KnownZero/KnownOn...
unsigned short computeExpressionSize(ArrayRef< SCEVUse > Args)
LLVM_ABI bool programUndefinedIfPoison(const Instruction *Inst)
LLVM_ABI raw_ostream & dbgs()
dbgs() - This returns a reference to a raw_ostream for debugging messages.
Definition Debug.cpp:209
bool isPointerTy(const Type *T)
Definition SPIRVUtils.h:380
FunctionAddr VTableAddr Count
Definition InstrProf.h:139
LLVM_ABI ConstantRange getVScaleRange(const Function *F, unsigned BitWidth)
Determine the possible constant range of vscale with the given bit width, based on the vscale_range f...
class LLVM_GSL_OWNER SmallVector
Forward declaration of SmallVector so that calculateSmallVectorDefaultInlinedElements can reference s...
bool isa(const From &Val)
isa<X> - Return true if the parameter to the template is an instance of one of the template type argu...
Definition Casting.h:547
LLVM_ATTRIBUTE_VISIBILITY_DEFAULT AnalysisKey InnerAnalysisManagerProxy< AnalysisManagerT, IRUnitT, ExtraArgTs... >::Key
LLVM_ABI bool isKnownNonZero(const Value *V, const SimplifyQuery &Q, unsigned Depth=0)
Return true if the given value is known to be non-zero when defined.
constexpr T divideCeil(U Numerator, V Denominator)
Returns the integer ceil(Numerator / Denominator).
Definition MathExtras.h:394
LLVM_ABI bool propagatesPoison(const Use &PoisonOp)
Return true if PoisonOp's user yields poison or raises UB if its operand PoisonOp is poison.
@ UMin
Unsigned integer min implemented in terms of select(cmp()).
@ Mul
Product of integers.
@ SMax
Signed integer max implemented in terms of select(cmp()).
@ SMin
Signed integer min implemented in terms of select(cmp()).
@ Add
Sum of integers.
@ UMax
Unsigned integer max implemented in terms of select(cmp()).
auto count(R &&Range, const E &Element)
Wrapper function around std::count to count the number of times an element Element occurs in the give...
Definition STLExtras.h:2011
DWARFExpression::Operation Op
auto max_element(R &&Range)
Provide wrappers to std::max_element which take ranges instead of having to pass begin/end explicitly...
Definition STLExtras.h:2087
raw_ostream & operator<<(raw_ostream &OS, const APFixedPoint &FX)
ArrayRef(const T &OneElt) -> ArrayRef< T >
LLVM_ABI unsigned ComputeNumSignBits(const Value *Op, const DataLayout &DL, AssumptionCache *AC=nullptr, const Instruction *CxtI=nullptr, const DominatorTree *DT=nullptr, bool UseInstrInfo=true, unsigned Depth=0)
Return the number of times the sign bit of the register is replicated into the other bits.
constexpr unsigned BitWidth
OutputIt move(R &&Range, OutputIt Out)
Provide wrappers to std::move which take ranges instead of having to pass begin/end explicitly.
Definition STLExtras.h:1916
LLVM_ABI bool isGuaranteedToTransferExecutionToSuccessor(const Instruction *I)
Return true if this function can prove that the instruction I will always transfer execution to one o...
auto count_if(R &&Range, UnaryPredicate P)
Wrapper function around std::count_if to count the number of times an element satisfying a given pred...
Definition STLExtras.h:2018
decltype(auto) cast(const From &Val)
cast<X> - Return the argument parameter cast to the specified type.
Definition Casting.h:559
constexpr bool isIntN(unsigned N, int64_t x)
Checks if an signed integer fits into the given (dynamic) bit width.
Definition MathExtras.h:248
auto predecessors(const MachineBasicBlock *BB)
bool is_contained(R &&Range, const E &Element)
Returns true if Element is found in Range.
Definition STLExtras.h:1946
iterator_range< df_iterator< T > > depth_first(const T &G)
auto seq(T Begin, T End)
Iterate over an integral type from Begin up to - but not including - End.
Definition Sequence.h:305
AnalysisManager< Function > FunctionAnalysisManager
Convenience typedef for the Function analysis manager.
LLVM_ABI bool isGuaranteedNotToBePoison(const Value *V, AssumptionCache *AC=nullptr, const Instruction *CtxI=nullptr, const DominatorTree *DT=nullptr, unsigned Depth=0)
Returns true if V cannot be poison, but may be undef.
LLVM_ABI Constant * ConstantFoldInstOperands(const Instruction *I, ArrayRef< Constant * > Ops, const DataLayout &DL, const TargetLibraryInfo *TLI=nullptr, bool AllowNonDeterministic=true)
ConstantFoldInstOperands - Attempt to constant fold an instruction with the specified operands.
SCEVUseT< const SCEV * > SCEVUse
bool SCEVExprContains(const SCEV *Root, PredTy Pred)
Return true if any node in Root satisfies the predicate Pred.
Implement std::hash so that hash_code can be used in STL containers.
Definition BitVector.h:860
void swap(llvm::BitVector &LHS, llvm::BitVector &RHS)
Implement std::swap in terms of BitVector swap.
Definition BitVector.h:862
#define N
#define NC
Definition regutils.h:42
A special type used by analysis passes to provide an address that identifies that particular analysis...
Definition Analysis.h:29
static KnownBits makeConstant(const APInt &C)
Create known bits from a known constant.
Definition KnownBits.h:315
bool isNonNegative() const
Returns true if this value is known to be non-negative.
Definition KnownBits.h:106
static LLVM_ABI KnownBits ashr(const KnownBits &LHS, const KnownBits &RHS, bool ShAmtNonZero=false, bool Exact=false)
Compute known bits for ashr(LHS, RHS).
unsigned getBitWidth() const
Get the bit width of this value.
Definition KnownBits.h:44
static LLVM_ABI KnownBits lshr(const KnownBits &LHS, const KnownBits &RHS, bool ShAmtNonZero=false, bool Exact=false)
Compute known bits for lshr(LHS, RHS).
KnownBits zextOrTrunc(unsigned BitWidth) const
Return known bits for a zero extension or truncation of the value we're tracking.
Definition KnownBits.h:200
APInt getMaxValue() const
Return the maximal unsigned value possible given these KnownBits.
Definition KnownBits.h:146
APInt getMinValue() const
Return the minimal unsigned value possible given these KnownBits.
Definition KnownBits.h:130
bool isNegative() const
Returns true if this value is known to be negative.
Definition KnownBits.h:103
static LLVM_ABI KnownBits shl(const KnownBits &LHS, const KnownBits &RHS, bool NUW=false, bool NSW=false, bool ShAmtNonZero=false)
Compute known bits for shl(LHS, RHS).
An object of this class is returned by queries that could not be answered.
static LLVM_ABI bool classof(const SCEV *S)
Methods for support type inquiry through isa, cast, and dyn_cast:
This class defines a simple visitor class that may be used for various SCEV analysis purposes.
A utility class that uses RAII to save and restore the value of a variable.
Information about the number of loop iterations for which a loop exit's branch condition evaluates to...
LLVM_ABI ExitLimit(const SCEV *E)
Construct either an exact exit limit from a constant, or an unknown one from a SCEVCouldNotCompute.
SmallVector< const SCEVPredicate *, 4 > Predicates
A vector of predicate guards for this ExitLimit.