LLVM 19.0.0git
PGOCtxProfLowering.cpp
Go to the documentation of this file.
1//===- PGOCtxProfLowering.cpp - Contextual PGO Instr. Lowering ------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9
12#include "llvm/IR/Analysis.h"
14#include "llvm/IR/IRBuilder.h"
17#include "llvm/IR/PassManager.h"
19#include <utility>
20
21using namespace llvm;
22
23#define DEBUG_TYPE "ctx-instr-lower"
24
26 "profile-context-root", cl::Hidden,
28 "A function name, assumed to be global, which will be treated as the "
29 "root of an interesting graph, which will be profiled independently "
30 "from other similar graphs."));
31
33 return !ContextRoots.empty();
34}
35
36// the names of symbols we expect in compiler-rt. Using a namespace for
37// readability.
39static auto StartCtx = "__llvm_ctx_profile_start_context";
40static auto ReleaseCtx = "__llvm_ctx_profile_release_context";
41static auto GetCtx = "__llvm_ctx_profile_get_context";
42static auto ExpectedCalleeTLS = "__llvm_ctx_profile_expected_callee";
43static auto CallsiteTLS = "__llvm_ctx_profile_callsite";
44} // namespace CompilerRtAPINames
45
46namespace {
47// The lowering logic and state.
48class CtxInstrumentationLowerer final {
49 Module &M;
51 Type *ContextNodeTy = nullptr;
52 Type *ContextRootTy = nullptr;
53
55 Function *StartCtx = nullptr;
56 Function *GetCtx = nullptr;
57 Function *ReleaseCtx = nullptr;
58 GlobalVariable *ExpectedCalleeTLS = nullptr;
59 GlobalVariable *CallsiteInfoTLS = nullptr;
60
61public:
62 CtxInstrumentationLowerer(Module &M, ModuleAnalysisManager &MAM);
63 // return true if lowering happened (i.e. a change was made)
64 bool lowerFunction(Function &F);
65};
66
67// llvm.instrprof.increment[.step] captures the total number of counters as one
68// of its parameters, and llvm.instrprof.callsite captures the total number of
69// callsites. Those values are the same for instances of those intrinsics in
70// this function. Find the first instance of each and return them.
71std::pair<uint32_t, uint32_t> getNrCountersAndCallsites(const Function &F) {
72 uint32_t NrCounters = 0;
73 uint32_t NrCallsites = 0;
74 for (const auto &BB : F) {
75 for (const auto &I : BB) {
76 if (const auto *Incr = dyn_cast<InstrProfIncrementInst>(&I)) {
77 uint32_t V =
78 static_cast<uint32_t>(Incr->getNumCounters()->getZExtValue());
79 assert((!NrCounters || V == NrCounters) &&
80 "expected all llvm.instrprof.increment[.step] intrinsics to "
81 "have the same total nr of counters parameter");
82 NrCounters = V;
83 } else if (const auto *CSIntr = dyn_cast<InstrProfCallsite>(&I)) {
84 uint32_t V =
85 static_cast<uint32_t>(CSIntr->getNumCounters()->getZExtValue());
86 assert((!NrCallsites || V == NrCallsites) &&
87 "expected all llvm.instrprof.callsite intrinsics to have the "
88 "same total nr of callsites parameter");
89 NrCallsites = V;
90 }
91#if NDEBUG
92 if (NrCounters && NrCallsites)
93 return std::make_pair(NrCounters, NrCallsites);
94#endif
95 }
96 }
97 return {NrCounters, NrCallsites};
98}
99} // namespace
100
101// set up tie-in with compiler-rt.
102// NOTE!!!
103// These have to match compiler-rt/lib/ctx_profile/CtxInstrProfiling.h
104CtxInstrumentationLowerer::CtxInstrumentationLowerer(Module &M,
106 : M(M), MAM(MAM) {
107 auto *PointerTy = PointerType::get(M.getContext(), 0);
108 auto *SanitizerMutexType = Type::getInt8Ty(M.getContext());
109 auto *I32Ty = Type::getInt32Ty(M.getContext());
110 auto *I64Ty = Type::getInt64Ty(M.getContext());
111
112 // The ContextRoot type
113 ContextRootTy =
114 StructType::get(M.getContext(), {
115 PointerTy, /*FirstNode*/
116 PointerTy, /*FirstMemBlock*/
117 PointerTy, /*CurrentMem*/
118 SanitizerMutexType, /*Taken*/
119 });
120 // The Context header.
121 ContextNodeTy = StructType::get(M.getContext(), {
122 I64Ty, /*Guid*/
123 PointerTy, /*Next*/
124 I32Ty, /*NrCounters*/
125 I32Ty, /*NrCallsites*/
126 });
127
128 // Define a global for each entrypoint. We'll reuse the entrypoint's name as
129 // prefix. We assume the entrypoint names to be unique.
130 for (const auto &Fname : ContextRoots) {
131 if (const auto *F = M.getFunction(Fname)) {
132 if (F->isDeclaration())
133 continue;
134 auto *G = M.getOrInsertGlobal(Fname + "_ctx_root", ContextRootTy);
135 cast<GlobalVariable>(G)->setInitializer(
136 Constant::getNullValue(ContextRootTy));
137 ContextRootMap.insert(std::make_pair(F, G));
138 for (const auto &BB : *F)
139 for (const auto &I : BB)
140 if (const auto *CB = dyn_cast<CallBase>(&I))
141 if (CB->isMustTailCall()) {
142 M.getContext().emitError(
143 "The function " + Fname +
144 " was indicated as a context root, but it features musttail "
145 "calls, which is not supported.");
146 }
147 }
148 }
149
150 // Declare the functions we will call.
151 StartCtx = cast<Function>(
152 M.getOrInsertFunction(
154 FunctionType::get(ContextNodeTy->getPointerTo(),
155 {ContextRootTy->getPointerTo(), /*ContextRoot*/
156 I64Ty, /*Guid*/ I32Ty,
157 /*NrCounters*/ I32Ty /*NrCallsites*/},
158 false))
159 .getCallee());
160 GetCtx = cast<Function>(
161 M.getOrInsertFunction(CompilerRtAPINames::GetCtx,
162 FunctionType::get(ContextNodeTy->getPointerTo(),
163 {PointerTy, /*Callee*/
164 I64Ty, /*Guid*/
165 I32Ty, /*NrCounters*/
166 I32Ty}, /*NrCallsites*/
167 false))
168 .getCallee());
169 ReleaseCtx = cast<Function>(
170 M.getOrInsertFunction(
172 FunctionType::get(Type::getVoidTy(M.getContext()),
173 {
174 ContextRootTy->getPointerTo(), /*ContextRoot*/
175 },
176 false))
177 .getCallee());
178
179 // Declare the TLSes we will need to use.
180 CallsiteInfoTLS =
183 CallsiteInfoTLS->setThreadLocal(true);
184 CallsiteInfoTLS->setVisibility(llvm::GlobalValue::HiddenVisibility);
188 ExpectedCalleeTLS->setThreadLocal(true);
190}
191
194 CtxInstrumentationLowerer Lowerer(M, MAM);
195 bool Changed = false;
196 for (auto &F : M)
197 Changed |= Lowerer.lowerFunction(F);
198 return Changed ? PreservedAnalyses::none() : PreservedAnalyses::all();
199}
200
201bool CtxInstrumentationLowerer::lowerFunction(Function &F) {
202 if (F.isDeclaration())
203 return false;
204 auto &FAM = MAM.getResult<FunctionAnalysisManagerModuleProxy>(M).getManager();
206
207 Value *Guid = nullptr;
208 auto [NrCounters, NrCallsites] = getNrCountersAndCallsites(F);
209
210 Value *Context = nullptr;
211 Value *RealContext = nullptr;
212
213 StructType *ThisContextType = nullptr;
214 Value *TheRootContext = nullptr;
215 Value *ExpectedCalleeTLSAddr = nullptr;
216 Value *CallsiteInfoTLSAddr = nullptr;
217
218 auto &Head = F.getEntryBlock();
219 for (auto &I : Head) {
220 // Find the increment intrinsic in the entry basic block.
221 if (auto *Mark = dyn_cast<InstrProfIncrementInst>(&I)) {
222 assert(Mark->getIndex()->isZero());
223
224 IRBuilder<> Builder(Mark);
225 // FIXME(mtrofin): use InstrProfSymtab::getCanonicalName
226 Guid = Builder.getInt64(F.getGUID());
227 // The type of the context of this function is now knowable since we have
228 // NrCallsites and NrCounters. We delcare it here because it's more
229 // convenient - we have the Builder.
230 ThisContextType = StructType::get(
231 F.getContext(),
232 {ContextNodeTy, ArrayType::get(Builder.getInt64Ty(), NrCounters),
233 ArrayType::get(Builder.getPtrTy(), NrCallsites)});
234 // Figure out which way we obtain the context object for this function -
235 // if it's an entrypoint, then we call StartCtx, otherwise GetCtx. In the
236 // former case, we also set TheRootContext since we need to release it
237 // at the end (plus it can be used to know if we have an entrypoint or a
238 // regular function)
239 auto Iter = ContextRootMap.find(&F);
240 if (Iter != ContextRootMap.end()) {
241 TheRootContext = Iter->second;
242 Context = Builder.CreateCall(StartCtx, {TheRootContext, Guid,
243 Builder.getInt32(NrCounters),
244 Builder.getInt32(NrCallsites)});
245 ORE.emit(
246 [&] { return OptimizationRemark(DEBUG_TYPE, "Entrypoint", &F); });
247 } else {
248 Context =
249 Builder.CreateCall(GetCtx, {&F, Guid, Builder.getInt32(NrCounters),
250 Builder.getInt32(NrCallsites)});
251 ORE.emit([&] {
252 return OptimizationRemark(DEBUG_TYPE, "RegularFunction", &F);
253 });
254 }
255 // The context could be scratch.
256 auto *CtxAsInt = Builder.CreatePtrToInt(Context, Builder.getInt64Ty());
257 if (NrCallsites > 0) {
258 // Figure out which index of the TLS 2-element buffers to use.
259 // Scratch context => we use index == 1. Real contexts => index == 0.
260 auto *Index = Builder.CreateAnd(CtxAsInt, Builder.getInt64(1));
261 // The GEPs corresponding to that index, in the respective TLS.
262 ExpectedCalleeTLSAddr = Builder.CreateGEP(
263 Builder.getInt8Ty()->getPointerTo(),
264 Builder.CreateThreadLocalAddress(ExpectedCalleeTLS), {Index});
265 CallsiteInfoTLSAddr = Builder.CreateGEP(
266 Builder.getInt32Ty(),
267 Builder.CreateThreadLocalAddress(CallsiteInfoTLS), {Index});
268 }
269 // Because the context pointer may have LSB set (to indicate scratch),
270 // clear it for the value we use as base address for the counter vector.
271 // This way, if later we want to have "real" (not clobbered) buffers
272 // acting as scratch, the lowering (at least this part of it that deals
273 // with counters) stays the same.
274 RealContext = Builder.CreateIntToPtr(
275 Builder.CreateAnd(CtxAsInt, Builder.getInt64(-2)),
276 ThisContextType->getPointerTo());
277 I.eraseFromParent();
278 break;
279 }
280 }
281 if (!Context) {
282 ORE.emit([&] {
283 return OptimizationRemarkMissed(DEBUG_TYPE, "Skip", &F)
284 << "Function doesn't have instrumentation, skipping";
285 });
286 return false;
287 }
288
289 bool ContextWasReleased = false;
290 for (auto &BB : F) {
291 for (auto &I : llvm::make_early_inc_range(BB)) {
292 if (auto *Instr = dyn_cast<InstrProfCntrInstBase>(&I)) {
293 IRBuilder<> Builder(Instr);
294 switch (Instr->getIntrinsicID()) {
295 case llvm::Intrinsic::instrprof_increment:
296 case llvm::Intrinsic::instrprof_increment_step: {
297 // Increments (or increment-steps) are just a typical load - increment
298 // - store in the RealContext.
299 auto *AsStep = cast<InstrProfIncrementInst>(Instr);
300 auto *GEP = Builder.CreateGEP(
301 ThisContextType, RealContext,
302 {Builder.getInt32(0), Builder.getInt32(1), AsStep->getIndex()});
303 Builder.CreateStore(
304 Builder.CreateAdd(Builder.CreateLoad(Builder.getInt64Ty(), GEP),
305 AsStep->getStep()),
306 GEP);
307 } break;
308 case llvm::Intrinsic::instrprof_callsite:
309 // callsite lowering: write the called value in the expected callee
310 // TLS we treat the TLS as volatile because of signal handlers and to
311 // avoid these being moved away from the callsite they decorate.
312 auto *CSIntrinsic = dyn_cast<InstrProfCallsite>(Instr);
313 Builder.CreateStore(CSIntrinsic->getCallee(), ExpectedCalleeTLSAddr,
314 true);
315 // write the GEP of the slot in the sub-contexts portion of the
316 // context in TLS. Now, here, we use the actual Context value - as
317 // returned from compiler-rt - which may have the LSB set if the
318 // Context was scratch. Since the header of the context object and
319 // then the values are all 8-aligned (or, really, insofar as we care,
320 // they are even) - if the context is scratch (meaning, an odd value),
321 // so will the GEP. This is important because this is then visible to
322 // compiler-rt which will produce scratch contexts for callers that
323 // have a scratch context.
324 Builder.CreateStore(
325 Builder.CreateGEP(ThisContextType, Context,
326 {Builder.getInt32(0), Builder.getInt32(2),
327 CSIntrinsic->getIndex()}),
328 CallsiteInfoTLSAddr, true);
329 break;
330 }
331 I.eraseFromParent();
332 } else if (TheRootContext && isa<ReturnInst>(I)) {
333 // Remember to release the context if we are an entrypoint.
334 IRBuilder<> Builder(&I);
335 Builder.CreateCall(ReleaseCtx, {TheRootContext});
336 ContextWasReleased = true;
337 }
338 }
339 }
340 // FIXME: This would happen if the entrypoint tailcalls. A way to fix would be
341 // to disallow this, (so this then stays as an error), another is to detect
342 // that and then do a wrapper or disallow the tail call. This only affects
343 // instrumentation, when we want to detect the call graph.
344 if (TheRootContext && !ContextWasReleased)
345 F.getContext().emitError(
346 "[ctx_prof] An entrypoint was instrumented but it has no `ret` "
347 "instructions above which to release the context: " +
348 F.getName());
349 return true;
350}
#define DEBUG_TYPE
Hexagon Common GEP
#define F(x, y, z)
Definition: MD5.cpp:55
#define I(x, y, z)
Definition: MD5.cpp:58
#define G(x, y, z)
Definition: MD5.cpp:56
LLVMContext & Context
static cl::list< std::string > ContextRoots("profile-context-root", cl::Hidden, cl::desc("A function name, assumed to be global, which will be treated as the " "root of an interesting graph, which will be profiled independently " "from other similar graphs."))
FunctionAnalysisManager FAM
ModuleAnalysisManager MAM
This header defines various interfaces for pass management in LLVM.
assert(ImpDefSCC.getReg()==AMDGPU::SCC &&ImpDefSCC.isDef())
A container for analyses that lazily runs them and caches their results.
Definition: PassManager.h:321
PassT::Result & getResult(IRUnitT &IR, ExtraArgTs... ExtraArgs)
Get the result of an analysis pass for a given IR unit.
Definition: PassManager.h:473
static Constant * getNullValue(Type *Ty)
Constructor to create a '0' constant of arbitrary type.
Definition: Constants.cpp:370
@ HiddenVisibility
The GV is hidden.
Definition: GlobalValue.h:67
@ ExternalLinkage
Externally visible function.
Definition: GlobalValue.h:51
This provides a uniform API for creating instructions and inserting them into a basic block: either a...
Definition: IRBuilder.h:2666
An analysis over an "outer" IR unit that provides access to an analysis manager over an "inner" IR un...
Definition: PassManager.h:631
A Module instance is used to store all the information related to an LLVM module.
Definition: Module.h:65
Diagnostic information for missed-optimization remarks.
Diagnostic information for applied optimization remarks.
PreservedAnalyses run(Module &M, ModuleAnalysisManager &MAM)
A set of analyses that are preserved following a run of a transformation pass.
Definition: Analysis.h:109
static PreservedAnalyses none()
Convenience factory function for the empty preserved set.
Definition: Analysis.h:112
static PreservedAnalyses all()
Construct a special preserved set that preserves all passes.
Definition: Analysis.h:115
Class to represent struct types.
Definition: DerivedTypes.h:216
static StructType * get(LLVMContext &Context, ArrayRef< Type * > Elements, bool isPacked=false)
This static method is the primary way to create a literal StructType.
Definition: Type.cpp:373
The instances of the Type class are immutable: once they are created, they are never changed.
Definition: Type.h:45
PointerType * getPointerTo(unsigned AddrSpace=0) const
Return a pointer to the current type.
static Type * getVoidTy(LLVMContext &C)
static IntegerType * getInt8Ty(LLVMContext &C)
static IntegerType * getInt32Ty(LLVMContext &C)
static IntegerType * getInt64Ty(LLVMContext &C)
LLVM Value Representation.
Definition: Value.h:74
Pass manager infrastructure for declaring and invalidating analyses.
NodeAddr< InstrNode * > Instr
Definition: RDFGraph.h:389
This is an optimization pass for GlobalISel generic memory operations.
Definition: AddressRanges.h:18
void * PointerTy
Definition: GenericValue.h:21
iterator_range< early_inc_iterator_impl< detail::IterOfRange< RangeT > > > make_early_inc_range(RangeT &&Range)
Make a range that does early increment to allow mutation of the underlying range without disrupting i...
Definition: STLExtras.h:656