LLVM 23.0.0git
SPIRVModuleAnalysis.cpp
Go to the documentation of this file.
1//===- SPIRVModuleAnalysis.cpp - analysis of global instrs & regs - C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// The analysis collects instructions that should be output at the module level
10// and performs the global register numbering.
11//
12// The results of this analysis are used in AsmPrinter to rename registers
13// globally and to output required instructions at the module level.
14//
15//===----------------------------------------------------------------------===//
16
17// TODO: uses or report_fatal_error (which is also deprecated) /
18// ReportFatalUsageError in this file should be refactored, as per LLVM
19// best practices, to rely on the Diagnostic infrastructure.
20
21#include "SPIRVModuleAnalysis.h"
24#include "SPIRV.h"
25#include "SPIRVSubtarget.h"
26#include "SPIRVTargetMachine.h"
27#include "SPIRVUtils.h"
28#include "llvm/ADT/STLExtras.h"
31
32using namespace llvm;
33
34#define DEBUG_TYPE "spirv-module-analysis"
35
36static cl::opt<bool>
37 SPVDumpDeps("spv-dump-deps",
38 cl::desc("Dump MIR with SPIR-V dependencies info"),
39 cl::Optional, cl::init(false));
40
42 AvoidCapabilities("avoid-spirv-capabilities",
43 cl::desc("SPIR-V capabilities to avoid if there are "
44 "other options enabling a feature"),
46 cl::values(clEnumValN(SPIRV::Capability::Shader, "Shader",
47 "SPIR-V Shader capability")));
48// Use sets instead of cl::list to check "if contains" condition
53
55
56INITIALIZE_PASS(SPIRVModuleAnalysis, DEBUG_TYPE, "SPIRV module analysis", true,
57 true)
58
59// Retrieve an unsigned from an MDNode with a list of them as operands.
60static unsigned getMetadataUInt(MDNode *MdNode, unsigned OpIndex,
61 unsigned DefaultVal = 0) {
62 if (MdNode && OpIndex < MdNode->getNumOperands()) {
63 const auto &Op = MdNode->getOperand(OpIndex);
64 return mdconst::extract<ConstantInt>(Op)->getZExtValue();
65 }
66 return DefaultVal;
67}
68
70getSymbolicOperandRequirements(SPIRV::OperandCategory::OperandCategory Category,
71 unsigned i, const SPIRVSubtarget &ST,
73 // A set of capabilities to avoid if there is another option.
74 AvoidCapabilitiesSet AvoidCaps;
75 if (!ST.isShader())
76 AvoidCaps.S.insert(SPIRV::Capability::Shader);
77 else
78 AvoidCaps.S.insert(SPIRV::Capability::Kernel);
79
80 VersionTuple ReqMinVer = getSymbolicOperandMinVersion(Category, i);
81 VersionTuple ReqMaxVer = getSymbolicOperandMaxVersion(Category, i);
82 VersionTuple SPIRVVersion = ST.getSPIRVVersion();
83 bool MinVerOK = SPIRVVersion.empty() || SPIRVVersion >= ReqMinVer;
84 bool MaxVerOK =
85 ReqMaxVer.empty() || SPIRVVersion.empty() || SPIRVVersion <= ReqMaxVer;
87 ExtensionList ReqExts = getSymbolicOperandExtensions(Category, i);
88 if (ReqCaps.empty()) {
89 if (ReqExts.empty()) {
90 if (MinVerOK && MaxVerOK)
91 return {true, {}, {}, ReqMinVer, ReqMaxVer};
92 return {false, {}, {}, VersionTuple(), VersionTuple()};
93 }
94 } else if (MinVerOK && MaxVerOK) {
95 if (ReqCaps.size() == 1) {
96 auto Cap = ReqCaps[0];
97 if (Reqs.isCapabilityAvailable(Cap)) {
99 SPIRV::OperandCategory::CapabilityOperand, Cap));
100 return {true, {Cap}, std::move(ReqExts), ReqMinVer, ReqMaxVer};
101 }
102 } else {
103 // By SPIR-V specification: "If an instruction, enumerant, or other
104 // feature specifies multiple enabling capabilities, only one such
105 // capability needs to be declared to use the feature." However, one
106 // capability may be preferred over another. We use command line
107 // argument(s) and AvoidCapabilities to avoid selection of certain
108 // capabilities if there are other options.
109 CapabilityList UseCaps;
110 for (auto Cap : ReqCaps)
111 if (Reqs.isCapabilityAvailable(Cap))
112 UseCaps.push_back(Cap);
113 for (size_t i = 0, Sz = UseCaps.size(); i < Sz; ++i) {
114 auto Cap = UseCaps[i];
115 if (i == Sz - 1 || !AvoidCaps.S.contains(Cap)) {
117 SPIRV::OperandCategory::CapabilityOperand, Cap));
118 return {true, {Cap}, std::move(ReqExts), ReqMinVer, ReqMaxVer};
119 }
120 }
121 }
122 }
123 // If there are no capabilities, or we can't satisfy the version or
124 // capability requirements, use the list of extensions (if the subtarget
125 // can handle them all).
126 if (llvm::all_of(ReqExts, [&ST](const SPIRV::Extension::Extension &Ext) {
127 return ST.canUseExtension(Ext);
128 })) {
129 return {true,
130 {},
131 std::move(ReqExts),
132 VersionTuple(),
133 VersionTuple()}; // TODO: add versions to extensions.
134 }
135 return {false, {}, {}, VersionTuple(), VersionTuple()};
136}
137
138void SPIRVModuleAnalysis::setBaseInfo(const Module &M) {
139 MAI.MaxID = 0;
140 for (int i = 0; i < SPIRV::NUM_MODULE_SECTIONS; i++)
141 MAI.MS[i].clear();
142 MAI.RegisterAliasTable.clear();
143 MAI.InstrsToDelete.clear();
144 MAI.GlobalObjMap.clear();
145 MAI.GlobalVarList.clear();
146 MAI.ExtInstSetMap.clear();
147 MAI.Reqs.clear();
148 MAI.Reqs.initAvailableCapabilities(*ST);
149
150 // TODO: determine memory model and source language from the configuratoin.
151 if (auto MemModel = M.getNamedMetadata("spirv.MemoryModel")) {
152 auto MemMD = MemModel->getOperand(0);
153 MAI.Addr = static_cast<SPIRV::AddressingModel::AddressingModel>(
154 getMetadataUInt(MemMD, 0));
155 MAI.Mem =
156 static_cast<SPIRV::MemoryModel::MemoryModel>(getMetadataUInt(MemMD, 1));
157 } else {
158 // TODO: Add support for VulkanMemoryModel.
159 MAI.Mem = ST->isShader() ? SPIRV::MemoryModel::GLSL450
160 : SPIRV::MemoryModel::OpenCL;
161 if (MAI.Mem == SPIRV::MemoryModel::OpenCL) {
162 unsigned PtrSize = ST->getPointerSize();
163 MAI.Addr = PtrSize == 32 ? SPIRV::AddressingModel::Physical32
164 : PtrSize == 64 ? SPIRV::AddressingModel::Physical64
165 : SPIRV::AddressingModel::Logical;
166 } else {
167 // TODO: Add support for PhysicalStorageBufferAddress.
168 MAI.Addr = SPIRV::AddressingModel::Logical;
169 }
170 }
171 // Get the OpenCL version number from metadata.
172 // TODO: support other source languages.
173 if (auto VerNode = M.getNamedMetadata("opencl.ocl.version")) {
174 MAI.SrcLang = SPIRV::SourceLanguage::OpenCL_C;
175 // Construct version literal in accordance with SPIRV-LLVM-Translator.
176 // TODO: support multiple OCL version metadata.
177 assert(VerNode->getNumOperands() > 0 && "Invalid SPIR");
178 auto VersionMD = VerNode->getOperand(0);
179 unsigned MajorNum = getMetadataUInt(VersionMD, 0, 2);
180 unsigned MinorNum = getMetadataUInt(VersionMD, 1);
181 unsigned RevNum = getMetadataUInt(VersionMD, 2);
182 // Prevent Major part of OpenCL version to be 0
183 MAI.SrcLangVersion =
184 (std::max(1U, MajorNum) * 100 + MinorNum) * 1000 + RevNum;
185 // When opencl.cxx.version is also present, validate compatibility
186 // and use C++ for OpenCL as source language with the C++ version.
187 if (auto *CxxVerNode = M.getNamedMetadata("opencl.cxx.version")) {
188 assert(CxxVerNode->getNumOperands() > 0 && "Invalid SPIR");
189 auto *CxxMD = CxxVerNode->getOperand(0);
190 unsigned CxxVer =
191 (getMetadataUInt(CxxMD, 0) * 100 + getMetadataUInt(CxxMD, 1)) * 1000 +
192 getMetadataUInt(CxxMD, 2);
193 if ((MAI.SrcLangVersion == 200000 && CxxVer == 100000) ||
194 (MAI.SrcLangVersion == 300000 && CxxVer == 202100000)) {
195 MAI.SrcLang = SPIRV::SourceLanguage::CPP_for_OpenCL;
196 MAI.SrcLangVersion = CxxVer;
197 } else {
199 "opencl cxx version is not compatible with opencl c version!");
200 }
201 }
202 } else {
203 // If there is no information about OpenCL version we are forced to generate
204 // OpenCL 1.0 by default for the OpenCL environment to avoid puzzling
205 // run-times with Unknown/0.0 version output. For a reference, LLVM-SPIRV
206 // Translator avoids potential issues with run-times in a similar manner.
207 if (!ST->isShader()) {
208 MAI.SrcLang = SPIRV::SourceLanguage::OpenCL_CPP;
209 MAI.SrcLangVersion = 100000;
210 } else {
211 MAI.SrcLang = SPIRV::SourceLanguage::Unknown;
212 MAI.SrcLangVersion = 0;
213 }
214 }
215
216 if (auto ExtNode = M.getNamedMetadata("opencl.used.extensions")) {
217 for (unsigned I = 0, E = ExtNode->getNumOperands(); I != E; ++I) {
218 MDNode *MD = ExtNode->getOperand(I);
219 if (!MD || MD->getNumOperands() == 0)
220 continue;
221 for (unsigned J = 0, N = MD->getNumOperands(); J != N; ++J)
222 MAI.SrcExt.insert(cast<MDString>(MD->getOperand(J))->getString());
223 }
224 }
225
226 // Update required capabilities for this memory model, addressing model and
227 // source language.
228 MAI.Reqs.getAndAddRequirements(SPIRV::OperandCategory::MemoryModelOperand,
229 MAI.Mem, *ST);
230 MAI.Reqs.getAndAddRequirements(SPIRV::OperandCategory::SourceLanguageOperand,
231 MAI.SrcLang, *ST);
232 MAI.Reqs.getAndAddRequirements(SPIRV::OperandCategory::AddressingModelOperand,
233 MAI.Addr, *ST);
234
235 if (MAI.Mem == SPIRV::MemoryModel::VulkanKHR)
236 MAI.Reqs.addExtension(SPIRV::Extension::SPV_KHR_vulkan_memory_model);
237
238 if (!ST->isShader()) {
239 // TODO: check if it's required by default.
240 MAI.ExtInstSetMap[static_cast<unsigned>(
241 SPIRV::InstructionSet::OpenCL_std)] = MAI.getNextIDRegister();
242 }
243}
244
245// Appends the signature of the decoration instructions that decorate R to
246// Signature.
247static void appendDecorationsForReg(const MachineRegisterInfo &MRI, Register R,
248 InstrSignature &Signature) {
249 for (MachineInstr &UseMI : MRI.use_instructions(R)) {
250 // We don't handle OpDecorateId because getting the register alias for the
251 // ID can cause problems, and we do not need it for now.
252 if (UseMI.getOpcode() != SPIRV::OpDecorate &&
253 UseMI.getOpcode() != SPIRV::OpMemberDecorate)
254 continue;
255
256 for (unsigned I = 0; I < UseMI.getNumOperands(); ++I) {
257 const MachineOperand &MO = UseMI.getOperand(I);
258 if (MO.isReg())
259 continue;
260 Signature.push_back(hash_value(MO));
261 }
262 }
263}
264
265// Returns a representation of an instruction as a vector of MachineOperand
266// hash values, see llvm::hash_value(const MachineOperand &MO) for details.
267// This creates a signature of the instruction with the same content
268// that MachineOperand::isIdenticalTo uses for comparison.
269static InstrSignature instrToSignature(const MachineInstr &MI,
271 bool UseDefReg) {
272 Register DefReg;
273 InstrSignature Signature{MI.getOpcode()};
274 for (unsigned i = 0; i < MI.getNumOperands(); ++i) {
275 // The only decorations that can be applied more than once to a given <id>
276 // or structure member are FuncParamAttr (38), UserSemantic (5635),
277 // CacheControlLoadINTEL (6442), and CacheControlStoreINTEL (6443). For all
278 // the rest of decorations, we will only add to the signature the Opcode,
279 // the id to which it applies, and the decoration id, disregarding any
280 // decoration flags. This will ensure that any subsequent decoration with
281 // the same id will be deemed as a duplicate. Then, at the call site, we
282 // will be able to handle duplicates in the best way.
283 unsigned Opcode = MI.getOpcode();
284 if ((Opcode == SPIRV::OpDecorate) && i >= 2) {
285 unsigned DecorationID = MI.getOperand(1).getImm();
286 if (DecorationID != SPIRV::Decoration::FuncParamAttr &&
287 DecorationID != SPIRV::Decoration::UserSemantic &&
288 DecorationID != SPIRV::Decoration::CacheControlLoadINTEL &&
289 DecorationID != SPIRV::Decoration::CacheControlStoreINTEL)
290 continue;
291 }
292 const MachineOperand &MO = MI.getOperand(i);
293 size_t h;
294 if (MO.isReg()) {
295 if (!UseDefReg && MO.isDef()) {
296 assert(!DefReg.isValid() && "Multiple def registers.");
297 DefReg = MO.getReg();
298 continue;
299 }
300 Register RegAlias = MAI.getRegisterAlias(MI.getMF(), MO.getReg());
301 if (!RegAlias.isValid()) {
302 LLVM_DEBUG({
303 dbgs() << "Unexpectedly, no global id found for the operand ";
304 MO.print(dbgs());
305 dbgs() << "\nInstruction: ";
306 MI.print(dbgs());
307 dbgs() << "\n";
308 });
309 report_fatal_error("All v-regs must have been mapped to global id's");
310 }
311 // mimic llvm::hash_value(const MachineOperand &MO)
312 h = hash_combine(MO.getType(), (unsigned)RegAlias, MO.getSubReg(),
313 MO.isDef());
314 } else {
315 h = hash_value(MO);
316 }
317 Signature.push_back(h);
318 }
319
320 if (DefReg.isValid()) {
321 // Decorations change the semantics of the current instruction. So two
322 // identical instruction with different decorations cannot be merged. That
323 // is why we add the decorations to the signature.
324 appendDecorationsForReg(MI.getMF()->getRegInfo(), DefReg, Signature);
325 }
326 return Signature;
327}
328
329bool SPIRVModuleAnalysis::isDeclSection(const MachineRegisterInfo &MRI,
330 const MachineInstr &MI) {
331 unsigned Opcode = MI.getOpcode();
332 switch (Opcode) {
333 case SPIRV::OpTypeForwardPointer:
334 // omit now, collect later
335 return false;
336 case SPIRV::OpVariable:
337 return static_cast<SPIRV::StorageClass::StorageClass>(
338 MI.getOperand(2).getImm()) != SPIRV::StorageClass::Function;
339 case SPIRV::OpFunction:
340 case SPIRV::OpFunctionParameter:
341 return true;
342 }
343 if (GR->hasConstFunPtr() && Opcode == SPIRV::OpUndef) {
344 Register DefReg = MI.getOperand(0).getReg();
345 for (MachineInstr &UseMI : MRI.use_instructions(DefReg)) {
346 if (UseMI.getOpcode() != SPIRV::OpConstantFunctionPointerINTEL)
347 continue;
348 // it's a dummy definition, FP constant refers to a function,
349 // and this is resolved in another way; let's skip this definition
350 assert(UseMI.getOperand(2).isReg() &&
351 UseMI.getOperand(2).getReg() == DefReg);
352 MAI.setSkipEmission(&MI);
353 return false;
354 }
355 }
356 return TII->isTypeDeclInstr(MI) || TII->isConstantInstr(MI) ||
357 TII->isInlineAsmDefInstr(MI);
358}
359
360// This is a special case of a function pointer referring to a possibly
361// forward function declaration. The operand is a dummy OpUndef that
362// requires a special treatment.
363// FunPtrOp is the MachineOperand previously recorded via
364// SPIRVGlobalRegistry::recordFunctionPointer, identifying which Function
365// this placeholder refers to.
366void SPIRVModuleAnalysis::visitFunPtrUse(
367 Register OpReg, const MachineOperand *FunPtrOp,
368 InstrGRegsMap &SignatureToGReg,
369 std::map<const Value *, unsigned> &GlobalToGReg,
370 const MachineFunction *MF) {
371 const MachineOperand *OpFunDef = GR->getFunctionDefinitionByUse(FunPtrOp);
372 assert(OpFunDef && OpFunDef->isReg());
373 // find the actual function definition and number it globally in advance
374 const MachineInstr *OpDefMI = OpFunDef->getParent();
375 assert(OpDefMI && OpDefMI->getOpcode() == SPIRV::OpFunction);
376 const MachineFunction *FunDefMF = OpDefMI->getParent()->getParent();
377 const MachineRegisterInfo &FunDefMRI = FunDefMF->getRegInfo();
378 do {
379 visitDecl(FunDefMRI, SignatureToGReg, GlobalToGReg, FunDefMF, *OpDefMI);
380 OpDefMI = OpDefMI->getNextNode();
381 } while (OpDefMI && (OpDefMI->getOpcode() == SPIRV::OpFunction ||
382 OpDefMI->getOpcode() == SPIRV::OpFunctionParameter));
383 // associate the function pointer with the newly assigned global number
384 MCRegister GlobalFunDefReg =
385 MAI.getRegisterAlias(FunDefMF, OpFunDef->getReg());
386 assert(GlobalFunDefReg.isValid() &&
387 "Function definition must refer to a global register");
388 MAI.setRegisterAlias(MF, OpReg, GlobalFunDefReg);
389}
390
391// Depth first recursive traversal of dependencies. Repeated visits are guarded
392// by MAI.hasRegisterAlias().
393void SPIRVModuleAnalysis::visitDecl(
394 const MachineRegisterInfo &MRI, InstrGRegsMap &SignatureToGReg,
395 std::map<const Value *, unsigned> &GlobalToGReg, const MachineFunction *MF,
396 const MachineInstr &MI) {
397 unsigned Opcode = MI.getOpcode();
398
399 // Process each operand of the instruction to resolve dependencies
400 for (const MachineOperand &MO : MI.operands()) {
401 if (!MO.isReg() || MO.isDef())
402 continue;
403 Register OpReg = MO.getReg();
404 // Handle function pointers special case
405 if (Opcode == SPIRV::OpConstantFunctionPointerINTEL &&
406 MRI.getRegClass(OpReg) == &SPIRV::pIDRegClass) {
407 visitFunPtrUse(OpReg, &MI.getOperand(2), SignatureToGReg, GlobalToGReg,
408 MF);
409 continue;
410 }
411 // Skip already processed instructions
412 if (MAI.hasRegisterAlias(MF, MO.getReg()))
413 continue;
414 // Recursively visit dependencies
415 if (const MachineInstr *OpDefMI = MRI.getUniqueVRegDef(OpReg)) {
416 if (isDeclSection(MRI, *OpDefMI))
417 visitDecl(MRI, SignatureToGReg, GlobalToGReg, MF, *OpDefMI);
418 continue;
419 }
420 // Handle the unexpected case of no unique definition for the SPIR-V
421 // instruction
422 LLVM_DEBUG({
423 dbgs() << "Unexpectedly, no unique definition for the operand ";
424 MO.print(dbgs());
425 dbgs() << "\nInstruction: ";
426 MI.print(dbgs());
427 dbgs() << "\n";
428 });
430 "No unique definition is found for the virtual register");
431 }
432
433 MCRegister GReg;
434 bool IsFunDef = false;
435 if (TII->isSpecConstantInstr(MI)) {
436 GReg = MAI.getNextIDRegister();
437 MAI.MS[SPIRV::MB_TypeConstVars].push_back(&MI);
438 } else if (Opcode == SPIRV::OpFunction ||
439 Opcode == SPIRV::OpFunctionParameter) {
440 GReg = handleFunctionOrParameter(MF, MI, GlobalToGReg, IsFunDef);
441 } else if (Opcode == SPIRV::OpTypeStruct ||
442 Opcode == SPIRV::OpConstantComposite) {
443 GReg = handleTypeDeclOrConstant(MI, SignatureToGReg);
444 const MachineInstr *NextInstr = MI.getNextNode();
445 while (NextInstr &&
446 ((Opcode == SPIRV::OpTypeStruct &&
447 NextInstr->getOpcode() == SPIRV::OpTypeStructContinuedINTEL) ||
448 (Opcode == SPIRV::OpConstantComposite &&
449 NextInstr->getOpcode() ==
450 SPIRV::OpConstantCompositeContinuedINTEL))) {
451 MCRegister Tmp = handleTypeDeclOrConstant(*NextInstr, SignatureToGReg);
452 MAI.setRegisterAlias(MF, NextInstr->getOperand(0).getReg(), Tmp);
453 MAI.setSkipEmission(NextInstr);
454 NextInstr = NextInstr->getNextNode();
455 }
456 } else if (TII->isTypeDeclInstr(MI) || TII->isConstantInstr(MI) ||
457 TII->isInlineAsmDefInstr(MI)) {
458 GReg = handleTypeDeclOrConstant(MI, SignatureToGReg);
459 } else if (Opcode == SPIRV::OpVariable) {
460 GReg = handleVariable(MF, MI, GlobalToGReg);
461 } else {
462 LLVM_DEBUG({
463 dbgs() << "\nInstruction: ";
464 MI.print(dbgs());
465 dbgs() << "\n";
466 });
467 llvm_unreachable("Unexpected instruction is visited");
468 }
469 MAI.setRegisterAlias(MF, MI.getOperand(0).getReg(), GReg);
470 if (!IsFunDef)
471 MAI.setSkipEmission(&MI);
472}
473
474MCRegister SPIRVModuleAnalysis::handleFunctionOrParameter(
475 const MachineFunction *MF, const MachineInstr &MI,
476 std::map<const Value *, unsigned> &GlobalToGReg, bool &IsFunDef) {
477 const Value *GObj = GR->getGlobalObject(MF, MI.getOperand(0).getReg());
478 assert(GObj && "Unregistered global definition");
479 const Function *F = dyn_cast<Function>(GObj);
480 if (!F)
481 F = dyn_cast<Argument>(GObj)->getParent();
482 assert(F && "Expected a reference to a function or an argument");
483 IsFunDef = !F->isDeclaration();
484 auto [It, Inserted] = GlobalToGReg.try_emplace(GObj);
485 if (!Inserted)
486 return It->second;
487 MCRegister GReg = MAI.getNextIDRegister();
488 It->second = GReg;
489 if (!IsFunDef)
490 MAI.MS[SPIRV::MB_ExtFuncDecls].push_back(&MI);
491 return GReg;
492}
493
495SPIRVModuleAnalysis::handleTypeDeclOrConstant(const MachineInstr &MI,
496 InstrGRegsMap &SignatureToGReg) {
497 InstrSignature MISign = instrToSignature(MI, MAI, false);
498 auto [It, Inserted] = SignatureToGReg.try_emplace(MISign);
499 if (!Inserted)
500 return It->second;
501 MCRegister GReg = MAI.getNextIDRegister();
502 It->second = GReg;
503 MAI.MS[SPIRV::MB_TypeConstVars].push_back(&MI);
504 return GReg;
505}
506
507MCRegister SPIRVModuleAnalysis::handleVariable(
508 const MachineFunction *MF, const MachineInstr &MI,
509 std::map<const Value *, unsigned> &GlobalToGReg) {
510 MAI.GlobalVarList.push_back(&MI);
511 const Value *GObj = GR->getGlobalObject(MF, MI.getOperand(0).getReg());
512 assert(GObj && "Unregistered global definition");
513 auto [It, Inserted] = GlobalToGReg.try_emplace(GObj);
514 if (!Inserted)
515 return It->second;
516 MCRegister GReg = MAI.getNextIDRegister();
517 It->second = GReg;
518 MAI.MS[SPIRV::MB_TypeConstVars].push_back(&MI);
519 if (const auto *GV = dyn_cast<GlobalVariable>(GObj))
520 MAI.GlobalObjMap[GV] = GReg;
521 return GReg;
522}
523
524void SPIRVModuleAnalysis::collectDeclarations(const Module &M) {
525 InstrGRegsMap SignatureToGReg;
526 std::map<const Value *, unsigned> GlobalToGReg;
527 for (const Function &F : M) {
528 MachineFunction *MF = MMI->getMachineFunction(F);
529 if (!MF)
530 continue;
531 const MachineRegisterInfo &MRI = MF->getRegInfo();
532 unsigned PastHeader = 0;
533 for (MachineBasicBlock &MBB : *MF) {
534 for (MachineInstr &MI : MBB) {
535 if (MI.getNumOperands() == 0)
536 continue;
537 unsigned Opcode = MI.getOpcode();
538 if (Opcode == SPIRV::OpFunction) {
539 if (PastHeader == 0) {
540 PastHeader = 1;
541 continue;
542 }
543 } else if (Opcode == SPIRV::OpFunctionParameter) {
544 if (PastHeader < 2)
545 continue;
546 } else if (PastHeader > 0) {
547 PastHeader = 2;
548 }
549
550 const MachineOperand &DefMO = MI.getOperand(0);
551 switch (Opcode) {
552 case SPIRV::OpExtension:
553 MAI.Reqs.addExtension(SPIRV::Extension::Extension(DefMO.getImm()));
554 MAI.setSkipEmission(&MI);
555 break;
556 case SPIRV::OpCapability:
557 MAI.Reqs.addCapability(SPIRV::Capability::Capability(DefMO.getImm()));
558 MAI.setSkipEmission(&MI);
559 if (PastHeader > 0)
560 PastHeader = 2;
561 break;
562 default:
563 if (DefMO.isReg() && isDeclSection(MRI, MI) &&
564 !MAI.hasRegisterAlias(MF, DefMO.getReg()))
565 visitDecl(MRI, SignatureToGReg, GlobalToGReg, MF, MI);
566 }
567 }
568 }
569 }
570}
571
572// Look for IDs declared with Import linkage, and map the corresponding function
573// to the register defining that variable (which will usually be the result of
574// an OpFunction). This lets us call externally imported functions using
575// the correct ID registers.
576void SPIRVModuleAnalysis::collectFuncNames(MachineInstr &MI,
577 const Function *F) {
578 if (MI.getOpcode() == SPIRV::OpDecorate) {
579 // If it's got Import linkage.
580 auto Dec = MI.getOperand(1).getImm();
581 if (Dec == SPIRV::Decoration::LinkageAttributes) {
582 auto Lnk = MI.getOperand(MI.getNumOperands() - 1).getImm();
583 if (Lnk == SPIRV::LinkageType::Import) {
584 // Map imported function name to function ID register.
585 const Function *ImportedFunc =
586 F->getParent()->getFunction(getStringImm(MI, 2));
587 Register Target = MI.getOperand(0).getReg();
588 MAI.GlobalObjMap[ImportedFunc] =
589 MAI.getRegisterAlias(MI.getMF(), Target);
590 }
591 }
592 } else if (MI.getOpcode() == SPIRV::OpFunction) {
593 // Record all internal OpFunction declarations.
594 Register Reg = MI.defs().begin()->getReg();
595 MCRegister GlobalReg = MAI.getRegisterAlias(MI.getMF(), Reg);
596 assert(GlobalReg.isValid());
597 MAI.GlobalObjMap[F] = GlobalReg;
598 }
599}
600
601// Collect the given instruction in the specified MS. We assume global register
602// numbering has already occurred by this point. We can directly compare reg
603// arguments when detecting duplicates.
604static void collectOtherInstr(MachineInstr &MI, SPIRV::ModuleAnalysisInfo &MAI,
606 bool Append = true) {
607 MAI.setSkipEmission(&MI);
608 InstrSignature MISign = instrToSignature(MI, MAI, true);
609 auto FoundMI = IS.insert(std::move(MISign));
610 if (!FoundMI.second) {
611 if (MI.getOpcode() == SPIRV::OpDecorate) {
612 assert(MI.getNumOperands() >= 2 &&
613 "Decoration instructions must have at least 2 operands");
614 assert(MSType == SPIRV::MB_Annotations &&
615 "Only OpDecorate instructions can be duplicates");
616 // For FPFastMathMode decoration, we need to merge the flags of the
617 // duplicate decoration with the original one, so we need to find the
618 // original instruction that has the same signature. For the rest of
619 // instructions, we will simply skip the duplicate.
620 if (MI.getOperand(1).getImm() != SPIRV::Decoration::FPFastMathMode)
621 return; // Skip duplicates of other decorations.
622
623 const SPIRV::InstrList &Decorations = MAI.MS[MSType];
624 for (const MachineInstr *OrigMI : Decorations) {
625 if (instrToSignature(*OrigMI, MAI, true) == MISign) {
626 assert(OrigMI->getNumOperands() == MI.getNumOperands() &&
627 "Original instruction must have the same number of operands");
628 assert(
629 OrigMI->getNumOperands() == 3 &&
630 "FPFastMathMode decoration must have 3 operands for OpDecorate");
631 unsigned OrigFlags = OrigMI->getOperand(2).getImm();
632 unsigned NewFlags = MI.getOperand(2).getImm();
633 if (OrigFlags == NewFlags)
634 return; // No need to merge, the flags are the same.
635
636 // Emit warning about possible conflict between flags.
637 unsigned FinalFlags = OrigFlags | NewFlags;
638 llvm::errs()
639 << "Warning: Conflicting FPFastMathMode decoration flags "
640 "in instruction: "
641 << *OrigMI << "Original flags: " << OrigFlags
642 << ", new flags: " << NewFlags
643 << ". They will be merged on a best effort basis, but not "
644 "validated. Final flags: "
645 << FinalFlags << "\n";
646 MachineInstr *OrigMINonConst = const_cast<MachineInstr *>(OrigMI);
647 MachineOperand &OrigFlagsOp = OrigMINonConst->getOperand(2);
648 OrigFlagsOp = MachineOperand::CreateImm(FinalFlags);
649 return; // Merge done, so we found a duplicate; don't add it to MAI.MS
650 }
651 }
652 assert(false && "No original instruction found for the duplicate "
653 "OpDecorate, but we found one in IS.");
654 }
655 return; // insert failed, so we found a duplicate; don't add it to MAI.MS
656 }
657 // No duplicates, so add it.
658 if (Append)
659 MAI.MS[MSType].push_back(&MI);
660 else
661 MAI.MS[MSType].insert(MAI.MS[MSType].begin(), &MI);
662}
663
664// Some global instructions make reference to function-local ID regs, so cannot
665// be correctly collected until these registers are globally numbered.
666void SPIRVModuleAnalysis::processOtherInstrs(const Module &M) {
667 InstrTraces IS;
668 for (const Function &F : M) {
669 if (F.isDeclaration())
670 continue;
671 MachineFunction *MF = MMI->getMachineFunction(F);
672 assert(MF);
673
674 for (MachineBasicBlock &MBB : *MF)
675 for (MachineInstr &MI : MBB) {
676 if (MAI.getSkipEmission(&MI))
677 continue;
678 const unsigned OpCode = MI.getOpcode();
679 if (OpCode == SPIRV::OpString) {
680 collectOtherInstr(MI, MAI, SPIRV::MB_DebugStrings, IS);
681 } else if (OpCode == SPIRV::OpExtInst && MI.getOperand(2).isImm() &&
682 MI.getOperand(2).getImm() ==
683 SPIRV::InstructionSet::
684 NonSemantic_Shader_DebugInfo_100) {
685 // TODO: This branch is dead. SPIRVNonSemanticDebugHandler emits NSDI
686 // instructions directly as MCInsts at print time; no
687 // MachineInstructions with the NSDI ext set are created anymore.
688 // Remove this block and
689 // MB_NonSemanticGlobalDI once per-function NSDI emission is confirmed
690 // not to need MIR routing.
691 MachineOperand Ins = MI.getOperand(3);
692 namespace NS = SPIRV::NonSemanticExtInst;
693 static constexpr int64_t GlobalNonSemanticDITy[] = {
694 NS::DebugSource, NS::DebugCompilationUnit, NS::DebugInfoNone,
695 NS::DebugTypeBasic, NS::DebugTypePointer};
696 bool IsGlobalDI = false;
697 for (unsigned Idx = 0; Idx < std::size(GlobalNonSemanticDITy); ++Idx)
698 IsGlobalDI |= Ins.getImm() == GlobalNonSemanticDITy[Idx];
699 if (IsGlobalDI)
700 collectOtherInstr(MI, MAI, SPIRV::MB_NonSemanticGlobalDI, IS);
701 } else if (OpCode == SPIRV::OpName || OpCode == SPIRV::OpMemberName) {
702 collectOtherInstr(MI, MAI, SPIRV::MB_DebugNames, IS);
703 } else if (OpCode == SPIRV::OpEntryPoint) {
704 collectOtherInstr(MI, MAI, SPIRV::MB_EntryPoints, IS);
705 } else if (TII->isAliasingInstr(MI)) {
706 collectOtherInstr(MI, MAI, SPIRV::MB_AliasingInsts, IS);
707 } else if (TII->isDecorationInstr(MI)) {
708 collectOtherInstr(MI, MAI, SPIRV::MB_Annotations, IS);
709 collectFuncNames(MI, &F);
710 } else if (TII->isConstantInstr(MI)) {
711 // Now OpSpecConstant*s are not in DT,
712 // but they need to be collected anyway.
713 collectOtherInstr(MI, MAI, SPIRV::MB_TypeConstVars, IS);
714 } else if (OpCode == SPIRV::OpFunction) {
715 collectFuncNames(MI, &F);
716 } else if (OpCode == SPIRV::OpTypeForwardPointer) {
717 collectOtherInstr(MI, MAI, SPIRV::MB_TypeConstVars, IS, false);
718 }
719 }
720 }
721}
722
723// Number registers in all functions globally from 0 onwards and store
724// the result in global register alias table. Some registers are already
725// numbered.
726void SPIRVModuleAnalysis::numberRegistersGlobally(const Module &M) {
727 for (const Function &F : M) {
728 if (F.isDeclaration())
729 continue;
730 MachineFunction *MF = MMI->getMachineFunction(F);
731 assert(MF);
732 for (MachineBasicBlock &MBB : *MF) {
733 for (MachineInstr &MI : MBB) {
734 for (MachineOperand &Op : MI.operands()) {
735 if (!Op.isReg())
736 continue;
737 Register Reg = Op.getReg();
738 if (MAI.hasRegisterAlias(MF, Reg))
739 continue;
740 MCRegister NewReg = MAI.getNextIDRegister();
741 MAI.setRegisterAlias(MF, Reg, NewReg);
742 }
743 if (MI.getOpcode() != SPIRV::OpExtInst)
744 continue;
745 auto Set = MI.getOperand(2).getImm();
746 auto [It, Inserted] = MAI.ExtInstSetMap.try_emplace(Set);
747 if (Inserted)
748 It->second = MAI.getNextIDRegister();
749 }
750 }
751 }
752}
753
754// RequirementHandler implementations.
756 SPIRV::OperandCategory::OperandCategory Category, uint32_t i,
757 const SPIRVSubtarget &ST) {
758 addRequirements(getSymbolicOperandRequirements(Category, i, ST, *this));
759}
760
761void SPIRV::RequirementHandler::recursiveAddCapabilities(
762 const CapabilityList &ToPrune) {
763 for (const auto &Cap : ToPrune) {
764 AllCaps.insert(Cap);
765 CapabilityList ImplicitDecls =
766 getSymbolicOperandCapabilities(OperandCategory::CapabilityOperand, Cap);
767 recursiveAddCapabilities(ImplicitDecls);
768 }
769}
770
772 for (const auto &Cap : ToAdd) {
773 bool IsNewlyInserted = AllCaps.insert(Cap).second;
774 if (!IsNewlyInserted) // Don't re-add if it's already been declared.
775 continue;
776 CapabilityList ImplicitDecls =
777 getSymbolicOperandCapabilities(OperandCategory::CapabilityOperand, Cap);
778 recursiveAddCapabilities(ImplicitDecls);
779 MinimalCaps.push_back(Cap);
780 }
781}
782
784 const SPIRV::Requirements &Req) {
785 if (!Req.IsSatisfiable)
786 report_fatal_error("Adding SPIR-V requirements this target can't satisfy.");
787
788 if (Req.Cap.has_value())
789 addCapabilities({Req.Cap.value()});
790
791 addExtensions(Req.Exts);
792
793 if (!Req.MinVer.empty()) {
794 if (!MaxVersion.empty() && Req.MinVer > MaxVersion) {
795 LLVM_DEBUG(dbgs() << "Conflicting version requirements: >= " << Req.MinVer
796 << " and <= " << MaxVersion << "\n");
797 report_fatal_error("Adding SPIR-V requirements that can't be satisfied.");
798 }
799
800 if (MinVersion.empty() || Req.MinVer > MinVersion)
801 MinVersion = Req.MinVer;
802 }
803
804 if (!Req.MaxVer.empty()) {
805 if (!MinVersion.empty() && Req.MaxVer < MinVersion) {
806 LLVM_DEBUG(dbgs() << "Conflicting version requirements: <= " << Req.MaxVer
807 << " and >= " << MinVersion << "\n");
808 report_fatal_error("Adding SPIR-V requirements that can't be satisfied.");
809 }
810
811 if (MaxVersion.empty() || Req.MaxVer < MaxVersion)
812 MaxVersion = Req.MaxVer;
813 }
814}
815
817 const SPIRVSubtarget &ST) const {
818 // Report as many errors as possible before aborting the compilation.
819 bool IsSatisfiable = true;
820 auto TargetVer = ST.getSPIRVVersion();
821
822 if (!MaxVersion.empty() && !TargetVer.empty() && MaxVersion < TargetVer) {
824 dbgs() << "Target SPIR-V version too high for required features\n"
825 << "Required max version: " << MaxVersion << " target version "
826 << TargetVer << "\n");
827 IsSatisfiable = false;
828 }
829
830 if (!MinVersion.empty() && !TargetVer.empty() && MinVersion > TargetVer) {
831 LLVM_DEBUG(dbgs() << "Target SPIR-V version too low for required features\n"
832 << "Required min version: " << MinVersion
833 << " target version " << TargetVer << "\n");
834 IsSatisfiable = false;
835 }
836
837 if (!MinVersion.empty() && !MaxVersion.empty() && MinVersion > MaxVersion) {
839 dbgs()
840 << "Version is too low for some features and too high for others.\n"
841 << "Required SPIR-V min version: " << MinVersion
842 << " required SPIR-V max version " << MaxVersion << "\n");
843 IsSatisfiable = false;
844 }
845
846 AvoidCapabilitiesSet AvoidCaps;
847 if (!ST.isShader())
848 AvoidCaps.S.insert(SPIRV::Capability::Shader);
849 else
850 AvoidCaps.S.insert(SPIRV::Capability::Kernel);
851
852 for (auto Cap : MinimalCaps) {
853 if (AvailableCaps.contains(Cap) && !AvoidCaps.S.contains(Cap))
854 continue;
855 LLVM_DEBUG(dbgs() << "Capability not supported: "
857 OperandCategory::CapabilityOperand, Cap)
858 << "\n");
859 IsSatisfiable = false;
860 }
861
862 for (auto Ext : AllExtensions) {
863 if (ST.canUseExtension(Ext))
864 continue;
865 LLVM_DEBUG(dbgs() << "Extension not supported: "
867 OperandCategory::ExtensionOperand, Ext)
868 << "\n");
869 IsSatisfiable = false;
870 }
871
872 if (!IsSatisfiable)
873 report_fatal_error("Unable to meet SPIR-V requirements for this target.");
874}
875
876// Add the given capabilities and all their implicitly defined capabilities too.
878 for (const auto Cap : ToAdd)
879 if (AvailableCaps.insert(Cap).second)
880 addAvailableCaps(getSymbolicOperandCapabilities(
881 SPIRV::OperandCategory::CapabilityOperand, Cap));
882}
883
885 const Capability::Capability ToRemove,
886 const Capability::Capability IfPresent) {
887 if (AllCaps.contains(IfPresent))
888 AllCaps.erase(ToRemove);
889}
890
891namespace llvm {
892namespace SPIRV {
893void RequirementHandler::initAvailableCapabilities(const SPIRVSubtarget &ST) {
894 // Provided by both all supported Vulkan versions and OpenCl.
895 addAvailableCaps({Capability::Shader, Capability::Linkage, Capability::Int8,
896 Capability::Int16});
897
898 if (ST.isAtLeastSPIRVVer(VersionTuple(1, 3)))
899 addAvailableCaps({Capability::GroupNonUniform,
900 Capability::GroupNonUniformVote,
901 Capability::GroupNonUniformArithmetic,
902 Capability::GroupNonUniformBallot,
903 Capability::GroupNonUniformClustered,
904 Capability::GroupNonUniformShuffle,
905 Capability::GroupNonUniformShuffleRelative,
906 Capability::GroupNonUniformQuad});
907
908 if (ST.isAtLeastSPIRVVer(VersionTuple(1, 6)))
909 addAvailableCaps({Capability::DotProduct, Capability::DotProductInputAll,
910 Capability::DotProductInput4x8Bit,
911 Capability::DotProductInput4x8BitPacked,
912 Capability::DemoteToHelperInvocation});
913
914 // Add capabilities enabled by extensions.
915 for (auto Extension : ST.getAllAvailableExtensions()) {
916 CapabilityList EnabledCapabilities =
918 addAvailableCaps(EnabledCapabilities);
919 }
920
921 if (!ST.isShader()) {
922 initAvailableCapabilitiesForOpenCL(ST);
923 return;
924 }
925
926 if (ST.isShader()) {
927 initAvailableCapabilitiesForVulkan(ST);
928 return;
929 }
930
931 report_fatal_error("Unimplemented environment for SPIR-V generation.");
932}
933
934void RequirementHandler::initAvailableCapabilitiesForOpenCL(
935 const SPIRVSubtarget &ST) {
936 // Add the min requirements for different OpenCL and SPIR-V versions.
937 addAvailableCaps({Capability::Addresses, Capability::Float16Buffer,
938 Capability::Kernel, Capability::Vector16,
939 Capability::Groups, Capability::GenericPointer,
940 Capability::StorageImageWriteWithoutFormat,
941 Capability::StorageImageReadWithoutFormat});
942 if (ST.hasOpenCLFullProfile())
943 addAvailableCaps({Capability::Int64, Capability::Int64Atomics});
944 if (ST.hasOpenCLImageSupport()) {
945 addAvailableCaps({Capability::ImageBasic, Capability::LiteralSampler,
946 Capability::Image1D, Capability::SampledBuffer,
947 Capability::ImageBuffer});
948 if (ST.isAtLeastOpenCLVer(VersionTuple(2, 0)))
949 addAvailableCaps({Capability::ImageReadWrite});
950 }
951 if (ST.isAtLeastSPIRVVer(VersionTuple(1, 1)) &&
952 ST.isAtLeastOpenCLVer(VersionTuple(2, 2)))
953 addAvailableCaps({Capability::SubgroupDispatch, Capability::PipeStorage});
954 if (ST.isAtLeastSPIRVVer(VersionTuple(1, 4)))
955 addAvailableCaps({Capability::DenormPreserve, Capability::DenormFlushToZero,
956 Capability::SignedZeroInfNanPreserve,
957 Capability::RoundingModeRTE,
958 Capability::RoundingModeRTZ});
959 // TODO: verify if this needs some checks.
960 addAvailableCaps({Capability::Float16, Capability::Float64});
961
962 // TODO: add OpenCL extensions.
963}
964
965void RequirementHandler::initAvailableCapabilitiesForVulkan(
966 const SPIRVSubtarget &ST) {
967
968 // Core in Vulkan 1.1 and earlier.
969 addAvailableCaps({Capability::Int64,
970 Capability::Float16,
971 Capability::Float64,
972 Capability::GroupNonUniform,
973 Capability::Image1D,
974 Capability::SampledBuffer,
975 Capability::ImageBuffer,
976 Capability::UniformBufferArrayDynamicIndexing,
977 Capability::SampledImageArrayDynamicIndexing,
978 Capability::StorageBufferArrayDynamicIndexing,
979 Capability::StorageImageArrayDynamicIndexing,
980 Capability::DerivativeControl,
981 Capability::MinLod,
982 Capability::ImageQuery,
983 Capability::ImageGatherExtended,
984 Capability::Addresses,
985 Capability::VulkanMemoryModelKHR,
986 Capability::StorageImageExtendedFormats,
987 Capability::StorageImageMultisample,
988 Capability::ImageMSArray});
989
990 // Became core in Vulkan 1.2
991 if (ST.isAtLeastSPIRVVer(VersionTuple(1, 5))) {
993 {Capability::ShaderNonUniformEXT, Capability::RuntimeDescriptorArrayEXT,
994 Capability::InputAttachmentArrayDynamicIndexingEXT,
995 Capability::UniformTexelBufferArrayDynamicIndexingEXT,
996 Capability::StorageTexelBufferArrayDynamicIndexingEXT,
997 Capability::UniformBufferArrayNonUniformIndexingEXT,
998 Capability::SampledImageArrayNonUniformIndexingEXT,
999 Capability::StorageBufferArrayNonUniformIndexingEXT,
1000 Capability::StorageImageArrayNonUniformIndexingEXT,
1001 Capability::InputAttachmentArrayNonUniformIndexingEXT,
1002 Capability::UniformTexelBufferArrayNonUniformIndexingEXT,
1003 Capability::StorageTexelBufferArrayNonUniformIndexingEXT});
1004 }
1005
1006 // Became core in Vulkan 1.3
1007 if (ST.isAtLeastSPIRVVer(VersionTuple(1, 6)))
1008 addAvailableCaps({Capability::StorageImageWriteWithoutFormat,
1009 Capability::StorageImageReadWithoutFormat});
1010}
1011
1012} // namespace SPIRV
1013} // namespace llvm
1014
1015// Add the required capabilities from a decoration instruction (including
1016// BuiltIns).
1017static void addOpDecorateReqs(const MachineInstr &MI, unsigned DecIndex,
1019 const SPIRVSubtarget &ST) {
1020 int64_t DecOp = MI.getOperand(DecIndex).getImm();
1021 auto Dec = static_cast<SPIRV::Decoration::Decoration>(DecOp);
1022 Reqs.addRequirements(getSymbolicOperandRequirements(
1023 SPIRV::OperandCategory::DecorationOperand, Dec, ST, Reqs));
1024
1025 if (Dec == SPIRV::Decoration::BuiltIn) {
1026 int64_t BuiltInOp = MI.getOperand(DecIndex + 1).getImm();
1027 auto BuiltIn = static_cast<SPIRV::BuiltIn::BuiltIn>(BuiltInOp);
1028 Reqs.addRequirements(getSymbolicOperandRequirements(
1029 SPIRV::OperandCategory::BuiltInOperand, BuiltIn, ST, Reqs));
1030 } else if (Dec == SPIRV::Decoration::LinkageAttributes) {
1031 int64_t LinkageOp = MI.getOperand(MI.getNumOperands() - 1).getImm();
1032 SPIRV::LinkageType::LinkageType LnkType =
1033 static_cast<SPIRV::LinkageType::LinkageType>(LinkageOp);
1034 if (LnkType == SPIRV::LinkageType::LinkOnceODR)
1035 Reqs.addExtension(SPIRV::Extension::SPV_KHR_linkonce_odr);
1036 else if (LnkType == SPIRV::LinkageType::WeakAMD) {
1037 Reqs.addExtension(SPIRV::Extension::SPV_AMD_weak_linkage);
1038 Reqs.addCapability(SPIRV::Capability::WeakLinkageAMD);
1039 }
1040 } else if (Dec == SPIRV::Decoration::CacheControlLoadINTEL ||
1041 Dec == SPIRV::Decoration::CacheControlStoreINTEL) {
1042 Reqs.addExtension(SPIRV::Extension::SPV_INTEL_cache_controls);
1043 } else if (Dec == SPIRV::Decoration::HostAccessINTEL) {
1044 Reqs.addExtension(SPIRV::Extension::SPV_INTEL_global_variable_host_access);
1045 } else if (Dec == SPIRV::Decoration::InitModeINTEL ||
1046 Dec == SPIRV::Decoration::ImplementInRegisterMapINTEL) {
1047 Reqs.addExtension(
1048 SPIRV::Extension::SPV_INTEL_global_variable_fpga_decorations);
1049 } else if (Dec == SPIRV::Decoration::NonUniformEXT) {
1050 Reqs.addRequirements(SPIRV::Capability::ShaderNonUniformEXT);
1051 } else if (Dec == SPIRV::Decoration::FPMaxErrorDecorationINTEL) {
1052 Reqs.addRequirements(SPIRV::Capability::FPMaxErrorINTEL);
1053 Reqs.addExtension(SPIRV::Extension::SPV_INTEL_fp_max_error);
1054 } else if (Dec == SPIRV::Decoration::FPFastMathMode) {
1055 if (ST.canUseExtension(SPIRV::Extension::SPV_KHR_float_controls2)) {
1056 Reqs.addRequirements(SPIRV::Capability::FloatControls2);
1057 Reqs.addExtension(SPIRV::Extension::SPV_KHR_float_controls2);
1058 }
1059 }
1060}
1061
1062// Add requirements for image handling.
1063static void addOpTypeImageReqs(const MachineInstr &MI,
1065 const SPIRVSubtarget &ST) {
1066 assert(MI.getNumOperands() >= 8 && "Insufficient operands for OpTypeImage");
1067 // The operand indices used here are based on the OpTypeImage layout, which
1068 // the MachineInstr follows as well.
1069 int64_t ImgFormatOp = MI.getOperand(7).getImm();
1070 auto ImgFormat = static_cast<SPIRV::ImageFormat::ImageFormat>(ImgFormatOp);
1071 Reqs.getAndAddRequirements(SPIRV::OperandCategory::ImageFormatOperand,
1072 ImgFormat, ST);
1073
1074 bool IsArrayed = MI.getOperand(4).getImm() == 1;
1075 bool IsMultisampled = MI.getOperand(5).getImm() == 1;
1076 bool NoSampler = MI.getOperand(6).getImm() == 2;
1077 // Add dimension requirements.
1078 assert(MI.getOperand(2).isImm());
1079 switch (MI.getOperand(2).getImm()) {
1080 case SPIRV::Dim::DIM_1D:
1081 Reqs.addRequirements(NoSampler ? SPIRV::Capability::Image1D
1082 : SPIRV::Capability::Sampled1D);
1083 break;
1084 case SPIRV::Dim::DIM_2D:
1085 if (IsMultisampled && NoSampler)
1086 Reqs.addRequirements(SPIRV::Capability::StorageImageMultisample);
1087 if (IsMultisampled && IsArrayed)
1088 Reqs.addRequirements(SPIRV::Capability::ImageMSArray);
1089 break;
1090 case SPIRV::Dim::DIM_3D:
1091 break;
1092 case SPIRV::Dim::DIM_Cube:
1093 Reqs.addRequirements(SPIRV::Capability::Shader);
1094 if (IsArrayed)
1095 Reqs.addRequirements(NoSampler ? SPIRV::Capability::ImageCubeArray
1096 : SPIRV::Capability::SampledCubeArray);
1097 break;
1098 case SPIRV::Dim::DIM_Rect:
1099 Reqs.addRequirements(NoSampler ? SPIRV::Capability::ImageRect
1100 : SPIRV::Capability::SampledRect);
1101 break;
1102 case SPIRV::Dim::DIM_Buffer:
1103 Reqs.addRequirements(NoSampler ? SPIRV::Capability::ImageBuffer
1104 : SPIRV::Capability::SampledBuffer);
1105 break;
1106 case SPIRV::Dim::DIM_SubpassData:
1107 Reqs.addRequirements(SPIRV::Capability::InputAttachment);
1108 break;
1109 }
1110
1111 // Check if the sampled type is a 64-bit integer, which requires
1112 // Int64ImageEXT capability.
1113 assert(MI.getOperand(1).isReg());
1114 const MachineRegisterInfo &MRI = MI.getMF()->getRegInfo();
1115 SPIRVTypeInst SampledTypeDef = MRI.getVRegDef(MI.getOperand(1).getReg());
1116 if (SampledTypeDef.isTypeIntN(64)) {
1117 Reqs.addCapability(SPIRV::Capability::Int64ImageEXT);
1118 Reqs.addExtension(SPIRV::Extension::SPV_EXT_shader_image_int64);
1119 }
1120
1121 // Has optional access qualifier.
1122 if (!ST.isShader()) {
1123 if (MI.getNumOperands() > 8 &&
1124 MI.getOperand(8).getImm() == SPIRV::AccessQualifier::ReadWrite)
1125 Reqs.addRequirements(SPIRV::Capability::ImageReadWrite);
1126 else
1127 Reqs.addRequirements(SPIRV::Capability::ImageBasic);
1128 }
1129}
1130
1131static bool isBFloat16Type(SPIRVTypeInst TypeDef) {
1132 return TypeDef && TypeDef->getNumOperands() == 3 &&
1133 TypeDef->getOpcode() == SPIRV::OpTypeFloat &&
1134 TypeDef->getOperand(1).getImm() == 16 &&
1135 TypeDef->getOperand(2).getImm() == SPIRV::FPEncoding::BFloat16KHR;
1136}
1137
1138// Add requirements for handling atomic float instructions
1139#define ATOM_FLT_REQ_EXT_MSG(ExtName) \
1140 "The atomic float instruction requires the following SPIR-V " \
1141 "extension: SPV_EXT_shader_atomic_float" ExtName
1142static void AddAtomicVectorFloatRequirements(const MachineInstr &MI,
1144 const SPIRVSubtarget &ST) {
1145 SPIRVTypeInst VecTypeDef =
1146 MI.getMF()->getRegInfo().getVRegDef(MI.getOperand(1).getReg());
1147
1148 const unsigned Rank = VecTypeDef->getOperand(2).getImm();
1149 if (Rank != 2 && Rank != 4)
1150 reportFatalUsageError("Result type of an atomic vector float instruction "
1151 "must be a 2-component or 4 component vector");
1152
1153 SPIRVTypeInst EltTypeDef =
1154 MI.getMF()->getRegInfo().getVRegDef(VecTypeDef->getOperand(1).getReg());
1155
1156 if (EltTypeDef->getOpcode() != SPIRV::OpTypeFloat ||
1157 EltTypeDef->getOperand(1).getImm() != 16)
1159 "The element type for the result type of an atomic vector float "
1160 "instruction must be a 16-bit floating-point scalar");
1161
1162 if (isBFloat16Type(EltTypeDef))
1164 "The element type for the result type of an atomic vector float "
1165 "instruction cannot be a bfloat16 scalar");
1166 if (!ST.canUseExtension(SPIRV::Extension::SPV_NV_shader_atomic_fp16_vector))
1168 "The atomic float16 vector instruction requires the following SPIR-V "
1169 "extension: SPV_NV_shader_atomic_fp16_vector");
1170
1171 Reqs.addExtension(SPIRV::Extension::SPV_NV_shader_atomic_fp16_vector);
1172 Reqs.addCapability(SPIRV::Capability::AtomicFloat16VectorNV);
1173}
1174
1175static void AddAtomicFloatRequirements(const MachineInstr &MI,
1177 const SPIRVSubtarget &ST) {
1178 assert(MI.getOperand(1).isReg() &&
1179 "Expect register operand in atomic float instruction");
1180 Register TypeReg = MI.getOperand(1).getReg();
1181 SPIRVTypeInst TypeDef = MI.getMF()->getRegInfo().getVRegDef(TypeReg);
1182
1183 if (TypeDef->getOpcode() == SPIRV::OpTypeVector)
1184 return AddAtomicVectorFloatRequirements(MI, Reqs, ST);
1185
1186 if (TypeDef->getOpcode() != SPIRV::OpTypeFloat)
1187 report_fatal_error("Result type of an atomic float instruction must be a "
1188 "floating-point type scalar");
1189
1190 unsigned BitWidth = TypeDef->getOperand(1).getImm();
1191 unsigned Op = MI.getOpcode();
1192 if (Op == SPIRV::OpAtomicFAddEXT) {
1193 if (!ST.canUseExtension(SPIRV::Extension::SPV_EXT_shader_atomic_float_add))
1195 Reqs.addExtension(SPIRV::Extension::SPV_EXT_shader_atomic_float_add);
1196 switch (BitWidth) {
1197 case 16:
1198 if (isBFloat16Type(TypeDef)) {
1199 if (!ST.canUseExtension(SPIRV::Extension::SPV_INTEL_16bit_atomics))
1201 "The atomic bfloat16 instruction requires the following SPIR-V "
1202 "extension: SPV_INTEL_16bit_atomics",
1203 false);
1204 Reqs.addExtension(SPIRV::Extension::SPV_INTEL_16bit_atomics);
1205 Reqs.addCapability(SPIRV::Capability::AtomicBFloat16AddINTEL);
1206 } else {
1207 if (!ST.canUseExtension(
1208 SPIRV::Extension::SPV_EXT_shader_atomic_float16_add))
1209 report_fatal_error(ATOM_FLT_REQ_EXT_MSG("16_add"), false);
1210 Reqs.addExtension(SPIRV::Extension::SPV_EXT_shader_atomic_float16_add);
1211 Reqs.addCapability(SPIRV::Capability::AtomicFloat16AddEXT);
1212 }
1213 break;
1214 case 32:
1215 Reqs.addCapability(SPIRV::Capability::AtomicFloat32AddEXT);
1216 break;
1217 case 64:
1218 Reqs.addCapability(SPIRV::Capability::AtomicFloat64AddEXT);
1219 break;
1220 default:
1222 "Unexpected floating-point type width in atomic float instruction");
1223 }
1224 } else {
1225 if (!ST.canUseExtension(
1226 SPIRV::Extension::SPV_EXT_shader_atomic_float_min_max))
1227 report_fatal_error(ATOM_FLT_REQ_EXT_MSG("_min_max"), false);
1228 Reqs.addExtension(SPIRV::Extension::SPV_EXT_shader_atomic_float_min_max);
1229 switch (BitWidth) {
1230 case 16:
1231 if (isBFloat16Type(TypeDef)) {
1232 if (!ST.canUseExtension(SPIRV::Extension::SPV_INTEL_16bit_atomics))
1234 "The atomic bfloat16 instruction requires the following SPIR-V "
1235 "extension: SPV_INTEL_16bit_atomics",
1236 false);
1237 Reqs.addExtension(SPIRV::Extension::SPV_INTEL_16bit_atomics);
1238 Reqs.addCapability(SPIRV::Capability::AtomicBFloat16MinMaxINTEL);
1239 } else {
1240 Reqs.addCapability(SPIRV::Capability::AtomicFloat16MinMaxEXT);
1241 }
1242 break;
1243 case 32:
1244 Reqs.addCapability(SPIRV::Capability::AtomicFloat32MinMaxEXT);
1245 break;
1246 case 64:
1247 Reqs.addCapability(SPIRV::Capability::AtomicFloat64MinMaxEXT);
1248 break;
1249 default:
1251 "Unexpected floating-point type width in atomic float instruction");
1252 }
1253 }
1254}
1255
1256bool isUniformTexelBuffer(MachineInstr *ImageInst) {
1257 if (ImageInst->getOpcode() != SPIRV::OpTypeImage)
1258 return false;
1259 uint32_t Dim = ImageInst->getOperand(2).getImm();
1260 uint32_t Sampled = ImageInst->getOperand(6).getImm();
1261 return Dim == SPIRV::Dim::DIM_Buffer && Sampled == 1;
1262}
1263
1264bool isStorageTexelBuffer(MachineInstr *ImageInst) {
1265 if (ImageInst->getOpcode() != SPIRV::OpTypeImage)
1266 return false;
1267 uint32_t Dim = ImageInst->getOperand(2).getImm();
1268 uint32_t Sampled = ImageInst->getOperand(6).getImm();
1269 return Dim == SPIRV::Dim::DIM_Buffer && Sampled == 2;
1270}
1271
1272bool isSampledImage(MachineInstr *ImageInst) {
1273 if (ImageInst->getOpcode() != SPIRV::OpTypeImage)
1274 return false;
1275 uint32_t Dim = ImageInst->getOperand(2).getImm();
1276 uint32_t Sampled = ImageInst->getOperand(6).getImm();
1277 return Dim != SPIRV::Dim::DIM_Buffer && Sampled == 1;
1278}
1279
1280bool isInputAttachment(MachineInstr *ImageInst) {
1281 if (ImageInst->getOpcode() != SPIRV::OpTypeImage)
1282 return false;
1283 uint32_t Dim = ImageInst->getOperand(2).getImm();
1284 uint32_t Sampled = ImageInst->getOperand(6).getImm();
1285 return Dim == SPIRV::Dim::DIM_SubpassData && Sampled == 2;
1286}
1287
1288bool isStorageImage(MachineInstr *ImageInst) {
1289 if (ImageInst->getOpcode() != SPIRV::OpTypeImage)
1290 return false;
1291 uint32_t Dim = ImageInst->getOperand(2).getImm();
1292 uint32_t Sampled = ImageInst->getOperand(6).getImm();
1293 return Dim != SPIRV::Dim::DIM_Buffer && Sampled == 2;
1294}
1295
1296bool isCombinedImageSampler(MachineInstr *SampledImageInst) {
1297 if (SampledImageInst->getOpcode() != SPIRV::OpTypeSampledImage)
1298 return false;
1299
1300 const MachineRegisterInfo &MRI = SampledImageInst->getMF()->getRegInfo();
1301 Register ImageReg = SampledImageInst->getOperand(1).getReg();
1302 auto *ImageInst = MRI.getUniqueVRegDef(ImageReg);
1303 return isSampledImage(ImageInst);
1304}
1305
1306bool hasNonUniformDecoration(Register Reg, const MachineRegisterInfo &MRI) {
1307 for (const auto &MI : MRI.reg_instructions(Reg)) {
1308 if (MI.getOpcode() != SPIRV::OpDecorate)
1309 continue;
1310
1311 uint32_t Dec = MI.getOperand(1).getImm();
1312 if (Dec == SPIRV::Decoration::NonUniformEXT)
1313 return true;
1314 }
1315 return false;
1316}
1317
1318void addOpAccessChainReqs(const MachineInstr &Instr,
1320 const SPIRVSubtarget &Subtarget) {
1321 const MachineRegisterInfo &MRI = Instr.getMF()->getRegInfo();
1322 // Get the result type. If it is an image type, then the shader uses
1323 // descriptor indexing. The appropriate capabilities will be added based
1324 // on the specifics of the image.
1325 Register ResTypeReg = Instr.getOperand(1).getReg();
1326 MachineInstr *ResTypeInst = MRI.getUniqueVRegDef(ResTypeReg);
1327
1328 assert(ResTypeInst->getOpcode() == SPIRV::OpTypePointer);
1329 uint32_t StorageClass = ResTypeInst->getOperand(1).getImm();
1330 if (StorageClass != SPIRV::StorageClass::StorageClass::UniformConstant &&
1331 StorageClass != SPIRV::StorageClass::StorageClass::Uniform &&
1332 StorageClass != SPIRV::StorageClass::StorageClass::StorageBuffer) {
1333 return;
1334 }
1335
1336 bool IsNonUniform =
1337 hasNonUniformDecoration(Instr.getOperand(0).getReg(), MRI);
1338
1339 auto FirstIndexReg = Instr.getOperand(3).getReg();
1340 bool FirstIndexIsConstant =
1341 Subtarget.getInstrInfo()->isConstantInstr(*MRI.getVRegDef(FirstIndexReg));
1342
1343 if (StorageClass == SPIRV::StorageClass::StorageClass::StorageBuffer) {
1344 if (IsNonUniform)
1345 Handler.addRequirements(
1346 SPIRV::Capability::StorageBufferArrayNonUniformIndexingEXT);
1347 else if (!FirstIndexIsConstant)
1348 Handler.addRequirements(
1349 SPIRV::Capability::StorageBufferArrayDynamicIndexing);
1350 return;
1351 }
1352
1353 Register PointeeTypeReg = ResTypeInst->getOperand(2).getReg();
1354 MachineInstr *PointeeType = MRI.getUniqueVRegDef(PointeeTypeReg);
1355 if (PointeeType->getOpcode() != SPIRV::OpTypeImage &&
1356 PointeeType->getOpcode() != SPIRV::OpTypeSampledImage &&
1357 PointeeType->getOpcode() != SPIRV::OpTypeSampler) {
1358 return;
1359 }
1360
1361 if (isUniformTexelBuffer(PointeeType)) {
1362 if (IsNonUniform)
1363 Handler.addRequirements(
1364 SPIRV::Capability::UniformTexelBufferArrayNonUniformIndexingEXT);
1365 else if (!FirstIndexIsConstant)
1366 Handler.addRequirements(
1367 SPIRV::Capability::UniformTexelBufferArrayDynamicIndexingEXT);
1368 } else if (isInputAttachment(PointeeType)) {
1369 if (IsNonUniform)
1370 Handler.addRequirements(
1371 SPIRV::Capability::InputAttachmentArrayNonUniformIndexingEXT);
1372 else if (!FirstIndexIsConstant)
1373 Handler.addRequirements(
1374 SPIRV::Capability::InputAttachmentArrayDynamicIndexingEXT);
1375 } else if (isStorageTexelBuffer(PointeeType)) {
1376 if (IsNonUniform)
1377 Handler.addRequirements(
1378 SPIRV::Capability::StorageTexelBufferArrayNonUniformIndexingEXT);
1379 else if (!FirstIndexIsConstant)
1380 Handler.addRequirements(
1381 SPIRV::Capability::StorageTexelBufferArrayDynamicIndexingEXT);
1382 } else if (isSampledImage(PointeeType) ||
1383 isCombinedImageSampler(PointeeType) ||
1384 PointeeType->getOpcode() == SPIRV::OpTypeSampler) {
1385 if (IsNonUniform)
1386 Handler.addRequirements(
1387 SPIRV::Capability::SampledImageArrayNonUniformIndexingEXT);
1388 else if (!FirstIndexIsConstant)
1389 Handler.addRequirements(
1390 SPIRV::Capability::SampledImageArrayDynamicIndexing);
1391 } else if (isStorageImage(PointeeType)) {
1392 if (IsNonUniform)
1393 Handler.addRequirements(
1394 SPIRV::Capability::StorageImageArrayNonUniformIndexingEXT);
1395 else if (!FirstIndexIsConstant)
1396 Handler.addRequirements(
1397 SPIRV::Capability::StorageImageArrayDynamicIndexing);
1398 }
1399}
1400
1401static bool isImageTypeWithUnknownFormat(SPIRVTypeInst TypeInst) {
1402 if (TypeInst->getOpcode() != SPIRV::OpTypeImage)
1403 return false;
1404 assert(TypeInst->getOperand(7).isImm() && "The image format must be an imm.");
1405 return TypeInst->getOperand(7).getImm() == 0;
1406}
1407
1408static void AddDotProductRequirements(const MachineInstr &MI,
1410 const SPIRVSubtarget &ST) {
1411 if (ST.canUseExtension(SPIRV::Extension::SPV_KHR_integer_dot_product))
1412 Reqs.addExtension(SPIRV::Extension::SPV_KHR_integer_dot_product);
1413 Reqs.addCapability(SPIRV::Capability::DotProduct);
1414
1415 const MachineRegisterInfo &MRI = MI.getMF()->getRegInfo();
1416 assert(MI.getOperand(2).isReg() && "Unexpected operand in dot");
1417 // We do not consider what the previous instruction is. This is just used
1418 // to get the input register and to check the type.
1419 const MachineInstr *Input = MRI.getVRegDef(MI.getOperand(2).getReg());
1420 assert(Input->getOperand(1).isReg() && "Unexpected operand in dot input");
1421 Register InputReg = Input->getOperand(1).getReg();
1422
1423 SPIRVTypeInst TypeDef = MRI.getVRegDef(InputReg);
1424 if (TypeDef->getOpcode() == SPIRV::OpTypeInt) {
1425 assert(TypeDef->getOperand(1).getImm() == 32);
1426 Reqs.addCapability(SPIRV::Capability::DotProductInput4x8BitPacked);
1427 } else if (TypeDef->getOpcode() == SPIRV::OpTypeVector) {
1428 SPIRVTypeInst ScalarTypeDef =
1429 MRI.getVRegDef(TypeDef->getOperand(1).getReg());
1430 assert(ScalarTypeDef->getOpcode() == SPIRV::OpTypeInt);
1431 if (ScalarTypeDef->getOperand(1).getImm() == 8) {
1432 assert(TypeDef->getOperand(2).getImm() == 4 &&
1433 "Dot operand of 8-bit integer type requires 4 components");
1434 Reqs.addCapability(SPIRV::Capability::DotProductInput4x8Bit);
1435 } else {
1436 Reqs.addCapability(SPIRV::Capability::DotProductInputAll);
1437 }
1438 }
1439}
1440
1441void addPrintfRequirements(const MachineInstr &MI,
1443 const SPIRVSubtarget &ST) {
1444 SPIRVGlobalRegistry *GR = ST.getSPIRVGlobalRegistry();
1445 SPIRVTypeInst PtrType = GR->getSPIRVTypeForVReg(MI.getOperand(4).getReg());
1446 if (PtrType) {
1447 MachineOperand ASOp = PtrType->getOperand(1);
1448 if (ASOp.isImm()) {
1449 unsigned AddrSpace = ASOp.getImm();
1450 if (AddrSpace != SPIRV::StorageClass::UniformConstant) {
1451 if (!ST.canUseExtension(
1453 SPV_EXT_relaxed_printf_string_address_space)) {
1454 report_fatal_error("SPV_EXT_relaxed_printf_string_address_space is "
1455 "required because printf uses a format string not "
1456 "in constant address space.",
1457 false);
1458 }
1459 Reqs.addExtension(
1460 SPIRV::Extension::SPV_EXT_relaxed_printf_string_address_space);
1461 }
1462 }
1463 }
1464}
1465
1466static void addImageOperandReqs(const MachineInstr &MI,
1468 const SPIRVSubtarget &ST, unsigned OpIdx) {
1469 if (MI.getNumOperands() <= OpIdx)
1470 return;
1471 uint32_t Mask = MI.getOperand(OpIdx).getImm();
1472 for (uint32_t I = 0; I < 32; ++I)
1473 if (Mask & (1U << I))
1474 Reqs.getAndAddRequirements(SPIRV::OperandCategory::ImageOperandOperand,
1475 1U << I, ST);
1476}
1477
1478void addInstrRequirements(const MachineInstr &MI,
1480 const SPIRVSubtarget &ST) {
1481 SPIRV::RequirementHandler &Reqs = MAI.Reqs;
1482 unsigned Op = MI.getOpcode();
1483 switch (Op) {
1484 case SPIRV::OpMemoryModel: {
1485 int64_t Addr = MI.getOperand(0).getImm();
1486 Reqs.getAndAddRequirements(SPIRV::OperandCategory::AddressingModelOperand,
1487 Addr, ST);
1488 int64_t Mem = MI.getOperand(1).getImm();
1489 Reqs.getAndAddRequirements(SPIRV::OperandCategory::MemoryModelOperand, Mem,
1490 ST);
1491 break;
1492 }
1493 case SPIRV::OpEntryPoint: {
1494 int64_t Exe = MI.getOperand(0).getImm();
1495 Reqs.getAndAddRequirements(SPIRV::OperandCategory::ExecutionModelOperand,
1496 Exe, ST);
1497 break;
1498 }
1499 case SPIRV::OpExecutionMode:
1500 case SPIRV::OpExecutionModeId: {
1501 int64_t Exe = MI.getOperand(1).getImm();
1502 Reqs.getAndAddRequirements(SPIRV::OperandCategory::ExecutionModeOperand,
1503 Exe, ST);
1504 break;
1505 }
1506 case SPIRV::OpTypeMatrix:
1507 Reqs.addCapability(SPIRV::Capability::Matrix);
1508 break;
1509 case SPIRV::OpTypeInt: {
1510 unsigned BitWidth = MI.getOperand(1).getImm();
1511 if (BitWidth == 64)
1512 Reqs.addCapability(SPIRV::Capability::Int64);
1513 else if (BitWidth == 16)
1514 Reqs.addCapability(SPIRV::Capability::Int16);
1515 else if (BitWidth == 8)
1516 Reqs.addCapability(SPIRV::Capability::Int8);
1517 else if (BitWidth == 4 &&
1518 ST.canUseExtension(SPIRV::Extension::SPV_INTEL_int4)) {
1519 Reqs.addExtension(SPIRV::Extension::SPV_INTEL_int4);
1520 Reqs.addCapability(SPIRV::Capability::Int4TypeINTEL);
1521 } else if (BitWidth != 32) {
1522 if (!ST.canUseExtension(
1523 SPIRV::Extension::SPV_ALTERA_arbitrary_precision_integers))
1525 "OpTypeInt type with a width other than 8, 16, 32 or 64 bits "
1526 "requires the following SPIR-V extension: "
1527 "SPV_ALTERA_arbitrary_precision_integers");
1528 Reqs.addExtension(
1529 SPIRV::Extension::SPV_ALTERA_arbitrary_precision_integers);
1530 Reqs.addCapability(SPIRV::Capability::ArbitraryPrecisionIntegersALTERA);
1531 }
1532 break;
1533 }
1534 case SPIRV::OpDot: {
1535 const MachineRegisterInfo &MRI = MI.getMF()->getRegInfo();
1536 SPIRVTypeInst TypeDef = MRI.getVRegDef(MI.getOperand(1).getReg());
1537 if (isBFloat16Type(TypeDef))
1538 Reqs.addCapability(SPIRV::Capability::BFloat16DotProductKHR);
1539 break;
1540 }
1541 case SPIRV::OpTypeFloat: {
1542 unsigned BitWidth = MI.getOperand(1).getImm();
1543 if (BitWidth == 64)
1544 Reqs.addCapability(SPIRV::Capability::Float64);
1545 else if (BitWidth == 16) {
1546 if (isBFloat16Type(&MI)) {
1547 if (!ST.canUseExtension(SPIRV::Extension::SPV_KHR_bfloat16))
1548 report_fatal_error("OpTypeFloat type with bfloat requires the "
1549 "following SPIR-V extension: SPV_KHR_bfloat16",
1550 false);
1551 Reqs.addExtension(SPIRV::Extension::SPV_KHR_bfloat16);
1552 Reqs.addCapability(SPIRV::Capability::BFloat16TypeKHR);
1553 } else {
1554 Reqs.addCapability(SPIRV::Capability::Float16);
1555 }
1556 }
1557 break;
1558 }
1559 case SPIRV::OpTypeVector: {
1560 unsigned NumComponents = MI.getOperand(2).getImm();
1561 if (NumComponents == 8 || NumComponents == 16)
1562 Reqs.addCapability(SPIRV::Capability::Vector16);
1563
1564 assert(MI.getOperand(1).isReg());
1565 const MachineRegisterInfo &MRI = MI.getMF()->getRegInfo();
1566 SPIRVTypeInst ElemTypeDef = MRI.getVRegDef(MI.getOperand(1).getReg());
1567 if (ElemTypeDef->getOpcode() == SPIRV::OpTypePointer &&
1568 ST.canUseExtension(SPIRV::Extension::SPV_INTEL_masked_gather_scatter)) {
1569 Reqs.addExtension(SPIRV::Extension::SPV_INTEL_masked_gather_scatter);
1570 Reqs.addCapability(SPIRV::Capability::MaskedGatherScatterINTEL);
1571 }
1572 break;
1573 }
1574 case SPIRV::OpTypePointer: {
1575 auto SC = MI.getOperand(1).getImm();
1576 Reqs.getAndAddRequirements(SPIRV::OperandCategory::StorageClassOperand, SC,
1577 ST);
1578 // If it's a type of pointer to float16 targeting OpenCL, add Float16Buffer
1579 // capability.
1580 if (ST.isShader())
1581 break;
1582 assert(MI.getOperand(2).isReg());
1583 const MachineRegisterInfo &MRI = MI.getMF()->getRegInfo();
1584 SPIRVTypeInst TypeDef = MRI.getVRegDef(MI.getOperand(2).getReg());
1585 if ((TypeDef->getNumOperands() == 2) &&
1586 (TypeDef->getOpcode() == SPIRV::OpTypeFloat) &&
1587 (TypeDef->getOperand(1).getImm() == 16))
1588 Reqs.addCapability(SPIRV::Capability::Float16Buffer);
1589 break;
1590 }
1591 case SPIRV::OpExtInst: {
1592 if (MI.getOperand(2).getImm() ==
1593 static_cast<int64_t>(
1594 SPIRV::InstructionSet::NonSemantic_Shader_DebugInfo_100)) {
1595 Reqs.addExtension(SPIRV::Extension::SPV_KHR_non_semantic_info);
1596 break;
1597 }
1598 if (MI.getOperand(3).getImm() ==
1599 static_cast<int64_t>(SPIRV::OpenCLExtInst::printf)) {
1600 addPrintfRequirements(MI, Reqs, ST);
1601 break;
1602 }
1603 // TODO: handle bfloat16 extended instructions when
1604 // SPV_INTEL_bfloat16_arithmetic is enabled.
1605 break;
1606 }
1607 case SPIRV::OpAliasDomainDeclINTEL:
1608 case SPIRV::OpAliasScopeDeclINTEL:
1609 case SPIRV::OpAliasScopeListDeclINTEL: {
1610 Reqs.addExtension(SPIRV::Extension::SPV_INTEL_memory_access_aliasing);
1611 Reqs.addCapability(SPIRV::Capability::MemoryAccessAliasingINTEL);
1612 break;
1613 }
1614 case SPIRV::OpBitReverse:
1615 case SPIRV::OpBitFieldInsert:
1616 case SPIRV::OpBitFieldSExtract:
1617 case SPIRV::OpBitFieldUExtract:
1618 if (!ST.canUseExtension(SPIRV::Extension::SPV_KHR_bit_instructions)) {
1619 Reqs.addCapability(SPIRV::Capability::Shader);
1620 break;
1621 }
1622 Reqs.addExtension(SPIRV::Extension::SPV_KHR_bit_instructions);
1623 Reqs.addCapability(SPIRV::Capability::BitInstructions);
1624 break;
1625 case SPIRV::OpTypeRuntimeArray:
1626 Reqs.addCapability(SPIRV::Capability::Shader);
1627 break;
1628 case SPIRV::OpTypeOpaque:
1629 case SPIRV::OpTypeEvent:
1630 Reqs.addCapability(SPIRV::Capability::Kernel);
1631 break;
1632 case SPIRV::OpTypePipe:
1633 case SPIRV::OpTypeReserveId:
1634 Reqs.addCapability(SPIRV::Capability::Pipes);
1635 break;
1636 case SPIRV::OpTypeDeviceEvent:
1637 case SPIRV::OpTypeQueue:
1638 case SPIRV::OpBuildNDRange:
1639 Reqs.addCapability(SPIRV::Capability::DeviceEnqueue);
1640 break;
1641 case SPIRV::OpDecorate:
1642 case SPIRV::OpDecorateId:
1643 case SPIRV::OpDecorateString:
1644 addOpDecorateReqs(MI, 1, Reqs, ST);
1645 break;
1646 case SPIRV::OpMemberDecorate:
1647 case SPIRV::OpMemberDecorateString:
1648 addOpDecorateReqs(MI, 2, Reqs, ST);
1649 break;
1650 case SPIRV::OpInBoundsPtrAccessChain:
1651 Reqs.addCapability(SPIRV::Capability::Addresses);
1652 break;
1653 case SPIRV::OpConstantSampler:
1654 Reqs.addCapability(SPIRV::Capability::LiteralSampler);
1655 break;
1656 case SPIRV::OpInBoundsAccessChain:
1657 case SPIRV::OpAccessChain:
1658 addOpAccessChainReqs(MI, Reqs, ST);
1659 break;
1660 case SPIRV::OpTypeImage:
1661 addOpTypeImageReqs(MI, Reqs, ST);
1662 break;
1663 case SPIRV::OpTypeSampler:
1664 if (!ST.isShader()) {
1665 Reqs.addCapability(SPIRV::Capability::ImageBasic);
1666 }
1667 break;
1668 case SPIRV::OpTypeForwardPointer:
1669 // TODO: check if it's OpenCL's kernel.
1670 Reqs.addCapability(SPIRV::Capability::Addresses);
1671 break;
1672 case SPIRV::OpAtomicFlagTestAndSet:
1673 case SPIRV::OpAtomicLoad:
1674 case SPIRV::OpAtomicStore:
1675 case SPIRV::OpAtomicExchange:
1676 case SPIRV::OpAtomicCompareExchange:
1677 case SPIRV::OpAtomicIIncrement:
1678 case SPIRV::OpAtomicIDecrement:
1679 case SPIRV::OpAtomicIAdd:
1680 case SPIRV::OpAtomicISub:
1681 case SPIRV::OpAtomicUMin:
1682 case SPIRV::OpAtomicUMax:
1683 case SPIRV::OpAtomicSMin:
1684 case SPIRV::OpAtomicSMax:
1685 case SPIRV::OpAtomicAnd:
1686 case SPIRV::OpAtomicOr:
1687 case SPIRV::OpAtomicXor: {
1688 const MachineRegisterInfo &MRI = MI.getMF()->getRegInfo();
1689 const MachineInstr *InstrPtr = &MI;
1690 if (Op == SPIRV::OpAtomicStore) {
1691 assert(MI.getOperand(3).isReg());
1692 InstrPtr = MRI.getVRegDef(MI.getOperand(3).getReg());
1693 assert(InstrPtr && "Unexpected type instruction for OpAtomicStore");
1694 }
1695 assert(InstrPtr->getOperand(1).isReg() && "Unexpected operand in atomic");
1696 Register TypeReg = InstrPtr->getOperand(1).getReg();
1697 SPIRVTypeInst TypeDef = MRI.getVRegDef(TypeReg);
1698
1699 if (TypeDef->getOpcode() == SPIRV::OpTypeInt) {
1700 unsigned BitWidth = TypeDef->getOperand(1).getImm();
1701 if (BitWidth == 64)
1702 Reqs.addCapability(SPIRV::Capability::Int64Atomics);
1703 else if (BitWidth == 16) {
1704 if (!ST.canUseExtension(SPIRV::Extension::SPV_INTEL_16bit_atomics))
1706 "16-bit integer atomic operations require the following SPIR-V "
1707 "extension: SPV_INTEL_16bit_atomics",
1708 false);
1709 Reqs.addExtension(SPIRV::Extension::SPV_INTEL_16bit_atomics);
1710 switch (Op) {
1711 case SPIRV::OpAtomicLoad:
1712 case SPIRV::OpAtomicStore:
1713 case SPIRV::OpAtomicExchange:
1714 case SPIRV::OpAtomicCompareExchange:
1715 case SPIRV::OpAtomicCompareExchangeWeak:
1716 Reqs.addCapability(
1717 SPIRV::Capability::AtomicInt16CompareExchangeINTEL);
1718 break;
1719 default:
1720 Reqs.addCapability(SPIRV::Capability::Int16AtomicsINTEL);
1721 break;
1722 }
1723 }
1724 } else if (isBFloat16Type(TypeDef)) {
1725 if (is_contained({SPIRV::OpAtomicLoad, SPIRV::OpAtomicStore,
1726 SPIRV::OpAtomicExchange},
1727 Op)) {
1728 if (!ST.canUseExtension(SPIRV::Extension::SPV_INTEL_16bit_atomics))
1730 "The atomic bfloat16 instruction requires the following SPIR-V "
1731 "extension: SPV_INTEL_16bit_atomics",
1732 false);
1733 Reqs.addExtension(SPIRV::Extension::SPV_INTEL_16bit_atomics);
1734 Reqs.addCapability(SPIRV::Capability::AtomicBFloat16LoadStoreINTEL);
1735 }
1736 }
1737 break;
1738 }
1739 case SPIRV::OpGroupNonUniformIAdd:
1740 case SPIRV::OpGroupNonUniformFAdd:
1741 case SPIRV::OpGroupNonUniformIMul:
1742 case SPIRV::OpGroupNonUniformFMul:
1743 case SPIRV::OpGroupNonUniformSMin:
1744 case SPIRV::OpGroupNonUniformUMin:
1745 case SPIRV::OpGroupNonUniformFMin:
1746 case SPIRV::OpGroupNonUniformSMax:
1747 case SPIRV::OpGroupNonUniformUMax:
1748 case SPIRV::OpGroupNonUniformFMax:
1749 case SPIRV::OpGroupNonUniformBitwiseAnd:
1750 case SPIRV::OpGroupNonUniformBitwiseOr:
1751 case SPIRV::OpGroupNonUniformBitwiseXor:
1752 case SPIRV::OpGroupNonUniformLogicalAnd:
1753 case SPIRV::OpGroupNonUniformLogicalOr:
1754 case SPIRV::OpGroupNonUniformLogicalXor: {
1755 assert(MI.getOperand(3).isImm());
1756 int64_t GroupOp = MI.getOperand(3).getImm();
1757 switch (GroupOp) {
1758 case SPIRV::GroupOperation::Reduce:
1759 case SPIRV::GroupOperation::InclusiveScan:
1760 case SPIRV::GroupOperation::ExclusiveScan:
1761 Reqs.addCapability(SPIRV::Capability::GroupNonUniformArithmetic);
1762 break;
1763 case SPIRV::GroupOperation::ClusteredReduce:
1764 Reqs.addCapability(SPIRV::Capability::GroupNonUniformClustered);
1765 break;
1766 case SPIRV::GroupOperation::PartitionedReduceNV:
1767 case SPIRV::GroupOperation::PartitionedInclusiveScanNV:
1768 case SPIRV::GroupOperation::PartitionedExclusiveScanNV:
1769 Reqs.addCapability(SPIRV::Capability::GroupNonUniformPartitionedNV);
1770 break;
1771 }
1772 break;
1773 }
1774 case SPIRV::OpGroupNonUniformQuadSwap:
1775 Reqs.addCapability(SPIRV::Capability::GroupNonUniformQuad);
1776 break;
1777 case SPIRV::OpImageQueryLod:
1778 Reqs.addCapability(SPIRV::Capability::ImageQuery);
1779 break;
1780 case SPIRV::OpImageQuerySize:
1781 case SPIRV::OpImageQuerySizeLod:
1782 case SPIRV::OpImageQueryLevels:
1783 case SPIRV::OpImageQuerySamples:
1784 if (ST.isShader())
1785 Reqs.addCapability(SPIRV::Capability::ImageQuery);
1786 break;
1787 case SPIRV::OpImageQueryFormat: {
1788 Register ResultReg = MI.getOperand(0).getReg();
1789 const MachineRegisterInfo &MRI = MI.getMF()->getRegInfo();
1790 static const unsigned CompareOps[] = {
1791 SPIRV::OpIEqual, SPIRV::OpINotEqual,
1792 SPIRV::OpUGreaterThan, SPIRV::OpUGreaterThanEqual,
1793 SPIRV::OpULessThan, SPIRV::OpULessThanEqual,
1794 SPIRV::OpSGreaterThan, SPIRV::OpSGreaterThanEqual,
1795 SPIRV::OpSLessThan, SPIRV::OpSLessThanEqual};
1796
1797 auto CheckAndAddExtension = [&](int64_t ImmVal) {
1798 if (ImmVal == 4323 || ImmVal == 4324) {
1799 if (ST.canUseExtension(SPIRV::Extension::SPV_EXT_image_raw10_raw12))
1800 Reqs.addExtension(SPIRV::Extension::SPV_EXT_image_raw10_raw12);
1801 else
1802 report_fatal_error("This requires the "
1803 "SPV_EXT_image_raw10_raw12 extension");
1804 }
1805 };
1806
1807 for (MachineInstr &UseInst : MRI.use_instructions(ResultReg)) {
1808 unsigned Opc = UseInst.getOpcode();
1809
1810 if (Opc == SPIRV::OpSwitch) {
1811 for (const MachineOperand &Op : UseInst.operands())
1812 if (Op.isImm())
1813 CheckAndAddExtension(Op.getImm());
1814 } else if (llvm::is_contained(CompareOps, Opc)) {
1815 for (unsigned i = 1; i < UseInst.getNumOperands(); ++i) {
1816 Register UseReg = UseInst.getOperand(i).getReg();
1817 MachineInstr *ConstInst = MRI.getVRegDef(UseReg);
1818 if (ConstInst && ConstInst->getOpcode() == SPIRV::OpConstantI) {
1819 int64_t ImmVal = ConstInst->getOperand(2).getImm();
1820 if (ImmVal)
1821 CheckAndAddExtension(ImmVal);
1822 }
1823 }
1824 }
1825 }
1826 break;
1827 }
1828
1829 case SPIRV::OpGroupNonUniformShuffle:
1830 case SPIRV::OpGroupNonUniformShuffleXor:
1831 Reqs.addCapability(SPIRV::Capability::GroupNonUniformShuffle);
1832 break;
1833 case SPIRV::OpGroupNonUniformShuffleUp:
1834 case SPIRV::OpGroupNonUniformShuffleDown:
1835 Reqs.addCapability(SPIRV::Capability::GroupNonUniformShuffleRelative);
1836 break;
1837 case SPIRV::OpGroupAll:
1838 case SPIRV::OpGroupAny:
1839 case SPIRV::OpGroupBroadcast:
1840 case SPIRV::OpGroupIAdd:
1841 case SPIRV::OpGroupFAdd:
1842 case SPIRV::OpGroupFMin:
1843 case SPIRV::OpGroupUMin:
1844 case SPIRV::OpGroupSMin:
1845 case SPIRV::OpGroupFMax:
1846 case SPIRV::OpGroupUMax:
1847 case SPIRV::OpGroupSMax:
1848 Reqs.addCapability(SPIRV::Capability::Groups);
1849 break;
1850 case SPIRV::OpGroupNonUniformElect:
1851 Reqs.addCapability(SPIRV::Capability::GroupNonUniform);
1852 break;
1853 case SPIRV::OpGroupNonUniformAll:
1854 case SPIRV::OpGroupNonUniformAny:
1855 case SPIRV::OpGroupNonUniformAllEqual:
1856 Reqs.addCapability(SPIRV::Capability::GroupNonUniformVote);
1857 break;
1858 case SPIRV::OpGroupNonUniformBroadcast:
1859 case SPIRV::OpGroupNonUniformBroadcastFirst:
1860 case SPIRV::OpGroupNonUniformBallot:
1861 case SPIRV::OpGroupNonUniformInverseBallot:
1862 case SPIRV::OpGroupNonUniformBallotBitExtract:
1863 case SPIRV::OpGroupNonUniformBallotBitCount:
1864 case SPIRV::OpGroupNonUniformBallotFindLSB:
1865 case SPIRV::OpGroupNonUniformBallotFindMSB:
1866 Reqs.addCapability(SPIRV::Capability::GroupNonUniformBallot);
1867 break;
1868 case SPIRV::OpSubgroupShuffleINTEL:
1869 case SPIRV::OpSubgroupShuffleDownINTEL:
1870 case SPIRV::OpSubgroupShuffleUpINTEL:
1871 case SPIRV::OpSubgroupShuffleXorINTEL:
1872 if (ST.canUseExtension(SPIRV::Extension::SPV_INTEL_subgroups)) {
1873 Reqs.addExtension(SPIRV::Extension::SPV_INTEL_subgroups);
1874 Reqs.addCapability(SPIRV::Capability::SubgroupShuffleINTEL);
1875 }
1876 break;
1877 case SPIRV::OpSubgroupBlockReadINTEL:
1878 case SPIRV::OpSubgroupBlockWriteINTEL:
1879 if (ST.canUseExtension(SPIRV::Extension::SPV_INTEL_subgroups)) {
1880 Reqs.addExtension(SPIRV::Extension::SPV_INTEL_subgroups);
1881 Reqs.addCapability(SPIRV::Capability::SubgroupBufferBlockIOINTEL);
1882 }
1883 break;
1884 case SPIRV::OpSubgroupImageBlockReadINTEL:
1885 case SPIRV::OpSubgroupImageBlockWriteINTEL:
1886 if (ST.canUseExtension(SPIRV::Extension::SPV_INTEL_subgroups)) {
1887 Reqs.addExtension(SPIRV::Extension::SPV_INTEL_subgroups);
1888 Reqs.addCapability(SPIRV::Capability::SubgroupImageBlockIOINTEL);
1889 }
1890 break;
1891 case SPIRV::OpSubgroupImageMediaBlockReadINTEL:
1892 case SPIRV::OpSubgroupImageMediaBlockWriteINTEL:
1893 if (ST.canUseExtension(SPIRV::Extension::SPV_INTEL_media_block_io)) {
1894 Reqs.addExtension(SPIRV::Extension::SPV_INTEL_media_block_io);
1895 Reqs.addCapability(SPIRV::Capability::SubgroupImageMediaBlockIOINTEL);
1896 }
1897 break;
1898 case SPIRV::OpAssumeTrueKHR:
1899 case SPIRV::OpExpectKHR:
1900 if (ST.canUseExtension(SPIRV::Extension::SPV_KHR_expect_assume)) {
1901 Reqs.addExtension(SPIRV::Extension::SPV_KHR_expect_assume);
1902 Reqs.addCapability(SPIRV::Capability::ExpectAssumeKHR);
1903 }
1904 break;
1905 case SPIRV::OpFmaKHR:
1906 if (ST.canUseExtension(SPIRV::Extension::SPV_KHR_fma)) {
1907 Reqs.addExtension(SPIRV::Extension::SPV_KHR_fma);
1908 Reqs.addCapability(SPIRV::Capability::FmaKHR);
1909 }
1910 break;
1911 case SPIRV::OpPtrCastToCrossWorkgroupINTEL:
1912 case SPIRV::OpCrossWorkgroupCastToPtrINTEL:
1913 if (ST.canUseExtension(SPIRV::Extension::SPV_INTEL_usm_storage_classes)) {
1914 Reqs.addExtension(SPIRV::Extension::SPV_INTEL_usm_storage_classes);
1915 Reqs.addCapability(SPIRV::Capability::USMStorageClassesINTEL);
1916 }
1917 break;
1918 case SPIRV::OpConstantFunctionPointerINTEL:
1919 if (ST.canUseExtension(SPIRV::Extension::SPV_INTEL_function_pointers)) {
1920 Reqs.addExtension(SPIRV::Extension::SPV_INTEL_function_pointers);
1921 Reqs.addCapability(SPIRV::Capability::FunctionPointersINTEL);
1922 }
1923 break;
1924 case SPIRV::OpGroupNonUniformRotateKHR:
1925 if (!ST.canUseExtension(SPIRV::Extension::SPV_KHR_subgroup_rotate))
1926 report_fatal_error("OpGroupNonUniformRotateKHR instruction requires the "
1927 "following SPIR-V extension: SPV_KHR_subgroup_rotate",
1928 false);
1929 Reqs.addExtension(SPIRV::Extension::SPV_KHR_subgroup_rotate);
1930 Reqs.addCapability(SPIRV::Capability::GroupNonUniformRotateKHR);
1931 Reqs.addCapability(SPIRV::Capability::GroupNonUniform);
1932 break;
1933 case SPIRV::OpFixedCosALTERA:
1934 case SPIRV::OpFixedSinALTERA:
1935 case SPIRV::OpFixedCosPiALTERA:
1936 case SPIRV::OpFixedSinPiALTERA:
1937 case SPIRV::OpFixedExpALTERA:
1938 case SPIRV::OpFixedLogALTERA:
1939 case SPIRV::OpFixedRecipALTERA:
1940 case SPIRV::OpFixedSqrtALTERA:
1941 case SPIRV::OpFixedSinCosALTERA:
1942 case SPIRV::OpFixedSinCosPiALTERA:
1943 case SPIRV::OpFixedRsqrtALTERA:
1944 if (!ST.canUseExtension(
1945 SPIRV::Extension::SPV_ALTERA_arbitrary_precision_fixed_point))
1946 report_fatal_error("This instruction requires the "
1947 "following SPIR-V extension: "
1948 "SPV_ALTERA_arbitrary_precision_fixed_point",
1949 false);
1950 Reqs.addExtension(
1951 SPIRV::Extension::SPV_ALTERA_arbitrary_precision_fixed_point);
1952 Reqs.addCapability(SPIRV::Capability::ArbitraryPrecisionFixedPointALTERA);
1953 break;
1954 case SPIRV::OpGroupIMulKHR:
1955 case SPIRV::OpGroupFMulKHR:
1956 case SPIRV::OpGroupBitwiseAndKHR:
1957 case SPIRV::OpGroupBitwiseOrKHR:
1958 case SPIRV::OpGroupBitwiseXorKHR:
1959 case SPIRV::OpGroupLogicalAndKHR:
1960 case SPIRV::OpGroupLogicalOrKHR:
1961 case SPIRV::OpGroupLogicalXorKHR:
1962 if (ST.canUseExtension(
1963 SPIRV::Extension::SPV_KHR_uniform_group_instructions)) {
1964 Reqs.addExtension(SPIRV::Extension::SPV_KHR_uniform_group_instructions);
1965 Reqs.addCapability(SPIRV::Capability::GroupUniformArithmeticKHR);
1966 }
1967 break;
1968 case SPIRV::OpReadClockKHR:
1969 if (!ST.canUseExtension(SPIRV::Extension::SPV_KHR_shader_clock))
1970 report_fatal_error("OpReadClockKHR instruction requires the "
1971 "following SPIR-V extension: SPV_KHR_shader_clock",
1972 false);
1973 Reqs.addExtension(SPIRV::Extension::SPV_KHR_shader_clock);
1974 Reqs.addCapability(SPIRV::Capability::ShaderClockKHR);
1975 break;
1976 case SPIRV::OpAbortKHR:
1977 if (!ST.canUseExtension(SPIRV::Extension::SPV_KHR_abort))
1978 report_fatal_error("OpAbortKHR instruction requires the "
1979 "following SPIR-V extension: SPV_KHR_abort",
1980 false);
1981 Reqs.addExtension(SPIRV::Extension::SPV_KHR_abort);
1982 Reqs.addCapability(SPIRV::Capability::AbortKHR);
1983 break;
1984 case SPIRV::OpFunctionPointerCallINTEL:
1985 if (ST.canUseExtension(SPIRV::Extension::SPV_INTEL_function_pointers)) {
1986 Reqs.addExtension(SPIRV::Extension::SPV_INTEL_function_pointers);
1987 Reqs.addCapability(SPIRV::Capability::FunctionPointersINTEL);
1988 }
1989 break;
1990 case SPIRV::OpAtomicFAddEXT:
1991 case SPIRV::OpAtomicFMinEXT:
1992 case SPIRV::OpAtomicFMaxEXT:
1993 AddAtomicFloatRequirements(MI, Reqs, ST);
1994 break;
1995 case SPIRV::OpConvertBF16ToFINTEL:
1996 case SPIRV::OpConvertFToBF16INTEL:
1997 if (ST.canUseExtension(SPIRV::Extension::SPV_INTEL_bfloat16_conversion)) {
1998 Reqs.addExtension(SPIRV::Extension::SPV_INTEL_bfloat16_conversion);
1999 Reqs.addCapability(SPIRV::Capability::BFloat16ConversionINTEL);
2000 }
2001 break;
2002 case SPIRV::OpRoundFToTF32INTEL:
2003 if (ST.canUseExtension(
2004 SPIRV::Extension::SPV_INTEL_tensor_float32_conversion)) {
2005 Reqs.addExtension(SPIRV::Extension::SPV_INTEL_tensor_float32_conversion);
2006 Reqs.addCapability(SPIRV::Capability::TensorFloat32RoundingINTEL);
2007 }
2008 break;
2009 case SPIRV::OpVariableLengthArrayINTEL:
2010 case SPIRV::OpSaveMemoryINTEL:
2011 case SPIRV::OpRestoreMemoryINTEL:
2012 if (ST.canUseExtension(SPIRV::Extension::SPV_INTEL_variable_length_array)) {
2013 Reqs.addExtension(SPIRV::Extension::SPV_INTEL_variable_length_array);
2014 Reqs.addCapability(SPIRV::Capability::VariableLengthArrayINTEL);
2015 }
2016 break;
2017 case SPIRV::OpAsmTargetINTEL:
2018 case SPIRV::OpAsmINTEL:
2019 case SPIRV::OpAsmCallINTEL:
2020 if (ST.canUseExtension(SPIRV::Extension::SPV_INTEL_inline_assembly)) {
2021 Reqs.addExtension(SPIRV::Extension::SPV_INTEL_inline_assembly);
2022 Reqs.addCapability(SPIRV::Capability::AsmINTEL);
2023 }
2024 break;
2025 case SPIRV::OpTypeCooperativeMatrixKHR: {
2026 if (!ST.canUseExtension(SPIRV::Extension::SPV_KHR_cooperative_matrix))
2028 "OpTypeCooperativeMatrixKHR type requires the "
2029 "following SPIR-V extension: SPV_KHR_cooperative_matrix",
2030 false);
2031 Reqs.addExtension(SPIRV::Extension::SPV_KHR_cooperative_matrix);
2032 Reqs.addCapability(SPIRV::Capability::CooperativeMatrixKHR);
2033 const MachineRegisterInfo &MRI = MI.getMF()->getRegInfo();
2034 SPIRVTypeInst TypeDef = MRI.getVRegDef(MI.getOperand(1).getReg());
2035 if (isBFloat16Type(TypeDef))
2036 Reqs.addCapability(SPIRV::Capability::BFloat16CooperativeMatrixKHR);
2037 break;
2038 }
2039 case SPIRV::OpArithmeticFenceEXT:
2040 if (!ST.canUseExtension(SPIRV::Extension::SPV_EXT_arithmetic_fence))
2041 report_fatal_error("OpArithmeticFenceEXT requires the "
2042 "following SPIR-V extension: SPV_EXT_arithmetic_fence",
2043 false);
2044 Reqs.addExtension(SPIRV::Extension::SPV_EXT_arithmetic_fence);
2045 Reqs.addCapability(SPIRV::Capability::ArithmeticFenceEXT);
2046 break;
2047 case SPIRV::OpControlBarrierArriveINTEL:
2048 case SPIRV::OpControlBarrierWaitINTEL:
2049 if (ST.canUseExtension(SPIRV::Extension::SPV_INTEL_split_barrier)) {
2050 Reqs.addExtension(SPIRV::Extension::SPV_INTEL_split_barrier);
2051 Reqs.addCapability(SPIRV::Capability::SplitBarrierINTEL);
2052 }
2053 break;
2054 case SPIRV::OpCooperativeMatrixMulAddKHR: {
2055 if (!ST.canUseExtension(SPIRV::Extension::SPV_KHR_cooperative_matrix))
2056 report_fatal_error("Cooperative matrix instructions require the "
2057 "following SPIR-V extension: "
2058 "SPV_KHR_cooperative_matrix",
2059 false);
2060 Reqs.addExtension(SPIRV::Extension::SPV_KHR_cooperative_matrix);
2061 Reqs.addCapability(SPIRV::Capability::CooperativeMatrixKHR);
2062 constexpr unsigned MulAddMaxSize = 6;
2063 if (MI.getNumOperands() != MulAddMaxSize)
2064 break;
2065 const int64_t CoopOperands = MI.getOperand(MulAddMaxSize - 1).getImm();
2066 if (CoopOperands &
2067 SPIRV::CooperativeMatrixOperands::MatrixAAndBTF32ComponentsINTEL) {
2068 if (!ST.canUseExtension(SPIRV::Extension::SPV_INTEL_joint_matrix))
2069 report_fatal_error("MatrixAAndBTF32ComponentsINTEL type interpretation "
2070 "require the following SPIR-V extension: "
2071 "SPV_INTEL_joint_matrix",
2072 false);
2073 Reqs.addExtension(SPIRV::Extension::SPV_INTEL_joint_matrix);
2074 Reqs.addCapability(
2075 SPIRV::Capability::CooperativeMatrixTF32ComponentTypeINTEL);
2076 }
2077 if (CoopOperands & SPIRV::CooperativeMatrixOperands::
2078 MatrixAAndBBFloat16ComponentsINTEL ||
2079 CoopOperands &
2080 SPIRV::CooperativeMatrixOperands::MatrixCBFloat16ComponentsINTEL ||
2081 CoopOperands & SPIRV::CooperativeMatrixOperands::
2082 MatrixResultBFloat16ComponentsINTEL) {
2083 if (!ST.canUseExtension(SPIRV::Extension::SPV_INTEL_joint_matrix))
2084 report_fatal_error("***BF16ComponentsINTEL type interpretations "
2085 "require the following SPIR-V extension: "
2086 "SPV_INTEL_joint_matrix",
2087 false);
2088 Reqs.addExtension(SPIRV::Extension::SPV_INTEL_joint_matrix);
2089 Reqs.addCapability(
2090 SPIRV::Capability::CooperativeMatrixBFloat16ComponentTypeINTEL);
2091 }
2092 break;
2093 }
2094 case SPIRV::OpCooperativeMatrixLoadKHR:
2095 case SPIRV::OpCooperativeMatrixStoreKHR:
2096 case SPIRV::OpCooperativeMatrixLoadCheckedINTEL:
2097 case SPIRV::OpCooperativeMatrixStoreCheckedINTEL:
2098 case SPIRV::OpCooperativeMatrixPrefetchINTEL: {
2099 if (!ST.canUseExtension(SPIRV::Extension::SPV_KHR_cooperative_matrix))
2100 report_fatal_error("Cooperative matrix instructions require the "
2101 "following SPIR-V extension: "
2102 "SPV_KHR_cooperative_matrix",
2103 false);
2104 Reqs.addExtension(SPIRV::Extension::SPV_KHR_cooperative_matrix);
2105 Reqs.addCapability(SPIRV::Capability::CooperativeMatrixKHR);
2106
2107 // Check Layout operand in case if it's not a standard one and add the
2108 // appropriate capability.
2109 std::unordered_map<unsigned, unsigned> LayoutToInstMap = {
2110 {SPIRV::OpCooperativeMatrixLoadKHR, 3},
2111 {SPIRV::OpCooperativeMatrixStoreKHR, 2},
2112 {SPIRV::OpCooperativeMatrixLoadCheckedINTEL, 5},
2113 {SPIRV::OpCooperativeMatrixStoreCheckedINTEL, 4},
2114 {SPIRV::OpCooperativeMatrixPrefetchINTEL, 4}};
2115
2116 const unsigned LayoutNum = LayoutToInstMap[Op];
2117 Register RegLayout = MI.getOperand(LayoutNum).getReg();
2118 const MachineRegisterInfo &MRI = MI.getMF()->getRegInfo();
2119 MachineInstr *MILayout = MRI.getUniqueVRegDef(RegLayout);
2120 if (MILayout->getOpcode() == SPIRV::OpConstantI) {
2121 const unsigned LayoutVal = MILayout->getOperand(2).getImm();
2122 if (LayoutVal ==
2123 static_cast<unsigned>(SPIRV::CooperativeMatrixLayout::PackedINTEL)) {
2124 if (!ST.canUseExtension(SPIRV::Extension::SPV_INTEL_joint_matrix))
2125 report_fatal_error("PackedINTEL layout require the following SPIR-V "
2126 "extension: SPV_INTEL_joint_matrix",
2127 false);
2128 Reqs.addExtension(SPIRV::Extension::SPV_INTEL_joint_matrix);
2129 Reqs.addCapability(SPIRV::Capability::PackedCooperativeMatrixINTEL);
2130 }
2131 }
2132
2133 // Nothing to do.
2134 if (Op == SPIRV::OpCooperativeMatrixLoadKHR ||
2135 Op == SPIRV::OpCooperativeMatrixStoreKHR)
2136 break;
2137
2138 std::string InstName;
2139 switch (Op) {
2140 case SPIRV::OpCooperativeMatrixPrefetchINTEL:
2141 InstName = "OpCooperativeMatrixPrefetchINTEL";
2142 break;
2143 case SPIRV::OpCooperativeMatrixLoadCheckedINTEL:
2144 InstName = "OpCooperativeMatrixLoadCheckedINTEL";
2145 break;
2146 case SPIRV::OpCooperativeMatrixStoreCheckedINTEL:
2147 InstName = "OpCooperativeMatrixStoreCheckedINTEL";
2148 break;
2149 }
2150
2151 if (!ST.canUseExtension(SPIRV::Extension::SPV_INTEL_joint_matrix)) {
2152 const std::string ErrorMsg =
2153 InstName + " instruction requires the "
2154 "following SPIR-V extension: SPV_INTEL_joint_matrix";
2155 report_fatal_error(ErrorMsg.c_str(), false);
2156 }
2157 Reqs.addExtension(SPIRV::Extension::SPV_INTEL_joint_matrix);
2158 if (Op == SPIRV::OpCooperativeMatrixPrefetchINTEL) {
2159 Reqs.addCapability(SPIRV::Capability::CooperativeMatrixPrefetchINTEL);
2160 break;
2161 }
2162 Reqs.addCapability(
2163 SPIRV::Capability::CooperativeMatrixCheckedInstructionsINTEL);
2164 break;
2165 }
2166 case SPIRV::OpCooperativeMatrixConstructCheckedINTEL:
2167 if (!ST.canUseExtension(SPIRV::Extension::SPV_INTEL_joint_matrix))
2168 report_fatal_error("OpCooperativeMatrixConstructCheckedINTEL "
2169 "instructions require the following SPIR-V extension: "
2170 "SPV_INTEL_joint_matrix",
2171 false);
2172 Reqs.addExtension(SPIRV::Extension::SPV_INTEL_joint_matrix);
2173 Reqs.addCapability(
2174 SPIRV::Capability::CooperativeMatrixCheckedInstructionsINTEL);
2175 break;
2176 case SPIRV::OpReadPipeBlockingALTERA:
2177 case SPIRV::OpWritePipeBlockingALTERA:
2178 if (ST.canUseExtension(SPIRV::Extension::SPV_ALTERA_blocking_pipes)) {
2179 Reqs.addExtension(SPIRV::Extension::SPV_ALTERA_blocking_pipes);
2180 Reqs.addCapability(SPIRV::Capability::BlockingPipesALTERA);
2181 }
2182 break;
2183 case SPIRV::OpCooperativeMatrixGetElementCoordINTEL:
2184 if (!ST.canUseExtension(SPIRV::Extension::SPV_INTEL_joint_matrix))
2185 report_fatal_error("OpCooperativeMatrixGetElementCoordINTEL requires the "
2186 "following SPIR-V extension: SPV_INTEL_joint_matrix",
2187 false);
2188 Reqs.addExtension(SPIRV::Extension::SPV_INTEL_joint_matrix);
2189 Reqs.addCapability(
2190 SPIRV::Capability::CooperativeMatrixInvocationInstructionsINTEL);
2191 break;
2192 case SPIRV::OpConvertHandleToImageINTEL:
2193 case SPIRV::OpConvertHandleToSamplerINTEL:
2194 case SPIRV::OpConvertHandleToSampledImageINTEL: {
2195 if (!ST.canUseExtension(SPIRV::Extension::SPV_INTEL_bindless_images))
2196 report_fatal_error("OpConvertHandleTo[Image/Sampler/SampledImage]INTEL "
2197 "instructions require the following SPIR-V extension: "
2198 "SPV_INTEL_bindless_images",
2199 false);
2200 SPIRVGlobalRegistry *GR = ST.getSPIRVGlobalRegistry();
2201 SPIRV::AddressingModel::AddressingModel AddrModel = MAI.Addr;
2202 SPIRVTypeInst TyDef = GR->getSPIRVTypeForVReg(MI.getOperand(1).getReg());
2203 if (Op == SPIRV::OpConvertHandleToImageINTEL &&
2204 TyDef->getOpcode() != SPIRV::OpTypeImage) {
2205 report_fatal_error("Incorrect return type for the instruction "
2206 "OpConvertHandleToImageINTEL",
2207 false);
2208 } else if (Op == SPIRV::OpConvertHandleToSamplerINTEL &&
2209 TyDef->getOpcode() != SPIRV::OpTypeSampler) {
2210 report_fatal_error("Incorrect return type for the instruction "
2211 "OpConvertHandleToSamplerINTEL",
2212 false);
2213 } else if (Op == SPIRV::OpConvertHandleToSampledImageINTEL &&
2214 TyDef->getOpcode() != SPIRV::OpTypeSampledImage) {
2215 report_fatal_error("Incorrect return type for the instruction "
2216 "OpConvertHandleToSampledImageINTEL",
2217 false);
2218 }
2219 SPIRVTypeInst SpvTy = GR->getSPIRVTypeForVReg(MI.getOperand(2).getReg());
2220 unsigned Bitwidth = GR->getScalarOrVectorBitWidth(SpvTy);
2221 if (!(Bitwidth == 32 && AddrModel == SPIRV::AddressingModel::Physical32) &&
2222 !(Bitwidth == 64 && AddrModel == SPIRV::AddressingModel::Physical64)) {
2224 "Parameter value must be a 32-bit scalar in case of "
2225 "Physical32 addressing model or a 64-bit scalar in case of "
2226 "Physical64 addressing model",
2227 false);
2228 }
2229 Reqs.addExtension(SPIRV::Extension::SPV_INTEL_bindless_images);
2230 Reqs.addCapability(SPIRV::Capability::BindlessImagesINTEL);
2231 break;
2232 }
2233 case SPIRV::OpSubgroup2DBlockLoadINTEL:
2234 case SPIRV::OpSubgroup2DBlockLoadTransposeINTEL:
2235 case SPIRV::OpSubgroup2DBlockLoadTransformINTEL:
2236 case SPIRV::OpSubgroup2DBlockPrefetchINTEL:
2237 case SPIRV::OpSubgroup2DBlockStoreINTEL: {
2238 if (!ST.canUseExtension(SPIRV::Extension::SPV_INTEL_2d_block_io))
2239 report_fatal_error("OpSubgroup2DBlock[Load/LoadTranspose/LoadTransform/"
2240 "Prefetch/Store]INTEL instructions require the "
2241 "following SPIR-V extension: SPV_INTEL_2d_block_io",
2242 false);
2243 Reqs.addExtension(SPIRV::Extension::SPV_INTEL_2d_block_io);
2244 Reqs.addCapability(SPIRV::Capability::Subgroup2DBlockIOINTEL);
2245
2246 if (Op == SPIRV::OpSubgroup2DBlockLoadTransposeINTEL) {
2247 Reqs.addCapability(SPIRV::Capability::Subgroup2DBlockTransposeINTEL);
2248 break;
2249 }
2250 if (Op == SPIRV::OpSubgroup2DBlockLoadTransformINTEL) {
2251 Reqs.addCapability(SPIRV::Capability::Subgroup2DBlockTransformINTEL);
2252 break;
2253 }
2254 break;
2255 }
2256 case SPIRV::OpKill: {
2257 Reqs.addCapability(SPIRV::Capability::Shader);
2258 } break;
2259 case SPIRV::OpDemoteToHelperInvocation:
2260 Reqs.addCapability(SPIRV::Capability::DemoteToHelperInvocation);
2261
2262 if (ST.canUseExtension(
2263 SPIRV::Extension::SPV_EXT_demote_to_helper_invocation)) {
2264 if (!ST.isAtLeastSPIRVVer(llvm::VersionTuple(1, 6)))
2265 Reqs.addExtension(
2266 SPIRV::Extension::SPV_EXT_demote_to_helper_invocation);
2267 }
2268 break;
2269 case SPIRV::OpSDot:
2270 case SPIRV::OpUDot:
2271 case SPIRV::OpSUDot:
2272 case SPIRV::OpSDotAccSat:
2273 case SPIRV::OpUDotAccSat:
2274 case SPIRV::OpSUDotAccSat:
2275 AddDotProductRequirements(MI, Reqs, ST);
2276 break;
2277 case SPIRV::OpImageSampleImplicitLod:
2278 Reqs.addCapability(SPIRV::Capability::Shader);
2279 addImageOperandReqs(MI, Reqs, ST, 4);
2280 break;
2281 case SPIRV::OpImageSampleExplicitLod:
2282 addImageOperandReqs(MI, Reqs, ST, 4);
2283 break;
2284 case SPIRV::OpImageSampleDrefImplicitLod:
2285 Reqs.addCapability(SPIRV::Capability::Shader);
2286 addImageOperandReqs(MI, Reqs, ST, 5);
2287 break;
2288 case SPIRV::OpImageSampleDrefExplicitLod:
2289 Reqs.addCapability(SPIRV::Capability::Shader);
2290 addImageOperandReqs(MI, Reqs, ST, 5);
2291 break;
2292 case SPIRV::OpImageFetch:
2293 Reqs.addCapability(SPIRV::Capability::Shader);
2294 addImageOperandReqs(MI, Reqs, ST, 4);
2295 break;
2296 case SPIRV::OpImageDrefGather:
2297 case SPIRV::OpImageGather:
2298 Reqs.addCapability(SPIRV::Capability::Shader);
2299 addImageOperandReqs(MI, Reqs, ST, 5);
2300 break;
2301 case SPIRV::OpImageRead: {
2302 Register ImageReg = MI.getOperand(2).getReg();
2303 SPIRVTypeInst TypeDef = ST.getSPIRVGlobalRegistry()->getResultType(
2304 ImageReg, const_cast<MachineFunction *>(MI.getMF()));
2305 // OpImageRead and OpImageWrite can use Unknown Image Formats
2306 // when the Kernel capability is declared. In the OpenCL environment we are
2307 // not allowed to produce
2308 // StorageImageReadWithoutFormat/StorageImageWriteWithoutFormat, see
2309 // https://github.com/KhronosGroup/SPIRV-Headers/issues/487
2310
2311 if (isImageTypeWithUnknownFormat(TypeDef) && ST.isShader())
2312 Reqs.addCapability(SPIRV::Capability::StorageImageReadWithoutFormat);
2313 break;
2314 }
2315 case SPIRV::OpImageWrite: {
2316 Register ImageReg = MI.getOperand(0).getReg();
2317 SPIRVTypeInst TypeDef = ST.getSPIRVGlobalRegistry()->getResultType(
2318 ImageReg, const_cast<MachineFunction *>(MI.getMF()));
2319 // OpImageRead and OpImageWrite can use Unknown Image Formats
2320 // when the Kernel capability is declared. In the OpenCL environment we are
2321 // not allowed to produce
2322 // StorageImageReadWithoutFormat/StorageImageWriteWithoutFormat, see
2323 // https://github.com/KhronosGroup/SPIRV-Headers/issues/487
2324
2325 if (isImageTypeWithUnknownFormat(TypeDef) && ST.isShader())
2326 Reqs.addCapability(SPIRV::Capability::StorageImageWriteWithoutFormat);
2327 break;
2328 }
2329 case SPIRV::OpTypeStructContinuedINTEL:
2330 case SPIRV::OpConstantCompositeContinuedINTEL:
2331 case SPIRV::OpSpecConstantCompositeContinuedINTEL:
2332 case SPIRV::OpCompositeConstructContinuedINTEL: {
2333 if (!ST.canUseExtension(SPIRV::Extension::SPV_INTEL_long_composites))
2335 "Continued instructions require the "
2336 "following SPIR-V extension: SPV_INTEL_long_composites",
2337 false);
2338 Reqs.addExtension(SPIRV::Extension::SPV_INTEL_long_composites);
2339 Reqs.addCapability(SPIRV::Capability::LongCompositesINTEL);
2340 break;
2341 }
2342 case SPIRV::OpArbitraryFloatEQALTERA:
2343 case SPIRV::OpArbitraryFloatGEALTERA:
2344 case SPIRV::OpArbitraryFloatGTALTERA:
2345 case SPIRV::OpArbitraryFloatLEALTERA:
2346 case SPIRV::OpArbitraryFloatLTALTERA:
2347 case SPIRV::OpArbitraryFloatCbrtALTERA:
2348 case SPIRV::OpArbitraryFloatCosALTERA:
2349 case SPIRV::OpArbitraryFloatCosPiALTERA:
2350 case SPIRV::OpArbitraryFloatExp10ALTERA:
2351 case SPIRV::OpArbitraryFloatExp2ALTERA:
2352 case SPIRV::OpArbitraryFloatExpALTERA:
2353 case SPIRV::OpArbitraryFloatExpm1ALTERA:
2354 case SPIRV::OpArbitraryFloatHypotALTERA:
2355 case SPIRV::OpArbitraryFloatLog10ALTERA:
2356 case SPIRV::OpArbitraryFloatLog1pALTERA:
2357 case SPIRV::OpArbitraryFloatLog2ALTERA:
2358 case SPIRV::OpArbitraryFloatLogALTERA:
2359 case SPIRV::OpArbitraryFloatRecipALTERA:
2360 case SPIRV::OpArbitraryFloatSinCosALTERA:
2361 case SPIRV::OpArbitraryFloatSinCosPiALTERA:
2362 case SPIRV::OpArbitraryFloatSinALTERA:
2363 case SPIRV::OpArbitraryFloatSinPiALTERA:
2364 case SPIRV::OpArbitraryFloatSqrtALTERA:
2365 case SPIRV::OpArbitraryFloatACosALTERA:
2366 case SPIRV::OpArbitraryFloatACosPiALTERA:
2367 case SPIRV::OpArbitraryFloatAddALTERA:
2368 case SPIRV::OpArbitraryFloatASinALTERA:
2369 case SPIRV::OpArbitraryFloatASinPiALTERA:
2370 case SPIRV::OpArbitraryFloatATan2ALTERA:
2371 case SPIRV::OpArbitraryFloatATanALTERA:
2372 case SPIRV::OpArbitraryFloatATanPiALTERA:
2373 case SPIRV::OpArbitraryFloatCastFromIntALTERA:
2374 case SPIRV::OpArbitraryFloatCastALTERA:
2375 case SPIRV::OpArbitraryFloatCastToIntALTERA:
2376 case SPIRV::OpArbitraryFloatDivALTERA:
2377 case SPIRV::OpArbitraryFloatMulALTERA:
2378 case SPIRV::OpArbitraryFloatPowALTERA:
2379 case SPIRV::OpArbitraryFloatPowNALTERA:
2380 case SPIRV::OpArbitraryFloatPowRALTERA:
2381 case SPIRV::OpArbitraryFloatRSqrtALTERA:
2382 case SPIRV::OpArbitraryFloatSubALTERA: {
2383 if (!ST.canUseExtension(
2384 SPIRV::Extension::SPV_ALTERA_arbitrary_precision_floating_point))
2386 "Floating point instructions can't be translated correctly without "
2387 "enabled SPV_ALTERA_arbitrary_precision_floating_point extension!",
2388 false);
2389 Reqs.addExtension(
2390 SPIRV::Extension::SPV_ALTERA_arbitrary_precision_floating_point);
2391 Reqs.addCapability(
2392 SPIRV::Capability::ArbitraryPrecisionFloatingPointALTERA);
2393 break;
2394 }
2395 case SPIRV::OpSubgroupMatrixMultiplyAccumulateINTEL: {
2396 if (!ST.canUseExtension(
2397 SPIRV::Extension::SPV_INTEL_subgroup_matrix_multiply_accumulate))
2399 "OpSubgroupMatrixMultiplyAccumulateINTEL instruction requires the "
2400 "following SPIR-V "
2401 "extension: SPV_INTEL_subgroup_matrix_multiply_accumulate",
2402 false);
2403 Reqs.addExtension(
2404 SPIRV::Extension::SPV_INTEL_subgroup_matrix_multiply_accumulate);
2405 Reqs.addCapability(
2406 SPIRV::Capability::SubgroupMatrixMultiplyAccumulateINTEL);
2407 break;
2408 }
2409 case SPIRV::OpBitwiseFunctionINTEL: {
2410 if (!ST.canUseExtension(
2411 SPIRV::Extension::SPV_INTEL_ternary_bitwise_function))
2413 "OpBitwiseFunctionINTEL instruction requires the following SPIR-V "
2414 "extension: SPV_INTEL_ternary_bitwise_function",
2415 false);
2416 Reqs.addExtension(SPIRV::Extension::SPV_INTEL_ternary_bitwise_function);
2417 Reqs.addCapability(SPIRV::Capability::TernaryBitwiseFunctionINTEL);
2418 break;
2419 }
2420 case SPIRV::OpCopyMemorySized: {
2421 Reqs.addCapability(SPIRV::Capability::Addresses);
2422 // TODO: Add UntypedPointersKHR when implemented.
2423 break;
2424 }
2425 case SPIRV::OpPredicatedLoadINTEL:
2426 case SPIRV::OpPredicatedStoreINTEL: {
2427 if (!ST.canUseExtension(SPIRV::Extension::SPV_INTEL_predicated_io))
2429 "OpPredicated[Load/Store]INTEL instructions require "
2430 "the following SPIR-V extension: SPV_INTEL_predicated_io",
2431 false);
2432 Reqs.addExtension(SPIRV::Extension::SPV_INTEL_predicated_io);
2433 Reqs.addCapability(SPIRV::Capability::PredicatedIOINTEL);
2434 break;
2435 }
2436 case SPIRV::OpFAddS:
2437 case SPIRV::OpFSubS:
2438 case SPIRV::OpFMulS:
2439 case SPIRV::OpFDivS:
2440 case SPIRV::OpFRemS:
2441 case SPIRV::OpFMod:
2442 case SPIRV::OpFNegate:
2443 case SPIRV::OpFAddV:
2444 case SPIRV::OpFSubV:
2445 case SPIRV::OpFMulV:
2446 case SPIRV::OpFDivV:
2447 case SPIRV::OpFRemV:
2448 case SPIRV::OpFNegateV: {
2449 const MachineRegisterInfo &MRI = MI.getMF()->getRegInfo();
2450 SPIRVTypeInst TypeDef = MRI.getVRegDef(MI.getOperand(1).getReg());
2451 if (TypeDef->getOpcode() == SPIRV::OpTypeVector)
2452 TypeDef = MRI.getVRegDef(TypeDef->getOperand(1).getReg());
2453 if (isBFloat16Type(TypeDef)) {
2454 if (!ST.canUseExtension(SPIRV::Extension::SPV_INTEL_bfloat16_arithmetic))
2456 "Arithmetic instructions with bfloat16 arguments require the "
2457 "following SPIR-V extension: SPV_INTEL_bfloat16_arithmetic",
2458 false);
2459 Reqs.addExtension(SPIRV::Extension::SPV_INTEL_bfloat16_arithmetic);
2460 Reqs.addCapability(SPIRV::Capability::BFloat16ArithmeticINTEL);
2461 }
2462 break;
2463 }
2464 case SPIRV::OpOrdered:
2465 case SPIRV::OpUnordered:
2466 case SPIRV::OpFOrdEqual:
2467 case SPIRV::OpFOrdNotEqual:
2468 case SPIRV::OpFOrdLessThan:
2469 case SPIRV::OpFOrdLessThanEqual:
2470 case SPIRV::OpFOrdGreaterThan:
2471 case SPIRV::OpFOrdGreaterThanEqual:
2472 case SPIRV::OpFUnordEqual:
2473 case SPIRV::OpFUnordNotEqual:
2474 case SPIRV::OpFUnordLessThan:
2475 case SPIRV::OpFUnordLessThanEqual:
2476 case SPIRV::OpFUnordGreaterThan:
2477 case SPIRV::OpFUnordGreaterThanEqual: {
2478 const MachineRegisterInfo &MRI = MI.getMF()->getRegInfo();
2479 MachineInstr *OperandDef = MRI.getVRegDef(MI.getOperand(2).getReg());
2480 SPIRVTypeInst TypeDef = MRI.getVRegDef(OperandDef->getOperand(1).getReg());
2481 if (TypeDef->getOpcode() == SPIRV::OpTypeVector)
2482 TypeDef = MRI.getVRegDef(TypeDef->getOperand(1).getReg());
2483 if (isBFloat16Type(TypeDef)) {
2484 if (!ST.canUseExtension(SPIRV::Extension::SPV_INTEL_bfloat16_arithmetic))
2486 "Relational instructions with bfloat16 arguments require the "
2487 "following SPIR-V extension: SPV_INTEL_bfloat16_arithmetic",
2488 false);
2489 Reqs.addExtension(SPIRV::Extension::SPV_INTEL_bfloat16_arithmetic);
2490 Reqs.addCapability(SPIRV::Capability::BFloat16ArithmeticINTEL);
2491 }
2492 break;
2493 }
2494 case SPIRV::OpDPdxCoarse:
2495 case SPIRV::OpDPdyCoarse:
2496 case SPIRV::OpDPdxFine:
2497 case SPIRV::OpDPdyFine: {
2498 Reqs.addCapability(SPIRV::Capability::DerivativeControl);
2499 break;
2500 }
2501 case SPIRV::OpLoopControlINTEL: {
2502 Reqs.addExtension(SPIRV::Extension::SPV_INTEL_unstructured_loop_controls);
2503 Reqs.addCapability(SPIRV::Capability::UnstructuredLoopControlsINTEL);
2504 break;
2505 }
2506
2507 default:
2508 break;
2509 }
2510
2511 // If we require capability Shader, then we can remove the requirement for
2512 // the BitInstructions capability, since Shader is a superset capability
2513 // of BitInstructions.
2514 Reqs.removeCapabilityIf(SPIRV::Capability::BitInstructions,
2515 SPIRV::Capability::Shader);
2516}
2517
2518static void collectReqs(const Module &M, SPIRV::ModuleAnalysisInfo &MAI,
2519 MachineModuleInfo *MMI, const SPIRVSubtarget &ST) {
2520 // Collect requirements for existing instructions.
2521 for (const Function &F : M) {
2523 if (!MF)
2524 continue;
2525 for (const MachineBasicBlock &MBB : *MF)
2526 for (const MachineInstr &MI : MBB)
2527 addInstrRequirements(MI, MAI, ST);
2528 }
2529 // Collect requirements for OpExecutionMode instructions.
2530 auto Node = M.getNamedMetadata("spirv.ExecutionMode");
2531 if (Node) {
2532 bool RequireFloatControls = false, RequireIntelFloatControls2 = false,
2533 RequireKHRFloatControls2 = false,
2534 VerLower14 = !ST.isAtLeastSPIRVVer(VersionTuple(1, 4));
2535 bool HasIntelFloatControls2 =
2536 ST.canUseExtension(SPIRV::Extension::SPV_INTEL_float_controls2);
2537 bool HasKHRFloatControls2 =
2538 ST.canUseExtension(SPIRV::Extension::SPV_KHR_float_controls2);
2539 for (unsigned i = 0; i < Node->getNumOperands(); i++) {
2540 MDNode *MDN = cast<MDNode>(Node->getOperand(i));
2541 const MDOperand &MDOp = MDN->getOperand(1);
2542 if (auto *CMeta = dyn_cast<ConstantAsMetadata>(MDOp)) {
2543 Constant *C = CMeta->getValue();
2544 if (ConstantInt *Const = dyn_cast<ConstantInt>(C)) {
2545 auto EM = Const->getZExtValue();
2546 // SPV_KHR_float_controls is not available until v1.4:
2547 // add SPV_KHR_float_controls if the version is too low
2548 switch (EM) {
2549 case SPIRV::ExecutionMode::DenormPreserve:
2550 case SPIRV::ExecutionMode::DenormFlushToZero:
2551 case SPIRV::ExecutionMode::RoundingModeRTE:
2552 case SPIRV::ExecutionMode::RoundingModeRTZ:
2553 RequireFloatControls = VerLower14;
2555 SPIRV::OperandCategory::ExecutionModeOperand, EM, ST);
2556 break;
2557 case SPIRV::ExecutionMode::RoundingModeRTPINTEL:
2558 case SPIRV::ExecutionMode::RoundingModeRTNINTEL:
2559 case SPIRV::ExecutionMode::FloatingPointModeALTINTEL:
2560 case SPIRV::ExecutionMode::FloatingPointModeIEEEINTEL:
2561 if (HasIntelFloatControls2) {
2562 RequireIntelFloatControls2 = true;
2564 SPIRV::OperandCategory::ExecutionModeOperand, EM, ST);
2565 }
2566 break;
2567 case SPIRV::ExecutionMode::FPFastMathDefault: {
2568 if (HasKHRFloatControls2) {
2569 RequireKHRFloatControls2 = true;
2571 SPIRV::OperandCategory::ExecutionModeOperand, EM, ST);
2572 }
2573 break;
2574 }
2575 case SPIRV::ExecutionMode::ContractionOff:
2576 case SPIRV::ExecutionMode::SignedZeroInfNanPreserve:
2577 if (HasKHRFloatControls2) {
2578 RequireKHRFloatControls2 = true;
2580 SPIRV::OperandCategory::ExecutionModeOperand,
2581 SPIRV::ExecutionMode::FPFastMathDefault, ST);
2582 } else {
2584 SPIRV::OperandCategory::ExecutionModeOperand, EM, ST);
2585 }
2586 break;
2587 default:
2589 SPIRV::OperandCategory::ExecutionModeOperand, EM, ST);
2590 }
2591 }
2592 }
2593 }
2594 if (RequireFloatControls &&
2595 ST.canUseExtension(SPIRV::Extension::SPV_KHR_float_controls))
2596 MAI.Reqs.addExtension(SPIRV::Extension::SPV_KHR_float_controls);
2597 if (RequireIntelFloatControls2)
2598 MAI.Reqs.addExtension(SPIRV::Extension::SPV_INTEL_float_controls2);
2599 if (RequireKHRFloatControls2)
2600 MAI.Reqs.addExtension(SPIRV::Extension::SPV_KHR_float_controls2);
2601 }
2602 for (const Function &F : M) {
2603 if (F.isDeclaration())
2604 continue;
2605 if (F.getMetadata("reqd_work_group_size"))
2607 SPIRV::OperandCategory::ExecutionModeOperand,
2608 SPIRV::ExecutionMode::LocalSize, ST);
2609 if (F.getFnAttribute("hlsl.numthreads").isValid()) {
2611 SPIRV::OperandCategory::ExecutionModeOperand,
2612 SPIRV::ExecutionMode::LocalSize, ST);
2613 }
2614 if (F.getFnAttribute("enable-maximal-reconvergence").getValueAsBool()) {
2615 MAI.Reqs.addExtension(SPIRV::Extension::SPV_KHR_maximal_reconvergence);
2616 }
2617 if (F.getMetadata("work_group_size_hint"))
2619 SPIRV::OperandCategory::ExecutionModeOperand,
2620 SPIRV::ExecutionMode::LocalSizeHint, ST);
2621 if (F.getMetadata("intel_reqd_sub_group_size"))
2623 SPIRV::OperandCategory::ExecutionModeOperand,
2624 SPIRV::ExecutionMode::SubgroupSize, ST);
2625 if (F.getMetadata("max_work_group_size"))
2627 SPIRV::OperandCategory::ExecutionModeOperand,
2628 SPIRV::ExecutionMode::MaxWorkgroupSizeINTEL, ST);
2629 if (F.getMetadata("vec_type_hint"))
2631 SPIRV::OperandCategory::ExecutionModeOperand,
2632 SPIRV::ExecutionMode::VecTypeHint, ST);
2633
2634 if (F.hasOptNone()) {
2635 if (ST.canUseExtension(SPIRV::Extension::SPV_INTEL_optnone)) {
2636 MAI.Reqs.addExtension(SPIRV::Extension::SPV_INTEL_optnone);
2637 MAI.Reqs.addCapability(SPIRV::Capability::OptNoneINTEL);
2638 } else if (ST.canUseExtension(SPIRV::Extension::SPV_EXT_optnone)) {
2639 MAI.Reqs.addExtension(SPIRV::Extension::SPV_EXT_optnone);
2640 MAI.Reqs.addCapability(SPIRV::Capability::OptNoneEXT);
2641 }
2642 }
2643 }
2644}
2645
2646static unsigned getFastMathFlags(const MachineInstr &I,
2647 const SPIRVSubtarget &ST) {
2648 unsigned Flags = SPIRV::FPFastMathMode::None;
2649 bool CanUseKHRFloatControls2 =
2650 ST.canUseExtension(SPIRV::Extension::SPV_KHR_float_controls2);
2651 if (I.getFlag(MachineInstr::MIFlag::FmNoNans))
2652 Flags |= SPIRV::FPFastMathMode::NotNaN;
2653 if (I.getFlag(MachineInstr::MIFlag::FmNoInfs))
2654 Flags |= SPIRV::FPFastMathMode::NotInf;
2655 if (I.getFlag(MachineInstr::MIFlag::FmNsz))
2656 Flags |= SPIRV::FPFastMathMode::NSZ;
2657 if (I.getFlag(MachineInstr::MIFlag::FmArcp))
2658 Flags |= SPIRV::FPFastMathMode::AllowRecip;
2659 if (I.getFlag(MachineInstr::MIFlag::FmContract) && CanUseKHRFloatControls2)
2660 Flags |= SPIRV::FPFastMathMode::AllowContract;
2661 if (I.getFlag(MachineInstr::MIFlag::FmReassoc)) {
2662 if (CanUseKHRFloatControls2)
2663 // LLVM reassoc maps to SPIRV transform, see
2664 // https://github.com/KhronosGroup/SPIRV-Registry/issues/326 for details.
2665 // Because we are enabling AllowTransform, we must enable AllowReassoc and
2666 // AllowContract too, as required by SPIRV spec. Also, we used to map
2667 // MIFlag::FmReassoc to FPFastMathMode::Fast, which now should instead by
2668 // replaced by turning all the other bits instead. Therefore, we're
2669 // enabling every bit here except None and Fast.
2670 Flags |= SPIRV::FPFastMathMode::NotNaN | SPIRV::FPFastMathMode::NotInf |
2671 SPIRV::FPFastMathMode::NSZ | SPIRV::FPFastMathMode::AllowRecip |
2672 SPIRV::FPFastMathMode::AllowTransform |
2673 SPIRV::FPFastMathMode::AllowReassoc |
2674 SPIRV::FPFastMathMode::AllowContract;
2675 else
2676 Flags |= SPIRV::FPFastMathMode::Fast;
2677 }
2678
2679 if (CanUseKHRFloatControls2) {
2680 // Error out if SPIRV::FPFastMathMode::Fast is enabled.
2681 assert(!(Flags & SPIRV::FPFastMathMode::Fast) &&
2682 "SPIRV::FPFastMathMode::Fast is deprecated and should not be used "
2683 "anymore.");
2684
2685 // Error out if AllowTransform is enabled without AllowReassoc and
2686 // AllowContract.
2687 assert((!(Flags & SPIRV::FPFastMathMode::AllowTransform) ||
2688 ((Flags & SPIRV::FPFastMathMode::AllowReassoc &&
2689 Flags & SPIRV::FPFastMathMode::AllowContract))) &&
2690 "SPIRV::FPFastMathMode::AllowTransform requires AllowReassoc and "
2691 "AllowContract flags to be enabled as well.");
2692 }
2693
2694 return Flags;
2695}
2696
2697static bool isFastMathModeAvailable(const SPIRVSubtarget &ST) {
2698 if (ST.isKernel())
2699 return true;
2700 if (ST.getSPIRVVersion() < VersionTuple(1, 2))
2701 return false;
2702 return ST.canUseExtension(SPIRV::Extension::SPV_KHR_float_controls2);
2703}
2704
2705static void handleMIFlagDecoration(
2706 MachineInstr &I, const SPIRVSubtarget &ST, const SPIRVInstrInfo &TII,
2708 SPIRV::FPFastMathDefaultInfoVector &FPFastMathDefaultInfoVec) {
2709 if (I.getFlag(MachineInstr::MIFlag::NoSWrap) && TII.canUseNSW(I) &&
2710 getSymbolicOperandRequirements(SPIRV::OperandCategory::DecorationOperand,
2711 SPIRV::Decoration::NoSignedWrap, ST, Reqs)
2712 .IsSatisfiable) {
2713 buildOpDecorate(I.getOperand(0).getReg(), I, TII,
2714 SPIRV::Decoration::NoSignedWrap, {});
2715 }
2716 if (I.getFlag(MachineInstr::MIFlag::NoUWrap) && TII.canUseNUW(I) &&
2717 getSymbolicOperandRequirements(SPIRV::OperandCategory::DecorationOperand,
2718 SPIRV::Decoration::NoUnsignedWrap, ST,
2719 Reqs)
2720 .IsSatisfiable) {
2721 buildOpDecorate(I.getOperand(0).getReg(), I, TII,
2722 SPIRV::Decoration::NoUnsignedWrap, {});
2723 }
2724 // In Kernel environments, FPFastMathMode on OpExtInst is valid per core
2725 // spec. For other instruction types, SPV_KHR_float_controls2 is required.
2726 bool CanUseFM =
2727 TII.canUseFastMathFlags(
2728 I, ST.canUseExtension(SPIRV::Extension::SPV_KHR_float_controls2)) ||
2729 (ST.isKernel() && I.getOpcode() == SPIRV::OpExtInst);
2730 if (!CanUseFM)
2731 return;
2732
2733 unsigned FMFlags = getFastMathFlags(I, ST);
2734 if (FMFlags == SPIRV::FPFastMathMode::None) {
2735 // We also need to check if any FPFastMathDefault info was set for the
2736 // types used in this instruction.
2737 if (FPFastMathDefaultInfoVec.empty())
2738 return;
2739
2740 // There are three types of instructions that can use fast math flags:
2741 // 1. Arithmetic instructions (FAdd, FMul, FSub, FDiv, FRem, etc.)
2742 // 2. Relational instructions (FCmp, FOrd, FUnord, etc.)
2743 // 3. Extended instructions (ExtInst)
2744 // For arithmetic instructions, the floating point type can be in the
2745 // result type or in the operands, but they all must be the same.
2746 // For the relational and logical instructions, the floating point type
2747 // can only be in the operands 1 and 2, not the result type. Also, the
2748 // operands must have the same type. For the extended instructions, the
2749 // floating point type can be in the result type or in the operands. It's
2750 // unclear if the operands and the result type must be the same. Let's
2751 // assume they must be. Therefore, for 1. and 2., we can check the first
2752 // operand type, and for 3. we can check the result type.
2753 assert(I.getNumOperands() >= 3 && "Expected at least 3 operands");
2754 Register ResReg = I.getOpcode() == SPIRV::OpExtInst
2755 ? I.getOperand(1).getReg()
2756 : I.getOperand(2).getReg();
2757 SPIRVTypeInst ResType = GR->getSPIRVTypeForVReg(ResReg, I.getMF());
2758 const Type *Ty = GR->getTypeForSPIRVType(ResType);
2759 Ty = Ty->isVectorTy() ? cast<VectorType>(Ty)->getElementType() : Ty;
2760
2761 // Match instruction type with the FPFastMathDefaultInfoVec.
2762 bool Emit = false;
2763 for (SPIRV::FPFastMathDefaultInfo &Elem : FPFastMathDefaultInfoVec) {
2764 if (Ty == Elem.Ty) {
2765 FMFlags = Elem.FastMathFlags;
2766 Emit = Elem.ContractionOff || Elem.SignedZeroInfNanPreserve ||
2767 Elem.FPFastMathDefault;
2768 break;
2769 }
2770 }
2771
2772 if (FMFlags == SPIRV::FPFastMathMode::None && !Emit)
2773 return;
2774 }
2775 if (isFastMathModeAvailable(ST)) {
2776 Register DstReg = I.getOperand(0).getReg();
2777 buildOpDecorate(DstReg, I, TII, SPIRV::Decoration::FPFastMathMode,
2778 {FMFlags});
2779 }
2780}
2781
2782// Walk all functions and add decorations related to MI flags.
2783static void addDecorations(const Module &M, const SPIRVInstrInfo &TII,
2784 MachineModuleInfo *MMI, const SPIRVSubtarget &ST,
2786 const SPIRVGlobalRegistry *GR) {
2787 for (const Function &F : M) {
2789 if (!MF)
2790 continue;
2791
2792 for (auto &MBB : *MF)
2793 for (auto &MI : MBB)
2794 handleMIFlagDecoration(MI, ST, TII, MAI.Reqs, GR,
2796 }
2797}
2798
2799static void addMBBNames(const Module &M, const SPIRVInstrInfo &TII,
2800 MachineModuleInfo *MMI, const SPIRVSubtarget &ST,
2802 for (const Function &F : M) {
2804 if (!MF)
2805 continue;
2806 if (MF->getFunction()
2808 .isValid())
2809 continue;
2810 MachineRegisterInfo &MRI = MF->getRegInfo();
2811 for (auto &MBB : *MF) {
2812 if (!MBB.hasName() || MBB.empty())
2813 continue;
2814 // Emit basic block names.
2816 MRI.setRegClass(Reg, &SPIRV::IDRegClass);
2817 buildOpName(Reg, MBB.getName(), *std::prev(MBB.end()), TII);
2818 MCRegister GlobalReg = MAI.getOrCreateMBBRegister(MBB);
2819 MAI.setRegisterAlias(MF, Reg, GlobalReg);
2820 }
2821 }
2822}
2823
2824// patching Instruction::PHI to SPIRV::OpPhi
2825static void patchPhis(const Module &M, SPIRVGlobalRegistry *GR,
2826 const SPIRVInstrInfo &TII, MachineModuleInfo *MMI) {
2827 for (const Function &F : M) {
2829 if (!MF)
2830 continue;
2831 for (auto &MBB : *MF) {
2832 for (MachineInstr &MI : MBB.phis()) {
2833 MI.setDesc(TII.get(SPIRV::OpPhi));
2834 Register ResTypeReg = GR->getSPIRVTypeID(
2835 GR->getSPIRVTypeForVReg(MI.getOperand(0).getReg(), MF));
2836 MI.insert(MI.operands_begin() + 1,
2837 {MachineOperand::CreateReg(ResTypeReg, false)});
2838 }
2839 }
2840
2841 MF->getProperties().setNoPHIs();
2842 }
2843}
2844
2846 const Module &M, SPIRV::ModuleAnalysisInfo &MAI, const Function *F) {
2847 auto it = MAI.FPFastMathDefaultInfoMap.find(F);
2848 if (it != MAI.FPFastMathDefaultInfoMap.end())
2849 return it->second;
2850
2851 // If the map does not contain the entry, create a new one. Initialize it to
2852 // contain all 3 elements sorted by bit width of target type: {half, float,
2853 // double}.
2854 SPIRV::FPFastMathDefaultInfoVector FPFastMathDefaultInfoVec;
2855 FPFastMathDefaultInfoVec.emplace_back(Type::getHalfTy(M.getContext()),
2856 SPIRV::FPFastMathMode::None);
2857 FPFastMathDefaultInfoVec.emplace_back(Type::getFloatTy(M.getContext()),
2858 SPIRV::FPFastMathMode::None);
2859 FPFastMathDefaultInfoVec.emplace_back(Type::getDoubleTy(M.getContext()),
2860 SPIRV::FPFastMathMode::None);
2861 return MAI.FPFastMathDefaultInfoMap[F] = std::move(FPFastMathDefaultInfoVec);
2862}
2863
2865 SPIRV::FPFastMathDefaultInfoVector &FPFastMathDefaultInfoVec,
2866 const Type *Ty) {
2867 size_t BitWidth = Ty->getScalarSizeInBits();
2868 int Index =
2870 BitWidth);
2871 assert(Index >= 0 && Index < 3 &&
2872 "Expected FPFastMathDefaultInfo for half, float, or double");
2873 assert(FPFastMathDefaultInfoVec.size() == 3 &&
2874 "Expected FPFastMathDefaultInfoVec to have exactly 3 elements");
2875 return FPFastMathDefaultInfoVec[Index];
2876}
2877
2878static void collectFPFastMathDefaults(const Module &M,
2880 const SPIRVSubtarget &ST) {
2881 if (!ST.canUseExtension(SPIRV::Extension::SPV_KHR_float_controls2))
2882 return;
2883
2884 // Store the FPFastMathDefaultInfo in the FPFastMathDefaultInfoMap.
2885 // We need the entry point (function) as the key, and the target
2886 // type and flags as the value.
2887 // We also need to check ContractionOff and SignedZeroInfNanPreserve
2888 // execution modes, as they are now deprecated and must be replaced
2889 // with FPFastMathDefaultInfo.
2890 auto Node = M.getNamedMetadata("spirv.ExecutionMode");
2891 if (!Node)
2892 return;
2893
2894 for (unsigned i = 0; i < Node->getNumOperands(); i++) {
2895 MDNode *MDN = cast<MDNode>(Node->getOperand(i));
2896 assert(MDN->getNumOperands() >= 2 && "Expected at least 2 operands");
2897 const Function *F = cast<Function>(
2898 cast<ConstantAsMetadata>(MDN->getOperand(0))->getValue());
2899 const auto EM =
2901 cast<ConstantAsMetadata>(MDN->getOperand(1))->getValue())
2902 ->getZExtValue();
2903 if (EM == SPIRV::ExecutionMode::FPFastMathDefault) {
2904 assert(MDN->getNumOperands() == 4 &&
2905 "Expected 4 operands for FPFastMathDefault");
2906
2907 const Type *T = cast<ValueAsMetadata>(MDN->getOperand(2))->getType();
2908 unsigned Flags =
2910 cast<ConstantAsMetadata>(MDN->getOperand(3))->getValue())
2911 ->getZExtValue();
2912 SPIRV::FPFastMathDefaultInfoVector &FPFastMathDefaultInfoVec =
2915 getFPFastMathDefaultInfo(FPFastMathDefaultInfoVec, T);
2916 Info.FastMathFlags = Flags;
2917 Info.FPFastMathDefault = true;
2918 } else if (EM == SPIRV::ExecutionMode::ContractionOff) {
2919 assert(MDN->getNumOperands() == 2 &&
2920 "Expected no operands for ContractionOff");
2921
2922 // We need to save this info for every possible FP type, i.e. {half,
2923 // float, double, fp128}.
2924 SPIRV::FPFastMathDefaultInfoVector &FPFastMathDefaultInfoVec =
2926 for (SPIRV::FPFastMathDefaultInfo &Info : FPFastMathDefaultInfoVec) {
2927 Info.ContractionOff = true;
2928 }
2929 } else if (EM == SPIRV::ExecutionMode::SignedZeroInfNanPreserve) {
2930 assert(MDN->getNumOperands() == 3 &&
2931 "Expected 1 operand for SignedZeroInfNanPreserve");
2932 unsigned TargetWidth =
2934 cast<ConstantAsMetadata>(MDN->getOperand(2))->getValue())
2935 ->getZExtValue();
2936 // We need to save this info only for the FP type with TargetWidth.
2937 SPIRV::FPFastMathDefaultInfoVector &FPFastMathDefaultInfoVec =
2941 assert(Index >= 0 && Index < 3 &&
2942 "Expected FPFastMathDefaultInfo for half, float, or double");
2943 assert(FPFastMathDefaultInfoVec.size() == 3 &&
2944 "Expected FPFastMathDefaultInfoVec to have exactly 3 elements");
2945 FPFastMathDefaultInfoVec[Index].SignedZeroInfNanPreserve = true;
2946 }
2947 }
2948}
2949
2951 AU.addRequired<TargetPassConfig>();
2952 AU.addRequired<MachineModuleInfoWrapperPass>();
2953}
2954
2956 SPIRVTargetMachine &TM =
2957 getAnalysis<TargetPassConfig>().getTM<SPIRVTargetMachine>();
2958 ST = TM.getSubtargetImpl();
2959 GR = ST->getSPIRVGlobalRegistry();
2960 TII = ST->getInstrInfo();
2961
2963
2964 setBaseInfo(M);
2965
2966 patchPhis(M, GR, *TII, MMI);
2967
2968 addMBBNames(M, *TII, MMI, *ST, MAI);
2969 collectFPFastMathDefaults(M, MAI, *ST);
2970 addDecorations(M, *TII, MMI, *ST, MAI, GR);
2971
2972 collectReqs(M, MAI, MMI, *ST);
2973
2974 // Process type/const/global var/func decl instructions, number their
2975 // destination registers from 0 to N, collect Extensions and Capabilities.
2976 collectReqs(M, MAI, MMI, *ST);
2977 collectDeclarations(M);
2978
2979 // Number rest of registers from N+1 onwards.
2980 numberRegistersGlobally(M);
2981
2982 // Collect OpName, OpEntryPoint, OpDecorate etc, process other instructions.
2983 processOtherInstrs(M);
2984
2985 // If there are no entry points, we need the Linkage capability.
2986 if (MAI.MS[SPIRV::MB_EntryPoints].empty())
2987 MAI.Reqs.addCapability(SPIRV::Capability::Linkage);
2988
2989 // Set maximum ID used.
2990 GR->setBound(MAI.MaxID);
2991
2992 return false;
2993}
MachineInstrBuilder & UseMI
assert(UImm &&(UImm !=~static_cast< T >(0)) &&"Invalid immediate!")
ReachingDefInfo InstSet & ToRemove
MachineBasicBlock & MBB
static GCRegistry::Add< CoreCLRGC > E("coreclr", "CoreCLR-compatible GC")
#define clEnumValN(ENUMVAL, FLAGNAME, DESC)
#define DEBUG_TYPE
static Register UseReg(const MachineOperand &MO)
const HexagonInstrInfo * TII
IRTranslator LLVM IR MI
#define F(x, y, z)
Definition MD5.cpp:54
#define I(x, y, z)
Definition MD5.cpp:57
Register Reg
Promote Memory to Register
Definition Mem2Reg.cpp:110
#define T
MachineInstr unsigned OpIdx
#define INITIALIZE_PASS(passName, arg, name, cfg, analysis)
Definition PassSupport.h:56
static SPIRV::FPFastMathDefaultInfoVector & getOrCreateFPFastMathDefaultInfoVec(const Module &M, DenseMap< Function *, SPIRV::FPFastMathDefaultInfoVector > &FPFastMathDefaultInfoMap, Function *F)
static SPIRV::FPFastMathDefaultInfo & getFPFastMathDefaultInfo(SPIRV::FPFastMathDefaultInfoVector &FPFastMathDefaultInfoVec, const Type *Ty)
#define ATOM_FLT_REQ_EXT_MSG(ExtName)
static cl::opt< bool > SPVDumpDeps("spv-dump-deps", cl::desc("Dump MIR with SPIR-V dependencies info"), cl::Optional, cl::init(false))
static cl::list< SPIRV::Capability::Capability > AvoidCapabilities("avoid-spirv-capabilities", cl::desc("SPIR-V capabilities to avoid if there are " "other options enabling a feature"), cl::Hidden, cl::values(clEnumValN(SPIRV::Capability::Shader, "Shader", "SPIR-V Shader capability")))
unsigned OpIndex
#define SPIRV_BACKEND_SERVICE_FUN_NAME
Definition SPIRVUtils.h:528
This file contains some templates that are useful if you are working with the STL at all.
#define LLVM_DEBUG(...)
Definition Debug.h:119
Target-Independent Code Generator Pass Configuration Options pass.
The Input class is used to parse a yaml document into in-memory structs and vectors.
Represent the analysis usage information of a pass.
AnalysisUsage & addRequired()
bool isValid() const
Return true if the attribute is any kind of attribute.
Definition Attributes.h:261
This is the shared class of boolean and integer constants.
Definition Constants.h:87
This is an important base class in LLVM.
Definition Constant.h:43
Attribute getFnAttribute(Attribute::AttrKind Kind) const
Return the attribute for the given attribute kind.
Definition Function.cpp:763
static constexpr LLT scalar(unsigned SizeInBits)
Get a low-level scalar or aggregate "bag of bits".
Wrapper class representing physical registers. Should be passed by value.
Definition MCRegister.h:41
constexpr bool isValid() const
Definition MCRegister.h:84
Metadata node.
Definition Metadata.h:1080
const MDOperand & getOperand(unsigned I) const
Definition Metadata.h:1444
unsigned getNumOperands() const
Return number of MDNode operands.
Definition Metadata.h:1450
Tracking metadata reference owned by Metadata.
Definition Metadata.h:902
const MachineFunction * getParent() const
Return the MachineFunction containing this basic block.
MachineRegisterInfo & getRegInfo()
getRegInfo - Return information about the registers currently in use.
Function & getFunction()
Return the LLVM function that this machine code represents.
const MachineFunctionProperties & getProperties() const
Get the function properties.
Register getReg(unsigned Idx) const
Get the register for the operand index.
Representation of each machine instruction.
unsigned getOpcode() const
Returns the opcode of this MachineInstr.
const MachineBasicBlock * getParent() const
unsigned getNumOperands() const
Retuns the total number of operands.
LLVM_ABI const MachineFunction * getMF() const
Return the function that contains the basic block that this instruction belongs to.
const MachineOperand & getOperand(unsigned i) const
This class contains meta information specific to a module.
LLVM_ABI MachineFunction * getMachineFunction(const Function &F) const
Returns the MachineFunction associated to IR function F if there is one, otherwise nullptr.
MachineOperand class - Representation of each machine instruction operand.
unsigned getSubReg() const
int64_t getImm() const
bool isReg() const
isReg - Tests if this is a MO_Register operand.
bool isImm() const
isImm - Tests if this is a MO_Immediate operand.
LLVM_ABI void print(raw_ostream &os, const TargetRegisterInfo *TRI=nullptr) const
Print the MachineOperand to os.
MachineInstr * getParent()
getParent - Return the instruction that this operand belongs to.
static MachineOperand CreateImm(int64_t Val)
MachineOperandType getType() const
getType - Returns the MachineOperandType for this operand.
Register getReg() const
getReg - Returns the register number.
MachineRegisterInfo - Keep track of information for virtual and physical registers,...
const TargetRegisterClass * getRegClass(Register Reg) const
Return the register class of the specified virtual register.
LLVM_ABI MachineInstr * getVRegDef(Register Reg) const
getVRegDef - Return the machine instr that defines the specified virtual register or null if none is ...
LLVM_ABI void setRegClass(Register Reg, const TargetRegisterClass *RC)
setRegClass - Set the register class of the specified virtual register.
LLVM_ABI Register createGenericVirtualRegister(LLT Ty, StringRef Name="")
Create and return a new generic virtual register with low-level type Ty.
iterator_range< reg_instr_iterator > reg_instructions(Register Reg) const
iterator_range< use_instr_iterator > use_instructions(Register Reg) const
LLVM_ABI MachineInstr * getUniqueVRegDef(Register Reg) const
getUniqueVRegDef - Return the unique machine instr that defines the specified virtual register or nul...
A Module instance is used to store all the information related to an LLVM module.
Definition Module.h:67
virtual void print(raw_ostream &OS, const Module *M) const
print - Print out the internal state of the pass.
Definition Pass.cpp:140
AnalysisType & getAnalysis() const
getAnalysis<AnalysisType>() - This function is used by subclasses to get to the analysis information ...
Wrapper class representing virtual and physical registers.
Definition Register.h:20
constexpr bool isValid() const
Definition Register.h:112
unsigned getScalarOrVectorBitWidth(SPIRVTypeInst Type) const
const Type * getTypeForSPIRVType(SPIRVTypeInst Ty) const
Register getSPIRVTypeID(SPIRVTypeInst SpirvType) const
SPIRVTypeInst getSPIRVTypeForVReg(Register VReg, const MachineFunction *MF=nullptr) const
bool isConstantInstr(const MachineInstr &MI) const
const SPIRVInstrInfo * getInstrInfo() const override
SPIRVGlobalRegistry * getSPIRVGlobalRegistry() const
const SPIRVSubtarget * getSubtargetImpl() const
bool isTypeIntN(unsigned N=0) const
SmallSet - This maintains a set of unique values, optimizing for the case when the set is small (less...
Definition SmallSet.h:134
bool contains(const T &V) const
Check if the SmallSet contains the given element.
Definition SmallSet.h:229
std::pair< const_iterator, bool > insert(const T &V)
insert - Insert an element into the set if it isn't already there.
Definition SmallSet.h:184
reference emplace_back(ArgTypes &&... Args)
void append(ItTy in_start, ItTy in_end)
Add the specified range to the end of the SmallVector.
iterator insert(iterator I, T &&Elt)
void push_back(const T &Elt)
The instances of the Type class are immutable: once they are created, they are never changed.
Definition Type.h:46
bool isVectorTy() const
True if this is an instance of VectorType.
Definition Type.h:290
static LLVM_ABI Type * getDoubleTy(LLVMContext &C)
Definition Type.cpp:291
static LLVM_ABI Type * getFloatTy(LLVMContext &C)
Definition Type.cpp:290
static LLVM_ABI Type * getHalfTy(LLVMContext &C)
Definition Type.cpp:288
Represents a version number in the form major[.minor[.subminor[.build]]].
bool empty() const
Determine whether this version information is empty (e.g., all version components are zero).
NodeTy * getNextNode()
Get the next node, or nullptr for the list tail.
Definition ilist_node.h:348
#define llvm_unreachable(msg)
Marks that the current location is not supposed to be reachable.
constexpr std::underlying_type_t< E > Mask()
Get a bitmask with 1s in all places up to the high-order bit of E's largest value.
@ C
The default llvm calling convention, compatible with C.
Definition CallingConv.h:34
SmallVector< const MachineInstr * > InstrList
ValuesClass values(OptsTy... Options)
Helper to build a ValuesClass by forwarding a variable number of arguments as an initializer list to ...
initializer< Ty > init(const Ty &Val)
std::enable_if_t< detail::IsValidPointer< X, Y >::value, X * > extract(Y &&MD)
Extract a Value from Metadata.
Definition Metadata.h:668
NodeAddr< InstrNode * > Instr
Definition RDFGraph.h:389
This is an optimization pass for GlobalISel generic memory operations.
void buildOpName(Register Target, const StringRef &Name, MachineIRBuilder &MIRBuilder)
FunctionAddr VTableAddr Value
Definition InstrProf.h:137
std::string getStringImm(const MachineInstr &MI, unsigned StartIndex)
bool all_of(R &&range, UnaryPredicate P)
Provide wrappers to std::all_of which take ranges instead of having to pass begin/end explicitly.
Definition STLExtras.h:1738
hash_code hash_value(const FixedPointSemantics &Val)
ExtensionList getSymbolicOperandExtensions(SPIRV::OperandCategory::OperandCategory Category, uint32_t Value)
CapabilityList getSymbolicOperandCapabilities(SPIRV::OperandCategory::OperandCategory Category, uint32_t Value)
SmallVector< SPIRV::Extension::Extension, 8 > ExtensionList
decltype(auto) dyn_cast(const From &Val)
dyn_cast<X> - Return the argument parameter cast to the specified type.
Definition Casting.h:643
SmallVector< size_t > InstrSignature
VersionTuple getSymbolicOperandMaxVersion(SPIRV::OperandCategory::OperandCategory Category, uint32_t Value)
void buildOpDecorate(Register Reg, MachineIRBuilder &MIRBuilder, SPIRV::Decoration::Decoration Dec, const std::vector< uint32_t > &DecArgs, StringRef StrImm)
MachineInstr * getImm(const MachineOperand &MO, const MachineRegisterInfo *MRI)
CapabilityList getCapabilitiesEnabledByExtension(SPIRV::Extension::Extension Extension)
LLVM_ABI raw_ostream & dbgs()
dbgs() - This returns a reference to a raw_ostream for debugging messages.
Definition Debug.cpp:209
LLVM_ABI void report_fatal_error(Error Err, bool gen_crash_diag=true)
Definition Error.cpp:163
std::string getSymbolicOperandMnemonic(SPIRV::OperandCategory::OperandCategory Category, int32_t Value)
LLVM_ABI raw_fd_ostream & errs()
This returns a reference to a raw_ostream for standard error.
DWARFExpression::Operation Op
VersionTuple getSymbolicOperandMinVersion(SPIRV::OperandCategory::OperandCategory Category, uint32_t Value)
constexpr unsigned BitWidth
decltype(auto) cast(const From &Val)
cast<X> - Return the argument parameter cast to the specified type.
Definition Casting.h:559
bool is_contained(R &&Range, const E &Element)
Returns true if Element is found in Range.
Definition STLExtras.h:1946
SmallVector< SPIRV::Capability::Capability, 8 > CapabilityList
std::set< InstrSignature > InstrTraces
hash_code hash_combine(const Ts &...args)
Combine values into a single hash_code.
Definition Hashing.h:325
std::map< SmallVector< size_t >, unsigned > InstrGRegsMap
LLVM_ABI void reportFatalUsageError(Error Err)
Report a fatal error that does not indicate a bug in LLVM.
Definition Error.cpp:177
#define N
SmallSet< SPIRV::Capability::Capability, 4 > S
void getAnalysisUsage(AnalysisUsage &AU) const override
getAnalysisUsage - This function should be overriden by passes that need analysis information to do t...
SPIRV::ModuleAnalysisInfo MAI
bool runOnModule(Module &M) override
runOnModule - Virtual method overriden by subclasses to process the module being operated on.
static size_t computeFPFastMathDefaultInfoVecIndex(size_t BitWidth)
Definition SPIRVUtils.h:148
void setSkipEmission(const MachineInstr *MI)
MCRegister getRegisterAlias(const MachineFunction *MF, Register Reg)
MCRegister getOrCreateMBBRegister(const MachineBasicBlock &MBB)
InstrList MS[NUM_MODULE_SECTIONS]
AddressingModel::AddressingModel Addr
void setRegisterAlias(const MachineFunction *MF, Register Reg, MCRegister AliasReg)
DenseMap< const Function *, SPIRV::FPFastMathDefaultInfoVector > FPFastMathDefaultInfoMap
void addCapabilities(const CapabilityList &ToAdd)
bool isCapabilityAvailable(Capability::Capability Cap) const
void checkSatisfiable(const SPIRVSubtarget &ST) const
void getAndAddRequirements(SPIRV::OperandCategory::OperandCategory Category, uint32_t i, const SPIRVSubtarget &ST)
void addExtension(Extension::Extension ToAdd)
void initAvailableCapabilities(const SPIRVSubtarget &ST)
void removeCapabilityIf(const Capability::Capability ToRemove, const Capability::Capability IfPresent)
void addCapability(Capability::Capability ToAdd)
void addAvailableCaps(const CapabilityList &ToAdd)
void addRequirements(const Requirements &Req)
const std::optional< Capability::Capability > Cap