LLVM  12.0.0git
R600OpenCLImageTypeLoweringPass.cpp
Go to the documentation of this file.
1 //===- R600OpenCLImageTypeLoweringPass.cpp ------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 /// \file
10 /// This pass resolves calls to OpenCL image attribute, image resource ID and
11 /// sampler resource ID getter functions.
12 ///
13 /// Image attributes (size and format) are expected to be passed to the kernel
14 /// as kernel arguments immediately following the image argument itself,
15 /// therefore this pass adds image size and format arguments to the kernel
16 /// functions in the module. The kernel functions with image arguments are
17 /// re-created using the new signature. The new arguments are added to the
18 /// kernel metadata with kernel_arg_type set to "image_size" or "image_format".
19 /// Note: this pass may invalidate pointers to functions.
20 ///
21 /// Resource IDs of read-only images, write-only images and samplers are
22 /// defined to be their index among the kernel arguments of the same
23 /// type and access qualifier.
24 //
25 //===----------------------------------------------------------------------===//
26 
27 #include "AMDGPU.h"
28 #include "llvm/ADT/SmallVector.h"
29 #include "llvm/ADT/StringRef.h"
30 #include "llvm/IR/Constants.h"
31 #include "llvm/IR/Function.h"
32 #include "llvm/IR/Instructions.h"
33 #include "llvm/IR/Metadata.h"
34 #include "llvm/Pass.h"
36 
37 using namespace llvm;
38 
39 static StringRef GetImageSizeFunc = "llvm.OpenCL.image.get.size";
40 static StringRef GetImageFormatFunc = "llvm.OpenCL.image.get.format";
41 static StringRef GetImageResourceIDFunc = "llvm.OpenCL.image.get.resource.id";
43  "llvm.OpenCL.sampler.get.resource.id";
44 
45 static StringRef ImageSizeArgMDType = "__llvm_image_size";
46 static StringRef ImageFormatArgMDType = "__llvm_image_format";
47 
48 static StringRef KernelsMDNodeName = "opencl.kernels";
50  "kernel_arg_addr_space",
51  "kernel_arg_access_qual",
52  "kernel_arg_type",
53  "kernel_arg_base_type",
54  "kernel_arg_type_qual"};
55 static const unsigned NumKernelArgMDNodes = 5;
56 
57 namespace {
58 
59 using MDVector = SmallVector<Metadata *, 8>;
60 struct KernelArgMD {
61  MDVector ArgVector[NumKernelArgMDNodes];
62 };
63 
64 } // end anonymous namespace
65 
66 static inline bool
67 IsImageType(StringRef TypeString) {
68  return TypeString == "image2d_t" || TypeString == "image3d_t";
69 }
70 
71 static inline bool
72 IsSamplerType(StringRef TypeString) {
73  return TypeString == "sampler_t";
74 }
75 
76 static Function *
78  if (!Node)
79  return nullptr;
80 
81  size_t NumOps = Node->getNumOperands();
82  if (NumOps != NumKernelArgMDNodes + 1)
83  return nullptr;
84 
85  auto F = mdconst::dyn_extract<Function>(Node->getOperand(0));
86  if (!F)
87  return nullptr;
88 
89  // Sanity checks.
90  size_t ExpectNumArgNodeOps = F->arg_size() + 1;
91  for (size_t i = 0; i < NumKernelArgMDNodes; ++i) {
92  MDNode *ArgNode = dyn_cast_or_null<MDNode>(Node->getOperand(i + 1));
93  if (ArgNode->getNumOperands() != ExpectNumArgNodeOps)
94  return nullptr;
95  if (!ArgNode->getOperand(0))
96  return nullptr;
97 
98  // FIXME: It should be possible to do image lowering when some metadata
99  // args missing or not in the expected order.
100  MDString *StringNode = dyn_cast<MDString>(ArgNode->getOperand(0));
101  if (!StringNode || StringNode->getString() != KernelArgMDNodeNames[i])
102  return nullptr;
103  }
104 
105  return F;
106 }
107 
108 static StringRef
109 AccessQualFromMD(MDNode *KernelMDNode, unsigned ArgIdx) {
110  MDNode *ArgAQNode = cast<MDNode>(KernelMDNode->getOperand(2));
111  return cast<MDString>(ArgAQNode->getOperand(ArgIdx + 1))->getString();
112 }
113 
114 static StringRef
115 ArgTypeFromMD(MDNode *KernelMDNode, unsigned ArgIdx) {
116  MDNode *ArgTypeNode = cast<MDNode>(KernelMDNode->getOperand(3));
117  return cast<MDString>(ArgTypeNode->getOperand(ArgIdx + 1))->getString();
118 }
119 
120 static MDVector
121 GetArgMD(MDNode *KernelMDNode, unsigned OpIdx) {
122  MDVector Res;
123  for (unsigned i = 0; i < NumKernelArgMDNodes; ++i) {
124  MDNode *Node = cast<MDNode>(KernelMDNode->getOperand(i + 1));
125  Res.push_back(Node->getOperand(OpIdx));
126  }
127  return Res;
128 }
129 
130 static void
131 PushArgMD(KernelArgMD &MD, const MDVector &V) {
132  assert(V.size() == NumKernelArgMDNodes);
133  for (unsigned i = 0; i < NumKernelArgMDNodes; ++i) {
134  MD.ArgVector[i].push_back(V[i]);
135  }
136 }
137 
138 namespace {
139 
140 class R600OpenCLImageTypeLoweringPass : public ModulePass {
141  static char ID;
142 
144  Type *Int32Type;
145  Type *ImageSizeType;
146  Type *ImageFormatType;
147  SmallVector<Instruction *, 4> InstsToErase;
148 
149  bool replaceImageUses(Argument &ImageArg, uint32_t ResourceID,
150  Argument &ImageSizeArg,
151  Argument &ImageFormatArg) {
152  bool Modified = false;
153 
154  for (auto &Use : ImageArg.uses()) {
155  auto Inst = dyn_cast<CallInst>(Use.getUser());
156  if (!Inst) {
157  continue;
158  }
159 
160  Function *F = Inst->getCalledFunction();
161  if (!F)
162  continue;
163 
164  Value *Replacement = nullptr;
165  StringRef Name = F->getName();
166  if (Name.startswith(GetImageResourceIDFunc)) {
167  Replacement = ConstantInt::get(Int32Type, ResourceID);
168  } else if (Name.startswith(GetImageSizeFunc)) {
169  Replacement = &ImageSizeArg;
170  } else if (Name.startswith(GetImageFormatFunc)) {
171  Replacement = &ImageFormatArg;
172  } else {
173  continue;
174  }
175 
176  Inst->replaceAllUsesWith(Replacement);
177  InstsToErase.push_back(Inst);
178  Modified = true;
179  }
180 
181  return Modified;
182  }
183 
184  bool replaceSamplerUses(Argument &SamplerArg, uint32_t ResourceID) {
185  bool Modified = false;
186 
187  for (const auto &Use : SamplerArg.uses()) {
188  auto Inst = dyn_cast<CallInst>(Use.getUser());
189  if (!Inst) {
190  continue;
191  }
192 
193  Function *F = Inst->getCalledFunction();
194  if (!F)
195  continue;
196 
197  Value *Replacement = nullptr;
198  StringRef Name = F->getName();
200  Replacement = ConstantInt::get(Int32Type, ResourceID);
201  } else {
202  continue;
203  }
204 
205  Inst->replaceAllUsesWith(Replacement);
206  InstsToErase.push_back(Inst);
207  Modified = true;
208  }
209 
210  return Modified;
211  }
212 
213  bool replaceImageAndSamplerUses(Function *F, MDNode *KernelMDNode) {
214  uint32_t NumReadOnlyImageArgs = 0;
215  uint32_t NumWriteOnlyImageArgs = 0;
216  uint32_t NumSamplerArgs = 0;
217 
218  bool Modified = false;
219  InstsToErase.clear();
220  for (auto ArgI = F->arg_begin(); ArgI != F->arg_end(); ++ArgI) {
221  Argument &Arg = *ArgI;
222  StringRef Type = ArgTypeFromMD(KernelMDNode, Arg.getArgNo());
223 
224  // Handle image types.
225  if (IsImageType(Type)) {
226  StringRef AccessQual = AccessQualFromMD(KernelMDNode, Arg.getArgNo());
227  uint32_t ResourceID;
228  if (AccessQual == "read_only") {
229  ResourceID = NumReadOnlyImageArgs++;
230  } else if (AccessQual == "write_only") {
231  ResourceID = NumWriteOnlyImageArgs++;
232  } else {
233  llvm_unreachable("Wrong image access qualifier.");
234  }
235 
236  Argument &SizeArg = *(++ArgI);
237  Argument &FormatArg = *(++ArgI);
238  Modified |= replaceImageUses(Arg, ResourceID, SizeArg, FormatArg);
239 
240  // Handle sampler type.
241  } else if (IsSamplerType(Type)) {
242  uint32_t ResourceID = NumSamplerArgs++;
243  Modified |= replaceSamplerUses(Arg, ResourceID);
244  }
245  }
246  for (unsigned i = 0; i < InstsToErase.size(); ++i) {
247  InstsToErase[i]->eraseFromParent();
248  }
249 
250  return Modified;
251  }
252 
253  std::tuple<Function *, MDNode *>
254  addImplicitArgs(Function *F, MDNode *KernelMDNode) {
255  bool Modified = false;
256 
257  FunctionType *FT = F->getFunctionType();
258  SmallVector<Type *, 8> ArgTypes;
259 
260  // Metadata operands for new MDNode.
261  KernelArgMD NewArgMDs;
262  PushArgMD(NewArgMDs, GetArgMD(KernelMDNode, 0));
263 
264  // Add implicit arguments to the signature.
265  for (unsigned i = 0; i < FT->getNumParams(); ++i) {
266  ArgTypes.push_back(FT->getParamType(i));
267  MDVector ArgMD = GetArgMD(KernelMDNode, i + 1);
268  PushArgMD(NewArgMDs, ArgMD);
269 
270  if (!IsImageType(ArgTypeFromMD(KernelMDNode, i)))
271  continue;
272 
273  // Add size implicit argument.
274  ArgTypes.push_back(ImageSizeType);
275  ArgMD[2] = ArgMD[3] = MDString::get(*Context, ImageSizeArgMDType);
276  PushArgMD(NewArgMDs, ArgMD);
277 
278  // Add format implicit argument.
279  ArgTypes.push_back(ImageFormatType);
280  ArgMD[2] = ArgMD[3] = MDString::get(*Context, ImageFormatArgMDType);
281  PushArgMD(NewArgMDs, ArgMD);
282 
283  Modified = true;
284  }
285  if (!Modified) {
286  return std::make_tuple(nullptr, nullptr);
287  }
288 
289  // Create function with new signature and clone the old body into it.
290  auto NewFT = FunctionType::get(FT->getReturnType(), ArgTypes, false);
291  auto NewF = Function::Create(NewFT, F->getLinkage(), F->getName());
292  ValueToValueMapTy VMap;
293  auto NewFArgIt = NewF->arg_begin();
294  for (auto &Arg: F->args()) {
295  auto ArgName = Arg.getName();
296  NewFArgIt->setName(ArgName);
297  VMap[&Arg] = &(*NewFArgIt++);
298  if (IsImageType(ArgTypeFromMD(KernelMDNode, Arg.getArgNo()))) {
299  (NewFArgIt++)->setName(Twine("__size_") + ArgName);
300  (NewFArgIt++)->setName(Twine("__format_") + ArgName);
301  }
302  }
304  CloneFunctionInto(NewF, F, VMap, /*ModuleLevelChanges=*/false, Returns);
305 
306  // Build new MDNode.
307  SmallVector<Metadata *, 6> KernelMDArgs;
308  KernelMDArgs.push_back(ConstantAsMetadata::get(NewF));
309  for (unsigned i = 0; i < NumKernelArgMDNodes; ++i)
310  KernelMDArgs.push_back(MDNode::get(*Context, NewArgMDs.ArgVector[i]));
311  MDNode *NewMDNode = MDNode::get(*Context, KernelMDArgs);
312 
313  return std::make_tuple(NewF, NewMDNode);
314  }
315 
316  bool transformKernels(Module &M) {
317  NamedMDNode *KernelsMDNode = M.getNamedMetadata(KernelsMDNodeName);
318  if (!KernelsMDNode)
319  return false;
320 
321  bool Modified = false;
322  for (unsigned i = 0; i < KernelsMDNode->getNumOperands(); ++i) {
323  MDNode *KernelMDNode = KernelsMDNode->getOperand(i);
324  Function *F = GetFunctionFromMDNode(KernelMDNode);
325  if (!F)
326  continue;
327 
328  Function *NewF;
329  MDNode *NewMDNode;
330  std::tie(NewF, NewMDNode) = addImplicitArgs(F, KernelMDNode);
331  if (NewF) {
332  // Replace old function and metadata with new ones.
333  F->eraseFromParent();
334  M.getFunctionList().push_back(NewF);
335  M.getOrInsertFunction(NewF->getName(), NewF->getFunctionType(),
336  NewF->getAttributes());
337  KernelsMDNode->setOperand(i, NewMDNode);
338 
339  F = NewF;
340  KernelMDNode = NewMDNode;
341  Modified = true;
342  }
343 
344  Modified |= replaceImageAndSamplerUses(F, KernelMDNode);
345  }
346 
347  return Modified;
348  }
349 
350 public:
351  R600OpenCLImageTypeLoweringPass() : ModulePass(ID) {}
352 
353  bool runOnModule(Module &M) override {
354  Context = &M.getContext();
355  Int32Type = Type::getInt32Ty(M.getContext());
356  ImageSizeType = ArrayType::get(Int32Type, 3);
357  ImageFormatType = ArrayType::get(Int32Type, 2);
358 
359  return transformKernels(M);
360  }
361 
362  StringRef getPassName() const override {
363  return "R600 OpenCL Image Type Pass";
364  }
365 };
366 
367 } // end anonymous namespace
368 
370 
372  return new R600OpenCLImageTypeLoweringPass();
373 }
iterator_range< use_iterator > uses()
Definition: Value.h:379
This class represents an incoming formal argument to a Function.
Definition: Argument.h:29
LLVMContext & Context
MDNode * getOperand(unsigned i) const
Definition: Metadata.cpp:1105
static StringRef AccessQualFromMD(MDNode *KernelMDNode, unsigned ArgIdx)
This class represents lattice values for constants.
Definition: AllocatorList.h:23
Type * getParamType(unsigned i) const
Parameter type accessors.
Definition: DerivedTypes.h:134
static StringRef GetImageFormatFunc
A Module instance is used to store all the information related to an LLVM module.
Definition: Module.h:67
ModulePass * createR600OpenCLImageTypeLoweringPass()
static MDString * get(LLVMContext &Context, StringRef Str)
Definition: Metadata.cpp:454
This file contains the declarations for metadata subclasses.
Metadata node.
Definition: Metadata.h:870
F(f)
const MDOperand & getOperand(unsigned I) const
Definition: Metadata.h:1075
static StringRef ImageFormatArgMDType
void setOperand(unsigned I, MDNode *New)
Definition: Metadata.cpp:1113
static StringRef KernelsMDNodeName
static StringRef GetSamplerResourceIDFunc
A tuple of MDNodes.
Definition: Metadata.h:1350
Twine - A lightweight data structure for efficiently representing the concatenation of temporary valu...
Definition: Twine.h:80
A Use represents the edge between a Value definition and its users.
Definition: Use.h:44
void CloneFunctionInto(Function *NewFunc, const Function *OldFunc, ValueToValueMapTy &VMap, bool ModuleLevelChanges, SmallVectorImpl< ReturnInst * > &Returns, const char *NameSuffix="", ClonedCodeInfo *CodeInfo=nullptr, ValueMapTypeRemapper *TypeMapper=nullptr, ValueMaterializer *Materializer=nullptr)
Clone OldFunc into NewFunc, transforming the old arguments into references to VMap values.
static bool IsImageType(StringRef TypeString)
unsigned getNumOperands() const
Definition: Metadata.cpp:1101
Class to represent function types.
Definition: DerivedTypes.h:102
static FunctionType * get(Type *Result, ArrayRef< Type * > Params, bool isVarArg)
This static method is the primary way of constructing a FunctionType.
Definition: Type.cpp:321
static StringRef ImageSizeArgMDType
AttributeList getAttributes() const
Return the attribute list for this Function.
Definition: Function.h:239
void replaceAllUsesWith(Value *V)
Change all uses of this to point to a new Value.
Definition: Value.cpp:523
static ConstantAsMetadata * get(Constant *C)
Definition: Metadata.h:410
StringRef getString() const
Definition: Metadata.cpp:464
static Function * Create(FunctionType *Ty, LinkageTypes Linkage, unsigned AddrSpace, const Twine &N="", Module *M=nullptr)
Definition: Function.h:137
The instances of the Type class are immutable: once they are created, they are never changed.
Definition: Type.h:46
This is an important class for using LLVM in a threaded context.
Definition: LLVMContext.h:68
static MDVector GetArgMD(MDNode *KernelMDNode, unsigned OpIdx)
This file contains the declarations for the subclasses of Constant, which represent the different fla...
unsigned getNumParams() const
Return the number of fixed parameters this function type requires.
Definition: DerivedTypes.h:138
amdgpu Simplify well known AMD library false FunctionCallee Value * Arg
static void PushArgMD(KernelArgMD &MD, const MDVector &V)
#define llvm_unreachable(msg)
Marks that the current location is not supposed to be reachable.
User * getUser() const
Returns the User that contains this Use.
Definition: Use.h:73
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
Definition: SmallVector.h:1116
Type * getReturnType() const
Definition: DerivedTypes.h:123
static Constant * get(Type *Ty, uint64_t V, bool isSigned=false)
If Ty is a vector type, return a Constant with a splat of the given value.
Definition: Constants.cpp:867
static StringRef KernelArgMDNodeNames[]
FunctionType * getFunctionType() const
Returns the FunctionType for me.
Definition: Function.h:165
static const unsigned NumKernelArgMDNodes
static StringRef GetImageResourceIDFunc
static MDTuple * get(LLVMContext &Context, ArrayRef< Metadata * > MDs)
Definition: Metadata.h:1171
static IntegerType * getInt32Ty(LLVMContext &C)
Definition: Type.cpp:197
StringRef getName() const
Return a constant reference to the value's name.
Definition: Value.cpp:295
ModulePass class - This class is used to implement unstructured interprocedural optimizations and ana...
Definition: Pass.h:238
static ArrayType * get(Type *ElementType, uint64_t NumElements)
This static method is the primary way to construct an ArrayType.
Definition: Type.cpp:581
static Function * GetFunctionFromMDNode(MDNode *Node)
assert(ImpDefSCC.getReg()==AMDGPU::SCC &&ImpDefSCC.isDef())
LLVM Value Representation.
Definition: Value.h:75
StringRef - Represent a constant reference to a string, i.e.
Definition: StringRef.h:57
A single uniqued string.
Definition: Metadata.h:602
static bool IsSamplerType(StringRef TypeString)
static StringRef GetImageSizeFunc
unsigned getNumOperands() const
Return number of MDNode operands.
Definition: Metadata.h:1081
static StringRef ArgTypeFromMD(MDNode *KernelMDNode, unsigned ArgIdx)