LLVM 20.0.0git
AMDGPULowerKernelAttributes.cpp
Go to the documentation of this file.
1//===-- AMDGPULowerKernelAttributes.cpp ------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9/// \file This pass does attempts to make use of reqd_work_group_size metadata
10/// to eliminate loads from the dispatch packet and to constant fold OpenCL
11/// get_local_size-like functions.
12//
13//===----------------------------------------------------------------------===//
14
15#include "AMDGPU.h"
19#include "llvm/CodeGen/Passes.h"
20#include "llvm/IR/Constants.h"
21#include "llvm/IR/Function.h"
24#include "llvm/IR/IntrinsicsAMDGPU.h"
25#include "llvm/IR/MDBuilder.h"
27#include "llvm/Pass.h"
28
29#define DEBUG_TYPE "amdgpu-lower-kernel-attributes"
30
31using namespace llvm;
32
33namespace {
34
35// Field offsets in hsa_kernel_dispatch_packet_t.
36enum DispatchPackedOffsets {
37 WORKGROUP_SIZE_X = 4,
38 WORKGROUP_SIZE_Y = 6,
39 WORKGROUP_SIZE_Z = 8,
40
41 GRID_SIZE_X = 12,
42 GRID_SIZE_Y = 16,
43 GRID_SIZE_Z = 20
44};
45
46// Field offsets to implicit kernel argument pointer.
47enum ImplicitArgOffsets {
48 HIDDEN_BLOCK_COUNT_X = 0,
49 HIDDEN_BLOCK_COUNT_Y = 4,
50 HIDDEN_BLOCK_COUNT_Z = 8,
51
52 HIDDEN_GROUP_SIZE_X = 12,
53 HIDDEN_GROUP_SIZE_Y = 14,
54 HIDDEN_GROUP_SIZE_Z = 16,
55
56 HIDDEN_REMAINDER_X = 18,
57 HIDDEN_REMAINDER_Y = 20,
58 HIDDEN_REMAINDER_Z = 22,
59};
60
61class AMDGPULowerKernelAttributes : public ModulePass {
62public:
63 static char ID;
64
65 AMDGPULowerKernelAttributes() : ModulePass(ID) {}
66
67 bool runOnModule(Module &M) override;
68
69 StringRef getPassName() const override {
70 return "AMDGPU Kernel Attributes";
71 }
72
73 void getAnalysisUsage(AnalysisUsage &AU) const override {
74 AU.setPreservesAll();
75 }
76};
77
78Function *getBasePtrIntrinsic(Module &M, bool IsV5OrAbove) {
79 auto IntrinsicId = IsV5OrAbove ? Intrinsic::amdgcn_implicitarg_ptr
80 : Intrinsic::amdgcn_dispatch_ptr;
81 return Intrinsic::getDeclarationIfExists(&M, IntrinsicId);
82}
83
84} // end anonymous namespace
85
87 uint32_t MaxNumGroups) {
88 if (MaxNumGroups == 0 || MaxNumGroups == std::numeric_limits<uint32_t>::max())
89 return;
90
91 if (!Load->getType()->isIntegerTy(32))
92 return;
93
94 // TODO: If there is existing range metadata, preserve it if it is stricter.
95 MDBuilder MDB(Load->getContext());
96 MDNode *Range = MDB.createRange(APInt(32, 1), APInt(32, MaxNumGroups + 1));
97 Load->setMetadata(LLVMContext::MD_range, Range);
98}
99
100static bool processUse(CallInst *CI, bool IsV5OrAbove) {
101 Function *F = CI->getParent()->getParent();
102
103 auto *MD = F->getMetadata("reqd_work_group_size");
104 const bool HasReqdWorkGroupSize = MD && MD->getNumOperands() == 3;
105
106 const bool HasUniformWorkGroupSize =
107 F->getFnAttribute("uniform-work-group-size").getValueAsBool();
108
109 SmallVector<unsigned> MaxNumWorkgroups =
110 AMDGPU::getIntegerVecAttribute(*F, "amdgpu-max-num-workgroups", 3);
111
112 if (!HasReqdWorkGroupSize && !HasUniformWorkGroupSize &&
113 none_of(MaxNumWorkgroups, [](unsigned X) { return X != 0; }))
114 return false;
115
116 Value *BlockCounts[3] = {nullptr, nullptr, nullptr};
117 Value *GroupSizes[3] = {nullptr, nullptr, nullptr};
118 Value *Remainders[3] = {nullptr, nullptr, nullptr};
119 Value *GridSizes[3] = {nullptr, nullptr, nullptr};
120
121 const DataLayout &DL = F->getDataLayout();
122
123 // We expect to see several GEP users, casted to the appropriate type and
124 // loaded.
125 for (User *U : CI->users()) {
126 if (!U->hasOneUse())
127 continue;
128
129 int64_t Offset = 0;
130 auto *Load = dyn_cast<LoadInst>(U); // Load from ImplicitArgPtr/DispatchPtr?
131 auto *BCI = dyn_cast<BitCastInst>(U);
132 if (!Load && !BCI) {
134 continue;
135 Load = dyn_cast<LoadInst>(*U->user_begin()); // Load from GEP?
136 BCI = dyn_cast<BitCastInst>(*U->user_begin());
137 }
138
139 if (BCI) {
140 if (!BCI->hasOneUse())
141 continue;
142 Load = dyn_cast<LoadInst>(*BCI->user_begin()); // Load from BCI?
143 }
144
145 if (!Load || !Load->isSimple())
146 continue;
147
148 unsigned LoadSize = DL.getTypeStoreSize(Load->getType());
149
150 // TODO: Handle merged loads.
151 if (IsV5OrAbove) { // Base is ImplicitArgPtr.
152 switch (Offset) {
153 case HIDDEN_BLOCK_COUNT_X:
154 if (LoadSize == 4) {
155 BlockCounts[0] = Load;
156 annotateGridSizeLoadWithRangeMD(Load, MaxNumWorkgroups[0]);
157 }
158 break;
159 case HIDDEN_BLOCK_COUNT_Y:
160 if (LoadSize == 4) {
161 BlockCounts[1] = Load;
162 annotateGridSizeLoadWithRangeMD(Load, MaxNumWorkgroups[1]);
163 }
164 break;
165 case HIDDEN_BLOCK_COUNT_Z:
166 if (LoadSize == 4) {
167 BlockCounts[2] = Load;
168 annotateGridSizeLoadWithRangeMD(Load, MaxNumWorkgroups[2]);
169 }
170 break;
171 case HIDDEN_GROUP_SIZE_X:
172 if (LoadSize == 2)
173 GroupSizes[0] = Load;
174 break;
175 case HIDDEN_GROUP_SIZE_Y:
176 if (LoadSize == 2)
177 GroupSizes[1] = Load;
178 break;
179 case HIDDEN_GROUP_SIZE_Z:
180 if (LoadSize == 2)
181 GroupSizes[2] = Load;
182 break;
183 case HIDDEN_REMAINDER_X:
184 if (LoadSize == 2)
185 Remainders[0] = Load;
186 break;
187 case HIDDEN_REMAINDER_Y:
188 if (LoadSize == 2)
189 Remainders[1] = Load;
190 break;
191 case HIDDEN_REMAINDER_Z:
192 if (LoadSize == 2)
193 Remainders[2] = Load;
194 break;
195 default:
196 break;
197 }
198 } else { // Base is DispatchPtr.
199 switch (Offset) {
200 case WORKGROUP_SIZE_X:
201 if (LoadSize == 2)
202 GroupSizes[0] = Load;
203 break;
204 case WORKGROUP_SIZE_Y:
205 if (LoadSize == 2)
206 GroupSizes[1] = Load;
207 break;
208 case WORKGROUP_SIZE_Z:
209 if (LoadSize == 2)
210 GroupSizes[2] = Load;
211 break;
212 case GRID_SIZE_X:
213 if (LoadSize == 4)
214 GridSizes[0] = Load;
215 break;
216 case GRID_SIZE_Y:
217 if (LoadSize == 4)
218 GridSizes[1] = Load;
219 break;
220 case GRID_SIZE_Z:
221 if (LoadSize == 4)
222 GridSizes[2] = Load;
223 break;
224 default:
225 break;
226 }
227 }
228 }
229
230 bool MadeChange = false;
231 if (IsV5OrAbove && HasUniformWorkGroupSize) {
232 // Under v5 __ockl_get_local_size returns the value computed by the expression:
233 //
234 // workgroup_id < hidden_block_count ? hidden_group_size : hidden_remainder
235 //
236 // For functions with the attribute uniform-work-group-size=true. we can evaluate
237 // workgroup_id < hidden_block_count as true, and thus hidden_group_size is returned
238 // for __ockl_get_local_size.
239 for (int I = 0; I < 3; ++I) {
240 Value *BlockCount = BlockCounts[I];
241 if (!BlockCount)
242 continue;
243
244 using namespace llvm::PatternMatch;
245 auto GroupIDIntrin =
246 I == 0 ? m_Intrinsic<Intrinsic::amdgcn_workgroup_id_x>()
247 : (I == 1 ? m_Intrinsic<Intrinsic::amdgcn_workgroup_id_y>()
248 : m_Intrinsic<Intrinsic::amdgcn_workgroup_id_z>());
249
250 for (User *ICmp : BlockCount->users()) {
251 if (match(ICmp, m_SpecificICmp(ICmpInst::ICMP_ULT, GroupIDIntrin,
252 m_Specific(BlockCount)))) {
253 ICmp->replaceAllUsesWith(llvm::ConstantInt::getTrue(ICmp->getType()));
254 MadeChange = true;
255 }
256 }
257 }
258
259 // All remainders should be 0 with uniform work group size.
260 for (Value *Remainder : Remainders) {
261 if (!Remainder)
262 continue;
263 Remainder->replaceAllUsesWith(Constant::getNullValue(Remainder->getType()));
264 MadeChange = true;
265 }
266 } else if (HasUniformWorkGroupSize) { // Pre-V5.
267 // Pattern match the code used to handle partial workgroup dispatches in the
268 // library implementation of get_local_size, so the entire function can be
269 // constant folded with a known group size.
270 //
271 // uint r = grid_size - group_id * group_size;
272 // get_local_size = (r < group_size) ? r : group_size;
273 //
274 // If we have uniform-work-group-size (which is the default in OpenCL 1.2),
275 // the grid_size is required to be a multiple of group_size). In this case:
276 //
277 // grid_size - (group_id * group_size) < group_size
278 // ->
279 // grid_size < group_size + (group_id * group_size)
280 //
281 // (grid_size / group_size) < 1 + group_id
282 //
283 // grid_size / group_size is at least 1, so we can conclude the select
284 // condition is false (except for group_id == 0, where the select result is
285 // the same).
286 for (int I = 0; I < 3; ++I) {
287 Value *GroupSize = GroupSizes[I];
288 Value *GridSize = GridSizes[I];
289 if (!GroupSize || !GridSize)
290 continue;
291
292 using namespace llvm::PatternMatch;
293 auto GroupIDIntrin =
294 I == 0 ? m_Intrinsic<Intrinsic::amdgcn_workgroup_id_x>()
295 : (I == 1 ? m_Intrinsic<Intrinsic::amdgcn_workgroup_id_y>()
296 : m_Intrinsic<Intrinsic::amdgcn_workgroup_id_z>());
297
298 for (User *U : GroupSize->users()) {
299 auto *ZextGroupSize = dyn_cast<ZExtInst>(U);
300 if (!ZextGroupSize)
301 continue;
302
303 for (User *UMin : ZextGroupSize->users()) {
304 if (match(UMin,
305 m_UMin(m_Sub(m_Specific(GridSize),
306 m_Mul(GroupIDIntrin, m_Specific(ZextGroupSize))),
307 m_Specific(ZextGroupSize)))) {
308 if (HasReqdWorkGroupSize) {
309 ConstantInt *KnownSize
310 = mdconst::extract<ConstantInt>(MD->getOperand(I));
311 UMin->replaceAllUsesWith(ConstantFoldIntegerCast(
312 KnownSize, UMin->getType(), false, DL));
313 } else {
314 UMin->replaceAllUsesWith(ZextGroupSize);
315 }
316
317 MadeChange = true;
318 }
319 }
320 }
321 }
322 }
323
324 // If reqd_work_group_size is set, we can replace work group size with it.
325 if (!HasReqdWorkGroupSize)
326 return MadeChange;
327
328 for (int I = 0; I < 3; I++) {
329 Value *GroupSize = GroupSizes[I];
330 if (!GroupSize)
331 continue;
332
333 ConstantInt *KnownSize = mdconst::extract<ConstantInt>(MD->getOperand(I));
334 GroupSize->replaceAllUsesWith(
335 ConstantFoldIntegerCast(KnownSize, GroupSize->getType(), false, DL));
336 MadeChange = true;
337 }
338
339 return MadeChange;
340}
341
342
343// TODO: Move makeLIDRangeMetadata usage into here. Seem to not get
344// TargetPassConfig for subtarget.
345bool AMDGPULowerKernelAttributes::runOnModule(Module &M) {
346 bool MadeChange = false;
347 bool IsV5OrAbove =
349 Function *BasePtr = getBasePtrIntrinsic(M, IsV5OrAbove);
350
351 if (!BasePtr) // ImplicitArgPtr/DispatchPtr not used.
352 return false;
353
355 for (auto *U : BasePtr->users()) {
356 CallInst *CI = cast<CallInst>(U);
357 if (HandledUses.insert(CI).second) {
358 if (processUse(CI, IsV5OrAbove))
359 MadeChange = true;
360 }
361 }
362
363 return MadeChange;
364}
365
366
367INITIALIZE_PASS_BEGIN(AMDGPULowerKernelAttributes, DEBUG_TYPE,
368 "AMDGPU Kernel Attributes", false, false)
369INITIALIZE_PASS_END(AMDGPULowerKernelAttributes, DEBUG_TYPE,
370 "AMDGPU Kernel Attributes", false, false)
371
372char AMDGPULowerKernelAttributes::ID = 0;
373
375 return new AMDGPULowerKernelAttributes();
376}
377
380 bool IsV5OrAbove =
382 Function *BasePtr = getBasePtrIntrinsic(*F.getParent(), IsV5OrAbove);
383
384 if (!BasePtr) // ImplicitArgPtr/DispatchPtr not used.
385 return PreservedAnalyses::all();
386
387 for (Instruction &I : instructions(F)) {
388 if (CallInst *CI = dyn_cast<CallInst>(&I)) {
389 if (CI->getCalledFunction() == BasePtr)
390 processUse(CI, IsV5OrAbove);
391 }
392 }
393
394 return PreservedAnalyses::all();
395}
AMDGPU Kernel Attributes
#define DEBUG_TYPE
static void annotateGridSizeLoadWithRangeMD(LoadInst *Load, uint32_t MaxNumGroups)
static bool processUse(CallInst *CI, bool IsV5OrAbove)
MachineBasicBlock MachineBasicBlock::iterator DebugLoc DL
Expand Atomic instructions
This file contains the declarations for the subclasses of Constant, which represent the different fla...
static GCMetadataPrinterRegistry::Add< ErlangGCPrinter > X("erlang", "erlang-compatible garbage collector")
#define F(x, y, z)
Definition: MD5.cpp:55
#define I(x, y, z)
Definition: MD5.cpp:58
ConstantRange Range(APInt(BitWidth, Low), APInt(BitWidth, High))
#define INITIALIZE_PASS_END(passName, arg, name, cfg, analysis)
Definition: PassSupport.h:57
#define INITIALIZE_PASS_BEGIN(passName, arg, name, cfg, analysis)
Definition: PassSupport.h:52
Class for arbitrary precision integers.
Definition: APInt.h:78
A container for analyses that lazily runs them and caches their results.
Definition: PassManager.h:253
Represent the analysis usage information of a pass.
void setPreservesAll()
Set by analyses that do not transform their input at all.
Function * getCalledFunction() const
Returns the function called, or null if this is an indirect function invocation or the function signa...
Definition: InstrTypes.h:1349
This class represents a function call, abstracting a target machine's calling convention.
This is the shared class of boolean and integer constants.
Definition: Constants.h:83
static ConstantInt * getTrue(LLVMContext &Context)
Definition: Constants.cpp:866
static Constant * getNullValue(Type *Ty)
Constructor to create a '0' constant of arbitrary type.
Definition: Constants.cpp:373
A parsed version of the target data layout string in and methods for querying it.
Definition: DataLayout.h:63
An instruction for reading from memory.
Definition: Instructions.h:176
MDNode * createRange(const APInt &Lo, const APInt &Hi)
Return metadata describing the range [Lo, Hi).
Definition: MDBuilder.cpp:95
Metadata node.
Definition: Metadata.h:1069
ModulePass class - This class is used to implement unstructured interprocedural optimizations and ana...
Definition: Pass.h:251
virtual bool runOnModule(Module &M)=0
runOnModule - Virtual method overriden by subclasses to process the module being operated on.
A Module instance is used to store all the information related to an LLVM module.
Definition: Module.h:65
virtual void getAnalysisUsage(AnalysisUsage &) const
getAnalysisUsage - This function should be overriden by passes that need analysis information to do t...
Definition: Pass.cpp:98
virtual StringRef getPassName() const
getPassName - Return a nice clean name for a pass.
Definition: Pass.cpp:81
A set of analyses that are preserved following a run of a transformation pass.
Definition: Analysis.h:111
static PreservedAnalyses all()
Construct a special preserved set that preserves all passes.
Definition: Analysis.h:117
std::pair< iterator, bool > insert(PtrType Ptr)
Inserts Ptr if and only if there is no element in the container equal to Ptr.
Definition: SmallPtrSet.h:384
SmallPtrSet - This class implements a set which is optimized for holding SmallSize or less elements.
Definition: SmallPtrSet.h:519
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
Definition: SmallVector.h:1196
StringRef - Represent a constant reference to a string, i.e.
Definition: StringRef.h:51
LLVM Value Representation.
Definition: Value.h:74
Type * getType() const
All values are typed, get the type of this value.
Definition: Value.h:255
void replaceAllUsesWith(Value *V)
Change all uses of this to point to a new Value.
Definition: Value.cpp:534
iterator_range< user_iterator > users()
Definition: Value.h:421
const ParentTy * getParent() const
Definition: ilist_node.h:32
unsigned getAMDHSACodeObjectVersion(const Module &M)
SmallVector< unsigned > getIntegerVecAttribute(const Function &F, StringRef Name, unsigned Size, unsigned DefaultVal)
unsigned ID
LLVM IR allows to use arbitrary numbers as calling convention identifiers.
Definition: CallingConv.h:24
Function * getDeclarationIfExists(Module *M, ID id, ArrayRef< Type * > Tys, FunctionType *FT=nullptr)
This version supports overloaded intrinsics.
Definition: Intrinsics.cpp:746
bool match(Val *V, const Pattern &P)
Definition: PatternMatch.h:49
specificval_ty m_Specific(const Value *V)
Match if we have a specific specified value.
Definition: PatternMatch.h:885
BinaryOp_match< LHS, RHS, Instruction::Mul > m_Mul(const LHS &L, const RHS &R)
SpecificCmpClass_match< LHS, RHS, ICmpInst > m_SpecificICmp(CmpPredicate MatchPred, const LHS &L, const RHS &R)
BinaryOp_match< LHS, RHS, Instruction::Sub > m_Sub(const LHS &L, const RHS &R)
MaxMin_match< ICmpInst, LHS, RHS, umin_pred_ty > m_UMin(const LHS &L, const RHS &R)
This is an optimization pass for GlobalISel generic memory operations.
Definition: AddressRanges.h:18
@ Offset
Definition: DWP.cpp:480
Value * GetPointerBaseWithConstantOffset(Value *Ptr, int64_t &Offset, const DataLayout &DL, bool AllowNonInbounds=true)
Analyze the specified pointer to see if it can be expressed as a base pointer plus a constant offset.
ModulePass * createAMDGPULowerKernelAttributesPass()
bool none_of(R &&Range, UnaryPredicate P)
Provide wrappers to std::none_of which take ranges instead of having to pass begin/end explicitly.
Definition: STLExtras.h:1753
@ UMin
Unsigned integer min implemented in terms of select(cmp()).
Constant * ConstantFoldIntegerCast(Constant *C, Type *DestTy, bool IsSigned, const DataLayout &DL)
Constant fold a zext, sext or trunc, depending on IsSigned and whether the DestTy is wider or narrowe...
PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM)