73 using namespace PatternMatch;
77 static const unsigned UnknownAddressSpace = ~0u;
92 : CandidateKind(
Invalid), Base(nullptr), Index(nullptr),
93 Stride(nullptr),
Ins(nullptr), Basis(nullptr) {}
96 : CandidateKind(CT), Base(B), Index(Idx), Stride(S),
Ins(I),
127 StraightLineStrengthReduce()
140 bool doInitialization(
Module &M)
override {
145 bool runOnFunction(
Function &
F)
override;
150 bool isBasisFor(
const Candidate &Basis,
const Candidate &
C);
156 bool isSimplestForm(
const Candidate &
C);
161 void allocateCandidatesAndFindBasisForAdd(
Instruction *
I);
164 void allocateCandidatesAndFindBasisForAdd(
Value *LHS,
Value *RHS,
167 void allocateCandidatesAndFindBasisForMul(
Instruction *
I);
170 void allocateCandidatesAndFindBasisForMul(
Value *LHS,
Value *RHS,
177 Value *S, uint64_t ElementSize,
185 void rewriteCandidateWithBasis(
const Candidate &
C,
const Candidate &Basis);
189 void factorArrayIndex(
Value *ArrayIdx,
const SCEV *Base, uint64_t ElementSize,
195 static Value *emitBump(
const Candidate &Basis,
const Candidate &
C,
197 bool &BumpWithUglyGEP);
203 std::list<Candidate> Candidates;
207 std::vector<Instruction *> UnlinkedInstructions;
213 "Straight line strength reduction",
false,
false)
221 return new StraightLineStrengthReduce();
224 bool StraightLineStrengthReduce::isBasisFor(
const Candidate &Basis,
225 const Candidate &
C) {
226 return (Basis.Ins != C.Ins &&
229 Basis.Ins->getType() == C.Ins->getType() &&
231 DT->dominates(Basis.Ins->getParent(), C.Ins->getParent()) &&
233 Basis.Base == C.Base && Basis.Stride == C.Stride &&
234 Basis.CandidateKind == C.CandidateKind);
255 bool StraightLineStrengthReduce::isFoldable(
const Candidate &C,
267 unsigned NumNonZeroIndices = 0;
270 if (ConstIdx ==
nullptr || !ConstIdx->
isZero())
273 return NumNonZeroIndices <= 1;
276 bool StraightLineStrengthReduce::isSimplestForm(
const Candidate &C) {
279 return C.Index->isOne() || C.Index->isMinusOne();
281 if (C.CandidateKind == Candidate::Mul) {
283 return C.Index->isZero();
287 return ((C.Index->isOne() || C.Index->isMinusOne()) &&
300 void StraightLineStrengthReduce::allocateCandidatesAndFindBasis(
303 Candidate
C(CT, B, Idx, S, I);
317 if (!isFoldable(C, TTI, DL) && !isSimplestForm(C)) {
319 unsigned NumIterations = 0;
321 static const unsigned MaxNumIterations = 50;
322 for (
auto Basis = Candidates.rbegin();
323 Basis != Candidates.rend() && NumIterations < MaxNumIterations;
324 ++Basis, ++NumIterations) {
325 if (isBasisFor(*Basis, C)) {
333 Candidates.push_back(C);
336 void StraightLineStrengthReduce::allocateCandidatesAndFindBasis(
340 allocateCandidatesAndFindBasisForAdd(I);
342 case Instruction::Mul:
343 allocateCandidatesAndFindBasisForMul(I);
345 case Instruction::GetElementPtr:
346 allocateCandidatesAndFindBasisForGEP(cast<GetElementPtrInst>(I));
351 void StraightLineStrengthReduce::allocateCandidatesAndFindBasisForAdd(
354 if (!isa<IntegerType>(I->
getType()))
359 allocateCandidatesAndFindBasisForAdd(LHS, RHS, I);
361 allocateCandidatesAndFindBasisForAdd(RHS, LHS, I);
364 void StraightLineStrengthReduce::allocateCandidatesAndFindBasisForAdd(
370 allocateCandidatesAndFindBasis(
Candidate::Add, SE->getSCEV(LHS), Idx, S,
I);
375 allocateCandidatesAndFindBasis(
Candidate::Add, SE->getSCEV(LHS), Idx, S,
I);
379 allocateCandidatesAndFindBasis(
Candidate::Add, SE->getSCEV(LHS), One, RHS,
396 void StraightLineStrengthReduce::allocateCandidatesAndFindBasisForMul(
403 allocateCandidatesAndFindBasis(Candidate::Mul, SE->getSCEV(B), Idx, RHS,
I);
409 allocateCandidatesAndFindBasis(Candidate::Mul, SE->getSCEV(B), Idx, RHS,
I);
413 allocateCandidatesAndFindBasis(Candidate::Mul, SE->getSCEV(LHS),
Zero, RHS,
418 void StraightLineStrengthReduce::allocateCandidatesAndFindBasisForMul(
422 if (!isa<IntegerType>(I->
getType()))
427 allocateCandidatesAndFindBasisForMul(LHS, RHS, I);
430 allocateCandidatesAndFindBasisForMul(RHS, LHS, I);
434 void StraightLineStrengthReduce::allocateCandidatesAndFindBasisForGEP(
443 IntPtrTy, Idx->
getSExtValue() * (int64_t)ElementSize,
true);
444 allocateCandidatesAndFindBasis(
Candidate::GEP, B, ScaledIdx, S, I);
447 void StraightLineStrengthReduce::factorArrayIndex(
Value *ArrayIdx,
449 uint64_t ElementSize,
452 allocateCandidatesAndFindBasisForGEP(
454 ArrayIdx, ElementSize,
GEP);
455 Value *LHS =
nullptr;
471 allocateCandidatesAndFindBasisForGEP(Base, RHS, LHS, ElementSize, GEP);
478 allocateCandidatesAndFindBasisForGEP(Base, PowerOf2, LHS, ElementSize, GEP);
482 void StraightLineStrengthReduce::allocateCandidatesAndFindBasisForGEP(
497 const SCEV *OrigIndexExpr = IndexExprs[I - 1];
498 IndexExprs[I - 1] = SE->getZero(OrigIndexExpr->
getType());
502 const SCEV *BaseExpr = SE->getGEPExpr(cast<GEPOperator>(GEP), IndexExprs);
509 factorArrayIndex(ArrayIdx, BaseExpr, ElementSize, GEP);
514 Value *TruncatedArrayIdx =
nullptr;
520 factorArrayIndex(TruncatedArrayIdx, BaseExpr, ElementSize, GEP);
523 IndexExprs[I - 1] = OrigIndexExpr;
535 Value *StraightLineStrengthReduce::emitBump(
const Candidate &Basis,
539 bool &BumpWithUglyGEP) {
540 APInt Idx = C.Index->getValue(), BasisIdx = Basis.Index->getValue();
542 APInt IndexOffset = Idx - BasisIdx;
544 BumpWithUglyGEP =
false;
549 cast<GetElementPtrInst>(Basis.Ins)->getResultElementType()));
555 BumpWithUglyGEP =
true;
560 if (IndexOffset == 1)
574 return Builder.
CreateShl(ExtendedStride, Exponent);
576 if ((-IndexOffset).isPowerOf2()) {
583 return Builder.
CreateMul(ExtendedStride, Delta);
586 void StraightLineStrengthReduce::rewriteCandidateWithBasis(
587 const Candidate &C,
const Candidate &Basis) {
588 assert(C.CandidateKind == Basis.CandidateKind && C.Base == Basis.Base &&
589 C.Stride == Basis.Stride);
592 assert(Basis.Ins->getParent() !=
nullptr &&
"the basis is unlinked");
598 if (!C.Ins->getParent())
602 bool BumpWithUglyGEP;
603 Value *Bump = emitBump(Basis, C, Builder, DL, BumpWithUglyGEP);
604 Value *Reduced =
nullptr;
605 switch (C.CandidateKind) {
626 Reduced = Builder.
CreateAdd(Basis.Ins, Bump);
632 bool InBounds = cast<GetElementPtrInst>(C.Ins)->isInBounds();
633 if (BumpWithUglyGEP) {
635 unsigned AS = Basis.Ins->getType()->getPointerAddressSpace();
651 Reduced = Builder.
CreateGEP(
nullptr, Basis.Ins, Bump);
658 Reduced->takeName(C.Ins);
659 C.Ins->replaceAllUsesWith(Reduced);
662 C.Ins->removeFromParent();
663 UnlinkedInstructions.push_back(C.Ins);
666 bool StraightLineStrengthReduce::runOnFunction(
Function &
F) {
670 TTI = &getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
671 DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree();
672 SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE();
676 for (
auto &I : *(Node->getBlock()))
677 allocateCandidatesAndFindBasis(&I);
681 while (!Candidates.empty()) {
682 const Candidate &C = Candidates.back();
683 if (C.Basis !=
nullptr) {
684 rewriteCandidateWithBasis(C, *C.Basis);
686 Candidates.pop_back();
690 for (
auto *UnlinkedInst : UnlinkedInstructions) {
691 for (
unsigned I = 0,
E = UnlinkedInst->getNumOperands(); I !=
E; ++
I) {
692 Value *
Op = UnlinkedInst->getOperand(I);
693 UnlinkedInst->setOperand(I,
nullptr);
698 bool Ret = !UnlinkedInstructions.empty();
699 UnlinkedInstructions.clear();
FunctionPass * createStraightLineStrengthReducePass()
void push_back(const T &Elt)
A parsed version of the target data layout string in and methods for querying it. ...
Type * getIndexedType() const
Value * getPointerOperand()
class_match< Value > m_Value()
Match an arbitrary value and ignore it.
static PassRegistry * getPassRegistry()
getPassRegistry - Access the global registry object, which is automatically initialized at applicatio...
Type * getSourceElementType() const
A Module instance is used to store all the information related to an LLVM module. ...
unsigned getBitWidth() const
getBitWidth - Return the bitwidth of this constant.
unsigned getNumOperands() const
BinaryOp_match< LHS, RHS, Instruction::Mul > m_Mul(const LHS &L, const RHS &R)
The main scalar evolution driver.
static void unifyBitWidth(APInt &A, APInt &B)
static void sdivrem(const APInt &LHS, const APInt &RHS, APInt &Quotient, APInt &Remainder)
static bool matchesAdd(Value *A, Value *&B, ConstantInt *&C)
bool match(Val *V, const Pattern &P)
AnalysisUsage & addRequired()
#define INITIALIZE_PASS_DEPENDENCY(depName)
static const Value * getNegArgument(const Value *BinOp)
Helper functions to extract the unary argument of a NEG, FNEG or NOT operation implemented via Sub...
static bool matchesOr(Value *A, Value *&B, ConstantInt *&C)
const APInt & getValue() const
Return the constant as an APInt value reference.
This provides a uniform API for creating instructions and inserting them into a basic block: either a...
unsigned logBase2(const APInt &APIVal)
Returns the floor log base 2 of the specified APInt value.
Value * CreateAdd(Value *LHS, Value *RHS, const Twine &Name="", bool HasNUW=false, bool HasNSW=false)
static bool hasOnlyOneNonZeroIndex(GetElementPtrInst *GEP)
BinaryOp_match< LHS, RHS, Instruction::Add > m_Add(const LHS &L, const RHS &R)
void initializeStraightLineStrengthReducePass(PassRegistry &)
Value * CreateBitCast(Value *V, Type *DestTy, const Twine &Name="")
Value * CreateSExtOrTrunc(Value *V, Type *DestTy, const Twine &Name="")
Create a SExt or Trunc from the integer value V to DestTy.
static bool isGEPFoldable(GetElementPtrInst *GEP, const TargetTransformInfo *TTI)
class_match< ConstantInt > m_ConstantInt()
Match an arbitrary ConstantInt and ignore it.
Value * CreateSub(Value *LHS, Value *RHS, const Twine &Name="", bool HasNUW=false, bool HasNSW=false)
static GCRegistry::Add< OcamlGC > B("ocaml","ocaml 3.10-compatible GC")
INITIALIZE_PASS_BEGIN(StraightLineStrengthReduce,"slsr","Straight line strength reduction", false, false) INITIALIZE_PASS_END(StraightLineStrengthReduce
bool haveNoCommonBitsSet(const Value *LHS, const Value *RHS, const DataLayout &DL, AssumptionCache *AC=nullptr, const Instruction *CxtI=nullptr, const DominatorTree *DT=nullptr)
Return true if LHS and RHS have no common bits set.
Value * CreateInBoundsGEP(Value *Ptr, ArrayRef< Value * > IdxList, const Twine &Name="")
Concrete subclass of DominatorTreeBase that is used to compute a normal dominator tree...
static GCRegistry::Add< CoreCLRGC > E("coreclr","CoreCLR-compatible GC")
an instruction for type-safe pointer arithmetic to access elements of arrays and structs ...
The instances of the Type class are immutable: once they are created, they are never changed...
BinaryOp_match< LHS, RHS, Instruction::Or > m_Or(const LHS &L, const RHS &R)
Type * getType() const
Return the LLVM type of this SCEV expression.
bool isVectorTy() const
True if this is an instance of VectorType.
This is an important base class in LLVM.
Straight line strength reduction
APInt sext(unsigned width) const
Sign extend to a new width.
Represent the analysis usage information of a pass.
BinaryOp_match< LHS, RHS, Instruction::Shl > m_Shl(const LHS &L, const RHS &R)
Value * CreateNeg(Value *V, const Twine &Name="", bool HasNUW=false, bool HasNSW=false)
INITIALIZE_PASS_END(RegBankSelect, DEBUG_TYPE,"Assign register bank of generic virtual registers", false, false) RegBankSelect
unsigned getBitWidth() const
Return the number of bits in the APInt.
FunctionPass class - This class is used to implement most global optimizations.
Value * getOperand(unsigned i) const
unsigned getIntegerBitWidth() const
Class to represent integer types.
bool RecursivelyDeleteTriviallyDeadInstructions(Value *V, const TargetLibraryInfo *TLI=nullptr)
If the specified value is a trivially dead instruction, delete it.
LLVMContext & getContext() const
All values hold a context through their type.
static PointerType * getInt8PtrTy(LLVMContext &C, unsigned AS=0)
bool isPowerOf2() const
Check if this APInt's value is a power of two greater than zero.
#define llvm_unreachable(msg)
Marks that the current location is not supposed to be reachable.
Value * CreateMul(Value *LHS, Value *RHS, const Twine &Name="", bool HasNUW=false, bool HasNSW=false)
unsigned getAddressSpace() const
Returns the address space of this instruction's pointer type.
CastClass_match< OpTy, Instruction::SExt > m_SExt(const OpTy &Op)
Matches SExt.
IntegerType * getIntPtrType(LLVMContext &C, unsigned AddressSpace=0) const
Returns an integer type with size at least as big as that of a pointer in the given address space...
static IntegerType * get(LLVMContext &C, unsigned NumBits)
This static method is the primary way of constructing an IntegerType.
This is the shared class of boolean and integer constants.
uint64_t getTypeAllocSize(Type *Ty) const
Returns the offset in bytes between successive objects of the specified type, including alignment pad...
unsigned logBase2() const
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small...
Module.h This file contains the declarations for the Module class.
Type * getType() const
All values are typed, get the type of this value.
static Constant * get(Type *Ty, uint64_t V, bool isSigned=false)
If Ty is a vector type, return a Constant with a splat of the given value.
bool isZero() const
This is just a convenience method to make client code smaller for a common code.
void setPreservesCFG()
This function should be called by the pass, iff they do not:
Straight line strength false
static GCRegistry::Add< ShadowStackGC > C("shadow-stack","Very portable GC for uncooperative code generators")
Value * CreateGEP(Value *Ptr, ArrayRef< Value * > IdxList, const Twine &Name="")
static bool isNeg(const Value *V)
Check if the given Value is a NEG, FNeg, or NOT instruction.
Class for arbitrary precision integers.
IntegerType * getInt8Ty()
Fetch the type representing an 8-bit integer.
Value * CreateShl(Value *LHS, Value *RHS, const Twine &Name="", bool HasNUW=false, bool HasNSW=false)
bool isAllOnesValue() const
Determine if all bits are set.
const DataLayout & getDataLayout() const
Get the data layout for the module's target platform.
This class represents an analyzed expression in the program.
OverflowingBinaryOp_match< LHS, RHS, Instruction::Shl, OverflowingBinaryOperator::NoSignedWrap > m_NSWShl(const LHS &L, const RHS &R)
unsigned getPointerSizeInBits(unsigned AS=0) const
Layout pointer size, in bits FIXME: The defaults need to be removed once all of the backends/clients ...
LLVM_NODISCARD std::enable_if<!is_simple_type< Y >::value, typename cast_retty< X, const Y >::ret_type >::type dyn_cast(const Y &Val)
iterator_range< df_iterator< T > > depth_first(const T &G)
assert(ImpDefSCC.getReg()==AMDGPU::SCC &&ImpDefSCC.isDef())
LLVM Value Representation.
unsigned getOpcode() const
Returns a member of one of the enums like Instruction::Add.
OverflowingBinaryOp_match< LHS, RHS, Instruction::Mul, OverflowingBinaryOperator::NoSignedWrap > m_NSWMul(const LHS &L, const RHS &R)
Legacy analysis pass which computes a DominatorTree.
int64_t getSExtValue() const
Return the constant as a 64-bit integer value after it has been sign extended as appropriate for the ...
static GCRegistry::Add< ErlangGC > A("erlang","erlang-compatible garbage collector")
static bool isAddFoldable(const SCEV *Base, ConstantInt *Index, Value *Stride, TargetTransformInfo *TTI)
gep_type_iterator gep_type_begin(const User *GEP)