91using namespace PatternMatch;
94 std::numeric_limits<unsigned>::max();
98class StraightLineStrengthReduceLegacyPass :
public FunctionPass {
118 DL = &
M.getDataLayout();
125class StraightLineStrengthReduce {
141 Candidate() =
default;
155 Value *Stride =
nullptr;
175 Candidate *Basis =
nullptr;
183 bool isBasisFor(
const Candidate &Basis,
const Candidate &
C);
191 bool isSimplestForm(
const Candidate &
C);
198 void allocateCandidatesAndFindBasisForAdd(
Instruction *
I);
202 void allocateCandidatesAndFindBasisForAdd(
Value *LHS,
Value *RHS,
205 void allocateCandidatesAndFindBasisForMul(
Instruction *
I);
209 void allocateCandidatesAndFindBasisForMul(
Value *LHS,
Value *RHS,
223 void allocateCandidatesAndFindBasis(Candidate::Kind CT,
const SCEV *
B,
228 void rewriteCandidateWithBasis(
const Candidate &
C,
const Candidate &Basis);
237 static Value *emitBump(
const Candidate &Basis,
const Candidate &
C,
244 std::list<Candidate> Candidates;
249 std::vector<Instruction *> UnlinkedInstructions;
254char StraightLineStrengthReduceLegacyPass::ID = 0;
257 "Straight line strength reduction",
false,
false)
265 return new StraightLineStrengthReduceLegacyPass();
268bool StraightLineStrengthReduce::isBasisFor(
const Candidate &Basis,
269 const Candidate &
C) {
270 return (Basis.Ins !=
C.Ins &&
273 Basis.Ins->getType() ==
C.Ins->getType() &&
275 DT->dominates(Basis.Ins->getParent(),
C.Ins->getParent()) &&
277 Basis.Base ==
C.Base && Basis.Stride ==
C.Stride &&
278 Basis.CandidateKind ==
C.CandidateKind);
292 return Index->getBitWidth() <= 64 &&
297bool StraightLineStrengthReduce::isFoldable(
const Candidate &
C,
300 if (
C.CandidateKind == Candidate::Add)
302 if (
C.CandidateKind == Candidate::GEP)
309 unsigned NumNonZeroIndices = 0;
312 if (ConstIdx ==
nullptr || !ConstIdx->
isZero())
315 return NumNonZeroIndices <= 1;
318bool StraightLineStrengthReduce::isSimplestForm(
const Candidate &
C) {
319 if (
C.CandidateKind == Candidate::Add) {
321 return C.Index->isOne() ||
C.Index->isMinusOne();
323 if (
C.CandidateKind == Candidate::Mul) {
325 return C.Index->isZero();
327 if (
C.CandidateKind == Candidate::GEP) {
329 return ((
C.Index->isOne() ||
C.Index->isMinusOne()) &&
342void StraightLineStrengthReduce::allocateCandidatesAndFindBasis(
345 Candidate
C(CT,
B,
Idx, S,
I);
359 if (!isFoldable(
C,
TTI,
DL) && !isSimplestForm(
C)) {
361 unsigned NumIterations = 0;
363 static const unsigned MaxNumIterations = 50;
364 for (
auto Basis = Candidates.rbegin();
365 Basis != Candidates.rend() && NumIterations < MaxNumIterations;
366 ++Basis, ++NumIterations) {
367 if (isBasisFor(*Basis,
C)) {
375 Candidates.push_back(
C);
378void StraightLineStrengthReduce::allocateCandidatesAndFindBasis(
380 switch (
I->getOpcode()) {
381 case Instruction::Add:
382 allocateCandidatesAndFindBasisForAdd(
I);
384 case Instruction::Mul:
385 allocateCandidatesAndFindBasisForMul(
I);
387 case Instruction::GetElementPtr:
388 allocateCandidatesAndFindBasisForGEP(cast<GetElementPtrInst>(
I));
393void StraightLineStrengthReduce::allocateCandidatesAndFindBasisForAdd(
396 if (!isa<IntegerType>(
I->getType()))
399 assert(
I->getNumOperands() == 2 &&
"isn't I an add?");
401 allocateCandidatesAndFindBasisForAdd(LHS, RHS,
I);
403 allocateCandidatesAndFindBasisForAdd(RHS, LHS,
I);
406void StraightLineStrengthReduce::allocateCandidatesAndFindBasisForAdd(
412 allocateCandidatesAndFindBasis(Candidate::Add, SE->getSCEV(LHS),
Idx, S,
I);
416 Idx = ConstantInt::get(
Idx->getContext(), One <<
Idx->getValue());
417 allocateCandidatesAndFindBasis(Candidate::Add, SE->getSCEV(LHS),
Idx, S,
I);
420 ConstantInt *One = ConstantInt::get(cast<IntegerType>(
I->getType()), 1);
421 allocateCandidatesAndFindBasis(Candidate::Add, SE->getSCEV(LHS), One, RHS,
436void StraightLineStrengthReduce::allocateCandidatesAndFindBasisForMul(
443 allocateCandidatesAndFindBasis(Candidate::Mul, SE->getSCEV(
B),
Idx, RHS,
I);
449 allocateCandidatesAndFindBasis(Candidate::Mul, SE->getSCEV(
B),
Idx, RHS,
I);
453 allocateCandidatesAndFindBasis(Candidate::Mul, SE->getSCEV(LHS), Zero, RHS,
458void StraightLineStrengthReduce::allocateCandidatesAndFindBasisForMul(
462 if (!isa<IntegerType>(
I->getType()))
465 assert(
I->getNumOperands() == 2 &&
"isn't I a mul?");
467 allocateCandidatesAndFindBasisForMul(LHS, RHS,
I);
470 allocateCandidatesAndFindBasisForMul(RHS, LHS,
I);
474void StraightLineStrengthReduce::allocateCandidatesAndFindBasisForGEP(
481 IntegerType *PtrIdxTy = cast<IntegerType>(
DL->getIndexType(
I->getType()));
483 PtrIdxTy,
Idx->getSExtValue() * (int64_t)ElementSize,
true);
484 allocateCandidatesAndFindBasis(Candidate::GEP,
B, ScaledIdx, S,
I);
487void StraightLineStrengthReduce::factorArrayIndex(
Value *ArrayIdx,
492 allocateCandidatesAndFindBasisForGEP(
493 Base, ConstantInt::get(cast<IntegerType>(ArrayIdx->
getType()), 1),
494 ArrayIdx, ElementSize,
GEP);
511 allocateCandidatesAndFindBasisForGEP(
Base, RHS, LHS, ElementSize,
GEP);
518 allocateCandidatesAndFindBasisForGEP(
Base, PowerOf2, LHS, ElementSize,
GEP);
522void StraightLineStrengthReduce::allocateCandidatesAndFindBasisForGEP(
525 if (
GEP->getType()->isVectorTy())
533 for (
unsigned I = 1, E =
GEP->getNumOperands();
I != E; ++
I, ++GTI) {
537 const SCEV *OrigIndexExpr = IndexExprs[
I - 1];
538 IndexExprs[
I - 1] = SE->getZero(OrigIndexExpr->
getType());
542 const SCEV *BaseExpr = SE->getGEPExpr(cast<GEPOperator>(
GEP), IndexExprs);
546 DL->getIndexSizeInBits(
GEP->getAddressSpace())) {
549 factorArrayIndex(ArrayIdx, BaseExpr, ElementSize,
GEP);
554 Value *TruncatedArrayIdx =
nullptr;
557 DL->getIndexSizeInBits(
GEP->getAddressSpace())) {
560 factorArrayIndex(TruncatedArrayIdx, BaseExpr, ElementSize,
GEP);
563 IndexExprs[
I - 1] = OrigIndexExpr;
569 if (
A.getBitWidth() <
B.getBitWidth())
570 A =
A.sext(
B.getBitWidth());
571 else if (
A.getBitWidth() >
B.getBitWidth())
572 B =
B.sext(
A.getBitWidth());
575Value *StraightLineStrengthReduce::emitBump(
const Candidate &Basis,
579 APInt Idx =
C.Index->getValue(), BasisIdx = Basis.Index->getValue();
585 if (IndexOffset == 1)
604 ConstantInt::get(DeltaType, (-IndexOffset).logBase2());
607 Constant *Delta = ConstantInt::get(DeltaType, IndexOffset);
608 return Builder.
CreateMul(ExtendedStride, Delta);
611void StraightLineStrengthReduce::rewriteCandidateWithBasis(
612 const Candidate &
C,
const Candidate &Basis) {
613 assert(
C.CandidateKind == Basis.CandidateKind &&
C.Base == Basis.Base &&
614 C.Stride == Basis.Stride);
617 assert(Basis.Ins->getParent() !=
nullptr &&
"the basis is unlinked");
623 if (!
C.Ins->getParent())
627 Value *Bump = emitBump(Basis,
C, Builder,
DL);
628 Value *Reduced =
nullptr;
629 switch (
C.CandidateKind) {
631 case Candidate::Mul: {
636 Reduced = Builder.
CreateSub(Basis.Ins, NegBump);
650 Reduced = Builder.
CreateAdd(Basis.Ins, Bump);
654 case Candidate::GEP: {
655 bool InBounds = cast<GetElementPtrInst>(
C.Ins)->isInBounds();
657 Reduced = Builder.
CreatePtrAdd(Basis.Ins, Bump,
"", InBounds);
664 C.Ins->replaceAllUsesWith(Reduced);
667 C.Ins->removeFromParent();
668 UnlinkedInstructions.push_back(
C.Ins);
671bool StraightLineStrengthReduceLegacyPass::runOnFunction(
Function &
F) {
675 auto *
TTI = &getAnalysis<TargetTransformInfoWrapperPass>().getTTI(
F);
676 auto *DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree();
677 auto *SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE();
678 return StraightLineStrengthReduce(
DL, DT, SE,
TTI).runOnFunction(
F);
681bool StraightLineStrengthReduce::runOnFunction(
Function &
F) {
685 for (
auto &
I : *(
Node->getBlock()))
686 allocateCandidatesAndFindBasis(&
I);
690 while (!Candidates.empty()) {
691 const Candidate &
C = Candidates.back();
692 if (
C.Basis !=
nullptr) {
693 rewriteCandidateWithBasis(
C, *
C.Basis);
695 Candidates.pop_back();
699 for (
auto *UnlinkedInst : UnlinkedInstructions) {
700 for (
unsigned I = 0, E = UnlinkedInst->getNumOperands();
I != E; ++
I) {
701 Value *
Op = UnlinkedInst->getOperand(
I);
702 UnlinkedInst->setOperand(
I,
nullptr);
705 UnlinkedInst->deleteValue();
707 bool Ret = !UnlinkedInstructions.empty();
708 UnlinkedInstructions.clear();
721 if (!StraightLineStrengthReduce(
DL, DT, SE,
TTI).runOnFunction(
F))
This file implements a class to represent arbitrary precision integral constant values and operations...
MachineBasicBlock MachineBasicBlock::iterator DebugLoc DL
static GCRegistry::Add< OcamlGC > B("ocaml", "ocaml 3.10-compatible GC")
static GCRegistry::Add< ErlangGC > A("erlang", "erlang-compatible garbage collector")
This file contains the declarations for the subclasses of Constant, which represent the different fla...
Returns the sub type a function will return at a given Idx Should correspond to the result type of an ExtractValue instruction executed with just that one unsigned Idx
This file builds on the ADT/GraphTraits.h file to build generic depth first graph iterator.
static bool runOnFunction(Function &F, bool PostInlining)
Module.h This file contains the declarations for the Module class.
#define INITIALIZE_PASS_DEPENDENCY(depName)
#define INITIALIZE_PASS_END(passName, arg, name, cfg, analysis)
#define INITIALIZE_PASS_BEGIN(passName, arg, name, cfg, analysis)
assert(ImpDefSCC.getReg()==AMDGPU::SCC &&ImpDefSCC.isDef())
This file defines the SmallVector class.
static bool matchesOr(Value *A, Value *&B, ConstantInt *&C)
static bool isAddFoldable(const SCEV *Base, ConstantInt *Index, Value *Stride, TargetTransformInfo *TTI)
static bool hasOnlyOneNonZeroIndex(GetElementPtrInst *GEP)
static void unifyBitWidth(APInt &A, APInt &B)
static bool matchesAdd(Value *A, Value *&B, ConstantInt *&C)
static bool isGEPFoldable(GetElementPtrInst *GEP, const TargetTransformInfo *TTI)
static const unsigned UnknownAddressSpace
Straight line strength reduction
Class for arbitrary precision integers.
bool isNegatedPowerOf2() const
Check if this APInt's negated value is a power of two greater than zero.
bool isAllOnes() const
Determine if all bits are set. This is true for zero-width values.
unsigned getBitWidth() const
Return the number of bits in the APInt.
unsigned logBase2() const
bool isPowerOf2() const
Check if this APInt's value is a power of two greater than zero.
A container for analyses that lazily runs them and caches their results.
PassT::Result & getResult(IRUnitT &IR, ExtraArgTs... ExtraArgs)
Get the result of an analysis pass for a given IR unit.
Represent the analysis usage information of a pass.
AnalysisUsage & addRequired()
void setPreservesCFG()
This function should be called by the pass, iff they do not:
Represents analyses that only rely on functions' control flow.
This is the shared class of boolean and integer constants.
bool isZero() const
This is just a convenience method to make client code smaller for a common code.
This is an important base class in LLVM.
This class represents an Operation in the Expression.
A parsed version of the target data layout string in and methods for querying it.
Analysis pass which computes a DominatorTree.
Legacy analysis pass which computes a DominatorTree.
Concrete subclass of DominatorTreeBase that is used to compute a normal dominator tree.
FunctionPass class - This class is used to implement most global optimizations.
virtual bool runOnFunction(Function &F)=0
runOnFunction - Virtual method overriden by subclasses to do the per-function processing of the pass.
an instruction for type-safe pointer arithmetic to access elements of arrays and structs
Value * CreatePtrAdd(Value *Ptr, Value *Offset, const Twine &Name="", GEPNoWrapFlags NW=GEPNoWrapFlags::none())
Value * CreateNeg(Value *V, const Twine &Name="", bool HasNSW=false)
Value * CreateSub(Value *LHS, Value *RHS, const Twine &Name="", bool HasNUW=false, bool HasNSW=false)
Value * CreateShl(Value *LHS, Value *RHS, const Twine &Name="", bool HasNUW=false, bool HasNSW=false)
Value * CreateAdd(Value *LHS, Value *RHS, const Twine &Name="", bool HasNUW=false, bool HasNSW=false)
Value * CreateSExtOrTrunc(Value *V, Type *DestTy, const Twine &Name="")
Create a SExt or Trunc from the integer value V to DestTy.
Value * CreateMul(Value *LHS, Value *RHS, const Twine &Name="", bool HasNUW=false, bool HasNSW=false)
This provides a uniform API for creating instructions and inserting them into a basic block: either a...
Class to represent integer types.
static IntegerType * get(LLVMContext &C, unsigned NumBits)
This static method is the primary way of constructing an IntegerType.
A Module instance is used to store all the information related to an LLVM module.
static PassRegistry * getPassRegistry()
getPassRegistry - Access the global registry object, which is automatically initialized at applicatio...
virtual void getAnalysisUsage(AnalysisUsage &) const
getAnalysisUsage - This function should be overriden by passes that need analysis information to do t...
virtual bool doInitialization(Module &)
doInitialization - Virtual method overridden by subclasses to do any necessary initialization before ...
A set of analyses that are preserved following a run of a transformation pass.
static PreservedAnalyses all()
Construct a special preserved set that preserves all passes.
void preserveSet()
Mark an analysis set as preserved.
void preserve()
Mark an analysis as preserved.
This class represents an analyzed expression in the program.
Type * getType() const
Return the LLVM type of this SCEV expression.
Analysis pass that exposes the ScalarEvolution for a function.
The main scalar evolution driver.
void push_back(const T &Elt)
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM)
Analysis pass providing the TargetTransformInfo.
unsigned getIntegerBitWidth() const
A Use represents the edge between a Value definition and its users.
LLVM Value Representation.
Type * getType() const
All values are typed, get the type of this value.
LLVMContext & getContext() const
All values hold a context through their type.
void takeName(Value *V)
Transfer the name from V to this value.
TypeSize getSequentialElementStride(const DataLayout &DL) const
#define llvm_unreachable(msg)
Marks that the current location is not supposed to be reachable.
unsigned ID
LLVM IR allows to use arbitrary numbers as calling convention identifiers.
@ C
The default llvm calling convention, compatible with C.
bool match(Val *V, const Pattern &P)
class_match< ConstantInt > m_ConstantInt()
Match an arbitrary ConstantInt and ignore it.
BinaryOp_match< LHS, RHS, Instruction::Mul > m_Mul(const LHS &L, const RHS &R)
BinaryOp_match< cst_pred_ty< is_zero_int >, ValTy, Instruction::Sub > m_Neg(const ValTy &V)
Matches a 'Neg' as 'sub 0, V'.
OverflowingBinaryOp_match< LHS, RHS, Instruction::Shl, OverflowingBinaryOperator::NoSignedWrap > m_NSWShl(const LHS &L, const RHS &R)
BinaryOp_match< LHS, RHS, Instruction::Add, true > m_c_Add(const LHS &L, const RHS &R)
Matches a Add with LHS and RHS in either order.
class_match< Value > m_Value()
Match an arbitrary value and ignore it.
BinaryOp_match< LHS, RHS, Instruction::Shl > m_Shl(const LHS &L, const RHS &R)
CastInst_match< OpTy, SExtInst > m_SExt(const OpTy &Op)
Matches SExt.
BinaryOp_match< LHS, RHS, Instruction::Or, true > m_c_Or(const LHS &L, const RHS &R)
Matches an Or with LHS and RHS in either order.
OverflowingBinaryOp_match< LHS, RHS, Instruction::Mul, OverflowingBinaryOperator::NoSignedWrap > m_NSWMul(const LHS &L, const RHS &R)
This is an optimization pass for GlobalISel generic memory operations.
bool haveNoCommonBitsSet(const WithCache< const Value * > &LHSCache, const WithCache< const Value * > &RHSCache, const SimplifyQuery &SQ)
Return true if LHS and RHS have no common bits set.
bool RecursivelyDeleteTriviallyDeadInstructions(Value *V, const TargetLibraryInfo *TLI=nullptr, MemorySSAUpdater *MSSAU=nullptr, std::function< void(Value *)> AboutToDeleteCallback=std::function< void(Value *)>())
If the specified value is a trivially dead instruction, delete it.
void initializeStraightLineStrengthReduceLegacyPassPass(PassRegistry &)
gep_type_iterator gep_type_begin(const User *GEP)
iterator_range< df_iterator< T > > depth_first(const T &G)
FunctionPass * createStraightLineStrengthReducePass()