LLVM 20.0.0git
ProfDataUtils.cpp
Go to the documentation of this file.
1//===- ProfDataUtils.cpp - Utility functions for MD_prof Metadata ---------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements utilities for working with Profiling Metadata.
10//
11//===----------------------------------------------------------------------===//
12
15#include "llvm/IR/Constants.h"
16#include "llvm/IR/Function.h"
18#include "llvm/IR/LLVMContext.h"
19#include "llvm/IR/MDBuilder.h"
20#include "llvm/IR/Metadata.h"
22
23using namespace llvm;
24
25namespace {
26
27// MD_prof nodes have the following layout
28//
29// In general:
30// { String name, Array of i32 }
31//
32// In terms of Types:
33// { MDString, [i32, i32, ...]}
34//
35// Concretely for Branch Weights
36// { "branch_weights", [i32 1, i32 10000]}
37//
38// We maintain some constants here to ensure that we access the branch weights
39// correctly, and can change the behavior in the future if the layout changes
40
41// the minimum number of operands for MD_prof nodes with branch weights
42constexpr unsigned MinBWOps = 3;
43
44// the minimum number of operands for MD_prof nodes with value profiles
45constexpr unsigned MinVPOps = 5;
46
47// We may want to add support for other MD_prof types, so provide an abstraction
48// for checking the metadata type.
49bool isTargetMD(const MDNode *ProfData, const char *Name, unsigned MinOps) {
50 // TODO: This routine may be simplified if MD_prof used an enum instead of a
51 // string to differentiate the types of MD_prof nodes.
52 if (!ProfData || !Name || MinOps < 2)
53 return false;
54
55 unsigned NOps = ProfData->getNumOperands();
56 if (NOps < MinOps)
57 return false;
58
59 auto *ProfDataName = dyn_cast<MDString>(ProfData->getOperand(0));
60 if (!ProfDataName)
61 return false;
62
63 return ProfDataName->getString() == Name;
64}
65
66template <typename T,
67 typename = typename std::enable_if<std::is_arithmetic_v<T>>>
68static void extractFromBranchWeightMD(const MDNode *ProfileData,
69 SmallVectorImpl<T> &Weights) {
70 assert(isBranchWeightMD(ProfileData) && "wrong metadata");
71
72 unsigned NOps = ProfileData->getNumOperands();
73 unsigned WeightsIdx = getBranchWeightOffset(ProfileData);
74 assert(WeightsIdx < NOps && "Weights Index must be less than NOps.");
75 Weights.resize(NOps - WeightsIdx);
76
77 for (unsigned Idx = WeightsIdx, E = NOps; Idx != E; ++Idx) {
78 ConstantInt *Weight =
79 mdconst::dyn_extract<ConstantInt>(ProfileData->getOperand(Idx));
80 assert(Weight && "Malformed branch_weight in MD_prof node");
81 assert(Weight->getValue().getActiveBits() <= (sizeof(T) * 8) &&
82 "Too many bits for MD_prof branch_weight");
83 Weights[Idx - WeightsIdx] = Weight->getZExtValue();
84 }
85}
86
87} // namespace
88
89namespace llvm {
90
91bool hasProfMD(const Instruction &I) {
92 return I.hasMetadata(LLVMContext::MD_prof);
93}
94
95bool isBranchWeightMD(const MDNode *ProfileData) {
96 return isTargetMD(ProfileData, "branch_weights", MinBWOps);
97}
98
99bool isValueProfileMD(const MDNode *ProfileData) {
100 return isTargetMD(ProfileData, "VP", MinVPOps);
101}
102
104 auto *ProfileData = I.getMetadata(LLVMContext::MD_prof);
105 return isBranchWeightMD(ProfileData);
106}
107
109 auto *ProfileData = I.getMetadata(LLVMContext::MD_prof);
110 // Value profiles record count-type information.
111 if (isValueProfileMD(ProfileData))
112 return true;
113 // Conservatively assume non CallBase instruction only get taken/not-taken
114 // branch probability, so not interpret them as count.
115 return isa<CallBase>(I) && !isBranchWeightMD(ProfileData);
116}
117
120}
121
123 auto *ProfileData = I.getMetadata(LLVMContext::MD_prof);
124 return hasBranchWeightOrigin(ProfileData);
125}
126
127bool hasBranchWeightOrigin(const MDNode *ProfileData) {
128 if (!isBranchWeightMD(ProfileData))
129 return false;
130 auto *ProfDataName = dyn_cast<MDString>(ProfileData->getOperand(1));
131 // NOTE: if we ever have more types of branch weight provenance,
132 // we need to check the string value is "expected". For now, we
133 // supply a more generic API, and avoid the spurious comparisons.
134 assert(ProfDataName == nullptr || ProfDataName->getString() == "expected");
135 return ProfDataName != nullptr;
136}
137
138unsigned getBranchWeightOffset(const MDNode *ProfileData) {
139 return hasBranchWeightOrigin(ProfileData) ? 2 : 1;
140}
141
142unsigned getNumBranchWeights(const MDNode &ProfileData) {
143 return ProfileData.getNumOperands() - getBranchWeightOffset(&ProfileData);
144}
145
147 auto *ProfileData = I.getMetadata(LLVMContext::MD_prof);
148 if (!isBranchWeightMD(ProfileData))
149 return nullptr;
150 return ProfileData;
151}
152
154 auto *ProfileData = getBranchWeightMDNode(I);
155 if (ProfileData && getNumBranchWeights(*ProfileData) == I.getNumSuccessors())
156 return ProfileData;
157 return nullptr;
158}
159
160void extractFromBranchWeightMD32(const MDNode *ProfileData,
161 SmallVectorImpl<uint32_t> &Weights) {
162 extractFromBranchWeightMD(ProfileData, Weights);
163}
164
165void extractFromBranchWeightMD64(const MDNode *ProfileData,
166 SmallVectorImpl<uint64_t> &Weights) {
167 extractFromBranchWeightMD(ProfileData, Weights);
168}
169
170bool extractBranchWeights(const MDNode *ProfileData,
171 SmallVectorImpl<uint32_t> &Weights) {
172 if (!isBranchWeightMD(ProfileData))
173 return false;
174 extractFromBranchWeightMD(ProfileData, Weights);
175 return true;
176}
177
179 SmallVectorImpl<uint32_t> &Weights) {
180 auto *ProfileData = I.getMetadata(LLVMContext::MD_prof);
181 return extractBranchWeights(ProfileData, Weights);
182}
183
185 uint64_t &FalseVal) {
186 assert((I.getOpcode() == Instruction::Br ||
187 I.getOpcode() == Instruction::Select) &&
188 "Looking for branch weights on something besides branch, select, or "
189 "switch");
190
192 auto *ProfileData = I.getMetadata(LLVMContext::MD_prof);
193 if (!extractBranchWeights(ProfileData, Weights))
194 return false;
195
196 if (Weights.size() > 2)
197 return false;
198
199 TrueVal = Weights[0];
200 FalseVal = Weights[1];
201 return true;
202}
203
204bool extractProfTotalWeight(const MDNode *ProfileData, uint64_t &TotalVal) {
205 TotalVal = 0;
206 if (!ProfileData)
207 return false;
208
209 auto *ProfDataName = dyn_cast<MDString>(ProfileData->getOperand(0));
210 if (!ProfDataName)
211 return false;
212
213 if (ProfDataName->getString() == "branch_weights") {
214 unsigned Offset = getBranchWeightOffset(ProfileData);
215 for (unsigned Idx = Offset; Idx < ProfileData->getNumOperands(); ++Idx) {
216 auto *V = mdconst::extract<ConstantInt>(ProfileData->getOperand(Idx));
217 TotalVal += V->getValue().getZExtValue();
218 }
219 return true;
220 }
221
222 if (ProfDataName->getString() == "VP" && ProfileData->getNumOperands() > 3) {
223 TotalVal = mdconst::dyn_extract<ConstantInt>(ProfileData->getOperand(2))
224 ->getValue()
225 .getZExtValue();
226 return true;
227 }
228 return false;
229}
230
232 return extractProfTotalWeight(I.getMetadata(LLVMContext::MD_prof), TotalVal);
233}
234
236 bool IsExpected) {
237 MDBuilder MDB(I.getContext());
238 MDNode *BranchWeights = MDB.createBranchWeights(Weights, IsExpected);
239 I.setMetadata(LLVMContext::MD_prof, BranchWeights);
240}
241
243 assert(T != 0 && "Caller should guarantee");
244 auto *ProfileData = I.getMetadata(LLVMContext::MD_prof);
245 if (ProfileData == nullptr)
246 return;
247
248 auto *ProfDataName = dyn_cast<MDString>(ProfileData->getOperand(0));
249 if (!ProfDataName || (ProfDataName->getString() != "branch_weights" &&
250 ProfDataName->getString() != "VP"))
251 return;
252
253 if (!hasCountTypeMD(I))
254 return;
255
256 LLVMContext &C = I.getContext();
257
258 MDBuilder MDB(C);
260 Vals.push_back(ProfileData->getOperand(0));
261 APInt APS(128, S), APT(128, T);
262 if (ProfDataName->getString() == "branch_weights" &&
263 ProfileData->getNumOperands() > 0) {
264 // Using APInt::div may be expensive, but most cases should fit 64 bits.
265 APInt Val(128,
266 mdconst::dyn_extract<ConstantInt>(
267 ProfileData->getOperand(getBranchWeightOffset(ProfileData)))
268 ->getValue()
269 .getZExtValue());
270 Val *= APS;
271 Vals.push_back(MDB.createConstant(ConstantInt::get(
272 Type::getInt32Ty(C), Val.udiv(APT).getLimitedValue(UINT32_MAX))));
273 } else if (ProfDataName->getString() == "VP")
274 for (unsigned i = 1; i < ProfileData->getNumOperands(); i += 2) {
275 // The first value is the key of the value profile, which will not change.
276 Vals.push_back(ProfileData->getOperand(i));
277 uint64_t Count =
278 mdconst::dyn_extract<ConstantInt>(ProfileData->getOperand(i + 1))
279 ->getValue()
280 .getZExtValue();
281 // Don't scale the magic number.
282 if (Count == NOMORE_ICP_MAGICNUM) {
283 Vals.push_back(ProfileData->getOperand(i + 1));
284 continue;
285 }
286 // Using APInt::div may be expensive, but most cases should fit 64 bits.
287 APInt Val(128, Count);
288 Val *= APS;
289 Vals.push_back(MDB.createConstant(ConstantInt::get(
290 Type::getInt64Ty(C), Val.udiv(APT).getLimitedValue())));
291 }
292 I.setMetadata(LLVMContext::MD_prof, MDNode::get(C, Vals));
293}
294
295} // namespace llvm
This file contains the declarations for the subclasses of Constant, which represent the different fla...
Returns the sub type a function will return at a given Idx Should correspond to the result type of an ExtractValue instruction executed with just that one unsigned Idx
std::string Name
#define I(x, y, z)
Definition: MD5.cpp:58
This file contains the declarations for metadata subclasses.
This file contains the declarations for profiling metadata utility functions.
assert(ImpDefSCC.getReg()==AMDGPU::SCC &&ImpDefSCC.isDef())
This file defines the SmallVector class.
Class for arbitrary precision integers.
Definition: APInt.h:78
APInt udiv(const APInt &RHS) const
Unsigned division operation.
Definition: APInt.cpp:1547
unsigned getActiveBits() const
Compute the number of active bits in the value.
Definition: APInt.h:1492
uint64_t getLimitedValue(uint64_t Limit=UINT64_MAX) const
If this value is smaller than the specified limit, return it, otherwise return the limit value.
Definition: APInt.h:475
ArrayRef - Represent a constant reference to an array (0 or more elements consecutively in memory),...
Definition: ArrayRef.h:41
This is the shared class of boolean and integer constants.
Definition: Constants.h:83
uint64_t getZExtValue() const
Return the constant as a 64-bit unsigned integer value after it has been zero extended as appropriate...
Definition: Constants.h:157
const APInt & getValue() const
Return the constant as an APInt value reference.
Definition: Constants.h:148
This is an important class for using LLVM in a threaded context.
Definition: LLVMContext.h:67
ConstantAsMetadata * createConstant(Constant *C)
Return the given constant as metadata.
Definition: MDBuilder.cpp:24
MDNode * createBranchWeights(uint32_t TrueWeight, uint32_t FalseWeight, bool IsExpected=false)
Return metadata containing two branch weights.
Definition: MDBuilder.cpp:37
Metadata node.
Definition: Metadata.h:1069
const MDOperand & getOperand(unsigned I) const
Definition: Metadata.h:1430
static MDTuple * get(LLVMContext &Context, ArrayRef< Metadata * > MDs)
Definition: Metadata.h:1543
unsigned getNumOperands() const
Return number of MDNode operands.
Definition: Metadata.h:1436
size_t size() const
Definition: SmallVector.h:78
This class consists of common code factored out of the SmallVector class to reduce code duplication b...
Definition: SmallVector.h:573
void resize(size_type N)
Definition: SmallVector.h:638
void push_back(const T &Elt)
Definition: SmallVector.h:413
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
Definition: SmallVector.h:1196
static IntegerType * getInt32Ty(LLVMContext &C)
static IntegerType * getInt64Ty(LLVMContext &C)
@ C
The default llvm calling convention, compatible with C.
Definition: CallingConv.h:34
This is an optimization pass for GlobalISel generic memory operations.
Definition: AddressRanges.h:18
@ Offset
Definition: DWP.cpp:480
bool extractProfTotalWeight(const MDNode *ProfileData, uint64_t &TotalWeights)
Retrieve the total of all weights from MD_prof data.
unsigned getBranchWeightOffset(const MDNode *ProfileData)
Return the offset to the first branch weight data.
bool isBranchWeightMD(const MDNode *ProfileData)
Checks if an MDNode contains Branch Weight Metadata.
MDNode * getBranchWeightMDNode(const Instruction &I)
Get the branch weights metadata node.
bool hasBranchWeightOrigin(const Instruction &I)
Check if Branch Weight Metadata has an "expected" field from an llvm.expect* intrinsic.
void setBranchWeights(Instruction &I, ArrayRef< uint32_t > Weights, bool IsExpected)
Create a new branch_weights metadata node and add or overwrite a prof metadata reference to instructi...
MDNode * getValidBranchWeightMDNode(const Instruction &I)
Get the valid branch weights metadata node.
bool hasValidBranchWeightMD(const Instruction &I)
Checks if an instructions has valid Branch Weight Metadata.
bool isValueProfileMD(const MDNode *ProfileData)
bool hasCountTypeMD(const Instruction &I)
unsigned getNumBranchWeights(const MDNode &ProfileData)
void extractFromBranchWeightMD32(const MDNode *ProfileData, SmallVectorImpl< uint32_t > &Weights)
Faster version of extractBranchWeights() that skips checks and must only be called with "branch_weigh...
bool hasProfMD(const Instruction &I)
Checks if an Instruction has MD_prof Metadata.
bool extractBranchWeights(const MDNode *ProfileData, SmallVectorImpl< uint32_t > &Weights)
Extract branch weights from MD_prof metadata.
bool hasBranchWeightMD(const Instruction &I)
Checks if an instructions has Branch Weight Metadata.
const uint64_t NOMORE_ICP_MAGICNUM
Magic number in the value profile metadata showing a target has been promoted for the instruction and...
Definition: Metadata.h:57
void scaleProfData(Instruction &I, uint64_t S, uint64_t T)
Scaling the profile data attached to 'I' using the ratio of S/T.
void extractFromBranchWeightMD64(const MDNode *ProfileData, SmallVectorImpl< uint64_t > &Weights)
Faster version of extractBranchWeights() that skips checks and must only be called with "branch_weigh...