LLVM 22.0.0git
SPIRVPostLegalizer.cpp
Go to the documentation of this file.
1//===-- SPIRVPostLegalizer.cpp - ammend info after legalization -*- C++ -*-===//
2//
3// which may appear after the legalizer pass
4//
5// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
6// See https://llvm.org/LICENSE.txt for license information.
7// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
8//
9//===----------------------------------------------------------------------===//
10//
11// The pass partially apply pre-legalization logic to new instructions inserted
12// as a result of legalization:
13// - assigns SPIR-V types to registers for new instructions.
14//
15//===----------------------------------------------------------------------===//
16
17#include "SPIRV.h"
18#include "SPIRVSubtarget.h"
19#include "SPIRVUtils.h"
20#include "llvm/IR/IntrinsicsSPIRV.h"
21#include "llvm/Support/Debug.h"
22#include <stack>
23
24#define DEBUG_TYPE "spirv-postlegalizer"
25
26using namespace llvm;
27
28namespace {
29class SPIRVPostLegalizer : public MachineFunctionPass {
30public:
31 static char ID;
32 SPIRVPostLegalizer() : MachineFunctionPass(ID) {}
33 bool runOnMachineFunction(MachineFunction &MF) override;
34};
35} // namespace
36
37namespace llvm {
38// Defined in SPIRVPreLegalizer.cpp.
39extern void insertAssignInstr(Register Reg, Type *Ty, SPIRVType *SpirvTy,
44 SPIRVType *KnownResType);
45} // namespace llvm
46
50 const LLT &Ty = MIB.getMRI()->getType(ResVReg);
51 return GR->getOrCreateSPIRVIntegerType(Ty.getScalarSizeInBits(), MIB);
52}
53
57 Register SrcReg = I->getOperand(I->getNumOperands() - 1).getReg();
58 SPIRVType *ScalarType = nullptr;
59 if (SPIRVType *DefType = GR->getSPIRVTypeForVReg(SrcReg)) {
60 assert(DefType->getOpcode() == SPIRV::OpTypeVector);
61 ScalarType = GR->getSPIRVTypeForVReg(DefType->getOperand(1).getReg());
62 }
63
64 if (!ScalarType) {
65 // If we could not deduce the type from the source, try to deduce it from
66 // the uses of the results.
67 for (unsigned i = 0; i < I->getNumDefs() && !ScalarType; ++i) {
68 for (const auto &Use :
69 MRI.use_nodbg_instructions(I->getOperand(i).getReg())) {
70 assert(Use.getOpcode() == TargetOpcode::G_BUILD_VECTOR &&
71 "Expected use of G_UNMERGE_VALUES to be a G_BUILD_VECTOR");
72 if (auto *VecType =
73 GR->getSPIRVTypeForVReg(Use.getOperand(0).getReg())) {
74 ScalarType = GR->getScalarOrVectorComponentType(VecType);
75 break;
76 }
77 }
78 }
79 }
80
81 if (!ScalarType)
82 return false;
83
84 for (unsigned i = 0; i < I->getNumDefs(); ++i) {
85 Register DefReg = I->getOperand(i).getReg();
86 if (GR->getSPIRVTypeForVReg(DefReg))
87 continue;
88
89 LLT DefLLT = MRI.getType(DefReg);
90 SPIRVType *ResType =
91 DefLLT.isVector()
93 ScalarType, DefLLT.getNumElements(), *I,
95 : ScalarType;
96 setRegClassType(DefReg, ResType, GR, &MRI, MF);
97 }
98 return true;
99}
100
102 MachineIRBuilder &MIB,
104 unsigned OpIdx) {
105 Register OpReg = I->getOperand(OpIdx).getReg();
106 if (SPIRVType *OpType = GR->getSPIRVTypeForVReg(OpReg)) {
107 if (SPIRVType *CompType = GR->getScalarOrVectorComponentType(OpType)) {
108 Register ResVReg = I->getOperand(0).getReg();
109 const LLT &ResLLT = MIB.getMRI()->getType(ResVReg);
110 if (ResLLT.isVector())
111 return GR->getOrCreateSPIRVVectorType(CompType, ResLLT.getNumElements(),
112 MIB, false);
113 return CompType;
114 }
115 }
116 return nullptr;
117}
118
120 MachineIRBuilder &MIB,
122 unsigned StartOp, unsigned EndOp) {
123 SPIRVType *ResType = nullptr;
124 for (unsigned i = StartOp; i < EndOp; ++i) {
125 if (SPIRVType *Type = deduceTypeFromSingleOperand(I, MIB, GR, i)) {
126#ifdef EXPENSIVE_CHECKS
127 assert(!ResType || Type == ResType && "Conflicting type from operands.");
128 ResType = Type;
129#else
130 return Type;
131#endif
132 }
133 }
134 return ResType;
135}
136
138 Register UseRegister,
140 MachineIRBuilder &MIB) {
141 for (const MachineOperand &MO : Use->defs()) {
142 if (!MO.isReg())
143 continue;
144 if (SPIRVType *OpType = GR->getSPIRVTypeForVReg(MO.getReg())) {
145 if (SPIRVType *CompType = GR->getScalarOrVectorComponentType(OpType)) {
146 const LLT &ResLLT = MIB.getMRI()->getType(UseRegister);
147 if (ResLLT.isVector())
149 CompType, ResLLT.getNumElements(), MIB, false);
150 return CompType;
151 }
152 }
153 }
154 return nullptr;
155}
156
159 MachineIRBuilder &MIB) {
161 for (MachineInstr &Use : MRI.use_nodbg_instructions(Reg)) {
162 SPIRVType *ResType = nullptr;
163 switch (Use.getOpcode()) {
164 case TargetOpcode::G_BUILD_VECTOR:
165 case TargetOpcode::G_EXTRACT_VECTOR_ELT:
166 case TargetOpcode::G_UNMERGE_VALUES:
167 LLVM_DEBUG(dbgs() << "Looking at use " << Use << "\n");
168 ResType = deduceTypeForResultRegister(&Use, Reg, GR, MIB);
169 break;
170 }
171 if (ResType)
172 return ResType;
173 }
174 return nullptr;
175}
176
179 MachineIRBuilder &MIB) {
180 Register ResVReg = I->getOperand(0).getReg();
181 switch (I->getOpcode()) {
182 case TargetOpcode::G_CONSTANT:
183 case TargetOpcode::G_ANYEXT:
184 return deduceIntTypeFromResult(ResVReg, MIB, GR);
185 case TargetOpcode::G_BUILD_VECTOR:
186 return deduceTypeFromOperandRange(I, MIB, GR, 1, I->getNumOperands());
187 case TargetOpcode::G_SHUFFLE_VECTOR:
188 return deduceTypeFromOperandRange(I, MIB, GR, 1, 3);
189 default:
190 if (I->getNumDefs() == 1 && I->getNumOperands() > 1 &&
191 I->getOperand(1).isReg())
192 return deduceTypeFromSingleOperand(I, MIB, GR, 1);
193 return nullptr;
194 }
195}
196
199 MachineIRBuilder &MIB) {
200 LLVM_DEBUG(dbgs() << "\nProcessing instruction: " << *I);
202 Register ResVReg = I->getOperand(0).getReg();
203
204 // G_UNMERGE_VALUES is handled separately because it has multiple definitions,
205 // unlike the other instructions which have a single result register. The main
206 // deduction logic is designed for the single-definition case.
207 if (I->getOpcode() == TargetOpcode::G_UNMERGE_VALUES)
208 return deduceAndAssignTypeForGUnmerge(I, MF, GR);
209
210 LLVM_DEBUG(dbgs() << "Inferring type from operands\n");
211 SPIRVType *ResType = deduceResultTypeFromOperands(I, GR, MIB);
212 if (!ResType) {
213 LLVM_DEBUG(dbgs() << "Inferring type from uses\n");
214 ResType = deduceTypeFromUses(ResVReg, MF, GR, MIB);
215 }
216
217 if (!ResType)
218 return false;
219
220 LLVM_DEBUG(dbgs() << "Assigned type to " << *I << ": " << *ResType);
221 GR->assignSPIRVTypeToVReg(ResType, ResVReg, MF);
222
223 if (!MRI.getRegClassOrNull(ResVReg)) {
224 LLVM_DEBUG(dbgs() << "Updating the register class.\n");
225 setRegClassType(ResVReg, ResType, GR, &MRI, *GR->CurMF, true);
226 }
227 return true;
228}
229
232 LLVM_DEBUG(dbgs() << "Checking if instruction requires a SPIR-V type: "
233 << I;);
234 if (I.getNumDefs() == 0) {
235 LLVM_DEBUG(dbgs() << "Instruction does not have a definition.\n");
236 return false;
237 }
238
239 if (!I.isPreISelOpcode()) {
240 LLVM_DEBUG(dbgs() << "Instruction is not a generic instruction.\n");
241 return false;
242 }
243
244 Register ResultRegister = I.defs().begin()->getReg();
245 if (GR->getSPIRVTypeForVReg(ResultRegister)) {
246 LLVM_DEBUG(dbgs() << "Instruction already has a SPIR-V type.\n");
247 if (!MRI.getRegClassOrNull(ResultRegister)) {
248 LLVM_DEBUG(dbgs() << "Updating the register class.\n");
249 setRegClassType(ResultRegister, GR->getSPIRVTypeForVReg(ResultRegister),
250 GR, &MRI, *GR->CurMF, true);
251 }
252 return false;
253 }
254
255 return true;
256}
257
262 for (MachineBasicBlock &MBB : MF) {
263 for (MachineInstr &I : MBB) {
264 if (requiresSpirvType(I, GR, MRI)) {
265 Worklist.push_back(&I);
266 }
267 }
268 }
269
270 if (Worklist.empty()) {
271 LLVM_DEBUG(dbgs() << "Initial worklist is empty.\n");
272 return;
273 }
274
275 LLVM_DEBUG(dbgs() << "Initial worklist:\n";
276 for (auto *I : Worklist) { I->dump(); });
277
278 bool Changed;
279 do {
280 Changed = false;
282
283 for (MachineInstr *I : Worklist) {
284 MachineIRBuilder MIB(*I);
285 if (deduceAndAssignSpirvType(I, MF, GR, MIB)) {
286 Changed = true;
287 } else {
288 NextWorklist.push_back(I);
289 }
290 }
291 Worklist = std::move(NextWorklist);
292 LLVM_DEBUG(dbgs() << "Worklist size: " << Worklist.size() << "\n");
293 } while (Changed);
294
295 if (Worklist.empty())
296 return;
297
298 for (auto *I : Worklist) {
299 MachineIRBuilder MIB(*I);
300 Register ResVReg = I->getOperand(0).getReg();
301 const LLT &ResLLT = MRI.getType(ResVReg);
302 SPIRVType *ResType = nullptr;
303 if (ResLLT.isVector()) {
305 ResLLT.getElementType().getSizeInBits(), MIB);
306 ResType = GR->getOrCreateSPIRVVectorType(
307 CompType, ResLLT.getNumElements(), MIB, false);
308 } else {
309 ResType = GR->getOrCreateSPIRVIntegerType(ResLLT.getSizeInBits(), MIB);
310 }
311 LLVM_DEBUG(dbgs() << "Could not determine type for " << *I
312 << ", defaulting to " << *ResType << "\n");
313 setRegClassType(ResVReg, ResType, GR, &MRI, MF, true);
314 }
315}
316
319 LLVM_DEBUG(dbgs() << "Entering ensureAssignTypeForTypeFolding for function "
320 << MF.getName() << "\n");
322 for (MachineBasicBlock &MBB : MF) {
323 for (MachineInstr &MI : MBB) {
324 if (!isTypeFoldingSupported(MI.getOpcode()))
325 continue;
326 if (MI.getNumOperands() == 1 || !MI.getOperand(1).isReg())
327 continue;
328
329 LLVM_DEBUG(dbgs() << "Processing instruction: " << MI);
330
331 // Check uses of MI to see if it already has an use in SPIRV::ASSIGN_TYPE
332 bool HasAssignType = false;
333 Register ResultRegister = MI.defs().begin()->getReg();
334 // All uses of Result register
335 for (MachineInstr &UseInstr :
336 MRI.use_nodbg_instructions(ResultRegister)) {
337 if (UseInstr.getOpcode() == SPIRV::ASSIGN_TYPE) {
338 HasAssignType = true;
339 LLVM_DEBUG(dbgs() << " Instruction already has an ASSIGN_TYPE use: "
340 << UseInstr);
341 break;
342 }
343 }
344
345 if (!HasAssignType) {
346 Register ResultRegister = MI.defs().begin()->getReg();
347 SPIRVType *ResultType = GR->getSPIRVTypeForVReg(ResultRegister);
349 dbgs() << " Adding ASSIGN_TYPE for ResultRegister: "
350 << printReg(ResultRegister, MRI.getTargetRegisterInfo())
351 << " with type: " << *ResultType);
352 MachineIRBuilder MIB(MI);
353 insertAssignInstr(ResultRegister, nullptr, ResultType, GR, MIB, MRI);
354 }
355 }
356 }
357}
358
359// Do a preorder traversal of the CFG starting from the BB |Start|.
360// point. Calls |op| on each basic block encountered during the traversal.
362 std::function<void(MachineBasicBlock *)> op) {
363 std::stack<MachineBasicBlock *> ToVisit;
365
366 ToVisit.push(&Start);
367 Seen.insert(ToVisit.top());
368 while (ToVisit.size() != 0) {
369 MachineBasicBlock *MBB = ToVisit.top();
370 ToVisit.pop();
371
372 op(MBB);
373
374 for (auto Succ : MBB->successors()) {
375 if (Seen.contains(Succ))
376 continue;
377 ToVisit.push(Succ);
378 Seen.insert(Succ);
379 }
380 }
381}
382
383// Do a preorder traversal of the CFG starting from the given function's entry
384// point. Calls |op| on each basic block encountered during the traversal.
385void visit(MachineFunction &MF, std::function<void(MachineBasicBlock *)> op) {
386 visit(MF, *MF.begin(), std::move(op));
387}
388
389bool SPIRVPostLegalizer::runOnMachineFunction(MachineFunction &MF) {
390 // Initialize the type registry.
391 const SPIRVSubtarget &ST = MF.getSubtarget<SPIRVSubtarget>();
392 SPIRVGlobalRegistry *GR = ST.getSPIRVGlobalRegistry();
393 GR->setCurrentFunc(MF);
396 return true;
397}
398
399INITIALIZE_PASS(SPIRVPostLegalizer, DEBUG_TYPE, "SPIRV post legalizer", false,
400 false)
401
402char SPIRVPostLegalizer::ID = 0;
403
404FunctionPass *llvm::createSPIRVPostLegalizerPass() {
405 return new SPIRVPostLegalizer();
406}
unsigned const MachineRegisterInfo * MRI
assert(UImm &&(UImm !=~static_cast< T >(0)) &&"Invalid immediate!")
MachineBasicBlock & MBB
#define DEBUG_TYPE
#define op(i)
IRTranslator LLVM IR MI
#define I(x, y, z)
Definition MD5.cpp:57
Register Reg
MachineInstr unsigned OpIdx
#define INITIALIZE_PASS(passName, arg, name, cfg, analysis)
Definition PassSupport.h:56
static bool deduceAndAssignSpirvType(MachineInstr *I, MachineFunction &MF, SPIRVGlobalRegistry *GR, MachineIRBuilder &MIB)
static SPIRVType * deduceIntTypeFromResult(Register ResVReg, MachineIRBuilder &MIB, SPIRVGlobalRegistry *GR)
void visit(MachineFunction &MF, MachineBasicBlock &Start, std::function< void(MachineBasicBlock *)> op)
static void registerSpirvTypeForNewInstructions(MachineFunction &MF, SPIRVGlobalRegistry *GR)
static SPIRVType * deduceResultTypeFromOperands(MachineInstr *I, SPIRVGlobalRegistry *GR, MachineIRBuilder &MIB)
static SPIRVType * deduceTypeFromOperandRange(MachineInstr *I, MachineIRBuilder &MIB, SPIRVGlobalRegistry *GR, unsigned StartOp, unsigned EndOp)
static SPIRVType * deduceTypeFromUses(Register Reg, MachineFunction &MF, SPIRVGlobalRegistry *GR, MachineIRBuilder &MIB)
static SPIRVType * deduceTypeForResultRegister(MachineInstr *Use, Register UseRegister, SPIRVGlobalRegistry *GR, MachineIRBuilder &MIB)
static bool deduceAndAssignTypeForGUnmerge(MachineInstr *I, MachineFunction &MF, SPIRVGlobalRegistry *GR)
static void ensureAssignTypeForTypeFolding(MachineFunction &MF, SPIRVGlobalRegistry *GR)
static bool requiresSpirvType(MachineInstr &I, SPIRVGlobalRegistry *GR, MachineRegisterInfo &MRI)
static SPIRVType * deduceTypeFromSingleOperand(MachineInstr *I, MachineIRBuilder &MIB, SPIRVGlobalRegistry *GR, unsigned OpIdx)
#define LLVM_DEBUG(...)
Definition Debug.h:114
constexpr uint16_t getNumElements() const
Returns the number of elements in a vector LLT.
constexpr bool isVector() const
constexpr TypeSize getSizeInBits() const
Returns the total size of the type. Must only be called on sized types.
constexpr LLT getElementType() const
Returns the vector's element type. Only valid for vector types.
MachineFunctionPass - This class adapts the FunctionPass interface to allow convenient creation of pa...
const TargetSubtargetInfo & getSubtarget() const
getSubtarget - Return the subtarget for which this machine code is being compiled.
StringRef getName() const
getName - Return the name of the corresponding LLVM function.
MachineRegisterInfo & getRegInfo()
getRegInfo - Return information about the registers currently in use.
Helper class to build MachineInstr.
MachineRegisterInfo * getMRI()
Getter for MRI.
Representation of each machine instruction.
MachineOperand class - Representation of each machine instruction operand.
MachineRegisterInfo - Keep track of information for virtual and physical registers,...
LLT getType(Register Reg) const
Get the low-level type of Reg or LLT{} if Reg is not a generic (target independent) virtual register.
Wrapper class representing virtual and physical registers.
Definition Register.h:20
SPIRVType * getSPIRVTypeForVReg(Register VReg, const MachineFunction *MF=nullptr) const
void assignSPIRVTypeToVReg(SPIRVType *Type, Register VReg, const MachineFunction &MF)
SPIRVType * getScalarOrVectorComponentType(Register VReg) const
SPIRVType * getOrCreateSPIRVVectorType(SPIRVType *BaseType, unsigned NumElements, MachineIRBuilder &MIRBuilder, bool EmitIR)
SPIRVType * getOrCreateSPIRVIntegerType(unsigned BitWidth, MachineIRBuilder &MIRBuilder)
const SPIRVInstrInfo * getInstrInfo() const override
std::pair< iterator, bool > insert(PtrType Ptr)
Inserts Ptr if and only if there is no element in the container equal to Ptr.
bool contains(ConstPtrType Ptr) const
SmallPtrSet - This class implements a set which is optimized for holding SmallSize or less elements.
void push_back(const T &Elt)
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
The instances of the Type class are immutable: once they are created, they are never changed.
Definition Type.h:45
A Use represents the edge between a Value definition and its users.
Definition Use.h:35
Changed
unsigned ID
LLVM IR allows to use arbitrary numbers as calling convention identifiers.
Definition CallingConv.h:24
This is an optimization pass for GlobalISel generic memory operations.
void insertAssignInstr(Register Reg, Type *Ty, SPIRVType *SpirvTy, SPIRVGlobalRegistry *GR, MachineIRBuilder &MIB, MachineRegisterInfo &MRI)
Helper external function for inserting ASSIGN_TYPE instuction between Reg and its definition,...
bool isTypeFoldingSupported(unsigned Opcode)
void processInstr(MachineInstr &MI, MachineIRBuilder &MIB, MachineRegisterInfo &MRI, SPIRVGlobalRegistry *GR, SPIRVType *KnownResType)
FunctionPass * createSPIRVPostLegalizerPass()
LLVM_ABI raw_ostream & dbgs()
dbgs() - This returns a reference to a raw_ostream for debugging messages.
Definition Debug.cpp:207
void setRegClassType(Register Reg, SPIRVType *SpvType, SPIRVGlobalRegistry *GR, MachineRegisterInfo *MRI, const MachineFunction &MF, bool Force)
const MachineInstr SPIRVType
LLVM_ABI Printable printReg(Register Reg, const TargetRegisterInfo *TRI=nullptr, unsigned SubIdx=0, const MachineRegisterInfo *MRI=nullptr)
Prints virtual and physical registers with or without a TRI instance.