LLVM  3.7.0
SITypeRewriter.cpp
Go to the documentation of this file.
1 //===-- SITypeRewriter.cpp - Remove unwanted types ------------------------===//
2 //
3 // The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 //
10 /// \file
11 /// This pass removes performs the following type substitution on all
12 /// non-compute shaders:
13 ///
14 /// v16i8 => i128
15 /// - v16i8 is used for constant memory resource descriptors. This type is
16 /// legal for some compute APIs, and we don't want to declare it as legal
17 /// in the backend, because we want the legalizer to expand all v16i8
18 /// operations.
19 /// v1* => *
20 /// - Having v1* types complicates the legalizer and we can easily replace
21 /// - them with the element type.
22 //===----------------------------------------------------------------------===//
23 
24 #include "AMDGPU.h"
25 #include "llvm/IR/IRBuilder.h"
26 #include "llvm/IR/InstVisitor.h"
27 
28 using namespace llvm;
29 
30 namespace {
31 
32 class SITypeRewriter : public FunctionPass,
33  public InstVisitor<SITypeRewriter> {
34 
35  static char ID;
36  Module *Mod;
37  Type *v16i8;
38  Type *v4i32;
39 
40 public:
41  SITypeRewriter() : FunctionPass(ID) { }
42  bool doInitialization(Module &M) override;
43  bool runOnFunction(Function &F) override;
44  const char *getPassName() const override {
45  return "SI Type Rewriter";
46  }
47  void visitLoadInst(LoadInst &I);
48  void visitCallInst(CallInst &I);
49  void visitBitCast(BitCastInst &I);
50 };
51 
52 } // End anonymous namespace
53 
54 char SITypeRewriter::ID = 0;
55 
56 bool SITypeRewriter::doInitialization(Module &M) {
57  Mod = &M;
58  v16i8 = VectorType::get(Type::getInt8Ty(M.getContext()), 16);
60  return false;
61 }
62 
63 bool SITypeRewriter::runOnFunction(Function &F) {
64  Attribute A = F.getFnAttribute("ShaderType");
65 
66  unsigned ShaderType = ShaderType::COMPUTE;
67  if (A.isStringAttribute()) {
68  StringRef Str = A.getValueAsString();
69  Str.getAsInteger(0, ShaderType);
70  }
71  if (ShaderType == ShaderType::COMPUTE)
72  return false;
73 
74  visit(F);
75  visit(F);
76 
77  return false;
78 }
79 
80 void SITypeRewriter::visitLoadInst(LoadInst &I) {
81  Value *Ptr = I.getPointerOperand();
82  Type *PtrTy = Ptr->getType();
83  Type *ElemTy = PtrTy->getPointerElementType();
84  IRBuilder<> Builder(&I);
85  if (ElemTy == v16i8) {
86  Value *BitCast = Builder.CreateBitCast(Ptr,
88  LoadInst *Load = Builder.CreateLoad(BitCast);
91  for (unsigned i = 0, e = MD.size(); i != e; ++i) {
92  Load->setMetadata(MD[i].first, MD[i].second);
93  }
94  Value *BitCastLoad = Builder.CreateBitCast(Load, I.getType());
95  I.replaceAllUsesWith(BitCastLoad);
96  I.eraseFromParent();
97  }
98 }
99 
100 void SITypeRewriter::visitCallInst(CallInst &I) {
101  IRBuilder<> Builder(&I);
102 
105  bool NeedToReplace = false;
106  Function *F = I.getCalledFunction();
107  std::string Name = F->getName();
108  for (unsigned i = 0, e = I.getNumArgOperands(); i != e; ++i) {
109  Value *Arg = I.getArgOperand(i);
110  if (Arg->getType() == v16i8) {
111  Args.push_back(Builder.CreateBitCast(Arg, v4i32));
112  Types.push_back(v4i32);
113  NeedToReplace = true;
114  Name = Name + ".v4i32";
115  } else if (Arg->getType()->isVectorTy() &&
116  Arg->getType()->getVectorNumElements() == 1 &&
117  Arg->getType()->getVectorElementType() ==
119  Type *ElementTy = Arg->getType()->getVectorElementType();
120  std::string TypeName = "i32";
121  InsertElementInst *Def = cast<InsertElementInst>(Arg);
122  Args.push_back(Def->getOperand(1));
123  Types.push_back(ElementTy);
124  std::string VecTypeName = "v1" + TypeName;
125  Name = Name.replace(Name.find(VecTypeName), VecTypeName.length(), TypeName);
126  NeedToReplace = true;
127  } else {
128  Args.push_back(Arg);
129  Types.push_back(Arg->getType());
130  }
131  }
132 
133  if (!NeedToReplace) {
134  return;
135  }
136  Function *NewF = Mod->getFunction(Name);
137  if (!NewF) {
138  NewF = Function::Create(FunctionType::get(F->getReturnType(), Types, false), GlobalValue::ExternalLinkage, Name, Mod);
139  NewF->setAttributes(F->getAttributes());
140  }
141  I.replaceAllUsesWith(Builder.CreateCall(NewF, Args));
142  I.eraseFromParent();
143 }
144 
145 void SITypeRewriter::visitBitCast(BitCastInst &I) {
146  IRBuilder<> Builder(&I);
147  if (I.getDestTy() != v4i32) {
148  return;
149  }
150 
151  if (BitCastInst *Op = dyn_cast<BitCastInst>(I.getOperand(0))) {
152  if (Op->getSrcTy() == v4i32) {
153  I.replaceAllUsesWith(Op->getOperand(0));
154  I.eraseFromParent();
155  }
156  }
157 }
158 
160  return new SITypeRewriter();
161 }
std::enable_if< std::numeric_limits< T >::is_signed, bool >::type getAsInteger(unsigned Radix, T &Result) const
Parse the current string as an integer of the specified radix.
Definition: StringRef.h:347
iplist< Instruction >::iterator eraseFromParent()
eraseFromParent - This method unlinks 'this' from the containing basic block and deletes it...
Definition: Instruction.cpp:70
Base class for instruction visitors.
Definition: InstVisitor.h:81
A Module instance is used to store all the information related to an LLVM module. ...
Definition: Module.h:114
CallInst - This class represents a function call, abstracting a target machine's calling convention...
static PointerType * get(Type *ElementType, unsigned AddressSpace)
PointerType::get - This constructs a pointer to an object of the specified type in a numbered address...
Definition: Type.cpp:738
Externally visible function.
Definition: GlobalValue.h:40
Type * getReturnType() const
Definition: Function.cpp:233
Attribute getFnAttribute(Attribute::AttrKind Kind) const
Return the attribute for the given attribute kind.
Definition: Function.h:225
F(f)
LoadInst - an instruction for reading from memory.
Definition: Instructions.h:177
unsigned getPointerAddressSpace() const
Get the address space of this pointer or pointer vector type.
Definition: Type.cpp:216
Type * getPointerElementType() const
Definition: Type.h:366
StringRef getName() const
Return a constant reference to the value's name.
Definition: Value.cpp:188
unsigned getNumArgOperands() const
getNumArgOperands - Return the number of call arguments.
This provides a uniform API for creating instructions and inserting them into a basic block: either a...
Definition: IRBuilder.h:517
Type * getVectorElementType() const
Definition: Type.h:364
FunctionPass * createSITypeRewriter()
This class represents a no-op cast from one type to another.
static FunctionType * get(Type *Result, ArrayRef< Type * > Params, bool isVarArg)
FunctionType::get - This static method is the primary way of constructing a FunctionType.
Definition: Type.cpp:361
void replaceAllUsesWith(Value *V)
Change all uses of this to point to a new Value.
Definition: Value.cpp:351
InsertElementInst - This instruction inserts a single (scalar) element into a VectorType value...
The instances of the Type class are immutable: once they are created, they are never changed...
Definition: Type.h:45
void getAllMetadataOtherThanDebugLoc(SmallVectorImpl< std::pair< unsigned, MDNode * >> &MDs) const
getAllMetadataOtherThanDebugLoc - This does the same thing as getAllMetadata, except that it filters ...
Definition: Instruction.h:190
bool isVectorTy() const
isVectorTy - True if this is an instance of VectorType.
Definition: Type.h:226
FunctionPass class - This class is used to implement most global optimizations.
Definition: Pass.h:294
Value * getOperand(unsigned i) const
Definition: User.h:118
Value * getPointerOperand()
Definition: Instructions.h:284
LLVMContext & getContext() const
All values hold a context through their type.
Definition: Value.cpp:519
void setMetadata(unsigned KindID, MDNode *Node)
setMetadata - Set the metadata of the specified kind to the specified node.
Definition: Metadata.cpp:1083
unsigned getVectorNumElements() const
Definition: Type.cpp:212
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small...
Definition: SmallVector.h:861
Type * getType() const
All values are typed, get the type of this value.
Definition: Value.h:222
Type * getDestTy() const
Return the destination type, as a convenience.
Definition: InstrTypes.h:656
Function * getCalledFunction() const
getCalledFunction - Return the function called, or null if this is an indirect function invocation...
AttributeSet getAttributes() const
Return the attribute list for this Function.
Definition: Function.h:181
Value * getArgOperand(unsigned i) const
getArgOperand/setArgOperand - Return/set the i-th call argument.
static IntegerType * getInt32Ty(LLVMContext &C)
Definition: Type.cpp:239
#define I(x, y, z)
Definition: MD5.cpp:54
bool isStringAttribute() const
Return true if the attribute is a string (target-dependent) attribute.
Definition: Attributes.cpp:115
void setAttributes(AttributeSet attrs)
Set the attribute list for this Function.
Definition: Function.h:184
StringRef getValueAsString() const
Return the attribute's value as a string.
Definition: Attributes.cpp:140
LLVM Value Representation.
Definition: Value.h:69
static VectorType * get(Type *ElementType, unsigned NumElements)
VectorType::get - This static method is the primary way to construct an VectorType.
Definition: Type.cpp:713
StringRef - Represent a constant reference to a string, i.e.
Definition: StringRef.h:40
static Function * Create(FunctionType *Ty, LinkageTypes Linkage, const Twine &N="", Module *M=nullptr)
Definition: Function.h:121
static IntegerType * getInt8Ty(LLVMContext &C)
Definition: Type.cpp:237
LLVMContext & getContext() const
Get the global data context.
Definition: Module.h:265