LLVM 20.0.0git
PGOCtxProfFlattening.cpp
Go to the documentation of this file.
1//===- PGOCtxProfFlattening.cpp - Contextual Instr. Flattening ------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Flattens the contextual profile and lowers it to MD_prof.
10// This should happen after all IPO (which is assumed to have maintained the
11// contextual profile) happened. Flattening consists of summing the values at
12// the same index of the counters belonging to all the contexts of a function.
13// The lowering consists of materializing the counter values to function
14// entrypoint counts and branch probabilities.
15//
16// This pass also removes contextual instrumentation, which has been kept around
17// to facilitate its functionality.
18//
19//===----------------------------------------------------------------------===//
20
22#include "llvm/ADT/STLExtras.h"
23#include "llvm/ADT/ScopeExit.h"
26#include "llvm/IR/Analysis.h"
27#include "llvm/IR/CFG.h"
28#include "llvm/IR/Dominators.h"
31#include "llvm/IR/Module.h"
32#include "llvm/IR/PassManager.h"
38#include <deque>
39
40using namespace llvm;
41
42namespace {
43
44class ProfileAnnotator final {
45 class BBInfo;
46 struct EdgeInfo {
47 BBInfo *const Src;
48 BBInfo *const Dest;
49 std::optional<uint64_t> Count;
50
51 explicit EdgeInfo(BBInfo &Src, BBInfo &Dest) : Src(&Src), Dest(&Dest) {}
52 };
53
54 class BBInfo {
55 std::optional<uint64_t> Count;
56 // OutEdges is dimensioned to match the number of terminator operands.
57 // Entries in the vector match the index in the terminator operand list. In
58 // some cases - see `shouldExcludeEdge` and its implementation - an entry
59 // will be nullptr.
60 // InEdges doesn't have the above constraint.
63 size_t UnknownCountOutEdges = 0;
64 size_t UnknownCountInEdges = 0;
65
66 // Pass AssumeAllKnown when we try to propagate counts from edges to BBs -
67 // because all the edge counters must be known.
68 // Return std::nullopt if there were no edges to sum. The user can decide
69 // how to interpret that.
70 std::optional<uint64_t> getEdgeSum(const SmallVector<EdgeInfo *> &Edges,
71 bool AssumeAllKnown) const {
72 std::optional<uint64_t> Sum;
73 for (const auto *E : Edges) {
74 // `Edges` may be `OutEdges`, case in which `E` could be nullptr.
75 if (E) {
76 if (!Sum.has_value())
77 Sum = 0;
78 *Sum += (AssumeAllKnown ? *E->Count : E->Count.value_or(0U));
79 }
80 }
81 return Sum;
82 }
83
84 bool computeCountFrom(const SmallVector<EdgeInfo *> &Edges) {
85 assert(!Count.has_value());
86 Count = getEdgeSum(Edges, true);
87 return Count.has_value();
88 }
89
90 void setSingleUnknownEdgeCount(SmallVector<EdgeInfo *> &Edges) {
91 uint64_t KnownSum = getEdgeSum(Edges, false).value_or(0U);
92 uint64_t EdgeVal = *Count > KnownSum ? *Count - KnownSum : 0U;
93 EdgeInfo *E = nullptr;
94 for (auto *I : Edges)
95 if (I && !I->Count.has_value()) {
96 E = I;
97#ifdef NDEBUG
98 break;
99#else
100 assert((!E || E == I) &&
101 "Expected exactly one edge to have an unknown count, "
102 "found a second one");
103 continue;
104#endif
105 }
106 assert(E && "Expected exactly one edge to have an unknown count");
107 assert(!E->Count.has_value());
108 E->Count = EdgeVal;
109 assert(E->Src->UnknownCountOutEdges > 0);
110 assert(E->Dest->UnknownCountInEdges > 0);
111 --E->Src->UnknownCountOutEdges;
112 --E->Dest->UnknownCountInEdges;
113 }
114
115 public:
116 BBInfo(size_t NumInEdges, size_t NumOutEdges, std::optional<uint64_t> Count)
117 : Count(Count) {
118 // For in edges, we just want to pre-allocate enough space, since we know
119 // it at this stage. For out edges, we will insert edges at the indices
120 // corresponding to positions in this BB's terminator instruction, so we
121 // construct a default (nullptr values)-initialized vector. A nullptr edge
122 // corresponds to those that are excluded (see shouldExcludeEdge).
123 InEdges.reserve(NumInEdges);
124 OutEdges.resize(NumOutEdges);
125 }
126
127 bool tryTakeCountFromKnownOutEdges(const BasicBlock &BB) {
128 if (!UnknownCountOutEdges) {
129 return computeCountFrom(OutEdges);
130 }
131 return false;
132 }
133
134 bool tryTakeCountFromKnownInEdges(const BasicBlock &BB) {
135 if (!UnknownCountInEdges) {
136 return computeCountFrom(InEdges);
137 }
138 return false;
139 }
140
141 void addInEdge(EdgeInfo &Info) {
142 InEdges.push_back(&Info);
143 ++UnknownCountInEdges;
144 }
145
146 // For the out edges, we care about the position we place them in, which is
147 // the position in terminator instruction's list (at construction). Later,
148 // we build branch_weights metadata with edge frequency values matching
149 // these positions.
150 void addOutEdge(size_t Index, EdgeInfo &Info) {
151 OutEdges[Index] = &Info;
152 ++UnknownCountOutEdges;
153 }
154
155 bool hasCount() const { return Count.has_value(); }
156
157 uint64_t getCount() const { return *Count; }
158
159 bool trySetSingleUnknownInEdgeCount() {
160 if (UnknownCountInEdges == 1) {
161 setSingleUnknownEdgeCount(InEdges);
162 return true;
163 }
164 return false;
165 }
166
167 bool trySetSingleUnknownOutEdgeCount() {
168 if (UnknownCountOutEdges == 1) {
169 setSingleUnknownEdgeCount(OutEdges);
170 return true;
171 }
172 return false;
173 }
174 size_t getNumOutEdges() const { return OutEdges.size(); }
175
176 uint64_t getEdgeCount(size_t Index) const {
177 if (auto *E = OutEdges[Index])
178 return *E->Count;
179 return 0U;
180 }
181 };
182
183 Function &F;
185 // To be accessed through getBBInfo() after construction.
186 std::map<const BasicBlock *, BBInfo> BBInfos;
187 std::vector<EdgeInfo> EdgeInfos;
189
190 // This is an adaptation of PGOUseFunc::populateCounters.
191 // FIXME(mtrofin): look into factoring the code to share one implementation.
192 void propagateCounterValues(const SmallVectorImpl<uint64_t> &Counters) {
193 bool KeepGoing = true;
194 while (KeepGoing) {
195 KeepGoing = false;
196 for (const auto &BB : F) {
197 auto &Info = getBBInfo(BB);
198 if (!Info.hasCount())
199 KeepGoing |= Info.tryTakeCountFromKnownOutEdges(BB) ||
200 Info.tryTakeCountFromKnownInEdges(BB);
201 if (Info.hasCount()) {
202 KeepGoing |= Info.trySetSingleUnknownOutEdgeCount();
203 KeepGoing |= Info.trySetSingleUnknownInEdgeCount();
204 }
205 }
206 }
207 }
208 // The only criteria for exclusion is faux suspend -> exit edges in presplit
209 // coroutines. The API serves for readability, currently.
210 bool shouldExcludeEdge(const BasicBlock &Src, const BasicBlock &Dest) const {
211 return llvm::isPresplitCoroSuspendExitEdge(Src, Dest);
212 }
213
214 BBInfo &getBBInfo(const BasicBlock &BB) { return BBInfos.find(&BB)->second; }
215
216 const BBInfo &getBBInfo(const BasicBlock &BB) const {
217 return BBInfos.find(&BB)->second;
218 }
219
220 // validation function after we propagate the counters: all BBs and edges'
221 // counters must have a value.
222 bool allCountersAreAssigned() const {
223 for (const auto &BBInfo : BBInfos)
224 if (!BBInfo.second.hasCount())
225 return false;
226 for (const auto &EdgeInfo : EdgeInfos)
227 if (!EdgeInfo.Count.has_value())
228 return false;
229 return true;
230 }
231
232 /// Check that all paths from the entry basic block that use edges with
233 /// non-zero counts arrive at a basic block with no successors (i.e. "exit")
234 bool allTakenPathsExit() const {
235 std::deque<const BasicBlock *> Worklist;
237 Worklist.push_back(&F.getEntryBlock());
238 bool HitExit = false;
239 while (!Worklist.empty()) {
240 const auto *BB = Worklist.front();
241 Worklist.pop_front();
242 if (!Visited.insert(BB).second)
243 continue;
244 if (succ_size(BB) == 0) {
245 if (isa<UnreachableInst>(BB->getTerminator()))
246 return false;
247 HitExit = true;
248 continue;
249 }
250 if (succ_size(BB) == 1) {
251 Worklist.push_back(BB->getUniqueSuccessor());
252 continue;
253 }
254 const auto &BBInfo = getBBInfo(*BB);
255 bool HasAWayOut = false;
256 for (auto I = 0U; I < BB->getTerminator()->getNumSuccessors(); ++I) {
257 const auto *Succ = BB->getTerminator()->getSuccessor(I);
258 if (!shouldExcludeEdge(*BB, *Succ)) {
259 if (BBInfo.getEdgeCount(I) > 0) {
260 HasAWayOut = true;
261 Worklist.push_back(Succ);
262 }
263 }
264 }
265 if (!HasAWayOut)
266 return false;
267 }
268 return HitExit;
269 }
270
271 bool allNonColdSelectsHaveProfile() const {
272 for (const auto &BB : F) {
273 if (getBBInfo(BB).getCount() > 0) {
274 for (const auto &I : BB) {
275 if (const auto *SI = dyn_cast<SelectInst>(&I)) {
276 if (!SI->getMetadata(LLVMContext::MD_prof)) {
277 return false;
278 }
279 }
280 }
281 }
282 }
283 return true;
284 }
285
286public:
287 ProfileAnnotator(Function &F, const SmallVectorImpl<uint64_t> &Counters,
289 : F(F), Counters(Counters), PB(PB) {
290 assert(!F.isDeclaration());
291 assert(!Counters.empty());
292 size_t NrEdges = 0;
293 for (const auto &BB : F) {
294 std::optional<uint64_t> Count;
296 const_cast<BasicBlock &>(BB))) {
297 auto Index = Ins->getIndex()->getZExtValue();
298 assert(Index < Counters.size() &&
299 "The index must be inside the counters vector by construction - "
300 "tripping this assertion indicates a bug in how the contextual "
301 "profile is managed by IPO transforms");
302 (void)Index;
303 Count = Counters[Ins->getIndex()->getZExtValue()];
304 } else if (isa<UnreachableInst>(BB.getTerminator())) {
305 // The program presumably didn't crash.
306 Count = 0;
307 }
308 auto [It, Ins] =
309 BBInfos.insert({&BB, {pred_size(&BB), succ_size(&BB), Count}});
310 (void)Ins;
311 assert(Ins && "We iterate through the function's BBs, no reason to "
312 "insert one more than once");
313 NrEdges += llvm::count_if(successors(&BB), [&](const auto *Succ) {
314 return !shouldExcludeEdge(BB, *Succ);
315 });
316 }
317 // Pre-allocate the vector, we want references to its contents to be stable.
318 EdgeInfos.reserve(NrEdges);
319 for (const auto &BB : F) {
320 auto &Info = getBBInfo(BB);
321 for (auto I = 0U; I < BB.getTerminator()->getNumSuccessors(); ++I) {
322 const auto *Succ = BB.getTerminator()->getSuccessor(I);
323 if (!shouldExcludeEdge(BB, *Succ)) {
324 auto &EI = EdgeInfos.emplace_back(getBBInfo(BB), getBBInfo(*Succ));
325 Info.addOutEdge(I, EI);
326 getBBInfo(*Succ).addInEdge(EI);
327 }
328 }
329 }
330 assert(EdgeInfos.capacity() == NrEdges &&
331 "The capacity of EdgeInfos should have stayed unchanged it was "
332 "populated, because we need pointers to its contents to be stable");
333 }
334
335 void setProfileForSelectInstructions(BasicBlock &BB, const BBInfo &BBInfo) {
336 if (BBInfo.getCount() == 0)
337 return;
338
339 for (auto &I : BB) {
340 if (auto *SI = dyn_cast<SelectInst>(&I)) {
341 if (auto *Step = CtxProfAnalysis::getSelectInstrumentation(*SI)) {
342 auto Index = Step->getIndex()->getZExtValue();
343 assert(Index < Counters.size() &&
344 "The index of the step instruction must be inside the "
345 "counters vector by "
346 "construction - tripping this assertion indicates a bug in "
347 "how the contextual profile is managed by IPO transforms");
348 auto TotalCount = BBInfo.getCount();
349 auto TrueCount = Counters[Index];
350 auto FalseCount =
351 (TotalCount > TrueCount ? TotalCount - TrueCount : 0U);
352 setProfMetadata(F.getParent(), SI, {TrueCount, FalseCount},
353 std::max(TrueCount, FalseCount));
354 PB.addInternalCount(TrueCount);
355 PB.addInternalCount(FalseCount);
356 }
357 }
358 }
359 }
360
361 /// Assign branch weights and function entry count. Also update the PSI
362 /// builder.
363 void assignProfileData() {
364 assert(!Counters.empty());
365 propagateCounterValues(Counters);
366 F.setEntryCount(Counters[0]);
367 PB.addEntryCount(Counters[0]);
368
369 for (auto &BB : F) {
370 const auto &BBInfo = getBBInfo(BB);
371 setProfileForSelectInstructions(BB, BBInfo);
372 if (succ_size(&BB) < 2)
373 continue;
374 auto *Term = BB.getTerminator();
375 SmallVector<uint64_t, 2> EdgeCounts(Term->getNumSuccessors(), 0);
376 uint64_t MaxCount = 0;
377
378 for (unsigned SuccIdx = 0, Size = BBInfo.getNumOutEdges(); SuccIdx < Size;
379 ++SuccIdx) {
380 uint64_t EdgeCount = BBInfo.getEdgeCount(SuccIdx);
381 if (EdgeCount > MaxCount)
382 MaxCount = EdgeCount;
383 EdgeCounts[SuccIdx] = EdgeCount;
384 PB.addInternalCount(EdgeCount);
385 }
386
387 if (MaxCount != 0)
388 setProfMetadata(F.getParent(), Term, EdgeCounts, MaxCount);
389 }
390 assert(allCountersAreAssigned() &&
391 "[ctx-prof] Expected all counters have been assigned.");
392 assert(allTakenPathsExit() &&
393 "[ctx-prof] Encountered a BB with more than one successor, where "
394 "all outgoing edges have a 0 count. This occurs in non-exiting "
395 "functions (message pumps, usually) which are not supported in the "
396 "contextual profiling case");
397 assert(allNonColdSelectsHaveProfile() &&
398 "[ctx-prof] All non-cold select instructions were expected to have "
399 "a profile.");
400 }
401};
402
403[[maybe_unused]] bool areAllBBsReachable(const Function &F,
405 auto &DT = FAM.getResult<DominatorTreeAnalysis>(const_cast<Function &>(F));
406 return llvm::all_of(
407 F, [&](const BasicBlock &BB) { return DT.isReachableFromEntry(&BB); });
408}
409
410void clearColdFunctionProfile(Function &F) {
411 for (auto &BB : F)
412 BB.getTerminator()->setMetadata(LLVMContext::MD_prof, nullptr);
413 F.setEntryCount(0U);
414}
415
416void removeInstrumentation(Function &F) {
417 for (auto &BB : F)
418 for (auto &I : llvm::make_early_inc_range(BB))
419 if (isa<InstrProfCntrInstBase>(I))
420 I.eraseFromParent();
421}
422
423} // namespace
424
427 // Ensure in all cases the instrumentation is removed: if this module had no
428 // roots, the contextual profile would evaluate to false, but there would
429 // still be instrumentation.
430 // Note: in such cases we leave as-is any other profile info (if present -
431 // e.g. synthetic weights, etc) because it wouldn't interfere with the
432 // contextual - based one (which would be in other modules)
433 auto OnExit = llvm::make_scope_exit([&]() {
434 for (auto &F : M)
435 removeInstrumentation(F);
436 });
437 auto &CtxProf = MAM.getResult<CtxProfAnalysis>(M);
438 if (!CtxProf)
440
441 const auto FlattenedProfile = CtxProf.flatten();
442
444 for (auto &F : M) {
445 if (F.isDeclaration())
446 continue;
447
448 assert(areAllBBsReachable(
450 .getManager()) &&
451 "Function has unreacheable basic blocks. The expectation was that "
452 "DCE was run before.");
453
454 auto It = FlattenedProfile.find(AssignGUIDPass::getGUID(F));
455 // If this function didn't appear in the contextual profile, it's cold.
456 if (It == FlattenedProfile.end())
457 clearColdFunctionProfile(F);
458 else {
459 ProfileAnnotator S(F, It->second, PB);
460 S.assignProfileData();
461 }
462 }
463
464 auto &PSI = MAM.getResult<ProfileSummaryAnalysis>(M);
465
466 M.setProfileSummary(PB.getSummary()->getMD(M.getContext()),
468 PSI.refresh();
470}
Analysis containing CSE Info
Definition: CSEInfo.cpp:27
uint64_t Size
This file provides various utilities for inspecting and working with the control flow graph in LLVM I...
Module.h This file contains the declarations for the Module class.
This header defines various interfaces for pass management in LLVM.
#define F(x, y, z)
Definition: MD5.cpp:55
#define I(x, y, z)
Definition: MD5.cpp:58
This file provides the interface for IR based instrumentation passes ( (profile-gen,...
FunctionAnalysisManager FAM
ModuleAnalysisManager MAM
PassBuilder PB(Machine, PassOpts->PTO, std::nullopt, &PIC)
assert(ImpDefSCC.getReg()==AMDGPU::SCC &&ImpDefSCC.isDef())
This file contains some templates that are useful if you are working with the STL at all.
This file defines the make_scope_exit function, which executes user-defined cleanup logic at scope ex...
A container for analyses that lazily runs them and caches their results.
Definition: PassManager.h:253
PassT::Result & getResult(IRUnitT &IR, ExtraArgTs... ExtraArgs)
Get the result of an analysis pass for a given IR unit.
Definition: PassManager.h:410
static uint64_t getGUID(const Function &F)
LLVM Basic Block Representation.
Definition: BasicBlock.h:61
const Instruction & front() const
Definition: BasicBlock.h:471
const BasicBlock * getUniqueSuccessor() const
Return the successor of this block if it has a unique successor.
Definition: BasicBlock.cpp:497
const Instruction * getTerminator() const LLVM_READONLY
Returns the terminator instruction if the block is well formed or null if the block is not well forme...
Definition: BasicBlock.h:239
static InstrProfIncrementInst * getBBInstrumentation(BasicBlock &BB)
Get the instruction instrumenting a BB, or nullptr if not present.
static InstrProfIncrementInstStep * getSelectInstrumentation(SelectInst &SI)
Get the step instrumentation associated with a select
Implements a dense probed hash-table based set.
Definition: DenseSet.h:278
Analysis pass which computes a DominatorTree.
Definition: Dominators.h:279
bool isReachableFromEntry(const Use &U) const
Provide an overload for a Use.
Definition: Dominators.cpp:321
An analysis over an "outer" IR unit that provides access to an analysis manager over an "inner" IR un...
Definition: PassManager.h:567
BasicBlock * getSuccessor(unsigned Idx) const LLVM_READONLY
Return the specified successor. This instruction must be a terminator.
void setMetadata(unsigned KindID, MDNode *Node)
Set the metadata of the specified kind to the specified node.
Definition: Metadata.cpp:1679
A Module instance is used to store all the information related to an LLVM module.
Definition: Module.h:65
PreservedAnalyses run(Module &M, ModuleAnalysisManager &MAM)
A set of analyses that are preserved following a run of a transformation pass.
Definition: Analysis.h:111
static PreservedAnalyses none()
Convenience factory function for the empty preserved set.
Definition: Analysis.h:114
An analysis pass based on the new PM to deliver ProfileSummaryInfo.
static const ArrayRef< uint32_t > DefaultCutoffs
A vector of useful cutoff values for detailed summary.
Definition: ProfileCommon.h:70
size_t size() const
Definition: SmallVector.h:78
This class consists of common code factored out of the SmallVector class to reduce code duplication b...
Definition: SmallVector.h:573
void reserve(size_type N)
Definition: SmallVector.h:663
void resize(size_type N)
Definition: SmallVector.h:638
void push_back(const T &Elt)
Definition: SmallVector.h:413
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
Definition: SmallVector.h:1196
std::pair< iterator, bool > insert(const ValueT &V)
Definition: DenseSet.h:213
Pass manager infrastructure for declaring and invalidating analyses.
This is an optimization pass for GlobalISel generic memory operations.
Definition: AddressRanges.h:18
bool all_of(R &&range, UnaryPredicate P)
Provide wrappers to std::all_of which take ranges instead of having to pass begin/end explicitly.
Definition: STLExtras.h:1739
detail::scope_exit< std::decay_t< Callable > > make_scope_exit(Callable &&F)
Definition: ScopeExit.h:59
auto successors(const MachineBasicBlock *BB)
iterator_range< early_inc_iterator_impl< detail::IterOfRange< RangeT > > > make_early_inc_range(RangeT &&Range)
Make a range that does early increment to allow mutation of the underlying range without disrupting i...
Definition: STLExtras.h:657
auto pred_size(const MachineBasicBlock *BB)
auto succ_size(const MachineBasicBlock *BB)
auto count_if(R &&Range, UnaryPredicate P)
Wrapper function around std::count_if to count the number of times an element satisfying a given pred...
Definition: STLExtras.h:1945
bool isPresplitCoroSuspendExitEdge(const BasicBlock &Src, const BasicBlock &Dest)
void setProfMetadata(Module *M, Instruction *TI, ArrayRef< uint64_t > EdgeCounts, uint64_t MaxCount)