LLVM  15.0.0git
ScalarEvolutionDivision.cpp
Go to the documentation of this file.
1 //===- ScalarEvolutionDivision.h - See below --------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the class that knows how to divide SCEV's.
10 //
11 //===----------------------------------------------------------------------===//
12 
14 #include "llvm/ADT/APInt.h"
15 #include "llvm/ADT/DenseMap.h"
16 #include "llvm/ADT/SmallVector.h"
18 #include "llvm/Support/Casting.h"
19 #include <cassert>
20 #include <cstdint>
21 
22 namespace llvm {
23 class Type;
24 }
25 
26 using namespace llvm;
27 
28 namespace {
29 
30 static inline int sizeOfSCEV(const SCEV *S) {
31  struct FindSCEVSize {
32  int Size = 0;
33 
34  FindSCEVSize() = default;
35 
36  bool follow(const SCEV *S) {
37  ++Size;
38  // Keep looking at all operands of S.
39  return true;
40  }
41 
42  bool isDone() const { return false; }
43  };
44 
45  FindSCEVSize F;
47  ST.visitAll(S);
48  return F.Size;
49 }
50 
51 } // namespace
52 
53 // Computes the Quotient and Remainder of the division of Numerator by
54 // Denominator.
55 void SCEVDivision::divide(ScalarEvolution &SE, const SCEV *Numerator,
56  const SCEV *Denominator, const SCEV **Quotient,
57  const SCEV **Remainder) {
58  assert(Numerator && Denominator && "Uninitialized SCEV");
59 
60  SCEVDivision D(SE, Numerator, Denominator);
61 
62  // Check for the trivial case here to avoid having to check for it in the
63  // rest of the code.
64  if (Numerator == Denominator) {
65  *Quotient = D.One;
66  *Remainder = D.Zero;
67  return;
68  }
69 
70  if (Numerator->isZero()) {
71  *Quotient = D.Zero;
72  *Remainder = D.Zero;
73  return;
74  }
75 
76  // A simple case when N/1. The quotient is N.
77  if (Denominator->isOne()) {
78  *Quotient = Numerator;
79  *Remainder = D.Zero;
80  return;
81  }
82 
83  // Split the Denominator when it is a product.
84  if (const SCEVMulExpr *T = dyn_cast<SCEVMulExpr>(Denominator)) {
85  const SCEV *Q, *R;
86  *Quotient = Numerator;
87  for (const SCEV *Op : T->operands()) {
88  divide(SE, *Quotient, Op, &Q, &R);
89  *Quotient = Q;
90 
91  // Bail out when the Numerator is not divisible by one of the terms of
92  // the Denominator.
93  if (!R->isZero()) {
94  *Quotient = D.Zero;
95  *Remainder = Numerator;
96  return;
97  }
98  }
99  *Remainder = D.Zero;
100  return;
101  }
102 
103  D.visit(Numerator);
104  *Quotient = D.Quotient;
105  *Remainder = D.Remainder;
106 }
107 
109  if (const SCEVConstant *D = dyn_cast<SCEVConstant>(Denominator)) {
110  APInt NumeratorVal = Numerator->getAPInt();
111  APInt DenominatorVal = D->getAPInt();
112  uint32_t NumeratorBW = NumeratorVal.getBitWidth();
113  uint32_t DenominatorBW = DenominatorVal.getBitWidth();
114 
115  if (NumeratorBW > DenominatorBW)
116  DenominatorVal = DenominatorVal.sext(NumeratorBW);
117  else if (NumeratorBW < DenominatorBW)
118  NumeratorVal = NumeratorVal.sext(DenominatorBW);
119 
120  APInt QuotientVal(NumeratorVal.getBitWidth(), 0);
121  APInt RemainderVal(NumeratorVal.getBitWidth(), 0);
122  APInt::sdivrem(NumeratorVal, DenominatorVal, QuotientVal, RemainderVal);
123  Quotient = SE.getConstant(QuotientVal);
124  Remainder = SE.getConstant(RemainderVal);
125  return;
126  }
127 }
128 
130  const SCEV *StartQ, *StartR, *StepQ, *StepR;
131  if (!Numerator->isAffine())
132  return cannotDivide(Numerator);
133  divide(SE, Numerator->getStart(), Denominator, &StartQ, &StartR);
134  divide(SE, Numerator->getStepRecurrence(SE), Denominator, &StepQ, &StepR);
135  // Bail out if the types do not match.
136  Type *Ty = Denominator->getType();
137  if (Ty != StartQ->getType() || Ty != StartR->getType() ||
138  Ty != StepQ->getType() || Ty != StepR->getType())
139  return cannotDivide(Numerator);
140  Quotient = SE.getAddRecExpr(StartQ, StepQ, Numerator->getLoop(),
141  Numerator->getNoWrapFlags());
142  Remainder = SE.getAddRecExpr(StartR, StepR, Numerator->getLoop(),
143  Numerator->getNoWrapFlags());
144 }
145 
146 void SCEVDivision::visitAddExpr(const SCEVAddExpr *Numerator) {
148  Type *Ty = Denominator->getType();
149 
150  for (const SCEV *Op : Numerator->operands()) {
151  const SCEV *Q, *R;
152  divide(SE, Op, Denominator, &Q, &R);
153 
154  // Bail out if types do not match.
155  if (Ty != Q->getType() || Ty != R->getType())
156  return cannotDivide(Numerator);
157 
158  Qs.push_back(Q);
159  Rs.push_back(R);
160  }
161 
162  if (Qs.size() == 1) {
163  Quotient = Qs[0];
164  Remainder = Rs[0];
165  return;
166  }
167 
168  Quotient = SE.getAddExpr(Qs);
169  Remainder = SE.getAddExpr(Rs);
170 }
171 
172 void SCEVDivision::visitMulExpr(const SCEVMulExpr *Numerator) {
174  Type *Ty = Denominator->getType();
175 
176  bool FoundDenominatorTerm = false;
177  for (const SCEV *Op : Numerator->operands()) {
178  // Bail out if types do not match.
179  if (Ty != Op->getType())
180  return cannotDivide(Numerator);
181 
182  if (FoundDenominatorTerm) {
183  Qs.push_back(Op);
184  continue;
185  }
186 
187  // Check whether Denominator divides one of the product operands.
188  const SCEV *Q, *R;
189  divide(SE, Op, Denominator, &Q, &R);
190  if (!R->isZero()) {
191  Qs.push_back(Op);
192  continue;
193  }
194 
195  // Bail out if types do not match.
196  if (Ty != Q->getType())
197  return cannotDivide(Numerator);
198 
199  FoundDenominatorTerm = true;
200  Qs.push_back(Q);
201  }
202 
203  if (FoundDenominatorTerm) {
204  Remainder = Zero;
205  if (Qs.size() == 1)
206  Quotient = Qs[0];
207  else
208  Quotient = SE.getMulExpr(Qs);
209  return;
210  }
211 
212  if (!isa<SCEVUnknown>(Denominator))
213  return cannotDivide(Numerator);
214 
215  // The Remainder is obtained by replacing Denominator by 0 in Numerator.
216  ValueToSCEVMapTy RewriteMap;
217  RewriteMap[cast<SCEVUnknown>(Denominator)->getValue()] = Zero;
218  Remainder = SCEVParameterRewriter::rewrite(Numerator, SE, RewriteMap);
219 
220  if (Remainder->isZero()) {
221  // The Quotient is obtained by replacing Denominator by 1 in Numerator.
222  RewriteMap[cast<SCEVUnknown>(Denominator)->getValue()] = One;
223  Quotient = SCEVParameterRewriter::rewrite(Numerator, SE, RewriteMap);
224  return;
225  }
226 
227  // Quotient is (Numerator - Remainder) divided by Denominator.
228  const SCEV *Q, *R;
229  const SCEV *Diff = SE.getMinusSCEV(Numerator, Remainder);
230  // This SCEV does not seem to simplify: fail the division here.
231  if (sizeOfSCEV(Diff) > sizeOfSCEV(Numerator))
232  return cannotDivide(Numerator);
233  divide(SE, Diff, Denominator, &Q, &R);
234  if (R != Zero)
235  return cannotDivide(Numerator);
236  Quotient = Q;
237 }
238 
239 SCEVDivision::SCEVDivision(ScalarEvolution &S, const SCEV *Numerator,
240  const SCEV *Denominator)
241  : SE(S), Denominator(Denominator) {
242  Zero = SE.getZero(Denominator->getType());
243  One = SE.getOne(Denominator->getType());
244 
245  // We generally do not know how to divide Expr by Denominator. We initialize
246  // the division to a "cannot divide" state to simplify the rest of the code.
247  cannotDivide(Numerator);
248 }
249 
250 // Convenience function for giving up on the division. We set the quotient to
251 // be equal to zero and the remainder to be equal to the numerator.
252 void SCEVDivision::cannotDivide(const SCEV *Numerator) {
253  Quotient = Zero;
254  Remainder = Numerator;
255 }
llvm::Check::Size
@ Size
Definition: FileCheck.h:77
llvm::SCEVDivision::visitConstant
void visitConstant(const SCEVConstant *Numerator)
Definition: ScalarEvolutionDivision.cpp:108
llvm
This is an optimization pass for GlobalISel generic memory operations.
Definition: AddressRanges.h:17
llvm::SCEVDivision::visitMulExpr
void visitMulExpr(const SCEVMulExpr *Numerator)
Definition: ScalarEvolutionDivision.cpp:172
llvm::SCEVAddRecExpr::isAffine
bool isAffine() const
Return true if this represents an expression A + B*x where A and B are loop invariant values.
Definition: ScalarEvolutionExpressions.h:370
llvm::SCEVAddRecExpr::getStart
const SCEV * getStart() const
Definition: ScalarEvolutionExpressions.h:353
T
llvm::SCEVParameterRewriter::rewrite
static const SCEV * rewrite(const SCEV *Scev, ScalarEvolution &SE, ValueToSCEVMapTy &Map)
Definition: ScalarEvolutionExpressions.h:913
llvm::SmallVector
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
Definition: SmallVector.h:1185
llvm::ScalarEvolution::getAddRecExpr
const SCEV * getAddRecExpr(const SCEV *Start, const SCEV *Step, const Loop *L, SCEV::NoWrapFlags Flags)
Get an add recurrence expression for the specified loop.
Definition: ScalarEvolution.cpp:3571
llvm::ScalarEvolution
The main scalar evolution driver.
Definition: ScalarEvolution.h:449
APInt.h
ScalarEvolution.h
llvm::Type
The instances of the Type class are immutable: once they are created, they are never changed.
Definition: Type.h:45
DenseMap.h
llvm::APInt::getBitWidth
unsigned getBitWidth() const
Return the number of bits in the APInt.
Definition: APInt.h:1411
llvm::SCEVNAryExpr::operands
op_range operands() const
Definition: ScalarEvolutionExpressions.h:211
F
#define F(x, y, z)
Definition: MD5.cpp:55
llvm::ScalarEvolution::getMulExpr
const SCEV * getMulExpr(SmallVectorImpl< const SCEV * > &Ops, SCEV::NoWrapFlags Flags=SCEV::FlagAnyWrap, unsigned Depth=0)
Get a canonical multiply expression, or something simpler if possible.
Definition: ScalarEvolution.cpp:3050
llvm::SCEVTraversal
Visit all nodes in the expression tree using worklist traversal.
Definition: ScalarEvolutionExpressions.h:666
llvm::ScalarEvolution::getOne
const SCEV * getOne(Type *Ty)
Return a SCEV for the constant 1 of a specific type.
Definition: ScalarEvolution.h:645
ScalarEvolutionDivision.h
llvm::SCEVMulExpr
This node represents multiplication of some number of SCEVs.
Definition: ScalarEvolutionExpressions.h:281
llvm::SCEVDivision
Definition: ScalarEvolutionDivision.h:26
llvm::SCEVDivision::visitAddExpr
void visitAddExpr(const SCEVAddExpr *Numerator)
Definition: ScalarEvolutionDivision.cpp:146
llvm::SCEV
This class represents an analyzed expression in the program.
Definition: ScalarEvolution.h:75
D
static GCRegistry::Add< StatepointGC > D("statepoint-example", "an example strategy for statepoint")
llvm::ARM_MB::ST
@ ST
Definition: ARMBaseInfo.h:73
llvm::APInt::sdivrem
static void sdivrem(const APInt &LHS, const APInt &RHS, APInt &Quotient, APInt &Remainder)
Definition: APInt.cpp:1888
llvm::DenseMap
Definition: DenseMap.h:716
llvm::SCEVConstant
This class represents a constant integer value.
Definition: ScalarEvolutionExpressions.h:60
assert
assert(ImpDefSCC.getReg()==AMDGPU::SCC &&ImpDefSCC.isDef())
llvm::APInt
Class for arbitrary precision integers.
Definition: APInt.h:75
llvm::SCEVNAryExpr::getNoWrapFlags
NoWrapFlags getNoWrapFlags(NoWrapFlags Mask=NoWrapMask) const
Definition: ScalarEvolutionExpressions.h:213
llvm::ScalarEvolution::getConstant
const SCEV * getConstant(ConstantInt *V)
Definition: ScalarEvolution.cpp:461
uint32_t
S
add sub stmia L5 ldr r0 bl L_printf $stub Instead of a and a wouldn t it be better to do three moves *Return an aggregate type is even return S
Definition: README.txt:210
llvm::SCEVDivision::divide
static void divide(ScalarEvolution &SE, const SCEV *Numerator, const SCEV *Denominator, const SCEV **Quotient, const SCEV **Remainder)
Definition: ScalarEvolutionDivision.cpp:55
llvm::SCEVAddRecExpr::getLoop
const Loop * getLoop() const
Definition: ScalarEvolutionExpressions.h:354
llvm::ScalarEvolution::getMinusSCEV
const SCEV * getMinusSCEV(const SCEV *LHS, const SCEV *RHS, SCEV::NoWrapFlags Flags=SCEV::FlagAnyWrap, unsigned Depth=0)
Return LHS-RHS.
Definition: ScalarEvolution.cpp:4510
llvm::AMDGPU::SendMsg::Op
Op
Definition: SIDefines.h:348
llvm::SCEVAddRecExpr
This node represents a polynomial recurrence on the trip count of the specified loop.
Definition: ScalarEvolutionExpressions.h:342
Casting.h
llvm::SCEVDivision::visitAddRecExpr
void visitAddRecExpr(const SCEVAddRecExpr *Numerator)
Definition: ScalarEvolutionDivision.cpp:129
llvm::SCEV::isOne
bool isOne() const
Return true if the expression is a constant one.
Definition: ScalarEvolution.cpp:430
llvm::APInt::sext
APInt sext(unsigned width) const
Sign extend to a new width.
Definition: APInt.cpp:946
llvm::SCEVConstant::getAPInt
const APInt & getAPInt() const
Definition: ScalarEvolutionExpressions.h:70
llvm::SCEVAddExpr
This node represents an addition of some number of SCEVs.
Definition: ScalarEvolutionExpressions.h:257
SmallVector.h
llvm::ScalarEvolution::getZero
const SCEV * getZero(Type *Ty)
Return a SCEV for the constant 0 of a specific type.
Definition: ScalarEvolution.h:642
llvm::SCEV::getType
Type * getType() const
Return the LLVM type of this SCEV expression.
Definition: ScalarEvolution.cpp:392
llvm::ScalarEvolution::getAddExpr
const SCEV * getAddExpr(SmallVectorImpl< const SCEV * > &Ops, SCEV::NoWrapFlags Flags=SCEV::FlagAnyWrap, unsigned Depth=0)
Get a canonical add expression, or something simpler if possible.
Definition: ScalarEvolution.cpp:2453
llvm::SCEV::isZero
bool isZero() const
Return true if the expression is a constant zero.
Definition: ScalarEvolution.cpp:424
llvm::SCEVAddRecExpr::getStepRecurrence
const SCEV * getStepRecurrence(ScalarEvolution &SE) const
Constructs and returns the recurrence indicating how much this expression steps by.
Definition: ScalarEvolutionExpressions.h:360