Line data Source code
1 : //===-- SITypeRewriter.cpp - Remove unwanted types ------------------------===//
2 : //
3 : // The LLVM Compiler Infrastructure
4 : //
5 : // This file is distributed under the University of Illinois Open Source
6 : // License. See LICENSE.TXT for details.
7 : //
8 : //===----------------------------------------------------------------------===//
9 : //
10 : /// \file
11 : /// This pass removes performs the following type substitution on all
12 : /// non-compute shaders:
13 : ///
14 : /// v16i8 => i128
15 : /// - v16i8 is used for constant memory resource descriptors. This type is
16 : /// legal for some compute APIs, and we don't want to declare it as legal
17 : /// in the backend, because we want the legalizer to expand all v16i8
18 : /// operations.
19 : /// v1* => *
20 : /// - Having v1* types complicates the legalizer and we can easily replace
21 : /// - them with the element type.
22 : //===----------------------------------------------------------------------===//
23 :
24 : #include "AMDGPU.h"
25 : #include "Utils/AMDGPUBaseInfo.h"
26 : #include "llvm/IR/IRBuilder.h"
27 : #include "llvm/IR/InstVisitor.h"
28 :
29 : using namespace llvm;
30 :
31 : namespace {
32 :
33 2758 : class SITypeRewriter : public FunctionPass,
34 : public InstVisitor<SITypeRewriter> {
35 :
36 : static char ID;
37 : Module *Mod;
38 : Type *v16i8;
39 : Type *v4i32;
40 :
41 : public:
42 2774 : SITypeRewriter() : FunctionPass(ID) { }
43 : bool doInitialization(Module &M) override;
44 : bool runOnFunction(Function &F) override;
45 0 : StringRef getPassName() const override { return "SI Type Rewriter"; }
46 : void visitLoadInst(LoadInst &I);
47 : void visitCallInst(CallInst &I);
48 : void visitBitCast(BitCastInst &I);
49 : };
50 :
51 : } // End anonymous namespace
52 :
53 : char SITypeRewriter::ID = 0;
54 :
55 1383 : bool SITypeRewriter::doInitialization(Module &M) {
56 1383 : Mod = &M;
57 1383 : v16i8 = VectorType::get(Type::getInt8Ty(M.getContext()), 16);
58 1383 : v4i32 = VectorType::get(Type::getInt32Ty(M.getContext()), 4);
59 1383 : return false;
60 : }
61 :
62 13243 : bool SITypeRewriter::runOnFunction(Function &F) {
63 13243 : if (!AMDGPU::isShader(F.getCallingConv()))
64 : return false;
65 :
66 518 : visit(F);
67 518 : visit(F);
68 :
69 518 : return false;
70 : }
71 :
72 486 : void SITypeRewriter::visitLoadInst(LoadInst &I) {
73 486 : Value *Ptr = I.getPointerOperand();
74 486 : Type *PtrTy = Ptr->getType();
75 486 : Type *ElemTy = PtrTy->getPointerElementType();
76 1458 : IRBuilder<> Builder(&I);
77 486 : if (ElemTy == v16i8) {
78 196 : Value *BitCast = Builder.CreateBitCast(Ptr,
79 196 : PointerType::get(v4i32,PtrTy->getPointerAddressSpace()));
80 98 : LoadInst *Load = Builder.CreateLoad(BitCast);
81 196 : SmallVector<std::pair<unsigned, MDNode *>, 8> MD;
82 196 : I.getAllMetadataOtherThanDebugLoc(MD);
83 277 : for (unsigned i = 0, e = MD.size(); i != e; ++i) {
84 243 : Load->setMetadata(MD[i].first, MD[i].second);
85 : }
86 196 : Value *BitCastLoad = Builder.CreateBitCast(Load, I.getType());
87 98 : I.replaceAllUsesWith(BitCastLoad);
88 98 : I.eraseFromParent();
89 : }
90 486 : }
91 :
92 4022 : void SITypeRewriter::visitCallInst(CallInst &I) {
93 8561 : IRBuilder<> Builder(&I);
94 :
95 4539 : SmallVector <Value*, 8> Args;
96 4539 : SmallVector <Type*, 8> Types;
97 4022 : bool NeedToReplace = false;
98 4022 : Function *F = I.getCalledFunction();
99 4022 : if (!F)
100 3505 : return;
101 :
102 8525 : std::string Name = F->getName();
103 25770 : for (unsigned i = 0, e = I.getNumArgOperands(); i != e; ++i) {
104 17762 : Value *Arg = I.getArgOperand(i);
105 17762 : if (Arg->getType() == v16i8) {
106 1034 : Args.push_back(Builder.CreateBitCast(Arg, v4i32));
107 517 : Types.push_back(v4i32);
108 517 : NeedToReplace = true;
109 1034 : Name = Name + ".v4i32";
110 54130 : } else if (Arg->getType()->isVectorTy() &&
111 19640 : Arg->getType()->getVectorNumElements() == 1 &&
112 0 : Arg->getType()->getVectorElementType() ==
113 0 : Type::getInt32Ty(I.getContext())){
114 0 : Type *ElementTy = Arg->getType()->getVectorElementType();
115 0 : std::string TypeName = "i32";
116 0 : InsertElementInst *Def = cast<InsertElementInst>(Arg);
117 0 : Args.push_back(Def->getOperand(1));
118 0 : Types.push_back(ElementTy);
119 0 : std::string VecTypeName = "v1" + TypeName;
120 0 : Name = Name.replace(Name.find(VecTypeName), VecTypeName.length(), TypeName);
121 0 : NeedToReplace = true;
122 : } else {
123 17245 : Args.push_back(Arg);
124 17245 : Types.push_back(Arg->getType());
125 : }
126 : }
127 :
128 4004 : if (!NeedToReplace) {
129 : return;
130 : }
131 517 : Function *NewF = Mod->getFunction(Name);
132 517 : if (!NewF) {
133 140 : NewF = Function::Create(FunctionType::get(F->getReturnType(), Types, false), GlobalValue::ExternalLinkage, Name, Mod);
134 : NewF->setAttributes(F->getAttributes());
135 : }
136 1551 : I.replaceAllUsesWith(Builder.CreateCall(NewF, Args));
137 517 : I.eraseFromParent();
138 : }
139 :
140 1570 : void SITypeRewriter::visitBitCast(BitCastInst &I) {
141 3620 : IRBuilder<> Builder(&I);
142 3140 : if (I.getDestTy() != v4i32) {
143 1090 : return;
144 : }
145 :
146 1440 : if (BitCastInst *Op = dyn_cast<BitCastInst>(I.getOperand(0))) {
147 864 : if (Op->getSrcTy() == v4i32) {
148 864 : I.replaceAllUsesWith(Op->getOperand(0));
149 432 : I.eraseFromParent();
150 : }
151 : }
152 : }
153 :
154 1387 : FunctionPass *llvm::createSITypeRewriter() {
155 2774 : return new SITypeRewriter();
156 : }
|