23#include "llvm/IR/IntrinsicsSPIRV.h"
25#define DEBUG_TYPE "spirv-lower"
51 if (Ty1->
getOpcode() == SPIRV::OpTypeArray) {
60 return ElemType1 == ElemType2 ||
64 if (Ty1->
getOpcode() == SPIRV::OpTypeStruct) {
70 if (ElemType1 != ElemType2 &&
112 unsigned AlignIdx = 3;
114 case Intrinsic::spv_load:
117 case Intrinsic::spv_store: {
118 if (
I.getNumOperands() >= AlignIdx + 1) {
120 Info.align =
Align(AlignOp->getZExtValue());
124 Info.memVT = MVT::i64;
135std::pair<unsigned, const TargetRegisterClass *>
141 return std::make_pair(0u, RC);
144 RC = VT.
isVector() ? &SPIRV::vfIDRegClass : &SPIRV::fIDRegClass;
146 RC = VT.
isVector() ? &SPIRV::vIDRegClass : &SPIRV::iIDRegClass;
148 RC = &SPIRV::iIDRegClass;
150 return std::make_pair(0u, RC);
155 return Inst && Inst->
getOpcode() == SPIRV::OpFunctionParameter
172 I.getOperand(
OpIdx).setReg(NewReg);
179 SPIRV::StorageClass::StorageClass SC =
180 static_cast<SPIRV::StorageClass::StorageClass
>(
181 OpType->getOperand(1).
getImm());
186 ResTy, MIB, SPIRV::AccessQualifier::ReadWrite,
false);
196 const Type *ResTy =
nullptr) {
202 if (!ResType || !OpType || OpType->
getOpcode() != SPIRV::OpTypePointer)
205 Register ElemTypeReg = OpType->getOperand(2).getReg();
211 bool IsEqualTypes = IsSameMF ? ElemType == ResType
221 "insert validation bitcast: incompatible result and operand types");
231 constexpr unsigned OpIdx = 2;
236 if (!OpType || OpType->getOpcode() != SPIRV::OpTypePointer)
240 if (!ElemType || ElemType->
getOpcode() == SPIRV::OpTypeEvent)
253 Register PtrReg =
I.getOperand(0).getReg();
258 if (!PonteeElemType || PonteeElemType->
getOpcode() == SPIRV::OpTypeVoid ||
259 (PonteeElemType->
getOpcode() == SPIRV::OpTypeInt &&
263 SPIRV::StorageClass::StorageClass SC =
264 static_cast<SPIRV::StorageClass::StorageClass
>(
281 if (!OpType || OpType->getOpcode() != SPIRV::OpTypePointer)
285 if (!ElemType || ElemType->
getOpcode() != SPIRV::OpTypeStruct ||
293 unsigned MemberTypeOp = MemberType->
getOpcode();
294 if (MemberTypeOp != SPIRV::OpTypeVector && MemberTypeOp != SPIRV::OpTypeInt &&
295 MemberTypeOp != SPIRV::OpTypeFloat && MemberTypeOp != SPIRV::OpTypeBool)
299 SPIRV::StorageClass::StorageClass SC =
300 static_cast<SPIRV::StorageClass::StorageClass
>(
301 OpType->getOperand(1).
getImm());
321 if (FunDef->
getOpcode() != SPIRV::OpFunction)
325 FunDef && FunDef->
getOpcode() == SPIRV::OpFunctionParameter &&
331 DefPtrType && DefPtrType->
getOpcode() == SPIRV::OpTypePointer
377 &FunCall->getParent()->getParent()->getRegInfo();
386 if (BaseTypeInst && BaseTypeInst->
getOpcode() == SPIRV::OpTypePointer) {
398 if (ProcessedMF.find(&MF) != ProcessedMF.end())
409 switch (
MI.getOpcode()) {
410 case SPIRV::OpAtomicLoad:
411 case SPIRV::OpAtomicExchange:
412 case SPIRV::OpAtomicCompareExchange:
413 case SPIRV::OpAtomicCompareExchangeWeak:
414 case SPIRV::OpAtomicIIncrement:
415 case SPIRV::OpAtomicIDecrement:
416 case SPIRV::OpAtomicIAdd:
417 case SPIRV::OpAtomicISub:
418 case SPIRV::OpAtomicSMin:
419 case SPIRV::OpAtomicUMin:
420 case SPIRV::OpAtomicSMax:
421 case SPIRV::OpAtomicUMax:
422 case SPIRV::OpAtomicAnd:
423 case SPIRV::OpAtomicOr:
424 case SPIRV::OpAtomicXor:
436 case SPIRV::OpAtomicStore:
447 case SPIRV::OpPtrCastToGeneric:
448 case SPIRV::OpGenericCastToPtr:
449 case SPIRV::OpGenericCastToPtrExplicit:
452 case SPIRV::OpPtrAccessChain:
453 case SPIRV::OpInBoundsPtrAccessChain:
454 if (
MI.getNumOperands() == 4)
458 case SPIRV::OpFunctionCall:
461 if (
MI.getNumOperands() > 3)
465 case SPIRV::OpFunction:
479 MI.setDesc(STI.getInstrInfo()->get(SPIRV::OpLogicalNotEqual));
484 case SPIRV::OpBitwiseOrS:
485 case SPIRV::OpBitwiseOrV:
488 MI.setDesc(STI.getInstrInfo()->get(SPIRV::OpLogicalOr));
490 case SPIRV::OpBitwiseAndS:
491 case SPIRV::OpBitwiseAndV:
494 MI.setDesc(STI.getInstrInfo()->get(SPIRV::OpLogicalAnd));
496 case SPIRV::OpBitwiseXorS:
497 case SPIRV::OpBitwiseXorV:
500 MI.setDesc(STI.getInstrInfo()->get(SPIRV::OpLogicalNotEqual));
502 case SPIRV::OpLifetimeStart:
503 case SPIRV::OpLifetimeStop:
504 if (
MI.getOperand(1).getImm() > 0)
507 case SPIRV::OpGroupAsyncCopy:
511 case SPIRV::OpGroupWaitEvents:
515 case SPIRV::OpConstantI: {
517 if (
Type->getOpcode() != SPIRV::OpTypeInt &&
MI.getOperand(2).isImm() &&
518 MI.getOperand(2).getImm() == 0) {
520 MI.setDesc(STI.getInstrInfo()->get(SPIRV::OpConstantNull));
521 for (
unsigned i =
MI.getNumOperands() - 1; i > 1; --i)
525 case SPIRV::OpExtInst: {
527 if (!
MI.getOperand(2).isImm() || !
MI.getOperand(3).isImm() ||
528 MI.getOperand(2).getImm() != SPIRV::InstructionSet::OpenCL_std)
530 switch (
MI.getOperand(3).getImm()) {
531 case SPIRV::OpenCLExtInst::frexp:
532 case SPIRV::OpenCLExtInst::lgamma_r:
533 case SPIRV::OpenCLExtInst::remquo: {
539 assert(RetType &&
"Expected return type");
541 RetType->
getOpcode() != SPIRV::OpTypeVector
547 case SPIRV::OpenCLExtInst::fract:
548 case SPIRV::OpenCLExtInst::modf:
549 case SPIRV::OpenCLExtInst::sincos:
552 assert(
MI.getOperand(
MI.getNumOperands() - 2).isReg() &&
555 STI,
MRI, GR,
MI,
MI.getNumOperands() - 1,
557 MI.getOperand(
MI.getNumOperands() - 2).getReg()));
559 case SPIRV::OpenCLExtInst::prefetch:
562 assert(
MI.getOperand(
MI.getNumOperands() - 2).isReg() &&
565 MI.getNumOperands() - 2);
586 if (PointeeType == OpType)
591 if (
I.getOperand(
OpIdx).isDef() &&
618 OldResult.
setReg(NewResultReg);
619 OldType.
setReg(NewTypeReg);
627 *STI.getRegBankInfo());
unsigned const MachineRegisterInfo * MRI
assert(UImm &&(UImm !=~static_cast< T >(0)) &&"Invalid immediate!")
MachineBasicBlock MachineBasicBlock::iterator MBBI
Register const TargetRegisterInfo * TRI
MachineInstr unsigned OpIdx
static bool typesLogicallyMatch(const SPIRVTypeInst Ty1, const SPIRVTypeInst Ty2, SPIRVGlobalRegistry &GR)
static void validateLifetimeStart(const SPIRVSubtarget &STI, MachineRegisterInfo *MRI, SPIRVGlobalRegistry &GR, MachineInstr &I)
static void validatePtrTypes(const SPIRVSubtarget &STI, MachineRegisterInfo *MRI, SPIRVGlobalRegistry &GR, MachineInstr &I, unsigned OpIdx, SPIRVTypeInst ResType, const Type *ResTy=nullptr)
static void validateGroupWaitEventsPtr(const SPIRVSubtarget &STI, MachineRegisterInfo *MRI, SPIRVGlobalRegistry &GR, MachineInstr &I)
static void validatePtrUnwrapStructField(const SPIRVSubtarget &STI, MachineRegisterInfo *MRI, SPIRVGlobalRegistry &GR, MachineInstr &I, unsigned OpIdx)
Register getTypeReg(MachineRegisterInfo *MRI, Register OpReg)
void validateAccessChain(const SPIRVSubtarget &STI, MachineRegisterInfo *MRI, SPIRVGlobalRegistry &GR, MachineInstr &I)
void validateFunCallMachineDef(const SPIRVSubtarget &STI, MachineRegisterInfo *DefMRI, MachineRegisterInfo *CallMRI, SPIRVGlobalRegistry &GR, MachineInstr &FunCall, MachineInstr *FunDef)
void validateForwardCalls(const SPIRVSubtarget &STI, MachineRegisterInfo *DefMRI, SPIRVGlobalRegistry &GR, MachineInstr &FunDef)
const Function * validateFunCall(const SPIRVSubtarget &STI, MachineRegisterInfo *CallMRI, SPIRVGlobalRegistry &GR, MachineInstr &FunCall)
static void doInsertBitcast(const SPIRVSubtarget &STI, MachineRegisterInfo *MRI, SPIRVGlobalRegistry &GR, MachineInstr &I, Register OpReg, unsigned OpIdx, SPIRVTypeInst NewPtrType)
static SPIRVTypeInst createNewPtrType(SPIRVGlobalRegistry &GR, MachineInstr &I, SPIRVTypeInst OpType, bool ReuseType, SPIRVTypeInst ResType, const Type *ResTy)
This file describes how to lower LLVM code to machine code.
an instruction that atomically reads a memory location, combines it with another value,...
@ UIncWrap
Increment one up to a maximum value.
@ FMin
*p = minnum(old, v) minnum matches the behavior of llvm.minnum.
@ FMax
*p = maxnum(old, v) maxnum matches the behavior of llvm.maxnum.
@ UDecWrap
Decrement one until a minimum value or zero.
BinOp getOperation() const
Base class for all callable instructions (InvokeInst and CallInst) Holds everything related to callin...
LLVMContext & getContext() const
getContext - Return a reference to the LLVMContext associated with this function.
This is an important class for using LLVM in a threaded context.
bool isVector() const
Return true if this is a vector value type.
bool isInteger() const
Return true if this is an integer or a vector integer type.
bool isFloatingPoint() const
Return true if this is a FP or a vector FP type.
const MachineFunction * getParent() const
Return the MachineFunction containing this basic block.
MachineInstrBundleIterator< MachineInstr > iterator
MachineRegisterInfo & getRegInfo()
getRegInfo - Return information about the registers currently in use.
Function & getFunction()
Return the LLVM function that this machine code represents.
BasicBlockListType::iterator iterator
void insert(iterator MBBI, MachineBasicBlock *MBB)
Helper class to build MachineInstr.
MachineInstrBuilder buildInstr(unsigned Opcode)
Build and insert <empty> = Opcode <empty>.
MachineFunction & getMF()
Getter for the function we currently build.
void constrainAllUses(const TargetInstrInfo &TII, const TargetRegisterInfo &TRI, const RegisterBankInfo &RBI) const
const MachineInstrBuilder & addUse(Register RegNo, RegState Flags={}, unsigned SubReg=0) const
Add a virtual register use operand.
const MachineInstrBuilder & addDef(Register RegNo, RegState Flags={}, unsigned SubReg=0) const
Add a virtual register definition operand.
Representation of each machine instruction.
unsigned getOpcode() const
Returns the opcode of this MachineInstr.
const MachineBasicBlock * getParent() const
unsigned getNumOperands() const
Retuns the total number of operands.
const MachineOperand & getOperand(unsigned i) const
Flags
Flags values. These may be or'd together.
MachineOperand class - Representation of each machine instruction operand.
const GlobalValue * getGlobal() const
LLVM_ABI void setReg(Register Reg)
Change the register this operand corresponds to.
Register getReg() const
getReg - Returns the register number.
MachineRegisterInfo - Keep track of information for virtual and physical registers,...
LLVM_ABI MachineInstr * getVRegDef(Register Reg) const
getVRegDef - Return the machine instr that defines the specified virtual register or null if none is ...
Wrapper class representing virtual and physical registers.
void addForwardCall(const Function *F, MachineInstr *MI)
SPIRVTypeInst getOrCreateSPIRVIntegerType(unsigned BitWidth, MachineIRBuilder &MIRBuilder)
SPIRVTypeInst getOrCreateSPIRVVectorType(SPIRVTypeInst BaseType, unsigned NumElements, MachineIRBuilder &MIRBuilder, bool EmitIR)
SPIRVTypeInst getResultType(Register VReg, MachineFunction *MF=nullptr)
const Type * getTypeForSPIRVType(SPIRVTypeInst Ty) const
bool isBitcastCompatible(SPIRVTypeInst Type1, SPIRVTypeInst Type2) const
const MachineInstr * getFunctionDefinition(const Function *F)
SPIRVTypeInst getOrCreateSPIRVPointerType(const Type *BaseType, MachineIRBuilder &MIRBuilder, SPIRV::StorageClass::StorageClass SC)
Register getSPIRVTypeID(SPIRVTypeInst SpirvType) const
SPIRVTypeInst getPointeeType(SPIRVTypeInst PtrType)
SmallPtrSet< MachineInstr *, 8 > * getForwardCalls(const Function *F)
SPIRVTypeInst getOrCreateSPIRVType(const Type *Type, MachineInstr &I, SPIRV::AccessQualifier::AccessQualifier AQ, bool EmitIR)
bool isScalarOrVectorOfType(Register VReg, unsigned TypeOpcode) const
MachineFunction * setCurrentFunc(MachineFunction &MF)
SPIRVTypeInst getSPIRVTypeForVReg(Register VReg, const MachineFunction *MF=nullptr) const
const Function * getFunctionByDefinition(const MachineInstr *MI)
const SPIRVInstrInfo * getInstrInfo() const override
const SPIRVRegisterInfo * getRegisterInfo() const override
const RegisterBankInfo * getRegBankInfo() const override
AtomicExpansionKind shouldCastAtomicRMWIInIR(AtomicRMWInst *RMWI) const override
Returns how the given atomic atomicrmw should be cast by the IR-level AtomicExpand pass.
bool enforcePtrTypeCompatibility(MachineInstr &I, unsigned PtrOpIdx, unsigned OpIdx) const
unsigned getNumRegisters(LLVMContext &Context, EVT VT, std::optional< MVT > RegisterVT=std::nullopt) const override
Return the number of registers that this ValueType will eventually require.
unsigned getNumRegistersForCallingConv(LLVMContext &Context, CallingConv::ID CC, EVT VT) const override
Certain targets require unusual breakdowns of certain types.
MVT getRegisterTypeForCallingConv(LLVMContext &Context, CallingConv::ID CC, EVT VT) const override
Certain combinations of ABIs, Targets and features require that types are legal for some operations a...
AtomicExpansionKind shouldExpandAtomicRMWInIR(const AtomicRMWInst *RMW) const override
Returns how the IR-level AtomicExpand pass should expand the given AtomicRMW, if at all.
void finalizeLowering(MachineFunction &MF) const override
Execute target specific actions to finalize target lowering.
void getTgtMemIntrinsic(SmallVectorImpl< IntrinsicInfo > &Infos, const CallBase &I, MachineFunction &MF, unsigned Intrinsic) const override
Given an intrinsic, checks if on the target the intrinsic will need to map to a MemIntrinsicNode (tou...
bool insertLogicalCopyOnResult(MachineInstr &I, SPIRVTypeInst NewResultType) const
SPIRVTargetLowering(const TargetMachine &TM, const SPIRVSubtarget &ST)
std::pair< unsigned, const TargetRegisterClass * > getRegForInlineAsmConstraint(const TargetRegisterInfo *TRI, StringRef Constraint, MVT VT) const override
Given a physical register constraint (e.g.
SmallPtrSet - This class implements a set which is optimized for holding SmallSize or less elements.
This class consists of common code factored out of the SmallVector class to reduce code duplication b...
void push_back(const T &Elt)
StringRef - Represent a constant reference to a string, i.e.
bool starts_with(StringRef Prefix) const
Check if this string starts with the given Prefix.
static LLVM_ABI TargetExtType * get(LLVMContext &Context, StringRef Name, ArrayRef< Type * > Types={}, ArrayRef< unsigned > Ints={})
Return a target extension type having the specified name and optional type and integer parameters.
virtual void finalizeLowering(MachineFunction &MF) const
Execute target specific actions to finalize target lowering.
virtual AtomicExpansionKind shouldExpandAtomicRMWInIR(const AtomicRMWInst *RMW) const
Returns how the IR-level AtomicExpand pass should expand the given AtomicRMW, if at all.
void setMaxAtomicSizeInBitsSupported(unsigned SizeInBits)
Set the maximum atomic operation size supported by the backend.
void setMinCmpXchgSizeInBits(unsigned SizeInBits)
Sets the minimum cmpxchg or ll/sc size supported by the backend.
AtomicExpansionKind
Enum that specifies what an atomic load/AtomicRMWInst is expanded to, if at all.
MVT getRegisterType(MVT VT) const
Return the type of registers that this ValueType will eventually require.
TargetLowering(const TargetLowering &)=delete
Primary interface to the complete machine description for the target machine.
TargetRegisterInfo base class - We assume that the target defines a static array of TargetRegisterDes...
The instances of the Type class are immutable: once they are created, they are never changed.
static LLVM_ABI IntegerType * getInt8Ty(LLVMContext &C)
NodeTy * getNextNode()
Get the next node, or nullptr for the list tail.
#define llvm_unreachable(msg)
Marks that the current location is not supposed to be reachable.
unsigned ID
LLVM IR allows to use arbitrary numbers as calling convention identifiers.
This namespace contains an enum with a value for every intrinsic/builtin function known by LLVM.
This is an optimization pass for GlobalISel generic memory operations.
auto size(R &&Range, std::enable_if_t< std::is_base_of< std::random_access_iterator_tag, typename std::iterator_traits< decltype(Range.begin())>::iterator_category >::value, void > *=nullptr)
Get the size of a range.
decltype(auto) dyn_cast(const From &Val)
dyn_cast<X> - Return the argument parameter cast to the specified type.
Register createVirtualRegister(SPIRVTypeInst SpvType, SPIRVGlobalRegistry *GR, MachineRegisterInfo *MRI, const MachineFunction &MF)
MachineInstr * getImm(const MachineOperand &MO, const MachineRegisterInfo *MRI)
LLVM_ABI void report_fatal_error(Error Err, bool gen_crash_diag=true)
decltype(auto) cast(const From &Val)
cast<X> - Return the argument parameter cast to the specified type.
This struct is a compact representation of a valid (non-zero power of two) alignment.
TypeSize getSizeInBits() const
Return the size of the specified value type in bits.
bool isVector() const
Return true if this is a vector value type.
EVT getVectorElementType() const
Given a vector type, return the type of each element.
unsigned getVectorNumElements() const
Given a vector type, return the number of elements it contains.
bool isInteger() const
Return true if this is an integer or a vector integer type.