26#define DEBUG_TYPE "ctx-instr-lower"
31 "A function name, assumed to be global, which will be treated as the "
32 "root of an interesting graph, which will be profiled independently "
33 "from other similar graphs."));
42static auto StartCtx =
"__llvm_ctx_profile_start_context";
43static auto ReleaseCtx =
"__llvm_ctx_profile_release_context";
44static auto GetCtx =
"__llvm_ctx_profile_get_context";
51class CtxInstrumentationLowerer final {
54 Type *ContextNodeTy =
nullptr;
55 Type *ContextRootTy =
nullptr;
74std::pair<uint32_t, uint32_t> getNumCountersAndCallsites(
const Function &
F) {
77 for (
const auto &BB :
F) {
78 for (
const auto &
I : BB) {
79 if (
const auto *Incr = dyn_cast<InstrProfIncrementInst>(&
I)) {
81 static_cast<uint32_t>(Incr->getNumCounters()->getZExtValue());
82 assert((!NumCounters || V == NumCounters) &&
83 "expected all llvm.instrprof.increment[.step] intrinsics to "
84 "have the same total nr of counters parameter");
86 }
else if (
const auto *CSIntr = dyn_cast<InstrProfCallsite>(&
I)) {
88 static_cast<uint32_t>(CSIntr->getNumCounters()->getZExtValue());
89 assert((!NumCallsites || V == NumCallsites) &&
90 "expected all llvm.instrprof.callsite intrinsics to have the "
91 "same total nr of callsites parameter");
95 if (NumCounters && NumCallsites)
96 return std::make_pair(NumCounters, NumCallsites);
100 return {NumCounters, NumCallsites};
107CtxInstrumentationLowerer::CtxInstrumentationLowerer(
Module &M,
110 auto *
PointerTy = PointerType::get(
M.getContext(), 0);
134 if (
const auto *
F =
M.getFunction(Fname)) {
135 if (
F->isDeclaration())
137 auto *
G =
M.getOrInsertGlobal(Fname +
"_ctx_root", ContextRootTy);
138 cast<GlobalVariable>(
G)->setInitializer(
140 ContextRootMap.insert(std::make_pair(
F,
G));
141 for (
const auto &BB : *
F)
142 for (
const auto &
I : BB)
143 if (
const auto *CB = dyn_cast<CallBase>(&
I))
144 if (CB->isMustTailCall()) {
145 M.getContext().emitError(
146 "The function " + Fname +
147 " was indicated as a context root, but it features musttail "
148 "calls, which is not supported.");
155 M.getOrInsertFunction(
185 CallsiteInfoTLS->setThreadLocal(
true);
196 CtxInstrumentationLowerer Lowerer(M,
MAM);
197 bool Changed =
false;
199 Changed |= Lowerer.lowerFunction(
F);
203bool CtxInstrumentationLowerer::lowerFunction(
Function &
F) {
204 if (
F.isDeclaration())
210 auto [NumCounters, NumCallsites] = getNumCountersAndCallsites(
F);
212 Value *Context =
nullptr;
213 Value *RealContext =
nullptr;
216 Value *TheRootContext =
nullptr;
217 Value *ExpectedCalleeTLSAddr =
nullptr;
218 Value *CallsiteInfoTLSAddr =
nullptr;
220 auto &Head =
F.getEntryBlock();
221 for (
auto &
I : Head) {
223 if (
auto *Mark = dyn_cast<InstrProfIncrementInst>(&
I)) {
224 assert(Mark->getIndex()->isZero());
228 Guid = Builder.getInt64(
235 {ContextNodeTy, ArrayType::get(Builder.getInt64Ty(), NumCounters),
236 ArrayType::get(Builder.getPtrTy(), NumCallsites)});
242 auto Iter = ContextRootMap.find(&
F);
243 if (Iter != ContextRootMap.end()) {
244 TheRootContext = Iter->second;
245 Context = Builder.CreateCall(
246 StartCtx, {TheRootContext,
Guid, Builder.getInt32(NumCounters),
247 Builder.getInt32(NumCallsites)});
252 Builder.CreateCall(GetCtx, {&
F,
Guid, Builder.getInt32(NumCounters),
253 Builder.getInt32(NumCallsites)});
259 auto *CtxAsInt = Builder.CreatePtrToInt(Context, Builder.getInt64Ty());
260 if (NumCallsites > 0) {
263 auto *
Index = Builder.CreateAnd(CtxAsInt, Builder.getInt64(1));
265 ExpectedCalleeTLSAddr = Builder.CreateGEP(
267 Builder.CreateThreadLocalAddress(ExpectedCalleeTLS), {
Index});
268 CallsiteInfoTLSAddr = Builder.CreateGEP(
269 Builder.getInt32Ty(),
270 Builder.CreateThreadLocalAddress(CallsiteInfoTLS), {
Index});
277 RealContext = Builder.CreateIntToPtr(
278 Builder.CreateAnd(CtxAsInt, Builder.getInt64(-2)),
287 <<
"Function doesn't have instrumentation, skipping";
292 bool ContextWasReleased =
false;
295 if (
auto *Instr = dyn_cast<InstrProfCntrInstBase>(&
I)) {
297 switch (
Instr->getIntrinsicID()) {
298 case llvm::Intrinsic::instrprof_increment:
299 case llvm::Intrinsic::instrprof_increment_step: {
302 auto *AsStep = cast<InstrProfIncrementInst>(Instr);
303 auto *
GEP = Builder.CreateGEP(
304 ThisContextType, RealContext,
305 {Builder.getInt32(0), Builder.getInt32(1), AsStep->getIndex()});
307 Builder.CreateAdd(Builder.CreateLoad(Builder.getInt64Ty(),
GEP),
311 case llvm::Intrinsic::instrprof_callsite:
315 auto *CSIntrinsic = dyn_cast<InstrProfCallsite>(Instr);
316 Builder.CreateStore(CSIntrinsic->getCallee(), ExpectedCalleeTLSAddr,
328 Builder.CreateGEP(ThisContextType, Context,
329 {Builder.getInt32(0), Builder.getInt32(2),
330 CSIntrinsic->getIndex()}),
331 CallsiteInfoTLSAddr,
true);
335 }
else if (TheRootContext && isa<ReturnInst>(
I)) {
338 Builder.CreateCall(ReleaseCtx, {TheRootContext});
339 ContextWasReleased =
true;
347 if (TheRootContext && !ContextWasReleased)
348 F.getContext().emitError(
349 "[ctx_prof] An entrypoint was instrumented but it has no `ret` "
350 "instructions above which to release the context: " +
Module.h This file contains the declarations for the Module class.
This header defines various interfaces for pass management in LLVM.
static cl::list< std::string > ContextRoots("profile-context-root", cl::Hidden, cl::desc("A function name, assumed to be global, which will be treated as the " "root of an interesting graph, which will be profiled independently " "from other similar graphs."))
FunctionAnalysisManager FAM
ModuleAnalysisManager MAM
assert(ImpDefSCC.getReg()==AMDGPU::SCC &&ImpDefSCC.isDef())
A container for analyses that lazily runs them and caches their results.
PassT::Result & getResult(IRUnitT &IR, ExtraArgTs... ExtraArgs)
Get the result of an analysis pass for a given IR unit.
static uint64_t getGUID(const Function &F)
static Constant * getNullValue(Type *Ty)
Constructor to create a '0' constant of arbitrary type.
@ HiddenVisibility
The GV is hidden.
@ ExternalLinkage
Externally visible function.
This provides a uniform API for creating instructions and inserting them into a basic block: either a...
An analysis over an "outer" IR unit that provides access to an analysis manager over an "inner" IR un...
A Module instance is used to store all the information related to an LLVM module.
PreservedAnalyses run(Module &M, ModuleAnalysisManager &MAM)
static bool isCtxIRPGOInstrEnabled()
static PointerType * getUnqual(Type *ElementType)
This constructs a pointer to an object of the specified type in the default address space (address sp...
A set of analyses that are preserved following a run of a transformation pass.
static PreservedAnalyses none()
Convenience factory function for the empty preserved set.
static PreservedAnalyses all()
Construct a special preserved set that preserves all passes.
Class to represent struct types.
static StructType * get(LLVMContext &Context, ArrayRef< Type * > Elements, bool isPacked=false)
This static method is the primary way to create a literal StructType.
The instances of the Type class are immutable: once they are created, they are never changed.
static Type * getVoidTy(LLVMContext &C)
static IntegerType * getInt8Ty(LLVMContext &C)
static IntegerType * getInt32Ty(LLVMContext &C)
static IntegerType * getInt64Ty(LLVMContext &C)
LLVM Value Representation.
Pass manager infrastructure for declaring and invalidating analyses.
static auto ExpectedCalleeTLS
NodeAddr< InstrNode * > Instr
This is an optimization pass for GlobalISel generic memory operations.
iterator_range< early_inc_iterator_impl< detail::IterOfRange< RangeT > > > make_early_inc_range(RangeT &&Range)
Make a range that does early increment to allow mutation of the underlying range without disrupting i...