67#define DEBUG_TYPE "mergeicmps"
75 :
GEP(
GEP), LoadI(LoadI), BaseId(BaseId), Offset(Offset) {}
77 BCEAtom(
const BCEAtom &) =
delete;
78 BCEAtom &operator=(
const BCEAtom &) =
delete;
80 BCEAtom(BCEAtom &&that) =
default;
81 BCEAtom &operator=(BCEAtom &&that) {
87 Offset = std::move(that.Offset);
102 return BaseId != O.BaseId ? BaseId < O.BaseId : Offset.slt(O.Offset);
113class BaseIdentifier {
119 const auto Insertion = BaseToIndex.try_emplace(
Base, Order);
120 if (Insertion.second)
122 return Insertion.first->second;
133BCEAtom visitICmpLoadOperand(
Value *
const Val, BaseIdentifier &BaseId) {
134 auto *
const LoadI = dyn_cast<LoadInst>(Val);
138 if (LoadI->isUsedOutsideOfBlock(LoadI->getParent())) {
143 if (!LoadI->isSimple()) {
148 if (
Addr->getType()->getPointerAddressSpace() != 0) {
152 const auto &
DL = LoadI->getModule()->getDataLayout();
162 auto *
GEP = dyn_cast<GetElementPtrInst>(
Addr);
165 if (
GEP->isUsedOutsideOfBlock(LoadI->getParent())) {
171 Base =
GEP->getPointerOperand();
188 BCECmp(BCEAtom L, BCEAtom R,
int SizeBits,
const ICmpInst *CmpI)
203 BCECmpBlock(BCECmp Cmp,
BasicBlock *BB, InstructionSet BlockInsts)
206 const BCEAtom &Lhs()
const {
return Cmp.Lhs; }
207 const BCEAtom &Rhs()
const {
return Cmp.Rhs; }
208 int SizeBits()
const {
return Cmp.SizeBits; }
211 bool doesOtherWork()
const;
231 InstructionSet BlockInsts;
233 bool RequireSplit =
false;
235 unsigned OrigOrder = 0;
241bool BCECmpBlock::canSinkBCECmpInst(
const Instruction *Inst,
246 auto MayClobber = [&](
LoadInst *LI) {
252 if (MayClobber(
Cmp.Lhs.LoadI) || MayClobber(
Cmp.Rhs.LoadI))
258 const Instruction *OpI = dyn_cast<Instruction>(Op);
259 return OpI && BlockInsts.contains(OpI);
266 if (BlockInsts.count(&Inst))
268 assert(canSinkBCECmpInst(&Inst, AA) &&
"Split unsplittable block");
281 if (!BlockInsts.count(&Inst)) {
282 if (!canSinkBCECmpInst(&Inst, AA))
289bool BCECmpBlock::doesOtherWork()
const {
295 if (!BlockInsts.count(&Inst))
303std::optional<BCECmp> visitICmp(
const ICmpInst *
const CmpI,
305 BaseIdentifier &BaseId) {
318 << (ExpectedPredicate == ICmpInst::ICMP_EQ ?
"eq" :
"ne")
320 auto Lhs = visitICmpLoadOperand(CmpI->
getOperand(0), BaseId);
323 auto Rhs = visitICmpLoadOperand(CmpI->
getOperand(1), BaseId);
327 return BCECmp(std::move(Lhs), std::move(Rhs),
333std::optional<BCECmpBlock> visitCmpBlock(
Value *
const Val,
336 BaseIdentifier &BaseId) {
339 auto *
const BranchI = dyn_cast<BranchInst>(
Block->getTerminator());
345 if (BranchI->isUnconditional()) {
351 ExpectedPredicate = ICmpInst::ICMP_EQ;
355 const auto *
const Const = cast<ConstantInt>(Val);
357 if (!
Const->isZero())
360 assert(BranchI->getNumSuccessors() == 2 &&
"expecting a cond branch");
361 BasicBlock *
const FalseBlock = BranchI->getSuccessor(1);
362 Cond = BranchI->getCondition();
364 FalseBlock == PhiBlock ? ICmpInst::ICMP_EQ : ICmpInst::ICMP_NE;
367 auto *CmpI = dyn_cast<ICmpInst>(
Cond);
372 std::optional<BCECmp>
Result = visitICmp(CmpI, ExpectedPredicate, BaseId);
381 BlockInsts.insert(
Result->Rhs.GEP);
382 return BCECmpBlock(std::move(*Result), Block, BlockInsts);
385static inline void enqueueBlock(std::vector<BCECmpBlock> &Comparisons,
386 BCECmpBlock &&Comparison) {
388 <<
"': Found cmp of " << Comparison.SizeBits()
389 <<
" bits between " << Comparison.Lhs().BaseId <<
" + "
390 << Comparison.Lhs().Offset <<
" and "
391 << Comparison.Rhs().BaseId <<
" + "
392 << Comparison.Rhs().Offset <<
"\n");
394 Comparison.OrigOrder = Comparisons.size();
395 Comparisons.push_back(std::move(Comparison));
401 using ContiguousBlocks = std::vector<BCECmpBlock>;
403 BCECmpChain(
const std::vector<BasicBlock *> &Blocks,
PHINode &Phi,
409 bool atLeastOneMerged()
const {
410 return any_of(MergedBlocks_,
411 [](
const auto &Blocks) {
return Blocks.size() > 1; });
417 std::vector<ContiguousBlocks> MergedBlocks_;
422static bool areContiguous(
const BCECmpBlock &First,
const BCECmpBlock &Second) {
423 return First.Lhs().BaseId == Second.Lhs().BaseId &&
424 First.Rhs().BaseId == Second.Rhs().BaseId &&
425 First.Lhs().Offset +
First.SizeBits() / 8 == Second.Lhs().Offset &&
426 First.Rhs().Offset +
First.SizeBits() / 8 == Second.Rhs().Offset;
429static unsigned getMinOrigOrder(
const BCECmpChain::ContiguousBlocks &Blocks) {
430 unsigned MinOrigOrder = std::numeric_limits<unsigned>::max();
431 for (
const BCECmpBlock &Block : Blocks)
432 MinOrigOrder = std::min(MinOrigOrder,
Block.OrigOrder);
438static std::vector<BCECmpChain::ContiguousBlocks>
439mergeBlocks(std::vector<BCECmpBlock> &&Blocks) {
440 std::vector<BCECmpChain::ContiguousBlocks> MergedBlocks;
444 [](
const BCECmpBlock &LhsBlock,
const BCECmpBlock &RhsBlock) {
445 return std::tie(LhsBlock.Lhs(), LhsBlock.Rhs()) <
446 std::tie(RhsBlock.Lhs(), RhsBlock.Rhs());
449 BCECmpChain::ContiguousBlocks *LastMergedBlock =
nullptr;
450 for (BCECmpBlock &Block : Blocks) {
451 if (!LastMergedBlock || !areContiguous(LastMergedBlock->back(), Block)) {
452 MergedBlocks.emplace_back();
453 LastMergedBlock = &MergedBlocks.back();
456 << LastMergedBlock->back().BB->getName() <<
"\n");
458 LastMergedBlock->push_back(std::move(Block));
463 llvm::sort(MergedBlocks, [](
const BCECmpChain::ContiguousBlocks &LhsBlocks,
464 const BCECmpChain::ContiguousBlocks &RhsBlocks) {
465 return getMinOrigOrder(LhsBlocks) < getMinOrigOrder(RhsBlocks);
471BCECmpChain::BCECmpChain(
const std::vector<BasicBlock *> &Blocks,
PHINode &Phi,
474 assert(!Blocks.empty() &&
"a chain should have at least one block");
476 std::vector<BCECmpBlock> Comparisons;
477 BaseIdentifier BaseId;
479 assert(Block &&
"invalid block");
480 std::optional<BCECmpBlock> Comparison = visitCmpBlock(
483 LLVM_DEBUG(
dbgs() <<
"chain with invalid BCECmpBlock, no merge.\n");
486 if (Comparison->doesOtherWork()) {
488 <<
"' does extra work besides compare\n");
489 if (Comparisons.empty()) {
503 if (Comparison->canSplit(AA)) {
505 <<
"Split initial block '" << Comparison->BB->getName()
506 <<
"' that does extra work besides compare\n");
507 Comparison->RequireSplit =
true;
508 enqueueBlock(Comparisons, std::move(*Comparison));
511 <<
"ignoring initial block '" << Comparison->BB->getName()
512 <<
"' that does extra work besides compare\n");
541 enqueueBlock(Comparisons, std::move(*Comparison));
545 if (Comparisons.empty()) {
546 LLVM_DEBUG(
dbgs() <<
"chain with no BCE basic blocks, no merge\n");
549 EntryBlock_ = Comparisons[0].BB;
550 MergedBlocks_ = mergeBlocks(std::move(Comparisons));
557class MergedBlockName {
563 :
Name(makeName(Comparisons)) {}
570 if (Comparisons.
size() == 1)
571 return Comparisons[0].BB->getName();
572 const int size = std::accumulate(Comparisons.
begin(), Comparisons.
end(), 0,
573 [](
int i,
const BCECmpBlock &Cmp) {
574 return i + Cmp.BB->getName().size();
585 Scratch.
append(str.begin(), str.end());
587 append(Comparisons[0].BB->getName());
588 for (
int I = 1,
E = Comparisons.
size();
I <
E; ++
I) {
595 return Scratch.
str();
606 assert(!Comparisons.
empty() &&
"merging zero comparisons");
608 const BCECmpBlock &FirstCmp = Comparisons[0];
613 NextCmpBlock->
getParent(), InsertBefore);
617 if (FirstCmp.Lhs().GEP)
618 Lhs =
Builder.Insert(FirstCmp.Lhs().GEP->clone());
620 Lhs = FirstCmp.Lhs().LoadI->getPointerOperand();
621 if (FirstCmp.Rhs().GEP)
622 Rhs =
Builder.Insert(FirstCmp.Rhs().GEP->clone());
624 Rhs = FirstCmp.Rhs().LoadI->getPointerOperand();
626 Value *IsEqual =
nullptr;
634 Comparisons, [](
const BCECmpBlock &
B) {
return B.RequireSplit; });
635 if (ToSplit != Comparisons.
end()) {
637 ToSplit->split(BB, AA);
640 if (Comparisons.
size() == 1) {
642 Value *
const LhsLoad =
643 Builder.CreateLoad(FirstCmp.Lhs().LoadI->getType(), Lhs);
644 Value *
const RhsLoad =
645 Builder.CreateLoad(FirstCmp.Rhs().LoadI->getType(), Rhs);
647 IsEqual =
Builder.CreateICmpEQ(LhsLoad, RhsLoad);
649 const unsigned TotalSizeBits = std::accumulate(
650 Comparisons.
begin(), Comparisons.
end(), 0u,
651 [](
int Size,
const BCECmpBlock &
C) { return Size + C.SizeBits(); });
663 IsEqual =
Builder.CreateICmpEQ(
669 if (NextCmpBlock == PhiBB) {
676 Builder.CreateCondBr(IsEqual, NextCmpBlock, PhiBB);
678 DTU.
applyUpdates({{DominatorTree::Insert, BB, NextCmpBlock},
679 {DominatorTree::Insert, BB, PhiBB}});
686 assert(atLeastOneMerged() &&
"simplifying trivial BCECmpChain");
687 LLVM_DEBUG(
dbgs() <<
"Simplifying comparison chain starting at block "
688 << EntryBlock_->getName() <<
"\n");
694 for (
const auto &Blocks :
reverse(MergedBlocks_)) {
695 InsertBefore = NextCmpBlock = mergeComparisons(
696 Blocks, InsertBefore, NextCmpBlock, Phi_, TLI, AA, DTU);
707 DTU.
applyUpdates({{DominatorTree::Delete, Pred, EntryBlock_},
708 {DominatorTree::Insert, Pred, NextCmpBlock}});
713 const bool ChainEntryIsFnEntry = EntryBlock_->isEntryBlock();
714 if (ChainEntryIsFnEntry && DTU.
hasDomTree()) {
716 << EntryBlock_->getName() <<
" to "
717 << NextCmpBlock->
getName() <<
"\n");
719 DTU.
applyUpdates({{DominatorTree::Delete, NextCmpBlock, EntryBlock_}});
721 EntryBlock_ =
nullptr;
725 for (
const auto &Blocks : MergedBlocks_) {
726 for (
const BCECmpBlock &Block : Blocks) {
734 MergedBlocks_.clear();
738std::vector<BasicBlock *> getOrderedBlocks(
PHINode &Phi,
742 std::vector<BasicBlock *> Blocks(NumBlocks);
743 assert(LastBlock &&
"invalid last block");
745 for (
int BlockIndex = NumBlocks - 1; BlockIndex > 0; --BlockIndex) {
750 <<
" has its address taken\n");
753 Blocks[BlockIndex] = CurBlock;
755 if (!SinglePredecessor) {
758 <<
" has two or more predecessors\n");
764 <<
" does not link back to the phi\n");
767 CurBlock = SinglePredecessor;
769 Blocks[0] = CurBlock;
814 <<
"skip: non-constant value not from cmp or not from last block.\n");
831 if (Blocks.empty())
return false;
832 BCECmpChain CmpChain(Blocks, Phi, AA);
834 if (!CmpChain.atLeastOneMerged()) {
839 return CmpChain.simplify(TLI, AA, DTU);
853 if (!TLI.
has(LibFunc_memcmp))
857 DomTreeUpdater::UpdateStrategy::Eager);
859 bool MadeChange =
false;
863 if (
auto *
const Phi = dyn_cast<PHINode>(&*BB.
begin()))
864 MadeChange |= processPhi(*Phi, TLI, AA, DTU);
880 const auto &TLI = getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(
F);
881 const auto &
TTI = getAnalysis<TargetTransformInfoWrapperPass>().getTTI(
F);
884 auto *DTWP = getAnalysisIfAvailable<DominatorTreeWrapperPass>();
885 auto &AA = getAnalysis<AAResultsWrapperPass>().getAAResults();
886 return runImpl(
F, TLI,
TTI, AA, DTWP ? &DTWP->getDomTree() :
nullptr);
901char MergeICmpsLegacyPass::ID = 0;
903 "Merge contiguous icmps into a memcmp",
false,
false)
918 const bool MadeChanges =
runImpl(
F, TLI,
TTI, AA, DT);
MachineBasicBlock MachineBasicBlock::iterator DebugLoc DL
SmallVector< MachineOperand, 4 > Cond
static GCRegistry::Add< OcamlGC > B("ocaml", "ocaml 3.10-compatible GC")
static GCRegistry::Add< CoreCLRGC > E("coreclr", "CoreCLR-compatible GC")
static Error split(StringRef Str, char Separator, std::pair< StringRef, StringRef > &Split)
Checked version of split, to ensure mandatory subparts.
static bool runImpl(Function &F, const TargetLowering &TLI)
This is the interface for a simple mod/ref and alias analysis over globals.
Merge contiguous icmps into a memcmp
return ToRemove size() > 0
#define INITIALIZE_PASS_DEPENDENCY(depName)
#define INITIALIZE_PASS_END(passName, arg, name, cfg, analysis)
#define INITIALIZE_PASS_BEGIN(passName, arg, name, cfg, analysis)
assert(ImpDefSCC.getReg()==AMDGPU::SCC &&ImpDefSCC.isDef())
A manager for alias analyses.
A wrapper pass to provide the legacy pass manager access to a suitably prepared AAResults object.
ModRefInfo getModRefInfo(const Instruction *I, const std::optional< MemoryLocation > &OptLoc)
Check whether or not an instruction may read or write the optionally specified memory location.
Class for arbitrary precision integers.
A container for analyses that lazily runs them and caches their results.
PassT::Result * getCachedResult(IRUnitT &IR) const
Get the cached result of an analysis pass for a given IR unit.
PassT::Result & getResult(IRUnitT &IR, ExtraArgTs... ExtraArgs)
Get the result of an analysis pass for a given IR unit.
Represent the analysis usage information of a pass.
AnalysisUsage & addRequired()
AnalysisUsage & addPreserved()
Add the specified Pass class to the set of analyses preserved by this pass.
ArrayRef - Represent a constant reference to an array (0 or more elements consecutively in memory),...
size_t size() const
size - Get the array size.
bool empty() const
empty - Check if the array is empty.
LLVM Basic Block Representation.
iterator begin()
Instruction iterator methods.
bool hasAddressTaken() const
Returns true if there are any uses of this basic block other than direct branches,...
static BasicBlock * Create(LLVMContext &Context, const Twine &Name="", Function *Parent=nullptr, BasicBlock *InsertBefore=nullptr)
Creates a new BasicBlock.
const BasicBlock * getSinglePredecessor() const
Return the predecessor of this block if it has a single predecessor block.
const BasicBlock * getSingleSuccessor() const
Return the successor of this block if it has a single successor.
const Function * getParent() const
Return the enclosing method, or null if none.
LLVMContext & getContext() const
Get the context in which this basic block lives.
const Instruction * getTerminator() const LLVM_READONLY
Returns the terminator instruction if the block is well formed or null if the block is not well forme...
Predicate
This enumeration lists the possible predicates for CmpInst subclasses.
Predicate getPredicate() const
Return the predicate for this instruction.
static Constant * get(Type *Ty, uint64_t V, bool IsSigned=false)
If Ty is a vector type, return a Constant with a splat of the given value.
static ConstantInt * getFalse(LLVMContext &Context)
bool hasDomTree() const
Returns true if it holds a DominatorTree.
void applyUpdates(ArrayRef< DominatorTree::UpdateType > Updates)
Submit updates to all available trees.
DominatorTree & getDomTree()
Flush DomTree updates and return DomTree.
Analysis pass which computes a DominatorTree.
DomTreeNodeBase< NodeT > * setNewRoot(NodeT *BB)
Add a new node to the forward dominator tree and make it a new root.
Legacy analysis pass which computes a DominatorTree.
Concrete subclass of DominatorTreeBase that is used to compute a normal dominator tree.
FunctionPass class - This class is used to implement most global optimizations.
virtual bool runOnFunction(Function &F)=0
runOnFunction - Virtual method overriden by subclasses to do the per-function processing of the pass.
bool skipFunction(const Function &F) const
Optional passes call this function to check whether the pass should be skipped.
an instruction for type-safe pointer arithmetic to access elements of arrays and structs
Legacy wrapper pass to provide the GlobalsAAResult object.
This instruction compares its operands according to the predicate given to the constructor.
This provides a uniform API for creating instructions and inserting them into a basic block: either a...
bool mayWriteToMemory() const LLVM_READONLY
Return true if this instruction may modify memory.
const Module * getModule() const
Return the module owning the function this instruction belongs to or nullptr it the function does not...
const BasicBlock * getParent() const
bool comesBefore(const Instruction *Other) const
Given an instruction Other in the same basic block as this instruction, return true if this instructi...
void moveBefore(Instruction *MovePos)
Unlink this instruction from its current basic block and insert it into the basic block that MovePos ...
This is an important class for using LLVM in a threaded context.
An instruction for reading from memory.
static MemoryLocation get(const LoadInst *LI)
Return a location with information about the memory reference by the given instruction.
const DataLayout & getDataLayout() const
Get the data layout for the module's target platform.
void addIncoming(Value *V, BasicBlock *BB)
Add an incoming value to the end of the PHI list.
Value * getIncomingValueForBlock(const BasicBlock *BB) const
BasicBlock * getIncomingBlock(unsigned i) const
Return incoming basic block number i.
Value * getIncomingValue(unsigned i) const
Return incoming value number x.
int getBasicBlockIndex(const BasicBlock *BB) const
Return the first index of the specified basic block in the value list for this PHI.
unsigned getNumIncomingValues() const
Return the number of incoming edges.
static PassRegistry * getPassRegistry()
getPassRegistry - Access the global registry object, which is automatically initialized at applicatio...
Pass interface - Implemented by all 'passes'.
virtual void getAnalysisUsage(AnalysisUsage &) const
getAnalysisUsage - This function should be overriden by passes that need analysis information to do t...
A set of analyses that are preserved following a run of a transformation pass.
static PreservedAnalyses all()
Construct a special preserved set that preserves all passes.
void preserve()
Mark an analysis as preserved.
Implements a dense probed hash-table based set with some number of buckets stored inline.
SmallString - A SmallString is just a SmallVector with methods and accessors that make it work better...
void append(StringRef RHS)
Append from a StringRef.
StringRef str() const
Explicit conversion to StringRef.
void reserve(size_type N)
void push_back(const T &Elt)
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
StringRef - Represent a constant reference to a string, i.e.
constexpr bool empty() const
empty - Check if the string is empty.
Analysis pass providing the TargetTransformInfo.
Analysis pass providing the TargetLibraryInfo.
Provides information about what library functions are available for the current target.
bool has(LibFunc F) const
Tests whether a library function is available.
unsigned getSizeTSize(const Module &M) const
Returns the size of the size_t type in bits.
unsigned getIntSize() const
Get size of a C-level int or unsigned int, in bits.
bool replaceUsesOfWith(Value *From, Value *To)
Replace uses of one Value with another.
Value * getOperand(unsigned i) const
LLVM Value Representation.
Type * getType() const
All values are typed, get the type of this value.
bool hasOneUse() const
Return true if there is exactly one use of this value.
StringRef getName() const
Return a constant reference to the value's name.
std::pair< iterator, bool > insert(const ValueT &V)
unsigned ID
LLVM IR allows to use arbitrary numbers as calling convention identifiers.
@ C
The default llvm calling convention, compatible with C.
void append(SmallVectorImpl< char > &path, const Twine &a, const Twine &b="", const Twine &c="", const Twine &d="")
Append to path.
This is an optimization pass for GlobalISel generic memory operations.
auto drop_begin(T &&RangeOrContainer, size_t N=1)
Return a range covering RangeOrContainer with the first N elements excluded.
bool operator<(int64_t V1, const APSInt &V2)
void initializeMergeICmpsLegacyPassPass(PassRegistry &)
bool any_of(R &&range, UnaryPredicate P)
Provide wrappers to std::any_of which take ranges instead of having to pass begin/end explicitly.
Value * emitMemCmp(Value *Ptr1, Value *Ptr2, Value *Len, IRBuilderBase &B, const DataLayout &DL, const TargetLibraryInfo *TLI)
Emit a call to the memcmp function.
auto reverse(ContainerTy &&C)
bool isModSet(const ModRefInfo MRI)
void sort(IteratorTy Start, IteratorTy End)
raw_ostream & dbgs()
dbgs() - This returns a reference to a raw_ostream for debugging messages.
bool none_of(R &&Range, UnaryPredicate P)
Provide wrappers to std::none_of which take ranges instead of having to pass begin/end explicitly.
Interval::pred_iterator pred_begin(Interval *I)
pred_begin/pred_end - define methods so that Intervals may be used just like BasicBlocks can with the...
Pass * createMergeICmpsLegacyPass()
bool isDereferenceablePointer(const Value *V, Type *Ty, const DataLayout &DL, const Instruction *CtxI=nullptr, AssumptionCache *AC=nullptr, const DominatorTree *DT=nullptr, const TargetLibraryInfo *TLI=nullptr)
Return true if this is always a dereferenceable pointer.
OutputIt move(R &&Range, OutputIt Out)
Provide wrappers to std::move which take ranges instead of having to pass begin/end explicitly.
auto find_if(R &&Range, UnaryPredicate P)
Provide wrappers to std::find_if which take ranges instead of having to pass begin/end explicitly.
bool pred_empty(const BasicBlock *BB)
void DeleteDeadBlocks(ArrayRef< BasicBlock * > BBs, DomTreeUpdater *DTU=nullptr, bool KeepOneInputPHIs=false)
Delete the specified blocks from BB.
void swap(llvm::BitVector &LHS, llvm::BitVector &RHS)
Implement std::swap in terms of BitVector swap.
PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM)