LLVM 22.0.0git
ScalarEvolutionDivision.cpp
Go to the documentation of this file.
1//===- ScalarEvolutionDivision.h - See below --------------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file defines the class that knows how to divide SCEV's.
10//
11//===----------------------------------------------------------------------===//
12
14#include "llvm/ADT/APInt.h"
15#include "llvm/ADT/DenseMap.h"
21#include <cassert>
22#include <cstdint>
23
24#define DEBUG_TYPE "scev-division"
25
26namespace llvm {
27class Type;
28} // namespace llvm
29
30using namespace llvm;
31
32namespace {
33
34static inline int sizeOfSCEV(const SCEV *S) {
35 struct FindSCEVSize {
36 int Size = 0;
37
38 FindSCEVSize() = default;
39
40 bool follow(const SCEV *S) {
41 ++Size;
42 // Keep looking at all operands of S.
43 return true;
44 }
45
46 bool isDone() const { return false; }
47 };
48
49 FindSCEVSize F;
50 SCEVTraversal<FindSCEVSize> ST(F);
51 ST.visitAll(S);
52 return F.Size;
53}
54
55} // namespace
56
57// Computes the Quotient and Remainder of the division of Numerator by
58// Denominator.
59void SCEVDivision::divide(ScalarEvolution &SE, const SCEV *Numerator,
60 const SCEV *Denominator, const SCEV **Quotient,
61 const SCEV **Remainder) {
62 assert(Numerator && Denominator && "Uninitialized SCEV");
63
64 SCEVDivision D(SE, Numerator, Denominator);
65
66 // Check for the trivial case here to avoid having to check for it in the
67 // rest of the code.
68 if (Numerator == Denominator) {
69 *Quotient = D.One;
70 *Remainder = D.Zero;
71 return;
72 }
73
74 if (Numerator->isZero()) {
75 *Quotient = D.Zero;
76 *Remainder = D.Zero;
77 return;
78 }
79
80 // A simple case when N/1. The quotient is N.
81 if (Denominator->isOne()) {
82 *Quotient = Numerator;
83 *Remainder = D.Zero;
84 return;
85 }
86
87 // Split the Denominator when it is a product.
88 if (const SCEVMulExpr *T = dyn_cast<SCEVMulExpr>(Denominator)) {
89 const SCEV *Q, *R;
90 *Quotient = Numerator;
91 for (const SCEV *Op : T->operands()) {
92 divide(SE, *Quotient, Op, &Q, &R);
93 *Quotient = Q;
94
95 // Bail out when the Numerator is not divisible by one of the terms of
96 // the Denominator.
97 if (!R->isZero()) {
98 *Quotient = D.Zero;
99 *Remainder = Numerator;
100 return;
101 }
102 }
103 *Remainder = D.Zero;
104 return;
105 }
106
107 D.visit(Numerator);
108 *Quotient = D.Quotient;
109 *Remainder = D.Remainder;
110}
111
113 if (const SCEVConstant *D = dyn_cast<SCEVConstant>(Denominator)) {
114 APInt NumeratorVal = Numerator->getAPInt();
115 APInt DenominatorVal = D->getAPInt();
116 uint32_t NumeratorBW = NumeratorVal.getBitWidth();
117 uint32_t DenominatorBW = DenominatorVal.getBitWidth();
118
119 if (NumeratorBW > DenominatorBW)
120 DenominatorVal = DenominatorVal.sext(NumeratorBW);
121 else if (NumeratorBW < DenominatorBW)
122 NumeratorVal = NumeratorVal.sext(DenominatorBW);
123
124 APInt QuotientVal(NumeratorVal.getBitWidth(), 0);
125 APInt RemainderVal(NumeratorVal.getBitWidth(), 0);
126 APInt::sdivrem(NumeratorVal, DenominatorVal, QuotientVal, RemainderVal);
127 Quotient = SE.getConstant(QuotientVal);
128 Remainder = SE.getConstant(RemainderVal);
129 return;
130 }
131}
132
133void SCEVDivision::visitVScale(const SCEVVScale *Numerator) {
134 return cannotDivide(Numerator);
135}
136
138 const SCEV *StartQ, *StartR, *StepQ, *StepR;
139 if (!Numerator->isAffine())
140 return cannotDivide(Numerator);
141 divide(SE, Numerator->getStart(), Denominator, &StartQ, &StartR);
142 divide(SE, Numerator->getStepRecurrence(SE), Denominator, &StepQ, &StepR);
143 // Bail out if the types do not match.
144 Type *Ty = Denominator->getType();
145 if (Ty != StartQ->getType() || Ty != StartR->getType() ||
146 Ty != StepQ->getType() || Ty != StepR->getType())
147 return cannotDivide(Numerator);
148 Quotient = SE.getAddRecExpr(StartQ, StepQ, Numerator->getLoop(),
149 Numerator->getNoWrapFlags());
150 Remainder = SE.getAddRecExpr(StartR, StepR, Numerator->getLoop(),
151 Numerator->getNoWrapFlags());
152}
153
156 Type *Ty = Denominator->getType();
157
158 for (const SCEV *Op : Numerator->operands()) {
159 const SCEV *Q, *R;
160 divide(SE, Op, Denominator, &Q, &R);
161
162 // Bail out if types do not match.
163 if (Ty != Q->getType() || Ty != R->getType())
164 return cannotDivide(Numerator);
165
166 Qs.push_back(Q);
167 Rs.push_back(R);
168 }
169
170 if (Qs.size() == 1) {
171 Quotient = Qs[0];
172 Remainder = Rs[0];
173 return;
174 }
175
176 Quotient = SE.getAddExpr(Qs);
177 Remainder = SE.getAddExpr(Rs);
178}
179
182 Type *Ty = Denominator->getType();
183
184 bool FoundDenominatorTerm = false;
185 for (const SCEV *Op : Numerator->operands()) {
186 // Bail out if types do not match.
187 if (Ty != Op->getType())
188 return cannotDivide(Numerator);
189
190 if (FoundDenominatorTerm) {
191 Qs.push_back(Op);
192 continue;
193 }
194
195 // Check whether Denominator divides one of the product operands.
196 const SCEV *Q, *R;
197 divide(SE, Op, Denominator, &Q, &R);
198 if (!R->isZero()) {
199 Qs.push_back(Op);
200 continue;
201 }
202
203 // Bail out if types do not match.
204 if (Ty != Q->getType())
205 return cannotDivide(Numerator);
206
207 FoundDenominatorTerm = true;
208 Qs.push_back(Q);
209 }
210
211 if (FoundDenominatorTerm) {
212 Remainder = Zero;
213 if (Qs.size() == 1)
214 Quotient = Qs[0];
215 else
216 Quotient = SE.getMulExpr(Qs);
217 return;
218 }
219
220 if (!isa<SCEVUnknown>(Denominator))
221 return cannotDivide(Numerator);
222
223 // The Remainder is obtained by replacing Denominator by 0 in Numerator.
224 ValueToSCEVMapTy RewriteMap;
225 RewriteMap[cast<SCEVUnknown>(Denominator)->getValue()] = Zero;
226 Remainder = SCEVParameterRewriter::rewrite(Numerator, SE, RewriteMap);
227
228 if (Remainder->isZero()) {
229 // The Quotient is obtained by replacing Denominator by 1 in Numerator.
230 RewriteMap[cast<SCEVUnknown>(Denominator)->getValue()] = One;
231 Quotient = SCEVParameterRewriter::rewrite(Numerator, SE, RewriteMap);
232 return;
233 }
234
235 // Quotient is (Numerator - Remainder) divided by Denominator.
236 const SCEV *Q, *R;
237 const SCEV *Diff = SE.getMinusSCEV(Numerator, Remainder);
238 // This SCEV does not seem to simplify: fail the division here.
239 if (sizeOfSCEV(Diff) > sizeOfSCEV(Numerator))
240 return cannotDivide(Numerator);
241 divide(SE, Diff, Denominator, &Q, &R);
242 if (R != Zero)
243 return cannotDivide(Numerator);
244 Quotient = Q;
245}
246
247SCEVDivision::SCEVDivision(ScalarEvolution &S, const SCEV *Numerator,
248 const SCEV *Denominator)
249 : SE(S), Denominator(Denominator) {
250 Zero = SE.getZero(Denominator->getType());
251 One = SE.getOne(Denominator->getType());
252
253 // We generally do not know how to divide Expr by Denominator. We initialize
254 // the division to a "cannot divide" state to simplify the rest of the code.
255 cannotDivide(Numerator);
256}
257
258// Convenience function for giving up on the division. We set the quotient to
259// be equal to zero and the remainder to be equal to the numerator.
260void SCEVDivision::cannotDivide(const SCEV *Numerator) {
261 Quotient = Zero;
262 Remainder = Numerator;
263}
264
265void SCEVDivisionPrinterPass::runImpl(Function &F, ScalarEvolution &SE) {
266 OS << "Printing analysis 'Scalar Evolution Division' for function '"
267 << F.getName() << "':\n";
268 for (Instruction &Inst : instructions(F)) {
269 BinaryOperator *Div = dyn_cast<BinaryOperator>(&Inst);
270 if (!Div || Div->getOpcode() != Instruction::SDiv)
271 continue;
272
273 const SCEV *Numerator = SE.getSCEV(Div->getOperand(0));
274 const SCEV *Denominator = SE.getSCEV(Div->getOperand(1));
275 const SCEV *Quotient, *Remainder;
276 SCEVDivision::divide(SE, Numerator, Denominator, &Quotient, &Remainder);
277
278 OS << "Instruction: " << *Div << "\n";
279 OS.indent(2) << "Numerator: " << *Numerator << "\n";
280 OS.indent(2) << "Denominator: " << *Denominator << "\n";
281 OS.indent(2) << "Quotient: " << *Quotient << "\n";
282 OS.indent(2) << "Remainder: " << *Remainder << "\n";
283 }
284}
285
assert(UImm &&(UImm !=~static_cast< T >(0)) &&"Invalid immediate!")
This file implements a class to represent arbitrary precision integral constant values and operations...
Expand Atomic instructions
static GCRegistry::Add< StatepointGC > D("statepoint-example", "an example strategy for statepoint")
This file defines the DenseMap class.
#define F(x, y, z)
Definition MD5.cpp:55
#define T
This file defines the SmallVector class.
Class for arbitrary precision integers.
Definition APInt.h:78
static LLVM_ABI void sdivrem(const APInt &LHS, const APInt &RHS, APInt &Quotient, APInt &Remainder)
Definition APInt.cpp:1890
unsigned getBitWidth() const
Return the number of bits in the APInt.
Definition APInt.h:1488
LLVM_ABI APInt sext(unsigned width) const
Sign extend to a new width.
Definition APInt.cpp:985
PassT::Result & getResult(IRUnitT &IR, ExtraArgTs... ExtraArgs)
Get the result of an analysis pass for a given IR unit.
BinaryOps getOpcode() const
Definition InstrTypes.h:374
A set of analyses that are preserved following a run of a transformation pass.
Definition Analysis.h:112
static PreservedAnalyses all()
Construct a special preserved set that preserves all passes.
Definition Analysis.h:118
This node represents an addition of some number of SCEVs.
This node represents a polynomial recurrence on the trip count of the specified loop.
const SCEV * getStepRecurrence(ScalarEvolution &SE) const
Constructs and returns the recurrence indicating how much this expression steps by.
bool isAffine() const
Return true if this represents an expression A + B*x where A and B are loop invariant values.
This class represents a constant integer value.
const APInt & getAPInt() const
PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM)
This node represents multiplication of some number of SCEVs.
NoWrapFlags getNoWrapFlags(NoWrapFlags Mask=NoWrapMask) const
ArrayRef< const SCEV * > operands() const
static const SCEV * rewrite(const SCEV *Scev, ScalarEvolution &SE, ValueToSCEVMapTy &Map)
This class represents the value of vscale, as used when defining the length of a scalable vector or r...
This class represents an analyzed expression in the program.
LLVM_ABI bool isZero() const
Return true if the expression is a constant zero.
LLVM_ABI Type * getType() const
Return the LLVM type of this SCEV expression.
Analysis pass that exposes the ScalarEvolution for a function.
The main scalar evolution driver.
const SCEV * getZero(Type *Ty)
Return a SCEV for the constant 0 of a specific type.
LLVM_ABI const SCEV * getSCEV(Value *V)
Return a SCEV expression for the full generality of the specified expression.
const SCEV * getOne(Type *Ty)
Return a SCEV for the constant 1 of a specific type.
void push_back(const T &Elt)
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
The instances of the Type class are immutable: once they are created, they are never changed.
Definition Type.h:45
Value * getOperand(unsigned i) const
Definition User.h:232
This is an optimization pass for GlobalISel generic memory operations.
DenseMap< const Value *, const SCEV * > ValueToSCEVMapTy
decltype(auto) dyn_cast(const From &Val)
dyn_cast<X> - Return the argument parameter cast to the specified type.
Definition Casting.h:649
bool isa(const From &Val)
isa<X> - Return true if the parameter to the template is an instance of one of the template type argu...
Definition Casting.h:548
DWARFExpression::Operation Op
decltype(auto) cast(const From &Val)
cast<X> - Return the argument parameter cast to the specified type.
Definition Casting.h:565
AnalysisManager< Function > FunctionAnalysisManager
Convenience typedef for the Function analysis manager.
static void divide(ScalarEvolution &SE, const SCEV *Numerator, const SCEV *Denominator, const SCEV **Quotient, const SCEV **Remainder)
Computes the Quotient and Remainder of the division of Numerator by Denominator.
void visitVScale(const SCEVVScale *Numerator)
void visitAddRecExpr(const SCEVAddRecExpr *Numerator)
void visitConstant(const SCEVConstant *Numerator)
void visitAddExpr(const SCEVAddExpr *Numerator)
void visitMulExpr(const SCEVMulExpr *Numerator)