24#define DEBUG_TYPE "ctx-instr-lower"
29 "A function name, assumed to be global, which will be treated as the "
30 "root of an interesting graph, which will be profiled independently "
31 "from other similar graphs."));
40static auto StartCtx =
"__llvm_ctx_profile_start_context";
41static auto ReleaseCtx =
"__llvm_ctx_profile_release_context";
42static auto GetCtx =
"__llvm_ctx_profile_get_context";
49class CtxInstrumentationLowerer final {
52 Type *ContextNodeTy =
nullptr;
53 Type *ContextRootTy =
nullptr;
72std::pair<uint32_t, uint32_t> getNrCountersAndCallsites(
const Function &
F) {
75 for (
const auto &BB :
F) {
76 for (
const auto &
I : BB) {
77 if (
const auto *Incr = dyn_cast<InstrProfIncrementInst>(&
I)) {
79 static_cast<uint32_t>(Incr->getNumCounters()->getZExtValue());
80 assert((!NrCounters || V == NrCounters) &&
81 "expected all llvm.instrprof.increment[.step] intrinsics to "
82 "have the same total nr of counters parameter");
84 }
else if (
const auto *CSIntr = dyn_cast<InstrProfCallsite>(&
I)) {
86 static_cast<uint32_t>(CSIntr->getNumCounters()->getZExtValue());
87 assert((!NrCallsites || V == NrCallsites) &&
88 "expected all llvm.instrprof.callsite intrinsics to have the "
89 "same total nr of callsites parameter");
93 if (NrCounters && NrCallsites)
94 return std::make_pair(NrCounters, NrCallsites);
98 return {NrCounters, NrCallsites};
105CtxInstrumentationLowerer::CtxInstrumentationLowerer(
Module &M,
108 auto *
PointerTy = PointerType::get(
M.getContext(), 0);
132 if (
const auto *
F =
M.getFunction(Fname)) {
133 if (
F->isDeclaration())
135 auto *
G =
M.getOrInsertGlobal(Fname +
"_ctx_root", ContextRootTy);
136 cast<GlobalVariable>(
G)->setInitializer(
138 ContextRootMap.insert(std::make_pair(
F,
G));
139 for (
const auto &BB : *
F)
140 for (
const auto &
I : BB)
141 if (
const auto *CB = dyn_cast<CallBase>(&
I))
142 if (CB->isMustTailCall()) {
143 M.getContext().emitError(
144 "The function " + Fname +
145 " was indicated as a context root, but it features musttail "
146 "calls, which is not supported.");
153 M.getOrInsertFunction(
155 FunctionType::get(ContextNodeTy->getPointerTo(),
156 {ContextRootTy->getPointerTo(),
163 FunctionType::get(ContextNodeTy->getPointerTo(),
171 M.getOrInsertFunction(
175 ContextRootTy->getPointerTo(),
184 CallsiteInfoTLS->setThreadLocal(
true);
195 CtxInstrumentationLowerer Lowerer(M,
MAM);
196 bool Changed =
false;
198 Changed |= Lowerer.lowerFunction(
F);
202bool CtxInstrumentationLowerer::lowerFunction(
Function &
F) {
203 if (
F.isDeclaration())
209 auto [NrCounters, NrCallsites] = getNrCountersAndCallsites(
F);
211 Value *Context =
nullptr;
212 Value *RealContext =
nullptr;
215 Value *TheRootContext =
nullptr;
216 Value *ExpectedCalleeTLSAddr =
nullptr;
217 Value *CallsiteInfoTLSAddr =
nullptr;
219 auto &Head =
F.getEntryBlock();
220 for (
auto &
I : Head) {
222 if (
auto *Mark = dyn_cast<InstrProfIncrementInst>(&
I)) {
223 assert(Mark->getIndex()->isZero());
227 Guid = Builder.getInt64(
F.getGUID());
233 {ContextNodeTy, ArrayType::get(Builder.getInt64Ty(), NrCounters),
234 ArrayType::get(Builder.getPtrTy(), NrCallsites)});
240 auto Iter = ContextRootMap.find(&
F);
241 if (Iter != ContextRootMap.end()) {
242 TheRootContext = Iter->second;
243 Context = Builder.CreateCall(StartCtx, {TheRootContext,
Guid,
244 Builder.getInt32(NrCounters),
245 Builder.getInt32(NrCallsites)});
250 Builder.CreateCall(GetCtx, {&
F,
Guid, Builder.getInt32(NrCounters),
251 Builder.getInt32(NrCallsites)});
257 auto *CtxAsInt = Builder.CreatePtrToInt(Context, Builder.getInt64Ty());
258 if (NrCallsites > 0) {
261 auto *
Index = Builder.CreateAnd(CtxAsInt, Builder.getInt64(1));
263 ExpectedCalleeTLSAddr = Builder.CreateGEP(
264 Builder.getInt8Ty()->getPointerTo(),
265 Builder.CreateThreadLocalAddress(ExpectedCalleeTLS), {
Index});
266 CallsiteInfoTLSAddr = Builder.CreateGEP(
267 Builder.getInt32Ty(),
268 Builder.CreateThreadLocalAddress(CallsiteInfoTLS), {
Index});
275 RealContext = Builder.CreateIntToPtr(
276 Builder.CreateAnd(CtxAsInt, Builder.getInt64(-2)),
285 <<
"Function doesn't have instrumentation, skipping";
290 bool ContextWasReleased =
false;
293 if (
auto *Instr = dyn_cast<InstrProfCntrInstBase>(&
I)) {
295 switch (
Instr->getIntrinsicID()) {
296 case llvm::Intrinsic::instrprof_increment:
297 case llvm::Intrinsic::instrprof_increment_step: {
300 auto *AsStep = cast<InstrProfIncrementInst>(Instr);
301 auto *
GEP = Builder.CreateGEP(
302 ThisContextType, RealContext,
303 {Builder.getInt32(0), Builder.getInt32(1), AsStep->getIndex()});
305 Builder.CreateAdd(Builder.CreateLoad(Builder.getInt64Ty(),
GEP),
309 case llvm::Intrinsic::instrprof_callsite:
313 auto *CSIntrinsic = dyn_cast<InstrProfCallsite>(Instr);
314 Builder.CreateStore(CSIntrinsic->getCallee(), ExpectedCalleeTLSAddr,
326 Builder.CreateGEP(ThisContextType, Context,
327 {Builder.getInt32(0), Builder.getInt32(2),
328 CSIntrinsic->getIndex()}),
329 CallsiteInfoTLSAddr,
true);
333 }
else if (TheRootContext && isa<ReturnInst>(
I)) {
336 Builder.CreateCall(ReleaseCtx, {TheRootContext});
337 ContextWasReleased =
true;
345 if (TheRootContext && !ContextWasReleased)
346 F.getContext().emitError(
347 "[ctx_prof] An entrypoint was instrumented but it has no `ret` "
348 "instructions above which to release the context: " +
Module.h This file contains the declarations for the Module class.
static cl::list< std::string > ContextRoots("profile-context-root", cl::Hidden, cl::desc("A function name, assumed to be global, which will be treated as the " "root of an interesting graph, which will be profiled independently " "from other similar graphs."))
FunctionAnalysisManager FAM
ModuleAnalysisManager MAM
This header defines various interfaces for pass management in LLVM.
assert(ImpDefSCC.getReg()==AMDGPU::SCC &&ImpDefSCC.isDef())
A container for analyses that lazily runs them and caches their results.
PassT::Result & getResult(IRUnitT &IR, ExtraArgTs... ExtraArgs)
Get the result of an analysis pass for a given IR unit.
static Constant * getNullValue(Type *Ty)
Constructor to create a '0' constant of arbitrary type.
@ HiddenVisibility
The GV is hidden.
@ ExternalLinkage
Externally visible function.
This provides a uniform API for creating instructions and inserting them into a basic block: either a...
An analysis over an "outer" IR unit that provides access to an analysis manager over an "inner" IR un...
A Module instance is used to store all the information related to an LLVM module.
PreservedAnalyses run(Module &M, ModuleAnalysisManager &MAM)
static bool isContextualIRPGOEnabled()
A set of analyses that are preserved following a run of a transformation pass.
static PreservedAnalyses none()
Convenience factory function for the empty preserved set.
static PreservedAnalyses all()
Construct a special preserved set that preserves all passes.
Class to represent struct types.
static StructType * get(LLVMContext &Context, ArrayRef< Type * > Elements, bool isPacked=false)
This static method is the primary way to create a literal StructType.
The instances of the Type class are immutable: once they are created, they are never changed.
PointerType * getPointerTo(unsigned AddrSpace=0) const
Return a pointer to the current type.
static Type * getVoidTy(LLVMContext &C)
static IntegerType * getInt8Ty(LLVMContext &C)
static IntegerType * getInt32Ty(LLVMContext &C)
static IntegerType * getInt64Ty(LLVMContext &C)
LLVM Value Representation.
Pass manager infrastructure for declaring and invalidating analyses.
static auto ExpectedCalleeTLS
NodeAddr< InstrNode * > Instr
This is an optimization pass for GlobalISel generic memory operations.
iterator_range< early_inc_iterator_impl< detail::IterOfRange< RangeT > > > make_early_inc_range(RangeT &&Range)
Make a range that does early increment to allow mutation of the underlying range without disrupting i...