Line data Source code
1 : //===- AMDGPURewriteOutArgumentsPass.cpp - Create struct returns ----------===//
2 : //
3 : // The LLVM Compiler Infrastructure
4 : //
5 : // This file is distributed under the University of Illinois Open Source
6 : // License. See LICENSE.TXT for details.
7 : //
8 : //===----------------------------------------------------------------------===//
9 : //
10 : /// \file This pass attempts to replace out argument usage with a return of a
11 : /// struct.
12 : ///
13 : /// We can support returning a lot of values directly in registers, but
14 : /// idiomatic C code frequently uses a pointer argument to return a second value
15 : /// rather than returning a struct by value. GPU stack access is also quite
16 : /// painful, so we want to avoid that if possible. Passing a stack object
17 : /// pointer to a function also requires an additional address expansion code
18 : /// sequence to convert the pointer to be relative to the kernel's scratch wave
19 : /// offset register since the callee doesn't know what stack frame the incoming
20 : /// pointer is relative to.
21 : ///
22 : /// The goal is to try rewriting code that looks like this:
23 : ///
24 : /// int foo(int a, int b, int* out) {
25 : /// *out = bar();
26 : /// return a + b;
27 : /// }
28 : ///
29 : /// into something like this:
30 : ///
31 : /// std::pair<int, int> foo(int a, int b) {
32 : /// return std::make_pair(a + b, bar());
33 : /// }
34 : ///
35 : /// Typically the incoming pointer is a simple alloca for a temporary variable
36 : /// to use the API, which if replaced with a struct return will be easily SROA'd
37 : /// out when the stub function we create is inlined
38 : ///
39 : /// This pass introduces the struct return, but leaves the unused pointer
40 : /// arguments and introduces a new stub function calling the struct returning
41 : /// body. DeadArgumentElimination should be run after this to clean these up.
42 : //
43 : //===----------------------------------------------------------------------===//
44 :
45 : #include "AMDGPU.h"
46 : #include "Utils/AMDGPUBaseInfo.h"
47 : #include "llvm/Analysis/MemoryDependenceAnalysis.h"
48 : #include "llvm/ADT/DenseMap.h"
49 : #include "llvm/ADT/STLExtras.h"
50 : #include "llvm/ADT/SmallSet.h"
51 : #include "llvm/ADT/SmallVector.h"
52 : #include "llvm/ADT/Statistic.h"
53 : #include "llvm/Analysis/MemoryLocation.h"
54 : #include "llvm/IR/Argument.h"
55 : #include "llvm/IR/Attributes.h"
56 : #include "llvm/IR/BasicBlock.h"
57 : #include "llvm/IR/Constants.h"
58 : #include "llvm/IR/DataLayout.h"
59 : #include "llvm/IR/DerivedTypes.h"
60 : #include "llvm/IR/Function.h"
61 : #include "llvm/IR/IRBuilder.h"
62 : #include "llvm/IR/Instructions.h"
63 : #include "llvm/IR/Module.h"
64 : #include "llvm/IR/Type.h"
65 : #include "llvm/IR/Use.h"
66 : #include "llvm/IR/User.h"
67 : #include "llvm/IR/Value.h"
68 : #include "llvm/Pass.h"
69 : #include "llvm/Support/Casting.h"
70 : #include "llvm/Support/CommandLine.h"
71 : #include "llvm/Support/Debug.h"
72 : #include "llvm/Support/raw_ostream.h"
73 : #include <cassert>
74 : #include <utility>
75 :
76 : #define DEBUG_TYPE "amdgpu-rewrite-out-arguments"
77 :
78 : using namespace llvm;
79 :
80 : static cl::opt<bool> AnyAddressSpace(
81 : "amdgpu-any-address-space-out-arguments",
82 : cl::desc("Replace pointer out arguments with "
83 : "struct returns for non-private address space"),
84 : cl::Hidden,
85 : cl::init(false));
86 :
87 : static cl::opt<unsigned> MaxNumRetRegs(
88 : "amdgpu-max-return-arg-num-regs",
89 : cl::desc("Approximately limit number of return registers for replacing out arguments"),
90 : cl::Hidden,
91 : cl::init(16));
92 :
93 : STATISTIC(NumOutArgumentsReplaced,
94 : "Number out arguments moved to struct return values");
95 : STATISTIC(NumOutArgumentFunctionsReplaced,
96 : "Number of functions with out arguments moved to struct return values");
97 :
98 : namespace {
99 :
100 : class AMDGPURewriteOutArguments : public FunctionPass {
101 : private:
102 : const DataLayout *DL = nullptr;
103 : MemoryDependenceResults *MDA = nullptr;
104 :
105 : bool checkArgumentUses(Value &Arg) const;
106 : bool isOutArgumentCandidate(Argument &Arg) const;
107 :
108 : #ifndef NDEBUG
109 : bool isVec3ToVec4Shuffle(Type *Ty0, Type* Ty1) const;
110 : #endif
111 :
112 : public:
113 : static char ID;
114 :
115 4 : AMDGPURewriteOutArguments() : FunctionPass(ID) {}
116 :
117 2 : void getAnalysisUsage(AnalysisUsage &AU) const override {
118 : AU.addRequired<MemoryDependenceWrapperPass>();
119 2 : FunctionPass::getAnalysisUsage(AU);
120 2 : }
121 :
122 : bool doInitialization(Module &M) override;
123 : bool runOnFunction(Function &F) override;
124 : };
125 :
126 : } // end anonymous namespace
127 :
128 85105 : INITIALIZE_PASS_BEGIN(AMDGPURewriteOutArguments, DEBUG_TYPE,
129 : "AMDGPU Rewrite Out Arguments", false, false)
130 85105 : INITIALIZE_PASS_DEPENDENCY(MemoryDependenceWrapperPass)
131 199024 : INITIALIZE_PASS_END(AMDGPURewriteOutArguments, DEBUG_TYPE,
132 : "AMDGPU Rewrite Out Arguments", false, false)
133 :
134 : char AMDGPURewriteOutArguments::ID = 0;
135 :
136 56 : bool AMDGPURewriteOutArguments::checkArgumentUses(Value &Arg) const {
137 : const int MaxUses = 10;
138 : int UseCount = 0;
139 :
140 115 : for (Use &U : Arg.uses()) {
141 69 : StoreInst *SI = dyn_cast<StoreInst>(U.getUser());
142 69 : if (UseCount > MaxUses)
143 : return false;
144 :
145 69 : if (!SI) {
146 : auto *BCI = dyn_cast<BitCastInst>(U.getUser());
147 19 : if (!BCI || !BCI->hasOneUse())
148 : return false;
149 :
150 : // We don't handle multiple stores currently, so stores to aggregate
151 : // pointers aren't worth the trouble since they are canonically split up.
152 19 : Type *DestEltTy = BCI->getType()->getPointerElementType();
153 : if (DestEltTy->isAggregateType())
154 : return false;
155 :
156 : // We could handle these if we had a convenient way to bitcast between
157 : // them.
158 17 : Type *SrcEltTy = Arg.getType()->getPointerElementType();
159 17 : if (SrcEltTy->isArrayTy())
160 : return false;
161 :
162 : // Special case handle structs with single members. It is useful to handle
163 : // some casts between structs and non-structs, but we can't bitcast
164 : // directly between them. directly bitcast between them. Blender uses
165 : // some casts that look like { <3 x float> }* to <4 x float>*
166 17 : if ((SrcEltTy->isStructTy() && (SrcEltTy->getNumContainedTypes() != 1)))
167 : return false;
168 :
169 : // Clang emits OpenCL 3-vector type accesses with a bitcast to the
170 : // equivalent 4-element vector and accesses that, and we're looking for
171 : // this pointer cast.
172 15 : if (DL->getTypeAllocSize(SrcEltTy) != DL->getTypeAllocSize(DestEltTy))
173 : return false;
174 :
175 12 : return checkArgumentUses(*BCI);
176 : }
177 :
178 48 : if (!SI->isSimple() ||
179 48 : U.getOperandNo() != StoreInst::getPointerOperandIndex())
180 1 : return false;
181 :
182 47 : ++UseCount;
183 : }
184 :
185 : // Skip unused arguments.
186 46 : return UseCount > 0;
187 : }
188 :
189 81 : bool AMDGPURewriteOutArguments::isOutArgumentCandidate(Argument &Arg) const {
190 81 : const unsigned MaxOutArgSizeBytes = 4 * MaxNumRetRegs;
191 81 : PointerType *ArgTy = dyn_cast<PointerType>(Arg.getType());
192 :
193 : // TODO: It might be useful for any out arguments, not just privates.
194 61 : if (!ArgTy || (ArgTy->getAddressSpace() != DL->getAllocaAddrSpace() &&
195 60 : !AnyAddressSpace) ||
196 117 : Arg.hasByValAttr() || Arg.hasStructRetAttr() ||
197 57 : DL->getTypeStoreSize(ArgTy->getPointerElementType()) > MaxOutArgSizeBytes) {
198 25 : return false;
199 : }
200 :
201 56 : return checkArgumentUses(Arg);
202 : }
203 :
204 2 : bool AMDGPURewriteOutArguments::doInitialization(Module &M) {
205 2 : DL = &M.getDataLayout();
206 2 : return false;
207 : }
208 :
209 : #ifndef NDEBUG
210 : bool AMDGPURewriteOutArguments::isVec3ToVec4Shuffle(Type *Ty0, Type* Ty1) const {
211 : VectorType *VT0 = dyn_cast<VectorType>(Ty0);
212 : VectorType *VT1 = dyn_cast<VectorType>(Ty1);
213 : if (!VT0 || !VT1)
214 : return false;
215 :
216 : if (VT0->getNumElements() != 3 ||
217 : VT1->getNumElements() != 4)
218 : return false;
219 :
220 : return DL->getTypeSizeInBits(VT0->getElementType()) ==
221 : DL->getTypeSizeInBits(VT1->getElementType());
222 : }
223 : #endif
224 :
225 60 : bool AMDGPURewriteOutArguments::runOnFunction(Function &F) {
226 60 : if (skipFunction(F))
227 : return false;
228 :
229 : // TODO: Could probably handle variadic functions.
230 59 : if (F.isVarArg() || F.hasStructRetAttr() ||
231 58 : AMDGPU::isEntryFunctionCC(F.getCallingConv()))
232 1 : return false;
233 :
234 58 : MDA = &getAnalysis<MemoryDependenceWrapperPass>().getMemDep();
235 :
236 : unsigned ReturnNumRegs = 0;
237 58 : SmallSet<int, 4> OutArgIndexes;
238 : SmallVector<Type *, 4> ReturnTypes;
239 58 : Type *RetTy = F.getReturnType();
240 58 : if (!RetTy->isVoidTy()) {
241 9 : ReturnNumRegs = DL->getTypeStoreSize(RetTy) / 4;
242 :
243 9 : if (ReturnNumRegs >= MaxNumRetRegs)
244 : return false;
245 :
246 8 : ReturnTypes.push_back(RetTy);
247 : }
248 :
249 : SmallVector<Argument *, 4> OutArgs;
250 138 : for (Argument &Arg : F.args()) {
251 81 : if (isOutArgumentCandidate(Arg)) {
252 : LLVM_DEBUG(dbgs() << "Found possible out argument " << Arg
253 : << " in function " << F.getName() << '\n');
254 44 : OutArgs.push_back(&Arg);
255 : }
256 : }
257 :
258 57 : if (OutArgs.empty())
259 : return false;
260 :
261 : using ReplacementVec = SmallVector<std::pair<Argument *, Value *>, 4>;
262 :
263 : DenseMap<ReturnInst *, ReplacementVec> Replacements;
264 :
265 : SmallVector<ReturnInst *, 4> Returns;
266 89 : for (BasicBlock &BB : F) {
267 49 : if (ReturnInst *RI = dyn_cast<ReturnInst>(&BB.back()))
268 43 : Returns.push_back(RI);
269 : }
270 :
271 40 : if (Returns.empty())
272 : return false;
273 :
274 : bool Changing;
275 :
276 78 : do {
277 : Changing = false;
278 :
279 : // Keep retrying if we are able to successfully eliminate an argument. This
280 : // helps with cases with multiple arguments which may alias, such as in a
281 : // sincos implemntation. If we have 2 stores to arguments, on the first
282 : // attempt the MDA query will succeed for the second store but not the
283 : // first. On the second iteration we've removed that out clobbering argument
284 : // (by effectively moving it into another function) and will find the second
285 : // argument is OK to move.
286 166 : for (Argument *OutArg : OutArgs) {
287 : bool ThisReplaceable = true;
288 : SmallVector<std::pair<ReturnInst *, StoreInst *>, 4> ReplaceableStores;
289 :
290 88 : Type *ArgTy = OutArg->getType()->getPointerElementType();
291 :
292 : // Skip this argument if converting it will push us over the register
293 : // count to return limit.
294 :
295 : // TODO: This is an approximation. When legalized this could be more. We
296 : // can ask TLI for exactly how many.
297 88 : unsigned ArgNumRegs = DL->getTypeStoreSize(ArgTy) / 4;
298 176 : if (ArgNumRegs + ReturnNumRegs > MaxNumRetRegs)
299 : continue;
300 :
301 : // An argument is convertible only if all exit blocks are able to replace
302 : // it.
303 131 : for (ReturnInst *RI : Returns) {
304 90 : BasicBlock *BB = RI->getParent();
305 :
306 180 : MemDepResult Q = MDA->getPointerDependencyFrom(MemoryLocation(OutArg),
307 90 : true, BB->end(), BB, RI);
308 90 : StoreInst *SI = nullptr;
309 90 : if (Q.isDef())
310 43 : SI = dyn_cast<StoreInst>(Q.getInst());
311 :
312 90 : if (SI) {
313 : LLVM_DEBUG(dbgs() << "Found out argument store: " << *SI << '\n');
314 43 : ReplaceableStores.emplace_back(RI, SI);
315 : } else {
316 : ThisReplaceable = false;
317 47 : break;
318 : }
319 : }
320 :
321 : if (!ThisReplaceable)
322 47 : continue; // Try the next argument candidate.
323 :
324 83 : for (std::pair<ReturnInst *, StoreInst *> Store : ReplaceableStores) {
325 43 : Value *ReplVal = Store.second->getValueOperand();
326 :
327 : auto &ValVec = Replacements[Store.first];
328 86 : if (llvm::find_if(ValVec,
329 : [OutArg](const std::pair<Argument *, Value *> &Entry) {
330 : return Entry.first == OutArg;}) != ValVec.end()) {
331 : LLVM_DEBUG(dbgs()
332 : << "Saw multiple out arg stores" << *OutArg << '\n');
333 : // It is possible to see stores to the same argument multiple times,
334 : // but we expect these would have been optimized out already.
335 : ThisReplaceable = false;
336 1 : break;
337 : }
338 :
339 42 : ValVec.emplace_back(OutArg, ReplVal);
340 42 : Store.second->eraseFromParent();
341 : }
342 :
343 : if (ThisReplaceable) {
344 40 : ReturnTypes.push_back(ArgTy);
345 40 : OutArgIndexes.insert(OutArg->getArgNo());
346 : ++NumOutArgumentsReplaced;
347 : Changing = true;
348 : }
349 : }
350 : } while (Changing);
351 :
352 40 : if (Replacements.empty())
353 : return false;
354 :
355 36 : LLVMContext &Ctx = F.getParent()->getContext();
356 36 : StructType *NewRetTy = StructType::create(Ctx, ReturnTypes, F.getName());
357 :
358 72 : FunctionType *NewFuncTy = FunctionType::get(NewRetTy,
359 : F.getFunctionType()->params(),
360 : F.isVarArg());
361 :
362 : LLVM_DEBUG(dbgs() << "Computed new return type: " << *NewRetTy << '\n');
363 :
364 : Function *NewFunc = Function::Create(NewFuncTy, Function::PrivateLinkage,
365 36 : F.getName() + ".body");
366 36 : F.getParent()->getFunctionList().insert(F.getIterator(), NewFunc);
367 36 : NewFunc->copyAttributesFrom(&F);
368 36 : NewFunc->setComdat(F.getComdat());
369 :
370 : // We want to preserve the function and param attributes, but need to strip
371 : // off any return attributes, e.g. zeroext doesn't make sense with a struct.
372 36 : NewFunc->stealArgumentListFrom(F);
373 :
374 : AttrBuilder RetAttrs;
375 36 : RetAttrs.addAttribute(Attribute::SExt);
376 36 : RetAttrs.addAttribute(Attribute::ZExt);
377 36 : RetAttrs.addAttribute(Attribute::NoAlias);
378 36 : NewFunc->removeAttributes(AttributeList::ReturnIndex, RetAttrs);
379 : // TODO: How to preserve metadata?
380 :
381 : // Move the body of the function into the new rewritten function, and replace
382 : // this function with a stub.
383 36 : NewFunc->getBasicBlockList().splice(NewFunc->begin(), F.getBasicBlockList());
384 :
385 74 : for (std::pair<ReturnInst *, ReplacementVec> &Replacement : Replacements) {
386 38 : ReturnInst *RI = Replacement.first;
387 38 : IRBuilder<> B(RI);
388 38 : B.SetCurrentDebugLocation(RI->getDebugLoc());
389 :
390 : int RetIdx = 0;
391 38 : Value *NewRetVal = UndefValue::get(NewRetTy);
392 :
393 : Value *RetVal = RI->getReturnValue();
394 : if (RetVal)
395 14 : NewRetVal = B.CreateInsertValue(NewRetVal, RetVal, RetIdx++);
396 :
397 80 : for (std::pair<Argument *, Value *> ReturnPoint : Replacement.second) {
398 : Argument *Arg = ReturnPoint.first;
399 : Value *Val = ReturnPoint.second;
400 42 : Type *EltTy = Arg->getType()->getPointerElementType();
401 42 : if (Val->getType() != EltTy) {
402 : Type *EffectiveEltTy = EltTy;
403 : if (StructType *CT = dyn_cast<StructType>(EltTy)) {
404 : assert(CT->getNumContainedTypes() == 1);
405 8 : EffectiveEltTy = CT->getContainedType(0);
406 : }
407 :
408 26 : if (DL->getTypeSizeInBits(EffectiveEltTy) !=
409 13 : DL->getTypeSizeInBits(Val->getType())) {
410 : assert(isVec3ToVec4Shuffle(EffectiveEltTy, Val->getType()));
411 8 : Val = B.CreateShuffleVector(Val, UndefValue::get(Val->getType()),
412 : { 0, 1, 2 });
413 : }
414 :
415 13 : Val = B.CreateBitCast(Val, EffectiveEltTy);
416 :
417 : // Re-create single element composite.
418 13 : if (EltTy != EffectiveEltTy)
419 16 : Val = B.CreateInsertValue(UndefValue::get(EltTy), Val, 0);
420 : }
421 :
422 84 : NewRetVal = B.CreateInsertValue(NewRetVal, Val, RetIdx++);
423 : }
424 :
425 38 : if (RetVal)
426 : RI->setOperand(0, NewRetVal);
427 : else {
428 31 : B.CreateRet(NewRetVal);
429 31 : RI->eraseFromParent();
430 : }
431 : }
432 :
433 : SmallVector<Value *, 16> StubCallArgs;
434 90 : for (Argument &Arg : F.args()) {
435 54 : if (OutArgIndexes.count(Arg.getArgNo())) {
436 : // It's easier to preserve the type of the argument list. We rely on
437 : // DeadArgumentElimination to take care of these.
438 40 : StubCallArgs.push_back(UndefValue::get(Arg.getType()));
439 : } else {
440 14 : StubCallArgs.push_back(&Arg);
441 : }
442 : }
443 :
444 36 : BasicBlock *StubBB = BasicBlock::Create(Ctx, "", &F);
445 : IRBuilder<> B(StubBB);
446 36 : CallInst *StubCall = B.CreateCall(NewFunc, StubCallArgs);
447 :
448 72 : int RetIdx = RetTy->isVoidTy() ? 0 : 1;
449 90 : for (Argument &Arg : F.args()) {
450 54 : if (!OutArgIndexes.count(Arg.getArgNo()))
451 : continue;
452 :
453 40 : PointerType *ArgType = cast<PointerType>(Arg.getType());
454 :
455 40 : auto *EltTy = ArgType->getElementType();
456 40 : unsigned Align = Arg.getParamAlignment();
457 40 : if (Align == 0)
458 39 : Align = DL->getABITypeAlignment(EltTy);
459 :
460 80 : Value *Val = B.CreateExtractValue(StubCall, RetIdx++);
461 80 : Type *PtrTy = Val->getType()->getPointerTo(ArgType->getAddressSpace());
462 :
463 : // We can peek through bitcasts, so the type may not match.
464 40 : Value *PtrVal = B.CreateBitCast(&Arg, PtrTy);
465 :
466 : B.CreateAlignedStore(Val, PtrVal, Align);
467 : }
468 :
469 72 : if (!RetTy->isVoidTy()) {
470 14 : B.CreateRet(B.CreateExtractValue(StubCall, 0));
471 : } else {
472 29 : B.CreateRetVoid();
473 : }
474 :
475 : // The function is now a stub we want to inline.
476 : F.addFnAttr(Attribute::AlwaysInline);
477 :
478 : ++NumOutArgumentFunctionsReplaced;
479 : return true;
480 : }
481 :
482 0 : FunctionPass *llvm::createAMDGPURewriteOutArgumentsPass() {
483 0 : return new AMDGPURewriteOutArguments();
484 : }
|