LLVM  3.7.0
NVPTXLowerKernelArgs.cpp
Go to the documentation of this file.
1 //===-- NVPTXLowerKernelArgs.cpp - Lower kernel arguments -----------------===//
2 //
3 // The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 //
10 // Pointer arguments to kernel functions need to be lowered specially.
11 //
12 // 1. Copy byval struct args to local memory. This is a preparation for handling
13 // cases like
14 //
15 // kernel void foo(struct A arg, ...)
16 // {
17 // struct A *p = &arg;
18 // ...
19 // ... = p->filed1 ... (this is no generic address for .param)
20 // p->filed2 = ... (this is no write access to .param)
21 // }
22 //
23 // 2. Convert non-byval pointer arguments of CUDA kernels to pointers in the
24 // global address space. This allows later optimizations to emit
25 // ld.global.*/st.global.* for accessing these pointer arguments. For
26 // example,
27 //
28 // define void @foo(float* %input) {
29 // %v = load float, float* %input, align 4
30 // ...
31 // }
32 //
33 // becomes
34 //
35 // define void @foo(float* %input) {
36 // %input2 = addrspacecast float* %input to float addrspace(1)*
37 // %input3 = addrspacecast float addrspace(1)* %input2 to float*
38 // %v = load float, float* %input3, align 4
39 // ...
40 // }
41 //
42 // Later, NVPTXFavorNonGenericAddrSpaces will optimize it to
43 //
44 // define void @foo(float* %input) {
45 // %input2 = addrspacecast float* %input to float addrspace(1)*
46 // %v = load float, float addrspace(1)* %input2, align 4
47 // ...
48 // }
49 //
50 // TODO: merge this pass with NVPTXFavorNonGenericAddrSpace so that other passes
51 // don't cancel the addrspacecast pair this pass emits.
52 //===----------------------------------------------------------------------===//
53 
54 #include "NVPTX.h"
55 #include "NVPTXUtilities.h"
56 #include "NVPTXTargetMachine.h"
57 #include "llvm/IR/Function.h"
58 #include "llvm/IR/Instructions.h"
59 #include "llvm/IR/Module.h"
60 #include "llvm/IR/Type.h"
61 #include "llvm/Pass.h"
62 
63 using namespace llvm;
64 
65 namespace llvm {
67 }
68 
69 namespace {
70 class NVPTXLowerKernelArgs : public FunctionPass {
71  bool runOnFunction(Function &F) override;
72 
73  // handle byval parameters
74  void handleByValParam(Argument *);
75  // handle non-byval pointer parameters
76  void handlePointerParam(Argument *);
77 
78 public:
79  static char ID; // Pass identification, replacement for typeid
80  NVPTXLowerKernelArgs(const NVPTXTargetMachine *TM = nullptr)
81  : FunctionPass(ID), TM(TM) {}
82  const char *getPassName() const override {
83  return "Lower pointer arguments of CUDA kernels";
84  }
85 
86 private:
87  const NVPTXTargetMachine *TM;
88 };
89 } // namespace
90 
92 
93 INITIALIZE_PASS(NVPTXLowerKernelArgs, "nvptx-lower-kernel-args",
94  "Lower kernel arguments (NVPTX)", false, false)
95 
96 // =============================================================================
97 // If the function had a byval struct ptr arg, say foo(%struct.x *byval %d),
98 // then add the following instructions to the first basic block:
99 //
100 // %temp = alloca %struct.x, align 8
101 // %tempd = addrspacecast %struct.x* %d to %struct.x addrspace(101)*
102 // %tv = load %struct.x addrspace(101)* %tempd
103 // store %struct.x %tv, %struct.x* %temp, align 8
104 //
105 // The above code allocates some space in the stack and copies the incoming
106 // struct from param space to local space.
107 // Then replace all occurences of %d by %temp.
108 // =============================================================================
109 void NVPTXLowerKernelArgs::handleByValParam(Argument *Arg) {
110  Function *Func = Arg->getParent();
111  Instruction *FirstInst = &(Func->getEntryBlock().front());
112  PointerType *PType = dyn_cast<PointerType>(Arg->getType());
113 
114  assert(PType && "Expecting pointer type in handleByValParam");
115 
116  Type *StructType = PType->getElementType();
117  AllocaInst *AllocA = new AllocaInst(StructType, Arg->getName(), FirstInst);
118  // Set the alignment to alignment of the byval parameter. This is because,
119  // later load/stores assume that alignment, and we are going to replace
120  // the use of the byval parameter with this alloca instruction.
121  AllocA->setAlignment(Func->getParamAlignment(Arg->getArgNo() + 1));
122  Arg->replaceAllUsesWith(AllocA);
123 
124  Value *ArgInParam = new AddrSpaceCastInst(
125  Arg, PointerType::get(StructType, ADDRESS_SPACE_PARAM), Arg->getName(),
126  FirstInst);
127  LoadInst *LI = new LoadInst(ArgInParam, Arg->getName(), FirstInst);
128  new StoreInst(LI, AllocA, FirstInst);
129 }
130 
131 void NVPTXLowerKernelArgs::handlePointerParam(Argument *Arg) {
132  assert(!Arg->hasByValAttr() &&
133  "byval params should be handled by handleByValParam");
134 
135  // Do nothing if the argument already points to the global address space.
137  return;
138 
139  Instruction *FirstInst = Arg->getParent()->getEntryBlock().begin();
140  Instruction *ArgInGlobal = new AddrSpaceCastInst(
143  Arg->getName(), FirstInst);
144  Value *ArgInGeneric = new AddrSpaceCastInst(ArgInGlobal, Arg->getType(),
145  Arg->getName(), FirstInst);
146  // Replace with ArgInGeneric all uses of Args except ArgInGlobal.
147  Arg->replaceAllUsesWith(ArgInGeneric);
148  ArgInGlobal->setOperand(0, Arg);
149 }
150 
151 
152 // =============================================================================
153 // Main function for this pass.
154 // =============================================================================
155 bool NVPTXLowerKernelArgs::runOnFunction(Function &F) {
156  // Skip non-kernels. See the comments at the top of this file.
157  if (!isKernelFunction(F))
158  return false;
159 
160  for (Argument &Arg : F.args()) {
161  if (Arg.getType()->isPointerTy()) {
162  if (Arg.hasByValAttr())
163  handleByValParam(&Arg);
164  else if (TM && TM->getDrvInterface() == NVPTX::CUDA)
165  handlePointerParam(&Arg);
166  }
167  }
168  return true;
169 }
170 
171 FunctionPass *
173  return new NVPTXLowerKernelArgs(TM);
174 }
LLVM Argument representation.
Definition: Argument.h:35
void setAlignment(unsigned Align)
static PointerType * get(Type *ElementType, unsigned AddressSpace)
PointerType::get - This constructs a pointer to an object of the specified type in a numbered address...
Definition: Type.cpp:738
bool isKernelFunction(const llvm::Function &)
const Instruction & front() const
Definition: BasicBlock.h:243
F(f)
LoadInst - an instruction for reading from memory.
Definition: Instructions.h:177
unsigned getPointerAddressSpace() const
Get the address space of this pointer or pointer vector type.
Definition: Type.cpp:216
Type * getPointerElementType() const
Definition: Type.h:366
StringRef getName() const
Return a constant reference to the value's name.
Definition: Value.cpp:188
iterator begin()
Instruction iterator methods.
Definition: BasicBlock.h:231
void initializeNVPTXLowerKernelArgsPass(PassRegistry &)
This class represents a conversion between pointers from one address space to another.
StructType - Class to represent struct types.
Definition: DerivedTypes.h:191
StoreInst - an instruction for storing to memory.
Definition: Instructions.h:316
void replaceAllUsesWith(Value *V)
Change all uses of this to point to a new Value.
Definition: Value.cpp:351
Type * getElementType() const
Definition: DerivedTypes.h:323
PointerType - Class to represent pointers.
Definition: DerivedTypes.h:449
The instances of the Type class are immutable: once they are created, they are never changed...
Definition: Type.h:45
const Function * getParent() const
Definition: Argument.h:49
FunctionPass class - This class is used to implement most global optimizations.
Definition: Pass.h:294
bool isPointerTy() const
isPointerTy - True if this is an instance of PointerType.
Definition: Type.h:217
FunctionPass * createNVPTXLowerKernelArgsPass(const NVPTXTargetMachine *TM)
bool hasByValAttr() const
Return true if this argument has the byval attribute on it in its containing function.
Definition: Function.cpp:90
Module.h This file contains the declarations for the Module class.
Type * getType() const
All values are typed, get the type of this value.
Definition: Value.h:222
const BasicBlock & getEntryBlock() const
Definition: Function.h:442
void setOperand(unsigned i, Value *Val)
Definition: User.h:122
INITIALIZE_PASS(NVPTXLowerKernelArgs,"nvptx-lower-kernel-args","Lower kernel arguments (NVPTX)", false, false) void NVPTXLowerKernelArgs
LLVM_ATTRIBUTE_UNUSED_RESULT std::enable_if< !is_simple_type< Y >::value, typename cast_retty< X, const Y >::ret_type >::type dyn_cast(const Y &Val)
Definition: Casting.h:285
NVPTXTargetMachine.
unsigned getParamAlignment(unsigned i) const
Extract the alignment for a call or parameter (0=unknown).
Definition: Function.h:261
LLVM Value Representation.
Definition: Value.h:69
PassRegistry - This class manages the registration and intitialization of the pass subsystem as appli...
Definition: PassRegistry.h:41
iterator_range< arg_iterator > args()
Definition: Function.h:489
AllocaInst - an instruction to allocate memory on the stack.
Definition: Instructions.h:76