60#include <forward_list>
66#define LLE_OPTION "loop-load-elim"
67#define DEBUG_TYPE LLE_OPTION
70 "runtime-check-per-loop-load-elim",
cl::Hidden,
71 cl::desc(
"Max number of memchecks allowed per eliminated load on average"),
76 cl::desc(
"The maximum number of SCEV checks allowed for Loop "
79STATISTIC(NumLoopLoadEliminted,
"Number of loads eliminated by LLE");
84struct StoreToLoadForwardingCandidate {
89 : Load(Load), Store(Store) {}
96 Value *LoadPtr =
Load->getPointerOperand();
99 auto &
DL =
Load->getDataLayout();
103 DL.getTypeSizeInBits(LoadType) ==
105 "Should be a known dependence");
107 int64_t StrideLoad =
getPtrStride(PSE, LoadType, LoadPtr, L).value_or(0);
108 int64_t StrideStore =
getPtrStride(PSE, LoadType, StorePtr, L).value_or(0);
109 if (!StrideLoad || !StrideStore || StrideLoad != StrideStore)
119 if (std::abs(StrideLoad) != 1)
122 unsigned TypeByteSize =
DL.getTypeAllocSize(
const_cast<Type *
>(LoadType));
124 auto *LoadPtrSCEV = cast<SCEVAddRecExpr>(PSE.
getSCEV(LoadPtr));
125 auto *StorePtrSCEV = cast<SCEVAddRecExpr>(PSE.
getSCEV(StorePtr));
129 auto *Dist = dyn_cast<SCEVConstant>(
133 const APInt &Val = Dist->getAPInt();
134 return Val == TypeByteSize * StrideLoad;
137 Value *getLoadPtr()
const {
return Load->getPointerOperand(); }
141 const StoreToLoadForwardingCandidate &Cand) {
142 OS << *Cand.Store <<
" -->\n";
156 L->getLoopLatches(Latches);
164 return Load->getParent() != L->getHeader();
170class LoadEliminationForLoop {
175 :
L(
L), LI(LI), LAI(LAI), DT(DT),
BFI(
BFI), PSI(PSI), PSE(LAI.getPSE()) {}
182 std::forward_list<StoreToLoadForwardingCandidate>
184 std::forward_list<StoreToLoadForwardingCandidate> Candidates;
197 for (
const auto &Dep : *Deps) {
199 Instruction *Destination = Dep.getDestination(DepChecker);
203 if (isa<LoadInst>(Source))
204 LoadsWithUnknownDependence.
insert(Source);
205 if (isa<LoadInst>(Destination))
206 LoadsWithUnknownDependence.
insert(Destination);
210 if (Dep.isBackward())
216 assert(Dep.isForward() &&
"Needs to be a forward dependence");
218 auto *
Store = dyn_cast<StoreInst>(Source);
221 auto *
Load = dyn_cast<LoadInst>(Destination);
228 Store->getDataLayout()))
231 Candidates.emplace_front(Load, Store);
234 if (!LoadsWithUnknownDependence.
empty())
235 Candidates.remove_if([&](
const StoreToLoadForwardingCandidate &
C) {
236 return LoadsWithUnknownDependence.
count(
C.Load);
244 auto I = InstOrder.find(Inst);
245 assert(
I != InstOrder.end() &&
"No index for instruction");
268 void removeDependencesFromMultipleStores(
269 std::forward_list<StoreToLoadForwardingCandidate> &Candidates) {
272 using LoadToSingleCandT =
274 LoadToSingleCandT LoadToSingleCand;
276 for (
const auto &Cand : Candidates) {
278 LoadToSingleCandT::iterator Iter;
280 std::tie(Iter, NewElt) =
281 LoadToSingleCand.
insert(std::make_pair(Cand.Load, &Cand));
283 const StoreToLoadForwardingCandidate *&OtherCand = Iter->second;
285 if (OtherCand ==
nullptr)
291 if (Cand.Store->getParent() == OtherCand->Store->getParent() &&
292 Cand.isDependenceDistanceOfOne(PSE, L) &&
293 OtherCand->isDependenceDistanceOfOne(PSE, L)) {
295 if (getInstrIndex(OtherCand->Store) < getInstrIndex(Cand.Store))
302 Candidates.remove_if([&](
const StoreToLoadForwardingCandidate &Cand) {
303 if (LoadToSingleCand[Cand.Load] != &Cand) {
305 dbgs() <<
"Removing from candidates: \n"
307 <<
" The load may have multiple stores forwarding to "
320 bool needsChecking(
unsigned PtrIdx1,
unsigned PtrIdx2,
327 return ((PtrsWrittenOnFwdingPath.
count(Ptr1) && CandLoadPtrs.
count(Ptr2)) ||
328 (PtrsWrittenOnFwdingPath.
count(Ptr2) && CandLoadPtrs.
count(Ptr1)));
356 [&](
const StoreToLoadForwardingCandidate &
A,
357 const StoreToLoadForwardingCandidate &
B) {
358 return getInstrIndex(
A.Load) <
359 getInstrIndex(
B.Load);
364 [&](
const StoreToLoadForwardingCandidate &
A,
365 const StoreToLoadForwardingCandidate &
B) {
366 return getInstrIndex(
A.Store) <
367 getInstrIndex(
B.Store);
377 if (
auto *S = dyn_cast<StoreInst>(
I))
378 PtrsWrittenOnFwdingPath.
insert(S->getPointerOperand());
381 std::for_each(MemInstrs.begin() + getInstrIndex(FirstStore) + 1,
382 MemInstrs.end(), InsertStorePtr);
383 std::for_each(MemInstrs.begin(), &MemInstrs[getInstrIndex(LastLoad)],
386 return PtrsWrittenOnFwdingPath;
395 findPointersWrittenOnForwardingPath(Candidates);
399 for (
const auto &Candidate : Candidates)
400 CandLoadPtrs.
insert(Candidate.getLoadPtr());
405 copy_if(AllChecks, std::back_inserter(Checks),
407 for (
auto PtrIdx1 :
Check.first->Members)
408 for (
auto PtrIdx2 :
Check.second->Members)
409 if (needsChecking(PtrIdx1, PtrIdx2, PtrsWrittenOnFwdingPath,
424 propagateStoredValueToLoadUsers(
const StoreToLoadForwardingCandidate &Cand,
441 Value *
Ptr = Cand.Load->getPointerOperand();
442 auto *PtrSCEV = cast<SCEVAddRecExpr>(PSE.
getSCEV(
Ptr));
443 auto *PH =
L->getLoopPreheader();
444 assert(PH &&
"Preheader should exist!");
445 Value *InitialPtr =
SEE.expandCodeFor(PtrSCEV->getStart(),
Ptr->getType(),
446 PH->getTerminator());
448 new LoadInst(Cand.Load->getType(), InitialPtr,
"load_initial",
449 false, Cand.Load->getAlign(),
450 PH->getTerminator()->getIterator());
457 PHI->insertBefore(
L->getHeader()->begin());
458 PHI->addIncoming(Initial, PH);
461 Type *StoreType = Cand.Store->getValueOperand()->getType();
462 auto &
DL = Cand.Load->getDataLayout();
465 assert(
DL.getTypeSizeInBits(LoadType) ==
DL.getTypeSizeInBits(StoreType) &&
466 "The type sizes should match!");
468 Value *StoreValue = Cand.Store->getValueOperand();
469 if (LoadType != StoreType) {
471 "store_forward_cast",
472 Cand.Store->getIterator());
476 cast<Instruction>(StoreValue)->setDebugLoc(Cand.Load->getDebugLoc());
479 PHI->addIncoming(StoreValue,
L->getLoopLatch());
481 Cand.Load->replaceAllUsesWith(
PHI);
482 PHI->setDebugLoc(Cand.Load->getDebugLoc());
488 LLVM_DEBUG(
dbgs() <<
"\nIn \"" <<
L->getHeader()->getParent()->getName()
489 <<
"\" checking " << *L <<
"\n");
510 auto StoreToLoadDependences = findStoreToLoadDependences(LAI);
511 if (StoreToLoadDependences.empty())
520 removeDependencesFromMultipleStores(StoreToLoadDependences);
521 if (StoreToLoadDependences.empty())
526 for (
const StoreToLoadForwardingCandidate &Cand : StoreToLoadDependences) {
542 if (!Cand.isDependenceDistanceOfOne(PSE, L))
545 assert(isa<SCEVAddRecExpr>(PSE.
getSCEV(Cand.Load->getPointerOperand())) &&
546 "Loading from something other than indvar?");
548 isa<SCEVAddRecExpr>(PSE.
getSCEV(Cand.Store->getPointerOperand())) &&
549 "Storing to something other than indvar?");
555 <<
". Valid store-to-load forwarding across the loop backedge\n");
557 if (Candidates.
empty())
576 if (!
L->isLoopSimplifyForm()) {
584 "convergent calls\n");
588 auto *HeaderBB =
L->getHeader();
589 auto *
F = HeaderBB->getParent();
590 bool OptForSize =
F->hasOptSize() ||
592 PGSOQueryType::IRPass);
595 dbgs() <<
"Versioning is needed but not allowed when optimizing "
608 auto NoLongerGoodCandidate = [
this](
609 const StoreToLoadForwardingCandidate &Cand) {
610 return !isa<SCEVAddRecExpr>(
611 PSE.
getSCEV(Cand.Load->getPointerOperand())) ||
612 !isa<SCEVAddRecExpr>(
613 PSE.
getSCEV(Cand.Store->getPointerOperand()));
622 for (
const auto &Cand : Candidates)
623 propagateStoredValueToLoadUsers(Cand,
SEE);
624 NumLoopLoadEliminted += Candidates.size();
660 bool Changed =
false;
662 for (
Loop *TopLevelLoop : LI)
664 Changed |=
simplifyLoop(L, &DT, &LI, SE, AC,
nullptr,
false);
666 if (L->isInnermost())
671 for (
Loop *L : Worklist) {
673 if (!L->isRotatedForm() || !L->getExitingBlock())
676 LoadEliminationForLoop LEL(L, &LI, LAIs.
getInfo(*L), &DT, BFI, PSI);
677 Changed |= LEL.processLoop();
696 auto *BFI = (PSI && PSI->hasProfileSummary()) ?
for(const MachineOperand &MO :llvm::drop_begin(OldMI.operands(), Desc.getNumOperands()))
This file implements a class to represent arbitrary precision integral constant values and operations...
MachineBasicBlock MachineBasicBlock::iterator DebugLoc DL
static GCRegistry::Add< OcamlGC > B("ocaml", "ocaml 3.10-compatible GC")
static GCRegistry::Add< ErlangGC > A("erlang", "erlang-compatible garbage collector")
This file defines the DenseMap class.
This file builds on the ADT/GraphTraits.h file to build generic depth first graph iterator.
This is the interface for a simple mod/ref and alias analysis over globals.
This header provides classes for managing per-loop analyses.
static bool eliminateLoadsAcrossLoops(Function &F, LoopInfo &LI, DominatorTree &DT, BlockFrequencyInfo *BFI, ProfileSummaryInfo *PSI, ScalarEvolution *SE, AssumptionCache *AC, LoopAccessInfoManager &LAIs)
static cl::opt< unsigned > LoadElimSCEVCheckThreshold("loop-load-elimination-scev-check-threshold", cl::init(8), cl::Hidden, cl::desc("The maximum number of SCEV checks allowed for Loop " "Load Elimination"))
static bool isLoadConditional(LoadInst *Load, Loop *L)
Return true if the load is not executed on all paths in the loop.
static bool doesStoreDominatesAllLatches(BasicBlock *StoreBlock, Loop *L, DominatorTree *DT)
Check if the store dominates all latches, so as long as there is no intervening store this value will...
static cl::opt< unsigned > CheckPerElim("runtime-check-per-loop-load-elim", cl::Hidden, cl::desc("Max number of memchecks allowed per eliminated load on average"), cl::init(1))
This header defines the LoopLoadEliminationPass object.
Module.h This file contains the declarations for the Module class.
This header defines various interfaces for pass management in LLVM.
assert(ImpDefSCC.getReg()==AMDGPU::SCC &&ImpDefSCC.isDef())
This file defines the SmallPtrSet class.
This file defines the SmallVector class.
This file defines the 'Statistic' class, which is designed to be an easy way to expose various metric...
#define STATISTIC(VARNAME, DESC)
Class for arbitrary precision integers.
A container for analyses that lazily runs them and caches their results.
PassT::Result & getResult(IRUnitT &IR, ExtraArgTs... ExtraArgs)
Get the result of an analysis pass for a given IR unit.
A function analysis which provides an AssumptionCache.
A cache of @llvm.assume calls within a function.
LLVM Basic Block Representation.
Analysis pass which computes BlockFrequencyInfo.
BlockFrequencyInfo pass uses BlockFrequencyInfoImpl implementation to estimate IR basic block frequen...
static bool isBitOrNoopPointerCastable(Type *SrcTy, Type *DestTy, const DataLayout &DL)
Check whether a bitcast, inttoptr, or ptrtoint cast between these types is valid and a no-op.
static CastInst * CreateBitOrPointerCast(Value *S, Type *Ty, const Twine &Name="", InsertPosition InsertBefore=nullptr)
Create a BitCast, a PtrToInt, or an IntToPTr cast instruction.
std::pair< iterator, bool > insert(const std::pair< KeyT, ValueT > &KV)
Analysis pass which computes a DominatorTree.
Concrete subclass of DominatorTreeBase that is used to compute a normal dominator tree.
bool dominates(const BasicBlock *BB, const Use &U) const
Return true if the (end of the) basic block BB dominates the use U.
An instruction for reading from memory.
This analysis provides dependence information for the memory accesses of a loop.
const LoopAccessInfo & getInfo(Loop &L)
Drive the analysis of memory accesses in the loop.
const MemoryDepChecker & getDepChecker() const
the Memory Dependence Checker which can determine the loop-independent and loop-carried dependences b...
const RuntimePointerChecking * getRuntimePointerChecking() const
const PredicatedScalarEvolution & getPSE() const
Used to add runtime SCEV checks.
bool hasConvergentOp() const
Return true if there is a convergent operation in the loop.
Analysis pass that exposes the LoopInfo for a function.
This class emits a version of the loop where run-time checks ensure that may-alias pointers can't ove...
Represents a single loop in the control flow graph.
const SmallVectorImpl< Instruction * > & getMemoryInstructions() const
The vector of memory access instructions.
const SmallVectorImpl< Dependence > * getDependences() const
Returns the memory dependences.
DenseMap< Instruction *, unsigned > generateInstructionOrderMap() const
Generate a mapping between the memory instructions and their indices according to program order.
An analysis over an "inner" IR unit that provides access to an analysis manager over a "outer" IR uni...
static PHINode * Create(Type *Ty, unsigned NumReservedValues, const Twine &NameStr="", InsertPosition InsertBefore=nullptr)
Constructors - NumReservedValues is a hint for the number of incoming edges that this phi node will h...
An interface layer with SCEV used to manage how we see SCEV expressions for values in the context of ...
ScalarEvolution * getSE() const
Returns the ScalarEvolution analysis used.
const SCEVPredicate & getPredicate() const
const SCEV * getSCEV(Value *V)
Returns the SCEV expression of V, in the context of the current SCEV predicate.
A set of analyses that are preserved following a run of a transformation pass.
static PreservedAnalyses all()
Construct a special preserved set that preserves all passes.
void preserve()
Mark an analysis as preserved.
An analysis pass based on the new PM to deliver ProfileSummaryInfo.
Analysis providing profile information.
void printChecks(raw_ostream &OS, const SmallVectorImpl< RuntimePointerCheck > &Checks, unsigned Depth=0) const
Print Checks.
const SmallVectorImpl< RuntimePointerCheck > & getChecks() const
Returns the checks that generateChecks created.
const PointerInfo & getPointerInfo(unsigned PtrIdx) const
Return PointerInfo for pointer at index PtrIdx.
This class uses information about analyze scalars to rewrite expressions in canonical form.
virtual unsigned getComplexity() const
Returns the estimated complexity of this predicate.
virtual bool isAlwaysTrue() const =0
Returns true if the predicate is always true.
Analysis pass that exposes the ScalarEvolution for a function.
The main scalar evolution driver.
const SCEV * getMinusSCEV(const SCEV *LHS, const SCEV *RHS, SCEV::NoWrapFlags Flags=SCEV::FlagAnyWrap, unsigned Depth=0)
Return LHS-RHS.
A templated base class for SmallPtrSet which provides the typesafe interface that is common across al...
size_type count(ConstPtrType Ptr) const
count - Return 1 if the specified pointer is in the set, 0 otherwise.
std::pair< iterator, bool > insert(PtrType Ptr)
Inserts Ptr if and only if there is no element in the container equal to Ptr.
SmallPtrSet - This class implements a set which is optimized for holding SmallSize or less elements.
This class consists of common code factored out of the SmallVector class to reduce code duplication b...
void push_back(const T &Elt)
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
An instruction for storing to memory.
The instances of the Type class are immutable: once they are created, they are never changed.
unsigned getPointerAddressSpace() const
Get the address space of this pointer or pointer vector type.
LLVM Value Representation.
Type * getType() const
All values are typed, get the type of this value.
This class implements an extremely fast bulk output stream that can only output to a stream.
raw_ostream & indent(unsigned NumSpaces)
indent - Insert 'NumSpaces' spaces.
@ C
The default llvm calling convention, compatible with C.
initializer< Ty > init(const Ty &Val)
This is an optimization pass for GlobalISel generic memory operations.
bool simplifyLoop(Loop *L, DominatorTree *DT, LoopInfo *LI, ScalarEvolution *SE, AssumptionCache *AC, MemorySSAUpdater *MSSAU, bool PreserveLCSSA)
Simplify each loop in a loop nest recursively.
auto min_element(R &&Range)
Provide wrappers to std::min_element which take ranges instead of having to pass begin/end explicitly...
bool all_of(R &&range, UnaryPredicate P)
Provide wrappers to std::all_of which take ranges instead of having to pass begin/end explicitly.
std::pair< const RuntimeCheckingPtrGroup *, const RuntimeCheckingPtrGroup * > RuntimePointerCheck
A memcheck which made up of a pair of grouped pointers.
bool shouldOptimizeForSize(const MachineFunction *MF, ProfileSummaryInfo *PSI, const MachineBlockFrequencyInfo *BFI, PGSOQueryType QueryType=PGSOQueryType::Other)
Returns true if machine function MF is suggested to be size-optimized based on the profile.
OutputIt copy_if(R &&Range, OutputIt Out, UnaryPredicate P)
Provide wrappers to std::copy_if which take ranges instead of having to pass begin/end explicitly.
raw_ostream & dbgs()
dbgs() - This returns a reference to a raw_ostream for debugging messages.
std::optional< int64_t > getPtrStride(PredicatedScalarEvolution &PSE, Type *AccessTy, Value *Ptr, const Loop *Lp, const DenseMap< Value *, const SCEV * > &StridesMap=DenseMap< Value *, const SCEV * >(), bool Assume=false, bool ShouldCheckWrap=true)
If the pointer has a constant stride return it in units of the access type size.
auto max_element(R &&Range)
Provide wrappers to std::max_element which take ranges instead of having to pass begin/end explicitly...
raw_ostream & operator<<(raw_ostream &OS, const APFixedPoint &FX)
void erase_if(Container &C, UnaryPredicate P)
Provide a container algorithm similar to C++ Library Fundamentals v2's erase_if which is equivalent t...
Type * getLoadStoreType(const Value *I)
A helper function that returns the type of a load or store instruction.
iterator_range< df_iterator< T > > depth_first(const T &G)
void swap(llvm::BitVector &LHS, llvm::BitVector &RHS)
Implement std::swap in terms of BitVector swap.
PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM)
TrackingVH< Value > PointerValue
Holds the pointer value that we need to check.