LLVM  14.0.0git
AMDGPULowerKernelAttributes.cpp
Go to the documentation of this file.
1 //===-- AMDGPULowerKernelAttributes.cpp ------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 /// \file This pass does attempts to make use of reqd_work_group_size metadata
10 /// to eliminate loads from the dispatch packet and to constant fold OpenCL
11 /// get_local_size-like functions.
12 //
13 //===----------------------------------------------------------------------===//
14 
15 #include "AMDGPU.h"
17 #include "llvm/CodeGen/Passes.h"
19 #include "llvm/IR/Constants.h"
20 #include "llvm/IR/Function.h"
21 #include "llvm/IR/InstIterator.h"
22 #include "llvm/IR/Instructions.h"
23 #include "llvm/IR/IntrinsicsAMDGPU.h"
24 #include "llvm/IR/PatternMatch.h"
25 #include "llvm/Pass.h"
26 
27 #define DEBUG_TYPE "amdgpu-lower-kernel-attributes"
28 
29 using namespace llvm;
30 
31 namespace {
32 
33 // Field offsets in hsa_kernel_dispatch_packet_t.
34 enum DispatchPackedOffsets {
35  WORKGROUP_SIZE_X = 4,
36  WORKGROUP_SIZE_Y = 6,
37  WORKGROUP_SIZE_Z = 8,
38 
39  GRID_SIZE_X = 12,
40  GRID_SIZE_Y = 16,
41  GRID_SIZE_Z = 20
42 };
43 
44 class AMDGPULowerKernelAttributes : public ModulePass {
45 public:
46  static char ID;
47 
48  AMDGPULowerKernelAttributes() : ModulePass(ID) {}
49 
50  bool runOnModule(Module &M) override;
51 
52  StringRef getPassName() const override {
53  return "AMDGPU Kernel Attributes";
54  }
55 
56  void getAnalysisUsage(AnalysisUsage &AU) const override {
57  AU.setPreservesAll();
58  }
59 };
60 
61 } // end anonymous namespace
62 
63 static bool processUse(CallInst *CI) {
64  Function *F = CI->getParent()->getParent();
65 
66  auto MD = F->getMetadata("reqd_work_group_size");
67  const bool HasReqdWorkGroupSize = MD && MD->getNumOperands() == 3;
68 
69  const bool HasUniformWorkGroupSize =
70  F->getFnAttribute("uniform-work-group-size").getValueAsBool();
71 
72  if (!HasReqdWorkGroupSize && !HasUniformWorkGroupSize)
73  return false;
74 
75  Value *WorkGroupSizeX = nullptr;
76  Value *WorkGroupSizeY = nullptr;
77  Value *WorkGroupSizeZ = nullptr;
78 
79  Value *GridSizeX = nullptr;
80  Value *GridSizeY = nullptr;
81  Value *GridSizeZ = nullptr;
82 
83  const DataLayout &DL = F->getParent()->getDataLayout();
84 
85  // We expect to see several GEP users, casted to the appropriate type and
86  // loaded.
87  for (User *U : CI->users()) {
88  if (!U->hasOneUse())
89  continue;
90 
91  int64_t Offset = 0;
93  continue;
94 
95  auto *BCI = dyn_cast<BitCastInst>(*U->user_begin());
96  if (!BCI || !BCI->hasOneUse())
97  continue;
98 
99  auto *Load = dyn_cast<LoadInst>(*BCI->user_begin());
100  if (!Load || !Load->isSimple())
101  continue;
102 
103  unsigned LoadSize = DL.getTypeStoreSize(Load->getType());
104 
105  // TODO: Handle merged loads.
106  switch (Offset) {
107  case WORKGROUP_SIZE_X:
108  if (LoadSize == 2)
109  WorkGroupSizeX = Load;
110  break;
111  case WORKGROUP_SIZE_Y:
112  if (LoadSize == 2)
113  WorkGroupSizeY = Load;
114  break;
115  case WORKGROUP_SIZE_Z:
116  if (LoadSize == 2)
117  WorkGroupSizeZ = Load;
118  break;
119  case GRID_SIZE_X:
120  if (LoadSize == 4)
121  GridSizeX = Load;
122  break;
123  case GRID_SIZE_Y:
124  if (LoadSize == 4)
125  GridSizeY = Load;
126  break;
127  case GRID_SIZE_Z:
128  if (LoadSize == 4)
129  GridSizeZ = Load;
130  break;
131  default:
132  break;
133  }
134  }
135 
136  // Pattern match the code used to handle partial workgroup dispatches in the
137  // library implementation of get_local_size, so the entire function can be
138  // constant folded with a known group size.
139  //
140  // uint r = grid_size - group_id * group_size;
141  // get_local_size = (r < group_size) ? r : group_size;
142  //
143  // If we have uniform-work-group-size (which is the default in OpenCL 1.2),
144  // the grid_size is required to be a multiple of group_size). In this case:
145  //
146  // grid_size - (group_id * group_size) < group_size
147  // ->
148  // grid_size < group_size + (group_id * group_size)
149  //
150  // (grid_size / group_size) < 1 + group_id
151  //
152  // grid_size / group_size is at least 1, so we can conclude the select
153  // condition is false (except for group_id == 0, where the select result is
154  // the same).
155 
156  bool MadeChange = false;
157  Value *WorkGroupSizes[3] = { WorkGroupSizeX, WorkGroupSizeY, WorkGroupSizeZ };
158  Value *GridSizes[3] = { GridSizeX, GridSizeY, GridSizeZ };
159 
160  for (int I = 0; HasUniformWorkGroupSize && I < 3; ++I) {
161  Value *GroupSize = WorkGroupSizes[I];
162  Value *GridSize = GridSizes[I];
163  if (!GroupSize || !GridSize)
164  continue;
165 
166  for (User *U : GroupSize->users()) {
167  auto *ZextGroupSize = dyn_cast<ZExtInst>(U);
168  if (!ZextGroupSize)
169  continue;
170 
171  for (User *ZextUser : ZextGroupSize->users()) {
172  auto *SI = dyn_cast<SelectInst>(ZextUser);
173  if (!SI)
174  continue;
175 
176  using namespace llvm::PatternMatch;
177  auto GroupIDIntrin = I == 0 ?
178  m_Intrinsic<Intrinsic::amdgcn_workgroup_id_x>() :
179  (I == 1 ? m_Intrinsic<Intrinsic::amdgcn_workgroup_id_y>() :
180  m_Intrinsic<Intrinsic::amdgcn_workgroup_id_z>());
181 
182  auto SubExpr = m_Sub(m_Specific(GridSize),
183  m_Mul(GroupIDIntrin, m_Specific(ZextGroupSize)));
184 
185  ICmpInst::Predicate Pred;
186  if (match(SI,
187  m_Select(m_ICmp(Pred, SubExpr, m_Specific(ZextGroupSize)),
188  SubExpr,
189  m_Specific(ZextGroupSize))) &&
190  Pred == ICmpInst::ICMP_ULT) {
191  if (HasReqdWorkGroupSize) {
192  ConstantInt *KnownSize
193  = mdconst::extract<ConstantInt>(MD->getOperand(I));
194  SI->replaceAllUsesWith(ConstantExpr::getIntegerCast(KnownSize,
195  SI->getType(),
196  false));
197  } else {
198  SI->replaceAllUsesWith(ZextGroupSize);
199  }
200 
201  MadeChange = true;
202  }
203  }
204  }
205  }
206 
207  if (!HasReqdWorkGroupSize)
208  return MadeChange;
209 
210  // Eliminate any other loads we can from the dispatch packet.
211  for (int I = 0; I < 3; ++I) {
212  Value *GroupSize = WorkGroupSizes[I];
213  if (!GroupSize)
214  continue;
215 
216  ConstantInt *KnownSize = mdconst::extract<ConstantInt>(MD->getOperand(I));
217  GroupSize->replaceAllUsesWith(
219  GroupSize->getType(),
220  false));
221  MadeChange = true;
222  }
223 
224  return MadeChange;
225 }
226 
227 // TODO: Move makeLIDRangeMetadata usage into here. Seem to not get
228 // TargetPassConfig for subtarget.
229 bool AMDGPULowerKernelAttributes::runOnModule(Module &M) {
230  StringRef DispatchPtrName
231  = Intrinsic::getName(Intrinsic::amdgcn_dispatch_ptr);
232 
233  Function *DispatchPtr = M.getFunction(DispatchPtrName);
234  if (!DispatchPtr) // Dispatch ptr not used.
235  return false;
236 
237  bool MadeChange = false;
238 
239  SmallPtrSet<Instruction *, 4> HandledUses;
240  for (auto *U : DispatchPtr->users()) {
241  CallInst *CI = cast<CallInst>(U);
242  if (HandledUses.insert(CI).second) {
243  if (processUse(CI))
244  MadeChange = true;
245  }
246  }
247 
248  return MadeChange;
249 }
250 
251 INITIALIZE_PASS_BEGIN(AMDGPULowerKernelAttributes, DEBUG_TYPE,
252  "AMDGPU Kernel Attributes", false, false)
253 INITIALIZE_PASS_END(AMDGPULowerKernelAttributes, DEBUG_TYPE,
255 
256 char AMDGPULowerKernelAttributes::ID = 0;
257 
259  return new AMDGPULowerKernelAttributes();
260 }
261 
264  StringRef DispatchPtrName =
265  Intrinsic::getName(Intrinsic::amdgcn_dispatch_ptr);
266 
267  Function *DispatchPtr = F.getParent()->getFunction(DispatchPtrName);
268  if (!DispatchPtr) // Dispatch ptr not used.
269  return PreservedAnalyses::all();
270 
271  for (Instruction &I : instructions(F)) {
272  if (CallInst *CI = dyn_cast<CallInst>(&I)) {
273  if (CI->getCalledFunction() == DispatchPtr)
274  processUse(CI);
275  }
276  }
277 
278  return PreservedAnalyses::all();
279 }
llvm::PreservedAnalyses
A set of analyses that are preserved following a run of a transformation pass.
Definition: PassManager.h:155
llvm
This file implements support for optimizing divisions by a constant.
Definition: AllocatorList.h:23
M
We currently emits eax Perhaps this is what we really should generate is Is imull three or four cycles eax eax The current instruction priority is based on pattern complexity The former is more complex because it folds a load so the latter will not be emitted Perhaps we should use AddedComplexity to give LEA32r a higher priority We should always try to match LEA first since the LEA matching code does some estimate to determine whether the match is profitable if we care more about code then imull is better It s two bytes shorter than movl leal On a Pentium M
Definition: README.txt:252
llvm::DataLayout
A parsed version of the target data layout string in and methods for querying it.
Definition: DataLayout.h:113
llvm::CmpInst::Predicate
Predicate
This enumeration lists the possible predicates for CmpInst subclasses.
Definition: InstrTypes.h:720
llvm::ModulePass
ModulePass class - This class is used to implement unstructured interprocedural optimizations and ana...
Definition: Pass.h:238
llvm::BasicBlock::getParent
const Function * getParent() const
Return the enclosing method, or null if none.
Definition: BasicBlock.h:107
InstIterator.h
llvm::Function
Definition: Function.h:62
Pass.h
llvm::Intrinsic::getName
StringRef getName(ID id)
Return the LLVM name for an intrinsic, such as "llvm.ppc.altivec.lvx".
Definition: Function.cpp:879
ValueTracking.h
llvm::SPII::Load
@ Load
Definition: SparcInstrInfo.h:32
Offset
uint64_t Offset
Definition: ELFObjHandler.cpp:81
llvm::SmallPtrSet< Instruction *, 4 >
llvm::AMDGPULowerKernelAttributesPass::run
PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM)
Definition: AMDGPULowerKernelAttributes.cpp:263
F
#define F(x, y, z)
Definition: MD5.cpp:56
llvm::ConstantInt
This is the shared class of boolean and integer constants.
Definition: Constants.h:79
llvm::PatternMatch::m_Select
ThreeOps_match< Cond, LHS, RHS, Instruction::Select > m_Select(const Cond &C, const LHS &L, const RHS &R)
Matches SelectInst.
Definition: PatternMatch.h:1472
Constants.h
llvm::PatternMatch::match
bool match(Val *V, const Pattern &P)
Definition: PatternMatch.h:49
llvm::User
Definition: User.h:44
llvm::CallBase::getCalledFunction
Function * getCalledFunction() const
Returns the function called, or null if this is an indirect function invocation.
Definition: InstrTypes.h:1383
llvm::AnalysisUsage
Represent the analysis usage information of a pass.
Definition: PassAnalysisSupport.h:47
false
Definition: StackSlotColoring.cpp:142
llvm::Instruction
Definition: Instruction.h:45
PatternMatch.h
INITIALIZE_PASS_END
#define INITIALIZE_PASS_END(passName, arg, name, cfg, analysis)
Definition: PassSupport.h:58
Passes.h
llvm::instructions
inst_range instructions(Function *F)
Definition: InstIterator.h:133
llvm::omp::Kernel
Function * Kernel
Summary of a kernel (=entry point for target offloading).
Definition: OpenMPOpt.h:21
processUse
static bool processUse(CallInst *CI)
Definition: AMDGPULowerKernelAttributes.cpp:63
I
#define I(x, y, z)
Definition: MD5.cpp:59
Attributes
AMDGPU Kernel Attributes
Definition: AMDGPULowerKernelAttributes.cpp:254
TargetPassConfig.h
llvm::PatternMatch::m_Sub
BinaryOp_match< LHS, RHS, Instruction::Sub > m_Sub(const LHS &L, const RHS &R)
Definition: PatternMatch.h:1020
SI
StandardInstrumentations SI(Debug, VerifyEach)
llvm::Module
A Module instance is used to store all the information related to an LLVM module.
Definition: Module.h:67
llvm::StringRef
StringRef - Represent a constant reference to a string, i.e.
Definition: StringRef.h:58
AMDGPU.h
llvm::CmpInst::ICMP_ULT
@ ICMP_ULT
unsigned less than
Definition: InstrTypes.h:745
llvm::Value::getType
Type * getType() const
All values are typed, get the type of this value.
Definition: Value.h:256
llvm::Value::replaceAllUsesWith
void replaceAllUsesWith(Value *V)
Change all uses of this to point to a new Value.
Definition: Value.cpp:532
DL
MachineBasicBlock MachineBasicBlock::iterator DebugLoc DL
Definition: AArch64SLSHardening.cpp:76
llvm::ConstantExpr::getIntegerCast
static Constant * getIntegerCast(Constant *C, Type *Ty, bool IsSigned)
Create a ZExt, Bitcast or Trunc for integer -> integer casts.
Definition: Constants.cpp:2071
llvm::createAMDGPULowerKernelAttributesPass
ModulePass * createAMDGPULowerKernelAttributesPass()
Definition: AMDGPULowerKernelAttributes.cpp:258
llvm::AnalysisUsage::setPreservesAll
void setPreservesAll()
Set by analyses that do not transform their input at all.
Definition: PassAnalysisSupport.h:130
llvm::PreservedAnalyses::all
static PreservedAnalyses all()
Construct a special preserved set that preserves all passes.
Definition: PassManager.h:161
Function.h
llvm::GetPointerBaseWithConstantOffset
Value * GetPointerBaseWithConstantOffset(Value *Ptr, int64_t &Offset, const DataLayout &DL, bool AllowNonInbounds=true)
Analyze the specified pointer to see if it can be expressed as a base pointer plus a constant offset.
Definition: ValueTracking.h:279
Instructions.h
llvm::PatternMatch::m_Specific
specificval_ty m_Specific(const Value *V)
Match if we have a specific specified value.
Definition: PatternMatch.h:802
DEBUG_TYPE
#define DEBUG_TYPE
Definition: AMDGPULowerKernelAttributes.cpp:27
llvm::PatternMatch::m_ICmp
CmpClass_match< LHS, RHS, ICmpInst, ICmpInst::Predicate > m_ICmp(ICmpInst::Predicate &Pred, const LHS &L, const RHS &R)
Definition: PatternMatch.h:1404
llvm::Instruction::getParent
const BasicBlock * getParent() const
Definition: Instruction.h:94
llvm::PatternMatch
Definition: PatternMatch.h:47
llvm::AnalysisManager
A container for analyses that lazily runs them and caches their results.
Definition: InstructionSimplify.h:44
llvm::CallInst
This class represents a function call, abstracting a target machine's calling convention.
Definition: Instructions.h:1475
INITIALIZE_PASS_BEGIN
INITIALIZE_PASS_BEGIN(AMDGPULowerKernelAttributes, DEBUG_TYPE, "AMDGPU Kernel Attributes", false, false) INITIALIZE_PASS_END(AMDGPULowerKernelAttributes
llvm::Value
LLVM Value Representation.
Definition: Value.h:75
llvm::Value::users
iterator_range< user_iterator > users()
Definition: Value.h:422
llvm::PatternMatch::m_Mul
BinaryOp_match< LHS, RHS, Instruction::Mul > m_Mul(const LHS &L, const RHS &R)
Definition: PatternMatch.h:1075
llvm::SmallPtrSetImpl::insert
std::pair< iterator, bool > insert(PtrType Ptr)
Inserts Ptr if and only if there is no element in the container equal to Ptr.
Definition: SmallPtrSet.h:364
llvm::Intrinsic::ID
unsigned ID
Definition: TargetTransformInfo.h:37