Bug Summary

File:build/source/llvm/lib/Analysis/ScalarEvolution.cpp
Warning:line 4437, column 3
Called C++ object pointer is null

Annotated Source Code

Press '?' to see keyboard shortcuts

clang -cc1 -cc1 -triple x86_64-pc-linux-gnu -analyze -disable-free -clear-ast-before-backend -disable-llvm-verifier -discard-value-names -main-file-name ScalarEvolution.cpp -analyzer-checker=core -analyzer-checker=apiModeling -analyzer-checker=unix -analyzer-checker=deadcode -analyzer-checker=cplusplus -analyzer-checker=security.insecureAPI.UncheckedReturn -analyzer-checker=security.insecureAPI.getpw -analyzer-checker=security.insecureAPI.gets -analyzer-checker=security.insecureAPI.mktemp -analyzer-checker=security.insecureAPI.mkstemp -analyzer-checker=security.insecureAPI.vfork -analyzer-checker=nullability.NullPassedToNonnull -analyzer-checker=nullability.NullReturnedFromNonnull -analyzer-output plist -w -setup-static-analyzer -analyzer-config-compatibility-mode=true -mrelocation-model pic -pic-level 2 -mframe-pointer=none -fmath-errno -ffp-contract=on -fno-rounding-math -mconstructor-aliases -funwind-tables=2 -target-cpu x86-64 -tune-cpu generic -debugger-tuning=gdb -ffunction-sections -fdata-sections -fcoverage-compilation-dir=/build/source/build-llvm -resource-dir /usr/lib/llvm-16/lib/clang/16.0.0 -D _DEBUG -D _GNU_SOURCE -D __STDC_CONSTANT_MACROS -D __STDC_FORMAT_MACROS -D __STDC_LIMIT_MACROS -I lib/Analysis -I /build/source/llvm/lib/Analysis -I include -I /build/source/llvm/include -D _FORTIFY_SOURCE=2 -D NDEBUG -U NDEBUG -internal-isystem /usr/lib/gcc/x86_64-linux-gnu/10/../../../../include/c++/10 -internal-isystem /usr/lib/gcc/x86_64-linux-gnu/10/../../../../include/x86_64-linux-gnu/c++/10 -internal-isystem /usr/lib/gcc/x86_64-linux-gnu/10/../../../../include/c++/10/backward -internal-isystem /usr/lib/llvm-16/lib/clang/16.0.0/include -internal-isystem /usr/local/include -internal-isystem /usr/lib/gcc/x86_64-linux-gnu/10/../../../../x86_64-linux-gnu/include -internal-externc-isystem /usr/include/x86_64-linux-gnu -internal-externc-isystem /include -internal-externc-isystem /usr/include -fmacro-prefix-map=/build/source/build-llvm=build-llvm -fmacro-prefix-map=/build/source/= -fcoverage-prefix-map=/build/source/build-llvm=build-llvm -fcoverage-prefix-map=/build/source/= -source-date-epoch 1668078801 -O3 -Wno-unused-command-line-argument -Wno-unused-parameter -Wwrite-strings -Wno-missing-field-initializers -Wno-long-long -Wno-maybe-uninitialized -Wno-class-memaccess -Wno-redundant-move -Wno-pessimizing-move -Wno-noexcept-type -Wno-comment -Wno-misleading-indentation -std=c++17 -fdeprecated-macro -fdebug-compilation-dir=/build/source/build-llvm -fdebug-prefix-map=/build/source/build-llvm=build-llvm -fdebug-prefix-map=/build/source/= -ferror-limit 19 -fvisibility-inlines-hidden -stack-protector 2 -fgnuc-version=4.2.1 -fcolor-diagnostics -vectorize-loops -vectorize-slp -analyzer-output=html -analyzer-config stable-report-filename=true -faddrsig -D__GCC_HAVE_DWARF2_CFI_ASM=1 -o /tmp/scan-build-2022-11-10-135928-647445-1 -x c++ /build/source/llvm/lib/Analysis/ScalarEvolution.cpp

/build/source/llvm/lib/Analysis/ScalarEvolution.cpp

1//===- ScalarEvolution.cpp - Scalar Evolution Analysis --------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file contains the implementation of the scalar evolution analysis
10// engine, which is used primarily to analyze expressions involving induction
11// variables in loops.
12//
13// There are several aspects to this library. First is the representation of
14// scalar expressions, which are represented as subclasses of the SCEV class.
15// These classes are used to represent certain types of subexpressions that we
16// can handle. We only create one SCEV of a particular shape, so
17// pointer-comparisons for equality are legal.
18//
19// One important aspect of the SCEV objects is that they are never cyclic, even
20// if there is a cycle in the dataflow for an expression (ie, a PHI node). If
21// the PHI node is one of the idioms that we can represent (e.g., a polynomial
22// recurrence) then we represent it directly as a recurrence node, otherwise we
23// represent it as a SCEVUnknown node.
24//
25// In addition to being able to represent expressions of various types, we also
26// have folders that are used to build the *canonical* representation for a
27// particular expression. These folders are capable of using a variety of
28// rewrite rules to simplify the expressions.
29//
30// Once the folders are defined, we can implement the more interesting
31// higher-level code, such as the code that recognizes PHI nodes of various
32// types, computes the execution count of a loop, etc.
33//
34// TODO: We should use these routines and value representations to implement
35// dependence analysis!
36//
37//===----------------------------------------------------------------------===//
38//
39// There are several good references for the techniques used in this analysis.
40//
41// Chains of recurrences -- a method to expedite the evaluation
42// of closed-form functions
43// Olaf Bachmann, Paul S. Wang, Eugene V. Zima
44//
45// On computational properties of chains of recurrences
46// Eugene V. Zima
47//
48// Symbolic Evaluation of Chains of Recurrences for Loop Optimization
49// Robert A. van Engelen
50//
51// Efficient Symbolic Analysis for Optimizing Compilers
52// Robert A. van Engelen
53//
54// Using the chains of recurrences algebra for data dependence testing and
55// induction variable substitution
56// MS Thesis, Johnie Birch
57//
58//===----------------------------------------------------------------------===//
59
60#include "llvm/Analysis/ScalarEvolution.h"
61#include "llvm/ADT/APInt.h"
62#include "llvm/ADT/ArrayRef.h"
63#include "llvm/ADT/DenseMap.h"
64#include "llvm/ADT/DepthFirstIterator.h"
65#include "llvm/ADT/EquivalenceClasses.h"
66#include "llvm/ADT/FoldingSet.h"
67#include "llvm/ADT/None.h"
68#include "llvm/ADT/Optional.h"
69#include "llvm/ADT/STLExtras.h"
70#include "llvm/ADT/ScopeExit.h"
71#include "llvm/ADT/Sequence.h"
72#include "llvm/ADT/SmallPtrSet.h"
73#include "llvm/ADT/SmallSet.h"
74#include "llvm/ADT/SmallVector.h"
75#include "llvm/ADT/Statistic.h"
76#include "llvm/ADT/StringRef.h"
77#include "llvm/Analysis/AssumptionCache.h"
78#include "llvm/Analysis/ConstantFolding.h"
79#include "llvm/Analysis/InstructionSimplify.h"
80#include "llvm/Analysis/LoopInfo.h"
81#include "llvm/Analysis/ScalarEvolutionExpressions.h"
82#include "llvm/Analysis/TargetLibraryInfo.h"
83#include "llvm/Analysis/ValueTracking.h"
84#include "llvm/Config/llvm-config.h"
85#include "llvm/IR/Argument.h"
86#include "llvm/IR/BasicBlock.h"
87#include "llvm/IR/CFG.h"
88#include "llvm/IR/Constant.h"
89#include "llvm/IR/ConstantRange.h"
90#include "llvm/IR/Constants.h"
91#include "llvm/IR/DataLayout.h"
92#include "llvm/IR/DerivedTypes.h"
93#include "llvm/IR/Dominators.h"
94#include "llvm/IR/Function.h"
95#include "llvm/IR/GlobalAlias.h"
96#include "llvm/IR/GlobalValue.h"
97#include "llvm/IR/InstIterator.h"
98#include "llvm/IR/InstrTypes.h"
99#include "llvm/IR/Instruction.h"
100#include "llvm/IR/Instructions.h"
101#include "llvm/IR/IntrinsicInst.h"
102#include "llvm/IR/Intrinsics.h"
103#include "llvm/IR/LLVMContext.h"
104#include "llvm/IR/Operator.h"
105#include "llvm/IR/PatternMatch.h"
106#include "llvm/IR/Type.h"
107#include "llvm/IR/Use.h"
108#include "llvm/IR/User.h"
109#include "llvm/IR/Value.h"
110#include "llvm/IR/Verifier.h"
111#include "llvm/InitializePasses.h"
112#include "llvm/Pass.h"
113#include "llvm/Support/Casting.h"
114#include "llvm/Support/CommandLine.h"
115#include "llvm/Support/Compiler.h"
116#include "llvm/Support/Debug.h"
117#include "llvm/Support/ErrorHandling.h"
118#include "llvm/Support/KnownBits.h"
119#include "llvm/Support/SaveAndRestore.h"
120#include "llvm/Support/raw_ostream.h"
121#include <algorithm>
122#include <cassert>
123#include <climits>
124#include <cstdint>
125#include <cstdlib>
126#include <map>
127#include <memory>
128#include <numeric>
129#include <tuple>
130#include <utility>
131#include <vector>
132
133using namespace llvm;
134using namespace PatternMatch;
135
136#define DEBUG_TYPE"scalar-evolution" "scalar-evolution"
137
138STATISTIC(NumTripCountsComputed,static llvm::Statistic NumTripCountsComputed = {"scalar-evolution"
, "NumTripCountsComputed", "Number of loops with predictable loop counts"
}
139 "Number of loops with predictable loop counts")static llvm::Statistic NumTripCountsComputed = {"scalar-evolution"
, "NumTripCountsComputed", "Number of loops with predictable loop counts"
}
;
140STATISTIC(NumTripCountsNotComputed,static llvm::Statistic NumTripCountsNotComputed = {"scalar-evolution"
, "NumTripCountsNotComputed", "Number of loops without predictable loop counts"
}
141 "Number of loops without predictable loop counts")static llvm::Statistic NumTripCountsNotComputed = {"scalar-evolution"
, "NumTripCountsNotComputed", "Number of loops without predictable loop counts"
}
;
142STATISTIC(NumBruteForceTripCountsComputed,static llvm::Statistic NumBruteForceTripCountsComputed = {"scalar-evolution"
, "NumBruteForceTripCountsComputed", "Number of loops with trip counts computed by force"
}
143 "Number of loops with trip counts computed by force")static llvm::Statistic NumBruteForceTripCountsComputed = {"scalar-evolution"
, "NumBruteForceTripCountsComputed", "Number of loops with trip counts computed by force"
}
;
144
145#ifdef EXPENSIVE_CHECKS
146bool llvm::VerifySCEV = true;
147#else
148bool llvm::VerifySCEV = false;
149#endif
150
151static cl::opt<unsigned>
152 MaxBruteForceIterations("scalar-evolution-max-iterations", cl::ReallyHidden,
153 cl::desc("Maximum number of iterations SCEV will "
154 "symbolically execute a constant "
155 "derived loop"),
156 cl::init(100));
157
158static cl::opt<bool, true> VerifySCEVOpt(
159 "verify-scev", cl::Hidden, cl::location(VerifySCEV),
160 cl::desc("Verify ScalarEvolution's backedge taken counts (slow)"));
161static cl::opt<bool> VerifySCEVStrict(
162 "verify-scev-strict", cl::Hidden,
163 cl::desc("Enable stricter verification with -verify-scev is passed"));
164static cl::opt<bool>
165 VerifySCEVMap("verify-scev-maps", cl::Hidden,
166 cl::desc("Verify no dangling value in ScalarEvolution's "
167 "ExprValueMap (slow)"));
168
169static cl::opt<bool> VerifyIR(
170 "scev-verify-ir", cl::Hidden,
171 cl::desc("Verify IR correctness when making sensitive SCEV queries (slow)"),
172 cl::init(false));
173
174static cl::opt<unsigned> MulOpsInlineThreshold(
175 "scev-mulops-inline-threshold", cl::Hidden,
176 cl::desc("Threshold for inlining multiplication operands into a SCEV"),
177 cl::init(32));
178
179static cl::opt<unsigned> AddOpsInlineThreshold(
180 "scev-addops-inline-threshold", cl::Hidden,
181 cl::desc("Threshold for inlining addition operands into a SCEV"),
182 cl::init(500));
183
184static cl::opt<unsigned> MaxSCEVCompareDepth(
185 "scalar-evolution-max-scev-compare-depth", cl::Hidden,
186 cl::desc("Maximum depth of recursive SCEV complexity comparisons"),
187 cl::init(32));
188
189static cl::opt<unsigned> MaxSCEVOperationsImplicationDepth(
190 "scalar-evolution-max-scev-operations-implication-depth", cl::Hidden,
191 cl::desc("Maximum depth of recursive SCEV operations implication analysis"),
192 cl::init(2));
193
194static cl::opt<unsigned> MaxValueCompareDepth(
195 "scalar-evolution-max-value-compare-depth", cl::Hidden,
196 cl::desc("Maximum depth of recursive value complexity comparisons"),
197 cl::init(2));
198
199static cl::opt<unsigned>
200 MaxArithDepth("scalar-evolution-max-arith-depth", cl::Hidden,
201 cl::desc("Maximum depth of recursive arithmetics"),
202 cl::init(32));
203
204static cl::opt<unsigned> MaxConstantEvolvingDepth(
205 "scalar-evolution-max-constant-evolving-depth", cl::Hidden,
206 cl::desc("Maximum depth of recursive constant evolving"), cl::init(32));
207
208static cl::opt<unsigned>
209 MaxCastDepth("scalar-evolution-max-cast-depth", cl::Hidden,
210 cl::desc("Maximum depth of recursive SExt/ZExt/Trunc"),
211 cl::init(8));
212
213static cl::opt<unsigned>
214 MaxAddRecSize("scalar-evolution-max-add-rec-size", cl::Hidden,
215 cl::desc("Max coefficients in AddRec during evolving"),
216 cl::init(8));
217
218static cl::opt<unsigned>
219 HugeExprThreshold("scalar-evolution-huge-expr-threshold", cl::Hidden,
220 cl::desc("Size of the expression which is considered huge"),
221 cl::init(4096));
222
223static cl::opt<bool>
224ClassifyExpressions("scalar-evolution-classify-expressions",
225 cl::Hidden, cl::init(true),
226 cl::desc("When printing analysis, include information on every instruction"));
227
228static cl::opt<bool> UseExpensiveRangeSharpening(
229 "scalar-evolution-use-expensive-range-sharpening", cl::Hidden,
230 cl::init(false),
231 cl::desc("Use more powerful methods of sharpening expression ranges. May "
232 "be costly in terms of compile time"));
233
234static cl::opt<unsigned> MaxPhiSCCAnalysisSize(
235 "scalar-evolution-max-scc-analysis-depth", cl::Hidden,
236 cl::desc("Maximum amount of nodes to process while searching SCEVUnknown "
237 "Phi strongly connected components"),
238 cl::init(8));
239
240static cl::opt<bool>
241 EnableFiniteLoopControl("scalar-evolution-finite-loop", cl::Hidden,
242 cl::desc("Handle <= and >= in finite loops"),
243 cl::init(true));
244
245static cl::opt<bool> UseContextForNoWrapFlagInference(
246 "scalar-evolution-use-context-for-no-wrap-flag-strenghening", cl::Hidden,
247 cl::desc("Infer nuw/nsw flags using context where suitable"),
248 cl::init(true));
249
250//===----------------------------------------------------------------------===//
251// SCEV class definitions
252//===----------------------------------------------------------------------===//
253
254//===----------------------------------------------------------------------===//
255// Implementation of the SCEV class.
256//
257
258#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
259LLVM_DUMP_METHOD__attribute__((noinline)) __attribute__((__used__)) void SCEV::dump() const {
260 print(dbgs());
261 dbgs() << '\n';
262}
263#endif
264
265void SCEV::print(raw_ostream &OS) const {
266 switch (getSCEVType()) {
267 case scConstant:
268 cast<SCEVConstant>(this)->getValue()->printAsOperand(OS, false);
269 return;
270 case scPtrToInt: {
271 const SCEVPtrToIntExpr *PtrToInt = cast<SCEVPtrToIntExpr>(this);
272 const SCEV *Op = PtrToInt->getOperand();
273 OS << "(ptrtoint " << *Op->getType() << " " << *Op << " to "
274 << *PtrToInt->getType() << ")";
275 return;
276 }
277 case scTruncate: {
278 const SCEVTruncateExpr *Trunc = cast<SCEVTruncateExpr>(this);
279 const SCEV *Op = Trunc->getOperand();
280 OS << "(trunc " << *Op->getType() << " " << *Op << " to "
281 << *Trunc->getType() << ")";
282 return;
283 }
284 case scZeroExtend: {
285 const SCEVZeroExtendExpr *ZExt = cast<SCEVZeroExtendExpr>(this);
286 const SCEV *Op = ZExt->getOperand();
287 OS << "(zext " << *Op->getType() << " " << *Op << " to "
288 << *ZExt->getType() << ")";
289 return;
290 }
291 case scSignExtend: {
292 const SCEVSignExtendExpr *SExt = cast<SCEVSignExtendExpr>(this);
293 const SCEV *Op = SExt->getOperand();
294 OS << "(sext " << *Op->getType() << " " << *Op << " to "
295 << *SExt->getType() << ")";
296 return;
297 }
298 case scAddRecExpr: {
299 const SCEVAddRecExpr *AR = cast<SCEVAddRecExpr>(this);
300 OS << "{" << *AR->getOperand(0);
301 for (unsigned i = 1, e = AR->getNumOperands(); i != e; ++i)
302 OS << ",+," << *AR->getOperand(i);
303 OS << "}<";
304 if (AR->hasNoUnsignedWrap())
305 OS << "nuw><";
306 if (AR->hasNoSignedWrap())
307 OS << "nsw><";
308 if (AR->hasNoSelfWrap() &&
309 !AR->getNoWrapFlags((NoWrapFlags)(FlagNUW | FlagNSW)))
310 OS << "nw><";
311 AR->getLoop()->getHeader()->printAsOperand(OS, /*PrintType=*/false);
312 OS << ">";
313 return;
314 }
315 case scAddExpr:
316 case scMulExpr:
317 case scUMaxExpr:
318 case scSMaxExpr:
319 case scUMinExpr:
320 case scSMinExpr:
321 case scSequentialUMinExpr: {
322 const SCEVNAryExpr *NAry = cast<SCEVNAryExpr>(this);
323 const char *OpStr = nullptr;
324 switch (NAry->getSCEVType()) {
325 case scAddExpr: OpStr = " + "; break;
326 case scMulExpr: OpStr = " * "; break;
327 case scUMaxExpr: OpStr = " umax "; break;
328 case scSMaxExpr: OpStr = " smax "; break;
329 case scUMinExpr:
330 OpStr = " umin ";
331 break;
332 case scSMinExpr:
333 OpStr = " smin ";
334 break;
335 case scSequentialUMinExpr:
336 OpStr = " umin_seq ";
337 break;
338 default:
339 llvm_unreachable("There are no other nary expression types.")::llvm::llvm_unreachable_internal("There are no other nary expression types."
, "llvm/lib/Analysis/ScalarEvolution.cpp", 339)
;
340 }
341 OS << "(";
342 ListSeparator LS(OpStr);
343 for (const SCEV *Op : NAry->operands())
344 OS << LS << *Op;
345 OS << ")";
346 switch (NAry->getSCEVType()) {
347 case scAddExpr:
348 case scMulExpr:
349 if (NAry->hasNoUnsignedWrap())
350 OS << "<nuw>";
351 if (NAry->hasNoSignedWrap())
352 OS << "<nsw>";
353 break;
354 default:
355 // Nothing to print for other nary expressions.
356 break;
357 }
358 return;
359 }
360 case scUDivExpr: {
361 const SCEVUDivExpr *UDiv = cast<SCEVUDivExpr>(this);
362 OS << "(" << *UDiv->getLHS() << " /u " << *UDiv->getRHS() << ")";
363 return;
364 }
365 case scUnknown: {
366 const SCEVUnknown *U = cast<SCEVUnknown>(this);
367 Type *AllocTy;
368 if (U->isSizeOf(AllocTy)) {
369 OS << "sizeof(" << *AllocTy << ")";
370 return;
371 }
372 if (U->isAlignOf(AllocTy)) {
373 OS << "alignof(" << *AllocTy << ")";
374 return;
375 }
376
377 Type *CTy;
378 Constant *FieldNo;
379 if (U->isOffsetOf(CTy, FieldNo)) {
380 OS << "offsetof(" << *CTy << ", ";
381 FieldNo->printAsOperand(OS, false);
382 OS << ")";
383 return;
384 }
385
386 // Otherwise just print it normally.
387 U->getValue()->printAsOperand(OS, false);
388 return;
389 }
390 case scCouldNotCompute:
391 OS << "***COULDNOTCOMPUTE***";
392 return;
393 }
394 llvm_unreachable("Unknown SCEV kind!")::llvm::llvm_unreachable_internal("Unknown SCEV kind!", "llvm/lib/Analysis/ScalarEvolution.cpp"
, 394)
;
395}
396
397Type *SCEV::getType() const {
398 switch (getSCEVType()) {
399 case scConstant:
400 return cast<SCEVConstant>(this)->getType();
401 case scPtrToInt:
402 case scTruncate:
403 case scZeroExtend:
404 case scSignExtend:
405 return cast<SCEVCastExpr>(this)->getType();
406 case scAddRecExpr:
407 return cast<SCEVAddRecExpr>(this)->getType();
408 case scMulExpr:
409 return cast<SCEVMulExpr>(this)->getType();
410 case scUMaxExpr:
411 case scSMaxExpr:
412 case scUMinExpr:
413 case scSMinExpr:
414 return cast<SCEVMinMaxExpr>(this)->getType();
415 case scSequentialUMinExpr:
416 return cast<SCEVSequentialMinMaxExpr>(this)->getType();
417 case scAddExpr:
418 return cast<SCEVAddExpr>(this)->getType();
419 case scUDivExpr:
420 return cast<SCEVUDivExpr>(this)->getType();
421 case scUnknown:
422 return cast<SCEVUnknown>(this)->getType();
423 case scCouldNotCompute:
424 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!")::llvm::llvm_unreachable_internal("Attempt to use a SCEVCouldNotCompute object!"
, "llvm/lib/Analysis/ScalarEvolution.cpp", 424)
;
425 }
426 llvm_unreachable("Unknown SCEV kind!")::llvm::llvm_unreachable_internal("Unknown SCEV kind!", "llvm/lib/Analysis/ScalarEvolution.cpp"
, 426)
;
427}
428
429bool SCEV::isZero() const {
430 if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(this))
431 return SC->getValue()->isZero();
432 return false;
433}
434
435bool SCEV::isOne() const {
436 if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(this))
437 return SC->getValue()->isOne();
438 return false;
439}
440
441bool SCEV::isAllOnesValue() const {
442 if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(this))
443 return SC->getValue()->isMinusOne();
444 return false;
445}
446
447bool SCEV::isNonConstantNegative() const {
448 const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(this);
449 if (!Mul) return false;
450
451 // If there is a constant factor, it will be first.
452 const SCEVConstant *SC = dyn_cast<SCEVConstant>(Mul->getOperand(0));
453 if (!SC) return false;
454
455 // Return true if the value is negative, this matches things like (-42 * V).
456 return SC->getAPInt().isNegative();
457}
458
459SCEVCouldNotCompute::SCEVCouldNotCompute() :
460 SCEV(FoldingSetNodeIDRef(), scCouldNotCompute, 0) {}
461
462bool SCEVCouldNotCompute::classof(const SCEV *S) {
463 return S->getSCEVType() == scCouldNotCompute;
464}
465
466const SCEV *ScalarEvolution::getConstant(ConstantInt *V) {
467 FoldingSetNodeID ID;
468 ID.AddInteger(scConstant);
469 ID.AddPointer(V);
470 void *IP = nullptr;
471 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
472 SCEV *S = new (SCEVAllocator) SCEVConstant(ID.Intern(SCEVAllocator), V);
473 UniqueSCEVs.InsertNode(S, IP);
474 return S;
475}
476
477const SCEV *ScalarEvolution::getConstant(const APInt &Val) {
478 return getConstant(ConstantInt::get(getContext(), Val));
479}
480
481const SCEV *
482ScalarEvolution::getConstant(Type *Ty, uint64_t V, bool isSigned) {
483 IntegerType *ITy = cast<IntegerType>(getEffectiveSCEVType(Ty));
484 return getConstant(ConstantInt::get(ITy, V, isSigned));
485}
486
487SCEVCastExpr::SCEVCastExpr(const FoldingSetNodeIDRef ID, SCEVTypes SCEVTy,
488 const SCEV *op, Type *ty)
489 : SCEV(ID, SCEVTy, computeExpressionSize(op)), Ty(ty) {
490 Operands[0] = op;
491}
492
493SCEVPtrToIntExpr::SCEVPtrToIntExpr(const FoldingSetNodeIDRef ID, const SCEV *Op,
494 Type *ITy)
495 : SCEVCastExpr(ID, scPtrToInt, Op, ITy) {
496 assert(getOperand()->getType()->isPointerTy() && Ty->isIntegerTy() &&(static_cast <bool> (getOperand()->getType()->isPointerTy
() && Ty->isIntegerTy() && "Must be a non-bit-width-changing pointer-to-integer cast!"
) ? void (0) : __assert_fail ("getOperand()->getType()->isPointerTy() && Ty->isIntegerTy() && \"Must be a non-bit-width-changing pointer-to-integer cast!\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 497, __extension__
__PRETTY_FUNCTION__))
497 "Must be a non-bit-width-changing pointer-to-integer cast!")(static_cast <bool> (getOperand()->getType()->isPointerTy
() && Ty->isIntegerTy() && "Must be a non-bit-width-changing pointer-to-integer cast!"
) ? void (0) : __assert_fail ("getOperand()->getType()->isPointerTy() && Ty->isIntegerTy() && \"Must be a non-bit-width-changing pointer-to-integer cast!\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 497, __extension__
__PRETTY_FUNCTION__))
;
498}
499
500SCEVIntegralCastExpr::SCEVIntegralCastExpr(const FoldingSetNodeIDRef ID,
501 SCEVTypes SCEVTy, const SCEV *op,
502 Type *ty)
503 : SCEVCastExpr(ID, SCEVTy, op, ty) {}
504
505SCEVTruncateExpr::SCEVTruncateExpr(const FoldingSetNodeIDRef ID, const SCEV *op,
506 Type *ty)
507 : SCEVIntegralCastExpr(ID, scTruncate, op, ty) {
508 assert(getOperand()->getType()->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&(static_cast <bool> (getOperand()->getType()->isIntOrPtrTy
() && Ty->isIntOrPtrTy() && "Cannot truncate non-integer value!"
) ? void (0) : __assert_fail ("getOperand()->getType()->isIntOrPtrTy() && Ty->isIntOrPtrTy() && \"Cannot truncate non-integer value!\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 509, __extension__
__PRETTY_FUNCTION__))
509 "Cannot truncate non-integer value!")(static_cast <bool> (getOperand()->getType()->isIntOrPtrTy
() && Ty->isIntOrPtrTy() && "Cannot truncate non-integer value!"
) ? void (0) : __assert_fail ("getOperand()->getType()->isIntOrPtrTy() && Ty->isIntOrPtrTy() && \"Cannot truncate non-integer value!\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 509, __extension__
__PRETTY_FUNCTION__))
;
510}
511
512SCEVZeroExtendExpr::SCEVZeroExtendExpr(const FoldingSetNodeIDRef ID,
513 const SCEV *op, Type *ty)
514 : SCEVIntegralCastExpr(ID, scZeroExtend, op, ty) {
515 assert(getOperand()->getType()->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&(static_cast <bool> (getOperand()->getType()->isIntOrPtrTy
() && Ty->isIntOrPtrTy() && "Cannot zero extend non-integer value!"
) ? void (0) : __assert_fail ("getOperand()->getType()->isIntOrPtrTy() && Ty->isIntOrPtrTy() && \"Cannot zero extend non-integer value!\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 516, __extension__
__PRETTY_FUNCTION__))
516 "Cannot zero extend non-integer value!")(static_cast <bool> (getOperand()->getType()->isIntOrPtrTy
() && Ty->isIntOrPtrTy() && "Cannot zero extend non-integer value!"
) ? void (0) : __assert_fail ("getOperand()->getType()->isIntOrPtrTy() && Ty->isIntOrPtrTy() && \"Cannot zero extend non-integer value!\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 516, __extension__
__PRETTY_FUNCTION__))
;
517}
518
519SCEVSignExtendExpr::SCEVSignExtendExpr(const FoldingSetNodeIDRef ID,
520 const SCEV *op, Type *ty)
521 : SCEVIntegralCastExpr(ID, scSignExtend, op, ty) {
522 assert(getOperand()->getType()->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&(static_cast <bool> (getOperand()->getType()->isIntOrPtrTy
() && Ty->isIntOrPtrTy() && "Cannot sign extend non-integer value!"
) ? void (0) : __assert_fail ("getOperand()->getType()->isIntOrPtrTy() && Ty->isIntOrPtrTy() && \"Cannot sign extend non-integer value!\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 523, __extension__
__PRETTY_FUNCTION__))
523 "Cannot sign extend non-integer value!")(static_cast <bool> (getOperand()->getType()->isIntOrPtrTy
() && Ty->isIntOrPtrTy() && "Cannot sign extend non-integer value!"
) ? void (0) : __assert_fail ("getOperand()->getType()->isIntOrPtrTy() && Ty->isIntOrPtrTy() && \"Cannot sign extend non-integer value!\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 523, __extension__
__PRETTY_FUNCTION__))
;
524}
525
526void SCEVUnknown::deleted() {
527 // Clear this SCEVUnknown from various maps.
528 SE->forgetMemoizedResults(this);
529
530 // Remove this SCEVUnknown from the uniquing map.
531 SE->UniqueSCEVs.RemoveNode(this);
532
533 // Release the value.
534 setValPtr(nullptr);
535}
536
537void SCEVUnknown::allUsesReplacedWith(Value *New) {
538 // Clear this SCEVUnknown from various maps.
539 SE->forgetMemoizedResults(this);
540
541 // Remove this SCEVUnknown from the uniquing map.
542 SE->UniqueSCEVs.RemoveNode(this);
543
544 // Replace the value pointer in case someone is still using this SCEVUnknown.
545 setValPtr(New);
546}
547
548bool SCEVUnknown::isSizeOf(Type *&AllocTy) const {
549 if (ConstantExpr *VCE = dyn_cast<ConstantExpr>(getValue()))
550 if (VCE->getOpcode() == Instruction::PtrToInt)
551 if (ConstantExpr *CE = dyn_cast<ConstantExpr>(VCE->getOperand(0)))
552 if (CE->getOpcode() == Instruction::GetElementPtr &&
553 CE->getOperand(0)->isNullValue() &&
554 CE->getNumOperands() == 2)
555 if (ConstantInt *CI = dyn_cast<ConstantInt>(CE->getOperand(1)))
556 if (CI->isOne()) {
557 AllocTy = cast<GEPOperator>(CE)->getSourceElementType();
558 return true;
559 }
560
561 return false;
562}
563
564bool SCEVUnknown::isAlignOf(Type *&AllocTy) const {
565 if (ConstantExpr *VCE = dyn_cast<ConstantExpr>(getValue()))
566 if (VCE->getOpcode() == Instruction::PtrToInt)
567 if (ConstantExpr *CE = dyn_cast<ConstantExpr>(VCE->getOperand(0)))
568 if (CE->getOpcode() == Instruction::GetElementPtr &&
569 CE->getOperand(0)->isNullValue()) {
570 Type *Ty = cast<GEPOperator>(CE)->getSourceElementType();
571 if (StructType *STy = dyn_cast<StructType>(Ty))
572 if (!STy->isPacked() &&
573 CE->getNumOperands() == 3 &&
574 CE->getOperand(1)->isNullValue()) {
575 if (ConstantInt *CI = dyn_cast<ConstantInt>(CE->getOperand(2)))
576 if (CI->isOne() &&
577 STy->getNumElements() == 2 &&
578 STy->getElementType(0)->isIntegerTy(1)) {
579 AllocTy = STy->getElementType(1);
580 return true;
581 }
582 }
583 }
584
585 return false;
586}
587
588bool SCEVUnknown::isOffsetOf(Type *&CTy, Constant *&FieldNo) const {
589 if (ConstantExpr *VCE = dyn_cast<ConstantExpr>(getValue()))
590 if (VCE->getOpcode() == Instruction::PtrToInt)
591 if (ConstantExpr *CE = dyn_cast<ConstantExpr>(VCE->getOperand(0)))
592 if (CE->getOpcode() == Instruction::GetElementPtr &&
593 CE->getNumOperands() == 3 &&
594 CE->getOperand(0)->isNullValue() &&
595 CE->getOperand(1)->isNullValue()) {
596 Type *Ty = cast<GEPOperator>(CE)->getSourceElementType();
597 // Ignore vector types here so that ScalarEvolutionExpander doesn't
598 // emit getelementptrs that index into vectors.
599 if (Ty->isStructTy() || Ty->isArrayTy()) {
600 CTy = Ty;
601 FieldNo = CE->getOperand(2);
602 return true;
603 }
604 }
605
606 return false;
607}
608
609//===----------------------------------------------------------------------===//
610// SCEV Utilities
611//===----------------------------------------------------------------------===//
612
613/// Compare the two values \p LV and \p RV in terms of their "complexity" where
614/// "complexity" is a partial (and somewhat ad-hoc) relation used to order
615/// operands in SCEV expressions. \p EqCache is a set of pairs of values that
616/// have been previously deemed to be "equally complex" by this routine. It is
617/// intended to avoid exponential time complexity in cases like:
618///
619/// %a = f(%x, %y)
620/// %b = f(%a, %a)
621/// %c = f(%b, %b)
622///
623/// %d = f(%x, %y)
624/// %e = f(%d, %d)
625/// %f = f(%e, %e)
626///
627/// CompareValueComplexity(%f, %c)
628///
629/// Since we do not continue running this routine on expression trees once we
630/// have seen unequal values, there is no need to track them in the cache.
631static int
632CompareValueComplexity(EquivalenceClasses<const Value *> &EqCacheValue,
633 const LoopInfo *const LI, Value *LV, Value *RV,
634 unsigned Depth) {
635 if (Depth > MaxValueCompareDepth || EqCacheValue.isEquivalent(LV, RV))
636 return 0;
637
638 // Order pointer values after integer values. This helps SCEVExpander form
639 // GEPs.
640 bool LIsPointer = LV->getType()->isPointerTy(),
641 RIsPointer = RV->getType()->isPointerTy();
642 if (LIsPointer != RIsPointer)
643 return (int)LIsPointer - (int)RIsPointer;
644
645 // Compare getValueID values.
646 unsigned LID = LV->getValueID(), RID = RV->getValueID();
647 if (LID != RID)
648 return (int)LID - (int)RID;
649
650 // Sort arguments by their position.
651 if (const auto *LA = dyn_cast<Argument>(LV)) {
652 const auto *RA = cast<Argument>(RV);
653 unsigned LArgNo = LA->getArgNo(), RArgNo = RA->getArgNo();
654 return (int)LArgNo - (int)RArgNo;
655 }
656
657 if (const auto *LGV = dyn_cast<GlobalValue>(LV)) {
658 const auto *RGV = cast<GlobalValue>(RV);
659
660 const auto IsGVNameSemantic = [&](const GlobalValue *GV) {
661 auto LT = GV->getLinkage();
662 return !(GlobalValue::isPrivateLinkage(LT) ||
663 GlobalValue::isInternalLinkage(LT));
664 };
665
666 // Use the names to distinguish the two values, but only if the
667 // names are semantically important.
668 if (IsGVNameSemantic(LGV) && IsGVNameSemantic(RGV))
669 return LGV->getName().compare(RGV->getName());
670 }
671
672 // For instructions, compare their loop depth, and their operand count. This
673 // is pretty loose.
674 if (const auto *LInst = dyn_cast<Instruction>(LV)) {
675 const auto *RInst = cast<Instruction>(RV);
676
677 // Compare loop depths.
678 const BasicBlock *LParent = LInst->getParent(),
679 *RParent = RInst->getParent();
680 if (LParent != RParent) {
681 unsigned LDepth = LI->getLoopDepth(LParent),
682 RDepth = LI->getLoopDepth(RParent);
683 if (LDepth != RDepth)
684 return (int)LDepth - (int)RDepth;
685 }
686
687 // Compare the number of operands.
688 unsigned LNumOps = LInst->getNumOperands(),
689 RNumOps = RInst->getNumOperands();
690 if (LNumOps != RNumOps)
691 return (int)LNumOps - (int)RNumOps;
692
693 for (unsigned Idx : seq(0u, LNumOps)) {
694 int Result =
695 CompareValueComplexity(EqCacheValue, LI, LInst->getOperand(Idx),
696 RInst->getOperand(Idx), Depth + 1);
697 if (Result != 0)
698 return Result;
699 }
700 }
701
702 EqCacheValue.unionSets(LV, RV);
703 return 0;
704}
705
706// Return negative, zero, or positive, if LHS is less than, equal to, or greater
707// than RHS, respectively. A three-way result allows recursive comparisons to be
708// more efficient.
709// If the max analysis depth was reached, return None, assuming we do not know
710// if they are equivalent for sure.
711static Optional<int>
712CompareSCEVComplexity(EquivalenceClasses<const SCEV *> &EqCacheSCEV,
713 EquivalenceClasses<const Value *> &EqCacheValue,
714 const LoopInfo *const LI, const SCEV *LHS,
715 const SCEV *RHS, DominatorTree &DT, unsigned Depth = 0) {
716 // Fast-path: SCEVs are uniqued so we can do a quick equality check.
717 if (LHS == RHS)
718 return 0;
719
720 // Primarily, sort the SCEVs by their getSCEVType().
721 SCEVTypes LType = LHS->getSCEVType(), RType = RHS->getSCEVType();
722 if (LType != RType)
723 return (int)LType - (int)RType;
724
725 if (EqCacheSCEV.isEquivalent(LHS, RHS))
726 return 0;
727
728 if (Depth > MaxSCEVCompareDepth)
729 return None;
730
731 // Aside from the getSCEVType() ordering, the particular ordering
732 // isn't very important except that it's beneficial to be consistent,
733 // so that (a + b) and (b + a) don't end up as different expressions.
734 switch (LType) {
735 case scUnknown: {
736 const SCEVUnknown *LU = cast<SCEVUnknown>(LHS);
737 const SCEVUnknown *RU = cast<SCEVUnknown>(RHS);
738
739 int X = CompareValueComplexity(EqCacheValue, LI, LU->getValue(),
740 RU->getValue(), Depth + 1);
741 if (X == 0)
742 EqCacheSCEV.unionSets(LHS, RHS);
743 return X;
744 }
745
746 case scConstant: {
747 const SCEVConstant *LC = cast<SCEVConstant>(LHS);
748 const SCEVConstant *RC = cast<SCEVConstant>(RHS);
749
750 // Compare constant values.
751 const APInt &LA = LC->getAPInt();
752 const APInt &RA = RC->getAPInt();
753 unsigned LBitWidth = LA.getBitWidth(), RBitWidth = RA.getBitWidth();
754 if (LBitWidth != RBitWidth)
755 return (int)LBitWidth - (int)RBitWidth;
756 return LA.ult(RA) ? -1 : 1;
757 }
758
759 case scAddRecExpr: {
760 const SCEVAddRecExpr *LA = cast<SCEVAddRecExpr>(LHS);
761 const SCEVAddRecExpr *RA = cast<SCEVAddRecExpr>(RHS);
762
763 // There is always a dominance between two recs that are used by one SCEV,
764 // so we can safely sort recs by loop header dominance. We require such
765 // order in getAddExpr.
766 const Loop *LLoop = LA->getLoop(), *RLoop = RA->getLoop();
767 if (LLoop != RLoop) {
768 const BasicBlock *LHead = LLoop->getHeader(), *RHead = RLoop->getHeader();
769 assert(LHead != RHead && "Two loops share the same header?")(static_cast <bool> (LHead != RHead && "Two loops share the same header?"
) ? void (0) : __assert_fail ("LHead != RHead && \"Two loops share the same header?\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 769, __extension__
__PRETTY_FUNCTION__))
;
770 if (DT.dominates(LHead, RHead))
771 return 1;
772 else
773 assert(DT.dominates(RHead, LHead) &&(static_cast <bool> (DT.dominates(RHead, LHead) &&
"No dominance between recurrences used by one SCEV?") ? void
(0) : __assert_fail ("DT.dominates(RHead, LHead) && \"No dominance between recurrences used by one SCEV?\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 774, __extension__
__PRETTY_FUNCTION__))
774 "No dominance between recurrences used by one SCEV?")(static_cast <bool> (DT.dominates(RHead, LHead) &&
"No dominance between recurrences used by one SCEV?") ? void
(0) : __assert_fail ("DT.dominates(RHead, LHead) && \"No dominance between recurrences used by one SCEV?\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 774, __extension__
__PRETTY_FUNCTION__))
;
775 return -1;
776 }
777
778 // Addrec complexity grows with operand count.
779 unsigned LNumOps = LA->getNumOperands(), RNumOps = RA->getNumOperands();
780 if (LNumOps != RNumOps)
781 return (int)LNumOps - (int)RNumOps;
782
783 // Lexicographically compare.
784 for (unsigned i = 0; i != LNumOps; ++i) {
785 auto X = CompareSCEVComplexity(EqCacheSCEV, EqCacheValue, LI,
786 LA->getOperand(i), RA->getOperand(i), DT,
787 Depth + 1);
788 if (X != 0)
789 return X;
790 }
791 EqCacheSCEV.unionSets(LHS, RHS);
792 return 0;
793 }
794
795 case scAddExpr:
796 case scMulExpr:
797 case scSMaxExpr:
798 case scUMaxExpr:
799 case scSMinExpr:
800 case scUMinExpr:
801 case scSequentialUMinExpr: {
802 const SCEVNAryExpr *LC = cast<SCEVNAryExpr>(LHS);
803 const SCEVNAryExpr *RC = cast<SCEVNAryExpr>(RHS);
804
805 // Lexicographically compare n-ary expressions.
806 unsigned LNumOps = LC->getNumOperands(), RNumOps = RC->getNumOperands();
807 if (LNumOps != RNumOps)
808 return (int)LNumOps - (int)RNumOps;
809
810 for (unsigned i = 0; i != LNumOps; ++i) {
811 auto X = CompareSCEVComplexity(EqCacheSCEV, EqCacheValue, LI,
812 LC->getOperand(i), RC->getOperand(i), DT,
813 Depth + 1);
814 if (X != 0)
815 return X;
816 }
817 EqCacheSCEV.unionSets(LHS, RHS);
818 return 0;
819 }
820
821 case scUDivExpr: {
822 const SCEVUDivExpr *LC = cast<SCEVUDivExpr>(LHS);
823 const SCEVUDivExpr *RC = cast<SCEVUDivExpr>(RHS);
824
825 // Lexicographically compare udiv expressions.
826 auto X = CompareSCEVComplexity(EqCacheSCEV, EqCacheValue, LI, LC->getLHS(),
827 RC->getLHS(), DT, Depth + 1);
828 if (X != 0)
829 return X;
830 X = CompareSCEVComplexity(EqCacheSCEV, EqCacheValue, LI, LC->getRHS(),
831 RC->getRHS(), DT, Depth + 1);
832 if (X == 0)
833 EqCacheSCEV.unionSets(LHS, RHS);
834 return X;
835 }
836
837 case scPtrToInt:
838 case scTruncate:
839 case scZeroExtend:
840 case scSignExtend: {
841 const SCEVCastExpr *LC = cast<SCEVCastExpr>(LHS);
842 const SCEVCastExpr *RC = cast<SCEVCastExpr>(RHS);
843
844 // Compare cast expressions by operand.
845 auto X =
846 CompareSCEVComplexity(EqCacheSCEV, EqCacheValue, LI, LC->getOperand(),
847 RC->getOperand(), DT, Depth + 1);
848 if (X == 0)
849 EqCacheSCEV.unionSets(LHS, RHS);
850 return X;
851 }
852
853 case scCouldNotCompute:
854 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!")::llvm::llvm_unreachable_internal("Attempt to use a SCEVCouldNotCompute object!"
, "llvm/lib/Analysis/ScalarEvolution.cpp", 854)
;
855 }
856 llvm_unreachable("Unknown SCEV kind!")::llvm::llvm_unreachable_internal("Unknown SCEV kind!", "llvm/lib/Analysis/ScalarEvolution.cpp"
, 856)
;
857}
858
859/// Given a list of SCEV objects, order them by their complexity, and group
860/// objects of the same complexity together by value. When this routine is
861/// finished, we know that any duplicates in the vector are consecutive and that
862/// complexity is monotonically increasing.
863///
864/// Note that we go take special precautions to ensure that we get deterministic
865/// results from this routine. In other words, we don't want the results of
866/// this to depend on where the addresses of various SCEV objects happened to
867/// land in memory.
868static void GroupByComplexity(SmallVectorImpl<const SCEV *> &Ops,
869 LoopInfo *LI, DominatorTree &DT) {
870 if (Ops.size() < 2) return; // Noop
871
872 EquivalenceClasses<const SCEV *> EqCacheSCEV;
873 EquivalenceClasses<const Value *> EqCacheValue;
874
875 // Whether LHS has provably less complexity than RHS.
876 auto IsLessComplex = [&](const SCEV *LHS, const SCEV *RHS) {
877 auto Complexity =
878 CompareSCEVComplexity(EqCacheSCEV, EqCacheValue, LI, LHS, RHS, DT);
879 return Complexity && *Complexity < 0;
880 };
881 if (Ops.size() == 2) {
882 // This is the common case, which also happens to be trivially simple.
883 // Special case it.
884 const SCEV *&LHS = Ops[0], *&RHS = Ops[1];
885 if (IsLessComplex(RHS, LHS))
886 std::swap(LHS, RHS);
887 return;
888 }
889
890 // Do the rough sort by complexity.
891 llvm::stable_sort(Ops, [&](const SCEV *LHS, const SCEV *RHS) {
892 return IsLessComplex(LHS, RHS);
893 });
894
895 // Now that we are sorted by complexity, group elements of the same
896 // complexity. Note that this is, at worst, N^2, but the vector is likely to
897 // be extremely short in practice. Note that we take this approach because we
898 // do not want to depend on the addresses of the objects we are grouping.
899 for (unsigned i = 0, e = Ops.size(); i != e-2; ++i) {
900 const SCEV *S = Ops[i];
901 unsigned Complexity = S->getSCEVType();
902
903 // If there are any objects of the same complexity and same value as this
904 // one, group them.
905 for (unsigned j = i+1; j != e && Ops[j]->getSCEVType() == Complexity; ++j) {
906 if (Ops[j] == S) { // Found a duplicate.
907 // Move it to immediately after i'th element.
908 std::swap(Ops[i+1], Ops[j]);
909 ++i; // no need to rescan it.
910 if (i == e-2) return; // Done!
911 }
912 }
913 }
914}
915
916/// Returns true if \p Ops contains a huge SCEV (the subtree of S contains at
917/// least HugeExprThreshold nodes).
918static bool hasHugeExpression(ArrayRef<const SCEV *> Ops) {
919 return any_of(Ops, [](const SCEV *S) {
920 return S->getExpressionSize() >= HugeExprThreshold;
921 });
922}
923
924//===----------------------------------------------------------------------===//
925// Simple SCEV method implementations
926//===----------------------------------------------------------------------===//
927
928/// Compute BC(It, K). The result has width W. Assume, K > 0.
929static const SCEV *BinomialCoefficient(const SCEV *It, unsigned K,
930 ScalarEvolution &SE,
931 Type *ResultTy) {
932 // Handle the simplest case efficiently.
933 if (K == 1)
934 return SE.getTruncateOrZeroExtend(It, ResultTy);
935
936 // We are using the following formula for BC(It, K):
937 //
938 // BC(It, K) = (It * (It - 1) * ... * (It - K + 1)) / K!
939 //
940 // Suppose, W is the bitwidth of the return value. We must be prepared for
941 // overflow. Hence, we must assure that the result of our computation is
942 // equal to the accurate one modulo 2^W. Unfortunately, division isn't
943 // safe in modular arithmetic.
944 //
945 // However, this code doesn't use exactly that formula; the formula it uses
946 // is something like the following, where T is the number of factors of 2 in
947 // K! (i.e. trailing zeros in the binary representation of K!), and ^ is
948 // exponentiation:
949 //
950 // BC(It, K) = (It * (It - 1) * ... * (It - K + 1)) / 2^T / (K! / 2^T)
951 //
952 // This formula is trivially equivalent to the previous formula. However,
953 // this formula can be implemented much more efficiently. The trick is that
954 // K! / 2^T is odd, and exact division by an odd number *is* safe in modular
955 // arithmetic. To do exact division in modular arithmetic, all we have
956 // to do is multiply by the inverse. Therefore, this step can be done at
957 // width W.
958 //
959 // The next issue is how to safely do the division by 2^T. The way this
960 // is done is by doing the multiplication step at a width of at least W + T
961 // bits. This way, the bottom W+T bits of the product are accurate. Then,
962 // when we perform the division by 2^T (which is equivalent to a right shift
963 // by T), the bottom W bits are accurate. Extra bits are okay; they'll get
964 // truncated out after the division by 2^T.
965 //
966 // In comparison to just directly using the first formula, this technique
967 // is much more efficient; using the first formula requires W * K bits,
968 // but this formula less than W + K bits. Also, the first formula requires
969 // a division step, whereas this formula only requires multiplies and shifts.
970 //
971 // It doesn't matter whether the subtraction step is done in the calculation
972 // width or the input iteration count's width; if the subtraction overflows,
973 // the result must be zero anyway. We prefer here to do it in the width of
974 // the induction variable because it helps a lot for certain cases; CodeGen
975 // isn't smart enough to ignore the overflow, which leads to much less
976 // efficient code if the width of the subtraction is wider than the native
977 // register width.
978 //
979 // (It's possible to not widen at all by pulling out factors of 2 before
980 // the multiplication; for example, K=2 can be calculated as
981 // It/2*(It+(It*INT_MIN/INT_MIN)+-1). However, it requires
982 // extra arithmetic, so it's not an obvious win, and it gets
983 // much more complicated for K > 3.)
984
985 // Protection from insane SCEVs; this bound is conservative,
986 // but it probably doesn't matter.
987 if (K > 1000)
988 return SE.getCouldNotCompute();
989
990 unsigned W = SE.getTypeSizeInBits(ResultTy);
991
992 // Calculate K! / 2^T and T; we divide out the factors of two before
993 // multiplying for calculating K! / 2^T to avoid overflow.
994 // Other overflow doesn't matter because we only care about the bottom
995 // W bits of the result.
996 APInt OddFactorial(W, 1);
997 unsigned T = 1;
998 for (unsigned i = 3; i <= K; ++i) {
999 APInt Mult(W, i);
1000 unsigned TwoFactors = Mult.countTrailingZeros();
1001 T += TwoFactors;
1002 Mult.lshrInPlace(TwoFactors);
1003 OddFactorial *= Mult;
1004 }
1005
1006 // We need at least W + T bits for the multiplication step
1007 unsigned CalculationBits = W + T;
1008
1009 // Calculate 2^T, at width T+W.
1010 APInt DivFactor = APInt::getOneBitSet(CalculationBits, T);
1011
1012 // Calculate the multiplicative inverse of K! / 2^T;
1013 // this multiplication factor will perform the exact division by
1014 // K! / 2^T.
1015 APInt Mod = APInt::getSignedMinValue(W+1);
1016 APInt MultiplyFactor = OddFactorial.zext(W+1);
1017 MultiplyFactor = MultiplyFactor.multiplicativeInverse(Mod);
1018 MultiplyFactor = MultiplyFactor.trunc(W);
1019
1020 // Calculate the product, at width T+W
1021 IntegerType *CalculationTy = IntegerType::get(SE.getContext(),
1022 CalculationBits);
1023 const SCEV *Dividend = SE.getTruncateOrZeroExtend(It, CalculationTy);
1024 for (unsigned i = 1; i != K; ++i) {
1025 const SCEV *S = SE.getMinusSCEV(It, SE.getConstant(It->getType(), i));
1026 Dividend = SE.getMulExpr(Dividend,
1027 SE.getTruncateOrZeroExtend(S, CalculationTy));
1028 }
1029
1030 // Divide by 2^T
1031 const SCEV *DivResult = SE.getUDivExpr(Dividend, SE.getConstant(DivFactor));
1032
1033 // Truncate the result, and divide by K! / 2^T.
1034
1035 return SE.getMulExpr(SE.getConstant(MultiplyFactor),
1036 SE.getTruncateOrZeroExtend(DivResult, ResultTy));
1037}
1038
1039/// Return the value of this chain of recurrences at the specified iteration
1040/// number. We can evaluate this recurrence by multiplying each element in the
1041/// chain by the binomial coefficient corresponding to it. In other words, we
1042/// can evaluate {A,+,B,+,C,+,D} as:
1043///
1044/// A*BC(It, 0) + B*BC(It, 1) + C*BC(It, 2) + D*BC(It, 3)
1045///
1046/// where BC(It, k) stands for binomial coefficient.
1047const SCEV *SCEVAddRecExpr::evaluateAtIteration(const SCEV *It,
1048 ScalarEvolution &SE) const {
1049 return evaluateAtIteration(makeArrayRef(op_begin(), op_end()), It, SE);
1050}
1051
1052const SCEV *
1053SCEVAddRecExpr::evaluateAtIteration(ArrayRef<const SCEV *> Operands,
1054 const SCEV *It, ScalarEvolution &SE) {
1055 assert(Operands.size() > 0)(static_cast <bool> (Operands.size() > 0) ? void (0)
: __assert_fail ("Operands.size() > 0", "llvm/lib/Analysis/ScalarEvolution.cpp"
, 1055, __extension__ __PRETTY_FUNCTION__))
;
1056 const SCEV *Result = Operands[0];
1057 for (unsigned i = 1, e = Operands.size(); i != e; ++i) {
1058 // The computation is correct in the face of overflow provided that the
1059 // multiplication is performed _after_ the evaluation of the binomial
1060 // coefficient.
1061 const SCEV *Coeff = BinomialCoefficient(It, i, SE, Result->getType());
1062 if (isa<SCEVCouldNotCompute>(Coeff))
1063 return Coeff;
1064
1065 Result = SE.getAddExpr(Result, SE.getMulExpr(Operands[i], Coeff));
1066 }
1067 return Result;
1068}
1069
1070//===----------------------------------------------------------------------===//
1071// SCEV Expression folder implementations
1072//===----------------------------------------------------------------------===//
1073
1074const SCEV *ScalarEvolution::getLosslessPtrToIntExpr(const SCEV *Op,
1075 unsigned Depth) {
1076 assert(Depth <= 1 &&(static_cast <bool> (Depth <= 1 && "getLosslessPtrToIntExpr() should self-recurse at most once."
) ? void (0) : __assert_fail ("Depth <= 1 && \"getLosslessPtrToIntExpr() should self-recurse at most once.\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 1077, __extension__
__PRETTY_FUNCTION__))
1077 "getLosslessPtrToIntExpr() should self-recurse at most once.")(static_cast <bool> (Depth <= 1 && "getLosslessPtrToIntExpr() should self-recurse at most once."
) ? void (0) : __assert_fail ("Depth <= 1 && \"getLosslessPtrToIntExpr() should self-recurse at most once.\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 1077, __extension__
__PRETTY_FUNCTION__))
;
1078
1079 // We could be called with an integer-typed operands during SCEV rewrites.
1080 // Since the operand is an integer already, just perform zext/trunc/self cast.
1081 if (!Op->getType()->isPointerTy())
1082 return Op;
1083
1084 // What would be an ID for such a SCEV cast expression?
1085 FoldingSetNodeID ID;
1086 ID.AddInteger(scPtrToInt);
1087 ID.AddPointer(Op);
1088
1089 void *IP = nullptr;
1090
1091 // Is there already an expression for such a cast?
1092 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP))
1093 return S;
1094
1095 // It isn't legal for optimizations to construct new ptrtoint expressions
1096 // for non-integral pointers.
1097 if (getDataLayout().isNonIntegralPointerType(Op->getType()))
1098 return getCouldNotCompute();
1099
1100 Type *IntPtrTy = getDataLayout().getIntPtrType(Op->getType());
1101
1102 // We can only trivially model ptrtoint if SCEV's effective (integer) type
1103 // is sufficiently wide to represent all possible pointer values.
1104 // We could theoretically teach SCEV to truncate wider pointers, but
1105 // that isn't implemented for now.
1106 if (getDataLayout().getTypeSizeInBits(getEffectiveSCEVType(Op->getType())) !=
1107 getDataLayout().getTypeSizeInBits(IntPtrTy))
1108 return getCouldNotCompute();
1109
1110 // If not, is this expression something we can't reduce any further?
1111 if (auto *U = dyn_cast<SCEVUnknown>(Op)) {
1112 // Perform some basic constant folding. If the operand of the ptr2int cast
1113 // is a null pointer, don't create a ptr2int SCEV expression (that will be
1114 // left as-is), but produce a zero constant.
1115 // NOTE: We could handle a more general case, but lack motivational cases.
1116 if (isa<ConstantPointerNull>(U->getValue()))
1117 return getZero(IntPtrTy);
1118
1119 // Create an explicit cast node.
1120 // We can reuse the existing insert position since if we get here,
1121 // we won't have made any changes which would invalidate it.
1122 SCEV *S = new (SCEVAllocator)
1123 SCEVPtrToIntExpr(ID.Intern(SCEVAllocator), Op, IntPtrTy);
1124 UniqueSCEVs.InsertNode(S, IP);
1125 registerUser(S, Op);
1126 return S;
1127 }
1128
1129 assert(Depth == 0 && "getLosslessPtrToIntExpr() should not self-recurse for "(static_cast <bool> (Depth == 0 && "getLosslessPtrToIntExpr() should not self-recurse for "
"non-SCEVUnknown's.") ? void (0) : __assert_fail ("Depth == 0 && \"getLosslessPtrToIntExpr() should not self-recurse for \" \"non-SCEVUnknown's.\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 1130, __extension__
__PRETTY_FUNCTION__))
1130 "non-SCEVUnknown's.")(static_cast <bool> (Depth == 0 && "getLosslessPtrToIntExpr() should not self-recurse for "
"non-SCEVUnknown's.") ? void (0) : __assert_fail ("Depth == 0 && \"getLosslessPtrToIntExpr() should not self-recurse for \" \"non-SCEVUnknown's.\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 1130, __extension__
__PRETTY_FUNCTION__))
;
1131
1132 // Otherwise, we've got some expression that is more complex than just a
1133 // single SCEVUnknown. But we don't want to have a SCEVPtrToIntExpr of an
1134 // arbitrary expression, we want to have SCEVPtrToIntExpr of an SCEVUnknown
1135 // only, and the expressions must otherwise be integer-typed.
1136 // So sink the cast down to the SCEVUnknown's.
1137
1138 /// The SCEVPtrToIntSinkingRewriter takes a scalar evolution expression,
1139 /// which computes a pointer-typed value, and rewrites the whole expression
1140 /// tree so that *all* the computations are done on integers, and the only
1141 /// pointer-typed operands in the expression are SCEVUnknown.
1142 class SCEVPtrToIntSinkingRewriter
1143 : public SCEVRewriteVisitor<SCEVPtrToIntSinkingRewriter> {
1144 using Base = SCEVRewriteVisitor<SCEVPtrToIntSinkingRewriter>;
1145
1146 public:
1147 SCEVPtrToIntSinkingRewriter(ScalarEvolution &SE) : SCEVRewriteVisitor(SE) {}
1148
1149 static const SCEV *rewrite(const SCEV *Scev, ScalarEvolution &SE) {
1150 SCEVPtrToIntSinkingRewriter Rewriter(SE);
1151 return Rewriter.visit(Scev);
1152 }
1153
1154 const SCEV *visit(const SCEV *S) {
1155 Type *STy = S->getType();
1156 // If the expression is not pointer-typed, just keep it as-is.
1157 if (!STy->isPointerTy())
1158 return S;
1159 // Else, recursively sink the cast down into it.
1160 return Base::visit(S);
1161 }
1162
1163 const SCEV *visitAddExpr(const SCEVAddExpr *Expr) {
1164 SmallVector<const SCEV *, 2> Operands;
1165 bool Changed = false;
1166 for (const auto *Op : Expr->operands()) {
1167 Operands.push_back(visit(Op));
1168 Changed |= Op != Operands.back();
1169 }
1170 return !Changed ? Expr : SE.getAddExpr(Operands, Expr->getNoWrapFlags());
1171 }
1172
1173 const SCEV *visitMulExpr(const SCEVMulExpr *Expr) {
1174 SmallVector<const SCEV *, 2> Operands;
1175 bool Changed = false;
1176 for (const auto *Op : Expr->operands()) {
1177 Operands.push_back(visit(Op));
1178 Changed |= Op != Operands.back();
1179 }
1180 return !Changed ? Expr : SE.getMulExpr(Operands, Expr->getNoWrapFlags());
1181 }
1182
1183 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
1184 assert(Expr->getType()->isPointerTy() &&(static_cast <bool> (Expr->getType()->isPointerTy
() && "Should only reach pointer-typed SCEVUnknown's."
) ? void (0) : __assert_fail ("Expr->getType()->isPointerTy() && \"Should only reach pointer-typed SCEVUnknown's.\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 1185, __extension__
__PRETTY_FUNCTION__))
1185 "Should only reach pointer-typed SCEVUnknown's.")(static_cast <bool> (Expr->getType()->isPointerTy
() && "Should only reach pointer-typed SCEVUnknown's."
) ? void (0) : __assert_fail ("Expr->getType()->isPointerTy() && \"Should only reach pointer-typed SCEVUnknown's.\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 1185, __extension__
__PRETTY_FUNCTION__))
;
1186 return SE.getLosslessPtrToIntExpr(Expr, /*Depth=*/1);
1187 }
1188 };
1189
1190 // And actually perform the cast sinking.
1191 const SCEV *IntOp = SCEVPtrToIntSinkingRewriter::rewrite(Op, *this);
1192 assert(IntOp->getType()->isIntegerTy() &&(static_cast <bool> (IntOp->getType()->isIntegerTy
() && "We must have succeeded in sinking the cast, " "and ending up with an integer-typed expression!"
) ? void (0) : __assert_fail ("IntOp->getType()->isIntegerTy() && \"We must have succeeded in sinking the cast, \" \"and ending up with an integer-typed expression!\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 1194, __extension__
__PRETTY_FUNCTION__))
1193 "We must have succeeded in sinking the cast, "(static_cast <bool> (IntOp->getType()->isIntegerTy
() && "We must have succeeded in sinking the cast, " "and ending up with an integer-typed expression!"
) ? void (0) : __assert_fail ("IntOp->getType()->isIntegerTy() && \"We must have succeeded in sinking the cast, \" \"and ending up with an integer-typed expression!\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 1194, __extension__
__PRETTY_FUNCTION__))
1194 "and ending up with an integer-typed expression!")(static_cast <bool> (IntOp->getType()->isIntegerTy
() && "We must have succeeded in sinking the cast, " "and ending up with an integer-typed expression!"
) ? void (0) : __assert_fail ("IntOp->getType()->isIntegerTy() && \"We must have succeeded in sinking the cast, \" \"and ending up with an integer-typed expression!\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 1194, __extension__
__PRETTY_FUNCTION__))
;
1195 return IntOp;
1196}
1197
1198const SCEV *ScalarEvolution::getPtrToIntExpr(const SCEV *Op, Type *Ty) {
1199 assert(Ty->isIntegerTy() && "Target type must be an integer type!")(static_cast <bool> (Ty->isIntegerTy() && "Target type must be an integer type!"
) ? void (0) : __assert_fail ("Ty->isIntegerTy() && \"Target type must be an integer type!\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 1199, __extension__
__PRETTY_FUNCTION__))
;
1200
1201 const SCEV *IntOp = getLosslessPtrToIntExpr(Op);
1202 if (isa<SCEVCouldNotCompute>(IntOp))
1203 return IntOp;
1204
1205 return getTruncateOrZeroExtend(IntOp, Ty);
1206}
1207
1208const SCEV *ScalarEvolution::getTruncateExpr(const SCEV *Op, Type *Ty,
1209 unsigned Depth) {
1210 assert(getTypeSizeInBits(Op->getType()) > getTypeSizeInBits(Ty) &&(static_cast <bool> (getTypeSizeInBits(Op->getType()
) > getTypeSizeInBits(Ty) && "This is not a truncating conversion!"
) ? void (0) : __assert_fail ("getTypeSizeInBits(Op->getType()) > getTypeSizeInBits(Ty) && \"This is not a truncating conversion!\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 1211, __extension__
__PRETTY_FUNCTION__))
1211 "This is not a truncating conversion!")(static_cast <bool> (getTypeSizeInBits(Op->getType()
) > getTypeSizeInBits(Ty) && "This is not a truncating conversion!"
) ? void (0) : __assert_fail ("getTypeSizeInBits(Op->getType()) > getTypeSizeInBits(Ty) && \"This is not a truncating conversion!\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 1211, __extension__
__PRETTY_FUNCTION__))
;
1212 assert(isSCEVable(Ty) &&(static_cast <bool> (isSCEVable(Ty) && "This is not a conversion to a SCEVable type!"
) ? void (0) : __assert_fail ("isSCEVable(Ty) && \"This is not a conversion to a SCEVable type!\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 1213, __extension__
__PRETTY_FUNCTION__))
1213 "This is not a conversion to a SCEVable type!")(static_cast <bool> (isSCEVable(Ty) && "This is not a conversion to a SCEVable type!"
) ? void (0) : __assert_fail ("isSCEVable(Ty) && \"This is not a conversion to a SCEVable type!\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 1213, __extension__
__PRETTY_FUNCTION__))
;
1214 assert(!Op->getType()->isPointerTy() && "Can't truncate pointer!")(static_cast <bool> (!Op->getType()->isPointerTy(
) && "Can't truncate pointer!") ? void (0) : __assert_fail
("!Op->getType()->isPointerTy() && \"Can't truncate pointer!\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 1214, __extension__
__PRETTY_FUNCTION__))
;
1215 Ty = getEffectiveSCEVType(Ty);
1216
1217 FoldingSetNodeID ID;
1218 ID.AddInteger(scTruncate);
1219 ID.AddPointer(Op);
1220 ID.AddPointer(Ty);
1221 void *IP = nullptr;
1222 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
1223
1224 // Fold if the operand is constant.
1225 if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op))
1226 return getConstant(
1227 cast<ConstantInt>(ConstantExpr::getTrunc(SC->getValue(), Ty)));
1228
1229 // trunc(trunc(x)) --> trunc(x)
1230 if (const SCEVTruncateExpr *ST = dyn_cast<SCEVTruncateExpr>(Op))
1231 return getTruncateExpr(ST->getOperand(), Ty, Depth + 1);
1232
1233 // trunc(sext(x)) --> sext(x) if widening or trunc(x) if narrowing
1234 if (const SCEVSignExtendExpr *SS = dyn_cast<SCEVSignExtendExpr>(Op))
1235 return getTruncateOrSignExtend(SS->getOperand(), Ty, Depth + 1);
1236
1237 // trunc(zext(x)) --> zext(x) if widening or trunc(x) if narrowing
1238 if (const SCEVZeroExtendExpr *SZ = dyn_cast<SCEVZeroExtendExpr>(Op))
1239 return getTruncateOrZeroExtend(SZ->getOperand(), Ty, Depth + 1);
1240
1241 if (Depth > MaxCastDepth) {
1242 SCEV *S =
1243 new (SCEVAllocator) SCEVTruncateExpr(ID.Intern(SCEVAllocator), Op, Ty);
1244 UniqueSCEVs.InsertNode(S, IP);
1245 registerUser(S, Op);
1246 return S;
1247 }
1248
1249 // trunc(x1 + ... + xN) --> trunc(x1) + ... + trunc(xN) and
1250 // trunc(x1 * ... * xN) --> trunc(x1) * ... * trunc(xN),
1251 // if after transforming we have at most one truncate, not counting truncates
1252 // that replace other casts.
1253 if (isa<SCEVAddExpr>(Op) || isa<SCEVMulExpr>(Op)) {
1254 auto *CommOp = cast<SCEVCommutativeExpr>(Op);
1255 SmallVector<const SCEV *, 4> Operands;
1256 unsigned numTruncs = 0;
1257 for (unsigned i = 0, e = CommOp->getNumOperands(); i != e && numTruncs < 2;
1258 ++i) {
1259 const SCEV *S = getTruncateExpr(CommOp->getOperand(i), Ty, Depth + 1);
1260 if (!isa<SCEVIntegralCastExpr>(CommOp->getOperand(i)) &&
1261 isa<SCEVTruncateExpr>(S))
1262 numTruncs++;
1263 Operands.push_back(S);
1264 }
1265 if (numTruncs < 2) {
1266 if (isa<SCEVAddExpr>(Op))
1267 return getAddExpr(Operands);
1268 else if (isa<SCEVMulExpr>(Op))
1269 return getMulExpr(Operands);
1270 else
1271 llvm_unreachable("Unexpected SCEV type for Op.")::llvm::llvm_unreachable_internal("Unexpected SCEV type for Op."
, "llvm/lib/Analysis/ScalarEvolution.cpp", 1271)
;
1272 }
1273 // Although we checked in the beginning that ID is not in the cache, it is
1274 // possible that during recursion and different modification ID was inserted
1275 // into the cache. So if we find it, just return it.
1276 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP))
1277 return S;
1278 }
1279
1280 // If the input value is a chrec scev, truncate the chrec's operands.
1281 if (const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(Op)) {
1282 SmallVector<const SCEV *, 4> Operands;
1283 for (const SCEV *Op : AddRec->operands())
1284 Operands.push_back(getTruncateExpr(Op, Ty, Depth + 1));
1285 return getAddRecExpr(Operands, AddRec->getLoop(), SCEV::FlagAnyWrap);
1286 }
1287
1288 // Return zero if truncating to known zeros.
1289 uint32_t MinTrailingZeros = GetMinTrailingZeros(Op);
1290 if (MinTrailingZeros >= getTypeSizeInBits(Ty))
1291 return getZero(Ty);
1292
1293 // The cast wasn't folded; create an explicit cast node. We can reuse
1294 // the existing insert position since if we get here, we won't have
1295 // made any changes which would invalidate it.
1296 SCEV *S = new (SCEVAllocator) SCEVTruncateExpr(ID.Intern(SCEVAllocator),
1297 Op, Ty);
1298 UniqueSCEVs.InsertNode(S, IP);
1299 registerUser(S, Op);
1300 return S;
1301}
1302
1303// Get the limit of a recurrence such that incrementing by Step cannot cause
1304// signed overflow as long as the value of the recurrence within the
1305// loop does not exceed this limit before incrementing.
1306static const SCEV *getSignedOverflowLimitForStep(const SCEV *Step,
1307 ICmpInst::Predicate *Pred,
1308 ScalarEvolution *SE) {
1309 unsigned BitWidth = SE->getTypeSizeInBits(Step->getType());
1310 if (SE->isKnownPositive(Step)) {
1311 *Pred = ICmpInst::ICMP_SLT;
1312 return SE->getConstant(APInt::getSignedMinValue(BitWidth) -
1313 SE->getSignedRangeMax(Step));
1314 }
1315 if (SE->isKnownNegative(Step)) {
1316 *Pred = ICmpInst::ICMP_SGT;
1317 return SE->getConstant(APInt::getSignedMaxValue(BitWidth) -
1318 SE->getSignedRangeMin(Step));
1319 }
1320 return nullptr;
1321}
1322
1323// Get the limit of a recurrence such that incrementing by Step cannot cause
1324// unsigned overflow as long as the value of the recurrence within the loop does
1325// not exceed this limit before incrementing.
1326static const SCEV *getUnsignedOverflowLimitForStep(const SCEV *Step,
1327 ICmpInst::Predicate *Pred,
1328 ScalarEvolution *SE) {
1329 unsigned BitWidth = SE->getTypeSizeInBits(Step->getType());
1330 *Pred = ICmpInst::ICMP_ULT;
1331
1332 return SE->getConstant(APInt::getMinValue(BitWidth) -
1333 SE->getUnsignedRangeMax(Step));
1334}
1335
1336namespace {
1337
1338struct ExtendOpTraitsBase {
1339 typedef const SCEV *(ScalarEvolution::*GetExtendExprTy)(const SCEV *, Type *,
1340 unsigned);
1341};
1342
1343// Used to make code generic over signed and unsigned overflow.
1344template <typename ExtendOp> struct ExtendOpTraits {
1345 // Members present:
1346 //
1347 // static const SCEV::NoWrapFlags WrapType;
1348 //
1349 // static const ExtendOpTraitsBase::GetExtendExprTy GetExtendExpr;
1350 //
1351 // static const SCEV *getOverflowLimitForStep(const SCEV *Step,
1352 // ICmpInst::Predicate *Pred,
1353 // ScalarEvolution *SE);
1354};
1355
1356template <>
1357struct ExtendOpTraits<SCEVSignExtendExpr> : public ExtendOpTraitsBase {
1358 static const SCEV::NoWrapFlags WrapType = SCEV::FlagNSW;
1359
1360 static const GetExtendExprTy GetExtendExpr;
1361
1362 static const SCEV *getOverflowLimitForStep(const SCEV *Step,
1363 ICmpInst::Predicate *Pred,
1364 ScalarEvolution *SE) {
1365 return getSignedOverflowLimitForStep(Step, Pred, SE);
1366 }
1367};
1368
1369const ExtendOpTraitsBase::GetExtendExprTy ExtendOpTraits<
1370 SCEVSignExtendExpr>::GetExtendExpr = &ScalarEvolution::getSignExtendExpr;
1371
1372template <>
1373struct ExtendOpTraits<SCEVZeroExtendExpr> : public ExtendOpTraitsBase {
1374 static const SCEV::NoWrapFlags WrapType = SCEV::FlagNUW;
1375
1376 static const GetExtendExprTy GetExtendExpr;
1377
1378 static const SCEV *getOverflowLimitForStep(const SCEV *Step,
1379 ICmpInst::Predicate *Pred,
1380 ScalarEvolution *SE) {
1381 return getUnsignedOverflowLimitForStep(Step, Pred, SE);
1382 }
1383};
1384
1385const ExtendOpTraitsBase::GetExtendExprTy ExtendOpTraits<
1386 SCEVZeroExtendExpr>::GetExtendExpr = &ScalarEvolution::getZeroExtendExpr;
1387
1388} // end anonymous namespace
1389
1390// The recurrence AR has been shown to have no signed/unsigned wrap or something
1391// close to it. Typically, if we can prove NSW/NUW for AR, then we can just as
1392// easily prove NSW/NUW for its preincrement or postincrement sibling. This
1393// allows normalizing a sign/zero extended AddRec as such: {sext/zext(Step +
1394// Start),+,Step} => {(Step + sext/zext(Start),+,Step} As a result, the
1395// expression "Step + sext/zext(PreIncAR)" is congruent with
1396// "sext/zext(PostIncAR)"
1397template <typename ExtendOpTy>
1398static const SCEV *getPreStartForExtend(const SCEVAddRecExpr *AR, Type *Ty,
1399 ScalarEvolution *SE, unsigned Depth) {
1400 auto WrapType = ExtendOpTraits<ExtendOpTy>::WrapType;
1401 auto GetExtendExpr = ExtendOpTraits<ExtendOpTy>::GetExtendExpr;
1402
1403 const Loop *L = AR->getLoop();
1404 const SCEV *Start = AR->getStart();
1405 const SCEV *Step = AR->getStepRecurrence(*SE);
1406
1407 // Check for a simple looking step prior to loop entry.
1408 const SCEVAddExpr *SA = dyn_cast<SCEVAddExpr>(Start);
1409 if (!SA)
1410 return nullptr;
1411
1412 // Create an AddExpr for "PreStart" after subtracting Step. Full SCEV
1413 // subtraction is expensive. For this purpose, perform a quick and dirty
1414 // difference, by checking for Step in the operand list.
1415 SmallVector<const SCEV *, 4> DiffOps;
1416 for (const SCEV *Op : SA->operands())
1417 if (Op != Step)
1418 DiffOps.push_back(Op);
1419
1420 if (DiffOps.size() == SA->getNumOperands())
1421 return nullptr;
1422
1423 // Try to prove `WrapType` (SCEV::FlagNSW or SCEV::FlagNUW) on `PreStart` +
1424 // `Step`:
1425
1426 // 1. NSW/NUW flags on the step increment.
1427 auto PreStartFlags =
1428 ScalarEvolution::maskFlags(SA->getNoWrapFlags(), SCEV::FlagNUW);
1429 const SCEV *PreStart = SE->getAddExpr(DiffOps, PreStartFlags);
1430 const SCEVAddRecExpr *PreAR = dyn_cast<SCEVAddRecExpr>(
1431 SE->getAddRecExpr(PreStart, Step, L, SCEV::FlagAnyWrap));
1432
1433 // "{S,+,X} is <nsw>/<nuw>" and "the backedge is taken at least once" implies
1434 // "S+X does not sign/unsign-overflow".
1435 //
1436
1437 const SCEV *BECount = SE->getBackedgeTakenCount(L);
1438 if (PreAR && PreAR->getNoWrapFlags(WrapType) &&
1439 !isa<SCEVCouldNotCompute>(BECount) && SE->isKnownPositive(BECount))
1440 return PreStart;
1441
1442 // 2. Direct overflow check on the step operation's expression.
1443 unsigned BitWidth = SE->getTypeSizeInBits(AR->getType());
1444 Type *WideTy = IntegerType::get(SE->getContext(), BitWidth * 2);
1445 const SCEV *OperandExtendedStart =
1446 SE->getAddExpr((SE->*GetExtendExpr)(PreStart, WideTy, Depth),
1447 (SE->*GetExtendExpr)(Step, WideTy, Depth));
1448 if ((SE->*GetExtendExpr)(Start, WideTy, Depth) == OperandExtendedStart) {
1449 if (PreAR && AR->getNoWrapFlags(WrapType)) {
1450 // If we know `AR` == {`PreStart`+`Step`,+,`Step`} is `WrapType` (FlagNSW
1451 // or FlagNUW) and that `PreStart` + `Step` is `WrapType` too, then
1452 // `PreAR` == {`PreStart`,+,`Step`} is also `WrapType`. Cache this fact.
1453 SE->setNoWrapFlags(const_cast<SCEVAddRecExpr *>(PreAR), WrapType);
1454 }
1455 return PreStart;
1456 }
1457
1458 // 3. Loop precondition.
1459 ICmpInst::Predicate Pred;
1460 const SCEV *OverflowLimit =
1461 ExtendOpTraits<ExtendOpTy>::getOverflowLimitForStep(Step, &Pred, SE);
1462
1463 if (OverflowLimit &&
1464 SE->isLoopEntryGuardedByCond(L, Pred, PreStart, OverflowLimit))
1465 return PreStart;
1466
1467 return nullptr;
1468}
1469
1470// Get the normalized zero or sign extended expression for this AddRec's Start.
1471template <typename ExtendOpTy>
1472static const SCEV *getExtendAddRecStart(const SCEVAddRecExpr *AR, Type *Ty,
1473 ScalarEvolution *SE,
1474 unsigned Depth) {
1475 auto GetExtendExpr = ExtendOpTraits<ExtendOpTy>::GetExtendExpr;
1476
1477 const SCEV *PreStart = getPreStartForExtend<ExtendOpTy>(AR, Ty, SE, Depth);
1478 if (!PreStart)
1479 return (SE->*GetExtendExpr)(AR->getStart(), Ty, Depth);
1480
1481 return SE->getAddExpr((SE->*GetExtendExpr)(AR->getStepRecurrence(*SE), Ty,
1482 Depth),
1483 (SE->*GetExtendExpr)(PreStart, Ty, Depth));
1484}
1485
1486// Try to prove away overflow by looking at "nearby" add recurrences. A
1487// motivating example for this rule: if we know `{0,+,4}` is `ult` `-1` and it
1488// does not itself wrap then we can conclude that `{1,+,4}` is `nuw`.
1489//
1490// Formally:
1491//
1492// {S,+,X} == {S-T,+,X} + T
1493// => Ext({S,+,X}) == Ext({S-T,+,X} + T)
1494//
1495// If ({S-T,+,X} + T) does not overflow ... (1)
1496//
1497// RHS == Ext({S-T,+,X} + T) == Ext({S-T,+,X}) + Ext(T)
1498//
1499// If {S-T,+,X} does not overflow ... (2)
1500//
1501// RHS == Ext({S-T,+,X}) + Ext(T) == {Ext(S-T),+,Ext(X)} + Ext(T)
1502// == {Ext(S-T)+Ext(T),+,Ext(X)}
1503//
1504// If (S-T)+T does not overflow ... (3)
1505//
1506// RHS == {Ext(S-T)+Ext(T),+,Ext(X)} == {Ext(S-T+T),+,Ext(X)}
1507// == {Ext(S),+,Ext(X)} == LHS
1508//
1509// Thus, if (1), (2) and (3) are true for some T, then
1510// Ext({S,+,X}) == {Ext(S),+,Ext(X)}
1511//
1512// (3) is implied by (1) -- "(S-T)+T does not overflow" is simply "({S-T,+,X}+T)
1513// does not overflow" restricted to the 0th iteration. Therefore we only need
1514// to check for (1) and (2).
1515//
1516// In the current context, S is `Start`, X is `Step`, Ext is `ExtendOpTy` and T
1517// is `Delta` (defined below).
1518template <typename ExtendOpTy>
1519bool ScalarEvolution::proveNoWrapByVaryingStart(const SCEV *Start,
1520 const SCEV *Step,
1521 const Loop *L) {
1522 auto WrapType = ExtendOpTraits<ExtendOpTy>::WrapType;
1523
1524 // We restrict `Start` to a constant to prevent SCEV from spending too much
1525 // time here. It is correct (but more expensive) to continue with a
1526 // non-constant `Start` and do a general SCEV subtraction to compute
1527 // `PreStart` below.
1528 const SCEVConstant *StartC = dyn_cast<SCEVConstant>(Start);
1529 if (!StartC)
1530 return false;
1531
1532 APInt StartAI = StartC->getAPInt();
1533
1534 for (unsigned Delta : {-2, -1, 1, 2}) {
1535 const SCEV *PreStart = getConstant(StartAI - Delta);
1536
1537 FoldingSetNodeID ID;
1538 ID.AddInteger(scAddRecExpr);
1539 ID.AddPointer(PreStart);
1540 ID.AddPointer(Step);
1541 ID.AddPointer(L);
1542 void *IP = nullptr;
1543 const auto *PreAR =
1544 static_cast<SCEVAddRecExpr *>(UniqueSCEVs.FindNodeOrInsertPos(ID, IP));
1545
1546 // Give up if we don't already have the add recurrence we need because
1547 // actually constructing an add recurrence is relatively expensive.
1548 if (PreAR && PreAR->getNoWrapFlags(WrapType)) { // proves (2)
1549 const SCEV *DeltaS = getConstant(StartC->getType(), Delta);
1550 ICmpInst::Predicate Pred = ICmpInst::BAD_ICMP_PREDICATE;
1551 const SCEV *Limit = ExtendOpTraits<ExtendOpTy>::getOverflowLimitForStep(
1552 DeltaS, &Pred, this);
1553 if (Limit && isKnownPredicate(Pred, PreAR, Limit)) // proves (1)
1554 return true;
1555 }
1556 }
1557
1558 return false;
1559}
1560
1561// Finds an integer D for an expression (C + x + y + ...) such that the top
1562// level addition in (D + (C - D + x + y + ...)) would not wrap (signed or
1563// unsigned) and the number of trailing zeros of (C - D + x + y + ...) is
1564// maximized, where C is the \p ConstantTerm, x, y, ... are arbitrary SCEVs, and
1565// the (C + x + y + ...) expression is \p WholeAddExpr.
1566static APInt extractConstantWithoutWrapping(ScalarEvolution &SE,
1567 const SCEVConstant *ConstantTerm,
1568 const SCEVAddExpr *WholeAddExpr) {
1569 const APInt &C = ConstantTerm->getAPInt();
1570 const unsigned BitWidth = C.getBitWidth();
1571 // Find number of trailing zeros of (x + y + ...) w/o the C first:
1572 uint32_t TZ = BitWidth;
1573 for (unsigned I = 1, E = WholeAddExpr->getNumOperands(); I < E && TZ; ++I)
1574 TZ = std::min(TZ, SE.GetMinTrailingZeros(WholeAddExpr->getOperand(I)));
1575 if (TZ) {
1576 // Set D to be as many least significant bits of C as possible while still
1577 // guaranteeing that adding D to (C - D + x + y + ...) won't cause a wrap:
1578 return TZ < BitWidth ? C.trunc(TZ).zext(BitWidth) : C;
1579 }
1580 return APInt(BitWidth, 0);
1581}
1582
1583// Finds an integer D for an affine AddRec expression {C,+,x} such that the top
1584// level addition in (D + {C-D,+,x}) would not wrap (signed or unsigned) and the
1585// number of trailing zeros of (C - D + x * n) is maximized, where C is the \p
1586// ConstantStart, x is an arbitrary \p Step, and n is the loop trip count.
1587static APInt extractConstantWithoutWrapping(ScalarEvolution &SE,
1588 const APInt &ConstantStart,
1589 const SCEV *Step) {
1590 const unsigned BitWidth = ConstantStart.getBitWidth();
1591 const uint32_t TZ = SE.GetMinTrailingZeros(Step);
1592 if (TZ)
1593 return TZ < BitWidth ? ConstantStart.trunc(TZ).zext(BitWidth)
1594 : ConstantStart;
1595 return APInt(BitWidth, 0);
1596}
1597
1598const SCEV *
1599ScalarEvolution::getZeroExtendExpr(const SCEV *Op, Type *Ty, unsigned Depth) {
1600 assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) &&(static_cast <bool> (getTypeSizeInBits(Op->getType()
) < getTypeSizeInBits(Ty) && "This is not an extending conversion!"
) ? void (0) : __assert_fail ("getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) && \"This is not an extending conversion!\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 1601, __extension__
__PRETTY_FUNCTION__))
1601 "This is not an extending conversion!")(static_cast <bool> (getTypeSizeInBits(Op->getType()
) < getTypeSizeInBits(Ty) && "This is not an extending conversion!"
) ? void (0) : __assert_fail ("getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) && \"This is not an extending conversion!\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 1601, __extension__
__PRETTY_FUNCTION__))
;
1602 assert(isSCEVable(Ty) &&(static_cast <bool> (isSCEVable(Ty) && "This is not a conversion to a SCEVable type!"
) ? void (0) : __assert_fail ("isSCEVable(Ty) && \"This is not a conversion to a SCEVable type!\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 1603, __extension__
__PRETTY_FUNCTION__))
1603 "This is not a conversion to a SCEVable type!")(static_cast <bool> (isSCEVable(Ty) && "This is not a conversion to a SCEVable type!"
) ? void (0) : __assert_fail ("isSCEVable(Ty) && \"This is not a conversion to a SCEVable type!\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 1603, __extension__
__PRETTY_FUNCTION__))
;
1604 assert(!Op->getType()->isPointerTy() && "Can't extend pointer!")(static_cast <bool> (!Op->getType()->isPointerTy(
) && "Can't extend pointer!") ? void (0) : __assert_fail
("!Op->getType()->isPointerTy() && \"Can't extend pointer!\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 1604, __extension__
__PRETTY_FUNCTION__))
;
1605 Ty = getEffectiveSCEVType(Ty);
1606
1607 // Fold if the operand is constant.
1608 if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op))
1609 return getConstant(
1610 cast<ConstantInt>(ConstantExpr::getZExt(SC->getValue(), Ty)));
1611
1612 // zext(zext(x)) --> zext(x)
1613 if (const SCEVZeroExtendExpr *SZ = dyn_cast<SCEVZeroExtendExpr>(Op))
1614 return getZeroExtendExpr(SZ->getOperand(), Ty, Depth + 1);
1615
1616 // Before doing any expensive analysis, check to see if we've already
1617 // computed a SCEV for this Op and Ty.
1618 FoldingSetNodeID ID;
1619 ID.AddInteger(scZeroExtend);
1620 ID.AddPointer(Op);
1621 ID.AddPointer(Ty);
1622 void *IP = nullptr;
1623 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
1624 if (Depth > MaxCastDepth) {
1625 SCEV *S = new (SCEVAllocator) SCEVZeroExtendExpr(ID.Intern(SCEVAllocator),
1626 Op, Ty);
1627 UniqueSCEVs.InsertNode(S, IP);
1628 registerUser(S, Op);
1629 return S;
1630 }
1631
1632 // zext(trunc(x)) --> zext(x) or x or trunc(x)
1633 if (const SCEVTruncateExpr *ST = dyn_cast<SCEVTruncateExpr>(Op)) {
1634 // It's possible the bits taken off by the truncate were all zero bits. If
1635 // so, we should be able to simplify this further.
1636 const SCEV *X = ST->getOperand();
1637 ConstantRange CR = getUnsignedRange(X);
1638 unsigned TruncBits = getTypeSizeInBits(ST->getType());
1639 unsigned NewBits = getTypeSizeInBits(Ty);
1640 if (CR.truncate(TruncBits).zeroExtend(NewBits).contains(
1641 CR.zextOrTrunc(NewBits)))
1642 return getTruncateOrZeroExtend(X, Ty, Depth);
1643 }
1644
1645 // If the input value is a chrec scev, and we can prove that the value
1646 // did not overflow the old, smaller, value, we can zero extend all of the
1647 // operands (often constants). This allows analysis of something like
1648 // this: for (unsigned char X = 0; X < 100; ++X) { int Y = X; }
1649 if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Op))
1650 if (AR->isAffine()) {
1651 const SCEV *Start = AR->getStart();
1652 const SCEV *Step = AR->getStepRecurrence(*this);
1653 unsigned BitWidth = getTypeSizeInBits(AR->getType());
1654 const Loop *L = AR->getLoop();
1655
1656 if (!AR->hasNoUnsignedWrap()) {
1657 auto NewFlags = proveNoWrapViaConstantRanges(AR);
1658 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), NewFlags);
1659 }
1660
1661 // If we have special knowledge that this addrec won't overflow,
1662 // we don't need to do any further analysis.
1663 if (AR->hasNoUnsignedWrap()) {
1664 Start =
1665 getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this, Depth + 1);
1666 Step = getZeroExtendExpr(Step, Ty, Depth + 1);
1667 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
1668 }
1669
1670 // Check whether the backedge-taken count is SCEVCouldNotCompute.
1671 // Note that this serves two purposes: It filters out loops that are
1672 // simply not analyzable, and it covers the case where this code is
1673 // being called from within backedge-taken count analysis, such that
1674 // attempting to ask for the backedge-taken count would likely result
1675 // in infinite recursion. In the later case, the analysis code will
1676 // cope with a conservative value, and it will take care to purge
1677 // that value once it has finished.
1678 const SCEV *MaxBECount = getConstantMaxBackedgeTakenCount(L);
1679 if (!isa<SCEVCouldNotCompute>(MaxBECount)) {
1680 // Manually compute the final value for AR, checking for overflow.
1681
1682 // Check whether the backedge-taken count can be losslessly casted to
1683 // the addrec's type. The count is always unsigned.
1684 const SCEV *CastedMaxBECount =
1685 getTruncateOrZeroExtend(MaxBECount, Start->getType(), Depth);
1686 const SCEV *RecastedMaxBECount = getTruncateOrZeroExtend(
1687 CastedMaxBECount, MaxBECount->getType(), Depth);
1688 if (MaxBECount == RecastedMaxBECount) {
1689 Type *WideTy = IntegerType::get(getContext(), BitWidth * 2);
1690 // Check whether Start+Step*MaxBECount has no unsigned overflow.
1691 const SCEV *ZMul = getMulExpr(CastedMaxBECount, Step,
1692 SCEV::FlagAnyWrap, Depth + 1);
1693 const SCEV *ZAdd = getZeroExtendExpr(getAddExpr(Start, ZMul,
1694 SCEV::FlagAnyWrap,
1695 Depth + 1),
1696 WideTy, Depth + 1);
1697 const SCEV *WideStart = getZeroExtendExpr(Start, WideTy, Depth + 1);
1698 const SCEV *WideMaxBECount =
1699 getZeroExtendExpr(CastedMaxBECount, WideTy, Depth + 1);
1700 const SCEV *OperandExtendedAdd =
1701 getAddExpr(WideStart,
1702 getMulExpr(WideMaxBECount,
1703 getZeroExtendExpr(Step, WideTy, Depth + 1),
1704 SCEV::FlagAnyWrap, Depth + 1),
1705 SCEV::FlagAnyWrap, Depth + 1);
1706 if (ZAdd == OperandExtendedAdd) {
1707 // Cache knowledge of AR NUW, which is propagated to this AddRec.
1708 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNUW);
1709 // Return the expression with the addrec on the outside.
1710 Start = getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this,
1711 Depth + 1);
1712 Step = getZeroExtendExpr(Step, Ty, Depth + 1);
1713 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
1714 }
1715 // Similar to above, only this time treat the step value as signed.
1716 // This covers loops that count down.
1717 OperandExtendedAdd =
1718 getAddExpr(WideStart,
1719 getMulExpr(WideMaxBECount,
1720 getSignExtendExpr(Step, WideTy, Depth + 1),
1721 SCEV::FlagAnyWrap, Depth + 1),
1722 SCEV::FlagAnyWrap, Depth + 1);
1723 if (ZAdd == OperandExtendedAdd) {
1724 // Cache knowledge of AR NW, which is propagated to this AddRec.
1725 // Negative step causes unsigned wrap, but it still can't self-wrap.
1726 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNW);
1727 // Return the expression with the addrec on the outside.
1728 Start = getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this,
1729 Depth + 1);
1730 Step = getSignExtendExpr(Step, Ty, Depth + 1);
1731 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
1732 }
1733 }
1734 }
1735
1736 // Normally, in the cases we can prove no-overflow via a
1737 // backedge guarding condition, we can also compute a backedge
1738 // taken count for the loop. The exceptions are assumptions and
1739 // guards present in the loop -- SCEV is not great at exploiting
1740 // these to compute max backedge taken counts, but can still use
1741 // these to prove lack of overflow. Use this fact to avoid
1742 // doing extra work that may not pay off.
1743 if (!isa<SCEVCouldNotCompute>(MaxBECount) || HasGuards ||
1744 !AC.assumptions().empty()) {
1745
1746 auto NewFlags = proveNoUnsignedWrapViaInduction(AR);
1747 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), NewFlags);
1748 if (AR->hasNoUnsignedWrap()) {
1749 // Same as nuw case above - duplicated here to avoid a compile time
1750 // issue. It's not clear that the order of checks does matter, but
1751 // it's one of two issue possible causes for a change which was
1752 // reverted. Be conservative for the moment.
1753 Start =
1754 getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this, Depth + 1);
1755 Step = getZeroExtendExpr(Step, Ty, Depth + 1);
1756 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
1757 }
1758
1759 // For a negative step, we can extend the operands iff doing so only
1760 // traverses values in the range zext([0,UINT_MAX]).
1761 if (isKnownNegative(Step)) {
1762 const SCEV *N = getConstant(APInt::getMaxValue(BitWidth) -
1763 getSignedRangeMin(Step));
1764 if (isLoopBackedgeGuardedByCond(L, ICmpInst::ICMP_UGT, AR, N) ||
1765 isKnownOnEveryIteration(ICmpInst::ICMP_UGT, AR, N)) {
1766 // Cache knowledge of AR NW, which is propagated to this
1767 // AddRec. Negative step causes unsigned wrap, but it
1768 // still can't self-wrap.
1769 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNW);
1770 // Return the expression with the addrec on the outside.
1771 Start = getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this,
1772 Depth + 1);
1773 Step = getSignExtendExpr(Step, Ty, Depth + 1);
1774 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
1775 }
1776 }
1777 }
1778
1779 // zext({C,+,Step}) --> (zext(D) + zext({C-D,+,Step}))<nuw><nsw>
1780 // if D + (C - D + Step * n) could be proven to not unsigned wrap
1781 // where D maximizes the number of trailing zeros of (C - D + Step * n)
1782 if (const auto *SC = dyn_cast<SCEVConstant>(Start)) {
1783 const APInt &C = SC->getAPInt();
1784 const APInt &D = extractConstantWithoutWrapping(*this, C, Step);
1785 if (D != 0) {
1786 const SCEV *SZExtD = getZeroExtendExpr(getConstant(D), Ty, Depth);
1787 const SCEV *SResidual =
1788 getAddRecExpr(getConstant(C - D), Step, L, AR->getNoWrapFlags());
1789 const SCEV *SZExtR = getZeroExtendExpr(SResidual, Ty, Depth + 1);
1790 return getAddExpr(SZExtD, SZExtR,
1791 (SCEV::NoWrapFlags)(SCEV::FlagNSW | SCEV::FlagNUW),
1792 Depth + 1);
1793 }
1794 }
1795
1796 if (proveNoWrapByVaryingStart<SCEVZeroExtendExpr>(Start, Step, L)) {
1797 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNUW);
1798 Start =
1799 getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this, Depth + 1);
1800 Step = getZeroExtendExpr(Step, Ty, Depth + 1);
1801 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
1802 }
1803 }
1804
1805 // zext(A % B) --> zext(A) % zext(B)
1806 {
1807 const SCEV *LHS;
1808 const SCEV *RHS;
1809 if (matchURem(Op, LHS, RHS))
1810 return getURemExpr(getZeroExtendExpr(LHS, Ty, Depth + 1),
1811 getZeroExtendExpr(RHS, Ty, Depth + 1));
1812 }
1813
1814 // zext(A / B) --> zext(A) / zext(B).
1815 if (auto *Div = dyn_cast<SCEVUDivExpr>(Op))
1816 return getUDivExpr(getZeroExtendExpr(Div->getLHS(), Ty, Depth + 1),
1817 getZeroExtendExpr(Div->getRHS(), Ty, Depth + 1));
1818
1819 if (auto *SA = dyn_cast<SCEVAddExpr>(Op)) {
1820 // zext((A + B + ...)<nuw>) --> (zext(A) + zext(B) + ...)<nuw>
1821 if (SA->hasNoUnsignedWrap()) {
1822 // If the addition does not unsign overflow then we can, by definition,
1823 // commute the zero extension with the addition operation.
1824 SmallVector<const SCEV *, 4> Ops;
1825 for (const auto *Op : SA->operands())
1826 Ops.push_back(getZeroExtendExpr(Op, Ty, Depth + 1));
1827 return getAddExpr(Ops, SCEV::FlagNUW, Depth + 1);
1828 }
1829
1830 // zext(C + x + y + ...) --> (zext(D) + zext((C - D) + x + y + ...))
1831 // if D + (C - D + x + y + ...) could be proven to not unsigned wrap
1832 // where D maximizes the number of trailing zeros of (C - D + x + y + ...)
1833 //
1834 // Often address arithmetics contain expressions like
1835 // (zext (add (shl X, C1), C2)), for instance, (zext (5 + (4 * X))).
1836 // This transformation is useful while proving that such expressions are
1837 // equal or differ by a small constant amount, see LoadStoreVectorizer pass.
1838 if (const auto *SC = dyn_cast<SCEVConstant>(SA->getOperand(0))) {
1839 const APInt &D = extractConstantWithoutWrapping(*this, SC, SA);
1840 if (D != 0) {
1841 const SCEV *SZExtD = getZeroExtendExpr(getConstant(D), Ty, Depth);
1842 const SCEV *SResidual =
1843 getAddExpr(getConstant(-D), SA, SCEV::FlagAnyWrap, Depth);
1844 const SCEV *SZExtR = getZeroExtendExpr(SResidual, Ty, Depth + 1);
1845 return getAddExpr(SZExtD, SZExtR,
1846 (SCEV::NoWrapFlags)(SCEV::FlagNSW | SCEV::FlagNUW),
1847 Depth + 1);
1848 }
1849 }
1850 }
1851
1852 if (auto *SM = dyn_cast<SCEVMulExpr>(Op)) {
1853 // zext((A * B * ...)<nuw>) --> (zext(A) * zext(B) * ...)<nuw>
1854 if (SM->hasNoUnsignedWrap()) {
1855 // If the multiply does not unsign overflow then we can, by definition,
1856 // commute the zero extension with the multiply operation.
1857 SmallVector<const SCEV *, 4> Ops;
1858 for (const auto *Op : SM->operands())
1859 Ops.push_back(getZeroExtendExpr(Op, Ty, Depth + 1));
1860 return getMulExpr(Ops, SCEV::FlagNUW, Depth + 1);
1861 }
1862
1863 // zext(2^K * (trunc X to iN)) to iM ->
1864 // 2^K * (zext(trunc X to i{N-K}) to iM)<nuw>
1865 //
1866 // Proof:
1867 //
1868 // zext(2^K * (trunc X to iN)) to iM
1869 // = zext((trunc X to iN) << K) to iM
1870 // = zext((trunc X to i{N-K}) << K)<nuw> to iM
1871 // (because shl removes the top K bits)
1872 // = zext((2^K * (trunc X to i{N-K}))<nuw>) to iM
1873 // = (2^K * (zext(trunc X to i{N-K}) to iM))<nuw>.
1874 //
1875 if (SM->getNumOperands() == 2)
1876 if (auto *MulLHS = dyn_cast<SCEVConstant>(SM->getOperand(0)))
1877 if (MulLHS->getAPInt().isPowerOf2())
1878 if (auto *TruncRHS = dyn_cast<SCEVTruncateExpr>(SM->getOperand(1))) {
1879 int NewTruncBits = getTypeSizeInBits(TruncRHS->getType()) -
1880 MulLHS->getAPInt().logBase2();
1881 Type *NewTruncTy = IntegerType::get(getContext(), NewTruncBits);
1882 return getMulExpr(
1883 getZeroExtendExpr(MulLHS, Ty),
1884 getZeroExtendExpr(
1885 getTruncateExpr(TruncRHS->getOperand(), NewTruncTy), Ty),
1886 SCEV::FlagNUW, Depth + 1);
1887 }
1888 }
1889
1890 // The cast wasn't folded; create an explicit cast node.
1891 // Recompute the insert position, as it may have been invalidated.
1892 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
1893 SCEV *S = new (SCEVAllocator) SCEVZeroExtendExpr(ID.Intern(SCEVAllocator),
1894 Op, Ty);
1895 UniqueSCEVs.InsertNode(S, IP);
1896 registerUser(S, Op);
1897 return S;
1898}
1899
1900const SCEV *
1901ScalarEvolution::getSignExtendExpr(const SCEV *Op, Type *Ty, unsigned Depth) {
1902 assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) &&(static_cast <bool> (getTypeSizeInBits(Op->getType()
) < getTypeSizeInBits(Ty) && "This is not an extending conversion!"
) ? void (0) : __assert_fail ("getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) && \"This is not an extending conversion!\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 1903, __extension__
__PRETTY_FUNCTION__))
1903 "This is not an extending conversion!")(static_cast <bool> (getTypeSizeInBits(Op->getType()
) < getTypeSizeInBits(Ty) && "This is not an extending conversion!"
) ? void (0) : __assert_fail ("getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) && \"This is not an extending conversion!\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 1903, __extension__
__PRETTY_FUNCTION__))
;
1904 assert(isSCEVable(Ty) &&(static_cast <bool> (isSCEVable(Ty) && "This is not a conversion to a SCEVable type!"
) ? void (0) : __assert_fail ("isSCEVable(Ty) && \"This is not a conversion to a SCEVable type!\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 1905, __extension__
__PRETTY_FUNCTION__))
1905 "This is not a conversion to a SCEVable type!")(static_cast <bool> (isSCEVable(Ty) && "This is not a conversion to a SCEVable type!"
) ? void (0) : __assert_fail ("isSCEVable(Ty) && \"This is not a conversion to a SCEVable type!\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 1905, __extension__
__PRETTY_FUNCTION__))
;
1906 assert(!Op->getType()->isPointerTy() && "Can't extend pointer!")(static_cast <bool> (!Op->getType()->isPointerTy(
) && "Can't extend pointer!") ? void (0) : __assert_fail
("!Op->getType()->isPointerTy() && \"Can't extend pointer!\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 1906, __extension__
__PRETTY_FUNCTION__))
;
1907 Ty = getEffectiveSCEVType(Ty);
1908
1909 // Fold if the operand is constant.
1910 if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op))
1911 return getConstant(
1912 cast<ConstantInt>(ConstantExpr::getSExt(SC->getValue(), Ty)));
1913
1914 // sext(sext(x)) --> sext(x)
1915 if (const SCEVSignExtendExpr *SS = dyn_cast<SCEVSignExtendExpr>(Op))
1916 return getSignExtendExpr(SS->getOperand(), Ty, Depth + 1);
1917
1918 // sext(zext(x)) --> zext(x)
1919 if (const SCEVZeroExtendExpr *SZ = dyn_cast<SCEVZeroExtendExpr>(Op))
1920 return getZeroExtendExpr(SZ->getOperand(), Ty, Depth + 1);
1921
1922 // Before doing any expensive analysis, check to see if we've already
1923 // computed a SCEV for this Op and Ty.
1924 FoldingSetNodeID ID;
1925 ID.AddInteger(scSignExtend);
1926 ID.AddPointer(Op);
1927 ID.AddPointer(Ty);
1928 void *IP = nullptr;
1929 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
1930 // Limit recursion depth.
1931 if (Depth > MaxCastDepth) {
1932 SCEV *S = new (SCEVAllocator) SCEVSignExtendExpr(ID.Intern(SCEVAllocator),
1933 Op, Ty);
1934 UniqueSCEVs.InsertNode(S, IP);
1935 registerUser(S, Op);
1936 return S;
1937 }
1938
1939 // sext(trunc(x)) --> sext(x) or x or trunc(x)
1940 if (const SCEVTruncateExpr *ST = dyn_cast<SCEVTruncateExpr>(Op)) {
1941 // It's possible the bits taken off by the truncate were all sign bits. If
1942 // so, we should be able to simplify this further.
1943 const SCEV *X = ST->getOperand();
1944 ConstantRange CR = getSignedRange(X);
1945 unsigned TruncBits = getTypeSizeInBits(ST->getType());
1946 unsigned NewBits = getTypeSizeInBits(Ty);
1947 if (CR.truncate(TruncBits).signExtend(NewBits).contains(
1948 CR.sextOrTrunc(NewBits)))
1949 return getTruncateOrSignExtend(X, Ty, Depth);
1950 }
1951
1952 if (auto *SA = dyn_cast<SCEVAddExpr>(Op)) {
1953 // sext((A + B + ...)<nsw>) --> (sext(A) + sext(B) + ...)<nsw>
1954 if (SA->hasNoSignedWrap()) {
1955 // If the addition does not sign overflow then we can, by definition,
1956 // commute the sign extension with the addition operation.
1957 SmallVector<const SCEV *, 4> Ops;
1958 for (const auto *Op : SA->operands())
1959 Ops.push_back(getSignExtendExpr(Op, Ty, Depth + 1));
1960 return getAddExpr(Ops, SCEV::FlagNSW, Depth + 1);
1961 }
1962
1963 // sext(C + x + y + ...) --> (sext(D) + sext((C - D) + x + y + ...))
1964 // if D + (C - D + x + y + ...) could be proven to not signed wrap
1965 // where D maximizes the number of trailing zeros of (C - D + x + y + ...)
1966 //
1967 // For instance, this will bring two seemingly different expressions:
1968 // 1 + sext(5 + 20 * %x + 24 * %y) and
1969 // sext(6 + 20 * %x + 24 * %y)
1970 // to the same form:
1971 // 2 + sext(4 + 20 * %x + 24 * %y)
1972 if (const auto *SC = dyn_cast<SCEVConstant>(SA->getOperand(0))) {
1973 const APInt &D = extractConstantWithoutWrapping(*this, SC, SA);
1974 if (D != 0) {
1975 const SCEV *SSExtD = getSignExtendExpr(getConstant(D), Ty, Depth);
1976 const SCEV *SResidual =
1977 getAddExpr(getConstant(-D), SA, SCEV::FlagAnyWrap, Depth);
1978 const SCEV *SSExtR = getSignExtendExpr(SResidual, Ty, Depth + 1);
1979 return getAddExpr(SSExtD, SSExtR,
1980 (SCEV::NoWrapFlags)(SCEV::FlagNSW | SCEV::FlagNUW),
1981 Depth + 1);
1982 }
1983 }
1984 }
1985 // If the input value is a chrec scev, and we can prove that the value
1986 // did not overflow the old, smaller, value, we can sign extend all of the
1987 // operands (often constants). This allows analysis of something like
1988 // this: for (signed char X = 0; X < 100; ++X) { int Y = X; }
1989 if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Op))
1990 if (AR->isAffine()) {
1991 const SCEV *Start = AR->getStart();
1992 const SCEV *Step = AR->getStepRecurrence(*this);
1993 unsigned BitWidth = getTypeSizeInBits(AR->getType());
1994 const Loop *L = AR->getLoop();
1995
1996 if (!AR->hasNoSignedWrap()) {
1997 auto NewFlags = proveNoWrapViaConstantRanges(AR);
1998 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), NewFlags);
1999 }
2000
2001 // If we have special knowledge that this addrec won't overflow,
2002 // we don't need to do any further analysis.
2003 if (AR->hasNoSignedWrap()) {
2004 Start =
2005 getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this, Depth + 1);
2006 Step = getSignExtendExpr(Step, Ty, Depth + 1);
2007 return getAddRecExpr(Start, Step, L, SCEV::FlagNSW);
2008 }
2009
2010 // Check whether the backedge-taken count is SCEVCouldNotCompute.
2011 // Note that this serves two purposes: It filters out loops that are
2012 // simply not analyzable, and it covers the case where this code is
2013 // being called from within backedge-taken count analysis, such that
2014 // attempting to ask for the backedge-taken count would likely result
2015 // in infinite recursion. In the later case, the analysis code will
2016 // cope with a conservative value, and it will take care to purge
2017 // that value once it has finished.
2018 const SCEV *MaxBECount = getConstantMaxBackedgeTakenCount(L);
2019 if (!isa<SCEVCouldNotCompute>(MaxBECount)) {
2020 // Manually compute the final value for AR, checking for
2021 // overflow.
2022
2023 // Check whether the backedge-taken count can be losslessly casted to
2024 // the addrec's type. The count is always unsigned.
2025 const SCEV *CastedMaxBECount =
2026 getTruncateOrZeroExtend(MaxBECount, Start->getType(), Depth);
2027 const SCEV *RecastedMaxBECount = getTruncateOrZeroExtend(
2028 CastedMaxBECount, MaxBECount->getType(), Depth);
2029 if (MaxBECount == RecastedMaxBECount) {
2030 Type *WideTy = IntegerType::get(getContext(), BitWidth * 2);
2031 // Check whether Start+Step*MaxBECount has no signed overflow.
2032 const SCEV *SMul = getMulExpr(CastedMaxBECount, Step,
2033 SCEV::FlagAnyWrap, Depth + 1);
2034 const SCEV *SAdd = getSignExtendExpr(getAddExpr(Start, SMul,
2035 SCEV::FlagAnyWrap,
2036 Depth + 1),
2037 WideTy, Depth + 1);
2038 const SCEV *WideStart = getSignExtendExpr(Start, WideTy, Depth + 1);
2039 const SCEV *WideMaxBECount =
2040 getZeroExtendExpr(CastedMaxBECount, WideTy, Depth + 1);
2041 const SCEV *OperandExtendedAdd =
2042 getAddExpr(WideStart,
2043 getMulExpr(WideMaxBECount,
2044 getSignExtendExpr(Step, WideTy, Depth + 1),
2045 SCEV::FlagAnyWrap, Depth + 1),
2046 SCEV::FlagAnyWrap, Depth + 1);
2047 if (SAdd == OperandExtendedAdd) {
2048 // Cache knowledge of AR NSW, which is propagated to this AddRec.
2049 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNSW);
2050 // Return the expression with the addrec on the outside.
2051 Start = getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this,
2052 Depth + 1);
2053 Step = getSignExtendExpr(Step, Ty, Depth + 1);
2054 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
2055 }
2056 // Similar to above, only this time treat the step value as unsigned.
2057 // This covers loops that count up with an unsigned step.
2058 OperandExtendedAdd =
2059 getAddExpr(WideStart,
2060 getMulExpr(WideMaxBECount,
2061 getZeroExtendExpr(Step, WideTy, Depth + 1),
2062 SCEV::FlagAnyWrap, Depth + 1),
2063 SCEV::FlagAnyWrap, Depth + 1);
2064 if (SAdd == OperandExtendedAdd) {
2065 // If AR wraps around then
2066 //
2067 // abs(Step) * MaxBECount > unsigned-max(AR->getType())
2068 // => SAdd != OperandExtendedAdd
2069 //
2070 // Thus (AR is not NW => SAdd != OperandExtendedAdd) <=>
2071 // (SAdd == OperandExtendedAdd => AR is NW)
2072
2073 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNW);
2074
2075 // Return the expression with the addrec on the outside.
2076 Start = getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this,
2077 Depth + 1);
2078 Step = getZeroExtendExpr(Step, Ty, Depth + 1);
2079 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
2080 }
2081 }
2082 }
2083
2084 auto NewFlags = proveNoSignedWrapViaInduction(AR);
2085 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), NewFlags);
2086 if (AR->hasNoSignedWrap()) {
2087 // Same as nsw case above - duplicated here to avoid a compile time
2088 // issue. It's not clear that the order of checks does matter, but
2089 // it's one of two issue possible causes for a change which was
2090 // reverted. Be conservative for the moment.
2091 Start =
2092 getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this, Depth + 1);
2093 Step = getSignExtendExpr(Step, Ty, Depth + 1);
2094 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
2095 }
2096
2097 // sext({C,+,Step}) --> (sext(D) + sext({C-D,+,Step}))<nuw><nsw>
2098 // if D + (C - D + Step * n) could be proven to not signed wrap
2099 // where D maximizes the number of trailing zeros of (C - D + Step * n)
2100 if (const auto *SC = dyn_cast<SCEVConstant>(Start)) {
2101 const APInt &C = SC->getAPInt();
2102 const APInt &D = extractConstantWithoutWrapping(*this, C, Step);
2103 if (D != 0) {
2104 const SCEV *SSExtD = getSignExtendExpr(getConstant(D), Ty, Depth);
2105 const SCEV *SResidual =
2106 getAddRecExpr(getConstant(C - D), Step, L, AR->getNoWrapFlags());
2107 const SCEV *SSExtR = getSignExtendExpr(SResidual, Ty, Depth + 1);
2108 return getAddExpr(SSExtD, SSExtR,
2109 (SCEV::NoWrapFlags)(SCEV::FlagNSW | SCEV::FlagNUW),
2110 Depth + 1);
2111 }
2112 }
2113
2114 if (proveNoWrapByVaryingStart<SCEVSignExtendExpr>(Start, Step, L)) {
2115 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNSW);
2116 Start =
2117 getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this, Depth + 1);
2118 Step = getSignExtendExpr(Step, Ty, Depth + 1);
2119 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
2120 }
2121 }
2122
2123 // If the input value is provably positive and we could not simplify
2124 // away the sext build a zext instead.
2125 if (isKnownNonNegative(Op))
2126 return getZeroExtendExpr(Op, Ty, Depth + 1);
2127
2128 // The cast wasn't folded; create an explicit cast node.
2129 // Recompute the insert position, as it may have been invalidated.
2130 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
2131 SCEV *S = new (SCEVAllocator) SCEVSignExtendExpr(ID.Intern(SCEVAllocator),
2132 Op, Ty);
2133 UniqueSCEVs.InsertNode(S, IP);
2134 registerUser(S, { Op });
2135 return S;
2136}
2137
2138const SCEV *ScalarEvolution::getCastExpr(SCEVTypes Kind, const SCEV *Op,
2139 Type *Ty) {
2140 switch (Kind) {
2141 case scTruncate:
2142 return getTruncateExpr(Op, Ty);
2143 case scZeroExtend:
2144 return getZeroExtendExpr(Op, Ty);
2145 case scSignExtend:
2146 return getSignExtendExpr(Op, Ty);
2147 case scPtrToInt:
2148 return getPtrToIntExpr(Op, Ty);
2149 default:
2150 llvm_unreachable("Not a SCEV cast expression!")::llvm::llvm_unreachable_internal("Not a SCEV cast expression!"
, "llvm/lib/Analysis/ScalarEvolution.cpp", 2150)
;
2151 }
2152}
2153
2154/// getAnyExtendExpr - Return a SCEV for the given operand extended with
2155/// unspecified bits out to the given type.
2156const SCEV *ScalarEvolution::getAnyExtendExpr(const SCEV *Op,
2157 Type *Ty) {
2158 assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) &&(static_cast <bool> (getTypeSizeInBits(Op->getType()
) < getTypeSizeInBits(Ty) && "This is not an extending conversion!"
) ? void (0) : __assert_fail ("getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) && \"This is not an extending conversion!\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 2159, __extension__
__PRETTY_FUNCTION__))
2159 "This is not an extending conversion!")(static_cast <bool> (getTypeSizeInBits(Op->getType()
) < getTypeSizeInBits(Ty) && "This is not an extending conversion!"
) ? void (0) : __assert_fail ("getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) && \"This is not an extending conversion!\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 2159, __extension__
__PRETTY_FUNCTION__))
;
2160 assert(isSCEVable(Ty) &&(static_cast <bool> (isSCEVable(Ty) && "This is not a conversion to a SCEVable type!"
) ? void (0) : __assert_fail ("isSCEVable(Ty) && \"This is not a conversion to a SCEVable type!\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 2161, __extension__
__PRETTY_FUNCTION__))
2161 "This is not a conversion to a SCEVable type!")(static_cast <bool> (isSCEVable(Ty) && "This is not a conversion to a SCEVable type!"
) ? void (0) : __assert_fail ("isSCEVable(Ty) && \"This is not a conversion to a SCEVable type!\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 2161, __extension__
__PRETTY_FUNCTION__))
;
2162 Ty = getEffectiveSCEVType(Ty);
2163
2164 // Sign-extend negative constants.
2165 if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op))
2166 if (SC->getAPInt().isNegative())
2167 return getSignExtendExpr(Op, Ty);
2168
2169 // Peel off a truncate cast.
2170 if (const SCEVTruncateExpr *T = dyn_cast<SCEVTruncateExpr>(Op)) {
2171 const SCEV *NewOp = T->getOperand();
2172 if (getTypeSizeInBits(NewOp->getType()) < getTypeSizeInBits(Ty))
2173 return getAnyExtendExpr(NewOp, Ty);
2174 return getTruncateOrNoop(NewOp, Ty);
2175 }
2176
2177 // Next try a zext cast. If the cast is folded, use it.
2178 const SCEV *ZExt = getZeroExtendExpr(Op, Ty);
2179 if (!isa<SCEVZeroExtendExpr>(ZExt))
2180 return ZExt;
2181
2182 // Next try a sext cast. If the cast is folded, use it.
2183 const SCEV *SExt = getSignExtendExpr(Op, Ty);
2184 if (!isa<SCEVSignExtendExpr>(SExt))
2185 return SExt;
2186
2187 // Force the cast to be folded into the operands of an addrec.
2188 if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Op)) {
2189 SmallVector<const SCEV *, 4> Ops;
2190 for (const SCEV *Op : AR->operands())
2191 Ops.push_back(getAnyExtendExpr(Op, Ty));
2192 return getAddRecExpr(Ops, AR->getLoop(), SCEV::FlagNW);
2193 }
2194
2195 // If the expression is obviously signed, use the sext cast value.
2196 if (isa<SCEVSMaxExpr>(Op))
2197 return SExt;
2198
2199 // Absent any other information, use the zext cast value.
2200 return ZExt;
2201}
2202
2203/// Process the given Ops list, which is a list of operands to be added under
2204/// the given scale, update the given map. This is a helper function for
2205/// getAddRecExpr. As an example of what it does, given a sequence of operands
2206/// that would form an add expression like this:
2207///
2208/// m + n + 13 + (A * (o + p + (B * (q + m + 29)))) + r + (-1 * r)
2209///
2210/// where A and B are constants, update the map with these values:
2211///
2212/// (m, 1+A*B), (n, 1), (o, A), (p, A), (q, A*B), (r, 0)
2213///
2214/// and add 13 + A*B*29 to AccumulatedConstant.
2215/// This will allow getAddRecExpr to produce this:
2216///
2217/// 13+A*B*29 + n + (m * (1+A*B)) + ((o + p) * A) + (q * A*B)
2218///
2219/// This form often exposes folding opportunities that are hidden in
2220/// the original operand list.
2221///
2222/// Return true iff it appears that any interesting folding opportunities
2223/// may be exposed. This helps getAddRecExpr short-circuit extra work in
2224/// the common case where no interesting opportunities are present, and
2225/// is also used as a check to avoid infinite recursion.
2226static bool
2227CollectAddOperandsWithScales(DenseMap<const SCEV *, APInt> &M,
2228 SmallVectorImpl<const SCEV *> &NewOps,
2229 APInt &AccumulatedConstant,
2230 const SCEV *const *Ops, size_t NumOperands,
2231 const APInt &Scale,
2232 ScalarEvolution &SE) {
2233 bool Interesting = false;
2234
2235 // Iterate over the add operands. They are sorted, with constants first.
2236 unsigned i = 0;
2237 while (const SCEVConstant *C = dyn_cast<SCEVConstant>(Ops[i])) {
2238 ++i;
2239 // Pull a buried constant out to the outside.
2240 if (Scale != 1 || AccumulatedConstant != 0 || C->getValue()->isZero())
2241 Interesting = true;
2242 AccumulatedConstant += Scale * C->getAPInt();
2243 }
2244
2245 // Next comes everything else. We're especially interested in multiplies
2246 // here, but they're in the middle, so just visit the rest with one loop.
2247 for (; i != NumOperands; ++i) {
2248 const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(Ops[i]);
2249 if (Mul && isa<SCEVConstant>(Mul->getOperand(0))) {
2250 APInt NewScale =
2251 Scale * cast<SCEVConstant>(Mul->getOperand(0))->getAPInt();
2252 if (Mul->getNumOperands() == 2 && isa<SCEVAddExpr>(Mul->getOperand(1))) {
2253 // A multiplication of a constant with another add; recurse.
2254 const SCEVAddExpr *Add = cast<SCEVAddExpr>(Mul->getOperand(1));
2255 Interesting |=
2256 CollectAddOperandsWithScales(M, NewOps, AccumulatedConstant,
2257 Add->op_begin(), Add->getNumOperands(),
2258 NewScale, SE);
2259 } else {
2260 // A multiplication of a constant with some other value. Update
2261 // the map.
2262 SmallVector<const SCEV *, 4> MulOps(drop_begin(Mul->operands()));
2263 const SCEV *Key = SE.getMulExpr(MulOps);
2264 auto Pair = M.insert({Key, NewScale});
2265 if (Pair.second) {
2266 NewOps.push_back(Pair.first->first);
2267 } else {
2268 Pair.first->second += NewScale;
2269 // The map already had an entry for this value, which may indicate
2270 // a folding opportunity.
2271 Interesting = true;
2272 }
2273 }
2274 } else {
2275 // An ordinary operand. Update the map.
2276 std::pair<DenseMap<const SCEV *, APInt>::iterator, bool> Pair =
2277 M.insert({Ops[i], Scale});
2278 if (Pair.second) {
2279 NewOps.push_back(Pair.first->first);
2280 } else {
2281 Pair.first->second += Scale;
2282 // The map already had an entry for this value, which may indicate
2283 // a folding opportunity.
2284 Interesting = true;
2285 }
2286 }
2287 }
2288
2289 return Interesting;
2290}
2291
2292bool ScalarEvolution::willNotOverflow(Instruction::BinaryOps BinOp, bool Signed,
2293 const SCEV *LHS, const SCEV *RHS,
2294 const Instruction *CtxI) {
2295 const SCEV *(ScalarEvolution::*Operation)(const SCEV *, const SCEV *,
2296 SCEV::NoWrapFlags, unsigned);
2297 switch (BinOp) {
2298 default:
2299 llvm_unreachable("Unsupported binary op")::llvm::llvm_unreachable_internal("Unsupported binary op", "llvm/lib/Analysis/ScalarEvolution.cpp"
, 2299)
;
2300 case Instruction::Add:
2301 Operation = &ScalarEvolution::getAddExpr;
2302 break;
2303 case Instruction::Sub:
2304 Operation = &ScalarEvolution::getMinusSCEV;
2305 break;
2306 case Instruction::Mul:
2307 Operation = &ScalarEvolution::getMulExpr;
2308 break;
2309 }
2310
2311 const SCEV *(ScalarEvolution::*Extension)(const SCEV *, Type *, unsigned) =
2312 Signed ? &ScalarEvolution::getSignExtendExpr
2313 : &ScalarEvolution::getZeroExtendExpr;
2314
2315 // Check ext(LHS op RHS) == ext(LHS) op ext(RHS)
2316 auto *NarrowTy = cast<IntegerType>(LHS->getType());
2317 auto *WideTy =
2318 IntegerType::get(NarrowTy->getContext(), NarrowTy->getBitWidth() * 2);
2319
2320 const SCEV *A = (this->*Extension)(
2321 (this->*Operation)(LHS, RHS, SCEV::FlagAnyWrap, 0), WideTy, 0);
2322 const SCEV *LHSB = (this->*Extension)(LHS, WideTy, 0);
2323 const SCEV *RHSB = (this->*Extension)(RHS, WideTy, 0);
2324 const SCEV *B = (this->*Operation)(LHSB, RHSB, SCEV::FlagAnyWrap, 0);
2325 if (A == B)
2326 return true;
2327 // Can we use context to prove the fact we need?
2328 if (!CtxI)
2329 return false;
2330 // We can prove that add(x, constant) doesn't wrap if isKnownPredicateAt can
2331 // guarantee that x <= max_int - constant at the given context.
2332 // TODO: Support other operations.
2333 if (BinOp != Instruction::Add)
2334 return false;
2335 auto *RHSC = dyn_cast<SCEVConstant>(RHS);
2336 // TODO: Lift this limitation.
2337 if (!RHSC)
2338 return false;
2339 APInt C = RHSC->getAPInt();
2340 // TODO: Also lift this limitation.
2341 if (Signed && C.isNegative())
2342 return false;
2343 unsigned NumBits = C.getBitWidth();
2344 APInt Max =
2345 Signed ? APInt::getSignedMaxValue(NumBits) : APInt::getMaxValue(NumBits);
2346 APInt Limit = Max - C;
2347 ICmpInst::Predicate Pred = Signed ? ICmpInst::ICMP_SLE : ICmpInst::ICMP_ULE;
2348 return isKnownPredicateAt(Pred, LHS, getConstant(Limit), CtxI);
2349}
2350
2351Optional<SCEV::NoWrapFlags>
2352ScalarEvolution::getStrengthenedNoWrapFlagsFromBinOp(
2353 const OverflowingBinaryOperator *OBO) {
2354 // It cannot be done any better.
2355 if (OBO->hasNoUnsignedWrap() && OBO->hasNoSignedWrap())
2356 return None;
2357
2358 SCEV::NoWrapFlags Flags = SCEV::NoWrapFlags::FlagAnyWrap;
2359
2360 if (OBO->hasNoUnsignedWrap())
2361 Flags = ScalarEvolution::setFlags(Flags, SCEV::FlagNUW);
2362 if (OBO->hasNoSignedWrap())
2363 Flags = ScalarEvolution::setFlags(Flags, SCEV::FlagNSW);
2364
2365 bool Deduced = false;
2366
2367 if (OBO->getOpcode() != Instruction::Add &&
2368 OBO->getOpcode() != Instruction::Sub &&
2369 OBO->getOpcode() != Instruction::Mul)
2370 return None;
2371
2372 const SCEV *LHS = getSCEV(OBO->getOperand(0));
2373 const SCEV *RHS = getSCEV(OBO->getOperand(1));
2374
2375 const Instruction *CtxI =
2376 UseContextForNoWrapFlagInference ? dyn_cast<Instruction>(OBO) : nullptr;
2377 if (!OBO->hasNoUnsignedWrap() &&
2378 willNotOverflow((Instruction::BinaryOps)OBO->getOpcode(),
2379 /* Signed */ false, LHS, RHS, CtxI)) {
2380 Flags = ScalarEvolution::setFlags(Flags, SCEV::FlagNUW);
2381 Deduced = true;
2382 }
2383
2384 if (!OBO->hasNoSignedWrap() &&
2385 willNotOverflow((Instruction::BinaryOps)OBO->getOpcode(),
2386 /* Signed */ true, LHS, RHS, CtxI)) {
2387 Flags = ScalarEvolution::setFlags(Flags, SCEV::FlagNSW);
2388 Deduced = true;
2389 }
2390
2391 if (Deduced)
2392 return Flags;
2393 return None;
2394}
2395
2396// We're trying to construct a SCEV of type `Type' with `Ops' as operands and
2397// `OldFlags' as can't-wrap behavior. Infer a more aggressive set of
2398// can't-overflow flags for the operation if possible.
2399static SCEV::NoWrapFlags
2400StrengthenNoWrapFlags(ScalarEvolution *SE, SCEVTypes Type,
2401 const ArrayRef<const SCEV *> Ops,
2402 SCEV::NoWrapFlags Flags) {
2403 using namespace std::placeholders;
2404
2405 using OBO = OverflowingBinaryOperator;
2406
2407 bool CanAnalyze =
2408 Type == scAddExpr || Type == scAddRecExpr || Type == scMulExpr;
2409 (void)CanAnalyze;
2410 assert(CanAnalyze && "don't call from other places!")(static_cast <bool> (CanAnalyze && "don't call from other places!"
) ? void (0) : __assert_fail ("CanAnalyze && \"don't call from other places!\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 2410, __extension__
__PRETTY_FUNCTION__))
;
2411
2412 int SignOrUnsignMask = SCEV::FlagNUW | SCEV::FlagNSW;
2413 SCEV::NoWrapFlags SignOrUnsignWrap =
2414 ScalarEvolution::maskFlags(Flags, SignOrUnsignMask);
2415
2416 // If FlagNSW is true and all the operands are non-negative, infer FlagNUW.
2417 auto IsKnownNonNegative = [&](const SCEV *S) {
2418 return SE->isKnownNonNegative(S);
2419 };
2420
2421 if (SignOrUnsignWrap == SCEV::FlagNSW && all_of(Ops, IsKnownNonNegative))
2422 Flags =
2423 ScalarEvolution::setFlags(Flags, (SCEV::NoWrapFlags)SignOrUnsignMask);
2424
2425 SignOrUnsignWrap = ScalarEvolution::maskFlags(Flags, SignOrUnsignMask);
2426
2427 if (SignOrUnsignWrap != SignOrUnsignMask &&
2428 (Type == scAddExpr || Type == scMulExpr) && Ops.size() == 2 &&
2429 isa<SCEVConstant>(Ops[0])) {
2430
2431 auto Opcode = [&] {
2432 switch (Type) {
2433 case scAddExpr:
2434 return Instruction::Add;
2435 case scMulExpr:
2436 return Instruction::Mul;
2437 default:
2438 llvm_unreachable("Unexpected SCEV op.")::llvm::llvm_unreachable_internal("Unexpected SCEV op.", "llvm/lib/Analysis/ScalarEvolution.cpp"
, 2438)
;
2439 }
2440 }();
2441
2442 const APInt &C = cast<SCEVConstant>(Ops[0])->getAPInt();
2443
2444 // (A <opcode> C) --> (A <opcode> C)<nsw> if the op doesn't sign overflow.
2445 if (!(SignOrUnsignWrap & SCEV::FlagNSW)) {
2446 auto NSWRegion = ConstantRange::makeGuaranteedNoWrapRegion(
2447 Opcode, C, OBO::NoSignedWrap);
2448 if (NSWRegion.contains(SE->getSignedRange(Ops[1])))
2449 Flags = ScalarEvolution::setFlags(Flags, SCEV::FlagNSW);
2450 }
2451
2452 // (A <opcode> C) --> (A <opcode> C)<nuw> if the op doesn't unsign overflow.
2453 if (!(SignOrUnsignWrap & SCEV::FlagNUW)) {
2454 auto NUWRegion = ConstantRange::makeGuaranteedNoWrapRegion(
2455 Opcode, C, OBO::NoUnsignedWrap);
2456 if (NUWRegion.contains(SE->getUnsignedRange(Ops[1])))
2457 Flags = ScalarEvolution::setFlags(Flags, SCEV::FlagNUW);
2458 }
2459 }
2460
2461 // <0,+,nonnegative><nw> is also nuw
2462 // TODO: Add corresponding nsw case
2463 if (Type == scAddRecExpr && ScalarEvolution::hasFlags(Flags, SCEV::FlagNW) &&
2464 !ScalarEvolution::hasFlags(Flags, SCEV::FlagNUW) && Ops.size() == 2 &&
2465 Ops[0]->isZero() && IsKnownNonNegative(Ops[1]))
2466 Flags = ScalarEvolution::setFlags(Flags, SCEV::FlagNUW);
2467
2468 // both (udiv X, Y) * Y and Y * (udiv X, Y) are always NUW
2469 if (Type == scMulExpr && !ScalarEvolution::hasFlags(Flags, SCEV::FlagNUW) &&
2470 Ops.size() == 2) {
2471 if (auto *UDiv = dyn_cast<SCEVUDivExpr>(Ops[0]))
2472 if (UDiv->getOperand(1) == Ops[1])
2473 Flags = ScalarEvolution::setFlags(Flags, SCEV::FlagNUW);
2474 if (auto *UDiv = dyn_cast<SCEVUDivExpr>(Ops[1]))
2475 if (UDiv->getOperand(1) == Ops[0])
2476 Flags = ScalarEvolution::setFlags(Flags, SCEV::FlagNUW);
2477 }
2478
2479 return Flags;
2480}
2481
2482bool ScalarEvolution::isAvailableAtLoopEntry(const SCEV *S, const Loop *L) {
2483 return isLoopInvariant(S, L) && properlyDominates(S, L->getHeader());
2484}
2485
2486/// Get a canonical add expression, or something simpler if possible.
2487const SCEV *ScalarEvolution::getAddExpr(SmallVectorImpl<const SCEV *> &Ops,
2488 SCEV::NoWrapFlags OrigFlags,
2489 unsigned Depth) {
2490 assert(!(OrigFlags & ~(SCEV::FlagNUW | SCEV::FlagNSW)) &&(static_cast <bool> (!(OrigFlags & ~(SCEV::FlagNUW |
SCEV::FlagNSW)) && "only nuw or nsw allowed") ? void
(0) : __assert_fail ("!(OrigFlags & ~(SCEV::FlagNUW | SCEV::FlagNSW)) && \"only nuw or nsw allowed\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 2491, __extension__
__PRETTY_FUNCTION__))
2491 "only nuw or nsw allowed")(static_cast <bool> (!(OrigFlags & ~(SCEV::FlagNUW |
SCEV::FlagNSW)) && "only nuw or nsw allowed") ? void
(0) : __assert_fail ("!(OrigFlags & ~(SCEV::FlagNUW | SCEV::FlagNSW)) && \"only nuw or nsw allowed\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 2491, __extension__
__PRETTY_FUNCTION__))
;
2492 assert(!Ops.empty() && "Cannot get empty add!")(static_cast <bool> (!Ops.empty() && "Cannot get empty add!"
) ? void (0) : __assert_fail ("!Ops.empty() && \"Cannot get empty add!\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 2492, __extension__
__PRETTY_FUNCTION__))
;
2493 if (Ops.size() == 1) return Ops[0];
2494#ifndef NDEBUG
2495 Type *ETy = getEffectiveSCEVType(Ops[0]->getType());
2496 for (unsigned i = 1, e = Ops.size(); i != e; ++i)
2497 assert(getEffectiveSCEVType(Ops[i]->getType()) == ETy &&(static_cast <bool> (getEffectiveSCEVType(Ops[i]->getType
()) == ETy && "SCEVAddExpr operand types don't match!"
) ? void (0) : __assert_fail ("getEffectiveSCEVType(Ops[i]->getType()) == ETy && \"SCEVAddExpr operand types don't match!\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 2498, __extension__
__PRETTY_FUNCTION__))
2498 "SCEVAddExpr operand types don't match!")(static_cast <bool> (getEffectiveSCEVType(Ops[i]->getType
()) == ETy && "SCEVAddExpr operand types don't match!"
) ? void (0) : __assert_fail ("getEffectiveSCEVType(Ops[i]->getType()) == ETy && \"SCEVAddExpr operand types don't match!\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 2498, __extension__
__PRETTY_FUNCTION__))
;
2499 unsigned NumPtrs = count_if(
2500 Ops, [](const SCEV *Op) { return Op->getType()->isPointerTy(); });
2501 assert(NumPtrs <= 1 && "add has at most one pointer operand")(static_cast <bool> (NumPtrs <= 1 && "add has at most one pointer operand"
) ? void (0) : __assert_fail ("NumPtrs <= 1 && \"add has at most one pointer operand\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 2501, __extension__
__PRETTY_FUNCTION__))
;
2502#endif
2503
2504 // Sort by complexity, this groups all similar expression types together.
2505 GroupByComplexity(Ops, &LI, DT);
2506
2507 // If there are any constants, fold them together.
2508 unsigned Idx = 0;
2509 if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(Ops[0])) {
2510 ++Idx;
2511 assert(Idx < Ops.size())(static_cast <bool> (Idx < Ops.size()) ? void (0) : __assert_fail
("Idx < Ops.size()", "llvm/lib/Analysis/ScalarEvolution.cpp"
, 2511, __extension__ __PRETTY_FUNCTION__))
;
2512 while (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(Ops[Idx])) {
2513 // We found two constants, fold them together!
2514 Ops[0] = getConstant(LHSC->getAPInt() + RHSC->getAPInt());
2515 if (Ops.size() == 2) return Ops[0];
2516 Ops.erase(Ops.begin()+1); // Erase the folded element
2517 LHSC = cast<SCEVConstant>(Ops[0]);
2518 }
2519
2520 // If we are left with a constant zero being added, strip it off.
2521 if (LHSC->getValue()->isZero()) {
2522 Ops.erase(Ops.begin());
2523 --Idx;
2524 }
2525
2526 if (Ops.size() == 1) return Ops[0];
2527 }
2528
2529 // Delay expensive flag strengthening until necessary.
2530 auto ComputeFlags = [this, OrigFlags](const ArrayRef<const SCEV *> Ops) {
2531 return StrengthenNoWrapFlags(this, scAddExpr, Ops, OrigFlags);
2532 };
2533
2534 // Limit recursion calls depth.
2535 if (Depth > MaxArithDepth || hasHugeExpression(Ops))
2536 return getOrCreateAddExpr(Ops, ComputeFlags(Ops));
2537
2538 if (SCEV *S = findExistingSCEVInCache(scAddExpr, Ops)) {
2539 // Don't strengthen flags if we have no new information.
2540 SCEVAddExpr *Add = static_cast<SCEVAddExpr *>(S);
2541 if (Add->getNoWrapFlags(OrigFlags) != OrigFlags)
2542 Add->setNoWrapFlags(ComputeFlags(Ops));
2543 return S;
2544 }
2545
2546 // Okay, check to see if the same value occurs in the operand list more than
2547 // once. If so, merge them together into an multiply expression. Since we
2548 // sorted the list, these values are required to be adjacent.
2549 Type *Ty = Ops[0]->getType();
2550 bool FoundMatch = false;
2551 for (unsigned i = 0, e = Ops.size(); i != e-1; ++i)
2552 if (Ops[i] == Ops[i+1]) { // X + Y + Y --> X + Y*2
2553 // Scan ahead to count how many equal operands there are.
2554 unsigned Count = 2;
2555 while (i+Count != e && Ops[i+Count] == Ops[i])
2556 ++Count;
2557 // Merge the values into a multiply.
2558 const SCEV *Scale = getConstant(Ty, Count);
2559 const SCEV *Mul = getMulExpr(Scale, Ops[i], SCEV::FlagAnyWrap, Depth + 1);
2560 if (Ops.size() == Count)
2561 return Mul;
2562 Ops[i] = Mul;
2563 Ops.erase(Ops.begin()+i+1, Ops.begin()+i+Count);
2564 --i; e -= Count - 1;
2565 FoundMatch = true;
2566 }
2567 if (FoundMatch)
2568 return getAddExpr(Ops, OrigFlags, Depth + 1);
2569
2570 // Check for truncates. If all the operands are truncated from the same
2571 // type, see if factoring out the truncate would permit the result to be
2572 // folded. eg., n*trunc(x) + m*trunc(y) --> trunc(trunc(m)*x + trunc(n)*y)
2573 // if the contents of the resulting outer trunc fold to something simple.
2574 auto FindTruncSrcType = [&]() -> Type * {
2575 // We're ultimately looking to fold an addrec of truncs and muls of only
2576 // constants and truncs, so if we find any other types of SCEV
2577 // as operands of the addrec then we bail and return nullptr here.
2578 // Otherwise, we return the type of the operand of a trunc that we find.
2579 if (auto *T = dyn_cast<SCEVTruncateExpr>(Ops[Idx]))
2580 return T->getOperand()->getType();
2581 if (const auto *Mul = dyn_cast<SCEVMulExpr>(Ops[Idx])) {
2582 const auto *LastOp = Mul->getOperand(Mul->getNumOperands() - 1);
2583 if (const auto *T = dyn_cast<SCEVTruncateExpr>(LastOp))
2584 return T->getOperand()->getType();
2585 }
2586 return nullptr;
2587 };
2588 if (auto *SrcType = FindTruncSrcType()) {
2589 SmallVector<const SCEV *, 8> LargeOps;
2590 bool Ok = true;
2591 // Check all the operands to see if they can be represented in the
2592 // source type of the truncate.
2593 for (unsigned i = 0, e = Ops.size(); i != e; ++i) {
2594 if (const SCEVTruncateExpr *T = dyn_cast<SCEVTruncateExpr>(Ops[i])) {
2595 if (T->getOperand()->getType() != SrcType) {
2596 Ok = false;
2597 break;
2598 }
2599 LargeOps.push_back(T->getOperand());
2600 } else if (const SCEVConstant *C = dyn_cast<SCEVConstant>(Ops[i])) {
2601 LargeOps.push_back(getAnyExtendExpr(C, SrcType));
2602 } else if (const SCEVMulExpr *M = dyn_cast<SCEVMulExpr>(Ops[i])) {
2603 SmallVector<const SCEV *, 8> LargeMulOps;
2604 for (unsigned j = 0, f = M->getNumOperands(); j != f && Ok; ++j) {
2605 if (const SCEVTruncateExpr *T =
2606 dyn_cast<SCEVTruncateExpr>(M->getOperand(j))) {
2607 if (T->getOperand()->getType() != SrcType) {
2608 Ok = false;
2609 break;
2610 }
2611 LargeMulOps.push_back(T->getOperand());
2612 } else if (const auto *C = dyn_cast<SCEVConstant>(M->getOperand(j))) {
2613 LargeMulOps.push_back(getAnyExtendExpr(C, SrcType));
2614 } else {
2615 Ok = false;
2616 break;
2617 }
2618 }
2619 if (Ok)
2620 LargeOps.push_back(getMulExpr(LargeMulOps, SCEV::FlagAnyWrap, Depth + 1));
2621 } else {
2622 Ok = false;
2623 break;
2624 }
2625 }
2626 if (Ok) {
2627 // Evaluate the expression in the larger type.
2628 const SCEV *Fold = getAddExpr(LargeOps, SCEV::FlagAnyWrap, Depth + 1);
2629 // If it folds to something simple, use it. Otherwise, don't.
2630 if (isa<SCEVConstant>(Fold) || isa<SCEVUnknown>(Fold))
2631 return getTruncateExpr(Fold, Ty);
2632 }
2633 }
2634
2635 if (Ops.size() == 2) {
2636 // Check if we have an expression of the form ((X + C1) - C2), where C1 and
2637 // C2 can be folded in a way that allows retaining wrapping flags of (X +
2638 // C1).
2639 const SCEV *A = Ops[0];
2640 const SCEV *B = Ops[1];
2641 auto *AddExpr = dyn_cast<SCEVAddExpr>(B);
2642 auto *C = dyn_cast<SCEVConstant>(A);
2643 if (AddExpr && C && isa<SCEVConstant>(AddExpr->getOperand(0))) {
2644 auto C1 = cast<SCEVConstant>(AddExpr->getOperand(0))->getAPInt();
2645 auto C2 = C->getAPInt();
2646 SCEV::NoWrapFlags PreservedFlags = SCEV::FlagAnyWrap;
2647
2648 APInt ConstAdd = C1 + C2;
2649 auto AddFlags = AddExpr->getNoWrapFlags();
2650 // Adding a smaller constant is NUW if the original AddExpr was NUW.
2651 if (ScalarEvolution::hasFlags(AddFlags, SCEV::FlagNUW) &&
2652 ConstAdd.ule(C1)) {
2653 PreservedFlags =
2654 ScalarEvolution::setFlags(PreservedFlags, SCEV::FlagNUW);
2655 }
2656
2657 // Adding a constant with the same sign and small magnitude is NSW, if the
2658 // original AddExpr was NSW.
2659 if (ScalarEvolution::hasFlags(AddFlags, SCEV::FlagNSW) &&
2660 C1.isSignBitSet() == ConstAdd.isSignBitSet() &&
2661 ConstAdd.abs().ule(C1.abs())) {
2662 PreservedFlags =
2663 ScalarEvolution::setFlags(PreservedFlags, SCEV::FlagNSW);
2664 }
2665
2666 if (PreservedFlags != SCEV::FlagAnyWrap) {
2667 SmallVector<const SCEV *, 4> NewOps(AddExpr->operands());
2668 NewOps[0] = getConstant(ConstAdd);
2669 return getAddExpr(NewOps, PreservedFlags);
2670 }
2671 }
2672 }
2673
2674 // Canonicalize (-1 * urem X, Y) + X --> (Y * X/Y)
2675 if (Ops.size() == 2) {
2676 const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(Ops[0]);
2677 if (Mul && Mul->getNumOperands() == 2 &&
2678 Mul->getOperand(0)->isAllOnesValue()) {
2679 const SCEV *X;
2680 const SCEV *Y;
2681 if (matchURem(Mul->getOperand(1), X, Y) && X == Ops[1]) {
2682 return getMulExpr(Y, getUDivExpr(X, Y));
2683 }
2684 }
2685 }
2686
2687 // Skip past any other cast SCEVs.
2688 while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scAddExpr)
2689 ++Idx;
2690
2691 // If there are add operands they would be next.
2692 if (Idx < Ops.size()) {
2693 bool DeletedAdd = false;
2694 // If the original flags and all inlined SCEVAddExprs are NUW, use the
2695 // common NUW flag for expression after inlining. Other flags cannot be
2696 // preserved, because they may depend on the original order of operations.
2697 SCEV::NoWrapFlags CommonFlags = maskFlags(OrigFlags, SCEV::FlagNUW);
2698 while (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Ops[Idx])) {
2699 if (Ops.size() > AddOpsInlineThreshold ||
2700 Add->getNumOperands() > AddOpsInlineThreshold)
2701 break;
2702 // If we have an add, expand the add operands onto the end of the operands
2703 // list.
2704 Ops.erase(Ops.begin()+Idx);
2705 Ops.append(Add->op_begin(), Add->op_end());
2706 DeletedAdd = true;
2707 CommonFlags = maskFlags(CommonFlags, Add->getNoWrapFlags());
2708 }
2709
2710 // If we deleted at least one add, we added operands to the end of the list,
2711 // and they are not necessarily sorted. Recurse to resort and resimplify
2712 // any operands we just acquired.
2713 if (DeletedAdd)
2714 return getAddExpr(Ops, CommonFlags, Depth + 1);
2715 }
2716
2717 // Skip over the add expression until we get to a multiply.
2718 while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scMulExpr)
2719 ++Idx;
2720
2721 // Check to see if there are any folding opportunities present with
2722 // operands multiplied by constant values.
2723 if (Idx < Ops.size() && isa<SCEVMulExpr>(Ops[Idx])) {
2724 uint64_t BitWidth = getTypeSizeInBits(Ty);
2725 DenseMap<const SCEV *, APInt> M;
2726 SmallVector<const SCEV *, 8> NewOps;
2727 APInt AccumulatedConstant(BitWidth, 0);
2728 if (CollectAddOperandsWithScales(M, NewOps, AccumulatedConstant,
2729 Ops.data(), Ops.size(),
2730 APInt(BitWidth, 1), *this)) {
2731 struct APIntCompare {
2732 bool operator()(const APInt &LHS, const APInt &RHS) const {
2733 return LHS.ult(RHS);
2734 }
2735 };
2736
2737 // Some interesting folding opportunity is present, so its worthwhile to
2738 // re-generate the operands list. Group the operands by constant scale,
2739 // to avoid multiplying by the same constant scale multiple times.
2740 std::map<APInt, SmallVector<const SCEV *, 4>, APIntCompare> MulOpLists;
2741 for (const SCEV *NewOp : NewOps)
2742 MulOpLists[M.find(NewOp)->second].push_back(NewOp);
2743 // Re-generate the operands list.
2744 Ops.clear();
2745 if (AccumulatedConstant != 0)
2746 Ops.push_back(getConstant(AccumulatedConstant));
2747 for (auto &MulOp : MulOpLists) {
2748 if (MulOp.first == 1) {
2749 Ops.push_back(getAddExpr(MulOp.second, SCEV::FlagAnyWrap, Depth + 1));
2750 } else if (MulOp.first != 0) {
2751 Ops.push_back(getMulExpr(
2752 getConstant(MulOp.first),
2753 getAddExpr(MulOp.second, SCEV::FlagAnyWrap, Depth + 1),
2754 SCEV::FlagAnyWrap, Depth + 1));
2755 }
2756 }
2757 if (Ops.empty())
2758 return getZero(Ty);
2759 if (Ops.size() == 1)
2760 return Ops[0];
2761 return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
2762 }
2763 }
2764
2765 // If we are adding something to a multiply expression, make sure the
2766 // something is not already an operand of the multiply. If so, merge it into
2767 // the multiply.
2768 for (; Idx < Ops.size() && isa<SCEVMulExpr>(Ops[Idx]); ++Idx) {
2769 const SCEVMulExpr *Mul = cast<SCEVMulExpr>(Ops[Idx]);
2770 for (unsigned MulOp = 0, e = Mul->getNumOperands(); MulOp != e; ++MulOp) {
2771 const SCEV *MulOpSCEV = Mul->getOperand(MulOp);
2772 if (isa<SCEVConstant>(MulOpSCEV))
2773 continue;
2774 for (unsigned AddOp = 0, e = Ops.size(); AddOp != e; ++AddOp)
2775 if (MulOpSCEV == Ops[AddOp]) {
2776 // Fold W + X + (X * Y * Z) --> W + (X * ((Y*Z)+1))
2777 const SCEV *InnerMul = Mul->getOperand(MulOp == 0);
2778 if (Mul->getNumOperands() != 2) {
2779 // If the multiply has more than two operands, we must get the
2780 // Y*Z term.
2781 SmallVector<const SCEV *, 4> MulOps(Mul->op_begin(),
2782 Mul->op_begin()+MulOp);
2783 MulOps.append(Mul->op_begin()+MulOp+1, Mul->op_end());
2784 InnerMul = getMulExpr(MulOps, SCEV::FlagAnyWrap, Depth + 1);
2785 }
2786 SmallVector<const SCEV *, 2> TwoOps = {getOne(Ty), InnerMul};
2787 const SCEV *AddOne = getAddExpr(TwoOps, SCEV::FlagAnyWrap, Depth + 1);
2788 const SCEV *OuterMul = getMulExpr(AddOne, MulOpSCEV,
2789 SCEV::FlagAnyWrap, Depth + 1);
2790 if (Ops.size() == 2) return OuterMul;
2791 if (AddOp < Idx) {
2792 Ops.erase(Ops.begin()+AddOp);
2793 Ops.erase(Ops.begin()+Idx-1);
2794 } else {
2795 Ops.erase(Ops.begin()+Idx);
2796 Ops.erase(Ops.begin()+AddOp-1);
2797 }
2798 Ops.push_back(OuterMul);
2799 return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
2800 }
2801
2802 // Check this multiply against other multiplies being added together.
2803 for (unsigned OtherMulIdx = Idx+1;
2804 OtherMulIdx < Ops.size() && isa<SCEVMulExpr>(Ops[OtherMulIdx]);
2805 ++OtherMulIdx) {
2806 const SCEVMulExpr *OtherMul = cast<SCEVMulExpr>(Ops[OtherMulIdx]);
2807 // If MulOp occurs in OtherMul, we can fold the two multiplies
2808 // together.
2809 for (unsigned OMulOp = 0, e = OtherMul->getNumOperands();
2810 OMulOp != e; ++OMulOp)
2811 if (OtherMul->getOperand(OMulOp) == MulOpSCEV) {
2812 // Fold X + (A*B*C) + (A*D*E) --> X + (A*(B*C+D*E))
2813 const SCEV *InnerMul1 = Mul->getOperand(MulOp == 0);
2814 if (Mul->getNumOperands() != 2) {
2815 SmallVector<const SCEV *, 4> MulOps(Mul->op_begin(),
2816 Mul->op_begin()+MulOp);
2817 MulOps.append(Mul->op_begin()+MulOp+1, Mul->op_end());
2818 InnerMul1 = getMulExpr(MulOps, SCEV::FlagAnyWrap, Depth + 1);
2819 }
2820 const SCEV *InnerMul2 = OtherMul->getOperand(OMulOp == 0);
2821 if (OtherMul->getNumOperands() != 2) {
2822 SmallVector<const SCEV *, 4> MulOps(OtherMul->op_begin(),
2823 OtherMul->op_begin()+OMulOp);
2824 MulOps.append(OtherMul->op_begin()+OMulOp+1, OtherMul->op_end());
2825 InnerMul2 = getMulExpr(MulOps, SCEV::FlagAnyWrap, Depth + 1);
2826 }
2827 SmallVector<const SCEV *, 2> TwoOps = {InnerMul1, InnerMul2};
2828 const SCEV *InnerMulSum =
2829 getAddExpr(TwoOps, SCEV::FlagAnyWrap, Depth + 1);
2830 const SCEV *OuterMul = getMulExpr(MulOpSCEV, InnerMulSum,
2831 SCEV::FlagAnyWrap, Depth + 1);
2832 if (Ops.size() == 2) return OuterMul;
2833 Ops.erase(Ops.begin()+Idx);
2834 Ops.erase(Ops.begin()+OtherMulIdx-1);
2835 Ops.push_back(OuterMul);
2836 return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
2837 }
2838 }
2839 }
2840 }
2841
2842 // If there are any add recurrences in the operands list, see if any other
2843 // added values are loop invariant. If so, we can fold them into the
2844 // recurrence.
2845 while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scAddRecExpr)
2846 ++Idx;
2847
2848 // Scan over all recurrences, trying to fold loop invariants into them.
2849 for (; Idx < Ops.size() && isa<SCEVAddRecExpr>(Ops[Idx]); ++Idx) {
2850 // Scan all of the other operands to this add and add them to the vector if
2851 // they are loop invariant w.r.t. the recurrence.
2852 SmallVector<const SCEV *, 8> LIOps;
2853 const SCEVAddRecExpr *AddRec = cast<SCEVAddRecExpr>(Ops[Idx]);
2854 const Loop *AddRecLoop = AddRec->getLoop();
2855 for (unsigned i = 0, e = Ops.size(); i != e; ++i)
2856 if (isAvailableAtLoopEntry(Ops[i], AddRecLoop)) {
2857 LIOps.push_back(Ops[i]);
2858 Ops.erase(Ops.begin()+i);
2859 --i; --e;
2860 }
2861
2862 // If we found some loop invariants, fold them into the recurrence.
2863 if (!LIOps.empty()) {
2864 // Compute nowrap flags for the addition of the loop-invariant ops and
2865 // the addrec. Temporarily push it as an operand for that purpose. These
2866 // flags are valid in the scope of the addrec only.
2867 LIOps.push_back(AddRec);
2868 SCEV::NoWrapFlags Flags = ComputeFlags(LIOps);
2869 LIOps.pop_back();
2870
2871 // NLI + LI + {Start,+,Step} --> NLI + {LI+Start,+,Step}
2872 LIOps.push_back(AddRec->getStart());
2873
2874 SmallVector<const SCEV *, 4> AddRecOps(AddRec->operands());
2875
2876 // It is not in general safe to propagate flags valid on an add within
2877 // the addrec scope to one outside it. We must prove that the inner
2878 // scope is guaranteed to execute if the outer one does to be able to
2879 // safely propagate. We know the program is undefined if poison is
2880 // produced on the inner scoped addrec. We also know that *for this use*
2881 // the outer scoped add can't overflow (because of the flags we just
2882 // computed for the inner scoped add) without the program being undefined.
2883 // Proving that entry to the outer scope neccesitates entry to the inner
2884 // scope, thus proves the program undefined if the flags would be violated
2885 // in the outer scope.
2886 SCEV::NoWrapFlags AddFlags = Flags;
2887 if (AddFlags != SCEV::FlagAnyWrap) {
2888 auto *DefI = getDefiningScopeBound(LIOps);
2889 auto *ReachI = &*AddRecLoop->getHeader()->begin();
2890 if (!isGuaranteedToTransferExecutionTo(DefI, ReachI))
2891 AddFlags = SCEV::FlagAnyWrap;
2892 }
2893 AddRecOps[0] = getAddExpr(LIOps, AddFlags, Depth + 1);
2894
2895 // Build the new addrec. Propagate the NUW and NSW flags if both the
2896 // outer add and the inner addrec are guaranteed to have no overflow.
2897 // Always propagate NW.
2898 Flags = AddRec->getNoWrapFlags(setFlags(Flags, SCEV::FlagNW));
2899 const SCEV *NewRec = getAddRecExpr(AddRecOps, AddRecLoop, Flags);
2900
2901 // If all of the other operands were loop invariant, we are done.
2902 if (Ops.size() == 1) return NewRec;
2903
2904 // Otherwise, add the folded AddRec by the non-invariant parts.
2905 for (unsigned i = 0;; ++i)
2906 if (Ops[i] == AddRec) {
2907 Ops[i] = NewRec;
2908 break;
2909 }
2910 return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
2911 }
2912
2913 // Okay, if there weren't any loop invariants to be folded, check to see if
2914 // there are multiple AddRec's with the same loop induction variable being
2915 // added together. If so, we can fold them.
2916 for (unsigned OtherIdx = Idx+1;
2917 OtherIdx < Ops.size() && isa<SCEVAddRecExpr>(Ops[OtherIdx]);
2918 ++OtherIdx) {
2919 // We expect the AddRecExpr's to be sorted in reverse dominance order,
2920 // so that the 1st found AddRecExpr is dominated by all others.
2921 assert(DT.dominates((static_cast <bool> (DT.dominates( cast<SCEVAddRecExpr
>(Ops[OtherIdx])->getLoop()->getHeader(), AddRec->
getLoop()->getHeader()) && "AddRecExprs are not sorted in reverse dominance order?"
) ? void (0) : __assert_fail ("DT.dominates( cast<SCEVAddRecExpr>(Ops[OtherIdx])->getLoop()->getHeader(), AddRec->getLoop()->getHeader()) && \"AddRecExprs are not sorted in reverse dominance order?\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 2924, __extension__
__PRETTY_FUNCTION__))
2922 cast<SCEVAddRecExpr>(Ops[OtherIdx])->getLoop()->getHeader(),(static_cast <bool> (DT.dominates( cast<SCEVAddRecExpr
>(Ops[OtherIdx])->getLoop()->getHeader(), AddRec->
getLoop()->getHeader()) && "AddRecExprs are not sorted in reverse dominance order?"
) ? void (0) : __assert_fail ("DT.dominates( cast<SCEVAddRecExpr>(Ops[OtherIdx])->getLoop()->getHeader(), AddRec->getLoop()->getHeader()) && \"AddRecExprs are not sorted in reverse dominance order?\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 2924, __extension__
__PRETTY_FUNCTION__))
2923 AddRec->getLoop()->getHeader()) &&(static_cast <bool> (DT.dominates( cast<SCEVAddRecExpr
>(Ops[OtherIdx])->getLoop()->getHeader(), AddRec->
getLoop()->getHeader()) && "AddRecExprs are not sorted in reverse dominance order?"
) ? void (0) : __assert_fail ("DT.dominates( cast<SCEVAddRecExpr>(Ops[OtherIdx])->getLoop()->getHeader(), AddRec->getLoop()->getHeader()) && \"AddRecExprs are not sorted in reverse dominance order?\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 2924, __extension__
__PRETTY_FUNCTION__))
2924 "AddRecExprs are not sorted in reverse dominance order?")(static_cast <bool> (DT.dominates( cast<SCEVAddRecExpr
>(Ops[OtherIdx])->getLoop()->getHeader(), AddRec->
getLoop()->getHeader()) && "AddRecExprs are not sorted in reverse dominance order?"
) ? void (0) : __assert_fail ("DT.dominates( cast<SCEVAddRecExpr>(Ops[OtherIdx])->getLoop()->getHeader(), AddRec->getLoop()->getHeader()) && \"AddRecExprs are not sorted in reverse dominance order?\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 2924, __extension__
__PRETTY_FUNCTION__))
;
2925 if (AddRecLoop == cast<SCEVAddRecExpr>(Ops[OtherIdx])->getLoop()) {
2926 // Other + {A,+,B}<L> + {C,+,D}<L> --> Other + {A+C,+,B+D}<L>
2927 SmallVector<const SCEV *, 4> AddRecOps(AddRec->operands());
2928 for (; OtherIdx != Ops.size() && isa<SCEVAddRecExpr>(Ops[OtherIdx]);
2929 ++OtherIdx) {
2930 const auto *OtherAddRec = cast<SCEVAddRecExpr>(Ops[OtherIdx]);
2931 if (OtherAddRec->getLoop() == AddRecLoop) {
2932 for (unsigned i = 0, e = OtherAddRec->getNumOperands();
2933 i != e; ++i) {
2934 if (i >= AddRecOps.size()) {
2935 AddRecOps.append(OtherAddRec->op_begin()+i,
2936 OtherAddRec->op_end());
2937 break;
2938 }
2939 SmallVector<const SCEV *, 2> TwoOps = {
2940 AddRecOps[i], OtherAddRec->getOperand(i)};
2941 AddRecOps[i] = getAddExpr(TwoOps, SCEV::FlagAnyWrap, Depth + 1);
2942 }
2943 Ops.erase(Ops.begin() + OtherIdx); --OtherIdx;
2944 }
2945 }
2946 // Step size has changed, so we cannot guarantee no self-wraparound.
2947 Ops[Idx] = getAddRecExpr(AddRecOps, AddRecLoop, SCEV::FlagAnyWrap);
2948 return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
2949 }
2950 }
2951
2952 // Otherwise couldn't fold anything into this recurrence. Move onto the
2953 // next one.
2954 }
2955
2956 // Okay, it looks like we really DO need an add expr. Check to see if we
2957 // already have one, otherwise create a new one.
2958 return getOrCreateAddExpr(Ops, ComputeFlags(Ops));
2959}
2960
2961const SCEV *
2962ScalarEvolution::getOrCreateAddExpr(ArrayRef<const SCEV *> Ops,
2963 SCEV::NoWrapFlags Flags) {
2964 FoldingSetNodeID ID;
2965 ID.AddInteger(scAddExpr);
2966 for (const SCEV *Op : Ops)
2967 ID.AddPointer(Op);
2968 void *IP = nullptr;
2969 SCEVAddExpr *S =
2970 static_cast<SCEVAddExpr *>(UniqueSCEVs.FindNodeOrInsertPos(ID, IP));
2971 if (!S) {
2972 const SCEV **O = SCEVAllocator.Allocate<const SCEV *>(Ops.size());
2973 std::uninitialized_copy(Ops.begin(), Ops.end(), O);
2974 S = new (SCEVAllocator)
2975 SCEVAddExpr(ID.Intern(SCEVAllocator), O, Ops.size());
2976 UniqueSCEVs.InsertNode(S, IP);
2977 registerUser(S, Ops);
2978 }
2979 S->setNoWrapFlags(Flags);
2980 return S;
2981}
2982
2983const SCEV *
2984ScalarEvolution::getOrCreateAddRecExpr(ArrayRef<const SCEV *> Ops,
2985 const Loop *L, SCEV::NoWrapFlags Flags) {
2986 FoldingSetNodeID ID;
2987 ID.AddInteger(scAddRecExpr);
2988 for (const SCEV *Op : Ops)
2989 ID.AddPointer(Op);
2990 ID.AddPointer(L);
2991 void *IP = nullptr;
2992 SCEVAddRecExpr *S =
2993 static_cast<SCEVAddRecExpr *>(UniqueSCEVs.FindNodeOrInsertPos(ID, IP));
2994 if (!S) {
2995 const SCEV **O = SCEVAllocator.Allocate<const SCEV *>(Ops.size());
2996 std::uninitialized_copy(Ops.begin(), Ops.end(), O);
2997 S = new (SCEVAllocator)
2998 SCEVAddRecExpr(ID.Intern(SCEVAllocator), O, Ops.size(), L);
2999 UniqueSCEVs.InsertNode(S, IP);
3000 LoopUsers[L].push_back(S);
3001 registerUser(S, Ops);
3002 }
3003 setNoWrapFlags(S, Flags);
3004 return S;
3005}
3006
3007const SCEV *
3008ScalarEvolution::getOrCreateMulExpr(ArrayRef<const SCEV *> Ops,
3009 SCEV::NoWrapFlags Flags) {
3010 FoldingSetNodeID ID;
3011 ID.AddInteger(scMulExpr);
3012 for (const SCEV *Op : Ops)
3013 ID.AddPointer(Op);
3014 void *IP = nullptr;
3015 SCEVMulExpr *S =
3016 static_cast<SCEVMulExpr *>(UniqueSCEVs.FindNodeOrInsertPos(ID, IP));
3017 if (!S) {
3018 const SCEV **O = SCEVAllocator.Allocate<const SCEV *>(Ops.size());
3019 std::uninitialized_copy(Ops.begin(), Ops.end(), O);
3020 S = new (SCEVAllocator) SCEVMulExpr(ID.Intern(SCEVAllocator),
3021 O, Ops.size());
3022 UniqueSCEVs.InsertNode(S, IP);
3023 registerUser(S, Ops);
3024 }
3025 S->setNoWrapFlags(Flags);
3026 return S;
3027}
3028
3029static uint64_t umul_ov(uint64_t i, uint64_t j, bool &Overflow) {
3030 uint64_t k = i*j;
3031 if (j > 1 && k / j != i) Overflow = true;
3032 return k;
3033}
3034
3035/// Compute the result of "n choose k", the binomial coefficient. If an
3036/// intermediate computation overflows, Overflow will be set and the return will
3037/// be garbage. Overflow is not cleared on absence of overflow.
3038static uint64_t Choose(uint64_t n, uint64_t k, bool &Overflow) {
3039 // We use the multiplicative formula:
3040 // n(n-1)(n-2)...(n-(k-1)) / k(k-1)(k-2)...1 .
3041 // At each iteration, we take the n-th term of the numeral and divide by the
3042 // (k-n)th term of the denominator. This division will always produce an
3043 // integral result, and helps reduce the chance of overflow in the
3044 // intermediate computations. However, we can still overflow even when the
3045 // final result would fit.
3046
3047 if (n == 0 || n == k) return 1;
3048 if (k > n) return 0;
3049
3050 if (k > n/2)
3051 k = n-k;
3052
3053 uint64_t r = 1;
3054 for (uint64_t i = 1; i <= k; ++i) {
3055 r = umul_ov(r, n-(i-1), Overflow);
3056 r /= i;
3057 }
3058 return r;
3059}
3060
3061/// Determine if any of the operands in this SCEV are a constant or if
3062/// any of the add or multiply expressions in this SCEV contain a constant.
3063static bool containsConstantInAddMulChain(const SCEV *StartExpr) {
3064 struct FindConstantInAddMulChain {
3065 bool FoundConstant = false;
3066
3067 bool follow(const SCEV *S) {
3068 FoundConstant |= isa<SCEVConstant>(S);
3069 return isa<SCEVAddExpr>(S) || isa<SCEVMulExpr>(S);
3070 }
3071
3072 bool isDone() const {
3073 return FoundConstant;
3074 }
3075 };
3076
3077 FindConstantInAddMulChain F;
3078 SCEVTraversal<FindConstantInAddMulChain> ST(F);
3079 ST.visitAll(StartExpr);
3080 return F.FoundConstant;
3081}
3082
3083/// Get a canonical multiply expression, or something simpler if possible.
3084const SCEV *ScalarEvolution::getMulExpr(SmallVectorImpl<const SCEV *> &Ops,
3085 SCEV::NoWrapFlags OrigFlags,
3086 unsigned Depth) {
3087 assert(OrigFlags == maskFlags(OrigFlags, SCEV::FlagNUW | SCEV::FlagNSW) &&(static_cast <bool> (OrigFlags == maskFlags(OrigFlags, SCEV
::FlagNUW | SCEV::FlagNSW) && "only nuw or nsw allowed"
) ? void (0) : __assert_fail ("OrigFlags == maskFlags(OrigFlags, SCEV::FlagNUW | SCEV::FlagNSW) && \"only nuw or nsw allowed\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 3088, __extension__
__PRETTY_FUNCTION__))
3088 "only nuw or nsw allowed")(static_cast <bool> (OrigFlags == maskFlags(OrigFlags, SCEV
::FlagNUW | SCEV::FlagNSW) && "only nuw or nsw allowed"
) ? void (0) : __assert_fail ("OrigFlags == maskFlags(OrigFlags, SCEV::FlagNUW | SCEV::FlagNSW) && \"only nuw or nsw allowed\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 3088, __extension__
__PRETTY_FUNCTION__))
;
3089 assert(!Ops.empty() && "Cannot get empty mul!")(static_cast <bool> (!Ops.empty() && "Cannot get empty mul!"
) ? void (0) : __assert_fail ("!Ops.empty() && \"Cannot get empty mul!\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 3089, __extension__
__PRETTY_FUNCTION__))
;
3090 if (Ops.size() == 1) return Ops[0];
3091#ifndef NDEBUG
3092 Type *ETy = Ops[0]->getType();
3093 assert(!ETy->isPointerTy())(static_cast <bool> (!ETy->isPointerTy()) ? void (0)
: __assert_fail ("!ETy->isPointerTy()", "llvm/lib/Analysis/ScalarEvolution.cpp"
, 3093, __extension__ __PRETTY_FUNCTION__))
;
3094 for (unsigned i = 1, e = Ops.size(); i != e; ++i)
3095 assert(Ops[i]->getType() == ETy &&(static_cast <bool> (Ops[i]->getType() == ETy &&
"SCEVMulExpr operand types don't match!") ? void (0) : __assert_fail
("Ops[i]->getType() == ETy && \"SCEVMulExpr operand types don't match!\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 3096, __extension__
__PRETTY_FUNCTION__))
3096 "SCEVMulExpr operand types don't match!")(static_cast <bool> (Ops[i]->getType() == ETy &&
"SCEVMulExpr operand types don't match!") ? void (0) : __assert_fail
("Ops[i]->getType() == ETy && \"SCEVMulExpr operand types don't match!\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 3096, __extension__
__PRETTY_FUNCTION__))
;
3097#endif
3098
3099 // Sort by complexity, this groups all similar expression types together.
3100 GroupByComplexity(Ops, &LI, DT);
3101
3102 // If there are any constants, fold them together.
3103 unsigned Idx = 0;
3104 if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(Ops[0])) {
3105 ++Idx;
3106 assert(Idx < Ops.size())(static_cast <bool> (Idx < Ops.size()) ? void (0) : __assert_fail
("Idx < Ops.size()", "llvm/lib/Analysis/ScalarEvolution.cpp"
, 3106, __extension__ __PRETTY_FUNCTION__))
;
3107 while (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(Ops[Idx])) {
3108 // We found two constants, fold them together!
3109 Ops[0] = getConstant(LHSC->getAPInt() * RHSC->getAPInt());
3110 if (Ops.size() == 2) return Ops[0];
3111 Ops.erase(Ops.begin()+1); // Erase the folded element
3112 LHSC = cast<SCEVConstant>(Ops[0]);
3113 }
3114
3115 // If we have a multiply of zero, it will always be zero.
3116 if (LHSC->getValue()->isZero())
3117 return LHSC;
3118
3119 // If we are left with a constant one being multiplied, strip it off.
3120 if (LHSC->getValue()->isOne()) {
3121 Ops.erase(Ops.begin());
3122 --Idx;
3123 }
3124
3125 if (Ops.size() == 1)
3126 return Ops[0];
3127 }
3128
3129 // Delay expensive flag strengthening until necessary.
3130 auto ComputeFlags = [this, OrigFlags](const ArrayRef<const SCEV *> Ops) {
3131 return StrengthenNoWrapFlags(this, scMulExpr, Ops, OrigFlags);
3132 };
3133
3134 // Limit recursion calls depth.
3135 if (Depth > MaxArithDepth || hasHugeExpression(Ops))
3136 return getOrCreateMulExpr(Ops, ComputeFlags(Ops));
3137
3138 if (SCEV *S = findExistingSCEVInCache(scMulExpr, Ops)) {
3139 // Don't strengthen flags if we have no new information.
3140 SCEVMulExpr *Mul = static_cast<SCEVMulExpr *>(S);
3141 if (Mul->getNoWrapFlags(OrigFlags) != OrigFlags)
3142 Mul->setNoWrapFlags(ComputeFlags(Ops));
3143 return S;
3144 }
3145
3146 if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(Ops[0])) {
3147 if (Ops.size() == 2) {
3148 // C1*(C2+V) -> C1*C2 + C1*V
3149 if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Ops[1]))
3150 // If any of Add's ops are Adds or Muls with a constant, apply this
3151 // transformation as well.
3152 //
3153 // TODO: There are some cases where this transformation is not
3154 // profitable; for example, Add = (C0 + X) * Y + Z. Maybe the scope of
3155 // this transformation should be narrowed down.
3156 if (Add->getNumOperands() == 2 && containsConstantInAddMulChain(Add)) {
3157 const SCEV *LHS = getMulExpr(LHSC, Add->getOperand(0),
3158 SCEV::FlagAnyWrap, Depth + 1);
3159 const SCEV *RHS = getMulExpr(LHSC, Add->getOperand(1),
3160 SCEV::FlagAnyWrap, Depth + 1);
3161 return getAddExpr(LHS, RHS, SCEV::FlagAnyWrap, Depth + 1);
3162 }
3163
3164 if (Ops[0]->isAllOnesValue()) {
3165 // If we have a mul by -1 of an add, try distributing the -1 among the
3166 // add operands.
3167 if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Ops[1])) {
3168 SmallVector<const SCEV *, 4> NewOps;
3169 bool AnyFolded = false;
3170 for (const SCEV *AddOp : Add->operands()) {
3171 const SCEV *Mul = getMulExpr(Ops[0], AddOp, SCEV::FlagAnyWrap,
3172 Depth + 1);
3173 if (!isa<SCEVMulExpr>(Mul)) AnyFolded = true;
3174 NewOps.push_back(Mul);
3175 }
3176 if (AnyFolded)
3177 return getAddExpr(NewOps, SCEV::FlagAnyWrap, Depth + 1);
3178 } else if (const auto *AddRec = dyn_cast<SCEVAddRecExpr>(Ops[1])) {
3179 // Negation preserves a recurrence's no self-wrap property.
3180 SmallVector<const SCEV *, 4> Operands;
3181 for (const SCEV *AddRecOp : AddRec->operands())
3182 Operands.push_back(getMulExpr(Ops[0], AddRecOp, SCEV::FlagAnyWrap,
3183 Depth + 1));
3184
3185 return getAddRecExpr(Operands, AddRec->getLoop(),
3186 AddRec->getNoWrapFlags(SCEV::FlagNW));
3187 }
3188 }
3189 }
3190 }
3191
3192 // Skip over the add expression until we get to a multiply.
3193 while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scMulExpr)
3194 ++Idx;
3195
3196 // If there are mul operands inline them all into this expression.
3197 if (Idx < Ops.size()) {
3198 bool DeletedMul = false;
3199 while (const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(Ops[Idx])) {
3200 if (Ops.size() > MulOpsInlineThreshold)
3201 break;
3202 // If we have an mul, expand the mul operands onto the end of the
3203 // operands list.
3204 Ops.erase(Ops.begin()+Idx);
3205 Ops.append(Mul->op_begin(), Mul->op_end());
3206 DeletedMul = true;
3207 }
3208
3209 // If we deleted at least one mul, we added operands to the end of the
3210 // list, and they are not necessarily sorted. Recurse to resort and
3211 // resimplify any operands we just acquired.
3212 if (DeletedMul)
3213 return getMulExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
3214 }
3215
3216 // If there are any add recurrences in the operands list, see if any other
3217 // added values are loop invariant. If so, we can fold them into the
3218 // recurrence.
3219 while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scAddRecExpr)
3220 ++Idx;
3221
3222 // Scan over all recurrences, trying to fold loop invariants into them.
3223 for (; Idx < Ops.size() && isa<SCEVAddRecExpr>(Ops[Idx]); ++Idx) {
3224 // Scan all of the other operands to this mul and add them to the vector
3225 // if they are loop invariant w.r.t. the recurrence.
3226 SmallVector<const SCEV *, 8> LIOps;
3227 const SCEVAddRecExpr *AddRec = cast<SCEVAddRecExpr>(Ops[Idx]);
3228 const Loop *AddRecLoop = AddRec->getLoop();
3229 for (unsigned i = 0, e = Ops.size(); i != e; ++i)
3230 if (isAvailableAtLoopEntry(Ops[i], AddRecLoop)) {
3231 LIOps.push_back(Ops[i]);
3232 Ops.erase(Ops.begin()+i);
3233 --i; --e;
3234 }
3235
3236 // If we found some loop invariants, fold them into the recurrence.
3237 if (!LIOps.empty()) {
3238 // NLI * LI * {Start,+,Step} --> NLI * {LI*Start,+,LI*Step}
3239 SmallVector<const SCEV *, 4> NewOps;
3240 NewOps.reserve(AddRec->getNumOperands());
3241 const SCEV *Scale = getMulExpr(LIOps, SCEV::FlagAnyWrap, Depth + 1);
3242 for (unsigned i = 0, e = AddRec->getNumOperands(); i != e; ++i)
3243 NewOps.push_back(getMulExpr(Scale, AddRec->getOperand(i),
3244 SCEV::FlagAnyWrap, Depth + 1));
3245
3246 // Build the new addrec. Propagate the NUW and NSW flags if both the
3247 // outer mul and the inner addrec are guaranteed to have no overflow.
3248 //
3249 // No self-wrap cannot be guaranteed after changing the step size, but
3250 // will be inferred if either NUW or NSW is true.
3251 SCEV::NoWrapFlags Flags = ComputeFlags({Scale, AddRec});
3252 const SCEV *NewRec = getAddRecExpr(
3253 NewOps, AddRecLoop, AddRec->getNoWrapFlags(Flags));
3254
3255 // If all of the other operands were loop invariant, we are done.
3256 if (Ops.size() == 1) return NewRec;
3257
3258 // Otherwise, multiply the folded AddRec by the non-invariant parts.
3259 for (unsigned i = 0;; ++i)
3260 if (Ops[i] == AddRec) {
3261 Ops[i] = NewRec;
3262 break;
3263 }
3264 return getMulExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
3265 }
3266
3267 // Okay, if there weren't any loop invariants to be folded, check to see
3268 // if there are multiple AddRec's with the same loop induction variable
3269 // being multiplied together. If so, we can fold them.
3270
3271 // {A1,+,A2,+,...,+,An}<L> * {B1,+,B2,+,...,+,Bn}<L>
3272 // = {x=1 in [ sum y=x..2x [ sum z=max(y-x, y-n)..min(x,n) [
3273 // choose(x, 2x)*choose(2x-y, x-z)*A_{y-z}*B_z
3274 // ]]],+,...up to x=2n}.
3275 // Note that the arguments to choose() are always integers with values
3276 // known at compile time, never SCEV objects.
3277 //
3278 // The implementation avoids pointless extra computations when the two
3279 // addrec's are of different length (mathematically, it's equivalent to
3280 // an infinite stream of zeros on the right).
3281 bool OpsModified = false;
3282 for (unsigned OtherIdx = Idx+1;
3283 OtherIdx != Ops.size() && isa<SCEVAddRecExpr>(Ops[OtherIdx]);
3284 ++OtherIdx) {
3285 const SCEVAddRecExpr *OtherAddRec =
3286 dyn_cast<SCEVAddRecExpr>(Ops[OtherIdx]);
3287 if (!OtherAddRec || OtherAddRec->getLoop() != AddRecLoop)
3288 continue;
3289
3290 // Limit max number of arguments to avoid creation of unreasonably big
3291 // SCEVAddRecs with very complex operands.
3292 if (AddRec->getNumOperands() + OtherAddRec->getNumOperands() - 1 >
3293 MaxAddRecSize || hasHugeExpression({AddRec, OtherAddRec}))
3294 continue;
3295
3296 bool Overflow = false;
3297 Type *Ty = AddRec->getType();
3298 bool LargerThan64Bits = getTypeSizeInBits(Ty) > 64;
3299 SmallVector<const SCEV*, 7> AddRecOps;
3300 for (int x = 0, xe = AddRec->getNumOperands() +
3301 OtherAddRec->getNumOperands() - 1; x != xe && !Overflow; ++x) {
3302 SmallVector <const SCEV *, 7> SumOps;
3303 for (int y = x, ye = 2*x+1; y != ye && !Overflow; ++y) {
3304 uint64_t Coeff1 = Choose(x, 2*x - y, Overflow);
3305 for (int z = std::max(y-x, y-(int)AddRec->getNumOperands()+1),
3306 ze = std::min(x+1, (int)OtherAddRec->getNumOperands());
3307 z < ze && !Overflow; ++z) {
3308 uint64_t Coeff2 = Choose(2*x - y, x-z, Overflow);
3309 uint64_t Coeff;
3310 if (LargerThan64Bits)
3311 Coeff = umul_ov(Coeff1, Coeff2, Overflow);
3312 else
3313 Coeff = Coeff1*Coeff2;
3314 const SCEV *CoeffTerm = getConstant(Ty, Coeff);
3315 const SCEV *Term1 = AddRec->getOperand(y-z);
3316 const SCEV *Term2 = OtherAddRec->getOperand(z);
3317 SumOps.push_back(getMulExpr(CoeffTerm, Term1, Term2,
3318 SCEV::FlagAnyWrap, Depth + 1));
3319 }
3320 }
3321 if (SumOps.empty())
3322 SumOps.push_back(getZero(Ty));
3323 AddRecOps.push_back(getAddExpr(SumOps, SCEV::FlagAnyWrap, Depth + 1));
3324 }
3325 if (!Overflow) {
3326 const SCEV *NewAddRec = getAddRecExpr(AddRecOps, AddRecLoop,
3327 SCEV::FlagAnyWrap);
3328 if (Ops.size() == 2) return NewAddRec;
3329 Ops[Idx] = NewAddRec;
3330 Ops.erase(Ops.begin() + OtherIdx); --OtherIdx;
3331 OpsModified = true;
3332 AddRec = dyn_cast<SCEVAddRecExpr>(NewAddRec);
3333 if (!AddRec)
3334 break;
3335 }
3336 }
3337 if (OpsModified)
3338 return getMulExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
3339
3340 // Otherwise couldn't fold anything into this recurrence. Move onto the
3341 // next one.
3342 }
3343
3344 // Okay, it looks like we really DO need an mul expr. Check to see if we
3345 // already have one, otherwise create a new one.
3346 return getOrCreateMulExpr(Ops, ComputeFlags(Ops));
3347}
3348
3349/// Represents an unsigned remainder expression based on unsigned division.
3350const SCEV *ScalarEvolution::getURemExpr(const SCEV *LHS,
3351 const SCEV *RHS) {
3352 assert(getEffectiveSCEVType(LHS->getType()) ==(static_cast <bool> (getEffectiveSCEVType(LHS->getType
()) == getEffectiveSCEVType(RHS->getType()) && "SCEVURemExpr operand types don't match!"
) ? void (0) : __assert_fail ("getEffectiveSCEVType(LHS->getType()) == getEffectiveSCEVType(RHS->getType()) && \"SCEVURemExpr operand types don't match!\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 3354, __extension__
__PRETTY_FUNCTION__))
3353 getEffectiveSCEVType(RHS->getType()) &&(static_cast <bool> (getEffectiveSCEVType(LHS->getType
()) == getEffectiveSCEVType(RHS->getType()) && "SCEVURemExpr operand types don't match!"
) ? void (0) : __assert_fail ("getEffectiveSCEVType(LHS->getType()) == getEffectiveSCEVType(RHS->getType()) && \"SCEVURemExpr operand types don't match!\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 3354, __extension__
__PRETTY_FUNCTION__))
3354 "SCEVURemExpr operand types don't match!")(static_cast <bool> (getEffectiveSCEVType(LHS->getType
()) == getEffectiveSCEVType(RHS->getType()) && "SCEVURemExpr operand types don't match!"
) ? void (0) : __assert_fail ("getEffectiveSCEVType(LHS->getType()) == getEffectiveSCEVType(RHS->getType()) && \"SCEVURemExpr operand types don't match!\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 3354, __extension__
__PRETTY_FUNCTION__))
;
3355
3356 // Short-circuit easy cases
3357 if (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS)) {
3358 // If constant is one, the result is trivial
3359 if (RHSC->getValue()->isOne())
3360 return getZero(LHS->getType()); // X urem 1 --> 0
3361
3362 // If constant is a power of two, fold into a zext(trunc(LHS)).
3363 if (RHSC->getAPInt().isPowerOf2()) {
3364 Type *FullTy = LHS->getType();
3365 Type *TruncTy =
3366 IntegerType::get(getContext(), RHSC->getAPInt().logBase2());
3367 return getZeroExtendExpr(getTruncateExpr(LHS, TruncTy), FullTy);
3368 }
3369 }
3370
3371 // Fallback to %a == %x urem %y == %x -<nuw> ((%x udiv %y) *<nuw> %y)
3372 const SCEV *UDiv = getUDivExpr(LHS, RHS);
3373 const SCEV *Mult = getMulExpr(UDiv, RHS, SCEV::FlagNUW);
3374 return getMinusSCEV(LHS, Mult, SCEV::FlagNUW);
3375}
3376
3377/// Get a canonical unsigned division expression, or something simpler if
3378/// possible.
3379const SCEV *ScalarEvolution::getUDivExpr(const SCEV *LHS,
3380 const SCEV *RHS) {
3381 assert(!LHS->getType()->isPointerTy() &&(static_cast <bool> (!LHS->getType()->isPointerTy
() && "SCEVUDivExpr operand can't be pointer!") ? void
(0) : __assert_fail ("!LHS->getType()->isPointerTy() && \"SCEVUDivExpr operand can't be pointer!\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 3382, __extension__
__PRETTY_FUNCTION__))
3382 "SCEVUDivExpr operand can't be pointer!")(static_cast <bool> (!LHS->getType()->isPointerTy
() && "SCEVUDivExpr operand can't be pointer!") ? void
(0) : __assert_fail ("!LHS->getType()->isPointerTy() && \"SCEVUDivExpr operand can't be pointer!\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 3382, __extension__
__PRETTY_FUNCTION__))
;
3383 assert(LHS->getType() == RHS->getType() &&(static_cast <bool> (LHS->getType() == RHS->getType
() && "SCEVUDivExpr operand types don't match!") ? void
(0) : __assert_fail ("LHS->getType() == RHS->getType() && \"SCEVUDivExpr operand types don't match!\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 3384, __extension__
__PRETTY_FUNCTION__))
3384 "SCEVUDivExpr operand types don't match!")(static_cast <bool> (LHS->getType() == RHS->getType
() && "SCEVUDivExpr operand types don't match!") ? void
(0) : __assert_fail ("LHS->getType() == RHS->getType() && \"SCEVUDivExpr operand types don't match!\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 3384, __extension__
__PRETTY_FUNCTION__))
;
3385
3386 FoldingSetNodeID ID;
3387 ID.AddInteger(scUDivExpr);
3388 ID.AddPointer(LHS);
3389 ID.AddPointer(RHS);
3390 void *IP = nullptr;
3391 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP))
3392 return S;
3393
3394 // 0 udiv Y == 0
3395 if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(LHS))
3396 if (LHSC->getValue()->isZero())
3397 return LHS;
3398
3399 if (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS)) {
3400 if (RHSC->getValue()->isOne())
3401 return LHS; // X udiv 1 --> x
3402 // If the denominator is zero, the result of the udiv is undefined. Don't
3403 // try to analyze it, because the resolution chosen here may differ from
3404 // the resolution chosen in other parts of the compiler.
3405 if (!RHSC->getValue()->isZero()) {
3406 // Determine if the division can be folded into the operands of
3407 // its operands.
3408 // TODO: Generalize this to non-constants by using known-bits information.
3409 Type *Ty = LHS->getType();
3410 unsigned LZ = RHSC->getAPInt().countLeadingZeros();
3411 unsigned MaxShiftAmt = getTypeSizeInBits(Ty) - LZ - 1;
3412 // For non-power-of-two values, effectively round the value up to the
3413 // nearest power of two.
3414 if (!RHSC->getAPInt().isPowerOf2())
3415 ++MaxShiftAmt;
3416 IntegerType *ExtTy =
3417 IntegerType::get(getContext(), getTypeSizeInBits(Ty) + MaxShiftAmt);
3418 if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(LHS))
3419 if (const SCEVConstant *Step =
3420 dyn_cast<SCEVConstant>(AR->getStepRecurrence(*this))) {
3421 // {X,+,N}/C --> {X/C,+,N/C} if safe and N/C can be folded.
3422 const APInt &StepInt = Step->getAPInt();
3423 const APInt &DivInt = RHSC->getAPInt();
3424 if (!StepInt.urem(DivInt) &&
3425 getZeroExtendExpr(AR, ExtTy) ==
3426 getAddRecExpr(getZeroExtendExpr(AR->getStart(), ExtTy),
3427 getZeroExtendExpr(Step, ExtTy),
3428 AR->getLoop(), SCEV::FlagAnyWrap)) {
3429 SmallVector<const SCEV *, 4> Operands;
3430 for (const SCEV *Op : AR->operands())
3431 Operands.push_back(getUDivExpr(Op, RHS));
3432 return getAddRecExpr(Operands, AR->getLoop(), SCEV::FlagNW);
3433 }
3434 /// Get a canonical UDivExpr for a recurrence.
3435 /// {X,+,N}/C => {Y,+,N}/C where Y=X-(X%N). Safe when C%N=0.
3436 // We can currently only fold X%N if X is constant.
3437 const SCEVConstant *StartC = dyn_cast<SCEVConstant>(AR->getStart());
3438 if (StartC && !DivInt.urem(StepInt) &&
3439 getZeroExtendExpr(AR, ExtTy) ==
3440 getAddRecExpr(getZeroExtendExpr(AR->getStart(), ExtTy),
3441 getZeroExtendExpr(Step, ExtTy),
3442 AR->getLoop(), SCEV::FlagAnyWrap)) {
3443 const APInt &StartInt = StartC->getAPInt();
3444 const APInt &StartRem = StartInt.urem(StepInt);
3445 if (StartRem != 0) {
3446 const SCEV *NewLHS =
3447 getAddRecExpr(getConstant(StartInt - StartRem), Step,
3448 AR->getLoop(), SCEV::FlagNW);
3449 if (LHS != NewLHS) {
3450 LHS = NewLHS;
3451
3452 // Reset the ID to include the new LHS, and check if it is
3453 // already cached.
3454 ID.clear();
3455 ID.AddInteger(scUDivExpr);
3456 ID.AddPointer(LHS);
3457 ID.AddPointer(RHS);
3458 IP = nullptr;
3459 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP))
3460 return S;
3461 }
3462 }
3463 }
3464 }
3465 // (A*B)/C --> A*(B/C) if safe and B/C can be folded.
3466 if (const SCEVMulExpr *M = dyn_cast<SCEVMulExpr>(LHS)) {
3467 SmallVector<const SCEV *, 4> Operands;
3468 for (const SCEV *Op : M->operands())
3469 Operands.push_back(getZeroExtendExpr(Op, ExtTy));
3470 if (getZeroExtendExpr(M, ExtTy) == getMulExpr(Operands))
3471 // Find an operand that's safely divisible.
3472 for (unsigned i = 0, e = M->getNumOperands(); i != e; ++i) {
3473 const SCEV *Op = M->getOperand(i);
3474 const SCEV *Div = getUDivExpr(Op, RHSC);
3475 if (!isa<SCEVUDivExpr>(Div) && getMulExpr(Div, RHSC) == Op) {
3476 Operands = SmallVector<const SCEV *, 4>(M->operands());
3477 Operands[i] = Div;
3478 return getMulExpr(Operands);
3479 }
3480 }
3481 }
3482
3483 // (A/B)/C --> A/(B*C) if safe and B*C can be folded.
3484 if (const SCEVUDivExpr *OtherDiv = dyn_cast<SCEVUDivExpr>(LHS)) {
3485 if (auto *DivisorConstant =
3486 dyn_cast<SCEVConstant>(OtherDiv->getRHS())) {
3487 bool Overflow = false;
3488 APInt NewRHS =
3489 DivisorConstant->getAPInt().umul_ov(RHSC->getAPInt(), Overflow);
3490 if (Overflow) {
3491 return getConstant(RHSC->getType(), 0, false);
3492 }
3493 return getUDivExpr(OtherDiv->getLHS(), getConstant(NewRHS));
3494 }
3495 }
3496
3497 // (A+B)/C --> (A/C + B/C) if safe and A/C and B/C can be folded.
3498 if (const SCEVAddExpr *A = dyn_cast<SCEVAddExpr>(LHS)) {
3499 SmallVector<const SCEV *, 4> Operands;
3500 for (const SCEV *Op : A->operands())
3501 Operands.push_back(getZeroExtendExpr(Op, ExtTy));
3502 if (getZeroExtendExpr(A, ExtTy) == getAddExpr(Operands)) {
3503 Operands.clear();
3504 for (unsigned i = 0, e = A->getNumOperands(); i != e; ++i) {
3505 const SCEV *Op = getUDivExpr(A->getOperand(i), RHS);
3506 if (isa<SCEVUDivExpr>(Op) ||
3507 getMulExpr(Op, RHS) != A->getOperand(i))
3508 break;
3509 Operands.push_back(Op);
3510 }
3511 if (Operands.size() == A->getNumOperands())
3512 return getAddExpr(Operands);
3513 }
3514 }
3515
3516 // Fold if both operands are constant.
3517 if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(LHS))
3518 return getConstant(LHSC->getAPInt().udiv(RHSC->getAPInt()));
3519 }
3520 }
3521
3522 // The Insertion Point (IP) might be invalid by now (due to UniqueSCEVs
3523 // changes). Make sure we get a new one.
3524 IP = nullptr;
3525 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
3526 SCEV *S = new (SCEVAllocator) SCEVUDivExpr(ID.Intern(SCEVAllocator),
3527 LHS, RHS);
3528 UniqueSCEVs.InsertNode(S, IP);
3529 registerUser(S, {LHS, RHS});
3530 return S;
3531}
3532
3533APInt gcd(const SCEVConstant *C1, const SCEVConstant *C2) {
3534 APInt A = C1->getAPInt().abs();
3535 APInt B = C2->getAPInt().abs();
3536 uint32_t ABW = A.getBitWidth();
3537 uint32_t BBW = B.getBitWidth();
3538
3539 if (ABW > BBW)
3540 B = B.zext(ABW);
3541 else if (ABW < BBW)
3542 A = A.zext(BBW);
3543
3544 return APIntOps::GreatestCommonDivisor(std::move(A), std::move(B));
3545}
3546
3547/// Get a canonical unsigned division expression, or something simpler if
3548/// possible. There is no representation for an exact udiv in SCEV IR, but we
3549/// can attempt to remove factors from the LHS and RHS. We can't do this when
3550/// it's not exact because the udiv may be clearing bits.
3551const SCEV *ScalarEvolution::getUDivExactExpr(const SCEV *LHS,
3552 const SCEV *RHS) {
3553 // TODO: we could try to find factors in all sorts of things, but for now we
3554 // just deal with u/exact (multiply, constant). See SCEVDivision towards the
3555 // end of this file for inspiration.
3556
3557 const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(LHS);
3558 if (!Mul || !Mul->hasNoUnsignedWrap())
3559 return getUDivExpr(LHS, RHS);
3560
3561 if (const SCEVConstant *RHSCst = dyn_cast<SCEVConstant>(RHS)) {
3562 // If the mulexpr multiplies by a constant, then that constant must be the
3563 // first element of the mulexpr.
3564 if (const auto *LHSCst = dyn_cast<SCEVConstant>(Mul->getOperand(0))) {
3565 if (LHSCst == RHSCst) {
3566 SmallVector<const SCEV *, 2> Operands(drop_begin(Mul->operands()));
3567 return getMulExpr(Operands);
3568 }
3569
3570 // We can't just assume that LHSCst divides RHSCst cleanly, it could be
3571 // that there's a factor provided by one of the other terms. We need to
3572 // check.
3573 APInt Factor = gcd(LHSCst, RHSCst);
3574 if (!Factor.isIntN(1)) {
3575 LHSCst =
3576 cast<SCEVConstant>(getConstant(LHSCst->getAPInt().udiv(Factor)));
3577 RHSCst =
3578 cast<SCEVConstant>(getConstant(RHSCst->getAPInt().udiv(Factor)));
3579 SmallVector<const SCEV *, 2> Operands;
3580 Operands.push_back(LHSCst);
3581 Operands.append(Mul->op_begin() + 1, Mul->op_end());
3582 LHS = getMulExpr(Operands);
3583 RHS = RHSCst;
3584 Mul = dyn_cast<SCEVMulExpr>(LHS);
3585 if (!Mul)
3586 return getUDivExactExpr(LHS, RHS);
3587 }
3588 }
3589 }
3590
3591 for (int i = 0, e = Mul->getNumOperands(); i != e; ++i) {
3592 if (Mul->getOperand(i) == RHS) {
3593 SmallVector<const SCEV *, 2> Operands;
3594 Operands.append(Mul->op_begin(), Mul->op_begin() + i);
3595 Operands.append(Mul->op_begin() + i + 1, Mul->op_end());
3596 return getMulExpr(Operands);
3597 }
3598 }
3599
3600 return getUDivExpr(LHS, RHS);
3601}
3602
3603/// Get an add recurrence expression for the specified loop. Simplify the
3604/// expression as much as possible.
3605const SCEV *ScalarEvolution::getAddRecExpr(const SCEV *Start, const SCEV *Step,
3606 const Loop *L,
3607 SCEV::NoWrapFlags Flags) {
3608 SmallVector<const SCEV *, 4> Operands;
3609 Operands.push_back(Start);
3610 if (const SCEVAddRecExpr *StepChrec = dyn_cast<SCEVAddRecExpr>(Step))
3611 if (StepChrec->getLoop() == L) {
3612 Operands.append(StepChrec->op_begin(), StepChrec->op_end());
3613 return getAddRecExpr(Operands, L, maskFlags(Flags, SCEV::FlagNW));
3614 }
3615
3616 Operands.push_back(Step);
3617 return getAddRecExpr(Operands, L, Flags);
3618}
3619
3620/// Get an add recurrence expression for the specified loop. Simplify the
3621/// expression as much as possible.
3622const SCEV *
3623ScalarEvolution::getAddRecExpr(SmallVectorImpl<const SCEV *> &Operands,
3624 const Loop *L, SCEV::NoWrapFlags Flags) {
3625 if (Operands.size() == 1) return Operands[0];
3626#ifndef NDEBUG
3627 Type *ETy = getEffectiveSCEVType(Operands[0]->getType());
3628 for (unsigned i = 1, e = Operands.size(); i != e; ++i) {
3629 assert(getEffectiveSCEVType(Operands[i]->getType()) == ETy &&(static_cast <bool> (getEffectiveSCEVType(Operands[i]->
getType()) == ETy && "SCEVAddRecExpr operand types don't match!"
) ? void (0) : __assert_fail ("getEffectiveSCEVType(Operands[i]->getType()) == ETy && \"SCEVAddRecExpr operand types don't match!\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 3630, __extension__
__PRETTY_FUNCTION__))
3630 "SCEVAddRecExpr operand types don't match!")(static_cast <bool> (getEffectiveSCEVType(Operands[i]->
getType()) == ETy && "SCEVAddRecExpr operand types don't match!"
) ? void (0) : __assert_fail ("getEffectiveSCEVType(Operands[i]->getType()) == ETy && \"SCEVAddRecExpr operand types don't match!\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 3630, __extension__
__PRETTY_FUNCTION__))
;
3631 assert(!Operands[i]->getType()->isPointerTy() && "Step must be integer")(static_cast <bool> (!Operands[i]->getType()->isPointerTy
() && "Step must be integer") ? void (0) : __assert_fail
("!Operands[i]->getType()->isPointerTy() && \"Step must be integer\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 3631, __extension__
__PRETTY_FUNCTION__))
;
3632 }
3633 for (unsigned i = 0, e = Operands.size(); i != e; ++i)
3634 assert(isLoopInvariant(Operands[i], L) &&(static_cast <bool> (isLoopInvariant(Operands[i], L) &&
"SCEVAddRecExpr operand is not loop-invariant!") ? void (0) :
__assert_fail ("isLoopInvariant(Operands[i], L) && \"SCEVAddRecExpr operand is not loop-invariant!\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 3635, __extension__
__PRETTY_FUNCTION__))
3635 "SCEVAddRecExpr operand is not loop-invariant!")(static_cast <bool> (isLoopInvariant(Operands[i], L) &&
"SCEVAddRecExpr operand is not loop-invariant!") ? void (0) :
__assert_fail ("isLoopInvariant(Operands[i], L) && \"SCEVAddRecExpr operand is not loop-invariant!\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 3635, __extension__
__PRETTY_FUNCTION__))
;
3636#endif
3637
3638 if (Operands.back()->isZero()) {
3639 Operands.pop_back();
3640 return getAddRecExpr(Operands, L, SCEV::FlagAnyWrap); // {X,+,0} --> X
3641 }
3642
3643 // It's tempting to want to call getConstantMaxBackedgeTakenCount count here and
3644 // use that information to infer NUW and NSW flags. However, computing a
3645 // BE count requires calling getAddRecExpr, so we may not yet have a
3646 // meaningful BE count at this point (and if we don't, we'd be stuck
3647 // with a SCEVCouldNotCompute as the cached BE count).
3648
3649 Flags = StrengthenNoWrapFlags(this, scAddRecExpr, Operands, Flags);
3650
3651 // Canonicalize nested AddRecs in by nesting them in order of loop depth.
3652 if (const SCEVAddRecExpr *NestedAR = dyn_cast<SCEVAddRecExpr>(Operands[0])) {
3653 const Loop *NestedLoop = NestedAR->getLoop();
3654 if (L->contains(NestedLoop)
3655 ? (L->getLoopDepth() < NestedLoop->getLoopDepth())
3656 : (!NestedLoop->contains(L) &&
3657 DT.dominates(L->getHeader(), NestedLoop->getHeader()))) {
3658 SmallVector<const SCEV *, 4> NestedOperands(NestedAR->operands());
3659 Operands[0] = NestedAR->getStart();
3660 // AddRecs require their operands be loop-invariant with respect to their
3661 // loops. Don't perform this transformation if it would break this
3662 // requirement.
3663 bool AllInvariant = all_of(
3664 Operands, [&](const SCEV *Op) { return isLoopInvariant(Op, L); });
3665
3666 if (AllInvariant) {
3667 // Create a recurrence for the outer loop with the same step size.
3668 //
3669 // The outer recurrence keeps its NW flag but only keeps NUW/NSW if the
3670 // inner recurrence has the same property.
3671 SCEV::NoWrapFlags OuterFlags =
3672 maskFlags(Flags, SCEV::FlagNW | NestedAR->getNoWrapFlags());
3673
3674 NestedOperands[0] = getAddRecExpr(Operands, L, OuterFlags);
3675 AllInvariant = all_of(NestedOperands, [&](const SCEV *Op) {
3676 return isLoopInvariant(Op, NestedLoop);
3677 });
3678
3679 if (AllInvariant) {
3680 // Ok, both add recurrences are valid after the transformation.
3681 //
3682 // The inner recurrence keeps its NW flag but only keeps NUW/NSW if
3683 // the outer recurrence has the same property.
3684 SCEV::NoWrapFlags InnerFlags =
3685 maskFlags(NestedAR->getNoWrapFlags(), SCEV::FlagNW | Flags);
3686 return getAddRecExpr(NestedOperands, NestedLoop, InnerFlags);
3687 }
3688 }
3689 // Reset Operands to its original state.
3690 Operands[0] = NestedAR;
3691 }
3692 }
3693
3694 // Okay, it looks like we really DO need an addrec expr. Check to see if we
3695 // already have one, otherwise create a new one.
3696 return getOrCreateAddRecExpr(Operands, L, Flags);
3697}
3698
3699const SCEV *
3700ScalarEvolution::getGEPExpr(GEPOperator *GEP,
3701 const SmallVectorImpl<const SCEV *> &IndexExprs) {
3702 const SCEV *BaseExpr = getSCEV(GEP->getPointerOperand());
3703 // getSCEV(Base)->getType() has the same address space as Base->getType()
3704 // because SCEV::getType() preserves the address space.
3705 Type *IntIdxTy = getEffectiveSCEVType(BaseExpr->getType());
3706 const bool AssumeInBoundsFlags = [&]() {
3707 if (!GEP->isInBounds())
3708 return false;
3709
3710 // We'd like to propagate flags from the IR to the corresponding SCEV nodes,
3711 // but to do that, we have to ensure that said flag is valid in the entire
3712 // defined scope of the SCEV.
3713 auto *GEPI = dyn_cast<Instruction>(GEP);
3714 // TODO: non-instructions have global scope. We might be able to prove
3715 // some global scope cases
3716 return GEPI && isSCEVExprNeverPoison(GEPI);
3717 }();
3718
3719 SCEV::NoWrapFlags OffsetWrap =
3720 AssumeInBoundsFlags ? SCEV::FlagNSW : SCEV::FlagAnyWrap;
3721
3722 Type *CurTy = GEP->getType();
3723 bool FirstIter = true;
3724 SmallVector<const SCEV *, 4> Offsets;
3725 for (const SCEV *IndexExpr : IndexExprs) {
3726 // Compute the (potentially symbolic) offset in bytes for this index.
3727 if (StructType *STy = dyn_cast<StructType>(CurTy)) {
3728 // For a struct, add the member offset.
3729 ConstantInt *Index = cast<SCEVConstant>(IndexExpr)->getValue();
3730 unsigned FieldNo = Index->getZExtValue();
3731 const SCEV *FieldOffset = getOffsetOfExpr(IntIdxTy, STy, FieldNo);
3732 Offsets.push_back(FieldOffset);
3733
3734 // Update CurTy to the type of the field at Index.
3735 CurTy = STy->getTypeAtIndex(Index);
3736 } else {
3737 // Update CurTy to its element type.
3738 if (FirstIter) {
3739 assert(isa<PointerType>(CurTy) &&(static_cast <bool> (isa<PointerType>(CurTy) &&
"The first index of a GEP indexes a pointer") ? void (0) : __assert_fail
("isa<PointerType>(CurTy) && \"The first index of a GEP indexes a pointer\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 3740, __extension__
__PRETTY_FUNCTION__))
3740 "The first index of a GEP indexes a pointer")(static_cast <bool> (isa<PointerType>(CurTy) &&
"The first index of a GEP indexes a pointer") ? void (0) : __assert_fail
("isa<PointerType>(CurTy) && \"The first index of a GEP indexes a pointer\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 3740, __extension__
__PRETTY_FUNCTION__))
;
3741 CurTy = GEP->getSourceElementType();
3742 FirstIter = false;
3743 } else {
3744 CurTy = GetElementPtrInst::getTypeAtIndex(CurTy, (uint64_t)0);
3745 }
3746 // For an array, add the element offset, explicitly scaled.
3747 const SCEV *ElementSize = getSizeOfExpr(IntIdxTy, CurTy);
3748 // Getelementptr indices are signed.
3749 IndexExpr = getTruncateOrSignExtend(IndexExpr, IntIdxTy);
3750
3751 // Multiply the index by the element size to compute the element offset.
3752 const SCEV *LocalOffset = getMulExpr(IndexExpr, ElementSize, OffsetWrap);
3753 Offsets.push_back(LocalOffset);
3754 }
3755 }
3756
3757 // Handle degenerate case of GEP without offsets.
3758 if (Offsets.empty())
3759 return BaseExpr;
3760
3761 // Add the offsets together, assuming nsw if inbounds.
3762 const SCEV *Offset = getAddExpr(Offsets, OffsetWrap);
3763 // Add the base address and the offset. We cannot use the nsw flag, as the
3764 // base address is unsigned. However, if we know that the offset is
3765 // non-negative, we can use nuw.
3766 SCEV::NoWrapFlags BaseWrap = AssumeInBoundsFlags && isKnownNonNegative(Offset)
3767 ? SCEV::FlagNUW : SCEV::FlagAnyWrap;
3768 auto *GEPExpr = getAddExpr(BaseExpr, Offset, BaseWrap);
3769 assert(BaseExpr->getType() == GEPExpr->getType() &&(static_cast <bool> (BaseExpr->getType() == GEPExpr->
getType() && "GEP should not change type mid-flight."
) ? void (0) : __assert_fail ("BaseExpr->getType() == GEPExpr->getType() && \"GEP should not change type mid-flight.\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 3770, __extension__
__PRETTY_FUNCTION__))
3770 "GEP should not change type mid-flight.")(static_cast <bool> (BaseExpr->getType() == GEPExpr->
getType() && "GEP should not change type mid-flight."
) ? void (0) : __assert_fail ("BaseExpr->getType() == GEPExpr->getType() && \"GEP should not change type mid-flight.\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 3770, __extension__
__PRETTY_FUNCTION__))
;
3771 return GEPExpr;
3772}
3773
3774SCEV *ScalarEvolution::findExistingSCEVInCache(SCEVTypes SCEVType,
3775 ArrayRef<const SCEV *> Ops) {
3776 FoldingSetNodeID ID;
3777 ID.AddInteger(SCEVType);
3778 for (const SCEV *Op : Ops)
3779 ID.AddPointer(Op);
3780 void *IP = nullptr;
3781 return UniqueSCEVs.FindNodeOrInsertPos(ID, IP);
3782}
3783
3784const SCEV *ScalarEvolution::getAbsExpr(const SCEV *Op, bool IsNSW) {
3785 SCEV::NoWrapFlags Flags = IsNSW ? SCEV::FlagNSW : SCEV::FlagAnyWrap;
3786 return getSMaxExpr(Op, getNegativeSCEV(Op, Flags));
3787}
3788
3789const SCEV *ScalarEvolution::getMinMaxExpr(SCEVTypes Kind,
3790 SmallVectorImpl<const SCEV *> &Ops) {
3791 assert(SCEVMinMaxExpr::isMinMaxType(Kind) && "Not a SCEVMinMaxExpr!")(static_cast <bool> (SCEVMinMaxExpr::isMinMaxType(Kind)
&& "Not a SCEVMinMaxExpr!") ? void (0) : __assert_fail
("SCEVMinMaxExpr::isMinMaxType(Kind) && \"Not a SCEVMinMaxExpr!\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 3791, __extension__
__PRETTY_FUNCTION__))
;
3792 assert(!Ops.empty() && "Cannot get empty (u|s)(min|max)!")(static_cast <bool> (!Ops.empty() && "Cannot get empty (u|s)(min|max)!"
) ? void (0) : __assert_fail ("!Ops.empty() && \"Cannot get empty (u|s)(min|max)!\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 3792, __extension__
__PRETTY_FUNCTION__))
;
3793 if (Ops.size() == 1) return Ops[0];
3794#ifndef NDEBUG
3795 Type *ETy = getEffectiveSCEVType(Ops[0]->getType());
3796 for (unsigned i = 1, e = Ops.size(); i != e; ++i) {
3797 assert(getEffectiveSCEVType(Ops[i]->getType()) == ETy &&(static_cast <bool> (getEffectiveSCEVType(Ops[i]->getType
()) == ETy && "Operand types don't match!") ? void (0
) : __assert_fail ("getEffectiveSCEVType(Ops[i]->getType()) == ETy && \"Operand types don't match!\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 3798, __extension__
__PRETTY_FUNCTION__))
3798 "Operand types don't match!")(static_cast <bool> (getEffectiveSCEVType(Ops[i]->getType
()) == ETy && "Operand types don't match!") ? void (0
) : __assert_fail ("getEffectiveSCEVType(Ops[i]->getType()) == ETy && \"Operand types don't match!\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 3798, __extension__
__PRETTY_FUNCTION__))
;
3799 assert(Ops[0]->getType()->isPointerTy() ==(static_cast <bool> (Ops[0]->getType()->isPointerTy
() == Ops[i]->getType()->isPointerTy() && "min/max should be consistently pointerish"
) ? void (0) : __assert_fail ("Ops[0]->getType()->isPointerTy() == Ops[i]->getType()->isPointerTy() && \"min/max should be consistently pointerish\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 3801, __extension__
__PRETTY_FUNCTION__))
3800 Ops[i]->getType()->isPointerTy() &&(static_cast <bool> (Ops[0]->getType()->isPointerTy
() == Ops[i]->getType()->isPointerTy() && "min/max should be consistently pointerish"
) ? void (0) : __assert_fail ("Ops[0]->getType()->isPointerTy() == Ops[i]->getType()->isPointerTy() && \"min/max should be consistently pointerish\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 3801, __extension__
__PRETTY_FUNCTION__))
3801 "min/max should be consistently pointerish")(static_cast <bool> (Ops[0]->getType()->isPointerTy
() == Ops[i]->getType()->isPointerTy() && "min/max should be consistently pointerish"
) ? void (0) : __assert_fail ("Ops[0]->getType()->isPointerTy() == Ops[i]->getType()->isPointerTy() && \"min/max should be consistently pointerish\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 3801, __extension__
__PRETTY_FUNCTION__))
;
3802 }
3803#endif
3804
3805 bool IsSigned = Kind == scSMaxExpr || Kind == scSMinExpr;
3806 bool IsMax = Kind == scSMaxExpr || Kind == scUMaxExpr;
3807
3808 // Sort by complexity, this groups all similar expression types together.
3809 GroupByComplexity(Ops, &LI, DT);
3810
3811 // Check if we have created the same expression before.
3812 if (const SCEV *S = findExistingSCEVInCache(Kind, Ops)) {
3813 return S;
3814 }
3815
3816 // If there are any constants, fold them together.
3817 unsigned Idx = 0;
3818 if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(Ops[0])) {
3819 ++Idx;
3820 assert(Idx < Ops.size())(static_cast <bool> (Idx < Ops.size()) ? void (0) : __assert_fail
("Idx < Ops.size()", "llvm/lib/Analysis/ScalarEvolution.cpp"
, 3820, __extension__ __PRETTY_FUNCTION__))
;
3821 auto FoldOp = [&](const APInt &LHS, const APInt &RHS) {
3822 if (Kind == scSMaxExpr)
3823 return APIntOps::smax(LHS, RHS);
3824 else if (Kind == scSMinExpr)
3825 return APIntOps::smin(LHS, RHS);
3826 else if (Kind == scUMaxExpr)
3827 return APIntOps::umax(LHS, RHS);
3828 else if (Kind == scUMinExpr)
3829 return APIntOps::umin(LHS, RHS);
3830 llvm_unreachable("Unknown SCEV min/max opcode")::llvm::llvm_unreachable_internal("Unknown SCEV min/max opcode"
, "llvm/lib/Analysis/ScalarEvolution.cpp", 3830)
;
3831 };
3832
3833 while (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(Ops[Idx])) {
3834 // We found two constants, fold them together!
3835 ConstantInt *Fold = ConstantInt::get(
3836 getContext(), FoldOp(LHSC->getAPInt(), RHSC->getAPInt()));
3837 Ops[0] = getConstant(Fold);
3838 Ops.erase(Ops.begin()+1); // Erase the folded element
3839 if (Ops.size() == 1) return Ops[0];
3840 LHSC = cast<SCEVConstant>(Ops[0]);
3841 }
3842
3843 bool IsMinV = LHSC->getValue()->isMinValue(IsSigned);
3844 bool IsMaxV = LHSC->getValue()->isMaxValue(IsSigned);
3845
3846 if (IsMax ? IsMinV : IsMaxV) {
3847 // If we are left with a constant minimum(/maximum)-int, strip it off.
3848 Ops.erase(Ops.begin());
3849 --Idx;
3850 } else if (IsMax ? IsMaxV : IsMinV) {
3851 // If we have a max(/min) with a constant maximum(/minimum)-int,
3852 // it will always be the extremum.
3853 return LHSC;
3854 }
3855
3856 if (Ops.size() == 1) return Ops[0];
3857 }
3858
3859 // Find the first operation of the same kind
3860 while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < Kind)
3861 ++Idx;
3862
3863 // Check to see if one of the operands is of the same kind. If so, expand its
3864 // operands onto our operand list, and recurse to simplify.
3865 if (Idx < Ops.size()) {
3866 bool DeletedAny = false;
3867 while (Ops[Idx]->getSCEVType() == Kind) {
3868 const SCEVMinMaxExpr *SMME = cast<SCEVMinMaxExpr>(Ops[Idx]);
3869 Ops.erase(Ops.begin()+Idx);
3870 Ops.append(SMME->op_begin(), SMME->op_end());
3871 DeletedAny = true;
3872 }
3873
3874 if (DeletedAny)
3875 return getMinMaxExpr(Kind, Ops);
3876 }
3877
3878 // Okay, check to see if the same value occurs in the operand list twice. If
3879 // so, delete one. Since we sorted the list, these values are required to
3880 // be adjacent.
3881 llvm::CmpInst::Predicate GEPred =
3882 IsSigned ? ICmpInst::ICMP_SGE : ICmpInst::ICMP_UGE;
3883 llvm::CmpInst::Predicate LEPred =
3884 IsSigned ? ICmpInst::ICMP_SLE : ICmpInst::ICMP_ULE;
3885 llvm::CmpInst::Predicate FirstPred = IsMax ? GEPred : LEPred;
3886 llvm::CmpInst::Predicate SecondPred = IsMax ? LEPred : GEPred;
3887 for (unsigned i = 0, e = Ops.size() - 1; i != e; ++i) {
3888 if (Ops[i] == Ops[i + 1] ||
3889 isKnownViaNonRecursiveReasoning(FirstPred, Ops[i], Ops[i + 1])) {
3890 // X op Y op Y --> X op Y
3891 // X op Y --> X, if we know X, Y are ordered appropriately
3892 Ops.erase(Ops.begin() + i + 1, Ops.begin() + i + 2);
3893 --i;
3894 --e;
3895 } else if (isKnownViaNonRecursiveReasoning(SecondPred, Ops[i],
3896 Ops[i + 1])) {
3897 // X op Y --> Y, if we know X, Y are ordered appropriately
3898 Ops.erase(Ops.begin() + i, Ops.begin() + i + 1);
3899 --i;
3900 --e;
3901 }
3902 }
3903
3904 if (Ops.size() == 1) return Ops[0];
3905
3906 assert(!Ops.empty() && "Reduced smax down to nothing!")(static_cast <bool> (!Ops.empty() && "Reduced smax down to nothing!"
) ? void (0) : __assert_fail ("!Ops.empty() && \"Reduced smax down to nothing!\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 3906, __extension__
__PRETTY_FUNCTION__))
;
3907
3908 // Okay, it looks like we really DO need an expr. Check to see if we
3909 // already have one, otherwise create a new one.
3910 FoldingSetNodeID ID;
3911 ID.AddInteger(Kind);
3912 for (unsigned i = 0, e = Ops.size(); i != e; ++i)
3913 ID.AddPointer(Ops[i]);
3914 void *IP = nullptr;
3915 const SCEV *ExistingSCEV = UniqueSCEVs.FindNodeOrInsertPos(ID, IP);
3916 if (ExistingSCEV)
3917 return ExistingSCEV;
3918 const SCEV **O = SCEVAllocator.Allocate<const SCEV *>(Ops.size());
3919 std::uninitialized_copy(Ops.begin(), Ops.end(), O);
3920 SCEV *S = new (SCEVAllocator)
3921 SCEVMinMaxExpr(ID.Intern(SCEVAllocator), Kind, O, Ops.size());
3922
3923 UniqueSCEVs.InsertNode(S, IP);
3924 registerUser(S, Ops);
3925 return S;
3926}
3927
3928namespace {
3929
3930class SCEVSequentialMinMaxDeduplicatingVisitor final
3931 : public SCEVVisitor<SCEVSequentialMinMaxDeduplicatingVisitor,
3932 Optional<const SCEV *>> {
3933 using RetVal = Optional<const SCEV *>;
3934 using Base = SCEVVisitor<SCEVSequentialMinMaxDeduplicatingVisitor, RetVal>;
3935
3936 ScalarEvolution &SE;
3937 const SCEVTypes RootKind; // Must be a sequential min/max expression.
3938 const SCEVTypes NonSequentialRootKind; // Non-sequential variant of RootKind.
3939 SmallPtrSet<const SCEV *, 16> SeenOps;
3940
3941 bool canRecurseInto(SCEVTypes Kind) const {
3942 // We can only recurse into the SCEV expression of the same effective type
3943 // as the type of our root SCEV expression.
3944 return RootKind == Kind || NonSequentialRootKind == Kind;
3945 };
3946
3947 RetVal visitAnyMinMaxExpr(const SCEV *S) {
3948 assert((isa<SCEVMinMaxExpr>(S) || isa<SCEVSequentialMinMaxExpr>(S)) &&(static_cast <bool> ((isa<SCEVMinMaxExpr>(S) || isa
<SCEVSequentialMinMaxExpr>(S)) && "Only for min/max expressions."
) ? void (0) : __assert_fail ("(isa<SCEVMinMaxExpr>(S) || isa<SCEVSequentialMinMaxExpr>(S)) && \"Only for min/max expressions.\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 3949, __extension__
__PRETTY_FUNCTION__))
3949 "Only for min/max expressions.")(static_cast <bool> ((isa<SCEVMinMaxExpr>(S) || isa
<SCEVSequentialMinMaxExpr>(S)) && "Only for min/max expressions."
) ? void (0) : __assert_fail ("(isa<SCEVMinMaxExpr>(S) || isa<SCEVSequentialMinMaxExpr>(S)) && \"Only for min/max expressions.\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 3949, __extension__
__PRETTY_FUNCTION__))
;
3950 SCEVTypes Kind = S->getSCEVType();
3951
3952 if (!canRecurseInto(Kind))
3953 return S;
3954
3955 auto *NAry = cast<SCEVNAryExpr>(S);
3956 SmallVector<const SCEV *> NewOps;
3957 bool Changed =
3958 visit(Kind, makeArrayRef(NAry->op_begin(), NAry->op_end()), NewOps);
3959
3960 if (!Changed)
3961 return S;
3962 if (NewOps.empty())
3963 return None;
3964
3965 return isa<SCEVSequentialMinMaxExpr>(S)
3966 ? SE.getSequentialMinMaxExpr(Kind, NewOps)
3967 : SE.getMinMaxExpr(Kind, NewOps);
3968 }
3969
3970 RetVal visit(const SCEV *S) {
3971 // Has the whole operand been seen already?
3972 if (!SeenOps.insert(S).second)
3973 return None;
3974 return Base::visit(S);
3975 }
3976
3977public:
3978 SCEVSequentialMinMaxDeduplicatingVisitor(ScalarEvolution &SE,
3979 SCEVTypes RootKind)
3980 : SE(SE), RootKind(RootKind),
3981 NonSequentialRootKind(
3982 SCEVSequentialMinMaxExpr::getEquivalentNonSequentialSCEVType(
3983 RootKind)) {}
3984
3985 bool /*Changed*/ visit(SCEVTypes Kind, ArrayRef<const SCEV *> OrigOps,
3986 SmallVectorImpl<const SCEV *> &NewOps) {
3987 bool Changed = false;
3988 SmallVector<const SCEV *> Ops;
3989 Ops.reserve(OrigOps.size());
3990
3991 for (const SCEV *Op : OrigOps) {
3992 RetVal NewOp = visit(Op);
3993 if (NewOp != Op)
3994 Changed = true;
3995 if (NewOp)
3996 Ops.emplace_back(*NewOp);
3997 }
3998
3999 if (Changed)
4000 NewOps = std::move(Ops);
4001 return Changed;
4002 }
4003
4004 RetVal visitConstant(const SCEVConstant *Constant) { return Constant; }
4005
4006 RetVal visitPtrToIntExpr(const SCEVPtrToIntExpr *Expr) { return Expr; }
4007
4008 RetVal visitTruncateExpr(const SCEVTruncateExpr *Expr) { return Expr; }
4009
4010 RetVal visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr) { return Expr; }
4011
4012 RetVal visitSignExtendExpr(const SCEVSignExtendExpr *Expr) { return Expr; }
4013
4014 RetVal visitAddExpr(const SCEVAddExpr *Expr) { return Expr; }
4015
4016 RetVal visitMulExpr(const SCEVMulExpr *Expr) { return Expr; }
4017
4018 RetVal visitUDivExpr(const SCEVUDivExpr *Expr) { return Expr; }
4019
4020 RetVal visitAddRecExpr(const SCEVAddRecExpr *Expr) { return Expr; }
4021
4022 RetVal visitSMaxExpr(const SCEVSMaxExpr *Expr) {
4023 return visitAnyMinMaxExpr(Expr);
4024 }
4025
4026 RetVal visitUMaxExpr(const SCEVUMaxExpr *Expr) {
4027 return visitAnyMinMaxExpr(Expr);
4028 }
4029
4030 RetVal visitSMinExpr(const SCEVSMinExpr *Expr) {
4031 return visitAnyMinMaxExpr(Expr);
4032 }
4033
4034 RetVal visitUMinExpr(const SCEVUMinExpr *Expr) {
4035 return visitAnyMinMaxExpr(Expr);
4036 }
4037
4038 RetVal visitSequentialUMinExpr(const SCEVSequentialUMinExpr *Expr) {
4039 return visitAnyMinMaxExpr(Expr);
4040 }
4041
4042 RetVal visitUnknown(const SCEVUnknown *Expr) { return Expr; }
4043
4044 RetVal visitCouldNotCompute(const SCEVCouldNotCompute *Expr) { return Expr; }
4045};
4046
4047} // namespace
4048
4049/// Return true if V is poison given that AssumedPoison is already poison.
4050static bool impliesPoison(const SCEV *AssumedPoison, const SCEV *S) {
4051 // The only way poison may be introduced in a SCEV expression is from a
4052 // poison SCEVUnknown (ConstantExprs are also represented as SCEVUnknown,
4053 // not SCEVConstant). Notably, nowrap flags in SCEV nodes can *not*
4054 // introduce poison -- they encode guaranteed, non-speculated knowledge.
4055 //
4056 // Additionally, all SCEV nodes propagate poison from inputs to outputs,
4057 // with the notable exception of umin_seq, where only poison from the first
4058 // operand is (unconditionally) propagated.
4059 struct SCEVPoisonCollector {
4060 bool LookThroughSeq;
4061 SmallPtrSet<const SCEV *, 4> MaybePoison;
4062 SCEVPoisonCollector(bool LookThroughSeq) : LookThroughSeq(LookThroughSeq) {}
4063
4064 bool follow(const SCEV *S) {
4065 // TODO: We can always follow the first operand, but the SCEVTraversal
4066 // API doesn't support this.
4067 if (!LookThroughSeq && isa<SCEVSequentialMinMaxExpr>(S))
4068 return false;
4069
4070 if (auto *SU = dyn_cast<SCEVUnknown>(S)) {
4071 if (!isGuaranteedNotToBePoison(SU->getValue()))
4072 MaybePoison.insert(S);
4073 }
4074 return true;
4075 }
4076 bool isDone() const { return false; }
4077 };
4078
4079 // First collect all SCEVs that might result in AssumedPoison to be poison.
4080 // We need to look through umin_seq here, because we want to find all SCEVs
4081 // that *might* result in poison, not only those that are *required* to.
4082 SCEVPoisonCollector PC1(/* LookThroughSeq */ true);
4083 visitAll(AssumedPoison, PC1);
4084
4085 // AssumedPoison is never poison. As the assumption is false, the implication
4086 // is true. Don't bother walking the other SCEV in this case.
4087 if (PC1.MaybePoison.empty())
4088 return true;
4089
4090 // Collect all SCEVs in S that, if poison, *will* result in S being poison
4091 // as well. We cannot look through umin_seq here, as its argument only *may*
4092 // make the result poison.
4093 SCEVPoisonCollector PC2(/* LookThroughSeq */ false);
4094 visitAll(S, PC2);
4095
4096 // Make sure that no matter which SCEV in PC1.MaybePoison is actually poison,
4097 // it will also make S poison by being part of PC2.MaybePoison.
4098 return all_of(PC1.MaybePoison,
4099 [&](const SCEV *S) { return PC2.MaybePoison.contains(S); });
4100}
4101
4102const SCEV *
4103ScalarEvolution::getSequentialMinMaxExpr(SCEVTypes Kind,
4104 SmallVectorImpl<const SCEV *> &Ops) {
4105 assert(SCEVSequentialMinMaxExpr::isSequentialMinMaxType(Kind) &&(static_cast <bool> (SCEVSequentialMinMaxExpr::isSequentialMinMaxType
(Kind) && "Not a SCEVSequentialMinMaxExpr!") ? void (
0) : __assert_fail ("SCEVSequentialMinMaxExpr::isSequentialMinMaxType(Kind) && \"Not a SCEVSequentialMinMaxExpr!\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 4106, __extension__
__PRETTY_FUNCTION__))
4106 "Not a SCEVSequentialMinMaxExpr!")(static_cast <bool> (SCEVSequentialMinMaxExpr::isSequentialMinMaxType
(Kind) && "Not a SCEVSequentialMinMaxExpr!") ? void (
0) : __assert_fail ("SCEVSequentialMinMaxExpr::isSequentialMinMaxType(Kind) && \"Not a SCEVSequentialMinMaxExpr!\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 4106, __extension__
__PRETTY_FUNCTION__))
;
4107 assert(!Ops.empty() && "Cannot get empty (u|s)(min|max)!")(static_cast <bool> (!Ops.empty() && "Cannot get empty (u|s)(min|max)!"
) ? void (0) : __assert_fail ("!Ops.empty() && \"Cannot get empty (u|s)(min|max)!\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 4107, __extension__
__PRETTY_FUNCTION__))
;
4108 if (Ops.size() == 1)
4109 return Ops[0];
4110#ifndef NDEBUG
4111 Type *ETy = getEffectiveSCEVType(Ops[0]->getType());
4112 for (unsigned i = 1, e = Ops.size(); i != e; ++i) {
4113 assert(getEffectiveSCEVType(Ops[i]->getType()) == ETy &&(static_cast <bool> (getEffectiveSCEVType(Ops[i]->getType
()) == ETy && "Operand types don't match!") ? void (0
) : __assert_fail ("getEffectiveSCEVType(Ops[i]->getType()) == ETy && \"Operand types don't match!\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 4114, __extension__
__PRETTY_FUNCTION__))
4114 "Operand types don't match!")(static_cast <bool> (getEffectiveSCEVType(Ops[i]->getType
()) == ETy && "Operand types don't match!") ? void (0
) : __assert_fail ("getEffectiveSCEVType(Ops[i]->getType()) == ETy && \"Operand types don't match!\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 4114, __extension__
__PRETTY_FUNCTION__))
;
4115 assert(Ops[0]->getType()->isPointerTy() ==(static_cast <bool> (Ops[0]->getType()->isPointerTy
() == Ops[i]->getType()->isPointerTy() && "min/max should be consistently pointerish"
) ? void (0) : __assert_fail ("Ops[0]->getType()->isPointerTy() == Ops[i]->getType()->isPointerTy() && \"min/max should be consistently pointerish\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 4117, __extension__
__PRETTY_FUNCTION__))
4116 Ops[i]->getType()->isPointerTy() &&(static_cast <bool> (Ops[0]->getType()->isPointerTy
() == Ops[i]->getType()->isPointerTy() && "min/max should be consistently pointerish"
) ? void (0) : __assert_fail ("Ops[0]->getType()->isPointerTy() == Ops[i]->getType()->isPointerTy() && \"min/max should be consistently pointerish\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 4117, __extension__
__PRETTY_FUNCTION__))
4117 "min/max should be consistently pointerish")(static_cast <bool> (Ops[0]->getType()->isPointerTy
() == Ops[i]->getType()->isPointerTy() && "min/max should be consistently pointerish"
) ? void (0) : __assert_fail ("Ops[0]->getType()->isPointerTy() == Ops[i]->getType()->isPointerTy() && \"min/max should be consistently pointerish\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 4117, __extension__
__PRETTY_FUNCTION__))
;
4118 }
4119#endif
4120
4121 // Note that SCEVSequentialMinMaxExpr is *NOT* commutative,
4122 // so we can *NOT* do any kind of sorting of the expressions!
4123
4124 // Check if we have created the same expression before.
4125 if (const SCEV *S = findExistingSCEVInCache(Kind, Ops))
4126 return S;
4127
4128 // FIXME: there are *some* simplifications that we can do here.
4129
4130 // Keep only the first instance of an operand.
4131 {
4132 SCEVSequentialMinMaxDeduplicatingVisitor Deduplicator(*this, Kind);
4133 bool Changed = Deduplicator.visit(Kind, Ops, Ops);
4134 if (Changed)
4135 return getSequentialMinMaxExpr(Kind, Ops);
4136 }
4137
4138 // Check to see if one of the operands is of the same kind. If so, expand its
4139 // operands onto our operand list, and recurse to simplify.
4140 {
4141 unsigned Idx = 0;
4142 bool DeletedAny = false;
4143 while (Idx < Ops.size()) {
4144 if (Ops[Idx]->getSCEVType() != Kind) {
4145 ++Idx;
4146 continue;
4147 }
4148 const auto *SMME = cast<SCEVSequentialMinMaxExpr>(Ops[Idx]);
4149 Ops.erase(Ops.begin() + Idx);
4150 Ops.insert(Ops.begin() + Idx, SMME->op_begin(), SMME->op_end());
4151 DeletedAny = true;
4152 }
4153
4154 if (DeletedAny)
4155 return getSequentialMinMaxExpr(Kind, Ops);
4156 }
4157
4158 const SCEV *SaturationPoint;
4159 ICmpInst::Predicate Pred;
4160 switch (Kind) {
4161 case scSequentialUMinExpr:
4162 SaturationPoint = getZero(Ops[0]->getType());
4163 Pred = ICmpInst::ICMP_ULE;
4164 break;
4165 default:
4166 llvm_unreachable("Not a sequential min/max type.")::llvm::llvm_unreachable_internal("Not a sequential min/max type."
, "llvm/lib/Analysis/ScalarEvolution.cpp", 4166)
;
4167 }
4168
4169 for (unsigned i = 1, e = Ops.size(); i != e; ++i) {
4170 // We can replace %x umin_seq %y with %x umin %y if either:
4171 // * %y being poison implies %x is also poison.
4172 // * %x cannot be the saturating value (e.g. zero for umin).
4173 if (::impliesPoison(Ops[i], Ops[i - 1]) ||
4174 isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_NE, Ops[i - 1],
4175 SaturationPoint)) {
4176 SmallVector<const SCEV *> SeqOps = {Ops[i - 1], Ops[i]};
4177 Ops[i - 1] = getMinMaxExpr(
4178 SCEVSequentialMinMaxExpr::getEquivalentNonSequentialSCEVType(Kind),
4179 SeqOps);
4180 Ops.erase(Ops.begin() + i);
4181 return getSequentialMinMaxExpr(Kind, Ops);
4182 }
4183 // Fold %x umin_seq %y to %x if %x ule %y.
4184 // TODO: We might be able to prove the predicate for a later operand.
4185 if (isKnownViaNonRecursiveReasoning(Pred, Ops[i - 1], Ops[i])) {
4186 Ops.erase(Ops.begin() + i);
4187 return getSequentialMinMaxExpr(Kind, Ops);
4188 }
4189 }
4190
4191 // Okay, it looks like we really DO need an expr. Check to see if we
4192 // already have one, otherwise create a new one.
4193 FoldingSetNodeID ID;
4194 ID.AddInteger(Kind);
4195 for (unsigned i = 0, e = Ops.size(); i != e; ++i)
4196 ID.AddPointer(Ops[i]);
4197 void *IP = nullptr;
4198 const SCEV *ExistingSCEV = UniqueSCEVs.FindNodeOrInsertPos(ID, IP);
4199 if (ExistingSCEV)
4200 return ExistingSCEV;
4201
4202 const SCEV **O = SCEVAllocator.Allocate<const SCEV *>(Ops.size());
4203 std::uninitialized_copy(Ops.begin(), Ops.end(), O);
4204 SCEV *S = new (SCEVAllocator)
4205 SCEVSequentialMinMaxExpr(ID.Intern(SCEVAllocator), Kind, O, Ops.size());
4206
4207 UniqueSCEVs.InsertNode(S, IP);
4208 registerUser(S, Ops);
4209 return S;
4210}
4211
4212const SCEV *ScalarEvolution::getSMaxExpr(const SCEV *LHS, const SCEV *RHS) {
4213 SmallVector<const SCEV *, 2> Ops = {LHS, RHS};
4214 return getSMaxExpr(Ops);
4215}
4216
4217const SCEV *ScalarEvolution::getSMaxExpr(SmallVectorImpl<const SCEV *> &Ops) {
4218 return getMinMaxExpr(scSMaxExpr, Ops);
4219}
4220
4221const SCEV *ScalarEvolution::getUMaxExpr(const SCEV *LHS, const SCEV *RHS) {
4222 SmallVector<const SCEV *, 2> Ops = {LHS, RHS};
4223 return getUMaxExpr(Ops);
4224}
4225
4226const SCEV *ScalarEvolution::getUMaxExpr(SmallVectorImpl<const SCEV *> &Ops) {
4227 return getMinMaxExpr(scUMaxExpr, Ops);
4228}
4229
4230const SCEV *ScalarEvolution::getSMinExpr(const SCEV *LHS,
4231 const SCEV *RHS) {
4232 SmallVector<const SCEV *, 2> Ops = { LHS, RHS };
4233 return getSMinExpr(Ops);
4234}
4235
4236const SCEV *ScalarEvolution::getSMinExpr(SmallVectorImpl<const SCEV *> &Ops) {
4237 return getMinMaxExpr(scSMinExpr, Ops);
4238}
4239
4240const SCEV *ScalarEvolution::getUMinExpr(const SCEV *LHS, const SCEV *RHS,
4241 bool Sequential) {
4242 SmallVector<const SCEV *, 2> Ops = { LHS, RHS };
4243 return getUMinExpr(Ops, Sequential);
4244}
4245
4246const SCEV *ScalarEvolution::getUMinExpr(SmallVectorImpl<const SCEV *> &Ops,
4247 bool Sequential) {
4248 return Sequential ? getSequentialMinMaxExpr(scSequentialUMinExpr, Ops)
4249 : getMinMaxExpr(scUMinExpr, Ops);
4250}
4251
4252const SCEV *
4253ScalarEvolution::getSizeOfScalableVectorExpr(Type *IntTy,
4254 ScalableVectorType *ScalableTy) {
4255 Constant *NullPtr = Constant::getNullValue(ScalableTy->getPointerTo());
4256 Constant *One = ConstantInt::get(IntTy, 1);
4257 Constant *GEP = ConstantExpr::getGetElementPtr(ScalableTy, NullPtr, One);
4258 // Note that the expression we created is the final expression, we don't
4259 // want to simplify it any further Also, if we call a normal getSCEV(),
4260 // we'll end up in an endless recursion. So just create an SCEVUnknown.
4261 return getUnknown(ConstantExpr::getPtrToInt(GEP, IntTy));
4262}
4263
4264const SCEV *ScalarEvolution::getSizeOfExpr(Type *IntTy, Type *AllocTy) {
4265 if (auto *ScalableAllocTy = dyn_cast<ScalableVectorType>(AllocTy))
4266 return getSizeOfScalableVectorExpr(IntTy, ScalableAllocTy);
4267 // We can bypass creating a target-independent constant expression and then
4268 // folding it back into a ConstantInt. This is just a compile-time
4269 // optimization.
4270 return getConstant(IntTy, getDataLayout().getTypeAllocSize(AllocTy));
4271}
4272
4273const SCEV *ScalarEvolution::getStoreSizeOfExpr(Type *IntTy, Type *StoreTy) {
4274 if (auto *ScalableStoreTy = dyn_cast<ScalableVectorType>(StoreTy))
4275 return getSizeOfScalableVectorExpr(IntTy, ScalableStoreTy);
4276 // We can bypass creating a target-independent constant expression and then
4277 // folding it back into a ConstantInt. This is just a compile-time
4278 // optimization.
4279 return getConstant(IntTy, getDataLayout().getTypeStoreSize(StoreTy));
4280}
4281
4282const SCEV *ScalarEvolution::getOffsetOfExpr(Type *IntTy,
4283 StructType *STy,
4284 unsigned FieldNo) {
4285 // We can bypass creating a target-independent constant expression and then
4286 // folding it back into a ConstantInt. This is just a compile-time
4287 // optimization.
4288 return getConstant(
4289 IntTy, getDataLayout().getStructLayout(STy)->getElementOffset(FieldNo));
4290}
4291
4292const SCEV *ScalarEvolution::getUnknown(Value *V) {
4293 // Don't attempt to do anything other than create a SCEVUnknown object
4294 // here. createSCEV only calls getUnknown after checking for all other
4295 // interesting possibilities, and any other code that calls getUnknown
4296 // is doing so in order to hide a value from SCEV canonicalization.
4297
4298 FoldingSetNodeID ID;
4299 ID.AddInteger(scUnknown);
4300 ID.AddPointer(V);
4301 void *IP = nullptr;
4302 if (SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) {
4303 assert(cast<SCEVUnknown>(S)->getValue() == V &&(static_cast <bool> (cast<SCEVUnknown>(S)->getValue
() == V && "Stale SCEVUnknown in uniquing map!") ? void
(0) : __assert_fail ("cast<SCEVUnknown>(S)->getValue() == V && \"Stale SCEVUnknown in uniquing map!\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 4304, __extension__
__PRETTY_FUNCTION__))
4304 "Stale SCEVUnknown in uniquing map!")(static_cast <bool> (cast<SCEVUnknown>(S)->getValue
() == V && "Stale SCEVUnknown in uniquing map!") ? void
(0) : __assert_fail ("cast<SCEVUnknown>(S)->getValue() == V && \"Stale SCEVUnknown in uniquing map!\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 4304, __extension__
__PRETTY_FUNCTION__))
;
4305 return S;
4306 }
4307 SCEV *S = new (SCEVAllocator) SCEVUnknown(ID.Intern(SCEVAllocator), V, this,
4308 FirstUnknown);
4309 FirstUnknown = cast<SCEVUnknown>(S);
4310 UniqueSCEVs.InsertNode(S, IP);
4311 return S;
4312}
4313
4314//===----------------------------------------------------------------------===//
4315// Basic SCEV Analysis and PHI Idiom Recognition Code
4316//
4317
4318/// Test if values of the given type are analyzable within the SCEV
4319/// framework. This primarily includes integer types, and it can optionally
4320/// include pointer types if the ScalarEvolution class has access to
4321/// target-specific information.
4322bool ScalarEvolution::isSCEVable(Type *Ty) const {
4323 // Integers and pointers are always SCEVable.
4324 return Ty->isIntOrPtrTy();
4325}
4326
4327/// Return the size in bits of the specified type, for which isSCEVable must
4328/// return true.
4329uint64_t ScalarEvolution::getTypeSizeInBits(Type *Ty) const {
4330 assert(isSCEVable(Ty) && "Type is not SCEVable!")(static_cast <bool> (isSCEVable(Ty) && "Type is not SCEVable!"
) ? void (0) : __assert_fail ("isSCEVable(Ty) && \"Type is not SCEVable!\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 4330, __extension__
__PRETTY_FUNCTION__))
;
4331 if (Ty->isPointerTy())
4332 return getDataLayout().getIndexTypeSizeInBits(Ty);
4333 return getDataLayout().getTypeSizeInBits(Ty);
4334}
4335
4336/// Return a type with the same bitwidth as the given type and which represents
4337/// how SCEV will treat the given type, for which isSCEVable must return
4338/// true. For pointer types, this is the pointer index sized integer type.
4339Type *ScalarEvolution::getEffectiveSCEVType(Type *Ty) const {
4340 assert(isSCEVable(Ty) && "Type is not SCEVable!")(static_cast <bool> (isSCEVable(Ty) && "Type is not SCEVable!"
) ? void (0) : __assert_fail ("isSCEVable(Ty) && \"Type is not SCEVable!\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 4340, __extension__
__PRETTY_FUNCTION__))
;
4341
4342 if (Ty->isIntegerTy())
4343 return Ty;
4344
4345 // The only other support type is pointer.
4346 assert(Ty->isPointerTy() && "Unexpected non-pointer non-integer type!")(static_cast <bool> (Ty->isPointerTy() && "Unexpected non-pointer non-integer type!"
) ? void (0) : __assert_fail ("Ty->isPointerTy() && \"Unexpected non-pointer non-integer type!\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 4346, __extension__
__PRETTY_FUNCTION__))
;
4347 return getDataLayout().getIndexType(Ty);
4348}
4349
4350Type *ScalarEvolution::getWiderType(Type *T1, Type *T2) const {
4351 return getTypeSizeInBits(T1) >= getTypeSizeInBits(T2) ? T1 : T2;
4352}
4353
4354bool ScalarEvolution::instructionCouldExistWitthOperands(const SCEV *A,
4355 const SCEV *B) {
4356 /// For a valid use point to exist, the defining scope of one operand
4357 /// must dominate the other.
4358 bool PreciseA, PreciseB;
4359 auto *ScopeA = getDefiningScopeBound({A}, PreciseA);
4360 auto *ScopeB = getDefiningScopeBound({B}, PreciseB);
4361 if (!PreciseA || !PreciseB)
4362 // Can't tell.
4363 return false;
4364 return (ScopeA == ScopeB) || DT.dominates(ScopeA, ScopeB) ||
4365 DT.dominates(ScopeB, ScopeA);
4366}
4367
4368
4369const SCEV *ScalarEvolution::getCouldNotCompute() {
4370 return CouldNotCompute.get();
4371}
4372
4373bool ScalarEvolution::checkValidity(const SCEV *S) const {
4374 bool ContainsNulls = SCEVExprContains(S, [](const SCEV *S) {
4375 auto *SU = dyn_cast<SCEVUnknown>(S);
4376 return SU && SU->getValue() == nullptr;
4377 });
4378
4379 return !ContainsNulls;
4380}
4381
4382bool ScalarEvolution::containsAddRecurrence(const SCEV *S) {
4383 HasRecMapType::iterator I = HasRecMap.find(S);
4384 if (I != HasRecMap.end())
4385 return I->second;
4386
4387 bool FoundAddRec =
4388 SCEVExprContains(S, [](const SCEV *S) { return isa<SCEVAddRecExpr>(S); });
4389 HasRecMap.insert({S, FoundAddRec});
4390 return FoundAddRec;
4391}
4392
4393/// Return the ValueOffsetPair set for \p S. \p S can be represented
4394/// by the value and offset from any ValueOffsetPair in the set.
4395ArrayRef<Value *> ScalarEvolution::getSCEVValues(const SCEV *S) {
4396 ExprValueMapType::iterator SI = ExprValueMap.find_as(S);
4397 if (SI == ExprValueMap.end())
4398 return None;
4399#ifndef NDEBUG
4400 if (VerifySCEVMap) {
4401 // Check there is no dangling Value in the set returned.
4402 for (Value *V : SI->second)
4403 assert(ValueExprMap.count(V))(static_cast <bool> (ValueExprMap.count(V)) ? void (0) :
__assert_fail ("ValueExprMap.count(V)", "llvm/lib/Analysis/ScalarEvolution.cpp"
, 4403, __extension__ __PRETTY_FUNCTION__))
;
4404 }
4405#endif
4406 return SI->second.getArrayRef();
4407}
4408
4409/// Erase Value from ValueExprMap and ExprValueMap. ValueExprMap.erase(V)
4410/// cannot be used separately. eraseValueFromMap should be used to remove
4411/// V from ValueExprMap and ExprValueMap at the same time.
4412void ScalarEvolution::eraseValueFromMap(Value *V) {
4413 ValueExprMapType::iterator I = ValueExprMap.find_as(V);
4414 if (I != ValueExprMap.end()) {
4415 auto EVIt = ExprValueMap.find(I->second);
4416 bool Removed = EVIt->second.remove(V);
4417 (void) Removed;
4418 assert(Removed && "Value not in ExprValueMap?")(static_cast <bool> (Removed && "Value not in ExprValueMap?"
) ? void (0) : __assert_fail ("Removed && \"Value not in ExprValueMap?\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 4418, __extension__
__PRETTY_FUNCTION__))
;
4419 ValueExprMap.erase(I);
4420 }
4421}
4422
4423void ScalarEvolution::insertValueToMap(Value *V, const SCEV *S) {
4424 // A recursive query may have already computed the SCEV. It should be
4425 // equivalent, but may not necessarily be exactly the same, e.g. due to lazily
4426 // inferred nowrap flags.
4427 auto It = ValueExprMap.find_as(V);
4428 if (It == ValueExprMap.end()) {
4429 ValueExprMap.insert({SCEVCallbackVH(V, this), S});
4430 ExprValueMap[S].insert(V);
4431 }
4432}
4433
4434/// Return an existing SCEV if it exists, otherwise analyze the expression and
4435/// create a new one.
4436const SCEV *ScalarEvolution::getSCEV(Value *V) {
4437 assert(isSCEVable(V->getType()) && "Value is not SCEVable!")(static_cast <bool> (isSCEVable(V->getType()) &&
"Value is not SCEVable!") ? void (0) : __assert_fail ("isSCEVable(V->getType()) && \"Value is not SCEVable!\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 4437, __extension__
__PRETTY_FUNCTION__))
;
34
Called C++ object pointer is null
4438
4439 if (const SCEV *S = getExistingSCEV(V))
4440 return S;
4441 return createSCEVIter(V);
4442}
4443
4444const SCEV *ScalarEvolution::getExistingSCEV(Value *V) {
4445 assert(isSCEVable(V->getType()) && "Value is not SCEVable!")(static_cast <bool> (isSCEVable(V->getType()) &&
"Value is not SCEVable!") ? void (0) : __assert_fail ("isSCEVable(V->getType()) && \"Value is not SCEVable!\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 4445, __extension__
__PRETTY_FUNCTION__))
;
4446
4447 ValueExprMapType::iterator I = ValueExprMap.find_as(V);
4448 if (I != ValueExprMap.end()) {
4449 const SCEV *S = I->second;
4450 assert(checkValidity(S) &&(static_cast <bool> (checkValidity(S) && "existing SCEV has not been properly invalidated"
) ? void (0) : __assert_fail ("checkValidity(S) && \"existing SCEV has not been properly invalidated\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 4451, __extension__
__PRETTY_FUNCTION__))
4451 "existing SCEV has not been properly invalidated")(static_cast <bool> (checkValidity(S) && "existing SCEV has not been properly invalidated"
) ? void (0) : __assert_fail ("checkValidity(S) && \"existing SCEV has not been properly invalidated\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 4451, __extension__
__PRETTY_FUNCTION__))
;
4452 return S;
4453 }
4454 return nullptr;
4455}
4456
4457/// Return a SCEV corresponding to -V = -1*V
4458const SCEV *ScalarEvolution::getNegativeSCEV(const SCEV *V,
4459 SCEV::NoWrapFlags Flags) {
4460 if (const SCEVConstant *VC = dyn_cast<SCEVConstant>(V))
4461 return getConstant(
4462 cast<ConstantInt>(ConstantExpr::getNeg(VC->getValue())));
4463
4464 Type *Ty = V->getType();
4465 Ty = getEffectiveSCEVType(Ty);
4466 return getMulExpr(V, getMinusOne(Ty), Flags);
4467}
4468
4469/// If Expr computes ~A, return A else return nullptr
4470static const SCEV *MatchNotExpr(const SCEV *Expr) {
4471 const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Expr);
4472 if (!Add || Add->getNumOperands() != 2 ||
4473 !Add->getOperand(0)->isAllOnesValue())
4474 return nullptr;
4475
4476 const SCEVMulExpr *AddRHS = dyn_cast<SCEVMulExpr>(Add->getOperand(1));
4477 if (!AddRHS || AddRHS->getNumOperands() != 2 ||
4478 !AddRHS->getOperand(0)->isAllOnesValue())
4479 return nullptr;
4480
4481 return AddRHS->getOperand(1);
4482}
4483
4484/// Return a SCEV corresponding to ~V = -1-V
4485const SCEV *ScalarEvolution::getNotSCEV(const SCEV *V) {
4486 assert(!V->getType()->isPointerTy() && "Can't negate pointer")(static_cast <bool> (!V->getType()->isPointerTy()
&& "Can't negate pointer") ? void (0) : __assert_fail
("!V->getType()->isPointerTy() && \"Can't negate pointer\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 4486, __extension__
__PRETTY_FUNCTION__))
;
4487
4488 if (const SCEVConstant *VC = dyn_cast<SCEVConstant>(V))
4489 return getConstant(
4490 cast<ConstantInt>(ConstantExpr::getNot(VC->getValue())));
4491
4492 // Fold ~(u|s)(min|max)(~x, ~y) to (u|s)(max|min)(x, y)
4493 if (const SCEVMinMaxExpr *MME = dyn_cast<SCEVMinMaxExpr>(V)) {
4494 auto MatchMinMaxNegation = [&](const SCEVMinMaxExpr *MME) {
4495 SmallVector<const SCEV *, 2> MatchedOperands;
4496 for (const SCEV *Operand : MME->operands()) {
4497 const SCEV *Matched = MatchNotExpr(Operand);
4498 if (!Matched)
4499 return (const SCEV *)nullptr;
4500 MatchedOperands.push_back(Matched);
4501 }
4502 return getMinMaxExpr(SCEVMinMaxExpr::negate(MME->getSCEVType()),
4503 MatchedOperands);
4504 };
4505 if (const SCEV *Replaced = MatchMinMaxNegation(MME))
4506 return Replaced;
4507 }
4508
4509 Type *Ty = V->getType();
4510 Ty = getEffectiveSCEVType(Ty);
4511 return getMinusSCEV(getMinusOne(Ty), V);
4512}
4513
4514const SCEV *ScalarEvolution::removePointerBase(const SCEV *P) {
4515 assert(P->getType()->isPointerTy())(static_cast <bool> (P->getType()->isPointerTy())
? void (0) : __assert_fail ("P->getType()->isPointerTy()"
, "llvm/lib/Analysis/ScalarEvolution.cpp", 4515, __extension__
__PRETTY_FUNCTION__))
;
4516
4517 if (auto *AddRec = dyn_cast<SCEVAddRecExpr>(P)) {
4518 // The base of an AddRec is the first operand.
4519 SmallVector<const SCEV *> Ops{AddRec->operands()};
4520 Ops[0] = removePointerBase(Ops[0]);
4521 // Don't try to transfer nowrap flags for now. We could in some cases
4522 // (for example, if pointer operand of the AddRec is a SCEVUnknown).
4523 return getAddRecExpr(Ops, AddRec->getLoop(), SCEV::FlagAnyWrap);
4524 }
4525 if (auto *Add = dyn_cast<SCEVAddExpr>(P)) {
4526 // The base of an Add is the pointer operand.
4527 SmallVector<const SCEV *> Ops{Add->operands()};
4528 const SCEV **PtrOp = nullptr;
4529 for (const SCEV *&AddOp : Ops) {
4530 if (AddOp->getType()->isPointerTy()) {
4531 assert(!PtrOp && "Cannot have multiple pointer ops")(static_cast <bool> (!PtrOp && "Cannot have multiple pointer ops"
) ? void (0) : __assert_fail ("!PtrOp && \"Cannot have multiple pointer ops\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 4531, __extension__
__PRETTY_FUNCTION__))
;
4532 PtrOp = &AddOp;
4533 }
4534 }
4535 *PtrOp = removePointerBase(*PtrOp);
4536 // Don't try to transfer nowrap flags for now. We could in some cases
4537 // (for example, if the pointer operand of the Add is a SCEVUnknown).
4538 return getAddExpr(Ops);
4539 }
4540 // Any other expression must be a pointer base.
4541 return getZero(P->getType());
4542}
4543
4544const SCEV *ScalarEvolution::getMinusSCEV(const SCEV *LHS, const SCEV *RHS,
4545 SCEV::NoWrapFlags Flags,
4546 unsigned Depth) {
4547 // Fast path: X - X --> 0.
4548 if (LHS == RHS)
4549 return getZero(LHS->getType());
4550
4551 // If we subtract two pointers with different pointer bases, bail.
4552 // Eventually, we're going to add an assertion to getMulExpr that we
4553 // can't multiply by a pointer.
4554 if (RHS->getType()->isPointerTy()) {
4555 if (!LHS->getType()->isPointerTy() ||
4556 getPointerBase(LHS) != getPointerBase(RHS))
4557 return getCouldNotCompute();
4558 LHS = removePointerBase(LHS);
4559 RHS = removePointerBase(RHS);
4560 }
4561
4562 // We represent LHS - RHS as LHS + (-1)*RHS. This transformation
4563 // makes it so that we cannot make much use of NUW.
4564 auto AddFlags = SCEV::FlagAnyWrap;
4565 const bool RHSIsNotMinSigned =
4566 !getSignedRangeMin(RHS).isMinSignedValue();
4567 if (hasFlags(Flags, SCEV::FlagNSW)) {
4568 // Let M be the minimum representable signed value. Then (-1)*RHS
4569 // signed-wraps if and only if RHS is M. That can happen even for
4570 // a NSW subtraction because e.g. (-1)*M signed-wraps even though
4571 // -1 - M does not. So to transfer NSW from LHS - RHS to LHS +
4572 // (-1)*RHS, we need to prove that RHS != M.
4573 //
4574 // If LHS is non-negative and we know that LHS - RHS does not
4575 // signed-wrap, then RHS cannot be M. So we can rule out signed-wrap
4576 // either by proving that RHS > M or that LHS >= 0.
4577 if (RHSIsNotMinSigned || isKnownNonNegative(LHS)) {
4578 AddFlags = SCEV::FlagNSW;
4579 }
4580 }
4581
4582 // FIXME: Find a correct way to transfer NSW to (-1)*M when LHS -
4583 // RHS is NSW and LHS >= 0.
4584 //
4585 // The difficulty here is that the NSW flag may have been proven
4586 // relative to a loop that is to be found in a recurrence in LHS and
4587 // not in RHS. Applying NSW to (-1)*M may then let the NSW have a
4588 // larger scope than intended.
4589 auto NegFlags = RHSIsNotMinSigned ? SCEV::FlagNSW : SCEV::FlagAnyWrap;
4590
4591 return getAddExpr(LHS, getNegativeSCEV(RHS, NegFlags), AddFlags, Depth);
4592}
4593
4594const SCEV *ScalarEvolution::getTruncateOrZeroExtend(const SCEV *V, Type *Ty,
4595 unsigned Depth) {
4596 Type *SrcTy = V->getType();
4597 assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&(static_cast <bool> (SrcTy->isIntOrPtrTy() &&
Ty->isIntOrPtrTy() && "Cannot truncate or zero extend with non-integer arguments!"
) ? void (0) : __assert_fail ("SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() && \"Cannot truncate or zero extend with non-integer arguments!\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 4598, __extension__
__PRETTY_FUNCTION__))
4598 "Cannot truncate or zero extend with non-integer arguments!")(static_cast <bool> (SrcTy->isIntOrPtrTy() &&
Ty->isIntOrPtrTy() && "Cannot truncate or zero extend with non-integer arguments!"
) ? void (0) : __assert_fail ("SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() && \"Cannot truncate or zero extend with non-integer arguments!\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 4598, __extension__
__PRETTY_FUNCTION__))
;
4599 if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
4600 return V; // No conversion
4601 if (getTypeSizeInBits(SrcTy) > getTypeSizeInBits(Ty))
4602 return getTruncateExpr(V, Ty, Depth);
4603 return getZeroExtendExpr(V, Ty, Depth);
4604}
4605
4606const SCEV *ScalarEvolution::getTruncateOrSignExtend(const SCEV *V, Type *Ty,
4607 unsigned Depth) {
4608 Type *SrcTy = V->getType();
4609 assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&(static_cast <bool> (SrcTy->isIntOrPtrTy() &&
Ty->isIntOrPtrTy() && "Cannot truncate or zero extend with non-integer arguments!"
) ? void (0) : __assert_fail ("SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() && \"Cannot truncate or zero extend with non-integer arguments!\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 4610, __extension__
__PRETTY_FUNCTION__))
4610 "Cannot truncate or zero extend with non-integer arguments!")(static_cast <bool> (SrcTy->isIntOrPtrTy() &&
Ty->isIntOrPtrTy() && "Cannot truncate or zero extend with non-integer arguments!"
) ? void (0) : __assert_fail ("SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() && \"Cannot truncate or zero extend with non-integer arguments!\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 4610, __extension__
__PRETTY_FUNCTION__))
;
4611 if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
4612 return V; // No conversion
4613 if (getTypeSizeInBits(SrcTy) > getTypeSizeInBits(Ty))
4614 return getTruncateExpr(V, Ty, Depth);
4615 return getSignExtendExpr(V, Ty, Depth);
4616}
4617
4618const SCEV *
4619ScalarEvolution::getNoopOrZeroExtend(const SCEV *V, Type *Ty) {
4620 Type *SrcTy = V->getType();
4621 assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&(static_cast <bool> (SrcTy->isIntOrPtrTy() &&
Ty->isIntOrPtrTy() && "Cannot noop or zero extend with non-integer arguments!"
) ? void (0) : __assert_fail ("SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() && \"Cannot noop or zero extend with non-integer arguments!\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 4622, __extension__
__PRETTY_FUNCTION__))
4622 "Cannot noop or zero extend with non-integer arguments!")(static_cast <bool> (SrcTy->isIntOrPtrTy() &&
Ty->isIntOrPtrTy() && "Cannot noop or zero extend with non-integer arguments!"
) ? void (0) : __assert_fail ("SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() && \"Cannot noop or zero extend with non-integer arguments!\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 4622, __extension__
__PRETTY_FUNCTION__))
;
4623 assert(getTypeSizeInBits(SrcTy) <= getTypeSizeInBits(Ty) &&(static_cast <bool> (getTypeSizeInBits(SrcTy) <= getTypeSizeInBits
(Ty) && "getNoopOrZeroExtend cannot truncate!") ? void
(0) : __assert_fail ("getTypeSizeInBits(SrcTy) <= getTypeSizeInBits(Ty) && \"getNoopOrZeroExtend cannot truncate!\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 4624, __extension__
__PRETTY_FUNCTION__))
4624 "getNoopOrZeroExtend cannot truncate!")(static_cast <bool> (getTypeSizeInBits(SrcTy) <= getTypeSizeInBits
(Ty) && "getNoopOrZeroExtend cannot truncate!") ? void
(0) : __assert_fail ("getTypeSizeInBits(SrcTy) <= getTypeSizeInBits(Ty) && \"getNoopOrZeroExtend cannot truncate!\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 4624, __extension__
__PRETTY_FUNCTION__))
;
4625 if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
4626 return V; // No conversion
4627 return getZeroExtendExpr(V, Ty);
4628}
4629
4630const SCEV *
4631ScalarEvolution::getNoopOrSignExtend(const SCEV *V, Type *Ty) {
4632 Type *SrcTy = V->getType();
4633 assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&(static_cast <bool> (SrcTy->isIntOrPtrTy() &&
Ty->isIntOrPtrTy() && "Cannot noop or sign extend with non-integer arguments!"
) ? void (0) : __assert_fail ("SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() && \"Cannot noop or sign extend with non-integer arguments!\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 4634, __extension__
__PRETTY_FUNCTION__))
4634 "Cannot noop or sign extend with non-integer arguments!")(static_cast <bool> (SrcTy->isIntOrPtrTy() &&
Ty->isIntOrPtrTy() && "Cannot noop or sign extend with non-integer arguments!"
) ? void (0) : __assert_fail ("SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() && \"Cannot noop or sign extend with non-integer arguments!\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 4634, __extension__
__PRETTY_FUNCTION__))
;
4635 assert(getTypeSizeInBits(SrcTy) <= getTypeSizeInBits(Ty) &&(static_cast <bool> (getTypeSizeInBits(SrcTy) <= getTypeSizeInBits
(Ty) && "getNoopOrSignExtend cannot truncate!") ? void
(0) : __assert_fail ("getTypeSizeInBits(SrcTy) <= getTypeSizeInBits(Ty) && \"getNoopOrSignExtend cannot truncate!\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 4636, __extension__
__PRETTY_FUNCTION__))
4636 "getNoopOrSignExtend cannot truncate!")(static_cast <bool> (getTypeSizeInBits(SrcTy) <= getTypeSizeInBits
(Ty) && "getNoopOrSignExtend cannot truncate!") ? void
(0) : __assert_fail ("getTypeSizeInBits(SrcTy) <= getTypeSizeInBits(Ty) && \"getNoopOrSignExtend cannot truncate!\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 4636, __extension__
__PRETTY_FUNCTION__))
;
4637 if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
4638 return V; // No conversion
4639 return getSignExtendExpr(V, Ty);
4640}
4641
4642const SCEV *
4643ScalarEvolution::getNoopOrAnyExtend(const SCEV *V, Type *Ty) {
4644 Type *SrcTy = V->getType();
4645 assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&(static_cast <bool> (SrcTy->isIntOrPtrTy() &&
Ty->isIntOrPtrTy() && "Cannot noop or any extend with non-integer arguments!"
) ? void (0) : __assert_fail ("SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() && \"Cannot noop or any extend with non-integer arguments!\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 4646, __extension__
__PRETTY_FUNCTION__))
4646 "Cannot noop or any extend with non-integer arguments!")(static_cast <bool> (SrcTy->isIntOrPtrTy() &&
Ty->isIntOrPtrTy() && "Cannot noop or any extend with non-integer arguments!"
) ? void (0) : __assert_fail ("SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() && \"Cannot noop or any extend with non-integer arguments!\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 4646, __extension__
__PRETTY_FUNCTION__))
;
4647 assert(getTypeSizeInBits(SrcTy) <= getTypeSizeInBits(Ty) &&(static_cast <bool> (getTypeSizeInBits(SrcTy) <= getTypeSizeInBits
(Ty) && "getNoopOrAnyExtend cannot truncate!") ? void
(0) : __assert_fail ("getTypeSizeInBits(SrcTy) <= getTypeSizeInBits(Ty) && \"getNoopOrAnyExtend cannot truncate!\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 4648, __extension__
__PRETTY_FUNCTION__))
4648 "getNoopOrAnyExtend cannot truncate!")(static_cast <bool> (getTypeSizeInBits(SrcTy) <= getTypeSizeInBits
(Ty) && "getNoopOrAnyExtend cannot truncate!") ? void
(0) : __assert_fail ("getTypeSizeInBits(SrcTy) <= getTypeSizeInBits(Ty) && \"getNoopOrAnyExtend cannot truncate!\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 4648, __extension__
__PRETTY_FUNCTION__))
;
4649 if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
4650 return V; // No conversion
4651 return getAnyExtendExpr(V, Ty);
4652}
4653
4654const SCEV *
4655ScalarEvolution::getTruncateOrNoop(const SCEV *V, Type *Ty) {
4656 Type *SrcTy = V->getType();
4657 assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&(static_cast <bool> (SrcTy->isIntOrPtrTy() &&
Ty->isIntOrPtrTy() && "Cannot truncate or noop with non-integer arguments!"
) ? void (0) : __assert_fail ("SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() && \"Cannot truncate or noop with non-integer arguments!\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 4658, __extension__
__PRETTY_FUNCTION__))
4658 "Cannot truncate or noop with non-integer arguments!")(static_cast <bool> (SrcTy->isIntOrPtrTy() &&
Ty->isIntOrPtrTy() && "Cannot truncate or noop with non-integer arguments!"
) ? void (0) : __assert_fail ("SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() && \"Cannot truncate or noop with non-integer arguments!\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 4658, __extension__
__PRETTY_FUNCTION__))
;
4659 assert(getTypeSizeInBits(SrcTy) >= getTypeSizeInBits(Ty) &&(static_cast <bool> (getTypeSizeInBits(SrcTy) >= getTypeSizeInBits
(Ty) && "getTruncateOrNoop cannot extend!") ? void (0
) : __assert_fail ("getTypeSizeInBits(SrcTy) >= getTypeSizeInBits(Ty) && \"getTruncateOrNoop cannot extend!\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 4660, __extension__
__PRETTY_FUNCTION__))
4660 "getTruncateOrNoop cannot extend!")(static_cast <bool> (getTypeSizeInBits(SrcTy) >= getTypeSizeInBits
(Ty) && "getTruncateOrNoop cannot extend!") ? void (0
) : __assert_fail ("getTypeSizeInBits(SrcTy) >= getTypeSizeInBits(Ty) && \"getTruncateOrNoop cannot extend!\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 4660, __extension__
__PRETTY_FUNCTION__))
;
4661 if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
4662 return V; // No conversion
4663 return getTruncateExpr(V, Ty);
4664}
4665
4666const SCEV *ScalarEvolution::getUMaxFromMismatchedTypes(const SCEV *LHS,
4667 const SCEV *RHS) {
4668 const SCEV *PromotedLHS = LHS;
4669 const SCEV *PromotedRHS = RHS;
4670
4671 if (getTypeSizeInBits(LHS->getType()) > getTypeSizeInBits(RHS->getType()))
4672 PromotedRHS = getZeroExtendExpr(RHS, LHS->getType());
4673 else
4674 PromotedLHS = getNoopOrZeroExtend(LHS, RHS->getType());
4675
4676 return getUMaxExpr(PromotedLHS, PromotedRHS);
4677}
4678
4679const SCEV *ScalarEvolution::getUMinFromMismatchedTypes(const SCEV *LHS,
4680 const SCEV *RHS,
4681 bool Sequential) {
4682 SmallVector<const SCEV *, 2> Ops = { LHS, RHS };
4683 return getUMinFromMismatchedTypes(Ops, Sequential);
4684}
4685
4686const SCEV *
4687ScalarEvolution::getUMinFromMismatchedTypes(SmallVectorImpl<const SCEV *> &Ops,
4688 bool Sequential) {
4689 assert(!Ops.empty() && "At least one operand must be!")(static_cast <bool> (!Ops.empty() && "At least one operand must be!"
) ? void (0) : __assert_fail ("!Ops.empty() && \"At least one operand must be!\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 4689, __extension__
__PRETTY_FUNCTION__))
;
4690 // Trivial case.
4691 if (Ops.size() == 1)
4692 return Ops[0];
4693
4694 // Find the max type first.
4695 Type *MaxType = nullptr;
4696 for (const auto *S : Ops)
4697 if (MaxType)
4698 MaxType = getWiderType(MaxType, S->getType());
4699 else
4700 MaxType = S->getType();
4701 assert(MaxType && "Failed to find maximum type!")(static_cast <bool> (MaxType && "Failed to find maximum type!"
) ? void (0) : __assert_fail ("MaxType && \"Failed to find maximum type!\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 4701, __extension__
__PRETTY_FUNCTION__))
;
4702
4703 // Extend all ops to max type.
4704 SmallVector<const SCEV *, 2> PromotedOps;
4705 for (const auto *S : Ops)
4706 PromotedOps.push_back(getNoopOrZeroExtend(S, MaxType));
4707
4708 // Generate umin.
4709 return getUMinExpr(PromotedOps, Sequential);
4710}
4711
4712const SCEV *ScalarEvolution::getPointerBase(const SCEV *V) {
4713 // A pointer operand may evaluate to a nonpointer expression, such as null.
4714 if (!V->getType()->isPointerTy())
4715 return V;
4716
4717 while (true) {
4718 if (auto *AddRec = dyn_cast<SCEVAddRecExpr>(V)) {
4719 V = AddRec->getStart();
4720 } else if (auto *Add = dyn_cast<SCEVAddExpr>(V)) {
4721 const SCEV *PtrOp = nullptr;
4722 for (const SCEV *AddOp : Add->operands()) {
4723 if (AddOp->getType()->isPointerTy()) {
4724 assert(!PtrOp && "Cannot have multiple pointer ops")(static_cast <bool> (!PtrOp && "Cannot have multiple pointer ops"
) ? void (0) : __assert_fail ("!PtrOp && \"Cannot have multiple pointer ops\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 4724, __extension__
__PRETTY_FUNCTION__))
;
4725 PtrOp = AddOp;
4726 }
4727 }
4728 assert(PtrOp && "Must have pointer op")(static_cast <bool> (PtrOp && "Must have pointer op"
) ? void (0) : __assert_fail ("PtrOp && \"Must have pointer op\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 4728, __extension__
__PRETTY_FUNCTION__))
;
4729 V = PtrOp;
4730 } else // Not something we can look further into.
4731 return V;
4732 }
4733}
4734
4735/// Push users of the given Instruction onto the given Worklist.
4736static void PushDefUseChildren(Instruction *I,
4737 SmallVectorImpl<Instruction *> &Worklist,
4738 SmallPtrSetImpl<Instruction *> &Visited) {
4739 // Push the def-use children onto the Worklist stack.
4740 for (User *U : I->users()) {
4741 auto *UserInsn = cast<Instruction>(U);
4742 if (Visited.insert(UserInsn).second)
4743 Worklist.push_back(UserInsn);
4744 }
4745}
4746
4747namespace {
4748
4749/// Takes SCEV S and Loop L. For each AddRec sub-expression, use its start
4750/// expression in case its Loop is L. If it is not L then
4751/// if IgnoreOtherLoops is true then use AddRec itself
4752/// otherwise rewrite cannot be done.
4753/// If SCEV contains non-invariant unknown SCEV rewrite cannot be done.
4754class SCEVInitRewriter : public SCEVRewriteVisitor<SCEVInitRewriter> {
4755public:
4756 static const SCEV *rewrite(const SCEV *S, const Loop *L, ScalarEvolution &SE,
4757 bool IgnoreOtherLoops = true) {
4758 SCEVInitRewriter Rewriter(L, SE);
4759 const SCEV *Result = Rewriter.visit(S);
4760 if (Rewriter.hasSeenLoopVariantSCEVUnknown())
4761 return SE.getCouldNotCompute();
4762 return Rewriter.hasSeenOtherLoops() && !IgnoreOtherLoops
4763 ? SE.getCouldNotCompute()
4764 : Result;
4765 }
4766
4767 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
4768 if (!SE.isLoopInvariant(Expr, L))
4769 SeenLoopVariantSCEVUnknown = true;
4770 return Expr;
4771 }
4772
4773 const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) {
4774 // Only re-write AddRecExprs for this loop.
4775 if (Expr->getLoop() == L)
4776 return Expr->getStart();
4777 SeenOtherLoops = true;
4778 return Expr;
4779 }
4780
4781 bool hasSeenLoopVariantSCEVUnknown() { return SeenLoopVariantSCEVUnknown; }
4782
4783 bool hasSeenOtherLoops() { return SeenOtherLoops; }
4784
4785private:
4786 explicit SCEVInitRewriter(const Loop *L, ScalarEvolution &SE)
4787 : SCEVRewriteVisitor(SE), L(L) {}
4788
4789 const Loop *L;
4790 bool SeenLoopVariantSCEVUnknown = false;
4791 bool SeenOtherLoops = false;
4792};
4793
4794/// Takes SCEV S and Loop L. For each AddRec sub-expression, use its post
4795/// increment expression in case its Loop is L. If it is not L then
4796/// use AddRec itself.
4797/// If SCEV contains non-invariant unknown SCEV rewrite cannot be done.
4798class SCEVPostIncRewriter : public SCEVRewriteVisitor<SCEVPostIncRewriter> {
4799public:
4800 static const SCEV *rewrite(const SCEV *S, const Loop *L, ScalarEvolution &SE) {
4801 SCEVPostIncRewriter Rewriter(L, SE);
4802 const SCEV *Result = Rewriter.visit(S);
4803 return Rewriter.hasSeenLoopVariantSCEVUnknown()
4804 ? SE.getCouldNotCompute()
4805 : Result;
4806 }
4807
4808 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
4809 if (!SE.isLoopInvariant(Expr, L))
4810 SeenLoopVariantSCEVUnknown = true;
4811 return Expr;
4812 }
4813
4814 const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) {
4815 // Only re-write AddRecExprs for this loop.
4816 if (Expr->getLoop() == L)
4817 return Expr->getPostIncExpr(SE);
4818 SeenOtherLoops = true;
4819 return Expr;
4820 }
4821
4822 bool hasSeenLoopVariantSCEVUnknown() { return SeenLoopVariantSCEVUnknown; }
4823
4824 bool hasSeenOtherLoops() { return SeenOtherLoops; }
4825
4826private:
4827 explicit SCEVPostIncRewriter(const Loop *L, ScalarEvolution &SE)
4828 : SCEVRewriteVisitor(SE), L(L) {}
4829
4830 const Loop *L;
4831 bool SeenLoopVariantSCEVUnknown = false;
4832 bool SeenOtherLoops = false;
4833};
4834
4835/// This class evaluates the compare condition by matching it against the
4836/// condition of loop latch. If there is a match we assume a true value
4837/// for the condition while building SCEV nodes.
4838class SCEVBackedgeConditionFolder
4839 : public SCEVRewriteVisitor<SCEVBackedgeConditionFolder> {
4840public:
4841 static const SCEV *rewrite(const SCEV *S, const Loop *L,
4842 ScalarEvolution &SE) {
4843 bool IsPosBECond = false;
4844 Value *BECond = nullptr;
4845 if (BasicBlock *Latch = L->getLoopLatch()) {
4846 BranchInst *BI = dyn_cast<BranchInst>(Latch->getTerminator());
4847 if (BI && BI->isConditional()) {
4848 assert(BI->getSuccessor(0) != BI->getSuccessor(1) &&(static_cast <bool> (BI->getSuccessor(0) != BI->getSuccessor
(1) && "Both outgoing branches should not target same header!"
) ? void (0) : __assert_fail ("BI->getSuccessor(0) != BI->getSuccessor(1) && \"Both outgoing branches should not target same header!\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 4849, __extension__
__PRETTY_FUNCTION__))
4849 "Both outgoing branches should not target same header!")(static_cast <bool> (BI->getSuccessor(0) != BI->getSuccessor
(1) && "Both outgoing branches should not target same header!"
) ? void (0) : __assert_fail ("BI->getSuccessor(0) != BI->getSuccessor(1) && \"Both outgoing branches should not target same header!\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 4849, __extension__
__PRETTY_FUNCTION__))
;
4850 BECond = BI->getCondition();
4851 IsPosBECond = BI->getSuccessor(0) == L->getHeader();
4852 } else {
4853 return S;
4854 }
4855 }
4856 SCEVBackedgeConditionFolder Rewriter(L, BECond, IsPosBECond, SE);
4857 return Rewriter.visit(S);
4858 }
4859
4860 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
4861 const SCEV *Result = Expr;
4862 bool InvariantF = SE.isLoopInvariant(Expr, L);
4863
4864 if (!InvariantF) {
4865 Instruction *I = cast<Instruction>(Expr->getValue());
4866 switch (I->getOpcode()) {
4867 case Instruction::Select: {
4868 SelectInst *SI = cast<SelectInst>(I);
4869 Optional<const SCEV *> Res =
4870 compareWithBackedgeCondition(SI->getCondition());
4871 if (Res) {
4872 bool IsOne = cast<SCEVConstant>(Res.value())->getValue()->isOne();
4873 Result = SE.getSCEV(IsOne ? SI->getTrueValue() : SI->getFalseValue());
4874 }
4875 break;
4876 }
4877 default: {
4878 Optional<const SCEV *> Res = compareWithBackedgeCondition(I);
4879 if (Res)
4880 Result = Res.value();
4881 break;
4882 }
4883 }
4884 }
4885 return Result;
4886 }
4887
4888private:
4889 explicit SCEVBackedgeConditionFolder(const Loop *L, Value *BECond,
4890 bool IsPosBECond, ScalarEvolution &SE)
4891 : SCEVRewriteVisitor(SE), L(L), BackedgeCond(BECond),
4892 IsPositiveBECond(IsPosBECond) {}
4893
4894 Optional<const SCEV *> compareWithBackedgeCondition(Value *IC);
4895
4896 const Loop *L;
4897 /// Loop back condition.
4898 Value *BackedgeCond = nullptr;
4899 /// Set to true if loop back is on positive branch condition.
4900 bool IsPositiveBECond;
4901};
4902
4903Optional<const SCEV *>
4904SCEVBackedgeConditionFolder::compareWithBackedgeCondition(Value *IC) {
4905
4906 // If value matches the backedge condition for loop latch,
4907 // then return a constant evolution node based on loopback
4908 // branch taken.
4909 if (BackedgeCond == IC)
4910 return IsPositiveBECond ? SE.getOne(Type::getInt1Ty(SE.getContext()))
4911 : SE.getZero(Type::getInt1Ty(SE.getContext()));
4912 return None;
4913}
4914
4915class SCEVShiftRewriter : public SCEVRewriteVisitor<SCEVShiftRewriter> {
4916public:
4917 static const SCEV *rewrite(const SCEV *S, const Loop *L,
4918 ScalarEvolution &SE) {
4919 SCEVShiftRewriter Rewriter(L, SE);
4920 const SCEV *Result = Rewriter.visit(S);
4921 return Rewriter.isValid() ? Result : SE.getCouldNotCompute();
4922 }
4923
4924 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
4925 // Only allow AddRecExprs for this loop.
4926 if (!SE.isLoopInvariant(Expr, L))
4927 Valid = false;
4928 return Expr;
4929 }
4930
4931 const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) {
4932 if (Expr->getLoop() == L && Expr->isAffine())
4933 return SE.getMinusSCEV(Expr, Expr->getStepRecurrence(SE));
4934 Valid = false;
4935 return Expr;
4936 }
4937
4938 bool isValid() { return Valid; }
4939
4940private:
4941 explicit SCEVShiftRewriter(const Loop *L, ScalarEvolution &SE)
4942 : SCEVRewriteVisitor(SE), L(L) {}
4943
4944 const Loop *L;
4945 bool Valid = true;
4946};
4947
4948} // end anonymous namespace
4949
4950SCEV::NoWrapFlags
4951ScalarEvolution::proveNoWrapViaConstantRanges(const SCEVAddRecExpr *AR) {
4952 if (!AR->isAffine())
4953 return SCEV::FlagAnyWrap;
4954
4955 using OBO = OverflowingBinaryOperator;
4956
4957 SCEV::NoWrapFlags Result = SCEV::FlagAnyWrap;
4958
4959 if (!AR->hasNoSignedWrap()) {
4960 ConstantRange AddRecRange = getSignedRange(AR);
4961 ConstantRange IncRange = getSignedRange(AR->getStepRecurrence(*this));
4962
4963 auto NSWRegion = ConstantRange::makeGuaranteedNoWrapRegion(
4964 Instruction::Add, IncRange, OBO::NoSignedWrap);
4965 if (NSWRegion.contains(AddRecRange))
4966 Result = ScalarEvolution::setFlags(Result, SCEV::FlagNSW);
4967 }
4968
4969 if (!AR->hasNoUnsignedWrap()) {
4970 ConstantRange AddRecRange = getUnsignedRange(AR);
4971 ConstantRange IncRange = getUnsignedRange(AR->getStepRecurrence(*this));
4972
4973 auto NUWRegion = ConstantRange::makeGuaranteedNoWrapRegion(
4974 Instruction::Add, IncRange, OBO::NoUnsignedWrap);
4975 if (NUWRegion.contains(AddRecRange))
4976 Result = ScalarEvolution::setFlags(Result, SCEV::FlagNUW);
4977 }
4978
4979 return Result;
4980}
4981
4982SCEV::NoWrapFlags
4983ScalarEvolution::proveNoSignedWrapViaInduction(const SCEVAddRecExpr *AR) {
4984 SCEV::NoWrapFlags Result = AR->getNoWrapFlags();
4985
4986 if (AR->hasNoSignedWrap())
4987 return Result;
4988
4989 if (!AR->isAffine())
4990 return Result;
4991
4992 // This function can be expensive, only try to prove NSW once per AddRec.
4993 if (!SignedWrapViaInductionTried.insert(AR).second)
4994 return Result;
4995
4996 const SCEV *Step = AR->getStepRecurrence(*this);
4997 const Loop *L = AR->getLoop();
4998
4999 // Check whether the backedge-taken count is SCEVCouldNotCompute.
5000 // Note that this serves two purposes: It filters out loops that are
5001 // simply not analyzable, and it covers the case where this code is
5002 // being called from within backedge-taken count analysis, such that
5003 // attempting to ask for the backedge-taken count would likely result
5004 // in infinite recursion. In the later case, the analysis code will
5005 // cope with a conservative value, and it will take care to purge
5006 // that value once it has finished.
5007 const SCEV *MaxBECount = getConstantMaxBackedgeTakenCount(L);
5008
5009 // Normally, in the cases we can prove no-overflow via a
5010 // backedge guarding condition, we can also compute a backedge
5011 // taken count for the loop. The exceptions are assumptions and
5012 // guards present in the loop -- SCEV is not great at exploiting
5013 // these to compute max backedge taken counts, but can still use
5014 // these to prove lack of overflow. Use this fact to avoid
5015 // doing extra work that may not pay off.
5016
5017 if (isa<SCEVCouldNotCompute>(MaxBECount) && !HasGuards &&
5018 AC.assumptions().empty())
5019 return Result;
5020
5021 // If the backedge is guarded by a comparison with the pre-inc value the
5022 // addrec is safe. Also, if the entry is guarded by a comparison with the
5023 // start value and the backedge is guarded by a comparison with the post-inc
5024 // value, the addrec is safe.
5025 ICmpInst::Predicate Pred;
5026 const SCEV *OverflowLimit =
5027 getSignedOverflowLimitForStep(Step, &Pred, this);
5028 if (OverflowLimit &&
5029 (isLoopBackedgeGuardedByCond(L, Pred, AR, OverflowLimit) ||
5030 isKnownOnEveryIteration(Pred, AR, OverflowLimit))) {
5031 Result = setFlags(Result, SCEV::FlagNSW);
5032 }
5033 return Result;
5034}
5035SCEV::NoWrapFlags
5036ScalarEvolution::proveNoUnsignedWrapViaInduction(const SCEVAddRecExpr *AR) {
5037 SCEV::NoWrapFlags Result = AR->getNoWrapFlags();
5038
5039 if (AR->hasNoUnsignedWrap())
5040 return Result;
5041
5042 if (!AR->isAffine())
5043 return Result;
5044
5045 // This function can be expensive, only try to prove NUW once per AddRec.
5046 if (!UnsignedWrapViaInductionTried.insert(AR).second)
5047 return Result;
5048
5049 const SCEV *Step = AR->getStepRecurrence(*this);
5050 unsigned BitWidth = getTypeSizeInBits(AR->getType());
5051 const Loop *L = AR->getLoop();
5052
5053 // Check whether the backedge-taken count is SCEVCouldNotCompute.
5054 // Note that this serves two purposes: It filters out loops that are
5055 // simply not analyzable, and it covers the case where this code is
5056 // being called from within backedge-taken count analysis, such that
5057 // attempting to ask for the backedge-taken count would likely result
5058 // in infinite recursion. In the later case, the analysis code will
5059 // cope with a conservative value, and it will take care to purge
5060 // that value once it has finished.
5061 const SCEV *MaxBECount = getConstantMaxBackedgeTakenCount(L);
5062
5063 // Normally, in the cases we can prove no-overflow via a
5064 // backedge guarding condition, we can also compute a backedge
5065 // taken count for the loop. The exceptions are assumptions and
5066 // guards present in the loop -- SCEV is not great at exploiting
5067 // these to compute max backedge taken counts, but can still use
5068 // these to prove lack of overflow. Use this fact to avoid
5069 // doing extra work that may not pay off.
5070
5071 if (isa<SCEVCouldNotCompute>(MaxBECount) && !HasGuards &&
5072 AC.assumptions().empty())
5073 return Result;
5074
5075 // If the backedge is guarded by a comparison with the pre-inc value the
5076 // addrec is safe. Also, if the entry is guarded by a comparison with the
5077 // start value and the backedge is guarded by a comparison with the post-inc
5078 // value, the addrec is safe.
5079 if (isKnownPositive(Step)) {
5080 const SCEV *N = getConstant(APInt::getMinValue(BitWidth) -
5081 getUnsignedRangeMax(Step));
5082 if (isLoopBackedgeGuardedByCond(L, ICmpInst::ICMP_ULT, AR, N) ||
5083 isKnownOnEveryIteration(ICmpInst::ICMP_ULT, AR, N)) {
5084 Result = setFlags(Result, SCEV::FlagNUW);
5085 }
5086 }
5087
5088 return Result;
5089}
5090
5091namespace {
5092
5093/// Represents an abstract binary operation. This may exist as a
5094/// normal instruction or constant expression, or may have been
5095/// derived from an expression tree.
5096struct BinaryOp {
5097 unsigned Opcode;
5098 Value *LHS;
5099 Value *RHS;
5100 bool IsNSW = false;
5101 bool IsNUW = false;
5102
5103 /// Op is set if this BinaryOp corresponds to a concrete LLVM instruction or
5104 /// constant expression.
5105 Operator *Op = nullptr;
5106
5107 explicit BinaryOp(Operator *Op)
5108 : Opcode(Op->getOpcode()), LHS(Op->getOperand(0)), RHS(Op->getOperand(1)),
5109 Op(Op) {
5110 if (auto *OBO = dyn_cast<OverflowingBinaryOperator>(Op)) {
5111 IsNSW = OBO->hasNoSignedWrap();
5112 IsNUW = OBO->hasNoUnsignedWrap();
5113 }
5114 }
5115
5116 explicit BinaryOp(unsigned Opcode, Value *LHS, Value *RHS, bool IsNSW = false,
5117 bool IsNUW = false)
5118 : Opcode(Opcode), LHS(LHS), RHS(RHS), IsNSW(IsNSW), IsNUW(IsNUW) {}
5119};
5120
5121} // end anonymous namespace
5122
5123/// Try to map \p V into a BinaryOp, and return \c None on failure.
5124static Optional<BinaryOp> MatchBinaryOp(Value *V, DominatorTree &DT) {
5125 auto *Op = dyn_cast<Operator>(V);
8
Assuming 'V' is a 'CastReturnType'
5126 if (!Op
8.1
'Op' is non-null
8.1
'Op' is non-null
)
9
Taking false branch
5127 return None;
5128
5129 // Implementation detail: all the cleverness here should happen without
5130 // creating new SCEV expressions -- our caller knowns tricks to avoid creating
5131 // SCEV expressions when possible, and we should not break that.
5132
5133 switch (Op->getOpcode()) {
10
Control jumps to 'case ExtractValue:' at line 5174
5134 case Instruction::Add:
5135 case Instruction::Sub:
5136 case Instruction::Mul:
5137 case Instruction::UDiv:
5138 case Instruction::URem:
5139 case Instruction::And:
5140 case Instruction::Or:
5141 case Instruction::AShr:
5142 case Instruction::Shl:
5143 return BinaryOp(Op);
5144
5145 case Instruction::Xor:
5146 if (auto *RHSC = dyn_cast<ConstantInt>(Op->getOperand(1)))
5147 // If the RHS of the xor is a signmask, then this is just an add.
5148 // Instcombine turns add of signmask into xor as a strength reduction step.
5149 if (RHSC->getValue().isSignMask())
5150 return BinaryOp(Instruction::Add, Op->getOperand(0), Op->getOperand(1));
5151 // Binary `xor` is a bit-wise `add`.
5152 if (V->getType()->isIntegerTy(1))
5153 return BinaryOp(Instruction::Add, Op->getOperand(0), Op->getOperand(1));
5154 return BinaryOp(Op);
5155
5156 case Instruction::LShr:
5157 // Turn logical shift right of a constant into a unsigned divide.
5158 if (ConstantInt *SA = dyn_cast<ConstantInt>(Op->getOperand(1))) {
5159 uint32_t BitWidth = cast<IntegerType>(Op->getType())->getBitWidth();
5160
5161 // If the shift count is not less than the bitwidth, the result of
5162 // the shift is undefined. Don't try to analyze it, because the
5163 // resolution chosen here may differ from the resolution chosen in
5164 // other parts of the compiler.
5165 if (SA->getValue().ult(BitWidth)) {
5166 Constant *X =
5167 ConstantInt::get(SA->getContext(),
5168 APInt::getOneBitSet(BitWidth, SA->getZExtValue()));
5169 return BinaryOp(Instruction::UDiv, Op->getOperand(0), X);
5170 }
5171 }
5172 return BinaryOp(Op);
5173
5174 case Instruction::ExtractValue: {
5175 auto *EVI = cast<ExtractValueInst>(Op);
11
'Op' is a 'CastReturnType'
5176 if (EVI->getNumIndices() != 1 || EVI->getIndices()[0] != 0)
12
Assuming the condition is false
13
Assuming the condition is false
14
Taking false branch
5177 break;
5178
5179 auto *WO = dyn_cast<WithOverflowInst>(EVI->getAggregateOperand());
5180 if (!WO)
15
Assuming 'WO' is non-null
16
Taking false branch
5181 break;
5182
5183 Instruction::BinaryOps BinOp = WO->getBinaryOp();
5184 bool Signed = WO->isSigned();
5185 // TODO: Should add nuw/nsw flags for mul as well.
5186 if (BinOp == Instruction::Mul || !isOverflowIntrinsicNoWrap(WO, DT))
17
Assuming 'BinOp' is not equal to Mul
18
Assuming the condition is true
19
Taking true branch
5187 return BinaryOp(BinOp, WO->getLHS(), WO->getRHS());
20
Calling constructor for 'Optional<(anonymous namespace)::BinaryOp>'
24
Returning from constructor for 'Optional<(anonymous namespace)::BinaryOp>'
5188
5189 // Now that we know that all uses of the arithmetic-result component of
5190 // CI are guarded by the overflow check, we can go ahead and pretend
5191 // that the arithmetic is non-overflowing.
5192 return BinaryOp(BinOp, WO->getLHS(), WO->getRHS(),
5193 /* IsNSW = */ Signed, /* IsNUW = */ !Signed);
5194 }
5195
5196 default:
5197 break;
5198 }
5199
5200 // Recognise intrinsic loop.decrement.reg, and as this has exactly the same
5201 // semantics as a Sub, return a binary sub expression.
5202 if (auto *II = dyn_cast<IntrinsicInst>(V))
5203 if (II->getIntrinsicID() == Intrinsic::loop_decrement_reg)
5204 return BinaryOp(Instruction::Sub, II->getOperand(0), II->getOperand(1));
5205
5206 return None;
5207}
5208
5209/// Helper function to createAddRecFromPHIWithCasts. We have a phi
5210/// node whose symbolic (unknown) SCEV is \p SymbolicPHI, which is updated via
5211/// the loop backedge by a SCEVAddExpr, possibly also with a few casts on the
5212/// way. This function checks if \p Op, an operand of this SCEVAddExpr,
5213/// follows one of the following patterns:
5214/// Op == (SExt ix (Trunc iy (%SymbolicPHI) to ix) to iy)
5215/// Op == (ZExt ix (Trunc iy (%SymbolicPHI) to ix) to iy)
5216/// If the SCEV expression of \p Op conforms with one of the expected patterns
5217/// we return the type of the truncation operation, and indicate whether the
5218/// truncated type should be treated as signed/unsigned by setting
5219/// \p Signed to true/false, respectively.
5220static Type *isSimpleCastedPHI(const SCEV *Op, const SCEVUnknown *SymbolicPHI,
5221 bool &Signed, ScalarEvolution &SE) {
5222 // The case where Op == SymbolicPHI (that is, with no type conversions on
5223 // the way) is handled by the regular add recurrence creating logic and
5224 // would have already been triggered in createAddRecForPHI. Reaching it here
5225 // means that createAddRecFromPHI had failed for this PHI before (e.g.,
5226 // because one of the other operands of the SCEVAddExpr updating this PHI is
5227 // not invariant).
5228 //
5229 // Here we look for the case where Op = (ext(trunc(SymbolicPHI))), and in
5230 // this case predicates that allow us to prove that Op == SymbolicPHI will
5231 // be added.
5232 if (Op == SymbolicPHI)
5233 return nullptr;
5234
5235 unsigned SourceBits = SE.getTypeSizeInBits(SymbolicPHI->getType());
5236 unsigned NewBits = SE.getTypeSizeInBits(Op->getType());
5237 if (SourceBits != NewBits)
5238 return nullptr;
5239
5240 const SCEVSignExtendExpr *SExt = dyn_cast<SCEVSignExtendExpr>(Op);
5241 const SCEVZeroExtendExpr *ZExt = dyn_cast<SCEVZeroExtendExpr>(Op);
5242 if (!SExt && !ZExt)
5243 return nullptr;
5244 const SCEVTruncateExpr *Trunc =
5245 SExt ? dyn_cast<SCEVTruncateExpr>(SExt->getOperand())
5246 : dyn_cast<SCEVTruncateExpr>(ZExt->getOperand());
5247 if (!Trunc)
5248 return nullptr;
5249 const SCEV *X = Trunc->getOperand();
5250 if (X != SymbolicPHI)
5251 return nullptr;
5252 Signed = SExt != nullptr;
5253 return Trunc->getType();
5254}
5255
5256static const Loop *isIntegerLoopHeaderPHI(const PHINode *PN, LoopInfo &LI) {
5257 if (!PN->getType()->isIntegerTy())
5258 return nullptr;
5259 const Loop *L = LI.getLoopFor(PN->getParent());
5260 if (!L || L->getHeader() != PN->getParent())
5261 return nullptr;
5262 return L;
5263}
5264
5265// Analyze \p SymbolicPHI, a SCEV expression of a phi node, and check if the
5266// computation that updates the phi follows the following pattern:
5267// (SExt/ZExt ix (Trunc iy (%SymbolicPHI) to ix) to iy) + InvariantAccum
5268// which correspond to a phi->trunc->sext/zext->add->phi update chain.
5269// If so, try to see if it can be rewritten as an AddRecExpr under some
5270// Predicates. If successful, return them as a pair. Also cache the results
5271// of the analysis.
5272//
5273// Example usage scenario:
5274// Say the Rewriter is called for the following SCEV:
5275// 8 * ((sext i32 (trunc i64 %X to i32) to i64) + %Step)
5276// where:
5277// %X = phi i64 (%Start, %BEValue)
5278// It will visitMul->visitAdd->visitSExt->visitTrunc->visitUnknown(%X),
5279// and call this function with %SymbolicPHI = %X.
5280//
5281// The analysis will find that the value coming around the backedge has
5282// the following SCEV:
5283// BEValue = ((sext i32 (trunc i64 %X to i32) to i64) + %Step)
5284// Upon concluding that this matches the desired pattern, the function
5285// will return the pair {NewAddRec, SmallPredsVec} where:
5286// NewAddRec = {%Start,+,%Step}
5287// SmallPredsVec = {P1, P2, P3} as follows:
5288// P1(WrapPred): AR: {trunc(%Start),+,(trunc %Step)}<nsw> Flags: <nssw>
5289// P2(EqualPred): %Start == (sext i32 (trunc i64 %Start to i32) to i64)
5290// P3(EqualPred): %Step == (sext i32 (trunc i64 %Step to i32) to i64)
5291// The returned pair means that SymbolicPHI can be rewritten into NewAddRec
5292// under the predicates {P1,P2,P3}.
5293// This predicated rewrite will be cached in PredicatedSCEVRewrites:
5294// PredicatedSCEVRewrites[{%X,L}] = {NewAddRec, {P1,P2,P3)}
5295//
5296// TODO's:
5297//
5298// 1) Extend the Induction descriptor to also support inductions that involve
5299// casts: When needed (namely, when we are called in the context of the
5300// vectorizer induction analysis), a Set of cast instructions will be
5301// populated by this method, and provided back to isInductionPHI. This is
5302// needed to allow the vectorizer to properly record them to be ignored by
5303// the cost model and to avoid vectorizing them (otherwise these casts,
5304// which are redundant under the runtime overflow checks, will be
5305// vectorized, which can be costly).
5306//
5307// 2) Support additional induction/PHISCEV patterns: We also want to support
5308// inductions where the sext-trunc / zext-trunc operations (partly) occur
5309// after the induction update operation (the induction increment):
5310//
5311// (Trunc iy (SExt/ZExt ix (%SymbolicPHI + InvariantAccum) to iy) to ix)
5312// which correspond to a phi->add->trunc->sext/zext->phi update chain.
5313//
5314// (Trunc iy ((SExt/ZExt ix (%SymbolicPhi) to iy) + InvariantAccum) to ix)
5315// which correspond to a phi->trunc->add->sext/zext->phi update chain.
5316//
5317// 3) Outline common code with createAddRecFromPHI to avoid duplication.
5318Optional<std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>>>
5319ScalarEvolution::createAddRecFromPHIWithCastsImpl(const SCEVUnknown *SymbolicPHI) {
5320 SmallVector<const SCEVPredicate *, 3> Predicates;
5321
5322 // *** Part1: Analyze if we have a phi-with-cast pattern for which we can
5323 // return an AddRec expression under some predicate.
5324
5325 auto *PN = cast<PHINode>(SymbolicPHI->getValue());
5326 const Loop *L = isIntegerLoopHeaderPHI(PN, LI);
5327 assert(L && "Expecting an integer loop header phi")(static_cast <bool> (L && "Expecting an integer loop header phi"
) ? void (0) : __assert_fail ("L && \"Expecting an integer loop header phi\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 5327, __extension__
__PRETTY_FUNCTION__))
;
5328
5329 // The loop may have multiple entrances or multiple exits; we can analyze
5330 // this phi as an addrec if it has a unique entry value and a unique
5331 // backedge value.
5332 Value *BEValueV = nullptr, *StartValueV = nullptr;
5333 for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) {
5334 Value *V = PN->getIncomingValue(i);
5335 if (L->contains(PN->getIncomingBlock(i))) {
5336 if (!BEValueV) {
5337 BEValueV = V;
5338 } else if (BEValueV != V) {
5339 BEValueV = nullptr;
5340 break;
5341 }
5342 } else if (!StartValueV) {
5343 StartValueV = V;
5344 } else if (StartValueV != V) {
5345 StartValueV = nullptr;
5346 break;
5347 }
5348 }
5349 if (!BEValueV || !StartValueV)
5350 return None;
5351
5352 const SCEV *BEValue = getSCEV(BEValueV);
5353
5354 // If the value coming around the backedge is an add with the symbolic
5355 // value we just inserted, possibly with casts that we can ignore under
5356 // an appropriate runtime guard, then we found a simple induction variable!
5357 const auto *Add = dyn_cast<SCEVAddExpr>(BEValue);
5358 if (!Add)
5359 return None;
5360
5361 // If there is a single occurrence of the symbolic value, possibly
5362 // casted, replace it with a recurrence.
5363 unsigned FoundIndex = Add->getNumOperands();
5364 Type *TruncTy = nullptr;
5365 bool Signed;
5366 for (unsigned i = 0, e = Add->getNumOperands(); i != e; ++i)
5367 if ((TruncTy =
5368 isSimpleCastedPHI(Add->getOperand(i), SymbolicPHI, Signed, *this)))
5369 if (FoundIndex == e) {
5370 FoundIndex = i;
5371 break;
5372 }
5373
5374 if (FoundIndex == Add->getNumOperands())
5375 return None;
5376
5377 // Create an add with everything but the specified operand.
5378 SmallVector<const SCEV *, 8> Ops;
5379 for (unsigned i = 0, e = Add->getNumOperands(); i != e; ++i)
5380 if (i != FoundIndex)
5381 Ops.push_back(Add->getOperand(i));
5382 const SCEV *Accum = getAddExpr(Ops);
5383
5384 // The runtime checks will not be valid if the step amount is
5385 // varying inside the loop.
5386 if (!isLoopInvariant(Accum, L))
5387 return None;
5388
5389 // *** Part2: Create the predicates
5390
5391 // Analysis was successful: we have a phi-with-cast pattern for which we
5392 // can return an AddRec expression under the following predicates:
5393 //
5394 // P1: A Wrap predicate that guarantees that Trunc(Start) + i*Trunc(Accum)
5395 // fits within the truncated type (does not overflow) for i = 0 to n-1.
5396 // P2: An Equal predicate that guarantees that
5397 // Start = (Ext ix (Trunc iy (Start) to ix) to iy)
5398 // P3: An Equal predicate that guarantees that
5399 // Accum = (Ext ix (Trunc iy (Accum) to ix) to iy)
5400 //
5401 // As we next prove, the above predicates guarantee that:
5402 // Start + i*Accum = (Ext ix (Trunc iy ( Start + i*Accum ) to ix) to iy)
5403 //
5404 //
5405 // More formally, we want to prove that:
5406 // Expr(i+1) = Start + (i+1) * Accum
5407 // = (Ext ix (Trunc iy (Expr(i)) to ix) to iy) + Accum
5408 //
5409 // Given that:
5410 // 1) Expr(0) = Start
5411 // 2) Expr(1) = Start + Accum
5412 // = (Ext ix (Trunc iy (Start) to ix) to iy) + Accum :: from P2
5413 // 3) Induction hypothesis (step i):
5414 // Expr(i) = (Ext ix (Trunc iy (Expr(i-1)) to ix) to iy) + Accum
5415 //
5416 // Proof:
5417 // Expr(i+1) =
5418 // = Start + (i+1)*Accum
5419 // = (Start + i*Accum) + Accum
5420 // = Expr(i) + Accum
5421 // = (Ext ix (Trunc iy (Expr(i-1)) to ix) to iy) + Accum + Accum
5422 // :: from step i
5423 //
5424 // = (Ext ix (Trunc iy (Start + (i-1)*Accum) to ix) to iy) + Accum + Accum
5425 //
5426 // = (Ext ix (Trunc iy (Start + (i-1)*Accum) to ix) to iy)
5427 // + (Ext ix (Trunc iy (Accum) to ix) to iy)
5428 // + Accum :: from P3
5429 //
5430 // = (Ext ix (Trunc iy ((Start + (i-1)*Accum) + Accum) to ix) to iy)
5431 // + Accum :: from P1: Ext(x)+Ext(y)=>Ext(x+y)
5432 //
5433 // = (Ext ix (Trunc iy (Start + i*Accum) to ix) to iy) + Accum
5434 // = (Ext ix (Trunc iy (Expr(i)) to ix) to iy) + Accum
5435 //
5436 // By induction, the same applies to all iterations 1<=i<n:
5437 //
5438
5439 // Create a truncated addrec for which we will add a no overflow check (P1).
5440 const SCEV *StartVal = getSCEV(StartValueV);
5441 const SCEV *PHISCEV =
5442 getAddRecExpr(getTruncateExpr(StartVal, TruncTy),
5443 getTruncateExpr(Accum, TruncTy), L, SCEV::FlagAnyWrap);
5444
5445 // PHISCEV can be either a SCEVConstant or a SCEVAddRecExpr.
5446 // ex: If truncated Accum is 0 and StartVal is a constant, then PHISCEV
5447 // will be constant.
5448 //
5449 // If PHISCEV is a constant, then P1 degenerates into P2 or P3, so we don't
5450 // add P1.
5451 if (const auto *AR = dyn_cast<SCEVAddRecExpr>(PHISCEV)) {
5452 SCEVWrapPredicate::IncrementWrapFlags AddedFlags =
5453 Signed ? SCEVWrapPredicate::IncrementNSSW
5454 : SCEVWrapPredicate::IncrementNUSW;
5455 const SCEVPredicate *AddRecPred = getWrapPredicate(AR, AddedFlags);
5456 Predicates.push_back(AddRecPred);
5457 }
5458
5459 // Create the Equal Predicates P2,P3:
5460
5461 // It is possible that the predicates P2 and/or P3 are computable at
5462 // compile time due to StartVal and/or Accum being constants.
5463 // If either one is, then we can check that now and escape if either P2
5464 // or P3 is false.
5465
5466 // Construct the extended SCEV: (Ext ix (Trunc iy (Expr) to ix) to iy)
5467 // for each of StartVal and Accum
5468 auto getExtendedExpr = [&](const SCEV *Expr,
5469 bool CreateSignExtend) -> const SCEV * {
5470 assert(isLoopInvariant(Expr, L) && "Expr is expected to be invariant")(static_cast <bool> (isLoopInvariant(Expr, L) &&
"Expr is expected to be invariant") ? void (0) : __assert_fail
("isLoopInvariant(Expr, L) && \"Expr is expected to be invariant\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 5470, __extension__
__PRETTY_FUNCTION__))
;
5471 const SCEV *TruncatedExpr = getTruncateExpr(Expr, TruncTy);
5472 const SCEV *ExtendedExpr =
5473 CreateSignExtend ? getSignExtendExpr(TruncatedExpr, Expr->getType())
5474 : getZeroExtendExpr(TruncatedExpr, Expr->getType());
5475 return ExtendedExpr;
5476 };
5477
5478 // Given:
5479 // ExtendedExpr = (Ext ix (Trunc iy (Expr) to ix) to iy
5480 // = getExtendedExpr(Expr)
5481 // Determine whether the predicate P: Expr == ExtendedExpr
5482 // is known to be false at compile time
5483 auto PredIsKnownFalse = [&](const SCEV *Expr,
5484 const SCEV *ExtendedExpr) -> bool {
5485 return Expr != ExtendedExpr &&
5486 isKnownPredicate(ICmpInst::ICMP_NE, Expr, ExtendedExpr);
5487 };
5488
5489 const SCEV *StartExtended = getExtendedExpr(StartVal, Signed);
5490 if (PredIsKnownFalse(StartVal, StartExtended)) {
5491 LLVM_DEBUG(dbgs() << "P2 is compile-time false\n";)do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType
("scalar-evolution")) { dbgs() << "P2 is compile-time false\n"
;; } } while (false)
;
5492 return None;
5493 }
5494
5495 // The Step is always Signed (because the overflow checks are either
5496 // NSSW or NUSW)
5497 const SCEV *AccumExtended = getExtendedExpr(Accum, /*CreateSignExtend=*/true);
5498 if (PredIsKnownFalse(Accum, AccumExtended)) {
5499 LLVM_DEBUG(dbgs() << "P3 is compile-time false\n";)do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType
("scalar-evolution")) { dbgs() << "P3 is compile-time false\n"
;; } } while (false)
;
5500 return None;
5501 }
5502
5503 auto AppendPredicate = [&](const SCEV *Expr,
5504 const SCEV *ExtendedExpr) -> void {
5505 if (Expr != ExtendedExpr &&
5506 !isKnownPredicate(ICmpInst::ICMP_EQ, Expr, ExtendedExpr)) {
5507 const SCEVPredicate *Pred = getEqualPredicate(Expr, ExtendedExpr);
5508 LLVM_DEBUG(dbgs() << "Added Predicate: " << *Pred)do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType
("scalar-evolution")) { dbgs() << "Added Predicate: " <<
*Pred; } } while (false)
;
5509 Predicates.push_back(Pred);
5510 }
5511 };
5512
5513 AppendPredicate(StartVal, StartExtended);
5514 AppendPredicate(Accum, AccumExtended);
5515
5516 // *** Part3: Predicates are ready. Now go ahead and create the new addrec in
5517 // which the casts had been folded away. The caller can rewrite SymbolicPHI
5518 // into NewAR if it will also add the runtime overflow checks specified in
5519 // Predicates.
5520 auto *NewAR = getAddRecExpr(StartVal, Accum, L, SCEV::FlagAnyWrap);
5521
5522 std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>> PredRewrite =
5523 std::make_pair(NewAR, Predicates);
5524 // Remember the result of the analysis for this SCEV at this locayyytion.
5525 PredicatedSCEVRewrites[{SymbolicPHI, L}] = PredRewrite;
5526 return PredRewrite;
5527}
5528
5529Optional<std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>>>
5530ScalarEvolution::createAddRecFromPHIWithCasts(const SCEVUnknown *SymbolicPHI) {
5531 auto *PN = cast<PHINode>(SymbolicPHI->getValue());
5532 const Loop *L = isIntegerLoopHeaderPHI(PN, LI);
5533 if (!L)
5534 return None;
5535
5536 // Check to see if we already analyzed this PHI.
5537 auto I = PredicatedSCEVRewrites.find({SymbolicPHI, L});
5538 if (I != PredicatedSCEVRewrites.end()) {
5539 std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>> Rewrite =
5540 I->second;
5541 // Analysis was done before and failed to create an AddRec:
5542 if (Rewrite.first == SymbolicPHI)
5543 return None;
5544 // Analysis was done before and succeeded to create an AddRec under
5545 // a predicate:
5546 assert(isa<SCEVAddRecExpr>(Rewrite.first) && "Expected an AddRec")(static_cast <bool> (isa<SCEVAddRecExpr>(Rewrite.
first) && "Expected an AddRec") ? void (0) : __assert_fail
("isa<SCEVAddRecExpr>(Rewrite.first) && \"Expected an AddRec\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 5546, __extension__
__PRETTY_FUNCTION__))
;
5547 assert(!(Rewrite.second).empty() && "Expected to find Predicates")(static_cast <bool> (!(Rewrite.second).empty() &&
"Expected to find Predicates") ? void (0) : __assert_fail ("!(Rewrite.second).empty() && \"Expected to find Predicates\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 5547, __extension__
__PRETTY_FUNCTION__))
;
5548 return Rewrite;
5549 }
5550
5551 Optional<std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>>>
5552 Rewrite = createAddRecFromPHIWithCastsImpl(SymbolicPHI);
5553
5554 // Record in the cache that the analysis failed
5555 if (!Rewrite) {
5556 SmallVector<const SCEVPredicate *, 3> Predicates;
5557 PredicatedSCEVRewrites[{SymbolicPHI, L}] = {SymbolicPHI, Predicates};
5558 return None;
5559 }
5560
5561 return Rewrite;
5562}
5563
5564// FIXME: This utility is currently required because the Rewriter currently
5565// does not rewrite this expression:
5566// {0, +, (sext ix (trunc iy to ix) to iy)}
5567// into {0, +, %step},
5568// even when the following Equal predicate exists:
5569// "%step == (sext ix (trunc iy to ix) to iy)".
5570bool PredicatedScalarEvolution::areAddRecsEqualWithPreds(
5571 const SCEVAddRecExpr *AR1, const SCEVAddRecExpr *AR2) const {
5572 if (AR1 == AR2)
5573 return true;
5574
5575 auto areExprsEqual = [&](const SCEV *Expr1, const SCEV *Expr2) -> bool {
5576 if (Expr1 != Expr2 && !Preds->implies(SE.getEqualPredicate(Expr1, Expr2)) &&
5577 !Preds->implies(SE.getEqualPredicate(Expr2, Expr1)))
5578 return false;
5579 return true;
5580 };
5581
5582 if (!areExprsEqual(AR1->getStart(), AR2->getStart()) ||
5583 !areExprsEqual(AR1->getStepRecurrence(SE), AR2->getStepRecurrence(SE)))
5584 return false;
5585 return true;
5586}
5587
5588/// A helper function for createAddRecFromPHI to handle simple cases.
5589///
5590/// This function tries to find an AddRec expression for the simplest (yet most
5591/// common) cases: PN = PHI(Start, OP(Self, LoopInvariant)).
5592/// If it fails, createAddRecFromPHI will use a more general, but slow,
5593/// technique for finding the AddRec expression.
5594const SCEV *ScalarEvolution::createSimpleAffineAddRec(PHINode *PN,
5595 Value *BEValueV,
5596 Value *StartValueV) {
5597 const Loop *L = LI.getLoopFor(PN->getParent());
5598 assert(L && L->getHeader() == PN->getParent())(static_cast <bool> (L && L->getHeader() == PN
->getParent()) ? void (0) : __assert_fail ("L && L->getHeader() == PN->getParent()"
, "llvm/lib/Analysis/ScalarEvolution.cpp", 5598, __extension__
__PRETTY_FUNCTION__))
;
1
Assuming 'L' is non-null
2
Assuming the condition is true
3
'?' condition is true
5599 assert(BEValueV && StartValueV)(static_cast <bool> (BEValueV && StartValueV) ?
void (0) : __assert_fail ("BEValueV && StartValueV",
"llvm/lib/Analysis/ScalarEvolution.cpp", 5599, __extension__
__PRETTY_FUNCTION__))
;
4
Assuming 'BEValueV' is non-null
5
Assuming the condition is true
6
'?' condition is true
5600
5601 auto BO = MatchBinaryOp(BEValueV, DT);
7
Calling 'MatchBinaryOp'
25
Returning from 'MatchBinaryOp'
5602 if (!BO)
26
Taking false branch
5603 return nullptr;
5604
5605 if (BO->Opcode != Instruction::Add)
27
Assuming field 'Opcode' is equal to Add
28
Taking false branch
5606 return nullptr;
5607
5608 const SCEV *Accum = nullptr;
5609 if (BO->LHS == PN
28.1
'PN' is not equal to field 'LHS'
28.1
'PN' is not equal to field 'LHS'
&& L->isLoopInvariant(BO->RHS))
5610 Accum = getSCEV(BO->RHS);
5611 else if (BO->RHS == PN && L->isLoopInvariant(BO->LHS))
29
Assuming 'PN' is equal to field 'RHS'
30
Assuming the condition is true
31
Taking true branch
5612 Accum = getSCEV(BO->LHS);
32
Passing null pointer value via 1st parameter 'V'
33
Calling 'ScalarEvolution::getSCEV'
5613
5614 if (!Accum)
5615 return nullptr;
5616
5617 SCEV::NoWrapFlags Flags = SCEV::FlagAnyWrap;
5618 if (BO->IsNUW)
5619 Flags = setFlags(Flags, SCEV::FlagNUW);
5620 if (BO->IsNSW)
5621 Flags = setFlags(Flags, SCEV::FlagNSW);
5622
5623 const SCEV *StartVal = getSCEV(StartValueV);
5624 const SCEV *PHISCEV = getAddRecExpr(StartVal, Accum, L, Flags);
5625 insertValueToMap(PN, PHISCEV);
5626
5627 // We can add Flags to the post-inc expression only if we
5628 // know that it is *undefined behavior* for BEValueV to
5629 // overflow.
5630 if (auto *BEInst = dyn_cast<Instruction>(BEValueV)) {
5631 assert(isLoopInvariant(Accum, L) &&(static_cast <bool> (isLoopInvariant(Accum, L) &&
"Accum is defined outside L, but is not invariant?") ? void (
0) : __assert_fail ("isLoopInvariant(Accum, L) && \"Accum is defined outside L, but is not invariant?\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 5632, __extension__
__PRETTY_FUNCTION__))
5632 "Accum is defined outside L, but is not invariant?")(static_cast <bool> (isLoopInvariant(Accum, L) &&
"Accum is defined outside L, but is not invariant?") ? void (
0) : __assert_fail ("isLoopInvariant(Accum, L) && \"Accum is defined outside L, but is not invariant?\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 5632, __extension__
__PRETTY_FUNCTION__))
;
5633 if (isAddRecNeverPoison(BEInst, L))
5634 (void)getAddRecExpr(getAddExpr(StartVal, Accum), Accum, L, Flags);
5635 }
5636
5637 return PHISCEV;
5638}
5639
5640const SCEV *ScalarEvolution::createAddRecFromPHI(PHINode *PN) {
5641 const Loop *L = LI.getLoopFor(PN->getParent());
5642 if (!L || L->getHeader() != PN->getParent())
5643 return nullptr;
5644
5645 // The loop may have multiple entrances or multiple exits; we can analyze
5646 // this phi as an addrec if it has a unique entry value and a unique
5647 // backedge value.
5648 Value *BEValueV = nullptr, *StartValueV = nullptr;
5649 for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) {
5650 Value *V = PN->getIncomingValue(i);
5651 if (L->contains(PN->getIncomingBlock(i))) {
5652 if (!BEValueV) {
5653 BEValueV = V;
5654 } else if (BEValueV != V) {
5655 BEValueV = nullptr;
5656 break;
5657 }
5658 } else if (!StartValueV) {
5659 StartValueV = V;
5660 } else if (StartValueV != V) {
5661 StartValueV = nullptr;
5662 break;
5663 }
5664 }
5665 if (!BEValueV || !StartValueV)
5666 return nullptr;
5667
5668 assert(ValueExprMap.find_as(PN) == ValueExprMap.end() &&(static_cast <bool> (ValueExprMap.find_as(PN) == ValueExprMap
.end() && "PHI node already processed?") ? void (0) :
__assert_fail ("ValueExprMap.find_as(PN) == ValueExprMap.end() && \"PHI node already processed?\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 5669, __extension__
__PRETTY_FUNCTION__))
5669 "PHI node already processed?")(static_cast <bool> (ValueExprMap.find_as(PN) == ValueExprMap
.end() && "PHI node already processed?") ? void (0) :
__assert_fail ("ValueExprMap.find_as(PN) == ValueExprMap.end() && \"PHI node already processed?\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 5669, __extension__
__PRETTY_FUNCTION__))
;
5670
5671 // First, try to find AddRec expression without creating a fictituos symbolic
5672 // value for PN.
5673 if (auto *S = createSimpleAffineAddRec(PN, BEValueV, StartValueV))
5674 return S;
5675
5676 // Handle PHI node value symbolically.
5677 const SCEV *SymbolicName = getUnknown(PN);
5678 insertValueToMap(PN, SymbolicName);
5679
5680 // Using this symbolic name for the PHI, analyze the value coming around
5681 // the back-edge.
5682 const SCEV *BEValue = getSCEV(BEValueV);
5683
5684 // NOTE: If BEValue is loop invariant, we know that the PHI node just
5685 // has a special value for the first iteration of the loop.
5686
5687 // If the value coming around the backedge is an add with the symbolic
5688 // value we just inserted, then we found a simple induction variable!
5689 if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(BEValue)) {
5690 // If there is a single occurrence of the symbolic value, replace it
5691 // with a recurrence.
5692 unsigned FoundIndex = Add->getNumOperands();
5693 for (unsigned i = 0, e = Add->getNumOperands(); i != e; ++i)
5694 if (Add->getOperand(i) == SymbolicName)
5695 if (FoundIndex == e) {
5696 FoundIndex = i;
5697 break;
5698 }
5699
5700 if (FoundIndex != Add->getNumOperands()) {
5701 // Create an add with everything but the specified operand.
5702 SmallVector<const SCEV *, 8> Ops;
5703 for (unsigned i = 0, e = Add->getNumOperands(); i != e; ++i)
5704 if (i != FoundIndex)
5705 Ops.push_back(SCEVBackedgeConditionFolder::rewrite(Add->getOperand(i),
5706 L, *this));
5707 const SCEV *Accum = getAddExpr(Ops);
5708
5709 // This is not a valid addrec if the step amount is varying each
5710 // loop iteration, but is not itself an addrec in this loop.
5711 if (isLoopInvariant(Accum, L) ||
5712 (isa<SCEVAddRecExpr>(Accum) &&
5713 cast<SCEVAddRecExpr>(Accum)->getLoop() == L)) {
5714 SCEV::NoWrapFlags Flags = SCEV::FlagAnyWrap;
5715
5716 if (auto BO = MatchBinaryOp(BEValueV, DT)) {
5717 if (BO->Opcode == Instruction::Add && BO->LHS == PN) {
5718 if (BO->IsNUW)
5719 Flags = setFlags(Flags, SCEV::FlagNUW);
5720 if (BO->IsNSW)
5721 Flags = setFlags(Flags, SCEV::FlagNSW);
5722 }
5723 } else if (GEPOperator *GEP = dyn_cast<GEPOperator>(BEValueV)) {
5724 // If the increment is an inbounds GEP, then we know the address
5725 // space cannot be wrapped around. We cannot make any guarantee
5726 // about signed or unsigned overflow because pointers are
5727 // unsigned but we may have a negative index from the base
5728 // pointer. We can guarantee that no unsigned wrap occurs if the
5729 // indices form a positive value.
5730 if (GEP->isInBounds() && GEP->getOperand(0) == PN) {
5731 Flags = setFlags(Flags, SCEV::FlagNW);
5732
5733 const SCEV *Ptr = getSCEV(GEP->getPointerOperand());
5734 if (isKnownPositive(getMinusSCEV(getSCEV(GEP), Ptr)))
5735 Flags = setFlags(Flags, SCEV::FlagNUW);
5736 }
5737
5738 // We cannot transfer nuw and nsw flags from subtraction
5739 // operations -- sub nuw X, Y is not the same as add nuw X, -Y
5740 // for instance.
5741 }
5742
5743 const SCEV *StartVal = getSCEV(StartValueV);
5744 const SCEV *PHISCEV = getAddRecExpr(StartVal, Accum, L, Flags);
5745
5746 // Okay, for the entire analysis of this edge we assumed the PHI
5747 // to be symbolic. We now need to go back and purge all of the
5748 // entries for the scalars that use the symbolic expression.
5749 forgetMemoizedResults(SymbolicName);
5750 insertValueToMap(PN, PHISCEV);
5751
5752 // We can add Flags to the post-inc expression only if we
5753 // know that it is *undefined behavior* for BEValueV to
5754 // overflow.
5755 if (auto *BEInst = dyn_cast<Instruction>(BEValueV))
5756 if (isLoopInvariant(Accum, L) && isAddRecNeverPoison(BEInst, L))
5757 (void)getAddRecExpr(getAddExpr(StartVal, Accum), Accum, L, Flags);
5758
5759 return PHISCEV;
5760 }
5761 }
5762 } else {
5763 // Otherwise, this could be a loop like this:
5764 // i = 0; for (j = 1; ..; ++j) { .... i = j; }
5765 // In this case, j = {1,+,1} and BEValue is j.
5766 // Because the other in-value of i (0) fits the evolution of BEValue
5767 // i really is an addrec evolution.
5768 //
5769 // We can generalize this saying that i is the shifted value of BEValue
5770 // by one iteration:
5771 // PHI(f(0), f({1,+,1})) --> f({0,+,1})
5772 const SCEV *Shifted = SCEVShiftRewriter::rewrite(BEValue, L, *this);
5773 const SCEV *Start = SCEVInitRewriter::rewrite(Shifted, L, *this, false);
5774 if (Shifted != getCouldNotCompute() &&
5775 Start != getCouldNotCompute()) {
5776 const SCEV *StartVal = getSCEV(StartValueV);
5777 if (Start == StartVal) {
5778 // Okay, for the entire analysis of this edge we assumed the PHI
5779 // to be symbolic. We now need to go back and purge all of the
5780 // entries for the scalars that use the symbolic expression.
5781 forgetMemoizedResults(SymbolicName);
5782 insertValueToMap(PN, Shifted);
5783 return Shifted;
5784 }
5785 }
5786 }
5787
5788 // Remove the temporary PHI node SCEV that has been inserted while intending
5789 // to create an AddRecExpr for this PHI node. We can not keep this temporary
5790 // as it will prevent later (possibly simpler) SCEV expressions to be added
5791 // to the ValueExprMap.
5792 eraseValueFromMap(PN);
5793
5794 return nullptr;
5795}
5796
5797// Checks if the SCEV S is available at BB. S is considered available at BB
5798// if S can be materialized at BB without introducing a fault.
5799static bool IsAvailableOnEntry(const Loop *L, DominatorTree &DT, const SCEV *S,
5800 BasicBlock *BB) {
5801 struct CheckAvailable {
5802 bool TraversalDone = false;
5803 bool Available = true;
5804
5805 const Loop *L = nullptr; // The loop BB is in (can be nullptr)
5806 BasicBlock *BB = nullptr;
5807 DominatorTree &DT;
5808
5809 CheckAvailable(const Loop *L, BasicBlock *BB, DominatorTree &DT)
5810 : L(L), BB(BB), DT(DT) {}
5811
5812 bool setUnavailable() {
5813 TraversalDone = true;
5814 Available = false;
5815 return false;
5816 }
5817
5818 bool follow(const SCEV *S) {
5819 switch (S->getSCEVType()) {
5820 case scConstant:
5821 case scPtrToInt:
5822 case scTruncate:
5823 case scZeroExtend:
5824 case scSignExtend:
5825 case scAddExpr:
5826 case scMulExpr:
5827 case scUMaxExpr:
5828 case scSMaxExpr:
5829 case scUMinExpr:
5830 case scSMinExpr:
5831 case scSequentialUMinExpr:
5832 // These expressions are available if their operand(s) is/are.
5833 return true;
5834
5835 case scAddRecExpr: {
5836 // We allow add recurrences that are on the loop BB is in, or some
5837 // outer loop. This guarantees availability because the value of the
5838 // add recurrence at BB is simply the "current" value of the induction
5839 // variable. We can relax this in the future; for instance an add
5840 // recurrence on a sibling dominating loop is also available at BB.
5841 const auto *ARLoop = cast<SCEVAddRecExpr>(S)->getLoop();
5842 if (L && (ARLoop == L || ARLoop->contains(L)))
5843 return true;
5844
5845 return setUnavailable();
5846 }
5847
5848 case scUnknown: {
5849 // For SCEVUnknown, we check for simple dominance.
5850 const auto *SU = cast<SCEVUnknown>(S);
5851 Value *V = SU->getValue();
5852
5853 if (isa<Argument>(V))
5854 return false;
5855
5856 if (isa<Instruction>(V) && DT.dominates(cast<Instruction>(V), BB))
5857 return false;
5858
5859 return setUnavailable();
5860 }
5861
5862 case scUDivExpr:
5863 case scCouldNotCompute:
5864 // We do not try to smart about these at all.
5865 return setUnavailable();
5866 }
5867 llvm_unreachable("Unknown SCEV kind!")::llvm::llvm_unreachable_internal("Unknown SCEV kind!", "llvm/lib/Analysis/ScalarEvolution.cpp"
, 5867)
;
5868 }
5869
5870 bool isDone() { return TraversalDone; }
5871 };
5872
5873 CheckAvailable CA(L, BB, DT);
5874 SCEVTraversal<CheckAvailable> ST(CA);
5875
5876 ST.visitAll(S);
5877 return CA.Available;
5878}
5879
5880// Try to match a control flow sequence that branches out at BI and merges back
5881// at Merge into a "C ? LHS : RHS" select pattern. Return true on a successful
5882// match.
5883static bool BrPHIToSelect(DominatorTree &DT, BranchInst *BI, PHINode *Merge,
5884 Value *&C, Value *&LHS, Value *&RHS) {
5885 C = BI->getCondition();
5886
5887 BasicBlockEdge LeftEdge(BI->getParent(), BI->getSuccessor(0));
5888 BasicBlockEdge RightEdge(BI->getParent(), BI->getSuccessor(1));
5889
5890 if (!LeftEdge.isSingleEdge())
5891 return false;
5892
5893 assert(RightEdge.isSingleEdge() && "Follows from LeftEdge.isSingleEdge()")(static_cast <bool> (RightEdge.isSingleEdge() &&
"Follows from LeftEdge.isSingleEdge()") ? void (0) : __assert_fail
("RightEdge.isSingleEdge() && \"Follows from LeftEdge.isSingleEdge()\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 5893, __extension__
__PRETTY_FUNCTION__))
;
5894
5895 Use &LeftUse = Merge->getOperandUse(0);
5896 Use &RightUse = Merge->getOperandUse(1);
5897
5898 if (DT.dominates(LeftEdge, LeftUse) && DT.dominates(RightEdge, RightUse)) {
5899 LHS = LeftUse;
5900 RHS = RightUse;
5901 return true;
5902 }
5903
5904 if (DT.dominates(LeftEdge, RightUse) && DT.dominates(RightEdge, LeftUse)) {
5905 LHS = RightUse;
5906 RHS = LeftUse;
5907 return true;
5908 }
5909
5910 return false;
5911}
5912
5913const SCEV *ScalarEvolution::createNodeFromSelectLikePHI(PHINode *PN) {
5914 auto IsReachable =
5915 [&](BasicBlock *BB) { return DT.isReachableFromEntry(BB); };
5916 if (PN->getNumIncomingValues() == 2 && all_of(PN->blocks(), IsReachable)) {
5917 const Loop *L = LI.getLoopFor(PN->getParent());
5918
5919 // We don't want to break LCSSA, even in a SCEV expression tree.
5920 for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i)
5921 if (LI.getLoopFor(PN->getIncomingBlock(i)) != L)
5922 return nullptr;
5923
5924 // Try to match
5925 //
5926 // br %cond, label %left, label %right
5927 // left:
5928 // br label %merge
5929 // right:
5930 // br label %merge
5931 // merge:
5932 // V = phi [ %x, %left ], [ %y, %right ]
5933 //
5934 // as "select %cond, %x, %y"
5935
5936 BasicBlock *IDom = DT[PN->getParent()]->getIDom()->getBlock();
5937 assert(IDom && "At least the entry block should dominate PN")(static_cast <bool> (IDom && "At least the entry block should dominate PN"
) ? void (0) : __assert_fail ("IDom && \"At least the entry block should dominate PN\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 5937, __extension__
__PRETTY_FUNCTION__))
;
5938
5939 auto *BI = dyn_cast<BranchInst>(IDom->getTerminator());
5940 Value *Cond = nullptr, *LHS = nullptr, *RHS = nullptr;
5941
5942 if (BI && BI->isConditional() &&
5943 BrPHIToSelect(DT, BI, PN, Cond, LHS, RHS) &&
5944 IsAvailableOnEntry(L, DT, getSCEV(LHS), PN->getParent()) &&
5945 IsAvailableOnEntry(L, DT, getSCEV(RHS), PN->getParent()))
5946 return createNodeForSelectOrPHI(PN, Cond, LHS, RHS);
5947 }
5948
5949 return nullptr;
5950}
5951
5952const SCEV *ScalarEvolution::createNodeForPHI(PHINode *PN) {
5953 if (const SCEV *S = createAddRecFromPHI(PN))
5954 return S;
5955
5956 if (const SCEV *S = createNodeFromSelectLikePHI(PN))
5957 return S;
5958
5959 if (Value *V = simplifyInstruction(PN, {getDataLayout(), &TLI, &DT, &AC}))
5960 return getSCEV(V);
5961
5962 // If it's not a loop phi, we can't handle it yet.
5963 return getUnknown(PN);
5964}
5965
5966bool SCEVMinMaxExprContains(const SCEV *Root, const SCEV *OperandToFind,
5967 SCEVTypes RootKind) {
5968 struct FindClosure {
5969 const SCEV *OperandToFind;
5970 const SCEVTypes RootKind; // Must be a sequential min/max expression.
5971 const SCEVTypes NonSequentialRootKind; // Non-seq variant of RootKind.
5972
5973 bool Found = false;
5974
5975 bool canRecurseInto(SCEVTypes Kind) const {
5976 // We can only recurse into the SCEV expression of the same effective type
5977 // as the type of our root SCEV expression, and into zero-extensions.
5978 return RootKind == Kind || NonSequentialRootKind == Kind ||
5979 scZeroExtend == Kind;
5980 };
5981
5982 FindClosure(const SCEV *OperandToFind, SCEVTypes RootKind)
5983 : OperandToFind(OperandToFind), RootKind(RootKind),
5984 NonSequentialRootKind(
5985 SCEVSequentialMinMaxExpr::getEquivalentNonSequentialSCEVType(
5986 RootKind)) {}
5987
5988 bool follow(const SCEV *S) {
5989 Found = S == OperandToFind;
5990
5991 return !isDone() && canRecurseInto(S->getSCEVType());
5992 }
5993
5994 bool isDone() const { return Found; }
5995 };
5996
5997 FindClosure FC(OperandToFind, RootKind);
5998 visitAll(Root, FC);
5999 return FC.Found;
6000}
6001
6002const SCEV *ScalarEvolution::createNodeForSelectOrPHIInstWithICmpInstCond(
6003 Instruction *I, ICmpInst *Cond, Value *TrueVal, Value *FalseVal) {
6004 // Try to match some simple smax or umax patterns.
6005 auto *ICI = Cond;
6006
6007 Value *LHS = ICI->getOperand(0);
6008 Value *RHS = ICI->getOperand(1);
6009
6010 switch (ICI->getPredicate()) {
6011 case ICmpInst::ICMP_SLT:
6012 case ICmpInst::ICMP_SLE:
6013 case ICmpInst::ICMP_ULT:
6014 case ICmpInst::ICMP_ULE:
6015 std::swap(LHS, RHS);
6016 [[fallthrough]];
6017 case ICmpInst::ICMP_SGT:
6018 case ICmpInst::ICMP_SGE:
6019 case ICmpInst::ICMP_UGT:
6020 case ICmpInst::ICMP_UGE:
6021 // a > b ? a+x : b+x -> max(a, b)+x
6022 // a > b ? b+x : a+x -> min(a, b)+x
6023 if (getTypeSizeInBits(LHS->getType()) <= getTypeSizeInBits(I->getType())) {
6024 bool Signed = ICI->isSigned();
6025 const SCEV *LA = getSCEV(TrueVal);
6026 const SCEV *RA = getSCEV(FalseVal);
6027 const SCEV *LS = getSCEV(LHS);
6028 const SCEV *RS = getSCEV(RHS);
6029 if (LA->getType()->isPointerTy()) {
6030 // FIXME: Handle cases where LS/RS are pointers not equal to LA/RA.
6031 // Need to make sure we can't produce weird expressions involving
6032 // negated pointers.
6033 if (LA == LS && RA == RS)
6034 return Signed ? getSMaxExpr(LS, RS) : getUMaxExpr(LS, RS);
6035 if (LA == RS && RA == LS)
6036 return Signed ? getSMinExpr(LS, RS) : getUMinExpr(LS, RS);
6037 }
6038 auto CoerceOperand = [&](const SCEV *Op) -> const SCEV * {
6039 if (Op->getType()->isPointerTy()) {
6040 Op = getLosslessPtrToIntExpr(Op);
6041 if (isa<SCEVCouldNotCompute>(Op))
6042 return Op;
6043 }
6044 if (Signed)
6045 Op = getNoopOrSignExtend(Op, I->getType());
6046 else
6047 Op = getNoopOrZeroExtend(Op, I->getType());
6048 return Op;
6049 };
6050 LS = CoerceOperand(LS);
6051 RS = CoerceOperand(RS);
6052 if (isa<SCEVCouldNotCompute>(LS) || isa<SCEVCouldNotCompute>(RS))
6053 break;
6054 const SCEV *LDiff = getMinusSCEV(LA, LS);
6055 const SCEV *RDiff = getMinusSCEV(RA, RS);
6056 if (LDiff == RDiff)
6057 return getAddExpr(Signed ? getSMaxExpr(LS, RS) : getUMaxExpr(LS, RS),
6058 LDiff);
6059 LDiff = getMinusSCEV(LA, RS);
6060 RDiff = getMinusSCEV(RA, LS);
6061 if (LDiff == RDiff)
6062 return getAddExpr(Signed ? getSMinExpr(LS, RS) : getUMinExpr(LS, RS),
6063 LDiff);
6064 }
6065 break;
6066 case ICmpInst::ICMP_NE:
6067 // x != 0 ? x+y : C+y -> x == 0 ? C+y : x+y
6068 std::swap(TrueVal, FalseVal);
6069 [[fallthrough]];
6070 case ICmpInst::ICMP_EQ:
6071 // x == 0 ? C+y : x+y -> umax(x, C)+y iff C u<= 1
6072 if (getTypeSizeInBits(LHS->getType()) <= getTypeSizeInBits(I->getType()) &&
6073 isa<ConstantInt>(RHS) && cast<ConstantInt>(RHS)->isZero()) {
6074 const SCEV *X = getNoopOrZeroExtend(getSCEV(LHS), I->getType());
6075 const SCEV *TrueValExpr = getSCEV(TrueVal); // C+y
6076 const SCEV *FalseValExpr = getSCEV(FalseVal); // x+y
6077 const SCEV *Y = getMinusSCEV(FalseValExpr, X); // y = (x+y)-x
6078 const SCEV *C = getMinusSCEV(TrueValExpr, Y); // C = (C+y)-y
6079 if (isa<SCEVConstant>(C) && cast<SCEVConstant>(C)->getAPInt().ule(1))
6080 return getAddExpr(getUMaxExpr(X, C), Y);
6081 }
6082 // x == 0 ? 0 : umin (..., x, ...) -> umin_seq(x, umin (...))
6083 // x == 0 ? 0 : umin_seq(..., x, ...) -> umin_seq(x, umin_seq(...))
6084 // x == 0 ? 0 : umin (..., umin_seq(..., x, ...), ...)
6085 // -> umin_seq(x, umin (..., umin_seq(...), ...))
6086 if (isa<ConstantInt>(RHS) && cast<ConstantInt>(RHS)->isZero() &&
6087 isa<ConstantInt>(TrueVal) && cast<ConstantInt>(TrueVal)->isZero()) {
6088 const SCEV *X = getSCEV(LHS);
6089 while (auto *ZExt = dyn_cast<SCEVZeroExtendExpr>(X))
6090 X = ZExt->getOperand();
6091 if (getTypeSizeInBits(X->getType()) <= getTypeSizeInBits(I->getType())) {
6092 const SCEV *FalseValExpr = getSCEV(FalseVal);
6093 if (SCEVMinMaxExprContains(FalseValExpr, X, scSequentialUMinExpr))
6094 return getUMinExpr(getNoopOrZeroExtend(X, I->getType()), FalseValExpr,
6095 /*Sequential=*/true);
6096 }
6097 }
6098 break;
6099 default:
6100 break;
6101 }
6102
6103 return getUnknown(I);
6104}
6105
6106static Optional<const SCEV *>
6107createNodeForSelectViaUMinSeq(ScalarEvolution *SE, const SCEV *CondExpr,
6108 const SCEV *TrueExpr, const SCEV *FalseExpr) {
6109 assert(CondExpr->getType()->isIntegerTy(1) &&(static_cast <bool> (CondExpr->getType()->isIntegerTy
(1) && TrueExpr->getType() == FalseExpr->getType
() && TrueExpr->getType()->isIntegerTy(1) &&
"Unexpected operands of a select.") ? void (0) : __assert_fail
("CondExpr->getType()->isIntegerTy(1) && TrueExpr->getType() == FalseExpr->getType() && TrueExpr->getType()->isIntegerTy(1) && \"Unexpected operands of a select.\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 6112, __extension__
__PRETTY_FUNCTION__))
6110 TrueExpr->getType() == FalseExpr->getType() &&(static_cast <bool> (CondExpr->getType()->isIntegerTy
(1) && TrueExpr->getType() == FalseExpr->getType
() && TrueExpr->getType()->isIntegerTy(1) &&
"Unexpected operands of a select.") ? void (0) : __assert_fail
("CondExpr->getType()->isIntegerTy(1) && TrueExpr->getType() == FalseExpr->getType() && TrueExpr->getType()->isIntegerTy(1) && \"Unexpected operands of a select.\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 6112, __extension__
__PRETTY_FUNCTION__))
6111 TrueExpr->getType()->isIntegerTy(1) &&(static_cast <bool> (CondExpr->getType()->isIntegerTy
(1) && TrueExpr->getType() == FalseExpr->getType
() && TrueExpr->getType()->isIntegerTy(1) &&
"Unexpected operands of a select.") ? void (0) : __assert_fail
("CondExpr->getType()->isIntegerTy(1) && TrueExpr->getType() == FalseExpr->getType() && TrueExpr->getType()->isIntegerTy(1) && \"Unexpected operands of a select.\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 6112, __extension__
__PRETTY_FUNCTION__))
6112 "Unexpected operands of a select.")(static_cast <bool> (CondExpr->getType()->isIntegerTy
(1) && TrueExpr->getType() == FalseExpr->getType
() && TrueExpr->getType()->isIntegerTy(1) &&
"Unexpected operands of a select.") ? void (0) : __assert_fail
("CondExpr->getType()->isIntegerTy(1) && TrueExpr->getType() == FalseExpr->getType() && TrueExpr->getType()->isIntegerTy(1) && \"Unexpected operands of a select.\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 6112, __extension__
__PRETTY_FUNCTION__))
;
6113
6114 // i1 cond ? i1 x : i1 C --> C + (i1 cond ? (i1 x - i1 C) : i1 0)
6115 // --> C + (umin_seq cond, x - C)
6116 //
6117 // i1 cond ? i1 C : i1 x --> C + (i1 cond ? i1 0 : (i1 x - i1 C))
6118 // --> C + (i1 ~cond ? (i1 x - i1 C) : i1 0)
6119 // --> C + (umin_seq ~cond, x - C)
6120
6121 // FIXME: while we can't legally model the case where both of the hands
6122 // are fully variable, we only require that the *difference* is constant.
6123 if (!isa<SCEVConstant>(TrueExpr) && !isa<SCEVConstant>(FalseExpr))
6124 return None;
6125
6126 const SCEV *X, *C;
6127 if (isa<SCEVConstant>(TrueExpr)) {
6128 CondExpr = SE->getNotSCEV(CondExpr);
6129 X = FalseExpr;
6130 C = TrueExpr;
6131 } else {
6132 X = TrueExpr;
6133 C = FalseExpr;
6134 }
6135 return SE->getAddExpr(C, SE->getUMinExpr(CondExpr, SE->getMinusSCEV(X, C),
6136 /*Sequential=*/true));
6137}
6138
6139static Optional<const SCEV *> createNodeForSelectViaUMinSeq(ScalarEvolution *SE,
6140 Value *Cond,
6141 Value *TrueVal,
6142 Value *FalseVal) {
6143 if (!isa<ConstantInt>(TrueVal) && !isa<ConstantInt>(FalseVal))
6144 return None;
6145
6146 const auto *SECond = SE->getSCEV(Cond);
6147 const auto *SETrue = SE->getSCEV(TrueVal);
6148 const auto *SEFalse = SE->getSCEV(FalseVal);
6149 return createNodeForSelectViaUMinSeq(SE, SECond, SETrue, SEFalse);
6150}
6151
6152const SCEV *ScalarEvolution::createNodeForSelectOrPHIViaUMinSeq(
6153 Value *V, Value *Cond, Value *TrueVal, Value *FalseVal) {
6154 assert(Cond->getType()->isIntegerTy(1) && "Select condition is not an i1?")(static_cast <bool> (Cond->getType()->isIntegerTy
(1) && "Select condition is not an i1?") ? void (0) :
__assert_fail ("Cond->getType()->isIntegerTy(1) && \"Select condition is not an i1?\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 6154, __extension__
__PRETTY_FUNCTION__))
;
6155 assert(TrueVal->getType() == FalseVal->getType() &&(static_cast <bool> (TrueVal->getType() == FalseVal->
getType() && V->getType() == TrueVal->getType()
&& "Types of select hands and of the result must match."
) ? void (0) : __assert_fail ("TrueVal->getType() == FalseVal->getType() && V->getType() == TrueVal->getType() && \"Types of select hands and of the result must match.\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 6157, __extension__
__PRETTY_FUNCTION__))
6156 V->getType() == TrueVal->getType() &&(static_cast <bool> (TrueVal->getType() == FalseVal->
getType() && V->getType() == TrueVal->getType()
&& "Types of select hands and of the result must match."
) ? void (0) : __assert_fail ("TrueVal->getType() == FalseVal->getType() && V->getType() == TrueVal->getType() && \"Types of select hands and of the result must match.\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 6157, __extension__
__PRETTY_FUNCTION__))
6157 "Types of select hands and of the result must match.")(static_cast <bool> (TrueVal->getType() == FalseVal->
getType() && V->getType() == TrueVal->getType()
&& "Types of select hands and of the result must match."
) ? void (0) : __assert_fail ("TrueVal->getType() == FalseVal->getType() && V->getType() == TrueVal->getType() && \"Types of select hands and of the result must match.\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 6157, __extension__
__PRETTY_FUNCTION__))
;
6158
6159 // For now, only deal with i1-typed `select`s.
6160 if (!V->getType()->isIntegerTy(1))
6161 return getUnknown(V);
6162
6163 if (Optional<const SCEV *> S =
6164 createNodeForSelectViaUMinSeq(this, Cond, TrueVal, FalseVal))
6165 return *S;
6166
6167 return getUnknown(V);
6168}
6169
6170const SCEV *ScalarEvolution::createNodeForSelectOrPHI(Value *V, Value *Cond,
6171 Value *TrueVal,
6172 Value *FalseVal) {
6173 // Handle "constant" branch or select. This can occur for instance when a
6174 // loop pass transforms an inner loop and moves on to process the outer loop.
6175 if (auto *CI = dyn_cast<ConstantInt>(Cond))
6176 return getSCEV(CI->isOne() ? TrueVal : FalseVal);
6177
6178 if (auto *I = dyn_cast<Instruction>(V)) {
6179 if (auto *ICI = dyn_cast<ICmpInst>(Cond)) {
6180 const SCEV *S = createNodeForSelectOrPHIInstWithICmpInstCond(
6181 I, ICI, TrueVal, FalseVal);
6182 if (!isa<SCEVUnknown>(S))
6183 return S;
6184 }
6185 }
6186
6187 return createNodeForSelectOrPHIViaUMinSeq(V, Cond, TrueVal, FalseVal);
6188}
6189
6190/// Expand GEP instructions into add and multiply operations. This allows them
6191/// to be analyzed by regular SCEV code.
6192const SCEV *ScalarEvolution::createNodeForGEP(GEPOperator *GEP) {
6193 assert(GEP->getSourceElementType()->isSized() &&(static_cast <bool> (GEP->getSourceElementType()->
isSized() && "GEP source element type must be sized")
? void (0) : __assert_fail ("GEP->getSourceElementType()->isSized() && \"GEP source element type must be sized\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 6194, __extension__
__PRETTY_FUNCTION__))
6194 "GEP source element type must be sized")(static_cast <bool> (GEP->getSourceElementType()->
isSized() && "GEP source element type must be sized")
? void (0) : __assert_fail ("GEP->getSourceElementType()->isSized() && \"GEP source element type must be sized\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 6194, __extension__
__PRETTY_FUNCTION__))
;
6195
6196 SmallVector<const SCEV *, 4> IndexExprs;
6197 for (Value *Index : GEP->indices())
6198 IndexExprs.push_back(getSCEV(Index));
6199 return getGEPExpr(GEP, IndexExprs);
6200}
6201
6202uint32_t ScalarEvolution::GetMinTrailingZerosImpl(const SCEV *S) {
6203 if (const SCEVConstant *C = dyn_cast<SCEVConstant>(S))
6204 return C->getAPInt().countTrailingZeros();
6205
6206 if (const SCEVPtrToIntExpr *I = dyn_cast<SCEVPtrToIntExpr>(S))
6207 return GetMinTrailingZeros(I->getOperand());
6208
6209 if (const SCEVTruncateExpr *T = dyn_cast<SCEVTruncateExpr>(S))
6210 return std::min(GetMinTrailingZeros(T->getOperand()),
6211 (uint32_t)getTypeSizeInBits(T->getType()));
6212
6213 if (const SCEVZeroExtendExpr *E = dyn_cast<SCEVZeroExtendExpr>(S)) {
6214 uint32_t OpRes = GetMinTrailingZeros(E->getOperand());
6215 return OpRes == getTypeSizeInBits(E->getOperand()->getType())
6216 ? getTypeSizeInBits(E->getType())
6217 : OpRes;
6218 }
6219
6220 if (const SCEVSignExtendExpr *E = dyn_cast<SCEVSignExtendExpr>(S)) {
6221 uint32_t OpRes = GetMinTrailingZeros(E->getOperand());
6222 return OpRes == getTypeSizeInBits(E->getOperand()->getType())
6223 ? getTypeSizeInBits(E->getType())
6224 : OpRes;
6225 }
6226
6227 if (const SCEVAddExpr *A = dyn_cast<SCEVAddExpr>(S)) {
6228 // The result is the min of all operands results.
6229 uint32_t MinOpRes = GetMinTrailingZeros(A->getOperand(0));
6230 for (unsigned i = 1, e = A->getNumOperands(); MinOpRes && i != e; ++i)
6231 MinOpRes = std::min(MinOpRes, GetMinTrailingZeros(A->getOperand(i)));
6232 return MinOpRes;
6233 }
6234
6235 if (const SCEVMulExpr *M = dyn_cast<SCEVMulExpr>(S)) {
6236 // The result is the sum of all operands results.
6237 uint32_t SumOpRes = GetMinTrailingZeros(M->getOperand(0));
6238 uint32_t BitWidth = getTypeSizeInBits(M->getType());
6239 for (unsigned i = 1, e = M->getNumOperands();
6240 SumOpRes != BitWidth && i != e; ++i)
6241 SumOpRes =
6242 std::min(SumOpRes + GetMinTrailingZeros(M->getOperand(i)), BitWidth);
6243 return SumOpRes;
6244 }
6245
6246 if (const SCEVAddRecExpr *A = dyn_cast<SCEVAddRecExpr>(S)) {
6247 // The result is the min of all operands results.
6248 uint32_t MinOpRes = GetMinTrailingZeros(A->getOperand(0));
6249 for (unsigned i = 1, e = A->getNumOperands(); MinOpRes && i != e; ++i)
6250 MinOpRes = std::min(MinOpRes, GetMinTrailingZeros(A->getOperand(i)));
6251 return MinOpRes;
6252 }
6253
6254 if (const SCEVSMaxExpr *M = dyn_cast<SCEVSMaxExpr>(S)) {
6255 // The result is the min of all operands results.
6256 uint32_t MinOpRes = GetMinTrailingZeros(M->getOperand(0));
6257 for (unsigned i = 1, e = M->getNumOperands(); MinOpRes && i != e; ++i)
6258 MinOpRes = std::min(MinOpRes, GetMinTrailingZeros(M->getOperand(i)));
6259 return MinOpRes;
6260 }
6261
6262 if (const SCEVUMaxExpr *M = dyn_cast<SCEVUMaxExpr>(S)) {
6263 // The result is the min of all operands results.
6264 uint32_t MinOpRes = GetMinTrailingZeros(M->getOperand(0));
6265 for (unsigned i = 1, e = M->getNumOperands(); MinOpRes && i != e; ++i)
6266 MinOpRes = std::min(MinOpRes, GetMinTrailingZeros(M->getOperand(i)));
6267 return MinOpRes;
6268 }
6269
6270 if (const SCEVUnknown *U = dyn_cast<SCEVUnknown>(S)) {
6271 // For a SCEVUnknown, ask ValueTracking.
6272 KnownBits Known = computeKnownBits(U->getValue(), getDataLayout(), 0, &AC, nullptr, &DT);
6273 return Known.countMinTrailingZeros();
6274 }
6275
6276 // SCEVUDivExpr
6277 return 0;
6278}
6279
6280uint32_t ScalarEvolution::GetMinTrailingZeros(const SCEV *S) {
6281 auto I = MinTrailingZerosCache.find(S);
6282 if (I != MinTrailingZerosCache.end())
6283 return I->second;
6284
6285 uint32_t Result = GetMinTrailingZerosImpl(S);
6286 auto InsertPair = MinTrailingZerosCache.insert({S, Result});
6287 assert(InsertPair.second && "Should insert a new key")(static_cast <bool> (InsertPair.second && "Should insert a new key"
) ? void (0) : __assert_fail ("InsertPair.second && \"Should insert a new key\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 6287, __extension__
__PRETTY_FUNCTION__))
;
6288 return InsertPair.first->second;
6289}
6290
6291/// Helper method to assign a range to V from metadata present in the IR.
6292static Optional<ConstantRange> GetRangeFromMetadata(Value *V) {
6293 if (Instruction *I = dyn_cast<Instruction>(V))
6294 if (MDNode *MD = I->getMetadata(LLVMContext::MD_range))
6295 return getConstantRangeFromMetadata(*MD);
6296
6297 return None;
6298}
6299
6300void ScalarEvolution::setNoWrapFlags(SCEVAddRecExpr *AddRec,
6301 SCEV::NoWrapFlags Flags) {
6302 if (AddRec->getNoWrapFlags(Flags) != Flags) {
6303 AddRec->setNoWrapFlags(Flags);
6304 UnsignedRanges.erase(AddRec);
6305 SignedRanges.erase(AddRec);
6306 }
6307}
6308
6309ConstantRange ScalarEvolution::
6310getRangeForUnknownRecurrence(const SCEVUnknown *U) {
6311 const DataLayout &DL = getDataLayout();
6312
6313 unsigned BitWidth = getTypeSizeInBits(U->getType());
6314 const ConstantRange FullSet(BitWidth, /*isFullSet=*/true);
6315
6316 // Match a simple recurrence of the form: <start, ShiftOp, Step>, and then
6317 // use information about the trip count to improve our available range. Note
6318 // that the trip count independent cases are already handled by known bits.
6319 // WARNING: The definition of recurrence used here is subtly different than
6320 // the one used by AddRec (and thus most of this file). Step is allowed to
6321 // be arbitrarily loop varying here, where AddRec allows only loop invariant
6322 // and other addrecs in the same loop (for non-affine addrecs). The code
6323 // below intentionally handles the case where step is not loop invariant.
6324 auto *P = dyn_cast<PHINode>(U->getValue());
6325 if (!P)
6326 return FullSet;
6327
6328 // Make sure that no Phi input comes from an unreachable block. Otherwise,
6329 // even the values that are not available in these blocks may come from them,
6330 // and this leads to false-positive recurrence test.
6331 for (auto *Pred : predecessors(P->getParent()))
6332 if (!DT.isReachableFromEntry(Pred))
6333 return FullSet;
6334
6335 BinaryOperator *BO;
6336 Value *Start, *Step;
6337 if (!matchSimpleRecurrence(P, BO, Start, Step))
6338 return FullSet;
6339
6340 // If we found a recurrence in reachable code, we must be in a loop. Note
6341 // that BO might be in some subloop of L, and that's completely okay.
6342 auto *L = LI.getLoopFor(P->getParent());
6343 assert(L && L->getHeader() == P->getParent())(static_cast <bool> (L && L->getHeader() == P
->getParent()) ? void (0) : __assert_fail ("L && L->getHeader() == P->getParent()"
, "llvm/lib/Analysis/ScalarEvolution.cpp", 6343, __extension__
__PRETTY_FUNCTION__))
;
6344 if (!L->contains(BO->getParent()))
6345 // NOTE: This bailout should be an assert instead. However, asserting
6346 // the condition here exposes a case where LoopFusion is querying SCEV
6347 // with malformed loop information during the midst of the transform.
6348 // There doesn't appear to be an obvious fix, so for the moment bailout
6349 // until the caller issue can be fixed. PR49566 tracks the bug.
6350 return FullSet;
6351
6352 // TODO: Extend to other opcodes such as mul, and div
6353 switch (BO->getOpcode()) {
6354 default:
6355 return FullSet;
6356 case Instruction::AShr:
6357 case Instruction::LShr:
6358 case Instruction::Shl:
6359 break;
6360 };
6361
6362 if (BO->getOperand(0) != P)
6363 // TODO: Handle the power function forms some day.
6364 return FullSet;
6365
6366 unsigned TC = getSmallConstantMaxTripCount(L);
6367 if (!TC || TC >= BitWidth)
6368 return FullSet;
6369
6370 auto KnownStart = computeKnownBits(Start, DL, 0, &AC, nullptr, &DT);
6371 auto KnownStep = computeKnownBits(Step, DL, 0, &AC, nullptr, &DT);
6372 assert(KnownStart.getBitWidth() == BitWidth &&(static_cast <bool> (KnownStart.getBitWidth() == BitWidth
&& KnownStep.getBitWidth() == BitWidth) ? void (0) :
__assert_fail ("KnownStart.getBitWidth() == BitWidth && KnownStep.getBitWidth() == BitWidth"
, "llvm/lib/Analysis/ScalarEvolution.cpp", 6373, __extension__
__PRETTY_FUNCTION__))
6373 KnownStep.getBitWidth() == BitWidth)(static_cast <bool> (KnownStart.getBitWidth() == BitWidth
&& KnownStep.getBitWidth() == BitWidth) ? void (0) :
__assert_fail ("KnownStart.getBitWidth() == BitWidth && KnownStep.getBitWidth() == BitWidth"
, "llvm/lib/Analysis/ScalarEvolution.cpp", 6373, __extension__
__PRETTY_FUNCTION__))
;
6374
6375 // Compute total shift amount, being careful of overflow and bitwidths.
6376 auto MaxShiftAmt = KnownStep.getMaxValue();
6377 APInt TCAP(BitWidth, TC-1);
6378 bool Overflow = false;
6379 auto TotalShift = MaxShiftAmt.umul_ov(TCAP, Overflow);
6380 if (Overflow)
6381 return FullSet;
6382
6383 switch (BO->getOpcode()) {
6384 default:
6385 llvm_unreachable("filtered out above")::llvm::llvm_unreachable_internal("filtered out above", "llvm/lib/Analysis/ScalarEvolution.cpp"
, 6385)
;
6386 case Instruction::AShr: {
6387 // For each ashr, three cases:
6388 // shift = 0 => unchanged value
6389 // saturation => 0 or -1
6390 // other => a value closer to zero (of the same sign)
6391 // Thus, the end value is closer to zero than the start.
6392 auto KnownEnd = KnownBits::ashr(KnownStart,
6393 KnownBits::makeConstant(TotalShift));
6394 if (KnownStart.isNonNegative())
6395 // Analogous to lshr (simply not yet canonicalized)
6396 return ConstantRange::getNonEmpty(KnownEnd.getMinValue(),
6397 KnownStart.getMaxValue() + 1);
6398 if (KnownStart.isNegative())
6399 // End >=u Start && End <=s Start
6400 return ConstantRange::getNonEmpty(KnownStart.getMinValue(),
6401 KnownEnd.getMaxValue() + 1);
6402 break;
6403 }
6404 case Instruction::LShr: {
6405 // For each lshr, three cases:
6406 // shift = 0 => unchanged value
6407 // saturation => 0
6408 // other => a smaller positive number
6409 // Thus, the low end of the unsigned range is the last value produced.
6410 auto KnownEnd = KnownBits::lshr(KnownStart,
6411 KnownBits::makeConstant(TotalShift));
6412 return ConstantRange::getNonEmpty(KnownEnd.getMinValue(),
6413 KnownStart.getMaxValue() + 1);
6414 }
6415 case Instruction::Shl: {
6416 // Iff no bits are shifted out, value increases on every shift.
6417 auto KnownEnd = KnownBits::shl(KnownStart,
6418 KnownBits::makeConstant(TotalShift));
6419 if (TotalShift.ult(KnownStart.countMinLeadingZeros()))
6420 return ConstantRange(KnownStart.getMinValue(),
6421 KnownEnd.getMaxValue() + 1);
6422 break;
6423 }
6424 };
6425 return FullSet;
6426}
6427
6428/// Determine the range for a particular SCEV. If SignHint is
6429/// HINT_RANGE_UNSIGNED (resp. HINT_RANGE_SIGNED) then getRange prefers ranges
6430/// with a "cleaner" unsigned (resp. signed) representation.
6431const ConstantRange &
6432ScalarEvolution::getRangeRef(const SCEV *S,
6433 ScalarEvolution::RangeSignHint SignHint) {
6434 DenseMap<const SCEV *, ConstantRange> &Cache =
6435 SignHint == ScalarEvolution::HINT_RANGE_UNSIGNED ? UnsignedRanges
6436 : SignedRanges;
6437 ConstantRange::PreferredRangeType RangeType =
6438 SignHint == ScalarEvolution::HINT_RANGE_UNSIGNED
6439 ? ConstantRange::Unsigned : ConstantRange::Signed;
6440
6441 // See if we've computed this range already.
6442 DenseMap<const SCEV *, ConstantRange>::iterator I = Cache.find(S);
6443 if (I != Cache.end())
6444 return I->second;
6445
6446 if (const SCEVConstant *C = dyn_cast<SCEVConstant>(S))
6447 return setRange(C, SignHint, ConstantRange(C->getAPInt()));
6448
6449 unsigned BitWidth = getTypeSizeInBits(S->getType());
6450 ConstantRange ConservativeResult(BitWidth, /*isFullSet=*/true);
6451 using OBO = OverflowingBinaryOperator;
6452
6453 // If the value has known zeros, the maximum value will have those known zeros
6454 // as well.
6455 uint32_t TZ = GetMinTrailingZeros(S);
6456 if (TZ != 0) {
6457 if (SignHint == ScalarEvolution::HINT_RANGE_UNSIGNED)
6458 ConservativeResult =
6459 ConstantRange(APInt::getMinValue(BitWidth),
6460 APInt::getMaxValue(BitWidth).lshr(TZ).shl(TZ) + 1);
6461 else
6462 ConservativeResult = ConstantRange(
6463 APInt::getSignedMinValue(BitWidth),
6464 APInt::getSignedMaxValue(BitWidth).ashr(TZ).shl(TZ) + 1);
6465 }
6466
6467 if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(S)) {
6468 ConstantRange X = getRangeRef(Add->getOperand(0), SignHint);
6469 unsigned WrapType = OBO::AnyWrap;
6470 if (Add->hasNoSignedWrap())
6471 WrapType |= OBO::NoSignedWrap;
6472 if (Add->hasNoUnsignedWrap())
6473 WrapType |= OBO::NoUnsignedWrap;
6474 for (unsigned i = 1, e = Add->getNumOperands(); i != e; ++i)
6475 X = X.addWithNoWrap(getRangeRef(Add->getOperand(i), SignHint),
6476 WrapType, RangeType);
6477 return setRange(Add, SignHint,
6478 ConservativeResult.intersectWith(X, RangeType));
6479 }
6480
6481 if (const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(S)) {
6482 ConstantRange X = getRangeRef(Mul->getOperand(0), SignHint);
6483 for (unsigned i = 1, e = Mul->getNumOperands(); i != e; ++i)
6484 X = X.multiply(getRangeRef(Mul->getOperand(i), SignHint));
6485 return setRange(Mul, SignHint,
6486 ConservativeResult.intersectWith(X, RangeType));
6487 }
6488
6489 if (isa<SCEVMinMaxExpr>(S) || isa<SCEVSequentialMinMaxExpr>(S)) {
6490 Intrinsic::ID ID;
6491 switch (S->getSCEVType()) {
6492 case scUMaxExpr:
6493 ID = Intrinsic::umax;
6494 break;
6495 case scSMaxExpr:
6496 ID = Intrinsic::smax;
6497 break;
6498 case scUMinExpr:
6499 case scSequentialUMinExpr:
6500 ID = Intrinsic::umin;
6501 break;
6502 case scSMinExpr:
6503 ID = Intrinsic::smin;
6504 break;
6505 default:
6506 llvm_unreachable("Unknown SCEVMinMaxExpr/SCEVSequentialMinMaxExpr.")::llvm::llvm_unreachable_internal("Unknown SCEVMinMaxExpr/SCEVSequentialMinMaxExpr."
, "llvm/lib/Analysis/ScalarEvolution.cpp", 6506)
;
6507 }
6508
6509 const auto *NAry = cast<SCEVNAryExpr>(S);
6510 ConstantRange X = getRangeRef(NAry->getOperand(0), SignHint);
6511 for (unsigned i = 1, e = NAry->getNumOperands(); i != e; ++i)
6512 X = X.intrinsic(ID, {X, getRangeRef(NAry->getOperand(i), SignHint)});
6513 return setRange(S, SignHint,
6514 ConservativeResult.intersectWith(X, RangeType));
6515 }
6516
6517 if (const SCEVUDivExpr *UDiv = dyn_cast<SCEVUDivExpr>(S)) {
6518 ConstantRange X = getRangeRef(UDiv->getLHS(), SignHint);
6519 ConstantRange Y = getRangeRef(UDiv->getRHS(), SignHint);
6520 return setRange(UDiv, SignHint,
6521 ConservativeResult.intersectWith(X.udiv(Y), RangeType));
6522 }
6523
6524 if (const SCEVZeroExtendExpr *ZExt = dyn_cast<SCEVZeroExtendExpr>(S)) {
6525 ConstantRange X = getRangeRef(ZExt->getOperand(), SignHint);
6526 return setRange(ZExt, SignHint,
6527 ConservativeResult.intersectWith(X.zeroExtend(BitWidth),
6528 RangeType));
6529 }
6530
6531 if (const SCEVSignExtendExpr *SExt = dyn_cast<SCEVSignExtendExpr>(S)) {
6532 ConstantRange X = getRangeRef(SExt->getOperand(), SignHint);
6533 return setRange(SExt, SignHint,
6534 ConservativeResult.intersectWith(X.signExtend(BitWidth),
6535 RangeType));
6536 }
6537
6538 if (const SCEVPtrToIntExpr *PtrToInt = dyn_cast<SCEVPtrToIntExpr>(S)) {
6539 ConstantRange X = getRangeRef(PtrToInt->getOperand(), SignHint);
6540 return setRange(PtrToInt, SignHint, X);
6541 }
6542
6543 if (const SCEVTruncateExpr *Trunc = dyn_cast<SCEVTruncateExpr>(S)) {
6544 ConstantRange X = getRangeRef(Trunc->getOperand(), SignHint);
6545 return setRange(Trunc, SignHint,
6546 ConservativeResult.intersectWith(X.truncate(BitWidth),
6547 RangeType));
6548 }
6549
6550 if (const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(S)) {
6551 // If there's no unsigned wrap, the value will never be less than its
6552 // initial value.
6553 if (AddRec->hasNoUnsignedWrap()) {
6554 APInt UnsignedMinValue = getUnsignedRangeMin(AddRec->getStart());
6555 if (!UnsignedMinValue.isZero())
6556 ConservativeResult = ConservativeResult.intersectWith(
6557 ConstantRange(UnsignedMinValue, APInt(BitWidth, 0)), RangeType);
6558 }
6559
6560 // If there's no signed wrap, and all the operands except initial value have
6561 // the same sign or zero, the value won't ever be:
6562 // 1: smaller than initial value if operands are non negative,
6563 // 2: bigger than initial value if operands are non positive.
6564 // For both cases, value can not cross signed min/max boundary.
6565 if (AddRec->hasNoSignedWrap()) {
6566 bool AllNonNeg = true;
6567 bool AllNonPos = true;
6568 for (unsigned i = 1, e = AddRec->getNumOperands(); i != e; ++i) {
6569 if (!isKnownNonNegative(AddRec->getOperand(i)))
6570 AllNonNeg = false;
6571 if (!isKnownNonPositive(AddRec->getOperand(i)))
6572 AllNonPos = false;
6573 }
6574 if (AllNonNeg)
6575 ConservativeResult = ConservativeResult.intersectWith(
6576 ConstantRange::getNonEmpty(getSignedRangeMin(AddRec->getStart()),
6577 APInt::getSignedMinValue(BitWidth)),
6578 RangeType);
6579 else if (AllNonPos)
6580 ConservativeResult = ConservativeResult.intersectWith(
6581 ConstantRange::getNonEmpty(
6582 APInt::getSignedMinValue(BitWidth),
6583 getSignedRangeMax(AddRec->getStart()) + 1),
6584 RangeType);
6585 }
6586
6587 // TODO: non-affine addrec
6588 if (AddRec->isAffine()) {
6589 const SCEV *MaxBECount = getConstantMaxBackedgeTakenCount(AddRec->getLoop());
6590 if (!isa<SCEVCouldNotCompute>(MaxBECount) &&
6591 getTypeSizeInBits(MaxBECount->getType()) <= BitWidth) {
6592 auto RangeFromAffine = getRangeForAffineAR(
6593 AddRec->getStart(), AddRec->getStepRecurrence(*this), MaxBECount,
6594 BitWidth);
6595 ConservativeResult =
6596 ConservativeResult.intersectWith(RangeFromAffine, RangeType);
6597
6598 auto RangeFromFactoring = getRangeViaFactoring(
6599 AddRec->getStart(), AddRec->getStepRecurrence(*this), MaxBECount,
6600 BitWidth);
6601 ConservativeResult =
6602 ConservativeResult.intersectWith(RangeFromFactoring, RangeType);
6603 }
6604
6605 // Now try symbolic BE count and more powerful methods.
6606 if (UseExpensiveRangeSharpening) {
6607 const SCEV *SymbolicMaxBECount =
6608 getSymbolicMaxBackedgeTakenCount(AddRec->getLoop());
6609 if (!isa<SCEVCouldNotCompute>(SymbolicMaxBECount) &&
6610 getTypeSizeInBits(MaxBECount->getType()) <= BitWidth &&
6611 AddRec->hasNoSelfWrap()) {
6612 auto RangeFromAffineNew = getRangeForAffineNoSelfWrappingAR(
6613 AddRec, SymbolicMaxBECount, BitWidth, SignHint);
6614 ConservativeResult =
6615 ConservativeResult.intersectWith(RangeFromAffineNew, RangeType);
6616 }
6617 }
6618 }
6619
6620 return setRange(AddRec, SignHint, std::move(ConservativeResult));
6621 }
6622
6623 if (const SCEVUnknown *U = dyn_cast<SCEVUnknown>(S)) {
6624
6625 // Check if the IR explicitly contains !range metadata.
6626 Optional<ConstantRange> MDRange = GetRangeFromMetadata(U->getValue());
6627 if (MDRange)
6628 ConservativeResult =
6629 ConservativeResult.intersectWith(MDRange.value(), RangeType);
6630
6631 // Use facts about recurrences in the underlying IR. Note that add
6632 // recurrences are AddRecExprs and thus don't hit this path. This
6633 // primarily handles shift recurrences.
6634 auto CR = getRangeForUnknownRecurrence(U);
6635 ConservativeResult = ConservativeResult.intersectWith(CR);
6636
6637 // See if ValueTracking can give us a useful range.
6638 const DataLayout &DL = getDataLayout();
6639 KnownBits Known = computeKnownBits(U->getValue(), DL, 0, &AC, nullptr, &DT);
6640 if (Known.getBitWidth() != BitWidth)
6641 Known = Known.zextOrTrunc(BitWidth);
6642
6643 // ValueTracking may be able to compute a tighter result for the number of
6644 // sign bits than for the value of those sign bits.
6645 unsigned NS = ComputeNumSignBits(U->getValue(), DL, 0, &AC, nullptr, &DT);
6646 if (U->getType()->isPointerTy()) {
6647 // If the pointer size is larger than the index size type, this can cause
6648 // NS to be larger than BitWidth. So compensate for this.
6649 unsigned ptrSize = DL.getPointerTypeSizeInBits(U->getType());
6650 int ptrIdxDiff = ptrSize - BitWidth;
6651 if (ptrIdxDiff > 0 && ptrSize > BitWidth && NS > (unsigned)ptrIdxDiff)
6652 NS -= ptrIdxDiff;
6653 }
6654
6655 if (NS > 1) {
6656 // If we know any of the sign bits, we know all of the sign bits.
6657 if (!Known.Zero.getHiBits(NS).isZero())
6658 Known.Zero.setHighBits(NS);
6659 if (!Known.One.getHiBits(NS).isZero())
6660 Known.One.setHighBits(NS);
6661 }
6662
6663 if (Known.getMinValue() != Known.getMaxValue() + 1)
6664 ConservativeResult = ConservativeResult.intersectWith(
6665 ConstantRange(Known.getMinValue(), Known.getMaxValue() + 1),
6666 RangeType);
6667 if (NS > 1)
6668 ConservativeResult = ConservativeResult.intersectWith(
6669 ConstantRange(APInt::getSignedMinValue(BitWidth).ashr(NS - 1),
6670 APInt::getSignedMaxValue(BitWidth).ashr(NS - 1) + 1),
6671 RangeType);
6672
6673 // A range of Phi is a subset of union of all ranges of its input.
6674 if (const PHINode *Phi = dyn_cast<PHINode>(U->getValue())) {
6675 // Make sure that we do not run over cycled Phis.
6676 if (PendingPhiRanges.insert(Phi).second) {
6677 ConstantRange RangeFromOps(BitWidth, /*isFullSet=*/false);
6678 for (const auto &Op : Phi->operands()) {
6679 auto OpRange = getRangeRef(getSCEV(Op), SignHint);
6680 RangeFromOps = RangeFromOps.unionWith(OpRange);
6681 // No point to continue if we already have a full set.
6682 if (RangeFromOps.isFullSet())
6683 break;
6684 }
6685 ConservativeResult =
6686 ConservativeResult.intersectWith(RangeFromOps, RangeType);
6687 bool Erased = PendingPhiRanges.erase(Phi);
6688 assert(Erased && "Failed to erase Phi properly?")(static_cast <bool> (Erased && "Failed to erase Phi properly?"
) ? void (0) : __assert_fail ("Erased && \"Failed to erase Phi properly?\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 6688, __extension__
__PRETTY_FUNCTION__))
;
6689 (void) Erased;
6690 }
6691 }
6692
6693 // vscale can't be equal to zero
6694 if (const auto *II = dyn_cast<IntrinsicInst>(U->getValue()))
6695 if (II->getIntrinsicID() == Intrinsic::vscale) {
6696 ConstantRange Disallowed = APInt::getZero(BitWidth);
6697 ConservativeResult = ConservativeResult.difference(Disallowed);
6698 }
6699
6700 return setRange(U, SignHint, std::move(ConservativeResult));
6701 }
6702
6703 return setRange(S, SignHint, std::move(ConservativeResult));
6704}
6705
6706// Given a StartRange, Step and MaxBECount for an expression compute a range of
6707// values that the expression can take. Initially, the expression has a value
6708// from StartRange and then is changed by Step up to MaxBECount times. Signed
6709// argument defines if we treat Step as signed or unsigned.
6710static ConstantRange getRangeForAffineARHelper(APInt Step,
6711 const ConstantRange &StartRange,
6712 const APInt &MaxBECount,
6713 unsigned BitWidth, bool Signed) {
6714 // If either Step or MaxBECount is 0, then the expression won't change, and we
6715 // just need to return the initial range.
6716 if (Step == 0 || MaxBECount == 0)
6717 return StartRange;
6718
6719 // If we don't know anything about the initial value (i.e. StartRange is
6720 // FullRange), then we don't know anything about the final range either.
6721 // Return FullRange.
6722 if (StartRange.isFullSet())
6723 return ConstantRange::getFull(BitWidth);
6724
6725 // If Step is signed and negative, then we use its absolute value, but we also
6726 // note that we're moving in the opposite direction.
6727 bool Descending = Signed && Step.isNegative();
6728
6729 if (Signed)
6730 // This is correct even for INT_SMIN. Let's look at i8 to illustrate this:
6731 // abs(INT_SMIN) = abs(-128) = abs(0x80) = -0x80 = 0x80 = 128.
6732 // This equations hold true due to the well-defined wrap-around behavior of
6733 // APInt.
6734 Step = Step.abs();
6735
6736 // Check if Offset is more than full span of BitWidth. If it is, the
6737 // expression is guaranteed to overflow.
6738 if (APInt::getMaxValue(StartRange.getBitWidth()).udiv(Step).ult(MaxBECount))
6739 return ConstantRange::getFull(BitWidth);
6740
6741 // Offset is by how much the expression can change. Checks above guarantee no
6742 // overflow here.
6743 APInt Offset = Step * MaxBECount;
6744
6745 // Minimum value of the final range will match the minimal value of StartRange
6746 // if the expression is increasing and will be decreased by Offset otherwise.
6747 // Maximum value of the final range will match the maximal value of StartRange
6748 // if the expression is decreasing and will be increased by Offset otherwise.
6749 APInt StartLower = StartRange.getLower();
6750 APInt StartUpper = StartRange.getUpper() - 1;
6751 APInt MovedBoundary = Descending ? (StartLower - std::move(Offset))
6752 : (StartUpper + std::move(Offset));
6753
6754 // It's possible that the new minimum/maximum value will fall into the initial
6755 // range (due to wrap around). This means that the expression can take any
6756 // value in this bitwidth, and we have to return full range.
6757 if (StartRange.contains(MovedBoundary))
6758 return ConstantRange::getFull(BitWidth);
6759
6760 APInt NewLower =
6761 Descending ? std::move(MovedBoundary) : std::move(StartLower);
6762 APInt NewUpper =
6763 Descending ? std::move(StartUpper) : std::move(MovedBoundary);
6764 NewUpper += 1;
6765
6766 // No overflow detected, return [StartLower, StartUpper + Offset + 1) range.
6767 return ConstantRange::getNonEmpty(std::move(NewLower), std::move(NewUpper));
6768}
6769
6770ConstantRange ScalarEvolution::getRangeForAffineAR(const SCEV *Start,
6771 const SCEV *Step,
6772 const SCEV *MaxBECount,
6773 unsigned BitWidth) {
6774 assert(!isa<SCEVCouldNotCompute>(MaxBECount) &&(static_cast <bool> (!isa<SCEVCouldNotCompute>(MaxBECount
) && getTypeSizeInBits(MaxBECount->getType()) <=
BitWidth && "Precondition!") ? void (0) : __assert_fail
("!isa<SCEVCouldNotCompute>(MaxBECount) && getTypeSizeInBits(MaxBECount->getType()) <= BitWidth && \"Precondition!\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 6776, __extension__
__PRETTY_FUNCTION__))
6775 getTypeSizeInBits(MaxBECount->getType()) <= BitWidth &&(static_cast <bool> (!isa<SCEVCouldNotCompute>(MaxBECount
) && getTypeSizeInBits(MaxBECount->getType()) <=
BitWidth && "Precondition!") ? void (0) : __assert_fail
("!isa<SCEVCouldNotCompute>(MaxBECount) && getTypeSizeInBits(MaxBECount->getType()) <= BitWidth && \"Precondition!\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 6776, __extension__
__PRETTY_FUNCTION__))
6776 "Precondition!")(static_cast <bool> (!isa<SCEVCouldNotCompute>(MaxBECount
) && getTypeSizeInBits(MaxBECount->getType()) <=
BitWidth && "Precondition!") ? void (0) : __assert_fail
("!isa<SCEVCouldNotCompute>(MaxBECount) && getTypeSizeInBits(MaxBECount->getType()) <= BitWidth && \"Precondition!\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 6776, __extension__
__PRETTY_FUNCTION__))
;
6777
6778 MaxBECount = getNoopOrZeroExtend(MaxBECount, Start->getType());
6779 APInt MaxBECountValue = getUnsignedRangeMax(MaxBECount);
6780
6781 // First, consider step signed.
6782 ConstantRange StartSRange = getSignedRange(Start);
6783 ConstantRange StepSRange = getSignedRange(Step);
6784
6785 // If Step can be both positive and negative, we need to find ranges for the
6786 // maximum absolute step values in both directions and union them.
6787 ConstantRange SR =
6788 getRangeForAffineARHelper(StepSRange.getSignedMin(), StartSRange,
6789 MaxBECountValue, BitWidth, /* Signed = */ true);
6790 SR = SR.unionWith(getRangeForAffineARHelper(StepSRange.getSignedMax(),
6791 StartSRange, MaxBECountValue,
6792 BitWidth, /* Signed = */ true));
6793
6794 // Next, consider step unsigned.
6795 ConstantRange UR = getRangeForAffineARHelper(
6796 getUnsignedRangeMax(Step), getUnsignedRange(Start),
6797 MaxBECountValue, BitWidth, /* Signed = */ false);
6798
6799 // Finally, intersect signed and unsigned ranges.
6800 return SR.intersectWith(UR, ConstantRange::Smallest);
6801}
6802
6803ConstantRange ScalarEvolution::getRangeForAffineNoSelfWrappingAR(
6804 const SCEVAddRecExpr *AddRec, const SCEV *MaxBECount, unsigned BitWidth,
6805 ScalarEvolution::RangeSignHint SignHint) {
6806 assert(AddRec->isAffine() && "Non-affine AddRecs are not suppored!\n")(static_cast <bool> (AddRec->isAffine() && "Non-affine AddRecs are not suppored!\n"
) ? void (0) : __assert_fail ("AddRec->isAffine() && \"Non-affine AddRecs are not suppored!\\n\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 6806, __extension__
__PRETTY_FUNCTION__))
;
6807 assert(AddRec->hasNoSelfWrap() &&(static_cast <bool> (AddRec->hasNoSelfWrap() &&
"This only works for non-self-wrapping AddRecs!") ? void (0)
: __assert_fail ("AddRec->hasNoSelfWrap() && \"This only works for non-self-wrapping AddRecs!\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 6808, __extension__
__PRETTY_FUNCTION__))
6808 "This only works for non-self-wrapping AddRecs!")(static_cast <bool> (AddRec->hasNoSelfWrap() &&
"This only works for non-self-wrapping AddRecs!") ? void (0)
: __assert_fail ("AddRec->hasNoSelfWrap() && \"This only works for non-self-wrapping AddRecs!\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 6808, __extension__
__PRETTY_FUNCTION__))
;
6809 const bool IsSigned = SignHint == HINT_RANGE_SIGNED;
6810 const SCEV *Step = AddRec->getStepRecurrence(*this);
6811 // Only deal with constant step to save compile time.
6812 if (!isa<SCEVConstant>(Step))
6813 return ConstantRange::getFull(BitWidth);
6814 // Let's make sure that we can prove that we do not self-wrap during
6815 // MaxBECount iterations. We need this because MaxBECount is a maximum
6816 // iteration count estimate, and we might infer nw from some exit for which we
6817 // do not know max exit count (or any other side reasoning).
6818 // TODO: Turn into assert at some point.
6819 if (getTypeSizeInBits(MaxBECount->getType()) >
6820 getTypeSizeInBits(AddRec->getType()))
6821 return ConstantRange::getFull(BitWidth);
6822 MaxBECount = getNoopOrZeroExtend(MaxBECount, AddRec->getType());
6823 const SCEV *RangeWidth = getMinusOne(AddRec->getType());
6824 const SCEV *StepAbs = getUMinExpr(Step, getNegativeSCEV(Step));
6825 const SCEV *MaxItersWithoutWrap = getUDivExpr(RangeWidth, StepAbs);
6826 if (!isKnownPredicateViaConstantRanges(ICmpInst::ICMP_ULE, MaxBECount,
6827 MaxItersWithoutWrap))
6828 return ConstantRange::getFull(BitWidth);
6829
6830 ICmpInst::Predicate LEPred =
6831 IsSigned ? ICmpInst::ICMP_SLE : ICmpInst::ICMP_ULE;
6832 ICmpInst::Predicate GEPred =
6833 IsSigned ? ICmpInst::ICMP_SGE : ICmpInst::ICMP_UGE;
6834 const SCEV *End = AddRec->evaluateAtIteration(MaxBECount, *this);
6835
6836 // We know that there is no self-wrap. Let's take Start and End values and
6837 // look at all intermediate values V1, V2, ..., Vn that IndVar takes during
6838 // the iteration. They either lie inside the range [Min(Start, End),
6839 // Max(Start, End)] or outside it:
6840 //
6841 // Case 1: RangeMin ... Start V1 ... VN End ... RangeMax;
6842 // Case 2: RangeMin Vk ... V1 Start ... End Vn ... Vk + 1 RangeMax;
6843 //
6844 // No self wrap flag guarantees that the intermediate values cannot be BOTH
6845 // outside and inside the range [Min(Start, End), Max(Start, End)]. Using that
6846 // knowledge, let's try to prove that we are dealing with Case 1. It is so if
6847 // Start <= End and step is positive, or Start >= End and step is negative.
6848 const SCEV *Start = AddRec->getStart();
6849 ConstantRange StartRange = getRangeRef(Start, SignHint);
6850 ConstantRange EndRange = getRangeRef(End, SignHint);
6851 ConstantRange RangeBetween = StartRange.unionWith(EndRange);
6852 // If they already cover full iteration space, we will know nothing useful
6853 // even if we prove what we want to prove.
6854 if (RangeBetween.isFullSet())
6855 return RangeBetween;
6856 // Only deal with ranges that do not wrap (i.e. RangeMin < RangeMax).
6857 bool IsWrappedSet = IsSigned ? RangeBetween.isSignWrappedSet()
6858 : RangeBetween.isWrappedSet();
6859 if (IsWrappedSet)
6860 return ConstantRange::getFull(BitWidth);
6861
6862 if (isKnownPositive(Step) &&
6863 isKnownPredicateViaConstantRanges(LEPred, Start, End))
6864 return RangeBetween;
6865 else if (isKnownNegative(Step) &&
6866 isKnownPredicateViaConstantRanges(GEPred, Start, End))
6867 return RangeBetween;
6868 return ConstantRange::getFull(BitWidth);
6869}
6870
6871ConstantRange ScalarEvolution::getRangeViaFactoring(const SCEV *Start,
6872 const SCEV *Step,
6873 const SCEV *MaxBECount,
6874 unsigned BitWidth) {
6875 // RangeOf({C?A:B,+,C?P:Q}) == RangeOf(C?{A,+,P}:{B,+,Q})
6876 // == RangeOf({A,+,P}) union RangeOf({B,+,Q})
6877
6878 struct SelectPattern {
6879 Value *Condition = nullptr;
6880 APInt TrueValue;
6881 APInt FalseValue;
6882
6883 explicit SelectPattern(ScalarEvolution &SE, unsigned BitWidth,
6884 const SCEV *S) {
6885 Optional<unsigned> CastOp;
6886 APInt Offset(BitWidth, 0);
6887
6888 assert(SE.getTypeSizeInBits(S->getType()) == BitWidth &&(static_cast <bool> (SE.getTypeSizeInBits(S->getType
()) == BitWidth && "Should be!") ? void (0) : __assert_fail
("SE.getTypeSizeInBits(S->getType()) == BitWidth && \"Should be!\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 6889, __extension__
__PRETTY_FUNCTION__))
6889 "Should be!")(static_cast <bool> (SE.getTypeSizeInBits(S->getType
()) == BitWidth && "Should be!") ? void (0) : __assert_fail
("SE.getTypeSizeInBits(S->getType()) == BitWidth && \"Should be!\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 6889, __extension__
__PRETTY_FUNCTION__))
;
6890
6891 // Peel off a constant offset:
6892 if (auto *SA = dyn_cast<SCEVAddExpr>(S)) {
6893 // In the future we could consider being smarter here and handle
6894 // {Start+Step,+,Step} too.
6895 if (SA->getNumOperands() != 2 || !isa<SCEVConstant>(SA->getOperand(0)))
6896 return;
6897
6898 Offset = cast<SCEVConstant>(SA->getOperand(0))->getAPInt();
6899 S = SA->getOperand(1);
6900 }
6901
6902 // Peel off a cast operation
6903 if (auto *SCast = dyn_cast<SCEVIntegralCastExpr>(S)) {
6904 CastOp = SCast->getSCEVType();
6905 S = SCast->getOperand();
6906 }
6907
6908 using namespace llvm::PatternMatch;
6909
6910 auto *SU = dyn_cast<SCEVUnknown>(S);
6911 const APInt *TrueVal, *FalseVal;
6912 if (!SU ||
6913 !match(SU->getValue(), m_Select(m_Value(Condition), m_APInt(TrueVal),
6914 m_APInt(FalseVal)))) {
6915 Condition = nullptr;
6916 return;
6917 }
6918
6919 TrueValue = *TrueVal;
6920 FalseValue = *FalseVal;
6921
6922 // Re-apply the cast we peeled off earlier
6923 if (CastOp)
6924 switch (*CastOp) {
6925 default:
6926 llvm_unreachable("Unknown SCEV cast type!")::llvm::llvm_unreachable_internal("Unknown SCEV cast type!", "llvm/lib/Analysis/ScalarEvolution.cpp"
, 6926)
;
6927
6928 case scTruncate:
6929 TrueValue = TrueValue.trunc(BitWidth);
6930 FalseValue = FalseValue.trunc(BitWidth);
6931 break;
6932 case scZeroExtend:
6933 TrueValue = TrueValue.zext(BitWidth);
6934 FalseValue = FalseValue.zext(BitWidth);
6935 break;
6936 case scSignExtend:
6937 TrueValue = TrueValue.sext(BitWidth);
6938 FalseValue = FalseValue.sext(BitWidth);
6939 break;
6940 }
6941
6942 // Re-apply the constant offset we peeled off earlier
6943 TrueValue += Offset;
6944 FalseValue += Offset;
6945 }
6946
6947 bool isRecognized() { return Condition != nullptr; }
6948 };
6949
6950 SelectPattern StartPattern(*this, BitWidth, Start);
6951 if (!StartPattern.isRecognized())
6952 return ConstantRange::getFull(BitWidth);
6953
6954 SelectPattern StepPattern(*this, BitWidth, Step);
6955 if (!StepPattern.isRecognized())
6956 return ConstantRange::getFull(BitWidth);
6957
6958 if (StartPattern.Condition != StepPattern.Condition) {
6959 // We don't handle this case today; but we could, by considering four
6960 // possibilities below instead of two. I'm not sure if there are cases where
6961 // that will help over what getRange already does, though.
6962 return ConstantRange::getFull(BitWidth);
6963 }
6964
6965 // NB! Calling ScalarEvolution::getConstant is fine, but we should not try to
6966 // construct arbitrary general SCEV expressions here. This function is called
6967 // from deep in the call stack, and calling getSCEV (on a sext instruction,
6968 // say) can end up caching a suboptimal value.
6969
6970 // FIXME: without the explicit `this` receiver below, MSVC errors out with
6971 // C2352 and C2512 (otherwise it isn't needed).
6972
6973 const SCEV *TrueStart = this->getConstant(StartPattern.TrueValue);
6974 const SCEV *TrueStep = this->getConstant(StepPattern.TrueValue);
6975 const SCEV *FalseStart = this->getConstant(StartPattern.FalseValue);
6976 const SCEV *FalseStep = this->getConstant(StepPattern.FalseValue);
6977
6978 ConstantRange TrueRange =
6979 this->getRangeForAffineAR(TrueStart, TrueStep, MaxBECount, BitWidth);
6980 ConstantRange FalseRange =
6981 this->getRangeForAffineAR(FalseStart, FalseStep, MaxBECount, BitWidth);
6982
6983 return TrueRange.unionWith(FalseRange);
6984}
6985
6986SCEV::NoWrapFlags ScalarEvolution::getNoWrapFlagsFromUB(const Value *V) {
6987 if (isa<ConstantExpr>(V)) return SCEV::FlagAnyWrap;
6988 const BinaryOperator *BinOp = cast<BinaryOperator>(V);
6989
6990 // Return early if there are no flags to propagate to the SCEV.
6991 SCEV::NoWrapFlags Flags = SCEV::FlagAnyWrap;
6992 if (BinOp->hasNoUnsignedWrap())
6993 Flags = ScalarEvolution::setFlags(Flags, SCEV::FlagNUW);
6994 if (BinOp->hasNoSignedWrap())
6995 Flags = ScalarEvolution::setFlags(Flags, SCEV::FlagNSW);
6996 if (Flags == SCEV::FlagAnyWrap)
6997 return SCEV::FlagAnyWrap;
6998
6999 return isSCEVExprNeverPoison(BinOp) ? Flags : SCEV::FlagAnyWrap;
7000}
7001
7002const Instruction *
7003ScalarEvolution::getNonTrivialDefiningScopeBound(const SCEV *S) {
7004 if (auto *AddRec = dyn_cast<SCEVAddRecExpr>(S))
7005 return &*AddRec->getLoop()->getHeader()->begin();
7006 if (auto *U = dyn_cast<SCEVUnknown>(S))
7007 if (auto *I = dyn_cast<Instruction>(U->getValue()))
7008 return I;
7009 return nullptr;
7010}
7011
7012/// Fills \p Ops with unique operands of \p S, if it has operands. If not,
7013/// \p Ops remains unmodified.
7014static void collectUniqueOps(const SCEV *S,
7015 SmallVectorImpl<const SCEV *> &Ops) {
7016 SmallPtrSet<const SCEV *, 4> Unique;
7017 auto InsertUnique = [&](const SCEV *S) {
7018 if (Unique.insert(S).second)
7019 Ops.push_back(S);
7020 };
7021 if (auto *S2 = dyn_cast<SCEVCastExpr>(S))
7022 for (const auto *Op : S2->operands())
7023 InsertUnique(Op);
7024 else if (auto *S2 = dyn_cast<SCEVNAryExpr>(S))
7025 for (const auto *Op : S2->operands())
7026 InsertUnique(Op);
7027 else if (auto *S2 = dyn_cast<SCEVUDivExpr>(S))
7028 for (const auto *Op : S2->operands())
7029 InsertUnique(Op);
7030}
7031
7032const Instruction *
7033ScalarEvolution::getDefiningScopeBound(ArrayRef<const SCEV *> Ops,
7034 bool &Precise) {
7035 Precise = true;
7036 // Do a bounded search of the def relation of the requested SCEVs.
7037 SmallSet<const SCEV *, 16> Visited;
7038 SmallVector<const SCEV *> Worklist;
7039 auto pushOp = [&](const SCEV *S) {
7040 if (!Visited.insert(S).second)
7041 return;
7042 // Threshold of 30 here is arbitrary.
7043 if (Visited.size() > 30) {
7044 Precise = false;
7045 return;
7046 }
7047 Worklist.push_back(S);
7048 };
7049
7050 for (const auto *S : Ops)
7051 pushOp(S);
7052
7053 const Instruction *Bound = nullptr;
7054 while (!Worklist.empty()) {
7055 auto *S = Worklist.pop_back_val();
7056 if (auto *DefI = getNonTrivialDefiningScopeBound(S)) {
7057 if (!Bound || DT.dominates(Bound, DefI))
7058 Bound = DefI;
7059 } else {
7060 SmallVector<const SCEV *, 4> Ops;
7061 collectUniqueOps(S, Ops);
7062 for (const auto *Op : Ops)
7063 pushOp(Op);
7064 }
7065 }
7066 return Bound ? Bound : &*F.getEntryBlock().begin();
7067}
7068
7069const Instruction *
7070ScalarEvolution::getDefiningScopeBound(ArrayRef<const SCEV *> Ops) {
7071 bool Discard;
7072 return getDefiningScopeBound(Ops, Discard);
7073}
7074
7075bool ScalarEvolution::isGuaranteedToTransferExecutionTo(const Instruction *A,
7076 const Instruction *B) {
7077 if (A->getParent() == B->getParent() &&
7078 isGuaranteedToTransferExecutionToSuccessor(A->getIterator(),
7079 B->getIterator()))
7080 return true;
7081
7082 auto *BLoop = LI.getLoopFor(B->getParent());
7083 if (BLoop && BLoop->getHeader() == B->getParent() &&
7084 BLoop->getLoopPreheader() == A->getParent() &&
7085 isGuaranteedToTransferExecutionToSuccessor(A->getIterator(),
7086 A->getParent()->end()) &&
7087 isGuaranteedToTransferExecutionToSuccessor(B->getParent()->begin(),
7088 B->getIterator()))
7089 return true;
7090 return false;
7091}
7092
7093
7094bool ScalarEvolution::isSCEVExprNeverPoison(const Instruction *I) {
7095 // Only proceed if we can prove that I does not yield poison.
7096 if (!programUndefinedIfPoison(I))
7097 return false;
7098
7099 // At this point we know that if I is executed, then it does not wrap
7100 // according to at least one of NSW or NUW. If I is not executed, then we do
7101 // not know if the calculation that I represents would wrap. Multiple
7102 // instructions can map to the same SCEV. If we apply NSW or NUW from I to
7103 // the SCEV, we must guarantee no wrapping for that SCEV also when it is
7104 // derived from other instructions that map to the same SCEV. We cannot make
7105 // that guarantee for cases where I is not executed. So we need to find a
7106 // upper bound on the defining scope for the SCEV, and prove that I is
7107 // executed every time we enter that scope. When the bounding scope is a
7108 // loop (the common case), this is equivalent to proving I executes on every
7109 // iteration of that loop.
7110 SmallVector<const SCEV *> SCEVOps;
7111 for (const Use &Op : I->operands()) {
7112 // I could be an extractvalue from a call to an overflow intrinsic.
7113 // TODO: We can do better here in some cases.
7114 if (isSCEVable(Op->getType()))
7115 SCEVOps.push_back(getSCEV(Op));
7116 }
7117 auto *DefI = getDefiningScopeBound(SCEVOps);
7118 return isGuaranteedToTransferExecutionTo(DefI, I);
7119}
7120
7121bool ScalarEvolution::isAddRecNeverPoison(const Instruction *I, const Loop *L) {
7122 // If we know that \c I can never be poison period, then that's enough.
7123 if (isSCEVExprNeverPoison(I))
7124 return true;
7125
7126 // For an add recurrence specifically, we assume that infinite loops without
7127 // side effects are undefined behavior, and then reason as follows:
7128 //
7129 // If the add recurrence is poison in any iteration, it is poison on all
7130 // future iterations (since incrementing poison yields poison). If the result
7131 // of the add recurrence is fed into the loop latch condition and the loop
7132 // does not contain any throws or exiting blocks other than the latch, we now
7133 // have the ability to "choose" whether the backedge is taken or not (by
7134 // choosing a sufficiently evil value for the poison feeding into the branch)
7135 // for every iteration including and after the one in which \p I first became
7136 // poison. There are two possibilities (let's call the iteration in which \p
7137 // I first became poison as K):
7138 //
7139 // 1. In the set of iterations including and after K, the loop body executes
7140 // no side effects. In this case executing the backege an infinte number
7141 // of times will yield undefined behavior.
7142 //
7143 // 2. In the set of iterations including and after K, the loop body executes
7144 // at least one side effect. In this case, that specific instance of side
7145 // effect is control dependent on poison, which also yields undefined
7146 // behavior.
7147
7148 auto *ExitingBB = L->getExitingBlock();
7149 auto *LatchBB = L->getLoopLatch();
7150 if (!ExitingBB || !LatchBB || ExitingBB != LatchBB)
7151 return false;
7152
7153 SmallPtrSet<const Instruction *, 16> Pushed;
7154 SmallVector<const Instruction *, 8> PoisonStack;
7155
7156 // We start by assuming \c I, the post-inc add recurrence, is poison. Only
7157 // things that are known to be poison under that assumption go on the
7158 // PoisonStack.
7159 Pushed.insert(I);
7160 PoisonStack.push_back(I);
7161
7162 bool LatchControlDependentOnPoison = false;
7163 while (!PoisonStack.empty() && !LatchControlDependentOnPoison) {
7164 const Instruction *Poison = PoisonStack.pop_back_val();
7165
7166 for (const auto *PoisonUser : Poison->users()) {
7167 if (propagatesPoison(cast<Operator>(PoisonUser))) {
7168 if (Pushed.insert(cast<Instruction>(PoisonUser)).second)
7169 PoisonStack.push_back(cast<Instruction>(PoisonUser));
7170 } else if (auto *BI = dyn_cast<BranchInst>(PoisonUser)) {
7171 assert(BI->isConditional() && "Only possibility!")(static_cast <bool> (BI->isConditional() && "Only possibility!"
) ? void (0) : __assert_fail ("BI->isConditional() && \"Only possibility!\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 7171, __extension__
__PRETTY_FUNCTION__))
;
7172 if (BI->getParent() == LatchBB) {
7173 LatchControlDependentOnPoison = true;
7174 break;
7175 }
7176 }
7177 }
7178 }
7179
7180 return LatchControlDependentOnPoison && loopHasNoAbnormalExits(L);
7181}
7182
7183ScalarEvolution::LoopProperties
7184ScalarEvolution::getLoopProperties(const Loop *L) {
7185 using LoopProperties = ScalarEvolution::LoopProperties;
7186
7187 auto Itr = LoopPropertiesCache.find(L);
7188 if (Itr == LoopPropertiesCache.end()) {
7189 auto HasSideEffects = [](Instruction *I) {
7190 if (auto *SI = dyn_cast<StoreInst>(I))
7191 return !SI->isSimple();
7192
7193 return I->mayThrow() || I->mayWriteToMemory();
7194 };
7195
7196 LoopProperties LP = {/* HasNoAbnormalExits */ true,
7197 /*HasNoSideEffects*/ true};
7198
7199 for (auto *BB : L->getBlocks())
7200 for (auto &I : *BB) {
7201 if (!isGuaranteedToTransferExecutionToSuccessor(&I))
7202 LP.HasNoAbnormalExits = false;
7203 if (HasSideEffects(&I))
7204 LP.HasNoSideEffects = false;
7205 if (!LP.HasNoAbnormalExits && !LP.HasNoSideEffects)
7206 break; // We're already as pessimistic as we can get.
7207 }
7208
7209 auto InsertPair = LoopPropertiesCache.insert({L, LP});
7210 assert(InsertPair.second && "We just checked!")(static_cast <bool> (InsertPair.second && "We just checked!"
) ? void (0) : __assert_fail ("InsertPair.second && \"We just checked!\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 7210, __extension__
__PRETTY_FUNCTION__))
;
7211 Itr = InsertPair.first;
7212 }
7213
7214 return Itr->second;
7215}
7216
7217bool ScalarEvolution::loopIsFiniteByAssumption(const Loop *L) {
7218 // A mustprogress loop without side effects must be finite.
7219 // TODO: The check used here is very conservative. It's only *specific*
7220 // side effects which are well defined in infinite loops.
7221 return isFinite(L) || (isMustProgress(L) && loopHasNoSideEffects(L));
7222}
7223
7224const SCEV *ScalarEvolution::createSCEVIter(Value *V) {
7225 // Worklist item with a Value and a bool indicating whether all operands have
7226 // been visited already.
7227 using PointerTy = PointerIntPair<Value *, 1, bool>;
7228 SmallVector<PointerTy> Stack;
7229
7230 Stack.emplace_back(V, true);
7231 Stack.emplace_back(V, false);
7232 while (!Stack.empty()) {
7233 auto E = Stack.pop_back_val();
7234 Value *CurV = E.getPointer();
7235
7236 if (getExistingSCEV(CurV))
7237 continue;
7238
7239 SmallVector<Value *> Ops;
7240 const SCEV *CreatedSCEV = nullptr;
7241 // If all operands have been visited already, create the SCEV.
7242 if (E.getInt()) {
7243 CreatedSCEV = createSCEV(CurV);
7244 } else {
7245 // Otherwise get the operands we need to create SCEV's for before creating
7246 // the SCEV for CurV. If the SCEV for CurV can be constructed trivially,
7247 // just use it.
7248 CreatedSCEV = getOperandsToCreate(CurV, Ops);
7249 }
7250
7251 if (CreatedSCEV) {
7252 insertValueToMap(CurV, CreatedSCEV);
7253 } else {
7254 // Queue CurV for SCEV creation, followed by its's operands which need to
7255 // be constructed first.
7256 Stack.emplace_back(CurV, true);
7257 for (Value *Op : Ops)
7258 Stack.emplace_back(Op, false);
7259 }
7260 }
7261
7262 return getExistingSCEV(V);
7263}
7264
7265const SCEV *
7266ScalarEvolution::getOperandsToCreate(Value *V, SmallVectorImpl<Value *> &Ops) {
7267 if (!isSCEVable(V->getType()))
7268 return getUnknown(V);
7269
7270 if (Instruction *I = dyn_cast<Instruction>(V)) {
7271 // Don't attempt to analyze instructions in blocks that aren't
7272 // reachable. Such instructions don't matter, and they aren't required
7273 // to obey basic rules for definitions dominating uses which this
7274 // analysis depends on.
7275 if (!DT.isReachableFromEntry(I->getParent()))
7276 return getUnknown(PoisonValue::get(V->getType()));
7277 } else if (ConstantInt *CI = dyn_cast<ConstantInt>(V))
7278 return getConstant(CI);
7279 else if (GlobalAlias *GA = dyn_cast<GlobalAlias>(V)) {
7280 if (!GA->isInterposable()) {
7281 Ops.push_back(GA->getAliasee());
7282 return nullptr;
7283 }
7284 return getUnknown(V);
7285 } else if (!isa<ConstantExpr>(V))
7286 return getUnknown(V);
7287
7288 Operator *U = cast<Operator>(V);
7289 if (auto BO = MatchBinaryOp(U, DT)) {
7290 bool IsConstArg = isa<ConstantInt>(BO->RHS);
7291 switch (BO->Opcode) {
7292 case Instruction::Add:
7293 case Instruction::Mul: {
7294 // For additions and multiplications, traverse add/mul chains for which we
7295 // can potentially create a single SCEV, to reduce the number of
7296 // get{Add,Mul}Expr calls.
7297 do {
7298 if (BO->Op) {
7299 if (BO->Op != V && getExistingSCEV(BO->Op)) {
7300 Ops.push_back(BO->Op);
7301 break;
7302 }
7303 }
7304 Ops.push_back(BO->RHS);
7305 auto NewBO = MatchBinaryOp(BO->LHS, DT);
7306 if (!NewBO ||
7307 (U->getOpcode() == Instruction::Add &&
7308 (NewBO->Opcode != Instruction::Add &&
7309 NewBO->Opcode != Instruction::Sub)) ||
7310 (U->getOpcode() == Instruction::Mul &&
7311 NewBO->Opcode != Instruction::Mul)) {
7312 Ops.push_back(BO->LHS);
7313 break;
7314 }
7315 // CreateSCEV calls getNoWrapFlagsFromUB, which under certain conditions
7316 // requires a SCEV for the LHS.
7317 if (NewBO->Op && (NewBO->IsNSW || NewBO->IsNUW)) {
7318 auto *I = dyn_cast<Instruction>(NewBO->Op);
7319 if (I && programUndefinedIfPoison(I)) {
7320 Ops.push_back(BO->LHS);
7321 break;
7322 }
7323 }
7324 BO = NewBO;
7325 } while (true);
7326 return nullptr;
7327 }
7328 case Instruction::Sub:
7329 case Instruction::UDiv:
7330 case Instruction::URem:
7331 break;
7332 case Instruction::AShr:
7333 case Instruction::Shl:
7334 case Instruction::Xor:
7335 if (!IsConstArg)
7336 return nullptr;
7337 break;
7338 case Instruction::And:
7339 case Instruction::Or:
7340 if (!IsConstArg && BO->LHS->getType()->isIntegerTy(1))
7341 return nullptr;
7342 break;
7343 case Instruction::LShr:
7344 return getUnknown(V);
7345 default:
7346 llvm_unreachable("Unhandled binop")::llvm::llvm_unreachable_internal("Unhandled binop", "llvm/lib/Analysis/ScalarEvolution.cpp"
, 7346)
;
7347 break;
7348 }
7349
7350 Ops.push_back(BO->LHS);
7351 Ops.push_back(BO->RHS);
7352 return nullptr;
7353 }
7354
7355 switch (U->getOpcode()) {
7356 case Instruction::Trunc:
7357 case Instruction::ZExt:
7358 case Instruction::SExt:
7359 case Instruction::PtrToInt:
7360 Ops.push_back(U->getOperand(0));
7361 return nullptr;
7362
7363 case Instruction::BitCast:
7364 if (isSCEVable(U->getType()) && isSCEVable(U->getOperand(0)->getType())) {
7365 Ops.push_back(U->getOperand(0));
7366 return nullptr;
7367 }
7368 return getUnknown(V);
7369
7370 case Instruction::SDiv:
7371 case Instruction::SRem:
7372 Ops.push_back(U->getOperand(0));
7373 Ops.push_back(U->getOperand(1));
7374 return nullptr;
7375
7376 case Instruction::GetElementPtr:
7377 assert(cast<GEPOperator>(U)->getSourceElementType()->isSized() &&(static_cast <bool> (cast<GEPOperator>(U)->getSourceElementType
()->isSized() && "GEP source element type must be sized"
) ? void (0) : __assert_fail ("cast<GEPOperator>(U)->getSourceElementType()->isSized() && \"GEP source element type must be sized\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 7378, __extension__
__PRETTY_FUNCTION__))
7378 "GEP source element type must be sized")(static_cast <bool> (cast<GEPOperator>(U)->getSourceElementType
()->isSized() && "GEP source element type must be sized"
) ? void (0) : __assert_fail ("cast<GEPOperator>(U)->getSourceElementType()->isSized() && \"GEP source element type must be sized\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 7378, __extension__
__PRETTY_FUNCTION__))
;
7379 for (Value *Index : U->operands())
7380 Ops.push_back(Index);
7381 return nullptr;
7382
7383 case Instruction::IntToPtr:
7384 return getUnknown(V);
7385
7386 case Instruction::PHI:
7387 // Keep constructing SCEVs' for phis recursively for now.
7388 return nullptr;
7389
7390 case Instruction::Select: {
7391 // Check if U is a select that can be simplified to a SCEVUnknown.
7392 auto CanSimplifyToUnknown = [this, U]() {
7393 if (U->getType()->isIntegerTy(1) || isa<ConstantInt>(U->getOperand(0)))
7394 return false;
7395
7396 auto *ICI = dyn_cast<ICmpInst>(U->getOperand(0));
7397 if (!ICI)
7398 return false;
7399 Value *LHS = ICI->getOperand(0);
7400 Value *RHS = ICI->getOperand(1);
7401 if (ICI->getPredicate() == CmpInst::ICMP_EQ ||
7402 ICI->getPredicate() == CmpInst::ICMP_NE) {
7403 if (!(isa<ConstantInt>(RHS) && cast<ConstantInt>(RHS)->isZero()))
7404 return true;
7405 } else if (getTypeSizeInBits(LHS->getType()) >
7406 getTypeSizeInBits(U->getType()))
7407 return true;
7408 return false;
7409 };
7410 if (CanSimplifyToUnknown())
7411 return getUnknown(U);
7412
7413 for (Value *Inc : U->operands())
7414 Ops.push_back(Inc);
7415 return nullptr;
7416 break;
7417 }
7418 case Instruction::Call:
7419 case Instruction::Invoke:
7420 if (Value *RV = cast<CallBase>(U)->getReturnedArgOperand()) {
7421 Ops.push_back(RV);
7422 return nullptr;
7423 }
7424
7425 if (auto *II = dyn_cast<IntrinsicInst>(U)) {
7426 switch (II->getIntrinsicID()) {
7427 case Intrinsic::abs:
7428 Ops.push_back(II->getArgOperand(0));
7429 return nullptr;
7430 case Intrinsic::umax:
7431 case Intrinsic::umin:
7432 case Intrinsic::smax:
7433 case Intrinsic::smin:
7434 case Intrinsic::usub_sat:
7435 case Intrinsic::uadd_sat:
7436 Ops.push_back(II->getArgOperand(0));
7437 Ops.push_back(II->getArgOperand(1));
7438 return nullptr;
7439 case Intrinsic::start_loop_iterations:
7440 case Intrinsic::annotation:
7441 case Intrinsic::ptr_annotation:
7442 Ops.push_back(II->getArgOperand(0));
7443 return nullptr;
7444 default:
7445 break;
7446 }
7447 }
7448 break;
7449 }
7450
7451 return nullptr;
7452}
7453
7454const SCEV *ScalarEvolution::createSCEV(Value *V) {
7455 if (!isSCEVable(V->getType()))
7456 return getUnknown(V);
7457
7458 if (Instruction *I = dyn_cast<Instruction>(V)) {
7459 // Don't attempt to analyze instructions in blocks that aren't
7460 // reachable. Such instructions don't matter, and they aren't required
7461 // to obey basic rules for definitions dominating uses which this
7462 // analysis depends on.
7463 if (!DT.isReachableFromEntry(I->getParent()))
7464 return getUnknown(PoisonValue::get(V->getType()));
7465 } else if (ConstantInt *CI = dyn_cast<ConstantInt>(V))
7466 return getConstant(CI);
7467 else if (GlobalAlias *GA = dyn_cast<GlobalAlias>(V))
7468 return GA->isInterposable() ? getUnknown(V) : getSCEV(GA->getAliasee());
7469 else if (!isa<ConstantExpr>(V))
7470 return getUnknown(V);
7471
7472 const SCEV *LHS;
7473 const SCEV *RHS;
7474
7475 Operator *U = cast<Operator>(V);
7476 if (auto BO = MatchBinaryOp(U, DT)) {
7477 switch (BO->Opcode) {
7478 case Instruction::Add: {
7479 // The simple thing to do would be to just call getSCEV on both operands
7480 // and call getAddExpr with the result. However if we're looking at a
7481 // bunch of things all added together, this can be quite inefficient,
7482 // because it leads to N-1 getAddExpr calls for N ultimate operands.
7483 // Instead, gather up all the operands and make a single getAddExpr call.
7484 // LLVM IR canonical form means we need only traverse the left operands.
7485 SmallVector<const SCEV *, 4> AddOps;
7486 do {
7487 if (BO->Op) {
7488 if (auto *OpSCEV = getExistingSCEV(BO->Op)) {
7489 AddOps.push_back(OpSCEV);
7490 break;
7491 }
7492
7493 // If a NUW or NSW flag can be applied to the SCEV for this
7494 // addition, then compute the SCEV for this addition by itself
7495 // with a separate call to getAddExpr. We need to do that
7496 // instead of pushing the operands of the addition onto AddOps,
7497 // since the flags are only known to apply to this particular
7498 // addition - they may not apply to other additions that can be
7499 // formed with operands from AddOps.
7500 const SCEV *RHS = getSCEV(BO->RHS);
7501 SCEV::NoWrapFlags Flags = getNoWrapFlagsFromUB(BO->Op);
7502 if (Flags != SCEV::FlagAnyWrap) {
7503 const SCEV *LHS = getSCEV(BO->LHS);
7504 if (BO->Opcode == Instruction::Sub)
7505 AddOps.push_back(getMinusSCEV(LHS, RHS, Flags));
7506 else
7507 AddOps.push_back(getAddExpr(LHS, RHS, Flags));
7508 break;
7509 }
7510 }
7511
7512 if (BO->Opcode == Instruction::Sub)
7513 AddOps.push_back(getNegativeSCEV(getSCEV(BO->RHS)));
7514 else
7515 AddOps.push_back(getSCEV(BO->RHS));
7516
7517 auto NewBO = MatchBinaryOp(BO->LHS, DT);
7518 if (!NewBO || (NewBO->Opcode != Instruction::Add &&
7519 NewBO->Opcode != Instruction::Sub)) {
7520 AddOps.push_back(getSCEV(BO->LHS));
7521 break;
7522 }
7523 BO = NewBO;
7524 } while (true);
7525
7526 return getAddExpr(AddOps);
7527 }
7528
7529 case Instruction::Mul: {
7530 SmallVector<const SCEV *, 4> MulOps;
7531 do {
7532 if (BO->Op) {
7533 if (auto *OpSCEV = getExistingSCEV(BO->Op)) {
7534 MulOps.push_back(OpSCEV);
7535 break;
7536 }
7537
7538 SCEV::NoWrapFlags Flags = getNoWrapFlagsFromUB(BO->Op);
7539 if (Flags != SCEV::FlagAnyWrap) {
7540 LHS = getSCEV(BO->LHS);
7541 RHS = getSCEV(BO->RHS);
7542 MulOps.push_back(getMulExpr(LHS, RHS, Flags));
7543 break;
7544 }
7545 }
7546
7547 MulOps.push_back(getSCEV(BO->RHS));
7548 auto NewBO = MatchBinaryOp(BO->LHS, DT);
7549 if (!NewBO || NewBO->Opcode != Instruction::Mul) {
7550 MulOps.push_back(getSCEV(BO->LHS));
7551 break;
7552 }
7553 BO = NewBO;
7554 } while (true);
7555
7556 return getMulExpr(MulOps);
7557 }
7558 case Instruction::UDiv:
7559 LHS = getSCEV(BO->LHS);
7560 RHS = getSCEV(BO->RHS);
7561 return getUDivExpr(LHS, RHS);
7562 case Instruction::URem:
7563 LHS = getSCEV(BO->LHS);
7564 RHS = getSCEV(BO->RHS);
7565 return getURemExpr(LHS, RHS);
7566 case Instruction::Sub: {
7567 SCEV::NoWrapFlags Flags = SCEV::FlagAnyWrap;
7568 if (BO->Op)
7569 Flags = getNoWrapFlagsFromUB(BO->Op);
7570 LHS = getSCEV(BO->LHS);
7571 RHS = getSCEV(BO->RHS);
7572 return getMinusSCEV(LHS, RHS, Flags);
7573 }
7574 case Instruction::And:
7575 // For an expression like x&255 that merely masks off the high bits,
7576 // use zext(trunc(x)) as the SCEV expression.
7577 if (ConstantInt *CI = dyn_cast<ConstantInt>(BO->RHS)) {
7578 if (CI->isZero())
7579 return getSCEV(BO->RHS);
7580 if (CI->isMinusOne())
7581 return getSCEV(BO->LHS);
7582 const APInt &A = CI->getValue();
7583
7584 // Instcombine's ShrinkDemandedConstant may strip bits out of
7585 // constants, obscuring what would otherwise be a low-bits mask.
7586 // Use computeKnownBits to compute what ShrinkDemandedConstant
7587 // knew about to reconstruct a low-bits mask value.
7588 unsigned LZ = A.countLeadingZeros();
7589 unsigned TZ = A.countTrailingZeros();
7590 unsigned BitWidth = A.getBitWidth();
7591 KnownBits Known(BitWidth);
7592 computeKnownBits(BO->LHS, Known, getDataLayout(),
7593 0, &AC, nullptr, &DT);
7594
7595 APInt EffectiveMask =
7596 APInt::getLowBitsSet(BitWidth, BitWidth - LZ - TZ).shl(TZ);
7597 if ((LZ != 0 || TZ != 0) && !((~A & ~Known.Zero) & EffectiveMask)) {
7598 const SCEV *MulCount = getConstant(APInt::getOneBitSet(BitWidth, TZ));
7599 const SCEV *LHS = getSCEV(BO->LHS);
7600 const SCEV *ShiftedLHS = nullptr;
7601 if (auto *LHSMul = dyn_cast<SCEVMulExpr>(LHS)) {
7602 if (auto *OpC = dyn_cast<SCEVConstant>(LHSMul->getOperand(0))) {
7603 // For an expression like (x * 8) & 8, simplify the multiply.
7604 unsigned MulZeros = OpC->getAPInt().countTrailingZeros();
7605 unsigned GCD = std::min(MulZeros, TZ);
7606 APInt DivAmt = APInt::getOneBitSet(BitWidth, TZ - GCD);
7607 SmallVector<const SCEV*, 4> MulOps;
7608 MulOps.push_back(getConstant(OpC->getAPInt().lshr(GCD)));
7609 MulOps.append(LHSMul->op_begin() + 1, LHSMul->op_end());
7610 auto *NewMul = getMulExpr(MulOps, LHSMul->getNoWrapFlags());
7611 ShiftedLHS = getUDivExpr(NewMul, getConstant(DivAmt));
7612 }
7613 }
7614 if (!ShiftedLHS)
7615 ShiftedLHS = getUDivExpr(LHS, MulCount);
7616 return getMulExpr(
7617 getZeroExtendExpr(
7618 getTruncateExpr(ShiftedLHS,
7619 IntegerType::get(getContext(), BitWidth - LZ - TZ)),
7620 BO->LHS->getType()),
7621 MulCount);
7622 }
7623 }
7624 // Binary `and` is a bit-wise `umin`.
7625 if (BO->LHS->getType()->isIntegerTy(1)) {
7626 LHS = getSCEV(BO->LHS);
7627 RHS = getSCEV(BO->RHS);
7628 return getUMinExpr(LHS, RHS);
7629 }
7630 break;
7631
7632 case Instruction::Or:
7633 // If the RHS of the Or is a constant, we may have something like:
7634 // X*4+1 which got turned into X*4|1. Handle this as an Add so loop
7635 // optimizations will transparently handle this case.
7636 //
7637 // In order for this transformation to be safe, the LHS must be of the
7638 // form X*(2^n) and the Or constant must be less than 2^n.
7639 if (ConstantInt *CI = dyn_cast<ConstantInt>(BO->RHS)) {
7640 const SCEV *LHS = getSCEV(BO->LHS);
7641 const APInt &CIVal = CI->getValue();
7642 if (GetMinTrailingZeros(LHS) >=
7643 (CIVal.getBitWidth() - CIVal.countLeadingZeros())) {
7644 // Build a plain add SCEV.
7645 return getAddExpr(LHS, getSCEV(CI),
7646 (SCEV::NoWrapFlags)(SCEV::FlagNUW | SCEV::FlagNSW));
7647 }
7648 }
7649 // Binary `or` is a bit-wise `umax`.
7650 if (BO->LHS->getType()->isIntegerTy(1)) {
7651 LHS = getSCEV(BO->LHS);
7652 RHS = getSCEV(BO->RHS);
7653 return getUMaxExpr(LHS, RHS);
7654 }
7655 break;
7656
7657 case Instruction::Xor:
7658 if (ConstantInt *CI = dyn_cast<ConstantInt>(BO->RHS)) {
7659 // If the RHS of xor is -1, then this is a not operation.
7660 if (CI->isMinusOne())
7661 return getNotSCEV(getSCEV(BO->LHS));
7662
7663 // Model xor(and(x, C), C) as and(~x, C), if C is a low-bits mask.
7664 // This is a variant of the check for xor with -1, and it handles
7665 // the case where instcombine has trimmed non-demanded bits out
7666 // of an xor with -1.
7667 if (auto *LBO = dyn_cast<BinaryOperator>(BO->LHS))
7668 if (ConstantInt *LCI = dyn_cast<ConstantInt>(LBO->getOperand(1)))
7669 if (LBO->getOpcode() == Instruction::And &&
7670 LCI->getValue() == CI->getValue())
7671 if (const SCEVZeroExtendExpr *Z =
7672 dyn_cast<SCEVZeroExtendExpr>(getSCEV(BO->LHS))) {
7673 Type *UTy = BO->LHS->getType();
7674 const SCEV *Z0 = Z->getOperand();
7675 Type *Z0Ty = Z0->getType();
7676 unsigned Z0TySize = getTypeSizeInBits(Z0Ty);
7677
7678 // If C is a low-bits mask, the zero extend is serving to
7679 // mask off the high bits. Complement the operand and
7680 // re-apply the zext.
7681 if (CI->getValue().isMask(Z0TySize))
7682 return getZeroExtendExpr(getNotSCEV(Z0), UTy);
7683
7684 // If C is a single bit, it may be in the sign-bit position
7685 // before the zero-extend. In this case, represent the xor
7686 // using an add, which is equivalent, and re-apply the zext.
7687 APInt Trunc = CI->getValue().trunc(Z0TySize);
7688 if (Trunc.zext(getTypeSizeInBits(UTy)) == CI->getValue() &&
7689 Trunc.isSignMask())
7690 return getZeroExtendExpr(getAddExpr(Z0, getConstant(Trunc)),
7691 UTy);
7692 }
7693 }
7694 break;
7695
7696 case Instruction::Shl:
7697 // Turn shift left of a constant amount into a multiply.
7698 if (ConstantInt *SA = dyn_cast<ConstantInt>(BO->RHS)) {
7699 uint32_t BitWidth = cast<IntegerType>(SA->getType())->getBitWidth();
7700
7701 // If the shift count is not less than the bitwidth, the result of
7702 // the shift is undefined. Don't try to analyze it, because the
7703 // resolution chosen here may differ from the resolution chosen in
7704 // other parts of the compiler.
7705 if (SA->getValue().uge(BitWidth))
7706 break;
7707
7708 // We can safely preserve the nuw flag in all cases. It's also safe to
7709 // turn a nuw nsw shl into a nuw nsw mul. However, nsw in isolation
7710 // requires special handling. It can be preserved as long as we're not
7711 // left shifting by bitwidth - 1.
7712 auto Flags = SCEV::FlagAnyWrap;
7713 if (BO->Op) {
7714 auto MulFlags = getNoWrapFlagsFromUB(BO->Op);
7715 if ((MulFlags & SCEV::FlagNSW) &&
7716 ((MulFlags & SCEV::FlagNUW) || SA->getValue().ult(BitWidth - 1)))
7717 Flags = (SCEV::NoWrapFlags)(Flags | SCEV::FlagNSW);
7718 if (MulFlags & SCEV::FlagNUW)
7719 Flags = (SCEV::NoWrapFlags)(Flags | SCEV::FlagNUW);
7720 }
7721
7722 ConstantInt *X = ConstantInt::get(
7723 getContext(), APInt::getOneBitSet(BitWidth, SA->getZExtValue()));
7724 return getMulExpr(getSCEV(BO->LHS), getConstant(X), Flags);
7725 }
7726 break;
7727
7728 case Instruction::AShr: {
7729 // AShr X, C, where C is a constant.
7730 ConstantInt *CI = dyn_cast<ConstantInt>(BO->RHS);
7731 if (!CI)
7732 break;
7733
7734 Type *OuterTy = BO->LHS->getType();
7735 uint64_t BitWidth = getTypeSizeInBits(OuterTy);
7736 // If the shift count is not less than the bitwidth, the result of
7737 // the shift is undefined. Don't try to analyze it, because the
7738 // resolution chosen here may differ from the resolution chosen in
7739 // other parts of the compiler.
7740 if (CI->getValue().uge(BitWidth))
7741 break;
7742
7743 if (CI->isZero())
7744 return getSCEV(BO->LHS); // shift by zero --> noop
7745
7746 uint64_t AShrAmt = CI->getZExtValue();
7747 Type *TruncTy = IntegerType::get(getContext(), BitWidth - AShrAmt);
7748
7749 Operator *L = dyn_cast<Operator>(BO->LHS);
7750 if (L && L->getOpcode() == Instruction::Shl) {
7751 // X = Shl A, n
7752 // Y = AShr X, m
7753 // Both n and m are constant.
7754
7755 const SCEV *ShlOp0SCEV = getSCEV(L->getOperand(0));
7756 if (L->getOperand(1) == BO->RHS)
7757 // For a two-shift sext-inreg, i.e. n = m,
7758 // use sext(trunc(x)) as the SCEV expression.
7759 return getSignExtendExpr(
7760 getTruncateExpr(ShlOp0SCEV, TruncTy), OuterTy);
7761
7762 ConstantInt *ShlAmtCI = dyn_cast<ConstantInt>(L->getOperand(1));
7763 if (ShlAmtCI && ShlAmtCI->getValue().ult(BitWidth)) {
7764 uint64_t ShlAmt = ShlAmtCI->getZExtValue();
7765 if (ShlAmt > AShrAmt) {
7766 // When n > m, use sext(mul(trunc(x), 2^(n-m)))) as the SCEV
7767 // expression. We already checked that ShlAmt < BitWidth, so
7768 // the multiplier, 1 << (ShlAmt - AShrAmt), fits into TruncTy as
7769 // ShlAmt - AShrAmt < Amt.
7770 APInt Mul = APInt::getOneBitSet(BitWidth - AShrAmt,
7771 ShlAmt - AShrAmt);
7772 return getSignExtendExpr(
7773 getMulExpr(getTruncateExpr(ShlOp0SCEV, TruncTy),
7774 getConstant(Mul)), OuterTy);
7775 }
7776 }
7777 }
7778 break;
7779 }
7780 }
7781 }
7782
7783 switch (U->getOpcode()) {
7784 case Instruction::Trunc:
7785 return getTruncateExpr(getSCEV(U->getOperand(0)), U->getType());
7786
7787 case Instruction::ZExt:
7788 return getZeroExtendExpr(getSCEV(U->getOperand(0)), U->getType());
7789
7790 case Instruction::SExt:
7791 if (auto BO = MatchBinaryOp(U->getOperand(0), DT)) {
7792 // The NSW flag of a subtract does not always survive the conversion to
7793 // A + (-1)*B. By pushing sign extension onto its operands we are much
7794 // more likely to preserve NSW and allow later AddRec optimisations.
7795 //
7796 // NOTE: This is effectively duplicating this logic from getSignExtend:
7797 // sext((A + B + ...)<nsw>) --> (sext(A) + sext(B) + ...)<nsw>
7798 // but by that point the NSW information has potentially been lost.
7799 if (BO->Opcode == Instruction::Sub && BO->IsNSW) {
7800 Type *Ty = U->getType();
7801 auto *V1 = getSignExtendExpr(getSCEV(BO->LHS), Ty);
7802 auto *V2 = getSignExtendExpr(getSCEV(BO->RHS), Ty);
7803 return getMinusSCEV(V1, V2, SCEV::FlagNSW);
7804 }
7805 }
7806 return getSignExtendExpr(getSCEV(U->getOperand(0)), U->getType());
7807
7808 case Instruction::BitCast:
7809 // BitCasts are no-op casts so we just eliminate the cast.
7810 if (isSCEVable(U->getType()) && isSCEVable(U->getOperand(0)->getType()))
7811 return getSCEV(U->getOperand(0));
7812 break;
7813
7814 case Instruction::PtrToInt: {
7815 // Pointer to integer cast is straight-forward, so do model it.
7816 const SCEV *Op = getSCEV(U->getOperand(0));
7817 Type *DstIntTy = U->getType();
7818 // But only if effective SCEV (integer) type is wide enough to represent
7819 // all possible pointer values.
7820 const SCEV *IntOp = getPtrToIntExpr(Op, DstIntTy);
7821 if (isa<SCEVCouldNotCompute>(IntOp))
7822 return getUnknown(V);
7823 return IntOp;
7824 }
7825 case Instruction::IntToPtr:
7826 // Just don't deal with inttoptr casts.
7827 return getUnknown(V);
7828
7829 case Instruction::SDiv:
7830 // If both operands are non-negative, this is just an udiv.
7831 if (isKnownNonNegative(getSCEV(U->getOperand(0))) &&
7832 isKnownNonNegative(getSCEV(U->getOperand(1))))
7833 return getUDivExpr(getSCEV(U->getOperand(0)), getSCEV(U->getOperand(1)));
7834 break;
7835
7836 case Instruction::SRem:
7837 // If both operands are non-negative, this is just an urem.
7838 if (isKnownNonNegative(getSCEV(U->getOperand(0))) &&
7839 isKnownNonNegative(getSCEV(U->getOperand(1))))
7840 return getURemExpr(getSCEV(U->getOperand(0)), getSCEV(U->getOperand(1)));
7841 break;
7842
7843 case Instruction::GetElementPtr:
7844 return createNodeForGEP(cast<GEPOperator>(U));
7845
7846 case Instruction::PHI:
7847 return createNodeForPHI(cast<PHINode>(U));
7848
7849 case Instruction::Select:
7850 return createNodeForSelectOrPHI(U, U->getOperand(0), U->getOperand(1),
7851 U->getOperand(2));
7852
7853 case Instruction::Call:
7854 case Instruction::Invoke:
7855 if (Value *RV = cast<CallBase>(U)->getReturnedArgOperand())
7856 return getSCEV(RV);
7857
7858 if (auto *II = dyn_cast<IntrinsicInst>(U)) {
7859 switch (II->getIntrinsicID()) {
7860 case Intrinsic::abs:
7861 return getAbsExpr(
7862 getSCEV(II->getArgOperand(0)),
7863 /*IsNSW=*/cast<ConstantInt>(II->getArgOperand(1))->isOne());
7864 case Intrinsic::umax:
7865 LHS = getSCEV(II->getArgOperand(0));
7866 RHS = getSCEV(II->getArgOperand(1));
7867 return getUMaxExpr(LHS, RHS);
7868 case Intrinsic::umin:
7869 LHS = getSCEV(II->getArgOperand(0));
7870 RHS = getSCEV(II->getArgOperand(1));
7871 return getUMinExpr(LHS, RHS);
7872 case Intrinsic::smax:
7873 LHS = getSCEV(II->getArgOperand(0));
7874 RHS = getSCEV(II->getArgOperand(1));
7875 return getSMaxExpr(LHS, RHS);
7876 case Intrinsic::smin:
7877 LHS = getSCEV(II->getArgOperand(0));
7878 RHS = getSCEV(II->getArgOperand(1));
7879 return getSMinExpr(LHS, RHS);
7880 case Intrinsic::usub_sat: {
7881 const SCEV *X = getSCEV(II->getArgOperand(0));
7882 const SCEV *Y = getSCEV(II->getArgOperand(1));
7883 const SCEV *ClampedY = getUMinExpr(X, Y);
7884 return getMinusSCEV(X, ClampedY, SCEV::FlagNUW);
7885 }
7886 case Intrinsic::uadd_sat: {
7887 const SCEV *X = getSCEV(II->getArgOperand(0));
7888 const SCEV *Y = getSCEV(II->getArgOperand(1));
7889 const SCEV *ClampedX = getUMinExpr(X, getNotSCEV(Y));
7890 return getAddExpr(ClampedX, Y, SCEV::FlagNUW);
7891 }
7892 case Intrinsic::start_loop_iterations:
7893 case Intrinsic::annotation:
7894 case Intrinsic::ptr_annotation:
7895 // A start_loop_iterations or llvm.annotation or llvm.prt.annotation is
7896 // just eqivalent to the first operand for SCEV purposes.
7897 return getSCEV(II->getArgOperand(0));
7898 default:
7899 break;
7900 }
7901 }
7902 break;
7903 }
7904
7905 return getUnknown(V);
7906}
7907
7908//===----------------------------------------------------------------------===//
7909// Iteration Count Computation Code
7910//
7911
7912const SCEV *ScalarEvolution::getTripCountFromExitCount(const SCEV *ExitCount,
7913 bool Extend) {
7914 if (isa<SCEVCouldNotCompute>(ExitCount))
7915 return getCouldNotCompute();
7916
7917 auto *ExitCountType = ExitCount->getType();
7918 assert(ExitCountType->isIntegerTy())(static_cast <bool> (ExitCountType->isIntegerTy()) ?
void (0) : __assert_fail ("ExitCountType->isIntegerTy()",
"llvm/lib/Analysis/ScalarEvolution.cpp", 7918, __extension__
__PRETTY_FUNCTION__))
;
7919
7920 if (!Extend)
7921 return getAddExpr(ExitCount, getOne(ExitCountType));
7922
7923 auto *WiderType = Type::getIntNTy(ExitCountType->getContext(),
7924 1 + ExitCountType->getScalarSizeInBits());
7925 return getAddExpr(getNoopOrZeroExtend(ExitCount, WiderType),
7926 getOne(WiderType));
7927}
7928
7929static unsigned getConstantTripCount(const SCEVConstant *ExitCount) {
7930 if (!ExitCount)
7931 return 0;
7932
7933 ConstantInt *ExitConst = ExitCount->getValue();
7934
7935 // Guard against huge trip counts.
7936 if (ExitConst->getValue().getActiveBits() > 32)
7937 return 0;
7938
7939 // In case of integer overflow, this returns 0, which is correct.
7940 return ((unsigned)ExitConst->getZExtValue()) + 1;
7941}
7942
7943unsigned ScalarEvolution::getSmallConstantTripCount(const Loop *L) {
7944 auto *ExitCount = dyn_cast<SCEVConstant>(getBackedgeTakenCount(L, Exact));
7945 return getConstantTripCount(ExitCount);
7946}
7947
7948unsigned
7949ScalarEvolution::getSmallConstantTripCount(const Loop *L,
7950 const BasicBlock *ExitingBlock) {
7951 assert(ExitingBlock && "Must pass a non-null exiting block!")(static_cast <bool> (ExitingBlock && "Must pass a non-null exiting block!"
) ? void (0) : __assert_fail ("ExitingBlock && \"Must pass a non-null exiting block!\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 7951, __extension__
__PRETTY_FUNCTION__))
;
7952 assert(L->isLoopExiting(ExitingBlock) &&(static_cast <bool> (L->isLoopExiting(ExitingBlock) &&
"Exiting block must actually branch out of the loop!") ? void
(0) : __assert_fail ("L->isLoopExiting(ExitingBlock) && \"Exiting block must actually branch out of the loop!\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 7953, __extension__
__PRETTY_FUNCTION__))
7953 "Exiting block must actually branch out of the loop!")(static_cast <bool> (L->isLoopExiting(ExitingBlock) &&
"Exiting block must actually branch out of the loop!") ? void
(0) : __assert_fail ("L->isLoopExiting(ExitingBlock) && \"Exiting block must actually branch out of the loop!\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 7953, __extension__
__PRETTY_FUNCTION__))
;
7954 const SCEVConstant *ExitCount =
7955 dyn_cast<SCEVConstant>(getExitCount(L, ExitingBlock));
7956 return getConstantTripCount(ExitCount);
7957}
7958
7959unsigned ScalarEvolution::getSmallConstantMaxTripCount(const Loop *L) {
7960 const auto *MaxExitCount =
7961 dyn_cast<SCEVConstant>(getConstantMaxBackedgeTakenCount(L));
7962 return getConstantTripCount(MaxExitCount);
7963}
7964
7965const SCEV *ScalarEvolution::getConstantMaxTripCountFromArray(const Loop *L) {
7966 // We can't infer from Array in Irregular Loop.
7967 // FIXME: It's hard to infer loop bound from array operated in Nested Loop.
7968 if (!L->isLoopSimplifyForm() || !L->isInnermost())
7969 return getCouldNotCompute();
7970
7971 // FIXME: To make the scene more typical, we only analysis loops that have
7972 // one exiting block and that block must be the latch. To make it easier to
7973 // capture loops that have memory access and memory access will be executed
7974 // in each iteration.
7975 const BasicBlock *LoopLatch = L->getLoopLatch();
7976 assert(LoopLatch && "See defination of simplify form loop.")(static_cast <bool> (LoopLatch && "See defination of simplify form loop."
) ? void (0) : __assert_fail ("LoopLatch && \"See defination of simplify form loop.\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 7976, __extension__
__PRETTY_FUNCTION__))
;
7977 if (L->getExitingBlock() != LoopLatch)
7978 return getCouldNotCompute();
7979
7980 const DataLayout &DL = getDataLayout();
7981 SmallVector<const SCEV *> InferCountColl;
7982 for (auto *BB : L->getBlocks()) {
7983 // Go here, we can know that Loop is a single exiting and simplified form
7984 // loop. Make sure that infer from Memory Operation in those BBs must be
7985 // executed in loop. First step, we can make sure that max execution time
7986 // of MemAccessBB in loop represents latch max excution time.
7987 // If MemAccessBB does not dom Latch, skip.
7988 // Entry
7989 // │
7990 // ┌─────▼─────┐
7991 // │Loop Header◄─────┐
7992 // └──┬──────┬─┘ │
7993 // │ │ │
7994 // ┌────────▼──┐ ┌─▼─────┐ │
7995 // │MemAccessBB│ │OtherBB│ │
7996 // └────────┬──┘ └─┬─────┘ │
7997 // │ │ │
7998 // ┌─▼──────▼─┐ │
7999 // │Loop Latch├─────┘
8000 // └────┬─────┘
8001 // ▼
8002 // Exit
8003 if (!DT.dominates(BB, LoopLatch))
8004 continue;
8005
8006 for (Instruction &Inst : *BB) {
8007 // Find Memory Operation Instruction.
8008 auto *GEP = getLoadStorePointerOperand(&Inst);
8009 if (!GEP)
8010 continue;
8011
8012 auto *ElemSize = dyn_cast<SCEVConstant>(getElementSize(&Inst));
8013 // Do not infer from scalar type, eg."ElemSize = sizeof()".
8014 if (!ElemSize)
8015 continue;
8016
8017 // Use a existing polynomial recurrence on the trip count.
8018 auto *AddRec = dyn_cast<SCEVAddRecExpr>(getSCEV(GEP));
8019 if (!AddRec)
8020 continue;
8021 auto *ArrBase = dyn_cast<SCEVUnknown>(getPointerBase(AddRec));
8022 auto *Step = dyn_cast<SCEVConstant>(AddRec->getStepRecurrence(*this));
8023 if (!ArrBase || !Step)
8024 continue;
8025 assert(isLoopInvariant(ArrBase, L) && "See addrec definition")(static_cast <bool> (isLoopInvariant(ArrBase, L) &&
"See addrec definition") ? void (0) : __assert_fail ("isLoopInvariant(ArrBase, L) && \"See addrec definition\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 8025, __extension__
__PRETTY_FUNCTION__))
;
8026
8027 // Only handle { %array + step },
8028 // FIXME: {(SCEVAddRecExpr) + step } could not be analysed here.
8029 if (AddRec->getStart() != ArrBase)
8030 continue;
8031
8032 // Memory operation pattern which have gaps.
8033 // Or repeat memory opreation.
8034 // And index of GEP wraps arround.
8035 if (Step->getAPInt().getActiveBits() > 32 ||
8036 Step->getAPInt().getZExtValue() !=
8037 ElemSize->getAPInt().getZExtValue() ||
8038 Step->isZero() || Step->getAPInt().isNegative())
8039 continue;
8040
8041 // Only infer from stack array which has certain size.
8042 // Make sure alloca instruction is not excuted in loop.
8043 AllocaInst *AllocateInst = dyn_cast<AllocaInst>(ArrBase->getValue());
8044 if (!AllocateInst || L->contains(AllocateInst->getParent()))
8045 continue;
8046
8047 // Make sure only handle normal array.
8048 auto *Ty = dyn_cast<ArrayType>(AllocateInst->getAllocatedType());
8049 auto *ArrSize = dyn_cast<ConstantInt>(AllocateInst->getArraySize());
8050 if (!Ty || !ArrSize || !ArrSize->isOne())
8051 continue;
8052
8053 // FIXME: Since gep indices are silently zext to the indexing type,
8054 // we will have a narrow gep index which wraps around rather than
8055 // increasing strictly, we shoule ensure that step is increasing
8056 // strictly by the loop iteration.
8057 // Now we can infer a max execution time by MemLength/StepLength.
8058 const SCEV *MemSize =
8059 getConstant(Step->getType(), DL.getTypeAllocSize(Ty));
8060 auto *MaxExeCount =
8061 dyn_cast<SCEVConstant>(getUDivCeilSCEV(MemSize, Step));
8062 if (!MaxExeCount || MaxExeCount->getAPInt().getActiveBits() > 32)
8063 continue;
8064
8065 // If the loop reaches the maximum number of executions, we can not
8066 // access bytes starting outside the statically allocated size without
8067 // being immediate UB. But it is allowed to enter loop header one more
8068 // time.
8069 auto *InferCount = dyn_cast<SCEVConstant>(
8070 getAddExpr(MaxExeCount, getOne(MaxExeCount->getType())));
8071 // Discard the maximum number of execution times under 32bits.
8072 if (!InferCount || InferCount->getAPInt().getActiveBits() > 32)
8073 continue;
8074
8075 InferCountColl.push_back(InferCount);
8076 }
8077 }
8078
8079 if (InferCountColl.size() == 0)
8080 return getCouldNotCompute();
8081
8082 return getUMinFromMismatchedTypes(InferCountColl);
8083}
8084
8085unsigned ScalarEvolution::getSmallConstantTripMultiple(const Loop *L) {
8086 SmallVector<BasicBlock *, 8> ExitingBlocks;
8087 L->getExitingBlocks(ExitingBlocks);
8088
8089 Optional<unsigned> Res;
8090 for (auto *ExitingBB : ExitingBlocks) {
8091 unsigned Multiple = getSmallConstantTripMultiple(L, ExitingBB);
8092 if (!Res)
8093 Res = Multiple;
8094 Res = (unsigned)std::gcd(*Res, Multiple);
8095 }
8096 return Res.value_or(1);
8097}
8098
8099unsigned ScalarEvolution::getSmallConstantTripMultiple(const Loop *L,
8100 const SCEV *ExitCount) {
8101 if (ExitCount == getCouldNotCompute())
8102 return 1;
8103
8104 // Get the trip count
8105 const SCEV *TCExpr = getTripCountFromExitCount(ExitCount);
8106
8107 const SCEVConstant *TC = dyn_cast<SCEVConstant>(TCExpr);
8108 if (!TC)
8109 // Attempt to factor more general cases. Returns the greatest power of
8110 // two divisor. If overflow happens, the trip count expression is still
8111 // divisible by the greatest power of 2 divisor returned.
8112 return 1U << std::min((uint32_t)31,
8113 GetMinTrailingZeros(applyLoopGuards(TCExpr, L)));
8114
8115 ConstantInt *Result = TC->getValue();
8116
8117 // Guard against huge trip counts (this requires checking
8118 // for zero to handle the case where the trip count == -1 and the
8119 // addition wraps).
8120 if (!Result || Result->getValue().getActiveBits() > 32 ||
8121 Result->getValue().getActiveBits() == 0)
8122 return 1;
8123
8124 return (unsigned)Result->getZExtValue();
8125}
8126
8127/// Returns the largest constant divisor of the trip count of this loop as a
8128/// normal unsigned value, if possible. This means that the actual trip count is
8129/// always a multiple of the returned value (don't forget the trip count could
8130/// very well be zero as well!).
8131///
8132/// Returns 1 if the trip count is unknown or not guaranteed to be the
8133/// multiple of a constant (which is also the case if the trip count is simply
8134/// constant, use getSmallConstantTripCount for that case), Will also return 1
8135/// if the trip count is very large (>= 2^32).
8136///
8137/// As explained in the comments for getSmallConstantTripCount, this assumes
8138/// that control exits the loop via ExitingBlock.
8139unsigned
8140ScalarEvolution::getSmallConstantTripMultiple(const Loop *L,
8141 const BasicBlock *ExitingBlock) {
8142 assert(ExitingBlock && "Must pass a non-null exiting block!")(static_cast <bool> (ExitingBlock && "Must pass a non-null exiting block!"
) ? void (0) : __assert_fail ("ExitingBlock && \"Must pass a non-null exiting block!\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 8142, __extension__
__PRETTY_FUNCTION__))
;
8143 assert(L->isLoopExiting(ExitingBlock) &&(static_cast <bool> (L->isLoopExiting(ExitingBlock) &&
"Exiting block must actually branch out of the loop!") ? void
(0) : __assert_fail ("L->isLoopExiting(ExitingBlock) && \"Exiting block must actually branch out of the loop!\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 8144, __extension__
__PRETTY_FUNCTION__))
8144 "Exiting block must actually branch out of the loop!")(static_cast <bool> (L->isLoopExiting(ExitingBlock) &&
"Exiting block must actually branch out of the loop!") ? void
(0) : __assert_fail ("L->isLoopExiting(ExitingBlock) && \"Exiting block must actually branch out of the loop!\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 8144, __extension__
__PRETTY_FUNCTION__))
;
8145 const SCEV *ExitCount = getExitCount(L, ExitingBlock);
8146 return getSmallConstantTripMultiple(L, ExitCount);
8147}
8148
8149const SCEV *ScalarEvolution::getExitCount(const Loop *L,
8150 const BasicBlock *ExitingBlock,
8151 ExitCountKind Kind) {
8152 switch (Kind) {
8153 case Exact:
8154 case SymbolicMaximum:
8155 return getBackedgeTakenInfo(L).getExact(ExitingBlock, this);
8156 case ConstantMaximum:
8157 return getBackedgeTakenInfo(L).getConstantMax(ExitingBlock, this);
8158 };
8159 llvm_unreachable("Invalid ExitCountKind!")::llvm::llvm_unreachable_internal("Invalid ExitCountKind!", "llvm/lib/Analysis/ScalarEvolution.cpp"
, 8159)
;
8160}
8161
8162const SCEV *
8163ScalarEvolution::getPredicatedBackedgeTakenCount(const Loop *L,
8164 SmallVector<const SCEVPredicate *, 4> &Preds) {
8165 return getPredicatedBackedgeTakenInfo(L).getExact(L, this, &Preds);
8166}
8167
8168const SCEV *ScalarEvolution::getBackedgeTakenCount(const Loop *L,
8169 ExitCountKind Kind) {
8170 switch (Kind) {
8171 case Exact:
8172 return getBackedgeTakenInfo(L).getExact(L, this);
8173 case ConstantMaximum:
8174 return getBackedgeTakenInfo(L).getConstantMax(this);
8175 case SymbolicMaximum:
8176 return getBackedgeTakenInfo(L).getSymbolicMax(L, this);
8177 };
8178 llvm_unreachable("Invalid ExitCountKind!")::llvm::llvm_unreachable_internal("Invalid ExitCountKind!", "llvm/lib/Analysis/ScalarEvolution.cpp"
, 8178)
;
8179}
8180
8181bool ScalarEvolution::isBackedgeTakenCountMaxOrZero(const Loop *L) {
8182 return getBackedgeTakenInfo(L).isConstantMaxOrZero(this);
8183}
8184
8185/// Push PHI nodes in the header of the given loop onto the given Worklist.
8186static void PushLoopPHIs(const Loop *L,
8187 SmallVectorImpl<Instruction *> &Worklist,
8188 SmallPtrSetImpl<Instruction *> &Visited) {
8189 BasicBlock *Header = L->getHeader();
8190
8191 // Push all Loop-header PHIs onto the Worklist stack.
8192 for (PHINode &PN : Header->phis())
8193 if (Visited.insert(&PN).second)
8194 Worklist.push_back(&PN);
8195}
8196
8197const ScalarEvolution::BackedgeTakenInfo &
8198ScalarEvolution::getPredicatedBackedgeTakenInfo(const Loop *L) {
8199 auto &BTI = getBackedgeTakenInfo(L);
8200 if (BTI.hasFullInfo())
8201 return BTI;
8202
8203 auto Pair = PredicatedBackedgeTakenCounts.insert({L, BackedgeTakenInfo()});
8204
8205 if (!Pair.second)
8206 return Pair.first->second;
8207
8208 BackedgeTakenInfo Result =
8209 computeBackedgeTakenCount(L, /*AllowPredicates=*/true);
8210
8211 return PredicatedBackedgeTakenCounts.find(L)->second = std::move(Result);
8212}
8213
8214ScalarEvolution::BackedgeTakenInfo &
8215ScalarEvolution::getBackedgeTakenInfo(const Loop *L) {
8216 // Initially insert an invalid entry for this loop. If the insertion
8217 // succeeds, proceed to actually compute a backedge-taken count and
8218 // update the value. The temporary CouldNotCompute value tells SCEV
8219 // code elsewhere that it shouldn't attempt to request a new
8220 // backedge-taken count, which could result in infinite recursion.
8221 std::pair<DenseMap<const Loop *, BackedgeTakenInfo>::iterator, bool> Pair =
8222 BackedgeTakenCounts.insert({L, BackedgeTakenInfo()});
8223 if (!Pair.second)
8224 return Pair.first->second;
8225
8226 // computeBackedgeTakenCount may allocate memory for its result. Inserting it
8227 // into the BackedgeTakenCounts map transfers ownership. Otherwise, the result
8228 // must be cleared in this scope.
8229 BackedgeTakenInfo Result = computeBackedgeTakenCount(L);
8230
8231 // In product build, there are no usage of statistic.
8232 (void)NumTripCountsComputed;
8233 (void)NumTripCountsNotComputed;
8234#if LLVM_ENABLE_STATS1 || !defined(NDEBUG)
8235 const SCEV *BEExact = Result.getExact(L, this);
8236 if (BEExact != getCouldNotCompute()) {
8237 assert(isLoopInvariant(BEExact, L) &&(static_cast <bool> (isLoopInvariant(BEExact, L) &&
isLoopInvariant(Result.getConstantMax(this), L) && "Computed backedge-taken count isn't loop invariant for loop!"
) ? void (0) : __assert_fail ("isLoopInvariant(BEExact, L) && isLoopInvariant(Result.getConstantMax(this), L) && \"Computed backedge-taken count isn't loop invariant for loop!\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 8239, __extension__
__PRETTY_FUNCTION__))
8238 isLoopInvariant(Result.getConstantMax(this), L) &&(static_cast <bool> (isLoopInvariant(BEExact, L) &&
isLoopInvariant(Result.getConstantMax(this), L) && "Computed backedge-taken count isn't loop invariant for loop!"
) ? void (0) : __assert_fail ("isLoopInvariant(BEExact, L) && isLoopInvariant(Result.getConstantMax(this), L) && \"Computed backedge-taken count isn't loop invariant for loop!\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 8239, __extension__
__PRETTY_FUNCTION__))
8239 "Computed backedge-taken count isn't loop invariant for loop!")(static_cast <bool> (isLoopInvariant(BEExact, L) &&
isLoopInvariant(Result.getConstantMax(this), L) && "Computed backedge-taken count isn't loop invariant for loop!"
) ? void (0) : __assert_fail ("isLoopInvariant(BEExact, L) && isLoopInvariant(Result.getConstantMax(this), L) && \"Computed backedge-taken count isn't loop invariant for loop!\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 8239, __extension__
__PRETTY_FUNCTION__))
;
8240 ++NumTripCountsComputed;
8241 } else if (Result.getConstantMax(this) == getCouldNotCompute() &&
8242 isa<PHINode>(L->getHeader()->begin())) {
8243 // Only count loops that have phi nodes as not being computable.
8244 ++NumTripCountsNotComputed;
8245 }
8246#endif // LLVM_ENABLE_STATS || !defined(NDEBUG)
8247
8248 // Now that we know more about the trip count for this loop, forget any
8249 // existing SCEV values for PHI nodes in this loop since they are only
8250 // conservative estimates made without the benefit of trip count
8251 // information. This invalidation is not necessary for correctness, and is
8252 // only done to produce more precise results.
8253 if (Result.hasAnyInfo()) {
8254 // Invalidate any expression using an addrec in this loop.
8255 SmallVector<const SCEV *, 8> ToForget;
8256 auto LoopUsersIt = LoopUsers.find(L);
8257 if (LoopUsersIt != LoopUsers.end())
8258 append_range(ToForget, LoopUsersIt->second);
8259 forgetMemoizedResults(ToForget);
8260
8261 // Invalidate constant-evolved loop header phis.
8262 for (PHINode &PN : L->getHeader()->phis())
8263 ConstantEvolutionLoopExitValue.erase(&PN);
8264 }
8265
8266 // Re-lookup the insert position, since the call to
8267 // computeBackedgeTakenCount above could result in a
8268 // recusive call to getBackedgeTakenInfo (on a different
8269 // loop), which would invalidate the iterator computed
8270 // earlier.
8271 return BackedgeTakenCounts.find(L)->second = std::move(Result);
8272}
8273
8274void ScalarEvolution::forgetAllLoops() {
8275 // This method is intended to forget all info about loops. It should
8276 // invalidate caches as if the following happened:
8277 // - The trip counts of all loops have changed arbitrarily
8278 // - Every llvm::Value has been updated in place to produce a different
8279 // result.
8280 BackedgeTakenCounts.clear();
8281 PredicatedBackedgeTakenCounts.clear();
8282 BECountUsers.clear();
8283 LoopPropertiesCache.clear();
8284 ConstantEvolutionLoopExitValue.clear();
8285 ValueExprMap.clear();
8286 ValuesAtScopes.clear();
8287 ValuesAtScopesUsers.clear();
8288 LoopDispositions.clear();
8289 BlockDispositions.clear();
8290 UnsignedRanges.clear();
8291 SignedRanges.clear();
8292 ExprValueMap.clear();
8293 HasRecMap.clear();
8294 MinTrailingZerosCache.clear();
8295 PredicatedSCEVRewrites.clear();
8296}
8297
8298void ScalarEvolution::forgetLoop(const Loop *L) {
8299 SmallVector<const Loop *, 16> LoopWorklist(1, L);
8300 SmallVector<Instruction *, 32> Worklist;
8301 SmallPtrSet<Instruction *, 16> Visited;
8302 SmallVector<const SCEV *, 16> ToForget;
8303
8304 // Iterate over all the loops and sub-loops to drop SCEV information.
8305 while (!LoopWorklist.empty()) {
8306 auto *CurrL = LoopWorklist.pop_back_val();
8307
8308 // Drop any stored trip count value.
8309 forgetBackedgeTakenCounts(CurrL, /* Predicated */ false);
8310 forgetBackedgeTakenCounts(CurrL, /* Predicated */ true);
8311
8312 // Drop information about predicated SCEV rewrites for this loop.
8313 for (auto I = PredicatedSCEVRewrites.begin();
8314 I != PredicatedSCEVRewrites.end();) {
8315 std::pair<const SCEV *, const Loop *> Entry = I->first;
8316 if (Entry.second == CurrL)
8317 PredicatedSCEVRewrites.erase(I++);
8318 else
8319 ++I;
8320 }
8321
8322 auto LoopUsersItr = LoopUsers.find(CurrL);
8323 if (LoopUsersItr != LoopUsers.end()) {
8324 ToForget.insert(ToForget.end(), LoopUsersItr->second.begin(),
8325 LoopUsersItr->second.end());
8326 }
8327
8328 // Drop information about expressions based on loop-header PHIs.
8329 PushLoopPHIs(CurrL, Worklist, Visited);
8330
8331 while (!Worklist.empty()) {
8332 Instruction *I = Worklist.pop_back_val();
8333
8334 ValueExprMapType::iterator It =
8335 ValueExprMap.find_as(static_cast<Value *>(I));
8336 if (It != ValueExprMap.end()) {
8337 eraseValueFromMap(It->first);
8338 ToForget.push_back(It->second);
8339 if (PHINode *PN = dyn_cast<PHINode>(I))
8340 ConstantEvolutionLoopExitValue.erase(PN);
8341 }
8342
8343 PushDefUseChildren(I, Worklist, Visited);
8344 }
8345
8346 LoopPropertiesCache.erase(CurrL);
8347 // Forget all contained loops too, to avoid dangling entries in the
8348 // ValuesAtScopes map.
8349 LoopWorklist.append(CurrL->begin(), CurrL->end());
8350 }
8351 forgetMemoizedResults(ToForget);
8352}
8353
8354void ScalarEvolution::forgetTopmostLoop(const Loop *L) {
8355 forgetLoop(L->getOutermostLoop());
8356}
8357
8358void ScalarEvolution::forgetValue(Value *V) {
8359 Instruction *I = dyn_cast<Instruction>(V);
8360 if (!I) return;
8361
8362 // Drop information about expressions based on loop-header PHIs.
8363 SmallVector<Instruction *, 16> Worklist;
8364 SmallPtrSet<Instruction *, 8> Visited;
8365 SmallVector<const SCEV *, 8> ToForget;
8366 Worklist.push_back(I);
8367 Visited.insert(I);
8368
8369 while (!Worklist.empty()) {
8370 I = Worklist.pop_back_val();
8371 ValueExprMapType::iterator It =
8372 ValueExprMap.find_as(static_cast<Value *>(I));
8373 if (It != ValueExprMap.end()) {
8374 eraseValueFromMap(It->first);
8375 ToForget.push_back(It->second);
8376 if (PHINode *PN = dyn_cast<PHINode>(I))
8377 ConstantEvolutionLoopExitValue.erase(PN);
8378 }
8379
8380 PushDefUseChildren(I, Worklist, Visited);
8381 }
8382 forgetMemoizedResults(ToForget);
8383}
8384
8385void ScalarEvolution::forgetLoopDispositions() { LoopDispositions.clear(); }
8386
8387void ScalarEvolution::forgetBlockAndLoopDispositions(Value *V) {
8388 // Unless a specific value is passed to invalidation, completely clear both
8389 // caches.
8390 if (!V) {
8391 BlockDispositions.clear();
8392 LoopDispositions.clear();
8393 return;
8394 }
8395
8396 if (!isSCEVable(V->getType()))
8397 return;
8398
8399 const SCEV *S = getExistingSCEV(V);
8400 if (!S)
8401 return;
8402
8403 // Invalidate the block and loop dispositions cached for S. Dispositions of
8404 // S's users may change if S's disposition changes (i.e. a user may change to
8405 // loop-invariant, if S changes to loop invariant), so also invalidate
8406 // dispositions of S's users recursively.
8407 SmallVector<const SCEV *, 8> Worklist = {S};
8408 SmallPtrSet<const SCEV *, 8> Seen = {S};
8409 while (!Worklist.empty()) {
8410 const SCEV *Curr = Worklist.pop_back_val();
8411 bool LoopDispoRemoved = LoopDispositions.erase(Curr);
8412 bool BlockDispoRemoved = BlockDispositions.erase(Curr);
8413 if (!LoopDispoRemoved && !BlockDispoRemoved)
8414 continue;
8415 auto Users = SCEVUsers.find(Curr);
8416 if (Users != SCEVUsers.end())
8417 for (const auto *User : Users->second)
8418 if (Seen.insert(User).second)
8419 Worklist.push_back(User);
8420 }
8421}
8422
8423/// Get the exact loop backedge taken count considering all loop exits. A
8424/// computable result can only be returned for loops with all exiting blocks
8425/// dominating the latch. howFarToZero assumes that the limit of each loop test
8426/// is never skipped. This is a valid assumption as long as the loop exits via
8427/// that test. For precise results, it is the caller's responsibility to specify
8428/// the relevant loop exiting block using getExact(ExitingBlock, SE).
8429const SCEV *
8430ScalarEvolution::BackedgeTakenInfo::getExact(const Loop *L, ScalarEvolution *SE,
8431 SmallVector<const SCEVPredicate *, 4> *Preds) const {
8432 // If any exits were not computable, the loop is not computable.
8433 if (!isComplete() || ExitNotTaken.empty())
8434 return SE->getCouldNotCompute();
8435
8436 const BasicBlock *Latch = L->getLoopLatch();
8437 // All exiting blocks we have collected must dominate the only backedge.
8438 if (!Latch)
8439 return SE->getCouldNotCompute();
8440
8441 // All exiting blocks we have gathered dominate loop's latch, so exact trip
8442 // count is simply a minimum out of all these calculated exit counts.
8443 SmallVector<const SCEV *, 2> Ops;
8444 for (const auto &ENT : ExitNotTaken) {
8445 const SCEV *BECount = ENT.ExactNotTaken;
8446 assert(BECount != SE->getCouldNotCompute() && "Bad exit SCEV!")(static_cast <bool> (BECount != SE->getCouldNotCompute
() && "Bad exit SCEV!") ? void (0) : __assert_fail ("BECount != SE->getCouldNotCompute() && \"Bad exit SCEV!\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 8446, __extension__
__PRETTY_FUNCTION__))
;
8447 assert(SE->DT.dominates(ENT.ExitingBlock, Latch) &&(static_cast <bool> (SE->DT.dominates(ENT.ExitingBlock
, Latch) && "We should only have known counts for exiting blocks that dominate "
"latch!") ? void (0) : __assert_fail ("SE->DT.dominates(ENT.ExitingBlock, Latch) && \"We should only have known counts for exiting blocks that dominate \" \"latch!\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 8449, __extension__
__PRETTY_FUNCTION__))
8448 "We should only have known counts for exiting blocks that dominate "(static_cast <bool> (SE->DT.dominates(ENT.ExitingBlock
, Latch) && "We should only have known counts for exiting blocks that dominate "
"latch!") ? void (0) : __assert_fail ("SE->DT.dominates(ENT.ExitingBlock, Latch) && \"We should only have known counts for exiting blocks that dominate \" \"latch!\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 8449, __extension__
__PRETTY_FUNCTION__))
8449 "latch!")(static_cast <bool> (SE->DT.dominates(ENT.ExitingBlock
, Latch) && "We should only have known counts for exiting blocks that dominate "
"latch!") ? void (0) : __assert_fail ("SE->DT.dominates(ENT.ExitingBlock, Latch) && \"We should only have known counts for exiting blocks that dominate \" \"latch!\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 8449, __extension__
__PRETTY_FUNCTION__))
;
8450
8451 Ops.push_back(BECount);
8452
8453 if (Preds)
8454 for (const auto *P : ENT.Predicates)
8455 Preds->push_back(P);
8456
8457 assert((Preds || ENT.hasAlwaysTruePredicate()) &&(static_cast <bool> ((Preds || ENT.hasAlwaysTruePredicate
()) && "Predicate should be always true!") ? void (0)
: __assert_fail ("(Preds || ENT.hasAlwaysTruePredicate()) && \"Predicate should be always true!\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 8458, __extension__
__PRETTY_FUNCTION__))
8458 "Predicate should be always true!")(static_cast <bool> ((Preds || ENT.hasAlwaysTruePredicate
()) && "Predicate should be always true!") ? void (0)
: __assert_fail ("(Preds || ENT.hasAlwaysTruePredicate()) && \"Predicate should be always true!\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 8458, __extension__
__PRETTY_FUNCTION__))
;
8459 }
8460
8461 // If an earlier exit exits on the first iteration (exit count zero), then
8462 // a later poison exit count should not propagate into the result. This are
8463 // exactly the semantics provided by umin_seq.
8464 return SE->getUMinFromMismatchedTypes(Ops, /* Sequential */ true);
8465}
8466
8467/// Get the exact not taken count for this loop exit.
8468const SCEV *
8469ScalarEvolution::BackedgeTakenInfo::getExact(const BasicBlock *ExitingBlock,
8470 ScalarEvolution *SE) const {
8471 for (const auto &ENT : ExitNotTaken)
8472 if (ENT.ExitingBlock == ExitingBlock && ENT.hasAlwaysTruePredicate())
8473 return ENT.ExactNotTaken;
8474
8475 return SE->getCouldNotCompute();
8476}
8477
8478const SCEV *ScalarEvolution::BackedgeTakenInfo::getConstantMax(
8479 const BasicBlock *ExitingBlock, ScalarEvolution *SE) const {
8480 for (const auto &ENT : ExitNotTaken)
8481 if (ENT.ExitingBlock == ExitingBlock && ENT.hasAlwaysTruePredicate())
8482 return ENT.MaxNotTaken;
8483
8484 return SE->getCouldNotCompute();
8485}
8486
8487/// getConstantMax - Get the constant max backedge taken count for the loop.
8488const SCEV *
8489ScalarEvolution::BackedgeTakenInfo::getConstantMax(ScalarEvolution *SE) const {
8490 auto PredicateNotAlwaysTrue = [](const ExitNotTakenInfo &ENT) {
8491 return !ENT.hasAlwaysTruePredicate();
8492 };
8493
8494 if (!getConstantMax() || any_of(ExitNotTaken, PredicateNotAlwaysTrue))
8495 return SE->getCouldNotCompute();
8496
8497 assert((isa<SCEVCouldNotCompute>(getConstantMax()) ||(static_cast <bool> ((isa<SCEVCouldNotCompute>(getConstantMax
()) || isa<SCEVConstant>(getConstantMax())) && "No point in having a non-constant max backedge taken count!"
) ? void (0) : __assert_fail ("(isa<SCEVCouldNotCompute>(getConstantMax()) || isa<SCEVConstant>(getConstantMax())) && \"No point in having a non-constant max backedge taken count!\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 8499, __extension__
__PRETTY_FUNCTION__))
8498 isa<SCEVConstant>(getConstantMax())) &&(static_cast <bool> ((isa<SCEVCouldNotCompute>(getConstantMax
()) || isa<SCEVConstant>(getConstantMax())) && "No point in having a non-constant max backedge taken count!"
) ? void (0) : __assert_fail ("(isa<SCEVCouldNotCompute>(getConstantMax()) || isa<SCEVConstant>(getConstantMax())) && \"No point in having a non-constant max backedge taken count!\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 8499, __extension__
__PRETTY_FUNCTION__))
8499 "No point in having a non-constant max backedge taken count!")(static_cast <bool> ((isa<SCEVCouldNotCompute>(getConstantMax
()) || isa<SCEVConstant>(getConstantMax())) && "No point in having a non-constant max backedge taken count!"
) ? void (0) : __assert_fail ("(isa<SCEVCouldNotCompute>(getConstantMax()) || isa<SCEVConstant>(getConstantMax())) && \"No point in having a non-constant max backedge taken count!\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 8499, __extension__
__PRETTY_FUNCTION__))
;
8500 return getConstantMax();
8501}
8502
8503const SCEV *
8504ScalarEvolution::BackedgeTakenInfo::getSymbolicMax(const Loop *L,
8505 ScalarEvolution *SE) {
8506 if (!SymbolicMax)
8507 SymbolicMax = SE->computeSymbolicMaxBackedgeTakenCount(L);
8508 return SymbolicMax;
8509}
8510
8511bool ScalarEvolution::BackedgeTakenInfo::isConstantMaxOrZero(
8512 ScalarEvolution *SE) const {
8513 auto PredicateNotAlwaysTrue = [](const ExitNotTakenInfo &ENT) {
8514 return !ENT.hasAlwaysTruePredicate();
8515 };
8516 return MaxOrZero && !any_of(ExitNotTaken, PredicateNotAlwaysTrue);
8517}
8518
8519ScalarEvolution::ExitLimit::ExitLimit(const SCEV *E)
8520 : ExitLimit(E, E, false, None) {
8521}
8522
8523ScalarEvolution::ExitLimit::ExitLimit(
8524 const SCEV *E, const SCEV *M, bool MaxOrZero,
8525 ArrayRef<const SmallPtrSetImpl<const SCEVPredicate *> *> PredSetList)
8526 : ExactNotTaken(E), MaxNotTaken(M), MaxOrZero(MaxOrZero) {
8527 // If we prove the max count is zero, so is the symbolic bound. This happens
8528 // in practice due to differences in a) how context sensitive we've chosen
8529 // to be and b) how we reason about bounds impied by UB.
8530 if (MaxNotTaken->isZero())
8531 ExactNotTaken = MaxNotTaken;
8532
8533 assert((isa<SCEVCouldNotCompute>(ExactNotTaken) ||(static_cast <bool> ((isa<SCEVCouldNotCompute>(ExactNotTaken
) || !isa<SCEVCouldNotCompute>(MaxNotTaken)) &&
"Exact is not allowed to be less precise than Max") ? void (
0) : __assert_fail ("(isa<SCEVCouldNotCompute>(ExactNotTaken) || !isa<SCEVCouldNotCompute>(MaxNotTaken)) && \"Exact is not allowed to be less precise than Max\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 8535, __extension__
__PRETTY_FUNCTION__))
8534 !isa<SCEVCouldNotCompute>(MaxNotTaken)) &&(static_cast <bool> ((isa<SCEVCouldNotCompute>(ExactNotTaken
) || !isa<SCEVCouldNotCompute>(MaxNotTaken)) &&
"Exact is not allowed to be less precise than Max") ? void (
0) : __assert_fail ("(isa<SCEVCouldNotCompute>(ExactNotTaken) || !isa<SCEVCouldNotCompute>(MaxNotTaken)) && \"Exact is not allowed to be less precise than Max\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 8535, __extension__
__PRETTY_FUNCTION__))
8535 "Exact is not allowed to be less precise than Max")(static_cast <bool> ((isa<SCEVCouldNotCompute>(ExactNotTaken
) || !isa<SCEVCouldNotCompute>(MaxNotTaken)) &&
"Exact is not allowed to be less precise than Max") ? void (
0) : __assert_fail ("(isa<SCEVCouldNotCompute>(ExactNotTaken) || !isa<SCEVCouldNotCompute>(MaxNotTaken)) && \"Exact is not allowed to be less precise than Max\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 8535, __extension__
__PRETTY_FUNCTION__))
;
8536 assert((isa<SCEVCouldNotCompute>(MaxNotTaken) ||(static_cast <bool> ((isa<SCEVCouldNotCompute>(MaxNotTaken
) || isa<SCEVConstant>(MaxNotTaken)) && "No point in having a non-constant max backedge taken count!"
) ? void (0) : __assert_fail ("(isa<SCEVCouldNotCompute>(MaxNotTaken) || isa<SCEVConstant>(MaxNotTaken)) && \"No point in having a non-constant max backedge taken count!\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 8538, __extension__
__PRETTY_FUNCTION__))
8537 isa<SCEVConstant>(MaxNotTaken)) &&(static_cast <bool> ((isa<SCEVCouldNotCompute>(MaxNotTaken
) || isa<SCEVConstant>(MaxNotTaken)) && "No point in having a non-constant max backedge taken count!"
) ? void (0) : __assert_fail ("(isa<SCEVCouldNotCompute>(MaxNotTaken) || isa<SCEVConstant>(MaxNotTaken)) && \"No point in having a non-constant max backedge taken count!\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 8538, __extension__
__PRETTY_FUNCTION__))
8538 "No point in having a non-constant max backedge taken count!")(static_cast <bool> ((isa<SCEVCouldNotCompute>(MaxNotTaken
) || isa<SCEVConstant>(MaxNotTaken)) && "No point in having a non-constant max backedge taken count!"
) ? void (0) : __assert_fail ("(isa<SCEVCouldNotCompute>(MaxNotTaken) || isa<SCEVConstant>(MaxNotTaken)) && \"No point in having a non-constant max backedge taken count!\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 8538, __extension__
__PRETTY_FUNCTION__))
;
8539 for (const auto *PredSet : PredSetList)
8540 for (const auto *P : *PredSet)
8541 addPredicate(P);
8542 assert((isa<SCEVCouldNotCompute>(E) || !E->getType()->isPointerTy()) &&(static_cast <bool> ((isa<SCEVCouldNotCompute>(E)
|| !E->getType()->isPointerTy()) && "Backedge count should be int"
) ? void (0) : __assert_fail ("(isa<SCEVCouldNotCompute>(E) || !E->getType()->isPointerTy()) && \"Backedge count should be int\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 8543, __extension__
__PRETTY_FUNCTION__))
8543 "Backedge count should be int")(static_cast <bool> ((isa<SCEVCouldNotCompute>(E)
|| !E->getType()->isPointerTy()) && "Backedge count should be int"
) ? void (0) : __assert_fail ("(isa<SCEVCouldNotCompute>(E) || !E->getType()->isPointerTy()) && \"Backedge count should be int\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 8543, __extension__
__PRETTY_FUNCTION__))
;
8544 assert((isa<SCEVCouldNotCompute>(M) || !M->getType()->isPointerTy()) &&(static_cast <bool> ((isa<SCEVCouldNotCompute>(M)
|| !M->getType()->isPointerTy()) && "Max backedge count should be int"
) ? void (0) : __assert_fail ("(isa<SCEVCouldNotCompute>(M) || !M->getType()->isPointerTy()) && \"Max backedge count should be int\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 8545, __extension__
__PRETTY_FUNCTION__))
8545 "Max backedge count should be int")(static_cast <bool> ((isa<SCEVCouldNotCompute>(M)
|| !M->getType()->isPointerTy()) && "Max backedge count should be int"
) ? void (0) : __assert_fail ("(isa<SCEVCouldNotCompute>(M) || !M->getType()->isPointerTy()) && \"Max backedge count should be int\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 8545, __extension__
__PRETTY_FUNCTION__))
;
8546}
8547
8548ScalarEvolution::ExitLimit::ExitLimit(
8549 const SCEV *E, const SCEV *M, bool MaxOrZero,
8550 const SmallPtrSetImpl<const SCEVPredicate *> &PredSet)
8551 : ExitLimit(E, M, MaxOrZero, {&PredSet}) {
8552}
8553
8554ScalarEvolution::ExitLimit::ExitLimit(const SCEV *E, const SCEV *M,
8555 bool MaxOrZero)
8556 : ExitLimit(E, M, MaxOrZero, None) {
8557}
8558
8559/// Allocate memory for BackedgeTakenInfo and copy the not-taken count of each
8560/// computable exit into a persistent ExitNotTakenInfo array.
8561ScalarEvolution::BackedgeTakenInfo::BackedgeTakenInfo(
8562 ArrayRef<ScalarEvolution::BackedgeTakenInfo::EdgeExitInfo> ExitCounts,
8563 bool IsComplete, const SCEV *ConstantMax, bool MaxOrZero)
8564 : ConstantMax(ConstantMax), IsComplete(IsComplete), MaxOrZero(MaxOrZero) {
8565 using EdgeExitInfo = ScalarEvolution::BackedgeTakenInfo::EdgeExitInfo;
8566
8567 ExitNotTaken.reserve(ExitCounts.size());
8568 std::transform(
8569 ExitCounts.begin(), ExitCounts.end(), std::back_inserter(ExitNotTaken),
8570 [&](const EdgeExitInfo &EEI) {
8571 BasicBlock *ExitBB = EEI.first;
8572 const ExitLimit &EL = EEI.second;
8573 return ExitNotTakenInfo(ExitBB, EL.ExactNotTaken, EL.MaxNotTaken,
8574 EL.Predicates);
8575 });
8576 assert((isa<SCEVCouldNotCompute>(ConstantMax) ||(static_cast <bool> ((isa<SCEVCouldNotCompute>(ConstantMax
) || isa<SCEVConstant>(ConstantMax)) && "No point in having a non-constant max backedge taken count!"
) ? void (0) : __assert_fail ("(isa<SCEVCouldNotCompute>(ConstantMax) || isa<SCEVConstant>(ConstantMax)) && \"No point in having a non-constant max backedge taken count!\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 8578, __extension__
__PRETTY_FUNCTION__))
8577 isa<SCEVConstant>(ConstantMax)) &&(static_cast <bool> ((isa<SCEVCouldNotCompute>(ConstantMax
) || isa<SCEVConstant>(ConstantMax)) && "No point in having a non-constant max backedge taken count!"
) ? void (0) : __assert_fail ("(isa<SCEVCouldNotCompute>(ConstantMax) || isa<SCEVConstant>(ConstantMax)) && \"No point in having a non-constant max backedge taken count!\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 8578, __extension__
__PRETTY_FUNCTION__))
8578 "No point in having a non-constant max backedge taken count!")(static_cast <bool> ((isa<SCEVCouldNotCompute>(ConstantMax
) || isa<SCEVConstant>(ConstantMax)) && "No point in having a non-constant max backedge taken count!"
) ? void (0) : __assert_fail ("(isa<SCEVCouldNotCompute>(ConstantMax) || isa<SCEVConstant>(ConstantMax)) && \"No point in having a non-constant max backedge taken count!\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 8578, __extension__
__PRETTY_FUNCTION__))
;
8579}
8580
8581/// Compute the number of times the backedge of the specified loop will execute.
8582ScalarEvolution::BackedgeTakenInfo
8583ScalarEvolution::computeBackedgeTakenCount(const Loop *L,
8584 bool AllowPredicates) {
8585 SmallVector<BasicBlock *, 8> ExitingBlocks;
8586 L->getExitingBlocks(ExitingBlocks);
8587
8588 using EdgeExitInfo = ScalarEvolution::BackedgeTakenInfo::EdgeExitInfo;
8589
8590 SmallVector<EdgeExitInfo, 4> ExitCounts;
8591 bool CouldComputeBECount = true;
8592 BasicBlock *Latch = L->getLoopLatch(); // may be NULL.
8593 const SCEV *MustExitMaxBECount = nullptr;
8594 const SCEV *MayExitMaxBECount = nullptr;
8595 bool MustExitMaxOrZero = false;
8596
8597 // Compute the ExitLimit for each loop exit. Use this to populate ExitCounts
8598 // and compute maxBECount.
8599 // Do a union of all the predicates here.
8600 for (unsigned i = 0, e = ExitingBlocks.size(); i != e; ++i) {
8601 BasicBlock *ExitBB = ExitingBlocks[i];
8602
8603 // We canonicalize untaken exits to br (constant), ignore them so that
8604 // proving an exit untaken doesn't negatively impact our ability to reason
8605 // about the loop as whole.
8606 if (auto *BI = dyn_cast<BranchInst>(ExitBB->getTerminator()))
8607 if (auto *CI = dyn_cast<ConstantInt>(BI->getCondition())) {
8608 bool ExitIfTrue = !L->contains(BI->getSuccessor(0));
8609 if (ExitIfTrue == CI->isZero())
8610 continue;
8611 }
8612
8613 ExitLimit EL = computeExitLimit(L, ExitBB, AllowPredicates);
8614
8615 assert((AllowPredicates || EL.Predicates.empty()) &&(static_cast <bool> ((AllowPredicates || EL.Predicates.
empty()) && "Predicated exit limit when predicates are not allowed!"
) ? void (0) : __assert_fail ("(AllowPredicates || EL.Predicates.empty()) && \"Predicated exit limit when predicates are not allowed!\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 8616, __extension__
__PRETTY_FUNCTION__))
8616 "Predicated exit limit when predicates are not allowed!")(static_cast <bool> ((AllowPredicates || EL.Predicates.
empty()) && "Predicated exit limit when predicates are not allowed!"
) ? void (0) : __assert_fail ("(AllowPredicates || EL.Predicates.empty()) && \"Predicated exit limit when predicates are not allowed!\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 8616, __extension__
__PRETTY_FUNCTION__))
;
8617
8618 // 1. For each exit that can be computed, add an entry to ExitCounts.
8619 // CouldComputeBECount is true only if all exits can be computed.
8620 if (EL.ExactNotTaken == getCouldNotCompute())
8621 // We couldn't compute an exact value for this exit, so
8622 // we won't be able to compute an exact value for the loop.
8623 CouldComputeBECount = false;
8624 else
8625 ExitCounts.emplace_back(ExitBB, EL);
8626
8627 // 2. Derive the loop's MaxBECount from each exit's max number of
8628 // non-exiting iterations. Partition the loop exits into two kinds:
8629 // LoopMustExits and LoopMayExits.
8630 //
8631 // If the exit dominates the loop latch, it is a LoopMustExit otherwise it
8632 // is a LoopMayExit. If any computable LoopMustExit is found, then
8633 // MaxBECount is the minimum EL.MaxNotTaken of computable
8634 // LoopMustExits. Otherwise, MaxBECount is conservatively the maximum
8635 // EL.MaxNotTaken, where CouldNotCompute is considered greater than any
8636 // computable EL.MaxNotTaken.
8637 if (EL.MaxNotTaken != getCouldNotCompute() && Latch &&
8638 DT.dominates(ExitBB, Latch)) {
8639 if (!MustExitMaxBECount) {
8640 MustExitMaxBECount = EL.MaxNotTaken;
8641 MustExitMaxOrZero = EL.MaxOrZero;
8642 } else {
8643 MustExitMaxBECount =
8644 getUMinFromMismatchedTypes(MustExitMaxBECount, EL.MaxNotTaken);
8645 }
8646 } else if (MayExitMaxBECount != getCouldNotCompute()) {
8647 if (!MayExitMaxBECount || EL.MaxNotTaken == getCouldNotCompute())
8648 MayExitMaxBECount = EL.MaxNotTaken;
8649 else {
8650 MayExitMaxBECount =
8651 getUMaxFromMismatchedTypes(MayExitMaxBECount, EL.MaxNotTaken);
8652 }
8653 }
8654 }
8655 const SCEV *MaxBECount = MustExitMaxBECount ? MustExitMaxBECount :
8656 (MayExitMaxBECount ? MayExitMaxBECount : getCouldNotCompute());
8657 // The loop backedge will be taken the maximum or zero times if there's
8658 // a single exit that must be taken the maximum or zero times.
8659 bool MaxOrZero = (MustExitMaxOrZero && ExitingBlocks.size() == 1);
8660
8661 // Remember which SCEVs are used in exit limits for invalidation purposes.
8662 // We only care about non-constant SCEVs here, so we can ignore EL.MaxNotTaken
8663 // and MaxBECount, which must be SCEVConstant.
8664 for (const auto &Pair : ExitCounts)
8665 if (!isa<SCEVConstant>(Pair.second.ExactNotTaken))
8666 BECountUsers[Pair.second.ExactNotTaken].insert({L, AllowPredicates});
8667 return BackedgeTakenInfo(std::move(ExitCounts), CouldComputeBECount,
8668 MaxBECount, MaxOrZero);
8669}
8670
8671ScalarEvolution::ExitLimit
8672ScalarEvolution::computeExitLimit(const Loop *L, BasicBlock *ExitingBlock,
8673 bool AllowPredicates) {
8674 assert(L->contains(ExitingBlock) && "Exit count for non-loop block?")(static_cast <bool> (L->contains(ExitingBlock) &&
"Exit count for non-loop block?") ? void (0) : __assert_fail
("L->contains(ExitingBlock) && \"Exit count for non-loop block?\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 8674, __extension__
__PRETTY_FUNCTION__))
;
8675 // If our exiting block does not dominate the latch, then its connection with
8676 // loop's exit limit may be far from trivial.
8677 const BasicBlock *Latch = L->getLoopLatch();
8678 if (!Latch || !DT.dominates(ExitingBlock, Latch))
8679 return getCouldNotCompute();
8680
8681 bool IsOnlyExit = (L->getExitingBlock() != nullptr);
8682 Instruction *Term = ExitingBlock->getTerminator();
8683 if (BranchInst *BI = dyn_cast<BranchInst>(Term)) {
8684 assert(BI->isConditional() && "If unconditional, it can't be in loop!")(static_cast <bool> (BI->isConditional() && "If unconditional, it can't be in loop!"
) ? void (0) : __assert_fail ("BI->isConditional() && \"If unconditional, it can't be in loop!\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 8684, __extension__
__PRETTY_FUNCTION__))
;
8685 bool ExitIfTrue = !L->contains(BI->getSuccessor(0));
8686 assert(ExitIfTrue == L->contains(BI->getSuccessor(1)) &&(static_cast <bool> (ExitIfTrue == L->contains(BI->
getSuccessor(1)) && "It should have one successor in loop and one exit block!"
) ? void (0) : __assert_fail ("ExitIfTrue == L->contains(BI->getSuccessor(1)) && \"It should have one successor in loop and one exit block!\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 8687, __extension__
__PRETTY_FUNCTION__))
8687 "It should have one successor in loop and one exit block!")(static_cast <bool> (ExitIfTrue == L->contains(BI->
getSuccessor(1)) && "It should have one successor in loop and one exit block!"
) ? void (0) : __assert_fail ("ExitIfTrue == L->contains(BI->getSuccessor(1)) && \"It should have one successor in loop and one exit block!\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 8687, __extension__
__PRETTY_FUNCTION__))
;
8688 // Proceed to the next level to examine the exit condition expression.
8689 return computeExitLimitFromCond(
8690 L, BI->getCondition(), ExitIfTrue,
8691 /*ControlsExit=*/IsOnlyExit, AllowPredicates);
8692 }
8693
8694 if (SwitchInst *SI = dyn_cast<SwitchInst>(Term)) {
8695 // For switch, make sure that there is a single exit from the loop.
8696 BasicBlock *Exit = nullptr;
8697 for (auto *SBB : successors(ExitingBlock))
8698 if (!L->contains(SBB)) {
8699 if (Exit) // Multiple exit successors.
8700 return getCouldNotCompute();
8701 Exit = SBB;
8702 }
8703 assert(Exit && "Exiting block must have at least one exit")(static_cast <bool> (Exit && "Exiting block must have at least one exit"
) ? void (0) : __assert_fail ("Exit && \"Exiting block must have at least one exit\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 8703, __extension__
__PRETTY_FUNCTION__))
;
8704 return computeExitLimitFromSingleExitSwitch(L, SI, Exit,
8705 /*ControlsExit=*/IsOnlyExit);
8706 }
8707
8708 return getCouldNotCompute();
8709}
8710
8711ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromCond(
8712 const Loop *L, Value *ExitCond, bool ExitIfTrue,
8713 bool ControlsExit, bool AllowPredicates) {
8714 ScalarEvolution::ExitLimitCacheTy Cache(L, ExitIfTrue, AllowPredicates);
8715 return computeExitLimitFromCondCached(Cache, L, ExitCond, ExitIfTrue,
8716 ControlsExit, AllowPredicates);
8717}
8718
8719Optional<ScalarEvolution::ExitLimit>
8720ScalarEvolution::ExitLimitCache::find(const Loop *L, Value *ExitCond,
8721 bool ExitIfTrue, bool ControlsExit,
8722 bool AllowPredicates) {
8723 (void)this->L;
8724 (void)this->ExitIfTrue;
8725 (void)this->AllowPredicates;
8726
8727 assert(this->L == L && this->ExitIfTrue == ExitIfTrue &&(static_cast <bool> (this->L == L && this->
ExitIfTrue == ExitIfTrue && this->AllowPredicates ==
AllowPredicates && "Variance in assumed invariant key components!"
) ? void (0) : __assert_fail ("this->L == L && this->ExitIfTrue == ExitIfTrue && this->AllowPredicates == AllowPredicates && \"Variance in assumed invariant key components!\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 8729, __extension__
__PRETTY_FUNCTION__))
8728 this->AllowPredicates == AllowPredicates &&(static_cast <bool> (this->L == L && this->
ExitIfTrue == ExitIfTrue && this->AllowPredicates ==
AllowPredicates && "Variance in assumed invariant key components!"
) ? void (0) : __assert_fail ("this->L == L && this->ExitIfTrue == ExitIfTrue && this->AllowPredicates == AllowPredicates && \"Variance in assumed invariant key components!\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 8729, __extension__
__PRETTY_FUNCTION__))
8729 "Variance in assumed invariant key components!")(static_cast <bool> (this->L == L && this->
ExitIfTrue == ExitIfTrue && this->AllowPredicates ==
AllowPredicates && "Variance in assumed invariant key components!"
) ? void (0) : __assert_fail ("this->L == L && this->ExitIfTrue == ExitIfTrue && this->AllowPredicates == AllowPredicates && \"Variance in assumed invariant key components!\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 8729, __extension__
__PRETTY_FUNCTION__))
;
8730 auto Itr = TripCountMap.find({ExitCond, ControlsExit});
8731 if (Itr == TripCountMap.end())
8732 return None;
8733 return Itr->second;
8734}
8735
8736void ScalarEvolution::ExitLimitCache::insert(const Loop *L, Value *ExitCond,
8737 bool ExitIfTrue,
8738 bool ControlsExit,
8739 bool AllowPredicates,
8740 const ExitLimit &EL) {
8741 assert(this->L == L && this->ExitIfTrue == ExitIfTrue &&(static_cast <bool> (this->L == L && this->
ExitIfTrue == ExitIfTrue && this->AllowPredicates ==
AllowPredicates && "Variance in assumed invariant key components!"
) ? void (0) : __assert_fail ("this->L == L && this->ExitIfTrue == ExitIfTrue && this->AllowPredicates == AllowPredicates && \"Variance in assumed invariant key components!\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 8743, __extension__
__PRETTY_FUNCTION__))
8742 this->AllowPredicates == AllowPredicates &&(static_cast <bool> (this->L == L && this->
ExitIfTrue == ExitIfTrue && this->AllowPredicates ==
AllowPredicates && "Variance in assumed invariant key components!"
) ? void (0) : __assert_fail ("this->L == L && this->ExitIfTrue == ExitIfTrue && this->AllowPredicates == AllowPredicates && \"Variance in assumed invariant key components!\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 8743, __extension__
__PRETTY_FUNCTION__))
8743 "Variance in assumed invariant key components!")(static_cast <bool> (this->L == L && this->
ExitIfTrue == ExitIfTrue && this->AllowPredicates ==
AllowPredicates && "Variance in assumed invariant key components!"
) ? void (0) : __assert_fail ("this->L == L && this->ExitIfTrue == ExitIfTrue && this->AllowPredicates == AllowPredicates && \"Variance in assumed invariant key components!\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 8743, __extension__
__PRETTY_FUNCTION__))
;
8744
8745 auto InsertResult = TripCountMap.insert({{ExitCond, ControlsExit}, EL});
8746 assert(InsertResult.second && "Expected successful insertion!")(static_cast <bool> (InsertResult.second && "Expected successful insertion!"
) ? void (0) : __assert_fail ("InsertResult.second && \"Expected successful insertion!\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 8746, __extension__
__PRETTY_FUNCTION__))
;
8747 (void)InsertResult;
8748 (void)ExitIfTrue;
8749}
8750
8751ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromCondCached(
8752 ExitLimitCacheTy &Cache, const Loop *L, Value *ExitCond, bool ExitIfTrue,
8753 bool ControlsExit, bool AllowPredicates) {
8754
8755 if (auto MaybeEL =
8756 Cache.find(L, ExitCond, ExitIfTrue, ControlsExit, AllowPredicates))
8757 return *MaybeEL;
8758
8759 ExitLimit EL = computeExitLimitFromCondImpl(Cache, L, ExitCond, ExitIfTrue,
8760 ControlsExit, AllowPredicates);
8761 Cache.insert(L, ExitCond, ExitIfTrue, ControlsExit, AllowPredicates, EL);
8762 return EL;
8763}
8764
8765ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromCondImpl(
8766 ExitLimitCacheTy &Cache, const Loop *L, Value *ExitCond, bool ExitIfTrue,
8767 bool ControlsExit, bool AllowPredicates) {
8768 // Handle BinOp conditions (And, Or).
8769 if (auto LimitFromBinOp = computeExitLimitFromCondFromBinOp(
8770 Cache, L, ExitCond, ExitIfTrue, ControlsExit, AllowPredicates))
8771 return *LimitFromBinOp;
8772
8773 // With an icmp, it may be feasible to compute an exact backedge-taken count.
8774 // Proceed to the next level to examine the icmp.
8775 if (ICmpInst *ExitCondICmp = dyn_cast<ICmpInst>(ExitCond)) {
8776 ExitLimit EL =
8777 computeExitLimitFromICmp(L, ExitCondICmp, ExitIfTrue, ControlsExit);
8778 if (EL.hasFullInfo() || !AllowPredicates)
8779 return EL;
8780
8781 // Try again, but use SCEV predicates this time.
8782 return computeExitLimitFromICmp(L, ExitCondICmp, ExitIfTrue, ControlsExit,
8783 /*AllowPredicates=*/true);
8784 }
8785
8786 // Check for a constant condition. These are normally stripped out by
8787 // SimplifyCFG, but ScalarEvolution may be used by a pass which wishes to
8788 // preserve the CFG and is temporarily leaving constant conditions
8789 // in place.
8790 if (ConstantInt *CI = dyn_cast<ConstantInt>(ExitCond)) {
8791 if (ExitIfTrue == !CI->getZExtValue())
8792 // The backedge is always taken.
8793 return getCouldNotCompute();
8794 else
8795 // The backedge is never taken.
8796 return getZero(CI->getType());
8797 }
8798
8799 // If we're exiting based on the overflow flag of an x.with.overflow intrinsic
8800 // with a constant step, we can form an equivalent icmp predicate and figure
8801 // out how many iterations will be taken before we exit.
8802 const WithOverflowInst *WO;
8803 const APInt *C;
8804 if (match(ExitCond, m_ExtractValue<1>(m_WithOverflowInst(WO))) &&
8805 match(WO->getRHS(), m_APInt(C))) {
8806 ConstantRange NWR =
8807 ConstantRange::makeExactNoWrapRegion(WO->getBinaryOp(), *C,
8808 WO->getNoWrapKind());
8809 CmpInst::Predicate Pred;
8810 APInt NewRHSC, Offset;
8811 NWR.getEquivalentICmp(Pred, NewRHSC, Offset);
8812 if (!ExitIfTrue)
8813 Pred = ICmpInst::getInversePredicate(Pred);
8814 auto *LHS = getSCEV(WO->getLHS());
8815 if (Offset != 0)
8816 LHS = getAddExpr(LHS, getConstant(Offset));
8817 auto EL = computeExitLimitFromICmp(L, Pred, LHS, getConstant(NewRHSC),
8818 ControlsExit, AllowPredicates);
8819 if (EL.hasAnyInfo()) return EL;
8820 }
8821
8822 // If it's not an integer or pointer comparison then compute it the hard way.
8823 return computeExitCountExhaustively(L, ExitCond, ExitIfTrue);
8824}
8825
8826Optional<ScalarEvolution::ExitLimit>
8827ScalarEvolution::computeExitLimitFromCondFromBinOp(
8828 ExitLimitCacheTy &Cache, const Loop *L, Value *ExitCond, bool ExitIfTrue,
8829 bool ControlsExit, bool AllowPredicates) {
8830 // Check if the controlling expression for this loop is an And or Or.
8831 Value *Op0, *Op1;
8832 bool IsAnd = false;
8833 if (match(ExitCond, m_LogicalAnd(m_Value(Op0), m_Value(Op1))))
8834 IsAnd = true;
8835 else if (match(ExitCond, m_LogicalOr(m_Value(Op0), m_Value(Op1))))
8836 IsAnd = false;
8837 else
8838 return None;
8839
8840 // EitherMayExit is true in these two cases:
8841 // br (and Op0 Op1), loop, exit
8842 // br (or Op0 Op1), exit, loop
8843 bool EitherMayExit = IsAnd ^ ExitIfTrue;
8844 ExitLimit EL0 = computeExitLimitFromCondCached(Cache, L, Op0, ExitIfTrue,
8845 ControlsExit && !EitherMayExit,
8846 AllowPredicates);
8847 ExitLimit EL1 = computeExitLimitFromCondCached(Cache, L, Op1, ExitIfTrue,
8848 ControlsExit && !EitherMayExit,
8849 AllowPredicates);
8850
8851 // Be robust against unsimplified IR for the form "op i1 X, NeutralElement"
8852 const Constant *NeutralElement = ConstantInt::get(ExitCond->getType(), IsAnd);
8853 if (isa<ConstantInt>(Op1))
8854 return Op1 == NeutralElement ? EL0 : EL1;
8855 if (isa<ConstantInt>(Op0))
8856 return Op0 == NeutralElement ? EL1 : EL0;
8857
8858 const SCEV *BECount = getCouldNotCompute();
8859 const SCEV *MaxBECount = getCouldNotCompute();
8860 if (EitherMayExit) {
8861 // Both conditions must be same for the loop to continue executing.
8862 // Choose the less conservative count.
8863 if (EL0.ExactNotTaken != getCouldNotCompute() &&
8864 EL1.ExactNotTaken != getCouldNotCompute()) {
8865 BECount = getUMinFromMismatchedTypes(
8866 EL0.ExactNotTaken, EL1.ExactNotTaken,
8867 /*Sequential=*/!isa<BinaryOperator>(ExitCond));
8868 }
8869 if (EL0.MaxNotTaken == getCouldNotCompute())
8870 MaxBECount = EL1.MaxNotTaken;
8871 else if (EL1.MaxNotTaken == getCouldNotCompute())
8872 MaxBECount = EL0.MaxNotTaken;
8873 else
8874 MaxBECount = getUMinFromMismatchedTypes(EL0.MaxNotTaken, EL1.MaxNotTaken);
8875 } else {
8876 // Both conditions must be same at the same time for the loop to exit.
8877 // For now, be conservative.
8878 if (EL0.ExactNotTaken == EL1.ExactNotTaken)
8879 BECount = EL0.ExactNotTaken;
8880 }
8881
8882 // There are cases (e.g. PR26207) where computeExitLimitFromCond is able
8883 // to be more aggressive when computing BECount than when computing
8884 // MaxBECount. In these cases it is possible for EL0.ExactNotTaken and
8885 // EL1.ExactNotTaken to match, but for EL0.MaxNotTaken and EL1.MaxNotTaken
8886 // to not.
8887 if (isa<SCEVCouldNotCompute>(MaxBECount) &&
8888 !isa<SCEVCouldNotCompute>(BECount))
8889 MaxBECount = getConstant(getUnsignedRangeMax(BECount));
8890
8891 return ExitLimit(BECount, MaxBECount, false,
8892 { &EL0.Predicates, &EL1.Predicates });
8893}
8894
8895ScalarEvolution::ExitLimit
8896ScalarEvolution::computeExitLimitFromICmp(const Loop *L,
8897 ICmpInst *ExitCond,
8898 bool ExitIfTrue,
8899 bool ControlsExit,
8900 bool AllowPredicates) {
8901 // If the condition was exit on true, convert the condition to exit on false
8902 ICmpInst::Predicate Pred;
8903 if (!ExitIfTrue)
8904 Pred = ExitCond->getPredicate();
8905 else
8906 Pred = ExitCond->getInversePredicate();
8907 const ICmpInst::Predicate OriginalPred = Pred;
8908
8909 const SCEV *LHS = getSCEV(ExitCond->getOperand(0));
8910 const SCEV *RHS = getSCEV(ExitCond->getOperand(1));
8911
8912 ExitLimit EL = computeExitLimitFromICmp(L, Pred, LHS, RHS, ControlsExit,
8913 AllowPredicates);
8914 if (EL.hasAnyInfo()) return EL;
8915
8916 auto *ExhaustiveCount =
8917 computeExitCountExhaustively(L, ExitCond, ExitIfTrue);
8918
8919 if (!isa<SCEVCouldNotCompute>(ExhaustiveCount))
8920 return ExhaustiveCount;
8921
8922 return computeShiftCompareExitLimit(ExitCond->getOperand(0),
8923 ExitCond->getOperand(1), L, OriginalPred);
8924}
8925ScalarEvolution::ExitLimit
8926ScalarEvolution::computeExitLimitFromICmp(const Loop *L,
8927 ICmpInst::Predicate Pred,
8928 const SCEV *LHS, const SCEV *RHS,
8929 bool ControlsExit,
8930 bool AllowPredicates) {
8931
8932 // Try to evaluate any dependencies out of the loop.
8933 LHS = getSCEVAtScope(LHS, L);
8934 RHS = getSCEVAtScope(RHS, L);
8935
8936 // At this point, we would like to compute how many iterations of the
8937 // loop the predicate will return true for these inputs.
8938 if (isLoopInvariant(LHS, L) && !isLoopInvariant(RHS, L)) {
8939 // If there is a loop-invariant, force it into the RHS.
8940 std::swap(LHS, RHS);
8941 Pred = ICmpInst::getSwappedPredicate(Pred);
8942 }
8943
8944 bool ControllingFiniteLoop =
8945 ControlsExit && loopHasNoAbnormalExits(L) && loopIsFiniteByAssumption(L);
8946 // Simplify the operands before analyzing them.
8947 (void)SimplifyICmpOperands(Pred, LHS, RHS, /*Depth=*/0,
8948 (EnableFiniteLoopControl ? ControllingFiniteLoop
8949 : false));
8950
8951 // If we have a comparison of a chrec against a constant, try to use value
8952 // ranges to answer this query.
8953 if (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS))
8954 if (const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(LHS))
8955 if (AddRec->getLoop() == L) {
8956 // Form the constant range.
8957 ConstantRange CompRange =
8958 ConstantRange::makeExactICmpRegion(Pred, RHSC->getAPInt());
8959
8960 const SCEV *Ret = AddRec->getNumIterationsInRange(CompRange, *this);
8961 if (!isa<SCEVCouldNotCompute>(Ret)) return Ret;
8962 }
8963
8964 // If this loop must exit based on this condition (or execute undefined
8965 // behaviour), and we can prove the test sequence produced must repeat
8966 // the same values on self-wrap of the IV, then we can infer that IV
8967 // doesn't self wrap because if it did, we'd have an infinite (undefined)
8968 // loop.
8969 if (ControllingFiniteLoop && isLoopInvariant(RHS, L)) {
8970 // TODO: We can peel off any functions which are invertible *in L*. Loop
8971 // invariant terms are effectively constants for our purposes here.
8972 auto *InnerLHS = LHS;
8973 if (auto *ZExt = dyn_cast<SCEVZeroExtendExpr>(LHS))
8974 InnerLHS = ZExt->getOperand();
8975 if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(InnerLHS)) {
8976 auto *StrideC = dyn_cast<SCEVConstant>(AR->getStepRecurrence(*this));
8977 if (!AR->hasNoSelfWrap() && AR->getLoop() == L && AR->isAffine() &&
8978 StrideC && StrideC->getAPInt().isPowerOf2()) {
8979 auto Flags = AR->getNoWrapFlags();
8980 Flags = setFlags(Flags, SCEV::FlagNW);
8981 SmallVector<const SCEV*> Operands{AR->operands()};
8982 Flags = StrengthenNoWrapFlags(this, scAddRecExpr, Operands, Flags);
8983 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), Flags);
8984 }
8985 }
8986 }
8987
8988 switch (Pred) {
8989 case ICmpInst::ICMP_NE: { // while (X != Y)
8990 // Convert to: while (X-Y != 0)
8991 if (LHS->getType()->isPointerTy()) {
8992 LHS = getLosslessPtrToIntExpr(LHS);
8993 if (isa<SCEVCouldNotCompute>(LHS))
8994 return LHS;
8995 }
8996 if (RHS->getType()->isPointerTy()) {
8997 RHS = getLosslessPtrToIntExpr(RHS);
8998 if (isa<SCEVCouldNotCompute>(RHS))
8999 return RHS;
9000 }
9001 ExitLimit EL = howFarToZero(getMinusSCEV(LHS, RHS), L, ControlsExit,
9002 AllowPredicates);
9003 if (EL.hasAnyInfo()) return EL;
9004 break;
9005 }
9006 case ICmpInst::ICMP_EQ: { // while (X == Y)
9007 // Convert to: while (X-Y == 0)
9008 if (LHS->getType()->isPointerTy()) {
9009 LHS = getLosslessPtrToIntExpr(LHS);
9010 if (isa<SCEVCouldNotCompute>(LHS))
9011 return LHS;
9012 }
9013 if (RHS->getType()->isPointerTy()) {
9014 RHS = getLosslessPtrToIntExpr(RHS);
9015 if (isa<SCEVCouldNotCompute>(RHS))
9016 return RHS;
9017 }
9018 ExitLimit EL = howFarToNonZero(getMinusSCEV(LHS, RHS), L);
9019 if (EL.hasAnyInfo()) return EL;
9020 break;
9021 }
9022 case ICmpInst::ICMP_SLT:
9023 case ICmpInst::ICMP_ULT: { // while (X < Y)
9024 bool IsSigned = Pred == ICmpInst::ICMP_SLT;
9025 ExitLimit EL = howManyLessThans(LHS, RHS, L, IsSigned, ControlsExit,
9026 AllowPredicates);
9027 if (EL.hasAnyInfo()) return EL;
9028 break;
9029 }
9030 case ICmpInst::ICMP_SGT:
9031 case ICmpInst::ICMP_UGT: { // while (X > Y)
9032 bool IsSigned = Pred == ICmpInst::ICMP_SGT;
9033 ExitLimit EL =
9034 howManyGreaterThans(LHS, RHS, L, IsSigned, ControlsExit,
9035 AllowPredicates);
9036 if (EL.hasAnyInfo()) return EL;
9037 break;
9038 }
9039 default:
9040 break;
9041 }
9042
9043 return getCouldNotCompute();
9044}
9045
9046ScalarEvolution::ExitLimit
9047ScalarEvolution::computeExitLimitFromSingleExitSwitch(const Loop *L,
9048 SwitchInst *Switch,
9049 BasicBlock *ExitingBlock,
9050 bool ControlsExit) {
9051 assert(!L->contains(ExitingBlock) && "Not an exiting block!")(static_cast <bool> (!L->contains(ExitingBlock) &&
"Not an exiting block!") ? void (0) : __assert_fail ("!L->contains(ExitingBlock) && \"Not an exiting block!\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 9051, __extension__
__PRETTY_FUNCTION__))
;
9052
9053 // Give up if the exit is the default dest of a switch.
9054 if (Switch->getDefaultDest() == ExitingBlock)
9055 return getCouldNotCompute();
9056
9057 assert(L->contains(Switch->getDefaultDest()) &&(static_cast <bool> (L->contains(Switch->getDefaultDest
()) && "Default case must not exit the loop!") ? void
(0) : __assert_fail ("L->contains(Switch->getDefaultDest()) && \"Default case must not exit the loop!\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 9058, __extension__
__PRETTY_FUNCTION__))
9058 "Default case must not exit the loop!")(static_cast <bool> (L->contains(Switch->getDefaultDest
()) && "Default case must not exit the loop!") ? void
(0) : __assert_fail ("L->contains(Switch->getDefaultDest()) && \"Default case must not exit the loop!\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 9058, __extension__
__PRETTY_FUNCTION__))
;
9059 const SCEV *LHS = getSCEVAtScope(Switch->getCondition(), L);
9060 const SCEV *RHS = getConstant(Switch->findCaseDest(ExitingBlock));
9061
9062 // while (X != Y) --> while (X-Y != 0)
9063 ExitLimit EL = howFarToZero(getMinusSCEV(LHS, RHS), L, ControlsExit);
9064 if (EL.hasAnyInfo())
9065 return EL;
9066
9067 return getCouldNotCompute();
9068}
9069
9070static ConstantInt *
9071EvaluateConstantChrecAtConstant(const SCEVAddRecExpr *AddRec, ConstantInt *C,
9072 ScalarEvolution &SE) {
9073 const SCEV *InVal = SE.getConstant(C);
9074 const SCEV *Val = AddRec->evaluateAtIteration(InVal, SE);
9075 assert(isa<SCEVConstant>(Val) &&(static_cast <bool> (isa<SCEVConstant>(Val) &&
"Evaluation of SCEV at constant didn't fold correctly?") ? void
(0) : __assert_fail ("isa<SCEVConstant>(Val) && \"Evaluation of SCEV at constant didn't fold correctly?\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 9076, __extension__
__PRETTY_FUNCTION__))
9076 "Evaluation of SCEV at constant didn't fold correctly?")(static_cast <bool> (isa<SCEVConstant>(Val) &&
"Evaluation of SCEV at constant didn't fold correctly?") ? void
(0) : __assert_fail ("isa<SCEVConstant>(Val) && \"Evaluation of SCEV at constant didn't fold correctly?\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 9076, __extension__
__PRETTY_FUNCTION__))
;
9077 return cast<SCEVConstant>(Val)->getValue();
9078}
9079
9080ScalarEvolution::ExitLimit ScalarEvolution::computeShiftCompareExitLimit(
9081 Value *LHS, Value *RHSV, const Loop *L, ICmpInst::Predicate Pred) {
9082 ConstantInt *RHS = dyn_cast<ConstantInt>(RHSV);
9083 if (!RHS)
9084 return getCouldNotCompute();
9085
9086 const BasicBlock *Latch = L->getLoopLatch();
9087 if (!Latch)
9088 return getCouldNotCompute();
9089
9090 const BasicBlock *Predecessor = L->getLoopPredecessor();
9091 if (!Predecessor)
9092 return getCouldNotCompute();
9093
9094 // Return true if V is of the form "LHS `shift_op` <positive constant>".
9095 // Return LHS in OutLHS and shift_opt in OutOpCode.
9096 auto MatchPositiveShift =
9097 [](Value *V, Value *&OutLHS, Instruction::BinaryOps &OutOpCode) {
9098
9099 using namespace PatternMatch;
9100
9101 ConstantInt *ShiftAmt;
9102 if (match(V, m_LShr(m_Value(OutLHS), m_ConstantInt(ShiftAmt))))
9103 OutOpCode = Instruction::LShr;
9104 else if (match(V, m_AShr(m_Value(OutLHS), m_ConstantInt(ShiftAmt))))
9105 OutOpCode = Instruction::AShr;
9106 else if (match(V, m_Shl(m_Value(OutLHS), m_ConstantInt(ShiftAmt))))
9107 OutOpCode = Instruction::Shl;
9108 else
9109 return false;
9110
9111 return ShiftAmt->getValue().isStrictlyPositive();
9112 };
9113
9114 // Recognize a "shift recurrence" either of the form %iv or of %iv.shifted in
9115 //
9116 // loop:
9117 // %iv = phi i32 [ %iv.shifted, %loop ], [ %val, %preheader ]
9118 // %iv.shifted = lshr i32 %iv, <positive constant>
9119 //
9120 // Return true on a successful match. Return the corresponding PHI node (%iv
9121 // above) in PNOut and the opcode of the shift operation in OpCodeOut.
9122 auto MatchShiftRecurrence =
9123 [&](Value *V, PHINode *&PNOut, Instruction::BinaryOps &OpCodeOut) {
9124 Optional<Instruction::BinaryOps> PostShiftOpCode;
9125
9126 {
9127 Instruction::BinaryOps OpC;
9128 Value *V;
9129
9130 // If we encounter a shift instruction, "peel off" the shift operation,
9131 // and remember that we did so. Later when we inspect %iv's backedge
9132 // value, we will make sure that the backedge value uses the same
9133 // operation.
9134 //
9135 // Note: the peeled shift operation does not have to be the same
9136 // instruction as the one feeding into the PHI's backedge value. We only
9137 // really care about it being the same *kind* of shift instruction --
9138 // that's all that is required for our later inferences to hold.
9139 if (MatchPositiveShift(LHS, V, OpC)) {
9140 PostShiftOpCode = OpC;
9141 LHS = V;
9142 }
9143 }
9144
9145 PNOut = dyn_cast<PHINode>(LHS);
9146 if (!PNOut || PNOut->getParent() != L->getHeader())
9147 return false;
9148
9149 Value *BEValue = PNOut->getIncomingValueForBlock(Latch);
9150 Value *OpLHS;
9151
9152 return
9153 // The backedge value for the PHI node must be a shift by a positive
9154 // amount
9155 MatchPositiveShift(BEValue, OpLHS, OpCodeOut) &&
9156
9157 // of the PHI node itself
9158 OpLHS == PNOut &&
9159
9160 // and the kind of shift should be match the kind of shift we peeled
9161 // off, if any.
9162 (!PostShiftOpCode || *PostShiftOpCode == OpCodeOut);
9163 };
9164
9165 PHINode *PN;
9166 Instruction::BinaryOps OpCode;
9167 if (!MatchShiftRecurrence(LHS, PN, OpCode))
9168 return getCouldNotCompute();
9169
9170 const DataLayout &DL = getDataLayout();
9171
9172 // The key rationale for this optimization is that for some kinds of shift
9173 // recurrences, the value of the recurrence "stabilizes" to either 0 or -1
9174 // within a finite number of iterations. If the condition guarding the
9175 // backedge (in the sense that the backedge is taken if the condition is true)
9176 // is false for the value the shift recurrence stabilizes to, then we know
9177 // that the backedge is taken only a finite number of times.
9178
9179 ConstantInt *StableValue = nullptr;
9180 switch (OpCode) {
9181 default:
9182 llvm_unreachable("Impossible case!")::llvm::llvm_unreachable_internal("Impossible case!", "llvm/lib/Analysis/ScalarEvolution.cpp"
, 9182)
;
9183
9184 case Instruction::AShr: {
9185 // {K,ashr,<positive-constant>} stabilizes to signum(K) in at most
9186 // bitwidth(K) iterations.
9187 Value *FirstValue = PN->getIncomingValueForBlock(Predecessor);
9188 KnownBits Known = computeKnownBits(FirstValue, DL, 0, &AC,
9189 Predecessor->getTerminator(), &DT);
9190 auto *Ty = cast<IntegerType>(RHS->getType());
9191 if (Known.isNonNegative())
9192 StableValue = ConstantInt::get(Ty, 0);
9193 else if (Known.isNegative())
9194 StableValue = ConstantInt::get(Ty, -1, true);
9195 else
9196 return getCouldNotCompute();
9197
9198 break;
9199 }
9200 case Instruction::LShr:
9201 case Instruction::Shl:
9202 // Both {K,lshr,<positive-constant>} and {K,shl,<positive-constant>}
9203 // stabilize to 0 in at most bitwidth(K) iterations.
9204 StableValue = ConstantInt::get(cast<IntegerType>(RHS->getType()), 0);
9205 break;
9206 }
9207
9208 auto *Result =
9209 ConstantFoldCompareInstOperands(Pred, StableValue, RHS, DL, &TLI);
9210 assert(Result->getType()->isIntegerTy(1) &&(static_cast <bool> (Result->getType()->isIntegerTy
(1) && "Otherwise cannot be an operand to a branch instruction"
) ? void (0) : __assert_fail ("Result->getType()->isIntegerTy(1) && \"Otherwise cannot be an operand to a branch instruction\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 9211, __extension__
__PRETTY_FUNCTION__))
9211 "Otherwise cannot be an operand to a branch instruction")(static_cast <bool> (Result->getType()->isIntegerTy
(1) && "Otherwise cannot be an operand to a branch instruction"
) ? void (0) : __assert_fail ("Result->getType()->isIntegerTy(1) && \"Otherwise cannot be an operand to a branch instruction\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 9211, __extension__
__PRETTY_FUNCTION__))
;
9212
9213 if (Result->isZeroValue()) {
9214 unsigned BitWidth = getTypeSizeInBits(RHS->getType());
9215 const SCEV *UpperBound =
9216 getConstant(getEffectiveSCEVType(RHS->getType()), BitWidth);
9217 return ExitLimit(getCouldNotCompute(), UpperBound, false);
9218 }
9219
9220 return getCouldNotCompute();
9221}
9222
9223/// Return true if we can constant fold an instruction of the specified type,
9224/// assuming that all operands were constants.
9225static bool CanConstantFold(const Instruction *I) {
9226 if (isa<BinaryOperator>(I) || isa<CmpInst>(I) ||
9227 isa<SelectInst>(I) || isa<CastInst>(I) || isa<GetElementPtrInst>(I) ||
9228 isa<LoadInst>(I) || isa<ExtractValueInst>(I))
9229 return true;
9230
9231 if (const CallInst *CI = dyn_cast<CallInst>(I))
9232 if (const Function *F = CI->getCalledFunction())
9233 return canConstantFoldCallTo(CI, F);
9234 return false;
9235}
9236
9237/// Determine whether this instruction can constant evolve within this loop
9238/// assuming its operands can all constant evolve.
9239static bool canConstantEvolve(Instruction *I, const Loop *L) {
9240 // An instruction outside of the loop can't be derived from a loop PHI.
9241 if (!L->contains(I)) return false;
9242
9243 if (isa<PHINode>(I)) {
9244 // We don't currently keep track of the control flow needed to evaluate
9245 // PHIs, so we cannot handle PHIs inside of loops.
9246 return L->getHeader() == I->getParent();
9247 }
9248
9249 // If we won't be able to constant fold this expression even if the operands
9250 // are constants, bail early.
9251 return CanConstantFold(I);
9252}
9253
9254/// getConstantEvolvingPHIOperands - Implement getConstantEvolvingPHI by
9255/// recursing through each instruction operand until reaching a loop header phi.
9256static PHINode *
9257getConstantEvolvingPHIOperands(Instruction *UseInst, const Loop *L,
9258 DenseMap<Instruction *, PHINode *> &PHIMap,
9259 unsigned Depth) {
9260 if (Depth > MaxConstantEvolvingDepth)
9261 return nullptr;
9262
9263 // Otherwise, we can evaluate this instruction if all of its operands are
9264 // constant or derived from a PHI node themselves.
9265 PHINode *PHI = nullptr;
9266 for (Value *Op : UseInst->operands()) {
9267 if (isa<Constant>(Op)) continue;
9268
9269 Instruction *OpInst = dyn_cast<Instruction>(Op);
9270 if (!OpInst || !canConstantEvolve(OpInst, L)) return nullptr;
9271
9272 PHINode *P = dyn_cast<PHINode>(OpInst);
9273 if (!P)
9274 // If this operand is already visited, reuse the prior result.
9275 // We may have P != PHI if this is the deepest point at which the
9276 // inconsistent paths meet.
9277 P = PHIMap.lookup(OpInst);
9278 if (!P) {
9279 // Recurse and memoize the results, whether a phi is found or not.
9280 // This recursive call invalidates pointers into PHIMap.
9281 P = getConstantEvolvingPHIOperands(OpInst, L, PHIMap, Depth + 1);
9282 PHIMap[OpInst] = P;
9283 }
9284 if (!P)
9285 return nullptr; // Not evolving from PHI
9286 if (PHI && PHI != P)
9287 return nullptr; // Evolving from multiple different PHIs.
9288 PHI = P;
9289 }
9290 // This is a expression evolving from a constant PHI!
9291 return PHI;
9292}
9293
9294/// getConstantEvolvingPHI - Given an LLVM value and a loop, return a PHI node
9295/// in the loop that V is derived from. We allow arbitrary operations along the
9296/// way, but the operands of an operation must either be constants or a value
9297/// derived from a constant PHI. If this expression does not fit with these
9298/// constraints, return null.
9299static PHINode *getConstantEvolvingPHI(Value *V, const Loop *L) {
9300 Instruction *I = dyn_cast<Instruction>(V);
9301 if (!I || !canConstantEvolve(I, L)) return nullptr;
9302
9303 if (PHINode *PN = dyn_cast<PHINode>(I))
9304 return PN;
9305
9306 // Record non-constant instructions contained by the loop.
9307 DenseMap<Instruction *, PHINode *> PHIMap;
9308 return getConstantEvolvingPHIOperands(I, L, PHIMap, 0);
9309}
9310
9311/// EvaluateExpression - Given an expression that passes the
9312/// getConstantEvolvingPHI predicate, evaluate its value assuming the PHI node
9313/// in the loop has the value PHIVal. If we can't fold this expression for some
9314/// reason, return null.
9315static Constant *EvaluateExpression(Value *V, const Loop *L,
9316 DenseMap<Instruction *, Constant *> &Vals,
9317 const DataLayout &DL,
9318 const TargetLibraryInfo *TLI) {
9319 // Convenient constant check, but redundant for recursive calls.
9320 if (Constant *C = dyn_cast<Constant>(V)) return C;
9321 Instruction *I = dyn_cast<Instruction>(V);
9322 if (!I) return nullptr;
9323
9324 if (Constant *C = Vals.lookup(I)) return C;
9325
9326 // An instruction inside the loop depends on a value outside the loop that we
9327 // weren't given a mapping for, or a value such as a call inside the loop.
9328 if (!canConstantEvolve(I, L)) return nullptr;
9329
9330 // An unmapped PHI can be due to a branch or another loop inside this loop,
9331 // or due to this not being the initial iteration through a loop where we
9332 // couldn't compute the evolution of this particular PHI last time.
9333 if (isa<PHINode>(I)) return nullptr;
9334
9335 std::vector<Constant*> Operands(I->getNumOperands());
9336
9337 for (unsigned i = 0, e = I->getNumOperands(); i != e; ++i) {
9338 Instruction *Operand = dyn_cast<Instruction>(I->getOperand(i));
9339 if (!Operand) {
9340 Operands[i] = dyn_cast<Constant>(I->getOperand(i));
9341 if (!Operands[i]) return nullptr;
9342 continue;
9343 }
9344 Constant *C = EvaluateExpression(Operand, L, Vals, DL, TLI);
9345 Vals[Operand] = C;
9346 if (!C) return nullptr;
9347 Operands[i] = C;
9348 }
9349
9350 return ConstantFoldInstOperands(I, Operands, DL, TLI);
9351}
9352
9353
9354// If every incoming value to PN except the one for BB is a specific Constant,
9355// return that, else return nullptr.
9356static Constant *getOtherIncomingValue(PHINode *PN, BasicBlock *BB) {
9357 Constant *IncomingVal = nullptr;
9358
9359 for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) {
9360 if (PN->getIncomingBlock(i) == BB)
9361 continue;
9362
9363 auto *CurrentVal = dyn_cast<Constant>(PN->getIncomingValue(i));
9364 if (!CurrentVal)
9365 return nullptr;
9366
9367 if (IncomingVal != CurrentVal) {
9368 if (IncomingVal)
9369 return nullptr;
9370 IncomingVal = CurrentVal;
9371 }
9372 }
9373
9374 return IncomingVal;
9375}
9376
9377/// getConstantEvolutionLoopExitValue - If we know that the specified Phi is
9378/// in the header of its containing loop, we know the loop executes a
9379/// constant number of times, and the PHI node is just a recurrence
9380/// involving constants, fold it.
9381Constant *
9382ScalarEvolution::getConstantEvolutionLoopExitValue(PHINode *PN,
9383 const APInt &BEs,
9384 const Loop *L) {
9385 auto I = ConstantEvolutionLoopExitValue.find(PN);
9386 if (I != ConstantEvolutionLoopExitValue.end())
9387 return I->second;
9388
9389 if (BEs.ugt(MaxBruteForceIterations))
9390 return ConstantEvolutionLoopExitValue[PN] = nullptr; // Not going to evaluate it.
9391
9392 Constant *&RetVal = ConstantEvolutionLoopExitValue[PN];
9393
9394 DenseMap<Instruction *, Constant *> CurrentIterVals;
9395 BasicBlock *Header = L->getHeader();
9396 assert(PN->getParent() == Header && "Can't evaluate PHI not in loop header!")(static_cast <bool> (PN->getParent() == Header &&
"Can't evaluate PHI not in loop header!") ? void (0) : __assert_fail
("PN->getParent() == Header && \"Can't evaluate PHI not in loop header!\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 9396, __extension__
__PRETTY_FUNCTION__))
;
9397
9398 BasicBlock *Latch = L->getLoopLatch();
9399 if (!Latch)
9400 return nullptr;
9401
9402 for (PHINode &PHI : Header->phis()) {
9403 if (auto *StartCST = getOtherIncomingValue(&PHI, Latch))
9404 CurrentIterVals[&PHI] = StartCST;
9405 }
9406 if (!CurrentIterVals.count(PN))
9407 return RetVal = nullptr;
9408
9409 Value *BEValue = PN->getIncomingValueForBlock(Latch);
9410
9411 // Execute the loop symbolically to determine the exit value.
9412 assert(BEs.getActiveBits() < CHAR_BIT * sizeof(unsigned) &&(static_cast <bool> (BEs.getActiveBits() < 8 * sizeof
(unsigned) && "BEs is <= MaxBruteForceIterations which is an 'unsigned'!"
) ? void (0) : __assert_fail ("BEs.getActiveBits() < CHAR_BIT * sizeof(unsigned) && \"BEs is <= MaxBruteForceIterations which is an 'unsigned'!\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 9413, __extension__
__PRETTY_FUNCTION__))
9413 "BEs is <= MaxBruteForceIterations which is an 'unsigned'!")(static_cast <bool> (BEs.getActiveBits() < 8 * sizeof
(unsigned) && "BEs is <= MaxBruteForceIterations which is an 'unsigned'!"
) ? void (0) : __assert_fail ("BEs.getActiveBits() < CHAR_BIT * sizeof(unsigned) && \"BEs is <= MaxBruteForceIterations which is an 'unsigned'!\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 9413, __extension__
__PRETTY_FUNCTION__))
;
9414
9415 unsigned NumIterations = BEs.getZExtValue(); // must be in range
9416 unsigned IterationNum = 0;
9417 const DataLayout &DL = getDataLayout();
9418 for (; ; ++IterationNum) {
9419 if (IterationNum == NumIterations)
9420 return RetVal = CurrentIterVals[PN]; // Got exit value!
9421
9422 // Compute the value of the PHIs for the next iteration.
9423 // EvaluateExpression adds non-phi values to the CurrentIterVals map.
9424 DenseMap<Instruction *, Constant *> NextIterVals;
9425 Constant *NextPHI =
9426 EvaluateExpression(BEValue, L, CurrentIterVals, DL, &TLI);
9427 if (!NextPHI)
9428 return nullptr; // Couldn't evaluate!
9429 NextIterVals[PN] = NextPHI;
9430
9431 bool StoppedEvolving = NextPHI == CurrentIterVals[PN];
9432
9433 // Also evaluate the other PHI nodes. However, we don't get to stop if we
9434 // cease to be able to evaluate one of them or if they stop evolving,
9435 // because that doesn't necessarily prevent us from computing PN.
9436 SmallVector<std::pair<PHINode *, Constant *>, 8> PHIsToCompute;
9437 for (const auto &I : CurrentIterVals) {
9438 PHINode *PHI = dyn_cast<PHINode>(I.first);
9439 if (!PHI || PHI == PN || PHI->getParent() != Header) continue;
9440 PHIsToCompute.emplace_back(PHI, I.second);
9441 }
9442 // We use two distinct loops because EvaluateExpression may invalidate any
9443 // iterators into CurrentIterVals.
9444 for (const auto &I : PHIsToCompute) {
9445 PHINode *PHI = I.first;
9446 Constant *&NextPHI = NextIterVals[PHI];
9447 if (!NextPHI) { // Not already computed.
9448 Value *BEValue = PHI->getIncomingValueForBlock(Latch);
9449 NextPHI = EvaluateExpression(BEValue, L, CurrentIterVals, DL, &TLI);
9450 }
9451 if (NextPHI != I.second)
9452 StoppedEvolving = false;
9453 }
9454
9455 // If all entries in CurrentIterVals == NextIterVals then we can stop
9456 // iterating, the loop can't continue to change.
9457 if (StoppedEvolving)
9458 return RetVal = CurrentIterVals[PN];
9459
9460 CurrentIterVals.swap(NextIterVals);
9461 }
9462}
9463
9464const SCEV *ScalarEvolution::computeExitCountExhaustively(const Loop *L,
9465 Value *Cond,
9466 bool ExitWhen) {
9467 PHINode *PN = getConstantEvolvingPHI(Cond, L);
9468 if (!PN) return getCouldNotCompute();
9469
9470 // If the loop is canonicalized, the PHI will have exactly two entries.
9471 // That's the only form we support here.
9472 if (PN->getNumIncomingValues() != 2) return getCouldNotCompute();
9473
9474 DenseMap<Instruction *, Constant *> CurrentIterVals;
9475 BasicBlock *Header = L->getHeader();
9476 assert(PN->getParent() == Header && "Can't evaluate PHI not in loop header!")(static_cast <bool> (PN->getParent() == Header &&
"Can't evaluate PHI not in loop header!") ? void (0) : __assert_fail
("PN->getParent() == Header && \"Can't evaluate PHI not in loop header!\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 9476, __extension__
__PRETTY_FUNCTION__))
;
9477
9478 BasicBlock *Latch = L->getLoopLatch();
9479 assert(Latch && "Should follow from NumIncomingValues == 2!")(static_cast <bool> (Latch && "Should follow from NumIncomingValues == 2!"
) ? void (0) : __assert_fail ("Latch && \"Should follow from NumIncomingValues == 2!\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 9479, __extension__
__PRETTY_FUNCTION__))
;
9480
9481 for (PHINode &PHI : Header->phis()) {
9482 if (auto *StartCST = getOtherIncomingValue(&PHI, Latch))
9483 CurrentIterVals[&PHI] = StartCST;
9484 }
9485 if (!CurrentIterVals.count(PN))
9486 return getCouldNotCompute();
9487
9488 // Okay, we find a PHI node that defines the trip count of this loop. Execute
9489 // the loop symbolically to determine when the condition gets a value of
9490 // "ExitWhen".
9491 unsigned MaxIterations = MaxBruteForceIterations; // Limit analysis.
9492 const DataLayout &DL = getDataLayout();
9493 for (unsigned IterationNum = 0; IterationNum != MaxIterations;++IterationNum){
9494 auto *CondVal = dyn_cast_or_null<ConstantInt>(
9495 EvaluateExpression(Cond, L, CurrentIterVals, DL, &TLI));
9496
9497 // Couldn't symbolically evaluate.
9498 if (!CondVal) return getCouldNotCompute();
9499
9500 if (CondVal->getValue() == uint64_t(ExitWhen)) {
9501 ++NumBruteForceTripCountsComputed;
9502 return getConstant(Type::getInt32Ty(getContext()), IterationNum);
9503 }
9504
9505 // Update all the PHI nodes for the next iteration.
9506 DenseMap<Instruction *, Constant *> NextIterVals;
9507
9508 // Create a list of which PHIs we need to compute. We want to do this before
9509 // calling EvaluateExpression on them because that may invalidate iterators
9510 // into CurrentIterVals.
9511 SmallVector<PHINode *, 8> PHIsToCompute;
9512 for (const auto &I : CurrentIterVals) {
9513 PHINode *PHI = dyn_cast<PHINode>(I.first);
9514 if (!PHI || PHI->getParent() != Header) continue;
9515 PHIsToCompute.push_back(PHI);
9516 }
9517 for (PHINode *PHI : PHIsToCompute) {
9518 Constant *&NextPHI = NextIterVals[PHI];
9519 if (NextPHI) continue; // Already computed!
9520
9521 Value *BEValue = PHI->getIncomingValueForBlock(Latch);
9522 NextPHI = EvaluateExpression(BEValue, L, CurrentIterVals, DL, &TLI);
9523 }
9524 CurrentIterVals.swap(NextIterVals);
9525 }
9526
9527 // Too many iterations were needed to evaluate.
9528 return getCouldNotCompute();
9529}
9530
9531const SCEV *ScalarEvolution::getSCEVAtScope(const SCEV *V, const Loop *L) {
9532 SmallVector<std::pair<const Loop *, const SCEV *>, 2> &Values =
9533 ValuesAtScopes[V];
9534 // Check to see if we've folded this expression at this loop before.
9535 for (auto &LS : Values)
9536 if (LS.first == L)
9537 return LS.second ? LS.second : V;
9538
9539 Values.emplace_back(L, nullptr);
9540
9541 // Otherwise compute it.
9542 const SCEV *C = computeSCEVAtScope(V, L);
9543 for (auto &LS : reverse(ValuesAtScopes[V]))
9544 if (LS.first == L) {
9545 LS.second = C;
9546 if (!isa<SCEVConstant>(C))
9547 ValuesAtScopesUsers[C].push_back({L, V});
9548 break;
9549 }
9550 return C;
9551}
9552
9553/// This builds up a Constant using the ConstantExpr interface. That way, we
9554/// will return Constants for objects which aren't represented by a
9555/// SCEVConstant, because SCEVConstant is restricted to ConstantInt.
9556/// Returns NULL if the SCEV isn't representable as a Constant.
9557static Constant *BuildConstantFromSCEV(const SCEV *V) {
9558 switch (V->getSCEVType()) {
9559 case scCouldNotCompute:
9560 case scAddRecExpr:
9561 return nullptr;
9562 case scConstant:
9563 return cast<SCEVConstant>(V)->getValue();
9564 case scUnknown:
9565 return dyn_cast<Constant>(cast<SCEVUnknown>(V)->getValue());
9566 case scSignExtend: {
9567 const SCEVSignExtendExpr *SS = cast<SCEVSignExtendExpr>(V);
9568 if (Constant *CastOp = BuildConstantFromSCEV(SS->getOperand()))
9569 return ConstantExpr::getSExt(CastOp, SS->getType());
9570 return nullptr;
9571 }
9572 case scZeroExtend: {
9573 const SCEVZeroExtendExpr *SZ = cast<SCEVZeroExtendExpr>(V);
9574 if (Constant *CastOp = BuildConstantFromSCEV(SZ->getOperand()))
9575 return ConstantExpr::getZExt(CastOp, SZ->getType());
9576 return nullptr;
9577 }
9578 case scPtrToInt: {
9579 const SCEVPtrToIntExpr *P2I = cast<SCEVPtrToIntExpr>(V);
9580 if (Constant *CastOp = BuildConstantFromSCEV(P2I->getOperand()))
9581 return ConstantExpr::getPtrToInt(CastOp, P2I->getType());
9582
9583 return nullptr;
9584 }
9585 case scTruncate: {
9586 const SCEVTruncateExpr *ST = cast<SCEVTruncateExpr>(V);
9587 if (Constant *CastOp = BuildConstantFromSCEV(ST->getOperand()))
9588 return ConstantExpr::getTrunc(CastOp, ST->getType());
9589 return nullptr;
9590 }
9591 case scAddExpr: {
9592 const SCEVAddExpr *SA = cast<SCEVAddExpr>(V);
9593 Constant *C = nullptr;
9594 for (const SCEV *Op : SA->operands()) {
9595 Constant *OpC = BuildConstantFromSCEV(Op);
9596 if (!OpC)
9597 return nullptr;
9598 if (!C) {
9599 C = OpC;
9600 continue;
9601 }
9602 assert(!C->getType()->isPointerTy() &&(static_cast <bool> (!C->getType()->isPointerTy()
&& "Can only have one pointer, and it must be last")
? void (0) : __assert_fail ("!C->getType()->isPointerTy() && \"Can only have one pointer, and it must be last\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 9603, __extension__
__PRETTY_FUNCTION__))
9603 "Can only have one pointer, and it must be last")(static_cast <bool> (!C->getType()->isPointerTy()
&& "Can only have one pointer, and it must be last")
? void (0) : __assert_fail ("!C->getType()->isPointerTy() && \"Can only have one pointer, and it must be last\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 9603, __extension__
__PRETTY_FUNCTION__))
;
9604 if (auto *PT = dyn_cast<PointerType>(OpC->getType())) {
9605 // The offsets have been converted to bytes. We can add bytes to an
9606 // i8* by GEP with the byte count in the first index.
9607 Type *DestPtrTy =
9608 Type::getInt8PtrTy(PT->getContext(), PT->getAddressSpace());
9609 OpC = ConstantExpr::getBitCast(OpC, DestPtrTy);
9610 C = ConstantExpr::getGetElementPtr(Type::getInt8Ty(C->getContext()),
9611 OpC, C);
9612 } else {
9613 C = ConstantExpr::getAdd(C, OpC);
9614 }
9615 }
9616 return C;
9617 }
9618 case scMulExpr: {
9619 const SCEVMulExpr *SM = cast<SCEVMulExpr>(V);
9620 Constant *C = nullptr;
9621 for (const SCEV *Op : SM->operands()) {
9622 assert(!Op->getType()->isPointerTy() && "Can't multiply pointers")(static_cast <bool> (!Op->getType()->isPointerTy(
) && "Can't multiply pointers") ? void (0) : __assert_fail
("!Op->getType()->isPointerTy() && \"Can't multiply pointers\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 9622, __extension__
__PRETTY_FUNCTION__))
;
9623 Constant *OpC = BuildConstantFromSCEV(Op);
9624 if (!OpC)
9625 return nullptr;
9626 C = C ? ConstantExpr::getMul(C, OpC) : OpC;
9627 }
9628 return C;
9629 }
9630 case scUDivExpr:
9631 case scSMaxExpr:
9632 case scUMaxExpr:
9633 case scSMinExpr:
9634 case scUMinExpr:
9635 case scSequentialUMinExpr:
9636 return nullptr; // TODO: smax, umax, smin, umax, umin_seq.
9637 }
9638 llvm_unreachable("Unknown SCEV kind!")::llvm::llvm_unreachable_internal("Unknown SCEV kind!", "llvm/lib/Analysis/ScalarEvolution.cpp"
, 9638)
;
9639}
9640
9641const SCEV *ScalarEvolution::computeSCEVAtScope(const SCEV *V, const Loop *L) {
9642 if (isa<SCEVConstant>(V)) return V;
9643
9644 // If this instruction is evolved from a constant-evolving PHI, compute the
9645 // exit value from the loop without using SCEVs.
9646 if (const SCEVUnknown *SU = dyn_cast<SCEVUnknown>(V)) {
9647 if (Instruction *I = dyn_cast<Instruction>(SU->getValue())) {
9648 if (PHINode *PN = dyn_cast<PHINode>(I)) {
9649 const Loop *CurrLoop = this->LI[I->getParent()];
9650 // Looking for loop exit value.
9651 if (CurrLoop && CurrLoop->getParentLoop() == L &&
9652 PN->getParent() == CurrLoop->getHeader()) {
9653 // Okay, there is no closed form solution for the PHI node. Check
9654 // to see if the loop that contains it has a known backedge-taken
9655 // count. If so, we may be able to force computation of the exit
9656 // value.
9657 const SCEV *BackedgeTakenCount = getBackedgeTakenCount(CurrLoop);
9658 // This trivial case can show up in some degenerate cases where
9659 // the incoming IR has not yet been fully simplified.
9660 if (BackedgeTakenCount->isZero()) {
9661 Value *InitValue = nullptr;
9662 bool MultipleInitValues = false;
9663 for (unsigned i = 0; i < PN->getNumIncomingValues(); i++) {
9664 if (!CurrLoop->contains(PN->getIncomingBlock(i))) {
9665 if (!InitValue)
9666 InitValue = PN->getIncomingValue(i);
9667 else if (InitValue != PN->getIncomingValue(i)) {
9668 MultipleInitValues = true;
9669 break;
9670 }
9671 }
9672 }
9673 if (!MultipleInitValues && InitValue)
9674 return getSCEV(InitValue);
9675 }
9676 // Do we have a loop invariant value flowing around the backedge
9677 // for a loop which must execute the backedge?
9678 if (!isa<SCEVCouldNotCompute>(BackedgeTakenCount) &&
9679 isKnownPositive(BackedgeTakenCount) &&
9680 PN->getNumIncomingValues() == 2) {
9681
9682 unsigned InLoopPred =
9683 CurrLoop->contains(PN->getIncomingBlock(0)) ? 0 : 1;
9684 Value *BackedgeVal = PN->getIncomingValue(InLoopPred);
9685 if (CurrLoop->isLoopInvariant(BackedgeVal))
9686 return getSCEV(BackedgeVal);
9687 }
9688 if (auto *BTCC = dyn_cast<SCEVConstant>(BackedgeTakenCount)) {
9689 // Okay, we know how many times the containing loop executes. If
9690 // this is a constant evolving PHI node, get the final value at
9691 // the specified iteration number.
9692 Constant *RV = getConstantEvolutionLoopExitValue(
9693 PN, BTCC->getAPInt(), CurrLoop);
9694 if (RV) return getSCEV(RV);
9695 }
9696 }
9697
9698 // If there is a single-input Phi, evaluate it at our scope. If we can
9699 // prove that this replacement does not break LCSSA form, use new value.
9700 if (PN->getNumOperands() == 1) {
9701 const SCEV *Input = getSCEV(PN->getOperand(0));
9702 const SCEV *InputAtScope = getSCEVAtScope(Input, L);
9703 // TODO: We can generalize it using LI.replacementPreservesLCSSAForm,
9704 // for the simplest case just support constants.
9705 if (isa<SCEVConstant>(InputAtScope)) return InputAtScope;
9706 }
9707 }
9708
9709 // Okay, this is an expression that we cannot symbolically evaluate
9710 // into a SCEV. Check to see if it's possible to symbolically evaluate
9711 // the arguments into constants, and if so, try to constant propagate the
9712 // result. This is particularly useful for computing loop exit values.
9713 if (CanConstantFold(I)) {
9714 SmallVector<Constant *, 4> Operands;
9715 bool MadeImprovement = false;
9716 for (Value *Op : I->operands()) {
9717 if (Constant *C = dyn_cast<Constant>(Op)) {
9718 Operands.push_back(C);
9719 continue;
9720 }
9721
9722 // If any of the operands is non-constant and if they are
9723 // non-integer and non-pointer, don't even try to analyze them
9724 // with scev techniques.
9725 if (!isSCEVable(Op->getType()))
9726 return V;
9727
9728 const SCEV *OrigV = getSCEV(Op);
9729 const SCEV *OpV = getSCEVAtScope(OrigV, L);
9730 MadeImprovement |= OrigV != OpV;
9731
9732 Constant *C = BuildConstantFromSCEV(OpV);
9733 if (!C) return V;
9734 if (C->getType() != Op->getType())
9735 C = ConstantExpr::getCast(CastInst::getCastOpcode(C, false,
9736 Op->getType(),
9737 false),
9738 C, Op->getType());
9739 Operands.push_back(C);
9740 }
9741
9742 // Check to see if getSCEVAtScope actually made an improvement.
9743 if (MadeImprovement) {
9744 Constant *C = nullptr;
9745 const DataLayout &DL = getDataLayout();
9746 C = ConstantFoldInstOperands(I, Operands, DL, &TLI);
9747 if (!C) return V;
9748 return getSCEV(C);
9749 }
9750 }
9751 }
9752
9753 // This is some other type of SCEVUnknown, just return it.
9754 return V;
9755 }
9756
9757 if (isa<SCEVCommutativeExpr>(V) || isa<SCEVSequentialMinMaxExpr>(V)) {
9758 const auto *Comm = cast<SCEVNAryExpr>(V);
9759 // Avoid performing the look-up in the common case where the specified
9760 // expression has no loop-variant portions.
9761 for (unsigned i = 0, e = Comm->getNumOperands(); i != e; ++i) {
9762 const SCEV *OpAtScope = getSCEVAtScope(Comm->getOperand(i), L);
9763 if (OpAtScope != Comm->getOperand(i)) {
9764 // Okay, at least one of these operands is loop variant but might be
9765 // foldable. Build a new instance of the folded commutative expression.
9766 SmallVector<const SCEV *, 8> NewOps(Comm->op_begin(),
9767 Comm->op_begin()+i);
9768 NewOps.push_back(OpAtScope);
9769
9770 for (++i; i != e; ++i) {
9771 OpAtScope = getSCEVAtScope(Comm->getOperand(i), L);
9772 NewOps.push_back(OpAtScope);
9773 }
9774 if (isa<SCEVAddExpr>(Comm))
9775 return getAddExpr(NewOps, Comm->getNoWrapFlags());
9776 if (isa<SCEVMulExpr>(Comm))
9777 return getMulExpr(NewOps, Comm->getNoWrapFlags());
9778 if (isa<SCEVMinMaxExpr>(Comm))
9779 return getMinMaxExpr(Comm->getSCEVType(), NewOps);
9780 if (isa<SCEVSequentialMinMaxExpr>(Comm))
9781 return getSequentialMinMaxExpr(Comm->getSCEVType(), NewOps);
9782 llvm_unreachable("Unknown commutative / sequential min/max SCEV type!")::llvm::llvm_unreachable_internal("Unknown commutative / sequential min/max SCEV type!"
, "llvm/lib/Analysis/ScalarEvolution.cpp", 9782)
;
9783 }
9784 }
9785 // If we got here, all operands are loop invariant.
9786 return Comm;
9787 }
9788
9789 if (const SCEVUDivExpr *Div = dyn_cast<SCEVUDivExpr>(V)) {
9790 const SCEV *LHS = getSCEVAtScope(Div->getLHS(), L);
9791 const SCEV *RHS = getSCEVAtScope(Div->getRHS(), L);
9792 if (LHS == Div->getLHS() && RHS == Div->getRHS())
9793 return Div; // must be loop invariant
9794 return getUDivExpr(LHS, RHS);
9795 }
9796
9797 // If this is a loop recurrence for a loop that does not contain L, then we
9798 // are dealing with the final value computed by the loop.
9799 if (const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(V)) {
9800 // First, attempt to evaluate each operand.
9801 // Avoid performing the look-up in the common case where the specified
9802 // expression has no loop-variant portions.
9803 for (unsigned i = 0, e = AddRec->getNumOperands(); i != e; ++i) {
9804 const SCEV *OpAtScope = getSCEVAtScope(AddRec->getOperand(i), L);
9805 if (OpAtScope == AddRec->getOperand(i))
9806 continue;
9807
9808 // Okay, at least one of these operands is loop variant but might be
9809 // foldable. Build a new instance of the folded commutative expression.
9810 SmallVector<const SCEV *, 8> NewOps(AddRec->op_begin(),
9811 AddRec->op_begin()+i);
9812 NewOps.push_back(OpAtScope);
9813 for (++i; i != e; ++i)
9814 NewOps.push_back(getSCEVAtScope(AddRec->getOperand(i), L));
9815
9816 const SCEV *FoldedRec =
9817 getAddRecExpr(NewOps, AddRec->getLoop(),
9818 AddRec->getNoWrapFlags(SCEV::FlagNW));
9819 AddRec = dyn_cast<SCEVAddRecExpr>(FoldedRec);
9820 // The addrec may be folded to a nonrecurrence, for example, if the
9821 // induction variable is multiplied by zero after constant folding. Go
9822 // ahead and return the folded value.
9823 if (!AddRec)
9824 return FoldedRec;
9825 break;
9826 }
9827
9828 // If the scope is outside the addrec's loop, evaluate it by using the
9829 // loop exit value of the addrec.
9830 if (!AddRec->getLoop()->contains(L)) {
9831 // To evaluate this recurrence, we need to know how many times the AddRec
9832 // loop iterates. Compute this now.
9833 const SCEV *BackedgeTakenCount = getBackedgeTakenCount(AddRec->getLoop());
9834 if (BackedgeTakenCount == getCouldNotCompute()) return AddRec;
9835
9836 // Then, evaluate the AddRec.
9837 return AddRec->evaluateAtIteration(BackedgeTakenCount, *this);
9838 }
9839
9840 return AddRec;
9841 }
9842
9843 if (const SCEVCastExpr *Cast = dyn_cast<SCEVCastExpr>(V)) {
9844 const SCEV *Op = getSCEVAtScope(Cast->getOperand(), L);
9845 if (Op == Cast->getOperand())
9846 return Cast; // must be loop invariant
9847 return getCastExpr(Cast->getSCEVType(), Op, Cast->getType());
9848 }
9849
9850 llvm_unreachable("Unknown SCEV type!")::llvm::llvm_unreachable_internal("Unknown SCEV type!", "llvm/lib/Analysis/ScalarEvolution.cpp"
, 9850)
;
9851}
9852
9853const SCEV *ScalarEvolution::getSCEVAtScope(Value *V, const Loop *L) {
9854 return getSCEVAtScope(getSCEV(V), L);
9855}
9856
9857const SCEV *ScalarEvolution::stripInjectiveFunctions(const SCEV *S) const {
9858 if (const SCEVZeroExtendExpr *ZExt = dyn_cast<SCEVZeroExtendExpr>(S))
9859 return stripInjectiveFunctions(ZExt->getOperand());
9860 if (const SCEVSignExtendExpr *SExt = dyn_cast<SCEVSignExtendExpr>(S))
9861 return stripInjectiveFunctions(SExt->getOperand());
9862 return S;
9863}
9864
9865/// Finds the minimum unsigned root of the following equation:
9866///
9867/// A * X = B (mod N)
9868///
9869/// where N = 2^BW and BW is the common bit width of A and B. The signedness of
9870/// A and B isn't important.
9871///
9872/// If the equation does not have a solution, SCEVCouldNotCompute is returned.
9873static const SCEV *SolveLinEquationWithOverflow(const APInt &A, const SCEV *B,
9874 ScalarEvolution &SE) {
9875 uint32_t BW = A.getBitWidth();
9876 assert(BW == SE.getTypeSizeInBits(B->getType()))(static_cast <bool> (BW == SE.getTypeSizeInBits(B->getType
())) ? void (0) : __assert_fail ("BW == SE.getTypeSizeInBits(B->getType())"
, "llvm/lib/Analysis/ScalarEvolution.cpp", 9876, __extension__
__PRETTY_FUNCTION__))
;
9877 assert(A != 0 && "A must be non-zero.")(static_cast <bool> (A != 0 && "A must be non-zero."
) ? void (0) : __assert_fail ("A != 0 && \"A must be non-zero.\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 9877, __extension__
__PRETTY_FUNCTION__))
;
9878
9879 // 1. D = gcd(A, N)
9880 //
9881 // The gcd of A and N may have only one prime factor: 2. The number of
9882 // trailing zeros in A is its multiplicity
9883 uint32_t Mult2 = A.countTrailingZeros();
9884 // D = 2^Mult2
9885
9886 // 2. Check if B is divisible by D.
9887 //
9888 // B is divisible by D if and only if the multiplicity of prime factor 2 for B
9889 // is not less than multiplicity of this prime factor for D.
9890 if (SE.GetMinTrailingZeros(B) < Mult2)
9891 return SE.getCouldNotCompute();
9892
9893 // 3. Compute I: the multiplicative inverse of (A / D) in arithmetic
9894 // modulo (N / D).
9895 //
9896 // If D == 1, (N / D) == N == 2^BW, so we need one extra bit to represent
9897 // (N / D) in general. The inverse itself always fits into BW bits, though,
9898 // so we immediately truncate it.
9899 APInt AD = A.lshr(Mult2).zext(BW + 1); // AD = A / D
9900 APInt Mod(BW + 1, 0);
9901 Mod.setBit(BW - Mult2); // Mod = N / D
9902 APInt I = AD.multiplicativeInverse(Mod).trunc(BW);
9903
9904 // 4. Compute the minimum unsigned root of the equation:
9905 // I * (B / D) mod (N / D)
9906 // To simplify the computation, we factor out the divide by D:
9907 // (I * B mod N) / D
9908 const SCEV *D = SE.getConstant(APInt::getOneBitSet(BW, Mult2));
9909 return SE.getUDivExactExpr(SE.getMulExpr(B, SE.getConstant(I)), D);
9910}
9911
9912/// For a given quadratic addrec, generate coefficients of the corresponding
9913/// quadratic equation, multiplied by a common value to ensure that they are
9914/// integers.
9915/// The returned value is a tuple { A, B, C, M, BitWidth }, where
9916/// Ax^2 + Bx + C is the quadratic function, M is the value that A, B and C
9917/// were multiplied by, and BitWidth is the bit width of the original addrec
9918/// coefficients.
9919/// This function returns None if the addrec coefficients are not compile-
9920/// time constants.
9921static Optional<std::tuple<APInt, APInt, APInt, APInt, unsigned>>
9922GetQuadraticEquation(const SCEVAddRecExpr *AddRec) {
9923 assert(AddRec->getNumOperands() == 3 && "This is not a quadratic chrec!")(static_cast <bool> (AddRec->getNumOperands() == 3 &&
"This is not a quadratic chrec!") ? void (0) : __assert_fail
("AddRec->getNumOperands() == 3 && \"This is not a quadratic chrec!\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 9923, __extension__
__PRETTY_FUNCTION__))
;
9924 const SCEVConstant *LC = dyn_cast<SCEVConstant>(AddRec->getOperand(0));
9925 const SCEVConstant *MC = dyn_cast<SCEVConstant>(AddRec->getOperand(1));
9926 const SCEVConstant *NC = dyn_cast<SCEVConstant>(AddRec->getOperand(2));
9927 LLVM_DEBUG(dbgs() << __func__ << ": analyzing quadratic addrec: "do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType
("scalar-evolution")) { dbgs() << __func__ << ": analyzing quadratic addrec: "
<< *AddRec << '\n'; } } while (false)
9928 << *AddRec << '\n')do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType
("scalar-evolution")) { dbgs() << __func__ << ": analyzing quadratic addrec: "
<< *AddRec << '\n'; } } while (false)
;
9929
9930 // We currently can only solve this if the coefficients are constants.
9931 if (!LC || !MC || !NC) {
9932 LLVM_DEBUG(dbgs() << __func__ << ": coefficients are not constant\n")do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType
("scalar-evolution")) { dbgs() << __func__ << ": coefficients are not constant\n"
; } } while (false)
;
9933 return None;
9934 }
9935
9936 APInt L = LC->getAPInt();
9937 APInt M = MC->getAPInt();
9938 APInt N = NC->getAPInt();
9939 assert(!N.isZero() && "This is not a quadratic addrec")(static_cast <bool> (!N.isZero() && "This is not a quadratic addrec"
) ? void (0) : __assert_fail ("!N.isZero() && \"This is not a quadratic addrec\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 9939, __extension__
__PRETTY_FUNCTION__))
;
9940
9941 unsigned BitWidth = LC->getAPInt().getBitWidth();
9942 unsigned NewWidth = BitWidth + 1;
9943 LLVM_DEBUG(dbgs() << __func__ << ": addrec coeff bw: "do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType
("scalar-evolution")) { dbgs() << __func__ << ": addrec coeff bw: "
<< BitWidth << '\n'; } } while (false)
9944 << BitWidth << '\n')do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType
("scalar-evolution")) { dbgs() << __func__ << ": addrec coeff bw: "
<< BitWidth << '\n'; } } while (false)
;
9945 // The sign-extension (as opposed to a zero-extension) here matches the
9946 // extension used in SolveQuadraticEquationWrap (with the same motivation).
9947 N = N.sext(NewWidth);
9948 M = M.sext(NewWidth);
9949 L = L.sext(NewWidth);
9950
9951 // The increments are M, M+N, M+2N, ..., so the accumulated values are
9952 // L+M, (L+M)+(M+N), (L+M)+(M+N)+(M+2N), ..., that is,
9953 // L+M, L+2M+N, L+3M+3N, ...
9954 // After n iterations the accumulated value Acc is L + nM + n(n-1)/2 N.
9955 //
9956 // The equation Acc = 0 is then
9957 // L + nM + n(n-1)/2 N = 0, or 2L + 2M n + n(n-1) N = 0.
9958 // In a quadratic form it becomes:
9959 // N n^2 + (2M-N) n + 2L = 0.
9960
9961 APInt A = N;
9962 APInt B = 2 * M - A;
9963 APInt C = 2 * L;
9964 APInt T = APInt(NewWidth, 2);
9965 LLVM_DEBUG(dbgs() << __func__ << ": equation " << A << "x^2 + " << Bdo { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType
("scalar-evolution")) { dbgs() << __func__ << ": equation "
<< A << "x^2 + " << B << "x + " <<
C << ", coeff bw: " << NewWidth << ", multiplied by "
<< T << '\n'; } } while (false)
9966 << "x + " << C << ", coeff bw: " << NewWidthdo { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType
("scalar-evolution")) { dbgs() << __func__ << ": equation "
<< A << "x^2 + " << B << "x + " <<
C << ", coeff bw: " << NewWidth << ", multiplied by "
<< T << '\n'; } } while (false)
9967 << ", multiplied by " << T << '\n')do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType
("scalar-evolution")) { dbgs() << __func__ << ": equation "
<< A << "x^2 + " << B << "x + " <<
C << ", coeff bw: " << NewWidth << ", multiplied by "
<< T << '\n'; } } while (false)
;
9968 return std::make_tuple(A, B, C, T, BitWidth);
9969}
9970
9971/// Helper function to compare optional APInts:
9972/// (a) if X and Y both exist, return min(X, Y),
9973/// (b) if neither X nor Y exist, return None,
9974/// (c) if exactly one of X and Y exists, return that value.
9975static Optional<APInt> MinOptional(Optional<APInt> X, Optional<APInt> Y) {
9976 if (X && Y) {
9977 unsigned W = std::max(X->getBitWidth(), Y->getBitWidth());
9978 APInt XW = X->sext(W);
9979 APInt YW = Y->sext(W);
9980 return XW.slt(YW) ? *X : *Y;
9981 }
9982 if (!X && !Y)
9983 return None;
9984 return X ? *X : *Y;
9985}
9986
9987/// Helper function to truncate an optional APInt to a given BitWidth.
9988/// When solving addrec-related equations, it is preferable to return a value
9989/// that has the same bit width as the original addrec's coefficients. If the
9990/// solution fits in the original bit width, truncate it (except for i1).
9991/// Returning a value of a different bit width may inhibit some optimizations.
9992///
9993/// In general, a solution to a quadratic equation generated from an addrec
9994/// may require BW+1 bits, where BW is the bit width of the addrec's
9995/// coefficients. The reason is that the coefficients of the quadratic
9996/// equation are BW+1 bits wide (to avoid truncation when converting from
9997/// the addrec to the equation).
9998static Optional<APInt> TruncIfPossible(Optional<APInt> X, unsigned BitWidth) {
9999 if (!X)
10000 return None;
10001 unsigned W = X->getBitWidth();
10002 if (BitWidth > 1 && BitWidth < W && X->isIntN(BitWidth))
10003 return X->trunc(BitWidth);
10004 return X;
10005}
10006
10007/// Let c(n) be the value of the quadratic chrec {L,+,M,+,N} after n
10008/// iterations. The values L, M, N are assumed to be signed, and they
10009/// should all have the same bit widths.
10010/// Find the least n >= 0 such that c(n) = 0 in the arithmetic modulo 2^BW,
10011/// where BW is the bit width of the addrec's coefficients.
10012/// If the calculated value is a BW-bit integer (for BW > 1), it will be
10013/// returned as such, otherwise the bit width of the returned value may
10014/// be greater than BW.
10015///
10016/// This function returns None if
10017/// (a) the addrec coefficients are not constant, or
10018/// (b) SolveQuadraticEquationWrap was unable to find a solution. For cases
10019/// like x^2 = 5, no integer solutions exist, in other cases an integer
10020/// solution may exist, but SolveQuadraticEquationWrap may fail to find it.
10021static Optional<APInt>
10022SolveQuadraticAddRecExact(const SCEVAddRecExpr *AddRec, ScalarEvolution &SE) {
10023 APInt A, B, C, M;
10024 unsigned BitWidth;
10025 auto T = GetQuadraticEquation(AddRec);
10026 if (!T)
10027 return None;
10028
10029 std::tie(A, B, C, M, BitWidth) = *T;
10030 LLVM_DEBUG(dbgs() << __func__ << ": solving for unsigned overflow\n")do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType
("scalar-evolution")) { dbgs() << __func__ << ": solving for unsigned overflow\n"
; } } while (false)
;
10031 Optional<APInt> X = APIntOps::SolveQuadraticEquationWrap(A, B, C, BitWidth+1);
10032 if (!X)
10033 return None;
10034
10035 ConstantInt *CX = ConstantInt::get(SE.getContext(), *X);
10036 ConstantInt *V = EvaluateConstantChrecAtConstant(AddRec, CX, SE);
10037 if (!V->isZero())
10038 return None;
10039
10040 return TruncIfPossible(X, BitWidth);
10041}
10042
10043/// Let c(n) be the value of the quadratic chrec {0,+,M,+,N} after n
10044/// iterations. The values M, N are assumed to be signed, and they
10045/// should all have the same bit widths.
10046/// Find the least n such that c(n) does not belong to the given range,
10047/// while c(n-1) does.
10048///
10049/// This function returns None if
10050/// (a) the addrec coefficients are not constant, or
10051/// (b) SolveQuadraticEquationWrap was unable to find a solution for the
10052/// bounds of the range.
10053static Optional<APInt>
10054SolveQuadraticAddRecRange(const SCEVAddRecExpr *AddRec,
10055 const ConstantRange &Range, ScalarEvolution &SE) {
10056 assert(AddRec->getOperand(0)->isZero() &&(static_cast <bool> (AddRec->getOperand(0)->isZero
() && "Starting value of addrec should be 0") ? void (
0) : __assert_fail ("AddRec->getOperand(0)->isZero() && \"Starting value of addrec should be 0\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 10057, __extension__
__PRETTY_FUNCTION__))
10057 "Starting value of addrec should be 0")(static_cast <bool> (AddRec->getOperand(0)->isZero
() && "Starting value of addrec should be 0") ? void (
0) : __assert_fail ("AddRec->getOperand(0)->isZero() && \"Starting value of addrec should be 0\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 10057, __extension__
__PRETTY_FUNCTION__))
;
10058 LLVM_DEBUG(dbgs() << __func__ << ": solving boundary crossing for range "do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType
("scalar-evolution")) { dbgs() << __func__ << ": solving boundary crossing for range "
<< Range << ", addrec " << *AddRec <<
'\n'; } } while (false)
10059 << Range << ", addrec " << *AddRec << '\n')do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType
("scalar-evolution")) { dbgs() << __func__ << ": solving boundary crossing for range "
<< Range << ", addrec " << *AddRec <<
'\n'; } } while (false)
;
10060 // This case is handled in getNumIterationsInRange. Here we can assume that
10061 // we start in the range.
10062 assert(Range.contains(APInt(SE.getTypeSizeInBits(AddRec->getType()), 0)) &&(static_cast <bool> (Range.contains(APInt(SE.getTypeSizeInBits
(AddRec->getType()), 0)) && "Addrec's initial value should be in range"
) ? void (0) : __assert_fail ("Range.contains(APInt(SE.getTypeSizeInBits(AddRec->getType()), 0)) && \"Addrec's initial value should be in range\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 10063, __extension__
__PRETTY_FUNCTION__))
10063 "Addrec's initial value should be in range")(static_cast <bool> (Range.contains(APInt(SE.getTypeSizeInBits
(AddRec->getType()), 0)) && "Addrec's initial value should be in range"
) ? void (0) : __assert_fail ("Range.contains(APInt(SE.getTypeSizeInBits(AddRec->getType()), 0)) && \"Addrec's initial value should be in range\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 10063, __extension__
__PRETTY_FUNCTION__))
;
10064
10065 APInt A, B, C, M;
10066 unsigned BitWidth;
10067 auto T = GetQuadraticEquation(AddRec);
10068 if (!T)
10069 return None;
10070
10071 // Be careful about the return value: there can be two reasons for not
10072 // returning an actual number. First, if no solutions to the equations
10073 // were found, and second, if the solutions don't leave the given range.
10074 // The first case means that the actual solution is "unknown", the second
10075 // means that it's known, but not valid. If the solution is unknown, we
10076 // cannot make any conclusions.
10077 // Return a pair: the optional solution and a flag indicating if the
10078 // solution was found.
10079 auto SolveForBoundary = [&](APInt Bound) -> std::pair<Optional<APInt>,bool> {
10080 // Solve for signed overflow and unsigned overflow, pick the lower
10081 // solution.
10082 LLVM_DEBUG(dbgs() << "SolveQuadraticAddRecRange: checking boundary "do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType
("scalar-evolution")) { dbgs() << "SolveQuadraticAddRecRange: checking boundary "
<< Bound << " (before multiplying by " << M
<< ")\n"; } } while (false)
10083 << Bound << " (before multiplying by " << M << ")\n")do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType
("scalar-evolution")) { dbgs() << "SolveQuadraticAddRecRange: checking boundary "
<< Bound << " (before multiplying by " << M
<< ")\n"; } } while (false)
;
10084 Bound *= M; // The quadratic equation multiplier.
10085
10086 Optional<APInt> SO;
10087 if (BitWidth > 1) {
10088 LLVM_DEBUG(dbgs() << "SolveQuadraticAddRecRange: solving for "do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType
("scalar-evolution")) { dbgs() << "SolveQuadraticAddRecRange: solving for "
"signed overflow\n"; } } while (false)
10089 "signed overflow\n")do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType
("scalar-evolution")) { dbgs() << "SolveQuadraticAddRecRange: solving for "
"signed overflow\n"; } } while (false)
;
10090 SO = APIntOps::SolveQuadraticEquationWrap(A, B, -Bound, BitWidth);
10091 }
10092 LLVM_DEBUG(dbgs() << "SolveQuadraticAddRecRange: solving for "do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType
("scalar-evolution")) { dbgs() << "SolveQuadraticAddRecRange: solving for "
"unsigned overflow\n"; } } while (false)
10093 "unsigned overflow\n")do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType
("scalar-evolution")) { dbgs() << "SolveQuadraticAddRecRange: solving for "
"unsigned overflow\n"; } } while (false)
;
10094 Optional<APInt> UO = APIntOps::SolveQuadraticEquationWrap(A, B, -Bound,
10095 BitWidth+1);
10096
10097 auto LeavesRange = [&] (const APInt &X) {
10098 ConstantInt *C0 = ConstantInt::get(SE.getContext(), X);
10099 ConstantInt *V0 = EvaluateConstantChrecAtConstant(AddRec, C0, SE);
10100 if (Range.contains(V0->getValue()))
10101 return false;
10102 // X should be at least 1, so X-1 is non-negative.
10103 ConstantInt *C1 = ConstantInt::get(SE.getContext(), X-1);
10104 ConstantInt *V1 = EvaluateConstantChrecAtConstant(AddRec, C1, SE);
10105 if (Range.contains(V1->getValue()))
10106 return true;
10107 return false;
10108 };
10109
10110 // If SolveQuadraticEquationWrap returns None, it means that there can
10111 // be a solution, but the function failed to find it. We cannot treat it
10112 // as "no solution".
10113 if (!SO || !UO)
10114 return { None, false };
10115
10116 // Check the smaller value first to see if it leaves the range.
10117 // At this point, both SO and UO must have values.
10118 Optional<APInt> Min = MinOptional(SO, UO);
10119 if (LeavesRange(*Min))
10120 return { Min, true };
10121 Optional<APInt> Max = Min == SO ? UO : SO;
10122 if (LeavesRange(*Max))
10123 return { Max, true };
10124
10125 // Solutions were found, but were eliminated, hence the "true".
10126 return { None, true };
10127 };
10128
10129 std::tie(A, B, C, M, BitWidth) = *T;
10130 // Lower bound is inclusive, subtract 1 to represent the exiting value.
10131 APInt Lower = Range.getLower().sext(A.getBitWidth()) - 1;
10132 APInt Upper = Range.getUpper().sext(A.getBitWidth());
10133 auto SL = SolveForBoundary(Lower);
10134 auto SU = SolveForBoundary(Upper);
10135 // If any of the solutions was unknown, no meaninigful conclusions can
10136 // be made.
10137 if (!SL.second || !SU.second)
10138 return None;
10139
10140 // Claim: The correct solution is not some value between Min and Max.
10141 //
10142 // Justification: Assuming that Min and Max are different values, one of
10143 // them is when the first signed overflow happens, the other is when the
10144 // first unsigned overflow happens. Crossing the range boundary is only
10145 // possible via an overflow (treating 0 as a special case of it, modeling
10146 // an overflow as crossing k*2^W for some k).
10147 //
10148 // The interesting case here is when Min was eliminated as an invalid
10149 // solution, but Max was not. The argument is that if there was another
10150 // overflow between Min and Max, it would also have been eliminated if
10151 // it was considered.
10152 //
10153 // For a given boundary, it is possible to have two overflows of the same
10154 // type (signed/unsigned) without having the other type in between: this
10155 // can happen when the vertex of the parabola is between the iterations
10156 // corresponding to the overflows. This is only possible when the two
10157 // overflows cross k*2^W for the same k. In such case, if the second one
10158 // left the range (and was the first one to do so), the first overflow
10159 // would have to enter the range, which would mean that either we had left
10160 // the range before or that we started outside of it. Both of these cases
10161 // are contradictions.
10162 //
10163 // Claim: In the case where SolveForBoundary returns None, the correct
10164 // solution is not some value between the Max for this boundary and the
10165 // Min of the other boundary.
10166 //
10167 // Justification: Assume that we had such Max_A and Min_B corresponding
10168 // to range boundaries A and B and such that Max_A < Min_B. If there was
10169 // a solution between Max_A and Min_B, it would have to be caused by an
10170 // overflow corresponding to either A or B. It cannot correspond to B,
10171 // since Min_B is the first occurrence of such an overflow. If it
10172 // corresponded to A, it would have to be either a signed or an unsigned
10173 // overflow that is larger than both eliminated overflows for A. But
10174 // between the eliminated overflows and this overflow, the values would
10175 // cover the entire value space, thus crossing the other boundary, which
10176 // is a contradiction.
10177
10178 return TruncIfPossible(MinOptional(SL.first, SU.first), BitWidth);
10179}
10180
10181ScalarEvolution::ExitLimit
10182ScalarEvolution::howFarToZero(const SCEV *V, const Loop *L, bool ControlsExit,
10183 bool AllowPredicates) {
10184
10185 // This is only used for loops with a "x != y" exit test. The exit condition
10186 // is now expressed as a single expression, V = x-y. So the exit test is
10187 // effectively V != 0. We know and take advantage of the fact that this
10188 // expression only being used in a comparison by zero context.
10189
10190 SmallPtrSet<const SCEVPredicate *, 4> Predicates;
10191 // If the value is a constant
10192 if (const SCEVConstant *C = dyn_cast<SCEVConstant>(V)) {
10193 // If the value is already zero, the branch will execute zero times.
10194 if (C->getValue()->isZero()) return C;
10195 return getCouldNotCompute(); // Otherwise it will loop infinitely.
10196 }
10197
10198 const SCEVAddRecExpr *AddRec =
10199 dyn_cast<SCEVAddRecExpr>(stripInjectiveFunctions(V));
10200
10201 if (!AddRec && AllowPredicates)
10202 // Try to make this an AddRec using runtime tests, in the first X
10203 // iterations of this loop, where X is the SCEV expression found by the
10204 // algorithm below.
10205 AddRec = convertSCEVToAddRecWithPredicates(V, L, Predicates);
10206
10207 if (!AddRec || AddRec->getLoop() != L)
10208 return getCouldNotCompute();
10209
10210 // If this is a quadratic (3-term) AddRec {L,+,M,+,N}, find the roots of
10211 // the quadratic equation to solve it.
10212 if (AddRec->isQuadratic() && AddRec->getType()->isIntegerTy()) {
10213 // We can only use this value if the chrec ends up with an exact zero
10214 // value at this index. When solving for "X*X != 5", for example, we
10215 // should not accept a root of 2.
10216 if (auto S = SolveQuadraticAddRecExact(AddRec, *this)) {
10217 const auto *R = cast<SCEVConstant>(getConstant(*S));
10218 return ExitLimit(R, R, false, Predicates);
10219 }
10220 return getCouldNotCompute();
10221 }
10222
10223 // Otherwise we can only handle this if it is affine.
10224 if (!AddRec->isAffine())
10225 return getCouldNotCompute();
10226
10227 // If this is an affine expression, the execution count of this branch is
10228 // the minimum unsigned root of the following equation:
10229 //
10230 // Start + Step*N = 0 (mod 2^BW)
10231 //
10232 // equivalent to:
10233 //
10234 // Step*N = -Start (mod 2^BW)
10235 //
10236 // where BW is the common bit width of Start and Step.
10237
10238 // Get the initial value for the loop.
10239 const SCEV *Start = getSCEVAtScope(AddRec->getStart(), L->getParentLoop());
10240 const SCEV *Step = getSCEVAtScope(AddRec->getOperand(1), L->getParentLoop());
10241
10242 // For now we handle only constant steps.
10243 //
10244 // TODO: Handle a nonconstant Step given AddRec<NUW>. If the
10245 // AddRec is NUW, then (in an unsigned sense) it cannot be counting up to wrap
10246 // to 0, it must be counting down to equal 0. Consequently, N = Start / -Step.
10247 // We have not yet seen any such cases.
10248 const SCEVConstant *StepC = dyn_cast<SCEVConstant>(Step);
10249 if (!StepC || StepC->getValue()->isZero())
10250 return getCouldNotCompute();
10251
10252 // For positive steps (counting up until unsigned overflow):
10253 // N = -Start/Step (as unsigned)
10254 // For negative steps (counting down to zero):
10255 // N = Start/-Step
10256 // First compute the unsigned distance from zero in the direction of Step.
10257 bool CountDown = StepC->getAPInt().isNegative();
10258 const SCEV *Distance = CountDown ? Start : getNegativeSCEV(Start);
10259
10260 // Handle unitary steps, which cannot wraparound.
10261 // 1*N = -Start; -1*N = Start (mod 2^BW), so:
10262 // N = Distance (as unsigned)
10263 if (StepC->getValue()->isOne() || StepC->getValue()->isMinusOne()) {
10264 APInt MaxBECount = getUnsignedRangeMax(applyLoopGuards(Distance, L));
10265 MaxBECount = APIntOps::umin(MaxBECount, getUnsignedRangeMax(Distance));
10266
10267 // When a loop like "for (int i = 0; i != n; ++i) { /* body */ }" is rotated,
10268 // we end up with a loop whose backedge-taken count is n - 1. Detect this
10269 // case, and see if we can improve the bound.
10270 //
10271 // Explicitly handling this here is necessary because getUnsignedRange
10272 // isn't context-sensitive; it doesn't know that we only care about the
10273 // range inside the loop.
10274 const SCEV *Zero = getZero(Distance->getType());
10275 const SCEV *One = getOne(Distance->getType());
10276 const SCEV *DistancePlusOne = getAddExpr(Distance, One);
10277 if (isLoopEntryGuardedByCond(L, ICmpInst::ICMP_NE, DistancePlusOne, Zero)) {
10278 // If Distance + 1 doesn't overflow, we can compute the maximum distance
10279 // as "unsigned_max(Distance + 1) - 1".
10280 ConstantRange CR = getUnsignedRange(DistancePlusOne);
10281 MaxBECount = APIntOps::umin(MaxBECount, CR.getUnsignedMax() - 1);
10282 }
10283 return ExitLimit(Distance, getConstant(MaxBECount), false, Predicates);
10284 }
10285
10286 // If the condition controls loop exit (the loop exits only if the expression
10287 // is true) and the addition is no-wrap we can use unsigned divide to
10288 // compute the backedge count. In this case, the step may not divide the
10289 // distance, but we don't care because if the condition is "missed" the loop
10290 // will have undefined behavior due to wrapping.
10291 if (ControlsExit && AddRec->hasNoSelfWrap() &&
10292 loopHasNoAbnormalExits(AddRec->getLoop())) {
10293 const SCEV *Exact =
10294 getUDivExpr(Distance, CountDown ? getNegativeSCEV(Step) : Step);
10295 const SCEV *Max = getCouldNotCompute();
10296 if (Exact != getCouldNotCompute()) {
10297 APInt MaxInt = getUnsignedRangeMax(applyLoopGuards(Exact, L));
10298 Max = getConstant(APIntOps::umin(MaxInt, getUnsignedRangeMax(Exact)));
10299 }
10300 return ExitLimit(Exact, Max, false, Predicates);
10301 }
10302
10303 // Solve the general equation.
10304 const SCEV *E = SolveLinEquationWithOverflow(StepC->getAPInt(),
10305 getNegativeSCEV(Start), *this);
10306
10307 const SCEV *M = E;
10308 if (E != getCouldNotCompute()) {
10309 APInt MaxWithGuards = getUnsignedRangeMax(applyLoopGuards(E, L));
10310 M = getConstant(APIntOps::umin(MaxWithGuards, getUnsignedRangeMax(E)));
10311 }
10312 return ExitLimit(E, M, false, Predicates);
10313}
10314
10315ScalarEvolution::ExitLimit
10316ScalarEvolution::howFarToNonZero(const SCEV *V, const Loop *L) {
10317 // Loops that look like: while (X == 0) are very strange indeed. We don't
10318 // handle them yet except for the trivial case. This could be expanded in the
10319 // future as needed.
10320
10321 // If the value is a constant, check to see if it is known to be non-zero
10322 // already. If so, the backedge will execute zero times.
10323 if (const SCEVConstant *C = dyn_cast<SCEVConstant>(V)) {
10324 if (!C->getValue()->isZero())
10325 return getZero(C->getType());
10326 return getCouldNotCompute(); // Otherwise it will loop infinitely.
10327 }
10328
10329 // We could implement others, but I really doubt anyone writes loops like
10330 // this, and if they did, they would already be constant folded.
10331 return getCouldNotCompute();
10332}
10333
10334std::pair<const BasicBlock *, const BasicBlock *>
10335ScalarEvolution::getPredecessorWithUniqueSuccessorForBB(const BasicBlock *BB)
10336 const {
10337 // If the block has a unique predecessor, then there is no path from the
10338 // predecessor to the block that does not go through the direct edge
10339 // from the predecessor to the block.
10340 if (const BasicBlock *Pred = BB->getSinglePredecessor())
10341 return {Pred, BB};
10342
10343 // A loop's header is defined to be a block that dominates the loop.
10344 // If the header has a unique predecessor outside the loop, it must be
10345 // a block that has exactly one successor that can reach the loop.
10346 if (const Loop *L = LI.getLoopFor(BB))
10347 return {L->getLoopPredecessor(), L->getHeader()};
10348
10349 return {nullptr, nullptr};
10350}
10351
10352/// SCEV structural equivalence is usually sufficient for testing whether two
10353/// expressions are equal, however for the purposes of looking for a condition
10354/// guarding a loop, it can be useful to be a little more general, since a
10355/// front-end may have replicated the controlling expression.
10356static bool HasSameValue(const SCEV *A, const SCEV *B) {
10357 // Quick check to see if they are the same SCEV.
10358 if (A == B) return true;
10359
10360 auto ComputesEqualValues = [](const Instruction *A, const Instruction *B) {
10361 // Not all instructions that are "identical" compute the same value. For
10362 // instance, two distinct alloca instructions allocating the same type are
10363 // identical and do not read memory; but compute distinct values.
10364 return A->isIdenticalTo(B) && (isa<BinaryOperator>(A) || isa<GetElementPtrInst>(A));
10365 };
10366
10367 // Otherwise, if they're both SCEVUnknown, it's possible that they hold
10368 // two different instructions with the same value. Check for this case.
10369 if (const SCEVUnknown *AU = dyn_cast<SCEVUnknown>(A))
10370 if (const SCEVUnknown *BU = dyn_cast<SCEVUnknown>(B))
10371 if (const Instruction *AI = dyn_cast<Instruction>(AU->getValue()))
10372 if (const Instruction *BI = dyn_cast<Instruction>(BU->getValue()))
10373 if (ComputesEqualValues(AI, BI))
10374 return true;
10375
10376 // Otherwise assume they may have a different value.
10377 return false;
10378}
10379
10380bool ScalarEvolution::SimplifyICmpOperands(ICmpInst::Predicate &Pred,
10381 const SCEV *&LHS, const SCEV *&RHS,
10382 unsigned Depth,
10383 bool ControllingFiniteLoop) {
10384 bool Changed = false;
10385 // Simplifies ICMP to trivial true or false by turning it into '0 == 0' or
10386 // '0 != 0'.
10387 auto TrivialCase = [&](bool TriviallyTrue) {
10388 LHS = RHS = getConstant(ConstantInt::getFalse(getContext()));
10389 Pred = TriviallyTrue ? ICmpInst::ICMP_EQ : ICmpInst::ICMP_NE;
10390 return true;
10391 };
10392 // If we hit the max recursion limit bail out.
10393 if (Depth >= 3)
10394 return false;
10395
10396 // Canonicalize a constant to the right side.
10397 if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(LHS)) {
10398 // Check for both operands constant.
10399 if (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS)) {
10400 if (ConstantExpr::getICmp(Pred,
10401 LHSC->getValue(),
10402 RHSC->getValue())->isNullValue())
10403 return TrivialCase(false);
10404 else
10405 return TrivialCase(true);
10406 }
10407 // Otherwise swap the operands to put the constant on the right.
10408 std::swap(LHS, RHS);
10409 Pred = ICmpInst::getSwappedPredicate(Pred);
10410 Changed = true;
10411 }
10412
10413 // If we're comparing an addrec with a value which is loop-invariant in the
10414 // addrec's loop, put the addrec on the left. Also make a dominance check,
10415 // as both operands could be addrecs loop-invariant in each other's loop.
10416 if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(RHS)) {
10417 const Loop *L = AR->getLoop();
10418 if (isLoopInvariant(LHS, L) && properlyDominates(LHS, L->getHeader())) {
10419 std::swap(LHS, RHS);
10420 Pred = ICmpInst::getSwappedPredicate(Pred);
10421 Changed = true;
10422 }
10423 }
10424
10425 // If there's a constant operand, canonicalize comparisons with boundary
10426 // cases, and canonicalize *-or-equal comparisons to regular comparisons.
10427 if (const SCEVConstant *RC = dyn_cast<SCEVConstant>(RHS)) {
10428 const APInt &RA = RC->getAPInt();
10429
10430 bool SimplifiedByConstantRange = false;
10431
10432 if (!ICmpInst::isEquality(Pred)) {
10433 ConstantRange ExactCR = ConstantRange::makeExactICmpRegion(Pred, RA);
10434 if (ExactCR.isFullSet())
10435 return TrivialCase(true);
10436 else if (ExactCR.isEmptySet())
10437 return TrivialCase(false);
10438
10439 APInt NewRHS;
10440 CmpInst::Predicate NewPred;
10441 if (ExactCR.getEquivalentICmp(NewPred, NewRHS) &&
10442 ICmpInst::isEquality(NewPred)) {
10443 // We were able to convert an inequality to an equality.
10444 Pred = NewPred;
10445 RHS = getConstant(NewRHS);
10446 Changed = SimplifiedByConstantRange = true;
10447 }
10448 }
10449
10450 if (!SimplifiedByConstantRange) {
10451 switch (Pred) {
10452 default:
10453 break;
10454 case ICmpInst::ICMP_EQ:
10455 case ICmpInst::ICMP_NE:
10456 // Fold ((-1) * %a) + %b == 0 (equivalent to %b-%a == 0) into %a == %b.
10457 if (!RA)
10458 if (const SCEVAddExpr *AE = dyn_cast<SCEVAddExpr>(LHS))
10459 if (const SCEVMulExpr *ME =
10460 dyn_cast<SCEVMulExpr>(AE->getOperand(0)))
10461 if (AE->getNumOperands() == 2 && ME->getNumOperands() == 2 &&
10462 ME->getOperand(0)->isAllOnesValue()) {
10463 RHS = AE->getOperand(1);
10464 LHS = ME->getOperand(1);
10465 Changed = true;
10466 }
10467 break;
10468
10469
10470 // The "Should have been caught earlier!" messages refer to the fact
10471 // that the ExactCR.isFullSet() or ExactCR.isEmptySet() check above
10472 // should have fired on the corresponding cases, and canonicalized the
10473 // check to trivial case.
10474
10475 case ICmpInst::ICMP_UGE:
10476 assert(!RA.isMinValue() && "Should have been caught earlier!")(static_cast <bool> (!RA.isMinValue() && "Should have been caught earlier!"
) ? void (0) : __assert_fail ("!RA.isMinValue() && \"Should have been caught earlier!\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 10476, __extension__
__PRETTY_FUNCTION__))
;
10477 Pred = ICmpInst::ICMP_UGT;
10478 RHS = getConstant(RA - 1);
10479 Changed = true;
10480 break;
10481 case ICmpInst::ICMP_ULE:
10482 assert(!RA.isMaxValue() && "Should have been caught earlier!")(static_cast <bool> (!RA.isMaxValue() && "Should have been caught earlier!"
) ? void (0) : __assert_fail ("!RA.isMaxValue() && \"Should have been caught earlier!\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 10482, __extension__
__PRETTY_FUNCTION__))
;
10483 Pred = ICmpInst::ICMP_ULT;
10484 RHS = getConstant(RA + 1);
10485 Changed = true;
10486 break;
10487 case ICmpInst::ICMP_SGE:
10488 assert(!RA.isMinSignedValue() && "Should have been caught earlier!")(static_cast <bool> (!RA.isMinSignedValue() && "Should have been caught earlier!"
) ? void (0) : __assert_fail ("!RA.isMinSignedValue() && \"Should have been caught earlier!\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 10488, __extension__
__PRETTY_FUNCTION__))
;
10489 Pred = ICmpInst::ICMP_SGT;
10490 RHS = getConstant(RA - 1);
10491 Changed = true;
10492 break;
10493 case ICmpInst::ICMP_SLE:
10494 assert(!RA.isMaxSignedValue() && "Should have been caught earlier!")(static_cast <bool> (!RA.isMaxSignedValue() && "Should have been caught earlier!"
) ? void (0) : __assert_fail ("!RA.isMaxSignedValue() && \"Should have been caught earlier!\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 10494, __extension__
__PRETTY_FUNCTION__))
;
10495 Pred = ICmpInst::ICMP_SLT;
10496 RHS = getConstant(RA + 1);
10497 Changed = true;
10498 break;
10499 }
10500 }
10501 }
10502
10503 // Check for obvious equality.
10504 if (HasSameValue(LHS, RHS)) {
10505 if (ICmpInst::isTrueWhenEqual(Pred))
10506 return TrivialCase(true);
10507 if (ICmpInst::isFalseWhenEqual(Pred))
10508 return TrivialCase(false);
10509 }
10510
10511 // If possible, canonicalize GE/LE comparisons to GT/LT comparisons, by
10512 // adding or subtracting 1 from one of the operands. This can be done for
10513 // one of two reasons:
10514 // 1) The range of the RHS does not include the (signed/unsigned) boundaries
10515 // 2) The loop is finite, with this comparison controlling the exit. Since the
10516 // loop is finite, the bound cannot include the corresponding boundary
10517 // (otherwise it would loop forever).
10518 switch (Pred) {
10519 case ICmpInst::ICMP_SLE:
10520 if (ControllingFiniteLoop || !getSignedRangeMax(RHS).isMaxSignedValue()) {
10521 RHS = getAddExpr(getConstant(RHS->getType(), 1, true), RHS,
10522 SCEV::FlagNSW);
10523 Pred = ICmpInst::ICMP_SLT;
10524 Changed = true;
10525 } else if (!getSignedRangeMin(LHS).isMinSignedValue()) {
10526 LHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), LHS,
10527 SCEV::FlagNSW);
10528 Pred = ICmpInst::ICMP_SLT;
10529 Changed = true;
10530 }
10531 break;
10532 case ICmpInst::ICMP_SGE:
10533 if (ControllingFiniteLoop || !getSignedRangeMin(RHS).isMinSignedValue()) {
10534 RHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), RHS,
10535 SCEV::FlagNSW);
10536 Pred = ICmpInst::ICMP_SGT;
10537 Changed = true;
10538 } else if (!getSignedRangeMax(LHS).isMaxSignedValue()) {
10539 LHS = getAddExpr(getConstant(RHS->getType(), 1, true), LHS,
10540 SCEV::FlagNSW);
10541 Pred = ICmpInst::ICMP_SGT;
10542 Changed = true;
10543 }
10544 break;
10545 case ICmpInst::ICMP_ULE:
10546 if (ControllingFiniteLoop || !getUnsignedRangeMax(RHS).isMaxValue()) {
10547 RHS = getAddExpr(getConstant(RHS->getType(), 1, true), RHS,
10548 SCEV::FlagNUW);
10549 Pred = ICmpInst::ICMP_ULT;
10550 Changed = true;
10551 } else if (!getUnsignedRangeMin(LHS).isMinValue()) {
10552 LHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), LHS);
10553 Pred = ICmpInst::ICMP_ULT;
10554 Changed = true;
10555 }
10556 break;
10557 case ICmpInst::ICMP_UGE:
10558 if (ControllingFiniteLoop || !getUnsignedRangeMin(RHS).isMinValue()) {
10559 RHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), RHS);
10560 Pred = ICmpInst::ICMP_UGT;
10561 Changed = true;
10562 } else if (!getUnsignedRangeMax(LHS).isMaxValue()) {
10563 LHS = getAddExpr(getConstant(RHS->getType(), 1, true), LHS,
10564 SCEV::FlagNUW);
10565 Pred = ICmpInst::ICMP_UGT;
10566 Changed = true;
10567 }
10568 break;
10569 default:
10570 break;
10571 }
10572
10573 // TODO: More simplifications are possible here.
10574
10575 // Recursively simplify until we either hit a recursion limit or nothing
10576 // changes.
10577 if (Changed)
10578 return SimplifyICmpOperands(Pred, LHS, RHS, Depth + 1,
10579 ControllingFiniteLoop);
10580
10581 return Changed;
10582}
10583
10584bool ScalarEvolution::isKnownNegative(const SCEV *S) {
10585 return getSignedRangeMax(S).isNegative();
10586}
10587
10588bool ScalarEvolution::isKnownPositive(const SCEV *S) {
10589 return getSignedRangeMin(S).isStrictlyPositive();
10590}
10591
10592bool ScalarEvolution::isKnownNonNegative(const SCEV *S) {
10593 return !getSignedRangeMin(S).isNegative();
10594}
10595
10596bool ScalarEvolution::isKnownNonPositive(const SCEV *S) {
10597 return !getSignedRangeMax(S).isStrictlyPositive();
10598}
10599
10600bool ScalarEvolution::isKnownNonZero(const SCEV *S) {
10601 return getUnsignedRangeMin(S) != 0;
10602}
10603
10604std::pair<const SCEV *, const SCEV *>
10605ScalarEvolution::SplitIntoInitAndPostInc(const Loop *L, const SCEV *S) {
10606 // Compute SCEV on entry of loop L.
10607 const SCEV *Start = SCEVInitRewriter::rewrite(S, L, *this);
10608 if (Start == getCouldNotCompute())
10609 return { Start, Start };
10610 // Compute post increment SCEV for loop L.
10611 const SCEV *PostInc = SCEVPostIncRewriter::rewrite(S, L, *this);
10612 assert(PostInc != getCouldNotCompute() && "Unexpected could not compute")(static_cast <bool> (PostInc != getCouldNotCompute() &&
"Unexpected could not compute") ? void (0) : __assert_fail (
"PostInc != getCouldNotCompute() && \"Unexpected could not compute\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 10612, __extension__
__PRETTY_FUNCTION__))
;
10613 return { Start, PostInc };
10614}
10615
10616bool ScalarEvolution::isKnownViaInduction(ICmpInst::Predicate Pred,
10617 const SCEV *LHS, const SCEV *RHS) {
10618 // First collect all loops.
10619 SmallPtrSet<const Loop *, 8> LoopsUsed;
10620 getUsedLoops(LHS, LoopsUsed);
10621 getUsedLoops(RHS, LoopsUsed);
10622
10623 if (LoopsUsed.empty())
10624 return false;
10625
10626 // Domination relationship must be a linear order on collected loops.
10627#ifndef NDEBUG
10628 for (const auto *L1 : LoopsUsed)
10629 for (const auto *L2 : LoopsUsed)
10630 assert((DT.dominates(L1->getHeader(), L2->getHeader()) ||(static_cast <bool> ((DT.dominates(L1->getHeader(), L2
->getHeader()) || DT.dominates(L2->getHeader(), L1->
getHeader())) && "Domination relationship is not a linear order"
) ? void (0) : __assert_fail ("(DT.dominates(L1->getHeader(), L2->getHeader()) || DT.dominates(L2->getHeader(), L1->getHeader())) && \"Domination relationship is not a linear order\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 10632, __extension__
__PRETTY_FUNCTION__))
10631 DT.dominates(L2->getHeader(), L1->getHeader())) &&(static_cast <bool> ((DT.dominates(L1->getHeader(), L2
->getHeader()) || DT.dominates(L2->getHeader(), L1->
getHeader())) && "Domination relationship is not a linear order"
) ? void (0) : __assert_fail ("(DT.dominates(L1->getHeader(), L2->getHeader()) || DT.dominates(L2->getHeader(), L1->getHeader())) && \"Domination relationship is not a linear order\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 10632, __extension__
__PRETTY_FUNCTION__))
10632 "Domination relationship is not a linear order")(static_cast <bool> ((DT.dominates(L1->getHeader(), L2
->getHeader()) || DT.dominates(L2->getHeader(), L1->
getHeader())) && "Domination relationship is not a linear order"
) ? void (0) : __assert_fail ("(DT.dominates(L1->getHeader(), L2->getHeader()) || DT.dominates(L2->getHeader(), L1->getHeader())) && \"Domination relationship is not a linear order\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 10632, __extension__
__PRETTY_FUNCTION__))
;
10633#endif
10634
10635 const Loop *MDL =
10636 *std::max_element(LoopsUsed.begin(), LoopsUsed.end(),
10637 [&](const Loop *L1, const Loop *L2) {
10638 return DT.properlyDominates(L1->getHeader(), L2->getHeader());
10639 });
10640
10641 // Get init and post increment value for LHS.
10642 auto SplitLHS = SplitIntoInitAndPostInc(MDL, LHS);
10643 // if LHS contains unknown non-invariant SCEV then bail out.
10644 if (SplitLHS.first == getCouldNotCompute())
10645 return false;
10646 assert (SplitLHS.second != getCouldNotCompute() && "Unexpected CNC")(static_cast <bool> (SplitLHS.second != getCouldNotCompute
() && "Unexpected CNC") ? void (0) : __assert_fail ("SplitLHS.second != getCouldNotCompute() && \"Unexpected CNC\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 10646, __extension__
__PRETTY_FUNCTION__))
;
10647 // Get init and post increment value for RHS.
10648 auto SplitRHS = SplitIntoInitAndPostInc(MDL, RHS);
10649 // if RHS contains unknown non-invariant SCEV then bail out.
10650 if (SplitRHS.first == getCouldNotCompute())
10651 return false;
10652 assert (SplitRHS.second != getCouldNotCompute() && "Unexpected CNC")(static_cast <bool> (SplitRHS.second != getCouldNotCompute
() && "Unexpected CNC") ? void (0) : __assert_fail ("SplitRHS.second != getCouldNotCompute() && \"Unexpected CNC\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 10652, __extension__
__PRETTY_FUNCTION__))
;
10653 // It is possible that init SCEV contains an invariant load but it does
10654 // not dominate MDL and is not available at MDL loop entry, so we should
10655 // check it here.
10656 if (!isAvailableAtLoopEntry(SplitLHS.first, MDL) ||
10657 !isAvailableAtLoopEntry(SplitRHS.first, MDL))
10658 return false;
10659
10660 // It seems backedge guard check is faster than entry one so in some cases
10661 // it can speed up whole estimation by short circuit
10662 return isLoopBackedgeGuardedByCond(MDL, Pred, SplitLHS.second,
10663 SplitRHS.second) &&
10664 isLoopEntryGuardedByCond(MDL, Pred, SplitLHS.first, SplitRHS.first);
10665}
10666
10667bool ScalarEvolution::isKnownPredicate(ICmpInst::Predicate Pred,
10668 const SCEV *LHS, const SCEV *RHS) {
10669 // Canonicalize the inputs first.
10670 (void)SimplifyICmpOperands(Pred, LHS, RHS);
10671
10672 if (isKnownViaInduction(Pred, LHS, RHS))
10673 return true;
10674
10675 if (isKnownPredicateViaSplitting(Pred, LHS, RHS))
10676 return true;
10677
10678 // Otherwise see what can be done with some simple reasoning.
10679 return isKnownViaNonRecursiveReasoning(Pred, LHS, RHS);
10680}
10681
10682Optional<bool> ScalarEvolution::evaluatePredicate(ICmpInst::Predicate Pred,
10683 const SCEV *LHS,
10684 const SCEV *RHS) {
10685 if (isKnownPredicate(Pred, LHS, RHS))
10686 return true;
10687 else if (isKnownPredicate(ICmpInst::getInversePredicate(Pred), LHS, RHS))
10688 return false;
10689 return None;
10690}
10691
10692bool ScalarEvolution::isKnownPredicateAt(ICmpInst::Predicate Pred,
10693 const SCEV *LHS, const SCEV *RHS,
10694 const Instruction *CtxI) {
10695 // TODO: Analyze guards and assumes from Context's block.
10696 return isKnownPredicate(Pred, LHS, RHS) ||
10697 isBasicBlockEntryGuardedByCond(CtxI->getParent(), Pred, LHS, RHS);
10698}
10699
10700Optional<bool> ScalarEvolution::evaluatePredicateAt(ICmpInst::Predicate Pred,
10701 const SCEV *LHS,
10702 const SCEV *RHS,
10703 const Instruction *CtxI) {
10704 Optional<bool> KnownWithoutContext = evaluatePredicate(Pred, LHS, RHS);
10705 if (KnownWithoutContext)
10706 return KnownWithoutContext;
10707
10708 if (isBasicBlockEntryGuardedByCond(CtxI->getParent(), Pred, LHS, RHS))
10709 return true;
10710 else if (isBasicBlockEntryGuardedByCond(CtxI->getParent(),
10711 ICmpInst::getInversePredicate(Pred),
10712 LHS, RHS))
10713 return false;
10714 return None;
10715}
10716
10717bool ScalarEvolution::isKnownOnEveryIteration(ICmpInst::Predicate Pred,
10718 const SCEVAddRecExpr *LHS,
10719 const SCEV *RHS) {
10720 const Loop *L = LHS->getLoop();
10721 return isLoopEntryGuardedByCond(L, Pred, LHS->getStart(), RHS) &&
10722 isLoopBackedgeGuardedByCond(L, Pred, LHS->getPostIncExpr(*this), RHS);
10723}
10724
10725Optional<ScalarEvolution::MonotonicPredicateType>
10726ScalarEvolution::getMonotonicPredicateType(const SCEVAddRecExpr *LHS,
10727 ICmpInst::Predicate Pred) {
10728 auto Result = getMonotonicPredicateTypeImpl(LHS, Pred);
10729
10730#ifndef NDEBUG
10731 // Verify an invariant: inverting the predicate should turn a monotonically
10732 // increasing change to a monotonically decreasing one, and vice versa.
10733 if (Result) {
10734 auto ResultSwapped =
10735 getMonotonicPredicateTypeImpl(LHS, ICmpInst::getSwappedPredicate(Pred));
10736
10737 assert(ResultSwapped && "should be able to analyze both!")(static_cast <bool> (ResultSwapped && "should be able to analyze both!"
) ? void (0) : __assert_fail ("ResultSwapped && \"should be able to analyze both!\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 10737, __extension__
__PRETTY_FUNCTION__))
;
10738 assert(ResultSwapped.value() != Result.value() &&(static_cast <bool> (ResultSwapped.value() != Result.value
() && "monotonicity should flip as we flip the predicate"
) ? void (0) : __assert_fail ("ResultSwapped.value() != Result.value() && \"monotonicity should flip as we flip the predicate\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 10739, __extension__
__PRETTY_FUNCTION__))
10739 "monotonicity should flip as we flip the predicate")(static_cast <bool> (ResultSwapped.value() != Result.value
() && "monotonicity should flip as we flip the predicate"
) ? void (0) : __assert_fail ("ResultSwapped.value() != Result.value() && \"monotonicity should flip as we flip the predicate\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 10739, __extension__
__PRETTY_FUNCTION__))
;
10740 }
10741#endif
10742
10743 return Result;
10744}
10745
10746Optional<ScalarEvolution::MonotonicPredicateType>
10747ScalarEvolution::getMonotonicPredicateTypeImpl(const SCEVAddRecExpr *LHS,
10748 ICmpInst::Predicate Pred) {
10749 // A zero step value for LHS means the induction variable is essentially a
10750 // loop invariant value. We don't really depend on the predicate actually
10751 // flipping from false to true (for increasing predicates, and the other way
10752 // around for decreasing predicates), all we care about is that *if* the
10753 // predicate changes then it only changes from false to true.
10754 //
10755 // A zero step value in itself is not very useful, but there may be places
10756 // where SCEV can prove X >= 0 but not prove X > 0, so it is helpful to be
10757 // as general as possible.
10758
10759 // Only handle LE/LT/GE/GT predicates.
10760 if (!ICmpInst::isRelational(Pred))
10761 return None;
10762
10763 bool IsGreater = ICmpInst::isGE(Pred) || ICmpInst::isGT(Pred);
10764 assert((IsGreater || ICmpInst::isLE(Pred) || ICmpInst::isLT(Pred)) &&(static_cast <bool> ((IsGreater || ICmpInst::isLE(Pred)
|| ICmpInst::isLT(Pred)) && "Should be greater or less!"
) ? void (0) : __assert_fail ("(IsGreater || ICmpInst::isLE(Pred) || ICmpInst::isLT(Pred)) && \"Should be greater or less!\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 10765, __extension__
__PRETTY_FUNCTION__))
10765 "Should be greater or less!")(static_cast <bool> ((IsGreater || ICmpInst::isLE(Pred)
|| ICmpInst::isLT(Pred)) && "Should be greater or less!"
) ? void (0) : __assert_fail ("(IsGreater || ICmpInst::isLE(Pred) || ICmpInst::isLT(Pred)) && \"Should be greater or less!\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 10765, __extension__
__PRETTY_FUNCTION__))
;
10766
10767 // Check that AR does not wrap.
10768 if (ICmpInst::isUnsigned(Pred)) {
10769 if (!LHS->hasNoUnsignedWrap())
10770 return None;
10771 return IsGreater ? MonotonicallyIncreasing : MonotonicallyDecreasing;
10772 } else {
10773 assert(ICmpInst::isSigned(Pred) &&(static_cast <bool> (ICmpInst::isSigned(Pred) &&
"Relational predicate is either signed or unsigned!") ? void
(0) : __assert_fail ("ICmpInst::isSigned(Pred) && \"Relational predicate is either signed or unsigned!\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 10774, __extension__
__PRETTY_FUNCTION__))
10774 "Relational predicate is either signed or unsigned!")(static_cast <bool> (ICmpInst::isSigned(Pred) &&
"Relational predicate is either signed or unsigned!") ? void
(0) : __assert_fail ("ICmpInst::isSigned(Pred) && \"Relational predicate is either signed or unsigned!\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 10774, __extension__
__PRETTY_FUNCTION__))
;
10775 if (!LHS->hasNoSignedWrap())
10776 return None;
10777
10778 const SCEV *Step = LHS->getStepRecurrence(*this);
10779
10780 if (isKnownNonNegative(Step))
10781 return IsGreater ? MonotonicallyIncreasing : MonotonicallyDecreasing;
10782
10783 if (isKnownNonPositive(Step))
10784 return !IsGreater ? MonotonicallyIncreasing : MonotonicallyDecreasing;
10785
10786 return None;
10787 }
10788}
10789
10790Optional<ScalarEvolution::LoopInvariantPredicate>
10791ScalarEvolution::getLoopInvariantPredicate(ICmpInst::Predicate Pred,
10792 const SCEV *LHS, const SCEV *RHS,
10793 const Loop *L,
10794 const Instruction *CtxI) {
10795 // If there is a loop-invariant, force it into the RHS, otherwise bail out.
10796 if (!isLoopInvariant(RHS, L)) {
10797 if (!isLoopInvariant(LHS, L))
10798 return None;
10799
10800 std::swap(LHS, RHS);
10801 Pred = ICmpInst::getSwappedPredicate(Pred);
10802 }
10803
10804 const SCEVAddRecExpr *ArLHS = dyn_cast<SCEVAddRecExpr>(LHS);
10805 if (!ArLHS || ArLHS->getLoop() != L)
10806 return None;
10807
10808 auto MonotonicType = getMonotonicPredicateType(ArLHS, Pred);
10809 if (!MonotonicType)
10810 return None;
10811 // If the predicate "ArLHS `Pred` RHS" monotonically increases from false to
10812 // true as the loop iterates, and the backedge is control dependent on
10813 // "ArLHS `Pred` RHS" == true then we can reason as follows:
10814 //
10815 // * if the predicate was false in the first iteration then the predicate
10816 // is never evaluated again, since the loop exits without taking the
10817 // backedge.
10818 // * if the predicate was true in the first iteration then it will
10819 // continue to be true for all future iterations since it is
10820 // monotonically increasing.
10821 //
10822 // For both the above possibilities, we can replace the loop varying
10823 // predicate with its value on the first iteration of the loop (which is
10824 // loop invariant).
10825 //
10826 // A similar reasoning applies for a monotonically decreasing predicate, by
10827 // replacing true with false and false with true in the above two bullets.
10828 bool Increasing = *MonotonicType == ScalarEvolution::MonotonicallyIncreasing;
10829 auto P = Increasing ? Pred : ICmpInst::getInversePredicate(Pred);
10830
10831 if (isLoopBackedgeGuardedByCond(L, P, LHS, RHS))
10832 return ScalarEvolution::LoopInvariantPredicate(Pred, ArLHS->getStart(),
10833 RHS);
10834
10835 if (!CtxI)
10836 return None;
10837 // Try to prove via context.
10838 // TODO: Support other cases.
10839 switch (Pred) {
10840 default:
10841 break;
10842 case ICmpInst::ICMP_ULE:
10843 case ICmpInst::ICMP_ULT: {
10844 assert(ArLHS->hasNoUnsignedWrap() && "Is a requirement of monotonicity!")(static_cast <bool> (ArLHS->hasNoUnsignedWrap() &&
"Is a requirement of monotonicity!") ? void (0) : __assert_fail
("ArLHS->hasNoUnsignedWrap() && \"Is a requirement of monotonicity!\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 10844, __extension__
__PRETTY_FUNCTION__))
;
10845 // Given preconditions
10846 // (1) ArLHS does not cross the border of positive and negative parts of
10847 // range because of:
10848 // - Positive step; (TODO: lift this limitation)
10849 // - nuw - does not cross zero boundary;
10850 // - nsw - does not cross SINT_MAX boundary;
10851 // (2) ArLHS <s RHS
10852 // (3) RHS >=s 0
10853 // we can replace the loop variant ArLHS <u RHS condition with loop
10854 // invariant Start(ArLHS) <u RHS.
10855 //
10856 // Because of (1) there are two options:
10857 // - ArLHS is always negative. It means that ArLHS <u RHS is always false;
10858 // - ArLHS is always non-negative. Because of (3) RHS is also non-negative.
10859 // It means that ArLHS <s RHS <=> ArLHS <u RHS.
10860 // Because of (2) ArLHS <u RHS is trivially true.
10861 // All together it means that ArLHS <u RHS <=> Start(ArLHS) >=s 0.
10862 // We can strengthen this to Start(ArLHS) <u RHS.
10863 auto SignFlippedPred = ICmpInst::getFlippedSignednessPredicate(Pred);
10864 if (ArLHS->hasNoSignedWrap() && ArLHS->isAffine() &&
10865 isKnownPositive(ArLHS->getStepRecurrence(*this)) &&
10866 isKnownNonNegative(RHS) &&
10867 isKnownPredicateAt(SignFlippedPred, ArLHS, RHS, CtxI))
10868 return ScalarEvolution::LoopInvariantPredicate(Pred, ArLHS->getStart(),
10869 RHS);
10870 }
10871 }
10872
10873 return None;
10874}
10875
10876Optional<ScalarEvolution::LoopInvariantPredicate>
10877ScalarEvolution::getLoopInvariantExitCondDuringFirstIterations(
10878 ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS, const Loop *L,
10879 const Instruction *CtxI, const SCEV *MaxIter) {
10880 // Try to prove the following set of facts:
10881 // - The predicate is monotonic in the iteration space.
10882 // - If the check does not fail on the 1st iteration:
10883 // - No overflow will happen during first MaxIter iterations;
10884 // - It will not fail on the MaxIter'th iteration.
10885 // If the check does fail on the 1st iteration, we leave the loop and no
10886 // other checks matter.
10887
10888 // If there is a loop-invariant, force it into the RHS, otherwise bail out.
10889 if (!isLoopInvariant(RHS, L)) {
10890 if (!isLoopInvariant(LHS, L))
10891 return None;
10892
10893 std::swap(LHS, RHS);
10894 Pred = ICmpInst::getSwappedPredicate(Pred);
10895 }
10896
10897 auto *AR = dyn_cast<SCEVAddRecExpr>(LHS);
10898 if (!AR || AR->getLoop() != L)
10899 return None;
10900
10901 // The predicate must be relational (i.e. <, <=, >=, >).
10902 if (!ICmpInst::isRelational(Pred))
10903 return None;
10904
10905 // TODO: Support steps other than +/- 1.
10906 const SCEV *Step = AR->getStepRecurrence(*this);
10907 auto *One = getOne(Step->getType());
10908 auto *MinusOne = getNegativeSCEV(One);
10909 if (Step != One && Step != MinusOne)
10910 return None;
10911
10912 // Type mismatch here means that MaxIter is potentially larger than max
10913 // unsigned value in start type, which mean we cannot prove no wrap for the
10914 // indvar.
10915 if (AR->getType() != MaxIter->getType())
10916 return None;
10917
10918 // Value of IV on suggested last iteration.
10919 const SCEV *Last = AR->evaluateAtIteration(MaxIter, *this);
10920 // Does it still meet the requirement?
10921 if (!isLoopBackedgeGuardedByCond(L, Pred, Last, RHS))
10922 return None;
10923 // Because step is +/- 1 and MaxIter has same type as Start (i.e. it does
10924 // not exceed max unsigned value of this type), this effectively proves
10925 // that there is no wrap during the iteration. To prove that there is no
10926 // signed/unsigned wrap, we need to check that
10927 // Start <= Last for step = 1 or Start >= Last for step = -1.
10928 ICmpInst::Predicate NoOverflowPred =
10929 CmpInst::isSigned(Pred) ? ICmpInst::ICMP_SLE : ICmpInst::ICMP_ULE;
10930 if (Step == MinusOne)
10931 NoOverflowPred = CmpInst::getSwappedPredicate(NoOverflowPred);
10932 const SCEV *Start = AR->getStart();
10933 if (!isKnownPredicateAt(NoOverflowPred, Start, Last, CtxI))
10934 return None;
10935
10936 // Everything is fine.
10937 return ScalarEvolution::LoopInvariantPredicate(Pred, Start, RHS);
10938}
10939
10940bool ScalarEvolution::isKnownPredicateViaConstantRanges(
10941 ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS) {
10942 if (HasSameValue(LHS, RHS))
10943 return ICmpInst::isTrueWhenEqual(Pred);
10944
10945 // This code is split out from isKnownPredicate because it is called from
10946 // within isLoopEntryGuardedByCond.
10947
10948 auto CheckRanges = [&](const ConstantRange &RangeLHS,
10949 const ConstantRange &RangeRHS) {
10950 return RangeLHS.icmp(Pred, RangeRHS);
10951 };
10952
10953 // The check at the top of the function catches the case where the values are
10954 // known to be equal.
10955 if (Pred == CmpInst::ICMP_EQ)
10956 return false;
10957
10958 if (Pred == CmpInst::ICMP_NE) {
10959 auto SL = getSignedRange(LHS);
10960 auto SR = getSignedRange(RHS);
10961 if (CheckRanges(SL, SR))
10962 return true;
10963 auto UL = getUnsignedRange(LHS);
10964 auto UR = getUnsignedRange(RHS);
10965 if (CheckRanges(UL, UR))
10966 return true;
10967 auto *Diff = getMinusSCEV(LHS, RHS);
10968 return !isa<SCEVCouldNotCompute>(Diff) && isKnownNonZero(Diff);
10969 }
10970
10971 if (CmpInst::isSigned(Pred)) {
10972 auto SL = getSignedRange(LHS);
10973 auto SR = getSignedRange(RHS);
10974 return CheckRanges(SL, SR);
10975 }
10976
10977 auto UL = getUnsignedRange(LHS);
10978 auto UR = getUnsignedRange(RHS);
10979 return CheckRanges(UL, UR);
10980}
10981
10982bool ScalarEvolution::isKnownPredicateViaNoOverflow(ICmpInst::Predicate Pred,
10983 const SCEV *LHS,
10984 const SCEV *RHS) {
10985 // Match X to (A + C1)<ExpectedFlags> and Y to (A + C2)<ExpectedFlags>, where
10986 // C1 and C2 are constant integers. If either X or Y are not add expressions,
10987 // consider them as X + 0 and Y + 0 respectively. C1 and C2 are returned via
10988 // OutC1 and OutC2.
10989 auto MatchBinaryAddToConst = [this](const SCEV *X, const SCEV *Y,
10990 APInt &OutC1, APInt &OutC2,
10991 SCEV::NoWrapFlags ExpectedFlags) {
10992 const SCEV *XNonConstOp, *XConstOp;
10993 const SCEV *YNonConstOp, *YConstOp;
10994 SCEV::NoWrapFlags XFlagsPresent;
10995 SCEV::NoWrapFlags YFlagsPresent;
10996
10997 if (!splitBinaryAdd(X, XConstOp, XNonConstOp, XFlagsPresent)) {
10998 XConstOp = getZero(X->getType());
10999 XNonConstOp = X;
11000 XFlagsPresent = ExpectedFlags;
11001 }
11002 if (!isa<SCEVConstant>(XConstOp) ||
11003 (XFlagsPresent & ExpectedFlags) != ExpectedFlags)
11004 return false;
11005
11006 if (!splitBinaryAdd(Y, YConstOp, YNonConstOp, YFlagsPresent)) {
11007 YConstOp = getZero(Y->getType());
11008 YNonConstOp = Y;
11009 YFlagsPresent = ExpectedFlags;
11010 }
11011
11012 if (!isa<SCEVConstant>(YConstOp) ||
11013 (YFlagsPresent & ExpectedFlags) != ExpectedFlags)
11014 return false;
11015
11016 if (YNonConstOp != XNonConstOp)
11017 return false;
11018
11019 OutC1 = cast<SCEVConstant>(XConstOp)->getAPInt();
11020 OutC2 = cast<SCEVConstant>(YConstOp)->getAPInt();
11021
11022 return true;
11023 };
11024
11025 APInt C1;
11026 APInt C2;
11027
11028 switch (Pred) {
11029 default:
11030 break;
11031
11032 case ICmpInst::ICMP_SGE:
11033 std::swap(LHS, RHS);
11034 [[fallthrough]];
11035 case ICmpInst::ICMP_SLE:
11036 // (X + C1)<nsw> s<= (X + C2)<nsw> if C1 s<= C2.
11037 if (MatchBinaryAddToConst(LHS, RHS, C1, C2, SCEV::FlagNSW) && C1.sle(C2))
11038 return true;
11039
11040 break;
11041
11042 case ICmpInst::ICMP_SGT:
11043 std::swap(LHS, RHS);
11044 [[fallthrough]];
11045 case ICmpInst::ICMP_SLT:
11046 // (X + C1)<nsw> s< (X + C2)<nsw> if C1 s< C2.
11047 if (MatchBinaryAddToConst(LHS, RHS, C1, C2, SCEV::FlagNSW) && C1.slt(C2))
11048 return true;
11049
11050 break;
11051
11052 case ICmpInst::ICMP_UGE:
11053 std::swap(LHS, RHS);
11054 [[fallthrough]];
11055 case ICmpInst::ICMP_ULE:
11056 // (X + C1)<nuw> u<= (X + C2)<nuw> for C1 u<= C2.
11057 if (MatchBinaryAddToConst(RHS, LHS, C2, C1, SCEV::FlagNUW) && C1.ule(C2))
11058 return true;
11059
11060 break;
11061
11062 case ICmpInst::ICMP_UGT:
11063 std::swap(LHS, RHS);
11064 [[fallthrough]];
11065 case ICmpInst::ICMP_ULT:
11066 // (X + C1)<nuw> u< (X + C2)<nuw> if C1 u< C2.
11067 if (MatchBinaryAddToConst(RHS, LHS, C2, C1, SCEV::FlagNUW) && C1.ult(C2))
11068 return true;
11069 break;
11070 }
11071
11072 return false;
11073}
11074
11075bool ScalarEvolution::isKnownPredicateViaSplitting(ICmpInst::Predicate Pred,
11076 const SCEV *LHS,
11077 const SCEV *RHS) {
11078 if (Pred != ICmpInst::ICMP_ULT || ProvingSplitPredicate)
11079 return false;
11080
11081 // Allowing arbitrary number of activations of isKnownPredicateViaSplitting on
11082 // the stack can result in exponential time complexity.
11083 SaveAndRestore<bool> Restore(ProvingSplitPredicate, true);
11084
11085 // If L >= 0 then I `ult` L <=> I >= 0 && I `slt` L
11086 //
11087 // To prove L >= 0 we use isKnownNonNegative whereas to prove I >= 0 we use
11088 // isKnownPredicate. isKnownPredicate is more powerful, but also more
11089 // expensive; and using isKnownNonNegative(RHS) is sufficient for most of the
11090 // interesting cases seen in practice. We can consider "upgrading" L >= 0 to
11091 // use isKnownPredicate later if needed.
11092 return isKnownNonNegative(RHS) &&
11093 isKnownPredicate(CmpInst::ICMP_SGE, LHS, getZero(LHS->getType())) &&
11094 isKnownPredicate(CmpInst::ICMP_SLT, LHS, RHS);
11095}
11096
11097bool ScalarEvolution::isImpliedViaGuard(const BasicBlock *BB,
11098 ICmpInst::Predicate Pred,
11099 const SCEV *LHS, const SCEV *RHS) {
11100 // No need to even try if we know the module has no guards.
11101 if (!HasGuards)
11102 return false;
11103
11104 return any_of(*BB, [&](const Instruction &I) {
11105 using namespace llvm::PatternMatch;
11106
11107 Value *Condition;
11108 return match(&I, m_Intrinsic<Intrinsic::experimental_guard>(
11109 m_Value(Condition))) &&
11110 isImpliedCond(Pred, LHS, RHS, Condition, false);
11111 });
11112}
11113
11114/// isLoopBackedgeGuardedByCond - Test whether the backedge of the loop is
11115/// protected by a conditional between LHS and RHS. This is used to
11116/// to eliminate casts.
11117bool
11118ScalarEvolution::isLoopBackedgeGuardedByCond(const Loop *L,
11119 ICmpInst::Predicate Pred,
11120 const SCEV *LHS, const SCEV *RHS) {
11121 // Interpret a null as meaning no loop, where there is obviously no guard
11122 // (interprocedural conditions notwithstanding). Do not bother about
11123 // unreachable loops.
11124 if (!L || !DT.isReachableFromEntry(L->getHeader()))
11125 return true;
11126
11127 if (VerifyIR)
11128 assert(!verifyFunction(*L->getHeader()->getParent(), &dbgs()) &&(static_cast <bool> (!verifyFunction(*L->getHeader()
->getParent(), &dbgs()) && "This cannot be done on broken IR!"
) ? void (0) : __assert_fail ("!verifyFunction(*L->getHeader()->getParent(), &dbgs()) && \"This cannot be done on broken IR!\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 11129, __extension__
__PRETTY_FUNCTION__))
11129 "This cannot be done on broken IR!")(static_cast <bool> (!verifyFunction(*L->getHeader()
->getParent(), &dbgs()) && "This cannot be done on broken IR!"
) ? void (0) : __assert_fail ("!verifyFunction(*L->getHeader()->getParent(), &dbgs()) && \"This cannot be done on broken IR!\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 11129, __extension__
__PRETTY_FUNCTION__))
;
11130
11131
11132 if (isKnownViaNonRecursiveReasoning(Pred, LHS, RHS))
11133 return true;
11134
11135 BasicBlock *Latch = L->getLoopLatch();
11136 if (!Latch)
11137 return false;
11138
11139 BranchInst *LoopContinuePredicate =
11140 dyn_cast<BranchInst>(Latch->getTerminator());
11141 if (LoopContinuePredicate && LoopContinuePredicate->isConditional() &&
11142 isImpliedCond(Pred, LHS, RHS,
11143 LoopContinuePredicate->getCondition(),
11144 LoopContinuePredicate->getSuccessor(0) != L->getHeader()))
11145 return true;
11146
11147 // We don't want more than one activation of the following loops on the stack
11148 // -- that can lead to O(n!) time complexity.
11149 if (WalkingBEDominatingConds)
11150 return false;
11151
11152 SaveAndRestore<bool> ClearOnExit(WalkingBEDominatingConds, true);
11153
11154 // See if we can exploit a trip count to prove the predicate.
11155 const auto &BETakenInfo = getBackedgeTakenInfo(L);
11156 const SCEV *LatchBECount = BETakenInfo.getExact(Latch, this);
11157 if (LatchBECount != getCouldNotCompute()) {
11158 // We know that Latch branches back to the loop header exactly
11159 // LatchBECount times. This means the backdege condition at Latch is
11160 // equivalent to "{0,+,1} u< LatchBECount".
11161 Type *Ty = LatchBECount->getType();
11162 auto NoWrapFlags = SCEV::NoWrapFlags(SCEV::FlagNUW | SCEV::FlagNW);
11163 const SCEV *LoopCounter =
11164 getAddRecExpr(getZero(Ty), getOne(Ty), L, NoWrapFlags);
11165 if (isImpliedCond(Pred, LHS, RHS, ICmpInst::ICMP_ULT, LoopCounter,
11166 LatchBECount))
11167 return true;
11168 }
11169
11170 // Check conditions due to any @llvm.assume intrinsics.
11171 for (auto &AssumeVH : AC.assumptions()) {
11172 if (!AssumeVH)
11173 continue;
11174 auto *CI = cast<CallInst>(AssumeVH);
11175 if (!DT.dominates(CI, Latch->getTerminator()))
11176 continue;
11177
11178 if (isImpliedCond(Pred, LHS, RHS, CI->getArgOperand(0), false))
11179 return true;
11180 }
11181
11182 if (isImpliedViaGuard(Latch, Pred, LHS, RHS))
11183 return true;
11184
11185 for (DomTreeNode *DTN = DT[Latch], *HeaderDTN = DT[L->getHeader()];
11186 DTN != HeaderDTN; DTN = DTN->getIDom()) {
11187 assert(DTN && "should reach the loop header before reaching the root!")(static_cast <bool> (DTN && "should reach the loop header before reaching the root!"
) ? void (0) : __assert_fail ("DTN && \"should reach the loop header before reaching the root!\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 11187, __extension__
__PRETTY_FUNCTION__))
;
11188
11189 BasicBlock *BB = DTN->getBlock();
11190 if (isImpliedViaGuard(BB, Pred, LHS, RHS))
11191 return true;
11192
11193 BasicBlock *PBB = BB->getSinglePredecessor();
11194 if (!PBB)
11195 continue;
11196
11197 BranchInst *ContinuePredicate = dyn_cast<BranchInst>(PBB->getTerminator());
11198 if (!ContinuePredicate || !ContinuePredicate->isConditional())
11199 continue;
11200
11201 Value *Condition = ContinuePredicate->getCondition();
11202
11203 // If we have an edge `E` within the loop body that dominates the only
11204 // latch, the condition guarding `E` also guards the backedge. This
11205 // reasoning works only for loops with a single latch.
11206
11207 BasicBlockEdge DominatingEdge(PBB, BB);
11208 if (DominatingEdge.isSingleEdge()) {
11209 // We're constructively (and conservatively) enumerating edges within the
11210 // loop body that dominate the latch. The dominator tree better agree
11211 // with us on this:
11212 assert(DT.dominates(DominatingEdge, Latch) && "should be!")(static_cast <bool> (DT.dominates(DominatingEdge, Latch
) && "should be!") ? void (0) : __assert_fail ("DT.dominates(DominatingEdge, Latch) && \"should be!\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 11212, __extension__
__PRETTY_FUNCTION__))
;
11213
11214 if (isImpliedCond(Pred, LHS, RHS, Condition,
11215 BB != ContinuePredicate->getSuccessor(0)))
11216 return true;
11217 }
11218 }
11219
11220 return false;
11221}
11222
11223bool ScalarEvolution::isBasicBlockEntryGuardedByCond(const BasicBlock *BB,
11224 ICmpInst::Predicate Pred,
11225 const SCEV *LHS,
11226 const SCEV *RHS) {
11227 // Do not bother proving facts for unreachable code.
11228 if (!DT.isReachableFromEntry(BB))
11229 return true;
11230 if (VerifyIR)
11231 assert(!verifyFunction(*BB->getParent(), &dbgs()) &&(static_cast <bool> (!verifyFunction(*BB->getParent(
), &dbgs()) && "This cannot be done on broken IR!"
) ? void (0) : __assert_fail ("!verifyFunction(*BB->getParent(), &dbgs()) && \"This cannot be done on broken IR!\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 11232, __extension__
__PRETTY_FUNCTION__))
11232 "This cannot be done on broken IR!")(static_cast <bool> (!verifyFunction(*BB->getParent(
), &dbgs()) && "This cannot be done on broken IR!"
) ? void (0) : __assert_fail ("!verifyFunction(*BB->getParent(), &dbgs()) && \"This cannot be done on broken IR!\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 11232, __extension__
__PRETTY_FUNCTION__))
;
11233
11234 // If we cannot prove strict comparison (e.g. a > b), maybe we can prove
11235 // the facts (a >= b && a != b) separately. A typical situation is when the
11236 // non-strict comparison is known from ranges and non-equality is known from
11237 // dominating predicates. If we are proving strict comparison, we always try
11238 // to prove non-equality and non-strict comparison separately.
11239 auto NonStrictPredicate = ICmpInst::getNonStrictPredicate(Pred);
11240 const bool ProvingStrictComparison = (Pred != NonStrictPredicate);
11241 bool ProvedNonStrictComparison = false;
11242 bool ProvedNonEquality = false;
11243
11244 auto SplitAndProve =
11245 [&](std::function<bool(ICmpInst::Predicate)> Fn) -> bool {
11246 if (!ProvedNonStrictComparison)
11247 ProvedNonStrictComparison = Fn(NonStrictPredicate);
11248 if (!ProvedNonEquality)
11249 ProvedNonEquality = Fn(ICmpInst::ICMP_NE);
11250 if (ProvedNonStrictComparison && ProvedNonEquality)
11251 return true;
11252 return false;
11253 };
11254
11255 if (ProvingStrictComparison) {
11256 auto ProofFn = [&](ICmpInst::Predicate P) {
11257 return isKnownViaNonRecursiveReasoning(P, LHS, RHS);
11258 };
11259 if (SplitAndProve(ProofFn))
11260 return true;
11261 }
11262
11263 // Try to prove (Pred, LHS, RHS) using isImpliedCond.
11264 auto ProveViaCond = [&](const Value *Condition, bool Inverse) {
11265 const Instruction *CtxI = &BB->front();
11266 if (isImpliedCond(Pred, LHS, RHS, Condition, Inverse, CtxI))
11267 return true;
11268 if (ProvingStrictComparison) {
11269 auto ProofFn = [&](ICmpInst::Predicate P) {
11270 return isImpliedCond(P, LHS, RHS, Condition, Inverse, CtxI);
11271 };
11272 if (SplitAndProve(ProofFn))
11273 return true;
11274 }
11275 return false;
11276 };
11277
11278 // Starting at the block's predecessor, climb up the predecessor chain, as long
11279 // as there are predecessors that can be found that have unique successors
11280 // leading to the original block.
11281 const Loop *ContainingLoop = LI.getLoopFor(BB);
11282 const BasicBlock *PredBB;
11283 if (ContainingLoop && ContainingLoop->getHeader() == BB)
11284 PredBB = ContainingLoop->getLoopPredecessor();
11285 else
11286 PredBB = BB->getSinglePredecessor();
11287 for (std::pair<const BasicBlock *, const BasicBlock *> Pair(PredBB, BB);
11288 Pair.first; Pair = getPredecessorWithUniqueSuccessorForBB(Pair.first)) {
11289 const BranchInst *BlockEntryPredicate =
11290 dyn_cast<BranchInst>(Pair.first->getTerminator());
11291 if (!BlockEntryPredicate || BlockEntryPredicate->isUnconditional())
11292 continue;
11293
11294 if (ProveViaCond(BlockEntryPredicate->getCondition(),
11295 BlockEntryPredicate->getSuccessor(0) != Pair.second))
11296 return true;
11297 }
11298
11299 // Check conditions due to any @llvm.assume intrinsics.
11300 for (auto &AssumeVH : AC.assumptions()) {
11301 if (!AssumeVH)
11302 continue;
11303 auto *CI = cast<CallInst>(AssumeVH);
11304 if (!DT.dominates(CI, BB))
11305 continue;
11306
11307 if (ProveViaCond(CI->getArgOperand(0), false))
11308 return true;
11309 }
11310
11311 // Check conditions due to any @llvm.experimental.guard intrinsics.
11312 auto *GuardDecl = F.getParent()->getFunction(
11313 Intrinsic::getName(Intrinsic::experimental_guard));
11314 if (GuardDecl)
11315 for (const auto *GU : GuardDecl->users())
11316 if (const auto *Guard = dyn_cast<IntrinsicInst>(GU))
11317 if (Guard->getFunction() == BB->getParent() && DT.dominates(Guard, BB))
11318 if (ProveViaCond(Guard->getArgOperand(0), false))
11319 return true;
11320 return false;
11321}
11322
11323bool ScalarEvolution::isLoopEntryGuardedByCond(const Loop *L,
11324 ICmpInst::Predicate Pred,
11325 const SCEV *LHS,
11326 const SCEV *RHS) {
11327 // Interpret a null as meaning no loop, where there is obviously no guard
11328 // (interprocedural conditions notwithstanding).
11329 if (!L)
11330 return false;
11331
11332 // Both LHS and RHS must be available at loop entry.
11333 assert(isAvailableAtLoopEntry(LHS, L) &&(static_cast <bool> (isAvailableAtLoopEntry(LHS, L) &&
"LHS is not available at Loop Entry") ? void (0) : __assert_fail
("isAvailableAtLoopEntry(LHS, L) && \"LHS is not available at Loop Entry\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 11334, __extension__
__PRETTY_FUNCTION__))
11334 "LHS is not available at Loop Entry")(static_cast <bool> (isAvailableAtLoopEntry(LHS, L) &&
"LHS is not available at Loop Entry") ? void (0) : __assert_fail
("isAvailableAtLoopEntry(LHS, L) && \"LHS is not available at Loop Entry\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 11334, __extension__
__PRETTY_FUNCTION__))
;
11335 assert(isAvailableAtLoopEntry(RHS, L) &&(static_cast <bool> (isAvailableAtLoopEntry(RHS, L) &&
"RHS is not available at Loop Entry") ? void (0) : __assert_fail
("isAvailableAtLoopEntry(RHS, L) && \"RHS is not available at Loop Entry\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 11336, __extension__
__PRETTY_FUNCTION__))
11336 "RHS is not available at Loop Entry")(static_cast <bool> (isAvailableAtLoopEntry(RHS, L) &&
"RHS is not available at Loop Entry") ? void (0) : __assert_fail
("isAvailableAtLoopEntry(RHS, L) && \"RHS is not available at Loop Entry\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 11336, __extension__
__PRETTY_FUNCTION__))
;
11337
11338 if (isKnownViaNonRecursiveReasoning(Pred, LHS, RHS))
11339 return true;
11340
11341 return isBasicBlockEntryGuardedByCond(L->getHeader(), Pred, LHS, RHS);
11342}
11343
11344bool ScalarEvolution::isImpliedCond(ICmpInst::Predicate Pred, const SCEV *LHS,
11345 const SCEV *RHS,
11346 const Value *FoundCondValue, bool Inverse,
11347 const Instruction *CtxI) {
11348 // False conditions implies anything. Do not bother analyzing it further.
11349 if (FoundCondValue ==
11350 ConstantInt::getBool(FoundCondValue->getContext(), Inverse))
11351 return true;
11352
11353 if (!PendingLoopPredicates.insert(FoundCondValue).second)
11354 return false;
11355
11356 auto ClearOnExit =
11357 make_scope_exit([&]() { PendingLoopPredicates.erase(FoundCondValue); });
11358
11359 // Recursively handle And and Or conditions.
11360 const Value *Op0, *Op1;
11361 if (match(FoundCondValue, m_LogicalAnd(m_Value(Op0), m_Value(Op1)))) {
11362 if (!Inverse)
11363 return isImpliedCond(Pred, LHS, RHS, Op0, Inverse, CtxI) ||
11364 isImpliedCond(Pred, LHS, RHS, Op1, Inverse, CtxI);
11365 } else if (match(FoundCondValue, m_LogicalOr(m_Value(Op0), m_Value(Op1)))) {
11366 if (Inverse)
11367 return isImpliedCond(Pred, LHS, RHS, Op0, Inverse, CtxI) ||
11368 isImpliedCond(Pred, LHS, RHS, Op1, Inverse, CtxI);
11369 }
11370
11371 const ICmpInst *ICI = dyn_cast<ICmpInst>(FoundCondValue);
11372 if (!ICI) return false;
11373
11374 // Now that we found a conditional branch that dominates the loop or controls
11375 // the loop latch. Check to see if it is the comparison we are looking for.
11376 ICmpInst::Predicate FoundPred;
11377 if (Inverse)
11378 FoundPred = ICI->getInversePredicate();
11379 else
11380 FoundPred = ICI->getPredicate();
11381
11382 const SCEV *FoundLHS = getSCEV(ICI->getOperand(0));
11383 const SCEV *FoundRHS = getSCEV(ICI->getOperand(1));
11384
11385 return isImpliedCond(Pred, LHS, RHS, FoundPred, FoundLHS, FoundRHS, CtxI);
11386}
11387
11388bool ScalarEvolution::isImpliedCond(ICmpInst::Predicate Pred, const SCEV *LHS,
11389 const SCEV *RHS,
11390 ICmpInst::Predicate FoundPred,
11391 const SCEV *FoundLHS, const SCEV *FoundRHS,
11392 const Instruction *CtxI) {
11393 // Balance the types.
11394 if (getTypeSizeInBits(LHS->getType()) <
11395 getTypeSizeInBits(FoundLHS->getType())) {
11396 // For unsigned and equality predicates, try to prove that both found
11397 // operands fit into narrow unsigned range. If so, try to prove facts in
11398 // narrow types.
11399 if (!CmpInst::isSigned(FoundPred) && !FoundLHS->getType()->isPointerTy() &&
11400 !FoundRHS->getType()->isPointerTy()) {
11401 auto *NarrowType = LHS->getType();
11402 auto *WideType = FoundLHS->getType();
11403 auto BitWidth = getTypeSizeInBits(NarrowType);
11404 const SCEV *MaxValue = getZeroExtendExpr(
11405 getConstant(APInt::getMaxValue(BitWidth)), WideType);
11406 if (isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_ULE, FoundLHS,
11407 MaxValue) &&
11408 isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_ULE, FoundRHS,
11409 MaxValue)) {
11410 const SCEV *TruncFoundLHS = getTruncateExpr(FoundLHS, NarrowType);
11411 const SCEV *TruncFoundRHS = getTruncateExpr(FoundRHS, NarrowType);
11412 if (isImpliedCondBalancedTypes(Pred, LHS, RHS, FoundPred, TruncFoundLHS,
11413 TruncFoundRHS, CtxI))
11414 return true;
11415 }
11416 }
11417
11418 if (LHS->getType()->isPointerTy() || RHS->getType()->isPointerTy())
11419 return false;
11420 if (CmpInst::isSigned(Pred)) {
11421 LHS = getSignExtendExpr(LHS, FoundLHS->getType());
11422 RHS = getSignExtendExpr(RHS, FoundLHS->getType());
11423 } else {
11424 LHS = getZeroExtendExpr(LHS, FoundLHS->getType());
11425 RHS = getZeroExtendExpr(RHS, FoundLHS->getType());
11426 }
11427 } else if (getTypeSizeInBits(LHS->getType()) >
11428 getTypeSizeInBits(FoundLHS->getType())) {
11429 if (FoundLHS->getType()->isPointerTy() || FoundRHS->getType()->isPointerTy())
11430 return false;
11431 if (CmpInst::isSigned(FoundPred)) {
11432 FoundLHS = getSignExtendExpr(FoundLHS, LHS->getType());
11433 FoundRHS = getSignExtendExpr(FoundRHS, LHS->getType());
11434 } else {
11435 FoundLHS = getZeroExtendExpr(FoundLHS, LHS->getType());
11436 FoundRHS = getZeroExtendExpr(FoundRHS, LHS->getType());
11437 }
11438 }
11439 return isImpliedCondBalancedTypes(Pred, LHS, RHS, FoundPred, FoundLHS,
11440 FoundRHS, CtxI);
11441}
11442
11443bool ScalarEvolution::isImpliedCondBalancedTypes(
11444 ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS,
11445 ICmpInst::Predicate FoundPred, const SCEV *FoundLHS, const SCEV *FoundRHS,
11446 const Instruction *CtxI) {
11447 assert(getTypeSizeInBits(LHS->getType()) ==(static_cast <bool> (getTypeSizeInBits(LHS->getType(
)) == getTypeSizeInBits(FoundLHS->getType()) && "Types should be balanced!"
) ? void (0) : __assert_fail ("getTypeSizeInBits(LHS->getType()) == getTypeSizeInBits(FoundLHS->getType()) && \"Types should be balanced!\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 11449, __extension__
__PRETTY_FUNCTION__))
11448 getTypeSizeInBits(FoundLHS->getType()) &&(static_cast <bool> (getTypeSizeInBits(LHS->getType(
)) == getTypeSizeInBits(FoundLHS->getType()) && "Types should be balanced!"
) ? void (0) : __assert_fail ("getTypeSizeInBits(LHS->getType()) == getTypeSizeInBits(FoundLHS->getType()) && \"Types should be balanced!\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 11449, __extension__
__PRETTY_FUNCTION__))
11449 "Types should be balanced!")(static_cast <bool> (getTypeSizeInBits(LHS->getType(
)) == getTypeSizeInBits(FoundLHS->getType()) && "Types should be balanced!"
) ? void (0) : __assert_fail ("getTypeSizeInBits(LHS->getType()) == getTypeSizeInBits(FoundLHS->getType()) && \"Types should be balanced!\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 11449, __extension__
__PRETTY_FUNCTION__))
;
11450 // Canonicalize the query to match the way instcombine will have
11451 // canonicalized the comparison.
11452 if (SimplifyICmpOperands(Pred, LHS, RHS))
11453 if (LHS == RHS)
11454 return CmpInst::isTrueWhenEqual(Pred);
11455 if (SimplifyICmpOperands(FoundPred, FoundLHS, FoundRHS))
11456 if (FoundLHS == FoundRHS)
11457 return CmpInst::isFalseWhenEqual(FoundPred);
11458
11459 // Check to see if we can make the LHS or RHS match.
11460 if (LHS == FoundRHS || RHS == FoundLHS) {
11461 if (isa<SCEVConstant>(RHS)) {
11462 std::swap(FoundLHS, FoundRHS);
11463 FoundPred = ICmpInst::getSwappedPredicate(FoundPred);
11464 } else {
11465 std::swap(LHS, RHS);
11466 Pred = ICmpInst::getSwappedPredicate(Pred);
11467 }
11468 }
11469
11470 // Check whether the found predicate is the same as the desired predicate.
11471 if (FoundPred == Pred)
11472 return isImpliedCondOperands(Pred, LHS, RHS, FoundLHS, FoundRHS, CtxI);
11473
11474 // Check whether swapping the found predicate makes it the same as the
11475 // desired predicate.
11476 if (ICmpInst::getSwappedPredicate(FoundPred) == Pred) {
11477 // We can write the implication
11478 // 0. LHS Pred RHS <- FoundLHS SwapPred FoundRHS
11479 // using one of the following ways:
11480 // 1. LHS Pred RHS <- FoundRHS Pred FoundLHS
11481 // 2. RHS SwapPred LHS <- FoundLHS SwapPred FoundRHS
11482 // 3. LHS Pred RHS <- ~FoundLHS Pred ~FoundRHS
11483 // 4. ~LHS SwapPred ~RHS <- FoundLHS SwapPred FoundRHS
11484 // Forms 1. and 2. require swapping the operands of one condition. Don't
11485 // do this if it would break canonical constant/addrec ordering.
11486 if (!isa<SCEVConstant>(RHS) && !isa<SCEVAddRecExpr>(LHS))
11487 return isImpliedCondOperands(FoundPred, RHS, LHS, FoundLHS, FoundRHS,
11488 CtxI);
11489 if (!isa<SCEVConstant>(FoundRHS) && !isa<SCEVAddRecExpr>(FoundLHS))
11490 return isImpliedCondOperands(Pred, LHS, RHS, FoundRHS, FoundLHS, CtxI);
11491
11492 // There's no clear preference between forms 3. and 4., try both. Avoid
11493 // forming getNotSCEV of pointer values as the resulting subtract is
11494 // not legal.
11495 if (!LHS->getType()->isPointerTy() && !RHS->getType()->isPointerTy() &&
11496 isImpliedCondOperands(FoundPred, getNotSCEV(LHS), getNotSCEV(RHS),
11497 FoundLHS, FoundRHS, CtxI))
11498 return true;
11499
11500 if (!FoundLHS->getType()->isPointerTy() &&
11501 !FoundRHS->getType()->isPointerTy() &&
11502 isImpliedCondOperands(Pred, LHS, RHS, getNotSCEV(FoundLHS),
11503 getNotSCEV(FoundRHS), CtxI))
11504 return true;
11505
11506 return false;
11507 }
11508
11509 auto IsSignFlippedPredicate = [](CmpInst::Predicate P1,
11510 CmpInst::Predicate P2) {
11511 assert(P1 != P2 && "Handled earlier!")(static_cast <bool> (P1 != P2 && "Handled earlier!"
) ? void (0) : __assert_fail ("P1 != P2 && \"Handled earlier!\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 11511, __extension__
__PRETTY_FUNCTION__))
;
11512 return CmpInst::isRelational(P2) &&
11513 P1 == CmpInst::getFlippedSignednessPredicate(P2);
11514 };
11515 if (IsSignFlippedPredicate(Pred, FoundPred)) {
11516 // Unsigned comparison is the same as signed comparison when both the
11517 // operands are non-negative or negative.
11518 if ((isKnownNonNegative(FoundLHS) && isKnownNonNegative(FoundRHS)) ||
11519 (isKnownNegative(FoundLHS) && isKnownNegative(FoundRHS)))
11520 return isImpliedCondOperands(Pred, LHS, RHS, FoundLHS, FoundRHS, CtxI);
11521 // Create local copies that we can freely swap and canonicalize our
11522 // conditions to "le/lt".
11523 ICmpInst::Predicate CanonicalPred = Pred, CanonicalFoundPred = FoundPred;
11524 const SCEV *CanonicalLHS = LHS, *CanonicalRHS = RHS,
11525 *CanonicalFoundLHS = FoundLHS, *CanonicalFoundRHS = FoundRHS;
11526 if (ICmpInst::isGT(CanonicalPred) || ICmpInst::isGE(CanonicalPred)) {
11527 CanonicalPred = ICmpInst::getSwappedPredicate(CanonicalPred);
11528 CanonicalFoundPred = ICmpInst::getSwappedPredicate(CanonicalFoundPred);
11529 std::swap(CanonicalLHS, CanonicalRHS);
11530 std::swap(CanonicalFoundLHS, CanonicalFoundRHS);
11531 }
11532 assert((ICmpInst::isLT(CanonicalPred) || ICmpInst::isLE(CanonicalPred)) &&(static_cast <bool> ((ICmpInst::isLT(CanonicalPred) || ICmpInst
::isLE(CanonicalPred)) && "Must be!") ? void (0) : __assert_fail
("(ICmpInst::isLT(CanonicalPred) || ICmpInst::isLE(CanonicalPred)) && \"Must be!\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 11533, __extension__
__PRETTY_FUNCTION__))
11533 "Must be!")(static_cast <bool> ((ICmpInst::isLT(CanonicalPred) || ICmpInst
::isLE(CanonicalPred)) && "Must be!") ? void (0) : __assert_fail
("(ICmpInst::isLT(CanonicalPred) || ICmpInst::isLE(CanonicalPred)) && \"Must be!\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 11533, __extension__
__PRETTY_FUNCTION__))
;
11534 assert((ICmpInst::isLT(CanonicalFoundPred) ||(static_cast <bool> ((ICmpInst::isLT(CanonicalFoundPred
) || ICmpInst::isLE(CanonicalFoundPred)) && "Must be!"
) ? void (0) : __assert_fail ("(ICmpInst::isLT(CanonicalFoundPred) || ICmpInst::isLE(CanonicalFoundPred)) && \"Must be!\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 11536, __extension__
__PRETTY_FUNCTION__))
11535 ICmpInst::isLE(CanonicalFoundPred)) &&(static_cast <bool> ((ICmpInst::isLT(CanonicalFoundPred
) || ICmpInst::isLE(CanonicalFoundPred)) && "Must be!"
) ? void (0) : __assert_fail ("(ICmpInst::isLT(CanonicalFoundPred) || ICmpInst::isLE(CanonicalFoundPred)) && \"Must be!\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 11536, __extension__
__PRETTY_FUNCTION__))
11536 "Must be!")(static_cast <bool> ((ICmpInst::isLT(CanonicalFoundPred
) || ICmpInst::isLE(CanonicalFoundPred)) && "Must be!"
) ? void (0) : __assert_fail ("(ICmpInst::isLT(CanonicalFoundPred) || ICmpInst::isLE(CanonicalFoundPred)) && \"Must be!\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 11536, __extension__
__PRETTY_FUNCTION__))
;
11537 if (ICmpInst::isSigned(CanonicalPred) && isKnownNonNegative(CanonicalRHS))
11538 // Use implication:
11539 // x <u y && y >=s 0 --> x <s y.
11540 // If we can prove the left part, the right part is also proven.
11541 return isImpliedCondOperands(CanonicalFoundPred, CanonicalLHS,
11542 CanonicalRHS, CanonicalFoundLHS,
11543 CanonicalFoundRHS);
11544 if (ICmpInst::isUnsigned(CanonicalPred) && isKnownNegative(CanonicalRHS))
11545 // Use implication:
11546 // x <s y && y <s 0 --> x <u y.
11547 // If we can prove the left part, the right part is also proven.
11548 return isImpliedCondOperands(CanonicalFoundPred, CanonicalLHS,
11549 CanonicalRHS, CanonicalFoundLHS,
11550 CanonicalFoundRHS);
11551 }
11552
11553 // Check if we can make progress by sharpening ranges.
11554 if (FoundPred == ICmpInst::ICMP_NE &&
11555 (isa<SCEVConstant>(FoundLHS) || isa<SCEVConstant>(FoundRHS))) {
11556
11557 const SCEVConstant *C = nullptr;
11558 const SCEV *V = nullptr;
11559
11560 if (isa<SCEVConstant>(FoundLHS)) {
11561 C = cast<SCEVConstant>(FoundLHS);
11562 V = FoundRHS;
11563 } else {
11564 C = cast<SCEVConstant>(FoundRHS);
11565 V = FoundLHS;
11566 }
11567
11568 // The guarding predicate tells us that C != V. If the known range
11569 // of V is [C, t), we can sharpen the range to [C + 1, t). The
11570 // range we consider has to correspond to same signedness as the
11571 // predicate we're interested in folding.
11572
11573 APInt Min = ICmpInst::isSigned(Pred) ?
11574 getSignedRangeMin(V) : getUnsignedRangeMin(V);
11575
11576 if (Min == C->getAPInt()) {
11577 // Given (V >= Min && V != Min) we conclude V >= (Min + 1).
11578 // This is true even if (Min + 1) wraps around -- in case of
11579 // wraparound, (Min + 1) < Min, so (V >= Min => V >= (Min + 1)).
11580
11581 APInt SharperMin = Min + 1;
11582
11583 switch (Pred) {
11584 case ICmpInst::ICMP_SGE:
11585 case ICmpInst::ICMP_UGE:
11586 // We know V `Pred` SharperMin. If this implies LHS `Pred`
11587 // RHS, we're done.
11588 if (isImpliedCondOperands(Pred, LHS, RHS, V, getConstant(SharperMin),
11589 CtxI))
11590 return true;
11591 [[fallthrough]];
11592
11593 case ICmpInst::ICMP_SGT:
11594 case ICmpInst::ICMP_UGT:
11595 // We know from the range information that (V `Pred` Min ||
11596 // V == Min). We know from the guarding condition that !(V
11597 // == Min). This gives us
11598 //
11599 // V `Pred` Min || V == Min && !(V == Min)
11600 // => V `Pred` Min
11601 //
11602 // If V `Pred` Min implies LHS `Pred` RHS, we're done.
11603
11604 if (isImpliedCondOperands(Pred, LHS, RHS, V, getConstant(Min), CtxI))
11605 return true;
11606 break;
11607
11608 // `LHS < RHS` and `LHS <= RHS` are handled in the same way as `RHS > LHS` and `RHS >= LHS` respectively.
11609 case ICmpInst::ICMP_SLE:
11610 case ICmpInst::ICMP_ULE:
11611 if (isImpliedCondOperands(CmpInst::getSwappedPredicate(Pred), RHS,
11612 LHS, V, getConstant(SharperMin), CtxI))
11613 return true;
11614 [[fallthrough]];
11615
11616 case ICmpInst::ICMP_SLT:
11617 case ICmpInst::ICMP_ULT:
11618 if (isImpliedCondOperands(CmpInst::getSwappedPredicate(Pred), RHS,
11619 LHS, V, getConstant(Min), CtxI))
11620 return true;
11621 break;
11622
11623 default:
11624 // No change
11625 break;
11626 }
11627 }
11628 }
11629
11630 // Check whether the actual condition is beyond sufficient.
11631 if (FoundPred == ICmpInst::ICMP_EQ)
11632 if (ICmpInst::isTrueWhenEqual(Pred))
11633 if (isImpliedCondOperands(Pred, LHS, RHS, FoundLHS, FoundRHS, CtxI))
11634 return true;
11635 if (Pred == ICmpInst::ICMP_NE)
11636 if (!ICmpInst::isTrueWhenEqual(FoundPred))
11637 if (isImpliedCondOperands(FoundPred, LHS, RHS, FoundLHS, FoundRHS, CtxI))
11638 return true;
11639
11640 // Otherwise assume the worst.
11641 return false;
11642}
11643
11644bool ScalarEvolution::splitBinaryAdd(const SCEV *Expr,
11645 const SCEV *&L, const SCEV *&R,
11646 SCEV::NoWrapFlags &Flags) {
11647 const auto *AE = dyn_cast<SCEVAddExpr>(Expr);
11648 if (!AE || AE->getNumOperands() != 2)
11649 return false;
11650
11651 L = AE->getOperand(0);
11652 R = AE->getOperand(1);
11653 Flags = AE->getNoWrapFlags();
11654 return true;
11655}
11656
11657Optional<APInt> ScalarEvolution::computeConstantDifference(const SCEV *More,
11658 const SCEV *Less) {
11659 // We avoid subtracting expressions here because this function is usually
11660 // fairly deep in the call stack (i.e. is called many times).
11661
11662 // X - X = 0.
11663 if (More == Less)
11664 return APInt(getTypeSizeInBits(More->getType()), 0);
11665
11666 if (isa<SCEVAddRecExpr>(Less) && isa<SCEVAddRecExpr>(More)) {
11667 const auto *LAR = cast<SCEVAddRecExpr>(Less);
11668 const auto *MAR = cast<SCEVAddRecExpr>(More);
11669
11670 if (LAR->getLoop() != MAR->getLoop())
11671 return None;
11672
11673 // We look at affine expressions only; not for correctness but to keep
11674 // getStepRecurrence cheap.
11675 if (!LAR->isAffine() || !MAR->isAffine())
11676 return None;
11677
11678 if (LAR->getStepRecurrence(*this) != MAR->getStepRecurrence(*this))
11679 return None;
11680
11681 Less = LAR->getStart();
11682 More = MAR->getStart();
11683
11684 // fall through
11685 }
11686
11687 if (isa<SCEVConstant>(Less) && isa<SCEVConstant>(More)) {
11688 const auto &M = cast<SCEVConstant>(More)->getAPInt();
11689 const auto &L = cast<SCEVConstant>(Less)->getAPInt();
11690 return M - L;
11691 }
11692
11693 SCEV::NoWrapFlags Flags;
11694 const SCEV *LLess = nullptr, *RLess = nullptr;
11695 const SCEV *LMore = nullptr, *RMore = nullptr;
11696 const SCEVConstant *C1 = nullptr, *C2 = nullptr;
11697 // Compare (X + C1) vs X.
11698 if (splitBinaryAdd(Less, LLess, RLess, Flags))
11699 if ((C1 = dyn_cast<SCEVConstant>(LLess)))
11700 if (RLess == More)
11701 return -(C1->getAPInt());
11702
11703 // Compare X vs (X + C2).
11704 if (splitBinaryAdd(More, LMore, RMore, Flags))
11705 if ((C2 = dyn_cast<SCEVConstant>(LMore)))
11706 if (RMore == Less)
11707 return C2->getAPInt();
11708
11709 // Compare (X + C1) vs (X + C2).
11710 if (C1 && C2 && RLess == RMore)
11711 return C2->getAPInt() - C1->getAPInt();
11712
11713 return None;
11714}
11715
11716bool ScalarEvolution::isImpliedCondOperandsViaAddRecStart(
11717 ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS,
11718 const SCEV *FoundLHS, const SCEV *FoundRHS, const Instruction *CtxI) {
11719 // Try to recognize the following pattern:
11720 //
11721 // FoundRHS = ...
11722 // ...
11723 // loop:
11724 // FoundLHS = {Start,+,W}
11725 // context_bb: // Basic block from the same loop
11726 // known(Pred, FoundLHS, FoundRHS)
11727 //
11728 // If some predicate is known in the context of a loop, it is also known on
11729 // each iteration of this loop, including the first iteration. Therefore, in
11730 // this case, `FoundLHS Pred FoundRHS` implies `Start Pred FoundRHS`. Try to
11731 // prove the original pred using this fact.
11732 if (!CtxI)
11733 return false;
11734 const BasicBlock *ContextBB = CtxI->getParent();
11735 // Make sure AR varies in the context block.
11736 if (auto *AR = dyn_cast<SCEVAddRecExpr>(FoundLHS)) {
11737 const Loop *L = AR->getLoop();
11738 // Make sure that context belongs to the loop and executes on 1st iteration
11739 // (if it ever executes at all).
11740 if (!L->contains(ContextBB) || !DT.dominates(ContextBB, L->getLoopLatch()))
11741 return false;
11742 if (!isAvailableAtLoopEntry(FoundRHS, AR->getLoop()))
11743 return false;
11744 return isImpliedCondOperands(Pred, LHS, RHS, AR->getStart(), FoundRHS);
11745 }
11746
11747 if (auto *AR = dyn_cast<SCEVAddRecExpr>(FoundRHS)) {
11748 const Loop *L = AR->getLoop();
11749 // Make sure that context belongs to the loop and executes on 1st iteration
11750 // (if it ever executes at all).
11751 if (!L->contains(ContextBB) || !DT.dominates(ContextBB, L->getLoopLatch()))
11752 return false;
11753 if (!isAvailableAtLoopEntry(FoundLHS, AR->getLoop()))
11754 return false;
11755 return isImpliedCondOperands(Pred, LHS, RHS, FoundLHS, AR->getStart());
11756 }
11757
11758 return false;
11759}
11760
11761bool ScalarEvolution::isImpliedCondOperandsViaNoOverflow(
11762 ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS,
11763 const SCEV *FoundLHS, const SCEV *FoundRHS) {
11764 if (Pred != CmpInst::ICMP_SLT && Pred != CmpInst::ICMP_ULT)
11765 return false;
11766
11767 const auto *AddRecLHS = dyn_cast<SCEVAddRecExpr>(LHS);
11768 if (!AddRecLHS)
11769 return false;
11770
11771 const auto *AddRecFoundLHS = dyn_cast<SCEVAddRecExpr>(FoundLHS);
11772 if (!AddRecFoundLHS)
11773 return false;
11774
11775 // We'd like to let SCEV reason about control dependencies, so we constrain
11776 // both the inequalities to be about add recurrences on the same loop. This
11777 // way we can use isLoopEntryGuardedByCond later.
11778
11779 const Loop *L = AddRecFoundLHS->getLoop();
11780 if (L != AddRecLHS->getLoop())
11781 return false;
11782
11783 // FoundLHS u< FoundRHS u< -C => (FoundLHS + C) u< (FoundRHS + C) ... (1)
11784 //
11785 // FoundLHS s< FoundRHS s< INT_MIN - C => (FoundLHS + C) s< (FoundRHS + C)
11786 // ... (2)
11787 //
11788 // Informal proof for (2), assuming (1) [*]:
11789 //
11790 // We'll also assume (A s< B) <=> ((A + INT_MIN) u< (B + INT_MIN)) ... (3)[**]
11791 //
11792 // Then
11793 //
11794 // FoundLHS s< FoundRHS s< INT_MIN - C
11795 // <=> (FoundLHS + INT_MIN) u< (FoundRHS + INT_MIN) u< -C [ using (3) ]
11796 // <=> (FoundLHS + INT_MIN + C) u< (FoundRHS + INT_MIN + C) [ using (1) ]
11797 // <=> (FoundLHS + INT_MIN + C + INT_MIN) s<
11798 // (FoundRHS + INT_MIN + C + INT_MIN) [ using (3) ]
11799 // <=> FoundLHS + C s< FoundRHS + C
11800 //
11801 // [*]: (1) can be proved by ruling out overflow.
11802 //
11803 // [**]: This can be proved by analyzing all the four possibilities:
11804 // (A s< 0, B s< 0), (A s< 0, B s>= 0), (A s>= 0, B s< 0) and
11805 // (A s>= 0, B s>= 0).
11806 //
11807 // Note:
11808 // Despite (2), "FoundRHS s< INT_MIN - C" does not mean that "FoundRHS + C"
11809 // will not sign underflow. For instance, say FoundLHS = (i8 -128), FoundRHS
11810 // = (i8 -127) and C = (i8 -100). Then INT_MIN - C = (i8 -28), and FoundRHS
11811 // s< (INT_MIN - C). Lack of sign overflow / underflow in "FoundRHS + C" is
11812 // neither necessary nor sufficient to prove "(FoundLHS + C) s< (FoundRHS +
11813 // C)".
11814
11815 Optional<APInt> LDiff = computeConstantDifference(LHS, FoundLHS);
11816 Optional<APInt> RDiff = computeConstantDifference(RHS, FoundRHS);
11817 if (!LDiff || !RDiff || *LDiff != *RDiff)
11818 return false;
11819
11820 if (LDiff->isMinValue())
11821 return true;
11822
11823 APInt FoundRHSLimit;
11824
11825 if (Pred == CmpInst::ICMP_ULT) {
11826 FoundRHSLimit = -(*RDiff);
11827 } else {
11828 assert(Pred == CmpInst::ICMP_SLT && "Checked above!")(static_cast <bool> (Pred == CmpInst::ICMP_SLT &&
"Checked above!") ? void (0) : __assert_fail ("Pred == CmpInst::ICMP_SLT && \"Checked above!\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 11828, __extension__
__PRETTY_FUNCTION__))
;
11829 FoundRHSLimit = APInt::getSignedMinValue(getTypeSizeInBits(RHS->getType())) - *RDiff;
11830 }
11831
11832 // Try to prove (1) or (2), as needed.
11833 return isAvailableAtLoopEntry(FoundRHS, L) &&
11834 isLoopEntryGuardedByCond(L, Pred, FoundRHS,
11835 getConstant(FoundRHSLimit));
11836}
11837
11838bool ScalarEvolution::isImpliedViaMerge(ICmpInst::Predicate Pred,
11839 const SCEV *LHS, const SCEV *RHS,
11840 const SCEV *FoundLHS,
11841 const SCEV *FoundRHS, unsigned Depth) {
11842 const PHINode *LPhi = nullptr, *RPhi = nullptr;
11843
11844 auto ClearOnExit = make_scope_exit([&]() {
11845 if (LPhi) {
11846 bool Erased = PendingMerges.erase(LPhi);
11847 assert(Erased && "Failed to erase LPhi!")(static_cast <bool> (Erased && "Failed to erase LPhi!"
) ? void (0) : __assert_fail ("Erased && \"Failed to erase LPhi!\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 11847, __extension__
__PRETTY_FUNCTION__))
;
11848 (void)Erased;
11849 }
11850 if (RPhi) {
11851 bool Erased = PendingMerges.erase(RPhi);
11852 assert(Erased && "Failed to erase RPhi!")(static_cast <bool> (Erased && "Failed to erase RPhi!"
) ? void (0) : __assert_fail ("Erased && \"Failed to erase RPhi!\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 11852, __extension__
__PRETTY_FUNCTION__))
;
11853 (void)Erased;
11854 }
11855 });
11856
11857 // Find respective Phis and check that they are not being pending.
11858 if (const SCEVUnknown *LU = dyn_cast<SCEVUnknown>(LHS))
11859 if (auto *Phi = dyn_cast<PHINode>(LU->getValue())) {
11860 if (!PendingMerges.insert(Phi).second)
11861 return false;
11862 LPhi = Phi;
11863 }
11864 if (const SCEVUnknown *RU = dyn_cast<SCEVUnknown>(RHS))
11865 if (auto *Phi = dyn_cast<PHINode>(RU->getValue())) {
11866 // If we detect a loop of Phi nodes being processed by this method, for
11867 // example:
11868 //
11869 // %a = phi i32 [ %some1, %preheader ], [ %b, %latch ]
11870 // %b = phi i32 [ %some2, %preheader ], [ %a, %latch ]
11871 //
11872 // we don't want to deal with a case that complex, so return conservative
11873 // answer false.
11874 if (!PendingMerges.insert(Phi).second)
11875 return false;
11876 RPhi = Phi;
11877 }
11878
11879 // If none of LHS, RHS is a Phi, nothing to do here.
11880 if (!LPhi && !RPhi)
11881 return false;
11882
11883 // If there is a SCEVUnknown Phi we are interested in, make it left.
11884 if (!LPhi) {
11885 std::swap(LHS, RHS);
11886 std::swap(FoundLHS, FoundRHS);
11887 std::swap(LPhi, RPhi);
11888 Pred = ICmpInst::getSwappedPredicate(Pred);
11889 }
11890
11891 assert(LPhi && "LPhi should definitely be a SCEVUnknown Phi!")(static_cast <bool> (LPhi && "LPhi should definitely be a SCEVUnknown Phi!"
) ? void (0) : __assert_fail ("LPhi && \"LPhi should definitely be a SCEVUnknown Phi!\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 11891, __extension__
__PRETTY_FUNCTION__))
;
11892 const BasicBlock *LBB = LPhi->getParent();
11893 const SCEVAddRecExpr *RAR = dyn_cast<SCEVAddRecExpr>(RHS);
11894
11895 auto ProvedEasily = [&](const SCEV *S1, const SCEV *S2) {
11896 return isKnownViaNonRecursiveReasoning(Pred, S1, S2) ||
11897 isImpliedCondOperandsViaRanges(Pred, S1, S2, FoundLHS, FoundRHS) ||
11898 isImpliedViaOperations(Pred, S1, S2, FoundLHS, FoundRHS, Depth);
11899 };
11900
11901 if (RPhi && RPhi->getParent() == LBB) {
11902 // Case one: RHS is also a SCEVUnknown Phi from the same basic block.
11903 // If we compare two Phis from the same block, and for each entry block
11904 // the predicate is true for incoming values from this block, then the
11905 // predicate is also true for the Phis.
11906 for (const BasicBlock *IncBB : predecessors(LBB)) {
11907 const SCEV *L = getSCEV(LPhi->getIncomingValueForBlock(IncBB));
11908 const SCEV *R = getSCEV(RPhi->getIncomingValueForBlock(IncBB));
11909 if (!ProvedEasily(L, R))
11910 return false;
11911 }
11912 } else if (RAR && RAR->getLoop()->getHeader() == LBB) {
11913 // Case two: RHS is also a Phi from the same basic block, and it is an
11914 // AddRec. It means that there is a loop which has both AddRec and Unknown
11915 // PHIs, for it we can compare incoming values of AddRec from above the loop
11916 // and latch with their respective incoming values of LPhi.
11917 // TODO: Generalize to handle loops with many inputs in a header.
11918 if (LPhi->getNumIncomingValues() != 2) return false;
11919
11920 auto *RLoop = RAR->getLoop();
11921 auto *Predecessor = RLoop->getLoopPredecessor();
11922 assert(Predecessor && "Loop with AddRec with no predecessor?")(static_cast <bool> (Predecessor && "Loop with AddRec with no predecessor?"
) ? void (0) : __assert_fail ("Predecessor && \"Loop with AddRec with no predecessor?\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 11922, __extension__
__PRETTY_FUNCTION__))
;
11923 const SCEV *L1 = getSCEV(LPhi->getIncomingValueForBlock(Predecessor));
11924 if (!ProvedEasily(L1, RAR->getStart()))
11925 return false;
11926 auto *Latch = RLoop->getLoopLatch();
11927 assert(Latch && "Loop with AddRec with no latch?")(static_cast <bool> (Latch && "Loop with AddRec with no latch?"
) ? void (0) : __assert_fail ("Latch && \"Loop with AddRec with no latch?\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 11927, __extension__
__PRETTY_FUNCTION__))
;
11928 const SCEV *L2 = getSCEV(LPhi->getIncomingValueForBlock(Latch));
11929 if (!ProvedEasily(L2, RAR->getPostIncExpr(*this)))
11930 return false;
11931 } else {
11932 // In all other cases go over inputs of LHS and compare each of them to RHS,
11933 // the predicate is true for (LHS, RHS) if it is true for all such pairs.
11934 // At this point RHS is either a non-Phi, or it is a Phi from some block
11935 // different from LBB.
11936 for (const BasicBlock *IncBB : predecessors(LBB)) {
11937 // Check that RHS is available in this block.
11938 if (!dominates(RHS, IncBB))
11939 return false;
11940 const SCEV *L = getSCEV(LPhi->getIncomingValueForBlock(IncBB));
11941 // Make sure L does not refer to a value from a potentially previous
11942 // iteration of a loop.
11943 if (!properlyDominates(L, LBB))
11944 return false;
11945 if (!ProvedEasily(L, RHS))
11946 return false;
11947 }
11948 }
11949 return true;
11950}
11951
11952bool ScalarEvolution::isImpliedCondOperandsViaShift(ICmpInst::Predicate Pred,
11953 const SCEV *LHS,
11954 const SCEV *RHS,
11955 const SCEV *FoundLHS,
11956 const SCEV *FoundRHS) {
11957 // We want to imply LHS < RHS from LHS < (RHS >> shiftvalue). First, make
11958 // sure that we are dealing with same LHS.
11959 if (RHS == FoundRHS) {
11960 std::swap(LHS, RHS);
11961 std::swap(FoundLHS, FoundRHS);
11962 Pred = ICmpInst::getSwappedPredicate(Pred);
11963 }
11964 if (LHS != FoundLHS)
11965 return false;
11966
11967 auto *SUFoundRHS = dyn_cast<SCEVUnknown>(FoundRHS);
11968 if (!SUFoundRHS)
11969 return false;
11970
11971 Value *Shiftee, *ShiftValue;
11972
11973 using namespace PatternMatch;
11974 if (match(SUFoundRHS->getValue(),
11975 m_LShr(m_Value(Shiftee), m_Value(ShiftValue)))) {
11976 auto *ShifteeS = getSCEV(Shiftee);
11977 // Prove one of the following:
11978 // LHS <u (shiftee >> shiftvalue) && shiftee <=u RHS ---> LHS <u RHS
11979 // LHS <=u (shiftee >> shiftvalue) && shiftee <=u RHS ---> LHS <=u RHS
11980 // LHS <s (shiftee >> shiftvalue) && shiftee <=s RHS && shiftee >=s 0
11981 // ---> LHS <s RHS
11982 // LHS <=s (shiftee >> shiftvalue) && shiftee <=s RHS && shiftee >=s 0
11983 // ---> LHS <=s RHS
11984 if (Pred == ICmpInst::ICMP_ULT || Pred == ICmpInst::ICMP_ULE)
11985 return isKnownPredicate(ICmpInst::ICMP_ULE, ShifteeS, RHS);
11986 if (Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_SLE)
11987 if (isKnownNonNegative(ShifteeS))
11988 return isKnownPredicate(ICmpInst::ICMP_SLE, ShifteeS, RHS);
11989 }
11990
11991 return false;
11992}
11993
11994bool ScalarEvolution::isImpliedCondOperands(ICmpInst::Predicate Pred,
11995 const SCEV *LHS, const SCEV *RHS,
11996 const SCEV *FoundLHS,
11997 const SCEV *FoundRHS,
11998 const Instruction *CtxI) {
11999 if (isImpliedCondOperandsViaRanges(Pred, LHS, RHS, FoundLHS, FoundRHS))
12000 return true;
12001
12002 if (isImpliedCondOperandsViaNoOverflow(Pred, LHS, RHS, FoundLHS, FoundRHS))
12003 return true;
12004
12005 if (isImpliedCondOperandsViaShift(Pred, LHS, RHS, FoundLHS, FoundRHS))
12006 return true;
12007
12008 if (isImpliedCondOperandsViaAddRecStart(Pred, LHS, RHS, FoundLHS, FoundRHS,
12009 CtxI))
12010 return true;
12011
12012 return isImpliedCondOperandsHelper(Pred, LHS, RHS,
12013 FoundLHS, FoundRHS);
12014}
12015
12016/// Is MaybeMinMaxExpr an (U|S)(Min|Max) of Candidate and some other values?
12017template <typename MinMaxExprType>
12018static bool IsMinMaxConsistingOf(const SCEV *MaybeMinMaxExpr,
12019 const SCEV *Candidate) {
12020 const MinMaxExprType *MinMaxExpr = dyn_cast<MinMaxExprType>(MaybeMinMaxExpr);
12021 if (!MinMaxExpr)
12022 return false;
12023
12024 return is_contained(MinMaxExpr->operands(), Candidate);
12025}
12026
12027static bool IsKnownPredicateViaAddRecStart(ScalarEvolution &SE,
12028 ICmpInst::Predicate Pred,
12029 const SCEV *LHS, const SCEV *RHS) {
12030 // If both sides are affine addrecs for the same loop, with equal
12031 // steps, and we know the recurrences don't wrap, then we only
12032 // need to check the predicate on the starting values.
12033
12034 if (!ICmpInst::isRelational(Pred))
12035 return false;
12036
12037 const SCEVAddRecExpr *LAR = dyn_cast<SCEVAddRecExpr>(LHS);
12038 if (!LAR)
12039 return false;
12040 const SCEVAddRecExpr *RAR = dyn_cast<SCEVAddRecExpr>(RHS);
12041 if (!RAR)
12042 return false;
12043 if (LAR->getLoop() != RAR->getLoop())
12044 return false;
12045 if (!LAR->isAffine() || !RAR->isAffine())
12046 return false;
12047
12048 if (LAR->getStepRecurrence(SE) != RAR->getStepRecurrence(SE))
12049 return false;
12050
12051 SCEV::NoWrapFlags NW = ICmpInst::isSigned(Pred) ?
12052 SCEV::FlagNSW : SCEV::FlagNUW;
12053 if (!LAR->getNoWrapFlags(NW) || !RAR->getNoWrapFlags(NW))
12054 return false;
12055
12056 return SE.isKnownPredicate(Pred, LAR->getStart(), RAR->getStart());
12057}
12058
12059/// Is LHS `Pred` RHS true on the virtue of LHS or RHS being a Min or Max
12060/// expression?
12061static bool IsKnownPredicateViaMinOrMax(ScalarEvolution &SE,
12062 ICmpInst::Predicate Pred,
12063 const SCEV *LHS, const SCEV *RHS) {
12064 switch (Pred) {
12065 default:
12066 return false;
12067
12068 case ICmpInst::ICMP_SGE:
12069 std::swap(LHS, RHS);
12070 [[fallthrough]];
12071 case ICmpInst::ICMP_SLE:
12072 return
12073 // min(A, ...) <= A
12074 IsMinMaxConsistingOf<SCEVSMinExpr>(LHS, RHS) ||
12075 // A <= max(A, ...)
12076 IsMinMaxConsistingOf<SCEVSMaxExpr>(RHS, LHS);
12077
12078 case ICmpInst::ICMP_UGE:
12079 std::swap(LHS, RHS);
12080 [[fallthrough]];
12081 case ICmpInst::ICMP_ULE:
12082 return
12083 // min(A, ...) <= A
12084 // FIXME: what about umin_seq?
12085 IsMinMaxConsistingOf<SCEVUMinExpr>(LHS, RHS) ||
12086 // A <= max(A, ...)
12087 IsMinMaxConsistingOf<SCEVUMaxExpr>(RHS, LHS);
12088 }
12089
12090 llvm_unreachable("covered switch fell through?!")::llvm::llvm_unreachable_internal("covered switch fell through?!"
, "llvm/lib/Analysis/ScalarEvolution.cpp", 12090)
;
12091}
12092
12093bool ScalarEvolution::isImpliedViaOperations(ICmpInst::Predicate Pred,
12094 const SCEV *LHS, const SCEV *RHS,
12095 const SCEV *FoundLHS,
12096 const SCEV *FoundRHS,
12097 unsigned Depth) {
12098 assert(getTypeSizeInBits(LHS->getType()) ==(static_cast <bool> (getTypeSizeInBits(LHS->getType(
)) == getTypeSizeInBits(RHS->getType()) && "LHS and RHS have different sizes?"
) ? void (0) : __assert_fail ("getTypeSizeInBits(LHS->getType()) == getTypeSizeInBits(RHS->getType()) && \"LHS and RHS have different sizes?\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 12100, __extension__
__PRETTY_FUNCTION__))
12099 getTypeSizeInBits(RHS->getType()) &&(static_cast <bool> (getTypeSizeInBits(LHS->getType(
)) == getTypeSizeInBits(RHS->getType()) && "LHS and RHS have different sizes?"
) ? void (0) : __assert_fail ("getTypeSizeInBits(LHS->getType()) == getTypeSizeInBits(RHS->getType()) && \"LHS and RHS have different sizes?\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 12100, __extension__
__PRETTY_FUNCTION__))
12100 "LHS and RHS have different sizes?")(static_cast <bool> (getTypeSizeInBits(LHS->getType(
)) == getTypeSizeInBits(RHS->getType()) && "LHS and RHS have different sizes?"
) ? void (0) : __assert_fail ("getTypeSizeInBits(LHS->getType()) == getTypeSizeInBits(RHS->getType()) && \"LHS and RHS have different sizes?\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 12100, __extension__
__PRETTY_FUNCTION__))
;
12101 assert(getTypeSizeInBits(FoundLHS->getType()) ==(static_cast <bool> (getTypeSizeInBits(FoundLHS->getType
()) == getTypeSizeInBits(FoundRHS->getType()) && "FoundLHS and FoundRHS have different sizes?"
) ? void (0) : __assert_fail ("getTypeSizeInBits(FoundLHS->getType()) == getTypeSizeInBits(FoundRHS->getType()) && \"FoundLHS and FoundRHS have different sizes?\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 12103, __extension__
__PRETTY_FUNCTION__))
12102 getTypeSizeInBits(FoundRHS->getType()) &&(static_cast <bool> (getTypeSizeInBits(FoundLHS->getType
()) == getTypeSizeInBits(FoundRHS->getType()) && "FoundLHS and FoundRHS have different sizes?"
) ? void (0) : __assert_fail ("getTypeSizeInBits(FoundLHS->getType()) == getTypeSizeInBits(FoundRHS->getType()) && \"FoundLHS and FoundRHS have different sizes?\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 12103, __extension__
__PRETTY_FUNCTION__))
12103 "FoundLHS and FoundRHS have different sizes?")(static_cast <bool> (getTypeSizeInBits(FoundLHS->getType
()) == getTypeSizeInBits(FoundRHS->getType()) && "FoundLHS and FoundRHS have different sizes?"
) ? void (0) : __assert_fail ("getTypeSizeInBits(FoundLHS->getType()) == getTypeSizeInBits(FoundRHS->getType()) && \"FoundLHS and FoundRHS have different sizes?\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 12103, __extension__
__PRETTY_FUNCTION__))
;
12104 // We want to avoid hurting the compile time with analysis of too big trees.
12105 if (Depth > MaxSCEVOperationsImplicationDepth)
12106 return false;
12107
12108 // We only want to work with GT comparison so far.
12109 if (Pred == ICmpInst::ICMP_ULT || Pred == ICmpInst::ICMP_SLT) {
12110 Pred = CmpInst::getSwappedPredicate(Pred);
12111 std::swap(LHS, RHS);
12112 std::swap(FoundLHS, FoundRHS);
12113 }
12114
12115 // For unsigned, try to reduce it to corresponding signed comparison.
12116 if (Pred == ICmpInst::ICMP_UGT)
12117 // We can replace unsigned predicate with its signed counterpart if all
12118 // involved values are non-negative.
12119 // TODO: We could have better support for unsigned.
12120 if (isKnownNonNegative(FoundLHS) && isKnownNonNegative(FoundRHS)) {
12121 // Knowing that both FoundLHS and FoundRHS are non-negative, and knowing
12122 // FoundLHS >u FoundRHS, we also know that FoundLHS >s FoundRHS. Let us
12123 // use this fact to prove that LHS and RHS are non-negative.
12124 const SCEV *MinusOne = getMinusOne(LHS->getType());
12125 if (isImpliedCondOperands(ICmpInst::ICMP_SGT, LHS, MinusOne, FoundLHS,
12126 FoundRHS) &&
12127 isImpliedCondOperands(ICmpInst::ICMP_SGT, RHS, MinusOne, FoundLHS,
12128 FoundRHS))
12129 Pred = ICmpInst::ICMP_SGT;
12130 }
12131
12132 if (Pred != ICmpInst::ICMP_SGT)
12133 return false;
12134
12135 auto GetOpFromSExt = [&](const SCEV *S) {
12136 if (auto *Ext = dyn_cast<SCEVSignExtendExpr>(S))
12137 return Ext->getOperand();
12138 // TODO: If S is a SCEVConstant then you can cheaply "strip" the sext off
12139 // the constant in some cases.
12140 return S;
12141 };
12142
12143 // Acquire values from extensions.
12144 auto *OrigLHS = LHS;
12145 auto *OrigFoundLHS = FoundLHS;
12146 LHS = GetOpFromSExt(LHS);
12147 FoundLHS = GetOpFromSExt(FoundLHS);
12148
12149 // Is the SGT predicate can be proved trivially or using the found context.
12150 auto IsSGTViaContext = [&](const SCEV *S1, const SCEV *S2) {
12151 return isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_SGT, S1, S2) ||
12152 isImpliedViaOperations(ICmpInst::ICMP_SGT, S1, S2, OrigFoundLHS,
12153 FoundRHS, Depth + 1);
12154 };
12155
12156 if (auto *LHSAddExpr = dyn_cast<SCEVAddExpr>(LHS)) {
12157 // We want to avoid creation of any new non-constant SCEV. Since we are
12158 // going to compare the operands to RHS, we should be certain that we don't
12159 // need any size extensions for this. So let's decline all cases when the
12160 // sizes of types of LHS and RHS do not match.
12161 // TODO: Maybe try to get RHS from sext to catch more cases?
12162 if (getTypeSizeInBits(LHS->getType()) != getTypeSizeInBits(RHS->getType()))
12163 return false;
12164
12165 // Should not overflow.
12166 if (!LHSAddExpr->hasNoSignedWrap())
12167 return false;
12168
12169 auto *LL = LHSAddExpr->getOperand(0);
12170 auto *LR = LHSAddExpr->getOperand(1);
12171 auto *MinusOne = getMinusOne(RHS->getType());
12172
12173 // Checks that S1 >= 0 && S2 > RHS, trivially or using the found context.
12174 auto IsSumGreaterThanRHS = [&](const SCEV *S1, const SCEV *S2) {
12175 return IsSGTViaContext(S1, MinusOne) && IsSGTViaContext(S2, RHS);
12176 };
12177 // Try to prove the following rule:
12178 // (LHS = LL + LR) && (LL >= 0) && (LR > RHS) => (LHS > RHS).
12179 // (LHS = LL + LR) && (LR >= 0) && (LL > RHS) => (LHS > RHS).
12180 if (IsSumGreaterThanRHS(LL, LR) || IsSumGreaterThanRHS(LR, LL))
12181 return true;
12182 } else if (auto *LHSUnknownExpr = dyn_cast<SCEVUnknown>(LHS)) {
12183 Value *LL, *LR;
12184 // FIXME: Once we have SDiv implemented, we can get rid of this matching.
12185
12186 using namespace llvm::PatternMatch;
12187
12188 if (match(LHSUnknownExpr->getValue(), m_SDiv(m_Value(LL), m_Value(LR)))) {
12189 // Rules for division.
12190 // We are going to perform some comparisons with Denominator and its
12191 // derivative expressions. In general case, creating a SCEV for it may
12192 // lead to a complex analysis of the entire graph, and in particular it
12193 // can request trip count recalculation for the same loop. This would
12194 // cache as SCEVCouldNotCompute to avoid the infinite recursion. To avoid
12195 // this, we only want to create SCEVs that are constants in this section.
12196 // So we bail if Denominator is not a constant.
12197 if (!isa<ConstantInt>(LR))
12198 return false;
12199
12200 auto *Denominator = cast<SCEVConstant>(getSCEV(LR));
12201
12202 // We want to make sure that LHS = FoundLHS / Denominator. If it is so,
12203 // then a SCEV for the numerator already exists and matches with FoundLHS.
12204 auto *Numerator = getExistingSCEV(LL);
12205 if (!Numerator || Numerator->getType() != FoundLHS->getType())
12206 return false;
12207
12208 // Make sure that the numerator matches with FoundLHS and the denominator
12209 // is positive.
12210 if (!HasSameValue(Numerator, FoundLHS) || !isKnownPositive(Denominator))
12211 return false;
12212
12213 auto *DTy = Denominator->getType();
12214 auto *FRHSTy = FoundRHS->getType();
12215 if (DTy->isPointerTy() != FRHSTy->isPointerTy())
12216 // One of types is a pointer and another one is not. We cannot extend
12217 // them properly to a wider type, so let us just reject this case.
12218 // TODO: Usage of getEffectiveSCEVType for DTy, FRHSTy etc should help
12219 // to avoid this check.
12220 return false;
12221
12222 // Given that:
12223 // FoundLHS > FoundRHS, LHS = FoundLHS / Denominator, Denominator > 0.
12224 auto *WTy = getWiderType(DTy, FRHSTy);
12225 auto *DenominatorExt = getNoopOrSignExtend(Denominator, WTy);
12226 auto *FoundRHSExt = getNoopOrSignExtend(FoundRHS, WTy);
12227
12228 // Try to prove the following rule:
12229 // (FoundRHS > Denominator - 2) && (RHS <= 0) => (LHS > RHS).
12230 // For example, given that FoundLHS > 2. It means that FoundLHS is at
12231 // least 3. If we divide it by Denominator < 4, we will have at least 1.
12232 auto *DenomMinusTwo = getMinusSCEV(DenominatorExt, getConstant(WTy, 2));
12233 if (isKnownNonPositive(RHS) &&
12234 IsSGTViaContext(FoundRHSExt, DenomMinusTwo))
12235 return true;
12236
12237 // Try to prove the following rule:
12238 // (FoundRHS > -1 - Denominator) && (RHS < 0) => (LHS > RHS).
12239 // For example, given that FoundLHS > -3. Then FoundLHS is at least -2.
12240 // If we divide it by Denominator > 2, then:
12241 // 1. If FoundLHS is negative, then the result is 0.
12242 // 2. If FoundLHS is non-negative, then the result is non-negative.
12243 // Anyways, the result is non-negative.
12244 auto *MinusOne = getMinusOne(WTy);
12245 auto *NegDenomMinusOne = getMinusSCEV(MinusOne, DenominatorExt);
12246 if (isKnownNegative(RHS) &&
12247 IsSGTViaContext(FoundRHSExt, NegDenomMinusOne))
12248 return true;
12249 }
12250 }
12251
12252 // If our expression contained SCEVUnknown Phis, and we split it down and now
12253 // need to prove something for them, try to prove the predicate for every
12254 // possible incoming values of those Phis.
12255 if (isImpliedViaMerge(Pred, OrigLHS, RHS, OrigFoundLHS, FoundRHS, Depth + 1))
12256 return true;
12257
12258 return false;
12259}
12260
12261static bool isKnownPredicateExtendIdiom(ICmpInst::Predicate Pred,
12262 const SCEV *LHS, const SCEV *RHS) {
12263 // zext x u<= sext x, sext x s<= zext x
12264 switch (Pred) {
12265 case ICmpInst::ICMP_SGE:
12266 std::swap(LHS, RHS);
12267 [[fallthrough]];
12268 case ICmpInst::ICMP_SLE: {
12269 // If operand >=s 0 then ZExt == SExt. If operand <s 0 then SExt <s ZExt.
12270 const SCEVSignExtendExpr *SExt = dyn_cast<SCEVSignExtendExpr>(LHS);
12271 const SCEVZeroExtendExpr *ZExt = dyn_cast<SCEVZeroExtendExpr>(RHS);
12272 if (SExt && ZExt && SExt->getOperand() == ZExt->getOperand())
12273 return true;
12274 break;
12275 }
12276 case ICmpInst::ICMP_UGE:
12277 std::swap(LHS, RHS);
12278 [[fallthrough]];
12279 case ICmpInst::ICMP_ULE: {
12280 // If operand >=s 0 then ZExt == SExt. If operand <s 0 then ZExt <u SExt.
12281 const SCEVZeroExtendExpr *ZExt = dyn_cast<SCEVZeroExtendExpr>(LHS);
12282 const SCEVSignExtendExpr *SExt = dyn_cast<SCEVSignExtendExpr>(RHS);
12283 if (SExt && ZExt && SExt->getOperand() == ZExt->getOperand())
12284 return true;
12285 break;
12286 }
12287 default:
12288 break;
12289 };
12290 return false;
12291}
12292
12293bool
12294ScalarEvolution::isKnownViaNonRecursiveReasoning(ICmpInst::Predicate Pred,
12295 const SCEV *LHS, const SCEV *RHS) {
12296 return isKnownPredicateExtendIdiom(Pred, LHS, RHS) ||
12297 isKnownPredicateViaConstantRanges(Pred, LHS, RHS) ||
12298 IsKnownPredicateViaMinOrMax(*this, Pred, LHS, RHS) ||
12299 IsKnownPredicateViaAddRecStart(*this, Pred, LHS, RHS) ||
12300 isKnownPredicateViaNoOverflow(Pred, LHS, RHS);
12301}
12302
12303bool
12304ScalarEvolution::isImpliedCondOperandsHelper(ICmpInst::Predicate Pred,
12305 const SCEV *LHS, const SCEV *RHS,
12306 const SCEV *FoundLHS,
12307 const SCEV *FoundRHS) {
12308 switch (Pred) {
12309 default: llvm_unreachable("Unexpected ICmpInst::Predicate value!")::llvm::llvm_unreachable_internal("Unexpected ICmpInst::Predicate value!"
, "llvm/lib/Analysis/ScalarEvolution.cpp", 12309)
;
12310 case ICmpInst::ICMP_EQ:
12311 case ICmpInst::ICMP_NE:
12312 if (HasSameValue(LHS, FoundLHS) && HasSameValue(RHS, FoundRHS))
12313 return true;
12314 break;
12315 case ICmpInst::ICMP_SLT:
12316 case ICmpInst::ICMP_SLE:
12317 if (isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_SLE, LHS, FoundLHS) &&
12318 isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_SGE, RHS, FoundRHS))
12319 return true;
12320 break;
12321 case ICmpInst::ICMP_SGT:
12322 case ICmpInst::ICMP_SGE:
12323 if (isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_SGE, LHS, FoundLHS) &&
12324 isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_SLE, RHS, FoundRHS))
12325 return true;
12326 break;
12327 case ICmpInst::ICMP_ULT:
12328 case ICmpInst::ICMP_ULE:
12329 if (isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_ULE, LHS, FoundLHS) &&
12330 isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_UGE, RHS, FoundRHS))
12331 return true;
12332 break;
12333 case ICmpInst::ICMP_UGT:
12334 case ICmpInst::ICMP_UGE:
12335 if (isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_UGE, LHS, FoundLHS) &&
12336 isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_ULE, RHS, FoundRHS))
12337 return true;
12338 break;
12339 }
12340
12341 // Maybe it can be proved via operations?
12342 if (isImpliedViaOperations(Pred, LHS, RHS, FoundLHS, FoundRHS))
12343 return true;
12344
12345 return false;
12346}
12347
12348bool ScalarEvolution::isImpliedCondOperandsViaRanges(ICmpInst::Predicate Pred,
12349 const SCEV *LHS,
12350 const SCEV *RHS,
12351 const SCEV *FoundLHS,
12352 const SCEV *FoundRHS) {
12353 if (!isa<SCEVConstant>(RHS) || !isa<SCEVConstant>(FoundRHS))
12354 // The restriction on `FoundRHS` be lifted easily -- it exists only to
12355 // reduce the compile time impact of this optimization.
12356 return false;
12357
12358 Optional<APInt> Addend = computeConstantDifference(LHS, FoundLHS);
12359 if (!Addend)
12360 return false;
12361
12362 const APInt &ConstFoundRHS = cast<SCEVConstant>(FoundRHS)->getAPInt();
12363
12364 // `FoundLHSRange` is the range we know `FoundLHS` to be in by virtue of the
12365 // antecedent "`FoundLHS` `Pred` `FoundRHS`".
12366 ConstantRange FoundLHSRange =
12367 ConstantRange::makeExactICmpRegion(Pred, ConstFoundRHS);
12368
12369 // Since `LHS` is `FoundLHS` + `Addend`, we can compute a range for `LHS`:
12370 ConstantRange LHSRange = FoundLHSRange.add(ConstantRange(*Addend));
12371
12372 // We can also compute the range of values for `LHS` that satisfy the
12373 // consequent, "`LHS` `Pred` `RHS`":
12374 const APInt &ConstRHS = cast<SCEVConstant>(RHS)->getAPInt();
12375 // The antecedent implies the consequent if every value of `LHS` that
12376 // satisfies the antecedent also satisfies the consequent.
12377 return LHSRange.icmp(Pred, ConstRHS);
12378}
12379
12380bool ScalarEvolution::canIVOverflowOnLT(const SCEV *RHS, const SCEV *Stride,
12381 bool IsSigned) {
12382 assert(isKnownPositive(Stride) && "Positive stride expected!")(static_cast <bool> (isKnownPositive(Stride) &&
"Positive stride expected!") ? void (0) : __assert_fail ("isKnownPositive(Stride) && \"Positive stride expected!\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 12382, __extension__
__PRETTY_FUNCTION__))
;
12383
12384 unsigned BitWidth = getTypeSizeInBits(RHS->getType());
12385 const SCEV *One = getOne(Stride->getType());
12386
12387 if (IsSigned) {
12388 APInt MaxRHS = getSignedRangeMax(RHS);
12389 APInt MaxValue = APInt::getSignedMaxValue(BitWidth);
12390 APInt MaxStrideMinusOne = getSignedRangeMax(getMinusSCEV(Stride, One));
12391
12392 // SMaxRHS + SMaxStrideMinusOne > SMaxValue => overflow!
12393 return (std::move(MaxValue) - MaxStrideMinusOne).slt(MaxRHS);
12394 }
12395
12396 APInt MaxRHS = getUnsignedRangeMax(RHS);
12397 APInt MaxValue = APInt::getMaxValue(BitWidth);
12398 APInt MaxStrideMinusOne = getUnsignedRangeMax(getMinusSCEV(Stride, One));
12399
12400 // UMaxRHS + UMaxStrideMinusOne > UMaxValue => overflow!
12401 return (std::move(MaxValue) - MaxStrideMinusOne).ult(MaxRHS);
12402}
12403
12404bool ScalarEvolution::canIVOverflowOnGT(const SCEV *RHS, const SCEV *Stride,
12405 bool IsSigned) {
12406
12407 unsigned BitWidth = getTypeSizeInBits(RHS->getType());
12408 const SCEV *One = getOne(Stride->getType());
12409
12410 if (IsSigned) {
12411 APInt MinRHS = getSignedRangeMin(RHS);
12412 APInt MinValue = APInt::getSignedMinValue(BitWidth);
12413 APInt MaxStrideMinusOne = getSignedRangeMax(getMinusSCEV(Stride, One));
12414
12415 // SMinRHS - SMaxStrideMinusOne < SMinValue => overflow!
12416 return (std::move(MinValue) + MaxStrideMinusOne).sgt(MinRHS);
12417 }
12418
12419 APInt MinRHS = getUnsignedRangeMin(RHS);
12420 APInt MinValue = APInt::getMinValue(BitWidth);
12421 APInt MaxStrideMinusOne = getUnsignedRangeMax(getMinusSCEV(Stride, One));
12422
12423 // UMinRHS - UMaxStrideMinusOne < UMinValue => overflow!
12424 return (std::move(MinValue) + MaxStrideMinusOne).ugt(MinRHS);
12425}
12426
12427const SCEV *ScalarEvolution::getUDivCeilSCEV(const SCEV *N, const SCEV *D) {
12428 // umin(N, 1) + floor((N - umin(N, 1)) / D)
12429 // This is equivalent to "1 + floor((N - 1) / D)" for N != 0. The umin
12430 // expression fixes the case of N=0.
12431 const SCEV *MinNOne = getUMinExpr(N, getOne(N->getType()));
12432 const SCEV *NMinusOne = getMinusSCEV(N, MinNOne);
12433 return getAddExpr(MinNOne, getUDivExpr(NMinusOne, D));
12434}
12435
12436const SCEV *ScalarEvolution::computeMaxBECountForLT(const SCEV *Start,
12437 const SCEV *Stride,
12438 const SCEV *End,
12439 unsigned BitWidth,
12440 bool IsSigned) {
12441 // The logic in this function assumes we can represent a positive stride.
12442 // If we can't, the backedge-taken count must be zero.
12443 if (IsSigned && BitWidth == 1)
12444 return getZero(Stride->getType());
12445
12446 // This code below only been closely audited for negative strides in the
12447 // unsigned comparison case, it may be correct for signed comparison, but
12448 // that needs to be established.
12449 if (IsSigned && isKnownNegative(Stride))
12450 return getCouldNotCompute();
12451
12452 // Calculate the maximum backedge count based on the range of values
12453 // permitted by Start, End, and Stride.
12454 APInt MinStart =
12455 IsSigned ? getSignedRangeMin(Start) : getUnsignedRangeMin(Start);
12456
12457 APInt MinStride =
12458 IsSigned ? getSignedRangeMin(Stride) : getUnsignedRangeMin(Stride);
12459
12460 // We assume either the stride is positive, or the backedge-taken count
12461 // is zero. So force StrideForMaxBECount to be at least one.
12462 APInt One(BitWidth, 1);
12463 APInt StrideForMaxBECount = IsSigned ? APIntOps::smax(One, MinStride)
12464 : APIntOps::umax(One, MinStride);
12465
12466 APInt MaxValue = IsSigned ? APInt::getSignedMaxValue(BitWidth)
12467 : APInt::getMaxValue(BitWidth);
12468 APInt Limit = MaxValue - (StrideForMaxBECount - 1);
12469
12470 // Although End can be a MAX expression we estimate MaxEnd considering only
12471 // the case End = RHS of the loop termination condition. This is safe because
12472 // in the other case (End - Start) is zero, leading to a zero maximum backedge
12473 // taken count.
12474 APInt MaxEnd = IsSigned ? APIntOps::smin(getSignedRangeMax(End), Limit)
12475 : APIntOps::umin(getUnsignedRangeMax(End), Limit);
12476
12477 // MaxBECount = ceil((max(MaxEnd, MinStart) - MinStart) / Stride)
12478 MaxEnd = IsSigned ? APIntOps::smax(MaxEnd, MinStart)
12479 : APIntOps::umax(MaxEnd, MinStart);
12480
12481 return getUDivCeilSCEV(getConstant(MaxEnd - MinStart) /* Delta */,
12482 getConstant(StrideForMaxBECount) /* Step */);
12483}
12484
12485ScalarEvolution::ExitLimit
12486ScalarEvolution::howManyLessThans(const SCEV *LHS, const SCEV *RHS,
12487 const Loop *L, bool IsSigned,
12488 bool ControlsExit, bool AllowPredicates) {
12489 SmallPtrSet<const SCEVPredicate *, 4> Predicates;
12490
12491 const SCEVAddRecExpr *IV = dyn_cast<SCEVAddRecExpr>(LHS);
12492 bool PredicatedIV = false;
12493
12494 auto canAssumeNoSelfWrap = [&](const SCEVAddRecExpr *AR) {
12495 // Can we prove this loop *must* be UB if overflow of IV occurs?
12496 // Reasoning goes as follows:
12497 // * Suppose the IV did self wrap.
12498 // * If Stride evenly divides the iteration space, then once wrap
12499 // occurs, the loop must revisit the same values.
12500 // * We know that RHS is invariant, and that none of those values
12501 // caused this exit to be taken previously. Thus, this exit is
12502 // dynamically dead.
12503 // * If this is the sole exit, then a dead exit implies the loop
12504 // must be infinite if there are no abnormal exits.
12505 // * If the loop were infinite, then it must either not be mustprogress
12506 // or have side effects. Otherwise, it must be UB.
12507 // * It can't (by assumption), be UB so we have contradicted our
12508 // premise and can conclude the IV did not in fact self-wrap.
12509 if (!isLoopInvariant(RHS, L))
12510 return false;
12511
12512 auto *StrideC = dyn_cast<SCEVConstant>(AR->getStepRecurrence(*this));
12513 if (!StrideC || !StrideC->getAPInt().isPowerOf2())
12514 return false;
12515
12516 if (!ControlsExit || !loopHasNoAbnormalExits(L))
12517 return false;
12518
12519 return loopIsFiniteByAssumption(L);
12520 };
12521
12522 if (!IV) {
12523 if (auto *ZExt = dyn_cast<SCEVZeroExtendExpr>(LHS)) {
12524 const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(ZExt->getOperand());
12525 if (AR && AR->getLoop() == L && AR->isAffine()) {
12526 auto canProveNUW = [&]() {
12527 if (!isLoopInvariant(RHS, L))
12528 return false;
12529
12530 if (!isKnownNonZero(AR->getStepRecurrence(*this)))
12531 // We need the sequence defined by AR to strictly increase in the
12532 // unsigned integer domain for the logic below to hold.
12533 return false;
12534
12535 const unsigned InnerBitWidth = getTypeSizeInBits(AR->getType());
12536 const unsigned OuterBitWidth = getTypeSizeInBits(RHS->getType());
12537 // If RHS <=u Limit, then there must exist a value V in the sequence
12538 // defined by AR (e.g. {Start,+,Step}) such that V >u RHS, and
12539 // V <=u UINT_MAX. Thus, we must exit the loop before unsigned
12540 // overflow occurs. This limit also implies that a signed comparison
12541 // (in the wide bitwidth) is equivalent to an unsigned comparison as
12542 // the high bits on both sides must be zero.
12543 APInt StrideMax = getUnsignedRangeMax(AR->getStepRecurrence(*this));
12544 APInt Limit = APInt::getMaxValue(InnerBitWidth) - (StrideMax - 1);
12545 Limit = Limit.zext(OuterBitWidth);
12546 return getUnsignedRangeMax(applyLoopGuards(RHS, L)).ule(Limit);
12547 };
12548 auto Flags = AR->getNoWrapFlags();
12549 if (!hasFlags(Flags, SCEV::FlagNUW) && canProveNUW())
12550 Flags = setFlags(Flags, SCEV::FlagNUW);
12551
12552 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), Flags);
12553 if (AR->hasNoUnsignedWrap()) {
12554 // Emulate what getZeroExtendExpr would have done during construction
12555 // if we'd been able to infer the fact just above at that time.
12556 const SCEV *Step = AR->getStepRecurrence(*this);
12557 Type *Ty = ZExt->getType();
12558 auto *S = getAddRecExpr(
12559 getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this, 0),
12560 getZeroExtendExpr(Step, Ty, 0), L, AR->getNoWrapFlags());
12561 IV = dyn_cast<SCEVAddRecExpr>(S);
12562 }
12563 }
12564 }
12565 }
12566
12567
12568 if (!IV && AllowPredicates) {
12569 // Try to make this an AddRec using runtime tests, in the first X
12570 // iterations of this loop, where X is the SCEV expression found by the
12571 // algorithm below.
12572 IV = convertSCEVToAddRecWithPredicates(LHS, L, Predicates);
12573 PredicatedIV = true;
12574 }
12575
12576 // Avoid weird loops
12577 if (!IV || IV->getLoop() != L || !IV->isAffine())
12578 return getCouldNotCompute();
12579
12580 // A precondition of this method is that the condition being analyzed
12581 // reaches an exiting branch which dominates the latch. Given that, we can
12582 // assume that an increment which violates the nowrap specification and
12583 // produces poison must cause undefined behavior when the resulting poison
12584 // value is branched upon and thus we can conclude that the backedge is
12585 // taken no more often than would be required to produce that poison value.
12586 // Note that a well defined loop can exit on the iteration which violates
12587 // the nowrap specification if there is another exit (either explicit or
12588 // implicit/exceptional) which causes the loop to execute before the
12589 // exiting instruction we're analyzing would trigger UB.
12590 auto WrapType = IsSigned ? SCEV::FlagNSW : SCEV::FlagNUW;
12591 bool NoWrap = ControlsExit && IV->getNoWrapFlags(WrapType);
12592 ICmpInst::Predicate Cond = IsSigned ? ICmpInst::ICMP_SLT : ICmpInst::ICMP_ULT;
12593
12594 const SCEV *Stride = IV->getStepRecurrence(*this);
12595
12596 bool PositiveStride = isKnownPositive(Stride);
12597
12598 // Avoid negative or zero stride values.
12599 if (!PositiveStride) {
12600 // We can compute the correct backedge taken count for loops with unknown
12601 // strides if we can prove that the loop is not an infinite loop with side
12602 // effects. Here's the loop structure we are trying to handle -
12603 //
12604 // i = start
12605 // do {
12606 // A[i] = i;
12607 // i += s;
12608 // } while (i < end);
12609 //
12610 // The backedge taken count for such loops is evaluated as -
12611 // (max(end, start + stride) - start - 1) /u stride
12612 //
12613 // The additional preconditions that we need to check to prove correctness
12614 // of the above formula is as follows -
12615 //
12616 // a) IV is either nuw or nsw depending upon signedness (indicated by the
12617 // NoWrap flag).
12618 // b) the loop is guaranteed to be finite (e.g. is mustprogress and has
12619 // no side effects within the loop)
12620 // c) loop has a single static exit (with no abnormal exits)
12621 //
12622 // Precondition a) implies that if the stride is negative, this is a single
12623 // trip loop. The backedge taken count formula reduces to zero in this case.
12624 //
12625 // Precondition b) and c) combine to imply that if rhs is invariant in L,
12626 // then a zero stride means the backedge can't be taken without executing
12627 // undefined behavior.
12628 //
12629 // The positive stride case is the same as isKnownPositive(Stride) returning
12630 // true (original behavior of the function).
12631 //
12632 if (PredicatedIV || !NoWrap || !loopIsFiniteByAssumption(L) ||
12633 !loopHasNoAbnormalExits(L))
12634 return getCouldNotCompute();
12635
12636 if (!isKnownNonZero(Stride)) {
12637 // If we have a step of zero, and RHS isn't invariant in L, we don't know
12638 // if it might eventually be greater than start and if so, on which
12639 // iteration. We can't even produce a useful upper bound.
12640 if (!isLoopInvariant(RHS, L))
12641 return getCouldNotCompute();
12642
12643 // We allow a potentially zero stride, but we need to divide by stride
12644 // below. Since the loop can't be infinite and this check must control
12645 // the sole exit, we can infer the exit must be taken on the first
12646 // iteration (e.g. backedge count = 0) if the stride is zero. Given that,
12647 // we know the numerator in the divides below must be zero, so we can
12648 // pick an arbitrary non-zero value for the denominator (e.g. stride)
12649 // and produce the right result.
12650 // FIXME: Handle the case where Stride is poison?
12651 auto wouldZeroStrideBeUB = [&]() {
12652 // Proof by contradiction. Suppose the stride were zero. If we can
12653 // prove that the backedge *is* taken on the first iteration, then since
12654 // we know this condition controls the sole exit, we must have an
12655 // infinite loop. We can't have a (well defined) infinite loop per
12656 // check just above.
12657 // Note: The (Start - Stride) term is used to get the start' term from
12658 // (start' + stride,+,stride). Remember that we only care about the
12659 // result of this expression when stride == 0 at runtime.
12660 auto *StartIfZero = getMinusSCEV(IV->getStart(), Stride);
12661 return isLoopEntryGuardedByCond(L, Cond, StartIfZero, RHS);
12662 };
12663 if (!wouldZeroStrideBeUB()) {
12664 Stride = getUMaxExpr(Stride, getOne(Stride->getType()));
12665 }
12666 }
12667 } else if (!Stride->isOne() && !NoWrap) {
12668 auto isUBOnWrap = [&]() {
12669 // From no-self-wrap, we need to then prove no-(un)signed-wrap. This
12670 // follows trivially from the fact that every (un)signed-wrapped, but
12671 // not self-wrapped value must be LT than the last value before
12672 // (un)signed wrap. Since we know that last value didn't exit, nor
12673 // will any smaller one.
12674 return canAssumeNoSelfWrap(IV);
12675 };
12676
12677 // Avoid proven overflow cases: this will ensure that the backedge taken
12678 // count will not generate any unsigned overflow. Relaxed no-overflow
12679 // conditions exploit NoWrapFlags, allowing to optimize in presence of
12680 // undefined behaviors like the case of C language.
12681 if (canIVOverflowOnLT(RHS, Stride, IsSigned) && !isUBOnWrap())
12682 return getCouldNotCompute();
12683 }
12684
12685 // On all paths just preceeding, we established the following invariant:
12686 // IV can be assumed not to overflow up to and including the exiting
12687 // iteration. We proved this in one of two ways:
12688 // 1) We can show overflow doesn't occur before the exiting iteration
12689 // 1a) canIVOverflowOnLT, and b) step of one
12690 // 2) We can show that if overflow occurs, the loop must execute UB
12691 // before any possible exit.
12692 // Note that we have not yet proved RHS invariant (in general).
12693
12694 const SCEV *Start = IV->getStart();
12695
12696 // Preserve pointer-typed Start/RHS to pass to isLoopEntryGuardedByCond.
12697 // If we convert to integers, isLoopEntryGuardedByCond will miss some cases.
12698 // Use integer-typed versions for actual computation; we can't subtract
12699 // pointers in general.
12700 const SCEV *OrigStart = Start;
12701 const SCEV *OrigRHS = RHS;
12702 if (Start->getType()->isPointerTy()) {
12703 Start = getLosslessPtrToIntExpr(Start);
12704 if (isa<SCEVCouldNotCompute>(Start))
12705 return Start;
12706 }
12707 if (RHS->getType()->isPointerTy()) {
12708 RHS = getLosslessPtrToIntExpr(RHS);
12709 if (isa<SCEVCouldNotCompute>(RHS))
12710 return RHS;
12711 }
12712
12713 // When the RHS is not invariant, we do not know the end bound of the loop and
12714 // cannot calculate the ExactBECount needed by ExitLimit. However, we can
12715 // calculate the MaxBECount, given the start, stride and max value for the end
12716 // bound of the loop (RHS), and the fact that IV does not overflow (which is
12717 // checked above).
12718 if (!isLoopInvariant(RHS, L)) {
12719 const SCEV *MaxBECount = computeMaxBECountForLT(
12720 Start, Stride, RHS, getTypeSizeInBits(LHS->getType()), IsSigned);
12721 return ExitLimit(getCouldNotCompute() /* ExactNotTaken */, MaxBECount,
12722 false /*MaxOrZero*/, Predicates);
12723 }
12724
12725 // We use the expression (max(End,Start)-Start)/Stride to describe the
12726 // backedge count, as if the backedge is taken at least once max(End,Start)
12727 // is End and so the result is as above, and if not max(End,Start) is Start
12728 // so we get a backedge count of zero.
12729 const SCEV *BECount = nullptr;
12730 auto *OrigStartMinusStride = getMinusSCEV(OrigStart, Stride);
12731 assert(isAvailableAtLoopEntry(OrigStartMinusStride, L) && "Must be!")(static_cast <bool> (isAvailableAtLoopEntry(OrigStartMinusStride
, L) && "Must be!") ? void (0) : __assert_fail ("isAvailableAtLoopEntry(OrigStartMinusStride, L) && \"Must be!\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 12731, __extension__
__PRETTY_FUNCTION__))
;
12732 assert(isAvailableAtLoopEntry(OrigStart, L) && "Must be!")(static_cast <bool> (isAvailableAtLoopEntry(OrigStart, L
) && "Must be!") ? void (0) : __assert_fail ("isAvailableAtLoopEntry(OrigStart, L) && \"Must be!\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 12732, __extension__
__PRETTY_FUNCTION__))
;
12733 assert(isAvailableAtLoopEntry(OrigRHS, L) && "Must be!")(static_cast <bool> (isAvailableAtLoopEntry(OrigRHS, L)
&& "Must be!") ? void (0) : __assert_fail ("isAvailableAtLoopEntry(OrigRHS, L) && \"Must be!\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 12733, __extension__
__PRETTY_FUNCTION__))
;
12734 // Can we prove (max(RHS,Start) > Start - Stride?
12735 if (isLoopEntryGuardedByCond(L, Cond, OrigStartMinusStride, OrigStart) &&
12736 isLoopEntryGuardedByCond(L, Cond, OrigStartMinusStride, OrigRHS)) {
12737 // In this case, we can use a refined formula for computing backedge taken
12738 // count. The general formula remains:
12739 // "End-Start /uceiling Stride" where "End = max(RHS,Start)"
12740 // We want to use the alternate formula:
12741 // "((End - 1) - (Start - Stride)) /u Stride"
12742 // Let's do a quick case analysis to show these are equivalent under
12743 // our precondition that max(RHS,Start) > Start - Stride.
12744 // * For RHS <= Start, the backedge-taken count must be zero.
12745 // "((End - 1) - (Start - Stride)) /u Stride" reduces to
12746 // "((Start - 1) - (Start - Stride)) /u Stride" which simplies to
12747 // "Stride - 1 /u Stride" which is indeed zero for all non-zero values
12748 // of Stride. For 0 stride, we've use umin(1,Stride) above, reducing
12749 // this to the stride of 1 case.
12750 // * For RHS >= Start, the backedge count must be "RHS-Start /uceil Stride".
12751 // "((End - 1) - (Start - Stride)) /u Stride" reduces to
12752 // "((RHS - 1) - (Start - Stride)) /u Stride" reassociates to
12753 // "((RHS - (Start - Stride) - 1) /u Stride".
12754 // Our preconditions trivially imply no overflow in that form.
12755 const SCEV *MinusOne = getMinusOne(Stride->getType());
12756 const SCEV *Numerator =
12757 getMinusSCEV(getAddExpr(RHS, MinusOne), getMinusSCEV(Start, Stride));
12758 BECount = getUDivExpr(Numerator, Stride);
12759 }
12760
12761 const SCEV *BECountIfBackedgeTaken = nullptr;
12762 if (!BECount) {
12763 auto canProveRHSGreaterThanEqualStart = [&]() {
12764 auto CondGE = IsSigned ? ICmpInst::ICMP_SGE : ICmpInst::ICMP_UGE;
12765 if (isLoopEntryGuardedByCond(L, CondGE, OrigRHS, OrigStart))
12766 return true;
12767
12768 // (RHS > Start - 1) implies RHS >= Start.
12769 // * "RHS >= Start" is trivially equivalent to "RHS > Start - 1" if
12770 // "Start - 1" doesn't overflow.
12771 // * For signed comparison, if Start - 1 does overflow, it's equal
12772 // to INT_MAX, and "RHS >s INT_MAX" is trivially false.
12773 // * For unsigned comparison, if Start - 1 does overflow, it's equal
12774 // to UINT_MAX, and "RHS >u UINT_MAX" is trivially false.
12775 //
12776 // FIXME: Should isLoopEntryGuardedByCond do this for us?
12777 auto CondGT = IsSigned ? ICmpInst::ICMP_SGT : ICmpInst::ICMP_UGT;
12778 auto *StartMinusOne = getAddExpr(OrigStart,
12779 getMinusOne(OrigStart->getType()));
12780 return isLoopEntryGuardedByCond(L, CondGT, OrigRHS, StartMinusOne);
12781 };
12782
12783 // If we know that RHS >= Start in the context of loop, then we know that
12784 // max(RHS, Start) = RHS at this point.
12785 const SCEV *End;
12786 if (canProveRHSGreaterThanEqualStart()) {
12787 End = RHS;
12788 } else {
12789 // If RHS < Start, the backedge will be taken zero times. So in
12790 // general, we can write the backedge-taken count as:
12791 //
12792 // RHS >= Start ? ceil(RHS - Start) / Stride : 0
12793 //
12794 // We convert it to the following to make it more convenient for SCEV:
12795 //
12796 // ceil(max(RHS, Start) - Start) / Stride
12797 End = IsSigned ? getSMaxExpr(RHS, Start) : getUMaxExpr(RHS, Start);
12798
12799 // See what would happen if we assume the backedge is taken. This is
12800 // used to compute MaxBECount.
12801 BECountIfBackedgeTaken = getUDivCeilSCEV(getMinusSCEV(RHS, Start), Stride);
12802 }
12803
12804 // At this point, we know:
12805 //
12806 // 1. If IsSigned, Start <=s End; otherwise, Start <=u End
12807 // 2. The index variable doesn't overflow.
12808 //
12809 // Therefore, we know N exists such that
12810 // (Start + Stride * N) >= End, and computing "(Start + Stride * N)"
12811 // doesn't overflow.
12812 //
12813 // Using this information, try to prove whether the addition in
12814 // "(Start - End) + (Stride - 1)" has unsigned overflow.
12815 const SCEV *One = getOne(Stride->getType());
12816 bool MayAddOverflow = [&] {
12817 if (auto *StrideC = dyn_cast<SCEVConstant>(Stride)) {
12818 if (StrideC->getAPInt().isPowerOf2()) {
12819 // Suppose Stride is a power of two, and Start/End are unsigned
12820 // integers. Let UMAX be the largest representable unsigned
12821 // integer.
12822 //
12823 // By the preconditions of this function, we know
12824 // "(Start + Stride * N) >= End", and this doesn't overflow.
12825 // As a formula:
12826 //
12827 // End <= (Start + Stride * N) <= UMAX
12828 //
12829 // Subtracting Start from all the terms:
12830 //
12831 // End - Start <= Stride * N <= UMAX - Start
12832 //
12833 // Since Start is unsigned, UMAX - Start <= UMAX. Therefore:
12834 //
12835 // End - Start <= Stride * N <= UMAX
12836 //
12837 // Stride * N is a multiple of Stride. Therefore,
12838 //
12839 // End - Start <= Stride * N <= UMAX - (UMAX mod Stride)
12840 //
12841 // Since Stride is a power of two, UMAX + 1 is divisible by Stride.
12842 // Therefore, UMAX mod Stride == Stride - 1. So we can write:
12843 //
12844 // End - Start <= Stride * N <= UMAX - Stride - 1
12845 //
12846 // Dropping the middle term:
12847 //
12848 // End - Start <= UMAX - Stride - 1
12849 //
12850 // Adding Stride - 1 to both sides:
12851 //
12852 // (End - Start) + (Stride - 1) <= UMAX
12853 //
12854 // In other words, the addition doesn't have unsigned overflow.
12855 //
12856 // A similar proof works if we treat Start/End as signed values.
12857 // Just rewrite steps before "End - Start <= Stride * N <= UMAX" to
12858 // use signed max instead of unsigned max. Note that we're trying
12859 // to prove a lack of unsigned overflow in either case.
12860 return false;
12861 }
12862 }
12863 if (Start == Stride || Start == getMinusSCEV(Stride, One)) {
12864 // If Start is equal to Stride, (End - Start) + (Stride - 1) == End - 1.
12865 // If !IsSigned, 0 <u Stride == Start <=u End; so 0 <u End - 1 <u End.
12866 // If IsSigned, 0 <s Stride == Start <=s End; so 0 <s End - 1 <s End.
12867 //
12868 // If Start is equal to Stride - 1, (End - Start) + Stride - 1 == End.
12869 return false;
12870 }
12871 return true;
12872 }();
12873
12874 const SCEV *Delta = getMinusSCEV(End, Start);
12875 if (!MayAddOverflow) {
12876 // floor((D + (S - 1)) / S)
12877 // We prefer this formulation if it's legal because it's fewer operations.
12878 BECount =
12879 getUDivExpr(getAddExpr(Delta, getMinusSCEV(Stride, One)), Stride);
12880 } else {
12881 BECount = getUDivCeilSCEV(Delta, Stride);
12882 }
12883 }
12884
12885 const SCEV *MaxBECount;
12886 bool MaxOrZero = false;
12887 if (isa<SCEVConstant>(BECount)) {
12888 MaxBECount = BECount;
12889 } else if (BECountIfBackedgeTaken &&
12890 isa<SCEVConstant>(BECountIfBackedgeTaken)) {
12891 // If we know exactly how many times the backedge will be taken if it's
12892 // taken at least once, then the backedge count will either be that or
12893 // zero.
12894 MaxBECount = BECountIfBackedgeTaken;
12895 MaxOrZero = true;
12896 } else {
12897 MaxBECount = computeMaxBECountForLT(
12898 Start, Stride, RHS, getTypeSizeInBits(LHS->getType()), IsSigned);
12899 }
12900
12901 if (isa<SCEVCouldNotCompute>(MaxBECount) &&
12902 !isa<SCEVCouldNotCompute>(BECount))
12903 MaxBECount = getConstant(getUnsignedRangeMax(BECount));
12904
12905 return ExitLimit(BECount, MaxBECount, MaxOrZero, Predicates);
12906}
12907
12908ScalarEvolution::ExitLimit
12909ScalarEvolution::howManyGreaterThans(const SCEV *LHS, const SCEV *RHS,
12910 const Loop *L, bool IsSigned,
12911 bool ControlsExit, bool AllowPredicates) {
12912 SmallPtrSet<const SCEVPredicate *, 4> Predicates;
12913 // We handle only IV > Invariant
12914 if (!isLoopInvariant(RHS, L))
12915 return getCouldNotCompute();
12916
12917 const SCEVAddRecExpr *IV = dyn_cast<SCEVAddRecExpr>(LHS);
12918 if (!IV && AllowPredicates)
12919 // Try to make this an AddRec using runtime tests, in the first X
12920 // iterations of this loop, where X is the SCEV expression found by the
12921 // algorithm below.
12922 IV = convertSCEVToAddRecWithPredicates(LHS, L, Predicates);
12923
12924 // Avoid weird loops
12925 if (!IV || IV->getLoop() != L || !IV->isAffine())
12926 return getCouldNotCompute();
12927
12928 auto WrapType = IsSigned ? SCEV::FlagNSW : SCEV::FlagNUW;
12929 bool NoWrap = ControlsExit && IV->getNoWrapFlags(WrapType);
12930 ICmpInst::Predicate Cond = IsSigned ? ICmpInst::ICMP_SGT : ICmpInst::ICMP_UGT;
12931
12932 const SCEV *Stride = getNegativeSCEV(IV->getStepRecurrence(*this));
12933
12934 // Avoid negative or zero stride values
12935 if (!isKnownPositive(Stride))
12936 return getCouldNotCompute();
12937
12938 // Avoid proven overflow cases: this will ensure that the backedge taken count
12939 // will not generate any unsigned overflow. Relaxed no-overflow conditions
12940 // exploit NoWrapFlags, allowing to optimize in presence of undefined
12941 // behaviors like the case of C language.
12942 if (!Stride->isOne() && !NoWrap)
12943 if (canIVOverflowOnGT(RHS, Stride, IsSigned))
12944 return getCouldNotCompute();
12945
12946 const SCEV *Start = IV->getStart();
12947 const SCEV *End = RHS;
12948 if (!isLoopEntryGuardedByCond(L, Cond, getAddExpr(Start, Stride), RHS)) {
12949 // If we know that Start >= RHS in the context of loop, then we know that
12950 // min(RHS, Start) = RHS at this point.
12951 if (isLoopEntryGuardedByCond(
12952 L, IsSigned ? ICmpInst::ICMP_SGE : ICmpInst::ICMP_UGE, Start, RHS))
12953 End = RHS;
12954 else
12955 End = IsSigned ? getSMinExpr(RHS, Start) : getUMinExpr(RHS, Start);
12956 }
12957
12958 if (Start->getType()->isPointerTy()) {
12959 Start = getLosslessPtrToIntExpr(Start);
12960 if (isa<SCEVCouldNotCompute>(Start))
12961 return Start;
12962 }
12963 if (End->getType()->isPointerTy()) {
12964 End = getLosslessPtrToIntExpr(End);
12965 if (isa<SCEVCouldNotCompute>(End))
12966 return End;
12967 }
12968
12969 // Compute ((Start - End) + (Stride - 1)) / Stride.
12970 // FIXME: This can overflow. Holding off on fixing this for now;
12971 // howManyGreaterThans will hopefully be gone soon.
12972 const SCEV *One = getOne(Stride->getType());
12973 const SCEV *BECount = getUDivExpr(
12974 getAddExpr(getMinusSCEV(Start, End), getMinusSCEV(Stride, One)), Stride);
12975
12976 APInt MaxStart = IsSigned ? getSignedRangeMax(Start)
12977 : getUnsignedRangeMax(Start);
12978
12979 APInt MinStride = IsSigned ? getSignedRangeMin(Stride)
12980 : getUnsignedRangeMin(Stride);
12981
12982 unsigned BitWidth = getTypeSizeInBits(LHS->getType());
12983 APInt Limit = IsSigned ? APInt::getSignedMinValue(BitWidth) + (MinStride - 1)
12984 : APInt::getMinValue(BitWidth) + (MinStride - 1);
12985
12986 // Although End can be a MIN expression we estimate MinEnd considering only
12987 // the case End = RHS. This is safe because in the other case (Start - End)
12988 // is zero, leading to a zero maximum backedge taken count.
12989 APInt MinEnd =
12990 IsSigned ? APIntOps::smax(getSignedRangeMin(RHS), Limit)
12991 : APIntOps::umax(getUnsignedRangeMin(RHS), Limit);
12992
12993 const SCEV *MaxBECount = isa<SCEVConstant>(BECount)
12994 ? BECount
12995 : getUDivCeilSCEV(getConstant(MaxStart - MinEnd),
12996 getConstant(MinStride));
12997
12998 if (isa<SCEVCouldNotCompute>(MaxBECount))
12999 MaxBECount = BECount;
13000
13001 return ExitLimit(BECount, MaxBECount, false, Predicates);
13002}
13003
13004const SCEV *SCEVAddRecExpr::getNumIterationsInRange(const ConstantRange &Range,
13005 ScalarEvolution &SE) const {
13006 if (Range.isFullSet()) // Infinite loop.
13007 return SE.getCouldNotCompute();
13008
13009 // If the start is a non-zero constant, shift the range to simplify things.
13010 if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(getStart()))
13011 if (!SC->getValue()->isZero()) {
13012 SmallVector<const SCEV *, 4> Operands(operands());
13013 Operands[0] = SE.getZero(SC->getType());
13014 const SCEV *Shifted = SE.getAddRecExpr(Operands, getLoop(),
13015 getNoWrapFlags(FlagNW));
13016 if (const auto *ShiftedAddRec = dyn_cast<SCEVAddRecExpr>(Shifted))
13017 return ShiftedAddRec->getNumIterationsInRange(
13018 Range.subtract(SC->getAPInt()), SE);
13019 // This is strange and shouldn't happen.
13020 return SE.getCouldNotCompute();
13021 }
13022
13023 // The only time we can solve this is when we have all constant indices.
13024 // Otherwise, we cannot determine the overflow conditions.
13025 if (any_of(operands(), [](const SCEV *Op) { return !isa<SCEVConstant>(Op); }))
13026 return SE.getCouldNotCompute();
13027
13028 // Okay at this point we know that all elements of the chrec are constants and
13029 // that the start element is zero.
13030
13031 // First check to see if the range contains zero. If not, the first
13032 // iteration exits.
13033 unsigned BitWidth = SE.getTypeSizeInBits(getType());
13034 if (!Range.contains(APInt(BitWidth, 0)))
13035 return SE.getZero(getType());
13036
13037 if (isAffine()) {
13038 // If this is an affine expression then we have this situation:
13039 // Solve {0,+,A} in Range === Ax in Range
13040
13041 // We know that zero is in the range. If A is positive then we know that
13042 // the upper value of the range must be the first possible exit value.
13043 // If A is negative then the lower of the range is the last possible loop
13044 // value. Also note that we already checked for a full range.
13045 APInt A = cast<SCEVConstant>(getOperand(1))->getAPInt();
13046 APInt End = A.sge(1) ? (Range.getUpper() - 1) : Range.getLower();
13047
13048 // The exit value should be (End+A)/A.
13049 APInt ExitVal = (End + A).udiv(A);
13050 ConstantInt *ExitValue = ConstantInt::get(SE.getContext(), ExitVal);
13051
13052 // Evaluate at the exit value. If we really did fall out of the valid
13053 // range, then we computed our trip count, otherwise wrap around or other
13054 // things must have happened.
13055 ConstantInt *Val = EvaluateConstantChrecAtConstant(this, ExitValue, SE);
13056 if (Range.contains(Val->getValue()))
13057 return SE.getCouldNotCompute(); // Something strange happened
13058
13059 // Ensure that the previous value is in the range.
13060 assert(Range.contains((static_cast <bool> (Range.contains( EvaluateConstantChrecAtConstant
(this, ConstantInt::get(SE.getContext(), ExitVal - 1), SE)->
getValue()) && "Linear scev computation is off in a bad way!"
) ? void (0) : __assert_fail ("Range.contains( EvaluateConstantChrecAtConstant(this, ConstantInt::get(SE.getContext(), ExitVal - 1), SE)->getValue()) && \"Linear scev computation is off in a bad way!\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 13063, __extension__
__PRETTY_FUNCTION__))
13061 EvaluateConstantChrecAtConstant(this,(static_cast <bool> (Range.contains( EvaluateConstantChrecAtConstant
(this, ConstantInt::get(SE.getContext(), ExitVal - 1), SE)->
getValue()) && "Linear scev computation is off in a bad way!"
) ? void (0) : __assert_fail ("Range.contains( EvaluateConstantChrecAtConstant(this, ConstantInt::get(SE.getContext(), ExitVal - 1), SE)->getValue()) && \"Linear scev computation is off in a bad way!\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 13063, __extension__
__PRETTY_FUNCTION__))
13062 ConstantInt::get(SE.getContext(), ExitVal - 1), SE)->getValue()) &&(static_cast <bool> (Range.contains( EvaluateConstantChrecAtConstant
(this, ConstantInt::get(SE.getContext(), ExitVal - 1), SE)->
getValue()) && "Linear scev computation is off in a bad way!"
) ? void (0) : __assert_fail ("Range.contains( EvaluateConstantChrecAtConstant(this, ConstantInt::get(SE.getContext(), ExitVal - 1), SE)->getValue()) && \"Linear scev computation is off in a bad way!\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 13063, __extension__
__PRETTY_FUNCTION__))
13063 "Linear scev computation is off in a bad way!")(static_cast <bool> (Range.contains( EvaluateConstantChrecAtConstant
(this, ConstantInt::get(SE.getContext(), ExitVal - 1), SE)->
getValue()) && "Linear scev computation is off in a bad way!"
) ? void (0) : __assert_fail ("Range.contains( EvaluateConstantChrecAtConstant(this, ConstantInt::get(SE.getContext(), ExitVal - 1), SE)->getValue()) && \"Linear scev computation is off in a bad way!\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 13063, __extension__
__PRETTY_FUNCTION__))
;
13064 return SE.getConstant(ExitValue);
13065 }
13066
13067 if (isQuadratic()) {
13068 if (auto S = SolveQuadraticAddRecRange(this, Range, SE))
13069 return SE.getConstant(*S);
13070 }
13071
13072 return SE.getCouldNotCompute();
13073}
13074
13075const SCEVAddRecExpr *
13076SCEVAddRecExpr::getPostIncExpr(ScalarEvolution &SE) const {
13077 assert(getNumOperands() > 1 && "AddRec with zero step?")(static_cast <bool> (getNumOperands() > 1 &&
"AddRec with zero step?") ? void (0) : __assert_fail ("getNumOperands() > 1 && \"AddRec with zero step?\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 13077, __extension__
__PRETTY_FUNCTION__))
;
13078 // There is a temptation to just call getAddExpr(this, getStepRecurrence(SE)),
13079 // but in this case we cannot guarantee that the value returned will be an
13080 // AddRec because SCEV does not have a fixed point where it stops
13081 // simplification: it is legal to return ({rec1} + {rec2}). For example, it
13082 // may happen if we reach arithmetic depth limit while simplifying. So we
13083 // construct the returned value explicitly.
13084 SmallVector<const SCEV *, 3> Ops;
13085 // If this is {A,+,B,+,C,...,+,N}, then its step is {B,+,C,+,...,+,N}, and
13086 // (this + Step) is {A+B,+,B+C,+...,+,N}.
13087 for (unsigned i = 0, e = getNumOperands() - 1; i < e; ++i)
13088 Ops.push_back(SE.getAddExpr(getOperand(i), getOperand(i + 1)));
13089 // We know that the last operand is not a constant zero (otherwise it would
13090 // have been popped out earlier). This guarantees us that if the result has
13091 // the same last operand, then it will also not be popped out, meaning that
13092 // the returned value will be an AddRec.
13093 const SCEV *Last = getOperand(getNumOperands() - 1);
13094 assert(!Last->isZero() && "Recurrency with zero step?")(static_cast <bool> (!Last->isZero() && "Recurrency with zero step?"
) ? void (0) : __assert_fail ("!Last->isZero() && \"Recurrency with zero step?\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 13094, __extension__
__PRETTY_FUNCTION__))
;
13095 Ops.push_back(Last);
13096 return cast<SCEVAddRecExpr>(SE.getAddRecExpr(Ops, getLoop(),
13097 SCEV::FlagAnyWrap));
13098}
13099
13100// Return true when S contains at least an undef value.
13101bool ScalarEvolution::containsUndefs(const SCEV *S) const {
13102 return SCEVExprContains(S, [](const SCEV *S) {
13103 if (const auto *SU = dyn_cast<SCEVUnknown>(S))
13104 return isa<UndefValue>(SU->getValue());
13105 return false;
13106 });
13107}
13108
13109// Return true when S contains a value that is a nullptr.
13110bool ScalarEvolution::containsErasedValue(const SCEV *S) const {
13111 return SCEVExprContains(S, [](const SCEV *S) {
13112 if (const auto *SU = dyn_cast<SCEVUnknown>(S))
13113 return SU->getValue() == nullptr;
13114 return false;
13115 });
13116}
13117
13118/// Return the size of an element read or written by Inst.
13119const SCEV *ScalarEvolution::getElementSize(Instruction *Inst) {
13120 Type *Ty;
13121 if (StoreInst *Store = dyn_cast<StoreInst>(Inst))
13122 Ty = Store->getValueOperand()->getType();
13123 else if (LoadInst *Load = dyn_cast<LoadInst>(Inst))
13124 Ty = Load->getType();
13125 else
13126 return nullptr;
13127
13128 Type *ETy = getEffectiveSCEVType(PointerType::getUnqual(Ty));
13129 return getSizeOfExpr(ETy, Ty);
13130}
13131
13132//===----------------------------------------------------------------------===//
13133// SCEVCallbackVH Class Implementation
13134//===----------------------------------------------------------------------===//
13135
13136void ScalarEvolution::SCEVCallbackVH::deleted() {
13137 assert(SE && "SCEVCallbackVH called with a null ScalarEvolution!")(static_cast <bool> (SE && "SCEVCallbackVH called with a null ScalarEvolution!"
) ? void (0) : __assert_fail ("SE && \"SCEVCallbackVH called with a null ScalarEvolution!\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 13137, __extension__
__PRETTY_FUNCTION__))
;
13138 if (PHINode *PN = dyn_cast<PHINode>(getValPtr()))
13139 SE->ConstantEvolutionLoopExitValue.erase(PN);
13140 SE->eraseValueFromMap(getValPtr());
13141 // this now dangles!
13142}
13143
13144void ScalarEvolution::SCEVCallbackVH::allUsesReplacedWith(Value *V) {
13145 assert(SE && "SCEVCallbackVH called with a null ScalarEvolution!")(static_cast <bool> (SE && "SCEVCallbackVH called with a null ScalarEvolution!"
) ? void (0) : __assert_fail ("SE && \"SCEVCallbackVH called with a null ScalarEvolution!\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 13145, __extension__
__PRETTY_FUNCTION__))
;
13146
13147 // Forget all the expressions associated with users of the old value,
13148 // so that future queries will recompute the expressions using the new
13149 // value.
13150 Value *Old = getValPtr();
13151 SmallVector<User *, 16> Worklist(Old->users());
13152 SmallPtrSet<User *, 8> Visited;
13153 while (!Worklist.empty()) {
13154 User *U = Worklist.pop_back_val();
13155 // Deleting the Old value will cause this to dangle. Postpone
13156 // that until everything else is done.
13157 if (U == Old)
13158 continue;
13159 if (!Visited.insert(U).second)
13160 continue;
13161 if (PHINode *PN = dyn_cast<PHINode>(U))
13162 SE->ConstantEvolutionLoopExitValue.erase(PN);
13163 SE->eraseValueFromMap(U);
13164 llvm::append_range(Worklist, U->users());
13165 }
13166 // Delete the Old value.
13167 if (PHINode *PN = dyn_cast<PHINode>(Old))
13168 SE->ConstantEvolutionLoopExitValue.erase(PN);
13169 SE->eraseValueFromMap(Old);
13170 // this now dangles!
13171}
13172
13173ScalarEvolution::SCEVCallbackVH::SCEVCallbackVH(Value *V, ScalarEvolution *se)
13174 : CallbackVH(V), SE(se) {}
13175
13176//===----------------------------------------------------------------------===//
13177// ScalarEvolution Class Implementation
13178//===----------------------------------------------------------------------===//
13179
13180ScalarEvolution::ScalarEvolution(Function &F, TargetLibraryInfo &TLI,
13181 AssumptionCache &AC, DominatorTree &DT,
13182 LoopInfo &LI)
13183 : F(F), TLI(TLI), AC(AC), DT(DT), LI(LI),
13184 CouldNotCompute(new SCEVCouldNotCompute()), ValuesAtScopes(64),
13185 LoopDispositions(64), BlockDispositions(64) {
13186 // To use guards for proving predicates, we need to scan every instruction in
13187 // relevant basic blocks, and not just terminators. Doing this is a waste of
13188 // time if the IR does not actually contain any calls to
13189 // @llvm.experimental.guard, so do a quick check and remember this beforehand.
13190 //
13191 // This pessimizes the case where a pass that preserves ScalarEvolution wants
13192 // to _add_ guards to the module when there weren't any before, and wants
13193 // ScalarEvolution to optimize based on those guards. For now we prefer to be
13194 // efficient in lieu of being smart in that rather obscure case.
13195
13196 auto *GuardDecl = F.getParent()->getFunction(
13197 Intrinsic::getName(Intrinsic::experimental_guard));
13198 HasGuards = GuardDecl && !GuardDecl->use_empty();
13199}
13200
13201ScalarEvolution::ScalarEvolution(ScalarEvolution &&Arg)
13202 : F(Arg.F), HasGuards(Arg.HasGuards), TLI(Arg.TLI), AC(Arg.AC), DT(Arg.DT),
13203 LI(Arg.LI), CouldNotCompute(std::move(Arg.CouldNotCompute)),
13204 ValueExprMap(std::move(Arg.ValueExprMap)),
13205 PendingLoopPredicates(std::move(Arg.PendingLoopPredicates)),
13206 PendingPhiRanges(std::move(Arg.PendingPhiRanges)),
13207 PendingMerges(std::move(Arg.PendingMerges)),
13208 MinTrailingZerosCache(std::move(Arg.MinTrailingZerosCache)),
13209 BackedgeTakenCounts(std::move(Arg.BackedgeTakenCounts)),
13210 PredicatedBackedgeTakenCounts(
13211 std::move(Arg.PredicatedBackedgeTakenCounts)),
13212 BECountUsers(std::move(Arg.BECountUsers)),
13213 ConstantEvolutionLoopExitValue(
13214 std::move(Arg.ConstantEvolutionLoopExitValue)),
13215 ValuesAtScopes(std::move(Arg.ValuesAtScopes)),
13216 ValuesAtScopesUsers(std::move(Arg.ValuesAtScopesUsers)),
13217 LoopDispositions(std::move(Arg.LoopDispositions)),
13218 LoopPropertiesCache(std::move(Arg.LoopPropertiesCache)),
13219 BlockDispositions(std::move(Arg.BlockDispositions)),
13220 SCEVUsers(std::move(Arg.SCEVUsers)),
13221 UnsignedRanges(std::move(Arg.UnsignedRanges)),
13222 SignedRanges(std::move(Arg.SignedRanges)),
13223 UniqueSCEVs(std::move(Arg.UniqueSCEVs)),
13224 UniquePreds(std::move(Arg.UniquePreds)),
13225 SCEVAllocator(std::move(Arg.SCEVAllocator)),
13226 LoopUsers(std::move(Arg.LoopUsers)),
13227 PredicatedSCEVRewrites(std::move(Arg.PredicatedSCEVRewrites)),
13228 FirstUnknown(Arg.FirstUnknown) {
13229 Arg.FirstUnknown = nullptr;
13230}
13231
13232ScalarEvolution::~ScalarEvolution() {
13233 // Iterate through all the SCEVUnknown instances and call their
13234 // destructors, so that they release their references to their values.
13235 for (SCEVUnknown *U = FirstUnknown; U;) {
13236 SCEVUnknown *Tmp = U;
13237 U = U->Next;
13238 Tmp->~SCEVUnknown();
13239 }
13240 FirstUnknown = nullptr;
13241
13242 ExprValueMap.clear();
13243 ValueExprMap.clear();
13244 HasRecMap.clear();
13245 BackedgeTakenCounts.clear();
13246 PredicatedBackedgeTakenCounts.clear();
13247
13248 assert(PendingLoopPredicates.empty() && "isImpliedCond garbage")(static_cast <bool> (PendingLoopPredicates.empty() &&
"isImpliedCond garbage") ? void (0) : __assert_fail ("PendingLoopPredicates.empty() && \"isImpliedCond garbage\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 13248, __extension__
__PRETTY_FUNCTION__))
;
13249 assert(PendingPhiRanges.empty() && "getRangeRef garbage")(static_cast <bool> (PendingPhiRanges.empty() &&
"getRangeRef garbage") ? void (0) : __assert_fail ("PendingPhiRanges.empty() && \"getRangeRef garbage\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 13249, __extension__
__PRETTY_FUNCTION__))
;
13250 assert(PendingMerges.empty() && "isImpliedViaMerge garbage")(static_cast <bool> (PendingMerges.empty() && "isImpliedViaMerge garbage"
) ? void (0) : __assert_fail ("PendingMerges.empty() && \"isImpliedViaMerge garbage\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 13250, __extension__
__PRETTY_FUNCTION__))
;
13251 assert(!WalkingBEDominatingConds && "isLoopBackedgeGuardedByCond garbage!")(static_cast <bool> (!WalkingBEDominatingConds &&
"isLoopBackedgeGuardedByCond garbage!") ? void (0) : __assert_fail
("!WalkingBEDominatingConds && \"isLoopBackedgeGuardedByCond garbage!\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 13251, __extension__
__PRETTY_FUNCTION__))
;
13252 assert(!ProvingSplitPredicate && "ProvingSplitPredicate garbage!")(static_cast <bool> (!ProvingSplitPredicate && "ProvingSplitPredicate garbage!"
) ? void (0) : __assert_fail ("!ProvingSplitPredicate && \"ProvingSplitPredicate garbage!\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 13252, __extension__
__PRETTY_FUNCTION__))
;
13253}
13254
13255bool ScalarEvolution::hasLoopInvariantBackedgeTakenCount(const Loop *L) {
13256 return !isa<SCEVCouldNotCompute>(getBackedgeTakenCount(L));
13257}
13258
13259static void PrintLoopInfo(raw_ostream &OS, ScalarEvolution *SE,
13260 const Loop *L) {
13261 // Print all inner loops first
13262 for (Loop *I : *L)
13263 PrintLoopInfo(OS, SE, I);
13264
13265 OS << "Loop ";
13266 L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
13267 OS << ": ";
13268
13269 SmallVector<BasicBlock *, 8> ExitingBlocks;
13270 L->getExitingBlocks(ExitingBlocks);
13271 if (ExitingBlocks.size() != 1)
13272 OS << "<multiple exits> ";
13273
13274 if (SE->hasLoopInvariantBackedgeTakenCount(L))
13275 OS << "backedge-taken count is " << *SE->getBackedgeTakenCount(L) << "\n";
13276 else
13277 OS << "Unpredictable backedge-taken count.\n";
13278
13279 if (ExitingBlocks.size() > 1)
13280 for (BasicBlock *ExitingBlock : ExitingBlocks) {
13281 OS << " exit count for " << ExitingBlock->getName() << ": "
13282 << *SE->getExitCount(L, ExitingBlock) << "\n";
13283 }
13284
13285 OS << "Loop ";
13286 L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
13287 OS << ": ";
13288
13289 if (!isa<SCEVCouldNotCompute>(SE->getConstantMaxBackedgeTakenCount(L))) {
13290 OS << "max backedge-taken count is " << *SE->getConstantMaxBackedgeTakenCount(L);
13291 if (SE->isBackedgeTakenCountMaxOrZero(L))
13292 OS << ", actual taken count either this or zero.";
13293 } else {
13294 OS << "Unpredictable max backedge-taken count. ";
13295 }
13296
13297 OS << "\n"
13298 "Loop ";
13299 L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
13300 OS << ": ";
13301
13302 SmallVector<const SCEVPredicate *, 4> Preds;
13303 auto PBT = SE->getPredicatedBackedgeTakenCount(L, Preds);
13304 if (!isa<SCEVCouldNotCompute>(PBT)) {
13305 OS << "Predicated backedge-taken count is " << *PBT << "\n";
13306 OS << " Predicates:\n";
13307 for (const auto *P : Preds)
13308 P->print(OS, 4);
13309 } else {
13310 OS << "Unpredictable predicated backedge-taken count. ";
13311 }
13312 OS << "\n";
13313
13314 if (SE->hasLoopInvariantBackedgeTakenCount(L)) {
13315 OS << "Loop ";
13316 L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
13317 OS << ": ";
13318 OS << "Trip multiple is " << SE->getSmallConstantTripMultiple(L) << "\n";
13319 }
13320}
13321
13322static StringRef loopDispositionToStr(ScalarEvolution::LoopDisposition LD) {
13323 switch (LD) {
13324 case ScalarEvolution::LoopVariant:
13325 return "Variant";
13326 case ScalarEvolution::LoopInvariant:
13327 return "Invariant";
13328 case ScalarEvolution::LoopComputable:
13329 return "Computable";
13330 }
13331 llvm_unreachable("Unknown ScalarEvolution::LoopDisposition kind!")::llvm::llvm_unreachable_internal("Unknown ScalarEvolution::LoopDisposition kind!"
, "llvm/lib/Analysis/ScalarEvolution.cpp", 13331)
;
13332}
13333
13334void ScalarEvolution::print(raw_ostream &OS) const {
13335 // ScalarEvolution's implementation of the print method is to print
13336 // out SCEV values of all instructions that are interesting. Doing
13337 // this potentially causes it to create new SCEV objects though,
13338 // which technically conflicts with the const qualifier. This isn't
13339 // observable from outside the class though, so casting away the
13340 // const isn't dangerous.
13341 ScalarEvolution &SE = *const_cast<ScalarEvolution *>(this);
13342
13343 if (ClassifyExpressions) {
13344 OS << "Classifying expressions for: ";
13345 F.printAsOperand(OS, /*PrintType=*/false);
13346 OS << "\n";
13347 for (Instruction &I : instructions(F))
13348 if (isSCEVable(I.getType()) && !isa<CmpInst>(I)) {
13349 OS << I << '\n';
13350 OS << " --> ";
13351 const SCEV *SV = SE.getSCEV(&I);
13352 SV->print(OS);
13353 if (!isa<SCEVCouldNotCompute>(SV)) {
13354 OS << " U: ";
13355 SE.getUnsignedRange(SV).print(OS);
13356 OS << " S: ";
13357 SE.getSignedRange(SV).print(OS);
13358 }
13359
13360 const Loop *L = LI.getLoopFor(I.getParent());
13361
13362 const SCEV *AtUse = SE.getSCEVAtScope(SV, L);
13363 if (AtUse != SV) {
13364 OS << " --> ";
13365 AtUse->print(OS);
13366 if (!isa<SCEVCouldNotCompute>(AtUse)) {
13367 OS << " U: ";
13368 SE.getUnsignedRange(AtUse).print(OS);
13369 OS << " S: ";
13370 SE.getSignedRange(AtUse).print(OS);
13371 }
13372 }
13373
13374 if (L) {
13375 OS << "\t\t" "Exits: ";
13376 const SCEV *ExitValue = SE.getSCEVAtScope(SV, L->getParentLoop());
13377 if (!SE.isLoopInvariant(ExitValue, L)) {
13378 OS << "<<Unknown>>";
13379 } else {
13380 OS << *ExitValue;
13381 }
13382
13383 bool First = true;
13384 for (const auto *Iter = L; Iter; Iter = Iter->getParentLoop()) {
13385 if (First) {
13386 OS << "\t\t" "LoopDispositions: { ";
13387 First = false;
13388 } else {
13389 OS << ", ";
13390 }
13391
13392 Iter->getHeader()->printAsOperand(OS, /*PrintType=*/false);
13393 OS << ": " << loopDispositionToStr(SE.getLoopDisposition(SV, Iter));
13394 }
13395
13396 for (const auto *InnerL : depth_first(L)) {
13397 if (InnerL == L)
13398 continue;
13399 if (First) {
13400 OS << "\t\t" "LoopDispositions: { ";
13401 First = false;
13402 } else {
13403 OS << ", ";
13404 }
13405
13406 InnerL->getHeader()->printAsOperand(OS, /*PrintType=*/false);
13407 OS << ": " << loopDispositionToStr(SE.getLoopDisposition(SV, InnerL));
13408 }
13409
13410 OS << " }";
13411 }
13412
13413 OS << "\n";
13414 }
13415 }
13416
13417 OS << "Determining loop execution counts for: ";
13418 F.printAsOperand(OS, /*PrintType=*/false);
13419 OS << "\n";
13420 for (Loop *I : LI)
13421 PrintLoopInfo(OS, &SE, I);
13422}
13423
13424ScalarEvolution::LoopDisposition
13425ScalarEvolution::getLoopDisposition(const SCEV *S, const Loop *L) {
13426 auto &Values = LoopDispositions[S];
13427 for (auto &V : Values) {
13428 if (V.getPointer() == L)
13429 return V.getInt();
13430 }
13431 Values.emplace_back(L, LoopVariant);
13432 LoopDisposition D = computeLoopDisposition(S, L);
13433 auto &Values2 = LoopDispositions[S];
13434 for (auto &V : llvm::reverse(Values2)) {
13435 if (V.getPointer() == L) {
13436 V.setInt(D);
13437 break;
13438 }
13439 }
13440 return D;
13441}
13442
13443ScalarEvolution::LoopDisposition
13444ScalarEvolution::computeLoopDisposition(const SCEV *S, const Loop *L) {
13445 switch (S->getSCEVType()) {
13446 case scConstant:
13447 return LoopInvariant;
13448 case scPtrToInt:
13449 case scTruncate:
13450 case scZeroExtend:
13451 case scSignExtend:
13452 return getLoopDisposition(cast<SCEVCastExpr>(S)->getOperand(), L);
13453 case scAddRecExpr: {
13454 const SCEVAddRecExpr *AR = cast<SCEVAddRecExpr>(S);
13455
13456 // If L is the addrec's loop, it's computable.
13457 if (AR->getLoop() == L)
13458 return LoopComputable;
13459
13460 // Add recurrences are never invariant in the function-body (null loop).
13461 if (!L)
13462 return LoopVariant;
13463
13464 // Everything that is not defined at loop entry is variant.
13465 if (DT.dominates(L->getHeader(), AR->getLoop()->getHeader()))
13466 return LoopVariant;
13467 assert(!L->contains(AR->getLoop()) && "Containing loop's header does not"(static_cast <bool> (!L->contains(AR->getLoop()) &&
"Containing loop's header does not" " dominate the contained loop's header?"
) ? void (0) : __assert_fail ("!L->contains(AR->getLoop()) && \"Containing loop's header does not\" \" dominate the contained loop's header?\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 13468, __extension__
__PRETTY_FUNCTION__))
13468 " dominate the contained loop's header?")(static_cast <bool> (!L->contains(AR->getLoop()) &&
"Containing loop's header does not" " dominate the contained loop's header?"
) ? void (0) : __assert_fail ("!L->contains(AR->getLoop()) && \"Containing loop's header does not\" \" dominate the contained loop's header?\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 13468, __extension__
__PRETTY_FUNCTION__))
;
13469
13470 // This recurrence is invariant w.r.t. L if AR's loop contains L.
13471 if (AR->getLoop()->contains(L))
13472 return LoopInvariant;
13473
13474 // This recurrence is variant w.r.t. L if any of its operands
13475 // are variant.
13476 for (const auto *Op : AR->operands())
13477 if (!isLoopInvariant(Op, L))
13478 return LoopVariant;
13479
13480 // Otherwise it's loop-invariant.
13481 return LoopInvariant;
13482 }
13483 case scAddExpr:
13484 case scMulExpr:
13485 case scUMaxExpr:
13486 case scSMaxExpr:
13487 case scUMinExpr:
13488 case scSMinExpr:
13489 case scSequentialUMinExpr: {
13490 bool HasVarying = false;
13491 for (const auto *Op : cast<SCEVNAryExpr>(S)->operands()) {
13492 LoopDisposition D = getLoopDisposition(Op, L);
13493 if (D == LoopVariant)
13494 return LoopVariant;
13495 if (D == LoopComputable)
13496 HasVarying = true;
13497 }
13498 return HasVarying ? LoopComputable : LoopInvariant;
13499 }
13500 case scUDivExpr: {
13501 const SCEVUDivExpr *UDiv = cast<SCEVUDivExpr>(S);
13502 LoopDisposition LD = getLoopDisposition(UDiv->getLHS(), L);
13503 if (LD == LoopVariant)
13504 return LoopVariant;
13505 LoopDisposition RD = getLoopDisposition(UDiv->getRHS(), L);
13506 if (RD == LoopVariant)
13507 return LoopVariant;
13508 return (LD == LoopInvariant && RD == LoopInvariant) ?
13509 LoopInvariant : LoopComputable;
13510 }
13511 case scUnknown:
13512 // All non-instruction values are loop invariant. All instructions are loop
13513 // invariant if they are not contained in the specified loop.
13514 // Instructions are never considered invariant in the function body
13515 // (null loop) because they are defined within the "loop".
13516 if (auto *I = dyn_cast<Instruction>(cast<SCEVUnknown>(S)->getValue()))
13517 return (L && !L->contains(I)) ? LoopInvariant : LoopVariant;
13518 return LoopInvariant;
13519 case scCouldNotCompute:
13520 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!")::llvm::llvm_unreachable_internal("Attempt to use a SCEVCouldNotCompute object!"
, "llvm/lib/Analysis/ScalarEvolution.cpp", 13520)
;
13521 }
13522 llvm_unreachable("Unknown SCEV kind!")::llvm::llvm_unreachable_internal("Unknown SCEV kind!", "llvm/lib/Analysis/ScalarEvolution.cpp"
, 13522)
;
13523}
13524
13525bool ScalarEvolution::isLoopInvariant(const SCEV *S, const Loop *L) {
13526 return getLoopDisposition(S, L) == LoopInvariant;
13527}
13528
13529bool ScalarEvolution::hasComputableLoopEvolution(const SCEV *S, const Loop *L) {
13530 return getLoopDisposition(S, L) == LoopComputable;
13531}
13532
13533ScalarEvolution::BlockDisposition
13534ScalarEvolution::getBlockDisposition(const SCEV *S, const BasicBlock *BB) {
13535 auto &Values = BlockDispositions[S];
13536 for (auto &V : Values) {
13537 if (V.getPointer() == BB)
13538 return V.getInt();
13539 }
13540 Values.emplace_back(BB, DoesNotDominateBlock);
13541 BlockDisposition D = computeBlockDisposition(S, BB);
13542 auto &Values2 = BlockDispositions[S];
13543 for (auto &V : llvm::reverse(Values2)) {
13544 if (V.getPointer() == BB) {
13545 V.setInt(D);
13546 break;
13547 }
13548 }
13549 return D;
13550}
13551
13552ScalarEvolution::BlockDisposition
13553ScalarEvolution::computeBlockDisposition(const SCEV *S, const BasicBlock *BB) {
13554 switch (S->getSCEVType()) {
13555 case scConstant:
13556 return ProperlyDominatesBlock;
13557 case scPtrToInt:
13558 case scTruncate:
13559 case scZeroExtend:
13560 case scSignExtend:
13561 return getBlockDisposition(cast<SCEVCastExpr>(S)->getOperand(), BB);
13562 case scAddRecExpr: {
13563 // This uses a "dominates" query instead of "properly dominates" query
13564 // to test for proper dominance too, because the instruction which
13565 // produces the addrec's value is a PHI, and a PHI effectively properly
13566 // dominates its entire containing block.
13567 const SCEVAddRecExpr *AR = cast<SCEVAddRecExpr>(S);
13568 if (!DT.dominates(AR->getLoop()->getHeader(), BB))
13569 return DoesNotDominateBlock;
13570
13571 // Fall through into SCEVNAryExpr handling.
13572 [[fallthrough]];
13573 }
13574 case scAddExpr:
13575 case scMulExpr:
13576 case scUMaxExpr:
13577 case scSMaxExpr:
13578 case scUMinExpr:
13579 case scSMinExpr:
13580 case scSequentialUMinExpr: {
13581 const SCEVNAryExpr *NAry = cast<SCEVNAryExpr>(S);
13582 bool Proper = true;
13583 for (const SCEV *NAryOp : NAry->operands()) {
13584 BlockDisposition D = getBlockDisposition(NAryOp, BB);
13585 if (D == DoesNotDominateBlock)
13586 return DoesNotDominateBlock;
13587 if (D == DominatesBlock)
13588 Proper = false;
13589 }
13590 return Proper ? ProperlyDominatesBlock : DominatesBlock;
13591 }
13592 case scUDivExpr: {
13593 const SCEVUDivExpr *UDiv = cast<SCEVUDivExpr>(S);
13594 const SCEV *LHS = UDiv->getLHS(), *RHS = UDiv->getRHS();
13595 BlockDisposition LD = getBlockDisposition(LHS, BB);
13596 if (LD == DoesNotDominateBlock)
13597 return DoesNotDominateBlock;
13598 BlockDisposition RD = getBlockDisposition(RHS, BB);
13599 if (RD == DoesNotDominateBlock)
13600 return DoesNotDominateBlock;
13601 return (LD == ProperlyDominatesBlock && RD == ProperlyDominatesBlock) ?
13602 ProperlyDominatesBlock : DominatesBlock;
13603 }
13604 case scUnknown:
13605 if (Instruction *I =
13606 dyn_cast<Instruction>(cast<SCEVUnknown>(S)->getValue())) {
13607 if (I->getParent() == BB)
13608 return DominatesBlock;
13609 if (DT.properlyDominates(I->getParent(), BB))
13610 return ProperlyDominatesBlock;
13611 return DoesNotDominateBlock;
13612 }
13613 return ProperlyDominatesBlock;
13614 case scCouldNotCompute:
13615 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!")::llvm::llvm_unreachable_internal("Attempt to use a SCEVCouldNotCompute object!"
, "llvm/lib/Analysis/ScalarEvolution.cpp", 13615)
;
13616 }
13617 llvm_unreachable("Unknown SCEV kind!")::llvm::llvm_unreachable_internal("Unknown SCEV kind!", "llvm/lib/Analysis/ScalarEvolution.cpp"
, 13617)
;
13618}
13619
13620bool ScalarEvolution::dominates(const SCEV *S, const BasicBlock *BB) {
13621 return getBlockDisposition(S, BB) >= DominatesBlock;
13622}
13623
13624bool ScalarEvolution::properlyDominates(const SCEV *S, const BasicBlock *BB) {
13625 return getBlockDisposition(S, BB) == ProperlyDominatesBlock;
13626}
13627
13628bool ScalarEvolution::hasOperand(const SCEV *S, const SCEV *Op) const {
13629 return SCEVExprContains(S, [&](const SCEV *Expr) { return Expr == Op; });
13630}
13631
13632void ScalarEvolution::forgetBackedgeTakenCounts(const Loop *L,
13633 bool Predicated) {
13634 auto &BECounts =
13635 Predicated ? PredicatedBackedgeTakenCounts : BackedgeTakenCounts;
13636 auto It = BECounts.find(L);
13637 if (It != BECounts.end()) {
13638 for (const ExitNotTakenInfo &ENT : It->second.ExitNotTaken) {
13639 if (!isa<SCEVConstant>(ENT.ExactNotTaken)) {
13640 auto UserIt = BECountUsers.find(ENT.ExactNotTaken);
13641 assert(UserIt != BECountUsers.end())(static_cast <bool> (UserIt != BECountUsers.end()) ? void
(0) : __assert_fail ("UserIt != BECountUsers.end()", "llvm/lib/Analysis/ScalarEvolution.cpp"
, 13641, __extension__ __PRETTY_FUNCTION__))
;
13642 UserIt->second.erase({L, Predicated});
13643 }
13644 }
13645 BECounts.erase(It);
13646 }
13647}
13648
13649void ScalarEvolution::forgetMemoizedResults(ArrayRef<const SCEV *> SCEVs) {
13650 SmallPtrSet<const SCEV *, 8> ToForget(SCEVs.begin(), SCEVs.end());
13651 SmallVector<const SCEV *, 8> Worklist(ToForget.begin(), ToForget.end());
13652
13653 while (!Worklist.empty()) {
13654 const SCEV *Curr = Worklist.pop_back_val();
13655 auto Users = SCEVUsers.find(Curr);
13656 if (Users != SCEVUsers.end())
13657 for (const auto *User : Users->second)
13658 if (ToForget.insert(User).second)
13659 Worklist.push_back(User);
13660 }
13661
13662 for (const auto *S : ToForget)
13663 forgetMemoizedResultsImpl(S);
13664
13665 for (auto I = PredicatedSCEVRewrites.begin();
13666 I != PredicatedSCEVRewrites.end();) {
13667 std::pair<const SCEV *, const Loop *> Entry = I->first;
13668 if (ToForget.count(Entry.first))
13669 PredicatedSCEVRewrites.erase(I++);
13670 else
13671 ++I;
13672 }
13673}
13674
13675void ScalarEvolution::forgetMemoizedResultsImpl(const SCEV *S) {
13676 LoopDispositions.erase(S);
13677 BlockDispositions.erase(S);
13678 UnsignedRanges.erase(S);
13679 SignedRanges.erase(S);
13680 HasRecMap.erase(S);
13681 MinTrailingZerosCache.erase(S);
13682
13683 if (auto *AR = dyn_cast<SCEVAddRecExpr>(S)) {
13684 UnsignedWrapViaInductionTried.erase(AR);
13685 SignedWrapViaInductionTried.erase(AR);
13686 }
13687
13688 auto ExprIt = ExprValueMap.find(S);
13689 if (ExprIt != ExprValueMap.end()) {
13690 for (Value *V : ExprIt->second) {
13691 auto ValueIt = ValueExprMap.find_as(V);
13692 if (ValueIt != ValueExprMap.end())
13693 ValueExprMap.erase(ValueIt);
13694 }
13695 ExprValueMap.erase(ExprIt);
13696 }
13697
13698 auto ScopeIt = ValuesAtScopes.find(S);
13699 if (ScopeIt != ValuesAtScopes.end()) {
13700 for (const auto &Pair : ScopeIt->second)
13701 if (!isa_and_nonnull<SCEVConstant>(Pair.second))
13702 erase_value(ValuesAtScopesUsers[Pair.second],
13703 std::make_pair(Pair.first, S));
13704 ValuesAtScopes.erase(ScopeIt);
13705 }
13706
13707 auto ScopeUserIt = ValuesAtScopesUsers.find(S);
13708 if (ScopeUserIt != ValuesAtScopesUsers.end()) {
13709 for (const auto &Pair : ScopeUserIt->second)
13710 erase_value(ValuesAtScopes[Pair.second], std::make_pair(Pair.first, S));
13711 ValuesAtScopesUsers.erase(ScopeUserIt);
13712 }
13713
13714 auto BEUsersIt = BECountUsers.find(S);
13715 if (BEUsersIt != BECountUsers.end()) {
13716 // Work on a copy, as forgetBackedgeTakenCounts() will modify the original.
13717 auto Copy = BEUsersIt->second;
13718 for (const auto &Pair : Copy)
13719 forgetBackedgeTakenCounts(Pair.getPointer(), Pair.getInt());
13720 BECountUsers.erase(BEUsersIt);
13721 }
13722}
13723
13724void
13725ScalarEvolution::getUsedLoops(const SCEV *S,
13726 SmallPtrSetImpl<const Loop *> &LoopsUsed) {
13727 struct FindUsedLoops {
13728 FindUsedLoops(SmallPtrSetImpl<const Loop *> &LoopsUsed)
13729 : LoopsUsed(LoopsUsed) {}
13730 SmallPtrSetImpl<const Loop *> &LoopsUsed;
13731 bool follow(const SCEV *S) {
13732 if (auto *AR = dyn_cast<SCEVAddRecExpr>(S))
13733 LoopsUsed.insert(AR->getLoop());
13734 return true;
13735 }
13736
13737 bool isDone() const { return false; }
13738 };
13739
13740 FindUsedLoops F(LoopsUsed);
13741 SCEVTraversal<FindUsedLoops>(F).visitAll(S);
13742}
13743
13744void ScalarEvolution::getReachableBlocks(
13745 SmallPtrSetImpl<BasicBlock *> &Reachable, Function &F) {
13746 SmallVector<BasicBlock *> Worklist;
13747 Worklist.push_back(&F.getEntryBlock());
13748 while (!Worklist.empty()) {
13749 BasicBlock *BB = Worklist.pop_back_val();
13750 if (!Reachable.insert(BB).second)
13751 continue;
13752
13753 Value *Cond;
13754 BasicBlock *TrueBB, *FalseBB;
13755 if (match(BB->getTerminator(), m_Br(m_Value(Cond), m_BasicBlock(TrueBB),
13756 m_BasicBlock(FalseBB)))) {
13757 if (auto *C = dyn_cast<ConstantInt>(Cond)) {
13758 Worklist.push_back(C->isOne() ? TrueBB : FalseBB);
13759 continue;
13760 }
13761
13762 if (auto *Cmp = dyn_cast<ICmpInst>(Cond)) {
13763 const SCEV *L = getSCEV(Cmp->getOperand(0));
13764 const SCEV *R = getSCEV(Cmp->getOperand(1));
13765 if (isKnownPredicateViaConstantRanges(Cmp->getPredicate(), L, R)) {
13766 Worklist.push_back(TrueBB);
13767 continue;
13768 }
13769 if (isKnownPredicateViaConstantRanges(Cmp->getInversePredicate(), L,
13770 R)) {
13771 Worklist.push_back(FalseBB);
13772 continue;
13773 }
13774 }
13775 }
13776
13777 append_range(Worklist, successors(BB));
13778 }
13779}
13780
13781void ScalarEvolution::verify() const {
13782 ScalarEvolution &SE = *const_cast<ScalarEvolution *>(this);
13783 ScalarEvolution SE2(F, TLI, AC, DT, LI);
13784
13785 SmallVector<Loop *, 8> LoopStack(LI.begin(), LI.end());
13786
13787 // Map's SCEV expressions from one ScalarEvolution "universe" to another.
13788 struct SCEVMapper : public SCEVRewriteVisitor<SCEVMapper> {
13789 SCEVMapper(ScalarEvolution &SE) : SCEVRewriteVisitor<SCEVMapper>(SE) {}
13790
13791 const SCEV *visitConstant(const SCEVConstant *Constant) {
13792 return SE.getConstant(Constant->getAPInt());
13793 }
13794
13795 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
13796 return SE.getUnknown(Expr->getValue());
13797 }
13798
13799 const SCEV *visitCouldNotCompute(const SCEVCouldNotCompute *Expr) {
13800 return SE.getCouldNotCompute();
13801 }
13802 };
13803
13804 SCEVMapper SCM(SE2);
13805 SmallPtrSet<BasicBlock *, 16> ReachableBlocks;
13806 SE2.getReachableBlocks(ReachableBlocks, F);
13807
13808 auto GetDelta = [&](const SCEV *Old, const SCEV *New) -> const SCEV * {
13809 if (containsUndefs(Old) || containsUndefs(New)) {
13810 // SCEV treats "undef" as an unknown but consistent value (i.e. it does
13811 // not propagate undef aggressively). This means we can (and do) fail
13812 // verification in cases where a transform makes a value go from "undef"
13813 // to "undef+1" (say). The transform is fine, since in both cases the
13814 // result is "undef", but SCEV thinks the value increased by 1.
13815 return nullptr;
13816 }
13817
13818 // Unless VerifySCEVStrict is set, we only compare constant deltas.
13819 const SCEV *Delta = SE2.getMinusSCEV(Old, New);
13820 if (!VerifySCEVStrict && !isa<SCEVConstant>(Delta))
13821 return nullptr;
13822
13823 return Delta;
13824 };
13825
13826 while (!LoopStack.empty()) {
13827 auto *L = LoopStack.pop_back_val();
13828 llvm::append_range(LoopStack, *L);
13829
13830 // Only verify BECounts in reachable loops. For an unreachable loop,
13831 // any BECount is legal.
13832 if (!ReachableBlocks.contains(L->getHeader()))
13833 continue;
13834
13835 // Only verify cached BECounts. Computing new BECounts may change the
13836 // results of subsequent SCEV uses.
13837 auto It = BackedgeTakenCounts.find(L);
13838 if (It == BackedgeTakenCounts.end())
13839 continue;
13840
13841 auto *CurBECount =
13842 SCM.visit(It->second.getExact(L, const_cast<ScalarEvolution *>(this)));
13843 auto *NewBECount = SE2.getBackedgeTakenCount(L);
13844
13845 if (CurBECount == SE2.getCouldNotCompute() ||
13846 NewBECount == SE2.getCouldNotCompute()) {
13847 // NB! This situation is legal, but is very suspicious -- whatever pass
13848 // change the loop to make a trip count go from could not compute to
13849 // computable or vice-versa *should have* invalidated SCEV. However, we
13850 // choose not to assert here (for now) since we don't want false
13851 // positives.
13852 continue;
13853 }
13854
13855 if (SE.getTypeSizeInBits(CurBECount->getType()) >
13856 SE.getTypeSizeInBits(NewBECount->getType()))
13857 NewBECount = SE2.getZeroExtendExpr(NewBECount, CurBECount->getType());
13858 else if (SE.getTypeSizeInBits(CurBECount->getType()) <
13859 SE.getTypeSizeInBits(NewBECount->getType()))
13860 CurBECount = SE2.getZeroExtendExpr(CurBECount, NewBECount->getType());
13861
13862 const SCEV *Delta = GetDelta(CurBECount, NewBECount);
13863 if (Delta && !Delta->isZero()) {
13864 dbgs() << "Trip Count for " << *L << " Changed!\n";
13865 dbgs() << "Old: " << *CurBECount << "\n";
13866 dbgs() << "New: " << *NewBECount << "\n";
13867 dbgs() << "Delta: " << *Delta << "\n";
13868 std::abort();
13869 }
13870 }
13871
13872 // Collect all valid loops currently in LoopInfo.
13873 SmallPtrSet<Loop *, 32> ValidLoops;
13874 SmallVector<Loop *, 32> Worklist(LI.begin(), LI.end());
13875 while (!Worklist.empty()) {
13876 Loop *L = Worklist.pop_back_val();
13877 if (ValidLoops.insert(L).second)
13878 Worklist.append(L->begin(), L->end());
13879 }
13880 for (const auto &KV : ValueExprMap) {
13881#ifndef NDEBUG
13882 // Check for SCEV expressions referencing invalid/deleted loops.
13883 if (auto *AR = dyn_cast<SCEVAddRecExpr>(KV.second)) {
13884 assert(ValidLoops.contains(AR->getLoop()) &&(static_cast <bool> (ValidLoops.contains(AR->getLoop
()) && "AddRec references invalid loop") ? void (0) :
__assert_fail ("ValidLoops.contains(AR->getLoop()) && \"AddRec references invalid loop\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 13885, __extension__
__PRETTY_FUNCTION__))
13885 "AddRec references invalid loop")(static_cast <bool> (ValidLoops.contains(AR->getLoop
()) && "AddRec references invalid loop") ? void (0) :
__assert_fail ("ValidLoops.contains(AR->getLoop()) && \"AddRec references invalid loop\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 13885, __extension__
__PRETTY_FUNCTION__))
;
13886 }
13887#endif
13888
13889 // Check that the value is also part of the reverse map.
13890 auto It = ExprValueMap.find(KV.second);
13891 if (It == ExprValueMap.end() || !It->second.contains(KV.first)) {
13892 dbgs() << "Value " << *KV.first
13893 << " is in ValueExprMap but not in ExprValueMap\n";
13894 std::abort();
13895 }
13896
13897 if (auto *I = dyn_cast<Instruction>(&*KV.first)) {
13898 if (!ReachableBlocks.contains(I->getParent()))
13899 continue;
13900 const SCEV *OldSCEV = SCM.visit(KV.second);
13901 const SCEV *NewSCEV = SE2.getSCEV(I);
13902 const SCEV *Delta = GetDelta(OldSCEV, NewSCEV);
13903 if (Delta && !Delta->isZero()) {
13904 dbgs() << "SCEV for value " << *I << " changed!\n"
13905 << "Old: " << *OldSCEV << "\n"
13906 << "New: " << *NewSCEV << "\n"
13907 << "Delta: " << *Delta << "\n";
13908 std::abort();
13909 }
13910 }
13911 }
13912
13913 for (const auto &KV : ExprValueMap) {
13914 for (Value *V : KV.second) {
13915 auto It = ValueExprMap.find_as(V);
13916 if (It == ValueExprMap.end()) {
13917 dbgs() << "Value " << *V
13918 << " is in ExprValueMap but not in ValueExprMap\n";
13919 std::abort();
13920 }
13921 if (It->second != KV.first) {
13922 dbgs() << "Value " << *V << " mapped to " << *It->second
13923 << " rather than " << *KV.first << "\n";
13924 std::abort();
13925 }
13926 }
13927 }
13928
13929 // Verify integrity of SCEV users.
13930 for (const auto &S : UniqueSCEVs) {
13931 SmallVector<const SCEV *, 4> Ops;
13932 collectUniqueOps(&S, Ops);
13933 for (const auto *Op : Ops) {
13934 // We do not store dependencies of constants.
13935 if (isa<SCEVConstant>(Op))
13936 continue;
13937 auto It = SCEVUsers.find(Op);
13938 if (It != SCEVUsers.end() && It->second.count(&S))
13939 continue;
13940 dbgs() << "Use of operand " << *Op << " by user " << S
13941 << " is not being tracked!\n";
13942 std::abort();
13943 }
13944 }
13945
13946 // Verify integrity of ValuesAtScopes users.
13947 for (const auto &ValueAndVec : ValuesAtScopes) {
13948 const SCEV *Value = ValueAndVec.first;
13949 for (const auto &LoopAndValueAtScope : ValueAndVec.second) {
13950 const Loop *L = LoopAndValueAtScope.first;
13951 const SCEV *ValueAtScope = LoopAndValueAtScope.second;
13952 if (!isa<SCEVConstant>(ValueAtScope)) {
13953 auto It = ValuesAtScopesUsers.find(ValueAtScope);
13954 if (It != ValuesAtScopesUsers.end() &&
13955 is_contained(It->second, std::make_pair(L, Value)))
13956 continue;
13957 dbgs() << "Value: " << *Value << ", Loop: " << *L << ", ValueAtScope: "
13958 << *ValueAtScope << " missing in ValuesAtScopesUsers\n";
13959 std::abort();
13960 }
13961 }
13962 }
13963
13964 for (const auto &ValueAtScopeAndVec : ValuesAtScopesUsers) {
13965 const SCEV *ValueAtScope = ValueAtScopeAndVec.first;
13966 for (const auto &LoopAndValue : ValueAtScopeAndVec.second) {
13967 const Loop *L = LoopAndValue.first;
13968 const SCEV *Value = LoopAndValue.second;
13969 assert(!isa<SCEVConstant>(Value))(static_cast <bool> (!isa<SCEVConstant>(Value)) ?
void (0) : __assert_fail ("!isa<SCEVConstant>(Value)",
"llvm/lib/Analysis/ScalarEvolution.cpp", 13969, __extension__
__PRETTY_FUNCTION__))
;
13970 auto It = ValuesAtScopes.find(Value);
13971 if (It != ValuesAtScopes.end() &&
13972 is_contained(It->second, std::make_pair(L, ValueAtScope)))
13973 continue;
13974 dbgs() << "Value: " << *Value << ", Loop: " << *L << ", ValueAtScope: "
13975 << *ValueAtScope << " missing in ValuesAtScopes\n";
13976 std::abort();
13977 }
13978 }
13979
13980 // Verify integrity of BECountUsers.
13981 auto VerifyBECountUsers = [&](bool Predicated) {
13982 auto &BECounts =
13983 Predicated ? PredicatedBackedgeTakenCounts : BackedgeTakenCounts;
13984 for (const auto &LoopAndBEInfo : BECounts) {
13985 for (const ExitNotTakenInfo &ENT : LoopAndBEInfo.second.ExitNotTaken) {
13986 if (!isa<SCEVConstant>(ENT.ExactNotTaken)) {
13987 auto UserIt = BECountUsers.find(ENT.ExactNotTaken);
13988 if (UserIt != BECountUsers.end() &&
13989 UserIt->second.contains({ LoopAndBEInfo.first, Predicated }))
13990 continue;
13991 dbgs() << "Value " << *ENT.ExactNotTaken << " for loop "
13992 << *LoopAndBEInfo.first << " missing from BECountUsers\n";
13993 std::abort();
13994 }
13995 }
13996 }
13997 };
13998 VerifyBECountUsers(/* Predicated */ false);
13999 VerifyBECountUsers(/* Predicated */ true);
14000
14001 // Verify intergity of loop disposition cache.
14002 for (const auto &It : LoopDispositions) {
14003 const SCEV *S = It.first;
14004 auto &Values = It.second;
14005 for (auto &V : Values) {
14006 auto CachedDisposition = V.getInt();
14007 const auto *Loop = V.getPointer();
14008 const auto RecomputedDisposition = SE2.getLoopDisposition(S, Loop);
14009 if (CachedDisposition != RecomputedDisposition) {
14010 dbgs() << "Cached disposition of " << *S << " for loop " << *Loop
14011 << " is incorrect: cached "
14012 << loopDispositionToStr(CachedDisposition) << ", actual "
14013 << loopDispositionToStr(RecomputedDisposition) << "\n";
14014 std::abort();
14015 }
14016 }
14017 }
14018
14019 // Verify integrity of the block disposition cache.
14020 for (const auto &It : BlockDispositions) {
14021 const SCEV *S = It.first;
14022 auto &Values = It.second;
14023 for (auto &V : Values) {
14024 auto CachedDisposition = V.getInt();
14025 const BasicBlock *BB = V.getPointer();
14026 const auto RecomputedDisposition = SE2.getBlockDisposition(S, BB);
14027 if (CachedDisposition != RecomputedDisposition) {
14028 dbgs() << "Cached disposition of " << *S << " for block %"
14029 << BB->getName() << " is incorrect! \n";
14030 std::abort();
14031 }
14032 }
14033 }
14034}
14035
14036bool ScalarEvolution::invalidate(
14037 Function &F, const PreservedAnalyses &PA,
14038 FunctionAnalysisManager::Invalidator &Inv) {
14039 // Invalidate the ScalarEvolution object whenever it isn't preserved or one
14040 // of its dependencies is invalidated.
14041 auto PAC = PA.getChecker<ScalarEvolutionAnalysis>();
14042 return !(PAC.preserved() || PAC.preservedSet<AllAnalysesOn<Function>>()) ||
14043 Inv.invalidate<AssumptionAnalysis>(F, PA) ||
14044 Inv.invalidate<DominatorTreeAnalysis>(F, PA) ||
14045 Inv.invalidate<LoopAnalysis>(F, PA);
14046}
14047
14048AnalysisKey ScalarEvolutionAnalysis::Key;
14049
14050ScalarEvolution ScalarEvolutionAnalysis::run(Function &F,
14051 FunctionAnalysisManager &AM) {
14052 return ScalarEvolution(F, AM.getResult<TargetLibraryAnalysis>(F),
14053 AM.getResult<AssumptionAnalysis>(F),
14054 AM.getResult<DominatorTreeAnalysis>(F),
14055 AM.getResult<LoopAnalysis>(F));
14056}
14057
14058PreservedAnalyses
14059ScalarEvolutionVerifierPass::run(Function &F, FunctionAnalysisManager &AM) {
14060 AM.getResult<ScalarEvolutionAnalysis>(F).verify();
14061 return PreservedAnalyses::all();
14062}
14063
14064PreservedAnalyses
14065ScalarEvolutionPrinterPass::run(Function &F, FunctionAnalysisManager &AM) {
14066 // For compatibility with opt's -analyze feature under legacy pass manager
14067 // which was not ported to NPM. This keeps tests using
14068 // update_analyze_test_checks.py working.
14069 OS << "Printing analysis 'Scalar Evolution Analysis' for function '"
14070 << F.getName() << "':\n";
14071 AM.getResult<ScalarEvolutionAnalysis>(F).print(OS);
14072 return PreservedAnalyses::all();
14073}
14074
14075INITIALIZE_PASS_BEGIN(ScalarEvolutionWrapperPass, "scalar-evolution",static void *initializeScalarEvolutionWrapperPassPassOnce(PassRegistry
&Registry) {
14076 "Scalar Evolution Analysis", false, true)static void *initializeScalarEvolutionWrapperPassPassOnce(PassRegistry
&Registry) {
14077INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker)initializeAssumptionCacheTrackerPass(Registry);
14078INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass)initializeLoopInfoWrapperPassPass(Registry);
14079INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)initializeDominatorTreeWrapperPassPass(Registry);
14080INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass)initializeTargetLibraryInfoWrapperPassPass(Registry);
14081INITIALIZE_PASS_END(ScalarEvolutionWrapperPass, "scalar-evolution",PassInfo *PI = new PassInfo( "Scalar Evolution Analysis", "scalar-evolution"
, &ScalarEvolutionWrapperPass::ID, PassInfo::NormalCtor_t
(callDefaultCtor<ScalarEvolutionWrapperPass>), false, true
); Registry.registerPass(*PI, true); return PI; } static llvm
::once_flag InitializeScalarEvolutionWrapperPassPassFlag; void
llvm::initializeScalarEvolutionWrapperPassPass(PassRegistry &
Registry) { llvm::call_once(InitializeScalarEvolutionWrapperPassPassFlag
, initializeScalarEvolutionWrapperPassPassOnce, std::ref(Registry
)); }
14082 "Scalar Evolution Analysis", false, true)PassInfo *PI = new PassInfo( "Scalar Evolution Analysis", "scalar-evolution"
, &ScalarEvolutionWrapperPass::ID, PassInfo::NormalCtor_t
(callDefaultCtor<ScalarEvolutionWrapperPass>), false, true
); Registry.registerPass(*PI, true); return PI; } static llvm
::once_flag InitializeScalarEvolutionWrapperPassPassFlag; void
llvm::initializeScalarEvolutionWrapperPassPass(PassRegistry &
Registry) { llvm::call_once(InitializeScalarEvolutionWrapperPassPassFlag
, initializeScalarEvolutionWrapperPassPassOnce, std::ref(Registry
)); }
14083
14084char ScalarEvolutionWrapperPass::ID = 0;
14085
14086ScalarEvolutionWrapperPass::ScalarEvolutionWrapperPass() : FunctionPass(ID) {
14087 initializeScalarEvolutionWrapperPassPass(*PassRegistry::getPassRegistry());
14088}
14089
14090bool ScalarEvolutionWrapperPass::runOnFunction(Function &F) {
14091 SE.reset(new ScalarEvolution(
14092 F, getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F),
14093 getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F),
14094 getAnalysis<DominatorTreeWrapperPass>().getDomTree(),
14095 getAnalysis<LoopInfoWrapperPass>().getLoopInfo()));
14096 return false;
14097}
14098
14099void ScalarEvolutionWrapperPass::releaseMemory() { SE.reset(); }
14100
14101void ScalarEvolutionWrapperPass::print(raw_ostream &OS, const Module *) const {
14102 SE->print(OS);
14103}
14104
14105void ScalarEvolutionWrapperPass::verifyAnalysis() const {
14106 if (!VerifySCEV)
14107 return;
14108
14109 SE->verify();
14110}
14111
14112void ScalarEvolutionWrapperPass::getAnalysisUsage(AnalysisUsage &AU) const {
14113 AU.setPreservesAll();
14114 AU.addRequiredTransitive<AssumptionCacheTracker>();
14115 AU.addRequiredTransitive<LoopInfoWrapperPass>();
14116 AU.addRequiredTransitive<DominatorTreeWrapperPass>();
14117 AU.addRequiredTransitive<TargetLibraryInfoWrapperPass>();
14118}
14119
14120const SCEVPredicate *ScalarEvolution::getEqualPredicate(const SCEV *LHS,
14121 const SCEV *RHS) {
14122 return getComparePredicate(ICmpInst::ICMP_EQ, LHS, RHS);
14123}
14124
14125const SCEVPredicate *
14126ScalarEvolution::getComparePredicate(const ICmpInst::Predicate Pred,
14127 const SCEV *LHS, const SCEV *RHS) {
14128 FoldingSetNodeID ID;
14129 assert(LHS->getType() == RHS->getType() &&(static_cast <bool> (LHS->getType() == RHS->getType
() && "Type mismatch between LHS and RHS") ? void (0)
: __assert_fail ("LHS->getType() == RHS->getType() && \"Type mismatch between LHS and RHS\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 14130, __extension__
__PRETTY_FUNCTION__))
14130 "Type mismatch between LHS and RHS")(static_cast <bool> (LHS->getType() == RHS->getType
() && "Type mismatch between LHS and RHS") ? void (0)
: __assert_fail ("LHS->getType() == RHS->getType() && \"Type mismatch between LHS and RHS\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 14130, __extension__
__PRETTY_FUNCTION__))
;
14131 // Unique this node based on the arguments
14132 ID.AddInteger(SCEVPredicate::P_Compare);
14133 ID.AddInteger(Pred);
14134 ID.AddPointer(LHS);
14135 ID.AddPointer(RHS);
14136 void *IP = nullptr;
14137 if (const auto *S = UniquePreds.FindNodeOrInsertPos(ID, IP))
14138 return S;
14139 SCEVComparePredicate *Eq = new (SCEVAllocator)
14140 SCEVComparePredicate(ID.Intern(SCEVAllocator), Pred, LHS, RHS);
14141 UniquePreds.InsertNode(Eq, IP);
14142 return Eq;
14143}
14144
14145const SCEVPredicate *ScalarEvolution::getWrapPredicate(
14146 const SCEVAddRecExpr *AR,
14147 SCEVWrapPredicate::IncrementWrapFlags AddedFlags) {
14148 FoldingSetNodeID ID;
14149 // Unique this node based on the arguments
14150 ID.AddInteger(SCEVPredicate::P_Wrap);
14151 ID.AddPointer(AR);
14152 ID.AddInteger(AddedFlags);
14153 void *IP = nullptr;
14154 if (const auto *S = UniquePreds.FindNodeOrInsertPos(ID, IP))
14155 return S;
14156 auto *OF = new (SCEVAllocator)
14157 SCEVWrapPredicate(ID.Intern(SCEVAllocator), AR, AddedFlags);
14158 UniquePreds.InsertNode(OF, IP);
14159 return OF;
14160}
14161
14162namespace {
14163
14164class SCEVPredicateRewriter : public SCEVRewriteVisitor<SCEVPredicateRewriter> {
14165public:
14166
14167 /// Rewrites \p S in the context of a loop L and the SCEV predication
14168 /// infrastructure.
14169 ///
14170 /// If \p Pred is non-null, the SCEV expression is rewritten to respect the
14171 /// equivalences present in \p Pred.
14172 ///
14173 /// If \p NewPreds is non-null, rewrite is free to add further predicates to
14174 /// \p NewPreds such that the result will be an AddRecExpr.
14175 static const SCEV *rewrite(const SCEV *S, const Loop *L, ScalarEvolution &SE,
14176 SmallPtrSetImpl<const SCEVPredicate *> *NewPreds,
14177 const SCEVPredicate *Pred) {
14178 SCEVPredicateRewriter Rewriter(L, SE, NewPreds, Pred);
14179 return Rewriter.visit(S);
14180 }
14181
14182 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
14183 if (Pred) {
14184 if (auto *U = dyn_cast<SCEVUnionPredicate>(Pred)) {
14185 for (const auto *Pred : U->getPredicates())
14186 if (const auto *IPred = dyn_cast<SCEVComparePredicate>(Pred))
14187 if (IPred->getLHS() == Expr &&
14188 IPred->getPredicate() == ICmpInst::ICMP_EQ)
14189 return IPred->getRHS();
14190 } else if (const auto *IPred = dyn_cast<SCEVComparePredicate>(Pred)) {
14191 if (IPred->getLHS() == Expr &&
14192 IPred->getPredicate() == ICmpInst::ICMP_EQ)
14193 return IPred->getRHS();
14194 }
14195 }
14196 return convertToAddRecWithPreds(Expr);
14197 }
14198
14199 const SCEV *visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr) {
14200 const SCEV *Operand = visit(Expr->getOperand());
14201 const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Operand);
14202 if (AR && AR->getLoop() == L && AR->isAffine()) {
14203 // This couldn't be folded because the operand didn't have the nuw
14204 // flag. Add the nusw flag as an assumption that we could make.
14205 const SCEV *Step = AR->getStepRecurrence(SE);
14206 Type *Ty = Expr->getType();
14207 if (addOverflowAssumption(AR, SCEVWrapPredicate::IncrementNUSW))
14208 return SE.getAddRecExpr(SE.getZeroExtendExpr(AR->getStart(), Ty),
14209 SE.getSignExtendExpr(Step, Ty), L,
14210 AR->getNoWrapFlags());
14211 }
14212 return SE.getZeroExtendExpr(Operand, Expr->getType());
14213 }
14214
14215 const SCEV *visitSignExtendExpr(const SCEVSignExtendExpr *Expr) {
14216 const SCEV *Operand = visit(Expr->getOperand());
14217 const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Operand);
14218 if (AR && AR->getLoop() == L && AR->isAffine()) {
14219 // This couldn't be folded because the operand didn't have the nsw
14220 // flag. Add the nssw flag as an assumption that we could make.
14221 const SCEV *Step = AR->getStepRecurrence(SE);
14222 Type *Ty = Expr->getType();
14223 if (addOverflowAssumption(AR, SCEVWrapPredicate::IncrementNSSW))
14224 return SE.getAddRecExpr(SE.getSignExtendExpr(AR->getStart(), Ty),
14225 SE.getSignExtendExpr(Step, Ty), L,
14226 AR->getNoWrapFlags());
14227 }
14228 return SE.getSignExtendExpr(Operand, Expr->getType());
14229 }
14230
14231private:
14232 explicit SCEVPredicateRewriter(const Loop *L, ScalarEvolution &SE,
14233 SmallPtrSetImpl<const SCEVPredicate *> *NewPreds,
14234 const SCEVPredicate *Pred)
14235 : SCEVRewriteVisitor(SE), NewPreds(NewPreds), Pred(Pred), L(L) {}
14236
14237 bool addOverflowAssumption(const SCEVPredicate *P) {
14238 if (!NewPreds) {
14239 // Check if we've already made this assumption.
14240 return Pred && Pred->implies(P);
14241 }
14242 NewPreds->insert(P);
14243 return true;
14244 }
14245
14246 bool addOverflowAssumption(const SCEVAddRecExpr *AR,
14247 SCEVWrapPredicate::IncrementWrapFlags AddedFlags) {
14248 auto *A = SE.getWrapPredicate(AR, AddedFlags);
14249 return addOverflowAssumption(A);
14250 }
14251
14252 // If \p Expr represents a PHINode, we try to see if it can be represented
14253 // as an AddRec, possibly under a predicate (PHISCEVPred). If it is possible
14254 // to add this predicate as a runtime overflow check, we return the AddRec.
14255 // If \p Expr does not meet these conditions (is not a PHI node, or we
14256 // couldn't create an AddRec for it, or couldn't add the predicate), we just
14257 // return \p Expr.
14258 const SCEV *convertToAddRecWithPreds(const SCEVUnknown *Expr) {
14259 if (!isa<PHINode>(Expr->getValue()))
14260 return Expr;
14261 Optional<std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>>>
14262 PredicatedRewrite = SE.createAddRecFromPHIWithCasts(Expr);
14263 if (!PredicatedRewrite)
14264 return Expr;
14265 for (const auto *P : PredicatedRewrite->second){
14266 // Wrap predicates from outer loops are not supported.
14267 if (auto *WP = dyn_cast<const SCEVWrapPredicate>(P)) {
14268 if (L != WP->getExpr()->getLoop())
14269 return Expr;
14270 }
14271 if (!addOverflowAssumption(P))
14272 return Expr;
14273 }
14274 return PredicatedRewrite->first;
14275 }
14276
14277 SmallPtrSetImpl<const SCEVPredicate *> *NewPreds;
14278 const SCEVPredicate *Pred;
14279 const Loop *L;
14280};
14281
14282} // end anonymous namespace
14283
14284const SCEV *
14285ScalarEvolution::rewriteUsingPredicate(const SCEV *S, const Loop *L,
14286 const SCEVPredicate &Preds) {
14287 return SCEVPredicateRewriter::rewrite(S, L, *this, nullptr, &Preds);
14288}
14289
14290const SCEVAddRecExpr *ScalarEvolution::convertSCEVToAddRecWithPredicates(
14291 const SCEV *S, const Loop *L,
14292 SmallPtrSetImpl<const SCEVPredicate *> &Preds) {
14293 SmallPtrSet<const SCEVPredicate *, 4> TransformPreds;
14294 S = SCEVPredicateRewriter::rewrite(S, L, *this, &TransformPreds, nullptr);
14295 auto *AddRec = dyn_cast<SCEVAddRecExpr>(S);
14296
14297 if (!AddRec)
14298 return nullptr;
14299
14300 // Since the transformation was successful, we can now transfer the SCEV
14301 // predicates.
14302 for (const auto *P : TransformPreds)
14303 Preds.insert(P);
14304
14305 return AddRec;
14306}
14307
14308/// SCEV predicates
14309SCEVPredicate::SCEVPredicate(const FoldingSetNodeIDRef ID,
14310 SCEVPredicateKind Kind)
14311 : FastID(ID), Kind(Kind) {}
14312
14313SCEVComparePredicate::SCEVComparePredicate(const FoldingSetNodeIDRef ID,
14314 const ICmpInst::Predicate Pred,
14315 const SCEV *LHS, const SCEV *RHS)
14316 : SCEVPredicate(ID, P_Compare), Pred(Pred), LHS(LHS), RHS(RHS) {
14317 assert(LHS->getType() == RHS->getType() && "LHS and RHS types don't match")(static_cast <bool> (LHS->getType() == RHS->getType
() && "LHS and RHS types don't match") ? void (0) : __assert_fail
("LHS->getType() == RHS->getType() && \"LHS and RHS types don't match\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 14317, __extension__
__PRETTY_FUNCTION__))
;
14318 assert(LHS != RHS && "LHS and RHS are the same SCEV")(static_cast <bool> (LHS != RHS && "LHS and RHS are the same SCEV"
) ? void (0) : __assert_fail ("LHS != RHS && \"LHS and RHS are the same SCEV\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 14318, __extension__
__PRETTY_FUNCTION__))
;
14319}
14320
14321bool SCEVComparePredicate::implies(const SCEVPredicate *N) const {
14322 const auto *Op = dyn_cast<SCEVComparePredicate>(N);
14323
14324 if (!Op)
14325 return false;
14326
14327 if (Pred != ICmpInst::ICMP_EQ)
14328 return false;
14329
14330 return Op->LHS == LHS && Op->RHS == RHS;
14331}
14332
14333bool SCEVComparePredicate::isAlwaysTrue() const { return false; }
14334
14335void SCEVComparePredicate::print(raw_ostream &OS, unsigned Depth) const {
14336 if (Pred == ICmpInst::ICMP_EQ)
14337 OS.indent(Depth) << "Equal predicate: " << *LHS << " == " << *RHS << "\n";
14338 else
14339 OS.indent(Depth) << "Compare predicate: " << *LHS
14340 << " " << CmpInst::getPredicateName(Pred) << ") "
14341 << *RHS << "\n";
14342
14343}
14344
14345SCEVWrapPredicate::SCEVWrapPredicate(const FoldingSetNodeIDRef ID,
14346 const SCEVAddRecExpr *AR,
14347 IncrementWrapFlags Flags)
14348 : SCEVPredicate(ID, P_Wrap), AR(AR), Flags(Flags) {}
14349
14350const SCEVAddRecExpr *SCEVWrapPredicate::getExpr() const { return AR; }
14351
14352bool SCEVWrapPredicate::implies(const SCEVPredicate *N) const {
14353 const auto *Op = dyn_cast<SCEVWrapPredicate>(N);
14354
14355 return Op && Op->AR == AR && setFlags(Flags, Op->Flags) == Flags;
14356}
14357
14358bool SCEVWrapPredicate::isAlwaysTrue() const {
14359 SCEV::NoWrapFlags ScevFlags = AR->getNoWrapFlags();
14360 IncrementWrapFlags IFlags = Flags;
14361
14362 if (ScalarEvolution::setFlags(ScevFlags, SCEV::FlagNSW) == ScevFlags)
14363 IFlags = clearFlags(IFlags, IncrementNSSW);
14364
14365 return IFlags == IncrementAnyWrap;
14366}
14367
14368void SCEVWrapPredicate::print(raw_ostream &OS, unsigned Depth) const {
14369 OS.indent(Depth) << *getExpr() << " Added Flags: ";
14370 if (SCEVWrapPredicate::IncrementNUSW & getFlags())
14371 OS << "<nusw>";
14372 if (SCEVWrapPredicate::IncrementNSSW & getFlags())
14373 OS << "<nssw>";
14374 OS << "\n";
14375}
14376
14377SCEVWrapPredicate::IncrementWrapFlags
14378SCEVWrapPredicate::getImpliedFlags(const SCEVAddRecExpr *AR,
14379 ScalarEvolution &SE) {
14380 IncrementWrapFlags ImpliedFlags = IncrementAnyWrap;
14381 SCEV::NoWrapFlags StaticFlags = AR->getNoWrapFlags();
14382
14383 // We can safely transfer the NSW flag as NSSW.
14384 if (ScalarEvolution::setFlags(StaticFlags, SCEV::FlagNSW) == StaticFlags)
14385 ImpliedFlags = IncrementNSSW;
14386
14387 if (ScalarEvolution::setFlags(StaticFlags, SCEV::FlagNUW) == StaticFlags) {
14388 // If the increment is positive, the SCEV NUW flag will also imply the
14389 // WrapPredicate NUSW flag.
14390 if (const auto *Step = dyn_cast<SCEVConstant>(AR->getStepRecurrence(SE)))
14391 if (Step->getValue()->getValue().isNonNegative())
14392 ImpliedFlags = setFlags(ImpliedFlags, IncrementNUSW);
14393 }
14394
14395 return ImpliedFlags;
14396}
14397
14398/// Union predicates don't get cached so create a dummy set ID for it.
14399SCEVUnionPredicate::SCEVUnionPredicate(ArrayRef<const SCEVPredicate *> Preds)
14400 : SCEVPredicate(FoldingSetNodeIDRef(nullptr, 0), P_Union) {
14401 for (const auto *P : Preds)
14402 add(P);
14403}
14404
14405bool SCEVUnionPredicate::isAlwaysTrue() const {
14406 return all_of(Preds,
14407 [](const SCEVPredicate *I) { return I->isAlwaysTrue(); });
14408}
14409
14410bool SCEVUnionPredicate::implies(const SCEVPredicate *N) const {
14411 if (const auto *Set = dyn_cast<SCEVUnionPredicate>(N))
14412 return all_of(Set->Preds,
14413 [this](const SCEVPredicate *I) { return this->implies(I); });
14414
14415 return any_of(Preds,
14416 [N](const SCEVPredicate *I) { return I->implies(N); });
14417}
14418
14419void SCEVUnionPredicate::print(raw_ostream &OS, unsigned Depth) const {
14420 for (const auto *Pred : Preds)
14421 Pred->print(OS, Depth);
14422}
14423
14424void SCEVUnionPredicate::add(const SCEVPredicate *N) {
14425 if (const auto *Set = dyn_cast<SCEVUnionPredicate>(N)) {
14426 for (const auto *Pred : Set->Preds)
14427 add(Pred);
14428 return;
14429 }
14430
14431 Preds.push_back(N);
14432}
14433
14434PredicatedScalarEvolution::PredicatedScalarEvolution(ScalarEvolution &SE,
14435 Loop &L)
14436 : SE(SE), L(L) {
14437 SmallVector<const SCEVPredicate*, 4> Empty;
14438 Preds = std::make_unique<SCEVUnionPredicate>(Empty);
14439}
14440
14441void ScalarEvolution::registerUser(const SCEV *User,
14442 ArrayRef<const SCEV *> Ops) {
14443 for (const auto *Op : Ops)
14444 // We do not expect that forgetting cached data for SCEVConstants will ever
14445 // open any prospects for sharpening or introduce any correctness issues,
14446 // so we don't bother storing their dependencies.
14447 if (!isa<SCEVConstant>(Op))
14448 SCEVUsers[Op].insert(User);
14449}
14450
14451const SCEV *PredicatedScalarEvolution::getSCEV(Value *V) {
14452 const SCEV *Expr = SE.getSCEV(V);
14453 RewriteEntry &Entry = RewriteMap[Expr];
14454
14455 // If we already have an entry and the version matches, return it.
14456 if (Entry.second && Generation == Entry.first)
14457 return Entry.second;
14458
14459 // We found an entry but it's stale. Rewrite the stale entry
14460 // according to the current predicate.
14461 if (Entry.second)
14462 Expr = Entry.second;
14463
14464 const SCEV *NewSCEV = SE.rewriteUsingPredicate(Expr, &L, *Preds);
14465 Entry = {Generation, NewSCEV};
14466
14467 return NewSCEV;
14468}
14469
14470const SCEV *PredicatedScalarEvolution::getBackedgeTakenCount() {
14471 if (!BackedgeCount) {
14472 SmallVector<const SCEVPredicate *, 4> Preds;
14473 BackedgeCount = SE.getPredicatedBackedgeTakenCount(&L, Preds);
14474 for (const auto *P : Preds)
14475 addPredicate(*P);
14476 }
14477 return BackedgeCount;
14478}
14479
14480void PredicatedScalarEvolution::addPredicate(const SCEVPredicate &Pred) {
14481 if (Preds->implies(&Pred))
14482 return;
14483
14484 auto &OldPreds = Preds->getPredicates();
14485 SmallVector<const SCEVPredicate*, 4> NewPreds(OldPreds.begin(), OldPreds.end());
14486 NewPreds.push_back(&Pred);
14487 Preds = std::make_unique<SCEVUnionPredicate>(NewPreds);
14488 updateGeneration();
14489}
14490
14491const SCEVPredicate &PredicatedScalarEvolution::getPredicate() const {
14492 return *Preds;
14493}
14494
14495void PredicatedScalarEvolution::updateGeneration() {
14496 // If the generation number wrapped recompute everything.
14497 if (++Generation == 0) {
14498 for (auto &II : RewriteMap) {
14499 const SCEV *Rewritten = II.second.second;
14500 II.second = {Generation, SE.rewriteUsingPredicate(Rewritten, &L, *Preds)};
14501 }
14502 }
14503}
14504
14505void PredicatedScalarEvolution::setNoOverflow(
14506 Value *V, SCEVWrapPredicate::IncrementWrapFlags Flags) {
14507 const SCEV *Expr = getSCEV(V);
14508 const auto *AR = cast<SCEVAddRecExpr>(Expr);
14509
14510 auto ImpliedFlags = SCEVWrapPredicate::getImpliedFlags(AR, SE);
14511
14512 // Clear the statically implied flags.
14513 Flags = SCEVWrapPredicate::clearFlags(Flags, ImpliedFlags);
14514 addPredicate(*SE.getWrapPredicate(AR, Flags));
14515
14516 auto II = FlagsMap.insert({V, Flags});
14517 if (!II.second)
14518 II.first->second = SCEVWrapPredicate::setFlags(Flags, II.first->second);
14519}
14520
14521bool PredicatedScalarEvolution::hasNoOverflow(
14522 Value *V, SCEVWrapPredicate::IncrementWrapFlags Flags) {
14523 const SCEV *Expr = getSCEV(V);
14524 const auto *AR = cast<SCEVAddRecExpr>(Expr);
14525
14526 Flags = SCEVWrapPredicate::clearFlags(
14527 Flags, SCEVWrapPredicate::getImpliedFlags(AR, SE));
14528
14529 auto II = FlagsMap.find(V);
14530
14531 if (II != FlagsMap.end())
14532 Flags = SCEVWrapPredicate::clearFlags(Flags, II->second);
14533
14534 return Flags == SCEVWrapPredicate::IncrementAnyWrap;
14535}
14536
14537const SCEVAddRecExpr *PredicatedScalarEvolution::getAsAddRec(Value *V) {
14538 const SCEV *Expr = this->getSCEV(V);
14539 SmallPtrSet<const SCEVPredicate *, 4> NewPreds;
14540 auto *New = SE.convertSCEVToAddRecWithPredicates(Expr, &L, NewPreds);
14541
14542 if (!New)
14543 return nullptr;
14544
14545 for (const auto *P : NewPreds)
14546 addPredicate(*P);
14547
14548 RewriteMap[SE.getSCEV(V)] = {Generation, New};
14549 return New;
14550}
14551
14552PredicatedScalarEvolution::PredicatedScalarEvolution(
14553 const PredicatedScalarEvolution &Init)
14554 : RewriteMap(Init.RewriteMap), SE(Init.SE), L(Init.L),
14555 Preds(std::make_unique<SCEVUnionPredicate>(Init.Preds->getPredicates())),
14556 Generation(Init.Generation), BackedgeCount(Init.BackedgeCount) {
14557 for (auto I : Init.FlagsMap)
14558 FlagsMap.insert(I);
14559}
14560
14561void PredicatedScalarEvolution::print(raw_ostream &OS, unsigned Depth) const {
14562 // For each block.
14563 for (auto *BB : L.getBlocks())
14564 for (auto &I : *BB) {
14565 if (!SE.isSCEVable(I.getType()))
14566 continue;
14567
14568 auto *Expr = SE.getSCEV(&I);
14569 auto II = RewriteMap.find(Expr);
14570
14571 if (II == RewriteMap.end())
14572 continue;
14573
14574 // Don't print things that are not interesting.
14575 if (II->second.second == Expr)
14576 continue;
14577
14578 OS.indent(Depth) << "[PSE]" << I << ":\n";
14579 OS.indent(Depth + 2) << *Expr << "\n";
14580 OS.indent(Depth + 2) << "--> " << *II->second.second << "\n";
14581 }
14582}
14583
14584// Match the mathematical pattern A - (A / B) * B, where A and B can be
14585// arbitrary expressions. Also match zext (trunc A to iB) to iY, which is used
14586// for URem with constant power-of-2 second operands.
14587// It's not always easy, as A and B can be folded (imagine A is X / 2, and B is
14588// 4, A / B becomes X / 8).
14589bool ScalarEvolution::matchURem(const SCEV *Expr, const SCEV *&LHS,
14590 const SCEV *&RHS) {
14591 // Try to match 'zext (trunc A to iB) to iY', which is used
14592 // for URem with constant power-of-2 second operands. Make sure the size of
14593 // the operand A matches the size of the whole expressions.
14594 if (const auto *ZExt = dyn_cast<SCEVZeroExtendExpr>(Expr))
14595 if (const auto *Trunc = dyn_cast<SCEVTruncateExpr>(ZExt->getOperand(0))) {
14596 LHS = Trunc->getOperand();
14597 // Bail out if the type of the LHS is larger than the type of the
14598 // expression for now.
14599 if (getTypeSizeInBits(LHS->getType()) >
14600 getTypeSizeInBits(Expr->getType()))
14601 return false;
14602 if (LHS->getType() != Expr->getType())
14603 LHS = getZeroExtendExpr(LHS, Expr->getType());
14604 RHS = getConstant(APInt(getTypeSizeInBits(Expr->getType()), 1)
14605 << getTypeSizeInBits(Trunc->getType()));
14606 return true;
14607 }
14608 const auto *Add = dyn_cast<SCEVAddExpr>(Expr);
14609 if (Add == nullptr || Add->getNumOperands() != 2)
14610 return false;
14611
14612 const SCEV *A = Add->getOperand(1);
14613 const auto *Mul = dyn_cast<SCEVMulExpr>(Add->getOperand(0));
14614
14615 if (Mul == nullptr)
14616 return false;
14617
14618 const auto MatchURemWithDivisor = [&](const SCEV *B) {
14619 // (SomeExpr + (-(SomeExpr / B) * B)).
14620 if (Expr == getURemExpr(A, B)) {
14621 LHS = A;
14622 RHS = B;
14623 return true;
14624 }
14625 return false;
14626 };
14627
14628 // (SomeExpr + (-1 * (SomeExpr / B) * B)).
14629 if (Mul->getNumOperands() == 3 && isa<SCEVConstant>(Mul->getOperand(0)))
14630 return MatchURemWithDivisor(Mul->getOperand(1)) ||
14631 MatchURemWithDivisor(Mul->getOperand(2));
14632
14633 // (SomeExpr + ((-SomeExpr / B) * B)) or (SomeExpr + ((SomeExpr / B) * -B)).
14634 if (Mul->getNumOperands() == 2)
14635 return MatchURemWithDivisor(Mul->getOperand(1)) ||
14636 MatchURemWithDivisor(Mul->getOperand(0)) ||
14637 MatchURemWithDivisor(getNegativeSCEV(Mul->getOperand(1))) ||
14638 MatchURemWithDivisor(getNegativeSCEV(Mul->getOperand(0)));
14639 return false;
14640}
14641
14642const SCEV *
14643ScalarEvolution::computeSymbolicMaxBackedgeTakenCount(const Loop *L) {
14644 SmallVector<BasicBlock*, 16> ExitingBlocks;
14645 L->getExitingBlocks(ExitingBlocks);
14646
14647 // Form an expression for the maximum exit count possible for this loop. We
14648 // merge the max and exact information to approximate a version of
14649 // getConstantMaxBackedgeTakenCount which isn't restricted to just constants.
14650 SmallVector<const SCEV*, 4> ExitCounts;
14651 for (BasicBlock *ExitingBB : ExitingBlocks) {
14652 const SCEV *ExitCount = getExitCount(L, ExitingBB);
14653 if (isa<SCEVCouldNotCompute>(ExitCount))
14654 ExitCount = getExitCount(L, ExitingBB,
14655 ScalarEvolution::ConstantMaximum);
14656 if (!isa<SCEVCouldNotCompute>(ExitCount)) {
14657 assert(DT.dominates(ExitingBB, L->getLoopLatch()) &&(static_cast <bool> (DT.dominates(ExitingBB, L->getLoopLatch
()) && "We should only have known counts for exiting blocks that "
"dominate latch!") ? void (0) : __assert_fail ("DT.dominates(ExitingBB, L->getLoopLatch()) && \"We should only have known counts for exiting blocks that \" \"dominate latch!\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 14659, __extension__
__PRETTY_FUNCTION__))
14658 "We should only have known counts for exiting blocks that "(static_cast <bool> (DT.dominates(ExitingBB, L->getLoopLatch
()) && "We should only have known counts for exiting blocks that "
"dominate latch!") ? void (0) : __assert_fail ("DT.dominates(ExitingBB, L->getLoopLatch()) && \"We should only have known counts for exiting blocks that \" \"dominate latch!\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 14659, __extension__
__PRETTY_FUNCTION__))
14659 "dominate latch!")(static_cast <bool> (DT.dominates(ExitingBB, L->getLoopLatch
()) && "We should only have known counts for exiting blocks that "
"dominate latch!") ? void (0) : __assert_fail ("DT.dominates(ExitingBB, L->getLoopLatch()) && \"We should only have known counts for exiting blocks that \" \"dominate latch!\""
, "llvm/lib/Analysis/ScalarEvolution.cpp", 14659, __extension__
__PRETTY_FUNCTION__))
;
14660 ExitCounts.push_back(ExitCount);
14661 }
14662 }
14663 if (ExitCounts.empty())
14664 return getCouldNotCompute();
14665 return getUMinFromMismatchedTypes(ExitCounts);
14666}
14667
14668/// A rewriter to replace SCEV expressions in Map with the corresponding entry
14669/// in the map. It skips AddRecExpr because we cannot guarantee that the
14670/// replacement is loop invariant in the loop of the AddRec.
14671///
14672/// At the moment only rewriting SCEVUnknown and SCEVZeroExtendExpr is
14673/// supported.
14674class SCEVLoopGuardRewriter : public SCEVRewriteVisitor<SCEVLoopGuardRewriter> {
14675 const DenseMap<const SCEV *, const SCEV *> &Map;
14676
14677public:
14678 SCEVLoopGuardRewriter(ScalarEvolution &SE,
14679 DenseMap<const SCEV *, const SCEV *> &M)
14680 : SCEVRewriteVisitor(SE), Map(M) {}
14681
14682 const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) { return Expr; }
14683
14684 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
14685 auto I = Map.find(Expr);
14686 if (I == Map.end())
14687 return Expr;
14688 return I->second;
14689 }
14690
14691 const SCEV *visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr) {
14692 auto I = Map.find(Expr);
14693 if (I == Map.end())
14694 return SCEVRewriteVisitor<SCEVLoopGuardRewriter>::visitZeroExtendExpr(
14695 Expr);
14696 return I->second;
14697 }
14698};
14699
14700const SCEV *ScalarEvolution::applyLoopGuards(const SCEV *Expr, const Loop *L) {
14701 SmallVector<const SCEV *> ExprsToRewrite;
14702 auto CollectCondition = [&](ICmpInst::Predicate Predicate, const SCEV *LHS,
14703 const SCEV *RHS,
14704 DenseMap<const SCEV *, const SCEV *>
14705 &RewriteMap) {
14706 // WARNING: It is generally unsound to apply any wrap flags to the proposed
14707 // replacement SCEV which isn't directly implied by the structure of that
14708 // SCEV. In particular, using contextual facts to imply flags is *NOT*
14709 // legal. See the scoping rules for flags in the header to understand why.
14710
14711 // If LHS is a constant, apply information to the other expression.
14712 if (isa<SCEVConstant>(LHS)) {
14713 std::swap(LHS, RHS);
14714 Predicate = CmpInst::getSwappedPredicate(Predicate);
14715 }
14716
14717 // Check for a condition of the form (-C1 + X < C2). InstCombine will
14718 // create this form when combining two checks of the form (X u< C2 + C1) and
14719 // (X >=u C1).
14720 auto MatchRangeCheckIdiom = [this, Predicate, LHS, RHS, &RewriteMap,
14721 &ExprsToRewrite]() {
14722 auto *AddExpr = dyn_cast<SCEVAddExpr>(LHS);
14723 if (!AddExpr || AddExpr->getNumOperands() != 2)
14724 return false;
14725
14726 auto *C1 = dyn_cast<SCEVConstant>(AddExpr->getOperand(0));
14727 auto *LHSUnknown = dyn_cast<SCEVUnknown>(AddExpr->getOperand(1));
14728 auto *C2 = dyn_cast<SCEVConstant>(RHS);
14729 if (!C1 || !C2 || !LHSUnknown)
14730 return false;
14731
14732 auto ExactRegion =
14733 ConstantRange::makeExactICmpRegion(Predicate, C2->getAPInt())
14734 .sub(C1->getAPInt());
14735
14736 // Bail out, unless we have a non-wrapping, monotonic range.
14737 if (ExactRegion.isWrappedSet() || ExactRegion.isFullSet())
14738 return false;
14739 auto I = RewriteMap.find(LHSUnknown);
14740 const SCEV *RewrittenLHS = I != RewriteMap.end() ? I->second : LHSUnknown;
14741 RewriteMap[LHSUnknown] = getUMaxExpr(
14742 getConstant(ExactRegion.getUnsignedMin()),
14743 getUMinExpr(RewrittenLHS, getConstant(ExactRegion.getUnsignedMax())));
14744 ExprsToRewrite.push_back(LHSUnknown);
14745 return true;
14746 };
14747 if (MatchRangeCheckIdiom())
14748 return;
14749
14750 // If we have LHS == 0, check if LHS is computing a property of some unknown
14751 // SCEV %v which we can rewrite %v to express explicitly.
14752 const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS);
14753 if (Predicate == CmpInst::ICMP_EQ && RHSC &&
14754 RHSC->getValue()->isNullValue()) {
14755 // If LHS is A % B, i.e. A % B == 0, rewrite A to (A /u B) * B to
14756 // explicitly express that.
14757 const SCEV *URemLHS = nullptr;
14758 const SCEV *URemRHS = nullptr;
14759 if (matchURem(LHS, URemLHS, URemRHS)) {
14760 if (const SCEVUnknown *LHSUnknown = dyn_cast<SCEVUnknown>(URemLHS)) {
14761 auto Multiple = getMulExpr(getUDivExpr(URemLHS, URemRHS), URemRHS);
14762 RewriteMap[LHSUnknown] = Multiple;
14763 ExprsToRewrite.push_back(LHSUnknown);
14764 return;
14765 }
14766 }
14767 }
14768
14769 // Do not apply information for constants or if RHS contains an AddRec.
14770 if (isa<SCEVConstant>(LHS) || containsAddRecurrence(RHS))
14771 return;
14772
14773 // If RHS is SCEVUnknown, make sure the information is applied to it.
14774 if (!isa<SCEVUnknown>(LHS) && isa<SCEVUnknown>(RHS)) {
14775 std::swap(LHS, RHS);
14776 Predicate = CmpInst::getSwappedPredicate(Predicate);
14777 }
14778
14779 // Limit to expressions that can be rewritten.
14780 if (!isa<SCEVUnknown>(LHS) && !isa<SCEVZeroExtendExpr>(LHS))
14781 return;
14782
14783 // Check whether LHS has already been rewritten. In that case we want to
14784 // chain further rewrites onto the already rewritten value.
14785 auto I = RewriteMap.find(LHS);
14786 const SCEV *RewrittenLHS = I != RewriteMap.end() ? I->second : LHS;
14787
14788 const SCEV *RewrittenRHS = nullptr;
14789 switch (Predicate) {
14790 case CmpInst::ICMP_ULT:
14791 RewrittenRHS =
14792 getUMinExpr(RewrittenLHS, getMinusSCEV(RHS, getOne(RHS->getType())));
14793 break;
14794 case CmpInst::ICMP_SLT:
14795 RewrittenRHS =
14796 getSMinExpr(RewrittenLHS, getMinusSCEV(RHS, getOne(RHS->getType())));
14797 break;
14798 case CmpInst::ICMP_ULE:
14799 RewrittenRHS = getUMinExpr(RewrittenLHS, RHS);
14800 break;
14801 case CmpInst::ICMP_SLE:
14802 RewrittenRHS = getSMinExpr(RewrittenLHS, RHS);
14803 break;
14804 case CmpInst::ICMP_UGT:
14805 RewrittenRHS =
14806 getUMaxExpr(RewrittenLHS, getAddExpr(RHS, getOne(RHS->getType())));
14807 break;
14808 case CmpInst::ICMP_SGT:
14809 RewrittenRHS =
14810 getSMaxExpr(RewrittenLHS, getAddExpr(RHS, getOne(RHS->getType())));
14811 break;
14812 case CmpInst::ICMP_UGE:
14813 RewrittenRHS = getUMaxExpr(RewrittenLHS, RHS);
14814 break;
14815 case CmpInst::ICMP_SGE:
14816 RewrittenRHS = getSMaxExpr(RewrittenLHS, RHS);
14817 break;
14818 case CmpInst::ICMP_EQ:
14819 if (isa<SCEVConstant>(RHS))
14820 RewrittenRHS = RHS;
14821 break;
14822 case CmpInst::ICMP_NE:
14823 if (isa<SCEVConstant>(RHS) &&
14824 cast<SCEVConstant>(RHS)->getValue()->isNullValue())
14825 RewrittenRHS = getUMaxExpr(RewrittenLHS, getOne(RHS->getType()));
14826 break;
14827 default:
14828 break;
14829 }
14830
14831 if (RewrittenRHS) {
14832 RewriteMap[LHS] = RewrittenRHS;
14833 if (LHS == RewrittenLHS)
14834 ExprsToRewrite.push_back(LHS);
14835 }
14836 };
14837
14838 SmallVector<std::pair<Value *, bool>> Terms;
14839 // First, collect information from assumptions dominating the loop.
14840 for (auto &AssumeVH : AC.assumptions()) {
14841 if (!AssumeVH)
14842 continue;
14843 auto *AssumeI = cast<CallInst>(AssumeVH);
14844 if (!DT.dominates(AssumeI, L->getHeader()))
14845 continue;
14846 Terms.emplace_back(AssumeI->getOperand(0), true);
14847 }
14848
14849 // Second, collect conditions from dominating branches. Starting at the loop
14850 // predecessor, climb up the predecessor chain, as long as there are
14851 // predecessors that can be found that have unique successors leading to the
14852 // original header.
14853 // TODO: share this logic with isLoopEntryGuardedByCond.
14854 for (std::pair<const BasicBlock *, const BasicBlock *> Pair(
14855 L->getLoopPredecessor(), L->getHeader());
14856 Pair.first; Pair = getPredecessorWithUniqueSuccessorForBB(Pair.first)) {
14857
14858 const BranchInst *LoopEntryPredicate =
14859 dyn_cast<BranchInst>(Pair.first->getTerminator());
14860 if (!LoopEntryPredicate || LoopEntryPredicate->isUnconditional())
14861 continue;
14862
14863 Terms.emplace_back(LoopEntryPredicate->getCondition(),
14864 LoopEntryPredicate->getSuccessor(0) == Pair.second);
14865 }
14866
14867 // Now apply the information from the collected conditions to RewriteMap.
14868 // Conditions are processed in reverse order, so the earliest conditions is
14869 // processed first. This ensures the SCEVs with the shortest dependency chains
14870 // are constructed first.
14871 DenseMap<const SCEV *, const SCEV *> RewriteMap;
14872 for (auto &E : reverse(Terms)) {
14873 bool EnterIfTrue = E.second;
14874 SmallVector<Value *, 8> Worklist;
14875 SmallPtrSet<Value *, 8> Visited;
14876 Worklist.push_back(E.first);
14877 while (!Worklist.empty()) {
14878 Value *Cond = Worklist.pop_back_val();
14879 if (!Visited.insert(Cond).second)
14880 continue;
14881
14882 if (auto *Cmp = dyn_cast<ICmpInst>(Cond)) {
14883 auto Predicate =
14884 EnterIfTrue ? Cmp->getPredicate() : Cmp->getInversePredicate();
14885 const auto *LHS = getSCEV(Cmp->getOperand(0));
14886 const auto *RHS = getSCEV(Cmp->getOperand(1));
14887 CollectCondition(Predicate, LHS, RHS, RewriteMap);
14888 continue;
14889 }
14890
14891 Value *L, *R;
14892 if (EnterIfTrue ? match(Cond, m_LogicalAnd(m_Value(L), m_Value(R)))
14893 : match(Cond, m_LogicalOr(m_Value(L), m_Value(R)))) {
14894 Worklist.push_back(L);
14895 Worklist.push_back(R);
14896 }
14897 }
14898 }
14899
14900 if (RewriteMap.empty())
14901 return Expr;
14902
14903 // Now that all rewrite information is collect, rewrite the collected
14904 // expressions with the information in the map. This applies information to
14905 // sub-expressions.
14906 if (ExprsToRewrite.size() > 1) {
14907 for (const SCEV *Expr : ExprsToRewrite) {
14908 const SCEV *RewriteTo = RewriteMap[Expr];
14909 RewriteMap.erase(Expr);
14910 SCEVLoopGuardRewriter Rewriter(*this, RewriteMap);
14911 RewriteMap.insert({Expr, Rewriter.visit(RewriteTo)});
14912 }
14913 }
14914
14915 SCEVLoopGuardRewriter Rewriter(*this, RewriteMap);
14916 return Rewriter.visit(Expr);
14917}

/build/source/llvm/include/llvm/ADT/Optional.h

1//===- Optional.h - Simple variant for passing optional values --*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8///
9/// \file
10/// This file provides Optional, a template class modeled in the spirit of
11/// OCaml's 'opt' variant. The idea is to strongly type whether or not
12/// a value can be optional.
13///
14//===----------------------------------------------------------------------===//
15
16#ifndef LLVM_ADT_OPTIONAL_H
17#define LLVM_ADT_OPTIONAL_H
18
19#include "llvm/ADT/Hashing.h"
20#include "llvm/ADT/None.h"
21#include "llvm/ADT/STLForwardCompat.h"
22#include "llvm/Support/Compiler.h"
23#include "llvm/Support/type_traits.h"
24#include <cassert>
25#include <new>
26#include <utility>
27
28namespace llvm {
29
30class raw_ostream;
31
32namespace optional_detail {
33
34/// Storage for any type.
35//
36// The specialization condition intentionally uses
37// llvm::is_trivially_{copy/move}_constructible instead of
38// std::is_trivially_{copy/move}_constructible. GCC versions prior to 7.4 may
39// instantiate the copy/move constructor of `T` when
40// std::is_trivially_{copy/move}_constructible is instantiated. This causes
41// compilation to fail if we query the trivially copy/move constructible
42// property of a class which is not copy/move constructible.
43//
44// The current implementation of OptionalStorage insists that in order to use
45// the trivial specialization, the value_type must be trivially copy
46// constructible and trivially copy assignable due to =default implementations
47// of the copy/move constructor/assignment. It does not follow that this is
48// necessarily the case std::is_trivially_copyable is true (hence the expanded
49// specialization condition).
50//
51// The move constructible / assignable conditions emulate the remaining behavior
52// of std::is_trivially_copyable.
53template <typename T,
54 bool = (llvm::is_trivially_copy_constructible<T>::value &&
55 std::is_trivially_copy_assignable<T>::value &&
56 (llvm::is_trivially_move_constructible<T>::value ||
57 !std::is_move_constructible<T>::value) &&
58 (std::is_trivially_move_assignable<T>::value ||
59 !std::is_move_assignable<T>::value))>
60class OptionalStorage {
61 union {
62 char empty;
63 T val;
64 };
65 bool hasVal = false;
66
67public:
68 ~OptionalStorage() { reset(); }
69
70 constexpr OptionalStorage() noexcept : empty() {}
71
72 constexpr OptionalStorage(OptionalStorage const &other) : OptionalStorage() {
73 if (other.has_value()) {
74 emplace(other.val);
75 }
76 }
77 constexpr OptionalStorage(OptionalStorage &&other) : OptionalStorage() {
78 if (other.has_value()) {
79 emplace(std::move(other.val));
80 }
81 }
82
83 template <class... Args>
84 constexpr explicit OptionalStorage(std::in_place_t, Args &&...args)
85 : val(std::forward<Args>(args)...), hasVal(true) {}
86
87 void reset() noexcept {
88 if (hasVal) {
89 val.~T();
90 hasVal = false;
91 }
92 }
93
94 constexpr bool has_value() const noexcept { return hasVal; }
95 LLVM_DEPRECATED("Use has_value instead.", "has_value")__attribute__((deprecated("Use has_value instead.", "has_value"
)))
96 constexpr bool hasValue() const noexcept {
97 return hasVal;
98 }
99
100 T &value() &noexcept {
101 assert(hasVal)(static_cast <bool> (hasVal) ? void (0) : __assert_fail
("hasVal", "llvm/include/llvm/ADT/Optional.h", 101, __extension__
__PRETTY_FUNCTION__))
;
102 return val;
103 }
104 LLVM_DEPRECATED("Use value instead.", "value")__attribute__((deprecated("Use value instead.", "value"))) T &getValue() &noexcept {
105 assert(hasVal)(static_cast <bool> (hasVal) ? void (0) : __assert_fail
("hasVal", "llvm/include/llvm/ADT/Optional.h", 105, __extension__
__PRETTY_FUNCTION__))
;
106 return val;
107 }
108 constexpr T const &value() const &noexcept {
109 assert(hasVal)(static_cast <bool> (hasVal) ? void (0) : __assert_fail
("hasVal", "llvm/include/llvm/ADT/Optional.h", 109, __extension__
__PRETTY_FUNCTION__))
;
110 return val;
111 }
112 LLVM_DEPRECATED("Use value instead.", "value")__attribute__((deprecated("Use value instead.", "value")))
113 constexpr T const &getValue() const &noexcept {
114 assert(hasVal)(static_cast <bool> (hasVal) ? void (0) : __assert_fail
("hasVal", "llvm/include/llvm/ADT/Optional.h", 114, __extension__
__PRETTY_FUNCTION__))
;
115 return val;
116 }
117 T &&value() &&noexcept {
118 assert(hasVal)(static_cast <bool> (hasVal) ? void (0) : __assert_fail
("hasVal", "llvm/include/llvm/ADT/Optional.h", 118, __extension__
__PRETTY_FUNCTION__))
;
119 return std::move(val);
120 }
121 LLVM_DEPRECATED("Use value instead.", "value")__attribute__((deprecated("Use value instead.", "value"))) T &&getValue() &&noexcept {
122 assert(hasVal)(static_cast <bool> (hasVal) ? void (0) : __assert_fail
("hasVal", "llvm/include/llvm/ADT/Optional.h", 122, __extension__
__PRETTY_FUNCTION__))
;
123 return std::move(val);
124 }
125
126 template <class... Args> void emplace(Args &&...args) {
127 reset();
128 ::new ((void *)std::addressof(val)) T(std::forward<Args>(args)...);
129 hasVal = true;
130 }
131
132 OptionalStorage &operator=(T const &y) {
133 if (has_value()) {
134 val = y;
135 } else {
136 ::new ((void *)std::addressof(val)) T(y);
137 hasVal = true;
138 }
139 return *this;
140 }
141 OptionalStorage &operator=(T &&y) {
142 if (has_value()) {
143 val = std::move(y);
144 } else {
145 ::new ((void *)std::addressof(val)) T(std::move(y));
146 hasVal = true;
147 }
148 return *this;
149 }
150
151 OptionalStorage &operator=(OptionalStorage const &other) {
152 if (other.has_value()) {
153 if (has_value()) {
154 val = other.val;
155 } else {
156 ::new ((void *)std::addressof(val)) T(other.val);
157 hasVal = true;
158 }
159 } else {
160 reset();
161 }
162 return *this;
163 }
164
165 OptionalStorage &operator=(OptionalStorage &&other) {
166 if (other.has_value()) {
167 if (has_value()) {
168 val = std::move(other.val);
169 } else {
170 ::new ((void *)std::addressof(val)) T(std::move(other.val));
171 hasVal = true;
172 }
173 } else {
174 reset();
175 }
176 return *this;
177 }
178};
179
180template <typename T> class OptionalStorage<T, true> {
181 union {
182 char empty;
183 T val;
184 };
185 bool hasVal = false;
186
187public:
188 ~OptionalStorage() = default;
189
190 constexpr OptionalStorage() noexcept : empty{} {}
191
192 constexpr OptionalStorage(OptionalStorage const &other) = default;
193 constexpr OptionalStorage(OptionalStorage &&other) = default;
194
195 OptionalStorage &operator=(OptionalStorage const &other) = default;
196 OptionalStorage &operator=(OptionalStorage &&other) = default;
197
198 template <class... Args>
199 constexpr explicit OptionalStorage(std::in_place_t, Args &&...args)
200 : val(std::forward<Args>(args)...), hasVal(true) {}
22
Null pointer value stored to 'BO.Storage..val.LHS'
201
202 void reset() noexcept {
203 if (hasVal) {
204 val.~T();
205 hasVal = false;
206 }
207 }
208
209 constexpr bool has_value() const noexcept { return hasVal; }
210 LLVM_DEPRECATED("Use has_value instead.", "has_value")__attribute__((deprecated("Use has_value instead.", "has_value"
)))
211 constexpr bool hasValue() const noexcept {
212 return hasVal;
213 }
214
215 T &value() &noexcept {
216 assert(hasVal)(static_cast <bool> (hasVal) ? void (0) : __assert_fail
("hasVal", "llvm/include/llvm/ADT/Optional.h", 216, __extension__
__PRETTY_FUNCTION__))
;
217 return val;
218 }
219 LLVM_DEPRECATED("Use value instead.", "value")__attribute__((deprecated("Use value instead.", "value"))) T &getValue() &noexcept {
220 assert(hasVal)(static_cast <bool> (hasVal) ? void (0) : __assert_fail
("hasVal", "llvm/include/llvm/ADT/Optional.h", 220, __extension__
__PRETTY_FUNCTION__))
;
221 return val;
222 }
223 constexpr T const &value() const &noexcept {
224 assert(hasVal)(static_cast <bool> (hasVal) ? void (0) : __assert_fail
("hasVal", "llvm/include/llvm/ADT/Optional.h", 224, __extension__
__PRETTY_FUNCTION__))
;
225 return val;
226 }
227 LLVM_DEPRECATED("Use value instead.", "value")__attribute__((deprecated("Use value instead.", "value")))
228 constexpr T const &getValue() const &noexcept {
229 assert(hasVal)(static_cast <bool> (hasVal) ? void (0) : __assert_fail
("hasVal", "llvm/include/llvm/ADT/Optional.h", 229, __extension__
__PRETTY_FUNCTION__))
;
230 return val;
231 }
232 T &&value() &&noexcept {
233 assert(hasVal)(static_cast <bool> (hasVal) ? void (0) : __assert_fail
("hasVal", "llvm/include/llvm/ADT/Optional.h", 233, __extension__
__PRETTY_FUNCTION__))
;
234 return std::move(val);
235 }
236 LLVM_DEPRECATED("Use value instead.", "value")__attribute__((deprecated("Use value instead.", "value"))) T &&getValue() &&noexcept {
237 assert(hasVal)(static_cast <bool> (hasVal) ? void (0) : __assert_fail
("hasVal", "llvm/include/llvm/ADT/Optional.h", 237, __extension__
__PRETTY_FUNCTION__))
;
238 return std::move(val);
239 }
240
241 template <class... Args> void emplace(Args &&...args) {
242 reset();
243 ::new ((void *)std::addressof(val)) T(std::forward<Args>(args)...);
244 hasVal = true;
245 }
246
247 OptionalStorage &operator=(T const &y) {
248 if (has_value()) {
249 val = y;
250 } else {
251 ::new ((void *)std::addressof(val)) T(y);
252 hasVal = true;
253 }
254 return *this;
255 }
256 OptionalStorage &operator=(T &&y) {
257 if (has_value()) {
258 val = std::move(y);
259 } else {
260 ::new ((void *)std::addressof(val)) T(std::move(y));
261 hasVal = true;
262 }
263 return *this;
264 }
265};
266
267} // namespace optional_detail
268
269template <typename T> class Optional {
270 optional_detail::OptionalStorage<T> Storage;
271
272public:
273 using value_type = T;
274
275 constexpr Optional() = default;
276 constexpr Optional(NoneType) {}
277
278 constexpr Optional(const T &y) : Storage(std::in_place, y) {}
279 constexpr Optional(const Optional &O) = default;
280
281 constexpr Optional(T &&y) : Storage(std::in_place, std::move(y)) {}
21
Calling constructor for 'OptionalStorage<(anonymous namespace)::BinaryOp, true>'
23
Returning from constructor for 'OptionalStorage<(anonymous namespace)::BinaryOp, true>'
282 constexpr Optional(Optional &&O) = default;
283
284 template <typename... ArgTypes>
285 constexpr Optional(std::in_place_t, ArgTypes &&...Args)
286 : Storage(std::in_place, std::forward<ArgTypes>(Args)...) {}
287
288 Optional &operator=(T &&y) {
289 Storage = std::move(y);
290 return *this;
291 }
292 Optional &operator=(Optional &&O) = default;
293
294 /// Create a new object by constructing it in place with the given arguments.
295 template <typename... ArgTypes> void emplace(ArgTypes &&... Args) {
296 Storage.emplace(std::forward<ArgTypes>(Args)...);
297 }
298
299 static constexpr Optional create(const T *y) {
300 return y ? Optional(*y) : Optional();
301 }
302
303 Optional &operator=(const T &y) {
304 Storage = y;
305 return *this;
306 }
307 Optional &operator=(const Optional &O) = default;
308
309 void reset() { Storage.reset(); }
310
311 constexpr const T *getPointer() const { return &Storage.value(); }
312 T *getPointer() { return &Storage.value(); }
313 constexpr const T &value() const & { return Storage.value(); }
314 LLVM_DEPRECATED("Use value instead.", "value")__attribute__((deprecated("Use value instead.", "value")))
315 constexpr const T &getValue() const & {
316 return Storage.value();
317 }
318 T &value() & { return Storage.value(); }
319 LLVM_DEPRECATED("Use value instead.", "value")__attribute__((deprecated("Use value instead.", "value"))) T &getValue() & {
320 return Storage.value();
321 }
322
323 constexpr explicit operator bool() const { return has_value(); }
324 constexpr bool has_value() const { return Storage.has_value(); }
325 LLVM_DEPRECATED("Use has_value instead.", "has_value")__attribute__((deprecated("Use has_value instead.", "has_value"
)))
326 constexpr bool hasValue() const {
327 return Storage.has_value();
328 }
329 constexpr const T *operator->() const { return getPointer(); }
330 T *operator->() { return getPointer(); }
331 constexpr const T &operator*() const & { return value(); }
332 T &operator*() & { return value(); }
333
334 template <typename U> constexpr T value_or(U &&alt) const & {
335 return has_value() ? value() : std::forward<U>(alt);
336 }
337 template <typename U>
338 LLVM_DEPRECATED("Use value_or instead.", "value_or")__attribute__((deprecated("Use value_or instead.", "value_or"
)))
339 constexpr T getValueOr(U &&alt) const & {
340 return has_value() ? value() : std::forward<U>(alt);
341 }
342
343 /// Apply a function to the value if present; otherwise return None.
344 template <class Function>
345 auto transform(const Function &F) const & -> Optional<decltype(F(value()))> {
346 if (*this)
347 return F(value());
348 return None;
349 }
350 template <class Function>
351 LLVM_DEPRECATED("Use transform instead.", "transform")__attribute__((deprecated("Use transform instead.", "transform"
)))
352 auto map(const Function &F) const & -> Optional<decltype(F(value()))> {
353 if (*this)
354 return F(value());
355 return None;
356 }
357
358 T &&value() && { return std::move(Storage.value()); }
359 LLVM_DEPRECATED("Use value instead.", "value")__attribute__((deprecated("Use value instead.", "value"))) T &&getValue() && {
360 return std::move(Storage.value());
361 }
362 T &&operator*() && { return std::move(Storage.value()); }
363
364 template <typename U> T value_or(U &&alt) && {
365 return has_value() ? std::move(value()) : std::forward<U>(alt);
366 }
367 template <typename U>
368 LLVM_DEPRECATED("Use value_or instead.", "value_or")__attribute__((deprecated("Use value_or instead.", "value_or"
)))
369 T getValueOr(U &&alt) && {
370 return has_value() ? std::move(value()) : std::forward<U>(alt);
371 }
372
373 /// Apply a function to the value if present; otherwise return None.
374 template <class Function>
375 auto transform(
376 const Function &F) && -> Optional<decltype(F(std::move(*this).value()))> {
377 if (*this)
378 return F(std::move(*this).value());
379 return None;
380 }
381 template <class Function>
382 LLVM_DEPRECATED("Use transform instead.", "transform")__attribute__((deprecated("Use transform instead.", "transform"
)))
383 auto map(const Function &F)
384 && -> Optional<decltype(F(std::move(*this).value()))> {
385 if (*this)
386 return F(std::move(*this).value());
387 return None;
388 }
389};
390
391template<typename T>
392Optional(const T&) -> Optional<T>;
393
394template <class T> llvm::hash_code hash_value(const Optional<T> &O) {
395 return O ? hash_combine(true, *O) : hash_value(false);
396}
397
398template <typename T, typename U>
399constexpr bool operator==(const Optional<T> &X, const Optional<U> &Y) {
400 if (X && Y)
401 return *X == *Y;
402 return X.has_value() == Y.has_value();
403}
404
405template <typename T, typename U>
406constexpr bool operator!=(const Optional<T> &X, const Optional<U> &Y) {
407 return !(X == Y);
408}
409
410template <typename T, typename U>
411constexpr bool operator<(const Optional<T> &X, const Optional<U> &Y) {
412 if (X && Y)
413 return *X < *Y;
414 return X.has_value() < Y.has_value();
415}
416
417template <typename T, typename U>
418constexpr bool operator<=(const Optional<T> &X, const Optional<U> &Y) {
419 return !(Y < X);
420}
421
422template <typename T, typename U>
423constexpr bool operator>(const Optional<T> &X, const Optional<U> &Y) {
424 return Y < X;
425}
426
427template <typename T, typename U>
428constexpr bool operator>=(const Optional<T> &X, const Optional<U> &Y) {
429 return !(X < Y);
430}
431
432template <typename T>
433constexpr bool operator==(const Optional<T> &X, NoneType) {
434 return !X;
435}
436
437template <typename T>
438constexpr bool operator==(NoneType, const Optional<T> &X) {
439 return X == None;
440}
441
442template <typename T>
443constexpr bool operator!=(const Optional<T> &X, NoneType) {
444 return !(X == None);
445}
446
447template <typename T>
448constexpr bool operator!=(NoneType, const Optional<T> &X) {
449 return X != None;
450}
451
452template <typename T> constexpr bool operator<(const Optional<T> &, NoneType) {
453 return false;
454}
455
456template <typename T> constexpr bool operator<(NoneType, const Optional<T> &X) {
457 return X.has_value();
458}
459
460template <typename T>
461constexpr bool operator<=(const Optional<T> &X, NoneType) {
462 return !(None < X);
463}
464
465template <typename T>
466constexpr bool operator<=(NoneType, const Optional<T> &X) {
467 return !(X < None);
468}
469
470template <typename T> constexpr bool operator>(const Optional<T> &X, NoneType) {
471 return None < X;
472}
473
474template <typename T> constexpr bool operator>(NoneType, const Optional<T> &X) {
475 return X < None;
476}
477
478template <typename T>
479constexpr bool operator>=(const Optional<T> &X, NoneType) {
480 return None <= X;
481}
482
483template <typename T>
484constexpr bool operator>=(NoneType, const Optional<T> &X) {
485 return X <= None;
486}
487
488template <typename T>
489constexpr bool operator==(const Optional<T> &X, const T &Y) {
490 return X && *X == Y;
491}
492
493template <typename T>
494constexpr bool operator==(const T &X, const Optional<T> &Y) {
495 return Y && X == *Y;
496}
497
498template <typename T>
499constexpr bool operator!=(const Optional<T> &X, const T &Y) {
500 return !(X == Y);
501}
502
503template <typename T>
504constexpr bool operator!=(const T &X, const Optional<T> &Y) {
505 return !(X == Y);
506}
507
508template <typename T>
509constexpr bool operator<(const Optional<T> &X, const T &Y) {
510 return !X || *X < Y;
511}
512
513template <typename T>
514constexpr bool operator<(const T &X, const Optional<T> &Y) {
515 return Y && X < *Y;
516}
517
518template <typename T>
519constexpr bool operator<=(const Optional<T> &X, const T &Y) {
520 return !(Y < X);
521}
522
523template <typename T>
524constexpr bool operator<=(const T &X, const Optional<T> &Y) {
525 return !(Y < X);
526}
527
528template <typename T>
529constexpr bool operator>(const Optional<T> &X, const T &Y) {
530 return Y < X;
531}
532
533template <typename T>
534constexpr bool operator>(const T &X, const Optional<T> &Y) {
535 return Y < X;
536}
537
538template <typename T>
539constexpr bool operator>=(const Optional<T> &X, const T &Y) {
540 return !(X < Y);
541}
542
543template <typename T>
544constexpr bool operator>=(const T &X, const Optional<T> &Y) {
545 return !(X < Y);
546}
547
548raw_ostream &operator<<(raw_ostream &OS, NoneType);
549
550template <typename T, typename = decltype(std::declval<raw_ostream &>()
551 << std::declval<const T &>())>
552raw_ostream &operator<<(raw_ostream &OS, const Optional<T> &O) {
553 if (O)
554 OS << *O;
555 else
556 OS << None;
557 return OS;
558}
559
560} // end namespace llvm
561
562#endif // LLVM_ADT_OPTIONAL_H