88 using namespace PatternMatch;
90 #define DEBUG_TYPE "nary-reassociate"
101 bool doInitialization(
Module &M)
override {
104 bool runOnFunction(
Function &
F)
override;
125 "Nary reassociation",
false,
false)
135 return new NaryReassociateLegacyPass();
138 bool NaryReassociateLegacyPass::runOnFunction(
Function &
F) {
142 auto *AC = &getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F);
143 auto *DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree();
144 auto *SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE();
145 auto *TLI = &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI();
146 auto *TTI = &getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
148 return Impl.runImpl(F, AC, DT, SE, TLI, TTI);
159 bool Changed =
runImpl(F, AC, DT, SE, TLI, TTI);
187 bool Changed =
false, ChangedInThisIteration;
189 ChangedInThisIteration = doOneIteration(F);
190 Changed |= ChangedInThisIteration;
191 }
while (ChangedInThisIteration);
199 case Instruction::GetElementPtr:
200 case Instruction::Mul:
207 bool NaryReassociatePass::doOneIteration(
Function &F) {
208 bool Changed =
false;
217 const SCEV *OldSCEV = SE->getSCEV(&*
I);
220 SE->forgetValue(&*
I);
221 I->replaceAllUsesWith(NewI);
225 I = NewI->getIterator();
229 const SCEV *NewSCEV = SE->getSCEV(&*
I);
230 SeenExprs[NewSCEV].push_back(
WeakVH(&*
I));
249 if (NewSCEV != OldSCEV)
250 SeenExprs[OldSCEV].push_back(
WeakVH(&*
I));
260 case Instruction::Mul:
261 return tryReassociateBinaryOp(cast<BinaryOperator>(I));
262 case Instruction::GetElementPtr:
263 return tryReassociateGEP(cast<GetElementPtrInst>(I));
265 llvm_unreachable(
"should be filtered out by isPotentiallyNaryReassociable");
286 if (
auto *NewGEP = tryReassociateGEPAtIndex(GEP, I - 1,
295 bool NaryReassociatePass::requiresSignExtension(
Value *Index,
297 unsigned PointerSizeInBits =
304 unsigned I,
Type *IndexedType) {
306 if (
SExtInst *SExt = dyn_cast<SExtInst>(IndexToSplit)) {
307 IndexToSplit = SExt->getOperand(0);
308 }
else if (
ZExtInst *ZExt = dyn_cast<ZExtInst>(IndexToSplit)) {
311 IndexToSplit = ZExt->getOperand(0);
314 if (
AddOperator *AO = dyn_cast<AddOperator>(IndexToSplit)) {
318 if (requiresSignExtension(IndexToSplit, GEP) &&
323 Value *LHS = AO->getOperand(0), *RHS = AO->getOperand(1);
325 if (
auto *NewGEP = tryReassociateGEPAtIndex(GEP, I, LHS, RHS, IndexedType))
330 tryReassociateGEPAtIndex(GEP, I, RHS, LHS, IndexedType))
339 unsigned I,
Value *LHS,
345 IndexExprs.
push_back(SE->getSCEV(*Index));
347 IndexExprs[
I] = SE->getSCEV(LHS);
358 const SCEV *CandidateExpr = SE->getGEPExpr(cast<GEPOperator>(GEP),
361 Value *Candidate = findClosestMatchingDominator(CandidateExpr, GEP);
362 if (Candidate ==
nullptr)
369 Candidate = Builder.CreateBitOrPointerCast(Candidate, GEP->
getType());
373 uint64_t IndexedSize =
DL->getTypeAllocSize(IndexedType);
375 uint64_t ElementSize =
DL->getTypeAllocSize(ElementType);
390 if (IndexedSize % ElementSize != 0)
395 if (RHS->
getType() != IntPtrTy)
396 RHS = Builder.CreateSExtOrTrunc(RHS, IntPtrTy);
397 if (IndexedSize != ElementSize) {
398 RHS = Builder.CreateMul(
402 cast<GetElementPtrInst>(Builder.CreateGEP(Candidate, RHS));
410 if (
auto *NewI = tryReassociateBinaryOp(LHS, RHS, I))
412 if (
auto *NewI = tryReassociateBinaryOp(RHS, LHS, I))
419 Value *
A =
nullptr, *
B =
nullptr;
422 if (LHS->
hasOneUse() && matchTernaryOp(I, LHS, A,
B)) {
425 const SCEV *AExpr = SE->getSCEV(A), *BExpr = SE->getSCEV(
B);
426 const SCEV *RHSExpr = SE->getSCEV(RHS);
427 if (BExpr != RHSExpr) {
429 tryReassociatedBinaryOp(getBinarySCEV(I, AExpr, RHSExpr),
B, I))
432 if (AExpr != RHSExpr) {
434 tryReassociatedBinaryOp(getBinarySCEV(I, BExpr, RHSExpr), A, I))
441 Instruction *NaryReassociatePass::tryReassociatedBinaryOp(
const SCEV *LHSExpr,
446 auto *LHS = findClosestMatchingDominator(LHSExpr, I);
455 case Instruction::Mul:
470 case Instruction::Mul:
483 return SE->getAddExpr(LHS, RHS);
484 case Instruction::Mul:
485 return SE->getMulExpr(LHS, RHS);
493 NaryReassociatePass::findClosestMatchingDominator(
const SCEV *CandidateExpr,
495 auto Pos = SeenExprs.find(CandidateExpr);
496 if (Pos == SeenExprs.end())
499 auto &Candidates = Pos->second;
504 while (!Candidates.empty()) {
507 if (
Value *Candidate = Candidates.back()) {
508 Instruction *CandidateInstruction = cast<Instruction>(Candidate);
509 if (DT->dominates(CandidateInstruction, Dominatee))
510 return CandidateInstruction;
512 Candidates.pop_back();
static unsigned getBitWidth(Type *Ty, const DataLayout &DL)
Returns the bitwidth of the given scalar or pointer type (if unknown returns 0).
void push_back(const T &Elt)
Type * getIndexedType() const
Value * getPointerOperand()
class_match< Value > m_Value()
Match an arbitrary value and ignore it.
AnalysisUsage & addPreserved()
Add the specified Pass class to the set of analyses preserved by this pass.
static PassRegistry * getPassRegistry()
getPassRegistry - Access the global registry object, which is automatically initialized at applicatio...
PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM)
Type * getSourceElementType() const
A Module instance is used to store all the information related to an LLVM module. ...
This class represents zero extension of integer types.
unsigned getNumOperands() const
BinaryOp_match< LHS, RHS, Instruction::Mul > m_Mul(const LHS &L, const RHS &R)
The main scalar evolution driver.
An immutable pass that tracks lazily created AssumptionCache objects.
A cache of .assume calls within a function.
Analysis pass providing the TargetTransformInfo.
Analysis pass which computes a DominatorTree.
This class represents a sign extension of integer types.
bool isSequential() const
iterator begin()
Instruction iterator methods.
bool match(Val *V, const Pattern &P)
AnalysisUsage & addRequired()
#define INITIALIZE_PASS_DEPENDENCY(depName)
void setIsInBounds(bool b=true)
Set or clear the inbounds flag on this GEP instruction.
This provides a uniform API for creating instructions and inserting them into a basic block: either a...
BinaryOp_match< LHS, RHS, Instruction::Add > m_Add(const LHS &L, const RHS &R)
Value handle that is nullable, but tries to track the Value.
static GCRegistry::Add< OcamlGC > B("ocaml","ocaml 3.10-compatible GC")
void takeName(Value *V)
Transfer the name from V to this value.
Concrete subclass of DominatorTreeBase that is used to compute a normal dominator tree...
static BinaryOperator * CreateAdd(Value *S1, Value *S2, const Twine &Name, Instruction *InsertBefore, Value *FlagsOp)
bool isInBounds() const
Determine whether the GEP has the inbounds flag.
static GCRegistry::Add< CoreCLRGC > E("coreclr","CoreCLR-compatible GC")
an instruction for type-safe pointer arithmetic to access elements of arrays and structs ...
OverflowResult computeOverflowForSignedAdd(const Value *LHS, const Value *RHS, const DataLayout &DL, AssumptionCache *AC=nullptr, const Instruction *CxtI=nullptr, const DominatorTree *DT=nullptr)
A set of analyses that are preserved following a run of a transformation pass.
PassT::Result & getResult(IRUnitT &IR, ExtraArgTs...ExtraArgs)
Get the result of an analysis pass for a given IR unit.
LLVM Basic Block Representation.
The instances of the Type class are immutable: once they are created, they are never changed...
bool isKnownNonNegative(const Value *V, const DataLayout &DL, unsigned Depth=0, AssumptionCache *AC=nullptr, const Instruction *CxtI=nullptr, const DominatorTree *DT=nullptr)
Returns true if the give value is known to be non-negative.
Represent the analysis usage information of a pass.
INITIALIZE_PASS_END(RegBankSelect, DEBUG_TYPE,"Assign register bank of generic virtual registers", false, false) RegBankSelect
FunctionPass class - This class is used to implement most global optimizations.
Value * getOperand(unsigned i) const
bool RecursivelyDeleteTriviallyDeadInstructions(Value *V, const TargetLibraryInfo *TLI=nullptr)
If the specified value is a trivially dead instruction, delete it.
static PreservedAnalyses all()
Construct a special preserved set that preserves all passes.
#define llvm_unreachable(msg)
Marks that the current location is not supposed to be reachable.
A function analysis which provides an AssumptionCache.
BinaryOps getOpcode() const
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small...
Module.h This file contains the declarations for the Module class.
Type * getType() const
All values are typed, get the type of this value.
INITIALIZE_PASS_BEGIN(NaryReassociateLegacyPass,"nary-reassociate","Nary reassociation", false, false) INITIALIZE_PASS_END(NaryReassociateLegacyPass
Provides information about what library functions are available for the current target.
bool runImpl(Function &F, AssumptionCache *AC_, DominatorTree *DT_, ScalarEvolution *SE_, TargetLibraryInfo *TLI_, TargetTransformInfo *TTI_)
void invalidate(IRUnitT &IR)
Invalidate a specific analysis pass for an IR module.
static Constant * get(Type *Ty, uint64_t V, bool isSigned=false)
If Ty is a vector type, return a Constant with a splat of the given value.
void setPreservesCFG()
This function should be called by the pass, iff they do not:
Analysis pass that exposes the ScalarEvolution for a function.
const DataLayout & getDataLayout() const
Get the data layout for the module's target platform.
This class represents an analyzed expression in the program.
bool hasOneUse() const
Return true if there is exactly one user of this value.
void initializeNaryReassociateLegacyPassPass(PassRegistry &)
void preserve()
Mark an analysis as preserved.
Analysis pass providing the TargetLibraryInfo.
iterator_range< df_iterator< T > > depth_first(const T &G)
assert(ImpDefSCC.getReg()==AMDGPU::SCC &&ImpDefSCC.isDef())
Module * getParent()
Get the module that this global value is contained inside of...
LLVM Value Representation.
unsigned getOpcode() const
Returns a member of one of the enums like Instruction::Add.
static bool isPotentiallyNaryReassociable(Instruction *I)
static bool isGEPFoldable(GetElementPtrInst *GEP, const TargetTransformInfo *TTI)
A container for analyses that lazily runs them and caches their results.
Legacy analysis pass which computes a DominatorTree.
static BinaryOperator * CreateMul(Value *S1, Value *S2, const Twine &Name, Instruction *InsertBefore, Value *FlagsOp)
unsigned getPointerAddressSpace() const
Get the address space of this pointer or pointer vector type.
static GCRegistry::Add< ErlangGC > A("erlang","erlang-compatible garbage collector")
Type * getResultElementType() const
FunctionPass * createNaryReassociatePass()
gep_type_iterator gep_type_begin(const User *GEP)