109 Candidate() =
default;
112 : CandidateKind(CT),
Base(B),
Index(Idx), Stride(S),
Ins(I) {}
114 Kind CandidateKind = Invalid;
123 Value *Stride =
nullptr;
143 Candidate *Basis =
nullptr;
160 bool doInitialization(
Module &M)
override {
170 bool isBasisFor(
const Candidate &Basis,
const Candidate &
C);
178 bool isSimplestForm(
const Candidate &
C);
185 void allocateCandidatesAndFindBasisForAdd(
Instruction *
I);
189 void allocateCandidatesAndFindBasisForAdd(
Value *LHS,
Value *RHS,
192 void allocateCandidatesAndFindBasisForMul(
Instruction *
I);
196 void allocateCandidatesAndFindBasisForMul(
Value *LHS,
Value *RHS,
205 Value *S, uint64_t ElementSize,
215 void rewriteCandidateWithBasis(
const Candidate &
C,
const Candidate &Basis);
220 void factorArrayIndex(
Value *ArrayIdx,
const SCEV *
Base, uint64_t ElementSize,
227 static Value *emitBump(
const Candidate &Basis,
const Candidate &
C,
229 bool &BumpWithUglyGEP);
235 std::list<Candidate> Candidates;
240 std::vector<Instruction *> UnlinkedInstructions;
248 "Straight line strength reduction",
false,
false)
256 return new StraightLineStrengthReduce();
259 bool StraightLineStrengthReduce::isBasisFor(
const Candidate &Basis,
260 const Candidate &
C) {
261 return (Basis.Ins != C.Ins &&
264 Basis.Ins->getType() == C.Ins->getType() &&
266 DT->dominates(Basis.Ins->getParent(), C.Ins->getParent()) &&
268 Basis.Base == C.Base && Basis.Stride == C.Stride &&
269 Basis.CandidateKind == C.CandidateKind);
290 bool StraightLineStrengthReduce::isFoldable(
const Candidate &C,
302 unsigned NumNonZeroIndices = 0;
305 if (ConstIdx ==
nullptr || !ConstIdx->
isZero())
308 return NumNonZeroIndices <= 1;
311 bool StraightLineStrengthReduce::isSimplestForm(
const Candidate &C) {
314 return C.Index->isOne() || C.Index->isMinusOne();
316 if (C.CandidateKind == Candidate::Mul) {
318 return C.Index->isZero();
322 return ((C.Index->isOne() || C.Index->isMinusOne()) &&
335 void StraightLineStrengthReduce::allocateCandidatesAndFindBasis(
338 Candidate
C(CT, B, Idx, S, I);
352 if (!isFoldable(C, TTI, DL) && !isSimplestForm(C)) {
354 unsigned NumIterations = 0;
356 static const unsigned MaxNumIterations = 50;
357 for (
auto Basis = Candidates.rbegin();
358 Basis != Candidates.rend() && NumIterations < MaxNumIterations;
359 ++Basis, ++NumIterations) {
360 if (isBasisFor(*Basis, C)) {
368 Candidates.push_back(C);
371 void StraightLineStrengthReduce::allocateCandidatesAndFindBasis(
375 allocateCandidatesAndFindBasisForAdd(I);
377 case Instruction::Mul:
378 allocateCandidatesAndFindBasisForMul(I);
380 case Instruction::GetElementPtr:
381 allocateCandidatesAndFindBasisForGEP(cast<GetElementPtrInst>(I));
386 void StraightLineStrengthReduce::allocateCandidatesAndFindBasisForAdd(
389 if (!isa<IntegerType>(I->
getType()))
394 allocateCandidatesAndFindBasisForAdd(LHS, RHS, I);
396 allocateCandidatesAndFindBasisForAdd(RHS, LHS, I);
399 void StraightLineStrengthReduce::allocateCandidatesAndFindBasisForAdd(
405 allocateCandidatesAndFindBasis(
Candidate::Add, SE->getSCEV(LHS), Idx, S,
I);
410 allocateCandidatesAndFindBasis(
Candidate::Add, SE->getSCEV(LHS), Idx, S,
I);
414 allocateCandidatesAndFindBasis(
Candidate::Add, SE->getSCEV(LHS), One, RHS,
431 void StraightLineStrengthReduce::allocateCandidatesAndFindBasisForMul(
438 allocateCandidatesAndFindBasis(Candidate::Mul, SE->getSCEV(B), Idx, RHS,
I);
444 allocateCandidatesAndFindBasis(Candidate::Mul, SE->getSCEV(B), Idx, RHS,
I);
448 allocateCandidatesAndFindBasis(Candidate::Mul, SE->getSCEV(LHS), Zero, RHS,
453 void StraightLineStrengthReduce::allocateCandidatesAndFindBasisForMul(
457 if (!isa<IntegerType>(I->
getType()))
462 allocateCandidatesAndFindBasisForMul(LHS, RHS, I);
465 allocateCandidatesAndFindBasisForMul(RHS, LHS, I);
469 void StraightLineStrengthReduce::allocateCandidatesAndFindBasisForGEP(
478 IntPtrTy, Idx->
getSExtValue() * (int64_t)ElementSize,
true);
479 allocateCandidatesAndFindBasis(
Candidate::GEP, B, ScaledIdx, S, I);
482 void StraightLineStrengthReduce::factorArrayIndex(
Value *ArrayIdx,
484 uint64_t ElementSize,
487 allocateCandidatesAndFindBasisForGEP(
489 ArrayIdx, ElementSize,
GEP);
490 Value *LHS =
nullptr;
506 allocateCandidatesAndFindBasisForGEP(Base, RHS, LHS, ElementSize, GEP);
513 allocateCandidatesAndFindBasisForGEP(Base, PowerOf2, LHS, ElementSize, GEP);
517 void StraightLineStrengthReduce::allocateCandidatesAndFindBasisForGEP(
532 const SCEV *OrigIndexExpr = IndexExprs[I - 1];
533 IndexExprs[I - 1] = SE->getZero(OrigIndexExpr->
getType());
537 const SCEV *BaseExpr = SE->getGEPExpr(cast<GEPOperator>(GEP), IndexExprs);
544 factorArrayIndex(ArrayIdx, BaseExpr, ElementSize, GEP);
549 Value *TruncatedArrayIdx =
nullptr;
555 factorArrayIndex(TruncatedArrayIdx, BaseExpr, ElementSize, GEP);
558 IndexExprs[I - 1] = OrigIndexExpr;
570 Value *StraightLineStrengthReduce::emitBump(
const Candidate &Basis,
574 bool &BumpWithUglyGEP) {
575 APInt Idx = C.Index->getValue(), BasisIdx = Basis.Index->getValue();
577 APInt IndexOffset = Idx - BasisIdx;
579 BumpWithUglyGEP =
false;
584 cast<GetElementPtrInst>(Basis.Ins)->getResultElementType()));
590 BumpWithUglyGEP =
true;
595 if (IndexOffset == 1)
609 return Builder.
CreateShl(ExtendedStride, Exponent);
611 if ((-IndexOffset).isPowerOf2()) {
618 return Builder.
CreateMul(ExtendedStride, Delta);
621 void StraightLineStrengthReduce::rewriteCandidateWithBasis(
622 const Candidate &C,
const Candidate &Basis) {
623 assert(C.CandidateKind == Basis.CandidateKind && C.Base == Basis.Base &&
624 C.Stride == Basis.Stride);
627 assert(Basis.Ins->getParent() !=
nullptr &&
"the basis is unlinked");
633 if (!C.Ins->getParent())
637 bool BumpWithUglyGEP;
638 Value *Bump = emitBump(Basis, C, Builder, DL, BumpWithUglyGEP);
639 Value *Reduced =
nullptr;
640 switch (C.CandidateKind) {
642 case Candidate::Mul: {
647 Reduced = Builder.
CreateSub(Basis.Ins, NegBump);
661 Reduced = Builder.
CreateAdd(Basis.Ins, Bump);
668 bool InBounds = cast<GetElementPtrInst>(C.Ins)->isInBounds();
669 if (BumpWithUglyGEP) {
671 unsigned AS = Basis.Ins->getType()->getPointerAddressSpace();
686 cast<GetElementPtrInst>(Basis.Ins)->getResultElementType(),
690 cast<GetElementPtrInst>(Basis.Ins)->getResultElementType(),
699 C.Ins->replaceAllUsesWith(Reduced);
702 C.Ins->removeFromParent();
703 UnlinkedInstructions.push_back(C.Ins);
710 TTI = &getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
711 DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree();
712 SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE();
716 for (
auto &I : *(Node->getBlock()))
717 allocateCandidatesAndFindBasis(&I);
721 while (!Candidates.empty()) {
722 const Candidate &C = Candidates.back();
723 if (C.Basis !=
nullptr) {
724 rewriteCandidateWithBasis(C, *C.Basis);
726 Candidates.pop_back();
730 for (
auto *UnlinkedInst : UnlinkedInstructions) {
731 for (
unsigned I = 0,
E = UnlinkedInst->getNumOperands(); I !=
E; ++
I) {
732 Value *
Op = UnlinkedInst->getOperand(I);
733 UnlinkedInst->setOperand(I,
nullptr);
736 UnlinkedInst->deleteValue();
738 bool Ret = !UnlinkedInstructions.empty();
739 UnlinkedInstructions.clear();
Value * CreateInBoundsGEP(Value *Ptr, ArrayRef< Value *> IdxList, const Twine &Name="")
FunctionPass * createStraightLineStrengthReducePass()
A parsed version of the target data layout string in and methods for querying it. ...
Value * getPointerOperand()
class_match< Value > m_Value()
Match an arbitrary value and ignore it.
static PassRegistry * getPassRegistry()
getPassRegistry - Access the global registry object, which is automatically initialized at applicatio...
APInt sext(unsigned width) const
Sign extend to a new width.
This class represents lattice values for constants.
A Module instance is used to store all the information related to an LLVM module. ...
void push_back(const T &Elt)
BinaryOp_match< LHS, RHS, Instruction::Mul > m_Mul(const LHS &L, const RHS &R)
The main scalar evolution driver.
static void unifyBitWidth(APInt &A, APInt &B)
static void sdivrem(const APInt &LHS, const APInt &RHS, APInt &Quotient, APInt &Remainder)
LLVMContext & getContext() const
All values hold a context through their type.
unsigned getPointerSizeInBits(unsigned AS=0) const
Layout pointer size, in bits FIXME: The defaults need to be removed once all of the backends/clients ...
static bool matchesAdd(Value *A, Value *&B, ConstantInt *&C)
bool isVectorTy() const
True if this is an instance of VectorType.
unsigned getBitWidth() const
getBitWidth - Return the bitwidth of this constant.
unsigned getBitWidth() const
Return the number of bits in the APInt.
bool match(Val *V, const Pattern &P)
AnalysisUsage & addRequired()
#define INITIALIZE_PASS_DEPENDENCY(depName)
static bool matchesOr(Value *A, Value *&B, ConstantInt *&C)
const DataLayout & getDataLayout() const
Get the data layout for the module's target platform.
This provides a uniform API for creating instructions and inserting them into a basic block: either a...
Value * CreateAdd(Value *LHS, Value *RHS, const Twine &Name="", bool HasNUW=false, bool HasNSW=false)
Type * getSourceElementType() const
This file implements a class to represent arbitrary precision integral constant values and operations...
static bool hasOnlyOneNonZeroIndex(GetElementPtrInst *GEP)
BinaryOp_match< LHS, RHS, Instruction::Add > m_Add(const LHS &L, const RHS &R)
void initializeStraightLineStrengthReducePass(PassRegistry &)
Value * CreateBitCast(Value *V, Type *DestTy, const Twine &Name="")
Type * getType() const
All values are typed, get the type of this value.
Value * CreateSExtOrTrunc(Value *V, Type *DestTy, const Twine &Name="")
Create a SExt or Trunc from the integer value V to DestTy.
static bool isGEPFoldable(GetElementPtrInst *GEP, const TargetTransformInfo *TTI)
class_match< ConstantInt > m_ConstantInt()
Match an arbitrary ConstantInt and ignore it.
const APInt & getValue() const
Return the constant as an APInt value reference.
unsigned getOpcode() const
Returns a member of one of the enums like Instruction::Add.
INITIALIZE_PASS_BEGIN(StraightLineStrengthReduce, "slsr", "Straight line strength reduction", false, false) INITIALIZE_PASS_END(StraightLineStrengthReduce
Value * CreateSub(Value *LHS, Value *RHS, const Twine &Name="", bool HasNUW=false, bool HasNSW=false)
void takeName(Value *V)
Transfer the name from V to this value.
Concrete subclass of DominatorTreeBase that is used to compute a normal dominator tree...
Value * getOperand(unsigned i) const
unsigned getAddressSpace() const
Returns the address space of this instruction's pointer type.
an instruction for type-safe pointer arithmetic to access elements of arrays and structs ...
IntegerType * getIntPtrType(LLVMContext &C, unsigned AddressSpace=0) const
Returns an integer type with size at least as big as that of a pointer in the given address space...
static bool runOnFunction(Function &F, bool PostInlining)
bool isAllOnesValue() const
Determine if all bits are set.
static GCRegistry::Add< OcamlGC > B("ocaml", "ocaml 3.10-compatible GC")
The instances of the Type class are immutable: once they are created, they are never changed...
BinaryOp_match< LHS, RHS, Instruction::Or > m_Or(const LHS &L, const RHS &R)
static GCRegistry::Add< CoreCLRGC > E("coreclr", "CoreCLR-compatible GC")
This is an important base class in LLVM.
Straight line strength reduction
This file contains the declarations for the subclasses of Constant, which represent the different fla...
static const unsigned UnknownAddressSpace
TypeSize getTypeAllocSize(Type *Ty) const
Returns the offset in bytes between successive objects of the specified type, including alignment pad...
Represent the analysis usage information of a pass.
BinaryOp_match< LHS, RHS, Instruction::Shl > m_Shl(const LHS &L, const RHS &R)
Value * CreateNeg(Value *V, const Twine &Name="", bool HasNUW=false, bool HasNSW=false)
FunctionPass class - This class is used to implement most global optimizations.
Class to represent integer types.
bool RecursivelyDeleteTriviallyDeadInstructions(Value *V, const TargetLibraryInfo *TLI=nullptr, MemorySSAUpdater *MSSAU=nullptr)
If the specified value is a trivially dead instruction, delete it.
Type * getIndexedType() const
static PointerType * getInt8PtrTy(LLVMContext &C, unsigned AS=0)
INITIALIZE_PASS_END(RegBankSelect, DEBUG_TYPE, "Assign register bank of generic virtual registers", false, false) RegBankSelect
#define llvm_unreachable(msg)
Marks that the current location is not supposed to be reachable.
Value * CreateMul(Value *LHS, Value *RHS, const Twine &Name="", bool HasNUW=false, bool HasNSW=false)
CastClass_match< OpTy, Instruction::SExt > m_SExt(const OpTy &Op)
Matches SExt.
Value * CreateGEP(Value *Ptr, ArrayRef< Value *> IdxList, const Twine &Name="")
static IntegerType * get(LLVMContext &C, unsigned NumBits)
This static method is the primary way of constructing an IntegerType.
unsigned getNumOperands() const
This is the shared class of boolean and integer constants.
Type * getType() const
Return the LLVM type of this SCEV expression.
Align max(MaybeAlign Lhs, Align Rhs)
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small...
Module.h This file contains the declarations for the Module class.
static Constant * get(Type *Ty, uint64_t V, bool isSigned=false)
If Ty is a vector type, return a Constant with a splat of the given value.
void setPreservesCFG()
This function should be called by the pass, iff they do not:
unsigned logBase2() const
BinaryOp_match< cst_pred_ty< is_zero_int >, ValTy, Instruction::Sub > m_Neg(const ValTy &V)
Matches a 'Neg' as 'sub 0, V'.
Class for arbitrary precision integers.
bool isPowerOf2() const
Check if this APInt's value is a power of two greater than zero.
IntegerType * getInt8Ty()
Fetch the type representing an 8-bit integer.
Value * CreateShl(Value *LHS, Value *RHS, const Twine &Name="", bool HasNUW=false, bool HasNSW=false)
This class represents an analyzed expression in the program.
unsigned getIntegerBitWidth() const
OverflowingBinaryOp_match< LHS, RHS, Instruction::Shl, OverflowingBinaryOperator::NoSignedWrap > m_NSWShl(const LHS &L, const RHS &R)
bool haveNoCommonBitsSet(const Value *LHS, const Value *RHS, const DataLayout &DL, AssumptionCache *AC=nullptr, const Instruction *CxtI=nullptr, const DominatorTree *DT=nullptr, bool UseInstrInfo=true)
Return true if LHS and RHS have no common bits set.
bool isZero() const
This is just a convenience method to make client code smaller for a common code.
LLVM_NODISCARD std::enable_if<!is_simple_type< Y >::value, typename cast_retty< X, const Y >::ret_type >::type dyn_cast(const Y &Val)
iterator_range< df_iterator< T > > depth_first(const T &G)
assert(ImpDefSCC.getReg()==AMDGPU::SCC &&ImpDefSCC.isDef())
LLVM Value Representation.
OverflowingBinaryOp_match< LHS, RHS, Instruction::Mul, OverflowingBinaryOperator::NoSignedWrap > m_NSWMul(const LHS &L, const RHS &R)
Legacy analysis pass which computes a DominatorTree.
int64_t getSExtValue() const
Return the constant as a 64-bit integer value after it has been sign extended as appropriate for the ...
static bool isAddFoldable(const SCEV *Base, ConstantInt *Index, Value *Stride, TargetTransformInfo *TTI)
gep_type_iterator gep_type_begin(const User *GEP)