LLVM 18.0.0git
DXILOpBuilder.cpp
Go to the documentation of this file.
1//===- DXILOpBuilder.cpp - Helper class for build DIXLOp functions --------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8///
9/// \file This file contains class to help build DXIL op functions.
10//===----------------------------------------------------------------------===//
11
12#include "DXILOpBuilder.h"
13#include "DXILConstants.h"
14#include "llvm/IR/IRBuilder.h"
15#include "llvm/IR/Module.h"
18
19using namespace llvm;
20using namespace llvm::dxil;
21
22constexpr StringLiteral DXILOpNamePrefix = "dx.op.";
23
24namespace {
25
26enum OverloadKind : uint16_t {
27 VOID = 1,
28 HALF = 1 << 1,
29 FLOAT = 1 << 2,
30 DOUBLE = 1 << 3,
31 I1 = 1 << 4,
32 I8 = 1 << 5,
33 I16 = 1 << 6,
34 I32 = 1 << 7,
35 I64 = 1 << 8,
36 UserDefineType = 1 << 9,
37 ObjectType = 1 << 10,
38};
39
40} // namespace
41
42static const char *getOverloadTypeName(OverloadKind Kind) {
43 switch (Kind) {
44 case OverloadKind::HALF:
45 return "f16";
46 case OverloadKind::FLOAT:
47 return "f32";
48 case OverloadKind::DOUBLE:
49 return "f64";
50 case OverloadKind::I1:
51 return "i1";
52 case OverloadKind::I8:
53 return "i8";
54 case OverloadKind::I16:
55 return "i16";
56 case OverloadKind::I32:
57 return "i32";
58 case OverloadKind::I64:
59 return "i64";
60 case OverloadKind::VOID:
61 case OverloadKind::ObjectType:
62 case OverloadKind::UserDefineType:
63 break;
64 }
65 llvm_unreachable("invalid overload type for name");
66 return "void";
67}
68
69static OverloadKind getOverloadKind(Type *Ty) {
70 Type::TypeID T = Ty->getTypeID();
71 switch (T) {
72 case Type::VoidTyID:
73 return OverloadKind::VOID;
74 case Type::HalfTyID:
75 return OverloadKind::HALF;
76 case Type::FloatTyID:
77 return OverloadKind::FLOAT;
79 return OverloadKind::DOUBLE;
80 case Type::IntegerTyID: {
81 IntegerType *ITy = cast<IntegerType>(Ty);
82 unsigned Bits = ITy->getBitWidth();
83 switch (Bits) {
84 case 1:
85 return OverloadKind::I1;
86 case 8:
87 return OverloadKind::I8;
88 case 16:
89 return OverloadKind::I16;
90 case 32:
91 return OverloadKind::I32;
92 case 64:
93 return OverloadKind::I64;
94 default:
95 llvm_unreachable("invalid overload type");
96 return OverloadKind::VOID;
97 }
98 }
100 return OverloadKind::UserDefineType;
101 case Type::StructTyID:
102 return OverloadKind::ObjectType;
103 default:
104 llvm_unreachable("invalid overload type");
105 return OverloadKind::VOID;
106 }
107}
108
109static std::string getTypeName(OverloadKind Kind, Type *Ty) {
110 if (Kind < OverloadKind::UserDefineType) {
111 return getOverloadTypeName(Kind);
112 } else if (Kind == OverloadKind::UserDefineType) {
113 StructType *ST = cast<StructType>(Ty);
114 return ST->getStructName().str();
115 } else if (Kind == OverloadKind::ObjectType) {
116 StructType *ST = cast<StructType>(Ty);
117 return ST->getStructName().str();
118 } else {
119 std::string Str;
121 Ty->print(OS);
122 return OS.str();
123 }
124}
125
126// Static properties.
128 dxil::OpCode OpCode;
129 // Offset in DXILOpCodeNameTable.
131 dxil::OpCodeClass OpCodeClass;
132 // Offset in DXILOpCodeClassNameTable.
136 int OverloadParamIndex; // parameter index which control the overload.
137 // When < 0, should be only 1 overload type.
138 unsigned NumOfParameters; // Number of parameters include return value.
139 unsigned ParameterTableOffset; // Offset in ParameterTable.
140};
141
142// Include getOpCodeClassName getOpCodeProperty, getOpCodeName and
143// getOpCodeParameterKind which generated by tableGen.
144#define DXIL_OP_OPERATION_TABLE
145#include "DXILOperation.inc"
146#undef DXIL_OP_OPERATION_TABLE
147
148static std::string constructOverloadName(OverloadKind Kind, Type *Ty,
149 const OpCodeProperty &Prop) {
150 if (Kind == OverloadKind::VOID) {
151 return (Twine(DXILOpNamePrefix) + getOpCodeClassName(Prop)).str();
152 }
153 return (Twine(DXILOpNamePrefix) + getOpCodeClassName(Prop) + "." +
154 getTypeName(Kind, Ty))
155 .str();
156}
157
158static std::string constructOverloadTypeName(OverloadKind Kind,
159 StringRef TypeName) {
160 if (Kind == OverloadKind::VOID)
161 return TypeName.str();
162
163 assert(Kind < OverloadKind::UserDefineType && "invalid overload kind");
164 return (Twine(TypeName) + getOverloadTypeName(Kind)).str();
165}
166
168 ArrayRef<Type *> EltTys,
169 LLVMContext &Ctx) {
171 if (ST)
172 return ST;
173
174 return StructType::create(Ctx, EltTys, Name);
175}
176
177static StructType *getResRetType(Type *OverloadTy, LLVMContext &Ctx) {
178 OverloadKind Kind = getOverloadKind(OverloadTy);
179 std::string TypeName = constructOverloadTypeName(Kind, "dx.types.ResRet.");
180 Type *FieldTypes[5] = {OverloadTy, OverloadTy, OverloadTy, OverloadTy,
181 Type::getInt32Ty(Ctx)};
182 return getOrCreateStructType(TypeName, FieldTypes, Ctx);
183}
184
186 return getOrCreateStructType("dx.types.Handle", Type::getInt8PtrTy(Ctx), Ctx);
187}
188
190 auto &Ctx = OverloadTy->getContext();
191 switch (Kind) {
192 case ParameterKind::VOID:
193 return Type::getVoidTy(Ctx);
194 case ParameterKind::HALF:
195 return Type::getHalfTy(Ctx);
196 case ParameterKind::FLOAT:
197 return Type::getFloatTy(Ctx);
198 case ParameterKind::DOUBLE:
199 return Type::getDoubleTy(Ctx);
200 case ParameterKind::I1:
201 return Type::getInt1Ty(Ctx);
202 case ParameterKind::I8:
203 return Type::getInt8Ty(Ctx);
204 case ParameterKind::I16:
205 return Type::getInt16Ty(Ctx);
206 case ParameterKind::I32:
207 return Type::getInt32Ty(Ctx);
208 case ParameterKind::I64:
209 return Type::getInt64Ty(Ctx);
210 case ParameterKind::OVERLOAD:
211 return OverloadTy;
212 case ParameterKind::RESOURCE_RET:
213 return getResRetType(OverloadTy, Ctx);
214 case ParameterKind::DXIL_HANDLE:
215 return getHandleType(Ctx);
216 default:
217 break;
218 }
219 llvm_unreachable("Invalid parameter kind");
220 return nullptr;
221}
222
224 Type *OverloadTy) {
225 SmallVector<Type *> ArgTys;
226
227 auto ParamKinds = getOpCodeParameterKind(*Prop);
228
229 for (unsigned I = 0; I < Prop->NumOfParameters; ++I) {
230 ParameterKind Kind = ParamKinds[I];
231 ArgTys.emplace_back(getTypeFromParameterKind(Kind, OverloadTy));
232 }
233 return FunctionType::get(
234 ArgTys[0], ArrayRef<Type *>(&ArgTys[1], ArgTys.size() - 1), false);
235}
236
237static FunctionCallee getOrCreateDXILOpFunction(dxil::OpCode DXILOp,
238 Type *OverloadTy, Module &M) {
239 const OpCodeProperty *Prop = getOpCodeProperty(DXILOp);
240
241 OverloadKind Kind = getOverloadKind(OverloadTy);
242 // FIXME: find the issue and report error in clang instead of check it in
243 // backend.
244 if ((Prop->OverloadTys & (uint16_t)Kind) == 0) {
245 llvm_unreachable("invalid overload");
246 }
247
248 std::string FnName = constructOverloadName(Kind, OverloadTy, *Prop);
249 // Dependent on name to dedup.
250 if (auto *Fn = M.getFunction(FnName))
251 return FunctionCallee(Fn);
252
253 FunctionType *DXILOpFT = getDXILOpFunctionType(Prop, OverloadTy);
254 return M.getOrInsertFunction(FnName, DXILOpFT);
255}
256
257namespace llvm {
258namespace dxil {
259
260CallInst *DXILOpBuilder::createDXILOpCall(dxil::OpCode OpCode, Type *OverloadTy,
262 auto Fn = getOrCreateDXILOpFunction(OpCode, OverloadTy, M);
263 SmallVector<Value *> FullArgs;
264 FullArgs.emplace_back(B.getInt32((int32_t)OpCode));
265 FullArgs.append(Args.begin(), Args.end());
266 return B.CreateCall(Fn, FullArgs);
267}
268
270 bool NoOpCodeParam) {
271
272 const OpCodeProperty *Prop = getOpCodeProperty(OpCode);
273 if (Prop->OverloadParamIndex < 0) {
274 auto &Ctx = FT->getContext();
275 // When only has 1 overload type, just return it.
276 switch (Prop->OverloadTys) {
277 case OverloadKind::VOID:
278 return Type::getVoidTy(Ctx);
279 case OverloadKind::HALF:
280 return Type::getHalfTy(Ctx);
281 case OverloadKind::FLOAT:
282 return Type::getFloatTy(Ctx);
283 case OverloadKind::DOUBLE:
284 return Type::getDoubleTy(Ctx);
285 case OverloadKind::I1:
286 return Type::getInt1Ty(Ctx);
287 case OverloadKind::I8:
288 return Type::getInt8Ty(Ctx);
289 case OverloadKind::I16:
290 return Type::getInt16Ty(Ctx);
291 case OverloadKind::I32:
292 return Type::getInt32Ty(Ctx);
293 case OverloadKind::I64:
294 return Type::getInt64Ty(Ctx);
295 default:
296 llvm_unreachable("invalid overload type");
297 return nullptr;
298 }
299 }
300
301 // Prop->OverloadParamIndex is 0, overload type is FT->getReturnType().
302 Type *OverloadType = FT->getReturnType();
303 if (Prop->OverloadParamIndex != 0) {
304 // Skip Return Type and Type for DXIL opcode.
305 const unsigned SkipedParam = NoOpCodeParam ? 2 : 1;
306 OverloadType = FT->getParamType(Prop->OverloadParamIndex - SkipedParam);
307 }
308
309 auto ParamKinds = getOpCodeParameterKind(*Prop);
310 auto Kind = ParamKinds[Prop->OverloadParamIndex];
311 // For ResRet and CBufferRet, OverloadTy is in field of StructType.
312 if (Kind == ParameterKind::CBUFFER_RET ||
314 auto *ST = cast<StructType>(OverloadType);
315 OverloadType = ST->getElementType(0);
316 }
317 return OverloadType;
318}
319
320const char *DXILOpBuilder::getOpCodeName(dxil::OpCode DXILOp) {
321 return ::getOpCodeName(DXILOp);
322}
323} // namespace dxil
324} // namespace llvm
static FunctionType * getDXILOpFunctionType(const OpCodeProperty *Prop, Type *OverloadTy)
static StructType * getResRetType(Type *OverloadTy, LLVMContext &Ctx)
static const char * getOverloadTypeName(OverloadKind Kind)
static OverloadKind getOverloadKind(Type *Ty)
static StructType * getOrCreateStructType(StringRef Name, ArrayRef< Type * > EltTys, LLVMContext &Ctx)
static StructType * getHandleType(LLVMContext &Ctx)
static std::string constructOverloadName(OverloadKind Kind, Type *Ty, const OpCodeProperty &Prop)
constexpr StringLiteral DXILOpNamePrefix
static std::string constructOverloadTypeName(OverloadKind Kind, StringRef TypeName)
static Type * getTypeFromParameterKind(ParameterKind Kind, Type *OverloadTy)
static FunctionCallee getOrCreateDXILOpFunction(dxil::OpCode DXILOp, Type *OverloadTy, Module &M)
std::string Name
#define I(x, y, z)
Definition: MD5.cpp:58
Module.h This file contains the declarations for the Module class.
assert(ImpDefSCC.getReg()==AMDGPU::SCC &&ImpDefSCC.isDef())
raw_pwrite_stream & OS
ArrayRef - Represent a constant reference to an array (0 or more elements consecutively in memory),...
Definition: ArrayRef.h:41
AttrKind
This enumeration lists the attributes that can be associated with parameters, function results,...
Definition: Attributes.h:84
This class represents a function call, abstracting a target machine's calling convention.
A handy container for a FunctionType+Callee-pointer pair, which can be passed around as a single enti...
Definition: DerivedTypes.h:165
Class to represent function types.
Definition: DerivedTypes.h:103
Type * getParamType(unsigned i) const
Parameter type accessors.
Definition: DerivedTypes.h:135
Type * getReturnType() const
Definition: DerivedTypes.h:124
ConstantInt * getInt32(uint32_t C)
Get a constant 32-bit value.
Definition: IRBuilder.h:472
CallInst * CreateCall(FunctionType *FTy, Value *Callee, ArrayRef< Value * > Args=std::nullopt, const Twine &Name="", MDNode *FPMathTag=nullptr)
Definition: IRBuilder.h:2374
Class to represent integer types.
Definition: DerivedTypes.h:40
unsigned getBitWidth() const
Get the number of bits in this IntegerType.
Definition: DerivedTypes.h:72
This is an important class for using LLVM in a threaded context.
Definition: LLVMContext.h:67
A Module instance is used to store all the information related to an LLVM module.
Definition: Module.h:65
size_t size() const
Definition: SmallVector.h:91
reference emplace_back(ArgTypes &&... Args)
Definition: SmallVector.h:941
void append(ItTy in_start, ItTy in_end)
Add the specified range to the end of the SmallVector.
Definition: SmallVector.h:687
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
Definition: SmallVector.h:1200
A wrapper around a string literal that serves as a proxy for constructing global tables of StringRefs...
Definition: StringRef.h:857
StringRef - Represent a constant reference to a string, i.e.
Definition: StringRef.h:50
Class to represent struct types.
Definition: DerivedTypes.h:213
static StructType * getTypeByName(LLVMContext &C, StringRef Name)
Return the type with the specified name, or null if there is none by that name.
Definition: Type.cpp:633
static StructType * create(LLVMContext &Context, StringRef Name)
This creates an identified struct.
Definition: Type.cpp:514
Twine - A lightweight data structure for efficiently representing the concatenation of temporary valu...
Definition: Twine.h:81
The instances of the Type class are immutable: once they are created, they are never changed.
Definition: Type.h:45
static Type * getHalfTy(LLVMContext &C)
static Type * getDoubleTy(LLVMContext &C)
static IntegerType * getInt1Ty(LLVMContext &C)
TypeID
Definitions of all of the base types for the Type system.
Definition: Type.h:54
@ HalfTyID
16-bit floating point type
Definition: Type.h:56
@ VoidTyID
type with no size
Definition: Type.h:63
@ FloatTyID
32-bit floating point type
Definition: Type.h:58
@ StructTyID
Structures.
Definition: Type.h:74
@ IntegerTyID
Arbitrary bit width integers.
Definition: Type.h:71
@ DoubleTyID
64-bit floating point type
Definition: Type.h:59
@ PointerTyID
Pointers.
Definition: Type.h:73
void print(raw_ostream &O, bool IsForDebug=false, bool NoDetails=false) const
Print the current type.
static Type * getVoidTy(LLVMContext &C)
static IntegerType * getInt16Ty(LLVMContext &C)
LLVMContext & getContext() const
Return the LLVMContext in which this type was uniqued.
Definition: Type.h:129
static IntegerType * getInt8Ty(LLVMContext &C)
static PointerType * getInt8PtrTy(LLVMContext &C, unsigned AS=0)
static IntegerType * getInt32Ty(LLVMContext &C)
static IntegerType * getInt64Ty(LLVMContext &C)
static Type * getFloatTy(LLVMContext &C)
TypeID getTypeID() const
Return the type id for the type.
Definition: Type.h:137
Type * getOverloadTy(dxil::OpCode OpCode, FunctionType *FT, bool NoOpCodeParam)
static const char * getOpCodeName(dxil::OpCode DXILOp)
CallInst * createDXILOpCall(dxil::OpCode OpCode, Type *OverloadTy, llvm::iterator_range< Use * > Args)
A range adaptor for a pair of iterators.
A raw_ostream that writes to an std::string.
Definition: raw_ostream.h:642
#define llvm_unreachable(msg)
Marks that the current location is not supposed to be reachable.
This is an optimization pass for GlobalISel generic memory operations.
Definition: AddressRanges.h:18
StringRef getTypeName()
We provide a function which tries to compute the (demangled) name of a type statically.
Definition: TypeName.h:27
uint16_t OverloadTys
dxil::OpCodeClass OpCodeClass
unsigned OpCodeNameOffset
unsigned OpCodeClassNameOffset
unsigned ParameterTableOffset
unsigned NumOfParameters
llvm::Attribute::AttrKind FuncAttr
dxil::OpCode OpCode