23#include "llvm/IR/IntrinsicsSPIRV.h"
26#define DEBUG_TYPE "spirv-prelegalizer"
54 cast<Constant>(cast<ConstantAsMetadata>(
55 MI.getOperand(3).getMetadata()->getOperand(0))
57 if (
auto *GV = dyn_cast<GlobalValue>(Const)) {
60 GR->
add(GV, &MF, SrcReg);
62 RegsAlreadyAddedToDT[&
MI] = Reg;
66 if (
auto *ConstVec = dyn_cast<ConstantDataVector>(Const)) {
67 auto *BuildVec =
MRI.getVRegDef(SrcReg);
69 BuildVec->getOpcode() == TargetOpcode::G_BUILD_VECTOR);
70 for (
unsigned i = 0; i < ConstVec->getNumElements(); ++i) {
73 Constant *ElemConst = ConstVec->getElementAsConstant(i);
76 GR->
add(ElemConst, &MF, BuildVec->getOperand(1 + i).getReg());
78 BuildVec->getOperand(1 + i).setReg(ElemReg);
81 GR->
add(Const, &MF, SrcReg);
82 if (Const->getType()->isTargetExtTy()) {
85 if (SrcMI && SrcMI->
getOpcode() == TargetOpcode::G_CONSTANT)
86 TargetExtConstTypes[SrcMI] = Const->getType();
89 RegsAlreadyAddedToDT[&
MI] = Reg;
92 assert(
MI.getOperand(2).isReg() &&
"Reg operand is expected");
94 if (SrcMI &&
isSpvIntrinsic(*SrcMI, Intrinsic::spv_const_composite))
103 Reg = RegsAlreadyAddedToDT[
MI];
104 auto *RC =
MRI.getRegClassOrNull(
MI->getOperand(0).getReg());
105 if (!
MRI.getRegClassOrNull(Reg) && RC)
106 MRI.setRegClass(Reg, RC);
107 MRI.replaceRegWith(
MI->getOperand(0).getReg(), Reg);
108 MI->eraseFromParent();
111 MI->eraseFromParent();
117 const unsigned AssignNameOperandShift = 2;
122 unsigned NumOp =
MI.getNumExplicitDefs() + AssignNameOperandShift;
123 while (
MI.getOperand(NumOp).isReg()) {
127 MI.removeOperand(NumOp);
136 MI->eraseFromParent();
176 MI->eraseFromParent();
197 assert(
MI &&
"Machine instr is expected");
198 if (
MI->getOperand(0).isReg()) {
202 switch (
MI->getOpcode()) {
203 case TargetOpcode::G_CONSTANT: {
205 Type *Ty =
MI->getOperand(1).getCImm()->getType();
209 case TargetOpcode::G_GLOBAL_VALUE: {
214 Global->getType()->getAddressSpace());
218 case TargetOpcode::G_ZEXT: {
219 if (
MI->getOperand(1).isReg()) {
221 MRI.getVRegDef(
MI->getOperand(1).getReg())) {
224 unsigned ExpectedBW =
225 std::max(
MRI.getType(Reg).getScalarSizeInBits(), CurrentBW);
236 case TargetOpcode::G_PTRTOINT:
238 MRI.getType(Reg).getScalarSizeInBits(), MIB);
240 case TargetOpcode::G_TRUNC:
241 case TargetOpcode::G_ADDRSPACE_CAST:
242 case TargetOpcode::G_PTR_ADD:
243 case TargetOpcode::COPY: {
255 if (!
MRI.getRegClassOrNull(Reg))
256 MRI.setRegClass(Reg, &SPIRV::IDRegClass);
262static std::pair<Register, unsigned>
267 assert(SpvType &&
"VReg is expected to have SPIRV type");
268 LLT SrcLLT =
MRI.getType(SrcReg);
270 bool IsFloat = SpvType->
getOpcode() == SPIRV::OpTypeFloat;
272 SpvType->
getOpcode() == SPIRV::OpTypeVector &&
275 IsFloat |= IsVectorFloat;
276 auto GetIdOp = IsFloat ? SPIRV::GET_fID : SPIRV::GET_ID;
277 auto DstClass = IsFloat ? &SPIRV::fIDRegClass : &SPIRV::IDRegClass;
286 GetIdOp = SPIRV::GET_vpID64;
287 DstClass = &SPIRV::vpID64RegClass;
289 GetIdOp = SPIRV::GET_pID64;
290 DstClass = &SPIRV::pID64RegClass;
294 GetIdOp = SPIRV::GET_vpID32;
295 DstClass = &SPIRV::vpID32RegClass;
297 GetIdOp = SPIRV::GET_pID32;
298 DstClass = &SPIRV::pID32RegClass;
304 GetIdOp = SPIRV::GET_vfID;
305 DstClass = &SPIRV::vfIDRegClass;
307 GetIdOp = SPIRV::GET_vID;
308 DstClass = &SPIRV::vIDRegClass;
311 Register IdReg =
MRI.createGenericVirtualRegister(NewT);
312 MRI.setRegClass(IdReg, DstClass);
313 return {IdReg, GetIdOp};
326 assert((Ty || SpirvTy) &&
"Either LLVM or SPIRV type is expected.");
328 (Def->getNextNode() ? Def->getNextNode()->getIterator()
329 : Def->getParent()->end()));
331 Register NewReg =
MRI.createGenericVirtualRegister(
MRI.getType(Reg));
332 if (
auto *RC =
MRI.getRegClassOrNull(Reg)) {
333 MRI.setRegClass(NewReg, RC);
335 MRI.setRegClass(NewReg, &SPIRV::IDRegClass);
336 MRI.setRegClass(Reg, &SPIRV::IDRegClass);
344 const uint32_t Flags = Def->getFlags();
350 Def->getOperand(0).setReg(NewReg);
356 assert(
MI.getNumDefs() > 0 &&
MRI.hasOneUse(
MI.getOperand(0).getReg()));
358 *(
MRI.use_instr_begin(
MI.getOperand(0).getReg()));
362 MI.getOperand(0).setReg(NewReg);
364 (
MI.getNextNode() ?
MI.getNextNode()->getIterator()
365 :
MI.getParent()->end()));
366 for (
auto &
Op :
MI.operands()) {
367 if (!
Op.isReg() ||
Op.isDef())
371 Op.setReg(IdOpInfo.first);
391 bool ReachedBegin =
false;
405 assert(Def &&
"Expecting an instruction that defines the register");
407 if (Def->getOpcode() != TargetOpcode::G_GLOBAL_VALUE)
415 assert(Def &&
"Expecting an instruction that defines the register");
417 if (Def->getOpcode() != TargetOpcode::G_GLOBAL_VALUE)
420 }
else if (
MI.getOpcode() == TargetOpcode::G_CONSTANT ||
421 MI.getOpcode() == TargetOpcode::G_FCONSTANT ||
422 MI.getOpcode() == TargetOpcode::G_BUILD_VECTOR) {
429 if (
MRI.hasOneUse(Reg)) {
436 if (
MI.getOpcode() == TargetOpcode::G_CONSTANT) {
437 auto TargetExtIt = TargetExtConstTypes.
find(&
MI);
438 Ty = TargetExtIt == TargetExtConstTypes.
end()
439 ?
MI.getOperand(1).getCImm()->getType()
440 : TargetExtIt->second;
441 }
else if (
MI.getOpcode() == TargetOpcode::G_FCONSTANT) {
442 Ty =
MI.getOperand(1).getFPImm()->getType();
444 assert(
MI.getOpcode() == TargetOpcode::G_BUILD_VECTOR);
445 Type *ElemTy =
nullptr;
449 if (ElemMI->
getOpcode() == TargetOpcode::G_CONSTANT)
451 else if (ElemMI->
getOpcode() == TargetOpcode::G_FCONSTANT)
456 MI.getNumExplicitOperands() -
MI.getNumExplicitDefs();
457 Ty = VectorType::get(ElemTy, NumElts,
false);
460 }
else if (
MI.getOpcode() == TargetOpcode::G_TRUNC ||
461 MI.getOpcode() == TargetOpcode::G_ZEXT ||
462 MI.getOpcode() == TargetOpcode::G_PTRTOINT ||
463 MI.getOpcode() == TargetOpcode::G_GLOBAL_VALUE ||
464 MI.getOpcode() == TargetOpcode::COPY ||
465 MI.getOpcode() == TargetOpcode::G_ADDRSPACE_CAST) {
476 MI->eraseFromParent();
498 if (
MI.getOpcode() != SPIRV::ASSIGN_TYPE)
501 unsigned Opcode =
MRI.getVRegDef(SrcReg)->getOpcode();
505 bool IsDstPtr =
MRI.getType(DstReg).isPointer();
506 bool isDstVec =
MRI.getType(DstReg).isVector();
507 if (IsDstPtr || isDstVec)
508 MRI.setRegClass(DstReg, &SPIRV::IDRegClass);
511 if (Opcode == TargetOpcode::G_CONSTANT &&
MRI.hasOneUse(DstReg)) {
513 if (
UseMI.getOpcode() == TargetOpcode::G_ADDRSPACE_CAST)
537 for (
unsigned i = 2; i <
MI.getNumOperands(); ++i) {
545 BuildMBB->
getOpcode() == TargetOpcode::G_BLOCK_ADDR &&
556 for (
auto &SwIt : Switches) {
560 for (
unsigned i = 0; i < Ins.size(); ++i) {
561 if (Ins[i]->
getOpcode() == TargetOpcode::G_BLOCK_ADDR) {
563 Ins[i]->getOperand(1).getBlockAddress()->getBasicBlock();
564 auto It = BB2MBB.
find(CaseBB);
565 if (It == BB2MBB.
end())
567 "block in a switch statement");
569 MI.getParent()->addSuccessor(It->second);
576 for (
unsigned i =
MI.getNumOperands() - 1; i > 1; --i)
578 for (
auto &MO : NewOps)
583 Next =
MI.getNextNode();
585 if (Next && Next->getOpcode() == TargetOpcode::G_BRINDIRECT)
590 BlockAddrI->eraseFromParent();
631 GR->setCurrentFunc(MF);
649char SPIRVPreLegalizer::
ID = 0;
652 return new SPIRVPreLegalizer();
unsigned const MachineRegisterInfo * MRI
MachineInstrBuilder & UseMI
This file contains the simple types necessary to represent the attributes associated with functions a...
This file contains the declarations for the subclasses of Constant, which represent the different fla...
#define INITIALIZE_PASS(passName, arg, name, cfg, analysis)
This file builds on the ADT/GraphTraits.h file to build a generic graph post order iterator.
assert(ImpDefSCC.getReg()==AMDGPU::SCC &&ImpDefSCC.isDef())
static void removeImplicitFallthroughs(MachineFunction &MF, MachineIRBuilder MIB)
static bool isImplicitFallthrough(MachineBasicBlock &MBB)
bool isTypeFoldingSupported(unsigned Opcode)
static void insertBitcasts(MachineFunction &MF, SPIRVGlobalRegistry *GR, MachineIRBuilder MIB)
static void processInstrsWithTypeFolding(MachineFunction &MF, SPIRVGlobalRegistry *GR, MachineIRBuilder MIB)
static SPIRVType * propagateSPIRVType(MachineInstr *MI, SPIRVGlobalRegistry *GR, MachineRegisterInfo &MRI, MachineIRBuilder &MIB)
static void processSwitches(MachineFunction &MF, SPIRVGlobalRegistry *GR, MachineIRBuilder MIB)
static std::pair< Register, unsigned > createNewIdReg(SPIRVType *SpvType, Register SrcReg, MachineRegisterInfo &MRI, const SPIRVGlobalRegistry &GR)
static void generateAssignInstrs(MachineFunction &MF, SPIRVGlobalRegistry *GR, MachineIRBuilder MIB, DenseMap< MachineInstr *, Type * > &TargetExtConstTypes)
static void addConstantsToTrack(MachineFunction &MF, SPIRVGlobalRegistry *GR, DenseMap< MachineInstr *, Type * > &TargetExtConstTypes)
static void foldConstantsIntoIntrinsics(MachineFunction &MF)
static std::optional< unsigned > getOpcode(ArrayRef< VPValue * > Values)
Returns the opcode of Values or ~0 if they do not all agree.
LLVM Basic Block Representation.
uint64_t getZExtValue() const
Return the constant as a 64-bit unsigned integer value after it has been zero extended as appropriate...
This is an important base class in LLVM.
This class represents an Operation in the Expression.
iterator find(const_arg_type_t< KeyT > Val)
bool contains(const_arg_type_t< KeyT > Val) const
Return true if the specified key is in the map, false otherwise.
FunctionPass class - This class is used to implement most global optimizations.
static constexpr LLT scalar(unsigned SizeInBits)
Get a low-level scalar or aggregate "bag of bits".
constexpr bool isVector() const
static constexpr LLT pointer(unsigned AddressSpace, unsigned SizeInBits)
Get a low-level pointer in the given address space.
constexpr bool isPointer() const
static constexpr LLT fixed_vector(unsigned NumElements, unsigned ScalarSizeInBits)
Get a low-level fixed-width vector of some number of elements and element width.
const BasicBlock * getBasicBlock() const
Return the LLVM basic block that this instance corresponded to originally.
bool canFallThrough()
Return true if the block can implicitly transfer control to the block after it by falling off the end...
iterator_range< succ_iterator > successors()
reverse_iterator rbegin()
MachineFunctionPass - This class adapts the FunctionPass interface to allow convenient creation of pa...
virtual bool runOnMachineFunction(MachineFunction &MF)=0
runOnMachineFunction - This method must be overloaded to perform the desired machine code transformat...
const TargetSubtargetInfo & getSubtarget() const
getSubtarget - Return the subtarget for which this machine code is being compiled.
MachineRegisterInfo & getRegInfo()
getRegInfo - Return information about the registers currently in use.
Helper class to build MachineInstr.
MachineInstrBuilder buildBr(MachineBasicBlock &Dest)
Build and insert G_BR Dest.
void setInsertPt(MachineBasicBlock &MBB, MachineBasicBlock::iterator II)
Set the insertion point before the specified position.
MachineInstrBuilder buildInstr(unsigned Opcode)
Build and insert <empty> = Opcode <empty>.
MachineFunction & getMF()
Getter for the function we currently build.
MachineInstrBuilder buildBitcast(const DstOp &Dst, const SrcOp &Src)
Build and insert Dst = G_BITCAST Src.
MachineRegisterInfo * getMRI()
Getter for MRI.
const MachineInstrBuilder & addUse(Register RegNo, unsigned Flags=0, unsigned SubReg=0) const
Add a virtual register use operand.
const MachineInstrBuilder & setMIFlags(unsigned Flags) const
const MachineInstrBuilder & addDef(Register RegNo, unsigned Flags=0, unsigned SubReg=0) const
Add a virtual register definition operand.
MachineBasicBlock iterator that automatically skips over MIs that are inside bundles (i....
Representation of each machine instruction.
unsigned getOpcode() const
Returns the opcode of this MachineInstr.
const MachineOperand & getOperand(unsigned i) const
MachineOperand class - Representation of each machine instruction operand.
const ConstantInt * getCImm() const
static MachineOperand CreateCImm(const ConstantInt *CI)
void setReg(Register Reg)
Change the register this operand corresponds to.
const BlockAddress * getBlockAddress() const
static MachineOperand CreateImm(int64_t Val)
bool isBlockAddress() const
isBlockAddress - Tests if this is a MO_BlockAddress operand.
Register getReg() const
getReg - Returns the register number.
const ConstantFP * getFPImm() const
static MachineOperand CreateMBB(MachineBasicBlock *MBB, unsigned TargetFlags=0)
MachineRegisterInfo - Keep track of information for virtual and physical registers,...
void replaceRegWith(Register FromReg, Register ToReg)
replaceRegWith - Replace all instances of FromReg with ToReg in the machine function.
static PassRegistry * getPassRegistry()
getPassRegistry - Access the global registry object, which is automatically initialized at applicatio...
Wrapper class representing virtual and physical registers.
constexpr bool isValid() const
SPIRVType * getSPIRVTypeForVReg(Register VReg, const MachineFunction *MF=nullptr) const
void add(const Constant *C, MachineFunction *MF, Register R)
unsigned getScalarOrVectorComponentCount(Register VReg) const
unsigned getPointerSize() const
Register getSPIRVTypeID(const SPIRVType *SpirvType) const
SPIRVType * getOrCreateSPIRVType(const Type *Type, MachineIRBuilder &MIRBuilder, SPIRV::AccessQualifier::AccessQualifier AQ=SPIRV::AccessQualifier::ReadWrite, bool EmitIR=true)
void assignSPIRVTypeToVReg(SPIRVType *Type, Register VReg, MachineFunction &MF)
Register find(const MachineInstr *MI, MachineFunction *MF)
SPIRVType * getOrCreateSPIRVPointerType(SPIRVType *BaseType, MachineIRBuilder &MIRBuilder, SPIRV::StorageClass::StorageClass SClass=SPIRV::StorageClass::Function)
SPIRVType * getOrCreateSPIRVVectorType(SPIRVType *BaseType, unsigned NumElements, MachineIRBuilder &MIRBuilder)
SPIRVType * getOrCreateSPIRVIntegerType(unsigned BitWidth, MachineIRBuilder &MIRBuilder)
Type * getDeducedGlobalValueType(const GlobalValue *Global)
unsigned getScalarOrVectorBitWidth(const SPIRVType *Type) const
const SPIRVInstrInfo * getInstrInfo() const override
std::pair< iterator, bool > insert(PtrType Ptr)
Inserts Ptr if and only if there is no element in the container equal to Ptr.
SmallPtrSet - This class implements a set which is optimized for holding SmallSize or less elements.
void push_back(const T &Elt)
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
The instances of the Type class are immutable: once they are created, they are never changed.
static TypedPointerType * get(Type *ElementType, unsigned AddressSpace)
This constructs a pointer to an object of the specified type in a numbered address space.
Type * getType() const
All values are typed, get the type of this value.
#define llvm_unreachable(msg)
Marks that the current location is not supposed to be reachable.
unsigned ID
LLVM IR allows to use arbitrary numbers as calling convention identifiers.
This is an optimization pass for GlobalISel generic memory operations.
FunctionPass * createSPIRVPreLegalizerPass()
Register insertAssignInstr(Register Reg, Type *Ty, SPIRVType *SpirvTy, SPIRVGlobalRegistry *GR, MachineIRBuilder &MIB, MachineRegisterInfo &MRI)
Helper external function for inserting ASSIGN_TYPE instuction between Reg and its definition,...
iterator_range< po_iterator< T > > post_order(const T &G)
void report_fatal_error(Error Err, bool gen_crash_diag=true)
Report a serious error, calling any installed error handler.
@ Global
Append to llvm.global_dtors.
SPIRV::StorageClass::StorageClass addressSpaceToStorageClass(unsigned AddrSpace, const SPIRVSubtarget &STI)
void processInstr(MachineInstr &MI, MachineIRBuilder &MIB, MachineRegisterInfo &MRI, SPIRVGlobalRegistry *GR)
MachineInstr * getDefInstrMaybeConstant(Register &ConstReg, const MachineRegisterInfo *MRI)
Type * getMDOperandAsType(const MDNode *N, unsigned I)
void initializeSPIRVPreLegalizerPass(PassRegistry &)
bool isSpvIntrinsic(const MachineInstr &MI, Intrinsic::ID IntrinsicID)