92using namespace PatternMatch;
95 std::numeric_limits<unsigned>::max();
98 "Controls whether rewriteCandidateWithBasis is executed.");
102class StraightLineStrengthReduceLegacyPass :
public FunctionPass {
122 DL = &
M.getDataLayout();
129class StraightLineStrengthReduce {
145 Candidate() =
default;
159 Value *Stride =
nullptr;
179 Candidate *Basis =
nullptr;
187 bool isBasisFor(
const Candidate &Basis,
const Candidate &
C);
195 bool isSimplestForm(
const Candidate &
C);
202 void allocateCandidatesAndFindBasisForAdd(
Instruction *
I);
206 void allocateCandidatesAndFindBasisForAdd(
Value *LHS,
Value *RHS,
209 void allocateCandidatesAndFindBasisForMul(
Instruction *
I);
213 void allocateCandidatesAndFindBasisForMul(
Value *LHS,
Value *RHS,
227 void allocateCandidatesAndFindBasis(Candidate::Kind CT,
const SCEV *
B,
232 void rewriteCandidateWithBasis(
const Candidate &
C,
const Candidate &Basis);
241 static Value *emitBump(
const Candidate &Basis,
const Candidate &
C,
248 std::list<Candidate> Candidates;
253 std::vector<Instruction *> UnlinkedInstructions;
258char StraightLineStrengthReduceLegacyPass::ID = 0;
261 "Straight line strength reduction",
false,
false)
269 return new StraightLineStrengthReduceLegacyPass();
272bool StraightLineStrengthReduce::isBasisFor(
const Candidate &Basis,
273 const Candidate &
C) {
274 return (Basis.Ins !=
C.Ins &&
277 Basis.Ins->getType() ==
C.Ins->getType() &&
279 DT->
dominates(Basis.Ins->getParent(),
C.Ins->getParent()) &&
281 Basis.Base ==
C.Base && Basis.Stride ==
C.Stride &&
282 Basis.CandidateKind ==
C.CandidateKind);
296 return Index->getBitWidth() <= 64 &&
301bool StraightLineStrengthReduce::isFoldable(
const Candidate &
C,
304 if (
C.CandidateKind == Candidate::Add)
306 if (
C.CandidateKind == Candidate::GEP)
313 unsigned NumNonZeroIndices = 0;
316 if (ConstIdx ==
nullptr || !ConstIdx->
isZero())
319 return NumNonZeroIndices <= 1;
322bool StraightLineStrengthReduce::isSimplestForm(
const Candidate &
C) {
323 if (
C.CandidateKind == Candidate::Add) {
325 return C.Index->isOne() ||
C.Index->isMinusOne();
327 if (
C.CandidateKind == Candidate::Mul) {
329 return C.Index->isZero();
331 if (
C.CandidateKind == Candidate::GEP) {
333 return ((
C.Index->isOne() ||
C.Index->isMinusOne()) &&
346void StraightLineStrengthReduce::allocateCandidatesAndFindBasis(
349 Candidate
C(CT,
B,
Idx, S,
I);
363 if (!isFoldable(
C,
TTI,
DL) && !isSimplestForm(
C)) {
365 unsigned NumIterations = 0;
367 static const unsigned MaxNumIterations = 50;
368 for (
auto Basis = Candidates.rbegin();
369 Basis != Candidates.rend() && NumIterations < MaxNumIterations;
370 ++Basis, ++NumIterations) {
371 if (isBasisFor(*Basis,
C)) {
379 Candidates.push_back(
C);
382void StraightLineStrengthReduce::allocateCandidatesAndFindBasis(
384 switch (
I->getOpcode()) {
385 case Instruction::Add:
386 allocateCandidatesAndFindBasisForAdd(
I);
388 case Instruction::Mul:
389 allocateCandidatesAndFindBasisForMul(
I);
391 case Instruction::GetElementPtr:
392 allocateCandidatesAndFindBasisForGEP(cast<GetElementPtrInst>(
I));
397void StraightLineStrengthReduce::allocateCandidatesAndFindBasisForAdd(
400 if (!isa<IntegerType>(
I->getType()))
403 assert(
I->getNumOperands() == 2 &&
"isn't I an add?");
405 allocateCandidatesAndFindBasisForAdd(LHS, RHS,
I);
407 allocateCandidatesAndFindBasisForAdd(RHS, LHS,
I);
410void StraightLineStrengthReduce::allocateCandidatesAndFindBasisForAdd(
416 allocateCandidatesAndFindBasis(Candidate::Add, SE->getSCEV(LHS),
Idx, S,
I);
420 Idx = ConstantInt::get(
Idx->getContext(), One <<
Idx->getValue());
421 allocateCandidatesAndFindBasis(Candidate::Add, SE->getSCEV(LHS),
Idx, S,
I);
424 ConstantInt *One = ConstantInt::get(cast<IntegerType>(
I->getType()), 1);
425 allocateCandidatesAndFindBasis(Candidate::Add, SE->getSCEV(LHS), One, RHS,
440void StraightLineStrengthReduce::allocateCandidatesAndFindBasisForMul(
447 allocateCandidatesAndFindBasis(Candidate::Mul, SE->getSCEV(
B),
Idx, RHS,
I);
453 allocateCandidatesAndFindBasis(Candidate::Mul, SE->getSCEV(
B),
Idx, RHS,
I);
457 allocateCandidatesAndFindBasis(Candidate::Mul, SE->getSCEV(LHS), Zero, RHS,
462void StraightLineStrengthReduce::allocateCandidatesAndFindBasisForMul(
466 if (!isa<IntegerType>(
I->getType()))
469 assert(
I->getNumOperands() == 2 &&
"isn't I a mul?");
471 allocateCandidatesAndFindBasisForMul(LHS, RHS,
I);
474 allocateCandidatesAndFindBasisForMul(RHS, LHS,
I);
478void StraightLineStrengthReduce::allocateCandidatesAndFindBasisForGEP(
485 IntegerType *PtrIdxTy = cast<IntegerType>(
DL->getIndexType(
I->getType()));
487 PtrIdxTy,
Idx->getSExtValue() * (int64_t)ElementSize,
true);
488 allocateCandidatesAndFindBasis(Candidate::GEP,
B, ScaledIdx, S,
I);
491void StraightLineStrengthReduce::factorArrayIndex(
Value *ArrayIdx,
496 allocateCandidatesAndFindBasisForGEP(
497 Base, ConstantInt::get(cast<IntegerType>(ArrayIdx->
getType()), 1),
498 ArrayIdx, ElementSize,
GEP);
515 allocateCandidatesAndFindBasisForGEP(
Base, RHS, LHS, ElementSize,
GEP);
522 allocateCandidatesAndFindBasisForGEP(
Base, PowerOf2, LHS, ElementSize,
GEP);
526void StraightLineStrengthReduce::allocateCandidatesAndFindBasisForGEP(
529 if (
GEP->getType()->isVectorTy())
537 for (
unsigned I = 1, E =
GEP->getNumOperands();
I != E; ++
I, ++GTI) {
541 const SCEV *OrigIndexExpr = IndexExprs[
I - 1];
542 IndexExprs[
I - 1] = SE->getZero(OrigIndexExpr->
getType());
546 const SCEV *BaseExpr = SE->getGEPExpr(cast<GEPOperator>(
GEP), IndexExprs);
550 DL->getIndexSizeInBits(
GEP->getAddressSpace())) {
553 factorArrayIndex(ArrayIdx, BaseExpr, ElementSize,
GEP);
558 Value *TruncatedArrayIdx =
nullptr;
561 DL->getIndexSizeInBits(
GEP->getAddressSpace())) {
564 factorArrayIndex(TruncatedArrayIdx, BaseExpr, ElementSize,
GEP);
567 IndexExprs[
I - 1] = OrigIndexExpr;
573 if (
A.getBitWidth() <
B.getBitWidth())
574 A =
A.sext(
B.getBitWidth());
575 else if (
A.getBitWidth() >
B.getBitWidth())
576 B =
B.sext(
A.getBitWidth());
579Value *StraightLineStrengthReduce::emitBump(
const Candidate &Basis,
583 APInt Idx =
C.Index->getValue(), BasisIdx = Basis.Index->getValue();
589 if (IndexOffset == 1)
608 ConstantInt::get(DeltaType, (-IndexOffset).logBase2());
611 Constant *Delta = ConstantInt::get(DeltaType, IndexOffset);
612 return Builder.
CreateMul(ExtendedStride, Delta);
615void StraightLineStrengthReduce::rewriteCandidateWithBasis(
616 const Candidate &
C,
const Candidate &Basis) {
620 assert(
C.CandidateKind == Basis.CandidateKind &&
C.Base == Basis.Base &&
621 C.Stride == Basis.Stride);
624 assert(Basis.Ins->getParent() !=
nullptr &&
"the basis is unlinked");
630 if (!
C.Ins->getParent())
634 Value *Bump = emitBump(Basis,
C, Builder,
DL);
635 Value *Reduced =
nullptr;
636 switch (
C.CandidateKind) {
638 case Candidate::Mul: {
643 Reduced = Builder.
CreateSub(Basis.Ins, NegBump);
657 Reduced = Builder.
CreateAdd(Basis.Ins, Bump);
661 case Candidate::GEP: {
662 bool InBounds = cast<GetElementPtrInst>(
C.Ins)->isInBounds();
664 Reduced = Builder.
CreatePtrAdd(Basis.Ins, Bump,
"", InBounds);
671 C.Ins->replaceAllUsesWith(Reduced);
674 C.Ins->removeFromParent();
675 UnlinkedInstructions.push_back(
C.Ins);
678bool StraightLineStrengthReduceLegacyPass::runOnFunction(
Function &
F) {
682 auto *
TTI = &getAnalysis<TargetTransformInfoWrapperPass>().getTTI(
F);
683 auto *DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree();
684 auto *SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE();
685 return StraightLineStrengthReduce(
DL, DT, SE,
TTI).runOnFunction(
F);
688bool StraightLineStrengthReduce::runOnFunction(
Function &
F) {
692 for (
auto &
I : *(
Node->getBlock()))
693 allocateCandidatesAndFindBasis(&
I);
697 while (!Candidates.empty()) {
698 const Candidate &
C = Candidates.back();
699 if (
C.Basis !=
nullptr) {
700 rewriteCandidateWithBasis(
C, *
C.Basis);
702 Candidates.pop_back();
706 for (
auto *UnlinkedInst : UnlinkedInstructions) {
707 for (
unsigned I = 0, E = UnlinkedInst->getNumOperands();
I != E; ++
I) {
708 Value *
Op = UnlinkedInst->getOperand(
I);
709 UnlinkedInst->setOperand(
I,
nullptr);
712 UnlinkedInst->deleteValue();
714 bool Ret = !UnlinkedInstructions.empty();
715 UnlinkedInstructions.clear();
728 if (!StraightLineStrengthReduce(
DL, DT, SE,
TTI).runOnFunction(
F))
This file implements a class to represent arbitrary precision integral constant values and operations...
MachineBasicBlock MachineBasicBlock::iterator DebugLoc DL
static GCRegistry::Add< OcamlGC > B("ocaml", "ocaml 3.10-compatible GC")
static GCRegistry::Add< ErlangGC > A("erlang", "erlang-compatible garbage collector")
This file contains the declarations for the subclasses of Constant, which represent the different fla...
Returns the sub type a function will return at a given Idx Should correspond to the result type of an ExtractValue instruction executed with just that one unsigned Idx
This file provides an implementation of debug counters.
#define DEBUG_COUNTER(VARNAME, COUNTERNAME, DESC)
This file builds on the ADT/GraphTraits.h file to build generic depth first graph iterator.
static bool runOnFunction(Function &F, bool PostInlining)
Module.h This file contains the declarations for the Module class.
#define INITIALIZE_PASS_DEPENDENCY(depName)
#define INITIALIZE_PASS_END(passName, arg, name, cfg, analysis)
#define INITIALIZE_PASS_BEGIN(passName, arg, name, cfg, analysis)
assert(ImpDefSCC.getReg()==AMDGPU::SCC &&ImpDefSCC.isDef())
This file defines the SmallVector class.
static bool matchesOr(Value *A, Value *&B, ConstantInt *&C)
static bool isAddFoldable(const SCEV *Base, ConstantInt *Index, Value *Stride, TargetTransformInfo *TTI)
static bool hasOnlyOneNonZeroIndex(GetElementPtrInst *GEP)
static void unifyBitWidth(APInt &A, APInt &B)
static bool matchesAdd(Value *A, Value *&B, ConstantInt *&C)
static bool isGEPFoldable(GetElementPtrInst *GEP, const TargetTransformInfo *TTI)
static const unsigned UnknownAddressSpace
Straight line strength reduction
Class for arbitrary precision integers.
bool isNegatedPowerOf2() const
Check if this APInt's negated value is a power of two greater than zero.
bool isAllOnes() const
Determine if all bits are set. This is true for zero-width values.
unsigned getBitWidth() const
Return the number of bits in the APInt.
unsigned logBase2() const
bool isPowerOf2() const
Check if this APInt's value is a power of two greater than zero.
A container for analyses that lazily runs them and caches their results.
PassT::Result & getResult(IRUnitT &IR, ExtraArgTs... ExtraArgs)
Get the result of an analysis pass for a given IR unit.
Represent the analysis usage information of a pass.
AnalysisUsage & addRequired()
void setPreservesCFG()
This function should be called by the pass, iff they do not:
Represents analyses that only rely on functions' control flow.
This is the shared class of boolean and integer constants.
bool isZero() const
This is just a convenience method to make client code smaller for a common code.
This is an important base class in LLVM.
This class represents an Operation in the Expression.
A parsed version of the target data layout string in and methods for querying it.
static bool shouldExecute(unsigned CounterName)
Analysis pass which computes a DominatorTree.
Legacy analysis pass which computes a DominatorTree.
Concrete subclass of DominatorTreeBase that is used to compute a normal dominator tree.
bool dominates(const BasicBlock *BB, const Use &U) const
Return true if the (end of the) basic block BB dominates the use U.
FunctionPass class - This class is used to implement most global optimizations.
virtual bool runOnFunction(Function &F)=0
runOnFunction - Virtual method overriden by subclasses to do the per-function processing of the pass.
an instruction for type-safe pointer arithmetic to access elements of arrays and structs
Value * CreatePtrAdd(Value *Ptr, Value *Offset, const Twine &Name="", GEPNoWrapFlags NW=GEPNoWrapFlags::none())
Value * CreateNeg(Value *V, const Twine &Name="", bool HasNSW=false)
Value * CreateSub(Value *LHS, Value *RHS, const Twine &Name="", bool HasNUW=false, bool HasNSW=false)
Value * CreateShl(Value *LHS, Value *RHS, const Twine &Name="", bool HasNUW=false, bool HasNSW=false)
Value * CreateAdd(Value *LHS, Value *RHS, const Twine &Name="", bool HasNUW=false, bool HasNSW=false)
Value * CreateSExtOrTrunc(Value *V, Type *DestTy, const Twine &Name="")
Create a SExt or Trunc from the integer value V to DestTy.
Value * CreateMul(Value *LHS, Value *RHS, const Twine &Name="", bool HasNUW=false, bool HasNSW=false)
This provides a uniform API for creating instructions and inserting them into a basic block: either a...
Class to represent integer types.
static IntegerType * get(LLVMContext &C, unsigned NumBits)
This static method is the primary way of constructing an IntegerType.
A Module instance is used to store all the information related to an LLVM module.
static PassRegistry * getPassRegistry()
getPassRegistry - Access the global registry object, which is automatically initialized at applicatio...
virtual void getAnalysisUsage(AnalysisUsage &) const
getAnalysisUsage - This function should be overriden by passes that need analysis information to do t...
virtual bool doInitialization(Module &)
doInitialization - Virtual method overridden by subclasses to do any necessary initialization before ...
A set of analyses that are preserved following a run of a transformation pass.
static PreservedAnalyses all()
Construct a special preserved set that preserves all passes.
void preserveSet()
Mark an analysis set as preserved.
void preserve()
Mark an analysis as preserved.
This class represents an analyzed expression in the program.
Type * getType() const
Return the LLVM type of this SCEV expression.
Analysis pass that exposes the ScalarEvolution for a function.
The main scalar evolution driver.
void push_back(const T &Elt)
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM)
Analysis pass providing the TargetTransformInfo.
unsigned getIntegerBitWidth() const
A Use represents the edge between a Value definition and its users.
LLVM Value Representation.
Type * getType() const
All values are typed, get the type of this value.
LLVMContext & getContext() const
All values hold a context through their type.
void takeName(Value *V)
Transfer the name from V to this value.
TypeSize getSequentialElementStride(const DataLayout &DL) const
#define llvm_unreachable(msg)
Marks that the current location is not supposed to be reachable.
@ C
The default llvm calling convention, compatible with C.
unsigned ID
LLVM IR allows to use arbitrary numbers as calling convention identifiers.
bool match(Val *V, const Pattern &P)
class_match< ConstantInt > m_ConstantInt()
Match an arbitrary ConstantInt and ignore it.
BinaryOp_match< LHS, RHS, Instruction::Mul > m_Mul(const LHS &L, const RHS &R)
BinaryOp_match< cst_pred_ty< is_zero_int >, ValTy, Instruction::Sub > m_Neg(const ValTy &V)
Matches a 'Neg' as 'sub 0, V'.
OverflowingBinaryOp_match< LHS, RHS, Instruction::Shl, OverflowingBinaryOperator::NoSignedWrap > m_NSWShl(const LHS &L, const RHS &R)
BinaryOp_match< LHS, RHS, Instruction::Add, true > m_c_Add(const LHS &L, const RHS &R)
Matches a Add with LHS and RHS in either order.
class_match< Value > m_Value()
Match an arbitrary value and ignore it.
BinaryOp_match< LHS, RHS, Instruction::Shl > m_Shl(const LHS &L, const RHS &R)
CastInst_match< OpTy, SExtInst > m_SExt(const OpTy &Op)
Matches SExt.
BinaryOp_match< LHS, RHS, Instruction::Or, true > m_c_Or(const LHS &L, const RHS &R)
Matches an Or with LHS and RHS in either order.
OverflowingBinaryOp_match< LHS, RHS, Instruction::Mul, OverflowingBinaryOperator::NoSignedWrap > m_NSWMul(const LHS &L, const RHS &R)
This is an optimization pass for GlobalISel generic memory operations.
bool haveNoCommonBitsSet(const WithCache< const Value * > &LHSCache, const WithCache< const Value * > &RHSCache, const SimplifyQuery &SQ)
Return true if LHS and RHS have no common bits set.
bool RecursivelyDeleteTriviallyDeadInstructions(Value *V, const TargetLibraryInfo *TLI=nullptr, MemorySSAUpdater *MSSAU=nullptr, std::function< void(Value *)> AboutToDeleteCallback=std::function< void(Value *)>())
If the specified value is a trivially dead instruction, delete it.
void initializeStraightLineStrengthReduceLegacyPassPass(PassRegistry &)
gep_type_iterator gep_type_begin(const User *GEP)
iterator_range< df_iterator< T > > depth_first(const T &G)
FunctionPass * createStraightLineStrengthReducePass()