LLVM 23.0.0git
ScalarEvolution.cpp
Go to the documentation of this file.
1//===- ScalarEvolution.cpp - Scalar Evolution Analysis --------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file contains the implementation of the scalar evolution analysis
10// engine, which is used primarily to analyze expressions involving induction
11// variables in loops.
12//
13// There are several aspects to this library. First is the representation of
14// scalar expressions, which are represented as subclasses of the SCEV class.
15// These classes are used to represent certain types of subexpressions that we
16// can handle. We only create one SCEV of a particular shape, so
17// pointer-comparisons for equality are legal.
18//
19// One important aspect of the SCEV objects is that they are never cyclic, even
20// if there is a cycle in the dataflow for an expression (ie, a PHI node). If
21// the PHI node is one of the idioms that we can represent (e.g., a polynomial
22// recurrence) then we represent it directly as a recurrence node, otherwise we
23// represent it as a SCEVUnknown node.
24//
25// In addition to being able to represent expressions of various types, we also
26// have folders that are used to build the *canonical* representation for a
27// particular expression. These folders are capable of using a variety of
28// rewrite rules to simplify the expressions.
29//
30// Once the folders are defined, we can implement the more interesting
31// higher-level code, such as the code that recognizes PHI nodes of various
32// types, computes the execution count of a loop, etc.
33//
34// TODO: We should use these routines and value representations to implement
35// dependence analysis!
36//
37//===----------------------------------------------------------------------===//
38//
39// There are several good references for the techniques used in this analysis.
40//
41// Chains of recurrences -- a method to expedite the evaluation
42// of closed-form functions
43// Olaf Bachmann, Paul S. Wang, Eugene V. Zima
44//
45// On computational properties of chains of recurrences
46// Eugene V. Zima
47//
48// Symbolic Evaluation of Chains of Recurrences for Loop Optimization
49// Robert A. van Engelen
50//
51// Efficient Symbolic Analysis for Optimizing Compilers
52// Robert A. van Engelen
53//
54// Using the chains of recurrences algebra for data dependence testing and
55// induction variable substitution
56// MS Thesis, Johnie Birch
57//
58//===----------------------------------------------------------------------===//
59
61#include "llvm/ADT/APInt.h"
62#include "llvm/ADT/ArrayRef.h"
63#include "llvm/ADT/DenseMap.h"
65#include "llvm/ADT/FoldingSet.h"
66#include "llvm/ADT/STLExtras.h"
67#include "llvm/ADT/ScopeExit.h"
68#include "llvm/ADT/Sequence.h"
71#include "llvm/ADT/Statistic.h"
73#include "llvm/ADT/StringRef.h"
83#include "llvm/Config/llvm-config.h"
84#include "llvm/IR/Argument.h"
85#include "llvm/IR/BasicBlock.h"
86#include "llvm/IR/CFG.h"
87#include "llvm/IR/Constant.h"
89#include "llvm/IR/Constants.h"
90#include "llvm/IR/DataLayout.h"
92#include "llvm/IR/Dominators.h"
93#include "llvm/IR/Function.h"
94#include "llvm/IR/GlobalAlias.h"
95#include "llvm/IR/GlobalValue.h"
97#include "llvm/IR/InstrTypes.h"
98#include "llvm/IR/Instruction.h"
101#include "llvm/IR/Intrinsics.h"
102#include "llvm/IR/LLVMContext.h"
103#include "llvm/IR/Operator.h"
104#include "llvm/IR/PatternMatch.h"
105#include "llvm/IR/Type.h"
106#include "llvm/IR/Use.h"
107#include "llvm/IR/User.h"
108#include "llvm/IR/Value.h"
109#include "llvm/IR/Verifier.h"
111#include "llvm/Pass.h"
112#include "llvm/Support/Casting.h"
115#include "llvm/Support/Debug.h"
121#include <algorithm>
122#include <cassert>
123#include <climits>
124#include <cstdint>
125#include <cstdlib>
126#include <map>
127#include <memory>
128#include <numeric>
129#include <optional>
130#include <tuple>
131#include <utility>
132#include <vector>
133
134using namespace llvm;
135using namespace PatternMatch;
136using namespace SCEVPatternMatch;
137
138#define DEBUG_TYPE "scalar-evolution"
139
140STATISTIC(NumExitCountsComputed,
141 "Number of loop exits with predictable exit counts");
142STATISTIC(NumExitCountsNotComputed,
143 "Number of loop exits without predictable exit counts");
144STATISTIC(NumBruteForceTripCountsComputed,
145 "Number of loops with trip counts computed by force");
146
147#ifdef EXPENSIVE_CHECKS
148bool llvm::VerifySCEV = true;
149#else
150bool llvm::VerifySCEV = false;
151#endif
152
154 MaxBruteForceIterations("scalar-evolution-max-iterations", cl::ReallyHidden,
155 cl::desc("Maximum number of iterations SCEV will "
156 "symbolically execute a constant "
157 "derived loop"),
158 cl::init(100));
159
161 "verify-scev", cl::Hidden, cl::location(VerifySCEV),
162 cl::desc("Verify ScalarEvolution's backedge taken counts (slow)"));
164 "verify-scev-strict", cl::Hidden,
165 cl::desc("Enable stricter verification with -verify-scev is passed"));
166
168 "scev-verify-ir", cl::Hidden,
169 cl::desc("Verify IR correctness when making sensitive SCEV queries (slow)"),
170 cl::init(false));
171
173 "scev-mulops-inline-threshold", cl::Hidden,
174 cl::desc("Threshold for inlining multiplication operands into a SCEV"),
175 cl::init(32));
176
178 "scev-addops-inline-threshold", cl::Hidden,
179 cl::desc("Threshold for inlining addition operands into a SCEV"),
180 cl::init(500));
181
183 "scalar-evolution-max-scev-compare-depth", cl::Hidden,
184 cl::desc("Maximum depth of recursive SCEV complexity comparisons"),
185 cl::init(32));
186
188 "scalar-evolution-max-scev-operations-implication-depth", cl::Hidden,
189 cl::desc("Maximum depth of recursive SCEV operations implication analysis"),
190 cl::init(2));
191
193 "scalar-evolution-max-value-compare-depth", cl::Hidden,
194 cl::desc("Maximum depth of recursive value complexity comparisons"),
195 cl::init(2));
196
198 MaxArithDepth("scalar-evolution-max-arith-depth", cl::Hidden,
199 cl::desc("Maximum depth of recursive arithmetics"),
200 cl::init(32));
201
203 "scalar-evolution-max-constant-evolving-depth", cl::Hidden,
204 cl::desc("Maximum depth of recursive constant evolving"), cl::init(32));
205
207 MaxCastDepth("scalar-evolution-max-cast-depth", cl::Hidden,
208 cl::desc("Maximum depth of recursive SExt/ZExt/Trunc"),
209 cl::init(8));
210
212 MaxAddRecSize("scalar-evolution-max-add-rec-size", cl::Hidden,
213 cl::desc("Max coefficients in AddRec during evolving"),
214 cl::init(8));
215
217 HugeExprThreshold("scalar-evolution-huge-expr-threshold", cl::Hidden,
218 cl::desc("Size of the expression which is considered huge"),
219 cl::init(4096));
220
222 "scev-range-iter-threshold", cl::Hidden,
223 cl::desc("Threshold for switching to iteratively computing SCEV ranges"),
224 cl::init(32));
225
227 "scalar-evolution-max-loop-guard-collection-depth", cl::Hidden,
228 cl::desc("Maximum depth for recursive loop guard collection"), cl::init(1));
229
230static cl::opt<bool>
231ClassifyExpressions("scalar-evolution-classify-expressions",
232 cl::Hidden, cl::init(true),
233 cl::desc("When printing analysis, include information on every instruction"));
234
236 "scalar-evolution-use-expensive-range-sharpening", cl::Hidden,
237 cl::init(false),
238 cl::desc("Use more powerful methods of sharpening expression ranges. May "
239 "be costly in terms of compile time"));
240
242 "scalar-evolution-max-scc-analysis-depth", cl::Hidden,
243 cl::desc("Maximum amount of nodes to process while searching SCEVUnknown "
244 "Phi strongly connected components"),
245 cl::init(8));
246
247static cl::opt<bool>
248 EnableFiniteLoopControl("scalar-evolution-finite-loop", cl::Hidden,
249 cl::desc("Handle <= and >= in finite loops"),
250 cl::init(true));
251
253 "scalar-evolution-use-context-for-no-wrap-flag-strenghening", cl::Hidden,
254 cl::desc("Infer nuw/nsw flags using context where suitable"),
255 cl::init(true));
256
257//===----------------------------------------------------------------------===//
258// SCEV class definitions
259//===----------------------------------------------------------------------===//
260
262 // Leaf nodes are always their own canonical.
263 switch (getSCEVType()) {
264 case scConstant:
265 case scVScale:
266 case scUnknown:
267 CanonicalSCEV = this;
268 return;
269 default:
270 break;
271 }
272
273 // For all other expressions, check whether any immediate operand has a
274 // different canonical. Since operands are always created before their parent,
275 // their canonical pointers are already set — no recursion needed.
276 bool Changed = false;
278 for (SCEVUse Op : operands()) {
279 CanonOps.push_back(Op->getCanonical());
280 Changed |= CanonOps.back() != Op.getPointer();
281 }
282
283 if (!Changed) {
284 CanonicalSCEV = this;
285 return;
286 }
287
288 auto *NAry = dyn_cast<SCEVNAryExpr>(this);
289 SCEV::NoWrapFlags Flags = NAry ? NAry->getNoWrapFlags() : SCEV::FlagAnyWrap;
290 switch (getSCEVType()) {
291 case scPtrToAddr:
292 CanonicalSCEV = SE.getPtrToAddrExpr(CanonOps[0]);
293 return;
294 case scPtrToInt:
295 CanonicalSCEV = SE.getPtrToIntExpr(CanonOps[0], getType());
296 return;
297 case scTruncate:
298 CanonicalSCEV = SE.getTruncateExpr(CanonOps[0], getType());
299 return;
300 case scZeroExtend:
301 CanonicalSCEV = SE.getZeroExtendExpr(CanonOps[0], getType());
302 return;
303 case scSignExtend:
304 CanonicalSCEV = SE.getSignExtendExpr(CanonOps[0], getType());
305 return;
306 case scUDivExpr:
307 CanonicalSCEV = SE.getUDivExpr(CanonOps[0], CanonOps[1]);
308 return;
309 case scAddExpr:
310 CanonicalSCEV = SE.getAddExpr(CanonOps, Flags);
311 return;
312 case scMulExpr:
313 CanonicalSCEV = SE.getMulExpr(CanonOps, Flags);
314 return;
315 case scAddRecExpr:
317 CanonOps, cast<SCEVAddRecExpr>(this)->getLoop(), Flags);
318 return;
319 case scSMaxExpr:
320 CanonicalSCEV = SE.getSMaxExpr(CanonOps);
321 return;
322 case scUMaxExpr:
323 CanonicalSCEV = SE.getUMaxExpr(CanonOps);
324 return;
325 case scSMinExpr:
326 CanonicalSCEV = SE.getSMinExpr(CanonOps);
327 return;
328 case scUMinExpr:
329 CanonicalSCEV = SE.getUMinExpr(CanonOps);
330 return;
332 CanonicalSCEV = SE.getUMinExpr(CanonOps, /*Sequential=*/true);
333 return;
334 default:
335 llvm_unreachable("Unknown SCEV type");
336 }
337}
338
339//===----------------------------------------------------------------------===//
340// Implementation of the SCEV class.
341//
342
343#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
345 print(dbgs());
346 dbgs() << '\n';
347}
348#endif
349
350void SCEV::print(raw_ostream &OS) const {
351 switch (getSCEVType()) {
352 case scConstant:
353 cast<SCEVConstant>(this)->getValue()->printAsOperand(OS, false);
354 return;
355 case scVScale:
356 OS << "vscale";
357 return;
358 case scPtrToAddr:
359 case scPtrToInt: {
360 const SCEVCastExpr *PtrCast = cast<SCEVCastExpr>(this);
361 const SCEV *Op = PtrCast->getOperand();
362 StringRef OpS = getSCEVType() == scPtrToAddr ? "addr" : "int";
363 OS << "(ptrto" << OpS << " " << *Op->getType() << " " << *Op << " to "
364 << *PtrCast->getType() << ")";
365 return;
366 }
367 case scTruncate: {
368 const SCEVTruncateExpr *Trunc = cast<SCEVTruncateExpr>(this);
369 const SCEV *Op = Trunc->getOperand();
370 OS << "(trunc " << *Op->getType() << " " << *Op << " to "
371 << *Trunc->getType() << ")";
372 return;
373 }
374 case scZeroExtend: {
376 const SCEV *Op = ZExt->getOperand();
377 OS << "(zext " << *Op->getType() << " " << *Op << " to "
378 << *ZExt->getType() << ")";
379 return;
380 }
381 case scSignExtend: {
383 const SCEV *Op = SExt->getOperand();
384 OS << "(sext " << *Op->getType() << " " << *Op << " to "
385 << *SExt->getType() << ")";
386 return;
387 }
388 case scAddRecExpr: {
389 const SCEVAddRecExpr *AR = cast<SCEVAddRecExpr>(this);
390 OS << "{" << *AR->getOperand(0);
391 for (unsigned i = 1, e = AR->getNumOperands(); i != e; ++i)
392 OS << ",+," << *AR->getOperand(i);
393 OS << "}<";
394 if (AR->hasNoUnsignedWrap())
395 OS << "nuw><";
396 if (AR->hasNoSignedWrap())
397 OS << "nsw><";
398 if (AR->hasNoSelfWrap() && !AR->hasNoUnsignedWrap() &&
399 !AR->hasNoSignedWrap())
400 OS << "nw><";
401 AR->getLoop()->getHeader()->printAsOperand(OS, /*PrintType=*/false);
402 OS << ">";
403 return;
404 }
405 case scAddExpr:
406 case scMulExpr:
407 case scUMaxExpr:
408 case scSMaxExpr:
409 case scUMinExpr:
410 case scSMinExpr:
412 const SCEVNAryExpr *NAry = cast<SCEVNAryExpr>(this);
413 const char *OpStr = nullptr;
414 switch (NAry->getSCEVType()) {
415 case scAddExpr: OpStr = " + "; break;
416 case scMulExpr: OpStr = " * "; break;
417 case scUMaxExpr: OpStr = " umax "; break;
418 case scSMaxExpr: OpStr = " smax "; break;
419 case scUMinExpr:
420 OpStr = " umin ";
421 break;
422 case scSMinExpr:
423 OpStr = " smin ";
424 break;
426 OpStr = " umin_seq ";
427 break;
428 default:
429 llvm_unreachable("There are no other nary expression types.");
430 }
431 OS << "("
433 << ")";
434 switch (NAry->getSCEVType()) {
435 case scAddExpr:
436 case scMulExpr:
437 if (NAry->hasNoUnsignedWrap())
438 OS << "<nuw>";
439 if (NAry->hasNoSignedWrap())
440 OS << "<nsw>";
441 break;
442 default:
443 // Nothing to print for other nary expressions.
444 break;
445 }
446 return;
447 }
448 case scUDivExpr: {
449 const SCEVUDivExpr *UDiv = cast<SCEVUDivExpr>(this);
450 OS << "(" << *UDiv->getLHS() << " /u " << *UDiv->getRHS() << ")";
451 return;
452 }
453 case scUnknown:
454 cast<SCEVUnknown>(this)->getValue()->printAsOperand(OS, false);
455 return;
457 OS << "***COULDNOTCOMPUTE***";
458 return;
459 }
460 llvm_unreachable("Unknown SCEV kind!");
461}
462
464 switch (getSCEVType()) {
465 case scConstant:
466 return cast<SCEVConstant>(this)->getType();
467 case scVScale:
468 return cast<SCEVVScale>(this)->getType();
469 case scPtrToAddr:
470 case scPtrToInt:
471 case scTruncate:
472 case scZeroExtend:
473 case scSignExtend:
474 return cast<SCEVCastExpr>(this)->getType();
475 case scAddRecExpr:
476 return cast<SCEVAddRecExpr>(this)->getType();
477 case scMulExpr:
478 return cast<SCEVMulExpr>(this)->getType();
479 case scUMaxExpr:
480 case scSMaxExpr:
481 case scUMinExpr:
482 case scSMinExpr:
483 return cast<SCEVMinMaxExpr>(this)->getType();
485 return cast<SCEVSequentialMinMaxExpr>(this)->getType();
486 case scAddExpr:
487 return cast<SCEVAddExpr>(this)->getType();
488 case scUDivExpr:
489 return cast<SCEVUDivExpr>(this)->getType();
490 case scUnknown:
491 return cast<SCEVUnknown>(this)->getType();
493 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
494 }
495 llvm_unreachable("Unknown SCEV kind!");
496}
497
499 switch (getSCEVType()) {
500 case scConstant:
501 case scVScale:
502 case scUnknown:
503 return {};
504 case scPtrToAddr:
505 case scPtrToInt:
506 case scTruncate:
507 case scZeroExtend:
508 case scSignExtend:
509 return cast<SCEVCastExpr>(this)->operands();
510 case scAddRecExpr:
511 case scAddExpr:
512 case scMulExpr:
513 case scUMaxExpr:
514 case scSMaxExpr:
515 case scUMinExpr:
516 case scSMinExpr:
518 return cast<SCEVNAryExpr>(this)->operands();
519 case scUDivExpr:
520 return cast<SCEVUDivExpr>(this)->operands();
522 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
523 }
524 llvm_unreachable("Unknown SCEV kind!");
525}
526
527bool SCEV::isZero() const { return match(this, m_scev_Zero()); }
528
529bool SCEV::isOne() const { return match(this, m_scev_One()); }
530
531bool SCEV::isAllOnesValue() const { return match(this, m_scev_AllOnes()); }
532
535 if (!Mul) return false;
536
537 // If there is a constant factor, it will be first.
538 const SCEVConstant *SC = dyn_cast<SCEVConstant>(Mul->getOperand(0));
539 if (!SC) return false;
540
541 // Return true if the value is negative, this matches things like (-42 * V).
542 return SC->getAPInt().isNegative();
543}
544
547
549 return S->getSCEVType() == scCouldNotCompute;
550}
551
554 ID.AddInteger(scConstant);
555 ID.AddPointer(V);
556 void *IP = nullptr;
557 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
558 SCEV *S = new (SCEVAllocator) SCEVConstant(ID.Intern(SCEVAllocator), V);
559 UniqueSCEVs.InsertNode(S, IP);
560 S->computeAndSetCanonical(*this);
561 return S;
562}
563
565 return getConstant(ConstantInt::get(getContext(), Val));
566}
567
568const SCEV *
571 // TODO: Avoid implicit trunc?
572 // See https://github.com/llvm/llvm-project/issues/112510.
573 return getConstant(
574 ConstantInt::get(ITy, V, isSigned, /*ImplicitTrunc=*/true));
575}
576
579 ID.AddInteger(scVScale);
580 ID.AddPointer(Ty);
581 void *IP = nullptr;
582 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP))
583 return S;
584 SCEV *S = new (SCEVAllocator) SCEVVScale(ID.Intern(SCEVAllocator), Ty);
585 UniqueSCEVs.InsertNode(S, IP);
586 S->computeAndSetCanonical(*this);
587 return S;
588}
589
591 SCEV::NoWrapFlags Flags) {
592 const SCEV *Res = getConstant(Ty, EC.getKnownMinValue());
593 if (EC.isScalable())
594 Res = getMulExpr(Res, getVScale(Ty), Flags);
595 return Res;
596}
597
601
602SCEVPtrToAddrExpr::SCEVPtrToAddrExpr(const FoldingSetNodeIDRef ID,
603 const SCEV *Op, Type *ITy)
604 : SCEVCastExpr(ID, scPtrToAddr, Op, ITy) {
605 assert(getOperand()->getType()->isPointerTy() && Ty->isIntegerTy() &&
606 "Must be a non-bit-width-changing pointer-to-integer cast!");
607}
608
609SCEVPtrToIntExpr::SCEVPtrToIntExpr(const FoldingSetNodeIDRef ID, SCEVUse Op,
610 Type *ITy)
611 : SCEVCastExpr(ID, scPtrToInt, Op, ITy) {
612 assert(getOperand()->getType()->isPointerTy() && Ty->isIntegerTy() &&
613 "Must be a non-bit-width-changing pointer-to-integer cast!");
614}
615
620
621SCEVTruncateExpr::SCEVTruncateExpr(const FoldingSetNodeIDRef ID, SCEVUse op,
622 Type *ty)
624 assert(getOperand()->getType()->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
625 "Cannot truncate non-integer value!");
626}
627
628SCEVZeroExtendExpr::SCEVZeroExtendExpr(const FoldingSetNodeIDRef ID, SCEVUse op,
629 Type *ty)
631 assert(getOperand()->getType()->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
632 "Cannot zero extend non-integer value!");
633}
634
635SCEVSignExtendExpr::SCEVSignExtendExpr(const FoldingSetNodeIDRef ID, SCEVUse op,
636 Type *ty)
638 assert(getOperand()->getType()->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
639 "Cannot sign extend non-integer value!");
640}
641
643 // Clear this SCEVUnknown from various maps.
644 SE->forgetMemoizedResults({this});
645
646 // Remove this SCEVUnknown from the uniquing map.
647 SE->UniqueSCEVs.RemoveNode(this);
648
649 // Release the value.
650 setValPtr(nullptr);
651}
652
653void SCEVUnknown::allUsesReplacedWith(Value *New) {
654 // Clear this SCEVUnknown from various maps.
655 SE->forgetMemoizedResults({this});
656
657 // Remove this SCEVUnknown from the uniquing map.
658 SE->UniqueSCEVs.RemoveNode(this);
659
660 // Replace the value pointer in case someone is still using this SCEVUnknown.
661 setValPtr(New);
662}
663
664//===----------------------------------------------------------------------===//
665// SCEV Utilities
666//===----------------------------------------------------------------------===//
667
668/// Compare the two values \p LV and \p RV in terms of their "complexity" where
669/// "complexity" is a partial (and somewhat ad-hoc) relation used to order
670/// operands in SCEV expressions.
671static int CompareValueComplexity(const LoopInfo *const LI, Value *LV,
672 Value *RV, unsigned Depth) {
674 return 0;
675
676 // Order pointer values after integer values. This helps SCEVExpander form
677 // GEPs.
678 bool LIsPointer = LV->getType()->isPointerTy(),
679 RIsPointer = RV->getType()->isPointerTy();
680 if (LIsPointer != RIsPointer)
681 return (int)LIsPointer - (int)RIsPointer;
682
683 // Compare getValueID values.
684 unsigned LID = LV->getValueID(), RID = RV->getValueID();
685 if (LID != RID)
686 return (int)LID - (int)RID;
687
688 // Sort arguments by their position.
689 if (const auto *LA = dyn_cast<Argument>(LV)) {
690 const auto *RA = cast<Argument>(RV);
691 unsigned LArgNo = LA->getArgNo(), RArgNo = RA->getArgNo();
692 return (int)LArgNo - (int)RArgNo;
693 }
694
695 if (const auto *LGV = dyn_cast<GlobalValue>(LV)) {
696 const auto *RGV = cast<GlobalValue>(RV);
697
698 if (auto L = LGV->getLinkage() - RGV->getLinkage())
699 return L;
700
701 const auto IsGVNameSemantic = [&](const GlobalValue *GV) {
702 auto LT = GV->getLinkage();
703 return !(GlobalValue::isPrivateLinkage(LT) ||
705 };
706
707 // Use the names to distinguish the two values, but only if the
708 // names are semantically important.
709 if (IsGVNameSemantic(LGV) && IsGVNameSemantic(RGV))
710 return LGV->getName().compare(RGV->getName());
711 }
712
713 // For instructions, compare their loop depth, and their operand count. This
714 // is pretty loose.
715 if (const auto *LInst = dyn_cast<Instruction>(LV)) {
716 const auto *RInst = cast<Instruction>(RV);
717
718 // Compare loop depths.
719 const BasicBlock *LParent = LInst->getParent(),
720 *RParent = RInst->getParent();
721 if (LParent != RParent) {
722 unsigned LDepth = LI->getLoopDepth(LParent),
723 RDepth = LI->getLoopDepth(RParent);
724 if (LDepth != RDepth)
725 return (int)LDepth - (int)RDepth;
726 }
727
728 // Compare the number of operands.
729 unsigned LNumOps = LInst->getNumOperands(),
730 RNumOps = RInst->getNumOperands();
731 if (LNumOps != RNumOps)
732 return (int)LNumOps - (int)RNumOps;
733
734 for (unsigned Idx : seq(LNumOps)) {
735 int Result = CompareValueComplexity(LI, LInst->getOperand(Idx),
736 RInst->getOperand(Idx), Depth + 1);
737 if (Result != 0)
738 return Result;
739 }
740 }
741
742 return 0;
743}
744
745// Return negative, zero, or positive, if LHS is less than, equal to, or greater
746// than RHS, respectively. A three-way result allows recursive comparisons to be
747// more efficient.
748// If the max analysis depth was reached, return std::nullopt, assuming we do
749// not know if they are equivalent for sure.
750static std::optional<int>
751CompareSCEVComplexity(const LoopInfo *const LI, const SCEV *LHS,
752 const SCEV *RHS, DominatorTree &DT, unsigned Depth = 0) {
753 // Fast-path: SCEVs are uniqued so we can do a quick equality check.
754 if (LHS == RHS)
755 return 0;
756
757 // Primarily, sort the SCEVs by their getSCEVType().
758 SCEVTypes LType = LHS->getSCEVType(), RType = RHS->getSCEVType();
759 if (LType != RType)
760 return (int)LType - (int)RType;
761
763 return std::nullopt;
764
765 // Aside from the getSCEVType() ordering, the particular ordering
766 // isn't very important except that it's beneficial to be consistent,
767 // so that (a + b) and (b + a) don't end up as different expressions.
768 switch (LType) {
769 case scUnknown: {
770 const SCEVUnknown *LU = cast<SCEVUnknown>(LHS);
771 const SCEVUnknown *RU = cast<SCEVUnknown>(RHS);
772
773 int X =
774 CompareValueComplexity(LI, LU->getValue(), RU->getValue(), Depth + 1);
775 return X;
776 }
777
778 case scConstant: {
781
782 // Compare constant values.
783 const APInt &LA = LC->getAPInt();
784 const APInt &RA = RC->getAPInt();
785 unsigned LBitWidth = LA.getBitWidth(), RBitWidth = RA.getBitWidth();
786 if (LBitWidth != RBitWidth)
787 return (int)LBitWidth - (int)RBitWidth;
788 return LA.ult(RA) ? -1 : 1;
789 }
790
791 case scVScale: {
792 const auto *LTy = cast<IntegerType>(cast<SCEVVScale>(LHS)->getType());
793 const auto *RTy = cast<IntegerType>(cast<SCEVVScale>(RHS)->getType());
794 return LTy->getBitWidth() - RTy->getBitWidth();
795 }
796
797 case scAddRecExpr: {
800
801 // There is always a dominance between two recs that are used by one SCEV,
802 // so we can safely sort recs by loop header dominance. We require such
803 // order in getAddExpr.
804 const Loop *LLoop = LA->getLoop(), *RLoop = RA->getLoop();
805 if (LLoop != RLoop) {
806 const BasicBlock *LHead = LLoop->getHeader(), *RHead = RLoop->getHeader();
807 assert(LHead != RHead && "Two loops share the same header?");
808 if (DT.dominates(LHead, RHead))
809 return 1;
810 assert(DT.dominates(RHead, LHead) &&
811 "No dominance between recurrences used by one SCEV?");
812 return -1;
813 }
814
815 [[fallthrough]];
816 }
817
818 case scTruncate:
819 case scZeroExtend:
820 case scSignExtend:
821 case scPtrToAddr:
822 case scPtrToInt:
823 case scAddExpr:
824 case scMulExpr:
825 case scUDivExpr:
826 case scSMaxExpr:
827 case scUMaxExpr:
828 case scSMinExpr:
829 case scUMinExpr:
831 ArrayRef<SCEVUse> LOps = LHS->operands();
832 ArrayRef<SCEVUse> ROps = RHS->operands();
833
834 // Lexicographically compare n-ary-like expressions.
835 unsigned LNumOps = LOps.size(), RNumOps = ROps.size();
836 if (LNumOps != RNumOps)
837 return (int)LNumOps - (int)RNumOps;
838
839 for (unsigned i = 0; i != LNumOps; ++i) {
840 auto X = CompareSCEVComplexity(LI, LOps[i].getPointer(),
841 ROps[i].getPointer(), DT, Depth + 1);
842 if (X != 0)
843 return X;
844 }
845 return 0;
846 }
847
849 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
850 }
851 llvm_unreachable("Unknown SCEV kind!");
852}
853
854/// Given a list of SCEV objects, order them by their complexity, and group
855/// objects of the same complexity together by value. When this routine is
856/// finished, we know that any duplicates in the vector are consecutive and that
857/// complexity is monotonically increasing.
858///
859/// Note that we go take special precautions to ensure that we get deterministic
860/// results from this routine. In other words, we don't want the results of
861/// this to depend on where the addresses of various SCEV objects happened to
862/// land in memory.
864 DominatorTree &DT) {
865 if (Ops.size() < 2) return; // Noop
866
867 // Whether LHS has provably less complexity than RHS.
868 auto IsLessComplex = [&](SCEVUse LHS, SCEVUse RHS) {
869 auto Complexity = CompareSCEVComplexity(LI, LHS, RHS, DT);
870 return Complexity && *Complexity < 0;
871 };
872 if (Ops.size() == 2) {
873 // This is the common case, which also happens to be trivially simple.
874 // Special case it.
875 SCEVUse &LHS = Ops[0], &RHS = Ops[1];
876 if (IsLessComplex(RHS, LHS))
877 std::swap(LHS, RHS);
878 return;
879 }
880
881 // Do the rough sort by complexity.
883 Ops, [&](SCEVUse LHS, SCEVUse RHS) { return IsLessComplex(LHS, RHS); });
884
885 // Now that we are sorted by complexity, group elements of the same
886 // complexity. Note that this is, at worst, N^2, but the vector is likely to
887 // be extremely short in practice. Note that we take this approach because we
888 // do not want to depend on the addresses of the objects we are grouping.
889 for (unsigned i = 0, e = Ops.size(); i != e-2; ++i) {
890 const SCEV *S = Ops[i];
891 unsigned Complexity = S->getSCEVType();
892
893 // If there are any objects of the same complexity and same value as this
894 // one, group them.
895 for (unsigned j = i+1; j != e && Ops[j]->getSCEVType() == Complexity; ++j) {
896 if (Ops[j] == S) { // Found a duplicate.
897 // Move it to immediately after i'th element.
898 std::swap(Ops[i+1], Ops[j]);
899 ++i; // no need to rescan it.
900 if (i == e-2) return; // Done!
901 }
902 }
903 }
904}
905
906/// Returns true if \p Ops contains a huge SCEV (the subtree of S contains at
907/// least HugeExprThreshold nodes).
909 return any_of(Ops, [](const SCEV *S) {
911 });
912}
913
914/// Performs a number of common optimizations on the passed \p Ops. If the
915/// whole expression reduces down to a single operand, it will be returned.
916///
917/// The following optimizations are performed:
918/// * Fold constants using the \p Fold function.
919/// * Remove identity constants satisfying \p IsIdentity.
920/// * If a constant satisfies \p IsAbsorber, return it.
921/// * Sort operands by complexity.
922template <typename FoldT, typename IsIdentityT, typename IsAbsorberT>
923static const SCEV *
925 SmallVectorImpl<SCEVUse> &Ops, FoldT Fold,
926 IsIdentityT IsIdentity, IsAbsorberT IsAbsorber) {
927 const SCEVConstant *Folded = nullptr;
928 for (unsigned Idx = 0; Idx < Ops.size();) {
929 const SCEV *Op = Ops[Idx];
930 if (const auto *C = dyn_cast<SCEVConstant>(Op)) {
931 if (!Folded)
932 Folded = C;
933 else
934 Folded = cast<SCEVConstant>(
935 SE.getConstant(Fold(Folded->getAPInt(), C->getAPInt())));
936 Ops.erase(Ops.begin() + Idx);
937 continue;
938 }
939 ++Idx;
940 }
941
942 if (Ops.empty()) {
943 assert(Folded && "Must have folded value");
944 return Folded;
945 }
946
947 if (Folded && IsAbsorber(Folded->getAPInt()))
948 return Folded;
949
950 GroupByComplexity(Ops, &LI, DT);
951 if (Folded && !IsIdentity(Folded->getAPInt()))
952 Ops.insert(Ops.begin(), Folded);
953
954 return Ops.size() == 1 ? Ops[0] : nullptr;
955}
956
957//===----------------------------------------------------------------------===//
958// Simple SCEV method implementations
959//===----------------------------------------------------------------------===//
960
961/// Compute BC(It, K). The result has width W. Assume, K > 0.
962static const SCEV *BinomialCoefficient(const SCEV *It, unsigned K,
963 ScalarEvolution &SE,
964 Type *ResultTy) {
965 // Handle the simplest case efficiently.
966 if (K == 1)
967 return SE.getTruncateOrZeroExtend(It, ResultTy);
968
969 // We are using the following formula for BC(It, K):
970 //
971 // BC(It, K) = (It * (It - 1) * ... * (It - K + 1)) / K!
972 //
973 // Suppose, W is the bitwidth of the return value. We must be prepared for
974 // overflow. Hence, we must assure that the result of our computation is
975 // equal to the accurate one modulo 2^W. Unfortunately, division isn't
976 // safe in modular arithmetic.
977 //
978 // However, this code doesn't use exactly that formula; the formula it uses
979 // is something like the following, where T is the number of factors of 2 in
980 // K! (i.e. trailing zeros in the binary representation of K!), and ^ is
981 // exponentiation:
982 //
983 // BC(It, K) = (It * (It - 1) * ... * (It - K + 1)) / 2^T / (K! / 2^T)
984 //
985 // This formula is trivially equivalent to the previous formula. However,
986 // this formula can be implemented much more efficiently. The trick is that
987 // K! / 2^T is odd, and exact division by an odd number *is* safe in modular
988 // arithmetic. To do exact division in modular arithmetic, all we have
989 // to do is multiply by the inverse. Therefore, this step can be done at
990 // width W.
991 //
992 // The next issue is how to safely do the division by 2^T. The way this
993 // is done is by doing the multiplication step at a width of at least W + T
994 // bits. This way, the bottom W+T bits of the product are accurate. Then,
995 // when we perform the division by 2^T (which is equivalent to a right shift
996 // by T), the bottom W bits are accurate. Extra bits are okay; they'll get
997 // truncated out after the division by 2^T.
998 //
999 // In comparison to just directly using the first formula, this technique
1000 // is much more efficient; using the first formula requires W * K bits,
1001 // but this formula less than W + K bits. Also, the first formula requires
1002 // a division step, whereas this formula only requires multiplies and shifts.
1003 //
1004 // It doesn't matter whether the subtraction step is done in the calculation
1005 // width or the input iteration count's width; if the subtraction overflows,
1006 // the result must be zero anyway. We prefer here to do it in the width of
1007 // the induction variable because it helps a lot for certain cases; CodeGen
1008 // isn't smart enough to ignore the overflow, which leads to much less
1009 // efficient code if the width of the subtraction is wider than the native
1010 // register width.
1011 //
1012 // (It's possible to not widen at all by pulling out factors of 2 before
1013 // the multiplication; for example, K=2 can be calculated as
1014 // It/2*(It+(It*INT_MIN/INT_MIN)+-1). However, it requires
1015 // extra arithmetic, so it's not an obvious win, and it gets
1016 // much more complicated for K > 3.)
1017
1018 // Protection from insane SCEVs; this bound is conservative,
1019 // but it probably doesn't matter.
1020 if (K > 1000)
1021 return SE.getCouldNotCompute();
1022
1023 unsigned W = SE.getTypeSizeInBits(ResultTy);
1024
1025 // Calculate K! / 2^T and T; we divide out the factors of two before
1026 // multiplying for calculating K! / 2^T to avoid overflow.
1027 // Other overflow doesn't matter because we only care about the bottom
1028 // W bits of the result.
1029 APInt OddFactorial(W, 1);
1030 unsigned T = 1;
1031 for (unsigned i = 3; i <= K; ++i) {
1032 unsigned TwoFactors = countr_zero(i);
1033 T += TwoFactors;
1034 OddFactorial *= (i >> TwoFactors);
1035 }
1036
1037 // We need at least W + T bits for the multiplication step
1038 unsigned CalculationBits = W + T;
1039
1040 // Calculate 2^T, at width T+W.
1041 APInt DivFactor = APInt::getOneBitSet(CalculationBits, T);
1042
1043 // Calculate the multiplicative inverse of K! / 2^T;
1044 // this multiplication factor will perform the exact division by
1045 // K! / 2^T.
1046 APInt MultiplyFactor = OddFactorial.multiplicativeInverse();
1047
1048 // Calculate the product, at width T+W
1049 IntegerType *CalculationTy = IntegerType::get(SE.getContext(),
1050 CalculationBits);
1051 const SCEV *Dividend = SE.getTruncateOrZeroExtend(It, CalculationTy);
1052 for (unsigned i = 1; i != K; ++i) {
1053 const SCEV *S = SE.getMinusSCEV(It, SE.getConstant(It->getType(), i));
1054 Dividend = SE.getMulExpr(Dividend,
1055 SE.getTruncateOrZeroExtend(S, CalculationTy));
1056 }
1057
1058 // Divide by 2^T
1059 const SCEV *DivResult = SE.getUDivExpr(Dividend, SE.getConstant(DivFactor));
1060
1061 // Truncate the result, and divide by K! / 2^T.
1062
1063 return SE.getMulExpr(SE.getConstant(MultiplyFactor),
1064 SE.getTruncateOrZeroExtend(DivResult, ResultTy));
1065}
1066
1067/// Return the value of this chain of recurrences at the specified iteration
1068/// number. We can evaluate this recurrence by multiplying each element in the
1069/// chain by the binomial coefficient corresponding to it. In other words, we
1070/// can evaluate {A,+,B,+,C,+,D} as:
1071///
1072/// A*BC(It, 0) + B*BC(It, 1) + C*BC(It, 2) + D*BC(It, 3)
1073///
1074/// where BC(It, k) stands for binomial coefficient.
1076 ScalarEvolution &SE) const {
1077 return evaluateAtIteration(operands(), It, SE);
1078}
1079
1081 const SCEV *It,
1082 ScalarEvolution &SE) {
1083 assert(Operands.size() > 0);
1084 const SCEV *Result = Operands[0].getPointer();
1085 for (unsigned i = 1, e = Operands.size(); i != e; ++i) {
1086 // The computation is correct in the face of overflow provided that the
1087 // multiplication is performed _after_ the evaluation of the binomial
1088 // coefficient.
1089 const SCEV *Coeff = BinomialCoefficient(It, i, SE, Result->getType());
1090 if (isa<SCEVCouldNotCompute>(Coeff))
1091 return Coeff;
1092
1093 Result =
1094 SE.getAddExpr(Result, SE.getMulExpr(Operands[i].getPointer(), Coeff));
1095 }
1096 return Result;
1097}
1098
1099//===----------------------------------------------------------------------===//
1100// SCEV Expression folder implementations
1101//===----------------------------------------------------------------------===//
1102
1103/// The SCEVCastSinkingRewriter takes a scalar evolution expression,
1104/// which computes a pointer-typed value, and rewrites the whole expression
1105/// tree so that *all* the computations are done on integers, and the only
1106/// pointer-typed operands in the expression are SCEVUnknown.
1107/// The CreatePtrCast callback is invoked to create the actual conversion
1108/// (ptrtoint or ptrtoaddr) at the SCEVUnknown leaves.
1110 : public SCEVRewriteVisitor<SCEVCastSinkingRewriter> {
1112 using ConversionFn = function_ref<const SCEV *(const SCEVUnknown *)>;
1113 Type *TargetTy;
1114 ConversionFn CreatePtrCast;
1115
1116public:
1118 ConversionFn CreatePtrCast)
1119 : Base(SE), TargetTy(TargetTy), CreatePtrCast(std::move(CreatePtrCast)) {}
1120
1121 static const SCEV *rewrite(const SCEV *Scev, ScalarEvolution &SE,
1122 Type *TargetTy, ConversionFn CreatePtrCast) {
1123 SCEVCastSinkingRewriter Rewriter(SE, TargetTy, std::move(CreatePtrCast));
1124 return Rewriter.visit(Scev);
1125 }
1126
1127 const SCEV *visit(const SCEV *S) {
1128 Type *STy = S->getType();
1129 // If the expression is not pointer-typed, just keep it as-is.
1130 if (!STy->isPointerTy())
1131 return S;
1132 // Else, recursively sink the cast down into it.
1133 return Base::visit(S);
1134 }
1135
1136 const SCEV *visitAddExpr(const SCEVAddExpr *Expr) {
1137 // Preserve wrap flags on rewritten SCEVAddExpr, which the default
1138 // implementation drops.
1139 SmallVector<SCEVUse, 2> Operands;
1140 bool Changed = false;
1141 for (SCEVUse Op : Expr->operands()) {
1142 Operands.push_back(visit(Op.getPointer()));
1143 Changed |= Op.getPointer() != Operands.back();
1144 }
1145 return !Changed ? Expr : SE.getAddExpr(Operands, Expr->getNoWrapFlags());
1146 }
1147
1148 const SCEV *visitMulExpr(const SCEVMulExpr *Expr) {
1149 SmallVector<SCEVUse, 2> Operands;
1150 bool Changed = false;
1151 for (SCEVUse Op : Expr->operands()) {
1152 Operands.push_back(visit(Op.getPointer()));
1153 Changed |= Op.getPointer() != Operands.back();
1154 }
1155 return !Changed ? Expr : SE.getMulExpr(Operands, Expr->getNoWrapFlags());
1156 }
1157
1158 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
1159 assert(Expr->getType()->isPointerTy() &&
1160 "Should only reach pointer-typed SCEVUnknown's.");
1161 // Perform some basic constant folding. If the operand of the cast is a
1162 // null pointer, don't create a cast SCEV expression (that will be left
1163 // as-is), but produce a zero constant.
1165 return SE.getZero(TargetTy);
1166 return CreatePtrCast(Expr);
1167 }
1168};
1169
1171 assert(Op->getType()->isPointerTy() && "Op must be a pointer");
1172
1173 // It isn't legal for optimizations to construct new ptrtoint expressions
1174 // for non-integral pointers.
1175 if (getDataLayout().isNonIntegralPointerType(Op->getType()))
1176 return getCouldNotCompute();
1177
1178 Type *IntPtrTy = getDataLayout().getIntPtrType(Op->getType());
1179
1180 // We can only trivially model ptrtoint if SCEV's effective (integer) type
1181 // is sufficiently wide to represent all possible pointer values.
1182 // We could theoretically teach SCEV to truncate wider pointers, but
1183 // that isn't implemented for now.
1185 getDataLayout().getTypeSizeInBits(IntPtrTy))
1186 return getCouldNotCompute();
1187
1188 // Use the rewriter to sink the cast down to SCEVUnknown leaves.
1190 Op, *this, IntPtrTy, [this, IntPtrTy](const SCEVUnknown *U) {
1192 ID.AddInteger(scPtrToInt);
1193 ID.AddPointer(U);
1194 void *IP = nullptr;
1195 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP))
1196 return S;
1197 SCEV *S = new (SCEVAllocator)
1198 SCEVPtrToIntExpr(ID.Intern(SCEVAllocator), U, IntPtrTy);
1199 UniqueSCEVs.InsertNode(S, IP);
1200 S->computeAndSetCanonical(*this);
1201 registerUser(S, U);
1202 return static_cast<const SCEV *>(S);
1203 });
1204 assert(IntOp->getType()->isIntegerTy() &&
1205 "We must have succeeded in sinking the cast, "
1206 "and ending up with an integer-typed expression!");
1207 return IntOp;
1208}
1209
1211 assert(Op->getType()->isPointerTy() && "Op must be a pointer");
1212
1213 // Treat pointers with unstable representation conservatively, since the
1214 // address bits may change.
1215 if (DL.hasUnstableRepresentation(Op->getType()))
1216 return getCouldNotCompute();
1217
1218 Type *Ty = DL.getAddressType(Op->getType());
1219
1220 // Use the rewriter to sink the cast down to SCEVUnknown leaves.
1221 // The rewriter handles null pointer constant folding.
1223 Op, *this, Ty, [this, Ty](const SCEVUnknown *U) {
1225 ID.AddInteger(scPtrToAddr);
1226 ID.AddPointer(U);
1227 ID.AddPointer(Ty);
1228 void *IP = nullptr;
1229 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP))
1230 return S;
1231 SCEV *S = new (SCEVAllocator)
1232 SCEVPtrToAddrExpr(ID.Intern(SCEVAllocator), U, Ty);
1233 UniqueSCEVs.InsertNode(S, IP);
1234 S->computeAndSetCanonical(*this);
1235 registerUser(S, U);
1236 return static_cast<const SCEV *>(S);
1237 });
1238 assert(IntOp->getType()->isIntegerTy() &&
1239 "We must have succeeded in sinking the cast, "
1240 "and ending up with an integer-typed expression!");
1241 return IntOp;
1242}
1243
1245 assert(Ty->isIntegerTy() && "Target type must be an integer type!");
1246
1247 const SCEV *IntOp = getLosslessPtrToIntExpr(Op);
1248 if (isa<SCEVCouldNotCompute>(IntOp))
1249 return IntOp;
1250
1251 return getTruncateOrZeroExtend(IntOp, Ty);
1252}
1253
1255 unsigned Depth) {
1256 assert(getTypeSizeInBits(Op->getType()) > getTypeSizeInBits(Ty) &&
1257 "This is not a truncating conversion!");
1258 assert(isSCEVable(Ty) &&
1259 "This is not a conversion to a SCEVable type!");
1260 assert(!Op->getType()->isPointerTy() && "Can't truncate pointer!");
1261 Ty = getEffectiveSCEVType(Ty);
1262
1264 ID.AddInteger(scTruncate);
1265 ID.AddPointer(Op);
1266 ID.AddPointer(Ty);
1267 void *IP = nullptr;
1268 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
1269
1270 // Fold if the operand is constant.
1271 if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op))
1272 return getConstant(
1273 cast<ConstantInt>(ConstantExpr::getTrunc(SC->getValue(), Ty)));
1274
1275 // trunc(trunc(x)) --> trunc(x)
1277 return getTruncateExpr(ST->getOperand(), Ty, Depth + 1);
1278
1279 // trunc(sext(x)) --> sext(x) if widening or trunc(x) if narrowing
1281 return getTruncateOrSignExtend(SS->getOperand(), Ty, Depth + 1);
1282
1283 // trunc(zext(x)) --> zext(x) if widening or trunc(x) if narrowing
1285 return getTruncateOrZeroExtend(SZ->getOperand(), Ty, Depth + 1);
1286
1287 if (Depth > MaxCastDepth) {
1288 SCEV *S =
1289 new (SCEVAllocator) SCEVTruncateExpr(ID.Intern(SCEVAllocator), Op, Ty);
1290 UniqueSCEVs.InsertNode(S, IP);
1291 S->computeAndSetCanonical(*this);
1292 registerUser(S, Op);
1293 return S;
1294 }
1295
1296 // trunc(x1 + ... + xN) --> trunc(x1) + ... + trunc(xN) and
1297 // trunc(x1 * ... * xN) --> trunc(x1) * ... * trunc(xN),
1298 // if after transforming we have at most one truncate, not counting truncates
1299 // that replace other casts.
1301 auto *CommOp = cast<SCEVCommutativeExpr>(Op);
1302 SmallVector<SCEVUse, 4> Operands;
1303 unsigned numTruncs = 0;
1304 for (unsigned i = 0, e = CommOp->getNumOperands(); i != e && numTruncs < 2;
1305 ++i) {
1306 const SCEV *S = getTruncateExpr(CommOp->getOperand(i), Ty, Depth + 1);
1307 if (!isa<SCEVIntegralCastExpr>(CommOp->getOperand(i)) &&
1309 numTruncs++;
1310 Operands.push_back(S);
1311 }
1312 if (numTruncs < 2) {
1313 if (isa<SCEVAddExpr>(Op))
1314 return getAddExpr(Operands);
1315 if (isa<SCEVMulExpr>(Op))
1316 return getMulExpr(Operands);
1317 llvm_unreachable("Unexpected SCEV type for Op.");
1318 }
1319 // Although we checked in the beginning that ID is not in the cache, it is
1320 // possible that during recursion and different modification ID was inserted
1321 // into the cache. So if we find it, just return it.
1322 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP))
1323 return S;
1324 }
1325
1326 // If the input value is a chrec scev, truncate the chrec's operands.
1327 if (const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(Op)) {
1328 SmallVector<SCEVUse, 4> Operands;
1329 for (const SCEV *Op : AddRec->operands())
1330 Operands.push_back(getTruncateExpr(Op, Ty, Depth + 1));
1331 return getAddRecExpr(Operands, AddRec->getLoop(), SCEV::FlagAnyWrap);
1332 }
1333
1334 // Return zero if truncating to known zeros.
1335 uint32_t MinTrailingZeros = getMinTrailingZeros(Op);
1336 if (MinTrailingZeros >= getTypeSizeInBits(Ty))
1337 return getZero(Ty);
1338
1339 // The cast wasn't folded; create an explicit cast node. We can reuse
1340 // the existing insert position since if we get here, we won't have
1341 // made any changes which would invalidate it.
1342 SCEV *S = new (SCEVAllocator) SCEVTruncateExpr(ID.Intern(SCEVAllocator),
1343 Op, Ty);
1344 UniqueSCEVs.InsertNode(S, IP);
1345 S->computeAndSetCanonical(*this);
1346 registerUser(S, Op);
1347 return S;
1348}
1349
1350// Get the limit of a recurrence such that incrementing by Step cannot cause
1351// signed overflow as long as the value of the recurrence within the
1352// loop does not exceed this limit before incrementing.
1353static const SCEV *getSignedOverflowLimitForStep(const SCEV *Step,
1354 ICmpInst::Predicate *Pred,
1355 ScalarEvolution *SE) {
1356 unsigned BitWidth = SE->getTypeSizeInBits(Step->getType());
1357 if (SE->isKnownPositive(Step)) {
1358 *Pred = ICmpInst::ICMP_SLT;
1360 SE->getSignedRangeMax(Step));
1361 }
1362 if (SE->isKnownNegative(Step)) {
1363 *Pred = ICmpInst::ICMP_SGT;
1365 SE->getSignedRangeMin(Step));
1366 }
1367 return nullptr;
1368}
1369
1370// Get the limit of a recurrence such that incrementing by Step cannot cause
1371// unsigned overflow as long as the value of the recurrence within the loop does
1372// not exceed this limit before incrementing.
1374 ICmpInst::Predicate *Pred,
1375 ScalarEvolution *SE) {
1376 unsigned BitWidth = SE->getTypeSizeInBits(Step->getType());
1377 *Pred = ICmpInst::ICMP_ULT;
1378
1380 SE->getUnsignedRangeMax(Step));
1381}
1382
1383namespace {
1384
1385struct ExtendOpTraitsBase {
1386 typedef const SCEV *(ScalarEvolution::*GetExtendExprTy)(const SCEV *, Type *,
1387 unsigned);
1388};
1389
1390// Used to make code generic over signed and unsigned overflow.
1391template <typename ExtendOp> struct ExtendOpTraits {
1392 // Members present:
1393 //
1394 // static const SCEV::NoWrapFlags WrapType;
1395 //
1396 // static const ExtendOpTraitsBase::GetExtendExprTy GetExtendExpr;
1397 //
1398 // static const SCEV *getOverflowLimitForStep(const SCEV *Step,
1399 // ICmpInst::Predicate *Pred,
1400 // ScalarEvolution *SE);
1401};
1402
1403template <>
1404struct ExtendOpTraits<SCEVSignExtendExpr> : public ExtendOpTraitsBase {
1405 static const SCEV::NoWrapFlags WrapType = SCEV::FlagNSW;
1406
1407 static const GetExtendExprTy GetExtendExpr;
1408
1409 static const SCEV *getOverflowLimitForStep(const SCEV *Step,
1410 ICmpInst::Predicate *Pred,
1411 ScalarEvolution *SE) {
1412 return getSignedOverflowLimitForStep(Step, Pred, SE);
1413 }
1414};
1415
1416const ExtendOpTraitsBase::GetExtendExprTy ExtendOpTraits<
1418
1419template <>
1420struct ExtendOpTraits<SCEVZeroExtendExpr> : public ExtendOpTraitsBase {
1421 static const SCEV::NoWrapFlags WrapType = SCEV::FlagNUW;
1422
1423 static const GetExtendExprTy GetExtendExpr;
1424
1425 static const SCEV *getOverflowLimitForStep(const SCEV *Step,
1426 ICmpInst::Predicate *Pred,
1427 ScalarEvolution *SE) {
1428 return getUnsignedOverflowLimitForStep(Step, Pred, SE);
1429 }
1430};
1431
1432const ExtendOpTraitsBase::GetExtendExprTy ExtendOpTraits<
1434
1435} // end anonymous namespace
1436
1437// The recurrence AR has been shown to have no signed/unsigned wrap or something
1438// close to it. Typically, if we can prove NSW/NUW for AR, then we can just as
1439// easily prove NSW/NUW for its preincrement or postincrement sibling. This
1440// allows normalizing a sign/zero extended AddRec as such: {sext/zext(Step +
1441// Start),+,Step} => {(Step + sext/zext(Start),+,Step} As a result, the
1442// expression "Step + sext/zext(PreIncAR)" is congruent with
1443// "sext/zext(PostIncAR)"
1444template <typename ExtendOpTy>
1445static const SCEV *getPreStartForExtend(const SCEVAddRecExpr *AR, Type *Ty,
1446 ScalarEvolution *SE, unsigned Depth) {
1447 auto WrapType = ExtendOpTraits<ExtendOpTy>::WrapType;
1448 auto GetExtendExpr = ExtendOpTraits<ExtendOpTy>::GetExtendExpr;
1449
1450 const Loop *L = AR->getLoop();
1451 const SCEV *Start = AR->getStart();
1452 const SCEV *Step = AR->getStepRecurrence(*SE);
1453
1454 // Check for a simple looking step prior to loop entry.
1455 const SCEVAddExpr *SA = dyn_cast<SCEVAddExpr>(Start);
1456 if (!SA)
1457 return nullptr;
1458
1459 // Create an AddExpr for "PreStart" after subtracting Step. Full SCEV
1460 // subtraction is expensive. For this purpose, perform a quick and dirty
1461 // difference, by checking for Step in the operand list. Note, that
1462 // SA might have repeated ops, like %a + %a + ..., so only remove one.
1463 SmallVector<SCEVUse, 4> DiffOps(SA->operands());
1464 for (auto It = DiffOps.begin(); It != DiffOps.end(); ++It)
1465 if (*It == Step) {
1466 DiffOps.erase(It);
1467 break;
1468 }
1469
1470 if (DiffOps.size() == SA->getNumOperands())
1471 return nullptr;
1472
1473 // Try to prove `WrapType` (SCEV::FlagNSW or SCEV::FlagNUW) on `PreStart` +
1474 // `Step`:
1475
1476 // 1. NSW/NUW flags on the step increment.
1477 auto PreStartFlags =
1479 const SCEV *PreStart = SE->getAddExpr(DiffOps, PreStartFlags);
1481 SE->getAddRecExpr(PreStart, Step, L, SCEV::FlagAnyWrap));
1482
1483 // "{S,+,X} is <nsw>/<nuw>" and "the backedge is taken at least once" implies
1484 // "S+X does not sign/unsign-overflow".
1485 //
1486
1487 const SCEV *BECount = SE->getBackedgeTakenCount(L);
1488 if (PreAR && any(PreAR->getNoWrapFlags(WrapType)) &&
1489 !isa<SCEVCouldNotCompute>(BECount) && SE->isKnownPositive(BECount))
1490 return PreStart;
1491
1492 // 2. Direct overflow check on the step operation's expression.
1493 unsigned BitWidth = SE->getTypeSizeInBits(AR->getType());
1494 Type *WideTy = IntegerType::get(SE->getContext(), BitWidth * 2);
1495 const SCEV *OperandExtendedStart =
1496 SE->getAddExpr((SE->*GetExtendExpr)(PreStart, WideTy, Depth),
1497 (SE->*GetExtendExpr)(Step, WideTy, Depth));
1498 if ((SE->*GetExtendExpr)(Start, WideTy, Depth) == OperandExtendedStart) {
1499 if (PreAR && any(AR->getNoWrapFlags(WrapType))) {
1500 // If we know `AR` == {`PreStart`+`Step`,+,`Step`} is `WrapType` (FlagNSW
1501 // or FlagNUW) and that `PreStart` + `Step` is `WrapType` too, then
1502 // `PreAR` == {`PreStart`,+,`Step`} is also `WrapType`. Cache this fact.
1503 SE->setNoWrapFlags(const_cast<SCEVAddRecExpr *>(PreAR), WrapType);
1504 }
1505 return PreStart;
1506 }
1507
1508 // 3. Loop precondition.
1510 const SCEV *OverflowLimit =
1511 ExtendOpTraits<ExtendOpTy>::getOverflowLimitForStep(Step, &Pred, SE);
1512
1513 if (OverflowLimit &&
1514 SE->isLoopEntryGuardedByCond(L, Pred, PreStart, OverflowLimit))
1515 return PreStart;
1516
1517 return nullptr;
1518}
1519
1520// Get the normalized zero or sign extended expression for this AddRec's Start.
1521template <typename ExtendOpTy>
1522static const SCEV *getExtendAddRecStart(const SCEVAddRecExpr *AR, Type *Ty,
1523 ScalarEvolution *SE,
1524 unsigned Depth) {
1525 auto GetExtendExpr = ExtendOpTraits<ExtendOpTy>::GetExtendExpr;
1526
1527 const SCEV *PreStart = getPreStartForExtend<ExtendOpTy>(AR, Ty, SE, Depth);
1528 if (!PreStart)
1529 return (SE->*GetExtendExpr)(AR->getStart(), Ty, Depth);
1530
1531 return SE->getAddExpr((SE->*GetExtendExpr)(AR->getStepRecurrence(*SE), Ty,
1532 Depth),
1533 (SE->*GetExtendExpr)(PreStart, Ty, Depth));
1534}
1535
1536// Try to prove away overflow by looking at "nearby" add recurrences. A
1537// motivating example for this rule: if we know `{0,+,4}` is `ult` `-1` and it
1538// does not itself wrap then we can conclude that `{1,+,4}` is `nuw`.
1539//
1540// Formally:
1541//
1542// {S,+,X} == {S-T,+,X} + T
1543// => Ext({S,+,X}) == Ext({S-T,+,X} + T)
1544//
1545// If ({S-T,+,X} + T) does not overflow ... (1)
1546//
1547// RHS == Ext({S-T,+,X} + T) == Ext({S-T,+,X}) + Ext(T)
1548//
1549// If {S-T,+,X} does not overflow ... (2)
1550//
1551// RHS == Ext({S-T,+,X}) + Ext(T) == {Ext(S-T),+,Ext(X)} + Ext(T)
1552// == {Ext(S-T)+Ext(T),+,Ext(X)}
1553//
1554// If (S-T)+T does not overflow ... (3)
1555//
1556// RHS == {Ext(S-T)+Ext(T),+,Ext(X)} == {Ext(S-T+T),+,Ext(X)}
1557// == {Ext(S),+,Ext(X)} == LHS
1558//
1559// Thus, if (1), (2) and (3) are true for some T, then
1560// Ext({S,+,X}) == {Ext(S),+,Ext(X)}
1561//
1562// (3) is implied by (1) -- "(S-T)+T does not overflow" is simply "({S-T,+,X}+T)
1563// does not overflow" restricted to the 0th iteration. Therefore we only need
1564// to check for (1) and (2).
1565//
1566// In the current context, S is `Start`, X is `Step`, Ext is `ExtendOpTy` and T
1567// is `Delta` (defined below).
1568template <typename ExtendOpTy>
1569bool ScalarEvolution::proveNoWrapByVaryingStart(const SCEV *Start,
1570 const SCEV *Step,
1571 const Loop *L) {
1572 auto WrapType = ExtendOpTraits<ExtendOpTy>::WrapType;
1573
1574 // We restrict `Start` to a constant to prevent SCEV from spending too much
1575 // time here. It is correct (but more expensive) to continue with a
1576 // non-constant `Start` and do a general SCEV subtraction to compute
1577 // `PreStart` below.
1578 const SCEVConstant *StartC = dyn_cast<SCEVConstant>(Start);
1579 if (!StartC)
1580 return false;
1581
1582 APInt StartAI = StartC->getAPInt();
1583
1584 for (unsigned Delta : {-2, -1, 1, 2}) {
1585 const SCEV *PreStart = getConstant(StartAI - Delta);
1586
1587 FoldingSetNodeID ID;
1588 ID.AddInteger(scAddRecExpr);
1589 ID.AddPointer(PreStart);
1590 ID.AddPointer(Step);
1591 ID.AddPointer(L);
1592 void *IP = nullptr;
1593 const auto *PreAR =
1594 static_cast<SCEVAddRecExpr *>(UniqueSCEVs.FindNodeOrInsertPos(ID, IP));
1595
1596 // Give up if we don't already have the add recurrence we need because
1597 // actually constructing an add recurrence is relatively expensive.
1598 if (PreAR && any(PreAR->getNoWrapFlags(WrapType))) { // proves (2)
1599 const SCEV *DeltaS = getConstant(StartC->getType(), Delta);
1601 const SCEV *Limit = ExtendOpTraits<ExtendOpTy>::getOverflowLimitForStep(
1602 DeltaS, &Pred, this);
1603 if (Limit && isKnownPredicate(Pred, PreAR, Limit)) // proves (1)
1604 return true;
1605 }
1606 }
1607
1608 return false;
1609}
1610
1611// Finds an integer D for an expression (C + x + y + ...) such that the top
1612// level addition in (D + (C - D + x + y + ...)) would not wrap (signed or
1613// unsigned) and the number of trailing zeros of (C - D + x + y + ...) is
1614// maximized, where C is the \p ConstantTerm, x, y, ... are arbitrary SCEVs, and
1615// the (C + x + y + ...) expression is \p WholeAddExpr.
1617 const SCEVConstant *ConstantTerm,
1618 const SCEVAddExpr *WholeAddExpr) {
1619 const APInt &C = ConstantTerm->getAPInt();
1620 const unsigned BitWidth = C.getBitWidth();
1621 // Find number of trailing zeros of (x + y + ...) w/o the C first:
1622 uint32_t TZ = BitWidth;
1623 for (unsigned I = 1, E = WholeAddExpr->getNumOperands(); I < E && TZ; ++I)
1624 TZ = std::min(TZ, SE.getMinTrailingZeros(WholeAddExpr->getOperand(I)));
1625 if (TZ) {
1626 // Set D to be as many least significant bits of C as possible while still
1627 // guaranteeing that adding D to (C - D + x + y + ...) won't cause a wrap:
1628 return TZ < BitWidth ? C.trunc(TZ).zext(BitWidth) : C;
1629 }
1630 return APInt(BitWidth, 0);
1631}
1632
1633// Finds an integer D for an affine AddRec expression {C,+,x} such that the top
1634// level addition in (D + {C-D,+,x}) would not wrap (signed or unsigned) and the
1635// number of trailing zeros of (C - D + x * n) is maximized, where C is the \p
1636// ConstantStart, x is an arbitrary \p Step, and n is the loop trip count.
1638 const APInt &ConstantStart,
1639 const SCEV *Step) {
1640 const unsigned BitWidth = ConstantStart.getBitWidth();
1641 const uint32_t TZ = SE.getMinTrailingZeros(Step);
1642 if (TZ)
1643 return TZ < BitWidth ? ConstantStart.trunc(TZ).zext(BitWidth)
1644 : ConstantStart;
1645 return APInt(BitWidth, 0);
1646}
1647
1649 const ScalarEvolution::FoldID &ID, const SCEV *S,
1652 &FoldCacheUser) {
1653 auto I = FoldCache.insert({ID, S});
1654 if (!I.second) {
1655 // Remove FoldCacheUser entry for ID when replacing an existing FoldCache
1656 // entry.
1657 auto &UserIDs = FoldCacheUser[I.first->second];
1658 assert(count(UserIDs, ID) == 1 && "unexpected duplicates in UserIDs");
1659 for (unsigned I = 0; I != UserIDs.size(); ++I)
1660 if (UserIDs[I] == ID) {
1661 std::swap(UserIDs[I], UserIDs.back());
1662 break;
1663 }
1664 UserIDs.pop_back();
1665 I.first->second = S;
1666 }
1667 FoldCacheUser[S].push_back(ID);
1668}
1669
1670const SCEV *
1672 assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) &&
1673 "This is not an extending conversion!");
1674 assert(isSCEVable(Ty) &&
1675 "This is not a conversion to a SCEVable type!");
1676 assert(!Op->getType()->isPointerTy() && "Can't extend pointer!");
1677 Ty = getEffectiveSCEVType(Ty);
1678
1679 FoldID ID(scZeroExtend, Op, Ty);
1680 if (const SCEV *S = FoldCache.lookup(ID))
1681 return S;
1682
1683 const SCEV *S = getZeroExtendExprImpl(Op, Ty, Depth);
1685 insertFoldCacheEntry(ID, S, FoldCache, FoldCacheUser);
1686 return S;
1687}
1688
1690 unsigned Depth) {
1691 assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) &&
1692 "This is not an extending conversion!");
1693 assert(isSCEVable(Ty) && "This is not a conversion to a SCEVable type!");
1694 assert(!Op->getType()->isPointerTy() && "Can't extend pointer!");
1695
1696 // Fold if the operand is constant.
1697 if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op))
1698 return getConstant(SC->getAPInt().zext(getTypeSizeInBits(Ty)));
1699
1700 // zext(zext(x)) --> zext(x)
1702 return getZeroExtendExpr(SZ->getOperand(), Ty, Depth + 1);
1703
1704 // Before doing any expensive analysis, check to see if we've already
1705 // computed a SCEV for this Op and Ty.
1707 ID.AddInteger(scZeroExtend);
1708 ID.AddPointer(Op);
1709 ID.AddPointer(Ty);
1710 void *IP = nullptr;
1711 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
1712 if (Depth > MaxCastDepth) {
1713 SCEV *S = new (SCEVAllocator) SCEVZeroExtendExpr(ID.Intern(SCEVAllocator),
1714 Op, Ty);
1715 UniqueSCEVs.InsertNode(S, IP);
1716 S->computeAndSetCanonical(*this);
1717 registerUser(S, Op);
1718 return S;
1719 }
1720
1721 // zext(trunc(x)) --> zext(x) or x or trunc(x)
1723 // It's possible the bits taken off by the truncate were all zero bits. If
1724 // so, we should be able to simplify this further.
1725 const SCEV *X = ST->getOperand();
1727 unsigned TruncBits = getTypeSizeInBits(ST->getType());
1728 unsigned NewBits = getTypeSizeInBits(Ty);
1729 if (CR.truncate(TruncBits).zeroExtend(NewBits).contains(
1730 CR.zextOrTrunc(NewBits)))
1731 return getTruncateOrZeroExtend(X, Ty, Depth);
1732 }
1733
1734 // If the input value is a chrec scev, and we can prove that the value
1735 // did not overflow the old, smaller, value, we can zero extend all of the
1736 // operands (often constants). This allows analysis of something like
1737 // this: for (unsigned char X = 0; X < 100; ++X) { int Y = X; }
1739 if (AR->isAffine()) {
1740 const SCEV *Start = AR->getStart();
1741 const SCEV *Step = AR->getStepRecurrence(*this);
1742 unsigned BitWidth = getTypeSizeInBits(AR->getType());
1743 const Loop *L = AR->getLoop();
1744
1745 // If we have special knowledge that this addrec won't overflow,
1746 // we don't need to do any further analysis.
1747 if (AR->hasNoUnsignedWrap()) {
1748 Start =
1750 Step = getZeroExtendExpr(Step, Ty, Depth + 1);
1751 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
1752 }
1753
1754 // Check whether the backedge-taken count is SCEVCouldNotCompute.
1755 // Note that this serves two purposes: It filters out loops that are
1756 // simply not analyzable, and it covers the case where this code is
1757 // being called from within backedge-taken count analysis, such that
1758 // attempting to ask for the backedge-taken count would likely result
1759 // in infinite recursion. In the later case, the analysis code will
1760 // cope with a conservative value, and it will take care to purge
1761 // that value once it has finished.
1762 const SCEV *MaxBECount = getConstantMaxBackedgeTakenCount(L);
1763 if (!isa<SCEVCouldNotCompute>(MaxBECount)) {
1764 // Manually compute the final value for AR, checking for overflow.
1765
1766 // Check whether the backedge-taken count can be losslessly casted to
1767 // the addrec's type. The count is always unsigned.
1768 const SCEV *CastedMaxBECount =
1769 getTruncateOrZeroExtend(MaxBECount, Start->getType(), Depth);
1770 const SCEV *RecastedMaxBECount = getTruncateOrZeroExtend(
1771 CastedMaxBECount, MaxBECount->getType(), Depth);
1772 if (MaxBECount == RecastedMaxBECount) {
1773 Type *WideTy = IntegerType::get(getContext(), BitWidth * 2);
1774 // Check whether Start+Step*MaxBECount has no unsigned overflow.
1775 const SCEV *ZMul = getMulExpr(CastedMaxBECount, Step,
1777 const SCEV *ZAdd = getZeroExtendExpr(getAddExpr(Start, ZMul,
1779 Depth + 1),
1780 WideTy, Depth + 1);
1781 const SCEV *WideStart = getZeroExtendExpr(Start, WideTy, Depth + 1);
1782 const SCEV *WideMaxBECount =
1783 getZeroExtendExpr(CastedMaxBECount, WideTy, Depth + 1);
1784 const SCEV *OperandExtendedAdd =
1785 getAddExpr(WideStart,
1786 getMulExpr(WideMaxBECount,
1787 getZeroExtendExpr(Step, WideTy, Depth + 1),
1790 if (ZAdd == OperandExtendedAdd) {
1791 // Cache knowledge of AR NUW, which is propagated to this AddRec.
1792 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNUW);
1793 // Return the expression with the addrec on the outside.
1794 Start = getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this,
1795 Depth + 1);
1796 Step = getZeroExtendExpr(Step, Ty, Depth + 1);
1797 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
1798 }
1799 // Similar to above, only this time treat the step value as signed.
1800 // This covers loops that count down.
1801 OperandExtendedAdd =
1802 getAddExpr(WideStart,
1803 getMulExpr(WideMaxBECount,
1804 getSignExtendExpr(Step, WideTy, Depth + 1),
1807 if (ZAdd == OperandExtendedAdd) {
1808 // Cache knowledge of AR NW, which is propagated to this AddRec.
1809 // Negative step causes unsigned wrap, but it still can't self-wrap.
1810 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNW);
1811 // Return the expression with the addrec on the outside.
1812 Start = getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this,
1813 Depth + 1);
1814 Step = getSignExtendExpr(Step, Ty, Depth + 1);
1815 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
1816 }
1817 }
1818 }
1819
1820 // Normally, in the cases we can prove no-overflow via a
1821 // backedge guarding condition, we can also compute a backedge
1822 // taken count for the loop. The exceptions are assumptions and
1823 // guards present in the loop -- SCEV is not great at exploiting
1824 // these to compute max backedge taken counts, but can still use
1825 // these to prove lack of overflow. Use this fact to avoid
1826 // doing extra work that may not pay off.
1827 if (!isa<SCEVCouldNotCompute>(MaxBECount) || HasGuards ||
1828 !AC.assumptions().empty()) {
1829
1830 auto NewFlags = proveNoUnsignedWrapViaInduction(AR);
1831 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), NewFlags);
1832 if (AR->hasNoUnsignedWrap()) {
1833 // Same as nuw case above - duplicated here to avoid a compile time
1834 // issue. It's not clear that the order of checks does matter, but
1835 // it's one of two issue possible causes for a change which was
1836 // reverted. Be conservative for the moment.
1837 Start =
1839 Step = getZeroExtendExpr(Step, Ty, Depth + 1);
1840 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
1841 }
1842
1843 // For a negative step, we can extend the operands iff doing so only
1844 // traverses values in the range zext([0,UINT_MAX]).
1845 if (isKnownNegative(Step)) {
1847 getSignedRangeMin(Step));
1850 // Cache knowledge of AR NW, which is propagated to this
1851 // AddRec. Negative step causes unsigned wrap, but it
1852 // still can't self-wrap.
1853 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNW);
1854 // Return the expression with the addrec on the outside.
1855 Start = getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this,
1856 Depth + 1);
1857 Step = getSignExtendExpr(Step, Ty, Depth + 1);
1858 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
1859 }
1860 }
1861 }
1862
1863 // zext({C,+,Step}) --> (zext(D) + zext({C-D,+,Step}))<nuw><nsw>
1864 // if D + (C - D + Step * n) could be proven to not unsigned wrap
1865 // where D maximizes the number of trailing zeros of (C - D + Step * n)
1866 if (const auto *SC = dyn_cast<SCEVConstant>(Start)) {
1867 const APInt &C = SC->getAPInt();
1868 const APInt &D = extractConstantWithoutWrapping(*this, C, Step);
1869 if (D != 0) {
1870 const SCEV *SZExtD = getZeroExtendExpr(getConstant(D), Ty, Depth);
1871 const SCEV *SResidual =
1872 getAddRecExpr(getConstant(C - D), Step, L, AR->getNoWrapFlags());
1873 const SCEV *SZExtR = getZeroExtendExpr(SResidual, Ty, Depth + 1);
1874 return getAddExpr(SZExtD, SZExtR, SCEV::FlagNSW | SCEV::FlagNUW,
1875 Depth + 1);
1876 }
1877 }
1878
1879 if (proveNoWrapByVaryingStart<SCEVZeroExtendExpr>(Start, Step, L)) {
1880 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNUW);
1881 Start =
1883 Step = getZeroExtendExpr(Step, Ty, Depth + 1);
1884 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
1885 }
1886 }
1887
1888 // zext(A % B) --> zext(A) % zext(B)
1889 {
1890 const SCEV *LHS;
1891 const SCEV *RHS;
1892 if (match(Op, m_scev_URem(m_SCEV(LHS), m_SCEV(RHS), *this)))
1893 return getURemExpr(getZeroExtendExpr(LHS, Ty, Depth + 1),
1894 getZeroExtendExpr(RHS, Ty, Depth + 1));
1895 }
1896
1897 // zext(A / B) --> zext(A) / zext(B).
1898 if (auto *Div = dyn_cast<SCEVUDivExpr>(Op))
1899 return getUDivExpr(getZeroExtendExpr(Div->getLHS(), Ty, Depth + 1),
1900 getZeroExtendExpr(Div->getRHS(), Ty, Depth + 1));
1901
1902 if (auto *SA = dyn_cast<SCEVAddExpr>(Op)) {
1903 // zext((A + B + ...)<nuw>) --> (zext(A) + zext(B) + ...)<nuw>
1904 if (SA->hasNoUnsignedWrap()) {
1905 // If the addition does not unsign overflow then we can, by definition,
1906 // commute the zero extension with the addition operation.
1908 for (SCEVUse Op : SA->operands())
1909 Ops.push_back(getZeroExtendExpr(Op, Ty, Depth + 1));
1910 return getAddExpr(Ops, SCEV::FlagNUW, Depth + 1);
1911 }
1912
1913 // zext(C + x + y + ...) --> (zext(D) + zext((C - D) + x + y + ...))
1914 // if D + (C - D + x + y + ...) could be proven to not unsigned wrap
1915 // where D maximizes the number of trailing zeros of (C - D + x + y + ...)
1916 //
1917 // Often address arithmetics contain expressions like
1918 // (zext (add (shl X, C1), C2)), for instance, (zext (5 + (4 * X))).
1919 // This transformation is useful while proving that such expressions are
1920 // equal or differ by a small constant amount, see LoadStoreVectorizer pass.
1921 if (const auto *SC = dyn_cast<SCEVConstant>(SA->getOperand(0))) {
1922 const APInt &D = extractConstantWithoutWrapping(*this, SC, SA);
1923 if (D != 0) {
1924 const SCEV *SZExtD = getZeroExtendExpr(getConstant(D), Ty, Depth);
1925 const SCEV *SResidual =
1927 const SCEV *SZExtR = getZeroExtendExpr(SResidual, Ty, Depth + 1);
1928 return getAddExpr(SZExtD, SZExtR, (SCEV::FlagNSW | SCEV::FlagNUW),
1929 Depth + 1);
1930 }
1931 }
1932 }
1933
1934 if (auto *SM = dyn_cast<SCEVMulExpr>(Op)) {
1935 // zext((A * B * ...)<nuw>) --> (zext(A) * zext(B) * ...)<nuw>
1936 if (SM->hasNoUnsignedWrap()) {
1937 // If the multiply does not unsign overflow then we can, by definition,
1938 // commute the zero extension with the multiply operation.
1940 for (SCEVUse Op : SM->operands())
1941 Ops.push_back(getZeroExtendExpr(Op, Ty, Depth + 1));
1942 return getMulExpr(Ops, SCEV::FlagNUW, Depth + 1);
1943 }
1944
1945 // zext(2^K * (trunc X to iN)) to iM ->
1946 // 2^K * (zext(trunc X to i{N-K}) to iM)<nuw>
1947 //
1948 // Proof:
1949 //
1950 // zext(2^K * (trunc X to iN)) to iM
1951 // = zext((trunc X to iN) << K) to iM
1952 // = zext((trunc X to i{N-K}) << K)<nuw> to iM
1953 // (because shl removes the top K bits)
1954 // = zext((2^K * (trunc X to i{N-K}))<nuw>) to iM
1955 // = (2^K * (zext(trunc X to i{N-K}) to iM))<nuw>.
1956 //
1957 const APInt *C;
1958 const SCEV *TruncRHS;
1959 if (match(SM,
1960 m_scev_Mul(m_scev_APInt(C), m_scev_Trunc(m_SCEV(TruncRHS)))) &&
1961 C->isPowerOf2()) {
1962 int NewTruncBits =
1963 getTypeSizeInBits(SM->getOperand(1)->getType()) - C->logBase2();
1964 Type *NewTruncTy = IntegerType::get(getContext(), NewTruncBits);
1965 return getMulExpr(
1966 getZeroExtendExpr(SM->getOperand(0), Ty),
1967 getZeroExtendExpr(getTruncateExpr(TruncRHS, NewTruncTy), Ty),
1968 SCEV::FlagNUW, Depth + 1);
1969 }
1970 }
1971
1972 // zext(umin(x, y)) -> umin(zext(x), zext(y))
1973 // zext(umax(x, y)) -> umax(zext(x), zext(y))
1976 SmallVector<SCEVUse, 4> Operands;
1977 for (SCEVUse Operand : MinMax->operands())
1978 Operands.push_back(getZeroExtendExpr(Operand, Ty));
1980 return getUMinExpr(Operands);
1981 return getUMaxExpr(Operands);
1982 }
1983
1984 // zext(umin_seq(x, y)) -> umin_seq(zext(x), zext(y))
1986 assert(isa<SCEVSequentialUMinExpr>(MinMax) && "Not supported!");
1987 SmallVector<SCEVUse, 4> Operands;
1988 for (SCEVUse Operand : MinMax->operands())
1989 Operands.push_back(getZeroExtendExpr(Operand, Ty));
1990 return getUMinExpr(Operands, /*Sequential*/ true);
1991 }
1992
1993 // The cast wasn't folded; create an explicit cast node.
1994 // Recompute the insert position, as it may have been invalidated.
1995 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
1996 SCEV *S = new (SCEVAllocator) SCEVZeroExtendExpr(ID.Intern(SCEVAllocator),
1997 Op, Ty);
1998 UniqueSCEVs.InsertNode(S, IP);
1999 S->computeAndSetCanonical(*this);
2000 registerUser(S, Op);
2001 return S;
2002}
2003
2004const SCEV *
2006 assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) &&
2007 "This is not an extending conversion!");
2008 assert(isSCEVable(Ty) &&
2009 "This is not a conversion to a SCEVable type!");
2010 assert(!Op->getType()->isPointerTy() && "Can't extend pointer!");
2011 Ty = getEffectiveSCEVType(Ty);
2012
2013 FoldID ID(scSignExtend, Op, Ty);
2014 if (const SCEV *S = FoldCache.lookup(ID))
2015 return S;
2016
2017 const SCEV *S = getSignExtendExprImpl(Op, Ty, Depth);
2019 insertFoldCacheEntry(ID, S, FoldCache, FoldCacheUser);
2020 return S;
2021}
2022
2024 unsigned Depth) {
2025 assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) &&
2026 "This is not an extending conversion!");
2027 assert(isSCEVable(Ty) && "This is not a conversion to a SCEVable type!");
2028 assert(!Op->getType()->isPointerTy() && "Can't extend pointer!");
2029 Ty = getEffectiveSCEVType(Ty);
2030
2031 // Fold if the operand is constant.
2032 if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op))
2033 return getConstant(SC->getAPInt().sext(getTypeSizeInBits(Ty)));
2034
2035 // sext(sext(x)) --> sext(x)
2037 return getSignExtendExpr(SS->getOperand(), Ty, Depth + 1);
2038
2039 // sext(zext(x)) --> zext(x)
2041 return getZeroExtendExpr(SZ->getOperand(), Ty, Depth + 1);
2042
2043 // Before doing any expensive analysis, check to see if we've already
2044 // computed a SCEV for this Op and Ty.
2046 ID.AddInteger(scSignExtend);
2047 ID.AddPointer(Op);
2048 ID.AddPointer(Ty);
2049 void *IP = nullptr;
2050 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
2051 // Limit recursion depth.
2052 if (Depth > MaxCastDepth) {
2053 SCEV *S = new (SCEVAllocator) SCEVSignExtendExpr(ID.Intern(SCEVAllocator),
2054 Op, Ty);
2055 UniqueSCEVs.InsertNode(S, IP);
2056 S->computeAndSetCanonical(*this);
2057 registerUser(S, Op);
2058 return S;
2059 }
2060
2061 // sext(trunc(x)) --> sext(x) or x or trunc(x)
2063 // It's possible the bits taken off by the truncate were all sign bits. If
2064 // so, we should be able to simplify this further.
2065 const SCEV *X = ST->getOperand();
2067 unsigned TruncBits = getTypeSizeInBits(ST->getType());
2068 unsigned NewBits = getTypeSizeInBits(Ty);
2069 if (CR.truncate(TruncBits).signExtend(NewBits).contains(
2070 CR.sextOrTrunc(NewBits)))
2071 return getTruncateOrSignExtend(X, Ty, Depth);
2072 }
2073
2074 if (auto *SA = dyn_cast<SCEVAddExpr>(Op)) {
2075 // sext((A + B + ...)<nsw>) --> (sext(A) + sext(B) + ...)<nsw>
2076 if (SA->hasNoSignedWrap()) {
2077 // If the addition does not sign overflow then we can, by definition,
2078 // commute the sign extension with the addition operation.
2080 for (SCEVUse Op : SA->operands())
2081 Ops.push_back(getSignExtendExpr(Op, Ty, Depth + 1));
2082 return getAddExpr(Ops, SCEV::FlagNSW, Depth + 1);
2083 }
2084
2085 // sext(C + x + y + ...) --> (sext(D) + sext((C - D) + x + y + ...))
2086 // if D + (C - D + x + y + ...) could be proven to not signed wrap
2087 // where D maximizes the number of trailing zeros of (C - D + x + y + ...)
2088 //
2089 // For instance, this will bring two seemingly different expressions:
2090 // 1 + sext(5 + 20 * %x + 24 * %y) and
2091 // sext(6 + 20 * %x + 24 * %y)
2092 // to the same form:
2093 // 2 + sext(4 + 20 * %x + 24 * %y)
2094 if (const auto *SC = dyn_cast<SCEVConstant>(SA->getOperand(0))) {
2095 const APInt &D = extractConstantWithoutWrapping(*this, SC, SA);
2096 if (D != 0) {
2097 const SCEV *SSExtD = getSignExtendExpr(getConstant(D), Ty, Depth);
2098 const SCEV *SResidual =
2100 const SCEV *SSExtR = getSignExtendExpr(SResidual, Ty, Depth + 1);
2101 return getAddExpr(SSExtD, SSExtR, (SCEV::FlagNSW | SCEV::FlagNUW),
2102 Depth + 1);
2103 }
2104 }
2105 }
2106 // If the input value is a chrec scev, and we can prove that the value
2107 // did not overflow the old, smaller, value, we can sign extend all of the
2108 // operands (often constants). This allows analysis of something like
2109 // this: for (signed char X = 0; X < 100; ++X) { int Y = X; }
2111 if (AR->isAffine()) {
2112 const SCEV *Start = AR->getStart();
2113 const SCEV *Step = AR->getStepRecurrence(*this);
2114 unsigned BitWidth = getTypeSizeInBits(AR->getType());
2115 const Loop *L = AR->getLoop();
2116
2117 // If we have special knowledge that this addrec won't overflow,
2118 // we don't need to do any further analysis.
2119 if (AR->hasNoSignedWrap()) {
2120 Start =
2122 Step = getSignExtendExpr(Step, Ty, Depth + 1);
2123 return getAddRecExpr(Start, Step, L, SCEV::FlagNSW);
2124 }
2125
2126 // Check whether the backedge-taken count is SCEVCouldNotCompute.
2127 // Note that this serves two purposes: It filters out loops that are
2128 // simply not analyzable, and it covers the case where this code is
2129 // being called from within backedge-taken count analysis, such that
2130 // attempting to ask for the backedge-taken count would likely result
2131 // in infinite recursion. In the later case, the analysis code will
2132 // cope with a conservative value, and it will take care to purge
2133 // that value once it has finished.
2134 const SCEV *MaxBECount = getConstantMaxBackedgeTakenCount(L);
2135 if (!isa<SCEVCouldNotCompute>(MaxBECount)) {
2136 // Manually compute the final value for AR, checking for
2137 // overflow.
2138
2139 // Check whether the backedge-taken count can be losslessly casted to
2140 // the addrec's type. The count is always unsigned.
2141 const SCEV *CastedMaxBECount =
2142 getTruncateOrZeroExtend(MaxBECount, Start->getType(), Depth);
2143 const SCEV *RecastedMaxBECount = getTruncateOrZeroExtend(
2144 CastedMaxBECount, MaxBECount->getType(), Depth);
2145 if (MaxBECount == RecastedMaxBECount) {
2146 Type *WideTy = IntegerType::get(getContext(), BitWidth * 2);
2147 // Check whether Start+Step*MaxBECount has no signed overflow.
2148 const SCEV *SMul = getMulExpr(CastedMaxBECount, Step,
2150 const SCEV *SAdd = getSignExtendExpr(getAddExpr(Start, SMul,
2152 Depth + 1),
2153 WideTy, Depth + 1);
2154 const SCEV *WideStart = getSignExtendExpr(Start, WideTy, Depth + 1);
2155 const SCEV *WideMaxBECount =
2156 getZeroExtendExpr(CastedMaxBECount, WideTy, Depth + 1);
2157 const SCEV *OperandExtendedAdd =
2158 getAddExpr(WideStart,
2159 getMulExpr(WideMaxBECount,
2160 getSignExtendExpr(Step, WideTy, Depth + 1),
2163 if (SAdd == OperandExtendedAdd) {
2164 // Cache knowledge of AR NSW, which is propagated to this AddRec.
2165 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNSW);
2166 // Return the expression with the addrec on the outside.
2167 Start = getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this,
2168 Depth + 1);
2169 Step = getSignExtendExpr(Step, Ty, Depth + 1);
2170 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
2171 }
2172 // Similar to above, only this time treat the step value as unsigned.
2173 // This covers loops that count up with an unsigned step.
2174 OperandExtendedAdd =
2175 getAddExpr(WideStart,
2176 getMulExpr(WideMaxBECount,
2177 getZeroExtendExpr(Step, WideTy, Depth + 1),
2180 if (SAdd == OperandExtendedAdd) {
2181 // If AR wraps around then
2182 //
2183 // abs(Step) * MaxBECount > unsigned-max(AR->getType())
2184 // => SAdd != OperandExtendedAdd
2185 //
2186 // Thus (AR is not NW => SAdd != OperandExtendedAdd) <=>
2187 // (SAdd == OperandExtendedAdd => AR is NW)
2188
2189 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNW);
2190
2191 // Return the expression with the addrec on the outside.
2192 Start = getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this,
2193 Depth + 1);
2194 Step = getZeroExtendExpr(Step, Ty, Depth + 1);
2195 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
2196 }
2197 }
2198 }
2199
2200 auto NewFlags = proveNoSignedWrapViaInduction(AR);
2201 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), NewFlags);
2202 if (AR->hasNoSignedWrap()) {
2203 // Same as nsw case above - duplicated here to avoid a compile time
2204 // issue. It's not clear that the order of checks does matter, but
2205 // it's one of two issue possible causes for a change which was
2206 // reverted. Be conservative for the moment.
2207 Start =
2209 Step = getSignExtendExpr(Step, Ty, Depth + 1);
2210 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
2211 }
2212
2213 // sext({C,+,Step}) --> (sext(D) + sext({C-D,+,Step}))<nuw><nsw>
2214 // if D + (C - D + Step * n) could be proven to not signed wrap
2215 // where D maximizes the number of trailing zeros of (C - D + Step * n)
2216 if (const auto *SC = dyn_cast<SCEVConstant>(Start)) {
2217 const APInt &C = SC->getAPInt();
2218 const APInt &D = extractConstantWithoutWrapping(*this, C, Step);
2219 if (D != 0) {
2220 const SCEV *SSExtD = getSignExtendExpr(getConstant(D), Ty, Depth);
2221 const SCEV *SResidual =
2222 getAddRecExpr(getConstant(C - D), Step, L, AR->getNoWrapFlags());
2223 const SCEV *SSExtR = getSignExtendExpr(SResidual, Ty, Depth + 1);
2224 return getAddExpr(SSExtD, SSExtR, (SCEV::FlagNSW | SCEV::FlagNUW),
2225 Depth + 1);
2226 }
2227 }
2228
2229 if (proveNoWrapByVaryingStart<SCEVSignExtendExpr>(Start, Step, L)) {
2230 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNSW);
2231 Start =
2233 Step = getSignExtendExpr(Step, Ty, Depth + 1);
2234 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
2235 }
2236 }
2237
2238 // If the input value is provably positive and we could not simplify
2239 // away the sext build a zext instead.
2241 return getZeroExtendExpr(Op, Ty, Depth + 1);
2242
2243 // sext(smin(x, y)) -> smin(sext(x), sext(y))
2244 // sext(smax(x, y)) -> smax(sext(x), sext(y))
2247 SmallVector<SCEVUse, 4> Operands;
2248 for (SCEVUse Operand : MinMax->operands())
2249 Operands.push_back(getSignExtendExpr(Operand, Ty));
2251 return getSMinExpr(Operands);
2252 return getSMaxExpr(Operands);
2253 }
2254
2255 // The cast wasn't folded; create an explicit cast node.
2256 // Recompute the insert position, as it may have been invalidated.
2257 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
2258 SCEV *S = new (SCEVAllocator) SCEVSignExtendExpr(ID.Intern(SCEVAllocator),
2259 Op, Ty);
2260 UniqueSCEVs.InsertNode(S, IP);
2261 S->computeAndSetCanonical(*this);
2262 registerUser(S, Op);
2263 return S;
2264}
2265
2267 Type *Ty) {
2268 switch (Kind) {
2269 case scTruncate:
2270 return getTruncateExpr(Op, Ty);
2271 case scZeroExtend:
2272 return getZeroExtendExpr(Op, Ty);
2273 case scSignExtend:
2274 return getSignExtendExpr(Op, Ty);
2275 case scPtrToInt:
2276 return getPtrToIntExpr(Op, Ty);
2277 default:
2278 llvm_unreachable("Not a SCEV cast expression!");
2279 }
2280}
2281
2282/// getAnyExtendExpr - Return a SCEV for the given operand extended with
2283/// unspecified bits out to the given type.
2285 Type *Ty) {
2286 assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) &&
2287 "This is not an extending conversion!");
2288 assert(isSCEVable(Ty) &&
2289 "This is not a conversion to a SCEVable type!");
2290 Ty = getEffectiveSCEVType(Ty);
2291
2292 // Sign-extend negative constants.
2293 if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op))
2294 if (SC->getAPInt().isNegative())
2295 return getSignExtendExpr(Op, Ty);
2296
2297 // Peel off a truncate cast.
2299 const SCEV *NewOp = T->getOperand();
2300 if (getTypeSizeInBits(NewOp->getType()) < getTypeSizeInBits(Ty))
2301 return getAnyExtendExpr(NewOp, Ty);
2302 return getTruncateOrNoop(NewOp, Ty);
2303 }
2304
2305 // Next try a zext cast. If the cast is folded, use it.
2306 const SCEV *ZExt = getZeroExtendExpr(Op, Ty);
2307 if (!isa<SCEVZeroExtendExpr>(ZExt))
2308 return ZExt;
2309
2310 // Next try a sext cast. If the cast is folded, use it.
2311 const SCEV *SExt = getSignExtendExpr(Op, Ty);
2312 if (!isa<SCEVSignExtendExpr>(SExt))
2313 return SExt;
2314
2315 // Force the cast to be folded into the operands of an addrec.
2316 if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Op)) {
2318 for (const SCEV *Op : AR->operands())
2319 Ops.push_back(getAnyExtendExpr(Op, Ty));
2320 return getAddRecExpr(Ops, AR->getLoop(), SCEV::FlagNW);
2321 }
2322
2323 // If the expression is obviously signed, use the sext cast value.
2324 if (isa<SCEVSMaxExpr>(Op))
2325 return SExt;
2326
2327 // Absent any other information, use the zext cast value.
2328 return ZExt;
2329}
2330
2331/// Process the given Ops list, which is a list of operands to be added under
2332/// the given scale, update the given map. This is a helper function for
2333/// getAddRecExpr. As an example of what it does, given a sequence of operands
2334/// that would form an add expression like this:
2335///
2336/// m + n + 13 + (A * (o + p + (B * (q + m + 29)))) + r + (-1 * r)
2337///
2338/// where A and B are constants, update the map with these values:
2339///
2340/// (m, 1+A*B), (n, 1), (o, A), (p, A), (q, A*B), (r, 0)
2341///
2342/// and add 13 + A*B*29 to AccumulatedConstant.
2343/// This will allow getAddRecExpr to produce this:
2344///
2345/// 13+A*B*29 + n + (m * (1+A*B)) + ((o + p) * A) + (q * A*B)
2346///
2347/// This form often exposes folding opportunities that are hidden in
2348/// the original operand list.
2349///
2350/// Return true iff it appears that any interesting folding opportunities
2351/// may be exposed. This helps getAddRecExpr short-circuit extra work in
2352/// the common case where no interesting opportunities are present, and
2353/// is also used as a check to avoid infinite recursion.
2356 APInt &AccumulatedConstant,
2358 const APInt &Scale,
2359 ScalarEvolution &SE) {
2360 bool Interesting = false;
2361
2362 // Iterate over the add operands. They are sorted, with constants first.
2363 unsigned i = 0;
2364 while (const SCEVConstant *C = dyn_cast<SCEVConstant>(Ops[i])) {
2365 ++i;
2366 // Pull a buried constant out to the outside.
2367 if (Scale != 1 || AccumulatedConstant != 0 || C->getValue()->isZero())
2368 Interesting = true;
2369 AccumulatedConstant += Scale * C->getAPInt();
2370 }
2371
2372 // Next comes everything else. We're especially interested in multiplies
2373 // here, but they're in the middle, so just visit the rest with one loop.
2374 for (; i != Ops.size(); ++i) {
2376 if (Mul && isa<SCEVConstant>(Mul->getOperand(0))) {
2377 APInt NewScale =
2378 Scale * cast<SCEVConstant>(Mul->getOperand(0))->getAPInt();
2379 if (Mul->getNumOperands() == 2 && isa<SCEVAddExpr>(Mul->getOperand(1))) {
2380 // A multiplication of a constant with another add; recurse.
2381 const SCEVAddExpr *Add = cast<SCEVAddExpr>(Mul->getOperand(1));
2382 Interesting |= CollectAddOperandsWithScales(
2383 M, NewOps, AccumulatedConstant, Add->operands(), NewScale, SE);
2384 } else {
2385 // A multiplication of a constant with some other value. Update
2386 // the map.
2387 SmallVector<SCEVUse, 4> MulOps(drop_begin(Mul->operands()));
2388 const SCEV *Key = SE.getMulExpr(MulOps);
2389 auto Pair = M.insert({Key, NewScale});
2390 if (Pair.second) {
2391 NewOps.push_back(Pair.first->first);
2392 } else {
2393 Pair.first->second += NewScale;
2394 // The map already had an entry for this value, which may indicate
2395 // a folding opportunity.
2396 Interesting = true;
2397 }
2398 }
2399 } else {
2400 // An ordinary operand. Update the map.
2401 auto Pair = M.insert({Ops[i], Scale});
2402 if (Pair.second) {
2403 NewOps.push_back(Pair.first->first);
2404 } else {
2405 Pair.first->second += Scale;
2406 // The map already had an entry for this value, which may indicate
2407 // a folding opportunity.
2408 Interesting = true;
2409 }
2410 }
2411 }
2412
2413 return Interesting;
2414}
2415
2417 const SCEV *LHS, const SCEV *RHS,
2418 const Instruction *CtxI) {
2420 unsigned);
2421 switch (BinOp) {
2422 default:
2423 llvm_unreachable("Unsupported binary op");
2424 case Instruction::Add:
2426 break;
2427 case Instruction::Sub:
2429 break;
2430 case Instruction::Mul:
2432 break;
2433 }
2434
2435 const SCEV *(ScalarEvolution::*Extension)(const SCEV *, Type *, unsigned) =
2438
2439 // Check ext(LHS op RHS) == ext(LHS) op ext(RHS)
2440 auto *NarrowTy = cast<IntegerType>(LHS->getType());
2441 auto *WideTy =
2442 IntegerType::get(NarrowTy->getContext(), NarrowTy->getBitWidth() * 2);
2443
2444 const SCEV *A = (this->*Extension)(
2445 (this->*Operation)(LHS, RHS, SCEV::FlagAnyWrap, 0), WideTy, 0);
2446 const SCEV *LHSB = (this->*Extension)(LHS, WideTy, 0);
2447 const SCEV *RHSB = (this->*Extension)(RHS, WideTy, 0);
2448 const SCEV *B = (this->*Operation)(LHSB, RHSB, SCEV::FlagAnyWrap, 0);
2449 if (A == B)
2450 return true;
2451 // Can we use context to prove the fact we need?
2452 if (!CtxI)
2453 return false;
2454 // TODO: Support mul.
2455 if (BinOp == Instruction::Mul)
2456 return false;
2457 auto *RHSC = dyn_cast<SCEVConstant>(RHS);
2458 // TODO: Lift this limitation.
2459 if (!RHSC)
2460 return false;
2461 APInt C = RHSC->getAPInt();
2462 unsigned NumBits = C.getBitWidth();
2463 bool IsSub = (BinOp == Instruction::Sub);
2464 bool IsNegativeConst = (Signed && C.isNegative());
2465 // Compute the direction and magnitude by which we need to check overflow.
2466 bool OverflowDown = IsSub ^ IsNegativeConst;
2467 APInt Magnitude = C;
2468 if (IsNegativeConst) {
2469 if (C == APInt::getSignedMinValue(NumBits))
2470 // TODO: SINT_MIN on inversion gives the same negative value, we don't
2471 // want to deal with that.
2472 return false;
2473 Magnitude = -C;
2474 }
2475
2477 if (OverflowDown) {
2478 // To avoid overflow down, we need to make sure that MIN + Magnitude <= LHS.
2479 APInt Min = Signed ? APInt::getSignedMinValue(NumBits)
2480 : APInt::getMinValue(NumBits);
2481 APInt Limit = Min + Magnitude;
2482 return isKnownPredicateAt(Pred, getConstant(Limit), LHS, CtxI);
2483 } else {
2484 // To avoid overflow up, we need to make sure that LHS <= MAX - Magnitude.
2485 APInt Max = Signed ? APInt::getSignedMaxValue(NumBits)
2486 : APInt::getMaxValue(NumBits);
2487 APInt Limit = Max - Magnitude;
2488 return isKnownPredicateAt(Pred, LHS, getConstant(Limit), CtxI);
2489 }
2490}
2491
2492std::optional<SCEV::NoWrapFlags>
2494 const OverflowingBinaryOperator *OBO) {
2495 // It cannot be done any better.
2496 if (OBO->hasNoUnsignedWrap() && OBO->hasNoSignedWrap())
2497 return std::nullopt;
2498
2499 SCEV::NoWrapFlags Flags = SCEV::NoWrapFlags::FlagAnyWrap;
2500
2501 if (OBO->hasNoUnsignedWrap())
2503 if (OBO->hasNoSignedWrap())
2505
2506 bool Deduced = false;
2507
2508 if (OBO->getOpcode() != Instruction::Add &&
2509 OBO->getOpcode() != Instruction::Sub &&
2510 OBO->getOpcode() != Instruction::Mul)
2511 return std::nullopt;
2512
2513 const SCEV *LHS = getSCEV(OBO->getOperand(0));
2514 const SCEV *RHS = getSCEV(OBO->getOperand(1));
2515
2516 const Instruction *CtxI =
2518 if (!OBO->hasNoUnsignedWrap() &&
2520 /* Signed */ false, LHS, RHS, CtxI)) {
2522 Deduced = true;
2523 }
2524
2525 if (!OBO->hasNoSignedWrap() &&
2527 /* Signed */ true, LHS, RHS, CtxI)) {
2529 Deduced = true;
2530 }
2531
2532 if (Deduced)
2533 return Flags;
2534 return std::nullopt;
2535}
2536
2537// We're trying to construct a SCEV of type `Type' with `Ops' as operands and
2538// `OldFlags' as can't-wrap behavior. Infer a more aggressive set of
2539// can't-overflow flags for the operation if possible.
2543 SCEV::NoWrapFlags Flags) {
2544 using namespace std::placeholders;
2545
2546 using OBO = OverflowingBinaryOperator;
2547
2548 bool CanAnalyze =
2550 (void)CanAnalyze;
2551 assert(CanAnalyze && "don't call from other places!");
2552
2553 SCEV::NoWrapFlags SignOrUnsignMask = SCEV::FlagNUW | SCEV::FlagNSW;
2554 SCEV::NoWrapFlags SignOrUnsignWrap =
2555 ScalarEvolution::maskFlags(Flags, SignOrUnsignMask);
2556
2557 // If FlagNSW is true and all the operands are non-negative, infer FlagNUW.
2558 auto IsKnownNonNegative = [&](SCEVUse U) {
2559 return SE->isKnownNonNegative(U);
2560 };
2561
2562 if (SignOrUnsignWrap == SCEV::FlagNSW && all_of(Ops, IsKnownNonNegative))
2563 Flags = ScalarEvolution::setFlags(Flags, SignOrUnsignMask);
2564
2565 SignOrUnsignWrap = ScalarEvolution::maskFlags(Flags, SignOrUnsignMask);
2566
2567 if (SignOrUnsignWrap != SignOrUnsignMask &&
2568 (Type == scAddExpr || Type == scMulExpr) && Ops.size() == 2 &&
2569 isa<SCEVConstant>(Ops[0])) {
2570
2571 auto Opcode = [&] {
2572 switch (Type) {
2573 case scAddExpr:
2574 return Instruction::Add;
2575 case scMulExpr:
2576 return Instruction::Mul;
2577 default:
2578 llvm_unreachable("Unexpected SCEV op.");
2579 }
2580 }();
2581
2582 const APInt &C = cast<SCEVConstant>(Ops[0])->getAPInt();
2583
2584 // (A <opcode> C) --> (A <opcode> C)<nsw> if the op doesn't sign overflow.
2585 if (!(SignOrUnsignWrap & SCEV::FlagNSW)) {
2587 Opcode, C, OBO::NoSignedWrap);
2588 if (NSWRegion.contains(SE->getSignedRange(Ops[1])))
2590 }
2591
2592 // (A <opcode> C) --> (A <opcode> C)<nuw> if the op doesn't unsign overflow.
2593 if (!(SignOrUnsignWrap & SCEV::FlagNUW)) {
2595 Opcode, C, OBO::NoUnsignedWrap);
2596 if (NUWRegion.contains(SE->getUnsignedRange(Ops[1])))
2598 }
2599 }
2600
2601 // <0,+,nonnegative><nw> is also nuw
2602 // TODO: Add corresponding nsw case
2604 !ScalarEvolution::hasFlags(Flags, SCEV::FlagNUW) && Ops.size() == 2 &&
2605 Ops[0]->isZero() && IsKnownNonNegative(Ops[1]))
2607
2608 // both (udiv X, Y) * Y and Y * (udiv X, Y) are always NUW
2610 Ops.size() == 2) {
2611 if (auto *UDiv = dyn_cast<SCEVUDivExpr>(Ops[0]))
2612 if (UDiv->getOperand(1) == Ops[1])
2614 if (auto *UDiv = dyn_cast<SCEVUDivExpr>(Ops[1]))
2615 if (UDiv->getOperand(1) == Ops[0])
2617 }
2618
2619 return Flags;
2620}
2621
2623 return isLoopInvariant(S, L) && properlyDominates(S, L->getHeader());
2624}
2625
2626/// Get a canonical add expression, or something simpler if possible.
2628 SCEV::NoWrapFlags OrigFlags,
2629 unsigned Depth) {
2630 assert(!(OrigFlags & ~(SCEV::FlagNUW | SCEV::FlagNSW)) &&
2631 "only nuw or nsw allowed");
2632 assert(!Ops.empty() && "Cannot get empty add!");
2633 if (Ops.size() == 1) return Ops[0];
2634#ifndef NDEBUG
2635 Type *ETy = getEffectiveSCEVType(Ops[0]->getType());
2636 for (unsigned i = 1, e = Ops.size(); i != e; ++i)
2637 assert(getEffectiveSCEVType(Ops[i]->getType()) == ETy &&
2638 "SCEVAddExpr operand types don't match!");
2639 unsigned NumPtrs = count_if(
2640 Ops, [](const SCEV *Op) { return Op->getType()->isPointerTy(); });
2641 assert(NumPtrs <= 1 && "add has at most one pointer operand");
2642#endif
2643
2644 const SCEV *Folded = constantFoldAndGroupOps(
2645 *this, LI, DT, Ops,
2646 [](const APInt &C1, const APInt &C2) { return C1 + C2; },
2647 [](const APInt &C) { return C.isZero(); }, // identity
2648 [](const APInt &C) { return false; }); // absorber
2649 if (Folded)
2650 return Folded;
2651
2652 unsigned Idx = isa<SCEVConstant>(Ops[0]) ? 1 : 0;
2653
2654 // Delay expensive flag strengthening until necessary.
2655 auto ComputeFlags = [this, OrigFlags](ArrayRef<SCEVUse> Ops) {
2656 return StrengthenNoWrapFlags(this, scAddExpr, Ops, OrigFlags);
2657 };
2658
2659 // Limit recursion calls depth.
2661 return getOrCreateAddExpr(Ops, ComputeFlags(Ops));
2662
2663 if (SCEV *S = findExistingSCEVInCache(scAddExpr, Ops)) {
2664 // Don't strengthen flags if we have no new information.
2665 SCEVAddExpr *Add = static_cast<SCEVAddExpr *>(S);
2666 if (Add->getNoWrapFlags(OrigFlags) != OrigFlags)
2667 Add->setNoWrapFlags(ComputeFlags(Ops));
2668 return S;
2669 }
2670
2671 // Okay, check to see if the same value occurs in the operand list more than
2672 // once. If so, merge them together into an multiply expression. Since we
2673 // sorted the list, these values are required to be adjacent.
2674 Type *Ty = Ops[0]->getType();
2675 bool FoundMatch = false;
2676 for (unsigned i = 0, e = Ops.size(); i != e-1; ++i)
2677 if (Ops[i] == Ops[i+1]) { // X + Y + Y --> X + Y*2
2678 // Scan ahead to count how many equal operands there are.
2679 unsigned Count = 2;
2680 while (i+Count != e && Ops[i+Count] == Ops[i])
2681 ++Count;
2682 // Merge the values into a multiply.
2683 SCEVUse Scale = getConstant(Ty, Count);
2684 const SCEV *Mul = getMulExpr(Scale, Ops[i], SCEV::FlagAnyWrap, Depth + 1);
2685 if (Ops.size() == Count)
2686 return Mul;
2687 Ops[i] = Mul;
2688 Ops.erase(Ops.begin()+i+1, Ops.begin()+i+Count);
2689 --i; e -= Count - 1;
2690 FoundMatch = true;
2691 }
2692 if (FoundMatch)
2693 return getAddExpr(Ops, OrigFlags, Depth + 1);
2694
2695 // Check for truncates. If all the operands are truncated from the same
2696 // type, see if factoring out the truncate would permit the result to be
2697 // folded. eg., n*trunc(x) + m*trunc(y) --> trunc(trunc(m)*x + trunc(n)*y)
2698 // if the contents of the resulting outer trunc fold to something simple.
2699 auto FindTruncSrcType = [&]() -> Type * {
2700 // We're ultimately looking to fold an addrec of truncs and muls of only
2701 // constants and truncs, so if we find any other types of SCEV
2702 // as operands of the addrec then we bail and return nullptr here.
2703 // Otherwise, we return the type of the operand of a trunc that we find.
2704 if (auto *T = dyn_cast<SCEVTruncateExpr>(Ops[Idx]))
2705 return T->getOperand()->getType();
2706 if (const auto *Mul = dyn_cast<SCEVMulExpr>(Ops[Idx])) {
2707 SCEVUse LastOp = Mul->getOperand(Mul->getNumOperands() - 1);
2708 if (const auto *T = dyn_cast<SCEVTruncateExpr>(LastOp))
2709 return T->getOperand()->getType();
2710 }
2711 return nullptr;
2712 };
2713 if (auto *SrcType = FindTruncSrcType()) {
2714 SmallVector<SCEVUse, 8> LargeOps;
2715 bool Ok = true;
2716 // Check all the operands to see if they can be represented in the
2717 // source type of the truncate.
2718 for (const SCEV *Op : Ops) {
2720 if (T->getOperand()->getType() != SrcType) {
2721 Ok = false;
2722 break;
2723 }
2724 LargeOps.push_back(T->getOperand());
2725 } else if (const SCEVConstant *C = dyn_cast<SCEVConstant>(Op)) {
2726 LargeOps.push_back(getAnyExtendExpr(C, SrcType));
2727 } else if (const SCEVMulExpr *M = dyn_cast<SCEVMulExpr>(Op)) {
2728 SmallVector<SCEVUse, 8> LargeMulOps;
2729 for (unsigned j = 0, f = M->getNumOperands(); j != f && Ok; ++j) {
2730 if (const SCEVTruncateExpr *T =
2731 dyn_cast<SCEVTruncateExpr>(M->getOperand(j))) {
2732 if (T->getOperand()->getType() != SrcType) {
2733 Ok = false;
2734 break;
2735 }
2736 LargeMulOps.push_back(T->getOperand());
2737 } else if (const auto *C = dyn_cast<SCEVConstant>(M->getOperand(j))) {
2738 LargeMulOps.push_back(getAnyExtendExpr(C, SrcType));
2739 } else {
2740 Ok = false;
2741 break;
2742 }
2743 }
2744 if (Ok)
2745 LargeOps.push_back(getMulExpr(LargeMulOps, SCEV::FlagAnyWrap, Depth + 1));
2746 } else {
2747 Ok = false;
2748 break;
2749 }
2750 }
2751 if (Ok) {
2752 // Evaluate the expression in the larger type.
2753 const SCEV *Fold = getAddExpr(LargeOps, SCEV::FlagAnyWrap, Depth + 1);
2754 // If it folds to something simple, use it. Otherwise, don't.
2755 if (isa<SCEVConstant>(Fold) || isa<SCEVUnknown>(Fold))
2756 return getTruncateExpr(Fold, Ty);
2757 }
2758 }
2759
2760 if (Ops.size() == 2) {
2761 // Check if we have an expression of the form ((X + C1) - C2), where C1 and
2762 // C2 can be folded in a way that allows retaining wrapping flags of (X +
2763 // C1).
2764 const SCEV *A = Ops[0];
2765 const SCEV *B = Ops[1];
2766 auto *AddExpr = dyn_cast<SCEVAddExpr>(B);
2767 auto *C = dyn_cast<SCEVConstant>(A);
2768 if (AddExpr && C && isa<SCEVConstant>(AddExpr->getOperand(0))) {
2769 auto C1 = cast<SCEVConstant>(AddExpr->getOperand(0))->getAPInt();
2770 auto C2 = C->getAPInt();
2771 SCEV::NoWrapFlags PreservedFlags = SCEV::FlagAnyWrap;
2772
2773 APInt ConstAdd = C1 + C2;
2774 auto AddFlags = AddExpr->getNoWrapFlags();
2775 // Adding a smaller constant is NUW if the original AddExpr was NUW.
2777 ConstAdd.ule(C1)) {
2778 PreservedFlags =
2780 }
2781
2782 // Adding a constant with the same sign and small magnitude is NSW, if the
2783 // original AddExpr was NSW.
2785 C1.isSignBitSet() == ConstAdd.isSignBitSet() &&
2786 ConstAdd.abs().ule(C1.abs())) {
2787 PreservedFlags =
2789 }
2790
2791 if (PreservedFlags != SCEV::FlagAnyWrap) {
2792 SmallVector<SCEVUse, 4> NewOps(AddExpr->operands());
2793 NewOps[0] = getConstant(ConstAdd);
2794 return getAddExpr(NewOps, PreservedFlags);
2795 }
2796 }
2797
2798 // Try to push the constant operand into a ZExt: A + zext (-A + B) -> zext
2799 // (B), if trunc (A) + -A + B does not unsigned-wrap.
2800 const SCEVAddExpr *InnerAdd;
2801 if (match(B, m_scev_ZExt(m_scev_Add(InnerAdd)))) {
2802 const SCEV *NarrowA = getTruncateExpr(A, InnerAdd->getType());
2803 if (NarrowA == getNegativeSCEV(InnerAdd->getOperand(0)) &&
2804 getZeroExtendExpr(NarrowA, B->getType()) == A &&
2805 hasFlags(StrengthenNoWrapFlags(this, scAddExpr, {NarrowA, InnerAdd},
2807 SCEV::FlagNUW)) {
2808 return getZeroExtendExpr(getAddExpr(NarrowA, InnerAdd), B->getType());
2809 }
2810 }
2811 }
2812
2813 // Canonicalize (-1 * urem X, Y) + X --> (Y * X/Y)
2814 const SCEV *Y;
2815 if (Ops.size() == 2 &&
2816 match(Ops[0],
2818 m_scev_URem(m_scev_Specific(Ops[1]), m_SCEV(Y), *this))))
2819 return getMulExpr(Y, getUDivExpr(Ops[1], Y));
2820
2821 // Skip past any other cast SCEVs.
2822 while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scAddExpr)
2823 ++Idx;
2824
2825 // If there are add operands they would be next.
2826 if (Idx < Ops.size()) {
2827 bool DeletedAdd = false;
2828 // If the original flags and all inlined SCEVAddExprs are NUW, use the
2829 // common NUW flag for expression after inlining. Other flags cannot be
2830 // preserved, because they may depend on the original order of operations.
2831 SCEV::NoWrapFlags CommonFlags = maskFlags(OrigFlags, SCEV::FlagNUW);
2832 while (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Ops[Idx])) {
2833 if (Ops.size() > AddOpsInlineThreshold ||
2834 Add->getNumOperands() > AddOpsInlineThreshold)
2835 break;
2836 // If we have an add, expand the add operands onto the end of the operands
2837 // list.
2838 Ops.erase(Ops.begin()+Idx);
2839 append_range(Ops, Add->operands());
2840 DeletedAdd = true;
2841 CommonFlags = maskFlags(CommonFlags, Add->getNoWrapFlags());
2842 }
2843
2844 // If we deleted at least one add, we added operands to the end of the list,
2845 // and they are not necessarily sorted. Recurse to resort and resimplify
2846 // any operands we just acquired.
2847 if (DeletedAdd)
2848 return getAddExpr(Ops, CommonFlags, Depth + 1);
2849 }
2850
2851 // Skip over the add expression until we get to a multiply.
2852 while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scMulExpr)
2853 ++Idx;
2854
2855 // Check to see if there are any folding opportunities present with
2856 // operands multiplied by constant values.
2857 if (Idx < Ops.size() && isa<SCEVMulExpr>(Ops[Idx])) {
2861 APInt AccumulatedConstant(BitWidth, 0);
2862 if (CollectAddOperandsWithScales(M, NewOps, AccumulatedConstant,
2863 Ops, APInt(BitWidth, 1), *this)) {
2864 struct APIntCompare {
2865 bool operator()(const APInt &LHS, const APInt &RHS) const {
2866 return LHS.ult(RHS);
2867 }
2868 };
2869
2870 // Some interesting folding opportunity is present, so its worthwhile to
2871 // re-generate the operands list. Group the operands by constant scale,
2872 // to avoid multiplying by the same constant scale multiple times.
2873 std::map<APInt, SmallVector<SCEVUse, 4>, APIntCompare> MulOpLists;
2874 for (const SCEV *NewOp : NewOps)
2875 MulOpLists[M.find(NewOp)->second].push_back(NewOp);
2876 // Re-generate the operands list.
2877 Ops.clear();
2878 if (AccumulatedConstant != 0)
2879 Ops.push_back(getConstant(AccumulatedConstant));
2880 for (auto &MulOp : MulOpLists) {
2881 if (MulOp.first == 1) {
2882 Ops.push_back(getAddExpr(MulOp.second, SCEV::FlagAnyWrap, Depth + 1));
2883 } else if (MulOp.first != 0) {
2884 Ops.push_back(getMulExpr(
2885 getConstant(MulOp.first),
2886 getAddExpr(MulOp.second, SCEV::FlagAnyWrap, Depth + 1),
2887 SCEV::FlagAnyWrap, Depth + 1));
2888 }
2889 }
2890 if (Ops.empty())
2891 return getZero(Ty);
2892 if (Ops.size() == 1)
2893 return Ops[0];
2894 return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
2895 }
2896 }
2897
2898 // If we are adding something to a multiply expression, make sure the
2899 // something is not already an operand of the multiply. If so, merge it into
2900 // the multiply.
2901 for (; Idx < Ops.size() && isa<SCEVMulExpr>(Ops[Idx]); ++Idx) {
2902 const SCEVMulExpr *Mul = cast<SCEVMulExpr>(Ops[Idx]);
2903 for (unsigned MulOp = 0, e = Mul->getNumOperands(); MulOp != e; ++MulOp) {
2904 const SCEV *MulOpSCEV = Mul->getOperand(MulOp);
2905 if (isa<SCEVConstant>(MulOpSCEV))
2906 continue;
2907 for (unsigned AddOp = 0, e = Ops.size(); AddOp != e; ++AddOp)
2908 if (MulOpSCEV == Ops[AddOp]) {
2909 // Fold W + X + (X * Y * Z) --> W + (X * ((Y*Z)+1))
2910 const SCEV *InnerMul = Mul->getOperand(MulOp == 0);
2911 if (Mul->getNumOperands() != 2) {
2912 // If the multiply has more than two operands, we must get the
2913 // Y*Z term.
2914 SmallVector<SCEVUse, 4> MulOps(Mul->operands().take_front(MulOp));
2915 append_range(MulOps, Mul->operands().drop_front(MulOp + 1));
2916 InnerMul = getMulExpr(MulOps, SCEV::FlagAnyWrap, Depth + 1);
2917 }
2918 const SCEV *AddOne =
2919 getAddExpr(getOne(Ty), InnerMul, SCEV::FlagAnyWrap, Depth + 1);
2920 const SCEV *OuterMul = getMulExpr(AddOne, MulOpSCEV,
2922 if (Ops.size() == 2) return OuterMul;
2923 if (AddOp < Idx) {
2924 Ops.erase(Ops.begin()+AddOp);
2925 Ops.erase(Ops.begin()+Idx-1);
2926 } else {
2927 Ops.erase(Ops.begin()+Idx);
2928 Ops.erase(Ops.begin()+AddOp-1);
2929 }
2930 Ops.push_back(OuterMul);
2931 return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
2932 }
2933
2934 // Check this multiply against other multiplies being added together.
2935 for (unsigned OtherMulIdx = Idx+1;
2936 OtherMulIdx < Ops.size() && isa<SCEVMulExpr>(Ops[OtherMulIdx]);
2937 ++OtherMulIdx) {
2938 const SCEVMulExpr *OtherMul = cast<SCEVMulExpr>(Ops[OtherMulIdx]);
2939 // If MulOp occurs in OtherMul, we can fold the two multiplies
2940 // together.
2941 for (unsigned OMulOp = 0, e = OtherMul->getNumOperands();
2942 OMulOp != e; ++OMulOp)
2943 if (OtherMul->getOperand(OMulOp) == MulOpSCEV) {
2944 // Fold X + (A*B*C) + (A*D*E) --> X + (A*(B*C+D*E))
2945 const SCEV *InnerMul1 = Mul->getOperand(MulOp == 0);
2946 if (Mul->getNumOperands() != 2) {
2947 SmallVector<SCEVUse, 4> MulOps(Mul->operands().take_front(MulOp));
2948 append_range(MulOps, Mul->operands().drop_front(MulOp+1));
2949 InnerMul1 = getMulExpr(MulOps, SCEV::FlagAnyWrap, Depth + 1);
2950 }
2951 const SCEV *InnerMul2 = OtherMul->getOperand(OMulOp == 0);
2952 if (OtherMul->getNumOperands() != 2) {
2954 OtherMul->operands().take_front(OMulOp));
2955 append_range(MulOps, OtherMul->operands().drop_front(OMulOp+1));
2956 InnerMul2 = getMulExpr(MulOps, SCEV::FlagAnyWrap, Depth + 1);
2957 }
2958 const SCEV *InnerMulSum =
2959 getAddExpr(InnerMul1, InnerMul2, SCEV::FlagAnyWrap, Depth + 1);
2960 const SCEV *OuterMul = getMulExpr(MulOpSCEV, InnerMulSum,
2962 if (Ops.size() == 2) return OuterMul;
2963 Ops.erase(Ops.begin()+Idx);
2964 Ops.erase(Ops.begin()+OtherMulIdx-1);
2965 Ops.push_back(OuterMul);
2966 return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
2967 }
2968 }
2969 }
2970 }
2971
2972 // If there are any add recurrences in the operands list, see if any other
2973 // added values are loop invariant. If so, we can fold them into the
2974 // recurrence.
2975 while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scAddRecExpr)
2976 ++Idx;
2977
2978 // Scan over all recurrences, trying to fold loop invariants into them.
2979 for (; Idx < Ops.size() && isa<SCEVAddRecExpr>(Ops[Idx]); ++Idx) {
2980 // Scan all of the other operands to this add and add them to the vector if
2981 // they are loop invariant w.r.t. the recurrence.
2983 const SCEVAddRecExpr *AddRec = cast<SCEVAddRecExpr>(Ops[Idx]);
2984 const Loop *AddRecLoop = AddRec->getLoop();
2985 for (unsigned i = 0, e = Ops.size(); i != e; ++i)
2986 if (isAvailableAtLoopEntry(Ops[i], AddRecLoop)) {
2987 LIOps.push_back(Ops[i]);
2988 Ops.erase(Ops.begin()+i);
2989 --i; --e;
2990 }
2991
2992 // If we found some loop invariants, fold them into the recurrence.
2993 if (!LIOps.empty()) {
2994 // Compute nowrap flags for the addition of the loop-invariant ops and
2995 // the addrec. Temporarily push it as an operand for that purpose. These
2996 // flags are valid in the scope of the addrec only.
2997 LIOps.push_back(AddRec);
2998 SCEV::NoWrapFlags Flags = ComputeFlags(LIOps);
2999 LIOps.pop_back();
3000
3001 // NLI + LI + {Start,+,Step} --> NLI + {LI+Start,+,Step}
3002 LIOps.push_back(AddRec->getStart());
3003
3004 SmallVector<SCEVUse, 4> AddRecOps(AddRec->operands());
3005
3006 // It is not in general safe to propagate flags valid on an add within
3007 // the addrec scope to one outside it. We must prove that the inner
3008 // scope is guaranteed to execute if the outer one does to be able to
3009 // safely propagate. We know the program is undefined if poison is
3010 // produced on the inner scoped addrec. We also know that *for this use*
3011 // the outer scoped add can't overflow (because of the flags we just
3012 // computed for the inner scoped add) without the program being undefined.
3013 // Proving that entry to the outer scope neccesitates entry to the inner
3014 // scope, thus proves the program undefined if the flags would be violated
3015 // in the outer scope.
3016 SCEV::NoWrapFlags AddFlags = Flags;
3017 if (AddFlags != SCEV::FlagAnyWrap) {
3018 auto *DefI = getDefiningScopeBound(LIOps);
3019 auto *ReachI = &*AddRecLoop->getHeader()->begin();
3020 if (!isGuaranteedToTransferExecutionTo(DefI, ReachI))
3021 AddFlags = SCEV::FlagAnyWrap;
3022 }
3023 AddRecOps[0] = getAddExpr(LIOps, AddFlags, Depth + 1);
3024
3025 // Build the new addrec. Propagate the NUW and NSW flags if both the
3026 // outer add and the inner addrec are guaranteed to have no overflow.
3027 // Always propagate NW.
3028 Flags = AddRec->getNoWrapFlags(setFlags(Flags, SCEV::FlagNW));
3029 const SCEV *NewRec = getAddRecExpr(AddRecOps, AddRecLoop, Flags);
3030
3031 // If all of the other operands were loop invariant, we are done.
3032 if (Ops.size() == 1) return NewRec;
3033
3034 // Otherwise, add the folded AddRec by the non-invariant parts.
3035 for (unsigned i = 0;; ++i)
3036 if (Ops[i] == AddRec) {
3037 Ops[i] = NewRec;
3038 break;
3039 }
3040 return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
3041 }
3042
3043 // Okay, if there weren't any loop invariants to be folded, check to see if
3044 // there are multiple AddRec's with the same loop induction variable being
3045 // added together. If so, we can fold them.
3046 for (unsigned OtherIdx = Idx+1;
3047 OtherIdx < Ops.size() && isa<SCEVAddRecExpr>(Ops[OtherIdx]);
3048 ++OtherIdx) {
3049 // We expect the AddRecExpr's to be sorted in reverse dominance order,
3050 // so that the 1st found AddRecExpr is dominated by all others.
3051 assert(DT.dominates(
3052 cast<SCEVAddRecExpr>(Ops[OtherIdx])->getLoop()->getHeader(),
3053 AddRec->getLoop()->getHeader()) &&
3054 "AddRecExprs are not sorted in reverse dominance order?");
3055 if (AddRecLoop == cast<SCEVAddRecExpr>(Ops[OtherIdx])->getLoop()) {
3056 // Other + {A,+,B}<L> + {C,+,D}<L> --> Other + {A+C,+,B+D}<L>
3057 SmallVector<SCEVUse, 4> AddRecOps(AddRec->operands());
3058 for (; OtherIdx != Ops.size() && isa<SCEVAddRecExpr>(Ops[OtherIdx]);
3059 ++OtherIdx) {
3060 const auto *OtherAddRec = cast<SCEVAddRecExpr>(Ops[OtherIdx]);
3061 if (OtherAddRec->getLoop() == AddRecLoop) {
3062 for (unsigned i = 0, e = OtherAddRec->getNumOperands();
3063 i != e; ++i) {
3064 if (i >= AddRecOps.size()) {
3065 append_range(AddRecOps, OtherAddRec->operands().drop_front(i));
3066 break;
3067 }
3068 AddRecOps[i] =
3069 getAddExpr(AddRecOps[i], OtherAddRec->getOperand(i),
3071 }
3072 Ops.erase(Ops.begin() + OtherIdx); --OtherIdx;
3073 }
3074 }
3075 // Step size has changed, so we cannot guarantee no self-wraparound.
3076 Ops[Idx] = getAddRecExpr(AddRecOps, AddRecLoop, SCEV::FlagAnyWrap);
3077 return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
3078 }
3079 }
3080
3081 // Otherwise couldn't fold anything into this recurrence. Move onto the
3082 // next one.
3083 }
3084
3085 // Okay, it looks like we really DO need an add expr. Check to see if we
3086 // already have one, otherwise create a new one.
3087 return getOrCreateAddExpr(Ops, ComputeFlags(Ops));
3088}
3089
3090const SCEV *ScalarEvolution::getOrCreateAddExpr(ArrayRef<SCEVUse> Ops,
3091 SCEV::NoWrapFlags Flags) {
3093 ID.AddInteger(scAddExpr);
3094 for (const SCEV *Op : Ops)
3095 ID.AddPointer(Op);
3096 void *IP = nullptr;
3097 SCEVAddExpr *S =
3098 static_cast<SCEVAddExpr *>(UniqueSCEVs.FindNodeOrInsertPos(ID, IP));
3099 if (!S) {
3100 SCEVUse *O = SCEVAllocator.Allocate<SCEVUse>(Ops.size());
3102 S = new (SCEVAllocator)
3103 SCEVAddExpr(ID.Intern(SCEVAllocator), O, Ops.size());
3104 UniqueSCEVs.InsertNode(S, IP);
3105 S->computeAndSetCanonical(*this);
3106 registerUser(S, Ops);
3107 }
3108 S->setNoWrapFlags(Flags);
3109 return S;
3110}
3111
3112const SCEV *ScalarEvolution::getOrCreateAddRecExpr(ArrayRef<SCEVUse> Ops,
3113 const Loop *L,
3114 SCEV::NoWrapFlags Flags) {
3115 FoldingSetNodeID ID;
3116 ID.AddInteger(scAddRecExpr);
3117 for (const SCEV *Op : Ops)
3118 ID.AddPointer(Op);
3119 ID.AddPointer(L);
3120 void *IP = nullptr;
3121 SCEVAddRecExpr *S =
3122 static_cast<SCEVAddRecExpr *>(UniqueSCEVs.FindNodeOrInsertPos(ID, IP));
3123 if (!S) {
3124 SCEVUse *O = SCEVAllocator.Allocate<SCEVUse>(Ops.size());
3126 S = new (SCEVAllocator)
3127 SCEVAddRecExpr(ID.Intern(SCEVAllocator), O, Ops.size(), L);
3128 UniqueSCEVs.InsertNode(S, IP);
3129 S->computeAndSetCanonical(*this);
3130 LoopUsers[L].push_back(S);
3131 registerUser(S, Ops);
3132 }
3133 setNoWrapFlags(S, Flags);
3134 return S;
3135}
3136
3137const SCEV *ScalarEvolution::getOrCreateMulExpr(ArrayRef<SCEVUse> Ops,
3138 SCEV::NoWrapFlags Flags) {
3139 FoldingSetNodeID ID;
3140 ID.AddInteger(scMulExpr);
3141 for (const SCEV *Op : Ops)
3142 ID.AddPointer(Op);
3143 void *IP = nullptr;
3144 SCEVMulExpr *S =
3145 static_cast<SCEVMulExpr *>(UniqueSCEVs.FindNodeOrInsertPos(ID, IP));
3146 if (!S) {
3147 SCEVUse *O = SCEVAllocator.Allocate<SCEVUse>(Ops.size());
3149 S = new (SCEVAllocator) SCEVMulExpr(ID.Intern(SCEVAllocator),
3150 O, Ops.size());
3151 UniqueSCEVs.InsertNode(S, IP);
3152 S->computeAndSetCanonical(*this);
3153 registerUser(S, Ops);
3154 }
3155 S->setNoWrapFlags(Flags);
3156 return S;
3157}
3158
3159static uint64_t umul_ov(uint64_t i, uint64_t j, bool &Overflow) {
3160 uint64_t k = i*j;
3161 if (j > 1 && k / j != i) Overflow = true;
3162 return k;
3163}
3164
3165/// Compute the result of "n choose k", the binomial coefficient. If an
3166/// intermediate computation overflows, Overflow will be set and the return will
3167/// be garbage. Overflow is not cleared on absence of overflow.
3168static uint64_t Choose(uint64_t n, uint64_t k, bool &Overflow) {
3169 // We use the multiplicative formula:
3170 // n(n-1)(n-2)...(n-(k-1)) / k(k-1)(k-2)...1 .
3171 // At each iteration, we take the n-th term of the numeral and divide by the
3172 // (k-n)th term of the denominator. This division will always produce an
3173 // integral result, and helps reduce the chance of overflow in the
3174 // intermediate computations. However, we can still overflow even when the
3175 // final result would fit.
3176
3177 if (n == 0 || n == k) return 1;
3178 if (k > n) return 0;
3179
3180 if (k > n/2)
3181 k = n-k;
3182
3183 uint64_t r = 1;
3184 for (uint64_t i = 1; i <= k; ++i) {
3185 r = umul_ov(r, n-(i-1), Overflow);
3186 r /= i;
3187 }
3188 return r;
3189}
3190
3191/// Determine if any of the operands in this SCEV are a constant or if
3192/// any of the add or multiply expressions in this SCEV contain a constant.
3193static bool containsConstantInAddMulChain(const SCEV *StartExpr) {
3194 struct FindConstantInAddMulChain {
3195 bool FoundConstant = false;
3196
3197 bool follow(const SCEV *S) {
3198 FoundConstant |= isa<SCEVConstant>(S);
3199 return isa<SCEVAddExpr>(S) || isa<SCEVMulExpr>(S);
3200 }
3201
3202 bool isDone() const {
3203 return FoundConstant;
3204 }
3205 };
3206
3207 FindConstantInAddMulChain F;
3209 ST.visitAll(StartExpr);
3210 return F.FoundConstant;
3211}
3212
3213/// Get a canonical multiply expression, or something simpler if possible.
3215 SCEV::NoWrapFlags OrigFlags,
3216 unsigned Depth) {
3217 assert(OrigFlags == maskFlags(OrigFlags, SCEV::FlagNUW | SCEV::FlagNSW) &&
3218 "only nuw or nsw allowed");
3219 assert(!Ops.empty() && "Cannot get empty mul!");
3220 if (Ops.size() == 1) return Ops[0];
3221#ifndef NDEBUG
3222 Type *ETy = Ops[0]->getType();
3223 assert(!ETy->isPointerTy());
3224 for (unsigned i = 1, e = Ops.size(); i != e; ++i)
3225 assert(Ops[i]->getType() == ETy &&
3226 "SCEVMulExpr operand types don't match!");
3227#endif
3228
3229 const SCEV *Folded = constantFoldAndGroupOps(
3230 *this, LI, DT, Ops,
3231 [](const APInt &C1, const APInt &C2) { return C1 * C2; },
3232 [](const APInt &C) { return C.isOne(); }, // identity
3233 [](const APInt &C) { return C.isZero(); }); // absorber
3234 if (Folded)
3235 return Folded;
3236
3237 // Delay expensive flag strengthening until necessary.
3238 auto ComputeFlags = [this, OrigFlags](const ArrayRef<SCEVUse> Ops) {
3239 return StrengthenNoWrapFlags(this, scMulExpr, Ops, OrigFlags);
3240 };
3241
3242 // Limit recursion calls depth.
3244 return getOrCreateMulExpr(Ops, ComputeFlags(Ops));
3245
3246 if (SCEV *S = findExistingSCEVInCache(scMulExpr, Ops)) {
3247 // Don't strengthen flags if we have no new information.
3248 SCEVMulExpr *Mul = static_cast<SCEVMulExpr *>(S);
3249 if (Mul->getNoWrapFlags(OrigFlags) != OrigFlags)
3250 Mul->setNoWrapFlags(ComputeFlags(Ops));
3251 return S;
3252 }
3253
3254 if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(Ops[0])) {
3255 if (Ops.size() == 2) {
3256 // C1*(C2+V) -> C1*C2 + C1*V
3257 // If any of Add's ops are Adds or Muls with a constant, apply this
3258 // transformation as well.
3259 //
3260 // TODO: There are some cases where this transformation is not
3261 // profitable; for example, Add = (C0 + X) * Y + Z. Maybe the scope of
3262 // this transformation should be narrowed down.
3263 const SCEV *Op0, *Op1;
3264 if (match(Ops[1], m_scev_Add(m_SCEV(Op0), m_SCEV(Op1))) &&
3266 const SCEV *LHS = getMulExpr(LHSC, Op0, SCEV::FlagAnyWrap, Depth + 1);
3267 const SCEV *RHS = getMulExpr(LHSC, Op1, SCEV::FlagAnyWrap, Depth + 1);
3268 return getAddExpr(LHS, RHS, SCEV::FlagAnyWrap, Depth + 1);
3269 }
3270
3271 if (Ops[0]->isAllOnesValue()) {
3272 // If we have a mul by -1 of an add, try distributing the -1 among the
3273 // add operands.
3274 if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Ops[1])) {
3276 bool AnyFolded = false;
3277 for (const SCEV *AddOp : Add->operands()) {
3278 const SCEV *Mul = getMulExpr(Ops[0], SCEVUse(AddOp),
3280 if (!isa<SCEVMulExpr>(Mul)) AnyFolded = true;
3281 NewOps.push_back(Mul);
3282 }
3283 if (AnyFolded)
3284 return getAddExpr(NewOps, SCEV::FlagAnyWrap, Depth + 1);
3285 } else if (const auto *AddRec = dyn_cast<SCEVAddRecExpr>(Ops[1])) {
3286 // Negation preserves a recurrence's no self-wrap property.
3287 SmallVector<SCEVUse, 4> Operands;
3288 for (const SCEV *AddRecOp : AddRec->operands())
3289 Operands.push_back(getMulExpr(Ops[0], SCEVUse(AddRecOp),
3290 SCEV::FlagAnyWrap, Depth + 1));
3291 // Let M be the minimum representable signed value. AddRec with nsw
3292 // multiplied by -1 can have signed overflow if and only if it takes a
3293 // value of M: M * (-1) would stay M and (M + 1) * (-1) would be the
3294 // maximum signed value. In all other cases signed overflow is
3295 // impossible.
3296 auto FlagsMask = SCEV::FlagNW;
3297 if (AddRec->hasNoSignedWrap()) {
3298 auto MinInt =
3299 APInt::getSignedMinValue(getTypeSizeInBits(AddRec->getType()));
3300 if (getSignedRangeMin(AddRec) != MinInt)
3301 FlagsMask = setFlags(FlagsMask, SCEV::FlagNSW);
3302 }
3303 return getAddRecExpr(Operands, AddRec->getLoop(),
3304 AddRec->getNoWrapFlags(FlagsMask));
3305 }
3306 }
3307
3308 // Try to push the constant operand into a ZExt: C * zext (A + B) ->
3309 // zext (C*A + C*B) if trunc (C) * (A + B) does not unsigned-wrap.
3310 const SCEVAddExpr *InnerAdd;
3311 if (match(Ops[1], m_scev_ZExt(m_scev_Add(InnerAdd)))) {
3312 const SCEV *NarrowC = getTruncateExpr(LHSC, InnerAdd->getType());
3313 if (isa<SCEVConstant>(InnerAdd->getOperand(0)) &&
3314 getZeroExtendExpr(NarrowC, Ops[1]->getType()) == LHSC &&
3315 hasFlags(StrengthenNoWrapFlags(this, scMulExpr, {NarrowC, InnerAdd},
3317 SCEV::FlagNUW)) {
3318 auto *Res = getMulExpr(NarrowC, InnerAdd, SCEV::FlagNUW, Depth + 1);
3319 return getZeroExtendExpr(Res, Ops[1]->getType(), Depth + 1);
3320 };
3321 }
3322
3323 // Try to fold (C1 * D /u C2) -> C1/C2 * D, if C1 and C2 are powers-of-2,
3324 // D is a multiple of C2, and C1 is a multiple of C2. If C2 is a multiple
3325 // of C1, fold to (D /u (C2 /u C1)).
3326 const SCEV *D;
3327 APInt C1V = LHSC->getAPInt();
3328 // (C1 * D /u C2) == -1 * -C1 * D /u C2 when C1 != INT_MIN. Don't treat -1
3329 // as -1 * 1, as it won't enable additional folds.
3330 if (C1V.isNegative() && !C1V.isMinSignedValue() && !C1V.isAllOnes())
3331 C1V = C1V.abs();
3332 const SCEVConstant *C2;
3333 if (C1V.isPowerOf2() &&
3335 C2->getAPInt().isPowerOf2() &&
3336 C1V.logBase2() <= getMinTrailingZeros(D)) {
3337 const SCEV *NewMul = nullptr;
3338 if (C1V.uge(C2->getAPInt())) {
3339 NewMul = getMulExpr(getUDivExpr(getConstant(C1V), C2), D);
3340 } else if (C2->getAPInt().logBase2() <= getMinTrailingZeros(D)) {
3341 assert(C1V.ugt(1) && "C1 <= 1 should have been folded earlier");
3342 NewMul = getUDivExpr(D, getUDivExpr(C2, getConstant(C1V)));
3343 }
3344 if (NewMul)
3345 return C1V == LHSC->getAPInt() ? NewMul : getNegativeSCEV(NewMul);
3346 }
3347 }
3348 }
3349
3350 // Skip over the add expression until we get to a multiply.
3351 unsigned Idx = 0;
3352 while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scMulExpr)
3353 ++Idx;
3354
3355 // If there are mul operands inline them all into this expression.
3356 if (Idx < Ops.size()) {
3357 bool DeletedMul = false;
3358 while (const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(Ops[Idx])) {
3359 if (Ops.size() > MulOpsInlineThreshold)
3360 break;
3361 // If we have an mul, expand the mul operands onto the end of the
3362 // operands list.
3363 Ops.erase(Ops.begin()+Idx);
3364 append_range(Ops, Mul->operands());
3365 DeletedMul = true;
3366 }
3367
3368 // If we deleted at least one mul, we added operands to the end of the
3369 // list, and they are not necessarily sorted. Recurse to resort and
3370 // resimplify any operands we just acquired.
3371 if (DeletedMul)
3372 return getMulExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
3373 }
3374
3375 // If there are any add recurrences in the operands list, see if any other
3376 // added values are loop invariant. If so, we can fold them into the
3377 // recurrence.
3378 while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scAddRecExpr)
3379 ++Idx;
3380
3381 // Scan over all recurrences, trying to fold loop invariants into them.
3382 for (; Idx < Ops.size() && isa<SCEVAddRecExpr>(Ops[Idx]); ++Idx) {
3383 // Scan all of the other operands to this mul and add them to the vector
3384 // if they are loop invariant w.r.t. the recurrence.
3386 const SCEVAddRecExpr *AddRec = cast<SCEVAddRecExpr>(Ops[Idx]);
3387 for (unsigned i = 0, e = Ops.size(); i != e; ++i)
3388 if (isAvailableAtLoopEntry(Ops[i], AddRec->getLoop())) {
3389 LIOps.push_back(Ops[i]);
3390 Ops.erase(Ops.begin()+i);
3391 --i; --e;
3392 }
3393
3394 // If we found some loop invariants, fold them into the recurrence.
3395 if (!LIOps.empty()) {
3396 // NLI * LI * {Start,+,Step} --> NLI * {LI*Start,+,LI*Step}
3398 NewOps.reserve(AddRec->getNumOperands());
3399 const SCEV *Scale = getMulExpr(LIOps, SCEV::FlagAnyWrap, Depth + 1);
3400
3401 // If both the mul and addrec are nuw, we can preserve nuw.
3402 // If both the mul and addrec are nsw, we can only preserve nsw if either
3403 // a) they are also nuw, or
3404 // b) all multiplications of addrec operands with scale are nsw.
3405 SCEV::NoWrapFlags Flags =
3406 AddRec->getNoWrapFlags(ComputeFlags({Scale, AddRec}));
3407
3408 for (unsigned i = 0, e = AddRec->getNumOperands(); i != e; ++i) {
3409 NewOps.push_back(getMulExpr(Scale, AddRec->getOperand(i),
3410 SCEV::FlagAnyWrap, Depth + 1));
3411
3412 if (hasFlags(Flags, SCEV::FlagNSW) && !hasFlags(Flags, SCEV::FlagNUW)) {
3414 Instruction::Mul, getSignedRange(Scale),
3416 if (!NSWRegion.contains(getSignedRange(AddRec->getOperand(i))))
3417 Flags = clearFlags(Flags, SCEV::FlagNSW);
3418 }
3419 }
3420
3421 const SCEV *NewRec = getAddRecExpr(NewOps, AddRec->getLoop(), Flags);
3422
3423 // If all of the other operands were loop invariant, we are done.
3424 if (Ops.size() == 1) return NewRec;
3425
3426 // Otherwise, multiply the folded AddRec by the non-invariant parts.
3427 for (unsigned i = 0;; ++i)
3428 if (Ops[i] == AddRec) {
3429 Ops[i] = NewRec;
3430 break;
3431 }
3432 return getMulExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
3433 }
3434
3435 // Okay, if there weren't any loop invariants to be folded, check to see
3436 // if there are multiple AddRec's with the same loop induction variable
3437 // being multiplied together. If so, we can fold them.
3438
3439 // {A1,+,A2,+,...,+,An}<L> * {B1,+,B2,+,...,+,Bn}<L>
3440 // = {x=1 in [ sum y=x..2x [ sum z=max(y-x, y-n)..min(x,n) [
3441 // choose(x, 2x)*choose(2x-y, x-z)*A_{y-z}*B_z
3442 // ]]],+,...up to x=2n}.
3443 // Note that the arguments to choose() are always integers with values
3444 // known at compile time, never SCEV objects.
3445 //
3446 // The implementation avoids pointless extra computations when the two
3447 // addrec's are of different length (mathematically, it's equivalent to
3448 // an infinite stream of zeros on the right).
3449 bool OpsModified = false;
3450 for (unsigned OtherIdx = Idx+1;
3451 OtherIdx != Ops.size() && isa<SCEVAddRecExpr>(Ops[OtherIdx]);
3452 ++OtherIdx) {
3453 const SCEVAddRecExpr *OtherAddRec =
3454 dyn_cast<SCEVAddRecExpr>(Ops[OtherIdx]);
3455 if (!OtherAddRec || OtherAddRec->getLoop() != AddRec->getLoop())
3456 continue;
3457
3458 // Limit max number of arguments to avoid creation of unreasonably big
3459 // SCEVAddRecs with very complex operands.
3460 if (AddRec->getNumOperands() + OtherAddRec->getNumOperands() - 1 >
3461 MaxAddRecSize || hasHugeExpression({AddRec, OtherAddRec}))
3462 continue;
3463
3464 bool Overflow = false;
3465 Type *Ty = AddRec->getType();
3466 bool LargerThan64Bits = getTypeSizeInBits(Ty) > 64;
3467 SmallVector<SCEVUse, 7> AddRecOps;
3468 for (int x = 0, xe = AddRec->getNumOperands() +
3469 OtherAddRec->getNumOperands() - 1; x != xe && !Overflow; ++x) {
3471 for (int y = x, ye = 2*x+1; y != ye && !Overflow; ++y) {
3472 uint64_t Coeff1 = Choose(x, 2*x - y, Overflow);
3473 for (int z = std::max(y-x, y-(int)AddRec->getNumOperands()+1),
3474 ze = std::min(x+1, (int)OtherAddRec->getNumOperands());
3475 z < ze && !Overflow; ++z) {
3476 uint64_t Coeff2 = Choose(2*x - y, x-z, Overflow);
3477 uint64_t Coeff;
3478 if (LargerThan64Bits)
3479 Coeff = umul_ov(Coeff1, Coeff2, Overflow);
3480 else
3481 Coeff = Coeff1*Coeff2;
3482 const SCEV *CoeffTerm = getConstant(Ty, Coeff);
3483 const SCEV *Term1 = AddRec->getOperand(y-z);
3484 const SCEV *Term2 = OtherAddRec->getOperand(z);
3485 SumOps.push_back(getMulExpr(CoeffTerm, Term1, Term2,
3486 SCEV::FlagAnyWrap, Depth + 1));
3487 }
3488 }
3489 if (SumOps.empty())
3490 SumOps.push_back(getZero(Ty));
3491 AddRecOps.push_back(getAddExpr(SumOps, SCEV::FlagAnyWrap, Depth + 1));
3492 }
3493 if (!Overflow) {
3494 const SCEV *NewAddRec = getAddRecExpr(AddRecOps, AddRec->getLoop(),
3496 if (Ops.size() == 2) return NewAddRec;
3497 Ops[Idx] = NewAddRec;
3498 Ops.erase(Ops.begin() + OtherIdx); --OtherIdx;
3499 OpsModified = true;
3500 AddRec = dyn_cast<SCEVAddRecExpr>(NewAddRec);
3501 if (!AddRec)
3502 break;
3503 }
3504 }
3505 if (OpsModified)
3506 return getMulExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
3507
3508 // Otherwise couldn't fold anything into this recurrence. Move onto the
3509 // next one.
3510 }
3511
3512 // Okay, it looks like we really DO need an mul expr. Check to see if we
3513 // already have one, otherwise create a new one.
3514 return getOrCreateMulExpr(Ops, ComputeFlags(Ops));
3515}
3516
3517/// Represents an unsigned remainder expression based on unsigned division.
3519 assert(getEffectiveSCEVType(LHS->getType()) ==
3520 getEffectiveSCEVType(RHS->getType()) &&
3521 "SCEVURemExpr operand types don't match!");
3522
3523 // Short-circuit easy cases
3524 if (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS)) {
3525 // If constant is one, the result is trivial
3526 if (RHSC->getValue()->isOne())
3527 return getZero(LHS->getType()); // X urem 1 --> 0
3528
3529 // If constant is a power of two, fold into a zext(trunc(LHS)).
3530 if (RHSC->getAPInt().isPowerOf2()) {
3531 Type *FullTy = LHS->getType();
3532 Type *TruncTy =
3533 IntegerType::get(getContext(), RHSC->getAPInt().logBase2());
3534 return getZeroExtendExpr(getTruncateExpr(LHS, TruncTy), FullTy);
3535 }
3536 }
3537
3538 // Fallback to %a == %x urem %y == %x -<nuw> ((%x udiv %y) *<nuw> %y)
3539 const SCEV *UDiv = getUDivExpr(LHS, RHS);
3540 const SCEV *Mult = getMulExpr(UDiv, RHS, SCEV::FlagNUW);
3541 return getMinusSCEV(LHS, Mult, SCEV::FlagNUW);
3542}
3543
3544/// Get a canonical unsigned division expression, or something simpler if
3545/// possible.
3547 assert(!LHS->getType()->isPointerTy() &&
3548 "SCEVUDivExpr operand can't be pointer!");
3549 assert(LHS->getType() == RHS->getType() &&
3550 "SCEVUDivExpr operand types don't match!");
3551
3553 ID.AddInteger(scUDivExpr);
3554 ID.AddPointer(LHS);
3555 ID.AddPointer(RHS);
3556 void *IP = nullptr;
3557 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP))
3558 return S;
3559
3560 // 0 udiv Y == 0
3561 if (match(LHS, m_scev_Zero()))
3562 return LHS;
3563
3564 if (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS)) {
3565 if (RHSC->getValue()->isOne())
3566 return LHS; // X udiv 1 --> x
3567 // If the denominator is zero, the result of the udiv is undefined. Don't
3568 // try to analyze it, because the resolution chosen here may differ from
3569 // the resolution chosen in other parts of the compiler.
3570 if (!RHSC->getValue()->isZero()) {
3571 // Determine if the division can be folded into the operands of
3572 // its operands.
3573 // TODO: Generalize this to non-constants by using known-bits information.
3574 Type *Ty = LHS->getType();
3575 unsigned LZ = RHSC->getAPInt().countl_zero();
3576 unsigned MaxShiftAmt = getTypeSizeInBits(Ty) - LZ - 1;
3577 // For non-power-of-two values, effectively round the value up to the
3578 // nearest power of two.
3579 if (!RHSC->getAPInt().isPowerOf2())
3580 ++MaxShiftAmt;
3581 IntegerType *ExtTy =
3582 IntegerType::get(getContext(), getTypeSizeInBits(Ty) + MaxShiftAmt);
3583 if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(LHS))
3584 if (const SCEVConstant *Step =
3585 dyn_cast<SCEVConstant>(AR->getStepRecurrence(*this))) {
3586 // {X,+,N}/C --> {X/C,+,N/C} if safe and N/C can be folded.
3587 const APInt &StepInt = Step->getAPInt();
3588 const APInt &DivInt = RHSC->getAPInt();
3589 if (!StepInt.urem(DivInt) &&
3590 getZeroExtendExpr(AR, ExtTy) ==
3591 getAddRecExpr(getZeroExtendExpr(AR->getStart(), ExtTy),
3592 getZeroExtendExpr(Step, ExtTy),
3593 AR->getLoop(), SCEV::FlagAnyWrap)) {
3594 SmallVector<SCEVUse, 4> Operands;
3595 for (const SCEV *Op : AR->operands())
3596 Operands.push_back(getUDivExpr(Op, RHS));
3597 return getAddRecExpr(Operands, AR->getLoop(), SCEV::FlagNW);
3598 }
3599 /// Get a canonical UDivExpr for a recurrence.
3600 /// {X,+,N}/C => {Y,+,N}/C where Y=X-(X%N). Safe when C%N=0.
3601 const APInt *StartRem;
3602 if (!DivInt.urem(StepInt) && match(getURemExpr(AR->getStart(), Step),
3603 m_scev_APInt(StartRem))) {
3604 bool NoWrap =
3605 getZeroExtendExpr(AR, ExtTy) ==
3606 getAddRecExpr(getZeroExtendExpr(AR->getStart(), ExtTy),
3607 getZeroExtendExpr(Step, ExtTy), AR->getLoop(),
3609
3610 // With N <= C and both N, C as powers-of-2, the transformation
3611 // {X,+,N}/C => {(X - X%N),+,N}/C preserves division results even
3612 // if wrapping occurs, as the division results remain equivalent for
3613 // all offsets in [[(X - X%N), X).
3614 bool CanFoldWithWrap = StepInt.ule(DivInt) && // N <= C
3615 StepInt.isPowerOf2() && DivInt.isPowerOf2();
3616 // Only fold if the subtraction can be folded in the start
3617 // expression.
3618 const SCEV *NewStart =
3619 getMinusSCEV(AR->getStart(), getConstant(*StartRem));
3620 if (*StartRem != 0 && (NoWrap || CanFoldWithWrap) &&
3621 !isa<SCEVAddExpr>(NewStart)) {
3622 const SCEV *NewLHS =
3623 getAddRecExpr(NewStart, Step, AR->getLoop(),
3624 NoWrap ? SCEV::FlagNW : SCEV::FlagAnyWrap);
3625 if (LHS != NewLHS) {
3626 LHS = NewLHS;
3627
3628 // Reset the ID to include the new LHS, and check if it is
3629 // already cached.
3630 ID.clear();
3631 ID.AddInteger(scUDivExpr);
3632 ID.AddPointer(LHS);
3633 ID.AddPointer(RHS);
3634 IP = nullptr;
3635 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP))
3636 return S;
3637 }
3638 }
3639 }
3640 }
3641 // (A*B)/C --> A*(B/C) if safe and B/C can be folded.
3642 if (const SCEVMulExpr *M = dyn_cast<SCEVMulExpr>(LHS)) {
3643 SmallVector<SCEVUse, 4> Operands;
3644 for (const SCEV *Op : M->operands())
3645 Operands.push_back(getZeroExtendExpr(Op, ExtTy));
3646 if (getZeroExtendExpr(M, ExtTy) == getMulExpr(Operands))
3647 // Find an operand that's safely divisible.
3648 for (unsigned i = 0, e = M->getNumOperands(); i != e; ++i) {
3649 const SCEV *Op = M->getOperand(i);
3650 const SCEV *Div = getUDivExpr(Op, RHSC);
3651 if (!isa<SCEVUDivExpr>(Div) && getMulExpr(Div, RHSC) == Op) {
3652 Operands = SmallVector<SCEVUse, 4>(M->operands());
3653 Operands[i] = Div;
3654 return getMulExpr(Operands);
3655 }
3656 }
3657 }
3658
3659 // (A/B)/C --> A/(B*C) if safe and B*C can be folded.
3660 if (const SCEVUDivExpr *OtherDiv = dyn_cast<SCEVUDivExpr>(LHS)) {
3661 if (auto *DivisorConstant =
3662 dyn_cast<SCEVConstant>(OtherDiv->getRHS())) {
3663 bool Overflow = false;
3664 APInt NewRHS =
3665 DivisorConstant->getAPInt().umul_ov(RHSC->getAPInt(), Overflow);
3666 if (Overflow) {
3667 return getConstant(RHSC->getType(), 0, false);
3668 }
3669 return getUDivExpr(OtherDiv->getLHS(), getConstant(NewRHS));
3670 }
3671 }
3672
3673 // (A+B)/C --> (A/C + B/C) if safe and A/C and B/C can be folded.
3674 if (const SCEVAddExpr *A = dyn_cast<SCEVAddExpr>(LHS)) {
3675 SmallVector<SCEVUse, 4> Operands;
3676 for (const SCEV *Op : A->operands())
3677 Operands.push_back(getZeroExtendExpr(Op, ExtTy));
3678 if (getZeroExtendExpr(A, ExtTy) == getAddExpr(Operands)) {
3679 Operands.clear();
3680 for (unsigned i = 0, e = A->getNumOperands(); i != e; ++i) {
3681 const SCEV *Op = getUDivExpr(A->getOperand(i), RHS);
3682 if (isa<SCEVUDivExpr>(Op) ||
3683 getMulExpr(Op, RHS) != A->getOperand(i))
3684 break;
3685 Operands.push_back(Op);
3686 }
3687 if (Operands.size() == A->getNumOperands())
3688 return getAddExpr(Operands);
3689 }
3690 }
3691
3692 // Fold if both operands are constant.
3693 if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(LHS))
3694 return getConstant(LHSC->getAPInt().udiv(RHSC->getAPInt()));
3695 }
3696 }
3697
3698 // ((-C + (C smax %x)) /u %x) evaluates to zero, for any positive constant C.
3699 const APInt *NegC, *C;
3700 if (match(LHS,
3703 NegC->isNegative() && !NegC->isMinSignedValue() && *C == -*NegC)
3704 return getZero(LHS->getType());
3705
3706 // TODO: Generalize to handle any common factors.
3707 // udiv (mul nuw a, vscale), (mul nuw b, vscale) --> udiv a, b
3708 const SCEV *NewLHS, *NewRHS;
3709 if (match(LHS, m_scev_c_NUWMul(m_SCEV(NewLHS), m_SCEVVScale())) &&
3710 match(RHS, m_scev_c_NUWMul(m_SCEV(NewRHS), m_SCEVVScale())))
3711 return getUDivExpr(NewLHS, NewRHS);
3712
3713 // The Insertion Point (IP) might be invalid by now (due to UniqueSCEVs
3714 // changes). Make sure we get a new one.
3715 IP = nullptr;
3716 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
3717 SCEV *S = new (SCEVAllocator) SCEVUDivExpr(ID.Intern(SCEVAllocator),
3718 LHS, RHS);
3719 UniqueSCEVs.InsertNode(S, IP);
3720 S->computeAndSetCanonical(*this);
3721 registerUser(S, ArrayRef<SCEVUse>({LHS, RHS}));
3722 return S;
3723}
3724
3725APInt gcd(const SCEVConstant *C1, const SCEVConstant *C2) {
3726 APInt A = C1->getAPInt().abs();
3727 APInt B = C2->getAPInt().abs();
3728 uint32_t ABW = A.getBitWidth();
3729 uint32_t BBW = B.getBitWidth();
3730
3731 if (ABW > BBW)
3732 B = B.zext(ABW);
3733 else if (ABW < BBW)
3734 A = A.zext(BBW);
3735
3736 return APIntOps::GreatestCommonDivisor(std::move(A), std::move(B));
3737}
3738
3739/// Get a canonical unsigned division expression, or something simpler if
3740/// possible. There is no representation for an exact udiv in SCEV IR, but we
3741/// can attempt to remove factors from the LHS and RHS. We can't do this when
3742/// it's not exact because the udiv may be clearing bits.
3744 // TODO: we could try to find factors in all sorts of things, but for now we
3745 // just deal with u/exact (multiply, constant). See SCEVDivision towards the
3746 // end of this file for inspiration.
3747
3749 if (!Mul || !Mul->hasNoUnsignedWrap())
3750 return getUDivExpr(LHS, RHS);
3751
3752 if (const SCEVConstant *RHSCst = dyn_cast<SCEVConstant>(RHS)) {
3753 // If the mulexpr multiplies by a constant, then that constant must be the
3754 // first element of the mulexpr.
3755 if (const auto *LHSCst = dyn_cast<SCEVConstant>(Mul->getOperand(0))) {
3756 if (LHSCst == RHSCst) {
3757 SmallVector<SCEVUse, 2> Operands(drop_begin(Mul->operands()));
3758 return getMulExpr(Operands);
3759 }
3760
3761 // We can't just assume that LHSCst divides RHSCst cleanly, it could be
3762 // that there's a factor provided by one of the other terms. We need to
3763 // check.
3764 APInt Factor = gcd(LHSCst, RHSCst);
3765 if (!Factor.isIntN(1)) {
3766 LHSCst =
3767 cast<SCEVConstant>(getConstant(LHSCst->getAPInt().udiv(Factor)));
3768 RHSCst =
3769 cast<SCEVConstant>(getConstant(RHSCst->getAPInt().udiv(Factor)));
3770 SmallVector<SCEVUse, 2> Operands;
3771 Operands.push_back(LHSCst);
3772 append_range(Operands, Mul->operands().drop_front());
3773 LHS = getMulExpr(Operands);
3774 RHS = RHSCst;
3776 if (!Mul)
3777 return getUDivExactExpr(LHS, RHS);
3778 }
3779 }
3780 }
3781
3782 for (int i = 0, e = Mul->getNumOperands(); i != e; ++i) {
3783 if (Mul->getOperand(i) == RHS) {
3784 SmallVector<SCEVUse, 2> Operands;
3785 append_range(Operands, Mul->operands().take_front(i));
3786 append_range(Operands, Mul->operands().drop_front(i + 1));
3787 return getMulExpr(Operands);
3788 }
3789 }
3790
3791 return getUDivExpr(LHS, RHS);
3792}
3793
3794/// Get an add recurrence expression for the specified loop. Simplify the
3795/// expression as much as possible.
3797 const Loop *L,
3798 SCEV::NoWrapFlags Flags) {
3799 SmallVector<SCEVUse, 4> Operands;
3800 Operands.push_back(Start);
3801 if (const SCEVAddRecExpr *StepChrec = dyn_cast<SCEVAddRecExpr>(Step))
3802 if (StepChrec->getLoop() == L) {
3803 append_range(Operands, StepChrec->operands());
3804 return getAddRecExpr(Operands, L, maskFlags(Flags, SCEV::FlagNW));
3805 }
3806
3807 Operands.push_back(Step);
3808 return getAddRecExpr(Operands, L, Flags);
3809}
3810
3811/// Get an add recurrence expression for the specified loop. Simplify the
3812/// expression as much as possible.
3814 const Loop *L,
3815 SCEV::NoWrapFlags Flags) {
3816 if (Operands.size() == 1) return Operands[0];
3817#ifndef NDEBUG
3818 Type *ETy = getEffectiveSCEVType(Operands[0]->getType());
3819 for (const SCEV *Op : llvm::drop_begin(Operands)) {
3820 assert(getEffectiveSCEVType(Op->getType()) == ETy &&
3821 "SCEVAddRecExpr operand types don't match!");
3822 assert(!Op->getType()->isPointerTy() && "Step must be integer");
3823 }
3824 for (const SCEV *Op : Operands)
3826 "SCEVAddRecExpr operand is not available at loop entry!");
3827#endif
3828
3829 if (Operands.back()->isZero()) {
3830 Operands.pop_back();
3831 return getAddRecExpr(Operands, L, SCEV::FlagAnyWrap); // {X,+,0} --> X
3832 }
3833
3834 // It's tempting to want to call getConstantMaxBackedgeTakenCount count here and
3835 // use that information to infer NUW and NSW flags. However, computing a
3836 // BE count requires calling getAddRecExpr, so we may not yet have a
3837 // meaningful BE count at this point (and if we don't, we'd be stuck
3838 // with a SCEVCouldNotCompute as the cached BE count).
3839
3840 Flags = StrengthenNoWrapFlags(this, scAddRecExpr, Operands, Flags);
3841
3842 // Canonicalize nested AddRecs in by nesting them in order of loop depth.
3843 if (const SCEVAddRecExpr *NestedAR = dyn_cast<SCEVAddRecExpr>(Operands[0])) {
3844 const Loop *NestedLoop = NestedAR->getLoop();
3845 if (L->contains(NestedLoop)
3846 ? (L->getLoopDepth() < NestedLoop->getLoopDepth())
3847 : (!NestedLoop->contains(L) &&
3848 DT.dominates(L->getHeader(), NestedLoop->getHeader()))) {
3849 SmallVector<SCEVUse, 4> NestedOperands(NestedAR->operands());
3850 Operands[0] = NestedAR->getStart();
3851 // AddRecs require their operands be loop-invariant with respect to their
3852 // loops. Don't perform this transformation if it would break this
3853 // requirement.
3854 bool AllInvariant = all_of(
3855 Operands, [&](const SCEV *Op) { return isLoopInvariant(Op, L); });
3856
3857 if (AllInvariant) {
3858 // Create a recurrence for the outer loop with the same step size.
3859 //
3860 // The outer recurrence keeps its NW flag but only keeps NUW/NSW if the
3861 // inner recurrence has the same property.
3862 SCEV::NoWrapFlags OuterFlags =
3863 maskFlags(Flags, SCEV::FlagNW | NestedAR->getNoWrapFlags());
3864
3865 NestedOperands[0] = getAddRecExpr(Operands, L, OuterFlags);
3866 AllInvariant = all_of(NestedOperands, [&](const SCEV *Op) {
3867 return isLoopInvariant(Op, NestedLoop);
3868 });
3869
3870 if (AllInvariant) {
3871 // Ok, both add recurrences are valid after the transformation.
3872 //
3873 // The inner recurrence keeps its NW flag but only keeps NUW/NSW if
3874 // the outer recurrence has the same property.
3875 SCEV::NoWrapFlags InnerFlags =
3876 maskFlags(NestedAR->getNoWrapFlags(), SCEV::FlagNW | Flags);
3877 return getAddRecExpr(NestedOperands, NestedLoop, InnerFlags);
3878 }
3879 }
3880 // Reset Operands to its original state.
3881 Operands[0] = NestedAR;
3882 }
3883 }
3884
3885 // Okay, it looks like we really DO need an addrec expr. Check to see if we
3886 // already have one, otherwise create a new one.
3887 return getOrCreateAddRecExpr(Operands, L, Flags);
3888}
3889
3891 ArrayRef<SCEVUse> IndexExprs) {
3892 const SCEV *BaseExpr = getSCEV(GEP->getPointerOperand());
3893 // getSCEV(Base)->getType() has the same address space as Base->getType()
3894 // because SCEV::getType() preserves the address space.
3895 GEPNoWrapFlags NW = GEP->getNoWrapFlags();
3896 if (NW != GEPNoWrapFlags::none()) {
3897 // We'd like to propagate flags from the IR to the corresponding SCEV nodes,
3898 // but to do that, we have to ensure that said flag is valid in the entire
3899 // defined scope of the SCEV.
3900 // TODO: non-instructions have global scope. We might be able to prove
3901 // some global scope cases
3902 auto *GEPI = dyn_cast<Instruction>(GEP);
3903 if (!GEPI || !isSCEVExprNeverPoison(GEPI))
3904 NW = GEPNoWrapFlags::none();
3905 }
3906
3907 return getGEPExpr(BaseExpr, IndexExprs, GEP->getSourceElementType(), NW);
3908}
3909
3911 ArrayRef<SCEVUse> IndexExprs,
3912 Type *SrcElementTy, GEPNoWrapFlags NW) {
3914 if (NW.hasNoUnsignedSignedWrap())
3915 OffsetWrap = setFlags(OffsetWrap, SCEV::FlagNSW);
3916 if (NW.hasNoUnsignedWrap())
3917 OffsetWrap = setFlags(OffsetWrap, SCEV::FlagNUW);
3918
3919 Type *CurTy = BaseExpr->getType();
3920 Type *IntIdxTy = getEffectiveSCEVType(BaseExpr->getType());
3921 bool FirstIter = true;
3923 for (SCEVUse IndexExpr : IndexExprs) {
3924 // Compute the (potentially symbolic) offset in bytes for this index.
3925 if (StructType *STy = dyn_cast<StructType>(CurTy)) {
3926 // For a struct, add the member offset.
3927 ConstantInt *Index = cast<SCEVConstant>(IndexExpr)->getValue();
3928 unsigned FieldNo = Index->getZExtValue();
3929 const SCEV *FieldOffset = getOffsetOfExpr(IntIdxTy, STy, FieldNo);
3930 Offsets.push_back(FieldOffset);
3931
3932 // Update CurTy to the type of the field at Index.
3933 CurTy = STy->getTypeAtIndex(Index);
3934 } else {
3935 // Update CurTy to its element type.
3936 if (FirstIter) {
3937 assert(isa<PointerType>(CurTy) &&
3938 "The first index of a GEP indexes a pointer");
3939 CurTy = SrcElementTy;
3940 FirstIter = false;
3941 } else {
3943 }
3944 // For an array, add the element offset, explicitly scaled.
3945 const SCEV *ElementSize = getSizeOfExpr(IntIdxTy, CurTy);
3946 // Getelementptr indices are signed.
3947 IndexExpr = getTruncateOrSignExtend(IndexExpr, IntIdxTy);
3948
3949 // Multiply the index by the element size to compute the element offset.
3950 const SCEV *LocalOffset = getMulExpr(IndexExpr, ElementSize, OffsetWrap);
3951 Offsets.push_back(LocalOffset);
3952 }
3953 }
3954
3955 // Handle degenerate case of GEP without offsets.
3956 if (Offsets.empty())
3957 return BaseExpr;
3958
3959 // Add the offsets together, assuming nsw if inbounds.
3960 const SCEV *Offset = getAddExpr(Offsets, OffsetWrap);
3961 // Add the base address and the offset. We cannot use the nsw flag, as the
3962 // base address is unsigned. However, if we know that the offset is
3963 // non-negative, we can use nuw.
3964 bool NUW = NW.hasNoUnsignedWrap() ||
3967 auto *GEPExpr = getAddExpr(BaseExpr, Offset, BaseWrap);
3968 assert(BaseExpr->getType() == GEPExpr->getType() &&
3969 "GEP should not change type mid-flight.");
3970 return GEPExpr;
3971}
3972
3973SCEV *ScalarEvolution::findExistingSCEVInCache(SCEVTypes SCEVType,
3976 ID.AddInteger(SCEVType);
3977 for (const SCEV *Op : Ops)
3978 ID.AddPointer(Op);
3979 void *IP = nullptr;
3980 return UniqueSCEVs.FindNodeOrInsertPos(ID, IP);
3981}
3982
3983SCEV *ScalarEvolution::findExistingSCEVInCache(SCEVTypes SCEVType,
3986 ID.AddInteger(SCEVType);
3987 for (const SCEV *Op : Ops)
3988 ID.AddPointer(Op);
3989 void *IP = nullptr;
3990 return UniqueSCEVs.FindNodeOrInsertPos(ID, IP);
3991}
3992
3993const SCEV *ScalarEvolution::getAbsExpr(const SCEV *Op, bool IsNSW) {
3995 return getSMaxExpr(Op, getNegativeSCEV(Op, Flags));
3996}
3997
4000 assert(SCEVMinMaxExpr::isMinMaxType(Kind) && "Not a SCEVMinMaxExpr!");
4001 assert(!Ops.empty() && "Cannot get empty (u|s)(min|max)!");
4002 if (Ops.size() == 1) return Ops[0];
4003#ifndef NDEBUG
4004 Type *ETy = getEffectiveSCEVType(Ops[0]->getType());
4005 for (unsigned i = 1, e = Ops.size(); i != e; ++i) {
4006 assert(getEffectiveSCEVType(Ops[i]->getType()) == ETy &&
4007 "Operand types don't match!");
4008 assert(Ops[0]->getType()->isPointerTy() ==
4009 Ops[i]->getType()->isPointerTy() &&
4010 "min/max should be consistently pointerish");
4011 }
4012#endif
4013
4014 bool IsSigned = Kind == scSMaxExpr || Kind == scSMinExpr;
4015 bool IsMax = Kind == scSMaxExpr || Kind == scUMaxExpr;
4016
4017 const SCEV *Folded = constantFoldAndGroupOps(
4018 *this, LI, DT, Ops,
4019 [&](const APInt &C1, const APInt &C2) {
4020 switch (Kind) {
4021 case scSMaxExpr:
4022 return APIntOps::smax(C1, C2);
4023 case scSMinExpr:
4024 return APIntOps::smin(C1, C2);
4025 case scUMaxExpr:
4026 return APIntOps::umax(C1, C2);
4027 case scUMinExpr:
4028 return APIntOps::umin(C1, C2);
4029 default:
4030 llvm_unreachable("Unknown SCEV min/max opcode");
4031 }
4032 },
4033 [&](const APInt &C) {
4034 // identity
4035 if (IsMax)
4036 return IsSigned ? C.isMinSignedValue() : C.isMinValue();
4037 else
4038 return IsSigned ? C.isMaxSignedValue() : C.isMaxValue();
4039 },
4040 [&](const APInt &C) {
4041 // absorber
4042 if (IsMax)
4043 return IsSigned ? C.isMaxSignedValue() : C.isMaxValue();
4044 else
4045 return IsSigned ? C.isMinSignedValue() : C.isMinValue();
4046 });
4047 if (Folded)
4048 return Folded;
4049
4050 // Check if we have created the same expression before.
4051 if (const SCEV *S = findExistingSCEVInCache(Kind, Ops)) {
4052 return S;
4053 }
4054
4055 // Find the first operation of the same kind
4056 unsigned Idx = 0;
4057 while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < Kind)
4058 ++Idx;
4059
4060 // Check to see if one of the operands is of the same kind. If so, expand its
4061 // operands onto our operand list, and recurse to simplify.
4062 if (Idx < Ops.size()) {
4063 bool DeletedAny = false;
4064 while (Ops[Idx]->getSCEVType() == Kind) {
4065 const SCEVMinMaxExpr *SMME = cast<SCEVMinMaxExpr>(Ops[Idx]);
4066 Ops.erase(Ops.begin()+Idx);
4067 append_range(Ops, SMME->operands());
4068 DeletedAny = true;
4069 }
4070
4071 if (DeletedAny)
4072 return getMinMaxExpr(Kind, Ops);
4073 }
4074
4075 // Okay, check to see if the same value occurs in the operand list twice. If
4076 // so, delete one. Since we sorted the list, these values are required to
4077 // be adjacent.
4082 llvm::CmpInst::Predicate FirstPred = IsMax ? GEPred : LEPred;
4083 llvm::CmpInst::Predicate SecondPred = IsMax ? LEPred : GEPred;
4084 for (unsigned i = 0, e = Ops.size() - 1; i != e; ++i) {
4085 if (Ops[i] == Ops[i + 1] ||
4086 isKnownViaNonRecursiveReasoning(FirstPred, Ops[i], Ops[i + 1])) {
4087 // X op Y op Y --> X op Y
4088 // X op Y --> X, if we know X, Y are ordered appropriately
4089 Ops.erase(Ops.begin() + i + 1, Ops.begin() + i + 2);
4090 --i;
4091 --e;
4092 } else if (isKnownViaNonRecursiveReasoning(SecondPred, Ops[i],
4093 Ops[i + 1])) {
4094 // X op Y --> Y, if we know X, Y are ordered appropriately
4095 Ops.erase(Ops.begin() + i, Ops.begin() + i + 1);
4096 --i;
4097 --e;
4098 }
4099 }
4100
4101 if (Ops.size() == 1) return Ops[0];
4102
4103 assert(!Ops.empty() && "Reduced smax down to nothing!");
4104
4105 // Okay, it looks like we really DO need an expr. Check to see if we
4106 // already have one, otherwise create a new one.
4108 ID.AddInteger(Kind);
4109 for (const SCEV *Op : Ops)
4110 ID.AddPointer(Op);
4111 void *IP = nullptr;
4112 const SCEV *ExistingSCEV = UniqueSCEVs.FindNodeOrInsertPos(ID, IP);
4113 if (ExistingSCEV)
4114 return ExistingSCEV;
4115 SCEVUse *O = SCEVAllocator.Allocate<SCEVUse>(Ops.size());
4117 SCEV *S = new (SCEVAllocator)
4118 SCEVMinMaxExpr(ID.Intern(SCEVAllocator), Kind, O, Ops.size());
4119
4120 UniqueSCEVs.InsertNode(S, IP);
4121 S->computeAndSetCanonical(*this);
4122 registerUser(S, Ops);
4123 return S;
4124}
4125
4126namespace {
4127
4128class SCEVSequentialMinMaxDeduplicatingVisitor final
4129 : public SCEVVisitor<SCEVSequentialMinMaxDeduplicatingVisitor,
4130 std::optional<const SCEV *>> {
4131 using RetVal = std::optional<const SCEV *>;
4133
4134 ScalarEvolution &SE;
4135 const SCEVTypes RootKind; // Must be a sequential min/max expression.
4136 const SCEVTypes NonSequentialRootKind; // Non-sequential variant of RootKind.
4138
4139 bool canRecurseInto(SCEVTypes Kind) const {
4140 // We can only recurse into the SCEV expression of the same effective type
4141 // as the type of our root SCEV expression.
4142 return RootKind == Kind || NonSequentialRootKind == Kind;
4143 };
4144
4145 RetVal visitAnyMinMaxExpr(const SCEV *S) {
4147 "Only for min/max expressions.");
4148 SCEVTypes Kind = S->getSCEVType();
4149
4150 if (!canRecurseInto(Kind))
4151 return S;
4152
4153 auto *NAry = cast<SCEVNAryExpr>(S);
4154 SmallVector<SCEVUse> NewOps;
4155 bool Changed = visit(Kind, NAry->operands(), NewOps);
4156
4157 if (!Changed)
4158 return S;
4159 if (NewOps.empty())
4160 return std::nullopt;
4161
4163 ? SE.getSequentialMinMaxExpr(Kind, NewOps)
4164 : SE.getMinMaxExpr(Kind, NewOps);
4165 }
4166
4167 RetVal visit(const SCEV *S) {
4168 // Has the whole operand been seen already?
4169 if (!SeenOps.insert(S).second)
4170 return std::nullopt;
4171 return Base::visit(S);
4172 }
4173
4174public:
4175 SCEVSequentialMinMaxDeduplicatingVisitor(ScalarEvolution &SE,
4176 SCEVTypes RootKind)
4177 : SE(SE), RootKind(RootKind),
4178 NonSequentialRootKind(
4179 SCEVSequentialMinMaxExpr::getEquivalentNonSequentialSCEVType(
4180 RootKind)) {}
4181
4182 bool /*Changed*/ visit(SCEVTypes Kind, ArrayRef<SCEVUse> OrigOps,
4183 SmallVectorImpl<SCEVUse> &NewOps) {
4184 bool Changed = false;
4186 Ops.reserve(OrigOps.size());
4187
4188 for (const SCEV *Op : OrigOps) {
4189 RetVal NewOp = visit(Op);
4190 if (NewOp != Op)
4191 Changed = true;
4192 if (NewOp)
4193 Ops.emplace_back(*NewOp);
4194 }
4195
4196 if (Changed)
4197 NewOps = std::move(Ops);
4198 return Changed;
4199 }
4200
4201 RetVal visitConstant(const SCEVConstant *Constant) { return Constant; }
4202
4203 RetVal visitVScale(const SCEVVScale *VScale) { return VScale; }
4204
4205 RetVal visitPtrToAddrExpr(const SCEVPtrToAddrExpr *Expr) { return Expr; }
4206
4207 RetVal visitPtrToIntExpr(const SCEVPtrToIntExpr *Expr) { return Expr; }
4208
4209 RetVal visitTruncateExpr(const SCEVTruncateExpr *Expr) { return Expr; }
4210
4211 RetVal visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr) { return Expr; }
4212
4213 RetVal visitSignExtendExpr(const SCEVSignExtendExpr *Expr) { return Expr; }
4214
4215 RetVal visitAddExpr(const SCEVAddExpr *Expr) { return Expr; }
4216
4217 RetVal visitMulExpr(const SCEVMulExpr *Expr) { return Expr; }
4218
4219 RetVal visitUDivExpr(const SCEVUDivExpr *Expr) { return Expr; }
4220
4221 RetVal visitAddRecExpr(const SCEVAddRecExpr *Expr) { return Expr; }
4222
4223 RetVal visitSMaxExpr(const SCEVSMaxExpr *Expr) {
4224 return visitAnyMinMaxExpr(Expr);
4225 }
4226
4227 RetVal visitUMaxExpr(const SCEVUMaxExpr *Expr) {
4228 return visitAnyMinMaxExpr(Expr);
4229 }
4230
4231 RetVal visitSMinExpr(const SCEVSMinExpr *Expr) {
4232 return visitAnyMinMaxExpr(Expr);
4233 }
4234
4235 RetVal visitUMinExpr(const SCEVUMinExpr *Expr) {
4236 return visitAnyMinMaxExpr(Expr);
4237 }
4238
4239 RetVal visitSequentialUMinExpr(const SCEVSequentialUMinExpr *Expr) {
4240 return visitAnyMinMaxExpr(Expr);
4241 }
4242
4243 RetVal visitUnknown(const SCEVUnknown *Expr) { return Expr; }
4244
4245 RetVal visitCouldNotCompute(const SCEVCouldNotCompute *Expr) { return Expr; }
4246};
4247
4248} // namespace
4249
4251 switch (Kind) {
4252 case scConstant:
4253 case scVScale:
4254 case scTruncate:
4255 case scZeroExtend:
4256 case scSignExtend:
4257 case scPtrToAddr:
4258 case scPtrToInt:
4259 case scAddExpr:
4260 case scMulExpr:
4261 case scUDivExpr:
4262 case scAddRecExpr:
4263 case scUMaxExpr:
4264 case scSMaxExpr:
4265 case scUMinExpr:
4266 case scSMinExpr:
4267 case scUnknown:
4268 // If any operand is poison, the whole expression is poison.
4269 return true;
4271 // FIXME: if the *first* operand is poison, the whole expression is poison.
4272 return false; // Pessimistically, say that it does not propagate poison.
4273 case scCouldNotCompute:
4274 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
4275 }
4276 llvm_unreachable("Unknown SCEV kind!");
4277}
4278
4279namespace {
4280// The only way poison may be introduced in a SCEV expression is from a
4281// poison SCEVUnknown (ConstantExprs are also represented as SCEVUnknown,
4282// not SCEVConstant). Notably, nowrap flags in SCEV nodes can *not*
4283// introduce poison -- they encode guaranteed, non-speculated knowledge.
4284//
4285// Additionally, all SCEV nodes propagate poison from inputs to outputs,
4286// with the notable exception of umin_seq, where only poison from the first
4287// operand is (unconditionally) propagated.
4288struct SCEVPoisonCollector {
4289 bool LookThroughMaybePoisonBlocking;
4290 SmallPtrSet<const SCEVUnknown *, 4> MaybePoison;
4291 SCEVPoisonCollector(bool LookThroughMaybePoisonBlocking)
4292 : LookThroughMaybePoisonBlocking(LookThroughMaybePoisonBlocking) {}
4293
4294 bool follow(const SCEV *S) {
4295 if (!LookThroughMaybePoisonBlocking &&
4297 return false;
4298
4299 if (auto *SU = dyn_cast<SCEVUnknown>(S)) {
4300 if (!isGuaranteedNotToBePoison(SU->getValue()))
4301 MaybePoison.insert(SU);
4302 }
4303 return true;
4304 }
4305 bool isDone() const { return false; }
4306};
4307} // namespace
4308
4309/// Return true if V is poison given that AssumedPoison is already poison.
4310static bool impliesPoison(const SCEV *AssumedPoison, const SCEV *S) {
4311 // First collect all SCEVs that might result in AssumedPoison to be poison.
4312 // We need to look through potentially poison-blocking operations here,
4313 // because we want to find all SCEVs that *might* result in poison, not only
4314 // those that are *required* to.
4315 SCEVPoisonCollector PC1(/* LookThroughMaybePoisonBlocking */ true);
4316 visitAll(AssumedPoison, PC1);
4317
4318 // AssumedPoison is never poison. As the assumption is false, the implication
4319 // is true. Don't bother walking the other SCEV in this case.
4320 if (PC1.MaybePoison.empty())
4321 return true;
4322
4323 // Collect all SCEVs in S that, if poison, *will* result in S being poison
4324 // as well. We cannot look through potentially poison-blocking operations
4325 // here, as their arguments only *may* make the result poison.
4326 SCEVPoisonCollector PC2(/* LookThroughMaybePoisonBlocking */ false);
4327 visitAll(S, PC2);
4328
4329 // Make sure that no matter which SCEV in PC1.MaybePoison is actually poison,
4330 // it will also make S poison by being part of PC2.MaybePoison.
4331 return llvm::set_is_subset(PC1.MaybePoison, PC2.MaybePoison);
4332}
4333
4335 SmallPtrSetImpl<const Value *> &Result, const SCEV *S) {
4336 SCEVPoisonCollector PC(/* LookThroughMaybePoisonBlocking */ false);
4337 visitAll(S, PC);
4338 for (const SCEVUnknown *SU : PC.MaybePoison)
4339 Result.insert(SU->getValue());
4340}
4341
4343 const SCEV *S, Instruction *I,
4344 SmallVectorImpl<Instruction *> &DropPoisonGeneratingInsts) {
4345 // If the instruction cannot be poison, it's always safe to reuse.
4347 return true;
4348
4349 // Otherwise, it is possible that I is more poisonous that S. Collect the
4350 // poison-contributors of S, and then check whether I has any additional
4351 // poison-contributors. Poison that is contributed through poison-generating
4352 // flags is handled by dropping those flags instead.
4354 getPoisonGeneratingValues(PoisonVals, S);
4355
4356 SmallVector<Value *> Worklist;
4358 Worklist.push_back(I);
4359 while (!Worklist.empty()) {
4360 Value *V = Worklist.pop_back_val();
4361 if (!Visited.insert(V).second)
4362 continue;
4363
4364 // Avoid walking large instruction graphs.
4365 if (Visited.size() > 16)
4366 return false;
4367
4368 // Either the value can't be poison, or the S would also be poison if it
4369 // is.
4370 if (PoisonVals.contains(V) || ::isGuaranteedNotToBePoison(V))
4371 continue;
4372
4373 auto *I = dyn_cast<Instruction>(V);
4374 if (!I)
4375 return false;
4376
4377 // Disjoint or instructions are interpreted as adds by SCEV. However, we
4378 // can't replace an arbitrary add with disjoint or, even if we drop the
4379 // flag. We would need to convert the or into an add.
4380 if (auto *PDI = dyn_cast<PossiblyDisjointInst>(I))
4381 if (PDI->isDisjoint())
4382 return false;
4383
4384 // FIXME: Ignore vscale, even though it technically could be poison. Do this
4385 // because SCEV currently assumes it can't be poison. Remove this special
4386 // case once we proper model when vscale can be poison.
4387 if (auto *II = dyn_cast<IntrinsicInst>(I);
4388 II && II->getIntrinsicID() == Intrinsic::vscale)
4389 continue;
4390
4391 if (canCreatePoison(cast<Operator>(I), /*ConsiderFlagsAndMetadata*/ false))
4392 return false;
4393
4394 // If the instruction can't create poison, we can recurse to its operands.
4395 if (I->hasPoisonGeneratingAnnotations())
4396 DropPoisonGeneratingInsts.push_back(I);
4397
4398 llvm::append_range(Worklist, I->operands());
4399 }
4400 return true;
4401}
4402
4403const SCEV *
4406 assert(SCEVSequentialMinMaxExpr::isSequentialMinMaxType(Kind) &&
4407 "Not a SCEVSequentialMinMaxExpr!");
4408 assert(!Ops.empty() && "Cannot get empty (u|s)(min|max)!");
4409 if (Ops.size() == 1)
4410 return Ops[0];
4411#ifndef NDEBUG
4412 Type *ETy = getEffectiveSCEVType(Ops[0]->getType());
4413 for (unsigned i = 1, e = Ops.size(); i != e; ++i) {
4414 assert(getEffectiveSCEVType(Ops[i]->getType()) == ETy &&
4415 "Operand types don't match!");
4416 assert(Ops[0]->getType()->isPointerTy() ==
4417 Ops[i]->getType()->isPointerTy() &&
4418 "min/max should be consistently pointerish");
4419 }
4420#endif
4421
4422 // Note that SCEVSequentialMinMaxExpr is *NOT* commutative,
4423 // so we can *NOT* do any kind of sorting of the expressions!
4424
4425 // Check if we have created the same expression before.
4426 if (const SCEV *S = findExistingSCEVInCache(Kind, Ops))
4427 return S;
4428
4429 // FIXME: there are *some* simplifications that we can do here.
4430
4431 // Keep only the first instance of an operand.
4432 {
4433 SCEVSequentialMinMaxDeduplicatingVisitor Deduplicator(*this, Kind);
4434 bool Changed = Deduplicator.visit(Kind, Ops, Ops);
4435 if (Changed)
4436 return getSequentialMinMaxExpr(Kind, Ops);
4437 }
4438
4439 // Check to see if one of the operands is of the same kind. If so, expand its
4440 // operands onto our operand list, and recurse to simplify.
4441 {
4442 unsigned Idx = 0;
4443 bool DeletedAny = false;
4444 while (Idx < Ops.size()) {
4445 if (Ops[Idx]->getSCEVType() != Kind) {
4446 ++Idx;
4447 continue;
4448 }
4449 const auto *SMME = cast<SCEVSequentialMinMaxExpr>(Ops[Idx]);
4450 Ops.erase(Ops.begin() + Idx);
4451 Ops.insert(Ops.begin() + Idx, SMME->operands().begin(),
4452 SMME->operands().end());
4453 DeletedAny = true;
4454 }
4455
4456 if (DeletedAny)
4457 return getSequentialMinMaxExpr(Kind, Ops);
4458 }
4459
4460 const SCEV *SaturationPoint;
4462 switch (Kind) {
4464 SaturationPoint = getZero(Ops[0]->getType());
4465 Pred = ICmpInst::ICMP_ULE;
4466 break;
4467 default:
4468 llvm_unreachable("Not a sequential min/max type.");
4469 }
4470
4471 for (unsigned i = 1, e = Ops.size(); i != e; ++i) {
4472 if (!isGuaranteedNotToCauseUB(Ops[i]))
4473 continue;
4474 // We can replace %x umin_seq %y with %x umin %y if either:
4475 // * %y being poison implies %x is also poison.
4476 // * %x cannot be the saturating value (e.g. zero for umin).
4477 if (::impliesPoison(Ops[i], Ops[i - 1]) ||
4478 isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_NE, Ops[i - 1],
4479 SaturationPoint)) {
4480 SmallVector<SCEVUse, 2> SeqOps = {Ops[i - 1], Ops[i]};
4481 Ops[i - 1] = getMinMaxExpr(
4483 SeqOps);
4484 Ops.erase(Ops.begin() + i);
4485 return getSequentialMinMaxExpr(Kind, Ops);
4486 }
4487 // Fold %x umin_seq %y to %x if %x ule %y.
4488 // TODO: We might be able to prove the predicate for a later operand.
4489 if (isKnownViaNonRecursiveReasoning(Pred, Ops[i - 1], Ops[i])) {
4490 Ops.erase(Ops.begin() + i);
4491 return getSequentialMinMaxExpr(Kind, Ops);
4492 }
4493 }
4494
4495 // Okay, it looks like we really DO need an expr. Check to see if we
4496 // already have one, otherwise create a new one.
4498 ID.AddInteger(Kind);
4499 for (const SCEV *Op : Ops)
4500 ID.AddPointer(Op);
4501 void *IP = nullptr;
4502 const SCEV *ExistingSCEV = UniqueSCEVs.FindNodeOrInsertPos(ID, IP);
4503 if (ExistingSCEV)
4504 return ExistingSCEV;
4505
4506 SCEVUse *O = SCEVAllocator.Allocate<SCEVUse>(Ops.size());
4508 SCEV *S = new (SCEVAllocator)
4509 SCEVSequentialMinMaxExpr(ID.Intern(SCEVAllocator), Kind, O, Ops.size());
4510
4511 UniqueSCEVs.InsertNode(S, IP);
4512 S->computeAndSetCanonical(*this);
4513 registerUser(S, Ops);
4514 return S;
4515}
4516
4521
4525
4530
4534
4539
4543
4545 bool Sequential) {
4546 SmallVector<SCEVUse, 2> Ops = {LHS, RHS};
4547 return getUMinExpr(Ops, Sequential);
4548}
4549
4555
4556const SCEV *
4558 const SCEV *Res = getConstant(IntTy, Size.getKnownMinValue());
4559 if (Size.isScalable())
4560 Res = getMulExpr(Res, getVScale(IntTy));
4561 return Res;
4562}
4563
4565 return getSizeOfExpr(IntTy, getDataLayout().getTypeAllocSize(AllocTy));
4566}
4567
4569 return getSizeOfExpr(IntTy, getDataLayout().getTypeStoreSize(StoreTy));
4570}
4571
4573 StructType *STy,
4574 unsigned FieldNo) {
4575 // We can bypass creating a target-independent constant expression and then
4576 // folding it back into a ConstantInt. This is just a compile-time
4577 // optimization.
4578 const StructLayout *SL = getDataLayout().getStructLayout(STy);
4579 assert(!SL->getSizeInBits().isScalable() &&
4580 "Cannot get offset for structure containing scalable vector types");
4581 return getConstant(IntTy, SL->getElementOffset(FieldNo));
4582}
4583
4585 // Don't attempt to do anything other than create a SCEVUnknown object
4586 // here. createSCEV only calls getUnknown after checking for all other
4587 // interesting possibilities, and any other code that calls getUnknown
4588 // is doing so in order to hide a value from SCEV canonicalization.
4589
4591 ID.AddInteger(scUnknown);
4592 ID.AddPointer(V);
4593 void *IP = nullptr;
4594 if (SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) {
4595 assert(cast<SCEVUnknown>(S)->getValue() == V &&
4596 "Stale SCEVUnknown in uniquing map!");
4597 return S;
4598 }
4599 SCEV *S = new (SCEVAllocator) SCEVUnknown(ID.Intern(SCEVAllocator), V, this,
4600 FirstUnknown);
4601 FirstUnknown = cast<SCEVUnknown>(S);
4602 UniqueSCEVs.InsertNode(S, IP);
4603 S->computeAndSetCanonical(*this);
4604 return S;
4605}
4606
4607//===----------------------------------------------------------------------===//
4608// Basic SCEV Analysis and PHI Idiom Recognition Code
4609//
4610
4611/// Test if values of the given type are analyzable within the SCEV
4612/// framework. This primarily includes integer types, and it can optionally
4613/// include pointer types if the ScalarEvolution class has access to
4614/// target-specific information.
4616 // Integers and pointers are always SCEVable.
4617 return Ty->isIntOrPtrTy();
4618}
4619
4620/// Return the size in bits of the specified type, for which isSCEVable must
4621/// return true.
4623 assert(isSCEVable(Ty) && "Type is not SCEVable!");
4624 if (Ty->isPointerTy())
4626 return getDataLayout().getTypeSizeInBits(Ty);
4627}
4628
4629/// Return a type with the same bitwidth as the given type and which represents
4630/// how SCEV will treat the given type, for which isSCEVable must return
4631/// true. For pointer types, this is the pointer index sized integer type.
4633 assert(isSCEVable(Ty) && "Type is not SCEVable!");
4634
4635 if (Ty->isIntegerTy())
4636 return Ty;
4637
4638 // The only other support type is pointer.
4639 assert(Ty->isPointerTy() && "Unexpected non-pointer non-integer type!");
4640 return getDataLayout().getIndexType(Ty);
4641}
4642
4644 return getTypeSizeInBits(T1) >= getTypeSizeInBits(T2) ? T1 : T2;
4645}
4646
4648 const SCEV *B) {
4649 /// For a valid use point to exist, the defining scope of one operand
4650 /// must dominate the other.
4651 bool PreciseA, PreciseB;
4652 auto *ScopeA = getDefiningScopeBound({A}, PreciseA);
4653 auto *ScopeB = getDefiningScopeBound({B}, PreciseB);
4654 if (!PreciseA || !PreciseB)
4655 // Can't tell.
4656 return false;
4657 return (ScopeA == ScopeB) || DT.dominates(ScopeA, ScopeB) ||
4658 DT.dominates(ScopeB, ScopeA);
4659}
4660
4662 return CouldNotCompute.get();
4663}
4664
4665bool ScalarEvolution::checkValidity(const SCEV *S) const {
4666 bool ContainsNulls = SCEVExprContains(S, [](const SCEV *S) {
4667 auto *SU = dyn_cast<SCEVUnknown>(S);
4668 return SU && SU->getValue() == nullptr;
4669 });
4670
4671 return !ContainsNulls;
4672}
4673
4675 HasRecMapType::iterator I = HasRecMap.find(S);
4676 if (I != HasRecMap.end())
4677 return I->second;
4678
4679 bool FoundAddRec =
4680 SCEVExprContains(S, [](const SCEV *S) { return isa<SCEVAddRecExpr>(S); });
4681 HasRecMap.insert({S, FoundAddRec});
4682 return FoundAddRec;
4683}
4684
4685/// Return the ValueOffsetPair set for \p S. \p S can be represented
4686/// by the value and offset from any ValueOffsetPair in the set.
4687ArrayRef<Value *> ScalarEvolution::getSCEVValues(const SCEV *S) {
4688 ExprValueMapType::iterator SI = ExprValueMap.find_as(S);
4689 if (SI == ExprValueMap.end())
4690 return {};
4691 return SI->second.getArrayRef();
4692}
4693
4694/// Erase Value from ValueExprMap and ExprValueMap. ValueExprMap.erase(V)
4695/// cannot be used separately. eraseValueFromMap should be used to remove
4696/// V from ValueExprMap and ExprValueMap at the same time.
4697void ScalarEvolution::eraseValueFromMap(Value *V) {
4698 ValueExprMapType::iterator I = ValueExprMap.find_as(V);
4699 if (I != ValueExprMap.end()) {
4700 auto EVIt = ExprValueMap.find(I->second);
4701 bool Removed = EVIt->second.remove(V);
4702 (void) Removed;
4703 assert(Removed && "Value not in ExprValueMap?");
4704 ValueExprMap.erase(I);
4705 }
4706}
4707
4708void ScalarEvolution::insertValueToMap(Value *V, const SCEV *S) {
4709 // A recursive query may have already computed the SCEV. It should be
4710 // equivalent, but may not necessarily be exactly the same, e.g. due to lazily
4711 // inferred nowrap flags.
4712 auto It = ValueExprMap.find_as(V);
4713 if (It == ValueExprMap.end()) {
4714 ValueExprMap.insert({SCEVCallbackVH(V, this), S});
4715 ExprValueMap[S].insert(V);
4716 }
4717}
4718
4719/// Return an existing SCEV if it exists, otherwise analyze the expression and
4720/// create a new one.
4722 assert(isSCEVable(V->getType()) && "Value is not SCEVable!");
4723
4724 if (const SCEV *S = getExistingSCEV(V))
4725 return S;
4726 return createSCEVIter(V);
4727}
4728
4730 assert(isSCEVable(V->getType()) && "Value is not SCEVable!");
4731
4732 ValueExprMapType::iterator I = ValueExprMap.find_as(V);
4733 if (I != ValueExprMap.end()) {
4734 const SCEV *S = I->second;
4735 assert(checkValidity(S) &&
4736 "existing SCEV has not been properly invalidated");
4737 return S;
4738 }
4739 return nullptr;
4740}
4741
4742/// Return a SCEV corresponding to -V = -1*V
4744 SCEV::NoWrapFlags Flags) {
4745 if (const SCEVConstant *VC = dyn_cast<SCEVConstant>(V))
4746 return getConstant(
4747 cast<ConstantInt>(ConstantExpr::getNeg(VC->getValue())));
4748
4749 Type *Ty = V->getType();
4750 Ty = getEffectiveSCEVType(Ty);
4751 return getMulExpr(V, getMinusOne(Ty), Flags);
4752}
4753
4754/// If Expr computes ~A, return A else return nullptr
4755static const SCEV *MatchNotExpr(const SCEV *Expr) {
4756 const SCEV *MulOp;
4757 if (match(Expr, m_scev_Add(m_scev_AllOnes(),
4758 m_scev_Mul(m_scev_AllOnes(), m_SCEV(MulOp)))))
4759 return MulOp;
4760 return nullptr;
4761}
4762
4763/// Return a SCEV corresponding to ~V = -1-V
4765 assert(!V->getType()->isPointerTy() && "Can't negate pointer");
4766
4767 if (const SCEVConstant *VC = dyn_cast<SCEVConstant>(V))
4768 return getConstant(
4769 cast<ConstantInt>(ConstantExpr::getNot(VC->getValue())));
4770
4771 // Fold ~(u|s)(min|max)(~x, ~y) to (u|s)(max|min)(x, y)
4772 if (const SCEVMinMaxExpr *MME = dyn_cast<SCEVMinMaxExpr>(V)) {
4773 auto MatchMinMaxNegation = [&](const SCEVMinMaxExpr *MME) {
4774 SmallVector<SCEVUse, 2> MatchedOperands;
4775 for (const SCEV *Operand : MME->operands()) {
4776 const SCEV *Matched = MatchNotExpr(Operand);
4777 if (!Matched)
4778 return (const SCEV *)nullptr;
4779 MatchedOperands.push_back(Matched);
4780 }
4781 return getMinMaxExpr(SCEVMinMaxExpr::negate(MME->getSCEVType()),
4782 MatchedOperands);
4783 };
4784 if (const SCEV *Replaced = MatchMinMaxNegation(MME))
4785 return Replaced;
4786 }
4787
4788 Type *Ty = V->getType();
4789 Ty = getEffectiveSCEVType(Ty);
4790 return getMinusSCEV(getMinusOne(Ty), V);
4791}
4792
4794 assert(P->getType()->isPointerTy());
4795
4796 if (auto *AddRec = dyn_cast<SCEVAddRecExpr>(P)) {
4797 // The base of an AddRec is the first operand.
4798 SmallVector<SCEVUse> Ops{AddRec->operands()};
4799 Ops[0] = removePointerBase(Ops[0]);
4800 // Don't try to transfer nowrap flags for now. We could in some cases
4801 // (for example, if pointer operand of the AddRec is a SCEVUnknown).
4802 return getAddRecExpr(Ops, AddRec->getLoop(), SCEV::FlagAnyWrap);
4803 }
4804 if (auto *Add = dyn_cast<SCEVAddExpr>(P)) {
4805 // The base of an Add is the pointer operand.
4806 SmallVector<SCEVUse> Ops{Add->operands()};
4807 SCEVUse *PtrOp = nullptr;
4808 for (SCEVUse &AddOp : Ops) {
4809 if (AddOp->getType()->isPointerTy()) {
4810 assert(!PtrOp && "Cannot have multiple pointer ops");
4811 PtrOp = &AddOp;
4812 }
4813 }
4814 *PtrOp = removePointerBase(*PtrOp);
4815 // Don't try to transfer nowrap flags for now. We could in some cases
4816 // (for example, if the pointer operand of the Add is a SCEVUnknown).
4817 return getAddExpr(Ops);
4818 }
4819 // Any other expression must be a pointer base.
4820 return getZero(P->getType());
4821}
4822
4824 SCEV::NoWrapFlags Flags,
4825 unsigned Depth) {
4826 // Fast path: X - X --> 0.
4827 if (LHS == RHS)
4828 return getZero(LHS->getType());
4829
4830 // If we subtract two pointers with different pointer bases, bail.
4831 // Eventually, we're going to add an assertion to getMulExpr that we
4832 // can't multiply by a pointer.
4833 if (RHS->getType()->isPointerTy()) {
4834 if (!LHS->getType()->isPointerTy() ||
4835 getPointerBase(LHS) != getPointerBase(RHS))
4836 return getCouldNotCompute();
4837 LHS = removePointerBase(LHS);
4838 RHS = removePointerBase(RHS);
4839 }
4840
4841 // We represent LHS - RHS as LHS + (-1)*RHS. This transformation
4842 // makes it so that we cannot make much use of NUW.
4843 auto AddFlags = SCEV::FlagAnyWrap;
4844 const bool RHSIsNotMinSigned =
4846 if (hasFlags(Flags, SCEV::FlagNSW)) {
4847 // Let M be the minimum representable signed value. Then (-1)*RHS
4848 // signed-wraps if and only if RHS is M. That can happen even for
4849 // a NSW subtraction because e.g. (-1)*M signed-wraps even though
4850 // -1 - M does not. So to transfer NSW from LHS - RHS to LHS +
4851 // (-1)*RHS, we need to prove that RHS != M.
4852 //
4853 // If LHS is non-negative and we know that LHS - RHS does not
4854 // signed-wrap, then RHS cannot be M. So we can rule out signed-wrap
4855 // either by proving that RHS > M or that LHS >= 0.
4856 if (RHSIsNotMinSigned || isKnownNonNegative(LHS)) {
4857 AddFlags = SCEV::FlagNSW;
4858 }
4859 }
4860
4861 // FIXME: Find a correct way to transfer NSW to (-1)*M when LHS -
4862 // RHS is NSW and LHS >= 0.
4863 //
4864 // The difficulty here is that the NSW flag may have been proven
4865 // relative to a loop that is to be found in a recurrence in LHS and
4866 // not in RHS. Applying NSW to (-1)*M may then let the NSW have a
4867 // larger scope than intended.
4868 auto NegFlags = RHSIsNotMinSigned ? SCEV::FlagNSW : SCEV::FlagAnyWrap;
4869
4870 return getAddExpr(LHS, getNegativeSCEV(RHS, NegFlags), AddFlags, Depth);
4871}
4872
4874 unsigned Depth) {
4875 Type *SrcTy = V->getType();
4876 assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
4877 "Cannot truncate or zero extend with non-integer arguments!");
4878 if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
4879 return V; // No conversion
4880 if (getTypeSizeInBits(SrcTy) > getTypeSizeInBits(Ty))
4881 return getTruncateExpr(V, Ty, Depth);
4882 return getZeroExtendExpr(V, Ty, Depth);
4883}
4884
4886 unsigned Depth) {
4887 Type *SrcTy = V->getType();
4888 assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
4889 "Cannot truncate or zero extend with non-integer arguments!");
4890 if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
4891 return V; // No conversion
4892 if (getTypeSizeInBits(SrcTy) > getTypeSizeInBits(Ty))
4893 return getTruncateExpr(V, Ty, Depth);
4894 return getSignExtendExpr(V, Ty, Depth);
4895}
4896
4897const SCEV *
4899 Type *SrcTy = V->getType();
4900 assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
4901 "Cannot noop or zero extend with non-integer arguments!");
4903 "getNoopOrZeroExtend cannot truncate!");
4904 if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
4905 return V; // No conversion
4906 return getZeroExtendExpr(V, Ty);
4907}
4908
4909const SCEV *
4911 Type *SrcTy = V->getType();
4912 assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
4913 "Cannot noop or sign extend with non-integer arguments!");
4915 "getNoopOrSignExtend cannot truncate!");
4916 if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
4917 return V; // No conversion
4918 return getSignExtendExpr(V, Ty);
4919}
4920
4921const SCEV *
4923 Type *SrcTy = V->getType();
4924 assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
4925 "Cannot noop or any extend with non-integer arguments!");
4927 "getNoopOrAnyExtend cannot truncate!");
4928 if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
4929 return V; // No conversion
4930 return getAnyExtendExpr(V, Ty);
4931}
4932
4933const SCEV *
4935 Type *SrcTy = V->getType();
4936 assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
4937 "Cannot truncate or noop with non-integer arguments!");
4939 "getTruncateOrNoop cannot extend!");
4940 if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
4941 return V; // No conversion
4942 return getTruncateExpr(V, Ty);
4943}
4944
4946 const SCEV *RHS) {
4947 const SCEV *PromotedLHS = LHS;
4948 const SCEV *PromotedRHS = RHS;
4949
4950 if (getTypeSizeInBits(LHS->getType()) > getTypeSizeInBits(RHS->getType()))
4951 PromotedRHS = getZeroExtendExpr(RHS, LHS->getType());
4952 else
4953 PromotedLHS = getNoopOrZeroExtend(LHS, RHS->getType());
4954
4955 return getUMaxExpr(PromotedLHS, PromotedRHS);
4956}
4957
4959 const SCEV *RHS,
4960 bool Sequential) {
4961 SmallVector<SCEVUse, 2> Ops = {LHS, RHS};
4962 return getUMinFromMismatchedTypes(Ops, Sequential);
4963}
4964
4965const SCEV *
4967 bool Sequential) {
4968 assert(!Ops.empty() && "At least one operand must be!");
4969 // Trivial case.
4970 if (Ops.size() == 1)
4971 return Ops[0];
4972
4973 // Find the max type first.
4974 Type *MaxType = nullptr;
4975 for (SCEVUse S : Ops)
4976 if (MaxType)
4977 MaxType = getWiderType(MaxType, S->getType());
4978 else
4979 MaxType = S->getType();
4980 assert(MaxType && "Failed to find maximum type!");
4981
4982 // Extend all ops to max type.
4983 SmallVector<SCEVUse, 2> PromotedOps;
4984 for (SCEVUse S : Ops)
4985 PromotedOps.push_back(getNoopOrZeroExtend(S, MaxType));
4986
4987 // Generate umin.
4988 return getUMinExpr(PromotedOps, Sequential);
4989}
4990
4992 // A pointer operand may evaluate to a nonpointer expression, such as null.
4993 if (!V->getType()->isPointerTy())
4994 return V;
4995
4996 while (true) {
4997 if (auto *AddRec = dyn_cast<SCEVAddRecExpr>(V)) {
4998 V = AddRec->getStart();
4999 } else if (auto *Add = dyn_cast<SCEVAddExpr>(V)) {
5000 const SCEV *PtrOp = nullptr;
5001 for (const SCEV *AddOp : Add->operands()) {
5002 if (AddOp->getType()->isPointerTy()) {
5003 assert(!PtrOp && "Cannot have multiple pointer ops");
5004 PtrOp = AddOp;
5005 }
5006 }
5007 assert(PtrOp && "Must have pointer op");
5008 V = PtrOp;
5009 } else // Not something we can look further into.
5010 return V;
5011 }
5012}
5013
5014/// Push users of the given Instruction onto the given Worklist.
5018 // Push the def-use children onto the Worklist stack.
5019 for (User *U : I->users()) {
5020 auto *UserInsn = cast<Instruction>(U);
5021 if (Visited.insert(UserInsn).second)
5022 Worklist.push_back(UserInsn);
5023 }
5024}
5025
5026namespace {
5027
5028/// Takes SCEV S and Loop L. For each AddRec sub-expression, use its start
5029/// expression in case its Loop is L. If it is not L then
5030/// if IgnoreOtherLoops is true then use AddRec itself
5031/// otherwise rewrite cannot be done.
5032/// If SCEV contains non-invariant unknown SCEV rewrite cannot be done.
5033class SCEVInitRewriter : public SCEVRewriteVisitor<SCEVInitRewriter> {
5034public:
5035 static const SCEV *rewrite(const SCEV *S, const Loop *L, ScalarEvolution &SE,
5036 bool IgnoreOtherLoops = true) {
5037 SCEVInitRewriter Rewriter(L, SE);
5038 const SCEV *Result = Rewriter.visit(S);
5039 if (Rewriter.hasSeenLoopVariantSCEVUnknown())
5040 return SE.getCouldNotCompute();
5041 return Rewriter.hasSeenOtherLoops() && !IgnoreOtherLoops
5042 ? SE.getCouldNotCompute()
5043 : Result;
5044 }
5045
5046 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
5047 if (!SE.isLoopInvariant(Expr, L))
5048 SeenLoopVariantSCEVUnknown = true;
5049 return Expr;
5050 }
5051
5052 const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) {
5053 // Only re-write AddRecExprs for this loop.
5054 if (Expr->getLoop() == L)
5055 return Expr->getStart();
5056 SeenOtherLoops = true;
5057 return Expr;
5058 }
5059
5060 bool hasSeenLoopVariantSCEVUnknown() { return SeenLoopVariantSCEVUnknown; }
5061
5062 bool hasSeenOtherLoops() { return SeenOtherLoops; }
5063
5064private:
5065 explicit SCEVInitRewriter(const Loop *L, ScalarEvolution &SE)
5066 : SCEVRewriteVisitor(SE), L(L) {}
5067
5068 const Loop *L;
5069 bool SeenLoopVariantSCEVUnknown = false;
5070 bool SeenOtherLoops = false;
5071};
5072
5073/// Takes SCEV S and Loop L. For each AddRec sub-expression, use its post
5074/// increment expression in case its Loop is L. If it is not L then
5075/// use AddRec itself.
5076/// If SCEV contains non-invariant unknown SCEV rewrite cannot be done.
5077class SCEVPostIncRewriter : public SCEVRewriteVisitor<SCEVPostIncRewriter> {
5078public:
5079 static const SCEV *rewrite(const SCEV *S, const Loop *L, ScalarEvolution &SE) {
5080 SCEVPostIncRewriter Rewriter(L, SE);
5081 const SCEV *Result = Rewriter.visit(S);
5082 return Rewriter.hasSeenLoopVariantSCEVUnknown()
5083 ? SE.getCouldNotCompute()
5084 : Result;
5085 }
5086
5087 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
5088 if (!SE.isLoopInvariant(Expr, L))
5089 SeenLoopVariantSCEVUnknown = true;
5090 return Expr;
5091 }
5092
5093 const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) {
5094 // Only re-write AddRecExprs for this loop.
5095 if (Expr->getLoop() == L)
5096 return Expr->getPostIncExpr(SE);
5097 SeenOtherLoops = true;
5098 return Expr;
5099 }
5100
5101 bool hasSeenLoopVariantSCEVUnknown() { return SeenLoopVariantSCEVUnknown; }
5102
5103 bool hasSeenOtherLoops() { return SeenOtherLoops; }
5104
5105private:
5106 explicit SCEVPostIncRewriter(const Loop *L, ScalarEvolution &SE)
5107 : SCEVRewriteVisitor(SE), L(L) {}
5108
5109 const Loop *L;
5110 bool SeenLoopVariantSCEVUnknown = false;
5111 bool SeenOtherLoops = false;
5112};
5113
5114/// This class evaluates the compare condition by matching it against the
5115/// condition of loop latch. If there is a match we assume a true value
5116/// for the condition while building SCEV nodes.
5117class SCEVBackedgeConditionFolder
5118 : public SCEVRewriteVisitor<SCEVBackedgeConditionFolder> {
5119public:
5120 static const SCEV *rewrite(const SCEV *S, const Loop *L,
5121 ScalarEvolution &SE) {
5122 bool IsPosBECond = false;
5123 Value *BECond = nullptr;
5124 if (BasicBlock *Latch = L->getLoopLatch()) {
5125 if (CondBrInst *BI = dyn_cast<CondBrInst>(Latch->getTerminator())) {
5126 assert(BI->getSuccessor(0) != BI->getSuccessor(1) &&
5127 "Both outgoing branches should not target same header!");
5128 BECond = BI->getCondition();
5129 IsPosBECond = BI->getSuccessor(0) == L->getHeader();
5130 } else {
5131 return S;
5132 }
5133 }
5134 SCEVBackedgeConditionFolder Rewriter(L, BECond, IsPosBECond, SE);
5135 return Rewriter.visit(S);
5136 }
5137
5138 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
5139 const SCEV *Result = Expr;
5140 bool InvariantF = SE.isLoopInvariant(Expr, L);
5141
5142 if (!InvariantF) {
5144 switch (I->getOpcode()) {
5145 case Instruction::Select: {
5146 SelectInst *SI = cast<SelectInst>(I);
5147 std::optional<const SCEV *> Res =
5148 compareWithBackedgeCondition(SI->getCondition());
5149 if (Res) {
5150 bool IsOne = cast<SCEVConstant>(*Res)->getValue()->isOne();
5151 Result = SE.getSCEV(IsOne ? SI->getTrueValue() : SI->getFalseValue());
5152 }
5153 break;
5154 }
5155 default: {
5156 std::optional<const SCEV *> Res = compareWithBackedgeCondition(I);
5157 if (Res)
5158 Result = *Res;
5159 break;
5160 }
5161 }
5162 }
5163 return Result;
5164 }
5165
5166private:
5167 explicit SCEVBackedgeConditionFolder(const Loop *L, Value *BECond,
5168 bool IsPosBECond, ScalarEvolution &SE)
5169 : SCEVRewriteVisitor(SE), L(L), BackedgeCond(BECond),
5170 IsPositiveBECond(IsPosBECond) {}
5171
5172 std::optional<const SCEV *> compareWithBackedgeCondition(Value *IC);
5173
5174 const Loop *L;
5175 /// Loop back condition.
5176 Value *BackedgeCond = nullptr;
5177 /// Set to true if loop back is on positive branch condition.
5178 bool IsPositiveBECond;
5179};
5180
5181std::optional<const SCEV *>
5182SCEVBackedgeConditionFolder::compareWithBackedgeCondition(Value *IC) {
5183
5184 // If value matches the backedge condition for loop latch,
5185 // then return a constant evolution node based on loopback
5186 // branch taken.
5187 if (BackedgeCond == IC)
5188 return IsPositiveBECond ? SE.getOne(Type::getInt1Ty(SE.getContext()))
5190 return std::nullopt;
5191}
5192
5193class SCEVShiftRewriter : public SCEVRewriteVisitor<SCEVShiftRewriter> {
5194public:
5195 static const SCEV *rewrite(const SCEV *S, const Loop *L,
5196 ScalarEvolution &SE) {
5197 SCEVShiftRewriter Rewriter(L, SE);
5198 const SCEV *Result = Rewriter.visit(S);
5199 return Rewriter.isValid() ? Result : SE.getCouldNotCompute();
5200 }
5201
5202 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
5203 // Only allow AddRecExprs for this loop.
5204 if (!SE.isLoopInvariant(Expr, L))
5205 Valid = false;
5206 return Expr;
5207 }
5208
5209 const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) {
5210 if (Expr->getLoop() == L && Expr->isAffine())
5211 return SE.getMinusSCEV(Expr, Expr->getStepRecurrence(SE));
5212 Valid = false;
5213 return Expr;
5214 }
5215
5216 bool isValid() { return Valid; }
5217
5218private:
5219 explicit SCEVShiftRewriter(const Loop *L, ScalarEvolution &SE)
5220 : SCEVRewriteVisitor(SE), L(L) {}
5221
5222 const Loop *L;
5223 bool Valid = true;
5224};
5225
5226} // end anonymous namespace
5227
5229ScalarEvolution::proveNoWrapViaConstantRanges(const SCEVAddRecExpr *AR) {
5230 if (!AR->isAffine())
5231 return SCEV::FlagAnyWrap;
5232
5233 using OBO = OverflowingBinaryOperator;
5234
5236
5237 if (!AR->hasNoSelfWrap()) {
5238 const SCEV *BECount = getConstantMaxBackedgeTakenCount(AR->getLoop());
5239 if (const SCEVConstant *BECountMax = dyn_cast<SCEVConstant>(BECount)) {
5240 ConstantRange StepCR = getSignedRange(AR->getStepRecurrence(*this));
5241 const APInt &BECountAP = BECountMax->getAPInt();
5242 unsigned NoOverflowBitWidth =
5243 BECountAP.getActiveBits() + StepCR.getMinSignedBits();
5244 if (NoOverflowBitWidth <= getTypeSizeInBits(AR->getType()))
5246 }
5247 }
5248
5249 if (!AR->hasNoSignedWrap()) {
5250 ConstantRange AddRecRange = getSignedRange(AR);
5251 ConstantRange IncRange = getSignedRange(AR->getStepRecurrence(*this));
5252
5254 Instruction::Add, IncRange, OBO::NoSignedWrap);
5255 if (NSWRegion.contains(AddRecRange))
5257 }
5258
5259 if (!AR->hasNoUnsignedWrap()) {
5260 ConstantRange AddRecRange = getUnsignedRange(AR);
5261 ConstantRange IncRange = getUnsignedRange(AR->getStepRecurrence(*this));
5262
5264 Instruction::Add, IncRange, OBO::NoUnsignedWrap);
5265 if (NUWRegion.contains(AddRecRange))
5267 }
5268
5269 return Result;
5270}
5271
5273ScalarEvolution::proveNoSignedWrapViaInduction(const SCEVAddRecExpr *AR) {
5275
5276 if (AR->hasNoSignedWrap())
5277 return Result;
5278
5279 if (!AR->isAffine())
5280 return Result;
5281
5282 // This function can be expensive, only try to prove NSW once per AddRec.
5283 if (!SignedWrapViaInductionTried.insert(AR).second)
5284 return Result;
5285
5286 const SCEV *Step = AR->getStepRecurrence(*this);
5287 const Loop *L = AR->getLoop();
5288
5289 // Check whether the backedge-taken count is SCEVCouldNotCompute.
5290 // Note that this serves two purposes: It filters out loops that are
5291 // simply not analyzable, and it covers the case where this code is
5292 // being called from within backedge-taken count analysis, such that
5293 // attempting to ask for the backedge-taken count would likely result
5294 // in infinite recursion. In the later case, the analysis code will
5295 // cope with a conservative value, and it will take care to purge
5296 // that value once it has finished.
5297 const SCEV *MaxBECount = getConstantMaxBackedgeTakenCount(L);
5298
5299 // Normally, in the cases we can prove no-overflow via a
5300 // backedge guarding condition, we can also compute a backedge
5301 // taken count for the loop. The exceptions are assumptions and
5302 // guards present in the loop -- SCEV is not great at exploiting
5303 // these to compute max backedge taken counts, but can still use
5304 // these to prove lack of overflow. Use this fact to avoid
5305 // doing extra work that may not pay off.
5306
5307 if (isa<SCEVCouldNotCompute>(MaxBECount) && !HasGuards &&
5308 AC.assumptions().empty())
5309 return Result;
5310
5311 // If the backedge is guarded by a comparison with the pre-inc value the
5312 // addrec is safe. Also, if the entry is guarded by a comparison with the
5313 // start value and the backedge is guarded by a comparison with the post-inc
5314 // value, the addrec is safe.
5316 const SCEV *OverflowLimit =
5317 getSignedOverflowLimitForStep(Step, &Pred, this);
5318 if (OverflowLimit &&
5319 (isLoopBackedgeGuardedByCond(L, Pred, AR, OverflowLimit) ||
5320 isKnownOnEveryIteration(Pred, AR, OverflowLimit))) {
5321 Result = setFlags(Result, SCEV::FlagNSW);
5322 }
5323 return Result;
5324}
5326ScalarEvolution::proveNoUnsignedWrapViaInduction(const SCEVAddRecExpr *AR) {
5328
5329 if (AR->hasNoUnsignedWrap())
5330 return Result;
5331
5332 if (!AR->isAffine())
5333 return Result;
5334
5335 // This function can be expensive, only try to prove NUW once per AddRec.
5336 if (!UnsignedWrapViaInductionTried.insert(AR).second)
5337 return Result;
5338
5339 const SCEV *Step = AR->getStepRecurrence(*this);
5340 unsigned BitWidth = getTypeSizeInBits(AR->getType());
5341 const Loop *L = AR->getLoop();
5342
5343 // Check whether the backedge-taken count is SCEVCouldNotCompute.
5344 // Note that this serves two purposes: It filters out loops that are
5345 // simply not analyzable, and it covers the case where this code is
5346 // being called from within backedge-taken count analysis, such that
5347 // attempting to ask for the backedge-taken count would likely result
5348 // in infinite recursion. In the later case, the analysis code will
5349 // cope with a conservative value, and it will take care to purge
5350 // that value once it has finished.
5351 const SCEV *MaxBECount = getConstantMaxBackedgeTakenCount(L);
5352
5353 // Normally, in the cases we can prove no-overflow via a
5354 // backedge guarding condition, we can also compute a backedge
5355 // taken count for the loop. The exceptions are assumptions and
5356 // guards present in the loop -- SCEV is not great at exploiting
5357 // these to compute max backedge taken counts, but can still use
5358 // these to prove lack of overflow. Use this fact to avoid
5359 // doing extra work that may not pay off.
5360
5361 if (isa<SCEVCouldNotCompute>(MaxBECount) && !HasGuards &&
5362 AC.assumptions().empty())
5363 return Result;
5364
5365 // If the backedge is guarded by a comparison with the pre-inc value the
5366 // addrec is safe. Also, if the entry is guarded by a comparison with the
5367 // start value and the backedge is guarded by a comparison with the post-inc
5368 // value, the addrec is safe.
5369 if (isKnownPositive(Step)) {
5370 const SCEV *N = getConstant(APInt::getMinValue(BitWidth) -
5371 getUnsignedRangeMax(Step));
5374 Result = setFlags(Result, SCEV::FlagNUW);
5375 }
5376 }
5377
5378 return Result;
5379}
5380
5381namespace {
5382
5383/// Represents an abstract binary operation. This may exist as a
5384/// normal instruction or constant expression, or may have been
5385/// derived from an expression tree.
5386struct BinaryOp {
5387 unsigned Opcode;
5388 Value *LHS;
5389 Value *RHS;
5390 bool IsNSW = false;
5391 bool IsNUW = false;
5392
5393 /// Op is set if this BinaryOp corresponds to a concrete LLVM instruction or
5394 /// constant expression.
5395 Operator *Op = nullptr;
5396
5397 explicit BinaryOp(Operator *Op)
5398 : Opcode(Op->getOpcode()), LHS(Op->getOperand(0)), RHS(Op->getOperand(1)),
5399 Op(Op) {
5400 if (auto *OBO = dyn_cast<OverflowingBinaryOperator>(Op)) {
5401 IsNSW = OBO->hasNoSignedWrap();
5402 IsNUW = OBO->hasNoUnsignedWrap();
5403 }
5404 }
5405
5406 explicit BinaryOp(unsigned Opcode, Value *LHS, Value *RHS, bool IsNSW = false,
5407 bool IsNUW = false)
5408 : Opcode(Opcode), LHS(LHS), RHS(RHS), IsNSW(IsNSW), IsNUW(IsNUW) {}
5409};
5410
5411} // end anonymous namespace
5412
5413/// Try to map \p V into a BinaryOp, and return \c std::nullopt on failure.
5414static std::optional<BinaryOp> MatchBinaryOp(Value *V, const DataLayout &DL,
5415 AssumptionCache &AC,
5416 const DominatorTree &DT,
5417 const Instruction *CxtI) {
5418 auto *Op = dyn_cast<Operator>(V);
5419 if (!Op)
5420 return std::nullopt;
5421
5422 // Implementation detail: all the cleverness here should happen without
5423 // creating new SCEV expressions -- our caller knowns tricks to avoid creating
5424 // SCEV expressions when possible, and we should not break that.
5425
5426 switch (Op->getOpcode()) {
5427 case Instruction::Add:
5428 case Instruction::Sub:
5429 case Instruction::Mul:
5430 case Instruction::UDiv:
5431 case Instruction::URem:
5432 case Instruction::And:
5433 case Instruction::AShr:
5434 case Instruction::Shl:
5435 return BinaryOp(Op);
5436
5437 case Instruction::Or: {
5438 // Convert or disjoint into add nuw nsw.
5439 if (cast<PossiblyDisjointInst>(Op)->isDisjoint())
5440 return BinaryOp(Instruction::Add, Op->getOperand(0), Op->getOperand(1),
5441 /*IsNSW=*/true, /*IsNUW=*/true);
5442 return BinaryOp(Op);
5443 }
5444
5445 case Instruction::Xor:
5446 if (auto *RHSC = dyn_cast<ConstantInt>(Op->getOperand(1)))
5447 // If the RHS of the xor is a signmask, then this is just an add.
5448 // Instcombine turns add of signmask into xor as a strength reduction step.
5449 if (RHSC->getValue().isSignMask())
5450 return BinaryOp(Instruction::Add, Op->getOperand(0), Op->getOperand(1));
5451 // Binary `xor` is a bit-wise `add`.
5452 if (V->getType()->isIntegerTy(1))
5453 return BinaryOp(Instruction::Add, Op->getOperand(0), Op->getOperand(1));
5454 return BinaryOp(Op);
5455
5456 case Instruction::LShr:
5457 // Turn logical shift right of a constant into a unsigned divide.
5458 if (ConstantInt *SA = dyn_cast<ConstantInt>(Op->getOperand(1))) {
5459 uint32_t BitWidth = cast<IntegerType>(Op->getType())->getBitWidth();
5460
5461 // If the shift count is not less than the bitwidth, the result of
5462 // the shift is undefined. Don't try to analyze it, because the
5463 // resolution chosen here may differ from the resolution chosen in
5464 // other parts of the compiler.
5465 if (SA->getValue().ult(BitWidth)) {
5466 Constant *X =
5467 ConstantInt::get(SA->getContext(),
5468 APInt::getOneBitSet(BitWidth, SA->getZExtValue()));
5469 return BinaryOp(Instruction::UDiv, Op->getOperand(0), X);
5470 }
5471 }
5472 return BinaryOp(Op);
5473
5474 case Instruction::ExtractValue: {
5475 auto *EVI = cast<ExtractValueInst>(Op);
5476 if (EVI->getNumIndices() != 1 || EVI->getIndices()[0] != 0)
5477 break;
5478
5479 auto *WO = dyn_cast<WithOverflowInst>(EVI->getAggregateOperand());
5480 if (!WO)
5481 break;
5482
5483 Instruction::BinaryOps BinOp = WO->getBinaryOp();
5484 bool Signed = WO->isSigned();
5485 // TODO: Should add nuw/nsw flags for mul as well.
5486 if (BinOp == Instruction::Mul || !isOverflowIntrinsicNoWrap(WO, DT))
5487 return BinaryOp(BinOp, WO->getLHS(), WO->getRHS());
5488
5489 // Now that we know that all uses of the arithmetic-result component of
5490 // CI are guarded by the overflow check, we can go ahead and pretend
5491 // that the arithmetic is non-overflowing.
5492 return BinaryOp(BinOp, WO->getLHS(), WO->getRHS(),
5493 /* IsNSW = */ Signed, /* IsNUW = */ !Signed);
5494 }
5495
5496 default:
5497 break;
5498 }
5499
5500 // Recognise intrinsic loop.decrement.reg, and as this has exactly the same
5501 // semantics as a Sub, return a binary sub expression.
5502 if (auto *II = dyn_cast<IntrinsicInst>(V))
5503 if (II->getIntrinsicID() == Intrinsic::loop_decrement_reg)
5504 return BinaryOp(Instruction::Sub, II->getOperand(0), II->getOperand(1));
5505
5506 return std::nullopt;
5507}
5508
5509/// Helper function to createAddRecFromPHIWithCasts. We have a phi
5510/// node whose symbolic (unknown) SCEV is \p SymbolicPHI, which is updated via
5511/// the loop backedge by a SCEVAddExpr, possibly also with a few casts on the
5512/// way. This function checks if \p Op, an operand of this SCEVAddExpr,
5513/// follows one of the following patterns:
5514/// Op == (SExt ix (Trunc iy (%SymbolicPHI) to ix) to iy)
5515/// Op == (ZExt ix (Trunc iy (%SymbolicPHI) to ix) to iy)
5516/// If the SCEV expression of \p Op conforms with one of the expected patterns
5517/// we return the type of the truncation operation, and indicate whether the
5518/// truncated type should be treated as signed/unsigned by setting
5519/// \p Signed to true/false, respectively.
5520static Type *isSimpleCastedPHI(const SCEV *Op, const SCEVUnknown *SymbolicPHI,
5521 bool &Signed, ScalarEvolution &SE) {
5522 // The case where Op == SymbolicPHI (that is, with no type conversions on
5523 // the way) is handled by the regular add recurrence creating logic and
5524 // would have already been triggered in createAddRecForPHI. Reaching it here
5525 // means that createAddRecFromPHI had failed for this PHI before (e.g.,
5526 // because one of the other operands of the SCEVAddExpr updating this PHI is
5527 // not invariant).
5528 //
5529 // Here we look for the case where Op = (ext(trunc(SymbolicPHI))), and in
5530 // this case predicates that allow us to prove that Op == SymbolicPHI will
5531 // be added.
5532 if (Op == SymbolicPHI)
5533 return nullptr;
5534
5535 unsigned SourceBits = SE.getTypeSizeInBits(SymbolicPHI->getType());
5536 unsigned NewBits = SE.getTypeSizeInBits(Op->getType());
5537 if (SourceBits != NewBits)
5538 return nullptr;
5539
5540 if (match(Op, m_scev_SExt(m_scev_Trunc(m_scev_Specific(SymbolicPHI))))) {
5541 Signed = true;
5542 return cast<SCEVCastExpr>(Op)->getOperand()->getType();
5543 }
5544 if (match(Op, m_scev_ZExt(m_scev_Trunc(m_scev_Specific(SymbolicPHI))))) {
5545 Signed = false;
5546 return cast<SCEVCastExpr>(Op)->getOperand()->getType();
5547 }
5548 return nullptr;
5549}
5550
5551static const Loop *isIntegerLoopHeaderPHI(const PHINode *PN, LoopInfo &LI) {
5552 if (!PN->getType()->isIntegerTy())
5553 return nullptr;
5554 const Loop *L = LI.getLoopFor(PN->getParent());
5555 if (!L || L->getHeader() != PN->getParent())
5556 return nullptr;
5557 return L;
5558}
5559
5560// Analyze \p SymbolicPHI, a SCEV expression of a phi node, and check if the
5561// computation that updates the phi follows the following pattern:
5562// (SExt/ZExt ix (Trunc iy (%SymbolicPHI) to ix) to iy) + InvariantAccum
5563// which correspond to a phi->trunc->sext/zext->add->phi update chain.
5564// If so, try to see if it can be rewritten as an AddRecExpr under some
5565// Predicates. If successful, return them as a pair. Also cache the results
5566// of the analysis.
5567//
5568// Example usage scenario:
5569// Say the Rewriter is called for the following SCEV:
5570// 8 * ((sext i32 (trunc i64 %X to i32) to i64) + %Step)
5571// where:
5572// %X = phi i64 (%Start, %BEValue)
5573// It will visitMul->visitAdd->visitSExt->visitTrunc->visitUnknown(%X),
5574// and call this function with %SymbolicPHI = %X.
5575//
5576// The analysis will find that the value coming around the backedge has
5577// the following SCEV:
5578// BEValue = ((sext i32 (trunc i64 %X to i32) to i64) + %Step)
5579// Upon concluding that this matches the desired pattern, the function
5580// will return the pair {NewAddRec, SmallPredsVec} where:
5581// NewAddRec = {%Start,+,%Step}
5582// SmallPredsVec = {P1, P2, P3} as follows:
5583// P1(WrapPred): AR: {trunc(%Start),+,(trunc %Step)}<nsw> Flags: <nssw>
5584// P2(EqualPred): %Start == (sext i32 (trunc i64 %Start to i32) to i64)
5585// P3(EqualPred): %Step == (sext i32 (trunc i64 %Step to i32) to i64)
5586// The returned pair means that SymbolicPHI can be rewritten into NewAddRec
5587// under the predicates {P1,P2,P3}.
5588// This predicated rewrite will be cached in PredicatedSCEVRewrites:
5589// PredicatedSCEVRewrites[{%X,L}] = {NewAddRec, {P1,P2,P3)}
5590//
5591// TODO's:
5592//
5593// 1) Extend the Induction descriptor to also support inductions that involve
5594// casts: When needed (namely, when we are called in the context of the
5595// vectorizer induction analysis), a Set of cast instructions will be
5596// populated by this method, and provided back to isInductionPHI. This is
5597// needed to allow the vectorizer to properly record them to be ignored by
5598// the cost model and to avoid vectorizing them (otherwise these casts,
5599// which are redundant under the runtime overflow checks, will be
5600// vectorized, which can be costly).
5601//
5602// 2) Support additional induction/PHISCEV patterns: We also want to support
5603// inductions where the sext-trunc / zext-trunc operations (partly) occur
5604// after the induction update operation (the induction increment):
5605//
5606// (Trunc iy (SExt/ZExt ix (%SymbolicPHI + InvariantAccum) to iy) to ix)
5607// which correspond to a phi->add->trunc->sext/zext->phi update chain.
5608//
5609// (Trunc iy ((SExt/ZExt ix (%SymbolicPhi) to iy) + InvariantAccum) to ix)
5610// which correspond to a phi->trunc->add->sext/zext->phi update chain.
5611//
5612// 3) Outline common code with createAddRecFromPHI to avoid duplication.
5613std::optional<std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>>>
5614ScalarEvolution::createAddRecFromPHIWithCastsImpl(const SCEVUnknown *SymbolicPHI) {
5616
5617 // *** Part1: Analyze if we have a phi-with-cast pattern for which we can
5618 // return an AddRec expression under some predicate.
5619
5620 auto *PN = cast<PHINode>(SymbolicPHI->getValue());
5621 const Loop *L = isIntegerLoopHeaderPHI(PN, LI);
5622 assert(L && "Expecting an integer loop header phi");
5623
5624 // The loop may have multiple entrances or multiple exits; we can analyze
5625 // this phi as an addrec if it has a unique entry value and a unique
5626 // backedge value.
5627 Value *BEValueV = nullptr, *StartValueV = nullptr;
5628 for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) {
5629 Value *V = PN->getIncomingValue(i);
5630 if (L->contains(PN->getIncomingBlock(i))) {
5631 if (!BEValueV) {
5632 BEValueV = V;
5633 } else if (BEValueV != V) {
5634 BEValueV = nullptr;
5635 break;
5636 }
5637 } else if (!StartValueV) {
5638 StartValueV = V;
5639 } else if (StartValueV != V) {
5640 StartValueV = nullptr;
5641 break;
5642 }
5643 }
5644 if (!BEValueV || !StartValueV)
5645 return std::nullopt;
5646
5647 const SCEV *BEValue = getSCEV(BEValueV);
5648
5649 // If the value coming around the backedge is an add with the symbolic
5650 // value we just inserted, possibly with casts that we can ignore under
5651 // an appropriate runtime guard, then we found a simple induction variable!
5652 const auto *Add = dyn_cast<SCEVAddExpr>(BEValue);
5653 if (!Add)
5654 return std::nullopt;
5655
5656 // If there is a single occurrence of the symbolic value, possibly
5657 // casted, replace it with a recurrence.
5658 unsigned FoundIndex = Add->getNumOperands();
5659 Type *TruncTy = nullptr;
5660 bool Signed;
5661 for (unsigned i = 0, e = Add->getNumOperands(); i != e; ++i)
5662 if ((TruncTy =
5663 isSimpleCastedPHI(Add->getOperand(i), SymbolicPHI, Signed, *this)))
5664 if (FoundIndex == e) {
5665 FoundIndex = i;
5666 break;
5667 }
5668
5669 if (FoundIndex == Add->getNumOperands())
5670 return std::nullopt;
5671
5672 // Create an add with everything but the specified operand.
5674 for (unsigned i = 0, e = Add->getNumOperands(); i != e; ++i)
5675 if (i != FoundIndex)
5676 Ops.push_back(Add->getOperand(i));
5677 const SCEV *Accum = getAddExpr(Ops);
5678
5679 // The runtime checks will not be valid if the step amount is
5680 // varying inside the loop.
5681 if (!isLoopInvariant(Accum, L))
5682 return std::nullopt;
5683
5684 // *** Part2: Create the predicates
5685
5686 // Analysis was successful: we have a phi-with-cast pattern for which we
5687 // can return an AddRec expression under the following predicates:
5688 //
5689 // P1: A Wrap predicate that guarantees that Trunc(Start) + i*Trunc(Accum)
5690 // fits within the truncated type (does not overflow) for i = 0 to n-1.
5691 // P2: An Equal predicate that guarantees that
5692 // Start = (Ext ix (Trunc iy (Start) to ix) to iy)
5693 // P3: An Equal predicate that guarantees that
5694 // Accum = (Ext ix (Trunc iy (Accum) to ix) to iy)
5695 //
5696 // As we next prove, the above predicates guarantee that:
5697 // Start + i*Accum = (Ext ix (Trunc iy ( Start + i*Accum ) to ix) to iy)
5698 //
5699 //
5700 // More formally, we want to prove that:
5701 // Expr(i+1) = Start + (i+1) * Accum
5702 // = (Ext ix (Trunc iy (Expr(i)) to ix) to iy) + Accum
5703 //
5704 // Given that:
5705 // 1) Expr(0) = Start
5706 // 2) Expr(1) = Start + Accum
5707 // = (Ext ix (Trunc iy (Start) to ix) to iy) + Accum :: from P2
5708 // 3) Induction hypothesis (step i):
5709 // Expr(i) = (Ext ix (Trunc iy (Expr(i-1)) to ix) to iy) + Accum
5710 //
5711 // Proof:
5712 // Expr(i+1) =
5713 // = Start + (i+1)*Accum
5714 // = (Start + i*Accum) + Accum
5715 // = Expr(i) + Accum
5716 // = (Ext ix (Trunc iy (Expr(i-1)) to ix) to iy) + Accum + Accum
5717 // :: from step i
5718 //
5719 // = (Ext ix (Trunc iy (Start + (i-1)*Accum) to ix) to iy) + Accum + Accum
5720 //
5721 // = (Ext ix (Trunc iy (Start + (i-1)*Accum) to ix) to iy)
5722 // + (Ext ix (Trunc iy (Accum) to ix) to iy)
5723 // + Accum :: from P3
5724 //
5725 // = (Ext ix (Trunc iy ((Start + (i-1)*Accum) + Accum) to ix) to iy)
5726 // + Accum :: from P1: Ext(x)+Ext(y)=>Ext(x+y)
5727 //
5728 // = (Ext ix (Trunc iy (Start + i*Accum) to ix) to iy) + Accum
5729 // = (Ext ix (Trunc iy (Expr(i)) to ix) to iy) + Accum
5730 //
5731 // By induction, the same applies to all iterations 1<=i<n:
5732 //
5733
5734 // Create a truncated addrec for which we will add a no overflow check (P1).
5735 const SCEV *StartVal = getSCEV(StartValueV);
5736 const SCEV *PHISCEV =
5737 getAddRecExpr(getTruncateExpr(StartVal, TruncTy),
5738 getTruncateExpr(Accum, TruncTy), L, SCEV::FlagAnyWrap);
5739
5740 // PHISCEV can be either a SCEVConstant or a SCEVAddRecExpr.
5741 // ex: If truncated Accum is 0 and StartVal is a constant, then PHISCEV
5742 // will be constant.
5743 //
5744 // If PHISCEV is a constant, then P1 degenerates into P2 or P3, so we don't
5745 // add P1.
5746 if (const auto *AR = dyn_cast<SCEVAddRecExpr>(PHISCEV)) {
5750 const SCEVPredicate *AddRecPred = getWrapPredicate(AR, AddedFlags);
5751 Predicates.push_back(AddRecPred);
5752 }
5753
5754 // Create the Equal Predicates P2,P3:
5755
5756 // It is possible that the predicates P2 and/or P3 are computable at
5757 // compile time due to StartVal and/or Accum being constants.
5758 // If either one is, then we can check that now and escape if either P2
5759 // or P3 is false.
5760
5761 // Construct the extended SCEV: (Ext ix (Trunc iy (Expr) to ix) to iy)
5762 // for each of StartVal and Accum
5763 auto getExtendedExpr = [&](const SCEV *Expr,
5764 bool CreateSignExtend) -> const SCEV * {
5765 assert(isLoopInvariant(Expr, L) && "Expr is expected to be invariant");
5766 const SCEV *TruncatedExpr = getTruncateExpr(Expr, TruncTy);
5767 const SCEV *ExtendedExpr =
5768 CreateSignExtend ? getSignExtendExpr(TruncatedExpr, Expr->getType())
5769 : getZeroExtendExpr(TruncatedExpr, Expr->getType());
5770 return ExtendedExpr;
5771 };
5772
5773 // Given:
5774 // ExtendedExpr = (Ext ix (Trunc iy (Expr) to ix) to iy
5775 // = getExtendedExpr(Expr)
5776 // Determine whether the predicate P: Expr == ExtendedExpr
5777 // is known to be false at compile time
5778 auto PredIsKnownFalse = [&](const SCEV *Expr,
5779 const SCEV *ExtendedExpr) -> bool {
5780 return Expr != ExtendedExpr &&
5781 isKnownPredicate(ICmpInst::ICMP_NE, Expr, ExtendedExpr);
5782 };
5783
5784 const SCEV *StartExtended = getExtendedExpr(StartVal, Signed);
5785 if (PredIsKnownFalse(StartVal, StartExtended)) {
5786 LLVM_DEBUG(dbgs() << "P2 is compile-time false\n";);
5787 return std::nullopt;
5788 }
5789
5790 // The Step is always Signed (because the overflow checks are either
5791 // NSSW or NUSW)
5792 const SCEV *AccumExtended = getExtendedExpr(Accum, /*CreateSignExtend=*/true);
5793 if (PredIsKnownFalse(Accum, AccumExtended)) {
5794 LLVM_DEBUG(dbgs() << "P3 is compile-time false\n";);
5795 return std::nullopt;
5796 }
5797
5798 auto AppendPredicate = [&](const SCEV *Expr,
5799 const SCEV *ExtendedExpr) -> void {
5800 if (Expr != ExtendedExpr &&
5801 !isKnownPredicate(ICmpInst::ICMP_EQ, Expr, ExtendedExpr)) {
5802 const SCEVPredicate *Pred = getEqualPredicate(Expr, ExtendedExpr);
5803 LLVM_DEBUG(dbgs() << "Added Predicate: " << *Pred);
5804 Predicates.push_back(Pred);
5805 }
5806 };
5807
5808 AppendPredicate(StartVal, StartExtended);
5809 AppendPredicate(Accum, AccumExtended);
5810
5811 // *** Part3: Predicates are ready. Now go ahead and create the new addrec in
5812 // which the casts had been folded away. The caller can rewrite SymbolicPHI
5813 // into NewAR if it will also add the runtime overflow checks specified in
5814 // Predicates.
5815 auto *NewAR = getAddRecExpr(StartVal, Accum, L, SCEV::FlagAnyWrap);
5816
5817 std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>> PredRewrite =
5818 std::make_pair(NewAR, Predicates);
5819 // Remember the result of the analysis for this SCEV at this locayyytion.
5820 PredicatedSCEVRewrites[{SymbolicPHI, L}] = PredRewrite;
5821 return PredRewrite;
5822}
5823
5824std::optional<std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>>>
5826 auto *PN = cast<PHINode>(SymbolicPHI->getValue());
5827 const Loop *L = isIntegerLoopHeaderPHI(PN, LI);
5828 if (!L)
5829 return std::nullopt;
5830
5831 // Check to see if we already analyzed this PHI.
5832 auto I = PredicatedSCEVRewrites.find({SymbolicPHI, L});
5833 if (I != PredicatedSCEVRewrites.end()) {
5834 std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>> Rewrite =
5835 I->second;
5836 // Analysis was done before and failed to create an AddRec:
5837 if (Rewrite.first == SymbolicPHI)
5838 return std::nullopt;
5839 // Analysis was done before and succeeded to create an AddRec under
5840 // a predicate:
5841 assert(isa<SCEVAddRecExpr>(Rewrite.first) && "Expected an AddRec");
5842 assert(!(Rewrite.second).empty() && "Expected to find Predicates");
5843 return Rewrite;
5844 }
5845
5846 std::optional<std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>>>
5847 Rewrite = createAddRecFromPHIWithCastsImpl(SymbolicPHI);
5848
5849 // Record in the cache that the analysis failed
5850 if (!Rewrite) {
5852 PredicatedSCEVRewrites[{SymbolicPHI, L}] = {SymbolicPHI, Predicates};
5853 return std::nullopt;
5854 }
5855
5856 return Rewrite;
5857}
5858
5859// FIXME: This utility is currently required because the Rewriter currently
5860// does not rewrite this expression:
5861// {0, +, (sext ix (trunc iy to ix) to iy)}
5862// into {0, +, %step},
5863// even when the following Equal predicate exists:
5864// "%step == (sext ix (trunc iy to ix) to iy)".
5866 const SCEVAddRecExpr *AR1, const SCEVAddRecExpr *AR2) const {
5867 if (AR1 == AR2)
5868 return true;
5869
5870 auto areExprsEqual = [&](const SCEV *Expr1, const SCEV *Expr2) -> bool {
5871 if (Expr1 != Expr2 &&
5872 !Preds->implies(SE.getEqualPredicate(Expr1, Expr2), SE) &&
5873 !Preds->implies(SE.getEqualPredicate(Expr2, Expr1), SE))
5874 return false;
5875 return true;
5876 };
5877
5878 if (!areExprsEqual(AR1->getStart(), AR2->getStart()) ||
5879 !areExprsEqual(AR1->getStepRecurrence(SE), AR2->getStepRecurrence(SE)))
5880 return false;
5881 return true;
5882}
5883
5884/// A helper function for createAddRecFromPHI to handle simple cases.
5885///
5886/// This function tries to find an AddRec expression for the simplest (yet most
5887/// common) cases: PN = PHI(Start, OP(Self, LoopInvariant)).
5888/// If it fails, createAddRecFromPHI will use a more general, but slow,
5889/// technique for finding the AddRec expression.
5890const SCEV *ScalarEvolution::createSimpleAffineAddRec(PHINode *PN,
5891 Value *BEValueV,
5892 Value *StartValueV) {
5893 const Loop *L = LI.getLoopFor(PN->getParent());
5894 assert(L && L->getHeader() == PN->getParent());
5895 assert(BEValueV && StartValueV);
5896
5897 auto BO = MatchBinaryOp(BEValueV, getDataLayout(), AC, DT, PN);
5898 if (!BO)
5899 return nullptr;
5900
5901 if (BO->Opcode != Instruction::Add)
5902 return nullptr;
5903
5904 const SCEV *Accum = nullptr;
5905 if (BO->LHS == PN && L->isLoopInvariant(BO->RHS))
5906 Accum = getSCEV(BO->RHS);
5907 else if (BO->RHS == PN && L->isLoopInvariant(BO->LHS))
5908 Accum = getSCEV(BO->LHS);
5909
5910 if (!Accum)
5911 return nullptr;
5912
5914 if (BO->IsNUW)
5915 Flags = setFlags(Flags, SCEV::FlagNUW);
5916 if (BO->IsNSW)
5917 Flags = setFlags(Flags, SCEV::FlagNSW);
5918
5919 const SCEV *StartVal = getSCEV(StartValueV);
5920 const SCEV *PHISCEV = getAddRecExpr(StartVal, Accum, L, Flags);
5921 insertValueToMap(PN, PHISCEV);
5922
5923 if (auto *AR = dyn_cast<SCEVAddRecExpr>(PHISCEV)) {
5924 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR),
5925 (AR->getNoWrapFlags() | proveNoWrapViaConstantRanges(AR)));
5926 }
5927
5928 // We can add Flags to the post-inc expression only if we
5929 // know that it is *undefined behavior* for BEValueV to
5930 // overflow.
5931 if (auto *BEInst = dyn_cast<Instruction>(BEValueV)) {
5932 assert(isLoopInvariant(Accum, L) &&
5933 "Accum is defined outside L, but is not invariant?");
5934 if (isAddRecNeverPoison(BEInst, L))
5935 (void)getAddRecExpr(getAddExpr(StartVal, Accum), Accum, L, Flags);
5936 }
5937
5938 return PHISCEV;
5939}
5940
5941const SCEV *ScalarEvolution::createAddRecFromPHI(PHINode *PN) {
5942 const Loop *L = LI.getLoopFor(PN->getParent());
5943 if (!L || L->getHeader() != PN->getParent())
5944 return nullptr;
5945
5946 // The loop may have multiple entrances or multiple exits; we can analyze
5947 // this phi as an addrec if it has a unique entry value and a unique
5948 // backedge value.
5949 Value *BEValueV = nullptr, *StartValueV = nullptr;
5950 for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) {
5951 Value *V = PN->getIncomingValue(i);
5952 if (L->contains(PN->getIncomingBlock(i))) {
5953 if (!BEValueV) {
5954 BEValueV = V;
5955 } else if (BEValueV != V) {
5956 BEValueV = nullptr;
5957 break;
5958 }
5959 } else if (!StartValueV) {
5960 StartValueV = V;
5961 } else if (StartValueV != V) {
5962 StartValueV = nullptr;
5963 break;
5964 }
5965 }
5966 if (!BEValueV || !StartValueV)
5967 return nullptr;
5968
5969 assert(ValueExprMap.find_as(PN) == ValueExprMap.end() &&
5970 "PHI node already processed?");
5971
5972 // First, try to find AddRec expression without creating a fictituos symbolic
5973 // value for PN.
5974 if (auto *S = createSimpleAffineAddRec(PN, BEValueV, StartValueV))
5975 return S;
5976
5977 // Handle PHI node value symbolically.
5978 const SCEV *SymbolicName = getUnknown(PN);
5979 insertValueToMap(PN, SymbolicName);
5980
5981 // Using this symbolic name for the PHI, analyze the value coming around
5982 // the back-edge.
5983 const SCEV *BEValue = getSCEV(BEValueV);
5984
5985 // NOTE: If BEValue is loop invariant, we know that the PHI node just
5986 // has a special value for the first iteration of the loop.
5987
5988 // If the value coming around the backedge is an add with the symbolic
5989 // value we just inserted, then we found a simple induction variable!
5990 if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(BEValue)) {
5991 // If there is a single occurrence of the symbolic value, replace it
5992 // with a recurrence.
5993 unsigned FoundIndex = Add->getNumOperands();
5994 for (unsigned i = 0, e = Add->getNumOperands(); i != e; ++i)
5995 if (Add->getOperand(i) == SymbolicName)
5996 if (FoundIndex == e) {
5997 FoundIndex = i;
5998 break;
5999 }
6000
6001 if (FoundIndex != Add->getNumOperands()) {
6002 // Create an add with everything but the specified operand.
6004 for (unsigned i = 0, e = Add->getNumOperands(); i != e; ++i)
6005 if (i != FoundIndex)
6006 Ops.push_back(SCEVBackedgeConditionFolder::rewrite(Add->getOperand(i),
6007 L, *this));
6008 const SCEV *Accum = getAddExpr(Ops);
6009
6010 // This is not a valid addrec if the step amount is varying each
6011 // loop iteration, but is not itself an addrec in this loop.
6012 if (isLoopInvariant(Accum, L) ||
6013 (isa<SCEVAddRecExpr>(Accum) &&
6014 cast<SCEVAddRecExpr>(Accum)->getLoop() == L)) {
6016
6017 if (auto BO = MatchBinaryOp(BEValueV, getDataLayout(), AC, DT, PN)) {
6018 if (BO->Opcode == Instruction::Add && BO->LHS == PN) {
6019 if (BO->IsNUW)
6020 Flags = setFlags(Flags, SCEV::FlagNUW);
6021 if (BO->IsNSW)
6022 Flags = setFlags(Flags, SCEV::FlagNSW);
6023 }
6024 } else if (GEPOperator *GEP = dyn_cast<GEPOperator>(BEValueV)) {
6025 if (GEP->getOperand(0) == PN) {
6026 GEPNoWrapFlags NW = GEP->getNoWrapFlags();
6027 // If the increment has any nowrap flags, then we know the address
6028 // space cannot be wrapped around.
6029 if (NW != GEPNoWrapFlags::none())
6030 Flags = setFlags(Flags, SCEV::FlagNW);
6031 // If the GEP is nuw or nusw with non-negative offset, we know that
6032 // no unsigned wrap occurs. We cannot set the nsw flag as only the
6033 // offset is treated as signed, while the base is unsigned.
6034 if (NW.hasNoUnsignedWrap() ||
6036 Flags = setFlags(Flags, SCEV::FlagNUW);
6037 }
6038
6039 // We cannot transfer nuw and nsw flags from subtraction
6040 // operations -- sub nuw X, Y is not the same as add nuw X, -Y
6041 // for instance.
6042 }
6043
6044 const SCEV *StartVal = getSCEV(StartValueV);
6045 const SCEV *PHISCEV = getAddRecExpr(StartVal, Accum, L, Flags);
6046
6047 // Okay, for the entire analysis of this edge we assumed the PHI
6048 // to be symbolic. We now need to go back and purge all of the
6049 // entries for the scalars that use the symbolic expression.
6050 forgetMemoizedResults({SymbolicName});
6051 insertValueToMap(PN, PHISCEV);
6052
6053 if (auto *AR = dyn_cast<SCEVAddRecExpr>(PHISCEV)) {
6055 const_cast<SCEVAddRecExpr *>(AR),
6056 (AR->getNoWrapFlags() | proveNoWrapViaConstantRanges(AR)));
6057 }
6058
6059 // We can add Flags to the post-inc expression only if we
6060 // know that it is *undefined behavior* for BEValueV to
6061 // overflow.
6062 if (auto *BEInst = dyn_cast<Instruction>(BEValueV))
6063 if (isLoopInvariant(Accum, L) && isAddRecNeverPoison(BEInst, L))
6064 (void)getAddRecExpr(getAddExpr(StartVal, Accum), Accum, L, Flags);
6065
6066 return PHISCEV;
6067 }
6068 }
6069 } else {
6070 // Otherwise, this could be a loop like this:
6071 // i = 0; for (j = 1; ..; ++j) { .... i = j; }
6072 // In this case, j = {1,+,1} and BEValue is j.
6073 // Because the other in-value of i (0) fits the evolution of BEValue
6074 // i really is an addrec evolution.
6075 //
6076 // We can generalize this saying that i is the shifted value of BEValue
6077 // by one iteration:
6078 // PHI(f(0), f({1,+,1})) --> f({0,+,1})
6079
6080 // Do not allow refinement in rewriting of BEValue.
6081 const SCEV *Shifted = SCEVShiftRewriter::rewrite(BEValue, L, *this);
6082 const SCEV *Start = SCEVInitRewriter::rewrite(Shifted, L, *this, false);
6083 if (Shifted != getCouldNotCompute() && Start != getCouldNotCompute() &&
6084 isGuaranteedNotToCauseUB(Shifted) && ::impliesPoison(Shifted, Start)) {
6085 const SCEV *StartVal = getSCEV(StartValueV);
6086 if (Start == StartVal) {
6087 // Okay, for the entire analysis of this edge we assumed the PHI
6088 // to be symbolic. We now need to go back and purge all of the
6089 // entries for the scalars that use the symbolic expression.
6090 forgetMemoizedResults({SymbolicName});
6091 insertValueToMap(PN, Shifted);
6092 return Shifted;
6093 }
6094 }
6095 }
6096
6097 // Remove the temporary PHI node SCEV that has been inserted while intending
6098 // to create an AddRecExpr for this PHI node. We can not keep this temporary
6099 // as it will prevent later (possibly simpler) SCEV expressions to be added
6100 // to the ValueExprMap.
6101 eraseValueFromMap(PN);
6102
6103 return nullptr;
6104}
6105
6106// Try to match a control flow sequence that branches out at BI and merges back
6107// at Merge into a "C ? LHS : RHS" select pattern. Return true on a successful
6108// match.
6110 Value *&C, Value *&LHS, Value *&RHS) {
6111 C = BI->getCondition();
6112
6113 BasicBlockEdge LeftEdge(BI->getParent(), BI->getSuccessor(0));
6114 BasicBlockEdge RightEdge(BI->getParent(), BI->getSuccessor(1));
6115
6116 Use &LeftUse = Merge->getOperandUse(0);
6117 Use &RightUse = Merge->getOperandUse(1);
6118
6119 if (DT.dominates(LeftEdge, LeftUse) && DT.dominates(RightEdge, RightUse)) {
6120 LHS = LeftUse;
6121 RHS = RightUse;
6122 return true;
6123 }
6124
6125 if (DT.dominates(LeftEdge, RightUse) && DT.dominates(RightEdge, LeftUse)) {
6126 LHS = RightUse;
6127 RHS = LeftUse;
6128 return true;
6129 }
6130
6131 return false;
6132}
6133
6135 Value *&Cond, Value *&LHS,
6136 Value *&RHS) {
6137 auto IsReachable =
6138 [&](BasicBlock *BB) { return DT.isReachableFromEntry(BB); };
6139 if (PN->getNumIncomingValues() == 2 && all_of(PN->blocks(), IsReachable)) {
6140 // Try to match
6141 //
6142 // br %cond, label %left, label %right
6143 // left:
6144 // br label %merge
6145 // right:
6146 // br label %merge
6147 // merge:
6148 // V = phi [ %x, %left ], [ %y, %right ]
6149 //
6150 // as "select %cond, %x, %y"
6151
6152 BasicBlock *IDom = DT[PN->getParent()]->getIDom()->getBlock();
6153 assert(IDom && "At least the entry block should dominate PN");
6154
6155 auto *BI = dyn_cast<CondBrInst>(IDom->getTerminator());
6156 return BI && BrPHIToSelect(DT, BI, PN, Cond, LHS, RHS);
6157 }
6158 return false;
6159}
6160
6161const SCEV *ScalarEvolution::createNodeFromSelectLikePHI(PHINode *PN) {
6162 Value *Cond = nullptr, *LHS = nullptr, *RHS = nullptr;
6163 if (getOperandsForSelectLikePHI(DT, PN, Cond, LHS, RHS) &&
6166 return createNodeForSelectOrPHI(PN, Cond, LHS, RHS);
6167
6168 return nullptr;
6169}
6170
6172 BinaryOperator *CommonInst = nullptr;
6173 // Check if instructions are identical.
6174 for (Value *Incoming : PN->incoming_values()) {
6175 auto *IncomingInst = dyn_cast<BinaryOperator>(Incoming);
6176 if (!IncomingInst)
6177 return nullptr;
6178 if (CommonInst) {
6179 if (!CommonInst->isIdenticalToWhenDefined(IncomingInst))
6180 return nullptr; // Not identical, give up
6181 } else {
6182 // Remember binary operator
6183 CommonInst = IncomingInst;
6184 }
6185 }
6186 return CommonInst;
6187}
6188
6189/// Returns SCEV for the first operand of a phi if all phi operands have
6190/// identical opcodes and operands
6191/// eg.
6192/// a: %add = %a + %b
6193/// br %c
6194/// b: %add1 = %a + %b
6195/// br %c
6196/// c: %phi = phi [%add, a], [%add1, b]
6197/// scev(%phi) => scev(%add)
6198const SCEV *
6199ScalarEvolution::createNodeForPHIWithIdenticalOperands(PHINode *PN) {
6200 BinaryOperator *CommonInst = getCommonInstForPHI(PN);
6201 if (!CommonInst)
6202 return nullptr;
6203
6204 // Check if SCEV exprs for instructions are identical.
6205 const SCEV *CommonSCEV = getSCEV(CommonInst);
6206 bool SCEVExprsIdentical =
6208 [this, CommonSCEV](Value *V) { return CommonSCEV == getSCEV(V); });
6209 return SCEVExprsIdentical ? CommonSCEV : nullptr;
6210}
6211
6212const SCEV *ScalarEvolution::createNodeForPHI(PHINode *PN) {
6213 if (const SCEV *S = createAddRecFromPHI(PN))
6214 return S;
6215
6216 // We do not allow simplifying phi (undef, X) to X here, to avoid reusing the
6217 // phi node for X.
6218 if (Value *V = simplifyInstruction(
6219 PN, {getDataLayout(), &TLI, &DT, &AC, /*CtxI=*/nullptr,
6220 /*UseInstrInfo=*/true, /*CanUseUndef=*/false}))
6221 return getSCEV(V);
6222
6223 if (const SCEV *S = createNodeForPHIWithIdenticalOperands(PN))
6224 return S;
6225
6226 if (const SCEV *S = createNodeFromSelectLikePHI(PN))
6227 return S;
6228
6229 // If it's not a loop phi, we can't handle it yet.
6230 return getUnknown(PN);
6231}
6232
6233bool SCEVMinMaxExprContains(const SCEV *Root, const SCEV *OperandToFind,
6234 SCEVTypes RootKind) {
6235 struct FindClosure {
6236 const SCEV *OperandToFind;
6237 const SCEVTypes RootKind; // Must be a sequential min/max expression.
6238 const SCEVTypes NonSequentialRootKind; // Non-seq variant of RootKind.
6239
6240 bool Found = false;
6241
6242 bool canRecurseInto(SCEVTypes Kind) const {
6243 // We can only recurse into the SCEV expression of the same effective type
6244 // as the type of our root SCEV expression, and into zero-extensions.
6245 return RootKind == Kind || NonSequentialRootKind == Kind ||
6246 scZeroExtend == Kind;
6247 };
6248
6249 FindClosure(const SCEV *OperandToFind, SCEVTypes RootKind)
6250 : OperandToFind(OperandToFind), RootKind(RootKind),
6251 NonSequentialRootKind(
6253 RootKind)) {}
6254
6255 bool follow(const SCEV *S) {
6256 Found = S == OperandToFind;
6257
6258 return !isDone() && canRecurseInto(S->getSCEVType());
6259 }
6260
6261 bool isDone() const { return Found; }
6262 };
6263
6264 FindClosure FC(OperandToFind, RootKind);
6265 visitAll(Root, FC);
6266 return FC.Found;
6267}
6268
6269std::optional<const SCEV *>
6270ScalarEvolution::createNodeForSelectOrPHIInstWithICmpInstCond(Type *Ty,
6271 ICmpInst *Cond,
6272 Value *TrueVal,
6273 Value *FalseVal) {
6274 // Try to match some simple smax or umax patterns.
6275 auto *ICI = Cond;
6276
6277 Value *LHS = ICI->getOperand(0);
6278 Value *RHS = ICI->getOperand(1);
6279
6280 switch (ICI->getPredicate()) {
6281 case ICmpInst::ICMP_SLT:
6282 case ICmpInst::ICMP_SLE:
6283 case ICmpInst::ICMP_ULT:
6284 case ICmpInst::ICMP_ULE:
6285 std::swap(LHS, RHS);
6286 [[fallthrough]];
6287 case ICmpInst::ICMP_SGT:
6288 case ICmpInst::ICMP_SGE:
6289 case ICmpInst::ICMP_UGT:
6290 case ICmpInst::ICMP_UGE:
6291 // a > b ? a+x : b+x -> max(a, b)+x
6292 // a > b ? b+x : a+x -> min(a, b)+x
6294 bool Signed = ICI->isSigned();
6295 const SCEV *LA = getSCEV(TrueVal);
6296 const SCEV *RA = getSCEV(FalseVal);
6297 const SCEV *LS = getSCEV(LHS);
6298 const SCEV *RS = getSCEV(RHS);
6299 if (LA->getType()->isPointerTy()) {
6300 // FIXME: Handle cases where LS/RS are pointers not equal to LA/RA.
6301 // Need to make sure we can't produce weird expressions involving
6302 // negated pointers.
6303 if (LA == LS && RA == RS)
6304 return Signed ? getSMaxExpr(LS, RS) : getUMaxExpr(LS, RS);
6305 if (LA == RS && RA == LS)
6306 return Signed ? getSMinExpr(LS, RS) : getUMinExpr(LS, RS);
6307 }
6308 auto CoerceOperand = [&](const SCEV *Op) -> const SCEV * {
6309 if (Op->getType()->isPointerTy()) {
6312 return Op;
6313 }
6314 if (Signed)
6315 Op = getNoopOrSignExtend(Op, Ty);
6316 else
6317 Op = getNoopOrZeroExtend(Op, Ty);
6318 return Op;
6319 };
6320 LS = CoerceOperand(LS);
6321 RS = CoerceOperand(RS);
6323 break;
6324 const SCEV *LDiff = getMinusSCEV(LA, LS);
6325 const SCEV *RDiff = getMinusSCEV(RA, RS);
6326 if (LDiff == RDiff)
6327 return getAddExpr(Signed ? getSMaxExpr(LS, RS) : getUMaxExpr(LS, RS),
6328 LDiff);
6329 LDiff = getMinusSCEV(LA, RS);
6330 RDiff = getMinusSCEV(RA, LS);
6331 if (LDiff == RDiff)
6332 return getAddExpr(Signed ? getSMinExpr(LS, RS) : getUMinExpr(LS, RS),
6333 LDiff);
6334 }
6335 break;
6336 case ICmpInst::ICMP_NE:
6337 // x != 0 ? x+y : C+y -> x == 0 ? C+y : x+y
6338 std::swap(TrueVal, FalseVal);
6339 [[fallthrough]];
6340 case ICmpInst::ICMP_EQ:
6341 // x == 0 ? C+y : x+y -> umax(x, C)+y iff C u<= 1
6344 const SCEV *X = getNoopOrZeroExtend(getSCEV(LHS), Ty);
6345 const SCEV *TrueValExpr = getSCEV(TrueVal); // C+y
6346 const SCEV *FalseValExpr = getSCEV(FalseVal); // x+y
6347 const SCEV *Y = getMinusSCEV(FalseValExpr, X); // y = (x+y)-x
6348 const SCEV *C = getMinusSCEV(TrueValExpr, Y); // C = (C+y)-y
6349 if (isa<SCEVConstant>(C) && cast<SCEVConstant>(C)->getAPInt().ule(1))
6350 return getAddExpr(getUMaxExpr(X, C), Y);
6351 }
6352 // x == 0 ? 0 : umin (..., x, ...) -> umin_seq(x, umin (...))
6353 // x == 0 ? 0 : umin_seq(..., x, ...) -> umin_seq(x, umin_seq(...))
6354 // x == 0 ? 0 : umin (..., umin_seq(..., x, ...), ...)
6355 // -> umin_seq(x, umin (..., umin_seq(...), ...))
6357 isa<ConstantInt>(TrueVal) && cast<ConstantInt>(TrueVal)->isZero()) {
6358 const SCEV *X = getSCEV(LHS);
6359 while (auto *ZExt = dyn_cast<SCEVZeroExtendExpr>(X))
6360 X = ZExt->getOperand();
6361 if (getTypeSizeInBits(X->getType()) <= getTypeSizeInBits(Ty)) {
6362 const SCEV *FalseValExpr = getSCEV(FalseVal);
6363 if (SCEVMinMaxExprContains(FalseValExpr, X, scSequentialUMinExpr))
6364 return getUMinExpr(getNoopOrZeroExtend(X, Ty), FalseValExpr,
6365 /*Sequential=*/true);
6366 }
6367 }
6368 break;
6369 default:
6370 break;
6371 }
6372
6373 return std::nullopt;
6374}
6375
6376static std::optional<const SCEV *>
6378 const SCEV *TrueExpr, const SCEV *FalseExpr) {
6379 assert(CondExpr->getType()->isIntegerTy(1) &&
6380 TrueExpr->getType() == FalseExpr->getType() &&
6381 TrueExpr->getType()->isIntegerTy(1) &&
6382 "Unexpected operands of a select.");
6383
6384 // i1 cond ? i1 x : i1 C --> C + (i1 cond ? (i1 x - i1 C) : i1 0)
6385 // --> C + (umin_seq cond, x - C)
6386 //
6387 // i1 cond ? i1 C : i1 x --> C + (i1 cond ? i1 0 : (i1 x - i1 C))
6388 // --> C + (i1 ~cond ? (i1 x - i1 C) : i1 0)
6389 // --> C + (umin_seq ~cond, x - C)
6390
6391 // FIXME: while we can't legally model the case where both of the hands
6392 // are fully variable, we only require that the *difference* is constant.
6393 if (!isa<SCEVConstant>(TrueExpr) && !isa<SCEVConstant>(FalseExpr))
6394 return std::nullopt;
6395
6396 const SCEV *X, *C;
6397 if (isa<SCEVConstant>(TrueExpr)) {
6398 CondExpr = SE->getNotSCEV(CondExpr);
6399 X = FalseExpr;
6400 C = TrueExpr;
6401 } else {
6402 X = TrueExpr;
6403 C = FalseExpr;
6404 }
6405 return SE->getAddExpr(C, SE->getUMinExpr(CondExpr, SE->getMinusSCEV(X, C),
6406 /*Sequential=*/true));
6407}
6408
6409static std::optional<const SCEV *>
6411 Value *FalseVal) {
6412 if (!isa<ConstantInt>(TrueVal) && !isa<ConstantInt>(FalseVal))
6413 return std::nullopt;
6414
6415 const auto *SECond = SE->getSCEV(Cond);
6416 const auto *SETrue = SE->getSCEV(TrueVal);
6417 const auto *SEFalse = SE->getSCEV(FalseVal);
6418 return createNodeForSelectViaUMinSeq(SE, SECond, SETrue, SEFalse);
6419}
6420
6421const SCEV *ScalarEvolution::createNodeForSelectOrPHIViaUMinSeq(
6422 Value *V, Value *Cond, Value *TrueVal, Value *FalseVal) {
6423 assert(Cond->getType()->isIntegerTy(1) && "Select condition is not an i1?");
6424 assert(TrueVal->getType() == FalseVal->getType() &&
6425 V->getType() == TrueVal->getType() &&
6426 "Types of select hands and of the result must match.");
6427
6428 // For now, only deal with i1-typed `select`s.
6429 if (!V->getType()->isIntegerTy(1))
6430 return getUnknown(V);
6431
6432 if (std::optional<const SCEV *> S =
6433 createNodeForSelectViaUMinSeq(this, Cond, TrueVal, FalseVal))
6434 return *S;
6435
6436 return getUnknown(V);
6437}
6438
6439const SCEV *ScalarEvolution::createNodeForSelectOrPHI(Value *V, Value *Cond,
6440 Value *TrueVal,
6441 Value *FalseVal) {
6442 // Handle "constant" branch or select. This can occur for instance when a
6443 // loop pass transforms an inner loop and moves on to process the outer loop.
6444 if (auto *CI = dyn_cast<ConstantInt>(Cond))
6445 return getSCEV(CI->isOne() ? TrueVal : FalseVal);
6446
6447 if (auto *I = dyn_cast<Instruction>(V)) {
6448 if (auto *ICI = dyn_cast<ICmpInst>(Cond)) {
6449 if (std::optional<const SCEV *> S =
6450 createNodeForSelectOrPHIInstWithICmpInstCond(I->getType(), ICI,
6451 TrueVal, FalseVal))
6452 return *S;
6453 }
6454 }
6455
6456 return createNodeForSelectOrPHIViaUMinSeq(V, Cond, TrueVal, FalseVal);
6457}
6458
6459/// Expand GEP instructions into add and multiply operations. This allows them
6460/// to be analyzed by regular SCEV code.
6461const SCEV *ScalarEvolution::createNodeForGEP(GEPOperator *GEP) {
6462 assert(GEP->getSourceElementType()->isSized() &&
6463 "GEP source element type must be sized");
6464
6465 SmallVector<SCEVUse, 4> IndexExprs;
6466 for (Value *Index : GEP->indices())
6467 IndexExprs.push_back(getSCEV(Index));
6468 return getGEPExpr(GEP, IndexExprs);
6469}
6470
6471APInt ScalarEvolution::getConstantMultipleImpl(const SCEV *S,
6472 const Instruction *CtxI) {
6473 uint64_t BitWidth = getTypeSizeInBits(S->getType());
6474 auto GetShiftedByZeros = [BitWidth](uint32_t TrailingZeros) {
6475 return TrailingZeros >= BitWidth
6477 : APInt::getOneBitSet(BitWidth, TrailingZeros);
6478 };
6479 auto GetGCDMultiple = [this, CtxI](const SCEVNAryExpr *N) {
6480 // The result is GCD of all operands results.
6481 APInt Res = getConstantMultiple(N->getOperand(0), CtxI);
6482 for (unsigned I = 1, E = N->getNumOperands(); I < E && Res != 1; ++I)
6484 Res, getConstantMultiple(N->getOperand(I), CtxI));
6485 return Res;
6486 };
6487
6488 switch (S->getSCEVType()) {
6489 case scConstant:
6490 return cast<SCEVConstant>(S)->getAPInt();
6491 case scPtrToAddr:
6492 case scPtrToInt:
6493 return getConstantMultiple(cast<SCEVCastExpr>(S)->getOperand());
6494 case scUDivExpr:
6495 case scVScale:
6496 return APInt(BitWidth, 1);
6497 case scTruncate: {
6498 // Only multiples that are a power of 2 will hold after truncation.
6499 const SCEVTruncateExpr *T = cast<SCEVTruncateExpr>(S);
6500 uint32_t TZ = getMinTrailingZeros(T->getOperand(), CtxI);
6501 return GetShiftedByZeros(TZ);
6502 }
6503 case scZeroExtend: {
6504 const SCEVZeroExtendExpr *Z = cast<SCEVZeroExtendExpr>(S);
6505 return getConstantMultiple(Z->getOperand(), CtxI).zext(BitWidth);
6506 }
6507 case scSignExtend: {
6508 // Only multiples that are a power of 2 will hold after sext.
6509 const SCEVSignExtendExpr *E = cast<SCEVSignExtendExpr>(S);
6510 uint32_t TZ = getMinTrailingZeros(E->getOperand(), CtxI);
6511 return GetShiftedByZeros(TZ);
6512 }
6513 case scMulExpr: {
6514 const SCEVMulExpr *M = cast<SCEVMulExpr>(S);
6515 if (M->hasNoUnsignedWrap()) {
6516 // The result is the product of all operand results.
6517 APInt Res = getConstantMultiple(M->getOperand(0), CtxI);
6518 for (const SCEV *Operand : M->operands().drop_front())
6519 Res = Res * getConstantMultiple(Operand, CtxI);
6520 return Res;
6521 }
6522
6523 // If there are no wrap guarentees, find the trailing zeros, which is the
6524 // sum of trailing zeros for all its operands.
6525 uint32_t TZ = 0;
6526 for (const SCEV *Operand : M->operands())
6527 TZ += getMinTrailingZeros(Operand, CtxI);
6528 return GetShiftedByZeros(TZ);
6529 }
6530 case scAddExpr:
6531 case scAddRecExpr: {
6532 const SCEVNAryExpr *N = cast<SCEVNAryExpr>(S);
6533 if (N->hasNoUnsignedWrap())
6534 return GetGCDMultiple(N);
6535 // Find the trailing bits, which is the minimum of its operands.
6536 uint32_t TZ = getMinTrailingZeros(N->getOperand(0), CtxI);
6537 for (const SCEV *Operand : N->operands().drop_front())
6538 TZ = std::min(TZ, getMinTrailingZeros(Operand, CtxI));
6539 return GetShiftedByZeros(TZ);
6540 }
6541 case scUMaxExpr:
6542 case scSMaxExpr:
6543 case scUMinExpr:
6544 case scSMinExpr:
6546 return GetGCDMultiple(cast<SCEVNAryExpr>(S));
6547 case scUnknown: {
6548 // Ask ValueTracking for known bits. SCEVUnknown only become available at
6549 // the point their underlying IR instruction has been defined. If CtxI was
6550 // not provided, use:
6551 // * the first instruction in the entry block if it is an argument
6552 // * the instruction itself otherwise.
6553 const SCEVUnknown *U = cast<SCEVUnknown>(S);
6554 if (!CtxI) {
6555 if (isa<Argument>(U->getValue()))
6556 CtxI = &*F.getEntryBlock().begin();
6557 else if (auto *I = dyn_cast<Instruction>(U->getValue()))
6558 CtxI = I;
6559 }
6560 unsigned Known =
6561 computeKnownBits(U->getValue(), getDataLayout(), &AC, CtxI, &DT)
6562 .countMinTrailingZeros();
6563 return GetShiftedByZeros(Known);
6564 }
6565 case scCouldNotCompute:
6566 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
6567 }
6568 llvm_unreachable("Unknown SCEV kind!");
6569}
6570
6572 const Instruction *CtxI) {
6573 // Skip looking up and updating the cache if there is a context instruction,
6574 // as the result will only be valid in the specified context.
6575 if (CtxI)
6576 return getConstantMultipleImpl(S, CtxI);
6577
6578 auto I = ConstantMultipleCache.find(S);
6579 if (I != ConstantMultipleCache.end())
6580 return I->second;
6581
6582 APInt Result = getConstantMultipleImpl(S, CtxI);
6583 auto InsertPair = ConstantMultipleCache.insert({S, Result});
6584 assert(InsertPair.second && "Should insert a new key");
6585 return InsertPair.first->second;
6586}
6587
6589 APInt Multiple = getConstantMultiple(S);
6590 return Multiple == 0 ? APInt(Multiple.getBitWidth(), 1) : Multiple;
6591}
6592
6594 const Instruction *CtxI) {
6595 return std::min(getConstantMultiple(S, CtxI).countTrailingZeros(),
6596 (unsigned)getTypeSizeInBits(S->getType()));
6597}
6598
6599/// Helper method to assign a range to V from metadata present in the IR.
6600static std::optional<ConstantRange> GetRangeFromMetadata(Value *V) {
6602 if (MDNode *MD = I->getMetadata(LLVMContext::MD_range))
6603 return getConstantRangeFromMetadata(*MD);
6604 if (const auto *CB = dyn_cast<CallBase>(V))
6605 if (std::optional<ConstantRange> Range = CB->getRange())
6606 return Range;
6607 }
6608 if (auto *A = dyn_cast<Argument>(V))
6609 if (std::optional<ConstantRange> Range = A->getRange())
6610 return Range;
6611
6612 return std::nullopt;
6613}
6614
6616 SCEV::NoWrapFlags Flags) {
6617 if (AddRec->getNoWrapFlags(Flags) != Flags) {
6618 AddRec->setNoWrapFlags(Flags);
6619 UnsignedRanges.erase(AddRec);
6620 SignedRanges.erase(AddRec);
6621 ConstantMultipleCache.erase(AddRec);
6622 }
6623}
6624
6625ConstantRange ScalarEvolution::
6626getRangeForUnknownRecurrence(const SCEVUnknown *U) {
6627 const DataLayout &DL = getDataLayout();
6628
6629 unsigned BitWidth = getTypeSizeInBits(U->getType());
6630 const ConstantRange FullSet(BitWidth, /*isFullSet=*/true);
6631
6632 // Match a simple recurrence of the form: <start, ShiftOp, Step>, and then
6633 // use information about the trip count to improve our available range. Note
6634 // that the trip count independent cases are already handled by known bits.
6635 // WARNING: The definition of recurrence used here is subtly different than
6636 // the one used by AddRec (and thus most of this file). Step is allowed to
6637 // be arbitrarily loop varying here, where AddRec allows only loop invariant
6638 // and other addrecs in the same loop (for non-affine addrecs). The code
6639 // below intentionally handles the case where step is not loop invariant.
6640 auto *P = dyn_cast<PHINode>(U->getValue());
6641 if (!P)
6642 return FullSet;
6643
6644 // Make sure that no Phi input comes from an unreachable block. Otherwise,
6645 // even the values that are not available in these blocks may come from them,
6646 // and this leads to false-positive recurrence test.
6647 for (auto *Pred : predecessors(P->getParent()))
6648 if (!DT.isReachableFromEntry(Pred))
6649 return FullSet;
6650
6651 BinaryOperator *BO;
6652 Value *Start, *Step;
6653 if (!matchSimpleRecurrence(P, BO, Start, Step))
6654 return FullSet;
6655
6656 // If we found a recurrence in reachable code, we must be in a loop. Note
6657 // that BO might be in some subloop of L, and that's completely okay.
6658 auto *L = LI.getLoopFor(P->getParent());
6659 assert(L && L->getHeader() == P->getParent());
6660 if (!L->contains(BO->getParent()))
6661 // NOTE: This bailout should be an assert instead. However, asserting
6662 // the condition here exposes a case where LoopFusion is querying SCEV
6663 // with malformed loop information during the midst of the transform.
6664 // There doesn't appear to be an obvious fix, so for the moment bailout
6665 // until the caller issue can be fixed. PR49566 tracks the bug.
6666 return FullSet;
6667
6668 // TODO: Extend to other opcodes such as mul, and div
6669 switch (BO->getOpcode()) {
6670 default:
6671 return FullSet;
6672 case Instruction::AShr:
6673 case Instruction::LShr:
6674 case Instruction::Shl:
6675 break;
6676 };
6677
6678 if (BO->getOperand(0) != P)
6679 // TODO: Handle the power function forms some day.
6680 return FullSet;
6681
6682 unsigned TC = getSmallConstantMaxTripCount(L);
6683 if (!TC || TC >= BitWidth)
6684 return FullSet;
6685
6686 auto KnownStart = computeKnownBits(Start, DL, &AC, nullptr, &DT);
6687 auto KnownStep = computeKnownBits(Step, DL, &AC, nullptr, &DT);
6688 assert(KnownStart.getBitWidth() == BitWidth &&
6689 KnownStep.getBitWidth() == BitWidth);
6690
6691 // Compute total shift amount, being careful of overflow and bitwidths.
6692 auto MaxShiftAmt = KnownStep.getMaxValue();
6693 APInt TCAP(BitWidth, TC-1);
6694 bool Overflow = false;
6695 auto TotalShift = MaxShiftAmt.umul_ov(TCAP, Overflow);
6696 if (Overflow)
6697 return FullSet;
6698
6699 switch (BO->getOpcode()) {
6700 default:
6701 llvm_unreachable("filtered out above");
6702 case Instruction::AShr: {
6703 // For each ashr, three cases:
6704 // shift = 0 => unchanged value
6705 // saturation => 0 or -1
6706 // other => a value closer to zero (of the same sign)
6707 // Thus, the end value is closer to zero than the start.
6708 auto KnownEnd = KnownBits::ashr(KnownStart,
6709 KnownBits::makeConstant(TotalShift));
6710 if (KnownStart.isNonNegative())
6711 // Analogous to lshr (simply not yet canonicalized)
6712 return ConstantRange::getNonEmpty(KnownEnd.getMinValue(),
6713 KnownStart.getMaxValue() + 1);
6714 if (KnownStart.isNegative())
6715 // End >=u Start && End <=s Start
6716 return ConstantRange::getNonEmpty(KnownStart.getMinValue(),
6717 KnownEnd.getMaxValue() + 1);
6718 break;
6719 }
6720 case Instruction::LShr: {
6721 // For each lshr, three cases:
6722 // shift = 0 => unchanged value
6723 // saturation => 0
6724 // other => a smaller positive number
6725 // Thus, the low end of the unsigned range is the last value produced.
6726 auto KnownEnd = KnownBits::lshr(KnownStart,
6727 KnownBits::makeConstant(TotalShift));
6728 return ConstantRange::getNonEmpty(KnownEnd.getMinValue(),
6729 KnownStart.getMaxValue() + 1);
6730 }
6731 case Instruction::Shl: {
6732 // Iff no bits are shifted out, value increases on every shift.
6733 auto KnownEnd = KnownBits::shl(KnownStart,
6734 KnownBits::makeConstant(TotalShift));
6735 if (TotalShift.ult(KnownStart.countMinLeadingZeros()))
6736 return ConstantRange(KnownStart.getMinValue(),
6737 KnownEnd.getMaxValue() + 1);
6738 break;
6739 }
6740 };
6741 return FullSet;
6742}
6743
6744// The goal of this function is to check if recursively visiting the operands
6745// of this PHI might lead to an infinite loop. If we do see such a loop,
6746// there's no good way to break it, so we avoid analyzing such cases.
6747//
6748// getRangeRef previously used a visited set to avoid infinite loops, but this
6749// caused other issues: the result was dependent on the order of getRangeRef
6750// calls, and the interaction with createSCEVIter could cause a stack overflow
6751// in some cases (see issue #148253).
6752//
6753// FIXME: The way this is implemented is overly conservative; this checks
6754// for a few obviously safe patterns, but anything that doesn't lead to
6755// recursion is fine.
6757 Value *Cond = nullptr, *LHS = nullptr, *RHS = nullptr;
6759 return true;
6760
6761 if (all_of(PHI->operands(),
6762 [&](Value *Operand) { return DT.dominates(Operand, PHI); }))
6763 return true;
6764
6765 return false;
6766}
6767
6768const ConstantRange &
6769ScalarEvolution::getRangeRefIter(const SCEV *S,
6770 ScalarEvolution::RangeSignHint SignHint) {
6771 DenseMap<const SCEV *, ConstantRange> &Cache =
6772 SignHint == ScalarEvolution::HINT_RANGE_UNSIGNED ? UnsignedRanges
6773 : SignedRanges;
6774 SmallVector<SCEVUse> WorkList;
6775 SmallPtrSet<const SCEV *, 8> Seen;
6776
6777 // Add Expr to the worklist, if Expr is either an N-ary expression or a
6778 // SCEVUnknown PHI node.
6779 auto AddToWorklist = [&WorkList, &Seen, &Cache](const SCEV *Expr) {
6780 if (!Seen.insert(Expr).second)
6781 return;
6782 if (Cache.contains(Expr))
6783 return;
6784 switch (Expr->getSCEVType()) {
6785 case scUnknown:
6786 if (!isa<PHINode>(cast<SCEVUnknown>(Expr)->getValue()))
6787 break;
6788 [[fallthrough]];
6789 case scConstant:
6790 case scVScale:
6791 case scTruncate:
6792 case scZeroExtend:
6793 case scSignExtend:
6794 case scPtrToAddr:
6795 case scPtrToInt:
6796 case scAddExpr:
6797 case scMulExpr:
6798 case scUDivExpr:
6799 case scAddRecExpr:
6800 case scUMaxExpr:
6801 case scSMaxExpr:
6802 case scUMinExpr:
6803 case scSMinExpr:
6805 WorkList.push_back(Expr);
6806 break;
6807 case scCouldNotCompute:
6808 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
6809 }
6810 };
6811 AddToWorklist(S);
6812
6813 // Build worklist by queuing operands of N-ary expressions and phi nodes.
6814 for (unsigned I = 0; I != WorkList.size(); ++I) {
6815 const SCEV *P = WorkList[I];
6816 auto *UnknownS = dyn_cast<SCEVUnknown>(P);
6817 // If it is not a `SCEVUnknown`, just recurse into operands.
6818 if (!UnknownS) {
6819 for (const SCEV *Op : P->operands())
6820 AddToWorklist(Op);
6821 continue;
6822 }
6823 // `SCEVUnknown`'s require special treatment.
6824 if (PHINode *P = dyn_cast<PHINode>(UnknownS->getValue())) {
6825 if (!RangeRefPHIAllowedOperands(DT, P))
6826 continue;
6827 for (auto &Op : reverse(P->operands()))
6828 AddToWorklist(getSCEV(Op));
6829 }
6830 }
6831
6832 if (!WorkList.empty()) {
6833 // Use getRangeRef to compute ranges for items in the worklist in reverse
6834 // order. This will force ranges for earlier operands to be computed before
6835 // their users in most cases.
6836 for (const SCEV *P : reverse(drop_begin(WorkList))) {
6837 getRangeRef(P, SignHint);
6838 }
6839 }
6840
6841 return getRangeRef(S, SignHint, 0);
6842}
6843
6844/// Determine the range for a particular SCEV. If SignHint is
6845/// HINT_RANGE_UNSIGNED (resp. HINT_RANGE_SIGNED) then getRange prefers ranges
6846/// with a "cleaner" unsigned (resp. signed) representation.
6847const ConstantRange &ScalarEvolution::getRangeRef(
6848 const SCEV *S, ScalarEvolution::RangeSignHint SignHint, unsigned Depth) {
6849 DenseMap<const SCEV *, ConstantRange> &Cache =
6850 SignHint == ScalarEvolution::HINT_RANGE_UNSIGNED ? UnsignedRanges
6851 : SignedRanges;
6853 SignHint == ScalarEvolution::HINT_RANGE_UNSIGNED ? ConstantRange::Unsigned
6855
6856 // See if we've computed this range already.
6858 if (I != Cache.end())
6859 return I->second;
6860
6861 if (const SCEVConstant *C = dyn_cast<SCEVConstant>(S))
6862 return setRange(C, SignHint, ConstantRange(C->getAPInt()));
6863
6864 // Switch to iteratively computing the range for S, if it is part of a deeply
6865 // nested expression.
6867 return getRangeRefIter(S, SignHint);
6868
6869 unsigned BitWidth = getTypeSizeInBits(S->getType());
6870 ConstantRange ConservativeResult(BitWidth, /*isFullSet=*/true);
6871 using OBO = OverflowingBinaryOperator;
6872
6873 // If the value has known zeros, the maximum value will have those known zeros
6874 // as well.
6875 if (SignHint == ScalarEvolution::HINT_RANGE_UNSIGNED) {
6876 APInt Multiple = getNonZeroConstantMultiple(S);
6877 APInt Remainder = APInt::getMaxValue(BitWidth).urem(Multiple);
6878 if (!Remainder.isZero())
6879 ConservativeResult =
6880 ConstantRange(APInt::getMinValue(BitWidth),
6881 APInt::getMaxValue(BitWidth) - Remainder + 1);
6882 }
6883 else {
6884 uint32_t TZ = getMinTrailingZeros(S);
6885 if (TZ != 0) {
6886 ConservativeResult = ConstantRange(
6888 APInt::getSignedMaxValue(BitWidth).ashr(TZ).shl(TZ) + 1);
6889 }
6890 }
6891
6892 switch (S->getSCEVType()) {
6893 case scConstant:
6894 llvm_unreachable("Already handled above.");
6895 case scVScale:
6896 return setRange(S, SignHint, getVScaleRange(&F, BitWidth));
6897 case scTruncate: {
6898 const SCEVTruncateExpr *Trunc = cast<SCEVTruncateExpr>(S);
6899 ConstantRange X = getRangeRef(Trunc->getOperand(), SignHint, Depth + 1);
6900 return setRange(
6901 Trunc, SignHint,
6902 ConservativeResult.intersectWith(X.truncate(BitWidth), RangeType));
6903 }
6904 case scZeroExtend: {
6905 const SCEVZeroExtendExpr *ZExt = cast<SCEVZeroExtendExpr>(S);
6906 ConstantRange X = getRangeRef(ZExt->getOperand(), SignHint, Depth + 1);
6907 return setRange(
6908 ZExt, SignHint,
6909 ConservativeResult.intersectWith(X.zeroExtend(BitWidth), RangeType));
6910 }
6911 case scSignExtend: {
6912 const SCEVSignExtendExpr *SExt = cast<SCEVSignExtendExpr>(S);
6913 ConstantRange X = getRangeRef(SExt->getOperand(), SignHint, Depth + 1);
6914 return setRange(
6915 SExt, SignHint,
6916 ConservativeResult.intersectWith(X.signExtend(BitWidth), RangeType));
6917 }
6918 case scPtrToAddr:
6919 case scPtrToInt: {
6920 const SCEVCastExpr *Cast = cast<SCEVCastExpr>(S);
6921 ConstantRange X = getRangeRef(Cast->getOperand(), SignHint, Depth + 1);
6922 return setRange(Cast, SignHint, X);
6923 }
6924 case scAddExpr: {
6925 const SCEVAddExpr *Add = cast<SCEVAddExpr>(S);
6926 // Check if this is a URem pattern: A - (A / B) * B, which is always < B.
6927 const SCEV *URemLHS = nullptr, *URemRHS = nullptr;
6928 if (SignHint == ScalarEvolution::HINT_RANGE_UNSIGNED &&
6929 match(S, m_scev_URem(m_SCEV(URemLHS), m_SCEV(URemRHS), *this))) {
6930 ConstantRange LHSRange = getRangeRef(URemLHS, SignHint, Depth + 1);
6931 ConstantRange RHSRange = getRangeRef(URemRHS, SignHint, Depth + 1);
6932 ConservativeResult =
6933 ConservativeResult.intersectWith(LHSRange.urem(RHSRange), RangeType);
6934 }
6935 ConstantRange X = getRangeRef(Add->getOperand(0), SignHint, Depth + 1);
6936 unsigned WrapType = OBO::AnyWrap;
6937 if (Add->hasNoSignedWrap())
6938 WrapType |= OBO::NoSignedWrap;
6939 if (Add->hasNoUnsignedWrap())
6940 WrapType |= OBO::NoUnsignedWrap;
6941 for (const SCEV *Op : drop_begin(Add->operands()))
6942 X = X.addWithNoWrap(getRangeRef(Op, SignHint, Depth + 1), WrapType,
6943 RangeType);
6944 return setRange(Add, SignHint,
6945 ConservativeResult.intersectWith(X, RangeType));
6946 }
6947 case scMulExpr: {
6948 const SCEVMulExpr *Mul = cast<SCEVMulExpr>(S);
6949 ConstantRange X = getRangeRef(Mul->getOperand(0), SignHint, Depth + 1);
6950 for (const SCEV *Op : drop_begin(Mul->operands()))
6951 X = X.multiply(getRangeRef(Op, SignHint, Depth + 1));
6952 return setRange(Mul, SignHint,
6953 ConservativeResult.intersectWith(X, RangeType));
6954 }
6955 case scUDivExpr: {
6956 const SCEVUDivExpr *UDiv = cast<SCEVUDivExpr>(S);
6957 ConstantRange X = getRangeRef(UDiv->getLHS(), SignHint, Depth + 1);
6958 ConstantRange Y = getRangeRef(UDiv->getRHS(), SignHint, Depth + 1);
6959 return setRange(UDiv, SignHint,
6960 ConservativeResult.intersectWith(X.udiv(Y), RangeType));
6961 }
6962 case scAddRecExpr: {
6963 const SCEVAddRecExpr *AddRec = cast<SCEVAddRecExpr>(S);
6964 // If there's no unsigned wrap, the value will never be less than its
6965 // initial value.
6966 if (AddRec->hasNoUnsignedWrap()) {
6967 APInt UnsignedMinValue = getUnsignedRangeMin(AddRec->getStart());
6968 if (!UnsignedMinValue.isZero())
6969 ConservativeResult = ConservativeResult.intersectWith(
6970 ConstantRange(UnsignedMinValue, APInt(BitWidth, 0)), RangeType);
6971 }
6972
6973 // If there's no signed wrap, and all the operands except initial value have
6974 // the same sign or zero, the value won't ever be:
6975 // 1: smaller than initial value if operands are non negative,
6976 // 2: bigger than initial value if operands are non positive.
6977 // For both cases, value can not cross signed min/max boundary.
6978 if (AddRec->hasNoSignedWrap()) {
6979 bool AllNonNeg = true;
6980 bool AllNonPos = true;
6981 for (unsigned i = 1, e = AddRec->getNumOperands(); i != e; ++i) {
6982 if (!isKnownNonNegative(AddRec->getOperand(i)))
6983 AllNonNeg = false;
6984 if (!isKnownNonPositive(AddRec->getOperand(i)))
6985 AllNonPos = false;
6986 }
6987 if (AllNonNeg)
6988 ConservativeResult = ConservativeResult.intersectWith(
6991 RangeType);
6992 else if (AllNonPos)
6993 ConservativeResult = ConservativeResult.intersectWith(
6995 getSignedRangeMax(AddRec->getStart()) +
6996 1),
6997 RangeType);
6998 }
6999
7000 // TODO: non-affine addrec
7001 if (AddRec->isAffine()) {
7002 const SCEV *MaxBEScev =
7004 if (!isa<SCEVCouldNotCompute>(MaxBEScev)) {
7005 APInt MaxBECount = cast<SCEVConstant>(MaxBEScev)->getAPInt();
7006
7007 // Adjust MaxBECount to the same bitwidth as AddRec. We can truncate if
7008 // MaxBECount's active bits are all <= AddRec's bit width.
7009 if (MaxBECount.getBitWidth() > BitWidth &&
7010 MaxBECount.getActiveBits() <= BitWidth)
7011 MaxBECount = MaxBECount.trunc(BitWidth);
7012 else if (MaxBECount.getBitWidth() < BitWidth)
7013 MaxBECount = MaxBECount.zext(BitWidth);
7014
7015 if (MaxBECount.getBitWidth() == BitWidth) {
7016 auto RangeFromAffine = getRangeForAffineAR(
7017 AddRec->getStart(), AddRec->getStepRecurrence(*this), MaxBECount);
7018 ConservativeResult =
7019 ConservativeResult.intersectWith(RangeFromAffine, RangeType);
7020
7021 auto RangeFromFactoring = getRangeViaFactoring(
7022 AddRec->getStart(), AddRec->getStepRecurrence(*this), MaxBECount);
7023 ConservativeResult =
7024 ConservativeResult.intersectWith(RangeFromFactoring, RangeType);
7025 }
7026 }
7027
7028 // Now try symbolic BE count and more powerful methods.
7030 const SCEV *SymbolicMaxBECount =
7032 if (!isa<SCEVCouldNotCompute>(SymbolicMaxBECount) &&
7033 getTypeSizeInBits(MaxBEScev->getType()) <= BitWidth &&
7034 AddRec->hasNoSelfWrap()) {
7035 auto RangeFromAffineNew = getRangeForAffineNoSelfWrappingAR(
7036 AddRec, SymbolicMaxBECount, BitWidth, SignHint);
7037 ConservativeResult =
7038 ConservativeResult.intersectWith(RangeFromAffineNew, RangeType);
7039 }
7040 }
7041 }
7042
7043 return setRange(AddRec, SignHint, std::move(ConservativeResult));
7044 }
7045 case scUMaxExpr:
7046 case scSMaxExpr:
7047 case scUMinExpr:
7048 case scSMinExpr:
7049 case scSequentialUMinExpr: {
7051 switch (S->getSCEVType()) {
7052 case scUMaxExpr:
7053 ID = Intrinsic::umax;
7054 break;
7055 case scSMaxExpr:
7056 ID = Intrinsic::smax;
7057 break;
7058 case scUMinExpr:
7060 ID = Intrinsic::umin;
7061 break;
7062 case scSMinExpr:
7063 ID = Intrinsic::smin;
7064 break;
7065 default:
7066 llvm_unreachable("Unknown SCEVMinMaxExpr/SCEVSequentialMinMaxExpr.");
7067 }
7068
7069 const auto *NAry = cast<SCEVNAryExpr>(S);
7070 ConstantRange X = getRangeRef(NAry->getOperand(0), SignHint, Depth + 1);
7071 for (unsigned i = 1, e = NAry->getNumOperands(); i != e; ++i)
7072 X = X.intrinsic(
7073 ID, {X, getRangeRef(NAry->getOperand(i), SignHint, Depth + 1)});
7074 return setRange(S, SignHint,
7075 ConservativeResult.intersectWith(X, RangeType));
7076 }
7077 case scUnknown: {
7078 const SCEVUnknown *U = cast<SCEVUnknown>(S);
7079 Value *V = U->getValue();
7080
7081 // Check if the IR explicitly contains !range metadata.
7082 std::optional<ConstantRange> MDRange = GetRangeFromMetadata(V);
7083 if (MDRange)
7084 ConservativeResult =
7085 ConservativeResult.intersectWith(*MDRange, RangeType);
7086
7087 // Use facts about recurrences in the underlying IR. Note that add
7088 // recurrences are AddRecExprs and thus don't hit this path. This
7089 // primarily handles shift recurrences.
7090 auto CR = getRangeForUnknownRecurrence(U);
7091 ConservativeResult = ConservativeResult.intersectWith(CR);
7092
7093 // See if ValueTracking can give us a useful range.
7094 const DataLayout &DL = getDataLayout();
7095 KnownBits Known = computeKnownBits(V, DL, &AC, nullptr, &DT);
7096 if (Known.getBitWidth() != BitWidth)
7097 Known = Known.zextOrTrunc(BitWidth);
7098
7099 // ValueTracking may be able to compute a tighter result for the number of
7100 // sign bits than for the value of those sign bits.
7101 unsigned NS = ComputeNumSignBits(V, DL, &AC, nullptr, &DT);
7102 if (U->getType()->isPointerTy()) {
7103 // If the pointer size is larger than the index size type, this can cause
7104 // NS to be larger than BitWidth. So compensate for this.
7105 unsigned ptrSize = DL.getPointerTypeSizeInBits(U->getType());
7106 int ptrIdxDiff = ptrSize - BitWidth;
7107 if (ptrIdxDiff > 0 && ptrSize > BitWidth && NS > (unsigned)ptrIdxDiff)
7108 NS -= ptrIdxDiff;
7109 }
7110
7111 if (NS > 1) {
7112 // If we know any of the sign bits, we know all of the sign bits.
7113 if (!Known.Zero.getHiBits(NS).isZero())
7114 Known.Zero.setHighBits(NS);
7115 if (!Known.One.getHiBits(NS).isZero())
7116 Known.One.setHighBits(NS);
7117 }
7118
7119 if (Known.getMinValue() != Known.getMaxValue() + 1)
7120 ConservativeResult = ConservativeResult.intersectWith(
7121 ConstantRange(Known.getMinValue(), Known.getMaxValue() + 1),
7122 RangeType);
7123 if (NS > 1)
7124 ConservativeResult = ConservativeResult.intersectWith(
7125 ConstantRange(APInt::getSignedMinValue(BitWidth).ashr(NS - 1),
7126 APInt::getSignedMaxValue(BitWidth).ashr(NS - 1) + 1),
7127 RangeType);
7128
7129 if (U->getType()->isPointerTy() && SignHint == HINT_RANGE_UNSIGNED) {
7130 // Strengthen the range if the underlying IR value is a
7131 // global/alloca/heap allocation using the size of the object.
7132 bool CanBeNull, CanBeFreed;
7133 uint64_t DerefBytes =
7134 V->getPointerDereferenceableBytes(DL, CanBeNull, CanBeFreed);
7135 if (DerefBytes > 1 && isUIntN(BitWidth, DerefBytes)) {
7136 // The highest address the object can start is DerefBytes bytes before
7137 // the end (unsigned max value). If this value is not a multiple of the
7138 // alignment, the last possible start value is the next lowest multiple
7139 // of the alignment. Note: The computations below cannot overflow,
7140 // because if they would there's no possible start address for the
7141 // object.
7142 APInt MaxVal =
7143 APInt::getMaxValue(BitWidth) - APInt(BitWidth, DerefBytes);
7144 uint64_t Align = U->getValue()->getPointerAlignment(DL).value();
7145 uint64_t Rem = MaxVal.urem(Align);
7146 MaxVal -= APInt(BitWidth, Rem);
7147 APInt MinVal = APInt::getZero(BitWidth);
7148 if (llvm::isKnownNonZero(V, DL))
7149 MinVal = Align;
7150 ConservativeResult = ConservativeResult.intersectWith(
7151 ConstantRange::getNonEmpty(MinVal, MaxVal + 1), RangeType);
7152 }
7153 }
7154
7155 // A range of Phi is a subset of union of all ranges of its input.
7156 if (PHINode *Phi = dyn_cast<PHINode>(V)) {
7157 // SCEVExpander sometimes creates SCEVUnknowns that are secretly
7158 // AddRecs; return the range for the corresponding AddRec.
7159 if (auto *AR = dyn_cast<SCEVAddRecExpr>(getSCEV(V)))
7160 return getRangeRef(AR, SignHint, Depth + 1);
7161
7162 // Make sure that we do not run over cycled Phis.
7163 if (RangeRefPHIAllowedOperands(DT, Phi)) {
7164 ConstantRange RangeFromOps(BitWidth, /*isFullSet=*/false);
7165
7166 for (const auto &Op : Phi->operands()) {
7167 auto OpRange = getRangeRef(getSCEV(Op), SignHint, Depth + 1);
7168 RangeFromOps = RangeFromOps.unionWith(OpRange);
7169 // No point to continue if we already have a full set.
7170 if (RangeFromOps.isFullSet())
7171 break;
7172 }
7173 ConservativeResult =
7174 ConservativeResult.intersectWith(RangeFromOps, RangeType);
7175 }
7176 }
7177
7178 // vscale can't be equal to zero
7179 if (const auto *II = dyn_cast<IntrinsicInst>(V))
7180 if (II->getIntrinsicID() == Intrinsic::vscale) {
7181 ConstantRange Disallowed = APInt::getZero(BitWidth);
7182 ConservativeResult = ConservativeResult.difference(Disallowed);
7183 }
7184
7185 return setRange(U, SignHint, std::move(ConservativeResult));
7186 }
7187 case scCouldNotCompute:
7188 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
7189 }
7190
7191 return setRange(S, SignHint, std::move(ConservativeResult));
7192}
7193
7194// Given a StartRange, Step and MaxBECount for an expression compute a range of
7195// values that the expression can take. Initially, the expression has a value
7196// from StartRange and then is changed by Step up to MaxBECount times. Signed
7197// argument defines if we treat Step as signed or unsigned.
7199 const ConstantRange &StartRange,
7200 const APInt &MaxBECount,
7201 bool Signed) {
7202 unsigned BitWidth = Step.getBitWidth();
7203 assert(BitWidth == StartRange.getBitWidth() &&
7204 BitWidth == MaxBECount.getBitWidth() && "mismatched bit widths");
7205 // If either Step or MaxBECount is 0, then the expression won't change, and we
7206 // just need to return the initial range.
7207 if (Step == 0 || MaxBECount == 0)
7208 return StartRange;
7209
7210 // If we don't know anything about the initial value (i.e. StartRange is
7211 // FullRange), then we don't know anything about the final range either.
7212 // Return FullRange.
7213 if (StartRange.isFullSet())
7214 return ConstantRange::getFull(BitWidth);
7215
7216 // If Step is signed and negative, then we use its absolute value, but we also
7217 // note that we're moving in the opposite direction.
7218 bool Descending = Signed && Step.isNegative();
7219
7220 if (Signed)
7221 // This is correct even for INT_SMIN. Let's look at i8 to illustrate this:
7222 // abs(INT_SMIN) = abs(-128) = abs(0x80) = -0x80 = 0x80 = 128.
7223 // This equations hold true due to the well-defined wrap-around behavior of
7224 // APInt.
7225 Step = Step.abs();
7226
7227 // Check if Offset is more than full span of BitWidth. If it is, the
7228 // expression is guaranteed to overflow.
7229 if (APInt::getMaxValue(StartRange.getBitWidth()).udiv(Step).ult(MaxBECount))
7230 return ConstantRange::getFull(BitWidth);
7231
7232 // Offset is by how much the expression can change. Checks above guarantee no
7233 // overflow here.
7234 APInt Offset = Step * MaxBECount;
7235
7236 // Minimum value of the final range will match the minimal value of StartRange
7237 // if the expression is increasing and will be decreased by Offset otherwise.
7238 // Maximum value of the final range will match the maximal value of StartRange
7239 // if the expression is decreasing and will be increased by Offset otherwise.
7240 APInt StartLower = StartRange.getLower();
7241 APInt StartUpper = StartRange.getUpper() - 1;
7242 APInt MovedBoundary = Descending ? (StartLower - std::move(Offset))
7243 : (StartUpper + std::move(Offset));
7244
7245 // It's possible that the new minimum/maximum value will fall into the initial
7246 // range (due to wrap around). This means that the expression can take any
7247 // value in this bitwidth, and we have to return full range.
7248 if (StartRange.contains(MovedBoundary))
7249 return ConstantRange::getFull(BitWidth);
7250
7251 APInt NewLower =
7252 Descending ? std::move(MovedBoundary) : std::move(StartLower);
7253 APInt NewUpper =
7254 Descending ? std::move(StartUpper) : std::move(MovedBoundary);
7255 NewUpper += 1;
7256
7257 // No overflow detected, return [StartLower, StartUpper + Offset + 1) range.
7258 return ConstantRange::getNonEmpty(std::move(NewLower), std::move(NewUpper));
7259}
7260
7261ConstantRange ScalarEvolution::getRangeForAffineAR(const SCEV *Start,
7262 const SCEV *Step,
7263 const APInt &MaxBECount) {
7264 assert(getTypeSizeInBits(Start->getType()) ==
7265 getTypeSizeInBits(Step->getType()) &&
7266 getTypeSizeInBits(Start->getType()) == MaxBECount.getBitWidth() &&
7267 "mismatched bit widths");
7268
7269 // First, consider step signed.
7270 ConstantRange StartSRange = getSignedRange(Start);
7271 ConstantRange StepSRange = getSignedRange(Step);
7272
7273 // If Step can be both positive and negative, we need to find ranges for the
7274 // maximum absolute step values in both directions and union them.
7275 ConstantRange SR = getRangeForAffineARHelper(
7276 StepSRange.getSignedMin(), StartSRange, MaxBECount, /* Signed = */ true);
7278 StartSRange, MaxBECount,
7279 /* Signed = */ true));
7280
7281 // Next, consider step unsigned.
7282 ConstantRange UR = getRangeForAffineARHelper(
7283 getUnsignedRangeMax(Step), getUnsignedRange(Start), MaxBECount,
7284 /* Signed = */ false);
7285
7286 // Finally, intersect signed and unsigned ranges.
7288}
7289
7290ConstantRange ScalarEvolution::getRangeForAffineNoSelfWrappingAR(
7291 const SCEVAddRecExpr *AddRec, const SCEV *MaxBECount, unsigned BitWidth,
7292 ScalarEvolution::RangeSignHint SignHint) {
7293 assert(AddRec->isAffine() && "Non-affine AddRecs are not suppored!\n");
7294 assert(AddRec->hasNoSelfWrap() &&
7295 "This only works for non-self-wrapping AddRecs!");
7296 const bool IsSigned = SignHint == HINT_RANGE_SIGNED;
7297 const SCEV *Step = AddRec->getStepRecurrence(*this);
7298 // Only deal with constant step to save compile time.
7299 if (!isa<SCEVConstant>(Step))
7300 return ConstantRange::getFull(BitWidth);
7301 // Let's make sure that we can prove that we do not self-wrap during
7302 // MaxBECount iterations. We need this because MaxBECount is a maximum
7303 // iteration count estimate, and we might infer nw from some exit for which we
7304 // do not know max exit count (or any other side reasoning).
7305 // TODO: Turn into assert at some point.
7306 if (getTypeSizeInBits(MaxBECount->getType()) >
7307 getTypeSizeInBits(AddRec->getType()))
7308 return ConstantRange::getFull(BitWidth);
7309 MaxBECount = getNoopOrZeroExtend(MaxBECount, AddRec->getType());
7310 const SCEV *RangeWidth = getMinusOne(AddRec->getType());
7311 const SCEV *StepAbs = getUMinExpr(Step, getNegativeSCEV(Step));
7312 const SCEV *MaxItersWithoutWrap = getUDivExpr(RangeWidth, StepAbs);
7313 if (!isKnownPredicateViaConstantRanges(ICmpInst::ICMP_ULE, MaxBECount,
7314 MaxItersWithoutWrap))
7315 return ConstantRange::getFull(BitWidth);
7316
7317 ICmpInst::Predicate LEPred =
7319 ICmpInst::Predicate GEPred =
7321 const SCEV *End = AddRec->evaluateAtIteration(MaxBECount, *this);
7322
7323 // We know that there is no self-wrap. Let's take Start and End values and
7324 // look at all intermediate values V1, V2, ..., Vn that IndVar takes during
7325 // the iteration. They either lie inside the range [Min(Start, End),
7326 // Max(Start, End)] or outside it:
7327 //
7328 // Case 1: RangeMin ... Start V1 ... VN End ... RangeMax;
7329 // Case 2: RangeMin Vk ... V1 Start ... End Vn ... Vk + 1 RangeMax;
7330 //
7331 // No self wrap flag guarantees that the intermediate values cannot be BOTH
7332 // outside and inside the range [Min(Start, End), Max(Start, End)]. Using that
7333 // knowledge, let's try to prove that we are dealing with Case 1. It is so if
7334 // Start <= End and step is positive, or Start >= End and step is negative.
7335 const SCEV *Start = applyLoopGuards(AddRec->getStart(), AddRec->getLoop());
7336 ConstantRange StartRange = getRangeRef(Start, SignHint);
7337 ConstantRange EndRange = getRangeRef(End, SignHint);
7338 ConstantRange RangeBetween = StartRange.unionWith(EndRange);
7339 // If they already cover full iteration space, we will know nothing useful
7340 // even if we prove what we want to prove.
7341 if (RangeBetween.isFullSet())
7342 return RangeBetween;
7343 // Only deal with ranges that do not wrap (i.e. RangeMin < RangeMax).
7344 bool IsWrappedSet = IsSigned ? RangeBetween.isSignWrappedSet()
7345 : RangeBetween.isWrappedSet();
7346 if (IsWrappedSet)
7347 return ConstantRange::getFull(BitWidth);
7348
7349 if (isKnownPositive(Step) &&
7350 isKnownPredicateViaConstantRanges(LEPred, Start, End))
7351 return RangeBetween;
7352 if (isKnownNegative(Step) &&
7353 isKnownPredicateViaConstantRanges(GEPred, Start, End))
7354 return RangeBetween;
7355 return ConstantRange::getFull(BitWidth);
7356}
7357
7358ConstantRange ScalarEvolution::getRangeViaFactoring(const SCEV *Start,
7359 const SCEV *Step,
7360 const APInt &MaxBECount) {
7361 // RangeOf({C?A:B,+,C?P:Q}) == RangeOf(C?{A,+,P}:{B,+,Q})
7362 // == RangeOf({A,+,P}) union RangeOf({B,+,Q})
7363
7364 unsigned BitWidth = MaxBECount.getBitWidth();
7365 assert(getTypeSizeInBits(Start->getType()) == BitWidth &&
7366 getTypeSizeInBits(Step->getType()) == BitWidth &&
7367 "mismatched bit widths");
7368
7369 struct SelectPattern {
7370 Value *Condition = nullptr;
7371 APInt TrueValue;
7372 APInt FalseValue;
7373
7374 explicit SelectPattern(ScalarEvolution &SE, unsigned BitWidth,
7375 const SCEV *S) {
7376 std::optional<unsigned> CastOp;
7377 APInt Offset(BitWidth, 0);
7378
7380 "Should be!");
7381
7382 // Peel off a constant offset. In the future we could consider being
7383 // smarter here and handle {Start+Step,+,Step} too.
7384 const APInt *Off;
7385 if (match(S, m_scev_Add(m_scev_APInt(Off), m_SCEV(S))))
7386 Offset = *Off;
7387
7388 // Peel off a cast operation
7389 if (auto *SCast = dyn_cast<SCEVIntegralCastExpr>(S)) {
7390 CastOp = SCast->getSCEVType();
7391 S = SCast->getOperand();
7392 }
7393
7394 using namespace llvm::PatternMatch;
7395
7396 auto *SU = dyn_cast<SCEVUnknown>(S);
7397 const APInt *TrueVal, *FalseVal;
7398 if (!SU ||
7399 !match(SU->getValue(), m_Select(m_Value(Condition), m_APInt(TrueVal),
7400 m_APInt(FalseVal)))) {
7401 Condition = nullptr;
7402 return;
7403 }
7404
7405 TrueValue = *TrueVal;
7406 FalseValue = *FalseVal;
7407
7408 // Re-apply the cast we peeled off earlier
7409 if (CastOp)
7410 switch (*CastOp) {
7411 default:
7412 llvm_unreachable("Unknown SCEV cast type!");
7413
7414 case scTruncate:
7415 TrueValue = TrueValue.trunc(BitWidth);
7416 FalseValue = FalseValue.trunc(BitWidth);
7417 break;
7418 case scZeroExtend:
7419 TrueValue = TrueValue.zext(BitWidth);
7420 FalseValue = FalseValue.zext(BitWidth);
7421 break;
7422 case scSignExtend:
7423 TrueValue = TrueValue.sext(BitWidth);
7424 FalseValue = FalseValue.sext(BitWidth);
7425 break;
7426 }
7427
7428 // Re-apply the constant offset we peeled off earlier
7429 TrueValue += Offset;
7430 FalseValue += Offset;
7431 }
7432
7433 bool isRecognized() { return Condition != nullptr; }
7434 };
7435
7436 SelectPattern StartPattern(*this, BitWidth, Start);
7437 if (!StartPattern.isRecognized())
7438 return ConstantRange::getFull(BitWidth);
7439
7440 SelectPattern StepPattern(*this, BitWidth, Step);
7441 if (!StepPattern.isRecognized())
7442 return ConstantRange::getFull(BitWidth);
7443
7444 if (StartPattern.Condition != StepPattern.Condition) {
7445 // We don't handle this case today; but we could, by considering four
7446 // possibilities below instead of two. I'm not sure if there are cases where
7447 // that will help over what getRange already does, though.
7448 return ConstantRange::getFull(BitWidth);
7449 }
7450
7451 // NB! Calling ScalarEvolution::getConstant is fine, but we should not try to
7452 // construct arbitrary general SCEV expressions here. This function is called
7453 // from deep in the call stack, and calling getSCEV (on a sext instruction,
7454 // say) can end up caching a suboptimal value.
7455
7456 // FIXME: without the explicit `this` receiver below, MSVC errors out with
7457 // C2352 and C2512 (otherwise it isn't needed).
7458
7459 const SCEV *TrueStart = this->getConstant(StartPattern.TrueValue);
7460 const SCEV *TrueStep = this->getConstant(StepPattern.TrueValue);
7461 const SCEV *FalseStart = this->getConstant(StartPattern.FalseValue);
7462 const SCEV *FalseStep = this->getConstant(StepPattern.FalseValue);
7463
7464 ConstantRange TrueRange =
7465 this->getRangeForAffineAR(TrueStart, TrueStep, MaxBECount);
7466 ConstantRange FalseRange =
7467 this->getRangeForAffineAR(FalseStart, FalseStep, MaxBECount);
7468
7469 return TrueRange.unionWith(FalseRange);
7470}
7471
7472SCEV::NoWrapFlags ScalarEvolution::getNoWrapFlagsFromUB(const Value *V) {
7473 if (isa<ConstantExpr>(V)) return SCEV::FlagAnyWrap;
7474 const BinaryOperator *BinOp = cast<BinaryOperator>(V);
7475
7476 // Return early if there are no flags to propagate to the SCEV.
7478 if (BinOp->hasNoUnsignedWrap())
7480 if (BinOp->hasNoSignedWrap())
7482 if (Flags == SCEV::FlagAnyWrap)
7483 return SCEV::FlagAnyWrap;
7484
7485 return isSCEVExprNeverPoison(BinOp) ? Flags : SCEV::FlagAnyWrap;
7486}
7487
7488const Instruction *
7489ScalarEvolution::getNonTrivialDefiningScopeBound(const SCEV *S) {
7490 if (auto *AddRec = dyn_cast<SCEVAddRecExpr>(S))
7491 return &*AddRec->getLoop()->getHeader()->begin();
7492 if (auto *U = dyn_cast<SCEVUnknown>(S))
7493 if (auto *I = dyn_cast<Instruction>(U->getValue()))
7494 return I;
7495 return nullptr;
7496}
7497
7498const Instruction *ScalarEvolution::getDefiningScopeBound(ArrayRef<SCEVUse> Ops,
7499 bool &Precise) {
7500 Precise = true;
7501 // Do a bounded search of the def relation of the requested SCEVs.
7502 SmallPtrSet<const SCEV *, 16> Visited;
7503 SmallVector<SCEVUse> Worklist;
7504 auto pushOp = [&](const SCEV *S) {
7505 if (!Visited.insert(S).second)
7506 return;
7507 // Threshold of 30 here is arbitrary.
7508 if (Visited.size() > 30) {
7509 Precise = false;
7510 return;
7511 }
7512 Worklist.push_back(S);
7513 };
7514
7515 for (SCEVUse S : Ops)
7516 pushOp(S);
7517
7518 const Instruction *Bound = nullptr;
7519 while (!Worklist.empty()) {
7520 SCEVUse S = Worklist.pop_back_val();
7521 if (auto *DefI = getNonTrivialDefiningScopeBound(S)) {
7522 if (!Bound || DT.dominates(Bound, DefI))
7523 Bound = DefI;
7524 } else {
7525 for (SCEVUse Op : S->operands())
7526 pushOp(Op);
7527 }
7528 }
7529 return Bound ? Bound : &*F.getEntryBlock().begin();
7530}
7531
7532const Instruction *
7533ScalarEvolution::getDefiningScopeBound(ArrayRef<SCEVUse> Ops) {
7534 bool Discard;
7535 return getDefiningScopeBound(Ops, Discard);
7536}
7537
7538bool ScalarEvolution::isGuaranteedToTransferExecutionTo(const Instruction *A,
7539 const Instruction *B) {
7540 if (A->getParent() == B->getParent() &&
7542 B->getIterator()))
7543 return true;
7544
7545 auto *BLoop = LI.getLoopFor(B->getParent());
7546 if (BLoop && BLoop->getHeader() == B->getParent() &&
7547 BLoop->getLoopPreheader() == A->getParent() &&
7549 A->getParent()->end()) &&
7550 isGuaranteedToTransferExecutionToSuccessor(B->getParent()->begin(),
7551 B->getIterator()))
7552 return true;
7553 return false;
7554}
7555
7556bool ScalarEvolution::isGuaranteedNotToBePoison(const SCEV *Op) {
7557 SCEVPoisonCollector PC(/* LookThroughMaybePoisonBlocking */ true);
7558 visitAll(Op, PC);
7559 return PC.MaybePoison.empty();
7560}
7561
7562bool ScalarEvolution::isGuaranteedNotToCauseUB(const SCEV *Op) {
7563 return !SCEVExprContains(Op, [this](const SCEV *S) {
7564 const SCEV *Op1;
7565 bool M = match(S, m_scev_UDiv(m_SCEV(), m_SCEV(Op1)));
7566 // The UDiv may be UB if the divisor is poison or zero. Unless the divisor
7567 // is a non-zero constant, we have to assume the UDiv may be UB.
7568 return M && (!isKnownNonZero(Op1) || !isGuaranteedNotToBePoison(Op1));
7569 });
7570}
7571
7572bool ScalarEvolution::isSCEVExprNeverPoison(const Instruction *I) {
7573 // Only proceed if we can prove that I does not yield poison.
7575 return false;
7576
7577 // At this point we know that if I is executed, then it does not wrap
7578 // according to at least one of NSW or NUW. If I is not executed, then we do
7579 // not know if the calculation that I represents would wrap. Multiple
7580 // instructions can map to the same SCEV. If we apply NSW or NUW from I to
7581 // the SCEV, we must guarantee no wrapping for that SCEV also when it is
7582 // derived from other instructions that map to the same SCEV. We cannot make
7583 // that guarantee for cases where I is not executed. So we need to find a
7584 // upper bound on the defining scope for the SCEV, and prove that I is
7585 // executed every time we enter that scope. When the bounding scope is a
7586 // loop (the common case), this is equivalent to proving I executes on every
7587 // iteration of that loop.
7588 SmallVector<SCEVUse> SCEVOps;
7589 for (const Use &Op : I->operands()) {
7590 // I could be an extractvalue from a call to an overflow intrinsic.
7591 // TODO: We can do better here in some cases.
7592 if (isSCEVable(Op->getType()))
7593 SCEVOps.push_back(getSCEV(Op));
7594 }
7595 auto *DefI = getDefiningScopeBound(SCEVOps);
7596 return isGuaranteedToTransferExecutionTo(DefI, I);
7597}
7598
7599bool ScalarEvolution::isAddRecNeverPoison(const Instruction *I, const Loop *L) {
7600 // If we know that \c I can never be poison period, then that's enough.
7601 if (isSCEVExprNeverPoison(I))
7602 return true;
7603
7604 // If the loop only has one exit, then we know that, if the loop is entered,
7605 // any instruction dominating that exit will be executed. If any such
7606 // instruction would result in UB, the addrec cannot be poison.
7607 //
7608 // This is basically the same reasoning as in isSCEVExprNeverPoison(), but
7609 // also handles uses outside the loop header (they just need to dominate the
7610 // single exit).
7611
7612 auto *ExitingBB = L->getExitingBlock();
7613 if (!ExitingBB || !loopHasNoAbnormalExits(L))
7614 return false;
7615
7616 SmallPtrSet<const Value *, 16> KnownPoison;
7618
7619 // We start by assuming \c I, the post-inc add recurrence, is poison. Only
7620 // things that are known to be poison under that assumption go on the
7621 // Worklist.
7622 KnownPoison.insert(I);
7623 Worklist.push_back(I);
7624
7625 while (!Worklist.empty()) {
7626 const Instruction *Poison = Worklist.pop_back_val();
7627
7628 for (const Use &U : Poison->uses()) {
7629 const Instruction *PoisonUser = cast<Instruction>(U.getUser());
7630 if (mustTriggerUB(PoisonUser, KnownPoison) &&
7631 DT.dominates(PoisonUser->getParent(), ExitingBB))
7632 return true;
7633
7634 if (propagatesPoison(U) && L->contains(PoisonUser))
7635 if (KnownPoison.insert(PoisonUser).second)
7636 Worklist.push_back(PoisonUser);
7637 }
7638 }
7639
7640 return false;
7641}
7642
7643ScalarEvolution::LoopProperties
7644ScalarEvolution::getLoopProperties(const Loop *L) {
7645 using LoopProperties = ScalarEvolution::LoopProperties;
7646
7647 auto Itr = LoopPropertiesCache.find(L);
7648 if (Itr == LoopPropertiesCache.end()) {
7649 auto HasSideEffects = [](Instruction *I) {
7650 if (auto *SI = dyn_cast<StoreInst>(I))
7651 return !SI->isSimple();
7652
7653 if (I->mayThrow())
7654 return true;
7655
7656 // Non-volatile memset / memcpy do not count as side-effect for forward
7657 // progress.
7658 if (isa<MemIntrinsic>(I) && !I->isVolatile())
7659 return false;
7660
7661 return I->mayWriteToMemory();
7662 };
7663
7664 LoopProperties LP = {/* HasNoAbnormalExits */ true,
7665 /*HasNoSideEffects*/ true};
7666
7667 for (auto *BB : L->getBlocks())
7668 for (auto &I : *BB) {
7670 LP.HasNoAbnormalExits = false;
7671 if (HasSideEffects(&I))
7672 LP.HasNoSideEffects = false;
7673 if (!LP.HasNoAbnormalExits && !LP.HasNoSideEffects)
7674 break; // We're already as pessimistic as we can get.
7675 }
7676
7677 auto InsertPair = LoopPropertiesCache.insert({L, LP});
7678 assert(InsertPair.second && "We just checked!");
7679 Itr = InsertPair.first;
7680 }
7681
7682 return Itr->second;
7683}
7684
7686 // A mustprogress loop without side effects must be finite.
7687 // TODO: The check used here is very conservative. It's only *specific*
7688 // side effects which are well defined in infinite loops.
7689 return isFinite(L) || (isMustProgress(L) && loopHasNoSideEffects(L));
7690}
7691
7692const SCEV *ScalarEvolution::createSCEVIter(Value *V) {
7693 // Worklist item with a Value and a bool indicating whether all operands have
7694 // been visited already.
7697
7698 Stack.emplace_back(V, true);
7699 Stack.emplace_back(V, false);
7700 while (!Stack.empty()) {
7701 auto E = Stack.pop_back_val();
7702 Value *CurV = E.getPointer();
7703
7704 if (getExistingSCEV(CurV))
7705 continue;
7706
7708 const SCEV *CreatedSCEV = nullptr;
7709 // If all operands have been visited already, create the SCEV.
7710 if (E.getInt()) {
7711 CreatedSCEV = createSCEV(CurV);
7712 } else {
7713 // Otherwise get the operands we need to create SCEV's for before creating
7714 // the SCEV for CurV. If the SCEV for CurV can be constructed trivially,
7715 // just use it.
7716 CreatedSCEV = getOperandsToCreate(CurV, Ops);
7717 }
7718
7719 if (CreatedSCEV) {
7720 insertValueToMap(CurV, CreatedSCEV);
7721 } else {
7722 // Queue CurV for SCEV creation, followed by its's operands which need to
7723 // be constructed first.
7724 Stack.emplace_back(CurV, true);
7725 for (Value *Op : Ops)
7726 Stack.emplace_back(Op, false);
7727 }
7728 }
7729
7730 return getExistingSCEV(V);
7731}
7732
7733const SCEV *
7734ScalarEvolution::getOperandsToCreate(Value *V, SmallVectorImpl<Value *> &Ops) {
7735 if (!isSCEVable(V->getType()))
7736 return getUnknown(V);
7737
7738 if (Instruction *I = dyn_cast<Instruction>(V)) {
7739 // Don't attempt to analyze instructions in blocks that aren't
7740 // reachable. Such instructions don't matter, and they aren't required
7741 // to obey basic rules for definitions dominating uses which this
7742 // analysis depends on.
7743 if (!DT.isReachableFromEntry(I->getParent()))
7744 return getUnknown(PoisonValue::get(V->getType()));
7745 } else if (ConstantInt *CI = dyn_cast<ConstantInt>(V))
7746 return getConstant(CI);
7747 else if (isa<GlobalAlias>(V))
7748 return getUnknown(V);
7749 else if (!isa<ConstantExpr>(V))
7750 return getUnknown(V);
7751
7753 if (auto BO =
7755 bool IsConstArg = isa<ConstantInt>(BO->RHS);
7756 switch (BO->Opcode) {
7757 case Instruction::Add:
7758 case Instruction::Mul: {
7759 // For additions and multiplications, traverse add/mul chains for which we
7760 // can potentially create a single SCEV, to reduce the number of
7761 // get{Add,Mul}Expr calls.
7762 do {
7763 if (BO->Op) {
7764 if (BO->Op != V && getExistingSCEV(BO->Op)) {
7765 Ops.push_back(BO->Op);
7766 break;
7767 }
7768 }
7769 Ops.push_back(BO->RHS);
7770 auto NewBO = MatchBinaryOp(BO->LHS, getDataLayout(), AC, DT,
7772 if (!NewBO ||
7773 (BO->Opcode == Instruction::Add &&
7774 (NewBO->Opcode != Instruction::Add &&
7775 NewBO->Opcode != Instruction::Sub)) ||
7776 (BO->Opcode == Instruction::Mul &&
7777 NewBO->Opcode != Instruction::Mul)) {
7778 Ops.push_back(BO->LHS);
7779 break;
7780 }
7781 // CreateSCEV calls getNoWrapFlagsFromUB, which under certain conditions
7782 // requires a SCEV for the LHS.
7783 if (BO->Op && (BO->IsNSW || BO->IsNUW)) {
7784 auto *I = dyn_cast<Instruction>(BO->Op);
7785 if (I && programUndefinedIfPoison(I)) {
7786 Ops.push_back(BO->LHS);
7787 break;
7788 }
7789 }
7790 BO = NewBO;
7791 } while (true);
7792 return nullptr;
7793 }
7794 case Instruction::Sub:
7795 case Instruction::UDiv:
7796 case Instruction::URem:
7797 break;
7798 case Instruction::AShr:
7799 case Instruction::Shl:
7800 case Instruction::Xor:
7801 if (!IsConstArg)
7802 return nullptr;
7803 break;
7804 case Instruction::And:
7805 case Instruction::Or:
7806 if (!IsConstArg && !BO->LHS->getType()->isIntegerTy(1))
7807 return nullptr;
7808 break;
7809 case Instruction::LShr:
7810 return getUnknown(V);
7811 default:
7812 llvm_unreachable("Unhandled binop");
7813 break;
7814 }
7815
7816 Ops.push_back(BO->LHS);
7817 Ops.push_back(BO->RHS);
7818 return nullptr;
7819 }
7820
7821 switch (U->getOpcode()) {
7822 case Instruction::Trunc:
7823 case Instruction::ZExt:
7824 case Instruction::SExt:
7825 case Instruction::PtrToAddr:
7826 case Instruction::PtrToInt:
7827 Ops.push_back(U->getOperand(0));
7828 return nullptr;
7829
7830 case Instruction::BitCast:
7831 if (isSCEVable(U->getType()) && isSCEVable(U->getOperand(0)->getType())) {
7832 Ops.push_back(U->getOperand(0));
7833 return nullptr;
7834 }
7835 return getUnknown(V);
7836
7837 case Instruction::SDiv:
7838 case Instruction::SRem:
7839 Ops.push_back(U->getOperand(0));
7840 Ops.push_back(U->getOperand(1));
7841 return nullptr;
7842
7843 case Instruction::GetElementPtr:
7844 assert(cast<GEPOperator>(U)->getSourceElementType()->isSized() &&
7845 "GEP source element type must be sized");
7846 llvm::append_range(Ops, U->operands());
7847 return nullptr;
7848
7849 case Instruction::IntToPtr:
7850 return getUnknown(V);
7851
7852 case Instruction::PHI:
7853 // getNodeForPHI has four ways to turn a PHI into a SCEV; retrieve the
7854 // relevant nodes for each of them.
7855 //
7856 // The first is just to call simplifyInstruction, and get something back
7857 // that isn't a PHI.
7858 if (Value *V = simplifyInstruction(
7859 cast<PHINode>(U),
7860 {getDataLayout(), &TLI, &DT, &AC, /*CtxI=*/nullptr,
7861 /*UseInstrInfo=*/true, /*CanUseUndef=*/false})) {
7862 assert(V);
7863 Ops.push_back(V);
7864 return nullptr;
7865 }
7866 // The second is createNodeForPHIWithIdenticalOperands: this looks for
7867 // operands which all perform the same operation, but haven't been
7868 // CSE'ed for whatever reason.
7869 if (BinaryOperator *BO = getCommonInstForPHI(cast<PHINode>(U))) {
7870 assert(BO);
7871 Ops.push_back(BO);
7872 return nullptr;
7873 }
7874 // The third is createNodeFromSelectLikePHI; this takes a PHI which
7875 // is equivalent to a select, and analyzes it like a select.
7876 {
7877 Value *Cond = nullptr, *LHS = nullptr, *RHS = nullptr;
7879 assert(Cond);
7880 assert(LHS);
7881 assert(RHS);
7882 if (auto *CondICmp = dyn_cast<ICmpInst>(Cond)) {
7883 Ops.push_back(CondICmp->getOperand(0));
7884 Ops.push_back(CondICmp->getOperand(1));
7885 }
7886 Ops.push_back(Cond);
7887 Ops.push_back(LHS);
7888 Ops.push_back(RHS);
7889 return nullptr;
7890 }
7891 }
7892 // The fourth way is createAddRecFromPHI. It's complicated to handle here,
7893 // so just construct it recursively.
7894 //
7895 // In addition to getNodeForPHI, also construct nodes which might be needed
7896 // by getRangeRef.
7898 for (Value *V : cast<PHINode>(U)->operands())
7899 Ops.push_back(V);
7900 return nullptr;
7901 }
7902 return nullptr;
7903
7904 case Instruction::Select: {
7905 // Check if U is a select that can be simplified to a SCEVUnknown.
7906 auto CanSimplifyToUnknown = [this, U]() {
7907 if (U->getType()->isIntegerTy(1) || isa<ConstantInt>(U->getOperand(0)))
7908 return false;
7909
7910 auto *ICI = dyn_cast<ICmpInst>(U->getOperand(0));
7911 if (!ICI)
7912 return false;
7913 Value *LHS = ICI->getOperand(0);
7914 Value *RHS = ICI->getOperand(1);
7915 if (ICI->getPredicate() == CmpInst::ICMP_EQ ||
7916 ICI->getPredicate() == CmpInst::ICMP_NE) {
7918 return true;
7919 } else if (getTypeSizeInBits(LHS->getType()) >
7920 getTypeSizeInBits(U->getType()))
7921 return true;
7922 return false;
7923 };
7924 if (CanSimplifyToUnknown())
7925 return getUnknown(U);
7926
7927 llvm::append_range(Ops, U->operands());
7928 return nullptr;
7929 break;
7930 }
7931 case Instruction::Call:
7932 case Instruction::Invoke:
7933 if (Value *RV = cast<CallBase>(U)->getReturnedArgOperand()) {
7934 Ops.push_back(RV);
7935 return nullptr;
7936 }
7937
7938 if (auto *II = dyn_cast<IntrinsicInst>(U)) {
7939 switch (II->getIntrinsicID()) {
7940 case Intrinsic::abs:
7941 Ops.push_back(II->getArgOperand(0));
7942 return nullptr;
7943 case Intrinsic::umax:
7944 case Intrinsic::umin:
7945 case Intrinsic::smax:
7946 case Intrinsic::smin:
7947 case Intrinsic::usub_sat:
7948 case Intrinsic::uadd_sat:
7949 Ops.push_back(II->getArgOperand(0));
7950 Ops.push_back(II->getArgOperand(1));
7951 return nullptr;
7952 case Intrinsic::start_loop_iterations:
7953 case Intrinsic::annotation:
7954 case Intrinsic::ptr_annotation:
7955 Ops.push_back(II->getArgOperand(0));
7956 return nullptr;
7957 default:
7958 break;
7959 }
7960 }
7961 break;
7962 }
7963
7964 return nullptr;
7965}
7966
7967const SCEV *ScalarEvolution::createSCEV(Value *V) {
7968 if (!isSCEVable(V->getType()))
7969 return getUnknown(V);
7970
7971 if (Instruction *I = dyn_cast<Instruction>(V)) {
7972 // Don't attempt to analyze instructions in blocks that aren't
7973 // reachable. Such instructions don't matter, and they aren't required
7974 // to obey basic rules for definitions dominating uses which this
7975 // analysis depends on.
7976 if (!DT.isReachableFromEntry(I->getParent()))
7977 return getUnknown(PoisonValue::get(V->getType()));
7978 } else if (ConstantInt *CI = dyn_cast<ConstantInt>(V))
7979 return getConstant(CI);
7980 else if (isa<GlobalAlias>(V))
7981 return getUnknown(V);
7982 else if (!isa<ConstantExpr>(V))
7983 return getUnknown(V);
7984
7985 const SCEV *LHS;
7986 const SCEV *RHS;
7987
7989 if (auto BO =
7991 switch (BO->Opcode) {
7992 case Instruction::Add: {
7993 // The simple thing to do would be to just call getSCEV on both operands
7994 // and call getAddExpr with the result. However if we're looking at a
7995 // bunch of things all added together, this can be quite inefficient,
7996 // because it leads to N-1 getAddExpr calls for N ultimate operands.
7997 // Instead, gather up all the operands and make a single getAddExpr call.
7998 // LLVM IR canonical form means we need only traverse the left operands.
8000 do {
8001 if (BO->Op) {
8002 if (auto *OpSCEV = getExistingSCEV(BO->Op)) {
8003 AddOps.push_back(OpSCEV);
8004 break;
8005 }
8006
8007 // If a NUW or NSW flag can be applied to the SCEV for this
8008 // addition, then compute the SCEV for this addition by itself
8009 // with a separate call to getAddExpr. We need to do that
8010 // instead of pushing the operands of the addition onto AddOps,
8011 // since the flags are only known to apply to this particular
8012 // addition - they may not apply to other additions that can be
8013 // formed with operands from AddOps.
8014 const SCEV *RHS = getSCEV(BO->RHS);
8015 SCEV::NoWrapFlags Flags = getNoWrapFlagsFromUB(BO->Op);
8016 if (Flags != SCEV::FlagAnyWrap) {
8017 const SCEV *LHS = getSCEV(BO->LHS);
8018 if (BO->Opcode == Instruction::Sub)
8019 AddOps.push_back(getMinusSCEV(LHS, RHS, Flags));
8020 else
8021 AddOps.push_back(getAddExpr(LHS, RHS, Flags));
8022 break;
8023 }
8024 }
8025
8026 if (BO->Opcode == Instruction::Sub)
8027 AddOps.push_back(getNegativeSCEV(getSCEV(BO->RHS)));
8028 else
8029 AddOps.push_back(getSCEV(BO->RHS));
8030
8031 auto NewBO = MatchBinaryOp(BO->LHS, getDataLayout(), AC, DT,
8033 if (!NewBO || (NewBO->Opcode != Instruction::Add &&
8034 NewBO->Opcode != Instruction::Sub)) {
8035 AddOps.push_back(getSCEV(BO->LHS));
8036 break;
8037 }
8038 BO = NewBO;
8039 } while (true);
8040
8041 return getAddExpr(AddOps);
8042 }
8043
8044 case Instruction::Mul: {
8046 do {
8047 if (BO->Op) {
8048 if (auto *OpSCEV = getExistingSCEV(BO->Op)) {
8049 MulOps.push_back(OpSCEV);
8050 break;
8051 }
8052
8053 SCEV::NoWrapFlags Flags = getNoWrapFlagsFromUB(BO->Op);
8054 if (Flags != SCEV::FlagAnyWrap) {
8055 LHS = getSCEV(BO->LHS);
8056 RHS = getSCEV(BO->RHS);
8057 MulOps.push_back(getMulExpr(LHS, RHS, Flags));
8058 break;
8059 }
8060 }
8061
8062 MulOps.push_back(getSCEV(BO->RHS));
8063 auto NewBO = MatchBinaryOp(BO->LHS, getDataLayout(), AC, DT,
8065 if (!NewBO || NewBO->Opcode != Instruction::Mul) {
8066 MulOps.push_back(getSCEV(BO->LHS));
8067 break;
8068 }
8069 BO = NewBO;
8070 } while (true);
8071
8072 return getMulExpr(MulOps);
8073 }
8074 case Instruction::UDiv:
8075 LHS = getSCEV(BO->LHS);
8076 RHS = getSCEV(BO->RHS);
8077 return getUDivExpr(LHS, RHS);
8078 case Instruction::URem:
8079 LHS = getSCEV(BO->LHS);
8080 RHS = getSCEV(BO->RHS);
8081 return getURemExpr(LHS, RHS);
8082 case Instruction::Sub: {
8084 if (BO->Op)
8085 Flags = getNoWrapFlagsFromUB(BO->Op);
8086 LHS = getSCEV(BO->LHS);
8087 RHS = getSCEV(BO->RHS);
8088 return getMinusSCEV(LHS, RHS, Flags);
8089 }
8090 case Instruction::And:
8091 // For an expression like x&255 that merely masks off the high bits,
8092 // use zext(trunc(x)) as the SCEV expression.
8093 if (ConstantInt *CI = dyn_cast<ConstantInt>(BO->RHS)) {
8094 if (CI->isZero())
8095 return getSCEV(BO->RHS);
8096 if (CI->isMinusOne())
8097 return getSCEV(BO->LHS);
8098 const APInt &A = CI->getValue();
8099
8100 // Instcombine's ShrinkDemandedConstant may strip bits out of
8101 // constants, obscuring what would otherwise be a low-bits mask.
8102 // Use computeKnownBits to compute what ShrinkDemandedConstant
8103 // knew about to reconstruct a low-bits mask value.
8104 unsigned LZ = A.countl_zero();
8105 unsigned TZ = A.countr_zero();
8106 unsigned BitWidth = A.getBitWidth();
8107 KnownBits Known(BitWidth);
8108 computeKnownBits(BO->LHS, Known, getDataLayout(), &AC, nullptr, &DT);
8109
8110 APInt EffectiveMask =
8111 APInt::getLowBitsSet(BitWidth, BitWidth - LZ - TZ).shl(TZ);
8112 if ((LZ != 0 || TZ != 0) && !((~A & ~Known.Zero) & EffectiveMask)) {
8113 const SCEV *MulCount = getConstant(APInt::getOneBitSet(BitWidth, TZ));
8114 const SCEV *LHS = getSCEV(BO->LHS);
8115 const SCEV *ShiftedLHS = nullptr;
8116 if (auto *LHSMul = dyn_cast<SCEVMulExpr>(LHS)) {
8117 if (auto *OpC = dyn_cast<SCEVConstant>(LHSMul->getOperand(0))) {
8118 // For an expression like (x * 8) & 8, simplify the multiply.
8119 unsigned MulZeros = OpC->getAPInt().countr_zero();
8120 unsigned GCD = std::min(MulZeros, TZ);
8121 APInt DivAmt = APInt::getOneBitSet(BitWidth, TZ - GCD);
8123 MulOps.push_back(getConstant(OpC->getAPInt().ashr(GCD)));
8124 append_range(MulOps, LHSMul->operands().drop_front());
8125 auto *NewMul = getMulExpr(MulOps, LHSMul->getNoWrapFlags());
8126 ShiftedLHS = getUDivExpr(NewMul, getConstant(DivAmt));
8127 }
8128 }
8129 if (!ShiftedLHS)
8130 ShiftedLHS = getUDivExpr(LHS, MulCount);
8131 return getMulExpr(
8133 getTruncateExpr(ShiftedLHS,
8134 IntegerType::get(getContext(), BitWidth - LZ - TZ)),
8135 BO->LHS->getType()),
8136 MulCount);
8137 }
8138 }
8139 // Binary `and` is a bit-wise `umin`.
8140 if (BO->LHS->getType()->isIntegerTy(1)) {
8141 LHS = getSCEV(BO->LHS);
8142 RHS = getSCEV(BO->RHS);
8143 return getUMinExpr(LHS, RHS);
8144 }
8145 break;
8146
8147 case Instruction::Or:
8148 // Binary `or` is a bit-wise `umax`.
8149 if (BO->LHS->getType()->isIntegerTy(1)) {
8150 LHS = getSCEV(BO->LHS);
8151 RHS = getSCEV(BO->RHS);
8152 return getUMaxExpr(LHS, RHS);
8153 }
8154 break;
8155
8156 case Instruction::Xor:
8157 if (ConstantInt *CI = dyn_cast<ConstantInt>(BO->RHS)) {
8158 // If the RHS of xor is -1, then this is a not operation.
8159 if (CI->isMinusOne())
8160 return getNotSCEV(getSCEV(BO->LHS));
8161
8162 // Model xor(and(x, C), C) as and(~x, C), if C is a low-bits mask.
8163 // This is a variant of the check for xor with -1, and it handles
8164 // the case where instcombine has trimmed non-demanded bits out
8165 // of an xor with -1.
8166 if (auto *LBO = dyn_cast<BinaryOperator>(BO->LHS))
8167 if (ConstantInt *LCI = dyn_cast<ConstantInt>(LBO->getOperand(1)))
8168 if (LBO->getOpcode() == Instruction::And &&
8169 LCI->getValue() == CI->getValue())
8170 if (const SCEVZeroExtendExpr *Z =
8172 Type *UTy = BO->LHS->getType();
8173 const SCEV *Z0 = Z->getOperand();
8174 Type *Z0Ty = Z0->getType();
8175 unsigned Z0TySize = getTypeSizeInBits(Z0Ty);
8176
8177 // If C is a low-bits mask, the zero extend is serving to
8178 // mask off the high bits. Complement the operand and
8179 // re-apply the zext.
8180 if (CI->getValue().isMask(Z0TySize))
8181 return getZeroExtendExpr(getNotSCEV(Z0), UTy);
8182
8183 // If C is a single bit, it may be in the sign-bit position
8184 // before the zero-extend. In this case, represent the xor
8185 // using an add, which is equivalent, and re-apply the zext.
8186 APInt Trunc = CI->getValue().trunc(Z0TySize);
8187 if (Trunc.zext(getTypeSizeInBits(UTy)) == CI->getValue() &&
8188 Trunc.isSignMask())
8189 return getZeroExtendExpr(getAddExpr(Z0, getConstant(Trunc)),
8190 UTy);
8191 }
8192 }
8193 break;
8194
8195 case Instruction::Shl:
8196 // Turn shift left of a constant amount into a multiply.
8197 if (ConstantInt *SA = dyn_cast<ConstantInt>(BO->RHS)) {
8198 uint32_t BitWidth = cast<IntegerType>(SA->getType())->getBitWidth();
8199
8200 // If the shift count is not less than the bitwidth, the result of
8201 // the shift is undefined. Don't try to analyze it, because the
8202 // resolution chosen here may differ from the resolution chosen in
8203 // other parts of the compiler.
8204 if (SA->getValue().uge(BitWidth))
8205 break;
8206
8207 // We can safely preserve the nuw flag in all cases. It's also safe to
8208 // turn a nuw nsw shl into a nuw nsw mul. However, nsw in isolation
8209 // requires special handling. It can be preserved as long as we're not
8210 // left shifting by bitwidth - 1.
8211 auto Flags = SCEV::FlagAnyWrap;
8212 if (BO->Op) {
8213 auto MulFlags = getNoWrapFlagsFromUB(BO->Op);
8214 if (any(MulFlags & SCEV::FlagNSW) &&
8215 (any(MulFlags & SCEV::FlagNUW) ||
8216 SA->getValue().ult(BitWidth - 1)))
8218 if (any(MulFlags & SCEV::FlagNUW))
8220 }
8221
8222 ConstantInt *X = ConstantInt::get(
8223 getContext(), APInt::getOneBitSet(BitWidth, SA->getZExtValue()));
8224 return getMulExpr(getSCEV(BO->LHS), getConstant(X), Flags);
8225 }
8226 break;
8227
8228 case Instruction::AShr:
8229 // AShr X, C, where C is a constant.
8230 ConstantInt *CI = dyn_cast<ConstantInt>(BO->RHS);
8231 if (!CI)
8232 break;
8233
8234 Type *OuterTy = BO->LHS->getType();
8235 uint64_t BitWidth = getTypeSizeInBits(OuterTy);
8236 // If the shift count is not less than the bitwidth, the result of
8237 // the shift is undefined. Don't try to analyze it, because the
8238 // resolution chosen here may differ from the resolution chosen in
8239 // other parts of the compiler.
8240 if (CI->getValue().uge(BitWidth))
8241 break;
8242
8243 if (CI->isZero())
8244 return getSCEV(BO->LHS); // shift by zero --> noop
8245
8246 uint64_t AShrAmt = CI->getZExtValue();
8247 Type *TruncTy = IntegerType::get(getContext(), BitWidth - AShrAmt);
8248
8249 Operator *L = dyn_cast<Operator>(BO->LHS);
8250 const SCEV *AddTruncateExpr = nullptr;
8251 ConstantInt *ShlAmtCI = nullptr;
8252 const SCEV *AddConstant = nullptr;
8253
8254 if (L && L->getOpcode() == Instruction::Add) {
8255 // X = Shl A, n
8256 // Y = Add X, c
8257 // Z = AShr Y, m
8258 // n, c and m are constants.
8259
8260 Operator *LShift = dyn_cast<Operator>(L->getOperand(0));
8261 ConstantInt *AddOperandCI = dyn_cast<ConstantInt>(L->getOperand(1));
8262 if (LShift && LShift->getOpcode() == Instruction::Shl) {
8263 if (AddOperandCI) {
8264 const SCEV *ShlOp0SCEV = getSCEV(LShift->getOperand(0));
8265 ShlAmtCI = dyn_cast<ConstantInt>(LShift->getOperand(1));
8266 // since we truncate to TruncTy, the AddConstant should be of the
8267 // same type, so create a new Constant with type same as TruncTy.
8268 // Also, the Add constant should be shifted right by AShr amount.
8269 APInt AddOperand = AddOperandCI->getValue().ashr(AShrAmt);
8270 AddConstant = getConstant(AddOperand.trunc(BitWidth - AShrAmt));
8271 // we model the expression as sext(add(trunc(A), c << n)), since the
8272 // sext(trunc) part is already handled below, we create a
8273 // AddExpr(TruncExp) which will be used later.
8274 AddTruncateExpr = getTruncateExpr(ShlOp0SCEV, TruncTy);
8275 }
8276 }
8277 } else if (L && L->getOpcode() == Instruction::Shl) {
8278 // X = Shl A, n
8279 // Y = AShr X, m
8280 // Both n and m are constant.
8281
8282 const SCEV *ShlOp0SCEV = getSCEV(L->getOperand(0));
8283 ShlAmtCI = dyn_cast<ConstantInt>(L->getOperand(1));
8284 AddTruncateExpr = getTruncateExpr(ShlOp0SCEV, TruncTy);
8285 }
8286
8287 if (AddTruncateExpr && ShlAmtCI) {
8288 // We can merge the two given cases into a single SCEV statement,
8289 // incase n = m, the mul expression will be 2^0, so it gets resolved to
8290 // a simpler case. The following code handles the two cases:
8291 //
8292 // 1) For a two-shift sext-inreg, i.e. n = m,
8293 // use sext(trunc(x)) as the SCEV expression.
8294 //
8295 // 2) When n > m, use sext(mul(trunc(x), 2^(n-m)))) as the SCEV
8296 // expression. We already checked that ShlAmt < BitWidth, so
8297 // the multiplier, 1 << (ShlAmt - AShrAmt), fits into TruncTy as
8298 // ShlAmt - AShrAmt < Amt.
8299 const APInt &ShlAmt = ShlAmtCI->getValue();
8300 if (ShlAmt.ult(BitWidth) && ShlAmt.uge(AShrAmt)) {
8301 APInt Mul = APInt::getOneBitSet(BitWidth - AShrAmt,
8302 ShlAmtCI->getZExtValue() - AShrAmt);
8303 const SCEV *CompositeExpr =
8304 getMulExpr(AddTruncateExpr, getConstant(Mul));
8305 if (L->getOpcode() != Instruction::Shl)
8306 CompositeExpr = getAddExpr(CompositeExpr, AddConstant);
8307
8308 return getSignExtendExpr(CompositeExpr, OuterTy);
8309 }
8310 }
8311 break;
8312 }
8313 }
8314
8315 switch (U->getOpcode()) {
8316 case Instruction::Trunc:
8317 return getTruncateExpr(getSCEV(U->getOperand(0)), U->getType());
8318
8319 case Instruction::ZExt:
8320 return getZeroExtendExpr(getSCEV(U->getOperand(0)), U->getType());
8321
8322 case Instruction::SExt:
8323 if (auto BO = MatchBinaryOp(U->getOperand(0), getDataLayout(), AC, DT,
8325 // The NSW flag of a subtract does not always survive the conversion to
8326 // A + (-1)*B. By pushing sign extension onto its operands we are much
8327 // more likely to preserve NSW and allow later AddRec optimisations.
8328 //
8329 // NOTE: This is effectively duplicating this logic from getSignExtend:
8330 // sext((A + B + ...)<nsw>) --> (sext(A) + sext(B) + ...)<nsw>
8331 // but by that point the NSW information has potentially been lost.
8332 if (BO->Opcode == Instruction::Sub && BO->IsNSW) {
8333 Type *Ty = U->getType();
8334 auto *V1 = getSignExtendExpr(getSCEV(BO->LHS), Ty);
8335 auto *V2 = getSignExtendExpr(getSCEV(BO->RHS), Ty);
8336 return getMinusSCEV(V1, V2, SCEV::FlagNSW);
8337 }
8338 }
8339 return getSignExtendExpr(getSCEV(U->getOperand(0)), U->getType());
8340
8341 case Instruction::BitCast:
8342 // BitCasts are no-op casts so we just eliminate the cast.
8343 if (isSCEVable(U->getType()) && isSCEVable(U->getOperand(0)->getType()))
8344 return getSCEV(U->getOperand(0));
8345 break;
8346
8347 case Instruction::PtrToAddr: {
8348 const SCEV *IntOp = getPtrToAddrExpr(getSCEV(U->getOperand(0)));
8349 if (isa<SCEVCouldNotCompute>(IntOp))
8350 return getUnknown(V);
8351 return IntOp;
8352 }
8353
8354 case Instruction::PtrToInt: {
8355 // Pointer to integer cast is straight-forward, so do model it.
8356 const SCEV *Op = getSCEV(U->getOperand(0));
8357 Type *DstIntTy = U->getType();
8358 // But only if effective SCEV (integer) type is wide enough to represent
8359 // all possible pointer values.
8360 const SCEV *IntOp = getPtrToIntExpr(Op, DstIntTy);
8361 if (isa<SCEVCouldNotCompute>(IntOp))
8362 return getUnknown(V);
8363 return IntOp;
8364 }
8365 case Instruction::IntToPtr:
8366 // Just don't deal with inttoptr casts.
8367 return getUnknown(V);
8368
8369 case Instruction::SDiv:
8370 // If both operands are non-negative, this is just an udiv.
8371 if (isKnownNonNegative(getSCEV(U->getOperand(0))) &&
8372 isKnownNonNegative(getSCEV(U->getOperand(1))))
8373 return getUDivExpr(getSCEV(U->getOperand(0)), getSCEV(U->getOperand(1)));
8374 break;
8375
8376 case Instruction::SRem:
8377 // If both operands are non-negative, this is just an urem.
8378 if (isKnownNonNegative(getSCEV(U->getOperand(0))) &&
8379 isKnownNonNegative(getSCEV(U->getOperand(1))))
8380 return getURemExpr(getSCEV(U->getOperand(0)), getSCEV(U->getOperand(1)));
8381 break;
8382
8383 case Instruction::GetElementPtr:
8384 return createNodeForGEP(cast<GEPOperator>(U));
8385
8386 case Instruction::PHI:
8387 return createNodeForPHI(cast<PHINode>(U));
8388
8389 case Instruction::Select:
8390 return createNodeForSelectOrPHI(U, U->getOperand(0), U->getOperand(1),
8391 U->getOperand(2));
8392
8393 case Instruction::Call:
8394 case Instruction::Invoke:
8395 if (Value *RV = cast<CallBase>(U)->getReturnedArgOperand())
8396 return getSCEV(RV);
8397
8398 if (auto *II = dyn_cast<IntrinsicInst>(U)) {
8399 switch (II->getIntrinsicID()) {
8400 case Intrinsic::abs:
8401 return getAbsExpr(
8402 getSCEV(II->getArgOperand(0)),
8403 /*IsNSW=*/cast<ConstantInt>(II->getArgOperand(1))->isOne());
8404 case Intrinsic::umax:
8405 LHS = getSCEV(II->getArgOperand(0));
8406 RHS = getSCEV(II->getArgOperand(1));
8407 return getUMaxExpr(LHS, RHS);
8408 case Intrinsic::umin:
8409 LHS = getSCEV(II->getArgOperand(0));
8410 RHS = getSCEV(II->getArgOperand(1));
8411 return getUMinExpr(LHS, RHS);
8412 case Intrinsic::smax:
8413 LHS = getSCEV(II->getArgOperand(0));
8414 RHS = getSCEV(II->getArgOperand(1));
8415 return getSMaxExpr(LHS, RHS);
8416 case Intrinsic::smin:
8417 LHS = getSCEV(II->getArgOperand(0));
8418 RHS = getSCEV(II->getArgOperand(1));
8419 return getSMinExpr(LHS, RHS);
8420 case Intrinsic::usub_sat: {
8421 const SCEV *X = getSCEV(II->getArgOperand(0));
8422 const SCEV *Y = getSCEV(II->getArgOperand(1));
8423 const SCEV *ClampedY = getUMinExpr(X, Y);
8424 return getMinusSCEV(X, ClampedY, SCEV::FlagNUW);
8425 }
8426 case Intrinsic::uadd_sat: {
8427 const SCEV *X = getSCEV(II->getArgOperand(0));
8428 const SCEV *Y = getSCEV(II->getArgOperand(1));
8429 const SCEV *ClampedX = getUMinExpr(X, getNotSCEV(Y));
8430 return getAddExpr(ClampedX, Y, SCEV::FlagNUW);
8431 }
8432 case Intrinsic::start_loop_iterations:
8433 case Intrinsic::annotation:
8434 case Intrinsic::ptr_annotation:
8435 // A start_loop_iterations or llvm.annotation or llvm.prt.annotation is
8436 // just eqivalent to the first operand for SCEV purposes.
8437 return getSCEV(II->getArgOperand(0));
8438 case Intrinsic::vscale:
8439 return getVScale(II->getType());
8440 default:
8441 break;
8442 }
8443 }
8444 break;
8445 }
8446
8447 return getUnknown(V);
8448}
8449
8450//===----------------------------------------------------------------------===//
8451// Iteration Count Computation Code
8452//
8453
8455 if (isa<SCEVCouldNotCompute>(ExitCount))
8456 return getCouldNotCompute();
8457
8458 auto *ExitCountType = ExitCount->getType();
8459 assert(ExitCountType->isIntegerTy());
8460 auto *EvalTy = Type::getIntNTy(ExitCountType->getContext(),
8461 1 + ExitCountType->getScalarSizeInBits());
8462 return getTripCountFromExitCount(ExitCount, EvalTy, nullptr);
8463}
8464
8466 Type *EvalTy,
8467 const Loop *L) {
8468 if (isa<SCEVCouldNotCompute>(ExitCount))
8469 return getCouldNotCompute();
8470
8471 unsigned ExitCountSize = getTypeSizeInBits(ExitCount->getType());
8472 unsigned EvalSize = EvalTy->getPrimitiveSizeInBits();
8473
8474 auto CanAddOneWithoutOverflow = [&]() {
8475 ConstantRange ExitCountRange =
8476 getRangeRef(ExitCount, RangeSignHint::HINT_RANGE_UNSIGNED);
8477 if (!ExitCountRange.contains(APInt::getMaxValue(ExitCountSize)))
8478 return true;
8479
8480 return L && isLoopEntryGuardedByCond(L, ICmpInst::ICMP_NE, ExitCount,
8481 getMinusOne(ExitCount->getType()));
8482 };
8483
8484 // If we need to zero extend the backedge count, check if we can add one to
8485 // it prior to zero extending without overflow. Provided this is safe, it
8486 // allows better simplification of the +1.
8487 if (EvalSize > ExitCountSize && CanAddOneWithoutOverflow())
8488 return getZeroExtendExpr(
8489 getAddExpr(ExitCount, getOne(ExitCount->getType())), EvalTy);
8490
8491 // Get the total trip count from the count by adding 1. This may wrap.
8492 return getAddExpr(getTruncateOrZeroExtend(ExitCount, EvalTy), getOne(EvalTy));
8493}
8494
8495static unsigned getConstantTripCount(const SCEVConstant *ExitCount) {
8496 if (!ExitCount)
8497 return 0;
8498
8499 ConstantInt *ExitConst = ExitCount->getValue();
8500
8501 // Guard against huge trip counts.
8502 if (ExitConst->getValue().getActiveBits() > 32)
8503 return 0;
8504
8505 // In case of integer overflow, this returns 0, which is correct.
8506 return ((unsigned)ExitConst->getZExtValue()) + 1;
8507}
8508
8510 auto *ExitCount = dyn_cast<SCEVConstant>(getBackedgeTakenCount(L, Exact));
8511 return getConstantTripCount(ExitCount);
8512}
8513
8514unsigned
8516 const BasicBlock *ExitingBlock) {
8517 assert(ExitingBlock && "Must pass a non-null exiting block!");
8518 assert(L->isLoopExiting(ExitingBlock) &&
8519 "Exiting block must actually branch out of the loop!");
8520 const SCEVConstant *ExitCount =
8521 dyn_cast<SCEVConstant>(getExitCount(L, ExitingBlock));
8522 return getConstantTripCount(ExitCount);
8523}
8524
8526 const Loop *L, SmallVectorImpl<const SCEVPredicate *> *Predicates) {
8527
8528 const auto *MaxExitCount =
8529 Predicates ? getPredicatedConstantMaxBackedgeTakenCount(L, *Predicates)
8531 return getConstantTripCount(dyn_cast<SCEVConstant>(MaxExitCount));
8532}
8533
8535 SmallVector<BasicBlock *, 8> ExitingBlocks;
8536 L->getExitingBlocks(ExitingBlocks);
8537
8538 std::optional<unsigned> Res;
8539 for (auto *ExitingBB : ExitingBlocks) {
8540 unsigned Multiple = getSmallConstantTripMultiple(L, ExitingBB);
8541 if (!Res)
8542 Res = Multiple;
8543 Res = std::gcd(*Res, Multiple);
8544 }
8545 return Res.value_or(1);
8546}
8547
8549 const SCEV *ExitCount) {
8550 if (isa<SCEVCouldNotCompute>(ExitCount))
8551 return 1;
8552
8553 // Get the trip count
8554 const SCEV *TCExpr = getTripCountFromExitCount(applyLoopGuards(ExitCount, L));
8555
8556 APInt Multiple = getNonZeroConstantMultiple(TCExpr);
8557 // If a trip multiple is huge (>=2^32), the trip count is still divisible by
8558 // the greatest power of 2 divisor less than 2^32.
8559 return Multiple.getActiveBits() > 32
8560 ? 1U << std::min(31U, Multiple.countTrailingZeros())
8561 : (unsigned)Multiple.getZExtValue();
8562}
8563
8564/// Returns the largest constant divisor of the trip count of this loop as a
8565/// normal unsigned value, if possible. This means that the actual trip count is
8566/// always a multiple of the returned value (don't forget the trip count could
8567/// very well be zero as well!).
8568///
8569/// Returns 1 if the trip count is unknown or not guaranteed to be the
8570/// multiple of a constant (which is also the case if the trip count is simply
8571/// constant, use getSmallConstantTripCount for that case), Will also return 1
8572/// if the trip count is very large (>= 2^32).
8573///
8574/// As explained in the comments for getSmallConstantTripCount, this assumes
8575/// that control exits the loop via ExitingBlock.
8576unsigned
8578 const BasicBlock *ExitingBlock) {
8579 assert(ExitingBlock && "Must pass a non-null exiting block!");
8580 assert(L->isLoopExiting(ExitingBlock) &&
8581 "Exiting block must actually branch out of the loop!");
8582 const SCEV *ExitCount = getExitCount(L, ExitingBlock);
8583 return getSmallConstantTripMultiple(L, ExitCount);
8584}
8585
8587 const BasicBlock *ExitingBlock,
8588 ExitCountKind Kind) {
8589 switch (Kind) {
8590 case Exact:
8591 return getBackedgeTakenInfo(L).getExact(ExitingBlock, this);
8592 case SymbolicMaximum:
8593 return getBackedgeTakenInfo(L).getSymbolicMax(ExitingBlock, this);
8594 case ConstantMaximum:
8595 return getBackedgeTakenInfo(L).getConstantMax(ExitingBlock, this);
8596 };
8597 llvm_unreachable("Invalid ExitCountKind!");
8598}
8599
8601 const Loop *L, const BasicBlock *ExitingBlock,
8603 switch (Kind) {
8604 case Exact:
8605 return getPredicatedBackedgeTakenInfo(L).getExact(ExitingBlock, this,
8606 Predicates);
8607 case SymbolicMaximum:
8608 return getPredicatedBackedgeTakenInfo(L).getSymbolicMax(ExitingBlock, this,
8609 Predicates);
8610 case ConstantMaximum:
8611 return getPredicatedBackedgeTakenInfo(L).getConstantMax(ExitingBlock, this,
8612 Predicates);
8613 };
8614 llvm_unreachable("Invalid ExitCountKind!");
8615}
8616
8619 return getPredicatedBackedgeTakenInfo(L).getExact(L, this, &Preds);
8620}
8621
8623 ExitCountKind Kind) {
8624 switch (Kind) {
8625 case Exact:
8626 return getBackedgeTakenInfo(L).getExact(L, this);
8627 case ConstantMaximum:
8628 return getBackedgeTakenInfo(L).getConstantMax(this);
8629 case SymbolicMaximum:
8630 return getBackedgeTakenInfo(L).getSymbolicMax(L, this);
8631 };
8632 llvm_unreachable("Invalid ExitCountKind!");
8633}
8634
8637 return getPredicatedBackedgeTakenInfo(L).getSymbolicMax(L, this, &Preds);
8638}
8639
8642 return getPredicatedBackedgeTakenInfo(L).getConstantMax(this, &Preds);
8643}
8644
8646 return getBackedgeTakenInfo(L).isConstantMaxOrZero(this);
8647}
8648
8649/// Push PHI nodes in the header of the given loop onto the given Worklist.
8650static void PushLoopPHIs(const Loop *L,
8653 BasicBlock *Header = L->getHeader();
8654
8655 // Push all Loop-header PHIs onto the Worklist stack.
8656 for (PHINode &PN : Header->phis())
8657 if (Visited.insert(&PN).second)
8658 Worklist.push_back(&PN);
8659}
8660
8661ScalarEvolution::BackedgeTakenInfo &
8662ScalarEvolution::getPredicatedBackedgeTakenInfo(const Loop *L) {
8663 auto &BTI = getBackedgeTakenInfo(L);
8664 if (BTI.hasFullInfo())
8665 return BTI;
8666
8667 auto Pair = PredicatedBackedgeTakenCounts.try_emplace(L);
8668
8669 if (!Pair.second)
8670 return Pair.first->second;
8671
8672 BackedgeTakenInfo Result =
8673 computeBackedgeTakenCount(L, /*AllowPredicates=*/true);
8674
8675 return PredicatedBackedgeTakenCounts.find(L)->second = std::move(Result);
8676}
8677
8678ScalarEvolution::BackedgeTakenInfo &
8679ScalarEvolution::getBackedgeTakenInfo(const Loop *L) {
8680 // Initially insert an invalid entry for this loop. If the insertion
8681 // succeeds, proceed to actually compute a backedge-taken count and
8682 // update the value. The temporary CouldNotCompute value tells SCEV
8683 // code elsewhere that it shouldn't attempt to request a new
8684 // backedge-taken count, which could result in infinite recursion.
8685 std::pair<DenseMap<const Loop *, BackedgeTakenInfo>::iterator, bool> Pair =
8686 BackedgeTakenCounts.try_emplace(L);
8687 if (!Pair.second)
8688 return Pair.first->second;
8689
8690 // computeBackedgeTakenCount may allocate memory for its result. Inserting it
8691 // into the BackedgeTakenCounts map transfers ownership. Otherwise, the result
8692 // must be cleared in this scope.
8693 BackedgeTakenInfo Result = computeBackedgeTakenCount(L);
8694
8695 // Now that we know more about the trip count for this loop, forget any
8696 // existing SCEV values for PHI nodes in this loop since they are only
8697 // conservative estimates made without the benefit of trip count
8698 // information. This invalidation is not necessary for correctness, and is
8699 // only done to produce more precise results.
8700 if (Result.hasAnyInfo()) {
8701 // Invalidate any expression using an addrec in this loop.
8702 SmallVector<SCEVUse, 8> ToForget;
8703 auto LoopUsersIt = LoopUsers.find(L);
8704 if (LoopUsersIt != LoopUsers.end())
8705 append_range(ToForget, LoopUsersIt->second);
8706 forgetMemoizedResults(ToForget);
8707
8708 // Invalidate constant-evolved loop header phis.
8709 for (PHINode &PN : L->getHeader()->phis())
8710 ConstantEvolutionLoopExitValue.erase(&PN);
8711 }
8712
8713 // Re-lookup the insert position, since the call to
8714 // computeBackedgeTakenCount above could result in a
8715 // recusive call to getBackedgeTakenInfo (on a different
8716 // loop), which would invalidate the iterator computed
8717 // earlier.
8718 return BackedgeTakenCounts.find(L)->second = std::move(Result);
8719}
8720
8722 // This method is intended to forget all info about loops. It should
8723 // invalidate caches as if the following happened:
8724 // - The trip counts of all loops have changed arbitrarily
8725 // - Every llvm::Value has been updated in place to produce a different
8726 // result.
8727 BackedgeTakenCounts.clear();
8728 PredicatedBackedgeTakenCounts.clear();
8729 BECountUsers.clear();
8730 LoopPropertiesCache.clear();
8731 ConstantEvolutionLoopExitValue.clear();
8732 ValueExprMap.clear();
8733 ValuesAtScopes.clear();
8734 ValuesAtScopesUsers.clear();
8735 LoopDispositions.clear();
8736 BlockDispositions.clear();
8737 UnsignedRanges.clear();
8738 SignedRanges.clear();
8739 ExprValueMap.clear();
8740 HasRecMap.clear();
8741 ConstantMultipleCache.clear();
8742 PredicatedSCEVRewrites.clear();
8743 FoldCache.clear();
8744 FoldCacheUser.clear();
8745}
8746void ScalarEvolution::visitAndClearUsers(
8749 SmallVectorImpl<SCEVUse> &ToForget) {
8750 while (!Worklist.empty()) {
8751 Instruction *I = Worklist.pop_back_val();
8752 if (!isSCEVable(I->getType()) && !isa<WithOverflowInst>(I))
8753 continue;
8754
8756 ValueExprMap.find_as(static_cast<Value *>(I));
8757 if (It != ValueExprMap.end()) {
8758 eraseValueFromMap(It->first);
8759 ToForget.push_back(It->second);
8760 if (PHINode *PN = dyn_cast<PHINode>(I))
8761 ConstantEvolutionLoopExitValue.erase(PN);
8762 }
8763
8764 PushDefUseChildren(I, Worklist, Visited);
8765 }
8766}
8767
8769 SmallVector<const Loop *, 16> LoopWorklist(1, L);
8772 SmallVector<SCEVUse, 16> ToForget;
8773
8774 // Iterate over all the loops and sub-loops to drop SCEV information.
8775 while (!LoopWorklist.empty()) {
8776 auto *CurrL = LoopWorklist.pop_back_val();
8777
8778 // Drop any stored trip count value.
8779 forgetBackedgeTakenCounts(CurrL, /* Predicated */ false);
8780 forgetBackedgeTakenCounts(CurrL, /* Predicated */ true);
8781
8782 // Drop information about predicated SCEV rewrites for this loop.
8783 for (auto I = PredicatedSCEVRewrites.begin();
8784 I != PredicatedSCEVRewrites.end();) {
8785 std::pair<const SCEV *, const Loop *> Entry = I->first;
8786 if (Entry.second == CurrL)
8787 PredicatedSCEVRewrites.erase(I++);
8788 else
8789 ++I;
8790 }
8791
8792 auto LoopUsersItr = LoopUsers.find(CurrL);
8793 if (LoopUsersItr != LoopUsers.end())
8794 llvm::append_range(ToForget, LoopUsersItr->second);
8795
8796 // Drop information about expressions based on loop-header PHIs.
8797 PushLoopPHIs(CurrL, Worklist, Visited);
8798 visitAndClearUsers(Worklist, Visited, ToForget);
8799
8800 LoopPropertiesCache.erase(CurrL);
8801 // Forget all contained loops too, to avoid dangling entries in the
8802 // ValuesAtScopes map.
8803 LoopWorklist.append(CurrL->begin(), CurrL->end());
8804 }
8805 forgetMemoizedResults(ToForget);
8806}
8807
8809 forgetLoop(L->getOutermostLoop());
8810}
8811
8814 if (!I) return;
8815
8816 // Drop information about expressions based on loop-header PHIs.
8819 SmallVector<SCEVUse, 8> ToForget;
8820 Worklist.push_back(I);
8821 Visited.insert(I);
8822 visitAndClearUsers(Worklist, Visited, ToForget);
8823
8824 forgetMemoizedResults(ToForget);
8825}
8826
8828 if (!isSCEVable(V->getType()))
8829 return;
8830
8831 // If SCEV looked through a trivial LCSSA phi node, we might have SCEV's
8832 // directly using a SCEVUnknown/SCEVAddRec defined in the loop. After an
8833 // extra predecessor is added, this is no longer valid. Find all Unknowns and
8834 // AddRecs defined in the loop and invalidate any SCEV's making use of them.
8835 if (const SCEV *S = getExistingSCEV(V)) {
8836 struct InvalidationRootCollector {
8837 Loop *L;
8839
8840 InvalidationRootCollector(Loop *L) : L(L) {}
8841
8842 bool follow(const SCEV *S) {
8843 if (auto *SU = dyn_cast<SCEVUnknown>(S)) {
8844 if (auto *I = dyn_cast<Instruction>(SU->getValue()))
8845 if (L->contains(I))
8846 Roots.push_back(S);
8847 } else if (auto *AddRec = dyn_cast<SCEVAddRecExpr>(S)) {
8848 if (L->contains(AddRec->getLoop()))
8849 Roots.push_back(S);
8850 }
8851 return true;
8852 }
8853 bool isDone() const { return false; }
8854 };
8855
8856 InvalidationRootCollector C(L);
8857 visitAll(S, C);
8858 forgetMemoizedResults(C.Roots);
8859 }
8860
8861 // Also perform the normal invalidation.
8862 forgetValue(V);
8863}
8864
8865void ScalarEvolution::forgetLoopDispositions() { LoopDispositions.clear(); }
8866
8868 // Unless a specific value is passed to invalidation, completely clear both
8869 // caches.
8870 if (!V) {
8871 BlockDispositions.clear();
8872 LoopDispositions.clear();
8873 return;
8874 }
8875
8876 if (!isSCEVable(V->getType()))
8877 return;
8878
8879 const SCEV *S = getExistingSCEV(V);
8880 if (!S)
8881 return;
8882
8883 // Invalidate the block and loop dispositions cached for S. Dispositions of
8884 // S's users may change if S's disposition changes (i.e. a user may change to
8885 // loop-invariant, if S changes to loop invariant), so also invalidate
8886 // dispositions of S's users recursively.
8887 SmallVector<SCEVUse, 8> Worklist = {S};
8889 while (!Worklist.empty()) {
8890 const SCEV *Curr = Worklist.pop_back_val();
8891 bool LoopDispoRemoved = LoopDispositions.erase(Curr);
8892 bool BlockDispoRemoved = BlockDispositions.erase(Curr);
8893 if (!LoopDispoRemoved && !BlockDispoRemoved)
8894 continue;
8895 auto Users = SCEVUsers.find(Curr);
8896 if (Users != SCEVUsers.end())
8897 for (const auto *User : Users->second)
8898 if (Seen.insert(User).second)
8899 Worklist.push_back(User);
8900 }
8901}
8902
8903/// Get the exact loop backedge taken count considering all loop exits. A
8904/// computable result can only be returned for loops with all exiting blocks
8905/// dominating the latch. howFarToZero assumes that the limit of each loop test
8906/// is never skipped. This is a valid assumption as long as the loop exits via
8907/// that test. For precise results, it is the caller's responsibility to specify
8908/// the relevant loop exiting block using getExact(ExitingBlock, SE).
8909const SCEV *ScalarEvolution::BackedgeTakenInfo::getExact(
8910 const Loop *L, ScalarEvolution *SE,
8912 // If any exits were not computable, the loop is not computable.
8913 if (!isComplete() || ExitNotTaken.empty())
8914 return SE->getCouldNotCompute();
8915
8916 const BasicBlock *Latch = L->getLoopLatch();
8917 // All exiting blocks we have collected must dominate the only backedge.
8918 if (!Latch)
8919 return SE->getCouldNotCompute();
8920
8921 // All exiting blocks we have gathered dominate loop's latch, so exact trip
8922 // count is simply a minimum out of all these calculated exit counts.
8924 for (const auto &ENT : ExitNotTaken) {
8925 const SCEV *BECount = ENT.ExactNotTaken;
8926 assert(BECount != SE->getCouldNotCompute() && "Bad exit SCEV!");
8927 assert(SE->DT.dominates(ENT.ExitingBlock, Latch) &&
8928 "We should only have known counts for exiting blocks that dominate "
8929 "latch!");
8930
8931 Ops.push_back(BECount);
8932
8933 if (Preds)
8934 append_range(*Preds, ENT.Predicates);
8935
8936 assert((Preds || ENT.hasAlwaysTruePredicate()) &&
8937 "Predicate should be always true!");
8938 }
8939
8940 // If an earlier exit exits on the first iteration (exit count zero), then
8941 // a later poison exit count should not propagate into the result. This are
8942 // exactly the semantics provided by umin_seq.
8943 return SE->getUMinFromMismatchedTypes(Ops, /* Sequential */ true);
8944}
8945
8946const ScalarEvolution::ExitNotTakenInfo *
8947ScalarEvolution::BackedgeTakenInfo::getExitNotTaken(
8948 const BasicBlock *ExitingBlock,
8949 SmallVectorImpl<const SCEVPredicate *> *Predicates) const {
8950 for (const auto &ENT : ExitNotTaken)
8951 if (ENT.ExitingBlock == ExitingBlock) {
8952 if (ENT.hasAlwaysTruePredicate())
8953 return &ENT;
8954 else if (Predicates) {
8955 append_range(*Predicates, ENT.Predicates);
8956 return &ENT;
8957 }
8958 }
8959
8960 return nullptr;
8961}
8962
8963/// getConstantMax - Get the constant max backedge taken count for the loop.
8964const SCEV *ScalarEvolution::BackedgeTakenInfo::getConstantMax(
8965 ScalarEvolution *SE,
8966 SmallVectorImpl<const SCEVPredicate *> *Predicates) const {
8967 if (!getConstantMax())
8968 return SE->getCouldNotCompute();
8969
8970 for (const auto &ENT : ExitNotTaken)
8971 if (!ENT.hasAlwaysTruePredicate()) {
8972 if (!Predicates)
8973 return SE->getCouldNotCompute();
8974 append_range(*Predicates, ENT.Predicates);
8975 }
8976
8977 assert((isa<SCEVCouldNotCompute>(getConstantMax()) ||
8978 isa<SCEVConstant>(getConstantMax())) &&
8979 "No point in having a non-constant max backedge taken count!");
8980 return getConstantMax();
8981}
8982
8983const SCEV *ScalarEvolution::BackedgeTakenInfo::getSymbolicMax(
8984 const Loop *L, ScalarEvolution *SE,
8985 SmallVectorImpl<const SCEVPredicate *> *Predicates) {
8986 if (!SymbolicMax) {
8987 // Form an expression for the maximum exit count possible for this loop. We
8988 // merge the max and exact information to approximate a version of
8989 // getConstantMaxBackedgeTakenCount which isn't restricted to just
8990 // constants.
8991 SmallVector<SCEVUse, 4> ExitCounts;
8992
8993 for (const auto &ENT : ExitNotTaken) {
8994 const SCEV *ExitCount = ENT.SymbolicMaxNotTaken;
8995 if (!isa<SCEVCouldNotCompute>(ExitCount)) {
8996 assert(SE->DT.dominates(ENT.ExitingBlock, L->getLoopLatch()) &&
8997 "We should only have known counts for exiting blocks that "
8998 "dominate latch!");
8999 ExitCounts.push_back(ExitCount);
9000 if (Predicates)
9001 append_range(*Predicates, ENT.Predicates);
9002
9003 assert((Predicates || ENT.hasAlwaysTruePredicate()) &&
9004 "Predicate should be always true!");
9005 }
9006 }
9007 if (ExitCounts.empty())
9008 SymbolicMax = SE->getCouldNotCompute();
9009 else
9010 SymbolicMax =
9011 SE->getUMinFromMismatchedTypes(ExitCounts, /*Sequential*/ true);
9012 }
9013 return SymbolicMax;
9014}
9015
9016bool ScalarEvolution::BackedgeTakenInfo::isConstantMaxOrZero(
9017 ScalarEvolution *SE) const {
9018 auto PredicateNotAlwaysTrue = [](const ExitNotTakenInfo &ENT) {
9019 return !ENT.hasAlwaysTruePredicate();
9020 };
9021 return MaxOrZero && !any_of(ExitNotTaken, PredicateNotAlwaysTrue);
9022}
9023
9026
9028 const SCEV *E, const SCEV *ConstantMaxNotTaken,
9029 const SCEV *SymbolicMaxNotTaken, bool MaxOrZero,
9033 // If we prove the max count is zero, so is the symbolic bound. This happens
9034 // in practice due to differences in a) how context sensitive we've chosen
9035 // to be and b) how we reason about bounds implied by UB.
9036 if (ConstantMaxNotTaken->isZero()) {
9037 this->ExactNotTaken = E = ConstantMaxNotTaken;
9038 this->SymbolicMaxNotTaken = SymbolicMaxNotTaken = ConstantMaxNotTaken;
9039 }
9040
9043 "Exact is not allowed to be less precise than Constant Max");
9046 "Exact is not allowed to be less precise than Symbolic Max");
9049 "Symbolic Max is not allowed to be less precise than Constant Max");
9052 "No point in having a non-constant max backedge taken count!");
9054 for (const auto PredList : PredLists)
9055 for (const auto *P : PredList) {
9056 if (SeenPreds.contains(P))
9057 continue;
9058 assert(!isa<SCEVUnionPredicate>(P) && "Only add leaf predicates here!");
9059 SeenPreds.insert(P);
9060 Predicates.push_back(P);
9061 }
9062 assert((isa<SCEVCouldNotCompute>(E) || !E->getType()->isPointerTy()) &&
9063 "Backedge count should be int");
9065 !ConstantMaxNotTaken->getType()->isPointerTy()) &&
9066 "Max backedge count should be int");
9067}
9068
9076
9077/// Allocate memory for BackedgeTakenInfo and copy the not-taken count of each
9078/// computable exit into a persistent ExitNotTakenInfo array.
9079ScalarEvolution::BackedgeTakenInfo::BackedgeTakenInfo(
9081 bool IsComplete, const SCEV *ConstantMax, bool MaxOrZero)
9082 : ConstantMax(ConstantMax), IsComplete(IsComplete), MaxOrZero(MaxOrZero) {
9083 using EdgeExitInfo = ScalarEvolution::BackedgeTakenInfo::EdgeExitInfo;
9084
9085 ExitNotTaken.reserve(ExitCounts.size());
9086 std::transform(ExitCounts.begin(), ExitCounts.end(),
9087 std::back_inserter(ExitNotTaken),
9088 [&](const EdgeExitInfo &EEI) {
9089 BasicBlock *ExitBB = EEI.first;
9090 const ExitLimit &EL = EEI.second;
9091 return ExitNotTakenInfo(ExitBB, EL.ExactNotTaken,
9092 EL.ConstantMaxNotTaken, EL.SymbolicMaxNotTaken,
9093 EL.Predicates);
9094 });
9095 assert((isa<SCEVCouldNotCompute>(ConstantMax) ||
9096 isa<SCEVConstant>(ConstantMax)) &&
9097 "No point in having a non-constant max backedge taken count!");
9098}
9099
9100/// Compute the number of times the backedge of the specified loop will execute.
9101ScalarEvolution::BackedgeTakenInfo
9102ScalarEvolution::computeBackedgeTakenCount(const Loop *L,
9103 bool AllowPredicates) {
9104 SmallVector<BasicBlock *, 8> ExitingBlocks;
9105 L->getExitingBlocks(ExitingBlocks);
9106
9107 using EdgeExitInfo = ScalarEvolution::BackedgeTakenInfo::EdgeExitInfo;
9108
9110 bool CouldComputeBECount = true;
9111 BasicBlock *Latch = L->getLoopLatch(); // may be NULL.
9112 const SCEV *MustExitMaxBECount = nullptr;
9113 const SCEV *MayExitMaxBECount = nullptr;
9114 bool MustExitMaxOrZero = false;
9115 bool IsOnlyExit = ExitingBlocks.size() == 1;
9116
9117 // Compute the ExitLimit for each loop exit. Use this to populate ExitCounts
9118 // and compute maxBECount.
9119 // Do a union of all the predicates here.
9120 for (BasicBlock *ExitBB : ExitingBlocks) {
9121 // We canonicalize untaken exits to br (constant), ignore them so that
9122 // proving an exit untaken doesn't negatively impact our ability to reason
9123 // about the loop as whole.
9124 if (auto *BI = dyn_cast<CondBrInst>(ExitBB->getTerminator()))
9125 if (auto *CI = dyn_cast<ConstantInt>(BI->getCondition())) {
9126 bool ExitIfTrue = !L->contains(BI->getSuccessor(0));
9127 if (ExitIfTrue == CI->isZero())
9128 continue;
9129 }
9130
9131 ExitLimit EL = computeExitLimit(L, ExitBB, IsOnlyExit, AllowPredicates);
9132
9133 assert((AllowPredicates || EL.Predicates.empty()) &&
9134 "Predicated exit limit when predicates are not allowed!");
9135
9136 // 1. For each exit that can be computed, add an entry to ExitCounts.
9137 // CouldComputeBECount is true only if all exits can be computed.
9138 if (EL.ExactNotTaken != getCouldNotCompute())
9139 ++NumExitCountsComputed;
9140 else
9141 // We couldn't compute an exact value for this exit, so
9142 // we won't be able to compute an exact value for the loop.
9143 CouldComputeBECount = false;
9144 // Remember exit count if either exact or symbolic is known. Because
9145 // Exact always implies symbolic, only check symbolic.
9146 if (EL.SymbolicMaxNotTaken != getCouldNotCompute())
9147 ExitCounts.emplace_back(ExitBB, EL);
9148 else {
9149 assert(EL.ExactNotTaken == getCouldNotCompute() &&
9150 "Exact is known but symbolic isn't?");
9151 ++NumExitCountsNotComputed;
9152 }
9153
9154 // 2. Derive the loop's MaxBECount from each exit's max number of
9155 // non-exiting iterations. Partition the loop exits into two kinds:
9156 // LoopMustExits and LoopMayExits.
9157 //
9158 // If the exit dominates the loop latch, it is a LoopMustExit otherwise it
9159 // is a LoopMayExit. If any computable LoopMustExit is found, then
9160 // MaxBECount is the minimum EL.ConstantMaxNotTaken of computable
9161 // LoopMustExits. Otherwise, MaxBECount is conservatively the maximum
9162 // EL.ConstantMaxNotTaken, where CouldNotCompute is considered greater than
9163 // any
9164 // computable EL.ConstantMaxNotTaken.
9165 if (EL.ConstantMaxNotTaken != getCouldNotCompute() && Latch &&
9166 DT.dominates(ExitBB, Latch)) {
9167 if (!MustExitMaxBECount) {
9168 MustExitMaxBECount = EL.ConstantMaxNotTaken;
9169 MustExitMaxOrZero = EL.MaxOrZero;
9170 } else {
9171 MustExitMaxBECount = getUMinFromMismatchedTypes(MustExitMaxBECount,
9172 EL.ConstantMaxNotTaken);
9173 }
9174 } else if (MayExitMaxBECount != getCouldNotCompute()) {
9175 if (!MayExitMaxBECount || EL.ConstantMaxNotTaken == getCouldNotCompute())
9176 MayExitMaxBECount = EL.ConstantMaxNotTaken;
9177 else {
9178 MayExitMaxBECount = getUMaxFromMismatchedTypes(MayExitMaxBECount,
9179 EL.ConstantMaxNotTaken);
9180 }
9181 }
9182 }
9183 const SCEV *MaxBECount = MustExitMaxBECount ? MustExitMaxBECount :
9184 (MayExitMaxBECount ? MayExitMaxBECount : getCouldNotCompute());
9185 // The loop backedge will be taken the maximum or zero times if there's
9186 // a single exit that must be taken the maximum or zero times.
9187 bool MaxOrZero = (MustExitMaxOrZero && ExitingBlocks.size() == 1);
9188
9189 // Remember which SCEVs are used in exit limits for invalidation purposes.
9190 // We only care about non-constant SCEVs here, so we can ignore
9191 // EL.ConstantMaxNotTaken
9192 // and MaxBECount, which must be SCEVConstant.
9193 for (const auto &Pair : ExitCounts) {
9194 if (!isa<SCEVConstant>(Pair.second.ExactNotTaken))
9195 BECountUsers[Pair.second.ExactNotTaken].insert({L, AllowPredicates});
9196 if (!isa<SCEVConstant>(Pair.second.SymbolicMaxNotTaken))
9197 BECountUsers[Pair.second.SymbolicMaxNotTaken].insert(
9198 {L, AllowPredicates});
9199 }
9200 return BackedgeTakenInfo(std::move(ExitCounts), CouldComputeBECount,
9201 MaxBECount, MaxOrZero);
9202}
9203
9204ScalarEvolution::ExitLimit
9205ScalarEvolution::computeExitLimit(const Loop *L, BasicBlock *ExitingBlock,
9206 bool IsOnlyExit, bool AllowPredicates) {
9207 assert(L->contains(ExitingBlock) && "Exit count for non-loop block?");
9208 // If our exiting block does not dominate the latch, then its connection with
9209 // loop's exit limit may be far from trivial.
9210 const BasicBlock *Latch = L->getLoopLatch();
9211 if (!Latch || !DT.dominates(ExitingBlock, Latch))
9212 return getCouldNotCompute();
9213
9214 Instruction *Term = ExitingBlock->getTerminator();
9215 if (CondBrInst *BI = dyn_cast<CondBrInst>(Term)) {
9216 bool ExitIfTrue = !L->contains(BI->getSuccessor(0));
9217 assert(ExitIfTrue == L->contains(BI->getSuccessor(1)) &&
9218 "It should have one successor in loop and one exit block!");
9219 // Proceed to the next level to examine the exit condition expression.
9220 return computeExitLimitFromCond(L, BI->getCondition(), ExitIfTrue,
9221 /*ControlsOnlyExit=*/IsOnlyExit,
9222 AllowPredicates);
9223 }
9224
9225 if (SwitchInst *SI = dyn_cast<SwitchInst>(Term)) {
9226 // For switch, make sure that there is a single exit from the loop.
9227 BasicBlock *Exit = nullptr;
9228 for (auto *SBB : successors(ExitingBlock))
9229 if (!L->contains(SBB)) {
9230 if (Exit) // Multiple exit successors.
9231 return getCouldNotCompute();
9232 Exit = SBB;
9233 }
9234 assert(Exit && "Exiting block must have at least one exit");
9235 return computeExitLimitFromSingleExitSwitch(
9236 L, SI, Exit, /*ControlsOnlyExit=*/IsOnlyExit);
9237 }
9238
9239 return getCouldNotCompute();
9240}
9241
9243 const Loop *L, Value *ExitCond, bool ExitIfTrue, bool ControlsOnlyExit,
9244 bool AllowPredicates) {
9245 ScalarEvolution::ExitLimitCacheTy Cache(L, ExitIfTrue, AllowPredicates);
9246 return computeExitLimitFromCondCached(Cache, L, ExitCond, ExitIfTrue,
9247 ControlsOnlyExit, AllowPredicates);
9248}
9249
9250std::optional<ScalarEvolution::ExitLimit>
9251ScalarEvolution::ExitLimitCache::find(const Loop *L, Value *ExitCond,
9252 bool ExitIfTrue, bool ControlsOnlyExit,
9253 bool AllowPredicates) {
9254 (void)this->L;
9255 (void)this->ExitIfTrue;
9256 (void)this->AllowPredicates;
9257
9258 assert(this->L == L && this->ExitIfTrue == ExitIfTrue &&
9259 this->AllowPredicates == AllowPredicates &&
9260 "Variance in assumed invariant key components!");
9261 auto Itr = TripCountMap.find({ExitCond, ControlsOnlyExit});
9262 if (Itr == TripCountMap.end())
9263 return std::nullopt;
9264 return Itr->second;
9265}
9266
9267void ScalarEvolution::ExitLimitCache::insert(const Loop *L, Value *ExitCond,
9268 bool ExitIfTrue,
9269 bool ControlsOnlyExit,
9270 bool AllowPredicates,
9271 const ExitLimit &EL) {
9272 assert(this->L == L && this->ExitIfTrue == ExitIfTrue &&
9273 this->AllowPredicates == AllowPredicates &&
9274 "Variance in assumed invariant key components!");
9275
9276 auto InsertResult = TripCountMap.insert({{ExitCond, ControlsOnlyExit}, EL});
9277 assert(InsertResult.second && "Expected successful insertion!");
9278 (void)InsertResult;
9279 (void)ExitIfTrue;
9280}
9281
9282ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromCondCached(
9283 ExitLimitCacheTy &Cache, const Loop *L, Value *ExitCond, bool ExitIfTrue,
9284 bool ControlsOnlyExit, bool AllowPredicates) {
9285
9286 if (auto MaybeEL = Cache.find(L, ExitCond, ExitIfTrue, ControlsOnlyExit,
9287 AllowPredicates))
9288 return *MaybeEL;
9289
9290 ExitLimit EL = computeExitLimitFromCondImpl(
9291 Cache, L, ExitCond, ExitIfTrue, ControlsOnlyExit, AllowPredicates);
9292 Cache.insert(L, ExitCond, ExitIfTrue, ControlsOnlyExit, AllowPredicates, EL);
9293 return EL;
9294}
9295
9296ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromCondImpl(
9297 ExitLimitCacheTy &Cache, const Loop *L, Value *ExitCond, bool ExitIfTrue,
9298 bool ControlsOnlyExit, bool AllowPredicates) {
9299 // Handle BinOp conditions (And, Or).
9300 if (auto LimitFromBinOp = computeExitLimitFromCondFromBinOp(
9301 Cache, L, ExitCond, ExitIfTrue, ControlsOnlyExit, AllowPredicates))
9302 return *LimitFromBinOp;
9303
9304 // With an icmp, it may be feasible to compute an exact backedge-taken count.
9305 // Proceed to the next level to examine the icmp.
9306 if (ICmpInst *ExitCondICmp = dyn_cast<ICmpInst>(ExitCond)) {
9307 ExitLimit EL =
9308 computeExitLimitFromICmp(L, ExitCondICmp, ExitIfTrue, ControlsOnlyExit);
9309 if (EL.hasFullInfo() || !AllowPredicates)
9310 return EL;
9311
9312 // Try again, but use SCEV predicates this time.
9313 return computeExitLimitFromICmp(L, ExitCondICmp, ExitIfTrue,
9314 ControlsOnlyExit,
9315 /*AllowPredicates=*/true);
9316 }
9317
9318 // Check for a constant condition. These are normally stripped out by
9319 // SimplifyCFG, but ScalarEvolution may be used by a pass which wishes to
9320 // preserve the CFG and is temporarily leaving constant conditions
9321 // in place.
9322 if (ConstantInt *CI = dyn_cast<ConstantInt>(ExitCond)) {
9323 if (ExitIfTrue == !CI->getZExtValue())
9324 // The backedge is always taken.
9325 return getCouldNotCompute();
9326 // The backedge is never taken.
9327 return getZero(CI->getType());
9328 }
9329
9330 // If we're exiting based on the overflow flag of an x.with.overflow intrinsic
9331 // with a constant step, we can form an equivalent icmp predicate and figure
9332 // out how many iterations will be taken before we exit.
9333 const WithOverflowInst *WO;
9334 const APInt *C;
9335 if (match(ExitCond, m_ExtractValue<1>(m_WithOverflowInst(WO))) &&
9336 match(WO->getRHS(), m_APInt(C))) {
9337 ConstantRange NWR =
9339 WO->getNoWrapKind());
9340 CmpInst::Predicate Pred;
9341 APInt NewRHSC, Offset;
9342 NWR.getEquivalentICmp(Pred, NewRHSC, Offset);
9343 if (!ExitIfTrue)
9344 Pred = ICmpInst::getInversePredicate(Pred);
9345 auto *LHS = getSCEV(WO->getLHS());
9346 if (Offset != 0)
9348 auto EL = computeExitLimitFromICmp(L, Pred, LHS, getConstant(NewRHSC),
9349 ControlsOnlyExit, AllowPredicates);
9350 if (EL.hasAnyInfo())
9351 return EL;
9352 }
9353
9354 // If it's not an integer or pointer comparison then compute it the hard way.
9355 return computeExitCountExhaustively(L, ExitCond, ExitIfTrue);
9356}
9357
9358std::optional<ScalarEvolution::ExitLimit>
9359ScalarEvolution::computeExitLimitFromCondFromBinOp(
9360 ExitLimitCacheTy &Cache, const Loop *L, Value *ExitCond, bool ExitIfTrue,
9361 bool ControlsOnlyExit, bool AllowPredicates) {
9362 // Check if the controlling expression for this loop is an And or Or.
9363 Value *Op0, *Op1;
9364 bool IsAnd;
9365 if (match(ExitCond, m_LogicalAnd(m_Value(Op0), m_Value(Op1))))
9366 IsAnd = true;
9367 else if (match(ExitCond, m_LogicalOr(m_Value(Op0), m_Value(Op1))))
9368 IsAnd = false;
9369 else
9370 return std::nullopt;
9371
9372 // Be robust against unsimplified IR for the form "op i1 X, NeutralElement".
9373 const Constant *NeutralElement = ConstantInt::get(ExitCond->getType(), IsAnd);
9374 if (Op0 == NeutralElement)
9375 std::swap(Op0, Op1);
9376 if (Op1 == NeutralElement)
9377 return computeExitLimitFromCondCached(Cache, L, Op0, ExitIfTrue,
9378 ControlsOnlyExit, AllowPredicates);
9379
9380 // A sub-condition of a non-trivial binop never solely controls the exit,
9381 // whether we exit always depends on both conditions.
9382 ExitLimit EL0 = computeExitLimitFromCondCached(
9383 Cache, L, Op0, ExitIfTrue, /*ControlsOnlyExit=*/false, AllowPredicates);
9384 ExitLimit EL1 = computeExitLimitFromCondCached(
9385 Cache, L, Op1, ExitIfTrue, /*ControlsOnlyExit=*/false, AllowPredicates);
9386
9387 // EitherMayExit is true in these two cases:
9388 // br (and Op0 Op1), loop, exit
9389 // br (or Op0 Op1), exit, loop
9390 bool EitherMayExit = IsAnd ^ ExitIfTrue;
9391
9392 const SCEV *BECount = getCouldNotCompute();
9393 const SCEV *ConstantMaxBECount = getCouldNotCompute();
9394 const SCEV *SymbolicMaxBECount = getCouldNotCompute();
9395 if (EitherMayExit) {
9396 bool UseSequentialUMin = !isa<BinaryOperator>(ExitCond);
9397 // Both conditions must be same for the loop to continue executing.
9398 // Choose the less conservative count.
9399 if (EL0.ExactNotTaken != getCouldNotCompute() &&
9400 EL1.ExactNotTaken != getCouldNotCompute()) {
9401 BECount = getUMinFromMismatchedTypes(EL0.ExactNotTaken, EL1.ExactNotTaken,
9402 UseSequentialUMin);
9403 }
9404 if (EL0.ConstantMaxNotTaken == getCouldNotCompute())
9405 ConstantMaxBECount = EL1.ConstantMaxNotTaken;
9406 else if (EL1.ConstantMaxNotTaken == getCouldNotCompute())
9407 ConstantMaxBECount = EL0.ConstantMaxNotTaken;
9408 else
9409 ConstantMaxBECount = getUMinFromMismatchedTypes(EL0.ConstantMaxNotTaken,
9410 EL1.ConstantMaxNotTaken);
9411 if (EL0.SymbolicMaxNotTaken == getCouldNotCompute())
9412 SymbolicMaxBECount = EL1.SymbolicMaxNotTaken;
9413 else if (EL1.SymbolicMaxNotTaken == getCouldNotCompute())
9414 SymbolicMaxBECount = EL0.SymbolicMaxNotTaken;
9415 else
9416 SymbolicMaxBECount = getUMinFromMismatchedTypes(
9417 EL0.SymbolicMaxNotTaken, EL1.SymbolicMaxNotTaken, UseSequentialUMin);
9418 } else {
9419 // Both conditions must be same at the same time for the loop to exit.
9420 // For now, be conservative.
9421 if (EL0.ExactNotTaken == EL1.ExactNotTaken)
9422 BECount = EL0.ExactNotTaken;
9423 }
9424
9425 // There are cases (e.g. PR26207) where computeExitLimitFromCond is able
9426 // to be more aggressive when computing BECount than when computing
9427 // ConstantMaxBECount. In these cases it is possible for EL0.ExactNotTaken
9428 // and
9429 // EL1.ExactNotTaken to match, but for EL0.ConstantMaxNotTaken and
9430 // EL1.ConstantMaxNotTaken to not.
9431 if (isa<SCEVCouldNotCompute>(ConstantMaxBECount) &&
9432 !isa<SCEVCouldNotCompute>(BECount))
9433 ConstantMaxBECount = getConstant(getUnsignedRangeMax(BECount));
9434 if (isa<SCEVCouldNotCompute>(SymbolicMaxBECount))
9435 SymbolicMaxBECount =
9436 isa<SCEVCouldNotCompute>(BECount) ? ConstantMaxBECount : BECount;
9437 return ExitLimit(BECount, ConstantMaxBECount, SymbolicMaxBECount, false,
9438 {ArrayRef(EL0.Predicates), ArrayRef(EL1.Predicates)});
9439}
9440
9441ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromICmp(
9442 const Loop *L, ICmpInst *ExitCond, bool ExitIfTrue, bool ControlsOnlyExit,
9443 bool AllowPredicates) {
9444 // If the condition was exit on true, convert the condition to exit on false
9445 CmpPredicate Pred;
9446 if (!ExitIfTrue)
9447 Pred = ExitCond->getCmpPredicate();
9448 else
9449 Pred = ExitCond->getInverseCmpPredicate();
9450 const ICmpInst::Predicate OriginalPred = Pred;
9451
9452 const SCEV *LHS = getSCEV(ExitCond->getOperand(0));
9453 const SCEV *RHS = getSCEV(ExitCond->getOperand(1));
9454
9455 ExitLimit EL = computeExitLimitFromICmp(L, Pred, LHS, RHS, ControlsOnlyExit,
9456 AllowPredicates);
9457 if (EL.hasAnyInfo())
9458 return EL;
9459
9460 auto *ExhaustiveCount =
9461 computeExitCountExhaustively(L, ExitCond, ExitIfTrue);
9462
9463 if (!isa<SCEVCouldNotCompute>(ExhaustiveCount))
9464 return ExhaustiveCount;
9465
9466 return computeShiftCompareExitLimit(ExitCond->getOperand(0),
9467 ExitCond->getOperand(1), L, OriginalPred);
9468}
9469ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromICmp(
9470 const Loop *L, CmpPredicate Pred, SCEVUse LHS, SCEVUse RHS,
9471 bool ControlsOnlyExit, bool AllowPredicates) {
9472
9473 // Try to evaluate any dependencies out of the loop.
9474 LHS = getSCEVAtScope(LHS, L);
9475 RHS = getSCEVAtScope(RHS, L);
9476
9477 // At this point, we would like to compute how many iterations of the
9478 // loop the predicate will return true for these inputs.
9479 if (isLoopInvariant(LHS, L) && !isLoopInvariant(RHS, L)) {
9480 // If there is a loop-invariant, force it into the RHS.
9481 std::swap(LHS, RHS);
9483 }
9484
9485 bool ControllingFiniteLoop = ControlsOnlyExit && loopHasNoAbnormalExits(L) &&
9487 // Simplify the operands before analyzing them.
9488 (void)SimplifyICmpOperands(Pred, LHS, RHS, /*Depth=*/0);
9489
9490 // If we have a comparison of a chrec against a constant, try to use value
9491 // ranges to answer this query.
9492 if (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS))
9493 if (const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(LHS))
9494 if (AddRec->getLoop() == L) {
9495 // Form the constant range.
9496 ConstantRange CompRange =
9497 ConstantRange::makeExactICmpRegion(Pred, RHSC->getAPInt());
9498
9499 const SCEV *Ret = AddRec->getNumIterationsInRange(CompRange, *this);
9500 if (!isa<SCEVCouldNotCompute>(Ret)) return Ret;
9501 }
9502
9503 // If this loop must exit based on this condition (or execute undefined
9504 // behaviour), see if we can improve wrap flags. This is essentially
9505 // a must execute style proof.
9506 if (ControllingFiniteLoop && isLoopInvariant(RHS, L)) {
9507 // If we can prove the test sequence produced must repeat the same values
9508 // on self-wrap of the IV, then we can infer that IV doesn't self wrap
9509 // because if it did, we'd have an infinite (undefined) loop.
9510 // TODO: We can peel off any functions which are invertible *in L*. Loop
9511 // invariant terms are effectively constants for our purposes here.
9512 SCEVUse InnerLHS = LHS;
9513 if (auto *ZExt = dyn_cast<SCEVZeroExtendExpr>(LHS))
9514 InnerLHS = ZExt->getOperand();
9515 if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(InnerLHS);
9516 AR && !AR->hasNoSelfWrap() && AR->getLoop() == L && AR->isAffine() &&
9517 isKnownToBeAPowerOfTwo(AR->getStepRecurrence(*this), /*OrZero=*/true,
9518 /*OrNegative=*/true)) {
9519 auto Flags = AR->getNoWrapFlags();
9520 Flags = setFlags(Flags, SCEV::FlagNW);
9521 SmallVector<SCEVUse> Operands{AR->operands()};
9522 Flags = StrengthenNoWrapFlags(this, scAddRecExpr, Operands, Flags);
9523 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), Flags);
9524 }
9525
9526 // For a slt/ult condition with a positive step, can we prove nsw/nuw?
9527 // From no-self-wrap, this follows trivially from the fact that every
9528 // (un)signed-wrapped, but not self-wrapped value must be LT than the
9529 // last value before (un)signed wrap. Since we know that last value
9530 // didn't exit, nor will any smaller one.
9531 if (Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_ULT) {
9532 auto WrapType = Pred == ICmpInst::ICMP_SLT ? SCEV::FlagNSW : SCEV::FlagNUW;
9533 if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(LHS);
9534 AR && AR->getLoop() == L && AR->isAffine() &&
9535 !AR->getNoWrapFlags(WrapType) && AR->hasNoSelfWrap() &&
9536 isKnownPositive(AR->getStepRecurrence(*this))) {
9537 auto Flags = AR->getNoWrapFlags();
9538 Flags = setFlags(Flags, WrapType);
9539 SmallVector<SCEVUse> Operands{AR->operands()};
9540 Flags = StrengthenNoWrapFlags(this, scAddRecExpr, Operands, Flags);
9541 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), Flags);
9542 }
9543 }
9544 }
9545
9546 switch (Pred) {
9547 case ICmpInst::ICMP_NE: { // while (X != Y)
9548 // Convert to: while (X-Y != 0)
9549 if (LHS->getType()->isPointerTy()) {
9552 return LHS;
9553 }
9554 if (RHS->getType()->isPointerTy()) {
9557 return RHS;
9558 }
9559 ExitLimit EL = howFarToZero(getMinusSCEV(LHS, RHS), L, ControlsOnlyExit,
9560 AllowPredicates);
9561 if (EL.hasAnyInfo())
9562 return EL;
9563 break;
9564 }
9565 case ICmpInst::ICMP_EQ: { // while (X == Y)
9566 // Convert to: while (X-Y == 0)
9567 if (LHS->getType()->isPointerTy()) {
9570 return LHS;
9571 }
9572 if (RHS->getType()->isPointerTy()) {
9575 return RHS;
9576 }
9577 ExitLimit EL = howFarToNonZero(getMinusSCEV(LHS, RHS), L);
9578 if (EL.hasAnyInfo()) return EL;
9579 break;
9580 }
9581 case ICmpInst::ICMP_SLE:
9582 case ICmpInst::ICMP_ULE:
9583 // Since the loop is finite, an invariant RHS cannot include the boundary
9584 // value, otherwise it would loop forever.
9585 if (!EnableFiniteLoopControl || !ControllingFiniteLoop ||
9586 !isLoopInvariant(RHS, L)) {
9587 // Otherwise, perform the addition in a wider type, to avoid overflow.
9588 // If the LHS is an addrec with the appropriate nowrap flag, the
9589 // extension will be sunk into it and the exit count can be analyzed.
9590 auto *OldType = dyn_cast<IntegerType>(LHS->getType());
9591 if (!OldType)
9592 break;
9593 // Prefer doubling the bitwidth over adding a single bit to make it more
9594 // likely that we use a legal type.
9595 auto *NewType =
9596 Type::getIntNTy(OldType->getContext(), OldType->getBitWidth() * 2);
9597 if (ICmpInst::isSigned(Pred)) {
9598 LHS = getSignExtendExpr(LHS, NewType);
9599 RHS = getSignExtendExpr(RHS, NewType);
9600 } else {
9601 LHS = getZeroExtendExpr(LHS, NewType);
9602 RHS = getZeroExtendExpr(RHS, NewType);
9603 }
9604 }
9606 [[fallthrough]];
9607 case ICmpInst::ICMP_SLT:
9608 case ICmpInst::ICMP_ULT: { // while (X < Y)
9609 bool IsSigned = ICmpInst::isSigned(Pred);
9610 ExitLimit EL = howManyLessThans(LHS, RHS, L, IsSigned, ControlsOnlyExit,
9611 AllowPredicates);
9612 if (EL.hasAnyInfo())
9613 return EL;
9614 break;
9615 }
9616 case ICmpInst::ICMP_SGE:
9617 case ICmpInst::ICMP_UGE:
9618 // Since the loop is finite, an invariant RHS cannot include the boundary
9619 // value, otherwise it would loop forever.
9620 if (!EnableFiniteLoopControl || !ControllingFiniteLoop ||
9621 !isLoopInvariant(RHS, L))
9622 break;
9624 [[fallthrough]];
9625 case ICmpInst::ICMP_SGT:
9626 case ICmpInst::ICMP_UGT: { // while (X > Y)
9627 bool IsSigned = ICmpInst::isSigned(Pred);
9628 ExitLimit EL = howManyGreaterThans(LHS, RHS, L, IsSigned, ControlsOnlyExit,
9629 AllowPredicates);
9630 if (EL.hasAnyInfo())
9631 return EL;
9632 break;
9633 }
9634 default:
9635 break;
9636 }
9637
9638 return getCouldNotCompute();
9639}
9640
9641ScalarEvolution::ExitLimit
9642ScalarEvolution::computeExitLimitFromSingleExitSwitch(const Loop *L,
9643 SwitchInst *Switch,
9644 BasicBlock *ExitingBlock,
9645 bool ControlsOnlyExit) {
9646 assert(!L->contains(ExitingBlock) && "Not an exiting block!");
9647
9648 // Give up if the exit is the default dest of a switch.
9649 if (Switch->getDefaultDest() == ExitingBlock)
9650 return getCouldNotCompute();
9651
9652 assert(L->contains(Switch->getDefaultDest()) &&
9653 "Default case must not exit the loop!");
9654 const SCEV *LHS = getSCEVAtScope(Switch->getCondition(), L);
9655 const SCEV *RHS = getConstant(Switch->findCaseDest(ExitingBlock));
9656
9657 // while (X != Y) --> while (X-Y != 0)
9658 ExitLimit EL = howFarToZero(getMinusSCEV(LHS, RHS), L, ControlsOnlyExit);
9659 if (EL.hasAnyInfo())
9660 return EL;
9661
9662 return getCouldNotCompute();
9663}
9664
9665static ConstantInt *
9667 ScalarEvolution &SE) {
9668 const SCEV *InVal = SE.getConstant(C);
9669 const SCEV *Val = AddRec->evaluateAtIteration(InVal, SE);
9671 "Evaluation of SCEV at constant didn't fold correctly?");
9672 return cast<SCEVConstant>(Val)->getValue();
9673}
9674
9675ScalarEvolution::ExitLimit ScalarEvolution::computeShiftCompareExitLimit(
9676 Value *LHS, Value *RHSV, const Loop *L, ICmpInst::Predicate Pred) {
9677 ConstantInt *RHS = dyn_cast<ConstantInt>(RHSV);
9678 if (!RHS)
9679 return getCouldNotCompute();
9680
9681 const BasicBlock *Latch = L->getLoopLatch();
9682 if (!Latch)
9683 return getCouldNotCompute();
9684
9685 const BasicBlock *Predecessor = L->getLoopPredecessor();
9686 if (!Predecessor)
9687 return getCouldNotCompute();
9688
9689 // Return true if V is of the form "LHS `shift_op` <positive constant>".
9690 // Return LHS in OutLHS and shift_opt in OutOpCode.
9691 auto MatchPositiveShift =
9692 [](Value *V, Value *&OutLHS, Instruction::BinaryOps &OutOpCode) {
9693
9694 using namespace PatternMatch;
9695
9696 ConstantInt *ShiftAmt;
9697 if (match(V, m_LShr(m_Value(OutLHS), m_ConstantInt(ShiftAmt))))
9698 OutOpCode = Instruction::LShr;
9699 else if (match(V, m_AShr(m_Value(OutLHS), m_ConstantInt(ShiftAmt))))
9700 OutOpCode = Instruction::AShr;
9701 else if (match(V, m_Shl(m_Value(OutLHS), m_ConstantInt(ShiftAmt))))
9702 OutOpCode = Instruction::Shl;
9703 else
9704 return false;
9705
9706 return ShiftAmt->getValue().isStrictlyPositive();
9707 };
9708
9709 // Recognize a "shift recurrence" either of the form %iv or of %iv.shifted in
9710 //
9711 // loop:
9712 // %iv = phi i32 [ %iv.shifted, %loop ], [ %val, %preheader ]
9713 // %iv.shifted = lshr i32 %iv, <positive constant>
9714 //
9715 // Return true on a successful match. Return the corresponding PHI node (%iv
9716 // above) in PNOut and the opcode of the shift operation in OpCodeOut.
9717 auto MatchShiftRecurrence =
9718 [&](Value *V, PHINode *&PNOut, Instruction::BinaryOps &OpCodeOut) {
9719 std::optional<Instruction::BinaryOps> PostShiftOpCode;
9720
9721 {
9723 Value *V;
9724
9725 // If we encounter a shift instruction, "peel off" the shift operation,
9726 // and remember that we did so. Later when we inspect %iv's backedge
9727 // value, we will make sure that the backedge value uses the same
9728 // operation.
9729 //
9730 // Note: the peeled shift operation does not have to be the same
9731 // instruction as the one feeding into the PHI's backedge value. We only
9732 // really care about it being the same *kind* of shift instruction --
9733 // that's all that is required for our later inferences to hold.
9734 if (MatchPositiveShift(LHS, V, OpC)) {
9735 PostShiftOpCode = OpC;
9736 LHS = V;
9737 }
9738 }
9739
9740 PNOut = dyn_cast<PHINode>(LHS);
9741 if (!PNOut || PNOut->getParent() != L->getHeader())
9742 return false;
9743
9744 Value *BEValue = PNOut->getIncomingValueForBlock(Latch);
9745 Value *OpLHS;
9746
9747 return
9748 // The backedge value for the PHI node must be a shift by a positive
9749 // amount
9750 MatchPositiveShift(BEValue, OpLHS, OpCodeOut) &&
9751
9752 // of the PHI node itself
9753 OpLHS == PNOut &&
9754
9755 // and the kind of shift should be match the kind of shift we peeled
9756 // off, if any.
9757 (!PostShiftOpCode || *PostShiftOpCode == OpCodeOut);
9758 };
9759
9760 PHINode *PN;
9762 if (!MatchShiftRecurrence(LHS, PN, OpCode))
9763 return getCouldNotCompute();
9764
9765 const DataLayout &DL = getDataLayout();
9766
9767 // The key rationale for this optimization is that for some kinds of shift
9768 // recurrences, the value of the recurrence "stabilizes" to either 0 or -1
9769 // within a finite number of iterations. If the condition guarding the
9770 // backedge (in the sense that the backedge is taken if the condition is true)
9771 // is false for the value the shift recurrence stabilizes to, then we know
9772 // that the backedge is taken only a finite number of times.
9773
9774 ConstantInt *StableValue = nullptr;
9775 switch (OpCode) {
9776 default:
9777 llvm_unreachable("Impossible case!");
9778
9779 case Instruction::AShr: {
9780 // {K,ashr,<positive-constant>} stabilizes to signum(K) in at most
9781 // bitwidth(K) iterations.
9782 Value *FirstValue = PN->getIncomingValueForBlock(Predecessor);
9783 KnownBits Known = computeKnownBits(FirstValue, DL, &AC,
9784 Predecessor->getTerminator(), &DT);
9785 auto *Ty = cast<IntegerType>(RHS->getType());
9786 if (Known.isNonNegative())
9787 StableValue = ConstantInt::get(Ty, 0);
9788 else if (Known.isNegative())
9789 StableValue = ConstantInt::get(Ty, -1, true);
9790 else
9791 return getCouldNotCompute();
9792
9793 break;
9794 }
9795 case Instruction::LShr:
9796 case Instruction::Shl:
9797 // Both {K,lshr,<positive-constant>} and {K,shl,<positive-constant>}
9798 // stabilize to 0 in at most bitwidth(K) iterations.
9799 StableValue = ConstantInt::get(cast<IntegerType>(RHS->getType()), 0);
9800 break;
9801 }
9802
9803 auto *Result =
9804 ConstantFoldCompareInstOperands(Pred, StableValue, RHS, DL, &TLI);
9805 assert(Result->getType()->isIntegerTy(1) &&
9806 "Otherwise cannot be an operand to a branch instruction");
9807
9808 if (Result->isNullValue()) {
9809 unsigned BitWidth = getTypeSizeInBits(RHS->getType());
9810 const SCEV *UpperBound =
9812 return ExitLimit(getCouldNotCompute(), UpperBound, UpperBound, false);
9813 }
9814
9815 return getCouldNotCompute();
9816}
9817
9818/// Return true if we can constant fold an instruction of the specified type,
9819/// assuming that all operands were constants.
9820static bool CanConstantFold(const Instruction *I) {
9824 return true;
9825
9826 if (const CallInst *CI = dyn_cast<CallInst>(I))
9827 if (const Function *F = CI->getCalledFunction())
9828 return canConstantFoldCallTo(CI, F);
9829 return false;
9830}
9831
9832/// Determine whether this instruction can constant evolve within this loop
9833/// assuming its operands can all constant evolve.
9834static bool canConstantEvolve(Instruction *I, const Loop *L) {
9835 // An instruction outside of the loop can't be derived from a loop PHI.
9836 if (!L->contains(I)) return false;
9837
9838 if (isa<PHINode>(I)) {
9839 // We don't currently keep track of the control flow needed to evaluate
9840 // PHIs, so we cannot handle PHIs inside of loops.
9841 return L->getHeader() == I->getParent();
9842 }
9843
9844 // If we won't be able to constant fold this expression even if the operands
9845 // are constants, bail early.
9846 return CanConstantFold(I);
9847}
9848
9849/// getConstantEvolvingPHIOperands - Implement getConstantEvolvingPHI by
9850/// recursing through each instruction operand until reaching a loop header phi.
9851static PHINode *
9854 unsigned Depth) {
9856 return nullptr;
9857
9858 // Otherwise, we can evaluate this instruction if all of its operands are
9859 // constant or derived from a PHI node themselves.
9860 PHINode *PHI = nullptr;
9861 for (Value *Op : UseInst->operands()) {
9862 if (isa<Constant>(Op)) continue;
9863
9865 if (!OpInst || !canConstantEvolve(OpInst, L)) return nullptr;
9866
9867 PHINode *P = dyn_cast<PHINode>(OpInst);
9868 if (!P)
9869 // If this operand is already visited, reuse the prior result.
9870 // We may have P != PHI if this is the deepest point at which the
9871 // inconsistent paths meet.
9872 P = PHIMap.lookup(OpInst);
9873 if (!P) {
9874 // Recurse and memoize the results, whether a phi is found or not.
9875 // This recursive call invalidates pointers into PHIMap.
9876 P = getConstantEvolvingPHIOperands(OpInst, L, PHIMap, Depth + 1);
9877 PHIMap[OpInst] = P;
9878 }
9879 if (!P)
9880 return nullptr; // Not evolving from PHI
9881 if (PHI && PHI != P)
9882 return nullptr; // Evolving from multiple different PHIs.
9883 PHI = P;
9884 }
9885 // This is a expression evolving from a constant PHI!
9886 return PHI;
9887}
9888
9889/// getConstantEvolvingPHI - Given an LLVM value and a loop, return a PHI node
9890/// in the loop that V is derived from. We allow arbitrary operations along the
9891/// way, but the operands of an operation must either be constants or a value
9892/// derived from a constant PHI. If this expression does not fit with these
9893/// constraints, return null.
9896 if (!I || !canConstantEvolve(I, L)) return nullptr;
9897
9898 if (PHINode *PN = dyn_cast<PHINode>(I))
9899 return PN;
9900
9901 // Record non-constant instructions contained by the loop.
9903 return getConstantEvolvingPHIOperands(I, L, PHIMap, 0);
9904}
9905
9906/// EvaluateExpression - Given an expression that passes the
9907/// getConstantEvolvingPHI predicate, evaluate its value assuming the PHI node
9908/// in the loop has the value PHIVal. If we can't fold this expression for some
9909/// reason, return null.
9912 const DataLayout &DL,
9913 const TargetLibraryInfo *TLI) {
9914 // Convenient constant check, but redundant for recursive calls.
9915 if (Constant *C = dyn_cast<Constant>(V)) return C;
9917 if (!I) return nullptr;
9918
9919 if (Constant *C = Vals.lookup(I)) return C;
9920
9921 // An instruction inside the loop depends on a value outside the loop that we
9922 // weren't given a mapping for, or a value such as a call inside the loop.
9923 if (!canConstantEvolve(I, L)) return nullptr;
9924
9925 // An unmapped PHI can be due to a branch or another loop inside this loop,
9926 // or due to this not being the initial iteration through a loop where we
9927 // couldn't compute the evolution of this particular PHI last time.
9928 if (isa<PHINode>(I)) return nullptr;
9929
9930 std::vector<Constant*> Operands(I->getNumOperands());
9931
9932 for (unsigned i = 0, e = I->getNumOperands(); i != e; ++i) {
9933 Instruction *Operand = dyn_cast<Instruction>(I->getOperand(i));
9934 if (!Operand) {
9935 Operands[i] = dyn_cast<Constant>(I->getOperand(i));
9936 if (!Operands[i]) return nullptr;
9937 continue;
9938 }
9939 Constant *C = EvaluateExpression(Operand, L, Vals, DL, TLI);
9940 Vals[Operand] = C;
9941 if (!C) return nullptr;
9942 Operands[i] = C;
9943 }
9944
9945 return ConstantFoldInstOperands(I, Operands, DL, TLI,
9946 /*AllowNonDeterministic=*/false);
9947}
9948
9949
9950// If every incoming value to PN except the one for BB is a specific Constant,
9951// return that, else return nullptr.
9953 Constant *IncomingVal = nullptr;
9954
9955 for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) {
9956 if (PN->getIncomingBlock(i) == BB)
9957 continue;
9958
9959 auto *CurrentVal = dyn_cast<Constant>(PN->getIncomingValue(i));
9960 if (!CurrentVal)
9961 return nullptr;
9962
9963 if (IncomingVal != CurrentVal) {
9964 if (IncomingVal)
9965 return nullptr;
9966 IncomingVal = CurrentVal;
9967 }
9968 }
9969
9970 return IncomingVal;
9971}
9972
9973/// getConstantEvolutionLoopExitValue - If we know that the specified Phi is
9974/// in the header of its containing loop, we know the loop executes a
9975/// constant number of times, and the PHI node is just a recurrence
9976/// involving constants, fold it.
9977Constant *
9978ScalarEvolution::getConstantEvolutionLoopExitValue(PHINode *PN,
9979 const APInt &BEs,
9980 const Loop *L) {
9981 auto [I, Inserted] = ConstantEvolutionLoopExitValue.try_emplace(PN);
9982 if (!Inserted)
9983 return I->second;
9984
9986 return nullptr; // Not going to evaluate it.
9987
9988 Constant *&RetVal = I->second;
9989
9990 DenseMap<Instruction *, Constant *> CurrentIterVals;
9991 BasicBlock *Header = L->getHeader();
9992 assert(PN->getParent() == Header && "Can't evaluate PHI not in loop header!");
9993
9994 BasicBlock *Latch = L->getLoopLatch();
9995 if (!Latch)
9996 return nullptr;
9997
9998 for (PHINode &PHI : Header->phis()) {
9999 if (auto *StartCST = getOtherIncomingValue(&PHI, Latch))
10000 CurrentIterVals[&PHI] = StartCST;
10001 }
10002 if (!CurrentIterVals.count(PN))
10003 return RetVal = nullptr;
10004
10005 Value *BEValue = PN->getIncomingValueForBlock(Latch);
10006
10007 // Execute the loop symbolically to determine the exit value.
10008 assert(BEs.getActiveBits() < CHAR_BIT * sizeof(unsigned) &&
10009 "BEs is <= MaxBruteForceIterations which is an 'unsigned'!");
10010
10011 unsigned NumIterations = BEs.getZExtValue(); // must be in range
10012 unsigned IterationNum = 0;
10013 const DataLayout &DL = getDataLayout();
10014 for (; ; ++IterationNum) {
10015 if (IterationNum == NumIterations)
10016 return RetVal = CurrentIterVals[PN]; // Got exit value!
10017
10018 // Compute the value of the PHIs for the next iteration.
10019 // EvaluateExpression adds non-phi values to the CurrentIterVals map.
10020 DenseMap<Instruction *, Constant *> NextIterVals;
10021 Constant *NextPHI =
10022 EvaluateExpression(BEValue, L, CurrentIterVals, DL, &TLI);
10023 if (!NextPHI)
10024 return nullptr; // Couldn't evaluate!
10025 NextIterVals[PN] = NextPHI;
10026
10027 bool StoppedEvolving = NextPHI == CurrentIterVals[PN];
10028
10029 // Also evaluate the other PHI nodes. However, we don't get to stop if we
10030 // cease to be able to evaluate one of them or if they stop evolving,
10031 // because that doesn't necessarily prevent us from computing PN.
10033 for (const auto &I : CurrentIterVals) {
10034 PHINode *PHI = dyn_cast<PHINode>(I.first);
10035 if (!PHI || PHI == PN || PHI->getParent() != Header) continue;
10036 PHIsToCompute.emplace_back(PHI, I.second);
10037 }
10038 // We use two distinct loops because EvaluateExpression may invalidate any
10039 // iterators into CurrentIterVals.
10040 for (const auto &I : PHIsToCompute) {
10041 PHINode *PHI = I.first;
10042 Constant *&NextPHI = NextIterVals[PHI];
10043 if (!NextPHI) { // Not already computed.
10044 Value *BEValue = PHI->getIncomingValueForBlock(Latch);
10045 NextPHI = EvaluateExpression(BEValue, L, CurrentIterVals, DL, &TLI);
10046 }
10047 if (NextPHI != I.second)
10048 StoppedEvolving = false;
10049 }
10050
10051 // If all entries in CurrentIterVals == NextIterVals then we can stop
10052 // iterating, the loop can't continue to change.
10053 if (StoppedEvolving)
10054 return RetVal = CurrentIterVals[PN];
10055
10056 CurrentIterVals.swap(NextIterVals);
10057 }
10058}
10059
10060const SCEV *ScalarEvolution::computeExitCountExhaustively(const Loop *L,
10061 Value *Cond,
10062 bool ExitWhen) {
10063 PHINode *PN = getConstantEvolvingPHI(Cond, L);
10064 if (!PN) return getCouldNotCompute();
10065
10066 // If the loop is canonicalized, the PHI will have exactly two entries.
10067 // That's the only form we support here.
10068 if (PN->getNumIncomingValues() != 2) return getCouldNotCompute();
10069
10070 DenseMap<Instruction *, Constant *> CurrentIterVals;
10071 BasicBlock *Header = L->getHeader();
10072 assert(PN->getParent() == Header && "Can't evaluate PHI not in loop header!");
10073
10074 BasicBlock *Latch = L->getLoopLatch();
10075 assert(Latch && "Should follow from NumIncomingValues == 2!");
10076
10077 for (PHINode &PHI : Header->phis()) {
10078 if (auto *StartCST = getOtherIncomingValue(&PHI, Latch))
10079 CurrentIterVals[&PHI] = StartCST;
10080 }
10081 if (!CurrentIterVals.count(PN))
10082 return getCouldNotCompute();
10083
10084 // Okay, we find a PHI node that defines the trip count of this loop. Execute
10085 // the loop symbolically to determine when the condition gets a value of
10086 // "ExitWhen".
10087 unsigned MaxIterations = MaxBruteForceIterations; // Limit analysis.
10088 const DataLayout &DL = getDataLayout();
10089 for (unsigned IterationNum = 0; IterationNum != MaxIterations;++IterationNum){
10090 auto *CondVal = dyn_cast_or_null<ConstantInt>(
10091 EvaluateExpression(Cond, L, CurrentIterVals, DL, &TLI));
10092
10093 // Couldn't symbolically evaluate.
10094 if (!CondVal) return getCouldNotCompute();
10095
10096 if (CondVal->getValue() == uint64_t(ExitWhen)) {
10097 ++NumBruteForceTripCountsComputed;
10098 return getConstant(Type::getInt32Ty(getContext()), IterationNum);
10099 }
10100
10101 // Update all the PHI nodes for the next iteration.
10102 DenseMap<Instruction *, Constant *> NextIterVals;
10103
10104 // Create a list of which PHIs we need to compute. We want to do this before
10105 // calling EvaluateExpression on them because that may invalidate iterators
10106 // into CurrentIterVals.
10107 SmallVector<PHINode *, 8> PHIsToCompute;
10108 for (const auto &I : CurrentIterVals) {
10109 PHINode *PHI = dyn_cast<PHINode>(I.first);
10110 if (!PHI || PHI->getParent() != Header) continue;
10111 PHIsToCompute.push_back(PHI);
10112 }
10113 for (PHINode *PHI : PHIsToCompute) {
10114 Constant *&NextPHI = NextIterVals[PHI];
10115 if (NextPHI) continue; // Already computed!
10116
10117 Value *BEValue = PHI->getIncomingValueForBlock(Latch);
10118 NextPHI = EvaluateExpression(BEValue, L, CurrentIterVals, DL, &TLI);
10119 }
10120 CurrentIterVals.swap(NextIterVals);
10121 }
10122
10123 // Too many iterations were needed to evaluate.
10124 return getCouldNotCompute();
10125}
10126
10127const SCEV *ScalarEvolution::getSCEVAtScope(const SCEV *V, const Loop *L) {
10129 ValuesAtScopes[V];
10130 // Check to see if we've folded this expression at this loop before.
10131 for (auto &LS : Values)
10132 if (LS.first == L)
10133 return LS.second ? LS.second : V;
10134
10135 Values.emplace_back(L, nullptr);
10136
10137 // Otherwise compute it.
10138 const SCEV *C = computeSCEVAtScope(V, L);
10139 for (auto &LS : reverse(ValuesAtScopes[V]))
10140 if (LS.first == L) {
10141 LS.second = C;
10142 if (!isa<SCEVConstant>(C))
10143 ValuesAtScopesUsers[C].push_back({L, V});
10144 break;
10145 }
10146 return C;
10147}
10148
10149/// This builds up a Constant using the ConstantExpr interface. That way, we
10150/// will return Constants for objects which aren't represented by a
10151/// SCEVConstant, because SCEVConstant is restricted to ConstantInt.
10152/// Returns NULL if the SCEV isn't representable as a Constant.
10154 switch (V->getSCEVType()) {
10155 case scCouldNotCompute:
10156 case scAddRecExpr:
10157 case scVScale:
10158 return nullptr;
10159 case scConstant:
10160 return cast<SCEVConstant>(V)->getValue();
10161 case scUnknown:
10162 return dyn_cast<Constant>(cast<SCEVUnknown>(V)->getValue());
10163 case scPtrToAddr: {
10165 if (Constant *CastOp = BuildConstantFromSCEV(P2I->getOperand()))
10166 return ConstantExpr::getPtrToAddr(CastOp, P2I->getType());
10167
10168 return nullptr;
10169 }
10170 case scPtrToInt: {
10172 if (Constant *CastOp = BuildConstantFromSCEV(P2I->getOperand()))
10173 return ConstantExpr::getPtrToInt(CastOp, P2I->getType());
10174
10175 return nullptr;
10176 }
10177 case scTruncate: {
10179 if (Constant *CastOp = BuildConstantFromSCEV(ST->getOperand()))
10180 return ConstantExpr::getTrunc(CastOp, ST->getType());
10181 return nullptr;
10182 }
10183 case scAddExpr: {
10184 const SCEVAddExpr *SA = cast<SCEVAddExpr>(V);
10185 Constant *C = nullptr;
10186 for (const SCEV *Op : SA->operands()) {
10188 if (!OpC)
10189 return nullptr;
10190 if (!C) {
10191 C = OpC;
10192 continue;
10193 }
10194 assert(!C->getType()->isPointerTy() &&
10195 "Can only have one pointer, and it must be last");
10196 if (OpC->getType()->isPointerTy()) {
10197 // The offsets have been converted to bytes. We can add bytes using
10198 // an i8 GEP.
10199 C = ConstantExpr::getPtrAdd(OpC, C);
10200 } else {
10201 C = ConstantExpr::getAdd(C, OpC);
10202 }
10203 }
10204 return C;
10205 }
10206 case scMulExpr:
10207 case scSignExtend:
10208 case scZeroExtend:
10209 case scUDivExpr:
10210 case scSMaxExpr:
10211 case scUMaxExpr:
10212 case scSMinExpr:
10213 case scUMinExpr:
10215 return nullptr;
10216 }
10217 llvm_unreachable("Unknown SCEV kind!");
10218}
10219
10220const SCEV *ScalarEvolution::getWithOperands(const SCEV *S,
10221 SmallVectorImpl<SCEVUse> &NewOps) {
10222 switch (S->getSCEVType()) {
10223 case scTruncate:
10224 case scZeroExtend:
10225 case scSignExtend:
10226 case scPtrToAddr:
10227 case scPtrToInt:
10228 return getCastExpr(S->getSCEVType(), NewOps[0], S->getType());
10229 case scAddRecExpr: {
10230 auto *AddRec = cast<SCEVAddRecExpr>(S);
10231 return getAddRecExpr(NewOps, AddRec->getLoop(), AddRec->getNoWrapFlags());
10232 }
10233 case scAddExpr:
10234 return getAddExpr(NewOps, cast<SCEVAddExpr>(S)->getNoWrapFlags());
10235 case scMulExpr:
10236 return getMulExpr(NewOps, cast<SCEVMulExpr>(S)->getNoWrapFlags());
10237 case scUDivExpr:
10238 return getUDivExpr(NewOps[0], NewOps[1]);
10239 case scUMaxExpr:
10240 case scSMaxExpr:
10241 case scUMinExpr:
10242 case scSMinExpr:
10243 return getMinMaxExpr(S->getSCEVType(), NewOps);
10245 return getSequentialMinMaxExpr(S->getSCEVType(), NewOps);
10246 case scConstant:
10247 case scVScale:
10248 case scUnknown:
10249 return S;
10250 case scCouldNotCompute:
10251 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
10252 }
10253 llvm_unreachable("Unknown SCEV kind!");
10254}
10255
10256const SCEV *ScalarEvolution::computeSCEVAtScope(const SCEV *V, const Loop *L) {
10257 switch (V->getSCEVType()) {
10258 case scConstant:
10259 case scVScale:
10260 return V;
10261 case scAddRecExpr: {
10262 // If this is a loop recurrence for a loop that does not contain L, then we
10263 // are dealing with the final value computed by the loop.
10264 const SCEVAddRecExpr *AddRec = cast<SCEVAddRecExpr>(V);
10265 // First, attempt to evaluate each operand.
10266 // Avoid performing the look-up in the common case where the specified
10267 // expression has no loop-variant portions.
10268 for (unsigned i = 0, e = AddRec->getNumOperands(); i != e; ++i) {
10269 const SCEV *OpAtScope = getSCEVAtScope(AddRec->getOperand(i), L);
10270 if (OpAtScope == AddRec->getOperand(i))
10271 continue;
10272
10273 // Okay, at least one of these operands is loop variant but might be
10274 // foldable. Build a new instance of the folded commutative expression.
10276 NewOps.reserve(AddRec->getNumOperands());
10277 append_range(NewOps, AddRec->operands().take_front(i));
10278 NewOps.push_back(OpAtScope);
10279 for (++i; i != e; ++i)
10280 NewOps.push_back(getSCEVAtScope(AddRec->getOperand(i), L));
10281
10282 const SCEV *FoldedRec = getAddRecExpr(
10283 NewOps, AddRec->getLoop(), AddRec->getNoWrapFlags(SCEV::FlagNW));
10284 AddRec = dyn_cast<SCEVAddRecExpr>(FoldedRec);
10285 // The addrec may be folded to a nonrecurrence, for example, if the
10286 // induction variable is multiplied by zero after constant folding. Go
10287 // ahead and return the folded value.
10288 if (!AddRec)
10289 return FoldedRec;
10290 break;
10291 }
10292
10293 // If the scope is outside the addrec's loop, evaluate it by using the
10294 // loop exit value of the addrec.
10295 if (!AddRec->getLoop()->contains(L)) {
10296 // To evaluate this recurrence, we need to know how many times the AddRec
10297 // loop iterates. Compute this now.
10298 const SCEV *BackedgeTakenCount = getBackedgeTakenCount(AddRec->getLoop());
10299 if (BackedgeTakenCount == getCouldNotCompute())
10300 return AddRec;
10301
10302 // Then, evaluate the AddRec.
10303 return AddRec->evaluateAtIteration(BackedgeTakenCount, *this);
10304 }
10305
10306 return AddRec;
10307 }
10308 case scTruncate:
10309 case scZeroExtend:
10310 case scSignExtend:
10311 case scPtrToAddr:
10312 case scPtrToInt:
10313 case scAddExpr:
10314 case scMulExpr:
10315 case scUDivExpr:
10316 case scUMaxExpr:
10317 case scSMaxExpr:
10318 case scUMinExpr:
10319 case scSMinExpr:
10320 case scSequentialUMinExpr: {
10321 ArrayRef<SCEVUse> Ops = V->operands();
10322 // Avoid performing the look-up in the common case where the specified
10323 // expression has no loop-variant portions.
10324 for (unsigned i = 0, e = Ops.size(); i != e; ++i) {
10325 const SCEV *OpAtScope = getSCEVAtScope(Ops[i].getPointer(), L);
10326 if (OpAtScope != Ops[i].getPointer()) {
10327 // Okay, at least one of these operands is loop variant but might be
10328 // foldable. Build a new instance of the folded commutative expression.
10330 NewOps.reserve(Ops.size());
10331 append_range(NewOps, Ops.take_front(i));
10332 NewOps.push_back(OpAtScope);
10333
10334 for (++i; i != e; ++i) {
10335 OpAtScope = getSCEVAtScope(Ops[i].getPointer(), L);
10336 NewOps.push_back(OpAtScope);
10337 }
10338
10339 return getWithOperands(V, NewOps);
10340 }
10341 }
10342 // If we got here, all operands are loop invariant.
10343 return V;
10344 }
10345 case scUnknown: {
10346 // If this instruction is evolved from a constant-evolving PHI, compute the
10347 // exit value from the loop without using SCEVs.
10348 const SCEVUnknown *SU = cast<SCEVUnknown>(V);
10350 if (!I)
10351 return V; // This is some other type of SCEVUnknown, just return it.
10352
10353 if (PHINode *PN = dyn_cast<PHINode>(I)) {
10354 const Loop *CurrLoop = this->LI[I->getParent()];
10355 // Looking for loop exit value.
10356 if (CurrLoop && CurrLoop->getParentLoop() == L &&
10357 PN->getParent() == CurrLoop->getHeader()) {
10358 // Okay, there is no closed form solution for the PHI node. Check
10359 // to see if the loop that contains it has a known backedge-taken
10360 // count. If so, we may be able to force computation of the exit
10361 // value.
10362 const SCEV *BackedgeTakenCount = getBackedgeTakenCount(CurrLoop);
10363 // This trivial case can show up in some degenerate cases where
10364 // the incoming IR has not yet been fully simplified.
10365 if (BackedgeTakenCount->isZero()) {
10366 Value *InitValue = nullptr;
10367 bool MultipleInitValues = false;
10368 for (unsigned i = 0; i < PN->getNumIncomingValues(); i++) {
10369 if (!CurrLoop->contains(PN->getIncomingBlock(i))) {
10370 if (!InitValue)
10371 InitValue = PN->getIncomingValue(i);
10372 else if (InitValue != PN->getIncomingValue(i)) {
10373 MultipleInitValues = true;
10374 break;
10375 }
10376 }
10377 }
10378 if (!MultipleInitValues && InitValue)
10379 return getSCEV(InitValue);
10380 }
10381 // Do we have a loop invariant value flowing around the backedge
10382 // for a loop which must execute the backedge?
10383 if (!isa<SCEVCouldNotCompute>(BackedgeTakenCount) &&
10384 isKnownNonZero(BackedgeTakenCount) &&
10385 PN->getNumIncomingValues() == 2) {
10386
10387 unsigned InLoopPred =
10388 CurrLoop->contains(PN->getIncomingBlock(0)) ? 0 : 1;
10389 Value *BackedgeVal = PN->getIncomingValue(InLoopPred);
10390 if (CurrLoop->isLoopInvariant(BackedgeVal))
10391 return getSCEV(BackedgeVal);
10392 }
10393 if (auto *BTCC = dyn_cast<SCEVConstant>(BackedgeTakenCount)) {
10394 // Okay, we know how many times the containing loop executes. If
10395 // this is a constant evolving PHI node, get the final value at
10396 // the specified iteration number.
10397 Constant *RV =
10398 getConstantEvolutionLoopExitValue(PN, BTCC->getAPInt(), CurrLoop);
10399 if (RV)
10400 return getSCEV(RV);
10401 }
10402 }
10403 }
10404
10405 // Okay, this is an expression that we cannot symbolically evaluate
10406 // into a SCEV. Check to see if it's possible to symbolically evaluate
10407 // the arguments into constants, and if so, try to constant propagate the
10408 // result. This is particularly useful for computing loop exit values.
10409 if (!CanConstantFold(I))
10410 return V; // This is some other type of SCEVUnknown, just return it.
10411
10412 SmallVector<Constant *, 4> Operands;
10413 Operands.reserve(I->getNumOperands());
10414 bool MadeImprovement = false;
10415 for (Value *Op : I->operands()) {
10416 if (Constant *C = dyn_cast<Constant>(Op)) {
10417 Operands.push_back(C);
10418 continue;
10419 }
10420
10421 // If any of the operands is non-constant and if they are
10422 // non-integer and non-pointer, don't even try to analyze them
10423 // with scev techniques.
10424 if (!isSCEVable(Op->getType()))
10425 return V;
10426
10427 const SCEV *OrigV = getSCEV(Op);
10428 const SCEV *OpV = getSCEVAtScope(OrigV, L);
10429 MadeImprovement |= OrigV != OpV;
10430
10432 if (!C)
10433 return V;
10434 assert(C->getType() == Op->getType() && "Type mismatch");
10435 Operands.push_back(C);
10436 }
10437
10438 // Check to see if getSCEVAtScope actually made an improvement.
10439 if (!MadeImprovement)
10440 return V; // This is some other type of SCEVUnknown, just return it.
10441
10442 Constant *C = nullptr;
10443 const DataLayout &DL = getDataLayout();
10444 C = ConstantFoldInstOperands(I, Operands, DL, &TLI,
10445 /*AllowNonDeterministic=*/false);
10446 if (!C)
10447 return V;
10448 return getSCEV(C);
10449 }
10450 case scCouldNotCompute:
10451 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
10452 }
10453 llvm_unreachable("Unknown SCEV type!");
10454}
10455
10457 return getSCEVAtScope(getSCEV(V), L);
10458}
10459
10460const SCEV *ScalarEvolution::stripInjectiveFunctions(const SCEV *S) const {
10462 return stripInjectiveFunctions(ZExt->getOperand());
10464 return stripInjectiveFunctions(SExt->getOperand());
10465 return S;
10466}
10467
10468/// Finds the minimum unsigned root of the following equation:
10469///
10470/// A * X = B (mod N)
10471///
10472/// where N = 2^BW and BW is the common bit width of A and B. The signedness of
10473/// A and B isn't important.
10474///
10475/// If the equation does not have a solution, SCEVCouldNotCompute is returned.
10476static const SCEV *
10479 ScalarEvolution &SE, const Loop *L) {
10480 uint32_t BW = A.getBitWidth();
10481 assert(BW == SE.getTypeSizeInBits(B->getType()));
10482 assert(A != 0 && "A must be non-zero.");
10483
10484 // 1. D = gcd(A, N)
10485 //
10486 // The gcd of A and N may have only one prime factor: 2. The number of
10487 // trailing zeros in A is its multiplicity
10488 uint32_t Mult2 = A.countr_zero();
10489 // D = 2^Mult2
10490
10491 // 2. Check if B is divisible by D.
10492 //
10493 // B is divisible by D if and only if the multiplicity of prime factor 2 for B
10494 // is not less than multiplicity of this prime factor for D.
10495 unsigned MinTZ = SE.getMinTrailingZeros(B);
10496 // Try again with the terminator of the loop predecessor for context-specific
10497 // result, if MinTZ s too small.
10498 if (MinTZ < Mult2 && L->getLoopPredecessor())
10499 MinTZ = SE.getMinTrailingZeros(B, L->getLoopPredecessor()->getTerminator());
10500 if (MinTZ < Mult2) {
10501 // Check if we can prove there's no remainder using URem.
10502 const SCEV *URem =
10503 SE.getURemExpr(B, SE.getConstant(APInt::getOneBitSet(BW, Mult2)));
10504 const SCEV *Zero = SE.getZero(B->getType());
10505 if (!SE.isKnownPredicate(CmpInst::ICMP_EQ, URem, Zero)) {
10506 // Try to add a predicate ensuring B is a multiple of 1 << Mult2.
10507 if (!Predicates)
10508 return SE.getCouldNotCompute();
10509
10510 // Avoid adding a predicate that is known to be false.
10511 if (SE.isKnownPredicate(CmpInst::ICMP_NE, URem, Zero))
10512 return SE.getCouldNotCompute();
10513 Predicates->push_back(SE.getEqualPredicate(URem, Zero));
10514 }
10515 }
10516
10517 // 3. Compute I: the multiplicative inverse of (A / D) in arithmetic
10518 // modulo (N / D).
10519 //
10520 // If D == 1, (N / D) == N == 2^BW, so we need one extra bit to represent
10521 // (N / D) in general. The inverse itself always fits into BW bits, though,
10522 // so we immediately truncate it.
10523 APInt AD = A.lshr(Mult2).trunc(BW - Mult2); // AD = A / D
10524 APInt I = AD.multiplicativeInverse().zext(BW);
10525
10526 // 4. Compute the minimum unsigned root of the equation:
10527 // I * (B / D) mod (N / D)
10528 // To simplify the computation, we factor out the divide by D:
10529 // (I * B mod N) / D
10530 const SCEV *D = SE.getConstant(APInt::getOneBitSet(BW, Mult2));
10531 return SE.getUDivExactExpr(SE.getMulExpr(B, SE.getConstant(I)), D);
10532}
10533
10534/// For a given quadratic addrec, generate coefficients of the corresponding
10535/// quadratic equation, multiplied by a common value to ensure that they are
10536/// integers.
10537/// The returned value is a tuple { A, B, C, M, BitWidth }, where
10538/// Ax^2 + Bx + C is the quadratic function, M is the value that A, B and C
10539/// were multiplied by, and BitWidth is the bit width of the original addrec
10540/// coefficients.
10541/// This function returns std::nullopt if the addrec coefficients are not
10542/// compile- time constants.
10543static std::optional<std::tuple<APInt, APInt, APInt, APInt, unsigned>>
10545 assert(AddRec->getNumOperands() == 3 && "This is not a quadratic chrec!");
10546 const SCEVConstant *LC = dyn_cast<SCEVConstant>(AddRec->getOperand(0));
10547 const SCEVConstant *MC = dyn_cast<SCEVConstant>(AddRec->getOperand(1));
10548 const SCEVConstant *NC = dyn_cast<SCEVConstant>(AddRec->getOperand(2));
10549 LLVM_DEBUG(dbgs() << __func__ << ": analyzing quadratic addrec: "
10550 << *AddRec << '\n');
10551
10552 // We currently can only solve this if the coefficients are constants.
10553 if (!LC || !MC || !NC) {
10554 LLVM_DEBUG(dbgs() << __func__ << ": coefficients are not constant\n");
10555 return std::nullopt;
10556 }
10557
10558 APInt L = LC->getAPInt();
10559 APInt M = MC->getAPInt();
10560 APInt N = NC->getAPInt();
10561 assert(!N.isZero() && "This is not a quadratic addrec");
10562
10563 unsigned BitWidth = LC->getAPInt().getBitWidth();
10564 unsigned NewWidth = BitWidth + 1;
10565 LLVM_DEBUG(dbgs() << __func__ << ": addrec coeff bw: "
10566 << BitWidth << '\n');
10567 // The sign-extension (as opposed to a zero-extension) here matches the
10568 // extension used in SolveQuadraticEquationWrap (with the same motivation).
10569 N = N.sext(NewWidth);
10570 M = M.sext(NewWidth);
10571 L = L.sext(NewWidth);
10572
10573 // The increments are M, M+N, M+2N, ..., so the accumulated values are
10574 // L+M, (L+M)+(M+N), (L+M)+(M+N)+(M+2N), ..., that is,
10575 // L+M, L+2M+N, L+3M+3N, ...
10576 // After n iterations the accumulated value Acc is L + nM + n(n-1)/2 N.
10577 //
10578 // The equation Acc = 0 is then
10579 // L + nM + n(n-1)/2 N = 0, or 2L + 2M n + n(n-1) N = 0.
10580 // In a quadratic form it becomes:
10581 // N n^2 + (2M-N) n + 2L = 0.
10582
10583 APInt A = N;
10584 APInt B = 2 * M - A;
10585 APInt C = 2 * L;
10586 APInt T = APInt(NewWidth, 2);
10587 LLVM_DEBUG(dbgs() << __func__ << ": equation " << A << "x^2 + " << B
10588 << "x + " << C << ", coeff bw: " << NewWidth
10589 << ", multiplied by " << T << '\n');
10590 return std::make_tuple(A, B, C, T, BitWidth);
10591}
10592
10593/// Helper function to compare optional APInts:
10594/// (a) if X and Y both exist, return min(X, Y),
10595/// (b) if neither X nor Y exist, return std::nullopt,
10596/// (c) if exactly one of X and Y exists, return that value.
10597static std::optional<APInt> MinOptional(std::optional<APInt> X,
10598 std::optional<APInt> Y) {
10599 if (X && Y) {
10600 unsigned W = std::max(X->getBitWidth(), Y->getBitWidth());
10601 APInt XW = X->sext(W);
10602 APInt YW = Y->sext(W);
10603 return XW.slt(YW) ? *X : *Y;
10604 }
10605 if (!X && !Y)
10606 return std::nullopt;
10607 return X ? *X : *Y;
10608}
10609
10610/// Helper function to truncate an optional APInt to a given BitWidth.
10611/// When solving addrec-related equations, it is preferable to return a value
10612/// that has the same bit width as the original addrec's coefficients. If the
10613/// solution fits in the original bit width, truncate it (except for i1).
10614/// Returning a value of a different bit width may inhibit some optimizations.
10615///
10616/// In general, a solution to a quadratic equation generated from an addrec
10617/// may require BW+1 bits, where BW is the bit width of the addrec's
10618/// coefficients. The reason is that the coefficients of the quadratic
10619/// equation are BW+1 bits wide (to avoid truncation when converting from
10620/// the addrec to the equation).
10621static std::optional<APInt> TruncIfPossible(std::optional<APInt> X,
10622 unsigned BitWidth) {
10623 if (!X)
10624 return std::nullopt;
10625 unsigned W = X->getBitWidth();
10627 return X->trunc(BitWidth);
10628 return X;
10629}
10630
10631/// Let c(n) be the value of the quadratic chrec {L,+,M,+,N} after n
10632/// iterations. The values L, M, N are assumed to be signed, and they
10633/// should all have the same bit widths.
10634/// Find the least n >= 0 such that c(n) = 0 in the arithmetic modulo 2^BW,
10635/// where BW is the bit width of the addrec's coefficients.
10636/// If the calculated value is a BW-bit integer (for BW > 1), it will be
10637/// returned as such, otherwise the bit width of the returned value may
10638/// be greater than BW.
10639///
10640/// This function returns std::nullopt if
10641/// (a) the addrec coefficients are not constant, or
10642/// (b) SolveQuadraticEquationWrap was unable to find a solution. For cases
10643/// like x^2 = 5, no integer solutions exist, in other cases an integer
10644/// solution may exist, but SolveQuadraticEquationWrap may fail to find it.
10645static std::optional<APInt>
10647 APInt A, B, C, M;
10648 unsigned BitWidth;
10649 auto T = GetQuadraticEquation(AddRec);
10650 if (!T)
10651 return std::nullopt;
10652
10653 std::tie(A, B, C, M, BitWidth) = *T;
10654 LLVM_DEBUG(dbgs() << __func__ << ": solving for unsigned overflow\n");
10655 std::optional<APInt> X =
10657 if (!X)
10658 return std::nullopt;
10659
10660 ConstantInt *CX = ConstantInt::get(SE.getContext(), *X);
10661 ConstantInt *V = EvaluateConstantChrecAtConstant(AddRec, CX, SE);
10662 if (!V->isZero())
10663 return std::nullopt;
10664
10665 return TruncIfPossible(X, BitWidth);
10666}
10667
10668/// Let c(n) be the value of the quadratic chrec {0,+,M,+,N} after n
10669/// iterations. The values M, N are assumed to be signed, and they
10670/// should all have the same bit widths.
10671/// Find the least n such that c(n) does not belong to the given range,
10672/// while c(n-1) does.
10673///
10674/// This function returns std::nullopt if
10675/// (a) the addrec coefficients are not constant, or
10676/// (b) SolveQuadraticEquationWrap was unable to find a solution for the
10677/// bounds of the range.
10678static std::optional<APInt>
10680 const ConstantRange &Range, ScalarEvolution &SE) {
10681 assert(AddRec->getOperand(0)->isZero() &&
10682 "Starting value of addrec should be 0");
10683 LLVM_DEBUG(dbgs() << __func__ << ": solving boundary crossing for range "
10684 << Range << ", addrec " << *AddRec << '\n');
10685 // This case is handled in getNumIterationsInRange. Here we can assume that
10686 // we start in the range.
10687 assert(Range.contains(APInt(SE.getTypeSizeInBits(AddRec->getType()), 0)) &&
10688 "Addrec's initial value should be in range");
10689
10690 APInt A, B, C, M;
10691 unsigned BitWidth;
10692 auto T = GetQuadraticEquation(AddRec);
10693 if (!T)
10694 return std::nullopt;
10695
10696 // Be careful about the return value: there can be two reasons for not
10697 // returning an actual number. First, if no solutions to the equations
10698 // were found, and second, if the solutions don't leave the given range.
10699 // The first case means that the actual solution is "unknown", the second
10700 // means that it's known, but not valid. If the solution is unknown, we
10701 // cannot make any conclusions.
10702 // Return a pair: the optional solution and a flag indicating if the
10703 // solution was found.
10704 auto SolveForBoundary =
10705 [&](APInt Bound) -> std::pair<std::optional<APInt>, bool> {
10706 // Solve for signed overflow and unsigned overflow, pick the lower
10707 // solution.
10708 LLVM_DEBUG(dbgs() << "SolveQuadraticAddRecRange: checking boundary "
10709 << Bound << " (before multiplying by " << M << ")\n");
10710 Bound *= M; // The quadratic equation multiplier.
10711
10712 std::optional<APInt> SO;
10713 if (BitWidth > 1) {
10714 LLVM_DEBUG(dbgs() << "SolveQuadraticAddRecRange: solving for "
10715 "signed overflow\n");
10717 }
10718 LLVM_DEBUG(dbgs() << "SolveQuadraticAddRecRange: solving for "
10719 "unsigned overflow\n");
10720 std::optional<APInt> UO =
10722
10723 auto LeavesRange = [&] (const APInt &X) {
10724 ConstantInt *C0 = ConstantInt::get(SE.getContext(), X);
10725 ConstantInt *V0 = EvaluateConstantChrecAtConstant(AddRec, C0, SE);
10726 if (Range.contains(V0->getValue()))
10727 return false;
10728 // X should be at least 1, so X-1 is non-negative.
10729 ConstantInt *C1 = ConstantInt::get(SE.getContext(), X-1);
10730 ConstantInt *V1 = EvaluateConstantChrecAtConstant(AddRec, C1, SE);
10731 if (Range.contains(V1->getValue()))
10732 return true;
10733 return false;
10734 };
10735
10736 // If SolveQuadraticEquationWrap returns std::nullopt, it means that there
10737 // can be a solution, but the function failed to find it. We cannot treat it
10738 // as "no solution".
10739 if (!SO || !UO)
10740 return {std::nullopt, false};
10741
10742 // Check the smaller value first to see if it leaves the range.
10743 // At this point, both SO and UO must have values.
10744 std::optional<APInt> Min = MinOptional(SO, UO);
10745 if (LeavesRange(*Min))
10746 return { Min, true };
10747 std::optional<APInt> Max = Min == SO ? UO : SO;
10748 if (LeavesRange(*Max))
10749 return { Max, true };
10750
10751 // Solutions were found, but were eliminated, hence the "true".
10752 return {std::nullopt, true};
10753 };
10754
10755 std::tie(A, B, C, M, BitWidth) = *T;
10756 // Lower bound is inclusive, subtract 1 to represent the exiting value.
10757 APInt Lower = Range.getLower().sext(A.getBitWidth()) - 1;
10758 APInt Upper = Range.getUpper().sext(A.getBitWidth());
10759 auto SL = SolveForBoundary(Lower);
10760 auto SU = SolveForBoundary(Upper);
10761 // If any of the solutions was unknown, no meaninigful conclusions can
10762 // be made.
10763 if (!SL.second || !SU.second)
10764 return std::nullopt;
10765
10766 // Claim: The correct solution is not some value between Min and Max.
10767 //
10768 // Justification: Assuming that Min and Max are different values, one of
10769 // them is when the first signed overflow happens, the other is when the
10770 // first unsigned overflow happens. Crossing the range boundary is only
10771 // possible via an overflow (treating 0 as a special case of it, modeling
10772 // an overflow as crossing k*2^W for some k).
10773 //
10774 // The interesting case here is when Min was eliminated as an invalid
10775 // solution, but Max was not. The argument is that if there was another
10776 // overflow between Min and Max, it would also have been eliminated if
10777 // it was considered.
10778 //
10779 // For a given boundary, it is possible to have two overflows of the same
10780 // type (signed/unsigned) without having the other type in between: this
10781 // can happen when the vertex of the parabola is between the iterations
10782 // corresponding to the overflows. This is only possible when the two
10783 // overflows cross k*2^W for the same k. In such case, if the second one
10784 // left the range (and was the first one to do so), the first overflow
10785 // would have to enter the range, which would mean that either we had left
10786 // the range before or that we started outside of it. Both of these cases
10787 // are contradictions.
10788 //
10789 // Claim: In the case where SolveForBoundary returns std::nullopt, the correct
10790 // solution is not some value between the Max for this boundary and the
10791 // Min of the other boundary.
10792 //
10793 // Justification: Assume that we had such Max_A and Min_B corresponding
10794 // to range boundaries A and B and such that Max_A < Min_B. If there was
10795 // a solution between Max_A and Min_B, it would have to be caused by an
10796 // overflow corresponding to either A or B. It cannot correspond to B,
10797 // since Min_B is the first occurrence of such an overflow. If it
10798 // corresponded to A, it would have to be either a signed or an unsigned
10799 // overflow that is larger than both eliminated overflows for A. But
10800 // between the eliminated overflows and this overflow, the values would
10801 // cover the entire value space, thus crossing the other boundary, which
10802 // is a contradiction.
10803
10804 return TruncIfPossible(MinOptional(SL.first, SU.first), BitWidth);
10805}
10806
10807ScalarEvolution::ExitLimit ScalarEvolution::howFarToZero(const SCEV *V,
10808 const Loop *L,
10809 bool ControlsOnlyExit,
10810 bool AllowPredicates) {
10811
10812 // This is only used for loops with a "x != y" exit test. The exit condition
10813 // is now expressed as a single expression, V = x-y. So the exit test is
10814 // effectively V != 0. We know and take advantage of the fact that this
10815 // expression only being used in a comparison by zero context.
10816
10818 // If the value is a constant
10819 if (const SCEVConstant *C = dyn_cast<SCEVConstant>(V)) {
10820 // If the value is already zero, the branch will execute zero times.
10821 if (C->getValue()->isZero()) return C;
10822 return getCouldNotCompute(); // Otherwise it will loop infinitely.
10823 }
10824
10825 const SCEVAddRecExpr *AddRec =
10826 dyn_cast<SCEVAddRecExpr>(stripInjectiveFunctions(V));
10827
10828 if (!AddRec && AllowPredicates)
10829 // Try to make this an AddRec using runtime tests, in the first X
10830 // iterations of this loop, where X is the SCEV expression found by the
10831 // algorithm below.
10832 AddRec = convertSCEVToAddRecWithPredicates(V, L, Predicates);
10833
10834 if (!AddRec || AddRec->getLoop() != L)
10835 return getCouldNotCompute();
10836
10837 // If this is a quadratic (3-term) AddRec {L,+,M,+,N}, find the roots of
10838 // the quadratic equation to solve it.
10839 if (AddRec->isQuadratic() && AddRec->getType()->isIntegerTy()) {
10840 // We can only use this value if the chrec ends up with an exact zero
10841 // value at this index. When solving for "X*X != 5", for example, we
10842 // should not accept a root of 2.
10843 if (auto S = SolveQuadraticAddRecExact(AddRec, *this)) {
10844 const auto *R = cast<SCEVConstant>(getConstant(*S));
10845 return ExitLimit(R, R, R, false, Predicates);
10846 }
10847 return getCouldNotCompute();
10848 }
10849
10850 // Otherwise we can only handle this if it is affine.
10851 if (!AddRec->isAffine())
10852 return getCouldNotCompute();
10853
10854 // If this is an affine expression, the execution count of this branch is
10855 // the minimum unsigned root of the following equation:
10856 //
10857 // Start + Step*N = 0 (mod 2^BW)
10858 //
10859 // equivalent to:
10860 //
10861 // Step*N = -Start (mod 2^BW)
10862 //
10863 // where BW is the common bit width of Start and Step.
10864
10865 // Get the initial value for the loop.
10866 const SCEV *Start = getSCEVAtScope(AddRec->getStart(), L->getParentLoop());
10867 const SCEV *Step = getSCEVAtScope(AddRec->getOperand(1), L->getParentLoop());
10868
10869 if (!isLoopInvariant(Step, L))
10870 return getCouldNotCompute();
10871
10872 LoopGuards Guards = LoopGuards::collect(L, *this);
10873 // Specialize step for this loop so we get context sensitive facts below.
10874 const SCEV *StepWLG = applyLoopGuards(Step, Guards);
10875
10876 // For positive steps (counting up until unsigned overflow):
10877 // N = -Start/Step (as unsigned)
10878 // For negative steps (counting down to zero):
10879 // N = Start/-Step
10880 // First compute the unsigned distance from zero in the direction of Step.
10881 bool CountDown = isKnownNegative(StepWLG);
10882 if (!CountDown && !isKnownNonNegative(StepWLG))
10883 return getCouldNotCompute();
10884
10885 const SCEV *Distance = CountDown ? Start : getNegativeSCEV(Start);
10886 // Handle unitary steps, which cannot wraparound.
10887 // 1*N = -Start; -1*N = Start (mod 2^BW), so:
10888 // N = Distance (as unsigned)
10889
10890 if (match(Step, m_CombineOr(m_scev_One(), m_scev_AllOnes()))) {
10891 APInt MaxBECount = getUnsignedRangeMax(applyLoopGuards(Distance, Guards));
10892 MaxBECount = APIntOps::umin(MaxBECount, getUnsignedRangeMax(Distance));
10893
10894 // When a loop like "for (int i = 0; i != n; ++i) { /* body */ }" is rotated,
10895 // we end up with a loop whose backedge-taken count is n - 1. Detect this
10896 // case, and see if we can improve the bound.
10897 //
10898 // Explicitly handling this here is necessary because getUnsignedRange
10899 // isn't context-sensitive; it doesn't know that we only care about the
10900 // range inside the loop.
10901 const SCEV *Zero = getZero(Distance->getType());
10902 const SCEV *One = getOne(Distance->getType());
10903 const SCEV *DistancePlusOne = getAddExpr(Distance, One);
10904 if (isLoopEntryGuardedByCond(L, ICmpInst::ICMP_NE, DistancePlusOne, Zero)) {
10905 // If Distance + 1 doesn't overflow, we can compute the maximum distance
10906 // as "unsigned_max(Distance + 1) - 1".
10907 ConstantRange CR = getUnsignedRange(DistancePlusOne);
10908 MaxBECount = APIntOps::umin(MaxBECount, CR.getUnsignedMax() - 1);
10909 }
10910 return ExitLimit(Distance, getConstant(MaxBECount), Distance, false,
10911 Predicates);
10912 }
10913
10914 // If the condition controls loop exit (the loop exits only if the expression
10915 // is true) and the addition is no-wrap we can use unsigned divide to
10916 // compute the backedge count. In this case, the step may not divide the
10917 // distance, but we don't care because if the condition is "missed" the loop
10918 // will have undefined behavior due to wrapping.
10919 if (ControlsOnlyExit && AddRec->hasNoSelfWrap() &&
10920 loopHasNoAbnormalExits(AddRec->getLoop())) {
10921
10922 // If the stride is zero and the start is non-zero, the loop must be
10923 // infinite. In C++, most loops are finite by assumption, in which case the
10924 // step being zero implies UB must execute if the loop is entered.
10925 if (!(loopIsFiniteByAssumption(L) && isKnownNonZero(Start)) &&
10926 !isKnownNonZero(StepWLG))
10927 return getCouldNotCompute();
10928
10929 const SCEV *Exact =
10930 getUDivExpr(Distance, CountDown ? getNegativeSCEV(Step) : Step);
10931 const SCEV *ConstantMax = getCouldNotCompute();
10932 if (Exact != getCouldNotCompute()) {
10933 APInt MaxInt = getUnsignedRangeMax(applyLoopGuards(Exact, Guards));
10934 ConstantMax =
10936 }
10937 const SCEV *SymbolicMax =
10938 isa<SCEVCouldNotCompute>(Exact) ? ConstantMax : Exact;
10939 return ExitLimit(Exact, ConstantMax, SymbolicMax, false, Predicates);
10940 }
10941
10942 // Solve the general equation.
10943 const SCEVConstant *StepC = dyn_cast<SCEVConstant>(Step);
10944 if (!StepC || StepC->getValue()->isZero())
10945 return getCouldNotCompute();
10946 const SCEV *E = SolveLinEquationWithOverflow(
10947 StepC->getAPInt(), getNegativeSCEV(Start),
10948 AllowPredicates ? &Predicates : nullptr, *this, L);
10949
10950 const SCEV *M = E;
10951 if (E != getCouldNotCompute()) {
10952 APInt MaxWithGuards = getUnsignedRangeMax(applyLoopGuards(E, Guards));
10953 M = getConstant(APIntOps::umin(MaxWithGuards, getUnsignedRangeMax(E)));
10954 }
10955 auto *S = isa<SCEVCouldNotCompute>(E) ? M : E;
10956 return ExitLimit(E, M, S, false, Predicates);
10957}
10958
10959ScalarEvolution::ExitLimit
10960ScalarEvolution::howFarToNonZero(const SCEV *V, const Loop *L) {
10961 // Loops that look like: while (X == 0) are very strange indeed. We don't
10962 // handle them yet except for the trivial case. This could be expanded in the
10963 // future as needed.
10964
10965 // If the value is a constant, check to see if it is known to be non-zero
10966 // already. If so, the backedge will execute zero times.
10967 if (const SCEVConstant *C = dyn_cast<SCEVConstant>(V)) {
10968 if (!C->getValue()->isZero())
10969 return getZero(C->getType());
10970 return getCouldNotCompute(); // Otherwise it will loop infinitely.
10971 }
10972
10973 // We could implement others, but I really doubt anyone writes loops like
10974 // this, and if they did, they would already be constant folded.
10975 return getCouldNotCompute();
10976}
10977
10978std::pair<const BasicBlock *, const BasicBlock *>
10979ScalarEvolution::getPredecessorWithUniqueSuccessorForBB(const BasicBlock *BB)
10980 const {
10981 // If the block has a unique predecessor, then there is no path from the
10982 // predecessor to the block that does not go through the direct edge
10983 // from the predecessor to the block.
10984 if (const BasicBlock *Pred = BB->getSinglePredecessor())
10985 return {Pred, BB};
10986
10987 // A loop's header is defined to be a block that dominates the loop.
10988 // If the header has a unique predecessor outside the loop, it must be
10989 // a block that has exactly one successor that can reach the loop.
10990 if (const Loop *L = LI.getLoopFor(BB))
10991 return {L->getLoopPredecessor(), L->getHeader()};
10992
10993 return {nullptr, BB};
10994}
10995
10996/// SCEV structural equivalence is usually sufficient for testing whether two
10997/// expressions are equal, however for the purposes of looking for a condition
10998/// guarding a loop, it can be useful to be a little more general, since a
10999/// front-end may have replicated the controlling expression.
11000static bool HasSameValue(const SCEV *A, const SCEV *B) {
11001 // Quick check to see if they are the same SCEV.
11002 if (A == B) return true;
11003
11004 auto ComputesEqualValues = [](const Instruction *A, const Instruction *B) {
11005 // Not all instructions that are "identical" compute the same value. For
11006 // instance, two distinct alloca instructions allocating the same type are
11007 // identical and do not read memory; but compute distinct values.
11008 return A->isIdenticalTo(B) && (isa<BinaryOperator>(A) || isa<GetElementPtrInst>(A));
11009 };
11010
11011 // Otherwise, if they're both SCEVUnknown, it's possible that they hold
11012 // two different instructions with the same value. Check for this case.
11013 if (const SCEVUnknown *AU = dyn_cast<SCEVUnknown>(A))
11014 if (const SCEVUnknown *BU = dyn_cast<SCEVUnknown>(B))
11015 if (const Instruction *AI = dyn_cast<Instruction>(AU->getValue()))
11016 if (const Instruction *BI = dyn_cast<Instruction>(BU->getValue()))
11017 if (ComputesEqualValues(AI, BI))
11018 return true;
11019
11020 // Otherwise assume they may have a different value.
11021 return false;
11022}
11023
11024static bool MatchBinarySub(const SCEV *S, SCEVUse &LHS, SCEVUse &RHS) {
11025 const SCEV *Op0, *Op1;
11026 if (!match(S, m_scev_Add(m_SCEV(Op0), m_SCEV(Op1))))
11027 return false;
11028 if (match(Op0, m_scev_Mul(m_scev_AllOnes(), m_SCEV(RHS)))) {
11029 LHS = Op1;
11030 return true;
11031 }
11032 if (match(Op1, m_scev_Mul(m_scev_AllOnes(), m_SCEV(RHS)))) {
11033 LHS = Op0;
11034 return true;
11035 }
11036 return false;
11037}
11038
11040 SCEVUse &RHS, unsigned Depth) {
11041 bool Changed = false;
11042 // Simplifies ICMP to trivial true or false by turning it into '0 == 0' or
11043 // '0 != 0'.
11044 auto TrivialCase = [&](bool TriviallyTrue) {
11046 Pred = TriviallyTrue ? ICmpInst::ICMP_EQ : ICmpInst::ICMP_NE;
11047 return true;
11048 };
11049 // If we hit the max recursion limit bail out.
11050 if (Depth >= 3)
11051 return false;
11052
11053 const SCEV *NewLHS, *NewRHS;
11054 if (match(LHS, m_scev_c_Mul(m_SCEV(NewLHS), m_SCEVVScale())) &&
11055 match(RHS, m_scev_c_Mul(m_SCEV(NewRHS), m_SCEVVScale()))) {
11056 const SCEVMulExpr *LMul = cast<SCEVMulExpr>(LHS);
11057 const SCEVMulExpr *RMul = cast<SCEVMulExpr>(RHS);
11058
11059 // (X * vscale) pred (Y * vscale) ==> X pred Y
11060 // when both multiples are NSW.
11061 // (X * vscale) uicmp/eq/ne (Y * vscale) ==> X uicmp/eq/ne Y
11062 // when both multiples are NUW.
11063 if ((LMul->hasNoSignedWrap() && RMul->hasNoSignedWrap()) ||
11064 (LMul->hasNoUnsignedWrap() && RMul->hasNoUnsignedWrap() &&
11065 !ICmpInst::isSigned(Pred))) {
11066 LHS = NewLHS;
11067 RHS = NewRHS;
11068 Changed = true;
11069 }
11070 }
11071
11072 // Canonicalize a constant to the right side.
11073 if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(LHS)) {
11074 // Check for both operands constant.
11075 if (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS)) {
11076 if (!ICmpInst::compare(LHSC->getAPInt(), RHSC->getAPInt(), Pred))
11077 return TrivialCase(false);
11078 return TrivialCase(true);
11079 }
11080 // Otherwise swap the operands to put the constant on the right.
11081 std::swap(LHS, RHS);
11083 Changed = true;
11084 }
11085
11086 // If we're comparing an addrec with a value which is loop-invariant in the
11087 // addrec's loop, put the addrec on the left. Also make a dominance check,
11088 // as both operands could be addrecs loop-invariant in each other's loop.
11089 if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(RHS)) {
11090 const Loop *L = AR->getLoop();
11091 if (isLoopInvariant(LHS, L) && properlyDominates(LHS, L->getHeader())) {
11092 std::swap(LHS, RHS);
11094 Changed = true;
11095 }
11096 }
11097
11098 // If there's a constant operand, canonicalize comparisons with boundary
11099 // cases, and canonicalize *-or-equal comparisons to regular comparisons.
11100 if (const SCEVConstant *RC = dyn_cast<SCEVConstant>(RHS)) {
11101 const APInt &RA = RC->getAPInt();
11102
11103 bool SimplifiedByConstantRange = false;
11104
11105 if (!ICmpInst::isEquality(Pred)) {
11107 if (ExactCR.isFullSet())
11108 return TrivialCase(true);
11109 if (ExactCR.isEmptySet())
11110 return TrivialCase(false);
11111
11112 APInt NewRHS;
11113 CmpInst::Predicate NewPred;
11114 if (ExactCR.getEquivalentICmp(NewPred, NewRHS) &&
11115 ICmpInst::isEquality(NewPred)) {
11116 // We were able to convert an inequality to an equality.
11117 Pred = NewPred;
11118 RHS = getConstant(NewRHS);
11119 Changed = SimplifiedByConstantRange = true;
11120 }
11121 }
11122
11123 if (!SimplifiedByConstantRange) {
11124 switch (Pred) {
11125 default:
11126 break;
11127 case ICmpInst::ICMP_EQ:
11128 case ICmpInst::ICMP_NE:
11129 // Fold ((-1) * %a) + %b == 0 (equivalent to %b-%a == 0) into %a == %b.
11130 if (RA.isZero() && MatchBinarySub(LHS, LHS, RHS))
11131 Changed = true;
11132 break;
11133
11134 // The "Should have been caught earlier!" messages refer to the fact
11135 // that the ExactCR.isFullSet() or ExactCR.isEmptySet() check above
11136 // should have fired on the corresponding cases, and canonicalized the
11137 // check to trivial case.
11138
11139 case ICmpInst::ICMP_UGE:
11140 assert(!RA.isMinValue() && "Should have been caught earlier!");
11141 Pred = ICmpInst::ICMP_UGT;
11142 RHS = getConstant(RA - 1);
11143 Changed = true;
11144 break;
11145 case ICmpInst::ICMP_ULE:
11146 assert(!RA.isMaxValue() && "Should have been caught earlier!");
11147 Pred = ICmpInst::ICMP_ULT;
11148 RHS = getConstant(RA + 1);
11149 Changed = true;
11150 break;
11151 case ICmpInst::ICMP_SGE:
11152 assert(!RA.isMinSignedValue() && "Should have been caught earlier!");
11153 Pred = ICmpInst::ICMP_SGT;
11154 RHS = getConstant(RA - 1);
11155 Changed = true;
11156 break;
11157 case ICmpInst::ICMP_SLE:
11158 assert(!RA.isMaxSignedValue() && "Should have been caught earlier!");
11159 Pred = ICmpInst::ICMP_SLT;
11160 RHS = getConstant(RA + 1);
11161 Changed = true;
11162 break;
11163 }
11164 }
11165 }
11166
11167 // Check for obvious equality.
11168 if (HasSameValue(LHS, RHS)) {
11169 if (ICmpInst::isTrueWhenEqual(Pred))
11170 return TrivialCase(true);
11172 return TrivialCase(false);
11173 }
11174
11175 // If possible, canonicalize GE/LE comparisons to GT/LT comparisons, by
11176 // adding or subtracting 1 from one of the operands.
11177 switch (Pred) {
11178 case ICmpInst::ICMP_SLE:
11179 if (!getSignedRangeMax(RHS).isMaxSignedValue()) {
11180 RHS = getAddExpr(getConstant(RHS->getType(), 1, true), RHS,
11182 Pred = ICmpInst::ICMP_SLT;
11183 Changed = true;
11184 } else if (!getSignedRangeMin(LHS).isMinSignedValue()) {
11185 LHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), LHS,
11187 Pred = ICmpInst::ICMP_SLT;
11188 Changed = true;
11189 }
11190 break;
11191 case ICmpInst::ICMP_SGE:
11192 if (!getSignedRangeMin(RHS).isMinSignedValue()) {
11193 RHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), RHS,
11195 Pred = ICmpInst::ICMP_SGT;
11196 Changed = true;
11197 } else if (!getSignedRangeMax(LHS).isMaxSignedValue()) {
11198 LHS = getAddExpr(getConstant(RHS->getType(), 1, true), LHS,
11200 Pred = ICmpInst::ICMP_SGT;
11201 Changed = true;
11202 }
11203 break;
11204 case ICmpInst::ICMP_ULE:
11205 if (!getUnsignedRangeMax(RHS).isMaxValue()) {
11206 RHS = getAddExpr(getConstant(RHS->getType(), 1, true), RHS,
11208 Pred = ICmpInst::ICMP_ULT;
11209 Changed = true;
11210 } else if (!getUnsignedRangeMin(LHS).isMinValue()) {
11211 LHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), LHS);
11212 Pred = ICmpInst::ICMP_ULT;
11213 Changed = true;
11214 }
11215 break;
11216 case ICmpInst::ICMP_UGE:
11217 // If RHS is an op we can fold the -1, try that first.
11218 // Otherwise prefer LHS to preserve the nuw flag.
11219 if ((isa<SCEVConstant>(RHS) ||
11221 isa<SCEVConstant>(cast<SCEVNAryExpr>(RHS)->getOperand(0)))) &&
11222 !getUnsignedRangeMin(RHS).isMinValue()) {
11223 RHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), RHS);
11224 Pred = ICmpInst::ICMP_UGT;
11225 Changed = true;
11226 } else if (!getUnsignedRangeMax(LHS).isMaxValue()) {
11227 LHS = getAddExpr(getConstant(RHS->getType(), 1, true), LHS,
11229 Pred = ICmpInst::ICMP_UGT;
11230 Changed = true;
11231 } else if (!getUnsignedRangeMin(RHS).isMinValue()) {
11232 RHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), RHS);
11233 Pred = ICmpInst::ICMP_UGT;
11234 Changed = true;
11235 }
11236 break;
11237 default:
11238 break;
11239 }
11240
11241 // TODO: More simplifications are possible here.
11242
11243 // Recursively simplify until we either hit a recursion limit or nothing
11244 // changes.
11245 if (Changed)
11246 (void)SimplifyICmpOperands(Pred, LHS, RHS, Depth + 1);
11247
11248 return Changed;
11249}
11250
11252 return getSignedRangeMax(S).isNegative();
11253}
11254
11258
11260 return !getSignedRangeMin(S).isNegative();
11261}
11262
11266
11268 // Query push down for cases where the unsigned range is
11269 // less than sufficient.
11270 if (const auto *SExt = dyn_cast<SCEVSignExtendExpr>(S))
11271 return isKnownNonZero(SExt->getOperand(0));
11272 return getUnsignedRangeMin(S) != 0;
11273}
11274
11276 bool OrNegative) {
11277 auto NonRecursive = [OrNegative](const SCEV *S) {
11278 if (auto *C = dyn_cast<SCEVConstant>(S))
11279 return C->getAPInt().isPowerOf2() ||
11280 (OrNegative && C->getAPInt().isNegatedPowerOf2());
11281
11282 // vscale is a power-of-two.
11283 return isa<SCEVVScale>(S);
11284 };
11285
11286 if (NonRecursive(S))
11287 return true;
11288
11289 auto *Mul = dyn_cast<SCEVMulExpr>(S);
11290 if (!Mul)
11291 return false;
11292 return all_of(Mul->operands(), NonRecursive) && (OrZero || isKnownNonZero(S));
11293}
11294
11296 const SCEV *S, uint64_t M,
11298 if (M == 0)
11299 return false;
11300 if (M == 1)
11301 return true;
11302
11303 // Recursively check AddRec operands. An AddRecExpr S is a multiple of M if S
11304 // starts with a multiple of M and at every iteration step S only adds
11305 // multiples of M.
11306 if (auto *AddRec = dyn_cast<SCEVAddRecExpr>(S))
11307 return isKnownMultipleOf(AddRec->getStart(), M, Assumptions) &&
11308 isKnownMultipleOf(AddRec->getStepRecurrence(*this), M, Assumptions);
11309
11310 // For a constant, check that "S % M == 0".
11311 if (auto *Cst = dyn_cast<SCEVConstant>(S)) {
11312 APInt C = Cst->getAPInt();
11313 return C.urem(M) == 0;
11314 }
11315
11316 // TODO: Also check other SCEV expressions, i.e., SCEVAddRecExpr, etc.
11317
11318 // Basic tests have failed.
11319 // Check "S % M == 0" at compile time and record runtime Assumptions.
11320 auto *STy = dyn_cast<IntegerType>(S->getType());
11321 const SCEV *SmodM =
11322 getURemExpr(S, getConstant(ConstantInt::get(STy, M, false)));
11323 const SCEV *Zero = getZero(STy);
11324
11325 // Check whether "S % M == 0" is known at compile time.
11326 if (isKnownPredicate(ICmpInst::ICMP_EQ, SmodM, Zero))
11327 return true;
11328
11329 // Check whether "S % M != 0" is known at compile time.
11330 if (isKnownPredicate(ICmpInst::ICMP_NE, SmodM, Zero))
11331 return false;
11332
11334
11335 // Detect redundant predicates.
11336 for (auto *A : Assumptions)
11337 if (A->implies(P, *this))
11338 return true;
11339
11340 // Only record non-redundant predicates.
11341 Assumptions.push_back(P);
11342 return true;
11343}
11344
11346 return ((isKnownNonNegative(S1) && isKnownNonNegative(S2)) ||
11348}
11349
11350std::pair<const SCEV *, const SCEV *>
11352 // Compute SCEV on entry of loop L.
11353 const SCEV *Start = SCEVInitRewriter::rewrite(S, L, *this);
11354 if (Start == getCouldNotCompute())
11355 return { Start, Start };
11356 // Compute post increment SCEV for loop L.
11357 const SCEV *PostInc = SCEVPostIncRewriter::rewrite(S, L, *this);
11358 assert(PostInc != getCouldNotCompute() && "Unexpected could not compute");
11359 return { Start, PostInc };
11360}
11361
11363 SCEVUse RHS) {
11364 // First collect all loops.
11366 getUsedLoops(LHS, LoopsUsed);
11367 getUsedLoops(RHS, LoopsUsed);
11368
11369 if (LoopsUsed.empty())
11370 return false;
11371
11372 // Domination relationship must be a linear order on collected loops.
11373#ifndef NDEBUG
11374 for (const auto *L1 : LoopsUsed)
11375 for (const auto *L2 : LoopsUsed)
11376 assert((DT.dominates(L1->getHeader(), L2->getHeader()) ||
11377 DT.dominates(L2->getHeader(), L1->getHeader())) &&
11378 "Domination relationship is not a linear order");
11379#endif
11380
11381 const Loop *MDL =
11382 *llvm::max_element(LoopsUsed, [&](const Loop *L1, const Loop *L2) {
11383 return DT.properlyDominates(L1->getHeader(), L2->getHeader());
11384 });
11385
11386 // Get init and post increment value for LHS.
11387 auto SplitLHS = SplitIntoInitAndPostInc(MDL, LHS);
11388 // if LHS contains unknown non-invariant SCEV then bail out.
11389 if (SplitLHS.first == getCouldNotCompute())
11390 return false;
11391 assert (SplitLHS.second != getCouldNotCompute() && "Unexpected CNC");
11392 // Get init and post increment value for RHS.
11393 auto SplitRHS = SplitIntoInitAndPostInc(MDL, RHS);
11394 // if RHS contains unknown non-invariant SCEV then bail out.
11395 if (SplitRHS.first == getCouldNotCompute())
11396 return false;
11397 assert (SplitRHS.second != getCouldNotCompute() && "Unexpected CNC");
11398 // It is possible that init SCEV contains an invariant load but it does
11399 // not dominate MDL and is not available at MDL loop entry, so we should
11400 // check it here.
11401 if (!isAvailableAtLoopEntry(SplitLHS.first, MDL) ||
11402 !isAvailableAtLoopEntry(SplitRHS.first, MDL))
11403 return false;
11404
11405 // It seems backedge guard check is faster than entry one so in some cases
11406 // it can speed up whole estimation by short circuit
11407 return isLoopBackedgeGuardedByCond(MDL, Pred, SplitLHS.second,
11408 SplitRHS.second) &&
11409 isLoopEntryGuardedByCond(MDL, Pred, SplitLHS.first, SplitRHS.first);
11410}
11411
11413 SCEVUse RHS) {
11414 // Canonicalize the inputs first.
11415 (void)SimplifyICmpOperands(Pred, LHS, RHS);
11416
11417 if (isKnownViaInduction(Pred, LHS, RHS))
11418 return true;
11419
11420 if (isKnownPredicateViaSplitting(Pred, LHS, RHS))
11421 return true;
11422
11423 // Otherwise see what can be done with some simple reasoning.
11424 return isKnownViaNonRecursiveReasoning(Pred, LHS, RHS);
11425}
11426
11428 const SCEV *LHS,
11429 const SCEV *RHS) {
11430 if (isKnownPredicate(Pred, LHS, RHS))
11431 return true;
11433 return false;
11434 return std::nullopt;
11435}
11436
11438 const SCEV *RHS,
11439 const Instruction *CtxI) {
11440 // TODO: Analyze guards and assumes from Context's block.
11441 return isKnownPredicate(Pred, LHS, RHS) ||
11442 isBasicBlockEntryGuardedByCond(CtxI->getParent(), Pred, LHS, RHS);
11443}
11444
11445std::optional<bool>
11447 const SCEV *RHS, const Instruction *CtxI) {
11448 std::optional<bool> KnownWithoutContext = evaluatePredicate(Pred, LHS, RHS);
11449 if (KnownWithoutContext)
11450 return KnownWithoutContext;
11451
11452 if (isBasicBlockEntryGuardedByCond(CtxI->getParent(), Pred, LHS, RHS))
11453 return true;
11455 CtxI->getParent(), ICmpInst::getInverseCmpPredicate(Pred), LHS, RHS))
11456 return false;
11457 return std::nullopt;
11458}
11459
11461 const SCEVAddRecExpr *LHS,
11462 const SCEV *RHS) {
11463 const Loop *L = LHS->getLoop();
11464 return isLoopEntryGuardedByCond(L, Pred, LHS->getStart(), RHS) &&
11465 isLoopBackedgeGuardedByCond(L, Pred, LHS->getPostIncExpr(*this), RHS);
11466}
11467
11468std::optional<ScalarEvolution::MonotonicPredicateType>
11470 ICmpInst::Predicate Pred) {
11471 auto Result = getMonotonicPredicateTypeImpl(LHS, Pred);
11472
11473#ifndef NDEBUG
11474 // Verify an invariant: inverting the predicate should turn a monotonically
11475 // increasing change to a monotonically decreasing one, and vice versa.
11476 if (Result) {
11477 auto ResultSwapped =
11478 getMonotonicPredicateTypeImpl(LHS, ICmpInst::getSwappedPredicate(Pred));
11479
11480 assert(*ResultSwapped != *Result &&
11481 "monotonicity should flip as we flip the predicate");
11482 }
11483#endif
11484
11485 return Result;
11486}
11487
11488std::optional<ScalarEvolution::MonotonicPredicateType>
11489ScalarEvolution::getMonotonicPredicateTypeImpl(const SCEVAddRecExpr *LHS,
11490 ICmpInst::Predicate Pred) {
11491 // A zero step value for LHS means the induction variable is essentially a
11492 // loop invariant value. We don't really depend on the predicate actually
11493 // flipping from false to true (for increasing predicates, and the other way
11494 // around for decreasing predicates), all we care about is that *if* the
11495 // predicate changes then it only changes from false to true.
11496 //
11497 // A zero step value in itself is not very useful, but there may be places
11498 // where SCEV can prove X >= 0 but not prove X > 0, so it is helpful to be
11499 // as general as possible.
11500
11501 // Only handle LE/LT/GE/GT predicates.
11502 if (!ICmpInst::isRelational(Pred))
11503 return std::nullopt;
11504
11505 bool IsGreater = ICmpInst::isGE(Pred) || ICmpInst::isGT(Pred);
11506 assert((IsGreater || ICmpInst::isLE(Pred) || ICmpInst::isLT(Pred)) &&
11507 "Should be greater or less!");
11508
11509 // Check that AR does not wrap.
11510 if (ICmpInst::isUnsigned(Pred)) {
11511 if (!LHS->hasNoUnsignedWrap())
11512 return std::nullopt;
11514 }
11515 assert(ICmpInst::isSigned(Pred) &&
11516 "Relational predicate is either signed or unsigned!");
11517 if (!LHS->hasNoSignedWrap())
11518 return std::nullopt;
11519
11520 const SCEV *Step = LHS->getStepRecurrence(*this);
11521
11522 if (isKnownNonNegative(Step))
11524
11525 if (isKnownNonPositive(Step))
11527
11528 return std::nullopt;
11529}
11530
11531std::optional<ScalarEvolution::LoopInvariantPredicate>
11533 const SCEV *RHS, const Loop *L,
11534 const Instruction *CtxI) {
11535 // If there is a loop-invariant, force it into the RHS, otherwise bail out.
11536 if (!isLoopInvariant(RHS, L)) {
11537 if (!isLoopInvariant(LHS, L))
11538 return std::nullopt;
11539
11540 std::swap(LHS, RHS);
11542 }
11543
11544 const SCEVAddRecExpr *ArLHS = dyn_cast<SCEVAddRecExpr>(LHS);
11545 if (!ArLHS || ArLHS->getLoop() != L)
11546 return std::nullopt;
11547
11548 auto MonotonicType = getMonotonicPredicateType(ArLHS, Pred);
11549 if (!MonotonicType)
11550 return std::nullopt;
11551 // If the predicate "ArLHS `Pred` RHS" monotonically increases from false to
11552 // true as the loop iterates, and the backedge is control dependent on
11553 // "ArLHS `Pred` RHS" == true then we can reason as follows:
11554 //
11555 // * if the predicate was false in the first iteration then the predicate
11556 // is never evaluated again, since the loop exits without taking the
11557 // backedge.
11558 // * if the predicate was true in the first iteration then it will
11559 // continue to be true for all future iterations since it is
11560 // monotonically increasing.
11561 //
11562 // For both the above possibilities, we can replace the loop varying
11563 // predicate with its value on the first iteration of the loop (which is
11564 // loop invariant).
11565 //
11566 // A similar reasoning applies for a monotonically decreasing predicate, by
11567 // replacing true with false and false with true in the above two bullets.
11569 auto P = Increasing ? Pred : ICmpInst::getInverseCmpPredicate(Pred);
11570
11571 if (isLoopBackedgeGuardedByCond(L, P, LHS, RHS))
11573 RHS);
11574
11575 if (!CtxI)
11576 return std::nullopt;
11577 // Try to prove via context.
11578 // TODO: Support other cases.
11579 switch (Pred) {
11580 default:
11581 break;
11582 case ICmpInst::ICMP_ULE:
11583 case ICmpInst::ICMP_ULT: {
11584 assert(ArLHS->hasNoUnsignedWrap() && "Is a requirement of monotonicity!");
11585 // Given preconditions
11586 // (1) ArLHS does not cross the border of positive and negative parts of
11587 // range because of:
11588 // - Positive step; (TODO: lift this limitation)
11589 // - nuw - does not cross zero boundary;
11590 // - nsw - does not cross SINT_MAX boundary;
11591 // (2) ArLHS <s RHS
11592 // (3) RHS >=s 0
11593 // we can replace the loop variant ArLHS <u RHS condition with loop
11594 // invariant Start(ArLHS) <u RHS.
11595 //
11596 // Because of (1) there are two options:
11597 // - ArLHS is always negative. It means that ArLHS <u RHS is always false;
11598 // - ArLHS is always non-negative. Because of (3) RHS is also non-negative.
11599 // It means that ArLHS <s RHS <=> ArLHS <u RHS.
11600 // Because of (2) ArLHS <u RHS is trivially true.
11601 // All together it means that ArLHS <u RHS <=> Start(ArLHS) >=s 0.
11602 // We can strengthen this to Start(ArLHS) <u RHS.
11603 auto SignFlippedPred = ICmpInst::getFlippedSignednessPredicate(Pred);
11604 if (ArLHS->hasNoSignedWrap() && ArLHS->isAffine() &&
11605 isKnownPositive(ArLHS->getStepRecurrence(*this)) &&
11606 isKnownNonNegative(RHS) &&
11607 isKnownPredicateAt(SignFlippedPred, ArLHS, RHS, CtxI))
11609 RHS);
11610 }
11611 }
11612
11613 return std::nullopt;
11614}
11615
11616std::optional<ScalarEvolution::LoopInvariantPredicate>
11618 CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS, const Loop *L,
11619 const Instruction *CtxI, const SCEV *MaxIter) {
11621 Pred, LHS, RHS, L, CtxI, MaxIter))
11622 return LIP;
11623 if (auto *UMin = dyn_cast<SCEVUMinExpr>(MaxIter))
11624 // Number of iterations expressed as UMIN isn't always great for expressing
11625 // the value on the last iteration. If the straightforward approach didn't
11626 // work, try the following trick: if the a predicate is invariant for X, it
11627 // is also invariant for umin(X, ...). So try to find something that works
11628 // among subexpressions of MaxIter expressed as umin.
11629 for (SCEVUse Op : UMin->operands())
11631 Pred, LHS, RHS, L, CtxI, Op))
11632 return LIP;
11633 return std::nullopt;
11634}
11635
11636std::optional<ScalarEvolution::LoopInvariantPredicate>
11638 CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS, const Loop *L,
11639 const Instruction *CtxI, const SCEV *MaxIter) {
11640 // Try to prove the following set of facts:
11641 // - The predicate is monotonic in the iteration space.
11642 // - If the check does not fail on the 1st iteration:
11643 // - No overflow will happen during first MaxIter iterations;
11644 // - It will not fail on the MaxIter'th iteration.
11645 // If the check does fail on the 1st iteration, we leave the loop and no
11646 // other checks matter.
11647
11648 // If there is a loop-invariant, force it into the RHS, otherwise bail out.
11649 if (!isLoopInvariant(RHS, L)) {
11650 if (!isLoopInvariant(LHS, L))
11651 return std::nullopt;
11652
11653 std::swap(LHS, RHS);
11655 }
11656
11657 auto *AR = dyn_cast<SCEVAddRecExpr>(LHS);
11658 if (!AR || AR->getLoop() != L)
11659 return std::nullopt;
11660
11661 // Even if both are valid, we need to consistently chose the unsigned or the
11662 // signed predicate below, not mixtures of both. For now, prefer the unsigned
11663 // predicate.
11664 Pred = Pred.dropSameSign();
11665
11666 // The predicate must be relational (i.e. <, <=, >=, >).
11667 if (!ICmpInst::isRelational(Pred))
11668 return std::nullopt;
11669
11670 // TODO: Support steps other than +/- 1.
11671 const SCEV *Step = AR->getStepRecurrence(*this);
11672 auto *One = getOne(Step->getType());
11673 auto *MinusOne = getNegativeSCEV(One);
11674 if (Step != One && Step != MinusOne)
11675 return std::nullopt;
11676
11677 // Type mismatch here means that MaxIter is potentially larger than max
11678 // unsigned value in start type, which mean we cannot prove no wrap for the
11679 // indvar.
11680 if (AR->getType() != MaxIter->getType())
11681 return std::nullopt;
11682
11683 // Value of IV on suggested last iteration.
11684 const SCEV *Last = AR->evaluateAtIteration(MaxIter, *this);
11685 // Does it still meet the requirement?
11686 if (!isLoopBackedgeGuardedByCond(L, Pred, Last, RHS))
11687 return std::nullopt;
11688 // Because step is +/- 1 and MaxIter has same type as Start (i.e. it does
11689 // not exceed max unsigned value of this type), this effectively proves
11690 // that there is no wrap during the iteration. To prove that there is no
11691 // signed/unsigned wrap, we need to check that
11692 // Start <= Last for step = 1 or Start >= Last for step = -1.
11693 ICmpInst::Predicate NoOverflowPred =
11695 if (Step == MinusOne)
11696 NoOverflowPred = ICmpInst::getSwappedPredicate(NoOverflowPred);
11697 const SCEV *Start = AR->getStart();
11698 if (!isKnownPredicateAt(NoOverflowPred, Start, Last, CtxI))
11699 return std::nullopt;
11700
11701 // Everything is fine.
11702 return ScalarEvolution::LoopInvariantPredicate(Pred, Start, RHS);
11703}
11704
11705bool ScalarEvolution::isKnownPredicateViaConstantRanges(CmpPredicate Pred,
11706 SCEVUse LHS,
11707 SCEVUse RHS) {
11708 if (HasSameValue(LHS, RHS))
11709 return ICmpInst::isTrueWhenEqual(Pred);
11710
11711 auto CheckRange = [&](bool IsSigned) {
11712 auto RangeLHS = IsSigned ? getSignedRange(LHS) : getUnsignedRange(LHS);
11713 auto RangeRHS = IsSigned ? getSignedRange(RHS) : getUnsignedRange(RHS);
11714 return RangeLHS.icmp(Pred, RangeRHS);
11715 };
11716
11717 // The check at the top of the function catches the case where the values are
11718 // known to be equal.
11719 if (Pred == CmpInst::ICMP_EQ)
11720 return false;
11721
11722 if (Pred == CmpInst::ICMP_NE) {
11723 if (CheckRange(true) || CheckRange(false))
11724 return true;
11725 auto *Diff = getMinusSCEV(LHS, RHS);
11726 return !isa<SCEVCouldNotCompute>(Diff) && isKnownNonZero(Diff);
11727 }
11728
11729 return CheckRange(CmpInst::isSigned(Pred));
11730}
11731
11732bool ScalarEvolution::isKnownPredicateViaNoOverflow(CmpPredicate Pred,
11734 // Match X to (A + C1)<ExpectedFlags> and Y to (A + C2)<ExpectedFlags>, where
11735 // C1 and C2 are constant integers. If either X or Y are not add expressions,
11736 // consider them as X + 0 and Y + 0 respectively. C1 and C2 are returned via
11737 // OutC1 and OutC2.
11738 auto MatchBinaryAddToConst = [this](SCEVUse X, SCEVUse Y, APInt &OutC1,
11739 APInt &OutC2,
11740 SCEV::NoWrapFlags ExpectedFlags) {
11741 SCEVUse XNonConstOp, XConstOp;
11742 SCEVUse YNonConstOp, YConstOp;
11743 SCEV::NoWrapFlags XFlagsPresent;
11744 SCEV::NoWrapFlags YFlagsPresent;
11745
11746 if (!splitBinaryAdd(X, XConstOp, XNonConstOp, XFlagsPresent)) {
11747 XConstOp = getZero(X->getType());
11748 XNonConstOp = X;
11749 XFlagsPresent = ExpectedFlags;
11750 }
11751 if (!isa<SCEVConstant>(XConstOp))
11752 return false;
11753
11754 if (!splitBinaryAdd(Y, YConstOp, YNonConstOp, YFlagsPresent)) {
11755 YConstOp = getZero(Y->getType());
11756 YNonConstOp = Y;
11757 YFlagsPresent = ExpectedFlags;
11758 }
11759
11760 if (YNonConstOp != XNonConstOp)
11761 return false;
11762
11763 if (!isa<SCEVConstant>(YConstOp))
11764 return false;
11765
11766 // When matching ADDs with NUW flags (and unsigned predicates), only the
11767 // second ADD (with the larger constant) requires NUW.
11768 if ((YFlagsPresent & ExpectedFlags) != ExpectedFlags)
11769 return false;
11770 if (ExpectedFlags != SCEV::FlagNUW &&
11771 (XFlagsPresent & ExpectedFlags) != ExpectedFlags) {
11772 return false;
11773 }
11774
11775 OutC1 = cast<SCEVConstant>(XConstOp)->getAPInt();
11776 OutC2 = cast<SCEVConstant>(YConstOp)->getAPInt();
11777
11778 return true;
11779 };
11780
11781 APInt C1;
11782 APInt C2;
11783
11784 switch (Pred) {
11785 default:
11786 break;
11787
11788 case ICmpInst::ICMP_SGE:
11789 std::swap(LHS, RHS);
11790 [[fallthrough]];
11791 case ICmpInst::ICMP_SLE:
11792 // (X + C1)<nsw> s<= (X + C2)<nsw> if C1 s<= C2.
11793 if (MatchBinaryAddToConst(LHS, RHS, C1, C2, SCEV::FlagNSW) && C1.sle(C2))
11794 return true;
11795
11796 break;
11797
11798 case ICmpInst::ICMP_SGT:
11799 std::swap(LHS, RHS);
11800 [[fallthrough]];
11801 case ICmpInst::ICMP_SLT:
11802 // (X + C1)<nsw> s< (X + C2)<nsw> if C1 s< C2.
11803 if (MatchBinaryAddToConst(LHS, RHS, C1, C2, SCEV::FlagNSW) && C1.slt(C2))
11804 return true;
11805
11806 break;
11807
11808 case ICmpInst::ICMP_UGE:
11809 std::swap(LHS, RHS);
11810 [[fallthrough]];
11811 case ICmpInst::ICMP_ULE:
11812 // (X + C1) u<= (X + C2)<nuw> for C1 u<= C2.
11813 if (MatchBinaryAddToConst(LHS, RHS, C1, C2, SCEV::FlagNUW) && C1.ule(C2))
11814 return true;
11815
11816 break;
11817
11818 case ICmpInst::ICMP_UGT:
11819 std::swap(LHS, RHS);
11820 [[fallthrough]];
11821 case ICmpInst::ICMP_ULT:
11822 // (X + C1) u< (X + C2)<nuw> if C1 u< C2.
11823 if (MatchBinaryAddToConst(LHS, RHS, C1, C2, SCEV::FlagNUW) && C1.ult(C2))
11824 return true;
11825 break;
11826 }
11827
11828 return false;
11829}
11830
11831bool ScalarEvolution::isKnownPredicateViaSplitting(CmpPredicate Pred,
11833 if (Pred != ICmpInst::ICMP_ULT || ProvingSplitPredicate)
11834 return false;
11835
11836 // Allowing arbitrary number of activations of isKnownPredicateViaSplitting on
11837 // the stack can result in exponential time complexity.
11838 SaveAndRestore Restore(ProvingSplitPredicate, true);
11839
11840 // If L >= 0 then I `ult` L <=> I >= 0 && I `slt` L
11841 //
11842 // To prove L >= 0 we use isKnownNonNegative whereas to prove I >= 0 we use
11843 // isKnownPredicate. isKnownPredicate is more powerful, but also more
11844 // expensive; and using isKnownNonNegative(RHS) is sufficient for most of the
11845 // interesting cases seen in practice. We can consider "upgrading" L >= 0 to
11846 // use isKnownPredicate later if needed.
11847 return isKnownNonNegative(RHS) &&
11850}
11851
11852bool ScalarEvolution::isImpliedViaGuard(const BasicBlock *BB, CmpPredicate Pred,
11853 const SCEV *LHS, const SCEV *RHS) {
11854 // No need to even try if we know the module has no guards.
11855 if (!HasGuards)
11856 return false;
11857
11858 return any_of(*BB, [&](const Instruction &I) {
11859 using namespace llvm::PatternMatch;
11860
11861 Value *Condition;
11863 m_Value(Condition))) &&
11864 isImpliedCond(Pred, LHS, RHS, Condition, false);
11865 });
11866}
11867
11868/// isLoopBackedgeGuardedByCond - Test whether the backedge of the loop is
11869/// protected by a conditional between LHS and RHS. This is used to
11870/// to eliminate casts.
11872 CmpPredicate Pred,
11873 const SCEV *LHS,
11874 const SCEV *RHS) {
11875 // Interpret a null as meaning no loop, where there is obviously no guard
11876 // (interprocedural conditions notwithstanding). Do not bother about
11877 // unreachable loops.
11878 if (!L || !DT.isReachableFromEntry(L->getHeader()))
11879 return true;
11880
11881 if (VerifyIR)
11882 assert(!verifyFunction(*L->getHeader()->getParent(), &dbgs()) &&
11883 "This cannot be done on broken IR!");
11884
11885
11886 if (isKnownViaNonRecursiveReasoning(Pred, LHS, RHS))
11887 return true;
11888
11889 BasicBlock *Latch = L->getLoopLatch();
11890 if (!Latch)
11891 return false;
11892
11893 CondBrInst *LoopContinuePredicate =
11895 if (LoopContinuePredicate &&
11896 isImpliedCond(Pred, LHS, RHS, LoopContinuePredicate->getCondition(),
11897 LoopContinuePredicate->getSuccessor(0) != L->getHeader()))
11898 return true;
11899
11900 // We don't want more than one activation of the following loops on the stack
11901 // -- that can lead to O(n!) time complexity.
11902 if (WalkingBEDominatingConds)
11903 return false;
11904
11905 SaveAndRestore ClearOnExit(WalkingBEDominatingConds, true);
11906
11907 // See if we can exploit a trip count to prove the predicate.
11908 const auto &BETakenInfo = getBackedgeTakenInfo(L);
11909 const SCEV *LatchBECount = BETakenInfo.getExact(Latch, this);
11910 if (LatchBECount != getCouldNotCompute()) {
11911 // We know that Latch branches back to the loop header exactly
11912 // LatchBECount times. This means the backdege condition at Latch is
11913 // equivalent to "{0,+,1} u< LatchBECount".
11914 Type *Ty = LatchBECount->getType();
11915 auto NoWrapFlags = SCEV::NoWrapFlags(SCEV::FlagNUW | SCEV::FlagNW);
11916 const SCEV *LoopCounter =
11917 getAddRecExpr(getZero(Ty), getOne(Ty), L, NoWrapFlags);
11918 if (isImpliedCond(Pred, LHS, RHS, ICmpInst::ICMP_ULT, LoopCounter,
11919 LatchBECount))
11920 return true;
11921 }
11922
11923 // Check conditions due to any @llvm.assume intrinsics.
11924 for (auto &AssumeVH : AC.assumptions()) {
11925 if (!AssumeVH)
11926 continue;
11927 auto *CI = cast<CallInst>(AssumeVH);
11928 if (!DT.dominates(CI, Latch->getTerminator()))
11929 continue;
11930
11931 if (isImpliedCond(Pred, LHS, RHS, CI->getArgOperand(0), false))
11932 return true;
11933 }
11934
11935 if (isImpliedViaGuard(Latch, Pred, LHS, RHS))
11936 return true;
11937
11938 for (DomTreeNode *DTN = DT[Latch], *HeaderDTN = DT[L->getHeader()];
11939 DTN != HeaderDTN; DTN = DTN->getIDom()) {
11940 assert(DTN && "should reach the loop header before reaching the root!");
11941
11942 BasicBlock *BB = DTN->getBlock();
11943 if (isImpliedViaGuard(BB, Pred, LHS, RHS))
11944 return true;
11945
11946 BasicBlock *PBB = BB->getSinglePredecessor();
11947 if (!PBB)
11948 continue;
11949
11951 if (!ContBr || ContBr->getSuccessor(0) == ContBr->getSuccessor(1))
11952 continue;
11953
11954 // If we have an edge `E` within the loop body that dominates the only
11955 // latch, the condition guarding `E` also guards the backedge. This
11956 // reasoning works only for loops with a single latch.
11957 // We're constructively (and conservatively) enumerating edges within the
11958 // loop body that dominate the latch. The dominator tree better agree
11959 // with us on this:
11960 assert(DT.dominates(BasicBlockEdge(PBB, BB), Latch) && "should be!");
11961 if (isImpliedCond(Pred, LHS, RHS, ContBr->getCondition(),
11962 BB != ContBr->getSuccessor(0)))
11963 return true;
11964 }
11965
11966 return false;
11967}
11968
11970 CmpPredicate Pred,
11971 const SCEV *LHS,
11972 const SCEV *RHS) {
11973 // Do not bother proving facts for unreachable code.
11974 if (!DT.isReachableFromEntry(BB))
11975 return true;
11976 if (VerifyIR)
11977 assert(!verifyFunction(*BB->getParent(), &dbgs()) &&
11978 "This cannot be done on broken IR!");
11979
11980 // If we cannot prove strict comparison (e.g. a > b), maybe we can prove
11981 // the facts (a >= b && a != b) separately. A typical situation is when the
11982 // non-strict comparison is known from ranges and non-equality is known from
11983 // dominating predicates. If we are proving strict comparison, we always try
11984 // to prove non-equality and non-strict comparison separately.
11985 CmpPredicate NonStrictPredicate = ICmpInst::getNonStrictCmpPredicate(Pred);
11986 const bool ProvingStrictComparison =
11987 Pred != NonStrictPredicate.dropSameSign();
11988 bool ProvedNonStrictComparison = false;
11989 bool ProvedNonEquality = false;
11990
11991 auto SplitAndProve = [&](std::function<bool(CmpPredicate)> Fn) -> bool {
11992 if (!ProvedNonStrictComparison)
11993 ProvedNonStrictComparison = Fn(NonStrictPredicate);
11994 if (!ProvedNonEquality)
11995 ProvedNonEquality = Fn(ICmpInst::ICMP_NE);
11996 if (ProvedNonStrictComparison && ProvedNonEquality)
11997 return true;
11998 return false;
11999 };
12000
12001 if (ProvingStrictComparison) {
12002 auto ProofFn = [&](CmpPredicate P) {
12003 return isKnownViaNonRecursiveReasoning(P, LHS, RHS);
12004 };
12005 if (SplitAndProve(ProofFn))
12006 return true;
12007 }
12008
12009 // Try to prove (Pred, LHS, RHS) using isImpliedCond.
12010 auto ProveViaCond = [&](const Value *Condition, bool Inverse) {
12011 const Instruction *CtxI = &BB->front();
12012 if (isImpliedCond(Pred, LHS, RHS, Condition, Inverse, CtxI))
12013 return true;
12014 if (ProvingStrictComparison) {
12015 auto ProofFn = [&](CmpPredicate P) {
12016 return isImpliedCond(P, LHS, RHS, Condition, Inverse, CtxI);
12017 };
12018 if (SplitAndProve(ProofFn))
12019 return true;
12020 }
12021 return false;
12022 };
12023
12024 // Starting at the block's predecessor, climb up the predecessor chain, as long
12025 // as there are predecessors that can be found that have unique successors
12026 // leading to the original block.
12027 const Loop *ContainingLoop = LI.getLoopFor(BB);
12028 const BasicBlock *PredBB;
12029 if (ContainingLoop && ContainingLoop->getHeader() == BB)
12030 PredBB = ContainingLoop->getLoopPredecessor();
12031 else
12032 PredBB = BB->getSinglePredecessor();
12033 for (std::pair<const BasicBlock *, const BasicBlock *> Pair(PredBB, BB);
12034 Pair.first; Pair = getPredecessorWithUniqueSuccessorForBB(Pair.first)) {
12035 const CondBrInst *BlockEntryPredicate =
12036 dyn_cast<CondBrInst>(Pair.first->getTerminator());
12037 if (!BlockEntryPredicate)
12038 continue;
12039
12040 if (ProveViaCond(BlockEntryPredicate->getCondition(),
12041 BlockEntryPredicate->getSuccessor(0) != Pair.second))
12042 return true;
12043 }
12044
12045 // Check conditions due to any @llvm.assume intrinsics.
12046 for (auto &AssumeVH : AC.assumptions()) {
12047 if (!AssumeVH)
12048 continue;
12049 auto *CI = cast<CallInst>(AssumeVH);
12050 if (!DT.dominates(CI, BB))
12051 continue;
12052
12053 if (ProveViaCond(CI->getArgOperand(0), false))
12054 return true;
12055 }
12056
12057 // Check conditions due to any @llvm.experimental.guard intrinsics.
12058 auto *GuardDecl = Intrinsic::getDeclarationIfExists(
12059 F.getParent(), Intrinsic::experimental_guard);
12060 if (GuardDecl)
12061 for (const auto *GU : GuardDecl->users())
12062 if (const auto *Guard = dyn_cast<IntrinsicInst>(GU))
12063 if (Guard->getFunction() == BB->getParent() && DT.dominates(Guard, BB))
12064 if (ProveViaCond(Guard->getArgOperand(0), false))
12065 return true;
12066 return false;
12067}
12068
12070 const SCEV *LHS,
12071 const SCEV *RHS) {
12072 // Interpret a null as meaning no loop, where there is obviously no guard
12073 // (interprocedural conditions notwithstanding).
12074 if (!L)
12075 return false;
12076
12077 // Both LHS and RHS must be available at loop entry.
12079 "LHS is not available at Loop Entry");
12081 "RHS is not available at Loop Entry");
12082
12083 if (isKnownViaNonRecursiveReasoning(Pred, LHS, RHS))
12084 return true;
12085
12086 return isBasicBlockEntryGuardedByCond(L->getHeader(), Pred, LHS, RHS);
12087}
12088
12089bool ScalarEvolution::isImpliedCond(CmpPredicate Pred, const SCEV *LHS,
12090 const SCEV *RHS,
12091 const Value *FoundCondValue, bool Inverse,
12092 const Instruction *CtxI) {
12093 // False conditions implies anything. Do not bother analyzing it further.
12094 if (FoundCondValue ==
12095 ConstantInt::getBool(FoundCondValue->getContext(), Inverse))
12096 return true;
12097
12098 if (!PendingLoopPredicates.insert(FoundCondValue).second)
12099 return false;
12100
12101 llvm::scope_exit ClearOnExit(
12102 [&]() { PendingLoopPredicates.erase(FoundCondValue); });
12103
12104 // Recursively handle And and Or conditions.
12105 const Value *Op0, *Op1;
12106 if (match(FoundCondValue, m_LogicalAnd(m_Value(Op0), m_Value(Op1)))) {
12107 if (!Inverse)
12108 return isImpliedCond(Pred, LHS, RHS, Op0, Inverse, CtxI) ||
12109 isImpliedCond(Pred, LHS, RHS, Op1, Inverse, CtxI);
12110 } else if (match(FoundCondValue, m_LogicalOr(m_Value(Op0), m_Value(Op1)))) {
12111 if (Inverse)
12112 return isImpliedCond(Pred, LHS, RHS, Op0, Inverse, CtxI) ||
12113 isImpliedCond(Pred, LHS, RHS, Op1, Inverse, CtxI);
12114 }
12115
12116 const ICmpInst *ICI = dyn_cast<ICmpInst>(FoundCondValue);
12117 if (!ICI) return false;
12118
12119 // Now that we found a conditional branch that dominates the loop or controls
12120 // the loop latch. Check to see if it is the comparison we are looking for.
12121 CmpPredicate FoundPred;
12122 if (Inverse)
12123 FoundPred = ICI->getInverseCmpPredicate();
12124 else
12125 FoundPred = ICI->getCmpPredicate();
12126
12127 const SCEV *FoundLHS = getSCEV(ICI->getOperand(0));
12128 const SCEV *FoundRHS = getSCEV(ICI->getOperand(1));
12129
12130 return isImpliedCond(Pred, LHS, RHS, FoundPred, FoundLHS, FoundRHS, CtxI);
12131}
12132
12133bool ScalarEvolution::isImpliedCond(CmpPredicate Pred, const SCEV *LHS,
12134 const SCEV *RHS, CmpPredicate FoundPred,
12135 const SCEV *FoundLHS, const SCEV *FoundRHS,
12136 const Instruction *CtxI) {
12137 // Balance the types.
12138 if (getTypeSizeInBits(LHS->getType()) <
12139 getTypeSizeInBits(FoundLHS->getType())) {
12140 // For unsigned and equality predicates, try to prove that both found
12141 // operands fit into narrow unsigned range. If so, try to prove facts in
12142 // narrow types.
12143 if (!CmpInst::isSigned(FoundPred) && !FoundLHS->getType()->isPointerTy() &&
12144 !FoundRHS->getType()->isPointerTy()) {
12145 auto *NarrowType = LHS->getType();
12146 auto *WideType = FoundLHS->getType();
12147 auto BitWidth = getTypeSizeInBits(NarrowType);
12148 const SCEV *MaxValue = getZeroExtendExpr(
12150 if (isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_ULE, FoundLHS,
12151 MaxValue) &&
12152 isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_ULE, FoundRHS,
12153 MaxValue)) {
12154 const SCEV *TruncFoundLHS = getTruncateExpr(FoundLHS, NarrowType);
12155 const SCEV *TruncFoundRHS = getTruncateExpr(FoundRHS, NarrowType);
12156 // We cannot preserve samesign after truncation.
12157 if (isImpliedCondBalancedTypes(Pred, LHS, RHS, FoundPred.dropSameSign(),
12158 TruncFoundLHS, TruncFoundRHS, CtxI))
12159 return true;
12160 }
12161 }
12162
12163 if (LHS->getType()->isPointerTy() || RHS->getType()->isPointerTy())
12164 return false;
12165 if (CmpInst::isSigned(Pred)) {
12166 LHS = getSignExtendExpr(LHS, FoundLHS->getType());
12167 RHS = getSignExtendExpr(RHS, FoundLHS->getType());
12168 } else {
12169 LHS = getZeroExtendExpr(LHS, FoundLHS->getType());
12170 RHS = getZeroExtendExpr(RHS, FoundLHS->getType());
12171 }
12172 } else if (getTypeSizeInBits(LHS->getType()) >
12173 getTypeSizeInBits(FoundLHS->getType())) {
12174 if (FoundLHS->getType()->isPointerTy() || FoundRHS->getType()->isPointerTy())
12175 return false;
12176 if (CmpInst::isSigned(FoundPred)) {
12177 FoundLHS = getSignExtendExpr(FoundLHS, LHS->getType());
12178 FoundRHS = getSignExtendExpr(FoundRHS, LHS->getType());
12179 } else {
12180 FoundLHS = getZeroExtendExpr(FoundLHS, LHS->getType());
12181 FoundRHS = getZeroExtendExpr(FoundRHS, LHS->getType());
12182 }
12183 }
12184 return isImpliedCondBalancedTypes(Pred, LHS, RHS, FoundPred, FoundLHS,
12185 FoundRHS, CtxI);
12186}
12187
12188bool ScalarEvolution::isImpliedCondBalancedTypes(
12189 CmpPredicate Pred, SCEVUse LHS, SCEVUse RHS, CmpPredicate FoundPred,
12190 SCEVUse FoundLHS, SCEVUse FoundRHS, const Instruction *CtxI) {
12192 getTypeSizeInBits(FoundLHS->getType()) &&
12193 "Types should be balanced!");
12194 // Canonicalize the query to match the way instcombine will have
12195 // canonicalized the comparison.
12196 if (SimplifyICmpOperands(Pred, LHS, RHS))
12197 if (LHS == RHS)
12198 return CmpInst::isTrueWhenEqual(Pred);
12199 if (SimplifyICmpOperands(FoundPred, FoundLHS, FoundRHS))
12200 if (FoundLHS == FoundRHS)
12201 return CmpInst::isFalseWhenEqual(FoundPred);
12202
12203 // Check to see if we can make the LHS or RHS match.
12204 if (LHS == FoundRHS || RHS == FoundLHS) {
12205 if (isa<SCEVConstant>(RHS)) {
12206 std::swap(FoundLHS, FoundRHS);
12207 FoundPred = ICmpInst::getSwappedCmpPredicate(FoundPred);
12208 } else {
12209 std::swap(LHS, RHS);
12211 }
12212 }
12213
12214 // Check whether the found predicate is the same as the desired predicate.
12215 if (auto P = CmpPredicate::getMatching(FoundPred, Pred))
12216 return isImpliedCondOperands(*P, LHS, RHS, FoundLHS, FoundRHS, CtxI);
12217
12218 // Check whether swapping the found predicate makes it the same as the
12219 // desired predicate.
12220 if (auto P = CmpPredicate::getMatching(
12221 ICmpInst::getSwappedCmpPredicate(FoundPred), Pred)) {
12222 // We can write the implication
12223 // 0. LHS Pred RHS <- FoundLHS SwapPred FoundRHS
12224 // using one of the following ways:
12225 // 1. LHS Pred RHS <- FoundRHS Pred FoundLHS
12226 // 2. RHS SwapPred LHS <- FoundLHS SwapPred FoundRHS
12227 // 3. LHS Pred RHS <- ~FoundLHS Pred ~FoundRHS
12228 // 4. ~LHS SwapPred ~RHS <- FoundLHS SwapPred FoundRHS
12229 // Forms 1. and 2. require swapping the operands of one condition. Don't
12230 // do this if it would break canonical constant/addrec ordering.
12232 return isImpliedCondOperands(ICmpInst::getSwappedCmpPredicate(*P), RHS,
12233 LHS, FoundLHS, FoundRHS, CtxI);
12234 if (!isa<SCEVConstant>(FoundRHS) && !isa<SCEVAddRecExpr>(FoundLHS))
12235 return isImpliedCondOperands(*P, LHS, RHS, FoundRHS, FoundLHS, CtxI);
12236
12237 // There's no clear preference between forms 3. and 4., try both. Avoid
12238 // forming getNotSCEV of pointer values as the resulting subtract is
12239 // not legal.
12240 if (!LHS->getType()->isPointerTy() && !RHS->getType()->isPointerTy() &&
12241 isImpliedCondOperands(ICmpInst::getSwappedCmpPredicate(*P),
12242 getNotSCEV(LHS), getNotSCEV(RHS), FoundLHS,
12243 FoundRHS, CtxI))
12244 return true;
12245
12246 if (!FoundLHS->getType()->isPointerTy() &&
12247 !FoundRHS->getType()->isPointerTy() &&
12248 isImpliedCondOperands(*P, LHS, RHS, getNotSCEV(FoundLHS),
12249 getNotSCEV(FoundRHS), CtxI))
12250 return true;
12251
12252 return false;
12253 }
12254
12255 auto IsSignFlippedPredicate = [](CmpInst::Predicate P1,
12257 assert(P1 != P2 && "Handled earlier!");
12258 return CmpInst::isRelational(P2) &&
12260 };
12261 if (IsSignFlippedPredicate(Pred, FoundPred)) {
12262 // Unsigned comparison is the same as signed comparison when both the
12263 // operands are non-negative or negative.
12264 if (haveSameSign(FoundLHS, FoundRHS))
12265 return isImpliedCondOperands(Pred, LHS, RHS, FoundLHS, FoundRHS, CtxI);
12266 // Create local copies that we can freely swap and canonicalize our
12267 // conditions to "le/lt".
12268 CmpPredicate CanonicalPred = Pred, CanonicalFoundPred = FoundPred;
12269 const SCEV *CanonicalLHS = LHS, *CanonicalRHS = RHS,
12270 *CanonicalFoundLHS = FoundLHS, *CanonicalFoundRHS = FoundRHS;
12271 if (ICmpInst::isGT(CanonicalPred) || ICmpInst::isGE(CanonicalPred)) {
12272 CanonicalPred = ICmpInst::getSwappedCmpPredicate(CanonicalPred);
12273 CanonicalFoundPred = ICmpInst::getSwappedCmpPredicate(CanonicalFoundPred);
12274 std::swap(CanonicalLHS, CanonicalRHS);
12275 std::swap(CanonicalFoundLHS, CanonicalFoundRHS);
12276 }
12277 assert((ICmpInst::isLT(CanonicalPred) || ICmpInst::isLE(CanonicalPred)) &&
12278 "Must be!");
12279 assert((ICmpInst::isLT(CanonicalFoundPred) ||
12280 ICmpInst::isLE(CanonicalFoundPred)) &&
12281 "Must be!");
12282 if (ICmpInst::isSigned(CanonicalPred) && isKnownNonNegative(CanonicalRHS))
12283 // Use implication:
12284 // x <u y && y >=s 0 --> x <s y.
12285 // If we can prove the left part, the right part is also proven.
12286 return isImpliedCondOperands(CanonicalFoundPred, CanonicalLHS,
12287 CanonicalRHS, CanonicalFoundLHS,
12288 CanonicalFoundRHS);
12289 if (ICmpInst::isUnsigned(CanonicalPred) && isKnownNegative(CanonicalRHS))
12290 // Use implication:
12291 // x <s y && y <s 0 --> x <u y.
12292 // If we can prove the left part, the right part is also proven.
12293 return isImpliedCondOperands(CanonicalFoundPred, CanonicalLHS,
12294 CanonicalRHS, CanonicalFoundLHS,
12295 CanonicalFoundRHS);
12296 }
12297
12298 // Check if we can make progress by sharpening ranges.
12299 if (FoundPred == ICmpInst::ICMP_NE &&
12300 (isa<SCEVConstant>(FoundLHS) || isa<SCEVConstant>(FoundRHS))) {
12301
12302 const SCEVConstant *C = nullptr;
12303 const SCEV *V = nullptr;
12304
12305 if (isa<SCEVConstant>(FoundLHS)) {
12306 C = cast<SCEVConstant>(FoundLHS);
12307 V = FoundRHS;
12308 } else {
12309 C = cast<SCEVConstant>(FoundRHS);
12310 V = FoundLHS;
12311 }
12312
12313 // The guarding predicate tells us that C != V. If the known range
12314 // of V is [C, t), we can sharpen the range to [C + 1, t). The
12315 // range we consider has to correspond to same signedness as the
12316 // predicate we're interested in folding.
12317
12318 APInt Min = ICmpInst::isSigned(Pred) ?
12320
12321 if (Min == C->getAPInt()) {
12322 // Given (V >= Min && V != Min) we conclude V >= (Min + 1).
12323 // This is true even if (Min + 1) wraps around -- in case of
12324 // wraparound, (Min + 1) < Min, so (V >= Min => V >= (Min + 1)).
12325
12326 APInt SharperMin = Min + 1;
12327
12328 switch (Pred) {
12329 case ICmpInst::ICMP_SGE:
12330 case ICmpInst::ICMP_UGE:
12331 // We know V `Pred` SharperMin. If this implies LHS `Pred`
12332 // RHS, we're done.
12333 if (isImpliedCondOperands(Pred, LHS, RHS, V, getConstant(SharperMin),
12334 CtxI))
12335 return true;
12336 [[fallthrough]];
12337
12338 case ICmpInst::ICMP_SGT:
12339 case ICmpInst::ICMP_UGT:
12340 // We know from the range information that (V `Pred` Min ||
12341 // V == Min). We know from the guarding condition that !(V
12342 // == Min). This gives us
12343 //
12344 // V `Pred` Min || V == Min && !(V == Min)
12345 // => V `Pred` Min
12346 //
12347 // If V `Pred` Min implies LHS `Pred` RHS, we're done.
12348
12349 if (isImpliedCondOperands(Pred, LHS, RHS, V, getConstant(Min), CtxI))
12350 return true;
12351 break;
12352
12353 // `LHS < RHS` and `LHS <= RHS` are handled in the same way as `RHS > LHS` and `RHS >= LHS` respectively.
12354 case ICmpInst::ICMP_SLE:
12355 case ICmpInst::ICMP_ULE:
12356 if (isImpliedCondOperands(ICmpInst::getSwappedCmpPredicate(Pred), RHS,
12357 LHS, V, getConstant(SharperMin), CtxI))
12358 return true;
12359 [[fallthrough]];
12360
12361 case ICmpInst::ICMP_SLT:
12362 case ICmpInst::ICMP_ULT:
12363 if (isImpliedCondOperands(ICmpInst::getSwappedCmpPredicate(Pred), RHS,
12364 LHS, V, getConstant(Min), CtxI))
12365 return true;
12366 break;
12367
12368 default:
12369 // No change
12370 break;
12371 }
12372 }
12373 }
12374
12375 // Check whether the actual condition is beyond sufficient.
12376 if (FoundPred == ICmpInst::ICMP_EQ)
12377 if (ICmpInst::isTrueWhenEqual(Pred))
12378 if (isImpliedCondOperands(Pred, LHS, RHS, FoundLHS, FoundRHS, CtxI))
12379 return true;
12380 if (Pred == ICmpInst::ICMP_NE)
12381 if (!ICmpInst::isTrueWhenEqual(FoundPred))
12382 if (isImpliedCondOperands(FoundPred, LHS, RHS, FoundLHS, FoundRHS, CtxI))
12383 return true;
12384
12385 if (isImpliedCondOperandsViaRanges(Pred, LHS, RHS, FoundPred, FoundLHS, FoundRHS))
12386 return true;
12387
12388 // Otherwise assume the worst.
12389 return false;
12390}
12391
12392bool ScalarEvolution::splitBinaryAdd(SCEVUse Expr, SCEVUse &L, SCEVUse &R,
12393 SCEV::NoWrapFlags &Flags) {
12394 if (!match(Expr, m_scev_Add(m_SCEV(L), m_SCEV(R))))
12395 return false;
12396
12397 Flags = cast<SCEVAddExpr>(Expr)->getNoWrapFlags();
12398 return true;
12399}
12400
12401std::optional<APInt>
12403 // We avoid subtracting expressions here because this function is usually
12404 // fairly deep in the call stack (i.e. is called many times).
12405
12406 unsigned BW = getTypeSizeInBits(More->getType());
12407 APInt Diff(BW, 0);
12408 APInt DiffMul(BW, 1);
12409 // Try various simplifications to reduce the difference to a constant. Limit
12410 // the number of allowed simplifications to keep compile-time low.
12411 for (unsigned I = 0; I < 8; ++I) {
12412 if (More == Less)
12413 return Diff;
12414
12415 // Reduce addrecs with identical steps to their start value.
12417 const auto *LAR = cast<SCEVAddRecExpr>(Less);
12418 const auto *MAR = cast<SCEVAddRecExpr>(More);
12419
12420 if (LAR->getLoop() != MAR->getLoop())
12421 return std::nullopt;
12422
12423 // We look at affine expressions only; not for correctness but to keep
12424 // getStepRecurrence cheap.
12425 if (!LAR->isAffine() || !MAR->isAffine())
12426 return std::nullopt;
12427
12428 if (LAR->getStepRecurrence(*this) != MAR->getStepRecurrence(*this))
12429 return std::nullopt;
12430
12431 Less = LAR->getStart();
12432 More = MAR->getStart();
12433 continue;
12434 }
12435
12436 // Try to match a common constant multiply.
12437 auto MatchConstMul =
12438 [](const SCEV *S) -> std::optional<std::pair<const SCEV *, APInt>> {
12439 const APInt *C;
12440 const SCEV *Op;
12441 if (match(S, m_scev_Mul(m_scev_APInt(C), m_SCEV(Op))))
12442 return {{Op, *C}};
12443 return std::nullopt;
12444 };
12445 if (auto MatchedMore = MatchConstMul(More)) {
12446 if (auto MatchedLess = MatchConstMul(Less)) {
12447 if (MatchedMore->second == MatchedLess->second) {
12448 More = MatchedMore->first;
12449 Less = MatchedLess->first;
12450 DiffMul *= MatchedMore->second;
12451 continue;
12452 }
12453 }
12454 }
12455
12456 // Try to cancel out common factors in two add expressions.
12458 auto Add = [&](const SCEV *S, int Mul) {
12459 if (auto *C = dyn_cast<SCEVConstant>(S)) {
12460 if (Mul == 1) {
12461 Diff += C->getAPInt() * DiffMul;
12462 } else {
12463 assert(Mul == -1);
12464 Diff -= C->getAPInt() * DiffMul;
12465 }
12466 } else
12467 Multiplicity[S] += Mul;
12468 };
12469 auto Decompose = [&](const SCEV *S, int Mul) {
12470 if (isa<SCEVAddExpr>(S)) {
12471 for (const SCEV *Op : S->operands())
12472 Add(Op, Mul);
12473 } else
12474 Add(S, Mul);
12475 };
12476 Decompose(More, 1);
12477 Decompose(Less, -1);
12478
12479 // Check whether all the non-constants cancel out, or reduce to new
12480 // More/Less values.
12481 const SCEV *NewMore = nullptr, *NewLess = nullptr;
12482 for (const auto &[S, Mul] : Multiplicity) {
12483 if (Mul == 0)
12484 continue;
12485 if (Mul == 1) {
12486 if (NewMore)
12487 return std::nullopt;
12488 NewMore = S;
12489 } else if (Mul == -1) {
12490 if (NewLess)
12491 return std::nullopt;
12492 NewLess = S;
12493 } else
12494 return std::nullopt;
12495 }
12496
12497 // Values stayed the same, no point in trying further.
12498 if (NewMore == More || NewLess == Less)
12499 return std::nullopt;
12500
12501 More = NewMore;
12502 Less = NewLess;
12503
12504 // Reduced to constant.
12505 if (!More && !Less)
12506 return Diff;
12507
12508 // Left with variable on only one side, bail out.
12509 if (!More || !Less)
12510 return std::nullopt;
12511 }
12512
12513 // Did not reduce to constant.
12514 return std::nullopt;
12515}
12516
12517bool ScalarEvolution::isImpliedCondOperandsViaAddRecStart(
12518 CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS, const SCEV *FoundLHS,
12519 const SCEV *FoundRHS, const Instruction *CtxI) {
12520 // Try to recognize the following pattern:
12521 //
12522 // FoundRHS = ...
12523 // ...
12524 // loop:
12525 // FoundLHS = {Start,+,W}
12526 // context_bb: // Basic block from the same loop
12527 // known(Pred, FoundLHS, FoundRHS)
12528 //
12529 // If some predicate is known in the context of a loop, it is also known on
12530 // each iteration of this loop, including the first iteration. Therefore, in
12531 // this case, `FoundLHS Pred FoundRHS` implies `Start Pred FoundRHS`. Try to
12532 // prove the original pred using this fact.
12533 if (!CtxI)
12534 return false;
12535 const BasicBlock *ContextBB = CtxI->getParent();
12536 // Make sure AR varies in the context block.
12537 if (auto *AR = dyn_cast<SCEVAddRecExpr>(FoundLHS)) {
12538 const Loop *L = AR->getLoop();
12539 const auto *Latch = L->getLoopLatch();
12540 // Make sure that context belongs to the loop and executes on 1st iteration
12541 // (if it ever executes at all).
12542 if (!L->contains(ContextBB) || !Latch || !DT.dominates(ContextBB, Latch))
12543 return false;
12544 if (!isAvailableAtLoopEntry(FoundRHS, AR->getLoop()))
12545 return false;
12546 return isImpliedCondOperands(Pred, LHS, RHS, AR->getStart(), FoundRHS);
12547 }
12548
12549 if (auto *AR = dyn_cast<SCEVAddRecExpr>(FoundRHS)) {
12550 const Loop *L = AR->getLoop();
12551 const auto *Latch = L->getLoopLatch();
12552 // Make sure that context belongs to the loop and executes on 1st iteration
12553 // (if it ever executes at all).
12554 if (!L->contains(ContextBB) || !Latch || !DT.dominates(ContextBB, Latch))
12555 return false;
12556 if (!isAvailableAtLoopEntry(FoundLHS, AR->getLoop()))
12557 return false;
12558 return isImpliedCondOperands(Pred, LHS, RHS, FoundLHS, AR->getStart());
12559 }
12560
12561 return false;
12562}
12563
12564bool ScalarEvolution::isImpliedCondOperandsViaNoOverflow(CmpPredicate Pred,
12565 const SCEV *LHS,
12566 const SCEV *RHS,
12567 const SCEV *FoundLHS,
12568 const SCEV *FoundRHS) {
12569 if (Pred != CmpInst::ICMP_SLT && Pred != CmpInst::ICMP_ULT)
12570 return false;
12571
12572 const auto *AddRecLHS = dyn_cast<SCEVAddRecExpr>(LHS);
12573 if (!AddRecLHS)
12574 return false;
12575
12576 const auto *AddRecFoundLHS = dyn_cast<SCEVAddRecExpr>(FoundLHS);
12577 if (!AddRecFoundLHS)
12578 return false;
12579
12580 // We'd like to let SCEV reason about control dependencies, so we constrain
12581 // both the inequalities to be about add recurrences on the same loop. This
12582 // way we can use isLoopEntryGuardedByCond later.
12583
12584 const Loop *L = AddRecFoundLHS->getLoop();
12585 if (L != AddRecLHS->getLoop())
12586 return false;
12587
12588 // FoundLHS u< FoundRHS u< -C => (FoundLHS + C) u< (FoundRHS + C) ... (1)
12589 //
12590 // FoundLHS s< FoundRHS s< INT_MIN - C => (FoundLHS + C) s< (FoundRHS + C)
12591 // ... (2)
12592 //
12593 // Informal proof for (2), assuming (1) [*]:
12594 //
12595 // We'll also assume (A s< B) <=> ((A + INT_MIN) u< (B + INT_MIN)) ... (3)[**]
12596 //
12597 // Then
12598 //
12599 // FoundLHS s< FoundRHS s< INT_MIN - C
12600 // <=> (FoundLHS + INT_MIN) u< (FoundRHS + INT_MIN) u< -C [ using (3) ]
12601 // <=> (FoundLHS + INT_MIN + C) u< (FoundRHS + INT_MIN + C) [ using (1) ]
12602 // <=> (FoundLHS + INT_MIN + C + INT_MIN) s<
12603 // (FoundRHS + INT_MIN + C + INT_MIN) [ using (3) ]
12604 // <=> FoundLHS + C s< FoundRHS + C
12605 //
12606 // [*]: (1) can be proved by ruling out overflow.
12607 //
12608 // [**]: This can be proved by analyzing all the four possibilities:
12609 // (A s< 0, B s< 0), (A s< 0, B s>= 0), (A s>= 0, B s< 0) and
12610 // (A s>= 0, B s>= 0).
12611 //
12612 // Note:
12613 // Despite (2), "FoundRHS s< INT_MIN - C" does not mean that "FoundRHS + C"
12614 // will not sign underflow. For instance, say FoundLHS = (i8 -128), FoundRHS
12615 // = (i8 -127) and C = (i8 -100). Then INT_MIN - C = (i8 -28), and FoundRHS
12616 // s< (INT_MIN - C). Lack of sign overflow / underflow in "FoundRHS + C" is
12617 // neither necessary nor sufficient to prove "(FoundLHS + C) s< (FoundRHS +
12618 // C)".
12619
12620 std::optional<APInt> LDiff = computeConstantDifference(LHS, FoundLHS);
12621 if (!LDiff)
12622 return false;
12623 std::optional<APInt> RDiff = computeConstantDifference(RHS, FoundRHS);
12624 if (!RDiff || *LDiff != *RDiff)
12625 return false;
12626
12627 if (LDiff->isMinValue())
12628 return true;
12629
12630 APInt FoundRHSLimit;
12631
12632 if (Pred == CmpInst::ICMP_ULT) {
12633 FoundRHSLimit = -(*RDiff);
12634 } else {
12635 assert(Pred == CmpInst::ICMP_SLT && "Checked above!");
12636 FoundRHSLimit = APInt::getSignedMinValue(getTypeSizeInBits(RHS->getType())) - *RDiff;
12637 }
12638
12639 // Try to prove (1) or (2), as needed.
12640 return isAvailableAtLoopEntry(FoundRHS, L) &&
12641 isLoopEntryGuardedByCond(L, Pred, FoundRHS,
12642 getConstant(FoundRHSLimit));
12643}
12644
12645bool ScalarEvolution::isImpliedViaMerge(CmpPredicate Pred, const SCEV *LHS,
12646 const SCEV *RHS, const SCEV *FoundLHS,
12647 const SCEV *FoundRHS, unsigned Depth) {
12648 const PHINode *LPhi = nullptr, *RPhi = nullptr;
12649
12650 llvm::scope_exit ClearOnExit([&]() {
12651 if (LPhi) {
12652 bool Erased = PendingMerges.erase(LPhi);
12653 assert(Erased && "Failed to erase LPhi!");
12654 (void)Erased;
12655 }
12656 if (RPhi) {
12657 bool Erased = PendingMerges.erase(RPhi);
12658 assert(Erased && "Failed to erase RPhi!");
12659 (void)Erased;
12660 }
12661 });
12662
12663 // Find respective Phis and check that they are not being pending.
12664 if (const SCEVUnknown *LU = dyn_cast<SCEVUnknown>(LHS))
12665 if (auto *Phi = dyn_cast<PHINode>(LU->getValue())) {
12666 if (!PendingMerges.insert(Phi).second)
12667 return false;
12668 LPhi = Phi;
12669 }
12670 if (const SCEVUnknown *RU = dyn_cast<SCEVUnknown>(RHS))
12671 if (auto *Phi = dyn_cast<PHINode>(RU->getValue())) {
12672 // If we detect a loop of Phi nodes being processed by this method, for
12673 // example:
12674 //
12675 // %a = phi i32 [ %some1, %preheader ], [ %b, %latch ]
12676 // %b = phi i32 [ %some2, %preheader ], [ %a, %latch ]
12677 //
12678 // we don't want to deal with a case that complex, so return conservative
12679 // answer false.
12680 if (!PendingMerges.insert(Phi).second)
12681 return false;
12682 RPhi = Phi;
12683 }
12684
12685 // If none of LHS, RHS is a Phi, nothing to do here.
12686 if (!LPhi && !RPhi)
12687 return false;
12688
12689 // If there is a SCEVUnknown Phi we are interested in, make it left.
12690 if (!LPhi) {
12691 std::swap(LHS, RHS);
12692 std::swap(FoundLHS, FoundRHS);
12693 std::swap(LPhi, RPhi);
12695 }
12696
12697 assert(LPhi && "LPhi should definitely be a SCEVUnknown Phi!");
12698 const BasicBlock *LBB = LPhi->getParent();
12699 const SCEVAddRecExpr *RAR = dyn_cast<SCEVAddRecExpr>(RHS);
12700
12701 auto ProvedEasily = [&](const SCEV *S1, const SCEV *S2) {
12702 return isKnownViaNonRecursiveReasoning(Pred, S1, S2) ||
12703 isImpliedCondOperandsViaRanges(Pred, S1, S2, Pred, FoundLHS, FoundRHS) ||
12704 isImpliedViaOperations(Pred, S1, S2, FoundLHS, FoundRHS, Depth);
12705 };
12706
12707 if (RPhi && RPhi->getParent() == LBB) {
12708 // Case one: RHS is also a SCEVUnknown Phi from the same basic block.
12709 // If we compare two Phis from the same block, and for each entry block
12710 // the predicate is true for incoming values from this block, then the
12711 // predicate is also true for the Phis.
12712 for (const BasicBlock *IncBB : predecessors(LBB)) {
12713 const SCEV *L = getSCEV(LPhi->getIncomingValueForBlock(IncBB));
12714 const SCEV *R = getSCEV(RPhi->getIncomingValueForBlock(IncBB));
12715 if (!ProvedEasily(L, R))
12716 return false;
12717 }
12718 } else if (RAR && RAR->getLoop()->getHeader() == LBB) {
12719 // Case two: RHS is also a Phi from the same basic block, and it is an
12720 // AddRec. It means that there is a loop which has both AddRec and Unknown
12721 // PHIs, for it we can compare incoming values of AddRec from above the loop
12722 // and latch with their respective incoming values of LPhi.
12723 // TODO: Generalize to handle loops with many inputs in a header.
12724 if (LPhi->getNumIncomingValues() != 2) return false;
12725
12726 auto *RLoop = RAR->getLoop();
12727 auto *Predecessor = RLoop->getLoopPredecessor();
12728 assert(Predecessor && "Loop with AddRec with no predecessor?");
12729 const SCEV *L1 = getSCEV(LPhi->getIncomingValueForBlock(Predecessor));
12730 if (!ProvedEasily(L1, RAR->getStart()))
12731 return false;
12732 auto *Latch = RLoop->getLoopLatch();
12733 assert(Latch && "Loop with AddRec with no latch?");
12734 const SCEV *L2 = getSCEV(LPhi->getIncomingValueForBlock(Latch));
12735 if (!ProvedEasily(L2, RAR->getPostIncExpr(*this)))
12736 return false;
12737 } else {
12738 // In all other cases go over inputs of LHS and compare each of them to RHS,
12739 // the predicate is true for (LHS, RHS) if it is true for all such pairs.
12740 // At this point RHS is either a non-Phi, or it is a Phi from some block
12741 // different from LBB.
12742 for (const BasicBlock *IncBB : predecessors(LBB)) {
12743 // Check that RHS is available in this block.
12744 if (!dominates(RHS, IncBB))
12745 return false;
12746 const SCEV *L = getSCEV(LPhi->getIncomingValueForBlock(IncBB));
12747 // Make sure L does not refer to a value from a potentially previous
12748 // iteration of a loop.
12749 if (!properlyDominates(L, LBB))
12750 return false;
12751 // Addrecs are considered to properly dominate their loop, so are missed
12752 // by the previous check. Discard any values that have computable
12753 // evolution in this loop.
12754 if (auto *Loop = LI.getLoopFor(LBB))
12755 if (hasComputableLoopEvolution(L, Loop))
12756 return false;
12757 if (!ProvedEasily(L, RHS))
12758 return false;
12759 }
12760 }
12761 return true;
12762}
12763
12764bool ScalarEvolution::isImpliedCondOperandsViaShift(CmpPredicate Pred,
12765 const SCEV *LHS,
12766 const SCEV *RHS,
12767 const SCEV *FoundLHS,
12768 const SCEV *FoundRHS) {
12769 // We want to imply LHS < RHS from LHS < (RHS >> shiftvalue). First, make
12770 // sure that we are dealing with same LHS.
12771 if (RHS == FoundRHS) {
12772 std::swap(LHS, RHS);
12773 std::swap(FoundLHS, FoundRHS);
12775 }
12776 if (LHS != FoundLHS)
12777 return false;
12778
12779 auto *SUFoundRHS = dyn_cast<SCEVUnknown>(FoundRHS);
12780 if (!SUFoundRHS)
12781 return false;
12782
12783 Value *Shiftee, *ShiftValue;
12784
12785 using namespace PatternMatch;
12786 if (match(SUFoundRHS->getValue(),
12787 m_LShr(m_Value(Shiftee), m_Value(ShiftValue)))) {
12788 auto *ShifteeS = getSCEV(Shiftee);
12789 // Prove one of the following:
12790 // LHS <u (shiftee >> shiftvalue) && shiftee <=u RHS ---> LHS <u RHS
12791 // LHS <=u (shiftee >> shiftvalue) && shiftee <=u RHS ---> LHS <=u RHS
12792 // LHS <s (shiftee >> shiftvalue) && shiftee <=s RHS && shiftee >=s 0
12793 // ---> LHS <s RHS
12794 // LHS <=s (shiftee >> shiftvalue) && shiftee <=s RHS && shiftee >=s 0
12795 // ---> LHS <=s RHS
12796 if (Pred == ICmpInst::ICMP_ULT || Pred == ICmpInst::ICMP_ULE)
12797 return isKnownPredicate(ICmpInst::ICMP_ULE, ShifteeS, RHS);
12798 if (Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_SLE)
12799 if (isKnownNonNegative(ShifteeS))
12800 return isKnownPredicate(ICmpInst::ICMP_SLE, ShifteeS, RHS);
12801 }
12802
12803 return false;
12804}
12805
12806bool ScalarEvolution::isImpliedCondOperands(CmpPredicate Pred, const SCEV *LHS,
12807 const SCEV *RHS,
12808 const SCEV *FoundLHS,
12809 const SCEV *FoundRHS,
12810 const Instruction *CtxI) {
12811 return isImpliedCondOperandsViaRanges(Pred, LHS, RHS, Pred, FoundLHS,
12812 FoundRHS) ||
12813 isImpliedCondOperandsViaNoOverflow(Pred, LHS, RHS, FoundLHS,
12814 FoundRHS) ||
12815 isImpliedCondOperandsViaShift(Pred, LHS, RHS, FoundLHS, FoundRHS) ||
12816 isImpliedCondOperandsViaAddRecStart(Pred, LHS, RHS, FoundLHS, FoundRHS,
12817 CtxI) ||
12818 isImpliedCondOperandsHelper(Pred, LHS, RHS, FoundLHS, FoundRHS);
12819}
12820
12821/// Is MaybeMinMaxExpr an (U|S)(Min|Max) of Candidate and some other values?
12822template <typename MinMaxExprType>
12823static bool IsMinMaxConsistingOf(const SCEV *MaybeMinMaxExpr,
12824 const SCEV *Candidate) {
12825 const MinMaxExprType *MinMaxExpr = dyn_cast<MinMaxExprType>(MaybeMinMaxExpr);
12826 if (!MinMaxExpr)
12827 return false;
12828
12829 return is_contained(MinMaxExpr->operands(), Candidate);
12830}
12831
12833 CmpPredicate Pred, const SCEV *LHS,
12834 const SCEV *RHS) {
12835 // If both sides are affine addrecs for the same loop, with equal
12836 // steps, and we know the recurrences don't wrap, then we only
12837 // need to check the predicate on the starting values.
12838
12839 if (!ICmpInst::isRelational(Pred))
12840 return false;
12841
12842 const SCEV *LStart, *RStart, *Step;
12843 const Loop *L;
12844 if (!match(LHS,
12845 m_scev_AffineAddRec(m_SCEV(LStart), m_SCEV(Step), m_Loop(L))) ||
12847 m_SpecificLoop(L))))
12848 return false;
12853 if (!LAR->getNoWrapFlags(NW) || !RAR->getNoWrapFlags(NW))
12854 return false;
12855
12856 return SE.isKnownPredicate(Pred, LStart, RStart);
12857}
12858
12859/// Is LHS `Pred` RHS true on the virtue of LHS or RHS being a Min or Max
12860/// expression?
12862 const SCEV *LHS, const SCEV *RHS) {
12863 switch (Pred) {
12864 default:
12865 return false;
12866
12867 case ICmpInst::ICMP_SGE:
12868 std::swap(LHS, RHS);
12869 [[fallthrough]];
12870 case ICmpInst::ICMP_SLE:
12871 return
12872 // min(A, ...) <= A
12874 // A <= max(A, ...)
12876
12877 case ICmpInst::ICMP_UGE:
12878 std::swap(LHS, RHS);
12879 [[fallthrough]];
12880 case ICmpInst::ICMP_ULE:
12881 return
12882 // min(A, ...) <= A
12883 // FIXME: what about umin_seq?
12885 // A <= max(A, ...)
12887 }
12888
12889 llvm_unreachable("covered switch fell through?!");
12890}
12891
12892bool ScalarEvolution::isImpliedViaOperations(CmpPredicate Pred, const SCEV *LHS,
12893 const SCEV *RHS,
12894 const SCEV *FoundLHS,
12895 const SCEV *FoundRHS,
12896 unsigned Depth) {
12899 "LHS and RHS have different sizes?");
12900 assert(getTypeSizeInBits(FoundLHS->getType()) ==
12901 getTypeSizeInBits(FoundRHS->getType()) &&
12902 "FoundLHS and FoundRHS have different sizes?");
12903 // We want to avoid hurting the compile time with analysis of too big trees.
12905 return false;
12906
12907 // We only want to work with GT comparison so far.
12908 if (ICmpInst::isLT(Pred)) {
12910 std::swap(LHS, RHS);
12911 std::swap(FoundLHS, FoundRHS);
12912 }
12913
12915
12916 // For unsigned, try to reduce it to corresponding signed comparison.
12917 if (P == ICmpInst::ICMP_UGT)
12918 // We can replace unsigned predicate with its signed counterpart if all
12919 // involved values are non-negative.
12920 // TODO: We could have better support for unsigned.
12921 if (isKnownNonNegative(FoundLHS) && isKnownNonNegative(FoundRHS)) {
12922 // Knowing that both FoundLHS and FoundRHS are non-negative, and knowing
12923 // FoundLHS >u FoundRHS, we also know that FoundLHS >s FoundRHS. Let us
12924 // use this fact to prove that LHS and RHS are non-negative.
12925 const SCEV *MinusOne = getMinusOne(LHS->getType());
12926 if (isImpliedCondOperands(ICmpInst::ICMP_SGT, LHS, MinusOne, FoundLHS,
12927 FoundRHS) &&
12928 isImpliedCondOperands(ICmpInst::ICMP_SGT, RHS, MinusOne, FoundLHS,
12929 FoundRHS))
12931 }
12932
12933 if (P != ICmpInst::ICMP_SGT)
12934 return false;
12935
12936 auto GetOpFromSExt = [&](const SCEV *S) -> const SCEV * {
12937 if (auto *Ext = dyn_cast<SCEVSignExtendExpr>(S))
12938 return Ext->getOperand();
12939 // TODO: If S is a SCEVConstant then you can cheaply "strip" the sext off
12940 // the constant in some cases.
12941 return S;
12942 };
12943
12944 // Acquire values from extensions.
12945 auto *OrigLHS = LHS;
12946 auto *OrigFoundLHS = FoundLHS;
12947 LHS = GetOpFromSExt(LHS);
12948 FoundLHS = GetOpFromSExt(FoundLHS);
12949
12950 // Is the SGT predicate can be proved trivially or using the found context.
12951 auto IsSGTViaContext = [&](const SCEV *S1, const SCEV *S2) {
12952 return isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_SGT, S1, S2) ||
12953 isImpliedViaOperations(ICmpInst::ICMP_SGT, S1, S2, OrigFoundLHS,
12954 FoundRHS, Depth + 1);
12955 };
12956
12957 if (auto *LHSAddExpr = dyn_cast<SCEVAddExpr>(LHS)) {
12958 // We want to avoid creation of any new non-constant SCEV. Since we are
12959 // going to compare the operands to RHS, we should be certain that we don't
12960 // need any size extensions for this. So let's decline all cases when the
12961 // sizes of types of LHS and RHS do not match.
12962 // TODO: Maybe try to get RHS from sext to catch more cases?
12964 return false;
12965
12966 // Should not overflow.
12967 if (!LHSAddExpr->hasNoSignedWrap())
12968 return false;
12969
12970 SCEVUse LL = LHSAddExpr->getOperand(0);
12971 SCEVUse LR = LHSAddExpr->getOperand(1);
12972 auto *MinusOne = getMinusOne(RHS->getType());
12973
12974 // Checks that S1 >= 0 && S2 > RHS, trivially or using the found context.
12975 auto IsSumGreaterThanRHS = [&](const SCEV *S1, const SCEV *S2) {
12976 return IsSGTViaContext(S1, MinusOne) && IsSGTViaContext(S2, RHS);
12977 };
12978 // Try to prove the following rule:
12979 // (LHS = LL + LR) && (LL >= 0) && (LR > RHS) => (LHS > RHS).
12980 // (LHS = LL + LR) && (LR >= 0) && (LL > RHS) => (LHS > RHS).
12981 if (IsSumGreaterThanRHS(LL, LR) || IsSumGreaterThanRHS(LR, LL))
12982 return true;
12983 } else if (auto *LHSUnknownExpr = dyn_cast<SCEVUnknown>(LHS)) {
12984 Value *LL, *LR;
12985 // FIXME: Once we have SDiv implemented, we can get rid of this matching.
12986
12987 using namespace llvm::PatternMatch;
12988
12989 if (match(LHSUnknownExpr->getValue(), m_SDiv(m_Value(LL), m_Value(LR)))) {
12990 // Rules for division.
12991 // We are going to perform some comparisons with Denominator and its
12992 // derivative expressions. In general case, creating a SCEV for it may
12993 // lead to a complex analysis of the entire graph, and in particular it
12994 // can request trip count recalculation for the same loop. This would
12995 // cache as SCEVCouldNotCompute to avoid the infinite recursion. To avoid
12996 // this, we only want to create SCEVs that are constants in this section.
12997 // So we bail if Denominator is not a constant.
12998 if (!isa<ConstantInt>(LR))
12999 return false;
13000
13001 auto *Denominator = cast<SCEVConstant>(getSCEV(LR));
13002
13003 // We want to make sure that LHS = FoundLHS / Denominator. If it is so,
13004 // then a SCEV for the numerator already exists and matches with FoundLHS.
13005 auto *Numerator = getExistingSCEV(LL);
13006 if (!Numerator || Numerator->getType() != FoundLHS->getType())
13007 return false;
13008
13009 // Make sure that the numerator matches with FoundLHS and the denominator
13010 // is positive.
13011 if (!HasSameValue(Numerator, FoundLHS) || !isKnownPositive(Denominator))
13012 return false;
13013
13014 auto *DTy = Denominator->getType();
13015 auto *FRHSTy = FoundRHS->getType();
13016 if (DTy->isPointerTy() != FRHSTy->isPointerTy())
13017 // One of types is a pointer and another one is not. We cannot extend
13018 // them properly to a wider type, so let us just reject this case.
13019 // TODO: Usage of getEffectiveSCEVType for DTy, FRHSTy etc should help
13020 // to avoid this check.
13021 return false;
13022
13023 // Given that:
13024 // FoundLHS > FoundRHS, LHS = FoundLHS / Denominator, Denominator > 0.
13025 auto *WTy = getWiderType(DTy, FRHSTy);
13026 auto *DenominatorExt = getNoopOrSignExtend(Denominator, WTy);
13027 auto *FoundRHSExt = getNoopOrSignExtend(FoundRHS, WTy);
13028
13029 // Try to prove the following rule:
13030 // (FoundRHS > Denominator - 2) && (RHS <= 0) => (LHS > RHS).
13031 // For example, given that FoundLHS > 2. It means that FoundLHS is at
13032 // least 3. If we divide it by Denominator < 4, we will have at least 1.
13033 auto *DenomMinusTwo = getMinusSCEV(DenominatorExt, getConstant(WTy, 2));
13034 if (isKnownNonPositive(RHS) &&
13035 IsSGTViaContext(FoundRHSExt, DenomMinusTwo))
13036 return true;
13037
13038 // Try to prove the following rule:
13039 // (FoundRHS > -1 - Denominator) && (RHS < 0) => (LHS > RHS).
13040 // For example, given that FoundLHS > -3. Then FoundLHS is at least -2.
13041 // If we divide it by Denominator > 2, then:
13042 // 1. If FoundLHS is negative, then the result is 0.
13043 // 2. If FoundLHS is non-negative, then the result is non-negative.
13044 // Anyways, the result is non-negative.
13045 auto *MinusOne = getMinusOne(WTy);
13046 auto *NegDenomMinusOne = getMinusSCEV(MinusOne, DenominatorExt);
13047 if (isKnownNegative(RHS) &&
13048 IsSGTViaContext(FoundRHSExt, NegDenomMinusOne))
13049 return true;
13050 }
13051 }
13052
13053 // If our expression contained SCEVUnknown Phis, and we split it down and now
13054 // need to prove something for them, try to prove the predicate for every
13055 // possible incoming values of those Phis.
13056 if (isImpliedViaMerge(Pred, OrigLHS, RHS, OrigFoundLHS, FoundRHS, Depth + 1))
13057 return true;
13058
13059 return false;
13060}
13061
13063 const SCEV *RHS) {
13064 // zext x u<= sext x, sext x s<= zext x
13065 const SCEV *Op;
13066 switch (Pred) {
13067 case ICmpInst::ICMP_SGE:
13068 std::swap(LHS, RHS);
13069 [[fallthrough]];
13070 case ICmpInst::ICMP_SLE: {
13071 // If operand >=s 0 then ZExt == SExt. If operand <s 0 then SExt <s ZExt.
13072 return match(LHS, m_scev_SExt(m_SCEV(Op))) &&
13074 }
13075 case ICmpInst::ICMP_UGE:
13076 std::swap(LHS, RHS);
13077 [[fallthrough]];
13078 case ICmpInst::ICMP_ULE: {
13079 // If operand >=u 0 then ZExt == SExt. If operand <u 0 then ZExt <u SExt.
13080 return match(LHS, m_scev_ZExt(m_SCEV(Op))) &&
13082 }
13083 default:
13084 return false;
13085 };
13086 llvm_unreachable("unhandled case");
13087}
13088
13089bool ScalarEvolution::isKnownViaNonRecursiveReasoning(CmpPredicate Pred,
13090 SCEVUse LHS,
13091 SCEVUse RHS) {
13092 return isKnownPredicateExtendIdiom(Pred, LHS, RHS) ||
13093 isKnownPredicateViaConstantRanges(Pred, LHS, RHS) ||
13094 IsKnownPredicateViaMinOrMax(*this, Pred, LHS, RHS) ||
13095 IsKnownPredicateViaAddRecStart(*this, Pred, LHS, RHS) ||
13096 isKnownPredicateViaNoOverflow(Pred, LHS, RHS);
13097}
13098
13099bool ScalarEvolution::isImpliedCondOperandsHelper(CmpPredicate Pred,
13100 const SCEV *LHS,
13101 const SCEV *RHS,
13102 const SCEV *FoundLHS,
13103 const SCEV *FoundRHS) {
13104 switch (Pred) {
13105 default:
13106 llvm_unreachable("Unexpected CmpPredicate value!");
13107 case ICmpInst::ICMP_EQ:
13108 case ICmpInst::ICMP_NE:
13109 if (HasSameValue(LHS, FoundLHS) && HasSameValue(RHS, FoundRHS))
13110 return true;
13111 break;
13112 case ICmpInst::ICMP_SLT:
13113 case ICmpInst::ICMP_SLE:
13114 if (isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_SLE, LHS, FoundLHS) &&
13115 isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_SGE, RHS, FoundRHS))
13116 return true;
13117 break;
13118 case ICmpInst::ICMP_SGT:
13119 case ICmpInst::ICMP_SGE:
13120 if (isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_SGE, LHS, FoundLHS) &&
13121 isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_SLE, RHS, FoundRHS))
13122 return true;
13123 break;
13124 case ICmpInst::ICMP_ULT:
13125 case ICmpInst::ICMP_ULE:
13126 if (isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_ULE, LHS, FoundLHS) &&
13127 isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_UGE, RHS, FoundRHS))
13128 return true;
13129 break;
13130 case ICmpInst::ICMP_UGT:
13131 case ICmpInst::ICMP_UGE:
13132 if (isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_UGE, LHS, FoundLHS) &&
13133 isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_ULE, RHS, FoundRHS))
13134 return true;
13135 break;
13136 }
13137
13138 // Maybe it can be proved via operations?
13139 if (isImpliedViaOperations(Pred, LHS, RHS, FoundLHS, FoundRHS))
13140 return true;
13141
13142 return false;
13143}
13144
13145bool ScalarEvolution::isImpliedCondOperandsViaRanges(
13146 CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS, CmpPredicate FoundPred,
13147 const SCEV *FoundLHS, const SCEV *FoundRHS) {
13148 if (!isa<SCEVConstant>(RHS) || !isa<SCEVConstant>(FoundRHS))
13149 // The restriction on `FoundRHS` be lifted easily -- it exists only to
13150 // reduce the compile time impact of this optimization.
13151 return false;
13152
13153 std::optional<APInt> Addend = computeConstantDifference(LHS, FoundLHS);
13154 if (!Addend)
13155 return false;
13156
13157 const APInt &ConstFoundRHS = cast<SCEVConstant>(FoundRHS)->getAPInt();
13158
13159 // `FoundLHSRange` is the range we know `FoundLHS` to be in by virtue of the
13160 // antecedent "`FoundLHS` `FoundPred` `FoundRHS`".
13161 ConstantRange FoundLHSRange =
13162 ConstantRange::makeExactICmpRegion(FoundPred, ConstFoundRHS);
13163
13164 // Since `LHS` is `FoundLHS` + `Addend`, we can compute a range for `LHS`:
13165 ConstantRange LHSRange = FoundLHSRange.add(ConstantRange(*Addend));
13166
13167 // We can also compute the range of values for `LHS` that satisfy the
13168 // consequent, "`LHS` `Pred` `RHS`":
13169 const APInt &ConstRHS = cast<SCEVConstant>(RHS)->getAPInt();
13170 // The antecedent implies the consequent if every value of `LHS` that
13171 // satisfies the antecedent also satisfies the consequent.
13172 return LHSRange.icmp(Pred, ConstRHS);
13173}
13174
13175bool ScalarEvolution::canIVOverflowOnLT(const SCEV *RHS, const SCEV *Stride,
13176 bool IsSigned) {
13177 assert(isKnownPositive(Stride) && "Positive stride expected!");
13178
13179 unsigned BitWidth = getTypeSizeInBits(RHS->getType());
13180 const SCEV *One = getOne(Stride->getType());
13181
13182 if (IsSigned) {
13183 APInt MaxRHS = getSignedRangeMax(RHS);
13184 APInt MaxValue = APInt::getSignedMaxValue(BitWidth);
13185 APInt MaxStrideMinusOne = getSignedRangeMax(getMinusSCEV(Stride, One));
13186
13187 // SMaxRHS + SMaxStrideMinusOne > SMaxValue => overflow!
13188 return (std::move(MaxValue) - MaxStrideMinusOne).slt(MaxRHS);
13189 }
13190
13191 APInt MaxRHS = getUnsignedRangeMax(RHS);
13192 APInt MaxValue = APInt::getMaxValue(BitWidth);
13193 APInt MaxStrideMinusOne = getUnsignedRangeMax(getMinusSCEV(Stride, One));
13194
13195 // UMaxRHS + UMaxStrideMinusOne > UMaxValue => overflow!
13196 return (std::move(MaxValue) - MaxStrideMinusOne).ult(MaxRHS);
13197}
13198
13199bool ScalarEvolution::canIVOverflowOnGT(const SCEV *RHS, const SCEV *Stride,
13200 bool IsSigned) {
13201
13202 unsigned BitWidth = getTypeSizeInBits(RHS->getType());
13203 const SCEV *One = getOne(Stride->getType());
13204
13205 if (IsSigned) {
13206 APInt MinRHS = getSignedRangeMin(RHS);
13207 APInt MinValue = APInt::getSignedMinValue(BitWidth);
13208 APInt MaxStrideMinusOne = getSignedRangeMax(getMinusSCEV(Stride, One));
13209
13210 // SMinRHS - SMaxStrideMinusOne < SMinValue => overflow!
13211 return (std::move(MinValue) + MaxStrideMinusOne).sgt(MinRHS);
13212 }
13213
13214 APInt MinRHS = getUnsignedRangeMin(RHS);
13215 APInt MinValue = APInt::getMinValue(BitWidth);
13216 APInt MaxStrideMinusOne = getUnsignedRangeMax(getMinusSCEV(Stride, One));
13217
13218 // UMinRHS - UMaxStrideMinusOne < UMinValue => overflow!
13219 return (std::move(MinValue) + MaxStrideMinusOne).ugt(MinRHS);
13220}
13221
13223 // umin(N, 1) + floor((N - umin(N, 1)) / D)
13224 // This is equivalent to "1 + floor((N - 1) / D)" for N != 0. The umin
13225 // expression fixes the case of N=0.
13226 const SCEV *MinNOne = getUMinExpr(N, getOne(N->getType()));
13227 const SCEV *NMinusOne = getMinusSCEV(N, MinNOne);
13228 return getAddExpr(MinNOne, getUDivExpr(NMinusOne, D));
13229}
13230
13231const SCEV *ScalarEvolution::computeMaxBECountForLT(const SCEV *Start,
13232 const SCEV *Stride,
13233 const SCEV *End,
13234 unsigned BitWidth,
13235 bool IsSigned) {
13236 // The logic in this function assumes we can represent a positive stride.
13237 // If we can't, the backedge-taken count must be zero.
13238 if (IsSigned && BitWidth == 1)
13239 return getZero(Stride->getType());
13240
13241 // This code below only been closely audited for negative strides in the
13242 // unsigned comparison case, it may be correct for signed comparison, but
13243 // that needs to be established.
13244 if (IsSigned && isKnownNegative(Stride))
13245 return getCouldNotCompute();
13246
13247 // Calculate the maximum backedge count based on the range of values
13248 // permitted by Start, End, and Stride.
13249 APInt MinStart =
13250 IsSigned ? getSignedRangeMin(Start) : getUnsignedRangeMin(Start);
13251
13252 APInt MinStride =
13253 IsSigned ? getSignedRangeMin(Stride) : getUnsignedRangeMin(Stride);
13254
13255 // We assume either the stride is positive, or the backedge-taken count
13256 // is zero. So force StrideForMaxBECount to be at least one.
13257 APInt One(BitWidth, 1);
13258 APInt StrideForMaxBECount = IsSigned ? APIntOps::smax(One, MinStride)
13259 : APIntOps::umax(One, MinStride);
13260
13261 APInt MaxValue = IsSigned ? APInt::getSignedMaxValue(BitWidth)
13262 : APInt::getMaxValue(BitWidth);
13263 APInt Limit = MaxValue - (StrideForMaxBECount - 1);
13264
13265 // Although End can be a MAX expression we estimate MaxEnd considering only
13266 // the case End = RHS of the loop termination condition. This is safe because
13267 // in the other case (End - Start) is zero, leading to a zero maximum backedge
13268 // taken count.
13269 APInt MaxEnd = IsSigned ? APIntOps::smin(getSignedRangeMax(End), Limit)
13270 : APIntOps::umin(getUnsignedRangeMax(End), Limit);
13271
13272 // MaxBECount = ceil((max(MaxEnd, MinStart) - MinStart) / Stride)
13273 MaxEnd = IsSigned ? APIntOps::smax(MaxEnd, MinStart)
13274 : APIntOps::umax(MaxEnd, MinStart);
13275
13276 return getUDivCeilSCEV(getConstant(MaxEnd - MinStart) /* Delta */,
13277 getConstant(StrideForMaxBECount) /* Step */);
13278}
13279
13281ScalarEvolution::howManyLessThans(const SCEV *LHS, const SCEV *RHS,
13282 const Loop *L, bool IsSigned,
13283 bool ControlsOnlyExit, bool AllowPredicates) {
13285
13287 bool PredicatedIV = false;
13288 if (!IV) {
13289 if (auto *ZExt = dyn_cast<SCEVZeroExtendExpr>(LHS)) {
13290 const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(ZExt->getOperand());
13291 if (AR && AR->getLoop() == L && AR->isAffine()) {
13292 auto canProveNUW = [&]() {
13293 // We can use the comparison to infer no-wrap flags only if it fully
13294 // controls the loop exit.
13295 if (!ControlsOnlyExit)
13296 return false;
13297
13298 if (!isLoopInvariant(RHS, L))
13299 return false;
13300
13301 if (!isKnownNonZero(AR->getStepRecurrence(*this)))
13302 // We need the sequence defined by AR to strictly increase in the
13303 // unsigned integer domain for the logic below to hold.
13304 return false;
13305
13306 const unsigned InnerBitWidth = getTypeSizeInBits(AR->getType());
13307 const unsigned OuterBitWidth = getTypeSizeInBits(RHS->getType());
13308 // If RHS <=u Limit, then there must exist a value V in the sequence
13309 // defined by AR (e.g. {Start,+,Step}) such that V >u RHS, and
13310 // V <=u UINT_MAX. Thus, we must exit the loop before unsigned
13311 // overflow occurs. This limit also implies that a signed comparison
13312 // (in the wide bitwidth) is equivalent to an unsigned comparison as
13313 // the high bits on both sides must be zero.
13314 APInt StrideMax = getUnsignedRangeMax(AR->getStepRecurrence(*this));
13315 APInt Limit = APInt::getMaxValue(InnerBitWidth) - (StrideMax - 1);
13316 Limit = Limit.zext(OuterBitWidth);
13317 return getUnsignedRangeMax(applyLoopGuards(RHS, L)).ule(Limit);
13318 };
13319 auto Flags = AR->getNoWrapFlags();
13320 if (!hasFlags(Flags, SCEV::FlagNUW) && canProveNUW())
13321 Flags = setFlags(Flags, SCEV::FlagNUW);
13322
13323 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), Flags);
13324 if (AR->hasNoUnsignedWrap()) {
13325 // Emulate what getZeroExtendExpr would have done during construction
13326 // if we'd been able to infer the fact just above at that time.
13327 const SCEV *Step = AR->getStepRecurrence(*this);
13328 Type *Ty = ZExt->getType();
13329 auto *S = getAddRecExpr(
13331 getZeroExtendExpr(Step, Ty, 0), L, AR->getNoWrapFlags());
13333 }
13334 }
13335 }
13336 }
13337
13338
13339 if (!IV && AllowPredicates) {
13340 // Try to make this an AddRec using runtime tests, in the first X
13341 // iterations of this loop, where X is the SCEV expression found by the
13342 // algorithm below.
13343 IV = convertSCEVToAddRecWithPredicates(LHS, L, Predicates);
13344 PredicatedIV = true;
13345 }
13346
13347 // Avoid weird loops
13348 if (!IV || IV->getLoop() != L || !IV->isAffine())
13349 return getCouldNotCompute();
13350
13351 // A precondition of this method is that the condition being analyzed
13352 // reaches an exiting branch which dominates the latch. Given that, we can
13353 // assume that an increment which violates the nowrap specification and
13354 // produces poison must cause undefined behavior when the resulting poison
13355 // value is branched upon and thus we can conclude that the backedge is
13356 // taken no more often than would be required to produce that poison value.
13357 // Note that a well defined loop can exit on the iteration which violates
13358 // the nowrap specification if there is another exit (either explicit or
13359 // implicit/exceptional) which causes the loop to execute before the
13360 // exiting instruction we're analyzing would trigger UB.
13361 auto WrapType = IsSigned ? SCEV::FlagNSW : SCEV::FlagNUW;
13362 bool NoWrap = ControlsOnlyExit && any(IV->getNoWrapFlags(WrapType));
13364
13365 const SCEV *Stride = IV->getStepRecurrence(*this);
13366
13367 bool PositiveStride = isKnownPositive(Stride);
13368
13369 // Avoid negative or zero stride values.
13370 if (!PositiveStride) {
13371 // We can compute the correct backedge taken count for loops with unknown
13372 // strides if we can prove that the loop is not an infinite loop with side
13373 // effects. Here's the loop structure we are trying to handle -
13374 //
13375 // i = start
13376 // do {
13377 // A[i] = i;
13378 // i += s;
13379 // } while (i < end);
13380 //
13381 // The backedge taken count for such loops is evaluated as -
13382 // (max(end, start + stride) - start - 1) /u stride
13383 //
13384 // The additional preconditions that we need to check to prove correctness
13385 // of the above formula is as follows -
13386 //
13387 // a) IV is either nuw or nsw depending upon signedness (indicated by the
13388 // NoWrap flag).
13389 // b) the loop is guaranteed to be finite (e.g. is mustprogress and has
13390 // no side effects within the loop)
13391 // c) loop has a single static exit (with no abnormal exits)
13392 //
13393 // Precondition a) implies that if the stride is negative, this is a single
13394 // trip loop. The backedge taken count formula reduces to zero in this case.
13395 //
13396 // Precondition b) and c) combine to imply that if rhs is invariant in L,
13397 // then a zero stride means the backedge can't be taken without executing
13398 // undefined behavior.
13399 //
13400 // The positive stride case is the same as isKnownPositive(Stride) returning
13401 // true (original behavior of the function).
13402 //
13403 if (PredicatedIV || !NoWrap || !loopIsFiniteByAssumption(L) ||
13405 return getCouldNotCompute();
13406
13407 if (!isKnownNonZero(Stride)) {
13408 // If we have a step of zero, and RHS isn't invariant in L, we don't know
13409 // if it might eventually be greater than start and if so, on which
13410 // iteration. We can't even produce a useful upper bound.
13411 if (!isLoopInvariant(RHS, L))
13412 return getCouldNotCompute();
13413
13414 // We allow a potentially zero stride, but we need to divide by stride
13415 // below. Since the loop can't be infinite and this check must control
13416 // the sole exit, we can infer the exit must be taken on the first
13417 // iteration (e.g. backedge count = 0) if the stride is zero. Given that,
13418 // we know the numerator in the divides below must be zero, so we can
13419 // pick an arbitrary non-zero value for the denominator (e.g. stride)
13420 // and produce the right result.
13421 // FIXME: Handle the case where Stride is poison?
13422 auto wouldZeroStrideBeUB = [&]() {
13423 // Proof by contradiction. Suppose the stride were zero. If we can
13424 // prove that the backedge *is* taken on the first iteration, then since
13425 // we know this condition controls the sole exit, we must have an
13426 // infinite loop. We can't have a (well defined) infinite loop per
13427 // check just above.
13428 // Note: The (Start - Stride) term is used to get the start' term from
13429 // (start' + stride,+,stride). Remember that we only care about the
13430 // result of this expression when stride == 0 at runtime.
13431 auto *StartIfZero = getMinusSCEV(IV->getStart(), Stride);
13432 return isLoopEntryGuardedByCond(L, Cond, StartIfZero, RHS);
13433 };
13434 if (!wouldZeroStrideBeUB()) {
13435 Stride = getUMaxExpr(Stride, getOne(Stride->getType()));
13436 }
13437 }
13438 } else if (!NoWrap) {
13439 // Avoid proven overflow cases: this will ensure that the backedge taken
13440 // count will not generate any unsigned overflow.
13441 if (canIVOverflowOnLT(RHS, Stride, IsSigned))
13442 return getCouldNotCompute();
13443 }
13444
13445 // On all paths just preceeding, we established the following invariant:
13446 // IV can be assumed not to overflow up to and including the exiting
13447 // iteration. We proved this in one of two ways:
13448 // 1) We can show overflow doesn't occur before the exiting iteration
13449 // 1a) canIVOverflowOnLT, and b) step of one
13450 // 2) We can show that if overflow occurs, the loop must execute UB
13451 // before any possible exit.
13452 // Note that we have not yet proved RHS invariant (in general).
13453
13454 const SCEV *Start = IV->getStart();
13455
13456 // Preserve pointer-typed Start/RHS to pass to isLoopEntryGuardedByCond.
13457 // If we convert to integers, isLoopEntryGuardedByCond will miss some cases.
13458 // Use integer-typed versions for actual computation; we can't subtract
13459 // pointers in general.
13460 const SCEV *OrigStart = Start;
13461 const SCEV *OrigRHS = RHS;
13462 if (Start->getType()->isPointerTy()) {
13464 if (isa<SCEVCouldNotCompute>(Start))
13465 return Start;
13466 }
13467 if (RHS->getType()->isPointerTy()) {
13470 return RHS;
13471 }
13472
13473 const SCEV *End = nullptr, *BECount = nullptr,
13474 *BECountIfBackedgeTaken = nullptr;
13475 if (!isLoopInvariant(RHS, L)) {
13476 const auto *RHSAddRec = dyn_cast<SCEVAddRecExpr>(RHS);
13477 if (PositiveStride && RHSAddRec != nullptr && RHSAddRec->getLoop() == L &&
13478 any(RHSAddRec->getNoWrapFlags())) {
13479 // The structure of loop we are trying to calculate backedge count of:
13480 //
13481 // left = left_start
13482 // right = right_start
13483 //
13484 // while(left < right){
13485 // ... do something here ...
13486 // left += s1; // stride of left is s1 (s1 > 0)
13487 // right += s2; // stride of right is s2 (s2 < 0)
13488 // }
13489 //
13490
13491 const SCEV *RHSStart = RHSAddRec->getStart();
13492 const SCEV *RHSStride = RHSAddRec->getStepRecurrence(*this);
13493
13494 // If Stride - RHSStride is positive and does not overflow, we can write
13495 // backedge count as ->
13496 // ceil((End - Start) /u (Stride - RHSStride))
13497 // Where, End = max(RHSStart, Start)
13498
13499 // Check if RHSStride < 0 and Stride - RHSStride will not overflow.
13500 if (isKnownNegative(RHSStride) &&
13501 willNotOverflow(Instruction::Sub, /*Signed=*/true, Stride,
13502 RHSStride)) {
13503
13504 const SCEV *Denominator = getMinusSCEV(Stride, RHSStride);
13505 if (isKnownPositive(Denominator)) {
13506 End = IsSigned ? getSMaxExpr(RHSStart, Start)
13507 : getUMaxExpr(RHSStart, Start);
13508
13509 // We can do this because End >= Start, as End = max(RHSStart, Start)
13510 const SCEV *Delta = getMinusSCEV(End, Start);
13511
13512 BECount = getUDivCeilSCEV(Delta, Denominator);
13513 BECountIfBackedgeTaken =
13514 getUDivCeilSCEV(getMinusSCEV(RHSStart, Start), Denominator);
13515 }
13516 }
13517 }
13518 if (BECount == nullptr) {
13519 // If we cannot calculate ExactBECount, we can calculate the MaxBECount,
13520 // given the start, stride and max value for the end bound of the
13521 // loop (RHS), and the fact that IV does not overflow (which is
13522 // checked above).
13523 const SCEV *MaxBECount = computeMaxBECountForLT(
13524 Start, Stride, RHS, getTypeSizeInBits(LHS->getType()), IsSigned);
13525 return ExitLimit(getCouldNotCompute() /* ExactNotTaken */, MaxBECount,
13526 MaxBECount, false /*MaxOrZero*/, Predicates);
13527 }
13528 } else {
13529 // We use the expression (max(End,Start)-Start)/Stride to describe the
13530 // backedge count, as if the backedge is taken at least once
13531 // max(End,Start) is End and so the result is as above, and if not
13532 // max(End,Start) is Start so we get a backedge count of zero.
13533 auto *OrigStartMinusStride = getMinusSCEV(OrigStart, Stride);
13534 assert(isAvailableAtLoopEntry(OrigStartMinusStride, L) && "Must be!");
13535 assert(isAvailableAtLoopEntry(OrigStart, L) && "Must be!");
13536 assert(isAvailableAtLoopEntry(OrigRHS, L) && "Must be!");
13537 // Can we prove (max(RHS,Start) > Start - Stride?
13538 if (isLoopEntryGuardedByCond(L, Cond, OrigStartMinusStride, OrigStart) &&
13539 isLoopEntryGuardedByCond(L, Cond, OrigStartMinusStride, OrigRHS)) {
13540 // In this case, we can use a refined formula for computing backedge
13541 // taken count. The general formula remains:
13542 // "End-Start /uceiling Stride" where "End = max(RHS,Start)"
13543 // We want to use the alternate formula:
13544 // "((End - 1) - (Start - Stride)) /u Stride"
13545 // Let's do a quick case analysis to show these are equivalent under
13546 // our precondition that max(RHS,Start) > Start - Stride.
13547 // * For RHS <= Start, the backedge-taken count must be zero.
13548 // "((End - 1) - (Start - Stride)) /u Stride" reduces to
13549 // "((Start - 1) - (Start - Stride)) /u Stride" which simplies to
13550 // "Stride - 1 /u Stride" which is indeed zero for all non-zero values
13551 // of Stride. For 0 stride, we've use umin(1,Stride) above,
13552 // reducing this to the stride of 1 case.
13553 // * For RHS >= Start, the backedge count must be "RHS-Start /uceil
13554 // Stride".
13555 // "((End - 1) - (Start - Stride)) /u Stride" reduces to
13556 // "((RHS - 1) - (Start - Stride)) /u Stride" reassociates to
13557 // "((RHS - (Start - Stride) - 1) /u Stride".
13558 // Our preconditions trivially imply no overflow in that form.
13559 const SCEV *MinusOne = getMinusOne(Stride->getType());
13560 const SCEV *Numerator =
13561 getMinusSCEV(getAddExpr(RHS, MinusOne), getMinusSCEV(Start, Stride));
13562 BECount = getUDivExpr(Numerator, Stride);
13563 }
13564
13565 if (!BECount) {
13566 auto canProveRHSGreaterThanEqualStart = [&]() {
13567 auto CondGE = IsSigned ? ICmpInst::ICMP_SGE : ICmpInst::ICMP_UGE;
13568 const SCEV *GuardedRHS = applyLoopGuards(OrigRHS, L);
13569 const SCEV *GuardedStart = applyLoopGuards(OrigStart, L);
13570
13571 if (isLoopEntryGuardedByCond(L, CondGE, OrigRHS, OrigStart) ||
13572 isKnownPredicate(CondGE, GuardedRHS, GuardedStart))
13573 return true;
13574
13575 // (RHS > Start - 1) implies RHS >= Start.
13576 // * "RHS >= Start" is trivially equivalent to "RHS > Start - 1" if
13577 // "Start - 1" doesn't overflow.
13578 // * For signed comparison, if Start - 1 does overflow, it's equal
13579 // to INT_MAX, and "RHS >s INT_MAX" is trivially false.
13580 // * For unsigned comparison, if Start - 1 does overflow, it's equal
13581 // to UINT_MAX, and "RHS >u UINT_MAX" is trivially false.
13582 //
13583 // FIXME: Should isLoopEntryGuardedByCond do this for us?
13584 auto CondGT = IsSigned ? ICmpInst::ICMP_SGT : ICmpInst::ICMP_UGT;
13585 auto *StartMinusOne =
13586 getAddExpr(OrigStart, getMinusOne(OrigStart->getType()));
13587 return isLoopEntryGuardedByCond(L, CondGT, OrigRHS, StartMinusOne);
13588 };
13589
13590 // If we know that RHS >= Start in the context of loop, then we know
13591 // that max(RHS, Start) = RHS at this point.
13592 if (canProveRHSGreaterThanEqualStart()) {
13593 End = RHS;
13594 } else {
13595 // If RHS < Start, the backedge will be taken zero times. So in
13596 // general, we can write the backedge-taken count as:
13597 //
13598 // RHS >= Start ? ceil(RHS - Start) / Stride : 0
13599 //
13600 // We convert it to the following to make it more convenient for SCEV:
13601 //
13602 // ceil(max(RHS, Start) - Start) / Stride
13603 End = IsSigned ? getSMaxExpr(RHS, Start) : getUMaxExpr(RHS, Start);
13604
13605 // See what would happen if we assume the backedge is taken. This is
13606 // used to compute MaxBECount.
13607 BECountIfBackedgeTaken =
13608 getUDivCeilSCEV(getMinusSCEV(RHS, Start), Stride);
13609 }
13610
13611 // At this point, we know:
13612 //
13613 // 1. If IsSigned, Start <=s End; otherwise, Start <=u End
13614 // 2. The index variable doesn't overflow.
13615 //
13616 // Therefore, we know N exists such that
13617 // (Start + Stride * N) >= End, and computing "(Start + Stride * N)"
13618 // doesn't overflow.
13619 //
13620 // Using this information, try to prove whether the addition in
13621 // "(Start - End) + (Stride - 1)" has unsigned overflow.
13622 const SCEV *One = getOne(Stride->getType());
13623 bool MayAddOverflow = [&] {
13624 if (isKnownToBeAPowerOfTwo(Stride)) {
13625 // Suppose Stride is a power of two, and Start/End are unsigned
13626 // integers. Let UMAX be the largest representable unsigned
13627 // integer.
13628 //
13629 // By the preconditions of this function, we know
13630 // "(Start + Stride * N) >= End", and this doesn't overflow.
13631 // As a formula:
13632 //
13633 // End <= (Start + Stride * N) <= UMAX
13634 //
13635 // Subtracting Start from all the terms:
13636 //
13637 // End - Start <= Stride * N <= UMAX - Start
13638 //
13639 // Since Start is unsigned, UMAX - Start <= UMAX. Therefore:
13640 //
13641 // End - Start <= Stride * N <= UMAX
13642 //
13643 // Stride * N is a multiple of Stride. Therefore,
13644 //
13645 // End - Start <= Stride * N <= UMAX - (UMAX mod Stride)
13646 //
13647 // Since Stride is a power of two, UMAX + 1 is divisible by
13648 // Stride. Therefore, UMAX mod Stride == Stride - 1. So we can
13649 // write:
13650 //
13651 // End - Start <= Stride * N <= UMAX - Stride - 1
13652 //
13653 // Dropping the middle term:
13654 //
13655 // End - Start <= UMAX - Stride - 1
13656 //
13657 // Adding Stride - 1 to both sides:
13658 //
13659 // (End - Start) + (Stride - 1) <= UMAX
13660 //
13661 // In other words, the addition doesn't have unsigned overflow.
13662 //
13663 // A similar proof works if we treat Start/End as signed values.
13664 // Just rewrite steps before "End - Start <= Stride * N <= UMAX"
13665 // to use signed max instead of unsigned max. Note that we're
13666 // trying to prove a lack of unsigned overflow in either case.
13667 return false;
13668 }
13669 if (Start == Stride || Start == getMinusSCEV(Stride, One)) {
13670 // If Start is equal to Stride, (End - Start) + (Stride - 1) == End
13671 // - 1. If !IsSigned, 0 <u Stride == Start <=u End; so 0 <u End - 1
13672 // <u End. If IsSigned, 0 <s Stride == Start <=s End; so 0 <s End -
13673 // 1 <s End.
13674 //
13675 // If Start is equal to Stride - 1, (End - Start) + Stride - 1 ==
13676 // End.
13677 return false;
13678 }
13679 return true;
13680 }();
13681
13682 const SCEV *Delta = getMinusSCEV(End, Start);
13683 if (!MayAddOverflow) {
13684 // floor((D + (S - 1)) / S)
13685 // We prefer this formulation if it's legal because it's fewer
13686 // operations.
13687 BECount =
13688 getUDivExpr(getAddExpr(Delta, getMinusSCEV(Stride, One)), Stride);
13689 } else {
13690 BECount = getUDivCeilSCEV(Delta, Stride);
13691 }
13692 }
13693 }
13694
13695 const SCEV *ConstantMaxBECount;
13696 bool MaxOrZero = false;
13697 if (isa<SCEVConstant>(BECount)) {
13698 ConstantMaxBECount = BECount;
13699 } else if (BECountIfBackedgeTaken &&
13700 isa<SCEVConstant>(BECountIfBackedgeTaken)) {
13701 // If we know exactly how many times the backedge will be taken if it's
13702 // taken at least once, then the backedge count will either be that or
13703 // zero.
13704 ConstantMaxBECount = BECountIfBackedgeTaken;
13705 MaxOrZero = true;
13706 } else {
13707 ConstantMaxBECount = computeMaxBECountForLT(
13708 Start, Stride, RHS, getTypeSizeInBits(LHS->getType()), IsSigned);
13709 }
13710
13711 if (isa<SCEVCouldNotCompute>(ConstantMaxBECount) &&
13712 !isa<SCEVCouldNotCompute>(BECount))
13713 ConstantMaxBECount = getConstant(getUnsignedRangeMax(BECount));
13714
13715 const SCEV *SymbolicMaxBECount =
13716 isa<SCEVCouldNotCompute>(BECount) ? ConstantMaxBECount : BECount;
13717 return ExitLimit(BECount, ConstantMaxBECount, SymbolicMaxBECount, MaxOrZero,
13718 Predicates);
13719}
13720
13721ScalarEvolution::ExitLimit ScalarEvolution::howManyGreaterThans(
13722 const SCEV *LHS, const SCEV *RHS, const Loop *L, bool IsSigned,
13723 bool ControlsOnlyExit, bool AllowPredicates) {
13725 // We handle only IV > Invariant
13726 if (!isLoopInvariant(RHS, L))
13727 return getCouldNotCompute();
13728
13729 const SCEVAddRecExpr *IV = dyn_cast<SCEVAddRecExpr>(LHS);
13730 if (!IV && AllowPredicates)
13731 // Try to make this an AddRec using runtime tests, in the first X
13732 // iterations of this loop, where X is the SCEV expression found by the
13733 // algorithm below.
13734 IV = convertSCEVToAddRecWithPredicates(LHS, L, Predicates);
13735
13736 // Avoid weird loops
13737 if (!IV || IV->getLoop() != L || !IV->isAffine())
13738 return getCouldNotCompute();
13739
13740 auto WrapType = IsSigned ? SCEV::FlagNSW : SCEV::FlagNUW;
13741 bool NoWrap = ControlsOnlyExit && any(IV->getNoWrapFlags(WrapType));
13743
13744 const SCEV *Stride = getNegativeSCEV(IV->getStepRecurrence(*this));
13745
13746 // Avoid negative or zero stride values
13747 if (!isKnownPositive(Stride))
13748 return getCouldNotCompute();
13749
13750 // Avoid proven overflow cases: this will ensure that the backedge taken count
13751 // will not generate any unsigned overflow. Relaxed no-overflow conditions
13752 // exploit NoWrapFlags, allowing to optimize in presence of undefined
13753 // behaviors like the case of C language.
13754 if (!Stride->isOne() && !NoWrap)
13755 if (canIVOverflowOnGT(RHS, Stride, IsSigned))
13756 return getCouldNotCompute();
13757
13758 const SCEV *Start = IV->getStart();
13759 const SCEV *End = RHS;
13760 if (!isLoopEntryGuardedByCond(L, Cond, getAddExpr(Start, Stride), RHS)) {
13761 // If we know that Start >= RHS in the context of loop, then we know that
13762 // min(RHS, Start) = RHS at this point.
13764 L, IsSigned ? ICmpInst::ICMP_SGE : ICmpInst::ICMP_UGE, Start, RHS))
13765 End = RHS;
13766 else
13767 End = IsSigned ? getSMinExpr(RHS, Start) : getUMinExpr(RHS, Start);
13768 }
13769
13770 if (Start->getType()->isPointerTy()) {
13772 if (isa<SCEVCouldNotCompute>(Start))
13773 return Start;
13774 }
13775 if (End->getType()->isPointerTy()) {
13776 End = getLosslessPtrToIntExpr(End);
13777 if (isa<SCEVCouldNotCompute>(End))
13778 return End;
13779 }
13780
13781 // Compute ((Start - End) + (Stride - 1)) / Stride.
13782 // FIXME: This can overflow. Holding off on fixing this for now;
13783 // howManyGreaterThans will hopefully be gone soon.
13784 const SCEV *One = getOne(Stride->getType());
13785 const SCEV *BECount = getUDivExpr(
13786 getAddExpr(getMinusSCEV(Start, End), getMinusSCEV(Stride, One)), Stride);
13787
13788 APInt MaxStart = IsSigned ? getSignedRangeMax(Start)
13790
13791 APInt MinStride = IsSigned ? getSignedRangeMin(Stride)
13792 : getUnsignedRangeMin(Stride);
13793
13794 unsigned BitWidth = getTypeSizeInBits(LHS->getType());
13795 APInt Limit = IsSigned ? APInt::getSignedMinValue(BitWidth) + (MinStride - 1)
13796 : APInt::getMinValue(BitWidth) + (MinStride - 1);
13797
13798 // Although End can be a MIN expression we estimate MinEnd considering only
13799 // the case End = RHS. This is safe because in the other case (Start - End)
13800 // is zero, leading to a zero maximum backedge taken count.
13801 APInt MinEnd =
13802 IsSigned ? APIntOps::smax(getSignedRangeMin(RHS), Limit)
13803 : APIntOps::umax(getUnsignedRangeMin(RHS), Limit);
13804
13805 const SCEV *ConstantMaxBECount =
13806 isa<SCEVConstant>(BECount)
13807 ? BECount
13808 : getUDivCeilSCEV(getConstant(MaxStart - MinEnd),
13809 getConstant(MinStride));
13810
13811 if (isa<SCEVCouldNotCompute>(ConstantMaxBECount))
13812 ConstantMaxBECount = BECount;
13813 const SCEV *SymbolicMaxBECount =
13814 isa<SCEVCouldNotCompute>(BECount) ? ConstantMaxBECount : BECount;
13815
13816 return ExitLimit(BECount, ConstantMaxBECount, SymbolicMaxBECount, false,
13817 Predicates);
13818}
13819
13821 ScalarEvolution &SE) const {
13822 if (Range.isFullSet()) // Infinite loop.
13823 return SE.getCouldNotCompute();
13824
13825 // If the start is a non-zero constant, shift the range to simplify things.
13826 if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(getStart()))
13827 if (!SC->getValue()->isZero()) {
13829 Operands[0] = SE.getZero(SC->getType());
13830 const SCEV *Shifted = SE.getAddRecExpr(Operands, getLoop(),
13832 if (const auto *ShiftedAddRec = dyn_cast<SCEVAddRecExpr>(Shifted))
13833 return ShiftedAddRec->getNumIterationsInRange(
13834 Range.subtract(SC->getAPInt()), SE);
13835 // This is strange and shouldn't happen.
13836 return SE.getCouldNotCompute();
13837 }
13838
13839 // The only time we can solve this is when we have all constant indices.
13840 // Otherwise, we cannot determine the overflow conditions.
13841 if (any_of(operands(), [](const SCEV *Op) { return !isa<SCEVConstant>(Op); }))
13842 return SE.getCouldNotCompute();
13843
13844 // Okay at this point we know that all elements of the chrec are constants and
13845 // that the start element is zero.
13846
13847 // First check to see if the range contains zero. If not, the first
13848 // iteration exits.
13849 unsigned BitWidth = SE.getTypeSizeInBits(getType());
13850 if (!Range.contains(APInt(BitWidth, 0)))
13851 return SE.getZero(getType());
13852
13853 if (isAffine()) {
13854 // If this is an affine expression then we have this situation:
13855 // Solve {0,+,A} in Range === Ax in Range
13856
13857 // We know that zero is in the range. If A is positive then we know that
13858 // the upper value of the range must be the first possible exit value.
13859 // If A is negative then the lower of the range is the last possible loop
13860 // value. Also note that we already checked for a full range.
13861 APInt A = cast<SCEVConstant>(getOperand(1))->getAPInt();
13862 APInt End = A.sge(1) ? (Range.getUpper() - 1) : Range.getLower();
13863
13864 // The exit value should be (End+A)/A.
13865 APInt ExitVal = (End + A).udiv(A);
13866 ConstantInt *ExitValue = ConstantInt::get(SE.getContext(), ExitVal);
13867
13868 // Evaluate at the exit value. If we really did fall out of the valid
13869 // range, then we computed our trip count, otherwise wrap around or other
13870 // things must have happened.
13871 ConstantInt *Val = EvaluateConstantChrecAtConstant(this, ExitValue, SE);
13872 if (Range.contains(Val->getValue()))
13873 return SE.getCouldNotCompute(); // Something strange happened
13874
13875 // Ensure that the previous value is in the range.
13876 assert(Range.contains(
13878 ConstantInt::get(SE.getContext(), ExitVal - 1), SE)->getValue()) &&
13879 "Linear scev computation is off in a bad way!");
13880 return SE.getConstant(ExitValue);
13881 }
13882
13883 if (isQuadratic()) {
13884 if (auto S = SolveQuadraticAddRecRange(this, Range, SE))
13885 return SE.getConstant(*S);
13886 }
13887
13888 return SE.getCouldNotCompute();
13889}
13890
13891const SCEVAddRecExpr *
13893 assert(getNumOperands() > 1 && "AddRec with zero step?");
13894 // There is a temptation to just call getAddExpr(this, getStepRecurrence(SE)),
13895 // but in this case we cannot guarantee that the value returned will be an
13896 // AddRec because SCEV does not have a fixed point where it stops
13897 // simplification: it is legal to return ({rec1} + {rec2}). For example, it
13898 // may happen if we reach arithmetic depth limit while simplifying. So we
13899 // construct the returned value explicitly.
13901 // If this is {A,+,B,+,C,...,+,N}, then its step is {B,+,C,+,...,+,N}, and
13902 // (this + Step) is {A+B,+,B+C,+...,+,N}.
13903 for (unsigned i = 0, e = getNumOperands() - 1; i < e; ++i)
13904 Ops.push_back(SE.getAddExpr(getOperand(i), getOperand(i + 1)));
13905 // We know that the last operand is not a constant zero (otherwise it would
13906 // have been popped out earlier). This guarantees us that if the result has
13907 // the same last operand, then it will also not be popped out, meaning that
13908 // the returned value will be an AddRec.
13909 const SCEV *Last = getOperand(getNumOperands() - 1);
13910 assert(!Last->isZero() && "Recurrency with zero step?");
13911 Ops.push_back(Last);
13914}
13915
13916// Return true when S contains at least an undef value.
13918 return SCEVExprContains(
13919 S, [](const SCEV *S) { return match(S, m_scev_UndefOrPoison()); });
13920}
13921
13922// Return true when S contains a value that is a nullptr.
13924 return SCEVExprContains(S, [](const SCEV *S) {
13925 if (const auto *SU = dyn_cast<SCEVUnknown>(S))
13926 return SU->getValue() == nullptr;
13927 return false;
13928 });
13929}
13930
13931/// Return the size of an element read or written by Inst.
13933 Type *Ty;
13934 if (StoreInst *Store = dyn_cast<StoreInst>(Inst))
13935 Ty = Store->getValueOperand()->getType();
13936 else if (LoadInst *Load = dyn_cast<LoadInst>(Inst))
13937 Ty = Load->getType();
13938 else
13939 return nullptr;
13940
13942 return getSizeOfExpr(ETy, Ty);
13943}
13944
13945//===----------------------------------------------------------------------===//
13946// SCEVCallbackVH Class Implementation
13947//===----------------------------------------------------------------------===//
13948
13950 assert(SE && "SCEVCallbackVH called with a null ScalarEvolution!");
13951 if (PHINode *PN = dyn_cast<PHINode>(getValPtr()))
13952 SE->ConstantEvolutionLoopExitValue.erase(PN);
13953 SE->eraseValueFromMap(getValPtr());
13954 // this now dangles!
13955}
13956
13957void ScalarEvolution::SCEVCallbackVH::allUsesReplacedWith(Value *V) {
13958 assert(SE && "SCEVCallbackVH called with a null ScalarEvolution!");
13959
13960 // Forget all the expressions associated with users of the old value,
13961 // so that future queries will recompute the expressions using the new
13962 // value.
13963 SE->forgetValue(getValPtr());
13964 // this now dangles!
13965}
13966
13967ScalarEvolution::SCEVCallbackVH::SCEVCallbackVH(Value *V, ScalarEvolution *se)
13968 : CallbackVH(V), SE(se) {}
13969
13970//===----------------------------------------------------------------------===//
13971// ScalarEvolution Class Implementation
13972//===----------------------------------------------------------------------===//
13973
13976 LoopInfo &LI)
13977 : F(F), DL(F.getDataLayout()), TLI(TLI), AC(AC), DT(DT), LI(LI),
13978 CouldNotCompute(new SCEVCouldNotCompute()), ValuesAtScopes(64),
13979 LoopDispositions(64), BlockDispositions(64) {
13980 // To use guards for proving predicates, we need to scan every instruction in
13981 // relevant basic blocks, and not just terminators. Doing this is a waste of
13982 // time if the IR does not actually contain any calls to
13983 // @llvm.experimental.guard, so do a quick check and remember this beforehand.
13984 //
13985 // This pessimizes the case where a pass that preserves ScalarEvolution wants
13986 // to _add_ guards to the module when there weren't any before, and wants
13987 // ScalarEvolution to optimize based on those guards. For now we prefer to be
13988 // efficient in lieu of being smart in that rather obscure case.
13989
13990 auto *GuardDecl = Intrinsic::getDeclarationIfExists(
13991 F.getParent(), Intrinsic::experimental_guard);
13992 HasGuards = GuardDecl && !GuardDecl->use_empty();
13993}
13994
13996 : F(Arg.F), DL(Arg.DL), HasGuards(Arg.HasGuards), TLI(Arg.TLI), AC(Arg.AC),
13997 DT(Arg.DT), LI(Arg.LI), CouldNotCompute(std::move(Arg.CouldNotCompute)),
13998 ValueExprMap(std::move(Arg.ValueExprMap)),
13999 PendingLoopPredicates(std::move(Arg.PendingLoopPredicates)),
14000 PendingMerges(std::move(Arg.PendingMerges)),
14001 ConstantMultipleCache(std::move(Arg.ConstantMultipleCache)),
14002 BackedgeTakenCounts(std::move(Arg.BackedgeTakenCounts)),
14003 PredicatedBackedgeTakenCounts(
14004 std::move(Arg.PredicatedBackedgeTakenCounts)),
14005 BECountUsers(std::move(Arg.BECountUsers)),
14006 ConstantEvolutionLoopExitValue(
14007 std::move(Arg.ConstantEvolutionLoopExitValue)),
14008 ValuesAtScopes(std::move(Arg.ValuesAtScopes)),
14009 ValuesAtScopesUsers(std::move(Arg.ValuesAtScopesUsers)),
14010 LoopDispositions(std::move(Arg.LoopDispositions)),
14011 LoopPropertiesCache(std::move(Arg.LoopPropertiesCache)),
14012 BlockDispositions(std::move(Arg.BlockDispositions)),
14013 SCEVUsers(std::move(Arg.SCEVUsers)),
14014 UnsignedRanges(std::move(Arg.UnsignedRanges)),
14015 SignedRanges(std::move(Arg.SignedRanges)),
14016 UniqueSCEVs(std::move(Arg.UniqueSCEVs)),
14017 UniquePreds(std::move(Arg.UniquePreds)),
14018 SCEVAllocator(std::move(Arg.SCEVAllocator)),
14019 LoopUsers(std::move(Arg.LoopUsers)),
14020 PredicatedSCEVRewrites(std::move(Arg.PredicatedSCEVRewrites)),
14021 FirstUnknown(Arg.FirstUnknown) {
14022 Arg.FirstUnknown = nullptr;
14023}
14024
14026 // Iterate through all the SCEVUnknown instances and call their
14027 // destructors, so that they release their references to their values.
14028 for (SCEVUnknown *U = FirstUnknown; U;) {
14029 SCEVUnknown *Tmp = U;
14030 U = U->Next;
14031 Tmp->~SCEVUnknown();
14032 }
14033 FirstUnknown = nullptr;
14034
14035 ExprValueMap.clear();
14036 ValueExprMap.clear();
14037 HasRecMap.clear();
14038 BackedgeTakenCounts.clear();
14039 PredicatedBackedgeTakenCounts.clear();
14040
14041 assert(PendingLoopPredicates.empty() && "isImpliedCond garbage");
14042 assert(PendingMerges.empty() && "isImpliedViaMerge garbage");
14043 assert(!WalkingBEDominatingConds && "isLoopBackedgeGuardedByCond garbage!");
14044 assert(!ProvingSplitPredicate && "ProvingSplitPredicate garbage!");
14045}
14046
14050
14051/// When printing a top-level SCEV for trip counts, it's helpful to include
14052/// a type for constants which are otherwise hard to disambiguate.
14053static void PrintSCEVWithTypeHint(raw_ostream &OS, const SCEV* S) {
14054 if (isa<SCEVConstant>(S))
14055 OS << *S->getType() << " ";
14056 OS << *S;
14057}
14058
14060 const Loop *L) {
14061 // Print all inner loops first
14062 for (Loop *I : *L)
14063 PrintLoopInfo(OS, SE, I);
14064
14065 OS << "Loop ";
14066 L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
14067 OS << ": ";
14068
14069 SmallVector<BasicBlock *, 8> ExitingBlocks;
14070 L->getExitingBlocks(ExitingBlocks);
14071 if (ExitingBlocks.size() != 1)
14072 OS << "<multiple exits> ";
14073
14074 auto *BTC = SE->getBackedgeTakenCount(L);
14075 if (!isa<SCEVCouldNotCompute>(BTC)) {
14076 OS << "backedge-taken count is ";
14077 PrintSCEVWithTypeHint(OS, BTC);
14078 } else
14079 OS << "Unpredictable backedge-taken count.";
14080 OS << "\n";
14081
14082 if (ExitingBlocks.size() > 1)
14083 for (BasicBlock *ExitingBlock : ExitingBlocks) {
14084 OS << " exit count for " << ExitingBlock->getName() << ": ";
14085 const SCEV *EC = SE->getExitCount(L, ExitingBlock);
14086 PrintSCEVWithTypeHint(OS, EC);
14087 if (isa<SCEVCouldNotCompute>(EC)) {
14088 // Retry with predicates.
14090 EC = SE->getPredicatedExitCount(L, ExitingBlock, &Predicates);
14091 if (!isa<SCEVCouldNotCompute>(EC)) {
14092 OS << "\n predicated exit count for " << ExitingBlock->getName()
14093 << ": ";
14094 PrintSCEVWithTypeHint(OS, EC);
14095 OS << "\n Predicates:\n";
14096 for (const auto *P : Predicates)
14097 P->print(OS, 4);
14098 }
14099 }
14100 OS << "\n";
14101 }
14102
14103 OS << "Loop ";
14104 L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
14105 OS << ": ";
14106
14107 auto *ConstantBTC = SE->getConstantMaxBackedgeTakenCount(L);
14108 if (!isa<SCEVCouldNotCompute>(ConstantBTC)) {
14109 OS << "constant max backedge-taken count is ";
14110 PrintSCEVWithTypeHint(OS, ConstantBTC);
14112 OS << ", actual taken count either this or zero.";
14113 } else {
14114 OS << "Unpredictable constant max backedge-taken count. ";
14115 }
14116
14117 OS << "\n"
14118 "Loop ";
14119 L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
14120 OS << ": ";
14121
14122 auto *SymbolicBTC = SE->getSymbolicMaxBackedgeTakenCount(L);
14123 if (!isa<SCEVCouldNotCompute>(SymbolicBTC)) {
14124 OS << "symbolic max backedge-taken count is ";
14125 PrintSCEVWithTypeHint(OS, SymbolicBTC);
14127 OS << ", actual taken count either this or zero.";
14128 } else {
14129 OS << "Unpredictable symbolic max backedge-taken count. ";
14130 }
14131 OS << "\n";
14132
14133 if (ExitingBlocks.size() > 1)
14134 for (BasicBlock *ExitingBlock : ExitingBlocks) {
14135 OS << " symbolic max exit count for " << ExitingBlock->getName() << ": ";
14136 auto *ExitBTC = SE->getExitCount(L, ExitingBlock,
14138 PrintSCEVWithTypeHint(OS, ExitBTC);
14139 if (isa<SCEVCouldNotCompute>(ExitBTC)) {
14140 // Retry with predicates.
14142 ExitBTC = SE->getPredicatedExitCount(L, ExitingBlock, &Predicates,
14144 if (!isa<SCEVCouldNotCompute>(ExitBTC)) {
14145 OS << "\n predicated symbolic max exit count for "
14146 << ExitingBlock->getName() << ": ";
14147 PrintSCEVWithTypeHint(OS, ExitBTC);
14148 OS << "\n Predicates:\n";
14149 for (const auto *P : Predicates)
14150 P->print(OS, 4);
14151 }
14152 }
14153 OS << "\n";
14154 }
14155
14157 auto *PBT = SE->getPredicatedBackedgeTakenCount(L, Preds);
14158 if (PBT != BTC) {
14159 assert(!Preds.empty() && "Different predicated BTC, but no predicates");
14160 OS << "Loop ";
14161 L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
14162 OS << ": ";
14163 if (!isa<SCEVCouldNotCompute>(PBT)) {
14164 OS << "Predicated backedge-taken count is ";
14165 PrintSCEVWithTypeHint(OS, PBT);
14166 } else
14167 OS << "Unpredictable predicated backedge-taken count.";
14168 OS << "\n";
14169 OS << " Predicates:\n";
14170 for (const auto *P : Preds)
14171 P->print(OS, 4);
14172 }
14173 Preds.clear();
14174
14175 auto *PredConstantMax =
14177 if (PredConstantMax != ConstantBTC) {
14178 assert(!Preds.empty() &&
14179 "different predicated constant max BTC but no predicates");
14180 OS << "Loop ";
14181 L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
14182 OS << ": ";
14183 if (!isa<SCEVCouldNotCompute>(PredConstantMax)) {
14184 OS << "Predicated constant max backedge-taken count is ";
14185 PrintSCEVWithTypeHint(OS, PredConstantMax);
14186 } else
14187 OS << "Unpredictable predicated constant max backedge-taken count.";
14188 OS << "\n";
14189 OS << " Predicates:\n";
14190 for (const auto *P : Preds)
14191 P->print(OS, 4);
14192 }
14193 Preds.clear();
14194
14195 auto *PredSymbolicMax =
14197 if (SymbolicBTC != PredSymbolicMax) {
14198 assert(!Preds.empty() &&
14199 "Different predicated symbolic max BTC, but no predicates");
14200 OS << "Loop ";
14201 L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
14202 OS << ": ";
14203 if (!isa<SCEVCouldNotCompute>(PredSymbolicMax)) {
14204 OS << "Predicated symbolic max backedge-taken count is ";
14205 PrintSCEVWithTypeHint(OS, PredSymbolicMax);
14206 } else
14207 OS << "Unpredictable predicated symbolic max backedge-taken count.";
14208 OS << "\n";
14209 OS << " Predicates:\n";
14210 for (const auto *P : Preds)
14211 P->print(OS, 4);
14212 }
14213
14215 OS << "Loop ";
14216 L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
14217 OS << ": ";
14218 OS << "Trip multiple is " << SE->getSmallConstantTripMultiple(L) << "\n";
14219 }
14220}
14221
14222namespace llvm {
14223// Note: these overloaded operators need to be in the llvm namespace for them
14224// to be resolved correctly. If we put them outside the llvm namespace, the
14225//
14226// OS << ": " << SE.getLoopDisposition(SV, InnerL);
14227//
14228// code below "breaks" and start printing raw enum values as opposed to the
14229// string values.
14232 switch (LD) {
14234 OS << "Variant";
14235 break;
14237 OS << "Invariant";
14238 break;
14240 OS << "Computable";
14241 break;
14242 }
14243 return OS;
14244}
14245
14248 switch (BD) {
14250 OS << "DoesNotDominate";
14251 break;
14253 OS << "Dominates";
14254 break;
14256 OS << "ProperlyDominates";
14257 break;
14258 }
14259 return OS;
14260}
14261} // namespace llvm
14262
14264 // ScalarEvolution's implementation of the print method is to print
14265 // out SCEV values of all instructions that are interesting. Doing
14266 // this potentially causes it to create new SCEV objects though,
14267 // which technically conflicts with the const qualifier. This isn't
14268 // observable from outside the class though, so casting away the
14269 // const isn't dangerous.
14270 ScalarEvolution &SE = *const_cast<ScalarEvolution *>(this);
14271
14272 if (ClassifyExpressions) {
14273 OS << "Classifying expressions for: ";
14274 F.printAsOperand(OS, /*PrintType=*/false);
14275 OS << "\n";
14276 for (Instruction &I : instructions(F))
14277 if (isSCEVable(I.getType()) && !isa<CmpInst>(I)) {
14278 OS << I << '\n';
14279 OS << " --> ";
14280 const SCEV *SV = SE.getSCEV(&I);
14281 SV->print(OS);
14282 if (!isa<SCEVCouldNotCompute>(SV)) {
14283 OS << " U: ";
14284 SE.getUnsignedRange(SV).print(OS);
14285 OS << " S: ";
14286 SE.getSignedRange(SV).print(OS);
14287 }
14288
14289 const Loop *L = LI.getLoopFor(I.getParent());
14290
14291 const SCEV *AtUse = SE.getSCEVAtScope(SV, L);
14292 if (AtUse != SV) {
14293 OS << " --> ";
14294 AtUse->print(OS);
14295 if (!isa<SCEVCouldNotCompute>(AtUse)) {
14296 OS << " U: ";
14297 SE.getUnsignedRange(AtUse).print(OS);
14298 OS << " S: ";
14299 SE.getSignedRange(AtUse).print(OS);
14300 }
14301 }
14302
14303 if (L) {
14304 OS << "\t\t" "Exits: ";
14305 const SCEV *ExitValue = SE.getSCEVAtScope(SV, L->getParentLoop());
14306 if (!SE.isLoopInvariant(ExitValue, L)) {
14307 OS << "<<Unknown>>";
14308 } else {
14309 OS << *ExitValue;
14310 }
14311
14312 ListSeparator LS(", ", "\t\tLoopDispositions: { ");
14313 for (const auto *Iter = L; Iter; Iter = Iter->getParentLoop()) {
14314 OS << LS;
14315 Iter->getHeader()->printAsOperand(OS, /*PrintType=*/false);
14316 OS << ": " << SE.getLoopDisposition(SV, Iter);
14317 }
14318
14319 for (const auto *InnerL : depth_first(L)) {
14320 if (InnerL == L)
14321 continue;
14322 OS << LS;
14323 InnerL->getHeader()->printAsOperand(OS, /*PrintType=*/false);
14324 OS << ": " << SE.getLoopDisposition(SV, InnerL);
14325 }
14326
14327 OS << " }";
14328 }
14329
14330 OS << "\n";
14331 }
14332 }
14333
14334 OS << "Determining loop execution counts for: ";
14335 F.printAsOperand(OS, /*PrintType=*/false);
14336 OS << "\n";
14337 for (Loop *I : LI)
14338 PrintLoopInfo(OS, &SE, I);
14339}
14340
14343 auto &Values = LoopDispositions[S];
14344 for (auto &V : Values) {
14345 if (V.getPointer() == L)
14346 return V.getInt();
14347 }
14348 Values.emplace_back(L, LoopVariant);
14349 LoopDisposition D = computeLoopDisposition(S, L);
14350 auto &Values2 = LoopDispositions[S];
14351 for (auto &V : llvm::reverse(Values2)) {
14352 if (V.getPointer() == L) {
14353 V.setInt(D);
14354 break;
14355 }
14356 }
14357 return D;
14358}
14359
14361ScalarEvolution::computeLoopDisposition(const SCEV *S, const Loop *L) {
14362 switch (S->getSCEVType()) {
14363 case scConstant:
14364 case scVScale:
14365 return LoopInvariant;
14366 case scAddRecExpr: {
14367 const SCEVAddRecExpr *AR = cast<SCEVAddRecExpr>(S);
14368
14369 // If L is the addrec's loop, it's computable.
14370 if (AR->getLoop() == L)
14371 return LoopComputable;
14372
14373 // Add recurrences are never invariant in the function-body (null loop).
14374 if (!L)
14375 return LoopVariant;
14376
14377 // Everything that is not defined at loop entry is variant.
14378 if (DT.dominates(L->getHeader(), AR->getLoop()->getHeader()))
14379 return LoopVariant;
14380 assert(!L->contains(AR->getLoop()) && "Containing loop's header does not"
14381 " dominate the contained loop's header?");
14382
14383 // This recurrence is invariant w.r.t. L if AR's loop contains L.
14384 if (AR->getLoop()->contains(L))
14385 return LoopInvariant;
14386
14387 // This recurrence is variant w.r.t. L if any of its operands
14388 // are variant.
14389 for (SCEVUse Op : AR->operands())
14390 if (!isLoopInvariant(Op, L))
14391 return LoopVariant;
14392
14393 // Otherwise it's loop-invariant.
14394 return LoopInvariant;
14395 }
14396 case scTruncate:
14397 case scZeroExtend:
14398 case scSignExtend:
14399 case scPtrToAddr:
14400 case scPtrToInt:
14401 case scAddExpr:
14402 case scMulExpr:
14403 case scUDivExpr:
14404 case scUMaxExpr:
14405 case scSMaxExpr:
14406 case scUMinExpr:
14407 case scSMinExpr:
14408 case scSequentialUMinExpr: {
14409 bool HasVarying = false;
14410 for (SCEVUse Op : S->operands()) {
14412 if (D == LoopVariant)
14413 return LoopVariant;
14414 if (D == LoopComputable)
14415 HasVarying = true;
14416 }
14417 return HasVarying ? LoopComputable : LoopInvariant;
14418 }
14419 case scUnknown:
14420 // All non-instruction values are loop invariant. All instructions are loop
14421 // invariant if they are not contained in the specified loop.
14422 // Instructions are never considered invariant in the function body
14423 // (null loop) because they are defined within the "loop".
14424 if (auto *I = dyn_cast<Instruction>(cast<SCEVUnknown>(S)->getValue()))
14425 return (L && !L->contains(I)) ? LoopInvariant : LoopVariant;
14426 return LoopInvariant;
14427 case scCouldNotCompute:
14428 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
14429 }
14430 llvm_unreachable("Unknown SCEV kind!");
14431}
14432
14434 return getLoopDisposition(S, L) == LoopInvariant;
14435}
14436
14438 return getLoopDisposition(S, L) == LoopComputable;
14439}
14440
14443 auto &Values = BlockDispositions[S];
14444 for (auto &V : Values) {
14445 if (V.getPointer() == BB)
14446 return V.getInt();
14447 }
14448 Values.emplace_back(BB, DoesNotDominateBlock);
14449 BlockDisposition D = computeBlockDisposition(S, BB);
14450 auto &Values2 = BlockDispositions[S];
14451 for (auto &V : llvm::reverse(Values2)) {
14452 if (V.getPointer() == BB) {
14453 V.setInt(D);
14454 break;
14455 }
14456 }
14457 return D;
14458}
14459
14461ScalarEvolution::computeBlockDisposition(const SCEV *S, const BasicBlock *BB) {
14462 switch (S->getSCEVType()) {
14463 case scConstant:
14464 case scVScale:
14466 case scAddRecExpr: {
14467 // This uses a "dominates" query instead of "properly dominates" query
14468 // to test for proper dominance too, because the instruction which
14469 // produces the addrec's value is a PHI, and a PHI effectively properly
14470 // dominates its entire containing block.
14471 const SCEVAddRecExpr *AR = cast<SCEVAddRecExpr>(S);
14472 if (!DT.dominates(AR->getLoop()->getHeader(), BB))
14473 return DoesNotDominateBlock;
14474
14475 // Fall through into SCEVNAryExpr handling.
14476 [[fallthrough]];
14477 }
14478 case scTruncate:
14479 case scZeroExtend:
14480 case scSignExtend:
14481 case scPtrToAddr:
14482 case scPtrToInt:
14483 case scAddExpr:
14484 case scMulExpr:
14485 case scUDivExpr:
14486 case scUMaxExpr:
14487 case scSMaxExpr:
14488 case scUMinExpr:
14489 case scSMinExpr:
14490 case scSequentialUMinExpr: {
14491 bool Proper = true;
14492 for (const SCEV *NAryOp : S->operands()) {
14494 if (D == DoesNotDominateBlock)
14495 return DoesNotDominateBlock;
14496 if (D == DominatesBlock)
14497 Proper = false;
14498 }
14499 return Proper ? ProperlyDominatesBlock : DominatesBlock;
14500 }
14501 case scUnknown:
14502 if (Instruction *I =
14503 dyn_cast<Instruction>(cast<SCEVUnknown>(S)->getValue())) {
14504 if (I->getParent() == BB)
14505 return DominatesBlock;
14506 if (DT.properlyDominates(I->getParent(), BB))
14508 return DoesNotDominateBlock;
14509 }
14511 case scCouldNotCompute:
14512 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
14513 }
14514 llvm_unreachable("Unknown SCEV kind!");
14515}
14516
14517bool ScalarEvolution::dominates(const SCEV *S, const BasicBlock *BB) {
14518 return getBlockDisposition(S, BB) >= DominatesBlock;
14519}
14520
14523}
14524
14525bool ScalarEvolution::hasOperand(const SCEV *S, const SCEV *Op) const {
14526 return SCEVExprContains(S, [&](const SCEV *Expr) { return Expr == Op; });
14527}
14528
14529void ScalarEvolution::forgetBackedgeTakenCounts(const Loop *L,
14530 bool Predicated) {
14531 auto &BECounts =
14532 Predicated ? PredicatedBackedgeTakenCounts : BackedgeTakenCounts;
14533 auto It = BECounts.find(L);
14534 if (It != BECounts.end()) {
14535 for (const ExitNotTakenInfo &ENT : It->second.ExitNotTaken) {
14536 for (const SCEV *S : {ENT.ExactNotTaken, ENT.SymbolicMaxNotTaken}) {
14537 if (!isa<SCEVConstant>(S)) {
14538 auto UserIt = BECountUsers.find(S);
14539 assert(UserIt != BECountUsers.end());
14540 UserIt->second.erase({L, Predicated});
14541 }
14542 }
14543 }
14544 BECounts.erase(It);
14545 }
14546}
14547
14548void ScalarEvolution::forgetMemoizedResults(ArrayRef<SCEVUse> SCEVs) {
14549 SmallPtrSet<const SCEV *, 8> ToForget(llvm::from_range, SCEVs);
14550 SmallVector<SCEVUse, 8> Worklist(ToForget.begin(), ToForget.end());
14551
14552 while (!Worklist.empty()) {
14553 const SCEV *Curr = Worklist.pop_back_val();
14554 auto Users = SCEVUsers.find(Curr);
14555 if (Users != SCEVUsers.end())
14556 for (const auto *User : Users->second)
14557 if (ToForget.insert(User).second)
14558 Worklist.push_back(User);
14559 }
14560
14561 for (const auto *S : ToForget)
14562 forgetMemoizedResultsImpl(S);
14563
14564 for (auto I = PredicatedSCEVRewrites.begin();
14565 I != PredicatedSCEVRewrites.end();) {
14566 std::pair<const SCEV *, const Loop *> Entry = I->first;
14567 if (ToForget.count(Entry.first))
14568 PredicatedSCEVRewrites.erase(I++);
14569 else
14570 ++I;
14571 }
14572}
14573
14574void ScalarEvolution::forgetMemoizedResultsImpl(const SCEV *S) {
14575 LoopDispositions.erase(S);
14576 BlockDispositions.erase(S);
14577 UnsignedRanges.erase(S);
14578 SignedRanges.erase(S);
14579 HasRecMap.erase(S);
14580 ConstantMultipleCache.erase(S);
14581
14582 if (auto *AR = dyn_cast<SCEVAddRecExpr>(S)) {
14583 UnsignedWrapViaInductionTried.erase(AR);
14584 SignedWrapViaInductionTried.erase(AR);
14585 }
14586
14587 auto ExprIt = ExprValueMap.find(S);
14588 if (ExprIt != ExprValueMap.end()) {
14589 for (Value *V : ExprIt->second) {
14590 auto ValueIt = ValueExprMap.find_as(V);
14591 if (ValueIt != ValueExprMap.end())
14592 ValueExprMap.erase(ValueIt);
14593 }
14594 ExprValueMap.erase(ExprIt);
14595 }
14596
14597 auto ScopeIt = ValuesAtScopes.find(S);
14598 if (ScopeIt != ValuesAtScopes.end()) {
14599 for (const auto &Pair : ScopeIt->second)
14600 if (!isa_and_nonnull<SCEVConstant>(Pair.second))
14601 llvm::erase(ValuesAtScopesUsers[Pair.second],
14602 std::make_pair(Pair.first, S));
14603 ValuesAtScopes.erase(ScopeIt);
14604 }
14605
14606 auto ScopeUserIt = ValuesAtScopesUsers.find(S);
14607 if (ScopeUserIt != ValuesAtScopesUsers.end()) {
14608 for (const auto &Pair : ScopeUserIt->second)
14609 llvm::erase(ValuesAtScopes[Pair.second], std::make_pair(Pair.first, S));
14610 ValuesAtScopesUsers.erase(ScopeUserIt);
14611 }
14612
14613 auto BEUsersIt = BECountUsers.find(S);
14614 if (BEUsersIt != BECountUsers.end()) {
14615 // Work on a copy, as forgetBackedgeTakenCounts() will modify the original.
14616 auto Copy = BEUsersIt->second;
14617 for (const auto &Pair : Copy)
14618 forgetBackedgeTakenCounts(Pair.getPointer(), Pair.getInt());
14619 BECountUsers.erase(BEUsersIt);
14620 }
14621
14622 auto FoldUser = FoldCacheUser.find(S);
14623 if (FoldUser != FoldCacheUser.end())
14624 for (auto &KV : FoldUser->second)
14625 FoldCache.erase(KV);
14626 FoldCacheUser.erase(S);
14627}
14628
14629void
14630ScalarEvolution::getUsedLoops(const SCEV *S,
14631 SmallPtrSetImpl<const Loop *> &LoopsUsed) {
14632 struct FindUsedLoops {
14633 FindUsedLoops(SmallPtrSetImpl<const Loop *> &LoopsUsed)
14634 : LoopsUsed(LoopsUsed) {}
14635 SmallPtrSetImpl<const Loop *> &LoopsUsed;
14636 bool follow(const SCEV *S) {
14637 if (auto *AR = dyn_cast<SCEVAddRecExpr>(S))
14638 LoopsUsed.insert(AR->getLoop());
14639 return true;
14640 }
14641
14642 bool isDone() const { return false; }
14643 };
14644
14645 FindUsedLoops F(LoopsUsed);
14646 SCEVTraversal<FindUsedLoops>(F).visitAll(S);
14647}
14648
14649void ScalarEvolution::getReachableBlocks(
14652 Worklist.push_back(&F.getEntryBlock());
14653 while (!Worklist.empty()) {
14654 BasicBlock *BB = Worklist.pop_back_val();
14655 if (!Reachable.insert(BB).second)
14656 continue;
14657
14658 Value *Cond;
14659 BasicBlock *TrueBB, *FalseBB;
14660 if (match(BB->getTerminator(), m_Br(m_Value(Cond), m_BasicBlock(TrueBB),
14661 m_BasicBlock(FalseBB)))) {
14662 if (auto *C = dyn_cast<ConstantInt>(Cond)) {
14663 Worklist.push_back(C->isOne() ? TrueBB : FalseBB);
14664 continue;
14665 }
14666
14667 if (auto *Cmp = dyn_cast<ICmpInst>(Cond)) {
14668 const SCEV *L = getSCEV(Cmp->getOperand(0));
14669 const SCEV *R = getSCEV(Cmp->getOperand(1));
14670 if (isKnownPredicateViaConstantRanges(Cmp->getCmpPredicate(), L, R)) {
14671 Worklist.push_back(TrueBB);
14672 continue;
14673 }
14674 if (isKnownPredicateViaConstantRanges(Cmp->getInverseCmpPredicate(), L,
14675 R)) {
14676 Worklist.push_back(FalseBB);
14677 continue;
14678 }
14679 }
14680 }
14681
14682 append_range(Worklist, successors(BB));
14683 }
14684}
14685
14687 ScalarEvolution &SE = *const_cast<ScalarEvolution *>(this);
14688 ScalarEvolution SE2(F, TLI, AC, DT, LI);
14689
14690 SmallVector<Loop *, 8> LoopStack(LI.begin(), LI.end());
14691
14692 // Map's SCEV expressions from one ScalarEvolution "universe" to another.
14693 struct SCEVMapper : public SCEVRewriteVisitor<SCEVMapper> {
14694 SCEVMapper(ScalarEvolution &SE) : SCEVRewriteVisitor<SCEVMapper>(SE) {}
14695
14696 const SCEV *visitConstant(const SCEVConstant *Constant) {
14697 return SE.getConstant(Constant->getAPInt());
14698 }
14699
14700 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
14701 return SE.getUnknown(Expr->getValue());
14702 }
14703
14704 const SCEV *visitCouldNotCompute(const SCEVCouldNotCompute *Expr) {
14705 return SE.getCouldNotCompute();
14706 }
14707 };
14708
14709 SCEVMapper SCM(SE2);
14710 SmallPtrSet<BasicBlock *, 16> ReachableBlocks;
14711 SE2.getReachableBlocks(ReachableBlocks, F);
14712
14713 auto GetDelta = [&](const SCEV *Old, const SCEV *New) -> const SCEV * {
14714 if (containsUndefs(Old) || containsUndefs(New)) {
14715 // SCEV treats "undef" as an unknown but consistent value (i.e. it does
14716 // not propagate undef aggressively). This means we can (and do) fail
14717 // verification in cases where a transform makes a value go from "undef"
14718 // to "undef+1" (say). The transform is fine, since in both cases the
14719 // result is "undef", but SCEV thinks the value increased by 1.
14720 return nullptr;
14721 }
14722
14723 // Unless VerifySCEVStrict is set, we only compare constant deltas.
14724 const SCEV *Delta = SE2.getMinusSCEV(Old, New);
14725 if (!VerifySCEVStrict && !isa<SCEVConstant>(Delta))
14726 return nullptr;
14727
14728 return Delta;
14729 };
14730
14731 while (!LoopStack.empty()) {
14732 auto *L = LoopStack.pop_back_val();
14733 llvm::append_range(LoopStack, *L);
14734
14735 // Only verify BECounts in reachable loops. For an unreachable loop,
14736 // any BECount is legal.
14737 if (!ReachableBlocks.contains(L->getHeader()))
14738 continue;
14739
14740 // Only verify cached BECounts. Computing new BECounts may change the
14741 // results of subsequent SCEV uses.
14742 auto It = BackedgeTakenCounts.find(L);
14743 if (It == BackedgeTakenCounts.end())
14744 continue;
14745
14746 auto *CurBECount =
14747 SCM.visit(It->second.getExact(L, const_cast<ScalarEvolution *>(this)));
14748 auto *NewBECount = SE2.getBackedgeTakenCount(L);
14749
14750 if (CurBECount == SE2.getCouldNotCompute() ||
14751 NewBECount == SE2.getCouldNotCompute()) {
14752 // NB! This situation is legal, but is very suspicious -- whatever pass
14753 // change the loop to make a trip count go from could not compute to
14754 // computable or vice-versa *should have* invalidated SCEV. However, we
14755 // choose not to assert here (for now) since we don't want false
14756 // positives.
14757 continue;
14758 }
14759
14760 if (SE.getTypeSizeInBits(CurBECount->getType()) >
14761 SE.getTypeSizeInBits(NewBECount->getType()))
14762 NewBECount = SE2.getZeroExtendExpr(NewBECount, CurBECount->getType());
14763 else if (SE.getTypeSizeInBits(CurBECount->getType()) <
14764 SE.getTypeSizeInBits(NewBECount->getType()))
14765 CurBECount = SE2.getZeroExtendExpr(CurBECount, NewBECount->getType());
14766
14767 const SCEV *Delta = GetDelta(CurBECount, NewBECount);
14768 if (Delta && !Delta->isZero()) {
14769 dbgs() << "Trip Count for " << *L << " Changed!\n";
14770 dbgs() << "Old: " << *CurBECount << "\n";
14771 dbgs() << "New: " << *NewBECount << "\n";
14772 dbgs() << "Delta: " << *Delta << "\n";
14773 std::abort();
14774 }
14775 }
14776
14777 // Collect all valid loops currently in LoopInfo.
14778 SmallPtrSet<Loop *, 32> ValidLoops;
14779 SmallVector<Loop *, 32> Worklist(LI.begin(), LI.end());
14780 while (!Worklist.empty()) {
14781 Loop *L = Worklist.pop_back_val();
14782 if (ValidLoops.insert(L).second)
14783 Worklist.append(L->begin(), L->end());
14784 }
14785 for (const auto &KV : ValueExprMap) {
14786#ifndef NDEBUG
14787 // Check for SCEV expressions referencing invalid/deleted loops.
14788 if (auto *AR = dyn_cast<SCEVAddRecExpr>(KV.second)) {
14789 assert(ValidLoops.contains(AR->getLoop()) &&
14790 "AddRec references invalid loop");
14791 }
14792#endif
14793
14794 // Check that the value is also part of the reverse map.
14795 auto It = ExprValueMap.find(KV.second);
14796 if (It == ExprValueMap.end() || !It->second.contains(KV.first)) {
14797 dbgs() << "Value " << *KV.first
14798 << " is in ValueExprMap but not in ExprValueMap\n";
14799 std::abort();
14800 }
14801
14802 if (auto *I = dyn_cast<Instruction>(&*KV.first)) {
14803 if (!ReachableBlocks.contains(I->getParent()))
14804 continue;
14805 const SCEV *OldSCEV = SCM.visit(KV.second);
14806 const SCEV *NewSCEV = SE2.getSCEV(I);
14807 const SCEV *Delta = GetDelta(OldSCEV, NewSCEV);
14808 if (Delta && !Delta->isZero()) {
14809 dbgs() << "SCEV for value " << *I << " changed!\n"
14810 << "Old: " << *OldSCEV << "\n"
14811 << "New: " << *NewSCEV << "\n"
14812 << "Delta: " << *Delta << "\n";
14813 std::abort();
14814 }
14815 }
14816 }
14817
14818 for (const auto &KV : ExprValueMap) {
14819 for (Value *V : KV.second) {
14820 const SCEV *S = ValueExprMap.lookup(V);
14821 if (!S) {
14822 dbgs() << "Value " << *V
14823 << " is in ExprValueMap but not in ValueExprMap\n";
14824 std::abort();
14825 }
14826 if (S != KV.first) {
14827 dbgs() << "Value " << *V << " mapped to " << *S << " rather than "
14828 << *KV.first << "\n";
14829 std::abort();
14830 }
14831 }
14832 }
14833
14834 // Verify integrity of SCEV users.
14835 for (const auto &S : UniqueSCEVs) {
14836 for (SCEVUse Op : S.operands()) {
14837 // We do not store dependencies of constants.
14838 if (isa<SCEVConstant>(Op))
14839 continue;
14840 auto It = SCEVUsers.find(Op);
14841 if (It != SCEVUsers.end() && It->second.count(&S))
14842 continue;
14843 dbgs() << "Use of operand " << *Op << " by user " << S
14844 << " is not being tracked!\n";
14845 std::abort();
14846 }
14847 }
14848
14849 // Verify integrity of ValuesAtScopes users.
14850 for (const auto &ValueAndVec : ValuesAtScopes) {
14851 const SCEV *Value = ValueAndVec.first;
14852 for (const auto &LoopAndValueAtScope : ValueAndVec.second) {
14853 const Loop *L = LoopAndValueAtScope.first;
14854 const SCEV *ValueAtScope = LoopAndValueAtScope.second;
14855 if (!isa<SCEVConstant>(ValueAtScope)) {
14856 auto It = ValuesAtScopesUsers.find(ValueAtScope);
14857 if (It != ValuesAtScopesUsers.end() &&
14858 is_contained(It->second, std::make_pair(L, Value)))
14859 continue;
14860 dbgs() << "Value: " << *Value << ", Loop: " << *L << ", ValueAtScope: "
14861 << *ValueAtScope << " missing in ValuesAtScopesUsers\n";
14862 std::abort();
14863 }
14864 }
14865 }
14866
14867 for (const auto &ValueAtScopeAndVec : ValuesAtScopesUsers) {
14868 const SCEV *ValueAtScope = ValueAtScopeAndVec.first;
14869 for (const auto &LoopAndValue : ValueAtScopeAndVec.second) {
14870 const Loop *L = LoopAndValue.first;
14871 const SCEV *Value = LoopAndValue.second;
14873 auto It = ValuesAtScopes.find(Value);
14874 if (It != ValuesAtScopes.end() &&
14875 is_contained(It->second, std::make_pair(L, ValueAtScope)))
14876 continue;
14877 dbgs() << "Value: " << *Value << ", Loop: " << *L << ", ValueAtScope: "
14878 << *ValueAtScope << " missing in ValuesAtScopes\n";
14879 std::abort();
14880 }
14881 }
14882
14883 // Verify integrity of BECountUsers.
14884 auto VerifyBECountUsers = [&](bool Predicated) {
14885 auto &BECounts =
14886 Predicated ? PredicatedBackedgeTakenCounts : BackedgeTakenCounts;
14887 for (const auto &LoopAndBEInfo : BECounts) {
14888 for (const ExitNotTakenInfo &ENT : LoopAndBEInfo.second.ExitNotTaken) {
14889 for (const SCEV *S : {ENT.ExactNotTaken, ENT.SymbolicMaxNotTaken}) {
14890 if (!isa<SCEVConstant>(S)) {
14891 auto UserIt = BECountUsers.find(S);
14892 if (UserIt != BECountUsers.end() &&
14893 UserIt->second.contains({ LoopAndBEInfo.first, Predicated }))
14894 continue;
14895 dbgs() << "Value " << *S << " for loop " << *LoopAndBEInfo.first
14896 << " missing from BECountUsers\n";
14897 std::abort();
14898 }
14899 }
14900 }
14901 }
14902 };
14903 VerifyBECountUsers(/* Predicated */ false);
14904 VerifyBECountUsers(/* Predicated */ true);
14905
14906 // Verify intergity of loop disposition cache.
14907 for (auto &[S, Values] : LoopDispositions) {
14908 for (auto [Loop, CachedDisposition] : Values) {
14909 const auto RecomputedDisposition = SE2.getLoopDisposition(S, Loop);
14910 if (CachedDisposition != RecomputedDisposition) {
14911 dbgs() << "Cached disposition of " << *S << " for loop " << *Loop
14912 << " is incorrect: cached " << CachedDisposition << ", actual "
14913 << RecomputedDisposition << "\n";
14914 std::abort();
14915 }
14916 }
14917 }
14918
14919 // Verify integrity of the block disposition cache.
14920 for (auto &[S, Values] : BlockDispositions) {
14921 for (auto [BB, CachedDisposition] : Values) {
14922 const auto RecomputedDisposition = SE2.getBlockDisposition(S, BB);
14923 if (CachedDisposition != RecomputedDisposition) {
14924 dbgs() << "Cached disposition of " << *S << " for block %"
14925 << BB->getName() << " is incorrect: cached " << CachedDisposition
14926 << ", actual " << RecomputedDisposition << "\n";
14927 std::abort();
14928 }
14929 }
14930 }
14931
14932 // Verify FoldCache/FoldCacheUser caches.
14933 for (auto [FoldID, Expr] : FoldCache) {
14934 auto I = FoldCacheUser.find(Expr);
14935 if (I == FoldCacheUser.end()) {
14936 dbgs() << "Missing entry in FoldCacheUser for cached expression " << *Expr
14937 << "!\n";
14938 std::abort();
14939 }
14940 if (!is_contained(I->second, FoldID)) {
14941 dbgs() << "Missing FoldID in cached users of " << *Expr << "!\n";
14942 std::abort();
14943 }
14944 }
14945 for (auto [Expr, IDs] : FoldCacheUser) {
14946 for (auto &FoldID : IDs) {
14947 const SCEV *S = FoldCache.lookup(FoldID);
14948 if (!S) {
14949 dbgs() << "Missing entry in FoldCache for expression " << *Expr
14950 << "!\n";
14951 std::abort();
14952 }
14953 if (S != Expr) {
14954 dbgs() << "Entry in FoldCache doesn't match FoldCacheUser: " << *S
14955 << " != " << *Expr << "!\n";
14956 std::abort();
14957 }
14958 }
14959 }
14960
14961 // Verify that ConstantMultipleCache computations are correct. We check that
14962 // cached multiples and recomputed multiples are multiples of each other to
14963 // verify correctness. It is possible that a recomputed multiple is different
14964 // from the cached multiple due to strengthened no wrap flags or changes in
14965 // KnownBits computations.
14966 for (auto [S, Multiple] : ConstantMultipleCache) {
14967 APInt RecomputedMultiple = SE2.getConstantMultiple(S);
14968 if ((Multiple != 0 && RecomputedMultiple != 0 &&
14969 Multiple.urem(RecomputedMultiple) != 0 &&
14970 RecomputedMultiple.urem(Multiple) != 0)) {
14971 dbgs() << "Incorrect cached computation in ConstantMultipleCache for "
14972 << *S << " : Computed " << RecomputedMultiple
14973 << " but cache contains " << Multiple << "!\n";
14974 std::abort();
14975 }
14976 }
14977}
14978
14980 Function &F, const PreservedAnalyses &PA,
14981 FunctionAnalysisManager::Invalidator &Inv) {
14982 // Invalidate the ScalarEvolution object whenever it isn't preserved or one
14983 // of its dependencies is invalidated.
14984 auto PAC = PA.getChecker<ScalarEvolutionAnalysis>();
14985 return !(PAC.preserved() || PAC.preservedSet<AllAnalysesOn<Function>>()) ||
14986 Inv.invalidate<AssumptionAnalysis>(F, PA) ||
14987 Inv.invalidate<DominatorTreeAnalysis>(F, PA) ||
14988 Inv.invalidate<LoopAnalysis>(F, PA);
14989}
14990
14991AnalysisKey ScalarEvolutionAnalysis::Key;
14992
14995 auto &TLI = AM.getResult<TargetLibraryAnalysis>(F);
14996 auto &AC = AM.getResult<AssumptionAnalysis>(F);
14997 auto &DT = AM.getResult<DominatorTreeAnalysis>(F);
14998 auto &LI = AM.getResult<LoopAnalysis>(F);
14999 return ScalarEvolution(F, TLI, AC, DT, LI);
15000}
15001
15007
15010 // For compatibility with opt's -analyze feature under legacy pass manager
15011 // which was not ported to NPM. This keeps tests using
15012 // update_analyze_test_checks.py working.
15013 OS << "Printing analysis 'Scalar Evolution Analysis' for function '"
15014 << F.getName() << "':\n";
15016 return PreservedAnalyses::all();
15017}
15018
15020 "Scalar Evolution Analysis", false, true)
15026 "Scalar Evolution Analysis", false, true)
15027
15029
15031
15033 SE.reset(new ScalarEvolution(
15035 getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F),
15037 getAnalysis<LoopInfoWrapperPass>().getLoopInfo()));
15038 return false;
15039}
15040
15042
15044 SE->print(OS);
15045}
15046
15048 if (!VerifySCEV)
15049 return;
15050
15051 SE->verify();
15052}
15053
15061
15063 const SCEV *RHS) {
15064 return getComparePredicate(ICmpInst::ICMP_EQ, LHS, RHS);
15065}
15066
15067const SCEVPredicate *
15069 const SCEV *LHS, const SCEV *RHS) {
15071 assert(LHS->getType() == RHS->getType() &&
15072 "Type mismatch between LHS and RHS");
15073 // Unique this node based on the arguments
15074 ID.AddInteger(SCEVPredicate::P_Compare);
15075 ID.AddInteger(Pred);
15076 ID.AddPointer(LHS);
15077 ID.AddPointer(RHS);
15078 void *IP = nullptr;
15079 if (const auto *S = UniquePreds.FindNodeOrInsertPos(ID, IP))
15080 return S;
15081 SCEVComparePredicate *Eq = new (SCEVAllocator)
15082 SCEVComparePredicate(ID.Intern(SCEVAllocator), Pred, LHS, RHS);
15083 UniquePreds.InsertNode(Eq, IP);
15084 return Eq;
15085}
15086
15088 const SCEVAddRecExpr *AR,
15091 // Unique this node based on the arguments
15092 ID.AddInteger(SCEVPredicate::P_Wrap);
15093 ID.AddPointer(AR);
15094 ID.AddInteger(AddedFlags);
15095 void *IP = nullptr;
15096 if (const auto *S = UniquePreds.FindNodeOrInsertPos(ID, IP))
15097 return S;
15098 auto *OF = new (SCEVAllocator)
15099 SCEVWrapPredicate(ID.Intern(SCEVAllocator), AR, AddedFlags);
15100 UniquePreds.InsertNode(OF, IP);
15101 return OF;
15102}
15103
15104namespace {
15105
15106class SCEVPredicateRewriter : public SCEVRewriteVisitor<SCEVPredicateRewriter> {
15107public:
15108
15109 /// Rewrites \p S in the context of a loop L and the SCEV predication
15110 /// infrastructure.
15111 ///
15112 /// If \p Pred is non-null, the SCEV expression is rewritten to respect the
15113 /// equivalences present in \p Pred.
15114 ///
15115 /// If \p NewPreds is non-null, rewrite is free to add further predicates to
15116 /// \p NewPreds such that the result will be an AddRecExpr.
15117 static const SCEV *rewrite(const SCEV *S, const Loop *L, ScalarEvolution &SE,
15119 const SCEVPredicate *Pred) {
15120 SCEVPredicateRewriter Rewriter(L, SE, NewPreds, Pred);
15121 return Rewriter.visit(S);
15122 }
15123
15124 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
15125 if (Pred) {
15126 if (auto *U = dyn_cast<SCEVUnionPredicate>(Pred)) {
15127 for (const auto *Pred : U->getPredicates())
15128 if (const auto *IPred = dyn_cast<SCEVComparePredicate>(Pred))
15129 if (IPred->getLHS() == Expr &&
15130 IPred->getPredicate() == ICmpInst::ICMP_EQ)
15131 return IPred->getRHS();
15132 } else if (const auto *IPred = dyn_cast<SCEVComparePredicate>(Pred)) {
15133 if (IPred->getLHS() == Expr &&
15134 IPred->getPredicate() == ICmpInst::ICMP_EQ)
15135 return IPred->getRHS();
15136 }
15137 }
15138 return convertToAddRecWithPreds(Expr);
15139 }
15140
15141 const SCEV *visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr) {
15142 const SCEV *Operand = visit(Expr->getOperand());
15143 const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Operand);
15144 if (AR && AR->getLoop() == L && AR->isAffine()) {
15145 // This couldn't be folded because the operand didn't have the nuw
15146 // flag. Add the nusw flag as an assumption that we could make.
15147 const SCEV *Step = AR->getStepRecurrence(SE);
15148 Type *Ty = Expr->getType();
15149 if (addOverflowAssumption(AR, SCEVWrapPredicate::IncrementNUSW))
15150 return SE.getAddRecExpr(SE.getZeroExtendExpr(AR->getStart(), Ty),
15151 SE.getSignExtendExpr(Step, Ty), L,
15152 AR->getNoWrapFlags());
15153 }
15154 return SE.getZeroExtendExpr(Operand, Expr->getType());
15155 }
15156
15157 const SCEV *visitSignExtendExpr(const SCEVSignExtendExpr *Expr) {
15158 const SCEV *Operand = visit(Expr->getOperand());
15159 const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Operand);
15160 if (AR && AR->getLoop() == L && AR->isAffine()) {
15161 // This couldn't be folded because the operand didn't have the nsw
15162 // flag. Add the nssw flag as an assumption that we could make.
15163 const SCEV *Step = AR->getStepRecurrence(SE);
15164 Type *Ty = Expr->getType();
15165 if (addOverflowAssumption(AR, SCEVWrapPredicate::IncrementNSSW))
15166 return SE.getAddRecExpr(SE.getSignExtendExpr(AR->getStart(), Ty),
15167 SE.getSignExtendExpr(Step, Ty), L,
15168 AR->getNoWrapFlags());
15169 }
15170 return SE.getSignExtendExpr(Operand, Expr->getType());
15171 }
15172
15173private:
15174 explicit SCEVPredicateRewriter(
15175 const Loop *L, ScalarEvolution &SE,
15176 SmallVectorImpl<const SCEVPredicate *> *NewPreds,
15177 const SCEVPredicate *Pred)
15178 : SCEVRewriteVisitor(SE), NewPreds(NewPreds), Pred(Pred), L(L) {}
15179
15180 bool addOverflowAssumption(const SCEVPredicate *P) {
15181 if (!NewPreds) {
15182 // Check if we've already made this assumption.
15183 return Pred && Pred->implies(P, SE);
15184 }
15185 NewPreds->push_back(P);
15186 return true;
15187 }
15188
15189 bool addOverflowAssumption(const SCEVAddRecExpr *AR,
15191 auto *A = SE.getWrapPredicate(AR, AddedFlags);
15192 return addOverflowAssumption(A);
15193 }
15194
15195 // If \p Expr represents a PHINode, we try to see if it can be represented
15196 // as an AddRec, possibly under a predicate (PHISCEVPred). If it is possible
15197 // to add this predicate as a runtime overflow check, we return the AddRec.
15198 // If \p Expr does not meet these conditions (is not a PHI node, or we
15199 // couldn't create an AddRec for it, or couldn't add the predicate), we just
15200 // return \p Expr.
15201 const SCEV *convertToAddRecWithPreds(const SCEVUnknown *Expr) {
15202 if (!isa<PHINode>(Expr->getValue()))
15203 return Expr;
15204 std::optional<
15205 std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>>>
15206 PredicatedRewrite = SE.createAddRecFromPHIWithCasts(Expr);
15207 if (!PredicatedRewrite)
15208 return Expr;
15209 for (const auto *P : PredicatedRewrite->second){
15210 // Wrap predicates from outer loops are not supported.
15211 if (auto *WP = dyn_cast<const SCEVWrapPredicate>(P)) {
15212 if (L != WP->getExpr()->getLoop())
15213 return Expr;
15214 }
15215 if (!addOverflowAssumption(P))
15216 return Expr;
15217 }
15218 return PredicatedRewrite->first;
15219 }
15220
15221 SmallVectorImpl<const SCEVPredicate *> *NewPreds;
15222 const SCEVPredicate *Pred;
15223 const Loop *L;
15224};
15225
15226} // end anonymous namespace
15227
15228const SCEV *
15230 const SCEVPredicate &Preds) {
15231 return SCEVPredicateRewriter::rewrite(S, L, *this, nullptr, &Preds);
15232}
15233
15235 const SCEV *S, const Loop *L,
15238 S = SCEVPredicateRewriter::rewrite(S, L, *this, &TransformPreds, nullptr);
15239 auto *AddRec = dyn_cast<SCEVAddRecExpr>(S);
15240
15241 if (!AddRec)
15242 return nullptr;
15243
15244 // Check if any of the transformed predicates is known to be false. In that
15245 // case, it doesn't make sense to convert to a predicated AddRec, as the
15246 // versioned loop will never execute.
15247 for (const SCEVPredicate *Pred : TransformPreds) {
15248 auto *WrapPred = dyn_cast<SCEVWrapPredicate>(Pred);
15249 if (!WrapPred || WrapPred->getFlags() != SCEVWrapPredicate::IncrementNSSW)
15250 continue;
15251
15252 const SCEVAddRecExpr *AddRecToCheck = WrapPred->getExpr();
15253 const SCEV *ExitCount = getBackedgeTakenCount(AddRecToCheck->getLoop());
15254 if (isa<SCEVCouldNotCompute>(ExitCount))
15255 continue;
15256
15257 const SCEV *Step = AddRecToCheck->getStepRecurrence(*this);
15258 if (!Step->isOne())
15259 continue;
15260
15261 ExitCount = getTruncateOrSignExtend(ExitCount, Step->getType());
15262 const SCEV *Add = getAddExpr(AddRecToCheck->getStart(), ExitCount);
15263 if (isKnownPredicate(CmpInst::ICMP_SLT, Add, AddRecToCheck->getStart()))
15264 return nullptr;
15265 }
15266
15267 // Since the transformation was successful, we can now transfer the SCEV
15268 // predicates.
15269 Preds.append(TransformPreds.begin(), TransformPreds.end());
15270
15271 return AddRec;
15272}
15273
15274/// SCEV predicates
15278
15280 const ICmpInst::Predicate Pred,
15281 const SCEV *LHS, const SCEV *RHS)
15282 : SCEVPredicate(ID, P_Compare), Pred(Pred), LHS(LHS), RHS(RHS) {
15283 assert(LHS->getType() == RHS->getType() && "LHS and RHS types don't match");
15284 assert(LHS != RHS && "LHS and RHS are the same SCEV");
15285}
15286
15288 ScalarEvolution &SE) const {
15289 const auto *Op = dyn_cast<SCEVComparePredicate>(N);
15290
15291 if (!Op)
15292 return false;
15293
15294 if (Pred != ICmpInst::ICMP_EQ)
15295 return false;
15296
15297 return Op->LHS == LHS && Op->RHS == RHS;
15298}
15299
15300bool SCEVComparePredicate::isAlwaysTrue() const { return false; }
15301
15303 if (Pred == ICmpInst::ICMP_EQ)
15304 OS.indent(Depth) << "Equal predicate: " << *LHS << " == " << *RHS << "\n";
15305 else
15306 OS.indent(Depth) << "Compare predicate: " << *LHS << " " << Pred << ") "
15307 << *RHS << "\n";
15308
15309}
15310
15312 const SCEVAddRecExpr *AR,
15313 IncrementWrapFlags Flags)
15314 : SCEVPredicate(ID, P_Wrap), AR(AR), Flags(Flags) {}
15315
15316const SCEVAddRecExpr *SCEVWrapPredicate::getExpr() const { return AR; }
15317
15319 ScalarEvolution &SE) const {
15320 const auto *Op = dyn_cast<SCEVWrapPredicate>(N);
15321 if (!Op || setFlags(Flags, Op->Flags) != Flags)
15322 return false;
15323
15324 if (Op->AR == AR)
15325 return true;
15326
15327 if (Flags != SCEVWrapPredicate::IncrementNSSW &&
15329 return false;
15330
15331 const SCEV *Start = AR->getStart();
15332 const SCEV *OpStart = Op->AR->getStart();
15333 if (Start->getType()->isPointerTy() != OpStart->getType()->isPointerTy())
15334 return false;
15335
15336 // Reject pointers to different address spaces.
15337 if (Start->getType()->isPointerTy() && Start->getType() != OpStart->getType())
15338 return false;
15339
15340 // NUSW/NSSW on a wider-type AddRec does not imply the same on a
15341 // narrower-type AddRec.
15342 if (SE.getTypeSizeInBits(AR->getType()) >
15343 SE.getTypeSizeInBits(Op->AR->getType()))
15344 return false;
15345
15346 const SCEV *Step = AR->getStepRecurrence(SE);
15347 const SCEV *OpStep = Op->AR->getStepRecurrence(SE);
15348 if (!SE.isKnownPositive(Step) || !SE.isKnownPositive(OpStep))
15349 return false;
15350
15351 // If both steps are positive, this implies N, if N's start and step are
15352 // ULE/SLE (for NSUW/NSSW) than this'.
15353 Type *WiderTy = SE.getWiderType(Step->getType(), OpStep->getType());
15354 Step = SE.getNoopOrZeroExtend(Step, WiderTy);
15355 OpStep = SE.getNoopOrZeroExtend(OpStep, WiderTy);
15356
15357 bool IsNUW = Flags == SCEVWrapPredicate::IncrementNUSW;
15358 OpStart = IsNUW ? SE.getNoopOrZeroExtend(OpStart, WiderTy)
15359 : SE.getNoopOrSignExtend(OpStart, WiderTy);
15360 Start = IsNUW ? SE.getNoopOrZeroExtend(Start, WiderTy)
15361 : SE.getNoopOrSignExtend(Start, WiderTy);
15363 return SE.isKnownPredicate(Pred, OpStep, Step) &&
15364 SE.isKnownPredicate(Pred, OpStart, Start);
15365}
15366
15368 SCEV::NoWrapFlags ScevFlags = AR->getNoWrapFlags();
15369 IncrementWrapFlags IFlags = Flags;
15370
15371 if (ScalarEvolution::setFlags(ScevFlags, SCEV::FlagNSW) == ScevFlags)
15372 IFlags = clearFlags(IFlags, IncrementNSSW);
15373
15374 return IFlags == IncrementAnyWrap;
15375}
15376
15377void SCEVWrapPredicate::print(raw_ostream &OS, unsigned Depth) const {
15378 OS.indent(Depth) << *getExpr() << " Added Flags: ";
15380 OS << "<nusw>";
15382 OS << "<nssw>";
15383 OS << "\n";
15384}
15385
15388 ScalarEvolution &SE) {
15389 IncrementWrapFlags ImpliedFlags = IncrementAnyWrap;
15390 SCEV::NoWrapFlags StaticFlags = AR->getNoWrapFlags();
15391
15392 // We can safely transfer the NSW flag as NSSW.
15393 if (ScalarEvolution::setFlags(StaticFlags, SCEV::FlagNSW) == StaticFlags)
15394 ImpliedFlags = IncrementNSSW;
15395
15396 if (ScalarEvolution::setFlags(StaticFlags, SCEV::FlagNUW) == StaticFlags) {
15397 // If the increment is positive, the SCEV NUW flag will also imply the
15398 // WrapPredicate NUSW flag.
15399 if (const auto *Step = dyn_cast<SCEVConstant>(AR->getStepRecurrence(SE)))
15400 if (Step->getValue()->getValue().isNonNegative())
15401 ImpliedFlags = setFlags(ImpliedFlags, IncrementNUSW);
15402 }
15403
15404 return ImpliedFlags;
15405}
15406
15407/// Union predicates don't get cached so create a dummy set ID for it.
15409 ScalarEvolution &SE)
15410 : SCEVPredicate(FoldingSetNodeIDRef(nullptr, 0), P_Union) {
15411 for (const auto *P : Preds)
15412 add(P, SE);
15413}
15414
15416 return all_of(Preds,
15417 [](const SCEVPredicate *I) { return I->isAlwaysTrue(); });
15418}
15419
15421 ScalarEvolution &SE) const {
15422 if (const auto *Set = dyn_cast<SCEVUnionPredicate>(N))
15423 return all_of(Set->Preds, [this, &SE](const SCEVPredicate *I) {
15424 return this->implies(I, SE);
15425 });
15426
15427 return any_of(Preds,
15428 [N, &SE](const SCEVPredicate *I) { return I->implies(N, SE); });
15429}
15430
15432 for (const auto *Pred : Preds)
15433 Pred->print(OS, Depth);
15434}
15435
15436void SCEVUnionPredicate::add(const SCEVPredicate *N, ScalarEvolution &SE) {
15437 if (const auto *Set = dyn_cast<SCEVUnionPredicate>(N)) {
15438 for (const auto *Pred : Set->Preds)
15439 add(Pred, SE);
15440 return;
15441 }
15442
15443 // Implication checks are quadratic in the number of predicates. Stop doing
15444 // them if there are many predicates, as they should be too expensive to use
15445 // anyway at that point.
15446 bool CheckImplies = Preds.size() < 16;
15447
15448 // Only add predicate if it is not already implied by this union predicate.
15449 if (CheckImplies && implies(N, SE))
15450 return;
15451
15452 // Build a new vector containing the current predicates, except the ones that
15453 // are implied by the new predicate N.
15455 for (auto *P : Preds) {
15456 if (CheckImplies && N->implies(P, SE))
15457 continue;
15458 PrunedPreds.push_back(P);
15459 }
15460 Preds = std::move(PrunedPreds);
15461 Preds.push_back(N);
15462}
15463
15465 Loop &L)
15466 : SE(SE), L(L) {
15468 Preds = std::make_unique<SCEVUnionPredicate>(Empty, SE);
15469}
15470
15473 for (const auto *Op : Ops)
15474 // We do not expect that forgetting cached data for SCEVConstants will ever
15475 // open any prospects for sharpening or introduce any correctness issues,
15476 // so we don't bother storing their dependencies.
15477 if (!isa<SCEVConstant>(Op))
15478 SCEVUsers[Op].insert(User);
15479}
15480
15482 for (const SCEV *Op : Ops)
15483 // We do not expect that forgetting cached data for SCEVConstants will ever
15484 // open any prospects for sharpening or introduce any correctness issues,
15485 // so we don't bother storing their dependencies.
15486 if (!isa<SCEVConstant>(Op))
15487 SCEVUsers[Op].insert(User);
15488}
15489
15491 const SCEV *Expr = SE.getSCEV(V);
15492 return getPredicatedSCEV(Expr);
15493}
15494
15496 RewriteEntry &Entry = RewriteMap[Expr];
15497
15498 // If we already have an entry and the version matches, return it.
15499 if (Entry.second && Generation == Entry.first)
15500 return Entry.second;
15501
15502 // We found an entry but it's stale. Rewrite the stale entry
15503 // according to the current predicate.
15504 if (Entry.second)
15505 Expr = Entry.second;
15506
15507 const SCEV *NewSCEV = SE.rewriteUsingPredicate(Expr, &L, *Preds);
15508 Entry = {Generation, NewSCEV};
15509
15510 return NewSCEV;
15511}
15512
15514 if (!BackedgeCount) {
15516 BackedgeCount = SE.getPredicatedBackedgeTakenCount(&L, Preds);
15517 for (const auto *P : Preds)
15518 addPredicate(*P);
15519 }
15520 return BackedgeCount;
15521}
15522
15524 if (!SymbolicMaxBackedgeCount) {
15526 SymbolicMaxBackedgeCount =
15527 SE.getPredicatedSymbolicMaxBackedgeTakenCount(&L, Preds);
15528 for (const auto *P : Preds)
15529 addPredicate(*P);
15530 }
15531 return SymbolicMaxBackedgeCount;
15532}
15533
15535 if (!SmallConstantMaxTripCount) {
15537 SmallConstantMaxTripCount = SE.getSmallConstantMaxTripCount(&L, &Preds);
15538 for (const auto *P : Preds)
15539 addPredicate(*P);
15540 }
15541 return *SmallConstantMaxTripCount;
15542}
15543
15545 if (Preds->implies(&Pred, SE))
15546 return;
15547
15548 SmallVector<const SCEVPredicate *, 4> NewPreds(Preds->getPredicates());
15549 NewPreds.push_back(&Pred);
15550 Preds = std::make_unique<SCEVUnionPredicate>(NewPreds, SE);
15551 updateGeneration();
15552}
15553
15555 return *Preds;
15556}
15557
15558void PredicatedScalarEvolution::updateGeneration() {
15559 // If the generation number wrapped recompute everything.
15560 if (++Generation == 0) {
15561 for (auto &II : RewriteMap) {
15562 const SCEV *Rewritten = II.second.second;
15563 II.second = {Generation, SE.rewriteUsingPredicate(Rewritten, &L, *Preds)};
15564 }
15565 }
15566}
15567
15570 const SCEV *Expr = getSCEV(V);
15571 const auto *AR = cast<SCEVAddRecExpr>(Expr);
15572
15573 auto ImpliedFlags = SCEVWrapPredicate::getImpliedFlags(AR, SE);
15574
15575 // Clear the statically implied flags.
15576 Flags = SCEVWrapPredicate::clearFlags(Flags, ImpliedFlags);
15577 addPredicate(*SE.getWrapPredicate(AR, Flags));
15578
15579 auto II = FlagsMap.insert({V, Flags});
15580 if (!II.second)
15581 II.first->second = SCEVWrapPredicate::setFlags(Flags, II.first->second);
15582}
15583
15586 const SCEV *Expr = getSCEV(V);
15587 const auto *AR = cast<SCEVAddRecExpr>(Expr);
15588
15590 Flags, SCEVWrapPredicate::getImpliedFlags(AR, SE));
15591
15592 auto II = FlagsMap.find(V);
15593
15594 if (II != FlagsMap.end())
15595 Flags = SCEVWrapPredicate::clearFlags(Flags, II->second);
15596
15598}
15599
15601 const SCEV *Expr = this->getSCEV(V);
15603 auto *New = SE.convertSCEVToAddRecWithPredicates(Expr, &L, NewPreds);
15604
15605 if (!New)
15606 return nullptr;
15607
15608 for (const auto *P : NewPreds)
15609 addPredicate(*P);
15610
15611 RewriteMap[SE.getSCEV(V)] = {Generation, New};
15612 return New;
15613}
15614
15617 : RewriteMap(Init.RewriteMap), SE(Init.SE), L(Init.L),
15618 Preds(std::make_unique<SCEVUnionPredicate>(Init.Preds->getPredicates(),
15619 SE)),
15620 Generation(Init.Generation), BackedgeCount(Init.BackedgeCount) {
15621 for (auto I : Init.FlagsMap)
15622 FlagsMap.insert(I);
15623}
15624
15626 // For each block.
15627 for (auto *BB : L.getBlocks())
15628 for (auto &I : *BB) {
15629 if (!SE.isSCEVable(I.getType()))
15630 continue;
15631
15632 auto *Expr = SE.getSCEV(&I);
15633 auto II = RewriteMap.find(Expr);
15634
15635 if (II == RewriteMap.end())
15636 continue;
15637
15638 // Don't print things that are not interesting.
15639 if (II->second.second == Expr)
15640 continue;
15641
15642 OS.indent(Depth) << "[PSE]" << I << ":\n";
15643 OS.indent(Depth + 2) << *Expr << "\n";
15644 OS.indent(Depth + 2) << "--> " << *II->second.second << "\n";
15645 }
15646}
15647
15650 BasicBlock *Header = L->getHeader();
15651 BasicBlock *Pred = L->getLoopPredecessor();
15652 LoopGuards Guards(SE);
15653 if (!Pred)
15654 return Guards;
15656 collectFromBlock(SE, Guards, Header, Pred, VisitedBlocks);
15657 return Guards;
15658}
15659
15660void ScalarEvolution::LoopGuards::collectFromPHI(
15664 unsigned Depth) {
15665 if (!SE.isSCEVable(Phi.getType()))
15666 return;
15667
15668 using MinMaxPattern = std::pair<const SCEVConstant *, SCEVTypes>;
15669 auto GetMinMaxConst = [&](unsigned IncomingIdx) -> MinMaxPattern {
15670 const BasicBlock *InBlock = Phi.getIncomingBlock(IncomingIdx);
15671 if (!VisitedBlocks.insert(InBlock).second)
15672 return {nullptr, scCouldNotCompute};
15673
15674 // Avoid analyzing unreachable blocks so that we don't get trapped
15675 // traversing cycles with ill-formed dominance or infinite cycles
15676 if (!SE.DT.isReachableFromEntry(InBlock))
15677 return {nullptr, scCouldNotCompute};
15678
15679 auto [G, Inserted] = IncomingGuards.try_emplace(InBlock, LoopGuards(SE));
15680 if (Inserted)
15681 collectFromBlock(SE, G->second, Phi.getParent(), InBlock, VisitedBlocks,
15682 Depth + 1);
15683 auto &RewriteMap = G->second.RewriteMap;
15684 if (RewriteMap.empty())
15685 return {nullptr, scCouldNotCompute};
15686 auto S = RewriteMap.find(SE.getSCEV(Phi.getIncomingValue(IncomingIdx)));
15687 if (S == RewriteMap.end())
15688 return {nullptr, scCouldNotCompute};
15689 auto *SM = dyn_cast_if_present<SCEVMinMaxExpr>(S->second);
15690 if (!SM)
15691 return {nullptr, scCouldNotCompute};
15692 if (const SCEVConstant *C0 = dyn_cast<SCEVConstant>(SM->getOperand(0)))
15693 return {C0, SM->getSCEVType()};
15694 return {nullptr, scCouldNotCompute};
15695 };
15696 auto MergeMinMaxConst = [](MinMaxPattern P1,
15697 MinMaxPattern P2) -> MinMaxPattern {
15698 auto [C1, T1] = P1;
15699 auto [C2, T2] = P2;
15700 if (!C1 || !C2 || T1 != T2)
15701 return {nullptr, scCouldNotCompute};
15702 switch (T1) {
15703 case scUMaxExpr:
15704 return {C1->getAPInt().ult(C2->getAPInt()) ? C1 : C2, T1};
15705 case scSMaxExpr:
15706 return {C1->getAPInt().slt(C2->getAPInt()) ? C1 : C2, T1};
15707 case scUMinExpr:
15708 return {C1->getAPInt().ugt(C2->getAPInt()) ? C1 : C2, T1};
15709 case scSMinExpr:
15710 return {C1->getAPInt().sgt(C2->getAPInt()) ? C1 : C2, T1};
15711 default:
15712 llvm_unreachable("Trying to merge non-MinMaxExpr SCEVs.");
15713 }
15714 };
15715 auto P = GetMinMaxConst(0);
15716 for (unsigned int In = 1; In < Phi.getNumIncomingValues(); In++) {
15717 if (!P.first)
15718 break;
15719 P = MergeMinMaxConst(P, GetMinMaxConst(In));
15720 }
15721 if (P.first) {
15722 const SCEV *LHS = SE.getSCEV(const_cast<PHINode *>(&Phi));
15723 SmallVector<SCEVUse, 2> Ops({P.first, LHS});
15724 const SCEV *RHS = SE.getMinMaxExpr(P.second, Ops);
15725 Guards.RewriteMap.insert({LHS, RHS});
15726 }
15727}
15728
15729// Return a new SCEV that modifies \p Expr to the closest number divides by
15730// \p Divisor and less or equal than Expr. For now, only handle constant
15731// Expr.
15733 const APInt &DivisorVal,
15734 ScalarEvolution &SE) {
15735 const APInt *ExprVal;
15736 if (!match(Expr, m_scev_APInt(ExprVal)) || ExprVal->isNegative() ||
15737 DivisorVal.isNonPositive())
15738 return Expr;
15739 APInt Rem = ExprVal->urem(DivisorVal);
15740 // return the SCEV: Expr - Expr % Divisor
15741 return SE.getConstant(*ExprVal - Rem);
15742}
15743
15744// Return a new SCEV that modifies \p Expr to the closest number divides by
15745// \p Divisor and greater or equal than Expr. For now, only handle constant
15746// Expr.
15747static const SCEV *getNextSCEVDivisibleByDivisor(const SCEV *Expr,
15748 const APInt &DivisorVal,
15749 ScalarEvolution &SE) {
15750 const APInt *ExprVal;
15751 if (!match(Expr, m_scev_APInt(ExprVal)) || ExprVal->isNegative() ||
15752 DivisorVal.isNonPositive())
15753 return Expr;
15754 APInt Rem = ExprVal->urem(DivisorVal);
15755 if (Rem.isZero())
15756 return Expr;
15757 // return the SCEV: Expr + Divisor - Expr % Divisor
15758 return SE.getConstant(*ExprVal + DivisorVal - Rem);
15759}
15760
15762 ICmpInst::Predicate Predicate, const SCEV *LHS, const SCEV *RHS,
15765 // If we have LHS == 0, check if LHS is computing a property of some unknown
15766 // SCEV %v which we can rewrite %v to express explicitly.
15768 return false;
15769 // If LHS is A % B, i.e. A % B == 0, rewrite A to (A /u B) * B to
15770 // explicitly express that.
15771 const SCEVUnknown *URemLHS = nullptr;
15772 const SCEV *URemRHS = nullptr;
15773 if (!match(LHS, m_scev_URem(m_SCEVUnknown(URemLHS), m_SCEV(URemRHS), SE)))
15774 return false;
15775
15776 const SCEV *Multiple =
15777 SE.getMulExpr(SE.getUDivExpr(URemLHS, URemRHS), URemRHS);
15778 DivInfo[URemLHS] = Multiple;
15779 if (auto *C = dyn_cast<SCEVConstant>(URemRHS))
15780 Multiples[URemLHS] = C->getAPInt();
15781 return true;
15782}
15783
15784// Check if the condition is a divisibility guard (A % B == 0).
15785static bool isDivisibilityGuard(const SCEV *LHS, const SCEV *RHS,
15786 ScalarEvolution &SE) {
15787 const SCEV *X, *Y;
15788 return match(LHS, m_scev_URem(m_SCEV(X), m_SCEV(Y), SE)) && RHS->isZero();
15789}
15790
15791// Apply divisibility by \p Divisor on MinMaxExpr with constant values,
15792// recursively. This is done by aligning up/down the constant value to the
15793// Divisor.
15794static const SCEV *applyDivisibilityOnMinMaxExpr(const SCEV *MinMaxExpr,
15795 APInt Divisor,
15796 ScalarEvolution &SE) {
15797 // Return true if \p Expr is a MinMax SCEV expression with a non-negative
15798 // constant operand. If so, return in \p SCTy the SCEV type and in \p RHS
15799 // the non-constant operand and in \p LHS the constant operand.
15800 auto IsMinMaxSCEVWithNonNegativeConstant =
15801 [&](const SCEV *Expr, SCEVTypes &SCTy, const SCEV *&LHS,
15802 const SCEV *&RHS) {
15803 if (auto *MinMax = dyn_cast<SCEVMinMaxExpr>(Expr)) {
15804 if (MinMax->getNumOperands() != 2)
15805 return false;
15806 if (auto *C = dyn_cast<SCEVConstant>(MinMax->getOperand(0))) {
15807 if (C->getAPInt().isNegative())
15808 return false;
15809 SCTy = MinMax->getSCEVType();
15810 LHS = MinMax->getOperand(0);
15811 RHS = MinMax->getOperand(1);
15812 return true;
15813 }
15814 }
15815 return false;
15816 };
15817
15818 const SCEV *MinMaxLHS = nullptr, *MinMaxRHS = nullptr;
15819 SCEVTypes SCTy;
15820 if (!IsMinMaxSCEVWithNonNegativeConstant(MinMaxExpr, SCTy, MinMaxLHS,
15821 MinMaxRHS))
15822 return MinMaxExpr;
15823 auto IsMin = isa<SCEVSMinExpr>(MinMaxExpr) || isa<SCEVUMinExpr>(MinMaxExpr);
15824 assert(SE.isKnownNonNegative(MinMaxLHS) && "Expected non-negative operand!");
15825 auto *DivisibleExpr =
15826 IsMin ? getPreviousSCEVDivisibleByDivisor(MinMaxLHS, Divisor, SE)
15827 : getNextSCEVDivisibleByDivisor(MinMaxLHS, Divisor, SE);
15829 applyDivisibilityOnMinMaxExpr(MinMaxRHS, Divisor, SE), DivisibleExpr};
15830 return SE.getMinMaxExpr(SCTy, Ops);
15831}
15832
15833void ScalarEvolution::LoopGuards::collectFromBlock(
15834 ScalarEvolution &SE, ScalarEvolution::LoopGuards &Guards,
15835 const BasicBlock *Block, const BasicBlock *Pred,
15836 SmallPtrSetImpl<const BasicBlock *> &VisitedBlocks, unsigned Depth) {
15837
15839
15840 SmallVector<SCEVUse> ExprsToRewrite;
15841 auto CollectCondition = [&](ICmpInst::Predicate Predicate, const SCEV *LHS,
15842 const SCEV *RHS,
15843 DenseMap<const SCEV *, const SCEV *> &RewriteMap,
15844 const LoopGuards &DivGuards) {
15845 // WARNING: It is generally unsound to apply any wrap flags to the proposed
15846 // replacement SCEV which isn't directly implied by the structure of that
15847 // SCEV. In particular, using contextual facts to imply flags is *NOT*
15848 // legal. See the scoping rules for flags in the header to understand why.
15849
15850 // Check for a condition of the form (-C1 + X < C2). InstCombine will
15851 // create this form when combining two checks of the form (X u< C2 + C1) and
15852 // (X >=u C1).
15853 auto MatchRangeCheckIdiom = [&SE, Predicate, LHS, RHS, &RewriteMap,
15854 &ExprsToRewrite]() {
15855 const SCEVConstant *C1;
15856 const SCEVUnknown *LHSUnknown;
15857 auto *C2 = dyn_cast<SCEVConstant>(RHS);
15858 if (!match(LHS,
15859 m_scev_Add(m_SCEVConstant(C1), m_SCEVUnknown(LHSUnknown))) ||
15860 !C2)
15861 return false;
15862
15863 auto ExactRegion =
15864 ConstantRange::makeExactICmpRegion(Predicate, C2->getAPInt())
15865 .sub(C1->getAPInt());
15866
15867 // Bail out, unless we have a non-wrapping, monotonic range.
15868 if (ExactRegion.isWrappedSet() || ExactRegion.isFullSet())
15869 return false;
15870 auto [I, Inserted] = RewriteMap.try_emplace(LHSUnknown);
15871 const SCEV *RewrittenLHS = Inserted ? LHSUnknown : I->second;
15872 I->second = SE.getUMaxExpr(
15873 SE.getConstant(ExactRegion.getUnsignedMin()),
15874 SE.getUMinExpr(RewrittenLHS,
15875 SE.getConstant(ExactRegion.getUnsignedMax())));
15876 ExprsToRewrite.push_back(LHSUnknown);
15877 return true;
15878 };
15879 if (MatchRangeCheckIdiom())
15880 return;
15881
15882 // Do not apply information for constants or if RHS contains an AddRec.
15884 return;
15885
15886 // If RHS is SCEVUnknown, make sure the information is applied to it.
15888 std::swap(LHS, RHS);
15890 }
15891
15892 // Puts rewrite rule \p From -> \p To into the rewrite map. Also if \p From
15893 // and \p FromRewritten are the same (i.e. there has been no rewrite
15894 // registered for \p From), then puts this value in the list of rewritten
15895 // expressions.
15896 auto AddRewrite = [&](const SCEV *From, const SCEV *FromRewritten,
15897 const SCEV *To) {
15898 if (From == FromRewritten)
15899 ExprsToRewrite.push_back(From);
15900 RewriteMap[From] = To;
15901 };
15902
15903 // Checks whether \p S has already been rewritten. In that case returns the
15904 // existing rewrite because we want to chain further rewrites onto the
15905 // already rewritten value. Otherwise returns \p S.
15906 auto GetMaybeRewritten = [&](const SCEV *S) {
15907 return RewriteMap.lookup_or(S, S);
15908 };
15909
15910 const SCEV *RewrittenLHS = GetMaybeRewritten(LHS);
15911 // Apply divisibility information when computing the constant multiple.
15912 const APInt &DividesBy =
15913 SE.getConstantMultiple(DivGuards.rewrite(RewrittenLHS));
15914
15915 // Collect rewrites for LHS and its transitive operands based on the
15916 // condition.
15917 // For min/max expressions, also apply the guard to its operands:
15918 // 'min(a, b) >= c' -> '(a >= c) and (b >= c)',
15919 // 'min(a, b) > c' -> '(a > c) and (b > c)',
15920 // 'max(a, b) <= c' -> '(a <= c) and (b <= c)',
15921 // 'max(a, b) < c' -> '(a < c) and (b < c)'.
15922
15923 // We cannot express strict predicates in SCEV, so instead we replace them
15924 // with non-strict ones against plus or minus one of RHS depending on the
15925 // predicate.
15926 const SCEV *One = SE.getOne(RHS->getType());
15927 switch (Predicate) {
15928 case CmpInst::ICMP_ULT:
15929 if (RHS->getType()->isPointerTy())
15930 return;
15931 RHS = SE.getUMaxExpr(RHS, One);
15932 [[fallthrough]];
15933 case CmpInst::ICMP_SLT: {
15934 RHS = SE.getMinusSCEV(RHS, One);
15935 RHS = getPreviousSCEVDivisibleByDivisor(RHS, DividesBy, SE);
15936 break;
15937 }
15938 case CmpInst::ICMP_UGT:
15939 case CmpInst::ICMP_SGT:
15940 RHS = SE.getAddExpr(RHS, One);
15941 RHS = getNextSCEVDivisibleByDivisor(RHS, DividesBy, SE);
15942 break;
15943 case CmpInst::ICMP_ULE:
15944 case CmpInst::ICMP_SLE:
15945 RHS = getPreviousSCEVDivisibleByDivisor(RHS, DividesBy, SE);
15946 break;
15947 case CmpInst::ICMP_UGE:
15948 case CmpInst::ICMP_SGE:
15949 RHS = getNextSCEVDivisibleByDivisor(RHS, DividesBy, SE);
15950 break;
15951 default:
15952 break;
15953 }
15954
15955 SmallVector<SCEVUse, 16> Worklist(1, LHS);
15956 SmallPtrSet<const SCEV *, 16> Visited;
15957
15958 auto EnqueueOperands = [&Worklist](const SCEVNAryExpr *S) {
15959 append_range(Worklist, S->operands());
15960 };
15961
15962 while (!Worklist.empty()) {
15963 const SCEV *From = Worklist.pop_back_val();
15964 if (isa<SCEVConstant>(From))
15965 continue;
15966 if (!Visited.insert(From).second)
15967 continue;
15968 const SCEV *FromRewritten = GetMaybeRewritten(From);
15969 const SCEV *To = nullptr;
15970
15971 switch (Predicate) {
15972 case CmpInst::ICMP_ULT:
15973 case CmpInst::ICMP_ULE:
15974 To = SE.getUMinExpr(FromRewritten, RHS);
15975 if (auto *UMax = dyn_cast<SCEVUMaxExpr>(FromRewritten))
15976 EnqueueOperands(UMax);
15977 break;
15978 case CmpInst::ICMP_SLT:
15979 case CmpInst::ICMP_SLE:
15980 To = SE.getSMinExpr(FromRewritten, RHS);
15981 if (auto *SMax = dyn_cast<SCEVSMaxExpr>(FromRewritten))
15982 EnqueueOperands(SMax);
15983 break;
15984 case CmpInst::ICMP_UGT:
15985 case CmpInst::ICMP_UGE:
15986 To = SE.getUMaxExpr(FromRewritten, RHS);
15987 if (auto *UMin = dyn_cast<SCEVUMinExpr>(FromRewritten))
15988 EnqueueOperands(UMin);
15989 break;
15990 case CmpInst::ICMP_SGT:
15991 case CmpInst::ICMP_SGE:
15992 To = SE.getSMaxExpr(FromRewritten, RHS);
15993 if (auto *SMin = dyn_cast<SCEVSMinExpr>(FromRewritten))
15994 EnqueueOperands(SMin);
15995 break;
15996 case CmpInst::ICMP_EQ:
15998 To = RHS;
15999 break;
16000 case CmpInst::ICMP_NE:
16001 if (match(RHS, m_scev_Zero())) {
16002 const SCEV *OneAlignedUp =
16003 getNextSCEVDivisibleByDivisor(One, DividesBy, SE);
16004 To = SE.getUMaxExpr(FromRewritten, OneAlignedUp);
16005 } else {
16006 // LHS != RHS can be rewritten as (LHS - RHS) = UMax(1, LHS - RHS),
16007 // but creating the subtraction eagerly is expensive. Track the
16008 // inequalities in a separate map, and materialize the rewrite lazily
16009 // when encountering a suitable subtraction while re-writing.
16010 if (LHS->getType()->isPointerTy()) {
16014 break;
16015 }
16016 const SCEVConstant *C;
16017 const SCEV *A, *B;
16020 RHS = A;
16021 LHS = B;
16022 }
16023 if (LHS > RHS)
16024 std::swap(LHS, RHS);
16025 Guards.NotEqual.insert({LHS, RHS});
16026 continue;
16027 }
16028 break;
16029 default:
16030 break;
16031 }
16032
16033 if (To)
16034 AddRewrite(From, FromRewritten, To);
16035 }
16036 };
16037
16039 // First, collect information from assumptions dominating the loop.
16040 for (auto &AssumeVH : SE.AC.assumptions()) {
16041 if (!AssumeVH)
16042 continue;
16043 auto *AssumeI = cast<CallInst>(AssumeVH);
16044 if (!SE.DT.dominates(AssumeI, Block))
16045 continue;
16046 Terms.emplace_back(AssumeI->getOperand(0), true);
16047 }
16048
16049 // Second, collect information from llvm.experimental.guards dominating the loop.
16050 auto *GuardDecl = Intrinsic::getDeclarationIfExists(
16051 SE.F.getParent(), Intrinsic::experimental_guard);
16052 if (GuardDecl)
16053 for (const auto *GU : GuardDecl->users())
16054 if (const auto *Guard = dyn_cast<IntrinsicInst>(GU))
16055 if (Guard->getFunction() == Block->getParent() &&
16056 SE.DT.dominates(Guard, Block))
16057 Terms.emplace_back(Guard->getArgOperand(0), true);
16058
16059 // Third, collect conditions from dominating branches. Starting at the loop
16060 // predecessor, climb up the predecessor chain, as long as there are
16061 // predecessors that can be found that have unique successors leading to the
16062 // original header.
16063 // TODO: share this logic with isLoopEntryGuardedByCond.
16064 unsigned NumCollectedConditions = 0;
16066 std::pair<const BasicBlock *, const BasicBlock *> Pair(Pred, Block);
16067 for (; Pair.first;
16068 Pair = SE.getPredecessorWithUniqueSuccessorForBB(Pair.first)) {
16069 VisitedBlocks.insert(Pair.second);
16070 const CondBrInst *LoopEntryPredicate =
16071 dyn_cast<CondBrInst>(Pair.first->getTerminator());
16072 if (!LoopEntryPredicate)
16073 continue;
16074
16075 Terms.emplace_back(LoopEntryPredicate->getCondition(),
16076 LoopEntryPredicate->getSuccessor(0) == Pair.second);
16077 NumCollectedConditions++;
16078
16079 // If we are recursively collecting guards stop after 2
16080 // conditions to limit compile-time impact for now.
16081 if (Depth > 0 && NumCollectedConditions == 2)
16082 break;
16083 }
16084 // Finally, if we stopped climbing the predecessor chain because
16085 // there wasn't a unique one to continue, try to collect conditions
16086 // for PHINodes by recursively following all of their incoming
16087 // blocks and try to merge the found conditions to build a new one
16088 // for the Phi.
16089 if (Pair.second->hasNPredecessorsOrMore(2) &&
16091 SmallDenseMap<const BasicBlock *, LoopGuards> IncomingGuards;
16092 for (auto &Phi : Pair.second->phis())
16093 collectFromPHI(SE, Guards, Phi, VisitedBlocks, IncomingGuards, Depth);
16094 }
16095
16096 // Now apply the information from the collected conditions to
16097 // Guards.RewriteMap. Conditions are processed in reverse order, so the
16098 // earliest conditions is processed first, except guards with divisibility
16099 // information, which are moved to the back. This ensures the SCEVs with the
16100 // shortest dependency chains are constructed first.
16102 GuardsToProcess;
16103 for (auto [Term, EnterIfTrue] : reverse(Terms)) {
16104 SmallVector<Value *, 8> Worklist;
16105 SmallPtrSet<Value *, 8> Visited;
16106 Worklist.push_back(Term);
16107 while (!Worklist.empty()) {
16108 Value *Cond = Worklist.pop_back_val();
16109 if (!Visited.insert(Cond).second)
16110 continue;
16111
16112 if (auto *Cmp = dyn_cast<ICmpInst>(Cond)) {
16113 auto Predicate =
16114 EnterIfTrue ? Cmp->getPredicate() : Cmp->getInversePredicate();
16115 const auto *LHS = SE.getSCEV(Cmp->getOperand(0));
16116 const auto *RHS = SE.getSCEV(Cmp->getOperand(1));
16117 // If LHS is a constant, apply information to the other expression.
16118 // TODO: If LHS is not a constant, check if using CompareSCEVComplexity
16119 // can improve results.
16120 if (isa<SCEVConstant>(LHS)) {
16121 std::swap(LHS, RHS);
16123 }
16124 GuardsToProcess.emplace_back(Predicate, LHS, RHS);
16125 continue;
16126 }
16127
16128 Value *L, *R;
16129 if (EnterIfTrue ? match(Cond, m_LogicalAnd(m_Value(L), m_Value(R)))
16130 : match(Cond, m_LogicalOr(m_Value(L), m_Value(R)))) {
16131 Worklist.push_back(L);
16132 Worklist.push_back(R);
16133 }
16134 }
16135 }
16136
16137 // Process divisibility guards in reverse order to populate DivGuards early.
16138 DenseMap<const SCEV *, APInt> Multiples;
16139 LoopGuards DivGuards(SE);
16140 for (const auto &[Predicate, LHS, RHS] : GuardsToProcess) {
16141 if (!isDivisibilityGuard(LHS, RHS, SE))
16142 continue;
16143 collectDivisibilityInformation(Predicate, LHS, RHS, DivGuards.RewriteMap,
16144 Multiples, SE);
16145 }
16146
16147 for (const auto &[Predicate, LHS, RHS] : GuardsToProcess)
16148 CollectCondition(Predicate, LHS, RHS, Guards.RewriteMap, DivGuards);
16149
16150 // Apply divisibility information last. This ensures it is applied to the
16151 // outermost expression after other rewrites for the given value.
16152 for (const auto &[K, Divisor] : Multiples) {
16153 const SCEV *DivisorSCEV = SE.getConstant(Divisor);
16154 Guards.RewriteMap[K] =
16156 Guards.rewrite(K), Divisor, SE),
16157 DivisorSCEV),
16158 DivisorSCEV);
16159 ExprsToRewrite.push_back(K);
16160 }
16161
16162 // Let the rewriter preserve NUW/NSW flags if the unsigned/signed ranges of
16163 // the replacement expressions are contained in the ranges of the replaced
16164 // expressions.
16165 Guards.PreserveNUW = true;
16166 Guards.PreserveNSW = true;
16167 for (const SCEV *Expr : ExprsToRewrite) {
16168 const SCEV *RewriteTo = Guards.RewriteMap[Expr];
16169 Guards.PreserveNUW &=
16170 SE.getUnsignedRange(Expr).contains(SE.getUnsignedRange(RewriteTo));
16171 Guards.PreserveNSW &=
16172 SE.getSignedRange(Expr).contains(SE.getSignedRange(RewriteTo));
16173 }
16174
16175 // Now that all rewrite information is collect, rewrite the collected
16176 // expressions with the information in the map. This applies information to
16177 // sub-expressions.
16178 if (ExprsToRewrite.size() > 1) {
16179 for (const SCEV *Expr : ExprsToRewrite) {
16180 const SCEV *RewriteTo = Guards.RewriteMap[Expr];
16181 Guards.RewriteMap.erase(Expr);
16182 Guards.RewriteMap.insert({Expr, Guards.rewrite(RewriteTo)});
16183 }
16184 }
16185}
16186
16188 /// A rewriter to replace SCEV expressions in Map with the corresponding entry
16189 /// in the map. It skips AddRecExpr because we cannot guarantee that the
16190 /// replacement is loop invariant in the loop of the AddRec.
16191 class SCEVLoopGuardRewriter
16192 : public SCEVRewriteVisitor<SCEVLoopGuardRewriter> {
16195
16197
16198 public:
16199 SCEVLoopGuardRewriter(ScalarEvolution &SE,
16200 const ScalarEvolution::LoopGuards &Guards)
16201 : SCEVRewriteVisitor(SE), Map(Guards.RewriteMap),
16202 NotEqual(Guards.NotEqual) {
16203 if (Guards.PreserveNUW)
16204 FlagMask = ScalarEvolution::setFlags(FlagMask, SCEV::FlagNUW);
16205 if (Guards.PreserveNSW)
16206 FlagMask = ScalarEvolution::setFlags(FlagMask, SCEV::FlagNSW);
16207 }
16208
16209 const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) { return Expr; }
16210
16211 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
16212 return Map.lookup_or(Expr, Expr);
16213 }
16214
16215 const SCEV *visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr) {
16216 if (const SCEV *S = Map.lookup(Expr))
16217 return S;
16218
16219 // If we didn't find the extact ZExt expr in the map, check if there's
16220 // an entry for a smaller ZExt we can use instead.
16221 Type *Ty = Expr->getType();
16222 const SCEV *Op = Expr->getOperand(0);
16223 unsigned Bitwidth = Ty->getScalarSizeInBits() / 2;
16224 while (Bitwidth % 8 == 0 && Bitwidth >= 8 &&
16225 Bitwidth > Op->getType()->getScalarSizeInBits()) {
16226 Type *NarrowTy = IntegerType::get(SE.getContext(), Bitwidth);
16227 auto *NarrowExt = SE.getZeroExtendExpr(Op, NarrowTy);
16228 if (const SCEV *S = Map.lookup(NarrowExt))
16229 return SE.getZeroExtendExpr(S, Ty);
16230 Bitwidth = Bitwidth / 2;
16231 }
16232
16234 Expr);
16235 }
16236
16237 const SCEV *visitSignExtendExpr(const SCEVSignExtendExpr *Expr) {
16238 if (const SCEV *S = Map.lookup(Expr))
16239 return S;
16241 Expr);
16242 }
16243
16244 const SCEV *visitUMinExpr(const SCEVUMinExpr *Expr) {
16245 if (const SCEV *S = Map.lookup(Expr))
16246 return S;
16248 }
16249
16250 const SCEV *visitSMinExpr(const SCEVSMinExpr *Expr) {
16251 if (const SCEV *S = Map.lookup(Expr))
16252 return S;
16254 }
16255
16256 const SCEV *visitAddExpr(const SCEVAddExpr *Expr) {
16257 // Helper to check if S is a subtraction (A - B) where A != B, and if so,
16258 // return UMax(S, 1).
16259 auto RewriteSubtraction = [&](const SCEV *S) -> const SCEV * {
16260 SCEVUse LHS, RHS;
16261 if (MatchBinarySub(S, LHS, RHS)) {
16262 if (LHS > RHS)
16263 std::swap(LHS, RHS);
16264 if (NotEqual.contains({LHS, RHS})) {
16265 const SCEV *OneAlignedUp = getNextSCEVDivisibleByDivisor(
16266 SE.getOne(S->getType()), SE.getConstantMultiple(S), SE);
16267 return SE.getUMaxExpr(OneAlignedUp, S);
16268 }
16269 }
16270 return nullptr;
16271 };
16272
16273 // Check if Expr itself is a subtraction pattern with guard info.
16274 if (const SCEV *Rewritten = RewriteSubtraction(Expr))
16275 return Rewritten;
16276
16277 // Trip count expressions sometimes consist of adding 3 operands, i.e.
16278 // (Const + A + B). There may be guard info for A + B, and if so, apply
16279 // it.
16280 // TODO: Could more generally apply guards to Add sub-expressions.
16281 if (isa<SCEVConstant>(Expr->getOperand(0)) &&
16282 Expr->getNumOperands() == 3) {
16283 const SCEV *Add =
16284 SE.getAddExpr(Expr->getOperand(1), Expr->getOperand(2));
16285 if (const SCEV *Rewritten = RewriteSubtraction(Add))
16286 return SE.getAddExpr(
16287 Expr->getOperand(0), Rewritten,
16288 ScalarEvolution::maskFlags(Expr->getNoWrapFlags(), FlagMask));
16289 if (const SCEV *S = Map.lookup(Add))
16290 return SE.getAddExpr(Expr->getOperand(0), S);
16291 }
16292 SmallVector<SCEVUse, 2> Operands;
16293 bool Changed = false;
16294 for (SCEVUse Op : Expr->operands()) {
16295 Operands.push_back(
16297 Changed |= Op != Operands.back();
16298 }
16299 // We are only replacing operands with equivalent values, so transfer the
16300 // flags from the original expression.
16301 return !Changed ? Expr
16302 : SE.getAddExpr(Operands,
16304 Expr->getNoWrapFlags(), FlagMask));
16305 }
16306
16307 const SCEV *visitMulExpr(const SCEVMulExpr *Expr) {
16308 SmallVector<SCEVUse, 2> Operands;
16309 bool Changed = false;
16310 for (SCEVUse Op : Expr->operands()) {
16311 Operands.push_back(
16313 Changed |= Op != Operands.back();
16314 }
16315 // We are only replacing operands with equivalent values, so transfer the
16316 // flags from the original expression.
16317 return !Changed ? Expr
16318 : SE.getMulExpr(Operands,
16320 Expr->getNoWrapFlags(), FlagMask));
16321 }
16322 };
16323
16324 if (RewriteMap.empty() && NotEqual.empty())
16325 return Expr;
16326
16327 SCEVLoopGuardRewriter Rewriter(SE, *this);
16328 return Rewriter.visit(Expr);
16329}
16330
16331const SCEV *ScalarEvolution::applyLoopGuards(const SCEV *Expr, const Loop *L) {
16332 return applyLoopGuards(Expr, LoopGuards::collect(L, *this));
16333}
16334
16336 const LoopGuards &Guards) {
16337 return Guards.rewrite(Expr);
16338}
assert(UImm &&(UImm !=~static_cast< T >(0)) &&"Invalid immediate!")
constexpr LLT S1
Rewrite undef for PHI
This file implements a class to represent arbitrary precision integral constant values and operations...
@ PostInc
MachineBasicBlock MachineBasicBlock::iterator DebugLoc DL
Expand Atomic instructions
#define X(NUM, ENUM, NAME)
Definition ELF.h:851
static GCRegistry::Add< ErlangGC > A("erlang", "erlang-compatible garbage collector")
static GCRegistry::Add< StatepointGC > D("statepoint-example", "an example strategy for statepoint")
static GCRegistry::Add< CoreCLRGC > E("coreclr", "CoreCLR-compatible GC")
static GCRegistry::Add< OcamlGC > B("ocaml", "ocaml 3.10-compatible GC")
#define LLVM_DUMP_METHOD
Mark debug helper function definitions like dump() that should not be stripped from debug builds.
Definition Compiler.h:661
This file contains the declarations for the subclasses of Constant, which represent the different fla...
SmallPtrSet< const BasicBlock *, 8 > VisitedBlocks
This file defines the DenseMap class.
This file builds on the ADT/GraphTraits.h file to build generic depth first graph iterator.
static bool isSigned(unsigned Opcode)
This file defines a hash set that can be used to remove duplication of nodes in a graph.
#define op(i)
Hexagon Common GEP
Value * getPointer(Value *Ptr)
This file provides various utilities for inspecting and working with the control flow graph in LLVM I...
This defines the Use class.
iv Induction Variable Users
Definition IVUsers.cpp:48
const AbstractManglingParser< Derived, Alloc >::OperatorInfo AbstractManglingParser< Derived, Alloc >::Ops[]
static bool isZero(Value *V, const DataLayout &DL, DominatorTree *DT, AssumptionCache *AC)
Definition Lint.cpp:539
#define F(x, y, z)
Definition MD5.cpp:54
#define I(x, y, z)
Definition MD5.cpp:57
#define G(x, y, z)
Definition MD5.cpp:55
#define T
#define T1
static constexpr unsigned SM(unsigned Version)
ConstantRange Range(APInt(BitWidth, Low), APInt(BitWidth, High))
uint64_t IntrinsicInst * II
#define P(N)
ppc ctr loops verify
PowerPC Reduce CR logical Operation
#define INITIALIZE_PASS_DEPENDENCY(depName)
Definition PassSupport.h:42
#define INITIALIZE_PASS_END(passName, arg, name, cfg, analysis)
Definition PassSupport.h:44
#define INITIALIZE_PASS_BEGIN(passName, arg, name, cfg, analysis)
Definition PassSupport.h:39
R600 Clause Merge
const SmallVectorImpl< MachineOperand > & Cond
static DominatorTree getDomTree(Function &F)
static bool isValid(const char C)
Returns true if C is a valid mangled character: <0-9a-zA-Z_>.
SI optimize exec mask operations pre RA
static void visit(BasicBlock &Start, std::function< bool(BasicBlock *)> op)
This file contains some templates that are useful if you are working with the STL at all.
This file provides utility classes that use RAII to save and restore values.
bool SCEVMinMaxExprContains(const SCEV *Root, const SCEV *OperandToFind, SCEVTypes RootKind)
static cl::opt< unsigned > MaxAddRecSize("scalar-evolution-max-add-rec-size", cl::Hidden, cl::desc("Max coefficients in AddRec during evolving"), cl::init(8))
static cl::opt< unsigned > RangeIterThreshold("scev-range-iter-threshold", cl::Hidden, cl::desc("Threshold for switching to iteratively computing SCEV ranges"), cl::init(32))
static const Loop * isIntegerLoopHeaderPHI(const PHINode *PN, LoopInfo &LI)
static unsigned getConstantTripCount(const SCEVConstant *ExitCount)
static int CompareValueComplexity(const LoopInfo *const LI, Value *LV, Value *RV, unsigned Depth)
Compare the two values LV and RV in terms of their "complexity" where "complexity" is a partial (and ...
static const SCEV * getNextSCEVDivisibleByDivisor(const SCEV *Expr, const APInt &DivisorVal, ScalarEvolution &SE)
static void PushLoopPHIs(const Loop *L, SmallVectorImpl< Instruction * > &Worklist, SmallPtrSetImpl< Instruction * > &Visited)
Push PHI nodes in the header of the given loop onto the given Worklist.
static void insertFoldCacheEntry(const ScalarEvolution::FoldID &ID, const SCEV *S, DenseMap< ScalarEvolution::FoldID, const SCEV * > &FoldCache, DenseMap< const SCEV *, SmallVector< ScalarEvolution::FoldID, 2 > > &FoldCacheUser)
static cl::opt< bool > ClassifyExpressions("scalar-evolution-classify-expressions", cl::Hidden, cl::init(true), cl::desc("When printing analysis, include information on every instruction"))
static bool hasHugeExpression(ArrayRef< SCEVUse > Ops)
Returns true if Ops contains a huge SCEV (the subtree of S contains at least HugeExprThreshold nodes)...
static bool CanConstantFold(const Instruction *I)
Return true if we can constant fold an instruction of the specified type, assuming that all operands ...
static cl::opt< unsigned > AddOpsInlineThreshold("scev-addops-inline-threshold", cl::Hidden, cl::desc("Threshold for inlining addition operands into a SCEV"), cl::init(500))
static cl::opt< unsigned > MaxLoopGuardCollectionDepth("scalar-evolution-max-loop-guard-collection-depth", cl::Hidden, cl::desc("Maximum depth for recursive loop guard collection"), cl::init(1))
static cl::opt< bool > VerifyIR("scev-verify-ir", cl::Hidden, cl::desc("Verify IR correctness when making sensitive SCEV queries (slow)"), cl::init(false))
static bool RangeRefPHIAllowedOperands(DominatorTree &DT, PHINode *PHI)
static const SCEV * getPreStartForExtend(const SCEVAddRecExpr *AR, Type *Ty, ScalarEvolution *SE, unsigned Depth)
static std::optional< APInt > MinOptional(std::optional< APInt > X, std::optional< APInt > Y)
Helper function to compare optional APInts: (a) if X and Y both exist, return min(X,...
static cl::opt< unsigned > MulOpsInlineThreshold("scev-mulops-inline-threshold", cl::Hidden, cl::desc("Threshold for inlining multiplication operands into a SCEV"), cl::init(32))
static BinaryOperator * getCommonInstForPHI(PHINode *PN)
static bool isDivisibilityGuard(const SCEV *LHS, const SCEV *RHS, ScalarEvolution &SE)
static std::optional< const SCEV * > createNodeForSelectViaUMinSeq(ScalarEvolution *SE, const SCEV *CondExpr, const SCEV *TrueExpr, const SCEV *FalseExpr)
static Constant * BuildConstantFromSCEV(const SCEV *V)
This builds up a Constant using the ConstantExpr interface.
static ConstantInt * EvaluateConstantChrecAtConstant(const SCEVAddRecExpr *AddRec, ConstantInt *C, ScalarEvolution &SE)
static const SCEV * BinomialCoefficient(const SCEV *It, unsigned K, ScalarEvolution &SE, Type *ResultTy)
Compute BC(It, K). The result has width W. Assume, K > 0.
static cl::opt< unsigned > MaxCastDepth("scalar-evolution-max-cast-depth", cl::Hidden, cl::desc("Maximum depth of recursive SExt/ZExt/Trunc"), cl::init(8))
static bool IsMinMaxConsistingOf(const SCEV *MaybeMinMaxExpr, const SCEV *Candidate)
Is MaybeMinMaxExpr an (U|S)(Min|Max) of Candidate and some other values?
static PHINode * getConstantEvolvingPHI(Value *V, const Loop *L)
getConstantEvolvingPHI - Given an LLVM value and a loop, return a PHI node in the loop that V is deri...
static const SCEV * SolveLinEquationWithOverflow(const APInt &A, const SCEV *B, SmallVectorImpl< const SCEVPredicate * > *Predicates, ScalarEvolution &SE, const Loop *L)
Finds the minimum unsigned root of the following equation:
static cl::opt< unsigned > MaxBruteForceIterations("scalar-evolution-max-iterations", cl::ReallyHidden, cl::desc("Maximum number of iterations SCEV will " "symbolically execute a constant " "derived loop"), cl::init(100))
static uint64_t umul_ov(uint64_t i, uint64_t j, bool &Overflow)
static void PrintSCEVWithTypeHint(raw_ostream &OS, const SCEV *S)
When printing a top-level SCEV for trip counts, it's helpful to include a type for constants which ar...
static void PrintLoopInfo(raw_ostream &OS, ScalarEvolution *SE, const Loop *L)
static SCEV::NoWrapFlags StrengthenNoWrapFlags(ScalarEvolution *SE, SCEVTypes Type, ArrayRef< SCEVUse > Ops, SCEV::NoWrapFlags Flags)
static bool containsConstantInAddMulChain(const SCEV *StartExpr)
Determine if any of the operands in this SCEV are a constant or if any of the add or multiply express...
static const SCEV * getExtendAddRecStart(const SCEVAddRecExpr *AR, Type *Ty, ScalarEvolution *SE, unsigned Depth)
static bool CollectAddOperandsWithScales(SmallDenseMap< SCEVUse, APInt, 16 > &M, SmallVectorImpl< SCEVUse > &NewOps, APInt &AccumulatedConstant, ArrayRef< SCEVUse > Ops, const APInt &Scale, ScalarEvolution &SE)
Process the given Ops list, which is a list of operands to be added under the given scale,...
static const SCEV * constantFoldAndGroupOps(ScalarEvolution &SE, LoopInfo &LI, DominatorTree &DT, SmallVectorImpl< SCEVUse > &Ops, FoldT Fold, IsIdentityT IsIdentity, IsAbsorberT IsAbsorber)
Performs a number of common optimizations on the passed Ops.
static cl::opt< unsigned > MaxPhiSCCAnalysisSize("scalar-evolution-max-scc-analysis-depth", cl::Hidden, cl::desc("Maximum amount of nodes to process while searching SCEVUnknown " "Phi strongly connected components"), cl::init(8))
static bool IsKnownPredicateViaAddRecStart(ScalarEvolution &SE, CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS)
static void GroupByComplexity(SmallVectorImpl< SCEVUse > &Ops, LoopInfo *LI, DominatorTree &DT)
Given a list of SCEV objects, order them by their complexity, and group objects of the same complexit...
static bool collectDivisibilityInformation(ICmpInst::Predicate Predicate, const SCEV *LHS, const SCEV *RHS, DenseMap< const SCEV *, const SCEV * > &DivInfo, DenseMap< const SCEV *, APInt > &Multiples, ScalarEvolution &SE)
static cl::opt< unsigned > MaxSCEVOperationsImplicationDepth("scalar-evolution-max-scev-operations-implication-depth", cl::Hidden, cl::desc("Maximum depth of recursive SCEV operations implication analysis"), cl::init(2))
static void PushDefUseChildren(Instruction *I, SmallVectorImpl< Instruction * > &Worklist, SmallPtrSetImpl< Instruction * > &Visited)
Push users of the given Instruction onto the given Worklist.
static std::optional< APInt > SolveQuadraticAddRecRange(const SCEVAddRecExpr *AddRec, const ConstantRange &Range, ScalarEvolution &SE)
Let c(n) be the value of the quadratic chrec {0,+,M,+,N} after n iterations.
static cl::opt< bool > UseContextForNoWrapFlagInference("scalar-evolution-use-context-for-no-wrap-flag-strenghening", cl::Hidden, cl::desc("Infer nuw/nsw flags using context where suitable"), cl::init(true))
static cl::opt< bool > EnableFiniteLoopControl("scalar-evolution-finite-loop", cl::Hidden, cl::desc("Handle <= and >= in finite loops"), cl::init(true))
static bool getOperandsForSelectLikePHI(DominatorTree &DT, PHINode *PN, Value *&Cond, Value *&LHS, Value *&RHS)
static std::optional< std::tuple< APInt, APInt, APInt, APInt, unsigned > > GetQuadraticEquation(const SCEVAddRecExpr *AddRec)
For a given quadratic addrec, generate coefficients of the corresponding quadratic equation,...
static bool isKnownPredicateExtendIdiom(CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS)
static std::optional< BinaryOp > MatchBinaryOp(Value *V, const DataLayout &DL, AssumptionCache &AC, const DominatorTree &DT, const Instruction *CxtI)
Try to map V into a BinaryOp, and return std::nullopt on failure.
static std::optional< APInt > SolveQuadraticAddRecExact(const SCEVAddRecExpr *AddRec, ScalarEvolution &SE)
Let c(n) be the value of the quadratic chrec {L,+,M,+,N} after n iterations.
static std::optional< APInt > TruncIfPossible(std::optional< APInt > X, unsigned BitWidth)
Helper function to truncate an optional APInt to a given BitWidth.
static cl::opt< unsigned > MaxSCEVCompareDepth("scalar-evolution-max-scev-compare-depth", cl::Hidden, cl::desc("Maximum depth of recursive SCEV complexity comparisons"), cl::init(32))
static APInt extractConstantWithoutWrapping(ScalarEvolution &SE, const SCEVConstant *ConstantTerm, const SCEVAddExpr *WholeAddExpr)
static cl::opt< unsigned > MaxConstantEvolvingDepth("scalar-evolution-max-constant-evolving-depth", cl::Hidden, cl::desc("Maximum depth of recursive constant evolving"), cl::init(32))
static ConstantRange getRangeForAffineARHelper(APInt Step, const ConstantRange &StartRange, const APInt &MaxBECount, bool Signed)
static bool MatchBinarySub(const SCEV *S, SCEVUse &LHS, SCEVUse &RHS)
static std::optional< ConstantRange > GetRangeFromMetadata(Value *V)
Helper method to assign a range to V from metadata present in the IR.
static cl::opt< unsigned > HugeExprThreshold("scalar-evolution-huge-expr-threshold", cl::Hidden, cl::desc("Size of the expression which is considered huge"), cl::init(4096))
static Type * isSimpleCastedPHI(const SCEV *Op, const SCEVUnknown *SymbolicPHI, bool &Signed, ScalarEvolution &SE)
Helper function to createAddRecFromPHIWithCasts.
static Constant * EvaluateExpression(Value *V, const Loop *L, DenseMap< Instruction *, Constant * > &Vals, const DataLayout &DL, const TargetLibraryInfo *TLI)
EvaluateExpression - Given an expression that passes the getConstantEvolvingPHI predicate,...
static const SCEV * getPreviousSCEVDivisibleByDivisor(const SCEV *Expr, const APInt &DivisorVal, ScalarEvolution &SE)
static const SCEV * MatchNotExpr(const SCEV *Expr)
If Expr computes ~A, return A else return nullptr.
static cl::opt< unsigned > MaxValueCompareDepth("scalar-evolution-max-value-compare-depth", cl::Hidden, cl::desc("Maximum depth of recursive value complexity comparisons"), cl::init(2))
static const SCEV * applyDivisibilityOnMinMaxExpr(const SCEV *MinMaxExpr, APInt Divisor, ScalarEvolution &SE)
static cl::opt< bool, true > VerifySCEVOpt("verify-scev", cl::Hidden, cl::location(VerifySCEV), cl::desc("Verify ScalarEvolution's backedge taken counts (slow)"))
static const SCEV * getSignedOverflowLimitForStep(const SCEV *Step, ICmpInst::Predicate *Pred, ScalarEvolution *SE)
static cl::opt< unsigned > MaxArithDepth("scalar-evolution-max-arith-depth", cl::Hidden, cl::desc("Maximum depth of recursive arithmetics"), cl::init(32))
static bool HasSameValue(const SCEV *A, const SCEV *B)
SCEV structural equivalence is usually sufficient for testing whether two expressions are equal,...
static uint64_t Choose(uint64_t n, uint64_t k, bool &Overflow)
Compute the result of "n choose k", the binomial coefficient.
static std::optional< int > CompareSCEVComplexity(const LoopInfo *const LI, const SCEV *LHS, const SCEV *RHS, DominatorTree &DT, unsigned Depth=0)
static bool canConstantEvolve(Instruction *I, const Loop *L)
Determine whether this instruction can constant evolve within this loop assuming its operands can all...
static PHINode * getConstantEvolvingPHIOperands(Instruction *UseInst, const Loop *L, DenseMap< Instruction *, PHINode * > &PHIMap, unsigned Depth)
getConstantEvolvingPHIOperands - Implement getConstantEvolvingPHI by recursing through each instructi...
static bool scevUnconditionallyPropagatesPoisonFromOperands(SCEVTypes Kind)
static cl::opt< bool > VerifySCEVStrict("verify-scev-strict", cl::Hidden, cl::desc("Enable stricter verification with -verify-scev is passed"))
static Constant * getOtherIncomingValue(PHINode *PN, BasicBlock *BB)
static cl::opt< bool > UseExpensiveRangeSharpening("scalar-evolution-use-expensive-range-sharpening", cl::Hidden, cl::init(false), cl::desc("Use more powerful methods of sharpening expression ranges. May " "be costly in terms of compile time"))
static const SCEV * getUnsignedOverflowLimitForStep(const SCEV *Step, ICmpInst::Predicate *Pred, ScalarEvolution *SE)
static bool IsKnownPredicateViaMinOrMax(ScalarEvolution &SE, CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS)
Is LHS Pred RHS true on the virtue of LHS or RHS being a Min or Max expression?
static bool BrPHIToSelect(DominatorTree &DT, CondBrInst *BI, PHINode *Merge, Value *&C, Value *&LHS, Value *&RHS)
This file defines the make_scope_exit function, which executes user-defined cleanup logic at scope ex...
static bool InBlock(const Value *V, const BasicBlock *BB)
Provides some synthesis utilities to produce sequences of values.
This file defines the SmallPtrSet class.
This file defines the SmallVector class.
This file defines the 'Statistic' class, which is designed to be an easy way to expose various metric...
#define STATISTIC(VARNAME, DESC)
Definition Statistic.h:171
This file contains some functions that are useful when dealing with strings.
#define LLVM_DEBUG(...)
Definition Debug.h:114
static TableGen::Emitter::Opt Y("gen-skeleton-entry", EmitSkeleton, "Generate example skeleton entry")
static SymbolRef::Type getType(const Symbol *Sym)
Definition TapiFile.cpp:39
LocallyHashedType DenseMapInfo< LocallyHashedType >::Empty
static std::optional< bool > isImpliedCondOperands(CmpInst::Predicate Pred, const Value *ALHS, const Value *ARHS, const Value *BLHS, const Value *BRHS)
Return true if "icmp Pred BLHS BRHS" is true whenever "icmp PredALHS ARHS" is true.
Virtual Register Rewriter
Value * RHS
Value * LHS
BinaryOperator * Mul
static const uint32_t IV[8]
Definition blake3_impl.h:83
SCEVCastSinkingRewriter(ScalarEvolution &SE, Type *TargetTy, ConversionFn CreatePtrCast)
static const SCEV * rewrite(const SCEV *Scev, ScalarEvolution &SE, Type *TargetTy, ConversionFn CreatePtrCast)
const SCEV * visitUnknown(const SCEVUnknown *Expr)
const SCEV * visitMulExpr(const SCEVMulExpr *Expr)
const SCEV * visitAddExpr(const SCEVAddExpr *Expr)
const SCEV * visit(const SCEV *S)
Class for arbitrary precision integers.
Definition APInt.h:78
LLVM_ABI APInt umul_ov(const APInt &RHS, bool &Overflow) const
Definition APInt.cpp:2022
LLVM_ABI APInt zext(unsigned width) const
Zero extend to a new width.
Definition APInt.cpp:1054
bool isMinSignedValue() const
Determine if this is the smallest signed value.
Definition APInt.h:424
uint64_t getZExtValue() const
Get zero extended value.
Definition APInt.h:1563
void setHighBits(unsigned hiBits)
Set the top hiBits bits.
Definition APInt.h:1414
LLVM_ABI APInt getHiBits(unsigned numBits) const
Compute an APInt containing numBits highbits from this APInt.
Definition APInt.cpp:639
unsigned getActiveBits() const
Compute the number of active bits in the value.
Definition APInt.h:1535
LLVM_ABI APInt trunc(unsigned width) const
Truncate to new width.
Definition APInt.cpp:967
static APInt getMaxValue(unsigned numBits)
Gets maximum unsigned value of APInt for specific bit width.
Definition APInt.h:207
APInt abs() const
Get the absolute value.
Definition APInt.h:1818
bool sgt(const APInt &RHS) const
Signed greater than comparison.
Definition APInt.h:1208
bool isAllOnes() const
Determine if all bits are set. This is true for zero-width values.
Definition APInt.h:372
bool ugt(const APInt &RHS) const
Unsigned greater than comparison.
Definition APInt.h:1189
bool isZero() const
Determine if this value is zero, i.e. all bits are clear.
Definition APInt.h:381
bool isSignMask() const
Check if the APInt's value is returned by getSignMask.
Definition APInt.h:467
LLVM_ABI APInt urem(const APInt &RHS) const
Unsigned remainder operation.
Definition APInt.cpp:1708
unsigned getBitWidth() const
Return the number of bits in the APInt.
Definition APInt.h:1511
bool ult(const APInt &RHS) const
Unsigned less than comparison.
Definition APInt.h:1118
static APInt getSignedMaxValue(unsigned numBits)
Gets maximum signed value of APInt for a specific bit width.
Definition APInt.h:210
static APInt getMinValue(unsigned numBits)
Gets minimum unsigned value of APInt for a specific bit width.
Definition APInt.h:217
bool isNegative() const
Determine sign of this APInt.
Definition APInt.h:330
bool sle(const APInt &RHS) const
Signed less or equal comparison.
Definition APInt.h:1173
static APInt getSignedMinValue(unsigned numBits)
Gets minimum signed value of APInt for a specific bit width.
Definition APInt.h:220
bool isNonPositive() const
Determine if this APInt Value is non-positive (<= 0).
Definition APInt.h:362
unsigned countTrailingZeros() const
Definition APInt.h:1670
bool isStrictlyPositive() const
Determine if this APInt Value is positive.
Definition APInt.h:357
unsigned logBase2() const
Definition APInt.h:1784
APInt ashr(unsigned ShiftAmt) const
Arithmetic right-shift function.
Definition APInt.h:834
LLVM_ABI APInt multiplicativeInverse() const
Definition APInt.cpp:1316
bool ule(const APInt &RHS) const
Unsigned less or equal comparison.
Definition APInt.h:1157
LLVM_ABI APInt sext(unsigned width) const
Sign extend to a new width.
Definition APInt.cpp:1027
APInt shl(unsigned shiftAmt) const
Left-shift function.
Definition APInt.h:880
bool isPowerOf2() const
Check if this APInt's value is a power of two greater than zero.
Definition APInt.h:441
static APInt getLowBitsSet(unsigned numBits, unsigned loBitsSet)
Constructs an APInt value that has the bottom loBitsSet bits set.
Definition APInt.h:307
bool isSignBitSet() const
Determine if sign bit of this APInt is set.
Definition APInt.h:342
bool slt(const APInt &RHS) const
Signed less than comparison.
Definition APInt.h:1137
static APInt getZero(unsigned numBits)
Get the '0' value for the specified bit-width.
Definition APInt.h:201
bool isIntN(unsigned N) const
Check if this APInt has an N-bits unsigned integer value.
Definition APInt.h:433
static APInt getOneBitSet(unsigned numBits, unsigned BitNo)
Return an APInt with exactly one bit set in the result.
Definition APInt.h:240
bool uge(const APInt &RHS) const
Unsigned greater or equal comparison.
Definition APInt.h:1228
This templated class represents "all analyses that operate over <aparticular IR unit>" (e....
Definition Analysis.h:50
PassT::Result & getResult(IRUnitT &IR, ExtraArgTs... ExtraArgs)
Get the result of an analysis pass for a given IR unit.
Represent the analysis usage information of a pass.
void setPreservesAll()
Set by analyses that do not transform their input at all.
AnalysisUsage & addRequiredTransitive()
ArrayRef - Represent a constant reference to an array (0 or more elements consecutively in memory),...
Definition ArrayRef.h:40
iterator end() const
Definition ArrayRef.h:131
size_t size() const
size - Get the array size.
Definition ArrayRef.h:142
iterator begin() const
Definition ArrayRef.h:130
A function analysis which provides an AssumptionCache.
An immutable pass that tracks lazily created AssumptionCache objects.
A cache of @llvm.assume calls within a function.
MutableArrayRef< WeakVH > assumptions()
Access the list of assumption handles currently tracked for this function.
LLVM Basic Block Representation.
Definition BasicBlock.h:62
iterator begin()
Instruction iterator methods.
Definition BasicBlock.h:461
const Function * getParent() const
Return the enclosing method, or null if none.
Definition BasicBlock.h:213
LLVM_ABI const BasicBlock * getSinglePredecessor() const
Return the predecessor of this block if it has a single predecessor block.
const Instruction & front() const
Definition BasicBlock.h:484
const Instruction * getTerminator() const LLVM_READONLY
Returns the terminator instruction; assumes that the block is well-formed.
Definition BasicBlock.h:237
LLVM_ABI unsigned getNoWrapKind() const
Returns one of OBO::NoSignedWrap or OBO::NoUnsignedWrap.
LLVM_ABI Instruction::BinaryOps getBinaryOp() const
Returns the binary operation underlying the intrinsic.
BinaryOps getOpcode() const
Definition InstrTypes.h:374
This class represents a function call, abstracting a target machine's calling convention.
virtual void deleted()
Callback for Value destruction.
void setValPtr(Value *P)
bool isFalseWhenEqual() const
This is just a convenience.
Definition InstrTypes.h:948
Predicate
This enumeration lists the possible predicates for CmpInst subclasses.
Definition InstrTypes.h:676
@ ICMP_SLT
signed less than
Definition InstrTypes.h:705
@ ICMP_SLE
signed less or equal
Definition InstrTypes.h:706
@ ICMP_UGE
unsigned greater or equal
Definition InstrTypes.h:700
@ ICMP_UGT
unsigned greater than
Definition InstrTypes.h:699
@ ICMP_SGT
signed greater than
Definition InstrTypes.h:703
@ ICMP_ULT
unsigned less than
Definition InstrTypes.h:701
@ ICMP_NE
not equal
Definition InstrTypes.h:698
@ ICMP_SGE
signed greater or equal
Definition InstrTypes.h:704
@ ICMP_ULE
unsigned less or equal
Definition InstrTypes.h:702
bool isSigned() const
Definition InstrTypes.h:930
Predicate getSwappedPredicate() const
For example, EQ->EQ, SLE->SGE, ULT->UGT, OEQ->OEQ, ULE->UGE, OLT->OGT, etc.
Definition InstrTypes.h:827
bool isTrueWhenEqual() const
This is just a convenience.
Definition InstrTypes.h:942
Predicate getInversePredicate() const
For example, EQ -> NE, UGT -> ULE, SLT -> SGE, OEQ -> UNE, UGT -> OLE, OLT -> UGE,...
Definition InstrTypes.h:789
bool isUnsigned() const
Definition InstrTypes.h:936
bool isRelational() const
Return true if the predicate is relational (not EQ or NE).
Definition InstrTypes.h:926
An abstraction over a floating-point predicate, and a pack of an integer predicate with samesign info...
static LLVM_ABI std::optional< CmpPredicate > getMatching(CmpPredicate A, CmpPredicate B)
Compares two CmpPredicates taking samesign into account and returns the canonicalized CmpPredicate if...
LLVM_ABI CmpInst::Predicate getPreferredSignedPredicate() const
Attempts to return a signed CmpInst::Predicate from the CmpPredicate.
CmpInst::Predicate dropSameSign() const
Drops samesign information.
Conditional Branch instruction.
Value * getCondition() const
BasicBlock * getSuccessor(unsigned i) const
static LLVM_ABI Constant * getNot(Constant *C)
static Constant * getPtrAdd(Constant *Ptr, Constant *Offset, GEPNoWrapFlags NW=GEPNoWrapFlags::none(), std::optional< ConstantRange > InRange=std::nullopt, Type *OnlyIfReduced=nullptr)
Create a getelementptr i8, ptr, offset constant expression.
Definition Constants.h:1472
static LLVM_ABI Constant * getPtrToInt(Constant *C, Type *Ty, bool OnlyIfReduced=false)
static LLVM_ABI Constant * getPtrToAddr(Constant *C, Type *Ty, bool OnlyIfReduced=false)
static LLVM_ABI Constant * getAdd(Constant *C1, Constant *C2, bool HasNUW=false, bool HasNSW=false)
static LLVM_ABI Constant * getNeg(Constant *C, bool HasNSW=false)
static LLVM_ABI Constant * getTrunc(Constant *C, Type *Ty, bool OnlyIfReduced=false)
This is the shared class of boolean and integer constants.
Definition Constants.h:87
bool isZero() const
This is just a convenience method to make client code smaller for a common code.
Definition Constants.h:219
static LLVM_ABI ConstantInt * getFalse(LLVMContext &Context)
uint64_t getZExtValue() const
Return the constant as a 64-bit unsigned integer value after it has been zero extended as appropriate...
Definition Constants.h:168
const APInt & getValue() const
Return the constant as an APInt value reference.
Definition Constants.h:159
static LLVM_ABI ConstantInt * getBool(LLVMContext &Context, bool V)
This class represents a range of values.
LLVM_ABI ConstantRange add(const ConstantRange &Other) const
Return a new range representing the possible values resulting from an addition of a value in this ran...
LLVM_ABI ConstantRange zextOrTrunc(uint32_t BitWidth) const
Make this range have the bit width given by BitWidth.
PreferredRangeType
If represented precisely, the result of some range operations may consist of multiple disjoint ranges...
LLVM_ABI bool getEquivalentICmp(CmpInst::Predicate &Pred, APInt &RHS) const
Set up Pred and RHS such that ConstantRange::makeExactICmpRegion(Pred, RHS) == *this.
const APInt & getLower() const
Return the lower value for this range.
LLVM_ABI ConstantRange urem(const ConstantRange &Other) const
Return a new range representing the possible values resulting from an unsigned remainder operation of...
LLVM_ABI bool isFullSet() const
Return true if this set contains all of the elements possible for this data-type.
LLVM_ABI bool icmp(CmpInst::Predicate Pred, const ConstantRange &Other) const
Does the predicate Pred hold between ranges this and Other?
LLVM_ABI bool isEmptySet() const
Return true if this set contains no members.
LLVM_ABI ConstantRange zeroExtend(uint32_t BitWidth) const
Return a new range in the specified integer type, which must be strictly larger than the current type...
LLVM_ABI bool isSignWrappedSet() const
Return true if this set wraps around the signed domain.
LLVM_ABI APInt getSignedMin() const
Return the smallest signed value contained in the ConstantRange.
LLVM_ABI bool isWrappedSet() const
Return true if this set wraps around the unsigned domain.
LLVM_ABI void print(raw_ostream &OS) const
Print out the bounds to a stream.
LLVM_ABI ConstantRange truncate(uint32_t BitWidth, unsigned NoWrapKind=0) const
Return a new range in the specified integer type, which must be strictly smaller than the current typ...
LLVM_ABI ConstantRange signExtend(uint32_t BitWidth) const
Return a new range in the specified integer type, which must be strictly larger than the current type...
const APInt & getUpper() const
Return the upper value for this range.
LLVM_ABI ConstantRange unionWith(const ConstantRange &CR, PreferredRangeType Type=Smallest) const
Return the range that results from the union of this range with another range.
static LLVM_ABI ConstantRange makeExactICmpRegion(CmpInst::Predicate Pred, const APInt &Other)
Produce the exact range such that all values in the returned range satisfy the given predicate with a...
LLVM_ABI bool contains(const APInt &Val) const
Return true if the specified value is in the set.
LLVM_ABI APInt getUnsignedMax() const
Return the largest unsigned value contained in the ConstantRange.
LLVM_ABI ConstantRange intersectWith(const ConstantRange &CR, PreferredRangeType Type=Smallest) const
Return the range that results from the intersection of this range with another range.
LLVM_ABI APInt getSignedMax() const
Return the largest signed value contained in the ConstantRange.
static ConstantRange getNonEmpty(APInt Lower, APInt Upper)
Create non-empty constant range with the given bounds.
static LLVM_ABI ConstantRange makeGuaranteedNoWrapRegion(Instruction::BinaryOps BinOp, const ConstantRange &Other, unsigned NoWrapKind)
Produce the largest range containing all X such that "X BinOp Y" is guaranteed not to wrap (overflow)...
LLVM_ABI unsigned getMinSignedBits() const
Compute the maximal number of bits needed to represent every value in this signed range.
uint32_t getBitWidth() const
Get the bit width of this ConstantRange.
LLVM_ABI ConstantRange sub(const ConstantRange &Other) const
Return a new range representing the possible values resulting from a subtraction of a value in this r...
LLVM_ABI ConstantRange sextOrTrunc(uint32_t BitWidth) const
Make this range have the bit width given by BitWidth.
static LLVM_ABI ConstantRange makeExactNoWrapRegion(Instruction::BinaryOps BinOp, const APInt &Other, unsigned NoWrapKind)
Produce the range that contains X if and only if "X BinOp Other" does not wrap.
This is an important base class in LLVM.
Definition Constant.h:43
A parsed version of the target data layout string in and methods for querying it.
Definition DataLayout.h:64
LLVM_ABI const StructLayout * getStructLayout(StructType *Ty) const
Returns a StructLayout object, indicating the alignment of the struct, its size, and the offsets of i...
LLVM_ABI IntegerType * getIntPtrType(LLVMContext &C, unsigned AddressSpace=0) const
Returns an integer type with size at least as big as that of a pointer in the given address space.
LLVM_ABI unsigned getIndexTypeSizeInBits(Type *Ty) const
The size in bits of the index used in GEP calculation for this type.
LLVM_ABI IntegerType * getIndexType(LLVMContext &C, unsigned AddressSpace) const
Returns the type of a GEP index in AddressSpace.
TypeSize getTypeSizeInBits(Type *Ty) const
Size examples:
Definition DataLayout.h:791
ValueT lookup(const_arg_type_t< KeyT > Val) const
lookup - Return the entry for the specified key, or a default constructed value if no such entry exis...
Definition DenseMap.h:205
iterator find(const_arg_type_t< KeyT > Val)
Definition DenseMap.h:178
std::pair< iterator, bool > try_emplace(KeyT &&Key, Ts &&...Args)
Definition DenseMap.h:256
DenseMapIterator< KeyT, ValueT, KeyInfoT, BucketT > iterator
Definition DenseMap.h:74
iterator find_as(const LookupKeyT &Val)
Alternate version of find() which allows a different, and possibly less expensive,...
Definition DenseMap.h:191
size_type count(const_arg_type_t< KeyT > Val) const
Return 1 if the specified key is in the map, 0 otherwise.
Definition DenseMap.h:174
iterator end()
Definition DenseMap.h:81
bool contains(const_arg_type_t< KeyT > Val) const
Return true if the specified key is in the map, false otherwise.
Definition DenseMap.h:169
void swap(DerivedT &RHS)
Definition DenseMap.h:371
std::pair< iterator, bool > insert(const std::pair< KeyT, ValueT > &KV)
Definition DenseMap.h:241
Analysis pass which computes a DominatorTree.
Definition Dominators.h:278
Legacy analysis pass which computes a DominatorTree.
Definition Dominators.h:316
Concrete subclass of DominatorTreeBase that is used to compute a normal dominator tree.
Definition Dominators.h:159
LLVM_ABI bool isReachableFromEntry(const Use &U) const
Provide an overload for a Use.
LLVM_ABI bool dominates(const BasicBlock *BB, const Use &U) const
Return true if the (end of the) basic block BB dominates the use U.
This class describes a reference to an interned FoldingSetNodeID, which can be a useful to store node...
Definition FoldingSet.h:171
This class is used to gather all the unique data bits of a node.
Definition FoldingSet.h:208
FunctionPass(char &pid)
Definition Pass.h:316
Represents flags for the getelementptr instruction/expression.
bool hasNoUnsignedSignedWrap() const
bool hasNoUnsignedWrap() const
static GEPNoWrapFlags none()
static LLVM_ABI Type * getTypeAtIndex(Type *Ty, Value *Idx)
Return the type of the element at the given index of an indexable type.
Module * getParent()
Get the module that this global value is contained inside of...
static bool isPrivateLinkage(LinkageTypes Linkage)
static bool isInternalLinkage(LinkageTypes Linkage)
This instruction compares its operands according to the predicate given to the constructor.
CmpPredicate getCmpPredicate() const
static bool isGE(Predicate P)
Return true if the predicate is SGE or UGE.
CmpPredicate getSwappedCmpPredicate() const
static LLVM_ABI bool compare(const APInt &LHS, const APInt &RHS, ICmpInst::Predicate Pred)
Return result of LHS Pred RHS comparison.
static bool isLT(Predicate P)
Return true if the predicate is SLT or ULT.
CmpPredicate getInverseCmpPredicate() const
Predicate getNonStrictCmpPredicate() const
For example, SGT -> SGE, SLT -> SLE, ULT -> ULE, UGT -> UGE.
static bool isGT(Predicate P)
Return true if the predicate is SGT or UGT.
Predicate getFlippedSignednessPredicate() const
For example, SLT->ULT, ULT->SLT, SLE->ULE, ULE->SLE, EQ->EQ.
static CmpPredicate getInverseCmpPredicate(CmpPredicate Pred)
static bool isEquality(Predicate P)
Return true if this predicate is either EQ or NE.
bool isRelational() const
Return true if the predicate is relational (not EQ or NE).
static bool isLE(Predicate P)
Return true if the predicate is SLE or ULE.
LLVM_ABI bool hasNoUnsignedWrap() const LLVM_READONLY
Determine whether the no unsigned wrap flag is set.
LLVM_ABI bool hasNoSignedWrap() const LLVM_READONLY
Determine whether the no signed wrap flag is set.
LLVM_ABI bool isIdenticalToWhenDefined(const Instruction *I, bool IntersectAttrs=false) const LLVM_READONLY
This is like isIdenticalTo, except that it ignores the SubclassOptionalData flags,...
Class to represent integer types.
static LLVM_ABI IntegerType * get(LLVMContext &C, unsigned NumBits)
This static method is the primary way of constructing an IntegerType.
Definition Type.cpp:354
A helper class to return the specified delimiter string after the first invocation of operator String...
An instruction for reading from memory.
Analysis pass that exposes the LoopInfo for a function.
Definition LoopInfo.h:589
bool contains(const LoopT *L) const
Return true if the specified loop is contained within in this loop.
BlockT * getHeader() const
unsigned getLoopDepth() const
Return the nesting level of this loop.
BlockT * getLoopPredecessor() const
If the given loop's header has exactly one unique predecessor outside the loop, return it.
LoopT * getParentLoop() const
Return the parent loop if it exists or nullptr for top level loops.
unsigned getLoopDepth(const BlockT *BB) const
Return the loop nesting level of the specified block.
LoopT * getLoopFor(const BlockT *BB) const
Return the inner most loop that BB lives in.
The legacy pass manager's analysis pass to compute loop information.
Definition LoopInfo.h:616
Represents a single loop in the control flow graph.
Definition LoopInfo.h:40
bool isLoopInvariant(const Value *V) const
Return true if the specified value is loop invariant.
Definition LoopInfo.cpp:67
Metadata node.
Definition Metadata.h:1080
A Module instance is used to store all the information related to an LLVM module.
Definition Module.h:67
unsigned getOpcode() const
Return the opcode for this Instruction or ConstantExpr.
Definition Operator.h:43
Utility class for integer operators which may exhibit overflow - Add, Sub, Mul, and Shl.
Definition Operator.h:78
bool hasNoSignedWrap() const
Test whether this operation is known to never undergo signed overflow, aka the nsw property.
Definition Operator.h:111
bool hasNoUnsignedWrap() const
Test whether this operation is known to never undergo unsigned overflow, aka the nuw property.
Definition Operator.h:105
iterator_range< const_block_iterator > blocks() const
op_range incoming_values()
Value * getIncomingValueForBlock(const BasicBlock *BB) const
BasicBlock * getIncomingBlock(unsigned i) const
Return incoming basic block number i.
Value * getIncomingValue(unsigned i) const
Return incoming value number x.
unsigned getNumIncomingValues() const
Return the number of incoming edges.
AnalysisType & getAnalysis() const
getAnalysis<AnalysisType>() - This function is used by subclasses to get to the analysis information ...
PointerIntPair - This class implements a pair of a pointer and small integer.
static PointerType * getUnqual(Type *ElementType)
This constructs a pointer to an object of the specified type in the default address space (address sp...
static LLVM_ABI PoisonValue * get(Type *T)
Static factory methods - Return an 'poison' object of the specified type.
LLVM_ABI void addPredicate(const SCEVPredicate &Pred)
Adds a new predicate.
LLVM_ABI const SCEVPredicate & getPredicate() const
LLVM_ABI const SCEV * getPredicatedSCEV(const SCEV *Expr)
Returns the rewritten SCEV for Expr in the context of the current SCEV predicate.
LLVM_ABI bool hasNoOverflow(Value *V, SCEVWrapPredicate::IncrementWrapFlags Flags)
Returns true if we've proved that V doesn't wrap by means of a SCEV predicate.
LLVM_ABI void setNoOverflow(Value *V, SCEVWrapPredicate::IncrementWrapFlags Flags)
Proves that V doesn't overflow by adding SCEV predicate.
LLVM_ABI void print(raw_ostream &OS, unsigned Depth) const
Print the SCEV mappings done by the Predicated Scalar Evolution.
LLVM_ABI bool areAddRecsEqualWithPreds(const SCEVAddRecExpr *AR1, const SCEVAddRecExpr *AR2) const
Check if AR1 and AR2 are equal, while taking into account Equal predicates in Preds.
LLVM_ABI PredicatedScalarEvolution(ScalarEvolution &SE, Loop &L)
LLVM_ABI const SCEVAddRecExpr * getAsAddRec(Value *V)
Attempts to produce an AddRecExpr for V by adding additional SCEV predicates.
LLVM_ABI unsigned getSmallConstantMaxTripCount()
Returns the upper bound of the loop trip count as a normal unsigned value, or 0 if the trip count is ...
LLVM_ABI const SCEV * getBackedgeTakenCount()
Get the (predicated) backedge count for the analyzed loop.
LLVM_ABI const SCEV * getSymbolicMaxBackedgeTakenCount()
Get the (predicated) symbolic max backedge count for the analyzed loop.
LLVM_ABI const SCEV * getSCEV(Value *V)
Returns the SCEV expression of V, in the context of the current SCEV predicate.
A set of analyses that are preserved following a run of a transformation pass.
Definition Analysis.h:112
static PreservedAnalyses all()
Construct a special preserved set that preserves all passes.
Definition Analysis.h:118
PreservedAnalysisChecker getChecker() const
Build a checker for this PreservedAnalyses and the specified analysis type.
Definition Analysis.h:275
constexpr bool isValid() const
Definition Register.h:112
This node represents an addition of some number of SCEVs.
This node represents a polynomial recurrence on the trip count of the specified loop.
LLVM_ABI const SCEV * evaluateAtIteration(const SCEV *It, ScalarEvolution &SE) const
Return the value of this chain of recurrences at the specified iteration number.
void setNoWrapFlags(NoWrapFlags Flags)
Set flags for a recurrence without clearing any previously set flags.
bool isAffine() const
Return true if this represents an expression A + B*x where A and B are loop invariant values.
bool isQuadratic() const
Return true if this represents an expression A + B*x + C*x^2 where A, B and C are loop invariant valu...
LLVM_ABI const SCEV * getNumIterationsInRange(const ConstantRange &Range, ScalarEvolution &SE) const
Return the number of iterations of this loop that produce values in the specified constant range.
LLVM_ABI const SCEVAddRecExpr * getPostIncExpr(ScalarEvolution &SE) const
Return an expression representing the value of this expression one iteration of the loop ahead.
SCEVUse getStepRecurrence(ScalarEvolution &SE) const
Constructs and returns the recurrence indicating how much this expression steps by.
This is the base class for unary cast operator classes.
LLVM_ABI SCEVCastExpr(const FoldingSetNodeIDRef ID, SCEVTypes SCEVTy, SCEVUse op, Type *ty)
void setNoWrapFlags(NoWrapFlags Flags)
Set flags for a non-recurrence without clearing previously set flags.
This class represents an assumption that the expression LHS Pred RHS evaluates to true,...
SCEVComparePredicate(const FoldingSetNodeIDRef ID, const ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS)
bool isAlwaysTrue() const override
Returns true if the predicate is always true.
void print(raw_ostream &OS, unsigned Depth=0) const override
Prints a textual representation of this predicate with an indentation of Depth.
bool implies(const SCEVPredicate *N, ScalarEvolution &SE) const override
Implementation of the SCEVPredicate interface.
This class represents a constant integer value.
ConstantInt * getValue() const
const APInt & getAPInt() const
This is the base class for unary integral cast operator classes.
LLVM_ABI SCEVIntegralCastExpr(const FoldingSetNodeIDRef ID, SCEVTypes SCEVTy, SCEVUse op, Type *ty)
This node is the base class min/max selections.
static enum SCEVTypes negate(enum SCEVTypes T)
This node represents multiplication of some number of SCEVs.
This node is a base class providing common functionality for n'ary operators.
ArrayRef< SCEVUse > operands() const
NoWrapFlags getNoWrapFlags(NoWrapFlags Mask=NoWrapMask) const
SCEVUse getOperand(unsigned i) const
This class represents an assumption made using SCEV expressions which can be checked at run-time.
SCEVPredicate(const SCEVPredicate &)=default
virtual bool implies(const SCEVPredicate *N, ScalarEvolution &SE) const =0
Returns true if this predicate implies N.
SCEVPredicateKind Kind
This class represents a cast from a pointer to a pointer-sized integer value, without capturing the p...
This class represents a cast from a pointer to a pointer-sized integer value.
This visitor recursively visits a SCEV expression and re-writes it.
const SCEV * visitSignExtendExpr(const SCEVSignExtendExpr *Expr)
const SCEV * visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr)
const SCEV * visitSMinExpr(const SCEVSMinExpr *Expr)
const SCEV * visitUMinExpr(const SCEVUMinExpr *Expr)
This class represents a signed minimum selection.
This node is the base class for sequential/in-order min/max selections.
static SCEVTypes getEquivalentNonSequentialSCEVType(SCEVTypes Ty)
This class represents a sign extension of a small integer value to a larger integer value.
Visit all nodes in the expression tree using worklist traversal.
This class represents a truncation of an integer value to a smaller integer value.
This class represents a binary unsigned division operation.
This class represents an unsigned minimum selection.
This class represents a composition of other SCEV predicates, and is the class that most clients will...
void print(raw_ostream &OS, unsigned Depth) const override
Prints a textual representation of this predicate with an indentation of Depth.
bool implies(const SCEVPredicate *N, ScalarEvolution &SE) const override
Returns true if this predicate implies N.
SCEVUnionPredicate(ArrayRef< const SCEVPredicate * > Preds, ScalarEvolution &SE)
Union predicates don't get cached so create a dummy set ID for it.
bool isAlwaysTrue() const override
Implementation of the SCEVPredicate interface.
This means that we are dealing with an entirely unknown SCEV value, and only represent it as its LLVM...
This class represents the value of vscale, as used when defining the length of a scalable vector or r...
This class represents an assumption made on an AddRec expression.
IncrementWrapFlags
Similar to SCEV::NoWrapFlags, but with slightly different semantics for FlagNUSW.
SCEVWrapPredicate(const FoldingSetNodeIDRef ID, const SCEVAddRecExpr *AR, IncrementWrapFlags Flags)
bool implies(const SCEVPredicate *N, ScalarEvolution &SE) const override
Returns true if this predicate implies N.
static SCEVWrapPredicate::IncrementWrapFlags setFlags(SCEVWrapPredicate::IncrementWrapFlags Flags, SCEVWrapPredicate::IncrementWrapFlags OnFlags)
void print(raw_ostream &OS, unsigned Depth=0) const override
Prints a textual representation of this predicate with an indentation of Depth.
bool isAlwaysTrue() const override
Returns true if the predicate is always true.
const SCEVAddRecExpr * getExpr() const
Implementation of the SCEVPredicate interface.
static SCEVWrapPredicate::IncrementWrapFlags clearFlags(SCEVWrapPredicate::IncrementWrapFlags Flags, SCEVWrapPredicate::IncrementWrapFlags OffFlags)
Convenient IncrementWrapFlags manipulation methods.
static SCEVWrapPredicate::IncrementWrapFlags getImpliedFlags(const SCEVAddRecExpr *AR, ScalarEvolution &SE)
Returns the set of SCEVWrapPredicate no wrap flags implied by a SCEVAddRecExpr.
IncrementWrapFlags getFlags() const
Returns the set assumed no overflow flags.
This class represents a zero extension of a small integer value to a larger integer value.
This class represents an analyzed expression in the program.
unsigned short getExpressionSize() const
SCEVNoWrapFlags NoWrapFlags
LLVM_ABI bool isOne() const
Return true if the expression is a constant one.
static constexpr auto FlagNUW
LLVM_ABI void computeAndSetCanonical(ScalarEvolution &SE)
Compute and set the canonical SCEV, by constructing a SCEV with the same operands,...
LLVM_ABI bool isZero() const
Return true if the expression is a constant zero.
const SCEV * CanonicalSCEV
Pointer to the canonical version of the SCEV, i.e.
static constexpr auto FlagAnyWrap
LLVM_ABI void dump() const
This method is used for debugging.
LLVM_ABI bool isAllOnesValue() const
Return true if the expression is a constant all-ones value.
LLVM_ABI bool isNonConstantNegative() const
Return true if the specified scev is negated, but not a constant.
static constexpr auto FlagNSW
LLVM_ABI ArrayRef< SCEVUse > operands() const
Return operands of this SCEV expression.
LLVM_ABI void print(raw_ostream &OS) const
Print out the internal representation of this scalar to the specified stream.
SCEV(const FoldingSetNodeIDRef ID, SCEVTypes SCEVTy, unsigned short ExpressionSize)
SCEVTypes getSCEVType() const
static constexpr auto FlagNW
LLVM_ABI Type * getType() const
Return the LLVM type of this SCEV expression.
Analysis pass that exposes the ScalarEvolution for a function.
LLVM_ABI ScalarEvolution run(Function &F, FunctionAnalysisManager &AM)
LLVM_ABI PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM)
LLVM_ABI PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM)
void getAnalysisUsage(AnalysisUsage &AU) const override
getAnalysisUsage - This function should be overriden by passes that need analysis information to do t...
void print(raw_ostream &OS, const Module *=nullptr) const override
print - Print out the internal state of the pass.
bool runOnFunction(Function &F) override
runOnFunction - Virtual method overriden by subclasses to do the per-function processing of the pass.
void releaseMemory() override
releaseMemory() - This member can be implemented by a pass if it wants to be able to release its memo...
void verifyAnalysis() const override
verifyAnalysis() - This member can be implemented by a analysis pass to check state of analysis infor...
static LLVM_ABI LoopGuards collect(const Loop *L, ScalarEvolution &SE)
Collect rewrite map for loop guards for loop L, together with flags indicating if NUW and NSW can be ...
LLVM_ABI const SCEV * rewrite(const SCEV *Expr) const
Try to apply the collected loop guards to Expr.
The main scalar evolution driver.
LLVM_ABI const SCEV * getUDivExpr(SCEVUse LHS, SCEVUse RHS)
Get a canonical unsigned division expression, or something simpler if possible.
const SCEV * getConstantMaxBackedgeTakenCount(const Loop *L)
When successful, this returns a SCEVConstant that is greater than or equal to (i.e.
static bool hasFlags(SCEV::NoWrapFlags Flags, SCEV::NoWrapFlags TestFlags)
const DataLayout & getDataLayout() const
Return the DataLayout associated with the module this SCEV instance is operating on.
LLVM_ABI bool isKnownNonNegative(const SCEV *S)
Test if the given expression is known to be non-negative.
LLVM_ABI bool isKnownOnEveryIteration(CmpPredicate Pred, const SCEVAddRecExpr *LHS, const SCEV *RHS)
Test if the condition described by Pred, LHS, RHS is known to be true on every iteration of the loop ...
LLVM_ABI const SCEV * getNegativeSCEV(const SCEV *V, SCEV::NoWrapFlags Flags=SCEV::FlagAnyWrap)
Return the SCEV object corresponding to -V.
LLVM_ABI std::optional< LoopInvariantPredicate > getLoopInvariantExitCondDuringFirstIterationsImpl(CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS, const Loop *L, const Instruction *CtxI, const SCEV *MaxIter)
LLVM_ABI const SCEV * getUDivCeilSCEV(const SCEV *N, const SCEV *D)
Compute ceil(N / D).
LLVM_ABI std::optional< LoopInvariantPredicate > getLoopInvariantExitCondDuringFirstIterations(CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS, const Loop *L, const Instruction *CtxI, const SCEV *MaxIter)
If the result of the predicate LHS Pred RHS is loop invariant with respect to L at given Context duri...
LLVM_ABI Type * getWiderType(Type *Ty1, Type *Ty2) const
LLVM_ABI const SCEV * getAbsExpr(const SCEV *Op, bool IsNSW)
LLVM_ABI bool isKnownNonPositive(const SCEV *S)
Test if the given expression is known to be non-positive.
LLVM_ABI bool isKnownNegative(const SCEV *S)
Test if the given expression is known to be negative.
LLVM_ABI const SCEV * getPredicatedConstantMaxBackedgeTakenCount(const Loop *L, SmallVectorImpl< const SCEVPredicate * > &Predicates)
Similar to getConstantMaxBackedgeTakenCount, except it will add a set of SCEV predicates to Predicate...
LLVM_ABI const SCEV * removePointerBase(const SCEV *S)
Compute an expression equivalent to S - getPointerBase(S).
LLVM_ABI bool isLoopEntryGuardedByCond(const Loop *L, CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS)
Test whether entry to the loop is protected by a conditional between LHS and RHS.
LLVM_ABI bool isKnownNonZero(const SCEV *S)
Test if the given expression is known to be non-zero.
LLVM_ABI const SCEV * getURemExpr(SCEVUse LHS, SCEVUse RHS)
Represents an unsigned remainder expression based on unsigned division.
LLVM_ABI const SCEV * getSCEVAtScope(const SCEV *S, const Loop *L)
Return a SCEV expression for the specified value at the specified scope in the program.
LLVM_ABI const SCEV * getBackedgeTakenCount(const Loop *L, ExitCountKind Kind=Exact)
If the specified loop has a predictable backedge-taken count, return it, otherwise return a SCEVCould...
LLVM_ABI const SCEV * getSMinExpr(SCEVUse LHS, SCEVUse RHS)
LLVM_ABI void setNoWrapFlags(SCEVAddRecExpr *AddRec, SCEV::NoWrapFlags Flags)
Update no-wrap flags of an AddRec.
LLVM_ABI const SCEV * getUMaxFromMismatchedTypes(const SCEV *LHS, const SCEV *RHS)
Promote the operands to the wider of the types using zero-extension, and then perform a umax operatio...
const SCEV * getZero(Type *Ty)
Return a SCEV for the constant 0 of a specific type.
LLVM_ABI bool willNotOverflow(Instruction::BinaryOps BinOp, bool Signed, const SCEV *LHS, const SCEV *RHS, const Instruction *CtxI=nullptr)
Is operation BinOp between LHS and RHS provably does not have a signed/unsigned overflow (Signed)?
LLVM_ABI ExitLimit computeExitLimitFromCond(const Loop *L, Value *ExitCond, bool ExitIfTrue, bool ControlsOnlyExit, bool AllowPredicates=false)
Compute the number of times the backedge of the specified loop will execute if its exit condition wer...
LLVM_ABI const SCEV * getZeroExtendExprImpl(const SCEV *Op, Type *Ty, unsigned Depth=0)
LLVM_ABI const SCEV * getMinMaxExpr(SCEVTypes Kind, SmallVectorImpl< SCEVUse > &Operands)
LLVM_ABI const SCEVPredicate * getEqualPredicate(const SCEV *LHS, const SCEV *RHS)
LLVM_ABI unsigned getSmallConstantTripMultiple(const Loop *L, const SCEV *ExitCount)
Returns the largest constant divisor of the trip count as a normal unsigned value,...
LLVM_ABI uint64_t getTypeSizeInBits(Type *Ty) const
Return the size in bits of the specified type, for which isSCEVable must return true.
LLVM_ABI const SCEV * getConstant(ConstantInt *V)
LLVM_ABI const SCEV * getPredicatedBackedgeTakenCount(const Loop *L, SmallVectorImpl< const SCEVPredicate * > &Predicates)
Similar to getBackedgeTakenCount, except it will add a set of SCEV predicates to Predicates that are ...
LLVM_ABI const SCEV * getSCEV(Value *V)
Return a SCEV expression for the full generality of the specified expression.
LLVM_ABI const SCEV * getMinusSCEV(SCEVUse LHS, SCEVUse RHS, SCEV::NoWrapFlags Flags=SCEV::FlagAnyWrap, unsigned Depth=0)
Return LHS-RHS.
ConstantRange getSignedRange(const SCEV *S)
Determine the signed range for a particular SCEV.
LLVM_ABI const SCEV * getAddRecExpr(SCEVUse Start, SCEVUse Step, const Loop *L, SCEV::NoWrapFlags Flags)
Get an add recurrence expression for the specified loop.
LLVM_ABI const SCEV * getNoopOrSignExtend(const SCEV *V, Type *Ty)
Return a SCEV corresponding to a conversion of the input value to the specified type.
bool loopHasNoAbnormalExits(const Loop *L)
Return true if the loop has no abnormal exits.
LLVM_ABI const SCEV * getTripCountFromExitCount(const SCEV *ExitCount)
A version of getTripCountFromExitCount below which always picks an evaluation type which can not resu...
LLVM_ABI ScalarEvolution(Function &F, TargetLibraryInfo &TLI, AssumptionCache &AC, DominatorTree &DT, LoopInfo &LI)
const SCEV * getOne(Type *Ty)
Return a SCEV for the constant 1 of a specific type.
LLVM_ABI const SCEV * getTruncateOrNoop(const SCEV *V, Type *Ty)
Return a SCEV corresponding to a conversion of the input value to the specified type.
LLVM_ABI const SCEV * getLosslessPtrToIntExpr(const SCEV *Op)
LLVM_ABI const SCEV * getCastExpr(SCEVTypes Kind, const SCEV *Op, Type *Ty)
LLVM_ABI const SCEV * getSequentialMinMaxExpr(SCEVTypes Kind, SmallVectorImpl< SCEVUse > &Operands)
LLVM_ABI std::optional< bool > evaluatePredicateAt(CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS, const Instruction *CtxI)
Check whether the condition described by Pred, LHS, and RHS is true or false in the given Context.
LLVM_ABI unsigned getSmallConstantMaxTripCount(const Loop *L, SmallVectorImpl< const SCEVPredicate * > *Predicates=nullptr)
Returns the upper bound of the loop trip count as a normal unsigned value.
LLVM_ABI const SCEV * getPtrToIntExpr(const SCEV *Op, Type *Ty)
LLVM_ABI bool isBackedgeTakenCountMaxOrZero(const Loop *L)
Return true if the backedge taken count is either the value returned by getConstantMaxBackedgeTakenCo...
LLVM_ABI void forgetLoop(const Loop *L)
This method should be called by the client when it has changed a loop in a way that may effect Scalar...
LLVM_ABI bool isLoopInvariant(const SCEV *S, const Loop *L)
Return true if the value of the given SCEV is unchanging in the specified loop.
LLVM_ABI bool isKnownPositive(const SCEV *S)
Test if the given expression is known to be positive.
LLVM_ABI bool SimplifyICmpOperands(CmpPredicate &Pred, SCEVUse &LHS, SCEVUse &RHS, unsigned Depth=0)
Simplify LHS and RHS in a comparison with predicate Pred.
APInt getUnsignedRangeMin(const SCEV *S)
Determine the min of the unsigned range for a particular SCEV.
LLVM_ABI const SCEV * getOffsetOfExpr(Type *IntTy, StructType *STy, unsigned FieldNo)
Return an expression for offsetof on the given field with type IntTy.
LLVM_ABI LoopDisposition getLoopDisposition(const SCEV *S, const Loop *L)
Return the "disposition" of the given SCEV with respect to the given loop.
LLVM_ABI bool containsAddRecurrence(const SCEV *S)
Return true if the SCEV is a scAddRecExpr or it contains scAddRecExpr.
LLVM_ABI const SCEV * getSignExtendExprImpl(const SCEV *Op, Type *Ty, unsigned Depth=0)
LLVM_ABI bool hasOperand(const SCEV *S, const SCEV *Op) const
Test whether the given SCEV has Op as a direct or indirect operand.
LLVM_ABI const SCEV * getZeroExtendExpr(const SCEV *Op, Type *Ty, unsigned Depth=0)
LLVM_ABI bool isSCEVable(Type *Ty) const
Test if values of the given type are analyzable within the SCEV framework.
LLVM_ABI Type * getEffectiveSCEVType(Type *Ty) const
Return a type with the same bitwidth as the given type and which represents how SCEV will treat the g...
LLVM_ABI const SCEVPredicate * getComparePredicate(ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS)
LLVM_ABI bool haveSameSign(const SCEV *S1, const SCEV *S2)
Return true if we know that S1 and S2 must have the same sign.
LLVM_ABI const SCEV * getNotSCEV(const SCEV *V)
Return the SCEV object corresponding to ~V.
LLVM_ABI const SCEV * getElementCount(Type *Ty, ElementCount EC, SCEV::NoWrapFlags Flags=SCEV::FlagAnyWrap)
LLVM_ABI bool instructionCouldExistWithOperands(const SCEV *A, const SCEV *B)
Return true if there exists a point in the program at which both A and B could be operands to the sam...
ConstantRange getUnsignedRange(const SCEV *S)
Determine the unsigned range for a particular SCEV.
LLVM_ABI void print(raw_ostream &OS) const
LLVM_ABI const SCEV * getPredicatedExitCount(const Loop *L, const BasicBlock *ExitingBlock, SmallVectorImpl< const SCEVPredicate * > *Predicates, ExitCountKind Kind=Exact)
Same as above except this uses the predicated backedge taken info and may require predicates.
static SCEV::NoWrapFlags clearFlags(SCEV::NoWrapFlags Flags, SCEV::NoWrapFlags OffFlags)
LLVM_ABI void forgetTopmostLoop(const Loop *L)
LLVM_ABI void forgetValue(Value *V)
This method should be called by the client when it has changed a value in a way that may effect its v...
APInt getSignedRangeMin(const SCEV *S)
Determine the min of the signed range for a particular SCEV.
LLVM_ABI const SCEV * getNoopOrAnyExtend(const SCEV *V, Type *Ty)
Return a SCEV corresponding to a conversion of the input value to the specified type.
LLVM_ABI void forgetBlockAndLoopDispositions(Value *V=nullptr)
Called when the client has changed the disposition of values in a loop or block.
LLVM_ABI const SCEV * getTruncateExpr(const SCEV *Op, Type *Ty, unsigned Depth=0)
LLVM_ABI const SCEV * getUMaxExpr(SCEVUse LHS, SCEVUse RHS)
static SCEV::NoWrapFlags maskFlags(SCEV::NoWrapFlags Flags, SCEV::NoWrapFlags Mask)
Convenient NoWrapFlags manipulation.
LLVM_ABI std::optional< LoopInvariantPredicate > getLoopInvariantPredicate(CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS, const Loop *L, const Instruction *CtxI=nullptr)
If the result of the predicate LHS Pred RHS is loop invariant with respect to L, return a LoopInvaria...
LLVM_ABI const SCEV * getStoreSizeOfExpr(Type *IntTy, Type *StoreTy)
Return an expression for the store size of StoreTy that is type IntTy.
LLVM_ABI const SCEVPredicate * getWrapPredicate(const SCEVAddRecExpr *AR, SCEVWrapPredicate::IncrementWrapFlags AddedFlags)
LLVM_ABI bool isLoopBackedgeGuardedByCond(const Loop *L, CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS)
Test whether the backedge of the loop is protected by a conditional between LHS and RHS.
LLVM_ABI APInt getNonZeroConstantMultiple(const SCEV *S)
const SCEV * getMinusOne(Type *Ty)
Return a SCEV for the constant -1 of a specific type.
static SCEV::NoWrapFlags setFlags(SCEV::NoWrapFlags Flags, SCEV::NoWrapFlags OnFlags)
LLVM_ABI bool hasLoopInvariantBackedgeTakenCount(const Loop *L)
Return true if the specified loop has an analyzable loop-invariant backedge-taken count.
LLVM_ABI BlockDisposition getBlockDisposition(const SCEV *S, const BasicBlock *BB)
Return the "disposition" of the given SCEV with respect to the given block.
LLVM_ABI const SCEV * getNoopOrZeroExtend(const SCEV *V, Type *Ty)
Return a SCEV corresponding to a conversion of the input value to the specified type.
LLVM_ABI bool invalidate(Function &F, const PreservedAnalyses &PA, FunctionAnalysisManager::Invalidator &Inv)
LLVM_ABI const SCEV * getUMinFromMismatchedTypes(const SCEV *LHS, const SCEV *RHS, bool Sequential=false)
Promote the operands to the wider of the types using zero-extension, and then perform a umin operatio...
LLVM_ABI bool loopIsFiniteByAssumption(const Loop *L)
Return true if this loop is finite by assumption.
LLVM_ABI const SCEV * getExistingSCEV(Value *V)
Return an existing SCEV for V if there is one, otherwise return nullptr.
LLVM_ABI APInt getConstantMultiple(const SCEV *S, const Instruction *CtxI=nullptr)
Returns the max constant multiple of S.
LoopDisposition
An enum describing the relationship between a SCEV and a loop.
@ LoopComputable
The SCEV varies predictably with the loop.
@ LoopVariant
The SCEV is loop-variant (unknown).
@ LoopInvariant
The SCEV is loop-invariant.
LLVM_ABI bool isKnownMultipleOf(const SCEV *S, uint64_t M, SmallVectorImpl< const SCEVPredicate * > &Assumptions)
Check that S is a multiple of M.
LLVM_ABI const SCEV * getAnyExtendExpr(const SCEV *Op, Type *Ty)
getAnyExtendExpr - Return a SCEV for the given operand extended with unspecified bits out to the give...
LLVM_ABI bool isKnownToBeAPowerOfTwo(const SCEV *S, bool OrZero=false, bool OrNegative=false)
Test if the given expression is known to be a power of 2.
LLVM_ABI std::optional< SCEV::NoWrapFlags > getStrengthenedNoWrapFlagsFromBinOp(const OverflowingBinaryOperator *OBO)
Parse NSW/NUW flags from add/sub/mul IR binary operation Op into SCEV no-wrap flags,...
LLVM_ABI void forgetLcssaPhiWithNewPredecessor(Loop *L, PHINode *V)
Forget LCSSA phi node V of loop L to which a new predecessor was added, such that it may no longer be...
LLVM_ABI bool containsUndefs(const SCEV *S) const
Return true if the SCEV expression contains an undef value.
LLVM_ABI std::optional< MonotonicPredicateType > getMonotonicPredicateType(const SCEVAddRecExpr *LHS, ICmpInst::Predicate Pred)
If, for all loop invariant X, the predicate "LHS `Pred` X" is monotonically increasing or decreasing,...
LLVM_ABI const SCEV * getCouldNotCompute()
LLVM_ABI const SCEV * getMulExpr(SmallVectorImpl< SCEVUse > &Ops, SCEV::NoWrapFlags Flags=SCEV::FlagAnyWrap, unsigned Depth=0)
Get a canonical multiply expression, or something simpler if possible.
LLVM_ABI bool isAvailableAtLoopEntry(const SCEV *S, const Loop *L)
Determine if the SCEV can be evaluated at loop's entry.
LLVM_ABI uint32_t getMinTrailingZeros(const SCEV *S, const Instruction *CtxI=nullptr)
Determine the minimum number of zero bits that S is guaranteed to end in (at every loop iteration).
BlockDisposition
An enum describing the relationship between a SCEV and a basic block.
@ DominatesBlock
The SCEV dominates the block.
@ ProperlyDominatesBlock
The SCEV properly dominates the block.
@ DoesNotDominateBlock
The SCEV does not dominate the block.
LLVM_ABI const SCEV * getExitCount(const Loop *L, const BasicBlock *ExitingBlock, ExitCountKind Kind=Exact)
Return the number of times the backedge executes before the given exit would be taken; if not exactly...
LLVM_ABI const SCEV * getSignExtendExpr(const SCEV *Op, Type *Ty, unsigned Depth=0)
LLVM_ABI void getPoisonGeneratingValues(SmallPtrSetImpl< const Value * > &Result, const SCEV *S)
Return the set of Values that, if poison, will definitively result in S being poison as well.
LLVM_ABI void forgetLoopDispositions()
Called when the client has changed the disposition of values in this loop.
LLVM_ABI const SCEV * getVScale(Type *Ty)
LLVM_ABI unsigned getSmallConstantTripCount(const Loop *L)
Returns the exact trip count of the loop if we can compute it, and the result is a small constant.
LLVM_ABI bool hasComputableLoopEvolution(const SCEV *S, const Loop *L)
Return true if the given SCEV changes value in a known way in the specified loop.
LLVM_ABI const SCEV * getPointerBase(const SCEV *V)
Transitively follow the chain of pointer-type operands until reaching a SCEV that does not have a sin...
LLVM_ABI void forgetAllLoops()
LLVM_ABI bool dominates(const SCEV *S, const BasicBlock *BB)
Return true if elements that makes up the given SCEV dominate the specified basic block.
APInt getUnsignedRangeMax(const SCEV *S)
Determine the max of the unsigned range for a particular SCEV.
LLVM_ABI const SCEV * getAddExpr(SmallVectorImpl< SCEVUse > &Ops, SCEV::NoWrapFlags Flags=SCEV::FlagAnyWrap, unsigned Depth=0)
Get a canonical add expression, or something simpler if possible.
ExitCountKind
The terms "backedge taken count" and "exit count" are used interchangeably to refer to the number of ...
@ SymbolicMaximum
An expression which provides an upper bound on the exact trip count.
@ ConstantMaximum
A constant which provides an upper bound on the exact trip count.
@ Exact
An expression exactly describing the number of times the backedge has executed when a loop is exited.
LLVM_ABI bool isKnownPredicate(CmpPredicate Pred, SCEVUse LHS, SCEVUse RHS)
Test if the given expression is known to satisfy the condition described by Pred, LHS,...
LLVM_ABI const SCEV * applyLoopGuards(const SCEV *Expr, const Loop *L)
Try to apply information from loop guards for L to Expr.
LLVM_ABI const SCEV * getPtrToAddrExpr(const SCEV *Op)
LLVM_ABI const SCEVAddRecExpr * convertSCEVToAddRecWithPredicates(const SCEV *S, const Loop *L, SmallVectorImpl< const SCEVPredicate * > &Preds)
Tries to convert the S expression to an AddRec expression, adding additional predicates to Preds as r...
LLVM_ABI const SCEV * getSMaxExpr(SCEVUse LHS, SCEVUse RHS)
LLVM_ABI const SCEV * getElementSize(Instruction *Inst)
Return the size of an element read or written by Inst.
LLVM_ABI const SCEV * getSizeOfExpr(Type *IntTy, TypeSize Size)
Return an expression for a TypeSize.
LLVM_ABI std::optional< bool > evaluatePredicate(CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS)
Check whether the condition described by Pred, LHS, and RHS is true or false.
LLVM_ABI const SCEV * getUnknown(Value *V)
LLVM_ABI std::optional< std::pair< const SCEV *, SmallVector< const SCEVPredicate *, 3 > > > createAddRecFromPHIWithCasts(const SCEVUnknown *SymbolicPHI)
Checks if SymbolicPHI can be rewritten as an AddRecExpr under some Predicates.
LLVM_ABI const SCEV * getTruncateOrZeroExtend(const SCEV *V, Type *Ty, unsigned Depth=0)
Return a SCEV corresponding to a conversion of the input value to the specified type.
LLVM_ABI bool isKnownViaInduction(CmpPredicate Pred, SCEVUse LHS, SCEVUse RHS)
We'd like to check the predicate on every iteration of the most dominated loop between loops used in ...
LLVM_ABI std::optional< APInt > computeConstantDifference(const SCEV *LHS, const SCEV *RHS)
Compute LHS - RHS and returns the result as an APInt if it is a constant, and std::nullopt if it isn'...
LLVM_ABI bool properlyDominates(const SCEV *S, const BasicBlock *BB)
Return true if elements that makes up the given SCEV properly dominate the specified basic block.
LLVM_ABI const SCEV * getUDivExactExpr(SCEVUse LHS, SCEVUse RHS)
Get a canonical unsigned division expression, or something simpler if possible.
LLVM_ABI const SCEV * rewriteUsingPredicate(const SCEV *S, const Loop *L, const SCEVPredicate &A)
Re-writes the SCEV according to the Predicates in A.
LLVM_ABI std::pair< const SCEV *, const SCEV * > SplitIntoInitAndPostInc(const Loop *L, const SCEV *S)
Splits SCEV expression S into two SCEVs.
LLVM_ABI bool canReuseInstruction(const SCEV *S, Instruction *I, SmallVectorImpl< Instruction * > &DropPoisonGeneratingInsts)
Check whether it is poison-safe to represent the expression S using the instruction I.
LLVM_ABI bool isKnownPredicateAt(CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS, const Instruction *CtxI)
Test if the given expression is known to satisfy the condition described by Pred, LHS,...
LLVM_ABI const SCEV * getPredicatedSymbolicMaxBackedgeTakenCount(const Loop *L, SmallVectorImpl< const SCEVPredicate * > &Predicates)
Similar to getSymbolicMaxBackedgeTakenCount, except it will add a set of SCEV predicates to Predicate...
LLVM_ABI const SCEV * getGEPExpr(GEPOperator *GEP, ArrayRef< SCEVUse > IndexExprs)
Returns an expression for a GEP.
LLVM_ABI const SCEV * getUMinExpr(SCEVUse LHS, SCEVUse RHS, bool Sequential=false)
LLVM_ABI void registerUser(const SCEV *User, ArrayRef< const SCEV * > Ops)
Notify this ScalarEvolution that User directly uses SCEVs in Ops.
LLVM_ABI bool isBasicBlockEntryGuardedByCond(const BasicBlock *BB, CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS)
Test whether entry to the basic block is protected by a conditional between LHS and RHS.
LLVM_ABI const SCEV * getTruncateOrSignExtend(const SCEV *V, Type *Ty, unsigned Depth=0)
Return a SCEV corresponding to a conversion of the input value to the specified type.
LLVM_ABI bool containsErasedValue(const SCEV *S) const
Return true if the SCEV expression contains a Value that has been optimised out and is now a nullptr.
const SCEV * getSymbolicMaxBackedgeTakenCount(const Loop *L)
When successful, this returns a SCEV that is greater than or equal to (i.e.
APInt getSignedRangeMax(const SCEV *S)
Determine the max of the signed range for a particular SCEV.
LLVM_ABI void verify() const
LLVMContext & getContext() const
Implements a dense probed hash-table based set with some number of buckets stored inline.
Definition DenseSet.h:291
size_type size() const
Definition SmallPtrSet.h:99
A templated base class for SmallPtrSet which provides the typesafe interface that is common across al...
std::pair< iterator, bool > insert(PtrType Ptr)
Inserts Ptr if and only if there is no element in the container equal to Ptr.
bool contains(ConstPtrType Ptr) const
SmallPtrSet - This class implements a set which is optimized for holding SmallSize or less elements.
This class consists of common code factored out of the SmallVector class to reduce code duplication b...
reference emplace_back(ArgTypes &&... Args)
void reserve(size_type N)
iterator erase(const_iterator CI)
void append(ItTy in_start, ItTy in_end)
Add the specified range to the end of the SmallVector.
iterator insert(iterator I, T &&Elt)
void push_back(const T &Elt)
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
An instruction for storing to memory.
Represent a constant reference to a string, i.e.
Definition StringRef.h:55
Used to lazily calculate structure layout information for a target machine, based on the DataLayout s...
Definition DataLayout.h:743
TypeSize getElementOffset(unsigned Idx) const
Definition DataLayout.h:774
TypeSize getSizeInBits() const
Definition DataLayout.h:754
Class to represent struct types.
Analysis pass providing the TargetLibraryInfo.
Provides information about what library functions are available for the current target.
The instances of the Type class are immutable: once they are created, they are never changed.
Definition Type.h:46
static LLVM_ABI IntegerType * getInt32Ty(LLVMContext &C)
Definition Type.cpp:313
bool isPointerTy() const
True if this is an instance of PointerType.
Definition Type.h:284
LLVM_ABI TypeSize getPrimitiveSizeInBits() const LLVM_READONLY
Return the basic size of this type if it is a primitive type.
Definition Type.cpp:201
static LLVM_ABI IntegerType * getInt1Ty(LLVMContext &C)
Definition Type.cpp:310
bool isIntOrPtrTy() const
Return true if this is an integer type or a pointer type.
Definition Type.h:272
bool isIntegerTy() const
True if this is an instance of IntegerType.
Definition Type.h:257
static LLVM_ABI IntegerType * getIntNTy(LLVMContext &C, unsigned N)
Definition Type.cpp:317
A Use represents the edge between a Value definition and its users.
Definition Use.h:35
op_range operands()
Definition User.h:267
Use & Op()
Definition User.h:171
Value * getOperand(unsigned i) const
Definition User.h:207
LLVM Value Representation.
Definition Value.h:75
Type * getType() const
All values are typed, get the type of this value.
Definition Value.h:255
LLVMContext & getContext() const
All values hold a context through their type.
Definition Value.h:258
unsigned getValueID() const
Return an ID for the concrete type of this object.
Definition Value.h:543
LLVM_ABI void printAsOperand(raw_ostream &O, bool PrintType=true, const Module *M=nullptr) const
Print the name of this Value out to the specified raw_ostream.
LLVM_ABI StringRef getName() const
Return a constant reference to the value's name.
Definition Value.cpp:318
constexpr bool isScalable() const
Returns whether the quantity is scaled by a runtime quantity (vscale).
Definition TypeSize.h:168
An efficient, type-erasing, non-owning reference to a callable.
const ParentTy * getParent() const
Definition ilist_node.h:34
This class implements an extremely fast bulk output stream that can only output to a stream.
Definition raw_ostream.h:53
raw_ostream & indent(unsigned NumSpaces)
indent - Insert 'NumSpaces' spaces.
Changed
#define llvm_unreachable(msg)
Marks that the current location is not supposed to be reachable.
constexpr char Align[]
Key for Kernel::Arg::Metadata::mAlign.
const APInt & smin(const APInt &A, const APInt &B)
Determine the smaller of two APInts considered to be signed.
Definition APInt.h:2277
const APInt & smax(const APInt &A, const APInt &B)
Determine the larger of two APInts considered to be signed.
Definition APInt.h:2282
const APInt & umin(const APInt &A, const APInt &B)
Determine the smaller of two APInts considered to be unsigned.
Definition APInt.h:2287
LLVM_ABI std::optional< APInt > SolveQuadraticEquationWrap(APInt A, APInt B, APInt C, unsigned RangeWidth)
Let q(n) = An^2 + Bn + C, and BW = bit width of the value range (e.g.
Definition APInt.cpp:2863
const APInt & umax(const APInt &A, const APInt &B)
Determine the larger of two APInts considered to be unsigned.
Definition APInt.h:2292
LLVM_ABI APInt GreatestCommonDivisor(APInt A, APInt B)
Compute GCD of two unsigned APInt values.
Definition APInt.cpp:829
constexpr bool any(E Val)
@ Entry
Definition COFF.h:862
unsigned ID
LLVM IR allows to use arbitrary numbers as calling convention identifiers.
Definition CallingConv.h:24
@ C
The default llvm calling convention, compatible with C.
Definition CallingConv.h:34
int getMinValue(MCInstrInfo const &MCII, MCInst const &MCI)
Return the minimum value of an extendable operand.
@ BasicBlock
Various leaf nodes.
Definition ISDOpcodes.h:81
LLVM_ABI Function * getDeclarationIfExists(const Module *M, ID id)
Look up the Function declaration of the intrinsic id in the Module M and return it if it exists.
Predicate
Predicate - These are "(BI << 5) | BO" for various predicates.
match_combine_or< Ty... > m_CombineOr(const Ty &...Ps)
Combine pattern matchers matching any of Ps patterns.
BinaryOp_match< LHS, RHS, Instruction::AShr > m_AShr(const LHS &L, const RHS &R)
ap_match< APInt > m_APInt(const APInt *&Res)
Match a ConstantInt or splatted ConstantVector, binding the specified pointer to the contained APInt.
bool match(Val *V, const Pattern &P)
IntrinsicID_match m_Intrinsic()
Match intrinsic calls like this: m_Intrinsic<Intrinsic::fabs>(m_Value(X))
ThreeOps_match< Cond, LHS, RHS, Instruction::Select > m_Select(const Cond &C, const LHS &L, const RHS &R)
Matches SelectInst.
auto m_BasicBlock()
Match an arbitrary basic block value and ignore it.
ExtractValue_match< Ind, Val_t > m_ExtractValue(const Val_t &V)
Match a single index ExtractValue instruction.
auto m_Value()
Match an arbitrary value and ignore it.
auto m_LogicalOr()
Matches L || R where L and R are arbitrary values.
match_bind< WithOverflowInst > m_WithOverflowInst(WithOverflowInst *&I)
Match a with overflow intrinsic, capturing it if we match.
BinaryOp_match< LHS, RHS, Instruction::SDiv > m_SDiv(const LHS &L, const RHS &R)
BinaryOp_match< LHS, RHS, Instruction::LShr > m_LShr(const LHS &L, const RHS &R)
BinaryOp_match< LHS, RHS, Instruction::Shl > m_Shl(const LHS &L, const RHS &R)
auto m_LogicalAnd()
Matches L && R where L and R are arbitrary values.
brc_match< Cond_t, match_bind< BasicBlock >, match_bind< BasicBlock > > m_Br(const Cond_t &C, BasicBlock *&T, BasicBlock *&F)
auto m_ConstantInt()
Match an arbitrary ConstantInt and ignore it.
bind_cst_ty m_scev_APInt(const APInt *&C)
Match an SCEV constant and bind it to an APInt.
cst_pred_ty< is_all_ones > m_scev_AllOnes()
Match an integer with all bits set.
SCEVUnaryExpr_match< SCEVZeroExtendExpr, Op0_t > m_scev_ZExt(const Op0_t &Op0)
is_undef_or_poison m_scev_UndefOrPoison()
Match an SCEVUnknown wrapping undef or poison.
cst_pred_ty< is_one > m_scev_One()
Match an integer 1.
specificloop_ty m_SpecificLoop(const Loop *L)
SCEVUnaryExpr_match< SCEVSignExtendExpr, Op0_t > m_scev_SExt(const Op0_t &Op0)
match_bind< const SCEVMulExpr > m_scev_Mul(const SCEVMulExpr *&V)
cst_pred_ty< is_zero > m_scev_Zero()
Match an integer 0.
SCEVUnaryExpr_match< SCEVTruncateExpr, Op0_t > m_scev_Trunc(const Op0_t &Op0)
bool match(const SCEV *S, const Pattern &P)
SCEVBinaryExpr_match< SCEVUDivExpr, Op0_t, Op1_t > m_scev_UDiv(const Op0_t &Op0, const Op1_t &Op1)
specificscev_ty m_scev_Specific(const SCEV *S)
Match if we have a specific specified SCEV.
SCEVAffineAddRec_match< Op0_t, Op1_t, match_isa< const Loop > > m_scev_AffineAddRec(const Op0_t &Op0, const Op1_t &Op1)
match_bind< const SCEVUnknown > m_SCEVUnknown(const SCEVUnknown *&V)
SCEVBinaryExpr_match< SCEVMulExpr, Op0_t, Op1_t, SCEV::FlagNUW, true > m_scev_c_NUWMul(const Op0_t &Op0, const Op1_t &Op1)
match_bind< const SCEVAddExpr > m_scev_Add(const SCEVAddExpr *&V)
SCEVBinaryExpr_match< SCEVMulExpr, Op0_t, Op1_t, SCEV::FlagAnyWrap, true > m_scev_c_Mul(const Op0_t &Op0, const Op1_t &Op1)
SCEVBinaryExpr_match< SCEVSMaxExpr, Op0_t, Op1_t > m_scev_SMax(const Op0_t &Op0, const Op1_t &Op1)
SCEVURem_match< Op0_t, Op1_t > m_scev_URem(Op0_t LHS, Op1_t RHS, ScalarEvolution &SE)
Match the mathematical pattern A - (A / B) * B, where A and B can be arbitrary expressions.
@ Valid
The data is already valid.
initializer< Ty > init(const Ty &Val)
LocationClass< Ty > location(Ty &L)
@ Switch
The "resume-switch" lowering, where there are separate resume and destroy functions that are shared b...
Definition CoroShape.h:31
constexpr double e
NodeAddr< PhiNode * > Phi
Definition RDFGraph.h:390
friend class Instruction
Iterator for Instructions in a `BasicBlock.
Definition BasicBlock.h:73
This is an optimization pass for GlobalISel generic memory operations.
void visitAll(const SCEV *Root, SV &Visitor)
Use SCEVTraversal to visit all nodes in the given expression tree.
auto drop_begin(T &&RangeOrContainer, size_t N=1)
Return a range covering RangeOrContainer with the first N elements excluded.
Definition STLExtras.h:315
@ Offset
Definition DWP.cpp:557
FunctionAddr VTableAddr Value
Definition InstrProf.h:137
LLVM_ATTRIBUTE_ALWAYS_INLINE DynamicAPInt gcd(const DynamicAPInt &A, const DynamicAPInt &B)
void stable_sort(R &&Range)
Definition STLExtras.h:2115
bool all_of(R &&range, UnaryPredicate P)
Provide wrappers to std::all_of which take ranges instead of having to pass begin/end explicitly.
Definition STLExtras.h:1738
SaveAndRestore(T &) -> SaveAndRestore< T >
Printable print(const GCNRegPressure &RP, const GCNSubtarget *ST=nullptr, unsigned DynamicVGPRBlockSize=0)
LLVM_ABI bool canCreatePoison(const Operator *Op, bool ConsiderFlagsAndMetadata=true)
LLVM_ABI bool mustTriggerUB(const Instruction *I, const SmallPtrSetImpl< const Value * > &KnownPoison)
Return true if the given instruction must trigger undefined behavior when I is executed with any oper...
LLVM_ABI bool canConstantFoldCallTo(const CallBase *Call, const Function *F)
canConstantFoldCallTo - Return true if its even possible to fold a call to the specified function.
InterleavedRange< Range > interleaved(const Range &R, StringRef Separator=", ", StringRef Prefix="", StringRef Suffix="")
Output range R as a sequence of interleaved elements.
decltype(auto) dyn_cast(const From &Val)
dyn_cast<X> - Return the argument parameter cast to the specified type.
Definition Casting.h:643
LLVM_ABI bool verifyFunction(const Function &F, raw_ostream *OS=nullptr)
Check a function for errors, useful for use when debugging a pass.
auto successors(const MachineBasicBlock *BB)
scope_exit(Callable) -> scope_exit< Callable >
constexpr from_range_t from_range
auto dyn_cast_if_present(const Y &Val)
dyn_cast_if_present<X> - Functionally identical to dyn_cast, except that a null (or none in the case ...
Definition Casting.h:732
bool set_is_subset(const S1Ty &S1, const S2Ty &S2)
set_is_subset(A, B) - Return true iff A in B
void append_range(Container &C, Range &&R)
Wrapper function to append range R to container C.
Definition STLExtras.h:2207
constexpr bool isUIntN(unsigned N, uint64_t x)
Checks if an unsigned integer fits into the given (dynamic) bit width.
Definition MathExtras.h:243
LLVM_ABI Constant * ConstantFoldCompareInstOperands(unsigned Predicate, Constant *LHS, Constant *RHS, const DataLayout &DL, const TargetLibraryInfo *TLI=nullptr, const Instruction *I=nullptr)
Attempt to constant fold a compare instruction (icmp/fcmp) with the specified operands.
void * PointerTy
LLVM_ABI bool VerifySCEV
auto uninitialized_copy(R &&Src, IterTy Dst)
Definition STLExtras.h:2110
bool isa_and_nonnull(const Y &Val)
Definition Casting.h:676
LLVM_ABI ConstantRange getConstantRangeFromMetadata(const MDNode &RangeMD)
Parse out a conservative ConstantRange from !range metadata.
int countr_zero(T Val)
Count number of 0's from the least significant bit to the most stopping at the first 1.
Definition bit.h:204
LLVM_ABI Value * simplifyInstruction(Instruction *I, const SimplifyQuery &Q)
See if we can compute a simplified version of this instruction.
LLVM_ABI bool isOverflowIntrinsicNoWrap(const WithOverflowInst *WO, const DominatorTree &DT)
Returns true if the arithmetic part of the WO 's result is used only along the paths control dependen...
DomTreeNodeBase< BasicBlock > DomTreeNode
Definition Dominators.h:94
LLVM_ABI bool matchSimpleRecurrence(const PHINode *P, BinaryOperator *&BO, Value *&Start, Value *&Step)
Attempt to match a simple first order recurrence cycle of the form: iv = phi Ty [Start,...
auto dyn_cast_or_null(const Y &Val)
Definition Casting.h:753
void erase(Container &C, ValueType V)
Wrapper function to remove a value from a container:
Definition STLExtras.h:2199
bool any_of(R &&range, UnaryPredicate P)
Provide wrappers to std::any_of which take ranges instead of having to pass begin/end explicitly.
Definition STLExtras.h:1745
iterator_range< pointee_iterator< WrappedIteratorT > > make_pointee_range(RangeT &&Range)
Definition iterator.h:341
auto reverse(ContainerTy &&C)
Definition STLExtras.h:407
LLVM_ABI bool isMustProgress(const Loop *L)
Return true if this loop can be assumed to make progress.
LLVM_ABI bool impliesPoison(const Value *ValAssumedPoison, const Value *V)
Return true if V is poison given that ValAssumedPoison is already poison.
LLVM_ABI bool isFinite(const Loop *L)
Return true if this loop can be assumed to run for a finite number of iterations.
LLVM_ABI void computeKnownBits(const Value *V, KnownBits &Known, const DataLayout &DL, AssumptionCache *AC=nullptr, const Instruction *CxtI=nullptr, const DominatorTree *DT=nullptr, bool UseInstrInfo=true, unsigned Depth=0)
Determine which bits of V are known to be either zero or one and return them in the KnownZero/KnownOn...
unsigned short computeExpressionSize(ArrayRef< SCEVUse > Args)
LLVM_ABI bool programUndefinedIfPoison(const Instruction *Inst)
LLVM_ABI raw_ostream & dbgs()
dbgs() - This returns a reference to a raw_ostream for debugging messages.
Definition Debug.cpp:207
bool isPointerTy(const Type *T)
Definition SPIRVUtils.h:372
FunctionAddr VTableAddr Count
Definition InstrProf.h:139
LLVM_ABI ConstantRange getVScaleRange(const Function *F, unsigned BitWidth)
Determine the possible constant range of vscale with the given bit width, based on the vscale_range f...
class LLVM_GSL_OWNER SmallVector
Forward declaration of SmallVector so that calculateSmallVectorDefaultInlinedElements can reference s...
bool isa(const From &Val)
isa<X> - Return true if the parameter to the template is an instance of one of the template type argu...
Definition Casting.h:547
LLVM_ATTRIBUTE_VISIBILITY_DEFAULT AnalysisKey InnerAnalysisManagerProxy< AnalysisManagerT, IRUnitT, ExtraArgTs... >::Key
LLVM_ABI bool isKnownNonZero(const Value *V, const SimplifyQuery &Q, unsigned Depth=0)
Return true if the given value is known to be non-zero when defined.
LLVM_ABI bool propagatesPoison(const Use &PoisonOp)
Return true if PoisonOp's user yields poison or raises UB if its operand PoisonOp is poison.
@ UMin
Unsigned integer min implemented in terms of select(cmp()).
@ Mul
Product of integers.
@ SMax
Signed integer max implemented in terms of select(cmp()).
@ SMin
Signed integer min implemented in terms of select(cmp()).
@ Add
Sum of integers.
@ UMax
Unsigned integer max implemented in terms of select(cmp()).
auto count(R &&Range, const E &Element)
Wrapper function around std::count to count the number of times an element Element occurs in the give...
Definition STLExtras.h:2011
DWARFExpression::Operation Op
auto max_element(R &&Range)
Provide wrappers to std::max_element which take ranges instead of having to pass begin/end explicitly...
Definition STLExtras.h:2087
raw_ostream & operator<<(raw_ostream &OS, const APFixedPoint &FX)
ArrayRef(const T &OneElt) -> ArrayRef< T >
LLVM_ABI unsigned ComputeNumSignBits(const Value *Op, const DataLayout &DL, AssumptionCache *AC=nullptr, const Instruction *CxtI=nullptr, const DominatorTree *DT=nullptr, bool UseInstrInfo=true, unsigned Depth=0)
Return the number of times the sign bit of the register is replicated into the other bits.
constexpr unsigned BitWidth
OutputIt move(R &&Range, OutputIt Out)
Provide wrappers to std::move which take ranges instead of having to pass begin/end explicitly.
Definition STLExtras.h:1916
LLVM_ABI bool isGuaranteedToTransferExecutionToSuccessor(const Instruction *I)
Return true if this function can prove that the instruction I will always transfer execution to one o...
auto count_if(R &&Range, UnaryPredicate P)
Wrapper function around std::count_if to count the number of times an element satisfying a given pred...
Definition STLExtras.h:2018
decltype(auto) cast(const From &Val)
cast<X> - Return the argument parameter cast to the specified type.
Definition Casting.h:559
constexpr bool isIntN(unsigned N, int64_t x)
Checks if an signed integer fits into the given (dynamic) bit width.
Definition MathExtras.h:248
auto predecessors(const MachineBasicBlock *BB)
bool is_contained(R &&Range, const E &Element)
Returns true if Element is found in Range.
Definition STLExtras.h:1946
iterator_range< df_iterator< T > > depth_first(const T &G)
auto seq(T Begin, T End)
Iterate over an integral type from Begin up to - but not including - End.
Definition Sequence.h:305
AnalysisManager< Function > FunctionAnalysisManager
Convenience typedef for the Function analysis manager.
LLVM_ABI bool isGuaranteedNotToBePoison(const Value *V, AssumptionCache *AC=nullptr, const Instruction *CtxI=nullptr, const DominatorTree *DT=nullptr, unsigned Depth=0)
Returns true if V cannot be poison, but may be undef.
LLVM_ABI Constant * ConstantFoldInstOperands(const Instruction *I, ArrayRef< Constant * > Ops, const DataLayout &DL, const TargetLibraryInfo *TLI=nullptr, bool AllowNonDeterministic=true)
ConstantFoldInstOperands - Attempt to constant fold an instruction with the specified operands.
SCEVUseT< const SCEV * > SCEVUse
bool SCEVExprContains(const SCEV *Root, PredTy Pred)
Return true if any node in Root satisfies the predicate Pred.
Implement std::hash so that hash_code can be used in STL containers.
Definition BitVector.h:870
void swap(llvm::BitVector &LHS, llvm::BitVector &RHS)
Implement std::swap in terms of BitVector swap.
Definition BitVector.h:872
#define N
#define NC
Definition regutils.h:42
A special type used by analysis passes to provide an address that identifies that particular analysis...
Definition Analysis.h:29
static KnownBits makeConstant(const APInt &C)
Create known bits from a known constant.
Definition KnownBits.h:315
bool isNonNegative() const
Returns true if this value is known to be non-negative.
Definition KnownBits.h:106
static LLVM_ABI KnownBits ashr(const KnownBits &LHS, const KnownBits &RHS, bool ShAmtNonZero=false, bool Exact=false)
Compute known bits for ashr(LHS, RHS).
unsigned getBitWidth() const
Get the bit width of this value.
Definition KnownBits.h:44
static LLVM_ABI KnownBits lshr(const KnownBits &LHS, const KnownBits &RHS, bool ShAmtNonZero=false, bool Exact=false)
Compute known bits for lshr(LHS, RHS).
KnownBits zextOrTrunc(unsigned BitWidth) const
Return known bits for a zero extension or truncation of the value we're tracking.
Definition KnownBits.h:200
APInt getMaxValue() const
Return the maximal unsigned value possible given these KnownBits.
Definition KnownBits.h:146
APInt getMinValue() const
Return the minimal unsigned value possible given these KnownBits.
Definition KnownBits.h:130
bool isNegative() const
Returns true if this value is known to be negative.
Definition KnownBits.h:103
static LLVM_ABI KnownBits shl(const KnownBits &LHS, const KnownBits &RHS, bool NUW=false, bool NSW=false, bool ShAmtNonZero=false)
Compute known bits for shl(LHS, RHS).
An object of this class is returned by queries that could not be answered.
static LLVM_ABI bool classof(const SCEV *S)
Methods for support type inquiry through isa, cast, and dyn_cast:
This class defines a simple visitor class that may be used for various SCEV analysis purposes.
A utility class that uses RAII to save and restore the value of a variable.
Information about the number of loop iterations for which a loop exit's branch condition evaluates to...
LLVM_ABI ExitLimit(const SCEV *E)
Construct either an exact exit limit from a constant, or an unknown one from a SCEVCouldNotCompute.
SmallVector< const SCEVPredicate *, 4 > Predicates
A vector of predicate guards for this ExitLimit.