LLVM  14.0.0git
ScalarEvolutionDivision.cpp
Go to the documentation of this file.
1 //===- ScalarEvolutionDivision.h - See below --------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the class that knows how to divide SCEV's.
10 //
11 //===----------------------------------------------------------------------===//
12 
14 #include "llvm/ADT/APInt.h"
15 #include "llvm/ADT/DenseMap.h"
16 #include "llvm/ADT/SmallVector.h"
18 #include "llvm/IR/Constants.h"
19 #include "llvm/Support/Casting.h"
21 #include <cassert>
22 #include <cstdint>
23 
24 namespace llvm {
25 class Type;
26 }
27 
28 using namespace llvm;
29 
30 namespace {
31 
32 static inline int sizeOfSCEV(const SCEV *S) {
33  struct FindSCEVSize {
34  int Size = 0;
35 
36  FindSCEVSize() = default;
37 
38  bool follow(const SCEV *S) {
39  ++Size;
40  // Keep looking at all operands of S.
41  return true;
42  }
43 
44  bool isDone() const { return false; }
45  };
46 
47  FindSCEVSize F;
49  ST.visitAll(S);
50  return F.Size;
51 }
52 
53 } // namespace
54 
55 // Computes the Quotient and Remainder of the division of Numerator by
56 // Denominator.
57 void SCEVDivision::divide(ScalarEvolution &SE, const SCEV *Numerator,
58  const SCEV *Denominator, const SCEV **Quotient,
59  const SCEV **Remainder) {
60  assert(Numerator && Denominator && "Uninitialized SCEV");
61 
62  SCEVDivision D(SE, Numerator, Denominator);
63 
64  // Check for the trivial case here to avoid having to check for it in the
65  // rest of the code.
66  if (Numerator == Denominator) {
67  *Quotient = D.One;
68  *Remainder = D.Zero;
69  return;
70  }
71 
72  if (Numerator->isZero()) {
73  *Quotient = D.Zero;
74  *Remainder = D.Zero;
75  return;
76  }
77 
78  // A simple case when N/1. The quotient is N.
79  if (Denominator->isOne()) {
80  *Quotient = Numerator;
81  *Remainder = D.Zero;
82  return;
83  }
84 
85  // Split the Denominator when it is a product.
86  if (const SCEVMulExpr *T = dyn_cast<SCEVMulExpr>(Denominator)) {
87  const SCEV *Q, *R;
88  *Quotient = Numerator;
89  for (const SCEV *Op : T->operands()) {
90  divide(SE, *Quotient, Op, &Q, &R);
91  *Quotient = Q;
92 
93  // Bail out when the Numerator is not divisible by one of the terms of
94  // the Denominator.
95  if (!R->isZero()) {
96  *Quotient = D.Zero;
97  *Remainder = Numerator;
98  return;
99  }
100  }
101  *Remainder = D.Zero;
102  return;
103  }
104 
105  D.visit(Numerator);
106  *Quotient = D.Quotient;
107  *Remainder = D.Remainder;
108 }
109 
111  if (const SCEVConstant *D = dyn_cast<SCEVConstant>(Denominator)) {
112  APInt NumeratorVal = Numerator->getAPInt();
113  APInt DenominatorVal = D->getAPInt();
114  uint32_t NumeratorBW = NumeratorVal.getBitWidth();
115  uint32_t DenominatorBW = DenominatorVal.getBitWidth();
116 
117  if (NumeratorBW > DenominatorBW)
118  DenominatorVal = DenominatorVal.sext(NumeratorBW);
119  else if (NumeratorBW < DenominatorBW)
120  NumeratorVal = NumeratorVal.sext(DenominatorBW);
121 
122  APInt QuotientVal(NumeratorVal.getBitWidth(), 0);
123  APInt RemainderVal(NumeratorVal.getBitWidth(), 0);
124  APInt::sdivrem(NumeratorVal, DenominatorVal, QuotientVal, RemainderVal);
125  Quotient = SE.getConstant(QuotientVal);
126  Remainder = SE.getConstant(RemainderVal);
127  return;
128  }
129 }
130 
132  const SCEV *StartQ, *StartR, *StepQ, *StepR;
133  if (!Numerator->isAffine())
134  return cannotDivide(Numerator);
135  divide(SE, Numerator->getStart(), Denominator, &StartQ, &StartR);
136  divide(SE, Numerator->getStepRecurrence(SE), Denominator, &StepQ, &StepR);
137  // Bail out if the types do not match.
138  Type *Ty = Denominator->getType();
139  if (Ty != StartQ->getType() || Ty != StartR->getType() ||
140  Ty != StepQ->getType() || Ty != StepR->getType())
141  return cannotDivide(Numerator);
142  Quotient = SE.getAddRecExpr(StartQ, StepQ, Numerator->getLoop(),
143  Numerator->getNoWrapFlags());
144  Remainder = SE.getAddRecExpr(StartR, StepR, Numerator->getLoop(),
145  Numerator->getNoWrapFlags());
146 }
147 
148 void SCEVDivision::visitAddExpr(const SCEVAddExpr *Numerator) {
150  Type *Ty = Denominator->getType();
151 
152  for (const SCEV *Op : Numerator->operands()) {
153  const SCEV *Q, *R;
154  divide(SE, Op, Denominator, &Q, &R);
155 
156  // Bail out if types do not match.
157  if (Ty != Q->getType() || Ty != R->getType())
158  return cannotDivide(Numerator);
159 
160  Qs.push_back(Q);
161  Rs.push_back(R);
162  }
163 
164  if (Qs.size() == 1) {
165  Quotient = Qs[0];
166  Remainder = Rs[0];
167  return;
168  }
169 
170  Quotient = SE.getAddExpr(Qs);
171  Remainder = SE.getAddExpr(Rs);
172 }
173 
174 void SCEVDivision::visitMulExpr(const SCEVMulExpr *Numerator) {
176  Type *Ty = Denominator->getType();
177 
178  bool FoundDenominatorTerm = false;
179  for (const SCEV *Op : Numerator->operands()) {
180  // Bail out if types do not match.
181  if (Ty != Op->getType())
182  return cannotDivide(Numerator);
183 
184  if (FoundDenominatorTerm) {
185  Qs.push_back(Op);
186  continue;
187  }
188 
189  // Check whether Denominator divides one of the product operands.
190  const SCEV *Q, *R;
191  divide(SE, Op, Denominator, &Q, &R);
192  if (!R->isZero()) {
193  Qs.push_back(Op);
194  continue;
195  }
196 
197  // Bail out if types do not match.
198  if (Ty != Q->getType())
199  return cannotDivide(Numerator);
200 
201  FoundDenominatorTerm = true;
202  Qs.push_back(Q);
203  }
204 
205  if (FoundDenominatorTerm) {
206  Remainder = Zero;
207  if (Qs.size() == 1)
208  Quotient = Qs[0];
209  else
210  Quotient = SE.getMulExpr(Qs);
211  return;
212  }
213 
214  if (!isa<SCEVUnknown>(Denominator))
215  return cannotDivide(Numerator);
216 
217  // The Remainder is obtained by replacing Denominator by 0 in Numerator.
218  ValueToSCEVMapTy RewriteMap;
219  RewriteMap[cast<SCEVUnknown>(Denominator)->getValue()] = Zero;
220  Remainder = SCEVParameterRewriter::rewrite(Numerator, SE, RewriteMap);
221 
222  if (Remainder->isZero()) {
223  // The Quotient is obtained by replacing Denominator by 1 in Numerator.
224  RewriteMap[cast<SCEVUnknown>(Denominator)->getValue()] = One;
225  Quotient = SCEVParameterRewriter::rewrite(Numerator, SE, RewriteMap);
226  return;
227  }
228 
229  // Quotient is (Numerator - Remainder) divided by Denominator.
230  const SCEV *Q, *R;
231  const SCEV *Diff = SE.getMinusSCEV(Numerator, Remainder);
232  // This SCEV does not seem to simplify: fail the division here.
233  if (sizeOfSCEV(Diff) > sizeOfSCEV(Numerator))
234  return cannotDivide(Numerator);
235  divide(SE, Diff, Denominator, &Q, &R);
236  if (R != Zero)
237  return cannotDivide(Numerator);
238  Quotient = Q;
239 }
240 
241 SCEVDivision::SCEVDivision(ScalarEvolution &S, const SCEV *Numerator,
242  const SCEV *Denominator)
243  : SE(S), Denominator(Denominator) {
244  Zero = SE.getZero(Denominator->getType());
245  One = SE.getOne(Denominator->getType());
246 
247  // We generally do not know how to divide Expr by Denominator. We initialize
248  // the division to a "cannot divide" state to simplify the rest of the code.
249  cannotDivide(Numerator);
250 }
251 
252 // Convenience function for giving up on the division. We set the quotient to
253 // be equal to zero and the remainder to be equal to the numerator.
254 void SCEVDivision::cannotDivide(const SCEV *Numerator) {
255  Quotient = Zero;
256  Remainder = Numerator;
257 }
llvm::Check::Size
@ Size
Definition: FileCheck.h:73
llvm::SCEVDivision::visitConstant
void visitConstant(const SCEVConstant *Numerator)
Definition: ScalarEvolutionDivision.cpp:110
llvm
This file implements support for optimizing divisions by a constant.
Definition: AllocatorList.h:23
llvm::SCEVDivision::visitMulExpr
void visitMulExpr(const SCEVMulExpr *Numerator)
Definition: ScalarEvolutionDivision.cpp:174
llvm::SCEVAddRecExpr::isAffine
bool isAffine() const
Return true if this represents an expression A + B*x where A and B are loop invariant values.
Definition: ScalarEvolutionExpressions.h:379
llvm::SCEVAddRecExpr::getStart
const SCEV * getStart() const
Definition: ScalarEvolutionExpressions.h:363
T
llvm::SCEVParameterRewriter::rewrite
static const SCEV * rewrite(const SCEV *Scev, ScalarEvolution &SE, ValueToSCEVMapTy &Map)
Definition: ScalarEvolutionExpressions.h:864
llvm::SmallVector
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
Definition: SmallVector.h:1168
llvm::ScalarEvolution::getAddRecExpr
const SCEV * getAddRecExpr(const SCEV *Start, const SCEV *Step, const Loop *L, SCEV::NoWrapFlags Flags)
Get an add recurrence expression for the specified loop.
Definition: ScalarEvolution.cpp:3526
ErrorHandling.h
llvm::ScalarEvolution
The main scalar evolution driver.
Definition: ScalarEvolution.h:460
APInt.h
ScalarEvolution.h
llvm::Type
The instances of the Type class are immutable: once they are created, they are never changed.
Definition: Type.h:45
DenseMap.h
llvm::APInt::getBitWidth
unsigned getBitWidth() const
Return the number of bits in the APInt.
Definition: APInt.h:1399
llvm::SCEVNAryExpr::operands
op_range operands() const
Definition: ScalarEvolutionExpressions.h:209
F
#define F(x, y, z)
Definition: MD5.cpp:56
llvm::ScalarEvolution::getMulExpr
const SCEV * getMulExpr(SmallVectorImpl< const SCEV * > &Ops, SCEV::NoWrapFlags Flags=SCEV::FlagAnyWrap, unsigned Depth=0)
Get a canonical multiply expression, or something simpler if possible.
Definition: ScalarEvolution.cpp:3002
llvm::SCEVTraversal
Visit all nodes in the expression tree using worklist traversal.
Definition: ScalarEvolutionExpressions.h:626
llvm::ScalarEvolution::getOne
const SCEV * getOne(Type *Ty)
Return a SCEV for the constant 1 of a specific type.
Definition: ScalarEvolution.h:628
Constants.h
ScalarEvolutionDivision.h
llvm::SCEVMulExpr
This node represents multiplication of some number of SCEVs.
Definition: ScalarEvolutionExpressions.h:286
llvm::SCEVDivision
Definition: ScalarEvolutionDivision.h:26
llvm::SCEVDivision::visitAddExpr
void visitAddExpr(const SCEVAddExpr *Numerator)
Definition: ScalarEvolutionDivision.cpp:148
llvm::SCEV
This class represents an analyzed expression in the program.
Definition: ScalarEvolution.h:77
D
static GCRegistry::Add< StatepointGC > D("statepoint-example", "an example strategy for statepoint")
llvm::ARM_MB::ST
@ ST
Definition: ARMBaseInfo.h:73
llvm::APInt::sdivrem
static void sdivrem(const APInt &LHS, const APInt &RHS, APInt &Quotient, APInt &Remainder)
Definition: APInt.cpp:1882
llvm::DenseMap
Definition: DenseMap.h:714
llvm::SCEVConstant
This class represents a constant integer value.
Definition: ScalarEvolutionExpressions.h:47
assert
assert(ImpDefSCC.getReg()==AMDGPU::SCC &&ImpDefSCC.isDef())
llvm::APInt
Class for arbitrary precision integers.
Definition: APInt.h:75
llvm::SCEVNAryExpr::getNoWrapFlags
NoWrapFlags getNoWrapFlags(NoWrapFlags Mask=NoWrapMask) const
Definition: ScalarEvolutionExpressions.h:213
llvm::ScalarEvolution::getConstant
const SCEV * getConstant(ConstantInt *V)
Definition: ScalarEvolution.cpp:446
uint32_t
S
add sub stmia L5 ldr r0 bl L_printf $stub Instead of a and a wouldn t it be better to do three moves *Return an aggregate type is even return S
Definition: README.txt:210
llvm::SCEVDivision::divide
static void divide(ScalarEvolution &SE, const SCEV *Numerator, const SCEV *Denominator, const SCEV **Quotient, const SCEV **Remainder)
Definition: ScalarEvolutionDivision.cpp:57
llvm::SCEVAddRecExpr::getLoop
const Loop * getLoop() const
Definition: ScalarEvolutionExpressions.h:364
llvm::ScalarEvolution::getMinusSCEV
const SCEV * getMinusSCEV(const SCEV *LHS, const SCEV *RHS, SCEV::NoWrapFlags Flags=SCEV::FlagAnyWrap, unsigned Depth=0)
Return LHS-RHS.
Definition: ScalarEvolution.cpp:4223
llvm::AMDGPU::SendMsg::Op
Op
Definition: SIDefines.h:324
llvm::SCEVAddRecExpr
This node represents a polynomial recurrence on the trip count of the specified loop.
Definition: ScalarEvolutionExpressions.h:352
Casting.h
llvm::SCEVDivision::visitAddRecExpr
void visitAddRecExpr(const SCEVAddRecExpr *Numerator)
Definition: ScalarEvolutionDivision.cpp:131
llvm::SCEV::isOne
bool isOne() const
Return true if the expression is a constant one.
Definition: ScalarEvolution.cpp:415
llvm::APInt::sext
APInt sext(unsigned width) const
Sign extend to a new width.
Definition: APInt.cpp:928
llvm::SCEVConstant::getAPInt
const APInt & getAPInt() const
Definition: ScalarEvolutionExpressions.h:57
llvm::SCEVAddExpr
This node represents an addition of some number of SCEVs.
Definition: ScalarEvolutionExpressions.h:260
SmallVector.h
llvm::ScalarEvolution::getZero
const SCEV * getZero(Type *Ty)
Return a SCEV for the constant 0 of a specific type.
Definition: ScalarEvolution.h:625
llvm::SCEV::getType
Type * getType() const
Return the LLVM type of this SCEV expression.
Definition: ScalarEvolution.cpp:379
llvm::ScalarEvolution::getAddExpr
const SCEV * getAddExpr(SmallVectorImpl< const SCEV * > &Ops, SCEV::NoWrapFlags Flags=SCEV::FlagAnyWrap, unsigned Depth=0)
Get a canonical add expression, or something simpler if possible.
Definition: ScalarEvolution.cpp:2419
llvm::SCEV::isZero
bool isZero() const
Return true if the expression is a constant zero.
Definition: ScalarEvolution.cpp:409
llvm::SCEVAddRecExpr::getStepRecurrence
const SCEV * getStepRecurrence(ScalarEvolution &SE) const
Constructs and returns the recurrence indicating how much this expression steps by.
Definition: ScalarEvolutionExpressions.h:370