LLVM 22.0.0git
SPIRVModuleAnalysis.cpp
Go to the documentation of this file.
1//===- SPIRVModuleAnalysis.cpp - analysis of global instrs & regs - C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// The analysis collects instructions that should be output at the module level
10// and performs the global register numbering.
11//
12// The results of this analysis are used in AsmPrinter to rename registers
13// globally and to output required instructions at the module level.
14//
15//===----------------------------------------------------------------------===//
16
17#include "SPIRVModuleAnalysis.h"
20#include "SPIRV.h"
21#include "SPIRVSubtarget.h"
22#include "SPIRVTargetMachine.h"
23#include "SPIRVUtils.h"
24#include "llvm/ADT/STLExtras.h"
27
28using namespace llvm;
29
30#define DEBUG_TYPE "spirv-module-analysis"
31
32static cl::opt<bool>
33 SPVDumpDeps("spv-dump-deps",
34 cl::desc("Dump MIR with SPIR-V dependencies info"),
35 cl::Optional, cl::init(false));
36
38 AvoidCapabilities("avoid-spirv-capabilities",
39 cl::desc("SPIR-V capabilities to avoid if there are "
40 "other options enabling a feature"),
42 cl::values(clEnumValN(SPIRV::Capability::Shader, "Shader",
43 "SPIR-V Shader capability")));
44// Use sets instead of cl::list to check "if contains" condition
49
51
52INITIALIZE_PASS(SPIRVModuleAnalysis, DEBUG_TYPE, "SPIRV module analysis", true,
53 true)
54
55// Retrieve an unsigned from an MDNode with a list of them as operands.
56static unsigned getMetadataUInt(MDNode *MdNode, unsigned OpIndex,
57 unsigned DefaultVal = 0) {
58 if (MdNode && OpIndex < MdNode->getNumOperands()) {
59 const auto &Op = MdNode->getOperand(OpIndex);
60 return mdconst::extract<ConstantInt>(Op)->getZExtValue();
61 }
62 return DefaultVal;
63}
64
66getSymbolicOperandRequirements(SPIRV::OperandCategory::OperandCategory Category,
67 unsigned i, const SPIRVSubtarget &ST,
69 // A set of capabilities to avoid if there is another option.
70 AvoidCapabilitiesSet AvoidCaps;
71 if (!ST.isShader())
72 AvoidCaps.S.insert(SPIRV::Capability::Shader);
73 else
74 AvoidCaps.S.insert(SPIRV::Capability::Kernel);
75
76 VersionTuple ReqMinVer = getSymbolicOperandMinVersion(Category, i);
77 VersionTuple ReqMaxVer = getSymbolicOperandMaxVersion(Category, i);
78 VersionTuple SPIRVVersion = ST.getSPIRVVersion();
79 bool MinVerOK = SPIRVVersion.empty() || SPIRVVersion >= ReqMinVer;
80 bool MaxVerOK =
81 ReqMaxVer.empty() || SPIRVVersion.empty() || SPIRVVersion <= ReqMaxVer;
83 ExtensionList ReqExts = getSymbolicOperandExtensions(Category, i);
84 if (ReqCaps.empty()) {
85 if (ReqExts.empty()) {
86 if (MinVerOK && MaxVerOK)
87 return {true, {}, {}, ReqMinVer, ReqMaxVer};
88 return {false, {}, {}, VersionTuple(), VersionTuple()};
89 }
90 } else if (MinVerOK && MaxVerOK) {
91 if (ReqCaps.size() == 1) {
92 auto Cap = ReqCaps[0];
93 if (Reqs.isCapabilityAvailable(Cap)) {
95 SPIRV::OperandCategory::CapabilityOperand, Cap));
96 return {true, {Cap}, std::move(ReqExts), ReqMinVer, ReqMaxVer};
97 }
98 } else {
99 // By SPIR-V specification: "If an instruction, enumerant, or other
100 // feature specifies multiple enabling capabilities, only one such
101 // capability needs to be declared to use the feature." However, one
102 // capability may be preferred over another. We use command line
103 // argument(s) and AvoidCapabilities to avoid selection of certain
104 // capabilities if there are other options.
105 CapabilityList UseCaps;
106 for (auto Cap : ReqCaps)
107 if (Reqs.isCapabilityAvailable(Cap))
108 UseCaps.push_back(Cap);
109 for (size_t i = 0, Sz = UseCaps.size(); i < Sz; ++i) {
110 auto Cap = UseCaps[i];
111 if (i == Sz - 1 || !AvoidCaps.S.contains(Cap)) {
113 SPIRV::OperandCategory::CapabilityOperand, Cap));
114 return {true, {Cap}, std::move(ReqExts), ReqMinVer, ReqMaxVer};
115 }
116 }
117 }
118 }
119 // If there are no capabilities, or we can't satisfy the version or
120 // capability requirements, use the list of extensions (if the subtarget
121 // can handle them all).
122 if (llvm::all_of(ReqExts, [&ST](const SPIRV::Extension::Extension &Ext) {
123 return ST.canUseExtension(Ext);
124 })) {
125 return {true,
126 {},
127 std::move(ReqExts),
128 VersionTuple(),
129 VersionTuple()}; // TODO: add versions to extensions.
130 }
131 return {false, {}, {}, VersionTuple(), VersionTuple()};
132}
133
134void SPIRVModuleAnalysis::setBaseInfo(const Module &M) {
135 MAI.MaxID = 0;
136 for (int i = 0; i < SPIRV::NUM_MODULE_SECTIONS; i++)
137 MAI.MS[i].clear();
138 MAI.RegisterAliasTable.clear();
139 MAI.InstrsToDelete.clear();
140 MAI.FuncMap.clear();
141 MAI.GlobalVarList.clear();
142 MAI.ExtInstSetMap.clear();
143 MAI.Reqs.clear();
144 MAI.Reqs.initAvailableCapabilities(*ST);
145
146 // TODO: determine memory model and source language from the configuratoin.
147 if (auto MemModel = M.getNamedMetadata("spirv.MemoryModel")) {
148 auto MemMD = MemModel->getOperand(0);
149 MAI.Addr = static_cast<SPIRV::AddressingModel::AddressingModel>(
150 getMetadataUInt(MemMD, 0));
151 MAI.Mem =
152 static_cast<SPIRV::MemoryModel::MemoryModel>(getMetadataUInt(MemMD, 1));
153 } else {
154 // TODO: Add support for VulkanMemoryModel.
155 MAI.Mem = ST->isShader() ? SPIRV::MemoryModel::GLSL450
156 : SPIRV::MemoryModel::OpenCL;
157 if (MAI.Mem == SPIRV::MemoryModel::OpenCL) {
158 unsigned PtrSize = ST->getPointerSize();
159 MAI.Addr = PtrSize == 32 ? SPIRV::AddressingModel::Physical32
160 : PtrSize == 64 ? SPIRV::AddressingModel::Physical64
161 : SPIRV::AddressingModel::Logical;
162 } else {
163 // TODO: Add support for PhysicalStorageBufferAddress.
164 MAI.Addr = SPIRV::AddressingModel::Logical;
165 }
166 }
167 // Get the OpenCL version number from metadata.
168 // TODO: support other source languages.
169 if (auto VerNode = M.getNamedMetadata("opencl.ocl.version")) {
170 MAI.SrcLang = SPIRV::SourceLanguage::OpenCL_C;
171 // Construct version literal in accordance with SPIRV-LLVM-Translator.
172 // TODO: support multiple OCL version metadata.
173 assert(VerNode->getNumOperands() > 0 && "Invalid SPIR");
174 auto VersionMD = VerNode->getOperand(0);
175 unsigned MajorNum = getMetadataUInt(VersionMD, 0, 2);
176 unsigned MinorNum = getMetadataUInt(VersionMD, 1);
177 unsigned RevNum = getMetadataUInt(VersionMD, 2);
178 // Prevent Major part of OpenCL version to be 0
179 MAI.SrcLangVersion =
180 (std::max(1U, MajorNum) * 100 + MinorNum) * 1000 + RevNum;
181 } else {
182 // If there is no information about OpenCL version we are forced to generate
183 // OpenCL 1.0 by default for the OpenCL environment to avoid puzzling
184 // run-times with Unknown/0.0 version output. For a reference, LLVM-SPIRV
185 // Translator avoids potential issues with run-times in a similar manner.
186 if (!ST->isShader()) {
187 MAI.SrcLang = SPIRV::SourceLanguage::OpenCL_CPP;
188 MAI.SrcLangVersion = 100000;
189 } else {
190 MAI.SrcLang = SPIRV::SourceLanguage::Unknown;
191 MAI.SrcLangVersion = 0;
192 }
193 }
194
195 if (auto ExtNode = M.getNamedMetadata("opencl.used.extensions")) {
196 for (unsigned I = 0, E = ExtNode->getNumOperands(); I != E; ++I) {
197 MDNode *MD = ExtNode->getOperand(I);
198 if (!MD || MD->getNumOperands() == 0)
199 continue;
200 for (unsigned J = 0, N = MD->getNumOperands(); J != N; ++J)
201 MAI.SrcExt.insert(cast<MDString>(MD->getOperand(J))->getString());
202 }
203 }
204
205 // Update required capabilities for this memory model, addressing model and
206 // source language.
207 MAI.Reqs.getAndAddRequirements(SPIRV::OperandCategory::MemoryModelOperand,
208 MAI.Mem, *ST);
209 MAI.Reqs.getAndAddRequirements(SPIRV::OperandCategory::SourceLanguageOperand,
210 MAI.SrcLang, *ST);
211 MAI.Reqs.getAndAddRequirements(SPIRV::OperandCategory::AddressingModelOperand,
212 MAI.Addr, *ST);
213
214 if (!ST->isShader()) {
215 // TODO: check if it's required by default.
216 MAI.ExtInstSetMap[static_cast<unsigned>(
217 SPIRV::InstructionSet::OpenCL_std)] = MAI.getNextIDRegister();
218 }
219}
220
221// Appends the signature of the decoration instructions that decorate R to
222// Signature.
223static void appendDecorationsForReg(const MachineRegisterInfo &MRI, Register R,
224 InstrSignature &Signature) {
225 for (MachineInstr &UseMI : MRI.use_instructions(R)) {
226 // We don't handle OpDecorateId because getting the register alias for the
227 // ID can cause problems, and we do not need it for now.
228 if (UseMI.getOpcode() != SPIRV::OpDecorate &&
229 UseMI.getOpcode() != SPIRV::OpMemberDecorate)
230 continue;
231
232 for (unsigned I = 0; I < UseMI.getNumOperands(); ++I) {
233 const MachineOperand &MO = UseMI.getOperand(I);
234 if (MO.isReg())
235 continue;
236 Signature.push_back(hash_value(MO));
237 }
238 }
239}
240
241// Returns a representation of an instruction as a vector of MachineOperand
242// hash values, see llvm::hash_value(const MachineOperand &MO) for details.
243// This creates a signature of the instruction with the same content
244// that MachineOperand::isIdenticalTo uses for comparison.
245static InstrSignature instrToSignature(const MachineInstr &MI,
247 bool UseDefReg) {
248 Register DefReg;
249 InstrSignature Signature{MI.getOpcode()};
250 for (unsigned i = 0; i < MI.getNumOperands(); ++i) {
251 // The only decorations that can be applied more than once to a given <id>
252 // or structure member are FuncParamAttr (38), UserSemantic (5635),
253 // CacheControlLoadINTEL (6442), and CacheControlStoreINTEL (6443). For all
254 // the rest of decorations, we will only add to the signature the Opcode,
255 // the id to which it applies, and the decoration id, disregarding any
256 // decoration flags. This will ensure that any subsequent decoration with
257 // the same id will be deemed as a duplicate. Then, at the call site, we
258 // will be able to handle duplicates in the best way.
259 unsigned Opcode = MI.getOpcode();
260 if ((Opcode == SPIRV::OpDecorate) && i >= 2) {
261 unsigned DecorationID = MI.getOperand(1).getImm();
262 if (DecorationID != SPIRV::Decoration::FuncParamAttr &&
263 DecorationID != SPIRV::Decoration::UserSemantic &&
264 DecorationID != SPIRV::Decoration::CacheControlLoadINTEL &&
265 DecorationID != SPIRV::Decoration::CacheControlStoreINTEL)
266 continue;
267 }
268 const MachineOperand &MO = MI.getOperand(i);
269 size_t h;
270 if (MO.isReg()) {
271 if (!UseDefReg && MO.isDef()) {
272 assert(!DefReg.isValid() && "Multiple def registers.");
273 DefReg = MO.getReg();
274 continue;
275 }
276 Register RegAlias = MAI.getRegisterAlias(MI.getMF(), MO.getReg());
277 if (!RegAlias.isValid()) {
278 LLVM_DEBUG({
279 dbgs() << "Unexpectedly, no global id found for the operand ";
280 MO.print(dbgs());
281 dbgs() << "\nInstruction: ";
282 MI.print(dbgs());
283 dbgs() << "\n";
284 });
285 report_fatal_error("All v-regs must have been mapped to global id's");
286 }
287 // mimic llvm::hash_value(const MachineOperand &MO)
288 h = hash_combine(MO.getType(), (unsigned)RegAlias, MO.getSubReg(),
289 MO.isDef());
290 } else {
291 h = hash_value(MO);
292 }
293 Signature.push_back(h);
294 }
295
296 if (DefReg.isValid()) {
297 // Decorations change the semantics of the current instruction. So two
298 // identical instruction with different decorations cannot be merged. That
299 // is why we add the decorations to the signature.
300 appendDecorationsForReg(MI.getMF()->getRegInfo(), DefReg, Signature);
301 }
302 return Signature;
303}
304
305bool SPIRVModuleAnalysis::isDeclSection(const MachineRegisterInfo &MRI,
306 const MachineInstr &MI) {
307 unsigned Opcode = MI.getOpcode();
308 switch (Opcode) {
309 case SPIRV::OpTypeForwardPointer:
310 // omit now, collect later
311 return false;
312 case SPIRV::OpVariable:
313 return static_cast<SPIRV::StorageClass::StorageClass>(
314 MI.getOperand(2).getImm()) != SPIRV::StorageClass::Function;
315 case SPIRV::OpFunction:
316 case SPIRV::OpFunctionParameter:
317 return true;
318 }
319 if (GR->hasConstFunPtr() && Opcode == SPIRV::OpUndef) {
320 Register DefReg = MI.getOperand(0).getReg();
321 for (MachineInstr &UseMI : MRI.use_instructions(DefReg)) {
322 if (UseMI.getOpcode() != SPIRV::OpConstantFunctionPointerINTEL)
323 continue;
324 // it's a dummy definition, FP constant refers to a function,
325 // and this is resolved in another way; let's skip this definition
326 assert(UseMI.getOperand(2).isReg() &&
327 UseMI.getOperand(2).getReg() == DefReg);
328 MAI.setSkipEmission(&MI);
329 return false;
330 }
331 }
332 return TII->isTypeDeclInstr(MI) || TII->isConstantInstr(MI) ||
333 TII->isInlineAsmDefInstr(MI);
334}
335
336// This is a special case of a function pointer refering to a possibly
337// forward function declaration. The operand is a dummy OpUndef that
338// requires a special treatment.
339void SPIRVModuleAnalysis::visitFunPtrUse(
340 Register OpReg, InstrGRegsMap &SignatureToGReg,
341 std::map<const Value *, unsigned> &GlobalToGReg, const MachineFunction *MF,
342 const MachineInstr &MI) {
343 const MachineOperand *OpFunDef =
344 GR->getFunctionDefinitionByUse(&MI.getOperand(2));
345 assert(OpFunDef && OpFunDef->isReg());
346 // find the actual function definition and number it globally in advance
347 const MachineInstr *OpDefMI = OpFunDef->getParent();
348 assert(OpDefMI && OpDefMI->getOpcode() == SPIRV::OpFunction);
349 const MachineFunction *FunDefMF = OpDefMI->getParent()->getParent();
350 const MachineRegisterInfo &FunDefMRI = FunDefMF->getRegInfo();
351 do {
352 visitDecl(FunDefMRI, SignatureToGReg, GlobalToGReg, FunDefMF, *OpDefMI);
353 OpDefMI = OpDefMI->getNextNode();
354 } while (OpDefMI && (OpDefMI->getOpcode() == SPIRV::OpFunction ||
355 OpDefMI->getOpcode() == SPIRV::OpFunctionParameter));
356 // associate the function pointer with the newly assigned global number
357 MCRegister GlobalFunDefReg =
358 MAI.getRegisterAlias(FunDefMF, OpFunDef->getReg());
359 assert(GlobalFunDefReg.isValid() &&
360 "Function definition must refer to a global register");
361 MAI.setRegisterAlias(MF, OpReg, GlobalFunDefReg);
362}
363
364// Depth first recursive traversal of dependencies. Repeated visits are guarded
365// by MAI.hasRegisterAlias().
366void SPIRVModuleAnalysis::visitDecl(
367 const MachineRegisterInfo &MRI, InstrGRegsMap &SignatureToGReg,
368 std::map<const Value *, unsigned> &GlobalToGReg, const MachineFunction *MF,
369 const MachineInstr &MI) {
370 unsigned Opcode = MI.getOpcode();
371
372 // Process each operand of the instruction to resolve dependencies
373 for (const MachineOperand &MO : MI.operands()) {
374 if (!MO.isReg() || MO.isDef())
375 continue;
376 Register OpReg = MO.getReg();
377 // Handle function pointers special case
378 if (Opcode == SPIRV::OpConstantFunctionPointerINTEL &&
379 MRI.getRegClass(OpReg) == &SPIRV::pIDRegClass) {
380 visitFunPtrUse(OpReg, SignatureToGReg, GlobalToGReg, MF, MI);
381 continue;
382 }
383 // Skip already processed instructions
384 if (MAI.hasRegisterAlias(MF, MO.getReg()))
385 continue;
386 // Recursively visit dependencies
387 if (const MachineInstr *OpDefMI = MRI.getUniqueVRegDef(OpReg)) {
388 if (isDeclSection(MRI, *OpDefMI))
389 visitDecl(MRI, SignatureToGReg, GlobalToGReg, MF, *OpDefMI);
390 continue;
391 }
392 // Handle the unexpected case of no unique definition for the SPIR-V
393 // instruction
394 LLVM_DEBUG({
395 dbgs() << "Unexpectedly, no unique definition for the operand ";
396 MO.print(dbgs());
397 dbgs() << "\nInstruction: ";
398 MI.print(dbgs());
399 dbgs() << "\n";
400 });
402 "No unique definition is found for the virtual register");
403 }
404
405 MCRegister GReg;
406 bool IsFunDef = false;
407 if (TII->isSpecConstantInstr(MI)) {
408 GReg = MAI.getNextIDRegister();
409 MAI.MS[SPIRV::MB_TypeConstVars].push_back(&MI);
410 } else if (Opcode == SPIRV::OpFunction ||
411 Opcode == SPIRV::OpFunctionParameter) {
412 GReg = handleFunctionOrParameter(MF, MI, GlobalToGReg, IsFunDef);
413 } else if (Opcode == SPIRV::OpTypeStruct ||
414 Opcode == SPIRV::OpConstantComposite) {
415 GReg = handleTypeDeclOrConstant(MI, SignatureToGReg);
416 const MachineInstr *NextInstr = MI.getNextNode();
417 while (NextInstr &&
418 ((Opcode == SPIRV::OpTypeStruct &&
419 NextInstr->getOpcode() == SPIRV::OpTypeStructContinuedINTEL) ||
420 (Opcode == SPIRV::OpConstantComposite &&
421 NextInstr->getOpcode() ==
422 SPIRV::OpConstantCompositeContinuedINTEL))) {
423 MCRegister Tmp = handleTypeDeclOrConstant(*NextInstr, SignatureToGReg);
424 MAI.setRegisterAlias(MF, NextInstr->getOperand(0).getReg(), Tmp);
425 MAI.setSkipEmission(NextInstr);
426 NextInstr = NextInstr->getNextNode();
427 }
428 } else if (TII->isTypeDeclInstr(MI) || TII->isConstantInstr(MI) ||
429 TII->isInlineAsmDefInstr(MI)) {
430 GReg = handleTypeDeclOrConstant(MI, SignatureToGReg);
431 } else if (Opcode == SPIRV::OpVariable) {
432 GReg = handleVariable(MF, MI, GlobalToGReg);
433 } else {
434 LLVM_DEBUG({
435 dbgs() << "\nInstruction: ";
436 MI.print(dbgs());
437 dbgs() << "\n";
438 });
439 llvm_unreachable("Unexpected instruction is visited");
440 }
441 MAI.setRegisterAlias(MF, MI.getOperand(0).getReg(), GReg);
442 if (!IsFunDef)
443 MAI.setSkipEmission(&MI);
444}
445
446MCRegister SPIRVModuleAnalysis::handleFunctionOrParameter(
447 const MachineFunction *MF, const MachineInstr &MI,
448 std::map<const Value *, unsigned> &GlobalToGReg, bool &IsFunDef) {
449 const Value *GObj = GR->getGlobalObject(MF, MI.getOperand(0).getReg());
450 assert(GObj && "Unregistered global definition");
451 const Function *F = dyn_cast<Function>(GObj);
452 if (!F)
453 F = dyn_cast<Argument>(GObj)->getParent();
454 assert(F && "Expected a reference to a function or an argument");
455 IsFunDef = !F->isDeclaration();
456 auto [It, Inserted] = GlobalToGReg.try_emplace(GObj);
457 if (!Inserted)
458 return It->second;
459 MCRegister GReg = MAI.getNextIDRegister();
460 It->second = GReg;
461 if (!IsFunDef)
462 MAI.MS[SPIRV::MB_ExtFuncDecls].push_back(&MI);
463 return GReg;
464}
465
467SPIRVModuleAnalysis::handleTypeDeclOrConstant(const MachineInstr &MI,
468 InstrGRegsMap &SignatureToGReg) {
469 InstrSignature MISign = instrToSignature(MI, MAI, false);
470 auto [It, Inserted] = SignatureToGReg.try_emplace(MISign);
471 if (!Inserted)
472 return It->second;
473 MCRegister GReg = MAI.getNextIDRegister();
474 It->second = GReg;
475 MAI.MS[SPIRV::MB_TypeConstVars].push_back(&MI);
476 return GReg;
477}
478
479MCRegister SPIRVModuleAnalysis::handleVariable(
480 const MachineFunction *MF, const MachineInstr &MI,
481 std::map<const Value *, unsigned> &GlobalToGReg) {
482 MAI.GlobalVarList.push_back(&MI);
483 const Value *GObj = GR->getGlobalObject(MF, MI.getOperand(0).getReg());
484 assert(GObj && "Unregistered global definition");
485 auto [It, Inserted] = GlobalToGReg.try_emplace(GObj);
486 if (!Inserted)
487 return It->second;
488 MCRegister GReg = MAI.getNextIDRegister();
489 It->second = GReg;
490 MAI.MS[SPIRV::MB_TypeConstVars].push_back(&MI);
491 return GReg;
492}
493
494void SPIRVModuleAnalysis::collectDeclarations(const Module &M) {
495 InstrGRegsMap SignatureToGReg;
496 std::map<const Value *, unsigned> GlobalToGReg;
497 for (auto F = M.begin(), E = M.end(); F != E; ++F) {
498 MachineFunction *MF = MMI->getMachineFunction(*F);
499 if (!MF)
500 continue;
501 const MachineRegisterInfo &MRI = MF->getRegInfo();
502 unsigned PastHeader = 0;
503 for (MachineBasicBlock &MBB : *MF) {
504 for (MachineInstr &MI : MBB) {
505 if (MI.getNumOperands() == 0)
506 continue;
507 unsigned Opcode = MI.getOpcode();
508 if (Opcode == SPIRV::OpFunction) {
509 if (PastHeader == 0) {
510 PastHeader = 1;
511 continue;
512 }
513 } else if (Opcode == SPIRV::OpFunctionParameter) {
514 if (PastHeader < 2)
515 continue;
516 } else if (PastHeader > 0) {
517 PastHeader = 2;
518 }
519
520 const MachineOperand &DefMO = MI.getOperand(0);
521 switch (Opcode) {
522 case SPIRV::OpExtension:
523 MAI.Reqs.addExtension(SPIRV::Extension::Extension(DefMO.getImm()));
524 MAI.setSkipEmission(&MI);
525 break;
526 case SPIRV::OpCapability:
527 MAI.Reqs.addCapability(SPIRV::Capability::Capability(DefMO.getImm()));
528 MAI.setSkipEmission(&MI);
529 if (PastHeader > 0)
530 PastHeader = 2;
531 break;
532 default:
533 if (DefMO.isReg() && isDeclSection(MRI, MI) &&
534 !MAI.hasRegisterAlias(MF, DefMO.getReg()))
535 visitDecl(MRI, SignatureToGReg, GlobalToGReg, MF, MI);
536 }
537 }
538 }
539 }
540}
541
542// Look for IDs declared with Import linkage, and map the corresponding function
543// to the register defining that variable (which will usually be the result of
544// an OpFunction). This lets us call externally imported functions using
545// the correct ID registers.
546void SPIRVModuleAnalysis::collectFuncNames(MachineInstr &MI,
547 const Function *F) {
548 if (MI.getOpcode() == SPIRV::OpDecorate) {
549 // If it's got Import linkage.
550 auto Dec = MI.getOperand(1).getImm();
551 if (Dec == SPIRV::Decoration::LinkageAttributes) {
552 auto Lnk = MI.getOperand(MI.getNumOperands() - 1).getImm();
553 if (Lnk == SPIRV::LinkageType::Import) {
554 // Map imported function name to function ID register.
555 const Function *ImportedFunc =
556 F->getParent()->getFunction(getStringImm(MI, 2));
557 Register Target = MI.getOperand(0).getReg();
558 MAI.FuncMap[ImportedFunc] = MAI.getRegisterAlias(MI.getMF(), Target);
559 }
560 }
561 } else if (MI.getOpcode() == SPIRV::OpFunction) {
562 // Record all internal OpFunction declarations.
563 Register Reg = MI.defs().begin()->getReg();
564 MCRegister GlobalReg = MAI.getRegisterAlias(MI.getMF(), Reg);
565 assert(GlobalReg.isValid());
566 MAI.FuncMap[F] = GlobalReg;
567 }
568}
569
570// Collect the given instruction in the specified MS. We assume global register
571// numbering has already occurred by this point. We can directly compare reg
572// arguments when detecting duplicates.
573static void collectOtherInstr(MachineInstr &MI, SPIRV::ModuleAnalysisInfo &MAI,
575 bool Append = true) {
576 MAI.setSkipEmission(&MI);
577 InstrSignature MISign = instrToSignature(MI, MAI, true);
578 auto FoundMI = IS.insert(std::move(MISign));
579 if (!FoundMI.second) {
580 if (MI.getOpcode() == SPIRV::OpDecorate) {
581 assert(MI.getNumOperands() >= 2 &&
582 "Decoration instructions must have at least 2 operands");
583 assert(MSType == SPIRV::MB_Annotations &&
584 "Only OpDecorate instructions can be duplicates");
585 // For FPFastMathMode decoration, we need to merge the flags of the
586 // duplicate decoration with the original one, so we need to find the
587 // original instruction that has the same signature. For the rest of
588 // instructions, we will simply skip the duplicate.
589 if (MI.getOperand(1).getImm() != SPIRV::Decoration::FPFastMathMode)
590 return; // Skip duplicates of other decorations.
591
592 const SPIRV::InstrList &Decorations = MAI.MS[MSType];
593 for (const MachineInstr *OrigMI : Decorations) {
594 if (instrToSignature(*OrigMI, MAI, true) == MISign) {
595 assert(OrigMI->getNumOperands() == MI.getNumOperands() &&
596 "Original instruction must have the same number of operands");
597 assert(
598 OrigMI->getNumOperands() == 3 &&
599 "FPFastMathMode decoration must have 3 operands for OpDecorate");
600 unsigned OrigFlags = OrigMI->getOperand(2).getImm();
601 unsigned NewFlags = MI.getOperand(2).getImm();
602 if (OrigFlags == NewFlags)
603 return; // No need to merge, the flags are the same.
604
605 // Emit warning about possible conflict between flags.
606 unsigned FinalFlags = OrigFlags | NewFlags;
607 llvm::errs()
608 << "Warning: Conflicting FPFastMathMode decoration flags "
609 "in instruction: "
610 << *OrigMI << "Original flags: " << OrigFlags
611 << ", new flags: " << NewFlags
612 << ". They will be merged on a best effort basis, but not "
613 "validated. Final flags: "
614 << FinalFlags << "\n";
615 MachineInstr *OrigMINonConst = const_cast<MachineInstr *>(OrigMI);
616 MachineOperand &OrigFlagsOp = OrigMINonConst->getOperand(2);
617 OrigFlagsOp = MachineOperand::CreateImm(FinalFlags);
618 return; // Merge done, so we found a duplicate; don't add it to MAI.MS
619 }
620 }
621 assert(false && "No original instruction found for the duplicate "
622 "OpDecorate, but we found one in IS.");
623 }
624 return; // insert failed, so we found a duplicate; don't add it to MAI.MS
625 }
626 // No duplicates, so add it.
627 if (Append)
628 MAI.MS[MSType].push_back(&MI);
629 else
630 MAI.MS[MSType].insert(MAI.MS[MSType].begin(), &MI);
631}
632
633// Some global instructions make reference to function-local ID regs, so cannot
634// be correctly collected until these registers are globally numbered.
635void SPIRVModuleAnalysis::processOtherInstrs(const Module &M) {
636 InstrTraces IS;
637 for (auto F = M.begin(), E = M.end(); F != E; ++F) {
638 if (F->isDeclaration())
639 continue;
640 MachineFunction *MF = MMI->getMachineFunction(*F);
641 assert(MF);
642
643 for (MachineBasicBlock &MBB : *MF)
644 for (MachineInstr &MI : MBB) {
645 if (MAI.getSkipEmission(&MI))
646 continue;
647 const unsigned OpCode = MI.getOpcode();
648 if (OpCode == SPIRV::OpString) {
649 collectOtherInstr(MI, MAI, SPIRV::MB_DebugStrings, IS);
650 } else if (OpCode == SPIRV::OpExtInst && MI.getOperand(2).isImm() &&
651 MI.getOperand(2).getImm() ==
652 SPIRV::InstructionSet::
653 NonSemantic_Shader_DebugInfo_100) {
654 MachineOperand Ins = MI.getOperand(3);
655 namespace NS = SPIRV::NonSemanticExtInst;
656 static constexpr int64_t GlobalNonSemanticDITy[] = {
657 NS::DebugSource, NS::DebugCompilationUnit, NS::DebugInfoNone,
658 NS::DebugTypeBasic, NS::DebugTypePointer};
659 bool IsGlobalDI = false;
660 for (unsigned Idx = 0; Idx < std::size(GlobalNonSemanticDITy); ++Idx)
661 IsGlobalDI |= Ins.getImm() == GlobalNonSemanticDITy[Idx];
662 if (IsGlobalDI)
663 collectOtherInstr(MI, MAI, SPIRV::MB_NonSemanticGlobalDI, IS);
664 } else if (OpCode == SPIRV::OpName || OpCode == SPIRV::OpMemberName) {
665 collectOtherInstr(MI, MAI, SPIRV::MB_DebugNames, IS);
666 } else if (OpCode == SPIRV::OpEntryPoint) {
667 collectOtherInstr(MI, MAI, SPIRV::MB_EntryPoints, IS);
668 } else if (TII->isAliasingInstr(MI)) {
669 collectOtherInstr(MI, MAI, SPIRV::MB_AliasingInsts, IS);
670 } else if (TII->isDecorationInstr(MI)) {
671 collectOtherInstr(MI, MAI, SPIRV::MB_Annotations, IS);
672 collectFuncNames(MI, &*F);
673 } else if (TII->isConstantInstr(MI)) {
674 // Now OpSpecConstant*s are not in DT,
675 // but they need to be collected anyway.
676 collectOtherInstr(MI, MAI, SPIRV::MB_TypeConstVars, IS);
677 } else if (OpCode == SPIRV::OpFunction) {
678 collectFuncNames(MI, &*F);
679 } else if (OpCode == SPIRV::OpTypeForwardPointer) {
680 collectOtherInstr(MI, MAI, SPIRV::MB_TypeConstVars, IS, false);
681 }
682 }
683 }
684}
685
686// Number registers in all functions globally from 0 onwards and store
687// the result in global register alias table. Some registers are already
688// numbered.
689void SPIRVModuleAnalysis::numberRegistersGlobally(const Module &M) {
690 for (auto F = M.begin(), E = M.end(); F != E; ++F) {
691 if ((*F).isDeclaration())
692 continue;
693 MachineFunction *MF = MMI->getMachineFunction(*F);
694 assert(MF);
695 for (MachineBasicBlock &MBB : *MF) {
696 for (MachineInstr &MI : MBB) {
697 for (MachineOperand &Op : MI.operands()) {
698 if (!Op.isReg())
699 continue;
700 Register Reg = Op.getReg();
701 if (MAI.hasRegisterAlias(MF, Reg))
702 continue;
703 MCRegister NewReg = MAI.getNextIDRegister();
704 MAI.setRegisterAlias(MF, Reg, NewReg);
705 }
706 if (MI.getOpcode() != SPIRV::OpExtInst)
707 continue;
708 auto Set = MI.getOperand(2).getImm();
709 auto [It, Inserted] = MAI.ExtInstSetMap.try_emplace(Set);
710 if (Inserted)
711 It->second = MAI.getNextIDRegister();
712 }
713 }
714 }
715}
716
717// RequirementHandler implementations.
719 SPIRV::OperandCategory::OperandCategory Category, uint32_t i,
720 const SPIRVSubtarget &ST) {
721 addRequirements(getSymbolicOperandRequirements(Category, i, ST, *this));
722}
723
724void SPIRV::RequirementHandler::recursiveAddCapabilities(
725 const CapabilityList &ToPrune) {
726 for (const auto &Cap : ToPrune) {
727 AllCaps.insert(Cap);
728 CapabilityList ImplicitDecls =
729 getSymbolicOperandCapabilities(OperandCategory::CapabilityOperand, Cap);
730 recursiveAddCapabilities(ImplicitDecls);
731 }
732}
733
735 for (const auto &Cap : ToAdd) {
736 bool IsNewlyInserted = AllCaps.insert(Cap).second;
737 if (!IsNewlyInserted) // Don't re-add if it's already been declared.
738 continue;
739 CapabilityList ImplicitDecls =
740 getSymbolicOperandCapabilities(OperandCategory::CapabilityOperand, Cap);
741 recursiveAddCapabilities(ImplicitDecls);
742 MinimalCaps.push_back(Cap);
743 }
744}
745
747 const SPIRV::Requirements &Req) {
748 if (!Req.IsSatisfiable)
749 report_fatal_error("Adding SPIR-V requirements this target can't satisfy.");
750
751 if (Req.Cap.has_value())
752 addCapabilities({Req.Cap.value()});
753
754 addExtensions(Req.Exts);
755
756 if (!Req.MinVer.empty()) {
757 if (!MaxVersion.empty() && Req.MinVer > MaxVersion) {
758 LLVM_DEBUG(dbgs() << "Conflicting version requirements: >= " << Req.MinVer
759 << " and <= " << MaxVersion << "\n");
760 report_fatal_error("Adding SPIR-V requirements that can't be satisfied.");
761 }
762
763 if (MinVersion.empty() || Req.MinVer > MinVersion)
764 MinVersion = Req.MinVer;
765 }
766
767 if (!Req.MaxVer.empty()) {
768 if (!MinVersion.empty() && Req.MaxVer < MinVersion) {
769 LLVM_DEBUG(dbgs() << "Conflicting version requirements: <= " << Req.MaxVer
770 << " and >= " << MinVersion << "\n");
771 report_fatal_error("Adding SPIR-V requirements that can't be satisfied.");
772 }
773
774 if (MaxVersion.empty() || Req.MaxVer < MaxVersion)
775 MaxVersion = Req.MaxVer;
776 }
777}
778
780 const SPIRVSubtarget &ST) const {
781 // Report as many errors as possible before aborting the compilation.
782 bool IsSatisfiable = true;
783 auto TargetVer = ST.getSPIRVVersion();
784
785 if (!MaxVersion.empty() && !TargetVer.empty() && MaxVersion < TargetVer) {
787 dbgs() << "Target SPIR-V version too high for required features\n"
788 << "Required max version: " << MaxVersion << " target version "
789 << TargetVer << "\n");
790 IsSatisfiable = false;
791 }
792
793 if (!MinVersion.empty() && !TargetVer.empty() && MinVersion > TargetVer) {
794 LLVM_DEBUG(dbgs() << "Target SPIR-V version too low for required features\n"
795 << "Required min version: " << MinVersion
796 << " target version " << TargetVer << "\n");
797 IsSatisfiable = false;
798 }
799
800 if (!MinVersion.empty() && !MaxVersion.empty() && MinVersion > MaxVersion) {
802 dbgs()
803 << "Version is too low for some features and too high for others.\n"
804 << "Required SPIR-V min version: " << MinVersion
805 << " required SPIR-V max version " << MaxVersion << "\n");
806 IsSatisfiable = false;
807 }
808
809 AvoidCapabilitiesSet AvoidCaps;
810 if (!ST.isShader())
811 AvoidCaps.S.insert(SPIRV::Capability::Shader);
812 else
813 AvoidCaps.S.insert(SPIRV::Capability::Kernel);
814
815 for (auto Cap : MinimalCaps) {
816 if (AvailableCaps.contains(Cap) && !AvoidCaps.S.contains(Cap))
817 continue;
818 LLVM_DEBUG(dbgs() << "Capability not supported: "
820 OperandCategory::CapabilityOperand, Cap)
821 << "\n");
822 IsSatisfiable = false;
823 }
824
825 for (auto Ext : AllExtensions) {
826 if (ST.canUseExtension(Ext))
827 continue;
828 LLVM_DEBUG(dbgs() << "Extension not supported: "
830 OperandCategory::ExtensionOperand, Ext)
831 << "\n");
832 IsSatisfiable = false;
833 }
834
835 if (!IsSatisfiable)
836 report_fatal_error("Unable to meet SPIR-V requirements for this target.");
837}
838
839// Add the given capabilities and all their implicitly defined capabilities too.
841 for (const auto Cap : ToAdd)
842 if (AvailableCaps.insert(Cap).second)
843 addAvailableCaps(getSymbolicOperandCapabilities(
844 SPIRV::OperandCategory::CapabilityOperand, Cap));
845}
846
848 const Capability::Capability ToRemove,
849 const Capability::Capability IfPresent) {
850 if (AllCaps.contains(IfPresent))
851 AllCaps.erase(ToRemove);
852}
853
854namespace llvm {
855namespace SPIRV {
856void RequirementHandler::initAvailableCapabilities(const SPIRVSubtarget &ST) {
857 // Provided by both all supported Vulkan versions and OpenCl.
858 addAvailableCaps({Capability::Shader, Capability::Linkage, Capability::Int8,
859 Capability::Int16});
860
861 if (ST.isAtLeastSPIRVVer(VersionTuple(1, 3)))
862 addAvailableCaps({Capability::GroupNonUniform,
863 Capability::GroupNonUniformVote,
864 Capability::GroupNonUniformArithmetic,
865 Capability::GroupNonUniformBallot,
866 Capability::GroupNonUniformClustered,
867 Capability::GroupNonUniformShuffle,
868 Capability::GroupNonUniformShuffleRelative});
869
870 if (ST.isAtLeastSPIRVVer(VersionTuple(1, 6)))
871 addAvailableCaps({Capability::DotProduct, Capability::DotProductInputAll,
872 Capability::DotProductInput4x8Bit,
873 Capability::DotProductInput4x8BitPacked,
874 Capability::DemoteToHelperInvocation});
875
876 // Add capabilities enabled by extensions.
877 for (auto Extension : ST.getAllAvailableExtensions()) {
878 CapabilityList EnabledCapabilities =
880 addAvailableCaps(EnabledCapabilities);
881 }
882
883 if (!ST.isShader()) {
884 initAvailableCapabilitiesForOpenCL(ST);
885 return;
886 }
887
888 if (ST.isShader()) {
889 initAvailableCapabilitiesForVulkan(ST);
890 return;
891 }
892
893 report_fatal_error("Unimplemented environment for SPIR-V generation.");
894}
895
896void RequirementHandler::initAvailableCapabilitiesForOpenCL(
897 const SPIRVSubtarget &ST) {
898 // Add the min requirements for different OpenCL and SPIR-V versions.
899 addAvailableCaps({Capability::Addresses, Capability::Float16Buffer,
900 Capability::Kernel, Capability::Vector16,
901 Capability::Groups, Capability::GenericPointer,
902 Capability::StorageImageWriteWithoutFormat,
903 Capability::StorageImageReadWithoutFormat});
904 if (ST.hasOpenCLFullProfile())
905 addAvailableCaps({Capability::Int64, Capability::Int64Atomics});
906 if (ST.hasOpenCLImageSupport()) {
907 addAvailableCaps({Capability::ImageBasic, Capability::LiteralSampler,
908 Capability::Image1D, Capability::SampledBuffer,
909 Capability::ImageBuffer});
910 if (ST.isAtLeastOpenCLVer(VersionTuple(2, 0)))
911 addAvailableCaps({Capability::ImageReadWrite});
912 }
913 if (ST.isAtLeastSPIRVVer(VersionTuple(1, 1)) &&
914 ST.isAtLeastOpenCLVer(VersionTuple(2, 2)))
915 addAvailableCaps({Capability::SubgroupDispatch, Capability::PipeStorage});
916 if (ST.isAtLeastSPIRVVer(VersionTuple(1, 4)))
917 addAvailableCaps({Capability::DenormPreserve, Capability::DenormFlushToZero,
918 Capability::SignedZeroInfNanPreserve,
919 Capability::RoundingModeRTE,
920 Capability::RoundingModeRTZ});
921 // TODO: verify if this needs some checks.
922 addAvailableCaps({Capability::Float16, Capability::Float64});
923
924 // TODO: add OpenCL extensions.
925}
926
927void RequirementHandler::initAvailableCapabilitiesForVulkan(
928 const SPIRVSubtarget &ST) {
929
930 // Core in Vulkan 1.1 and earlier.
931 addAvailableCaps({Capability::Int64, Capability::Float16, Capability::Float64,
932 Capability::GroupNonUniform, Capability::Image1D,
933 Capability::SampledBuffer, Capability::ImageBuffer,
934 Capability::UniformBufferArrayDynamicIndexing,
935 Capability::SampledImageArrayDynamicIndexing,
936 Capability::StorageBufferArrayDynamicIndexing,
937 Capability::StorageImageArrayDynamicIndexing,
938 Capability::DerivativeControl});
939
940 // Became core in Vulkan 1.2
941 if (ST.isAtLeastSPIRVVer(VersionTuple(1, 5))) {
943 {Capability::ShaderNonUniformEXT, Capability::RuntimeDescriptorArrayEXT,
944 Capability::InputAttachmentArrayDynamicIndexingEXT,
945 Capability::UniformTexelBufferArrayDynamicIndexingEXT,
946 Capability::StorageTexelBufferArrayDynamicIndexingEXT,
947 Capability::UniformBufferArrayNonUniformIndexingEXT,
948 Capability::SampledImageArrayNonUniformIndexingEXT,
949 Capability::StorageBufferArrayNonUniformIndexingEXT,
950 Capability::StorageImageArrayNonUniformIndexingEXT,
951 Capability::InputAttachmentArrayNonUniformIndexingEXT,
952 Capability::UniformTexelBufferArrayNonUniformIndexingEXT,
953 Capability::StorageTexelBufferArrayNonUniformIndexingEXT});
954 }
955
956 // Became core in Vulkan 1.3
957 if (ST.isAtLeastSPIRVVer(VersionTuple(1, 6)))
958 addAvailableCaps({Capability::StorageImageWriteWithoutFormat,
959 Capability::StorageImageReadWithoutFormat});
960}
961
962} // namespace SPIRV
963} // namespace llvm
964
965// Add the required capabilities from a decoration instruction (including
966// BuiltIns).
967static void addOpDecorateReqs(const MachineInstr &MI, unsigned DecIndex,
969 const SPIRVSubtarget &ST) {
970 int64_t DecOp = MI.getOperand(DecIndex).getImm();
971 auto Dec = static_cast<SPIRV::Decoration::Decoration>(DecOp);
972 Reqs.addRequirements(getSymbolicOperandRequirements(
973 SPIRV::OperandCategory::DecorationOperand, Dec, ST, Reqs));
974
975 if (Dec == SPIRV::Decoration::BuiltIn) {
976 int64_t BuiltInOp = MI.getOperand(DecIndex + 1).getImm();
977 auto BuiltIn = static_cast<SPIRV::BuiltIn::BuiltIn>(BuiltInOp);
978 Reqs.addRequirements(getSymbolicOperandRequirements(
979 SPIRV::OperandCategory::BuiltInOperand, BuiltIn, ST, Reqs));
980 } else if (Dec == SPIRV::Decoration::LinkageAttributes) {
981 int64_t LinkageOp = MI.getOperand(MI.getNumOperands() - 1).getImm();
982 SPIRV::LinkageType::LinkageType LnkType =
983 static_cast<SPIRV::LinkageType::LinkageType>(LinkageOp);
984 if (LnkType == SPIRV::LinkageType::LinkOnceODR)
985 Reqs.addExtension(SPIRV::Extension::SPV_KHR_linkonce_odr);
986 } else if (Dec == SPIRV::Decoration::CacheControlLoadINTEL ||
987 Dec == SPIRV::Decoration::CacheControlStoreINTEL) {
988 Reqs.addExtension(SPIRV::Extension::SPV_INTEL_cache_controls);
989 } else if (Dec == SPIRV::Decoration::HostAccessINTEL) {
990 Reqs.addExtension(SPIRV::Extension::SPV_INTEL_global_variable_host_access);
991 } else if (Dec == SPIRV::Decoration::InitModeINTEL ||
992 Dec == SPIRV::Decoration::ImplementInRegisterMapINTEL) {
993 Reqs.addExtension(
994 SPIRV::Extension::SPV_INTEL_global_variable_fpga_decorations);
995 } else if (Dec == SPIRV::Decoration::NonUniformEXT) {
996 Reqs.addRequirements(SPIRV::Capability::ShaderNonUniformEXT);
997 } else if (Dec == SPIRV::Decoration::FPMaxErrorDecorationINTEL) {
998 Reqs.addRequirements(SPIRV::Capability::FPMaxErrorINTEL);
999 Reqs.addExtension(SPIRV::Extension::SPV_INTEL_fp_max_error);
1000 } else if (Dec == SPIRV::Decoration::FPFastMathMode) {
1001 if (ST.canUseExtension(SPIRV::Extension::SPV_KHR_float_controls2)) {
1002 Reqs.addRequirements(SPIRV::Capability::FloatControls2);
1003 Reqs.addExtension(SPIRV::Extension::SPV_KHR_float_controls2);
1004 }
1005 }
1006}
1007
1008// Add requirements for image handling.
1009static void addOpTypeImageReqs(const MachineInstr &MI,
1011 const SPIRVSubtarget &ST) {
1012 assert(MI.getNumOperands() >= 8 && "Insufficient operands for OpTypeImage");
1013 // The operand indices used here are based on the OpTypeImage layout, which
1014 // the MachineInstr follows as well.
1015 int64_t ImgFormatOp = MI.getOperand(7).getImm();
1016 auto ImgFormat = static_cast<SPIRV::ImageFormat::ImageFormat>(ImgFormatOp);
1017 Reqs.getAndAddRequirements(SPIRV::OperandCategory::ImageFormatOperand,
1018 ImgFormat, ST);
1019
1020 bool IsArrayed = MI.getOperand(4).getImm() == 1;
1021 bool IsMultisampled = MI.getOperand(5).getImm() == 1;
1022 bool NoSampler = MI.getOperand(6).getImm() == 2;
1023 // Add dimension requirements.
1024 assert(MI.getOperand(2).isImm());
1025 switch (MI.getOperand(2).getImm()) {
1026 case SPIRV::Dim::DIM_1D:
1027 Reqs.addRequirements(NoSampler ? SPIRV::Capability::Image1D
1028 : SPIRV::Capability::Sampled1D);
1029 break;
1030 case SPIRV::Dim::DIM_2D:
1031 if (IsMultisampled && NoSampler)
1032 Reqs.addRequirements(SPIRV::Capability::ImageMSArray);
1033 break;
1034 case SPIRV::Dim::DIM_Cube:
1035 Reqs.addRequirements(SPIRV::Capability::Shader);
1036 if (IsArrayed)
1037 Reqs.addRequirements(NoSampler ? SPIRV::Capability::ImageCubeArray
1038 : SPIRV::Capability::SampledCubeArray);
1039 break;
1040 case SPIRV::Dim::DIM_Rect:
1041 Reqs.addRequirements(NoSampler ? SPIRV::Capability::ImageRect
1042 : SPIRV::Capability::SampledRect);
1043 break;
1044 case SPIRV::Dim::DIM_Buffer:
1045 Reqs.addRequirements(NoSampler ? SPIRV::Capability::ImageBuffer
1046 : SPIRV::Capability::SampledBuffer);
1047 break;
1048 case SPIRV::Dim::DIM_SubpassData:
1049 Reqs.addRequirements(SPIRV::Capability::InputAttachment);
1050 break;
1051 }
1052
1053 // Has optional access qualifier.
1054 if (!ST.isShader()) {
1055 if (MI.getNumOperands() > 8 &&
1056 MI.getOperand(8).getImm() == SPIRV::AccessQualifier::ReadWrite)
1057 Reqs.addRequirements(SPIRV::Capability::ImageReadWrite);
1058 else
1059 Reqs.addRequirements(SPIRV::Capability::ImageBasic);
1060 }
1061}
1062
1063static bool isBFloat16Type(const SPIRVType *TypeDef) {
1064 return TypeDef && TypeDef->getNumOperands() == 3 &&
1065 TypeDef->getOpcode() == SPIRV::OpTypeFloat &&
1066 TypeDef->getOperand(1).getImm() == 16 &&
1067 TypeDef->getOperand(2).getImm() == SPIRV::FPEncoding::BFloat16KHR;
1068}
1069
1070// Add requirements for handling atomic float instructions
1071#define ATOM_FLT_REQ_EXT_MSG(ExtName) \
1072 "The atomic float instruction requires the following SPIR-V " \
1073 "extension: SPV_EXT_shader_atomic_float" ExtName
1074static void AddAtomicFloatRequirements(const MachineInstr &MI,
1076 const SPIRVSubtarget &ST) {
1077 assert(MI.getOperand(1).isReg() &&
1078 "Expect register operand in atomic float instruction");
1079 Register TypeReg = MI.getOperand(1).getReg();
1080 SPIRVType *TypeDef = MI.getMF()->getRegInfo().getVRegDef(TypeReg);
1081 if (TypeDef->getOpcode() != SPIRV::OpTypeFloat)
1082 report_fatal_error("Result type of an atomic float instruction must be a "
1083 "floating-point type scalar");
1084
1085 unsigned BitWidth = TypeDef->getOperand(1).getImm();
1086 unsigned Op = MI.getOpcode();
1087 if (Op == SPIRV::OpAtomicFAddEXT) {
1088 if (!ST.canUseExtension(SPIRV::Extension::SPV_EXT_shader_atomic_float_add))
1090 Reqs.addExtension(SPIRV::Extension::SPV_EXT_shader_atomic_float_add);
1091 switch (BitWidth) {
1092 case 16:
1093 if (isBFloat16Type(TypeDef)) {
1094 if (!ST.canUseExtension(SPIRV::Extension::SPV_INTEL_16bit_atomics))
1096 "The atomic bfloat16 instruction requires the following SPIR-V "
1097 "extension: SPV_INTEL_16bit_atomics",
1098 false);
1099 Reqs.addExtension(SPIRV::Extension::SPV_INTEL_16bit_atomics);
1100 Reqs.addCapability(SPIRV::Capability::AtomicBFloat16AddINTEL);
1101 } else {
1102 if (!ST.canUseExtension(
1103 SPIRV::Extension::SPV_EXT_shader_atomic_float16_add))
1104 report_fatal_error(ATOM_FLT_REQ_EXT_MSG("16_add"), false);
1105 Reqs.addExtension(SPIRV::Extension::SPV_EXT_shader_atomic_float16_add);
1106 Reqs.addCapability(SPIRV::Capability::AtomicFloat16AddEXT);
1107 }
1108 break;
1109 case 32:
1110 Reqs.addCapability(SPIRV::Capability::AtomicFloat32AddEXT);
1111 break;
1112 case 64:
1113 Reqs.addCapability(SPIRV::Capability::AtomicFloat64AddEXT);
1114 break;
1115 default:
1117 "Unexpected floating-point type width in atomic float instruction");
1118 }
1119 } else {
1120 if (!ST.canUseExtension(
1121 SPIRV::Extension::SPV_EXT_shader_atomic_float_min_max))
1122 report_fatal_error(ATOM_FLT_REQ_EXT_MSG("_min_max"), false);
1123 Reqs.addExtension(SPIRV::Extension::SPV_EXT_shader_atomic_float_min_max);
1124 switch (BitWidth) {
1125 case 16:
1126 if (isBFloat16Type(TypeDef)) {
1127 if (!ST.canUseExtension(SPIRV::Extension::SPV_INTEL_16bit_atomics))
1129 "The atomic bfloat16 instruction requires the following SPIR-V "
1130 "extension: SPV_INTEL_16bit_atomics",
1131 false);
1132 Reqs.addExtension(SPIRV::Extension::SPV_INTEL_16bit_atomics);
1133 Reqs.addCapability(SPIRV::Capability::AtomicBFloat16MinMaxINTEL);
1134 } else {
1135 Reqs.addCapability(SPIRV::Capability::AtomicFloat16MinMaxEXT);
1136 }
1137 break;
1138 case 32:
1139 Reqs.addCapability(SPIRV::Capability::AtomicFloat32MinMaxEXT);
1140 break;
1141 case 64:
1142 Reqs.addCapability(SPIRV::Capability::AtomicFloat64MinMaxEXT);
1143 break;
1144 default:
1146 "Unexpected floating-point type width in atomic float instruction");
1147 }
1148 }
1149}
1150
1151bool isUniformTexelBuffer(MachineInstr *ImageInst) {
1152 if (ImageInst->getOpcode() != SPIRV::OpTypeImage)
1153 return false;
1154 uint32_t Dim = ImageInst->getOperand(2).getImm();
1155 uint32_t Sampled = ImageInst->getOperand(6).getImm();
1156 return Dim == SPIRV::Dim::DIM_Buffer && Sampled == 1;
1157}
1158
1159bool isStorageTexelBuffer(MachineInstr *ImageInst) {
1160 if (ImageInst->getOpcode() != SPIRV::OpTypeImage)
1161 return false;
1162 uint32_t Dim = ImageInst->getOperand(2).getImm();
1163 uint32_t Sampled = ImageInst->getOperand(6).getImm();
1164 return Dim == SPIRV::Dim::DIM_Buffer && Sampled == 2;
1165}
1166
1167bool isSampledImage(MachineInstr *ImageInst) {
1168 if (ImageInst->getOpcode() != SPIRV::OpTypeImage)
1169 return false;
1170 uint32_t Dim = ImageInst->getOperand(2).getImm();
1171 uint32_t Sampled = ImageInst->getOperand(6).getImm();
1172 return Dim != SPIRV::Dim::DIM_Buffer && Sampled == 1;
1173}
1174
1175bool isInputAttachment(MachineInstr *ImageInst) {
1176 if (ImageInst->getOpcode() != SPIRV::OpTypeImage)
1177 return false;
1178 uint32_t Dim = ImageInst->getOperand(2).getImm();
1179 uint32_t Sampled = ImageInst->getOperand(6).getImm();
1180 return Dim == SPIRV::Dim::DIM_SubpassData && Sampled == 2;
1181}
1182
1183bool isStorageImage(MachineInstr *ImageInst) {
1184 if (ImageInst->getOpcode() != SPIRV::OpTypeImage)
1185 return false;
1186 uint32_t Dim = ImageInst->getOperand(2).getImm();
1187 uint32_t Sampled = ImageInst->getOperand(6).getImm();
1188 return Dim != SPIRV::Dim::DIM_Buffer && Sampled == 2;
1189}
1190
1191bool isCombinedImageSampler(MachineInstr *SampledImageInst) {
1192 if (SampledImageInst->getOpcode() != SPIRV::OpTypeSampledImage)
1193 return false;
1194
1195 const MachineRegisterInfo &MRI = SampledImageInst->getMF()->getRegInfo();
1196 Register ImageReg = SampledImageInst->getOperand(1).getReg();
1197 auto *ImageInst = MRI.getUniqueVRegDef(ImageReg);
1198 return isSampledImage(ImageInst);
1199}
1200
1201bool hasNonUniformDecoration(Register Reg, const MachineRegisterInfo &MRI) {
1202 for (const auto &MI : MRI.reg_instructions(Reg)) {
1203 if (MI.getOpcode() != SPIRV::OpDecorate)
1204 continue;
1205
1206 uint32_t Dec = MI.getOperand(1).getImm();
1207 if (Dec == SPIRV::Decoration::NonUniformEXT)
1208 return true;
1209 }
1210 return false;
1211}
1212
1213void addOpAccessChainReqs(const MachineInstr &Instr,
1215 const SPIRVSubtarget &Subtarget) {
1216 const MachineRegisterInfo &MRI = Instr.getMF()->getRegInfo();
1217 // Get the result type. If it is an image type, then the shader uses
1218 // descriptor indexing. The appropriate capabilities will be added based
1219 // on the specifics of the image.
1220 Register ResTypeReg = Instr.getOperand(1).getReg();
1221 MachineInstr *ResTypeInst = MRI.getUniqueVRegDef(ResTypeReg);
1222
1223 assert(ResTypeInst->getOpcode() == SPIRV::OpTypePointer);
1224 uint32_t StorageClass = ResTypeInst->getOperand(1).getImm();
1225 if (StorageClass != SPIRV::StorageClass::StorageClass::UniformConstant &&
1226 StorageClass != SPIRV::StorageClass::StorageClass::Uniform &&
1227 StorageClass != SPIRV::StorageClass::StorageClass::StorageBuffer) {
1228 return;
1229 }
1230
1231 bool IsNonUniform =
1232 hasNonUniformDecoration(Instr.getOperand(0).getReg(), MRI);
1233
1234 auto FirstIndexReg = Instr.getOperand(3).getReg();
1235 bool FirstIndexIsConstant =
1236 Subtarget.getInstrInfo()->isConstantInstr(*MRI.getVRegDef(FirstIndexReg));
1237
1238 if (StorageClass == SPIRV::StorageClass::StorageClass::StorageBuffer) {
1239 if (IsNonUniform)
1240 Handler.addRequirements(
1241 SPIRV::Capability::StorageBufferArrayNonUniformIndexingEXT);
1242 else if (!FirstIndexIsConstant)
1243 Handler.addRequirements(
1244 SPIRV::Capability::StorageBufferArrayDynamicIndexing);
1245 return;
1246 }
1247
1248 Register PointeeTypeReg = ResTypeInst->getOperand(2).getReg();
1249 MachineInstr *PointeeType = MRI.getUniqueVRegDef(PointeeTypeReg);
1250 if (PointeeType->getOpcode() != SPIRV::OpTypeImage &&
1251 PointeeType->getOpcode() != SPIRV::OpTypeSampledImage &&
1252 PointeeType->getOpcode() != SPIRV::OpTypeSampler) {
1253 return;
1254 }
1255
1256 if (isUniformTexelBuffer(PointeeType)) {
1257 if (IsNonUniform)
1258 Handler.addRequirements(
1259 SPIRV::Capability::UniformTexelBufferArrayNonUniformIndexingEXT);
1260 else if (!FirstIndexIsConstant)
1261 Handler.addRequirements(
1262 SPIRV::Capability::UniformTexelBufferArrayDynamicIndexingEXT);
1263 } else if (isInputAttachment(PointeeType)) {
1264 if (IsNonUniform)
1265 Handler.addRequirements(
1266 SPIRV::Capability::InputAttachmentArrayNonUniformIndexingEXT);
1267 else if (!FirstIndexIsConstant)
1268 Handler.addRequirements(
1269 SPIRV::Capability::InputAttachmentArrayDynamicIndexingEXT);
1270 } else if (isStorageTexelBuffer(PointeeType)) {
1271 if (IsNonUniform)
1272 Handler.addRequirements(
1273 SPIRV::Capability::StorageTexelBufferArrayNonUniformIndexingEXT);
1274 else if (!FirstIndexIsConstant)
1275 Handler.addRequirements(
1276 SPIRV::Capability::StorageTexelBufferArrayDynamicIndexingEXT);
1277 } else if (isSampledImage(PointeeType) ||
1278 isCombinedImageSampler(PointeeType) ||
1279 PointeeType->getOpcode() == SPIRV::OpTypeSampler) {
1280 if (IsNonUniform)
1281 Handler.addRequirements(
1282 SPIRV::Capability::SampledImageArrayNonUniformIndexingEXT);
1283 else if (!FirstIndexIsConstant)
1284 Handler.addRequirements(
1285 SPIRV::Capability::SampledImageArrayDynamicIndexing);
1286 } else if (isStorageImage(PointeeType)) {
1287 if (IsNonUniform)
1288 Handler.addRequirements(
1289 SPIRV::Capability::StorageImageArrayNonUniformIndexingEXT);
1290 else if (!FirstIndexIsConstant)
1291 Handler.addRequirements(
1292 SPIRV::Capability::StorageImageArrayDynamicIndexing);
1293 }
1294}
1295
1296static bool isImageTypeWithUnknownFormat(SPIRVType *TypeInst) {
1297 if (TypeInst->getOpcode() != SPIRV::OpTypeImage)
1298 return false;
1299 assert(TypeInst->getOperand(7).isImm() && "The image format must be an imm.");
1300 return TypeInst->getOperand(7).getImm() == 0;
1301}
1302
1303static void AddDotProductRequirements(const MachineInstr &MI,
1305 const SPIRVSubtarget &ST) {
1306 if (ST.canUseExtension(SPIRV::Extension::SPV_KHR_integer_dot_product))
1307 Reqs.addExtension(SPIRV::Extension::SPV_KHR_integer_dot_product);
1308 Reqs.addCapability(SPIRV::Capability::DotProduct);
1309
1310 const MachineRegisterInfo &MRI = MI.getMF()->getRegInfo();
1311 assert(MI.getOperand(2).isReg() && "Unexpected operand in dot");
1312 // We do not consider what the previous instruction is. This is just used
1313 // to get the input register and to check the type.
1314 const MachineInstr *Input = MRI.getVRegDef(MI.getOperand(2).getReg());
1315 assert(Input->getOperand(1).isReg() && "Unexpected operand in dot input");
1316 Register InputReg = Input->getOperand(1).getReg();
1317
1318 SPIRVType *TypeDef = MRI.getVRegDef(InputReg);
1319 if (TypeDef->getOpcode() == SPIRV::OpTypeInt) {
1320 assert(TypeDef->getOperand(1).getImm() == 32);
1321 Reqs.addCapability(SPIRV::Capability::DotProductInput4x8BitPacked);
1322 } else if (TypeDef->getOpcode() == SPIRV::OpTypeVector) {
1323 SPIRVType *ScalarTypeDef = MRI.getVRegDef(TypeDef->getOperand(1).getReg());
1324 assert(ScalarTypeDef->getOpcode() == SPIRV::OpTypeInt);
1325 if (ScalarTypeDef->getOperand(1).getImm() == 8) {
1326 assert(TypeDef->getOperand(2).getImm() == 4 &&
1327 "Dot operand of 8-bit integer type requires 4 components");
1328 Reqs.addCapability(SPIRV::Capability::DotProductInput4x8Bit);
1329 } else {
1330 Reqs.addCapability(SPIRV::Capability::DotProductInputAll);
1331 }
1332 }
1333}
1334
1335void addPrintfRequirements(const MachineInstr &MI,
1337 const SPIRVSubtarget &ST) {
1338 SPIRVGlobalRegistry *GR = ST.getSPIRVGlobalRegistry();
1339 const SPIRVType *PtrType = GR->getSPIRVTypeForVReg(MI.getOperand(4).getReg());
1340 if (PtrType) {
1341 MachineOperand ASOp = PtrType->getOperand(1);
1342 if (ASOp.isImm()) {
1343 unsigned AddrSpace = ASOp.getImm();
1344 if (AddrSpace != SPIRV::StorageClass::UniformConstant) {
1345 if (!ST.canUseExtension(
1347 SPV_EXT_relaxed_printf_string_address_space)) {
1348 report_fatal_error("SPV_EXT_relaxed_printf_string_address_space is "
1349 "required because printf uses a format string not "
1350 "in constant address space.",
1351 false);
1352 }
1353 Reqs.addExtension(
1354 SPIRV::Extension::SPV_EXT_relaxed_printf_string_address_space);
1355 }
1356 }
1357 }
1358}
1359
1360void addInstrRequirements(const MachineInstr &MI,
1362 const SPIRVSubtarget &ST) {
1363 SPIRV::RequirementHandler &Reqs = MAI.Reqs;
1364 switch (MI.getOpcode()) {
1365 case SPIRV::OpMemoryModel: {
1366 int64_t Addr = MI.getOperand(0).getImm();
1367 Reqs.getAndAddRequirements(SPIRV::OperandCategory::AddressingModelOperand,
1368 Addr, ST);
1369 int64_t Mem = MI.getOperand(1).getImm();
1370 Reqs.getAndAddRequirements(SPIRV::OperandCategory::MemoryModelOperand, Mem,
1371 ST);
1372 break;
1373 }
1374 case SPIRV::OpEntryPoint: {
1375 int64_t Exe = MI.getOperand(0).getImm();
1376 Reqs.getAndAddRequirements(SPIRV::OperandCategory::ExecutionModelOperand,
1377 Exe, ST);
1378 break;
1379 }
1380 case SPIRV::OpExecutionMode:
1381 case SPIRV::OpExecutionModeId: {
1382 int64_t Exe = MI.getOperand(1).getImm();
1383 Reqs.getAndAddRequirements(SPIRV::OperandCategory::ExecutionModeOperand,
1384 Exe, ST);
1385 break;
1386 }
1387 case SPIRV::OpTypeMatrix:
1388 Reqs.addCapability(SPIRV::Capability::Matrix);
1389 break;
1390 case SPIRV::OpTypeInt: {
1391 unsigned BitWidth = MI.getOperand(1).getImm();
1392 if (BitWidth == 64)
1393 Reqs.addCapability(SPIRV::Capability::Int64);
1394 else if (BitWidth == 16)
1395 Reqs.addCapability(SPIRV::Capability::Int16);
1396 else if (BitWidth == 8)
1397 Reqs.addCapability(SPIRV::Capability::Int8);
1398 break;
1399 }
1400 case SPIRV::OpDot: {
1401 const MachineRegisterInfo &MRI = MI.getMF()->getRegInfo();
1402 SPIRVType *TypeDef = MRI.getVRegDef(MI.getOperand(1).getReg());
1403 if (isBFloat16Type(TypeDef))
1404 Reqs.addCapability(SPIRV::Capability::BFloat16DotProductKHR);
1405 break;
1406 }
1407 case SPIRV::OpTypeFloat: {
1408 unsigned BitWidth = MI.getOperand(1).getImm();
1409 if (BitWidth == 64)
1410 Reqs.addCapability(SPIRV::Capability::Float64);
1411 else if (BitWidth == 16) {
1412 if (isBFloat16Type(&MI)) {
1413 if (!ST.canUseExtension(SPIRV::Extension::SPV_KHR_bfloat16))
1414 report_fatal_error("OpTypeFloat type with bfloat requires the "
1415 "following SPIR-V extension: SPV_KHR_bfloat16",
1416 false);
1417 Reqs.addExtension(SPIRV::Extension::SPV_KHR_bfloat16);
1418 Reqs.addCapability(SPIRV::Capability::BFloat16TypeKHR);
1419 } else {
1420 Reqs.addCapability(SPIRV::Capability::Float16);
1421 }
1422 }
1423 break;
1424 }
1425 case SPIRV::OpTypeVector: {
1426 unsigned NumComponents = MI.getOperand(2).getImm();
1427 if (NumComponents == 8 || NumComponents == 16)
1428 Reqs.addCapability(SPIRV::Capability::Vector16);
1429 break;
1430 }
1431 case SPIRV::OpTypePointer: {
1432 auto SC = MI.getOperand(1).getImm();
1433 Reqs.getAndAddRequirements(SPIRV::OperandCategory::StorageClassOperand, SC,
1434 ST);
1435 // If it's a type of pointer to float16 targeting OpenCL, add Float16Buffer
1436 // capability.
1437 if (ST.isShader())
1438 break;
1439 assert(MI.getOperand(2).isReg());
1440 const MachineRegisterInfo &MRI = MI.getMF()->getRegInfo();
1441 SPIRVType *TypeDef = MRI.getVRegDef(MI.getOperand(2).getReg());
1442 if ((TypeDef->getNumOperands() == 2) &&
1443 (TypeDef->getOpcode() == SPIRV::OpTypeFloat) &&
1444 (TypeDef->getOperand(1).getImm() == 16))
1445 Reqs.addCapability(SPIRV::Capability::Float16Buffer);
1446 break;
1447 }
1448 case SPIRV::OpExtInst: {
1449 if (MI.getOperand(2).getImm() ==
1450 static_cast<int64_t>(
1451 SPIRV::InstructionSet::NonSemantic_Shader_DebugInfo_100)) {
1452 Reqs.addExtension(SPIRV::Extension::SPV_KHR_non_semantic_info);
1453 break;
1454 }
1455 if (MI.getOperand(3).getImm() ==
1456 static_cast<int64_t>(SPIRV::OpenCLExtInst::printf)) {
1457 addPrintfRequirements(MI, Reqs, ST);
1458 break;
1459 }
1460 // TODO: handle bfloat16 extended instructions when
1461 // SPV_INTEL_bfloat16_arithmetic is enabled.
1462 break;
1463 }
1464 case SPIRV::OpAliasDomainDeclINTEL:
1465 case SPIRV::OpAliasScopeDeclINTEL:
1466 case SPIRV::OpAliasScopeListDeclINTEL: {
1467 Reqs.addExtension(SPIRV::Extension::SPV_INTEL_memory_access_aliasing);
1468 Reqs.addCapability(SPIRV::Capability::MemoryAccessAliasingINTEL);
1469 break;
1470 }
1471 case SPIRV::OpBitReverse:
1472 case SPIRV::OpBitFieldInsert:
1473 case SPIRV::OpBitFieldSExtract:
1474 case SPIRV::OpBitFieldUExtract:
1475 if (!ST.canUseExtension(SPIRV::Extension::SPV_KHR_bit_instructions)) {
1476 Reqs.addCapability(SPIRV::Capability::Shader);
1477 break;
1478 }
1479 Reqs.addExtension(SPIRV::Extension::SPV_KHR_bit_instructions);
1480 Reqs.addCapability(SPIRV::Capability::BitInstructions);
1481 break;
1482 case SPIRV::OpTypeRuntimeArray:
1483 Reqs.addCapability(SPIRV::Capability::Shader);
1484 break;
1485 case SPIRV::OpTypeOpaque:
1486 case SPIRV::OpTypeEvent:
1487 Reqs.addCapability(SPIRV::Capability::Kernel);
1488 break;
1489 case SPIRV::OpTypePipe:
1490 case SPIRV::OpTypeReserveId:
1491 Reqs.addCapability(SPIRV::Capability::Pipes);
1492 break;
1493 case SPIRV::OpTypeDeviceEvent:
1494 case SPIRV::OpTypeQueue:
1495 case SPIRV::OpBuildNDRange:
1496 Reqs.addCapability(SPIRV::Capability::DeviceEnqueue);
1497 break;
1498 case SPIRV::OpDecorate:
1499 case SPIRV::OpDecorateId:
1500 case SPIRV::OpDecorateString:
1501 addOpDecorateReqs(MI, 1, Reqs, ST);
1502 break;
1503 case SPIRV::OpMemberDecorate:
1504 case SPIRV::OpMemberDecorateString:
1505 addOpDecorateReqs(MI, 2, Reqs, ST);
1506 break;
1507 case SPIRV::OpInBoundsPtrAccessChain:
1508 Reqs.addCapability(SPIRV::Capability::Addresses);
1509 break;
1510 case SPIRV::OpConstantSampler:
1511 Reqs.addCapability(SPIRV::Capability::LiteralSampler);
1512 break;
1513 case SPIRV::OpInBoundsAccessChain:
1514 case SPIRV::OpAccessChain:
1515 addOpAccessChainReqs(MI, Reqs, ST);
1516 break;
1517 case SPIRV::OpTypeImage:
1518 addOpTypeImageReqs(MI, Reqs, ST);
1519 break;
1520 case SPIRV::OpTypeSampler:
1521 if (!ST.isShader()) {
1522 Reqs.addCapability(SPIRV::Capability::ImageBasic);
1523 }
1524 break;
1525 case SPIRV::OpTypeForwardPointer:
1526 // TODO: check if it's OpenCL's kernel.
1527 Reqs.addCapability(SPIRV::Capability::Addresses);
1528 break;
1529 case SPIRV::OpAtomicFlagTestAndSet:
1530 case SPIRV::OpAtomicLoad:
1531 case SPIRV::OpAtomicStore:
1532 case SPIRV::OpAtomicExchange:
1533 case SPIRV::OpAtomicCompareExchange:
1534 case SPIRV::OpAtomicIIncrement:
1535 case SPIRV::OpAtomicIDecrement:
1536 case SPIRV::OpAtomicIAdd:
1537 case SPIRV::OpAtomicISub:
1538 case SPIRV::OpAtomicUMin:
1539 case SPIRV::OpAtomicUMax:
1540 case SPIRV::OpAtomicSMin:
1541 case SPIRV::OpAtomicSMax:
1542 case SPIRV::OpAtomicAnd:
1543 case SPIRV::OpAtomicOr:
1544 case SPIRV::OpAtomicXor: {
1545 const MachineRegisterInfo &MRI = MI.getMF()->getRegInfo();
1546 const MachineInstr *InstrPtr = &MI;
1547 if (MI.getOpcode() == SPIRV::OpAtomicStore) {
1548 assert(MI.getOperand(3).isReg());
1549 InstrPtr = MRI.getVRegDef(MI.getOperand(3).getReg());
1550 assert(InstrPtr && "Unexpected type instruction for OpAtomicStore");
1551 }
1552 assert(InstrPtr->getOperand(1).isReg() && "Unexpected operand in atomic");
1553 Register TypeReg = InstrPtr->getOperand(1).getReg();
1554 SPIRVType *TypeDef = MRI.getVRegDef(TypeReg);
1555 if (TypeDef->getOpcode() == SPIRV::OpTypeInt) {
1556 unsigned BitWidth = TypeDef->getOperand(1).getImm();
1557 if (BitWidth == 64)
1558 Reqs.addCapability(SPIRV::Capability::Int64Atomics);
1559 }
1560 break;
1561 }
1562 case SPIRV::OpGroupNonUniformIAdd:
1563 case SPIRV::OpGroupNonUniformFAdd:
1564 case SPIRV::OpGroupNonUniformIMul:
1565 case SPIRV::OpGroupNonUniformFMul:
1566 case SPIRV::OpGroupNonUniformSMin:
1567 case SPIRV::OpGroupNonUniformUMin:
1568 case SPIRV::OpGroupNonUniformFMin:
1569 case SPIRV::OpGroupNonUniformSMax:
1570 case SPIRV::OpGroupNonUniformUMax:
1571 case SPIRV::OpGroupNonUniformFMax:
1572 case SPIRV::OpGroupNonUniformBitwiseAnd:
1573 case SPIRV::OpGroupNonUniformBitwiseOr:
1574 case SPIRV::OpGroupNonUniformBitwiseXor:
1575 case SPIRV::OpGroupNonUniformLogicalAnd:
1576 case SPIRV::OpGroupNonUniformLogicalOr:
1577 case SPIRV::OpGroupNonUniformLogicalXor: {
1578 assert(MI.getOperand(3).isImm());
1579 int64_t GroupOp = MI.getOperand(3).getImm();
1580 switch (GroupOp) {
1581 case SPIRV::GroupOperation::Reduce:
1582 case SPIRV::GroupOperation::InclusiveScan:
1583 case SPIRV::GroupOperation::ExclusiveScan:
1584 Reqs.addCapability(SPIRV::Capability::GroupNonUniformArithmetic);
1585 break;
1586 case SPIRV::GroupOperation::ClusteredReduce:
1587 Reqs.addCapability(SPIRV::Capability::GroupNonUniformClustered);
1588 break;
1589 case SPIRV::GroupOperation::PartitionedReduceNV:
1590 case SPIRV::GroupOperation::PartitionedInclusiveScanNV:
1591 case SPIRV::GroupOperation::PartitionedExclusiveScanNV:
1592 Reqs.addCapability(SPIRV::Capability::GroupNonUniformPartitionedNV);
1593 break;
1594 }
1595 break;
1596 }
1597 case SPIRV::OpGroupNonUniformShuffle:
1598 case SPIRV::OpGroupNonUniformShuffleXor:
1599 Reqs.addCapability(SPIRV::Capability::GroupNonUniformShuffle);
1600 break;
1601 case SPIRV::OpGroupNonUniformShuffleUp:
1602 case SPIRV::OpGroupNonUniformShuffleDown:
1603 Reqs.addCapability(SPIRV::Capability::GroupNonUniformShuffleRelative);
1604 break;
1605 case SPIRV::OpGroupAll:
1606 case SPIRV::OpGroupAny:
1607 case SPIRV::OpGroupBroadcast:
1608 case SPIRV::OpGroupIAdd:
1609 case SPIRV::OpGroupFAdd:
1610 case SPIRV::OpGroupFMin:
1611 case SPIRV::OpGroupUMin:
1612 case SPIRV::OpGroupSMin:
1613 case SPIRV::OpGroupFMax:
1614 case SPIRV::OpGroupUMax:
1615 case SPIRV::OpGroupSMax:
1616 Reqs.addCapability(SPIRV::Capability::Groups);
1617 break;
1618 case SPIRV::OpGroupNonUniformElect:
1619 Reqs.addCapability(SPIRV::Capability::GroupNonUniform);
1620 break;
1621 case SPIRV::OpGroupNonUniformAll:
1622 case SPIRV::OpGroupNonUniformAny:
1623 case SPIRV::OpGroupNonUniformAllEqual:
1624 Reqs.addCapability(SPIRV::Capability::GroupNonUniformVote);
1625 break;
1626 case SPIRV::OpGroupNonUniformBroadcast:
1627 case SPIRV::OpGroupNonUniformBroadcastFirst:
1628 case SPIRV::OpGroupNonUniformBallot:
1629 case SPIRV::OpGroupNonUniformInverseBallot:
1630 case SPIRV::OpGroupNonUniformBallotBitExtract:
1631 case SPIRV::OpGroupNonUniformBallotBitCount:
1632 case SPIRV::OpGroupNonUniformBallotFindLSB:
1633 case SPIRV::OpGroupNonUniformBallotFindMSB:
1634 Reqs.addCapability(SPIRV::Capability::GroupNonUniformBallot);
1635 break;
1636 case SPIRV::OpSubgroupShuffleINTEL:
1637 case SPIRV::OpSubgroupShuffleDownINTEL:
1638 case SPIRV::OpSubgroupShuffleUpINTEL:
1639 case SPIRV::OpSubgroupShuffleXorINTEL:
1640 if (ST.canUseExtension(SPIRV::Extension::SPV_INTEL_subgroups)) {
1641 Reqs.addExtension(SPIRV::Extension::SPV_INTEL_subgroups);
1642 Reqs.addCapability(SPIRV::Capability::SubgroupShuffleINTEL);
1643 }
1644 break;
1645 case SPIRV::OpSubgroupBlockReadINTEL:
1646 case SPIRV::OpSubgroupBlockWriteINTEL:
1647 if (ST.canUseExtension(SPIRV::Extension::SPV_INTEL_subgroups)) {
1648 Reqs.addExtension(SPIRV::Extension::SPV_INTEL_subgroups);
1649 Reqs.addCapability(SPIRV::Capability::SubgroupBufferBlockIOINTEL);
1650 }
1651 break;
1652 case SPIRV::OpSubgroupImageBlockReadINTEL:
1653 case SPIRV::OpSubgroupImageBlockWriteINTEL:
1654 if (ST.canUseExtension(SPIRV::Extension::SPV_INTEL_subgroups)) {
1655 Reqs.addExtension(SPIRV::Extension::SPV_INTEL_subgroups);
1656 Reqs.addCapability(SPIRV::Capability::SubgroupImageBlockIOINTEL);
1657 }
1658 break;
1659 case SPIRV::OpSubgroupImageMediaBlockReadINTEL:
1660 case SPIRV::OpSubgroupImageMediaBlockWriteINTEL:
1661 if (ST.canUseExtension(SPIRV::Extension::SPV_INTEL_media_block_io)) {
1662 Reqs.addExtension(SPIRV::Extension::SPV_INTEL_media_block_io);
1663 Reqs.addCapability(SPIRV::Capability::SubgroupImageMediaBlockIOINTEL);
1664 }
1665 break;
1666 case SPIRV::OpAssumeTrueKHR:
1667 case SPIRV::OpExpectKHR:
1668 if (ST.canUseExtension(SPIRV::Extension::SPV_KHR_expect_assume)) {
1669 Reqs.addExtension(SPIRV::Extension::SPV_KHR_expect_assume);
1670 Reqs.addCapability(SPIRV::Capability::ExpectAssumeKHR);
1671 }
1672 break;
1673 case SPIRV::OpPtrCastToCrossWorkgroupINTEL:
1674 case SPIRV::OpCrossWorkgroupCastToPtrINTEL:
1675 if (ST.canUseExtension(SPIRV::Extension::SPV_INTEL_usm_storage_classes)) {
1676 Reqs.addExtension(SPIRV::Extension::SPV_INTEL_usm_storage_classes);
1677 Reqs.addCapability(SPIRV::Capability::USMStorageClassesINTEL);
1678 }
1679 break;
1680 case SPIRV::OpConstantFunctionPointerINTEL:
1681 if (ST.canUseExtension(SPIRV::Extension::SPV_INTEL_function_pointers)) {
1682 Reqs.addExtension(SPIRV::Extension::SPV_INTEL_function_pointers);
1683 Reqs.addCapability(SPIRV::Capability::FunctionPointersINTEL);
1684 }
1685 break;
1686 case SPIRV::OpGroupNonUniformRotateKHR:
1687 if (!ST.canUseExtension(SPIRV::Extension::SPV_KHR_subgroup_rotate))
1688 report_fatal_error("OpGroupNonUniformRotateKHR instruction requires the "
1689 "following SPIR-V extension: SPV_KHR_subgroup_rotate",
1690 false);
1691 Reqs.addExtension(SPIRV::Extension::SPV_KHR_subgroup_rotate);
1692 Reqs.addCapability(SPIRV::Capability::GroupNonUniformRotateKHR);
1693 Reqs.addCapability(SPIRV::Capability::GroupNonUniform);
1694 break;
1695 case SPIRV::OpGroupIMulKHR:
1696 case SPIRV::OpGroupFMulKHR:
1697 case SPIRV::OpGroupBitwiseAndKHR:
1698 case SPIRV::OpGroupBitwiseOrKHR:
1699 case SPIRV::OpGroupBitwiseXorKHR:
1700 case SPIRV::OpGroupLogicalAndKHR:
1701 case SPIRV::OpGroupLogicalOrKHR:
1702 case SPIRV::OpGroupLogicalXorKHR:
1703 if (ST.canUseExtension(
1704 SPIRV::Extension::SPV_KHR_uniform_group_instructions)) {
1705 Reqs.addExtension(SPIRV::Extension::SPV_KHR_uniform_group_instructions);
1706 Reqs.addCapability(SPIRV::Capability::GroupUniformArithmeticKHR);
1707 }
1708 break;
1709 case SPIRV::OpReadClockKHR:
1710 if (!ST.canUseExtension(SPIRV::Extension::SPV_KHR_shader_clock))
1711 report_fatal_error("OpReadClockKHR instruction requires the "
1712 "following SPIR-V extension: SPV_KHR_shader_clock",
1713 false);
1714 Reqs.addExtension(SPIRV::Extension::SPV_KHR_shader_clock);
1715 Reqs.addCapability(SPIRV::Capability::ShaderClockKHR);
1716 break;
1717 case SPIRV::OpFunctionPointerCallINTEL:
1718 if (ST.canUseExtension(SPIRV::Extension::SPV_INTEL_function_pointers)) {
1719 Reqs.addExtension(SPIRV::Extension::SPV_INTEL_function_pointers);
1720 Reqs.addCapability(SPIRV::Capability::FunctionPointersINTEL);
1721 }
1722 break;
1723 case SPIRV::OpAtomicFAddEXT:
1724 case SPIRV::OpAtomicFMinEXT:
1725 case SPIRV::OpAtomicFMaxEXT:
1726 AddAtomicFloatRequirements(MI, Reqs, ST);
1727 break;
1728 case SPIRV::OpConvertBF16ToFINTEL:
1729 case SPIRV::OpConvertFToBF16INTEL:
1730 if (ST.canUseExtension(SPIRV::Extension::SPV_INTEL_bfloat16_conversion)) {
1731 Reqs.addExtension(SPIRV::Extension::SPV_INTEL_bfloat16_conversion);
1732 Reqs.addCapability(SPIRV::Capability::BFloat16ConversionINTEL);
1733 }
1734 break;
1735 case SPIRV::OpRoundFToTF32INTEL:
1736 if (ST.canUseExtension(
1737 SPIRV::Extension::SPV_INTEL_tensor_float32_conversion)) {
1738 Reqs.addExtension(SPIRV::Extension::SPV_INTEL_tensor_float32_conversion);
1739 Reqs.addCapability(SPIRV::Capability::TensorFloat32RoundingINTEL);
1740 }
1741 break;
1742 case SPIRV::OpVariableLengthArrayINTEL:
1743 case SPIRV::OpSaveMemoryINTEL:
1744 case SPIRV::OpRestoreMemoryINTEL:
1745 if (ST.canUseExtension(SPIRV::Extension::SPV_INTEL_variable_length_array)) {
1746 Reqs.addExtension(SPIRV::Extension::SPV_INTEL_variable_length_array);
1747 Reqs.addCapability(SPIRV::Capability::VariableLengthArrayINTEL);
1748 }
1749 break;
1750 case SPIRV::OpAsmTargetINTEL:
1751 case SPIRV::OpAsmINTEL:
1752 case SPIRV::OpAsmCallINTEL:
1753 if (ST.canUseExtension(SPIRV::Extension::SPV_INTEL_inline_assembly)) {
1754 Reqs.addExtension(SPIRV::Extension::SPV_INTEL_inline_assembly);
1755 Reqs.addCapability(SPIRV::Capability::AsmINTEL);
1756 }
1757 break;
1758 case SPIRV::OpTypeCooperativeMatrixKHR: {
1759 if (!ST.canUseExtension(SPIRV::Extension::SPV_KHR_cooperative_matrix))
1761 "OpTypeCooperativeMatrixKHR type requires the "
1762 "following SPIR-V extension: SPV_KHR_cooperative_matrix",
1763 false);
1764 Reqs.addExtension(SPIRV::Extension::SPV_KHR_cooperative_matrix);
1765 Reqs.addCapability(SPIRV::Capability::CooperativeMatrixKHR);
1766 const MachineRegisterInfo &MRI = MI.getMF()->getRegInfo();
1767 SPIRVType *TypeDef = MRI.getVRegDef(MI.getOperand(1).getReg());
1768 if (isBFloat16Type(TypeDef))
1769 Reqs.addCapability(SPIRV::Capability::BFloat16CooperativeMatrixKHR);
1770 break;
1771 }
1772 case SPIRV::OpArithmeticFenceEXT:
1773 if (!ST.canUseExtension(SPIRV::Extension::SPV_EXT_arithmetic_fence))
1774 report_fatal_error("OpArithmeticFenceEXT requires the "
1775 "following SPIR-V extension: SPV_EXT_arithmetic_fence",
1776 false);
1777 Reqs.addExtension(SPIRV::Extension::SPV_EXT_arithmetic_fence);
1778 Reqs.addCapability(SPIRV::Capability::ArithmeticFenceEXT);
1779 break;
1780 case SPIRV::OpControlBarrierArriveINTEL:
1781 case SPIRV::OpControlBarrierWaitINTEL:
1782 if (ST.canUseExtension(SPIRV::Extension::SPV_INTEL_split_barrier)) {
1783 Reqs.addExtension(SPIRV::Extension::SPV_INTEL_split_barrier);
1784 Reqs.addCapability(SPIRV::Capability::SplitBarrierINTEL);
1785 }
1786 break;
1787 case SPIRV::OpCooperativeMatrixMulAddKHR: {
1788 if (!ST.canUseExtension(SPIRV::Extension::SPV_KHR_cooperative_matrix))
1789 report_fatal_error("Cooperative matrix instructions require the "
1790 "following SPIR-V extension: "
1791 "SPV_KHR_cooperative_matrix",
1792 false);
1793 Reqs.addExtension(SPIRV::Extension::SPV_KHR_cooperative_matrix);
1794 Reqs.addCapability(SPIRV::Capability::CooperativeMatrixKHR);
1795 constexpr unsigned MulAddMaxSize = 6;
1796 if (MI.getNumOperands() != MulAddMaxSize)
1797 break;
1798 const int64_t CoopOperands = MI.getOperand(MulAddMaxSize - 1).getImm();
1799 if (CoopOperands &
1800 SPIRV::CooperativeMatrixOperands::MatrixAAndBTF32ComponentsINTEL) {
1801 if (!ST.canUseExtension(SPIRV::Extension::SPV_INTEL_joint_matrix))
1802 report_fatal_error("MatrixAAndBTF32ComponentsINTEL type interpretation "
1803 "require the following SPIR-V extension: "
1804 "SPV_INTEL_joint_matrix",
1805 false);
1806 Reqs.addExtension(SPIRV::Extension::SPV_INTEL_joint_matrix);
1807 Reqs.addCapability(
1808 SPIRV::Capability::CooperativeMatrixTF32ComponentTypeINTEL);
1809 }
1810 if (CoopOperands & SPIRV::CooperativeMatrixOperands::
1811 MatrixAAndBBFloat16ComponentsINTEL ||
1812 CoopOperands &
1813 SPIRV::CooperativeMatrixOperands::MatrixCBFloat16ComponentsINTEL ||
1814 CoopOperands & SPIRV::CooperativeMatrixOperands::
1815 MatrixResultBFloat16ComponentsINTEL) {
1816 if (!ST.canUseExtension(SPIRV::Extension::SPV_INTEL_joint_matrix))
1817 report_fatal_error("***BF16ComponentsINTEL type interpretations "
1818 "require the following SPIR-V extension: "
1819 "SPV_INTEL_joint_matrix",
1820 false);
1821 Reqs.addExtension(SPIRV::Extension::SPV_INTEL_joint_matrix);
1822 Reqs.addCapability(
1823 SPIRV::Capability::CooperativeMatrixBFloat16ComponentTypeINTEL);
1824 }
1825 break;
1826 }
1827 case SPIRV::OpCooperativeMatrixLoadKHR:
1828 case SPIRV::OpCooperativeMatrixStoreKHR:
1829 case SPIRV::OpCooperativeMatrixLoadCheckedINTEL:
1830 case SPIRV::OpCooperativeMatrixStoreCheckedINTEL:
1831 case SPIRV::OpCooperativeMatrixPrefetchINTEL: {
1832 if (!ST.canUseExtension(SPIRV::Extension::SPV_KHR_cooperative_matrix))
1833 report_fatal_error("Cooperative matrix instructions require the "
1834 "following SPIR-V extension: "
1835 "SPV_KHR_cooperative_matrix",
1836 false);
1837 Reqs.addExtension(SPIRV::Extension::SPV_KHR_cooperative_matrix);
1838 Reqs.addCapability(SPIRV::Capability::CooperativeMatrixKHR);
1839
1840 // Check Layout operand in case if it's not a standard one and add the
1841 // appropriate capability.
1842 std::unordered_map<unsigned, unsigned> LayoutToInstMap = {
1843 {SPIRV::OpCooperativeMatrixLoadKHR, 3},
1844 {SPIRV::OpCooperativeMatrixStoreKHR, 2},
1845 {SPIRV::OpCooperativeMatrixLoadCheckedINTEL, 5},
1846 {SPIRV::OpCooperativeMatrixStoreCheckedINTEL, 4},
1847 {SPIRV::OpCooperativeMatrixPrefetchINTEL, 4}};
1848
1849 const auto OpCode = MI.getOpcode();
1850 const unsigned LayoutNum = LayoutToInstMap[OpCode];
1851 Register RegLayout = MI.getOperand(LayoutNum).getReg();
1852 const MachineRegisterInfo &MRI = MI.getMF()->getRegInfo();
1853 MachineInstr *MILayout = MRI.getUniqueVRegDef(RegLayout);
1854 if (MILayout->getOpcode() == SPIRV::OpConstantI) {
1855 const unsigned LayoutVal = MILayout->getOperand(2).getImm();
1856 if (LayoutVal ==
1857 static_cast<unsigned>(SPIRV::CooperativeMatrixLayout::PackedINTEL)) {
1858 if (!ST.canUseExtension(SPIRV::Extension::SPV_INTEL_joint_matrix))
1859 report_fatal_error("PackedINTEL layout require the following SPIR-V "
1860 "extension: SPV_INTEL_joint_matrix",
1861 false);
1862 Reqs.addExtension(SPIRV::Extension::SPV_INTEL_joint_matrix);
1863 Reqs.addCapability(SPIRV::Capability::PackedCooperativeMatrixINTEL);
1864 }
1865 }
1866
1867 // Nothing to do.
1868 if (OpCode == SPIRV::OpCooperativeMatrixLoadKHR ||
1869 OpCode == SPIRV::OpCooperativeMatrixStoreKHR)
1870 break;
1871
1872 std::string InstName;
1873 switch (OpCode) {
1874 case SPIRV::OpCooperativeMatrixPrefetchINTEL:
1875 InstName = "OpCooperativeMatrixPrefetchINTEL";
1876 break;
1877 case SPIRV::OpCooperativeMatrixLoadCheckedINTEL:
1878 InstName = "OpCooperativeMatrixLoadCheckedINTEL";
1879 break;
1880 case SPIRV::OpCooperativeMatrixStoreCheckedINTEL:
1881 InstName = "OpCooperativeMatrixStoreCheckedINTEL";
1882 break;
1883 }
1884
1885 if (!ST.canUseExtension(SPIRV::Extension::SPV_INTEL_joint_matrix)) {
1886 const std::string ErrorMsg =
1887 InstName + " instruction requires the "
1888 "following SPIR-V extension: SPV_INTEL_joint_matrix";
1889 report_fatal_error(ErrorMsg.c_str(), false);
1890 }
1891 Reqs.addExtension(SPIRV::Extension::SPV_INTEL_joint_matrix);
1892 if (OpCode == SPIRV::OpCooperativeMatrixPrefetchINTEL) {
1893 Reqs.addCapability(SPIRV::Capability::CooperativeMatrixPrefetchINTEL);
1894 break;
1895 }
1896 Reqs.addCapability(
1897 SPIRV::Capability::CooperativeMatrixCheckedInstructionsINTEL);
1898 break;
1899 }
1900 case SPIRV::OpCooperativeMatrixConstructCheckedINTEL:
1901 if (!ST.canUseExtension(SPIRV::Extension::SPV_INTEL_joint_matrix))
1902 report_fatal_error("OpCooperativeMatrixConstructCheckedINTEL "
1903 "instructions require the following SPIR-V extension: "
1904 "SPV_INTEL_joint_matrix",
1905 false);
1906 Reqs.addExtension(SPIRV::Extension::SPV_INTEL_joint_matrix);
1907 Reqs.addCapability(
1908 SPIRV::Capability::CooperativeMatrixCheckedInstructionsINTEL);
1909 break;
1910 case SPIRV::OpReadPipeBlockingALTERA:
1911 case SPIRV::OpWritePipeBlockingALTERA:
1912 if (ST.canUseExtension(SPIRV::Extension::SPV_ALTERA_blocking_pipes)) {
1913 Reqs.addExtension(SPIRV::Extension::SPV_ALTERA_blocking_pipes);
1914 Reqs.addCapability(SPIRV::Capability::BlockingPipesALTERA);
1915 }
1916 break;
1917 case SPIRV::OpCooperativeMatrixGetElementCoordINTEL:
1918 if (!ST.canUseExtension(SPIRV::Extension::SPV_INTEL_joint_matrix))
1919 report_fatal_error("OpCooperativeMatrixGetElementCoordINTEL requires the "
1920 "following SPIR-V extension: SPV_INTEL_joint_matrix",
1921 false);
1922 Reqs.addExtension(SPIRV::Extension::SPV_INTEL_joint_matrix);
1923 Reqs.addCapability(
1924 SPIRV::Capability::CooperativeMatrixInvocationInstructionsINTEL);
1925 break;
1926 case SPIRV::OpConvertHandleToImageINTEL:
1927 case SPIRV::OpConvertHandleToSamplerINTEL:
1928 case SPIRV::OpConvertHandleToSampledImageINTEL: {
1929 if (!ST.canUseExtension(SPIRV::Extension::SPV_INTEL_bindless_images))
1930 report_fatal_error("OpConvertHandleTo[Image/Sampler/SampledImage]INTEL "
1931 "instructions require the following SPIR-V extension: "
1932 "SPV_INTEL_bindless_images",
1933 false);
1934 SPIRVGlobalRegistry *GR = ST.getSPIRVGlobalRegistry();
1935 SPIRV::AddressingModel::AddressingModel AddrModel = MAI.Addr;
1936 SPIRVType *TyDef = GR->getSPIRVTypeForVReg(MI.getOperand(1).getReg());
1937 if (MI.getOpcode() == SPIRV::OpConvertHandleToImageINTEL &&
1938 TyDef->getOpcode() != SPIRV::OpTypeImage) {
1939 report_fatal_error("Incorrect return type for the instruction "
1940 "OpConvertHandleToImageINTEL",
1941 false);
1942 } else if (MI.getOpcode() == SPIRV::OpConvertHandleToSamplerINTEL &&
1943 TyDef->getOpcode() != SPIRV::OpTypeSampler) {
1944 report_fatal_error("Incorrect return type for the instruction "
1945 "OpConvertHandleToSamplerINTEL",
1946 false);
1947 } else if (MI.getOpcode() == SPIRV::OpConvertHandleToSampledImageINTEL &&
1948 TyDef->getOpcode() != SPIRV::OpTypeSampledImage) {
1949 report_fatal_error("Incorrect return type for the instruction "
1950 "OpConvertHandleToSampledImageINTEL",
1951 false);
1952 }
1953 SPIRVType *SpvTy = GR->getSPIRVTypeForVReg(MI.getOperand(2).getReg());
1954 unsigned Bitwidth = GR->getScalarOrVectorBitWidth(SpvTy);
1955 if (!(Bitwidth == 32 && AddrModel == SPIRV::AddressingModel::Physical32) &&
1956 !(Bitwidth == 64 && AddrModel == SPIRV::AddressingModel::Physical64)) {
1958 "Parameter value must be a 32-bit scalar in case of "
1959 "Physical32 addressing model or a 64-bit scalar in case of "
1960 "Physical64 addressing model",
1961 false);
1962 }
1963 Reqs.addExtension(SPIRV::Extension::SPV_INTEL_bindless_images);
1964 Reqs.addCapability(SPIRV::Capability::BindlessImagesINTEL);
1965 break;
1966 }
1967 case SPIRV::OpSubgroup2DBlockLoadINTEL:
1968 case SPIRV::OpSubgroup2DBlockLoadTransposeINTEL:
1969 case SPIRV::OpSubgroup2DBlockLoadTransformINTEL:
1970 case SPIRV::OpSubgroup2DBlockPrefetchINTEL:
1971 case SPIRV::OpSubgroup2DBlockStoreINTEL: {
1972 if (!ST.canUseExtension(SPIRV::Extension::SPV_INTEL_2d_block_io))
1973 report_fatal_error("OpSubgroup2DBlock[Load/LoadTranspose/LoadTransform/"
1974 "Prefetch/Store]INTEL instructions require the "
1975 "following SPIR-V extension: SPV_INTEL_2d_block_io",
1976 false);
1977 Reqs.addExtension(SPIRV::Extension::SPV_INTEL_2d_block_io);
1978 Reqs.addCapability(SPIRV::Capability::Subgroup2DBlockIOINTEL);
1979
1980 const auto OpCode = MI.getOpcode();
1981 if (OpCode == SPIRV::OpSubgroup2DBlockLoadTransposeINTEL) {
1982 Reqs.addCapability(SPIRV::Capability::Subgroup2DBlockTransposeINTEL);
1983 break;
1984 }
1985 if (OpCode == SPIRV::OpSubgroup2DBlockLoadTransformINTEL) {
1986 Reqs.addCapability(SPIRV::Capability::Subgroup2DBlockTransformINTEL);
1987 break;
1988 }
1989 break;
1990 }
1991 case SPIRV::OpKill: {
1992 Reqs.addCapability(SPIRV::Capability::Shader);
1993 } break;
1994 case SPIRV::OpDemoteToHelperInvocation:
1995 Reqs.addCapability(SPIRV::Capability::DemoteToHelperInvocation);
1996
1997 if (ST.canUseExtension(
1998 SPIRV::Extension::SPV_EXT_demote_to_helper_invocation)) {
1999 if (!ST.isAtLeastSPIRVVer(llvm::VersionTuple(1, 6)))
2000 Reqs.addExtension(
2001 SPIRV::Extension::SPV_EXT_demote_to_helper_invocation);
2002 }
2003 break;
2004 case SPIRV::OpSDot:
2005 case SPIRV::OpUDot:
2006 case SPIRV::OpSUDot:
2007 case SPIRV::OpSDotAccSat:
2008 case SPIRV::OpUDotAccSat:
2009 case SPIRV::OpSUDotAccSat:
2010 AddDotProductRequirements(MI, Reqs, ST);
2011 break;
2012 case SPIRV::OpImageRead: {
2013 Register ImageReg = MI.getOperand(2).getReg();
2014 SPIRVType *TypeDef = ST.getSPIRVGlobalRegistry()->getResultType(
2015 ImageReg, const_cast<MachineFunction *>(MI.getMF()));
2016 // OpImageRead and OpImageWrite can use Unknown Image Formats
2017 // when the Kernel capability is declared. In the OpenCL environment we are
2018 // not allowed to produce
2019 // StorageImageReadWithoutFormat/StorageImageWriteWithoutFormat, see
2020 // https://github.com/KhronosGroup/SPIRV-Headers/issues/487
2021
2022 if (isImageTypeWithUnknownFormat(TypeDef) && ST.isShader())
2023 Reqs.addCapability(SPIRV::Capability::StorageImageReadWithoutFormat);
2024 break;
2025 }
2026 case SPIRV::OpImageWrite: {
2027 Register ImageReg = MI.getOperand(0).getReg();
2028 SPIRVType *TypeDef = ST.getSPIRVGlobalRegistry()->getResultType(
2029 ImageReg, const_cast<MachineFunction *>(MI.getMF()));
2030 // OpImageRead and OpImageWrite can use Unknown Image Formats
2031 // when the Kernel capability is declared. In the OpenCL environment we are
2032 // not allowed to produce
2033 // StorageImageReadWithoutFormat/StorageImageWriteWithoutFormat, see
2034 // https://github.com/KhronosGroup/SPIRV-Headers/issues/487
2035
2036 if (isImageTypeWithUnknownFormat(TypeDef) && ST.isShader())
2037 Reqs.addCapability(SPIRV::Capability::StorageImageWriteWithoutFormat);
2038 break;
2039 }
2040 case SPIRV::OpTypeStructContinuedINTEL:
2041 case SPIRV::OpConstantCompositeContinuedINTEL:
2042 case SPIRV::OpSpecConstantCompositeContinuedINTEL:
2043 case SPIRV::OpCompositeConstructContinuedINTEL: {
2044 if (!ST.canUseExtension(SPIRV::Extension::SPV_INTEL_long_composites))
2046 "Continued instructions require the "
2047 "following SPIR-V extension: SPV_INTEL_long_composites",
2048 false);
2049 Reqs.addExtension(SPIRV::Extension::SPV_INTEL_long_composites);
2050 Reqs.addCapability(SPIRV::Capability::LongCompositesINTEL);
2051 break;
2052 }
2053 case SPIRV::OpSubgroupMatrixMultiplyAccumulateINTEL: {
2054 if (!ST.canUseExtension(
2055 SPIRV::Extension::SPV_INTEL_subgroup_matrix_multiply_accumulate))
2057 "OpSubgroupMatrixMultiplyAccumulateINTEL instruction requires the "
2058 "following SPIR-V "
2059 "extension: SPV_INTEL_subgroup_matrix_multiply_accumulate",
2060 false);
2061 Reqs.addExtension(
2062 SPIRV::Extension::SPV_INTEL_subgroup_matrix_multiply_accumulate);
2063 Reqs.addCapability(
2064 SPIRV::Capability::SubgroupMatrixMultiplyAccumulateINTEL);
2065 break;
2066 }
2067 case SPIRV::OpBitwiseFunctionINTEL: {
2068 if (!ST.canUseExtension(
2069 SPIRV::Extension::SPV_INTEL_ternary_bitwise_function))
2071 "OpBitwiseFunctionINTEL instruction requires the following SPIR-V "
2072 "extension: SPV_INTEL_ternary_bitwise_function",
2073 false);
2074 Reqs.addExtension(SPIRV::Extension::SPV_INTEL_ternary_bitwise_function);
2075 Reqs.addCapability(SPIRV::Capability::TernaryBitwiseFunctionINTEL);
2076 break;
2077 }
2078 case SPIRV::OpCopyMemorySized: {
2079 Reqs.addCapability(SPIRV::Capability::Addresses);
2080 // TODO: Add UntypedPointersKHR when implemented.
2081 break;
2082 }
2083 case SPIRV::OpPredicatedLoadINTEL:
2084 case SPIRV::OpPredicatedStoreINTEL: {
2085 if (!ST.canUseExtension(SPIRV::Extension::SPV_INTEL_predicated_io))
2087 "OpPredicated[Load/Store]INTEL instructions require "
2088 "the following SPIR-V extension: SPV_INTEL_predicated_io",
2089 false);
2090 Reqs.addExtension(SPIRV::Extension::SPV_INTEL_predicated_io);
2091 Reqs.addCapability(SPIRV::Capability::PredicatedIOINTEL);
2092 break;
2093 }
2094 case SPIRV::OpFAddS:
2095 case SPIRV::OpFSubS:
2096 case SPIRV::OpFMulS:
2097 case SPIRV::OpFDivS:
2098 case SPIRV::OpFRemS:
2099 case SPIRV::OpFMod:
2100 case SPIRV::OpFNegate:
2101 case SPIRV::OpFAddV:
2102 case SPIRV::OpFSubV:
2103 case SPIRV::OpFMulV:
2104 case SPIRV::OpFDivV:
2105 case SPIRV::OpFRemV:
2106 case SPIRV::OpFNegateV: {
2107 const MachineRegisterInfo &MRI = MI.getMF()->getRegInfo();
2108 SPIRVType *TypeDef = MRI.getVRegDef(MI.getOperand(1).getReg());
2109 if (TypeDef->getOpcode() == SPIRV::OpTypeVector)
2110 TypeDef = MRI.getVRegDef(TypeDef->getOperand(1).getReg());
2111 if (isBFloat16Type(TypeDef)) {
2112 if (!ST.canUseExtension(SPIRV::Extension::SPV_INTEL_bfloat16_arithmetic))
2114 "Arithmetic instructions with bfloat16 arguments require the "
2115 "following SPIR-V extension: SPV_INTEL_bfloat16_arithmetic",
2116 false);
2117 Reqs.addExtension(SPIRV::Extension::SPV_INTEL_bfloat16_arithmetic);
2118 Reqs.addCapability(SPIRV::Capability::BFloat16ArithmeticINTEL);
2119 }
2120 break;
2121 }
2122 case SPIRV::OpOrdered:
2123 case SPIRV::OpUnordered:
2124 case SPIRV::OpFOrdEqual:
2125 case SPIRV::OpFOrdNotEqual:
2126 case SPIRV::OpFOrdLessThan:
2127 case SPIRV::OpFOrdLessThanEqual:
2128 case SPIRV::OpFOrdGreaterThan:
2129 case SPIRV::OpFOrdGreaterThanEqual:
2130 case SPIRV::OpFUnordEqual:
2131 case SPIRV::OpFUnordNotEqual:
2132 case SPIRV::OpFUnordLessThan:
2133 case SPIRV::OpFUnordLessThanEqual:
2134 case SPIRV::OpFUnordGreaterThan:
2135 case SPIRV::OpFUnordGreaterThanEqual: {
2136 const MachineRegisterInfo &MRI = MI.getMF()->getRegInfo();
2137 MachineInstr *OperandDef = MRI.getVRegDef(MI.getOperand(2).getReg());
2138 SPIRVType *TypeDef = MRI.getVRegDef(OperandDef->getOperand(1).getReg());
2139 if (TypeDef->getOpcode() == SPIRV::OpTypeVector)
2140 TypeDef = MRI.getVRegDef(TypeDef->getOperand(1).getReg());
2141 if (isBFloat16Type(TypeDef)) {
2142 if (!ST.canUseExtension(SPIRV::Extension::SPV_INTEL_bfloat16_arithmetic))
2144 "Relational instructions with bfloat16 arguments require the "
2145 "following SPIR-V extension: SPV_INTEL_bfloat16_arithmetic",
2146 false);
2147 Reqs.addExtension(SPIRV::Extension::SPV_INTEL_bfloat16_arithmetic);
2148 Reqs.addCapability(SPIRV::Capability::BFloat16ArithmeticINTEL);
2149 }
2150 break;
2151 }
2152 case SPIRV::OpDPdxCoarse:
2153 case SPIRV::OpDPdyCoarse: {
2154 Reqs.addCapability(SPIRV::Capability::DerivativeControl);
2155 break;
2156 }
2157
2158 default:
2159 break;
2160 }
2161
2162 // If we require capability Shader, then we can remove the requirement for
2163 // the BitInstructions capability, since Shader is a superset capability
2164 // of BitInstructions.
2165 Reqs.removeCapabilityIf(SPIRV::Capability::BitInstructions,
2166 SPIRV::Capability::Shader);
2167}
2168
2169static void collectReqs(const Module &M, SPIRV::ModuleAnalysisInfo &MAI,
2170 MachineModuleInfo *MMI, const SPIRVSubtarget &ST) {
2171 // Collect requirements for existing instructions.
2172 for (auto F = M.begin(), E = M.end(); F != E; ++F) {
2174 if (!MF)
2175 continue;
2176 for (const MachineBasicBlock &MBB : *MF)
2177 for (const MachineInstr &MI : MBB)
2178 addInstrRequirements(MI, MAI, ST);
2179 }
2180 // Collect requirements for OpExecutionMode instructions.
2181 auto Node = M.getNamedMetadata("spirv.ExecutionMode");
2182 if (Node) {
2183 bool RequireFloatControls = false, RequireIntelFloatControls2 = false,
2184 RequireKHRFloatControls2 = false,
2185 VerLower14 = !ST.isAtLeastSPIRVVer(VersionTuple(1, 4));
2186 bool HasIntelFloatControls2 =
2187 ST.canUseExtension(SPIRV::Extension::SPV_INTEL_float_controls2);
2188 bool HasKHRFloatControls2 =
2189 ST.canUseExtension(SPIRV::Extension::SPV_KHR_float_controls2);
2190 for (unsigned i = 0; i < Node->getNumOperands(); i++) {
2191 MDNode *MDN = cast<MDNode>(Node->getOperand(i));
2192 const MDOperand &MDOp = MDN->getOperand(1);
2193 if (auto *CMeta = dyn_cast<ConstantAsMetadata>(MDOp)) {
2194 Constant *C = CMeta->getValue();
2195 if (ConstantInt *Const = dyn_cast<ConstantInt>(C)) {
2196 auto EM = Const->getZExtValue();
2197 // SPV_KHR_float_controls is not available until v1.4:
2198 // add SPV_KHR_float_controls if the version is too low
2199 switch (EM) {
2200 case SPIRV::ExecutionMode::DenormPreserve:
2201 case SPIRV::ExecutionMode::DenormFlushToZero:
2202 case SPIRV::ExecutionMode::RoundingModeRTE:
2203 case SPIRV::ExecutionMode::RoundingModeRTZ:
2204 RequireFloatControls = VerLower14;
2206 SPIRV::OperandCategory::ExecutionModeOperand, EM, ST);
2207 break;
2208 case SPIRV::ExecutionMode::RoundingModeRTPINTEL:
2209 case SPIRV::ExecutionMode::RoundingModeRTNINTEL:
2210 case SPIRV::ExecutionMode::FloatingPointModeALTINTEL:
2211 case SPIRV::ExecutionMode::FloatingPointModeIEEEINTEL:
2212 if (HasIntelFloatControls2) {
2213 RequireIntelFloatControls2 = true;
2215 SPIRV::OperandCategory::ExecutionModeOperand, EM, ST);
2216 }
2217 break;
2218 case SPIRV::ExecutionMode::FPFastMathDefault: {
2219 if (HasKHRFloatControls2) {
2220 RequireKHRFloatControls2 = true;
2222 SPIRV::OperandCategory::ExecutionModeOperand, EM, ST);
2223 }
2224 break;
2225 }
2226 case SPIRV::ExecutionMode::ContractionOff:
2227 case SPIRV::ExecutionMode::SignedZeroInfNanPreserve:
2228 if (HasKHRFloatControls2) {
2229 RequireKHRFloatControls2 = true;
2231 SPIRV::OperandCategory::ExecutionModeOperand,
2232 SPIRV::ExecutionMode::FPFastMathDefault, ST);
2233 } else {
2235 SPIRV::OperandCategory::ExecutionModeOperand, EM, ST);
2236 }
2237 break;
2238 default:
2240 SPIRV::OperandCategory::ExecutionModeOperand, EM, ST);
2241 }
2242 }
2243 }
2244 }
2245 if (RequireFloatControls &&
2246 ST.canUseExtension(SPIRV::Extension::SPV_KHR_float_controls))
2247 MAI.Reqs.addExtension(SPIRV::Extension::SPV_KHR_float_controls);
2248 if (RequireIntelFloatControls2)
2249 MAI.Reqs.addExtension(SPIRV::Extension::SPV_INTEL_float_controls2);
2250 if (RequireKHRFloatControls2)
2251 MAI.Reqs.addExtension(SPIRV::Extension::SPV_KHR_float_controls2);
2252 }
2253 for (auto FI = M.begin(), E = M.end(); FI != E; ++FI) {
2254 const Function &F = *FI;
2255 if (F.isDeclaration())
2256 continue;
2257 if (F.getMetadata("reqd_work_group_size"))
2259 SPIRV::OperandCategory::ExecutionModeOperand,
2260 SPIRV::ExecutionMode::LocalSize, ST);
2261 if (F.getFnAttribute("hlsl.numthreads").isValid()) {
2263 SPIRV::OperandCategory::ExecutionModeOperand,
2264 SPIRV::ExecutionMode::LocalSize, ST);
2265 }
2266 if (F.getFnAttribute("enable-maximal-reconvergence").getValueAsBool()) {
2267 MAI.Reqs.addExtension(SPIRV::Extension::SPV_KHR_maximal_reconvergence);
2268 }
2269 if (F.getMetadata("work_group_size_hint"))
2271 SPIRV::OperandCategory::ExecutionModeOperand,
2272 SPIRV::ExecutionMode::LocalSizeHint, ST);
2273 if (F.getMetadata("intel_reqd_sub_group_size"))
2275 SPIRV::OperandCategory::ExecutionModeOperand,
2276 SPIRV::ExecutionMode::SubgroupSize, ST);
2277 if (F.getMetadata("max_work_group_size"))
2279 SPIRV::OperandCategory::ExecutionModeOperand,
2280 SPIRV::ExecutionMode::MaxWorkgroupSizeINTEL, ST);
2281 if (F.getMetadata("vec_type_hint"))
2283 SPIRV::OperandCategory::ExecutionModeOperand,
2284 SPIRV::ExecutionMode::VecTypeHint, ST);
2285
2286 if (F.hasOptNone()) {
2287 if (ST.canUseExtension(SPIRV::Extension::SPV_INTEL_optnone)) {
2288 MAI.Reqs.addExtension(SPIRV::Extension::SPV_INTEL_optnone);
2289 MAI.Reqs.addCapability(SPIRV::Capability::OptNoneINTEL);
2290 } else if (ST.canUseExtension(SPIRV::Extension::SPV_EXT_optnone)) {
2291 MAI.Reqs.addExtension(SPIRV::Extension::SPV_EXT_optnone);
2292 MAI.Reqs.addCapability(SPIRV::Capability::OptNoneEXT);
2293 }
2294 }
2295 }
2296}
2297
2298static unsigned getFastMathFlags(const MachineInstr &I,
2299 const SPIRVSubtarget &ST) {
2300 unsigned Flags = SPIRV::FPFastMathMode::None;
2301 bool CanUseKHRFloatControls2 =
2302 ST.canUseExtension(SPIRV::Extension::SPV_KHR_float_controls2);
2303 if (I.getFlag(MachineInstr::MIFlag::FmNoNans))
2304 Flags |= SPIRV::FPFastMathMode::NotNaN;
2305 if (I.getFlag(MachineInstr::MIFlag::FmNoInfs))
2306 Flags |= SPIRV::FPFastMathMode::NotInf;
2307 if (I.getFlag(MachineInstr::MIFlag::FmNsz))
2308 Flags |= SPIRV::FPFastMathMode::NSZ;
2309 if (I.getFlag(MachineInstr::MIFlag::FmArcp))
2310 Flags |= SPIRV::FPFastMathMode::AllowRecip;
2311 if (I.getFlag(MachineInstr::MIFlag::FmContract) && CanUseKHRFloatControls2)
2312 Flags |= SPIRV::FPFastMathMode::AllowContract;
2313 if (I.getFlag(MachineInstr::MIFlag::FmReassoc)) {
2314 if (CanUseKHRFloatControls2)
2315 // LLVM reassoc maps to SPIRV transform, see
2316 // https://github.com/KhronosGroup/SPIRV-Registry/issues/326 for details.
2317 // Because we are enabling AllowTransform, we must enable AllowReassoc and
2318 // AllowContract too, as required by SPIRV spec. Also, we used to map
2319 // MIFlag::FmReassoc to FPFastMathMode::Fast, which now should instead by
2320 // replaced by turning all the other bits instead. Therefore, we're
2321 // enabling every bit here except None and Fast.
2322 Flags |= SPIRV::FPFastMathMode::NotNaN | SPIRV::FPFastMathMode::NotInf |
2323 SPIRV::FPFastMathMode::NSZ | SPIRV::FPFastMathMode::AllowRecip |
2324 SPIRV::FPFastMathMode::AllowTransform |
2325 SPIRV::FPFastMathMode::AllowReassoc |
2326 SPIRV::FPFastMathMode::AllowContract;
2327 else
2328 Flags |= SPIRV::FPFastMathMode::Fast;
2329 }
2330
2331 if (CanUseKHRFloatControls2) {
2332 // Error out if SPIRV::FPFastMathMode::Fast is enabled.
2333 assert(!(Flags & SPIRV::FPFastMathMode::Fast) &&
2334 "SPIRV::FPFastMathMode::Fast is deprecated and should not be used "
2335 "anymore.");
2336
2337 // Error out if AllowTransform is enabled without AllowReassoc and
2338 // AllowContract.
2339 assert((!(Flags & SPIRV::FPFastMathMode::AllowTransform) ||
2340 ((Flags & SPIRV::FPFastMathMode::AllowReassoc &&
2341 Flags & SPIRV::FPFastMathMode::AllowContract))) &&
2342 "SPIRV::FPFastMathMode::AllowTransform requires AllowReassoc and "
2343 "AllowContract flags to be enabled as well.");
2344 }
2345
2346 return Flags;
2347}
2348
2349static bool isFastMathModeAvailable(const SPIRVSubtarget &ST) {
2350 if (ST.isKernel())
2351 return true;
2352 if (ST.getSPIRVVersion() < VersionTuple(1, 2))
2353 return false;
2354 return ST.canUseExtension(SPIRV::Extension::SPV_KHR_float_controls2);
2355}
2356
2357static void handleMIFlagDecoration(
2358 MachineInstr &I, const SPIRVSubtarget &ST, const SPIRVInstrInfo &TII,
2360 SPIRV::FPFastMathDefaultInfoVector &FPFastMathDefaultInfoVec) {
2361 if (I.getFlag(MachineInstr::MIFlag::NoSWrap) && TII.canUseNSW(I) &&
2362 getSymbolicOperandRequirements(SPIRV::OperandCategory::DecorationOperand,
2363 SPIRV::Decoration::NoSignedWrap, ST, Reqs)
2364 .IsSatisfiable) {
2365 buildOpDecorate(I.getOperand(0).getReg(), I, TII,
2366 SPIRV::Decoration::NoSignedWrap, {});
2367 }
2368 if (I.getFlag(MachineInstr::MIFlag::NoUWrap) && TII.canUseNUW(I) &&
2369 getSymbolicOperandRequirements(SPIRV::OperandCategory::DecorationOperand,
2370 SPIRV::Decoration::NoUnsignedWrap, ST,
2371 Reqs)
2372 .IsSatisfiable) {
2373 buildOpDecorate(I.getOperand(0).getReg(), I, TII,
2374 SPIRV::Decoration::NoUnsignedWrap, {});
2375 }
2376 if (!TII.canUseFastMathFlags(
2377 I, ST.canUseExtension(SPIRV::Extension::SPV_KHR_float_controls2)))
2378 return;
2379
2380 unsigned FMFlags = getFastMathFlags(I, ST);
2381 if (FMFlags == SPIRV::FPFastMathMode::None) {
2382 // We also need to check if any FPFastMathDefault info was set for the
2383 // types used in this instruction.
2384 if (FPFastMathDefaultInfoVec.empty())
2385 return;
2386
2387 // There are three types of instructions that can use fast math flags:
2388 // 1. Arithmetic instructions (FAdd, FMul, FSub, FDiv, FRem, etc.)
2389 // 2. Relational instructions (FCmp, FOrd, FUnord, etc.)
2390 // 3. Extended instructions (ExtInst)
2391 // For arithmetic instructions, the floating point type can be in the
2392 // result type or in the operands, but they all must be the same.
2393 // For the relational and logical instructions, the floating point type
2394 // can only be in the operands 1 and 2, not the result type. Also, the
2395 // operands must have the same type. For the extended instructions, the
2396 // floating point type can be in the result type or in the operands. It's
2397 // unclear if the operands and the result type must be the same. Let's
2398 // assume they must be. Therefore, for 1. and 2., we can check the first
2399 // operand type, and for 3. we can check the result type.
2400 assert(I.getNumOperands() >= 3 && "Expected at least 3 operands");
2401 Register ResReg = I.getOpcode() == SPIRV::OpExtInst
2402 ? I.getOperand(1).getReg()
2403 : I.getOperand(2).getReg();
2404 SPIRVType *ResType = GR->getSPIRVTypeForVReg(ResReg, I.getMF());
2405 const Type *Ty = GR->getTypeForSPIRVType(ResType);
2406 Ty = Ty->isVectorTy() ? cast<VectorType>(Ty)->getElementType() : Ty;
2407
2408 // Match instruction type with the FPFastMathDefaultInfoVec.
2409 bool Emit = false;
2410 for (SPIRV::FPFastMathDefaultInfo &Elem : FPFastMathDefaultInfoVec) {
2411 if (Ty == Elem.Ty) {
2412 FMFlags = Elem.FastMathFlags;
2413 Emit = Elem.ContractionOff || Elem.SignedZeroInfNanPreserve ||
2414 Elem.FPFastMathDefault;
2415 break;
2416 }
2417 }
2418
2419 if (FMFlags == SPIRV::FPFastMathMode::None && !Emit)
2420 return;
2421 }
2422 if (isFastMathModeAvailable(ST)) {
2423 Register DstReg = I.getOperand(0).getReg();
2424 buildOpDecorate(DstReg, I, TII, SPIRV::Decoration::FPFastMathMode,
2425 {FMFlags});
2426 }
2427}
2428
2429// Walk all functions and add decorations related to MI flags.
2430static void addDecorations(const Module &M, const SPIRVInstrInfo &TII,
2431 MachineModuleInfo *MMI, const SPIRVSubtarget &ST,
2433 const SPIRVGlobalRegistry *GR) {
2434 for (auto F = M.begin(), E = M.end(); F != E; ++F) {
2436 if (!MF)
2437 continue;
2438
2439 for (auto &MBB : *MF)
2440 for (auto &MI : MBB)
2441 handleMIFlagDecoration(MI, ST, TII, MAI.Reqs, GR,
2442 MAI.FPFastMathDefaultInfoMap[&(*F)]);
2443 }
2444}
2445
2446static void addMBBNames(const Module &M, const SPIRVInstrInfo &TII,
2447 MachineModuleInfo *MMI, const SPIRVSubtarget &ST,
2449 for (auto F = M.begin(), E = M.end(); F != E; ++F) {
2451 if (!MF)
2452 continue;
2454 for (auto &MBB : *MF) {
2455 if (!MBB.hasName() || MBB.empty())
2456 continue;
2457 // Emit basic block names.
2458 Register Reg = MRI.createGenericVirtualRegister(LLT::scalar(64));
2459 MRI.setRegClass(Reg, &SPIRV::IDRegClass);
2460 buildOpName(Reg, MBB.getName(), *std::prev(MBB.end()), TII);
2461 MCRegister GlobalReg = MAI.getOrCreateMBBRegister(MBB);
2462 MAI.setRegisterAlias(MF, Reg, GlobalReg);
2463 }
2464 }
2465}
2466
2467// patching Instruction::PHI to SPIRV::OpPhi
2468static void patchPhis(const Module &M, SPIRVGlobalRegistry *GR,
2469 const SPIRVInstrInfo &TII, MachineModuleInfo *MMI) {
2470 for (auto F = M.begin(), E = M.end(); F != E; ++F) {
2472 if (!MF)
2473 continue;
2474 for (auto &MBB : *MF) {
2475 for (MachineInstr &MI : MBB.phis()) {
2476 MI.setDesc(TII.get(SPIRV::OpPhi));
2477 Register ResTypeReg = GR->getSPIRVTypeID(
2478 GR->getSPIRVTypeForVReg(MI.getOperand(0).getReg(), MF));
2479 MI.insert(MI.operands_begin() + 1,
2480 {MachineOperand::CreateReg(ResTypeReg, false)});
2481 }
2482 }
2483
2484 MF->getProperties().setNoPHIs();
2485 }
2486}
2487
2489 const Module &M, SPIRV::ModuleAnalysisInfo &MAI, const Function *F) {
2490 auto it = MAI.FPFastMathDefaultInfoMap.find(F);
2491 if (it != MAI.FPFastMathDefaultInfoMap.end())
2492 return it->second;
2493
2494 // If the map does not contain the entry, create a new one. Initialize it to
2495 // contain all 3 elements sorted by bit width of target type: {half, float,
2496 // double}.
2497 SPIRV::FPFastMathDefaultInfoVector FPFastMathDefaultInfoVec;
2498 FPFastMathDefaultInfoVec.emplace_back(Type::getHalfTy(M.getContext()),
2499 SPIRV::FPFastMathMode::None);
2500 FPFastMathDefaultInfoVec.emplace_back(Type::getFloatTy(M.getContext()),
2501 SPIRV::FPFastMathMode::None);
2502 FPFastMathDefaultInfoVec.emplace_back(Type::getDoubleTy(M.getContext()),
2503 SPIRV::FPFastMathMode::None);
2504 return MAI.FPFastMathDefaultInfoMap[F] = std::move(FPFastMathDefaultInfoVec);
2505}
2506
2508 SPIRV::FPFastMathDefaultInfoVector &FPFastMathDefaultInfoVec,
2509 const Type *Ty) {
2510 size_t BitWidth = Ty->getScalarSizeInBits();
2511 int Index =
2513 BitWidth);
2514 assert(Index >= 0 && Index < 3 &&
2515 "Expected FPFastMathDefaultInfo for half, float, or double");
2516 assert(FPFastMathDefaultInfoVec.size() == 3 &&
2517 "Expected FPFastMathDefaultInfoVec to have exactly 3 elements");
2518 return FPFastMathDefaultInfoVec[Index];
2519}
2520
2521static void collectFPFastMathDefaults(const Module &M,
2523 const SPIRVSubtarget &ST) {
2524 if (!ST.canUseExtension(SPIRV::Extension::SPV_KHR_float_controls2))
2525 return;
2526
2527 // Store the FPFastMathDefaultInfo in the FPFastMathDefaultInfoMap.
2528 // We need the entry point (function) as the key, and the target
2529 // type and flags as the value.
2530 // We also need to check ContractionOff and SignedZeroInfNanPreserve
2531 // execution modes, as they are now deprecated and must be replaced
2532 // with FPFastMathDefaultInfo.
2533 auto Node = M.getNamedMetadata("spirv.ExecutionMode");
2534 if (!Node)
2535 return;
2536
2537 for (unsigned i = 0; i < Node->getNumOperands(); i++) {
2538 MDNode *MDN = cast<MDNode>(Node->getOperand(i));
2539 assert(MDN->getNumOperands() >= 2 && "Expected at least 2 operands");
2540 const Function *F = cast<Function>(
2541 cast<ConstantAsMetadata>(MDN->getOperand(0))->getValue());
2542 const auto EM =
2544 cast<ConstantAsMetadata>(MDN->getOperand(1))->getValue())
2545 ->getZExtValue();
2546 if (EM == SPIRV::ExecutionMode::FPFastMathDefault) {
2547 assert(MDN->getNumOperands() == 4 &&
2548 "Expected 4 operands for FPFastMathDefault");
2549
2550 const Type *T = cast<ValueAsMetadata>(MDN->getOperand(2))->getType();
2551 unsigned Flags =
2553 cast<ConstantAsMetadata>(MDN->getOperand(3))->getValue())
2554 ->getZExtValue();
2555 SPIRV::FPFastMathDefaultInfoVector &FPFastMathDefaultInfoVec =
2558 getFPFastMathDefaultInfo(FPFastMathDefaultInfoVec, T);
2559 Info.FastMathFlags = Flags;
2560 Info.FPFastMathDefault = true;
2561 } else if (EM == SPIRV::ExecutionMode::ContractionOff) {
2562 assert(MDN->getNumOperands() == 2 &&
2563 "Expected no operands for ContractionOff");
2564
2565 // We need to save this info for every possible FP type, i.e. {half,
2566 // float, double, fp128}.
2567 SPIRV::FPFastMathDefaultInfoVector &FPFastMathDefaultInfoVec =
2569 for (SPIRV::FPFastMathDefaultInfo &Info : FPFastMathDefaultInfoVec) {
2570 Info.ContractionOff = true;
2571 }
2572 } else if (EM == SPIRV::ExecutionMode::SignedZeroInfNanPreserve) {
2573 assert(MDN->getNumOperands() == 3 &&
2574 "Expected 1 operand for SignedZeroInfNanPreserve");
2575 unsigned TargetWidth =
2577 cast<ConstantAsMetadata>(MDN->getOperand(2))->getValue())
2578 ->getZExtValue();
2579 // We need to save this info only for the FP type with TargetWidth.
2580 SPIRV::FPFastMathDefaultInfoVector &FPFastMathDefaultInfoVec =
2584 assert(Index >= 0 && Index < 3 &&
2585 "Expected FPFastMathDefaultInfo for half, float, or double");
2586 assert(FPFastMathDefaultInfoVec.size() == 3 &&
2587 "Expected FPFastMathDefaultInfoVec to have exactly 3 elements");
2588 FPFastMathDefaultInfoVec[Index].SignedZeroInfNanPreserve = true;
2589 }
2590 }
2591}
2592
2594
2596 AU.addRequired<TargetPassConfig>();
2597 AU.addRequired<MachineModuleInfoWrapperPass>();
2598}
2599
2601 SPIRVTargetMachine &TM =
2603 ST = TM.getSubtargetImpl();
2604 GR = ST->getSPIRVGlobalRegistry();
2605 TII = ST->getInstrInfo();
2606
2608
2609 setBaseInfo(M);
2610
2611 patchPhis(M, GR, *TII, MMI);
2612
2613 addMBBNames(M, *TII, MMI, *ST, MAI);
2614 collectFPFastMathDefaults(M, MAI, *ST);
2615 addDecorations(M, *TII, MMI, *ST, MAI, GR);
2616
2617 collectReqs(M, MAI, MMI, *ST);
2618
2619 // Process type/const/global var/func decl instructions, number their
2620 // destination registers from 0 to N, collect Extensions and Capabilities.
2621 collectReqs(M, MAI, MMI, *ST);
2622 collectDeclarations(M);
2623
2624 // Number rest of registers from N+1 onwards.
2625 numberRegistersGlobally(M);
2626
2627 // Collect OpName, OpEntryPoint, OpDecorate etc, process other instructions.
2628 processOtherInstrs(M);
2629
2630 // If there are no entry points, we need the Linkage capability.
2631 if (MAI.MS[SPIRV::MB_EntryPoints].empty())
2632 MAI.Reqs.addCapability(SPIRV::Capability::Linkage);
2633
2634 // Set maximum ID used.
2635 GR->setBound(MAI.MaxID);
2636
2637 return false;
2638}
unsigned const MachineRegisterInfo * MRI
MachineInstrBuilder & UseMI
assert(UImm &&(UImm !=~static_cast< T >(0)) &&"Invalid immediate!")
aarch64 promote const
ReachingDefInfo InstSet & ToRemove
MachineBasicBlock & MBB
static GCRegistry::Add< CoreCLRGC > E("coreclr", "CoreCLR-compatible GC")
Analysis containing CSE Info
Definition CSEInfo.cpp:27
#define clEnumValN(ENUMVAL, FLAGNAME, DESC)
#define DEBUG_TYPE
const HexagonInstrInfo * TII
IRTranslator LLVM IR MI
#define F(x, y, z)
Definition MD5.cpp:54
#define I(x, y, z)
Definition MD5.cpp:57
Register Reg
Promote Memory to Register
Definition Mem2Reg.cpp:110
#define T
#define INITIALIZE_PASS(passName, arg, name, cfg, analysis)
Definition PassSupport.h:56
static SPIRV::FPFastMathDefaultInfoVector & getOrCreateFPFastMathDefaultInfoVec(const Module &M, DenseMap< Function *, SPIRV::FPFastMathDefaultInfoVector > &FPFastMathDefaultInfoMap, Function *F)
static SPIRV::FPFastMathDefaultInfo & getFPFastMathDefaultInfo(SPIRV::FPFastMathDefaultInfoVector &FPFastMathDefaultInfoVec, const Type *Ty)
#define ATOM_FLT_REQ_EXT_MSG(ExtName)
static cl::opt< bool > SPVDumpDeps("spv-dump-deps", cl::desc("Dump MIR with SPIR-V dependencies info"), cl::Optional, cl::init(false))
unsigned unsigned DefaultVal
unsigned OpIndex
static cl::list< SPIRV::Capability::Capability > AvoidCapabilities("avoid-spirv-capabilities", cl::desc("SPIR-V capabilities to avoid if there are " "other options enabling a feature"), cl::ZeroOrMore, cl::Hidden, cl::values(clEnumValN(SPIRV::Capability::Shader, "Shader", "SPIR-V Shader capability")))
This file contains some templates that are useful if you are working with the STL at all.
#define LLVM_DEBUG(...)
Definition Debug.h:114
Target-Independent Code Generator Pass Configuration Options pass.
The Input class is used to parse a yaml document into in-memory structs and vectors.
This is the shared class of boolean and integer constants.
Definition Constants.h:87
This is an important base class in LLVM.
Definition Constant.h:43
static constexpr LLT scalar(unsigned SizeInBits)
Get a low-level scalar or aggregate "bag of bits".
Wrapper class representing physical registers. Should be passed by value.
Definition MCRegister.h:41
constexpr bool isValid() const
Definition MCRegister.h:84
Metadata node.
Definition Metadata.h:1078
const MDOperand & getOperand(unsigned I) const
Definition Metadata.h:1442
unsigned getNumOperands() const
Return number of MDNode operands.
Definition Metadata.h:1448
Tracking metadata reference owned by Metadata.
Definition Metadata.h:900
const MachineFunction * getParent() const
Return the MachineFunction containing this basic block.
MachineRegisterInfo & getRegInfo()
getRegInfo - Return information about the registers currently in use.
const MachineFunctionProperties & getProperties() const
Get the function properties.
Register getReg(unsigned Idx) const
Get the register for the operand index.
Representation of each machine instruction.
unsigned getOpcode() const
Returns the opcode of this MachineInstr.
const MachineBasicBlock * getParent() const
unsigned getNumOperands() const
Retuns the total number of operands.
LLVM_ABI const MachineFunction * getMF() const
Return the function that contains the basic block that this instruction belongs to.
const MachineOperand & getOperand(unsigned i) const
This class contains meta information specific to a module.
LLVM_ABI MachineFunction * getMachineFunction(const Function &F) const
Returns the MachineFunction associated to IR function F if there is one, otherwise nullptr.
MachineOperand class - Representation of each machine instruction operand.
unsigned getSubReg() const
int64_t getImm() const
bool isReg() const
isReg - Tests if this is a MO_Register operand.
bool isImm() const
isImm - Tests if this is a MO_Immediate operand.
LLVM_ABI void print(raw_ostream &os, const TargetRegisterInfo *TRI=nullptr) const
Print the MachineOperand to os.
MachineInstr * getParent()
getParent - Return the instruction that this operand belongs to.
static MachineOperand CreateImm(int64_t Val)
MachineOperandType getType() const
getType - Returns the MachineOperandType for this operand.
Register getReg() const
getReg - Returns the register number.
MachineRegisterInfo - Keep track of information for virtual and physical registers,...
A Module instance is used to store all the information related to an LLVM module.
Definition Module.h:67
virtual void print(raw_ostream &OS, const Module *M) const
print - Print out the internal state of the pass.
Definition Pass.cpp:140
AnalysisType & getAnalysis() const
getAnalysis<AnalysisType>() - This function is used by subclasses to get to the analysis information ...
Wrapper class representing virtual and physical registers.
Definition Register.h:20
constexpr bool isValid() const
Definition Register.h:112
SPIRVType * getSPIRVTypeForVReg(Register VReg, const MachineFunction *MF=nullptr) const
const Type * getTypeForSPIRVType(const SPIRVType *Ty) const
Register getSPIRVTypeID(const SPIRVType *SpirvType) const
unsigned getScalarOrVectorBitWidth(const SPIRVType *Type) const
bool isConstantInstr(const MachineInstr &MI) const
const SPIRVInstrInfo * getInstrInfo() const override
SPIRVGlobalRegistry * getSPIRVGlobalRegistry() const
const SPIRVSubtarget * getSubtargetImpl() const
SmallSet - This maintains a set of unique values, optimizing for the case when the set is small (less...
Definition SmallSet.h:133
bool contains(const T &V) const
Check if the SmallSet contains the given element.
Definition SmallSet.h:228
std::pair< const_iterator, bool > insert(const T &V)
insert - Insert an element into the set if it isn't already there.
Definition SmallSet.h:183
reference emplace_back(ArgTypes &&... Args)
void append(ItTy in_start, ItTy in_end)
Add the specified range to the end of the SmallVector.
iterator insert(iterator I, T &&Elt)
void push_back(const T &Elt)
Target-Independent Code Generator Pass Configuration Options.
The instances of the Type class are immutable: once they are created, they are never changed.
Definition Type.h:45
bool isVectorTy() const
True if this is an instance of VectorType.
Definition Type.h:273
static LLVM_ABI Type * getDoubleTy(LLVMContext &C)
Definition Type.cpp:285
static LLVM_ABI Type * getFloatTy(LLVMContext &C)
Definition Type.cpp:284
static LLVM_ABI Type * getHalfTy(LLVMContext &C)
Definition Type.cpp:282
Represents a version number in the form major[.minor[.subminor[.build]]].
bool empty() const
Determine whether this version information is empty (e.g., all version components are zero).
NodeTy * getNextNode()
Get the next node, or nullptr for the list tail.
Definition ilist_node.h:348
#define llvm_unreachable(msg)
Marks that the current location is not supposed to be reachable.
@ C
The default llvm calling convention, compatible with C.
Definition CallingConv.h:34
SmallVector< const MachineInstr * > InstrList
ValuesClass values(OptsTy... Options)
Helper to build a ValuesClass by forwarding a variable number of arguments as an initializer list to ...
initializer< Ty > init(const Ty &Val)
std::enable_if_t< detail::IsValidPointer< X, Y >::value, X * > extract(Y &&MD)
Extract a Value from Metadata.
Definition Metadata.h:667
NodeAddr< InstrNode * > Instr
Definition RDFGraph.h:389
This is an optimization pass for GlobalISel generic memory operations.
void buildOpName(Register Target, const StringRef &Name, MachineIRBuilder &MIRBuilder)
FunctionAddr VTableAddr Value
Definition InstrProf.h:137
std::string getStringImm(const MachineInstr &MI, unsigned StartIndex)
bool all_of(R &&range, UnaryPredicate P)
Provide wrappers to std::all_of which take ranges instead of having to pass begin/end explicitly.
Definition STLExtras.h:1725
hash_code hash_value(const FixedPointSemantics &Val)
ExtensionList getSymbolicOperandExtensions(SPIRV::OperandCategory::OperandCategory Category, uint32_t Value)
CapabilityList getSymbolicOperandCapabilities(SPIRV::OperandCategory::OperandCategory Category, uint32_t Value)
SmallVector< SPIRV::Extension::Extension, 8 > ExtensionList
decltype(auto) dyn_cast(const From &Val)
dyn_cast<X> - Return the argument parameter cast to the specified type.
Definition Casting.h:643
SmallVector< size_t > InstrSignature
VersionTuple getSymbolicOperandMaxVersion(SPIRV::OperandCategory::OperandCategory Category, uint32_t Value)
void buildOpDecorate(Register Reg, MachineIRBuilder &MIRBuilder, SPIRV::Decoration::Decoration Dec, const std::vector< uint32_t > &DecArgs, StringRef StrImm)
MachineInstr * getImm(const MachineOperand &MO, const MachineRegisterInfo *MRI)
CapabilityList getCapabilitiesEnabledByExtension(SPIRV::Extension::Extension Extension)
LLVM_ABI raw_ostream & dbgs()
dbgs() - This returns a reference to a raw_ostream for debugging messages.
Definition Debug.cpp:207
LLVM_ABI void report_fatal_error(Error Err, bool gen_crash_diag=true)
Definition Error.cpp:167
const MachineInstr SPIRVType
std::string getSymbolicOperandMnemonic(SPIRV::OperandCategory::OperandCategory Category, int32_t Value)
LLVM_ABI raw_fd_ostream & errs()
This returns a reference to a raw_ostream for standard error.
DWARFExpression::Operation Op
VersionTuple getSymbolicOperandMinVersion(SPIRV::OperandCategory::OperandCategory Category, uint32_t Value)
constexpr unsigned BitWidth
decltype(auto) cast(const From &Val)
cast<X> - Return the argument parameter cast to the specified type.
Definition Casting.h:559
SmallVector< SPIRV::Capability::Capability, 8 > CapabilityList
std::set< InstrSignature > InstrTraces
hash_code hash_combine(const Ts &...args)
Combine values into a single hash_code.
Definition Hashing.h:592
std::map< SmallVector< size_t >, unsigned > InstrGRegsMap
#define N
SmallSet< SPIRV::Capability::Capability, 4 > S
static struct SPIRV::ModuleAnalysisInfo MAI
bool runOnModule(Module &M) override
runOnModule - Virtual method overriden by subclasses to process the module being operated on.
void getAnalysisUsage(AnalysisUsage &AU) const override
getAnalysisUsage - This function should be overriden by passes that need analysis information to do t...
static size_t computeFPFastMathDefaultInfoVecIndex(size_t BitWidth)
Definition SPIRVUtils.h:146
void setSkipEmission(const MachineInstr *MI)
MCRegister getRegisterAlias(const MachineFunction *MF, Register Reg)
MCRegister getOrCreateMBBRegister(const MachineBasicBlock &MBB)
InstrList MS[NUM_MODULE_SECTIONS]
AddressingModel::AddressingModel Addr
void setRegisterAlias(const MachineFunction *MF, Register Reg, MCRegister AliasReg)
DenseMap< const Function *, SPIRV::FPFastMathDefaultInfoVector > FPFastMathDefaultInfoMap
void addCapabilities(const CapabilityList &ToAdd)
bool isCapabilityAvailable(Capability::Capability Cap) const
void checkSatisfiable(const SPIRVSubtarget &ST) const
void getAndAddRequirements(SPIRV::OperandCategory::OperandCategory Category, uint32_t i, const SPIRVSubtarget &ST)
void addExtension(Extension::Extension ToAdd)
void initAvailableCapabilities(const SPIRVSubtarget &ST)
void removeCapabilityIf(const Capability::Capability ToRemove, const Capability::Capability IfPresent)
void addCapability(Capability::Capability ToAdd)
void addAvailableCaps(const CapabilityList &ToAdd)
void addRequirements(const Requirements &Req)
const std::optional< Capability::Capability > Cap