LLVM  4.0.0
AMDGPUOpenCLImageTypeLoweringPass.cpp
Go to the documentation of this file.
1 //===-- AMDGPUOpenCLImageTypeLoweringPass.cpp -----------------------------===//
2 //
3 // The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 //
10 /// \file
11 /// This pass resolves calls to OpenCL image attribute, image resource ID and
12 /// sampler resource ID getter functions.
13 ///
14 /// Image attributes (size and format) are expected to be passed to the kernel
15 /// as kernel arguments immediately following the image argument itself,
16 /// therefore this pass adds image size and format arguments to the kernel
17 /// functions in the module. The kernel functions with image arguments are
18 /// re-created using the new signature. The new arguments are added to the
19 /// kernel metadata with kernel_arg_type set to "image_size" or "image_format".
20 /// Note: this pass may invalidate pointers to functions.
21 ///
22 /// Resource IDs of read-only images, write-only images and samplers are
23 /// defined to be their index among the kernel arguments of the same
24 /// type and access qualifier.
25 //===----------------------------------------------------------------------===//
26 
27 #include "AMDGPU.h"
28 #include "llvm/ADT/STLExtras.h"
29 #include "llvm/ADT/SmallVector.h"
30 #include "llvm/Analysis/Passes.h"
31 #include "llvm/IR/Constants.h"
32 #include "llvm/IR/Function.h"
33 #include "llvm/IR/Instructions.h"
34 #include "llvm/IR/Module.h"
36 
37 using namespace llvm;
38 
39 namespace {
40 
41 StringRef GetImageSizeFunc = "llvm.OpenCL.image.get.size";
42 StringRef GetImageFormatFunc = "llvm.OpenCL.image.get.format";
43 StringRef GetImageResourceIDFunc = "llvm.OpenCL.image.get.resource.id";
44 StringRef GetSamplerResourceIDFunc = "llvm.OpenCL.sampler.get.resource.id";
45 
46 StringRef ImageSizeArgMDType = "__llvm_image_size";
47 StringRef ImageFormatArgMDType = "__llvm_image_format";
48 
49 StringRef KernelsMDNodeName = "opencl.kernels";
50 StringRef KernelArgMDNodeNames[] = {
51  "kernel_arg_addr_space",
52  "kernel_arg_access_qual",
53  "kernel_arg_type",
54  "kernel_arg_base_type",
55  "kernel_arg_type_qual"};
56 const unsigned NumKernelArgMDNodes = 5;
57 
58 typedef SmallVector<Metadata *, 8> MDVector;
59 struct KernelArgMD {
60  MDVector ArgVector[NumKernelArgMDNodes];
61 };
62 
63 } // end anonymous namespace
64 
65 static inline bool
66 IsImageType(StringRef TypeString) {
67  return TypeString == "image2d_t" || TypeString == "image3d_t";
68 }
69 
70 static inline bool
71 IsSamplerType(StringRef TypeString) {
72  return TypeString == "sampler_t";
73 }
74 
75 static Function *
77  if (!Node)
78  return nullptr;
79 
80  size_t NumOps = Node->getNumOperands();
81  if (NumOps != NumKernelArgMDNodes + 1)
82  return nullptr;
83 
84  auto F = mdconst::dyn_extract<Function>(Node->getOperand(0));
85  if (!F)
86  return nullptr;
87 
88  // Sanity checks.
89  size_t ExpectNumArgNodeOps = F->arg_size() + 1;
90  for (size_t i = 0; i < NumKernelArgMDNodes; ++i) {
91  MDNode *ArgNode = dyn_cast_or_null<MDNode>(Node->getOperand(i + 1));
92  if (ArgNode->getNumOperands() != ExpectNumArgNodeOps)
93  return nullptr;
94  if (!ArgNode->getOperand(0))
95  return nullptr;
96 
97  // FIXME: It should be possible to do image lowering when some metadata
98  // args missing or not in the expected order.
99  MDString *StringNode = dyn_cast<MDString>(ArgNode->getOperand(0));
100  if (!StringNode || StringNode->getString() != KernelArgMDNodeNames[i])
101  return nullptr;
102  }
103 
104  return F;
105 }
106 
107 static StringRef
108 AccessQualFromMD(MDNode *KernelMDNode, unsigned ArgIdx) {
109  MDNode *ArgAQNode = cast<MDNode>(KernelMDNode->getOperand(2));
110  return cast<MDString>(ArgAQNode->getOperand(ArgIdx + 1))->getString();
111 }
112 
113 static StringRef
114 ArgTypeFromMD(MDNode *KernelMDNode, unsigned ArgIdx) {
115  MDNode *ArgTypeNode = cast<MDNode>(KernelMDNode->getOperand(3));
116  return cast<MDString>(ArgTypeNode->getOperand(ArgIdx + 1))->getString();
117 }
118 
119 static MDVector
120 GetArgMD(MDNode *KernelMDNode, unsigned OpIdx) {
121  MDVector Res;
122  for (unsigned i = 0; i < NumKernelArgMDNodes; ++i) {
123  MDNode *Node = cast<MDNode>(KernelMDNode->getOperand(i + 1));
124  Res.push_back(Node->getOperand(OpIdx));
125  }
126  return Res;
127 }
128 
129 static void
130 PushArgMD(KernelArgMD &MD, const MDVector &V) {
131  assert(V.size() == NumKernelArgMDNodes);
132  for (unsigned i = 0; i < NumKernelArgMDNodes; ++i) {
133  MD.ArgVector[i].push_back(V[i]);
134  }
135 }
136 
137 namespace {
138 
139 class AMDGPUOpenCLImageTypeLoweringPass : public ModulePass {
140  static char ID;
141 
143  Type *Int32Type;
144  Type *ImageSizeType;
145  Type *ImageFormatType;
146  SmallVector<Instruction *, 4> InstsToErase;
147 
148  bool replaceImageUses(Argument &ImageArg, uint32_t ResourceID,
149  Argument &ImageSizeArg,
150  Argument &ImageFormatArg) {
151  bool Modified = false;
152 
153  for (auto &Use : ImageArg.uses()) {
154  auto Inst = dyn_cast<CallInst>(Use.getUser());
155  if (!Inst) {
156  continue;
157  }
158 
159  Function *F = Inst->getCalledFunction();
160  if (!F)
161  continue;
162 
163  Value *Replacement = nullptr;
164  StringRef Name = F->getName();
165  if (Name.startswith(GetImageResourceIDFunc)) {
166  Replacement = ConstantInt::get(Int32Type, ResourceID);
167  } else if (Name.startswith(GetImageSizeFunc)) {
168  Replacement = &ImageSizeArg;
169  } else if (Name.startswith(GetImageFormatFunc)) {
170  Replacement = &ImageFormatArg;
171  } else {
172  continue;
173  }
174 
175  Inst->replaceAllUsesWith(Replacement);
176  InstsToErase.push_back(Inst);
177  Modified = true;
178  }
179 
180  return Modified;
181  }
182 
183  bool replaceSamplerUses(Argument &SamplerArg, uint32_t ResourceID) {
184  bool Modified = false;
185 
186  for (const auto &Use : SamplerArg.uses()) {
187  auto Inst = dyn_cast<CallInst>(Use.getUser());
188  if (!Inst) {
189  continue;
190  }
191 
192  Function *F = Inst->getCalledFunction();
193  if (!F)
194  continue;
195 
196  Value *Replacement = nullptr;
197  StringRef Name = F->getName();
198  if (Name == GetSamplerResourceIDFunc) {
199  Replacement = ConstantInt::get(Int32Type, ResourceID);
200  } else {
201  continue;
202  }
203 
204  Inst->replaceAllUsesWith(Replacement);
205  InstsToErase.push_back(Inst);
206  Modified = true;
207  }
208 
209  return Modified;
210  }
211 
212  bool replaceImageAndSamplerUses(Function *F, MDNode *KernelMDNode) {
213  uint32_t NumReadOnlyImageArgs = 0;
214  uint32_t NumWriteOnlyImageArgs = 0;
215  uint32_t NumSamplerArgs = 0;
216 
217  bool Modified = false;
218  InstsToErase.clear();
219  for (auto ArgI = F->arg_begin(); ArgI != F->arg_end(); ++ArgI) {
220  Argument &Arg = *ArgI;
221  StringRef Type = ArgTypeFromMD(KernelMDNode, Arg.getArgNo());
222 
223  // Handle image types.
224  if (IsImageType(Type)) {
225  StringRef AccessQual = AccessQualFromMD(KernelMDNode, Arg.getArgNo());
226  uint32_t ResourceID;
227  if (AccessQual == "read_only") {
228  ResourceID = NumReadOnlyImageArgs++;
229  } else if (AccessQual == "write_only") {
230  ResourceID = NumWriteOnlyImageArgs++;
231  } else {
232  llvm_unreachable("Wrong image access qualifier.");
233  }
234 
235  Argument &SizeArg = *(++ArgI);
236  Argument &FormatArg = *(++ArgI);
237  Modified |= replaceImageUses(Arg, ResourceID, SizeArg, FormatArg);
238 
239  // Handle sampler type.
240  } else if (IsSamplerType(Type)) {
241  uint32_t ResourceID = NumSamplerArgs++;
242  Modified |= replaceSamplerUses(Arg, ResourceID);
243  }
244  }
245  for (unsigned i = 0; i < InstsToErase.size(); ++i) {
246  InstsToErase[i]->eraseFromParent();
247  }
248 
249  return Modified;
250  }
251 
252  std::tuple<Function *, MDNode *>
253  addImplicitArgs(Function *F, MDNode *KernelMDNode) {
254  bool Modified = false;
255 
256  FunctionType *FT = F->getFunctionType();
257  SmallVector<Type *, 8> ArgTypes;
258 
259  // Metadata operands for new MDNode.
260  KernelArgMD NewArgMDs;
261  PushArgMD(NewArgMDs, GetArgMD(KernelMDNode, 0));
262 
263  // Add implicit arguments to the signature.
264  for (unsigned i = 0; i < FT->getNumParams(); ++i) {
265  ArgTypes.push_back(FT->getParamType(i));
266  MDVector ArgMD = GetArgMD(KernelMDNode, i + 1);
267  PushArgMD(NewArgMDs, ArgMD);
268 
269  if (!IsImageType(ArgTypeFromMD(KernelMDNode, i)))
270  continue;
271 
272  // Add size implicit argument.
273  ArgTypes.push_back(ImageSizeType);
274  ArgMD[2] = ArgMD[3] = MDString::get(*Context, ImageSizeArgMDType);
275  PushArgMD(NewArgMDs, ArgMD);
276 
277  // Add format implicit argument.
278  ArgTypes.push_back(ImageFormatType);
279  ArgMD[2] = ArgMD[3] = MDString::get(*Context, ImageFormatArgMDType);
280  PushArgMD(NewArgMDs, ArgMD);
281 
282  Modified = true;
283  }
284  if (!Modified) {
285  return std::make_tuple(nullptr, nullptr);
286  }
287 
288  // Create function with new signature and clone the old body into it.
289  auto NewFT = FunctionType::get(FT->getReturnType(), ArgTypes, false);
290  auto NewF = Function::Create(NewFT, F->getLinkage(), F->getName());
291  ValueToValueMapTy VMap;
292  auto NewFArgIt = NewF->arg_begin();
293  for (auto &Arg: F->args()) {
294  auto ArgName = Arg.getName();
295  NewFArgIt->setName(ArgName);
296  VMap[&Arg] = &(*NewFArgIt++);
297  if (IsImageType(ArgTypeFromMD(KernelMDNode, Arg.getArgNo()))) {
298  (NewFArgIt++)->setName(Twine("__size_") + ArgName);
299  (NewFArgIt++)->setName(Twine("__format_") + ArgName);
300  }
301  }
303  CloneFunctionInto(NewF, F, VMap, /*ModuleLevelChanges=*/false, Returns);
304 
305  // Build new MDNode.
307  KernelMDArgs.push_back(ConstantAsMetadata::get(NewF));
308  for (unsigned i = 0; i < NumKernelArgMDNodes; ++i)
309  KernelMDArgs.push_back(MDNode::get(*Context, NewArgMDs.ArgVector[i]));
310  MDNode *NewMDNode = MDNode::get(*Context, KernelMDArgs);
311 
312  return std::make_tuple(NewF, NewMDNode);
313  }
314 
315  bool transformKernels(Module &M) {
316  NamedMDNode *KernelsMDNode = M.getNamedMetadata(KernelsMDNodeName);
317  if (!KernelsMDNode)
318  return false;
319 
320  bool Modified = false;
321  for (unsigned i = 0; i < KernelsMDNode->getNumOperands(); ++i) {
322  MDNode *KernelMDNode = KernelsMDNode->getOperand(i);
323  Function *F = GetFunctionFromMDNode(KernelMDNode);
324  if (!F)
325  continue;
326 
327  Function *NewF;
328  MDNode *NewMDNode;
329  std::tie(NewF, NewMDNode) = addImplicitArgs(F, KernelMDNode);
330  if (NewF) {
331  // Replace old function and metadata with new ones.
332  F->eraseFromParent();
333  M.getFunctionList().push_back(NewF);
334  M.getOrInsertFunction(NewF->getName(), NewF->getFunctionType(),
335  NewF->getAttributes());
336  KernelsMDNode->setOperand(i, NewMDNode);
337 
338  F = NewF;
339  KernelMDNode = NewMDNode;
340  Modified = true;
341  }
342 
343  Modified |= replaceImageAndSamplerUses(F, KernelMDNode);
344  }
345 
346  return Modified;
347  }
348 
349  public:
350  AMDGPUOpenCLImageTypeLoweringPass() : ModulePass(ID) {}
351 
352  bool runOnModule(Module &M) override {
353  Context = &M.getContext();
354  Int32Type = Type::getInt32Ty(M.getContext());
355  ImageSizeType = ArrayType::get(Int32Type, 3);
356  ImageFormatType = ArrayType::get(Int32Type, 2);
357 
358  return transformKernels(M);
359  }
360 
361  StringRef getPassName() const override {
362  return "AMDGPU OpenCL Image Type Pass";
363  }
364 };
365 
367 
368 } // end anonymous namespace
369 
371  return new AMDGPUOpenCLImageTypeLoweringPass();
372 }
LinkageTypes getLinkage() const
Definition: GlobalValue.h:429
iterator_range< use_iterator > uses()
Definition: Value.h:326
LLVM Argument representation.
Definition: Argument.h:34
LLVMContext & Context
size_t i
static bool IsImageType(StringRef TypeString)
A Module instance is used to store all the information related to an LLVM module. ...
Definition: Module.h:52
unsigned getNumParams() const
Return the number of fixed parameters this function type requires.
Definition: DerivedTypes.h:137
static MDString * get(LLVMContext &Context, StringRef Str)
Definition: Metadata.cpp:414
unsigned getNumOperands() const
Return number of MDNode operands.
Definition: Metadata.h:1040
static StringRef AccessQualFromMD(MDNode *KernelMDNode, unsigned ArgIdx)
This class represents a function call, abstracting a target machine's calling convention.
arg_iterator arg_end()
Definition: Function.h:559
Metadata node.
Definition: Metadata.h:830
void setOperand(unsigned I, MDNode *New)
Definition: Metadata.cpp:1050
StringRef getName() const
Return a constant reference to the value's name.
Definition: Value.cpp:191
A tuple of MDNodes.
Definition: Metadata.h:1282
Twine - A lightweight data structure for efficiently representing the concatenation of temporary valu...
Definition: Twine.h:81
A Use represents the edge between a Value definition and its users.
Definition: Use.h:56
void CloneFunctionInto(Function *NewFunc, const Function *OldFunc, ValueToValueMapTy &VMap, bool ModuleLevelChanges, SmallVectorImpl< ReturnInst * > &Returns, const char *NameSuffix="", ClonedCodeInfo *CodeInfo=nullptr, ValueMapTypeRemapper *TypeMapper=nullptr, ValueMaterializer *Materializer=nullptr)
Clone OldFunc into NewFunc, transforming the old arguments into references to VMap values...
Class to represent function types.
Definition: DerivedTypes.h:102
#define F(x, y, z)
Definition: MD5.cpp:51
LLVM_NODISCARD LLVM_ATTRIBUTE_ALWAYS_INLINE bool startswith(StringRef Prefix) const
Check if this string starts with the given Prefix.
Definition: StringRef.h:264
static FunctionType * get(Type *Result, ArrayRef< Type * > Params, bool isVarArg)
This static method is the primary way of constructing a FunctionType.
Definition: Type.cpp:291
void replaceAllUsesWith(Value *V)
Change all uses of this to point to a new Value.
Definition: Value.cpp:401
static ConstantAsMetadata * get(Constant *C)
Definition: Metadata.h:392
static MDVector GetArgMD(MDNode *KernelMDNode, unsigned OpIdx)
Type * getParamType(unsigned i) const
Parameter type accessors.
Definition: DerivedTypes.h:133
Constant * getOrInsertFunction(StringRef Name, FunctionType *T, AttributeSet AttributeList)
Look up the specified function in the module symbol table.
Definition: Module.cpp:123
The instances of the Type class are immutable: once they are created, they are never changed...
Definition: Type.h:45
This is an important class for using LLVM in a threaded context.
Definition: LLVMContext.h:48
This file contains the declarations for the subclasses of Constant, which represent the different fla...
MDNode * getOperand(unsigned i) const
Definition: Metadata.cpp:1042
User * getUser() const
Returns the User that contains this Use.
Definition: Use.cpp:41
arg_iterator arg_begin()
Definition: Function.h:550
static Function * GetFunctionFromMDNode(MDNode *Node)
#define llvm_unreachable(msg)
Marks that the current location is not supposed to be reachable.
const FunctionListType & getFunctionList() const
Get the Module's list of functions (constant).
Definition: Module.h:478
StringRef getString() const
Definition: Metadata.cpp:424
const MDOperand & getOperand(unsigned I) const
Definition: Metadata.h:1034
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small...
Definition: SmallVector.h:843
Module.h This file contains the declarations for the Module class.
static Constant * get(Type *Ty, uint64_t V, bool isSigned=false)
If Ty is a vector type, return a Constant with a splat of the given value.
Definition: Constants.cpp:558
AttributeSet getAttributes() const
Return the attribute list for this Function.
Definition: Function.h:176
void push_back(pointer val)
Definition: ilist.h:326
static StringRef ArgTypeFromMD(MDNode *KernelMDNode, unsigned ArgIdx)
void eraseFromParent() override
eraseFromParent - This method unlinks 'this' from the containing module and deletes it...
Definition: Function.cpp:246
static void PushArgMD(KernelArgMD &MD, const MDVector &V)
static MDTuple * get(LLVMContext &Context, ArrayRef< Metadata * > MDs)
Definition: Metadata.h:1132
NamedMDNode * getNamedMetadata(const Twine &Name) const
Return the first NamedMDNode in the module with the specified name.
Definition: Module.cpp:265
static IntegerType * getInt32Ty(LLVMContext &C)
Definition: Type.cpp:169
FunctionType * getFunctionType() const
Returns the FunctionType for me.
Definition: Function.cpp:230
ModulePass class - This class is used to implement unstructured interprocedural optimizations and ana...
Definition: Pass.h:235
static ArrayType * get(Type *ElementType, uint64_t NumElements)
This static method is the primary way to construct an ArrayType.
Definition: Type.cpp:606
LLVM_NODISCARD std::enable_if<!is_simple_type< Y >::value, typename cast_retty< X, const Y >::ret_type >::type dyn_cast(const Y &Val)
Definition: Casting.h:287
static bool IsSamplerType(StringRef TypeString)
Type * getReturnType() const
Definition: DerivedTypes.h:123
assert(ImpDefSCC.getReg()==AMDGPU::SCC &&ImpDefSCC.isDef())
LLVM Value Representation.
Definition: Value.h:71
unsigned getArgNo() const
Return the index of this formal argument in its containing function.
Definition: Function.cpp:57
unsigned getNumOperands() const
Definition: Metadata.cpp:1038
StringRef - Represent a constant reference to a string, i.e.
Definition: StringRef.h:47
A single uniqued string.
Definition: Metadata.h:586
static Function * Create(FunctionType *Ty, LinkageTypes Linkage, const Twine &N="", Module *M=nullptr)
Definition: Function.h:117
ModulePass * createAMDGPUOpenCLImageTypeLoweringPass()
iterator_range< arg_iterator > args()
Definition: Function.h:568
LLVMContext & getContext() const
Get the global data context.
Definition: Module.h:222