LLVM  4.0.0
NVPTXLowerAggrCopies.cpp
Go to the documentation of this file.
1 //===- NVPTXLowerAggrCopies.cpp - ------------------------------*- C++ -*--===//
2 //
3 // The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 //
10 // \file
11 // Lower aggregate copies, memset, memcpy, memmov intrinsics into loops when
12 // the size is large or is not a compile-time constant.
13 //
14 //===----------------------------------------------------------------------===//
15 
16 #include "NVPTXLowerAggrCopies.h"
18 #include "llvm/IR/Constants.h"
19 #include "llvm/IR/DataLayout.h"
20 #include "llvm/IR/Function.h"
21 #include "llvm/IR/IRBuilder.h"
22 #include "llvm/IR/Instructions.h"
23 #include "llvm/IR/IntrinsicInst.h"
24 #include "llvm/IR/Intrinsics.h"
25 #include "llvm/IR/LLVMContext.h"
26 #include "llvm/IR/Module.h"
27 #include "llvm/Support/Debug.h"
29 
30 #define DEBUG_TYPE "nvptx"
31 
32 using namespace llvm;
33 
34 namespace {
35 
36 // actual analysis class, which is a functionpass
37 struct NVPTXLowerAggrCopies : public FunctionPass {
38  static char ID;
39 
40  NVPTXLowerAggrCopies() : FunctionPass(ID) {}
41 
42  void getAnalysisUsage(AnalysisUsage &AU) const override {
44  }
45 
46  bool runOnFunction(Function &F) override;
47 
48  static const unsigned MaxAggrCopySize = 128;
49 
50  StringRef getPassName() const override {
51  return "Lower aggregate copies/intrinsics into loops";
52  }
53 };
54 
56 
57 // Lower memcpy to loop.
58 void convertMemCpyToLoop(Instruction *ConvertedInst, Value *SrcAddr,
59  Value *DstAddr, Value *CopyLen, bool SrcIsVolatile,
60  bool DstIsVolatile, LLVMContext &Context,
61  Function &F) {
62  Type *TypeOfCopyLen = CopyLen->getType();
63 
64  BasicBlock *OrigBB = ConvertedInst->getParent();
65  BasicBlock *NewBB =
66  ConvertedInst->getParent()->splitBasicBlock(ConvertedInst, "split");
67  BasicBlock *LoopBB = BasicBlock::Create(Context, "loadstoreloop", &F, NewBB);
68 
69  OrigBB->getTerminator()->setSuccessor(0, LoopBB);
70  IRBuilder<> Builder(OrigBB->getTerminator());
71 
72  // SrcAddr and DstAddr are expected to be pointer types,
73  // so no check is made here.
74  unsigned SrcAS = cast<PointerType>(SrcAddr->getType())->getAddressSpace();
75  unsigned DstAS = cast<PointerType>(DstAddr->getType())->getAddressSpace();
76 
77  // Cast pointers to (char *)
78  SrcAddr = Builder.CreateBitCast(SrcAddr, Builder.getInt8PtrTy(SrcAS));
79  DstAddr = Builder.CreateBitCast(DstAddr, Builder.getInt8PtrTy(DstAS));
80 
81  IRBuilder<> LoopBuilder(LoopBB);
82  PHINode *LoopIndex = LoopBuilder.CreatePHI(TypeOfCopyLen, 0);
83  LoopIndex->addIncoming(ConstantInt::get(TypeOfCopyLen, 0), OrigBB);
84 
85  // load from SrcAddr+LoopIndex
86  // TODO: we can leverage the align parameter of llvm.memcpy for more efficient
87  // word-sized loads and stores.
88  Value *Element =
89  LoopBuilder.CreateLoad(LoopBuilder.CreateInBoundsGEP(
90  LoopBuilder.getInt8Ty(), SrcAddr, LoopIndex),
91  SrcIsVolatile);
92  // store at DstAddr+LoopIndex
93  LoopBuilder.CreateStore(Element,
94  LoopBuilder.CreateInBoundsGEP(LoopBuilder.getInt8Ty(),
95  DstAddr, LoopIndex),
96  DstIsVolatile);
97 
98  // The value for LoopIndex coming from backedge is (LoopIndex + 1)
99  Value *NewIndex =
100  LoopBuilder.CreateAdd(LoopIndex, ConstantInt::get(TypeOfCopyLen, 1));
101  LoopIndex->addIncoming(NewIndex, LoopBB);
102 
103  LoopBuilder.CreateCondBr(LoopBuilder.CreateICmpULT(NewIndex, CopyLen), LoopBB,
104  NewBB);
105 }
106 
107 // Lower memmove to IR. memmove is required to correctly copy overlapping memory
108 // regions; therefore, it has to check the relative positions of the source and
109 // destination pointers and choose the copy direction accordingly.
110 //
111 // The code below is an IR rendition of this C function:
112 //
113 // void* memmove(void* dst, const void* src, size_t n) {
114 // unsigned char* d = dst;
115 // const unsigned char* s = src;
116 // if (s < d) {
117 // // copy backwards
118 // while (n--) {
119 // d[n] = s[n];
120 // }
121 // } else {
122 // // copy forward
123 // for (size_t i = 0; i < n; ++i) {
124 // d[i] = s[i];
125 // }
126 // }
127 // return dst;
128 // }
129 void convertMemMoveToLoop(Instruction *ConvertedInst, Value *SrcAddr,
130  Value *DstAddr, Value *CopyLen, bool SrcIsVolatile,
131  bool DstIsVolatile, LLVMContext &Context,
132  Function &F) {
133  Type *TypeOfCopyLen = CopyLen->getType();
134  BasicBlock *OrigBB = ConvertedInst->getParent();
135 
136  // Create the a comparison of src and dst, based on which we jump to either
137  // the forward-copy part of the function (if src >= dst) or the backwards-copy
138  // part (if src < dst).
139  // SplitBlockAndInsertIfThenElse conveniently creates the basic if-then-else
140  // structure. Its block terminators (unconditional branches) are replaced by
141  // the appropriate conditional branches when the loop is built.
142  ICmpInst *PtrCompare = new ICmpInst(ConvertedInst, ICmpInst::ICMP_ULT,
143  SrcAddr, DstAddr, "compare_src_dst");
144  TerminatorInst *ThenTerm, *ElseTerm;
145  SplitBlockAndInsertIfThenElse(PtrCompare, ConvertedInst, &ThenTerm,
146  &ElseTerm);
147 
148  // Each part of the function consists of two blocks:
149  // copy_backwards: used to skip the loop when n == 0
150  // copy_backwards_loop: the actual backwards loop BB
151  // copy_forward: used to skip the loop when n == 0
152  // copy_forward_loop: the actual forward loop BB
153  BasicBlock *CopyBackwardsBB = ThenTerm->getParent();
154  CopyBackwardsBB->setName("copy_backwards");
155  BasicBlock *CopyForwardBB = ElseTerm->getParent();
156  CopyForwardBB->setName("copy_forward");
157  BasicBlock *ExitBB = ConvertedInst->getParent();
158  ExitBB->setName("memmove_done");
159 
160  // Initial comparison of n == 0 that lets us skip the loops altogether. Shared
161  // between both backwards and forward copy clauses.
162  ICmpInst *CompareN =
163  new ICmpInst(OrigBB->getTerminator(), ICmpInst::ICMP_EQ, CopyLen,
164  ConstantInt::get(TypeOfCopyLen, 0), "compare_n_to_0");
165 
166  // Copying backwards.
167  BasicBlock *LoopBB =
168  BasicBlock::Create(Context, "copy_backwards_loop", &F, CopyForwardBB);
169  IRBuilder<> LoopBuilder(LoopBB);
170  PHINode *LoopPhi = LoopBuilder.CreatePHI(TypeOfCopyLen, 0);
171  Value *IndexPtr = LoopBuilder.CreateSub(
172  LoopPhi, ConstantInt::get(TypeOfCopyLen, 1), "index_ptr");
173  Value *Element = LoopBuilder.CreateLoad(
174  LoopBuilder.CreateInBoundsGEP(SrcAddr, IndexPtr), "element");
175  LoopBuilder.CreateStore(Element,
176  LoopBuilder.CreateInBoundsGEP(DstAddr, IndexPtr));
177  LoopBuilder.CreateCondBr(
178  LoopBuilder.CreateICmpEQ(IndexPtr, ConstantInt::get(TypeOfCopyLen, 0)),
179  ExitBB, LoopBB);
180  LoopPhi->addIncoming(IndexPtr, LoopBB);
181  LoopPhi->addIncoming(CopyLen, CopyBackwardsBB);
182  BranchInst::Create(ExitBB, LoopBB, CompareN, ThenTerm);
183  ThenTerm->eraseFromParent();
184 
185  // Copying forward.
186  BasicBlock *FwdLoopBB =
187  BasicBlock::Create(Context, "copy_forward_loop", &F, ExitBB);
188  IRBuilder<> FwdLoopBuilder(FwdLoopBB);
189  PHINode *FwdCopyPhi = FwdLoopBuilder.CreatePHI(TypeOfCopyLen, 0, "index_ptr");
190  Value *FwdElement = FwdLoopBuilder.CreateLoad(
191  FwdLoopBuilder.CreateInBoundsGEP(SrcAddr, FwdCopyPhi), "element");
192  FwdLoopBuilder.CreateStore(
193  FwdElement, FwdLoopBuilder.CreateInBoundsGEP(DstAddr, FwdCopyPhi));
194  Value *FwdIndexPtr = FwdLoopBuilder.CreateAdd(
195  FwdCopyPhi, ConstantInt::get(TypeOfCopyLen, 1), "index_increment");
196  FwdLoopBuilder.CreateCondBr(FwdLoopBuilder.CreateICmpEQ(FwdIndexPtr, CopyLen),
197  ExitBB, FwdLoopBB);
198  FwdCopyPhi->addIncoming(FwdIndexPtr, FwdLoopBB);
199  FwdCopyPhi->addIncoming(ConstantInt::get(TypeOfCopyLen, 0), CopyForwardBB);
200 
201  BranchInst::Create(ExitBB, FwdLoopBB, CompareN, ElseTerm);
202  ElseTerm->eraseFromParent();
203 }
204 
205 // Lower memset to loop.
206 void convertMemSetToLoop(Instruction *ConvertedInst, Value *DstAddr,
207  Value *CopyLen, Value *SetValue, LLVMContext &Context,
208  Function &F) {
209  BasicBlock *OrigBB = ConvertedInst->getParent();
210  BasicBlock *NewBB =
211  ConvertedInst->getParent()->splitBasicBlock(ConvertedInst, "split");
212  BasicBlock *LoopBB = BasicBlock::Create(Context, "loadstoreloop", &F, NewBB);
213 
214  OrigBB->getTerminator()->setSuccessor(0, LoopBB);
215  IRBuilder<> Builder(OrigBB->getTerminator());
216 
217  // Cast pointer to the type of value getting stored
218  unsigned dstAS = cast<PointerType>(DstAddr->getType())->getAddressSpace();
219  DstAddr = Builder.CreateBitCast(DstAddr,
220  PointerType::get(SetValue->getType(), dstAS));
221 
222  IRBuilder<> LoopBuilder(LoopBB);
223  PHINode *LoopIndex = LoopBuilder.CreatePHI(CopyLen->getType(), 0);
224  LoopIndex->addIncoming(ConstantInt::get(CopyLen->getType(), 0), OrigBB);
225 
226  LoopBuilder.CreateStore(
227  SetValue,
228  LoopBuilder.CreateInBoundsGEP(SetValue->getType(), DstAddr, LoopIndex),
229  false);
230 
231  Value *NewIndex =
232  LoopBuilder.CreateAdd(LoopIndex, ConstantInt::get(CopyLen->getType(), 1));
233  LoopIndex->addIncoming(NewIndex, LoopBB);
234 
235  LoopBuilder.CreateCondBr(LoopBuilder.CreateICmpULT(NewIndex, CopyLen), LoopBB,
236  NewBB);
237 }
238 
239 bool NVPTXLowerAggrCopies::runOnFunction(Function &F) {
240  SmallVector<LoadInst *, 4> AggrLoads;
242 
243  const DataLayout &DL = F.getParent()->getDataLayout();
244  LLVMContext &Context = F.getParent()->getContext();
245 
246  // Collect all aggregate loads and mem* calls.
247  for (Function::iterator BI = F.begin(), BE = F.end(); BI != BE; ++BI) {
248  for (BasicBlock::iterator II = BI->begin(), IE = BI->end(); II != IE;
249  ++II) {
250  if (LoadInst *LI = dyn_cast<LoadInst>(II)) {
251  if (!LI->hasOneUse())
252  continue;
253 
254  if (DL.getTypeStoreSize(LI->getType()) < MaxAggrCopySize)
255  continue;
256 
257  if (StoreInst *SI = dyn_cast<StoreInst>(LI->user_back())) {
258  if (SI->getOperand(0) != LI)
259  continue;
260  AggrLoads.push_back(LI);
261  }
262  } else if (MemIntrinsic *IntrCall = dyn_cast<MemIntrinsic>(II)) {
263  // Convert intrinsic calls with variable size or with constant size
264  // larger than the MaxAggrCopySize threshold.
265  if (ConstantInt *LenCI = dyn_cast<ConstantInt>(IntrCall->getLength())) {
266  if (LenCI->getZExtValue() >= MaxAggrCopySize) {
267  MemCalls.push_back(IntrCall);
268  }
269  } else {
270  MemCalls.push_back(IntrCall);
271  }
272  }
273  }
274  }
275 
276  if (AggrLoads.size() == 0 && MemCalls.size() == 0) {
277  return false;
278  }
279 
280  //
281  // Do the transformation of an aggr load/copy/set to a loop
282  //
283  for (LoadInst *LI : AggrLoads) {
284  StoreInst *SI = dyn_cast<StoreInst>(*LI->user_begin());
285  Value *SrcAddr = LI->getOperand(0);
286  Value *DstAddr = SI->getOperand(1);
287  unsigned NumLoads = DL.getTypeStoreSize(LI->getType());
288  Value *CopyLen = ConstantInt::get(Type::getInt32Ty(Context), NumLoads);
289 
290  convertMemCpyToLoop(/* ConvertedInst */ SI,
291  /* SrcAddr */ SrcAddr, /* DstAddr */ DstAddr,
292  /* CopyLen */ CopyLen,
293  /* SrcIsVolatile */ LI->isVolatile(),
294  /* DstIsVolatile */ SI->isVolatile(),
295  /* Context */ Context,
296  /* Function F */ F);
297 
298  SI->eraseFromParent();
299  LI->eraseFromParent();
300  }
301 
302  // Transform mem* intrinsic calls.
303  for (MemIntrinsic *MemCall : MemCalls) {
304  if (MemCpyInst *Memcpy = dyn_cast<MemCpyInst>(MemCall)) {
305  convertMemCpyToLoop(/* ConvertedInst */ Memcpy,
306  /* SrcAddr */ Memcpy->getRawSource(),
307  /* DstAddr */ Memcpy->getRawDest(),
308  /* CopyLen */ Memcpy->getLength(),
309  /* SrcIsVolatile */ Memcpy->isVolatile(),
310  /* DstIsVolatile */ Memcpy->isVolatile(),
311  /* Context */ Context,
312  /* Function F */ F);
313  } else if (MemMoveInst *Memmove = dyn_cast<MemMoveInst>(MemCall)) {
314  convertMemMoveToLoop(/* ConvertedInst */ Memmove,
315  /* SrcAddr */ Memmove->getRawSource(),
316  /* DstAddr */ Memmove->getRawDest(),
317  /* CopyLen */ Memmove->getLength(),
318  /* SrcIsVolatile */ Memmove->isVolatile(),
319  /* DstIsVolatile */ Memmove->isVolatile(),
320  /* Context */ Context,
321  /* Function F */ F);
322 
323  } else if (MemSetInst *Memset = dyn_cast<MemSetInst>(MemCall)) {
324  convertMemSetToLoop(/* ConvertedInst */ Memset,
325  /* DstAddr */ Memset->getRawDest(),
326  /* CopyLen */ Memset->getLength(),
327  /* SetValue */ Memset->getValue(),
328  /* Context */ Context,
329  /* Function F */ F);
330  }
331  MemCall->eraseFromParent();
332  }
333 
334  return true;
335 }
336 
337 } // namespace
338 
339 namespace llvm {
341 }
342 
343 INITIALIZE_PASS(NVPTXLowerAggrCopies, "nvptx-lower-aggr-copies",
344  "Lower aggregate copies, and llvm.mem* intrinsics into loops",
345  false, false)
346 
348  return new NVPTXLowerAggrCopies();
349 }
SymbolTableList< Instruction >::iterator eraseFromParent()
This method unlinks 'this' from the containing basic block and deletes it.
Definition: Instruction.cpp:76
A parsed version of the target data layout string in and methods for querying it. ...
Definition: DataLayout.h:102
BranchInst * CreateCondBr(Value *Cond, BasicBlock *True, BasicBlock *False, MDNode *BranchWeights=nullptr, MDNode *Unpredictable=nullptr)
Create a conditional 'br Cond, TrueDest, FalseDest' instruction.
Definition: IRBuilder.h:699
AnalysisUsage & addPreserved()
Add the specified Pass class to the set of analyses preserved by this pass.
void addIncoming(Value *V, BasicBlock *BB)
Add an incoming value to the end of the PHI list.
void SplitBlockAndInsertIfThenElse(Value *Cond, Instruction *SplitBefore, TerminatorInst **ThenTerm, TerminatorInst **ElseTerm, MDNode *BranchWeights=nullptr)
SplitBlockAndInsertIfThenElse is similar to SplitBlockAndInsertIfThen, but also creates the ElseBlock...
LLVMContext & Context
bool isVolatile() const
Return true if this is a store to a volatile memory location.
Definition: Instructions.h:336
Value * CreateICmpULT(Value *LHS, Value *RHS, const Twine &Name="")
Definition: IRBuilder.h:1478
static void SetValue(Value *V, GenericValue Val, ExecutionContext &SF)
Definition: Execution.cpp:42
iterator end()
Definition: Function.h:537
static PointerType * get(Type *ElementType, unsigned AddressSpace)
This constructs a pointer to an object of the specified type in a numbered address space...
Definition: Type.cpp:655
unsigned less than
Definition: InstrTypes.h:905
This class wraps the llvm.memset intrinsic.
An instruction for reading from memory.
Definition: Instructions.h:164
This class wraps the llvm.memmove intrinsic.
This provides a uniform API for creating instructions and inserting them into a basic block: either a...
Definition: IRBuilder.h:588
Value * CreateAdd(Value *LHS, Value *RHS, const Twine &Name="", bool HasNUW=false, bool HasNSW=false)
Definition: IRBuilder.h:813
void setName(const Twine &Name)
Change the name of the value.
Definition: Value.cpp:257
StoreInst * CreateStore(Value *Val, Value *Ptr, bool isVolatile=false)
Definition: IRBuilder.h:1094
#define F(x, y, z)
Definition: MD5.cpp:51
void setSuccessor(unsigned idx, BasicBlock *B)
Update the specified successor to point at the provided block.
Definition: InstrTypes.h:84
Value * CreateSub(Value *LHS, Value *RHS, const Twine &Name="", bool HasNUW=false, bool HasNSW=false)
Definition: IRBuilder.h:835
An instruction for storing to memory.
Definition: Instructions.h:300
iterator begin()
Definition: Function.h:535
Value * CreateInBoundsGEP(Value *Ptr, ArrayRef< Value * > IdxList, const Twine &Name="")
Definition: IRBuilder.h:1158
LoadInst * CreateLoad(Value *Ptr, const char *Name)
Definition: IRBuilder.h:1082
Subclasses of this class are all able to terminate a basic block.
Definition: InstrTypes.h:52
void initializeNVPTXLowerAggrCopiesPass(PassRegistry &)
LLVM Basic Block Representation.
Definition: BasicBlock.h:51
The instances of the Type class are immutable: once they are created, they are never changed...
Definition: Type.h:45
This is an important class for using LLVM in a threaded context.
Definition: LLVMContext.h:48
This file contains the declarations for the subclasses of Constant, which represent the different fla...
Represent the analysis usage information of a pass.
This instruction compares its operands according to the predicate given to the constructor.
FunctionPass class - This class is used to implement most global optimizations.
Definition: Pass.h:298
Value * getOperand(unsigned i) const
Definition: User.h:145
static BasicBlock * Create(LLVMContext &Context, const Twine &Name="", Function *Parent=nullptr, BasicBlock *InsertBefore=nullptr)
Creates a new BasicBlock.
Definition: BasicBlock.h:93
Value * CreateICmpEQ(Value *LHS, Value *RHS, const Twine &Name="")
Definition: IRBuilder.h:1466
PHINode * CreatePHI(Type *Ty, unsigned NumReservedValues, const Twine &Name="")
Definition: IRBuilder.h:1574
This is the common base class for memset/memcpy/memmove.
Iterator for intrusive lists based on ilist_node.
This is the shared class of boolean and integer constants.
Definition: Constants.h:88
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small...
Definition: SmallVector.h:843
Module.h This file contains the declarations for the Module class.
Type * getType() const
All values are typed, get the type of this value.
Definition: Value.h:230
This class wraps the llvm.memcpy intrinsic.
static Constant * get(Type *Ty, uint64_t V, bool isSigned=false)
If Ty is a vector type, return a Constant with a splat of the given value.
Definition: Constants.cpp:558
static BranchInst * Create(BasicBlock *IfTrue, Instruction *InsertBefore=nullptr)
INITIALIZE_PASS(NVPTXLowerAggrCopies,"nvptx-lower-aggr-copies","Lower aggregate copies, and llvm.mem* intrinsics into loops", false, false) FunctionPass *llvm
IntegerType * getInt8Ty()
Fetch the type representing an 8-bit integer.
Definition: IRBuilder.h:337
const DataLayout & getDataLayout() const
Get the data layout for the module's target platform.
Definition: Module.cpp:384
FunctionPass * createLowerAggrCopies()
static IntegerType * getInt32Ty(LLVMContext &C)
Definition: Type.cpp:169
TerminatorInst * getTerminator()
Returns the terminator instruction if the block is well formed or null if the block is not well forme...
Definition: BasicBlock.cpp:124
LLVM_ATTRIBUTE_ALWAYS_INLINE size_type size() const
Definition: SmallVector.h:135
LLVM_NODISCARD std::enable_if<!is_simple_type< Y >::value, typename cast_retty< X, const Y >::ret_type >::type dyn_cast(const Y &Val)
Definition: Casting.h:287
BasicBlock * splitBasicBlock(iterator I, const Twine &BBName="")
Split the basic block into two basic blocks at the specified instruction.
Definition: BasicBlock.cpp:374
uint64_t getTypeStoreSize(Type *Ty) const
Returns the maximum number of bytes that may be overwritten by storing the specified type...
Definition: DataLayout.h:391
Module * getParent()
Get the module that this global value is contained inside of...
Definition: GlobalValue.h:537
LLVM Value Representation.
Definition: Value.h:71
StringRef - Represent a constant reference to a string, i.e.
Definition: StringRef.h:47
PassRegistry - This class manages the registration and intitialization of the pass subsystem as appli...
Definition: PassRegistry.h:40
const BasicBlock * getParent() const
Definition: Instruction.h:62
LLVMContext & getContext() const
Get the global data context.
Definition: Module.h:222