Line data Source code
1 : //===- R600OpenCLImageTypeLoweringPass.cpp ------------------------------===//
2 : //
3 : // The LLVM Compiler Infrastructure
4 : //
5 : // This file is distributed under the University of Illinois Open Source
6 : // License. See LICENSE.TXT for details.
7 : //
8 : //===----------------------------------------------------------------------===//
9 : //
10 : /// \file
11 : /// This pass resolves calls to OpenCL image attribute, image resource ID and
12 : /// sampler resource ID getter functions.
13 : ///
14 : /// Image attributes (size and format) are expected to be passed to the kernel
15 : /// as kernel arguments immediately following the image argument itself,
16 : /// therefore this pass adds image size and format arguments to the kernel
17 : /// functions in the module. The kernel functions with image arguments are
18 : /// re-created using the new signature. The new arguments are added to the
19 : /// kernel metadata with kernel_arg_type set to "image_size" or "image_format".
20 : /// Note: this pass may invalidate pointers to functions.
21 : ///
22 : /// Resource IDs of read-only images, write-only images and samplers are
23 : /// defined to be their index among the kernel arguments of the same
24 : /// type and access qualifier.
25 : //
26 : //===----------------------------------------------------------------------===//
27 :
28 : #include "AMDGPU.h"
29 : #include "llvm/ADT/SmallVector.h"
30 : #include "llvm/ADT/StringRef.h"
31 : #include "llvm/ADT/Twine.h"
32 : #include "llvm/IR/Argument.h"
33 : #include "llvm/IR/DerivedTypes.h"
34 : #include "llvm/IR/Constants.h"
35 : #include "llvm/IR/Function.h"
36 : #include "llvm/IR/Instruction.h"
37 : #include "llvm/IR/Instructions.h"
38 : #include "llvm/IR/Metadata.h"
39 : #include "llvm/IR/Module.h"
40 : #include "llvm/IR/Type.h"
41 : #include "llvm/IR/Use.h"
42 : #include "llvm/IR/User.h"
43 : #include "llvm/Pass.h"
44 : #include "llvm/Support/Casting.h"
45 : #include "llvm/Support/ErrorHandling.h"
46 : #include "llvm/Transforms/Utils/Cloning.h"
47 : #include "llvm/Transforms/Utils/ValueMapper.h"
48 : #include <cassert>
49 : #include <cstddef>
50 : #include <cstdint>
51 : #include <tuple>
52 :
53 : using namespace llvm;
54 :
55 : static StringRef GetImageSizeFunc = "llvm.OpenCL.image.get.size";
56 : static StringRef GetImageFormatFunc = "llvm.OpenCL.image.get.format";
57 : static StringRef GetImageResourceIDFunc = "llvm.OpenCL.image.get.resource.id";
58 : static StringRef GetSamplerResourceIDFunc =
59 : "llvm.OpenCL.sampler.get.resource.id";
60 :
61 : static StringRef ImageSizeArgMDType = "__llvm_image_size";
62 : static StringRef ImageFormatArgMDType = "__llvm_image_format";
63 :
64 : static StringRef KernelsMDNodeName = "opencl.kernels";
65 : static StringRef KernelArgMDNodeNames[] = {
66 : "kernel_arg_addr_space",
67 : "kernel_arg_access_qual",
68 : "kernel_arg_type",
69 : "kernel_arg_base_type",
70 : "kernel_arg_type_qual"};
71 : static const unsigned NumKernelArgMDNodes = 5;
72 :
73 : namespace {
74 :
75 : using MDVector = SmallVector<Metadata *, 8>;
76 396 : struct KernelArgMD {
77 : MDVector ArgVector[NumKernelArgMDNodes];
78 : };
79 :
80 : } // end anonymous namespace
81 :
82 : static inline bool
83 276 : IsImageType(StringRef TypeString) {
84 276 : return TypeString == "image2d_t" || TypeString == "image3d_t";
85 : }
86 :
87 : static inline bool
88 : IsSamplerType(StringRef TypeString) {
89 : return TypeString == "sampler_t";
90 : }
91 :
92 : static Function *
93 41 : GetFunctionFromMDNode(MDNode *Node) {
94 41 : if (!Node)
95 : return nullptr;
96 :
97 41 : size_t NumOps = Node->getNumOperands();
98 41 : if (NumOps != NumKernelArgMDNodes + 1)
99 : return nullptr;
100 :
101 : auto F = mdconst::dyn_extract<Function>(Node->getOperand(0));
102 : if (!F)
103 : return nullptr;
104 :
105 : // Sanity checks.
106 34 : size_t ExpectNumArgNodeOps = F->arg_size() + 1;
107 202 : for (size_t i = 0; i < NumKernelArgMDNodes; ++i) {
108 169 : MDNode *ArgNode = dyn_cast_or_null<MDNode>(Node->getOperand(i + 1));
109 169 : if (ArgNode->getNumOperands() != ExpectNumArgNodeOps)
110 : return nullptr;
111 169 : if (!ArgNode->getOperand(0))
112 : return nullptr;
113 :
114 : // FIXME: It should be possible to do image lowering when some metadata
115 : // args missing or not in the expected order.
116 : MDString *StringNode = dyn_cast<MDString>(ArgNode->getOperand(0));
117 169 : if (!StringNode || StringNode->getString() != KernelArgMDNodeNames[i])
118 1 : return nullptr;
119 : }
120 :
121 : return F;
122 : }
123 :
124 : static StringRef
125 55 : AccessQualFromMD(MDNode *KernelMDNode, unsigned ArgIdx) {
126 : MDNode *ArgAQNode = cast<MDNode>(KernelMDNode->getOperand(2));
127 110 : return cast<MDString>(ArgAQNode->getOperand(ArgIdx + 1))->getString();
128 : }
129 :
130 : static StringRef
131 276 : ArgTypeFromMD(MDNode *KernelMDNode, unsigned ArgIdx) {
132 : MDNode *ArgTypeNode = cast<MDNode>(KernelMDNode->getOperand(3));
133 552 : return cast<MDString>(ArgTypeNode->getOperand(ArgIdx + 1))->getString();
134 : }
135 :
136 : static MDVector
137 128 : GetArgMD(MDNode *KernelMDNode, unsigned OpIdx) {
138 : MDVector Res;
139 768 : for (unsigned i = 0; i < NumKernelArgMDNodes; ++i) {
140 640 : MDNode *Node = cast<MDNode>(KernelMDNode->getOperand(i + 1));
141 1280 : Res.push_back(Node->getOperand(OpIdx));
142 : }
143 128 : return Res;
144 : }
145 :
146 : static void
147 : PushArgMD(KernelArgMD &MD, const MDVector &V) {
148 : assert(V.size() == NumKernelArgMDNodes);
149 1428 : for (unsigned i = 0; i < NumKernelArgMDNodes; ++i) {
150 2380 : MD.ArgVector[i].push_back(V[i]);
151 : }
152 : }
153 :
154 : namespace {
155 :
156 : class R600OpenCLImageTypeLoweringPass : public ModulePass {
157 : static char ID;
158 :
159 : LLVMContext *Context;
160 : Type *Int32Type;
161 : Type *ImageSizeType;
162 : Type *ImageFormatType;
163 : SmallVector<Instruction *, 4> InstsToErase;
164 :
165 55 : bool replaceImageUses(Argument &ImageArg, uint32_t ResourceID,
166 : Argument &ImageSizeArg,
167 : Argument &ImageFormatArg) {
168 : bool Modified = false;
169 :
170 85 : for (auto &Use : ImageArg.uses()) {
171 30 : auto Inst = dyn_cast<CallInst>(Use.getUser());
172 : if (!Inst) {
173 : continue;
174 : }
175 :
176 : Function *F = Inst->getCalledFunction();
177 : if (!F)
178 : continue;
179 :
180 : Value *Replacement = nullptr;
181 30 : StringRef Name = F->getName();
182 : if (Name.startswith(GetImageResourceIDFunc)) {
183 20 : Replacement = ConstantInt::get(Int32Type, ResourceID);
184 : } else if (Name.startswith(GetImageSizeFunc)) {
185 6 : Replacement = &ImageSizeArg;
186 : } else if (Name.startswith(GetImageFormatFunc)) {
187 4 : Replacement = &ImageFormatArg;
188 : } else {
189 : continue;
190 : }
191 :
192 30 : Inst->replaceAllUsesWith(Replacement);
193 30 : InstsToErase.push_back(Inst);
194 : Modified = true;
195 : }
196 :
197 55 : return Modified;
198 : }
199 :
200 7 : bool replaceSamplerUses(Argument &SamplerArg, uint32_t ResourceID) {
201 : bool Modified = false;
202 :
203 10 : for (const auto &Use : SamplerArg.uses()) {
204 3 : auto Inst = dyn_cast<CallInst>(Use.getUser());
205 : if (!Inst) {
206 0 : continue;
207 : }
208 :
209 : Function *F = Inst->getCalledFunction();
210 : if (!F)
211 : continue;
212 :
213 : Value *Replacement = nullptr;
214 3 : StringRef Name = F->getName();
215 : if (Name == GetSamplerResourceIDFunc) {
216 3 : Replacement = ConstantInt::get(Int32Type, ResourceID);
217 : } else {
218 : continue;
219 : }
220 :
221 3 : Inst->replaceAllUsesWith(Replacement);
222 3 : InstsToErase.push_back(Inst);
223 : Modified = true;
224 : }
225 :
226 7 : return Modified;
227 : }
228 :
229 33 : bool replaceImageAndSamplerUses(Function *F, MDNode *KernelMDNode) {
230 : uint32_t NumReadOnlyImageArgs = 0;
231 : uint32_t NumWriteOnlyImageArgs = 0;
232 : uint32_t NumSamplerArgs = 0;
233 :
234 : bool Modified = false;
235 : InstsToErase.clear();
236 223 : for (auto ArgI = F->arg_begin(); ArgI != F->arg_end(); ++ArgI) {
237 : Argument &Arg = *ArgI;
238 95 : StringRef Type = ArgTypeFromMD(KernelMDNode, Arg.getArgNo());
239 :
240 : // Handle image types.
241 95 : if (IsImageType(Type)) {
242 55 : StringRef AccessQual = AccessQualFromMD(KernelMDNode, Arg.getArgNo());
243 : uint32_t ResourceID;
244 : if (AccessQual == "read_only") {
245 32 : ResourceID = NumReadOnlyImageArgs++;
246 : } else if (AccessQual == "write_only") {
247 23 : ResourceID = NumWriteOnlyImageArgs++;
248 : } else {
249 0 : llvm_unreachable("Wrong image access qualifier.");
250 : }
251 :
252 55 : Argument &SizeArg = *(++ArgI);
253 55 : Argument &FormatArg = *(++ArgI);
254 55 : Modified |= replaceImageUses(Arg, ResourceID, SizeArg, FormatArg);
255 :
256 : // Handle sampler type.
257 : } else if (IsSamplerType(Type)) {
258 7 : uint32_t ResourceID = NumSamplerArgs++;
259 7 : Modified |= replaceSamplerUses(Arg, ResourceID);
260 : }
261 : }
262 66 : for (unsigned i = 0; i < InstsToErase.size(); ++i) {
263 33 : InstsToErase[i]->eraseFromParent();
264 : }
265 :
266 33 : return Modified;
267 : }
268 :
269 : std::tuple<Function *, MDNode *>
270 33 : addImplicitArgs(Function *F, MDNode *KernelMDNode) {
271 : bool Modified = false;
272 :
273 : FunctionType *FT = F->getFunctionType();
274 : SmallVector<Type *, 8> ArgTypes;
275 :
276 : // Metadata operands for new MDNode.
277 : KernelArgMD NewArgMDs;
278 66 : PushArgMD(NewArgMDs, GetArgMD(KernelMDNode, 0));
279 :
280 : // Add implicit arguments to the signature.
281 256 : for (unsigned i = 0; i < FT->getNumParams(); ++i) {
282 190 : ArgTypes.push_back(FT->getParamType(i));
283 95 : MDVector ArgMD = GetArgMD(KernelMDNode, i + 1);
284 : PushArgMD(NewArgMDs, ArgMD);
285 :
286 95 : if (!IsImageType(ArgTypeFromMD(KernelMDNode, i)))
287 : continue;
288 :
289 : // Add size implicit argument.
290 55 : ArgTypes.push_back(ImageSizeType);
291 110 : ArgMD[2] = ArgMD[3] = MDString::get(*Context, ImageSizeArgMDType);
292 : PushArgMD(NewArgMDs, ArgMD);
293 :
294 : // Add format implicit argument.
295 55 : ArgTypes.push_back(ImageFormatType);
296 110 : ArgMD[2] = ArgMD[3] = MDString::get(*Context, ImageFormatArgMDType);
297 : PushArgMD(NewArgMDs, ArgMD);
298 :
299 : Modified = true;
300 : }
301 33 : if (!Modified) {
302 : return std::make_tuple(nullptr, nullptr);
303 : }
304 :
305 : // Create function with new signature and clone the old body into it.
306 60 : auto NewFT = FunctionType::get(FT->getReturnType(), ArgTypes, false);
307 60 : auto NewF = Function::Create(NewFT, F->getLinkage(), F->getName());
308 30 : ValueToValueMapTy VMap;
309 : auto NewFArgIt = NewF->arg_begin();
310 116 : for (auto &Arg: F->args()) {
311 86 : auto ArgName = Arg.getName();
312 172 : NewFArgIt->setName(ArgName);
313 86 : VMap[&Arg] = &(*NewFArgIt++);
314 86 : if (IsImageType(ArgTypeFromMD(KernelMDNode, Arg.getArgNo()))) {
315 55 : (NewFArgIt++)->setName(Twine("__size_") + ArgName);
316 55 : (NewFArgIt++)->setName(Twine("__format_") + ArgName);
317 : }
318 : }
319 : SmallVector<ReturnInst*, 8> Returns;
320 30 : CloneFunctionInto(NewF, F, VMap, /*ModuleLevelChanges=*/false, Returns);
321 :
322 : // Build new MDNode.
323 : SmallVector<Metadata *, 6> KernelMDArgs;
324 30 : KernelMDArgs.push_back(ConstantAsMetadata::get(NewF));
325 180 : for (unsigned i = 0; i < NumKernelArgMDNodes; ++i)
326 300 : KernelMDArgs.push_back(MDNode::get(*Context, NewArgMDs.ArgVector[i]));
327 30 : MDNode *NewMDNode = MDNode::get(*Context, KernelMDArgs);
328 :
329 : return std::make_tuple(NewF, NewMDNode);
330 : }
331 :
332 281 : bool transformKernels(Module &M) {
333 281 : NamedMDNode *KernelsMDNode = M.getNamedMetadata(KernelsMDNodeName);
334 281 : if (!KernelsMDNode)
335 : return false;
336 :
337 : bool Modified = false;
338 47 : for (unsigned i = 0; i < KernelsMDNode->getNumOperands(); ++i) {
339 41 : MDNode *KernelMDNode = KernelsMDNode->getOperand(i);
340 41 : Function *F = GetFunctionFromMDNode(KernelMDNode);
341 41 : if (!F)
342 8 : continue;
343 :
344 : Function *NewF;
345 : MDNode *NewMDNode;
346 33 : std::tie(NewF, NewMDNode) = addImplicitArgs(F, KernelMDNode);
347 33 : if (NewF) {
348 : // Replace old function and metadata with new ones.
349 30 : F->eraseFromParent();
350 30 : M.getFunctionList().push_back(NewF);
351 30 : M.getOrInsertFunction(NewF->getName(), NewF->getFunctionType(),
352 : NewF->getAttributes());
353 30 : KernelsMDNode->setOperand(i, NewMDNode);
354 :
355 : F = NewF;
356 : KernelMDNode = NewMDNode;
357 : Modified = true;
358 : }
359 :
360 33 : Modified |= replaceImageAndSamplerUses(F, KernelMDNode);
361 : }
362 :
363 : return Modified;
364 : }
365 :
366 : public:
367 564 : R600OpenCLImageTypeLoweringPass() : ModulePass(ID) {}
368 :
369 281 : bool runOnModule(Module &M) override {
370 281 : Context = &M.getContext();
371 281 : Int32Type = Type::getInt32Ty(M.getContext());
372 281 : ImageSizeType = ArrayType::get(Int32Type, 3);
373 281 : ImageFormatType = ArrayType::get(Int32Type, 2);
374 :
375 281 : return transformKernels(M);
376 : }
377 :
378 0 : StringRef getPassName() const override {
379 0 : return "R600 OpenCL Image Type Pass";
380 : }
381 : };
382 :
383 : } // end anonymous namespace
384 :
385 : char R600OpenCLImageTypeLoweringPass::ID = 0;
386 :
387 282 : ModulePass *llvm::createR600OpenCLImageTypeLoweringPass() {
388 282 : return new R600OpenCLImageTypeLoweringPass();
389 : }
|