LLVM 23.0.0git
SPIRVISelLowering.cpp
Go to the documentation of this file.
1//===- SPIRVISelLowering.cpp - SPIR-V DAG Lowering Impl ---------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements the SPIRVTargetLowering class.
10//
11//===----------------------------------------------------------------------===//
12
13#include "SPIRVISelLowering.h"
14#include "SPIRV.h"
15#include "SPIRVInstrInfo.h"
17#include "SPIRVRegisterInfo.h"
18#include "SPIRVSubtarget.h"
23#include "llvm/IR/IntrinsicsSPIRV.h"
24
25#define DEBUG_TYPE "spirv-lower"
26
27using namespace llvm;
28
30 const SPIRVSubtarget &ST)
31 : TargetLowering(TM, ST), STI(ST) {
32 // Even with SPV_ALTERA_arbitrary_precision_integers enabled, atomic sizes are
33 // limited by atomicrmw xchg operation, which only supports operand up to 64
34 // bits wide, as defined in SPIR-V legalizer. Currently, spirv-val doesn't
35 // consider 128-bit OpTypeInt as valid either.
38}
39
40// Returns true of the types logically match, as defined in
41// https://registry.khronos.org/SPIR-V/specs/unified1/SPIRV.html#OpCopyLogical.
42static bool typesLogicallyMatch(const SPIRVTypeInst Ty1,
43 const SPIRVTypeInst Ty2,
45 if (Ty1->getOpcode() != Ty2->getOpcode())
46 return false;
47
48 if (Ty1->getNumOperands() != Ty2->getNumOperands())
49 return false;
50
51 if (Ty1->getOpcode() == SPIRV::OpTypeArray) {
52 // Array must have the same size.
53 if (Ty1->getOperand(2).getReg() != Ty2->getOperand(2).getReg())
54 return false;
55
56 SPIRVTypeInst ElemType1 =
58 SPIRVTypeInst ElemType2 =
60 return ElemType1 == ElemType2 ||
61 typesLogicallyMatch(ElemType1, ElemType2, GR);
62 }
63
64 if (Ty1->getOpcode() == SPIRV::OpTypeStruct) {
65 for (unsigned I = 1; I < Ty1->getNumOperands(); I++) {
66 SPIRVTypeInst ElemType1 =
68 SPIRVTypeInst ElemType2 =
70 if (ElemType1 != ElemType2 &&
71 !typesLogicallyMatch(ElemType1, ElemType2, GR))
72 return false;
73 }
74 return true;
75 }
76 return false;
77}
78
80 LLVMContext &Context, CallingConv::ID CC, EVT VT) const {
81 // This code avoids CallLowering fail inside getVectorTypeBreakdown
82 // on v3i1 arguments. Maybe we need to return 1 for all types.
83 // TODO: remove it once this case is supported by the default implementation.
84 if (VT.isVector() && VT.getVectorNumElements() == 3 &&
85 (VT.getVectorElementType() == MVT::i1 ||
86 VT.getVectorElementType() == MVT::i8))
87 return 1;
88 if (!VT.isVector() && VT.isInteger() && VT.getSizeInBits() <= 64)
89 return 1;
90 return getNumRegisters(Context, VT);
91}
92
95 EVT VT) const {
96 // This code avoids CallLowering fail inside getVectorTypeBreakdown
97 // on v3i1 arguments. Maybe we need to return i32 for all types.
98 // TODO: remove it once this case is supported by the default implementation.
99 if (VT.isVector() && VT.getVectorNumElements() == 3) {
100 if (VT.getVectorElementType() == MVT::i1)
101 return MVT::v4i1;
102 else if (VT.getVectorElementType() == MVT::i8)
103 return MVT::v4i8;
104 }
105 return getRegisterType(Context, VT);
106}
107
110 MachineFunction &MF, unsigned Intrinsic) const {
111 IntrinsicInfo Info;
112 unsigned AlignIdx = 3;
113 switch (Intrinsic) {
114 case Intrinsic::spv_load:
115 AlignIdx = 2;
116 [[fallthrough]];
117 case Intrinsic::spv_store: {
118 if (I.getNumOperands() >= AlignIdx + 1) {
119 auto *AlignOp = cast<ConstantInt>(I.getOperand(AlignIdx));
120 Info.align = Align(AlignOp->getZExtValue());
121 }
122 Info.flags = static_cast<MachineMemOperand::Flags>(
123 cast<ConstantInt>(I.getOperand(AlignIdx - 1))->getZExtValue());
124 Info.memVT = MVT::i64;
125 // TODO: take into account opaque pointers (don't use getElementType).
126 // MVT::getVT(PtrTy->getElementType());
127 Infos.push_back(Info);
128 return;
129 }
130 default:
131 break;
132 }
133}
134
135std::pair<unsigned, const TargetRegisterClass *>
137 StringRef Constraint,
138 MVT VT) const {
139 const TargetRegisterClass *RC = nullptr;
140 if (Constraint.starts_with("{"))
141 return std::make_pair(0u, RC);
142
143 if (VT.isFloatingPoint())
144 RC = VT.isVector() ? &SPIRV::vfIDRegClass : &SPIRV::fIDRegClass;
145 else if (VT.isInteger())
146 RC = VT.isVector() ? &SPIRV::vIDRegClass : &SPIRV::iIDRegClass;
147 else
148 RC = &SPIRV::iIDRegClass;
149
150 return std::make_pair(0u, RC);
151}
152
154 const MachineInstr *Inst = MRI->getVRegDef(OpReg);
155 return Inst && Inst->getOpcode() == SPIRV::OpFunctionParameter
156 ? Inst->getOperand(1).getReg()
157 : OpReg;
158}
159
162 Register OpReg, unsigned OpIdx,
163 SPIRVTypeInst NewPtrType) {
164 MachineIRBuilder MIB(I);
165 Register NewReg = createVirtualRegister(NewPtrType, &GR, MRI, MIB.getMF());
166 MIB.buildInstr(SPIRV::OpBitcast)
167 .addDef(NewReg)
168 .addUse(GR.getSPIRVTypeID(NewPtrType))
169 .addUse(OpReg)
171 *STI.getRegBankInfo());
172 I.getOperand(OpIdx).setReg(NewReg);
173}
174
176 SPIRVTypeInst OpType, bool ReuseType,
177 SPIRVTypeInst ResType,
178 const Type *ResTy) {
179 SPIRV::StorageClass::StorageClass SC =
180 static_cast<SPIRV::StorageClass::StorageClass>(
181 OpType->getOperand(1).getImm());
182 MachineIRBuilder MIB(I);
183 SPIRVTypeInst NewBaseType =
184 ReuseType ? ResType
186 ResTy, MIB, SPIRV::AccessQualifier::ReadWrite, false);
187 return GR.getOrCreateSPIRVPointerType(NewBaseType, MIB, SC);
188}
189
190// Insert a bitcast before the instruction to keep SPIR-V code valid
191// when there is a type mismatch between results and operand types.
192static void validatePtrTypes(const SPIRVSubtarget &STI,
194 MachineInstr &I, unsigned OpIdx,
195 SPIRVTypeInst ResType,
196 const Type *ResTy = nullptr) {
197 // Get operand type
198 MachineFunction *MF = I.getParent()->getParent();
199 Register OpReg = I.getOperand(OpIdx).getReg();
200 Register OpTypeReg = getTypeReg(MRI, OpReg);
201 const MachineInstr *OpType = GR.getSPIRVTypeForVReg(OpTypeReg, MF);
202 if (!ResType || !OpType || OpType->getOpcode() != SPIRV::OpTypePointer)
203 return;
204 // Get operand's pointee type
205 Register ElemTypeReg = OpType->getOperand(2).getReg();
206 SPIRVTypeInst ElemType = GR.getSPIRVTypeForVReg(ElemTypeReg, MF);
207 if (!ElemType)
208 return;
209 // Check if we need a bitcast to make a statement valid
210 bool IsSameMF = MF == ResType->getParent()->getParent();
211 bool IsEqualTypes = IsSameMF ? ElemType == ResType
212 : GR.getTypeForSPIRVType(ElemType) == ResTy;
213 if (IsEqualTypes)
214 return;
215 // There is a type mismatch between results and operand types
216 // and we insert a bitcast before the instruction to keep SPIR-V code valid
217 SPIRVTypeInst NewPtrType =
218 createNewPtrType(GR, I, OpType, IsSameMF, ResType, ResTy);
219 if (!GR.isBitcastCompatible(NewPtrType, OpType))
221 "insert validation bitcast: incompatible result and operand types");
222 doInsertBitcast(STI, MRI, GR, I, OpReg, OpIdx, NewPtrType);
223}
224
225// Insert a bitcast before OpGroupWaitEvents if the last argument is a pointer
226// that doesn't point to OpTypeEvent.
230 MachineInstr &I) {
231 constexpr unsigned OpIdx = 2;
232 MachineFunction *MF = I.getParent()->getParent();
233 Register OpReg = I.getOperand(OpIdx).getReg();
234 Register OpTypeReg = getTypeReg(MRI, OpReg);
235 SPIRVTypeInst OpType = GR.getSPIRVTypeForVReg(OpTypeReg, MF);
236 if (!OpType || OpType->getOpcode() != SPIRV::OpTypePointer)
237 return;
238 SPIRVTypeInst ElemType =
239 GR.getSPIRVTypeForVReg(OpType->getOperand(2).getReg());
240 if (!ElemType || ElemType->getOpcode() == SPIRV::OpTypeEvent)
241 return;
242 // Insert a bitcast before the instruction to keep SPIR-V code valid.
243 LLVMContext &Context = MF->getFunction().getContext();
244 SPIRVTypeInst NewPtrType =
245 createNewPtrType(GR, I, OpType, false, nullptr,
246 TargetExtType::get(Context, "spirv.Event"));
247 doInsertBitcast(STI, MRI, GR, I, OpReg, OpIdx, NewPtrType);
248}
249
253 Register PtrReg = I.getOperand(0).getReg();
254 MachineFunction *MF = I.getParent()->getParent();
255 Register PtrTypeReg = getTypeReg(MRI, PtrReg);
256 SPIRVTypeInst PtrType = GR.getSPIRVTypeForVReg(PtrTypeReg, MF);
257 SPIRVTypeInst PonteeElemType = PtrType ? GR.getPointeeType(PtrType) : nullptr;
258 if (!PonteeElemType || PonteeElemType->getOpcode() == SPIRV::OpTypeVoid ||
259 (PonteeElemType->getOpcode() == SPIRV::OpTypeInt &&
260 PonteeElemType->getOperand(1).getImm() == 8))
261 return;
262 // To keep the code valid a bitcast must be inserted
263 SPIRV::StorageClass::StorageClass SC =
264 static_cast<SPIRV::StorageClass::StorageClass>(
265 PtrType->getOperand(1).getImm());
266 MachineIRBuilder MIB(I);
267 LLVMContext &Context = MF->getFunction().getContext();
268 SPIRVTypeInst NewPtrType =
270 doInsertBitcast(STI, MRI, GR, I, PtrReg, 0, NewPtrType);
271}
272
276 MachineInstr &I, unsigned OpIdx) {
277 MachineFunction *MF = I.getParent()->getParent();
278 Register OpReg = I.getOperand(OpIdx).getReg();
279 Register OpTypeReg = getTypeReg(MRI, OpReg);
280 SPIRVTypeInst OpType = GR.getSPIRVTypeForVReg(OpTypeReg, MF);
281 if (!OpType || OpType->getOpcode() != SPIRV::OpTypePointer)
282 return;
283 SPIRVTypeInst ElemType =
284 GR.getSPIRVTypeForVReg(OpType->getOperand(2).getReg());
285 if (!ElemType || ElemType->getOpcode() != SPIRV::OpTypeStruct ||
286 ElemType->getNumOperands() != 2)
287 return;
288 // It's a structure-wrapper around another type with a single member field.
289 SPIRVTypeInst MemberType =
290 GR.getSPIRVTypeForVReg(ElemType->getOperand(1).getReg());
291 if (!MemberType)
292 return;
293 unsigned MemberTypeOp = MemberType->getOpcode();
294 if (MemberTypeOp != SPIRV::OpTypeVector && MemberTypeOp != SPIRV::OpTypeInt &&
295 MemberTypeOp != SPIRV::OpTypeFloat && MemberTypeOp != SPIRV::OpTypeBool)
296 return;
297 // It's a structure-wrapper around a valid type. Insert a bitcast before the
298 // instruction to keep SPIR-V code valid.
299 SPIRV::StorageClass::StorageClass SC =
300 static_cast<SPIRV::StorageClass::StorageClass>(
301 OpType->getOperand(1).getImm());
302 MachineIRBuilder MIB(I);
303 SPIRVTypeInst NewPtrType =
304 GR.getOrCreateSPIRVPointerType(MemberType, MIB, SC);
305 doInsertBitcast(STI, MRI, GR, I, OpReg, OpIdx, NewPtrType);
306}
307
308// Insert a bitcast before the function call instruction to keep SPIR-V code
309// valid when there is a type mismatch between actual and expected types of an
310// argument:
311// %formal = OpFunctionParameter %formal_type
312// ...
313// %res = OpFunctionCall %ty %fun %actual ...
314// implies that %actual is of %formal_type, and in case of opaque pointers.
315// We may need to insert a bitcast to ensure this.
317 MachineRegisterInfo *DefMRI,
318 MachineRegisterInfo *CallMRI,
319 SPIRVGlobalRegistry &GR, MachineInstr &FunCall,
320 MachineInstr *FunDef) {
321 if (FunDef->getOpcode() != SPIRV::OpFunction)
322 return;
323 unsigned OpIdx = 3;
324 for (FunDef = FunDef->getNextNode();
325 FunDef && FunDef->getOpcode() == SPIRV::OpFunctionParameter &&
326 OpIdx < FunCall.getNumOperands();
327 FunDef = FunDef->getNextNode(), OpIdx++) {
328 SPIRVTypeInst DefPtrType =
329 DefMRI->getVRegDef(FunDef->getOperand(1).getReg());
330 SPIRVTypeInst DefElemType =
331 DefPtrType && DefPtrType->getOpcode() == SPIRV::OpTypePointer
332 ? GR.getSPIRVTypeForVReg(DefPtrType->getOperand(2).getReg(),
333 DefPtrType->getParent()->getParent())
334 : nullptr;
335 if (DefElemType) {
336 const Type *DefElemTy = GR.getTypeForSPIRVType(DefElemType);
337 // validatePtrTypes() works in the context if the call site
338 // When we process historical records about forward calls
339 // we need to switch context to the (forward) call site and
340 // then restore it back to the current machine function.
341 MachineFunction *CurMF =
342 GR.setCurrentFunc(*FunCall.getParent()->getParent());
343 validatePtrTypes(STI, CallMRI, GR, FunCall, OpIdx, DefElemType,
344 DefElemTy);
345 GR.setCurrentFunc(*CurMF);
346 }
347 }
348}
349
350// Ensure there is no mismatch between actual and expected arg types: calls
351// with a processed definition. Return Function pointer if it's a forward
352// call (ahead of definition), and nullptr otherwise.
354 MachineRegisterInfo *CallMRI,
356 MachineInstr &FunCall) {
357 const GlobalValue *GV = FunCall.getOperand(2).getGlobal();
358 const Function *F = dyn_cast<Function>(GV);
359 MachineInstr *FunDef =
360 const_cast<MachineInstr *>(GR.getFunctionDefinition(F));
361 if (!FunDef)
362 return F;
363 MachineRegisterInfo *DefMRI = &FunDef->getParent()->getParent()->getRegInfo();
364 validateFunCallMachineDef(STI, DefMRI, CallMRI, GR, FunCall, FunDef);
365 return nullptr;
366}
367
368// Ensure there is no mismatch between actual and expected arg types: calls
369// ahead of a processed definition.
372 MachineInstr &FunDef) {
373 const Function *F = GR.getFunctionByDefinition(&FunDef);
375 for (MachineInstr *FunCall : *FwdCalls) {
376 MachineRegisterInfo *CallMRI =
377 &FunCall->getParent()->getParent()->getRegInfo();
378 validateFunCallMachineDef(STI, DefMRI, CallMRI, GR, *FunCall, &FunDef);
379 }
380}
381
382// Validation of an access chain.
385 SPIRVTypeInst BaseTypeInst = GR.getSPIRVTypeForVReg(I.getOperand(0).getReg());
386 if (BaseTypeInst && BaseTypeInst->getOpcode() == SPIRV::OpTypePointer) {
387 SPIRVTypeInst BaseElemType =
388 GR.getSPIRVTypeForVReg(BaseTypeInst->getOperand(2).getReg());
389 validatePtrTypes(STI, MRI, GR, I, 2, BaseElemType);
390 }
391}
392
393// TODO: the logic of inserting additional bitcast's is to be moved
394// to pre-IRTranslation passes eventually
396 // finalizeLowering() is called twice (see GlobalISel/InstructionSelect.cpp)
397 // We'd like to avoid the needless second processing pass.
398 if (ProcessedMF.find(&MF) != ProcessedMF.end())
399 return;
400
402 SPIRVGlobalRegistry &GR = *STI.getSPIRVGlobalRegistry();
403 GR.setCurrentFunc(MF);
404 for (MachineFunction::iterator I = MF.begin(), E = MF.end(); I != E; ++I) {
406 for (MachineBasicBlock::iterator MBBI = MBB->begin(), MBBE = MBB->end();
407 MBBI != MBBE;) {
408 MachineInstr &MI = *MBBI++;
409 switch (MI.getOpcode()) {
410 case SPIRV::OpAtomicLoad:
411 case SPIRV::OpAtomicExchange:
412 case SPIRV::OpAtomicCompareExchange:
413 case SPIRV::OpAtomicCompareExchangeWeak:
414 case SPIRV::OpAtomicIIncrement:
415 case SPIRV::OpAtomicIDecrement:
416 case SPIRV::OpAtomicIAdd:
417 case SPIRV::OpAtomicISub:
418 case SPIRV::OpAtomicSMin:
419 case SPIRV::OpAtomicUMin:
420 case SPIRV::OpAtomicSMax:
421 case SPIRV::OpAtomicUMax:
422 case SPIRV::OpAtomicAnd:
423 case SPIRV::OpAtomicOr:
424 case SPIRV::OpAtomicXor:
425 // for the above listed instructions
426 // OpAtomicXXX <ResType>, ptr %Op, ...
427 // implies that %Op is a pointer to <ResType>
428 case SPIRV::OpLoad:
429 // OpLoad <ResType>, ptr %Op implies that %Op is a pointer to <ResType>
431 break;
432
433 validatePtrTypes(STI, MRI, GR, MI, 2,
434 GR.getSPIRVTypeForVReg(MI.getOperand(0).getReg()));
435 break;
436 case SPIRV::OpAtomicStore:
437 // OpAtomicStore ptr %Op, <Scope>, <Mem>, <Obj>
438 // implies that %Op points to the <Obj>'s type
439 validatePtrTypes(STI, MRI, GR, MI, 0,
440 GR.getSPIRVTypeForVReg(MI.getOperand(3).getReg()));
441 break;
442 case SPIRV::OpStore:
443 // OpStore ptr %Op, <Obj> implies that %Op points to the <Obj>'s type
444 validatePtrTypes(STI, MRI, GR, MI, 0,
445 GR.getSPIRVTypeForVReg(MI.getOperand(1).getReg()));
446 break;
447 case SPIRV::OpPtrCastToGeneric:
448 case SPIRV::OpGenericCastToPtr:
449 case SPIRV::OpGenericCastToPtrExplicit:
450 validateAccessChain(STI, MRI, GR, MI);
451 break;
452 case SPIRV::OpPtrAccessChain:
453 case SPIRV::OpInBoundsPtrAccessChain:
454 if (MI.getNumOperands() == 4)
455 validateAccessChain(STI, MRI, GR, MI);
456 break;
457
458 case SPIRV::OpFunctionCall:
459 // ensure there is no mismatch between actual and expected arg types:
460 // calls with a processed definition
461 if (MI.getNumOperands() > 3)
462 if (const Function *F = validateFunCall(STI, MRI, GR, MI))
463 GR.addForwardCall(F, &MI);
464 break;
465 case SPIRV::OpFunction:
466 // ensure there is no mismatch between actual and expected arg types:
467 // calls ahead of a processed definition
468 validateForwardCalls(STI, MRI, GR, MI);
469 break;
470
471 // ensure that LLVM IR add/sub instructions result in logical SPIR-V
472 // instructions when applied to bool type
473 case SPIRV::OpIAddS:
474 case SPIRV::OpIAddV:
475 case SPIRV::OpISubS:
476 case SPIRV::OpISubV:
477 if (GR.isScalarOrVectorOfType(MI.getOperand(1).getReg(),
478 SPIRV::OpTypeBool))
479 MI.setDesc(STI.getInstrInfo()->get(SPIRV::OpLogicalNotEqual));
480 break;
481
482 // ensure that LLVM IR bitwise instructions result in logical SPIR-V
483 // instructions when applied to bool type
484 case SPIRV::OpBitwiseOrS:
485 case SPIRV::OpBitwiseOrV:
486 if (GR.isScalarOrVectorOfType(MI.getOperand(1).getReg(),
487 SPIRV::OpTypeBool))
488 MI.setDesc(STI.getInstrInfo()->get(SPIRV::OpLogicalOr));
489 break;
490 case SPIRV::OpBitwiseAndS:
491 case SPIRV::OpBitwiseAndV:
492 if (GR.isScalarOrVectorOfType(MI.getOperand(1).getReg(),
493 SPIRV::OpTypeBool))
494 MI.setDesc(STI.getInstrInfo()->get(SPIRV::OpLogicalAnd));
495 break;
496 case SPIRV::OpBitwiseXorS:
497 case SPIRV::OpBitwiseXorV:
498 if (GR.isScalarOrVectorOfType(MI.getOperand(1).getReg(),
499 SPIRV::OpTypeBool))
500 MI.setDesc(STI.getInstrInfo()->get(SPIRV::OpLogicalNotEqual));
501 break;
502 case SPIRV::OpLifetimeStart:
503 case SPIRV::OpLifetimeStop:
504 if (MI.getOperand(1).getImm() > 0)
505 validateLifetimeStart(STI, MRI, GR, MI);
506 break;
507 case SPIRV::OpGroupAsyncCopy:
508 validatePtrUnwrapStructField(STI, MRI, GR, MI, 3);
509 validatePtrUnwrapStructField(STI, MRI, GR, MI, 4);
510 break;
511 case SPIRV::OpGroupWaitEvents:
512 // OpGroupWaitEvents ..., ..., <pointer to OpTypeEvent>
514 break;
515 case SPIRV::OpConstantI: {
516 SPIRVTypeInst Type = GR.getSPIRVTypeForVReg(MI.getOperand(1).getReg());
517 if (Type->getOpcode() != SPIRV::OpTypeInt && MI.getOperand(2).isImm() &&
518 MI.getOperand(2).getImm() == 0) {
519 // Validate the null constant of a target extension type
520 MI.setDesc(STI.getInstrInfo()->get(SPIRV::OpConstantNull));
521 for (unsigned i = MI.getNumOperands() - 1; i > 1; --i)
522 MI.removeOperand(i);
523 }
524 } break;
525 case SPIRV::OpExtInst: {
526 // prefetch
527 if (!MI.getOperand(2).isImm() || !MI.getOperand(3).isImm() ||
528 MI.getOperand(2).getImm() != SPIRV::InstructionSet::OpenCL_std)
529 continue;
530 switch (MI.getOperand(3).getImm()) {
531 case SPIRV::OpenCLExtInst::frexp:
532 case SPIRV::OpenCLExtInst::lgamma_r:
533 case SPIRV::OpenCLExtInst::remquo: {
534 // The last operand must be of a pointer to i32 or vector of i32
535 // values.
536 MachineIRBuilder MIB(MI);
537 SPIRVTypeInst Int32Type = GR.getOrCreateSPIRVIntegerType(32, MIB);
538 SPIRVTypeInst RetType = MRI->getVRegDef(MI.getOperand(1).getReg());
539 assert(RetType && "Expected return type");
540 validatePtrTypes(STI, MRI, GR, MI, MI.getNumOperands() - 1,
541 RetType->getOpcode() != SPIRV::OpTypeVector
542 ? Int32Type
544 Int32Type, RetType->getOperand(2).getImm(),
545 MIB, false));
546 } break;
547 case SPIRV::OpenCLExtInst::fract:
548 case SPIRV::OpenCLExtInst::modf:
549 case SPIRV::OpenCLExtInst::sincos:
550 // The last operand must be of a pointer to the base type represented
551 // by the previous operand.
552 assert(MI.getOperand(MI.getNumOperands() - 2).isReg() &&
553 "Expected v-reg");
555 STI, MRI, GR, MI, MI.getNumOperands() - 1,
557 MI.getOperand(MI.getNumOperands() - 2).getReg()));
558 break;
559 case SPIRV::OpenCLExtInst::prefetch:
560 // Expected `ptr` type is a pointer to float, integer or vector, but
561 // the pontee value can be wrapped into a struct.
562 assert(MI.getOperand(MI.getNumOperands() - 2).isReg() &&
563 "Expected v-reg");
565 MI.getNumOperands() - 2);
566 break;
567 }
568 } break;
569 }
570 }
571 }
572 ProcessedMF.insert(&MF);
574}
575
576// Modifies either operand PtrOpIdx or OpIdx so that the pointee type of
577// PtrOpIdx matches the type for operand OpIdx. Returns true if they already
578// match or if the instruction was modified to make them match.
580 MachineInstr &I, unsigned int PtrOpIdx, unsigned int OpIdx) const {
581 SPIRVGlobalRegistry &GR = *STI.getSPIRVGlobalRegistry();
582 SPIRVTypeInst PtrType = GR.getResultType(I.getOperand(PtrOpIdx).getReg());
583 SPIRVTypeInst PointeeType = GR.getPointeeType(PtrType);
584 SPIRVTypeInst OpType = GR.getResultType(I.getOperand(OpIdx).getReg());
585
586 if (PointeeType == OpType)
587 return true;
588
589 if (typesLogicallyMatch(PointeeType, OpType, GR)) {
590 // Apply OpCopyLogical to OpIdx.
591 if (I.getOperand(OpIdx).isDef() &&
592 insertLogicalCopyOnResult(I, PointeeType)) {
593 return true;
594 }
595
596 llvm_unreachable("Unable to add OpCopyLogical yet.");
597 return false;
598 }
599
600 return false;
601}
602
604 MachineInstr &I, SPIRVTypeInst NewResultType) const {
605 MachineRegisterInfo *MRI = &I.getMF()->getRegInfo();
606 SPIRVGlobalRegistry &GR = *STI.getSPIRVGlobalRegistry();
607
608 Register NewResultReg =
609 createVirtualRegister(NewResultType, &GR, MRI, *I.getMF());
610 Register NewTypeReg = GR.getSPIRVTypeID(NewResultType);
611
612 assert(llvm::size(I.defs()) == 1 && "Expected only one def");
613 MachineOperand &OldResult = *I.defs().begin();
614 Register OldResultReg = OldResult.getReg();
615 MachineOperand &OldType = *I.uses().begin();
616 Register OldTypeReg = OldType.getReg();
617
618 OldResult.setReg(NewResultReg);
619 OldType.setReg(NewTypeReg);
620
621 MachineIRBuilder MIB(*I.getNextNode());
622 MIB.buildInstr(SPIRV::OpCopyLogical)
623 .addDef(OldResultReg)
624 .addUse(OldTypeReg)
625 .addUse(NewResultReg)
626 .constrainAllUses(*STI.getInstrInfo(), *STI.getRegisterInfo(),
627 *STI.getRegBankInfo());
628 return true;
629}
630
646
649 // TODO: Pointer operand should be cast to integer in atomicrmw xchg, since
650 // SPIR-V only supports atomic exchange for integer and floating-point types.
652}
unsigned const MachineRegisterInfo * MRI
assert(UImm &&(UImm !=~static_cast< T >(0)) &&"Invalid immediate!")
MachineBasicBlock & MBB
MachineBasicBlock MachineBasicBlock::iterator MBBI
IRTranslator LLVM IR MI
#define F(x, y, z)
Definition MD5.cpp:54
#define I(x, y, z)
Definition MD5.cpp:57
Register const TargetRegisterInfo * TRI
MachineInstr unsigned OpIdx
static bool typesLogicallyMatch(const SPIRVTypeInst Ty1, const SPIRVTypeInst Ty2, SPIRVGlobalRegistry &GR)
static void validateLifetimeStart(const SPIRVSubtarget &STI, MachineRegisterInfo *MRI, SPIRVGlobalRegistry &GR, MachineInstr &I)
static void validatePtrTypes(const SPIRVSubtarget &STI, MachineRegisterInfo *MRI, SPIRVGlobalRegistry &GR, MachineInstr &I, unsigned OpIdx, SPIRVTypeInst ResType, const Type *ResTy=nullptr)
static void validateGroupWaitEventsPtr(const SPIRVSubtarget &STI, MachineRegisterInfo *MRI, SPIRVGlobalRegistry &GR, MachineInstr &I)
static void validatePtrUnwrapStructField(const SPIRVSubtarget &STI, MachineRegisterInfo *MRI, SPIRVGlobalRegistry &GR, MachineInstr &I, unsigned OpIdx)
Register getTypeReg(MachineRegisterInfo *MRI, Register OpReg)
void validateAccessChain(const SPIRVSubtarget &STI, MachineRegisterInfo *MRI, SPIRVGlobalRegistry &GR, MachineInstr &I)
void validateFunCallMachineDef(const SPIRVSubtarget &STI, MachineRegisterInfo *DefMRI, MachineRegisterInfo *CallMRI, SPIRVGlobalRegistry &GR, MachineInstr &FunCall, MachineInstr *FunDef)
void validateForwardCalls(const SPIRVSubtarget &STI, MachineRegisterInfo *DefMRI, SPIRVGlobalRegistry &GR, MachineInstr &FunDef)
const Function * validateFunCall(const SPIRVSubtarget &STI, MachineRegisterInfo *CallMRI, SPIRVGlobalRegistry &GR, MachineInstr &FunCall)
static void doInsertBitcast(const SPIRVSubtarget &STI, MachineRegisterInfo *MRI, SPIRVGlobalRegistry &GR, MachineInstr &I, Register OpReg, unsigned OpIdx, SPIRVTypeInst NewPtrType)
static SPIRVTypeInst createNewPtrType(SPIRVGlobalRegistry &GR, MachineInstr &I, SPIRVTypeInst OpType, bool ReuseType, SPIRVTypeInst ResType, const Type *ResTy)
This file describes how to lower LLVM code to machine code.
an instruction that atomically reads a memory location, combines it with another value,...
@ FAdd
*p = old + v
@ FSub
*p = old - v
@ UIncWrap
Increment one up to a maximum value.
@ FMin
*p = minnum(old, v) minnum matches the behavior of llvm.minnum.
@ FMax
*p = maxnum(old, v) maxnum matches the behavior of llvm.maxnum.
@ UDecWrap
Decrement one until a minimum value or zero.
BinOp getOperation() const
Base class for all callable instructions (InvokeInst and CallInst) Holds everything related to callin...
LLVMContext & getContext() const
getContext - Return a reference to the LLVMContext associated with this function.
Definition Function.cpp:358
This is an important class for using LLVM in a threaded context.
Definition LLVMContext.h:68
Machine Value Type.
bool isVector() const
Return true if this is a vector value type.
bool isInteger() const
Return true if this is an integer or a vector integer type.
bool isFloatingPoint() const
Return true if this is a FP or a vector FP type.
const MachineFunction * getParent() const
Return the MachineFunction containing this basic block.
MachineInstrBundleIterator< MachineInstr > iterator
MachineRegisterInfo & getRegInfo()
getRegInfo - Return information about the registers currently in use.
Function & getFunction()
Return the LLVM function that this machine code represents.
BasicBlockListType::iterator iterator
void insert(iterator MBBI, MachineBasicBlock *MBB)
Helper class to build MachineInstr.
MachineInstrBuilder buildInstr(unsigned Opcode)
Build and insert <empty> = Opcode <empty>.
MachineFunction & getMF()
Getter for the function we currently build.
void constrainAllUses(const TargetInstrInfo &TII, const TargetRegisterInfo &TRI, const RegisterBankInfo &RBI) const
const MachineInstrBuilder & addUse(Register RegNo, RegState Flags={}, unsigned SubReg=0) const
Add a virtual register use operand.
const MachineInstrBuilder & addDef(Register RegNo, RegState Flags={}, unsigned SubReg=0) const
Add a virtual register definition operand.
Representation of each machine instruction.
unsigned getOpcode() const
Returns the opcode of this MachineInstr.
const MachineBasicBlock * getParent() const
unsigned getNumOperands() const
Retuns the total number of operands.
const MachineOperand & getOperand(unsigned i) const
Flags
Flags values. These may be or'd together.
MachineOperand class - Representation of each machine instruction operand.
const GlobalValue * getGlobal() const
int64_t getImm() const
LLVM_ABI void setReg(Register Reg)
Change the register this operand corresponds to.
Register getReg() const
getReg - Returns the register number.
MachineRegisterInfo - Keep track of information for virtual and physical registers,...
LLVM_ABI MachineInstr * getVRegDef(Register Reg) const
getVRegDef - Return the machine instr that defines the specified virtual register or null if none is ...
Wrapper class representing virtual and physical registers.
Definition Register.h:20
void addForwardCall(const Function *F, MachineInstr *MI)
SPIRVTypeInst getOrCreateSPIRVIntegerType(unsigned BitWidth, MachineIRBuilder &MIRBuilder)
SPIRVTypeInst getOrCreateSPIRVVectorType(SPIRVTypeInst BaseType, unsigned NumElements, MachineIRBuilder &MIRBuilder, bool EmitIR)
SPIRVTypeInst getResultType(Register VReg, MachineFunction *MF=nullptr)
const Type * getTypeForSPIRVType(SPIRVTypeInst Ty) const
bool isBitcastCompatible(SPIRVTypeInst Type1, SPIRVTypeInst Type2) const
const MachineInstr * getFunctionDefinition(const Function *F)
SPIRVTypeInst getOrCreateSPIRVPointerType(const Type *BaseType, MachineIRBuilder &MIRBuilder, SPIRV::StorageClass::StorageClass SC)
Register getSPIRVTypeID(SPIRVTypeInst SpirvType) const
SPIRVTypeInst getPointeeType(SPIRVTypeInst PtrType)
SmallPtrSet< MachineInstr *, 8 > * getForwardCalls(const Function *F)
SPIRVTypeInst getOrCreateSPIRVType(const Type *Type, MachineInstr &I, SPIRV::AccessQualifier::AccessQualifier AQ, bool EmitIR)
bool isScalarOrVectorOfType(Register VReg, unsigned TypeOpcode) const
MachineFunction * setCurrentFunc(MachineFunction &MF)
SPIRVTypeInst getSPIRVTypeForVReg(Register VReg, const MachineFunction *MF=nullptr) const
const Function * getFunctionByDefinition(const MachineInstr *MI)
const SPIRVInstrInfo * getInstrInfo() const override
const SPIRVRegisterInfo * getRegisterInfo() const override
const RegisterBankInfo * getRegBankInfo() const override
AtomicExpansionKind shouldCastAtomicRMWIInIR(AtomicRMWInst *RMWI) const override
Returns how the given atomic atomicrmw should be cast by the IR-level AtomicExpand pass.
bool enforcePtrTypeCompatibility(MachineInstr &I, unsigned PtrOpIdx, unsigned OpIdx) const
unsigned getNumRegisters(LLVMContext &Context, EVT VT, std::optional< MVT > RegisterVT=std::nullopt) const override
Return the number of registers that this ValueType will eventually require.
unsigned getNumRegistersForCallingConv(LLVMContext &Context, CallingConv::ID CC, EVT VT) const override
Certain targets require unusual breakdowns of certain types.
MVT getRegisterTypeForCallingConv(LLVMContext &Context, CallingConv::ID CC, EVT VT) const override
Certain combinations of ABIs, Targets and features require that types are legal for some operations a...
AtomicExpansionKind shouldExpandAtomicRMWInIR(const AtomicRMWInst *RMW) const override
Returns how the IR-level AtomicExpand pass should expand the given AtomicRMW, if at all.
void finalizeLowering(MachineFunction &MF) const override
Execute target specific actions to finalize target lowering.
void getTgtMemIntrinsic(SmallVectorImpl< IntrinsicInfo > &Infos, const CallBase &I, MachineFunction &MF, unsigned Intrinsic) const override
Given an intrinsic, checks if on the target the intrinsic will need to map to a MemIntrinsicNode (tou...
bool insertLogicalCopyOnResult(MachineInstr &I, SPIRVTypeInst NewResultType) const
SPIRVTargetLowering(const TargetMachine &TM, const SPIRVSubtarget &ST)
std::pair< unsigned, const TargetRegisterClass * > getRegForInlineAsmConstraint(const TargetRegisterInfo *TRI, StringRef Constraint, MVT VT) const override
Given a physical register constraint (e.g.
SmallPtrSet - This class implements a set which is optimized for holding SmallSize or less elements.
This class consists of common code factored out of the SmallVector class to reduce code duplication b...
void push_back(const T &Elt)
StringRef - Represent a constant reference to a string, i.e.
Definition StringRef.h:55
bool starts_with(StringRef Prefix) const
Check if this string starts with the given Prefix.
Definition StringRef.h:258
static LLVM_ABI TargetExtType * get(LLVMContext &Context, StringRef Name, ArrayRef< Type * > Types={}, ArrayRef< unsigned > Ints={})
Return a target extension type having the specified name and optional type and integer parameters.
Definition Type.cpp:907
virtual void finalizeLowering(MachineFunction &MF) const
Execute target specific actions to finalize target lowering.
virtual AtomicExpansionKind shouldExpandAtomicRMWInIR(const AtomicRMWInst *RMW) const
Returns how the IR-level AtomicExpand pass should expand the given AtomicRMW, if at all.
void setMaxAtomicSizeInBitsSupported(unsigned SizeInBits)
Set the maximum atomic operation size supported by the backend.
void setMinCmpXchgSizeInBits(unsigned SizeInBits)
Sets the minimum cmpxchg or ll/sc size supported by the backend.
AtomicExpansionKind
Enum that specifies what an atomic load/AtomicRMWInst is expanded to, if at all.
MVT getRegisterType(MVT VT) const
Return the type of registers that this ValueType will eventually require.
TargetLowering(const TargetLowering &)=delete
Primary interface to the complete machine description for the target machine.
TargetRegisterInfo base class - We assume that the target defines a static array of TargetRegisterDes...
The instances of the Type class are immutable: once they are created, they are never changed.
Definition Type.h:45
static LLVM_ABI IntegerType * getInt8Ty(LLVMContext &C)
Definition Type.cpp:294
NodeTy * getNextNode()
Get the next node, or nullptr for the list tail.
Definition ilist_node.h:348
#define llvm_unreachable(msg)
Marks that the current location is not supposed to be reachable.
unsigned ID
LLVM IR allows to use arbitrary numbers as calling convention identifiers.
Definition CallingConv.h:24
This namespace contains an enum with a value for every intrinsic/builtin function known by LLVM.
This is an optimization pass for GlobalISel generic memory operations.
Definition Types.h:26
auto size(R &&Range, std::enable_if_t< std::is_base_of< std::random_access_iterator_tag, typename std::iterator_traits< decltype(Range.begin())>::iterator_category >::value, void > *=nullptr)
Get the size of a range.
Definition STLExtras.h:1669
decltype(auto) dyn_cast(const From &Val)
dyn_cast<X> - Return the argument parameter cast to the specified type.
Definition Casting.h:643
Register createVirtualRegister(SPIRVTypeInst SpvType, SPIRVGlobalRegistry *GR, MachineRegisterInfo *MRI, const MachineFunction &MF)
MachineInstr * getImm(const MachineOperand &MO, const MachineRegisterInfo *MRI)
LLVM_ABI void report_fatal_error(Error Err, bool gen_crash_diag=true)
Definition Error.cpp:163
decltype(auto) cast(const From &Val)
cast<X> - Return the argument parameter cast to the specified type.
Definition Casting.h:559
This struct is a compact representation of a valid (non-zero power of two) alignment.
Definition Alignment.h:39
Extended Value Type.
Definition ValueTypes.h:35
TypeSize getSizeInBits() const
Return the size of the specified value type in bits.
Definition ValueTypes.h:373
bool isVector() const
Return true if this is a vector value type.
Definition ValueTypes.h:168
EVT getVectorElementType() const
Given a vector type, return the type of each element.
Definition ValueTypes.h:328
unsigned getVectorNumElements() const
Given a vector type, return the number of elements it contains.
Definition ValueTypes.h:336
bool isInteger() const
Return true if this is an integer or a vector integer type.
Definition ValueTypes.h:152