LLVM 23.0.0git
SPIRVLegalizerInfo.cpp
Go to the documentation of this file.
1//===- SPIRVLegalizerInfo.cpp --- SPIR-V Legalization Rules ------*- C++ -*-==//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements the targeting of the Machinelegalizer class for SPIR-V.
10//
11//===----------------------------------------------------------------------===//
12
13#include "SPIRVLegalizerInfo.h"
14#include "SPIRV.h"
15#include "SPIRVGlobalRegistry.h"
16#include "SPIRVSubtarget.h"
17#include "SPIRVUtils.h"
24#include "llvm/IR/IntrinsicsSPIRV.h"
25#include "llvm/Support/Debug.h"
27
28using namespace llvm;
29using namespace llvm::LegalizeActions;
30using namespace llvm::LegalityPredicates;
31
32#define DEBUG_TYPE "spirv-legalizer"
33
34LegalityPredicate typeOfExtendedScalars(unsigned TypeIdx, bool IsExtendedInts) {
35 return [IsExtendedInts, TypeIdx](const LegalityQuery &Query) {
36 const LLT Ty = Query.Types[TypeIdx];
37 return IsExtendedInts && Ty.isValid() && Ty.isScalar();
38 };
39}
40
42 using namespace TargetOpcode;
43
44 this->ST = &ST;
45 GR = ST.getSPIRVGlobalRegistry();
46
47 const LLT s1 = LLT::scalar(1);
48 const LLT s8 = LLT::scalar(8);
49 const LLT s16 = LLT::scalar(16);
50 const LLT s32 = LLT::scalar(32);
51 const LLT s64 = LLT::scalar(64);
52 const LLT s128 = LLT::scalar(128);
53
54 const LLT v16s64 = LLT::fixed_vector(16, 64);
55 const LLT v16s32 = LLT::fixed_vector(16, 32);
56 const LLT v16s16 = LLT::fixed_vector(16, 16);
57 const LLT v16s8 = LLT::fixed_vector(16, 8);
58 const LLT v16s1 = LLT::fixed_vector(16, 1);
59
60 const LLT v8s64 = LLT::fixed_vector(8, 64);
61 const LLT v8s32 = LLT::fixed_vector(8, 32);
62 const LLT v8s16 = LLT::fixed_vector(8, 16);
63 const LLT v8s8 = LLT::fixed_vector(8, 8);
64 const LLT v8s1 = LLT::fixed_vector(8, 1);
65
66 const LLT v4s64 = LLT::fixed_vector(4, 64);
67 const LLT v4s32 = LLT::fixed_vector(4, 32);
68 const LLT v4s16 = LLT::fixed_vector(4, 16);
69 const LLT v4s8 = LLT::fixed_vector(4, 8);
70 const LLT v4s1 = LLT::fixed_vector(4, 1);
71
72 const LLT v3s64 = LLT::fixed_vector(3, 64);
73 const LLT v3s32 = LLT::fixed_vector(3, 32);
74 const LLT v3s16 = LLT::fixed_vector(3, 16);
75 const LLT v3s8 = LLT::fixed_vector(3, 8);
76 const LLT v3s1 = LLT::fixed_vector(3, 1);
77
78 const LLT v2s64 = LLT::fixed_vector(2, 64);
79 const LLT v2s32 = LLT::fixed_vector(2, 32);
80 const LLT v2s16 = LLT::fixed_vector(2, 16);
81 const LLT v2s8 = LLT::fixed_vector(2, 8);
82 const LLT v2s1 = LLT::fixed_vector(2, 1);
83
84 const unsigned PSize = ST.getPointerSize();
85 const LLT p0 = LLT::pointer(0, PSize); // Function
86 const LLT p1 = LLT::pointer(1, PSize); // CrossWorkgroup
87 const LLT p2 = LLT::pointer(2, PSize); // UniformConstant
88 const LLT p3 = LLT::pointer(3, PSize); // Workgroup
89 const LLT p4 = LLT::pointer(4, PSize); // Generic
90 const LLT p5 =
91 LLT::pointer(5, PSize); // Input, SPV_INTEL_usm_storage_classes (Device)
92 const LLT p6 = LLT::pointer(6, PSize); // SPV_INTEL_usm_storage_classes (Host)
93 const LLT p7 = LLT::pointer(7, PSize); // Input
94 const LLT p8 = LLT::pointer(8, PSize); // Output
95 const LLT p9 =
96 LLT::pointer(9, PSize); // CodeSectionINTEL, SPV_INTEL_function_pointers
97 const LLT p10 = LLT::pointer(10, PSize); // Private
98 const LLT p11 = LLT::pointer(11, PSize); // StorageBuffer
99 const LLT p12 = LLT::pointer(12, PSize); // Uniform
100 const LLT p13 = LLT::pointer(13, PSize); // PushConstant
101
102 // TODO: remove copy-pasting here by using concatenation in some way.
103 auto allPtrsScalarsAndVectors = {
104 p0, p1, p2, p3, p4, p5, p6, p7, p8,
105 p9, p10, p11, p12, p13, s1, s8, s16, s32,
106 s64, v2s1, v2s8, v2s16, v2s32, v2s64, v3s1, v3s8, v3s16,
107 v3s32, v3s64, v4s1, v4s8, v4s16, v4s32, v4s64, v8s1, v8s8,
108 v8s16, v8s32, v8s64, v16s1, v16s8, v16s16, v16s32, v16s64};
109
110 auto allVectors = {v2s1, v2s8, v2s16, v2s32, v2s64, v3s1, v3s8,
111 v3s16, v3s32, v3s64, v4s1, v4s8, v4s16, v4s32,
112 v4s64, v8s1, v8s8, v8s16, v8s32, v8s64, v16s1,
113 v16s8, v16s16, v16s32, v16s64};
114
115 auto allShaderVectors = {v2s1, v2s8, v2s16, v2s32, v2s64,
116 v3s1, v3s8, v3s16, v3s32, v3s64,
117 v4s1, v4s8, v4s16, v4s32, v4s64};
118
119 auto allScalars = {s1, s8, s16, s32, s64};
120
121 auto allScalarsAndVectors = {
122 s1, s8, s16, s32, s64, s128, v2s1, v2s8,
123 v2s16, v2s32, v2s64, v3s1, v3s8, v3s16, v3s32, v3s64,
124 v4s1, v4s8, v4s16, v4s32, v4s64, v8s1, v8s8, v8s16,
125 v8s32, v8s64, v16s1, v16s8, v16s16, v16s32, v16s64};
126
127 auto allIntScalarsAndVectors = {
128 s8, s16, s32, s64, s128, v2s8, v2s16, v2s32, v2s64,
129 v3s8, v3s16, v3s32, v3s64, v4s8, v4s16, v4s32, v4s64, v8s8,
130 v8s16, v8s32, v8s64, v16s8, v16s16, v16s32, v16s64};
131
132 auto allBoolScalarsAndVectors = {s1, v2s1, v3s1, v4s1, v8s1, v16s1};
133
134 auto allIntScalars = {s8, s16, s32, s64, s128};
135
136 auto allFloatScalarsAndF16Vector2AndVector4s = {s16, s32, s64, v2s16, v4s16};
137
138 auto allFloatScalarsAndVectors = {
139 s16, s32, s64, v2s16, v2s32, v2s64, v3s16, v3s32, v3s64,
140 v4s16, v4s32, v4s64, v8s16, v8s32, v8s64, v16s16, v16s32, v16s64};
141
142 auto allFloatAndIntScalarsAndPtrs = {s8, s16, s32, s64, p0, p1,
143 p2, p3, p4, p5, p6, p7,
144 p8, p9, p10, p11, p12, p13};
145
146 auto allPtrs = {p0, p1, p2, p3, p4, p5, p6, p7, p8, p9, p10, p11, p12, p13};
147
148 auto &allowedVectorTypes = ST.isShader() ? allShaderVectors : allVectors;
149
150 bool IsExtendedInts =
151 ST.canUseExtension(
152 SPIRV::Extension::SPV_ALTERA_arbitrary_precision_integers) ||
153 ST.canUseExtension(SPIRV::Extension::SPV_KHR_bit_instructions) ||
154 ST.canUseExtension(SPIRV::Extension::SPV_INTEL_int4);
155 auto extendedScalarsAndVectors =
156 [IsExtendedInts](const LegalityQuery &Query) {
157 const LLT Ty = Query.Types[0];
158 return IsExtendedInts && Ty.isValid() && !Ty.isPointerOrPointerVector();
159 };
160 auto extendedScalarsAndVectorsProduct = [IsExtendedInts](
161 const LegalityQuery &Query) {
162 const LLT Ty1 = Query.Types[0], Ty2 = Query.Types[1];
163 return IsExtendedInts && Ty1.isValid() && Ty2.isValid() &&
164 !Ty1.isPointerOrPointerVector() && !Ty2.isPointerOrPointerVector();
165 };
166 auto extendedPtrsScalarsAndVectors =
167 [IsExtendedInts](const LegalityQuery &Query) {
168 const LLT Ty = Query.Types[0];
169 return IsExtendedInts && Ty.isValid();
170 };
171
172 // The universal validation rules in the SPIR-V specification state that
173 // vector sizes are typically limited to 2, 3, or 4. However, larger vector
174 // sizes (8 and 16) are enabled when the Kernel capability is present. For
175 // shader execution models, vector sizes are strictly limited to 4. In
176 // non-shader contexts, vector sizes of 8 and 16 are also permitted, but
177 // arbitrary sizes (e.g., 6 or 11) are not.
178 uint32_t MaxVectorSize = ST.isShader() ? 4 : 16;
179 LLVM_DEBUG(dbgs() << "MaxVectorSize: " << MaxVectorSize << "\n");
180
181 for (auto Opc : getTypeFoldingSupportedOpcodes()) {
182 switch (Opc) {
183 case G_EXTRACT_VECTOR_ELT:
184 case G_UREM:
185 case G_SREM:
186 case G_UDIV:
187 case G_SDIV:
188 case G_FREM:
189 break;
190 default:
192 .customFor(allScalars)
193 .customFor(allowedVectorTypes)
197 0, ElementCount::getFixed(MaxVectorSize)))
198 .custom();
199 break;
200 }
201 }
202
203 getActionDefinitionsBuilder({G_UREM, G_SREM, G_SDIV, G_UDIV, G_FREM})
204 .customFor(allScalars)
205 .customFor(allowedVectorTypes)
209 0, ElementCount::getFixed(MaxVectorSize)))
210 .custom();
211
212 getActionDefinitionsBuilder({G_FMA, G_STRICT_FMA})
213 .legalFor(allScalars)
214 .legalFor(allowedVectorTypes)
218 0, ElementCount::getFixed(MaxVectorSize)))
219 .alwaysLegal();
220
221 getActionDefinitionsBuilder(G_INTRINSIC_W_SIDE_EFFECTS).custom();
222
223 getActionDefinitionsBuilder(G_SHUFFLE_VECTOR)
224 .legalForCartesianProduct(allowedVectorTypes, allowedVectorTypes)
226 .lowerIf(vectorElementCountIsGreaterThan(0, MaxVectorSize))
228 .lowerIf(vectorElementCountIsGreaterThan(1, MaxVectorSize));
229
230 getActionDefinitionsBuilder(G_EXTRACT_VECTOR_ELT)
234 1, ElementCount::getFixed(MaxVectorSize)))
235 .custom();
236
237 getActionDefinitionsBuilder(G_INSERT_VECTOR_ELT)
241 0, ElementCount::getFixed(MaxVectorSize)))
242 .custom();
243
244 // Illegal G_UNMERGE_VALUES instructions should be handled
245 // during the combine phase.
246 getActionDefinitionsBuilder(G_BUILD_VECTOR)
248
249 // When entering the legalizer, there should be no G_BITCAST instructions.
250 // They should all be calls to the `spv_bitcast` intrinsic. The call to
251 // the intrinsic will be converted to a G_BITCAST during legalization if
252 // the vectors are not legal. After using the rules to legalize a G_BITCAST,
253 // we turn it back into a call to the intrinsic with a custom rule to avoid
254 // potential machine verifier failures.
260 0, ElementCount::getFixed(MaxVectorSize)))
261 .lowerIf(vectorElementCountIsGreaterThan(1, MaxVectorSize))
262 .custom();
263
264 // If the result is still illegal, the combiner should be able to remove it.
265 getActionDefinitionsBuilder(G_CONCAT_VECTORS)
266 .legalForCartesianProduct(allowedVectorTypes, allowedVectorTypes);
267
268 getActionDefinitionsBuilder(G_SPLAT_VECTOR)
269 .legalFor(allowedVectorTypes)
273 .alwaysLegal();
274
275 // Vector Reduction Operations
277 {G_VECREDUCE_SMIN, G_VECREDUCE_SMAX, G_VECREDUCE_UMIN, G_VECREDUCE_UMAX,
278 G_VECREDUCE_ADD, G_VECREDUCE_MUL, G_VECREDUCE_FMUL, G_VECREDUCE_FMIN,
279 G_VECREDUCE_FMAX, G_VECREDUCE_FMINIMUM, G_VECREDUCE_FMAXIMUM,
280 G_VECREDUCE_OR, G_VECREDUCE_AND, G_VECREDUCE_XOR})
281 .legalFor(allowedVectorTypes)
282 .scalarize(1)
283 .lower();
284
285 getActionDefinitionsBuilder({G_VECREDUCE_SEQ_FADD, G_VECREDUCE_SEQ_FMUL})
286 .scalarize(2)
287 .lower();
288
289 // Illegal G_UNMERGE_VALUES instructions should be handled
290 // during the combine phase.
291 getActionDefinitionsBuilder(G_UNMERGE_VALUES)
293
294 getActionDefinitionsBuilder({G_MEMCPY, G_MEMMOVE})
295 .unsupportedIf(LegalityPredicates::any(typeIs(0, p9), typeIs(1, p9)))
296 .legalIf(all(typeInSet(0, allPtrs), typeInSet(1, allPtrs)));
297
299 .unsupportedIf(typeIs(0, p9))
300 .legalIf(all(typeInSet(0, allPtrs), typeInSet(1, allIntScalars)));
301
302 getActionDefinitionsBuilder(G_ADDRSPACE_CAST)
305 all(typeIsNot(0, p9), typeIs(1, p9))))
306 .legalForCartesianProduct(allPtrs, allPtrs);
307
308 // Should we be legalizing bad scalar sizes like s5 here instead
309 // of handling them in the instruction selector?
310 getActionDefinitionsBuilder({G_LOAD, G_STORE})
311 .unsupportedIf(typeIs(1, p9))
312 .legalForCartesianProduct(allowedVectorTypes, allPtrs)
313 .legalForCartesianProduct(allPtrs, allPtrs)
314 .legalIf(isScalar(0))
315 .custom();
316
317 getActionDefinitionsBuilder({G_SMIN, G_SMAX, G_UMIN, G_UMAX, G_ABS,
318 G_BITREVERSE, G_SADDSAT, G_UADDSAT, G_SSUBSAT,
319 G_USUBSAT, G_SCMP, G_UCMP})
320 .legalFor(allIntScalarsAndVectors)
321 .legalIf(extendedScalarsAndVectors);
322
323 getActionDefinitionsBuilder(G_STRICT_FLDEXP)
324 .legalForCartesianProduct(allFloatScalarsAndVectors, allIntScalars);
325
326 getActionDefinitionsBuilder({G_FPTOSI, G_FPTOUI})
327 .legalForCartesianProduct(allIntScalarsAndVectors,
328 allFloatScalarsAndVectors);
329
330 getActionDefinitionsBuilder({G_FPTOSI_SAT, G_FPTOUI_SAT})
331 .legalForCartesianProduct(allIntScalarsAndVectors,
332 allFloatScalarsAndVectors);
333
334 getActionDefinitionsBuilder({G_SITOFP, G_UITOFP})
335 .legalForCartesianProduct(allFloatScalarsAndVectors,
336 allScalarsAndVectors);
337
339 .legalForCartesianProduct(allIntScalarsAndVectors)
340 .legalIf(extendedScalarsAndVectorsProduct);
341
342 // Extensions.
343 getActionDefinitionsBuilder({G_TRUNC, G_ZEXT, G_SEXT, G_ANYEXT})
344 .legalForCartesianProduct(allScalarsAndVectors)
345 .legalIf(extendedScalarsAndVectorsProduct);
346
348 .legalFor(allPtrsScalarsAndVectors)
349 .legalIf(extendedPtrsScalarsAndVectors);
350
352 all(typeInSet(0, allPtrsScalarsAndVectors),
353 typeInSet(1, allPtrsScalarsAndVectors)));
354
355 getActionDefinitionsBuilder({G_IMPLICIT_DEF, G_FREEZE})
356 .legalFor({s1, s128})
357 .legalFor(allFloatAndIntScalarsAndPtrs)
358 .legalFor(allowedVectorTypes)
362 0, ElementCount::getFixed(MaxVectorSize)));
363
364 getActionDefinitionsBuilder({G_STACKSAVE, G_STACKRESTORE}).alwaysLegal();
365
367 .legalForCartesianProduct(allPtrs, allIntScalars)
368 .legalIf(
369 all(typeInSet(0, allPtrs), typeOfExtendedScalars(1, IsExtendedInts)));
371 .legalForCartesianProduct(allIntScalars, allPtrs)
372 .legalIf(
373 all(typeOfExtendedScalars(0, IsExtendedInts), typeInSet(1, allPtrs)));
375 .legalForCartesianProduct(allPtrs, allIntScalars)
376 .legalIf(
377 all(typeInSet(0, allPtrs), typeOfExtendedScalars(1, IsExtendedInts)));
378
379 // ST.canDirectlyComparePointers() for pointer args is supported in
380 // legalizeCustom().
383 all(typeIs(0, p9), typeInSet(1, allPtrs), typeIsNot(1, p9)),
384 all(typeInSet(0, allPtrs), typeIsNot(0, p9), typeIs(1, p9))))
385 .legalIf([IsExtendedInts](const LegalityQuery &Query) {
386 const LLT Ty = Query.Types[1];
387 return IsExtendedInts && Ty.isValid() && !Ty.isPointerOrPointerVector();
388 })
389 .customIf(all(typeInSet(0, allBoolScalarsAndVectors),
390 typeInSet(1, allPtrsScalarsAndVectors)));
391
393 all(typeInSet(0, allBoolScalarsAndVectors),
394 typeInSet(1, allFloatScalarsAndVectors)));
395
396 getActionDefinitionsBuilder({G_ATOMICRMW_OR, G_ATOMICRMW_ADD, G_ATOMICRMW_AND,
397 G_ATOMICRMW_MAX, G_ATOMICRMW_MIN,
398 G_ATOMICRMW_SUB, G_ATOMICRMW_XOR,
399 G_ATOMICRMW_UMAX, G_ATOMICRMW_UMIN})
400 .legalForCartesianProduct(allIntScalars, allPtrs);
401
403 {G_ATOMICRMW_FADD, G_ATOMICRMW_FSUB, G_ATOMICRMW_FMIN, G_ATOMICRMW_FMAX})
404 .legalForCartesianProduct(allFloatScalarsAndF16Vector2AndVector4s,
405 allPtrs);
406
407 getActionDefinitionsBuilder(G_ATOMICRMW_XCHG)
408 .legalForCartesianProduct(allFloatAndIntScalarsAndPtrs, allPtrs);
409
410 getActionDefinitionsBuilder(G_ATOMIC_CMPXCHG_WITH_SUCCESS).lower();
411 // TODO: add proper legalization rules.
412 getActionDefinitionsBuilder(G_ATOMIC_CMPXCHG).alwaysLegal();
413
415 {G_UADDO, G_SADDO, G_USUBO, G_SSUBO, G_UMULO, G_SMULO})
416 .alwaysLegal();
417
418 getActionDefinitionsBuilder({G_LROUND, G_LLROUND})
419 .legalForCartesianProduct(allFloatScalarsAndVectors,
420 allIntScalarsAndVectors);
421
422 // FP conversions.
423 getActionDefinitionsBuilder({G_FPTRUNC, G_FPEXT})
424 .legalForCartesianProduct(allFloatScalarsAndVectors);
425
426 // Pointer-handling.
427 getActionDefinitionsBuilder(G_FRAME_INDEX).legalFor({p0});
428
429 getActionDefinitionsBuilder(G_GLOBAL_VALUE).legalFor(allPtrs);
430
431 // Control-flow. In some cases (e.g. constants) s1 may be promoted to s32.
432 getActionDefinitionsBuilder(G_BRCOND).legalFor({s1, s32});
433
435 allFloatScalarsAndVectors, {s32, v2s32, v3s32, v4s32, v8s32, v16s32});
436
437 // TODO: Review the target OpenCL and GLSL Extended Instruction Set specs to
438 // tighten these requirements. Many of these math functions are only legal on
439 // specific bitwidths, so they are not selectable for
440 // allFloatScalarsAndVectors.
441 // clang-format off
442 getActionDefinitionsBuilder({G_STRICT_FSQRT,
443 G_FPOW,
444 G_FEXP,
445 G_FMODF,
446 G_FSINCOS,
447 G_FEXP2,
448 G_FEXP10,
449 G_FLOG,
450 G_FLOG2,
451 G_FLOG10,
452 G_FABS,
453 G_FMINNUM,
454 G_FMAXNUM,
455 G_FCEIL,
456 G_FCOS,
457 G_FSIN,
458 G_FTAN,
459 G_FACOS,
460 G_FASIN,
461 G_FATAN,
462 G_FATAN2,
463 G_FCOSH,
464 G_FSINH,
465 G_FTANH,
466 G_FSQRT,
467 G_FFLOOR,
468 G_FRINT,
469 G_FNEARBYINT,
470 G_INTRINSIC_ROUND,
471 G_INTRINSIC_TRUNC,
472 G_FMINIMUM,
473 G_FMAXIMUM,
474 G_INTRINSIC_ROUNDEVEN})
475 .legalFor(allFloatScalarsAndVectors);
476 // clang-format on
477
478 getActionDefinitionsBuilder(G_FCOPYSIGN)
479 .legalForCartesianProduct(allFloatScalarsAndVectors,
480 allFloatScalarsAndVectors);
481
483 allFloatScalarsAndVectors, allIntScalarsAndVectors);
484
485 if (ST.canUseExtInstSet(SPIRV::InstructionSet::OpenCL_std)) {
487 {G_CTTZ, G_CTTZ_ZERO_UNDEF, G_CTLZ, G_CTLZ_ZERO_UNDEF})
488 .legalForCartesianProduct(allIntScalarsAndVectors,
489 allIntScalarsAndVectors);
490
491 // Struct return types become a single scalar, so cannot easily legalize.
492 getActionDefinitionsBuilder({G_SMULH, G_UMULH}).alwaysLegal();
493 }
494
495 getActionDefinitionsBuilder(G_IS_FPCLASS).custom();
496
498 verify(*ST.getInstrInfo());
499}
500
503 MachineIRBuilder &MIRBuilder = Helper.MIRBuilder;
504 Register DstReg = MI.getOperand(0).getReg();
505 Register SrcReg = MI.getOperand(1).getReg();
506 Register IdxReg = MI.getOperand(2).getReg();
507
508 MIRBuilder
509 .buildIntrinsic(Intrinsic::spv_extractelt, ArrayRef<Register>{DstReg})
510 .addUse(SrcReg)
511 .addUse(IdxReg);
512 MI.eraseFromParent();
513 return true;
514}
515
518 MachineIRBuilder &MIRBuilder = Helper.MIRBuilder;
519 Register DstReg = MI.getOperand(0).getReg();
520 Register SrcReg = MI.getOperand(1).getReg();
521 Register ValReg = MI.getOperand(2).getReg();
522 Register IdxReg = MI.getOperand(3).getReg();
523
524 MIRBuilder
525 .buildIntrinsic(Intrinsic::spv_insertelt, ArrayRef<Register>{DstReg})
526 .addUse(SrcReg)
527 .addUse(ValReg)
528 .addUse(IdxReg);
529 MI.eraseFromParent();
530 return true;
531}
532
534 LegalizerHelper &Helper,
537 Register ConvReg = MRI.createGenericVirtualRegister(ConvTy);
538 MRI.setRegClass(ConvReg, GR->getRegClass(SpvType));
539 GR->assignSPIRVTypeToVReg(SpvType, ConvReg, Helper.MIRBuilder.getMF());
540 Helper.MIRBuilder.buildInstr(TargetOpcode::G_PTRTOINT)
541 .addDef(ConvReg)
542 .addUse(Reg);
543 return ConvReg;
544}
545
546static bool needsVectorLegalization(const LLT &Ty, const SPIRVSubtarget &ST) {
547 if (!Ty.isVector())
548 return false;
549 unsigned NumElements = Ty.getNumElements();
550 unsigned MaxVectorSize = ST.isShader() ? 4 : 16;
551 return (NumElements > 4 && !isPowerOf2_32(NumElements)) ||
552 NumElements > MaxVectorSize;
553}
554
557 MachineRegisterInfo &MRI = MI.getMF()->getRegInfo();
558 MachineIRBuilder &MIRBuilder = Helper.MIRBuilder;
559 Register DstReg = MI.getOperand(0).getReg();
560 Register PtrReg = MI.getOperand(1).getReg();
561 LLT DstTy = MRI.getType(DstReg);
562
563 if (!DstTy.isVector())
564 return true;
565
566 const SPIRVSubtarget &ST = MI.getMF()->getSubtarget<SPIRVSubtarget>();
567 if (!needsVectorLegalization(DstTy, ST))
568 return true;
569
570 SmallVector<Register, 8> SplitRegs;
571 LLT EltTy = DstTy.getElementType();
572 unsigned NumElts = DstTy.getNumElements();
573
574 LLT PtrTy = MRI.getType(PtrReg);
575 auto Zero = MIRBuilder.buildConstant(LLT::scalar(32), 0);
576
577 for (unsigned i = 0; i < NumElts; ++i) {
578 auto Idx = MIRBuilder.buildConstant(LLT::scalar(32), i);
579 Register EltPtr = MRI.createGenericVirtualRegister(PtrTy);
580
581 MIRBuilder.buildIntrinsic(Intrinsic::spv_gep, ArrayRef<Register>{EltPtr})
582 .addImm(1) // InBounds
583 .addUse(PtrReg)
584 .addUse(Zero.getReg(0))
585 .addUse(Idx.getReg(0));
586
587 MachinePointerInfo EltPtrInfo;
588 Align EltAlign = Align(1);
589 if (!MI.memoperands_empty()) {
590 MachineMemOperand *MMO = *MI.memoperands_begin();
591 EltPtrInfo =
592 MMO->getPointerInfo().getWithOffset(i * EltTy.getSizeInBytes());
593 EltAlign = commonAlignment(MMO->getAlign(), i * EltTy.getSizeInBytes());
594 }
595
596 Register EltReg = MRI.createGenericVirtualRegister(EltTy);
597 MIRBuilder.buildLoad(EltReg, EltPtr, EltPtrInfo, EltAlign);
598 SplitRegs.push_back(EltReg);
599 }
600
601 MIRBuilder.buildBuildVector(DstReg, SplitRegs);
602 MI.eraseFromParent();
603 return true;
604}
605
608 MachineRegisterInfo &MRI = MI.getMF()->getRegInfo();
609 MachineIRBuilder &MIRBuilder = Helper.MIRBuilder;
610 Register ValReg = MI.getOperand(0).getReg();
611 Register PtrReg = MI.getOperand(1).getReg();
612 LLT ValTy = MRI.getType(ValReg);
613
614 assert(ValTy.isVector() && "Expected vector store");
615
616 SmallVector<Register, 8> SplitRegs;
617 LLT EltTy = ValTy.getElementType();
618 unsigned NumElts = ValTy.getNumElements();
619
620 for (unsigned i = 0; i < NumElts; ++i)
621 SplitRegs.push_back(MRI.createGenericVirtualRegister(EltTy));
622
623 MIRBuilder.buildUnmerge(SplitRegs, ValReg);
624
625 LLT PtrTy = MRI.getType(PtrReg);
626 auto Zero = MIRBuilder.buildConstant(LLT::scalar(32), 0);
627
628 for (unsigned i = 0; i < NumElts; ++i) {
629 auto Idx = MIRBuilder.buildConstant(LLT::scalar(32), i);
630 Register EltPtr = MRI.createGenericVirtualRegister(PtrTy);
631
632 MIRBuilder.buildIntrinsic(Intrinsic::spv_gep, ArrayRef<Register>{EltPtr})
633 .addImm(1) // InBounds
634 .addUse(PtrReg)
635 .addUse(Zero.getReg(0))
636 .addUse(Idx.getReg(0));
637
638 MachinePointerInfo EltPtrInfo;
639 Align EltAlign = Align(1);
640 if (!MI.memoperands_empty()) {
641 MachineMemOperand *MMO = *MI.memoperands_begin();
642 EltPtrInfo =
643 MMO->getPointerInfo().getWithOffset(i * EltTy.getSizeInBytes());
644 EltAlign = commonAlignment(MMO->getAlign(), i * EltTy.getSizeInBytes());
645 }
646
647 MIRBuilder.buildStore(SplitRegs[i], EltPtr, EltPtrInfo, EltAlign);
648 }
649
650 MI.eraseFromParent();
651 return true;
652}
653
656 LostDebugLocObserver &LocObserver) const {
657 MachineRegisterInfo &MRI = MI.getMF()->getRegInfo();
658 switch (MI.getOpcode()) {
659 default:
660 // TODO: implement legalization for other opcodes.
661 return true;
662 case TargetOpcode::G_BITCAST:
663 return legalizeBitcast(Helper, MI);
664 case TargetOpcode::G_EXTRACT_VECTOR_ELT:
665 return legalizeExtractVectorElt(Helper, MI, GR);
666 case TargetOpcode::G_INSERT_VECTOR_ELT:
667 return legalizeInsertVectorElt(Helper, MI, GR);
668 case TargetOpcode::G_INTRINSIC:
669 case TargetOpcode::G_INTRINSIC_W_SIDE_EFFECTS:
670 return legalizeIntrinsic(Helper, MI);
671 case TargetOpcode::G_IS_FPCLASS:
672 return legalizeIsFPClass(Helper, MI, LocObserver);
673 case TargetOpcode::G_ICMP: {
674 assert(GR->getSPIRVTypeForVReg(MI.getOperand(0).getReg()));
675 auto &Op0 = MI.getOperand(2);
676 auto &Op1 = MI.getOperand(3);
677 Register Reg0 = Op0.getReg();
678 Register Reg1 = Op1.getReg();
680 static_cast<CmpInst::Predicate>(MI.getOperand(1).getPredicate());
681 if ((!ST->canDirectlyComparePointers() ||
683 MRI.getType(Reg0).isPointer() && MRI.getType(Reg1).isPointer()) {
684 LLT ConvT = LLT::scalar(ST->getPointerSize());
685 Type *LLVMTy = IntegerType::get(MI.getMF()->getFunction().getContext(),
686 ST->getPointerSize());
687 SPIRVTypeInst SpirvTy = GR->getOrCreateSPIRVType(
688 LLVMTy, Helper.MIRBuilder, SPIRV::AccessQualifier::ReadWrite, true);
689 Op0.setReg(convertPtrToInt(Reg0, ConvT, SpirvTy, Helper, MRI, GR));
690 Op1.setReg(convertPtrToInt(Reg1, ConvT, SpirvTy, Helper, MRI, GR));
691 }
692 return true;
693 }
694 case TargetOpcode::G_LOAD:
695 return legalizeLoad(Helper, MI, GR);
696 case TargetOpcode::G_STORE:
697 return legalizeStore(Helper, MI, GR);
698 }
699}
700
703 Register SrcReg, LLT SrcTy,
704 MachinePointerInfo &PtrInfo, Align &VecAlign) {
705 MachineIRBuilder &MIRBuilder = Helper.MIRBuilder;
706 MachineRegisterInfo &MRI = *MIRBuilder.getMRI();
707
708 VecAlign = Helper.getStackTemporaryAlignment(SrcTy);
709 auto StackTemp = Helper.createStackTemporary(
710 TypeSize::getFixed(SrcTy.getSizeInBytes()), VecAlign, PtrInfo);
711
712 // Set the type of StackTemp to a pointer to an array of the element type.
713 SPIRVTypeInst SpvSrcTy = GR->getSPIRVTypeForVReg(SrcReg);
714 SPIRVTypeInst EltSpvTy = GR->getScalarOrVectorComponentType(SpvSrcTy);
715 const Type *LLVMEltTy = GR->getTypeForSPIRVType(EltSpvTy);
716 const Type *LLVMArrTy =
717 ArrayType::get(const_cast<Type *>(LLVMEltTy), SrcTy.getNumElements());
718 SPIRVTypeInst ArrSpvTy = GR->getOrCreateSPIRVType(
719 LLVMArrTy, MIRBuilder, SPIRV::AccessQualifier::ReadWrite, true);
720 SPIRVTypeInst PtrToArrSpvTy = GR->getOrCreateSPIRVPointerType(
721 ArrSpvTy, MIRBuilder, SPIRV::StorageClass::Function);
722
723 Register StackReg = StackTemp.getReg(0);
724 MRI.setRegClass(StackReg, GR->getRegClass(PtrToArrSpvTy));
725 GR->assignSPIRVTypeToVReg(PtrToArrSpvTy, StackReg, MIRBuilder.getMF());
726
727 return StackTemp;
728}
729
732 LLVM_DEBUG(dbgs() << "Found a bitcast instruction\n");
733 MachineIRBuilder &MIRBuilder = Helper.MIRBuilder;
734 MachineRegisterInfo &MRI = *MIRBuilder.getMRI();
735 const SPIRVSubtarget &ST = MI.getMF()->getSubtarget<SPIRVSubtarget>();
736
737 Register DstReg = MI.getOperand(0).getReg();
738 Register SrcReg = MI.getOperand(2).getReg();
739 LLT DstTy = MRI.getType(DstReg);
740 LLT SrcTy = MRI.getType(SrcReg);
741
742 // If an spv_bitcast needs to be legalized, we convert it to G_BITCAST to
743 // allow using the generic legalization rules.
744 if (needsVectorLegalization(DstTy, ST) ||
745 needsVectorLegalization(SrcTy, ST)) {
746 LLVM_DEBUG(dbgs() << "Replacing with a G_BITCAST\n");
747 MIRBuilder.buildBitcast(DstReg, SrcReg);
748 MI.eraseFromParent();
749 }
750 return true;
751}
752
755 MachineIRBuilder &MIRBuilder = Helper.MIRBuilder;
756 MachineRegisterInfo &MRI = *MIRBuilder.getMRI();
757 const SPIRVSubtarget &ST = MI.getMF()->getSubtarget<SPIRVSubtarget>();
758
759 Register DstReg = MI.getOperand(0).getReg();
760 LLT DstTy = MRI.getType(DstReg);
761
762 if (needsVectorLegalization(DstTy, ST)) {
763 Register SrcReg = MI.getOperand(2).getReg();
764 Register ValReg = MI.getOperand(3).getReg();
765 LLT SrcTy = MRI.getType(SrcReg);
766 MachineOperand &IdxOperand = MI.getOperand(4);
767
768 if (getImm(IdxOperand, &MRI)) {
769 uint64_t IdxVal = foldImm(IdxOperand, &MRI);
770 if (IdxVal < SrcTy.getNumElements()) {
772 SPIRVTypeInst ElementType =
774 LLT ElementLLTTy = GR->getRegType(ElementType);
775 for (unsigned I = 0, E = SrcTy.getNumElements(); I < E; ++I) {
776 Register Reg = MRI.createGenericVirtualRegister(ElementLLTTy);
777 MRI.setRegClass(Reg, GR->getRegClass(ElementType));
778 GR->assignSPIRVTypeToVReg(ElementType, Reg, *MI.getMF());
779 Regs.push_back(Reg);
780 }
781 MIRBuilder.buildUnmerge(Regs, SrcReg);
782 Regs[IdxVal] = ValReg;
783 MIRBuilder.buildBuildVector(DstReg, Regs);
784 MI.eraseFromParent();
785 return true;
786 }
787 }
788
789 LLT EltTy = SrcTy.getElementType();
790 Align VecAlign;
791 MachinePointerInfo PtrInfo;
792 auto StackTemp = createStackTemporaryForVector(Helper, GR, SrcReg, SrcTy,
793 PtrInfo, VecAlign);
794
795 MIRBuilder.buildStore(SrcReg, StackTemp, PtrInfo, VecAlign);
796
797 Register IdxReg = IdxOperand.getReg();
798 LLT PtrTy = MRI.getType(StackTemp.getReg(0));
799 Register EltPtr = MRI.createGenericVirtualRegister(PtrTy);
800 auto Zero = MIRBuilder.buildConstant(LLT::scalar(32), 0);
801
802 MIRBuilder.buildIntrinsic(Intrinsic::spv_gep, ArrayRef<Register>{EltPtr})
803 .addImm(1) // InBounds
804 .addUse(StackTemp.getReg(0))
805 .addUse(Zero.getReg(0))
806 .addUse(IdxReg);
807
809 Align EltAlign = Helper.getStackTemporaryAlignment(EltTy);
810 MIRBuilder.buildStore(ValReg, EltPtr, EltPtrInfo, EltAlign);
811
812 MIRBuilder.buildLoad(DstReg, StackTemp, PtrInfo, VecAlign);
813 MI.eraseFromParent();
814 return true;
815 }
816 return true;
817}
818
821 MachineIRBuilder &MIRBuilder = Helper.MIRBuilder;
822 MachineRegisterInfo &MRI = *MIRBuilder.getMRI();
823 const SPIRVSubtarget &ST = MI.getMF()->getSubtarget<SPIRVSubtarget>();
824
825 Register SrcReg = MI.getOperand(2).getReg();
826 LLT SrcTy = MRI.getType(SrcReg);
827
828 if (needsVectorLegalization(SrcTy, ST)) {
829 Register DstReg = MI.getOperand(0).getReg();
830 MachineOperand &IdxOperand = MI.getOperand(3);
831
832 if (getImm(IdxOperand, &MRI)) {
833 uint64_t IdxVal = foldImm(IdxOperand, &MRI);
834 if (IdxVal < SrcTy.getNumElements()) {
835 LLT DstTy = MRI.getType(DstReg);
837 SPIRVTypeInst DstSpvTy = GR->getSPIRVTypeForVReg(DstReg);
838 for (unsigned I = 0, E = SrcTy.getNumElements(); I < E; ++I) {
839 if (I == IdxVal) {
840 Regs.push_back(DstReg);
841 } else {
842 Register Reg = MRI.createGenericVirtualRegister(DstTy);
843 MRI.setRegClass(Reg, GR->getRegClass(DstSpvTy));
844 GR->assignSPIRVTypeToVReg(DstSpvTy, Reg, *MI.getMF());
845 Regs.push_back(Reg);
846 }
847 }
848 MIRBuilder.buildUnmerge(Regs, SrcReg);
849 MI.eraseFromParent();
850 return true;
851 }
852 }
853
854 LLT EltTy = SrcTy.getElementType();
855 Align VecAlign;
856 MachinePointerInfo PtrInfo;
857 auto StackTemp = createStackTemporaryForVector(Helper, GR, SrcReg, SrcTy,
858 PtrInfo, VecAlign);
859
860 MIRBuilder.buildStore(SrcReg, StackTemp, PtrInfo, VecAlign);
861
862 Register IdxReg = IdxOperand.getReg();
863 LLT PtrTy = MRI.getType(StackTemp.getReg(0));
864 Register EltPtr = MRI.createGenericVirtualRegister(PtrTy);
865 auto Zero = MIRBuilder.buildConstant(LLT::scalar(32), 0);
866
867 MIRBuilder.buildIntrinsic(Intrinsic::spv_gep, ArrayRef<Register>{EltPtr})
868 .addImm(1) // InBounds
869 .addUse(StackTemp.getReg(0))
870 .addUse(Zero.getReg(0))
871 .addUse(IdxReg);
872
874 Align EltAlign = Helper.getStackTemporaryAlignment(EltTy);
875 MIRBuilder.buildLoad(DstReg, EltPtr, EltPtrInfo, EltAlign);
876
877 MI.eraseFromParent();
878 return true;
879 }
880 return true;
881}
882
885 MachineIRBuilder &MIRBuilder = Helper.MIRBuilder;
886 MachineRegisterInfo &MRI = *MIRBuilder.getMRI();
887 const SPIRVSubtarget &ST = MI.getMF()->getSubtarget<SPIRVSubtarget>();
888
889 Register DstReg = MI.getOperand(0).getReg();
890 LLT DstTy = MRI.getType(DstReg);
891
892 if (!needsVectorLegalization(DstTy, ST))
893 return true;
894
896 if (MI.getNumOperands() == 2) {
897 // The "null" case: no values are attached.
898 LLT EltTy = DstTy.getElementType();
899 auto Zero = MIRBuilder.buildConstant(EltTy, 0);
900 SPIRVTypeInst SpvDstTy = GR->getSPIRVTypeForVReg(DstReg);
901 SPIRVTypeInst SpvEltTy = GR->getScalarOrVectorComponentType(SpvDstTy);
902 GR->assignSPIRVTypeToVReg(SpvEltTy, Zero.getReg(0), MIRBuilder.getMF());
903 for (unsigned i = 0; i < DstTy.getNumElements(); ++i)
904 SrcRegs.push_back(Zero.getReg(0));
905 } else {
906 for (unsigned i = 2; i < MI.getNumOperands(); ++i) {
907 SrcRegs.push_back(MI.getOperand(i).getReg());
908 }
909 }
910 MIRBuilder.buildBuildVector(DstReg, SrcRegs);
911 MI.eraseFromParent();
912 return true;
913}
914
916 MachineInstr &MI) const {
917 LLVM_DEBUG(dbgs() << "legalizeIntrinsic: " << MI);
918 auto IntrinsicID = cast<GIntrinsic>(MI).getIntrinsicID();
919 switch (IntrinsicID) {
920 case Intrinsic::spv_bitcast:
921 return legalizeSpvBitcast(Helper, MI, GR);
922 case Intrinsic::spv_insertelt:
923 return legalizeSpvInsertElt(Helper, MI, GR);
924 case Intrinsic::spv_extractelt:
925 return legalizeSpvExtractElt(Helper, MI, GR);
926 case Intrinsic::spv_const_composite:
927 return legalizeSpvConstComposite(Helper, MI, GR);
928 }
929 return true;
930}
931
932bool SPIRVLegalizerInfo::legalizeBitcast(LegalizerHelper &Helper,
933 MachineInstr &MI) const {
934 // Once the G_BITCAST is using vectors that are allowed, we turn it back into
935 // an spv_bitcast to avoid verifier problems when the register types are the
936 // same for the source and the result. Note that the SPIR-V types associated
937 // with the bitcast can be different even if the register types are the same.
938 MachineIRBuilder &MIRBuilder = Helper.MIRBuilder;
939 Register DstReg = MI.getOperand(0).getReg();
940 Register SrcReg = MI.getOperand(1).getReg();
941 SmallVector<Register, 1> DstRegs = {DstReg};
942 MIRBuilder.buildIntrinsic(Intrinsic::spv_bitcast, DstRegs).addUse(SrcReg);
943 MI.eraseFromParent();
944 return true;
945}
946
947// Note this code was copied from LegalizerHelper::lowerISFPCLASS and adjusted
948// to ensure that all instructions created during the lowering have SPIR-V types
949// assigned to them.
950bool SPIRVLegalizerInfo::legalizeIsFPClass(
952 LostDebugLocObserver &LocObserver) const {
953 auto [DstReg, DstTy, SrcReg, SrcTy] = MI.getFirst2RegLLTs();
954 FPClassTest Mask = static_cast<FPClassTest>(MI.getOperand(2).getImm());
955
956 auto &MIRBuilder = Helper.MIRBuilder;
957 auto &MF = MIRBuilder.getMF();
958 MachineRegisterInfo &MRI = MF.getRegInfo();
959
960 Type *LLVMDstTy =
961 IntegerType::get(MIRBuilder.getContext(), DstTy.getScalarSizeInBits());
962 if (DstTy.isVector())
963 LLVMDstTy = VectorType::get(LLVMDstTy, DstTy.getElementCount());
964 SPIRVTypeInst SPIRVDstTy = GR->getOrCreateSPIRVType(
965 LLVMDstTy, MIRBuilder, SPIRV::AccessQualifier::ReadWrite,
966 /*EmitIR*/ true);
967
968 unsigned BitSize = SrcTy.getScalarSizeInBits();
969 const fltSemantics &Semantics = getFltSemanticForLLT(SrcTy.getScalarType());
970
971 LLT IntTy = LLT::scalar(BitSize);
972 Type *LLVMIntTy = IntegerType::get(MIRBuilder.getContext(), BitSize);
973 if (SrcTy.isVector()) {
974 IntTy = LLT::vector(SrcTy.getElementCount(), IntTy);
975 LLVMIntTy = VectorType::get(LLVMIntTy, SrcTy.getElementCount());
976 }
977 SPIRVTypeInst SPIRVIntTy = GR->getOrCreateSPIRVType(
978 LLVMIntTy, MIRBuilder, SPIRV::AccessQualifier::ReadWrite,
979 /*EmitIR*/ true);
980
981 // Clang doesn't support capture of structured bindings:
982 LLT DstTyCopy = DstTy;
983 const auto assignSPIRVTy = [&](MachineInstrBuilder &&MI) {
984 // Assign this MI's (assumed only) destination to one of the two types we
985 // expect: either the G_IS_FPCLASS's destination type, or the integer type
986 // bitcast from the source type.
987 LLT MITy = MRI.getType(MI.getReg(0));
988 assert((MITy == IntTy || MITy == DstTyCopy) &&
989 "Unexpected LLT type while lowering G_IS_FPCLASS");
990 SPIRVTypeInst SPVTy = MITy == IntTy ? SPIRVIntTy : SPIRVDstTy;
991 GR->assignSPIRVTypeToVReg(SPVTy, MI.getReg(0), MF);
992 return MI;
993 };
994
995 // Helper to build and assign a constant in one go
996 const auto buildSPIRVConstant = [&](LLT Ty, auto &&C) -> MachineInstrBuilder {
997 if (!Ty.isFixedVector())
998 return assignSPIRVTy(MIRBuilder.buildConstant(Ty, C));
999 auto ScalarC = MIRBuilder.buildConstant(Ty.getScalarType(), C);
1000 assert((Ty == IntTy || Ty == DstTyCopy) &&
1001 "Unexpected LLT type while lowering constant for G_IS_FPCLASS");
1002 SPIRVTypeInst VecEltTy = GR->getOrCreateSPIRVType(
1003 (Ty == IntTy ? LLVMIntTy : LLVMDstTy)->getScalarType(), MIRBuilder,
1004 SPIRV::AccessQualifier::ReadWrite,
1005 /*EmitIR*/ true);
1006 GR->assignSPIRVTypeToVReg(VecEltTy, ScalarC.getReg(0), MF);
1007 return assignSPIRVTy(MIRBuilder.buildSplatBuildVector(Ty, ScalarC));
1008 };
1009
1010 if (Mask == fcNone) {
1011 MIRBuilder.buildCopy(DstReg, buildSPIRVConstant(DstTy, 0));
1012 MI.eraseFromParent();
1013 return true;
1014 }
1015 if (Mask == fcAllFlags) {
1016 MIRBuilder.buildCopy(DstReg, buildSPIRVConstant(DstTy, 1));
1017 MI.eraseFromParent();
1018 return true;
1019 }
1020
1021 // Note that rather than creating a COPY here (between a floating-point and
1022 // integer type of the same size) we create a SPIR-V bitcast immediately. We
1023 // can't create a G_BITCAST because the LLTs are the same, and we can't seem
1024 // to correctly lower COPYs to SPIR-V bitcasts at this moment.
1025 Register ResVReg = MRI.createGenericVirtualRegister(IntTy);
1026 MRI.setRegClass(ResVReg, GR->getRegClass(SPIRVIntTy));
1027 GR->assignSPIRVTypeToVReg(SPIRVIntTy, ResVReg, Helper.MIRBuilder.getMF());
1028 auto AsInt = MIRBuilder.buildInstr(SPIRV::OpBitcast)
1029 .addDef(ResVReg)
1030 .addUse(GR->getSPIRVTypeID(SPIRVIntTy))
1031 .addUse(SrcReg);
1032 AsInt = assignSPIRVTy(std::move(AsInt));
1033
1034 // Various masks.
1035 APInt SignBit = APInt::getSignMask(BitSize);
1036 APInt ValueMask = APInt::getSignedMaxValue(BitSize); // All bits but sign.
1037 APInt Inf = APFloat::getInf(Semantics).bitcastToAPInt(); // Exp and int bit.
1038 APInt ExpMask = Inf;
1039 APInt AllOneMantissa = APFloat::getLargest(Semantics).bitcastToAPInt() & ~Inf;
1040 APInt QNaNBitMask =
1041 APInt::getOneBitSet(BitSize, AllOneMantissa.getActiveBits() - 1);
1042 APInt InversionMask = APInt::getAllOnes(DstTy.getScalarSizeInBits());
1043
1044 auto SignBitC = buildSPIRVConstant(IntTy, SignBit);
1045 auto ValueMaskC = buildSPIRVConstant(IntTy, ValueMask);
1046 auto InfC = buildSPIRVConstant(IntTy, Inf);
1047 auto ExpMaskC = buildSPIRVConstant(IntTy, ExpMask);
1048 auto ZeroC = buildSPIRVConstant(IntTy, 0);
1049
1050 auto Abs = assignSPIRVTy(MIRBuilder.buildAnd(IntTy, AsInt, ValueMaskC));
1051 auto Sign = assignSPIRVTy(
1052 MIRBuilder.buildICmp(CmpInst::Predicate::ICMP_NE, DstTy, AsInt, Abs));
1053
1054 auto Res = buildSPIRVConstant(DstTy, 0);
1055
1056 const auto appendToRes = [&](MachineInstrBuilder &&ToAppend) {
1057 Res = assignSPIRVTy(
1058 MIRBuilder.buildOr(DstTyCopy, Res, assignSPIRVTy(std::move(ToAppend))));
1059 };
1060
1061 // Tests that involve more than one class should be processed first.
1062 if ((Mask & fcFinite) == fcFinite) {
1063 // finite(V) ==> abs(V) u< exp_mask
1064 appendToRes(MIRBuilder.buildICmp(CmpInst::Predicate::ICMP_ULT, DstTy, Abs,
1065 ExpMaskC));
1066 Mask &= ~fcFinite;
1067 } else if ((Mask & fcFinite) == fcPosFinite) {
1068 // finite(V) && V > 0 ==> V u< exp_mask
1069 appendToRes(MIRBuilder.buildICmp(CmpInst::Predicate::ICMP_ULT, DstTy, AsInt,
1070 ExpMaskC));
1071 Mask &= ~fcPosFinite;
1072 } else if ((Mask & fcFinite) == fcNegFinite) {
1073 // finite(V) && V < 0 ==> abs(V) u< exp_mask && signbit == 1
1074 auto Cmp = assignSPIRVTy(MIRBuilder.buildICmp(CmpInst::Predicate::ICMP_ULT,
1075 DstTy, Abs, ExpMaskC));
1076 appendToRes(MIRBuilder.buildAnd(DstTy, Cmp, Sign));
1077 Mask &= ~fcNegFinite;
1078 }
1079
1080 if (FPClassTest PartialCheck = Mask & (fcZero | fcSubnormal)) {
1081 // fcZero | fcSubnormal => test all exponent bits are 0
1082 // TODO: Handle sign bit specific cases
1083 // TODO: Handle inverted case
1084 if (PartialCheck == (fcZero | fcSubnormal)) {
1085 auto ExpBits = assignSPIRVTy(MIRBuilder.buildAnd(IntTy, AsInt, ExpMaskC));
1086 appendToRes(MIRBuilder.buildICmp(CmpInst::Predicate::ICMP_EQ, DstTy,
1087 ExpBits, ZeroC));
1088 Mask &= ~PartialCheck;
1089 }
1090 }
1091
1092 // Check for individual classes.
1093 if (FPClassTest PartialCheck = Mask & fcZero) {
1094 if (PartialCheck == fcPosZero)
1095 appendToRes(MIRBuilder.buildICmp(CmpInst::Predicate::ICMP_EQ, DstTy,
1096 AsInt, ZeroC));
1097 else if (PartialCheck == fcZero)
1098 appendToRes(
1099 MIRBuilder.buildICmp(CmpInst::Predicate::ICMP_EQ, DstTy, Abs, ZeroC));
1100 else // fcNegZero
1101 appendToRes(MIRBuilder.buildICmp(CmpInst::Predicate::ICMP_EQ, DstTy,
1102 AsInt, SignBitC));
1103 }
1104
1105 if (FPClassTest PartialCheck = Mask & fcSubnormal) {
1106 // issubnormal(V) ==> unsigned(abs(V) - 1) u< (all mantissa bits set)
1107 // issubnormal(V) && V>0 ==> unsigned(V - 1) u< (all mantissa bits set)
1108 auto V = (PartialCheck == fcPosSubnormal) ? AsInt : Abs;
1109 auto OneC = buildSPIRVConstant(IntTy, 1);
1110 auto VMinusOne = MIRBuilder.buildSub(IntTy, V, OneC);
1111 auto SubnormalRes = assignSPIRVTy(
1112 MIRBuilder.buildICmp(CmpInst::Predicate::ICMP_ULT, DstTy, VMinusOne,
1113 buildSPIRVConstant(IntTy, AllOneMantissa)));
1114 if (PartialCheck == fcNegSubnormal)
1115 SubnormalRes = MIRBuilder.buildAnd(DstTy, SubnormalRes, Sign);
1116 appendToRes(std::move(SubnormalRes));
1117 }
1118
1119 if (FPClassTest PartialCheck = Mask & fcInf) {
1120 if (PartialCheck == fcPosInf)
1121 appendToRes(MIRBuilder.buildICmp(CmpInst::Predicate::ICMP_EQ, DstTy,
1122 AsInt, InfC));
1123 else if (PartialCheck == fcInf)
1124 appendToRes(
1125 MIRBuilder.buildICmp(CmpInst::Predicate::ICMP_EQ, DstTy, Abs, InfC));
1126 else { // fcNegInf
1127 APInt NegInf = APFloat::getInf(Semantics, true).bitcastToAPInt();
1128 auto NegInfC = buildSPIRVConstant(IntTy, NegInf);
1129 appendToRes(MIRBuilder.buildICmp(CmpInst::Predicate::ICMP_EQ, DstTy,
1130 AsInt, NegInfC));
1131 }
1132 }
1133
1134 if (FPClassTest PartialCheck = Mask & fcNan) {
1135 auto InfWithQnanBitC =
1136 buildSPIRVConstant(IntTy, std::move(Inf) | QNaNBitMask);
1137 if (PartialCheck == fcNan) {
1138 // isnan(V) ==> abs(V) u> int(inf)
1139 appendToRes(
1140 MIRBuilder.buildICmp(CmpInst::Predicate::ICMP_UGT, DstTy, Abs, InfC));
1141 } else if (PartialCheck == fcQNan) {
1142 // isquiet(V) ==> abs(V) u>= (unsigned(Inf) | quiet_bit)
1143 appendToRes(MIRBuilder.buildICmp(CmpInst::Predicate::ICMP_UGE, DstTy, Abs,
1144 InfWithQnanBitC));
1145 } else { // fcSNan
1146 // issignaling(V) ==> abs(V) u> unsigned(Inf) &&
1147 // abs(V) u< (unsigned(Inf) | quiet_bit)
1148 auto IsNan = assignSPIRVTy(
1149 MIRBuilder.buildICmp(CmpInst::Predicate::ICMP_UGT, DstTy, Abs, InfC));
1150 auto IsNotQnan = assignSPIRVTy(MIRBuilder.buildICmp(
1151 CmpInst::Predicate::ICMP_ULT, DstTy, Abs, InfWithQnanBitC));
1152 appendToRes(MIRBuilder.buildAnd(DstTy, IsNan, IsNotQnan));
1153 }
1154 }
1155
1156 if (FPClassTest PartialCheck = Mask & fcNormal) {
1157 // isnormal(V) ==> (0 u< exp u< max_exp) ==> (unsigned(exp-1) u<
1158 // (max_exp-1))
1159 APInt ExpLSB = ExpMask & ~(ExpMask.shl(1));
1160 auto ExpMinusOne = assignSPIRVTy(
1161 MIRBuilder.buildSub(IntTy, Abs, buildSPIRVConstant(IntTy, ExpLSB)));
1162 APInt MaxExpMinusOne = std::move(ExpMask) - ExpLSB;
1163 auto NormalRes = assignSPIRVTy(
1164 MIRBuilder.buildICmp(CmpInst::Predicate::ICMP_ULT, DstTy, ExpMinusOne,
1165 buildSPIRVConstant(IntTy, MaxExpMinusOne)));
1166 if (PartialCheck == fcNegNormal)
1167 NormalRes = MIRBuilder.buildAnd(DstTy, NormalRes, Sign);
1168 else if (PartialCheck == fcPosNormal) {
1169 auto PosSign = assignSPIRVTy(MIRBuilder.buildXor(
1170 DstTy, Sign, buildSPIRVConstant(DstTy, InversionMask)));
1171 NormalRes = MIRBuilder.buildAnd(DstTy, NormalRes, PosSign);
1172 }
1173 appendToRes(std::move(NormalRes));
1174 }
1175
1176 MIRBuilder.buildCopy(DstReg, Res);
1177 MI.eraseFromParent();
1178 return true;
1179}
unsigned const MachineRegisterInfo * MRI
assert(UImm &&(UImm !=~static_cast< T >(0)) &&"Invalid immediate!")
static GCRegistry::Add< CoreCLRGC > E("coreclr", "CoreCLR-compatible GC")
static void scalarize(Instruction *I, SmallVectorImpl< Instruction * > &Worklist)
Declares convenience wrapper classes for interpreting MachineInstr instances as specific generic oper...
IRTranslator LLVM IR MI
#define I(x, y, z)
Definition MD5.cpp:57
This file declares the MachineIRBuilder class.
Register Reg
Promote Memory to Register
Definition Mem2Reg.cpp:110
ppc ctr loops verify
const SmallVectorImpl< MachineOperand > & Cond
static bool legalizeSpvInsertElt(LegalizerHelper &Helper, MachineInstr &MI, SPIRVGlobalRegistry *GR)
static bool needsVectorLegalization(const LLT &Ty, const SPIRVSubtarget &ST)
static MachineInstrBuilder createStackTemporaryForVector(LegalizerHelper &Helper, SPIRVGlobalRegistry *GR, Register SrcReg, LLT SrcTy, MachinePointerInfo &PtrInfo, Align &VecAlign)
static Register convertPtrToInt(Register Reg, LLT ConvTy, SPIRVTypeInst SpvType, LegalizerHelper &Helper, MachineRegisterInfo &MRI, SPIRVGlobalRegistry *GR)
LegalityPredicate typeOfExtendedScalars(unsigned TypeIdx, bool IsExtendedInts)
static bool legalizeExtractVectorElt(LegalizerHelper &Helper, MachineInstr &MI, SPIRVGlobalRegistry *GR)
static bool legalizeStore(LegalizerHelper &Helper, MachineInstr &MI, SPIRVGlobalRegistry *GR)
static bool legalizeSpvExtractElt(LegalizerHelper &Helper, MachineInstr &MI, SPIRVGlobalRegistry *GR)
static bool legalizeSpvBitcast(LegalizerHelper &Helper, MachineInstr &MI, SPIRVGlobalRegistry *GR)
static bool legalizeSpvConstComposite(LegalizerHelper &Helper, MachineInstr &MI, SPIRVGlobalRegistry *GR)
static bool legalizeLoad(LegalizerHelper &Helper, MachineInstr &MI, SPIRVGlobalRegistry *GR)
static bool legalizeInsertVectorElt(LegalizerHelper &Helper, MachineInstr &MI, SPIRVGlobalRegistry *GR)
#define LLVM_DEBUG(...)
Definition Debug.h:114
APInt bitcastToAPInt() const
Definition APFloat.h:1404
static APFloat getLargest(const fltSemantics &Sem, bool Negative=false)
Returns the largest finite number in the given semantics.
Definition APFloat.h:1189
static APFloat getInf(const fltSemantics &Sem, bool Negative=false)
Factory for Positive and Negative Infinity.
Definition APFloat.h:1149
static APInt getAllOnes(unsigned numBits)
Return an APInt of a specified width with all bits set.
Definition APInt.h:235
static APInt getSignMask(unsigned BitWidth)
Get the SignMask for a specific bit width.
Definition APInt.h:230
unsigned getActiveBits() const
Compute the number of active bits in the value.
Definition APInt.h:1527
static APInt getSignedMaxValue(unsigned numBits)
Gets maximum signed value of APInt for a specific bit width.
Definition APInt.h:210
APInt shl(unsigned shiftAmt) const
Left-shift function.
Definition APInt.h:880
static APInt getOneBitSet(unsigned numBits, unsigned BitNo)
Return an APInt with exactly one bit set in the result.
Definition APInt.h:240
ArrayRef - Represent a constant reference to an array (0 or more elements consecutively in memory),...
Definition ArrayRef.h:40
static LLVM_ABI ArrayType * get(Type *ElementType, uint64_t NumElements)
This static method is the primary way to construct an ArrayType.
Predicate
This enumeration lists the possible predicates for CmpInst subclasses.
Definition InstrTypes.h:676
@ ICMP_UGE
unsigned greater or equal
Definition InstrTypes.h:700
@ ICMP_UGT
unsigned greater than
Definition InstrTypes.h:699
@ ICMP_ULT
unsigned less than
Definition InstrTypes.h:701
@ ICMP_NE
not equal
Definition InstrTypes.h:698
static constexpr ElementCount getFixed(ScalarTy MinVal)
Definition TypeSize.h:309
static LLVM_ABI IntegerType * get(LLVMContext &C, unsigned NumBits)
This static method is the primary way of constructing an IntegerType.
Definition Type.cpp:318
static constexpr LLT vector(ElementCount EC, unsigned ScalarSizeInBits)
Get a low-level vector of some number of elements and element width.
static constexpr LLT scalar(unsigned SizeInBits)
Get a low-level scalar or aggregate "bag of bits".
constexpr bool isValid() const
constexpr uint16_t getNumElements() const
Returns the number of elements in a vector LLT.
constexpr bool isVector() const
static constexpr LLT pointer(unsigned AddressSpace, unsigned SizeInBits)
Get a low-level pointer in the given address space.
constexpr LLT getElementType() const
Returns the vector's element type. Only valid for vector types.
constexpr unsigned getAddressSpace() const
static constexpr LLT fixed_vector(unsigned NumElements, unsigned ScalarSizeInBits)
Get a low-level fixed-width vector of some number of elements and element width.
constexpr bool isPointerOrPointerVector() const
constexpr bool isFixedVector() const
Returns true if the LLT is a fixed vector.
constexpr LLT getScalarType() const
constexpr TypeSize getSizeInBytes() const
Returns the total size of the type in bytes, i.e.
LLVM_ABI void computeTables()
Compute any ancillary tables needed to quickly decide how an operation should be handled.
LegalizeRuleSet & legalFor(std::initializer_list< LLT > Types)
The instruction is legal when type index 0 is any type in the given list.
LegalizeRuleSet & fewerElementsIf(LegalityPredicate Predicate, LegalizeMutation Mutation)
Remove elements to reach the type selected by the mutation if the predicate is true.
LegalizeRuleSet & moreElementsToNextPow2(unsigned TypeIdx)
Add more elements to the vector to reach the next power of two.
LegalizeRuleSet & lower()
The instruction is lowered.
LegalizeRuleSet & scalarizeIf(LegalityPredicate Predicate, unsigned TypeIdx)
LegalizeRuleSet & lowerIf(LegalityPredicate Predicate)
The instruction is lowered if predicate is true.
LegalizeRuleSet & custom()
Unconditionally custom lower.
LegalizeRuleSet & unsupportedIf(LegalityPredicate Predicate)
LegalizeRuleSet & alwaysLegal()
LegalizeRuleSet & scalarize(unsigned TypeIdx)
LegalizeRuleSet & legalForCartesianProduct(std::initializer_list< LLT > Types)
The instruction is legal when type indexes 0 and 1 are both in the given list.
LegalizeRuleSet & legalIf(LegalityPredicate Predicate)
The instruction is legal if predicate is true.
LegalizeRuleSet & customFor(std::initializer_list< LLT > Types)
LLVM_ABI MachineInstrBuilder createStackTemporary(TypeSize Bytes, Align Alignment, MachinePointerInfo &PtrInfo)
Create a stack temporary based on the size in bytes and the alignment.
MachineIRBuilder & MIRBuilder
Expose MIRBuilder so clients can set their own RecordInsertInstruction functions.
LLVM_ABI Align getStackTemporaryAlignment(LLT Type, Align MinAlign=Align()) const
Return the alignment to use for a stack temporary object with the given type.
LegalizeRuleSet & getActionDefinitionsBuilder(unsigned Opcode)
Get the action definition builder for the given opcode.
const LegacyLegalizerInfo & getLegacyLegalizerInfo() const
Helper class to build MachineInstr.
LLVMContext & getContext() const
MachineInstrBuilder buildUnmerge(ArrayRef< LLT > Res, const SrcOp &Op)
Build and insert Res0, ... = G_UNMERGE_VALUES Op.
MachineInstrBuilder buildAnd(const DstOp &Dst, const SrcOp &Src0, const SrcOp &Src1)
Build and insert Res = G_AND Op0, Op1.
MachineInstrBuilder buildICmp(CmpInst::Predicate Pred, const DstOp &Res, const SrcOp &Op0, const SrcOp &Op1, std::optional< unsigned > Flags=std::nullopt)
Build and insert a Res = G_ICMP Pred, Op0, Op1.
MachineInstrBuilder buildSub(const DstOp &Dst, const SrcOp &Src0, const SrcOp &Src1, std::optional< unsigned > Flags=std::nullopt)
Build and insert Res = G_SUB Op0, Op1.
MachineInstrBuilder buildIntrinsic(Intrinsic::ID ID, ArrayRef< Register > Res, bool HasSideEffects, bool isConvergent)
Build and insert a G_INTRINSIC instruction.
MachineInstrBuilder buildSplatBuildVector(const DstOp &Res, const SrcOp &Src)
Build and insert Res = G_BUILD_VECTOR with Src replicated to fill the number of elements.
MachineInstrBuilder buildBuildVector(const DstOp &Res, ArrayRef< Register > Ops)
Build and insert Res = G_BUILD_VECTOR Op0, ...
MachineInstrBuilder buildLoad(const DstOp &Res, const SrcOp &Addr, MachineMemOperand &MMO)
Build and insert Res = G_LOAD Addr, MMO.
MachineInstrBuilder buildStore(const SrcOp &Val, const SrcOp &Addr, MachineMemOperand &MMO)
Build and insert G_STORE Val, Addr, MMO.
MachineInstrBuilder buildInstr(unsigned Opcode)
Build and insert <empty> = Opcode <empty>.
MachineFunction & getMF()
Getter for the function we currently build.
MachineInstrBuilder buildBitcast(const DstOp &Dst, const SrcOp &Src)
Build and insert Dst = G_BITCAST Src.
MachineRegisterInfo * getMRI()
Getter for MRI.
MachineInstrBuilder buildOr(const DstOp &Dst, const SrcOp &Src0, const SrcOp &Src1, std::optional< unsigned > Flags=std::nullopt)
Build and insert Res = G_OR Op0, Op1.
MachineInstrBuilder buildCopy(const DstOp &Res, const SrcOp &Op)
Build and insert Res = COPY Op.
MachineInstrBuilder buildXor(const DstOp &Dst, const SrcOp &Src0, const SrcOp &Src1)
Build and insert Res = G_XOR Op0, Op1.
virtual MachineInstrBuilder buildConstant(const DstOp &Res, const ConstantInt &Val)
Build and insert Res = G_CONSTANT Val.
const MachineInstrBuilder & addUse(Register RegNo, RegState Flags={}, unsigned SubReg=0) const
Add a virtual register use operand.
const MachineInstrBuilder & addDef(Register RegNo, RegState Flags={}, unsigned SubReg=0) const
Add a virtual register definition operand.
Representation of each machine instruction.
A description of a memory reference used in the backend.
const MachinePointerInfo & getPointerInfo() const
LLVM_ABI Align getAlign() const
Return the minimum known alignment in bytes of the actual memory reference.
MachineOperand class - Representation of each machine instruction operand.
Register getReg() const
getReg - Returns the register number.
MachineRegisterInfo - Keep track of information for virtual and physical registers,...
Wrapper class representing virtual and physical registers.
Definition Register.h:20
void assignSPIRVTypeToVReg(SPIRVTypeInst Type, Register VReg, const MachineFunction &MF)
const TargetRegisterClass * getRegClass(SPIRVTypeInst SpvType) const
const Type * getTypeForSPIRVType(SPIRVTypeInst Ty) const
LLT getRegType(SPIRVTypeInst SpvType) const
SPIRVTypeInst getOrCreateSPIRVPointerType(const Type *BaseType, MachineIRBuilder &MIRBuilder, SPIRV::StorageClass::StorageClass SC)
SPIRVTypeInst getScalarOrVectorComponentType(SPIRVTypeInst Type) const
SPIRVTypeInst getOrCreateSPIRVType(const Type *Type, MachineInstr &I, SPIRV::AccessQualifier::AccessQualifier AQ, bool EmitIR)
SPIRVTypeInst getSPIRVTypeForVReg(Register VReg, const MachineFunction *MF=nullptr) const
SPIRVLegalizerInfo(const SPIRVSubtarget &ST)
bool legalizeCustom(LegalizerHelper &Helper, MachineInstr &MI, LostDebugLocObserver &LocObserver) const override
Called for instructions with the Custom LegalizationAction.
bool legalizeIntrinsic(LegalizerHelper &Helper, MachineInstr &MI) const override
SPIRVGlobalRegistry * getSPIRVGlobalRegistry() const
void push_back(const T &Elt)
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
static constexpr TypeSize getFixed(ScalarTy ExactSize)
Definition TypeSize.h:343
The instances of the Type class are immutable: once they are created, they are never changed.
Definition Type.h:45
static LLVM_ABI VectorType * get(Type *ElementType, ElementCount EC)
This static method is the primary way to construct an VectorType.
constexpr std::underlying_type_t< E > Mask()
Get a bitmask with 1s in all places up to the high-order bit of E's largest value.
@ C
The default llvm calling convention, compatible with C.
Definition CallingConv.h:34
LLVM_ABI LegalityPredicate isScalar(unsigned TypeIdx)
True iff the specified type index is a scalar.
LLVM_ABI LegalityPredicate numElementsNotPow2(unsigned TypeIdx)
True iff the specified type index is a vector whose element count is not a power of 2.
LLVM_ABI LegalityPredicate vectorElementCountIsLessThanOrEqualTo(unsigned TypeIdx, unsigned Size)
True iff the specified type index is a vector with a number of elements that's less than or equal to ...
LLVM_ABI LegalityPredicate typeInSet(unsigned TypeIdx, std::initializer_list< LLT > TypesInit)
True iff the given type index is one of the specified types.
LLVM_ABI LegalityPredicate vectorElementCountIsGreaterThan(unsigned TypeIdx, unsigned Size)
True iff the specified type index is a vector with a number of elements that's greater than the given...
Predicate any(Predicate P0, Predicate P1)
True iff P0 or P1 are true.
LegalityPredicate typeIsNot(unsigned TypeIdx, LLT Type)
True iff the given type index is not the specified type.
Predicate all(Predicate P0, Predicate P1)
True iff P0 and P1 are true.
LLVM_ABI LegalityPredicate typeIs(unsigned TypeIdx, LLT TypesInit)
True iff the given type index is the specified type.
LLVM_ABI LegalizeMutation changeElementCountTo(unsigned TypeIdx, unsigned FromTypeIdx)
Keep the same scalar or element type as TypeIdx, but take the number of elements from FromTypeIdx.
LLVM_ABI LegalizeMutation changeElementSizeTo(unsigned TypeIdx, unsigned FromTypeIdx)
Change the scalar size or element size to have the same scalar size as type index FromIndex.
Invariant opcodes: All instruction sets have these as their low opcodes.
This is an optimization pass for GlobalISel generic memory operations.
Definition Types.h:26
LLVM_ABI const llvm::fltSemantics & getFltSemanticForLLT(LLT Ty)
Get the appropriate floating point arithmetic semantic based on the bit size of the given scalar LLT.
std::function< bool(const LegalityQuery &)> LegalityPredicate
MachineInstr * getImm(const MachineOperand &MO, const MachineRegisterInfo *MRI)
constexpr bool isPowerOf2_32(uint32_t Value)
Return true if the argument is a power of two > 0.
Definition MathExtras.h:279
FPClassTest
Floating-point class tests, supported by 'is_fpclass' intrinsic.
LLVM_ABI raw_ostream & dbgs()
dbgs() - This returns a reference to a raw_ostream for debugging messages.
Definition Debug.cpp:207
const std::set< unsigned > & getTypeFoldingSupportedOpcodes()
int64_t foldImm(const MachineOperand &MO, const MachineRegisterInfo *MRI)
decltype(auto) cast(const From &Val)
cast<X> - Return the argument parameter cast to the specified type.
Definition Casting.h:559
Align commonAlignment(Align A, uint64_t Offset)
Returns the alignment that satisfies both alignments.
Definition Alignment.h:201
This struct is a compact representation of a valid (non-zero power of two) alignment.
Definition Alignment.h:39
The LegalityQuery object bundles together all the information that's needed to decide whether a given...
ArrayRef< LLT > Types
This class contains a discriminated union of information about pointers in memory operands,...
MachinePointerInfo getWithOffset(int64_t O) const