LLVM  16.0.0git
SPIRVLegalizerInfo.cpp
Go to the documentation of this file.
1 //===- SPIRVLegalizerInfo.cpp --- SPIR-V Legalization Rules ------*- C++ -*-==//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the targeting of the Machinelegalizer class for SPIR-V.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "SPIRVLegalizerInfo.h"
14 #include "SPIRV.h"
15 #include "SPIRVGlobalRegistry.h"
16 #include "SPIRVSubtarget.h"
22 
23 using namespace llvm;
24 using namespace llvm::LegalizeActions;
25 using namespace llvm::LegalityPredicates;
26 
27 static const std::set<unsigned> TypeFoldingSupportingOpcs = {
28  TargetOpcode::G_ADD,
29  TargetOpcode::G_FADD,
30  TargetOpcode::G_SUB,
31  TargetOpcode::G_FSUB,
32  TargetOpcode::G_MUL,
33  TargetOpcode::G_FMUL,
34  TargetOpcode::G_SDIV,
35  TargetOpcode::G_UDIV,
36  TargetOpcode::G_FDIV,
37  TargetOpcode::G_SREM,
38  TargetOpcode::G_UREM,
39  TargetOpcode::G_FREM,
40  TargetOpcode::G_FNEG,
41  TargetOpcode::G_CONSTANT,
42  TargetOpcode::G_FCONSTANT,
43  TargetOpcode::G_AND,
44  TargetOpcode::G_OR,
45  TargetOpcode::G_XOR,
46  TargetOpcode::G_SHL,
47  TargetOpcode::G_ASHR,
48  TargetOpcode::G_LSHR,
49  TargetOpcode::G_SELECT,
50  TargetOpcode::G_EXTRACT_VECTOR_ELT,
51 };
52 
53 bool isTypeFoldingSupported(unsigned Opcode) {
54  return TypeFoldingSupportingOpcs.count(Opcode) > 0;
55 }
56 
58  using namespace TargetOpcode;
59 
60  this->ST = &ST;
61  GR = ST.getSPIRVGlobalRegistry();
62 
63  const LLT s1 = LLT::scalar(1);
64  const LLT s8 = LLT::scalar(8);
65  const LLT s16 = LLT::scalar(16);
66  const LLT s32 = LLT::scalar(32);
67  const LLT s64 = LLT::scalar(64);
68 
69  const LLT v16s64 = LLT::fixed_vector(16, 64);
70  const LLT v16s32 = LLT::fixed_vector(16, 32);
71  const LLT v16s16 = LLT::fixed_vector(16, 16);
72  const LLT v16s8 = LLT::fixed_vector(16, 8);
73  const LLT v16s1 = LLT::fixed_vector(16, 1);
74 
75  const LLT v8s64 = LLT::fixed_vector(8, 64);
76  const LLT v8s32 = LLT::fixed_vector(8, 32);
77  const LLT v8s16 = LLT::fixed_vector(8, 16);
78  const LLT v8s8 = LLT::fixed_vector(8, 8);
79  const LLT v8s1 = LLT::fixed_vector(8, 1);
80 
81  const LLT v4s64 = LLT::fixed_vector(4, 64);
82  const LLT v4s32 = LLT::fixed_vector(4, 32);
83  const LLT v4s16 = LLT::fixed_vector(4, 16);
84  const LLT v4s8 = LLT::fixed_vector(4, 8);
85  const LLT v4s1 = LLT::fixed_vector(4, 1);
86 
87  const LLT v3s64 = LLT::fixed_vector(3, 64);
88  const LLT v3s32 = LLT::fixed_vector(3, 32);
89  const LLT v3s16 = LLT::fixed_vector(3, 16);
90  const LLT v3s8 = LLT::fixed_vector(3, 8);
91  const LLT v3s1 = LLT::fixed_vector(3, 1);
92 
93  const LLT v2s64 = LLT::fixed_vector(2, 64);
94  const LLT v2s32 = LLT::fixed_vector(2, 32);
95  const LLT v2s16 = LLT::fixed_vector(2, 16);
96  const LLT v2s8 = LLT::fixed_vector(2, 8);
97  const LLT v2s1 = LLT::fixed_vector(2, 1);
98 
99  const unsigned PSize = ST.getPointerSize();
100  const LLT p0 = LLT::pointer(0, PSize); // Function
101  const LLT p1 = LLT::pointer(1, PSize); // CrossWorkgroup
102  const LLT p2 = LLT::pointer(2, PSize); // UniformConstant
103  const LLT p3 = LLT::pointer(3, PSize); // Workgroup
104  const LLT p4 = LLT::pointer(4, PSize); // Generic
105  const LLT p5 = LLT::pointer(5, PSize); // Input
106 
107  // TODO: remove copy-pasting here by using concatenation in some way.
108  auto allPtrsScalarsAndVectors = {
109  p0, p1, p2, p3, p4, p5, s1, s8, s16,
110  s32, s64, v2s1, v2s8, v2s16, v2s32, v2s64, v3s1, v3s8,
111  v3s16, v3s32, v3s64, v4s1, v4s8, v4s16, v4s32, v4s64, v8s1,
112  v8s8, v8s16, v8s32, v8s64, v16s1, v16s8, v16s16, v16s32, v16s64};
113 
114  auto allScalarsAndVectors = {
115  s1, s8, s16, s32, s64, v2s1, v2s8, v2s16, v2s32, v2s64,
116  v3s1, v3s8, v3s16, v3s32, v3s64, v4s1, v4s8, v4s16, v4s32, v4s64,
117  v8s1, v8s8, v8s16, v8s32, v8s64, v16s1, v16s8, v16s16, v16s32, v16s64};
118 
119  auto allIntScalarsAndVectors = {s8, s16, s32, s64, v2s8, v2s16,
120  v2s32, v2s64, v3s8, v3s16, v3s32, v3s64,
121  v4s8, v4s16, v4s32, v4s64, v8s8, v8s16,
122  v8s32, v8s64, v16s8, v16s16, v16s32, v16s64};
123 
124  auto allBoolScalarsAndVectors = {s1, v2s1, v3s1, v4s1, v8s1, v16s1};
125 
126  auto allIntScalars = {s8, s16, s32, s64};
127 
128  auto allFloatScalarsAndVectors = {
129  s16, s32, s64, v2s16, v2s32, v2s64, v3s16, v3s32, v3s64,
130  v4s16, v4s32, v4s64, v8s16, v8s32, v8s64, v16s16, v16s32, v16s64};
131 
132  auto allFloatAndIntScalars = allIntScalars;
133 
134  auto allPtrs = {p0, p1, p2, p3, p4, p5};
135  auto allWritablePtrs = {p0, p1, p3, p4};
136 
137  for (auto Opc : TypeFoldingSupportingOpcs)
138  getActionDefinitionsBuilder(Opc).custom();
139 
140  getActionDefinitionsBuilder(G_GLOBAL_VALUE).alwaysLegal();
141 
142  // TODO: add proper rules for vectors legalization.
143  getActionDefinitionsBuilder({G_BUILD_VECTOR, G_SHUFFLE_VECTOR}).alwaysLegal();
144 
145  getActionDefinitionsBuilder({G_MEMCPY, G_MEMMOVE})
146  .legalIf(all(typeInSet(0, allWritablePtrs), typeInSet(1, allPtrs)));
147 
148  getActionDefinitionsBuilder(G_MEMSET).legalIf(
149  all(typeInSet(0, allWritablePtrs), typeInSet(1, allIntScalars)));
150 
151  getActionDefinitionsBuilder(G_ADDRSPACE_CAST)
152  .legalForCartesianProduct(allPtrs, allPtrs);
153 
154  getActionDefinitionsBuilder({G_LOAD, G_STORE}).legalIf(typeInSet(1, allPtrs));
155 
156  getActionDefinitionsBuilder(G_BITREVERSE).legalFor(allFloatScalarsAndVectors);
157 
158  getActionDefinitionsBuilder(G_FMA).legalFor(allFloatScalarsAndVectors);
159 
160  getActionDefinitionsBuilder({G_FPTOSI, G_FPTOUI})
161  .legalForCartesianProduct(allIntScalarsAndVectors,
162  allFloatScalarsAndVectors);
163 
164  getActionDefinitionsBuilder({G_SITOFP, G_UITOFP})
165  .legalForCartesianProduct(allFloatScalarsAndVectors,
166  allScalarsAndVectors);
167 
168  getActionDefinitionsBuilder({G_SMIN, G_SMAX, G_UMIN, G_UMAX, G_ABS})
169  .legalFor(allIntScalarsAndVectors);
170 
171  getActionDefinitionsBuilder(G_CTPOP).legalForCartesianProduct(
172  allIntScalarsAndVectors, allIntScalarsAndVectors);
173 
174  getActionDefinitionsBuilder(G_PHI).legalFor(allPtrsScalarsAndVectors);
175 
176  getActionDefinitionsBuilder(G_BITCAST).legalIf(all(
177  typeInSet(0, allPtrsScalarsAndVectors),
178  typeInSet(1, allPtrsScalarsAndVectors),
179  LegalityPredicate(([=](const LegalityQuery &Query) {
180  return Query.Types[0].getSizeInBits() == Query.Types[1].getSizeInBits();
181  }))));
182 
183  getActionDefinitionsBuilder(G_IMPLICIT_DEF).alwaysLegal();
184 
185  getActionDefinitionsBuilder(G_INTTOPTR)
186  .legalForCartesianProduct(allPtrs, allIntScalars);
187  getActionDefinitionsBuilder(G_PTRTOINT)
188  .legalForCartesianProduct(allIntScalars, allPtrs);
189  getActionDefinitionsBuilder(G_PTR_ADD).legalForCartesianProduct(
190  allPtrs, allIntScalars);
191 
192  // ST.canDirectlyComparePointers() for pointer args is supported in
193  // legalizeCustom().
194  getActionDefinitionsBuilder(G_ICMP).customIf(
195  all(typeInSet(0, allBoolScalarsAndVectors),
196  typeInSet(1, allPtrsScalarsAndVectors)));
197 
198  getActionDefinitionsBuilder(G_FCMP).legalIf(
199  all(typeInSet(0, allBoolScalarsAndVectors),
200  typeInSet(1, allFloatScalarsAndVectors)));
201 
202  getActionDefinitionsBuilder({G_ATOMICRMW_OR, G_ATOMICRMW_ADD, G_ATOMICRMW_AND,
203  G_ATOMICRMW_MAX, G_ATOMICRMW_MIN,
204  G_ATOMICRMW_SUB, G_ATOMICRMW_XOR,
205  G_ATOMICRMW_UMAX, G_ATOMICRMW_UMIN})
206  .legalForCartesianProduct(allIntScalars, allWritablePtrs);
207 
208  getActionDefinitionsBuilder(G_ATOMICRMW_XCHG)
209  .legalForCartesianProduct(allFloatAndIntScalars, allWritablePtrs);
210 
211  getActionDefinitionsBuilder(G_ATOMIC_CMPXCHG_WITH_SUCCESS).lower();
212  // TODO: add proper legalization rules.
213  getActionDefinitionsBuilder(G_ATOMIC_CMPXCHG).alwaysLegal();
214 
215  getActionDefinitionsBuilder({G_UADDO, G_USUBO, G_SMULO, G_UMULO})
216  .alwaysLegal();
217 
218  // Extensions.
219  getActionDefinitionsBuilder({G_TRUNC, G_ZEXT, G_SEXT, G_ANYEXT})
220  .legalForCartesianProduct(allScalarsAndVectors);
221 
222  // FP conversions.
223  getActionDefinitionsBuilder({G_FPTRUNC, G_FPEXT})
224  .legalForCartesianProduct(allFloatScalarsAndVectors);
225 
226  // Pointer-handling.
227  getActionDefinitionsBuilder(G_FRAME_INDEX).legalFor({p0});
228 
229  // Control-flow. In some cases (e.g. constants) s1 may be promoted to s32.
230  getActionDefinitionsBuilder(G_BRCOND).legalFor({s1, s32});
231 
232  getActionDefinitionsBuilder({G_FPOW,
233  G_FEXP,
234  G_FEXP2,
235  G_FLOG,
236  G_FLOG2,
237  G_FABS,
238  G_FMINNUM,
239  G_FMAXNUM,
240  G_FCEIL,
241  G_FCOS,
242  G_FSIN,
243  G_FSQRT,
244  G_FFLOOR,
245  G_FRINT,
246  G_FNEARBYINT,
247  G_INTRINSIC_ROUND,
248  G_INTRINSIC_TRUNC,
249  G_FMINIMUM,
250  G_FMAXIMUM,
251  G_INTRINSIC_ROUNDEVEN})
252  .legalFor(allFloatScalarsAndVectors);
253 
254  getActionDefinitionsBuilder(G_FCOPYSIGN)
255  .legalForCartesianProduct(allFloatScalarsAndVectors,
256  allFloatScalarsAndVectors);
257 
258  getActionDefinitionsBuilder(G_FPOWI).legalForCartesianProduct(
259  allFloatScalarsAndVectors, allIntScalarsAndVectors);
260 
261  if (ST.canUseExtInstSet(SPIRV::InstructionSet::OpenCL_std)) {
262  getActionDefinitionsBuilder(G_FLOG10).legalFor(allFloatScalarsAndVectors);
263 
264  getActionDefinitionsBuilder(
265  {G_CTTZ, G_CTTZ_ZERO_UNDEF, G_CTLZ, G_CTLZ_ZERO_UNDEF})
266  .legalForCartesianProduct(allIntScalarsAndVectors,
267  allIntScalarsAndVectors);
268 
269  // Struct return types become a single scalar, so cannot easily legalize.
270  getActionDefinitionsBuilder({G_SMULH, G_UMULH}).alwaysLegal();
271  }
272 
273  getLegacyLegalizerInfo().computeTables();
274  verify(*ST.getInstrInfo());
275 }
276 
277 static Register convertPtrToInt(Register Reg, LLT ConvTy, SPIRVType *SpirvType,
278  LegalizerHelper &Helper,
280  SPIRVGlobalRegistry *GR) {
281  Register ConvReg = MRI.createGenericVirtualRegister(ConvTy);
282  GR->assignSPIRVTypeToVReg(SpirvType, ConvReg, Helper.MIRBuilder.getMF());
283  Helper.MIRBuilder.buildInstr(TargetOpcode::G_PTRTOINT)
284  .addDef(ConvReg)
285  .addUse(Reg);
286  return ConvReg;
287 }
288 
290  MachineInstr &MI) const {
291  auto Opc = MI.getOpcode();
292  MachineRegisterInfo &MRI = MI.getMF()->getRegInfo();
293  if (!isTypeFoldingSupported(Opc)) {
294  assert(Opc == TargetOpcode::G_ICMP);
295  assert(GR->getSPIRVTypeForVReg(MI.getOperand(0).getReg()));
296  auto &Op0 = MI.getOperand(2);
297  auto &Op1 = MI.getOperand(3);
298  Register Reg0 = Op0.getReg();
299  Register Reg1 = Op1.getReg();
301  static_cast<CmpInst::Predicate>(MI.getOperand(1).getPredicate());
302  if ((!ST->canDirectlyComparePointers() ||
304  MRI.getType(Reg0).isPointer() && MRI.getType(Reg1).isPointer()) {
305  LLT ConvT = LLT::scalar(ST->getPointerSize());
306  Type *LLVMTy = IntegerType::get(MI.getMF()->getFunction().getContext(),
307  ST->getPointerSize());
308  SPIRVType *SpirvTy = GR->getOrCreateSPIRVType(LLVMTy, Helper.MIRBuilder);
309  Op0.setReg(convertPtrToInt(Reg0, ConvT, SpirvTy, Helper, MRI, GR));
310  Op1.setReg(convertPtrToInt(Reg1, ConvT, SpirvTy, Helper, MRI, GR));
311  }
312  return true;
313  }
314  // TODO: implement legalization for other opcodes.
315  return true;
316 }
MI
IRTranslator LLVM IR MI
Definition: IRTranslator.cpp:109
MachineInstr.h
llvm
This is an optimization pass for GlobalISel generic memory operations.
Definition: AddressRanges.h:18
SPIRVLegalizerInfo.h
llvm::CmpInst::ICMP_EQ
@ ICMP_EQ
equal
Definition: InstrTypes.h:741
llvm::CmpInst::Predicate
Predicate
This enumeration lists the possible predicates for CmpInst subclasses.
Definition: InstrTypes.h:720
llvm::MachineRegisterInfo
MachineRegisterInfo - Keep track of information for virtual and physical registers,...
Definition: MachineRegisterInfo.h:50
llvm::X86Disassembler::Reg
Reg
All possible values of the reg field in the ModR/M byte.
Definition: X86DisassemblerDecoder.h:462
llvm::CmpInst::ICMP_NE
@ ICMP_NE
not equal
Definition: InstrTypes.h:742
llvm::Type
The instances of the Type class are immutable: once they are created, they are never changed.
Definition: Type.h:45
llvm::SPIRVGlobalRegistry::assignSPIRVTypeToVReg
void assignSPIRVTypeToVReg(SPIRVType *Type, Register VReg, MachineFunction &MF)
Definition: SPIRVGlobalRegistry.cpp:56
llvm::SPIRVSubtarget
Definition: SPIRVSubtarget.h:35
convertPtrToInt
static Register convertPtrToInt(Register Reg, LLT ConvTy, SPIRVType *SpirvType, LegalizerHelper &Helper, MachineRegisterInfo &MRI, SPIRVGlobalRegistry *GR)
Definition: SPIRVLegalizerInfo.cpp:277
MachineIRBuilder.h
SPIRVSubtarget.h
llvm::LegalizerHelper
Definition: LegalizerHelper.h:46
MachineRegisterInfo.h
llvm::LLT::fixed_vector
static LLT fixed_vector(unsigned NumElements, unsigned ScalarSizeInBits)
Get a low-level fixed-width vector of some number of elements and element width.
Definition: LowLevelTypeImpl.h:74
llvm::MachineInstrBuilder::addDef
const MachineInstrBuilder & addDef(Register RegNo, unsigned Flags=0, unsigned SubReg=0) const
Add a virtual register definition operand.
Definition: MachineInstrBuilder.h:116
TargetOpcodes.h
llvm::MachineIRBuilder::getMF
MachineFunction & getMF()
Getter for the function we currently build.
Definition: MachineIRBuilder.h:271
llvm::SPIRVLegalizerInfo::legalizeCustom
bool legalizeCustom(LegalizerHelper &Helper, MachineInstr &MI) const override
Called for instructions with the Custom LegalizationAction.
Definition: SPIRVLegalizerInfo.cpp:289
llvm::LegalityPredicates
Definition: LegalizerInfo.h:203
isTypeFoldingSupported
bool isTypeFoldingSupported(unsigned Opcode)
Definition: SPIRVLegalizerInfo.cpp:53
llvm::LLT::pointer
static LLT pointer(unsigned AddressSpace, unsigned SizeInBits)
Get a low-level pointer in the given address space.
Definition: LowLevelTypeImpl.h:49
llvm::SPIRVLegalizerInfo::SPIRVLegalizerInfo
SPIRVLegalizerInfo(const SPIRVSubtarget &ST)
Definition: SPIRVLegalizerInfo.cpp:57
llvm::MachineInstr
Representation of each machine instruction.
Definition: MachineInstr.h:66
SPIRVGlobalRegistry.h
llvm::ARM_MB::ST
@ ST
Definition: ARMBaseInfo.h:73
llvm::LegalityPredicates::all
Predicate all(Predicate P0, Predicate P1)
True iff P0 and P1 are true.
Definition: LegalizerInfo.h:228
SPIRV.h
assert
assert(ImpDefSCC.getReg()==AMDGPU::SCC &&ImpDefSCC.isDef())
llvm::LLT::isPointer
bool isPointer() const
Definition: LowLevelTypeImpl.h:120
llvm::LegalityPredicate
std::function< bool(const LegalityQuery &)> LegalityPredicate
Definition: LegalizerInfo.h:199
llvm::LegalityPredicates::typeInSet
LegalityPredicate typeInSet(unsigned TypeIdx, std::initializer_list< LLT > TypesInit)
True iff the given type index is one of the specified types.
Definition: LegalityPredicates.cpp:34
llvm::MachineRegisterInfo::createGenericVirtualRegister
Register createGenericVirtualRegister(LLT Ty, StringRef Name="")
Create and return a new generic virtual register with low-level type Ty.
Definition: MachineRegisterInfo.cpp:186
llvm::SPIRVGlobalRegistry
Definition: SPIRVGlobalRegistry.h:27
llvm::MachineInstrBuilder::addUse
const MachineInstrBuilder & addUse(Register RegNo, unsigned Flags=0, unsigned SubReg=0) const
Add a virtual register use operand.
Definition: MachineInstrBuilder.h:123
llvm::MachineIRBuilder::buildInstr
MachineInstrBuilder buildInstr(unsigned Opcode)
Build and insert <empty> = Opcode <empty>.
Definition: MachineIRBuilder.h:383
llvm::LegalityQuery
The LegalityQuery object bundles together all the information that's needed to decide whether a given...
Definition: LegalizerInfo.h:108
Cond
SmallVector< MachineOperand, 4 > Cond
Definition: BasicBlockSections.cpp:138
MRI
unsigned const MachineRegisterInfo * MRI
Definition: AArch64AdvSIMDScalarPass.cpp:105
llvm::Register
Wrapper class representing virtual and physical registers.
Definition: Register.h:19
verify
ppc ctr loops verify
Definition: PPCCTRLoopsVerify.cpp:76
LegalizerHelper.h
s1
int s1
Definition: README.txt:182
llvm::MachineRegisterInfo::getType
LLT getType(Register Reg) const
Get the low-level type of Reg or LLT{} if Reg is not a generic (target independent) virtual register.
Definition: MachineRegisterInfo.h:745
llvm::LegalityQuery::Types
ArrayRef< LLT > Types
Definition: LegalizerInfo.h:110
llvm::LegalizerHelper::MIRBuilder
MachineIRBuilder & MIRBuilder
Expose MIRBuilder so clients can set their own RecordInsertInstruction functions.
Definition: LegalizerHelper.h:50
TypeFoldingSupportingOpcs
static const std::set< unsigned > TypeFoldingSupportingOpcs
Definition: SPIRVLegalizerInfo.cpp:27
llvm::IntegerType::get
static IntegerType * get(LLVMContext &C, unsigned NumBits)
This static method is the primary way of constructing an IntegerType.
Definition: Type.cpp:311
llvm::LLT::scalar
static LLT scalar(unsigned SizeInBits)
Get a low-level scalar or aggregate "bag of bits".
Definition: LowLevelTypeImpl.h:42
llvm::LegalizeActions
Definition: LegalizerInfo.h:42
llvm::LLT
Definition: LowLevelTypeImpl.h:39