LLVM 20.0.0git
SPIRVLegalizerInfo.cpp
Go to the documentation of this file.
1//===- SPIRVLegalizerInfo.cpp --- SPIR-V Legalization Rules ------*- C++ -*-==//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements the targeting of the Machinelegalizer class for SPIR-V.
10//
11//===----------------------------------------------------------------------===//
12
13#include "SPIRVLegalizerInfo.h"
14#include "SPIRV.h"
15#include "SPIRVGlobalRegistry.h"
16#include "SPIRVSubtarget.h"
22
23using namespace llvm;
24using namespace llvm::LegalizeActions;
25using namespace llvm::LegalityPredicates;
26
27static const std::set<unsigned> TypeFoldingSupportingOpcs = {
28 TargetOpcode::G_ADD,
29 TargetOpcode::G_FADD,
30 TargetOpcode::G_SUB,
31 TargetOpcode::G_FSUB,
32 TargetOpcode::G_MUL,
33 TargetOpcode::G_FMUL,
34 TargetOpcode::G_SDIV,
35 TargetOpcode::G_UDIV,
36 TargetOpcode::G_FDIV,
37 TargetOpcode::G_SREM,
38 TargetOpcode::G_UREM,
39 TargetOpcode::G_FREM,
40 TargetOpcode::G_FNEG,
41 TargetOpcode::G_CONSTANT,
42 TargetOpcode::G_FCONSTANT,
43 TargetOpcode::G_AND,
44 TargetOpcode::G_OR,
45 TargetOpcode::G_XOR,
46 TargetOpcode::G_SHL,
47 TargetOpcode::G_ASHR,
48 TargetOpcode::G_LSHR,
49 TargetOpcode::G_SELECT,
50 TargetOpcode::G_EXTRACT_VECTOR_ELT,
51};
52
53bool isTypeFoldingSupported(unsigned Opcode) {
54 return TypeFoldingSupportingOpcs.count(Opcode) > 0;
55}
56
58 using namespace TargetOpcode;
59
60 this->ST = &ST;
61 GR = ST.getSPIRVGlobalRegistry();
62
63 const LLT s1 = LLT::scalar(1);
64 const LLT s8 = LLT::scalar(8);
65 const LLT s16 = LLT::scalar(16);
66 const LLT s32 = LLT::scalar(32);
67 const LLT s64 = LLT::scalar(64);
68
69 const LLT v16s64 = LLT::fixed_vector(16, 64);
70 const LLT v16s32 = LLT::fixed_vector(16, 32);
71 const LLT v16s16 = LLT::fixed_vector(16, 16);
72 const LLT v16s8 = LLT::fixed_vector(16, 8);
73 const LLT v16s1 = LLT::fixed_vector(16, 1);
74
75 const LLT v8s64 = LLT::fixed_vector(8, 64);
76 const LLT v8s32 = LLT::fixed_vector(8, 32);
77 const LLT v8s16 = LLT::fixed_vector(8, 16);
78 const LLT v8s8 = LLT::fixed_vector(8, 8);
79 const LLT v8s1 = LLT::fixed_vector(8, 1);
80
81 const LLT v4s64 = LLT::fixed_vector(4, 64);
82 const LLT v4s32 = LLT::fixed_vector(4, 32);
83 const LLT v4s16 = LLT::fixed_vector(4, 16);
84 const LLT v4s8 = LLT::fixed_vector(4, 8);
85 const LLT v4s1 = LLT::fixed_vector(4, 1);
86
87 const LLT v3s64 = LLT::fixed_vector(3, 64);
88 const LLT v3s32 = LLT::fixed_vector(3, 32);
89 const LLT v3s16 = LLT::fixed_vector(3, 16);
90 const LLT v3s8 = LLT::fixed_vector(3, 8);
91 const LLT v3s1 = LLT::fixed_vector(3, 1);
92
93 const LLT v2s64 = LLT::fixed_vector(2, 64);
94 const LLT v2s32 = LLT::fixed_vector(2, 32);
95 const LLT v2s16 = LLT::fixed_vector(2, 16);
96 const LLT v2s8 = LLT::fixed_vector(2, 8);
97 const LLT v2s1 = LLT::fixed_vector(2, 1);
98
99 const unsigned PSize = ST.getPointerSize();
100 const LLT p0 = LLT::pointer(0, PSize); // Function
101 const LLT p1 = LLT::pointer(1, PSize); // CrossWorkgroup
102 const LLT p2 = LLT::pointer(2, PSize); // UniformConstant
103 const LLT p3 = LLT::pointer(3, PSize); // Workgroup
104 const LLT p4 = LLT::pointer(4, PSize); // Generic
105 const LLT p5 =
106 LLT::pointer(5, PSize); // Input, SPV_INTEL_usm_storage_classes (Device)
107 const LLT p6 = LLT::pointer(6, PSize); // SPV_INTEL_usm_storage_classes (Host)
108
109 // TODO: remove copy-pasting here by using concatenation in some way.
110 auto allPtrsScalarsAndVectors = {
111 p0, p1, p2, p3, p4, p5, p6, s1, s8, s16,
112 s32, s64, v2s1, v2s8, v2s16, v2s32, v2s64, v3s1, v3s8, v3s16,
113 v3s32, v3s64, v4s1, v4s8, v4s16, v4s32, v4s64, v8s1, v8s8, v8s16,
114 v8s32, v8s64, v16s1, v16s8, v16s16, v16s32, v16s64};
115
116 auto allVectors = {v2s1, v2s8, v2s16, v2s32, v2s64, v3s1, v3s8,
117 v3s16, v3s32, v3s64, v4s1, v4s8, v4s16, v4s32,
118 v4s64, v8s1, v8s8, v8s16, v8s32, v8s64, v16s1,
119 v16s8, v16s16, v16s32, v16s64};
120
121 auto allScalarsAndVectors = {
122 s1, s8, s16, s32, s64, v2s1, v2s8, v2s16, v2s32, v2s64,
123 v3s1, v3s8, v3s16, v3s32, v3s64, v4s1, v4s8, v4s16, v4s32, v4s64,
124 v8s1, v8s8, v8s16, v8s32, v8s64, v16s1, v16s8, v16s16, v16s32, v16s64};
125
126 auto allIntScalarsAndVectors = {s8, s16, s32, s64, v2s8, v2s16,
127 v2s32, v2s64, v3s8, v3s16, v3s32, v3s64,
128 v4s8, v4s16, v4s32, v4s64, v8s8, v8s16,
129 v8s32, v8s64, v16s8, v16s16, v16s32, v16s64};
130
131 auto allBoolScalarsAndVectors = {s1, v2s1, v3s1, v4s1, v8s1, v16s1};
132
133 auto allIntScalars = {s8, s16, s32, s64};
134
135 auto allFloatScalars = {s16, s32, s64};
136
137 auto allFloatScalarsAndVectors = {
138 s16, s32, s64, v2s16, v2s32, v2s64, v3s16, v3s32, v3s64,
139 v4s16, v4s32, v4s64, v8s16, v8s32, v8s64, v16s16, v16s32, v16s64};
140
141 auto allFloatAndIntScalarsAndPtrs = {s8, s16, s32, s64, p0, p1,
142 p2, p3, p4, p5, p6};
143
144 auto allPtrs = {p0, p1, p2, p3, p4, p5, p6};
145 auto allWritablePtrs = {p0, p1, p3, p4, p5, p6};
146
147 for (auto Opc : TypeFoldingSupportingOpcs)
149
151
152 // TODO: add proper rules for vectors legalization.
154 {G_BUILD_VECTOR, G_SHUFFLE_VECTOR, G_SPLAT_VECTOR})
155 .alwaysLegal();
156
157 // Vector Reduction Operations
159 {G_VECREDUCE_SMIN, G_VECREDUCE_SMAX, G_VECREDUCE_UMIN, G_VECREDUCE_UMAX,
160 G_VECREDUCE_ADD, G_VECREDUCE_MUL, G_VECREDUCE_FMUL, G_VECREDUCE_FMIN,
161 G_VECREDUCE_FMAX, G_VECREDUCE_FMINIMUM, G_VECREDUCE_FMAXIMUM,
162 G_VECREDUCE_OR, G_VECREDUCE_AND, G_VECREDUCE_XOR})
163 .legalFor(allVectors)
164 .scalarize(1)
165 .lower();
166
167 getActionDefinitionsBuilder({G_VECREDUCE_SEQ_FADD, G_VECREDUCE_SEQ_FMUL})
168 .scalarize(2)
169 .lower();
170
171 // Merge/Unmerge
172 // TODO: add proper legalization rules.
173 getActionDefinitionsBuilder(G_UNMERGE_VALUES).alwaysLegal();
174
175 getActionDefinitionsBuilder({G_MEMCPY, G_MEMMOVE})
176 .legalIf(all(typeInSet(0, allWritablePtrs), typeInSet(1, allPtrs)));
177
179 all(typeInSet(0, allWritablePtrs), typeInSet(1, allIntScalars)));
180
181 getActionDefinitionsBuilder(G_ADDRSPACE_CAST)
182 .legalForCartesianProduct(allPtrs, allPtrs);
183
184 getActionDefinitionsBuilder({G_LOAD, G_STORE}).legalIf(typeInSet(1, allPtrs));
185
186 getActionDefinitionsBuilder(G_BITREVERSE).legalFor(allIntScalarsAndVectors);
187
188 getActionDefinitionsBuilder(G_FMA).legalFor(allFloatScalarsAndVectors);
189
190 getActionDefinitionsBuilder({G_FPTOSI, G_FPTOUI})
191 .legalForCartesianProduct(allIntScalarsAndVectors,
192 allFloatScalarsAndVectors);
193
194 getActionDefinitionsBuilder({G_SITOFP, G_UITOFP})
195 .legalForCartesianProduct(allFloatScalarsAndVectors,
196 allScalarsAndVectors);
197
198 getActionDefinitionsBuilder({G_SMIN, G_SMAX, G_UMIN, G_UMAX, G_ABS})
199 .legalFor(allIntScalarsAndVectors);
200
202 allIntScalarsAndVectors, allIntScalarsAndVectors);
203
204 getActionDefinitionsBuilder(G_PHI).legalFor(allPtrsScalarsAndVectors);
205
207 all(typeInSet(0, allPtrsScalarsAndVectors),
208 typeInSet(1, allPtrsScalarsAndVectors)));
209
210 getActionDefinitionsBuilder({G_IMPLICIT_DEF, G_FREEZE}).alwaysLegal();
211
212 getActionDefinitionsBuilder({G_STACKSAVE, G_STACKRESTORE}).alwaysLegal();
213
215 .legalForCartesianProduct(allPtrs, allIntScalars);
217 .legalForCartesianProduct(allIntScalars, allPtrs);
219 allPtrs, allIntScalars);
220
221 // ST.canDirectlyComparePointers() for pointer args is supported in
222 // legalizeCustom().
224 all(typeInSet(0, allBoolScalarsAndVectors),
225 typeInSet(1, allPtrsScalarsAndVectors)));
226
228 all(typeInSet(0, allBoolScalarsAndVectors),
229 typeInSet(1, allFloatScalarsAndVectors)));
230
231 getActionDefinitionsBuilder({G_ATOMICRMW_OR, G_ATOMICRMW_ADD, G_ATOMICRMW_AND,
232 G_ATOMICRMW_MAX, G_ATOMICRMW_MIN,
233 G_ATOMICRMW_SUB, G_ATOMICRMW_XOR,
234 G_ATOMICRMW_UMAX, G_ATOMICRMW_UMIN})
235 .legalForCartesianProduct(allIntScalars, allWritablePtrs);
236
238 {G_ATOMICRMW_FADD, G_ATOMICRMW_FSUB, G_ATOMICRMW_FMIN, G_ATOMICRMW_FMAX})
239 .legalForCartesianProduct(allFloatScalars, allWritablePtrs);
240
241 getActionDefinitionsBuilder(G_ATOMICRMW_XCHG)
242 .legalForCartesianProduct(allFloatAndIntScalarsAndPtrs, allWritablePtrs);
243
244 getActionDefinitionsBuilder(G_ATOMIC_CMPXCHG_WITH_SUCCESS).lower();
245 // TODO: add proper legalization rules.
246 getActionDefinitionsBuilder(G_ATOMIC_CMPXCHG).alwaysLegal();
247
248 getActionDefinitionsBuilder({G_UADDO, G_USUBO, G_SMULO, G_UMULO})
249 .alwaysLegal();
250
251 // Extensions.
252 getActionDefinitionsBuilder({G_TRUNC, G_ZEXT, G_SEXT, G_ANYEXT})
253 .legalForCartesianProduct(allScalarsAndVectors);
254
255 // FP conversions.
256 getActionDefinitionsBuilder({G_FPTRUNC, G_FPEXT})
257 .legalForCartesianProduct(allFloatScalarsAndVectors);
258
259 // Pointer-handling.
260 getActionDefinitionsBuilder(G_FRAME_INDEX).legalFor({p0});
261
262 // Control-flow. In some cases (e.g. constants) s1 may be promoted to s32.
263 getActionDefinitionsBuilder(G_BRCOND).legalFor({s1, s32});
264
265 // TODO: Review the target OpenCL and GLSL Extended Instruction Set specs to
266 // tighten these requirements. Many of these math functions are only legal on
267 // specific bitwidths, so they are not selectable for
268 // allFloatScalarsAndVectors.
270 G_FEXP,
271 G_FEXP2,
272 G_FLOG,
273 G_FLOG2,
274 G_FLOG10,
275 G_FABS,
276 G_FMINNUM,
277 G_FMAXNUM,
278 G_FCEIL,
279 G_FCOS,
280 G_FSIN,
281 G_FTAN,
282 G_FACOS,
283 G_FASIN,
284 G_FATAN,
285 G_FCOSH,
286 G_FSINH,
287 G_FTANH,
288 G_FSQRT,
289 G_FFLOOR,
290 G_FRINT,
291 G_FNEARBYINT,
292 G_INTRINSIC_ROUND,
293 G_INTRINSIC_TRUNC,
294 G_FMINIMUM,
295 G_FMAXIMUM,
296 G_INTRINSIC_ROUNDEVEN})
297 .legalFor(allFloatScalarsAndVectors);
298
299 getActionDefinitionsBuilder(G_FCOPYSIGN)
300 .legalForCartesianProduct(allFloatScalarsAndVectors,
301 allFloatScalarsAndVectors);
302
304 allFloatScalarsAndVectors, allIntScalarsAndVectors);
305
306 if (ST.canUseExtInstSet(SPIRV::InstructionSet::OpenCL_std)) {
308 {G_CTTZ, G_CTTZ_ZERO_UNDEF, G_CTLZ, G_CTLZ_ZERO_UNDEF})
309 .legalForCartesianProduct(allIntScalarsAndVectors,
310 allIntScalarsAndVectors);
311
312 // Struct return types become a single scalar, so cannot easily legalize.
313 getActionDefinitionsBuilder({G_SMULH, G_UMULH}).alwaysLegal();
314
315 // supported saturation arithmetic
316 getActionDefinitionsBuilder({G_SADDSAT, G_UADDSAT, G_SSUBSAT, G_USUBSAT})
317 .legalFor(allIntScalarsAndVectors);
318 }
319
321 verify(*ST.getInstrInfo());
322}
323
324static Register convertPtrToInt(Register Reg, LLT ConvTy, SPIRVType *SpirvType,
325 LegalizerHelper &Helper,
328 Register ConvReg = MRI.createGenericVirtualRegister(ConvTy);
329 GR->assignSPIRVTypeToVReg(SpirvType, ConvReg, Helper.MIRBuilder.getMF());
330 Helper.MIRBuilder.buildInstr(TargetOpcode::G_PTRTOINT)
331 .addDef(ConvReg)
332 .addUse(Reg);
333 return ConvReg;
334}
335
338 LostDebugLocObserver &LocObserver) const {
339 auto Opc = MI.getOpcode();
340 MachineRegisterInfo &MRI = MI.getMF()->getRegInfo();
341 if (!isTypeFoldingSupported(Opc)) {
342 assert(Opc == TargetOpcode::G_ICMP);
343 assert(GR->getSPIRVTypeForVReg(MI.getOperand(0).getReg()));
344 auto &Op0 = MI.getOperand(2);
345 auto &Op1 = MI.getOperand(3);
346 Register Reg0 = Op0.getReg();
347 Register Reg1 = Op1.getReg();
349 static_cast<CmpInst::Predicate>(MI.getOperand(1).getPredicate());
350 if ((!ST->canDirectlyComparePointers() ||
352 MRI.getType(Reg0).isPointer() && MRI.getType(Reg1).isPointer()) {
353 LLT ConvT = LLT::scalar(ST->getPointerSize());
354 Type *LLVMTy = IntegerType::get(MI.getMF()->getFunction().getContext(),
355 ST->getPointerSize());
356 SPIRVType *SpirvTy = GR->getOrCreateSPIRVType(LLVMTy, Helper.MIRBuilder);
357 Op0.setReg(convertPtrToInt(Reg0, ConvT, SpirvTy, Helper, MRI, GR));
358 Op1.setReg(convertPtrToInt(Reg1, ConvT, SpirvTy, Helper, MRI, GR));
359 }
360 return true;
361 }
362 // TODO: implement legalization for other opcodes.
363 return true;
364}
unsigned const MachineRegisterInfo * MRI
static void scalarize(BinaryOperator *BO, SmallVectorImpl< BinaryOperator * > &Replace)
IRTranslator LLVM IR MI
This file declares the MachineIRBuilder class.
ppc ctr loops verify
const SmallVectorImpl< MachineOperand > & Cond
assert(ImpDefSCC.getReg()==AMDGPU::SCC &&ImpDefSCC.isDef())
bool isTypeFoldingSupported(unsigned Opcode)
static const std::set< unsigned > TypeFoldingSupportingOpcs
static Register convertPtrToInt(Register Reg, LLT ConvTy, SPIRVType *SpirvType, LegalizerHelper &Helper, MachineRegisterInfo &MRI, SPIRVGlobalRegistry *GR)
bool isTypeFoldingSupported(unsigned Opcode)
Predicate
This enumeration lists the possible predicates for CmpInst subclasses.
Definition: InstrTypes.h:757
@ ICMP_EQ
equal
Definition: InstrTypes.h:778
@ ICMP_NE
not equal
Definition: InstrTypes.h:779
static IntegerType * get(LLVMContext &C, unsigned NumBits)
This static method is the primary way of constructing an IntegerType.
Definition: Type.cpp:266
static constexpr LLT scalar(unsigned SizeInBits)
Get a low-level scalar or aggregate "bag of bits".
Definition: LowLevelType.h:42
static constexpr LLT pointer(unsigned AddressSpace, unsigned SizeInBits)
Get a low-level pointer in the given address space.
Definition: LowLevelType.h:57
static constexpr LLT fixed_vector(unsigned NumElements, unsigned ScalarSizeInBits)
Get a low-level fixed-width vector of some number of elements and element width.
Definition: LowLevelType.h:100
void computeTables()
Compute any ancillary tables needed to quickly decide how an operation should be handled.
LegalizeRuleSet & legalFor(std::initializer_list< LLT > Types)
The instruction is legal when type index 0 is any type in the given list.
LegalizeRuleSet & lower()
The instruction is lowered.
LegalizeRuleSet & custom()
Unconditionally custom lower.
LegalizeRuleSet & alwaysLegal()
LegalizeRuleSet & customIf(LegalityPredicate Predicate)
LegalizeRuleSet & scalarize(unsigned TypeIdx)
LegalizeRuleSet & legalForCartesianProduct(std::initializer_list< LLT > Types)
The instruction is legal when type indexes 0 and 1 are both in the given list.
LegalizeRuleSet & legalIf(LegalityPredicate Predicate)
The instruction is legal if predicate is true.
MachineIRBuilder & MIRBuilder
Expose MIRBuilder so clients can set their own RecordInsertInstruction functions.
LegalizeRuleSet & getActionDefinitionsBuilder(unsigned Opcode)
Get the action definition builder for the given opcode.
const LegacyLegalizerInfo & getLegacyLegalizerInfo() const
MachineInstrBuilder buildInstr(unsigned Opcode)
Build and insert <empty> = Opcode <empty>.
MachineFunction & getMF()
Getter for the function we currently build.
const MachineInstrBuilder & addUse(Register RegNo, unsigned Flags=0, unsigned SubReg=0) const
Add a virtual register use operand.
const MachineInstrBuilder & addDef(Register RegNo, unsigned Flags=0, unsigned SubReg=0) const
Add a virtual register definition operand.
Representation of each machine instruction.
Definition: MachineInstr.h:69
MachineRegisterInfo - Keep track of information for virtual and physical registers,...
Wrapper class representing virtual and physical registers.
Definition: Register.h:19
SPIRVType * getSPIRVTypeForVReg(Register VReg, const MachineFunction *MF=nullptr) const
SPIRVType * getOrCreateSPIRVType(const Type *Type, MachineIRBuilder &MIRBuilder, SPIRV::AccessQualifier::AccessQualifier AQ=SPIRV::AccessQualifier::ReadWrite, bool EmitIR=true)
void assignSPIRVTypeToVReg(SPIRVType *Type, Register VReg, MachineFunction &MF)
SPIRVLegalizerInfo(const SPIRVSubtarget &ST)
bool legalizeCustom(LegalizerHelper &Helper, MachineInstr &MI, LostDebugLocObserver &LocObserver) const override
Called for instructions with the Custom LegalizationAction.
unsigned getPointerSize() const
bool canDirectlyComparePointers() const
The instances of the Type class are immutable: once they are created, they are never changed.
Definition: Type.h:45
LegalityPredicate typeInSet(unsigned TypeIdx, std::initializer_list< LLT > TypesInit)
True iff the given type index is one of the specified types.
Predicate all(Predicate P0, Predicate P1)
True iff P0 and P1 are true.
This is an optimization pass for GlobalISel generic memory operations.
Definition: AddressRanges.h:18