48 #include <forward_list>
55 #define LLE_OPTION "loop-load-elim"
56 #define DEBUG_TYPE LLE_OPTION
61 "runtime-check-per-loop-load-elim",
cl::Hidden,
62 cl::desc(
"Max number of memchecks allowed per eliminated load on average"),
67 cl::desc(
"The maximum number of SCEV checks allowed for Loop "
70 STATISTIC(NumLoopLoadEliminted,
"Number of loads eliminated by LLE");
75 struct StoreToLoadForwardingCandidate {
80 : Load(Load), Store(Store) {}
86 Value *LoadPtr =
Load->getPointerOperand();
94 "Should be a known dependence");
103 auto &
DL =
Load->getParent()->getModule()->getDataLayout();
104 unsigned TypeByteSize =
DL.getTypeAllocSize(const_cast<Type *>(LoadType));
106 auto *LoadPtrSCEV = cast<SCEVAddRecExpr>(PSE.
getSCEV(LoadPtr));
107 auto *StorePtrSCEV = cast<SCEVAddRecExpr>(PSE.
getSCEV(StorePtr));
111 auto *Dist = cast<SCEVConstant>(
113 const APInt &Val = Dist->getAPInt();
114 return Val == TypeByteSize;
117 Value *getLoadPtr()
const {
return Load->getPointerOperand(); }
121 const StoreToLoadForwardingCandidate &Cand) {
122 OS << *Cand.Store <<
" -->\n";
123 OS.
indent(2) << *Cand.Load <<
"\n";
131 bool doesStoreDominatesAllLatches(
BasicBlock *StoreBlock,
Loop *L,
146 class LoadEliminationForLoop {
150 : L(L), LI(LI), LAI(LAI), DT(DT), PSE(LAI.getPSE()) {}
157 std::forward_list<StoreToLoadForwardingCandidate>
159 std::forward_list<StoreToLoadForwardingCandidate> Candidates;
171 for (
const auto &Dep : *Deps) {
173 Instruction *Destination = Dep.getDestination(LAI);
176 if (isa<LoadInst>(Source))
177 LoadsWithUnknownDepedence.
insert(Source);
178 if (isa<LoadInst>(Destination))
179 LoadsWithUnknownDepedence.
insert(Destination);
183 if (Dep.isBackward())
189 assert(Dep.isForward() &&
"Needs to be a forward dependence");
199 if (Store->getPointerOperand()->getType() !=
203 Candidates.emplace_front(Load, Store);
206 if (!LoadsWithUnknownDepedence.
empty())
207 Candidates.remove_if([&](
const StoreToLoadForwardingCandidate &
C) {
208 return LoadsWithUnknownDepedence.
count(C.Load);
216 auto I = InstOrder.find(Inst);
217 assert(
I != InstOrder.end() &&
"No index for instruction");
240 void removeDependencesFromMultipleStores(
241 std::forward_list<StoreToLoadForwardingCandidate> &Candidates) {
246 LoadToSingleCandT LoadToSingleCand;
248 for (
const auto &Cand : Candidates) {
250 LoadToSingleCandT::iterator Iter;
252 std::tie(Iter, NewElt) =
253 LoadToSingleCand.
insert(std::make_pair(Cand.Load, &Cand));
255 const StoreToLoadForwardingCandidate *&OtherCand = Iter->second;
257 if (OtherCand ==
nullptr)
263 if (Cand.Store->getParent() == OtherCand->Store->getParent() &&
264 Cand.isDependenceDistanceOfOne(PSE, L) &&
265 OtherCand->isDependenceDistanceOfOne(PSE, L)) {
267 if (getInstrIndex(OtherCand->Store) < getInstrIndex(Cand.Store))
274 Candidates.remove_if([&](
const StoreToLoadForwardingCandidate &Cand) {
275 if (LoadToSingleCand[Cand.Load] != &Cand) {
276 DEBUG(
dbgs() <<
"Removing from candidates: \n" << Cand
277 <<
" The load may have multiple stores forwarding to "
290 bool needsChecking(
unsigned PtrIdx1,
unsigned PtrIdx2,
292 const std::set<Value *> &CandLoadPtrs) {
297 return ((PtrsWrittenOnFwdingPath.
count(Ptr1) && CandLoadPtrs.count(Ptr2)) ||
298 (PtrsWrittenOnFwdingPath.
count(Ptr2) && CandLoadPtrs.count(Ptr1)));
325 std::max_element(Candidates.
begin(), Candidates.
end(),
326 [&](
const StoreToLoadForwardingCandidate &
A,
327 const StoreToLoadForwardingCandidate &
B) {
328 return getInstrIndex(A.Load) < getInstrIndex(
B.Load);
332 std::min_element(Candidates.
begin(), Candidates.
end(),
333 [&](
const StoreToLoadForwardingCandidate &
A,
334 const StoreToLoadForwardingCandidate &
B) {
335 return getInstrIndex(A.Store) <
336 getInstrIndex(
B.Store);
346 if (
auto *S = dyn_cast<StoreInst>(
I))
347 PtrsWrittenOnFwdingPath.
insert(S->getPointerOperand());
350 std::for_each(MemInstrs.begin() + getInstrIndex(FirstStore) + 1,
351 MemInstrs.end(), InsertStorePtr);
352 std::for_each(MemInstrs.begin(), &MemInstrs[getInstrIndex(LastLoad)],
355 return PtrsWrittenOnFwdingPath;
364 findPointersWrittenOnForwardingPath(Candidates);
368 std::set<Value *> CandLoadPtrs;
370 std::inserter(CandLoadPtrs, CandLoadPtrs.begin()),
371 std::mem_fn(&StoreToLoadForwardingCandidate::getLoadPtr));
376 std::copy_if(AllChecks.begin(), AllChecks.end(), std::back_inserter(Checks),
378 for (
auto PtrIdx1 :
Check.first->Members)
379 for (
auto PtrIdx2 :
Check.second->Members)
380 if (needsChecking(PtrIdx1, PtrIdx2,
381 PtrsWrittenOnFwdingPath, CandLoadPtrs))
386 DEBUG(
dbgs() <<
"\nPointer Checks (count: " << Checks.
size() <<
"):\n");
394 propagateStoredValueToLoadUsers(
const StoreToLoadForwardingCandidate &Cand,
412 Value *
Ptr = Cand.Load->getPointerOperand();
413 auto *PtrSCEV = cast<SCEVAddRecExpr>(PSE.
getSCEV(Ptr));
416 PH->getTerminator());
418 new LoadInst(InitialPtr,
"load_initial",
false,
419 Cand.Load->getAlignment(), PH->getTerminator());
424 PHI->addIncoming(Cand.Store->getOperand(0), L->
getLoopLatch());
426 Cand.Load->replaceAllUsesWith(PHI);
433 <<
"\" checking " << *L <<
"\n");
453 auto StoreToLoadDependences = findStoreToLoadDependences(LAI);
454 if (StoreToLoadDependences.empty())
463 removeDependencesFromMultipleStores(StoreToLoadDependences);
464 if (StoreToLoadDependences.empty())
469 unsigned NumForwarding = 0;
470 for (
const StoreToLoadForwardingCandidate Cand : StoreToLoadDependences) {
475 if (!doesStoreDominatesAllLatches(Cand.Store->getParent(),
L, DT))
481 if (isLoadConditional(Cand.Load, L))
486 if (!Cand.isDependenceDistanceOfOne(PSE, L))
492 <<
". Valid store-to-load forwarding across the loop backedge\n");
495 if (Candidates.
empty())
501 collectMemchecks(Candidates);
505 DEBUG(
dbgs() <<
"Too many run-time checks needed.\n");
511 DEBUG(
dbgs() <<
"Too many SCEV run-time checks needed.\n");
516 if (L->
getHeader()->getParent()->optForSize()) {
517 DEBUG(
dbgs() <<
"Versioning is needed but not allowed when optimizing "
523 DEBUG(
dbgs() <<
"Loop is not is loop-simplify form");
540 for (
const auto &Cand : Candidates)
541 propagateStoredValueToLoadUsers(Cand, SEE);
542 NumLoopLoadEliminted += NumForwarding;
573 auto *LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
574 auto *LAA = &getAnalysis<LoopAccessLegacyAnalysis>();
575 auto *DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree();
582 for (
Loop *TopLevelLoop : *LI)
589 bool Changed =
false;
590 for (
Loop *L : Worklist) {
593 LoadEliminationForLoop LEL(L, LI, LAI, DT);
594 Changed |= LEL.processLoop();
618 static const char LLE_name[] =
"Loop Load Elimination";
631 return new LoopLoadElimination();
Legacy wrapper pass to provide the GlobalsAAResult object.
static bool Check(DecodeStatus &Out, DecodeStatus In)
TrackingVH< Value > PointerValue
Holds the pointer value that we need to check.
void push_back(const T &Elt)
AnalysisUsage & addPreserved()
Add the specified Pass class to the set of analyses preserved by this pass.
void addIncoming(Value *V, BasicBlock *BB)
Add an incoming value to the end of the PHI list.
STATISTIC(NumFunctions,"Total number of functions")
This is the interface for a simple mod/ref and alias analysis over globals.
const SmallVectorImpl< Instruction * > & getMemoryInstructions() const
The vector of memory access instructions.
void getLoopLatches(SmallVectorImpl< BlockT * > &LoopLatches) const
Return all loop latch blocks of this loop.
virtual void getAnalysisUsage(AnalysisUsage &) const
getAnalysisUsage - This function should be overriden by passes that need analysis information to do t...
int64_t getPtrStride(PredicatedScalarEvolution &PSE, Value *Ptr, const Loop *Lp, const ValueToValueMap &StridesMap=ValueToValueMap(), bool Assume=false, bool ShouldCheckWrap=true)
If the pointer has a constant stride return it in units of its element size.
raw_ostream & indent(unsigned NumSpaces)
indent - Insert 'NumSpaces' spaces.
bool all_of(R &&range, UnaryPredicate P)
Provide wrappers to std::all_of which take ranges instead of having to pass begin/end explicitly...
An instruction for reading from memory.
BlockT * getHeader() const
Type * getPointerElementType() const
BlockT * getLoopLatch() const
If there is a single latch block for this loop, return it.
void printChecks(raw_ostream &OS, const SmallVectorImpl< PointerCheck > &Checks, unsigned Depth=0) const
Print Checks.
std::pair< iterator, bool > insert(const std::pair< KeyT, ValueT > &KV)
AnalysisUsage & addRequired()
#define INITIALIZE_PASS_DEPENDENCY(depName)
LLVM_NODISCARD bool empty() const
This class consists of common code factored out of the SmallVector class to reduce code duplication b...
const PredicatedScalarEvolution & getPSE() const
Used to add runtime SCEV checks.
static const char LLE_name[]
bool isLoopSimplifyForm() const
Return true if the Loop is in the form that the LoopSimplify form transforms loops to...
This file implements a class to represent arbitrary precision integral constant values and operations...
LLVM_NODISCARD bool empty() const
const RuntimePointerChecking * getRuntimePointerChecking() const
Function Alias Analysis false
void initializeLoopLoadEliminationPass(PassRegistry &)
static GCRegistry::Add< OcamlGC > B("ocaml","ocaml 3.10-compatible GC")
An instruction for storing to memory.
Concrete subclass of DominatorTreeBase that is used to compute a normal dominator tree...
const SCEV * getSCEV(Value *V)
Returns the SCEV expression of V, in the context of the current SCEV predicate.
initializer< Ty > init(const Ty &Val)
BlockT * getLoopPreheader() const
If there is a preheader for this loop, return it.
LLVM Basic Block Representation.
The instances of the Type class are immutable: once they are created, they are never changed...
This analysis provides dependence information for the memory accesses of a loop.
LLVM_ATTRIBUTE_ALWAYS_INLINE iterator begin()
SmallSet - This maintains a set of unique values, optimizing for the case when the set is small (less...
const PointerInfo & getPointerInfo(unsigned PtrIdx) const
Return PointerInfo for pointer at index PtrIdx.
Represent the analysis usage information of a pass.
const SmallVector< PointerCheck, 4 > & getChecks() const
Returns the checks that generateChecks created.
const SCEV * getMinusSCEV(const SCEV *LHS, const SCEV *RHS, SCEV::NoWrapFlags Flags=SCEV::FlagAnyWrap)
Return LHS-RHS. Minus is represented in SCEV as A+B*-1.
INITIALIZE_PASS_END(RegBankSelect, DEBUG_TYPE,"Assign register bank of generic virtual registers", false, false) RegBankSelect
Value * expandCodeFor(const SCEV *SH, Type *Ty, Instruction *I)
Insert code to directly compute the specified SCEV expression into the program.
FunctionPass class - This class is used to implement most global optimizations.
Value * getPointerOperand()
bool skipFunction(const Function &F) const
Optional passes call this function to check whether the pass should be skipped.
std::pair< NoneType, bool > insert(const T &V)
insert - Insert an element into the set if it isn't already there.
void setAliasChecks(SmallVector< RuntimePointerChecking::PointerCheck, 4 > Checks)
Sets the runtime alias checks for versioning the loop.
std::pair< const CheckingPtrGroup *, const CheckingPtrGroup * > PointerCheck
A memcheck which made up of a pair of grouped pointers.
bool dominates(const Instruction *Def, const Use &U) const
Return true if Def dominates a use in User.
size_type count(const T &V) const
count - Return 1 if the element is in the set, 0 otherwise.
bool isAlwaysTrue() const override
Implementation of the SCEVPredicate interface.
AnalysisUsage & addRequiredID(const void *ID)
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small...
Module.h This file contains the declarations for the Module class.
Type * getType() const
All values are typed, get the type of this value.
FunctionPass * createLoopLoadEliminationPass()
An interface layer with SCEV used to manage how we see SCEV expressions for values in the context of ...
#define INITIALIZE_PASS_BEGIN(passName, arg, name, cfg, analysis)
const MemoryDepChecker & getDepChecker() const
the Memory Dependence Checker which can determine the loop-independent and loop-carried dependences b...
Drive the analysis of memory accesses in the loop.
static PHINode * Create(Type *Ty, unsigned NumReservedValues, const Twine &NameStr="", Instruction *InsertBefore=nullptr)
Constructors - NumReservedValues is a hint for the number of incoming edges that this phi node will h...
This class emits a version of the loop where run-time checks ensure that may-alias pointers can't ove...
static GCRegistry::Add< ShadowStackGC > C("shadow-stack","Very portable GC for uncooperative code generators")
const SmallVectorImpl< Dependence > * getDependences() const
Returns the memory dependences.
raw_ostream & dbgs()
dbgs() - This returns a reference to a raw_ostream for debugging messages.
void swap(llvm::BitVector &LHS, llvm::BitVector &RHS)
Implement std::swap in terms of BitVector swap.
Class for arbitrary precision integers.
This class uses information about analyze scalars to rewrite expressions in canonical form...
unsigned getComplexity() const override
We estimate the complexity of a union predicate as the size number of predicates in the union...
const SCEVUnionPredicate & getUnionPredicate() const
LLVM_ATTRIBUTE_ALWAYS_INLINE iterator end()
Represents a single loop in the control flow graph.
ScalarEvolution * getSE() const
Returns the ScalarEvolution analysis used.
LLVM_ATTRIBUTE_ALWAYS_INLINE size_type size() const
LLVM_NODISCARD std::enable_if<!is_simple_type< Y >::value, typename cast_retty< X, const Y >::ret_type >::type dyn_cast(const Y &Val)
DenseMap< Instruction *, unsigned > generateInstructionOrderMap() const
Generate a mapping between the memory instructions and their indices according to program order...
raw_ostream & operator<<(raw_ostream &OS, const APInt &I)
OutputIt transform(R &&Range, OutputIt d_first, UnaryPredicate P)
Wrapper function around std::transform to apply a function to a range and store the result elsewhere...
iterator_range< df_iterator< T > > depth_first(const T &G)
assert(ImpDefSCC.getReg()==AMDGPU::SCC &&ImpDefSCC.isDef())
static cl::opt< unsigned > CheckPerElim("runtime-check-per-loop-load-elim", cl::Hidden, cl::desc("Max number of memchecks allowed per eliminated load on average"), cl::init(1))
LLVM Value Representation.
virtual bool runOnFunction(Function &F)=0
runOnFunction - Virtual method overriden by subclasses to do the per-function processing of the pass...
This class implements an extremely fast bulk output stream that can only output to a stream...
The legacy pass manager's analysis pass to compute loop information.
Legacy analysis pass which computes a DominatorTree.
unsigned getPointerAddressSpace() const
Get the address space of this pointer or pointer vector type.
static GCRegistry::Add< ErlangGC > A("erlang","erlang-compatible garbage collector")
const BasicBlock * getParent() const
static cl::opt< unsigned > LoadElimSCEVCheckThreshold("loop-load-elimination-scev-check-threshold", cl::init(8), cl::Hidden, cl::desc("The maximum number of SCEV checks allowed for Loop ""Load Elimination"))