LCOV - code coverage report
Current view: top level - lib/Target/AMDGPU - R600OpenCLImageTypeLoweringPass.cpp (source / functions) Hit Total Coverage
Test: llvm-toolchain.info Lines: 108 112 96.4 %
Date: 2018-10-20 13:21:21 Functions: 12 13 92.3 %
Legend: Lines: hit not hit

          Line data    Source code
       1             : //===- R600OpenCLImageTypeLoweringPass.cpp ------------------------------===//
       2             : //
       3             : //                     The LLVM Compiler Infrastructure
       4             : //
       5             : // This file is distributed under the University of Illinois Open Source
       6             : // License. See LICENSE.TXT for details.
       7             : //
       8             : //===----------------------------------------------------------------------===//
       9             : //
      10             : /// \file
      11             : /// This pass resolves calls to OpenCL image attribute, image resource ID and
      12             : /// sampler resource ID getter functions.
      13             : ///
      14             : /// Image attributes (size and format) are expected to be passed to the kernel
      15             : /// as kernel arguments immediately following the image argument itself,
      16             : /// therefore this pass adds image size and format arguments to the kernel
      17             : /// functions in the module. The kernel functions with image arguments are
      18             : /// re-created using the new signature. The new arguments are added to the
      19             : /// kernel metadata with kernel_arg_type set to "image_size" or "image_format".
      20             : /// Note: this pass may invalidate pointers to functions.
      21             : ///
      22             : /// Resource IDs of read-only images, write-only images and samplers are
      23             : /// defined to be their index among the kernel arguments of the same
      24             : /// type and access qualifier.
      25             : //
      26             : //===----------------------------------------------------------------------===//
      27             : 
      28             : #include "AMDGPU.h"
      29             : #include "llvm/ADT/SmallVector.h"
      30             : #include "llvm/ADT/StringRef.h"
      31             : #include "llvm/ADT/Twine.h"
      32             : #include "llvm/IR/Argument.h"
      33             : #include "llvm/IR/DerivedTypes.h"
      34             : #include "llvm/IR/Constants.h"
      35             : #include "llvm/IR/Function.h"
      36             : #include "llvm/IR/Instruction.h"
      37             : #include "llvm/IR/Instructions.h"
      38             : #include "llvm/IR/Metadata.h"
      39             : #include "llvm/IR/Module.h"
      40             : #include "llvm/IR/Type.h"
      41             : #include "llvm/IR/Use.h"
      42             : #include "llvm/IR/User.h"
      43             : #include "llvm/Pass.h"
      44             : #include "llvm/Support/Casting.h"
      45             : #include "llvm/Support/ErrorHandling.h"
      46             : #include "llvm/Transforms/Utils/Cloning.h"
      47             : #include "llvm/Transforms/Utils/ValueMapper.h"
      48             : #include <cassert>
      49             : #include <cstddef>
      50             : #include <cstdint>
      51             : #include <tuple>
      52             : 
      53             : using namespace llvm;
      54             : 
      55             : static StringRef GetImageSizeFunc =         "llvm.OpenCL.image.get.size";
      56             : static StringRef GetImageFormatFunc =       "llvm.OpenCL.image.get.format";
      57             : static StringRef GetImageResourceIDFunc =   "llvm.OpenCL.image.get.resource.id";
      58             : static StringRef GetSamplerResourceIDFunc =
      59             :     "llvm.OpenCL.sampler.get.resource.id";
      60             : 
      61             : static StringRef ImageSizeArgMDType =   "__llvm_image_size";
      62             : static StringRef ImageFormatArgMDType = "__llvm_image_format";
      63             : 
      64             : static StringRef KernelsMDNodeName = "opencl.kernels";
      65             : static StringRef KernelArgMDNodeNames[] = {
      66             :   "kernel_arg_addr_space",
      67             :   "kernel_arg_access_qual",
      68             :   "kernel_arg_type",
      69             :   "kernel_arg_base_type",
      70             :   "kernel_arg_type_qual"};
      71             : static const unsigned NumKernelArgMDNodes = 5;
      72             : 
      73             : namespace {
      74             : 
      75             : using MDVector = SmallVector<Metadata *, 8>;
      76         396 : struct KernelArgMD {
      77             :   MDVector ArgVector[NumKernelArgMDNodes];
      78             : };
      79             : 
      80             : } // end anonymous namespace
      81             : 
      82             : static inline bool
      83         276 : IsImageType(StringRef TypeString) {
      84         276 :   return TypeString == "image2d_t" || TypeString == "image3d_t";
      85             : }
      86             : 
      87             : static inline bool
      88             : IsSamplerType(StringRef TypeString) {
      89             :   return TypeString == "sampler_t";
      90             : }
      91             : 
      92             : static Function *
      93          41 : GetFunctionFromMDNode(MDNode *Node) {
      94          41 :   if (!Node)
      95             :     return nullptr;
      96             : 
      97          41 :   size_t NumOps = Node->getNumOperands();
      98          41 :   if (NumOps != NumKernelArgMDNodes + 1)
      99             :     return nullptr;
     100             : 
     101             :   auto F = mdconst::dyn_extract<Function>(Node->getOperand(0));
     102             :   if (!F)
     103             :     return nullptr;
     104             : 
     105             :   // Sanity checks.
     106          34 :   size_t ExpectNumArgNodeOps = F->arg_size() + 1;
     107         202 :   for (size_t i = 0; i < NumKernelArgMDNodes; ++i) {
     108         169 :     MDNode *ArgNode = dyn_cast_or_null<MDNode>(Node->getOperand(i + 1));
     109         169 :     if (ArgNode->getNumOperands() != ExpectNumArgNodeOps)
     110             :       return nullptr;
     111         169 :     if (!ArgNode->getOperand(0))
     112             :       return nullptr;
     113             : 
     114             :     // FIXME: It should be possible to do image lowering when some metadata
     115             :     // args missing or not in the expected order.
     116             :     MDString *StringNode = dyn_cast<MDString>(ArgNode->getOperand(0));
     117         169 :     if (!StringNode || StringNode->getString() != KernelArgMDNodeNames[i])
     118           1 :       return nullptr;
     119             :   }
     120             : 
     121             :   return F;
     122             : }
     123             : 
     124             : static StringRef
     125          55 : AccessQualFromMD(MDNode *KernelMDNode, unsigned ArgIdx) {
     126             :   MDNode *ArgAQNode = cast<MDNode>(KernelMDNode->getOperand(2));
     127         110 :   return cast<MDString>(ArgAQNode->getOperand(ArgIdx + 1))->getString();
     128             : }
     129             : 
     130             : static StringRef
     131         276 : ArgTypeFromMD(MDNode *KernelMDNode, unsigned ArgIdx) {
     132             :   MDNode *ArgTypeNode = cast<MDNode>(KernelMDNode->getOperand(3));
     133         552 :   return cast<MDString>(ArgTypeNode->getOperand(ArgIdx + 1))->getString();
     134             : }
     135             : 
     136             : static MDVector
     137         128 : GetArgMD(MDNode *KernelMDNode, unsigned OpIdx) {
     138             :   MDVector Res;
     139         768 :   for (unsigned i = 0; i < NumKernelArgMDNodes; ++i) {
     140         640 :     MDNode *Node = cast<MDNode>(KernelMDNode->getOperand(i + 1));
     141        1280 :     Res.push_back(Node->getOperand(OpIdx));
     142             :   }
     143         128 :   return Res;
     144             : }
     145             : 
     146             : static void
     147             : PushArgMD(KernelArgMD &MD, const MDVector &V) {
     148             :   assert(V.size() == NumKernelArgMDNodes);
     149        1428 :   for (unsigned i = 0; i < NumKernelArgMDNodes; ++i) {
     150        2380 :     MD.ArgVector[i].push_back(V[i]);
     151             :   }
     152             : }
     153             : 
     154             : namespace {
     155             : 
     156             : class R600OpenCLImageTypeLoweringPass : public ModulePass {
     157             :   static char ID;
     158             : 
     159             :   LLVMContext *Context;
     160             :   Type *Int32Type;
     161             :   Type *ImageSizeType;
     162             :   Type *ImageFormatType;
     163             :   SmallVector<Instruction *, 4> InstsToErase;
     164             : 
     165          55 :   bool replaceImageUses(Argument &ImageArg, uint32_t ResourceID,
     166             :                         Argument &ImageSizeArg,
     167             :                         Argument &ImageFormatArg) {
     168             :     bool Modified = false;
     169             : 
     170          85 :     for (auto &Use : ImageArg.uses()) {
     171          30 :       auto Inst = dyn_cast<CallInst>(Use.getUser());
     172             :       if (!Inst) {
     173             :         continue;
     174             :       }
     175             : 
     176             :       Function *F = Inst->getCalledFunction();
     177             :       if (!F)
     178             :         continue;
     179             : 
     180             :       Value *Replacement = nullptr;
     181          30 :       StringRef Name = F->getName();
     182             :       if (Name.startswith(GetImageResourceIDFunc)) {
     183          20 :         Replacement = ConstantInt::get(Int32Type, ResourceID);
     184             :       } else if (Name.startswith(GetImageSizeFunc)) {
     185           6 :         Replacement = &ImageSizeArg;
     186             :       } else if (Name.startswith(GetImageFormatFunc)) {
     187           4 :         Replacement = &ImageFormatArg;
     188             :       } else {
     189             :         continue;
     190             :       }
     191             : 
     192          30 :       Inst->replaceAllUsesWith(Replacement);
     193          30 :       InstsToErase.push_back(Inst);
     194             :       Modified = true;
     195             :     }
     196             : 
     197          55 :     return Modified;
     198             :   }
     199             : 
     200           7 :   bool replaceSamplerUses(Argument &SamplerArg, uint32_t ResourceID) {
     201             :     bool Modified = false;
     202             : 
     203          10 :     for (const auto &Use : SamplerArg.uses()) {
     204           3 :       auto Inst = dyn_cast<CallInst>(Use.getUser());
     205             :       if (!Inst) {
     206           0 :         continue;
     207             :       }
     208             : 
     209             :       Function *F = Inst->getCalledFunction();
     210             :       if (!F)
     211             :         continue;
     212             : 
     213             :       Value *Replacement = nullptr;
     214           3 :       StringRef Name = F->getName();
     215             :       if (Name == GetSamplerResourceIDFunc) {
     216           3 :         Replacement = ConstantInt::get(Int32Type, ResourceID);
     217             :       } else {
     218             :         continue;
     219             :       }
     220             : 
     221           3 :       Inst->replaceAllUsesWith(Replacement);
     222           3 :       InstsToErase.push_back(Inst);
     223             :       Modified = true;
     224             :     }
     225             : 
     226           7 :     return Modified;
     227             :   }
     228             : 
     229          33 :   bool replaceImageAndSamplerUses(Function *F, MDNode *KernelMDNode) {
     230             :     uint32_t NumReadOnlyImageArgs = 0;
     231             :     uint32_t NumWriteOnlyImageArgs = 0;
     232             :     uint32_t NumSamplerArgs = 0;
     233             : 
     234             :     bool Modified = false;
     235             :     InstsToErase.clear();
     236         223 :     for (auto ArgI = F->arg_begin(); ArgI != F->arg_end(); ++ArgI) {
     237             :       Argument &Arg = *ArgI;
     238          95 :       StringRef Type = ArgTypeFromMD(KernelMDNode, Arg.getArgNo());
     239             : 
     240             :       // Handle image types.
     241          95 :       if (IsImageType(Type)) {
     242          55 :         StringRef AccessQual = AccessQualFromMD(KernelMDNode, Arg.getArgNo());
     243             :         uint32_t ResourceID;
     244             :         if (AccessQual == "read_only") {
     245          32 :           ResourceID = NumReadOnlyImageArgs++;
     246             :         } else if (AccessQual == "write_only") {
     247          23 :           ResourceID = NumWriteOnlyImageArgs++;
     248             :         } else {
     249           0 :           llvm_unreachable("Wrong image access qualifier.");
     250             :         }
     251             : 
     252          55 :         Argument &SizeArg = *(++ArgI);
     253          55 :         Argument &FormatArg = *(++ArgI);
     254          55 :         Modified |= replaceImageUses(Arg, ResourceID, SizeArg, FormatArg);
     255             : 
     256             :       // Handle sampler type.
     257             :       } else if (IsSamplerType(Type)) {
     258           7 :         uint32_t ResourceID = NumSamplerArgs++;
     259           7 :         Modified |= replaceSamplerUses(Arg, ResourceID);
     260             :       }
     261             :     }
     262          66 :     for (unsigned i = 0; i < InstsToErase.size(); ++i) {
     263          33 :       InstsToErase[i]->eraseFromParent();
     264             :     }
     265             : 
     266          33 :     return Modified;
     267             :   }
     268             : 
     269             :   std::tuple<Function *, MDNode *>
     270          33 :   addImplicitArgs(Function *F, MDNode *KernelMDNode) {
     271             :     bool Modified = false;
     272             : 
     273             :     FunctionType *FT = F->getFunctionType();
     274             :     SmallVector<Type *, 8> ArgTypes;
     275             : 
     276             :     // Metadata operands for new MDNode.
     277             :     KernelArgMD NewArgMDs;
     278          66 :     PushArgMD(NewArgMDs, GetArgMD(KernelMDNode, 0));
     279             : 
     280             :     // Add implicit arguments to the signature.
     281         256 :     for (unsigned i = 0; i < FT->getNumParams(); ++i) {
     282         190 :       ArgTypes.push_back(FT->getParamType(i));
     283          95 :       MDVector ArgMD = GetArgMD(KernelMDNode, i + 1);
     284             :       PushArgMD(NewArgMDs, ArgMD);
     285             : 
     286          95 :       if (!IsImageType(ArgTypeFromMD(KernelMDNode, i)))
     287             :         continue;
     288             : 
     289             :       // Add size implicit argument.
     290          55 :       ArgTypes.push_back(ImageSizeType);
     291         110 :       ArgMD[2] = ArgMD[3] = MDString::get(*Context, ImageSizeArgMDType);
     292             :       PushArgMD(NewArgMDs, ArgMD);
     293             : 
     294             :       // Add format implicit argument.
     295          55 :       ArgTypes.push_back(ImageFormatType);
     296         110 :       ArgMD[2] = ArgMD[3] = MDString::get(*Context, ImageFormatArgMDType);
     297             :       PushArgMD(NewArgMDs, ArgMD);
     298             : 
     299             :       Modified = true;
     300             :     }
     301          33 :     if (!Modified) {
     302             :       return std::make_tuple(nullptr, nullptr);
     303             :     }
     304             : 
     305             :     // Create function with new signature and clone the old body into it.
     306          60 :     auto NewFT = FunctionType::get(FT->getReturnType(), ArgTypes, false);
     307          60 :     auto NewF = Function::Create(NewFT, F->getLinkage(), F->getName());
     308          30 :     ValueToValueMapTy VMap;
     309             :     auto NewFArgIt = NewF->arg_begin();
     310         116 :     for (auto &Arg: F->args()) {
     311          86 :       auto ArgName = Arg.getName();
     312         172 :       NewFArgIt->setName(ArgName);
     313          86 :       VMap[&Arg] = &(*NewFArgIt++);
     314          86 :       if (IsImageType(ArgTypeFromMD(KernelMDNode, Arg.getArgNo()))) {
     315          55 :         (NewFArgIt++)->setName(Twine("__size_") + ArgName);
     316          55 :         (NewFArgIt++)->setName(Twine("__format_") + ArgName);
     317             :       }
     318             :     }
     319             :     SmallVector<ReturnInst*, 8> Returns;
     320          30 :     CloneFunctionInto(NewF, F, VMap, /*ModuleLevelChanges=*/false, Returns);
     321             : 
     322             :     // Build new MDNode.
     323             :     SmallVector<Metadata *, 6> KernelMDArgs;
     324          30 :     KernelMDArgs.push_back(ConstantAsMetadata::get(NewF));
     325         180 :     for (unsigned i = 0; i < NumKernelArgMDNodes; ++i)
     326         300 :       KernelMDArgs.push_back(MDNode::get(*Context, NewArgMDs.ArgVector[i]));
     327          30 :     MDNode *NewMDNode = MDNode::get(*Context, KernelMDArgs);
     328             : 
     329             :     return std::make_tuple(NewF, NewMDNode);
     330             :   }
     331             : 
     332         281 :   bool transformKernels(Module &M) {
     333         281 :     NamedMDNode *KernelsMDNode = M.getNamedMetadata(KernelsMDNodeName);
     334         281 :     if (!KernelsMDNode)
     335             :       return false;
     336             : 
     337             :     bool Modified = false;
     338          47 :     for (unsigned i = 0; i < KernelsMDNode->getNumOperands(); ++i) {
     339          41 :       MDNode *KernelMDNode = KernelsMDNode->getOperand(i);
     340          41 :       Function *F = GetFunctionFromMDNode(KernelMDNode);
     341          41 :       if (!F)
     342           8 :         continue;
     343             : 
     344             :       Function *NewF;
     345             :       MDNode *NewMDNode;
     346          33 :       std::tie(NewF, NewMDNode) = addImplicitArgs(F, KernelMDNode);
     347          33 :       if (NewF) {
     348             :         // Replace old function and metadata with new ones.
     349          30 :         F->eraseFromParent();
     350          30 :         M.getFunctionList().push_back(NewF);
     351          30 :         M.getOrInsertFunction(NewF->getName(), NewF->getFunctionType(),
     352             :                               NewF->getAttributes());
     353          30 :         KernelsMDNode->setOperand(i, NewMDNode);
     354             : 
     355             :         F = NewF;
     356             :         KernelMDNode = NewMDNode;
     357             :         Modified = true;
     358             :       }
     359             : 
     360          33 :       Modified |= replaceImageAndSamplerUses(F, KernelMDNode);
     361             :     }
     362             : 
     363             :     return Modified;
     364             :   }
     365             : 
     366             : public:
     367         564 :   R600OpenCLImageTypeLoweringPass() : ModulePass(ID) {}
     368             : 
     369         281 :   bool runOnModule(Module &M) override {
     370         281 :     Context = &M.getContext();
     371         281 :     Int32Type = Type::getInt32Ty(M.getContext());
     372         281 :     ImageSizeType = ArrayType::get(Int32Type, 3);
     373         281 :     ImageFormatType = ArrayType::get(Int32Type, 2);
     374             : 
     375         281 :     return transformKernels(M);
     376             :   }
     377             : 
     378           0 :   StringRef getPassName() const override {
     379           0 :     return "R600 OpenCL Image Type Pass";
     380             :   }
     381             : };
     382             : 
     383             : } // end anonymous namespace
     384             : 
     385             : char R600OpenCLImageTypeLoweringPass::ID = 0;
     386             : 
     387         282 : ModulePass *llvm::createR600OpenCLImageTypeLoweringPass() {
     388         282 :   return new R600OpenCLImageTypeLoweringPass();
     389             : }

Generated by: LCOV version 1.13