LLVM 17.0.0git
AMDGPULowerKernelAttributes.cpp
Go to the documentation of this file.
1//===-- AMDGPULowerKernelAttributes.cpp ------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9/// \file This pass does attempts to make use of reqd_work_group_size metadata
10/// to eliminate loads from the dispatch packet and to constant fold OpenCL
11/// get_local_size-like functions.
12//
13//===----------------------------------------------------------------------===//
14
15#include "AMDGPU.h"
18#include "llvm/CodeGen/Passes.h"
20#include "llvm/IR/Constants.h"
21#include "llvm/IR/Function.h"
24#include "llvm/IR/IntrinsicsAMDGPU.h"
26#include "llvm/Pass.h"
27
28#define DEBUG_TYPE "amdgpu-lower-kernel-attributes"
29
30using namespace llvm;
31
32namespace {
33
34// Field offsets in hsa_kernel_dispatch_packet_t.
35enum DispatchPackedOffsets {
36 WORKGROUP_SIZE_X = 4,
37 WORKGROUP_SIZE_Y = 6,
38 WORKGROUP_SIZE_Z = 8,
39
40 GRID_SIZE_X = 12,
41 GRID_SIZE_Y = 16,
42 GRID_SIZE_Z = 20
43};
44
45// Field offsets to implicit kernel argument pointer.
46enum ImplicitArgOffsets {
47 HIDDEN_BLOCK_COUNT_X = 0,
48 HIDDEN_BLOCK_COUNT_Y = 4,
49 HIDDEN_BLOCK_COUNT_Z = 8,
50
51 HIDDEN_GROUP_SIZE_X = 12,
52 HIDDEN_GROUP_SIZE_Y = 14,
53 HIDDEN_GROUP_SIZE_Z = 16,
54
55 HIDDEN_REMAINDER_X = 18,
56 HIDDEN_REMAINDER_Y = 20,
57 HIDDEN_REMAINDER_Z = 22,
58};
59
60class AMDGPULowerKernelAttributes : public ModulePass {
61public:
62 static char ID;
63
64 AMDGPULowerKernelAttributes() : ModulePass(ID) {}
65
66 bool runOnModule(Module &M) override;
67
68 StringRef getPassName() const override {
69 return "AMDGPU Kernel Attributes";
70 }
71
72 void getAnalysisUsage(AnalysisUsage &AU) const override {
73 AU.setPreservesAll();
74 }
75};
76
77Function *getBasePtrIntrinsic(Module &M, bool IsV5OrAbove) {
78 auto IntrinsicId = IsV5OrAbove ? Intrinsic::amdgcn_implicitarg_ptr
79 : Intrinsic::amdgcn_dispatch_ptr;
80 StringRef Name = Intrinsic::getName(IntrinsicId);
81 return M.getFunction(Name);
82}
83
84} // end anonymous namespace
85
86static bool processUse(CallInst *CI, bool IsV5OrAbove) {
87 Function *F = CI->getParent()->getParent();
88
89 auto MD = F->getMetadata("reqd_work_group_size");
90 const bool HasReqdWorkGroupSize = MD && MD->getNumOperands() == 3;
91
92 const bool HasUniformWorkGroupSize =
93 F->getFnAttribute("uniform-work-group-size").getValueAsBool();
94
95 if (!HasReqdWorkGroupSize && !HasUniformWorkGroupSize)
96 return false;
97
98 Value *BlockCounts[3] = {nullptr, nullptr, nullptr};
99 Value *GroupSizes[3] = {nullptr, nullptr, nullptr};
100 Value *Remainders[3] = {nullptr, nullptr, nullptr};
101 Value *GridSizes[3] = {nullptr, nullptr, nullptr};
102
103 const DataLayout &DL = F->getParent()->getDataLayout();
104
105 // We expect to see several GEP users, casted to the appropriate type and
106 // loaded.
107 for (User *U : CI->users()) {
108 if (!U->hasOneUse())
109 continue;
110
111 int64_t Offset = 0;
112 auto *Load = dyn_cast<LoadInst>(U); // Load from ImplicitArgPtr/DispatchPtr?
113 auto *BCI = dyn_cast<BitCastInst>(U);
114 if (!Load && !BCI) {
116 continue;
117 Load = dyn_cast<LoadInst>(*U->user_begin()); // Load from GEP?
118 BCI = dyn_cast<BitCastInst>(*U->user_begin());
119 }
120
121 if (BCI) {
122 if (!BCI->hasOneUse())
123 continue;
124 Load = dyn_cast<LoadInst>(*BCI->user_begin()); // Load from BCI?
125 }
126
127 if (!Load || !Load->isSimple())
128 continue;
129
130 unsigned LoadSize = DL.getTypeStoreSize(Load->getType());
131
132 // TODO: Handle merged loads.
133 if (IsV5OrAbove) { // Base is ImplicitArgPtr.
134 switch (Offset) {
135 case HIDDEN_BLOCK_COUNT_X:
136 if (LoadSize == 4)
137 BlockCounts[0] = Load;
138 break;
139 case HIDDEN_BLOCK_COUNT_Y:
140 if (LoadSize == 4)
141 BlockCounts[1] = Load;
142 break;
143 case HIDDEN_BLOCK_COUNT_Z:
144 if (LoadSize == 4)
145 BlockCounts[2] = Load;
146 break;
147 case HIDDEN_GROUP_SIZE_X:
148 if (LoadSize == 2)
149 GroupSizes[0] = Load;
150 break;
151 case HIDDEN_GROUP_SIZE_Y:
152 if (LoadSize == 2)
153 GroupSizes[1] = Load;
154 break;
155 case HIDDEN_GROUP_SIZE_Z:
156 if (LoadSize == 2)
157 GroupSizes[2] = Load;
158 break;
159 case HIDDEN_REMAINDER_X:
160 if (LoadSize == 2)
161 Remainders[0] = Load;
162 break;
163 case HIDDEN_REMAINDER_Y:
164 if (LoadSize == 2)
165 Remainders[1] = Load;
166 break;
167 case HIDDEN_REMAINDER_Z:
168 if (LoadSize == 2)
169 Remainders[2] = Load;
170 break;
171 default:
172 break;
173 }
174 } else { // Base is DispatchPtr.
175 switch (Offset) {
176 case WORKGROUP_SIZE_X:
177 if (LoadSize == 2)
178 GroupSizes[0] = Load;
179 break;
180 case WORKGROUP_SIZE_Y:
181 if (LoadSize == 2)
182 GroupSizes[1] = Load;
183 break;
184 case WORKGROUP_SIZE_Z:
185 if (LoadSize == 2)
186 GroupSizes[2] = Load;
187 break;
188 case GRID_SIZE_X:
189 if (LoadSize == 4)
190 GridSizes[0] = Load;
191 break;
192 case GRID_SIZE_Y:
193 if (LoadSize == 4)
194 GridSizes[1] = Load;
195 break;
196 case GRID_SIZE_Z:
197 if (LoadSize == 4)
198 GridSizes[2] = Load;
199 break;
200 default:
201 break;
202 }
203 }
204 }
205
206 bool MadeChange = false;
207 if (IsV5OrAbove && HasUniformWorkGroupSize) {
208 // Under v5 __ockl_get_local_size returns the value computed by the expression:
209 //
210 // workgroup_id < hidden_block_count ? hidden_group_size : hidden_remainder
211 //
212 // For functions with the attribute uniform-work-group-size=true. we can evaluate
213 // workgroup_id < hidden_block_count as true, and thus hidden_group_size is returned
214 // for __ockl_get_local_size.
215 for (int I = 0; I < 3; ++I) {
216 Value *BlockCount = BlockCounts[I];
217 if (!BlockCount)
218 continue;
219
220 using namespace llvm::PatternMatch;
221 auto GroupIDIntrin =
222 I == 0 ? m_Intrinsic<Intrinsic::amdgcn_workgroup_id_x>()
223 : (I == 1 ? m_Intrinsic<Intrinsic::amdgcn_workgroup_id_y>()
224 : m_Intrinsic<Intrinsic::amdgcn_workgroup_id_z>());
225
226 for (User *ICmp : BlockCount->users()) {
228 if (match(ICmp, m_ICmp(Pred, GroupIDIntrin, m_Specific(BlockCount)))) {
229 if (Pred != ICmpInst::ICMP_ULT)
230 continue;
231 ICmp->replaceAllUsesWith(llvm::ConstantInt::getTrue(ICmp->getType()));
232 MadeChange = true;
233 }
234 }
235 }
236
237 // All remainders should be 0 with uniform work group size.
238 for (Value *Remainder : Remainders) {
239 if (!Remainder)
240 continue;
241 Remainder->replaceAllUsesWith(Constant::getNullValue(Remainder->getType()));
242 MadeChange = true;
243 }
244 } else if (HasUniformWorkGroupSize) { // Pre-V5.
245 // Pattern match the code used to handle partial workgroup dispatches in the
246 // library implementation of get_local_size, so the entire function can be
247 // constant folded with a known group size.
248 //
249 // uint r = grid_size - group_id * group_size;
250 // get_local_size = (r < group_size) ? r : group_size;
251 //
252 // If we have uniform-work-group-size (which is the default in OpenCL 1.2),
253 // the grid_size is required to be a multiple of group_size). In this case:
254 //
255 // grid_size - (group_id * group_size) < group_size
256 // ->
257 // grid_size < group_size + (group_id * group_size)
258 //
259 // (grid_size / group_size) < 1 + group_id
260 //
261 // grid_size / group_size is at least 1, so we can conclude the select
262 // condition is false (except for group_id == 0, where the select result is
263 // the same).
264 for (int I = 0; I < 3; ++I) {
265 Value *GroupSize = GroupSizes[I];
266 Value *GridSize = GridSizes[I];
267 if (!GroupSize || !GridSize)
268 continue;
269
270 using namespace llvm::PatternMatch;
271 auto GroupIDIntrin =
272 I == 0 ? m_Intrinsic<Intrinsic::amdgcn_workgroup_id_x>()
273 : (I == 1 ? m_Intrinsic<Intrinsic::amdgcn_workgroup_id_y>()
274 : m_Intrinsic<Intrinsic::amdgcn_workgroup_id_z>());
275
276 for (User *U : GroupSize->users()) {
277 auto *ZextGroupSize = dyn_cast<ZExtInst>(U);
278 if (!ZextGroupSize)
279 continue;
280
281 for (User *UMin : ZextGroupSize->users()) {
282 if (match(UMin,
283 m_UMin(m_Sub(m_Specific(GridSize),
284 m_Mul(GroupIDIntrin, m_Specific(ZextGroupSize))),
285 m_Specific(ZextGroupSize)))) {
286 if (HasReqdWorkGroupSize) {
287 ConstantInt *KnownSize
288 = mdconst::extract<ConstantInt>(MD->getOperand(I));
289 UMin->replaceAllUsesWith(ConstantExpr::getIntegerCast(
290 KnownSize, UMin->getType(), false));
291 } else {
292 UMin->replaceAllUsesWith(ZextGroupSize);
293 }
294
295 MadeChange = true;
296 }
297 }
298 }
299 }
300 }
301
302 // If reqd_work_group_size is set, we can replace work group size with it.
303 if (!HasReqdWorkGroupSize)
304 return MadeChange;
305
306 for (int I = 0; I < 3; I++) {
307 Value *GroupSize = GroupSizes[I];
308 if (!GroupSize)
309 continue;
310
311 ConstantInt *KnownSize = mdconst::extract<ConstantInt>(MD->getOperand(I));
312 GroupSize->replaceAllUsesWith(
313 ConstantExpr::getIntegerCast(KnownSize, GroupSize->getType(), false));
314 MadeChange = true;
315 }
316
317 return MadeChange;
318}
319
320
321// TODO: Move makeLIDRangeMetadata usage into here. Seem to not get
322// TargetPassConfig for subtarget.
323bool AMDGPULowerKernelAttributes::runOnModule(Module &M) {
324 bool MadeChange = false;
325 bool IsV5OrAbove = AMDGPU::getCodeObjectVersion(M) >= AMDGPU::AMDHSA_COV5;
326 Function *BasePtr = getBasePtrIntrinsic(M, IsV5OrAbove);
327
328 if (!BasePtr) // ImplicitArgPtr/DispatchPtr not used.
329 return false;
330
332 for (auto *U : BasePtr->users()) {
333 CallInst *CI = cast<CallInst>(U);
334 if (HandledUses.insert(CI).second) {
335 if (processUse(CI, IsV5OrAbove))
336 MadeChange = true;
337 }
338 }
339
340 return MadeChange;
341}
342
343
344INITIALIZE_PASS_BEGIN(AMDGPULowerKernelAttributes, DEBUG_TYPE,
345 "AMDGPU Kernel Attributes", false, false)
346INITIALIZE_PASS_END(AMDGPULowerKernelAttributes, DEBUG_TYPE,
347 "AMDGPU Kernel Attributes", false, false)
348
349char AMDGPULowerKernelAttributes::ID = 0;
350
352 return new AMDGPULowerKernelAttributes();
353}
354
357 bool IsV5OrAbove =
359 Function *BasePtr = getBasePtrIntrinsic(*F.getParent(), IsV5OrAbove);
360
361 if (!BasePtr) // ImplicitArgPtr/DispatchPtr not used.
362 return PreservedAnalyses::all();
363
364 for (Instruction &I : instructions(F)) {
365 if (CallInst *CI = dyn_cast<CallInst>(&I)) {
366 if (CI->getCalledFunction() == BasePtr)
367 processUse(CI, IsV5OrAbove);
368 }
369 }
370
371 return PreservedAnalyses::all();
372}
MachineBasicBlock MachineBasicBlock::iterator DebugLoc DL
AMDGPU Kernel Attributes
#define DEBUG_TYPE
static bool processUse(CallInst *CI, bool IsV5OrAbove)
This file contains the declarations for the subclasses of Constant, which represent the different fla...
std::string Name
#define F(x, y, z)
Definition: MD5.cpp:55
#define I(x, y, z)
Definition: MD5.cpp:58
print must be executed print the must be executed context for all instructions
#define INITIALIZE_PASS_END(passName, arg, name, cfg, analysis)
Definition: PassSupport.h:59
#define INITIALIZE_PASS_BEGIN(passName, arg, name, cfg, analysis)
Definition: PassSupport.h:52
Target-Independent Code Generator Pass Configuration Options pass.
A container for analyses that lazily runs them and caches their results.
Definition: PassManager.h:620
Represent the analysis usage information of a pass.
void setPreservesAll()
Set by analyses that do not transform their input at all.
const Function * getParent() const
Return the enclosing method, or null if none.
Definition: BasicBlock.h:112
Function * getCalledFunction() const
Returns the function called, or null if this is an indirect function invocation or the function signa...
Definition: InstrTypes.h:1412
This class represents a function call, abstracting a target machine's calling convention.
Predicate
This enumeration lists the possible predicates for CmpInst subclasses.
Definition: InstrTypes.h:711
static Constant * getIntegerCast(Constant *C, Type *Ty, bool IsSigned)
Create a ZExt, Bitcast or Trunc for integer -> integer casts.
Definition: Constants.cpp:2051
This is the shared class of boolean and integer constants.
Definition: Constants.h:78
static ConstantInt * getTrue(LLVMContext &Context)
Definition: Constants.cpp:833
static Constant * getNullValue(Type *Ty)
Constructor to create a '0' constant of arbitrary type.
Definition: Constants.cpp:356
A parsed version of the target data layout string in and methods for querying it.
Definition: DataLayout.h:110
const BasicBlock * getParent() const
Definition: Instruction.h:90
ModulePass class - This class is used to implement unstructured interprocedural optimizations and ana...
Definition: Pass.h:251
virtual bool runOnModule(Module &M)=0
runOnModule - Virtual method overriden by subclasses to process the module being operated on.
A Module instance is used to store all the information related to an LLVM module.
Definition: Module.h:65
virtual void getAnalysisUsage(AnalysisUsage &) const
getAnalysisUsage - This function should be overriden by passes that need analysis information to do t...
Definition: Pass.cpp:98
virtual StringRef getPassName() const
getPassName - Return a nice clean name for a pass.
Definition: Pass.cpp:81
A set of analyses that are preserved following a run of a transformation pass.
Definition: PassManager.h:152
static PreservedAnalyses all()
Construct a special preserved set that preserves all passes.
Definition: PassManager.h:158
std::pair< iterator, bool > insert(PtrType Ptr)
Inserts Ptr if and only if there is no element in the container equal to Ptr.
Definition: SmallPtrSet.h:365
SmallPtrSet - This class implements a set which is optimized for holding SmallSize or less elements.
Definition: SmallPtrSet.h:450
StringRef - Represent a constant reference to a string, i.e.
Definition: StringRef.h:50
LLVM Value Representation.
Definition: Value.h:74
Type * getType() const
All values are typed, get the type of this value.
Definition: Value.h:255
void replaceAllUsesWith(Value *V)
Change all uses of this to point to a new Value.
Definition: Value.cpp:535
iterator_range< user_iterator > users()
Definition: Value.h:421
unsigned getCodeObjectVersion(const Module &M)
unsigned ID
LLVM IR allows to use arbitrary numbers as calling convention identifiers.
Definition: CallingConv.h:24
StringRef getName(ID id)
Return the LLVM name for an intrinsic, such as "llvm.ppc.altivec.lvx".
Definition: Function.cpp:992
bool match(Val *V, const Pattern &P)
Definition: PatternMatch.h:49
specificval_ty m_Specific(const Value *V)
Match if we have a specific specified value.
Definition: PatternMatch.h:772
BinaryOp_match< LHS, RHS, Instruction::Mul > m_Mul(const LHS &L, const RHS &R)
CmpClass_match< LHS, RHS, ICmpInst, ICmpInst::Predicate > m_ICmp(ICmpInst::Predicate &Pred, const LHS &L, const RHS &R)
BinaryOp_match< LHS, RHS, Instruction::Sub > m_Sub(const LHS &L, const RHS &R)
Definition: PatternMatch.h:991
MaxMin_match< ICmpInst, LHS, RHS, umin_pred_ty > m_UMin(const LHS &L, const RHS &R)
This is an optimization pass for GlobalISel generic memory operations.
Definition: AddressRanges.h:18
@ Offset
Definition: DWP.cpp:406
Value * GetPointerBaseWithConstantOffset(Value *Ptr, int64_t &Offset, const DataLayout &DL, bool AllowNonInbounds=true)
Analyze the specified pointer to see if it can be expressed as a base pointer plus a constant offset.
ModulePass * createAMDGPULowerKernelAttributesPass()
@ UMin
Unisgned integer min implemented in terms of select(cmp()).
PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM)