LLVM 18.0.0git
SPIRVLegalizerInfo.cpp
Go to the documentation of this file.
1//===- SPIRVLegalizerInfo.cpp --- SPIR-V Legalization Rules ------*- C++ -*-==//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements the targeting of the Machinelegalizer class for SPIR-V.
10//
11//===----------------------------------------------------------------------===//
12
13#include "SPIRVLegalizerInfo.h"
14#include "SPIRV.h"
15#include "SPIRVGlobalRegistry.h"
16#include "SPIRVSubtarget.h"
22
23using namespace llvm;
24using namespace llvm::LegalizeActions;
25using namespace llvm::LegalityPredicates;
26
27static const std::set<unsigned> TypeFoldingSupportingOpcs = {
28 TargetOpcode::G_ADD,
29 TargetOpcode::G_FADD,
30 TargetOpcode::G_SUB,
31 TargetOpcode::G_FSUB,
32 TargetOpcode::G_MUL,
33 TargetOpcode::G_FMUL,
34 TargetOpcode::G_SDIV,
35 TargetOpcode::G_UDIV,
36 TargetOpcode::G_FDIV,
37 TargetOpcode::G_SREM,
38 TargetOpcode::G_UREM,
39 TargetOpcode::G_FREM,
40 TargetOpcode::G_FNEG,
41 TargetOpcode::G_CONSTANT,
42 TargetOpcode::G_FCONSTANT,
43 TargetOpcode::G_AND,
44 TargetOpcode::G_OR,
45 TargetOpcode::G_XOR,
46 TargetOpcode::G_SHL,
47 TargetOpcode::G_ASHR,
48 TargetOpcode::G_LSHR,
49 TargetOpcode::G_SELECT,
50 TargetOpcode::G_EXTRACT_VECTOR_ELT,
51};
52
54 return TypeFoldingSupportingOpcs.count(Opcode) > 0;
55}
56
58 using namespace TargetOpcode;
59
60 this->ST = &ST;
61 GR = ST.getSPIRVGlobalRegistry();
62
63 const LLT s1 = LLT::scalar(1);
64 const LLT s8 = LLT::scalar(8);
65 const LLT s16 = LLT::scalar(16);
66 const LLT s32 = LLT::scalar(32);
67 const LLT s64 = LLT::scalar(64);
68
69 const LLT v16s64 = LLT::fixed_vector(16, 64);
70 const LLT v16s32 = LLT::fixed_vector(16, 32);
71 const LLT v16s16 = LLT::fixed_vector(16, 16);
72 const LLT v16s8 = LLT::fixed_vector(16, 8);
73 const LLT v16s1 = LLT::fixed_vector(16, 1);
74
75 const LLT v8s64 = LLT::fixed_vector(8, 64);
76 const LLT v8s32 = LLT::fixed_vector(8, 32);
77 const LLT v8s16 = LLT::fixed_vector(8, 16);
78 const LLT v8s8 = LLT::fixed_vector(8, 8);
79 const LLT v8s1 = LLT::fixed_vector(8, 1);
80
81 const LLT v4s64 = LLT::fixed_vector(4, 64);
82 const LLT v4s32 = LLT::fixed_vector(4, 32);
83 const LLT v4s16 = LLT::fixed_vector(4, 16);
84 const LLT v4s8 = LLT::fixed_vector(4, 8);
85 const LLT v4s1 = LLT::fixed_vector(4, 1);
86
87 const LLT v3s64 = LLT::fixed_vector(3, 64);
88 const LLT v3s32 = LLT::fixed_vector(3, 32);
89 const LLT v3s16 = LLT::fixed_vector(3, 16);
90 const LLT v3s8 = LLT::fixed_vector(3, 8);
91 const LLT v3s1 = LLT::fixed_vector(3, 1);
92
93 const LLT v2s64 = LLT::fixed_vector(2, 64);
94 const LLT v2s32 = LLT::fixed_vector(2, 32);
95 const LLT v2s16 = LLT::fixed_vector(2, 16);
96 const LLT v2s8 = LLT::fixed_vector(2, 8);
97 const LLT v2s1 = LLT::fixed_vector(2, 1);
98
99 const unsigned PSize = ST.getPointerSize();
100 const LLT p0 = LLT::pointer(0, PSize); // Function
101 const LLT p1 = LLT::pointer(1, PSize); // CrossWorkgroup
102 const LLT p2 = LLT::pointer(2, PSize); // UniformConstant
103 const LLT p3 = LLT::pointer(3, PSize); // Workgroup
104 const LLT p4 = LLT::pointer(4, PSize); // Generic
105 const LLT p5 = LLT::pointer(5, PSize); // Input
106
107 // TODO: remove copy-pasting here by using concatenation in some way.
108 auto allPtrsScalarsAndVectors = {
109 p0, p1, p2, p3, p4, p5, s1, s8, s16,
110 s32, s64, v2s1, v2s8, v2s16, v2s32, v2s64, v3s1, v3s8,
111 v3s16, v3s32, v3s64, v4s1, v4s8, v4s16, v4s32, v4s64, v8s1,
112 v8s8, v8s16, v8s32, v8s64, v16s1, v16s8, v16s16, v16s32, v16s64};
113
114 auto allScalarsAndVectors = {
115 s1, s8, s16, s32, s64, v2s1, v2s8, v2s16, v2s32, v2s64,
116 v3s1, v3s8, v3s16, v3s32, v3s64, v4s1, v4s8, v4s16, v4s32, v4s64,
117 v8s1, v8s8, v8s16, v8s32, v8s64, v16s1, v16s8, v16s16, v16s32, v16s64};
118
119 auto allIntScalarsAndVectors = {s8, s16, s32, s64, v2s8, v2s16,
120 v2s32, v2s64, v3s8, v3s16, v3s32, v3s64,
121 v4s8, v4s16, v4s32, v4s64, v8s8, v8s16,
122 v8s32, v8s64, v16s8, v16s16, v16s32, v16s64};
123
124 auto allBoolScalarsAndVectors = {s1, v2s1, v3s1, v4s1, v8s1, v16s1};
125
126 auto allIntScalars = {s8, s16, s32, s64};
127
128 auto allFloatScalarsAndVectors = {
129 s16, s32, s64, v2s16, v2s32, v2s64, v3s16, v3s32, v3s64,
130 v4s16, v4s32, v4s64, v8s16, v8s32, v8s64, v16s16, v16s32, v16s64};
131
132 auto allFloatAndIntScalars = allIntScalars;
133
134 auto allPtrs = {p0, p1, p2, p3, p4, p5};
135 auto allWritablePtrs = {p0, p1, p3, p4};
136
137 for (auto Opc : TypeFoldingSupportingOpcs)
139
141
142 // TODO: add proper rules for vectors legalization.
143 getActionDefinitionsBuilder({G_BUILD_VECTOR, G_SHUFFLE_VECTOR}).alwaysLegal();
144
145 getActionDefinitionsBuilder({G_MEMCPY, G_MEMMOVE})
146 .legalIf(all(typeInSet(0, allWritablePtrs), typeInSet(1, allPtrs)));
147
149 all(typeInSet(0, allWritablePtrs), typeInSet(1, allIntScalars)));
150
151 getActionDefinitionsBuilder(G_ADDRSPACE_CAST)
152 .legalForCartesianProduct(allPtrs, allPtrs);
153
154 getActionDefinitionsBuilder({G_LOAD, G_STORE}).legalIf(typeInSet(1, allPtrs));
155
156 getActionDefinitionsBuilder(G_BITREVERSE).legalFor(allFloatScalarsAndVectors);
157
158 getActionDefinitionsBuilder(G_FMA).legalFor(allFloatScalarsAndVectors);
159
160 getActionDefinitionsBuilder({G_FPTOSI, G_FPTOUI})
161 .legalForCartesianProduct(allIntScalarsAndVectors,
162 allFloatScalarsAndVectors);
163
164 getActionDefinitionsBuilder({G_SITOFP, G_UITOFP})
165 .legalForCartesianProduct(allFloatScalarsAndVectors,
166 allScalarsAndVectors);
167
168 getActionDefinitionsBuilder({G_SMIN, G_SMAX, G_UMIN, G_UMAX, G_ABS})
169 .legalFor(allIntScalarsAndVectors);
170
172 allIntScalarsAndVectors, allIntScalarsAndVectors);
173
174 getActionDefinitionsBuilder(G_PHI).legalFor(allPtrsScalarsAndVectors);
175
177 typeInSet(0, allPtrsScalarsAndVectors),
178 typeInSet(1, allPtrsScalarsAndVectors),
179 LegalityPredicate(([=](const LegalityQuery &Query) {
180 return Query.Types[0].getSizeInBits() == Query.Types[1].getSizeInBits();
181 }))));
182
184
186 .legalForCartesianProduct(allPtrs, allIntScalars);
188 .legalForCartesianProduct(allIntScalars, allPtrs);
190 allPtrs, allIntScalars);
191
192 // ST.canDirectlyComparePointers() for pointer args is supported in
193 // legalizeCustom().
195 all(typeInSet(0, allBoolScalarsAndVectors),
196 typeInSet(1, allPtrsScalarsAndVectors)));
197
199 all(typeInSet(0, allBoolScalarsAndVectors),
200 typeInSet(1, allFloatScalarsAndVectors)));
201
202 getActionDefinitionsBuilder({G_ATOMICRMW_OR, G_ATOMICRMW_ADD, G_ATOMICRMW_AND,
203 G_ATOMICRMW_MAX, G_ATOMICRMW_MIN,
204 G_ATOMICRMW_SUB, G_ATOMICRMW_XOR,
205 G_ATOMICRMW_UMAX, G_ATOMICRMW_UMIN})
206 .legalForCartesianProduct(allIntScalars, allWritablePtrs);
207
208 getActionDefinitionsBuilder(G_ATOMICRMW_XCHG)
209 .legalForCartesianProduct(allFloatAndIntScalars, allWritablePtrs);
210
211 getActionDefinitionsBuilder(G_ATOMIC_CMPXCHG_WITH_SUCCESS).lower();
212 // TODO: add proper legalization rules.
213 getActionDefinitionsBuilder(G_ATOMIC_CMPXCHG).alwaysLegal();
214
215 getActionDefinitionsBuilder({G_UADDO, G_USUBO, G_SMULO, G_UMULO})
216 .alwaysLegal();
217
218 // Extensions.
219 getActionDefinitionsBuilder({G_TRUNC, G_ZEXT, G_SEXT, G_ANYEXT})
220 .legalForCartesianProduct(allScalarsAndVectors);
221
222 // FP conversions.
223 getActionDefinitionsBuilder({G_FPTRUNC, G_FPEXT})
224 .legalForCartesianProduct(allFloatScalarsAndVectors);
225
226 // Pointer-handling.
227 getActionDefinitionsBuilder(G_FRAME_INDEX).legalFor({p0});
228
229 // Control-flow. In some cases (e.g. constants) s1 may be promoted to s32.
230 getActionDefinitionsBuilder(G_BRCOND).legalFor({s1, s32});
231
232 // TODO: Review the target OpenCL and GLSL Extended Instruction Set specs to
233 // tighten these requirements. Many of these math functions are only legal on
234 // specific bitwidths, so they are not selectable for
235 // allFloatScalarsAndVectors.
237 G_FEXP,
238 G_FEXP2,
239 G_FLOG,
240 G_FLOG2,
241 G_FLOG10,
242 G_FABS,
243 G_FMINNUM,
244 G_FMAXNUM,
245 G_FCEIL,
246 G_FCOS,
247 G_FSIN,
248 G_FSQRT,
249 G_FFLOOR,
250 G_FRINT,
251 G_FNEARBYINT,
252 G_INTRINSIC_ROUND,
253 G_INTRINSIC_TRUNC,
254 G_FMINIMUM,
255 G_FMAXIMUM,
256 G_INTRINSIC_ROUNDEVEN})
257 .legalFor(allFloatScalarsAndVectors);
258
259 getActionDefinitionsBuilder(G_FCOPYSIGN)
260 .legalForCartesianProduct(allFloatScalarsAndVectors,
261 allFloatScalarsAndVectors);
262
264 allFloatScalarsAndVectors, allIntScalarsAndVectors);
265
266 if (ST.canUseExtInstSet(SPIRV::InstructionSet::OpenCL_std)) {
268 {G_CTTZ, G_CTTZ_ZERO_UNDEF, G_CTLZ, G_CTLZ_ZERO_UNDEF})
269 .legalForCartesianProduct(allIntScalarsAndVectors,
270 allIntScalarsAndVectors);
271
272 // Struct return types become a single scalar, so cannot easily legalize.
273 getActionDefinitionsBuilder({G_SMULH, G_UMULH}).alwaysLegal();
274 }
275
277 verify(*ST.getInstrInfo());
278}
279
280static Register convertPtrToInt(Register Reg, LLT ConvTy, SPIRVType *SpirvType,
281 LegalizerHelper &Helper,
284 Register ConvReg = MRI.createGenericVirtualRegister(ConvTy);
285 GR->assignSPIRVTypeToVReg(SpirvType, ConvReg, Helper.MIRBuilder.getMF());
286 Helper.MIRBuilder.buildInstr(TargetOpcode::G_PTRTOINT)
287 .addDef(ConvReg)
288 .addUse(Reg);
289 return ConvReg;
290}
291
293 MachineInstr &MI) const {
294 auto Opc = MI.getOpcode();
295 MachineRegisterInfo &MRI = MI.getMF()->getRegInfo();
296 if (!isTypeFoldingSupported(Opc)) {
297 assert(Opc == TargetOpcode::G_ICMP);
298 assert(GR->getSPIRVTypeForVReg(MI.getOperand(0).getReg()));
299 auto &Op0 = MI.getOperand(2);
300 auto &Op1 = MI.getOperand(3);
301 Register Reg0 = Op0.getReg();
302 Register Reg1 = Op1.getReg();
304 static_cast<CmpInst::Predicate>(MI.getOperand(1).getPredicate());
305 if ((!ST->canDirectlyComparePointers() ||
307 MRI.getType(Reg0).isPointer() && MRI.getType(Reg1).isPointer()) {
308 LLT ConvT = LLT::scalar(ST->getPointerSize());
309 Type *LLVMTy = IntegerType::get(MI.getMF()->getFunction().getContext(),
310 ST->getPointerSize());
311 SPIRVType *SpirvTy = GR->getOrCreateSPIRVType(LLVMTy, Helper.MIRBuilder);
312 Op0.setReg(convertPtrToInt(Reg0, ConvT, SpirvTy, Helper, MRI, GR));
313 Op1.setReg(convertPtrToInt(Reg1, ConvT, SpirvTy, Helper, MRI, GR));
314 }
315 return true;
316 }
317 // TODO: implement legalization for other opcodes.
318 return true;
319}
unsigned const MachineRegisterInfo * MRI
IRTranslator LLVM IR MI
This file declares the MachineIRBuilder class.
ppc ctr loops verify
const SmallVectorImpl< MachineOperand > & Cond
assert(ImpDefSCC.getReg()==AMDGPU::SCC &&ImpDefSCC.isDef())
bool isTypeFoldingSupported(unsigned Opcode)
static const std::set< unsigned > TypeFoldingSupportingOpcs
static Register convertPtrToInt(Register Reg, LLT ConvTy, SPIRVType *SpirvType, LegalizerHelper &Helper, MachineRegisterInfo &MRI, SPIRVGlobalRegistry *GR)
bool isTypeFoldingSupported(unsigned Opcode)
static constexpr uint32_t Opcode
Definition: aarch32.h:200
Predicate
This enumeration lists the possible predicates for CmpInst subclasses.
Definition: InstrTypes.h:748
@ ICMP_EQ
equal
Definition: InstrTypes.h:769
@ ICMP_NE
not equal
Definition: InstrTypes.h:770
static IntegerType * get(LLVMContext &C, unsigned NumBits)
This static method is the primary way of constructing an IntegerType.
Definition: Type.cpp:285
static constexpr LLT scalar(unsigned SizeInBits)
Get a low-level scalar or aggregate "bag of bits".
Definition: LowLevelType.h:42
static constexpr LLT pointer(unsigned AddressSpace, unsigned SizeInBits)
Get a low-level pointer in the given address space.
Definition: LowLevelType.h:49
static constexpr LLT fixed_vector(unsigned NumElements, unsigned ScalarSizeInBits)
Get a low-level fixed-width vector of some number of elements and element width.
Definition: LowLevelType.h:92
void computeTables()
Compute any ancillary tables needed to quickly decide how an operation should be handled.
LegalizeRuleSet & legalFor(std::initializer_list< LLT > Types)
The instruction is legal when type index 0 is any type in the given list.
LegalizeRuleSet & lower()
The instruction is lowered.
LegalizeRuleSet & custom()
Unconditionally custom lower.
LegalizeRuleSet & alwaysLegal()
LegalizeRuleSet & customIf(LegalityPredicate Predicate)
LegalizeRuleSet & legalForCartesianProduct(std::initializer_list< LLT > Types)
The instruction is legal when type indexes 0 and 1 are both in the given list.
LegalizeRuleSet & legalIf(LegalityPredicate Predicate)
The instruction is legal if predicate is true.
MachineIRBuilder & MIRBuilder
Expose MIRBuilder so clients can set their own RecordInsertInstruction functions.
LegalizeRuleSet & getActionDefinitionsBuilder(unsigned Opcode)
Get the action definition builder for the given opcode.
const LegacyLegalizerInfo & getLegacyLegalizerInfo() const
MachineInstrBuilder buildInstr(unsigned Opcode)
Build and insert <empty> = Opcode <empty>.
MachineFunction & getMF()
Getter for the function we currently build.
const MachineInstrBuilder & addUse(Register RegNo, unsigned Flags=0, unsigned SubReg=0) const
Add a virtual register use operand.
const MachineInstrBuilder & addDef(Register RegNo, unsigned Flags=0, unsigned SubReg=0) const
Add a virtual register definition operand.
Representation of each machine instruction.
Definition: MachineInstr.h:68
MachineRegisterInfo - Keep track of information for virtual and physical registers,...
Wrapper class representing virtual and physical registers.
Definition: Register.h:19
SPIRVType * getSPIRVTypeForVReg(Register VReg) const
SPIRVType * getOrCreateSPIRVType(const Type *Type, MachineIRBuilder &MIRBuilder, SPIRV::AccessQualifier::AccessQualifier AQ=SPIRV::AccessQualifier::ReadWrite, bool EmitIR=true)
void assignSPIRVTypeToVReg(SPIRVType *Type, Register VReg, MachineFunction &MF)
SPIRVLegalizerInfo(const SPIRVSubtarget &ST)
bool legalizeCustom(LegalizerHelper &Helper, MachineInstr &MI) const override
Called for instructions with the Custom LegalizationAction.
unsigned getPointerSize() const
bool canDirectlyComparePointers() const
The instances of the Type class are immutable: once they are created, they are never changed.
Definition: Type.h:45
LegalityPredicate typeInSet(unsigned TypeIdx, std::initializer_list< LLT > TypesInit)
True iff the given type index is one of the specified types.
Predicate all(Predicate P0, Predicate P1)
True iff P0 and P1 are true.
This is an optimization pass for GlobalISel generic memory operations.
Definition: AddressRanges.h:18
std::function< bool(const LegalityQuery &)> LegalityPredicate
The LegalityQuery object bundles together all the information that's needed to decide whether a given...
ArrayRef< LLT > Types