LLVM 19.0.0git
SPIRVISelLowering.cpp
Go to the documentation of this file.
1//===- SPIRVISelLowering.cpp - SPIR-V DAG Lowering Impl ---------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements the SPIRVTargetLowering class.
10//
11//===----------------------------------------------------------------------===//
12
13#include "SPIRVISelLowering.h"
14#include "SPIRV.h"
15#include "SPIRVInstrInfo.h"
17#include "SPIRVRegisterInfo.h"
18#include "SPIRVSubtarget.h"
19#include "SPIRVTargetMachine.h"
22#include "llvm/IR/IntrinsicsSPIRV.h"
23
24#define DEBUG_TYPE "spirv-lower"
25
26using namespace llvm;
27
29 LLVMContext &Context, CallingConv::ID CC, EVT VT) const {
30 // This code avoids CallLowering fail inside getVectorTypeBreakdown
31 // on v3i1 arguments. Maybe we need to return 1 for all types.
32 // TODO: remove it once this case is supported by the default implementation.
33 if (VT.isVector() && VT.getVectorNumElements() == 3 &&
34 (VT.getVectorElementType() == MVT::i1 ||
35 VT.getVectorElementType() == MVT::i8))
36 return 1;
37 if (!VT.isVector() && VT.isInteger() && VT.getSizeInBits() <= 64)
38 return 1;
39 return getNumRegisters(Context, VT);
40}
41
44 EVT VT) const {
45 // This code avoids CallLowering fail inside getVectorTypeBreakdown
46 // on v3i1 arguments. Maybe we need to return i32 for all types.
47 // TODO: remove it once this case is supported by the default implementation.
48 if (VT.isVector() && VT.getVectorNumElements() == 3) {
49 if (VT.getVectorElementType() == MVT::i1)
50 return MVT::v4i1;
51 else if (VT.getVectorElementType() == MVT::i8)
52 return MVT::v4i8;
53 }
54 return getRegisterType(Context, VT);
55}
56
58 const CallInst &I,
60 unsigned Intrinsic) const {
61 unsigned AlignIdx = 3;
62 switch (Intrinsic) {
63 case Intrinsic::spv_load:
64 AlignIdx = 2;
65 [[fallthrough]];
66 case Intrinsic::spv_store: {
67 if (I.getNumOperands() >= AlignIdx + 1) {
68 auto *AlignOp = cast<ConstantInt>(I.getOperand(AlignIdx));
69 Info.align = Align(AlignOp->getZExtValue());
70 }
71 Info.flags = static_cast<MachineMemOperand::Flags>(
72 cast<ConstantInt>(I.getOperand(AlignIdx - 1))->getZExtValue());
73 Info.memVT = MVT::i64;
74 // TODO: take into account opaque pointers (don't use getElementType).
75 // MVT::getVT(PtrTy->getElementType());
76 return true;
77 break;
78 }
79 default:
80 break;
81 }
82 return false;
83}
84
85// Insert a bitcast before the instruction to keep SPIR-V code valid
86// when there is a type mismatch between results and operand types.
87static void validatePtrTypes(const SPIRVSubtarget &STI,
89 MachineInstr &I, unsigned OpIdx,
90 SPIRVType *ResType, const Type *ResTy = nullptr) {
91 // Get operand type
92 MachineFunction *MF = I.getParent()->getParent();
93 Register OpReg = I.getOperand(OpIdx).getReg();
94 SPIRVType *TypeInst = MRI->getVRegDef(OpReg);
95 Register OpTypeReg =
96 TypeInst && TypeInst->getOpcode() == SPIRV::OpFunctionParameter
97 ? TypeInst->getOperand(1).getReg()
98 : OpReg;
99 SPIRVType *OpType = GR.getSPIRVTypeForVReg(OpTypeReg, MF);
100 if (!ResType || !OpType || OpType->getOpcode() != SPIRV::OpTypePointer)
101 return;
102 // Get operand's pointee type
103 Register ElemTypeReg = OpType->getOperand(2).getReg();
104 SPIRVType *ElemType = GR.getSPIRVTypeForVReg(ElemTypeReg, MF);
105 if (!ElemType)
106 return;
107 // Check if we need a bitcast to make a statement valid
108 bool IsSameMF = MF == ResType->getParent()->getParent();
109 bool IsEqualTypes = IsSameMF ? ElemType == ResType
110 : GR.getTypeForSPIRVType(ElemType) == ResTy;
111 if (IsEqualTypes)
112 return;
113 // There is a type mismatch between results and operand types
114 // and we insert a bitcast before the instruction to keep SPIR-V code valid
115 SPIRV::StorageClass::StorageClass SC =
116 static_cast<SPIRV::StorageClass::StorageClass>(
117 OpType->getOperand(1).getImm());
118 MachineIRBuilder MIB(I);
119 SPIRVType *NewBaseType =
120 IsSameMF ? ResType
122 ResTy, MIB, SPIRV::AccessQualifier::ReadWrite, false);
123 SPIRVType *NewPtrType = GR.getOrCreateSPIRVPointerType(NewBaseType, MIB, SC);
124 if (!GR.isBitcastCompatible(NewPtrType, OpType))
126 "insert validation bitcast: incompatible result and operand types");
127 Register NewReg = MRI->createGenericVirtualRegister(LLT::scalar(32));
128 bool Res = MIB.buildInstr(SPIRV::OpBitcast)
129 .addDef(NewReg)
130 .addUse(GR.getSPIRVTypeID(NewPtrType))
131 .addUse(OpReg)
133 *STI.getRegBankInfo());
134 if (!Res)
135 report_fatal_error("insert validation bitcast: cannot constrain all uses");
136 MRI->setRegClass(NewReg, &SPIRV::IDRegClass);
137 GR.assignSPIRVTypeToVReg(NewPtrType, NewReg, MIB.getMF());
138 I.getOperand(OpIdx).setReg(NewReg);
139}
140
141// Insert a bitcast before the function call instruction to keep SPIR-V code
142// valid when there is a type mismatch between actual and expected types of an
143// argument:
144// %formal = OpFunctionParameter %formal_type
145// ...
146// %res = OpFunctionCall %ty %fun %actual ...
147// implies that %actual is of %formal_type, and in case of opaque pointers.
148// We may need to insert a bitcast to ensure this.
150 MachineRegisterInfo *DefMRI,
151 MachineRegisterInfo *CallMRI,
152 SPIRVGlobalRegistry &GR, MachineInstr &FunCall,
153 MachineInstr *FunDef) {
154 if (FunDef->getOpcode() != SPIRV::OpFunction)
155 return;
156 unsigned OpIdx = 3;
157 for (FunDef = FunDef->getNextNode();
158 FunDef && FunDef->getOpcode() == SPIRV::OpFunctionParameter &&
159 OpIdx < FunCall.getNumOperands();
160 FunDef = FunDef->getNextNode(), OpIdx++) {
161 SPIRVType *DefPtrType = DefMRI->getVRegDef(FunDef->getOperand(1).getReg());
162 SPIRVType *DefElemType =
163 DefPtrType && DefPtrType->getOpcode() == SPIRV::OpTypePointer
164 ? GR.getSPIRVTypeForVReg(DefPtrType->getOperand(2).getReg(),
165 DefPtrType->getParent()->getParent())
166 : nullptr;
167 if (DefElemType) {
168 const Type *DefElemTy = GR.getTypeForSPIRVType(DefElemType);
169 // validatePtrTypes() works in the context if the call site
170 // When we process historical records about forward calls
171 // we need to switch context to the (forward) call site and
172 // then restore it back to the current machine function.
173 MachineFunction *CurMF =
174 GR.setCurrentFunc(*FunCall.getParent()->getParent());
175 validatePtrTypes(STI, CallMRI, GR, FunCall, OpIdx, DefElemType,
176 DefElemTy);
177 GR.setCurrentFunc(*CurMF);
178 }
179 }
180}
181
182// Ensure there is no mismatch between actual and expected arg types: calls
183// with a processed definition. Return Function pointer if it's a forward
184// call (ahead of definition), and nullptr otherwise.
186 MachineRegisterInfo *CallMRI,
188 MachineInstr &FunCall) {
189 const GlobalValue *GV = FunCall.getOperand(2).getGlobal();
190 const Function *F = dyn_cast<Function>(GV);
191 MachineInstr *FunDef =
192 const_cast<MachineInstr *>(GR.getFunctionDefinition(F));
193 if (!FunDef)
194 return F;
195 MachineRegisterInfo *DefMRI = &FunDef->getParent()->getParent()->getRegInfo();
196 validateFunCallMachineDef(STI, DefMRI, CallMRI, GR, FunCall, FunDef);
197 return nullptr;
198}
199
200// Ensure there is no mismatch between actual and expected arg types: calls
201// ahead of a processed definition.
204 MachineInstr &FunDef) {
205 const Function *F = GR.getFunctionByDefinition(&FunDef);
207 for (MachineInstr *FunCall : *FwdCalls) {
208 MachineRegisterInfo *CallMRI =
209 &FunCall->getParent()->getParent()->getRegInfo();
210 validateFunCallMachineDef(STI, DefMRI, CallMRI, GR, *FunCall, &FunDef);
211 }
212}
213
214// Validation of an access chain.
217 SPIRVType *BaseTypeInst = GR.getSPIRVTypeForVReg(I.getOperand(0).getReg());
218 if (BaseTypeInst && BaseTypeInst->getOpcode() == SPIRV::OpTypePointer) {
219 SPIRVType *BaseElemType =
220 GR.getSPIRVTypeForVReg(BaseTypeInst->getOperand(2).getReg());
221 validatePtrTypes(STI, MRI, GR, I, 2, BaseElemType);
222 }
223}
224
225// TODO: the logic of inserting additional bitcast's is to be moved
226// to pre-IRTranslation passes eventually
228 // finalizeLowering() is called twice (see GlobalISel/InstructionSelect.cpp)
229 // We'd like to avoid the needless second processing pass.
230 if (ProcessedMF.find(&MF) != ProcessedMF.end())
231 return;
232
235 GR.setCurrentFunc(MF);
236 for (MachineFunction::iterator I = MF.begin(), E = MF.end(); I != E; ++I) {
238 for (MachineBasicBlock::iterator MBBI = MBB->begin(), MBBE = MBB->end();
239 MBBI != MBBE;) {
240 MachineInstr &MI = *MBBI++;
241 switch (MI.getOpcode()) {
242 case SPIRV::OpAtomicLoad:
243 case SPIRV::OpAtomicExchange:
244 case SPIRV::OpAtomicCompareExchange:
245 case SPIRV::OpAtomicCompareExchangeWeak:
246 case SPIRV::OpAtomicIIncrement:
247 case SPIRV::OpAtomicIDecrement:
248 case SPIRV::OpAtomicIAdd:
249 case SPIRV::OpAtomicISub:
250 case SPIRV::OpAtomicSMin:
251 case SPIRV::OpAtomicUMin:
252 case SPIRV::OpAtomicSMax:
253 case SPIRV::OpAtomicUMax:
254 case SPIRV::OpAtomicAnd:
255 case SPIRV::OpAtomicOr:
256 case SPIRV::OpAtomicXor:
257 // for the above listed instructions
258 // OpAtomicXXX <ResType>, ptr %Op, ...
259 // implies that %Op is a pointer to <ResType>
260 case SPIRV::OpLoad:
261 // OpLoad <ResType>, ptr %Op implies that %Op is a pointer to <ResType>
262 validatePtrTypes(STI, MRI, GR, MI, 2,
263 GR.getSPIRVTypeForVReg(MI.getOperand(0).getReg()));
264 break;
265 case SPIRV::OpAtomicStore:
266 // OpAtomicStore ptr %Op, <Scope>, <Mem>, <Obj>
267 // implies that %Op points to the <Obj>'s type
268 validatePtrTypes(STI, MRI, GR, MI, 0,
269 GR.getSPIRVTypeForVReg(MI.getOperand(3).getReg()));
270 break;
271 case SPIRV::OpStore:
272 // OpStore ptr %Op, <Obj> implies that %Op points to the <Obj>'s type
273 validatePtrTypes(STI, MRI, GR, MI, 0,
274 GR.getSPIRVTypeForVReg(MI.getOperand(1).getReg()));
275 break;
276 case SPIRV::OpPtrCastToGeneric:
277 validateAccessChain(STI, MRI, GR, MI);
278 break;
279 case SPIRV::OpInBoundsPtrAccessChain:
280 if (MI.getNumOperands() == 4)
281 validateAccessChain(STI, MRI, GR, MI);
282 break;
283
284 case SPIRV::OpFunctionCall:
285 // ensure there is no mismatch between actual and expected arg types:
286 // calls with a processed definition
287 if (MI.getNumOperands() > 3)
288 if (const Function *F = validateFunCall(STI, MRI, GR, MI))
289 GR.addForwardCall(F, &MI);
290 break;
291 case SPIRV::OpFunction:
292 // ensure there is no mismatch between actual and expected arg types:
293 // calls ahead of a processed definition
294 validateForwardCalls(STI, MRI, GR, MI);
295 break;
296
297 // ensure that LLVM IR bitwise instructions result in logical SPIR-V
298 // instructions when applied to bool type
299 case SPIRV::OpBitwiseOrS:
300 case SPIRV::OpBitwiseOrV:
301 if (GR.isScalarOrVectorOfType(MI.getOperand(1).getReg(),
302 SPIRV::OpTypeBool))
303 MI.setDesc(STI.getInstrInfo()->get(SPIRV::OpLogicalOr));
304 break;
305 case SPIRV::OpBitwiseAndS:
306 case SPIRV::OpBitwiseAndV:
307 if (GR.isScalarOrVectorOfType(MI.getOperand(1).getReg(),
308 SPIRV::OpTypeBool))
309 MI.setDesc(STI.getInstrInfo()->get(SPIRV::OpLogicalAnd));
310 break;
311 case SPIRV::OpBitwiseXorS:
312 case SPIRV::OpBitwiseXorV:
313 if (GR.isScalarOrVectorOfType(MI.getOperand(1).getReg(),
314 SPIRV::OpTypeBool))
315 MI.setDesc(STI.getInstrInfo()->get(SPIRV::OpLogicalNotEqual));
316 break;
317 }
318 }
319 }
320 ProcessedMF.insert(&MF);
322}
unsigned const MachineRegisterInfo * MRI
MachineBasicBlock & MBB
MachineBasicBlock MachineBasicBlock::iterator MBBI
Analysis containing CSE Info
Definition: CSEInfo.cpp:27
IRTranslator LLVM IR MI
#define F(x, y, z)
Definition: MD5.cpp:55
#define I(x, y, z)
Definition: MD5.cpp:58
LLVMContext & Context
void validateAccessChain(const SPIRVSubtarget &STI, MachineRegisterInfo *MRI, SPIRVGlobalRegistry &GR, MachineInstr &I)
void validateFunCallMachineDef(const SPIRVSubtarget &STI, MachineRegisterInfo *DefMRI, MachineRegisterInfo *CallMRI, SPIRVGlobalRegistry &GR, MachineInstr &FunCall, MachineInstr *FunDef)
void validateForwardCalls(const SPIRVSubtarget &STI, MachineRegisterInfo *DefMRI, SPIRVGlobalRegistry &GR, MachineInstr &FunDef)
const Function * validateFunCall(const SPIRVSubtarget &STI, MachineRegisterInfo *CallMRI, SPIRVGlobalRegistry &GR, MachineInstr &FunCall)
static void validatePtrTypes(const SPIRVSubtarget &STI, MachineRegisterInfo *MRI, SPIRVGlobalRegistry &GR, MachineInstr &I, unsigned OpIdx, SPIRVType *ResType, const Type *ResTy=nullptr)
This class represents a function call, abstracting a target machine's calling convention.
static constexpr LLT scalar(unsigned SizeInBits)
Get a low-level scalar or aggregate "bag of bits".
Definition: LowLevelType.h:42
This is an important class for using LLVM in a threaded context.
Definition: LLVMContext.h:67
Machine Value Type.
const MachineFunction * getParent() const
Return the MachineFunction containing this basic block.
MachineRegisterInfo & getRegInfo()
getRegInfo - Return information about the registers currently in use.
Helper class to build MachineInstr.
MachineInstrBuilder buildInstr(unsigned Opcode)
Build and insert <empty> = Opcode <empty>.
MachineFunction & getMF()
Getter for the function we currently build.
bool constrainAllUses(const TargetInstrInfo &TII, const TargetRegisterInfo &TRI, const RegisterBankInfo &RBI) const
const MachineInstrBuilder & addUse(Register RegNo, unsigned Flags=0, unsigned SubReg=0) const
Add a virtual register use operand.
const MachineInstrBuilder & addDef(Register RegNo, unsigned Flags=0, unsigned SubReg=0) const
Add a virtual register definition operand.
Representation of each machine instruction.
Definition: MachineInstr.h:69
unsigned getOpcode() const
Returns the opcode of this MachineInstr.
Definition: MachineInstr.h:546
const MachineBasicBlock * getParent() const
Definition: MachineInstr.h:329
unsigned getNumOperands() const
Retuns the total number of operands.
Definition: MachineInstr.h:549
const MachineOperand & getOperand(unsigned i) const
Definition: MachineInstr.h:556
Flags
Flags values. These may be or'd together.
const GlobalValue * getGlobal() const
int64_t getImm() const
Register getReg() const
getReg - Returns the register number.
MachineRegisterInfo - Keep track of information for virtual and physical registers,...
MachineInstr * getVRegDef(Register Reg) const
getVRegDef - Return the machine instr that defines the specified virtual register or null if none is ...
Wrapper class representing virtual and physical registers.
Definition: Register.h:19
SPIRVType * getSPIRVTypeForVReg(Register VReg, const MachineFunction *MF=nullptr) const
void addForwardCall(const Function *F, MachineInstr *MI)
const Type * getTypeForSPIRVType(const SPIRVType *Ty) const
bool isBitcastCompatible(const SPIRVType *Type1, const SPIRVType *Type2) const
const MachineInstr * getFunctionDefinition(const Function *F)
Register getSPIRVTypeID(const SPIRVType *SpirvType) const
SPIRVType * getOrCreateSPIRVType(const Type *Type, MachineIRBuilder &MIRBuilder, SPIRV::AccessQualifier::AccessQualifier AQ=SPIRV::AccessQualifier::ReadWrite, bool EmitIR=true)
void assignSPIRVTypeToVReg(SPIRVType *Type, Register VReg, MachineFunction &MF)
SmallPtrSet< MachineInstr *, 8 > * getForwardCalls(const Function *F)
bool isScalarOrVectorOfType(Register VReg, unsigned TypeOpcode) const
MachineFunction * setCurrentFunc(MachineFunction &MF)
SPIRVType * getOrCreateSPIRVPointerType(SPIRVType *BaseType, MachineIRBuilder &MIRBuilder, SPIRV::StorageClass::StorageClass SClass=SPIRV::StorageClass::Function)
const Function * getFunctionByDefinition(const MachineInstr *MI)
const SPIRVInstrInfo * getInstrInfo() const override
SPIRVGlobalRegistry * getSPIRVGlobalRegistry() const
const SPIRVRegisterInfo * getRegisterInfo() const override
const RegisterBankInfo * getRegBankInfo() const override
unsigned getNumRegistersForCallingConv(LLVMContext &Context, CallingConv::ID CC, EVT VT) const override
Certain targets require unusual breakdowns of certain types.
MVT getRegisterTypeForCallingConv(LLVMContext &Context, CallingConv::ID CC, EVT VT) const override
Certain combinations of ABIs, Targets and features require that types are legal for some operations a...
bool getTgtMemIntrinsic(IntrinsicInfo &Info, const CallInst &I, MachineFunction &MF, unsigned Intrinsic) const override
Given an intrinsic, checks if on the target the intrinsic will need to map to a MemIntrinsicNode (tou...
void finalizeLowering(MachineFunction &MF) const override
Execute target specific actions to finalize target lowering.
SmallPtrSet - This class implements a set which is optimized for holding SmallSize or less elements.
Definition: SmallPtrSet.h:427
virtual void finalizeLowering(MachineFunction &MF) const
Execute target specific actions to finalize target lowering.
virtual unsigned getNumRegisters(LLVMContext &Context, EVT VT, std::optional< MVT > RegisterVT=std::nullopt) const
Return the number of registers that this ValueType will eventually require.
MVT getRegisterType(MVT VT) const
Return the type of registers that this ValueType will eventually require.
The instances of the Type class are immutable: once they are created, they are never changed.
Definition: Type.h:45
NodeTy * getNextNode()
Get the next node, or nullptr for the list tail.
Definition: ilist_node.h:316
This is an optimization pass for GlobalISel generic memory operations.
Definition: AddressRanges.h:18
void report_fatal_error(Error Err, bool gen_crash_diag=true)
Report a serious error, calling any installed error handler.
Definition: Error.cpp:156
This struct is a compact representation of a valid (non-zero power of two) alignment.
Definition: Alignment.h:39
Extended Value Type.
Definition: ValueTypes.h:34
TypeSize getSizeInBits() const
Return the size of the specified value type in bits.
Definition: ValueTypes.h:358
bool isVector() const
Return true if this is a vector value type.
Definition: ValueTypes.h:167
EVT getVectorElementType() const
Given a vector type, return the type of each element.
Definition: ValueTypes.h:318
unsigned getVectorNumElements() const
Given a vector type, return the number of elements it contains.
Definition: ValueTypes.h:326
bool isInteger() const
Return true if this is an integer or a vector integer type.
Definition: ValueTypes.h:151