LLVM 17.0.0git
ScalarEvolutionDivision.cpp
Go to the documentation of this file.
1//===- ScalarEvolutionDivision.h - See below --------------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file defines the class that knows how to divide SCEV's.
10//
11//===----------------------------------------------------------------------===//
12
14#include "llvm/ADT/APInt.h"
15#include "llvm/ADT/DenseMap.h"
19#include <cassert>
20#include <cstdint>
21
22namespace llvm {
23class Type;
24}
25
26using namespace llvm;
27
28namespace {
29
30static inline int sizeOfSCEV(const SCEV *S) {
31 struct FindSCEVSize {
32 int Size = 0;
33
34 FindSCEVSize() = default;
35
36 bool follow(const SCEV *S) {
37 ++Size;
38 // Keep looking at all operands of S.
39 return true;
40 }
41
42 bool isDone() const { return false; }
43 };
44
45 FindSCEVSize F;
47 ST.visitAll(S);
48 return F.Size;
49}
50
51} // namespace
52
53// Computes the Quotient and Remainder of the division of Numerator by
54// Denominator.
55void SCEVDivision::divide(ScalarEvolution &SE, const SCEV *Numerator,
56 const SCEV *Denominator, const SCEV **Quotient,
57 const SCEV **Remainder) {
58 assert(Numerator && Denominator && "Uninitialized SCEV");
59
60 SCEVDivision D(SE, Numerator, Denominator);
61
62 // Check for the trivial case here to avoid having to check for it in the
63 // rest of the code.
64 if (Numerator == Denominator) {
65 *Quotient = D.One;
66 *Remainder = D.Zero;
67 return;
68 }
69
70 if (Numerator->isZero()) {
71 *Quotient = D.Zero;
72 *Remainder = D.Zero;
73 return;
74 }
75
76 // A simple case when N/1. The quotient is N.
77 if (Denominator->isOne()) {
78 *Quotient = Numerator;
79 *Remainder = D.Zero;
80 return;
81 }
82
83 // Split the Denominator when it is a product.
84 if (const SCEVMulExpr *T = dyn_cast<SCEVMulExpr>(Denominator)) {
85 const SCEV *Q, *R;
86 *Quotient = Numerator;
87 for (const SCEV *Op : T->operands()) {
88 divide(SE, *Quotient, Op, &Q, &R);
89 *Quotient = Q;
90
91 // Bail out when the Numerator is not divisible by one of the terms of
92 // the Denominator.
93 if (!R->isZero()) {
94 *Quotient = D.Zero;
95 *Remainder = Numerator;
96 return;
97 }
98 }
99 *Remainder = D.Zero;
100 return;
101 }
102
103 D.visit(Numerator);
104 *Quotient = D.Quotient;
105 *Remainder = D.Remainder;
106}
107
109 if (const SCEVConstant *D = dyn_cast<SCEVConstant>(Denominator)) {
110 APInt NumeratorVal = Numerator->getAPInt();
111 APInt DenominatorVal = D->getAPInt();
112 uint32_t NumeratorBW = NumeratorVal.getBitWidth();
113 uint32_t DenominatorBW = DenominatorVal.getBitWidth();
114
115 if (NumeratorBW > DenominatorBW)
116 DenominatorVal = DenominatorVal.sext(NumeratorBW);
117 else if (NumeratorBW < DenominatorBW)
118 NumeratorVal = NumeratorVal.sext(DenominatorBW);
119
120 APInt QuotientVal(NumeratorVal.getBitWidth(), 0);
121 APInt RemainderVal(NumeratorVal.getBitWidth(), 0);
122 APInt::sdivrem(NumeratorVal, DenominatorVal, QuotientVal, RemainderVal);
123 Quotient = SE.getConstant(QuotientVal);
124 Remainder = SE.getConstant(RemainderVal);
125 return;
126 }
127}
128
130 const SCEV *StartQ, *StartR, *StepQ, *StepR;
131 if (!Numerator->isAffine())
132 return cannotDivide(Numerator);
133 divide(SE, Numerator->getStart(), Denominator, &StartQ, &StartR);
134 divide(SE, Numerator->getStepRecurrence(SE), Denominator, &StepQ, &StepR);
135 // Bail out if the types do not match.
136 Type *Ty = Denominator->getType();
137 if (Ty != StartQ->getType() || Ty != StartR->getType() ||
138 Ty != StepQ->getType() || Ty != StepR->getType())
139 return cannotDivide(Numerator);
140 Quotient = SE.getAddRecExpr(StartQ, StepQ, Numerator->getLoop(),
141 Numerator->getNoWrapFlags());
142 Remainder = SE.getAddRecExpr(StartR, StepR, Numerator->getLoop(),
143 Numerator->getNoWrapFlags());
144}
145
148 Type *Ty = Denominator->getType();
149
150 for (const SCEV *Op : Numerator->operands()) {
151 const SCEV *Q, *R;
152 divide(SE, Op, Denominator, &Q, &R);
153
154 // Bail out if types do not match.
155 if (Ty != Q->getType() || Ty != R->getType())
156 return cannotDivide(Numerator);
157
158 Qs.push_back(Q);
159 Rs.push_back(R);
160 }
161
162 if (Qs.size() == 1) {
163 Quotient = Qs[0];
164 Remainder = Rs[0];
165 return;
166 }
167
168 Quotient = SE.getAddExpr(Qs);
169 Remainder = SE.getAddExpr(Rs);
170}
171
174 Type *Ty = Denominator->getType();
175
176 bool FoundDenominatorTerm = false;
177 for (const SCEV *Op : Numerator->operands()) {
178 // Bail out if types do not match.
179 if (Ty != Op->getType())
180 return cannotDivide(Numerator);
181
182 if (FoundDenominatorTerm) {
183 Qs.push_back(Op);
184 continue;
185 }
186
187 // Check whether Denominator divides one of the product operands.
188 const SCEV *Q, *R;
189 divide(SE, Op, Denominator, &Q, &R);
190 if (!R->isZero()) {
191 Qs.push_back(Op);
192 continue;
193 }
194
195 // Bail out if types do not match.
196 if (Ty != Q->getType())
197 return cannotDivide(Numerator);
198
199 FoundDenominatorTerm = true;
200 Qs.push_back(Q);
201 }
202
203 if (FoundDenominatorTerm) {
204 Remainder = Zero;
205 if (Qs.size() == 1)
206 Quotient = Qs[0];
207 else
208 Quotient = SE.getMulExpr(Qs);
209 return;
210 }
211
212 if (!isa<SCEVUnknown>(Denominator))
213 return cannotDivide(Numerator);
214
215 // The Remainder is obtained by replacing Denominator by 0 in Numerator.
216 ValueToSCEVMapTy RewriteMap;
217 RewriteMap[cast<SCEVUnknown>(Denominator)->getValue()] = Zero;
218 Remainder = SCEVParameterRewriter::rewrite(Numerator, SE, RewriteMap);
219
220 if (Remainder->isZero()) {
221 // The Quotient is obtained by replacing Denominator by 1 in Numerator.
222 RewriteMap[cast<SCEVUnknown>(Denominator)->getValue()] = One;
223 Quotient = SCEVParameterRewriter::rewrite(Numerator, SE, RewriteMap);
224 return;
225 }
226
227 // Quotient is (Numerator - Remainder) divided by Denominator.
228 const SCEV *Q, *R;
229 const SCEV *Diff = SE.getMinusSCEV(Numerator, Remainder);
230 // This SCEV does not seem to simplify: fail the division here.
231 if (sizeOfSCEV(Diff) > sizeOfSCEV(Numerator))
232 return cannotDivide(Numerator);
233 divide(SE, Diff, Denominator, &Q, &R);
234 if (R != Zero)
235 return cannotDivide(Numerator);
236 Quotient = Q;
237}
238
239SCEVDivision::SCEVDivision(ScalarEvolution &S, const SCEV *Numerator,
240 const SCEV *Denominator)
241 : SE(S), Denominator(Denominator) {
242 Zero = SE.getZero(Denominator->getType());
243 One = SE.getOne(Denominator->getType());
244
245 // We generally do not know how to divide Expr by Denominator. We initialize
246 // the division to a "cannot divide" state to simplify the rest of the code.
247 cannotDivide(Numerator);
248}
249
250// Convenience function for giving up on the division. We set the quotient to
251// be equal to zero and the remainder to be equal to the numerator.
252void SCEVDivision::cannotDivide(const SCEV *Numerator) {
253 Quotient = Zero;
254 Remainder = Numerator;
255}
This file implements a class to represent arbitrary precision integral constant values and operations...
static GCRegistry::Add< StatepointGC > D("statepoint-example", "an example strategy for statepoint")
This file defines the DenseMap class.
uint64_t Size
#define F(x, y, z)
Definition: MD5.cpp:55
assert(ImpDefSCC.getReg()==AMDGPU::SCC &&ImpDefSCC.isDef())
This file defines the SmallVector class.
Class for arbitrary precision integers.
Definition: APInt.h:75
static void sdivrem(const APInt &LHS, const APInt &RHS, APInt &Quotient, APInt &Remainder)
Definition: APInt.cpp:1888
unsigned getBitWidth() const
Return the number of bits in the APInt.
Definition: APInt.h:1439
APInt sext(unsigned width) const
Sign extend to a new width.
Definition: APInt.cpp:946
This node represents an addition of some number of SCEVs.
This node represents a polynomial recurrence on the trip count of the specified loop.
const SCEV * getStepRecurrence(ScalarEvolution &SE) const
Constructs and returns the recurrence indicating how much this expression steps by.
bool isAffine() const
Return true if this represents an expression A + B*x where A and B are loop invariant values.
This class represents a constant integer value.
const APInt & getAPInt() const
This node represents multiplication of some number of SCEVs.
NoWrapFlags getNoWrapFlags(NoWrapFlags Mask=NoWrapMask) const
ArrayRef< const SCEV * > operands() const
static const SCEV * rewrite(const SCEV *Scev, ScalarEvolution &SE, ValueToSCEVMapTy &Map)
Visit all nodes in the expression tree using worklist traversal.
This class represents an analyzed expression in the program.
bool isOne() const
Return true if the expression is a constant one.
bool isZero() const
Return true if the expression is a constant zero.
Type * getType() const
Return the LLVM type of this SCEV expression.
The main scalar evolution driver.
const SCEV * getZero(Type *Ty)
Return a SCEV for the constant 0 of a specific type.
const SCEV * getConstant(ConstantInt *V)
const SCEV * getOne(Type *Ty)
Return a SCEV for the constant 1 of a specific type.
const SCEV * getAddRecExpr(const SCEV *Start, const SCEV *Step, const Loop *L, SCEV::NoWrapFlags Flags)
Get an add recurrence expression for the specified loop.
const SCEV * getMinusSCEV(const SCEV *LHS, const SCEV *RHS, SCEV::NoWrapFlags Flags=SCEV::FlagAnyWrap, unsigned Depth=0)
Return LHS-RHS.
const SCEV * getMulExpr(SmallVectorImpl< const SCEV * > &Ops, SCEV::NoWrapFlags Flags=SCEV::FlagAnyWrap, unsigned Depth=0)
Get a canonical multiply expression, or something simpler if possible.
const SCEV * getAddExpr(SmallVectorImpl< const SCEV * > &Ops, SCEV::NoWrapFlags Flags=SCEV::FlagAnyWrap, unsigned Depth=0)
Get a canonical add expression, or something simpler if possible.
size_t size() const
Definition: SmallVector.h:91
void push_back(const T &Elt)
Definition: SmallVector.h:416
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
Definition: SmallVector.h:1200
The instances of the Type class are immutable: once they are created, they are never changed.
Definition: Type.h:45
This is an optimization pass for GlobalISel generic memory operations.
Definition: AddressRanges.h:18
static void divide(ScalarEvolution &SE, const SCEV *Numerator, const SCEV *Denominator, const SCEV **Quotient, const SCEV **Remainder)
void visitAddRecExpr(const SCEVAddRecExpr *Numerator)
void visitConstant(const SCEVConstant *Numerator)
void visitAddExpr(const SCEVAddExpr *Numerator)
void visitMulExpr(const SCEVMulExpr *Numerator)