LLVM  15.0.0git
SPIRVLegalizerInfo.cpp
Go to the documentation of this file.
1 //===- SPIRVLegalizerInfo.cpp --- SPIR-V Legalization Rules ------*- C++ -*-==//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the targeting of the Machinelegalizer class for SPIR-V.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "SPIRVLegalizerInfo.h"
14 #include "SPIRV.h"
15 #include "SPIRVGlobalRegistry.h"
16 #include "SPIRVSubtarget.h"
22 
23 using namespace llvm;
24 using namespace llvm::LegalizeActions;
25 using namespace llvm::LegalityPredicates;
26 
27 static const std::set<unsigned> TypeFoldingSupportingOpcs = {
28  TargetOpcode::G_ADD,
29  TargetOpcode::G_FADD,
30  TargetOpcode::G_SUB,
31  TargetOpcode::G_FSUB,
32  TargetOpcode::G_MUL,
33  TargetOpcode::G_FMUL,
34  TargetOpcode::G_SDIV,
35  TargetOpcode::G_UDIV,
36  TargetOpcode::G_FDIV,
37  TargetOpcode::G_SREM,
38  TargetOpcode::G_UREM,
39  TargetOpcode::G_FREM,
40  TargetOpcode::G_FNEG,
41  TargetOpcode::G_CONSTANT,
42  TargetOpcode::G_FCONSTANT,
43  TargetOpcode::G_AND,
44  TargetOpcode::G_OR,
45  TargetOpcode::G_XOR,
46  TargetOpcode::G_SHL,
47  TargetOpcode::G_ASHR,
48  TargetOpcode::G_LSHR,
49  TargetOpcode::G_SELECT,
50  TargetOpcode::G_EXTRACT_VECTOR_ELT,
51 };
52 
53 bool isTypeFoldingSupported(unsigned Opcode) {
54  return TypeFoldingSupportingOpcs.count(Opcode) > 0;
55 }
56 
58  using namespace TargetOpcode;
59 
60  this->ST = &ST;
61  GR = ST.getSPIRVGlobalRegistry();
62 
63  const LLT s1 = LLT::scalar(1);
64  const LLT s8 = LLT::scalar(8);
65  const LLT s16 = LLT::scalar(16);
66  const LLT s32 = LLT::scalar(32);
67  const LLT s64 = LLT::scalar(64);
68 
69  const LLT v16s64 = LLT::fixed_vector(16, 64);
70  const LLT v16s32 = LLT::fixed_vector(16, 32);
71  const LLT v16s16 = LLT::fixed_vector(16, 16);
72  const LLT v16s8 = LLT::fixed_vector(16, 8);
73  const LLT v16s1 = LLT::fixed_vector(16, 1);
74 
75  const LLT v8s64 = LLT::fixed_vector(8, 64);
76  const LLT v8s32 = LLT::fixed_vector(8, 32);
77  const LLT v8s16 = LLT::fixed_vector(8, 16);
78  const LLT v8s8 = LLT::fixed_vector(8, 8);
79  const LLT v8s1 = LLT::fixed_vector(8, 1);
80 
81  const LLT v4s64 = LLT::fixed_vector(4, 64);
82  const LLT v4s32 = LLT::fixed_vector(4, 32);
83  const LLT v4s16 = LLT::fixed_vector(4, 16);
84  const LLT v4s8 = LLT::fixed_vector(4, 8);
85  const LLT v4s1 = LLT::fixed_vector(4, 1);
86 
87  const LLT v3s64 = LLT::fixed_vector(3, 64);
88  const LLT v3s32 = LLT::fixed_vector(3, 32);
89  const LLT v3s16 = LLT::fixed_vector(3, 16);
90  const LLT v3s8 = LLT::fixed_vector(3, 8);
91  const LLT v3s1 = LLT::fixed_vector(3, 1);
92 
93  const LLT v2s64 = LLT::fixed_vector(2, 64);
94  const LLT v2s32 = LLT::fixed_vector(2, 32);
95  const LLT v2s16 = LLT::fixed_vector(2, 16);
96  const LLT v2s8 = LLT::fixed_vector(2, 8);
97  const LLT v2s1 = LLT::fixed_vector(2, 1);
98 
99  const unsigned PSize = ST.getPointerSize();
100  const LLT p0 = LLT::pointer(0, PSize); // Function
101  const LLT p1 = LLT::pointer(1, PSize); // CrossWorkgroup
102  const LLT p2 = LLT::pointer(2, PSize); // UniformConstant
103  const LLT p3 = LLT::pointer(3, PSize); // Workgroup
104  const LLT p4 = LLT::pointer(4, PSize); // Generic
105  const LLT p5 = LLT::pointer(5, PSize); // Input
106 
107  // TODO: remove copy-pasting here by using concatenation in some way.
108  auto allPtrsScalarsAndVectors = {
109  p0, p1, p2, p3, p4, p5, s1, s8, s16,
110  s32, s64, v2s1, v2s8, v2s16, v2s32, v2s64, v3s1, v3s8,
111  v3s16, v3s32, v3s64, v4s1, v4s8, v4s16, v4s32, v4s64, v8s1,
112  v8s8, v8s16, v8s32, v8s64, v16s1, v16s8, v16s16, v16s32, v16s64};
113 
114  auto allScalarsAndVectors = {
115  s1, s8, s16, s32, s64, v2s1, v2s8, v2s16, v2s32, v2s64,
116  v3s1, v3s8, v3s16, v3s32, v3s64, v4s1, v4s8, v4s16, v4s32, v4s64,
117  v8s1, v8s8, v8s16, v8s32, v8s64, v16s1, v16s8, v16s16, v16s32, v16s64};
118 
119  auto allIntScalarsAndVectors = {s8, s16, s32, s64, v2s8, v2s16,
120  v2s32, v2s64, v3s8, v3s16, v3s32, v3s64,
121  v4s8, v4s16, v4s32, v4s64, v8s8, v8s16,
122  v8s32, v8s64, v16s8, v16s16, v16s32, v16s64};
123 
124  auto allBoolScalarsAndVectors = {s1, v2s1, v3s1, v4s1, v8s1, v16s1};
125 
126  auto allIntScalars = {s8, s16, s32, s64};
127 
128  auto allFloatScalarsAndVectors = {
129  s16, s32, s64, v2s16, v2s32, v2s64, v3s16, v3s32, v3s64,
130  v4s16, v4s32, v4s64, v8s16, v8s32, v8s64, v16s16, v16s32, v16s64};
131 
132  auto allFloatAndIntScalars = allIntScalars;
133 
134  auto allPtrs = {p0, p1, p2, p3, p4, p5};
135  auto allWritablePtrs = {p0, p1, p3, p4};
136 
137  for (auto Opc : TypeFoldingSupportingOpcs)
138  getActionDefinitionsBuilder(Opc).custom();
139 
140  getActionDefinitionsBuilder(G_GLOBAL_VALUE).alwaysLegal();
141 
142  // TODO: add proper rules for vectors legalization.
143  getActionDefinitionsBuilder({G_BUILD_VECTOR, G_SHUFFLE_VECTOR}).alwaysLegal();
144 
145  getActionDefinitionsBuilder({G_MEMCPY, G_MEMMOVE})
146  .legalIf(all(typeInSet(0, allWritablePtrs), typeInSet(1, allPtrs)));
147 
148  getActionDefinitionsBuilder(G_ADDRSPACE_CAST)
149  .legalForCartesianProduct(allPtrs, allPtrs);
150 
151  getActionDefinitionsBuilder({G_LOAD, G_STORE}).legalIf(typeInSet(1, allPtrs));
152 
153  getActionDefinitionsBuilder(G_BITREVERSE).legalFor(allFloatScalarsAndVectors);
154 
155  getActionDefinitionsBuilder(G_FMA).legalFor(allFloatScalarsAndVectors);
156 
157  getActionDefinitionsBuilder({G_FPTOSI, G_FPTOUI})
158  .legalForCartesianProduct(allIntScalarsAndVectors,
159  allFloatScalarsAndVectors);
160 
161  getActionDefinitionsBuilder({G_SITOFP, G_UITOFP})
162  .legalForCartesianProduct(allFloatScalarsAndVectors,
163  allScalarsAndVectors);
164 
165  getActionDefinitionsBuilder({G_SMIN, G_SMAX, G_UMIN, G_UMAX, G_ABS})
166  .legalFor(allIntScalarsAndVectors);
167 
168  getActionDefinitionsBuilder(G_CTPOP).legalForCartesianProduct(
169  allIntScalarsAndVectors, allIntScalarsAndVectors);
170 
171  getActionDefinitionsBuilder(G_PHI).legalFor(allPtrsScalarsAndVectors);
172 
173  getActionDefinitionsBuilder(G_BITCAST).legalIf(all(
174  typeInSet(0, allPtrsScalarsAndVectors),
175  typeInSet(1, allPtrsScalarsAndVectors),
176  LegalityPredicate(([=](const LegalityQuery &Query) {
177  return Query.Types[0].getSizeInBits() == Query.Types[1].getSizeInBits();
178  }))));
179 
180  getActionDefinitionsBuilder(G_IMPLICIT_DEF).alwaysLegal();
181 
182  getActionDefinitionsBuilder(G_INTTOPTR)
183  .legalForCartesianProduct(allPtrs, allIntScalars);
184  getActionDefinitionsBuilder(G_PTRTOINT)
185  .legalForCartesianProduct(allIntScalars, allPtrs);
186  getActionDefinitionsBuilder(G_PTR_ADD).legalForCartesianProduct(
187  allPtrs, allIntScalars);
188 
189  // ST.canDirectlyComparePointers() for pointer args is supported in
190  // legalizeCustom().
191  getActionDefinitionsBuilder(G_ICMP).customIf(
192  all(typeInSet(0, allBoolScalarsAndVectors),
193  typeInSet(1, allPtrsScalarsAndVectors)));
194 
195  getActionDefinitionsBuilder(G_FCMP).legalIf(
196  all(typeInSet(0, allBoolScalarsAndVectors),
197  typeInSet(1, allFloatScalarsAndVectors)));
198 
199  getActionDefinitionsBuilder({G_ATOMICRMW_OR, G_ATOMICRMW_ADD, G_ATOMICRMW_AND,
200  G_ATOMICRMW_MAX, G_ATOMICRMW_MIN,
201  G_ATOMICRMW_SUB, G_ATOMICRMW_XOR,
202  G_ATOMICRMW_UMAX, G_ATOMICRMW_UMIN})
203  .legalForCartesianProduct(allIntScalars, allWritablePtrs);
204 
205  getActionDefinitionsBuilder(G_ATOMICRMW_XCHG)
206  .legalForCartesianProduct(allFloatAndIntScalars, allWritablePtrs);
207 
208  getActionDefinitionsBuilder(G_ATOMIC_CMPXCHG_WITH_SUCCESS).lower();
209  // TODO: add proper legalization rules.
210  getActionDefinitionsBuilder(G_ATOMIC_CMPXCHG).alwaysLegal();
211 
212  getActionDefinitionsBuilder({G_UADDO, G_USUBO, G_SMULO, G_UMULO})
213  .alwaysLegal();
214 
215  // Extensions.
216  getActionDefinitionsBuilder({G_TRUNC, G_ZEXT, G_SEXT, G_ANYEXT})
217  .legalForCartesianProduct(allScalarsAndVectors);
218 
219  // FP conversions.
220  getActionDefinitionsBuilder({G_FPTRUNC, G_FPEXT})
221  .legalForCartesianProduct(allFloatScalarsAndVectors);
222 
223  // Pointer-handling.
224  getActionDefinitionsBuilder(G_FRAME_INDEX).legalFor({p0});
225 
226  // Control-flow.
227  getActionDefinitionsBuilder(G_BRCOND).legalFor({s1});
228 
229  getActionDefinitionsBuilder({G_FPOW,
230  G_FEXP,
231  G_FEXP2,
232  G_FLOG,
233  G_FLOG2,
234  G_FABS,
235  G_FMINNUM,
236  G_FMAXNUM,
237  G_FCEIL,
238  G_FCOS,
239  G_FSIN,
240  G_FSQRT,
241  G_FFLOOR,
242  G_FRINT,
243  G_FNEARBYINT,
244  G_INTRINSIC_ROUND,
245  G_INTRINSIC_TRUNC,
246  G_FMINIMUM,
247  G_FMAXIMUM,
248  G_INTRINSIC_ROUNDEVEN})
249  .legalFor(allFloatScalarsAndVectors);
250 
251  getActionDefinitionsBuilder(G_FCOPYSIGN)
252  .legalForCartesianProduct(allFloatScalarsAndVectors,
253  allFloatScalarsAndVectors);
254 
255  getActionDefinitionsBuilder(G_FPOWI).legalForCartesianProduct(
256  allFloatScalarsAndVectors, allIntScalarsAndVectors);
257 
258  getLegacyLegalizerInfo().computeTables();
259  verify(*ST.getInstrInfo());
260 }
261 
262 static Register convertPtrToInt(Register Reg, LLT ConvTy, SPIRVType *SpirvType,
263  LegalizerHelper &Helper,
265  SPIRVGlobalRegistry *GR) {
266  Register ConvReg = MRI.createGenericVirtualRegister(ConvTy);
267  GR->assignSPIRVTypeToVReg(SpirvType, ConvReg, Helper.MIRBuilder.getMF());
268  Helper.MIRBuilder.buildInstr(TargetOpcode::G_PTRTOINT)
269  .addDef(ConvReg)
270  .addUse(Reg);
271  return ConvReg;
272 }
273 
275  MachineInstr &MI) const {
276  auto Opc = MI.getOpcode();
277  MachineRegisterInfo &MRI = MI.getMF()->getRegInfo();
278  if (!isTypeFoldingSupported(Opc)) {
279  assert(Opc == TargetOpcode::G_ICMP);
280  assert(GR->getSPIRVTypeForVReg(MI.getOperand(0).getReg()));
281  auto &Op0 = MI.getOperand(2);
282  auto &Op1 = MI.getOperand(3);
283  Register Reg0 = Op0.getReg();
284  Register Reg1 = Op1.getReg();
286  static_cast<CmpInst::Predicate>(MI.getOperand(1).getPredicate());
287  if ((!ST->canDirectlyComparePointers() ||
289  MRI.getType(Reg0).isPointer() && MRI.getType(Reg1).isPointer()) {
290  LLT ConvT = LLT::scalar(ST->getPointerSize());
291  Type *LLVMTy = IntegerType::get(MI.getMF()->getFunction().getContext(),
292  ST->getPointerSize());
293  SPIRVType *SpirvTy = GR->getOrCreateSPIRVType(LLVMTy, Helper.MIRBuilder);
294  Op0.setReg(convertPtrToInt(Reg0, ConvT, SpirvTy, Helper, MRI, GR));
295  Op1.setReg(convertPtrToInt(Reg1, ConvT, SpirvTy, Helper, MRI, GR));
296  }
297  return true;
298  }
299  // TODO: implement legalization for other opcodes.
300  return true;
301 }
MI
IRTranslator LLVM IR MI
Definition: IRTranslator.cpp:104
MachineInstr.h
llvm
This is an optimization pass for GlobalISel generic memory operations.
Definition: AddressRanges.h:17
SPIRVLegalizerInfo.h
llvm::CmpInst::ICMP_EQ
@ ICMP_EQ
equal
Definition: InstrTypes.h:740
llvm::CmpInst::Predicate
Predicate
This enumeration lists the possible predicates for CmpInst subclasses.
Definition: InstrTypes.h:719
llvm::MachineRegisterInfo
MachineRegisterInfo - Keep track of information for virtual and physical registers,...
Definition: MachineRegisterInfo.h:50
llvm::X86Disassembler::Reg
Reg
All possible values of the reg field in the ModR/M byte.
Definition: X86DisassemblerDecoder.h:462
llvm::CmpInst::ICMP_NE
@ ICMP_NE
not equal
Definition: InstrTypes.h:741
llvm::Type
The instances of the Type class are immutable: once they are created, they are never changed.
Definition: Type.h:45
llvm::SPIRVGlobalRegistry::assignSPIRVTypeToVReg
void assignSPIRVTypeToVReg(SPIRVType *Type, Register VReg, MachineFunction &MF)
Definition: SPIRVGlobalRegistry.cpp:37
llvm::SPIRVSubtarget
Definition: SPIRVSubtarget.h:36
convertPtrToInt
static Register convertPtrToInt(Register Reg, LLT ConvTy, SPIRVType *SpirvType, LegalizerHelper &Helper, MachineRegisterInfo &MRI, SPIRVGlobalRegistry *GR)
Definition: SPIRVLegalizerInfo.cpp:262
MachineIRBuilder.h
SPIRVSubtarget.h
llvm::LegalizerHelper
Definition: LegalizerHelper.h:46
MachineRegisterInfo.h
llvm::LLT::fixed_vector
static LLT fixed_vector(unsigned NumElements, unsigned ScalarSizeInBits)
Get a low-level fixed-width vector of some number of elements and element width.
Definition: LowLevelTypeImpl.h:74
llvm::MachineInstrBuilder::addDef
const MachineInstrBuilder & addDef(Register RegNo, unsigned Flags=0, unsigned SubReg=0) const
Add a virtual register definition operand.
Definition: MachineInstrBuilder.h:116
TargetOpcodes.h
llvm::MachineIRBuilder::getMF
MachineFunction & getMF()
Getter for the function we currently build.
Definition: MachineIRBuilder.h:269
llvm::SPIRVLegalizerInfo::legalizeCustom
bool legalizeCustom(LegalizerHelper &Helper, MachineInstr &MI) const override
Called for instructions with the Custom LegalizationAction.
Definition: SPIRVLegalizerInfo.cpp:274
llvm::LegalityPredicates
Definition: LegalizerInfo.h:203
isTypeFoldingSupported
bool isTypeFoldingSupported(unsigned Opcode)
Definition: SPIRVLegalizerInfo.cpp:53
llvm::LLT::pointer
static LLT pointer(unsigned AddressSpace, unsigned SizeInBits)
Get a low-level pointer in the given address space.
Definition: LowLevelTypeImpl.h:49
llvm::SPIRVLegalizerInfo::SPIRVLegalizerInfo
SPIRVLegalizerInfo(const SPIRVSubtarget &ST)
Definition: SPIRVLegalizerInfo.cpp:57
llvm::MachineInstr
Representation of each machine instruction.
Definition: MachineInstr.h:66
SPIRVGlobalRegistry.h
llvm::ARM_MB::ST
@ ST
Definition: ARMBaseInfo.h:73
llvm::LegalityPredicates::all
Predicate all(Predicate P0, Predicate P1)
True iff P0 and P1 are true.
Definition: LegalizerInfo.h:228
SPIRV.h
assert
assert(ImpDefSCC.getReg()==AMDGPU::SCC &&ImpDefSCC.isDef())
llvm::LLT::isPointer
bool isPointer() const
Definition: LowLevelTypeImpl.h:120
llvm::LegalityPredicate
std::function< bool(const LegalityQuery &)> LegalityPredicate
Definition: LegalizerInfo.h:199
llvm::LegalityPredicates::typeInSet
LegalityPredicate typeInSet(unsigned TypeIdx, std::initializer_list< LLT > TypesInit)
True iff the given type index is one of the specified types.
Definition: LegalityPredicates.cpp:34
llvm::MachineRegisterInfo::createGenericVirtualRegister
Register createGenericVirtualRegister(LLT Ty, StringRef Name="")
Create and return a new generic virtual register with low-level type Ty.
Definition: MachineRegisterInfo.cpp:186
llvm::SPIRVGlobalRegistry
Definition: SPIRVGlobalRegistry.h:26
llvm::MachineInstrBuilder::addUse
const MachineInstrBuilder & addUse(Register RegNo, unsigned Flags=0, unsigned SubReg=0) const
Add a virtual register use operand.
Definition: MachineInstrBuilder.h:123
llvm::MachineIRBuilder::buildInstr
MachineInstrBuilder buildInstr(unsigned Opcode)
Build and insert <empty> = Opcode <empty>.
Definition: MachineIRBuilder.h:374
llvm::LegalityQuery
The LegalityQuery object bundles together all the information that's needed to decide whether a given...
Definition: LegalizerInfo.h:108
Cond
SmallVector< MachineOperand, 4 > Cond
Definition: BasicBlockSections.cpp:137
MRI
unsigned const MachineRegisterInfo * MRI
Definition: AArch64AdvSIMDScalarPass.cpp:105
llvm::Register
Wrapper class representing virtual and physical registers.
Definition: Register.h:19
verify
ppc ctr loops verify
Definition: PPCCTRLoopsVerify.cpp:76
LegalizerHelper.h
s1
int s1
Definition: README.txt:182
llvm::MachineRegisterInfo::getType
LLT getType(Register Reg) const
Get the low-level type of Reg or LLT{} if Reg is not a generic (target independent) virtual register.
Definition: MachineRegisterInfo.h:740
llvm::LegalityQuery::Types
ArrayRef< LLT > Types
Definition: LegalizerInfo.h:110
llvm::LegalizerHelper::MIRBuilder
MachineIRBuilder & MIRBuilder
Expose MIRBuilder so clients can set their own RecordInsertInstruction functions.
Definition: LegalizerHelper.h:50
TypeFoldingSupportingOpcs
static const std::set< unsigned > TypeFoldingSupportingOpcs
Definition: SPIRVLegalizerInfo.cpp:27
llvm::IntegerType::get
static IntegerType * get(LLVMContext &C, unsigned NumBits)
This static method is the primary way of constructing an IntegerType.
Definition: Type.cpp:311
llvm::LLT::scalar
static LLT scalar(unsigned SizeInBits)
Get a low-level scalar or aggregate "bag of bits".
Definition: LowLevelTypeImpl.h:42
llvm::LegalizeActions
Definition: LegalizerInfo.h:42
llvm::LLT
Definition: LowLevelTypeImpl.h:39