48#include <unordered_map>
54#define DEBUG_TYPE "pgo-icall-prom"
56STATISTIC(NumOfPGOICallPromotion,
"Number of indirect call promotions.");
57STATISTIC(NumOfPGOICallsites,
"Number of indirect call candidate sites.");
68 cl::desc(
"Disable indirect call promotion"));
76 cl::desc(
"Max number of promotions for this compilation"));
82 cl::desc(
"Skip Callsite up to this number for this compilation"));
88 cl::desc(
"Run indirect-call promotion in LTO "
95 cl::desc(
"Run indirect-call promotion in SamplePGO mode"));
101 cl::desc(
"Run indirect-call promotion for call instructions "
108 cl::desc(
"Run indirect-call promotion for "
109 "invoke instruction only"));
115 cl::desc(
"Dump IR after transformation happens"));
121 cl::desc(
"The percentage threshold of vtable-count / function-count for "
122 "cost-benefit analysis."));
133 cl::desc(
"The maximum number of vtable for the last candidate."));
140using VTableAddressPointOffsetValMap =
144struct VirtualCallSiteInfo {
154using VirtualCallSiteTypeInfoMap =
165static std::optional<uint64_t>
169 VTableVar.
getMetadata(LLVMContext::MD_type, Types);
172 if (
auto *TypeId = dyn_cast<MDString>(
Type->getOperand(1).get());
173 TypeId && TypeId->getString() == CompatibleType)
174 return cast<ConstantInt>(
175 cast<ConstantAsMetadata>(
Type->getOperand(0))->getValue())
187 assert(AddressPointOffset <
188 M.getDataLayout().getTypeAllocSize(
VTable->getValueType()) &&
189 "Out-of-bound access");
198 if (
PHINode *PN = dyn_cast<PHINode>(UserInst))
199 return PN->getIncomingBlock(U);
214 "Guaranteed by ICP transformation");
225 UserBB = getUserBasicBlock(
Use, UserInst);
228 if (UserBB != DestBB)
231 return UserBB !=
nullptr;
239 if (!isDestBBSuitableForSink(
I, DestBlock))
244 if (isa<PHINode>(
I) ||
I->isEHPad() ||
I->mayThrow() || !
I->willReturn() ||
249 if (
const auto *
C = dyn_cast<CallBase>(
I))
250 if (
C->isInlineAsm() ||
C->cannotMerge() ||
C->isConvergent())
254 if (
I->mayWriteToMemory())
259 if (
I->mayReadFromMemory()) {
262 E =
I->getParent()->end();
267 if (Scan->mayWriteToMemory())
273 I->moveBefore(*DestBlock, InsertPos);
284static int tryToSinkInstructions(
BasicBlock *OriginalBB,
293 if (tryToSinkInstruction(&
I, IndirectCallBB))
301class IndirectCallPromoter {
312 const bool SamplePGO;
315 const VirtualCallSiteTypeInfoMap &VirtualCSInfo;
317 VTableAddressPointOffsetValMap &VTableAddressPointOffsetVal;
322 struct PromotionCandidate {
333 VTableGUIDCountsMap VTableGUIDAndCounts;
344 std::vector<PromotionCandidate> getPromotionCandidatesForCallSite(
356 VTableGUIDCountsMap &VTableGUIDCounts);
361 bool tryToPromoteWithVTableCmp(
365 VTableGUIDCountsMap &VTableGUIDCounts);
368 bool isProfitableToCompareVTables(
const CallBase &CB,
380 VTableGUIDCountsMap &VTableGUIDCounts,
381 std::vector<PromotionCandidate> &Candidates);
390 VTableGUIDCountsMap &VTableGUIDCounts);
393 IndirectCallPromoter(
396 const VirtualCallSiteTypeInfoMap &VirtualCSInfo,
397 VTableAddressPointOffsetValMap &VTableAddressPointOffsetVal,
399 :
F(
Func),
M(
M), PSI(PSI), Symtab(Symtab), SamplePGO(SamplePGO),
400 VirtualCSInfo(VirtualCSInfo),
401 VTableAddressPointOffsetVal(VTableAddressPointOffsetVal), ORE(ORE) {}
402 IndirectCallPromoter(
const IndirectCallPromoter &) =
delete;
403 IndirectCallPromoter &operator=(
const IndirectCallPromoter &) =
delete;
412std::vector<IndirectCallPromoter::PromotionCandidate>
413IndirectCallPromoter::getPromotionCandidatesForCallSite(
416 std::vector<PromotionCandidate>
Ret;
418 LLVM_DEBUG(
dbgs() <<
" \nWork on callsite #" << NumOfPGOICallsites << CB
419 <<
" Num_targets: " << ValueDataRef.
size()
420 <<
" Num_candidates: " << NumCandidates <<
"\n");
421 NumOfPGOICallsites++;
429 assert(Count <= TotalCount);
433 <<
" Target_func: " <<
Target <<
"\n");
439 <<
" Not promote: User options";
447 <<
" Not promote: User options";
455 <<
" Not promote: Cutoff reached";
468 if (TargetFunction ==
nullptr || TargetFunction->
isDeclaration()) {
472 <<
"Cannot promote indirect call: target with md5sum "
478 const char *Reason =
nullptr;
484 <<
"Cannot promote indirect call to "
485 <<
NV(
"TargetFunction", TargetFunction) <<
" with count of "
486 <<
NV(
"Count", Count) <<
": " << Reason;
491 Ret.push_back(PromotionCandidate(TargetFunction, Count));
497Constant *IndirectCallPromoter::getOrCreateVTableAddressPointVar(
500 VTableAddressPointOffsetVal[GV].try_emplace(AddressPointOffset,
nullptr);
502 Iter->second = getVTableAddressPointOffset(GV, AddressPointOffset);
506Instruction *IndirectCallPromoter::computeVTableInfos(
507 const CallBase *CB, VTableGUIDCountsMap &GUIDCountsMap,
508 std::vector<PromotionCandidate> &Candidates) {
536 auto Iter = VirtualCSInfo.find(CB);
537 if (Iter == VirtualCSInfo.end())
541 << NumOfPGOICallsites <<
"\n");
543 const auto &VirtualCallInfo = Iter->second;
547 for (
size_t I = 0;
I < Candidates.size();
I++)
548 CalleeIndexMap[Candidates[
I].TargetFunction] =
I;
551 auto VTableValueDataArray =
554 if (VTableValueDataArray.empty())
558 for (
const auto &V : VTableValueDataArray) {
560 GUIDCountsMap[VTableVal] =
V.Count;
563 LLVM_DEBUG(
dbgs() <<
" Cannot find vtable definition for " << VTableVal
564 <<
"; maybe the vtable isn't imported\n");
568 std::optional<uint64_t> MaybeAddressPointOffset =
569 getAddressPointOffset(*VTableVar, VirtualCallInfo.CompatibleTypeStr);
570 if (!MaybeAddressPointOffset)
573 const uint64_t AddressPointOffset = *MaybeAddressPointOffset;
577 VTableVar, AddressPointOffset + VirtualCallInfo.FunctionOffset, M);
580 auto CalleeIndexIter = CalleeIndexMap.
find(Callee);
581 if (CalleeIndexIter == CalleeIndexMap.
end())
584 auto &Candidate = Candidates[CalleeIndexIter->second];
588 Candidate.VTableGUIDAndCounts[VTableVal] =
V.Count;
589 Candidate.AddressPoints.push_back(
590 getOrCreateVTableAddressPointVar(VTableVar, AddressPointOffset));
608 bool AttachProfToDirectCall,
614 if (AttachProfToDirectCall)
623 <<
"Promote indirect call to " << NV(
"DirectCallee", DirectCallee)
624 <<
" with count " << NV(
"Count", Count) <<
" out of "
625 << NV(
"TotalCount", TotalCount);
631bool IndirectCallPromoter::tryToPromoteWithFuncCmp(
634 uint32_t NumCandidates, VTableGUIDCountsMap &VTableGUIDCounts) {
637 for (
const auto &
C : Candidates) {
641 assert(TotalCount >= FuncCount);
642 TotalCount -= FuncCount;
643 NumOfPGOICallPromotion++;
654 for (
const auto &[GUID, VTableCount] :
C.VTableGUIDAndCounts)
655 SumVTableCount += VTableCount;
657 for (
const auto &[GUID, VTableCount] :
C.VTableGUIDAndCounts) {
658 APInt APFuncCount((
unsigned)128, FuncCount,
false );
659 APFuncCount *= VTableCount;
660 VTableGUIDCounts[GUID] -= APFuncCount.udiv(SumVTableCount).getZExtValue();
663 if (NumPromoted == 0)
666 assert(NumPromoted <= ICallProfDataRef.
size() &&
667 "Number of promoted functions should not be greater than the number "
668 "of values in profile metadata");
671 updateFuncValueProfiles(CB, ICallProfDataRef.
slice(NumPromoted), TotalCount,
673 updateVPtrValueProfiles(VPtr, VTableGUIDCounts);
677void IndirectCallPromoter::updateFuncValueProfiles(
688void IndirectCallPromoter::updateVPtrValueProfiles(
689 Instruction *VPtr, VTableGUIDCountsMap &VTableGUIDCounts) {
694 std::vector<InstrProfValueData> VTableValueProfiles;
696 for (
auto [GUID, Count] : VTableGUIDCounts) {
700 VTableValueProfiles.push_back({
GUID, Count});
701 TotalVTableCount += Count;
704 [](
const InstrProfValueData &LHS,
const InstrProfValueData &RHS) {
705 return LHS.Count >
RHS.Count;
709 IPVK_VTableTarget, VTableValueProfiles.size());
712bool IndirectCallPromoter::tryToPromoteWithVTableCmp(
716 VTableGUIDCountsMap &VTableGUIDCounts) {
719 for (
const auto &Candidate : Candidates) {
720 for (
auto &[GUID, Count] : Candidate.VTableGUIDAndCounts)
721 VTableGUIDCounts[
GUID] -= Count;
728 CB, VPtr, Candidate.TargetFunction, Candidate.AddressPoints,
730 TotalFuncCount - Candidate.Count));
732 int SinkCount = tryToSinkInstructions(OriginalBB, CB.
getParent());
737 const auto &VTableGUIDAndCounts = Candidate.VTableGUIDAndCounts;
738 Remark <<
"Promote indirect call to "
739 <<
ore::NV(
"DirectCallee", Candidate.TargetFunction)
740 <<
" with count " <<
ore::NV(
"Count", Candidate.Count)
741 <<
" out of " <<
ore::NV(
"TotalCount", TotalFuncCount) <<
", sink "
742 <<
ore::NV(
"SinkCount", SinkCount)
743 <<
" instruction(s) and compare "
744 <<
ore::NV(
"VTable", VTableGUIDAndCounts.size())
748 std::set<uint64_t> GUIDSet;
749 for (
auto [GUID, Count] : VTableGUIDAndCounts)
750 GUIDSet.insert(GUID);
751 for (
auto Iter = GUIDSet.begin(); Iter != GUIDSet.end(); Iter++) {
752 if (Iter != GUIDSet.begin())
754 Remark <<
ore::NV(
"VTable", Symtab->getGlobalVariable(*Iter));
762 PromotedFuncCount.
push_back(Candidate.Count);
764 assert(TotalFuncCount >= Candidate.Count &&
765 "Within one prof metadata, total count is the sum of counts from "
766 "individual <target, count> pairs");
770 TotalFuncCount -= std::min(TotalFuncCount, Candidate.Count);
771 NumOfPGOICallPromotion++;
774 if (PromotedFuncCount.
empty())
783 for (
size_t I = 0;
I < PromotedFuncCount.
size();
I++)
784 ICallProfDataRef[
I].Count -=
785 std::max(PromotedFuncCount[
I], ICallProfDataRef[
I].Count);
788 const InstrProfValueData &RHS) {
789 return LHS.Count >
RHS.Count;
793 ICallProfDataRef.
begin(),
795 [](
uint64_t Count,
const InstrProfValueData &ProfData) {
796 return ProfData.Count <= Count;
798 updateFuncValueProfiles(CB, VDs, TotalFuncCount, NumCandidates);
799 updateVPtrValueProfiles(VPtr, VTableGUIDCounts);
806 bool Changed =
false;
812 CB, TotalCount, NumCandidates);
813 if (!NumCandidates ||
817 auto PromotionCandidates = getPromotionCandidatesForCallSite(
818 *CB, ICallProfDataRef, TotalCount, NumCandidates);
820 VTableGUIDCountsMap VTableGUIDCounts;
822 computeVTableInfos(CB, VTableGUIDCounts, PromotionCandidates);
824 if (isProfitableToCompareVTables(*CB, PromotionCandidates, TotalCount))
825 Changed |= tryToPromoteWithVTableCmp(*CB, VPtr, PromotionCandidates,
826 TotalCount, NumCandidates,
827 ICallProfDataRef, VTableGUIDCounts);
829 Changed |= tryToPromoteWithFuncCmp(*CB, VPtr, PromotionCandidates,
830 TotalCount, ICallProfDataRef,
831 NumCandidates, VTableGUIDCounts);
838bool IndirectCallPromoter::isProfitableToCompareVTables(
843 LLVM_DEBUG(
dbgs() <<
"\nEvaluating vtable profitability for callsite #"
844 << NumOfPGOICallsites << CB <<
"\n");
845 uint64_t RemainingVTableCount = TotalCount;
846 const size_t CandidateSize = Candidates.
size();
847 for (
size_t I = 0;
I < CandidateSize;
I++) {
848 auto &Candidate = Candidates[
I];
849 auto &VTableGUIDAndCounts = Candidate.VTableGUIDAndCounts;
852 << Candidate.Count <<
", VTableCounts:");
854 for ([[maybe_unused]]
auto &[GUID, Count] : VTableGUIDAndCounts)
855 LLVM_DEBUG(
dbgs() <<
" {" << Symtab->getGlobalVariable(GUID)->getName()
856 <<
", " << Count <<
"}");
860 for (
auto &[GUID, Count] : VTableGUIDAndCounts)
861 CandidateVTableCount += Count;
865 dbgs() <<
" function count " << Candidate.Count
866 <<
" and its vtable sum count " << CandidateVTableCount
867 <<
" have discrepancies. Bail out vtable comparison.\n");
871 RemainingVTableCount -= Candidate.Count;
879 int MaxNumVTable = 1;
880 if (
I == CandidateSize - 1)
883 if ((
int)Candidate.AddressPoints.size() > MaxNumVTable) {
884 LLVM_DEBUG(
dbgs() <<
" allow at most " << MaxNumVTable <<
" and got "
885 << Candidate.AddressPoints.size()
886 <<
" vtables. Bail out for vtable comparison.\n");
894 LLVM_DEBUG(
dbgs() <<
" Indirect fallback basic block is not cold. Bail "
895 "out for vtable comparison.\n");
910 VirtualCallSiteTypeInfoMap &VirtualCSInfo) {
921 if (!TypeTestFunc || TypeTestFunc->
use_empty())
930 auto *CI = dyn_cast<CallInst>(U.getUser());
933 auto *TypeMDVal = cast<MetadataAsValue>(CI->getArgOperand(1));
936 auto *CompatibleTypeId = dyn_cast<MDString>(TypeMDVal->getMetadata());
937 if (!CompatibleTypeId)
944 auto &DT = LookupDomTree(*CI->getFunction());
947 for (
auto &DevirtCall : DevirtCalls) {
955 VirtualCSInfo[&CB] = {DevirtCall.Offset, VTablePtr,
956 CompatibleTypeId->getString()};
968 std::string SymtabFailure =
toString(std::move(E));
969 M.getContext().emitError(
"Failed to create symtab: " + SymtabFailure);
972 bool Changed =
false;
973 VirtualCallSiteTypeInfoMap VirtualCSInfo;
985 VTableAddressPointOffsetValMap VTableAddressPointOffsetVal;
988 if (
F.isDeclaration() ||
F.hasOptNone())
995 IndirectCallPromoter CallPromoter(
F, M, PSI, &Symtab, SamplePGO,
997 VTableAddressPointOffsetVal, ORE);
998 bool FuncChanged = CallPromoter.processFunction(PSI);
1003 Changed |= FuncChanged;
This file defines the DenseMap class.
This file provides the interface for IR based instrumentation passes ( (profile-gen,...
FunctionAnalysisManager FAM
ModuleAnalysisManager MAM
This header defines various interfaces for pass management in LLVM.
This file contains the declarations for profiling metadata utility functions.
assert(ImpDefSCC.getReg()==AMDGPU::SCC &&ImpDefSCC.isDef())
This file defines the 'Statistic' class, which is designed to be an easy way to expose various metric...
#define STATISTIC(VARNAME, DESC)
Class for arbitrary precision integers.
A container for analyses that lazily runs them and caches their results.
PassT::Result & getResult(IRUnitT &IR, ExtraArgTs... ExtraArgs)
Get the result of an analysis pass for a given IR unit.
ArrayRef - Represent a constant reference to an array (0 or more elements consecutively in memory),...
size_t size() const
size - Get the array size.
bool empty() const
empty - Check if the array is empty.
ArrayRef< T > slice(size_t N, size_t M) const
slice(n, m) - Chop off the first N elements of the array, and keep M elements in the array.
LLVM Basic Block Representation.
const_iterator getFirstInsertionPt() const
Returns an iterator to the first instruction in this block that is suitable for inserting a non-PHI i...
const BasicBlock * getUniquePredecessor() const
Return the predecessor of this block if it has a unique predecessor block.
InstListType::iterator iterator
Instruction iterators...
const Instruction * getTerminator() const LLVM_READONLY
Returns the terminator instruction if the block is well formed or null if the block is not well forme...
Base class for all callable instructions (InvokeInst and CallInst) Holds everything related to callin...
static Constant * getInBoundsGetElementPtr(Type *Ty, Constant *C, ArrayRef< Constant * > IdxList)
Create an "inbounds" getelementptr.
This is an important base class in LLVM.
iterator find(const_arg_type_t< KeyT > Val)
Analysis pass which computes a DominatorTree.
Concrete subclass of DominatorTreeBase that is used to compute a normal dominator tree.
Lightweight error class with error context and mandatory checking.
const Function & getFunction() const
MDNode * getMetadata(unsigned KindID) const
Get the current metadata attachments for the given kind, if any.
bool isDeclaration() const
Return true if the primary definition of this global value is outside of the current translation unit...
An analysis over an "outer" IR unit that provides access to an analysis manager over an "inner" IR un...
A symbol table used for function [IR]PGO name look-up with keys (such as pointers,...
Error create(object::SectionRef &Section)
Create InstrProfSymtab from an object file section which contains function PGO names.
bool isDebugOrPseudoInst() const LLVM_READONLY
Return true if the instruction is a DbgInfoIntrinsic or PseudoProbeInst.
unsigned getNumSuccessors() const LLVM_READONLY
Return the number of successors that this instruction has.
MDNode * getMetadata(unsigned KindID) const
Get the metadata of given kind attached to this Instruction.
void setMetadata(unsigned KindID, MDNode *Node)
Set the metadata of the specified kind to the specified node.
This is an important class for using LLVM in a threaded context.
MDNode * createBranchWeights(uint32_t TrueWeight, uint32_t FalseWeight, bool IsExpected=false)
Return metadata containing two branch weights.
A Module instance is used to store all the information related to an LLVM module.
MutableArrayRef - Represent a mutable reference to an array (0 or more elements consecutively in memo...
A set of analyses that are preserved following a run of a transformation pass.
static PreservedAnalyses none()
Convenience factory function for the empty preserved set.
static PreservedAnalyses all()
Construct a special preserved set that preserves all passes.
An analysis pass based on the new PM to deliver ProfileSummaryInfo.
Analysis providing profile information.
bool hasProfileSummary() const
Returns true if profile summary is available.
bool isColdCount(uint64_t C) const
Returns true if count C is considered cold.
bool isHotCount(uint64_t C) const
Returns true if count C is considered hot.
void push_back(const T &Elt)
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
StringRef - Represent a constant reference to a string, i.e.
Target - Wrapper for Target specific information.
The instances of the Type class are immutable: once they are created, they are never changed.
static IntegerType * getInt8Ty(LLVMContext &C)
static IntegerType * getInt32Ty(LLVMContext &C)
A Use represents the edge between a Value definition and its users.
User * getUser() const
Returns the User that contains this Use.
LLVMContext & getContext() const
All values hold a context through their type.
iterator_range< use_iterator > uses()
const ParentTy * getParent() const
@ C
The default llvm calling convention, compatible with C.
StringRef getName(ID id)
Return the LLVM name for an intrinsic, such as "llvm.ppc.altivec.lvx".
initializer< Ty > init(const Ty &Val)
std::optional< const char * > toString(const std::optional< DWARFFormValue > &V)
Take an optional DWARFFormValue and try to extract a string value from it.
DiagnosticInfoOptimizationBase::Argument NV
CallBase & promoteIndirectCall(CallBase &CB, Function *F, uint64_t Count, uint64_t TotalCount, bool AttachProfToDirectCall, OptimizationRemarkEmitter *ORE)
NodeAddr< FuncNode * > Func
This is an optimization pass for GlobalISel generic memory operations.
auto drop_begin(T &&RangeOrContainer, size_t N=1)
Return a range covering RangeOrContainer with the first N elements excluded.
void stable_sort(R &&Range)
bool isLegalToPromote(const CallBase &CB, Function *Callee, const char **FailureReason=nullptr)
Return true if the given indirect call site can be made to call Callee.
std::vector< CallBase * > findIndirectCalls(Function &F)
CallBase & promoteCallWithIfThenElse(CallBase &CB, Function *Callee, MDNode *BranchWeights=nullptr)
Promote the given indirect call site to conditionally call Callee.
iterator_range< early_inc_iterator_impl< detail::IterOfRange< RangeT > > > make_early_inc_range(RangeT &&Range)
Make a range that does early increment to allow mutation of the underlying range without disrupting i...
auto upper_bound(R &&Range, T &&Value)
Provide wrappers to std::upper_bound which take ranges instead of having to pass begin/end explicitly...
cl::opt< bool > EnableVTableProfileUse("enable-vtable-profile-use", cl::init(false), cl::desc("If ThinLTO and WPD is enabled and this option is true, vtable " "profiles will be used by ICP pass for more efficient indirect " "call sequence. If false, type profiles won't be used."))
void annotateValueSite(Module &M, Instruction &Inst, const InstrProfRecord &InstrProfR, InstrProfValueKind ValueKind, uint32_t SiteIndx, uint32_t MaxMDCount=3)
Get the value profile data for value site SiteIdx from InstrProfR and annotate the instruction Inst w...
auto reverse(ContainerTy &&C)
void setBranchWeights(Instruction &I, ArrayRef< uint32_t > Weights, bool IsExpected)
Create a new branch_weights metadata node and add or overwrite a prof metadata reference to instructi...
void sort(IteratorTy Start, IteratorTy End)
raw_ostream & dbgs()
dbgs() - This returns a reference to a raw_ostream for debugging messages.
SmallVector< InstrProfValueData, 4 > getValueProfDataFromInst(const Instruction &Inst, InstrProfValueKind ValueKind, uint32_t MaxNumValueData, uint64_t &TotalC, bool GetNoICPValue=false)
Extract the value profile data from Inst and returns them if Inst is annotated with value profile dat...
static uint32_t scaleBranchCount(uint64_t Count, uint64_t Scale)
Scale an individual branch count.
static uint64_t calculateCountScale(uint64_t MaxCount)
Calculate what to divide by to scale counts.
CallBase & promoteCallWithVTableCmp(CallBase &CB, Instruction *VPtr, Function *Callee, ArrayRef< Constant * > AddressPoints, MDNode *BranchWeights)
This is similar to promoteCallWithIfThenElse except that the condition to promote a virtual call is t...
void findDevirtualizableCallsForTypeTest(SmallVectorImpl< DevirtCallSite > &DevirtCalls, SmallVectorImpl< CallInst * > &Assumes, const CallInst *CI, DominatorTree &DT)
Given a call to the intrinsic @llvm.type.test, find all devirtualizable call sites based on the call ...
std::pair< Function *, Constant * > getFunctionAtVTableOffset(GlobalVariable *GV, uint64_t Offset, Module &M)
Given a vtable and a specified offset, returns the function and the trivial pointer at the specified ...
static Instruction * tryGetVTableInstruction(CallBase *CB)