LLVM 23.0.0git
ScalarEvolution.cpp
Go to the documentation of this file.
1//===- ScalarEvolution.cpp - Scalar Evolution Analysis --------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file contains the implementation of the scalar evolution analysis
10// engine, which is used primarily to analyze expressions involving induction
11// variables in loops.
12//
13// There are several aspects to this library. First is the representation of
14// scalar expressions, which are represented as subclasses of the SCEV class.
15// These classes are used to represent certain types of subexpressions that we
16// can handle. We only create one SCEV of a particular shape, so
17// pointer-comparisons for equality are legal.
18//
19// One important aspect of the SCEV objects is that they are never cyclic, even
20// if there is a cycle in the dataflow for an expression (ie, a PHI node). If
21// the PHI node is one of the idioms that we can represent (e.g., a polynomial
22// recurrence) then we represent it directly as a recurrence node, otherwise we
23// represent it as a SCEVUnknown node.
24//
25// In addition to being able to represent expressions of various types, we also
26// have folders that are used to build the *canonical* representation for a
27// particular expression. These folders are capable of using a variety of
28// rewrite rules to simplify the expressions.
29//
30// Once the folders are defined, we can implement the more interesting
31// higher-level code, such as the code that recognizes PHI nodes of various
32// types, computes the execution count of a loop, etc.
33//
34// TODO: We should use these routines and value representations to implement
35// dependence analysis!
36//
37//===----------------------------------------------------------------------===//
38//
39// There are several good references for the techniques used in this analysis.
40//
41// Chains of recurrences -- a method to expedite the evaluation
42// of closed-form functions
43// Olaf Bachmann, Paul S. Wang, Eugene V. Zima
44//
45// On computational properties of chains of recurrences
46// Eugene V. Zima
47//
48// Symbolic Evaluation of Chains of Recurrences for Loop Optimization
49// Robert A. van Engelen
50//
51// Efficient Symbolic Analysis for Optimizing Compilers
52// Robert A. van Engelen
53//
54// Using the chains of recurrences algebra for data dependence testing and
55// induction variable substitution
56// MS Thesis, Johnie Birch
57//
58//===----------------------------------------------------------------------===//
59
61#include "llvm/ADT/APInt.h"
62#include "llvm/ADT/ArrayRef.h"
63#include "llvm/ADT/DenseMap.h"
65#include "llvm/ADT/FoldingSet.h"
66#include "llvm/ADT/STLExtras.h"
67#include "llvm/ADT/ScopeExit.h"
68#include "llvm/ADT/Sequence.h"
71#include "llvm/ADT/Statistic.h"
73#include "llvm/ADT/StringRef.h"
83#include "llvm/Config/llvm-config.h"
84#include "llvm/IR/Argument.h"
85#include "llvm/IR/BasicBlock.h"
86#include "llvm/IR/CFG.h"
87#include "llvm/IR/Constant.h"
89#include "llvm/IR/Constants.h"
90#include "llvm/IR/DataLayout.h"
92#include "llvm/IR/Dominators.h"
93#include "llvm/IR/Function.h"
94#include "llvm/IR/GlobalAlias.h"
95#include "llvm/IR/GlobalValue.h"
97#include "llvm/IR/InstrTypes.h"
98#include "llvm/IR/Instruction.h"
101#include "llvm/IR/Intrinsics.h"
102#include "llvm/IR/LLVMContext.h"
103#include "llvm/IR/Operator.h"
104#include "llvm/IR/PatternMatch.h"
105#include "llvm/IR/Type.h"
106#include "llvm/IR/Use.h"
107#include "llvm/IR/User.h"
108#include "llvm/IR/Value.h"
109#include "llvm/IR/Verifier.h"
111#include "llvm/Pass.h"
112#include "llvm/Support/Casting.h"
115#include "llvm/Support/Debug.h"
121#include <algorithm>
122#include <cassert>
123#include <climits>
124#include <cstdint>
125#include <cstdlib>
126#include <map>
127#include <memory>
128#include <numeric>
129#include <optional>
130#include <tuple>
131#include <utility>
132#include <vector>
133
134using namespace llvm;
135using namespace PatternMatch;
136using namespace SCEVPatternMatch;
137
138#define DEBUG_TYPE "scalar-evolution"
139
140STATISTIC(NumExitCountsComputed,
141 "Number of loop exits with predictable exit counts");
142STATISTIC(NumExitCountsNotComputed,
143 "Number of loop exits without predictable exit counts");
144STATISTIC(NumBruteForceTripCountsComputed,
145 "Number of loops with trip counts computed by force");
146
147#ifdef EXPENSIVE_CHECKS
148bool llvm::VerifySCEV = true;
149#else
150bool llvm::VerifySCEV = false;
151#endif
152
154 MaxBruteForceIterations("scalar-evolution-max-iterations", cl::ReallyHidden,
155 cl::desc("Maximum number of iterations SCEV will "
156 "symbolically execute a constant "
157 "derived loop"),
158 cl::init(100));
159
161 "verify-scev", cl::Hidden, cl::location(VerifySCEV),
162 cl::desc("Verify ScalarEvolution's backedge taken counts (slow)"));
164 "verify-scev-strict", cl::Hidden,
165 cl::desc("Enable stricter verification with -verify-scev is passed"));
166
168 "scev-verify-ir", cl::Hidden,
169 cl::desc("Verify IR correctness when making sensitive SCEV queries (slow)"),
170 cl::init(false));
171
173 "scev-mulops-inline-threshold", cl::Hidden,
174 cl::desc("Threshold for inlining multiplication operands into a SCEV"),
175 cl::init(32));
176
178 "scev-addops-inline-threshold", cl::Hidden,
179 cl::desc("Threshold for inlining addition operands into a SCEV"),
180 cl::init(500));
181
183 "scalar-evolution-max-scev-compare-depth", cl::Hidden,
184 cl::desc("Maximum depth of recursive SCEV complexity comparisons"),
185 cl::init(32));
186
188 "scalar-evolution-max-scev-operations-implication-depth", cl::Hidden,
189 cl::desc("Maximum depth of recursive SCEV operations implication analysis"),
190 cl::init(2));
191
193 "scalar-evolution-max-value-compare-depth", cl::Hidden,
194 cl::desc("Maximum depth of recursive value complexity comparisons"),
195 cl::init(2));
196
198 MaxArithDepth("scalar-evolution-max-arith-depth", cl::Hidden,
199 cl::desc("Maximum depth of recursive arithmetics"),
200 cl::init(32));
201
203 "scalar-evolution-max-constant-evolving-depth", cl::Hidden,
204 cl::desc("Maximum depth of recursive constant evolving"), cl::init(32));
205
207 MaxCastDepth("scalar-evolution-max-cast-depth", cl::Hidden,
208 cl::desc("Maximum depth of recursive SExt/ZExt/Trunc"),
209 cl::init(8));
210
212 MaxAddRecSize("scalar-evolution-max-add-rec-size", cl::Hidden,
213 cl::desc("Max coefficients in AddRec during evolving"),
214 cl::init(8));
215
217 HugeExprThreshold("scalar-evolution-huge-expr-threshold", cl::Hidden,
218 cl::desc("Size of the expression which is considered huge"),
219 cl::init(4096));
220
222 "scev-range-iter-threshold", cl::Hidden,
223 cl::desc("Threshold for switching to iteratively computing SCEV ranges"),
224 cl::init(32));
225
227 "scalar-evolution-max-loop-guard-collection-depth", cl::Hidden,
228 cl::desc("Maximum depth for recursive loop guard collection"), cl::init(1));
229
230static cl::opt<bool>
231ClassifyExpressions("scalar-evolution-classify-expressions",
232 cl::Hidden, cl::init(true),
233 cl::desc("When printing analysis, include information on every instruction"));
234
236 "scalar-evolution-use-expensive-range-sharpening", cl::Hidden,
237 cl::init(false),
238 cl::desc("Use more powerful methods of sharpening expression ranges. May "
239 "be costly in terms of compile time"));
240
242 "scalar-evolution-max-scc-analysis-depth", cl::Hidden,
243 cl::desc("Maximum amount of nodes to process while searching SCEVUnknown "
244 "Phi strongly connected components"),
245 cl::init(8));
246
247static cl::opt<bool>
248 EnableFiniteLoopControl("scalar-evolution-finite-loop", cl::Hidden,
249 cl::desc("Handle <= and >= in finite loops"),
250 cl::init(true));
251
253 "scalar-evolution-use-context-for-no-wrap-flag-strenghening", cl::Hidden,
254 cl::desc("Infer nuw/nsw flags using context where suitable"),
255 cl::init(true));
256
257//===----------------------------------------------------------------------===//
258// SCEV class definitions
259//===----------------------------------------------------------------------===//
260
262 // Leaf nodes are always their own canonical.
263 switch (getSCEVType()) {
264 case scConstant:
265 case scVScale:
266 case scUnknown:
267 CanonicalSCEV = this;
268 return;
269 default:
270 break;
271 }
272
273 // For all other expressions, check whether any immediate operand has a
274 // different canonical. Since operands are always created before their parent,
275 // their canonical pointers are already set — no recursion needed.
276 bool Changed = false;
278 for (SCEVUse Op : operands()) {
279 CanonOps.push_back(Op->getCanonical());
280 Changed |= CanonOps.back() != Op.getPointer();
281 }
282
283 if (!Changed) {
284 CanonicalSCEV = this;
285 return;
286 }
287
288 auto *NAry = dyn_cast<SCEVNAryExpr>(this);
289 SCEV::NoWrapFlags Flags = NAry ? NAry->getNoWrapFlags() : SCEV::FlagAnyWrap;
290 switch (getSCEVType()) {
291 case scPtrToAddr:
292 CanonicalSCEV = SE.getPtrToAddrExpr(CanonOps[0]);
293 return;
294 case scPtrToInt:
295 CanonicalSCEV = SE.getPtrToIntExpr(CanonOps[0], getType());
296 return;
297 case scTruncate:
298 CanonicalSCEV = SE.getTruncateExpr(CanonOps[0], getType());
299 return;
300 case scZeroExtend:
301 CanonicalSCEV = SE.getZeroExtendExpr(CanonOps[0], getType());
302 return;
303 case scSignExtend:
304 CanonicalSCEV = SE.getSignExtendExpr(CanonOps[0], getType());
305 return;
306 case scUDivExpr:
307 CanonicalSCEV = SE.getUDivExpr(CanonOps[0], CanonOps[1]);
308 return;
309 case scAddExpr:
310 CanonicalSCEV = SE.getAddExpr(CanonOps, Flags);
311 return;
312 case scMulExpr:
313 CanonicalSCEV = SE.getMulExpr(CanonOps, Flags);
314 return;
315 case scAddRecExpr:
317 CanonOps, cast<SCEVAddRecExpr>(this)->getLoop(), Flags);
318 return;
319 case scSMaxExpr:
320 CanonicalSCEV = SE.getSMaxExpr(CanonOps);
321 return;
322 case scUMaxExpr:
323 CanonicalSCEV = SE.getUMaxExpr(CanonOps);
324 return;
325 case scSMinExpr:
326 CanonicalSCEV = SE.getSMinExpr(CanonOps);
327 return;
328 case scUMinExpr:
329 CanonicalSCEV = SE.getUMinExpr(CanonOps);
330 return;
332 CanonicalSCEV = SE.getUMinExpr(CanonOps, /*Sequential=*/true);
333 return;
334 default:
335 llvm_unreachable("Unknown SCEV type");
336 }
337}
338
339//===----------------------------------------------------------------------===//
340// Implementation of the SCEV class.
341//
342
343#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
345 print(dbgs());
346 dbgs() << '\n';
347}
348#endif
349
350void SCEV::print(raw_ostream &OS) const {
351 switch (getSCEVType()) {
352 case scConstant:
353 cast<SCEVConstant>(this)->getValue()->printAsOperand(OS, false);
354 return;
355 case scVScale:
356 OS << "vscale";
357 return;
358 case scPtrToAddr:
359 case scPtrToInt: {
360 const SCEVCastExpr *PtrCast = cast<SCEVCastExpr>(this);
361 const SCEV *Op = PtrCast->getOperand();
362 StringRef OpS = getSCEVType() == scPtrToAddr ? "addr" : "int";
363 OS << "(ptrto" << OpS << " " << *Op->getType() << " " << *Op << " to "
364 << *PtrCast->getType() << ")";
365 return;
366 }
367 case scTruncate: {
368 const SCEVTruncateExpr *Trunc = cast<SCEVTruncateExpr>(this);
369 const SCEV *Op = Trunc->getOperand();
370 OS << "(trunc " << *Op->getType() << " " << *Op << " to "
371 << *Trunc->getType() << ")";
372 return;
373 }
374 case scZeroExtend: {
376 const SCEV *Op = ZExt->getOperand();
377 OS << "(zext " << *Op->getType() << " " << *Op << " to "
378 << *ZExt->getType() << ")";
379 return;
380 }
381 case scSignExtend: {
383 const SCEV *Op = SExt->getOperand();
384 OS << "(sext " << *Op->getType() << " " << *Op << " to "
385 << *SExt->getType() << ")";
386 return;
387 }
388 case scAddRecExpr: {
389 const SCEVAddRecExpr *AR = cast<SCEVAddRecExpr>(this);
390 OS << "{" << *AR->getOperand(0);
391 for (unsigned i = 1, e = AR->getNumOperands(); i != e; ++i)
392 OS << ",+," << *AR->getOperand(i);
393 OS << "}<";
394 if (AR->hasNoUnsignedWrap())
395 OS << "nuw><";
396 if (AR->hasNoSignedWrap())
397 OS << "nsw><";
398 if (AR->hasNoSelfWrap() && !AR->hasNoUnsignedWrap() &&
399 !AR->hasNoSignedWrap())
400 OS << "nw><";
401 AR->getLoop()->getHeader()->printAsOperand(OS, /*PrintType=*/false);
402 OS << ">";
403 return;
404 }
405 case scAddExpr:
406 case scMulExpr:
407 case scUMaxExpr:
408 case scSMaxExpr:
409 case scUMinExpr:
410 case scSMinExpr:
412 const SCEVNAryExpr *NAry = cast<SCEVNAryExpr>(this);
413 const char *OpStr = nullptr;
414 switch (NAry->getSCEVType()) {
415 case scAddExpr: OpStr = " + "; break;
416 case scMulExpr: OpStr = " * "; break;
417 case scUMaxExpr: OpStr = " umax "; break;
418 case scSMaxExpr: OpStr = " smax "; break;
419 case scUMinExpr:
420 OpStr = " umin ";
421 break;
422 case scSMinExpr:
423 OpStr = " smin ";
424 break;
426 OpStr = " umin_seq ";
427 break;
428 default:
429 llvm_unreachable("There are no other nary expression types.");
430 }
431 OS << "("
433 << ")";
434 switch (NAry->getSCEVType()) {
435 case scAddExpr:
436 case scMulExpr:
437 if (NAry->hasNoUnsignedWrap())
438 OS << "<nuw>";
439 if (NAry->hasNoSignedWrap())
440 OS << "<nsw>";
441 break;
442 default:
443 // Nothing to print for other nary expressions.
444 break;
445 }
446 return;
447 }
448 case scUDivExpr: {
449 const SCEVUDivExpr *UDiv = cast<SCEVUDivExpr>(this);
450 OS << "(" << *UDiv->getLHS() << " /u " << *UDiv->getRHS() << ")";
451 return;
452 }
453 case scUnknown:
454 cast<SCEVUnknown>(this)->getValue()->printAsOperand(OS, false);
455 return;
457 OS << "***COULDNOTCOMPUTE***";
458 return;
459 }
460 llvm_unreachable("Unknown SCEV kind!");
461}
462
464 switch (getSCEVType()) {
465 case scConstant:
466 return cast<SCEVConstant>(this)->getType();
467 case scVScale:
468 return cast<SCEVVScale>(this)->getType();
469 case scPtrToAddr:
470 case scPtrToInt:
471 case scTruncate:
472 case scZeroExtend:
473 case scSignExtend:
474 return cast<SCEVCastExpr>(this)->getType();
475 case scAddRecExpr:
476 return cast<SCEVAddRecExpr>(this)->getType();
477 case scMulExpr:
478 return cast<SCEVMulExpr>(this)->getType();
479 case scUMaxExpr:
480 case scSMaxExpr:
481 case scUMinExpr:
482 case scSMinExpr:
483 return cast<SCEVMinMaxExpr>(this)->getType();
485 return cast<SCEVSequentialMinMaxExpr>(this)->getType();
486 case scAddExpr:
487 return cast<SCEVAddExpr>(this)->getType();
488 case scUDivExpr:
489 return cast<SCEVUDivExpr>(this)->getType();
490 case scUnknown:
491 return cast<SCEVUnknown>(this)->getType();
493 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
494 }
495 llvm_unreachable("Unknown SCEV kind!");
496}
497
499 switch (getSCEVType()) {
500 case scConstant:
501 case scVScale:
502 case scUnknown:
503 return {};
504 case scPtrToAddr:
505 case scPtrToInt:
506 case scTruncate:
507 case scZeroExtend:
508 case scSignExtend:
509 return cast<SCEVCastExpr>(this)->operands();
510 case scAddRecExpr:
511 case scAddExpr:
512 case scMulExpr:
513 case scUMaxExpr:
514 case scSMaxExpr:
515 case scUMinExpr:
516 case scSMinExpr:
518 return cast<SCEVNAryExpr>(this)->operands();
519 case scUDivExpr:
520 return cast<SCEVUDivExpr>(this)->operands();
522 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
523 }
524 llvm_unreachable("Unknown SCEV kind!");
525}
526
527bool SCEV::isZero() const { return match(this, m_scev_Zero()); }
528
529bool SCEV::isOne() const { return match(this, m_scev_One()); }
530
531bool SCEV::isAllOnesValue() const { return match(this, m_scev_AllOnes()); }
532
535 if (!Mul) return false;
536
537 // If there is a constant factor, it will be first.
538 const SCEVConstant *SC = dyn_cast<SCEVConstant>(Mul->getOperand(0));
539 if (!SC) return false;
540
541 // Return true if the value is negative, this matches things like (-42 * V).
542 return SC->getAPInt().isNegative();
543}
544
547
549 return S->getSCEVType() == scCouldNotCompute;
550}
551
554 ID.AddInteger(scConstant);
555 ID.AddPointer(V);
556 void *IP = nullptr;
557 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
558 SCEV *S = new (SCEVAllocator) SCEVConstant(ID.Intern(SCEVAllocator), V);
559 UniqueSCEVs.InsertNode(S, IP);
560 S->computeAndSetCanonical(*this);
561 return S;
562}
563
565 return getConstant(ConstantInt::get(getContext(), Val));
566}
567
568const SCEV *
571 // TODO: Avoid implicit trunc?
572 // See https://github.com/llvm/llvm-project/issues/112510.
573 return getConstant(
574 ConstantInt::get(ITy, V, isSigned, /*ImplicitTrunc=*/true));
575}
576
579 ID.AddInteger(scVScale);
580 ID.AddPointer(Ty);
581 void *IP = nullptr;
582 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP))
583 return S;
584 SCEV *S = new (SCEVAllocator) SCEVVScale(ID.Intern(SCEVAllocator), Ty);
585 UniqueSCEVs.InsertNode(S, IP);
586 S->computeAndSetCanonical(*this);
587 return S;
588}
589
591 SCEV::NoWrapFlags Flags) {
592 const SCEV *Res = getConstant(Ty, EC.getKnownMinValue());
593 if (EC.isScalable())
594 Res = getMulExpr(Res, getVScale(Ty), Flags);
595 return Res;
596}
597
601
602SCEVPtrToAddrExpr::SCEVPtrToAddrExpr(const FoldingSetNodeIDRef ID,
603 const SCEV *Op, Type *ITy)
604 : SCEVCastExpr(ID, scPtrToAddr, Op, ITy) {
605 assert(getOperand()->getType()->isPointerTy() && Ty->isIntegerTy() &&
606 "Must be a non-bit-width-changing pointer-to-integer cast!");
607}
608
609SCEVPtrToIntExpr::SCEVPtrToIntExpr(const FoldingSetNodeIDRef ID, SCEVUse Op,
610 Type *ITy)
611 : SCEVCastExpr(ID, scPtrToInt, Op, ITy) {
612 assert(getOperand()->getType()->isPointerTy() && Ty->isIntegerTy() &&
613 "Must be a non-bit-width-changing pointer-to-integer cast!");
614}
615
620
621SCEVTruncateExpr::SCEVTruncateExpr(const FoldingSetNodeIDRef ID, SCEVUse op,
622 Type *ty)
624 assert(getOperand()->getType()->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
625 "Cannot truncate non-integer value!");
626}
627
628SCEVZeroExtendExpr::SCEVZeroExtendExpr(const FoldingSetNodeIDRef ID, SCEVUse op,
629 Type *ty)
631 assert(getOperand()->getType()->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
632 "Cannot zero extend non-integer value!");
633}
634
635SCEVSignExtendExpr::SCEVSignExtendExpr(const FoldingSetNodeIDRef ID, SCEVUse op,
636 Type *ty)
638 assert(getOperand()->getType()->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
639 "Cannot sign extend non-integer value!");
640}
641
643 // Clear this SCEVUnknown from various maps.
644 SE->forgetMemoizedResults({this});
645
646 // Remove this SCEVUnknown from the uniquing map.
647 SE->UniqueSCEVs.RemoveNode(this);
648
649 // Release the value.
650 setValPtr(nullptr);
651}
652
653void SCEVUnknown::allUsesReplacedWith(Value *New) {
654 // Clear this SCEVUnknown from various maps.
655 SE->forgetMemoizedResults({this});
656
657 // Remove this SCEVUnknown from the uniquing map.
658 SE->UniqueSCEVs.RemoveNode(this);
659
660 // Replace the value pointer in case someone is still using this SCEVUnknown.
661 setValPtr(New);
662}
663
664//===----------------------------------------------------------------------===//
665// SCEV Utilities
666//===----------------------------------------------------------------------===//
667
668/// Compare the two values \p LV and \p RV in terms of their "complexity" where
669/// "complexity" is a partial (and somewhat ad-hoc) relation used to order
670/// operands in SCEV expressions.
671static int CompareValueComplexity(const LoopInfo *const LI, Value *LV,
672 Value *RV, unsigned Depth) {
674 return 0;
675
676 // Order pointer values after integer values. This helps SCEVExpander form
677 // GEPs.
678 bool LIsPointer = LV->getType()->isPointerTy(),
679 RIsPointer = RV->getType()->isPointerTy();
680 if (LIsPointer != RIsPointer)
681 return (int)LIsPointer - (int)RIsPointer;
682
683 // Compare getValueID values.
684 unsigned LID = LV->getValueID(), RID = RV->getValueID();
685 if (LID != RID)
686 return (int)LID - (int)RID;
687
688 // Sort arguments by their position.
689 if (const auto *LA = dyn_cast<Argument>(LV)) {
690 const auto *RA = cast<Argument>(RV);
691 unsigned LArgNo = LA->getArgNo(), RArgNo = RA->getArgNo();
692 return (int)LArgNo - (int)RArgNo;
693 }
694
695 if (const auto *LGV = dyn_cast<GlobalValue>(LV)) {
696 const auto *RGV = cast<GlobalValue>(RV);
697
698 if (auto L = LGV->getLinkage() - RGV->getLinkage())
699 return L;
700
701 const auto IsGVNameSemantic = [&](const GlobalValue *GV) {
702 auto LT = GV->getLinkage();
703 return !(GlobalValue::isPrivateLinkage(LT) ||
705 };
706
707 // Use the names to distinguish the two values, but only if the
708 // names are semantically important.
709 if (IsGVNameSemantic(LGV) && IsGVNameSemantic(RGV))
710 return LGV->getName().compare(RGV->getName());
711 }
712
713 // For instructions, compare their loop depth, and their operand count. This
714 // is pretty loose.
715 if (const auto *LInst = dyn_cast<Instruction>(LV)) {
716 const auto *RInst = cast<Instruction>(RV);
717
718 // Compare loop depths.
719 const BasicBlock *LParent = LInst->getParent(),
720 *RParent = RInst->getParent();
721 if (LParent != RParent) {
722 unsigned LDepth = LI->getLoopDepth(LParent),
723 RDepth = LI->getLoopDepth(RParent);
724 if (LDepth != RDepth)
725 return (int)LDepth - (int)RDepth;
726 }
727
728 // Compare the number of operands.
729 unsigned LNumOps = LInst->getNumOperands(),
730 RNumOps = RInst->getNumOperands();
731 if (LNumOps != RNumOps)
732 return (int)LNumOps - (int)RNumOps;
733
734 for (unsigned Idx : seq(LNumOps)) {
735 int Result = CompareValueComplexity(LI, LInst->getOperand(Idx),
736 RInst->getOperand(Idx), Depth + 1);
737 if (Result != 0)
738 return Result;
739 }
740 }
741
742 return 0;
743}
744
745// Return negative, zero, or positive, if LHS is less than, equal to, or greater
746// than RHS, respectively. A three-way result allows recursive comparisons to be
747// more efficient.
748// If the max analysis depth was reached, return std::nullopt, assuming we do
749// not know if they are equivalent for sure.
750static std::optional<int>
751CompareSCEVComplexity(const LoopInfo *const LI, const SCEV *LHS,
752 const SCEV *RHS, DominatorTree &DT, unsigned Depth = 0) {
753 // Fast-path: SCEVs are uniqued so we can do a quick equality check.
754 if (LHS == RHS)
755 return 0;
756
757 // Primarily, sort the SCEVs by their getSCEVType().
758 SCEVTypes LType = LHS->getSCEVType(), RType = RHS->getSCEVType();
759 if (LType != RType)
760 return (int)LType - (int)RType;
761
763 return std::nullopt;
764
765 // Aside from the getSCEVType() ordering, the particular ordering
766 // isn't very important except that it's beneficial to be consistent,
767 // so that (a + b) and (b + a) don't end up as different expressions.
768 switch (LType) {
769 case scUnknown: {
770 const SCEVUnknown *LU = cast<SCEVUnknown>(LHS);
771 const SCEVUnknown *RU = cast<SCEVUnknown>(RHS);
772
773 int X =
774 CompareValueComplexity(LI, LU->getValue(), RU->getValue(), Depth + 1);
775 return X;
776 }
777
778 case scConstant: {
781
782 // Compare constant values.
783 const APInt &LA = LC->getAPInt();
784 const APInt &RA = RC->getAPInt();
785 unsigned LBitWidth = LA.getBitWidth(), RBitWidth = RA.getBitWidth();
786 if (LBitWidth != RBitWidth)
787 return (int)LBitWidth - (int)RBitWidth;
788 return LA.ult(RA) ? -1 : 1;
789 }
790
791 case scVScale: {
792 const auto *LTy = cast<IntegerType>(cast<SCEVVScale>(LHS)->getType());
793 const auto *RTy = cast<IntegerType>(cast<SCEVVScale>(RHS)->getType());
794 return LTy->getBitWidth() - RTy->getBitWidth();
795 }
796
797 case scAddRecExpr: {
800
801 // There is always a dominance between two recs that are used by one SCEV,
802 // so we can safely sort recs by loop header dominance. We require such
803 // order in getAddExpr.
804 const Loop *LLoop = LA->getLoop(), *RLoop = RA->getLoop();
805 if (LLoop != RLoop) {
806 const BasicBlock *LHead = LLoop->getHeader(), *RHead = RLoop->getHeader();
807 assert(LHead != RHead && "Two loops share the same header?");
808 if (DT.dominates(LHead, RHead))
809 return 1;
810 assert(DT.dominates(RHead, LHead) &&
811 "No dominance between recurrences used by one SCEV?");
812 return -1;
813 }
814
815 [[fallthrough]];
816 }
817
818 case scTruncate:
819 case scZeroExtend:
820 case scSignExtend:
821 case scPtrToAddr:
822 case scPtrToInt:
823 case scAddExpr:
824 case scMulExpr:
825 case scUDivExpr:
826 case scSMaxExpr:
827 case scUMaxExpr:
828 case scSMinExpr:
829 case scUMinExpr:
831 ArrayRef<SCEVUse> LOps = LHS->operands();
832 ArrayRef<SCEVUse> ROps = RHS->operands();
833
834 // Lexicographically compare n-ary-like expressions.
835 unsigned LNumOps = LOps.size(), RNumOps = ROps.size();
836 if (LNumOps != RNumOps)
837 return (int)LNumOps - (int)RNumOps;
838
839 for (unsigned i = 0; i != LNumOps; ++i) {
840 auto X = CompareSCEVComplexity(LI, LOps[i].getPointer(),
841 ROps[i].getPointer(), DT, Depth + 1);
842 if (X != 0)
843 return X;
844 }
845 return 0;
846 }
847
849 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
850 }
851 llvm_unreachable("Unknown SCEV kind!");
852}
853
854/// Given a list of SCEV objects, order them by their complexity, and group
855/// objects of the same complexity together by value. When this routine is
856/// finished, we know that any duplicates in the vector are consecutive and that
857/// complexity is monotonically increasing.
858///
859/// Note that we go take special precautions to ensure that we get deterministic
860/// results from this routine. In other words, we don't want the results of
861/// this to depend on where the addresses of various SCEV objects happened to
862/// land in memory.
864 DominatorTree &DT) {
865 if (Ops.size() < 2) return; // Noop
866
867 // Whether LHS has provably less complexity than RHS.
868 auto IsLessComplex = [&](SCEVUse LHS, SCEVUse RHS) {
869 auto Complexity = CompareSCEVComplexity(LI, LHS, RHS, DT);
870 return Complexity && *Complexity < 0;
871 };
872 if (Ops.size() == 2) {
873 // This is the common case, which also happens to be trivially simple.
874 // Special case it.
875 SCEVUse &LHS = Ops[0], &RHS = Ops[1];
876 if (IsLessComplex(RHS, LHS))
877 std::swap(LHS, RHS);
878 return;
879 }
880
881 // Do the rough sort by complexity.
883 Ops, [&](SCEVUse LHS, SCEVUse RHS) { return IsLessComplex(LHS, RHS); });
884
885 // Now that we are sorted by complexity, group elements of the same
886 // complexity. Note that this is, at worst, N^2, but the vector is likely to
887 // be extremely short in practice. Note that we take this approach because we
888 // do not want to depend on the addresses of the objects we are grouping.
889 for (unsigned i = 0, e = Ops.size(); i != e-2; ++i) {
890 const SCEV *S = Ops[i];
891 unsigned Complexity = S->getSCEVType();
892
893 // If there are any objects of the same complexity and same value as this
894 // one, group them.
895 for (unsigned j = i+1; j != e && Ops[j]->getSCEVType() == Complexity; ++j) {
896 if (Ops[j] == S) { // Found a duplicate.
897 // Move it to immediately after i'th element.
898 std::swap(Ops[i+1], Ops[j]);
899 ++i; // no need to rescan it.
900 if (i == e-2) return; // Done!
901 }
902 }
903 }
904}
905
906/// Returns true if \p Ops contains a huge SCEV (the subtree of S contains at
907/// least HugeExprThreshold nodes).
909 return any_of(Ops, [](const SCEV *S) {
911 });
912}
913
914/// Performs a number of common optimizations on the passed \p Ops. If the
915/// whole expression reduces down to a single operand, it will be returned.
916///
917/// The following optimizations are performed:
918/// * Fold constants using the \p Fold function.
919/// * Remove identity constants satisfying \p IsIdentity.
920/// * If a constant satisfies \p IsAbsorber, return it.
921/// * Sort operands by complexity.
922template <typename FoldT, typename IsIdentityT, typename IsAbsorberT>
923static const SCEV *
925 SmallVectorImpl<SCEVUse> &Ops, FoldT Fold,
926 IsIdentityT IsIdentity, IsAbsorberT IsAbsorber) {
927 const SCEVConstant *Folded = nullptr;
928 for (unsigned Idx = 0; Idx < Ops.size();) {
929 const SCEV *Op = Ops[Idx];
930 if (const auto *C = dyn_cast<SCEVConstant>(Op)) {
931 if (!Folded)
932 Folded = C;
933 else
934 Folded = cast<SCEVConstant>(
935 SE.getConstant(Fold(Folded->getAPInt(), C->getAPInt())));
936 Ops.erase(Ops.begin() + Idx);
937 continue;
938 }
939 ++Idx;
940 }
941
942 if (Ops.empty()) {
943 assert(Folded && "Must have folded value");
944 return Folded;
945 }
946
947 if (Folded && IsAbsorber(Folded->getAPInt()))
948 return Folded;
949
950 GroupByComplexity(Ops, &LI, DT);
951 if (Folded && !IsIdentity(Folded->getAPInt()))
952 Ops.insert(Ops.begin(), Folded);
953
954 return Ops.size() == 1 ? Ops[0] : nullptr;
955}
956
957//===----------------------------------------------------------------------===//
958// Simple SCEV method implementations
959//===----------------------------------------------------------------------===//
960
961/// Compute BC(It, K). The result has width W. Assume, K > 0.
962static const SCEV *BinomialCoefficient(const SCEV *It, unsigned K,
963 ScalarEvolution &SE,
964 Type *ResultTy) {
965 // Handle the simplest case efficiently.
966 if (K == 1)
967 return SE.getTruncateOrZeroExtend(It, ResultTy);
968
969 // We are using the following formula for BC(It, K):
970 //
971 // BC(It, K) = (It * (It - 1) * ... * (It - K + 1)) / K!
972 //
973 // Suppose, W is the bitwidth of the return value. We must be prepared for
974 // overflow. Hence, we must assure that the result of our computation is
975 // equal to the accurate one modulo 2^W. Unfortunately, division isn't
976 // safe in modular arithmetic.
977 //
978 // However, this code doesn't use exactly that formula; the formula it uses
979 // is something like the following, where T is the number of factors of 2 in
980 // K! (i.e. trailing zeros in the binary representation of K!), and ^ is
981 // exponentiation:
982 //
983 // BC(It, K) = (It * (It - 1) * ... * (It - K + 1)) / 2^T / (K! / 2^T)
984 //
985 // This formula is trivially equivalent to the previous formula. However,
986 // this formula can be implemented much more efficiently. The trick is that
987 // K! / 2^T is odd, and exact division by an odd number *is* safe in modular
988 // arithmetic. To do exact division in modular arithmetic, all we have
989 // to do is multiply by the inverse. Therefore, this step can be done at
990 // width W.
991 //
992 // The next issue is how to safely do the division by 2^T. The way this
993 // is done is by doing the multiplication step at a width of at least W + T
994 // bits. This way, the bottom W+T bits of the product are accurate. Then,
995 // when we perform the division by 2^T (which is equivalent to a right shift
996 // by T), the bottom W bits are accurate. Extra bits are okay; they'll get
997 // truncated out after the division by 2^T.
998 //
999 // In comparison to just directly using the first formula, this technique
1000 // is much more efficient; using the first formula requires W * K bits,
1001 // but this formula less than W + K bits. Also, the first formula requires
1002 // a division step, whereas this formula only requires multiplies and shifts.
1003 //
1004 // It doesn't matter whether the subtraction step is done in the calculation
1005 // width or the input iteration count's width; if the subtraction overflows,
1006 // the result must be zero anyway. We prefer here to do it in the width of
1007 // the induction variable because it helps a lot for certain cases; CodeGen
1008 // isn't smart enough to ignore the overflow, which leads to much less
1009 // efficient code if the width of the subtraction is wider than the native
1010 // register width.
1011 //
1012 // (It's possible to not widen at all by pulling out factors of 2 before
1013 // the multiplication; for example, K=2 can be calculated as
1014 // It/2*(It+(It*INT_MIN/INT_MIN)+-1). However, it requires
1015 // extra arithmetic, so it's not an obvious win, and it gets
1016 // much more complicated for K > 3.)
1017
1018 // Protection from insane SCEVs; this bound is conservative,
1019 // but it probably doesn't matter.
1020 if (K > 1000)
1021 return SE.getCouldNotCompute();
1022
1023 unsigned W = SE.getTypeSizeInBits(ResultTy);
1024
1025 // Calculate K! / 2^T and T; we divide out the factors of two before
1026 // multiplying for calculating K! / 2^T to avoid overflow.
1027 // Other overflow doesn't matter because we only care about the bottom
1028 // W bits of the result.
1029 APInt OddFactorial(W, 1);
1030 unsigned T = 1;
1031 for (unsigned i = 3; i <= K; ++i) {
1032 unsigned TwoFactors = countr_zero(i);
1033 T += TwoFactors;
1034 OddFactorial *= (i >> TwoFactors);
1035 }
1036
1037 // We need at least W + T bits for the multiplication step
1038 unsigned CalculationBits = W + T;
1039
1040 // Calculate 2^T, at width T+W.
1041 APInt DivFactor = APInt::getOneBitSet(CalculationBits, T);
1042
1043 // Calculate the multiplicative inverse of K! / 2^T;
1044 // this multiplication factor will perform the exact division by
1045 // K! / 2^T.
1046 APInt MultiplyFactor = OddFactorial.multiplicativeInverse();
1047
1048 // Calculate the product, at width T+W
1049 IntegerType *CalculationTy = IntegerType::get(SE.getContext(),
1050 CalculationBits);
1051 const SCEV *Dividend = SE.getTruncateOrZeroExtend(It, CalculationTy);
1052 for (unsigned i = 1; i != K; ++i) {
1053 const SCEV *S = SE.getMinusSCEV(It, SE.getConstant(It->getType(), i));
1054 Dividend = SE.getMulExpr(Dividend,
1055 SE.getTruncateOrZeroExtend(S, CalculationTy));
1056 }
1057
1058 // Divide by 2^T
1059 const SCEV *DivResult = SE.getUDivExpr(Dividend, SE.getConstant(DivFactor));
1060
1061 // Truncate the result, and divide by K! / 2^T.
1062
1063 return SE.getMulExpr(SE.getConstant(MultiplyFactor),
1064 SE.getTruncateOrZeroExtend(DivResult, ResultTy));
1065}
1066
1067/// Return the value of this chain of recurrences at the specified iteration
1068/// number. We can evaluate this recurrence by multiplying each element in the
1069/// chain by the binomial coefficient corresponding to it. In other words, we
1070/// can evaluate {A,+,B,+,C,+,D} as:
1071///
1072/// A*BC(It, 0) + B*BC(It, 1) + C*BC(It, 2) + D*BC(It, 3)
1073///
1074/// where BC(It, k) stands for binomial coefficient.
1076 ScalarEvolution &SE) const {
1077 return evaluateAtIteration(operands(), It, SE);
1078}
1079
1081 const SCEV *It,
1082 ScalarEvolution &SE) {
1083 assert(Operands.size() > 0);
1084 const SCEV *Result = Operands[0].getPointer();
1085 for (unsigned i = 1, e = Operands.size(); i != e; ++i) {
1086 // The computation is correct in the face of overflow provided that the
1087 // multiplication is performed _after_ the evaluation of the binomial
1088 // coefficient.
1089 const SCEV *Coeff = BinomialCoefficient(It, i, SE, Result->getType());
1090 if (isa<SCEVCouldNotCompute>(Coeff))
1091 return Coeff;
1092
1093 Result =
1094 SE.getAddExpr(Result, SE.getMulExpr(Operands[i].getPointer(), Coeff));
1095 }
1096 return Result;
1097}
1098
1099//===----------------------------------------------------------------------===//
1100// SCEV Expression folder implementations
1101//===----------------------------------------------------------------------===//
1102
1103/// The SCEVCastSinkingRewriter takes a scalar evolution expression,
1104/// which computes a pointer-typed value, and rewrites the whole expression
1105/// tree so that *all* the computations are done on integers, and the only
1106/// pointer-typed operands in the expression are SCEVUnknown.
1107/// The CreatePtrCast callback is invoked to create the actual conversion
1108/// (ptrtoint or ptrtoaddr) at the SCEVUnknown leaves.
1110 : public SCEVRewriteVisitor<SCEVCastSinkingRewriter> {
1112 using ConversionFn = function_ref<const SCEV *(const SCEVUnknown *)>;
1113 Type *TargetTy;
1114 ConversionFn CreatePtrCast;
1115
1116public:
1118 ConversionFn CreatePtrCast)
1119 : Base(SE), TargetTy(TargetTy), CreatePtrCast(std::move(CreatePtrCast)) {}
1120
1121 static const SCEV *rewrite(const SCEV *Scev, ScalarEvolution &SE,
1122 Type *TargetTy, ConversionFn CreatePtrCast) {
1123 SCEVCastSinkingRewriter Rewriter(SE, TargetTy, std::move(CreatePtrCast));
1124 return Rewriter.visit(Scev);
1125 }
1126
1127 const SCEV *visit(const SCEV *S) {
1128 Type *STy = S->getType();
1129 // If the expression is not pointer-typed, just keep it as-is.
1130 if (!STy->isPointerTy())
1131 return S;
1132 // Else, recursively sink the cast down into it.
1133 return Base::visit(S);
1134 }
1135
1136 const SCEV *visitAddExpr(const SCEVAddExpr *Expr) {
1137 // Preserve wrap flags on rewritten SCEVAddExpr, which the default
1138 // implementation drops.
1139 SmallVector<SCEVUse, 2> Operands;
1140 bool Changed = false;
1141 for (SCEVUse Op : Expr->operands()) {
1142 Operands.push_back(visit(Op.getPointer()));
1143 Changed |= Op.getPointer() != Operands.back();
1144 }
1145 return !Changed ? Expr : SE.getAddExpr(Operands, Expr->getNoWrapFlags());
1146 }
1147
1148 const SCEV *visitMulExpr(const SCEVMulExpr *Expr) {
1149 SmallVector<SCEVUse, 2> Operands;
1150 bool Changed = false;
1151 for (SCEVUse Op : Expr->operands()) {
1152 Operands.push_back(visit(Op.getPointer()));
1153 Changed |= Op.getPointer() != Operands.back();
1154 }
1155 return !Changed ? Expr : SE.getMulExpr(Operands, Expr->getNoWrapFlags());
1156 }
1157
1158 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
1159 assert(Expr->getType()->isPointerTy() &&
1160 "Should only reach pointer-typed SCEVUnknown's.");
1161 // Perform some basic constant folding. If the operand of the cast is a
1162 // null pointer, don't create a cast SCEV expression (that will be left
1163 // as-is), but produce a zero constant.
1165 return SE.getZero(TargetTy);
1166 return CreatePtrCast(Expr);
1167 }
1168};
1169
1171 assert(Op->getType()->isPointerTy() && "Op must be a pointer");
1172
1173 // It isn't legal for optimizations to construct new ptrtoint expressions
1174 // for non-integral pointers.
1175 if (getDataLayout().isNonIntegralPointerType(Op->getType()))
1176 return getCouldNotCompute();
1177
1178 Type *IntPtrTy = getDataLayout().getIntPtrType(Op->getType());
1179
1180 // We can only trivially model ptrtoint if SCEV's effective (integer) type
1181 // is sufficiently wide to represent all possible pointer values.
1182 // We could theoretically teach SCEV to truncate wider pointers, but
1183 // that isn't implemented for now.
1185 getDataLayout().getTypeSizeInBits(IntPtrTy))
1186 return getCouldNotCompute();
1187
1188 // Use the rewriter to sink the cast down to SCEVUnknown leaves.
1190 Op, *this, IntPtrTy, [this, IntPtrTy](const SCEVUnknown *U) {
1192 ID.AddInteger(scPtrToInt);
1193 ID.AddPointer(U);
1194 void *IP = nullptr;
1195 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP))
1196 return S;
1197 SCEV *S = new (SCEVAllocator)
1198 SCEVPtrToIntExpr(ID.Intern(SCEVAllocator), U, IntPtrTy);
1199 UniqueSCEVs.InsertNode(S, IP);
1200 S->computeAndSetCanonical(*this);
1201 registerUser(S, U);
1202 return static_cast<const SCEV *>(S);
1203 });
1204 assert(IntOp->getType()->isIntegerTy() &&
1205 "We must have succeeded in sinking the cast, "
1206 "and ending up with an integer-typed expression!");
1207 return IntOp;
1208}
1209
1211 assert(Op->getType()->isPointerTy() && "Op must be a pointer");
1212
1213 // Treat pointers with unstable representation conservatively, since the
1214 // address bits may change.
1215 if (DL.hasUnstableRepresentation(Op->getType()))
1216 return getCouldNotCompute();
1217
1218 Type *Ty = DL.getAddressType(Op->getType());
1219
1220 // Use the rewriter to sink the cast down to SCEVUnknown leaves.
1221 // The rewriter handles null pointer constant folding.
1223 Op, *this, Ty, [this, Ty](const SCEVUnknown *U) {
1225 ID.AddInteger(scPtrToAddr);
1226 ID.AddPointer(U);
1227 ID.AddPointer(Ty);
1228 void *IP = nullptr;
1229 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP))
1230 return S;
1231 SCEV *S = new (SCEVAllocator)
1232 SCEVPtrToAddrExpr(ID.Intern(SCEVAllocator), U, Ty);
1233 UniqueSCEVs.InsertNode(S, IP);
1234 S->computeAndSetCanonical(*this);
1235 registerUser(S, U);
1236 return static_cast<const SCEV *>(S);
1237 });
1238 assert(IntOp->getType()->isIntegerTy() &&
1239 "We must have succeeded in sinking the cast, "
1240 "and ending up with an integer-typed expression!");
1241 return IntOp;
1242}
1243
1245 assert(Ty->isIntegerTy() && "Target type must be an integer type!");
1246
1247 const SCEV *IntOp = getLosslessPtrToIntExpr(Op);
1248 if (isa<SCEVCouldNotCompute>(IntOp))
1249 return IntOp;
1250
1251 return getTruncateOrZeroExtend(IntOp, Ty);
1252}
1253
1255 unsigned Depth) {
1256 assert(getTypeSizeInBits(Op->getType()) > getTypeSizeInBits(Ty) &&
1257 "This is not a truncating conversion!");
1258 assert(isSCEVable(Ty) &&
1259 "This is not a conversion to a SCEVable type!");
1260 assert(!Op->getType()->isPointerTy() && "Can't truncate pointer!");
1261 Ty = getEffectiveSCEVType(Ty);
1262
1264 ID.AddInteger(scTruncate);
1265 ID.AddPointer(Op);
1266 ID.AddPointer(Ty);
1267 void *IP = nullptr;
1268 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
1269
1270 // Fold if the operand is constant.
1271 if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op))
1272 return getConstant(
1273 cast<ConstantInt>(ConstantExpr::getTrunc(SC->getValue(), Ty)));
1274
1275 // trunc(trunc(x)) --> trunc(x)
1277 return getTruncateExpr(ST->getOperand(), Ty, Depth + 1);
1278
1279 // trunc(sext(x)) --> sext(x) if widening or trunc(x) if narrowing
1281 return getTruncateOrSignExtend(SS->getOperand(), Ty, Depth + 1);
1282
1283 // trunc(zext(x)) --> zext(x) if widening or trunc(x) if narrowing
1285 return getTruncateOrZeroExtend(SZ->getOperand(), Ty, Depth + 1);
1286
1287 if (Depth > MaxCastDepth) {
1288 SCEV *S =
1289 new (SCEVAllocator) SCEVTruncateExpr(ID.Intern(SCEVAllocator), Op, Ty);
1290 UniqueSCEVs.InsertNode(S, IP);
1291 S->computeAndSetCanonical(*this);
1292 registerUser(S, Op);
1293 return S;
1294 }
1295
1296 // trunc(x1 + ... + xN) --> trunc(x1) + ... + trunc(xN) and
1297 // trunc(x1 * ... * xN) --> trunc(x1) * ... * trunc(xN),
1298 // if after transforming we have at most one truncate, not counting truncates
1299 // that replace other casts.
1301 auto *CommOp = cast<SCEVCommutativeExpr>(Op);
1302 SmallVector<SCEVUse, 4> Operands;
1303 unsigned numTruncs = 0;
1304 for (unsigned i = 0, e = CommOp->getNumOperands(); i != e && numTruncs < 2;
1305 ++i) {
1306 const SCEV *S = getTruncateExpr(CommOp->getOperand(i), Ty, Depth + 1);
1307 if (!isa<SCEVIntegralCastExpr>(CommOp->getOperand(i)) &&
1309 numTruncs++;
1310 Operands.push_back(S);
1311 }
1312 if (numTruncs < 2) {
1313 if (isa<SCEVAddExpr>(Op))
1314 return getAddExpr(Operands);
1315 if (isa<SCEVMulExpr>(Op))
1316 return getMulExpr(Operands);
1317 llvm_unreachable("Unexpected SCEV type for Op.");
1318 }
1319 // Although we checked in the beginning that ID is not in the cache, it is
1320 // possible that during recursion and different modification ID was inserted
1321 // into the cache. So if we find it, just return it.
1322 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP))
1323 return S;
1324 }
1325
1326 // If the input value is a chrec scev, truncate the chrec's operands.
1327 if (const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(Op)) {
1328 SmallVector<SCEVUse, 4> Operands;
1329 for (const SCEV *Op : AddRec->operands())
1330 Operands.push_back(getTruncateExpr(Op, Ty, Depth + 1));
1331 return getAddRecExpr(Operands, AddRec->getLoop(), SCEV::FlagAnyWrap);
1332 }
1333
1334 // Return zero if truncating to known zeros.
1335 uint32_t MinTrailingZeros = getMinTrailingZeros(Op);
1336 if (MinTrailingZeros >= getTypeSizeInBits(Ty))
1337 return getZero(Ty);
1338
1339 // The cast wasn't folded; create an explicit cast node. We can reuse
1340 // the existing insert position since if we get here, we won't have
1341 // made any changes which would invalidate it.
1342 SCEV *S = new (SCEVAllocator) SCEVTruncateExpr(ID.Intern(SCEVAllocator),
1343 Op, Ty);
1344 UniqueSCEVs.InsertNode(S, IP);
1345 S->computeAndSetCanonical(*this);
1346 registerUser(S, Op);
1347 return S;
1348}
1349
1350// Get the limit of a recurrence such that incrementing by Step cannot cause
1351// signed overflow as long as the value of the recurrence within the
1352// loop does not exceed this limit before incrementing.
1353static const SCEV *getSignedOverflowLimitForStep(const SCEV *Step,
1354 ICmpInst::Predicate *Pred,
1355 ScalarEvolution *SE) {
1356 unsigned BitWidth = SE->getTypeSizeInBits(Step->getType());
1357 if (SE->isKnownPositive(Step)) {
1358 *Pred = ICmpInst::ICMP_SLT;
1360 SE->getSignedRangeMax(Step));
1361 }
1362 if (SE->isKnownNegative(Step)) {
1363 *Pred = ICmpInst::ICMP_SGT;
1365 SE->getSignedRangeMin(Step));
1366 }
1367 return nullptr;
1368}
1369
1370// Get the limit of a recurrence such that incrementing by Step cannot cause
1371// unsigned overflow as long as the value of the recurrence within the loop does
1372// not exceed this limit before incrementing.
1374 ICmpInst::Predicate *Pred,
1375 ScalarEvolution *SE) {
1376 unsigned BitWidth = SE->getTypeSizeInBits(Step->getType());
1377 *Pred = ICmpInst::ICMP_ULT;
1378
1380 SE->getUnsignedRangeMax(Step));
1381}
1382
1383namespace {
1384
1385struct ExtendOpTraitsBase {
1386 typedef const SCEV *(ScalarEvolution::*GetExtendExprTy)(const SCEV *, Type *,
1387 unsigned);
1388};
1389
1390// Used to make code generic over signed and unsigned overflow.
1391template <typename ExtendOp> struct ExtendOpTraits {
1392 // Members present:
1393 //
1394 // static const SCEV::NoWrapFlags WrapType;
1395 //
1396 // static const ExtendOpTraitsBase::GetExtendExprTy GetExtendExpr;
1397 //
1398 // static const SCEV *getOverflowLimitForStep(const SCEV *Step,
1399 // ICmpInst::Predicate *Pred,
1400 // ScalarEvolution *SE);
1401};
1402
1403template <>
1404struct ExtendOpTraits<SCEVSignExtendExpr> : public ExtendOpTraitsBase {
1405 static const SCEV::NoWrapFlags WrapType = SCEV::FlagNSW;
1406
1407 static const GetExtendExprTy GetExtendExpr;
1408
1409 static const SCEV *getOverflowLimitForStep(const SCEV *Step,
1410 ICmpInst::Predicate *Pred,
1411 ScalarEvolution *SE) {
1412 return getSignedOverflowLimitForStep(Step, Pred, SE);
1413 }
1414};
1415
1416const ExtendOpTraitsBase::GetExtendExprTy ExtendOpTraits<
1418
1419template <>
1420struct ExtendOpTraits<SCEVZeroExtendExpr> : public ExtendOpTraitsBase {
1421 static const SCEV::NoWrapFlags WrapType = SCEV::FlagNUW;
1422
1423 static const GetExtendExprTy GetExtendExpr;
1424
1425 static const SCEV *getOverflowLimitForStep(const SCEV *Step,
1426 ICmpInst::Predicate *Pred,
1427 ScalarEvolution *SE) {
1428 return getUnsignedOverflowLimitForStep(Step, Pred, SE);
1429 }
1430};
1431
1432const ExtendOpTraitsBase::GetExtendExprTy ExtendOpTraits<
1434
1435} // end anonymous namespace
1436
1437// The recurrence AR has been shown to have no signed/unsigned wrap or something
1438// close to it. Typically, if we can prove NSW/NUW for AR, then we can just as
1439// easily prove NSW/NUW for its preincrement or postincrement sibling. This
1440// allows normalizing a sign/zero extended AddRec as such: {sext/zext(Step +
1441// Start),+,Step} => {(Step + sext/zext(Start),+,Step} As a result, the
1442// expression "Step + sext/zext(PreIncAR)" is congruent with
1443// "sext/zext(PostIncAR)"
1444template <typename ExtendOpTy>
1445static const SCEV *getPreStartForExtend(const SCEVAddRecExpr *AR, Type *Ty,
1446 ScalarEvolution *SE, unsigned Depth) {
1447 auto WrapType = ExtendOpTraits<ExtendOpTy>::WrapType;
1448 auto GetExtendExpr = ExtendOpTraits<ExtendOpTy>::GetExtendExpr;
1449
1450 const Loop *L = AR->getLoop();
1451 const SCEV *Start = AR->getStart();
1452 const SCEV *Step = AR->getStepRecurrence(*SE);
1453
1454 // Check for a simple looking step prior to loop entry.
1455 const SCEVAddExpr *SA = dyn_cast<SCEVAddExpr>(Start);
1456 if (!SA)
1457 return nullptr;
1458
1459 // Create an AddExpr for "PreStart" after subtracting Step. Full SCEV
1460 // subtraction is expensive. For this purpose, perform a quick and dirty
1461 // difference, by checking for Step in the operand list. Note, that
1462 // SA might have repeated ops, like %a + %a + ..., so only remove one.
1463 SmallVector<SCEVUse, 4> DiffOps(SA->operands());
1464 for (auto It = DiffOps.begin(); It != DiffOps.end(); ++It)
1465 if (*It == Step) {
1466 DiffOps.erase(It);
1467 break;
1468 }
1469
1470 if (DiffOps.size() == SA->getNumOperands())
1471 return nullptr;
1472
1473 // Try to prove `WrapType` (SCEV::FlagNSW or SCEV::FlagNUW) on `PreStart` +
1474 // `Step`:
1475
1476 // 1. NSW/NUW flags on the step increment.
1477 auto PreStartFlags =
1479 const SCEV *PreStart = SE->getAddExpr(DiffOps, PreStartFlags);
1481 SE->getAddRecExpr(PreStart, Step, L, SCEV::FlagAnyWrap));
1482
1483 // "{S,+,X} is <nsw>/<nuw>" and "the backedge is taken at least once" implies
1484 // "S+X does not sign/unsign-overflow".
1485 //
1486
1487 const SCEV *BECount = SE->getBackedgeTakenCount(L);
1488 if (PreAR && any(PreAR->getNoWrapFlags(WrapType)) &&
1489 !isa<SCEVCouldNotCompute>(BECount) && SE->isKnownPositive(BECount))
1490 return PreStart;
1491
1492 // 2. Direct overflow check on the step operation's expression.
1493 unsigned BitWidth = SE->getTypeSizeInBits(AR->getType());
1494 Type *WideTy = IntegerType::get(SE->getContext(), BitWidth * 2);
1495 const SCEV *OperandExtendedStart =
1496 SE->getAddExpr((SE->*GetExtendExpr)(PreStart, WideTy, Depth),
1497 (SE->*GetExtendExpr)(Step, WideTy, Depth));
1498 if ((SE->*GetExtendExpr)(Start, WideTy, Depth) == OperandExtendedStart) {
1499 if (PreAR && any(AR->getNoWrapFlags(WrapType))) {
1500 // If we know `AR` == {`PreStart`+`Step`,+,`Step`} is `WrapType` (FlagNSW
1501 // or FlagNUW) and that `PreStart` + `Step` is `WrapType` too, then
1502 // `PreAR` == {`PreStart`,+,`Step`} is also `WrapType`. Cache this fact.
1503 SE->setNoWrapFlags(const_cast<SCEVAddRecExpr *>(PreAR), WrapType);
1504 }
1505 return PreStart;
1506 }
1507
1508 // 3. Loop precondition.
1510 const SCEV *OverflowLimit =
1511 ExtendOpTraits<ExtendOpTy>::getOverflowLimitForStep(Step, &Pred, SE);
1512
1513 if (OverflowLimit &&
1514 SE->isLoopEntryGuardedByCond(L, Pred, PreStart, OverflowLimit))
1515 return PreStart;
1516
1517 return nullptr;
1518}
1519
1520// Get the normalized zero or sign extended expression for this AddRec's Start.
1521template <typename ExtendOpTy>
1522static const SCEV *getExtendAddRecStart(const SCEVAddRecExpr *AR, Type *Ty,
1523 ScalarEvolution *SE,
1524 unsigned Depth) {
1525 auto GetExtendExpr = ExtendOpTraits<ExtendOpTy>::GetExtendExpr;
1526
1527 const SCEV *PreStart = getPreStartForExtend<ExtendOpTy>(AR, Ty, SE, Depth);
1528 if (!PreStart)
1529 return (SE->*GetExtendExpr)(AR->getStart(), Ty, Depth);
1530
1531 return SE->getAddExpr((SE->*GetExtendExpr)(AR->getStepRecurrence(*SE), Ty,
1532 Depth),
1533 (SE->*GetExtendExpr)(PreStart, Ty, Depth));
1534}
1535
1536// Try to prove away overflow by looking at "nearby" add recurrences. A
1537// motivating example for this rule: if we know `{0,+,4}` is `ult` `-1` and it
1538// does not itself wrap then we can conclude that `{1,+,4}` is `nuw`.
1539//
1540// Formally:
1541//
1542// {S,+,X} == {S-T,+,X} + T
1543// => Ext({S,+,X}) == Ext({S-T,+,X} + T)
1544//
1545// If ({S-T,+,X} + T) does not overflow ... (1)
1546//
1547// RHS == Ext({S-T,+,X} + T) == Ext({S-T,+,X}) + Ext(T)
1548//
1549// If {S-T,+,X} does not overflow ... (2)
1550//
1551// RHS == Ext({S-T,+,X}) + Ext(T) == {Ext(S-T),+,Ext(X)} + Ext(T)
1552// == {Ext(S-T)+Ext(T),+,Ext(X)}
1553//
1554// If (S-T)+T does not overflow ... (3)
1555//
1556// RHS == {Ext(S-T)+Ext(T),+,Ext(X)} == {Ext(S-T+T),+,Ext(X)}
1557// == {Ext(S),+,Ext(X)} == LHS
1558//
1559// Thus, if (1), (2) and (3) are true for some T, then
1560// Ext({S,+,X}) == {Ext(S),+,Ext(X)}
1561//
1562// (3) is implied by (1) -- "(S-T)+T does not overflow" is simply "({S-T,+,X}+T)
1563// does not overflow" restricted to the 0th iteration. Therefore we only need
1564// to check for (1) and (2).
1565//
1566// In the current context, S is `Start`, X is `Step`, Ext is `ExtendOpTy` and T
1567// is `Delta` (defined below).
1568template <typename ExtendOpTy>
1569bool ScalarEvolution::proveNoWrapByVaryingStart(const SCEV *Start,
1570 const SCEV *Step,
1571 const Loop *L) {
1572 auto WrapType = ExtendOpTraits<ExtendOpTy>::WrapType;
1573
1574 // We restrict `Start` to a constant to prevent SCEV from spending too much
1575 // time here. It is correct (but more expensive) to continue with a
1576 // non-constant `Start` and do a general SCEV subtraction to compute
1577 // `PreStart` below.
1578 const SCEVConstant *StartC = dyn_cast<SCEVConstant>(Start);
1579 if (!StartC)
1580 return false;
1581
1582 APInt StartAI = StartC->getAPInt();
1583
1584 for (unsigned Delta : {-2, -1, 1, 2}) {
1585 const SCEV *PreStart = getConstant(StartAI - Delta);
1586
1587 FoldingSetNodeID ID;
1588 ID.AddInteger(scAddRecExpr);
1589 ID.AddPointer(PreStart);
1590 ID.AddPointer(Step);
1591 ID.AddPointer(L);
1592 void *IP = nullptr;
1593 const auto *PreAR =
1594 static_cast<SCEVAddRecExpr *>(UniqueSCEVs.FindNodeOrInsertPos(ID, IP));
1595
1596 // Give up if we don't already have the add recurrence we need because
1597 // actually constructing an add recurrence is relatively expensive.
1598 if (PreAR && any(PreAR->getNoWrapFlags(WrapType))) { // proves (2)
1599 const SCEV *DeltaS = getConstant(StartC->getType(), Delta);
1601 const SCEV *Limit = ExtendOpTraits<ExtendOpTy>::getOverflowLimitForStep(
1602 DeltaS, &Pred, this);
1603 if (Limit && isKnownPredicate(Pred, PreAR, Limit)) // proves (1)
1604 return true;
1605 }
1606 }
1607
1608 return false;
1609}
1610
1611// Finds an integer D for an expression (C + x + y + ...) such that the top
1612// level addition in (D + (C - D + x + y + ...)) would not wrap (signed or
1613// unsigned) and the number of trailing zeros of (C - D + x + y + ...) is
1614// maximized, where C is the \p ConstantTerm, x, y, ... are arbitrary SCEVs, and
1615// the (C + x + y + ...) expression is \p WholeAddExpr.
1617 const SCEVConstant *ConstantTerm,
1618 const SCEVAddExpr *WholeAddExpr) {
1619 const APInt &C = ConstantTerm->getAPInt();
1620 const unsigned BitWidth = C.getBitWidth();
1621 // Find number of trailing zeros of (x + y + ...) w/o the C first:
1622 uint32_t TZ = BitWidth;
1623 for (unsigned I = 1, E = WholeAddExpr->getNumOperands(); I < E && TZ; ++I)
1624 TZ = std::min(TZ, SE.getMinTrailingZeros(WholeAddExpr->getOperand(I)));
1625 if (TZ) {
1626 // Set D to be as many least significant bits of C as possible while still
1627 // guaranteeing that adding D to (C - D + x + y + ...) won't cause a wrap:
1628 return TZ < BitWidth ? C.trunc(TZ).zext(BitWidth) : C;
1629 }
1630 return APInt(BitWidth, 0);
1631}
1632
1633// Finds an integer D for an affine AddRec expression {C,+,x} such that the top
1634// level addition in (D + {C-D,+,x}) would not wrap (signed or unsigned) and the
1635// number of trailing zeros of (C - D + x * n) is maximized, where C is the \p
1636// ConstantStart, x is an arbitrary \p Step, and n is the loop trip count.
1638 const APInt &ConstantStart,
1639 const SCEV *Step) {
1640 const unsigned BitWidth = ConstantStart.getBitWidth();
1641 const uint32_t TZ = SE.getMinTrailingZeros(Step);
1642 if (TZ)
1643 return TZ < BitWidth ? ConstantStart.trunc(TZ).zext(BitWidth)
1644 : ConstantStart;
1645 return APInt(BitWidth, 0);
1646}
1647
1649 const ScalarEvolution::FoldID &ID, const SCEV *S,
1652 &FoldCacheUser) {
1653 auto I = FoldCache.insert({ID, S});
1654 if (!I.second) {
1655 // Remove FoldCacheUser entry for ID when replacing an existing FoldCache
1656 // entry.
1657 auto &UserIDs = FoldCacheUser[I.first->second];
1658 assert(count(UserIDs, ID) == 1 && "unexpected duplicates in UserIDs");
1659 for (unsigned I = 0; I != UserIDs.size(); ++I)
1660 if (UserIDs[I] == ID) {
1661 std::swap(UserIDs[I], UserIDs.back());
1662 break;
1663 }
1664 UserIDs.pop_back();
1665 I.first->second = S;
1666 }
1667 FoldCacheUser[S].push_back(ID);
1668}
1669
1670const SCEV *
1672 assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) &&
1673 "This is not an extending conversion!");
1674 assert(isSCEVable(Ty) &&
1675 "This is not a conversion to a SCEVable type!");
1676 assert(!Op->getType()->isPointerTy() && "Can't extend pointer!");
1677 Ty = getEffectiveSCEVType(Ty);
1678
1679 FoldID ID(scZeroExtend, Op, Ty);
1680 if (const SCEV *S = FoldCache.lookup(ID))
1681 return S;
1682
1683 const SCEV *S = getZeroExtendExprImpl(Op, Ty, Depth);
1685 insertFoldCacheEntry(ID, S, FoldCache, FoldCacheUser);
1686 return S;
1687}
1688
1690 unsigned Depth) {
1691 assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) &&
1692 "This is not an extending conversion!");
1693 assert(isSCEVable(Ty) && "This is not a conversion to a SCEVable type!");
1694 assert(!Op->getType()->isPointerTy() && "Can't extend pointer!");
1695
1696 // Fold if the operand is constant.
1697 if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op))
1698 return getConstant(SC->getAPInt().zext(getTypeSizeInBits(Ty)));
1699
1700 // zext(zext(x)) --> zext(x)
1702 return getZeroExtendExpr(SZ->getOperand(), Ty, Depth + 1);
1703
1704 // Before doing any expensive analysis, check to see if we've already
1705 // computed a SCEV for this Op and Ty.
1707 ID.AddInteger(scZeroExtend);
1708 ID.AddPointer(Op);
1709 ID.AddPointer(Ty);
1710 void *IP = nullptr;
1711 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
1712 if (Depth > MaxCastDepth) {
1713 SCEV *S = new (SCEVAllocator) SCEVZeroExtendExpr(ID.Intern(SCEVAllocator),
1714 Op, Ty);
1715 UniqueSCEVs.InsertNode(S, IP);
1716 S->computeAndSetCanonical(*this);
1717 registerUser(S, Op);
1718 return S;
1719 }
1720
1721 // zext(trunc(x)) --> zext(x) or x or trunc(x)
1723 // It's possible the bits taken off by the truncate were all zero bits. If
1724 // so, we should be able to simplify this further.
1725 const SCEV *X = ST->getOperand();
1727 unsigned TruncBits = getTypeSizeInBits(ST->getType());
1728 unsigned NewBits = getTypeSizeInBits(Ty);
1729 if (CR.truncate(TruncBits).zeroExtend(NewBits).contains(
1730 CR.zextOrTrunc(NewBits)))
1731 return getTruncateOrZeroExtend(X, Ty, Depth);
1732 }
1733
1734 // If the input value is a chrec scev, and we can prove that the value
1735 // did not overflow the old, smaller, value, we can zero extend all of the
1736 // operands (often constants). This allows analysis of something like
1737 // this: for (unsigned char X = 0; X < 100; ++X) { int Y = X; }
1739 if (AR->isAffine()) {
1740 const SCEV *Start = AR->getStart();
1741 const SCEV *Step = AR->getStepRecurrence(*this);
1742 unsigned BitWidth = getTypeSizeInBits(AR->getType());
1743 const Loop *L = AR->getLoop();
1744
1745 // If we have special knowledge that this addrec won't overflow,
1746 // we don't need to do any further analysis.
1747 if (AR->hasNoUnsignedWrap()) {
1748 Start =
1750 Step = getZeroExtendExpr(Step, Ty, Depth + 1);
1751 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
1752 }
1753
1754 // Check whether the backedge-taken count is SCEVCouldNotCompute.
1755 // Note that this serves two purposes: It filters out loops that are
1756 // simply not analyzable, and it covers the case where this code is
1757 // being called from within backedge-taken count analysis, such that
1758 // attempting to ask for the backedge-taken count would likely result
1759 // in infinite recursion. In the later case, the analysis code will
1760 // cope with a conservative value, and it will take care to purge
1761 // that value once it has finished.
1762 const SCEV *MaxBECount = getConstantMaxBackedgeTakenCount(L);
1763 if (!isa<SCEVCouldNotCompute>(MaxBECount)) {
1764 // Manually compute the final value for AR, checking for overflow.
1765
1766 // Check whether the backedge-taken count can be losslessly casted to
1767 // the addrec's type. The count is always unsigned.
1768 const SCEV *CastedMaxBECount =
1769 getTruncateOrZeroExtend(MaxBECount, Start->getType(), Depth);
1770 const SCEV *RecastedMaxBECount = getTruncateOrZeroExtend(
1771 CastedMaxBECount, MaxBECount->getType(), Depth);
1772 if (MaxBECount == RecastedMaxBECount) {
1773 Type *WideTy = IntegerType::get(getContext(), BitWidth * 2);
1774 // Check whether Start+Step*MaxBECount has no unsigned overflow.
1775 const SCEV *ZMul = getMulExpr(CastedMaxBECount, Step,
1777 const SCEV *ZAdd = getZeroExtendExpr(getAddExpr(Start, ZMul,
1779 Depth + 1),
1780 WideTy, Depth + 1);
1781 const SCEV *WideStart = getZeroExtendExpr(Start, WideTy, Depth + 1);
1782 const SCEV *WideMaxBECount =
1783 getZeroExtendExpr(CastedMaxBECount, WideTy, Depth + 1);
1784 const SCEV *OperandExtendedAdd =
1785 getAddExpr(WideStart,
1786 getMulExpr(WideMaxBECount,
1787 getZeroExtendExpr(Step, WideTy, Depth + 1),
1790 if (ZAdd == OperandExtendedAdd) {
1791 // Cache knowledge of AR NUW, which is propagated to this AddRec.
1792 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNUW);
1793 // Return the expression with the addrec on the outside.
1794 Start = getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this,
1795 Depth + 1);
1796 Step = getZeroExtendExpr(Step, Ty, Depth + 1);
1797 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
1798 }
1799 // Similar to above, only this time treat the step value as signed.
1800 // This covers loops that count down.
1801 OperandExtendedAdd =
1802 getAddExpr(WideStart,
1803 getMulExpr(WideMaxBECount,
1804 getSignExtendExpr(Step, WideTy, Depth + 1),
1807 if (ZAdd == OperandExtendedAdd) {
1808 // Cache knowledge of AR NW, which is propagated to this AddRec.
1809 // Negative step causes unsigned wrap, but it still can't self-wrap.
1810 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNW);
1811 // Return the expression with the addrec on the outside.
1812 Start = getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this,
1813 Depth + 1);
1814 Step = getSignExtendExpr(Step, Ty, Depth + 1);
1815 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
1816 }
1817 }
1818 }
1819
1820 // Normally, in the cases we can prove no-overflow via a
1821 // backedge guarding condition, we can also compute a backedge
1822 // taken count for the loop. The exceptions are assumptions and
1823 // guards present in the loop -- SCEV is not great at exploiting
1824 // these to compute max backedge taken counts, but can still use
1825 // these to prove lack of overflow. Use this fact to avoid
1826 // doing extra work that may not pay off.
1827 if (!isa<SCEVCouldNotCompute>(MaxBECount) || HasGuards ||
1828 !AC.assumptions().empty()) {
1829
1830 auto NewFlags = proveNoUnsignedWrapViaInduction(AR);
1831 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), NewFlags);
1832 if (AR->hasNoUnsignedWrap()) {
1833 // Same as nuw case above - duplicated here to avoid a compile time
1834 // issue. It's not clear that the order of checks does matter, but
1835 // it's one of two issue possible causes for a change which was
1836 // reverted. Be conservative for the moment.
1837 Start =
1839 Step = getZeroExtendExpr(Step, Ty, Depth + 1);
1840 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
1841 }
1842
1843 // For a negative step, we can extend the operands iff doing so only
1844 // traverses values in the range zext([0,UINT_MAX]).
1845 if (isKnownNegative(Step)) {
1847 getSignedRangeMin(Step));
1850 // Cache knowledge of AR NW, which is propagated to this
1851 // AddRec. Negative step causes unsigned wrap, but it
1852 // still can't self-wrap.
1853 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNW);
1854 // Return the expression with the addrec on the outside.
1855 Start = getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this,
1856 Depth + 1);
1857 Step = getSignExtendExpr(Step, Ty, Depth + 1);
1858 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
1859 }
1860 }
1861 }
1862
1863 // zext({C,+,Step}) --> (zext(D) + zext({C-D,+,Step}))<nuw><nsw>
1864 // if D + (C - D + Step * n) could be proven to not unsigned wrap
1865 // where D maximizes the number of trailing zeros of (C - D + Step * n)
1866 if (const auto *SC = dyn_cast<SCEVConstant>(Start)) {
1867 const APInt &C = SC->getAPInt();
1868 const APInt &D = extractConstantWithoutWrapping(*this, C, Step);
1869 if (D != 0) {
1870 const SCEV *SZExtD = getZeroExtendExpr(getConstant(D), Ty, Depth);
1871 const SCEV *SResidual =
1872 getAddRecExpr(getConstant(C - D), Step, L, AR->getNoWrapFlags());
1873 const SCEV *SZExtR = getZeroExtendExpr(SResidual, Ty, Depth + 1);
1874 return getAddExpr(SZExtD, SZExtR, SCEV::FlagNSW | SCEV::FlagNUW,
1875 Depth + 1);
1876 }
1877 }
1878
1879 if (proveNoWrapByVaryingStart<SCEVZeroExtendExpr>(Start, Step, L)) {
1880 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNUW);
1881 Start =
1883 Step = getZeroExtendExpr(Step, Ty, Depth + 1);
1884 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
1885 }
1886 }
1887
1888 // zext(A % B) --> zext(A) % zext(B)
1889 {
1890 const SCEV *LHS;
1891 const SCEV *RHS;
1892 if (match(Op, m_scev_URem(m_SCEV(LHS), m_SCEV(RHS), *this)))
1893 return getURemExpr(getZeroExtendExpr(LHS, Ty, Depth + 1),
1894 getZeroExtendExpr(RHS, Ty, Depth + 1));
1895 }
1896
1897 // zext(A / B) --> zext(A) / zext(B).
1898 if (auto *Div = dyn_cast<SCEVUDivExpr>(Op))
1899 return getUDivExpr(getZeroExtendExpr(Div->getLHS(), Ty, Depth + 1),
1900 getZeroExtendExpr(Div->getRHS(), Ty, Depth + 1));
1901
1902 if (auto *SA = dyn_cast<SCEVAddExpr>(Op)) {
1903 // zext((A + B + ...)<nuw>) --> (zext(A) + zext(B) + ...)<nuw>
1904 if (SA->hasNoUnsignedWrap()) {
1905 // If the addition does not unsign overflow then we can, by definition,
1906 // commute the zero extension with the addition operation.
1908 for (SCEVUse Op : SA->operands())
1909 Ops.push_back(getZeroExtendExpr(Op, Ty, Depth + 1));
1910 return getAddExpr(Ops, SCEV::FlagNUW, Depth + 1);
1911 }
1912
1913 // zext(C + x + y + ...) --> (zext(D) + zext((C - D) + x + y + ...))
1914 // if D + (C - D + x + y + ...) could be proven to not unsigned wrap
1915 // where D maximizes the number of trailing zeros of (C - D + x + y + ...)
1916 //
1917 // Often address arithmetics contain expressions like
1918 // (zext (add (shl X, C1), C2)), for instance, (zext (5 + (4 * X))).
1919 // This transformation is useful while proving that such expressions are
1920 // equal or differ by a small constant amount, see LoadStoreVectorizer pass.
1921 if (const auto *SC = dyn_cast<SCEVConstant>(SA->getOperand(0))) {
1922 const APInt &D = extractConstantWithoutWrapping(*this, SC, SA);
1923 if (D != 0) {
1924 const SCEV *SZExtD = getZeroExtendExpr(getConstant(D), Ty, Depth);
1925 const SCEV *SResidual =
1927 const SCEV *SZExtR = getZeroExtendExpr(SResidual, Ty, Depth + 1);
1928 return getAddExpr(SZExtD, SZExtR, (SCEV::FlagNSW | SCEV::FlagNUW),
1929 Depth + 1);
1930 }
1931 }
1932 }
1933
1934 if (auto *SM = dyn_cast<SCEVMulExpr>(Op)) {
1935 // zext((A * B * ...)<nuw>) --> (zext(A) * zext(B) * ...)<nuw>
1936 if (SM->hasNoUnsignedWrap()) {
1937 // If the multiply does not unsign overflow then we can, by definition,
1938 // commute the zero extension with the multiply operation.
1940 for (SCEVUse Op : SM->operands())
1941 Ops.push_back(getZeroExtendExpr(Op, Ty, Depth + 1));
1942 return getMulExpr(Ops, SCEV::FlagNUW, Depth + 1);
1943 }
1944
1945 // zext(2^K * (trunc X to iN)) to iM ->
1946 // 2^K * (zext(trunc X to i{N-K}) to iM)<nuw>
1947 //
1948 // Proof:
1949 //
1950 // zext(2^K * (trunc X to iN)) to iM
1951 // = zext((trunc X to iN) << K) to iM
1952 // = zext((trunc X to i{N-K}) << K)<nuw> to iM
1953 // (because shl removes the top K bits)
1954 // = zext((2^K * (trunc X to i{N-K}))<nuw>) to iM
1955 // = (2^K * (zext(trunc X to i{N-K}) to iM))<nuw>.
1956 //
1957 const APInt *C;
1958 const SCEV *TruncRHS;
1959 if (match(SM,
1960 m_scev_Mul(m_scev_APInt(C), m_scev_Trunc(m_SCEV(TruncRHS)))) &&
1961 C->isPowerOf2()) {
1962 int NewTruncBits =
1963 getTypeSizeInBits(SM->getOperand(1)->getType()) - C->logBase2();
1964 Type *NewTruncTy = IntegerType::get(getContext(), NewTruncBits);
1965 return getMulExpr(
1966 getZeroExtendExpr(SM->getOperand(0), Ty),
1967 getZeroExtendExpr(getTruncateExpr(TruncRHS, NewTruncTy), Ty),
1968 SCEV::FlagNUW, Depth + 1);
1969 }
1970 }
1971
1972 // zext(umin(x, y)) -> umin(zext(x), zext(y))
1973 // zext(umax(x, y)) -> umax(zext(x), zext(y))
1976 SmallVector<SCEVUse, 4> Operands;
1977 for (SCEVUse Operand : MinMax->operands())
1978 Operands.push_back(getZeroExtendExpr(Operand, Ty));
1980 return getUMinExpr(Operands);
1981 return getUMaxExpr(Operands);
1982 }
1983
1984 // zext(umin_seq(x, y)) -> umin_seq(zext(x), zext(y))
1986 assert(isa<SCEVSequentialUMinExpr>(MinMax) && "Not supported!");
1987 SmallVector<SCEVUse, 4> Operands;
1988 for (SCEVUse Operand : MinMax->operands())
1989 Operands.push_back(getZeroExtendExpr(Operand, Ty));
1990 return getUMinExpr(Operands, /*Sequential*/ true);
1991 }
1992
1993 // The cast wasn't folded; create an explicit cast node.
1994 // Recompute the insert position, as it may have been invalidated.
1995 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
1996 SCEV *S = new (SCEVAllocator) SCEVZeroExtendExpr(ID.Intern(SCEVAllocator),
1997 Op, Ty);
1998 UniqueSCEVs.InsertNode(S, IP);
1999 S->computeAndSetCanonical(*this);
2000 registerUser(S, Op);
2001 return S;
2002}
2003
2004const SCEV *
2006 assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) &&
2007 "This is not an extending conversion!");
2008 assert(isSCEVable(Ty) &&
2009 "This is not a conversion to a SCEVable type!");
2010 assert(!Op->getType()->isPointerTy() && "Can't extend pointer!");
2011 Ty = getEffectiveSCEVType(Ty);
2012
2013 FoldID ID(scSignExtend, Op, Ty);
2014 if (const SCEV *S = FoldCache.lookup(ID))
2015 return S;
2016
2017 const SCEV *S = getSignExtendExprImpl(Op, Ty, Depth);
2019 insertFoldCacheEntry(ID, S, FoldCache, FoldCacheUser);
2020 return S;
2021}
2022
2024 unsigned Depth) {
2025 assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) &&
2026 "This is not an extending conversion!");
2027 assert(isSCEVable(Ty) && "This is not a conversion to a SCEVable type!");
2028 assert(!Op->getType()->isPointerTy() && "Can't extend pointer!");
2029 Ty = getEffectiveSCEVType(Ty);
2030
2031 // Fold if the operand is constant.
2032 if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op))
2033 return getConstant(SC->getAPInt().sext(getTypeSizeInBits(Ty)));
2034
2035 // sext(sext(x)) --> sext(x)
2037 return getSignExtendExpr(SS->getOperand(), Ty, Depth + 1);
2038
2039 // sext(zext(x)) --> zext(x)
2041 return getZeroExtendExpr(SZ->getOperand(), Ty, Depth + 1);
2042
2043 // Before doing any expensive analysis, check to see if we've already
2044 // computed a SCEV for this Op and Ty.
2046 ID.AddInteger(scSignExtend);
2047 ID.AddPointer(Op);
2048 ID.AddPointer(Ty);
2049 void *IP = nullptr;
2050 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
2051 // Limit recursion depth.
2052 if (Depth > MaxCastDepth) {
2053 SCEV *S = new (SCEVAllocator) SCEVSignExtendExpr(ID.Intern(SCEVAllocator),
2054 Op, Ty);
2055 UniqueSCEVs.InsertNode(S, IP);
2056 S->computeAndSetCanonical(*this);
2057 registerUser(S, Op);
2058 return S;
2059 }
2060
2061 // sext(trunc(x)) --> sext(x) or x or trunc(x)
2063 // It's possible the bits taken off by the truncate were all sign bits. If
2064 // so, we should be able to simplify this further.
2065 const SCEV *X = ST->getOperand();
2067 unsigned TruncBits = getTypeSizeInBits(ST->getType());
2068 unsigned NewBits = getTypeSizeInBits(Ty);
2069 if (CR.truncate(TruncBits).signExtend(NewBits).contains(
2070 CR.sextOrTrunc(NewBits)))
2071 return getTruncateOrSignExtend(X, Ty, Depth);
2072 }
2073
2074 if (auto *SA = dyn_cast<SCEVAddExpr>(Op)) {
2075 // sext((A + B + ...)<nsw>) --> (sext(A) + sext(B) + ...)<nsw>
2076 if (SA->hasNoSignedWrap()) {
2077 // If the addition does not sign overflow then we can, by definition,
2078 // commute the sign extension with the addition operation.
2080 for (SCEVUse Op : SA->operands())
2081 Ops.push_back(getSignExtendExpr(Op, Ty, Depth + 1));
2082 return getAddExpr(Ops, SCEV::FlagNSW, Depth + 1);
2083 }
2084
2085 // sext(C + x + y + ...) --> (sext(D) + sext((C - D) + x + y + ...))
2086 // if D + (C - D + x + y + ...) could be proven to not signed wrap
2087 // where D maximizes the number of trailing zeros of (C - D + x + y + ...)
2088 //
2089 // For instance, this will bring two seemingly different expressions:
2090 // 1 + sext(5 + 20 * %x + 24 * %y) and
2091 // sext(6 + 20 * %x + 24 * %y)
2092 // to the same form:
2093 // 2 + sext(4 + 20 * %x + 24 * %y)
2094 if (const auto *SC = dyn_cast<SCEVConstant>(SA->getOperand(0))) {
2095 const APInt &D = extractConstantWithoutWrapping(*this, SC, SA);
2096 if (D != 0) {
2097 const SCEV *SSExtD = getSignExtendExpr(getConstant(D), Ty, Depth);
2098 const SCEV *SResidual =
2100 const SCEV *SSExtR = getSignExtendExpr(SResidual, Ty, Depth + 1);
2101 return getAddExpr(SSExtD, SSExtR, (SCEV::FlagNSW | SCEV::FlagNUW),
2102 Depth + 1);
2103 }
2104 }
2105 }
2106 // If the input value is a chrec scev, and we can prove that the value
2107 // did not overflow the old, smaller, value, we can sign extend all of the
2108 // operands (often constants). This allows analysis of something like
2109 // this: for (signed char X = 0; X < 100; ++X) { int Y = X; }
2111 if (AR->isAffine()) {
2112 const SCEV *Start = AR->getStart();
2113 const SCEV *Step = AR->getStepRecurrence(*this);
2114 unsigned BitWidth = getTypeSizeInBits(AR->getType());
2115 const Loop *L = AR->getLoop();
2116
2117 // If we have special knowledge that this addrec won't overflow,
2118 // we don't need to do any further analysis.
2119 if (AR->hasNoSignedWrap()) {
2120 Start =
2122 Step = getSignExtendExpr(Step, Ty, Depth + 1);
2123 return getAddRecExpr(Start, Step, L, SCEV::FlagNSW);
2124 }
2125
2126 // Check whether the backedge-taken count is SCEVCouldNotCompute.
2127 // Note that this serves two purposes: It filters out loops that are
2128 // simply not analyzable, and it covers the case where this code is
2129 // being called from within backedge-taken count analysis, such that
2130 // attempting to ask for the backedge-taken count would likely result
2131 // in infinite recursion. In the later case, the analysis code will
2132 // cope with a conservative value, and it will take care to purge
2133 // that value once it has finished.
2134 const SCEV *MaxBECount = getConstantMaxBackedgeTakenCount(L);
2135 if (!isa<SCEVCouldNotCompute>(MaxBECount)) {
2136 // Manually compute the final value for AR, checking for
2137 // overflow.
2138
2139 // Check whether the backedge-taken count can be losslessly casted to
2140 // the addrec's type. The count is always unsigned.
2141 const SCEV *CastedMaxBECount =
2142 getTruncateOrZeroExtend(MaxBECount, Start->getType(), Depth);
2143 const SCEV *RecastedMaxBECount = getTruncateOrZeroExtend(
2144 CastedMaxBECount, MaxBECount->getType(), Depth);
2145 if (MaxBECount == RecastedMaxBECount) {
2146 Type *WideTy = IntegerType::get(getContext(), BitWidth * 2);
2147 // Check whether Start+Step*MaxBECount has no signed overflow.
2148 const SCEV *SMul = getMulExpr(CastedMaxBECount, Step,
2150 const SCEV *SAdd = getSignExtendExpr(getAddExpr(Start, SMul,
2152 Depth + 1),
2153 WideTy, Depth + 1);
2154 const SCEV *WideStart = getSignExtendExpr(Start, WideTy, Depth + 1);
2155 const SCEV *WideMaxBECount =
2156 getZeroExtendExpr(CastedMaxBECount, WideTy, Depth + 1);
2157 const SCEV *OperandExtendedAdd =
2158 getAddExpr(WideStart,
2159 getMulExpr(WideMaxBECount,
2160 getSignExtendExpr(Step, WideTy, Depth + 1),
2163 if (SAdd == OperandExtendedAdd) {
2164 // Cache knowledge of AR NSW, which is propagated to this AddRec.
2165 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNSW);
2166 // Return the expression with the addrec on the outside.
2167 Start = getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this,
2168 Depth + 1);
2169 Step = getSignExtendExpr(Step, Ty, Depth + 1);
2170 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
2171 }
2172 // Similar to above, only this time treat the step value as unsigned.
2173 // This covers loops that count up with an unsigned step.
2174 OperandExtendedAdd =
2175 getAddExpr(WideStart,
2176 getMulExpr(WideMaxBECount,
2177 getZeroExtendExpr(Step, WideTy, Depth + 1),
2180 if (SAdd == OperandExtendedAdd) {
2181 // If AR wraps around then
2182 //
2183 // abs(Step) * MaxBECount > unsigned-max(AR->getType())
2184 // => SAdd != OperandExtendedAdd
2185 //
2186 // Thus (AR is not NW => SAdd != OperandExtendedAdd) <=>
2187 // (SAdd == OperandExtendedAdd => AR is NW)
2188
2189 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNW);
2190
2191 // Return the expression with the addrec on the outside.
2192 Start = getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this,
2193 Depth + 1);
2194 Step = getZeroExtendExpr(Step, Ty, Depth + 1);
2195 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
2196 }
2197 }
2198 }
2199
2200 auto NewFlags = proveNoSignedWrapViaInduction(AR);
2201 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), NewFlags);
2202 if (AR->hasNoSignedWrap()) {
2203 // Same as nsw case above - duplicated here to avoid a compile time
2204 // issue. It's not clear that the order of checks does matter, but
2205 // it's one of two issue possible causes for a change which was
2206 // reverted. Be conservative for the moment.
2207 Start =
2209 Step = getSignExtendExpr(Step, Ty, Depth + 1);
2210 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
2211 }
2212
2213 // sext({C,+,Step}) --> (sext(D) + sext({C-D,+,Step}))<nuw><nsw>
2214 // if D + (C - D + Step * n) could be proven to not signed wrap
2215 // where D maximizes the number of trailing zeros of (C - D + Step * n)
2216 if (const auto *SC = dyn_cast<SCEVConstant>(Start)) {
2217 const APInt &C = SC->getAPInt();
2218 const APInt &D = extractConstantWithoutWrapping(*this, C, Step);
2219 if (D != 0) {
2220 const SCEV *SSExtD = getSignExtendExpr(getConstant(D), Ty, Depth);
2221 const SCEV *SResidual =
2222 getAddRecExpr(getConstant(C - D), Step, L, AR->getNoWrapFlags());
2223 const SCEV *SSExtR = getSignExtendExpr(SResidual, Ty, Depth + 1);
2224 return getAddExpr(SSExtD, SSExtR, (SCEV::FlagNSW | SCEV::FlagNUW),
2225 Depth + 1);
2226 }
2227 }
2228
2229 if (proveNoWrapByVaryingStart<SCEVSignExtendExpr>(Start, Step, L)) {
2230 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNSW);
2231 Start =
2233 Step = getSignExtendExpr(Step, Ty, Depth + 1);
2234 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
2235 }
2236 }
2237
2238 // If the input value is provably positive and we could not simplify
2239 // away the sext build a zext instead.
2241 return getZeroExtendExpr(Op, Ty, Depth + 1);
2242
2243 // sext(smin(x, y)) -> smin(sext(x), sext(y))
2244 // sext(smax(x, y)) -> smax(sext(x), sext(y))
2247 SmallVector<SCEVUse, 4> Operands;
2248 for (SCEVUse Operand : MinMax->operands())
2249 Operands.push_back(getSignExtendExpr(Operand, Ty));
2251 return getSMinExpr(Operands);
2252 return getSMaxExpr(Operands);
2253 }
2254
2255 // The cast wasn't folded; create an explicit cast node.
2256 // Recompute the insert position, as it may have been invalidated.
2257 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
2258 SCEV *S = new (SCEVAllocator) SCEVSignExtendExpr(ID.Intern(SCEVAllocator),
2259 Op, Ty);
2260 UniqueSCEVs.InsertNode(S, IP);
2261 S->computeAndSetCanonical(*this);
2262 registerUser(S, Op);
2263 return S;
2264}
2265
2267 Type *Ty) {
2268 switch (Kind) {
2269 case scTruncate:
2270 return getTruncateExpr(Op, Ty);
2271 case scZeroExtend:
2272 return getZeroExtendExpr(Op, Ty);
2273 case scSignExtend:
2274 return getSignExtendExpr(Op, Ty);
2275 case scPtrToInt:
2276 return getPtrToIntExpr(Op, Ty);
2277 default:
2278 llvm_unreachable("Not a SCEV cast expression!");
2279 }
2280}
2281
2282/// getAnyExtendExpr - Return a SCEV for the given operand extended with
2283/// unspecified bits out to the given type.
2285 Type *Ty) {
2286 assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) &&
2287 "This is not an extending conversion!");
2288 assert(isSCEVable(Ty) &&
2289 "This is not a conversion to a SCEVable type!");
2290 Ty = getEffectiveSCEVType(Ty);
2291
2292 // Sign-extend negative constants.
2293 if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op))
2294 if (SC->getAPInt().isNegative())
2295 return getSignExtendExpr(Op, Ty);
2296
2297 // Peel off a truncate cast.
2299 const SCEV *NewOp = T->getOperand();
2300 if (getTypeSizeInBits(NewOp->getType()) < getTypeSizeInBits(Ty))
2301 return getAnyExtendExpr(NewOp, Ty);
2302 return getTruncateOrNoop(NewOp, Ty);
2303 }
2304
2305 // Next try a zext cast. If the cast is folded, use it.
2306 const SCEV *ZExt = getZeroExtendExpr(Op, Ty);
2307 if (!isa<SCEVZeroExtendExpr>(ZExt))
2308 return ZExt;
2309
2310 // Next try a sext cast. If the cast is folded, use it.
2311 const SCEV *SExt = getSignExtendExpr(Op, Ty);
2312 if (!isa<SCEVSignExtendExpr>(SExt))
2313 return SExt;
2314
2315 // Force the cast to be folded into the operands of an addrec.
2316 if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Op)) {
2318 for (const SCEV *Op : AR->operands())
2319 Ops.push_back(getAnyExtendExpr(Op, Ty));
2320 return getAddRecExpr(Ops, AR->getLoop(), SCEV::FlagNW);
2321 }
2322
2323 // If the expression is obviously signed, use the sext cast value.
2324 if (isa<SCEVSMaxExpr>(Op))
2325 return SExt;
2326
2327 // Absent any other information, use the zext cast value.
2328 return ZExt;
2329}
2330
2331/// Process the given Ops list, which is a list of operands to be added under
2332/// the given scale, update the given map. This is a helper function for
2333/// getAddRecExpr. As an example of what it does, given a sequence of operands
2334/// that would form an add expression like this:
2335///
2336/// m + n + 13 + (A * (o + p + (B * (q + m + 29)))) + r + (-1 * r)
2337///
2338/// where A and B are constants, update the map with these values:
2339///
2340/// (m, 1+A*B), (n, 1), (o, A), (p, A), (q, A*B), (r, 0)
2341///
2342/// and add 13 + A*B*29 to AccumulatedConstant.
2343/// This will allow getAddRecExpr to produce this:
2344///
2345/// 13+A*B*29 + n + (m * (1+A*B)) + ((o + p) * A) + (q * A*B)
2346///
2347/// This form often exposes folding opportunities that are hidden in
2348/// the original operand list.
2349///
2350/// Return true iff it appears that any interesting folding opportunities
2351/// may be exposed. This helps getAddRecExpr short-circuit extra work in
2352/// the common case where no interesting opportunities are present, and
2353/// is also used as a check to avoid infinite recursion.
2356 APInt &AccumulatedConstant,
2358 const APInt &Scale,
2359 ScalarEvolution &SE) {
2360 bool Interesting = false;
2361
2362 // Iterate over the add operands. They are sorted, with constants first.
2363 unsigned i = 0;
2364 while (const SCEVConstant *C = dyn_cast<SCEVConstant>(Ops[i])) {
2365 ++i;
2366 // Pull a buried constant out to the outside.
2367 if (Scale != 1 || AccumulatedConstant != 0 || C->getValue()->isZero())
2368 Interesting = true;
2369 AccumulatedConstant += Scale * C->getAPInt();
2370 }
2371
2372 // Next comes everything else. We're especially interested in multiplies
2373 // here, but they're in the middle, so just visit the rest with one loop.
2374 for (; i != Ops.size(); ++i) {
2376 if (Mul && isa<SCEVConstant>(Mul->getOperand(0))) {
2377 APInt NewScale =
2378 Scale * cast<SCEVConstant>(Mul->getOperand(0))->getAPInt();
2379 if (Mul->getNumOperands() == 2 && isa<SCEVAddExpr>(Mul->getOperand(1))) {
2380 // A multiplication of a constant with another add; recurse.
2381 const SCEVAddExpr *Add = cast<SCEVAddExpr>(Mul->getOperand(1));
2382 Interesting |= CollectAddOperandsWithScales(
2383 M, NewOps, AccumulatedConstant, Add->operands(), NewScale, SE);
2384 } else {
2385 // A multiplication of a constant with some other value. Update
2386 // the map.
2387 SmallVector<SCEVUse, 4> MulOps(drop_begin(Mul->operands()));
2388 const SCEV *Key = SE.getMulExpr(MulOps);
2389 auto Pair = M.insert({Key, NewScale});
2390 if (Pair.second) {
2391 NewOps.push_back(Pair.first->first);
2392 } else {
2393 Pair.first->second += NewScale;
2394 // The map already had an entry for this value, which may indicate
2395 // a folding opportunity.
2396 Interesting = true;
2397 }
2398 }
2399 } else {
2400 // An ordinary operand. Update the map.
2401 auto Pair = M.insert({Ops[i], Scale});
2402 if (Pair.second) {
2403 NewOps.push_back(Pair.first->first);
2404 } else {
2405 Pair.first->second += Scale;
2406 // The map already had an entry for this value, which may indicate
2407 // a folding opportunity.
2408 Interesting = true;
2409 }
2410 }
2411 }
2412
2413 return Interesting;
2414}
2415
2417 const SCEV *LHS, const SCEV *RHS,
2418 const Instruction *CtxI) {
2420 unsigned);
2421 switch (BinOp) {
2422 default:
2423 llvm_unreachable("Unsupported binary op");
2424 case Instruction::Add:
2426 break;
2427 case Instruction::Sub:
2429 break;
2430 case Instruction::Mul:
2432 break;
2433 }
2434
2435 const SCEV *(ScalarEvolution::*Extension)(const SCEV *, Type *, unsigned) =
2438
2439 // Check ext(LHS op RHS) == ext(LHS) op ext(RHS)
2440 auto *NarrowTy = cast<IntegerType>(LHS->getType());
2441 auto *WideTy =
2442 IntegerType::get(NarrowTy->getContext(), NarrowTy->getBitWidth() * 2);
2443
2444 const SCEV *A = (this->*Extension)(
2445 (this->*Operation)(LHS, RHS, SCEV::FlagAnyWrap, 0), WideTy, 0);
2446 const SCEV *LHSB = (this->*Extension)(LHS, WideTy, 0);
2447 const SCEV *RHSB = (this->*Extension)(RHS, WideTy, 0);
2448 const SCEV *B = (this->*Operation)(LHSB, RHSB, SCEV::FlagAnyWrap, 0);
2449 if (A == B)
2450 return true;
2451 // Can we use context to prove the fact we need?
2452 if (!CtxI)
2453 return false;
2454 // TODO: Support mul.
2455 if (BinOp == Instruction::Mul)
2456 return false;
2457 auto *RHSC = dyn_cast<SCEVConstant>(RHS);
2458 // TODO: Lift this limitation.
2459 if (!RHSC)
2460 return false;
2461 APInt C = RHSC->getAPInt();
2462 unsigned NumBits = C.getBitWidth();
2463 bool IsSub = (BinOp == Instruction::Sub);
2464 bool IsNegativeConst = (Signed && C.isNegative());
2465 // Compute the direction and magnitude by which we need to check overflow.
2466 bool OverflowDown = IsSub ^ IsNegativeConst;
2467 APInt Magnitude = C;
2468 if (IsNegativeConst) {
2469 if (C == APInt::getSignedMinValue(NumBits))
2470 // TODO: SINT_MIN on inversion gives the same negative value, we don't
2471 // want to deal with that.
2472 return false;
2473 Magnitude = -C;
2474 }
2475
2477 if (OverflowDown) {
2478 // To avoid overflow down, we need to make sure that MIN + Magnitude <= LHS.
2479 APInt Min = Signed ? APInt::getSignedMinValue(NumBits)
2480 : APInt::getMinValue(NumBits);
2481 APInt Limit = Min + Magnitude;
2482 return isKnownPredicateAt(Pred, getConstant(Limit), LHS, CtxI);
2483 } else {
2484 // To avoid overflow up, we need to make sure that LHS <= MAX - Magnitude.
2485 APInt Max = Signed ? APInt::getSignedMaxValue(NumBits)
2486 : APInt::getMaxValue(NumBits);
2487 APInt Limit = Max - Magnitude;
2488 return isKnownPredicateAt(Pred, LHS, getConstant(Limit), CtxI);
2489 }
2490}
2491
2492std::optional<SCEV::NoWrapFlags>
2494 const OverflowingBinaryOperator *OBO) {
2495 // It cannot be done any better.
2496 if (OBO->hasNoUnsignedWrap() && OBO->hasNoSignedWrap())
2497 return std::nullopt;
2498
2499 SCEV::NoWrapFlags Flags = SCEV::NoWrapFlags::FlagAnyWrap;
2500
2501 if (OBO->hasNoUnsignedWrap())
2503 if (OBO->hasNoSignedWrap())
2505
2506 bool Deduced = false;
2507
2508 if (OBO->getOpcode() != Instruction::Add &&
2509 OBO->getOpcode() != Instruction::Sub &&
2510 OBO->getOpcode() != Instruction::Mul)
2511 return std::nullopt;
2512
2513 const SCEV *LHS = getSCEV(OBO->getOperand(0));
2514 const SCEV *RHS = getSCEV(OBO->getOperand(1));
2515
2516 const Instruction *CtxI =
2518 if (!OBO->hasNoUnsignedWrap() &&
2520 /* Signed */ false, LHS, RHS, CtxI)) {
2522 Deduced = true;
2523 }
2524
2525 if (!OBO->hasNoSignedWrap() &&
2527 /* Signed */ true, LHS, RHS, CtxI)) {
2529 Deduced = true;
2530 }
2531
2532 if (Deduced)
2533 return Flags;
2534 return std::nullopt;
2535}
2536
2537// We're trying to construct a SCEV of type `Type' with `Ops' as operands and
2538// `OldFlags' as can't-wrap behavior. Infer a more aggressive set of
2539// can't-overflow flags for the operation if possible.
2543 SCEV::NoWrapFlags Flags) {
2544 using namespace std::placeholders;
2545
2546 using OBO = OverflowingBinaryOperator;
2547
2548 bool CanAnalyze =
2550 (void)CanAnalyze;
2551 assert(CanAnalyze && "don't call from other places!");
2552
2553 SCEV::NoWrapFlags SignOrUnsignMask = SCEV::FlagNUW | SCEV::FlagNSW;
2554 SCEV::NoWrapFlags SignOrUnsignWrap =
2555 ScalarEvolution::maskFlags(Flags, SignOrUnsignMask);
2556
2557 // If FlagNSW is true and all the operands are non-negative, infer FlagNUW.
2558 auto IsKnownNonNegative = [&](SCEVUse U) {
2559 return SE->isKnownNonNegative(U);
2560 };
2561
2562 if (SignOrUnsignWrap == SCEV::FlagNSW && all_of(Ops, IsKnownNonNegative))
2563 Flags = ScalarEvolution::setFlags(Flags, SignOrUnsignMask);
2564
2565 SignOrUnsignWrap = ScalarEvolution::maskFlags(Flags, SignOrUnsignMask);
2566
2567 if (SignOrUnsignWrap != SignOrUnsignMask &&
2568 (Type == scAddExpr || Type == scMulExpr) && Ops.size() == 2 &&
2569 isa<SCEVConstant>(Ops[0])) {
2570
2571 auto Opcode = [&] {
2572 switch (Type) {
2573 case scAddExpr:
2574 return Instruction::Add;
2575 case scMulExpr:
2576 return Instruction::Mul;
2577 default:
2578 llvm_unreachable("Unexpected SCEV op.");
2579 }
2580 }();
2581
2582 const APInt &C = cast<SCEVConstant>(Ops[0])->getAPInt();
2583
2584 // (A <opcode> C) --> (A <opcode> C)<nsw> if the op doesn't sign overflow.
2585 if (!(SignOrUnsignWrap & SCEV::FlagNSW)) {
2587 Opcode, C, OBO::NoSignedWrap);
2588 if (NSWRegion.contains(SE->getSignedRange(Ops[1])))
2590 }
2591
2592 // (A <opcode> C) --> (A <opcode> C)<nuw> if the op doesn't unsign overflow.
2593 if (!(SignOrUnsignWrap & SCEV::FlagNUW)) {
2595 Opcode, C, OBO::NoUnsignedWrap);
2596 if (NUWRegion.contains(SE->getUnsignedRange(Ops[1])))
2598 }
2599 }
2600
2601 // <0,+,nonnegative><nw> is also nuw
2602 // TODO: Add corresponding nsw case
2604 !ScalarEvolution::hasFlags(Flags, SCEV::FlagNUW) && Ops.size() == 2 &&
2605 Ops[0]->isZero() && IsKnownNonNegative(Ops[1]))
2607
2608 // both (udiv X, Y) * Y and Y * (udiv X, Y) are always NUW
2610 Ops.size() == 2) {
2611 if (auto *UDiv = dyn_cast<SCEVUDivExpr>(Ops[0]))
2612 if (UDiv->getOperand(1) == Ops[1])
2614 if (auto *UDiv = dyn_cast<SCEVUDivExpr>(Ops[1]))
2615 if (UDiv->getOperand(1) == Ops[0])
2617 }
2618
2619 return Flags;
2620}
2621
2623 return isLoopInvariant(S, L) && properlyDominates(S, L->getHeader());
2624}
2625
2626/// Get a canonical add expression, or something simpler if possible.
2628 SCEV::NoWrapFlags OrigFlags,
2629 unsigned Depth) {
2630 assert(!(OrigFlags & ~(SCEV::FlagNUW | SCEV::FlagNSW)) &&
2631 "only nuw or nsw allowed");
2632 assert(!Ops.empty() && "Cannot get empty add!");
2633 if (Ops.size() == 1) return Ops[0];
2634#ifndef NDEBUG
2635 Type *ETy = getEffectiveSCEVType(Ops[0]->getType());
2636 for (unsigned i = 1, e = Ops.size(); i != e; ++i)
2637 assert(getEffectiveSCEVType(Ops[i]->getType()) == ETy &&
2638 "SCEVAddExpr operand types don't match!");
2639 unsigned NumPtrs = count_if(
2640 Ops, [](const SCEV *Op) { return Op->getType()->isPointerTy(); });
2641 assert(NumPtrs <= 1 && "add has at most one pointer operand");
2642#endif
2643
2644 const SCEV *Folded = constantFoldAndGroupOps(
2645 *this, LI, DT, Ops,
2646 [](const APInt &C1, const APInt &C2) { return C1 + C2; },
2647 [](const APInt &C) { return C.isZero(); }, // identity
2648 [](const APInt &C) { return false; }); // absorber
2649 if (Folded)
2650 return Folded;
2651
2652 unsigned Idx = isa<SCEVConstant>(Ops[0]) ? 1 : 0;
2653
2654 // Delay expensive flag strengthening until necessary.
2655 auto ComputeFlags = [this, OrigFlags](ArrayRef<SCEVUse> Ops) {
2656 return StrengthenNoWrapFlags(this, scAddExpr, Ops, OrigFlags);
2657 };
2658
2659 // Limit recursion calls depth.
2661 return getOrCreateAddExpr(Ops, ComputeFlags(Ops));
2662
2663 if (SCEV *S = findExistingSCEVInCache(scAddExpr, Ops)) {
2664 // Don't strengthen flags if we have no new information.
2665 SCEVAddExpr *Add = static_cast<SCEVAddExpr *>(S);
2666 if (Add->getNoWrapFlags(OrigFlags) != OrigFlags)
2667 Add->setNoWrapFlags(ComputeFlags(Ops));
2668 return S;
2669 }
2670
2671 // Okay, check to see if the same value occurs in the operand list more than
2672 // once. If so, merge them together into an multiply expression. Since we
2673 // sorted the list, these values are required to be adjacent.
2674 Type *Ty = Ops[0]->getType();
2675 bool FoundMatch = false;
2676 for (unsigned i = 0, e = Ops.size(); i != e-1; ++i)
2677 if (Ops[i] == Ops[i+1]) { // X + Y + Y --> X + Y*2
2678 // Scan ahead to count how many equal operands there are.
2679 unsigned Count = 2;
2680 while (i+Count != e && Ops[i+Count] == Ops[i])
2681 ++Count;
2682 // Merge the values into a multiply.
2683 SCEVUse Scale = getConstant(Ty, Count);
2684 const SCEV *Mul = getMulExpr(Scale, Ops[i], SCEV::FlagAnyWrap, Depth + 1);
2685 if (Ops.size() == Count)
2686 return Mul;
2687 Ops[i] = Mul;
2688 Ops.erase(Ops.begin()+i+1, Ops.begin()+i+Count);
2689 --i; e -= Count - 1;
2690 FoundMatch = true;
2691 }
2692 if (FoundMatch)
2693 return getAddExpr(Ops, OrigFlags, Depth + 1);
2694
2695 // Check for truncates. If all the operands are truncated from the same
2696 // type, see if factoring out the truncate would permit the result to be
2697 // folded. eg., n*trunc(x) + m*trunc(y) --> trunc(trunc(m)*x + trunc(n)*y)
2698 // if the contents of the resulting outer trunc fold to something simple.
2699 auto FindTruncSrcType = [&]() -> Type * {
2700 // We're ultimately looking to fold an addrec of truncs and muls of only
2701 // constants and truncs, so if we find any other types of SCEV
2702 // as operands of the addrec then we bail and return nullptr here.
2703 // Otherwise, we return the type of the operand of a trunc that we find.
2704 if (auto *T = dyn_cast<SCEVTruncateExpr>(Ops[Idx]))
2705 return T->getOperand()->getType();
2706 if (const auto *Mul = dyn_cast<SCEVMulExpr>(Ops[Idx])) {
2707 SCEVUse LastOp = Mul->getOperand(Mul->getNumOperands() - 1);
2708 if (const auto *T = dyn_cast<SCEVTruncateExpr>(LastOp))
2709 return T->getOperand()->getType();
2710 }
2711 return nullptr;
2712 };
2713 if (auto *SrcType = FindTruncSrcType()) {
2714 SmallVector<SCEVUse, 8> LargeOps;
2715 bool Ok = true;
2716 // Check all the operands to see if they can be represented in the
2717 // source type of the truncate.
2718 for (const SCEV *Op : Ops) {
2720 if (T->getOperand()->getType() != SrcType) {
2721 Ok = false;
2722 break;
2723 }
2724 LargeOps.push_back(T->getOperand());
2725 } else if (const SCEVConstant *C = dyn_cast<SCEVConstant>(Op)) {
2726 LargeOps.push_back(getAnyExtendExpr(C, SrcType));
2727 } else if (const SCEVMulExpr *M = dyn_cast<SCEVMulExpr>(Op)) {
2728 SmallVector<SCEVUse, 8> LargeMulOps;
2729 for (unsigned j = 0, f = M->getNumOperands(); j != f && Ok; ++j) {
2730 if (const SCEVTruncateExpr *T =
2731 dyn_cast<SCEVTruncateExpr>(M->getOperand(j))) {
2732 if (T->getOperand()->getType() != SrcType) {
2733 Ok = false;
2734 break;
2735 }
2736 LargeMulOps.push_back(T->getOperand());
2737 } else if (const auto *C = dyn_cast<SCEVConstant>(M->getOperand(j))) {
2738 LargeMulOps.push_back(getAnyExtendExpr(C, SrcType));
2739 } else {
2740 Ok = false;
2741 break;
2742 }
2743 }
2744 if (Ok)
2745 LargeOps.push_back(getMulExpr(LargeMulOps, SCEV::FlagAnyWrap, Depth + 1));
2746 } else {
2747 Ok = false;
2748 break;
2749 }
2750 }
2751 if (Ok) {
2752 // Evaluate the expression in the larger type.
2753 const SCEV *Fold = getAddExpr(LargeOps, SCEV::FlagAnyWrap, Depth + 1);
2754 // If it folds to something simple, use it. Otherwise, don't.
2755 if (isa<SCEVConstant>(Fold) || isa<SCEVUnknown>(Fold))
2756 return getTruncateExpr(Fold, Ty);
2757 }
2758 }
2759
2760 if (Ops.size() == 2) {
2761 // Check if we have an expression of the form ((X + C1) - C2), where C1 and
2762 // C2 can be folded in a way that allows retaining wrapping flags of (X +
2763 // C1).
2764 const SCEV *A = Ops[0];
2765 const SCEV *B = Ops[1];
2766 auto *AddExpr = dyn_cast<SCEVAddExpr>(B);
2767 auto *C = dyn_cast<SCEVConstant>(A);
2768 if (AddExpr && C && isa<SCEVConstant>(AddExpr->getOperand(0))) {
2769 auto C1 = cast<SCEVConstant>(AddExpr->getOperand(0))->getAPInt();
2770 auto C2 = C->getAPInt();
2771 SCEV::NoWrapFlags PreservedFlags = SCEV::FlagAnyWrap;
2772
2773 APInt ConstAdd = C1 + C2;
2774 auto AddFlags = AddExpr->getNoWrapFlags();
2775 // Adding a smaller constant is NUW if the original AddExpr was NUW.
2777 ConstAdd.ule(C1)) {
2778 PreservedFlags =
2780 }
2781
2782 // Adding a constant with the same sign and small magnitude is NSW, if the
2783 // original AddExpr was NSW.
2785 C1.isSignBitSet() == ConstAdd.isSignBitSet() &&
2786 ConstAdd.abs().ule(C1.abs())) {
2787 PreservedFlags =
2789 }
2790
2791 if (PreservedFlags != SCEV::FlagAnyWrap) {
2792 SmallVector<SCEVUse, 4> NewOps(AddExpr->operands());
2793 NewOps[0] = getConstant(ConstAdd);
2794 return getAddExpr(NewOps, PreservedFlags);
2795 }
2796 }
2797
2798 // Try to push the constant operand into a ZExt: A + zext (-A + B) -> zext
2799 // (B), if trunc (A) + -A + B does not unsigned-wrap.
2800 const SCEVAddExpr *InnerAdd;
2801 if (match(B, m_scev_ZExt(m_scev_Add(InnerAdd)))) {
2802 const SCEV *NarrowA = getTruncateExpr(A, InnerAdd->getType());
2803 if (NarrowA == getNegativeSCEV(InnerAdd->getOperand(0)) &&
2804 getZeroExtendExpr(NarrowA, B->getType()) == A &&
2805 hasFlags(StrengthenNoWrapFlags(this, scAddExpr, {NarrowA, InnerAdd},
2807 SCEV::FlagNUW)) {
2808 return getZeroExtendExpr(getAddExpr(NarrowA, InnerAdd), B->getType());
2809 }
2810 }
2811 }
2812
2813 // Canonicalize (-1 * urem X, Y) + X --> (Y * X/Y)
2814 const SCEV *Y;
2815 if (Ops.size() == 2 &&
2816 match(Ops[0],
2818 m_scev_URem(m_scev_Specific(Ops[1]), m_SCEV(Y), *this))))
2819 return getMulExpr(Y, getUDivExpr(Ops[1], Y));
2820
2821 // Skip past any other cast SCEVs.
2822 while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scAddExpr)
2823 ++Idx;
2824
2825 // If there are add operands they would be next.
2826 if (Idx < Ops.size()) {
2827 bool DeletedAdd = false;
2828 // If the original flags and all inlined SCEVAddExprs are NUW, use the
2829 // common NUW flag for expression after inlining. Other flags cannot be
2830 // preserved, because they may depend on the original order of operations.
2831 SCEV::NoWrapFlags CommonFlags = maskFlags(OrigFlags, SCEV::FlagNUW);
2832 while (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Ops[Idx])) {
2833 if (Ops.size() > AddOpsInlineThreshold ||
2834 Add->getNumOperands() > AddOpsInlineThreshold)
2835 break;
2836 // If we have an add, expand the add operands onto the end of the operands
2837 // list.
2838 Ops.erase(Ops.begin()+Idx);
2839 append_range(Ops, Add->operands());
2840 DeletedAdd = true;
2841 CommonFlags = maskFlags(CommonFlags, Add->getNoWrapFlags());
2842 }
2843
2844 // If we deleted at least one add, we added operands to the end of the list,
2845 // and they are not necessarily sorted. Recurse to resort and resimplify
2846 // any operands we just acquired.
2847 if (DeletedAdd)
2848 return getAddExpr(Ops, CommonFlags, Depth + 1);
2849 }
2850
2851 // Skip over the add expression until we get to a multiply.
2852 while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scMulExpr)
2853 ++Idx;
2854
2855 // Check to see if there are any folding opportunities present with
2856 // operands multiplied by constant values.
2857 if (Idx < Ops.size() && isa<SCEVMulExpr>(Ops[Idx])) {
2861 APInt AccumulatedConstant(BitWidth, 0);
2862 if (CollectAddOperandsWithScales(M, NewOps, AccumulatedConstant,
2863 Ops, APInt(BitWidth, 1), *this)) {
2864 struct APIntCompare {
2865 bool operator()(const APInt &LHS, const APInt &RHS) const {
2866 return LHS.ult(RHS);
2867 }
2868 };
2869
2870 // Some interesting folding opportunity is present, so its worthwhile to
2871 // re-generate the operands list. Group the operands by constant scale,
2872 // to avoid multiplying by the same constant scale multiple times.
2873 std::map<APInt, SmallVector<SCEVUse, 4>, APIntCompare> MulOpLists;
2874 for (const SCEV *NewOp : NewOps)
2875 MulOpLists[M.find(NewOp)->second].push_back(NewOp);
2876 // Re-generate the operands list.
2877 Ops.clear();
2878 if (AccumulatedConstant != 0)
2879 Ops.push_back(getConstant(AccumulatedConstant));
2880 for (auto &MulOp : MulOpLists) {
2881 if (MulOp.first == 1) {
2882 Ops.push_back(getAddExpr(MulOp.second, SCEV::FlagAnyWrap, Depth + 1));
2883 } else if (MulOp.first != 0) {
2884 Ops.push_back(getMulExpr(
2885 getConstant(MulOp.first),
2886 getAddExpr(MulOp.second, SCEV::FlagAnyWrap, Depth + 1),
2887 SCEV::FlagAnyWrap, Depth + 1));
2888 }
2889 }
2890 if (Ops.empty())
2891 return getZero(Ty);
2892 if (Ops.size() == 1)
2893 return Ops[0];
2894 return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
2895 }
2896 }
2897
2898 // If we are adding something to a multiply expression, make sure the
2899 // something is not already an operand of the multiply. If so, merge it into
2900 // the multiply.
2901 for (; Idx < Ops.size() && isa<SCEVMulExpr>(Ops[Idx]); ++Idx) {
2902 const SCEVMulExpr *Mul = cast<SCEVMulExpr>(Ops[Idx]);
2903 for (unsigned MulOp = 0, e = Mul->getNumOperands(); MulOp != e; ++MulOp) {
2904 const SCEV *MulOpSCEV = Mul->getOperand(MulOp);
2905 if (isa<SCEVConstant>(MulOpSCEV))
2906 continue;
2907 for (unsigned AddOp = 0, e = Ops.size(); AddOp != e; ++AddOp)
2908 if (MulOpSCEV == Ops[AddOp]) {
2909 // Fold W + X + (X * Y * Z) --> W + (X * ((Y*Z)+1))
2910 const SCEV *InnerMul = Mul->getOperand(MulOp == 0);
2911 if (Mul->getNumOperands() != 2) {
2912 // If the multiply has more than two operands, we must get the
2913 // Y*Z term.
2914 SmallVector<SCEVUse, 4> MulOps(Mul->operands().take_front(MulOp));
2915 append_range(MulOps, Mul->operands().drop_front(MulOp + 1));
2916 InnerMul = getMulExpr(MulOps, SCEV::FlagAnyWrap, Depth + 1);
2917 }
2918 const SCEV *AddOne =
2919 getAddExpr(getOne(Ty), InnerMul, SCEV::FlagAnyWrap, Depth + 1);
2920 const SCEV *OuterMul = getMulExpr(AddOne, MulOpSCEV,
2922 if (Ops.size() == 2) return OuterMul;
2923 if (AddOp < Idx) {
2924 Ops.erase(Ops.begin()+AddOp);
2925 Ops.erase(Ops.begin()+Idx-1);
2926 } else {
2927 Ops.erase(Ops.begin()+Idx);
2928 Ops.erase(Ops.begin()+AddOp-1);
2929 }
2930 Ops.push_back(OuterMul);
2931 return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
2932 }
2933
2934 // Check this multiply against other multiplies being added together.
2935 for (unsigned OtherMulIdx = Idx+1;
2936 OtherMulIdx < Ops.size() && isa<SCEVMulExpr>(Ops[OtherMulIdx]);
2937 ++OtherMulIdx) {
2938 const SCEVMulExpr *OtherMul = cast<SCEVMulExpr>(Ops[OtherMulIdx]);
2939 // If MulOp occurs in OtherMul, we can fold the two multiplies
2940 // together.
2941 for (unsigned OMulOp = 0, e = OtherMul->getNumOperands();
2942 OMulOp != e; ++OMulOp)
2943 if (OtherMul->getOperand(OMulOp) == MulOpSCEV) {
2944 // Fold X + (A*B*C) + (A*D*E) --> X + (A*(B*C+D*E))
2945 const SCEV *InnerMul1 = Mul->getOperand(MulOp == 0);
2946 if (Mul->getNumOperands() != 2) {
2947 SmallVector<SCEVUse, 4> MulOps(Mul->operands().take_front(MulOp));
2948 append_range(MulOps, Mul->operands().drop_front(MulOp+1));
2949 InnerMul1 = getMulExpr(MulOps, SCEV::FlagAnyWrap, Depth + 1);
2950 }
2951 const SCEV *InnerMul2 = OtherMul->getOperand(OMulOp == 0);
2952 if (OtherMul->getNumOperands() != 2) {
2954 OtherMul->operands().take_front(OMulOp));
2955 append_range(MulOps, OtherMul->operands().drop_front(OMulOp+1));
2956 InnerMul2 = getMulExpr(MulOps, SCEV::FlagAnyWrap, Depth + 1);
2957 }
2958 const SCEV *InnerMulSum =
2959 getAddExpr(InnerMul1, InnerMul2, SCEV::FlagAnyWrap, Depth + 1);
2960 const SCEV *OuterMul = getMulExpr(MulOpSCEV, InnerMulSum,
2962 if (Ops.size() == 2) return OuterMul;
2963 Ops.erase(Ops.begin()+Idx);
2964 Ops.erase(Ops.begin()+OtherMulIdx-1);
2965 Ops.push_back(OuterMul);
2966 return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
2967 }
2968 }
2969 }
2970 }
2971
2972 // If there are any add recurrences in the operands list, see if any other
2973 // added values are loop invariant. If so, we can fold them into the
2974 // recurrence.
2975 while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scAddRecExpr)
2976 ++Idx;
2977
2978 // Scan over all recurrences, trying to fold loop invariants into them.
2979 for (; Idx < Ops.size() && isa<SCEVAddRecExpr>(Ops[Idx]); ++Idx) {
2980 // Scan all of the other operands to this add and add them to the vector if
2981 // they are loop invariant w.r.t. the recurrence.
2983 const SCEVAddRecExpr *AddRec = cast<SCEVAddRecExpr>(Ops[Idx]);
2984 const Loop *AddRecLoop = AddRec->getLoop();
2985 for (unsigned i = 0, e = Ops.size(); i != e; ++i)
2986 if (isAvailableAtLoopEntry(Ops[i], AddRecLoop)) {
2987 LIOps.push_back(Ops[i]);
2988 Ops.erase(Ops.begin()+i);
2989 --i; --e;
2990 }
2991
2992 // If we found some loop invariants, fold them into the recurrence.
2993 if (!LIOps.empty()) {
2994 // Compute nowrap flags for the addition of the loop-invariant ops and
2995 // the addrec. Temporarily push it as an operand for that purpose. These
2996 // flags are valid in the scope of the addrec only.
2997 LIOps.push_back(AddRec);
2998 SCEV::NoWrapFlags Flags = ComputeFlags(LIOps);
2999 LIOps.pop_back();
3000
3001 // NLI + LI + {Start,+,Step} --> NLI + {LI+Start,+,Step}
3002 LIOps.push_back(AddRec->getStart());
3003
3004 SmallVector<SCEVUse, 4> AddRecOps(AddRec->operands());
3005
3006 // It is not in general safe to propagate flags valid on an add within
3007 // the addrec scope to one outside it. We must prove that the inner
3008 // scope is guaranteed to execute if the outer one does to be able to
3009 // safely propagate. We know the program is undefined if poison is
3010 // produced on the inner scoped addrec. We also know that *for this use*
3011 // the outer scoped add can't overflow (because of the flags we just
3012 // computed for the inner scoped add) without the program being undefined.
3013 // Proving that entry to the outer scope neccesitates entry to the inner
3014 // scope, thus proves the program undefined if the flags would be violated
3015 // in the outer scope.
3016 SCEV::NoWrapFlags AddFlags = Flags;
3017 if (AddFlags != SCEV::FlagAnyWrap) {
3018 auto *DefI = getDefiningScopeBound(LIOps);
3019 auto *ReachI = &*AddRecLoop->getHeader()->begin();
3020 if (!isGuaranteedToTransferExecutionTo(DefI, ReachI))
3021 AddFlags = SCEV::FlagAnyWrap;
3022 }
3023 AddRecOps[0] = getAddExpr(LIOps, AddFlags, Depth + 1);
3024
3025 // Build the new addrec. Propagate the NUW and NSW flags if both the
3026 // outer add and the inner addrec are guaranteed to have no overflow.
3027 // Always propagate NW.
3028 Flags = AddRec->getNoWrapFlags(setFlags(Flags, SCEV::FlagNW));
3029 const SCEV *NewRec = getAddRecExpr(AddRecOps, AddRecLoop, Flags);
3030
3031 // If all of the other operands were loop invariant, we are done.
3032 if (Ops.size() == 1) return NewRec;
3033
3034 // Otherwise, add the folded AddRec by the non-invariant parts.
3035 for (unsigned i = 0;; ++i)
3036 if (Ops[i] == AddRec) {
3037 Ops[i] = NewRec;
3038 break;
3039 }
3040 return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
3041 }
3042
3043 // Okay, if there weren't any loop invariants to be folded, check to see if
3044 // there are multiple AddRec's with the same loop induction variable being
3045 // added together. If so, we can fold them.
3046 for (unsigned OtherIdx = Idx+1;
3047 OtherIdx < Ops.size() && isa<SCEVAddRecExpr>(Ops[OtherIdx]);
3048 ++OtherIdx) {
3049 // We expect the AddRecExpr's to be sorted in reverse dominance order,
3050 // so that the 1st found AddRecExpr is dominated by all others.
3051 assert(DT.dominates(
3052 cast<SCEVAddRecExpr>(Ops[OtherIdx])->getLoop()->getHeader(),
3053 AddRec->getLoop()->getHeader()) &&
3054 "AddRecExprs are not sorted in reverse dominance order?");
3055 if (AddRecLoop == cast<SCEVAddRecExpr>(Ops[OtherIdx])->getLoop()) {
3056 // Other + {A,+,B}<L> + {C,+,D}<L> --> Other + {A+C,+,B+D}<L>
3057 SmallVector<SCEVUse, 4> AddRecOps(AddRec->operands());
3058 for (; OtherIdx != Ops.size() && isa<SCEVAddRecExpr>(Ops[OtherIdx]);
3059 ++OtherIdx) {
3060 const auto *OtherAddRec = cast<SCEVAddRecExpr>(Ops[OtherIdx]);
3061 if (OtherAddRec->getLoop() == AddRecLoop) {
3062 for (unsigned i = 0, e = OtherAddRec->getNumOperands();
3063 i != e; ++i) {
3064 if (i >= AddRecOps.size()) {
3065 append_range(AddRecOps, OtherAddRec->operands().drop_front(i));
3066 break;
3067 }
3068 AddRecOps[i] =
3069 getAddExpr(AddRecOps[i], OtherAddRec->getOperand(i),
3071 }
3072 Ops.erase(Ops.begin() + OtherIdx); --OtherIdx;
3073 }
3074 }
3075 // Step size has changed, so we cannot guarantee no self-wraparound.
3076 Ops[Idx] = getAddRecExpr(AddRecOps, AddRecLoop, SCEV::FlagAnyWrap);
3077 return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
3078 }
3079 }
3080
3081 // Otherwise couldn't fold anything into this recurrence. Move onto the
3082 // next one.
3083 }
3084
3085 // Okay, it looks like we really DO need an add expr. Check to see if we
3086 // already have one, otherwise create a new one.
3087 return getOrCreateAddExpr(Ops, ComputeFlags(Ops));
3088}
3089
3090const SCEV *ScalarEvolution::getOrCreateAddExpr(ArrayRef<SCEVUse> Ops,
3091 SCEV::NoWrapFlags Flags) {
3093 ID.AddInteger(scAddExpr);
3094 for (const SCEV *Op : Ops)
3095 ID.AddPointer(Op);
3096 void *IP = nullptr;
3097 SCEVAddExpr *S =
3098 static_cast<SCEVAddExpr *>(UniqueSCEVs.FindNodeOrInsertPos(ID, IP));
3099 if (!S) {
3100 SCEVUse *O = SCEVAllocator.Allocate<SCEVUse>(Ops.size());
3102 S = new (SCEVAllocator)
3103 SCEVAddExpr(ID.Intern(SCEVAllocator), O, Ops.size());
3104 UniqueSCEVs.InsertNode(S, IP);
3105 S->computeAndSetCanonical(*this);
3106 registerUser(S, Ops);
3107 }
3108 S->setNoWrapFlags(Flags);
3109 return S;
3110}
3111
3112const SCEV *ScalarEvolution::getOrCreateAddRecExpr(ArrayRef<SCEVUse> Ops,
3113 const Loop *L,
3114 SCEV::NoWrapFlags Flags) {
3115 FoldingSetNodeID ID;
3116 ID.AddInteger(scAddRecExpr);
3117 for (const SCEV *Op : Ops)
3118 ID.AddPointer(Op);
3119 ID.AddPointer(L);
3120 void *IP = nullptr;
3121 SCEVAddRecExpr *S =
3122 static_cast<SCEVAddRecExpr *>(UniqueSCEVs.FindNodeOrInsertPos(ID, IP));
3123 if (!S) {
3124 SCEVUse *O = SCEVAllocator.Allocate<SCEVUse>(Ops.size());
3126 S = new (SCEVAllocator)
3127 SCEVAddRecExpr(ID.Intern(SCEVAllocator), O, Ops.size(), L);
3128 UniqueSCEVs.InsertNode(S, IP);
3129 S->computeAndSetCanonical(*this);
3130 LoopUsers[L].push_back(S);
3131 registerUser(S, Ops);
3132 }
3133 setNoWrapFlags(S, Flags);
3134 return S;
3135}
3136
3137const SCEV *ScalarEvolution::getOrCreateMulExpr(ArrayRef<SCEVUse> Ops,
3138 SCEV::NoWrapFlags Flags) {
3139 FoldingSetNodeID ID;
3140 ID.AddInteger(scMulExpr);
3141 for (const SCEV *Op : Ops)
3142 ID.AddPointer(Op);
3143 void *IP = nullptr;
3144 SCEVMulExpr *S =
3145 static_cast<SCEVMulExpr *>(UniqueSCEVs.FindNodeOrInsertPos(ID, IP));
3146 if (!S) {
3147 SCEVUse *O = SCEVAllocator.Allocate<SCEVUse>(Ops.size());
3149 S = new (SCEVAllocator) SCEVMulExpr(ID.Intern(SCEVAllocator),
3150 O, Ops.size());
3151 UniqueSCEVs.InsertNode(S, IP);
3152 S->computeAndSetCanonical(*this);
3153 registerUser(S, Ops);
3154 }
3155 S->setNoWrapFlags(Flags);
3156 return S;
3157}
3158
3159static uint64_t umul_ov(uint64_t i, uint64_t j, bool &Overflow) {
3160 uint64_t k = i*j;
3161 if (j > 1 && k / j != i) Overflow = true;
3162 return k;
3163}
3164
3165/// Compute the result of "n choose k", the binomial coefficient. If an
3166/// intermediate computation overflows, Overflow will be set and the return will
3167/// be garbage. Overflow is not cleared on absence of overflow.
3168static uint64_t Choose(uint64_t n, uint64_t k, bool &Overflow) {
3169 // We use the multiplicative formula:
3170 // n(n-1)(n-2)...(n-(k-1)) / k(k-1)(k-2)...1 .
3171 // At each iteration, we take the n-th term of the numeral and divide by the
3172 // (k-n)th term of the denominator. This division will always produce an
3173 // integral result, and helps reduce the chance of overflow in the
3174 // intermediate computations. However, we can still overflow even when the
3175 // final result would fit.
3176
3177 if (n == 0 || n == k) return 1;
3178 if (k > n) return 0;
3179
3180 if (k > n/2)
3181 k = n-k;
3182
3183 uint64_t r = 1;
3184 for (uint64_t i = 1; i <= k; ++i) {
3185 r = umul_ov(r, n-(i-1), Overflow);
3186 r /= i;
3187 }
3188 return r;
3189}
3190
3191/// Determine if any of the operands in this SCEV are a constant or if
3192/// any of the add or multiply expressions in this SCEV contain a constant.
3193static bool containsConstantInAddMulChain(const SCEV *StartExpr) {
3194 struct FindConstantInAddMulChain {
3195 bool FoundConstant = false;
3196
3197 bool follow(const SCEV *S) {
3198 FoundConstant |= isa<SCEVConstant>(S);
3199 return isa<SCEVAddExpr>(S) || isa<SCEVMulExpr>(S);
3200 }
3201
3202 bool isDone() const {
3203 return FoundConstant;
3204 }
3205 };
3206
3207 FindConstantInAddMulChain F;
3209 ST.visitAll(StartExpr);
3210 return F.FoundConstant;
3211}
3212
3213/// Get a canonical multiply expression, or something simpler if possible.
3215 SCEV::NoWrapFlags OrigFlags,
3216 unsigned Depth) {
3217 assert(OrigFlags == maskFlags(OrigFlags, SCEV::FlagNUW | SCEV::FlagNSW) &&
3218 "only nuw or nsw allowed");
3219 assert(!Ops.empty() && "Cannot get empty mul!");
3220 if (Ops.size() == 1) return Ops[0];
3221#ifndef NDEBUG
3222 Type *ETy = Ops[0]->getType();
3223 assert(!ETy->isPointerTy());
3224 for (unsigned i = 1, e = Ops.size(); i != e; ++i)
3225 assert(Ops[i]->getType() == ETy &&
3226 "SCEVMulExpr operand types don't match!");
3227#endif
3228
3229 const SCEV *Folded = constantFoldAndGroupOps(
3230 *this, LI, DT, Ops,
3231 [](const APInt &C1, const APInt &C2) { return C1 * C2; },
3232 [](const APInt &C) { return C.isOne(); }, // identity
3233 [](const APInt &C) { return C.isZero(); }); // absorber
3234 if (Folded)
3235 return Folded;
3236
3237 // Delay expensive flag strengthening until necessary.
3238 auto ComputeFlags = [this, OrigFlags](const ArrayRef<SCEVUse> Ops) {
3239 return StrengthenNoWrapFlags(this, scMulExpr, Ops, OrigFlags);
3240 };
3241
3242 // Limit recursion calls depth.
3244 return getOrCreateMulExpr(Ops, ComputeFlags(Ops));
3245
3246 if (SCEV *S = findExistingSCEVInCache(scMulExpr, Ops)) {
3247 // Don't strengthen flags if we have no new information.
3248 SCEVMulExpr *Mul = static_cast<SCEVMulExpr *>(S);
3249 if (Mul->getNoWrapFlags(OrigFlags) != OrigFlags)
3250 Mul->setNoWrapFlags(ComputeFlags(Ops));
3251 return S;
3252 }
3253
3254 if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(Ops[0])) {
3255 if (Ops.size() == 2) {
3256 // C1*(C2+V) -> C1*C2 + C1*V
3257 // If any of Add's ops are Adds or Muls with a constant, apply this
3258 // transformation as well.
3259 //
3260 // TODO: There are some cases where this transformation is not
3261 // profitable; for example, Add = (C0 + X) * Y + Z. Maybe the scope of
3262 // this transformation should be narrowed down.
3263 const SCEV *Op0, *Op1;
3264 if (match(Ops[1], m_scev_Add(m_SCEV(Op0), m_SCEV(Op1))) &&
3266 const SCEV *LHS = getMulExpr(LHSC, Op0, SCEV::FlagAnyWrap, Depth + 1);
3267 const SCEV *RHS = getMulExpr(LHSC, Op1, SCEV::FlagAnyWrap, Depth + 1);
3268 return getAddExpr(LHS, RHS, SCEV::FlagAnyWrap, Depth + 1);
3269 }
3270
3271 if (Ops[0]->isAllOnesValue()) {
3272 // If we have a mul by -1 of an add, try distributing the -1 among the
3273 // add operands.
3274 if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Ops[1])) {
3276 bool AnyFolded = false;
3277 for (const SCEV *AddOp : Add->operands()) {
3278 const SCEV *Mul = getMulExpr(Ops[0], SCEVUse(AddOp),
3280 if (!isa<SCEVMulExpr>(Mul)) AnyFolded = true;
3281 NewOps.push_back(Mul);
3282 }
3283 if (AnyFolded)
3284 return getAddExpr(NewOps, SCEV::FlagAnyWrap, Depth + 1);
3285 } else if (const auto *AddRec = dyn_cast<SCEVAddRecExpr>(Ops[1])) {
3286 // Negation preserves a recurrence's no self-wrap property.
3287 SmallVector<SCEVUse, 4> Operands;
3288 for (const SCEV *AddRecOp : AddRec->operands())
3289 Operands.push_back(getMulExpr(Ops[0], SCEVUse(AddRecOp),
3290 SCEV::FlagAnyWrap, Depth + 1));
3291 // Let M be the minimum representable signed value. AddRec with nsw
3292 // multiplied by -1 can have signed overflow if and only if it takes a
3293 // value of M: M * (-1) would stay M and (M + 1) * (-1) would be the
3294 // maximum signed value. In all other cases signed overflow is
3295 // impossible.
3296 auto FlagsMask = SCEV::FlagNW;
3297 if (AddRec->hasNoSignedWrap()) {
3298 auto MinInt =
3299 APInt::getSignedMinValue(getTypeSizeInBits(AddRec->getType()));
3300 if (getSignedRangeMin(AddRec) != MinInt)
3301 FlagsMask = setFlags(FlagsMask, SCEV::FlagNSW);
3302 }
3303 return getAddRecExpr(Operands, AddRec->getLoop(),
3304 AddRec->getNoWrapFlags(FlagsMask));
3305 }
3306 }
3307
3308 // Try to push the constant operand into a ZExt: C * zext (A + B) ->
3309 // zext (C*A + C*B) if trunc (C) * (A + B) does not unsigned-wrap.
3310 const SCEVAddExpr *InnerAdd;
3311 if (match(Ops[1], m_scev_ZExt(m_scev_Add(InnerAdd)))) {
3312 const SCEV *NarrowC = getTruncateExpr(LHSC, InnerAdd->getType());
3313 if (isa<SCEVConstant>(InnerAdd->getOperand(0)) &&
3314 getZeroExtendExpr(NarrowC, Ops[1]->getType()) == LHSC &&
3315 hasFlags(StrengthenNoWrapFlags(this, scMulExpr, {NarrowC, InnerAdd},
3317 SCEV::FlagNUW)) {
3318 auto *Res = getMulExpr(NarrowC, InnerAdd, SCEV::FlagNUW, Depth + 1);
3319 return getZeroExtendExpr(Res, Ops[1]->getType(), Depth + 1);
3320 };
3321 }
3322
3323 // Try to fold (C1 * D /u C2) -> C1/C2 * D, if C1 and C2 are powers-of-2,
3324 // D is a multiple of C2, and C1 is a multiple of C2. If C2 is a multiple
3325 // of C1, fold to (D /u (C2 /u C1)).
3326 const SCEV *D;
3327 APInt C1V = LHSC->getAPInt();
3328 // (C1 * D /u C2) == -1 * -C1 * D /u C2 when C1 != INT_MIN. Don't treat -1
3329 // as -1 * 1, as it won't enable additional folds.
3330 if (C1V.isNegative() && !C1V.isMinSignedValue() && !C1V.isAllOnes())
3331 C1V = C1V.abs();
3332 const SCEVConstant *C2;
3333 if (C1V.isPowerOf2() &&
3335 C2->getAPInt().isPowerOf2() &&
3336 C1V.logBase2() <= getMinTrailingZeros(D)) {
3337 const SCEV *NewMul = nullptr;
3338 if (C1V.uge(C2->getAPInt())) {
3339 NewMul = getMulExpr(getUDivExpr(getConstant(C1V), C2), D);
3340 } else if (C2->getAPInt().logBase2() <= getMinTrailingZeros(D)) {
3341 assert(C1V.ugt(1) && "C1 <= 1 should have been folded earlier");
3342 NewMul = getUDivExpr(D, getUDivExpr(C2, getConstant(C1V)));
3343 }
3344 if (NewMul)
3345 return C1V == LHSC->getAPInt() ? NewMul : getNegativeSCEV(NewMul);
3346 }
3347 }
3348 }
3349
3350 // Skip over the add expression until we get to a multiply.
3351 unsigned Idx = 0;
3352 while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scMulExpr)
3353 ++Idx;
3354
3355 // If there are mul operands inline them all into this expression.
3356 if (Idx < Ops.size()) {
3357 bool DeletedMul = false;
3358 while (const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(Ops[Idx])) {
3359 if (Ops.size() > MulOpsInlineThreshold)
3360 break;
3361 // If we have an mul, expand the mul operands onto the end of the
3362 // operands list.
3363 Ops.erase(Ops.begin()+Idx);
3364 append_range(Ops, Mul->operands());
3365 DeletedMul = true;
3366 }
3367
3368 // If we deleted at least one mul, we added operands to the end of the
3369 // list, and they are not necessarily sorted. Recurse to resort and
3370 // resimplify any operands we just acquired.
3371 if (DeletedMul)
3372 return getMulExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
3373 }
3374
3375 // If there are any add recurrences in the operands list, see if any other
3376 // added values are loop invariant. If so, we can fold them into the
3377 // recurrence.
3378 while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scAddRecExpr)
3379 ++Idx;
3380
3381 // Scan over all recurrences, trying to fold loop invariants into them.
3382 for (; Idx < Ops.size() && isa<SCEVAddRecExpr>(Ops[Idx]); ++Idx) {
3383 // Scan all of the other operands to this mul and add them to the vector
3384 // if they are loop invariant w.r.t. the recurrence.
3386 const SCEVAddRecExpr *AddRec = cast<SCEVAddRecExpr>(Ops[Idx]);
3387 for (unsigned i = 0, e = Ops.size(); i != e; ++i)
3388 if (isAvailableAtLoopEntry(Ops[i], AddRec->getLoop())) {
3389 LIOps.push_back(Ops[i]);
3390 Ops.erase(Ops.begin()+i);
3391 --i; --e;
3392 }
3393
3394 // If we found some loop invariants, fold them into the recurrence.
3395 if (!LIOps.empty()) {
3396 // NLI * LI * {Start,+,Step} --> NLI * {LI*Start,+,LI*Step}
3398 NewOps.reserve(AddRec->getNumOperands());
3399 const SCEV *Scale = getMulExpr(LIOps, SCEV::FlagAnyWrap, Depth + 1);
3400
3401 // If both the mul and addrec are nuw, we can preserve nuw.
3402 // If both the mul and addrec are nsw, we can only preserve nsw if either
3403 // a) they are also nuw, or
3404 // b) all multiplications of addrec operands with scale are nsw.
3405 SCEV::NoWrapFlags Flags =
3406 AddRec->getNoWrapFlags(ComputeFlags({Scale, AddRec}));
3407
3408 for (unsigned i = 0, e = AddRec->getNumOperands(); i != e; ++i) {
3409 NewOps.push_back(getMulExpr(Scale, AddRec->getOperand(i),
3410 SCEV::FlagAnyWrap, Depth + 1));
3411
3412 if (hasFlags(Flags, SCEV::FlagNSW) && !hasFlags(Flags, SCEV::FlagNUW)) {
3414 Instruction::Mul, getSignedRange(Scale),
3416 if (!NSWRegion.contains(getSignedRange(AddRec->getOperand(i))))
3417 Flags = clearFlags(Flags, SCEV::FlagNSW);
3418 }
3419 }
3420
3421 const SCEV *NewRec = getAddRecExpr(NewOps, AddRec->getLoop(), Flags);
3422
3423 // If all of the other operands were loop invariant, we are done.
3424 if (Ops.size() == 1) return NewRec;
3425
3426 // Otherwise, multiply the folded AddRec by the non-invariant parts.
3427 for (unsigned i = 0;; ++i)
3428 if (Ops[i] == AddRec) {
3429 Ops[i] = NewRec;
3430 break;
3431 }
3432 return getMulExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
3433 }
3434
3435 // Okay, if there weren't any loop invariants to be folded, check to see
3436 // if there are multiple AddRec's with the same loop induction variable
3437 // being multiplied together. If so, we can fold them.
3438
3439 // {A1,+,A2,+,...,+,An}<L> * {B1,+,B2,+,...,+,Bn}<L>
3440 // = {x=1 in [ sum y=x..2x [ sum z=max(y-x, y-n)..min(x,n) [
3441 // choose(x, 2x)*choose(2x-y, x-z)*A_{y-z}*B_z
3442 // ]]],+,...up to x=2n}.
3443 // Note that the arguments to choose() are always integers with values
3444 // known at compile time, never SCEV objects.
3445 //
3446 // The implementation avoids pointless extra computations when the two
3447 // addrec's are of different length (mathematically, it's equivalent to
3448 // an infinite stream of zeros on the right).
3449 bool OpsModified = false;
3450 for (unsigned OtherIdx = Idx+1;
3451 OtherIdx != Ops.size() && isa<SCEVAddRecExpr>(Ops[OtherIdx]);
3452 ++OtherIdx) {
3453 const SCEVAddRecExpr *OtherAddRec =
3454 dyn_cast<SCEVAddRecExpr>(Ops[OtherIdx]);
3455 if (!OtherAddRec || OtherAddRec->getLoop() != AddRec->getLoop())
3456 continue;
3457
3458 // Limit max number of arguments to avoid creation of unreasonably big
3459 // SCEVAddRecs with very complex operands.
3460 if (AddRec->getNumOperands() + OtherAddRec->getNumOperands() - 1 >
3461 MaxAddRecSize || hasHugeExpression({AddRec, OtherAddRec}))
3462 continue;
3463
3464 bool Overflow = false;
3465 Type *Ty = AddRec->getType();
3466 bool LargerThan64Bits = getTypeSizeInBits(Ty) > 64;
3467 SmallVector<SCEVUse, 7> AddRecOps;
3468 for (int x = 0, xe = AddRec->getNumOperands() +
3469 OtherAddRec->getNumOperands() - 1; x != xe && !Overflow; ++x) {
3471 for (int y = x, ye = 2*x+1; y != ye && !Overflow; ++y) {
3472 uint64_t Coeff1 = Choose(x, 2*x - y, Overflow);
3473 for (int z = std::max(y-x, y-(int)AddRec->getNumOperands()+1),
3474 ze = std::min(x+1, (int)OtherAddRec->getNumOperands());
3475 z < ze && !Overflow; ++z) {
3476 uint64_t Coeff2 = Choose(2*x - y, x-z, Overflow);
3477 uint64_t Coeff;
3478 if (LargerThan64Bits)
3479 Coeff = umul_ov(Coeff1, Coeff2, Overflow);
3480 else
3481 Coeff = Coeff1*Coeff2;
3482 const SCEV *CoeffTerm = getConstant(Ty, Coeff);
3483 const SCEV *Term1 = AddRec->getOperand(y-z);
3484 const SCEV *Term2 = OtherAddRec->getOperand(z);
3485 SumOps.push_back(getMulExpr(CoeffTerm, Term1, Term2,
3486 SCEV::FlagAnyWrap, Depth + 1));
3487 }
3488 }
3489 if (SumOps.empty())
3490 SumOps.push_back(getZero(Ty));
3491 AddRecOps.push_back(getAddExpr(SumOps, SCEV::FlagAnyWrap, Depth + 1));
3492 }
3493 if (!Overflow) {
3494 const SCEV *NewAddRec = getAddRecExpr(AddRecOps, AddRec->getLoop(),
3496 if (Ops.size() == 2) return NewAddRec;
3497 Ops[Idx] = NewAddRec;
3498 Ops.erase(Ops.begin() + OtherIdx); --OtherIdx;
3499 OpsModified = true;
3500 AddRec = dyn_cast<SCEVAddRecExpr>(NewAddRec);
3501 if (!AddRec)
3502 break;
3503 }
3504 }
3505 if (OpsModified)
3506 return getMulExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
3507
3508 // Otherwise couldn't fold anything into this recurrence. Move onto the
3509 // next one.
3510 }
3511
3512 // Okay, it looks like we really DO need an mul expr. Check to see if we
3513 // already have one, otherwise create a new one.
3514 return getOrCreateMulExpr(Ops, ComputeFlags(Ops));
3515}
3516
3517/// Represents an unsigned remainder expression based on unsigned division.
3519 assert(getEffectiveSCEVType(LHS->getType()) ==
3520 getEffectiveSCEVType(RHS->getType()) &&
3521 "SCEVURemExpr operand types don't match!");
3522
3523 // Short-circuit easy cases
3524 if (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS)) {
3525 // If constant is one, the result is trivial
3526 if (RHSC->getValue()->isOne())
3527 return getZero(LHS->getType()); // X urem 1 --> 0
3528
3529 // If constant is a power of two, fold into a zext(trunc(LHS)).
3530 if (RHSC->getAPInt().isPowerOf2()) {
3531 Type *FullTy = LHS->getType();
3532 Type *TruncTy =
3533 IntegerType::get(getContext(), RHSC->getAPInt().logBase2());
3534 return getZeroExtendExpr(getTruncateExpr(LHS, TruncTy), FullTy);
3535 }
3536 }
3537
3538 // Fallback to %a == %x urem %y == %x -<nuw> ((%x udiv %y) *<nuw> %y)
3539 const SCEV *UDiv = getUDivExpr(LHS, RHS);
3540 const SCEV *Mult = getMulExpr(UDiv, RHS, SCEV::FlagNUW);
3541 return getMinusSCEV(LHS, Mult, SCEV::FlagNUW);
3542}
3543
3544/// Get a canonical unsigned division expression, or something simpler if
3545/// possible.
3547 assert(!LHS->getType()->isPointerTy() &&
3548 "SCEVUDivExpr operand can't be pointer!");
3549 assert(LHS->getType() == RHS->getType() &&
3550 "SCEVUDivExpr operand types don't match!");
3551
3553 ID.AddInteger(scUDivExpr);
3554 ID.AddPointer(LHS);
3555 ID.AddPointer(RHS);
3556 void *IP = nullptr;
3557 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP))
3558 return S;
3559
3560 // 0 udiv Y == 0
3561 if (match(LHS, m_scev_Zero()))
3562 return LHS;
3563
3564 if (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS)) {
3565 if (RHSC->getValue()->isOne())
3566 return LHS; // X udiv 1 --> x
3567 // If the denominator is zero, the result of the udiv is undefined. Don't
3568 // try to analyze it, because the resolution chosen here may differ from
3569 // the resolution chosen in other parts of the compiler.
3570 if (!RHSC->getValue()->isZero()) {
3571 // Determine if the division can be folded into the operands of
3572 // its operands.
3573 // TODO: Generalize this to non-constants by using known-bits information.
3574 Type *Ty = LHS->getType();
3575 unsigned LZ = RHSC->getAPInt().countl_zero();
3576 unsigned MaxShiftAmt = getTypeSizeInBits(Ty) - LZ - 1;
3577 // For non-power-of-two values, effectively round the value up to the
3578 // nearest power of two.
3579 if (!RHSC->getAPInt().isPowerOf2())
3580 ++MaxShiftAmt;
3581 IntegerType *ExtTy =
3582 IntegerType::get(getContext(), getTypeSizeInBits(Ty) + MaxShiftAmt);
3583 if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(LHS))
3584 if (const SCEVConstant *Step =
3585 dyn_cast<SCEVConstant>(AR->getStepRecurrence(*this))) {
3586 // {X,+,N}/C --> {X/C,+,N/C} if safe and N/C can be folded.
3587 const APInt &StepInt = Step->getAPInt();
3588 const APInt &DivInt = RHSC->getAPInt();
3589 if (!StepInt.urem(DivInt) &&
3590 getZeroExtendExpr(AR, ExtTy) ==
3591 getAddRecExpr(getZeroExtendExpr(AR->getStart(), ExtTy),
3592 getZeroExtendExpr(Step, ExtTy),
3593 AR->getLoop(), SCEV::FlagAnyWrap)) {
3594 SmallVector<SCEVUse, 4> Operands;
3595 for (const SCEV *Op : AR->operands())
3596 Operands.push_back(getUDivExpr(Op, RHS));
3597 return getAddRecExpr(Operands, AR->getLoop(), SCEV::FlagNW);
3598 }
3599 /// Get a canonical UDivExpr for a recurrence.
3600 /// {X,+,N}/C => {Y,+,N}/C where Y=X-(X%N). Safe when C%N=0.
3601 const APInt *StartRem;
3602 if (!DivInt.urem(StepInt) && match(getURemExpr(AR->getStart(), Step),
3603 m_scev_APInt(StartRem))) {
3604 bool NoWrap =
3605 getZeroExtendExpr(AR, ExtTy) ==
3606 getAddRecExpr(getZeroExtendExpr(AR->getStart(), ExtTy),
3607 getZeroExtendExpr(Step, ExtTy), AR->getLoop(),
3609
3610 // With N <= C and both N, C as powers-of-2, the transformation
3611 // {X,+,N}/C => {(X - X%N),+,N}/C preserves division results even
3612 // if wrapping occurs, as the division results remain equivalent for
3613 // all offsets in [[(X - X%N), X).
3614 bool CanFoldWithWrap = StepInt.ule(DivInt) && // N <= C
3615 StepInt.isPowerOf2() && DivInt.isPowerOf2();
3616 // Only fold if the subtraction can be folded in the start
3617 // expression.
3618 const SCEV *NewStart =
3619 getMinusSCEV(AR->getStart(), getConstant(*StartRem));
3620 if (*StartRem != 0 && (NoWrap || CanFoldWithWrap) &&
3621 !isa<SCEVAddExpr>(NewStart)) {
3622 const SCEV *NewLHS =
3623 getAddRecExpr(NewStart, Step, AR->getLoop(),
3624 NoWrap ? SCEV::FlagNW : SCEV::FlagAnyWrap);
3625 if (LHS != NewLHS) {
3626 LHS = NewLHS;
3627
3628 // Reset the ID to include the new LHS, and check if it is
3629 // already cached.
3630 ID.clear();
3631 ID.AddInteger(scUDivExpr);
3632 ID.AddPointer(LHS);
3633 ID.AddPointer(RHS);
3634 IP = nullptr;
3635 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP))
3636 return S;
3637 }
3638 }
3639 }
3640 }
3641 // (A*B)/C --> A*(B/C) if safe and B/C can be folded.
3642 if (const SCEVMulExpr *M = dyn_cast<SCEVMulExpr>(LHS)) {
3643 SmallVector<SCEVUse, 4> Operands;
3644 for (const SCEV *Op : M->operands())
3645 Operands.push_back(getZeroExtendExpr(Op, ExtTy));
3646 if (getZeroExtendExpr(M, ExtTy) == getMulExpr(Operands))
3647 // Find an operand that's safely divisible.
3648 for (unsigned i = 0, e = M->getNumOperands(); i != e; ++i) {
3649 const SCEV *Op = M->getOperand(i);
3650 const SCEV *Div = getUDivExpr(Op, RHSC);
3651 if (!isa<SCEVUDivExpr>(Div) && getMulExpr(Div, RHSC) == Op) {
3652 Operands = SmallVector<SCEVUse, 4>(M->operands());
3653 Operands[i] = Div;
3654 return getMulExpr(Operands);
3655 }
3656 }
3657 }
3658
3659 // (A/B)/C --> A/(B*C) if safe and B*C can be folded.
3660 if (const SCEVUDivExpr *OtherDiv = dyn_cast<SCEVUDivExpr>(LHS)) {
3661 if (auto *DivisorConstant =
3662 dyn_cast<SCEVConstant>(OtherDiv->getRHS())) {
3663 bool Overflow = false;
3664 APInt NewRHS =
3665 DivisorConstant->getAPInt().umul_ov(RHSC->getAPInt(), Overflow);
3666 if (Overflow) {
3667 return getConstant(RHSC->getType(), 0, false);
3668 }
3669 return getUDivExpr(OtherDiv->getLHS(), getConstant(NewRHS));
3670 }
3671 }
3672
3673 // (A+B)/C --> (A/C + B/C) if safe and A/C and B/C can be folded.
3674 if (const SCEVAddExpr *A = dyn_cast<SCEVAddExpr>(LHS)) {
3675 SmallVector<SCEVUse, 4> Operands;
3676 for (const SCEV *Op : A->operands())
3677 Operands.push_back(getZeroExtendExpr(Op, ExtTy));
3678 if (getZeroExtendExpr(A, ExtTy) == getAddExpr(Operands)) {
3679 Operands.clear();
3680 for (unsigned i = 0, e = A->getNumOperands(); i != e; ++i) {
3681 const SCEV *Op = getUDivExpr(A->getOperand(i), RHS);
3682 if (isa<SCEVUDivExpr>(Op) ||
3683 getMulExpr(Op, RHS) != A->getOperand(i))
3684 break;
3685 Operands.push_back(Op);
3686 }
3687 if (Operands.size() == A->getNumOperands())
3688 return getAddExpr(Operands);
3689 }
3690 }
3691
3692 // Fold if both operands are constant.
3693 if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(LHS))
3694 return getConstant(LHSC->getAPInt().udiv(RHSC->getAPInt()));
3695 }
3696 }
3697
3698 // ((-C + (C smax %x)) /u %x) evaluates to zero, for any positive constant C.
3699 const APInt *NegC, *C;
3700 if (match(LHS,
3703 NegC->isNegative() && !NegC->isMinSignedValue() && *C == -*NegC)
3704 return getZero(LHS->getType());
3705
3706 // TODO: Generalize to handle any common factors.
3707 // udiv (mul nuw a, vscale), (mul nuw b, vscale) --> udiv a, b
3708 const SCEV *NewLHS, *NewRHS;
3709 if (match(LHS, m_scev_c_NUWMul(m_SCEV(NewLHS), m_SCEVVScale())) &&
3710 match(RHS, m_scev_c_NUWMul(m_SCEV(NewRHS), m_SCEVVScale())))
3711 return getUDivExpr(NewLHS, NewRHS);
3712
3713 // The Insertion Point (IP) might be invalid by now (due to UniqueSCEVs
3714 // changes). Make sure we get a new one.
3715 IP = nullptr;
3716 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
3717 SCEV *S = new (SCEVAllocator) SCEVUDivExpr(ID.Intern(SCEVAllocator),
3718 LHS, RHS);
3719 UniqueSCEVs.InsertNode(S, IP);
3720 S->computeAndSetCanonical(*this);
3721 registerUser(S, ArrayRef<SCEVUse>({LHS, RHS}));
3722 return S;
3723}
3724
3725APInt gcd(const SCEVConstant *C1, const SCEVConstant *C2) {
3726 APInt A = C1->getAPInt().abs();
3727 APInt B = C2->getAPInt().abs();
3728 uint32_t ABW = A.getBitWidth();
3729 uint32_t BBW = B.getBitWidth();
3730
3731 if (ABW > BBW)
3732 B = B.zext(ABW);
3733 else if (ABW < BBW)
3734 A = A.zext(BBW);
3735
3736 return APIntOps::GreatestCommonDivisor(std::move(A), std::move(B));
3737}
3738
3739/// Get a canonical unsigned division expression, or something simpler if
3740/// possible. There is no representation for an exact udiv in SCEV IR, but we
3741/// can attempt to remove factors from the LHS and RHS. We can't do this when
3742/// it's not exact because the udiv may be clearing bits.
3744 // TODO: we could try to find factors in all sorts of things, but for now we
3745 // just deal with u/exact (multiply, constant). See SCEVDivision towards the
3746 // end of this file for inspiration.
3747
3749 if (!Mul || !Mul->hasNoUnsignedWrap())
3750 return getUDivExpr(LHS, RHS);
3751
3752 if (const SCEVConstant *RHSCst = dyn_cast<SCEVConstant>(RHS)) {
3753 // If the mulexpr multiplies by a constant, then that constant must be the
3754 // first element of the mulexpr.
3755 if (const auto *LHSCst = dyn_cast<SCEVConstant>(Mul->getOperand(0))) {
3756 if (LHSCst == RHSCst) {
3757 SmallVector<SCEVUse, 2> Operands(drop_begin(Mul->operands()));
3758 return getMulExpr(Operands);
3759 }
3760
3761 // We can't just assume that LHSCst divides RHSCst cleanly, it could be
3762 // that there's a factor provided by one of the other terms. We need to
3763 // check.
3764 APInt Factor = gcd(LHSCst, RHSCst);
3765 if (!Factor.isIntN(1)) {
3766 LHSCst =
3767 cast<SCEVConstant>(getConstant(LHSCst->getAPInt().udiv(Factor)));
3768 RHSCst =
3769 cast<SCEVConstant>(getConstant(RHSCst->getAPInt().udiv(Factor)));
3770 SmallVector<SCEVUse, 2> Operands;
3771 Operands.push_back(LHSCst);
3772 append_range(Operands, Mul->operands().drop_front());
3773 LHS = getMulExpr(Operands);
3774 RHS = RHSCst;
3776 if (!Mul)
3777 return getUDivExactExpr(LHS, RHS);
3778 }
3779 }
3780 }
3781
3782 for (int i = 0, e = Mul->getNumOperands(); i != e; ++i) {
3783 if (Mul->getOperand(i) == RHS) {
3784 SmallVector<SCEVUse, 2> Operands;
3785 append_range(Operands, Mul->operands().take_front(i));
3786 append_range(Operands, Mul->operands().drop_front(i + 1));
3787 return getMulExpr(Operands);
3788 }
3789 }
3790
3791 return getUDivExpr(LHS, RHS);
3792}
3793
3794/// Get an add recurrence expression for the specified loop. Simplify the
3795/// expression as much as possible.
3797 const Loop *L,
3798 SCEV::NoWrapFlags Flags) {
3799 SmallVector<SCEVUse, 4> Operands;
3800 Operands.push_back(Start);
3801 if (const SCEVAddRecExpr *StepChrec = dyn_cast<SCEVAddRecExpr>(Step))
3802 if (StepChrec->getLoop() == L) {
3803 append_range(Operands, StepChrec->operands());
3804 return getAddRecExpr(Operands, L, maskFlags(Flags, SCEV::FlagNW));
3805 }
3806
3807 Operands.push_back(Step);
3808 return getAddRecExpr(Operands, L, Flags);
3809}
3810
3811/// Get an add recurrence expression for the specified loop. Simplify the
3812/// expression as much as possible.
3814 const Loop *L,
3815 SCEV::NoWrapFlags Flags) {
3816 if (Operands.size() == 1) return Operands[0];
3817#ifndef NDEBUG
3818 Type *ETy = getEffectiveSCEVType(Operands[0]->getType());
3819 for (const SCEV *Op : llvm::drop_begin(Operands)) {
3820 assert(getEffectiveSCEVType(Op->getType()) == ETy &&
3821 "SCEVAddRecExpr operand types don't match!");
3822 assert(!Op->getType()->isPointerTy() && "Step must be integer");
3823 }
3824 for (const SCEV *Op : Operands)
3826 "SCEVAddRecExpr operand is not available at loop entry!");
3827#endif
3828
3829 if (Operands.back()->isZero()) {
3830 Operands.pop_back();
3831 return getAddRecExpr(Operands, L, SCEV::FlagAnyWrap); // {X,+,0} --> X
3832 }
3833
3834 // It's tempting to want to call getConstantMaxBackedgeTakenCount count here and
3835 // use that information to infer NUW and NSW flags. However, computing a
3836 // BE count requires calling getAddRecExpr, so we may not yet have a
3837 // meaningful BE count at this point (and if we don't, we'd be stuck
3838 // with a SCEVCouldNotCompute as the cached BE count).
3839
3840 Flags = StrengthenNoWrapFlags(this, scAddRecExpr, Operands, Flags);
3841
3842 // Canonicalize nested AddRecs in by nesting them in order of loop depth.
3843 if (const SCEVAddRecExpr *NestedAR = dyn_cast<SCEVAddRecExpr>(Operands[0])) {
3844 const Loop *NestedLoop = NestedAR->getLoop();
3845 if (L->contains(NestedLoop)
3846 ? (L->getLoopDepth() < NestedLoop->getLoopDepth())
3847 : (!NestedLoop->contains(L) &&
3848 DT.dominates(L->getHeader(), NestedLoop->getHeader()))) {
3849 SmallVector<SCEVUse, 4> NestedOperands(NestedAR->operands());
3850 Operands[0] = NestedAR->getStart();
3851 // AddRecs require their operands be loop-invariant with respect to their
3852 // loops. Don't perform this transformation if it would break this
3853 // requirement.
3854 bool AllInvariant = all_of(
3855 Operands, [&](const SCEV *Op) { return isLoopInvariant(Op, L); });
3856
3857 if (AllInvariant) {
3858 // Create a recurrence for the outer loop with the same step size.
3859 //
3860 // The outer recurrence keeps its NW flag but only keeps NUW/NSW if the
3861 // inner recurrence has the same property.
3862 SCEV::NoWrapFlags OuterFlags =
3863 maskFlags(Flags, SCEV::FlagNW | NestedAR->getNoWrapFlags());
3864
3865 NestedOperands[0] = getAddRecExpr(Operands, L, OuterFlags);
3866 AllInvariant = all_of(NestedOperands, [&](const SCEV *Op) {
3867 return isLoopInvariant(Op, NestedLoop);
3868 });
3869
3870 if (AllInvariant) {
3871 // Ok, both add recurrences are valid after the transformation.
3872 //
3873 // The inner recurrence keeps its NW flag but only keeps NUW/NSW if
3874 // the outer recurrence has the same property.
3875 SCEV::NoWrapFlags InnerFlags =
3876 maskFlags(NestedAR->getNoWrapFlags(), SCEV::FlagNW | Flags);
3877 return getAddRecExpr(NestedOperands, NestedLoop, InnerFlags);
3878 }
3879 }
3880 // Reset Operands to its original state.
3881 Operands[0] = NestedAR;
3882 }
3883 }
3884
3885 // Okay, it looks like we really DO need an addrec expr. Check to see if we
3886 // already have one, otherwise create a new one.
3887 return getOrCreateAddRecExpr(Operands, L, Flags);
3888}
3889
3891 ArrayRef<SCEVUse> IndexExprs) {
3892 const SCEV *BaseExpr = getSCEV(GEP->getPointerOperand());
3893 // getSCEV(Base)->getType() has the same address space as Base->getType()
3894 // because SCEV::getType() preserves the address space.
3895 GEPNoWrapFlags NW = GEP->getNoWrapFlags();
3896 if (NW != GEPNoWrapFlags::none()) {
3897 // We'd like to propagate flags from the IR to the corresponding SCEV nodes,
3898 // but to do that, we have to ensure that said flag is valid in the entire
3899 // defined scope of the SCEV.
3900 // TODO: non-instructions have global scope. We might be able to prove
3901 // some global scope cases
3902 auto *GEPI = dyn_cast<Instruction>(GEP);
3903 if (!GEPI || !isSCEVExprNeverPoison(GEPI))
3904 NW = GEPNoWrapFlags::none();
3905 }
3906
3907 return getGEPExpr(BaseExpr, IndexExprs, GEP->getSourceElementType(), NW);
3908}
3909
3911 ArrayRef<SCEVUse> IndexExprs,
3912 Type *SrcElementTy, GEPNoWrapFlags NW) {
3914 if (NW.hasNoUnsignedSignedWrap())
3915 OffsetWrap = setFlags(OffsetWrap, SCEV::FlagNSW);
3916 if (NW.hasNoUnsignedWrap())
3917 OffsetWrap = setFlags(OffsetWrap, SCEV::FlagNUW);
3918
3919 Type *CurTy = BaseExpr->getType();
3920 Type *IntIdxTy = getEffectiveSCEVType(BaseExpr->getType());
3921 bool FirstIter = true;
3923 for (SCEVUse IndexExpr : IndexExprs) {
3924 // Compute the (potentially symbolic) offset in bytes for this index.
3925 if (StructType *STy = dyn_cast<StructType>(CurTy)) {
3926 // For a struct, add the member offset.
3927 ConstantInt *Index = cast<SCEVConstant>(IndexExpr)->getValue();
3928 unsigned FieldNo = Index->getZExtValue();
3929 const SCEV *FieldOffset = getOffsetOfExpr(IntIdxTy, STy, FieldNo);
3930 Offsets.push_back(FieldOffset);
3931
3932 // Update CurTy to the type of the field at Index.
3933 CurTy = STy->getTypeAtIndex(Index);
3934 } else {
3935 // Update CurTy to its element type.
3936 if (FirstIter) {
3937 assert(isa<PointerType>(CurTy) &&
3938 "The first index of a GEP indexes a pointer");
3939 CurTy = SrcElementTy;
3940 FirstIter = false;
3941 } else {
3943 }
3944 // For an array, add the element offset, explicitly scaled.
3945 const SCEV *ElementSize = getSizeOfExpr(IntIdxTy, CurTy);
3946 // Getelementptr indices are signed.
3947 IndexExpr = getTruncateOrSignExtend(IndexExpr, IntIdxTy);
3948
3949 // Multiply the index by the element size to compute the element offset.
3950 const SCEV *LocalOffset = getMulExpr(IndexExpr, ElementSize, OffsetWrap);
3951 Offsets.push_back(LocalOffset);
3952 }
3953 }
3954
3955 // Handle degenerate case of GEP without offsets.
3956 if (Offsets.empty())
3957 return BaseExpr;
3958
3959 // Add the offsets together, assuming nsw if inbounds.
3960 const SCEV *Offset = getAddExpr(Offsets, OffsetWrap);
3961 // Add the base address and the offset. We cannot use the nsw flag, as the
3962 // base address is unsigned. However, if we know that the offset is
3963 // non-negative, we can use nuw.
3964 bool NUW = NW.hasNoUnsignedWrap() ||
3967 auto *GEPExpr = getAddExpr(BaseExpr, Offset, BaseWrap);
3968 assert(BaseExpr->getType() == GEPExpr->getType() &&
3969 "GEP should not change type mid-flight.");
3970 return GEPExpr;
3971}
3972
3973SCEV *ScalarEvolution::findExistingSCEVInCache(SCEVTypes SCEVType,
3976 ID.AddInteger(SCEVType);
3977 for (const SCEV *Op : Ops)
3978 ID.AddPointer(Op);
3979 void *IP = nullptr;
3980 return UniqueSCEVs.FindNodeOrInsertPos(ID, IP);
3981}
3982
3983SCEV *ScalarEvolution::findExistingSCEVInCache(SCEVTypes SCEVType,
3986 ID.AddInteger(SCEVType);
3987 for (const SCEV *Op : Ops)
3988 ID.AddPointer(Op);
3989 void *IP = nullptr;
3990 return UniqueSCEVs.FindNodeOrInsertPos(ID, IP);
3991}
3992
3993const SCEV *ScalarEvolution::getAbsExpr(const SCEV *Op, bool IsNSW) {
3995 return getSMaxExpr(Op, getNegativeSCEV(Op, Flags));
3996}
3997
4000 assert(SCEVMinMaxExpr::isMinMaxType(Kind) && "Not a SCEVMinMaxExpr!");
4001 assert(!Ops.empty() && "Cannot get empty (u|s)(min|max)!");
4002 if (Ops.size() == 1) return Ops[0];
4003#ifndef NDEBUG
4004 Type *ETy = getEffectiveSCEVType(Ops[0]->getType());
4005 for (unsigned i = 1, e = Ops.size(); i != e; ++i) {
4006 assert(getEffectiveSCEVType(Ops[i]->getType()) == ETy &&
4007 "Operand types don't match!");
4008 assert(Ops[0]->getType()->isPointerTy() ==
4009 Ops[i]->getType()->isPointerTy() &&
4010 "min/max should be consistently pointerish");
4011 }
4012#endif
4013
4014 bool IsSigned = Kind == scSMaxExpr || Kind == scSMinExpr;
4015 bool IsMax = Kind == scSMaxExpr || Kind == scUMaxExpr;
4016
4017 const SCEV *Folded = constantFoldAndGroupOps(
4018 *this, LI, DT, Ops,
4019 [&](const APInt &C1, const APInt &C2) {
4020 switch (Kind) {
4021 case scSMaxExpr:
4022 return APIntOps::smax(C1, C2);
4023 case scSMinExpr:
4024 return APIntOps::smin(C1, C2);
4025 case scUMaxExpr:
4026 return APIntOps::umax(C1, C2);
4027 case scUMinExpr:
4028 return APIntOps::umin(C1, C2);
4029 default:
4030 llvm_unreachable("Unknown SCEV min/max opcode");
4031 }
4032 },
4033 [&](const APInt &C) {
4034 // identity
4035 if (IsMax)
4036 return IsSigned ? C.isMinSignedValue() : C.isMinValue();
4037 else
4038 return IsSigned ? C.isMaxSignedValue() : C.isMaxValue();
4039 },
4040 [&](const APInt &C) {
4041 // absorber
4042 if (IsMax)
4043 return IsSigned ? C.isMaxSignedValue() : C.isMaxValue();
4044 else
4045 return IsSigned ? C.isMinSignedValue() : C.isMinValue();
4046 });
4047 if (Folded)
4048 return Folded;
4049
4050 // Check if we have created the same expression before.
4051 if (const SCEV *S = findExistingSCEVInCache(Kind, Ops)) {
4052 return S;
4053 }
4054
4055 // Find the first operation of the same kind
4056 unsigned Idx = 0;
4057 while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < Kind)
4058 ++Idx;
4059
4060 // Check to see if one of the operands is of the same kind. If so, expand its
4061 // operands onto our operand list, and recurse to simplify.
4062 if (Idx < Ops.size()) {
4063 bool DeletedAny = false;
4064 while (Ops[Idx]->getSCEVType() == Kind) {
4065 const SCEVMinMaxExpr *SMME = cast<SCEVMinMaxExpr>(Ops[Idx]);
4066 Ops.erase(Ops.begin()+Idx);
4067 append_range(Ops, SMME->operands());
4068 DeletedAny = true;
4069 }
4070
4071 if (DeletedAny)
4072 return getMinMaxExpr(Kind, Ops);
4073 }
4074
4075 // Okay, check to see if the same value occurs in the operand list twice. If
4076 // so, delete one. Since we sorted the list, these values are required to
4077 // be adjacent.
4082 llvm::CmpInst::Predicate FirstPred = IsMax ? GEPred : LEPred;
4083 llvm::CmpInst::Predicate SecondPred = IsMax ? LEPred : GEPred;
4084 for (unsigned i = 0, e = Ops.size() - 1; i != e; ++i) {
4085 if (Ops[i] == Ops[i + 1] ||
4086 isKnownViaNonRecursiveReasoning(FirstPred, Ops[i], Ops[i + 1])) {
4087 // X op Y op Y --> X op Y
4088 // X op Y --> X, if we know X, Y are ordered appropriately
4089 Ops.erase(Ops.begin() + i + 1, Ops.begin() + i + 2);
4090 --i;
4091 --e;
4092 } else if (isKnownViaNonRecursiveReasoning(SecondPred, Ops[i],
4093 Ops[i + 1])) {
4094 // X op Y --> Y, if we know X, Y are ordered appropriately
4095 Ops.erase(Ops.begin() + i, Ops.begin() + i + 1);
4096 --i;
4097 --e;
4098 }
4099 }
4100
4101 if (Ops.size() == 1) return Ops[0];
4102
4103 assert(!Ops.empty() && "Reduced smax down to nothing!");
4104
4105 // Okay, it looks like we really DO need an expr. Check to see if we
4106 // already have one, otherwise create a new one.
4108 ID.AddInteger(Kind);
4109 for (const SCEV *Op : Ops)
4110 ID.AddPointer(Op);
4111 void *IP = nullptr;
4112 const SCEV *ExistingSCEV = UniqueSCEVs.FindNodeOrInsertPos(ID, IP);
4113 if (ExistingSCEV)
4114 return ExistingSCEV;
4115 SCEVUse *O = SCEVAllocator.Allocate<SCEVUse>(Ops.size());
4117 SCEV *S = new (SCEVAllocator)
4118 SCEVMinMaxExpr(ID.Intern(SCEVAllocator), Kind, O, Ops.size());
4119
4120 UniqueSCEVs.InsertNode(S, IP);
4121 S->computeAndSetCanonical(*this);
4122 registerUser(S, Ops);
4123 return S;
4124}
4125
4126namespace {
4127
4128class SCEVSequentialMinMaxDeduplicatingVisitor final
4129 : public SCEVVisitor<SCEVSequentialMinMaxDeduplicatingVisitor,
4130 std::optional<const SCEV *>> {
4131 using RetVal = std::optional<const SCEV *>;
4133
4134 ScalarEvolution &SE;
4135 const SCEVTypes RootKind; // Must be a sequential min/max expression.
4136 const SCEVTypes NonSequentialRootKind; // Non-sequential variant of RootKind.
4138
4139 bool canRecurseInto(SCEVTypes Kind) const {
4140 // We can only recurse into the SCEV expression of the same effective type
4141 // as the type of our root SCEV expression.
4142 return RootKind == Kind || NonSequentialRootKind == Kind;
4143 };
4144
4145 RetVal visitAnyMinMaxExpr(const SCEV *S) {
4147 "Only for min/max expressions.");
4148 SCEVTypes Kind = S->getSCEVType();
4149
4150 if (!canRecurseInto(Kind))
4151 return S;
4152
4153 auto *NAry = cast<SCEVNAryExpr>(S);
4154 SmallVector<SCEVUse> NewOps;
4155 bool Changed = visit(Kind, NAry->operands(), NewOps);
4156
4157 if (!Changed)
4158 return S;
4159 if (NewOps.empty())
4160 return std::nullopt;
4161
4163 ? SE.getSequentialMinMaxExpr(Kind, NewOps)
4164 : SE.getMinMaxExpr(Kind, NewOps);
4165 }
4166
4167 RetVal visit(const SCEV *S) {
4168 // Has the whole operand been seen already?
4169 if (!SeenOps.insert(S).second)
4170 return std::nullopt;
4171 return Base::visit(S);
4172 }
4173
4174public:
4175 SCEVSequentialMinMaxDeduplicatingVisitor(ScalarEvolution &SE,
4176 SCEVTypes RootKind)
4177 : SE(SE), RootKind(RootKind),
4178 NonSequentialRootKind(
4179 SCEVSequentialMinMaxExpr::getEquivalentNonSequentialSCEVType(
4180 RootKind)) {}
4181
4182 bool /*Changed*/ visit(SCEVTypes Kind, ArrayRef<SCEVUse> OrigOps,
4183 SmallVectorImpl<SCEVUse> &NewOps) {
4184 bool Changed = false;
4186 Ops.reserve(OrigOps.size());
4187
4188 for (const SCEV *Op : OrigOps) {
4189 RetVal NewOp = visit(Op);
4190 if (NewOp != Op)
4191 Changed = true;
4192 if (NewOp)
4193 Ops.emplace_back(*NewOp);
4194 }
4195
4196 if (Changed)
4197 NewOps = std::move(Ops);
4198 return Changed;
4199 }
4200
4201 RetVal visitConstant(const SCEVConstant *Constant) { return Constant; }
4202
4203 RetVal visitVScale(const SCEVVScale *VScale) { return VScale; }
4204
4205 RetVal visitPtrToAddrExpr(const SCEVPtrToAddrExpr *Expr) { return Expr; }
4206
4207 RetVal visitPtrToIntExpr(const SCEVPtrToIntExpr *Expr) { return Expr; }
4208
4209 RetVal visitTruncateExpr(const SCEVTruncateExpr *Expr) { return Expr; }
4210
4211 RetVal visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr) { return Expr; }
4212
4213 RetVal visitSignExtendExpr(const SCEVSignExtendExpr *Expr) { return Expr; }
4214
4215 RetVal visitAddExpr(const SCEVAddExpr *Expr) { return Expr; }
4216
4217 RetVal visitMulExpr(const SCEVMulExpr *Expr) { return Expr; }
4218
4219 RetVal visitUDivExpr(const SCEVUDivExpr *Expr) { return Expr; }
4220
4221 RetVal visitAddRecExpr(const SCEVAddRecExpr *Expr) { return Expr; }
4222
4223 RetVal visitSMaxExpr(const SCEVSMaxExpr *Expr) {
4224 return visitAnyMinMaxExpr(Expr);
4225 }
4226
4227 RetVal visitUMaxExpr(const SCEVUMaxExpr *Expr) {
4228 return visitAnyMinMaxExpr(Expr);
4229 }
4230
4231 RetVal visitSMinExpr(const SCEVSMinExpr *Expr) {
4232 return visitAnyMinMaxExpr(Expr);
4233 }
4234
4235 RetVal visitUMinExpr(const SCEVUMinExpr *Expr) {
4236 return visitAnyMinMaxExpr(Expr);
4237 }
4238
4239 RetVal visitSequentialUMinExpr(const SCEVSequentialUMinExpr *Expr) {
4240 return visitAnyMinMaxExpr(Expr);
4241 }
4242
4243 RetVal visitUnknown(const SCEVUnknown *Expr) { return Expr; }
4244
4245 RetVal visitCouldNotCompute(const SCEVCouldNotCompute *Expr) { return Expr; }
4246};
4247
4248} // namespace
4249
4251 switch (Kind) {
4252 case scConstant:
4253 case scVScale:
4254 case scTruncate:
4255 case scZeroExtend:
4256 case scSignExtend:
4257 case scPtrToAddr:
4258 case scPtrToInt:
4259 case scAddExpr:
4260 case scMulExpr:
4261 case scUDivExpr:
4262 case scAddRecExpr:
4263 case scUMaxExpr:
4264 case scSMaxExpr:
4265 case scUMinExpr:
4266 case scSMinExpr:
4267 case scUnknown:
4268 // If any operand is poison, the whole expression is poison.
4269 return true;
4271 // FIXME: if the *first* operand is poison, the whole expression is poison.
4272 return false; // Pessimistically, say that it does not propagate poison.
4273 case scCouldNotCompute:
4274 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
4275 }
4276 llvm_unreachable("Unknown SCEV kind!");
4277}
4278
4279namespace {
4280// The only way poison may be introduced in a SCEV expression is from a
4281// poison SCEVUnknown (ConstantExprs are also represented as SCEVUnknown,
4282// not SCEVConstant). Notably, nowrap flags in SCEV nodes can *not*
4283// introduce poison -- they encode guaranteed, non-speculated knowledge.
4284//
4285// Additionally, all SCEV nodes propagate poison from inputs to outputs,
4286// with the notable exception of umin_seq, where only poison from the first
4287// operand is (unconditionally) propagated.
4288struct SCEVPoisonCollector {
4289 bool LookThroughMaybePoisonBlocking;
4290 SmallPtrSet<const SCEVUnknown *, 4> MaybePoison;
4291 SCEVPoisonCollector(bool LookThroughMaybePoisonBlocking)
4292 : LookThroughMaybePoisonBlocking(LookThroughMaybePoisonBlocking) {}
4293
4294 bool follow(const SCEV *S) {
4295 if (!LookThroughMaybePoisonBlocking &&
4297 return false;
4298
4299 if (auto *SU = dyn_cast<SCEVUnknown>(S)) {
4300 if (!isGuaranteedNotToBePoison(SU->getValue()))
4301 MaybePoison.insert(SU);
4302 }
4303 return true;
4304 }
4305 bool isDone() const { return false; }
4306};
4307} // namespace
4308
4309/// Return true if V is poison given that AssumedPoison is already poison.
4310static bool impliesPoison(const SCEV *AssumedPoison, const SCEV *S) {
4311 // First collect all SCEVs that might result in AssumedPoison to be poison.
4312 // We need to look through potentially poison-blocking operations here,
4313 // because we want to find all SCEVs that *might* result in poison, not only
4314 // those that are *required* to.
4315 SCEVPoisonCollector PC1(/* LookThroughMaybePoisonBlocking */ true);
4316 visitAll(AssumedPoison, PC1);
4317
4318 // AssumedPoison is never poison. As the assumption is false, the implication
4319 // is true. Don't bother walking the other SCEV in this case.
4320 if (PC1.MaybePoison.empty())
4321 return true;
4322
4323 // Collect all SCEVs in S that, if poison, *will* result in S being poison
4324 // as well. We cannot look through potentially poison-blocking operations
4325 // here, as their arguments only *may* make the result poison.
4326 SCEVPoisonCollector PC2(/* LookThroughMaybePoisonBlocking */ false);
4327 visitAll(S, PC2);
4328
4329 // Make sure that no matter which SCEV in PC1.MaybePoison is actually poison,
4330 // it will also make S poison by being part of PC2.MaybePoison.
4331 return llvm::set_is_subset(PC1.MaybePoison, PC2.MaybePoison);
4332}
4333
4335 SmallPtrSetImpl<const Value *> &Result, const SCEV *S) {
4336 SCEVPoisonCollector PC(/* LookThroughMaybePoisonBlocking */ false);
4337 visitAll(S, PC);
4338 for (const SCEVUnknown *SU : PC.MaybePoison)
4339 Result.insert(SU->getValue());
4340}
4341
4343 const SCEV *S, Instruction *I,
4344 SmallVectorImpl<Instruction *> &DropPoisonGeneratingInsts) {
4345 // If the instruction cannot be poison, it's always safe to reuse.
4347 return true;
4348
4349 // Otherwise, it is possible that I is more poisonous that S. Collect the
4350 // poison-contributors of S, and then check whether I has any additional
4351 // poison-contributors. Poison that is contributed through poison-generating
4352 // flags is handled by dropping those flags instead.
4354 getPoisonGeneratingValues(PoisonVals, S);
4355
4356 SmallVector<Value *> Worklist;
4358 Worklist.push_back(I);
4359 while (!Worklist.empty()) {
4360 Value *V = Worklist.pop_back_val();
4361 if (!Visited.insert(V).second)
4362 continue;
4363
4364 // Avoid walking large instruction graphs.
4365 if (Visited.size() > 16)
4366 return false;
4367
4368 // Either the value can't be poison, or the S would also be poison if it
4369 // is.
4370 if (PoisonVals.contains(V) || ::isGuaranteedNotToBePoison(V))
4371 continue;
4372
4373 auto *I = dyn_cast<Instruction>(V);
4374 if (!I)
4375 return false;
4376
4377 // Disjoint or instructions are interpreted as adds by SCEV. However, we
4378 // can't replace an arbitrary add with disjoint or, even if we drop the
4379 // flag. We would need to convert the or into an add.
4380 if (auto *PDI = dyn_cast<PossiblyDisjointInst>(I))
4381 if (PDI->isDisjoint())
4382 return false;
4383
4384 // FIXME: Ignore vscale, even though it technically could be poison. Do this
4385 // because SCEV currently assumes it can't be poison. Remove this special
4386 // case once we proper model when vscale can be poison.
4387 if (auto *II = dyn_cast<IntrinsicInst>(I);
4388 II && II->getIntrinsicID() == Intrinsic::vscale)
4389 continue;
4390
4391 if (canCreatePoison(cast<Operator>(I), /*ConsiderFlagsAndMetadata*/ false))
4392 return false;
4393
4394 // If the instruction can't create poison, we can recurse to its operands.
4395 if (I->hasPoisonGeneratingAnnotations())
4396 DropPoisonGeneratingInsts.push_back(I);
4397
4398 llvm::append_range(Worklist, I->operands());
4399 }
4400 return true;
4401}
4402
4403const SCEV *
4406 assert(SCEVSequentialMinMaxExpr::isSequentialMinMaxType(Kind) &&
4407 "Not a SCEVSequentialMinMaxExpr!");
4408 assert(!Ops.empty() && "Cannot get empty (u|s)(min|max)!");
4409 if (Ops.size() == 1)
4410 return Ops[0];
4411#ifndef NDEBUG
4412 Type *ETy = getEffectiveSCEVType(Ops[0]->getType());
4413 for (unsigned i = 1, e = Ops.size(); i != e; ++i) {
4414 assert(getEffectiveSCEVType(Ops[i]->getType()) == ETy &&
4415 "Operand types don't match!");
4416 assert(Ops[0]->getType()->isPointerTy() ==
4417 Ops[i]->getType()->isPointerTy() &&
4418 "min/max should be consistently pointerish");
4419 }
4420#endif
4421
4422 // Note that SCEVSequentialMinMaxExpr is *NOT* commutative,
4423 // so we can *NOT* do any kind of sorting of the expressions!
4424
4425 // Check if we have created the same expression before.
4426 if (const SCEV *S = findExistingSCEVInCache(Kind, Ops))
4427 return S;
4428
4429 // FIXME: there are *some* simplifications that we can do here.
4430
4431 // Keep only the first instance of an operand.
4432 {
4433 SCEVSequentialMinMaxDeduplicatingVisitor Deduplicator(*this, Kind);
4434 bool Changed = Deduplicator.visit(Kind, Ops, Ops);
4435 if (Changed)
4436 return getSequentialMinMaxExpr(Kind, Ops);
4437 }
4438
4439 // Check to see if one of the operands is of the same kind. If so, expand its
4440 // operands onto our operand list, and recurse to simplify.
4441 {
4442 unsigned Idx = 0;
4443 bool DeletedAny = false;
4444 while (Idx < Ops.size()) {
4445 if (Ops[Idx]->getSCEVType() != Kind) {
4446 ++Idx;
4447 continue;
4448 }
4449 const auto *SMME = cast<SCEVSequentialMinMaxExpr>(Ops[Idx]);
4450 Ops.erase(Ops.begin() + Idx);
4451 Ops.insert(Ops.begin() + Idx, SMME->operands().begin(),
4452 SMME->operands().end());
4453 DeletedAny = true;
4454 }
4455
4456 if (DeletedAny)
4457 return getSequentialMinMaxExpr(Kind, Ops);
4458 }
4459
4460 const SCEV *SaturationPoint;
4462 switch (Kind) {
4464 SaturationPoint = getZero(Ops[0]->getType());
4465 Pred = ICmpInst::ICMP_ULE;
4466 break;
4467 default:
4468 llvm_unreachable("Not a sequential min/max type.");
4469 }
4470
4471 for (unsigned i = 1, e = Ops.size(); i != e; ++i) {
4472 if (!isGuaranteedNotToCauseUB(Ops[i]))
4473 continue;
4474 // We can replace %x umin_seq %y with %x umin %y if either:
4475 // * %y being poison implies %x is also poison.
4476 // * %x cannot be the saturating value (e.g. zero for umin).
4477 if (::impliesPoison(Ops[i], Ops[i - 1]) ||
4478 isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_NE, Ops[i - 1],
4479 SaturationPoint)) {
4480 SmallVector<SCEVUse, 2> SeqOps = {Ops[i - 1], Ops[i]};
4481 Ops[i - 1] = getMinMaxExpr(
4483 SeqOps);
4484 Ops.erase(Ops.begin() + i);
4485 return getSequentialMinMaxExpr(Kind, Ops);
4486 }
4487 // Fold %x umin_seq %y to %x if %x ule %y.
4488 // TODO: We might be able to prove the predicate for a later operand.
4489 if (isKnownViaNonRecursiveReasoning(Pred, Ops[i - 1], Ops[i])) {
4490 Ops.erase(Ops.begin() + i);
4491 return getSequentialMinMaxExpr(Kind, Ops);
4492 }
4493 }
4494
4495 // Okay, it looks like we really DO need an expr. Check to see if we
4496 // already have one, otherwise create a new one.
4498 ID.AddInteger(Kind);
4499 for (const SCEV *Op : Ops)
4500 ID.AddPointer(Op);
4501 void *IP = nullptr;
4502 const SCEV *ExistingSCEV = UniqueSCEVs.FindNodeOrInsertPos(ID, IP);
4503 if (ExistingSCEV)
4504 return ExistingSCEV;
4505
4506 SCEVUse *O = SCEVAllocator.Allocate<SCEVUse>(Ops.size());
4508 SCEV *S = new (SCEVAllocator)
4509 SCEVSequentialMinMaxExpr(ID.Intern(SCEVAllocator), Kind, O, Ops.size());
4510
4511 UniqueSCEVs.InsertNode(S, IP);
4512 S->computeAndSetCanonical(*this);
4513 registerUser(S, Ops);
4514 return S;
4515}
4516
4521
4525
4530
4534
4539
4543
4545 bool Sequential) {
4546 SmallVector<SCEVUse, 2> Ops = {LHS, RHS};
4547 return getUMinExpr(Ops, Sequential);
4548}
4549
4555
4556const SCEV *
4558 const SCEV *Res = getConstant(IntTy, Size.getKnownMinValue());
4559 if (Size.isScalable())
4560 Res = getMulExpr(Res, getVScale(IntTy));
4561 return Res;
4562}
4563
4565 return getSizeOfExpr(IntTy, getDataLayout().getTypeAllocSize(AllocTy));
4566}
4567
4569 return getSizeOfExpr(IntTy, getDataLayout().getTypeStoreSize(StoreTy));
4570}
4571
4573 StructType *STy,
4574 unsigned FieldNo) {
4575 // We can bypass creating a target-independent constant expression and then
4576 // folding it back into a ConstantInt. This is just a compile-time
4577 // optimization.
4578 const StructLayout *SL = getDataLayout().getStructLayout(STy);
4579 assert(!SL->getSizeInBits().isScalable() &&
4580 "Cannot get offset for structure containing scalable vector types");
4581 return getConstant(IntTy, SL->getElementOffset(FieldNo));
4582}
4583
4585 // Don't attempt to do anything other than create a SCEVUnknown object
4586 // here. createSCEV only calls getUnknown after checking for all other
4587 // interesting possibilities, and any other code that calls getUnknown
4588 // is doing so in order to hide a value from SCEV canonicalization.
4589
4591 ID.AddInteger(scUnknown);
4592 ID.AddPointer(V);
4593 void *IP = nullptr;
4594 if (SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) {
4595 assert(cast<SCEVUnknown>(S)->getValue() == V &&
4596 "Stale SCEVUnknown in uniquing map!");
4597 return S;
4598 }
4599 SCEV *S = new (SCEVAllocator) SCEVUnknown(ID.Intern(SCEVAllocator), V, this,
4600 FirstUnknown);
4601 FirstUnknown = cast<SCEVUnknown>(S);
4602 UniqueSCEVs.InsertNode(S, IP);
4603 S->computeAndSetCanonical(*this);
4604 return S;
4605}
4606
4607//===----------------------------------------------------------------------===//
4608// Basic SCEV Analysis and PHI Idiom Recognition Code
4609//
4610
4611/// Test if values of the given type are analyzable within the SCEV
4612/// framework. This primarily includes integer types, and it can optionally
4613/// include pointer types if the ScalarEvolution class has access to
4614/// target-specific information.
4616 // Integers and pointers are always SCEVable.
4617 return Ty->isIntOrPtrTy();
4618}
4619
4620/// Return the size in bits of the specified type, for which isSCEVable must
4621/// return true.
4623 assert(isSCEVable(Ty) && "Type is not SCEVable!");
4624 if (Ty->isPointerTy())
4626 return getDataLayout().getTypeSizeInBits(Ty);
4627}
4628
4629/// Return a type with the same bitwidth as the given type and which represents
4630/// how SCEV will treat the given type, for which isSCEVable must return
4631/// true. For pointer types, this is the pointer index sized integer type.
4633 assert(isSCEVable(Ty) && "Type is not SCEVable!");
4634
4635 if (Ty->isIntegerTy())
4636 return Ty;
4637
4638 // The only other support type is pointer.
4639 assert(Ty->isPointerTy() && "Unexpected non-pointer non-integer type!");
4640 return getDataLayout().getIndexType(Ty);
4641}
4642
4644 return getTypeSizeInBits(T1) >= getTypeSizeInBits(T2) ? T1 : T2;
4645}
4646
4648 const SCEV *B) {
4649 /// For a valid use point to exist, the defining scope of one operand
4650 /// must dominate the other.
4651 bool PreciseA, PreciseB;
4652 auto *ScopeA = getDefiningScopeBound({A}, PreciseA);
4653 auto *ScopeB = getDefiningScopeBound({B}, PreciseB);
4654 if (!PreciseA || !PreciseB)
4655 // Can't tell.
4656 return false;
4657 return (ScopeA == ScopeB) || DT.dominates(ScopeA, ScopeB) ||
4658 DT.dominates(ScopeB, ScopeA);
4659}
4660
4662 return CouldNotCompute.get();
4663}
4664
4665bool ScalarEvolution::checkValidity(const SCEV *S) const {
4666 bool ContainsNulls = SCEVExprContains(S, [](const SCEV *S) {
4667 auto *SU = dyn_cast<SCEVUnknown>(S);
4668 return SU && SU->getValue() == nullptr;
4669 });
4670
4671 return !ContainsNulls;
4672}
4673
4675 HasRecMapType::iterator I = HasRecMap.find(S);
4676 if (I != HasRecMap.end())
4677 return I->second;
4678
4679 bool FoundAddRec =
4680 SCEVExprContains(S, [](const SCEV *S) { return isa<SCEVAddRecExpr>(S); });
4681 HasRecMap.insert({S, FoundAddRec});
4682 return FoundAddRec;
4683}
4684
4685/// Return the ValueOffsetPair set for \p S. \p S can be represented
4686/// by the value and offset from any ValueOffsetPair in the set.
4687ArrayRef<Value *> ScalarEvolution::getSCEVValues(const SCEV *S) {
4688 ExprValueMapType::iterator SI = ExprValueMap.find_as(S);
4689 if (SI == ExprValueMap.end())
4690 return {};
4691 return SI->second.getArrayRef();
4692}
4693
4694/// Erase Value from ValueExprMap and ExprValueMap. ValueExprMap.erase(V)
4695/// cannot be used separately. eraseValueFromMap should be used to remove
4696/// V from ValueExprMap and ExprValueMap at the same time.
4697void ScalarEvolution::eraseValueFromMap(Value *V) {
4698 ValueExprMapType::iterator I = ValueExprMap.find_as(V);
4699 if (I != ValueExprMap.end()) {
4700 auto EVIt = ExprValueMap.find(I->second);
4701 bool Removed = EVIt->second.remove(V);
4702 (void) Removed;
4703 assert(Removed && "Value not in ExprValueMap?");
4704 ValueExprMap.erase(I);
4705 }
4706}
4707
4708void ScalarEvolution::insertValueToMap(Value *V, const SCEV *S) {
4709 // A recursive query may have already computed the SCEV. It should be
4710 // equivalent, but may not necessarily be exactly the same, e.g. due to lazily
4711 // inferred nowrap flags.
4712 auto It = ValueExprMap.find_as(V);
4713 if (It == ValueExprMap.end()) {
4714 ValueExprMap.insert({SCEVCallbackVH(V, this), S});
4715 ExprValueMap[S].insert(V);
4716 }
4717}
4718
4719/// Return an existing SCEV if it exists, otherwise analyze the expression and
4720/// create a new one.
4722 assert(isSCEVable(V->getType()) && "Value is not SCEVable!");
4723
4724 if (const SCEV *S = getExistingSCEV(V))
4725 return S;
4726 return createSCEVIter(V);
4727}
4728
4730 assert(isSCEVable(V->getType()) && "Value is not SCEVable!");
4731
4732 ValueExprMapType::iterator I = ValueExprMap.find_as(V);
4733 if (I != ValueExprMap.end()) {
4734 const SCEV *S = I->second;
4735 assert(checkValidity(S) &&
4736 "existing SCEV has not been properly invalidated");
4737 return S;
4738 }
4739 return nullptr;
4740}
4741
4742/// Return a SCEV corresponding to -V = -1*V
4744 SCEV::NoWrapFlags Flags) {
4745 if (const SCEVConstant *VC = dyn_cast<SCEVConstant>(V))
4746 return getConstant(
4747 cast<ConstantInt>(ConstantExpr::getNeg(VC->getValue())));
4748
4749 Type *Ty = V->getType();
4750 Ty = getEffectiveSCEVType(Ty);
4751 return getMulExpr(V, getMinusOne(Ty), Flags);
4752}
4753
4754/// If Expr computes ~A, return A else return nullptr
4755static const SCEV *MatchNotExpr(const SCEV *Expr) {
4756 const SCEV *MulOp;
4757 if (match(Expr, m_scev_Add(m_scev_AllOnes(),
4758 m_scev_Mul(m_scev_AllOnes(), m_SCEV(MulOp)))))
4759 return MulOp;
4760 return nullptr;
4761}
4762
4763/// Return a SCEV corresponding to ~V = -1-V
4765 assert(!V->getType()->isPointerTy() && "Can't negate pointer");
4766
4767 if (const SCEVConstant *VC = dyn_cast<SCEVConstant>(V))
4768 return getConstant(
4769 cast<ConstantInt>(ConstantExpr::getNot(VC->getValue())));
4770
4771 // Fold ~(u|s)(min|max)(~x, ~y) to (u|s)(max|min)(x, y)
4772 if (const SCEVMinMaxExpr *MME = dyn_cast<SCEVMinMaxExpr>(V)) {
4773 auto MatchMinMaxNegation = [&](const SCEVMinMaxExpr *MME) {
4774 SmallVector<SCEVUse, 2> MatchedOperands;
4775 for (const SCEV *Operand : MME->operands()) {
4776 const SCEV *Matched = MatchNotExpr(Operand);
4777 if (!Matched)
4778 return (const SCEV *)nullptr;
4779 MatchedOperands.push_back(Matched);
4780 }
4781 return getMinMaxExpr(SCEVMinMaxExpr::negate(MME->getSCEVType()),
4782 MatchedOperands);
4783 };
4784 if (const SCEV *Replaced = MatchMinMaxNegation(MME))
4785 return Replaced;
4786 }
4787
4788 Type *Ty = V->getType();
4789 Ty = getEffectiveSCEVType(Ty);
4790 return getMinusSCEV(getMinusOne(Ty), V);
4791}
4792
4794 assert(P->getType()->isPointerTy());
4795
4796 if (auto *AddRec = dyn_cast<SCEVAddRecExpr>(P)) {
4797 // The base of an AddRec is the first operand.
4798 SmallVector<SCEVUse> Ops{AddRec->operands()};
4799 Ops[0] = removePointerBase(Ops[0]);
4800 // Don't try to transfer nowrap flags for now. We could in some cases
4801 // (for example, if pointer operand of the AddRec is a SCEVUnknown).
4802 return getAddRecExpr(Ops, AddRec->getLoop(), SCEV::FlagAnyWrap);
4803 }
4804 if (auto *Add = dyn_cast<SCEVAddExpr>(P)) {
4805 // The base of an Add is the pointer operand.
4806 SmallVector<SCEVUse> Ops{Add->operands()};
4807 SCEVUse *PtrOp = nullptr;
4808 for (SCEVUse &AddOp : Ops) {
4809 if (AddOp->getType()->isPointerTy()) {
4810 assert(!PtrOp && "Cannot have multiple pointer ops");
4811 PtrOp = &AddOp;
4812 }
4813 }
4814 *PtrOp = removePointerBase(*PtrOp);
4815 // Don't try to transfer nowrap flags for now. We could in some cases
4816 // (for example, if the pointer operand of the Add is a SCEVUnknown).
4817 return getAddExpr(Ops);
4818 }
4819 // Any other expression must be a pointer base.
4820 return getZero(P->getType());
4821}
4822
4824 SCEV::NoWrapFlags Flags,
4825 unsigned Depth) {
4826 // Fast path: X - X --> 0.
4827 if (LHS == RHS)
4828 return getZero(LHS->getType());
4829
4830 // If we subtract two pointers with different pointer bases, bail.
4831 // Eventually, we're going to add an assertion to getMulExpr that we
4832 // can't multiply by a pointer.
4833 if (RHS->getType()->isPointerTy()) {
4834 if (!LHS->getType()->isPointerTy() ||
4835 getPointerBase(LHS) != getPointerBase(RHS))
4836 return getCouldNotCompute();
4837 LHS = removePointerBase(LHS);
4838 RHS = removePointerBase(RHS);
4839 }
4840
4841 // We represent LHS - RHS as LHS + (-1)*RHS. This transformation
4842 // makes it so that we cannot make much use of NUW.
4843 auto AddFlags = SCEV::FlagAnyWrap;
4844 const bool RHSIsNotMinSigned =
4846 if (hasFlags(Flags, SCEV::FlagNSW)) {
4847 // Let M be the minimum representable signed value. Then (-1)*RHS
4848 // signed-wraps if and only if RHS is M. That can happen even for
4849 // a NSW subtraction because e.g. (-1)*M signed-wraps even though
4850 // -1 - M does not. So to transfer NSW from LHS - RHS to LHS +
4851 // (-1)*RHS, we need to prove that RHS != M.
4852 //
4853 // If LHS is non-negative and we know that LHS - RHS does not
4854 // signed-wrap, then RHS cannot be M. So we can rule out signed-wrap
4855 // either by proving that RHS > M or that LHS >= 0.
4856 if (RHSIsNotMinSigned || isKnownNonNegative(LHS)) {
4857 AddFlags = SCEV::FlagNSW;
4858 }
4859 }
4860
4861 // FIXME: Find a correct way to transfer NSW to (-1)*M when LHS -
4862 // RHS is NSW and LHS >= 0.
4863 //
4864 // The difficulty here is that the NSW flag may have been proven
4865 // relative to a loop that is to be found in a recurrence in LHS and
4866 // not in RHS. Applying NSW to (-1)*M may then let the NSW have a
4867 // larger scope than intended.
4868 auto NegFlags = RHSIsNotMinSigned ? SCEV::FlagNSW : SCEV::FlagAnyWrap;
4869
4870 return getAddExpr(LHS, getNegativeSCEV(RHS, NegFlags), AddFlags, Depth);
4871}
4872
4874 unsigned Depth) {
4875 Type *SrcTy = V->getType();
4876 assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
4877 "Cannot truncate or zero extend with non-integer arguments!");
4878 if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
4879 return V; // No conversion
4880 if (getTypeSizeInBits(SrcTy) > getTypeSizeInBits(Ty))
4881 return getTruncateExpr(V, Ty, Depth);
4882 return getZeroExtendExpr(V, Ty, Depth);
4883}
4884
4886 unsigned Depth) {
4887 Type *SrcTy = V->getType();
4888 assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
4889 "Cannot truncate or zero extend with non-integer arguments!");
4890 if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
4891 return V; // No conversion
4892 if (getTypeSizeInBits(SrcTy) > getTypeSizeInBits(Ty))
4893 return getTruncateExpr(V, Ty, Depth);
4894 return getSignExtendExpr(V, Ty, Depth);
4895}
4896
4897const SCEV *
4899 Type *SrcTy = V->getType();
4900 assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
4901 "Cannot noop or zero extend with non-integer arguments!");
4903 "getNoopOrZeroExtend cannot truncate!");
4904 if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
4905 return V; // No conversion
4906 return getZeroExtendExpr(V, Ty);
4907}
4908
4909const SCEV *
4911 Type *SrcTy = V->getType();
4912 assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
4913 "Cannot noop or sign extend with non-integer arguments!");
4915 "getNoopOrSignExtend cannot truncate!");
4916 if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
4917 return V; // No conversion
4918 return getSignExtendExpr(V, Ty);
4919}
4920
4921const SCEV *
4923 Type *SrcTy = V->getType();
4924 assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
4925 "Cannot noop or any extend with non-integer arguments!");
4927 "getNoopOrAnyExtend cannot truncate!");
4928 if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
4929 return V; // No conversion
4930 return getAnyExtendExpr(V, Ty);
4931}
4932
4933const SCEV *
4935 Type *SrcTy = V->getType();
4936 assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
4937 "Cannot truncate or noop with non-integer arguments!");
4939 "getTruncateOrNoop cannot extend!");
4940 if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
4941 return V; // No conversion
4942 return getTruncateExpr(V, Ty);
4943}
4944
4946 const SCEV *RHS) {
4947 const SCEV *PromotedLHS = LHS;
4948 const SCEV *PromotedRHS = RHS;
4949
4950 if (getTypeSizeInBits(LHS->getType()) > getTypeSizeInBits(RHS->getType()))
4951 PromotedRHS = getZeroExtendExpr(RHS, LHS->getType());
4952 else
4953 PromotedLHS = getNoopOrZeroExtend(LHS, RHS->getType());
4954
4955 return getUMaxExpr(PromotedLHS, PromotedRHS);
4956}
4957
4959 const SCEV *RHS,
4960 bool Sequential) {
4961 SmallVector<SCEVUse, 2> Ops = {LHS, RHS};
4962 return getUMinFromMismatchedTypes(Ops, Sequential);
4963}
4964
4965const SCEV *
4967 bool Sequential) {
4968 assert(!Ops.empty() && "At least one operand must be!");
4969 // Trivial case.
4970 if (Ops.size() == 1)
4971 return Ops[0];
4972
4973 // Find the max type first.
4974 Type *MaxType = nullptr;
4975 for (SCEVUse S : Ops)
4976 if (MaxType)
4977 MaxType = getWiderType(MaxType, S->getType());
4978 else
4979 MaxType = S->getType();
4980 assert(MaxType && "Failed to find maximum type!");
4981
4982 // Extend all ops to max type.
4983 SmallVector<SCEVUse, 2> PromotedOps;
4984 for (SCEVUse S : Ops)
4985 PromotedOps.push_back(getNoopOrZeroExtend(S, MaxType));
4986
4987 // Generate umin.
4988 return getUMinExpr(PromotedOps, Sequential);
4989}
4990
4992 // A pointer operand may evaluate to a nonpointer expression, such as null.
4993 if (!V->getType()->isPointerTy())
4994 return V;
4995
4996 while (true) {
4997 if (auto *AddRec = dyn_cast<SCEVAddRecExpr>(V)) {
4998 V = AddRec->getStart();
4999 } else if (auto *Add = dyn_cast<SCEVAddExpr>(V)) {
5000 const SCEV *PtrOp = nullptr;
5001 for (const SCEV *AddOp : Add->operands()) {
5002 if (AddOp->getType()->isPointerTy()) {
5003 assert(!PtrOp && "Cannot have multiple pointer ops");
5004 PtrOp = AddOp;
5005 }
5006 }
5007 assert(PtrOp && "Must have pointer op");
5008 V = PtrOp;
5009 } else // Not something we can look further into.
5010 return V;
5011 }
5012}
5013
5014/// Push users of the given Instruction onto the given Worklist.
5018 // Push the def-use children onto the Worklist stack.
5019 for (User *U : I->users()) {
5020 auto *UserInsn = cast<Instruction>(U);
5021 if (Visited.insert(UserInsn).second)
5022 Worklist.push_back(UserInsn);
5023 }
5024}
5025
5026namespace {
5027
5028/// Takes SCEV S and Loop L. For each AddRec sub-expression, use its start
5029/// expression in case its Loop is L. If it is not L then
5030/// if IgnoreOtherLoops is true then use AddRec itself
5031/// otherwise rewrite cannot be done.
5032/// If SCEV contains non-invariant unknown SCEV rewrite cannot be done.
5033class SCEVInitRewriter : public SCEVRewriteVisitor<SCEVInitRewriter> {
5034public:
5035 static const SCEV *rewrite(const SCEV *S, const Loop *L, ScalarEvolution &SE,
5036 bool IgnoreOtherLoops = true) {
5037 SCEVInitRewriter Rewriter(L, SE);
5038 const SCEV *Result = Rewriter.visit(S);
5039 if (Rewriter.hasSeenLoopVariantSCEVUnknown())
5040 return SE.getCouldNotCompute();
5041 return Rewriter.hasSeenOtherLoops() && !IgnoreOtherLoops
5042 ? SE.getCouldNotCompute()
5043 : Result;
5044 }
5045
5046 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
5047 if (!SE.isLoopInvariant(Expr, L))
5048 SeenLoopVariantSCEVUnknown = true;
5049 return Expr;
5050 }
5051
5052 const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) {
5053 // Only re-write AddRecExprs for this loop.
5054 if (Expr->getLoop() == L)
5055 return Expr->getStart();
5056 SeenOtherLoops = true;
5057 return Expr;
5058 }
5059
5060 bool hasSeenLoopVariantSCEVUnknown() { return SeenLoopVariantSCEVUnknown; }
5061
5062 bool hasSeenOtherLoops() { return SeenOtherLoops; }
5063
5064private:
5065 explicit SCEVInitRewriter(const Loop *L, ScalarEvolution &SE)
5066 : SCEVRewriteVisitor(SE), L(L) {}
5067
5068 const Loop *L;
5069 bool SeenLoopVariantSCEVUnknown = false;
5070 bool SeenOtherLoops = false;
5071};
5072
5073/// Takes SCEV S and Loop L. For each AddRec sub-expression, use its post
5074/// increment expression in case its Loop is L. If it is not L then
5075/// use AddRec itself.
5076/// If SCEV contains non-invariant unknown SCEV rewrite cannot be done.
5077class SCEVPostIncRewriter : public SCEVRewriteVisitor<SCEVPostIncRewriter> {
5078public:
5079 static const SCEV *rewrite(const SCEV *S, const Loop *L, ScalarEvolution &SE) {
5080 SCEVPostIncRewriter Rewriter(L, SE);
5081 const SCEV *Result = Rewriter.visit(S);
5082 return Rewriter.hasSeenLoopVariantSCEVUnknown()
5083 ? SE.getCouldNotCompute()
5084 : Result;
5085 }
5086
5087 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
5088 if (!SE.isLoopInvariant(Expr, L))
5089 SeenLoopVariantSCEVUnknown = true;
5090 return Expr;
5091 }
5092
5093 const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) {
5094 // Only re-write AddRecExprs for this loop.
5095 if (Expr->getLoop() == L)
5096 return Expr->getPostIncExpr(SE);
5097 SeenOtherLoops = true;
5098 return Expr;
5099 }
5100
5101 bool hasSeenLoopVariantSCEVUnknown() { return SeenLoopVariantSCEVUnknown; }
5102
5103 bool hasSeenOtherLoops() { return SeenOtherLoops; }
5104
5105private:
5106 explicit SCEVPostIncRewriter(const Loop *L, ScalarEvolution &SE)
5107 : SCEVRewriteVisitor(SE), L(L) {}
5108
5109 const Loop *L;
5110 bool SeenLoopVariantSCEVUnknown = false;
5111 bool SeenOtherLoops = false;
5112};
5113
5114/// This class evaluates the compare condition by matching it against the
5115/// condition of loop latch. If there is a match we assume a true value
5116/// for the condition while building SCEV nodes.
5117class SCEVBackedgeConditionFolder
5118 : public SCEVRewriteVisitor<SCEVBackedgeConditionFolder> {
5119public:
5120 static const SCEV *rewrite(const SCEV *S, const Loop *L,
5121 ScalarEvolution &SE) {
5122 bool IsPosBECond = false;
5123 Value *BECond = nullptr;
5124 if (BasicBlock *Latch = L->getLoopLatch()) {
5125 if (CondBrInst *BI = dyn_cast<CondBrInst>(Latch->getTerminator())) {
5126 assert(BI->getSuccessor(0) != BI->getSuccessor(1) &&
5127 "Both outgoing branches should not target same header!");
5128 BECond = BI->getCondition();
5129 IsPosBECond = BI->getSuccessor(0) == L->getHeader();
5130 } else {
5131 return S;
5132 }
5133 }
5134 SCEVBackedgeConditionFolder Rewriter(L, BECond, IsPosBECond, SE);
5135 return Rewriter.visit(S);
5136 }
5137
5138 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
5139 const SCEV *Result = Expr;
5140 bool InvariantF = SE.isLoopInvariant(Expr, L);
5141
5142 if (!InvariantF) {
5144 switch (I->getOpcode()) {
5145 case Instruction::Select: {
5146 SelectInst *SI = cast<SelectInst>(I);
5147 std::optional<const SCEV *> Res =
5148 compareWithBackedgeCondition(SI->getCondition());
5149 if (Res) {
5150 bool IsOne = cast<SCEVConstant>(*Res)->getValue()->isOne();
5151 Result = SE.getSCEV(IsOne ? SI->getTrueValue() : SI->getFalseValue());
5152 }
5153 break;
5154 }
5155 default: {
5156 std::optional<const SCEV *> Res = compareWithBackedgeCondition(I);
5157 if (Res)
5158 Result = *Res;
5159 break;
5160 }
5161 }
5162 }
5163 return Result;
5164 }
5165
5166private:
5167 explicit SCEVBackedgeConditionFolder(const Loop *L, Value *BECond,
5168 bool IsPosBECond, ScalarEvolution &SE)
5169 : SCEVRewriteVisitor(SE), L(L), BackedgeCond(BECond),
5170 IsPositiveBECond(IsPosBECond) {}
5171
5172 std::optional<const SCEV *> compareWithBackedgeCondition(Value *IC);
5173
5174 const Loop *L;
5175 /// Loop back condition.
5176 Value *BackedgeCond = nullptr;
5177 /// Set to true if loop back is on positive branch condition.
5178 bool IsPositiveBECond;
5179};
5180
5181std::optional<const SCEV *>
5182SCEVBackedgeConditionFolder::compareWithBackedgeCondition(Value *IC) {
5183
5184 // If value matches the backedge condition for loop latch,
5185 // then return a constant evolution node based on loopback
5186 // branch taken.
5187 if (BackedgeCond == IC)
5188 return IsPositiveBECond ? SE.getOne(Type::getInt1Ty(SE.getContext()))
5190 return std::nullopt;
5191}
5192
5193class SCEVShiftRewriter : public SCEVRewriteVisitor<SCEVShiftRewriter> {
5194public:
5195 static const SCEV *rewrite(const SCEV *S, const Loop *L,
5196 ScalarEvolution &SE) {
5197 SCEVShiftRewriter Rewriter(L, SE);
5198 const SCEV *Result = Rewriter.visit(S);
5199 return Rewriter.isValid() ? Result : SE.getCouldNotCompute();
5200 }
5201
5202 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
5203 // Only allow AddRecExprs for this loop.
5204 if (!SE.isLoopInvariant(Expr, L))
5205 Valid = false;
5206 return Expr;
5207 }
5208
5209 const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) {
5210 if (Expr->getLoop() == L && Expr->isAffine())
5211 return SE.getMinusSCEV(Expr, Expr->getStepRecurrence(SE));
5212 Valid = false;
5213 return Expr;
5214 }
5215
5216 bool isValid() { return Valid; }
5217
5218private:
5219 explicit SCEVShiftRewriter(const Loop *L, ScalarEvolution &SE)
5220 : SCEVRewriteVisitor(SE), L(L) {}
5221
5222 const Loop *L;
5223 bool Valid = true;
5224};
5225
5226} // end anonymous namespace
5227
5229ScalarEvolution::proveNoWrapViaConstantRanges(const SCEVAddRecExpr *AR) {
5230 if (!AR->isAffine())
5231 return SCEV::FlagAnyWrap;
5232
5233 using OBO = OverflowingBinaryOperator;
5234
5236
5237 if (!AR->hasNoSelfWrap()) {
5238 const SCEV *BECount = getConstantMaxBackedgeTakenCount(AR->getLoop());
5239 if (const SCEVConstant *BECountMax = dyn_cast<SCEVConstant>(BECount)) {
5240 ConstantRange StepCR = getSignedRange(AR->getStepRecurrence(*this));
5241 const APInt &BECountAP = BECountMax->getAPInt();
5242 unsigned NoOverflowBitWidth =
5243 BECountAP.getActiveBits() + StepCR.getMinSignedBits();
5244 if (NoOverflowBitWidth <= getTypeSizeInBits(AR->getType()))
5246 }
5247 }
5248
5249 if (!AR->hasNoSignedWrap()) {
5250 ConstantRange AddRecRange = getSignedRange(AR);
5251 ConstantRange IncRange = getSignedRange(AR->getStepRecurrence(*this));
5252
5254 Instruction::Add, IncRange, OBO::NoSignedWrap);
5255 if (NSWRegion.contains(AddRecRange))
5257 }
5258
5259 if (!AR->hasNoUnsignedWrap()) {
5260 ConstantRange AddRecRange = getUnsignedRange(AR);
5261 ConstantRange IncRange = getUnsignedRange(AR->getStepRecurrence(*this));
5262
5264 Instruction::Add, IncRange, OBO::NoUnsignedWrap);
5265 if (NUWRegion.contains(AddRecRange))
5267 }
5268
5269 return Result;
5270}
5271
5273ScalarEvolution::proveNoSignedWrapViaInduction(const SCEVAddRecExpr *AR) {
5275
5276 if (AR->hasNoSignedWrap())
5277 return Result;
5278
5279 if (!AR->isAffine())
5280 return Result;
5281
5282 // This function can be expensive, only try to prove NSW once per AddRec.
5283 if (!SignedWrapViaInductionTried.insert(AR).second)
5284 return Result;
5285
5286 const SCEV *Step = AR->getStepRecurrence(*this);
5287 const Loop *L = AR->getLoop();
5288
5289 // Check whether the backedge-taken count is SCEVCouldNotCompute.
5290 // Note that this serves two purposes: It filters out loops that are
5291 // simply not analyzable, and it covers the case where this code is
5292 // being called from within backedge-taken count analysis, such that
5293 // attempting to ask for the backedge-taken count would likely result
5294 // in infinite recursion. In the later case, the analysis code will
5295 // cope with a conservative value, and it will take care to purge
5296 // that value once it has finished.
5297 const SCEV *MaxBECount = getConstantMaxBackedgeTakenCount(L);
5298
5299 // Normally, in the cases we can prove no-overflow via a
5300 // backedge guarding condition, we can also compute a backedge
5301 // taken count for the loop. The exceptions are assumptions and
5302 // guards present in the loop -- SCEV is not great at exploiting
5303 // these to compute max backedge taken counts, but can still use
5304 // these to prove lack of overflow. Use this fact to avoid
5305 // doing extra work that may not pay off.
5306
5307 if (isa<SCEVCouldNotCompute>(MaxBECount) && !HasGuards &&
5308 AC.assumptions().empty())
5309 return Result;
5310
5311 // If the backedge is guarded by a comparison with the pre-inc value the
5312 // addrec is safe. Also, if the entry is guarded by a comparison with the
5313 // start value and the backedge is guarded by a comparison with the post-inc
5314 // value, the addrec is safe.
5316 const SCEV *OverflowLimit =
5317 getSignedOverflowLimitForStep(Step, &Pred, this);
5318 if (OverflowLimit &&
5319 (isLoopBackedgeGuardedByCond(L, Pred, AR, OverflowLimit) ||
5320 isKnownOnEveryIteration(Pred, AR, OverflowLimit))) {
5321 Result = setFlags(Result, SCEV::FlagNSW);
5322 }
5323 return Result;
5324}
5326ScalarEvolution::proveNoUnsignedWrapViaInduction(const SCEVAddRecExpr *AR) {
5328
5329 if (AR->hasNoUnsignedWrap())
5330 return Result;
5331
5332 if (!AR->isAffine())
5333 return Result;
5334
5335 // This function can be expensive, only try to prove NUW once per AddRec.
5336 if (!UnsignedWrapViaInductionTried.insert(AR).second)
5337 return Result;
5338
5339 const SCEV *Step = AR->getStepRecurrence(*this);
5340 unsigned BitWidth = getTypeSizeInBits(AR->getType());
5341 const Loop *L = AR->getLoop();
5342
5343 // Check whether the backedge-taken count is SCEVCouldNotCompute.
5344 // Note that this serves two purposes: It filters out loops that are
5345 // simply not analyzable, and it covers the case where this code is
5346 // being called from within backedge-taken count analysis, such that
5347 // attempting to ask for the backedge-taken count would likely result
5348 // in infinite recursion. In the later case, the analysis code will
5349 // cope with a conservative value, and it will take care to purge
5350 // that value once it has finished.
5351 const SCEV *MaxBECount = getConstantMaxBackedgeTakenCount(L);
5352
5353 // Normally, in the cases we can prove no-overflow via a
5354 // backedge guarding condition, we can also compute a backedge
5355 // taken count for the loop. The exceptions are assumptions and
5356 // guards present in the loop -- SCEV is not great at exploiting
5357 // these to compute max backedge taken counts, but can still use
5358 // these to prove lack of overflow. Use this fact to avoid
5359 // doing extra work that may not pay off.
5360
5361 if (isa<SCEVCouldNotCompute>(MaxBECount) && !HasGuards &&
5362 AC.assumptions().empty())
5363 return Result;
5364
5365 // If the backedge is guarded by a comparison with the pre-inc value the
5366 // addrec is safe. Also, if the entry is guarded by a comparison with the
5367 // start value and the backedge is guarded by a comparison with the post-inc
5368 // value, the addrec is safe.
5369 if (isKnownPositive(Step)) {
5370 const SCEV *N = getConstant(APInt::getMinValue(BitWidth) -
5371 getUnsignedRangeMax(Step));
5374 Result = setFlags(Result, SCEV::FlagNUW);
5375 }
5376 }
5377
5378 return Result;
5379}
5380
5381namespace {
5382
5383/// Represents an abstract binary operation. This may exist as a
5384/// normal instruction or constant expression, or may have been
5385/// derived from an expression tree.
5386struct BinaryOp {
5387 unsigned Opcode;
5388 Value *LHS;
5389 Value *RHS;
5390 bool IsNSW = false;
5391 bool IsNUW = false;
5392
5393 /// Op is set if this BinaryOp corresponds to a concrete LLVM instruction or
5394 /// constant expression.
5395 Operator *Op = nullptr;
5396
5397 explicit BinaryOp(Operator *Op)
5398 : Opcode(Op->getOpcode()), LHS(Op->getOperand(0)), RHS(Op->getOperand(1)),
5399 Op(Op) {
5400 if (auto *OBO = dyn_cast<OverflowingBinaryOperator>(Op)) {
5401 IsNSW = OBO->hasNoSignedWrap();
5402 IsNUW = OBO->hasNoUnsignedWrap();
5403 }
5404 }
5405
5406 explicit BinaryOp(unsigned Opcode, Value *LHS, Value *RHS, bool IsNSW = false,
5407 bool IsNUW = false)
5408 : Opcode(Opcode), LHS(LHS), RHS(RHS), IsNSW(IsNSW), IsNUW(IsNUW) {}
5409};
5410
5411} // end anonymous namespace
5412
5413/// Try to map \p V into a BinaryOp, and return \c std::nullopt on failure.
5414static std::optional<BinaryOp> MatchBinaryOp(Value *V, const DataLayout &DL,
5415 AssumptionCache &AC,
5416 const DominatorTree &DT,
5417 const Instruction *CxtI) {
5418 auto *Op = dyn_cast<Operator>(V);
5419 if (!Op)
5420 return std::nullopt;
5421
5422 // Implementation detail: all the cleverness here should happen without
5423 // creating new SCEV expressions -- our caller knowns tricks to avoid creating
5424 // SCEV expressions when possible, and we should not break that.
5425
5426 switch (Op->getOpcode()) {
5427 case Instruction::Add:
5428 case Instruction::Sub:
5429 case Instruction::Mul:
5430 case Instruction::UDiv:
5431 case Instruction::URem:
5432 case Instruction::And:
5433 case Instruction::AShr:
5434 case Instruction::Shl:
5435 return BinaryOp(Op);
5436
5437 case Instruction::Or: {
5438 // Convert or disjoint into add nuw nsw.
5439 if (cast<PossiblyDisjointInst>(Op)->isDisjoint())
5440 return BinaryOp(Instruction::Add, Op->getOperand(0), Op->getOperand(1),
5441 /*IsNSW=*/true, /*IsNUW=*/true);
5442 return BinaryOp(Op);
5443 }
5444
5445 case Instruction::Xor:
5446 if (auto *RHSC = dyn_cast<ConstantInt>(Op->getOperand(1)))
5447 // If the RHS of the xor is a signmask, then this is just an add.
5448 // Instcombine turns add of signmask into xor as a strength reduction step.
5449 if (RHSC->getValue().isSignMask())
5450 return BinaryOp(Instruction::Add, Op->getOperand(0), Op->getOperand(1));
5451 // Binary `xor` is a bit-wise `add`.
5452 if (V->getType()->isIntegerTy(1))
5453 return BinaryOp(Instruction::Add, Op->getOperand(0), Op->getOperand(1));
5454 return BinaryOp(Op);
5455
5456 case Instruction::LShr:
5457 // Turn logical shift right of a constant into a unsigned divide.
5458 if (ConstantInt *SA = dyn_cast<ConstantInt>(Op->getOperand(1))) {
5459 uint32_t BitWidth = cast<IntegerType>(Op->getType())->getBitWidth();
5460
5461 // If the shift count is not less than the bitwidth, the result of
5462 // the shift is undefined. Don't try to analyze it, because the
5463 // resolution chosen here may differ from the resolution chosen in
5464 // other parts of the compiler.
5465 if (SA->getValue().ult(BitWidth)) {
5466 Constant *X =
5467 ConstantInt::get(SA->getContext(),
5468 APInt::getOneBitSet(BitWidth, SA->getZExtValue()));
5469 return BinaryOp(Instruction::UDiv, Op->getOperand(0), X);
5470 }
5471 }
5472 return BinaryOp(Op);
5473
5474 case Instruction::ExtractValue: {
5475 auto *EVI = cast<ExtractValueInst>(Op);
5476 if (EVI->getNumIndices() != 1 || EVI->getIndices()[0] != 0)
5477 break;
5478
5479 auto *WO = dyn_cast<WithOverflowInst>(EVI->getAggregateOperand());
5480 if (!WO)
5481 break;
5482
5483 Instruction::BinaryOps BinOp = WO->getBinaryOp();
5484 bool Signed = WO->isSigned();
5485 // TODO: Should add nuw/nsw flags for mul as well.
5486 if (BinOp == Instruction::Mul || !isOverflowIntrinsicNoWrap(WO, DT))
5487 return BinaryOp(BinOp, WO->getLHS(), WO->getRHS());
5488
5489 // Now that we know that all uses of the arithmetic-result component of
5490 // CI are guarded by the overflow check, we can go ahead and pretend
5491 // that the arithmetic is non-overflowing.
5492 return BinaryOp(BinOp, WO->getLHS(), WO->getRHS(),
5493 /* IsNSW = */ Signed, /* IsNUW = */ !Signed);
5494 }
5495
5496 default:
5497 break;
5498 }
5499
5500 // Recognise intrinsic loop.decrement.reg, and as this has exactly the same
5501 // semantics as a Sub, return a binary sub expression.
5502 if (auto *II = dyn_cast<IntrinsicInst>(V))
5503 if (II->getIntrinsicID() == Intrinsic::loop_decrement_reg)
5504 return BinaryOp(Instruction::Sub, II->getOperand(0), II->getOperand(1));
5505
5506 return std::nullopt;
5507}
5508
5509/// Helper function to createAddRecFromPHIWithCasts. We have a phi
5510/// node whose symbolic (unknown) SCEV is \p SymbolicPHI, which is updated via
5511/// the loop backedge by a SCEVAddExpr, possibly also with a few casts on the
5512/// way. This function checks if \p Op, an operand of this SCEVAddExpr,
5513/// follows one of the following patterns:
5514/// Op == (SExt ix (Trunc iy (%SymbolicPHI) to ix) to iy)
5515/// Op == (ZExt ix (Trunc iy (%SymbolicPHI) to ix) to iy)
5516/// If the SCEV expression of \p Op conforms with one of the expected patterns
5517/// we return the type of the truncation operation, and indicate whether the
5518/// truncated type should be treated as signed/unsigned by setting
5519/// \p Signed to true/false, respectively.
5520static Type *isSimpleCastedPHI(const SCEV *Op, const SCEVUnknown *SymbolicPHI,
5521 bool &Signed, ScalarEvolution &SE) {
5522 // The case where Op == SymbolicPHI (that is, with no type conversions on
5523 // the way) is handled by the regular add recurrence creating logic and
5524 // would have already been triggered in createAddRecForPHI. Reaching it here
5525 // means that createAddRecFromPHI had failed for this PHI before (e.g.,
5526 // because one of the other operands of the SCEVAddExpr updating this PHI is
5527 // not invariant).
5528 //
5529 // Here we look for the case where Op = (ext(trunc(SymbolicPHI))), and in
5530 // this case predicates that allow us to prove that Op == SymbolicPHI will
5531 // be added.
5532 if (Op == SymbolicPHI)
5533 return nullptr;
5534
5535 unsigned SourceBits = SE.getTypeSizeInBits(SymbolicPHI->getType());
5536 unsigned NewBits = SE.getTypeSizeInBits(Op->getType());
5537 if (SourceBits != NewBits)
5538 return nullptr;
5539
5540 if (match(Op, m_scev_SExt(m_scev_Trunc(m_scev_Specific(SymbolicPHI))))) {
5541 Signed = true;
5542 return cast<SCEVCastExpr>(Op)->getOperand()->getType();
5543 }
5544 if (match(Op, m_scev_ZExt(m_scev_Trunc(m_scev_Specific(SymbolicPHI))))) {
5545 Signed = false;
5546 return cast<SCEVCastExpr>(Op)->getOperand()->getType();
5547 }
5548 return nullptr;
5549}
5550
5551static const Loop *isIntegerLoopHeaderPHI(const PHINode *PN, LoopInfo &LI) {
5552 if (!PN->getType()->isIntegerTy())
5553 return nullptr;
5554 const Loop *L = LI.getLoopFor(PN->getParent());
5555 if (!L || L->getHeader() != PN->getParent())
5556 return nullptr;
5557 return L;
5558}
5559
5560// Analyze \p SymbolicPHI, a SCEV expression of a phi node, and check if the
5561// computation that updates the phi follows the following pattern:
5562// (SExt/ZExt ix (Trunc iy (%SymbolicPHI) to ix) to iy) + InvariantAccum
5563// which correspond to a phi->trunc->sext/zext->add->phi update chain.
5564// If so, try to see if it can be rewritten as an AddRecExpr under some
5565// Predicates. If successful, return them as a pair. Also cache the results
5566// of the analysis.
5567//
5568// Example usage scenario:
5569// Say the Rewriter is called for the following SCEV:
5570// 8 * ((sext i32 (trunc i64 %X to i32) to i64) + %Step)
5571// where:
5572// %X = phi i64 (%Start, %BEValue)
5573// It will visitMul->visitAdd->visitSExt->visitTrunc->visitUnknown(%X),
5574// and call this function with %SymbolicPHI = %X.
5575//
5576// The analysis will find that the value coming around the backedge has
5577// the following SCEV:
5578// BEValue = ((sext i32 (trunc i64 %X to i32) to i64) + %Step)
5579// Upon concluding that this matches the desired pattern, the function
5580// will return the pair {NewAddRec, SmallPredsVec} where:
5581// NewAddRec = {%Start,+,%Step}
5582// SmallPredsVec = {P1, P2, P3} as follows:
5583// P1(WrapPred): AR: {trunc(%Start),+,(trunc %Step)}<nsw> Flags: <nssw>
5584// P2(EqualPred): %Start == (sext i32 (trunc i64 %Start to i32) to i64)
5585// P3(EqualPred): %Step == (sext i32 (trunc i64 %Step to i32) to i64)
5586// The returned pair means that SymbolicPHI can be rewritten into NewAddRec
5587// under the predicates {P1,P2,P3}.
5588// This predicated rewrite will be cached in PredicatedSCEVRewrites:
5589// PredicatedSCEVRewrites[{%X,L}] = {NewAddRec, {P1,P2,P3)}
5590//
5591// TODO's:
5592//
5593// 1) Extend the Induction descriptor to also support inductions that involve
5594// casts: When needed (namely, when we are called in the context of the
5595// vectorizer induction analysis), a Set of cast instructions will be
5596// populated by this method, and provided back to isInductionPHI. This is
5597// needed to allow the vectorizer to properly record them to be ignored by
5598// the cost model and to avoid vectorizing them (otherwise these casts,
5599// which are redundant under the runtime overflow checks, will be
5600// vectorized, which can be costly).
5601//
5602// 2) Support additional induction/PHISCEV patterns: We also want to support
5603// inductions where the sext-trunc / zext-trunc operations (partly) occur
5604// after the induction update operation (the induction increment):
5605//
5606// (Trunc iy (SExt/ZExt ix (%SymbolicPHI + InvariantAccum) to iy) to ix)
5607// which correspond to a phi->add->trunc->sext/zext->phi update chain.
5608//
5609// (Trunc iy ((SExt/ZExt ix (%SymbolicPhi) to iy) + InvariantAccum) to ix)
5610// which correspond to a phi->trunc->add->sext/zext->phi update chain.
5611//
5612// 3) Outline common code with createAddRecFromPHI to avoid duplication.
5613std::optional<std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>>>
5614ScalarEvolution::createAddRecFromPHIWithCastsImpl(const SCEVUnknown *SymbolicPHI) {
5616
5617 // *** Part1: Analyze if we have a phi-with-cast pattern for which we can
5618 // return an AddRec expression under some predicate.
5619
5620 auto *PN = cast<PHINode>(SymbolicPHI->getValue());
5621 const Loop *L = isIntegerLoopHeaderPHI(PN, LI);
5622 assert(L && "Expecting an integer loop header phi");
5623
5624 // The loop may have multiple entrances or multiple exits; we can analyze
5625 // this phi as an addrec if it has a unique entry value and a unique
5626 // backedge value.
5627 Value *BEValueV = nullptr, *StartValueV = nullptr;
5628 for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) {
5629 Value *V = PN->getIncomingValue(i);
5630 if (L->contains(PN->getIncomingBlock(i))) {
5631 if (!BEValueV) {
5632 BEValueV = V;
5633 } else if (BEValueV != V) {
5634 BEValueV = nullptr;
5635 break;
5636 }
5637 } else if (!StartValueV) {
5638 StartValueV = V;
5639 } else if (StartValueV != V) {
5640 StartValueV = nullptr;
5641 break;
5642 }
5643 }
5644 if (!BEValueV || !StartValueV)
5645 return std::nullopt;
5646
5647 const SCEV *BEValue = getSCEV(BEValueV);
5648
5649 // If the value coming around the backedge is an add with the symbolic
5650 // value we just inserted, possibly with casts that we can ignore under
5651 // an appropriate runtime guard, then we found a simple induction variable!
5652 const auto *Add = dyn_cast<SCEVAddExpr>(BEValue);
5653 if (!Add)
5654 return std::nullopt;
5655
5656 // If there is a single occurrence of the symbolic value, possibly
5657 // casted, replace it with a recurrence.
5658 unsigned FoundIndex = Add->getNumOperands();
5659 Type *TruncTy = nullptr;
5660 bool Signed;
5661 for (unsigned i = 0, e = Add->getNumOperands(); i != e; ++i)
5662 if ((TruncTy =
5663 isSimpleCastedPHI(Add->getOperand(i), SymbolicPHI, Signed, *this)))
5664 if (FoundIndex == e) {
5665 FoundIndex = i;
5666 break;
5667 }
5668
5669 if (FoundIndex == Add->getNumOperands())
5670 return std::nullopt;
5671
5672 // Create an add with everything but the specified operand.
5674 for (unsigned i = 0, e = Add->getNumOperands(); i != e; ++i)
5675 if (i != FoundIndex)
5676 Ops.push_back(Add->getOperand(i));
5677 const SCEV *Accum = getAddExpr(Ops);
5678
5679 // The runtime checks will not be valid if the step amount is
5680 // varying inside the loop.
5681 if (!isLoopInvariant(Accum, L))
5682 return std::nullopt;
5683
5684 // *** Part2: Create the predicates
5685
5686 // Analysis was successful: we have a phi-with-cast pattern for which we
5687 // can return an AddRec expression under the following predicates:
5688 //
5689 // P1: A Wrap predicate that guarantees that Trunc(Start) + i*Trunc(Accum)
5690 // fits within the truncated type (does not overflow) for i = 0 to n-1.
5691 // P2: An Equal predicate that guarantees that
5692 // Start = (Ext ix (Trunc iy (Start) to ix) to iy)
5693 // P3: An Equal predicate that guarantees that
5694 // Accum = (Ext ix (Trunc iy (Accum) to ix) to iy)
5695 //
5696 // As we next prove, the above predicates guarantee that:
5697 // Start + i*Accum = (Ext ix (Trunc iy ( Start + i*Accum ) to ix) to iy)
5698 //
5699 //
5700 // More formally, we want to prove that:
5701 // Expr(i+1) = Start + (i+1) * Accum
5702 // = (Ext ix (Trunc iy (Expr(i)) to ix) to iy) + Accum
5703 //
5704 // Given that:
5705 // 1) Expr(0) = Start
5706 // 2) Expr(1) = Start + Accum
5707 // = (Ext ix (Trunc iy (Start) to ix) to iy) + Accum :: from P2
5708 // 3) Induction hypothesis (step i):
5709 // Expr(i) = (Ext ix (Trunc iy (Expr(i-1)) to ix) to iy) + Accum
5710 //
5711 // Proof:
5712 // Expr(i+1) =
5713 // = Start + (i+1)*Accum
5714 // = (Start + i*Accum) + Accum
5715 // = Expr(i) + Accum
5716 // = (Ext ix (Trunc iy (Expr(i-1)) to ix) to iy) + Accum + Accum
5717 // :: from step i
5718 //
5719 // = (Ext ix (Trunc iy (Start + (i-1)*Accum) to ix) to iy) + Accum + Accum
5720 //
5721 // = (Ext ix (Trunc iy (Start + (i-1)*Accum) to ix) to iy)
5722 // + (Ext ix (Trunc iy (Accum) to ix) to iy)
5723 // + Accum :: from P3
5724 //
5725 // = (Ext ix (Trunc iy ((Start + (i-1)*Accum) + Accum) to ix) to iy)
5726 // + Accum :: from P1: Ext(x)+Ext(y)=>Ext(x+y)
5727 //
5728 // = (Ext ix (Trunc iy (Start + i*Accum) to ix) to iy) + Accum
5729 // = (Ext ix (Trunc iy (Expr(i)) to ix) to iy) + Accum
5730 //
5731 // By induction, the same applies to all iterations 1<=i<n:
5732 //
5733
5734 // Create a truncated addrec for which we will add a no overflow check (P1).
5735 const SCEV *StartVal = getSCEV(StartValueV);
5736 const SCEV *PHISCEV =
5737 getAddRecExpr(getTruncateExpr(StartVal, TruncTy),
5738 getTruncateExpr(Accum, TruncTy), L, SCEV::FlagAnyWrap);
5739
5740 // PHISCEV can be either a SCEVConstant or a SCEVAddRecExpr.
5741 // ex: If truncated Accum is 0 and StartVal is a constant, then PHISCEV
5742 // will be constant.
5743 //
5744 // If PHISCEV is a constant, then P1 degenerates into P2 or P3, so we don't
5745 // add P1.
5746 if (const auto *AR = dyn_cast<SCEVAddRecExpr>(PHISCEV)) {
5750 const SCEVPredicate *AddRecPred = getWrapPredicate(AR, AddedFlags);
5751 Predicates.push_back(AddRecPred);
5752 }
5753
5754 // Create the Equal Predicates P2,P3:
5755
5756 // It is possible that the predicates P2 and/or P3 are computable at
5757 // compile time due to StartVal and/or Accum being constants.
5758 // If either one is, then we can check that now and escape if either P2
5759 // or P3 is false.
5760
5761 // Construct the extended SCEV: (Ext ix (Trunc iy (Expr) to ix) to iy)
5762 // for each of StartVal and Accum
5763 auto getExtendedExpr = [&](const SCEV *Expr,
5764 bool CreateSignExtend) -> const SCEV * {
5765 assert(isLoopInvariant(Expr, L) && "Expr is expected to be invariant");
5766 const SCEV *TruncatedExpr = getTruncateExpr(Expr, TruncTy);
5767 const SCEV *ExtendedExpr =
5768 CreateSignExtend ? getSignExtendExpr(TruncatedExpr, Expr->getType())
5769 : getZeroExtendExpr(TruncatedExpr, Expr->getType());
5770 return ExtendedExpr;
5771 };
5772
5773 // Given:
5774 // ExtendedExpr = (Ext ix (Trunc iy (Expr) to ix) to iy
5775 // = getExtendedExpr(Expr)
5776 // Determine whether the predicate P: Expr == ExtendedExpr
5777 // is known to be false at compile time
5778 auto PredIsKnownFalse = [&](const SCEV *Expr,
5779 const SCEV *ExtendedExpr) -> bool {
5780 return Expr != ExtendedExpr &&
5781 isKnownPredicate(ICmpInst::ICMP_NE, Expr, ExtendedExpr);
5782 };
5783
5784 const SCEV *StartExtended = getExtendedExpr(StartVal, Signed);
5785 if (PredIsKnownFalse(StartVal, StartExtended)) {
5786 LLVM_DEBUG(dbgs() << "P2 is compile-time false\n";);
5787 return std::nullopt;
5788 }
5789
5790 // The Step is always Signed (because the overflow checks are either
5791 // NSSW or NUSW)
5792 const SCEV *AccumExtended = getExtendedExpr(Accum, /*CreateSignExtend=*/true);
5793 if (PredIsKnownFalse(Accum, AccumExtended)) {
5794 LLVM_DEBUG(dbgs() << "P3 is compile-time false\n";);
5795 return std::nullopt;
5796 }
5797
5798 auto AppendPredicate = [&](const SCEV *Expr,
5799 const SCEV *ExtendedExpr) -> void {
5800 if (Expr != ExtendedExpr &&
5801 !isKnownPredicate(ICmpInst::ICMP_EQ, Expr, ExtendedExpr)) {
5802 const SCEVPredicate *Pred = getEqualPredicate(Expr, ExtendedExpr);
5803 LLVM_DEBUG(dbgs() << "Added Predicate: " << *Pred);
5804 Predicates.push_back(Pred);
5805 }
5806 };
5807
5808 AppendPredicate(StartVal, StartExtended);
5809 AppendPredicate(Accum, AccumExtended);
5810
5811 // *** Part3: Predicates are ready. Now go ahead and create the new addrec in
5812 // which the casts had been folded away. The caller can rewrite SymbolicPHI
5813 // into NewAR if it will also add the runtime overflow checks specified in
5814 // Predicates.
5815 auto *NewAR = getAddRecExpr(StartVal, Accum, L, SCEV::FlagAnyWrap);
5816
5817 std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>> PredRewrite =
5818 std::make_pair(NewAR, Predicates);
5819 // Remember the result of the analysis for this SCEV at this locayyytion.
5820 PredicatedSCEVRewrites[{SymbolicPHI, L}] = PredRewrite;
5821 return PredRewrite;
5822}
5823
5824std::optional<std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>>>
5826 auto *PN = cast<PHINode>(SymbolicPHI->getValue());
5827 const Loop *L = isIntegerLoopHeaderPHI(PN, LI);
5828 if (!L)
5829 return std::nullopt;
5830
5831 // Check to see if we already analyzed this PHI.
5832 auto I = PredicatedSCEVRewrites.find({SymbolicPHI, L});
5833 if (I != PredicatedSCEVRewrites.end()) {
5834 std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>> Rewrite =
5835 I->second;
5836 // Analysis was done before and failed to create an AddRec:
5837 if (Rewrite.first == SymbolicPHI)
5838 return std::nullopt;
5839 // Analysis was done before and succeeded to create an AddRec under
5840 // a predicate:
5841 assert(isa<SCEVAddRecExpr>(Rewrite.first) && "Expected an AddRec");
5842 assert(!(Rewrite.second).empty() && "Expected to find Predicates");
5843 return Rewrite;
5844 }
5845
5846 std::optional<std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>>>
5847 Rewrite = createAddRecFromPHIWithCastsImpl(SymbolicPHI);
5848
5849 // Record in the cache that the analysis failed
5850 if (!Rewrite) {
5852 PredicatedSCEVRewrites[{SymbolicPHI, L}] = {SymbolicPHI, Predicates};
5853 return std::nullopt;
5854 }
5855
5856 return Rewrite;
5857}
5858
5859// FIXME: This utility is currently required because the Rewriter currently
5860// does not rewrite this expression:
5861// {0, +, (sext ix (trunc iy to ix) to iy)}
5862// into {0, +, %step},
5863// even when the following Equal predicate exists:
5864// "%step == (sext ix (trunc iy to ix) to iy)".
5866 const SCEVAddRecExpr *AR1, const SCEVAddRecExpr *AR2) const {
5867 if (AR1 == AR2)
5868 return true;
5869
5870 auto areExprsEqual = [&](const SCEV *Expr1, const SCEV *Expr2) -> bool {
5871 if (Expr1 != Expr2 &&
5872 !Preds->implies(SE.getEqualPredicate(Expr1, Expr2), SE) &&
5873 !Preds->implies(SE.getEqualPredicate(Expr2, Expr1), SE))
5874 return false;
5875 return true;
5876 };
5877
5878 if (!areExprsEqual(AR1->getStart(), AR2->getStart()) ||
5879 !areExprsEqual(AR1->getStepRecurrence(SE), AR2->getStepRecurrence(SE)))
5880 return false;
5881 return true;
5882}
5883
5884/// A helper function for createAddRecFromPHI to handle simple cases.
5885///
5886/// This function tries to find an AddRec expression for the simplest (yet most
5887/// common) cases: PN = PHI(Start, OP(Self, LoopInvariant)).
5888/// If it fails, createAddRecFromPHI will use a more general, but slow,
5889/// technique for finding the AddRec expression.
5890const SCEV *ScalarEvolution::createSimpleAffineAddRec(PHINode *PN,
5891 Value *BEValueV,
5892 Value *StartValueV) {
5893 const Loop *L = LI.getLoopFor(PN->getParent());
5894 assert(L && L->getHeader() == PN->getParent());
5895 assert(BEValueV && StartValueV);
5896
5897 auto BO = MatchBinaryOp(BEValueV, getDataLayout(), AC, DT, PN);
5898 if (!BO)
5899 return nullptr;
5900
5901 if (BO->Opcode != Instruction::Add)
5902 return nullptr;
5903
5904 const SCEV *Accum = nullptr;
5905 if (BO->LHS == PN && L->isLoopInvariant(BO->RHS))
5906 Accum = getSCEV(BO->RHS);
5907 else if (BO->RHS == PN && L->isLoopInvariant(BO->LHS))
5908 Accum = getSCEV(BO->LHS);
5909
5910 if (!Accum)
5911 return nullptr;
5912
5914 if (BO->IsNUW)
5915 Flags = setFlags(Flags, SCEV::FlagNUW);
5916 if (BO->IsNSW)
5917 Flags = setFlags(Flags, SCEV::FlagNSW);
5918
5919 const SCEV *StartVal = getSCEV(StartValueV);
5920 const SCEV *PHISCEV = getAddRecExpr(StartVal, Accum, L, Flags);
5921 insertValueToMap(PN, PHISCEV);
5922
5923 if (auto *AR = dyn_cast<SCEVAddRecExpr>(PHISCEV)) {
5924 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR),
5925 (AR->getNoWrapFlags() | proveNoWrapViaConstantRanges(AR)));
5926 }
5927
5928 // We can add Flags to the post-inc expression only if we
5929 // know that it is *undefined behavior* for BEValueV to
5930 // overflow.
5931 if (auto *BEInst = dyn_cast<Instruction>(BEValueV)) {
5932 assert(isLoopInvariant(Accum, L) &&
5933 "Accum is defined outside L, but is not invariant?");
5934 if (isAddRecNeverPoison(BEInst, L))
5935 (void)getAddRecExpr(getAddExpr(StartVal, Accum), Accum, L, Flags);
5936 }
5937
5938 return PHISCEV;
5939}
5940
5941const SCEV *ScalarEvolution::createAddRecFromPHI(PHINode *PN) {
5942 const Loop *L = LI.getLoopFor(PN->getParent());
5943 if (!L || L->getHeader() != PN->getParent())
5944 return nullptr;
5945
5946 // The loop may have multiple entrances or multiple exits; we can analyze
5947 // this phi as an addrec if it has a unique entry value and a unique
5948 // backedge value.
5949 Value *BEValueV = nullptr, *StartValueV = nullptr;
5950 for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) {
5951 Value *V = PN->getIncomingValue(i);
5952 if (L->contains(PN->getIncomingBlock(i))) {
5953 if (!BEValueV) {
5954 BEValueV = V;
5955 } else if (BEValueV != V) {
5956 BEValueV = nullptr;
5957 break;
5958 }
5959 } else if (!StartValueV) {
5960 StartValueV = V;
5961 } else if (StartValueV != V) {
5962 StartValueV = nullptr;
5963 break;
5964 }
5965 }
5966 if (!BEValueV || !StartValueV)
5967 return nullptr;
5968
5969 assert(ValueExprMap.find_as(PN) == ValueExprMap.end() &&
5970 "PHI node already processed?");
5971
5972 // First, try to find AddRec expression without creating a fictituos symbolic
5973 // value for PN.
5974 if (auto *S = createSimpleAffineAddRec(PN, BEValueV, StartValueV))
5975 return S;
5976
5977 // Handle PHI node value symbolically.
5978 const SCEV *SymbolicName = getUnknown(PN);
5979 insertValueToMap(PN, SymbolicName);
5980
5981 // Using this symbolic name for the PHI, analyze the value coming around
5982 // the back-edge.
5983 const SCEV *BEValue = getSCEV(BEValueV);
5984
5985 // NOTE: If BEValue is loop invariant, we know that the PHI node just
5986 // has a special value for the first iteration of the loop.
5987
5988 // If the value coming around the backedge is an add with the symbolic
5989 // value we just inserted, then we found a simple induction variable!
5990 if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(BEValue)) {
5991 // If there is a single occurrence of the symbolic value, replace it
5992 // with a recurrence.
5993 unsigned FoundIndex = Add->getNumOperands();
5994 for (unsigned i = 0, e = Add->getNumOperands(); i != e; ++i)
5995 if (Add->getOperand(i) == SymbolicName)
5996 if (FoundIndex == e) {
5997 FoundIndex = i;
5998 break;
5999 }
6000
6001 if (FoundIndex != Add->getNumOperands()) {
6002 // Create an add with everything but the specified operand.
6004 for (unsigned i = 0, e = Add->getNumOperands(); i != e; ++i)
6005 if (i != FoundIndex)
6006 Ops.push_back(SCEVBackedgeConditionFolder::rewrite(Add->getOperand(i),
6007 L, *this));
6008 const SCEV *Accum = getAddExpr(Ops);
6009
6010 // This is not a valid addrec if the step amount is varying each
6011 // loop iteration, but is not itself an addrec in this loop.
6012 if (isLoopInvariant(Accum, L) ||
6013 (isa<SCEVAddRecExpr>(Accum) &&
6014 cast<SCEVAddRecExpr>(Accum)->getLoop() == L)) {
6016
6017 if (auto BO = MatchBinaryOp(BEValueV, getDataLayout(), AC, DT, PN)) {
6018 if (BO->Opcode == Instruction::Add && BO->LHS == PN) {
6019 if (BO->IsNUW)
6020 Flags = setFlags(Flags, SCEV::FlagNUW);
6021 if (BO->IsNSW)
6022 Flags = setFlags(Flags, SCEV::FlagNSW);
6023 }
6024 } else if (GEPOperator *GEP = dyn_cast<GEPOperator>(BEValueV)) {
6025 if (GEP->getOperand(0) == PN) {
6026 GEPNoWrapFlags NW = GEP->getNoWrapFlags();
6027 // If the increment has any nowrap flags, then we know the address
6028 // space cannot be wrapped around.
6029 if (NW != GEPNoWrapFlags::none())
6030 Flags = setFlags(Flags, SCEV::FlagNW);
6031 // If the GEP is nuw or nusw with non-negative offset, we know that
6032 // no unsigned wrap occurs. We cannot set the nsw flag as only the
6033 // offset is treated as signed, while the base is unsigned.
6034 if (NW.hasNoUnsignedWrap() ||
6036 Flags = setFlags(Flags, SCEV::FlagNUW);
6037 }
6038
6039 // We cannot transfer nuw and nsw flags from subtraction
6040 // operations -- sub nuw X, Y is not the same as add nuw X, -Y
6041 // for instance.
6042 }
6043
6044 const SCEV *StartVal = getSCEV(StartValueV);
6045 const SCEV *PHISCEV = getAddRecExpr(StartVal, Accum, L, Flags);
6046
6047 // Okay, for the entire analysis of this edge we assumed the PHI
6048 // to be symbolic. We now need to go back and purge all of the
6049 // entries for the scalars that use the symbolic expression.
6050 forgetMemoizedResults({SymbolicName});
6051 insertValueToMap(PN, PHISCEV);
6052
6053 if (auto *AR = dyn_cast<SCEVAddRecExpr>(PHISCEV)) {
6055 const_cast<SCEVAddRecExpr *>(AR),
6056 (AR->getNoWrapFlags() | proveNoWrapViaConstantRanges(AR)));
6057 }
6058
6059 // We can add Flags to the post-inc expression only if we
6060 // know that it is *undefined behavior* for BEValueV to
6061 // overflow.
6062 if (auto *BEInst = dyn_cast<Instruction>(BEValueV))
6063 if (isLoopInvariant(Accum, L) && isAddRecNeverPoison(BEInst, L))
6064 (void)getAddRecExpr(getAddExpr(StartVal, Accum), Accum, L, Flags);
6065
6066 return PHISCEV;
6067 }
6068 }
6069 } else {
6070 // Otherwise, this could be a loop like this:
6071 // i = 0; for (j = 1; ..; ++j) { .... i = j; }
6072 // In this case, j = {1,+,1} and BEValue is j.
6073 // Because the other in-value of i (0) fits the evolution of BEValue
6074 // i really is an addrec evolution.
6075 //
6076 // We can generalize this saying that i is the shifted value of BEValue
6077 // by one iteration:
6078 // PHI(f(0), f({1,+,1})) --> f({0,+,1})
6079
6080 // Do not allow refinement in rewriting of BEValue.
6081 const SCEV *Shifted = SCEVShiftRewriter::rewrite(BEValue, L, *this);
6082 const SCEV *Start = SCEVInitRewriter::rewrite(Shifted, L, *this, false);
6083 if (Shifted != getCouldNotCompute() && Start != getCouldNotCompute() &&
6084 isGuaranteedNotToCauseUB(Shifted) && ::impliesPoison(Shifted, Start)) {
6085 const SCEV *StartVal = getSCEV(StartValueV);
6086 if (Start == StartVal) {
6087 // Okay, for the entire analysis of this edge we assumed the PHI
6088 // to be symbolic. We now need to go back and purge all of the
6089 // entries for the scalars that use the symbolic expression.
6090 forgetMemoizedResults({SymbolicName});
6091 insertValueToMap(PN, Shifted);
6092 return Shifted;
6093 }
6094 }
6095 }
6096
6097 // Remove the temporary PHI node SCEV that has been inserted while intending
6098 // to create an AddRecExpr for this PHI node. We can not keep this temporary
6099 // as it will prevent later (possibly simpler) SCEV expressions to be added
6100 // to the ValueExprMap.
6101 eraseValueFromMap(PN);
6102
6103 return nullptr;
6104}
6105
6106// Try to match a control flow sequence that branches out at BI and merges back
6107// at Merge into a "C ? LHS : RHS" select pattern. Return true on a successful
6108// match.
6110 Value *&C, Value *&LHS, Value *&RHS) {
6111 C = BI->getCondition();
6112
6113 BasicBlockEdge LeftEdge(BI->getParent(), BI->getSuccessor(0));
6114 BasicBlockEdge RightEdge(BI->getParent(), BI->getSuccessor(1));
6115
6116 Use &LeftUse = Merge->getOperandUse(0);
6117 Use &RightUse = Merge->getOperandUse(1);
6118
6119 if (DT.dominates(LeftEdge, LeftUse) && DT.dominates(RightEdge, RightUse)) {
6120 LHS = LeftUse;
6121 RHS = RightUse;
6122 return true;
6123 }
6124
6125 if (DT.dominates(LeftEdge, RightUse) && DT.dominates(RightEdge, LeftUse)) {
6126 LHS = RightUse;
6127 RHS = LeftUse;
6128 return true;
6129 }
6130
6131 return false;
6132}
6133
6135 Value *&Cond, Value *&LHS,
6136 Value *&RHS) {
6137 auto IsReachable =
6138 [&](BasicBlock *BB) { return DT.isReachableFromEntry(BB); };
6139 if (PN->getNumIncomingValues() == 2 && all_of(PN->blocks(), IsReachable)) {
6140 // Try to match
6141 //
6142 // br %cond, label %left, label %right
6143 // left:
6144 // br label %merge
6145 // right:
6146 // br label %merge
6147 // merge:
6148 // V = phi [ %x, %left ], [ %y, %right ]
6149 //
6150 // as "select %cond, %x, %y"
6151
6152 BasicBlock *IDom = DT[PN->getParent()]->getIDom()->getBlock();
6153 assert(IDom && "At least the entry block should dominate PN");
6154
6155 auto *BI = dyn_cast<CondBrInst>(IDom->getTerminator());
6156 return BI && BrPHIToSelect(DT, BI, PN, Cond, LHS, RHS);
6157 }
6158 return false;
6159}
6160
6161const SCEV *ScalarEvolution::createNodeFromSelectLikePHI(PHINode *PN) {
6162 Value *Cond = nullptr, *LHS = nullptr, *RHS = nullptr;
6163 if (getOperandsForSelectLikePHI(DT, PN, Cond, LHS, RHS) &&
6166 return createNodeForSelectOrPHI(PN, Cond, LHS, RHS);
6167
6168 return nullptr;
6169}
6170
6172 BinaryOperator *CommonInst = nullptr;
6173 // Check if instructions are identical.
6174 for (Value *Incoming : PN->incoming_values()) {
6175 auto *IncomingInst = dyn_cast<BinaryOperator>(Incoming);
6176 if (!IncomingInst)
6177 return nullptr;
6178 if (CommonInst) {
6179 if (!CommonInst->isIdenticalToWhenDefined(IncomingInst))
6180 return nullptr; // Not identical, give up
6181 } else {
6182 // Remember binary operator
6183 CommonInst = IncomingInst;
6184 }
6185 }
6186 return CommonInst;
6187}
6188
6189/// Returns SCEV for the first operand of a phi if all phi operands have
6190/// identical opcodes and operands
6191/// eg.
6192/// a: %add = %a + %b
6193/// br %c
6194/// b: %add1 = %a + %b
6195/// br %c
6196/// c: %phi = phi [%add, a], [%add1, b]
6197/// scev(%phi) => scev(%add)
6198const SCEV *
6199ScalarEvolution::createNodeForPHIWithIdenticalOperands(PHINode *PN) {
6200 BinaryOperator *CommonInst = getCommonInstForPHI(PN);
6201 if (!CommonInst)
6202 return nullptr;
6203
6204 // Check if SCEV exprs for instructions are identical.
6205 const SCEV *CommonSCEV = getSCEV(CommonInst);
6206 bool SCEVExprsIdentical =
6208 [this, CommonSCEV](Value *V) { return CommonSCEV == getSCEV(V); });
6209 return SCEVExprsIdentical ? CommonSCEV : nullptr;
6210}
6211
6212const SCEV *ScalarEvolution::createNodeForPHI(PHINode *PN) {
6213 if (const SCEV *S = createAddRecFromPHI(PN))
6214 return S;
6215
6216 // We do not allow simplifying phi (undef, X) to X here, to avoid reusing the
6217 // phi node for X.
6218 if (Value *V = simplifyInstruction(
6219 PN, {getDataLayout(), &TLI, &DT, &AC, /*CtxI=*/nullptr,
6220 /*UseInstrInfo=*/true, /*CanUseUndef=*/false}))
6221 return getSCEV(V);
6222
6223 if (const SCEV *S = createNodeForPHIWithIdenticalOperands(PN))
6224 return S;
6225
6226 if (const SCEV *S = createNodeFromSelectLikePHI(PN))
6227 return S;
6228
6229 // If it's not a loop phi, we can't handle it yet.
6230 return getUnknown(PN);
6231}
6232
6233bool SCEVMinMaxExprContains(const SCEV *Root, const SCEV *OperandToFind,
6234 SCEVTypes RootKind) {
6235 struct FindClosure {
6236 const SCEV *OperandToFind;
6237 const SCEVTypes RootKind; // Must be a sequential min/max expression.
6238 const SCEVTypes NonSequentialRootKind; // Non-seq variant of RootKind.
6239
6240 bool Found = false;
6241
6242 bool canRecurseInto(SCEVTypes Kind) const {
6243 // We can only recurse into the SCEV expression of the same effective type
6244 // as the type of our root SCEV expression, and into zero-extensions.
6245 return RootKind == Kind || NonSequentialRootKind == Kind ||
6246 scZeroExtend == Kind;
6247 };
6248
6249 FindClosure(const SCEV *OperandToFind, SCEVTypes RootKind)
6250 : OperandToFind(OperandToFind), RootKind(RootKind),
6251 NonSequentialRootKind(
6253 RootKind)) {}
6254
6255 bool follow(const SCEV *S) {
6256 Found = S == OperandToFind;
6257
6258 return !isDone() && canRecurseInto(S->getSCEVType());
6259 }
6260
6261 bool isDone() const { return Found; }
6262 };
6263
6264 FindClosure FC(OperandToFind, RootKind);
6265 visitAll(Root, FC);
6266 return FC.Found;
6267}
6268
6269std::optional<const SCEV *>
6270ScalarEvolution::createNodeForSelectOrPHIInstWithICmpInstCond(Type *Ty,
6271 ICmpInst *Cond,
6272 Value *TrueVal,
6273 Value *FalseVal) {
6274 // Try to match some simple smax or umax patterns.
6275 auto *ICI = Cond;
6276
6277 Value *LHS = ICI->getOperand(0);
6278 Value *RHS = ICI->getOperand(1);
6279
6280 switch (ICI->getPredicate()) {
6281 case ICmpInst::ICMP_SLT:
6282 case ICmpInst::ICMP_SLE:
6283 case ICmpInst::ICMP_ULT:
6284 case ICmpInst::ICMP_ULE:
6285 std::swap(LHS, RHS);
6286 [[fallthrough]];
6287 case ICmpInst::ICMP_SGT:
6288 case ICmpInst::ICMP_SGE:
6289 case ICmpInst::ICMP_UGT:
6290 case ICmpInst::ICMP_UGE:
6291 // a > b ? a+x : b+x -> max(a, b)+x
6292 // a > b ? b+x : a+x -> min(a, b)+x
6294 bool Signed = ICI->isSigned();
6295 const SCEV *LA = getSCEV(TrueVal);
6296 const SCEV *RA = getSCEV(FalseVal);
6297 const SCEV *LS = getSCEV(LHS);
6298 const SCEV *RS = getSCEV(RHS);
6299 if (LA->getType()->isPointerTy()) {
6300 // FIXME: Handle cases where LS/RS are pointers not equal to LA/RA.
6301 // Need to make sure we can't produce weird expressions involving
6302 // negated pointers.
6303 if (LA == LS && RA == RS)
6304 return Signed ? getSMaxExpr(LS, RS) : getUMaxExpr(LS, RS);
6305 if (LA == RS && RA == LS)
6306 return Signed ? getSMinExpr(LS, RS) : getUMinExpr(LS, RS);
6307 }
6308 auto CoerceOperand = [&](const SCEV *Op) -> const SCEV * {
6309 if (Op->getType()->isPointerTy()) {
6312 return Op;
6313 }
6314 if (Signed)
6315 Op = getNoopOrSignExtend(Op, Ty);
6316 else
6317 Op = getNoopOrZeroExtend(Op, Ty);
6318 return Op;
6319 };
6320 LS = CoerceOperand(LS);
6321 RS = CoerceOperand(RS);
6323 break;
6324 const SCEV *LDiff = getMinusSCEV(LA, LS);
6325 const SCEV *RDiff = getMinusSCEV(RA, RS);
6326 if (LDiff == RDiff)
6327 return getAddExpr(Signed ? getSMaxExpr(LS, RS) : getUMaxExpr(LS, RS),
6328 LDiff);
6329 LDiff = getMinusSCEV(LA, RS);
6330 RDiff = getMinusSCEV(RA, LS);
6331 if (LDiff == RDiff)
6332 return getAddExpr(Signed ? getSMinExpr(LS, RS) : getUMinExpr(LS, RS),
6333 LDiff);
6334 }
6335 break;
6336 case ICmpInst::ICMP_NE:
6337 // x != 0 ? x+y : C+y -> x == 0 ? C+y : x+y
6338 std::swap(TrueVal, FalseVal);
6339 [[fallthrough]];
6340 case ICmpInst::ICMP_EQ:
6341 // x == 0 ? C+y : x+y -> umax(x, C)+y iff C u<= 1
6344 const SCEV *X = getNoopOrZeroExtend(getSCEV(LHS), Ty);
6345 const SCEV *TrueValExpr = getSCEV(TrueVal); // C+y
6346 const SCEV *FalseValExpr = getSCEV(FalseVal); // x+y
6347 const SCEV *Y = getMinusSCEV(FalseValExpr, X); // y = (x+y)-x
6348 const SCEV *C = getMinusSCEV(TrueValExpr, Y); // C = (C+y)-y
6349 if (isa<SCEVConstant>(C) && cast<SCEVConstant>(C)->getAPInt().ule(1))
6350 return getAddExpr(getUMaxExpr(X, C), Y);
6351 }
6352 // x == 0 ? 0 : umin (..., x, ...) -> umin_seq(x, umin (...))
6353 // x == 0 ? 0 : umin_seq(..., x, ...) -> umin_seq(x, umin_seq(...))
6354 // x == 0 ? 0 : umin (..., umin_seq(..., x, ...), ...)
6355 // -> umin_seq(x, umin (..., umin_seq(...), ...))
6357 isa<ConstantInt>(TrueVal) && cast<ConstantInt>(TrueVal)->isZero()) {
6358 const SCEV *X = getSCEV(LHS);
6359 while (auto *ZExt = dyn_cast<SCEVZeroExtendExpr>(X))
6360 X = ZExt->getOperand();
6361 if (getTypeSizeInBits(X->getType()) <= getTypeSizeInBits(Ty)) {
6362 const SCEV *FalseValExpr = getSCEV(FalseVal);
6363 if (SCEVMinMaxExprContains(FalseValExpr, X, scSequentialUMinExpr))
6364 return getUMinExpr(getNoopOrZeroExtend(X, Ty), FalseValExpr,
6365 /*Sequential=*/true);
6366 }
6367 }
6368 break;
6369 default:
6370 break;
6371 }
6372
6373 return std::nullopt;
6374}
6375
6376static std::optional<const SCEV *>
6378 const SCEV *TrueExpr, const SCEV *FalseExpr) {
6379 assert(CondExpr->getType()->isIntegerTy(1) &&
6380 TrueExpr->getType() == FalseExpr->getType() &&
6381 TrueExpr->getType()->isIntegerTy(1) &&
6382 "Unexpected operands of a select.");
6383
6384 // i1 cond ? i1 x : i1 C --> C + (i1 cond ? (i1 x - i1 C) : i1 0)
6385 // --> C + (umin_seq cond, x - C)
6386 //
6387 // i1 cond ? i1 C : i1 x --> C + (i1 cond ? i1 0 : (i1 x - i1 C))
6388 // --> C + (i1 ~cond ? (i1 x - i1 C) : i1 0)
6389 // --> C + (umin_seq ~cond, x - C)
6390
6391 // FIXME: while we can't legally model the case where both of the hands
6392 // are fully variable, we only require that the *difference* is constant.
6393 if (!isa<SCEVConstant>(TrueExpr) && !isa<SCEVConstant>(FalseExpr))
6394 return std::nullopt;
6395
6396 const SCEV *X, *C;
6397 if (isa<SCEVConstant>(TrueExpr)) {
6398 CondExpr = SE->getNotSCEV(CondExpr);
6399 X = FalseExpr;
6400 C = TrueExpr;
6401 } else {
6402 X = TrueExpr;
6403 C = FalseExpr;
6404 }
6405 return SE->getAddExpr(C, SE->getUMinExpr(CondExpr, SE->getMinusSCEV(X, C),
6406 /*Sequential=*/true));
6407}
6408
6409static std::optional<const SCEV *>
6411 Value *FalseVal) {
6412 if (!isa<ConstantInt>(TrueVal) && !isa<ConstantInt>(FalseVal))
6413 return std::nullopt;
6414
6415 const auto *SECond = SE->getSCEV(Cond);
6416 const auto *SETrue = SE->getSCEV(TrueVal);
6417 const auto *SEFalse = SE->getSCEV(FalseVal);
6418 return createNodeForSelectViaUMinSeq(SE, SECond, SETrue, SEFalse);
6419}
6420
6421const SCEV *ScalarEvolution::createNodeForSelectOrPHIViaUMinSeq(
6422 Value *V, Value *Cond, Value *TrueVal, Value *FalseVal) {
6423 assert(Cond->getType()->isIntegerTy(1) && "Select condition is not an i1?");
6424 assert(TrueVal->getType() == FalseVal->getType() &&
6425 V->getType() == TrueVal->getType() &&
6426 "Types of select hands and of the result must match.");
6427
6428 // For now, only deal with i1-typed `select`s.
6429 if (!V->getType()->isIntegerTy(1))
6430 return getUnknown(V);
6431
6432 if (std::optional<const SCEV *> S =
6433 createNodeForSelectViaUMinSeq(this, Cond, TrueVal, FalseVal))
6434 return *S;
6435
6436 return getUnknown(V);
6437}
6438
6439const SCEV *ScalarEvolution::createNodeForSelectOrPHI(Value *V, Value *Cond,
6440 Value *TrueVal,
6441 Value *FalseVal) {
6442 // Handle "constant" branch or select. This can occur for instance when a
6443 // loop pass transforms an inner loop and moves on to process the outer loop.
6444 if (auto *CI = dyn_cast<ConstantInt>(Cond))
6445 return getSCEV(CI->isOne() ? TrueVal : FalseVal);
6446
6447 if (auto *I = dyn_cast<Instruction>(V)) {
6448 if (auto *ICI = dyn_cast<ICmpInst>(Cond)) {
6449 if (std::optional<const SCEV *> S =
6450 createNodeForSelectOrPHIInstWithICmpInstCond(I->getType(), ICI,
6451 TrueVal, FalseVal))
6452 return *S;
6453 }
6454 }
6455
6456 return createNodeForSelectOrPHIViaUMinSeq(V, Cond, TrueVal, FalseVal);
6457}
6458
6459/// Expand GEP instructions into add and multiply operations. This allows them
6460/// to be analyzed by regular SCEV code.
6461const SCEV *ScalarEvolution::createNodeForGEP(GEPOperator *GEP) {
6462 assert(GEP->getSourceElementType()->isSized() &&
6463 "GEP source element type must be sized");
6464
6465 SmallVector<SCEVUse, 4> IndexExprs;
6466 for (Value *Index : GEP->indices())
6467 IndexExprs.push_back(getSCEV(Index));
6468 return getGEPExpr(GEP, IndexExprs);
6469}
6470
6471APInt ScalarEvolution::getConstantMultipleImpl(const SCEV *S,
6472 const Instruction *CtxI) {
6473 uint64_t BitWidth = getTypeSizeInBits(S->getType());
6474 auto GetShiftedByZeros = [BitWidth](uint32_t TrailingZeros) {
6475 return TrailingZeros >= BitWidth
6477 : APInt::getOneBitSet(BitWidth, TrailingZeros);
6478 };
6479 auto GetGCDMultiple = [this, CtxI](const SCEVNAryExpr *N) {
6480 // The result is GCD of all operands results.
6481 APInt Res = getConstantMultiple(N->getOperand(0), CtxI);
6482 for (unsigned I = 1, E = N->getNumOperands(); I < E && Res != 1; ++I)
6484 Res, getConstantMultiple(N->getOperand(I), CtxI));
6485 return Res;
6486 };
6487
6488 switch (S->getSCEVType()) {
6489 case scConstant:
6490 return cast<SCEVConstant>(S)->getAPInt();
6491 case scPtrToAddr:
6492 case scPtrToInt:
6493 return getConstantMultiple(cast<SCEVCastExpr>(S)->getOperand());
6494 case scUDivExpr:
6495 case scVScale:
6496 return APInt(BitWidth, 1);
6497 case scTruncate: {
6498 // Only multiples that are a power of 2 will hold after truncation.
6499 const SCEVTruncateExpr *T = cast<SCEVTruncateExpr>(S);
6500 uint32_t TZ = getMinTrailingZeros(T->getOperand(), CtxI);
6501 return GetShiftedByZeros(TZ);
6502 }
6503 case scZeroExtend: {
6504 const SCEVZeroExtendExpr *Z = cast<SCEVZeroExtendExpr>(S);
6505 return getConstantMultiple(Z->getOperand(), CtxI).zext(BitWidth);
6506 }
6507 case scSignExtend: {
6508 // Only multiples that are a power of 2 will hold after sext.
6509 const SCEVSignExtendExpr *E = cast<SCEVSignExtendExpr>(S);
6510 uint32_t TZ = getMinTrailingZeros(E->getOperand(), CtxI);
6511 return GetShiftedByZeros(TZ);
6512 }
6513 case scMulExpr: {
6514 const SCEVMulExpr *M = cast<SCEVMulExpr>(S);
6515 if (M->hasNoUnsignedWrap()) {
6516 // The result is the product of all operand results.
6517 APInt Res = getConstantMultiple(M->getOperand(0), CtxI);
6518 for (const SCEV *Operand : M->operands().drop_front())
6519 Res = Res * getConstantMultiple(Operand, CtxI);
6520 return Res;
6521 }
6522
6523 // If there are no wrap guarentees, find the trailing zeros, which is the
6524 // sum of trailing zeros for all its operands.
6525 uint32_t TZ = 0;
6526 for (const SCEV *Operand : M->operands())
6527 TZ += getMinTrailingZeros(Operand, CtxI);
6528 return GetShiftedByZeros(TZ);
6529 }
6530 case scAddExpr:
6531 case scAddRecExpr: {
6532 const SCEVNAryExpr *N = cast<SCEVNAryExpr>(S);
6533 if (N->hasNoUnsignedWrap())
6534 return GetGCDMultiple(N);
6535 // Find the trailing bits, which is the minimum of its operands.
6536 uint32_t TZ = getMinTrailingZeros(N->getOperand(0), CtxI);
6537 for (const SCEV *Operand : N->operands().drop_front())
6538 TZ = std::min(TZ, getMinTrailingZeros(Operand, CtxI));
6539 return GetShiftedByZeros(TZ);
6540 }
6541 case scUMaxExpr:
6542 case scSMaxExpr:
6543 case scUMinExpr:
6544 case scSMinExpr:
6546 return GetGCDMultiple(cast<SCEVNAryExpr>(S));
6547 case scUnknown: {
6548 // Ask ValueTracking for known bits. SCEVUnknown only become available at
6549 // the point their underlying IR instruction has been defined. If CtxI was
6550 // not provided, use:
6551 // * the first instruction in the entry block if it is an argument
6552 // * the instruction itself otherwise.
6553 const SCEVUnknown *U = cast<SCEVUnknown>(S);
6554 if (!CtxI) {
6555 if (isa<Argument>(U->getValue()))
6556 CtxI = &*F.getEntryBlock().begin();
6557 else if (auto *I = dyn_cast<Instruction>(U->getValue()))
6558 CtxI = I;
6559 }
6560 unsigned Known =
6561 computeKnownBits(U->getValue(), getDataLayout(), &AC, CtxI, &DT)
6562 .countMinTrailingZeros();
6563 return GetShiftedByZeros(Known);
6564 }
6565 case scCouldNotCompute:
6566 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
6567 }
6568 llvm_unreachable("Unknown SCEV kind!");
6569}
6570
6572 const Instruction *CtxI) {
6573 // Skip looking up and updating the cache if there is a context instruction,
6574 // as the result will only be valid in the specified context.
6575 if (CtxI)
6576 return getConstantMultipleImpl(S, CtxI);
6577
6578 auto I = ConstantMultipleCache.find(S);
6579 if (I != ConstantMultipleCache.end())
6580 return I->second;
6581
6582 APInt Result = getConstantMultipleImpl(S, CtxI);
6583 auto InsertPair = ConstantMultipleCache.insert({S, Result});
6584 assert(InsertPair.second && "Should insert a new key");
6585 return InsertPair.first->second;
6586}
6587
6589 APInt Multiple = getConstantMultiple(S);
6590 return Multiple == 0 ? APInt(Multiple.getBitWidth(), 1) : Multiple;
6591}
6592
6594 const Instruction *CtxI) {
6595 return std::min(getConstantMultiple(S, CtxI).countTrailingZeros(),
6596 (unsigned)getTypeSizeInBits(S->getType()));
6597}
6598
6599/// Helper method to assign a range to V from metadata present in the IR.
6600static std::optional<ConstantRange> GetRangeFromMetadata(Value *V) {
6602 if (MDNode *MD = I->getMetadata(LLVMContext::MD_range))
6603 return getConstantRangeFromMetadata(*MD);
6604 if (const auto *CB = dyn_cast<CallBase>(V))
6605 if (std::optional<ConstantRange> Range = CB->getRange())
6606 return Range;
6607 }
6608 if (auto *A = dyn_cast<Argument>(V))
6609 if (std::optional<ConstantRange> Range = A->getRange())
6610 return Range;
6611
6612 return std::nullopt;
6613}
6614
6616 SCEV::NoWrapFlags Flags) {
6617 if (AddRec->getNoWrapFlags(Flags) != Flags) {
6618 AddRec->setNoWrapFlags(Flags);
6619 UnsignedRanges.erase(AddRec);
6620 SignedRanges.erase(AddRec);
6621 ConstantMultipleCache.erase(AddRec);
6622 }
6623}
6624
6625ConstantRange ScalarEvolution::
6626getRangeForUnknownRecurrence(const SCEVUnknown *U) {
6627 const DataLayout &DL = getDataLayout();
6628
6629 unsigned BitWidth = getTypeSizeInBits(U->getType());
6630 const ConstantRange FullSet(BitWidth, /*isFullSet=*/true);
6631
6632 // Match a simple recurrence of the form: <start, ShiftOp, Step>, and then
6633 // use information about the trip count to improve our available range. Note
6634 // that the trip count independent cases are already handled by known bits.
6635 // WARNING: The definition of recurrence used here is subtly different than
6636 // the one used by AddRec (and thus most of this file). Step is allowed to
6637 // be arbitrarily loop varying here, where AddRec allows only loop invariant
6638 // and other addrecs in the same loop (for non-affine addrecs). The code
6639 // below intentionally handles the case where step is not loop invariant.
6640 auto *P = dyn_cast<PHINode>(U->getValue());
6641 if (!P)
6642 return FullSet;
6643
6644 // Make sure that no Phi input comes from an unreachable block. Otherwise,
6645 // even the values that are not available in these blocks may come from them,
6646 // and this leads to false-positive recurrence test.
6647 for (auto *Pred : predecessors(P->getParent()))
6648 if (!DT.isReachableFromEntry(Pred))
6649 return FullSet;
6650
6651 BinaryOperator *BO;
6652 Value *Start, *Step;
6653 if (!matchSimpleRecurrence(P, BO, Start, Step))
6654 return FullSet;
6655
6656 // If we found a recurrence in reachable code, we must be in a loop. Note
6657 // that BO might be in some subloop of L, and that's completely okay.
6658 auto *L = LI.getLoopFor(P->getParent());
6659 assert(L && L->getHeader() == P->getParent());
6660 if (!L->contains(BO->getParent()))
6661 // NOTE: This bailout should be an assert instead. However, asserting
6662 // the condition here exposes a case where LoopFusion is querying SCEV
6663 // with malformed loop information during the midst of the transform.
6664 // There doesn't appear to be an obvious fix, so for the moment bailout
6665 // until the caller issue can be fixed. PR49566 tracks the bug.
6666 return FullSet;
6667
6668 // TODO: Extend to other opcodes such as mul, and div
6669 switch (BO->getOpcode()) {
6670 default:
6671 return FullSet;
6672 case Instruction::AShr:
6673 case Instruction::LShr:
6674 case Instruction::Shl:
6675 break;
6676 };
6677
6678 if (BO->getOperand(0) != P)
6679 // TODO: Handle the power function forms some day.
6680 return FullSet;
6681
6682 unsigned TC = getSmallConstantMaxTripCount(L);
6683 if (!TC || TC >= BitWidth)
6684 return FullSet;
6685
6686 auto KnownStart = computeKnownBits(Start, DL, &AC, nullptr, &DT);
6687 auto KnownStep = computeKnownBits(Step, DL, &AC, nullptr, &DT);
6688 assert(KnownStart.getBitWidth() == BitWidth &&
6689 KnownStep.getBitWidth() == BitWidth);
6690
6691 // Compute total shift amount, being careful of overflow and bitwidths.
6692 auto MaxShiftAmt = KnownStep.getMaxValue();
6693 APInt TCAP(BitWidth, TC-1);
6694 bool Overflow = false;
6695 auto TotalShift = MaxShiftAmt.umul_ov(TCAP, Overflow);
6696 if (Overflow)
6697 return FullSet;
6698
6699 switch (BO->getOpcode()) {
6700 default:
6701 llvm_unreachable("filtered out above");
6702 case Instruction::AShr: {
6703 // For each ashr, three cases:
6704 // shift = 0 => unchanged value
6705 // saturation => 0 or -1
6706 // other => a value closer to zero (of the same sign)
6707 // Thus, the end value is closer to zero than the start.
6708 auto KnownEnd = KnownBits::ashr(KnownStart,
6709 KnownBits::makeConstant(TotalShift));
6710 if (KnownStart.isNonNegative())
6711 // Analogous to lshr (simply not yet canonicalized)
6712 return ConstantRange::getNonEmpty(KnownEnd.getMinValue(),
6713 KnownStart.getMaxValue() + 1);
6714 if (KnownStart.isNegative())
6715 // End >=u Start && End <=s Start
6716 return ConstantRange::getNonEmpty(KnownStart.getMinValue(),
6717 KnownEnd.getMaxValue() + 1);
6718 break;
6719 }
6720 case Instruction::LShr: {
6721 // For each lshr, three cases:
6722 // shift = 0 => unchanged value
6723 // saturation => 0
6724 // other => a smaller positive number
6725 // Thus, the low end of the unsigned range is the last value produced.
6726 auto KnownEnd = KnownBits::lshr(KnownStart,
6727 KnownBits::makeConstant(TotalShift));
6728 return ConstantRange::getNonEmpty(KnownEnd.getMinValue(),
6729 KnownStart.getMaxValue() + 1);
6730 }
6731 case Instruction::Shl: {
6732 // Iff no bits are shifted out, value increases on every shift.
6733 auto KnownEnd = KnownBits::shl(KnownStart,
6734 KnownBits::makeConstant(TotalShift));
6735 if (TotalShift.ult(KnownStart.countMinLeadingZeros()))
6736 return ConstantRange(KnownStart.getMinValue(),
6737 KnownEnd.getMaxValue() + 1);
6738 break;
6739 }
6740 };
6741 return FullSet;
6742}
6743
6744// The goal of this function is to check if recursively visiting the operands
6745// of this PHI might lead to an infinite loop. If we do see such a loop,
6746// there's no good way to break it, so we avoid analyzing such cases.
6747//
6748// getRangeRef previously used a visited set to avoid infinite loops, but this
6749// caused other issues: the result was dependent on the order of getRangeRef
6750// calls, and the interaction with createSCEVIter could cause a stack overflow
6751// in some cases (see issue #148253).
6752//
6753// FIXME: The way this is implemented is overly conservative; this checks
6754// for a few obviously safe patterns, but anything that doesn't lead to
6755// recursion is fine.
6757 Value *Cond = nullptr, *LHS = nullptr, *RHS = nullptr;
6759 return true;
6760
6761 if (all_of(PHI->operands(),
6762 [&](Value *Operand) { return DT.dominates(Operand, PHI); }))
6763 return true;
6764
6765 return false;
6766}
6767
6768const ConstantRange &
6769ScalarEvolution::getRangeRefIter(const SCEV *S,
6770 ScalarEvolution::RangeSignHint SignHint) {
6771 DenseMap<const SCEV *, ConstantRange> &Cache =
6772 SignHint == ScalarEvolution::HINT_RANGE_UNSIGNED ? UnsignedRanges
6773 : SignedRanges;
6774 SmallVector<SCEVUse> WorkList;
6775 SmallPtrSet<const SCEV *, 8> Seen;
6776
6777 // Add Expr to the worklist, if Expr is either an N-ary expression or a
6778 // SCEVUnknown PHI node.
6779 auto AddToWorklist = [&WorkList, &Seen, &Cache](const SCEV *Expr) {
6780 if (!Seen.insert(Expr).second)
6781 return;
6782 if (Cache.contains(Expr))
6783 return;
6784 switch (Expr->getSCEVType()) {
6785 case scUnknown:
6786 if (!isa<PHINode>(cast<SCEVUnknown>(Expr)->getValue()))
6787 break;
6788 [[fallthrough]];
6789 case scConstant:
6790 case scVScale:
6791 case scTruncate:
6792 case scZeroExtend:
6793 case scSignExtend:
6794 case scPtrToAddr:
6795 case scPtrToInt:
6796 case scAddExpr:
6797 case scMulExpr:
6798 case scUDivExpr:
6799 case scAddRecExpr:
6800 case scUMaxExpr:
6801 case scSMaxExpr:
6802 case scUMinExpr:
6803 case scSMinExpr:
6805 WorkList.push_back(Expr);
6806 break;
6807 case scCouldNotCompute:
6808 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
6809 }
6810 };
6811 AddToWorklist(S);
6812
6813 // Build worklist by queuing operands of N-ary expressions and phi nodes.
6814 for (unsigned I = 0; I != WorkList.size(); ++I) {
6815 const SCEV *P = WorkList[I];
6816 auto *UnknownS = dyn_cast<SCEVUnknown>(P);
6817 // If it is not a `SCEVUnknown`, just recurse into operands.
6818 if (!UnknownS) {
6819 for (const SCEV *Op : P->operands())
6820 AddToWorklist(Op);
6821 continue;
6822 }
6823 // `SCEVUnknown`'s require special treatment.
6824 if (PHINode *P = dyn_cast<PHINode>(UnknownS->getValue())) {
6825 if (!RangeRefPHIAllowedOperands(DT, P))
6826 continue;
6827 for (auto &Op : reverse(P->operands()))
6828 AddToWorklist(getSCEV(Op));
6829 }
6830 }
6831
6832 if (!WorkList.empty()) {
6833 // Use getRangeRef to compute ranges for items in the worklist in reverse
6834 // order. This will force ranges for earlier operands to be computed before
6835 // their users in most cases.
6836 for (const SCEV *P : reverse(drop_begin(WorkList))) {
6837 getRangeRef(P, SignHint);
6838 }
6839 }
6840
6841 return getRangeRef(S, SignHint, 0);
6842}
6843
6844/// Determine the range for a particular SCEV. If SignHint is
6845/// HINT_RANGE_UNSIGNED (resp. HINT_RANGE_SIGNED) then getRange prefers ranges
6846/// with a "cleaner" unsigned (resp. signed) representation.
6847const ConstantRange &ScalarEvolution::getRangeRef(
6848 const SCEV *S, ScalarEvolution::RangeSignHint SignHint, unsigned Depth) {
6849 DenseMap<const SCEV *, ConstantRange> &Cache =
6850 SignHint == ScalarEvolution::HINT_RANGE_UNSIGNED ? UnsignedRanges
6851 : SignedRanges;
6853 SignHint == ScalarEvolution::HINT_RANGE_UNSIGNED ? ConstantRange::Unsigned
6855
6856 // See if we've computed this range already.
6858 if (I != Cache.end())
6859 return I->second;
6860
6861 if (const SCEVConstant *C = dyn_cast<SCEVConstant>(S))
6862 return setRange(C, SignHint, ConstantRange(C->getAPInt()));
6863
6864 // Switch to iteratively computing the range for S, if it is part of a deeply
6865 // nested expression.
6867 return getRangeRefIter(S, SignHint);
6868
6869 unsigned BitWidth = getTypeSizeInBits(S->getType());
6870 ConstantRange ConservativeResult(BitWidth, /*isFullSet=*/true);
6871 using OBO = OverflowingBinaryOperator;
6872
6873 // If the value has known zeros, the maximum value will have those known zeros
6874 // as well.
6875 if (SignHint == ScalarEvolution::HINT_RANGE_UNSIGNED) {
6876 APInt Multiple = getNonZeroConstantMultiple(S);
6877 APInt Remainder = APInt::getMaxValue(BitWidth).urem(Multiple);
6878 if (!Remainder.isZero())
6879 ConservativeResult =
6880 ConstantRange(APInt::getMinValue(BitWidth),
6881 APInt::getMaxValue(BitWidth) - Remainder + 1);
6882 }
6883 else {
6884 uint32_t TZ = getMinTrailingZeros(S);
6885 if (TZ != 0) {
6886 ConservativeResult = ConstantRange(
6888 APInt::getSignedMaxValue(BitWidth).ashr(TZ).shl(TZ) + 1);
6889 }
6890 }
6891
6892 switch (S->getSCEVType()) {
6893 case scConstant:
6894 llvm_unreachable("Already handled above.");
6895 case scVScale:
6896 return setRange(S, SignHint, getVScaleRange(&F, BitWidth));
6897 case scTruncate: {
6898 const SCEVTruncateExpr *Trunc = cast<SCEVTruncateExpr>(S);
6899 ConstantRange X = getRangeRef(Trunc->getOperand(), SignHint, Depth + 1);
6900 return setRange(
6901 Trunc, SignHint,
6902 ConservativeResult.intersectWith(X.truncate(BitWidth), RangeType));
6903 }
6904 case scZeroExtend: {
6905 const SCEVZeroExtendExpr *ZExt = cast<SCEVZeroExtendExpr>(S);
6906 ConstantRange X = getRangeRef(ZExt->getOperand(), SignHint, Depth + 1);
6907 return setRange(
6908 ZExt, SignHint,
6909 ConservativeResult.intersectWith(X.zeroExtend(BitWidth), RangeType));
6910 }
6911 case scSignExtend: {
6912 const SCEVSignExtendExpr *SExt = cast<SCEVSignExtendExpr>(S);
6913 ConstantRange X = getRangeRef(SExt->getOperand(), SignHint, Depth + 1);
6914 return setRange(
6915 SExt, SignHint,
6916 ConservativeResult.intersectWith(X.signExtend(BitWidth), RangeType));
6917 }
6918 case scPtrToAddr:
6919 case scPtrToInt: {
6920 const SCEVCastExpr *Cast = cast<SCEVCastExpr>(S);
6921 ConstantRange X = getRangeRef(Cast->getOperand(), SignHint, Depth + 1);
6922 return setRange(Cast, SignHint, X);
6923 }
6924 case scAddExpr: {
6925 const SCEVAddExpr *Add = cast<SCEVAddExpr>(S);
6926 // Check if this is a URem pattern: A - (A / B) * B, which is always < B.
6927 const SCEV *URemLHS = nullptr, *URemRHS = nullptr;
6928 if (SignHint == ScalarEvolution::HINT_RANGE_UNSIGNED &&
6929 match(S, m_scev_URem(m_SCEV(URemLHS), m_SCEV(URemRHS), *this))) {
6930 ConstantRange LHSRange = getRangeRef(URemLHS, SignHint, Depth + 1);
6931 ConstantRange RHSRange = getRangeRef(URemRHS, SignHint, Depth + 1);
6932 ConservativeResult =
6933 ConservativeResult.intersectWith(LHSRange.urem(RHSRange), RangeType);
6934 }
6935 ConstantRange X = getRangeRef(Add->getOperand(0), SignHint, Depth + 1);
6936 unsigned WrapType = OBO::AnyWrap;
6937 if (Add->hasNoSignedWrap())
6938 WrapType |= OBO::NoSignedWrap;
6939 if (Add->hasNoUnsignedWrap())
6940 WrapType |= OBO::NoUnsignedWrap;
6941 for (const SCEV *Op : drop_begin(Add->operands()))
6942 X = X.addWithNoWrap(getRangeRef(Op, SignHint, Depth + 1), WrapType,
6943 RangeType);
6944 return setRange(Add, SignHint,
6945 ConservativeResult.intersectWith(X, RangeType));
6946 }
6947 case scMulExpr: {
6948 const SCEVMulExpr *Mul = cast<SCEVMulExpr>(S);
6949 ConstantRange X = getRangeRef(Mul->getOperand(0), SignHint, Depth + 1);
6950 for (const SCEV *Op : drop_begin(Mul->operands()))
6951 X = X.multiply(getRangeRef(Op, SignHint, Depth + 1));
6952 return setRange(Mul, SignHint,
6953 ConservativeResult.intersectWith(X, RangeType));
6954 }
6955 case scUDivExpr: {
6956 const SCEVUDivExpr *UDiv = cast<SCEVUDivExpr>(S);
6957 ConstantRange X = getRangeRef(UDiv->getLHS(), SignHint, Depth + 1);
6958 ConstantRange Y = getRangeRef(UDiv->getRHS(), SignHint, Depth + 1);
6959 return setRange(UDiv, SignHint,
6960 ConservativeResult.intersectWith(X.udiv(Y), RangeType));
6961 }
6962 case scAddRecExpr: {
6963 const SCEVAddRecExpr *AddRec = cast<SCEVAddRecExpr>(S);
6964 // If there's no unsigned wrap, the value will never be less than its
6965 // initial value.
6966 if (AddRec->hasNoUnsignedWrap()) {
6967 APInt UnsignedMinValue = getUnsignedRangeMin(AddRec->getStart());
6968 if (!UnsignedMinValue.isZero())
6969 ConservativeResult = ConservativeResult.intersectWith(
6970 ConstantRange(UnsignedMinValue, APInt(BitWidth, 0)), RangeType);
6971 }
6972
6973 // If there's no signed wrap, and all the operands except initial value have
6974 // the same sign or zero, the value won't ever be:
6975 // 1: smaller than initial value if operands are non negative,
6976 // 2: bigger than initial value if operands are non positive.
6977 // For both cases, value can not cross signed min/max boundary.
6978 if (AddRec->hasNoSignedWrap()) {
6979 bool AllNonNeg = true;
6980 bool AllNonPos = true;
6981 for (unsigned i = 1, e = AddRec->getNumOperands(); i != e; ++i) {
6982 if (!isKnownNonNegative(AddRec->getOperand(i)))
6983 AllNonNeg = false;
6984 if (!isKnownNonPositive(AddRec->getOperand(i)))
6985 AllNonPos = false;
6986 }
6987 if (AllNonNeg)
6988 ConservativeResult = ConservativeResult.intersectWith(
6991 RangeType);
6992 else if (AllNonPos)
6993 ConservativeResult = ConservativeResult.intersectWith(
6995 getSignedRangeMax(AddRec->getStart()) +
6996 1),
6997 RangeType);
6998 }
6999
7000 // TODO: non-affine addrec
7001 if (AddRec->isAffine()) {
7002 const SCEV *MaxBEScev =
7004 if (!isa<SCEVCouldNotCompute>(MaxBEScev)) {
7005 APInt MaxBECount = cast<SCEVConstant>(MaxBEScev)->getAPInt();
7006
7007 // Adjust MaxBECount to the same bitwidth as AddRec. We can truncate if
7008 // MaxBECount's active bits are all <= AddRec's bit width.
7009 if (MaxBECount.getBitWidth() > BitWidth &&
7010 MaxBECount.getActiveBits() <= BitWidth)
7011 MaxBECount = MaxBECount.trunc(BitWidth);
7012 else if (MaxBECount.getBitWidth() < BitWidth)
7013 MaxBECount = MaxBECount.zext(BitWidth);
7014
7015 if (MaxBECount.getBitWidth() == BitWidth) {
7016 auto RangeFromAffine = getRangeForAffineAR(
7017 AddRec->getStart(), AddRec->getStepRecurrence(*this), MaxBECount);
7018 ConservativeResult =
7019 ConservativeResult.intersectWith(RangeFromAffine, RangeType);
7020
7021 auto RangeFromFactoring = getRangeViaFactoring(
7022 AddRec->getStart(), AddRec->getStepRecurrence(*this), MaxBECount);
7023 ConservativeResult =
7024 ConservativeResult.intersectWith(RangeFromFactoring, RangeType);
7025 }
7026 }
7027
7028 // Now try symbolic BE count and more powerful methods.
7030 const SCEV *SymbolicMaxBECount =
7032 if (!isa<SCEVCouldNotCompute>(SymbolicMaxBECount) &&
7033 getTypeSizeInBits(MaxBEScev->getType()) <= BitWidth &&
7034 AddRec->hasNoSelfWrap()) {
7035 auto RangeFromAffineNew = getRangeForAffineNoSelfWrappingAR(
7036 AddRec, SymbolicMaxBECount, BitWidth, SignHint);
7037 ConservativeResult =
7038 ConservativeResult.intersectWith(RangeFromAffineNew, RangeType);
7039 }
7040 }
7041 }
7042
7043 return setRange(AddRec, SignHint, std::move(ConservativeResult));
7044 }
7045 case scUMaxExpr:
7046 case scSMaxExpr:
7047 case scUMinExpr:
7048 case scSMinExpr:
7049 case scSequentialUMinExpr: {
7051 switch (S->getSCEVType()) {
7052 case scUMaxExpr:
7053 ID = Intrinsic::umax;
7054 break;
7055 case scSMaxExpr:
7056 ID = Intrinsic::smax;
7057 break;
7058 case scUMinExpr:
7060 ID = Intrinsic::umin;
7061 break;
7062 case scSMinExpr:
7063 ID = Intrinsic::smin;
7064 break;
7065 default:
7066 llvm_unreachable("Unknown SCEVMinMaxExpr/SCEVSequentialMinMaxExpr.");
7067 }
7068
7069 const auto *NAry = cast<SCEVNAryExpr>(S);
7070 ConstantRange X = getRangeRef(NAry->getOperand(0), SignHint, Depth + 1);
7071 for (unsigned i = 1, e = NAry->getNumOperands(); i != e; ++i)
7072 X = X.intrinsic(
7073 ID, {X, getRangeRef(NAry->getOperand(i), SignHint, Depth + 1)});
7074 return setRange(S, SignHint,
7075 ConservativeResult.intersectWith(X, RangeType));
7076 }
7077 case scUnknown: {
7078 const SCEVUnknown *U = cast<SCEVUnknown>(S);
7079 Value *V = U->getValue();
7080
7081 // Check if the IR explicitly contains !range metadata.
7082 std::optional<ConstantRange> MDRange = GetRangeFromMetadata(V);
7083 if (MDRange)
7084 ConservativeResult =
7085 ConservativeResult.intersectWith(*MDRange, RangeType);
7086
7087 // Use facts about recurrences in the underlying IR. Note that add
7088 // recurrences are AddRecExprs and thus don't hit this path. This
7089 // primarily handles shift recurrences.
7090 auto CR = getRangeForUnknownRecurrence(U);
7091 ConservativeResult = ConservativeResult.intersectWith(CR);
7092
7093 // See if ValueTracking can give us a useful range.
7094 const DataLayout &DL = getDataLayout();
7095 KnownBits Known = computeKnownBits(V, DL, &AC, nullptr, &DT);
7096 if (Known.getBitWidth() != BitWidth)
7097 Known = Known.zextOrTrunc(BitWidth);
7098
7099 // ValueTracking may be able to compute a tighter result for the number of
7100 // sign bits than for the value of those sign bits.
7101 unsigned NS = ComputeNumSignBits(V, DL, &AC, nullptr, &DT);
7102 if (U->getType()->isPointerTy()) {
7103 // If the pointer size is larger than the index size type, this can cause
7104 // NS to be larger than BitWidth. So compensate for this.
7105 unsigned ptrSize = DL.getPointerTypeSizeInBits(U->getType());
7106 int ptrIdxDiff = ptrSize - BitWidth;
7107 if (ptrIdxDiff > 0 && ptrSize > BitWidth && NS > (unsigned)ptrIdxDiff)
7108 NS -= ptrIdxDiff;
7109 }
7110
7111 if (NS > 1) {
7112 // If we know any of the sign bits, we know all of the sign bits.
7113 if (!Known.Zero.getHiBits(NS).isZero())
7114 Known.Zero.setHighBits(NS);
7115 if (!Known.One.getHiBits(NS).isZero())
7116 Known.One.setHighBits(NS);
7117 }
7118
7119 if (Known.getMinValue() != Known.getMaxValue() + 1)
7120 ConservativeResult = ConservativeResult.intersectWith(
7121 ConstantRange(Known.getMinValue(), Known.getMaxValue() + 1),
7122 RangeType);
7123 if (NS > 1)
7124 ConservativeResult = ConservativeResult.intersectWith(
7125 ConstantRange(APInt::getSignedMinValue(BitWidth).ashr(NS - 1),
7126 APInt::getSignedMaxValue(BitWidth).ashr(NS - 1) + 1),
7127 RangeType);
7128
7129 if (U->getType()->isPointerTy() && SignHint == HINT_RANGE_UNSIGNED) {
7130 // Strengthen the range if the underlying IR value is a
7131 // global/alloca/heap allocation using the size of the object.
7132 bool CanBeNull, CanBeFreed;
7133 uint64_t DerefBytes =
7134 V->getPointerDereferenceableBytes(DL, CanBeNull, CanBeFreed);
7135 if (DerefBytes > 1 && isUIntN(BitWidth, DerefBytes)) {
7136 // The highest address the object can start is DerefBytes bytes before
7137 // the end (unsigned max value). If this value is not a multiple of the
7138 // alignment, the last possible start value is the next lowest multiple
7139 // of the alignment. Note: The computations below cannot overflow,
7140 // because if they would there's no possible start address for the
7141 // object.
7142 APInt MaxVal =
7143 APInt::getMaxValue(BitWidth) - APInt(BitWidth, DerefBytes);
7144 uint64_t Align = U->getValue()->getPointerAlignment(DL).value();
7145 uint64_t Rem = MaxVal.urem(Align);
7146 MaxVal -= APInt(BitWidth, Rem);
7147 APInt MinVal = APInt::getZero(BitWidth);
7148 if (llvm::isKnownNonZero(V, DL))
7149 MinVal = Align;
7150 ConservativeResult = ConservativeResult.intersectWith(
7151 ConstantRange::getNonEmpty(MinVal, MaxVal + 1), RangeType);
7152 }
7153 }
7154
7155 // A range of Phi is a subset of union of all ranges of its input.
7156 if (PHINode *Phi = dyn_cast<PHINode>(V)) {
7157 // SCEVExpander sometimes creates SCEVUnknowns that are secretly
7158 // AddRecs; return the range for the corresponding AddRec.
7159 if (auto *AR = dyn_cast<SCEVAddRecExpr>(getSCEV(V)))
7160 return getRangeRef(AR, SignHint, Depth + 1);
7161
7162 // Make sure that we do not run over cycled Phis.
7163 if (RangeRefPHIAllowedOperands(DT, Phi)) {
7164 ConstantRange RangeFromOps(BitWidth, /*isFullSet=*/false);
7165
7166 for (const auto &Op : Phi->operands()) {
7167 auto OpRange = getRangeRef(getSCEV(Op), SignHint, Depth + 1);
7168 RangeFromOps = RangeFromOps.unionWith(OpRange);
7169 // No point to continue if we already have a full set.
7170 if (RangeFromOps.isFullSet())
7171 break;
7172 }
7173 ConservativeResult =
7174 ConservativeResult.intersectWith(RangeFromOps, RangeType);
7175 }
7176 }
7177
7178 // vscale can't be equal to zero
7179 if (const auto *II = dyn_cast<IntrinsicInst>(V))
7180 if (II->getIntrinsicID() == Intrinsic::vscale) {
7181 ConstantRange Disallowed = APInt::getZero(BitWidth);
7182 ConservativeResult = ConservativeResult.difference(Disallowed);
7183 }
7184
7185 return setRange(U, SignHint, std::move(ConservativeResult));
7186 }
7187 case scCouldNotCompute:
7188 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
7189 }
7190
7191 return setRange(S, SignHint, std::move(ConservativeResult));
7192}
7193
7194// Given a StartRange, Step and MaxBECount for an expression compute a range of
7195// values that the expression can take. Initially, the expression has a value
7196// from StartRange and then is changed by Step up to MaxBECount times. Signed
7197// argument defines if we treat Step as signed or unsigned.
7199 const ConstantRange &StartRange,
7200 const APInt &MaxBECount,
7201 bool Signed) {
7202 unsigned BitWidth = Step.getBitWidth();
7203 assert(BitWidth == StartRange.getBitWidth() &&
7204 BitWidth == MaxBECount.getBitWidth() && "mismatched bit widths");
7205 // If either Step or MaxBECount is 0, then the expression won't change, and we
7206 // just need to return the initial range.
7207 if (Step == 0 || MaxBECount == 0)
7208 return StartRange;
7209
7210 // If we don't know anything about the initial value (i.e. StartRange is
7211 // FullRange), then we don't know anything about the final range either.
7212 // Return FullRange.
7213 if (StartRange.isFullSet())
7214 return ConstantRange::getFull(BitWidth);
7215
7216 // If Step is signed and negative, then we use its absolute value, but we also
7217 // note that we're moving in the opposite direction.
7218 bool Descending = Signed && Step.isNegative();
7219
7220 if (Signed)
7221 // This is correct even for INT_SMIN. Let's look at i8 to illustrate this:
7222 // abs(INT_SMIN) = abs(-128) = abs(0x80) = -0x80 = 0x80 = 128.
7223 // This equations hold true due to the well-defined wrap-around behavior of
7224 // APInt.
7225 Step = Step.abs();
7226
7227 // Check if Offset is more than full span of BitWidth. If it is, the
7228 // expression is guaranteed to overflow.
7229 if (APInt::getMaxValue(StartRange.getBitWidth()).udiv(Step).ult(MaxBECount))
7230 return ConstantRange::getFull(BitWidth);
7231
7232 // Offset is by how much the expression can change. Checks above guarantee no
7233 // overflow here.
7234 APInt Offset = Step * MaxBECount;
7235
7236 // Minimum value of the final range will match the minimal value of StartRange
7237 // if the expression is increasing and will be decreased by Offset otherwise.
7238 // Maximum value of the final range will match the maximal value of StartRange
7239 // if the expression is decreasing and will be increased by Offset otherwise.
7240 APInt StartLower = StartRange.getLower();
7241 APInt StartUpper = StartRange.getUpper() - 1;
7242 APInt MovedBoundary = Descending ? (StartLower - std::move(Offset))
7243 : (StartUpper + std::move(Offset));
7244
7245 // It's possible that the new minimum/maximum value will fall into the initial
7246 // range (due to wrap around). This means that the expression can take any
7247 // value in this bitwidth, and we have to return full range.
7248 if (StartRange.contains(MovedBoundary))
7249 return ConstantRange::getFull(BitWidth);
7250
7251 APInt NewLower =
7252 Descending ? std::move(MovedBoundary) : std::move(StartLower);
7253 APInt NewUpper =
7254 Descending ? std::move(StartUpper) : std::move(MovedBoundary);
7255 NewUpper += 1;
7256
7257 // No overflow detected, return [StartLower, StartUpper + Offset + 1) range.
7258 return ConstantRange::getNonEmpty(std::move(NewLower), std::move(NewUpper));
7259}
7260
7261ConstantRange ScalarEvolution::getRangeForAffineAR(const SCEV *Start,
7262 const SCEV *Step,
7263 const APInt &MaxBECount) {
7264 assert(getTypeSizeInBits(Start->getType()) ==
7265 getTypeSizeInBits(Step->getType()) &&
7266 getTypeSizeInBits(Start->getType()) == MaxBECount.getBitWidth() &&
7267 "mismatched bit widths");
7268
7269 // First, consider step signed.
7270 ConstantRange StartSRange = getSignedRange(Start);
7271 ConstantRange StepSRange = getSignedRange(Step);
7272
7273 // If Step can be both positive and negative, we need to find ranges for the
7274 // maximum absolute step values in both directions and union them.
7275 ConstantRange SR = getRangeForAffineARHelper(
7276 StepSRange.getSignedMin(), StartSRange, MaxBECount, /* Signed = */ true);
7278 StartSRange, MaxBECount,
7279 /* Signed = */ true));
7280
7281 // Next, consider step unsigned.
7282 ConstantRange UR = getRangeForAffineARHelper(
7283 getUnsignedRangeMax(Step), getUnsignedRange(Start), MaxBECount,
7284 /* Signed = */ false);
7285
7286 // Finally, intersect signed and unsigned ranges.
7288}
7289
7290ConstantRange ScalarEvolution::getRangeForAffineNoSelfWrappingAR(
7291 const SCEVAddRecExpr *AddRec, const SCEV *MaxBECount, unsigned BitWidth,
7292 ScalarEvolution::RangeSignHint SignHint) {
7293 assert(AddRec->isAffine() && "Non-affine AddRecs are not suppored!\n");
7294 assert(AddRec->hasNoSelfWrap() &&
7295 "This only works for non-self-wrapping AddRecs!");
7296 const bool IsSigned = SignHint == HINT_RANGE_SIGNED;
7297 const SCEV *Step = AddRec->getStepRecurrence(*this);
7298 // Only deal with constant step to save compile time.
7299 if (!isa<SCEVConstant>(Step))
7300 return ConstantRange::getFull(BitWidth);
7301 // Let's make sure that we can prove that we do not self-wrap during
7302 // MaxBECount iterations. We need this because MaxBECount is a maximum
7303 // iteration count estimate, and we might infer nw from some exit for which we
7304 // do not know max exit count (or any other side reasoning).
7305 // TODO: Turn into assert at some point.
7306 if (getTypeSizeInBits(MaxBECount->getType()) >
7307 getTypeSizeInBits(AddRec->getType()))
7308 return ConstantRange::getFull(BitWidth);
7309 MaxBECount = getNoopOrZeroExtend(MaxBECount, AddRec->getType());
7310 const SCEV *RangeWidth = getMinusOne(AddRec->getType());
7311 const SCEV *StepAbs = getUMinExpr(Step, getNegativeSCEV(Step));
7312 const SCEV *MaxItersWithoutWrap = getUDivExpr(RangeWidth, StepAbs);
7313 if (!isKnownPredicateViaConstantRanges(ICmpInst::ICMP_ULE, MaxBECount,
7314 MaxItersWithoutWrap))
7315 return ConstantRange::getFull(BitWidth);
7316
7317 ICmpInst::Predicate LEPred =
7319 ICmpInst::Predicate GEPred =
7321 const SCEV *End = AddRec->evaluateAtIteration(MaxBECount, *this);
7322
7323 // We know that there is no self-wrap. Let's take Start and End values and
7324 // look at all intermediate values V1, V2, ..., Vn that IndVar takes during
7325 // the iteration. They either lie inside the range [Min(Start, End),
7326 // Max(Start, End)] or outside it:
7327 //
7328 // Case 1: RangeMin ... Start V1 ... VN End ... RangeMax;
7329 // Case 2: RangeMin Vk ... V1 Start ... End Vn ... Vk + 1 RangeMax;
7330 //
7331 // No self wrap flag guarantees that the intermediate values cannot be BOTH
7332 // outside and inside the range [Min(Start, End), Max(Start, End)]. Using that
7333 // knowledge, let's try to prove that we are dealing with Case 1. It is so if
7334 // Start <= End and step is positive, or Start >= End and step is negative.
7335 const SCEV *Start = applyLoopGuards(AddRec->getStart(), AddRec->getLoop());
7336 ConstantRange StartRange = getRangeRef(Start, SignHint);
7337 ConstantRange EndRange = getRangeRef(End, SignHint);
7338 ConstantRange RangeBetween = StartRange.unionWith(EndRange);
7339 // If they already cover full iteration space, we will know nothing useful
7340 // even if we prove what we want to prove.
7341 if (RangeBetween.isFullSet())
7342 return RangeBetween;
7343 // Only deal with ranges that do not wrap (i.e. RangeMin < RangeMax).
7344 bool IsWrappedSet = IsSigned ? RangeBetween.isSignWrappedSet()
7345 : RangeBetween.isWrappedSet();
7346 if (IsWrappedSet)
7347 return ConstantRange::getFull(BitWidth);
7348
7349 if (isKnownPositive(Step) &&
7350 isKnownPredicateViaConstantRanges(LEPred, Start, End))
7351 return RangeBetween;
7352 if (isKnownNegative(Step) &&
7353 isKnownPredicateViaConstantRanges(GEPred, Start, End))
7354 return RangeBetween;
7355 return ConstantRange::getFull(BitWidth);
7356}
7357
7358ConstantRange ScalarEvolution::getRangeViaFactoring(const SCEV *Start,
7359 const SCEV *Step,
7360 const APInt &MaxBECount) {
7361 // RangeOf({C?A:B,+,C?P:Q}) == RangeOf(C?{A,+,P}:{B,+,Q})
7362 // == RangeOf({A,+,P}) union RangeOf({B,+,Q})
7363
7364 unsigned BitWidth = MaxBECount.getBitWidth();
7365 assert(getTypeSizeInBits(Start->getType()) == BitWidth &&
7366 getTypeSizeInBits(Step->getType()) == BitWidth &&
7367 "mismatched bit widths");
7368
7369 struct SelectPattern {
7370 Value *Condition = nullptr;
7371 APInt TrueValue;
7372 APInt FalseValue;
7373
7374 explicit SelectPattern(ScalarEvolution &SE, unsigned BitWidth,
7375 const SCEV *S) {
7376 std::optional<unsigned> CastOp;
7377 APInt Offset(BitWidth, 0);
7378
7380 "Should be!");
7381
7382 // Peel off a constant offset. In the future we could consider being
7383 // smarter here and handle {Start+Step,+,Step} too.
7384 const APInt *Off;
7385 if (match(S, m_scev_Add(m_scev_APInt(Off), m_SCEV(S))))
7386 Offset = *Off;
7387
7388 // Peel off a cast operation
7389 if (auto *SCast = dyn_cast<SCEVIntegralCastExpr>(S)) {
7390 CastOp = SCast->getSCEVType();
7391 S = SCast->getOperand();
7392 }
7393
7394 using namespace llvm::PatternMatch;
7395
7396 auto *SU = dyn_cast<SCEVUnknown>(S);
7397 const APInt *TrueVal, *FalseVal;
7398 if (!SU ||
7399 !match(SU->getValue(), m_Select(m_Value(Condition), m_APInt(TrueVal),
7400 m_APInt(FalseVal)))) {
7401 Condition = nullptr;
7402 return;
7403 }
7404
7405 TrueValue = *TrueVal;
7406 FalseValue = *FalseVal;
7407
7408 // Re-apply the cast we peeled off earlier
7409 if (CastOp)
7410 switch (*CastOp) {
7411 default:
7412 llvm_unreachable("Unknown SCEV cast type!");
7413
7414 case scTruncate:
7415 TrueValue = TrueValue.trunc(BitWidth);
7416 FalseValue = FalseValue.trunc(BitWidth);
7417 break;
7418 case scZeroExtend:
7419 TrueValue = TrueValue.zext(BitWidth);
7420 FalseValue = FalseValue.zext(BitWidth);
7421 break;
7422 case scSignExtend:
7423 TrueValue = TrueValue.sext(BitWidth);
7424 FalseValue = FalseValue.sext(BitWidth);
7425 break;
7426 }
7427
7428 // Re-apply the constant offset we peeled off earlier
7429 TrueValue += Offset;
7430 FalseValue += Offset;
7431 }
7432
7433 bool isRecognized() { return Condition != nullptr; }
7434 };
7435
7436 SelectPattern StartPattern(*this, BitWidth, Start);
7437 if (!StartPattern.isRecognized())
7438 return ConstantRange::getFull(BitWidth);
7439
7440 SelectPattern StepPattern(*this, BitWidth, Step);
7441 if (!StepPattern.isRecognized())
7442 return ConstantRange::getFull(BitWidth);
7443
7444 if (StartPattern.Condition != StepPattern.Condition) {
7445 // We don't handle this case today; but we could, by considering four
7446 // possibilities below instead of two. I'm not sure if there are cases where
7447 // that will help over what getRange already does, though.
7448 return ConstantRange::getFull(BitWidth);
7449 }
7450
7451 // NB! Calling ScalarEvolution::getConstant is fine, but we should not try to
7452 // construct arbitrary general SCEV expressions here. This function is called
7453 // from deep in the call stack, and calling getSCEV (on a sext instruction,
7454 // say) can end up caching a suboptimal value.
7455
7456 // FIXME: without the explicit `this` receiver below, MSVC errors out with
7457 // C2352 and C2512 (otherwise it isn't needed).
7458
7459 const SCEV *TrueStart = this->getConstant(StartPattern.TrueValue);
7460 const SCEV *TrueStep = this->getConstant(StepPattern.TrueValue);
7461 const SCEV *FalseStart = this->getConstant(StartPattern.FalseValue);
7462 const SCEV *FalseStep = this->getConstant(StepPattern.FalseValue);
7463
7464 ConstantRange TrueRange =
7465 this->getRangeForAffineAR(TrueStart, TrueStep, MaxBECount);
7466 ConstantRange FalseRange =
7467 this->getRangeForAffineAR(FalseStart, FalseStep, MaxBECount);
7468
7469 return TrueRange.unionWith(FalseRange);
7470}
7471
7472SCEV::NoWrapFlags ScalarEvolution::getNoWrapFlagsFromUB(const Value *V) {
7473 if (isa<ConstantExpr>(V)) return SCEV::FlagAnyWrap;
7474 const BinaryOperator *BinOp = cast<BinaryOperator>(V);
7475
7476 // Return early if there are no flags to propagate to the SCEV.
7478 if (BinOp->hasNoUnsignedWrap())
7480 if (BinOp->hasNoSignedWrap())
7482 if (Flags == SCEV::FlagAnyWrap)
7483 return SCEV::FlagAnyWrap;
7484
7485 return isSCEVExprNeverPoison(BinOp) ? Flags : SCEV::FlagAnyWrap;
7486}
7487
7488const Instruction *
7489ScalarEvolution::getNonTrivialDefiningScopeBound(const SCEV *S) {
7490 if (auto *AddRec = dyn_cast<SCEVAddRecExpr>(S))
7491 return &*AddRec->getLoop()->getHeader()->begin();
7492 if (auto *U = dyn_cast<SCEVUnknown>(S))
7493 if (auto *I = dyn_cast<Instruction>(U->getValue()))
7494 return I;
7495 return nullptr;
7496}
7497
7498const Instruction *ScalarEvolution::getDefiningScopeBound(ArrayRef<SCEVUse> Ops,
7499 bool &Precise) {
7500 Precise = true;
7501 // Do a bounded search of the def relation of the requested SCEVs.
7502 SmallPtrSet<const SCEV *, 16> Visited;
7503 SmallVector<SCEVUse> Worklist;
7504 auto pushOp = [&](const SCEV *S) {
7505 if (!Visited.insert(S).second)
7506 return;
7507 // Threshold of 30 here is arbitrary.
7508 if (Visited.size() > 30) {
7509 Precise = false;
7510 return;
7511 }
7512 Worklist.push_back(S);
7513 };
7514
7515 for (SCEVUse S : Ops)
7516 pushOp(S);
7517
7518 const Instruction *Bound = nullptr;
7519 while (!Worklist.empty()) {
7520 SCEVUse S = Worklist.pop_back_val();
7521 if (auto *DefI = getNonTrivialDefiningScopeBound(S)) {
7522 if (!Bound || DT.dominates(Bound, DefI))
7523 Bound = DefI;
7524 } else {
7525 for (SCEVUse Op : S->operands())
7526 pushOp(Op);
7527 }
7528 }
7529 return Bound ? Bound : &*F.getEntryBlock().begin();
7530}
7531
7532const Instruction *
7533ScalarEvolution::getDefiningScopeBound(ArrayRef<SCEVUse> Ops) {
7534 bool Discard;
7535 return getDefiningScopeBound(Ops, Discard);
7536}
7537
7538bool ScalarEvolution::isGuaranteedToTransferExecutionTo(const Instruction *A,
7539 const Instruction *B) {
7540 if (A->getParent() == B->getParent() &&
7542 B->getIterator()))
7543 return true;
7544
7545 auto *BLoop = LI.getLoopFor(B->getParent());
7546 if (BLoop && BLoop->getHeader() == B->getParent() &&
7547 BLoop->getLoopPreheader() == A->getParent() &&
7549 A->getParent()->end()) &&
7550 isGuaranteedToTransferExecutionToSuccessor(B->getParent()->begin(),
7551 B->getIterator()))
7552 return true;
7553 return false;
7554}
7555
7556bool ScalarEvolution::isGuaranteedNotToBePoison(const SCEV *Op) {
7557 SCEVPoisonCollector PC(/* LookThroughMaybePoisonBlocking */ true);
7558 visitAll(Op, PC);
7559 return PC.MaybePoison.empty();
7560}
7561
7562bool ScalarEvolution::isGuaranteedNotToCauseUB(const SCEV *Op) {
7563 return !SCEVExprContains(Op, [this](const SCEV *S) {
7564 const SCEV *Op1;
7565 bool M = match(S, m_scev_UDiv(m_SCEV(), m_SCEV(Op1)));
7566 // The UDiv may be UB if the divisor is poison or zero. Unless the divisor
7567 // is a non-zero constant, we have to assume the UDiv may be UB.
7568 return M && (!isKnownNonZero(Op1) || !isGuaranteedNotToBePoison(Op1));
7569 });
7570}
7571
7572bool ScalarEvolution::isSCEVExprNeverPoison(const Instruction *I) {
7573 // Only proceed if we can prove that I does not yield poison.
7575 return false;
7576
7577 // At this point we know that if I is executed, then it does not wrap
7578 // according to at least one of NSW or NUW. If I is not executed, then we do
7579 // not know if the calculation that I represents would wrap. Multiple
7580 // instructions can map to the same SCEV. If we apply NSW or NUW from I to
7581 // the SCEV, we must guarantee no wrapping for that SCEV also when it is
7582 // derived from other instructions that map to the same SCEV. We cannot make
7583 // that guarantee for cases where I is not executed. So we need to find a
7584 // upper bound on the defining scope for the SCEV, and prove that I is
7585 // executed every time we enter that scope. When the bounding scope is a
7586 // loop (the common case), this is equivalent to proving I executes on every
7587 // iteration of that loop.
7588 SmallVector<SCEVUse> SCEVOps;
7589 for (const Use &Op : I->operands()) {
7590 // I could be an extractvalue from a call to an overflow intrinsic.
7591 // TODO: We can do better here in some cases.
7592 if (isSCEVable(Op->getType()))
7593 SCEVOps.push_back(getSCEV(Op));
7594 }
7595 auto *DefI = getDefiningScopeBound(SCEVOps);
7596 return isGuaranteedToTransferExecutionTo(DefI, I);
7597}
7598
7599bool ScalarEvolution::isAddRecNeverPoison(const Instruction *I, const Loop *L) {
7600 // If we know that \c I can never be poison period, then that's enough.
7601 if (isSCEVExprNeverPoison(I))
7602 return true;
7603
7604 // If the loop only has one exit, then we know that, if the loop is entered,
7605 // any instruction dominating that exit will be executed. If any such
7606 // instruction would result in UB, the addrec cannot be poison.
7607 //
7608 // This is basically the same reasoning as in isSCEVExprNeverPoison(), but
7609 // also handles uses outside the loop header (they just need to dominate the
7610 // single exit).
7611
7612 auto *ExitingBB = L->getExitingBlock();
7613 if (!ExitingBB || !loopHasNoAbnormalExits(L))
7614 return false;
7615
7616 SmallPtrSet<const Value *, 16> KnownPoison;
7618
7619 // We start by assuming \c I, the post-inc add recurrence, is poison. Only
7620 // things that are known to be poison under that assumption go on the
7621 // Worklist.
7622 KnownPoison.insert(I);
7623 Worklist.push_back(I);
7624
7625 while (!Worklist.empty()) {
7626 const Instruction *Poison = Worklist.pop_back_val();
7627
7628 for (const Use &U : Poison->uses()) {
7629 const Instruction *PoisonUser = cast<Instruction>(U.getUser());
7630 if (mustTriggerUB(PoisonUser, KnownPoison) &&
7631 DT.dominates(PoisonUser->getParent(), ExitingBB))
7632 return true;
7633
7634 if (propagatesPoison(U) && L->contains(PoisonUser))
7635 if (KnownPoison.insert(PoisonUser).second)
7636 Worklist.push_back(PoisonUser);
7637 }
7638 }
7639
7640 return false;
7641}
7642
7643ScalarEvolution::LoopProperties
7644ScalarEvolution::getLoopProperties(const Loop *L) {
7645 using LoopProperties = ScalarEvolution::LoopProperties;
7646
7647 auto Itr = LoopPropertiesCache.find(L);
7648 if (Itr == LoopPropertiesCache.end()) {
7649 auto HasSideEffects = [](Instruction *I) {
7650 if (auto *SI = dyn_cast<StoreInst>(I))
7651 return !SI->isSimple();
7652
7653 if (I->mayThrow())
7654 return true;
7655
7656 // Non-volatile memset / memcpy do not count as side-effect for forward
7657 // progress.
7658 if (isa<MemIntrinsic>(I) && !I->isVolatile())
7659 return false;
7660
7661 return I->mayWriteToMemory();
7662 };
7663
7664 LoopProperties LP = {/* HasNoAbnormalExits */ true,
7665 /*HasNoSideEffects*/ true};
7666
7667 for (auto *BB : L->getBlocks())
7668 for (auto &I : *BB) {
7670 LP.HasNoAbnormalExits = false;
7671 if (HasSideEffects(&I))
7672 LP.HasNoSideEffects = false;
7673 if (!LP.HasNoAbnormalExits && !LP.HasNoSideEffects)
7674 break; // We're already as pessimistic as we can get.
7675 }
7676
7677 auto InsertPair = LoopPropertiesCache.insert({L, LP});
7678 assert(InsertPair.second && "We just checked!");
7679 Itr = InsertPair.first;
7680 }
7681
7682 return Itr->second;
7683}
7684
7686 // A mustprogress loop without side effects must be finite.
7687 // TODO: The check used here is very conservative. It's only *specific*
7688 // side effects which are well defined in infinite loops.
7689 return isFinite(L) || (isMustProgress(L) && loopHasNoSideEffects(L));
7690}
7691
7692const SCEV *ScalarEvolution::createSCEVIter(Value *V) {
7693 // Worklist item with a Value and a bool indicating whether all operands have
7694 // been visited already.
7697
7698 Stack.emplace_back(V, true);
7699 Stack.emplace_back(V, false);
7700 while (!Stack.empty()) {
7701 auto E = Stack.pop_back_val();
7702 Value *CurV = E.getPointer();
7703
7704 if (getExistingSCEV(CurV))
7705 continue;
7706
7708 const SCEV *CreatedSCEV = nullptr;
7709 // If all operands have been visited already, create the SCEV.
7710 if (E.getInt()) {
7711 CreatedSCEV = createSCEV(CurV);
7712 } else {
7713 // Otherwise get the operands we need to create SCEV's for before creating
7714 // the SCEV for CurV. If the SCEV for CurV can be constructed trivially,
7715 // just use it.
7716 CreatedSCEV = getOperandsToCreate(CurV, Ops);
7717 }
7718
7719 if (CreatedSCEV) {
7720 insertValueToMap(CurV, CreatedSCEV);
7721 } else {
7722 // Queue CurV for SCEV creation, followed by its's operands which need to
7723 // be constructed first.
7724 Stack.emplace_back(CurV, true);
7725 for (Value *Op : Ops)
7726 Stack.emplace_back(Op, false);
7727 }
7728 }
7729
7730 return getExistingSCEV(V);
7731}
7732
7733const SCEV *
7734ScalarEvolution::getOperandsToCreate(Value *V, SmallVectorImpl<Value *> &Ops) {
7735 if (!isSCEVable(V->getType()))
7736 return getUnknown(V);
7737
7738 if (Instruction *I = dyn_cast<Instruction>(V)) {
7739 // Don't attempt to analyze instructions in blocks that aren't
7740 // reachable. Such instructions don't matter, and they aren't required
7741 // to obey basic rules for definitions dominating uses which this
7742 // analysis depends on.
7743 if (!DT.isReachableFromEntry(I->getParent()))
7744 return getUnknown(PoisonValue::get(V->getType()));
7745 } else if (ConstantInt *CI = dyn_cast<ConstantInt>(V))
7746 return getConstant(CI);
7747 else if (isa<GlobalAlias>(V))
7748 return getUnknown(V);
7749 else if (!isa<ConstantExpr>(V))
7750 return getUnknown(V);
7751
7753 if (auto BO =
7755 bool IsConstArg = isa<ConstantInt>(BO->RHS);
7756 switch (BO->Opcode) {
7757 case Instruction::Add:
7758 case Instruction::Mul: {
7759 // For additions and multiplications, traverse add/mul chains for which we
7760 // can potentially create a single SCEV, to reduce the number of
7761 // get{Add,Mul}Expr calls.
7762 do {
7763 if (BO->Op) {
7764 if (BO->Op != V && getExistingSCEV(BO->Op)) {
7765 Ops.push_back(BO->Op);
7766 break;
7767 }
7768 }
7769 Ops.push_back(BO->RHS);
7770 auto NewBO = MatchBinaryOp(BO->LHS, getDataLayout(), AC, DT,
7772 if (!NewBO ||
7773 (BO->Opcode == Instruction::Add &&
7774 (NewBO->Opcode != Instruction::Add &&
7775 NewBO->Opcode != Instruction::Sub)) ||
7776 (BO->Opcode == Instruction::Mul &&
7777 NewBO->Opcode != Instruction::Mul)) {
7778 Ops.push_back(BO->LHS);
7779 break;
7780 }
7781 // CreateSCEV calls getNoWrapFlagsFromUB, which under certain conditions
7782 // requires a SCEV for the LHS.
7783 if (BO->Op && (BO->IsNSW || BO->IsNUW)) {
7784 auto *I = dyn_cast<Instruction>(BO->Op);
7785 if (I && programUndefinedIfPoison(I)) {
7786 Ops.push_back(BO->LHS);
7787 break;
7788 }
7789 }
7790 BO = NewBO;
7791 } while (true);
7792 return nullptr;
7793 }
7794 case Instruction::Sub:
7795 case Instruction::UDiv:
7796 case Instruction::URem:
7797 break;
7798 case Instruction::AShr:
7799 case Instruction::Shl:
7800 case Instruction::Xor:
7801 if (!IsConstArg)
7802 return nullptr;
7803 break;
7804 case Instruction::And:
7805 case Instruction::Or:
7806 if (!IsConstArg && !BO->LHS->getType()->isIntegerTy(1))
7807 return nullptr;
7808 break;
7809 case Instruction::LShr:
7810 return getUnknown(V);
7811 default:
7812 llvm_unreachable("Unhandled binop");
7813 break;
7814 }
7815
7816 Ops.push_back(BO->LHS);
7817 Ops.push_back(BO->RHS);
7818 return nullptr;
7819 }
7820
7821 switch (U->getOpcode()) {
7822 case Instruction::Trunc:
7823 case Instruction::ZExt:
7824 case Instruction::SExt:
7825 case Instruction::PtrToAddr:
7826 case Instruction::PtrToInt:
7827 Ops.push_back(U->getOperand(0));
7828 return nullptr;
7829
7830 case Instruction::BitCast:
7831 if (isSCEVable(U->getType()) && isSCEVable(U->getOperand(0)->getType())) {
7832 Ops.push_back(U->getOperand(0));
7833 return nullptr;
7834 }
7835 return getUnknown(V);
7836
7837 case Instruction::SDiv:
7838 case Instruction::SRem:
7839 Ops.push_back(U->getOperand(0));
7840 Ops.push_back(U->getOperand(1));
7841 return nullptr;
7842
7843 case Instruction::GetElementPtr:
7844 assert(cast<GEPOperator>(U)->getSourceElementType()->isSized() &&
7845 "GEP source element type must be sized");
7846 llvm::append_range(Ops, U->operands());
7847 return nullptr;
7848
7849 case Instruction::IntToPtr:
7850 return getUnknown(V);
7851
7852 case Instruction::PHI:
7853 // getNodeForPHI has four ways to turn a PHI into a SCEV; retrieve the
7854 // relevant nodes for each of them.
7855 //
7856 // The first is just to call simplifyInstruction, and get something back
7857 // that isn't a PHI.
7858 if (Value *V = simplifyInstruction(
7859 cast<PHINode>(U),
7860 {getDataLayout(), &TLI, &DT, &AC, /*CtxI=*/nullptr,
7861 /*UseInstrInfo=*/true, /*CanUseUndef=*/false})) {
7862 assert(V);
7863 Ops.push_back(V);
7864 return nullptr;
7865 }
7866 // The second is createNodeForPHIWithIdenticalOperands: this looks for
7867 // operands which all perform the same operation, but haven't been
7868 // CSE'ed for whatever reason.
7869 if (BinaryOperator *BO = getCommonInstForPHI(cast<PHINode>(U))) {
7870 assert(BO);
7871 Ops.push_back(BO);
7872 return nullptr;
7873 }
7874 // The third is createNodeFromSelectLikePHI; this takes a PHI which
7875 // is equivalent to a select, and analyzes it like a select.
7876 {
7877 Value *Cond = nullptr, *LHS = nullptr, *RHS = nullptr;
7879 assert(Cond);
7880 assert(LHS);
7881 assert(RHS);
7882 if (auto *CondICmp = dyn_cast<ICmpInst>(Cond)) {
7883 Ops.push_back(CondICmp->getOperand(0));
7884 Ops.push_back(CondICmp->getOperand(1));
7885 }
7886 Ops.push_back(Cond);
7887 Ops.push_back(LHS);
7888 Ops.push_back(RHS);
7889 return nullptr;
7890 }
7891 }
7892 // The fourth way is createAddRecFromPHI. It's complicated to handle here,
7893 // so just construct it recursively.
7894 //
7895 // In addition to getNodeForPHI, also construct nodes which might be needed
7896 // by getRangeRef.
7898 for (Value *V : cast<PHINode>(U)->operands())
7899 Ops.push_back(V);
7900 return nullptr;
7901 }
7902 return nullptr;
7903
7904 case Instruction::Select: {
7905 // Check if U is a select that can be simplified to a SCEVUnknown.
7906 auto CanSimplifyToUnknown = [this, U]() {
7907 if (U->getType()->isIntegerTy(1) || isa<ConstantInt>(U->getOperand(0)))
7908 return false;
7909
7910 auto *ICI = dyn_cast<ICmpInst>(U->getOperand(0));
7911 if (!ICI)
7912 return false;
7913 Value *LHS = ICI->getOperand(0);
7914 Value *RHS = ICI->getOperand(1);
7915 if (ICI->getPredicate() == CmpInst::ICMP_EQ ||
7916 ICI->getPredicate() == CmpInst::ICMP_NE) {
7918 return true;
7919 } else if (getTypeSizeInBits(LHS->getType()) >
7920 getTypeSizeInBits(U->getType()))
7921 return true;
7922 return false;
7923 };
7924 if (CanSimplifyToUnknown())
7925 return getUnknown(U);
7926
7927 llvm::append_range(Ops, U->operands());
7928 return nullptr;
7929 break;
7930 }
7931 case Instruction::Call:
7932 case Instruction::Invoke:
7933 if (Value *RV = cast<CallBase>(U)->getReturnedArgOperand()) {
7934 Ops.push_back(RV);
7935 return nullptr;
7936 }
7937
7938 if (auto *II = dyn_cast<IntrinsicInst>(U)) {
7939 switch (II->getIntrinsicID()) {
7940 case Intrinsic::abs:
7941 Ops.push_back(II->getArgOperand(0));
7942 return nullptr;
7943 case Intrinsic::umax:
7944 case Intrinsic::umin:
7945 case Intrinsic::smax:
7946 case Intrinsic::smin:
7947 case Intrinsic::usub_sat:
7948 case Intrinsic::uadd_sat:
7949 Ops.push_back(II->getArgOperand(0));
7950 Ops.push_back(II->getArgOperand(1));
7951 return nullptr;
7952 case Intrinsic::start_loop_iterations:
7953 case Intrinsic::annotation:
7954 case Intrinsic::ptr_annotation:
7955 Ops.push_back(II->getArgOperand(0));
7956 return nullptr;
7957 default:
7958 break;
7959 }
7960 }
7961 break;
7962 }
7963
7964 return nullptr;
7965}
7966
7967const SCEV *ScalarEvolution::createSCEV(Value *V) {
7968 if (!isSCEVable(V->getType()))
7969 return getUnknown(V);
7970
7971 if (Instruction *I = dyn_cast<Instruction>(V)) {
7972 // Don't attempt to analyze instructions in blocks that aren't
7973 // reachable. Such instructions don't matter, and they aren't required
7974 // to obey basic rules for definitions dominating uses which this
7975 // analysis depends on.
7976 if (!DT.isReachableFromEntry(I->getParent()))
7977 return getUnknown(PoisonValue::get(V->getType()));
7978 } else if (ConstantInt *CI = dyn_cast<ConstantInt>(V))
7979 return getConstant(CI);
7980 else if (isa<GlobalAlias>(V))
7981 return getUnknown(V);
7982 else if (!isa<ConstantExpr>(V))
7983 return getUnknown(V);
7984
7985 const SCEV *LHS;
7986 const SCEV *RHS;
7987
7989 if (auto BO =
7991 switch (BO->Opcode) {
7992 case Instruction::Add: {
7993 // The simple thing to do would be to just call getSCEV on both operands
7994 // and call getAddExpr with the result. However if we're looking at a
7995 // bunch of things all added together, this can be quite inefficient,
7996 // because it leads to N-1 getAddExpr calls for N ultimate operands.
7997 // Instead, gather up all the operands and make a single getAddExpr call.
7998 // LLVM IR canonical form means we need only traverse the left operands.
8000 do {
8001 if (BO->Op) {
8002 if (auto *OpSCEV = getExistingSCEV(BO->Op)) {
8003 AddOps.push_back(OpSCEV);
8004 break;
8005 }
8006
8007 // If a NUW or NSW flag can be applied to the SCEV for this
8008 // addition, then compute the SCEV for this addition by itself
8009 // with a separate call to getAddExpr. We need to do that
8010 // instead of pushing the operands of the addition onto AddOps,
8011 // since the flags are only known to apply to this particular
8012 // addition - they may not apply to other additions that can be
8013 // formed with operands from AddOps.
8014 const SCEV *RHS = getSCEV(BO->RHS);
8015 SCEV::NoWrapFlags Flags = getNoWrapFlagsFromUB(BO->Op);
8016 if (Flags != SCEV::FlagAnyWrap) {
8017 const SCEV *LHS = getSCEV(BO->LHS);
8018 if (BO->Opcode == Instruction::Sub)
8019 AddOps.push_back(getMinusSCEV(LHS, RHS, Flags));
8020 else
8021 AddOps.push_back(getAddExpr(LHS, RHS, Flags));
8022 break;
8023 }
8024 }
8025
8026 if (BO->Opcode == Instruction::Sub)
8027 AddOps.push_back(getNegativeSCEV(getSCEV(BO->RHS)));
8028 else
8029 AddOps.push_back(getSCEV(BO->RHS));
8030
8031 auto NewBO = MatchBinaryOp(BO->LHS, getDataLayout(), AC, DT,
8033 if (!NewBO || (NewBO->Opcode != Instruction::Add &&
8034 NewBO->Opcode != Instruction::Sub)) {
8035 AddOps.push_back(getSCEV(BO->LHS));
8036 break;
8037 }
8038 BO = NewBO;
8039 } while (true);
8040
8041 return getAddExpr(AddOps);
8042 }
8043
8044 case Instruction::Mul: {
8046 do {
8047 if (BO->Op) {
8048 if (auto *OpSCEV = getExistingSCEV(BO->Op)) {
8049 MulOps.push_back(OpSCEV);
8050 break;
8051 }
8052
8053 SCEV::NoWrapFlags Flags = getNoWrapFlagsFromUB(BO->Op);
8054 if (Flags != SCEV::FlagAnyWrap) {
8055 LHS = getSCEV(BO->LHS);
8056 RHS = getSCEV(BO->RHS);
8057 MulOps.push_back(getMulExpr(LHS, RHS, Flags));
8058 break;
8059 }
8060 }
8061
8062 MulOps.push_back(getSCEV(BO->RHS));
8063 auto NewBO = MatchBinaryOp(BO->LHS, getDataLayout(), AC, DT,
8065 if (!NewBO || NewBO->Opcode != Instruction::Mul) {
8066 MulOps.push_back(getSCEV(BO->LHS));
8067 break;
8068 }
8069 BO = NewBO;
8070 } while (true);
8071
8072 return getMulExpr(MulOps);
8073 }
8074 case Instruction::UDiv:
8075 LHS = getSCEV(BO->LHS);
8076 RHS = getSCEV(BO->RHS);
8077 return getUDivExpr(LHS, RHS);
8078 case Instruction::URem:
8079 LHS = getSCEV(BO->LHS);
8080 RHS = getSCEV(BO->RHS);
8081 return getURemExpr(LHS, RHS);
8082 case Instruction::Sub: {
8084 if (BO->Op)
8085 Flags = getNoWrapFlagsFromUB(BO->Op);
8086 LHS = getSCEV(BO->LHS);
8087 RHS = getSCEV(BO->RHS);
8088 return getMinusSCEV(LHS, RHS, Flags);
8089 }
8090 case Instruction::And:
8091 // For an expression like x&255 that merely masks off the high bits,
8092 // use zext(trunc(x)) as the SCEV expression.
8093 if (ConstantInt *CI = dyn_cast<ConstantInt>(BO->RHS)) {
8094 if (CI->isZero())
8095 return getSCEV(BO->RHS);
8096 if (CI->isMinusOne())
8097 return getSCEV(BO->LHS);
8098 const APInt &A = CI->getValue();
8099
8100 // Instcombine's ShrinkDemandedConstant may strip bits out of
8101 // constants, obscuring what would otherwise be a low-bits mask.
8102 // Use computeKnownBits to compute what ShrinkDemandedConstant
8103 // knew about to reconstruct a low-bits mask value.
8104 unsigned LZ = A.countl_zero();
8105 unsigned TZ = A.countr_zero();
8106 unsigned BitWidth = A.getBitWidth();
8107 KnownBits Known(BitWidth);
8108 computeKnownBits(BO->LHS, Known, getDataLayout(), &AC, nullptr, &DT);
8109
8110 APInt EffectiveMask =
8111 APInt::getLowBitsSet(BitWidth, BitWidth - LZ - TZ).shl(TZ);
8112 if ((LZ != 0 || TZ != 0) && !((~A & ~Known.Zero) & EffectiveMask)) {
8113 const SCEV *MulCount = getConstant(APInt::getOneBitSet(BitWidth, TZ));
8114 const SCEV *LHS = getSCEV(BO->LHS);
8115 const SCEV *ShiftedLHS = nullptr;
8116 if (auto *LHSMul = dyn_cast<SCEVMulExpr>(LHS)) {
8117 if (auto *OpC = dyn_cast<SCEVConstant>(LHSMul->getOperand(0))) {
8118 // For an expression like (x * 8) & 8, simplify the multiply.
8119 unsigned MulZeros = OpC->getAPInt().countr_zero();
8120 unsigned GCD = std::min(MulZeros, TZ);
8121 APInt DivAmt = APInt::getOneBitSet(BitWidth, TZ - GCD);
8123 MulOps.push_back(getConstant(OpC->getAPInt().ashr(GCD)));
8124 append_range(MulOps, LHSMul->operands().drop_front());
8125 auto *NewMul = getMulExpr(MulOps, LHSMul->getNoWrapFlags());
8126 ShiftedLHS = getUDivExpr(NewMul, getConstant(DivAmt));
8127 }
8128 }
8129 if (!ShiftedLHS)
8130 ShiftedLHS = getUDivExpr(LHS, MulCount);
8131 return getMulExpr(
8133 getTruncateExpr(ShiftedLHS,
8134 IntegerType::get(getContext(), BitWidth - LZ - TZ)),
8135 BO->LHS->getType()),
8136 MulCount);
8137 }
8138 }
8139 // Binary `and` is a bit-wise `umin`.
8140 if (BO->LHS->getType()->isIntegerTy(1)) {
8141 LHS = getSCEV(BO->LHS);
8142 RHS = getSCEV(BO->RHS);
8143 return getUMinExpr(LHS, RHS);
8144 }
8145 break;
8146
8147 case Instruction::Or:
8148 // Binary `or` is a bit-wise `umax`.
8149 if (BO->LHS->getType()->isIntegerTy(1)) {
8150 LHS = getSCEV(BO->LHS);
8151 RHS = getSCEV(BO->RHS);
8152 return getUMaxExpr(LHS, RHS);
8153 }
8154 break;
8155
8156 case Instruction::Xor:
8157 if (ConstantInt *CI = dyn_cast<ConstantInt>(BO->RHS)) {
8158 // If the RHS of xor is -1, then this is a not operation.
8159 if (CI->isMinusOne())
8160 return getNotSCEV(getSCEV(BO->LHS));
8161
8162 // Model xor(and(x, C), C) as and(~x, C), if C is a low-bits mask.
8163 // This is a variant of the check for xor with -1, and it handles
8164 // the case where instcombine has trimmed non-demanded bits out
8165 // of an xor with -1.
8166 if (auto *LBO = dyn_cast<BinaryOperator>(BO->LHS))
8167 if (ConstantInt *LCI = dyn_cast<ConstantInt>(LBO->getOperand(1)))
8168 if (LBO->getOpcode() == Instruction::And &&
8169 LCI->getValue() == CI->getValue())
8170 if (const SCEVZeroExtendExpr *Z =
8172 Type *UTy = BO->LHS->getType();
8173 const SCEV *Z0 = Z->getOperand();
8174 Type *Z0Ty = Z0->getType();
8175 unsigned Z0TySize = getTypeSizeInBits(Z0Ty);
8176
8177 // If C is a low-bits mask, the zero extend is serving to
8178 // mask off the high bits. Complement the operand and
8179 // re-apply the zext.
8180 if (CI->getValue().isMask(Z0TySize))
8181 return getZeroExtendExpr(getNotSCEV(Z0), UTy);
8182
8183 // If C is a single bit, it may be in the sign-bit position
8184 // before the zero-extend. In this case, represent the xor
8185 // using an add, which is equivalent, and re-apply the zext.
8186 APInt Trunc = CI->getValue().trunc(Z0TySize);
8187 if (Trunc.zext(getTypeSizeInBits(UTy)) == CI->getValue() &&
8188 Trunc.isSignMask())
8189 return getZeroExtendExpr(getAddExpr(Z0, getConstant(Trunc)),
8190 UTy);
8191 }
8192 }
8193 break;
8194
8195 case Instruction::Shl:
8196 // Turn shift left of a constant amount into a multiply.
8197 if (ConstantInt *SA = dyn_cast<ConstantInt>(BO->RHS)) {
8198 uint32_t BitWidth = cast<IntegerType>(SA->getType())->getBitWidth();
8199
8200 // If the shift count is not less than the bitwidth, the result of
8201 // the shift is undefined. Don't try to analyze it, because the
8202 // resolution chosen here may differ from the resolution chosen in
8203 // other parts of the compiler.
8204 if (SA->getValue().uge(BitWidth))
8205 break;
8206
8207 // We can safely preserve the nuw flag in all cases. It's also safe to
8208 // turn a nuw nsw shl into a nuw nsw mul. However, nsw in isolation
8209 // requires special handling. It can be preserved as long as we're not
8210 // left shifting by bitwidth - 1.
8211 auto Flags = SCEV::FlagAnyWrap;
8212 if (BO->Op) {
8213 auto MulFlags = getNoWrapFlagsFromUB(BO->Op);
8214 if (any(MulFlags & SCEV::FlagNSW) &&
8215 (any(MulFlags & SCEV::FlagNUW) ||
8216 SA->getValue().ult(BitWidth - 1)))
8218 if (any(MulFlags & SCEV::FlagNUW))
8220 }
8221
8222 ConstantInt *X = ConstantInt::get(
8223 getContext(), APInt::getOneBitSet(BitWidth, SA->getZExtValue()));
8224 return getMulExpr(getSCEV(BO->LHS), getConstant(X), Flags);
8225 }
8226 break;
8227
8228 case Instruction::AShr:
8229 // AShr X, C, where C is a constant.
8230 ConstantInt *CI = dyn_cast<ConstantInt>(BO->RHS);
8231 if (!CI)
8232 break;
8233
8234 Type *OuterTy = BO->LHS->getType();
8235 uint64_t BitWidth = getTypeSizeInBits(OuterTy);
8236 // If the shift count is not less than the bitwidth, the result of
8237 // the shift is undefined. Don't try to analyze it, because the
8238 // resolution chosen here may differ from the resolution chosen in
8239 // other parts of the compiler.
8240 if (CI->getValue().uge(BitWidth))
8241 break;
8242
8243 if (CI->isZero())
8244 return getSCEV(BO->LHS); // shift by zero --> noop
8245
8246 uint64_t AShrAmt = CI->getZExtValue();
8247 Type *TruncTy = IntegerType::get(getContext(), BitWidth - AShrAmt);
8248
8249 Operator *L = dyn_cast<Operator>(BO->LHS);
8250 const SCEV *AddTruncateExpr = nullptr;
8251 ConstantInt *ShlAmtCI = nullptr;
8252 const SCEV *AddConstant = nullptr;
8253
8254 if (L && L->getOpcode() == Instruction::Add) {
8255 // X = Shl A, n
8256 // Y = Add X, c
8257 // Z = AShr Y, m
8258 // n, c and m are constants.
8259
8260 Operator *LShift = dyn_cast<Operator>(L->getOperand(0));
8261 ConstantInt *AddOperandCI = dyn_cast<ConstantInt>(L->getOperand(1));
8262 if (LShift && LShift->getOpcode() == Instruction::Shl) {
8263 if (AddOperandCI) {
8264 const SCEV *ShlOp0SCEV = getSCEV(LShift->getOperand(0));
8265 ShlAmtCI = dyn_cast<ConstantInt>(LShift->getOperand(1));
8266 // since we truncate to TruncTy, the AddConstant should be of the
8267 // same type, so create a new Constant with type same as TruncTy.
8268 // Also, the Add constant should be shifted right by AShr amount.
8269 APInt AddOperand = AddOperandCI->getValue().ashr(AShrAmt);
8270 AddConstant = getConstant(AddOperand.trunc(BitWidth - AShrAmt));
8271 // we model the expression as sext(add(trunc(A), c << n)), since the
8272 // sext(trunc) part is already handled below, we create a
8273 // AddExpr(TruncExp) which will be used later.
8274 AddTruncateExpr = getTruncateExpr(ShlOp0SCEV, TruncTy);
8275 }
8276 }
8277 } else if (L && L->getOpcode() == Instruction::Shl) {
8278 // X = Shl A, n
8279 // Y = AShr X, m
8280 // Both n and m are constant.
8281
8282 const SCEV *ShlOp0SCEV = getSCEV(L->getOperand(0));
8283 ShlAmtCI = dyn_cast<ConstantInt>(L->getOperand(1));
8284 AddTruncateExpr = getTruncateExpr(ShlOp0SCEV, TruncTy);
8285 }
8286
8287 if (AddTruncateExpr && ShlAmtCI) {
8288 // We can merge the two given cases into a single SCEV statement,
8289 // incase n = m, the mul expression will be 2^0, so it gets resolved to
8290 // a simpler case. The following code handles the two cases:
8291 //
8292 // 1) For a two-shift sext-inreg, i.e. n = m,
8293 // use sext(trunc(x)) as the SCEV expression.
8294 //
8295 // 2) When n > m, use sext(mul(trunc(x), 2^(n-m)))) as the SCEV
8296 // expression. We already checked that ShlAmt < BitWidth, so
8297 // the multiplier, 1 << (ShlAmt - AShrAmt), fits into TruncTy as
8298 // ShlAmt - AShrAmt < Amt.
8299 const APInt &ShlAmt = ShlAmtCI->getValue();
8300 if (ShlAmt.ult(BitWidth) && ShlAmt.uge(AShrAmt)) {
8301 APInt Mul = APInt::getOneBitSet(BitWidth - AShrAmt,
8302 ShlAmtCI->getZExtValue() - AShrAmt);
8303 const SCEV *CompositeExpr =
8304 getMulExpr(AddTruncateExpr, getConstant(Mul));
8305 if (L->getOpcode() != Instruction::Shl)
8306 CompositeExpr = getAddExpr(CompositeExpr, AddConstant);
8307
8308 return getSignExtendExpr(CompositeExpr, OuterTy);
8309 }
8310 }
8311 break;
8312 }
8313 }
8314
8315 switch (U->getOpcode()) {
8316 case Instruction::Trunc:
8317 return getTruncateExpr(getSCEV(U->getOperand(0)), U->getType());
8318
8319 case Instruction::ZExt:
8320 return getZeroExtendExpr(getSCEV(U->getOperand(0)), U->getType());
8321
8322 case Instruction::SExt:
8323 if (auto BO = MatchBinaryOp(U->getOperand(0), getDataLayout(), AC, DT,
8325 // The NSW flag of a subtract does not always survive the conversion to
8326 // A + (-1)*B. By pushing sign extension onto its operands we are much
8327 // more likely to preserve NSW and allow later AddRec optimisations.
8328 //
8329 // NOTE: This is effectively duplicating this logic from getSignExtend:
8330 // sext((A + B + ...)<nsw>) --> (sext(A) + sext(B) + ...)<nsw>
8331 // but by that point the NSW information has potentially been lost.
8332 if (BO->Opcode == Instruction::Sub && BO->IsNSW) {
8333 Type *Ty = U->getType();
8334 auto *V1 = getSignExtendExpr(getSCEV(BO->LHS), Ty);
8335 auto *V2 = getSignExtendExpr(getSCEV(BO->RHS), Ty);
8336 return getMinusSCEV(V1, V2, SCEV::FlagNSW);
8337 }
8338 }
8339 return getSignExtendExpr(getSCEV(U->getOperand(0)), U->getType());
8340
8341 case Instruction::BitCast:
8342 // BitCasts are no-op casts so we just eliminate the cast.
8343 if (isSCEVable(U->getType()) && isSCEVable(U->getOperand(0)->getType()))
8344 return getSCEV(U->getOperand(0));
8345 break;
8346
8347 case Instruction::PtrToAddr: {
8348 const SCEV *IntOp = getPtrToAddrExpr(getSCEV(U->getOperand(0)));
8349 if (isa<SCEVCouldNotCompute>(IntOp))
8350 return getUnknown(V);
8351 return IntOp;
8352 }
8353
8354 case Instruction::PtrToInt: {
8355 // Pointer to integer cast is straight-forward, so do model it.
8356 const SCEV *Op = getSCEV(U->getOperand(0));
8357 Type *DstIntTy = U->getType();
8358 // But only if effective SCEV (integer) type is wide enough to represent
8359 // all possible pointer values.
8360 const SCEV *IntOp = getPtrToIntExpr(Op, DstIntTy);
8361 if (isa<SCEVCouldNotCompute>(IntOp))
8362 return getUnknown(V);
8363 return IntOp;
8364 }
8365 case Instruction::IntToPtr:
8366 // Just don't deal with inttoptr casts.
8367 return getUnknown(V);
8368
8369 case Instruction::SDiv:
8370 // If both operands are non-negative, this is just an udiv.
8371 if (isKnownNonNegative(getSCEV(U->getOperand(0))) &&
8372 isKnownNonNegative(getSCEV(U->getOperand(1))))
8373 return getUDivExpr(getSCEV(U->getOperand(0)), getSCEV(U->getOperand(1)));
8374 break;
8375
8376 case Instruction::SRem:
8377 // If both operands are non-negative, this is just an urem.
8378 if (isKnownNonNegative(getSCEV(U->getOperand(0))) &&
8379 isKnownNonNegative(getSCEV(U->getOperand(1))))
8380 return getURemExpr(getSCEV(U->getOperand(0)), getSCEV(U->getOperand(1)));
8381 break;
8382
8383 case Instruction::GetElementPtr:
8384 return createNodeForGEP(cast<GEPOperator>(U));
8385
8386 case Instruction::PHI:
8387 return createNodeForPHI(cast<PHINode>(U));
8388
8389 case Instruction::Select:
8390 return createNodeForSelectOrPHI(U, U->getOperand(0), U->getOperand(1),
8391 U->getOperand(2));
8392
8393 case Instruction::Call:
8394 case Instruction::Invoke:
8395 if (Value *RV = cast<CallBase>(U)->getReturnedArgOperand())
8396 return getSCEV(RV);
8397
8398 if (auto *II = dyn_cast<IntrinsicInst>(U)) {
8399 switch (II->getIntrinsicID()) {
8400 case Intrinsic::abs:
8401 return getAbsExpr(
8402 getSCEV(II->getArgOperand(0)),
8403 /*IsNSW=*/cast<ConstantInt>(II->getArgOperand(1))->isOne());
8404 case Intrinsic::umax:
8405 LHS = getSCEV(II->getArgOperand(0));
8406 RHS = getSCEV(II->getArgOperand(1));
8407 return getUMaxExpr(LHS, RHS);
8408 case Intrinsic::umin:
8409 LHS = getSCEV(II->getArgOperand(0));
8410 RHS = getSCEV(II->getArgOperand(1));
8411 return getUMinExpr(LHS, RHS);
8412 case Intrinsic::smax:
8413 LHS = getSCEV(II->getArgOperand(0));
8414 RHS = getSCEV(II->getArgOperand(1));
8415 return getSMaxExpr(LHS, RHS);
8416 case Intrinsic::smin:
8417 LHS = getSCEV(II->getArgOperand(0));
8418 RHS = getSCEV(II->getArgOperand(1));
8419 return getSMinExpr(LHS, RHS);
8420 case Intrinsic::usub_sat: {
8421 const SCEV *X = getSCEV(II->getArgOperand(0));
8422 const SCEV *Y = getSCEV(II->getArgOperand(1));
8423 const SCEV *ClampedY = getUMinExpr(X, Y);
8424 return getMinusSCEV(X, ClampedY, SCEV::FlagNUW);
8425 }
8426 case Intrinsic::uadd_sat: {
8427 const SCEV *X = getSCEV(II->getArgOperand(0));
8428 const SCEV *Y = getSCEV(II->getArgOperand(1));
8429 const SCEV *ClampedX = getUMinExpr(X, getNotSCEV(Y));
8430 return getAddExpr(ClampedX, Y, SCEV::FlagNUW);
8431 }
8432 case Intrinsic::start_loop_iterations:
8433 case Intrinsic::annotation:
8434 case Intrinsic::ptr_annotation:
8435 // A start_loop_iterations or llvm.annotation or llvm.prt.annotation is
8436 // just eqivalent to the first operand for SCEV purposes.
8437 return getSCEV(II->getArgOperand(0));
8438 case Intrinsic::vscale:
8439 return getVScale(II->getType());
8440 default:
8441 break;
8442 }
8443 }
8444 break;
8445 }
8446
8447 return getUnknown(V);
8448}
8449
8450//===----------------------------------------------------------------------===//
8451// Iteration Count Computation Code
8452//
8453
8455 if (isa<SCEVCouldNotCompute>(ExitCount))
8456 return getCouldNotCompute();
8457
8458 auto *ExitCountType = ExitCount->getType();
8459 assert(ExitCountType->isIntegerTy());
8460 auto *EvalTy = Type::getIntNTy(ExitCountType->getContext(),
8461 1 + ExitCountType->getScalarSizeInBits());
8462 return getTripCountFromExitCount(ExitCount, EvalTy, nullptr);
8463}
8464
8466 Type *EvalTy,
8467 const Loop *L) {
8468 if (isa<SCEVCouldNotCompute>(ExitCount))
8469 return getCouldNotCompute();
8470
8471 unsigned ExitCountSize = getTypeSizeInBits(ExitCount->getType());
8472 unsigned EvalSize = EvalTy->getPrimitiveSizeInBits();
8473
8474 auto CanAddOneWithoutOverflow = [&]() {
8475 ConstantRange ExitCountRange =
8476 getRangeRef(ExitCount, RangeSignHint::HINT_RANGE_UNSIGNED);
8477 if (!ExitCountRange.contains(APInt::getMaxValue(ExitCountSize)))
8478 return true;
8479
8480 return L && isLoopEntryGuardedByCond(L, ICmpInst::ICMP_NE, ExitCount,
8481 getMinusOne(ExitCount->getType()));
8482 };
8483
8484 // If we need to zero extend the backedge count, check if we can add one to
8485 // it prior to zero extending without overflow. Provided this is safe, it
8486 // allows better simplification of the +1.
8487 if (EvalSize > ExitCountSize && CanAddOneWithoutOverflow())
8488 return getZeroExtendExpr(
8489 getAddExpr(ExitCount, getOne(ExitCount->getType())), EvalTy);
8490
8491 // Get the total trip count from the count by adding 1. This may wrap.
8492 return getAddExpr(getTruncateOrZeroExtend(ExitCount, EvalTy), getOne(EvalTy));
8493}
8494
8495static unsigned getConstantTripCount(const SCEVConstant *ExitCount) {
8496 if (!ExitCount)
8497 return 0;
8498
8499 ConstantInt *ExitConst = ExitCount->getValue();
8500
8501 // Guard against huge trip counts.
8502 if (ExitConst->getValue().getActiveBits() > 32)
8503 return 0;
8504
8505 // In case of integer overflow, this returns 0, which is correct.
8506 return ((unsigned)ExitConst->getZExtValue()) + 1;
8507}
8508
8510 auto *ExitCount = dyn_cast<SCEVConstant>(getBackedgeTakenCount(L, Exact));
8511 return getConstantTripCount(ExitCount);
8512}
8513
8514unsigned
8516 const BasicBlock *ExitingBlock) {
8517 assert(ExitingBlock && "Must pass a non-null exiting block!");
8518 assert(L->isLoopExiting(ExitingBlock) &&
8519 "Exiting block must actually branch out of the loop!");
8520 const SCEVConstant *ExitCount =
8521 dyn_cast<SCEVConstant>(getExitCount(L, ExitingBlock));
8522 return getConstantTripCount(ExitCount);
8523}
8524
8526 const Loop *L, SmallVectorImpl<const SCEVPredicate *> *Predicates) {
8527
8528 const auto *MaxExitCount =
8529 Predicates ? getPredicatedConstantMaxBackedgeTakenCount(L, *Predicates)
8531 return getConstantTripCount(dyn_cast<SCEVConstant>(MaxExitCount));
8532}
8533
8535 SmallVector<BasicBlock *, 8> ExitingBlocks;
8536 L->getExitingBlocks(ExitingBlocks);
8537
8538 std::optional<unsigned> Res;
8539 for (auto *ExitingBB : ExitingBlocks) {
8540 unsigned Multiple = getSmallConstantTripMultiple(L, ExitingBB);
8541 if (!Res)
8542 Res = Multiple;
8543 Res = std::gcd(*Res, Multiple);
8544 }
8545 return Res.value_or(1);
8546}
8547
8549 const SCEV *ExitCount) {
8550 if (isa<SCEVCouldNotCompute>(ExitCount))
8551 return 1;
8552
8553 // Get the trip count
8554 const SCEV *TCExpr = getTripCountFromExitCount(applyLoopGuards(ExitCount, L));
8555
8556 APInt Multiple = getNonZeroConstantMultiple(TCExpr);
8557 // If a trip multiple is huge (>=2^32), the trip count is still divisible by
8558 // the greatest power of 2 divisor less than 2^32.
8559 return Multiple.getActiveBits() > 32
8560 ? 1U << std::min(31U, Multiple.countTrailingZeros())
8561 : (unsigned)Multiple.getZExtValue();
8562}
8563
8564/// Returns the largest constant divisor of the trip count of this loop as a
8565/// normal unsigned value, if possible. This means that the actual trip count is
8566/// always a multiple of the returned value (don't forget the trip count could
8567/// very well be zero as well!).
8568///
8569/// Returns 1 if the trip count is unknown or not guaranteed to be the
8570/// multiple of a constant (which is also the case if the trip count is simply
8571/// constant, use getSmallConstantTripCount for that case), Will also return 1
8572/// if the trip count is very large (>= 2^32).
8573///
8574/// As explained in the comments for getSmallConstantTripCount, this assumes
8575/// that control exits the loop via ExitingBlock.
8576unsigned
8578 const BasicBlock *ExitingBlock) {
8579 assert(ExitingBlock && "Must pass a non-null exiting block!");
8580 assert(L->isLoopExiting(ExitingBlock) &&
8581 "Exiting block must actually branch out of the loop!");
8582 const SCEV *ExitCount = getExitCount(L, ExitingBlock);
8583 return getSmallConstantTripMultiple(L, ExitCount);
8584}
8585
8587 const BasicBlock *ExitingBlock,
8588 ExitCountKind Kind) {
8589 switch (Kind) {
8590 case Exact:
8591 return getBackedgeTakenInfo(L).getExact(ExitingBlock, this);
8592 case SymbolicMaximum:
8593 return getBackedgeTakenInfo(L).getSymbolicMax(ExitingBlock, this);
8594 case ConstantMaximum:
8595 return getBackedgeTakenInfo(L).getConstantMax(ExitingBlock, this);
8596 };
8597 llvm_unreachable("Invalid ExitCountKind!");
8598}
8599
8601 const Loop *L, const BasicBlock *ExitingBlock,
8603 switch (Kind) {
8604 case Exact:
8605 return getPredicatedBackedgeTakenInfo(L).getExact(ExitingBlock, this,
8606 Predicates);
8607 case SymbolicMaximum:
8608 return getPredicatedBackedgeTakenInfo(L).getSymbolicMax(ExitingBlock, this,
8609 Predicates);
8610 case ConstantMaximum:
8611 return getPredicatedBackedgeTakenInfo(L).getConstantMax(ExitingBlock, this,
8612 Predicates);
8613 };
8614 llvm_unreachable("Invalid ExitCountKind!");
8615}
8616
8619 return getPredicatedBackedgeTakenInfo(L).getExact(L, this, &Preds);
8620}
8621
8623 ExitCountKind Kind) {
8624 switch (Kind) {
8625 case Exact:
8626 return getBackedgeTakenInfo(L).getExact(L, this);
8627 case ConstantMaximum:
8628 return getBackedgeTakenInfo(L).getConstantMax(this);
8629 case SymbolicMaximum:
8630 return getBackedgeTakenInfo(L).getSymbolicMax(L, this);
8631 };
8632 llvm_unreachable("Invalid ExitCountKind!");
8633}
8634
8637 return getPredicatedBackedgeTakenInfo(L).getSymbolicMax(L, this, &Preds);
8638}
8639
8642 return getPredicatedBackedgeTakenInfo(L).getConstantMax(this, &Preds);
8643}
8644
8646 return getBackedgeTakenInfo(L).isConstantMaxOrZero(this);
8647}
8648
8649/// Push PHI nodes in the header of the given loop onto the given Worklist.
8650static void PushLoopPHIs(const Loop *L,
8653 BasicBlock *Header = L->getHeader();
8654
8655 // Push all Loop-header PHIs onto the Worklist stack.
8656 for (PHINode &PN : Header->phis())
8657 if (Visited.insert(&PN).second)
8658 Worklist.push_back(&PN);
8659}
8660
8661ScalarEvolution::BackedgeTakenInfo &
8662ScalarEvolution::getPredicatedBackedgeTakenInfo(const Loop *L) {
8663 auto &BTI = getBackedgeTakenInfo(L);
8664 if (BTI.hasFullInfo())
8665 return BTI;
8666
8667 auto Pair = PredicatedBackedgeTakenCounts.try_emplace(L);
8668
8669 if (!Pair.second)
8670 return Pair.first->second;
8671
8672 BackedgeTakenInfo Result =
8673 computeBackedgeTakenCount(L, /*AllowPredicates=*/true);
8674
8675 return PredicatedBackedgeTakenCounts.find(L)->second = std::move(Result);
8676}
8677
8678ScalarEvolution::BackedgeTakenInfo &
8679ScalarEvolution::getBackedgeTakenInfo(const Loop *L) {
8680 // Initially insert an invalid entry for this loop. If the insertion
8681 // succeeds, proceed to actually compute a backedge-taken count and
8682 // update the value. The temporary CouldNotCompute value tells SCEV
8683 // code elsewhere that it shouldn't attempt to request a new
8684 // backedge-taken count, which could result in infinite recursion.
8685 std::pair<DenseMap<const Loop *, BackedgeTakenInfo>::iterator, bool> Pair =
8686 BackedgeTakenCounts.try_emplace(L);
8687 if (!Pair.second)
8688 return Pair.first->second;
8689
8690 // computeBackedgeTakenCount may allocate memory for its result. Inserting it
8691 // into the BackedgeTakenCounts map transfers ownership. Otherwise, the result
8692 // must be cleared in this scope.
8693 BackedgeTakenInfo Result = computeBackedgeTakenCount(L);
8694
8695 // Now that we know more about the trip count for this loop, forget any
8696 // existing SCEV values for PHI nodes in this loop since they are only
8697 // conservative estimates made without the benefit of trip count
8698 // information. This invalidation is not necessary for correctness, and is
8699 // only done to produce more precise results.
8700 if (Result.hasAnyInfo()) {
8701 // Invalidate any expression using an addrec in this loop.
8702 SmallVector<SCEVUse, 8> ToForget;
8703 auto LoopUsersIt = LoopUsers.find(L);
8704 if (LoopUsersIt != LoopUsers.end())
8705 append_range(ToForget, LoopUsersIt->second);
8706 forgetMemoizedResults(ToForget);
8707
8708 // Invalidate constant-evolved loop header phis.
8709 for (PHINode &PN : L->getHeader()->phis())
8710 ConstantEvolutionLoopExitValue.erase(&PN);
8711 }
8712
8713 // Re-lookup the insert position, since the call to
8714 // computeBackedgeTakenCount above could result in a
8715 // recusive call to getBackedgeTakenInfo (on a different
8716 // loop), which would invalidate the iterator computed
8717 // earlier.
8718 return BackedgeTakenCounts.find(L)->second = std::move(Result);
8719}
8720
8722 // This method is intended to forget all info about loops. It should
8723 // invalidate caches as if the following happened:
8724 // - The trip counts of all loops have changed arbitrarily
8725 // - Every llvm::Value has been updated in place to produce a different
8726 // result.
8727 BackedgeTakenCounts.clear();
8728 PredicatedBackedgeTakenCounts.clear();
8729 BECountUsers.clear();
8730 LoopPropertiesCache.clear();
8731 ConstantEvolutionLoopExitValue.clear();
8732 ValueExprMap.clear();
8733 ValuesAtScopes.clear();
8734 ValuesAtScopesUsers.clear();
8735 LoopDispositions.clear();
8736 BlockDispositions.clear();
8737 UnsignedRanges.clear();
8738 SignedRanges.clear();
8739 ExprValueMap.clear();
8740 HasRecMap.clear();
8741 ConstantMultipleCache.clear();
8742 PredicatedSCEVRewrites.clear();
8743 FoldCache.clear();
8744 FoldCacheUser.clear();
8745}
8746void ScalarEvolution::visitAndClearUsers(
8749 SmallVectorImpl<SCEVUse> &ToForget) {
8750 while (!Worklist.empty()) {
8751 Instruction *I = Worklist.pop_back_val();
8752 if (!isSCEVable(I->getType()) && !isa<WithOverflowInst>(I))
8753 continue;
8754
8756 ValueExprMap.find_as(static_cast<Value *>(I));
8757 if (It != ValueExprMap.end()) {
8758 eraseValueFromMap(It->first);
8759 ToForget.push_back(It->second);
8760 if (PHINode *PN = dyn_cast<PHINode>(I))
8761 ConstantEvolutionLoopExitValue.erase(PN);
8762 }
8763
8764 PushDefUseChildren(I, Worklist, Visited);
8765 }
8766}
8767
8769 SmallVector<const Loop *, 16> LoopWorklist(1, L);
8772 SmallVector<SCEVUse, 16> ToForget;
8773
8774 // Iterate over all the loops and sub-loops to drop SCEV information.
8775 while (!LoopWorklist.empty()) {
8776 auto *CurrL = LoopWorklist.pop_back_val();
8777
8778 // Drop any stored trip count value.
8779 forgetBackedgeTakenCounts(CurrL, /* Predicated */ false);
8780 forgetBackedgeTakenCounts(CurrL, /* Predicated */ true);
8781
8782 // Drop information about predicated SCEV rewrites for this loop.
8783 for (auto I = PredicatedSCEVRewrites.begin();
8784 I != PredicatedSCEVRewrites.end();) {
8785 std::pair<const SCEV *, const Loop *> Entry = I->first;
8786 if (Entry.second == CurrL)
8787 PredicatedSCEVRewrites.erase(I++);
8788 else
8789 ++I;
8790 }
8791
8792 auto LoopUsersItr = LoopUsers.find(CurrL);
8793 if (LoopUsersItr != LoopUsers.end())
8794 llvm::append_range(ToForget, LoopUsersItr->second);
8795
8796 // Drop information about expressions based on loop-header PHIs.
8797 PushLoopPHIs(CurrL, Worklist, Visited);
8798 visitAndClearUsers(Worklist, Visited, ToForget);
8799
8800 LoopPropertiesCache.erase(CurrL);
8801 // Forget all contained loops too, to avoid dangling entries in the
8802 // ValuesAtScopes map.
8803 LoopWorklist.append(CurrL->begin(), CurrL->end());
8804 }
8805 forgetMemoizedResults(ToForget);
8806}
8807
8809 forgetLoop(L->getOutermostLoop());
8810}
8811
8814 if (!I) return;
8815
8816 // Drop information about expressions based on loop-header PHIs.
8819 SmallVector<SCEVUse, 8> ToForget;
8820 Worklist.push_back(I);
8821 Visited.insert(I);
8822 visitAndClearUsers(Worklist, Visited, ToForget);
8823
8824 forgetMemoizedResults(ToForget);
8825}
8826
8828 if (!isSCEVable(V->getType()))
8829 return;
8830
8831 // If SCEV looked through a trivial LCSSA phi node, we might have SCEV's
8832 // directly using a SCEVUnknown/SCEVAddRec defined in the loop. After an
8833 // extra predecessor is added, this is no longer valid. Find all Unknowns and
8834 // AddRecs defined in the loop and invalidate any SCEV's making use of them.
8835 if (const SCEV *S = getExistingSCEV(V)) {
8836 struct InvalidationRootCollector {
8837 Loop *L;
8839
8840 InvalidationRootCollector(Loop *L) : L(L) {}
8841
8842 bool follow(const SCEV *S) {
8843 if (auto *SU = dyn_cast<SCEVUnknown>(S)) {
8844 if (auto *I = dyn_cast<Instruction>(SU->getValue()))
8845 if (L->contains(I))
8846 Roots.push_back(S);
8847 } else if (auto *AddRec = dyn_cast<SCEVAddRecExpr>(S)) {
8848 if (L->contains(AddRec->getLoop()))
8849 Roots.push_back(S);
8850 }
8851 return true;
8852 }
8853 bool isDone() const { return false; }
8854 };
8855
8856 InvalidationRootCollector C(L);
8857 visitAll(S, C);
8858 forgetMemoizedResults(C.Roots);
8859 }
8860
8861 // Also perform the normal invalidation.
8862 forgetValue(V);
8863}
8864
8865void ScalarEvolution::forgetLoopDispositions() { LoopDispositions.clear(); }
8866
8868 // Unless a specific value is passed to invalidation, completely clear both
8869 // caches.
8870 if (!V) {
8871 BlockDispositions.clear();
8872 LoopDispositions.clear();
8873 return;
8874 }
8875
8876 if (!isSCEVable(V->getType()))
8877 return;
8878
8879 const SCEV *S = getExistingSCEV(V);
8880 if (!S)
8881 return;
8882
8883 // Invalidate the block and loop dispositions cached for S. Dispositions of
8884 // S's users may change if S's disposition changes (i.e. a user may change to
8885 // loop-invariant, if S changes to loop invariant), so also invalidate
8886 // dispositions of S's users recursively.
8887 SmallVector<SCEVUse, 8> Worklist = {S};
8889 while (!Worklist.empty()) {
8890 const SCEV *Curr = Worklist.pop_back_val();
8891 bool LoopDispoRemoved = LoopDispositions.erase(Curr);
8892 bool BlockDispoRemoved = BlockDispositions.erase(Curr);
8893 if (!LoopDispoRemoved && !BlockDispoRemoved)
8894 continue;
8895 auto Users = SCEVUsers.find(Curr);
8896 if (Users != SCEVUsers.end())
8897 for (const auto *User : Users->second)
8898 if (Seen.insert(User).second)
8899 Worklist.push_back(User);
8900 }
8901}
8902
8903/// Get the exact loop backedge taken count considering all loop exits. A
8904/// computable result can only be returned for loops with all exiting blocks
8905/// dominating the latch. howFarToZero assumes that the limit of each loop test
8906/// is never skipped. This is a valid assumption as long as the loop exits via
8907/// that test. For precise results, it is the caller's responsibility to specify
8908/// the relevant loop exiting block using getExact(ExitingBlock, SE).
8909const SCEV *ScalarEvolution::BackedgeTakenInfo::getExact(
8910 const Loop *L, ScalarEvolution *SE,
8912 // If any exits were not computable, the loop is not computable.
8913 if (!isComplete() || ExitNotTaken.empty())
8914 return SE->getCouldNotCompute();
8915
8916 const BasicBlock *Latch = L->getLoopLatch();
8917 // All exiting blocks we have collected must dominate the only backedge.
8918 if (!Latch)
8919 return SE->getCouldNotCompute();
8920
8921 // All exiting blocks we have gathered dominate loop's latch, so exact trip
8922 // count is simply a minimum out of all these calculated exit counts.
8924 for (const auto &ENT : ExitNotTaken) {
8925 const SCEV *BECount = ENT.ExactNotTaken;
8926 assert(BECount != SE->getCouldNotCompute() && "Bad exit SCEV!");
8927 assert(SE->DT.dominates(ENT.ExitingBlock, Latch) &&
8928 "We should only have known counts for exiting blocks that dominate "
8929 "latch!");
8930
8931 Ops.push_back(BECount);
8932
8933 if (Preds)
8934 append_range(*Preds, ENT.Predicates);
8935
8936 assert((Preds || ENT.hasAlwaysTruePredicate()) &&
8937 "Predicate should be always true!");
8938 }
8939
8940 // If an earlier exit exits on the first iteration (exit count zero), then
8941 // a later poison exit count should not propagate into the result. This are
8942 // exactly the semantics provided by umin_seq.
8943 return SE->getUMinFromMismatchedTypes(Ops, /* Sequential */ true);
8944}
8945
8946const ScalarEvolution::ExitNotTakenInfo *
8947ScalarEvolution::BackedgeTakenInfo::getExitNotTaken(
8948 const BasicBlock *ExitingBlock,
8949 SmallVectorImpl<const SCEVPredicate *> *Predicates) const {
8950 for (const auto &ENT : ExitNotTaken)
8951 if (ENT.ExitingBlock == ExitingBlock) {
8952 if (ENT.hasAlwaysTruePredicate())
8953 return &ENT;
8954 else if (Predicates) {
8955 append_range(*Predicates, ENT.Predicates);
8956 return &ENT;
8957 }
8958 }
8959
8960 return nullptr;
8961}
8962
8963/// getConstantMax - Get the constant max backedge taken count for the loop.
8964const SCEV *ScalarEvolution::BackedgeTakenInfo::getConstantMax(
8965 ScalarEvolution *SE,
8966 SmallVectorImpl<const SCEVPredicate *> *Predicates) const {
8967 if (!getConstantMax())
8968 return SE->getCouldNotCompute();
8969
8970 for (const auto &ENT : ExitNotTaken)
8971 if (!ENT.hasAlwaysTruePredicate()) {
8972 if (!Predicates)
8973 return SE->getCouldNotCompute();
8974 append_range(*Predicates, ENT.Predicates);
8975 }
8976
8977 assert((isa<SCEVCouldNotCompute>(getConstantMax()) ||
8978 isa<SCEVConstant>(getConstantMax())) &&
8979 "No point in having a non-constant max backedge taken count!");
8980 return getConstantMax();
8981}
8982
8983const SCEV *ScalarEvolution::BackedgeTakenInfo::getSymbolicMax(
8984 const Loop *L, ScalarEvolution *SE,
8985 SmallVectorImpl<const SCEVPredicate *> *Predicates) {
8986 if (!SymbolicMax) {
8987 // Form an expression for the maximum exit count possible for this loop. We
8988 // merge the max and exact information to approximate a version of
8989 // getConstantMaxBackedgeTakenCount which isn't restricted to just
8990 // constants.
8991 SmallVector<SCEVUse, 4> ExitCounts;
8992
8993 for (const auto &ENT : ExitNotTaken) {
8994 const SCEV *ExitCount = ENT.SymbolicMaxNotTaken;
8995 if (!isa<SCEVCouldNotCompute>(ExitCount)) {
8996 assert(SE->DT.dominates(ENT.ExitingBlock, L->getLoopLatch()) &&
8997 "We should only have known counts for exiting blocks that "
8998 "dominate latch!");
8999 ExitCounts.push_back(ExitCount);
9000 if (Predicates)
9001 append_range(*Predicates, ENT.Predicates);
9002
9003 assert((Predicates || ENT.hasAlwaysTruePredicate()) &&
9004 "Predicate should be always true!");
9005 }
9006 }
9007 if (ExitCounts.empty())
9008 SymbolicMax = SE->getCouldNotCompute();
9009 else
9010 SymbolicMax =
9011 SE->getUMinFromMismatchedTypes(ExitCounts, /*Sequential*/ true);
9012 }
9013 return SymbolicMax;
9014}
9015
9016bool ScalarEvolution::BackedgeTakenInfo::isConstantMaxOrZero(
9017 ScalarEvolution *SE) const {
9018 auto PredicateNotAlwaysTrue = [](const ExitNotTakenInfo &ENT) {
9019 return !ENT.hasAlwaysTruePredicate();
9020 };
9021 return MaxOrZero && !any_of(ExitNotTaken, PredicateNotAlwaysTrue);
9022}
9023
9026
9028 const SCEV *E, const SCEV *ConstantMaxNotTaken,
9029 const SCEV *SymbolicMaxNotTaken, bool MaxOrZero,
9033 // If we prove the max count is zero, so is the symbolic bound. This happens
9034 // in practice due to differences in a) how context sensitive we've chosen
9035 // to be and b) how we reason about bounds implied by UB.
9036 if (ConstantMaxNotTaken->isZero()) {
9037 this->ExactNotTaken = E = ConstantMaxNotTaken;
9038 this->SymbolicMaxNotTaken = SymbolicMaxNotTaken = ConstantMaxNotTaken;
9039 }
9040
9043 "Exact is not allowed to be less precise than Constant Max");
9046 "Exact is not allowed to be less precise than Symbolic Max");
9049 "Symbolic Max is not allowed to be less precise than Constant Max");
9052 "No point in having a non-constant max backedge taken count!");
9054 for (const auto PredList : PredLists)
9055 for (const auto *P : PredList) {
9056 if (SeenPreds.contains(P))
9057 continue;
9058 assert(!isa<SCEVUnionPredicate>(P) && "Only add leaf predicates here!");
9059 SeenPreds.insert(P);
9060 Predicates.push_back(P);
9061 }
9062 assert((isa<SCEVCouldNotCompute>(E) || !E->getType()->isPointerTy()) &&
9063 "Backedge count should be int");
9065 !ConstantMaxNotTaken->getType()->isPointerTy()) &&
9066 "Max backedge count should be int");
9067}
9068
9076
9077/// Allocate memory for BackedgeTakenInfo and copy the not-taken count of each
9078/// computable exit into a persistent ExitNotTakenInfo array.
9079ScalarEvolution::BackedgeTakenInfo::BackedgeTakenInfo(
9081 bool IsComplete, const SCEV *ConstantMax, bool MaxOrZero)
9082 : ConstantMax(ConstantMax), IsComplete(IsComplete), MaxOrZero(MaxOrZero) {
9083 using EdgeExitInfo = ScalarEvolution::BackedgeTakenInfo::EdgeExitInfo;
9084
9085 ExitNotTaken.reserve(ExitCounts.size());
9086 std::transform(ExitCounts.begin(), ExitCounts.end(),
9087 std::back_inserter(ExitNotTaken),
9088 [&](const EdgeExitInfo &EEI) {
9089 BasicBlock *ExitBB = EEI.first;
9090 const ExitLimit &EL = EEI.second;
9091 return ExitNotTakenInfo(ExitBB, EL.ExactNotTaken,
9092 EL.ConstantMaxNotTaken, EL.SymbolicMaxNotTaken,
9093 EL.Predicates);
9094 });
9095 assert((isa<SCEVCouldNotCompute>(ConstantMax) ||
9096 isa<SCEVConstant>(ConstantMax)) &&
9097 "No point in having a non-constant max backedge taken count!");
9098}
9099
9100/// Compute the number of times the backedge of the specified loop will execute.
9101ScalarEvolution::BackedgeTakenInfo
9102ScalarEvolution::computeBackedgeTakenCount(const Loop *L,
9103 bool AllowPredicates) {
9104 SmallVector<BasicBlock *, 8> ExitingBlocks;
9105 L->getExitingBlocks(ExitingBlocks);
9106
9107 using EdgeExitInfo = ScalarEvolution::BackedgeTakenInfo::EdgeExitInfo;
9108
9110 bool CouldComputeBECount = true;
9111 BasicBlock *Latch = L->getLoopLatch(); // may be NULL.
9112 const SCEV *MustExitMaxBECount = nullptr;
9113 const SCEV *MayExitMaxBECount = nullptr;
9114 bool MustExitMaxOrZero = false;
9115 bool IsOnlyExit = ExitingBlocks.size() == 1;
9116
9117 // Compute the ExitLimit for each loop exit. Use this to populate ExitCounts
9118 // and compute maxBECount.
9119 // Do a union of all the predicates here.
9120 for (BasicBlock *ExitBB : ExitingBlocks) {
9121 // We canonicalize untaken exits to br (constant), ignore them so that
9122 // proving an exit untaken doesn't negatively impact our ability to reason
9123 // about the loop as whole.
9124 if (auto *BI = dyn_cast<CondBrInst>(ExitBB->getTerminator()))
9125 if (auto *CI = dyn_cast<ConstantInt>(BI->getCondition())) {
9126 bool ExitIfTrue = !L->contains(BI->getSuccessor(0));
9127 if (ExitIfTrue == CI->isZero())
9128 continue;
9129 }
9130
9131 ExitLimit EL = computeExitLimit(L, ExitBB, IsOnlyExit, AllowPredicates);
9132
9133 assert((AllowPredicates || EL.Predicates.empty()) &&
9134 "Predicated exit limit when predicates are not allowed!");
9135
9136 // 1. For each exit that can be computed, add an entry to ExitCounts.
9137 // CouldComputeBECount is true only if all exits can be computed.
9138 if (EL.ExactNotTaken != getCouldNotCompute())
9139 ++NumExitCountsComputed;
9140 else
9141 // We couldn't compute an exact value for this exit, so
9142 // we won't be able to compute an exact value for the loop.
9143 CouldComputeBECount = false;
9144 // Remember exit count if either exact or symbolic is known. Because
9145 // Exact always implies symbolic, only check symbolic.
9146 if (EL.SymbolicMaxNotTaken != getCouldNotCompute())
9147 ExitCounts.emplace_back(ExitBB, EL);
9148 else {
9149 assert(EL.ExactNotTaken == getCouldNotCompute() &&
9150 "Exact is known but symbolic isn't?");
9151 ++NumExitCountsNotComputed;
9152 }
9153
9154 // 2. Derive the loop's MaxBECount from each exit's max number of
9155 // non-exiting iterations. Partition the loop exits into two kinds:
9156 // LoopMustExits and LoopMayExits.
9157 //
9158 // If the exit dominates the loop latch, it is a LoopMustExit otherwise it
9159 // is a LoopMayExit. If any computable LoopMustExit is found, then
9160 // MaxBECount is the minimum EL.ConstantMaxNotTaken of computable
9161 // LoopMustExits. Otherwise, MaxBECount is conservatively the maximum
9162 // EL.ConstantMaxNotTaken, where CouldNotCompute is considered greater than
9163 // any
9164 // computable EL.ConstantMaxNotTaken.
9165 if (EL.ConstantMaxNotTaken != getCouldNotCompute() && Latch &&
9166 DT.dominates(ExitBB, Latch)) {
9167 if (!MustExitMaxBECount) {
9168 MustExitMaxBECount = EL.ConstantMaxNotTaken;
9169 MustExitMaxOrZero = EL.MaxOrZero;
9170 } else {
9171 MustExitMaxBECount = getUMinFromMismatchedTypes(MustExitMaxBECount,
9172 EL.ConstantMaxNotTaken);
9173 }
9174 } else if (MayExitMaxBECount != getCouldNotCompute()) {
9175 if (!MayExitMaxBECount || EL.ConstantMaxNotTaken == getCouldNotCompute())
9176 MayExitMaxBECount = EL.ConstantMaxNotTaken;
9177 else {
9178 MayExitMaxBECount = getUMaxFromMismatchedTypes(MayExitMaxBECount,
9179 EL.ConstantMaxNotTaken);
9180 }
9181 }
9182 }
9183 const SCEV *MaxBECount = MustExitMaxBECount ? MustExitMaxBECount :
9184 (MayExitMaxBECount ? MayExitMaxBECount : getCouldNotCompute());
9185 // The loop backedge will be taken the maximum or zero times if there's
9186 // a single exit that must be taken the maximum or zero times.
9187 bool MaxOrZero = (MustExitMaxOrZero && ExitingBlocks.size() == 1);
9188
9189 // Remember which SCEVs are used in exit limits for invalidation purposes.
9190 // We only care about non-constant SCEVs here, so we can ignore
9191 // EL.ConstantMaxNotTaken
9192 // and MaxBECount, which must be SCEVConstant.
9193 for (const auto &Pair : ExitCounts) {
9194 if (!isa<SCEVConstant>(Pair.second.ExactNotTaken))
9195 BECountUsers[Pair.second.ExactNotTaken].insert({L, AllowPredicates});
9196 if (!isa<SCEVConstant>(Pair.second.SymbolicMaxNotTaken))
9197 BECountUsers[Pair.second.SymbolicMaxNotTaken].insert(
9198 {L, AllowPredicates});
9199 }
9200 return BackedgeTakenInfo(std::move(ExitCounts), CouldComputeBECount,
9201 MaxBECount, MaxOrZero);
9202}
9203
9204ScalarEvolution::ExitLimit
9205ScalarEvolution::computeExitLimit(const Loop *L, BasicBlock *ExitingBlock,
9206 bool IsOnlyExit, bool AllowPredicates) {
9207 assert(L->contains(ExitingBlock) && "Exit count for non-loop block?");
9208 // If our exiting block does not dominate the latch, then its connection with
9209 // loop's exit limit may be far from trivial.
9210 const BasicBlock *Latch = L->getLoopLatch();
9211 if (!Latch || !DT.dominates(ExitingBlock, Latch))
9212 return getCouldNotCompute();
9213
9214 Instruction *Term = ExitingBlock->getTerminator();
9215 if (CondBrInst *BI = dyn_cast<CondBrInst>(Term)) {
9216 bool ExitIfTrue = !L->contains(BI->getSuccessor(0));
9217 assert(ExitIfTrue == L->contains(BI->getSuccessor(1)) &&
9218 "It should have one successor in loop and one exit block!");
9219 // Proceed to the next level to examine the exit condition expression.
9220 return computeExitLimitFromCond(L, BI->getCondition(), ExitIfTrue,
9221 /*ControlsOnlyExit=*/IsOnlyExit,
9222 AllowPredicates);
9223 }
9224
9225 if (SwitchInst *SI = dyn_cast<SwitchInst>(Term)) {
9226 // For switch, make sure that there is a single exit from the loop.
9227 BasicBlock *Exit = nullptr;
9228 for (auto *SBB : successors(ExitingBlock))
9229 if (!L->contains(SBB)) {
9230 if (Exit) // Multiple exit successors.
9231 return getCouldNotCompute();
9232 Exit = SBB;
9233 }
9234 assert(Exit && "Exiting block must have at least one exit");
9235 return computeExitLimitFromSingleExitSwitch(
9236 L, SI, Exit, /*ControlsOnlyExit=*/IsOnlyExit);
9237 }
9238
9239 return getCouldNotCompute();
9240}
9241
9243 const Loop *L, Value *ExitCond, bool ExitIfTrue, bool ControlsOnlyExit,
9244 bool AllowPredicates) {
9245 ScalarEvolution::ExitLimitCacheTy Cache(L, ExitIfTrue, AllowPredicates);
9246 return computeExitLimitFromCondCached(Cache, L, ExitCond, ExitIfTrue,
9247 ControlsOnlyExit, AllowPredicates);
9248}
9249
9250std::optional<ScalarEvolution::ExitLimit>
9251ScalarEvolution::ExitLimitCache::find(const Loop *L, Value *ExitCond,
9252 bool ExitIfTrue, bool ControlsOnlyExit,
9253 bool AllowPredicates) {
9254 (void)this->L;
9255 (void)this->ExitIfTrue;
9256 (void)this->AllowPredicates;
9257
9258 assert(this->L == L && this->ExitIfTrue == ExitIfTrue &&
9259 this->AllowPredicates == AllowPredicates &&
9260 "Variance in assumed invariant key components!");
9261 auto Itr = TripCountMap.find({ExitCond, ControlsOnlyExit});
9262 if (Itr == TripCountMap.end())
9263 return std::nullopt;
9264 return Itr->second;
9265}
9266
9267void ScalarEvolution::ExitLimitCache::insert(const Loop *L, Value *ExitCond,
9268 bool ExitIfTrue,
9269 bool ControlsOnlyExit,
9270 bool AllowPredicates,
9271 const ExitLimit &EL) {
9272 assert(this->L == L && this->ExitIfTrue == ExitIfTrue &&
9273 this->AllowPredicates == AllowPredicates &&
9274 "Variance in assumed invariant key components!");
9275
9276 auto InsertResult = TripCountMap.insert({{ExitCond, ControlsOnlyExit}, EL});
9277 assert(InsertResult.second && "Expected successful insertion!");
9278 (void)InsertResult;
9279 (void)ExitIfTrue;
9280}
9281
9282ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromCondCached(
9283 ExitLimitCacheTy &Cache, const Loop *L, Value *ExitCond, bool ExitIfTrue,
9284 bool ControlsOnlyExit, bool AllowPredicates) {
9285
9286 if (auto MaybeEL = Cache.find(L, ExitCond, ExitIfTrue, ControlsOnlyExit,
9287 AllowPredicates))
9288 return *MaybeEL;
9289
9290 ExitLimit EL = computeExitLimitFromCondImpl(
9291 Cache, L, ExitCond, ExitIfTrue, ControlsOnlyExit, AllowPredicates);
9292 Cache.insert(L, ExitCond, ExitIfTrue, ControlsOnlyExit, AllowPredicates, EL);
9293 return EL;
9294}
9295
9296ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromCondImpl(
9297 ExitLimitCacheTy &Cache, const Loop *L, Value *ExitCond, bool ExitIfTrue,
9298 bool ControlsOnlyExit, bool AllowPredicates) {
9299 // Handle BinOp conditions (And, Or).
9300 if (auto LimitFromBinOp = computeExitLimitFromCondFromBinOp(
9301 Cache, L, ExitCond, ExitIfTrue, ControlsOnlyExit, AllowPredicates))
9302 return *LimitFromBinOp;
9303
9304 // With an icmp, it may be feasible to compute an exact backedge-taken count.
9305 // Proceed to the next level to examine the icmp.
9306 if (ICmpInst *ExitCondICmp = dyn_cast<ICmpInst>(ExitCond)) {
9307 ExitLimit EL =
9308 computeExitLimitFromICmp(L, ExitCondICmp, ExitIfTrue, ControlsOnlyExit);
9309 if (EL.hasFullInfo() || !AllowPredicates)
9310 return EL;
9311
9312 // Try again, but use SCEV predicates this time.
9313 return computeExitLimitFromICmp(L, ExitCondICmp, ExitIfTrue,
9314 ControlsOnlyExit,
9315 /*AllowPredicates=*/true);
9316 }
9317
9318 // Check for a constant condition. These are normally stripped out by
9319 // SimplifyCFG, but ScalarEvolution may be used by a pass which wishes to
9320 // preserve the CFG and is temporarily leaving constant conditions
9321 // in place.
9322 if (ConstantInt *CI = dyn_cast<ConstantInt>(ExitCond)) {
9323 if (ExitIfTrue == !CI->getZExtValue())
9324 // The backedge is always taken.
9325 return getCouldNotCompute();
9326 // The backedge is never taken.
9327 return getZero(CI->getType());
9328 }
9329
9330 // If we're exiting based on the overflow flag of an x.with.overflow intrinsic
9331 // with a constant step, we can form an equivalent icmp predicate and figure
9332 // out how many iterations will be taken before we exit.
9333 const WithOverflowInst *WO;
9334 const APInt *C;
9335 if (match(ExitCond, m_ExtractValue<1>(m_WithOverflowInst(WO))) &&
9336 match(WO->getRHS(), m_APInt(C))) {
9337 ConstantRange NWR =
9339 WO->getNoWrapKind());
9340 CmpInst::Predicate Pred;
9341 APInt NewRHSC, Offset;
9342 NWR.getEquivalentICmp(Pred, NewRHSC, Offset);
9343 if (!ExitIfTrue)
9344 Pred = ICmpInst::getInversePredicate(Pred);
9345 auto *LHS = getSCEV(WO->getLHS());
9346 if (Offset != 0)
9348 auto EL = computeExitLimitFromICmp(L, Pred, LHS, getConstant(NewRHSC),
9349 ControlsOnlyExit, AllowPredicates);
9350 if (EL.hasAnyInfo())
9351 return EL;
9352 }
9353
9354 // If it's not an integer or pointer comparison then compute it the hard way.
9355 return computeExitCountExhaustively(L, ExitCond, ExitIfTrue);
9356}
9357
9358std::optional<ScalarEvolution::ExitLimit>
9359ScalarEvolution::computeExitLimitFromCondFromBinOp(
9360 ExitLimitCacheTy &Cache, const Loop *L, Value *ExitCond, bool ExitIfTrue,
9361 bool ControlsOnlyExit, bool AllowPredicates) {
9362 // Check if the controlling expression for this loop is an And or Or.
9363 Value *Op0, *Op1;
9364 bool IsAnd = false;
9365 if (match(ExitCond, m_LogicalAnd(m_Value(Op0), m_Value(Op1))))
9366 IsAnd = true;
9367 else if (match(ExitCond, m_LogicalOr(m_Value(Op0), m_Value(Op1))))
9368 IsAnd = false;
9369 else
9370 return std::nullopt;
9371
9372 // EitherMayExit is true in these two cases:
9373 // br (and Op0 Op1), loop, exit
9374 // br (or Op0 Op1), exit, loop
9375 bool EitherMayExit = IsAnd ^ ExitIfTrue;
9376 ExitLimit EL0 = computeExitLimitFromCondCached(
9377 Cache, L, Op0, ExitIfTrue, ControlsOnlyExit && !EitherMayExit,
9378 AllowPredicates);
9379 ExitLimit EL1 = computeExitLimitFromCondCached(
9380 Cache, L, Op1, ExitIfTrue, ControlsOnlyExit && !EitherMayExit,
9381 AllowPredicates);
9382
9383 // Be robust against unsimplified IR for the form "op i1 X, NeutralElement"
9384 const Constant *NeutralElement = ConstantInt::get(ExitCond->getType(), IsAnd);
9385 if (isa<ConstantInt>(Op1))
9386 return Op1 == NeutralElement ? EL0 : EL1;
9387 if (isa<ConstantInt>(Op0))
9388 return Op0 == NeutralElement ? EL1 : EL0;
9389
9390 const SCEV *BECount = getCouldNotCompute();
9391 const SCEV *ConstantMaxBECount = getCouldNotCompute();
9392 const SCEV *SymbolicMaxBECount = getCouldNotCompute();
9393 if (EitherMayExit) {
9394 bool UseSequentialUMin = !isa<BinaryOperator>(ExitCond);
9395 // Both conditions must be same for the loop to continue executing.
9396 // Choose the less conservative count.
9397 if (EL0.ExactNotTaken != getCouldNotCompute() &&
9398 EL1.ExactNotTaken != getCouldNotCompute()) {
9399 BECount = getUMinFromMismatchedTypes(EL0.ExactNotTaken, EL1.ExactNotTaken,
9400 UseSequentialUMin);
9401 }
9402 if (EL0.ConstantMaxNotTaken == getCouldNotCompute())
9403 ConstantMaxBECount = EL1.ConstantMaxNotTaken;
9404 else if (EL1.ConstantMaxNotTaken == getCouldNotCompute())
9405 ConstantMaxBECount = EL0.ConstantMaxNotTaken;
9406 else
9407 ConstantMaxBECount = getUMinFromMismatchedTypes(EL0.ConstantMaxNotTaken,
9408 EL1.ConstantMaxNotTaken);
9409 if (EL0.SymbolicMaxNotTaken == getCouldNotCompute())
9410 SymbolicMaxBECount = EL1.SymbolicMaxNotTaken;
9411 else if (EL1.SymbolicMaxNotTaken == getCouldNotCompute())
9412 SymbolicMaxBECount = EL0.SymbolicMaxNotTaken;
9413 else
9414 SymbolicMaxBECount = getUMinFromMismatchedTypes(
9415 EL0.SymbolicMaxNotTaken, EL1.SymbolicMaxNotTaken, UseSequentialUMin);
9416 } else {
9417 // Both conditions must be same at the same time for the loop to exit.
9418 // For now, be conservative.
9419 if (EL0.ExactNotTaken == EL1.ExactNotTaken)
9420 BECount = EL0.ExactNotTaken;
9421 }
9422
9423 // There are cases (e.g. PR26207) where computeExitLimitFromCond is able
9424 // to be more aggressive when computing BECount than when computing
9425 // ConstantMaxBECount. In these cases it is possible for EL0.ExactNotTaken
9426 // and
9427 // EL1.ExactNotTaken to match, but for EL0.ConstantMaxNotTaken and
9428 // EL1.ConstantMaxNotTaken to not.
9429 if (isa<SCEVCouldNotCompute>(ConstantMaxBECount) &&
9430 !isa<SCEVCouldNotCompute>(BECount))
9431 ConstantMaxBECount = getConstant(getUnsignedRangeMax(BECount));
9432 if (isa<SCEVCouldNotCompute>(SymbolicMaxBECount))
9433 SymbolicMaxBECount =
9434 isa<SCEVCouldNotCompute>(BECount) ? ConstantMaxBECount : BECount;
9435 return ExitLimit(BECount, ConstantMaxBECount, SymbolicMaxBECount, false,
9436 {ArrayRef(EL0.Predicates), ArrayRef(EL1.Predicates)});
9437}
9438
9439ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromICmp(
9440 const Loop *L, ICmpInst *ExitCond, bool ExitIfTrue, bool ControlsOnlyExit,
9441 bool AllowPredicates) {
9442 // If the condition was exit on true, convert the condition to exit on false
9443 CmpPredicate Pred;
9444 if (!ExitIfTrue)
9445 Pred = ExitCond->getCmpPredicate();
9446 else
9447 Pred = ExitCond->getInverseCmpPredicate();
9448 const ICmpInst::Predicate OriginalPred = Pred;
9449
9450 const SCEV *LHS = getSCEV(ExitCond->getOperand(0));
9451 const SCEV *RHS = getSCEV(ExitCond->getOperand(1));
9452
9453 ExitLimit EL = computeExitLimitFromICmp(L, Pred, LHS, RHS, ControlsOnlyExit,
9454 AllowPredicates);
9455 if (EL.hasAnyInfo())
9456 return EL;
9457
9458 auto *ExhaustiveCount =
9459 computeExitCountExhaustively(L, ExitCond, ExitIfTrue);
9460
9461 if (!isa<SCEVCouldNotCompute>(ExhaustiveCount))
9462 return ExhaustiveCount;
9463
9464 return computeShiftCompareExitLimit(ExitCond->getOperand(0),
9465 ExitCond->getOperand(1), L, OriginalPred);
9466}
9467ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromICmp(
9468 const Loop *L, CmpPredicate Pred, SCEVUse LHS, SCEVUse RHS,
9469 bool ControlsOnlyExit, bool AllowPredicates) {
9470
9471 // Try to evaluate any dependencies out of the loop.
9472 LHS = getSCEVAtScope(LHS, L);
9473 RHS = getSCEVAtScope(RHS, L);
9474
9475 // At this point, we would like to compute how many iterations of the
9476 // loop the predicate will return true for these inputs.
9477 if (isLoopInvariant(LHS, L) && !isLoopInvariant(RHS, L)) {
9478 // If there is a loop-invariant, force it into the RHS.
9479 std::swap(LHS, RHS);
9481 }
9482
9483 bool ControllingFiniteLoop = ControlsOnlyExit && loopHasNoAbnormalExits(L) &&
9485 // Simplify the operands before analyzing them.
9486 (void)SimplifyICmpOperands(Pred, LHS, RHS, /*Depth=*/0);
9487
9488 // If we have a comparison of a chrec against a constant, try to use value
9489 // ranges to answer this query.
9490 if (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS))
9491 if (const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(LHS))
9492 if (AddRec->getLoop() == L) {
9493 // Form the constant range.
9494 ConstantRange CompRange =
9495 ConstantRange::makeExactICmpRegion(Pred, RHSC->getAPInt());
9496
9497 const SCEV *Ret = AddRec->getNumIterationsInRange(CompRange, *this);
9498 if (!isa<SCEVCouldNotCompute>(Ret)) return Ret;
9499 }
9500
9501 // If this loop must exit based on this condition (or execute undefined
9502 // behaviour), see if we can improve wrap flags. This is essentially
9503 // a must execute style proof.
9504 if (ControllingFiniteLoop && isLoopInvariant(RHS, L)) {
9505 // If we can prove the test sequence produced must repeat the same values
9506 // on self-wrap of the IV, then we can infer that IV doesn't self wrap
9507 // because if it did, we'd have an infinite (undefined) loop.
9508 // TODO: We can peel off any functions which are invertible *in L*. Loop
9509 // invariant terms are effectively constants for our purposes here.
9510 SCEVUse InnerLHS = LHS;
9511 if (auto *ZExt = dyn_cast<SCEVZeroExtendExpr>(LHS))
9512 InnerLHS = ZExt->getOperand();
9513 if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(InnerLHS);
9514 AR && !AR->hasNoSelfWrap() && AR->getLoop() == L && AR->isAffine() &&
9515 isKnownToBeAPowerOfTwo(AR->getStepRecurrence(*this), /*OrZero=*/true,
9516 /*OrNegative=*/true)) {
9517 auto Flags = AR->getNoWrapFlags();
9518 Flags = setFlags(Flags, SCEV::FlagNW);
9519 SmallVector<SCEVUse> Operands{AR->operands()};
9520 Flags = StrengthenNoWrapFlags(this, scAddRecExpr, Operands, Flags);
9521 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), Flags);
9522 }
9523
9524 // For a slt/ult condition with a positive step, can we prove nsw/nuw?
9525 // From no-self-wrap, this follows trivially from the fact that every
9526 // (un)signed-wrapped, but not self-wrapped value must be LT than the
9527 // last value before (un)signed wrap. Since we know that last value
9528 // didn't exit, nor will any smaller one.
9529 if (Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_ULT) {
9530 auto WrapType = Pred == ICmpInst::ICMP_SLT ? SCEV::FlagNSW : SCEV::FlagNUW;
9531 if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(LHS);
9532 AR && AR->getLoop() == L && AR->isAffine() &&
9533 !AR->getNoWrapFlags(WrapType) && AR->hasNoSelfWrap() &&
9534 isKnownPositive(AR->getStepRecurrence(*this))) {
9535 auto Flags = AR->getNoWrapFlags();
9536 Flags = setFlags(Flags, WrapType);
9537 SmallVector<SCEVUse> Operands{AR->operands()};
9538 Flags = StrengthenNoWrapFlags(this, scAddRecExpr, Operands, Flags);
9539 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), Flags);
9540 }
9541 }
9542 }
9543
9544 switch (Pred) {
9545 case ICmpInst::ICMP_NE: { // while (X != Y)
9546 // Convert to: while (X-Y != 0)
9547 if (LHS->getType()->isPointerTy()) {
9550 return LHS;
9551 }
9552 if (RHS->getType()->isPointerTy()) {
9555 return RHS;
9556 }
9557 ExitLimit EL = howFarToZero(getMinusSCEV(LHS, RHS), L, ControlsOnlyExit,
9558 AllowPredicates);
9559 if (EL.hasAnyInfo())
9560 return EL;
9561 break;
9562 }
9563 case ICmpInst::ICMP_EQ: { // while (X == Y)
9564 // Convert to: while (X-Y == 0)
9565 if (LHS->getType()->isPointerTy()) {
9568 return LHS;
9569 }
9570 if (RHS->getType()->isPointerTy()) {
9573 return RHS;
9574 }
9575 ExitLimit EL = howFarToNonZero(getMinusSCEV(LHS, RHS), L);
9576 if (EL.hasAnyInfo()) return EL;
9577 break;
9578 }
9579 case ICmpInst::ICMP_SLE:
9580 case ICmpInst::ICMP_ULE:
9581 // Since the loop is finite, an invariant RHS cannot include the boundary
9582 // value, otherwise it would loop forever.
9583 if (!EnableFiniteLoopControl || !ControllingFiniteLoop ||
9584 !isLoopInvariant(RHS, L)) {
9585 // Otherwise, perform the addition in a wider type, to avoid overflow.
9586 // If the LHS is an addrec with the appropriate nowrap flag, the
9587 // extension will be sunk into it and the exit count can be analyzed.
9588 auto *OldType = dyn_cast<IntegerType>(LHS->getType());
9589 if (!OldType)
9590 break;
9591 // Prefer doubling the bitwidth over adding a single bit to make it more
9592 // likely that we use a legal type.
9593 auto *NewType =
9594 Type::getIntNTy(OldType->getContext(), OldType->getBitWidth() * 2);
9595 if (ICmpInst::isSigned(Pred)) {
9596 LHS = getSignExtendExpr(LHS, NewType);
9597 RHS = getSignExtendExpr(RHS, NewType);
9598 } else {
9599 LHS = getZeroExtendExpr(LHS, NewType);
9600 RHS = getZeroExtendExpr(RHS, NewType);
9601 }
9602 }
9604 [[fallthrough]];
9605 case ICmpInst::ICMP_SLT:
9606 case ICmpInst::ICMP_ULT: { // while (X < Y)
9607 bool IsSigned = ICmpInst::isSigned(Pred);
9608 ExitLimit EL = howManyLessThans(LHS, RHS, L, IsSigned, ControlsOnlyExit,
9609 AllowPredicates);
9610 if (EL.hasAnyInfo())
9611 return EL;
9612 break;
9613 }
9614 case ICmpInst::ICMP_SGE:
9615 case ICmpInst::ICMP_UGE:
9616 // Since the loop is finite, an invariant RHS cannot include the boundary
9617 // value, otherwise it would loop forever.
9618 if (!EnableFiniteLoopControl || !ControllingFiniteLoop ||
9619 !isLoopInvariant(RHS, L))
9620 break;
9622 [[fallthrough]];
9623 case ICmpInst::ICMP_SGT:
9624 case ICmpInst::ICMP_UGT: { // while (X > Y)
9625 bool IsSigned = ICmpInst::isSigned(Pred);
9626 ExitLimit EL = howManyGreaterThans(LHS, RHS, L, IsSigned, ControlsOnlyExit,
9627 AllowPredicates);
9628 if (EL.hasAnyInfo())
9629 return EL;
9630 break;
9631 }
9632 default:
9633 break;
9634 }
9635
9636 return getCouldNotCompute();
9637}
9638
9639ScalarEvolution::ExitLimit
9640ScalarEvolution::computeExitLimitFromSingleExitSwitch(const Loop *L,
9641 SwitchInst *Switch,
9642 BasicBlock *ExitingBlock,
9643 bool ControlsOnlyExit) {
9644 assert(!L->contains(ExitingBlock) && "Not an exiting block!");
9645
9646 // Give up if the exit is the default dest of a switch.
9647 if (Switch->getDefaultDest() == ExitingBlock)
9648 return getCouldNotCompute();
9649
9650 assert(L->contains(Switch->getDefaultDest()) &&
9651 "Default case must not exit the loop!");
9652 const SCEV *LHS = getSCEVAtScope(Switch->getCondition(), L);
9653 const SCEV *RHS = getConstant(Switch->findCaseDest(ExitingBlock));
9654
9655 // while (X != Y) --> while (X-Y != 0)
9656 ExitLimit EL = howFarToZero(getMinusSCEV(LHS, RHS), L, ControlsOnlyExit);
9657 if (EL.hasAnyInfo())
9658 return EL;
9659
9660 return getCouldNotCompute();
9661}
9662
9663static ConstantInt *
9665 ScalarEvolution &SE) {
9666 const SCEV *InVal = SE.getConstant(C);
9667 const SCEV *Val = AddRec->evaluateAtIteration(InVal, SE);
9669 "Evaluation of SCEV at constant didn't fold correctly?");
9670 return cast<SCEVConstant>(Val)->getValue();
9671}
9672
9673ScalarEvolution::ExitLimit ScalarEvolution::computeShiftCompareExitLimit(
9674 Value *LHS, Value *RHSV, const Loop *L, ICmpInst::Predicate Pred) {
9675 ConstantInt *RHS = dyn_cast<ConstantInt>(RHSV);
9676 if (!RHS)
9677 return getCouldNotCompute();
9678
9679 const BasicBlock *Latch = L->getLoopLatch();
9680 if (!Latch)
9681 return getCouldNotCompute();
9682
9683 const BasicBlock *Predecessor = L->getLoopPredecessor();
9684 if (!Predecessor)
9685 return getCouldNotCompute();
9686
9687 // Return true if V is of the form "LHS `shift_op` <positive constant>".
9688 // Return LHS in OutLHS and shift_opt in OutOpCode.
9689 auto MatchPositiveShift =
9690 [](Value *V, Value *&OutLHS, Instruction::BinaryOps &OutOpCode) {
9691
9692 using namespace PatternMatch;
9693
9694 ConstantInt *ShiftAmt;
9695 if (match(V, m_LShr(m_Value(OutLHS), m_ConstantInt(ShiftAmt))))
9696 OutOpCode = Instruction::LShr;
9697 else if (match(V, m_AShr(m_Value(OutLHS), m_ConstantInt(ShiftAmt))))
9698 OutOpCode = Instruction::AShr;
9699 else if (match(V, m_Shl(m_Value(OutLHS), m_ConstantInt(ShiftAmt))))
9700 OutOpCode = Instruction::Shl;
9701 else
9702 return false;
9703
9704 return ShiftAmt->getValue().isStrictlyPositive();
9705 };
9706
9707 // Recognize a "shift recurrence" either of the form %iv or of %iv.shifted in
9708 //
9709 // loop:
9710 // %iv = phi i32 [ %iv.shifted, %loop ], [ %val, %preheader ]
9711 // %iv.shifted = lshr i32 %iv, <positive constant>
9712 //
9713 // Return true on a successful match. Return the corresponding PHI node (%iv
9714 // above) in PNOut and the opcode of the shift operation in OpCodeOut.
9715 auto MatchShiftRecurrence =
9716 [&](Value *V, PHINode *&PNOut, Instruction::BinaryOps &OpCodeOut) {
9717 std::optional<Instruction::BinaryOps> PostShiftOpCode;
9718
9719 {
9721 Value *V;
9722
9723 // If we encounter a shift instruction, "peel off" the shift operation,
9724 // and remember that we did so. Later when we inspect %iv's backedge
9725 // value, we will make sure that the backedge value uses the same
9726 // operation.
9727 //
9728 // Note: the peeled shift operation does not have to be the same
9729 // instruction as the one feeding into the PHI's backedge value. We only
9730 // really care about it being the same *kind* of shift instruction --
9731 // that's all that is required for our later inferences to hold.
9732 if (MatchPositiveShift(LHS, V, OpC)) {
9733 PostShiftOpCode = OpC;
9734 LHS = V;
9735 }
9736 }
9737
9738 PNOut = dyn_cast<PHINode>(LHS);
9739 if (!PNOut || PNOut->getParent() != L->getHeader())
9740 return false;
9741
9742 Value *BEValue = PNOut->getIncomingValueForBlock(Latch);
9743 Value *OpLHS;
9744
9745 return
9746 // The backedge value for the PHI node must be a shift by a positive
9747 // amount
9748 MatchPositiveShift(BEValue, OpLHS, OpCodeOut) &&
9749
9750 // of the PHI node itself
9751 OpLHS == PNOut &&
9752
9753 // and the kind of shift should be match the kind of shift we peeled
9754 // off, if any.
9755 (!PostShiftOpCode || *PostShiftOpCode == OpCodeOut);
9756 };
9757
9758 PHINode *PN;
9760 if (!MatchShiftRecurrence(LHS, PN, OpCode))
9761 return getCouldNotCompute();
9762
9763 const DataLayout &DL = getDataLayout();
9764
9765 // The key rationale for this optimization is that for some kinds of shift
9766 // recurrences, the value of the recurrence "stabilizes" to either 0 or -1
9767 // within a finite number of iterations. If the condition guarding the
9768 // backedge (in the sense that the backedge is taken if the condition is true)
9769 // is false for the value the shift recurrence stabilizes to, then we know
9770 // that the backedge is taken only a finite number of times.
9771
9772 ConstantInt *StableValue = nullptr;
9773 switch (OpCode) {
9774 default:
9775 llvm_unreachable("Impossible case!");
9776
9777 case Instruction::AShr: {
9778 // {K,ashr,<positive-constant>} stabilizes to signum(K) in at most
9779 // bitwidth(K) iterations.
9780 Value *FirstValue = PN->getIncomingValueForBlock(Predecessor);
9781 KnownBits Known = computeKnownBits(FirstValue, DL, &AC,
9782 Predecessor->getTerminator(), &DT);
9783 auto *Ty = cast<IntegerType>(RHS->getType());
9784 if (Known.isNonNegative())
9785 StableValue = ConstantInt::get(Ty, 0);
9786 else if (Known.isNegative())
9787 StableValue = ConstantInt::get(Ty, -1, true);
9788 else
9789 return getCouldNotCompute();
9790
9791 break;
9792 }
9793 case Instruction::LShr:
9794 case Instruction::Shl:
9795 // Both {K,lshr,<positive-constant>} and {K,shl,<positive-constant>}
9796 // stabilize to 0 in at most bitwidth(K) iterations.
9797 StableValue = ConstantInt::get(cast<IntegerType>(RHS->getType()), 0);
9798 break;
9799 }
9800
9801 auto *Result =
9802 ConstantFoldCompareInstOperands(Pred, StableValue, RHS, DL, &TLI);
9803 assert(Result->getType()->isIntegerTy(1) &&
9804 "Otherwise cannot be an operand to a branch instruction");
9805
9806 if (Result->isNullValue()) {
9807 unsigned BitWidth = getTypeSizeInBits(RHS->getType());
9808 const SCEV *UpperBound =
9810 return ExitLimit(getCouldNotCompute(), UpperBound, UpperBound, false);
9811 }
9812
9813 return getCouldNotCompute();
9814}
9815
9816/// Return true if we can constant fold an instruction of the specified type,
9817/// assuming that all operands were constants.
9818static bool CanConstantFold(const Instruction *I) {
9822 return true;
9823
9824 if (const CallInst *CI = dyn_cast<CallInst>(I))
9825 if (const Function *F = CI->getCalledFunction())
9826 return canConstantFoldCallTo(CI, F);
9827 return false;
9828}
9829
9830/// Determine whether this instruction can constant evolve within this loop
9831/// assuming its operands can all constant evolve.
9832static bool canConstantEvolve(Instruction *I, const Loop *L) {
9833 // An instruction outside of the loop can't be derived from a loop PHI.
9834 if (!L->contains(I)) return false;
9835
9836 if (isa<PHINode>(I)) {
9837 // We don't currently keep track of the control flow needed to evaluate
9838 // PHIs, so we cannot handle PHIs inside of loops.
9839 return L->getHeader() == I->getParent();
9840 }
9841
9842 // If we won't be able to constant fold this expression even if the operands
9843 // are constants, bail early.
9844 return CanConstantFold(I);
9845}
9846
9847/// getConstantEvolvingPHIOperands - Implement getConstantEvolvingPHI by
9848/// recursing through each instruction operand until reaching a loop header phi.
9849static PHINode *
9852 unsigned Depth) {
9854 return nullptr;
9855
9856 // Otherwise, we can evaluate this instruction if all of its operands are
9857 // constant or derived from a PHI node themselves.
9858 PHINode *PHI = nullptr;
9859 for (Value *Op : UseInst->operands()) {
9860 if (isa<Constant>(Op)) continue;
9861
9863 if (!OpInst || !canConstantEvolve(OpInst, L)) return nullptr;
9864
9865 PHINode *P = dyn_cast<PHINode>(OpInst);
9866 if (!P)
9867 // If this operand is already visited, reuse the prior result.
9868 // We may have P != PHI if this is the deepest point at which the
9869 // inconsistent paths meet.
9870 P = PHIMap.lookup(OpInst);
9871 if (!P) {
9872 // Recurse and memoize the results, whether a phi is found or not.
9873 // This recursive call invalidates pointers into PHIMap.
9874 P = getConstantEvolvingPHIOperands(OpInst, L, PHIMap, Depth + 1);
9875 PHIMap[OpInst] = P;
9876 }
9877 if (!P)
9878 return nullptr; // Not evolving from PHI
9879 if (PHI && PHI != P)
9880 return nullptr; // Evolving from multiple different PHIs.
9881 PHI = P;
9882 }
9883 // This is a expression evolving from a constant PHI!
9884 return PHI;
9885}
9886
9887/// getConstantEvolvingPHI - Given an LLVM value and a loop, return a PHI node
9888/// in the loop that V is derived from. We allow arbitrary operations along the
9889/// way, but the operands of an operation must either be constants or a value
9890/// derived from a constant PHI. If this expression does not fit with these
9891/// constraints, return null.
9894 if (!I || !canConstantEvolve(I, L)) return nullptr;
9895
9896 if (PHINode *PN = dyn_cast<PHINode>(I))
9897 return PN;
9898
9899 // Record non-constant instructions contained by the loop.
9901 return getConstantEvolvingPHIOperands(I, L, PHIMap, 0);
9902}
9903
9904/// EvaluateExpression - Given an expression that passes the
9905/// getConstantEvolvingPHI predicate, evaluate its value assuming the PHI node
9906/// in the loop has the value PHIVal. If we can't fold this expression for some
9907/// reason, return null.
9910 const DataLayout &DL,
9911 const TargetLibraryInfo *TLI) {
9912 // Convenient constant check, but redundant for recursive calls.
9913 if (Constant *C = dyn_cast<Constant>(V)) return C;
9915 if (!I) return nullptr;
9916
9917 if (Constant *C = Vals.lookup(I)) return C;
9918
9919 // An instruction inside the loop depends on a value outside the loop that we
9920 // weren't given a mapping for, or a value such as a call inside the loop.
9921 if (!canConstantEvolve(I, L)) return nullptr;
9922
9923 // An unmapped PHI can be due to a branch or another loop inside this loop,
9924 // or due to this not being the initial iteration through a loop where we
9925 // couldn't compute the evolution of this particular PHI last time.
9926 if (isa<PHINode>(I)) return nullptr;
9927
9928 std::vector<Constant*> Operands(I->getNumOperands());
9929
9930 for (unsigned i = 0, e = I->getNumOperands(); i != e; ++i) {
9931 Instruction *Operand = dyn_cast<Instruction>(I->getOperand(i));
9932 if (!Operand) {
9933 Operands[i] = dyn_cast<Constant>(I->getOperand(i));
9934 if (!Operands[i]) return nullptr;
9935 continue;
9936 }
9937 Constant *C = EvaluateExpression(Operand, L, Vals, DL, TLI);
9938 Vals[Operand] = C;
9939 if (!C) return nullptr;
9940 Operands[i] = C;
9941 }
9942
9943 return ConstantFoldInstOperands(I, Operands, DL, TLI,
9944 /*AllowNonDeterministic=*/false);
9945}
9946
9947
9948// If every incoming value to PN except the one for BB is a specific Constant,
9949// return that, else return nullptr.
9951 Constant *IncomingVal = nullptr;
9952
9953 for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) {
9954 if (PN->getIncomingBlock(i) == BB)
9955 continue;
9956
9957 auto *CurrentVal = dyn_cast<Constant>(PN->getIncomingValue(i));
9958 if (!CurrentVal)
9959 return nullptr;
9960
9961 if (IncomingVal != CurrentVal) {
9962 if (IncomingVal)
9963 return nullptr;
9964 IncomingVal = CurrentVal;
9965 }
9966 }
9967
9968 return IncomingVal;
9969}
9970
9971/// getConstantEvolutionLoopExitValue - If we know that the specified Phi is
9972/// in the header of its containing loop, we know the loop executes a
9973/// constant number of times, and the PHI node is just a recurrence
9974/// involving constants, fold it.
9975Constant *
9976ScalarEvolution::getConstantEvolutionLoopExitValue(PHINode *PN,
9977 const APInt &BEs,
9978 const Loop *L) {
9979 auto [I, Inserted] = ConstantEvolutionLoopExitValue.try_emplace(PN);
9980 if (!Inserted)
9981 return I->second;
9982
9984 return nullptr; // Not going to evaluate it.
9985
9986 Constant *&RetVal = I->second;
9987
9988 DenseMap<Instruction *, Constant *> CurrentIterVals;
9989 BasicBlock *Header = L->getHeader();
9990 assert(PN->getParent() == Header && "Can't evaluate PHI not in loop header!");
9991
9992 BasicBlock *Latch = L->getLoopLatch();
9993 if (!Latch)
9994 return nullptr;
9995
9996 for (PHINode &PHI : Header->phis()) {
9997 if (auto *StartCST = getOtherIncomingValue(&PHI, Latch))
9998 CurrentIterVals[&PHI] = StartCST;
9999 }
10000 if (!CurrentIterVals.count(PN))
10001 return RetVal = nullptr;
10002
10003 Value *BEValue = PN->getIncomingValueForBlock(Latch);
10004
10005 // Execute the loop symbolically to determine the exit value.
10006 assert(BEs.getActiveBits() < CHAR_BIT * sizeof(unsigned) &&
10007 "BEs is <= MaxBruteForceIterations which is an 'unsigned'!");
10008
10009 unsigned NumIterations = BEs.getZExtValue(); // must be in range
10010 unsigned IterationNum = 0;
10011 const DataLayout &DL = getDataLayout();
10012 for (; ; ++IterationNum) {
10013 if (IterationNum == NumIterations)
10014 return RetVal = CurrentIterVals[PN]; // Got exit value!
10015
10016 // Compute the value of the PHIs for the next iteration.
10017 // EvaluateExpression adds non-phi values to the CurrentIterVals map.
10018 DenseMap<Instruction *, Constant *> NextIterVals;
10019 Constant *NextPHI =
10020 EvaluateExpression(BEValue, L, CurrentIterVals, DL, &TLI);
10021 if (!NextPHI)
10022 return nullptr; // Couldn't evaluate!
10023 NextIterVals[PN] = NextPHI;
10024
10025 bool StoppedEvolving = NextPHI == CurrentIterVals[PN];
10026
10027 // Also evaluate the other PHI nodes. However, we don't get to stop if we
10028 // cease to be able to evaluate one of them or if they stop evolving,
10029 // because that doesn't necessarily prevent us from computing PN.
10031 for (const auto &I : CurrentIterVals) {
10032 PHINode *PHI = dyn_cast<PHINode>(I.first);
10033 if (!PHI || PHI == PN || PHI->getParent() != Header) continue;
10034 PHIsToCompute.emplace_back(PHI, I.second);
10035 }
10036 // We use two distinct loops because EvaluateExpression may invalidate any
10037 // iterators into CurrentIterVals.
10038 for (const auto &I : PHIsToCompute) {
10039 PHINode *PHI = I.first;
10040 Constant *&NextPHI = NextIterVals[PHI];
10041 if (!NextPHI) { // Not already computed.
10042 Value *BEValue = PHI->getIncomingValueForBlock(Latch);
10043 NextPHI = EvaluateExpression(BEValue, L, CurrentIterVals, DL, &TLI);
10044 }
10045 if (NextPHI != I.second)
10046 StoppedEvolving = false;
10047 }
10048
10049 // If all entries in CurrentIterVals == NextIterVals then we can stop
10050 // iterating, the loop can't continue to change.
10051 if (StoppedEvolving)
10052 return RetVal = CurrentIterVals[PN];
10053
10054 CurrentIterVals.swap(NextIterVals);
10055 }
10056}
10057
10058const SCEV *ScalarEvolution::computeExitCountExhaustively(const Loop *L,
10059 Value *Cond,
10060 bool ExitWhen) {
10061 PHINode *PN = getConstantEvolvingPHI(Cond, L);
10062 if (!PN) return getCouldNotCompute();
10063
10064 // If the loop is canonicalized, the PHI will have exactly two entries.
10065 // That's the only form we support here.
10066 if (PN->getNumIncomingValues() != 2) return getCouldNotCompute();
10067
10068 DenseMap<Instruction *, Constant *> CurrentIterVals;
10069 BasicBlock *Header = L->getHeader();
10070 assert(PN->getParent() == Header && "Can't evaluate PHI not in loop header!");
10071
10072 BasicBlock *Latch = L->getLoopLatch();
10073 assert(Latch && "Should follow from NumIncomingValues == 2!");
10074
10075 for (PHINode &PHI : Header->phis()) {
10076 if (auto *StartCST = getOtherIncomingValue(&PHI, Latch))
10077 CurrentIterVals[&PHI] = StartCST;
10078 }
10079 if (!CurrentIterVals.count(PN))
10080 return getCouldNotCompute();
10081
10082 // Okay, we find a PHI node that defines the trip count of this loop. Execute
10083 // the loop symbolically to determine when the condition gets a value of
10084 // "ExitWhen".
10085 unsigned MaxIterations = MaxBruteForceIterations; // Limit analysis.
10086 const DataLayout &DL = getDataLayout();
10087 for (unsigned IterationNum = 0; IterationNum != MaxIterations;++IterationNum){
10088 auto *CondVal = dyn_cast_or_null<ConstantInt>(
10089 EvaluateExpression(Cond, L, CurrentIterVals, DL, &TLI));
10090
10091 // Couldn't symbolically evaluate.
10092 if (!CondVal) return getCouldNotCompute();
10093
10094 if (CondVal->getValue() == uint64_t(ExitWhen)) {
10095 ++NumBruteForceTripCountsComputed;
10096 return getConstant(Type::getInt32Ty(getContext()), IterationNum);
10097 }
10098
10099 // Update all the PHI nodes for the next iteration.
10100 DenseMap<Instruction *, Constant *> NextIterVals;
10101
10102 // Create a list of which PHIs we need to compute. We want to do this before
10103 // calling EvaluateExpression on them because that may invalidate iterators
10104 // into CurrentIterVals.
10105 SmallVector<PHINode *, 8> PHIsToCompute;
10106 for (const auto &I : CurrentIterVals) {
10107 PHINode *PHI = dyn_cast<PHINode>(I.first);
10108 if (!PHI || PHI->getParent() != Header) continue;
10109 PHIsToCompute.push_back(PHI);
10110 }
10111 for (PHINode *PHI : PHIsToCompute) {
10112 Constant *&NextPHI = NextIterVals[PHI];
10113 if (NextPHI) continue; // Already computed!
10114
10115 Value *BEValue = PHI->getIncomingValueForBlock(Latch);
10116 NextPHI = EvaluateExpression(BEValue, L, CurrentIterVals, DL, &TLI);
10117 }
10118 CurrentIterVals.swap(NextIterVals);
10119 }
10120
10121 // Too many iterations were needed to evaluate.
10122 return getCouldNotCompute();
10123}
10124
10125const SCEV *ScalarEvolution::getSCEVAtScope(const SCEV *V, const Loop *L) {
10127 ValuesAtScopes[V];
10128 // Check to see if we've folded this expression at this loop before.
10129 for (auto &LS : Values)
10130 if (LS.first == L)
10131 return LS.second ? LS.second : V;
10132
10133 Values.emplace_back(L, nullptr);
10134
10135 // Otherwise compute it.
10136 const SCEV *C = computeSCEVAtScope(V, L);
10137 for (auto &LS : reverse(ValuesAtScopes[V]))
10138 if (LS.first == L) {
10139 LS.second = C;
10140 if (!isa<SCEVConstant>(C))
10141 ValuesAtScopesUsers[C].push_back({L, V});
10142 break;
10143 }
10144 return C;
10145}
10146
10147/// This builds up a Constant using the ConstantExpr interface. That way, we
10148/// will return Constants for objects which aren't represented by a
10149/// SCEVConstant, because SCEVConstant is restricted to ConstantInt.
10150/// Returns NULL if the SCEV isn't representable as a Constant.
10152 switch (V->getSCEVType()) {
10153 case scCouldNotCompute:
10154 case scAddRecExpr:
10155 case scVScale:
10156 return nullptr;
10157 case scConstant:
10158 return cast<SCEVConstant>(V)->getValue();
10159 case scUnknown:
10160 return dyn_cast<Constant>(cast<SCEVUnknown>(V)->getValue());
10161 case scPtrToAddr: {
10163 if (Constant *CastOp = BuildConstantFromSCEV(P2I->getOperand()))
10164 return ConstantExpr::getPtrToAddr(CastOp, P2I->getType());
10165
10166 return nullptr;
10167 }
10168 case scPtrToInt: {
10170 if (Constant *CastOp = BuildConstantFromSCEV(P2I->getOperand()))
10171 return ConstantExpr::getPtrToInt(CastOp, P2I->getType());
10172
10173 return nullptr;
10174 }
10175 case scTruncate: {
10177 if (Constant *CastOp = BuildConstantFromSCEV(ST->getOperand()))
10178 return ConstantExpr::getTrunc(CastOp, ST->getType());
10179 return nullptr;
10180 }
10181 case scAddExpr: {
10182 const SCEVAddExpr *SA = cast<SCEVAddExpr>(V);
10183 Constant *C = nullptr;
10184 for (const SCEV *Op : SA->operands()) {
10186 if (!OpC)
10187 return nullptr;
10188 if (!C) {
10189 C = OpC;
10190 continue;
10191 }
10192 assert(!C->getType()->isPointerTy() &&
10193 "Can only have one pointer, and it must be last");
10194 if (OpC->getType()->isPointerTy()) {
10195 // The offsets have been converted to bytes. We can add bytes using
10196 // an i8 GEP.
10197 C = ConstantExpr::getPtrAdd(OpC, C);
10198 } else {
10199 C = ConstantExpr::getAdd(C, OpC);
10200 }
10201 }
10202 return C;
10203 }
10204 case scMulExpr:
10205 case scSignExtend:
10206 case scZeroExtend:
10207 case scUDivExpr:
10208 case scSMaxExpr:
10209 case scUMaxExpr:
10210 case scSMinExpr:
10211 case scUMinExpr:
10213 return nullptr;
10214 }
10215 llvm_unreachable("Unknown SCEV kind!");
10216}
10217
10218const SCEV *ScalarEvolution::getWithOperands(const SCEV *S,
10219 SmallVectorImpl<SCEVUse> &NewOps) {
10220 switch (S->getSCEVType()) {
10221 case scTruncate:
10222 case scZeroExtend:
10223 case scSignExtend:
10224 case scPtrToAddr:
10225 case scPtrToInt:
10226 return getCastExpr(S->getSCEVType(), NewOps[0], S->getType());
10227 case scAddRecExpr: {
10228 auto *AddRec = cast<SCEVAddRecExpr>(S);
10229 return getAddRecExpr(NewOps, AddRec->getLoop(), AddRec->getNoWrapFlags());
10230 }
10231 case scAddExpr:
10232 return getAddExpr(NewOps, cast<SCEVAddExpr>(S)->getNoWrapFlags());
10233 case scMulExpr:
10234 return getMulExpr(NewOps, cast<SCEVMulExpr>(S)->getNoWrapFlags());
10235 case scUDivExpr:
10236 return getUDivExpr(NewOps[0], NewOps[1]);
10237 case scUMaxExpr:
10238 case scSMaxExpr:
10239 case scUMinExpr:
10240 case scSMinExpr:
10241 return getMinMaxExpr(S->getSCEVType(), NewOps);
10243 return getSequentialMinMaxExpr(S->getSCEVType(), NewOps);
10244 case scConstant:
10245 case scVScale:
10246 case scUnknown:
10247 return S;
10248 case scCouldNotCompute:
10249 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
10250 }
10251 llvm_unreachable("Unknown SCEV kind!");
10252}
10253
10254const SCEV *ScalarEvolution::computeSCEVAtScope(const SCEV *V, const Loop *L) {
10255 switch (V->getSCEVType()) {
10256 case scConstant:
10257 case scVScale:
10258 return V;
10259 case scAddRecExpr: {
10260 // If this is a loop recurrence for a loop that does not contain L, then we
10261 // are dealing with the final value computed by the loop.
10262 const SCEVAddRecExpr *AddRec = cast<SCEVAddRecExpr>(V);
10263 // First, attempt to evaluate each operand.
10264 // Avoid performing the look-up in the common case where the specified
10265 // expression has no loop-variant portions.
10266 for (unsigned i = 0, e = AddRec->getNumOperands(); i != e; ++i) {
10267 const SCEV *OpAtScope = getSCEVAtScope(AddRec->getOperand(i), L);
10268 if (OpAtScope == AddRec->getOperand(i))
10269 continue;
10270
10271 // Okay, at least one of these operands is loop variant but might be
10272 // foldable. Build a new instance of the folded commutative expression.
10274 NewOps.reserve(AddRec->getNumOperands());
10275 append_range(NewOps, AddRec->operands().take_front(i));
10276 NewOps.push_back(OpAtScope);
10277 for (++i; i != e; ++i)
10278 NewOps.push_back(getSCEVAtScope(AddRec->getOperand(i), L));
10279
10280 const SCEV *FoldedRec = getAddRecExpr(
10281 NewOps, AddRec->getLoop(), AddRec->getNoWrapFlags(SCEV::FlagNW));
10282 AddRec = dyn_cast<SCEVAddRecExpr>(FoldedRec);
10283 // The addrec may be folded to a nonrecurrence, for example, if the
10284 // induction variable is multiplied by zero after constant folding. Go
10285 // ahead and return the folded value.
10286 if (!AddRec)
10287 return FoldedRec;
10288 break;
10289 }
10290
10291 // If the scope is outside the addrec's loop, evaluate it by using the
10292 // loop exit value of the addrec.
10293 if (!AddRec->getLoop()->contains(L)) {
10294 // To evaluate this recurrence, we need to know how many times the AddRec
10295 // loop iterates. Compute this now.
10296 const SCEV *BackedgeTakenCount = getBackedgeTakenCount(AddRec->getLoop());
10297 if (BackedgeTakenCount == getCouldNotCompute())
10298 return AddRec;
10299
10300 // Then, evaluate the AddRec.
10301 return AddRec->evaluateAtIteration(BackedgeTakenCount, *this);
10302 }
10303
10304 return AddRec;
10305 }
10306 case scTruncate:
10307 case scZeroExtend:
10308 case scSignExtend:
10309 case scPtrToAddr:
10310 case scPtrToInt:
10311 case scAddExpr:
10312 case scMulExpr:
10313 case scUDivExpr:
10314 case scUMaxExpr:
10315 case scSMaxExpr:
10316 case scUMinExpr:
10317 case scSMinExpr:
10318 case scSequentialUMinExpr: {
10319 ArrayRef<SCEVUse> Ops = V->operands();
10320 // Avoid performing the look-up in the common case where the specified
10321 // expression has no loop-variant portions.
10322 for (unsigned i = 0, e = Ops.size(); i != e; ++i) {
10323 const SCEV *OpAtScope = getSCEVAtScope(Ops[i].getPointer(), L);
10324 if (OpAtScope != Ops[i].getPointer()) {
10325 // Okay, at least one of these operands is loop variant but might be
10326 // foldable. Build a new instance of the folded commutative expression.
10328 NewOps.reserve(Ops.size());
10329 append_range(NewOps, Ops.take_front(i));
10330 NewOps.push_back(OpAtScope);
10331
10332 for (++i; i != e; ++i) {
10333 OpAtScope = getSCEVAtScope(Ops[i].getPointer(), L);
10334 NewOps.push_back(OpAtScope);
10335 }
10336
10337 return getWithOperands(V, NewOps);
10338 }
10339 }
10340 // If we got here, all operands are loop invariant.
10341 return V;
10342 }
10343 case scUnknown: {
10344 // If this instruction is evolved from a constant-evolving PHI, compute the
10345 // exit value from the loop without using SCEVs.
10346 const SCEVUnknown *SU = cast<SCEVUnknown>(V);
10348 if (!I)
10349 return V; // This is some other type of SCEVUnknown, just return it.
10350
10351 if (PHINode *PN = dyn_cast<PHINode>(I)) {
10352 const Loop *CurrLoop = this->LI[I->getParent()];
10353 // Looking for loop exit value.
10354 if (CurrLoop && CurrLoop->getParentLoop() == L &&
10355 PN->getParent() == CurrLoop->getHeader()) {
10356 // Okay, there is no closed form solution for the PHI node. Check
10357 // to see if the loop that contains it has a known backedge-taken
10358 // count. If so, we may be able to force computation of the exit
10359 // value.
10360 const SCEV *BackedgeTakenCount = getBackedgeTakenCount(CurrLoop);
10361 // This trivial case can show up in some degenerate cases where
10362 // the incoming IR has not yet been fully simplified.
10363 if (BackedgeTakenCount->isZero()) {
10364 Value *InitValue = nullptr;
10365 bool MultipleInitValues = false;
10366 for (unsigned i = 0; i < PN->getNumIncomingValues(); i++) {
10367 if (!CurrLoop->contains(PN->getIncomingBlock(i))) {
10368 if (!InitValue)
10369 InitValue = PN->getIncomingValue(i);
10370 else if (InitValue != PN->getIncomingValue(i)) {
10371 MultipleInitValues = true;
10372 break;
10373 }
10374 }
10375 }
10376 if (!MultipleInitValues && InitValue)
10377 return getSCEV(InitValue);
10378 }
10379 // Do we have a loop invariant value flowing around the backedge
10380 // for a loop which must execute the backedge?
10381 if (!isa<SCEVCouldNotCompute>(BackedgeTakenCount) &&
10382 isKnownNonZero(BackedgeTakenCount) &&
10383 PN->getNumIncomingValues() == 2) {
10384
10385 unsigned InLoopPred =
10386 CurrLoop->contains(PN->getIncomingBlock(0)) ? 0 : 1;
10387 Value *BackedgeVal = PN->getIncomingValue(InLoopPred);
10388 if (CurrLoop->isLoopInvariant(BackedgeVal))
10389 return getSCEV(BackedgeVal);
10390 }
10391 if (auto *BTCC = dyn_cast<SCEVConstant>(BackedgeTakenCount)) {
10392 // Okay, we know how many times the containing loop executes. If
10393 // this is a constant evolving PHI node, get the final value at
10394 // the specified iteration number.
10395 Constant *RV =
10396 getConstantEvolutionLoopExitValue(PN, BTCC->getAPInt(), CurrLoop);
10397 if (RV)
10398 return getSCEV(RV);
10399 }
10400 }
10401 }
10402
10403 // Okay, this is an expression that we cannot symbolically evaluate
10404 // into a SCEV. Check to see if it's possible to symbolically evaluate
10405 // the arguments into constants, and if so, try to constant propagate the
10406 // result. This is particularly useful for computing loop exit values.
10407 if (!CanConstantFold(I))
10408 return V; // This is some other type of SCEVUnknown, just return it.
10409
10410 SmallVector<Constant *, 4> Operands;
10411 Operands.reserve(I->getNumOperands());
10412 bool MadeImprovement = false;
10413 for (Value *Op : I->operands()) {
10414 if (Constant *C = dyn_cast<Constant>(Op)) {
10415 Operands.push_back(C);
10416 continue;
10417 }
10418
10419 // If any of the operands is non-constant and if they are
10420 // non-integer and non-pointer, don't even try to analyze them
10421 // with scev techniques.
10422 if (!isSCEVable(Op->getType()))
10423 return V;
10424
10425 const SCEV *OrigV = getSCEV(Op);
10426 const SCEV *OpV = getSCEVAtScope(OrigV, L);
10427 MadeImprovement |= OrigV != OpV;
10428
10430 if (!C)
10431 return V;
10432 assert(C->getType() == Op->getType() && "Type mismatch");
10433 Operands.push_back(C);
10434 }
10435
10436 // Check to see if getSCEVAtScope actually made an improvement.
10437 if (!MadeImprovement)
10438 return V; // This is some other type of SCEVUnknown, just return it.
10439
10440 Constant *C = nullptr;
10441 const DataLayout &DL = getDataLayout();
10442 C = ConstantFoldInstOperands(I, Operands, DL, &TLI,
10443 /*AllowNonDeterministic=*/false);
10444 if (!C)
10445 return V;
10446 return getSCEV(C);
10447 }
10448 case scCouldNotCompute:
10449 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
10450 }
10451 llvm_unreachable("Unknown SCEV type!");
10452}
10453
10455 return getSCEVAtScope(getSCEV(V), L);
10456}
10457
10458const SCEV *ScalarEvolution::stripInjectiveFunctions(const SCEV *S) const {
10460 return stripInjectiveFunctions(ZExt->getOperand());
10462 return stripInjectiveFunctions(SExt->getOperand());
10463 return S;
10464}
10465
10466/// Finds the minimum unsigned root of the following equation:
10467///
10468/// A * X = B (mod N)
10469///
10470/// where N = 2^BW and BW is the common bit width of A and B. The signedness of
10471/// A and B isn't important.
10472///
10473/// If the equation does not have a solution, SCEVCouldNotCompute is returned.
10474static const SCEV *
10477 ScalarEvolution &SE, const Loop *L) {
10478 uint32_t BW = A.getBitWidth();
10479 assert(BW == SE.getTypeSizeInBits(B->getType()));
10480 assert(A != 0 && "A must be non-zero.");
10481
10482 // 1. D = gcd(A, N)
10483 //
10484 // The gcd of A and N may have only one prime factor: 2. The number of
10485 // trailing zeros in A is its multiplicity
10486 uint32_t Mult2 = A.countr_zero();
10487 // D = 2^Mult2
10488
10489 // 2. Check if B is divisible by D.
10490 //
10491 // B is divisible by D if and only if the multiplicity of prime factor 2 for B
10492 // is not less than multiplicity of this prime factor for D.
10493 unsigned MinTZ = SE.getMinTrailingZeros(B);
10494 // Try again with the terminator of the loop predecessor for context-specific
10495 // result, if MinTZ s too small.
10496 if (MinTZ < Mult2 && L->getLoopPredecessor())
10497 MinTZ = SE.getMinTrailingZeros(B, L->getLoopPredecessor()->getTerminator());
10498 if (MinTZ < Mult2) {
10499 // Check if we can prove there's no remainder using URem.
10500 const SCEV *URem =
10501 SE.getURemExpr(B, SE.getConstant(APInt::getOneBitSet(BW, Mult2)));
10502 const SCEV *Zero = SE.getZero(B->getType());
10503 if (!SE.isKnownPredicate(CmpInst::ICMP_EQ, URem, Zero)) {
10504 // Try to add a predicate ensuring B is a multiple of 1 << Mult2.
10505 if (!Predicates)
10506 return SE.getCouldNotCompute();
10507
10508 // Avoid adding a predicate that is known to be false.
10509 if (SE.isKnownPredicate(CmpInst::ICMP_NE, URem, Zero))
10510 return SE.getCouldNotCompute();
10511 Predicates->push_back(SE.getEqualPredicate(URem, Zero));
10512 }
10513 }
10514
10515 // 3. Compute I: the multiplicative inverse of (A / D) in arithmetic
10516 // modulo (N / D).
10517 //
10518 // If D == 1, (N / D) == N == 2^BW, so we need one extra bit to represent
10519 // (N / D) in general. The inverse itself always fits into BW bits, though,
10520 // so we immediately truncate it.
10521 APInt AD = A.lshr(Mult2).trunc(BW - Mult2); // AD = A / D
10522 APInt I = AD.multiplicativeInverse().zext(BW);
10523
10524 // 4. Compute the minimum unsigned root of the equation:
10525 // I * (B / D) mod (N / D)
10526 // To simplify the computation, we factor out the divide by D:
10527 // (I * B mod N) / D
10528 const SCEV *D = SE.getConstant(APInt::getOneBitSet(BW, Mult2));
10529 return SE.getUDivExactExpr(SE.getMulExpr(B, SE.getConstant(I)), D);
10530}
10531
10532/// For a given quadratic addrec, generate coefficients of the corresponding
10533/// quadratic equation, multiplied by a common value to ensure that they are
10534/// integers.
10535/// The returned value is a tuple { A, B, C, M, BitWidth }, where
10536/// Ax^2 + Bx + C is the quadratic function, M is the value that A, B and C
10537/// were multiplied by, and BitWidth is the bit width of the original addrec
10538/// coefficients.
10539/// This function returns std::nullopt if the addrec coefficients are not
10540/// compile- time constants.
10541static std::optional<std::tuple<APInt, APInt, APInt, APInt, unsigned>>
10543 assert(AddRec->getNumOperands() == 3 && "This is not a quadratic chrec!");
10544 const SCEVConstant *LC = dyn_cast<SCEVConstant>(AddRec->getOperand(0));
10545 const SCEVConstant *MC = dyn_cast<SCEVConstant>(AddRec->getOperand(1));
10546 const SCEVConstant *NC = dyn_cast<SCEVConstant>(AddRec->getOperand(2));
10547 LLVM_DEBUG(dbgs() << __func__ << ": analyzing quadratic addrec: "
10548 << *AddRec << '\n');
10549
10550 // We currently can only solve this if the coefficients are constants.
10551 if (!LC || !MC || !NC) {
10552 LLVM_DEBUG(dbgs() << __func__ << ": coefficients are not constant\n");
10553 return std::nullopt;
10554 }
10555
10556 APInt L = LC->getAPInt();
10557 APInt M = MC->getAPInt();
10558 APInt N = NC->getAPInt();
10559 assert(!N.isZero() && "This is not a quadratic addrec");
10560
10561 unsigned BitWidth = LC->getAPInt().getBitWidth();
10562 unsigned NewWidth = BitWidth + 1;
10563 LLVM_DEBUG(dbgs() << __func__ << ": addrec coeff bw: "
10564 << BitWidth << '\n');
10565 // The sign-extension (as opposed to a zero-extension) here matches the
10566 // extension used in SolveQuadraticEquationWrap (with the same motivation).
10567 N = N.sext(NewWidth);
10568 M = M.sext(NewWidth);
10569 L = L.sext(NewWidth);
10570
10571 // The increments are M, M+N, M+2N, ..., so the accumulated values are
10572 // L+M, (L+M)+(M+N), (L+M)+(M+N)+(M+2N), ..., that is,
10573 // L+M, L+2M+N, L+3M+3N, ...
10574 // After n iterations the accumulated value Acc is L + nM + n(n-1)/2 N.
10575 //
10576 // The equation Acc = 0 is then
10577 // L + nM + n(n-1)/2 N = 0, or 2L + 2M n + n(n-1) N = 0.
10578 // In a quadratic form it becomes:
10579 // N n^2 + (2M-N) n + 2L = 0.
10580
10581 APInt A = N;
10582 APInt B = 2 * M - A;
10583 APInt C = 2 * L;
10584 APInt T = APInt(NewWidth, 2);
10585 LLVM_DEBUG(dbgs() << __func__ << ": equation " << A << "x^2 + " << B
10586 << "x + " << C << ", coeff bw: " << NewWidth
10587 << ", multiplied by " << T << '\n');
10588 return std::make_tuple(A, B, C, T, BitWidth);
10589}
10590
10591/// Helper function to compare optional APInts:
10592/// (a) if X and Y both exist, return min(X, Y),
10593/// (b) if neither X nor Y exist, return std::nullopt,
10594/// (c) if exactly one of X and Y exists, return that value.
10595static std::optional<APInt> MinOptional(std::optional<APInt> X,
10596 std::optional<APInt> Y) {
10597 if (X && Y) {
10598 unsigned W = std::max(X->getBitWidth(), Y->getBitWidth());
10599 APInt XW = X->sext(W);
10600 APInt YW = Y->sext(W);
10601 return XW.slt(YW) ? *X : *Y;
10602 }
10603 if (!X && !Y)
10604 return std::nullopt;
10605 return X ? *X : *Y;
10606}
10607
10608/// Helper function to truncate an optional APInt to a given BitWidth.
10609/// When solving addrec-related equations, it is preferable to return a value
10610/// that has the same bit width as the original addrec's coefficients. If the
10611/// solution fits in the original bit width, truncate it (except for i1).
10612/// Returning a value of a different bit width may inhibit some optimizations.
10613///
10614/// In general, a solution to a quadratic equation generated from an addrec
10615/// may require BW+1 bits, where BW is the bit width of the addrec's
10616/// coefficients. The reason is that the coefficients of the quadratic
10617/// equation are BW+1 bits wide (to avoid truncation when converting from
10618/// the addrec to the equation).
10619static std::optional<APInt> TruncIfPossible(std::optional<APInt> X,
10620 unsigned BitWidth) {
10621 if (!X)
10622 return std::nullopt;
10623 unsigned W = X->getBitWidth();
10625 return X->trunc(BitWidth);
10626 return X;
10627}
10628
10629/// Let c(n) be the value of the quadratic chrec {L,+,M,+,N} after n
10630/// iterations. The values L, M, N are assumed to be signed, and they
10631/// should all have the same bit widths.
10632/// Find the least n >= 0 such that c(n) = 0 in the arithmetic modulo 2^BW,
10633/// where BW is the bit width of the addrec's coefficients.
10634/// If the calculated value is a BW-bit integer (for BW > 1), it will be
10635/// returned as such, otherwise the bit width of the returned value may
10636/// be greater than BW.
10637///
10638/// This function returns std::nullopt if
10639/// (a) the addrec coefficients are not constant, or
10640/// (b) SolveQuadraticEquationWrap was unable to find a solution. For cases
10641/// like x^2 = 5, no integer solutions exist, in other cases an integer
10642/// solution may exist, but SolveQuadraticEquationWrap may fail to find it.
10643static std::optional<APInt>
10645 APInt A, B, C, M;
10646 unsigned BitWidth;
10647 auto T = GetQuadraticEquation(AddRec);
10648 if (!T)
10649 return std::nullopt;
10650
10651 std::tie(A, B, C, M, BitWidth) = *T;
10652 LLVM_DEBUG(dbgs() << __func__ << ": solving for unsigned overflow\n");
10653 std::optional<APInt> X =
10655 if (!X)
10656 return std::nullopt;
10657
10658 ConstantInt *CX = ConstantInt::get(SE.getContext(), *X);
10659 ConstantInt *V = EvaluateConstantChrecAtConstant(AddRec, CX, SE);
10660 if (!V->isZero())
10661 return std::nullopt;
10662
10663 return TruncIfPossible(X, BitWidth);
10664}
10665
10666/// Let c(n) be the value of the quadratic chrec {0,+,M,+,N} after n
10667/// iterations. The values M, N are assumed to be signed, and they
10668/// should all have the same bit widths.
10669/// Find the least n such that c(n) does not belong to the given range,
10670/// while c(n-1) does.
10671///
10672/// This function returns std::nullopt if
10673/// (a) the addrec coefficients are not constant, or
10674/// (b) SolveQuadraticEquationWrap was unable to find a solution for the
10675/// bounds of the range.
10676static std::optional<APInt>
10678 const ConstantRange &Range, ScalarEvolution &SE) {
10679 assert(AddRec->getOperand(0)->isZero() &&
10680 "Starting value of addrec should be 0");
10681 LLVM_DEBUG(dbgs() << __func__ << ": solving boundary crossing for range "
10682 << Range << ", addrec " << *AddRec << '\n');
10683 // This case is handled in getNumIterationsInRange. Here we can assume that
10684 // we start in the range.
10685 assert(Range.contains(APInt(SE.getTypeSizeInBits(AddRec->getType()), 0)) &&
10686 "Addrec's initial value should be in range");
10687
10688 APInt A, B, C, M;
10689 unsigned BitWidth;
10690 auto T = GetQuadraticEquation(AddRec);
10691 if (!T)
10692 return std::nullopt;
10693
10694 // Be careful about the return value: there can be two reasons for not
10695 // returning an actual number. First, if no solutions to the equations
10696 // were found, and second, if the solutions don't leave the given range.
10697 // The first case means that the actual solution is "unknown", the second
10698 // means that it's known, but not valid. If the solution is unknown, we
10699 // cannot make any conclusions.
10700 // Return a pair: the optional solution and a flag indicating if the
10701 // solution was found.
10702 auto SolveForBoundary =
10703 [&](APInt Bound) -> std::pair<std::optional<APInt>, bool> {
10704 // Solve for signed overflow and unsigned overflow, pick the lower
10705 // solution.
10706 LLVM_DEBUG(dbgs() << "SolveQuadraticAddRecRange: checking boundary "
10707 << Bound << " (before multiplying by " << M << ")\n");
10708 Bound *= M; // The quadratic equation multiplier.
10709
10710 std::optional<APInt> SO;
10711 if (BitWidth > 1) {
10712 LLVM_DEBUG(dbgs() << "SolveQuadraticAddRecRange: solving for "
10713 "signed overflow\n");
10715 }
10716 LLVM_DEBUG(dbgs() << "SolveQuadraticAddRecRange: solving for "
10717 "unsigned overflow\n");
10718 std::optional<APInt> UO =
10720
10721 auto LeavesRange = [&] (const APInt &X) {
10722 ConstantInt *C0 = ConstantInt::get(SE.getContext(), X);
10723 ConstantInt *V0 = EvaluateConstantChrecAtConstant(AddRec, C0, SE);
10724 if (Range.contains(V0->getValue()))
10725 return false;
10726 // X should be at least 1, so X-1 is non-negative.
10727 ConstantInt *C1 = ConstantInt::get(SE.getContext(), X-1);
10728 ConstantInt *V1 = EvaluateConstantChrecAtConstant(AddRec, C1, SE);
10729 if (Range.contains(V1->getValue()))
10730 return true;
10731 return false;
10732 };
10733
10734 // If SolveQuadraticEquationWrap returns std::nullopt, it means that there
10735 // can be a solution, but the function failed to find it. We cannot treat it
10736 // as "no solution".
10737 if (!SO || !UO)
10738 return {std::nullopt, false};
10739
10740 // Check the smaller value first to see if it leaves the range.
10741 // At this point, both SO and UO must have values.
10742 std::optional<APInt> Min = MinOptional(SO, UO);
10743 if (LeavesRange(*Min))
10744 return { Min, true };
10745 std::optional<APInt> Max = Min == SO ? UO : SO;
10746 if (LeavesRange(*Max))
10747 return { Max, true };
10748
10749 // Solutions were found, but were eliminated, hence the "true".
10750 return {std::nullopt, true};
10751 };
10752
10753 std::tie(A, B, C, M, BitWidth) = *T;
10754 // Lower bound is inclusive, subtract 1 to represent the exiting value.
10755 APInt Lower = Range.getLower().sext(A.getBitWidth()) - 1;
10756 APInt Upper = Range.getUpper().sext(A.getBitWidth());
10757 auto SL = SolveForBoundary(Lower);
10758 auto SU = SolveForBoundary(Upper);
10759 // If any of the solutions was unknown, no meaninigful conclusions can
10760 // be made.
10761 if (!SL.second || !SU.second)
10762 return std::nullopt;
10763
10764 // Claim: The correct solution is not some value between Min and Max.
10765 //
10766 // Justification: Assuming that Min and Max are different values, one of
10767 // them is when the first signed overflow happens, the other is when the
10768 // first unsigned overflow happens. Crossing the range boundary is only
10769 // possible via an overflow (treating 0 as a special case of it, modeling
10770 // an overflow as crossing k*2^W for some k).
10771 //
10772 // The interesting case here is when Min was eliminated as an invalid
10773 // solution, but Max was not. The argument is that if there was another
10774 // overflow between Min and Max, it would also have been eliminated if
10775 // it was considered.
10776 //
10777 // For a given boundary, it is possible to have two overflows of the same
10778 // type (signed/unsigned) without having the other type in between: this
10779 // can happen when the vertex of the parabola is between the iterations
10780 // corresponding to the overflows. This is only possible when the two
10781 // overflows cross k*2^W for the same k. In such case, if the second one
10782 // left the range (and was the first one to do so), the first overflow
10783 // would have to enter the range, which would mean that either we had left
10784 // the range before or that we started outside of it. Both of these cases
10785 // are contradictions.
10786 //
10787 // Claim: In the case where SolveForBoundary returns std::nullopt, the correct
10788 // solution is not some value between the Max for this boundary and the
10789 // Min of the other boundary.
10790 //
10791 // Justification: Assume that we had such Max_A and Min_B corresponding
10792 // to range boundaries A and B and such that Max_A < Min_B. If there was
10793 // a solution between Max_A and Min_B, it would have to be caused by an
10794 // overflow corresponding to either A or B. It cannot correspond to B,
10795 // since Min_B is the first occurrence of such an overflow. If it
10796 // corresponded to A, it would have to be either a signed or an unsigned
10797 // overflow that is larger than both eliminated overflows for A. But
10798 // between the eliminated overflows and this overflow, the values would
10799 // cover the entire value space, thus crossing the other boundary, which
10800 // is a contradiction.
10801
10802 return TruncIfPossible(MinOptional(SL.first, SU.first), BitWidth);
10803}
10804
10805ScalarEvolution::ExitLimit ScalarEvolution::howFarToZero(const SCEV *V,
10806 const Loop *L,
10807 bool ControlsOnlyExit,
10808 bool AllowPredicates) {
10809
10810 // This is only used for loops with a "x != y" exit test. The exit condition
10811 // is now expressed as a single expression, V = x-y. So the exit test is
10812 // effectively V != 0. We know and take advantage of the fact that this
10813 // expression only being used in a comparison by zero context.
10814
10816 // If the value is a constant
10817 if (const SCEVConstant *C = dyn_cast<SCEVConstant>(V)) {
10818 // If the value is already zero, the branch will execute zero times.
10819 if (C->getValue()->isZero()) return C;
10820 return getCouldNotCompute(); // Otherwise it will loop infinitely.
10821 }
10822
10823 const SCEVAddRecExpr *AddRec =
10824 dyn_cast<SCEVAddRecExpr>(stripInjectiveFunctions(V));
10825
10826 if (!AddRec && AllowPredicates)
10827 // Try to make this an AddRec using runtime tests, in the first X
10828 // iterations of this loop, where X is the SCEV expression found by the
10829 // algorithm below.
10830 AddRec = convertSCEVToAddRecWithPredicates(V, L, Predicates);
10831
10832 if (!AddRec || AddRec->getLoop() != L)
10833 return getCouldNotCompute();
10834
10835 // If this is a quadratic (3-term) AddRec {L,+,M,+,N}, find the roots of
10836 // the quadratic equation to solve it.
10837 if (AddRec->isQuadratic() && AddRec->getType()->isIntegerTy()) {
10838 // We can only use this value if the chrec ends up with an exact zero
10839 // value at this index. When solving for "X*X != 5", for example, we
10840 // should not accept a root of 2.
10841 if (auto S = SolveQuadraticAddRecExact(AddRec, *this)) {
10842 const auto *R = cast<SCEVConstant>(getConstant(*S));
10843 return ExitLimit(R, R, R, false, Predicates);
10844 }
10845 return getCouldNotCompute();
10846 }
10847
10848 // Otherwise we can only handle this if it is affine.
10849 if (!AddRec->isAffine())
10850 return getCouldNotCompute();
10851
10852 // If this is an affine expression, the execution count of this branch is
10853 // the minimum unsigned root of the following equation:
10854 //
10855 // Start + Step*N = 0 (mod 2^BW)
10856 //
10857 // equivalent to:
10858 //
10859 // Step*N = -Start (mod 2^BW)
10860 //
10861 // where BW is the common bit width of Start and Step.
10862
10863 // Get the initial value for the loop.
10864 const SCEV *Start = getSCEVAtScope(AddRec->getStart(), L->getParentLoop());
10865 const SCEV *Step = getSCEVAtScope(AddRec->getOperand(1), L->getParentLoop());
10866
10867 if (!isLoopInvariant(Step, L))
10868 return getCouldNotCompute();
10869
10870 LoopGuards Guards = LoopGuards::collect(L, *this);
10871 // Specialize step for this loop so we get context sensitive facts below.
10872 const SCEV *StepWLG = applyLoopGuards(Step, Guards);
10873
10874 // For positive steps (counting up until unsigned overflow):
10875 // N = -Start/Step (as unsigned)
10876 // For negative steps (counting down to zero):
10877 // N = Start/-Step
10878 // First compute the unsigned distance from zero in the direction of Step.
10879 bool CountDown = isKnownNegative(StepWLG);
10880 if (!CountDown && !isKnownNonNegative(StepWLG))
10881 return getCouldNotCompute();
10882
10883 const SCEV *Distance = CountDown ? Start : getNegativeSCEV(Start);
10884 // Handle unitary steps, which cannot wraparound.
10885 // 1*N = -Start; -1*N = Start (mod 2^BW), so:
10886 // N = Distance (as unsigned)
10887
10888 if (match(Step, m_CombineOr(m_scev_One(), m_scev_AllOnes()))) {
10889 APInt MaxBECount = getUnsignedRangeMax(applyLoopGuards(Distance, Guards));
10890 MaxBECount = APIntOps::umin(MaxBECount, getUnsignedRangeMax(Distance));
10891
10892 // When a loop like "for (int i = 0; i != n; ++i) { /* body */ }" is rotated,
10893 // we end up with a loop whose backedge-taken count is n - 1. Detect this
10894 // case, and see if we can improve the bound.
10895 //
10896 // Explicitly handling this here is necessary because getUnsignedRange
10897 // isn't context-sensitive; it doesn't know that we only care about the
10898 // range inside the loop.
10899 const SCEV *Zero = getZero(Distance->getType());
10900 const SCEV *One = getOne(Distance->getType());
10901 const SCEV *DistancePlusOne = getAddExpr(Distance, One);
10902 if (isLoopEntryGuardedByCond(L, ICmpInst::ICMP_NE, DistancePlusOne, Zero)) {
10903 // If Distance + 1 doesn't overflow, we can compute the maximum distance
10904 // as "unsigned_max(Distance + 1) - 1".
10905 ConstantRange CR = getUnsignedRange(DistancePlusOne);
10906 MaxBECount = APIntOps::umin(MaxBECount, CR.getUnsignedMax() - 1);
10907 }
10908 return ExitLimit(Distance, getConstant(MaxBECount), Distance, false,
10909 Predicates);
10910 }
10911
10912 // If the condition controls loop exit (the loop exits only if the expression
10913 // is true) and the addition is no-wrap we can use unsigned divide to
10914 // compute the backedge count. In this case, the step may not divide the
10915 // distance, but we don't care because if the condition is "missed" the loop
10916 // will have undefined behavior due to wrapping.
10917 if (ControlsOnlyExit && AddRec->hasNoSelfWrap() &&
10918 loopHasNoAbnormalExits(AddRec->getLoop())) {
10919
10920 // If the stride is zero and the start is non-zero, the loop must be
10921 // infinite. In C++, most loops are finite by assumption, in which case the
10922 // step being zero implies UB must execute if the loop is entered.
10923 if (!(loopIsFiniteByAssumption(L) && isKnownNonZero(Start)) &&
10924 !isKnownNonZero(StepWLG))
10925 return getCouldNotCompute();
10926
10927 const SCEV *Exact =
10928 getUDivExpr(Distance, CountDown ? getNegativeSCEV(Step) : Step);
10929 const SCEV *ConstantMax = getCouldNotCompute();
10930 if (Exact != getCouldNotCompute()) {
10931 APInt MaxInt = getUnsignedRangeMax(applyLoopGuards(Exact, Guards));
10932 ConstantMax =
10934 }
10935 const SCEV *SymbolicMax =
10936 isa<SCEVCouldNotCompute>(Exact) ? ConstantMax : Exact;
10937 return ExitLimit(Exact, ConstantMax, SymbolicMax, false, Predicates);
10938 }
10939
10940 // Solve the general equation.
10941 const SCEVConstant *StepC = dyn_cast<SCEVConstant>(Step);
10942 if (!StepC || StepC->getValue()->isZero())
10943 return getCouldNotCompute();
10944 const SCEV *E = SolveLinEquationWithOverflow(
10945 StepC->getAPInt(), getNegativeSCEV(Start),
10946 AllowPredicates ? &Predicates : nullptr, *this, L);
10947
10948 const SCEV *M = E;
10949 if (E != getCouldNotCompute()) {
10950 APInt MaxWithGuards = getUnsignedRangeMax(applyLoopGuards(E, Guards));
10951 M = getConstant(APIntOps::umin(MaxWithGuards, getUnsignedRangeMax(E)));
10952 }
10953 auto *S = isa<SCEVCouldNotCompute>(E) ? M : E;
10954 return ExitLimit(E, M, S, false, Predicates);
10955}
10956
10957ScalarEvolution::ExitLimit
10958ScalarEvolution::howFarToNonZero(const SCEV *V, const Loop *L) {
10959 // Loops that look like: while (X == 0) are very strange indeed. We don't
10960 // handle them yet except for the trivial case. This could be expanded in the
10961 // future as needed.
10962
10963 // If the value is a constant, check to see if it is known to be non-zero
10964 // already. If so, the backedge will execute zero times.
10965 if (const SCEVConstant *C = dyn_cast<SCEVConstant>(V)) {
10966 if (!C->getValue()->isZero())
10967 return getZero(C->getType());
10968 return getCouldNotCompute(); // Otherwise it will loop infinitely.
10969 }
10970
10971 // We could implement others, but I really doubt anyone writes loops like
10972 // this, and if they did, they would already be constant folded.
10973 return getCouldNotCompute();
10974}
10975
10976std::pair<const BasicBlock *, const BasicBlock *>
10977ScalarEvolution::getPredecessorWithUniqueSuccessorForBB(const BasicBlock *BB)
10978 const {
10979 // If the block has a unique predecessor, then there is no path from the
10980 // predecessor to the block that does not go through the direct edge
10981 // from the predecessor to the block.
10982 if (const BasicBlock *Pred = BB->getSinglePredecessor())
10983 return {Pred, BB};
10984
10985 // A loop's header is defined to be a block that dominates the loop.
10986 // If the header has a unique predecessor outside the loop, it must be
10987 // a block that has exactly one successor that can reach the loop.
10988 if (const Loop *L = LI.getLoopFor(BB))
10989 return {L->getLoopPredecessor(), L->getHeader()};
10990
10991 return {nullptr, BB};
10992}
10993
10994/// SCEV structural equivalence is usually sufficient for testing whether two
10995/// expressions are equal, however for the purposes of looking for a condition
10996/// guarding a loop, it can be useful to be a little more general, since a
10997/// front-end may have replicated the controlling expression.
10998static bool HasSameValue(const SCEV *A, const SCEV *B) {
10999 // Quick check to see if they are the same SCEV.
11000 if (A == B) return true;
11001
11002 auto ComputesEqualValues = [](const Instruction *A, const Instruction *B) {
11003 // Not all instructions that are "identical" compute the same value. For
11004 // instance, two distinct alloca instructions allocating the same type are
11005 // identical and do not read memory; but compute distinct values.
11006 return A->isIdenticalTo(B) && (isa<BinaryOperator>(A) || isa<GetElementPtrInst>(A));
11007 };
11008
11009 // Otherwise, if they're both SCEVUnknown, it's possible that they hold
11010 // two different instructions with the same value. Check for this case.
11011 if (const SCEVUnknown *AU = dyn_cast<SCEVUnknown>(A))
11012 if (const SCEVUnknown *BU = dyn_cast<SCEVUnknown>(B))
11013 if (const Instruction *AI = dyn_cast<Instruction>(AU->getValue()))
11014 if (const Instruction *BI = dyn_cast<Instruction>(BU->getValue()))
11015 if (ComputesEqualValues(AI, BI))
11016 return true;
11017
11018 // Otherwise assume they may have a different value.
11019 return false;
11020}
11021
11022static bool MatchBinarySub(const SCEV *S, SCEVUse &LHS, SCEVUse &RHS) {
11023 const SCEV *Op0, *Op1;
11024 if (!match(S, m_scev_Add(m_SCEV(Op0), m_SCEV(Op1))))
11025 return false;
11026 if (match(Op0, m_scev_Mul(m_scev_AllOnes(), m_SCEV(RHS)))) {
11027 LHS = Op1;
11028 return true;
11029 }
11030 if (match(Op1, m_scev_Mul(m_scev_AllOnes(), m_SCEV(RHS)))) {
11031 LHS = Op0;
11032 return true;
11033 }
11034 return false;
11035}
11036
11038 SCEVUse &RHS, unsigned Depth) {
11039 bool Changed = false;
11040 // Simplifies ICMP to trivial true or false by turning it into '0 == 0' or
11041 // '0 != 0'.
11042 auto TrivialCase = [&](bool TriviallyTrue) {
11044 Pred = TriviallyTrue ? ICmpInst::ICMP_EQ : ICmpInst::ICMP_NE;
11045 return true;
11046 };
11047 // If we hit the max recursion limit bail out.
11048 if (Depth >= 3)
11049 return false;
11050
11051 const SCEV *NewLHS, *NewRHS;
11052 if (match(LHS, m_scev_c_Mul(m_SCEV(NewLHS), m_SCEVVScale())) &&
11053 match(RHS, m_scev_c_Mul(m_SCEV(NewRHS), m_SCEVVScale()))) {
11054 const SCEVMulExpr *LMul = cast<SCEVMulExpr>(LHS);
11055 const SCEVMulExpr *RMul = cast<SCEVMulExpr>(RHS);
11056
11057 // (X * vscale) pred (Y * vscale) ==> X pred Y
11058 // when both multiples are NSW.
11059 // (X * vscale) uicmp/eq/ne (Y * vscale) ==> X uicmp/eq/ne Y
11060 // when both multiples are NUW.
11061 if ((LMul->hasNoSignedWrap() && RMul->hasNoSignedWrap()) ||
11062 (LMul->hasNoUnsignedWrap() && RMul->hasNoUnsignedWrap() &&
11063 !ICmpInst::isSigned(Pred))) {
11064 LHS = NewLHS;
11065 RHS = NewRHS;
11066 Changed = true;
11067 }
11068 }
11069
11070 // Canonicalize a constant to the right side.
11071 if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(LHS)) {
11072 // Check for both operands constant.
11073 if (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS)) {
11074 if (!ICmpInst::compare(LHSC->getAPInt(), RHSC->getAPInt(), Pred))
11075 return TrivialCase(false);
11076 return TrivialCase(true);
11077 }
11078 // Otherwise swap the operands to put the constant on the right.
11079 std::swap(LHS, RHS);
11081 Changed = true;
11082 }
11083
11084 // If we're comparing an addrec with a value which is loop-invariant in the
11085 // addrec's loop, put the addrec on the left. Also make a dominance check,
11086 // as both operands could be addrecs loop-invariant in each other's loop.
11087 if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(RHS)) {
11088 const Loop *L = AR->getLoop();
11089 if (isLoopInvariant(LHS, L) && properlyDominates(LHS, L->getHeader())) {
11090 std::swap(LHS, RHS);
11092 Changed = true;
11093 }
11094 }
11095
11096 // If there's a constant operand, canonicalize comparisons with boundary
11097 // cases, and canonicalize *-or-equal comparisons to regular comparisons.
11098 if (const SCEVConstant *RC = dyn_cast<SCEVConstant>(RHS)) {
11099 const APInt &RA = RC->getAPInt();
11100
11101 bool SimplifiedByConstantRange = false;
11102
11103 if (!ICmpInst::isEquality(Pred)) {
11105 if (ExactCR.isFullSet())
11106 return TrivialCase(true);
11107 if (ExactCR.isEmptySet())
11108 return TrivialCase(false);
11109
11110 APInt NewRHS;
11111 CmpInst::Predicate NewPred;
11112 if (ExactCR.getEquivalentICmp(NewPred, NewRHS) &&
11113 ICmpInst::isEquality(NewPred)) {
11114 // We were able to convert an inequality to an equality.
11115 Pred = NewPred;
11116 RHS = getConstant(NewRHS);
11117 Changed = SimplifiedByConstantRange = true;
11118 }
11119 }
11120
11121 if (!SimplifiedByConstantRange) {
11122 switch (Pred) {
11123 default:
11124 break;
11125 case ICmpInst::ICMP_EQ:
11126 case ICmpInst::ICMP_NE:
11127 // Fold ((-1) * %a) + %b == 0 (equivalent to %b-%a == 0) into %a == %b.
11128 if (RA.isZero() && MatchBinarySub(LHS, LHS, RHS))
11129 Changed = true;
11130 break;
11131
11132 // The "Should have been caught earlier!" messages refer to the fact
11133 // that the ExactCR.isFullSet() or ExactCR.isEmptySet() check above
11134 // should have fired on the corresponding cases, and canonicalized the
11135 // check to trivial case.
11136
11137 case ICmpInst::ICMP_UGE:
11138 assert(!RA.isMinValue() && "Should have been caught earlier!");
11139 Pred = ICmpInst::ICMP_UGT;
11140 RHS = getConstant(RA - 1);
11141 Changed = true;
11142 break;
11143 case ICmpInst::ICMP_ULE:
11144 assert(!RA.isMaxValue() && "Should have been caught earlier!");
11145 Pred = ICmpInst::ICMP_ULT;
11146 RHS = getConstant(RA + 1);
11147 Changed = true;
11148 break;
11149 case ICmpInst::ICMP_SGE:
11150 assert(!RA.isMinSignedValue() && "Should have been caught earlier!");
11151 Pred = ICmpInst::ICMP_SGT;
11152 RHS = getConstant(RA - 1);
11153 Changed = true;
11154 break;
11155 case ICmpInst::ICMP_SLE:
11156 assert(!RA.isMaxSignedValue() && "Should have been caught earlier!");
11157 Pred = ICmpInst::ICMP_SLT;
11158 RHS = getConstant(RA + 1);
11159 Changed = true;
11160 break;
11161 }
11162 }
11163 }
11164
11165 // Check for obvious equality.
11166 if (HasSameValue(LHS, RHS)) {
11167 if (ICmpInst::isTrueWhenEqual(Pred))
11168 return TrivialCase(true);
11170 return TrivialCase(false);
11171 }
11172
11173 // If possible, canonicalize GE/LE comparisons to GT/LT comparisons, by
11174 // adding or subtracting 1 from one of the operands.
11175 switch (Pred) {
11176 case ICmpInst::ICMP_SLE:
11177 if (!getSignedRangeMax(RHS).isMaxSignedValue()) {
11178 RHS = getAddExpr(getConstant(RHS->getType(), 1, true), RHS,
11180 Pred = ICmpInst::ICMP_SLT;
11181 Changed = true;
11182 } else if (!getSignedRangeMin(LHS).isMinSignedValue()) {
11183 LHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), LHS,
11185 Pred = ICmpInst::ICMP_SLT;
11186 Changed = true;
11187 }
11188 break;
11189 case ICmpInst::ICMP_SGE:
11190 if (!getSignedRangeMin(RHS).isMinSignedValue()) {
11191 RHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), RHS,
11193 Pred = ICmpInst::ICMP_SGT;
11194 Changed = true;
11195 } else if (!getSignedRangeMax(LHS).isMaxSignedValue()) {
11196 LHS = getAddExpr(getConstant(RHS->getType(), 1, true), LHS,
11198 Pred = ICmpInst::ICMP_SGT;
11199 Changed = true;
11200 }
11201 break;
11202 case ICmpInst::ICMP_ULE:
11203 if (!getUnsignedRangeMax(RHS).isMaxValue()) {
11204 RHS = getAddExpr(getConstant(RHS->getType(), 1, true), RHS,
11206 Pred = ICmpInst::ICMP_ULT;
11207 Changed = true;
11208 } else if (!getUnsignedRangeMin(LHS).isMinValue()) {
11209 LHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), LHS);
11210 Pred = ICmpInst::ICMP_ULT;
11211 Changed = true;
11212 }
11213 break;
11214 case ICmpInst::ICMP_UGE:
11215 // If RHS is an op we can fold the -1, try that first.
11216 // Otherwise prefer LHS to preserve the nuw flag.
11217 if ((isa<SCEVConstant>(RHS) ||
11219 isa<SCEVConstant>(cast<SCEVNAryExpr>(RHS)->getOperand(0)))) &&
11220 !getUnsignedRangeMin(RHS).isMinValue()) {
11221 RHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), RHS);
11222 Pred = ICmpInst::ICMP_UGT;
11223 Changed = true;
11224 } else if (!getUnsignedRangeMax(LHS).isMaxValue()) {
11225 LHS = getAddExpr(getConstant(RHS->getType(), 1, true), LHS,
11227 Pred = ICmpInst::ICMP_UGT;
11228 Changed = true;
11229 } else if (!getUnsignedRangeMin(RHS).isMinValue()) {
11230 RHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), RHS);
11231 Pred = ICmpInst::ICMP_UGT;
11232 Changed = true;
11233 }
11234 break;
11235 default:
11236 break;
11237 }
11238
11239 // TODO: More simplifications are possible here.
11240
11241 // Recursively simplify until we either hit a recursion limit or nothing
11242 // changes.
11243 if (Changed)
11244 (void)SimplifyICmpOperands(Pred, LHS, RHS, Depth + 1);
11245
11246 return Changed;
11247}
11248
11250 return getSignedRangeMax(S).isNegative();
11251}
11252
11256
11258 return !getSignedRangeMin(S).isNegative();
11259}
11260
11264
11266 // Query push down for cases where the unsigned range is
11267 // less than sufficient.
11268 if (const auto *SExt = dyn_cast<SCEVSignExtendExpr>(S))
11269 return isKnownNonZero(SExt->getOperand(0));
11270 return getUnsignedRangeMin(S) != 0;
11271}
11272
11274 bool OrNegative) {
11275 auto NonRecursive = [OrNegative](const SCEV *S) {
11276 if (auto *C = dyn_cast<SCEVConstant>(S))
11277 return C->getAPInt().isPowerOf2() ||
11278 (OrNegative && C->getAPInt().isNegatedPowerOf2());
11279
11280 // vscale is a power-of-two.
11281 return isa<SCEVVScale>(S);
11282 };
11283
11284 if (NonRecursive(S))
11285 return true;
11286
11287 auto *Mul = dyn_cast<SCEVMulExpr>(S);
11288 if (!Mul)
11289 return false;
11290 return all_of(Mul->operands(), NonRecursive) && (OrZero || isKnownNonZero(S));
11291}
11292
11294 const SCEV *S, uint64_t M,
11296 if (M == 0)
11297 return false;
11298 if (M == 1)
11299 return true;
11300
11301 // Recursively check AddRec operands. An AddRecExpr S is a multiple of M if S
11302 // starts with a multiple of M and at every iteration step S only adds
11303 // multiples of M.
11304 if (auto *AddRec = dyn_cast<SCEVAddRecExpr>(S))
11305 return isKnownMultipleOf(AddRec->getStart(), M, Assumptions) &&
11306 isKnownMultipleOf(AddRec->getStepRecurrence(*this), M, Assumptions);
11307
11308 // For a constant, check that "S % M == 0".
11309 if (auto *Cst = dyn_cast<SCEVConstant>(S)) {
11310 APInt C = Cst->getAPInt();
11311 return C.urem(M) == 0;
11312 }
11313
11314 // TODO: Also check other SCEV expressions, i.e., SCEVAddRecExpr, etc.
11315
11316 // Basic tests have failed.
11317 // Check "S % M == 0" at compile time and record runtime Assumptions.
11318 auto *STy = dyn_cast<IntegerType>(S->getType());
11319 const SCEV *SmodM =
11320 getURemExpr(S, getConstant(ConstantInt::get(STy, M, false)));
11321 const SCEV *Zero = getZero(STy);
11322
11323 // Check whether "S % M == 0" is known at compile time.
11324 if (isKnownPredicate(ICmpInst::ICMP_EQ, SmodM, Zero))
11325 return true;
11326
11327 // Check whether "S % M != 0" is known at compile time.
11328 if (isKnownPredicate(ICmpInst::ICMP_NE, SmodM, Zero))
11329 return false;
11330
11332
11333 // Detect redundant predicates.
11334 for (auto *A : Assumptions)
11335 if (A->implies(P, *this))
11336 return true;
11337
11338 // Only record non-redundant predicates.
11339 Assumptions.push_back(P);
11340 return true;
11341}
11342
11344 return ((isKnownNonNegative(S1) && isKnownNonNegative(S2)) ||
11346}
11347
11348std::pair<const SCEV *, const SCEV *>
11350 // Compute SCEV on entry of loop L.
11351 const SCEV *Start = SCEVInitRewriter::rewrite(S, L, *this);
11352 if (Start == getCouldNotCompute())
11353 return { Start, Start };
11354 // Compute post increment SCEV for loop L.
11355 const SCEV *PostInc = SCEVPostIncRewriter::rewrite(S, L, *this);
11356 assert(PostInc != getCouldNotCompute() && "Unexpected could not compute");
11357 return { Start, PostInc };
11358}
11359
11361 SCEVUse RHS) {
11362 // First collect all loops.
11364 getUsedLoops(LHS, LoopsUsed);
11365 getUsedLoops(RHS, LoopsUsed);
11366
11367 if (LoopsUsed.empty())
11368 return false;
11369
11370 // Domination relationship must be a linear order on collected loops.
11371#ifndef NDEBUG
11372 for (const auto *L1 : LoopsUsed)
11373 for (const auto *L2 : LoopsUsed)
11374 assert((DT.dominates(L1->getHeader(), L2->getHeader()) ||
11375 DT.dominates(L2->getHeader(), L1->getHeader())) &&
11376 "Domination relationship is not a linear order");
11377#endif
11378
11379 const Loop *MDL =
11380 *llvm::max_element(LoopsUsed, [&](const Loop *L1, const Loop *L2) {
11381 return DT.properlyDominates(L1->getHeader(), L2->getHeader());
11382 });
11383
11384 // Get init and post increment value for LHS.
11385 auto SplitLHS = SplitIntoInitAndPostInc(MDL, LHS);
11386 // if LHS contains unknown non-invariant SCEV then bail out.
11387 if (SplitLHS.first == getCouldNotCompute())
11388 return false;
11389 assert (SplitLHS.second != getCouldNotCompute() && "Unexpected CNC");
11390 // Get init and post increment value for RHS.
11391 auto SplitRHS = SplitIntoInitAndPostInc(MDL, RHS);
11392 // if RHS contains unknown non-invariant SCEV then bail out.
11393 if (SplitRHS.first == getCouldNotCompute())
11394 return false;
11395 assert (SplitRHS.second != getCouldNotCompute() && "Unexpected CNC");
11396 // It is possible that init SCEV contains an invariant load but it does
11397 // not dominate MDL and is not available at MDL loop entry, so we should
11398 // check it here.
11399 if (!isAvailableAtLoopEntry(SplitLHS.first, MDL) ||
11400 !isAvailableAtLoopEntry(SplitRHS.first, MDL))
11401 return false;
11402
11403 // It seems backedge guard check is faster than entry one so in some cases
11404 // it can speed up whole estimation by short circuit
11405 return isLoopBackedgeGuardedByCond(MDL, Pred, SplitLHS.second,
11406 SplitRHS.second) &&
11407 isLoopEntryGuardedByCond(MDL, Pred, SplitLHS.first, SplitRHS.first);
11408}
11409
11411 SCEVUse RHS) {
11412 // Canonicalize the inputs first.
11413 (void)SimplifyICmpOperands(Pred, LHS, RHS);
11414
11415 if (isKnownViaInduction(Pred, LHS, RHS))
11416 return true;
11417
11418 if (isKnownPredicateViaSplitting(Pred, LHS, RHS))
11419 return true;
11420
11421 // Otherwise see what can be done with some simple reasoning.
11422 return isKnownViaNonRecursiveReasoning(Pred, LHS, RHS);
11423}
11424
11426 const SCEV *LHS,
11427 const SCEV *RHS) {
11428 if (isKnownPredicate(Pred, LHS, RHS))
11429 return true;
11431 return false;
11432 return std::nullopt;
11433}
11434
11436 const SCEV *RHS,
11437 const Instruction *CtxI) {
11438 // TODO: Analyze guards and assumes from Context's block.
11439 return isKnownPredicate(Pred, LHS, RHS) ||
11440 isBasicBlockEntryGuardedByCond(CtxI->getParent(), Pred, LHS, RHS);
11441}
11442
11443std::optional<bool>
11445 const SCEV *RHS, const Instruction *CtxI) {
11446 std::optional<bool> KnownWithoutContext = evaluatePredicate(Pred, LHS, RHS);
11447 if (KnownWithoutContext)
11448 return KnownWithoutContext;
11449
11450 if (isBasicBlockEntryGuardedByCond(CtxI->getParent(), Pred, LHS, RHS))
11451 return true;
11453 CtxI->getParent(), ICmpInst::getInverseCmpPredicate(Pred), LHS, RHS))
11454 return false;
11455 return std::nullopt;
11456}
11457
11459 const SCEVAddRecExpr *LHS,
11460 const SCEV *RHS) {
11461 const Loop *L = LHS->getLoop();
11462 return isLoopEntryGuardedByCond(L, Pred, LHS->getStart(), RHS) &&
11463 isLoopBackedgeGuardedByCond(L, Pred, LHS->getPostIncExpr(*this), RHS);
11464}
11465
11466std::optional<ScalarEvolution::MonotonicPredicateType>
11468 ICmpInst::Predicate Pred) {
11469 auto Result = getMonotonicPredicateTypeImpl(LHS, Pred);
11470
11471#ifndef NDEBUG
11472 // Verify an invariant: inverting the predicate should turn a monotonically
11473 // increasing change to a monotonically decreasing one, and vice versa.
11474 if (Result) {
11475 auto ResultSwapped =
11476 getMonotonicPredicateTypeImpl(LHS, ICmpInst::getSwappedPredicate(Pred));
11477
11478 assert(*ResultSwapped != *Result &&
11479 "monotonicity should flip as we flip the predicate");
11480 }
11481#endif
11482
11483 return Result;
11484}
11485
11486std::optional<ScalarEvolution::MonotonicPredicateType>
11487ScalarEvolution::getMonotonicPredicateTypeImpl(const SCEVAddRecExpr *LHS,
11488 ICmpInst::Predicate Pred) {
11489 // A zero step value for LHS means the induction variable is essentially a
11490 // loop invariant value. We don't really depend on the predicate actually
11491 // flipping from false to true (for increasing predicates, and the other way
11492 // around for decreasing predicates), all we care about is that *if* the
11493 // predicate changes then it only changes from false to true.
11494 //
11495 // A zero step value in itself is not very useful, but there may be places
11496 // where SCEV can prove X >= 0 but not prove X > 0, so it is helpful to be
11497 // as general as possible.
11498
11499 // Only handle LE/LT/GE/GT predicates.
11500 if (!ICmpInst::isRelational(Pred))
11501 return std::nullopt;
11502
11503 bool IsGreater = ICmpInst::isGE(Pred) || ICmpInst::isGT(Pred);
11504 assert((IsGreater || ICmpInst::isLE(Pred) || ICmpInst::isLT(Pred)) &&
11505 "Should be greater or less!");
11506
11507 // Check that AR does not wrap.
11508 if (ICmpInst::isUnsigned(Pred)) {
11509 if (!LHS->hasNoUnsignedWrap())
11510 return std::nullopt;
11512 }
11513 assert(ICmpInst::isSigned(Pred) &&
11514 "Relational predicate is either signed or unsigned!");
11515 if (!LHS->hasNoSignedWrap())
11516 return std::nullopt;
11517
11518 const SCEV *Step = LHS->getStepRecurrence(*this);
11519
11520 if (isKnownNonNegative(Step))
11522
11523 if (isKnownNonPositive(Step))
11525
11526 return std::nullopt;
11527}
11528
11529std::optional<ScalarEvolution::LoopInvariantPredicate>
11531 const SCEV *RHS, const Loop *L,
11532 const Instruction *CtxI) {
11533 // If there is a loop-invariant, force it into the RHS, otherwise bail out.
11534 if (!isLoopInvariant(RHS, L)) {
11535 if (!isLoopInvariant(LHS, L))
11536 return std::nullopt;
11537
11538 std::swap(LHS, RHS);
11540 }
11541
11542 const SCEVAddRecExpr *ArLHS = dyn_cast<SCEVAddRecExpr>(LHS);
11543 if (!ArLHS || ArLHS->getLoop() != L)
11544 return std::nullopt;
11545
11546 auto MonotonicType = getMonotonicPredicateType(ArLHS, Pred);
11547 if (!MonotonicType)
11548 return std::nullopt;
11549 // If the predicate "ArLHS `Pred` RHS" monotonically increases from false to
11550 // true as the loop iterates, and the backedge is control dependent on
11551 // "ArLHS `Pred` RHS" == true then we can reason as follows:
11552 //
11553 // * if the predicate was false in the first iteration then the predicate
11554 // is never evaluated again, since the loop exits without taking the
11555 // backedge.
11556 // * if the predicate was true in the first iteration then it will
11557 // continue to be true for all future iterations since it is
11558 // monotonically increasing.
11559 //
11560 // For both the above possibilities, we can replace the loop varying
11561 // predicate with its value on the first iteration of the loop (which is
11562 // loop invariant).
11563 //
11564 // A similar reasoning applies for a monotonically decreasing predicate, by
11565 // replacing true with false and false with true in the above two bullets.
11567 auto P = Increasing ? Pred : ICmpInst::getInverseCmpPredicate(Pred);
11568
11569 if (isLoopBackedgeGuardedByCond(L, P, LHS, RHS))
11571 RHS);
11572
11573 if (!CtxI)
11574 return std::nullopt;
11575 // Try to prove via context.
11576 // TODO: Support other cases.
11577 switch (Pred) {
11578 default:
11579 break;
11580 case ICmpInst::ICMP_ULE:
11581 case ICmpInst::ICMP_ULT: {
11582 assert(ArLHS->hasNoUnsignedWrap() && "Is a requirement of monotonicity!");
11583 // Given preconditions
11584 // (1) ArLHS does not cross the border of positive and negative parts of
11585 // range because of:
11586 // - Positive step; (TODO: lift this limitation)
11587 // - nuw - does not cross zero boundary;
11588 // - nsw - does not cross SINT_MAX boundary;
11589 // (2) ArLHS <s RHS
11590 // (3) RHS >=s 0
11591 // we can replace the loop variant ArLHS <u RHS condition with loop
11592 // invariant Start(ArLHS) <u RHS.
11593 //
11594 // Because of (1) there are two options:
11595 // - ArLHS is always negative. It means that ArLHS <u RHS is always false;
11596 // - ArLHS is always non-negative. Because of (3) RHS is also non-negative.
11597 // It means that ArLHS <s RHS <=> ArLHS <u RHS.
11598 // Because of (2) ArLHS <u RHS is trivially true.
11599 // All together it means that ArLHS <u RHS <=> Start(ArLHS) >=s 0.
11600 // We can strengthen this to Start(ArLHS) <u RHS.
11601 auto SignFlippedPred = ICmpInst::getFlippedSignednessPredicate(Pred);
11602 if (ArLHS->hasNoSignedWrap() && ArLHS->isAffine() &&
11603 isKnownPositive(ArLHS->getStepRecurrence(*this)) &&
11604 isKnownNonNegative(RHS) &&
11605 isKnownPredicateAt(SignFlippedPred, ArLHS, RHS, CtxI))
11607 RHS);
11608 }
11609 }
11610
11611 return std::nullopt;
11612}
11613
11614std::optional<ScalarEvolution::LoopInvariantPredicate>
11616 CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS, const Loop *L,
11617 const Instruction *CtxI, const SCEV *MaxIter) {
11619 Pred, LHS, RHS, L, CtxI, MaxIter))
11620 return LIP;
11621 if (auto *UMin = dyn_cast<SCEVUMinExpr>(MaxIter))
11622 // Number of iterations expressed as UMIN isn't always great for expressing
11623 // the value on the last iteration. If the straightforward approach didn't
11624 // work, try the following trick: if the a predicate is invariant for X, it
11625 // is also invariant for umin(X, ...). So try to find something that works
11626 // among subexpressions of MaxIter expressed as umin.
11627 for (SCEVUse Op : UMin->operands())
11629 Pred, LHS, RHS, L, CtxI, Op))
11630 return LIP;
11631 return std::nullopt;
11632}
11633
11634std::optional<ScalarEvolution::LoopInvariantPredicate>
11636 CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS, const Loop *L,
11637 const Instruction *CtxI, const SCEV *MaxIter) {
11638 // Try to prove the following set of facts:
11639 // - The predicate is monotonic in the iteration space.
11640 // - If the check does not fail on the 1st iteration:
11641 // - No overflow will happen during first MaxIter iterations;
11642 // - It will not fail on the MaxIter'th iteration.
11643 // If the check does fail on the 1st iteration, we leave the loop and no
11644 // other checks matter.
11645
11646 // If there is a loop-invariant, force it into the RHS, otherwise bail out.
11647 if (!isLoopInvariant(RHS, L)) {
11648 if (!isLoopInvariant(LHS, L))
11649 return std::nullopt;
11650
11651 std::swap(LHS, RHS);
11653 }
11654
11655 auto *AR = dyn_cast<SCEVAddRecExpr>(LHS);
11656 if (!AR || AR->getLoop() != L)
11657 return std::nullopt;
11658
11659 // Even if both are valid, we need to consistently chose the unsigned or the
11660 // signed predicate below, not mixtures of both. For now, prefer the unsigned
11661 // predicate.
11662 Pred = Pred.dropSameSign();
11663
11664 // The predicate must be relational (i.e. <, <=, >=, >).
11665 if (!ICmpInst::isRelational(Pred))
11666 return std::nullopt;
11667
11668 // TODO: Support steps other than +/- 1.
11669 const SCEV *Step = AR->getStepRecurrence(*this);
11670 auto *One = getOne(Step->getType());
11671 auto *MinusOne = getNegativeSCEV(One);
11672 if (Step != One && Step != MinusOne)
11673 return std::nullopt;
11674
11675 // Type mismatch here means that MaxIter is potentially larger than max
11676 // unsigned value in start type, which mean we cannot prove no wrap for the
11677 // indvar.
11678 if (AR->getType() != MaxIter->getType())
11679 return std::nullopt;
11680
11681 // Value of IV on suggested last iteration.
11682 const SCEV *Last = AR->evaluateAtIteration(MaxIter, *this);
11683 // Does it still meet the requirement?
11684 if (!isLoopBackedgeGuardedByCond(L, Pred, Last, RHS))
11685 return std::nullopt;
11686 // Because step is +/- 1 and MaxIter has same type as Start (i.e. it does
11687 // not exceed max unsigned value of this type), this effectively proves
11688 // that there is no wrap during the iteration. To prove that there is no
11689 // signed/unsigned wrap, we need to check that
11690 // Start <= Last for step = 1 or Start >= Last for step = -1.
11691 ICmpInst::Predicate NoOverflowPred =
11693 if (Step == MinusOne)
11694 NoOverflowPred = ICmpInst::getSwappedPredicate(NoOverflowPred);
11695 const SCEV *Start = AR->getStart();
11696 if (!isKnownPredicateAt(NoOverflowPred, Start, Last, CtxI))
11697 return std::nullopt;
11698
11699 // Everything is fine.
11700 return ScalarEvolution::LoopInvariantPredicate(Pred, Start, RHS);
11701}
11702
11703bool ScalarEvolution::isKnownPredicateViaConstantRanges(CmpPredicate Pred,
11704 SCEVUse LHS,
11705 SCEVUse RHS) {
11706 if (HasSameValue(LHS, RHS))
11707 return ICmpInst::isTrueWhenEqual(Pred);
11708
11709 auto CheckRange = [&](bool IsSigned) {
11710 auto RangeLHS = IsSigned ? getSignedRange(LHS) : getUnsignedRange(LHS);
11711 auto RangeRHS = IsSigned ? getSignedRange(RHS) : getUnsignedRange(RHS);
11712 return RangeLHS.icmp(Pred, RangeRHS);
11713 };
11714
11715 // The check at the top of the function catches the case where the values are
11716 // known to be equal.
11717 if (Pred == CmpInst::ICMP_EQ)
11718 return false;
11719
11720 if (Pred == CmpInst::ICMP_NE) {
11721 if (CheckRange(true) || CheckRange(false))
11722 return true;
11723 auto *Diff = getMinusSCEV(LHS, RHS);
11724 return !isa<SCEVCouldNotCompute>(Diff) && isKnownNonZero(Diff);
11725 }
11726
11727 return CheckRange(CmpInst::isSigned(Pred));
11728}
11729
11730bool ScalarEvolution::isKnownPredicateViaNoOverflow(CmpPredicate Pred,
11732 // Match X to (A + C1)<ExpectedFlags> and Y to (A + C2)<ExpectedFlags>, where
11733 // C1 and C2 are constant integers. If either X or Y are not add expressions,
11734 // consider them as X + 0 and Y + 0 respectively. C1 and C2 are returned via
11735 // OutC1 and OutC2.
11736 auto MatchBinaryAddToConst = [this](SCEVUse X, SCEVUse Y, APInt &OutC1,
11737 APInt &OutC2,
11738 SCEV::NoWrapFlags ExpectedFlags) {
11739 SCEVUse XNonConstOp, XConstOp;
11740 SCEVUse YNonConstOp, YConstOp;
11741 SCEV::NoWrapFlags XFlagsPresent;
11742 SCEV::NoWrapFlags YFlagsPresent;
11743
11744 if (!splitBinaryAdd(X, XConstOp, XNonConstOp, XFlagsPresent)) {
11745 XConstOp = getZero(X->getType());
11746 XNonConstOp = X;
11747 XFlagsPresent = ExpectedFlags;
11748 }
11749 if (!isa<SCEVConstant>(XConstOp))
11750 return false;
11751
11752 if (!splitBinaryAdd(Y, YConstOp, YNonConstOp, YFlagsPresent)) {
11753 YConstOp = getZero(Y->getType());
11754 YNonConstOp = Y;
11755 YFlagsPresent = ExpectedFlags;
11756 }
11757
11758 if (YNonConstOp != XNonConstOp)
11759 return false;
11760
11761 if (!isa<SCEVConstant>(YConstOp))
11762 return false;
11763
11764 // When matching ADDs with NUW flags (and unsigned predicates), only the
11765 // second ADD (with the larger constant) requires NUW.
11766 if ((YFlagsPresent & ExpectedFlags) != ExpectedFlags)
11767 return false;
11768 if (ExpectedFlags != SCEV::FlagNUW &&
11769 (XFlagsPresent & ExpectedFlags) != ExpectedFlags) {
11770 return false;
11771 }
11772
11773 OutC1 = cast<SCEVConstant>(XConstOp)->getAPInt();
11774 OutC2 = cast<SCEVConstant>(YConstOp)->getAPInt();
11775
11776 return true;
11777 };
11778
11779 APInt C1;
11780 APInt C2;
11781
11782 switch (Pred) {
11783 default:
11784 break;
11785
11786 case ICmpInst::ICMP_SGE:
11787 std::swap(LHS, RHS);
11788 [[fallthrough]];
11789 case ICmpInst::ICMP_SLE:
11790 // (X + C1)<nsw> s<= (X + C2)<nsw> if C1 s<= C2.
11791 if (MatchBinaryAddToConst(LHS, RHS, C1, C2, SCEV::FlagNSW) && C1.sle(C2))
11792 return true;
11793
11794 break;
11795
11796 case ICmpInst::ICMP_SGT:
11797 std::swap(LHS, RHS);
11798 [[fallthrough]];
11799 case ICmpInst::ICMP_SLT:
11800 // (X + C1)<nsw> s< (X + C2)<nsw> if C1 s< C2.
11801 if (MatchBinaryAddToConst(LHS, RHS, C1, C2, SCEV::FlagNSW) && C1.slt(C2))
11802 return true;
11803
11804 break;
11805
11806 case ICmpInst::ICMP_UGE:
11807 std::swap(LHS, RHS);
11808 [[fallthrough]];
11809 case ICmpInst::ICMP_ULE:
11810 // (X + C1) u<= (X + C2)<nuw> for C1 u<= C2.
11811 if (MatchBinaryAddToConst(LHS, RHS, C1, C2, SCEV::FlagNUW) && C1.ule(C2))
11812 return true;
11813
11814 break;
11815
11816 case ICmpInst::ICMP_UGT:
11817 std::swap(LHS, RHS);
11818 [[fallthrough]];
11819 case ICmpInst::ICMP_ULT:
11820 // (X + C1) u< (X + C2)<nuw> if C1 u< C2.
11821 if (MatchBinaryAddToConst(LHS, RHS, C1, C2, SCEV::FlagNUW) && C1.ult(C2))
11822 return true;
11823 break;
11824 }
11825
11826 return false;
11827}
11828
11829bool ScalarEvolution::isKnownPredicateViaSplitting(CmpPredicate Pred,
11831 if (Pred != ICmpInst::ICMP_ULT || ProvingSplitPredicate)
11832 return false;
11833
11834 // Allowing arbitrary number of activations of isKnownPredicateViaSplitting on
11835 // the stack can result in exponential time complexity.
11836 SaveAndRestore Restore(ProvingSplitPredicate, true);
11837
11838 // If L >= 0 then I `ult` L <=> I >= 0 && I `slt` L
11839 //
11840 // To prove L >= 0 we use isKnownNonNegative whereas to prove I >= 0 we use
11841 // isKnownPredicate. isKnownPredicate is more powerful, but also more
11842 // expensive; and using isKnownNonNegative(RHS) is sufficient for most of the
11843 // interesting cases seen in practice. We can consider "upgrading" L >= 0 to
11844 // use isKnownPredicate later if needed.
11845 return isKnownNonNegative(RHS) &&
11848}
11849
11850bool ScalarEvolution::isImpliedViaGuard(const BasicBlock *BB, CmpPredicate Pred,
11851 const SCEV *LHS, const SCEV *RHS) {
11852 // No need to even try if we know the module has no guards.
11853 if (!HasGuards)
11854 return false;
11855
11856 return any_of(*BB, [&](const Instruction &I) {
11857 using namespace llvm::PatternMatch;
11858
11859 Value *Condition;
11861 m_Value(Condition))) &&
11862 isImpliedCond(Pred, LHS, RHS, Condition, false);
11863 });
11864}
11865
11866/// isLoopBackedgeGuardedByCond - Test whether the backedge of the loop is
11867/// protected by a conditional between LHS and RHS. This is used to
11868/// to eliminate casts.
11870 CmpPredicate Pred,
11871 const SCEV *LHS,
11872 const SCEV *RHS) {
11873 // Interpret a null as meaning no loop, where there is obviously no guard
11874 // (interprocedural conditions notwithstanding). Do not bother about
11875 // unreachable loops.
11876 if (!L || !DT.isReachableFromEntry(L->getHeader()))
11877 return true;
11878
11879 if (VerifyIR)
11880 assert(!verifyFunction(*L->getHeader()->getParent(), &dbgs()) &&
11881 "This cannot be done on broken IR!");
11882
11883
11884 if (isKnownViaNonRecursiveReasoning(Pred, LHS, RHS))
11885 return true;
11886
11887 BasicBlock *Latch = L->getLoopLatch();
11888 if (!Latch)
11889 return false;
11890
11891 CondBrInst *LoopContinuePredicate =
11893 if (LoopContinuePredicate &&
11894 isImpliedCond(Pred, LHS, RHS, LoopContinuePredicate->getCondition(),
11895 LoopContinuePredicate->getSuccessor(0) != L->getHeader()))
11896 return true;
11897
11898 // We don't want more than one activation of the following loops on the stack
11899 // -- that can lead to O(n!) time complexity.
11900 if (WalkingBEDominatingConds)
11901 return false;
11902
11903 SaveAndRestore ClearOnExit(WalkingBEDominatingConds, true);
11904
11905 // See if we can exploit a trip count to prove the predicate.
11906 const auto &BETakenInfo = getBackedgeTakenInfo(L);
11907 const SCEV *LatchBECount = BETakenInfo.getExact(Latch, this);
11908 if (LatchBECount != getCouldNotCompute()) {
11909 // We know that Latch branches back to the loop header exactly
11910 // LatchBECount times. This means the backdege condition at Latch is
11911 // equivalent to "{0,+,1} u< LatchBECount".
11912 Type *Ty = LatchBECount->getType();
11913 auto NoWrapFlags = SCEV::NoWrapFlags(SCEV::FlagNUW | SCEV::FlagNW);
11914 const SCEV *LoopCounter =
11915 getAddRecExpr(getZero(Ty), getOne(Ty), L, NoWrapFlags);
11916 if (isImpliedCond(Pred, LHS, RHS, ICmpInst::ICMP_ULT, LoopCounter,
11917 LatchBECount))
11918 return true;
11919 }
11920
11921 // Check conditions due to any @llvm.assume intrinsics.
11922 for (auto &AssumeVH : AC.assumptions()) {
11923 if (!AssumeVH)
11924 continue;
11925 auto *CI = cast<CallInst>(AssumeVH);
11926 if (!DT.dominates(CI, Latch->getTerminator()))
11927 continue;
11928
11929 if (isImpliedCond(Pred, LHS, RHS, CI->getArgOperand(0), false))
11930 return true;
11931 }
11932
11933 if (isImpliedViaGuard(Latch, Pred, LHS, RHS))
11934 return true;
11935
11936 for (DomTreeNode *DTN = DT[Latch], *HeaderDTN = DT[L->getHeader()];
11937 DTN != HeaderDTN; DTN = DTN->getIDom()) {
11938 assert(DTN && "should reach the loop header before reaching the root!");
11939
11940 BasicBlock *BB = DTN->getBlock();
11941 if (isImpliedViaGuard(BB, Pred, LHS, RHS))
11942 return true;
11943
11944 BasicBlock *PBB = BB->getSinglePredecessor();
11945 if (!PBB)
11946 continue;
11947
11949 if (!ContBr || ContBr->getSuccessor(0) == ContBr->getSuccessor(1))
11950 continue;
11951
11952 // If we have an edge `E` within the loop body that dominates the only
11953 // latch, the condition guarding `E` also guards the backedge. This
11954 // reasoning works only for loops with a single latch.
11955 // We're constructively (and conservatively) enumerating edges within the
11956 // loop body that dominate the latch. The dominator tree better agree
11957 // with us on this:
11958 assert(DT.dominates(BasicBlockEdge(PBB, BB), Latch) && "should be!");
11959 if (isImpliedCond(Pred, LHS, RHS, ContBr->getCondition(),
11960 BB != ContBr->getSuccessor(0)))
11961 return true;
11962 }
11963
11964 return false;
11965}
11966
11968 CmpPredicate Pred,
11969 const SCEV *LHS,
11970 const SCEV *RHS) {
11971 // Do not bother proving facts for unreachable code.
11972 if (!DT.isReachableFromEntry(BB))
11973 return true;
11974 if (VerifyIR)
11975 assert(!verifyFunction(*BB->getParent(), &dbgs()) &&
11976 "This cannot be done on broken IR!");
11977
11978 // If we cannot prove strict comparison (e.g. a > b), maybe we can prove
11979 // the facts (a >= b && a != b) separately. A typical situation is when the
11980 // non-strict comparison is known from ranges and non-equality is known from
11981 // dominating predicates. If we are proving strict comparison, we always try
11982 // to prove non-equality and non-strict comparison separately.
11983 CmpPredicate NonStrictPredicate = ICmpInst::getNonStrictCmpPredicate(Pred);
11984 const bool ProvingStrictComparison =
11985 Pred != NonStrictPredicate.dropSameSign();
11986 bool ProvedNonStrictComparison = false;
11987 bool ProvedNonEquality = false;
11988
11989 auto SplitAndProve = [&](std::function<bool(CmpPredicate)> Fn) -> bool {
11990 if (!ProvedNonStrictComparison)
11991 ProvedNonStrictComparison = Fn(NonStrictPredicate);
11992 if (!ProvedNonEquality)
11993 ProvedNonEquality = Fn(ICmpInst::ICMP_NE);
11994 if (ProvedNonStrictComparison && ProvedNonEquality)
11995 return true;
11996 return false;
11997 };
11998
11999 if (ProvingStrictComparison) {
12000 auto ProofFn = [&](CmpPredicate P) {
12001 return isKnownViaNonRecursiveReasoning(P, LHS, RHS);
12002 };
12003 if (SplitAndProve(ProofFn))
12004 return true;
12005 }
12006
12007 // Try to prove (Pred, LHS, RHS) using isImpliedCond.
12008 auto ProveViaCond = [&](const Value *Condition, bool Inverse) {
12009 const Instruction *CtxI = &BB->front();
12010 if (isImpliedCond(Pred, LHS, RHS, Condition, Inverse, CtxI))
12011 return true;
12012 if (ProvingStrictComparison) {
12013 auto ProofFn = [&](CmpPredicate P) {
12014 return isImpliedCond(P, LHS, RHS, Condition, Inverse, CtxI);
12015 };
12016 if (SplitAndProve(ProofFn))
12017 return true;
12018 }
12019 return false;
12020 };
12021
12022 // Starting at the block's predecessor, climb up the predecessor chain, as long
12023 // as there are predecessors that can be found that have unique successors
12024 // leading to the original block.
12025 const Loop *ContainingLoop = LI.getLoopFor(BB);
12026 const BasicBlock *PredBB;
12027 if (ContainingLoop && ContainingLoop->getHeader() == BB)
12028 PredBB = ContainingLoop->getLoopPredecessor();
12029 else
12030 PredBB = BB->getSinglePredecessor();
12031 for (std::pair<const BasicBlock *, const BasicBlock *> Pair(PredBB, BB);
12032 Pair.first; Pair = getPredecessorWithUniqueSuccessorForBB(Pair.first)) {
12033 const CondBrInst *BlockEntryPredicate =
12034 dyn_cast<CondBrInst>(Pair.first->getTerminator());
12035 if (!BlockEntryPredicate)
12036 continue;
12037
12038 if (ProveViaCond(BlockEntryPredicate->getCondition(),
12039 BlockEntryPredicate->getSuccessor(0) != Pair.second))
12040 return true;
12041 }
12042
12043 // Check conditions due to any @llvm.assume intrinsics.
12044 for (auto &AssumeVH : AC.assumptions()) {
12045 if (!AssumeVH)
12046 continue;
12047 auto *CI = cast<CallInst>(AssumeVH);
12048 if (!DT.dominates(CI, BB))
12049 continue;
12050
12051 if (ProveViaCond(CI->getArgOperand(0), false))
12052 return true;
12053 }
12054
12055 // Check conditions due to any @llvm.experimental.guard intrinsics.
12056 auto *GuardDecl = Intrinsic::getDeclarationIfExists(
12057 F.getParent(), Intrinsic::experimental_guard);
12058 if (GuardDecl)
12059 for (const auto *GU : GuardDecl->users())
12060 if (const auto *Guard = dyn_cast<IntrinsicInst>(GU))
12061 if (Guard->getFunction() == BB->getParent() && DT.dominates(Guard, BB))
12062 if (ProveViaCond(Guard->getArgOperand(0), false))
12063 return true;
12064 return false;
12065}
12066
12068 const SCEV *LHS,
12069 const SCEV *RHS) {
12070 // Interpret a null as meaning no loop, where there is obviously no guard
12071 // (interprocedural conditions notwithstanding).
12072 if (!L)
12073 return false;
12074
12075 // Both LHS and RHS must be available at loop entry.
12077 "LHS is not available at Loop Entry");
12079 "RHS is not available at Loop Entry");
12080
12081 if (isKnownViaNonRecursiveReasoning(Pred, LHS, RHS))
12082 return true;
12083
12084 return isBasicBlockEntryGuardedByCond(L->getHeader(), Pred, LHS, RHS);
12085}
12086
12087bool ScalarEvolution::isImpliedCond(CmpPredicate Pred, const SCEV *LHS,
12088 const SCEV *RHS,
12089 const Value *FoundCondValue, bool Inverse,
12090 const Instruction *CtxI) {
12091 // False conditions implies anything. Do not bother analyzing it further.
12092 if (FoundCondValue ==
12093 ConstantInt::getBool(FoundCondValue->getContext(), Inverse))
12094 return true;
12095
12096 if (!PendingLoopPredicates.insert(FoundCondValue).second)
12097 return false;
12098
12099 llvm::scope_exit ClearOnExit(
12100 [&]() { PendingLoopPredicates.erase(FoundCondValue); });
12101
12102 // Recursively handle And and Or conditions.
12103 const Value *Op0, *Op1;
12104 if (match(FoundCondValue, m_LogicalAnd(m_Value(Op0), m_Value(Op1)))) {
12105 if (!Inverse)
12106 return isImpliedCond(Pred, LHS, RHS, Op0, Inverse, CtxI) ||
12107 isImpliedCond(Pred, LHS, RHS, Op1, Inverse, CtxI);
12108 } else if (match(FoundCondValue, m_LogicalOr(m_Value(Op0), m_Value(Op1)))) {
12109 if (Inverse)
12110 return isImpliedCond(Pred, LHS, RHS, Op0, Inverse, CtxI) ||
12111 isImpliedCond(Pred, LHS, RHS, Op1, Inverse, CtxI);
12112 }
12113
12114 const ICmpInst *ICI = dyn_cast<ICmpInst>(FoundCondValue);
12115 if (!ICI) return false;
12116
12117 // Now that we found a conditional branch that dominates the loop or controls
12118 // the loop latch. Check to see if it is the comparison we are looking for.
12119 CmpPredicate FoundPred;
12120 if (Inverse)
12121 FoundPred = ICI->getInverseCmpPredicate();
12122 else
12123 FoundPred = ICI->getCmpPredicate();
12124
12125 const SCEV *FoundLHS = getSCEV(ICI->getOperand(0));
12126 const SCEV *FoundRHS = getSCEV(ICI->getOperand(1));
12127
12128 return isImpliedCond(Pred, LHS, RHS, FoundPred, FoundLHS, FoundRHS, CtxI);
12129}
12130
12131bool ScalarEvolution::isImpliedCond(CmpPredicate Pred, const SCEV *LHS,
12132 const SCEV *RHS, CmpPredicate FoundPred,
12133 const SCEV *FoundLHS, const SCEV *FoundRHS,
12134 const Instruction *CtxI) {
12135 // Balance the types.
12136 if (getTypeSizeInBits(LHS->getType()) <
12137 getTypeSizeInBits(FoundLHS->getType())) {
12138 // For unsigned and equality predicates, try to prove that both found
12139 // operands fit into narrow unsigned range. If so, try to prove facts in
12140 // narrow types.
12141 if (!CmpInst::isSigned(FoundPred) && !FoundLHS->getType()->isPointerTy() &&
12142 !FoundRHS->getType()->isPointerTy()) {
12143 auto *NarrowType = LHS->getType();
12144 auto *WideType = FoundLHS->getType();
12145 auto BitWidth = getTypeSizeInBits(NarrowType);
12146 const SCEV *MaxValue = getZeroExtendExpr(
12148 if (isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_ULE, FoundLHS,
12149 MaxValue) &&
12150 isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_ULE, FoundRHS,
12151 MaxValue)) {
12152 const SCEV *TruncFoundLHS = getTruncateExpr(FoundLHS, NarrowType);
12153 const SCEV *TruncFoundRHS = getTruncateExpr(FoundRHS, NarrowType);
12154 // We cannot preserve samesign after truncation.
12155 if (isImpliedCondBalancedTypes(Pred, LHS, RHS, FoundPred.dropSameSign(),
12156 TruncFoundLHS, TruncFoundRHS, CtxI))
12157 return true;
12158 }
12159 }
12160
12161 if (LHS->getType()->isPointerTy() || RHS->getType()->isPointerTy())
12162 return false;
12163 if (CmpInst::isSigned(Pred)) {
12164 LHS = getSignExtendExpr(LHS, FoundLHS->getType());
12165 RHS = getSignExtendExpr(RHS, FoundLHS->getType());
12166 } else {
12167 LHS = getZeroExtendExpr(LHS, FoundLHS->getType());
12168 RHS = getZeroExtendExpr(RHS, FoundLHS->getType());
12169 }
12170 } else if (getTypeSizeInBits(LHS->getType()) >
12171 getTypeSizeInBits(FoundLHS->getType())) {
12172 if (FoundLHS->getType()->isPointerTy() || FoundRHS->getType()->isPointerTy())
12173 return false;
12174 if (CmpInst::isSigned(FoundPred)) {
12175 FoundLHS = getSignExtendExpr(FoundLHS, LHS->getType());
12176 FoundRHS = getSignExtendExpr(FoundRHS, LHS->getType());
12177 } else {
12178 FoundLHS = getZeroExtendExpr(FoundLHS, LHS->getType());
12179 FoundRHS = getZeroExtendExpr(FoundRHS, LHS->getType());
12180 }
12181 }
12182 return isImpliedCondBalancedTypes(Pred, LHS, RHS, FoundPred, FoundLHS,
12183 FoundRHS, CtxI);
12184}
12185
12186bool ScalarEvolution::isImpliedCondBalancedTypes(
12187 CmpPredicate Pred, SCEVUse LHS, SCEVUse RHS, CmpPredicate FoundPred,
12188 SCEVUse FoundLHS, SCEVUse FoundRHS, const Instruction *CtxI) {
12190 getTypeSizeInBits(FoundLHS->getType()) &&
12191 "Types should be balanced!");
12192 // Canonicalize the query to match the way instcombine will have
12193 // canonicalized the comparison.
12194 if (SimplifyICmpOperands(Pred, LHS, RHS))
12195 if (LHS == RHS)
12196 return CmpInst::isTrueWhenEqual(Pred);
12197 if (SimplifyICmpOperands(FoundPred, FoundLHS, FoundRHS))
12198 if (FoundLHS == FoundRHS)
12199 return CmpInst::isFalseWhenEqual(FoundPred);
12200
12201 // Check to see if we can make the LHS or RHS match.
12202 if (LHS == FoundRHS || RHS == FoundLHS) {
12203 if (isa<SCEVConstant>(RHS)) {
12204 std::swap(FoundLHS, FoundRHS);
12205 FoundPred = ICmpInst::getSwappedCmpPredicate(FoundPred);
12206 } else {
12207 std::swap(LHS, RHS);
12209 }
12210 }
12211
12212 // Check whether the found predicate is the same as the desired predicate.
12213 if (auto P = CmpPredicate::getMatching(FoundPred, Pred))
12214 return isImpliedCondOperands(*P, LHS, RHS, FoundLHS, FoundRHS, CtxI);
12215
12216 // Check whether swapping the found predicate makes it the same as the
12217 // desired predicate.
12218 if (auto P = CmpPredicate::getMatching(
12219 ICmpInst::getSwappedCmpPredicate(FoundPred), Pred)) {
12220 // We can write the implication
12221 // 0. LHS Pred RHS <- FoundLHS SwapPred FoundRHS
12222 // using one of the following ways:
12223 // 1. LHS Pred RHS <- FoundRHS Pred FoundLHS
12224 // 2. RHS SwapPred LHS <- FoundLHS SwapPred FoundRHS
12225 // 3. LHS Pred RHS <- ~FoundLHS Pred ~FoundRHS
12226 // 4. ~LHS SwapPred ~RHS <- FoundLHS SwapPred FoundRHS
12227 // Forms 1. and 2. require swapping the operands of one condition. Don't
12228 // do this if it would break canonical constant/addrec ordering.
12230 return isImpliedCondOperands(ICmpInst::getSwappedCmpPredicate(*P), RHS,
12231 LHS, FoundLHS, FoundRHS, CtxI);
12232 if (!isa<SCEVConstant>(FoundRHS) && !isa<SCEVAddRecExpr>(FoundLHS))
12233 return isImpliedCondOperands(*P, LHS, RHS, FoundRHS, FoundLHS, CtxI);
12234
12235 // There's no clear preference between forms 3. and 4., try both. Avoid
12236 // forming getNotSCEV of pointer values as the resulting subtract is
12237 // not legal.
12238 if (!LHS->getType()->isPointerTy() && !RHS->getType()->isPointerTy() &&
12239 isImpliedCondOperands(ICmpInst::getSwappedCmpPredicate(*P),
12240 getNotSCEV(LHS), getNotSCEV(RHS), FoundLHS,
12241 FoundRHS, CtxI))
12242 return true;
12243
12244 if (!FoundLHS->getType()->isPointerTy() &&
12245 !FoundRHS->getType()->isPointerTy() &&
12246 isImpliedCondOperands(*P, LHS, RHS, getNotSCEV(FoundLHS),
12247 getNotSCEV(FoundRHS), CtxI))
12248 return true;
12249
12250 return false;
12251 }
12252
12253 auto IsSignFlippedPredicate = [](CmpInst::Predicate P1,
12255 assert(P1 != P2 && "Handled earlier!");
12256 return CmpInst::isRelational(P2) &&
12258 };
12259 if (IsSignFlippedPredicate(Pred, FoundPred)) {
12260 // Unsigned comparison is the same as signed comparison when both the
12261 // operands are non-negative or negative.
12262 if (haveSameSign(FoundLHS, FoundRHS))
12263 return isImpliedCondOperands(Pred, LHS, RHS, FoundLHS, FoundRHS, CtxI);
12264 // Create local copies that we can freely swap and canonicalize our
12265 // conditions to "le/lt".
12266 CmpPredicate CanonicalPred = Pred, CanonicalFoundPred = FoundPred;
12267 const SCEV *CanonicalLHS = LHS, *CanonicalRHS = RHS,
12268 *CanonicalFoundLHS = FoundLHS, *CanonicalFoundRHS = FoundRHS;
12269 if (ICmpInst::isGT(CanonicalPred) || ICmpInst::isGE(CanonicalPred)) {
12270 CanonicalPred = ICmpInst::getSwappedCmpPredicate(CanonicalPred);
12271 CanonicalFoundPred = ICmpInst::getSwappedCmpPredicate(CanonicalFoundPred);
12272 std::swap(CanonicalLHS, CanonicalRHS);
12273 std::swap(CanonicalFoundLHS, CanonicalFoundRHS);
12274 }
12275 assert((ICmpInst::isLT(CanonicalPred) || ICmpInst::isLE(CanonicalPred)) &&
12276 "Must be!");
12277 assert((ICmpInst::isLT(CanonicalFoundPred) ||
12278 ICmpInst::isLE(CanonicalFoundPred)) &&
12279 "Must be!");
12280 if (ICmpInst::isSigned(CanonicalPred) && isKnownNonNegative(CanonicalRHS))
12281 // Use implication:
12282 // x <u y && y >=s 0 --> x <s y.
12283 // If we can prove the left part, the right part is also proven.
12284 return isImpliedCondOperands(CanonicalFoundPred, CanonicalLHS,
12285 CanonicalRHS, CanonicalFoundLHS,
12286 CanonicalFoundRHS);
12287 if (ICmpInst::isUnsigned(CanonicalPred) && isKnownNegative(CanonicalRHS))
12288 // Use implication:
12289 // x <s y && y <s 0 --> x <u y.
12290 // If we can prove the left part, the right part is also proven.
12291 return isImpliedCondOperands(CanonicalFoundPred, CanonicalLHS,
12292 CanonicalRHS, CanonicalFoundLHS,
12293 CanonicalFoundRHS);
12294 }
12295
12296 // Check if we can make progress by sharpening ranges.
12297 if (FoundPred == ICmpInst::ICMP_NE &&
12298 (isa<SCEVConstant>(FoundLHS) || isa<SCEVConstant>(FoundRHS))) {
12299
12300 const SCEVConstant *C = nullptr;
12301 const SCEV *V = nullptr;
12302
12303 if (isa<SCEVConstant>(FoundLHS)) {
12304 C = cast<SCEVConstant>(FoundLHS);
12305 V = FoundRHS;
12306 } else {
12307 C = cast<SCEVConstant>(FoundRHS);
12308 V = FoundLHS;
12309 }
12310
12311 // The guarding predicate tells us that C != V. If the known range
12312 // of V is [C, t), we can sharpen the range to [C + 1, t). The
12313 // range we consider has to correspond to same signedness as the
12314 // predicate we're interested in folding.
12315
12316 APInt Min = ICmpInst::isSigned(Pred) ?
12318
12319 if (Min == C->getAPInt()) {
12320 // Given (V >= Min && V != Min) we conclude V >= (Min + 1).
12321 // This is true even if (Min + 1) wraps around -- in case of
12322 // wraparound, (Min + 1) < Min, so (V >= Min => V >= (Min + 1)).
12323
12324 APInt SharperMin = Min + 1;
12325
12326 switch (Pred) {
12327 case ICmpInst::ICMP_SGE:
12328 case ICmpInst::ICMP_UGE:
12329 // We know V `Pred` SharperMin. If this implies LHS `Pred`
12330 // RHS, we're done.
12331 if (isImpliedCondOperands(Pred, LHS, RHS, V, getConstant(SharperMin),
12332 CtxI))
12333 return true;
12334 [[fallthrough]];
12335
12336 case ICmpInst::ICMP_SGT:
12337 case ICmpInst::ICMP_UGT:
12338 // We know from the range information that (V `Pred` Min ||
12339 // V == Min). We know from the guarding condition that !(V
12340 // == Min). This gives us
12341 //
12342 // V `Pred` Min || V == Min && !(V == Min)
12343 // => V `Pred` Min
12344 //
12345 // If V `Pred` Min implies LHS `Pred` RHS, we're done.
12346
12347 if (isImpliedCondOperands(Pred, LHS, RHS, V, getConstant(Min), CtxI))
12348 return true;
12349 break;
12350
12351 // `LHS < RHS` and `LHS <= RHS` are handled in the same way as `RHS > LHS` and `RHS >= LHS` respectively.
12352 case ICmpInst::ICMP_SLE:
12353 case ICmpInst::ICMP_ULE:
12354 if (isImpliedCondOperands(ICmpInst::getSwappedCmpPredicate(Pred), RHS,
12355 LHS, V, getConstant(SharperMin), CtxI))
12356 return true;
12357 [[fallthrough]];
12358
12359 case ICmpInst::ICMP_SLT:
12360 case ICmpInst::ICMP_ULT:
12361 if (isImpliedCondOperands(ICmpInst::getSwappedCmpPredicate(Pred), RHS,
12362 LHS, V, getConstant(Min), CtxI))
12363 return true;
12364 break;
12365
12366 default:
12367 // No change
12368 break;
12369 }
12370 }
12371 }
12372
12373 // Check whether the actual condition is beyond sufficient.
12374 if (FoundPred == ICmpInst::ICMP_EQ)
12375 if (ICmpInst::isTrueWhenEqual(Pred))
12376 if (isImpliedCondOperands(Pred, LHS, RHS, FoundLHS, FoundRHS, CtxI))
12377 return true;
12378 if (Pred == ICmpInst::ICMP_NE)
12379 if (!ICmpInst::isTrueWhenEqual(FoundPred))
12380 if (isImpliedCondOperands(FoundPred, LHS, RHS, FoundLHS, FoundRHS, CtxI))
12381 return true;
12382
12383 if (isImpliedCondOperandsViaRanges(Pred, LHS, RHS, FoundPred, FoundLHS, FoundRHS))
12384 return true;
12385
12386 // Otherwise assume the worst.
12387 return false;
12388}
12389
12390bool ScalarEvolution::splitBinaryAdd(SCEVUse Expr, SCEVUse &L, SCEVUse &R,
12391 SCEV::NoWrapFlags &Flags) {
12392 if (!match(Expr, m_scev_Add(m_SCEV(L), m_SCEV(R))))
12393 return false;
12394
12395 Flags = cast<SCEVAddExpr>(Expr)->getNoWrapFlags();
12396 return true;
12397}
12398
12399std::optional<APInt>
12401 // We avoid subtracting expressions here because this function is usually
12402 // fairly deep in the call stack (i.e. is called many times).
12403
12404 unsigned BW = getTypeSizeInBits(More->getType());
12405 APInt Diff(BW, 0);
12406 APInt DiffMul(BW, 1);
12407 // Try various simplifications to reduce the difference to a constant. Limit
12408 // the number of allowed simplifications to keep compile-time low.
12409 for (unsigned I = 0; I < 8; ++I) {
12410 if (More == Less)
12411 return Diff;
12412
12413 // Reduce addrecs with identical steps to their start value.
12415 const auto *LAR = cast<SCEVAddRecExpr>(Less);
12416 const auto *MAR = cast<SCEVAddRecExpr>(More);
12417
12418 if (LAR->getLoop() != MAR->getLoop())
12419 return std::nullopt;
12420
12421 // We look at affine expressions only; not for correctness but to keep
12422 // getStepRecurrence cheap.
12423 if (!LAR->isAffine() || !MAR->isAffine())
12424 return std::nullopt;
12425
12426 if (LAR->getStepRecurrence(*this) != MAR->getStepRecurrence(*this))
12427 return std::nullopt;
12428
12429 Less = LAR->getStart();
12430 More = MAR->getStart();
12431 continue;
12432 }
12433
12434 // Try to match a common constant multiply.
12435 auto MatchConstMul =
12436 [](const SCEV *S) -> std::optional<std::pair<const SCEV *, APInt>> {
12437 const APInt *C;
12438 const SCEV *Op;
12439 if (match(S, m_scev_Mul(m_scev_APInt(C), m_SCEV(Op))))
12440 return {{Op, *C}};
12441 return std::nullopt;
12442 };
12443 if (auto MatchedMore = MatchConstMul(More)) {
12444 if (auto MatchedLess = MatchConstMul(Less)) {
12445 if (MatchedMore->second == MatchedLess->second) {
12446 More = MatchedMore->first;
12447 Less = MatchedLess->first;
12448 DiffMul *= MatchedMore->second;
12449 continue;
12450 }
12451 }
12452 }
12453
12454 // Try to cancel out common factors in two add expressions.
12456 auto Add = [&](const SCEV *S, int Mul) {
12457 if (auto *C = dyn_cast<SCEVConstant>(S)) {
12458 if (Mul == 1) {
12459 Diff += C->getAPInt() * DiffMul;
12460 } else {
12461 assert(Mul == -1);
12462 Diff -= C->getAPInt() * DiffMul;
12463 }
12464 } else
12465 Multiplicity[S] += Mul;
12466 };
12467 auto Decompose = [&](const SCEV *S, int Mul) {
12468 if (isa<SCEVAddExpr>(S)) {
12469 for (const SCEV *Op : S->operands())
12470 Add(Op, Mul);
12471 } else
12472 Add(S, Mul);
12473 };
12474 Decompose(More, 1);
12475 Decompose(Less, -1);
12476
12477 // Check whether all the non-constants cancel out, or reduce to new
12478 // More/Less values.
12479 const SCEV *NewMore = nullptr, *NewLess = nullptr;
12480 for (const auto &[S, Mul] : Multiplicity) {
12481 if (Mul == 0)
12482 continue;
12483 if (Mul == 1) {
12484 if (NewMore)
12485 return std::nullopt;
12486 NewMore = S;
12487 } else if (Mul == -1) {
12488 if (NewLess)
12489 return std::nullopt;
12490 NewLess = S;
12491 } else
12492 return std::nullopt;
12493 }
12494
12495 // Values stayed the same, no point in trying further.
12496 if (NewMore == More || NewLess == Less)
12497 return std::nullopt;
12498
12499 More = NewMore;
12500 Less = NewLess;
12501
12502 // Reduced to constant.
12503 if (!More && !Less)
12504 return Diff;
12505
12506 // Left with variable on only one side, bail out.
12507 if (!More || !Less)
12508 return std::nullopt;
12509 }
12510
12511 // Did not reduce to constant.
12512 return std::nullopt;
12513}
12514
12515bool ScalarEvolution::isImpliedCondOperandsViaAddRecStart(
12516 CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS, const SCEV *FoundLHS,
12517 const SCEV *FoundRHS, const Instruction *CtxI) {
12518 // Try to recognize the following pattern:
12519 //
12520 // FoundRHS = ...
12521 // ...
12522 // loop:
12523 // FoundLHS = {Start,+,W}
12524 // context_bb: // Basic block from the same loop
12525 // known(Pred, FoundLHS, FoundRHS)
12526 //
12527 // If some predicate is known in the context of a loop, it is also known on
12528 // each iteration of this loop, including the first iteration. Therefore, in
12529 // this case, `FoundLHS Pred FoundRHS` implies `Start Pred FoundRHS`. Try to
12530 // prove the original pred using this fact.
12531 if (!CtxI)
12532 return false;
12533 const BasicBlock *ContextBB = CtxI->getParent();
12534 // Make sure AR varies in the context block.
12535 if (auto *AR = dyn_cast<SCEVAddRecExpr>(FoundLHS)) {
12536 const Loop *L = AR->getLoop();
12537 const auto *Latch = L->getLoopLatch();
12538 // Make sure that context belongs to the loop and executes on 1st iteration
12539 // (if it ever executes at all).
12540 if (!L->contains(ContextBB) || !Latch || !DT.dominates(ContextBB, Latch))
12541 return false;
12542 if (!isAvailableAtLoopEntry(FoundRHS, AR->getLoop()))
12543 return false;
12544 return isImpliedCondOperands(Pred, LHS, RHS, AR->getStart(), FoundRHS);
12545 }
12546
12547 if (auto *AR = dyn_cast<SCEVAddRecExpr>(FoundRHS)) {
12548 const Loop *L = AR->getLoop();
12549 const auto *Latch = L->getLoopLatch();
12550 // Make sure that context belongs to the loop and executes on 1st iteration
12551 // (if it ever executes at all).
12552 if (!L->contains(ContextBB) || !Latch || !DT.dominates(ContextBB, Latch))
12553 return false;
12554 if (!isAvailableAtLoopEntry(FoundLHS, AR->getLoop()))
12555 return false;
12556 return isImpliedCondOperands(Pred, LHS, RHS, FoundLHS, AR->getStart());
12557 }
12558
12559 return false;
12560}
12561
12562bool ScalarEvolution::isImpliedCondOperandsViaNoOverflow(CmpPredicate Pred,
12563 const SCEV *LHS,
12564 const SCEV *RHS,
12565 const SCEV *FoundLHS,
12566 const SCEV *FoundRHS) {
12567 if (Pred != CmpInst::ICMP_SLT && Pred != CmpInst::ICMP_ULT)
12568 return false;
12569
12570 const auto *AddRecLHS = dyn_cast<SCEVAddRecExpr>(LHS);
12571 if (!AddRecLHS)
12572 return false;
12573
12574 const auto *AddRecFoundLHS = dyn_cast<SCEVAddRecExpr>(FoundLHS);
12575 if (!AddRecFoundLHS)
12576 return false;
12577
12578 // We'd like to let SCEV reason about control dependencies, so we constrain
12579 // both the inequalities to be about add recurrences on the same loop. This
12580 // way we can use isLoopEntryGuardedByCond later.
12581
12582 const Loop *L = AddRecFoundLHS->getLoop();
12583 if (L != AddRecLHS->getLoop())
12584 return false;
12585
12586 // FoundLHS u< FoundRHS u< -C => (FoundLHS + C) u< (FoundRHS + C) ... (1)
12587 //
12588 // FoundLHS s< FoundRHS s< INT_MIN - C => (FoundLHS + C) s< (FoundRHS + C)
12589 // ... (2)
12590 //
12591 // Informal proof for (2), assuming (1) [*]:
12592 //
12593 // We'll also assume (A s< B) <=> ((A + INT_MIN) u< (B + INT_MIN)) ... (3)[**]
12594 //
12595 // Then
12596 //
12597 // FoundLHS s< FoundRHS s< INT_MIN - C
12598 // <=> (FoundLHS + INT_MIN) u< (FoundRHS + INT_MIN) u< -C [ using (3) ]
12599 // <=> (FoundLHS + INT_MIN + C) u< (FoundRHS + INT_MIN + C) [ using (1) ]
12600 // <=> (FoundLHS + INT_MIN + C + INT_MIN) s<
12601 // (FoundRHS + INT_MIN + C + INT_MIN) [ using (3) ]
12602 // <=> FoundLHS + C s< FoundRHS + C
12603 //
12604 // [*]: (1) can be proved by ruling out overflow.
12605 //
12606 // [**]: This can be proved by analyzing all the four possibilities:
12607 // (A s< 0, B s< 0), (A s< 0, B s>= 0), (A s>= 0, B s< 0) and
12608 // (A s>= 0, B s>= 0).
12609 //
12610 // Note:
12611 // Despite (2), "FoundRHS s< INT_MIN - C" does not mean that "FoundRHS + C"
12612 // will not sign underflow. For instance, say FoundLHS = (i8 -128), FoundRHS
12613 // = (i8 -127) and C = (i8 -100). Then INT_MIN - C = (i8 -28), and FoundRHS
12614 // s< (INT_MIN - C). Lack of sign overflow / underflow in "FoundRHS + C" is
12615 // neither necessary nor sufficient to prove "(FoundLHS + C) s< (FoundRHS +
12616 // C)".
12617
12618 std::optional<APInt> LDiff = computeConstantDifference(LHS, FoundLHS);
12619 if (!LDiff)
12620 return false;
12621 std::optional<APInt> RDiff = computeConstantDifference(RHS, FoundRHS);
12622 if (!RDiff || *LDiff != *RDiff)
12623 return false;
12624
12625 if (LDiff->isMinValue())
12626 return true;
12627
12628 APInt FoundRHSLimit;
12629
12630 if (Pred == CmpInst::ICMP_ULT) {
12631 FoundRHSLimit = -(*RDiff);
12632 } else {
12633 assert(Pred == CmpInst::ICMP_SLT && "Checked above!");
12634 FoundRHSLimit = APInt::getSignedMinValue(getTypeSizeInBits(RHS->getType())) - *RDiff;
12635 }
12636
12637 // Try to prove (1) or (2), as needed.
12638 return isAvailableAtLoopEntry(FoundRHS, L) &&
12639 isLoopEntryGuardedByCond(L, Pred, FoundRHS,
12640 getConstant(FoundRHSLimit));
12641}
12642
12643bool ScalarEvolution::isImpliedViaMerge(CmpPredicate Pred, const SCEV *LHS,
12644 const SCEV *RHS, const SCEV *FoundLHS,
12645 const SCEV *FoundRHS, unsigned Depth) {
12646 const PHINode *LPhi = nullptr, *RPhi = nullptr;
12647
12648 llvm::scope_exit ClearOnExit([&]() {
12649 if (LPhi) {
12650 bool Erased = PendingMerges.erase(LPhi);
12651 assert(Erased && "Failed to erase LPhi!");
12652 (void)Erased;
12653 }
12654 if (RPhi) {
12655 bool Erased = PendingMerges.erase(RPhi);
12656 assert(Erased && "Failed to erase RPhi!");
12657 (void)Erased;
12658 }
12659 });
12660
12661 // Find respective Phis and check that they are not being pending.
12662 if (const SCEVUnknown *LU = dyn_cast<SCEVUnknown>(LHS))
12663 if (auto *Phi = dyn_cast<PHINode>(LU->getValue())) {
12664 if (!PendingMerges.insert(Phi).second)
12665 return false;
12666 LPhi = Phi;
12667 }
12668 if (const SCEVUnknown *RU = dyn_cast<SCEVUnknown>(RHS))
12669 if (auto *Phi = dyn_cast<PHINode>(RU->getValue())) {
12670 // If we detect a loop of Phi nodes being processed by this method, for
12671 // example:
12672 //
12673 // %a = phi i32 [ %some1, %preheader ], [ %b, %latch ]
12674 // %b = phi i32 [ %some2, %preheader ], [ %a, %latch ]
12675 //
12676 // we don't want to deal with a case that complex, so return conservative
12677 // answer false.
12678 if (!PendingMerges.insert(Phi).second)
12679 return false;
12680 RPhi = Phi;
12681 }
12682
12683 // If none of LHS, RHS is a Phi, nothing to do here.
12684 if (!LPhi && !RPhi)
12685 return false;
12686
12687 // If there is a SCEVUnknown Phi we are interested in, make it left.
12688 if (!LPhi) {
12689 std::swap(LHS, RHS);
12690 std::swap(FoundLHS, FoundRHS);
12691 std::swap(LPhi, RPhi);
12693 }
12694
12695 assert(LPhi && "LPhi should definitely be a SCEVUnknown Phi!");
12696 const BasicBlock *LBB = LPhi->getParent();
12697 const SCEVAddRecExpr *RAR = dyn_cast<SCEVAddRecExpr>(RHS);
12698
12699 auto ProvedEasily = [&](const SCEV *S1, const SCEV *S2) {
12700 return isKnownViaNonRecursiveReasoning(Pred, S1, S2) ||
12701 isImpliedCondOperandsViaRanges(Pred, S1, S2, Pred, FoundLHS, FoundRHS) ||
12702 isImpliedViaOperations(Pred, S1, S2, FoundLHS, FoundRHS, Depth);
12703 };
12704
12705 if (RPhi && RPhi->getParent() == LBB) {
12706 // Case one: RHS is also a SCEVUnknown Phi from the same basic block.
12707 // If we compare two Phis from the same block, and for each entry block
12708 // the predicate is true for incoming values from this block, then the
12709 // predicate is also true for the Phis.
12710 for (const BasicBlock *IncBB : predecessors(LBB)) {
12711 const SCEV *L = getSCEV(LPhi->getIncomingValueForBlock(IncBB));
12712 const SCEV *R = getSCEV(RPhi->getIncomingValueForBlock(IncBB));
12713 if (!ProvedEasily(L, R))
12714 return false;
12715 }
12716 } else if (RAR && RAR->getLoop()->getHeader() == LBB) {
12717 // Case two: RHS is also a Phi from the same basic block, and it is an
12718 // AddRec. It means that there is a loop which has both AddRec and Unknown
12719 // PHIs, for it we can compare incoming values of AddRec from above the loop
12720 // and latch with their respective incoming values of LPhi.
12721 // TODO: Generalize to handle loops with many inputs in a header.
12722 if (LPhi->getNumIncomingValues() != 2) return false;
12723
12724 auto *RLoop = RAR->getLoop();
12725 auto *Predecessor = RLoop->getLoopPredecessor();
12726 assert(Predecessor && "Loop with AddRec with no predecessor?");
12727 const SCEV *L1 = getSCEV(LPhi->getIncomingValueForBlock(Predecessor));
12728 if (!ProvedEasily(L1, RAR->getStart()))
12729 return false;
12730 auto *Latch = RLoop->getLoopLatch();
12731 assert(Latch && "Loop with AddRec with no latch?");
12732 const SCEV *L2 = getSCEV(LPhi->getIncomingValueForBlock(Latch));
12733 if (!ProvedEasily(L2, RAR->getPostIncExpr(*this)))
12734 return false;
12735 } else {
12736 // In all other cases go over inputs of LHS and compare each of them to RHS,
12737 // the predicate is true for (LHS, RHS) if it is true for all such pairs.
12738 // At this point RHS is either a non-Phi, or it is a Phi from some block
12739 // different from LBB.
12740 for (const BasicBlock *IncBB : predecessors(LBB)) {
12741 // Check that RHS is available in this block.
12742 if (!dominates(RHS, IncBB))
12743 return false;
12744 const SCEV *L = getSCEV(LPhi->getIncomingValueForBlock(IncBB));
12745 // Make sure L does not refer to a value from a potentially previous
12746 // iteration of a loop.
12747 if (!properlyDominates(L, LBB))
12748 return false;
12749 // Addrecs are considered to properly dominate their loop, so are missed
12750 // by the previous check. Discard any values that have computable
12751 // evolution in this loop.
12752 if (auto *Loop = LI.getLoopFor(LBB))
12753 if (hasComputableLoopEvolution(L, Loop))
12754 return false;
12755 if (!ProvedEasily(L, RHS))
12756 return false;
12757 }
12758 }
12759 return true;
12760}
12761
12762bool ScalarEvolution::isImpliedCondOperandsViaShift(CmpPredicate Pred,
12763 const SCEV *LHS,
12764 const SCEV *RHS,
12765 const SCEV *FoundLHS,
12766 const SCEV *FoundRHS) {
12767 // We want to imply LHS < RHS from LHS < (RHS >> shiftvalue). First, make
12768 // sure that we are dealing with same LHS.
12769 if (RHS == FoundRHS) {
12770 std::swap(LHS, RHS);
12771 std::swap(FoundLHS, FoundRHS);
12773 }
12774 if (LHS != FoundLHS)
12775 return false;
12776
12777 auto *SUFoundRHS = dyn_cast<SCEVUnknown>(FoundRHS);
12778 if (!SUFoundRHS)
12779 return false;
12780
12781 Value *Shiftee, *ShiftValue;
12782
12783 using namespace PatternMatch;
12784 if (match(SUFoundRHS->getValue(),
12785 m_LShr(m_Value(Shiftee), m_Value(ShiftValue)))) {
12786 auto *ShifteeS = getSCEV(Shiftee);
12787 // Prove one of the following:
12788 // LHS <u (shiftee >> shiftvalue) && shiftee <=u RHS ---> LHS <u RHS
12789 // LHS <=u (shiftee >> shiftvalue) && shiftee <=u RHS ---> LHS <=u RHS
12790 // LHS <s (shiftee >> shiftvalue) && shiftee <=s RHS && shiftee >=s 0
12791 // ---> LHS <s RHS
12792 // LHS <=s (shiftee >> shiftvalue) && shiftee <=s RHS && shiftee >=s 0
12793 // ---> LHS <=s RHS
12794 if (Pred == ICmpInst::ICMP_ULT || Pred == ICmpInst::ICMP_ULE)
12795 return isKnownPredicate(ICmpInst::ICMP_ULE, ShifteeS, RHS);
12796 if (Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_SLE)
12797 if (isKnownNonNegative(ShifteeS))
12798 return isKnownPredicate(ICmpInst::ICMP_SLE, ShifteeS, RHS);
12799 }
12800
12801 return false;
12802}
12803
12804bool ScalarEvolution::isImpliedCondOperands(CmpPredicate Pred, const SCEV *LHS,
12805 const SCEV *RHS,
12806 const SCEV *FoundLHS,
12807 const SCEV *FoundRHS,
12808 const Instruction *CtxI) {
12809 return isImpliedCondOperandsViaRanges(Pred, LHS, RHS, Pred, FoundLHS,
12810 FoundRHS) ||
12811 isImpliedCondOperandsViaNoOverflow(Pred, LHS, RHS, FoundLHS,
12812 FoundRHS) ||
12813 isImpliedCondOperandsViaShift(Pred, LHS, RHS, FoundLHS, FoundRHS) ||
12814 isImpliedCondOperandsViaAddRecStart(Pred, LHS, RHS, FoundLHS, FoundRHS,
12815 CtxI) ||
12816 isImpliedCondOperandsHelper(Pred, LHS, RHS, FoundLHS, FoundRHS);
12817}
12818
12819/// Is MaybeMinMaxExpr an (U|S)(Min|Max) of Candidate and some other values?
12820template <typename MinMaxExprType>
12821static bool IsMinMaxConsistingOf(const SCEV *MaybeMinMaxExpr,
12822 const SCEV *Candidate) {
12823 const MinMaxExprType *MinMaxExpr = dyn_cast<MinMaxExprType>(MaybeMinMaxExpr);
12824 if (!MinMaxExpr)
12825 return false;
12826
12827 return is_contained(MinMaxExpr->operands(), Candidate);
12828}
12829
12831 CmpPredicate Pred, const SCEV *LHS,
12832 const SCEV *RHS) {
12833 // If both sides are affine addrecs for the same loop, with equal
12834 // steps, and we know the recurrences don't wrap, then we only
12835 // need to check the predicate on the starting values.
12836
12837 if (!ICmpInst::isRelational(Pred))
12838 return false;
12839
12840 const SCEV *LStart, *RStart, *Step;
12841 const Loop *L;
12842 if (!match(LHS,
12843 m_scev_AffineAddRec(m_SCEV(LStart), m_SCEV(Step), m_Loop(L))) ||
12845 m_SpecificLoop(L))))
12846 return false;
12851 if (!LAR->getNoWrapFlags(NW) || !RAR->getNoWrapFlags(NW))
12852 return false;
12853
12854 return SE.isKnownPredicate(Pred, LStart, RStart);
12855}
12856
12857/// Is LHS `Pred` RHS true on the virtue of LHS or RHS being a Min or Max
12858/// expression?
12860 const SCEV *LHS, const SCEV *RHS) {
12861 switch (Pred) {
12862 default:
12863 return false;
12864
12865 case ICmpInst::ICMP_SGE:
12866 std::swap(LHS, RHS);
12867 [[fallthrough]];
12868 case ICmpInst::ICMP_SLE:
12869 return
12870 // min(A, ...) <= A
12872 // A <= max(A, ...)
12874
12875 case ICmpInst::ICMP_UGE:
12876 std::swap(LHS, RHS);
12877 [[fallthrough]];
12878 case ICmpInst::ICMP_ULE:
12879 return
12880 // min(A, ...) <= A
12881 // FIXME: what about umin_seq?
12883 // A <= max(A, ...)
12885 }
12886
12887 llvm_unreachable("covered switch fell through?!");
12888}
12889
12890bool ScalarEvolution::isImpliedViaOperations(CmpPredicate Pred, const SCEV *LHS,
12891 const SCEV *RHS,
12892 const SCEV *FoundLHS,
12893 const SCEV *FoundRHS,
12894 unsigned Depth) {
12897 "LHS and RHS have different sizes?");
12898 assert(getTypeSizeInBits(FoundLHS->getType()) ==
12899 getTypeSizeInBits(FoundRHS->getType()) &&
12900 "FoundLHS and FoundRHS have different sizes?");
12901 // We want to avoid hurting the compile time with analysis of too big trees.
12903 return false;
12904
12905 // We only want to work with GT comparison so far.
12906 if (ICmpInst::isLT(Pred)) {
12908 std::swap(LHS, RHS);
12909 std::swap(FoundLHS, FoundRHS);
12910 }
12911
12913
12914 // For unsigned, try to reduce it to corresponding signed comparison.
12915 if (P == ICmpInst::ICMP_UGT)
12916 // We can replace unsigned predicate with its signed counterpart if all
12917 // involved values are non-negative.
12918 // TODO: We could have better support for unsigned.
12919 if (isKnownNonNegative(FoundLHS) && isKnownNonNegative(FoundRHS)) {
12920 // Knowing that both FoundLHS and FoundRHS are non-negative, and knowing
12921 // FoundLHS >u FoundRHS, we also know that FoundLHS >s FoundRHS. Let us
12922 // use this fact to prove that LHS and RHS are non-negative.
12923 const SCEV *MinusOne = getMinusOne(LHS->getType());
12924 if (isImpliedCondOperands(ICmpInst::ICMP_SGT, LHS, MinusOne, FoundLHS,
12925 FoundRHS) &&
12926 isImpliedCondOperands(ICmpInst::ICMP_SGT, RHS, MinusOne, FoundLHS,
12927 FoundRHS))
12929 }
12930
12931 if (P != ICmpInst::ICMP_SGT)
12932 return false;
12933
12934 auto GetOpFromSExt = [&](const SCEV *S) -> const SCEV * {
12935 if (auto *Ext = dyn_cast<SCEVSignExtendExpr>(S))
12936 return Ext->getOperand();
12937 // TODO: If S is a SCEVConstant then you can cheaply "strip" the sext off
12938 // the constant in some cases.
12939 return S;
12940 };
12941
12942 // Acquire values from extensions.
12943 auto *OrigLHS = LHS;
12944 auto *OrigFoundLHS = FoundLHS;
12945 LHS = GetOpFromSExt(LHS);
12946 FoundLHS = GetOpFromSExt(FoundLHS);
12947
12948 // Is the SGT predicate can be proved trivially or using the found context.
12949 auto IsSGTViaContext = [&](const SCEV *S1, const SCEV *S2) {
12950 return isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_SGT, S1, S2) ||
12951 isImpliedViaOperations(ICmpInst::ICMP_SGT, S1, S2, OrigFoundLHS,
12952 FoundRHS, Depth + 1);
12953 };
12954
12955 if (auto *LHSAddExpr = dyn_cast<SCEVAddExpr>(LHS)) {
12956 // We want to avoid creation of any new non-constant SCEV. Since we are
12957 // going to compare the operands to RHS, we should be certain that we don't
12958 // need any size extensions for this. So let's decline all cases when the
12959 // sizes of types of LHS and RHS do not match.
12960 // TODO: Maybe try to get RHS from sext to catch more cases?
12962 return false;
12963
12964 // Should not overflow.
12965 if (!LHSAddExpr->hasNoSignedWrap())
12966 return false;
12967
12968 SCEVUse LL = LHSAddExpr->getOperand(0);
12969 SCEVUse LR = LHSAddExpr->getOperand(1);
12970 auto *MinusOne = getMinusOne(RHS->getType());
12971
12972 // Checks that S1 >= 0 && S2 > RHS, trivially or using the found context.
12973 auto IsSumGreaterThanRHS = [&](const SCEV *S1, const SCEV *S2) {
12974 return IsSGTViaContext(S1, MinusOne) && IsSGTViaContext(S2, RHS);
12975 };
12976 // Try to prove the following rule:
12977 // (LHS = LL + LR) && (LL >= 0) && (LR > RHS) => (LHS > RHS).
12978 // (LHS = LL + LR) && (LR >= 0) && (LL > RHS) => (LHS > RHS).
12979 if (IsSumGreaterThanRHS(LL, LR) || IsSumGreaterThanRHS(LR, LL))
12980 return true;
12981 } else if (auto *LHSUnknownExpr = dyn_cast<SCEVUnknown>(LHS)) {
12982 Value *LL, *LR;
12983 // FIXME: Once we have SDiv implemented, we can get rid of this matching.
12984
12985 using namespace llvm::PatternMatch;
12986
12987 if (match(LHSUnknownExpr->getValue(), m_SDiv(m_Value(LL), m_Value(LR)))) {
12988 // Rules for division.
12989 // We are going to perform some comparisons with Denominator and its
12990 // derivative expressions. In general case, creating a SCEV for it may
12991 // lead to a complex analysis of the entire graph, and in particular it
12992 // can request trip count recalculation for the same loop. This would
12993 // cache as SCEVCouldNotCompute to avoid the infinite recursion. To avoid
12994 // this, we only want to create SCEVs that are constants in this section.
12995 // So we bail if Denominator is not a constant.
12996 if (!isa<ConstantInt>(LR))
12997 return false;
12998
12999 auto *Denominator = cast<SCEVConstant>(getSCEV(LR));
13000
13001 // We want to make sure that LHS = FoundLHS / Denominator. If it is so,
13002 // then a SCEV for the numerator already exists and matches with FoundLHS.
13003 auto *Numerator = getExistingSCEV(LL);
13004 if (!Numerator || Numerator->getType() != FoundLHS->getType())
13005 return false;
13006
13007 // Make sure that the numerator matches with FoundLHS and the denominator
13008 // is positive.
13009 if (!HasSameValue(Numerator, FoundLHS) || !isKnownPositive(Denominator))
13010 return false;
13011
13012 auto *DTy = Denominator->getType();
13013 auto *FRHSTy = FoundRHS->getType();
13014 if (DTy->isPointerTy() != FRHSTy->isPointerTy())
13015 // One of types is a pointer and another one is not. We cannot extend
13016 // them properly to a wider type, so let us just reject this case.
13017 // TODO: Usage of getEffectiveSCEVType for DTy, FRHSTy etc should help
13018 // to avoid this check.
13019 return false;
13020
13021 // Given that:
13022 // FoundLHS > FoundRHS, LHS = FoundLHS / Denominator, Denominator > 0.
13023 auto *WTy = getWiderType(DTy, FRHSTy);
13024 auto *DenominatorExt = getNoopOrSignExtend(Denominator, WTy);
13025 auto *FoundRHSExt = getNoopOrSignExtend(FoundRHS, WTy);
13026
13027 // Try to prove the following rule:
13028 // (FoundRHS > Denominator - 2) && (RHS <= 0) => (LHS > RHS).
13029 // For example, given that FoundLHS > 2. It means that FoundLHS is at
13030 // least 3. If we divide it by Denominator < 4, we will have at least 1.
13031 auto *DenomMinusTwo = getMinusSCEV(DenominatorExt, getConstant(WTy, 2));
13032 if (isKnownNonPositive(RHS) &&
13033 IsSGTViaContext(FoundRHSExt, DenomMinusTwo))
13034 return true;
13035
13036 // Try to prove the following rule:
13037 // (FoundRHS > -1 - Denominator) && (RHS < 0) => (LHS > RHS).
13038 // For example, given that FoundLHS > -3. Then FoundLHS is at least -2.
13039 // If we divide it by Denominator > 2, then:
13040 // 1. If FoundLHS is negative, then the result is 0.
13041 // 2. If FoundLHS is non-negative, then the result is non-negative.
13042 // Anyways, the result is non-negative.
13043 auto *MinusOne = getMinusOne(WTy);
13044 auto *NegDenomMinusOne = getMinusSCEV(MinusOne, DenominatorExt);
13045 if (isKnownNegative(RHS) &&
13046 IsSGTViaContext(FoundRHSExt, NegDenomMinusOne))
13047 return true;
13048 }
13049 }
13050
13051 // If our expression contained SCEVUnknown Phis, and we split it down and now
13052 // need to prove something for them, try to prove the predicate for every
13053 // possible incoming values of those Phis.
13054 if (isImpliedViaMerge(Pred, OrigLHS, RHS, OrigFoundLHS, FoundRHS, Depth + 1))
13055 return true;
13056
13057 return false;
13058}
13059
13061 const SCEV *RHS) {
13062 // zext x u<= sext x, sext x s<= zext x
13063 const SCEV *Op;
13064 switch (Pred) {
13065 case ICmpInst::ICMP_SGE:
13066 std::swap(LHS, RHS);
13067 [[fallthrough]];
13068 case ICmpInst::ICMP_SLE: {
13069 // If operand >=s 0 then ZExt == SExt. If operand <s 0 then SExt <s ZExt.
13070 return match(LHS, m_scev_SExt(m_SCEV(Op))) &&
13072 }
13073 case ICmpInst::ICMP_UGE:
13074 std::swap(LHS, RHS);
13075 [[fallthrough]];
13076 case ICmpInst::ICMP_ULE: {
13077 // If operand >=u 0 then ZExt == SExt. If operand <u 0 then ZExt <u SExt.
13078 return match(LHS, m_scev_ZExt(m_SCEV(Op))) &&
13080 }
13081 default:
13082 return false;
13083 };
13084 llvm_unreachable("unhandled case");
13085}
13086
13087bool ScalarEvolution::isKnownViaNonRecursiveReasoning(CmpPredicate Pred,
13088 SCEVUse LHS,
13089 SCEVUse RHS) {
13090 return isKnownPredicateExtendIdiom(Pred, LHS, RHS) ||
13091 isKnownPredicateViaConstantRanges(Pred, LHS, RHS) ||
13092 IsKnownPredicateViaMinOrMax(*this, Pred, LHS, RHS) ||
13093 IsKnownPredicateViaAddRecStart(*this, Pred, LHS, RHS) ||
13094 isKnownPredicateViaNoOverflow(Pred, LHS, RHS);
13095}
13096
13097bool ScalarEvolution::isImpliedCondOperandsHelper(CmpPredicate Pred,
13098 const SCEV *LHS,
13099 const SCEV *RHS,
13100 const SCEV *FoundLHS,
13101 const SCEV *FoundRHS) {
13102 switch (Pred) {
13103 default:
13104 llvm_unreachable("Unexpected CmpPredicate value!");
13105 case ICmpInst::ICMP_EQ:
13106 case ICmpInst::ICMP_NE:
13107 if (HasSameValue(LHS, FoundLHS) && HasSameValue(RHS, FoundRHS))
13108 return true;
13109 break;
13110 case ICmpInst::ICMP_SLT:
13111 case ICmpInst::ICMP_SLE:
13112 if (isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_SLE, LHS, FoundLHS) &&
13113 isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_SGE, RHS, FoundRHS))
13114 return true;
13115 break;
13116 case ICmpInst::ICMP_SGT:
13117 case ICmpInst::ICMP_SGE:
13118 if (isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_SGE, LHS, FoundLHS) &&
13119 isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_SLE, RHS, FoundRHS))
13120 return true;
13121 break;
13122 case ICmpInst::ICMP_ULT:
13123 case ICmpInst::ICMP_ULE:
13124 if (isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_ULE, LHS, FoundLHS) &&
13125 isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_UGE, RHS, FoundRHS))
13126 return true;
13127 break;
13128 case ICmpInst::ICMP_UGT:
13129 case ICmpInst::ICMP_UGE:
13130 if (isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_UGE, LHS, FoundLHS) &&
13131 isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_ULE, RHS, FoundRHS))
13132 return true;
13133 break;
13134 }
13135
13136 // Maybe it can be proved via operations?
13137 if (isImpliedViaOperations(Pred, LHS, RHS, FoundLHS, FoundRHS))
13138 return true;
13139
13140 return false;
13141}
13142
13143bool ScalarEvolution::isImpliedCondOperandsViaRanges(
13144 CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS, CmpPredicate FoundPred,
13145 const SCEV *FoundLHS, const SCEV *FoundRHS) {
13146 if (!isa<SCEVConstant>(RHS) || !isa<SCEVConstant>(FoundRHS))
13147 // The restriction on `FoundRHS` be lifted easily -- it exists only to
13148 // reduce the compile time impact of this optimization.
13149 return false;
13150
13151 std::optional<APInt> Addend = computeConstantDifference(LHS, FoundLHS);
13152 if (!Addend)
13153 return false;
13154
13155 const APInt &ConstFoundRHS = cast<SCEVConstant>(FoundRHS)->getAPInt();
13156
13157 // `FoundLHSRange` is the range we know `FoundLHS` to be in by virtue of the
13158 // antecedent "`FoundLHS` `FoundPred` `FoundRHS`".
13159 ConstantRange FoundLHSRange =
13160 ConstantRange::makeExactICmpRegion(FoundPred, ConstFoundRHS);
13161
13162 // Since `LHS` is `FoundLHS` + `Addend`, we can compute a range for `LHS`:
13163 ConstantRange LHSRange = FoundLHSRange.add(ConstantRange(*Addend));
13164
13165 // We can also compute the range of values for `LHS` that satisfy the
13166 // consequent, "`LHS` `Pred` `RHS`":
13167 const APInt &ConstRHS = cast<SCEVConstant>(RHS)->getAPInt();
13168 // The antecedent implies the consequent if every value of `LHS` that
13169 // satisfies the antecedent also satisfies the consequent.
13170 return LHSRange.icmp(Pred, ConstRHS);
13171}
13172
13173bool ScalarEvolution::canIVOverflowOnLT(const SCEV *RHS, const SCEV *Stride,
13174 bool IsSigned) {
13175 assert(isKnownPositive(Stride) && "Positive stride expected!");
13176
13177 unsigned BitWidth = getTypeSizeInBits(RHS->getType());
13178 const SCEV *One = getOne(Stride->getType());
13179
13180 if (IsSigned) {
13181 APInt MaxRHS = getSignedRangeMax(RHS);
13182 APInt MaxValue = APInt::getSignedMaxValue(BitWidth);
13183 APInt MaxStrideMinusOne = getSignedRangeMax(getMinusSCEV(Stride, One));
13184
13185 // SMaxRHS + SMaxStrideMinusOne > SMaxValue => overflow!
13186 return (std::move(MaxValue) - MaxStrideMinusOne).slt(MaxRHS);
13187 }
13188
13189 APInt MaxRHS = getUnsignedRangeMax(RHS);
13190 APInt MaxValue = APInt::getMaxValue(BitWidth);
13191 APInt MaxStrideMinusOne = getUnsignedRangeMax(getMinusSCEV(Stride, One));
13192
13193 // UMaxRHS + UMaxStrideMinusOne > UMaxValue => overflow!
13194 return (std::move(MaxValue) - MaxStrideMinusOne).ult(MaxRHS);
13195}
13196
13197bool ScalarEvolution::canIVOverflowOnGT(const SCEV *RHS, const SCEV *Stride,
13198 bool IsSigned) {
13199
13200 unsigned BitWidth = getTypeSizeInBits(RHS->getType());
13201 const SCEV *One = getOne(Stride->getType());
13202
13203 if (IsSigned) {
13204 APInt MinRHS = getSignedRangeMin(RHS);
13205 APInt MinValue = APInt::getSignedMinValue(BitWidth);
13206 APInt MaxStrideMinusOne = getSignedRangeMax(getMinusSCEV(Stride, One));
13207
13208 // SMinRHS - SMaxStrideMinusOne < SMinValue => overflow!
13209 return (std::move(MinValue) + MaxStrideMinusOne).sgt(MinRHS);
13210 }
13211
13212 APInt MinRHS = getUnsignedRangeMin(RHS);
13213 APInt MinValue = APInt::getMinValue(BitWidth);
13214 APInt MaxStrideMinusOne = getUnsignedRangeMax(getMinusSCEV(Stride, One));
13215
13216 // UMinRHS - UMaxStrideMinusOne < UMinValue => overflow!
13217 return (std::move(MinValue) + MaxStrideMinusOne).ugt(MinRHS);
13218}
13219
13221 // umin(N, 1) + floor((N - umin(N, 1)) / D)
13222 // This is equivalent to "1 + floor((N - 1) / D)" for N != 0. The umin
13223 // expression fixes the case of N=0.
13224 const SCEV *MinNOne = getUMinExpr(N, getOne(N->getType()));
13225 const SCEV *NMinusOne = getMinusSCEV(N, MinNOne);
13226 return getAddExpr(MinNOne, getUDivExpr(NMinusOne, D));
13227}
13228
13229const SCEV *ScalarEvolution::computeMaxBECountForLT(const SCEV *Start,
13230 const SCEV *Stride,
13231 const SCEV *End,
13232 unsigned BitWidth,
13233 bool IsSigned) {
13234 // The logic in this function assumes we can represent a positive stride.
13235 // If we can't, the backedge-taken count must be zero.
13236 if (IsSigned && BitWidth == 1)
13237 return getZero(Stride->getType());
13238
13239 // This code below only been closely audited for negative strides in the
13240 // unsigned comparison case, it may be correct for signed comparison, but
13241 // that needs to be established.
13242 if (IsSigned && isKnownNegative(Stride))
13243 return getCouldNotCompute();
13244
13245 // Calculate the maximum backedge count based on the range of values
13246 // permitted by Start, End, and Stride.
13247 APInt MinStart =
13248 IsSigned ? getSignedRangeMin(Start) : getUnsignedRangeMin(Start);
13249
13250 APInt MinStride =
13251 IsSigned ? getSignedRangeMin(Stride) : getUnsignedRangeMin(Stride);
13252
13253 // We assume either the stride is positive, or the backedge-taken count
13254 // is zero. So force StrideForMaxBECount to be at least one.
13255 APInt One(BitWidth, 1);
13256 APInt StrideForMaxBECount = IsSigned ? APIntOps::smax(One, MinStride)
13257 : APIntOps::umax(One, MinStride);
13258
13259 APInt MaxValue = IsSigned ? APInt::getSignedMaxValue(BitWidth)
13260 : APInt::getMaxValue(BitWidth);
13261 APInt Limit = MaxValue - (StrideForMaxBECount - 1);
13262
13263 // Although End can be a MAX expression we estimate MaxEnd considering only
13264 // the case End = RHS of the loop termination condition. This is safe because
13265 // in the other case (End - Start) is zero, leading to a zero maximum backedge
13266 // taken count.
13267 APInt MaxEnd = IsSigned ? APIntOps::smin(getSignedRangeMax(End), Limit)
13268 : APIntOps::umin(getUnsignedRangeMax(End), Limit);
13269
13270 // MaxBECount = ceil((max(MaxEnd, MinStart) - MinStart) / Stride)
13271 MaxEnd = IsSigned ? APIntOps::smax(MaxEnd, MinStart)
13272 : APIntOps::umax(MaxEnd, MinStart);
13273
13274 return getUDivCeilSCEV(getConstant(MaxEnd - MinStart) /* Delta */,
13275 getConstant(StrideForMaxBECount) /* Step */);
13276}
13277
13279ScalarEvolution::howManyLessThans(const SCEV *LHS, const SCEV *RHS,
13280 const Loop *L, bool IsSigned,
13281 bool ControlsOnlyExit, bool AllowPredicates) {
13283
13285 bool PredicatedIV = false;
13286 if (!IV) {
13287 if (auto *ZExt = dyn_cast<SCEVZeroExtendExpr>(LHS)) {
13288 const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(ZExt->getOperand());
13289 if (AR && AR->getLoop() == L && AR->isAffine()) {
13290 auto canProveNUW = [&]() {
13291 // We can use the comparison to infer no-wrap flags only if it fully
13292 // controls the loop exit.
13293 if (!ControlsOnlyExit)
13294 return false;
13295
13296 if (!isLoopInvariant(RHS, L))
13297 return false;
13298
13299 if (!isKnownNonZero(AR->getStepRecurrence(*this)))
13300 // We need the sequence defined by AR to strictly increase in the
13301 // unsigned integer domain for the logic below to hold.
13302 return false;
13303
13304 const unsigned InnerBitWidth = getTypeSizeInBits(AR->getType());
13305 const unsigned OuterBitWidth = getTypeSizeInBits(RHS->getType());
13306 // If RHS <=u Limit, then there must exist a value V in the sequence
13307 // defined by AR (e.g. {Start,+,Step}) such that V >u RHS, and
13308 // V <=u UINT_MAX. Thus, we must exit the loop before unsigned
13309 // overflow occurs. This limit also implies that a signed comparison
13310 // (in the wide bitwidth) is equivalent to an unsigned comparison as
13311 // the high bits on both sides must be zero.
13312 APInt StrideMax = getUnsignedRangeMax(AR->getStepRecurrence(*this));
13313 APInt Limit = APInt::getMaxValue(InnerBitWidth) - (StrideMax - 1);
13314 Limit = Limit.zext(OuterBitWidth);
13315 return getUnsignedRangeMax(applyLoopGuards(RHS, L)).ule(Limit);
13316 };
13317 auto Flags = AR->getNoWrapFlags();
13318 if (!hasFlags(Flags, SCEV::FlagNUW) && canProveNUW())
13319 Flags = setFlags(Flags, SCEV::FlagNUW);
13320
13321 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), Flags);
13322 if (AR->hasNoUnsignedWrap()) {
13323 // Emulate what getZeroExtendExpr would have done during construction
13324 // if we'd been able to infer the fact just above at that time.
13325 const SCEV *Step = AR->getStepRecurrence(*this);
13326 Type *Ty = ZExt->getType();
13327 auto *S = getAddRecExpr(
13329 getZeroExtendExpr(Step, Ty, 0), L, AR->getNoWrapFlags());
13331 }
13332 }
13333 }
13334 }
13335
13336
13337 if (!IV && AllowPredicates) {
13338 // Try to make this an AddRec using runtime tests, in the first X
13339 // iterations of this loop, where X is the SCEV expression found by the
13340 // algorithm below.
13341 IV = convertSCEVToAddRecWithPredicates(LHS, L, Predicates);
13342 PredicatedIV = true;
13343 }
13344
13345 // Avoid weird loops
13346 if (!IV || IV->getLoop() != L || !IV->isAffine())
13347 return getCouldNotCompute();
13348
13349 // A precondition of this method is that the condition being analyzed
13350 // reaches an exiting branch which dominates the latch. Given that, we can
13351 // assume that an increment which violates the nowrap specification and
13352 // produces poison must cause undefined behavior when the resulting poison
13353 // value is branched upon and thus we can conclude that the backedge is
13354 // taken no more often than would be required to produce that poison value.
13355 // Note that a well defined loop can exit on the iteration which violates
13356 // the nowrap specification if there is another exit (either explicit or
13357 // implicit/exceptional) which causes the loop to execute before the
13358 // exiting instruction we're analyzing would trigger UB.
13359 auto WrapType = IsSigned ? SCEV::FlagNSW : SCEV::FlagNUW;
13360 bool NoWrap = ControlsOnlyExit && any(IV->getNoWrapFlags(WrapType));
13362
13363 const SCEV *Stride = IV->getStepRecurrence(*this);
13364
13365 bool PositiveStride = isKnownPositive(Stride);
13366
13367 // Avoid negative or zero stride values.
13368 if (!PositiveStride) {
13369 // We can compute the correct backedge taken count for loops with unknown
13370 // strides if we can prove that the loop is not an infinite loop with side
13371 // effects. Here's the loop structure we are trying to handle -
13372 //
13373 // i = start
13374 // do {
13375 // A[i] = i;
13376 // i += s;
13377 // } while (i < end);
13378 //
13379 // The backedge taken count for such loops is evaluated as -
13380 // (max(end, start + stride) - start - 1) /u stride
13381 //
13382 // The additional preconditions that we need to check to prove correctness
13383 // of the above formula is as follows -
13384 //
13385 // a) IV is either nuw or nsw depending upon signedness (indicated by the
13386 // NoWrap flag).
13387 // b) the loop is guaranteed to be finite (e.g. is mustprogress and has
13388 // no side effects within the loop)
13389 // c) loop has a single static exit (with no abnormal exits)
13390 //
13391 // Precondition a) implies that if the stride is negative, this is a single
13392 // trip loop. The backedge taken count formula reduces to zero in this case.
13393 //
13394 // Precondition b) and c) combine to imply that if rhs is invariant in L,
13395 // then a zero stride means the backedge can't be taken without executing
13396 // undefined behavior.
13397 //
13398 // The positive stride case is the same as isKnownPositive(Stride) returning
13399 // true (original behavior of the function).
13400 //
13401 if (PredicatedIV || !NoWrap || !loopIsFiniteByAssumption(L) ||
13403 return getCouldNotCompute();
13404
13405 if (!isKnownNonZero(Stride)) {
13406 // If we have a step of zero, and RHS isn't invariant in L, we don't know
13407 // if it might eventually be greater than start and if so, on which
13408 // iteration. We can't even produce a useful upper bound.
13409 if (!isLoopInvariant(RHS, L))
13410 return getCouldNotCompute();
13411
13412 // We allow a potentially zero stride, but we need to divide by stride
13413 // below. Since the loop can't be infinite and this check must control
13414 // the sole exit, we can infer the exit must be taken on the first
13415 // iteration (e.g. backedge count = 0) if the stride is zero. Given that,
13416 // we know the numerator in the divides below must be zero, so we can
13417 // pick an arbitrary non-zero value for the denominator (e.g. stride)
13418 // and produce the right result.
13419 // FIXME: Handle the case where Stride is poison?
13420 auto wouldZeroStrideBeUB = [&]() {
13421 // Proof by contradiction. Suppose the stride were zero. If we can
13422 // prove that the backedge *is* taken on the first iteration, then since
13423 // we know this condition controls the sole exit, we must have an
13424 // infinite loop. We can't have a (well defined) infinite loop per
13425 // check just above.
13426 // Note: The (Start - Stride) term is used to get the start' term from
13427 // (start' + stride,+,stride). Remember that we only care about the
13428 // result of this expression when stride == 0 at runtime.
13429 auto *StartIfZero = getMinusSCEV(IV->getStart(), Stride);
13430 return isLoopEntryGuardedByCond(L, Cond, StartIfZero, RHS);
13431 };
13432 if (!wouldZeroStrideBeUB()) {
13433 Stride = getUMaxExpr(Stride, getOne(Stride->getType()));
13434 }
13435 }
13436 } else if (!NoWrap) {
13437 // Avoid proven overflow cases: this will ensure that the backedge taken
13438 // count will not generate any unsigned overflow.
13439 if (canIVOverflowOnLT(RHS, Stride, IsSigned))
13440 return getCouldNotCompute();
13441 }
13442
13443 // On all paths just preceeding, we established the following invariant:
13444 // IV can be assumed not to overflow up to and including the exiting
13445 // iteration. We proved this in one of two ways:
13446 // 1) We can show overflow doesn't occur before the exiting iteration
13447 // 1a) canIVOverflowOnLT, and b) step of one
13448 // 2) We can show that if overflow occurs, the loop must execute UB
13449 // before any possible exit.
13450 // Note that we have not yet proved RHS invariant (in general).
13451
13452 const SCEV *Start = IV->getStart();
13453
13454 // Preserve pointer-typed Start/RHS to pass to isLoopEntryGuardedByCond.
13455 // If we convert to integers, isLoopEntryGuardedByCond will miss some cases.
13456 // Use integer-typed versions for actual computation; we can't subtract
13457 // pointers in general.
13458 const SCEV *OrigStart = Start;
13459 const SCEV *OrigRHS = RHS;
13460 if (Start->getType()->isPointerTy()) {
13462 if (isa<SCEVCouldNotCompute>(Start))
13463 return Start;
13464 }
13465 if (RHS->getType()->isPointerTy()) {
13468 return RHS;
13469 }
13470
13471 const SCEV *End = nullptr, *BECount = nullptr,
13472 *BECountIfBackedgeTaken = nullptr;
13473 if (!isLoopInvariant(RHS, L)) {
13474 const auto *RHSAddRec = dyn_cast<SCEVAddRecExpr>(RHS);
13475 if (PositiveStride && RHSAddRec != nullptr && RHSAddRec->getLoop() == L &&
13476 any(RHSAddRec->getNoWrapFlags())) {
13477 // The structure of loop we are trying to calculate backedge count of:
13478 //
13479 // left = left_start
13480 // right = right_start
13481 //
13482 // while(left < right){
13483 // ... do something here ...
13484 // left += s1; // stride of left is s1 (s1 > 0)
13485 // right += s2; // stride of right is s2 (s2 < 0)
13486 // }
13487 //
13488
13489 const SCEV *RHSStart = RHSAddRec->getStart();
13490 const SCEV *RHSStride = RHSAddRec->getStepRecurrence(*this);
13491
13492 // If Stride - RHSStride is positive and does not overflow, we can write
13493 // backedge count as ->
13494 // ceil((End - Start) /u (Stride - RHSStride))
13495 // Where, End = max(RHSStart, Start)
13496
13497 // Check if RHSStride < 0 and Stride - RHSStride will not overflow.
13498 if (isKnownNegative(RHSStride) &&
13499 willNotOverflow(Instruction::Sub, /*Signed=*/true, Stride,
13500 RHSStride)) {
13501
13502 const SCEV *Denominator = getMinusSCEV(Stride, RHSStride);
13503 if (isKnownPositive(Denominator)) {
13504 End = IsSigned ? getSMaxExpr(RHSStart, Start)
13505 : getUMaxExpr(RHSStart, Start);
13506
13507 // We can do this because End >= Start, as End = max(RHSStart, Start)
13508 const SCEV *Delta = getMinusSCEV(End, Start);
13509
13510 BECount = getUDivCeilSCEV(Delta, Denominator);
13511 BECountIfBackedgeTaken =
13512 getUDivCeilSCEV(getMinusSCEV(RHSStart, Start), Denominator);
13513 }
13514 }
13515 }
13516 if (BECount == nullptr) {
13517 // If we cannot calculate ExactBECount, we can calculate the MaxBECount,
13518 // given the start, stride and max value for the end bound of the
13519 // loop (RHS), and the fact that IV does not overflow (which is
13520 // checked above).
13521 const SCEV *MaxBECount = computeMaxBECountForLT(
13522 Start, Stride, RHS, getTypeSizeInBits(LHS->getType()), IsSigned);
13523 return ExitLimit(getCouldNotCompute() /* ExactNotTaken */, MaxBECount,
13524 MaxBECount, false /*MaxOrZero*/, Predicates);
13525 }
13526 } else {
13527 // We use the expression (max(End,Start)-Start)/Stride to describe the
13528 // backedge count, as if the backedge is taken at least once
13529 // max(End,Start) is End and so the result is as above, and if not
13530 // max(End,Start) is Start so we get a backedge count of zero.
13531 auto *OrigStartMinusStride = getMinusSCEV(OrigStart, Stride);
13532 assert(isAvailableAtLoopEntry(OrigStartMinusStride, L) && "Must be!");
13533 assert(isAvailableAtLoopEntry(OrigStart, L) && "Must be!");
13534 assert(isAvailableAtLoopEntry(OrigRHS, L) && "Must be!");
13535 // Can we prove (max(RHS,Start) > Start - Stride?
13536 if (isLoopEntryGuardedByCond(L, Cond, OrigStartMinusStride, OrigStart) &&
13537 isLoopEntryGuardedByCond(L, Cond, OrigStartMinusStride, OrigRHS)) {
13538 // In this case, we can use a refined formula for computing backedge
13539 // taken count. The general formula remains:
13540 // "End-Start /uceiling Stride" where "End = max(RHS,Start)"
13541 // We want to use the alternate formula:
13542 // "((End - 1) - (Start - Stride)) /u Stride"
13543 // Let's do a quick case analysis to show these are equivalent under
13544 // our precondition that max(RHS,Start) > Start - Stride.
13545 // * For RHS <= Start, the backedge-taken count must be zero.
13546 // "((End - 1) - (Start - Stride)) /u Stride" reduces to
13547 // "((Start - 1) - (Start - Stride)) /u Stride" which simplies to
13548 // "Stride - 1 /u Stride" which is indeed zero for all non-zero values
13549 // of Stride. For 0 stride, we've use umin(1,Stride) above,
13550 // reducing this to the stride of 1 case.
13551 // * For RHS >= Start, the backedge count must be "RHS-Start /uceil
13552 // Stride".
13553 // "((End - 1) - (Start - Stride)) /u Stride" reduces to
13554 // "((RHS - 1) - (Start - Stride)) /u Stride" reassociates to
13555 // "((RHS - (Start - Stride) - 1) /u Stride".
13556 // Our preconditions trivially imply no overflow in that form.
13557 const SCEV *MinusOne = getMinusOne(Stride->getType());
13558 const SCEV *Numerator =
13559 getMinusSCEV(getAddExpr(RHS, MinusOne), getMinusSCEV(Start, Stride));
13560 BECount = getUDivExpr(Numerator, Stride);
13561 }
13562
13563 if (!BECount) {
13564 auto canProveRHSGreaterThanEqualStart = [&]() {
13565 auto CondGE = IsSigned ? ICmpInst::ICMP_SGE : ICmpInst::ICMP_UGE;
13566 const SCEV *GuardedRHS = applyLoopGuards(OrigRHS, L);
13567 const SCEV *GuardedStart = applyLoopGuards(OrigStart, L);
13568
13569 if (isLoopEntryGuardedByCond(L, CondGE, OrigRHS, OrigStart) ||
13570 isKnownPredicate(CondGE, GuardedRHS, GuardedStart))
13571 return true;
13572
13573 // (RHS > Start - 1) implies RHS >= Start.
13574 // * "RHS >= Start" is trivially equivalent to "RHS > Start - 1" if
13575 // "Start - 1" doesn't overflow.
13576 // * For signed comparison, if Start - 1 does overflow, it's equal
13577 // to INT_MAX, and "RHS >s INT_MAX" is trivially false.
13578 // * For unsigned comparison, if Start - 1 does overflow, it's equal
13579 // to UINT_MAX, and "RHS >u UINT_MAX" is trivially false.
13580 //
13581 // FIXME: Should isLoopEntryGuardedByCond do this for us?
13582 auto CondGT = IsSigned ? ICmpInst::ICMP_SGT : ICmpInst::ICMP_UGT;
13583 auto *StartMinusOne =
13584 getAddExpr(OrigStart, getMinusOne(OrigStart->getType()));
13585 return isLoopEntryGuardedByCond(L, CondGT, OrigRHS, StartMinusOne);
13586 };
13587
13588 // If we know that RHS >= Start in the context of loop, then we know
13589 // that max(RHS, Start) = RHS at this point.
13590 if (canProveRHSGreaterThanEqualStart()) {
13591 End = RHS;
13592 } else {
13593 // If RHS < Start, the backedge will be taken zero times. So in
13594 // general, we can write the backedge-taken count as:
13595 //
13596 // RHS >= Start ? ceil(RHS - Start) / Stride : 0
13597 //
13598 // We convert it to the following to make it more convenient for SCEV:
13599 //
13600 // ceil(max(RHS, Start) - Start) / Stride
13601 End = IsSigned ? getSMaxExpr(RHS, Start) : getUMaxExpr(RHS, Start);
13602
13603 // See what would happen if we assume the backedge is taken. This is
13604 // used to compute MaxBECount.
13605 BECountIfBackedgeTaken =
13606 getUDivCeilSCEV(getMinusSCEV(RHS, Start), Stride);
13607 }
13608
13609 // At this point, we know:
13610 //
13611 // 1. If IsSigned, Start <=s End; otherwise, Start <=u End
13612 // 2. The index variable doesn't overflow.
13613 //
13614 // Therefore, we know N exists such that
13615 // (Start + Stride * N) >= End, and computing "(Start + Stride * N)"
13616 // doesn't overflow.
13617 //
13618 // Using this information, try to prove whether the addition in
13619 // "(Start - End) + (Stride - 1)" has unsigned overflow.
13620 const SCEV *One = getOne(Stride->getType());
13621 bool MayAddOverflow = [&] {
13622 if (isKnownToBeAPowerOfTwo(Stride)) {
13623 // Suppose Stride is a power of two, and Start/End are unsigned
13624 // integers. Let UMAX be the largest representable unsigned
13625 // integer.
13626 //
13627 // By the preconditions of this function, we know
13628 // "(Start + Stride * N) >= End", and this doesn't overflow.
13629 // As a formula:
13630 //
13631 // End <= (Start + Stride * N) <= UMAX
13632 //
13633 // Subtracting Start from all the terms:
13634 //
13635 // End - Start <= Stride * N <= UMAX - Start
13636 //
13637 // Since Start is unsigned, UMAX - Start <= UMAX. Therefore:
13638 //
13639 // End - Start <= Stride * N <= UMAX
13640 //
13641 // Stride * N is a multiple of Stride. Therefore,
13642 //
13643 // End - Start <= Stride * N <= UMAX - (UMAX mod Stride)
13644 //
13645 // Since Stride is a power of two, UMAX + 1 is divisible by
13646 // Stride. Therefore, UMAX mod Stride == Stride - 1. So we can
13647 // write:
13648 //
13649 // End - Start <= Stride * N <= UMAX - Stride - 1
13650 //
13651 // Dropping the middle term:
13652 //
13653 // End - Start <= UMAX - Stride - 1
13654 //
13655 // Adding Stride - 1 to both sides:
13656 //
13657 // (End - Start) + (Stride - 1) <= UMAX
13658 //
13659 // In other words, the addition doesn't have unsigned overflow.
13660 //
13661 // A similar proof works if we treat Start/End as signed values.
13662 // Just rewrite steps before "End - Start <= Stride * N <= UMAX"
13663 // to use signed max instead of unsigned max. Note that we're
13664 // trying to prove a lack of unsigned overflow in either case.
13665 return false;
13666 }
13667 if (Start == Stride || Start == getMinusSCEV(Stride, One)) {
13668 // If Start is equal to Stride, (End - Start) + (Stride - 1) == End
13669 // - 1. If !IsSigned, 0 <u Stride == Start <=u End; so 0 <u End - 1
13670 // <u End. If IsSigned, 0 <s Stride == Start <=s End; so 0 <s End -
13671 // 1 <s End.
13672 //
13673 // If Start is equal to Stride - 1, (End - Start) + Stride - 1 ==
13674 // End.
13675 return false;
13676 }
13677 return true;
13678 }();
13679
13680 const SCEV *Delta = getMinusSCEV(End, Start);
13681 if (!MayAddOverflow) {
13682 // floor((D + (S - 1)) / S)
13683 // We prefer this formulation if it's legal because it's fewer
13684 // operations.
13685 BECount =
13686 getUDivExpr(getAddExpr(Delta, getMinusSCEV(Stride, One)), Stride);
13687 } else {
13688 BECount = getUDivCeilSCEV(Delta, Stride);
13689 }
13690 }
13691 }
13692
13693 const SCEV *ConstantMaxBECount;
13694 bool MaxOrZero = false;
13695 if (isa<SCEVConstant>(BECount)) {
13696 ConstantMaxBECount = BECount;
13697 } else if (BECountIfBackedgeTaken &&
13698 isa<SCEVConstant>(BECountIfBackedgeTaken)) {
13699 // If we know exactly how many times the backedge will be taken if it's
13700 // taken at least once, then the backedge count will either be that or
13701 // zero.
13702 ConstantMaxBECount = BECountIfBackedgeTaken;
13703 MaxOrZero = true;
13704 } else {
13705 ConstantMaxBECount = computeMaxBECountForLT(
13706 Start, Stride, RHS, getTypeSizeInBits(LHS->getType()), IsSigned);
13707 }
13708
13709 if (isa<SCEVCouldNotCompute>(ConstantMaxBECount) &&
13710 !isa<SCEVCouldNotCompute>(BECount))
13711 ConstantMaxBECount = getConstant(getUnsignedRangeMax(BECount));
13712
13713 const SCEV *SymbolicMaxBECount =
13714 isa<SCEVCouldNotCompute>(BECount) ? ConstantMaxBECount : BECount;
13715 return ExitLimit(BECount, ConstantMaxBECount, SymbolicMaxBECount, MaxOrZero,
13716 Predicates);
13717}
13718
13719ScalarEvolution::ExitLimit ScalarEvolution::howManyGreaterThans(
13720 const SCEV *LHS, const SCEV *RHS, const Loop *L, bool IsSigned,
13721 bool ControlsOnlyExit, bool AllowPredicates) {
13723 // We handle only IV > Invariant
13724 if (!isLoopInvariant(RHS, L))
13725 return getCouldNotCompute();
13726
13727 const SCEVAddRecExpr *IV = dyn_cast<SCEVAddRecExpr>(LHS);
13728 if (!IV && AllowPredicates)
13729 // Try to make this an AddRec using runtime tests, in the first X
13730 // iterations of this loop, where X is the SCEV expression found by the
13731 // algorithm below.
13732 IV = convertSCEVToAddRecWithPredicates(LHS, L, Predicates);
13733
13734 // Avoid weird loops
13735 if (!IV || IV->getLoop() != L || !IV->isAffine())
13736 return getCouldNotCompute();
13737
13738 auto WrapType = IsSigned ? SCEV::FlagNSW : SCEV::FlagNUW;
13739 bool NoWrap = ControlsOnlyExit && any(IV->getNoWrapFlags(WrapType));
13741
13742 const SCEV *Stride = getNegativeSCEV(IV->getStepRecurrence(*this));
13743
13744 // Avoid negative or zero stride values
13745 if (!isKnownPositive(Stride))
13746 return getCouldNotCompute();
13747
13748 // Avoid proven overflow cases: this will ensure that the backedge taken count
13749 // will not generate any unsigned overflow. Relaxed no-overflow conditions
13750 // exploit NoWrapFlags, allowing to optimize in presence of undefined
13751 // behaviors like the case of C language.
13752 if (!Stride->isOne() && !NoWrap)
13753 if (canIVOverflowOnGT(RHS, Stride, IsSigned))
13754 return getCouldNotCompute();
13755
13756 const SCEV *Start = IV->getStart();
13757 const SCEV *End = RHS;
13758 if (!isLoopEntryGuardedByCond(L, Cond, getAddExpr(Start, Stride), RHS)) {
13759 // If we know that Start >= RHS in the context of loop, then we know that
13760 // min(RHS, Start) = RHS at this point.
13762 L, IsSigned ? ICmpInst::ICMP_SGE : ICmpInst::ICMP_UGE, Start, RHS))
13763 End = RHS;
13764 else
13765 End = IsSigned ? getSMinExpr(RHS, Start) : getUMinExpr(RHS, Start);
13766 }
13767
13768 if (Start->getType()->isPointerTy()) {
13770 if (isa<SCEVCouldNotCompute>(Start))
13771 return Start;
13772 }
13773 if (End->getType()->isPointerTy()) {
13774 End = getLosslessPtrToIntExpr(End);
13775 if (isa<SCEVCouldNotCompute>(End))
13776 return End;
13777 }
13778
13779 // Compute ((Start - End) + (Stride - 1)) / Stride.
13780 // FIXME: This can overflow. Holding off on fixing this for now;
13781 // howManyGreaterThans will hopefully be gone soon.
13782 const SCEV *One = getOne(Stride->getType());
13783 const SCEV *BECount = getUDivExpr(
13784 getAddExpr(getMinusSCEV(Start, End), getMinusSCEV(Stride, One)), Stride);
13785
13786 APInt MaxStart = IsSigned ? getSignedRangeMax(Start)
13788
13789 APInt MinStride = IsSigned ? getSignedRangeMin(Stride)
13790 : getUnsignedRangeMin(Stride);
13791
13792 unsigned BitWidth = getTypeSizeInBits(LHS->getType());
13793 APInt Limit = IsSigned ? APInt::getSignedMinValue(BitWidth) + (MinStride - 1)
13794 : APInt::getMinValue(BitWidth) + (MinStride - 1);
13795
13796 // Although End can be a MIN expression we estimate MinEnd considering only
13797 // the case End = RHS. This is safe because in the other case (Start - End)
13798 // is zero, leading to a zero maximum backedge taken count.
13799 APInt MinEnd =
13800 IsSigned ? APIntOps::smax(getSignedRangeMin(RHS), Limit)
13801 : APIntOps::umax(getUnsignedRangeMin(RHS), Limit);
13802
13803 const SCEV *ConstantMaxBECount =
13804 isa<SCEVConstant>(BECount)
13805 ? BECount
13806 : getUDivCeilSCEV(getConstant(MaxStart - MinEnd),
13807 getConstant(MinStride));
13808
13809 if (isa<SCEVCouldNotCompute>(ConstantMaxBECount))
13810 ConstantMaxBECount = BECount;
13811 const SCEV *SymbolicMaxBECount =
13812 isa<SCEVCouldNotCompute>(BECount) ? ConstantMaxBECount : BECount;
13813
13814 return ExitLimit(BECount, ConstantMaxBECount, SymbolicMaxBECount, false,
13815 Predicates);
13816}
13817
13819 ScalarEvolution &SE) const {
13820 if (Range.isFullSet()) // Infinite loop.
13821 return SE.getCouldNotCompute();
13822
13823 // If the start is a non-zero constant, shift the range to simplify things.
13824 if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(getStart()))
13825 if (!SC->getValue()->isZero()) {
13827 Operands[0] = SE.getZero(SC->getType());
13828 const SCEV *Shifted = SE.getAddRecExpr(Operands, getLoop(),
13830 if (const auto *ShiftedAddRec = dyn_cast<SCEVAddRecExpr>(Shifted))
13831 return ShiftedAddRec->getNumIterationsInRange(
13832 Range.subtract(SC->getAPInt()), SE);
13833 // This is strange and shouldn't happen.
13834 return SE.getCouldNotCompute();
13835 }
13836
13837 // The only time we can solve this is when we have all constant indices.
13838 // Otherwise, we cannot determine the overflow conditions.
13839 if (any_of(operands(), [](const SCEV *Op) { return !isa<SCEVConstant>(Op); }))
13840 return SE.getCouldNotCompute();
13841
13842 // Okay at this point we know that all elements of the chrec are constants and
13843 // that the start element is zero.
13844
13845 // First check to see if the range contains zero. If not, the first
13846 // iteration exits.
13847 unsigned BitWidth = SE.getTypeSizeInBits(getType());
13848 if (!Range.contains(APInt(BitWidth, 0)))
13849 return SE.getZero(getType());
13850
13851 if (isAffine()) {
13852 // If this is an affine expression then we have this situation:
13853 // Solve {0,+,A} in Range === Ax in Range
13854
13855 // We know that zero is in the range. If A is positive then we know that
13856 // the upper value of the range must be the first possible exit value.
13857 // If A is negative then the lower of the range is the last possible loop
13858 // value. Also note that we already checked for a full range.
13859 APInt A = cast<SCEVConstant>(getOperand(1))->getAPInt();
13860 APInt End = A.sge(1) ? (Range.getUpper() - 1) : Range.getLower();
13861
13862 // The exit value should be (End+A)/A.
13863 APInt ExitVal = (End + A).udiv(A);
13864 ConstantInt *ExitValue = ConstantInt::get(SE.getContext(), ExitVal);
13865
13866 // Evaluate at the exit value. If we really did fall out of the valid
13867 // range, then we computed our trip count, otherwise wrap around or other
13868 // things must have happened.
13869 ConstantInt *Val = EvaluateConstantChrecAtConstant(this, ExitValue, SE);
13870 if (Range.contains(Val->getValue()))
13871 return SE.getCouldNotCompute(); // Something strange happened
13872
13873 // Ensure that the previous value is in the range.
13874 assert(Range.contains(
13876 ConstantInt::get(SE.getContext(), ExitVal - 1), SE)->getValue()) &&
13877 "Linear scev computation is off in a bad way!");
13878 return SE.getConstant(ExitValue);
13879 }
13880
13881 if (isQuadratic()) {
13882 if (auto S = SolveQuadraticAddRecRange(this, Range, SE))
13883 return SE.getConstant(*S);
13884 }
13885
13886 return SE.getCouldNotCompute();
13887}
13888
13889const SCEVAddRecExpr *
13891 assert(getNumOperands() > 1 && "AddRec with zero step?");
13892 // There is a temptation to just call getAddExpr(this, getStepRecurrence(SE)),
13893 // but in this case we cannot guarantee that the value returned will be an
13894 // AddRec because SCEV does not have a fixed point where it stops
13895 // simplification: it is legal to return ({rec1} + {rec2}). For example, it
13896 // may happen if we reach arithmetic depth limit while simplifying. So we
13897 // construct the returned value explicitly.
13899 // If this is {A,+,B,+,C,...,+,N}, then its step is {B,+,C,+,...,+,N}, and
13900 // (this + Step) is {A+B,+,B+C,+...,+,N}.
13901 for (unsigned i = 0, e = getNumOperands() - 1; i < e; ++i)
13902 Ops.push_back(SE.getAddExpr(getOperand(i), getOperand(i + 1)));
13903 // We know that the last operand is not a constant zero (otherwise it would
13904 // have been popped out earlier). This guarantees us that if the result has
13905 // the same last operand, then it will also not be popped out, meaning that
13906 // the returned value will be an AddRec.
13907 const SCEV *Last = getOperand(getNumOperands() - 1);
13908 assert(!Last->isZero() && "Recurrency with zero step?");
13909 Ops.push_back(Last);
13912}
13913
13914// Return true when S contains at least an undef value.
13916 return SCEVExprContains(
13917 S, [](const SCEV *S) { return match(S, m_scev_UndefOrPoison()); });
13918}
13919
13920// Return true when S contains a value that is a nullptr.
13922 return SCEVExprContains(S, [](const SCEV *S) {
13923 if (const auto *SU = dyn_cast<SCEVUnknown>(S))
13924 return SU->getValue() == nullptr;
13925 return false;
13926 });
13927}
13928
13929/// Return the size of an element read or written by Inst.
13931 Type *Ty;
13932 if (StoreInst *Store = dyn_cast<StoreInst>(Inst))
13933 Ty = Store->getValueOperand()->getType();
13934 else if (LoadInst *Load = dyn_cast<LoadInst>(Inst))
13935 Ty = Load->getType();
13936 else
13937 return nullptr;
13938
13940 return getSizeOfExpr(ETy, Ty);
13941}
13942
13943//===----------------------------------------------------------------------===//
13944// SCEVCallbackVH Class Implementation
13945//===----------------------------------------------------------------------===//
13946
13948 assert(SE && "SCEVCallbackVH called with a null ScalarEvolution!");
13949 if (PHINode *PN = dyn_cast<PHINode>(getValPtr()))
13950 SE->ConstantEvolutionLoopExitValue.erase(PN);
13951 SE->eraseValueFromMap(getValPtr());
13952 // this now dangles!
13953}
13954
13955void ScalarEvolution::SCEVCallbackVH::allUsesReplacedWith(Value *V) {
13956 assert(SE && "SCEVCallbackVH called with a null ScalarEvolution!");
13957
13958 // Forget all the expressions associated with users of the old value,
13959 // so that future queries will recompute the expressions using the new
13960 // value.
13961 SE->forgetValue(getValPtr());
13962 // this now dangles!
13963}
13964
13965ScalarEvolution::SCEVCallbackVH::SCEVCallbackVH(Value *V, ScalarEvolution *se)
13966 : CallbackVH(V), SE(se) {}
13967
13968//===----------------------------------------------------------------------===//
13969// ScalarEvolution Class Implementation
13970//===----------------------------------------------------------------------===//
13971
13974 LoopInfo &LI)
13975 : F(F), DL(F.getDataLayout()), TLI(TLI), AC(AC), DT(DT), LI(LI),
13976 CouldNotCompute(new SCEVCouldNotCompute()), ValuesAtScopes(64),
13977 LoopDispositions(64), BlockDispositions(64) {
13978 // To use guards for proving predicates, we need to scan every instruction in
13979 // relevant basic blocks, and not just terminators. Doing this is a waste of
13980 // time if the IR does not actually contain any calls to
13981 // @llvm.experimental.guard, so do a quick check and remember this beforehand.
13982 //
13983 // This pessimizes the case where a pass that preserves ScalarEvolution wants
13984 // to _add_ guards to the module when there weren't any before, and wants
13985 // ScalarEvolution to optimize based on those guards. For now we prefer to be
13986 // efficient in lieu of being smart in that rather obscure case.
13987
13988 auto *GuardDecl = Intrinsic::getDeclarationIfExists(
13989 F.getParent(), Intrinsic::experimental_guard);
13990 HasGuards = GuardDecl && !GuardDecl->use_empty();
13991}
13992
13994 : F(Arg.F), DL(Arg.DL), HasGuards(Arg.HasGuards), TLI(Arg.TLI), AC(Arg.AC),
13995 DT(Arg.DT), LI(Arg.LI), CouldNotCompute(std::move(Arg.CouldNotCompute)),
13996 ValueExprMap(std::move(Arg.ValueExprMap)),
13997 PendingLoopPredicates(std::move(Arg.PendingLoopPredicates)),
13998 PendingMerges(std::move(Arg.PendingMerges)),
13999 ConstantMultipleCache(std::move(Arg.ConstantMultipleCache)),
14000 BackedgeTakenCounts(std::move(Arg.BackedgeTakenCounts)),
14001 PredicatedBackedgeTakenCounts(
14002 std::move(Arg.PredicatedBackedgeTakenCounts)),
14003 BECountUsers(std::move(Arg.BECountUsers)),
14004 ConstantEvolutionLoopExitValue(
14005 std::move(Arg.ConstantEvolutionLoopExitValue)),
14006 ValuesAtScopes(std::move(Arg.ValuesAtScopes)),
14007 ValuesAtScopesUsers(std::move(Arg.ValuesAtScopesUsers)),
14008 LoopDispositions(std::move(Arg.LoopDispositions)),
14009 LoopPropertiesCache(std::move(Arg.LoopPropertiesCache)),
14010 BlockDispositions(std::move(Arg.BlockDispositions)),
14011 SCEVUsers(std::move(Arg.SCEVUsers)),
14012 UnsignedRanges(std::move(Arg.UnsignedRanges)),
14013 SignedRanges(std::move(Arg.SignedRanges)),
14014 UniqueSCEVs(std::move(Arg.UniqueSCEVs)),
14015 UniquePreds(std::move(Arg.UniquePreds)),
14016 SCEVAllocator(std::move(Arg.SCEVAllocator)),
14017 LoopUsers(std::move(Arg.LoopUsers)),
14018 PredicatedSCEVRewrites(std::move(Arg.PredicatedSCEVRewrites)),
14019 FirstUnknown(Arg.FirstUnknown) {
14020 Arg.FirstUnknown = nullptr;
14021}
14022
14024 // Iterate through all the SCEVUnknown instances and call their
14025 // destructors, so that they release their references to their values.
14026 for (SCEVUnknown *U = FirstUnknown; U;) {
14027 SCEVUnknown *Tmp = U;
14028 U = U->Next;
14029 Tmp->~SCEVUnknown();
14030 }
14031 FirstUnknown = nullptr;
14032
14033 ExprValueMap.clear();
14034 ValueExprMap.clear();
14035 HasRecMap.clear();
14036 BackedgeTakenCounts.clear();
14037 PredicatedBackedgeTakenCounts.clear();
14038
14039 assert(PendingLoopPredicates.empty() && "isImpliedCond garbage");
14040 assert(PendingMerges.empty() && "isImpliedViaMerge garbage");
14041 assert(!WalkingBEDominatingConds && "isLoopBackedgeGuardedByCond garbage!");
14042 assert(!ProvingSplitPredicate && "ProvingSplitPredicate garbage!");
14043}
14044
14048
14049/// When printing a top-level SCEV for trip counts, it's helpful to include
14050/// a type for constants which are otherwise hard to disambiguate.
14051static void PrintSCEVWithTypeHint(raw_ostream &OS, const SCEV* S) {
14052 if (isa<SCEVConstant>(S))
14053 OS << *S->getType() << " ";
14054 OS << *S;
14055}
14056
14058 const Loop *L) {
14059 // Print all inner loops first
14060 for (Loop *I : *L)
14061 PrintLoopInfo(OS, SE, I);
14062
14063 OS << "Loop ";
14064 L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
14065 OS << ": ";
14066
14067 SmallVector<BasicBlock *, 8> ExitingBlocks;
14068 L->getExitingBlocks(ExitingBlocks);
14069 if (ExitingBlocks.size() != 1)
14070 OS << "<multiple exits> ";
14071
14072 auto *BTC = SE->getBackedgeTakenCount(L);
14073 if (!isa<SCEVCouldNotCompute>(BTC)) {
14074 OS << "backedge-taken count is ";
14075 PrintSCEVWithTypeHint(OS, BTC);
14076 } else
14077 OS << "Unpredictable backedge-taken count.";
14078 OS << "\n";
14079
14080 if (ExitingBlocks.size() > 1)
14081 for (BasicBlock *ExitingBlock : ExitingBlocks) {
14082 OS << " exit count for " << ExitingBlock->getName() << ": ";
14083 const SCEV *EC = SE->getExitCount(L, ExitingBlock);
14084 PrintSCEVWithTypeHint(OS, EC);
14085 if (isa<SCEVCouldNotCompute>(EC)) {
14086 // Retry with predicates.
14088 EC = SE->getPredicatedExitCount(L, ExitingBlock, &Predicates);
14089 if (!isa<SCEVCouldNotCompute>(EC)) {
14090 OS << "\n predicated exit count for " << ExitingBlock->getName()
14091 << ": ";
14092 PrintSCEVWithTypeHint(OS, EC);
14093 OS << "\n Predicates:\n";
14094 for (const auto *P : Predicates)
14095 P->print(OS, 4);
14096 }
14097 }
14098 OS << "\n";
14099 }
14100
14101 OS << "Loop ";
14102 L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
14103 OS << ": ";
14104
14105 auto *ConstantBTC = SE->getConstantMaxBackedgeTakenCount(L);
14106 if (!isa<SCEVCouldNotCompute>(ConstantBTC)) {
14107 OS << "constant max backedge-taken count is ";
14108 PrintSCEVWithTypeHint(OS, ConstantBTC);
14110 OS << ", actual taken count either this or zero.";
14111 } else {
14112 OS << "Unpredictable constant max backedge-taken count. ";
14113 }
14114
14115 OS << "\n"
14116 "Loop ";
14117 L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
14118 OS << ": ";
14119
14120 auto *SymbolicBTC = SE->getSymbolicMaxBackedgeTakenCount(L);
14121 if (!isa<SCEVCouldNotCompute>(SymbolicBTC)) {
14122 OS << "symbolic max backedge-taken count is ";
14123 PrintSCEVWithTypeHint(OS, SymbolicBTC);
14125 OS << ", actual taken count either this or zero.";
14126 } else {
14127 OS << "Unpredictable symbolic max backedge-taken count. ";
14128 }
14129 OS << "\n";
14130
14131 if (ExitingBlocks.size() > 1)
14132 for (BasicBlock *ExitingBlock : ExitingBlocks) {
14133 OS << " symbolic max exit count for " << ExitingBlock->getName() << ": ";
14134 auto *ExitBTC = SE->getExitCount(L, ExitingBlock,
14136 PrintSCEVWithTypeHint(OS, ExitBTC);
14137 if (isa<SCEVCouldNotCompute>(ExitBTC)) {
14138 // Retry with predicates.
14140 ExitBTC = SE->getPredicatedExitCount(L, ExitingBlock, &Predicates,
14142 if (!isa<SCEVCouldNotCompute>(ExitBTC)) {
14143 OS << "\n predicated symbolic max exit count for "
14144 << ExitingBlock->getName() << ": ";
14145 PrintSCEVWithTypeHint(OS, ExitBTC);
14146 OS << "\n Predicates:\n";
14147 for (const auto *P : Predicates)
14148 P->print(OS, 4);
14149 }
14150 }
14151 OS << "\n";
14152 }
14153
14155 auto *PBT = SE->getPredicatedBackedgeTakenCount(L, Preds);
14156 if (PBT != BTC) {
14157 assert(!Preds.empty() && "Different predicated BTC, but no predicates");
14158 OS << "Loop ";
14159 L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
14160 OS << ": ";
14161 if (!isa<SCEVCouldNotCompute>(PBT)) {
14162 OS << "Predicated backedge-taken count is ";
14163 PrintSCEVWithTypeHint(OS, PBT);
14164 } else
14165 OS << "Unpredictable predicated backedge-taken count.";
14166 OS << "\n";
14167 OS << " Predicates:\n";
14168 for (const auto *P : Preds)
14169 P->print(OS, 4);
14170 }
14171 Preds.clear();
14172
14173 auto *PredConstantMax =
14175 if (PredConstantMax != ConstantBTC) {
14176 assert(!Preds.empty() &&
14177 "different predicated constant max BTC but no predicates");
14178 OS << "Loop ";
14179 L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
14180 OS << ": ";
14181 if (!isa<SCEVCouldNotCompute>(PredConstantMax)) {
14182 OS << "Predicated constant max backedge-taken count is ";
14183 PrintSCEVWithTypeHint(OS, PredConstantMax);
14184 } else
14185 OS << "Unpredictable predicated constant max backedge-taken count.";
14186 OS << "\n";
14187 OS << " Predicates:\n";
14188 for (const auto *P : Preds)
14189 P->print(OS, 4);
14190 }
14191 Preds.clear();
14192
14193 auto *PredSymbolicMax =
14195 if (SymbolicBTC != PredSymbolicMax) {
14196 assert(!Preds.empty() &&
14197 "Different predicated symbolic max BTC, but no predicates");
14198 OS << "Loop ";
14199 L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
14200 OS << ": ";
14201 if (!isa<SCEVCouldNotCompute>(PredSymbolicMax)) {
14202 OS << "Predicated symbolic max backedge-taken count is ";
14203 PrintSCEVWithTypeHint(OS, PredSymbolicMax);
14204 } else
14205 OS << "Unpredictable predicated symbolic max backedge-taken count.";
14206 OS << "\n";
14207 OS << " Predicates:\n";
14208 for (const auto *P : Preds)
14209 P->print(OS, 4);
14210 }
14211
14213 OS << "Loop ";
14214 L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
14215 OS << ": ";
14216 OS << "Trip multiple is " << SE->getSmallConstantTripMultiple(L) << "\n";
14217 }
14218}
14219
14220namespace llvm {
14221// Note: these overloaded operators need to be in the llvm namespace for them
14222// to be resolved correctly. If we put them outside the llvm namespace, the
14223//
14224// OS << ": " << SE.getLoopDisposition(SV, InnerL);
14225//
14226// code below "breaks" and start printing raw enum values as opposed to the
14227// string values.
14230 switch (LD) {
14232 OS << "Variant";
14233 break;
14235 OS << "Invariant";
14236 break;
14238 OS << "Computable";
14239 break;
14240 }
14241 return OS;
14242}
14243
14246 switch (BD) {
14248 OS << "DoesNotDominate";
14249 break;
14251 OS << "Dominates";
14252 break;
14254 OS << "ProperlyDominates";
14255 break;
14256 }
14257 return OS;
14258}
14259} // namespace llvm
14260
14262 // ScalarEvolution's implementation of the print method is to print
14263 // out SCEV values of all instructions that are interesting. Doing
14264 // this potentially causes it to create new SCEV objects though,
14265 // which technically conflicts with the const qualifier. This isn't
14266 // observable from outside the class though, so casting away the
14267 // const isn't dangerous.
14268 ScalarEvolution &SE = *const_cast<ScalarEvolution *>(this);
14269
14270 if (ClassifyExpressions) {
14271 OS << "Classifying expressions for: ";
14272 F.printAsOperand(OS, /*PrintType=*/false);
14273 OS << "\n";
14274 for (Instruction &I : instructions(F))
14275 if (isSCEVable(I.getType()) && !isa<CmpInst>(I)) {
14276 OS << I << '\n';
14277 OS << " --> ";
14278 const SCEV *SV = SE.getSCEV(&I);
14279 SV->print(OS);
14280 if (!isa<SCEVCouldNotCompute>(SV)) {
14281 OS << " U: ";
14282 SE.getUnsignedRange(SV).print(OS);
14283 OS << " S: ";
14284 SE.getSignedRange(SV).print(OS);
14285 }
14286
14287 const Loop *L = LI.getLoopFor(I.getParent());
14288
14289 const SCEV *AtUse = SE.getSCEVAtScope(SV, L);
14290 if (AtUse != SV) {
14291 OS << " --> ";
14292 AtUse->print(OS);
14293 if (!isa<SCEVCouldNotCompute>(AtUse)) {
14294 OS << " U: ";
14295 SE.getUnsignedRange(AtUse).print(OS);
14296 OS << " S: ";
14297 SE.getSignedRange(AtUse).print(OS);
14298 }
14299 }
14300
14301 if (L) {
14302 OS << "\t\t" "Exits: ";
14303 const SCEV *ExitValue = SE.getSCEVAtScope(SV, L->getParentLoop());
14304 if (!SE.isLoopInvariant(ExitValue, L)) {
14305 OS << "<<Unknown>>";
14306 } else {
14307 OS << *ExitValue;
14308 }
14309
14310 ListSeparator LS(", ", "\t\tLoopDispositions: { ");
14311 for (const auto *Iter = L; Iter; Iter = Iter->getParentLoop()) {
14312 OS << LS;
14313 Iter->getHeader()->printAsOperand(OS, /*PrintType=*/false);
14314 OS << ": " << SE.getLoopDisposition(SV, Iter);
14315 }
14316
14317 for (const auto *InnerL : depth_first(L)) {
14318 if (InnerL == L)
14319 continue;
14320 OS << LS;
14321 InnerL->getHeader()->printAsOperand(OS, /*PrintType=*/false);
14322 OS << ": " << SE.getLoopDisposition(SV, InnerL);
14323 }
14324
14325 OS << " }";
14326 }
14327
14328 OS << "\n";
14329 }
14330 }
14331
14332 OS << "Determining loop execution counts for: ";
14333 F.printAsOperand(OS, /*PrintType=*/false);
14334 OS << "\n";
14335 for (Loop *I : LI)
14336 PrintLoopInfo(OS, &SE, I);
14337}
14338
14341 auto &Values = LoopDispositions[S];
14342 for (auto &V : Values) {
14343 if (V.getPointer() == L)
14344 return V.getInt();
14345 }
14346 Values.emplace_back(L, LoopVariant);
14347 LoopDisposition D = computeLoopDisposition(S, L);
14348 auto &Values2 = LoopDispositions[S];
14349 for (auto &V : llvm::reverse(Values2)) {
14350 if (V.getPointer() == L) {
14351 V.setInt(D);
14352 break;
14353 }
14354 }
14355 return D;
14356}
14357
14359ScalarEvolution::computeLoopDisposition(const SCEV *S, const Loop *L) {
14360 switch (S->getSCEVType()) {
14361 case scConstant:
14362 case scVScale:
14363 return LoopInvariant;
14364 case scAddRecExpr: {
14365 const SCEVAddRecExpr *AR = cast<SCEVAddRecExpr>(S);
14366
14367 // If L is the addrec's loop, it's computable.
14368 if (AR->getLoop() == L)
14369 return LoopComputable;
14370
14371 // Add recurrences are never invariant in the function-body (null loop).
14372 if (!L)
14373 return LoopVariant;
14374
14375 // Everything that is not defined at loop entry is variant.
14376 if (DT.dominates(L->getHeader(), AR->getLoop()->getHeader()))
14377 return LoopVariant;
14378 assert(!L->contains(AR->getLoop()) && "Containing loop's header does not"
14379 " dominate the contained loop's header?");
14380
14381 // This recurrence is invariant w.r.t. L if AR's loop contains L.
14382 if (AR->getLoop()->contains(L))
14383 return LoopInvariant;
14384
14385 // This recurrence is variant w.r.t. L if any of its operands
14386 // are variant.
14387 for (SCEVUse Op : AR->operands())
14388 if (!isLoopInvariant(Op, L))
14389 return LoopVariant;
14390
14391 // Otherwise it's loop-invariant.
14392 return LoopInvariant;
14393 }
14394 case scTruncate:
14395 case scZeroExtend:
14396 case scSignExtend:
14397 case scPtrToAddr:
14398 case scPtrToInt:
14399 case scAddExpr:
14400 case scMulExpr:
14401 case scUDivExpr:
14402 case scUMaxExpr:
14403 case scSMaxExpr:
14404 case scUMinExpr:
14405 case scSMinExpr:
14406 case scSequentialUMinExpr: {
14407 bool HasVarying = false;
14408 for (SCEVUse Op : S->operands()) {
14410 if (D == LoopVariant)
14411 return LoopVariant;
14412 if (D == LoopComputable)
14413 HasVarying = true;
14414 }
14415 return HasVarying ? LoopComputable : LoopInvariant;
14416 }
14417 case scUnknown:
14418 // All non-instruction values are loop invariant. All instructions are loop
14419 // invariant if they are not contained in the specified loop.
14420 // Instructions are never considered invariant in the function body
14421 // (null loop) because they are defined within the "loop".
14422 if (auto *I = dyn_cast<Instruction>(cast<SCEVUnknown>(S)->getValue()))
14423 return (L && !L->contains(I)) ? LoopInvariant : LoopVariant;
14424 return LoopInvariant;
14425 case scCouldNotCompute:
14426 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
14427 }
14428 llvm_unreachable("Unknown SCEV kind!");
14429}
14430
14432 return getLoopDisposition(S, L) == LoopInvariant;
14433}
14434
14436 return getLoopDisposition(S, L) == LoopComputable;
14437}
14438
14441 auto &Values = BlockDispositions[S];
14442 for (auto &V : Values) {
14443 if (V.getPointer() == BB)
14444 return V.getInt();
14445 }
14446 Values.emplace_back(BB, DoesNotDominateBlock);
14447 BlockDisposition D = computeBlockDisposition(S, BB);
14448 auto &Values2 = BlockDispositions[S];
14449 for (auto &V : llvm::reverse(Values2)) {
14450 if (V.getPointer() == BB) {
14451 V.setInt(D);
14452 break;
14453 }
14454 }
14455 return D;
14456}
14457
14459ScalarEvolution::computeBlockDisposition(const SCEV *S, const BasicBlock *BB) {
14460 switch (S->getSCEVType()) {
14461 case scConstant:
14462 case scVScale:
14464 case scAddRecExpr: {
14465 // This uses a "dominates" query instead of "properly dominates" query
14466 // to test for proper dominance too, because the instruction which
14467 // produces the addrec's value is a PHI, and a PHI effectively properly
14468 // dominates its entire containing block.
14469 const SCEVAddRecExpr *AR = cast<SCEVAddRecExpr>(S);
14470 if (!DT.dominates(AR->getLoop()->getHeader(), BB))
14471 return DoesNotDominateBlock;
14472
14473 // Fall through into SCEVNAryExpr handling.
14474 [[fallthrough]];
14475 }
14476 case scTruncate:
14477 case scZeroExtend:
14478 case scSignExtend:
14479 case scPtrToAddr:
14480 case scPtrToInt:
14481 case scAddExpr:
14482 case scMulExpr:
14483 case scUDivExpr:
14484 case scUMaxExpr:
14485 case scSMaxExpr:
14486 case scUMinExpr:
14487 case scSMinExpr:
14488 case scSequentialUMinExpr: {
14489 bool Proper = true;
14490 for (const SCEV *NAryOp : S->operands()) {
14492 if (D == DoesNotDominateBlock)
14493 return DoesNotDominateBlock;
14494 if (D == DominatesBlock)
14495 Proper = false;
14496 }
14497 return Proper ? ProperlyDominatesBlock : DominatesBlock;
14498 }
14499 case scUnknown:
14500 if (Instruction *I =
14501 dyn_cast<Instruction>(cast<SCEVUnknown>(S)->getValue())) {
14502 if (I->getParent() == BB)
14503 return DominatesBlock;
14504 if (DT.properlyDominates(I->getParent(), BB))
14506 return DoesNotDominateBlock;
14507 }
14509 case scCouldNotCompute:
14510 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
14511 }
14512 llvm_unreachable("Unknown SCEV kind!");
14513}
14514
14515bool ScalarEvolution::dominates(const SCEV *S, const BasicBlock *BB) {
14516 return getBlockDisposition(S, BB) >= DominatesBlock;
14517}
14518
14521}
14522
14523bool ScalarEvolution::hasOperand(const SCEV *S, const SCEV *Op) const {
14524 return SCEVExprContains(S, [&](const SCEV *Expr) { return Expr == Op; });
14525}
14526
14527void ScalarEvolution::forgetBackedgeTakenCounts(const Loop *L,
14528 bool Predicated) {
14529 auto &BECounts =
14530 Predicated ? PredicatedBackedgeTakenCounts : BackedgeTakenCounts;
14531 auto It = BECounts.find(L);
14532 if (It != BECounts.end()) {
14533 for (const ExitNotTakenInfo &ENT : It->second.ExitNotTaken) {
14534 for (const SCEV *S : {ENT.ExactNotTaken, ENT.SymbolicMaxNotTaken}) {
14535 if (!isa<SCEVConstant>(S)) {
14536 auto UserIt = BECountUsers.find(S);
14537 assert(UserIt != BECountUsers.end());
14538 UserIt->second.erase({L, Predicated});
14539 }
14540 }
14541 }
14542 BECounts.erase(It);
14543 }
14544}
14545
14546void ScalarEvolution::forgetMemoizedResults(ArrayRef<SCEVUse> SCEVs) {
14547 SmallPtrSet<const SCEV *, 8> ToForget(llvm::from_range, SCEVs);
14548 SmallVector<SCEVUse, 8> Worklist(ToForget.begin(), ToForget.end());
14549
14550 while (!Worklist.empty()) {
14551 const SCEV *Curr = Worklist.pop_back_val();
14552 auto Users = SCEVUsers.find(Curr);
14553 if (Users != SCEVUsers.end())
14554 for (const auto *User : Users->second)
14555 if (ToForget.insert(User).second)
14556 Worklist.push_back(User);
14557 }
14558
14559 for (const auto *S : ToForget)
14560 forgetMemoizedResultsImpl(S);
14561
14562 for (auto I = PredicatedSCEVRewrites.begin();
14563 I != PredicatedSCEVRewrites.end();) {
14564 std::pair<const SCEV *, const Loop *> Entry = I->first;
14565 if (ToForget.count(Entry.first))
14566 PredicatedSCEVRewrites.erase(I++);
14567 else
14568 ++I;
14569 }
14570}
14571
14572void ScalarEvolution::forgetMemoizedResultsImpl(const SCEV *S) {
14573 LoopDispositions.erase(S);
14574 BlockDispositions.erase(S);
14575 UnsignedRanges.erase(S);
14576 SignedRanges.erase(S);
14577 HasRecMap.erase(S);
14578 ConstantMultipleCache.erase(S);
14579
14580 if (auto *AR = dyn_cast<SCEVAddRecExpr>(S)) {
14581 UnsignedWrapViaInductionTried.erase(AR);
14582 SignedWrapViaInductionTried.erase(AR);
14583 }
14584
14585 auto ExprIt = ExprValueMap.find(S);
14586 if (ExprIt != ExprValueMap.end()) {
14587 for (Value *V : ExprIt->second) {
14588 auto ValueIt = ValueExprMap.find_as(V);
14589 if (ValueIt != ValueExprMap.end())
14590 ValueExprMap.erase(ValueIt);
14591 }
14592 ExprValueMap.erase(ExprIt);
14593 }
14594
14595 auto ScopeIt = ValuesAtScopes.find(S);
14596 if (ScopeIt != ValuesAtScopes.end()) {
14597 for (const auto &Pair : ScopeIt->second)
14598 if (!isa_and_nonnull<SCEVConstant>(Pair.second))
14599 llvm::erase(ValuesAtScopesUsers[Pair.second],
14600 std::make_pair(Pair.first, S));
14601 ValuesAtScopes.erase(ScopeIt);
14602 }
14603
14604 auto ScopeUserIt = ValuesAtScopesUsers.find(S);
14605 if (ScopeUserIt != ValuesAtScopesUsers.end()) {
14606 for (const auto &Pair : ScopeUserIt->second)
14607 llvm::erase(ValuesAtScopes[Pair.second], std::make_pair(Pair.first, S));
14608 ValuesAtScopesUsers.erase(ScopeUserIt);
14609 }
14610
14611 auto BEUsersIt = BECountUsers.find(S);
14612 if (BEUsersIt != BECountUsers.end()) {
14613 // Work on a copy, as forgetBackedgeTakenCounts() will modify the original.
14614 auto Copy = BEUsersIt->second;
14615 for (const auto &Pair : Copy)
14616 forgetBackedgeTakenCounts(Pair.getPointer(), Pair.getInt());
14617 BECountUsers.erase(BEUsersIt);
14618 }
14619
14620 auto FoldUser = FoldCacheUser.find(S);
14621 if (FoldUser != FoldCacheUser.end())
14622 for (auto &KV : FoldUser->second)
14623 FoldCache.erase(KV);
14624 FoldCacheUser.erase(S);
14625}
14626
14627void
14628ScalarEvolution::getUsedLoops(const SCEV *S,
14629 SmallPtrSetImpl<const Loop *> &LoopsUsed) {
14630 struct FindUsedLoops {
14631 FindUsedLoops(SmallPtrSetImpl<const Loop *> &LoopsUsed)
14632 : LoopsUsed(LoopsUsed) {}
14633 SmallPtrSetImpl<const Loop *> &LoopsUsed;
14634 bool follow(const SCEV *S) {
14635 if (auto *AR = dyn_cast<SCEVAddRecExpr>(S))
14636 LoopsUsed.insert(AR->getLoop());
14637 return true;
14638 }
14639
14640 bool isDone() const { return false; }
14641 };
14642
14643 FindUsedLoops F(LoopsUsed);
14644 SCEVTraversal<FindUsedLoops>(F).visitAll(S);
14645}
14646
14647void ScalarEvolution::getReachableBlocks(
14650 Worklist.push_back(&F.getEntryBlock());
14651 while (!Worklist.empty()) {
14652 BasicBlock *BB = Worklist.pop_back_val();
14653 if (!Reachable.insert(BB).second)
14654 continue;
14655
14656 Value *Cond;
14657 BasicBlock *TrueBB, *FalseBB;
14658 if (match(BB->getTerminator(), m_Br(m_Value(Cond), m_BasicBlock(TrueBB),
14659 m_BasicBlock(FalseBB)))) {
14660 if (auto *C = dyn_cast<ConstantInt>(Cond)) {
14661 Worklist.push_back(C->isOne() ? TrueBB : FalseBB);
14662 continue;
14663 }
14664
14665 if (auto *Cmp = dyn_cast<ICmpInst>(Cond)) {
14666 const SCEV *L = getSCEV(Cmp->getOperand(0));
14667 const SCEV *R = getSCEV(Cmp->getOperand(1));
14668 if (isKnownPredicateViaConstantRanges(Cmp->getCmpPredicate(), L, R)) {
14669 Worklist.push_back(TrueBB);
14670 continue;
14671 }
14672 if (isKnownPredicateViaConstantRanges(Cmp->getInverseCmpPredicate(), L,
14673 R)) {
14674 Worklist.push_back(FalseBB);
14675 continue;
14676 }
14677 }
14678 }
14679
14680 append_range(Worklist, successors(BB));
14681 }
14682}
14683
14685 ScalarEvolution &SE = *const_cast<ScalarEvolution *>(this);
14686 ScalarEvolution SE2(F, TLI, AC, DT, LI);
14687
14688 SmallVector<Loop *, 8> LoopStack(LI.begin(), LI.end());
14689
14690 // Map's SCEV expressions from one ScalarEvolution "universe" to another.
14691 struct SCEVMapper : public SCEVRewriteVisitor<SCEVMapper> {
14692 SCEVMapper(ScalarEvolution &SE) : SCEVRewriteVisitor<SCEVMapper>(SE) {}
14693
14694 const SCEV *visitConstant(const SCEVConstant *Constant) {
14695 return SE.getConstant(Constant->getAPInt());
14696 }
14697
14698 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
14699 return SE.getUnknown(Expr->getValue());
14700 }
14701
14702 const SCEV *visitCouldNotCompute(const SCEVCouldNotCompute *Expr) {
14703 return SE.getCouldNotCompute();
14704 }
14705 };
14706
14707 SCEVMapper SCM(SE2);
14708 SmallPtrSet<BasicBlock *, 16> ReachableBlocks;
14709 SE2.getReachableBlocks(ReachableBlocks, F);
14710
14711 auto GetDelta = [&](const SCEV *Old, const SCEV *New) -> const SCEV * {
14712 if (containsUndefs(Old) || containsUndefs(New)) {
14713 // SCEV treats "undef" as an unknown but consistent value (i.e. it does
14714 // not propagate undef aggressively). This means we can (and do) fail
14715 // verification in cases where a transform makes a value go from "undef"
14716 // to "undef+1" (say). The transform is fine, since in both cases the
14717 // result is "undef", but SCEV thinks the value increased by 1.
14718 return nullptr;
14719 }
14720
14721 // Unless VerifySCEVStrict is set, we only compare constant deltas.
14722 const SCEV *Delta = SE2.getMinusSCEV(Old, New);
14723 if (!VerifySCEVStrict && !isa<SCEVConstant>(Delta))
14724 return nullptr;
14725
14726 return Delta;
14727 };
14728
14729 while (!LoopStack.empty()) {
14730 auto *L = LoopStack.pop_back_val();
14731 llvm::append_range(LoopStack, *L);
14732
14733 // Only verify BECounts in reachable loops. For an unreachable loop,
14734 // any BECount is legal.
14735 if (!ReachableBlocks.contains(L->getHeader()))
14736 continue;
14737
14738 // Only verify cached BECounts. Computing new BECounts may change the
14739 // results of subsequent SCEV uses.
14740 auto It = BackedgeTakenCounts.find(L);
14741 if (It == BackedgeTakenCounts.end())
14742 continue;
14743
14744 auto *CurBECount =
14745 SCM.visit(It->second.getExact(L, const_cast<ScalarEvolution *>(this)));
14746 auto *NewBECount = SE2.getBackedgeTakenCount(L);
14747
14748 if (CurBECount == SE2.getCouldNotCompute() ||
14749 NewBECount == SE2.getCouldNotCompute()) {
14750 // NB! This situation is legal, but is very suspicious -- whatever pass
14751 // change the loop to make a trip count go from could not compute to
14752 // computable or vice-versa *should have* invalidated SCEV. However, we
14753 // choose not to assert here (for now) since we don't want false
14754 // positives.
14755 continue;
14756 }
14757
14758 if (SE.getTypeSizeInBits(CurBECount->getType()) >
14759 SE.getTypeSizeInBits(NewBECount->getType()))
14760 NewBECount = SE2.getZeroExtendExpr(NewBECount, CurBECount->getType());
14761 else if (SE.getTypeSizeInBits(CurBECount->getType()) <
14762 SE.getTypeSizeInBits(NewBECount->getType()))
14763 CurBECount = SE2.getZeroExtendExpr(CurBECount, NewBECount->getType());
14764
14765 const SCEV *Delta = GetDelta(CurBECount, NewBECount);
14766 if (Delta && !Delta->isZero()) {
14767 dbgs() << "Trip Count for " << *L << " Changed!\n";
14768 dbgs() << "Old: " << *CurBECount << "\n";
14769 dbgs() << "New: " << *NewBECount << "\n";
14770 dbgs() << "Delta: " << *Delta << "\n";
14771 std::abort();
14772 }
14773 }
14774
14775 // Collect all valid loops currently in LoopInfo.
14776 SmallPtrSet<Loop *, 32> ValidLoops;
14777 SmallVector<Loop *, 32> Worklist(LI.begin(), LI.end());
14778 while (!Worklist.empty()) {
14779 Loop *L = Worklist.pop_back_val();
14780 if (ValidLoops.insert(L).second)
14781 Worklist.append(L->begin(), L->end());
14782 }
14783 for (const auto &KV : ValueExprMap) {
14784#ifndef NDEBUG
14785 // Check for SCEV expressions referencing invalid/deleted loops.
14786 if (auto *AR = dyn_cast<SCEVAddRecExpr>(KV.second)) {
14787 assert(ValidLoops.contains(AR->getLoop()) &&
14788 "AddRec references invalid loop");
14789 }
14790#endif
14791
14792 // Check that the value is also part of the reverse map.
14793 auto It = ExprValueMap.find(KV.second);
14794 if (It == ExprValueMap.end() || !It->second.contains(KV.first)) {
14795 dbgs() << "Value " << *KV.first
14796 << " is in ValueExprMap but not in ExprValueMap\n";
14797 std::abort();
14798 }
14799
14800 if (auto *I = dyn_cast<Instruction>(&*KV.first)) {
14801 if (!ReachableBlocks.contains(I->getParent()))
14802 continue;
14803 const SCEV *OldSCEV = SCM.visit(KV.second);
14804 const SCEV *NewSCEV = SE2.getSCEV(I);
14805 const SCEV *Delta = GetDelta(OldSCEV, NewSCEV);
14806 if (Delta && !Delta->isZero()) {
14807 dbgs() << "SCEV for value " << *I << " changed!\n"
14808 << "Old: " << *OldSCEV << "\n"
14809 << "New: " << *NewSCEV << "\n"
14810 << "Delta: " << *Delta << "\n";
14811 std::abort();
14812 }
14813 }
14814 }
14815
14816 for (const auto &KV : ExprValueMap) {
14817 for (Value *V : KV.second) {
14818 const SCEV *S = ValueExprMap.lookup(V);
14819 if (!S) {
14820 dbgs() << "Value " << *V
14821 << " is in ExprValueMap but not in ValueExprMap\n";
14822 std::abort();
14823 }
14824 if (S != KV.first) {
14825 dbgs() << "Value " << *V << " mapped to " << *S << " rather than "
14826 << *KV.first << "\n";
14827 std::abort();
14828 }
14829 }
14830 }
14831
14832 // Verify integrity of SCEV users.
14833 for (const auto &S : UniqueSCEVs) {
14834 for (SCEVUse Op : S.operands()) {
14835 // We do not store dependencies of constants.
14836 if (isa<SCEVConstant>(Op))
14837 continue;
14838 auto It = SCEVUsers.find(Op);
14839 if (It != SCEVUsers.end() && It->second.count(&S))
14840 continue;
14841 dbgs() << "Use of operand " << *Op << " by user " << S
14842 << " is not being tracked!\n";
14843 std::abort();
14844 }
14845 }
14846
14847 // Verify integrity of ValuesAtScopes users.
14848 for (const auto &ValueAndVec : ValuesAtScopes) {
14849 const SCEV *Value = ValueAndVec.first;
14850 for (const auto &LoopAndValueAtScope : ValueAndVec.second) {
14851 const Loop *L = LoopAndValueAtScope.first;
14852 const SCEV *ValueAtScope = LoopAndValueAtScope.second;
14853 if (!isa<SCEVConstant>(ValueAtScope)) {
14854 auto It = ValuesAtScopesUsers.find(ValueAtScope);
14855 if (It != ValuesAtScopesUsers.end() &&
14856 is_contained(It->second, std::make_pair(L, Value)))
14857 continue;
14858 dbgs() << "Value: " << *Value << ", Loop: " << *L << ", ValueAtScope: "
14859 << *ValueAtScope << " missing in ValuesAtScopesUsers\n";
14860 std::abort();
14861 }
14862 }
14863 }
14864
14865 for (const auto &ValueAtScopeAndVec : ValuesAtScopesUsers) {
14866 const SCEV *ValueAtScope = ValueAtScopeAndVec.first;
14867 for (const auto &LoopAndValue : ValueAtScopeAndVec.second) {
14868 const Loop *L = LoopAndValue.first;
14869 const SCEV *Value = LoopAndValue.second;
14871 auto It = ValuesAtScopes.find(Value);
14872 if (It != ValuesAtScopes.end() &&
14873 is_contained(It->second, std::make_pair(L, ValueAtScope)))
14874 continue;
14875 dbgs() << "Value: " << *Value << ", Loop: " << *L << ", ValueAtScope: "
14876 << *ValueAtScope << " missing in ValuesAtScopes\n";
14877 std::abort();
14878 }
14879 }
14880
14881 // Verify integrity of BECountUsers.
14882 auto VerifyBECountUsers = [&](bool Predicated) {
14883 auto &BECounts =
14884 Predicated ? PredicatedBackedgeTakenCounts : BackedgeTakenCounts;
14885 for (const auto &LoopAndBEInfo : BECounts) {
14886 for (const ExitNotTakenInfo &ENT : LoopAndBEInfo.second.ExitNotTaken) {
14887 for (const SCEV *S : {ENT.ExactNotTaken, ENT.SymbolicMaxNotTaken}) {
14888 if (!isa<SCEVConstant>(S)) {
14889 auto UserIt = BECountUsers.find(S);
14890 if (UserIt != BECountUsers.end() &&
14891 UserIt->second.contains({ LoopAndBEInfo.first, Predicated }))
14892 continue;
14893 dbgs() << "Value " << *S << " for loop " << *LoopAndBEInfo.first
14894 << " missing from BECountUsers\n";
14895 std::abort();
14896 }
14897 }
14898 }
14899 }
14900 };
14901 VerifyBECountUsers(/* Predicated */ false);
14902 VerifyBECountUsers(/* Predicated */ true);
14903
14904 // Verify intergity of loop disposition cache.
14905 for (auto &[S, Values] : LoopDispositions) {
14906 for (auto [Loop, CachedDisposition] : Values) {
14907 const auto RecomputedDisposition = SE2.getLoopDisposition(S, Loop);
14908 if (CachedDisposition != RecomputedDisposition) {
14909 dbgs() << "Cached disposition of " << *S << " for loop " << *Loop
14910 << " is incorrect: cached " << CachedDisposition << ", actual "
14911 << RecomputedDisposition << "\n";
14912 std::abort();
14913 }
14914 }
14915 }
14916
14917 // Verify integrity of the block disposition cache.
14918 for (auto &[S, Values] : BlockDispositions) {
14919 for (auto [BB, CachedDisposition] : Values) {
14920 const auto RecomputedDisposition = SE2.getBlockDisposition(S, BB);
14921 if (CachedDisposition != RecomputedDisposition) {
14922 dbgs() << "Cached disposition of " << *S << " for block %"
14923 << BB->getName() << " is incorrect: cached " << CachedDisposition
14924 << ", actual " << RecomputedDisposition << "\n";
14925 std::abort();
14926 }
14927 }
14928 }
14929
14930 // Verify FoldCache/FoldCacheUser caches.
14931 for (auto [FoldID, Expr] : FoldCache) {
14932 auto I = FoldCacheUser.find(Expr);
14933 if (I == FoldCacheUser.end()) {
14934 dbgs() << "Missing entry in FoldCacheUser for cached expression " << *Expr
14935 << "!\n";
14936 std::abort();
14937 }
14938 if (!is_contained(I->second, FoldID)) {
14939 dbgs() << "Missing FoldID in cached users of " << *Expr << "!\n";
14940 std::abort();
14941 }
14942 }
14943 for (auto [Expr, IDs] : FoldCacheUser) {
14944 for (auto &FoldID : IDs) {
14945 const SCEV *S = FoldCache.lookup(FoldID);
14946 if (!S) {
14947 dbgs() << "Missing entry in FoldCache for expression " << *Expr
14948 << "!\n";
14949 std::abort();
14950 }
14951 if (S != Expr) {
14952 dbgs() << "Entry in FoldCache doesn't match FoldCacheUser: " << *S
14953 << " != " << *Expr << "!\n";
14954 std::abort();
14955 }
14956 }
14957 }
14958
14959 // Verify that ConstantMultipleCache computations are correct. We check that
14960 // cached multiples and recomputed multiples are multiples of each other to
14961 // verify correctness. It is possible that a recomputed multiple is different
14962 // from the cached multiple due to strengthened no wrap flags or changes in
14963 // KnownBits computations.
14964 for (auto [S, Multiple] : ConstantMultipleCache) {
14965 APInt RecomputedMultiple = SE2.getConstantMultiple(S);
14966 if ((Multiple != 0 && RecomputedMultiple != 0 &&
14967 Multiple.urem(RecomputedMultiple) != 0 &&
14968 RecomputedMultiple.urem(Multiple) != 0)) {
14969 dbgs() << "Incorrect cached computation in ConstantMultipleCache for "
14970 << *S << " : Computed " << RecomputedMultiple
14971 << " but cache contains " << Multiple << "!\n";
14972 std::abort();
14973 }
14974 }
14975}
14976
14978 Function &F, const PreservedAnalyses &PA,
14979 FunctionAnalysisManager::Invalidator &Inv) {
14980 // Invalidate the ScalarEvolution object whenever it isn't preserved or one
14981 // of its dependencies is invalidated.
14982 auto PAC = PA.getChecker<ScalarEvolutionAnalysis>();
14983 return !(PAC.preserved() || PAC.preservedSet<AllAnalysesOn<Function>>()) ||
14984 Inv.invalidate<AssumptionAnalysis>(F, PA) ||
14985 Inv.invalidate<DominatorTreeAnalysis>(F, PA) ||
14986 Inv.invalidate<LoopAnalysis>(F, PA);
14987}
14988
14989AnalysisKey ScalarEvolutionAnalysis::Key;
14990
14993 auto &TLI = AM.getResult<TargetLibraryAnalysis>(F);
14994 auto &AC = AM.getResult<AssumptionAnalysis>(F);
14995 auto &DT = AM.getResult<DominatorTreeAnalysis>(F);
14996 auto &LI = AM.getResult<LoopAnalysis>(F);
14997 return ScalarEvolution(F, TLI, AC, DT, LI);
14998}
14999
15005
15008 // For compatibility with opt's -analyze feature under legacy pass manager
15009 // which was not ported to NPM. This keeps tests using
15010 // update_analyze_test_checks.py working.
15011 OS << "Printing analysis 'Scalar Evolution Analysis' for function '"
15012 << F.getName() << "':\n";
15014 return PreservedAnalyses::all();
15015}
15016
15018 "Scalar Evolution Analysis", false, true)
15024 "Scalar Evolution Analysis", false, true)
15025
15027
15029
15031 SE.reset(new ScalarEvolution(
15033 getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F),
15035 getAnalysis<LoopInfoWrapperPass>().getLoopInfo()));
15036 return false;
15037}
15038
15040
15042 SE->print(OS);
15043}
15044
15046 if (!VerifySCEV)
15047 return;
15048
15049 SE->verify();
15050}
15051
15059
15061 const SCEV *RHS) {
15062 return getComparePredicate(ICmpInst::ICMP_EQ, LHS, RHS);
15063}
15064
15065const SCEVPredicate *
15067 const SCEV *LHS, const SCEV *RHS) {
15069 assert(LHS->getType() == RHS->getType() &&
15070 "Type mismatch between LHS and RHS");
15071 // Unique this node based on the arguments
15072 ID.AddInteger(SCEVPredicate::P_Compare);
15073 ID.AddInteger(Pred);
15074 ID.AddPointer(LHS);
15075 ID.AddPointer(RHS);
15076 void *IP = nullptr;
15077 if (const auto *S = UniquePreds.FindNodeOrInsertPos(ID, IP))
15078 return S;
15079 SCEVComparePredicate *Eq = new (SCEVAllocator)
15080 SCEVComparePredicate(ID.Intern(SCEVAllocator), Pred, LHS, RHS);
15081 UniquePreds.InsertNode(Eq, IP);
15082 return Eq;
15083}
15084
15086 const SCEVAddRecExpr *AR,
15089 // Unique this node based on the arguments
15090 ID.AddInteger(SCEVPredicate::P_Wrap);
15091 ID.AddPointer(AR);
15092 ID.AddInteger(AddedFlags);
15093 void *IP = nullptr;
15094 if (const auto *S = UniquePreds.FindNodeOrInsertPos(ID, IP))
15095 return S;
15096 auto *OF = new (SCEVAllocator)
15097 SCEVWrapPredicate(ID.Intern(SCEVAllocator), AR, AddedFlags);
15098 UniquePreds.InsertNode(OF, IP);
15099 return OF;
15100}
15101
15102namespace {
15103
15104class SCEVPredicateRewriter : public SCEVRewriteVisitor<SCEVPredicateRewriter> {
15105public:
15106
15107 /// Rewrites \p S in the context of a loop L and the SCEV predication
15108 /// infrastructure.
15109 ///
15110 /// If \p Pred is non-null, the SCEV expression is rewritten to respect the
15111 /// equivalences present in \p Pred.
15112 ///
15113 /// If \p NewPreds is non-null, rewrite is free to add further predicates to
15114 /// \p NewPreds such that the result will be an AddRecExpr.
15115 static const SCEV *rewrite(const SCEV *S, const Loop *L, ScalarEvolution &SE,
15117 const SCEVPredicate *Pred) {
15118 SCEVPredicateRewriter Rewriter(L, SE, NewPreds, Pred);
15119 return Rewriter.visit(S);
15120 }
15121
15122 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
15123 if (Pred) {
15124 if (auto *U = dyn_cast<SCEVUnionPredicate>(Pred)) {
15125 for (const auto *Pred : U->getPredicates())
15126 if (const auto *IPred = dyn_cast<SCEVComparePredicate>(Pred))
15127 if (IPred->getLHS() == Expr &&
15128 IPred->getPredicate() == ICmpInst::ICMP_EQ)
15129 return IPred->getRHS();
15130 } else if (const auto *IPred = dyn_cast<SCEVComparePredicate>(Pred)) {
15131 if (IPred->getLHS() == Expr &&
15132 IPred->getPredicate() == ICmpInst::ICMP_EQ)
15133 return IPred->getRHS();
15134 }
15135 }
15136 return convertToAddRecWithPreds(Expr);
15137 }
15138
15139 const SCEV *visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr) {
15140 const SCEV *Operand = visit(Expr->getOperand());
15141 const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Operand);
15142 if (AR && AR->getLoop() == L && AR->isAffine()) {
15143 // This couldn't be folded because the operand didn't have the nuw
15144 // flag. Add the nusw flag as an assumption that we could make.
15145 const SCEV *Step = AR->getStepRecurrence(SE);
15146 Type *Ty = Expr->getType();
15147 if (addOverflowAssumption(AR, SCEVWrapPredicate::IncrementNUSW))
15148 return SE.getAddRecExpr(SE.getZeroExtendExpr(AR->getStart(), Ty),
15149 SE.getSignExtendExpr(Step, Ty), L,
15150 AR->getNoWrapFlags());
15151 }
15152 return SE.getZeroExtendExpr(Operand, Expr->getType());
15153 }
15154
15155 const SCEV *visitSignExtendExpr(const SCEVSignExtendExpr *Expr) {
15156 const SCEV *Operand = visit(Expr->getOperand());
15157 const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Operand);
15158 if (AR && AR->getLoop() == L && AR->isAffine()) {
15159 // This couldn't be folded because the operand didn't have the nsw
15160 // flag. Add the nssw flag as an assumption that we could make.
15161 const SCEV *Step = AR->getStepRecurrence(SE);
15162 Type *Ty = Expr->getType();
15163 if (addOverflowAssumption(AR, SCEVWrapPredicate::IncrementNSSW))
15164 return SE.getAddRecExpr(SE.getSignExtendExpr(AR->getStart(), Ty),
15165 SE.getSignExtendExpr(Step, Ty), L,
15166 AR->getNoWrapFlags());
15167 }
15168 return SE.getSignExtendExpr(Operand, Expr->getType());
15169 }
15170
15171private:
15172 explicit SCEVPredicateRewriter(
15173 const Loop *L, ScalarEvolution &SE,
15174 SmallVectorImpl<const SCEVPredicate *> *NewPreds,
15175 const SCEVPredicate *Pred)
15176 : SCEVRewriteVisitor(SE), NewPreds(NewPreds), Pred(Pred), L(L) {}
15177
15178 bool addOverflowAssumption(const SCEVPredicate *P) {
15179 if (!NewPreds) {
15180 // Check if we've already made this assumption.
15181 return Pred && Pred->implies(P, SE);
15182 }
15183 NewPreds->push_back(P);
15184 return true;
15185 }
15186
15187 bool addOverflowAssumption(const SCEVAddRecExpr *AR,
15189 auto *A = SE.getWrapPredicate(AR, AddedFlags);
15190 return addOverflowAssumption(A);
15191 }
15192
15193 // If \p Expr represents a PHINode, we try to see if it can be represented
15194 // as an AddRec, possibly under a predicate (PHISCEVPred). If it is possible
15195 // to add this predicate as a runtime overflow check, we return the AddRec.
15196 // If \p Expr does not meet these conditions (is not a PHI node, or we
15197 // couldn't create an AddRec for it, or couldn't add the predicate), we just
15198 // return \p Expr.
15199 const SCEV *convertToAddRecWithPreds(const SCEVUnknown *Expr) {
15200 if (!isa<PHINode>(Expr->getValue()))
15201 return Expr;
15202 std::optional<
15203 std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>>>
15204 PredicatedRewrite = SE.createAddRecFromPHIWithCasts(Expr);
15205 if (!PredicatedRewrite)
15206 return Expr;
15207 for (const auto *P : PredicatedRewrite->second){
15208 // Wrap predicates from outer loops are not supported.
15209 if (auto *WP = dyn_cast<const SCEVWrapPredicate>(P)) {
15210 if (L != WP->getExpr()->getLoop())
15211 return Expr;
15212 }
15213 if (!addOverflowAssumption(P))
15214 return Expr;
15215 }
15216 return PredicatedRewrite->first;
15217 }
15218
15219 SmallVectorImpl<const SCEVPredicate *> *NewPreds;
15220 const SCEVPredicate *Pred;
15221 const Loop *L;
15222};
15223
15224} // end anonymous namespace
15225
15226const SCEV *
15228 const SCEVPredicate &Preds) {
15229 return SCEVPredicateRewriter::rewrite(S, L, *this, nullptr, &Preds);
15230}
15231
15233 const SCEV *S, const Loop *L,
15236 S = SCEVPredicateRewriter::rewrite(S, L, *this, &TransformPreds, nullptr);
15237 auto *AddRec = dyn_cast<SCEVAddRecExpr>(S);
15238
15239 if (!AddRec)
15240 return nullptr;
15241
15242 // Check if any of the transformed predicates is known to be false. In that
15243 // case, it doesn't make sense to convert to a predicated AddRec, as the
15244 // versioned loop will never execute.
15245 for (const SCEVPredicate *Pred : TransformPreds) {
15246 auto *WrapPred = dyn_cast<SCEVWrapPredicate>(Pred);
15247 if (!WrapPred || WrapPred->getFlags() != SCEVWrapPredicate::IncrementNSSW)
15248 continue;
15249
15250 const SCEVAddRecExpr *AddRecToCheck = WrapPred->getExpr();
15251 const SCEV *ExitCount = getBackedgeTakenCount(AddRecToCheck->getLoop());
15252 if (isa<SCEVCouldNotCompute>(ExitCount))
15253 continue;
15254
15255 const SCEV *Step = AddRecToCheck->getStepRecurrence(*this);
15256 if (!Step->isOne())
15257 continue;
15258
15259 ExitCount = getTruncateOrSignExtend(ExitCount, Step->getType());
15260 const SCEV *Add = getAddExpr(AddRecToCheck->getStart(), ExitCount);
15261 if (isKnownPredicate(CmpInst::ICMP_SLT, Add, AddRecToCheck->getStart()))
15262 return nullptr;
15263 }
15264
15265 // Since the transformation was successful, we can now transfer the SCEV
15266 // predicates.
15267 Preds.append(TransformPreds.begin(), TransformPreds.end());
15268
15269 return AddRec;
15270}
15271
15272/// SCEV predicates
15276
15278 const ICmpInst::Predicate Pred,
15279 const SCEV *LHS, const SCEV *RHS)
15280 : SCEVPredicate(ID, P_Compare), Pred(Pred), LHS(LHS), RHS(RHS) {
15281 assert(LHS->getType() == RHS->getType() && "LHS and RHS types don't match");
15282 assert(LHS != RHS && "LHS and RHS are the same SCEV");
15283}
15284
15286 ScalarEvolution &SE) const {
15287 const auto *Op = dyn_cast<SCEVComparePredicate>(N);
15288
15289 if (!Op)
15290 return false;
15291
15292 if (Pred != ICmpInst::ICMP_EQ)
15293 return false;
15294
15295 return Op->LHS == LHS && Op->RHS == RHS;
15296}
15297
15298bool SCEVComparePredicate::isAlwaysTrue() const { return false; }
15299
15301 if (Pred == ICmpInst::ICMP_EQ)
15302 OS.indent(Depth) << "Equal predicate: " << *LHS << " == " << *RHS << "\n";
15303 else
15304 OS.indent(Depth) << "Compare predicate: " << *LHS << " " << Pred << ") "
15305 << *RHS << "\n";
15306
15307}
15308
15310 const SCEVAddRecExpr *AR,
15311 IncrementWrapFlags Flags)
15312 : SCEVPredicate(ID, P_Wrap), AR(AR), Flags(Flags) {}
15313
15314const SCEVAddRecExpr *SCEVWrapPredicate::getExpr() const { return AR; }
15315
15317 ScalarEvolution &SE) const {
15318 const auto *Op = dyn_cast<SCEVWrapPredicate>(N);
15319 if (!Op || setFlags(Flags, Op->Flags) != Flags)
15320 return false;
15321
15322 if (Op->AR == AR)
15323 return true;
15324
15325 if (Flags != SCEVWrapPredicate::IncrementNSSW &&
15327 return false;
15328
15329 const SCEV *Start = AR->getStart();
15330 const SCEV *OpStart = Op->AR->getStart();
15331 if (Start->getType()->isPointerTy() != OpStart->getType()->isPointerTy())
15332 return false;
15333
15334 // Reject pointers to different address spaces.
15335 if (Start->getType()->isPointerTy() && Start->getType() != OpStart->getType())
15336 return false;
15337
15338 const SCEV *Step = AR->getStepRecurrence(SE);
15339 const SCEV *OpStep = Op->AR->getStepRecurrence(SE);
15340 if (!SE.isKnownPositive(Step) || !SE.isKnownPositive(OpStep))
15341 return false;
15342
15343 // If both steps are positive, this implies N, if N's start and step are
15344 // ULE/SLE (for NSUW/NSSW) than this'.
15345 Type *WiderTy = SE.getWiderType(Step->getType(), OpStep->getType());
15346 Step = SE.getNoopOrZeroExtend(Step, WiderTy);
15347 OpStep = SE.getNoopOrZeroExtend(OpStep, WiderTy);
15348
15349 bool IsNUW = Flags == SCEVWrapPredicate::IncrementNUSW;
15350 OpStart = IsNUW ? SE.getNoopOrZeroExtend(OpStart, WiderTy)
15351 : SE.getNoopOrSignExtend(OpStart, WiderTy);
15352 Start = IsNUW ? SE.getNoopOrZeroExtend(Start, WiderTy)
15353 : SE.getNoopOrSignExtend(Start, WiderTy);
15355 return SE.isKnownPredicate(Pred, OpStep, Step) &&
15356 SE.isKnownPredicate(Pred, OpStart, Start);
15357}
15358
15360 SCEV::NoWrapFlags ScevFlags = AR->getNoWrapFlags();
15361 IncrementWrapFlags IFlags = Flags;
15362
15363 if (ScalarEvolution::setFlags(ScevFlags, SCEV::FlagNSW) == ScevFlags)
15364 IFlags = clearFlags(IFlags, IncrementNSSW);
15365
15366 return IFlags == IncrementAnyWrap;
15367}
15368
15369void SCEVWrapPredicate::print(raw_ostream &OS, unsigned Depth) const {
15370 OS.indent(Depth) << *getExpr() << " Added Flags: ";
15372 OS << "<nusw>";
15374 OS << "<nssw>";
15375 OS << "\n";
15376}
15377
15380 ScalarEvolution &SE) {
15381 IncrementWrapFlags ImpliedFlags = IncrementAnyWrap;
15382 SCEV::NoWrapFlags StaticFlags = AR->getNoWrapFlags();
15383
15384 // We can safely transfer the NSW flag as NSSW.
15385 if (ScalarEvolution::setFlags(StaticFlags, SCEV::FlagNSW) == StaticFlags)
15386 ImpliedFlags = IncrementNSSW;
15387
15388 if (ScalarEvolution::setFlags(StaticFlags, SCEV::FlagNUW) == StaticFlags) {
15389 // If the increment is positive, the SCEV NUW flag will also imply the
15390 // WrapPredicate NUSW flag.
15391 if (const auto *Step = dyn_cast<SCEVConstant>(AR->getStepRecurrence(SE)))
15392 if (Step->getValue()->getValue().isNonNegative())
15393 ImpliedFlags = setFlags(ImpliedFlags, IncrementNUSW);
15394 }
15395
15396 return ImpliedFlags;
15397}
15398
15399/// Union predicates don't get cached so create a dummy set ID for it.
15401 ScalarEvolution &SE)
15402 : SCEVPredicate(FoldingSetNodeIDRef(nullptr, 0), P_Union) {
15403 for (const auto *P : Preds)
15404 add(P, SE);
15405}
15406
15408 return all_of(Preds,
15409 [](const SCEVPredicate *I) { return I->isAlwaysTrue(); });
15410}
15411
15413 ScalarEvolution &SE) const {
15414 if (const auto *Set = dyn_cast<SCEVUnionPredicate>(N))
15415 return all_of(Set->Preds, [this, &SE](const SCEVPredicate *I) {
15416 return this->implies(I, SE);
15417 });
15418
15419 return any_of(Preds,
15420 [N, &SE](const SCEVPredicate *I) { return I->implies(N, SE); });
15421}
15422
15424 for (const auto *Pred : Preds)
15425 Pred->print(OS, Depth);
15426}
15427
15428void SCEVUnionPredicate::add(const SCEVPredicate *N, ScalarEvolution &SE) {
15429 if (const auto *Set = dyn_cast<SCEVUnionPredicate>(N)) {
15430 for (const auto *Pred : Set->Preds)
15431 add(Pred, SE);
15432 return;
15433 }
15434
15435 // Implication checks are quadratic in the number of predicates. Stop doing
15436 // them if there are many predicates, as they should be too expensive to use
15437 // anyway at that point.
15438 bool CheckImplies = Preds.size() < 16;
15439
15440 // Only add predicate if it is not already implied by this union predicate.
15441 if (CheckImplies && implies(N, SE))
15442 return;
15443
15444 // Build a new vector containing the current predicates, except the ones that
15445 // are implied by the new predicate N.
15447 for (auto *P : Preds) {
15448 if (CheckImplies && N->implies(P, SE))
15449 continue;
15450 PrunedPreds.push_back(P);
15451 }
15452 Preds = std::move(PrunedPreds);
15453 Preds.push_back(N);
15454}
15455
15457 Loop &L)
15458 : SE(SE), L(L) {
15460 Preds = std::make_unique<SCEVUnionPredicate>(Empty, SE);
15461}
15462
15465 for (const auto *Op : Ops)
15466 // We do not expect that forgetting cached data for SCEVConstants will ever
15467 // open any prospects for sharpening or introduce any correctness issues,
15468 // so we don't bother storing their dependencies.
15469 if (!isa<SCEVConstant>(Op))
15470 SCEVUsers[Op].insert(User);
15471}
15472
15474 for (const SCEV *Op : Ops)
15475 // We do not expect that forgetting cached data for SCEVConstants will ever
15476 // open any prospects for sharpening or introduce any correctness issues,
15477 // so we don't bother storing their dependencies.
15478 if (!isa<SCEVConstant>(Op))
15479 SCEVUsers[Op].insert(User);
15480}
15481
15483 const SCEV *Expr = SE.getSCEV(V);
15484 return getPredicatedSCEV(Expr);
15485}
15486
15488 RewriteEntry &Entry = RewriteMap[Expr];
15489
15490 // If we already have an entry and the version matches, return it.
15491 if (Entry.second && Generation == Entry.first)
15492 return Entry.second;
15493
15494 // We found an entry but it's stale. Rewrite the stale entry
15495 // according to the current predicate.
15496 if (Entry.second)
15497 Expr = Entry.second;
15498
15499 const SCEV *NewSCEV = SE.rewriteUsingPredicate(Expr, &L, *Preds);
15500 Entry = {Generation, NewSCEV};
15501
15502 return NewSCEV;
15503}
15504
15506 if (!BackedgeCount) {
15508 BackedgeCount = SE.getPredicatedBackedgeTakenCount(&L, Preds);
15509 for (const auto *P : Preds)
15510 addPredicate(*P);
15511 }
15512 return BackedgeCount;
15513}
15514
15516 if (!SymbolicMaxBackedgeCount) {
15518 SymbolicMaxBackedgeCount =
15519 SE.getPredicatedSymbolicMaxBackedgeTakenCount(&L, Preds);
15520 for (const auto *P : Preds)
15521 addPredicate(*P);
15522 }
15523 return SymbolicMaxBackedgeCount;
15524}
15525
15527 if (!SmallConstantMaxTripCount) {
15529 SmallConstantMaxTripCount = SE.getSmallConstantMaxTripCount(&L, &Preds);
15530 for (const auto *P : Preds)
15531 addPredicate(*P);
15532 }
15533 return *SmallConstantMaxTripCount;
15534}
15535
15537 if (Preds->implies(&Pred, SE))
15538 return;
15539
15540 SmallVector<const SCEVPredicate *, 4> NewPreds(Preds->getPredicates());
15541 NewPreds.push_back(&Pred);
15542 Preds = std::make_unique<SCEVUnionPredicate>(NewPreds, SE);
15543 updateGeneration();
15544}
15545
15547 return *Preds;
15548}
15549
15550void PredicatedScalarEvolution::updateGeneration() {
15551 // If the generation number wrapped recompute everything.
15552 if (++Generation == 0) {
15553 for (auto &II : RewriteMap) {
15554 const SCEV *Rewritten = II.second.second;
15555 II.second = {Generation, SE.rewriteUsingPredicate(Rewritten, &L, *Preds)};
15556 }
15557 }
15558}
15559
15562 const SCEV *Expr = getSCEV(V);
15563 const auto *AR = cast<SCEVAddRecExpr>(Expr);
15564
15565 auto ImpliedFlags = SCEVWrapPredicate::getImpliedFlags(AR, SE);
15566
15567 // Clear the statically implied flags.
15568 Flags = SCEVWrapPredicate::clearFlags(Flags, ImpliedFlags);
15569 addPredicate(*SE.getWrapPredicate(AR, Flags));
15570
15571 auto II = FlagsMap.insert({V, Flags});
15572 if (!II.second)
15573 II.first->second = SCEVWrapPredicate::setFlags(Flags, II.first->second);
15574}
15575
15578 const SCEV *Expr = getSCEV(V);
15579 const auto *AR = cast<SCEVAddRecExpr>(Expr);
15580
15582 Flags, SCEVWrapPredicate::getImpliedFlags(AR, SE));
15583
15584 auto II = FlagsMap.find(V);
15585
15586 if (II != FlagsMap.end())
15587 Flags = SCEVWrapPredicate::clearFlags(Flags, II->second);
15588
15590}
15591
15593 const SCEV *Expr = this->getSCEV(V);
15595 auto *New = SE.convertSCEVToAddRecWithPredicates(Expr, &L, NewPreds);
15596
15597 if (!New)
15598 return nullptr;
15599
15600 for (const auto *P : NewPreds)
15601 addPredicate(*P);
15602
15603 RewriteMap[SE.getSCEV(V)] = {Generation, New};
15604 return New;
15605}
15606
15609 : RewriteMap(Init.RewriteMap), SE(Init.SE), L(Init.L),
15610 Preds(std::make_unique<SCEVUnionPredicate>(Init.Preds->getPredicates(),
15611 SE)),
15612 Generation(Init.Generation), BackedgeCount(Init.BackedgeCount) {
15613 for (auto I : Init.FlagsMap)
15614 FlagsMap.insert(I);
15615}
15616
15618 // For each block.
15619 for (auto *BB : L.getBlocks())
15620 for (auto &I : *BB) {
15621 if (!SE.isSCEVable(I.getType()))
15622 continue;
15623
15624 auto *Expr = SE.getSCEV(&I);
15625 auto II = RewriteMap.find(Expr);
15626
15627 if (II == RewriteMap.end())
15628 continue;
15629
15630 // Don't print things that are not interesting.
15631 if (II->second.second == Expr)
15632 continue;
15633
15634 OS.indent(Depth) << "[PSE]" << I << ":\n";
15635 OS.indent(Depth + 2) << *Expr << "\n";
15636 OS.indent(Depth + 2) << "--> " << *II->second.second << "\n";
15637 }
15638}
15639
15642 BasicBlock *Header = L->getHeader();
15643 BasicBlock *Pred = L->getLoopPredecessor();
15644 LoopGuards Guards(SE);
15645 if (!Pred)
15646 return Guards;
15648 collectFromBlock(SE, Guards, Header, Pred, VisitedBlocks);
15649 return Guards;
15650}
15651
15652void ScalarEvolution::LoopGuards::collectFromPHI(
15656 unsigned Depth) {
15657 if (!SE.isSCEVable(Phi.getType()))
15658 return;
15659
15660 using MinMaxPattern = std::pair<const SCEVConstant *, SCEVTypes>;
15661 auto GetMinMaxConst = [&](unsigned IncomingIdx) -> MinMaxPattern {
15662 const BasicBlock *InBlock = Phi.getIncomingBlock(IncomingIdx);
15663 if (!VisitedBlocks.insert(InBlock).second)
15664 return {nullptr, scCouldNotCompute};
15665
15666 // Avoid analyzing unreachable blocks so that we don't get trapped
15667 // traversing cycles with ill-formed dominance or infinite cycles
15668 if (!SE.DT.isReachableFromEntry(InBlock))
15669 return {nullptr, scCouldNotCompute};
15670
15671 auto [G, Inserted] = IncomingGuards.try_emplace(InBlock, LoopGuards(SE));
15672 if (Inserted)
15673 collectFromBlock(SE, G->second, Phi.getParent(), InBlock, VisitedBlocks,
15674 Depth + 1);
15675 auto &RewriteMap = G->second.RewriteMap;
15676 if (RewriteMap.empty())
15677 return {nullptr, scCouldNotCompute};
15678 auto S = RewriteMap.find(SE.getSCEV(Phi.getIncomingValue(IncomingIdx)));
15679 if (S == RewriteMap.end())
15680 return {nullptr, scCouldNotCompute};
15681 auto *SM = dyn_cast_if_present<SCEVMinMaxExpr>(S->second);
15682 if (!SM)
15683 return {nullptr, scCouldNotCompute};
15684 if (const SCEVConstant *C0 = dyn_cast<SCEVConstant>(SM->getOperand(0)))
15685 return {C0, SM->getSCEVType()};
15686 return {nullptr, scCouldNotCompute};
15687 };
15688 auto MergeMinMaxConst = [](MinMaxPattern P1,
15689 MinMaxPattern P2) -> MinMaxPattern {
15690 auto [C1, T1] = P1;
15691 auto [C2, T2] = P2;
15692 if (!C1 || !C2 || T1 != T2)
15693 return {nullptr, scCouldNotCompute};
15694 switch (T1) {
15695 case scUMaxExpr:
15696 return {C1->getAPInt().ult(C2->getAPInt()) ? C1 : C2, T1};
15697 case scSMaxExpr:
15698 return {C1->getAPInt().slt(C2->getAPInt()) ? C1 : C2, T1};
15699 case scUMinExpr:
15700 return {C1->getAPInt().ugt(C2->getAPInt()) ? C1 : C2, T1};
15701 case scSMinExpr:
15702 return {C1->getAPInt().sgt(C2->getAPInt()) ? C1 : C2, T1};
15703 default:
15704 llvm_unreachable("Trying to merge non-MinMaxExpr SCEVs.");
15705 }
15706 };
15707 auto P = GetMinMaxConst(0);
15708 for (unsigned int In = 1; In < Phi.getNumIncomingValues(); In++) {
15709 if (!P.first)
15710 break;
15711 P = MergeMinMaxConst(P, GetMinMaxConst(In));
15712 }
15713 if (P.first) {
15714 const SCEV *LHS = SE.getSCEV(const_cast<PHINode *>(&Phi));
15715 SmallVector<SCEVUse, 2> Ops({P.first, LHS});
15716 const SCEV *RHS = SE.getMinMaxExpr(P.second, Ops);
15717 Guards.RewriteMap.insert({LHS, RHS});
15718 }
15719}
15720
15721// Return a new SCEV that modifies \p Expr to the closest number divides by
15722// \p Divisor and less or equal than Expr. For now, only handle constant
15723// Expr.
15725 const APInt &DivisorVal,
15726 ScalarEvolution &SE) {
15727 const APInt *ExprVal;
15728 if (!match(Expr, m_scev_APInt(ExprVal)) || ExprVal->isNegative() ||
15729 DivisorVal.isNonPositive())
15730 return Expr;
15731 APInt Rem = ExprVal->urem(DivisorVal);
15732 // return the SCEV: Expr - Expr % Divisor
15733 return SE.getConstant(*ExprVal - Rem);
15734}
15735
15736// Return a new SCEV that modifies \p Expr to the closest number divides by
15737// \p Divisor and greater or equal than Expr. For now, only handle constant
15738// Expr.
15739static const SCEV *getNextSCEVDivisibleByDivisor(const SCEV *Expr,
15740 const APInt &DivisorVal,
15741 ScalarEvolution &SE) {
15742 const APInt *ExprVal;
15743 if (!match(Expr, m_scev_APInt(ExprVal)) || ExprVal->isNegative() ||
15744 DivisorVal.isNonPositive())
15745 return Expr;
15746 APInt Rem = ExprVal->urem(DivisorVal);
15747 if (Rem.isZero())
15748 return Expr;
15749 // return the SCEV: Expr + Divisor - Expr % Divisor
15750 return SE.getConstant(*ExprVal + DivisorVal - Rem);
15751}
15752
15754 ICmpInst::Predicate Predicate, const SCEV *LHS, const SCEV *RHS,
15757 // If we have LHS == 0, check if LHS is computing a property of some unknown
15758 // SCEV %v which we can rewrite %v to express explicitly.
15760 return false;
15761 // If LHS is A % B, i.e. A % B == 0, rewrite A to (A /u B) * B to
15762 // explicitly express that.
15763 const SCEVUnknown *URemLHS = nullptr;
15764 const SCEV *URemRHS = nullptr;
15765 if (!match(LHS, m_scev_URem(m_SCEVUnknown(URemLHS), m_SCEV(URemRHS), SE)))
15766 return false;
15767
15768 const SCEV *Multiple =
15769 SE.getMulExpr(SE.getUDivExpr(URemLHS, URemRHS), URemRHS);
15770 DivInfo[URemLHS] = Multiple;
15771 if (auto *C = dyn_cast<SCEVConstant>(URemRHS))
15772 Multiples[URemLHS] = C->getAPInt();
15773 return true;
15774}
15775
15776// Check if the condition is a divisibility guard (A % B == 0).
15777static bool isDivisibilityGuard(const SCEV *LHS, const SCEV *RHS,
15778 ScalarEvolution &SE) {
15779 const SCEV *X, *Y;
15780 return match(LHS, m_scev_URem(m_SCEV(X), m_SCEV(Y), SE)) && RHS->isZero();
15781}
15782
15783// Apply divisibility by \p Divisor on MinMaxExpr with constant values,
15784// recursively. This is done by aligning up/down the constant value to the
15785// Divisor.
15786static const SCEV *applyDivisibilityOnMinMaxExpr(const SCEV *MinMaxExpr,
15787 APInt Divisor,
15788 ScalarEvolution &SE) {
15789 // Return true if \p Expr is a MinMax SCEV expression with a non-negative
15790 // constant operand. If so, return in \p SCTy the SCEV type and in \p RHS
15791 // the non-constant operand and in \p LHS the constant operand.
15792 auto IsMinMaxSCEVWithNonNegativeConstant =
15793 [&](const SCEV *Expr, SCEVTypes &SCTy, const SCEV *&LHS,
15794 const SCEV *&RHS) {
15795 if (auto *MinMax = dyn_cast<SCEVMinMaxExpr>(Expr)) {
15796 if (MinMax->getNumOperands() != 2)
15797 return false;
15798 if (auto *C = dyn_cast<SCEVConstant>(MinMax->getOperand(0))) {
15799 if (C->getAPInt().isNegative())
15800 return false;
15801 SCTy = MinMax->getSCEVType();
15802 LHS = MinMax->getOperand(0);
15803 RHS = MinMax->getOperand(1);
15804 return true;
15805 }
15806 }
15807 return false;
15808 };
15809
15810 const SCEV *MinMaxLHS = nullptr, *MinMaxRHS = nullptr;
15811 SCEVTypes SCTy;
15812 if (!IsMinMaxSCEVWithNonNegativeConstant(MinMaxExpr, SCTy, MinMaxLHS,
15813 MinMaxRHS))
15814 return MinMaxExpr;
15815 auto IsMin = isa<SCEVSMinExpr>(MinMaxExpr) || isa<SCEVUMinExpr>(MinMaxExpr);
15816 assert(SE.isKnownNonNegative(MinMaxLHS) && "Expected non-negative operand!");
15817 auto *DivisibleExpr =
15818 IsMin ? getPreviousSCEVDivisibleByDivisor(MinMaxLHS, Divisor, SE)
15819 : getNextSCEVDivisibleByDivisor(MinMaxLHS, Divisor, SE);
15821 applyDivisibilityOnMinMaxExpr(MinMaxRHS, Divisor, SE), DivisibleExpr};
15822 return SE.getMinMaxExpr(SCTy, Ops);
15823}
15824
15825void ScalarEvolution::LoopGuards::collectFromBlock(
15826 ScalarEvolution &SE, ScalarEvolution::LoopGuards &Guards,
15827 const BasicBlock *Block, const BasicBlock *Pred,
15828 SmallPtrSetImpl<const BasicBlock *> &VisitedBlocks, unsigned Depth) {
15829
15831
15832 SmallVector<SCEVUse> ExprsToRewrite;
15833 auto CollectCondition = [&](ICmpInst::Predicate Predicate, const SCEV *LHS,
15834 const SCEV *RHS,
15835 DenseMap<const SCEV *, const SCEV *> &RewriteMap,
15836 const LoopGuards &DivGuards) {
15837 // WARNING: It is generally unsound to apply any wrap flags to the proposed
15838 // replacement SCEV which isn't directly implied by the structure of that
15839 // SCEV. In particular, using contextual facts to imply flags is *NOT*
15840 // legal. See the scoping rules for flags in the header to understand why.
15841
15842 // Check for a condition of the form (-C1 + X < C2). InstCombine will
15843 // create this form when combining two checks of the form (X u< C2 + C1) and
15844 // (X >=u C1).
15845 auto MatchRangeCheckIdiom = [&SE, Predicate, LHS, RHS, &RewriteMap,
15846 &ExprsToRewrite]() {
15847 const SCEVConstant *C1;
15848 const SCEVUnknown *LHSUnknown;
15849 auto *C2 = dyn_cast<SCEVConstant>(RHS);
15850 if (!match(LHS,
15851 m_scev_Add(m_SCEVConstant(C1), m_SCEVUnknown(LHSUnknown))) ||
15852 !C2)
15853 return false;
15854
15855 auto ExactRegion =
15856 ConstantRange::makeExactICmpRegion(Predicate, C2->getAPInt())
15857 .sub(C1->getAPInt());
15858
15859 // Bail out, unless we have a non-wrapping, monotonic range.
15860 if (ExactRegion.isWrappedSet() || ExactRegion.isFullSet())
15861 return false;
15862 auto [I, Inserted] = RewriteMap.try_emplace(LHSUnknown);
15863 const SCEV *RewrittenLHS = Inserted ? LHSUnknown : I->second;
15864 I->second = SE.getUMaxExpr(
15865 SE.getConstant(ExactRegion.getUnsignedMin()),
15866 SE.getUMinExpr(RewrittenLHS,
15867 SE.getConstant(ExactRegion.getUnsignedMax())));
15868 ExprsToRewrite.push_back(LHSUnknown);
15869 return true;
15870 };
15871 if (MatchRangeCheckIdiom())
15872 return;
15873
15874 // Do not apply information for constants or if RHS contains an AddRec.
15876 return;
15877
15878 // If RHS is SCEVUnknown, make sure the information is applied to it.
15880 std::swap(LHS, RHS);
15882 }
15883
15884 // Puts rewrite rule \p From -> \p To into the rewrite map. Also if \p From
15885 // and \p FromRewritten are the same (i.e. there has been no rewrite
15886 // registered for \p From), then puts this value in the list of rewritten
15887 // expressions.
15888 auto AddRewrite = [&](const SCEV *From, const SCEV *FromRewritten,
15889 const SCEV *To) {
15890 if (From == FromRewritten)
15891 ExprsToRewrite.push_back(From);
15892 RewriteMap[From] = To;
15893 };
15894
15895 // Checks whether \p S has already been rewritten. In that case returns the
15896 // existing rewrite because we want to chain further rewrites onto the
15897 // already rewritten value. Otherwise returns \p S.
15898 auto GetMaybeRewritten = [&](const SCEV *S) {
15899 return RewriteMap.lookup_or(S, S);
15900 };
15901
15902 const SCEV *RewrittenLHS = GetMaybeRewritten(LHS);
15903 // Apply divisibility information when computing the constant multiple.
15904 const APInt &DividesBy =
15905 SE.getConstantMultiple(DivGuards.rewrite(RewrittenLHS));
15906
15907 // Collect rewrites for LHS and its transitive operands based on the
15908 // condition.
15909 // For min/max expressions, also apply the guard to its operands:
15910 // 'min(a, b) >= c' -> '(a >= c) and (b >= c)',
15911 // 'min(a, b) > c' -> '(a > c) and (b > c)',
15912 // 'max(a, b) <= c' -> '(a <= c) and (b <= c)',
15913 // 'max(a, b) < c' -> '(a < c) and (b < c)'.
15914
15915 // We cannot express strict predicates in SCEV, so instead we replace them
15916 // with non-strict ones against plus or minus one of RHS depending on the
15917 // predicate.
15918 const SCEV *One = SE.getOne(RHS->getType());
15919 switch (Predicate) {
15920 case CmpInst::ICMP_ULT:
15921 if (RHS->getType()->isPointerTy())
15922 return;
15923 RHS = SE.getUMaxExpr(RHS, One);
15924 [[fallthrough]];
15925 case CmpInst::ICMP_SLT: {
15926 RHS = SE.getMinusSCEV(RHS, One);
15927 RHS = getPreviousSCEVDivisibleByDivisor(RHS, DividesBy, SE);
15928 break;
15929 }
15930 case CmpInst::ICMP_UGT:
15931 case CmpInst::ICMP_SGT:
15932 RHS = SE.getAddExpr(RHS, One);
15933 RHS = getNextSCEVDivisibleByDivisor(RHS, DividesBy, SE);
15934 break;
15935 case CmpInst::ICMP_ULE:
15936 case CmpInst::ICMP_SLE:
15937 RHS = getPreviousSCEVDivisibleByDivisor(RHS, DividesBy, SE);
15938 break;
15939 case CmpInst::ICMP_UGE:
15940 case CmpInst::ICMP_SGE:
15941 RHS = getNextSCEVDivisibleByDivisor(RHS, DividesBy, SE);
15942 break;
15943 default:
15944 break;
15945 }
15946
15947 SmallVector<SCEVUse, 16> Worklist(1, LHS);
15948 SmallPtrSet<const SCEV *, 16> Visited;
15949
15950 auto EnqueueOperands = [&Worklist](const SCEVNAryExpr *S) {
15951 append_range(Worklist, S->operands());
15952 };
15953
15954 while (!Worklist.empty()) {
15955 const SCEV *From = Worklist.pop_back_val();
15956 if (isa<SCEVConstant>(From))
15957 continue;
15958 if (!Visited.insert(From).second)
15959 continue;
15960 const SCEV *FromRewritten = GetMaybeRewritten(From);
15961 const SCEV *To = nullptr;
15962
15963 switch (Predicate) {
15964 case CmpInst::ICMP_ULT:
15965 case CmpInst::ICMP_ULE:
15966 To = SE.getUMinExpr(FromRewritten, RHS);
15967 if (auto *UMax = dyn_cast<SCEVUMaxExpr>(FromRewritten))
15968 EnqueueOperands(UMax);
15969 break;
15970 case CmpInst::ICMP_SLT:
15971 case CmpInst::ICMP_SLE:
15972 To = SE.getSMinExpr(FromRewritten, RHS);
15973 if (auto *SMax = dyn_cast<SCEVSMaxExpr>(FromRewritten))
15974 EnqueueOperands(SMax);
15975 break;
15976 case CmpInst::ICMP_UGT:
15977 case CmpInst::ICMP_UGE:
15978 To = SE.getUMaxExpr(FromRewritten, RHS);
15979 if (auto *UMin = dyn_cast<SCEVUMinExpr>(FromRewritten))
15980 EnqueueOperands(UMin);
15981 break;
15982 case CmpInst::ICMP_SGT:
15983 case CmpInst::ICMP_SGE:
15984 To = SE.getSMaxExpr(FromRewritten, RHS);
15985 if (auto *SMin = dyn_cast<SCEVSMinExpr>(FromRewritten))
15986 EnqueueOperands(SMin);
15987 break;
15988 case CmpInst::ICMP_EQ:
15990 To = RHS;
15991 break;
15992 case CmpInst::ICMP_NE:
15993 if (match(RHS, m_scev_Zero())) {
15994 const SCEV *OneAlignedUp =
15995 getNextSCEVDivisibleByDivisor(One, DividesBy, SE);
15996 To = SE.getUMaxExpr(FromRewritten, OneAlignedUp);
15997 } else {
15998 // LHS != RHS can be rewritten as (LHS - RHS) = UMax(1, LHS - RHS),
15999 // but creating the subtraction eagerly is expensive. Track the
16000 // inequalities in a separate map, and materialize the rewrite lazily
16001 // when encountering a suitable subtraction while re-writing.
16002 if (LHS->getType()->isPointerTy()) {
16006 break;
16007 }
16008 const SCEVConstant *C;
16009 const SCEV *A, *B;
16012 RHS = A;
16013 LHS = B;
16014 }
16015 if (LHS > RHS)
16016 std::swap(LHS, RHS);
16017 Guards.NotEqual.insert({LHS, RHS});
16018 continue;
16019 }
16020 break;
16021 default:
16022 break;
16023 }
16024
16025 if (To)
16026 AddRewrite(From, FromRewritten, To);
16027 }
16028 };
16029
16031 // First, collect information from assumptions dominating the loop.
16032 for (auto &AssumeVH : SE.AC.assumptions()) {
16033 if (!AssumeVH)
16034 continue;
16035 auto *AssumeI = cast<CallInst>(AssumeVH);
16036 if (!SE.DT.dominates(AssumeI, Block))
16037 continue;
16038 Terms.emplace_back(AssumeI->getOperand(0), true);
16039 }
16040
16041 // Second, collect information from llvm.experimental.guards dominating the loop.
16042 auto *GuardDecl = Intrinsic::getDeclarationIfExists(
16043 SE.F.getParent(), Intrinsic::experimental_guard);
16044 if (GuardDecl)
16045 for (const auto *GU : GuardDecl->users())
16046 if (const auto *Guard = dyn_cast<IntrinsicInst>(GU))
16047 if (Guard->getFunction() == Block->getParent() &&
16048 SE.DT.dominates(Guard, Block))
16049 Terms.emplace_back(Guard->getArgOperand(0), true);
16050
16051 // Third, collect conditions from dominating branches. Starting at the loop
16052 // predecessor, climb up the predecessor chain, as long as there are
16053 // predecessors that can be found that have unique successors leading to the
16054 // original header.
16055 // TODO: share this logic with isLoopEntryGuardedByCond.
16056 unsigned NumCollectedConditions = 0;
16058 std::pair<const BasicBlock *, const BasicBlock *> Pair(Pred, Block);
16059 for (; Pair.first;
16060 Pair = SE.getPredecessorWithUniqueSuccessorForBB(Pair.first)) {
16061 VisitedBlocks.insert(Pair.second);
16062 const CondBrInst *LoopEntryPredicate =
16063 dyn_cast<CondBrInst>(Pair.first->getTerminator());
16064 if (!LoopEntryPredicate)
16065 continue;
16066
16067 Terms.emplace_back(LoopEntryPredicate->getCondition(),
16068 LoopEntryPredicate->getSuccessor(0) == Pair.second);
16069 NumCollectedConditions++;
16070
16071 // If we are recursively collecting guards stop after 2
16072 // conditions to limit compile-time impact for now.
16073 if (Depth > 0 && NumCollectedConditions == 2)
16074 break;
16075 }
16076 // Finally, if we stopped climbing the predecessor chain because
16077 // there wasn't a unique one to continue, try to collect conditions
16078 // for PHINodes by recursively following all of their incoming
16079 // blocks and try to merge the found conditions to build a new one
16080 // for the Phi.
16081 if (Pair.second->hasNPredecessorsOrMore(2) &&
16083 SmallDenseMap<const BasicBlock *, LoopGuards> IncomingGuards;
16084 for (auto &Phi : Pair.second->phis())
16085 collectFromPHI(SE, Guards, Phi, VisitedBlocks, IncomingGuards, Depth);
16086 }
16087
16088 // Now apply the information from the collected conditions to
16089 // Guards.RewriteMap. Conditions are processed in reverse order, so the
16090 // earliest conditions is processed first, except guards with divisibility
16091 // information, which are moved to the back. This ensures the SCEVs with the
16092 // shortest dependency chains are constructed first.
16094 GuardsToProcess;
16095 for (auto [Term, EnterIfTrue] : reverse(Terms)) {
16096 SmallVector<Value *, 8> Worklist;
16097 SmallPtrSet<Value *, 8> Visited;
16098 Worklist.push_back(Term);
16099 while (!Worklist.empty()) {
16100 Value *Cond = Worklist.pop_back_val();
16101 if (!Visited.insert(Cond).second)
16102 continue;
16103
16104 if (auto *Cmp = dyn_cast<ICmpInst>(Cond)) {
16105 auto Predicate =
16106 EnterIfTrue ? Cmp->getPredicate() : Cmp->getInversePredicate();
16107 const auto *LHS = SE.getSCEV(Cmp->getOperand(0));
16108 const auto *RHS = SE.getSCEV(Cmp->getOperand(1));
16109 // If LHS is a constant, apply information to the other expression.
16110 // TODO: If LHS is not a constant, check if using CompareSCEVComplexity
16111 // can improve results.
16112 if (isa<SCEVConstant>(LHS)) {
16113 std::swap(LHS, RHS);
16115 }
16116 GuardsToProcess.emplace_back(Predicate, LHS, RHS);
16117 continue;
16118 }
16119
16120 Value *L, *R;
16121 if (EnterIfTrue ? match(Cond, m_LogicalAnd(m_Value(L), m_Value(R)))
16122 : match(Cond, m_LogicalOr(m_Value(L), m_Value(R)))) {
16123 Worklist.push_back(L);
16124 Worklist.push_back(R);
16125 }
16126 }
16127 }
16128
16129 // Process divisibility guards in reverse order to populate DivGuards early.
16130 DenseMap<const SCEV *, APInt> Multiples;
16131 LoopGuards DivGuards(SE);
16132 for (const auto &[Predicate, LHS, RHS] : GuardsToProcess) {
16133 if (!isDivisibilityGuard(LHS, RHS, SE))
16134 continue;
16135 collectDivisibilityInformation(Predicate, LHS, RHS, DivGuards.RewriteMap,
16136 Multiples, SE);
16137 }
16138
16139 for (const auto &[Predicate, LHS, RHS] : GuardsToProcess)
16140 CollectCondition(Predicate, LHS, RHS, Guards.RewriteMap, DivGuards);
16141
16142 // Apply divisibility information last. This ensures it is applied to the
16143 // outermost expression after other rewrites for the given value.
16144 for (const auto &[K, Divisor] : Multiples) {
16145 const SCEV *DivisorSCEV = SE.getConstant(Divisor);
16146 Guards.RewriteMap[K] =
16148 Guards.rewrite(K), Divisor, SE),
16149 DivisorSCEV),
16150 DivisorSCEV);
16151 ExprsToRewrite.push_back(K);
16152 }
16153
16154 // Let the rewriter preserve NUW/NSW flags if the unsigned/signed ranges of
16155 // the replacement expressions are contained in the ranges of the replaced
16156 // expressions.
16157 Guards.PreserveNUW = true;
16158 Guards.PreserveNSW = true;
16159 for (const SCEV *Expr : ExprsToRewrite) {
16160 const SCEV *RewriteTo = Guards.RewriteMap[Expr];
16161 Guards.PreserveNUW &=
16162 SE.getUnsignedRange(Expr).contains(SE.getUnsignedRange(RewriteTo));
16163 Guards.PreserveNSW &=
16164 SE.getSignedRange(Expr).contains(SE.getSignedRange(RewriteTo));
16165 }
16166
16167 // Now that all rewrite information is collect, rewrite the collected
16168 // expressions with the information in the map. This applies information to
16169 // sub-expressions.
16170 if (ExprsToRewrite.size() > 1) {
16171 for (const SCEV *Expr : ExprsToRewrite) {
16172 const SCEV *RewriteTo = Guards.RewriteMap[Expr];
16173 Guards.RewriteMap.erase(Expr);
16174 Guards.RewriteMap.insert({Expr, Guards.rewrite(RewriteTo)});
16175 }
16176 }
16177}
16178
16180 /// A rewriter to replace SCEV expressions in Map with the corresponding entry
16181 /// in the map. It skips AddRecExpr because we cannot guarantee that the
16182 /// replacement is loop invariant in the loop of the AddRec.
16183 class SCEVLoopGuardRewriter
16184 : public SCEVRewriteVisitor<SCEVLoopGuardRewriter> {
16187
16189
16190 public:
16191 SCEVLoopGuardRewriter(ScalarEvolution &SE,
16192 const ScalarEvolution::LoopGuards &Guards)
16193 : SCEVRewriteVisitor(SE), Map(Guards.RewriteMap),
16194 NotEqual(Guards.NotEqual) {
16195 if (Guards.PreserveNUW)
16196 FlagMask = ScalarEvolution::setFlags(FlagMask, SCEV::FlagNUW);
16197 if (Guards.PreserveNSW)
16198 FlagMask = ScalarEvolution::setFlags(FlagMask, SCEV::FlagNSW);
16199 }
16200
16201 const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) { return Expr; }
16202
16203 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
16204 return Map.lookup_or(Expr, Expr);
16205 }
16206
16207 const SCEV *visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr) {
16208 if (const SCEV *S = Map.lookup(Expr))
16209 return S;
16210
16211 // If we didn't find the extact ZExt expr in the map, check if there's
16212 // an entry for a smaller ZExt we can use instead.
16213 Type *Ty = Expr->getType();
16214 const SCEV *Op = Expr->getOperand(0);
16215 unsigned Bitwidth = Ty->getScalarSizeInBits() / 2;
16216 while (Bitwidth % 8 == 0 && Bitwidth >= 8 &&
16217 Bitwidth > Op->getType()->getScalarSizeInBits()) {
16218 Type *NarrowTy = IntegerType::get(SE.getContext(), Bitwidth);
16219 auto *NarrowExt = SE.getZeroExtendExpr(Op, NarrowTy);
16220 if (const SCEV *S = Map.lookup(NarrowExt))
16221 return SE.getZeroExtendExpr(S, Ty);
16222 Bitwidth = Bitwidth / 2;
16223 }
16224
16226 Expr);
16227 }
16228
16229 const SCEV *visitSignExtendExpr(const SCEVSignExtendExpr *Expr) {
16230 if (const SCEV *S = Map.lookup(Expr))
16231 return S;
16233 Expr);
16234 }
16235
16236 const SCEV *visitUMinExpr(const SCEVUMinExpr *Expr) {
16237 if (const SCEV *S = Map.lookup(Expr))
16238 return S;
16240 }
16241
16242 const SCEV *visitSMinExpr(const SCEVSMinExpr *Expr) {
16243 if (const SCEV *S = Map.lookup(Expr))
16244 return S;
16246 }
16247
16248 const SCEV *visitAddExpr(const SCEVAddExpr *Expr) {
16249 // Helper to check if S is a subtraction (A - B) where A != B, and if so,
16250 // return UMax(S, 1).
16251 auto RewriteSubtraction = [&](const SCEV *S) -> const SCEV * {
16252 SCEVUse LHS, RHS;
16253 if (MatchBinarySub(S, LHS, RHS)) {
16254 if (LHS > RHS)
16255 std::swap(LHS, RHS);
16256 if (NotEqual.contains({LHS, RHS})) {
16257 const SCEV *OneAlignedUp = getNextSCEVDivisibleByDivisor(
16258 SE.getOne(S->getType()), SE.getConstantMultiple(S), SE);
16259 return SE.getUMaxExpr(OneAlignedUp, S);
16260 }
16261 }
16262 return nullptr;
16263 };
16264
16265 // Check if Expr itself is a subtraction pattern with guard info.
16266 if (const SCEV *Rewritten = RewriteSubtraction(Expr))
16267 return Rewritten;
16268
16269 // Trip count expressions sometimes consist of adding 3 operands, i.e.
16270 // (Const + A + B). There may be guard info for A + B, and if so, apply
16271 // it.
16272 // TODO: Could more generally apply guards to Add sub-expressions.
16273 if (isa<SCEVConstant>(Expr->getOperand(0)) &&
16274 Expr->getNumOperands() == 3) {
16275 const SCEV *Add =
16276 SE.getAddExpr(Expr->getOperand(1), Expr->getOperand(2));
16277 if (const SCEV *Rewritten = RewriteSubtraction(Add))
16278 return SE.getAddExpr(
16279 Expr->getOperand(0), Rewritten,
16280 ScalarEvolution::maskFlags(Expr->getNoWrapFlags(), FlagMask));
16281 if (const SCEV *S = Map.lookup(Add))
16282 return SE.getAddExpr(Expr->getOperand(0), S);
16283 }
16284 SmallVector<SCEVUse, 2> Operands;
16285 bool Changed = false;
16286 for (SCEVUse Op : Expr->operands()) {
16287 Operands.push_back(
16289 Changed |= Op != Operands.back();
16290 }
16291 // We are only replacing operands with equivalent values, so transfer the
16292 // flags from the original expression.
16293 return !Changed ? Expr
16294 : SE.getAddExpr(Operands,
16296 Expr->getNoWrapFlags(), FlagMask));
16297 }
16298
16299 const SCEV *visitMulExpr(const SCEVMulExpr *Expr) {
16300 SmallVector<SCEVUse, 2> Operands;
16301 bool Changed = false;
16302 for (SCEVUse Op : Expr->operands()) {
16303 Operands.push_back(
16305 Changed |= Op != Operands.back();
16306 }
16307 // We are only replacing operands with equivalent values, so transfer the
16308 // flags from the original expression.
16309 return !Changed ? Expr
16310 : SE.getMulExpr(Operands,
16312 Expr->getNoWrapFlags(), FlagMask));
16313 }
16314 };
16315
16316 if (RewriteMap.empty() && NotEqual.empty())
16317 return Expr;
16318
16319 SCEVLoopGuardRewriter Rewriter(SE, *this);
16320 return Rewriter.visit(Expr);
16321}
16322
16323const SCEV *ScalarEvolution::applyLoopGuards(const SCEV *Expr, const Loop *L) {
16324 return applyLoopGuards(Expr, LoopGuards::collect(L, *this));
16325}
16326
16328 const LoopGuards &Guards) {
16329 return Guards.rewrite(Expr);
16330}
assert(UImm &&(UImm !=~static_cast< T >(0)) &&"Invalid immediate!")
constexpr LLT S1
Rewrite undef for PHI
This file implements a class to represent arbitrary precision integral constant values and operations...
@ PostInc
MachineBasicBlock MachineBasicBlock::iterator DebugLoc DL
Expand Atomic instructions
#define X(NUM, ENUM, NAME)
Definition ELF.h:851
static GCRegistry::Add< ErlangGC > A("erlang", "erlang-compatible garbage collector")
static GCRegistry::Add< StatepointGC > D("statepoint-example", "an example strategy for statepoint")
static GCRegistry::Add< CoreCLRGC > E("coreclr", "CoreCLR-compatible GC")
static GCRegistry::Add< OcamlGC > B("ocaml", "ocaml 3.10-compatible GC")
#define LLVM_DUMP_METHOD
Mark debug helper function definitions like dump() that should not be stripped from debug builds.
Definition Compiler.h:661
This file contains the declarations for the subclasses of Constant, which represent the different fla...
SmallPtrSet< const BasicBlock *, 8 > VisitedBlocks
This file defines the DenseMap class.
This file builds on the ADT/GraphTraits.h file to build generic depth first graph iterator.
This file defines a hash set that can be used to remove duplication of nodes in a graph.
#define op(i)
Hexagon Common GEP
Value * getPointer(Value *Ptr)
This file provides various utilities for inspecting and working with the control flow graph in LLVM I...
This defines the Use class.
iv Induction Variable Users
Definition IVUsers.cpp:48
const AbstractManglingParser< Derived, Alloc >::OperatorInfo AbstractManglingParser< Derived, Alloc >::Ops[]
static bool isZero(Value *V, const DataLayout &DL, DominatorTree *DT, AssumptionCache *AC)
Definition Lint.cpp:539
#define F(x, y, z)
Definition MD5.cpp:54
#define I(x, y, z)
Definition MD5.cpp:57
#define G(x, y, z)
Definition MD5.cpp:55
#define T
#define T1
static constexpr unsigned SM(unsigned Version)
ConstantRange Range(APInt(BitWidth, Low), APInt(BitWidth, High))
uint64_t IntrinsicInst * II
#define P(N)
ppc ctr loops verify
PowerPC Reduce CR logical Operation
#define INITIALIZE_PASS_DEPENDENCY(depName)
Definition PassSupport.h:42
#define INITIALIZE_PASS_END(passName, arg, name, cfg, analysis)
Definition PassSupport.h:44
#define INITIALIZE_PASS_BEGIN(passName, arg, name, cfg, analysis)
Definition PassSupport.h:39
R600 Clause Merge
const SmallVectorImpl< MachineOperand > & Cond
static DominatorTree getDomTree(Function &F)
static bool isValid(const char C)
Returns true if C is a valid mangled character: <0-9a-zA-Z_>.
SI optimize exec mask operations pre RA
static void visit(BasicBlock &Start, std::function< bool(BasicBlock *)> op)
This file contains some templates that are useful if you are working with the STL at all.
This file provides utility classes that use RAII to save and restore values.
bool SCEVMinMaxExprContains(const SCEV *Root, const SCEV *OperandToFind, SCEVTypes RootKind)
static cl::opt< unsigned > MaxAddRecSize("scalar-evolution-max-add-rec-size", cl::Hidden, cl::desc("Max coefficients in AddRec during evolving"), cl::init(8))
static cl::opt< unsigned > RangeIterThreshold("scev-range-iter-threshold", cl::Hidden, cl::desc("Threshold for switching to iteratively computing SCEV ranges"), cl::init(32))
static const Loop * isIntegerLoopHeaderPHI(const PHINode *PN, LoopInfo &LI)
static unsigned getConstantTripCount(const SCEVConstant *ExitCount)
static int CompareValueComplexity(const LoopInfo *const LI, Value *LV, Value *RV, unsigned Depth)
Compare the two values LV and RV in terms of their "complexity" where "complexity" is a partial (and ...
static const SCEV * getNextSCEVDivisibleByDivisor(const SCEV *Expr, const APInt &DivisorVal, ScalarEvolution &SE)
static void PushLoopPHIs(const Loop *L, SmallVectorImpl< Instruction * > &Worklist, SmallPtrSetImpl< Instruction * > &Visited)
Push PHI nodes in the header of the given loop onto the given Worklist.
static void insertFoldCacheEntry(const ScalarEvolution::FoldID &ID, const SCEV *S, DenseMap< ScalarEvolution::FoldID, const SCEV * > &FoldCache, DenseMap< const SCEV *, SmallVector< ScalarEvolution::FoldID, 2 > > &FoldCacheUser)
static cl::opt< bool > ClassifyExpressions("scalar-evolution-classify-expressions", cl::Hidden, cl::init(true), cl::desc("When printing analysis, include information on every instruction"))
static bool hasHugeExpression(ArrayRef< SCEVUse > Ops)
Returns true if Ops contains a huge SCEV (the subtree of S contains at least HugeExprThreshold nodes)...
static bool CanConstantFold(const Instruction *I)
Return true if we can constant fold an instruction of the specified type, assuming that all operands ...
static cl::opt< unsigned > AddOpsInlineThreshold("scev-addops-inline-threshold", cl::Hidden, cl::desc("Threshold for inlining addition operands into a SCEV"), cl::init(500))
static cl::opt< unsigned > MaxLoopGuardCollectionDepth("scalar-evolution-max-loop-guard-collection-depth", cl::Hidden, cl::desc("Maximum depth for recursive loop guard collection"), cl::init(1))
static cl::opt< bool > VerifyIR("scev-verify-ir", cl::Hidden, cl::desc("Verify IR correctness when making sensitive SCEV queries (slow)"), cl::init(false))
static bool RangeRefPHIAllowedOperands(DominatorTree &DT, PHINode *PHI)
static const SCEV * getPreStartForExtend(const SCEVAddRecExpr *AR, Type *Ty, ScalarEvolution *SE, unsigned Depth)
static std::optional< APInt > MinOptional(std::optional< APInt > X, std::optional< APInt > Y)
Helper function to compare optional APInts: (a) if X and Y both exist, return min(X,...
static cl::opt< unsigned > MulOpsInlineThreshold("scev-mulops-inline-threshold", cl::Hidden, cl::desc("Threshold for inlining multiplication operands into a SCEV"), cl::init(32))
static BinaryOperator * getCommonInstForPHI(PHINode *PN)
static bool isDivisibilityGuard(const SCEV *LHS, const SCEV *RHS, ScalarEvolution &SE)
static std::optional< const SCEV * > createNodeForSelectViaUMinSeq(ScalarEvolution *SE, const SCEV *CondExpr, const SCEV *TrueExpr, const SCEV *FalseExpr)
static Constant * BuildConstantFromSCEV(const SCEV *V)
This builds up a Constant using the ConstantExpr interface.
static ConstantInt * EvaluateConstantChrecAtConstant(const SCEVAddRecExpr *AddRec, ConstantInt *C, ScalarEvolution &SE)
static const SCEV * BinomialCoefficient(const SCEV *It, unsigned K, ScalarEvolution &SE, Type *ResultTy)
Compute BC(It, K). The result has width W. Assume, K > 0.
static cl::opt< unsigned > MaxCastDepth("scalar-evolution-max-cast-depth", cl::Hidden, cl::desc("Maximum depth of recursive SExt/ZExt/Trunc"), cl::init(8))
static bool IsMinMaxConsistingOf(const SCEV *MaybeMinMaxExpr, const SCEV *Candidate)
Is MaybeMinMaxExpr an (U|S)(Min|Max) of Candidate and some other values?
static PHINode * getConstantEvolvingPHI(Value *V, const Loop *L)
getConstantEvolvingPHI - Given an LLVM value and a loop, return a PHI node in the loop that V is deri...
static const SCEV * SolveLinEquationWithOverflow(const APInt &A, const SCEV *B, SmallVectorImpl< const SCEVPredicate * > *Predicates, ScalarEvolution &SE, const Loop *L)
Finds the minimum unsigned root of the following equation:
static cl::opt< unsigned > MaxBruteForceIterations("scalar-evolution-max-iterations", cl::ReallyHidden, cl::desc("Maximum number of iterations SCEV will " "symbolically execute a constant " "derived loop"), cl::init(100))
static uint64_t umul_ov(uint64_t i, uint64_t j, bool &Overflow)
static void PrintSCEVWithTypeHint(raw_ostream &OS, const SCEV *S)
When printing a top-level SCEV for trip counts, it's helpful to include a type for constants which ar...
static void PrintLoopInfo(raw_ostream &OS, ScalarEvolution *SE, const Loop *L)
static SCEV::NoWrapFlags StrengthenNoWrapFlags(ScalarEvolution *SE, SCEVTypes Type, ArrayRef< SCEVUse > Ops, SCEV::NoWrapFlags Flags)
static bool containsConstantInAddMulChain(const SCEV *StartExpr)
Determine if any of the operands in this SCEV are a constant or if any of the add or multiply express...
static const SCEV * getExtendAddRecStart(const SCEVAddRecExpr *AR, Type *Ty, ScalarEvolution *SE, unsigned Depth)
static bool CollectAddOperandsWithScales(SmallDenseMap< SCEVUse, APInt, 16 > &M, SmallVectorImpl< SCEVUse > &NewOps, APInt &AccumulatedConstant, ArrayRef< SCEVUse > Ops, const APInt &Scale, ScalarEvolution &SE)
Process the given Ops list, which is a list of operands to be added under the given scale,...
static const SCEV * constantFoldAndGroupOps(ScalarEvolution &SE, LoopInfo &LI, DominatorTree &DT, SmallVectorImpl< SCEVUse > &Ops, FoldT Fold, IsIdentityT IsIdentity, IsAbsorberT IsAbsorber)
Performs a number of common optimizations on the passed Ops.
static cl::opt< unsigned > MaxPhiSCCAnalysisSize("scalar-evolution-max-scc-analysis-depth", cl::Hidden, cl::desc("Maximum amount of nodes to process while searching SCEVUnknown " "Phi strongly connected components"), cl::init(8))
static bool IsKnownPredicateViaAddRecStart(ScalarEvolution &SE, CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS)
static void GroupByComplexity(SmallVectorImpl< SCEVUse > &Ops, LoopInfo *LI, DominatorTree &DT)
Given a list of SCEV objects, order them by their complexity, and group objects of the same complexit...
static bool collectDivisibilityInformation(ICmpInst::Predicate Predicate, const SCEV *LHS, const SCEV *RHS, DenseMap< const SCEV *, const SCEV * > &DivInfo, DenseMap< const SCEV *, APInt > &Multiples, ScalarEvolution &SE)
static cl::opt< unsigned > MaxSCEVOperationsImplicationDepth("scalar-evolution-max-scev-operations-implication-depth", cl::Hidden, cl::desc("Maximum depth of recursive SCEV operations implication analysis"), cl::init(2))
static void PushDefUseChildren(Instruction *I, SmallVectorImpl< Instruction * > &Worklist, SmallPtrSetImpl< Instruction * > &Visited)
Push users of the given Instruction onto the given Worklist.
static std::optional< APInt > SolveQuadraticAddRecRange(const SCEVAddRecExpr *AddRec, const ConstantRange &Range, ScalarEvolution &SE)
Let c(n) be the value of the quadratic chrec {0,+,M,+,N} after n iterations.
static cl::opt< bool > UseContextForNoWrapFlagInference("scalar-evolution-use-context-for-no-wrap-flag-strenghening", cl::Hidden, cl::desc("Infer nuw/nsw flags using context where suitable"), cl::init(true))
static cl::opt< bool > EnableFiniteLoopControl("scalar-evolution-finite-loop", cl::Hidden, cl::desc("Handle <= and >= in finite loops"), cl::init(true))
static bool getOperandsForSelectLikePHI(DominatorTree &DT, PHINode *PN, Value *&Cond, Value *&LHS, Value *&RHS)
static std::optional< std::tuple< APInt, APInt, APInt, APInt, unsigned > > GetQuadraticEquation(const SCEVAddRecExpr *AddRec)
For a given quadratic addrec, generate coefficients of the corresponding quadratic equation,...
static bool isKnownPredicateExtendIdiom(CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS)
static std::optional< BinaryOp > MatchBinaryOp(Value *V, const DataLayout &DL, AssumptionCache &AC, const DominatorTree &DT, const Instruction *CxtI)
Try to map V into a BinaryOp, and return std::nullopt on failure.
static std::optional< APInt > SolveQuadraticAddRecExact(const SCEVAddRecExpr *AddRec, ScalarEvolution &SE)
Let c(n) be the value of the quadratic chrec {L,+,M,+,N} after n iterations.
static std::optional< APInt > TruncIfPossible(std::optional< APInt > X, unsigned BitWidth)
Helper function to truncate an optional APInt to a given BitWidth.
static cl::opt< unsigned > MaxSCEVCompareDepth("scalar-evolution-max-scev-compare-depth", cl::Hidden, cl::desc("Maximum depth of recursive SCEV complexity comparisons"), cl::init(32))
static APInt extractConstantWithoutWrapping(ScalarEvolution &SE, const SCEVConstant *ConstantTerm, const SCEVAddExpr *WholeAddExpr)
static cl::opt< unsigned > MaxConstantEvolvingDepth("scalar-evolution-max-constant-evolving-depth", cl::Hidden, cl::desc("Maximum depth of recursive constant evolving"), cl::init(32))
static ConstantRange getRangeForAffineARHelper(APInt Step, const ConstantRange &StartRange, const APInt &MaxBECount, bool Signed)
static bool MatchBinarySub(const SCEV *S, SCEVUse &LHS, SCEVUse &RHS)
static std::optional< ConstantRange > GetRangeFromMetadata(Value *V)
Helper method to assign a range to V from metadata present in the IR.
static cl::opt< unsigned > HugeExprThreshold("scalar-evolution-huge-expr-threshold", cl::Hidden, cl::desc("Size of the expression which is considered huge"), cl::init(4096))
static Type * isSimpleCastedPHI(const SCEV *Op, const SCEVUnknown *SymbolicPHI, bool &Signed, ScalarEvolution &SE)
Helper function to createAddRecFromPHIWithCasts.
static Constant * EvaluateExpression(Value *V, const Loop *L, DenseMap< Instruction *, Constant * > &Vals, const DataLayout &DL, const TargetLibraryInfo *TLI)
EvaluateExpression - Given an expression that passes the getConstantEvolvingPHI predicate,...
static const SCEV * getPreviousSCEVDivisibleByDivisor(const SCEV *Expr, const APInt &DivisorVal, ScalarEvolution &SE)
static const SCEV * MatchNotExpr(const SCEV *Expr)
If Expr computes ~A, return A else return nullptr.
static cl::opt< unsigned > MaxValueCompareDepth("scalar-evolution-max-value-compare-depth", cl::Hidden, cl::desc("Maximum depth of recursive value complexity comparisons"), cl::init(2))
static const SCEV * applyDivisibilityOnMinMaxExpr(const SCEV *MinMaxExpr, APInt Divisor, ScalarEvolution &SE)
static cl::opt< bool, true > VerifySCEVOpt("verify-scev", cl::Hidden, cl::location(VerifySCEV), cl::desc("Verify ScalarEvolution's backedge taken counts (slow)"))
static const SCEV * getSignedOverflowLimitForStep(const SCEV *Step, ICmpInst::Predicate *Pred, ScalarEvolution *SE)
static cl::opt< unsigned > MaxArithDepth("scalar-evolution-max-arith-depth", cl::Hidden, cl::desc("Maximum depth of recursive arithmetics"), cl::init(32))
static bool HasSameValue(const SCEV *A, const SCEV *B)
SCEV structural equivalence is usually sufficient for testing whether two expressions are equal,...
static uint64_t Choose(uint64_t n, uint64_t k, bool &Overflow)
Compute the result of "n choose k", the binomial coefficient.
static std::optional< int > CompareSCEVComplexity(const LoopInfo *const LI, const SCEV *LHS, const SCEV *RHS, DominatorTree &DT, unsigned Depth=0)
static bool canConstantEvolve(Instruction *I, const Loop *L)
Determine whether this instruction can constant evolve within this loop assuming its operands can all...
static PHINode * getConstantEvolvingPHIOperands(Instruction *UseInst, const Loop *L, DenseMap< Instruction *, PHINode * > &PHIMap, unsigned Depth)
getConstantEvolvingPHIOperands - Implement getConstantEvolvingPHI by recursing through each instructi...
static bool scevUnconditionallyPropagatesPoisonFromOperands(SCEVTypes Kind)
static cl::opt< bool > VerifySCEVStrict("verify-scev-strict", cl::Hidden, cl::desc("Enable stricter verification with -verify-scev is passed"))
static Constant * getOtherIncomingValue(PHINode *PN, BasicBlock *BB)
static cl::opt< bool > UseExpensiveRangeSharpening("scalar-evolution-use-expensive-range-sharpening", cl::Hidden, cl::init(false), cl::desc("Use more powerful methods of sharpening expression ranges. May " "be costly in terms of compile time"))
static const SCEV * getUnsignedOverflowLimitForStep(const SCEV *Step, ICmpInst::Predicate *Pred, ScalarEvolution *SE)
static bool IsKnownPredicateViaMinOrMax(ScalarEvolution &SE, CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS)
Is LHS Pred RHS true on the virtue of LHS or RHS being a Min or Max expression?
static bool BrPHIToSelect(DominatorTree &DT, CondBrInst *BI, PHINode *Merge, Value *&C, Value *&LHS, Value *&RHS)
This file defines the make_scope_exit function, which executes user-defined cleanup logic at scope ex...
static bool InBlock(const Value *V, const BasicBlock *BB)
Provides some synthesis utilities to produce sequences of values.
This file defines the SmallPtrSet class.
This file defines the SmallVector class.
This file defines the 'Statistic' class, which is designed to be an easy way to expose various metric...
#define STATISTIC(VARNAME, DESC)
Definition Statistic.h:171
This file contains some functions that are useful when dealing with strings.
#define LLVM_DEBUG(...)
Definition Debug.h:114
static TableGen::Emitter::Opt Y("gen-skeleton-entry", EmitSkeleton, "Generate example skeleton entry")
static SymbolRef::Type getType(const Symbol *Sym)
Definition TapiFile.cpp:39
LocallyHashedType DenseMapInfo< LocallyHashedType >::Empty
static std::optional< unsigned > getOpcode(ArrayRef< VPValue * > Values)
Returns the opcode of Values or ~0 if they do not all agree.
Definition VPlanSLP.cpp:247
static std::optional< bool > isImpliedCondOperands(CmpInst::Predicate Pred, const Value *ALHS, const Value *ARHS, const Value *BLHS, const Value *BRHS)
Return true if "icmp Pred BLHS BRHS" is true whenever "icmp PredALHS ARHS" is true.
Virtual Register Rewriter
Value * RHS
Value * LHS
BinaryOperator * Mul
static const uint32_t IV[8]
Definition blake3_impl.h:83
SCEVCastSinkingRewriter(ScalarEvolution &SE, Type *TargetTy, ConversionFn CreatePtrCast)
static const SCEV * rewrite(const SCEV *Scev, ScalarEvolution &SE, Type *TargetTy, ConversionFn CreatePtrCast)
const SCEV * visitUnknown(const SCEVUnknown *Expr)
const SCEV * visitMulExpr(const SCEVMulExpr *Expr)
const SCEV * visitAddExpr(const SCEVAddExpr *Expr)
const SCEV * visit(const SCEV *S)
Class for arbitrary precision integers.
Definition APInt.h:78
LLVM_ABI APInt umul_ov(const APInt &RHS, bool &Overflow) const
Definition APInt.cpp:2011
LLVM_ABI APInt zext(unsigned width) const
Zero extend to a new width.
Definition APInt.cpp:1043
bool isMinSignedValue() const
Determine if this is the smallest signed value.
Definition APInt.h:424
uint64_t getZExtValue() const
Get zero extended value.
Definition APInt.h:1555
void setHighBits(unsigned hiBits)
Set the top hiBits bits.
Definition APInt.h:1406
LLVM_ABI APInt getHiBits(unsigned numBits) const
Compute an APInt containing numBits highbits from this APInt.
Definition APInt.cpp:639
unsigned getActiveBits() const
Compute the number of active bits in the value.
Definition APInt.h:1527
LLVM_ABI APInt trunc(unsigned width) const
Truncate to new width.
Definition APInt.cpp:956
static APInt getMaxValue(unsigned numBits)
Gets maximum unsigned value of APInt for specific bit width.
Definition APInt.h:207
APInt abs() const
Get the absolute value.
Definition APInt.h:1810
bool sgt(const APInt &RHS) const
Signed greater than comparison.
Definition APInt.h:1208
bool isAllOnes() const
Determine if all bits are set. This is true for zero-width values.
Definition APInt.h:372
bool ugt(const APInt &RHS) const
Unsigned greater than comparison.
Definition APInt.h:1189
bool isZero() const
Determine if this value is zero, i.e. all bits are clear.
Definition APInt.h:381
bool isSignMask() const
Check if the APInt's value is returned by getSignMask.
Definition APInt.h:467
LLVM_ABI APInt urem(const APInt &RHS) const
Unsigned remainder operation.
Definition APInt.cpp:1697
unsigned getBitWidth() const
Return the number of bits in the APInt.
Definition APInt.h:1503
bool ult(const APInt &RHS) const
Unsigned less than comparison.
Definition APInt.h:1118
static APInt getSignedMaxValue(unsigned numBits)
Gets maximum signed value of APInt for a specific bit width.
Definition APInt.h:210
static APInt getMinValue(unsigned numBits)
Gets minimum unsigned value of APInt for a specific bit width.
Definition APInt.h:217
bool isNegative() const
Determine sign of this APInt.
Definition APInt.h:330
bool sle(const APInt &RHS) const
Signed less or equal comparison.
Definition APInt.h:1173
static APInt getSignedMinValue(unsigned numBits)
Gets minimum signed value of APInt for a specific bit width.
Definition APInt.h:220
bool isNonPositive() const
Determine if this APInt Value is non-positive (<= 0).
Definition APInt.h:362
unsigned countTrailingZeros() const
Definition APInt.h:1662
bool isStrictlyPositive() const
Determine if this APInt Value is positive.
Definition APInt.h:357
unsigned logBase2() const
Definition APInt.h:1776
APInt ashr(unsigned ShiftAmt) const
Arithmetic right-shift function.
Definition APInt.h:834
LLVM_ABI APInt multiplicativeInverse() const
Definition APInt.cpp:1305
bool ule(const APInt &RHS) const
Unsigned less or equal comparison.
Definition APInt.h:1157
LLVM_ABI APInt sext(unsigned width) const
Sign extend to a new width.
Definition APInt.cpp:1016
APInt shl(unsigned shiftAmt) const
Left-shift function.
Definition APInt.h:880
bool isPowerOf2() const
Check if this APInt's value is a power of two greater than zero.
Definition APInt.h:441
static APInt getLowBitsSet(unsigned numBits, unsigned loBitsSet)
Constructs an APInt value that has the bottom loBitsSet bits set.
Definition APInt.h:307
bool isSignBitSet() const
Determine if sign bit of this APInt is set.
Definition APInt.h:342
bool slt(const APInt &RHS) const
Signed less than comparison.
Definition APInt.h:1137
static APInt getZero(unsigned numBits)
Get the '0' value for the specified bit-width.
Definition APInt.h:201
bool isIntN(unsigned N) const
Check if this APInt has an N-bits unsigned integer value.
Definition APInt.h:433
static APInt getOneBitSet(unsigned numBits, unsigned BitNo)
Return an APInt with exactly one bit set in the result.
Definition APInt.h:240
bool uge(const APInt &RHS) const
Unsigned greater or equal comparison.
Definition APInt.h:1228
This templated class represents "all analyses that operate over <aparticular IR unit>" (e....
Definition Analysis.h:50
PassT::Result & getResult(IRUnitT &IR, ExtraArgTs... ExtraArgs)
Get the result of an analysis pass for a given IR unit.
Represent the analysis usage information of a pass.
void setPreservesAll()
Set by analyses that do not transform their input at all.
AnalysisUsage & addRequiredTransitive()
ArrayRef - Represent a constant reference to an array (0 or more elements consecutively in memory),...
Definition ArrayRef.h:40
iterator end() const
Definition ArrayRef.h:131
size_t size() const
size - Get the array size.
Definition ArrayRef.h:142
iterator begin() const
Definition ArrayRef.h:130
A function analysis which provides an AssumptionCache.
An immutable pass that tracks lazily created AssumptionCache objects.
A cache of @llvm.assume calls within a function.
MutableArrayRef< WeakVH > assumptions()
Access the list of assumption handles currently tracked for this function.
LLVM Basic Block Representation.
Definition BasicBlock.h:62
iterator begin()
Instruction iterator methods.
Definition BasicBlock.h:461
const Function * getParent() const
Return the enclosing method, or null if none.
Definition BasicBlock.h:213
LLVM_ABI const BasicBlock * getSinglePredecessor() const
Return the predecessor of this block if it has a single predecessor block.
const Instruction & front() const
Definition BasicBlock.h:484
const Instruction * getTerminator() const LLVM_READONLY
Returns the terminator instruction; assumes that the block is well-formed.
Definition BasicBlock.h:237
LLVM_ABI unsigned getNoWrapKind() const
Returns one of OBO::NoSignedWrap or OBO::NoUnsignedWrap.
LLVM_ABI Instruction::BinaryOps getBinaryOp() const
Returns the binary operation underlying the intrinsic.
BinaryOps getOpcode() const
Definition InstrTypes.h:374
This class represents a function call, abstracting a target machine's calling convention.
virtual void deleted()
Callback for Value destruction.
void setValPtr(Value *P)
bool isFalseWhenEqual() const
This is just a convenience.
Definition InstrTypes.h:948
Predicate
This enumeration lists the possible predicates for CmpInst subclasses.
Definition InstrTypes.h:676
@ ICMP_SLT
signed less than
Definition InstrTypes.h:705
@ ICMP_SLE
signed less or equal
Definition InstrTypes.h:706
@ ICMP_UGE
unsigned greater or equal
Definition InstrTypes.h:700
@ ICMP_UGT
unsigned greater than
Definition InstrTypes.h:699
@ ICMP_SGT
signed greater than
Definition InstrTypes.h:703
@ ICMP_ULT
unsigned less than
Definition InstrTypes.h:701
@ ICMP_NE
not equal
Definition InstrTypes.h:698
@ ICMP_SGE
signed greater or equal
Definition InstrTypes.h:704
@ ICMP_ULE
unsigned less or equal
Definition InstrTypes.h:702
bool isSigned() const
Definition InstrTypes.h:930
Predicate getSwappedPredicate() const
For example, EQ->EQ, SLE->SGE, ULT->UGT, OEQ->OEQ, ULE->UGE, OLT->OGT, etc.
Definition InstrTypes.h:827
bool isTrueWhenEqual() const
This is just a convenience.
Definition InstrTypes.h:942
Predicate getInversePredicate() const
For example, EQ -> NE, UGT -> ULE, SLT -> SGE, OEQ -> UNE, UGT -> OLE, OLT -> UGE,...
Definition InstrTypes.h:789
bool isUnsigned() const
Definition InstrTypes.h:936
bool isRelational() const
Return true if the predicate is relational (not EQ or NE).
Definition InstrTypes.h:926
An abstraction over a floating-point predicate, and a pack of an integer predicate with samesign info...
static LLVM_ABI std::optional< CmpPredicate > getMatching(CmpPredicate A, CmpPredicate B)
Compares two CmpPredicates taking samesign into account and returns the canonicalized CmpPredicate if...
LLVM_ABI CmpInst::Predicate getPreferredSignedPredicate() const
Attempts to return a signed CmpInst::Predicate from the CmpPredicate.
CmpInst::Predicate dropSameSign() const
Drops samesign information.
Conditional Branch instruction.
Value * getCondition() const
BasicBlock * getSuccessor(unsigned i) const
static LLVM_ABI Constant * getNot(Constant *C)
static Constant * getPtrAdd(Constant *Ptr, Constant *Offset, GEPNoWrapFlags NW=GEPNoWrapFlags::none(), std::optional< ConstantRange > InRange=std::nullopt, Type *OnlyIfReduced=nullptr)
Create a getelementptr i8, ptr, offset constant expression.
Definition Constants.h:1472
static LLVM_ABI Constant * getPtrToInt(Constant *C, Type *Ty, bool OnlyIfReduced=false)
static LLVM_ABI Constant * getPtrToAddr(Constant *C, Type *Ty, bool OnlyIfReduced=false)
static LLVM_ABI Constant * getAdd(Constant *C1, Constant *C2, bool HasNUW=false, bool HasNSW=false)
static LLVM_ABI Constant * getNeg(Constant *C, bool HasNSW=false)
static LLVM_ABI Constant * getTrunc(Constant *C, Type *Ty, bool OnlyIfReduced=false)
This is the shared class of boolean and integer constants.
Definition Constants.h:87
bool isZero() const
This is just a convenience method to make client code smaller for a common code.
Definition Constants.h:219
static LLVM_ABI ConstantInt * getFalse(LLVMContext &Context)
uint64_t getZExtValue() const
Return the constant as a 64-bit unsigned integer value after it has been zero extended as appropriate...
Definition Constants.h:168
const APInt & getValue() const
Return the constant as an APInt value reference.
Definition Constants.h:159
static LLVM_ABI ConstantInt * getBool(LLVMContext &Context, bool V)
This class represents a range of values.
LLVM_ABI ConstantRange add(const ConstantRange &Other) const
Return a new range representing the possible values resulting from an addition of a value in this ran...
LLVM_ABI ConstantRange zextOrTrunc(uint32_t BitWidth) const
Make this range have the bit width given by BitWidth.
PreferredRangeType
If represented precisely, the result of some range operations may consist of multiple disjoint ranges...
LLVM_ABI bool getEquivalentICmp(CmpInst::Predicate &Pred, APInt &RHS) const
Set up Pred and RHS such that ConstantRange::makeExactICmpRegion(Pred, RHS) == *this.
const APInt & getLower() const
Return the lower value for this range.
LLVM_ABI ConstantRange urem(const ConstantRange &Other) const
Return a new range representing the possible values resulting from an unsigned remainder operation of...
LLVM_ABI bool isFullSet() const
Return true if this set contains all of the elements possible for this data-type.
LLVM_ABI bool icmp(CmpInst::Predicate Pred, const ConstantRange &Other) const
Does the predicate Pred hold between ranges this and Other?
LLVM_ABI bool isEmptySet() const
Return true if this set contains no members.
LLVM_ABI ConstantRange zeroExtend(uint32_t BitWidth) const
Return a new range in the specified integer type, which must be strictly larger than the current type...
LLVM_ABI bool isSignWrappedSet() const
Return true if this set wraps around the signed domain.
LLVM_ABI APInt getSignedMin() const
Return the smallest signed value contained in the ConstantRange.
LLVM_ABI bool isWrappedSet() const
Return true if this set wraps around the unsigned domain.
LLVM_ABI void print(raw_ostream &OS) const
Print out the bounds to a stream.
LLVM_ABI ConstantRange truncate(uint32_t BitWidth, unsigned NoWrapKind=0) const
Return a new range in the specified integer type, which must be strictly smaller than the current typ...
LLVM_ABI ConstantRange signExtend(uint32_t BitWidth) const
Return a new range in the specified integer type, which must be strictly larger than the current type...
const APInt & getUpper() const
Return the upper value for this range.
LLVM_ABI ConstantRange unionWith(const ConstantRange &CR, PreferredRangeType Type=Smallest) const
Return the range that results from the union of this range with another range.
static LLVM_ABI ConstantRange makeExactICmpRegion(CmpInst::Predicate Pred, const APInt &Other)
Produce the exact range such that all values in the returned range satisfy the given predicate with a...
LLVM_ABI bool contains(const APInt &Val) const
Return true if the specified value is in the set.
LLVM_ABI APInt getUnsignedMax() const
Return the largest unsigned value contained in the ConstantRange.
LLVM_ABI ConstantRange intersectWith(const ConstantRange &CR, PreferredRangeType Type=Smallest) const
Return the range that results from the intersection of this range with another range.
LLVM_ABI APInt getSignedMax() const
Return the largest signed value contained in the ConstantRange.
static ConstantRange getNonEmpty(APInt Lower, APInt Upper)
Create non-empty constant range with the given bounds.
static LLVM_ABI ConstantRange makeGuaranteedNoWrapRegion(Instruction::BinaryOps BinOp, const ConstantRange &Other, unsigned NoWrapKind)
Produce the largest range containing all X such that "X BinOp Y" is guaranteed not to wrap (overflow)...
LLVM_ABI unsigned getMinSignedBits() const
Compute the maximal number of bits needed to represent every value in this signed range.
uint32_t getBitWidth() const
Get the bit width of this ConstantRange.
LLVM_ABI ConstantRange sub(const ConstantRange &Other) const
Return a new range representing the possible values resulting from a subtraction of a value in this r...
LLVM_ABI ConstantRange sextOrTrunc(uint32_t BitWidth) const
Make this range have the bit width given by BitWidth.
static LLVM_ABI ConstantRange makeExactNoWrapRegion(Instruction::BinaryOps BinOp, const APInt &Other, unsigned NoWrapKind)
Produce the range that contains X if and only if "X BinOp Other" does not wrap.
This is an important base class in LLVM.
Definition Constant.h:43
A parsed version of the target data layout string in and methods for querying it.
Definition DataLayout.h:64
LLVM_ABI const StructLayout * getStructLayout(StructType *Ty) const
Returns a StructLayout object, indicating the alignment of the struct, its size, and the offsets of i...
LLVM_ABI IntegerType * getIntPtrType(LLVMContext &C, unsigned AddressSpace=0) const
Returns an integer type with size at least as big as that of a pointer in the given address space.
LLVM_ABI unsigned getIndexTypeSizeInBits(Type *Ty) const
The size in bits of the index used in GEP calculation for this type.
LLVM_ABI IntegerType * getIndexType(LLVMContext &C, unsigned AddressSpace) const
Returns the type of a GEP index in AddressSpace.
TypeSize getTypeSizeInBits(Type *Ty) const
Size examples:
Definition DataLayout.h:784
ValueT lookup(const_arg_type_t< KeyT > Val) const
lookup - Return the entry for the specified key, or a default constructed value if no such entry exis...
Definition DenseMap.h:205
iterator find(const_arg_type_t< KeyT > Val)
Definition DenseMap.h:178
std::pair< iterator, bool > try_emplace(KeyT &&Key, Ts &&...Args)
Definition DenseMap.h:256
DenseMapIterator< KeyT, ValueT, KeyInfoT, BucketT > iterator
Definition DenseMap.h:74
iterator find_as(const LookupKeyT &Val)
Alternate version of find() which allows a different, and possibly less expensive,...
Definition DenseMap.h:191
size_type count(const_arg_type_t< KeyT > Val) const
Return 1 if the specified key is in the map, 0 otherwise.
Definition DenseMap.h:174
iterator end()
Definition DenseMap.h:81
bool contains(const_arg_type_t< KeyT > Val) const
Return true if the specified key is in the map, false otherwise.
Definition DenseMap.h:169
void swap(DerivedT &RHS)
Definition DenseMap.h:371
std::pair< iterator, bool > insert(const std::pair< KeyT, ValueT > &KV)
Definition DenseMap.h:241
Analysis pass which computes a DominatorTree.
Definition Dominators.h:278
Legacy analysis pass which computes a DominatorTree.
Definition Dominators.h:316
Concrete subclass of DominatorTreeBase that is used to compute a normal dominator tree.
Definition Dominators.h:159
LLVM_ABI bool isReachableFromEntry(const Use &U) const
Provide an overload for a Use.
LLVM_ABI bool dominates(const BasicBlock *BB, const Use &U) const
Return true if the (end of the) basic block BB dominates the use U.
FoldingSetNodeIDRef - This class describes a reference to an interned FoldingSetNodeID,...
Definition FoldingSet.h:172
FoldingSetNodeID - This class is used to gather all the unique data bits of a node.
Definition FoldingSet.h:209
FunctionPass(char &pid)
Definition Pass.h:316
Represents flags for the getelementptr instruction/expression.
bool hasNoUnsignedSignedWrap() const
bool hasNoUnsignedWrap() const
static GEPNoWrapFlags none()
static LLVM_ABI Type * getTypeAtIndex(Type *Ty, Value *Idx)
Return the type of the element at the given index of an indexable type.
Module * getParent()
Get the module that this global value is contained inside of...
static bool isPrivateLinkage(LinkageTypes Linkage)
static bool isInternalLinkage(LinkageTypes Linkage)
This instruction compares its operands according to the predicate given to the constructor.
CmpPredicate getCmpPredicate() const
static bool isGE(Predicate P)
Return true if the predicate is SGE or UGE.
CmpPredicate getSwappedCmpPredicate() const
static LLVM_ABI bool compare(const APInt &LHS, const APInt &RHS, ICmpInst::Predicate Pred)
Return result of LHS Pred RHS comparison.
static bool isLT(Predicate P)
Return true if the predicate is SLT or ULT.
CmpPredicate getInverseCmpPredicate() const
Predicate getNonStrictCmpPredicate() const
For example, SGT -> SGE, SLT -> SLE, ULT -> ULE, UGT -> UGE.
static bool isGT(Predicate P)
Return true if the predicate is SGT or UGT.
Predicate getFlippedSignednessPredicate() const
For example, SLT->ULT, ULT->SLT, SLE->ULE, ULE->SLE, EQ->EQ.
static CmpPredicate getInverseCmpPredicate(CmpPredicate Pred)
static bool isEquality(Predicate P)
Return true if this predicate is either EQ or NE.
bool isRelational() const
Return true if the predicate is relational (not EQ or NE).
static bool isLE(Predicate P)
Return true if the predicate is SLE or ULE.
LLVM_ABI bool hasNoUnsignedWrap() const LLVM_READONLY
Determine whether the no unsigned wrap flag is set.
LLVM_ABI bool hasNoSignedWrap() const LLVM_READONLY
Determine whether the no signed wrap flag is set.
LLVM_ABI bool isIdenticalToWhenDefined(const Instruction *I, bool IntersectAttrs=false) const LLVM_READONLY
This is like isIdenticalTo, except that it ignores the SubclassOptionalData flags,...
Class to represent integer types.
static LLVM_ABI IntegerType * get(LLVMContext &C, unsigned NumBits)
This static method is the primary way of constructing an IntegerType.
Definition Type.cpp:354
A helper class to return the specified delimiter string after the first invocation of operator String...
An instruction for reading from memory.
Analysis pass that exposes the LoopInfo for a function.
Definition LoopInfo.h:569
bool contains(const LoopT *L) const
Return true if the specified loop is contained within in this loop.
BlockT * getHeader() const
unsigned getLoopDepth() const
Return the nesting level of this loop.
BlockT * getLoopPredecessor() const
If the given loop's header has exactly one unique predecessor outside the loop, return it.
LoopT * getParentLoop() const
Return the parent loop if it exists or nullptr for top level loops.
unsigned getLoopDepth(const BlockT *BB) const
Return the loop nesting level of the specified block.
LoopT * getLoopFor(const BlockT *BB) const
Return the inner most loop that BB lives in.
The legacy pass manager's analysis pass to compute loop information.
Definition LoopInfo.h:596
Represents a single loop in the control flow graph.
Definition LoopInfo.h:40
bool isLoopInvariant(const Value *V) const
Return true if the specified value is loop invariant.
Definition LoopInfo.cpp:67
Metadata node.
Definition Metadata.h:1080
A Module instance is used to store all the information related to an LLVM module.
Definition Module.h:67
unsigned getOpcode() const
Return the opcode for this Instruction or ConstantExpr.
Definition Operator.h:43
Utility class for integer operators which may exhibit overflow - Add, Sub, Mul, and Shl.
Definition Operator.h:78
bool hasNoSignedWrap() const
Test whether this operation is known to never undergo signed overflow, aka the nsw property.
Definition Operator.h:111
bool hasNoUnsignedWrap() const
Test whether this operation is known to never undergo unsigned overflow, aka the nuw property.
Definition Operator.h:105
iterator_range< const_block_iterator > blocks() const
op_range incoming_values()
Value * getIncomingValueForBlock(const BasicBlock *BB) const
BasicBlock * getIncomingBlock(unsigned i) const
Return incoming basic block number i.
Value * getIncomingValue(unsigned i) const
Return incoming value number x.
unsigned getNumIncomingValues() const
Return the number of incoming edges.
AnalysisType & getAnalysis() const
getAnalysis<AnalysisType>() - This function is used by subclasses to get to the analysis information ...
PointerIntPair - This class implements a pair of a pointer and small integer.
static PointerType * getUnqual(Type *ElementType)
This constructs a pointer to an object of the specified type in the default address space (address sp...
static LLVM_ABI PoisonValue * get(Type *T)
Static factory methods - Return an 'poison' object of the specified type.
LLVM_ABI void addPredicate(const SCEVPredicate &Pred)
Adds a new predicate.
LLVM_ABI const SCEVPredicate & getPredicate() const
LLVM_ABI const SCEV * getPredicatedSCEV(const SCEV *Expr)
Returns the rewritten SCEV for Expr in the context of the current SCEV predicate.
LLVM_ABI bool hasNoOverflow(Value *V, SCEVWrapPredicate::IncrementWrapFlags Flags)
Returns true if we've proved that V doesn't wrap by means of a SCEV predicate.
LLVM_ABI void setNoOverflow(Value *V, SCEVWrapPredicate::IncrementWrapFlags Flags)
Proves that V doesn't overflow by adding SCEV predicate.
LLVM_ABI void print(raw_ostream &OS, unsigned Depth) const
Print the SCEV mappings done by the Predicated Scalar Evolution.
LLVM_ABI bool areAddRecsEqualWithPreds(const SCEVAddRecExpr *AR1, const SCEVAddRecExpr *AR2) const
Check if AR1 and AR2 are equal, while taking into account Equal predicates in Preds.
LLVM_ABI PredicatedScalarEvolution(ScalarEvolution &SE, Loop &L)
LLVM_ABI const SCEVAddRecExpr * getAsAddRec(Value *V)
Attempts to produce an AddRecExpr for V by adding additional SCEV predicates.
LLVM_ABI unsigned getSmallConstantMaxTripCount()
Returns the upper bound of the loop trip count as a normal unsigned value, or 0 if the trip count is ...
LLVM_ABI const SCEV * getBackedgeTakenCount()
Get the (predicated) backedge count for the analyzed loop.
LLVM_ABI const SCEV * getSymbolicMaxBackedgeTakenCount()
Get the (predicated) symbolic max backedge count for the analyzed loop.
LLVM_ABI const SCEV * getSCEV(Value *V)
Returns the SCEV expression of V, in the context of the current SCEV predicate.
A set of analyses that are preserved following a run of a transformation pass.
Definition Analysis.h:112
static PreservedAnalyses all()
Construct a special preserved set that preserves all passes.
Definition Analysis.h:118
PreservedAnalysisChecker getChecker() const
Build a checker for this PreservedAnalyses and the specified analysis type.
Definition Analysis.h:275
constexpr bool isValid() const
Definition Register.h:112
This node represents an addition of some number of SCEVs.
This node represents a polynomial recurrence on the trip count of the specified loop.
LLVM_ABI const SCEV * evaluateAtIteration(const SCEV *It, ScalarEvolution &SE) const
Return the value of this chain of recurrences at the specified iteration number.
void setNoWrapFlags(NoWrapFlags Flags)
Set flags for a recurrence without clearing any previously set flags.
bool isAffine() const
Return true if this represents an expression A + B*x where A and B are loop invariant values.
bool isQuadratic() const
Return true if this represents an expression A + B*x + C*x^2 where A, B and C are loop invariant valu...
LLVM_ABI const SCEV * getNumIterationsInRange(const ConstantRange &Range, ScalarEvolution &SE) const
Return the number of iterations of this loop that produce values in the specified constant range.
LLVM_ABI const SCEVAddRecExpr * getPostIncExpr(ScalarEvolution &SE) const
Return an expression representing the value of this expression one iteration of the loop ahead.
SCEVUse getStepRecurrence(ScalarEvolution &SE) const
Constructs and returns the recurrence indicating how much this expression steps by.
This is the base class for unary cast operator classes.
LLVM_ABI SCEVCastExpr(const FoldingSetNodeIDRef ID, SCEVTypes SCEVTy, SCEVUse op, Type *ty)
void setNoWrapFlags(NoWrapFlags Flags)
Set flags for a non-recurrence without clearing previously set flags.
This class represents an assumption that the expression LHS Pred RHS evaluates to true,...
SCEVComparePredicate(const FoldingSetNodeIDRef ID, const ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS)
bool isAlwaysTrue() const override
Returns true if the predicate is always true.
void print(raw_ostream &OS, unsigned Depth=0) const override
Prints a textual representation of this predicate with an indentation of Depth.
bool implies(const SCEVPredicate *N, ScalarEvolution &SE) const override
Implementation of the SCEVPredicate interface.
This class represents a constant integer value.
ConstantInt * getValue() const
const APInt & getAPInt() const
This is the base class for unary integral cast operator classes.
LLVM_ABI SCEVIntegralCastExpr(const FoldingSetNodeIDRef ID, SCEVTypes SCEVTy, SCEVUse op, Type *ty)
This node is the base class min/max selections.
static enum SCEVTypes negate(enum SCEVTypes T)
This node represents multiplication of some number of SCEVs.
This node is a base class providing common functionality for n'ary operators.
ArrayRef< SCEVUse > operands() const
NoWrapFlags getNoWrapFlags(NoWrapFlags Mask=NoWrapMask) const
SCEVUse getOperand(unsigned i) const
This class represents an assumption made using SCEV expressions which can be checked at run-time.
SCEVPredicate(const SCEVPredicate &)=default
virtual bool implies(const SCEVPredicate *N, ScalarEvolution &SE) const =0
Returns true if this predicate implies N.
SCEVPredicateKind Kind
This class represents a cast from a pointer to a pointer-sized integer value, without capturing the p...
This class represents a cast from a pointer to a pointer-sized integer value.
This visitor recursively visits a SCEV expression and re-writes it.
const SCEV * visitSignExtendExpr(const SCEVSignExtendExpr *Expr)
const SCEV * visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr)
const SCEV * visitSMinExpr(const SCEVSMinExpr *Expr)
const SCEV * visitUMinExpr(const SCEVUMinExpr *Expr)
This class represents a signed minimum selection.
This node is the base class for sequential/in-order min/max selections.
static SCEVTypes getEquivalentNonSequentialSCEVType(SCEVTypes Ty)
This class represents a sign extension of a small integer value to a larger integer value.
Visit all nodes in the expression tree using worklist traversal.
This class represents a truncation of an integer value to a smaller integer value.
This class represents a binary unsigned division operation.
This class represents an unsigned minimum selection.
This class represents a composition of other SCEV predicates, and is the class that most clients will...
void print(raw_ostream &OS, unsigned Depth) const override
Prints a textual representation of this predicate with an indentation of Depth.
bool implies(const SCEVPredicate *N, ScalarEvolution &SE) const override
Returns true if this predicate implies N.
SCEVUnionPredicate(ArrayRef< const SCEVPredicate * > Preds, ScalarEvolution &SE)
Union predicates don't get cached so create a dummy set ID for it.
bool isAlwaysTrue() const override
Implementation of the SCEVPredicate interface.
This means that we are dealing with an entirely unknown SCEV value, and only represent it as its LLVM...
This class represents the value of vscale, as used when defining the length of a scalable vector or r...
This class represents an assumption made on an AddRec expression.
IncrementWrapFlags
Similar to SCEV::NoWrapFlags, but with slightly different semantics for FlagNUSW.
SCEVWrapPredicate(const FoldingSetNodeIDRef ID, const SCEVAddRecExpr *AR, IncrementWrapFlags Flags)
bool implies(const SCEVPredicate *N, ScalarEvolution &SE) const override
Returns true if this predicate implies N.
static SCEVWrapPredicate::IncrementWrapFlags setFlags(SCEVWrapPredicate::IncrementWrapFlags Flags, SCEVWrapPredicate::IncrementWrapFlags OnFlags)
void print(raw_ostream &OS, unsigned Depth=0) const override
Prints a textual representation of this predicate with an indentation of Depth.
bool isAlwaysTrue() const override
Returns true if the predicate is always true.
const SCEVAddRecExpr * getExpr() const
Implementation of the SCEVPredicate interface.
static SCEVWrapPredicate::IncrementWrapFlags clearFlags(SCEVWrapPredicate::IncrementWrapFlags Flags, SCEVWrapPredicate::IncrementWrapFlags OffFlags)
Convenient IncrementWrapFlags manipulation methods.
static SCEVWrapPredicate::IncrementWrapFlags getImpliedFlags(const SCEVAddRecExpr *AR, ScalarEvolution &SE)
Returns the set of SCEVWrapPredicate no wrap flags implied by a SCEVAddRecExpr.
IncrementWrapFlags getFlags() const
Returns the set assumed no overflow flags.
This class represents a zero extension of a small integer value to a larger integer value.
This class represents an analyzed expression in the program.
unsigned short getExpressionSize() const
SCEVNoWrapFlags NoWrapFlags
LLVM_ABI bool isOne() const
Return true if the expression is a constant one.
static constexpr auto FlagNUW
LLVM_ABI void computeAndSetCanonical(ScalarEvolution &SE)
Compute and set the canonical SCEV, by constructing a SCEV with the same operands,...
LLVM_ABI bool isZero() const
Return true if the expression is a constant zero.
const SCEV * CanonicalSCEV
Pointer to the canonical version of the SCEV, i.e.
static constexpr auto FlagAnyWrap
LLVM_ABI void dump() const
This method is used for debugging.
LLVM_ABI bool isAllOnesValue() const
Return true if the expression is a constant all-ones value.
LLVM_ABI bool isNonConstantNegative() const
Return true if the specified scev is negated, but not a constant.
static constexpr auto FlagNSW
LLVM_ABI ArrayRef< SCEVUse > operands() const
Return operands of this SCEV expression.
LLVM_ABI void print(raw_ostream &OS) const
Print out the internal representation of this scalar to the specified stream.
SCEV(const FoldingSetNodeIDRef ID, SCEVTypes SCEVTy, unsigned short ExpressionSize)
SCEVTypes getSCEVType() const
static constexpr auto FlagNW
LLVM_ABI Type * getType() const
Return the LLVM type of this SCEV expression.
Analysis pass that exposes the ScalarEvolution for a function.
LLVM_ABI ScalarEvolution run(Function &F, FunctionAnalysisManager &AM)
LLVM_ABI PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM)
LLVM_ABI PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM)
void getAnalysisUsage(AnalysisUsage &AU) const override
getAnalysisUsage - This function should be overriden by passes that need analysis information to do t...
void print(raw_ostream &OS, const Module *=nullptr) const override
print - Print out the internal state of the pass.
bool runOnFunction(Function &F) override
runOnFunction - Virtual method overriden by subclasses to do the per-function processing of the pass.
void releaseMemory() override
releaseMemory() - This member can be implemented by a pass if it wants to be able to release its memo...
void verifyAnalysis() const override
verifyAnalysis() - This member can be implemented by a analysis pass to check state of analysis infor...
static LLVM_ABI LoopGuards collect(const Loop *L, ScalarEvolution &SE)
Collect rewrite map for loop guards for loop L, together with flags indicating if NUW and NSW can be ...
LLVM_ABI const SCEV * rewrite(const SCEV *Expr) const
Try to apply the collected loop guards to Expr.
The main scalar evolution driver.
LLVM_ABI const SCEV * getUDivExpr(SCEVUse LHS, SCEVUse RHS)
Get a canonical unsigned division expression, or something simpler if possible.
const SCEV * getConstantMaxBackedgeTakenCount(const Loop *L)
When successful, this returns a SCEVConstant that is greater than or equal to (i.e.
static bool hasFlags(SCEV::NoWrapFlags Flags, SCEV::NoWrapFlags TestFlags)
const DataLayout & getDataLayout() const
Return the DataLayout associated with the module this SCEV instance is operating on.
LLVM_ABI bool isKnownNonNegative(const SCEV *S)
Test if the given expression is known to be non-negative.
LLVM_ABI bool isKnownOnEveryIteration(CmpPredicate Pred, const SCEVAddRecExpr *LHS, const SCEV *RHS)
Test if the condition described by Pred, LHS, RHS is known to be true on every iteration of the loop ...
LLVM_ABI const SCEV * getNegativeSCEV(const SCEV *V, SCEV::NoWrapFlags Flags=SCEV::FlagAnyWrap)
Return the SCEV object corresponding to -V.
LLVM_ABI std::optional< LoopInvariantPredicate > getLoopInvariantExitCondDuringFirstIterationsImpl(CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS, const Loop *L, const Instruction *CtxI, const SCEV *MaxIter)
LLVM_ABI const SCEV * getUDivCeilSCEV(const SCEV *N, const SCEV *D)
Compute ceil(N / D).
LLVM_ABI std::optional< LoopInvariantPredicate > getLoopInvariantExitCondDuringFirstIterations(CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS, const Loop *L, const Instruction *CtxI, const SCEV *MaxIter)
If the result of the predicate LHS Pred RHS is loop invariant with respect to L at given Context duri...
LLVM_ABI Type * getWiderType(Type *Ty1, Type *Ty2) const
LLVM_ABI const SCEV * getAbsExpr(const SCEV *Op, bool IsNSW)
LLVM_ABI bool isKnownNonPositive(const SCEV *S)
Test if the given expression is known to be non-positive.
LLVM_ABI bool isKnownNegative(const SCEV *S)
Test if the given expression is known to be negative.
LLVM_ABI const SCEV * getPredicatedConstantMaxBackedgeTakenCount(const Loop *L, SmallVectorImpl< const SCEVPredicate * > &Predicates)
Similar to getConstantMaxBackedgeTakenCount, except it will add a set of SCEV predicates to Predicate...
LLVM_ABI const SCEV * removePointerBase(const SCEV *S)
Compute an expression equivalent to S - getPointerBase(S).
LLVM_ABI bool isLoopEntryGuardedByCond(const Loop *L, CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS)
Test whether entry to the loop is protected by a conditional between LHS and RHS.
LLVM_ABI bool isKnownNonZero(const SCEV *S)
Test if the given expression is known to be non-zero.
LLVM_ABI const SCEV * getURemExpr(SCEVUse LHS, SCEVUse RHS)
Represents an unsigned remainder expression based on unsigned division.
LLVM_ABI const SCEV * getSCEVAtScope(const SCEV *S, const Loop *L)
Return a SCEV expression for the specified value at the specified scope in the program.
LLVM_ABI const SCEV * getBackedgeTakenCount(const Loop *L, ExitCountKind Kind=Exact)
If the specified loop has a predictable backedge-taken count, return it, otherwise return a SCEVCould...
LLVM_ABI const SCEV * getSMinExpr(SCEVUse LHS, SCEVUse RHS)
LLVM_ABI void setNoWrapFlags(SCEVAddRecExpr *AddRec, SCEV::NoWrapFlags Flags)
Update no-wrap flags of an AddRec.
LLVM_ABI const SCEV * getUMaxFromMismatchedTypes(const SCEV *LHS, const SCEV *RHS)
Promote the operands to the wider of the types using zero-extension, and then perform a umax operatio...
const SCEV * getZero(Type *Ty)
Return a SCEV for the constant 0 of a specific type.
LLVM_ABI bool willNotOverflow(Instruction::BinaryOps BinOp, bool Signed, const SCEV *LHS, const SCEV *RHS, const Instruction *CtxI=nullptr)
Is operation BinOp between LHS and RHS provably does not have a signed/unsigned overflow (Signed)?
LLVM_ABI ExitLimit computeExitLimitFromCond(const Loop *L, Value *ExitCond, bool ExitIfTrue, bool ControlsOnlyExit, bool AllowPredicates=false)
Compute the number of times the backedge of the specified loop will execute if its exit condition wer...
LLVM_ABI const SCEV * getZeroExtendExprImpl(const SCEV *Op, Type *Ty, unsigned Depth=0)
LLVM_ABI const SCEV * getMinMaxExpr(SCEVTypes Kind, SmallVectorImpl< SCEVUse > &Operands)
LLVM_ABI const SCEVPredicate * getEqualPredicate(const SCEV *LHS, const SCEV *RHS)
LLVM_ABI unsigned getSmallConstantTripMultiple(const Loop *L, const SCEV *ExitCount)
Returns the largest constant divisor of the trip count as a normal unsigned value,...
LLVM_ABI uint64_t getTypeSizeInBits(Type *Ty) const
Return the size in bits of the specified type, for which isSCEVable must return true.
LLVM_ABI const SCEV * getConstant(ConstantInt *V)
LLVM_ABI const SCEV * getPredicatedBackedgeTakenCount(const Loop *L, SmallVectorImpl< const SCEVPredicate * > &Predicates)
Similar to getBackedgeTakenCount, except it will add a set of SCEV predicates to Predicates that are ...
LLVM_ABI const SCEV * getSCEV(Value *V)
Return a SCEV expression for the full generality of the specified expression.
LLVM_ABI const SCEV * getMinusSCEV(SCEVUse LHS, SCEVUse RHS, SCEV::NoWrapFlags Flags=SCEV::FlagAnyWrap, unsigned Depth=0)
Return LHS-RHS.
ConstantRange getSignedRange(const SCEV *S)
Determine the signed range for a particular SCEV.
LLVM_ABI const SCEV * getAddRecExpr(SCEVUse Start, SCEVUse Step, const Loop *L, SCEV::NoWrapFlags Flags)
Get an add recurrence expression for the specified loop.
LLVM_ABI const SCEV * getNoopOrSignExtend(const SCEV *V, Type *Ty)
Return a SCEV corresponding to a conversion of the input value to the specified type.
bool loopHasNoAbnormalExits(const Loop *L)
Return true if the loop has no abnormal exits.
LLVM_ABI const SCEV * getTripCountFromExitCount(const SCEV *ExitCount)
A version of getTripCountFromExitCount below which always picks an evaluation type which can not resu...
LLVM_ABI ScalarEvolution(Function &F, TargetLibraryInfo &TLI, AssumptionCache &AC, DominatorTree &DT, LoopInfo &LI)
const SCEV * getOne(Type *Ty)
Return a SCEV for the constant 1 of a specific type.
LLVM_ABI const SCEV * getTruncateOrNoop(const SCEV *V, Type *Ty)
Return a SCEV corresponding to a conversion of the input value to the specified type.
LLVM_ABI const SCEV * getLosslessPtrToIntExpr(const SCEV *Op)
LLVM_ABI const SCEV * getCastExpr(SCEVTypes Kind, const SCEV *Op, Type *Ty)
LLVM_ABI const SCEV * getSequentialMinMaxExpr(SCEVTypes Kind, SmallVectorImpl< SCEVUse > &Operands)
LLVM_ABI std::optional< bool > evaluatePredicateAt(CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS, const Instruction *CtxI)
Check whether the condition described by Pred, LHS, and RHS is true or false in the given Context.
LLVM_ABI unsigned getSmallConstantMaxTripCount(const Loop *L, SmallVectorImpl< const SCEVPredicate * > *Predicates=nullptr)
Returns the upper bound of the loop trip count as a normal unsigned value.
LLVM_ABI const SCEV * getPtrToIntExpr(const SCEV *Op, Type *Ty)
LLVM_ABI bool isBackedgeTakenCountMaxOrZero(const Loop *L)
Return true if the backedge taken count is either the value returned by getConstantMaxBackedgeTakenCo...
LLVM_ABI void forgetLoop(const Loop *L)
This method should be called by the client when it has changed a loop in a way that may effect Scalar...
LLVM_ABI bool isLoopInvariant(const SCEV *S, const Loop *L)
Return true if the value of the given SCEV is unchanging in the specified loop.
LLVM_ABI bool isKnownPositive(const SCEV *S)
Test if the given expression is known to be positive.
LLVM_ABI bool SimplifyICmpOperands(CmpPredicate &Pred, SCEVUse &LHS, SCEVUse &RHS, unsigned Depth=0)
Simplify LHS and RHS in a comparison with predicate Pred.
APInt getUnsignedRangeMin(const SCEV *S)
Determine the min of the unsigned range for a particular SCEV.
LLVM_ABI const SCEV * getOffsetOfExpr(Type *IntTy, StructType *STy, unsigned FieldNo)
Return an expression for offsetof on the given field with type IntTy.
LLVM_ABI LoopDisposition getLoopDisposition(const SCEV *S, const Loop *L)
Return the "disposition" of the given SCEV with respect to the given loop.
LLVM_ABI bool containsAddRecurrence(const SCEV *S)
Return true if the SCEV is a scAddRecExpr or it contains scAddRecExpr.
LLVM_ABI const SCEV * getSignExtendExprImpl(const SCEV *Op, Type *Ty, unsigned Depth=0)
LLVM_ABI bool hasOperand(const SCEV *S, const SCEV *Op) const
Test whether the given SCEV has Op as a direct or indirect operand.
LLVM_ABI const SCEV * getZeroExtendExpr(const SCEV *Op, Type *Ty, unsigned Depth=0)
LLVM_ABI bool isSCEVable(Type *Ty) const
Test if values of the given type are analyzable within the SCEV framework.
LLVM_ABI Type * getEffectiveSCEVType(Type *Ty) const
Return a type with the same bitwidth as the given type and which represents how SCEV will treat the g...
LLVM_ABI const SCEVPredicate * getComparePredicate(ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS)
LLVM_ABI bool haveSameSign(const SCEV *S1, const SCEV *S2)
Return true if we know that S1 and S2 must have the same sign.
LLVM_ABI const SCEV * getNotSCEV(const SCEV *V)
Return the SCEV object corresponding to ~V.
LLVM_ABI const SCEV * getElementCount(Type *Ty, ElementCount EC, SCEV::NoWrapFlags Flags=SCEV::FlagAnyWrap)
LLVM_ABI bool instructionCouldExistWithOperands(const SCEV *A, const SCEV *B)
Return true if there exists a point in the program at which both A and B could be operands to the sam...
ConstantRange getUnsignedRange(const SCEV *S)
Determine the unsigned range for a particular SCEV.
LLVM_ABI void print(raw_ostream &OS) const
LLVM_ABI const SCEV * getPredicatedExitCount(const Loop *L, const BasicBlock *ExitingBlock, SmallVectorImpl< const SCEVPredicate * > *Predicates, ExitCountKind Kind=Exact)
Same as above except this uses the predicated backedge taken info and may require predicates.
static SCEV::NoWrapFlags clearFlags(SCEV::NoWrapFlags Flags, SCEV::NoWrapFlags OffFlags)
LLVM_ABI void forgetTopmostLoop(const Loop *L)
LLVM_ABI void forgetValue(Value *V)
This method should be called by the client when it has changed a value in a way that may effect its v...
APInt getSignedRangeMin(const SCEV *S)
Determine the min of the signed range for a particular SCEV.
LLVM_ABI const SCEV * getNoopOrAnyExtend(const SCEV *V, Type *Ty)
Return a SCEV corresponding to a conversion of the input value to the specified type.
LLVM_ABI void forgetBlockAndLoopDispositions(Value *V=nullptr)
Called when the client has changed the disposition of values in a loop or block.
LLVM_ABI const SCEV * getTruncateExpr(const SCEV *Op, Type *Ty, unsigned Depth=0)
LLVM_ABI const SCEV * getUMaxExpr(SCEVUse LHS, SCEVUse RHS)
static SCEV::NoWrapFlags maskFlags(SCEV::NoWrapFlags Flags, SCEV::NoWrapFlags Mask)
Convenient NoWrapFlags manipulation.
LLVM_ABI std::optional< LoopInvariantPredicate > getLoopInvariantPredicate(CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS, const Loop *L, const Instruction *CtxI=nullptr)
If the result of the predicate LHS Pred RHS is loop invariant with respect to L, return a LoopInvaria...
LLVM_ABI const SCEV * getStoreSizeOfExpr(Type *IntTy, Type *StoreTy)
Return an expression for the store size of StoreTy that is type IntTy.
LLVM_ABI const SCEVPredicate * getWrapPredicate(const SCEVAddRecExpr *AR, SCEVWrapPredicate::IncrementWrapFlags AddedFlags)
LLVM_ABI bool isLoopBackedgeGuardedByCond(const Loop *L, CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS)
Test whether the backedge of the loop is protected by a conditional between LHS and RHS.
LLVM_ABI APInt getNonZeroConstantMultiple(const SCEV *S)
const SCEV * getMinusOne(Type *Ty)
Return a SCEV for the constant -1 of a specific type.
static SCEV::NoWrapFlags setFlags(SCEV::NoWrapFlags Flags, SCEV::NoWrapFlags OnFlags)
LLVM_ABI bool hasLoopInvariantBackedgeTakenCount(const Loop *L)
Return true if the specified loop has an analyzable loop-invariant backedge-taken count.
LLVM_ABI BlockDisposition getBlockDisposition(const SCEV *S, const BasicBlock *BB)
Return the "disposition" of the given SCEV with respect to the given block.
LLVM_ABI const SCEV * getNoopOrZeroExtend(const SCEV *V, Type *Ty)
Return a SCEV corresponding to a conversion of the input value to the specified type.
LLVM_ABI bool invalidate(Function &F, const PreservedAnalyses &PA, FunctionAnalysisManager::Invalidator &Inv)
LLVM_ABI const SCEV * getUMinFromMismatchedTypes(const SCEV *LHS, const SCEV *RHS, bool Sequential=false)
Promote the operands to the wider of the types using zero-extension, and then perform a umin operatio...
LLVM_ABI bool loopIsFiniteByAssumption(const Loop *L)
Return true if this loop is finite by assumption.
LLVM_ABI const SCEV * getExistingSCEV(Value *V)
Return an existing SCEV for V if there is one, otherwise return nullptr.
LLVM_ABI APInt getConstantMultiple(const SCEV *S, const Instruction *CtxI=nullptr)
Returns the max constant multiple of S.
LoopDisposition
An enum describing the relationship between a SCEV and a loop.
@ LoopComputable
The SCEV varies predictably with the loop.
@ LoopVariant
The SCEV is loop-variant (unknown).
@ LoopInvariant
The SCEV is loop-invariant.
LLVM_ABI bool isKnownMultipleOf(const SCEV *S, uint64_t M, SmallVectorImpl< const SCEVPredicate * > &Assumptions)
Check that S is a multiple of M.
LLVM_ABI const SCEV * getAnyExtendExpr(const SCEV *Op, Type *Ty)
getAnyExtendExpr - Return a SCEV for the given operand extended with unspecified bits out to the give...
LLVM_ABI bool isKnownToBeAPowerOfTwo(const SCEV *S, bool OrZero=false, bool OrNegative=false)
Test if the given expression is known to be a power of 2.
LLVM_ABI std::optional< SCEV::NoWrapFlags > getStrengthenedNoWrapFlagsFromBinOp(const OverflowingBinaryOperator *OBO)
Parse NSW/NUW flags from add/sub/mul IR binary operation Op into SCEV no-wrap flags,...
LLVM_ABI void forgetLcssaPhiWithNewPredecessor(Loop *L, PHINode *V)
Forget LCSSA phi node V of loop L to which a new predecessor was added, such that it may no longer be...
LLVM_ABI bool containsUndefs(const SCEV *S) const
Return true if the SCEV expression contains an undef value.
LLVM_ABI std::optional< MonotonicPredicateType > getMonotonicPredicateType(const SCEVAddRecExpr *LHS, ICmpInst::Predicate Pred)
If, for all loop invariant X, the predicate "LHS `Pred` X" is monotonically increasing or decreasing,...
LLVM_ABI const SCEV * getCouldNotCompute()
LLVM_ABI const SCEV * getMulExpr(SmallVectorImpl< SCEVUse > &Ops, SCEV::NoWrapFlags Flags=SCEV::FlagAnyWrap, unsigned Depth=0)
Get a canonical multiply expression, or something simpler if possible.
LLVM_ABI bool isAvailableAtLoopEntry(const SCEV *S, const Loop *L)
Determine if the SCEV can be evaluated at loop's entry.
LLVM_ABI uint32_t getMinTrailingZeros(const SCEV *S, const Instruction *CtxI=nullptr)
Determine the minimum number of zero bits that S is guaranteed to end in (at every loop iteration).
BlockDisposition
An enum describing the relationship between a SCEV and a basic block.
@ DominatesBlock
The SCEV dominates the block.
@ ProperlyDominatesBlock
The SCEV properly dominates the block.
@ DoesNotDominateBlock
The SCEV does not dominate the block.
LLVM_ABI const SCEV * getExitCount(const Loop *L, const BasicBlock *ExitingBlock, ExitCountKind Kind=Exact)
Return the number of times the backedge executes before the given exit would be taken; if not exactly...
LLVM_ABI const SCEV * getSignExtendExpr(const SCEV *Op, Type *Ty, unsigned Depth=0)
LLVM_ABI void getPoisonGeneratingValues(SmallPtrSetImpl< const Value * > &Result, const SCEV *S)
Return the set of Values that, if poison, will definitively result in S being poison as well.
LLVM_ABI void forgetLoopDispositions()
Called when the client has changed the disposition of values in this loop.
LLVM_ABI const SCEV * getVScale(Type *Ty)
LLVM_ABI unsigned getSmallConstantTripCount(const Loop *L)
Returns the exact trip count of the loop if we can compute it, and the result is a small constant.
LLVM_ABI bool hasComputableLoopEvolution(const SCEV *S, const Loop *L)
Return true if the given SCEV changes value in a known way in the specified loop.
LLVM_ABI const SCEV * getPointerBase(const SCEV *V)
Transitively follow the chain of pointer-type operands until reaching a SCEV that does not have a sin...
LLVM_ABI void forgetAllLoops()
LLVM_ABI bool dominates(const SCEV *S, const BasicBlock *BB)
Return true if elements that makes up the given SCEV dominate the specified basic block.
APInt getUnsignedRangeMax(const SCEV *S)
Determine the max of the unsigned range for a particular SCEV.
LLVM_ABI const SCEV * getAddExpr(SmallVectorImpl< SCEVUse > &Ops, SCEV::NoWrapFlags Flags=SCEV::FlagAnyWrap, unsigned Depth=0)
Get a canonical add expression, or something simpler if possible.
ExitCountKind
The terms "backedge taken count" and "exit count" are used interchangeably to refer to the number of ...
@ SymbolicMaximum
An expression which provides an upper bound on the exact trip count.
@ ConstantMaximum
A constant which provides an upper bound on the exact trip count.
@ Exact
An expression exactly describing the number of times the backedge has executed when a loop is exited.
LLVM_ABI bool isKnownPredicate(CmpPredicate Pred, SCEVUse LHS, SCEVUse RHS)
Test if the given expression is known to satisfy the condition described by Pred, LHS,...
LLVM_ABI const SCEV * applyLoopGuards(const SCEV *Expr, const Loop *L)
Try to apply information from loop guards for L to Expr.
LLVM_ABI const SCEV * getPtrToAddrExpr(const SCEV *Op)
LLVM_ABI const SCEVAddRecExpr * convertSCEVToAddRecWithPredicates(const SCEV *S, const Loop *L, SmallVectorImpl< const SCEVPredicate * > &Preds)
Tries to convert the S expression to an AddRec expression, adding additional predicates to Preds as r...
LLVM_ABI const SCEV * getSMaxExpr(SCEVUse LHS, SCEVUse RHS)
LLVM_ABI const SCEV * getElementSize(Instruction *Inst)
Return the size of an element read or written by Inst.
LLVM_ABI const SCEV * getSizeOfExpr(Type *IntTy, TypeSize Size)
Return an expression for a TypeSize.
LLVM_ABI std::optional< bool > evaluatePredicate(CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS)
Check whether the condition described by Pred, LHS, and RHS is true or false.
LLVM_ABI const SCEV * getUnknown(Value *V)
LLVM_ABI std::optional< std::pair< const SCEV *, SmallVector< const SCEVPredicate *, 3 > > > createAddRecFromPHIWithCasts(const SCEVUnknown *SymbolicPHI)
Checks if SymbolicPHI can be rewritten as an AddRecExpr under some Predicates.
LLVM_ABI const SCEV * getTruncateOrZeroExtend(const SCEV *V, Type *Ty, unsigned Depth=0)
Return a SCEV corresponding to a conversion of the input value to the specified type.
LLVM_ABI bool isKnownViaInduction(CmpPredicate Pred, SCEVUse LHS, SCEVUse RHS)
We'd like to check the predicate on every iteration of the most dominated loop between loops used in ...
LLVM_ABI std::optional< APInt > computeConstantDifference(const SCEV *LHS, const SCEV *RHS)
Compute LHS - RHS and returns the result as an APInt if it is a constant, and std::nullopt if it isn'...
LLVM_ABI bool properlyDominates(const SCEV *S, const BasicBlock *BB)
Return true if elements that makes up the given SCEV properly dominate the specified basic block.
LLVM_ABI const SCEV * getUDivExactExpr(SCEVUse LHS, SCEVUse RHS)
Get a canonical unsigned division expression, or something simpler if possible.
LLVM_ABI const SCEV * rewriteUsingPredicate(const SCEV *S, const Loop *L, const SCEVPredicate &A)
Re-writes the SCEV according to the Predicates in A.
LLVM_ABI std::pair< const SCEV *, const SCEV * > SplitIntoInitAndPostInc(const Loop *L, const SCEV *S)
Splits SCEV expression S into two SCEVs.
LLVM_ABI bool canReuseInstruction(const SCEV *S, Instruction *I, SmallVectorImpl< Instruction * > &DropPoisonGeneratingInsts)
Check whether it is poison-safe to represent the expression S using the instruction I.
LLVM_ABI bool isKnownPredicateAt(CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS, const Instruction *CtxI)
Test if the given expression is known to satisfy the condition described by Pred, LHS,...
LLVM_ABI const SCEV * getPredicatedSymbolicMaxBackedgeTakenCount(const Loop *L, SmallVectorImpl< const SCEVPredicate * > &Predicates)
Similar to getSymbolicMaxBackedgeTakenCount, except it will add a set of SCEV predicates to Predicate...
LLVM_ABI const SCEV * getGEPExpr(GEPOperator *GEP, ArrayRef< SCEVUse > IndexExprs)
Returns an expression for a GEP.
LLVM_ABI const SCEV * getUMinExpr(SCEVUse LHS, SCEVUse RHS, bool Sequential=false)
LLVM_ABI void registerUser(const SCEV *User, ArrayRef< const SCEV * > Ops)
Notify this ScalarEvolution that User directly uses SCEVs in Ops.
LLVM_ABI bool isBasicBlockEntryGuardedByCond(const BasicBlock *BB, CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS)
Test whether entry to the basic block is protected by a conditional between LHS and RHS.
LLVM_ABI const SCEV * getTruncateOrSignExtend(const SCEV *V, Type *Ty, unsigned Depth=0)
Return a SCEV corresponding to a conversion of the input value to the specified type.
LLVM_ABI bool containsErasedValue(const SCEV *S) const
Return true if the SCEV expression contains a Value that has been optimised out and is now a nullptr.
const SCEV * getSymbolicMaxBackedgeTakenCount(const Loop *L)
When successful, this returns a SCEV that is greater than or equal to (i.e.
APInt getSignedRangeMax(const SCEV *S)
Determine the max of the signed range for a particular SCEV.
LLVM_ABI void verify() const
LLVMContext & getContext() const
Implements a dense probed hash-table based set with some number of buckets stored inline.
Definition DenseSet.h:291
size_type size() const
Definition SmallPtrSet.h:99
A templated base class for SmallPtrSet which provides the typesafe interface that is common across al...
std::pair< iterator, bool > insert(PtrType Ptr)
Inserts Ptr if and only if there is no element in the container equal to Ptr.
bool contains(ConstPtrType Ptr) const
SmallPtrSet - This class implements a set which is optimized for holding SmallSize or less elements.
This class consists of common code factored out of the SmallVector class to reduce code duplication b...
reference emplace_back(ArgTypes &&... Args)
void reserve(size_type N)
iterator erase(const_iterator CI)
void append(ItTy in_start, ItTy in_end)
Add the specified range to the end of the SmallVector.
iterator insert(iterator I, T &&Elt)
void push_back(const T &Elt)
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
An instruction for storing to memory.
StringRef - Represent a constant reference to a string, i.e.
Definition StringRef.h:55
Used to lazily calculate structure layout information for a target machine, based on the DataLayout s...
Definition DataLayout.h:736
TypeSize getElementOffset(unsigned Idx) const
Definition DataLayout.h:767
TypeSize getSizeInBits() const
Definition DataLayout.h:747
Class to represent struct types.
Analysis pass providing the TargetLibraryInfo.
Provides information about what library functions are available for the current target.
The instances of the Type class are immutable: once they are created, they are never changed.
Definition Type.h:46
static LLVM_ABI IntegerType * getInt32Ty(LLVMContext &C)
Definition Type.cpp:313
bool isPointerTy() const
True if this is an instance of PointerType.
Definition Type.h:284
LLVM_ABI TypeSize getPrimitiveSizeInBits() const LLVM_READONLY
Return the basic size of this type if it is a primitive type.
Definition Type.cpp:201
static LLVM_ABI IntegerType * getInt1Ty(LLVMContext &C)
Definition Type.cpp:310
bool isIntOrPtrTy() const
Return true if this is an integer type or a pointer type.
Definition Type.h:272
bool isIntegerTy() const
True if this is an instance of IntegerType.
Definition Type.h:257
static LLVM_ABI IntegerType * getIntNTy(LLVMContext &C, unsigned N)
Definition Type.cpp:317
A Use represents the edge between a Value definition and its users.
Definition Use.h:35
op_range operands()
Definition User.h:267
Use & Op()
Definition User.h:171
Value * getOperand(unsigned i) const
Definition User.h:207
LLVM Value Representation.
Definition Value.h:75
Type * getType() const
All values are typed, get the type of this value.
Definition Value.h:255
LLVMContext & getContext() const
All values hold a context through their type.
Definition Value.h:258
unsigned getValueID() const
Return an ID for the concrete type of this object.
Definition Value.h:543
LLVM_ABI void printAsOperand(raw_ostream &O, bool PrintType=true, const Module *M=nullptr) const
Print the name of this Value out to the specified raw_ostream.
LLVM_ABI StringRef getName() const
Return a constant reference to the value's name.
Definition Value.cpp:318
constexpr bool isScalable() const
Returns whether the quantity is scaled by a runtime quantity (vscale).
Definition TypeSize.h:168
An efficient, type-erasing, non-owning reference to a callable.
const ParentTy * getParent() const
Definition ilist_node.h:34
This class implements an extremely fast bulk output stream that can only output to a stream.
Definition raw_ostream.h:53
raw_ostream & indent(unsigned NumSpaces)
indent - Insert 'NumSpaces' spaces.
Changed
#define llvm_unreachable(msg)
Marks that the current location is not supposed to be reachable.
constexpr char Align[]
Key for Kernel::Arg::Metadata::mAlign.
const APInt & smin(const APInt &A, const APInt &B)
Determine the smaller of two APInts considered to be signed.
Definition APInt.h:2266
const APInt & smax(const APInt &A, const APInt &B)
Determine the larger of two APInts considered to be signed.
Definition APInt.h:2271
const APInt & umin(const APInt &A, const APInt &B)
Determine the smaller of two APInts considered to be unsigned.
Definition APInt.h:2276
LLVM_ABI std::optional< APInt > SolveQuadraticEquationWrap(APInt A, APInt B, APInt C, unsigned RangeWidth)
Let q(n) = An^2 + Bn + C, and BW = bit width of the value range (e.g.
Definition APInt.cpp:2852
const APInt & umax(const APInt &A, const APInt &B)
Determine the larger of two APInts considered to be unsigned.
Definition APInt.h:2281
LLVM_ABI APInt GreatestCommonDivisor(APInt A, APInt B)
Compute GCD of two unsigned APInt values.
Definition APInt.cpp:818
constexpr bool any(E Val)
@ Entry
Definition COFF.h:862
unsigned ID
LLVM IR allows to use arbitrary numbers as calling convention identifiers.
Definition CallingConv.h:24
@ C
The default llvm calling convention, compatible with C.
Definition CallingConv.h:34
int getMinValue(MCInstrInfo const &MCII, MCInst const &MCI)
Return the minimum value of an extendable operand.
@ BasicBlock
Various leaf nodes.
Definition ISDOpcodes.h:81
LLVM_ABI Function * getDeclarationIfExists(const Module *M, ID id)
Look up the Function declaration of the intrinsic id in the Module M and return it if it exists.
Predicate
Predicate - These are "(BI << 5) | BO" for various predicates.
BinaryOp_match< LHS, RHS, Instruction::AShr > m_AShr(const LHS &L, const RHS &R)
ap_match< APInt > m_APInt(const APInt *&Res)
Match a ConstantInt or splatted ConstantVector, binding the specified pointer to the contained APInt.
bool match(Val *V, const Pattern &P)
class_match< ConstantInt > m_ConstantInt()
Match an arbitrary ConstantInt and ignore it.
IntrinsicID_match m_Intrinsic()
Match intrinsic calls like this: m_Intrinsic<Intrinsic::fabs>(m_Value(X))
ThreeOps_match< Cond, LHS, RHS, Instruction::Select > m_Select(const Cond &C, const LHS &L, const RHS &R)
Matches SelectInst.
ExtractValue_match< Ind, Val_t > m_ExtractValue(const Val_t &V)
Match a single index ExtractValue instruction.
bind_ty< WithOverflowInst > m_WithOverflowInst(WithOverflowInst *&I)
Match a with overflow intrinsic, capturing it if we match.
auto m_LogicalOr()
Matches L || R where L and R are arbitrary values.
brc_match< Cond_t, bind_ty< BasicBlock >, bind_ty< BasicBlock > > m_Br(const Cond_t &C, BasicBlock *&T, BasicBlock *&F)
BinaryOp_match< LHS, RHS, Instruction::SDiv > m_SDiv(const LHS &L, const RHS &R)
class_match< Value > m_Value()
Match an arbitrary value and ignore it.
BinaryOp_match< LHS, RHS, Instruction::LShr > m_LShr(const LHS &L, const RHS &R)
BinaryOp_match< LHS, RHS, Instruction::Shl > m_Shl(const LHS &L, const RHS &R)
auto m_LogicalAnd()
Matches L && R where L and R are arbitrary values.
class_match< BasicBlock > m_BasicBlock()
Match an arbitrary basic block value and ignore it.
match_combine_or< LTy, RTy > m_CombineOr(const LTy &L, const RTy &R)
Combine two pattern matchers matching L || R.
class_match< const SCEVVScale > m_SCEVVScale()
bind_cst_ty m_scev_APInt(const APInt *&C)
Match an SCEV constant and bind it to an APInt.
cst_pred_ty< is_all_ones > m_scev_AllOnes()
Match an integer with all bits set.
SCEVUnaryExpr_match< SCEVZeroExtendExpr, Op0_t > m_scev_ZExt(const Op0_t &Op0)
is_undef_or_poison m_scev_UndefOrPoison()
Match an SCEVUnknown wrapping undef or poison.
class_match< const SCEVConstant > m_SCEVConstant()
cst_pred_ty< is_one > m_scev_One()
Match an integer 1.
specificloop_ty m_SpecificLoop(const Loop *L)
SCEVAffineAddRec_match< Op0_t, Op1_t, class_match< const Loop > > m_scev_AffineAddRec(const Op0_t &Op0, const Op1_t &Op1)
bind_ty< const SCEVMulExpr > m_scev_Mul(const SCEVMulExpr *&V)
SCEVUnaryExpr_match< SCEVSignExtendExpr, Op0_t > m_scev_SExt(const Op0_t &Op0)
cst_pred_ty< is_zero > m_scev_Zero()
Match an integer 0.
SCEVUnaryExpr_match< SCEVTruncateExpr, Op0_t > m_scev_Trunc(const Op0_t &Op0)
bool match(const SCEV *S, const Pattern &P)
SCEVBinaryExpr_match< SCEVUDivExpr, Op0_t, Op1_t > m_scev_UDiv(const Op0_t &Op0, const Op1_t &Op1)
specificscev_ty m_scev_Specific(const SCEV *S)
Match if we have a specific specified SCEV.
SCEVBinaryExpr_match< SCEVMulExpr, Op0_t, Op1_t, SCEV::FlagNUW, true > m_scev_c_NUWMul(const Op0_t &Op0, const Op1_t &Op1)
class_match< const Loop > m_Loop()
bind_ty< const SCEVAddExpr > m_scev_Add(const SCEVAddExpr *&V)
bind_ty< const SCEVUnknown > m_SCEVUnknown(const SCEVUnknown *&V)
SCEVBinaryExpr_match< SCEVMulExpr, Op0_t, Op1_t, SCEV::FlagAnyWrap, true > m_scev_c_Mul(const Op0_t &Op0, const Op1_t &Op1)
SCEVBinaryExpr_match< SCEVSMaxExpr, Op0_t, Op1_t > m_scev_SMax(const Op0_t &Op0, const Op1_t &Op1)
SCEVURem_match< Op0_t, Op1_t > m_scev_URem(Op0_t LHS, Op1_t RHS, ScalarEvolution &SE)
Match the mathematical pattern A - (A / B) * B, where A and B can be arbitrary expressions.
class_match< const SCEV > m_SCEV()
@ Valid
The data is already valid.
initializer< Ty > init(const Ty &Val)
LocationClass< Ty > location(Ty &L)
@ Switch
The "resume-switch" lowering, where there are separate resume and destroy functions that are shared b...
Definition CoroShape.h:31
constexpr double e
NodeAddr< PhiNode * > Phi
Definition RDFGraph.h:390
friend class Instruction
Iterator for Instructions in a `BasicBlock.
Definition BasicBlock.h:73
This is an optimization pass for GlobalISel generic memory operations.
void visitAll(const SCEV *Root, SV &Visitor)
Use SCEVTraversal to visit all nodes in the given expression tree.
auto drop_begin(T &&RangeOrContainer, size_t N=1)
Return a range covering RangeOrContainer with the first N elements excluded.
Definition STLExtras.h:316
@ Offset
Definition DWP.cpp:532
FunctionAddr VTableAddr Value
Definition InstrProf.h:137
LLVM_ATTRIBUTE_ALWAYS_INLINE DynamicAPInt gcd(const DynamicAPInt &A, const DynamicAPInt &B)
void stable_sort(R &&Range)
Definition STLExtras.h:2116
bool all_of(R &&range, UnaryPredicate P)
Provide wrappers to std::all_of which take ranges instead of having to pass begin/end explicitly.
Definition STLExtras.h:1739
SaveAndRestore(T &) -> SaveAndRestore< T >
Printable print(const GCNRegPressure &RP, const GCNSubtarget *ST=nullptr, unsigned DynamicVGPRBlockSize=0)
LLVM_ABI bool canCreatePoison(const Operator *Op, bool ConsiderFlagsAndMetadata=true)
LLVM_ABI bool mustTriggerUB(const Instruction *I, const SmallPtrSetImpl< const Value * > &KnownPoison)
Return true if the given instruction must trigger undefined behavior when I is executed with any oper...
LLVM_ABI bool canConstantFoldCallTo(const CallBase *Call, const Function *F)
canConstantFoldCallTo - Return true if its even possible to fold a call to the specified function.
InterleavedRange< Range > interleaved(const Range &R, StringRef Separator=", ", StringRef Prefix="", StringRef Suffix="")
Output range R as a sequence of interleaved elements.
decltype(auto) dyn_cast(const From &Val)
dyn_cast<X> - Return the argument parameter cast to the specified type.
Definition Casting.h:643
LLVM_ABI bool verifyFunction(const Function &F, raw_ostream *OS=nullptr)
Check a function for errors, useful for use when debugging a pass.
auto successors(const MachineBasicBlock *BB)
scope_exit(Callable) -> scope_exit< Callable >
constexpr from_range_t from_range
auto dyn_cast_if_present(const Y &Val)
dyn_cast_if_present<X> - Functionally identical to dyn_cast, except that a null (or none in the case ...
Definition Casting.h:732
bool set_is_subset(const S1Ty &S1, const S2Ty &S2)
set_is_subset(A, B) - Return true iff A in B
void append_range(Container &C, Range &&R)
Wrapper function to append range R to container C.
Definition STLExtras.h:2208
constexpr bool isUIntN(unsigned N, uint64_t x)
Checks if an unsigned integer fits into the given (dynamic) bit width.
Definition MathExtras.h:243
LLVM_ABI Constant * ConstantFoldCompareInstOperands(unsigned Predicate, Constant *LHS, Constant *RHS, const DataLayout &DL, const TargetLibraryInfo *TLI=nullptr, const Instruction *I=nullptr)
Attempt to constant fold a compare instruction (icmp/fcmp) with the specified operands.
void * PointerTy
LLVM_ABI bool VerifySCEV
auto uninitialized_copy(R &&Src, IterTy Dst)
Definition STLExtras.h:2111
bool isa_and_nonnull(const Y &Val)
Definition Casting.h:676
LLVM_ABI ConstantRange getConstantRangeFromMetadata(const MDNode &RangeMD)
Parse out a conservative ConstantRange from !range metadata.
int countr_zero(T Val)
Count number of 0's from the least significant bit to the most stopping at the first 1.
Definition bit.h:202
LLVM_ABI Value * simplifyInstruction(Instruction *I, const SimplifyQuery &Q)
See if we can compute a simplified version of this instruction.
LLVM_ABI bool isOverflowIntrinsicNoWrap(const WithOverflowInst *WO, const DominatorTree &DT)
Returns true if the arithmetic part of the WO 's result is used only along the paths control dependen...
DomTreeNodeBase< BasicBlock > DomTreeNode
Definition Dominators.h:94
LLVM_ABI bool matchSimpleRecurrence(const PHINode *P, BinaryOperator *&BO, Value *&Start, Value *&Step)
Attempt to match a simple first order recurrence cycle of the form: iv = phi Ty [Start,...
auto dyn_cast_or_null(const Y &Val)
Definition Casting.h:753
void erase(Container &C, ValueType V)
Wrapper function to remove a value from a container:
Definition STLExtras.h:2200
bool any_of(R &&range, UnaryPredicate P)
Provide wrappers to std::any_of which take ranges instead of having to pass begin/end explicitly.
Definition STLExtras.h:1746
iterator_range< pointee_iterator< WrappedIteratorT > > make_pointee_range(RangeT &&Range)
Definition iterator.h:341
auto reverse(ContainerTy &&C)
Definition STLExtras.h:408
LLVM_ABI bool isMustProgress(const Loop *L)
Return true if this loop can be assumed to make progress.
LLVM_ABI bool impliesPoison(const Value *ValAssumedPoison, const Value *V)
Return true if V is poison given that ValAssumedPoison is already poison.
LLVM_ABI bool isFinite(const Loop *L)
Return true if this loop can be assumed to run for a finite number of iterations.
LLVM_ABI void computeKnownBits(const Value *V, KnownBits &Known, const DataLayout &DL, AssumptionCache *AC=nullptr, const Instruction *CxtI=nullptr, const DominatorTree *DT=nullptr, bool UseInstrInfo=true, unsigned Depth=0)
Determine which bits of V are known to be either zero or one and return them in the KnownZero/KnownOn...
unsigned short computeExpressionSize(ArrayRef< SCEVUse > Args)
LLVM_ABI bool programUndefinedIfPoison(const Instruction *Inst)
LLVM_ABI raw_ostream & dbgs()
dbgs() - This returns a reference to a raw_ostream for debugging messages.
Definition Debug.cpp:207
bool isPointerTy(const Type *T)
Definition SPIRVUtils.h:368
FunctionAddr VTableAddr Count
Definition InstrProf.h:139
LLVM_ABI ConstantRange getVScaleRange(const Function *F, unsigned BitWidth)
Determine the possible constant range of vscale with the given bit width, based on the vscale_range f...
class LLVM_GSL_OWNER SmallVector
Forward declaration of SmallVector so that calculateSmallVectorDefaultInlinedElements can reference s...
bool isa(const From &Val)
isa<X> - Return true if the parameter to the template is an instance of one of the template type argu...
Definition Casting.h:547
LLVM_ATTRIBUTE_VISIBILITY_DEFAULT AnalysisKey InnerAnalysisManagerProxy< AnalysisManagerT, IRUnitT, ExtraArgTs... >::Key
LLVM_ABI bool isKnownNonZero(const Value *V, const SimplifyQuery &Q, unsigned Depth=0)
Return true if the given value is known to be non-zero when defined.
LLVM_ABI bool propagatesPoison(const Use &PoisonOp)
Return true if PoisonOp's user yields poison or raises UB if its operand PoisonOp is poison.
@ UMin
Unsigned integer min implemented in terms of select(cmp()).
@ Mul
Product of integers.
@ SMax
Signed integer max implemented in terms of select(cmp()).
@ SMin
Signed integer min implemented in terms of select(cmp()).
@ Add
Sum of integers.
@ UMax
Unsigned integer max implemented in terms of select(cmp()).
auto count(R &&Range, const E &Element)
Wrapper function around std::count to count the number of times an element Element occurs in the give...
Definition STLExtras.h:2012
DWARFExpression::Operation Op
auto max_element(R &&Range)
Provide wrappers to std::max_element which take ranges instead of having to pass begin/end explicitly...
Definition STLExtras.h:2088
raw_ostream & operator<<(raw_ostream &OS, const APFixedPoint &FX)
ArrayRef(const T &OneElt) -> ArrayRef< T >
LLVM_ABI unsigned ComputeNumSignBits(const Value *Op, const DataLayout &DL, AssumptionCache *AC=nullptr, const Instruction *CxtI=nullptr, const DominatorTree *DT=nullptr, bool UseInstrInfo=true, unsigned Depth=0)
Return the number of times the sign bit of the register is replicated into the other bits.
constexpr unsigned BitWidth
OutputIt move(R &&Range, OutputIt Out)
Provide wrappers to std::move which take ranges instead of having to pass begin/end explicitly.
Definition STLExtras.h:1917
LLVM_ABI bool isGuaranteedToTransferExecutionToSuccessor(const Instruction *I)
Return true if this function can prove that the instruction I will always transfer execution to one o...
auto count_if(R &&Range, UnaryPredicate P)
Wrapper function around std::count_if to count the number of times an element satisfying a given pred...
Definition STLExtras.h:2019
decltype(auto) cast(const From &Val)
cast<X> - Return the argument parameter cast to the specified type.
Definition Casting.h:559
constexpr bool isIntN(unsigned N, int64_t x)
Checks if an signed integer fits into the given (dynamic) bit width.
Definition MathExtras.h:248
auto predecessors(const MachineBasicBlock *BB)
bool is_contained(R &&Range, const E &Element)
Returns true if Element is found in Range.
Definition STLExtras.h:1947
iterator_range< df_iterator< T > > depth_first(const T &G)
auto seq(T Begin, T End)
Iterate over an integral type from Begin up to - but not including - End.
Definition Sequence.h:305
AnalysisManager< Function > FunctionAnalysisManager
Convenience typedef for the Function analysis manager.
LLVM_ABI bool isGuaranteedNotToBePoison(const Value *V, AssumptionCache *AC=nullptr, const Instruction *CtxI=nullptr, const DominatorTree *DT=nullptr, unsigned Depth=0)
Returns true if V cannot be poison, but may be undef.
LLVM_ABI Constant * ConstantFoldInstOperands(const Instruction *I, ArrayRef< Constant * > Ops, const DataLayout &DL, const TargetLibraryInfo *TLI=nullptr, bool AllowNonDeterministic=true)
ConstantFoldInstOperands - Attempt to constant fold an instruction with the specified operands.
SCEVUseT< const SCEV * > SCEVUse
bool SCEVExprContains(const SCEV *Root, PredTy Pred)
Return true if any node in Root satisfies the predicate Pred.
Implement std::hash so that hash_code can be used in STL containers.
Definition BitVector.h:870
void swap(llvm::BitVector &LHS, llvm::BitVector &RHS)
Implement std::swap in terms of BitVector swap.
Definition BitVector.h:872
#define N
#define NC
Definition regutils.h:42
A special type used by analysis passes to provide an address that identifies that particular analysis...
Definition Analysis.h:29
Incoming for lane mask phi as machine instruction, incoming register Reg and incoming block Block are...
static KnownBits makeConstant(const APInt &C)
Create known bits from a known constant.
Definition KnownBits.h:317
bool isNonNegative() const
Returns true if this value is known to be non-negative.
Definition KnownBits.h:108
static LLVM_ABI KnownBits ashr(const KnownBits &LHS, const KnownBits &RHS, bool ShAmtNonZero=false, bool Exact=false)
Compute known bits for ashr(LHS, RHS).
unsigned getBitWidth() const
Get the bit width of this value.
Definition KnownBits.h:44
static LLVM_ABI KnownBits lshr(const KnownBits &LHS, const KnownBits &RHS, bool ShAmtNonZero=false, bool Exact=false)
Compute known bits for lshr(LHS, RHS).
KnownBits zextOrTrunc(unsigned BitWidth) const
Return known bits for a zero extension or truncation of the value we're tracking.
Definition KnownBits.h:202
APInt getMaxValue() const
Return the maximal unsigned value possible given these KnownBits.
Definition KnownBits.h:148
APInt getMinValue() const
Return the minimal unsigned value possible given these KnownBits.
Definition KnownBits.h:132
bool isNegative() const
Returns true if this value is known to be negative.
Definition KnownBits.h:105
static LLVM_ABI KnownBits shl(const KnownBits &LHS, const KnownBits &RHS, bool NUW=false, bool NSW=false, bool ShAmtNonZero=false)
Compute known bits for shl(LHS, RHS).
An object of this class is returned by queries that could not be answered.
static LLVM_ABI bool classof(const SCEV *S)
Methods for support type inquiry through isa, cast, and dyn_cast:
This class defines a simple visitor class that may be used for various SCEV analysis purposes.
A utility class that uses RAII to save and restore the value of a variable.
Information about the number of loop iterations for which a loop exit's branch condition evaluates to...
LLVM_ABI ExitLimit(const SCEV *E)
Construct either an exact exit limit from a constant, or an unknown one from a SCEVCouldNotCompute.
SmallVector< const SCEVPredicate *, 4 > Predicates
A vector of predicate guards for this ExitLimit.