LLVM 20.0.0git
PGOCtxProfLowering.cpp
Go to the documentation of this file.
1//===- PGOCtxProfLowering.cpp - Contextual PGO Instr. Lowering ------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9
13#include "llvm/IR/Analysis.h"
15#include "llvm/IR/IRBuilder.h"
18#include "llvm/IR/Module.h"
19#include "llvm/IR/PassManager.h"
22#include <utility>
23
24using namespace llvm;
25
26#define DEBUG_TYPE "ctx-instr-lower"
27
29 "profile-context-root", cl::Hidden,
31 "A function name, assumed to be global, which will be treated as the "
32 "root of an interesting graph, which will be profiled independently "
33 "from other similar graphs."));
34
36 return !ContextRoots.empty();
37}
38
39// the names of symbols we expect in compiler-rt. Using a namespace for
40// readability.
42static auto StartCtx = "__llvm_ctx_profile_start_context";
43static auto ReleaseCtx = "__llvm_ctx_profile_release_context";
44static auto GetCtx = "__llvm_ctx_profile_get_context";
45static auto ExpectedCalleeTLS = "__llvm_ctx_profile_expected_callee";
46static auto CallsiteTLS = "__llvm_ctx_profile_callsite";
47} // namespace CompilerRtAPINames
48
49namespace {
50// The lowering logic and state.
51class CtxInstrumentationLowerer final {
52 Module &M;
54 Type *ContextNodeTy = nullptr;
55 Type *ContextRootTy = nullptr;
56
58 Function *StartCtx = nullptr;
59 Function *GetCtx = nullptr;
60 Function *ReleaseCtx = nullptr;
61 GlobalVariable *ExpectedCalleeTLS = nullptr;
62 GlobalVariable *CallsiteInfoTLS = nullptr;
63
64public:
65 CtxInstrumentationLowerer(Module &M, ModuleAnalysisManager &MAM);
66 // return true if lowering happened (i.e. a change was made)
67 bool lowerFunction(Function &F);
68};
69
70// llvm.instrprof.increment[.step] captures the total number of counters as one
71// of its parameters, and llvm.instrprof.callsite captures the total number of
72// callsites. Those values are the same for instances of those intrinsics in
73// this function. Find the first instance of each and return them.
74std::pair<uint32_t, uint32_t> getNumCountersAndCallsites(const Function &F) {
75 uint32_t NumCounters = 0;
76 uint32_t NumCallsites = 0;
77 for (const auto &BB : F) {
78 for (const auto &I : BB) {
79 if (const auto *Incr = dyn_cast<InstrProfIncrementInst>(&I)) {
80 uint32_t V =
81 static_cast<uint32_t>(Incr->getNumCounters()->getZExtValue());
82 assert((!NumCounters || V == NumCounters) &&
83 "expected all llvm.instrprof.increment[.step] intrinsics to "
84 "have the same total nr of counters parameter");
85 NumCounters = V;
86 } else if (const auto *CSIntr = dyn_cast<InstrProfCallsite>(&I)) {
87 uint32_t V =
88 static_cast<uint32_t>(CSIntr->getNumCounters()->getZExtValue());
89 assert((!NumCallsites || V == NumCallsites) &&
90 "expected all llvm.instrprof.callsite intrinsics to have the "
91 "same total nr of callsites parameter");
92 NumCallsites = V;
93 }
94#if NDEBUG
95 if (NumCounters && NumCallsites)
96 return std::make_pair(NumCounters, NumCallsites);
97#endif
98 }
99 }
100 return {NumCounters, NumCallsites};
101}
102} // namespace
103
104// set up tie-in with compiler-rt.
105// NOTE!!!
106// These have to match compiler-rt/lib/ctx_profile/CtxInstrProfiling.h
107CtxInstrumentationLowerer::CtxInstrumentationLowerer(Module &M,
109 : M(M), MAM(MAM) {
110 auto *PointerTy = PointerType::get(M.getContext(), 0);
111 auto *SanitizerMutexType = Type::getInt8Ty(M.getContext());
112 auto *I32Ty = Type::getInt32Ty(M.getContext());
113 auto *I64Ty = Type::getInt64Ty(M.getContext());
114
115 // The ContextRoot type
116 ContextRootTy =
117 StructType::get(M.getContext(), {
118 PointerTy, /*FirstNode*/
119 PointerTy, /*FirstMemBlock*/
120 PointerTy, /*CurrentMem*/
121 SanitizerMutexType, /*Taken*/
122 });
123 // The Context header.
124 ContextNodeTy = StructType::get(M.getContext(), {
125 I64Ty, /*Guid*/
126 PointerTy, /*Next*/
127 I32Ty, /*NumCounters*/
128 I32Ty, /*NumCallsites*/
129 });
130
131 // Define a global for each entrypoint. We'll reuse the entrypoint's name as
132 // prefix. We assume the entrypoint names to be unique.
133 for (const auto &Fname : ContextRoots) {
134 if (const auto *F = M.getFunction(Fname)) {
135 if (F->isDeclaration())
136 continue;
137 auto *G = M.getOrInsertGlobal(Fname + "_ctx_root", ContextRootTy);
138 cast<GlobalVariable>(G)->setInitializer(
139 Constant::getNullValue(ContextRootTy));
140 ContextRootMap.insert(std::make_pair(F, G));
141 for (const auto &BB : *F)
142 for (const auto &I : BB)
143 if (const auto *CB = dyn_cast<CallBase>(&I))
144 if (CB->isMustTailCall()) {
145 M.getContext().emitError(
146 "The function " + Fname +
147 " was indicated as a context root, but it features musttail "
148 "calls, which is not supported.");
149 }
150 }
151 }
152
153 // Declare the functions we will call.
154 StartCtx = cast<Function>(
155 M.getOrInsertFunction(
157 FunctionType::get(PointerTy,
158 {PointerTy, /*ContextRoot*/
159 I64Ty, /*Guid*/ I32Ty,
160 /*NumCounters*/ I32Ty /*NumCallsites*/},
161 false))
162 .getCallee());
163 GetCtx = cast<Function>(
164 M.getOrInsertFunction(CompilerRtAPINames::GetCtx,
165 FunctionType::get(PointerTy,
166 {PointerTy, /*Callee*/
167 I64Ty, /*Guid*/
168 I32Ty, /*NumCounters*/
169 I32Ty}, /*NumCallsites*/
170 false))
171 .getCallee());
172 ReleaseCtx = cast<Function>(
173 M.getOrInsertFunction(CompilerRtAPINames::ReleaseCtx,
174 FunctionType::get(Type::getVoidTy(M.getContext()),
175 {
176 PointerTy, /*ContextRoot*/
177 },
178 false))
179 .getCallee());
180
181 // Declare the TLSes we will need to use.
182 CallsiteInfoTLS =
185 CallsiteInfoTLS->setThreadLocal(true);
186 CallsiteInfoTLS->setVisibility(llvm::GlobalValue::HiddenVisibility);
190 ExpectedCalleeTLS->setThreadLocal(true);
192}
193
196 CtxInstrumentationLowerer Lowerer(M, MAM);
197 bool Changed = false;
198 for (auto &F : M)
199 Changed |= Lowerer.lowerFunction(F);
200 return Changed ? PreservedAnalyses::none() : PreservedAnalyses::all();
201}
202
203bool CtxInstrumentationLowerer::lowerFunction(Function &F) {
204 if (F.isDeclaration())
205 return false;
206 auto &FAM = MAM.getResult<FunctionAnalysisManagerModuleProxy>(M).getManager();
208
209 Value *Guid = nullptr;
210 auto [NumCounters, NumCallsites] = getNumCountersAndCallsites(F);
211
212 Value *Context = nullptr;
213 Value *RealContext = nullptr;
214
215 StructType *ThisContextType = nullptr;
216 Value *TheRootContext = nullptr;
217 Value *ExpectedCalleeTLSAddr = nullptr;
218 Value *CallsiteInfoTLSAddr = nullptr;
219
220 auto &Head = F.getEntryBlock();
221 for (auto &I : Head) {
222 // Find the increment intrinsic in the entry basic block.
223 if (auto *Mark = dyn_cast<InstrProfIncrementInst>(&I)) {
224 assert(Mark->getIndex()->isZero());
225
226 IRBuilder<> Builder(Mark);
227
228 Guid = Builder.getInt64(
229 AssignGUIDPass::getGUID(cast<Function>(*Mark->getNameValue())));
230 // The type of the context of this function is now knowable since we have
231 // NumCallsites and NumCounters. We delcare it here because it's more
232 // convenient - we have the Builder.
233 ThisContextType = StructType::get(
234 F.getContext(),
235 {ContextNodeTy, ArrayType::get(Builder.getInt64Ty(), NumCounters),
236 ArrayType::get(Builder.getPtrTy(), NumCallsites)});
237 // Figure out which way we obtain the context object for this function -
238 // if it's an entrypoint, then we call StartCtx, otherwise GetCtx. In the
239 // former case, we also set TheRootContext since we need to release it
240 // at the end (plus it can be used to know if we have an entrypoint or a
241 // regular function)
242 auto Iter = ContextRootMap.find(&F);
243 if (Iter != ContextRootMap.end()) {
244 TheRootContext = Iter->second;
245 Context = Builder.CreateCall(
246 StartCtx, {TheRootContext, Guid, Builder.getInt32(NumCounters),
247 Builder.getInt32(NumCallsites)});
248 ORE.emit(
249 [&] { return OptimizationRemark(DEBUG_TYPE, "Entrypoint", &F); });
250 } else {
251 Context =
252 Builder.CreateCall(GetCtx, {&F, Guid, Builder.getInt32(NumCounters),
253 Builder.getInt32(NumCallsites)});
254 ORE.emit([&] {
255 return OptimizationRemark(DEBUG_TYPE, "RegularFunction", &F);
256 });
257 }
258 // The context could be scratch.
259 auto *CtxAsInt = Builder.CreatePtrToInt(Context, Builder.getInt64Ty());
260 if (NumCallsites > 0) {
261 // Figure out which index of the TLS 2-element buffers to use.
262 // Scratch context => we use index == 1. Real contexts => index == 0.
263 auto *Index = Builder.CreateAnd(CtxAsInt, Builder.getInt64(1));
264 // The GEPs corresponding to that index, in the respective TLS.
265 ExpectedCalleeTLSAddr = Builder.CreateGEP(
266 PointerType::getUnqual(F.getContext()),
267 Builder.CreateThreadLocalAddress(ExpectedCalleeTLS), {Index});
268 CallsiteInfoTLSAddr = Builder.CreateGEP(
269 Builder.getInt32Ty(),
270 Builder.CreateThreadLocalAddress(CallsiteInfoTLS), {Index});
271 }
272 // Because the context pointer may have LSB set (to indicate scratch),
273 // clear it for the value we use as base address for the counter vector.
274 // This way, if later we want to have "real" (not clobbered) buffers
275 // acting as scratch, the lowering (at least this part of it that deals
276 // with counters) stays the same.
277 RealContext = Builder.CreateIntToPtr(
278 Builder.CreateAnd(CtxAsInt, Builder.getInt64(-2)),
279 PointerType::getUnqual(F.getContext()));
280 I.eraseFromParent();
281 break;
282 }
283 }
284 if (!Context) {
285 ORE.emit([&] {
286 return OptimizationRemarkMissed(DEBUG_TYPE, "Skip", &F)
287 << "Function doesn't have instrumentation, skipping";
288 });
289 return false;
290 }
291
292 bool ContextWasReleased = false;
293 for (auto &BB : F) {
294 for (auto &I : llvm::make_early_inc_range(BB)) {
295 if (auto *Instr = dyn_cast<InstrProfCntrInstBase>(&I)) {
296 IRBuilder<> Builder(Instr);
297 switch (Instr->getIntrinsicID()) {
298 case llvm::Intrinsic::instrprof_increment:
299 case llvm::Intrinsic::instrprof_increment_step: {
300 // Increments (or increment-steps) are just a typical load - increment
301 // - store in the RealContext.
302 auto *AsStep = cast<InstrProfIncrementInst>(Instr);
303 auto *GEP = Builder.CreateGEP(
304 ThisContextType, RealContext,
305 {Builder.getInt32(0), Builder.getInt32(1), AsStep->getIndex()});
306 Builder.CreateStore(
307 Builder.CreateAdd(Builder.CreateLoad(Builder.getInt64Ty(), GEP),
308 AsStep->getStep()),
309 GEP);
310 } break;
311 case llvm::Intrinsic::instrprof_callsite:
312 // callsite lowering: write the called value in the expected callee
313 // TLS we treat the TLS as volatile because of signal handlers and to
314 // avoid these being moved away from the callsite they decorate.
315 auto *CSIntrinsic = dyn_cast<InstrProfCallsite>(Instr);
316 Builder.CreateStore(CSIntrinsic->getCallee(), ExpectedCalleeTLSAddr,
317 true);
318 // write the GEP of the slot in the sub-contexts portion of the
319 // context in TLS. Now, here, we use the actual Context value - as
320 // returned from compiler-rt - which may have the LSB set if the
321 // Context was scratch. Since the header of the context object and
322 // then the values are all 8-aligned (or, really, insofar as we care,
323 // they are even) - if the context is scratch (meaning, an odd value),
324 // so will the GEP. This is important because this is then visible to
325 // compiler-rt which will produce scratch contexts for callers that
326 // have a scratch context.
327 Builder.CreateStore(
328 Builder.CreateGEP(ThisContextType, Context,
329 {Builder.getInt32(0), Builder.getInt32(2),
330 CSIntrinsic->getIndex()}),
331 CallsiteInfoTLSAddr, true);
332 break;
333 }
334 I.eraseFromParent();
335 } else if (TheRootContext && isa<ReturnInst>(I)) {
336 // Remember to release the context if we are an entrypoint.
337 IRBuilder<> Builder(&I);
338 Builder.CreateCall(ReleaseCtx, {TheRootContext});
339 ContextWasReleased = true;
340 }
341 }
342 }
343 // FIXME: This would happen if the entrypoint tailcalls. A way to fix would be
344 // to disallow this, (so this then stays as an error), another is to detect
345 // that and then do a wrapper or disallow the tail call. This only affects
346 // instrumentation, when we want to detect the call graph.
347 if (TheRootContext && !ContextWasReleased)
348 F.getContext().emitError(
349 "[ctx_prof] An entrypoint was instrumented but it has no `ret` "
350 "instructions above which to release the context: " +
351 F.getName());
352 return true;
353}
#define DEBUG_TYPE
Hexagon Common GEP
Module.h This file contains the declarations for the Module class.
This header defines various interfaces for pass management in LLVM.
#define F(x, y, z)
Definition: MD5.cpp:55
#define I(x, y, z)
Definition: MD5.cpp:58
#define G(x, y, z)
Definition: MD5.cpp:56
static cl::list< std::string > ContextRoots("profile-context-root", cl::Hidden, cl::desc("A function name, assumed to be global, which will be treated as the " "root of an interesting graph, which will be profiled independently " "from other similar graphs."))
FunctionAnalysisManager FAM
ModuleAnalysisManager MAM
assert(ImpDefSCC.getReg()==AMDGPU::SCC &&ImpDefSCC.isDef())
A container for analyses that lazily runs them and caches their results.
Definition: PassManager.h:253
PassT::Result & getResult(IRUnitT &IR, ExtraArgTs... ExtraArgs)
Get the result of an analysis pass for a given IR unit.
Definition: PassManager.h:410
static uint64_t getGUID(const Function &F)
static Constant * getNullValue(Type *Ty)
Constructor to create a '0' constant of arbitrary type.
Definition: Constants.cpp:373
@ HiddenVisibility
The GV is hidden.
Definition: GlobalValue.h:68
@ ExternalLinkage
Externally visible function.
Definition: GlobalValue.h:52
This provides a uniform API for creating instructions and inserting them into a basic block: either a...
Definition: IRBuilder.h:2697
An analysis over an "outer" IR unit that provides access to an analysis manager over an "inner" IR un...
Definition: PassManager.h:567
A Module instance is used to store all the information related to an LLVM module.
Definition: Module.h:65
Diagnostic information for missed-optimization remarks.
Diagnostic information for applied optimization remarks.
PreservedAnalyses run(Module &M, ModuleAnalysisManager &MAM)
static PointerType * getUnqual(Type *ElementType)
This constructs a pointer to an object of the specified type in the default address space (address sp...
Definition: DerivedTypes.h:686
A set of analyses that are preserved following a run of a transformation pass.
Definition: Analysis.h:111
static PreservedAnalyses none()
Convenience factory function for the empty preserved set.
Definition: Analysis.h:114
static PreservedAnalyses all()
Construct a special preserved set that preserves all passes.
Definition: Analysis.h:117
Class to represent struct types.
Definition: DerivedTypes.h:218
static StructType * get(LLVMContext &Context, ArrayRef< Type * > Elements, bool isPacked=false)
This static method is the primary way to create a literal StructType.
Definition: Type.cpp:406
The instances of the Type class are immutable: once they are created, they are never changed.
Definition: Type.h:45
static Type * getVoidTy(LLVMContext &C)
static IntegerType * getInt8Ty(LLVMContext &C)
static IntegerType * getInt32Ty(LLVMContext &C)
static IntegerType * getInt64Ty(LLVMContext &C)
LLVM Value Representation.
Definition: Value.h:74
Pass manager infrastructure for declaring and invalidating analyses.
NodeAddr< InstrNode * > Instr
Definition: RDFGraph.h:389
This is an optimization pass for GlobalISel generic memory operations.
Definition: AddressRanges.h:18
void * PointerTy
Definition: GenericValue.h:21
iterator_range< early_inc_iterator_impl< detail::IterOfRange< RangeT > > > make_early_inc_range(RangeT &&Range)
Make a range that does early increment to allow mutation of the underlying range without disrupting i...
Definition: STLExtras.h:657