LLVM  9.0.0svn
R600OpenCLImageTypeLoweringPass.cpp
Go to the documentation of this file.
1 //===- R600OpenCLImageTypeLoweringPass.cpp ------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 /// \file
10 /// This pass resolves calls to OpenCL image attribute, image resource ID and
11 /// sampler resource ID getter functions.
12 ///
13 /// Image attributes (size and format) are expected to be passed to the kernel
14 /// as kernel arguments immediately following the image argument itself,
15 /// therefore this pass adds image size and format arguments to the kernel
16 /// functions in the module. The kernel functions with image arguments are
17 /// re-created using the new signature. The new arguments are added to the
18 /// kernel metadata with kernel_arg_type set to "image_size" or "image_format".
19 /// Note: this pass may invalidate pointers to functions.
20 ///
21 /// Resource IDs of read-only images, write-only images and samplers are
22 /// defined to be their index among the kernel arguments of the same
23 /// type and access qualifier.
24 //
25 //===----------------------------------------------------------------------===//
26 
27 #include "AMDGPU.h"
28 #include "llvm/ADT/SmallVector.h"
29 #include "llvm/ADT/StringRef.h"
30 #include "llvm/ADT/Twine.h"
31 #include "llvm/IR/Argument.h"
32 #include "llvm/IR/DerivedTypes.h"
33 #include "llvm/IR/Constants.h"
34 #include "llvm/IR/Function.h"
35 #include "llvm/IR/Instruction.h"
36 #include "llvm/IR/Instructions.h"
37 #include "llvm/IR/Metadata.h"
38 #include "llvm/IR/Module.h"
39 #include "llvm/IR/Type.h"
40 #include "llvm/IR/Use.h"
41 #include "llvm/IR/User.h"
42 #include "llvm/Pass.h"
43 #include "llvm/Support/Casting.h"
47 #include <cassert>
48 #include <cstddef>
49 #include <cstdint>
50 #include <tuple>
51 
52 using namespace llvm;
53 
54 static StringRef GetImageSizeFunc = "llvm.OpenCL.image.get.size";
55 static StringRef GetImageFormatFunc = "llvm.OpenCL.image.get.format";
56 static StringRef GetImageResourceIDFunc = "llvm.OpenCL.image.get.resource.id";
58  "llvm.OpenCL.sampler.get.resource.id";
59 
60 static StringRef ImageSizeArgMDType = "__llvm_image_size";
61 static StringRef ImageFormatArgMDType = "__llvm_image_format";
62 
63 static StringRef KernelsMDNodeName = "opencl.kernels";
65  "kernel_arg_addr_space",
66  "kernel_arg_access_qual",
67  "kernel_arg_type",
68  "kernel_arg_base_type",
69  "kernel_arg_type_qual"};
70 static const unsigned NumKernelArgMDNodes = 5;
71 
72 namespace {
73 
74 using MDVector = SmallVector<Metadata *, 8>;
75 struct KernelArgMD {
76  MDVector ArgVector[NumKernelArgMDNodes];
77 };
78 
79 } // end anonymous namespace
80 
81 static inline bool
82 IsImageType(StringRef TypeString) {
83  return TypeString == "image2d_t" || TypeString == "image3d_t";
84 }
85 
86 static inline bool
87 IsSamplerType(StringRef TypeString) {
88  return TypeString == "sampler_t";
89 }
90 
91 static Function *
93  if (!Node)
94  return nullptr;
95 
96  size_t NumOps = Node->getNumOperands();
97  if (NumOps != NumKernelArgMDNodes + 1)
98  return nullptr;
99 
100  auto F = mdconst::dyn_extract<Function>(Node->getOperand(0));
101  if (!F)
102  return nullptr;
103 
104  // Sanity checks.
105  size_t ExpectNumArgNodeOps = F->arg_size() + 1;
106  for (size_t i = 0; i < NumKernelArgMDNodes; ++i) {
107  MDNode *ArgNode = dyn_cast_or_null<MDNode>(Node->getOperand(i + 1));
108  if (ArgNode->getNumOperands() != ExpectNumArgNodeOps)
109  return nullptr;
110  if (!ArgNode->getOperand(0))
111  return nullptr;
112 
113  // FIXME: It should be possible to do image lowering when some metadata
114  // args missing or not in the expected order.
115  MDString *StringNode = dyn_cast<MDString>(ArgNode->getOperand(0));
116  if (!StringNode || StringNode->getString() != KernelArgMDNodeNames[i])
117  return nullptr;
118  }
119 
120  return F;
121 }
122 
123 static StringRef
124 AccessQualFromMD(MDNode *KernelMDNode, unsigned ArgIdx) {
125  MDNode *ArgAQNode = cast<MDNode>(KernelMDNode->getOperand(2));
126  return cast<MDString>(ArgAQNode->getOperand(ArgIdx + 1))->getString();
127 }
128 
129 static StringRef
130 ArgTypeFromMD(MDNode *KernelMDNode, unsigned ArgIdx) {
131  MDNode *ArgTypeNode = cast<MDNode>(KernelMDNode->getOperand(3));
132  return cast<MDString>(ArgTypeNode->getOperand(ArgIdx + 1))->getString();
133 }
134 
135 static MDVector
136 GetArgMD(MDNode *KernelMDNode, unsigned OpIdx) {
137  MDVector Res;
138  for (unsigned i = 0; i < NumKernelArgMDNodes; ++i) {
139  MDNode *Node = cast<MDNode>(KernelMDNode->getOperand(i + 1));
140  Res.push_back(Node->getOperand(OpIdx));
141  }
142  return Res;
143 }
144 
145 static void
146 PushArgMD(KernelArgMD &MD, const MDVector &V) {
147  assert(V.size() == NumKernelArgMDNodes);
148  for (unsigned i = 0; i < NumKernelArgMDNodes; ++i) {
149  MD.ArgVector[i].push_back(V[i]);
150  }
151 }
152 
153 namespace {
154 
155 class R600OpenCLImageTypeLoweringPass : public ModulePass {
156  static char ID;
157 
159  Type *Int32Type;
160  Type *ImageSizeType;
161  Type *ImageFormatType;
162  SmallVector<Instruction *, 4> InstsToErase;
163 
164  bool replaceImageUses(Argument &ImageArg, uint32_t ResourceID,
165  Argument &ImageSizeArg,
166  Argument &ImageFormatArg) {
167  bool Modified = false;
168 
169  for (auto &Use : ImageArg.uses()) {
170  auto Inst = dyn_cast<CallInst>(Use.getUser());
171  if (!Inst) {
172  continue;
173  }
174 
175  Function *F = Inst->getCalledFunction();
176  if (!F)
177  continue;
178 
179  Value *Replacement = nullptr;
180  StringRef Name = F->getName();
181  if (Name.startswith(GetImageResourceIDFunc)) {
182  Replacement = ConstantInt::get(Int32Type, ResourceID);
183  } else if (Name.startswith(GetImageSizeFunc)) {
184  Replacement = &ImageSizeArg;
185  } else if (Name.startswith(GetImageFormatFunc)) {
186  Replacement = &ImageFormatArg;
187  } else {
188  continue;
189  }
190 
191  Inst->replaceAllUsesWith(Replacement);
192  InstsToErase.push_back(Inst);
193  Modified = true;
194  }
195 
196  return Modified;
197  }
198 
199  bool replaceSamplerUses(Argument &SamplerArg, uint32_t ResourceID) {
200  bool Modified = false;
201 
202  for (const auto &Use : SamplerArg.uses()) {
203  auto Inst = dyn_cast<CallInst>(Use.getUser());
204  if (!Inst) {
205  continue;
206  }
207 
208  Function *F = Inst->getCalledFunction();
209  if (!F)
210  continue;
211 
212  Value *Replacement = nullptr;
213  StringRef Name = F->getName();
214  if (Name == GetSamplerResourceIDFunc) {
215  Replacement = ConstantInt::get(Int32Type, ResourceID);
216  } else {
217  continue;
218  }
219 
220  Inst->replaceAllUsesWith(Replacement);
221  InstsToErase.push_back(Inst);
222  Modified = true;
223  }
224 
225  return Modified;
226  }
227 
228  bool replaceImageAndSamplerUses(Function *F, MDNode *KernelMDNode) {
229  uint32_t NumReadOnlyImageArgs = 0;
230  uint32_t NumWriteOnlyImageArgs = 0;
231  uint32_t NumSamplerArgs = 0;
232 
233  bool Modified = false;
234  InstsToErase.clear();
235  for (auto ArgI = F->arg_begin(); ArgI != F->arg_end(); ++ArgI) {
236  Argument &Arg = *ArgI;
237  StringRef Type = ArgTypeFromMD(KernelMDNode, Arg.getArgNo());
238 
239  // Handle image types.
240  if (IsImageType(Type)) {
241  StringRef AccessQual = AccessQualFromMD(KernelMDNode, Arg.getArgNo());
242  uint32_t ResourceID;
243  if (AccessQual == "read_only") {
244  ResourceID = NumReadOnlyImageArgs++;
245  } else if (AccessQual == "write_only") {
246  ResourceID = NumWriteOnlyImageArgs++;
247  } else {
248  llvm_unreachable("Wrong image access qualifier.");
249  }
250 
251  Argument &SizeArg = *(++ArgI);
252  Argument &FormatArg = *(++ArgI);
253  Modified |= replaceImageUses(Arg, ResourceID, SizeArg, FormatArg);
254 
255  // Handle sampler type.
256  } else if (IsSamplerType(Type)) {
257  uint32_t ResourceID = NumSamplerArgs++;
258  Modified |= replaceSamplerUses(Arg, ResourceID);
259  }
260  }
261  for (unsigned i = 0; i < InstsToErase.size(); ++i) {
262  InstsToErase[i]->eraseFromParent();
263  }
264 
265  return Modified;
266  }
267 
268  std::tuple<Function *, MDNode *>
269  addImplicitArgs(Function *F, MDNode *KernelMDNode) {
270  bool Modified = false;
271 
272  FunctionType *FT = F->getFunctionType();
273  SmallVector<Type *, 8> ArgTypes;
274 
275  // Metadata operands for new MDNode.
276  KernelArgMD NewArgMDs;
277  PushArgMD(NewArgMDs, GetArgMD(KernelMDNode, 0));
278 
279  // Add implicit arguments to the signature.
280  for (unsigned i = 0; i < FT->getNumParams(); ++i) {
281  ArgTypes.push_back(FT->getParamType(i));
282  MDVector ArgMD = GetArgMD(KernelMDNode, i + 1);
283  PushArgMD(NewArgMDs, ArgMD);
284 
285  if (!IsImageType(ArgTypeFromMD(KernelMDNode, i)))
286  continue;
287 
288  // Add size implicit argument.
289  ArgTypes.push_back(ImageSizeType);
290  ArgMD[2] = ArgMD[3] = MDString::get(*Context, ImageSizeArgMDType);
291  PushArgMD(NewArgMDs, ArgMD);
292 
293  // Add format implicit argument.
294  ArgTypes.push_back(ImageFormatType);
295  ArgMD[2] = ArgMD[3] = MDString::get(*Context, ImageFormatArgMDType);
296  PushArgMD(NewArgMDs, ArgMD);
297 
298  Modified = true;
299  }
300  if (!Modified) {
301  return std::make_tuple(nullptr, nullptr);
302  }
303 
304  // Create function with new signature and clone the old body into it.
305  auto NewFT = FunctionType::get(FT->getReturnType(), ArgTypes, false);
306  auto NewF = Function::Create(NewFT, F->getLinkage(), F->getName());
307  ValueToValueMapTy VMap;
308  auto NewFArgIt = NewF->arg_begin();
309  for (auto &Arg: F->args()) {
310  auto ArgName = Arg.getName();
311  NewFArgIt->setName(ArgName);
312  VMap[&Arg] = &(*NewFArgIt++);
313  if (IsImageType(ArgTypeFromMD(KernelMDNode, Arg.getArgNo()))) {
314  (NewFArgIt++)->setName(Twine("__size_") + ArgName);
315  (NewFArgIt++)->setName(Twine("__format_") + ArgName);
316  }
317  }
319  CloneFunctionInto(NewF, F, VMap, /*ModuleLevelChanges=*/false, Returns);
320 
321  // Build new MDNode.
322  SmallVector<Metadata *, 6> KernelMDArgs;
323  KernelMDArgs.push_back(ConstantAsMetadata::get(NewF));
324  for (unsigned i = 0; i < NumKernelArgMDNodes; ++i)
325  KernelMDArgs.push_back(MDNode::get(*Context, NewArgMDs.ArgVector[i]));
326  MDNode *NewMDNode = MDNode::get(*Context, KernelMDArgs);
327 
328  return std::make_tuple(NewF, NewMDNode);
329  }
330 
331  bool transformKernels(Module &M) {
332  NamedMDNode *KernelsMDNode = M.getNamedMetadata(KernelsMDNodeName);
333  if (!KernelsMDNode)
334  return false;
335 
336  bool Modified = false;
337  for (unsigned i = 0; i < KernelsMDNode->getNumOperands(); ++i) {
338  MDNode *KernelMDNode = KernelsMDNode->getOperand(i);
339  Function *F = GetFunctionFromMDNode(KernelMDNode);
340  if (!F)
341  continue;
342 
343  Function *NewF;
344  MDNode *NewMDNode;
345  std::tie(NewF, NewMDNode) = addImplicitArgs(F, KernelMDNode);
346  if (NewF) {
347  // Replace old function and metadata with new ones.
348  F->eraseFromParent();
349  M.getFunctionList().push_back(NewF);
350  M.getOrInsertFunction(NewF->getName(), NewF->getFunctionType(),
351  NewF->getAttributes());
352  KernelsMDNode->setOperand(i, NewMDNode);
353 
354  F = NewF;
355  KernelMDNode = NewMDNode;
356  Modified = true;
357  }
358 
359  Modified |= replaceImageAndSamplerUses(F, KernelMDNode);
360  }
361 
362  return Modified;
363  }
364 
365 public:
366  R600OpenCLImageTypeLoweringPass() : ModulePass(ID) {}
367 
368  bool runOnModule(Module &M) override {
369  Context = &M.getContext();
370  Int32Type = Type::getInt32Ty(M.getContext());
371  ImageSizeType = ArrayType::get(Int32Type, 3);
372  ImageFormatType = ArrayType::get(Int32Type, 2);
373 
374  return transformKernels(M);
375  }
376 
377  StringRef getPassName() const override {
378  return "R600 OpenCL Image Type Pass";
379  }
380 };
381 
382 } // end anonymous namespace
383 
385 
387  return new R600OpenCLImageTypeLoweringPass();
388 }
iterator_range< use_iterator > uses()
Definition: Value.h:354
This class represents an incoming formal argument to a Function.
Definition: Argument.h:29
LLVMContext & Context
MDNode * getOperand(unsigned i) const
Definition: Metadata.cpp:1080
static StringRef AccessQualFromMD(MDNode *KernelMDNode, unsigned ArgIdx)
This class represents lattice values for constants.
Definition: AllocatorList.h:23
Type * getParamType(unsigned i) const
Parameter type accessors.
Definition: DerivedTypes.h:135
static StringRef GetImageFormatFunc
A Module instance is used to store all the information related to an LLVM module. ...
Definition: Module.h:65
ModulePass * createR600OpenCLImageTypeLoweringPass()
amdgpu Simplify well known AMD library false FunctionCallee Value const Twine & Name
static MDString * get(LLVMContext &Context, StringRef Str)
Definition: Metadata.cpp:453
LLVM_NODISCARD bool startswith(StringRef Prefix) const
Check if this string starts with the given Prefix.
Definition: StringRef.h:256
This class represents a function call, abstracting a target machine&#39;s calling convention.
This file contains the declarations for metadata subclasses.
arg_iterator arg_end()
Definition: Function.h:696
Metadata node.
Definition: Metadata.h:863
F(f)
const MDOperand & getOperand(unsigned I) const
Definition: Metadata.h:1068
This defines the Use class.
static StringRef ImageFormatArgMDType
void setOperand(unsigned I, MDNode *New)
Definition: Metadata.cpp:1088
static StringRef KernelsMDNodeName
static StringRef GetSamplerResourceIDFunc
A tuple of MDNodes.
Definition: Metadata.h:1325
void CloneFunctionInto(Function *NewFunc, const Function *OldFunc, ValueToValueMapTy &VMap, bool ModuleLevelChanges, SmallVectorImpl< ReturnInst *> &Returns, const char *NameSuffix="", ClonedCodeInfo *CodeInfo=nullptr, ValueMapTypeRemapper *TypeMapper=nullptr, ValueMaterializer *Materializer=nullptr)
Clone OldFunc into NewFunc, transforming the old arguments into references to VMap values...
Twine - A lightweight data structure for efficiently representing the concatenation of temporary valu...
Definition: Twine.h:80
LLVMContext & getContext() const
Get the global data context.
Definition: Module.h:244
A Use represents the edge between a Value definition and its users.
Definition: Use.h:55
static bool IsImageType(StringRef TypeString)
unsigned getNumOperands() const
Definition: Metadata.cpp:1076
User * getUser() const LLVM_READONLY
Returns the User that contains this Use.
Definition: Use.cpp:40
Class to represent function types.
Definition: DerivedTypes.h:103
NamedMDNode * getNamedMetadata(const Twine &Name) const
Return the first NamedMDNode in the module with the specified name.
Definition: Module.cpp:250
static StringRef ImageSizeArgMDType
AttributeList getAttributes() const
Return the attribute list for this Function.
Definition: Function.h:223
LinkageTypes getLinkage() const
Definition: GlobalValue.h:460
void replaceAllUsesWith(Value *V)
Change all uses of this to point to a new Value.
Definition: Value.cpp:429
static ConstantAsMetadata * get(Constant *C)
Definition: Metadata.h:409
StringRef getString() const
Definition: Metadata.cpp:463
static MDTuple * get(LLVMContext &Context, ArrayRef< Metadata *> MDs)
Definition: Metadata.h:1165
static Function * Create(FunctionType *Ty, LinkageTypes Linkage, unsigned AddrSpace, const Twine &N="", Module *M=nullptr)
Definition: Function.h:135
const FunctionListType & getFunctionList() const
Get the Module&#39;s list of functions (constant).
Definition: Module.h:533
The instances of the Type class are immutable: once they are created, they are never changed...
Definition: Type.h:45
This is an important class for using LLVM in a threaded context.
Definition: LLVMContext.h:64
static MDVector GetArgMD(MDNode *KernelMDNode, unsigned OpIdx)
This file contains the declarations for the subclasses of Constant, which represent the different fla...
unsigned getNumParams() const
Return the number of fixed parameters this function type requires.
Definition: DerivedTypes.h:139
amdgpu Simplify well known AMD library false FunctionCallee Value * Arg
static FunctionType * get(Type *Result, ArrayRef< Type *> Params, bool isVarArg)
This static method is the primary way of constructing a FunctionType.
Definition: Type.cpp:296
arg_iterator arg_begin()
Definition: Function.h:687
size_t size() const
Definition: SmallVector.h:52
static void PushArgMD(KernelArgMD &MD, const MDVector &V)
#define llvm_unreachable(msg)
Marks that the current location is not supposed to be reachable.
This is a &#39;vector&#39; (really, a variable-sized array), optimized for the case when the array is small...
Definition: SmallVector.h:837
Module.h This file contains the declarations for the Module class.
Type * getReturnType() const
Definition: DerivedTypes.h:124
static Constant * get(Type *Ty, uint64_t V, bool isSigned=false)
If Ty is a vector type, return a Constant with a splat of the given value.
Definition: Constants.cpp:631
FunctionCallee getOrInsertFunction(StringRef Name, FunctionType *T, AttributeList AttributeList)
Look up the specified function in the module symbol table.
Definition: Module.cpp:143
static StringRef KernelArgMDNodeNames[]
FunctionType * getFunctionType() const
Returns the FunctionType for me.
Definition: Function.h:163
void push_back(pointer val)
Definition: ilist.h:311
static const unsigned NumKernelArgMDNodes
static StringRef GetImageResourceIDFunc
unsigned getArgNo() const
Return the index of this formal argument in its containing function.
Definition: Argument.h:47
static IntegerType * getInt32Ty(LLVMContext &C)
Definition: Type.cpp:175
StringRef getName() const
Return a constant reference to the value&#39;s name.
Definition: Value.cpp:214
ModulePass class - This class is used to implement unstructured interprocedural optimizations and ana...
Definition: Pass.h:224
static ArrayType * get(Type *ElementType, uint64_t NumElements)
This static method is the primary way to construct an ArrayType.
Definition: Type.cpp:580
LLVM_NODISCARD std::enable_if<!is_simple_type< Y >::value, typename cast_retty< X, const Y >::ret_type >::type dyn_cast(const Y &Val)
Definition: Casting.h:332
static Function * GetFunctionFromMDNode(MDNode *Node)
void eraseFromParent()
eraseFromParent - This method unlinks &#39;this&#39; from the containing module and deletes it...
Definition: Function.cpp:226
assert(ImpDefSCC.getReg()==AMDGPU::SCC &&ImpDefSCC.isDef())
LLVM Value Representation.
Definition: Value.h:72
StringRef - Represent a constant reference to a string, i.e.
Definition: StringRef.h:48
A single uniqued string.
Definition: Metadata.h:603
static bool IsSamplerType(StringRef TypeString)
static StringRef GetImageSizeFunc
unsigned getNumOperands() const
Return number of MDNode operands.
Definition: Metadata.h:1074
static StringRef ArgTypeFromMD(MDNode *KernelMDNode, unsigned ArgIdx)
iterator_range< arg_iterator > args()
Definition: Function.h:705