LLVM  6.0.0svn
AMDGPUOpenCLImageTypeLoweringPass.cpp
Go to the documentation of this file.
1 //===- AMDGPUOpenCLImageTypeLoweringPass.cpp ------------------------------===//
2 //
3 // The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 //
10 /// \file
11 /// This pass resolves calls to OpenCL image attribute, image resource ID and
12 /// sampler resource ID getter functions.
13 ///
14 /// Image attributes (size and format) are expected to be passed to the kernel
15 /// as kernel arguments immediately following the image argument itself,
16 /// therefore this pass adds image size and format arguments to the kernel
17 /// functions in the module. The kernel functions with image arguments are
18 /// re-created using the new signature. The new arguments are added to the
19 /// kernel metadata with kernel_arg_type set to "image_size" or "image_format".
20 /// Note: this pass may invalidate pointers to functions.
21 ///
22 /// Resource IDs of read-only images, write-only images and samplers are
23 /// defined to be their index among the kernel arguments of the same
24 /// type and access qualifier.
25 //
26 //===----------------------------------------------------------------------===//
27 
28 #include "AMDGPU.h"
29 #include "llvm/ADT/SmallVector.h"
30 #include "llvm/ADT/StringRef.h"
31 #include "llvm/ADT/Twine.h"
32 #include "llvm/IR/Argument.h"
33 #include "llvm/IR/DerivedTypes.h"
34 #include "llvm/IR/Constants.h"
35 #include "llvm/IR/Function.h"
36 #include "llvm/IR/Instruction.h"
37 #include "llvm/IR/Instructions.h"
38 #include "llvm/IR/Metadata.h"
39 #include "llvm/IR/Module.h"
40 #include "llvm/IR/Type.h"
41 #include "llvm/IR/Use.h"
42 #include "llvm/IR/User.h"
43 #include "llvm/Pass.h"
44 #include "llvm/Support/Casting.h"
48 #include <cassert>
49 #include <cstddef>
50 #include <cstdint>
51 #include <tuple>
52 
53 using namespace llvm;
54 
55 static StringRef GetImageSizeFunc = "llvm.OpenCL.image.get.size";
56 static StringRef GetImageFormatFunc = "llvm.OpenCL.image.get.format";
57 static StringRef GetImageResourceIDFunc = "llvm.OpenCL.image.get.resource.id";
59  "llvm.OpenCL.sampler.get.resource.id";
60 
61 static StringRef ImageSizeArgMDType = "__llvm_image_size";
62 static StringRef ImageFormatArgMDType = "__llvm_image_format";
63 
64 static StringRef KernelsMDNodeName = "opencl.kernels";
66  "kernel_arg_addr_space",
67  "kernel_arg_access_qual",
68  "kernel_arg_type",
69  "kernel_arg_base_type",
70  "kernel_arg_type_qual"};
71 static const unsigned NumKernelArgMDNodes = 5;
72 
73 namespace {
74 
75 using MDVector = SmallVector<Metadata *, 8>;
76 struct KernelArgMD {
77  MDVector ArgVector[NumKernelArgMDNodes];
78 };
79 
80 } // end anonymous namespace
81 
82 static inline bool
83 IsImageType(StringRef TypeString) {
84  return TypeString == "image2d_t" || TypeString == "image3d_t";
85 }
86 
87 static inline bool
88 IsSamplerType(StringRef TypeString) {
89  return TypeString == "sampler_t";
90 }
91 
92 static Function *
94  if (!Node)
95  return nullptr;
96 
97  size_t NumOps = Node->getNumOperands();
98  if (NumOps != NumKernelArgMDNodes + 1)
99  return nullptr;
100 
101  auto F = mdconst::dyn_extract<Function>(Node->getOperand(0));
102  if (!F)
103  return nullptr;
104 
105  // Sanity checks.
106  size_t ExpectNumArgNodeOps = F->arg_size() + 1;
107  for (size_t i = 0; i < NumKernelArgMDNodes; ++i) {
108  MDNode *ArgNode = dyn_cast_or_null<MDNode>(Node->getOperand(i + 1));
109  if (ArgNode->getNumOperands() != ExpectNumArgNodeOps)
110  return nullptr;
111  if (!ArgNode->getOperand(0))
112  return nullptr;
113 
114  // FIXME: It should be possible to do image lowering when some metadata
115  // args missing or not in the expected order.
116  MDString *StringNode = dyn_cast<MDString>(ArgNode->getOperand(0));
117  if (!StringNode || StringNode->getString() != KernelArgMDNodeNames[i])
118  return nullptr;
119  }
120 
121  return F;
122 }
123 
124 static StringRef
125 AccessQualFromMD(MDNode *KernelMDNode, unsigned ArgIdx) {
126  MDNode *ArgAQNode = cast<MDNode>(KernelMDNode->getOperand(2));
127  return cast<MDString>(ArgAQNode->getOperand(ArgIdx + 1))->getString();
128 }
129 
130 static StringRef
131 ArgTypeFromMD(MDNode *KernelMDNode, unsigned ArgIdx) {
132  MDNode *ArgTypeNode = cast<MDNode>(KernelMDNode->getOperand(3));
133  return cast<MDString>(ArgTypeNode->getOperand(ArgIdx + 1))->getString();
134 }
135 
136 static MDVector
137 GetArgMD(MDNode *KernelMDNode, unsigned OpIdx) {
138  MDVector Res;
139  for (unsigned i = 0; i < NumKernelArgMDNodes; ++i) {
140  MDNode *Node = cast<MDNode>(KernelMDNode->getOperand(i + 1));
141  Res.push_back(Node->getOperand(OpIdx));
142  }
143  return Res;
144 }
145 
146 static void
147 PushArgMD(KernelArgMD &MD, const MDVector &V) {
148  assert(V.size() == NumKernelArgMDNodes);
149  for (unsigned i = 0; i < NumKernelArgMDNodes; ++i) {
150  MD.ArgVector[i].push_back(V[i]);
151  }
152 }
153 
154 namespace {
155 
156 class AMDGPUOpenCLImageTypeLoweringPass : public ModulePass {
157  static char ID;
158 
160  Type *Int32Type;
161  Type *ImageSizeType;
162  Type *ImageFormatType;
163  SmallVector<Instruction *, 4> InstsToErase;
164 
165  bool replaceImageUses(Argument &ImageArg, uint32_t ResourceID,
166  Argument &ImageSizeArg,
167  Argument &ImageFormatArg) {
168  bool Modified = false;
169 
170  for (auto &Use : ImageArg.uses()) {
171  auto Inst = dyn_cast<CallInst>(Use.getUser());
172  if (!Inst) {
173  continue;
174  }
175 
176  Function *F = Inst->getCalledFunction();
177  if (!F)
178  continue;
179 
180  Value *Replacement = nullptr;
181  StringRef Name = F->getName();
182  if (Name.startswith(GetImageResourceIDFunc)) {
183  Replacement = ConstantInt::get(Int32Type, ResourceID);
184  } else if (Name.startswith(GetImageSizeFunc)) {
185  Replacement = &ImageSizeArg;
186  } else if (Name.startswith(GetImageFormatFunc)) {
187  Replacement = &ImageFormatArg;
188  } else {
189  continue;
190  }
191 
192  Inst->replaceAllUsesWith(Replacement);
193  InstsToErase.push_back(Inst);
194  Modified = true;
195  }
196 
197  return Modified;
198  }
199 
200  bool replaceSamplerUses(Argument &SamplerArg, uint32_t ResourceID) {
201  bool Modified = false;
202 
203  for (const auto &Use : SamplerArg.uses()) {
204  auto Inst = dyn_cast<CallInst>(Use.getUser());
205  if (!Inst) {
206  continue;
207  }
208 
209  Function *F = Inst->getCalledFunction();
210  if (!F)
211  continue;
212 
213  Value *Replacement = nullptr;
214  StringRef Name = F->getName();
215  if (Name == GetSamplerResourceIDFunc) {
216  Replacement = ConstantInt::get(Int32Type, ResourceID);
217  } else {
218  continue;
219  }
220 
221  Inst->replaceAllUsesWith(Replacement);
222  InstsToErase.push_back(Inst);
223  Modified = true;
224  }
225 
226  return Modified;
227  }
228 
229  bool replaceImageAndSamplerUses(Function *F, MDNode *KernelMDNode) {
230  uint32_t NumReadOnlyImageArgs = 0;
231  uint32_t NumWriteOnlyImageArgs = 0;
232  uint32_t NumSamplerArgs = 0;
233 
234  bool Modified = false;
235  InstsToErase.clear();
236  for (auto ArgI = F->arg_begin(); ArgI != F->arg_end(); ++ArgI) {
237  Argument &Arg = *ArgI;
238  StringRef Type = ArgTypeFromMD(KernelMDNode, Arg.getArgNo());
239 
240  // Handle image types.
241  if (IsImageType(Type)) {
242  StringRef AccessQual = AccessQualFromMD(KernelMDNode, Arg.getArgNo());
243  uint32_t ResourceID;
244  if (AccessQual == "read_only") {
245  ResourceID = NumReadOnlyImageArgs++;
246  } else if (AccessQual == "write_only") {
247  ResourceID = NumWriteOnlyImageArgs++;
248  } else {
249  llvm_unreachable("Wrong image access qualifier.");
250  }
251 
252  Argument &SizeArg = *(++ArgI);
253  Argument &FormatArg = *(++ArgI);
254  Modified |= replaceImageUses(Arg, ResourceID, SizeArg, FormatArg);
255 
256  // Handle sampler type.
257  } else if (IsSamplerType(Type)) {
258  uint32_t ResourceID = NumSamplerArgs++;
259  Modified |= replaceSamplerUses(Arg, ResourceID);
260  }
261  }
262  for (unsigned i = 0; i < InstsToErase.size(); ++i) {
263  InstsToErase[i]->eraseFromParent();
264  }
265 
266  return Modified;
267  }
268 
269  std::tuple<Function *, MDNode *>
270  addImplicitArgs(Function *F, MDNode *KernelMDNode) {
271  bool Modified = false;
272 
273  FunctionType *FT = F->getFunctionType();
274  SmallVector<Type *, 8> ArgTypes;
275 
276  // Metadata operands for new MDNode.
277  KernelArgMD NewArgMDs;
278  PushArgMD(NewArgMDs, GetArgMD(KernelMDNode, 0));
279 
280  // Add implicit arguments to the signature.
281  for (unsigned i = 0; i < FT->getNumParams(); ++i) {
282  ArgTypes.push_back(FT->getParamType(i));
283  MDVector ArgMD = GetArgMD(KernelMDNode, i + 1);
284  PushArgMD(NewArgMDs, ArgMD);
285 
286  if (!IsImageType(ArgTypeFromMD(KernelMDNode, i)))
287  continue;
288 
289  // Add size implicit argument.
290  ArgTypes.push_back(ImageSizeType);
291  ArgMD[2] = ArgMD[3] = MDString::get(*Context, ImageSizeArgMDType);
292  PushArgMD(NewArgMDs, ArgMD);
293 
294  // Add format implicit argument.
295  ArgTypes.push_back(ImageFormatType);
296  ArgMD[2] = ArgMD[3] = MDString::get(*Context, ImageFormatArgMDType);
297  PushArgMD(NewArgMDs, ArgMD);
298 
299  Modified = true;
300  }
301  if (!Modified) {
302  return std::make_tuple(nullptr, nullptr);
303  }
304 
305  // Create function with new signature and clone the old body into it.
306  auto NewFT = FunctionType::get(FT->getReturnType(), ArgTypes, false);
307  auto NewF = Function::Create(NewFT, F->getLinkage(), F->getName());
308  ValueToValueMapTy VMap;
309  auto NewFArgIt = NewF->arg_begin();
310  for (auto &Arg: F->args()) {
311  auto ArgName = Arg.getName();
312  NewFArgIt->setName(ArgName);
313  VMap[&Arg] = &(*NewFArgIt++);
314  if (IsImageType(ArgTypeFromMD(KernelMDNode, Arg.getArgNo()))) {
315  (NewFArgIt++)->setName(Twine("__size_") + ArgName);
316  (NewFArgIt++)->setName(Twine("__format_") + ArgName);
317  }
318  }
320  CloneFunctionInto(NewF, F, VMap, /*ModuleLevelChanges=*/false, Returns);
321 
322  // Build new MDNode.
323  SmallVector<Metadata *, 6> KernelMDArgs;
324  KernelMDArgs.push_back(ConstantAsMetadata::get(NewF));
325  for (unsigned i = 0; i < NumKernelArgMDNodes; ++i)
326  KernelMDArgs.push_back(MDNode::get(*Context, NewArgMDs.ArgVector[i]));
327  MDNode *NewMDNode = MDNode::get(*Context, KernelMDArgs);
328 
329  return std::make_tuple(NewF, NewMDNode);
330  }
331 
332  bool transformKernels(Module &M) {
333  NamedMDNode *KernelsMDNode = M.getNamedMetadata(KernelsMDNodeName);
334  if (!KernelsMDNode)
335  return false;
336 
337  bool Modified = false;
338  for (unsigned i = 0; i < KernelsMDNode->getNumOperands(); ++i) {
339  MDNode *KernelMDNode = KernelsMDNode->getOperand(i);
340  Function *F = GetFunctionFromMDNode(KernelMDNode);
341  if (!F)
342  continue;
343 
344  Function *NewF;
345  MDNode *NewMDNode;
346  std::tie(NewF, NewMDNode) = addImplicitArgs(F, KernelMDNode);
347  if (NewF) {
348  // Replace old function and metadata with new ones.
349  F->eraseFromParent();
350  M.getFunctionList().push_back(NewF);
351  M.getOrInsertFunction(NewF->getName(), NewF->getFunctionType(),
352  NewF->getAttributes());
353  KernelsMDNode->setOperand(i, NewMDNode);
354 
355  F = NewF;
356  KernelMDNode = NewMDNode;
357  Modified = true;
358  }
359 
360  Modified |= replaceImageAndSamplerUses(F, KernelMDNode);
361  }
362 
363  return Modified;
364  }
365 
366 public:
367  AMDGPUOpenCLImageTypeLoweringPass() : ModulePass(ID) {}
368 
369  bool runOnModule(Module &M) override {
370  Context = &M.getContext();
371  Int32Type = Type::getInt32Ty(M.getContext());
372  ImageSizeType = ArrayType::get(Int32Type, 3);
373  ImageFormatType = ArrayType::get(Int32Type, 2);
374 
375  return transformKernels(M);
376  }
377 
378  StringRef getPassName() const override {
379  return "AMDGPU OpenCL Image Type Pass";
380  }
381 };
382 
383 } // end anonymous namespace
384 
386 
388  return new AMDGPUOpenCLImageTypeLoweringPass();
389 }
iterator_range< use_iterator > uses()
Definition: Value.h:356
This class represents an incoming formal argument to a Function.
Definition: Argument.h:30
LLVMContext & Context
MDNode * getOperand(unsigned i) const
Definition: Metadata.cpp:1073
Compute iterated dominance frontiers using a linear time algorithm.
Definition: AllocatorList.h:24
Type * getParamType(unsigned i) const
Parameter type accessors.
Definition: DerivedTypes.h:135
static bool IsImageType(StringRef TypeString)
Constant * getOrInsertFunction(StringRef Name, FunctionType *T, AttributeList AttributeList)
Look up the specified function in the module symbol table.
Definition: Module.cpp:142
A Module instance is used to store all the information related to an LLVM module. ...
Definition: Module.h:63
LLVM_ATTRIBUTE_ALWAYS_INLINE size_type size() const
Definition: SmallVector.h:136
static MDString * get(LLVMContext &Context, StringRef Str)
Definition: Metadata.cpp:446
static StringRef AccessQualFromMD(MDNode *KernelMDNode, unsigned ArgIdx)
This class represents a function call, abstracting a target machine&#39;s calling convention.
This file contains the declarations for metadata subclasses.
static StringRef GetSamplerResourceIDFunc
static StringRef GetImageSizeFunc
arg_iterator arg_end()
Definition: Function.h:612
Metadata node.
Definition: Metadata.h:862
F(f)
const MDOperand & getOperand(unsigned I) const
Definition: Metadata.h:1067
This defines the Use class.
void setOperand(unsigned I, MDNode *New)
Definition: Metadata.cpp:1081
A tuple of MDNodes.
Definition: Metadata.h:1323
void CloneFunctionInto(Function *NewFunc, const Function *OldFunc, ValueToValueMapTy &VMap, bool ModuleLevelChanges, SmallVectorImpl< ReturnInst *> &Returns, const char *NameSuffix="", ClonedCodeInfo *CodeInfo=nullptr, ValueMapTypeRemapper *TypeMapper=nullptr, ValueMaterializer *Materializer=nullptr)
Clone OldFunc into NewFunc, transforming the old arguments into references to VMap values...
static StringRef KernelsMDNodeName
Twine - A lightweight data structure for efficiently representing the concatenation of temporary valu...
Definition: Twine.h:81
static StringRef KernelArgMDNodeNames[]
LLVMContext & getContext() const
Get the global data context.
Definition: Module.h:237
A Use represents the edge between a Value definition and its users.
Definition: Use.h:56
unsigned getNumOperands() const
Definition: Metadata.cpp:1069
LLVM_NODISCARD LLVM_ATTRIBUTE_ALWAYS_INLINE bool startswith(StringRef Prefix) const
Check if this string starts with the given Prefix.
Definition: StringRef.h:267
User * getUser() const LLVM_READONLY
Returns the User that contains this Use.
Definition: Use.cpp:41
Class to represent function types.
Definition: DerivedTypes.h:103
NamedMDNode * getNamedMetadata(const Twine &Name) const
Return the first NamedMDNode in the module with the specified name.
Definition: Module.cpp:242
static StringRef GetImageResourceIDFunc
AttributeList getAttributes() const
Return the attribute list for this Function.
Definition: Function.h:205
LinkageTypes getLinkage() const
Definition: GlobalValue.h:441
void replaceAllUsesWith(Value *V)
Change all uses of this to point to a new Value.
Definition: Value.cpp:430
static ConstantAsMetadata * get(Constant *C)
Definition: Metadata.h:408
StringRef getString() const
Definition: Metadata.cpp:456
static MDVector GetArgMD(MDNode *KernelMDNode, unsigned OpIdx)
static MDTuple * get(LLVMContext &Context, ArrayRef< Metadata *> MDs)
Definition: Metadata.h:1164
static StringRef ImageSizeArgMDType
const FunctionListType & getFunctionList() const
Get the Module&#39;s list of functions (constant).
Definition: Module.h:507
The instances of the Type class are immutable: once they are created, they are never changed...
Definition: Type.h:46
This is an important class for using LLVM in a threaded context.
Definition: LLVMContext.h:69
This file contains the declarations for the subclasses of Constant, which represent the different fla...
unsigned getNumParams() const
Return the number of fixed parameters this function type requires.
Definition: DerivedTypes.h:139
static FunctionType * get(Type *Result, ArrayRef< Type *> Params, bool isVarArg)
This static method is the primary way of constructing a FunctionType.
Definition: Type.cpp:297
arg_iterator arg_begin()
Definition: Function.h:603
static Function * GetFunctionFromMDNode(MDNode *Node)
#define llvm_unreachable(msg)
Marks that the current location is not supposed to be reachable.
static StringRef ImageFormatArgMDType
This is a &#39;vector&#39; (really, a variable-sized array), optimized for the case when the array is small...
Definition: SmallVector.h:864
Module.h This file contains the declarations for the Module class.
static const unsigned NumKernelArgMDNodes
Type * getReturnType() const
Definition: DerivedTypes.h:124
static Constant * get(Type *Ty, uint64_t V, bool isSigned=false)
If Ty is a vector type, return a Constant with a splat of the given value.
Definition: Constants.cpp:560
FunctionType * getFunctionType() const
Returns the FunctionType for me.
Definition: Function.h:145
void push_back(pointer val)
Definition: ilist.h:326
static StringRef ArgTypeFromMD(MDNode *KernelMDNode, unsigned ArgIdx)
unsigned getArgNo() const
Return the index of this formal argument in its containing function.
Definition: Argument.h:48
static void PushArgMD(KernelArgMD &MD, const MDVector &V)
amdgpu Simplify well known AMD library false Value Value * Arg
static IntegerType * getInt32Ty(LLVMContext &C)
Definition: Type.cpp:176
StringRef getName() const
Return a constant reference to the value&#39;s name.
Definition: Value.cpp:220
static StringRef GetImageFormatFunc
ModulePass class - This class is used to implement unstructured interprocedural optimizations and ana...
Definition: Pass.h:225
static ArrayType * get(Type *ElementType, uint64_t NumElements)
This static method is the primary way to construct an ArrayType.
Definition: Type.cpp:568
LLVM_NODISCARD std::enable_if<!is_simple_type< Y >::value, typename cast_retty< X, const Y >::ret_type >::type dyn_cast(const Y &Val)
Definition: Casting.h:323
void eraseFromParent()
eraseFromParent - This method unlinks &#39;this&#39; from the containing module and deletes it...
Definition: Function.cpp:202
static bool IsSamplerType(StringRef TypeString)
assert(ImpDefSCC.getReg()==AMDGPU::SCC &&ImpDefSCC.isDef())
LLVM Value Representation.
Definition: Value.h:73
StringRef - Represent a constant reference to a string, i.e.
Definition: StringRef.h:49
A single uniqued string.
Definition: Metadata.h:602
static Function * Create(FunctionType *Ty, LinkageTypes Linkage, const Twine &N="", Module *M=nullptr)
Definition: Function.h:136
unsigned getNumOperands() const
Return number of MDNode operands.
Definition: Metadata.h:1073
ModulePass * createAMDGPUOpenCLImageTypeLoweringPass()
iterator_range< arg_iterator > args()
Definition: Function.h:621